[
  {
    "path": ".gitignore",
    "content": "# IDE\n.idea/\n.vscode/\n.claude/\n.gemini/\n*.swp\n*.swo\n*~\n\n# OS\n.DS_Store\nThumbs.db\n\n# Python\n__pycache__/\n*.py[cod]\n*$py.class\n*.so\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# Virtual environments\n.venv/\nvenv/\nENV/\nenv/\n\n# Logs\n*.log\nlogs/\ntmp_ray/\n\n# Jupyter\n.ipynb_checkpoints/\n\n# Testing\n.pytest_cache/\n.coverage\nhtmlcov/\n.tox/\n.nox/\n\n# ML/DL\nwandb/\nmlruns/\n*.ckpt\n*.pt\n*.pth\n*.bin\n*.safetensors\noutput/\ncheckpoints/\nckpt/\n\n# Data\n# *.parquet\n# *.csv\n# *.json\n# *.jsonl\n\n# Ray\nray_results/"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n  <h1>OpenOneRec</h1>\n  <p align=\"center\">\n    <strong>An Open Foundation Model and Benchmark to Accelerate Generative Recommendation</strong>\n  </p>\n  <p align=\"center\">\n    <a href=\"https://huggingface.co/OpenOneRec\">\n        <img alt=\"Hugging Face\" src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-OneRec-ffc107?color=ffc107&logoColor=white\" />\n    </a>\n    <a href=\"https://github.com/Kuaishou-OneRec/OpenOneRec\">\n        <img alt=\"GitHub Code\" src=\"https://img.shields.io/badge/GitHub-OpenOneRec-black?logo=github\" />\n    </a>\n     <a href=\"https://arxiv.org/abs/2512.24762\">\n        <img alt=\"Paper\" src=\"https://img.shields.io/badge/Paper-ArXiv-b31b1b?logo=arxiv\" />\n    </a>\n    <a href=\"#license\">\n        <img alt=\"License\" src=\"https://img.shields.io/badge/License-Apache%202.0-green\" />\n    </a>\n  </p>\n</div>\n<br>\n\n## 📖 Introduction\n\n**OpenOneRec** is an open-source framework designed to bridge the gap between traditional recommendation systems and Large Language Models (LLMs). While Generative Recommendation has shown promise, existing models often struggle with isolated data silos and a lack of reasoning capabilities.\n\nTo address this, we introduce a unified framework that comprises:\n* **RecIF-Bench**: The first holistic Recommendation Instruction-Following Benchmark, containing **100M interactions** from 200k users across heterogeneous domains (Short Video, Ads, Product).\n* **OneRec-Foundation Models**: A family of models (1.7B & 8B) built on the Qwen3 backbone. The series includes **Standard** versions trained on our open-source dataset and **Pro** versions enhanced with a hundred-billion-token industrial corpus from Kuaishou.\n* **Full-Stack Pipeline**: We open-source our comprehensive training pipeline, including data processing, co-pretraining, and post-training, to ensure full reproducibility and facilitate scaling law research in recommendation.\n\n## 🔥 News\n\n* **[2026.1.1]** 📑 **The technical report** has been released.\n* **[2026.1.1]** 🎉 **OneRec-Foundation** models (1.7B, 8B) are now available on Hugging Face!\n* **[2026.1.1]** 🚀 **RecIF-Bench** dataset and evaluation scripts are open-sourced.\n\n## 📊 RecIF-Bench\n\nWe propose **RecIF-Bench** to rigorously assess the synergy between instruction following and domain-specific recommendation. It organizes 8 distinct tasks into a four-layer capability hierarchy:\n\n* **Layer 0: Semantic Alignment** (Item Understanding) \n* **Layer 1: Fundamental Prediction** (Short Video Rec, Ad Rec, Product Rec, Label Prediction) \n* **Layer 2: Instruction Following** (Interactive Rec, Label-Conditional Rec) \n* **Layer 3: Reasoning** (Recommendation Explanation) \n\nThe benchmark aggregates data from three domains: **Short Video** (Content), **Ads** (Commercial), and **Product** (E-commerce).\n\n## 🤖 Model Zoo\n\nThe OpenOneRec-Foundation series is built upon the Qwen architecture, enhanced with **Itemic Tokens** for modality alignment and trained via a multi-stage protocol.\n\n| Model | Backbone | Parameters | Description | Link |\n| :--- | :--- | :--- | :--- | :--- |\n| **OneRec-1.7B** | Qwen3-1.7B | 1.7B | Standard version trained on open-source data (~33B tokens) | [HuggingFace](https://huggingface.co/OpenOneRec/OneRec-1.7B) |\n| **OneRec-8B** | Qwen3-8B | 8B | Standard version trained on open-source data (~33B tokens) | [HuggingFace](https://huggingface.co/OpenOneRec/OneRec-8B) |\n| **OneRec-1.7B-Pro** | Qwen3-1.7B | 1.7B | Scaled-up version with expanded datasets (~130B tokens) | [HuggingFace](https://huggingface.co/OpenOneRec/OneRec-1.7B-pro) |\n| **OneRec-8B-Pro** | Qwen3-8B | 8B | Scaled-up version with expanded datasets (~130B tokens) | [HuggingFace](https://huggingface.co/OpenOneRec/OneRec-8B-pro) |\n\n## 🏗️ Method & Architecture\n\nOpenOneRec reframes recommendation as a general-purpose sequence modeling paradigm.\n\n### 1. Items as Tokens\nTo bridge the modality gap, we treat items as a distinct modality using **Itemic Tokens** derived from hierarchical vector quantization. This allows the LLM to process interaction history as a cohesive context sequence.\n\n### 2. Training Pipeline\nOur framework utilizes the following recipe:\n* **Pre-Training**: Integrates collaborative signals via Itemic-Text Alignment and Full-Parameter Co-Pretraining.\n* **Post-Training**:\n    * *Stage 1*: Multi-task Supervised Fine-tuning for basic instruction following.\n    * *Stage 2*: On-policy Distillation to restore general reasoning performance.\n    * *Stage 3*: Reinforcement Learning to enhance recommendation capabilities.\n\n<div align=\"center\">\n  <img src=\"assets/main_framework.png\" width=\"80%\" alt=\"OpenOneRec Overall Framework\" />\n  <br>\n  <em>Figure: The Overall Framework of OpenOneRec.</em>\n</div>\n\n## 📈 Performance\n\n### Results on RecIF-Bench\nOpenOneRec-Foundation achieves **State-of-the-Art (SOTA)** results across RecIF-Bench tasks, significantly outperforming baselines like LC-Rec and TIGER.\n\n| Task | Metric | SASRec | TIGER | LC-Rec | OneRec-1.7B | OneRec-8B | OneRec-1.7B-Pro | **OneRec-8B-Pro** |\n| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |\n| **Short Video Rec** | Recall@32 | 0.0119 | 0.0132 | 0.0180 | 0.0272 | 0.0355 | 0.0274 | **0.0369** |\n| **Ad Rec** | Recall@32 | 0.0293 | 0.0581 | 0.0723 | 0.0707 | 0.0877 | 0.0735 | **0.0964** |\n| **Product Rec** | Recall@32 | 0.0175 | 0.0283 | 0.0416 | 0.0360 | 0.0470 | 0.0405 | **0.0538** |\n| **Label-Cond. Rec** | Recall@32 | 0.0140 | 0.0123 | 0.0170 | 0.0184 | 0.0228 | 0.0182 | **0.0235** |\n| **Label Pred.** | AUC | 0.6244 | 0.6675 | 0.6139 | 0.6184 | 0.6615 | 0.6071 | **0.6912** |\n| **Interactive Rec** | Recall@32 | -- | -- | 0.2394 | 0.1941 | 0.3032 | 0.2024 | **0.3458** |\n| **Item Und.** | LLM Score | -- | -- | 0.2517 | 0.3175 | 0.3202 | 0.3133 | **0.3209** |\n| **Rec. Explanation** | LLM Score | -- | -- | 3.9350 | 3.3540 | 3.6774 | 3.5060 | **4.0381** |\n\n<div align=\"center\">\n  <img src=\"assets/benchmark.png\" width=\"80%\" alt=\"Holistic Performance Overview of OpenOneRec.\" />\n  <br>\n  <em>Holistic Performance Overview of OpenOneRec.</em>\n</div>\n\n### Cross-Domain Transferability\nOn the **Amazon Benchmark** (10 datasets), OpenOneRec demonstrates exceptional zero-shot/few-shot transfer capabilities, achieving an average **26.8% improvement** in Recall@10 over the second-best method.\n\n| Domain | SASRec | TIGER | LC-Rec | **Ours** |\n| :--- | :--- | :--- | :--- | :--- |\n| Baby | 0.0381 | 0.0318 | 0.0344 | **0.0513** |\n| Beauty | 0.0639 | 0.0628 | 0.0764 | **0.0924** |\n| Cell Phones | 0.0782 | 0.0786 | 0.0883 | **0.1036** |\n| Grocery | 0.0789 | 0.0691 | 0.0790 | **0.1029** |\n| Health | 0.0506 | 0.0534 | 0.0616 | **0.0768** |\n| Home | 0.0212 | 0.0216 | 0.0293 | **0.0390** |\n| Pet Supplies | 0.0607 | 0.0542 | 0.0612 | **0.0834** |\n| Sports | 0.0389 | 0.0331 | 0.0418 | **0.0547** |\n| Tools | 0.0437 | 0.0344 | 0.0438 | **0.0593** |\n| Toys | 0.0658 | 0.0527 | 0.0549 | **0.0953** |\n\n*Metric: Recall@10. Ours refers to OneRec-Foundation with text-augmented itemic tokens strategy. For implementation details, please refer to [GRLM](https://github.com/ZY0025/GRLM).*\n\n## 🚀 Quick Start\n\n*Code release and detailed usage instructions are coming soon.*\n\nCurrently, you can load our models using `transformers>=4.51.0`:\n\n```python\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nmodel_name = \"OpenOneRec/OneRec-8B\"\n\n# load the tokenizer and the model\ntokenizer = AutoTokenizer.from_pretrained(model_name)\nmodel = AutoModelForCausalLM.from_pretrained(\n    model_name,\n    torch_dtype=\"auto\",\n    device_map=\"auto\"\n)\n\n# prepare the model input\n# case - prompt with itemic tokens\nprompt = \"这是一个视频：<|sid_begin|><s_a_340><s_b_6566><s_c_5603><|sid_end|>，帮我总结一下这个视频讲述了什么内容\"\nmessages = [\n    {\"role\": \"user\", \"content\": prompt}\n]\ntext = tokenizer.apply_chat_template(\n    messages,\n    tokenize=False,\n    add_generation_prompt=True,\n    enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.\n)\nmodel_inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n\n# conduct text completion\n# Note: In our experience, default decoding settings may be unstable for small models.\n# For 1.7B, we suggest: top_p=0.95, top_k=20, temperature=0.75 (during 0.6 to 0.8)\ngenerated_ids = model.generate(\n    **model_inputs,\n    max_new_tokens=32768\n)\noutput_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() \n\n# parsing thinking content\ntry:\n    # rindex finding 151668 (</think>)\n    index = len(output_ids) - output_ids[::-1].index(151668)\nexcept ValueError:\n    index = 0\n\nthinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(\"\\n\")\ncontent = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip(\"\\n\")\n\nprint(\"thinking content:\", thinking_content)\nprint(\"content:\", content)\n```\n\n## 🛣️ Roadmap / Under Development\n\nWe are actively working on the following features:\n\n- [ ] **General-domain data**: scripts to fetch and preprocess public general-domain corpora used in `data/general_text`.\n- [ ] **Reproducible environments**: training pipeline Docker/Apptainer images for easier end-to-end reproduction.\n- [ ] **One-click reproduction**: further code cleanup and streamlined training recipes for an end-to-end “run from scratch” experience.\n- [ ] **Docs & tutorials**: improved documentation, tutorials, and best-practice guides.\n- [ ] **Unified VeRL integration**: consolidate RL and distillation codepaths into a single, consistent VeRL-based implementation.\n- [ ] **More model sizes**: support additional pretraining scales and configurations beyond current checkpoints.\n\nContributions are welcome! Please refer to the detailed documentation in each module.\n\n\n## 📜 Citation\nIf you find our work helpful, please cite our technical report:\n\n```bibtex\n@misc{OpenOneRec,\ntitle={OpenOneRec Technical Report}, \n      author={Guorui Zhou and Honghui Bao and Jiaming Huang and Jiaxin Deng and Jinghao Zhang and Junda She and Kuo Cai and Lejian Ren and Lu Ren and Qiang Luo and Qianqian Wang and Qigen Hu and Rongzhou Zhang and Ruiming Tang and Shiyao Wang and Wuchao Li and Xiangyu Wu and Xinchen Luo and Xingmei Wang and Yifei Hu and Yunfan Wu and Zhanyu Liu and Zhiyang Zhang and Zixing Zhang and Bo Chen and Bin Wen and Chaoyi Ma and Chengru Song and Chenglong Chu and Defu Lian and Fan Yang and Feng Jiang and Hongtao Cheng and Huanjie Wang and Kun Gai and Pengfei Zheng and Qiang Wang and Rui Huang and Siyang Mao and Tingting Gao and Wei Yuan and Yan Wang and Yang Zhou and Yi Su and Zexuan Cheng and Zhixin Ling and Ziming Li},\n      year={2025},\n      eprint={2512.24762},\n      archivePrefix={arXiv},\n      primaryClass={cs.IR}\n}\n```\n## 🛡️ License\nThe code in this repository is licensed under the Apache 2.0 License. The model weights are subject to their specific license agreements.\n\n## 🙏 Acknowledgements\n\nOpenOneRec is built upon and inspired by the open-source ecosystem. We would like to thank:\n\n- **Qwen3**: for providing the base architecture and model initialization that OpenOneRec builds upon.\n- **General-domain data sources**: for the public corpora referenced in [`data/general_text`](https://github.com/Kuaishou-OneRec/OpenOneRec/tree/main/data/general_text) used for mixed-domain training.\n- **VeRL & PyTorch distributed training**: for the training infrastructure and scalable primitives (e.g., **FSDP**) used in post-training and large-scale runs.\n\nWe sincerely thank these projects for their outstanding work.\n"
  },
  {
    "path": "benchmarks/LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [2025] [OneRec Team]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "benchmarks/README.md",
    "content": "# Benchmark\n\n\n## Quick Start\n\n### Step 1: Install Dependencies\n\n```bash\ncd benchmarks\n\nconda create -n benchmark python=3.10 \nconda activate benchmark\npip install uv\nuv pip install torch==2.5.1 transformers==4.52.0 vllm==0.7.3\npip install -r requirements.txt\npip install -e . --no-deps --no-build-isolation\n```\n\n### Step 2: Start Ray Cluster (Optional)\n\n```bash\n# Initialize multi-node multi-GPU environment\n# Skip this step if using single-node multi-GPU setup\nbash scripts/init_ray_cluster.sh\n```\n\n\n### Step 3: Configure LLM API\n\nEdit `api/config/llm_config.json` to fill in your Gemini configuration:\n\n```json\n{\n  \"gemini\": {\n    \"project\": \"<your-project>\",\n    \"location\": \"<your-location>\",\n    \"credentials_path\": \"<path-to-credentials>\",\n    ...\n  }\n}\n```\n\n**Note**: Only `project`, `location`, and `credentials_path` need to be configured. \n\nTest the configuration:\n\n```python\nfrom api import get_client_from_config\n\n# Create client\nclient = get_client_from_config(\"gemini\")\n\n# Generate text\nresponse = client.generate(\"Tell me a joke\")\nprint(response)\n```\n\n### Step 4: Run Evaluation\n\n```bash\nexport BENCHMARK_BASE_DIR=\".\"\nexport BENCHMARK_DATA_DIR=\"../raw_data/onerec_data/benchmark_data\"\nexport DATA_VERSION=\"v1.0\"\n\nbash eval_script.sh <model_path> <result_name> <enable_thinking>\n```\n\n**Parameters**:\n| Parameter | Description | Example |\n|-----------|-------------|---------|\n| model_path | Path to the model to evaluate | `model_output/sft/global_step10/converted` |\n| result_name | Name identifier for output directory | `sft_nonthink` |\n| enable_thinking | `true` or `false` | `false` |\n\n**Examples**:\n```bash\n# Without thinking mode\nbash eval_script.sh \\\n    /path/to/model \\\n    model_nonthink \\\n    false\n\n# With thinking mode\nbash eval_script.sh \\\n    /path/to/model \\\n    model_think \\\n    true\n```\n\nFor debugging purposes, you can add `--sample_size 10` to each python command in `eval_script.sh` to run evaluation on a smaller subset of data.\n\n\n### Step 5: View Results\n\nAfter evaluation completes, results are saved in:\n```\n./results/v1.0/results_<result_name>/\n```\n\nLog files are located at:\n```\n./auto_eval_logs/v1.0/<result_name>.log\n```\n\n\n---\n\n## Evaluation Tasks\n\n| Task Name | Source | Description |\n|-----------|--------|-------------|\n| ad | Kuaishou Internal | 27,677 | Predict next clicked advertisement |\n| product | Kuaishou Internal | 27,910 | Predict next clicked product |\n| interactive | Kuaishou Internal | 1,000 | Predict next interacted video |\n| video | Kuaishou Internal | 38,781  | Next video prediction |\n| label_cond | Kuaishou Internal | 34,891 | Predict next video given specified consumption behavior |\n| label_pred | Kuaishou Internal | 346,190 | Predict user engagement with video content |\n| item_understand | Kuaishou Internal | 500 | Video SID to Caption generation task |\n| rec_reason | Kuaishou Internal | 470 | Recommendation reason inference |\n\n\n\n"
  },
  {
    "path": "benchmarks/api/README.md",
    "content": "# Unified LLM API Wrapper\n\nThis is a unified LLM API wrapper library that provides a clean and elegant interface for calling different large language models.\n\n## Supported Models\n\n- **Claude** - Anthropic Claude models\n- **Gemini** - Google Vertex AI Gemini models\n- **DeepSeek** - DeepSeek models via Baidu Qianfan platform\n\n## Model Pricing Comparison \n\n- Claude: https://claude.com/pricing\n- Gemini: https://ai.google.dev/gemini-api/docs/pricing\n- DeepSeek: https://api-docs.deepseek.com/quick_start/pricing\n\n\n## Quick Start\n\n### Installation\n\n```bash\npip install openai google-cloud-aiplatform anthropic tqdm\n```\n\n### Using Configuration File\n\nFirst, edit `api/config/llm_config.json` to fill in your configuration:\n\nThen use the following code to test:\n\n```python\nfrom api import get_client_from_config\n\n# Create client\nclient = get_client_from_config(\"gemini\")\n\n# Generate text\nresponse = client.generate(\"Tell me a joke\")\nprint(response)\n```\n\n"
  },
  {
    "path": "benchmarks/api/__init__.py",
    "content": "\"\"\"\nUnified LLM API Wrapper\nSupports convenient calling of Gemini, DeepSeek, and Claude models\n\"\"\"\nimport json\nfrom pathlib import Path\nfrom typing import List, Dict, Any, Optional\n\nfrom .base import BaseLLMClient\nfrom .gemini import GeminiClient\nfrom .deepseek import DeepSeekClient\nfrom .claude import ClaudeClient\n\n\n# Model mapping\nMODEL_CLASSES = {\n    \"gemini\": GeminiClient,\n    \"deepseek\": DeepSeekClient,\n    \"claude\": ClaudeClient,\n}\n\n\ndef load_config(config_path: str = None) -> Dict[str, Any]:\n    \"\"\"\n    Load configuration from JSON file\n\n    Args:\n        config_path: Configuration file path, defaults to api/config/llm_config.json\n\n    Returns:\n        dict: Configuration dictionary\n\n    Raises:\n        FileNotFoundError: Configuration file does not exist\n        json.JSONDecodeError: Configuration file format error\n    \"\"\"\n    if config_path is None:\n        current_dir = Path(__file__).parent\n        config_path = current_dir / \"config\" / \"llm_config.json\"\n\n    config_path = Path(config_path)\n    if not config_path.exists():\n        raise FileNotFoundError(f\"Configuration file does not exist: {config_path}\")\n\n    with open(config_path, 'r', encoding='utf-8') as f:\n        return json.load(f)\n\n\ndef get_client(model: str, **config) -> BaseLLMClient:\n    \"\"\"\n    Factory function: Create LLM client instance\n\n    Args:\n        model: Model name (\"gemini\" or \"deepseek\")\n        **config: Model-specific configuration parameters\n\n    Returns:\n        BaseLLMClient: Client instance\n\n    Raises:\n        ValueError: Unsupported model type\n\n    Example:\n        >>> client = get_client(\"gemini\",\n        ...                    project=\"your-project\",\n        ...                    location=\"us-central1\")\n        >>> result = client.generate(\"Tell me a joke\")\n    \"\"\"\n    model = model.lower()\n    if model not in MODEL_CLASSES:\n        raise ValueError(\n            f\"Unsupported model: {model}. \"\n            f\"Supported models: {', '.join(MODEL_CLASSES.keys())}\"\n        )\n\n    client_class = MODEL_CLASSES[model]\n    return client_class(**config)\n\n\ndef get_client_from_config(\n    model: str,\n    config_path: Optional[str] = None\n) -> BaseLLMClient:\n    \"\"\"\n    Create LLM client from configuration file\n\n    Args:\n        model: Model name (\"gemini\" or \"deepseek\")\n        config_path: Configuration file path, defaults to api/config/llm_config.json\n\n    Returns:\n        BaseLLMClient: Client instance\n\n    Raises:\n        ValueError: Model configuration not found in configuration file\n\n    Example:\n        >>> client = get_client_from_config(\"gemini\")\n        >>> result = client.generate(\"Tell me a joke\")\n    \"\"\"\n    config = load_config(config_path)\n    model = model.lower()\n\n    if model not in config:\n        raise ValueError(\n            f\"Model '{model}' configuration not found in configuration file. \"\n            f\"Available models: {', '.join(config.keys())}\"\n        )\n\n    model_config = config[model]\n    return get_client(model, **model_config)\n\n\ndef batch_generate(\n    prompts: List[str],\n    model: str,\n    max_workers: int = 5,\n    show_progress: bool = True,\n    config_path: Optional[str] = None,\n    **config\n) -> List[Dict[str, Any]]:\n    \"\"\"\n    Batch generate text (with concurrent support)\n\n    Args:\n        prompts: List of prompts\n        model: Model name (\"gemini\" or \"deepseek\")\n        max_workers: Maximum number of concurrent threads, default 5\n        show_progress: Whether to show progress bar, default True\n        config_path: Configuration file path (if provided, use configuration file first)\n        **config: Model configuration parameters (if not using configuration file)\n\n    Returns:\n        List[Dict]: List of results, each element contains:\n            - prompt: Original prompt\n            - result: Generated text (on success)\n            - error: Error message (on failure)\n            - success: Whether successful\n\n    Example:\n        >>> # Using configuration file\n        >>> results = batch_generate(\n        ...     prompts=[\"Question 1\", \"Question 2\", \"Question 3\"],\n        ...     model=\"gemini\",\n        ...     max_workers=3\n        ... )\n\n        >>> # Direct configuration\n        >>> results = batch_generate(\n        ...     prompts=[\"Question 1\", \"Question 2\"],\n        ...     model=\"deepseek\",\n        ...     api_key=\"your-key\",\n        ...     appid=\"your-appid\"\n        ... )\n    \"\"\"\n    if config_path:\n        client = get_client_from_config(model, config_path)\n    else:\n        client = get_client(model, **config)\n\n    return client.batch_generate(\n        prompts=prompts,\n        max_workers=max_workers,\n        show_progress=show_progress\n    )\n\n\n# Export all public interfaces\n__all__ = [\n    # Classes\n    \"BaseLLMClient\",\n    \"GeminiClient\",\n    \"DeepSeekClient\",\n    \"ClaudeClient\",\n    # Functions\n    \"get_client\",\n    \"get_client_from_config\",\n    \"batch_generate\",\n    \"load_config\",\n]\n"
  },
  {
    "path": "benchmarks/api/base.py",
    "content": "\"\"\"\nBase LLM Client Definition\nProvides unified interface specification with retry mechanism and batch processing\n\"\"\"\nfrom abc import ABC, abstractmethod\nfrom typing import Optional, Dict, Any, List\nimport time\nimport random\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\n\n\nclass BaseLLMClient(ABC):\n    \"\"\"\n    Base class for LLM clients, defining unified interface\n\n    All concrete LLM clients (Gemini, DeepSeek, etc.) should inherit from this class\n    Provides unified retry mechanism and batch processing capabilities\n    \"\"\"\n\n    def __init__(self, **config):\n        \"\"\"\n        Initialize client\n\n        Args:\n            **config: Model-specific configuration parameters\n        \"\"\"\n        self.config = config\n        self.max_retries = config.get(\"max_retries\", 3)\n        self.retry_delay = config.get(\"retry_delay\", 2)\n        self._setup()\n\n    @abstractmethod\n    def _setup(self):\n        \"\"\"Setup client (subclasses implement specific initialization logic)\"\"\"\n        pass\n\n    @abstractmethod\n    def _call_api(\n        self,\n        prompt: str,\n        temperature: Optional[float] = None,\n        max_tokens: Optional[int] = None,\n        **kwargs\n    ) -> str:\n        \"\"\"\n        Call API to generate text (subclasses implement specific API call logic)\n\n        Args:\n            prompt: Input prompt\n            temperature: Temperature parameter\n            max_tokens: Maximum number of tokens to generate\n            **kwargs: Other model-specific parameters\n\n        Returns:\n            Generated text content\n\n        Raises:\n            Exception: Raised when API call fails\n        \"\"\"\n        pass\n\n    def _is_retryable_error(self, error_msg: str) -> bool:\n        \"\"\"\n        Determine if error is retryable\n\n        Args:\n            error_msg: Error message\n\n        Returns:\n            bool: Whether the error is retryable\n        \"\"\"\n        retryable_keywords = [\n            '503', '429', '500', 'timeout', 'timed out', 'deadline',\n            'unavailable', 'failed to connect', 'connection',\n            'rate limit', 'overload'\n        ]\n        return any(keyword in error_msg.lower() for keyword in retryable_keywords)\n\n    def _generate_with_retry(\n        self,\n        prompt: str,\n        temperature: Optional[float] = None,\n        max_tokens: Optional[int] = None,\n        **kwargs\n    ) -> str:\n        \"\"\"\n        Generation method with retry mechanism (template method)\n\n        Args:\n            prompt: Input prompt\n            temperature: Temperature parameter\n            max_tokens: Maximum number of tokens to generate\n            **kwargs: Other parameters\n\n        Returns:\n            str: Generated text content\n\n        Raises:\n            Exception: Raised when API call fails\n        \"\"\"\n        if not prompt or not prompt.strip():\n            raise ValueError(\"prompt cannot be empty\")\n\n        last_error = None\n\n        for attempt in range(self.max_retries):\n            try:\n                if attempt > 0:\n                    delay = self.retry_delay * (2 ** (attempt - 1))\n                    jitter = random.uniform(0, delay * 0.3)\n                    time.sleep(delay + jitter)\n\n                return self._call_api(prompt, temperature, max_tokens, **kwargs)\n\n            except Exception as e:\n                last_error = e\n                error_msg = str(e)\n\n                is_retryable = self._is_retryable_error(error_msg)\n\n                if attempt == self.max_retries - 1 or not is_retryable:\n                    raise Exception(f\"{self.__class__.__name__} API call failed: {error_msg}\")\n\n                print(f\"{self.__class__.__name__} API call failed \"\n                      f\"(attempt {attempt + 1}/{self.max_retries}), \"\n                      f\"will retry in {self.retry_delay} seconds: {error_msg[:100]}\")\n\n        raise Exception(f\"Maximum retry attempts reached ({self.max_retries}): {last_error}\")\n\n    def generate(\n        self,\n        prompt: str,\n        temperature: Optional[float] = None,\n        max_tokens: Optional[int] = None,\n        **kwargs\n    ) -> str:\n        \"\"\"\n        Generate text content (public interface)\n\n        Args:\n            prompt: Input prompt\n            temperature: Temperature parameter (controls randomness)\n            max_tokens: Maximum number of tokens to generate\n            **kwargs: Other model-specific parameters\n\n        Returns:\n            str: Generated text content\n\n        Raises:\n            ValueError: Parameter error\n            Exception: API call failed\n        \"\"\"\n        return self._generate_with_retry(prompt, temperature, max_tokens, **kwargs)\n\n    def batch_generate(\n        self,\n        prompts: List[str],\n        max_workers: int = 5,\n        show_progress: bool = True,\n        **kwargs\n    ) -> List[Dict[str, Any]]:\n        \"\"\"\n        Batch generate text (with concurrent support)\n\n        Args:\n            prompts: List of prompts\n            max_workers: Maximum number of concurrent threads, default 5\n            show_progress: Whether to show progress bar, default True\n            **kwargs: Other parameters to pass to generate\n\n        Returns:\n            List[Dict]: List of results, each element contains:\n                - prompt: Original prompt\n                - result: Generated text (on success)\n                - error: Error message (on failure)\n                - success: Whether successful\n        \"\"\"\n        try:\n            from tqdm import tqdm\n            has_tqdm = True\n        except ImportError:\n            has_tqdm = False\n            if show_progress:\n                print(\"Warning: tqdm not installed, cannot show progress bar\")\n\n        def process_prompt(prompt: str, index: int) -> Dict[str, Any]:\n            try:\n                result = self.generate(prompt, **kwargs)\n                return {\n                    \"index\": index,\n                    \"prompt\": prompt,\n                    \"result\": result,\n                    \"success\": True\n                }\n            except Exception as e:\n                return {\n                    \"index\": index,\n                    \"prompt\": prompt,\n                    \"error\": str(e),\n                    \"success\": False\n                }\n\n        with ThreadPoolExecutor(max_workers=max_workers) as executor:\n            future_to_index = {\n                executor.submit(process_prompt, prompt, i): i\n                for i, prompt in enumerate(prompts)\n            }\n            if show_progress and has_tqdm:\n                progress = tqdm(\n                    as_completed(future_to_index),\n                    total=len(prompts),\n                    desc=f\"Generating ({self.__class__.__name__})\"\n                )\n            else:\n                progress = as_completed(future_to_index)\n\n            temp_results = []\n            for future in progress:\n                try:\n                    result = future.result()\n                    temp_results.append(result)\n                except Exception as e:\n                    index = future_to_index[future]\n                    temp_results.append({\n                        \"index\": index,\n                        \"prompt\": prompts[index],\n                        \"error\": f\"Task execution failed: {str(e)}\",\n                        \"success\": False\n                    })\n\n        results = sorted(temp_results, key=lambda x: x[\"index\"])\n\n        for r in results:\n            r.pop(\"index\", None)\n\n        return results\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(config={self.config})\"\n"
  },
  {
    "path": "benchmarks/api/claude.py",
    "content": "\"\"\"\nClaude API Client Implementation\nBased on Anthropic official SDK\n\"\"\"\nfrom typing import Optional\nfrom anthropic import Anthropic\nfrom .base import BaseLLMClient\n\n\nclass ClaudeClient(BaseLLMClient):\n    \"\"\"\n    Claude API Client\n\n    Example:\n        >>> client = ClaudeClient(\n        ...     api_key=\"your-api-key\",\n        ...     model_name=\"claude-sonnet-4-20250514\"\n        ... )\n        >>> response = client.generate(\"Tell me a joke\")\n    \"\"\"\n\n    def _setup(self):\n        \"\"\"Initialize Claude client\"\"\"\n        self.api_key = self.config.get(\"api_key\")\n        self.model_name = self.config.get(\"model_name\", \"claude-sonnet-4-20250514\")\n        self.base_url = self.config.get(\"base_url\")\n        self.default_max_tokens = self.config.get(\"max_new_tokens\", 1024)\n        self.default_temperature = self.config.get(\"temperature\", 1.0)\n\n        if not self.api_key:\n            raise ValueError(\"api_key is a required parameter\")\n\n        client_kwargs = {\"api_key\": self.api_key}\n        if self.base_url:\n            client_kwargs[\"base_url\"] = self.base_url\n\n        self.client = Anthropic(**client_kwargs)\n\n    def _call_api(\n        self,\n        prompt: str,\n        temperature: Optional[float] = None,\n        max_tokens: Optional[int] = None,\n        **kwargs\n    ) -> str:\n        \"\"\"\n        Call Claude API to generate text\n\n        Args:\n            prompt: Input prompt\n            temperature: Temperature parameter (0.0-1.0), default 1.0\n            max_tokens: Maximum number of tokens to generate, default 1024\n            **kwargs: Other Claude-specific parameters, such as:\n                - system: System prompt\n                - top_p: Nucleus sampling parameter\n                - top_k: Top-k sampling parameter\n\n        Returns:\n            str: Generated text content\n\n        Raises:\n            Exception: Raised when API call fails\n        \"\"\"\n        if temperature is None:\n            temperature = self.default_temperature\n        if max_tokens is None:\n            max_tokens = self.default_max_tokens\n\n        system = kwargs.pop(\"system\", None)\n\n        request_params = {\n            \"model\": self.model_name,\n            \"messages\": [{\"role\": \"user\", \"content\": prompt}],\n            \"max_tokens\": max_tokens,\n        }\n\n        if temperature is not None:\n            request_params[\"temperature\"] = temperature\n        if system:\n            request_params[\"system\"] = system\n\n        for key in [\"top_p\", \"top_k\", \"stop_sequences\"]:\n            if key in kwargs:\n                request_params[key] = kwargs.pop(key)\n\n        response = self.client.messages.create(**request_params)\n\n        if response and response.content:\n            text_blocks = [\n                block.text for block in response.content\n                if hasattr(block, 'text')\n            ]\n            if text_blocks:\n                return \"\".join(text_blocks)\n            else:\n                raise Exception(\"API returned empty response\")\n        else:\n            raise Exception(\"API returned invalid response\")\n"
  },
  {
    "path": "benchmarks/api/config/llm_config.json",
    "content": "{\n  \"gemini\": {\n    \"project\": \"\",\n    \"location\": \"\",\n    \"model_name\": \"gemini-2.5-flash-lite\",\n    \"credentials_path\": \"\",\n    \"max_new_tokens\": 10000,\n    \"temperature\": 0.01,\n    \"max_retries\": 3,\n    \"retry_delay\": 2\n  },\n  \"deepseek\": {\n    \"api_key\": \"\",\n    \"base_url\": \"\",\n    \"model_name\": \"deepseek-r1\",\n    \"appid\": \"\",\n    \"max_new_tokens\": 10000,\n    \"temperature\": 0.01,\n    \"max_retries\": 3,\n    \"retry_delay\": 2\n  },\n  \"claude\": {\n    \"api_key\": \"\",\n    \"base_url\": \"\",\n    \"model_name\": \"\",\n    \"max_new_tokens\": 10000,\n    \"temperature\": 0.01,\n    \"max_retries\": 3,\n    \"retry_delay\": 2\n  }\n}\n"
  },
  {
    "path": "benchmarks/api/deepseek.py",
    "content": "\"\"\"\nDeepSeek API Client Implementation\nCall DeepSeek model through Baidu Qianfan platform\n\"\"\"\nfrom typing import Optional\nfrom openai import OpenAI\nfrom .base import BaseLLMClient\n\n\nclass DeepSeekClient(BaseLLMClient):\n    \"\"\"\n    DeepSeek API Client (through Baidu Qianfan platform)\n\n    Example:\n        >>> client = DeepSeekClient(\n        ...     api_key=\"your-api-key\",\n        ...     base_url=\"https://qianfan.baidubce.com/v2\",\n        ...     model_name=\"deepseek-r1\",\n        ...     appid=\"your-appid\"\n        ... )\n        >>> response = client.generate(\"Tell me a joke\")\n    \"\"\"\n\n    def _setup(self):\n        \"\"\"Initialize DeepSeek client\"\"\"\n        self.api_key = self.config.get(\"api_key\")\n        self.base_url = self.config.get(\"base_url\", \"https://qianfan.baidubce.com/v2\")\n        self.model_name = self.config.get(\"model_name\", \"deepseek-r1\")\n        self.appid = self.config.get(\"appid\")\n        self.default_max_tokens = self.config.get(\"max_new_tokens\", 300)\n        self.default_temperature = self.config.get(\"temperature\", 0.7)\n\n        if not self.api_key:\n            raise ValueError(\"api_key is a required parameter\")\n        if not self.appid:\n            raise ValueError(\"appid is a required parameter\")\n\n        self.client = OpenAI(\n            api_key=self.api_key,\n            base_url=self.base_url,\n            default_headers={\"appid\": self.appid}\n        )\n\n    def _call_api(\n        self,\n        prompt: str,\n        temperature: Optional[float] = None,\n        max_tokens: Optional[int] = None,\n        **kwargs\n    ) -> str:\n        \"\"\"\n        Call DeepSeek API to generate text\n\n        Args:\n            prompt: Input prompt\n            temperature: Temperature parameter (0.0-2.0), default from config or 0.7\n            max_tokens: Maximum number of tokens to generate, default from config or 300\n            **kwargs: Other DeepSeek-specific parameters\n\n        Returns:\n            str: Generated text content\n\n        Raises:\n            Exception: Raised when API call fails\n        \"\"\"\n        if temperature is None:\n            temperature = self.default_temperature\n        if max_tokens is None:\n            max_tokens = self.default_max_tokens\n\n        request_params = {\n            \"model\": self.model_name,\n            \"messages\": [{\"role\": \"user\", \"content\": prompt}],\n            \"temperature\": temperature,\n            \"max_tokens\": max_tokens,\n            \"stream\": False\n        }\n\n        request_params.update(kwargs)\n\n        response = self.client.chat.completions.create(**request_params)\n\n        if response and response.choices:\n            content = response.choices[0].message.content\n            if content:\n                return content\n            else:\n                raise Exception(\"API returned empty response\")\n        else:\n            raise Exception(\"API returned invalid response\")\n"
  },
  {
    "path": "benchmarks/api/example.py",
    "content": "\"\"\"\nLLM API Usage Examples\nDemonstrates various calling methods and use cases\n\"\"\"\n\n# ============================================================================\n# Example 1: Using Configuration File (Simplest)\n# ============================================================================\ndef example1_use_config():\n    \"\"\"Load and use from configuration file\"\"\"\n    from api import get_client_from_config\n\n    print(\"=\" * 60)\n    print(\"Example 1: Using Configuration File\")\n    print(\"=\" * 60)\n\n    # Create client from configuration file\n    client = get_client_from_config(\"gemini\")\n\n    # Generate text\n    response = client.generate(\"Explain what AI is in one sentence\")\n    print(f\"Answer: {response}\\n\")\n\n\n# ============================================================================\n# Example 2: Direct Parameters\n# ============================================================================\ndef example2_direct_params():\n    \"\"\"Pass configuration parameters directly\"\"\"\n    from api import get_client\n\n    print(\"=\" * 60)\n    print(\"Example 2: Direct Parameters\")\n    print(\"=\" * 60)\n\n    # Gemini\n    gemini_client = get_client(\n        \"gemini\",\n        project=\"your-project\",\n        location=\"us-central1\",\n        model_name=\"gemini-2.5-pro\",\n        credentials_path=\"path/to/credentials.json\"\n    )\n\n    # DeepSeek\n    deepseek_client = get_client(\n        \"deepseek\",\n        api_key=\"your-api-key\",\n        appid=\"your-appid\",\n        base_url=\"https://qianfan.baidubce.com/v2\"\n    )\n\n    # Usage\n    response = gemini_client.generate(\"Hello\")\n    print(f\"Gemini: {response}\\n\")\n\n\n# ============================================================================\n# Example 3: Batch Generation (Concurrent)\n# ============================================================================\ndef example3_batch_generate():\n    \"\"\"Batch text generation with concurrent support\"\"\"\n    from api import get_client_from_config\n\n    print(\"=\" * 60)\n    print(\"Example 3: Batch Generation (Concurrent)\")\n    print(\"=\" * 60)\n\n    prompts = [\n        \"What is machine learning?\",\n        \"Explain deep learning\",\n        \"Principles of neural networks\",\n        \"What is natural language processing?\",\n        \"Applications of computer vision\"\n    ]\n\n    # Use client instance's batch_generate method (recommended)\n    client = get_client_from_config(\"gemini\")\n    results = client.batch_generate(\n        prompts=prompts,\n        max_workers=3,  # 3 concurrent threads\n        show_progress=True  # Show progress bar\n    )\n\n    # Process results\n    for i, item in enumerate(results, 1):\n        print(f\"\\nQuestion {i}: {item['prompt']}\")\n        if item['success']:\n            print(f\"Answer: {item['result'][:100]}...\")\n        else:\n            print(f\"Error: {item['error']}\")\n\n\n# ============================================================================\n# Example 4: Custom Generation Parameters\n# ============================================================================\ndef example4_custom_params():\n    \"\"\"Custom generation parameters\"\"\"\n    from api import get_client_from_config\n\n    print(\"=\" * 60)\n    print(\"Example 4: Custom Generation Parameters\")\n    print(\"=\" * 60)\n\n    client = get_client_from_config(\"deepseek\")\n\n    # Creative generation (high temperature)\n    creative = client.generate(\n        \"Write a poem about spring\",\n        temperature=0.9,\n        max_tokens=200\n    )\n    print(f\"Creative output:\\n{creative}\\n\")\n\n    # Precise generation (low temperature)\n    precise = client.generate(\n        \"What is 1+1?\",\n        temperature=0.1,\n        max_tokens=50\n    )\n    print(f\"Precise output:\\n{precise}\\n\")\n\n\n# ============================================================================\n# Example 5: Error Handling\n# ============================================================================\ndef example5_error_handling():\n    \"\"\"Demonstrate error handling\"\"\"\n    from api import get_client_from_config\n\n    print(\"=\" * 60)\n    print(\"Example 5: Error Handling\")\n    print(\"=\" * 60)\n\n    try:\n        client = get_client_from_config(\"gemini\")\n\n        # Normal call\n        response = client.generate(\"Hello\")\n        print(f\"Success: {response}\")\n\n        # Empty prompt (will raise ValueError)\n        response = client.generate(\"\")\n\n    except ValueError as e:\n        print(f\"Parameter error: {e}\")\n    except Exception as e:\n        print(f\"API call failed: {e}\")\n\n\n# ============================================================================\n# Example 6: Switch Models\n# ============================================================================\ndef example6_switch_models():\n    \"\"\"Switch between different models\"\"\"\n    from api import get_client_from_config\n\n    print(\"=\" * 60)\n    print(\"Example 6: Switch Models\")\n    print(\"=\" * 60)\n\n    question = \"What is quantum computing?\"\n\n    for model_name in [\"gemini\", \"deepseek\"]:\n        try:\n            client = get_client_from_config(model_name)\n            response = client.generate(question)\n            print(f\"\\n{model_name.upper()}'s answer:\")\n            print(response[:150] + \"...\")\n        except Exception as e:\n            print(f\"\\n{model_name} call failed: {e}\")\n\n\n# ============================================================================\n# Example 7: Real Application - User Profile Generation\n# ============================================================================\ndef example7_user_portrait():\n    \"\"\"Real application: Generate user profile based on user behavior\"\"\"\n    from api import get_client_from_config\n\n    print(\"=\" * 60)\n    print(\"Example 7: User Profile Generation\")\n    print(\"=\" * 60)\n\n    # User behavior data\n    user_behavior = \"\"\"\n    User's recently watched videos:\n    1. Machine Learning Tutorial\n    2. Python Programming Tips\n    3. Deep Learning Practical Projects\n    4. Data Analysis Case Studies\n    5. Latest AI Trends\n    \"\"\"\n\n    prompt = f\"\"\"Based on the following user behavior data, generate a concise user profile:\n\n{user_behavior}\n\nRequirements:\n1. Summarize user's areas of interest\n2. Infer user's skill level\n3. Provide 3-5 precise tags\n\"\"\"\n\n    client = get_client_from_config(\"gemini\")\n    portrait = client.generate(prompt, temperature=0.5)\n\n    print(\"User Profile:\")\n    print(portrait)\n\n\n# ============================================================================\n# Example 8: Direct Import of Classes\n# ============================================================================\ndef example8_direct_import():\n    \"\"\"Import client classes directly\"\"\"\n    from api import GeminiClient, DeepSeekClient\n\n    print(\"=\" * 60)\n    print(\"Example 8: Direct Import of Client Classes\")\n    print(\"=\" * 60)\n\n    # Direct instantiation\n    gemini = GeminiClient(\n        project=\"your-project\",\n        location=\"us-central1\"\n    )\n\n    deepseek = DeepSeekClient(\n        api_key=\"your-key\",\n        appid=\"your-appid\"\n    )\n\n    print(\"Clients created successfully\")\n    print(f\"Gemini client: {gemini}\")\n    print(f\"DeepSeek client: {deepseek}\")\n\n\n# ============================================================================\n# Main Function\n# ============================================================================\ndef main():\n    \"\"\"Run all examples\"\"\"\n    examples = [\n        (\"Using Configuration File\", example1_use_config),\n        (\"Direct Parameters\", example2_direct_params),\n        (\"Batch Generation\", example3_batch_generate),\n        (\"Custom Parameters\", example4_custom_params),\n        (\"Error Handling\", example5_error_handling),\n        (\"Switch Models\", example6_switch_models),\n        (\"User Profile Generation\", example7_user_portrait),\n        (\"Direct Import Classes\", example8_direct_import),\n    ]\n\n    print(\"\\n\" + \"=\" * 60)\n    print(\"LLM API Usage Examples\")\n    print(\"=\" * 60)\n    print(\"\\nAvailable examples:\")\n    for i, (name, _) in enumerate(examples, 1):\n        print(f\"{i}. {name}\")\n\n    print(\"\\nNote: Please ensure api/config/llm_config.json is configured before running\")\n    print(\"\\n\" + \"=\" * 60 + \"\\n\")\n\n    # Uncomment the lines below to run specific examples\n    # example1_use_config()\n    # example2_direct_params()\n    # example3_batch_generate()\n    # example4_custom_params()\n    # example5_error_handling()\n    # example6_switch_models()\n    # example7_user_portrait()\n    # example8_direct_import()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/api/gemini.py",
    "content": "\"\"\"\nGemini API Client Implementation\nBased on Google Vertex AI's Gemini model\n\"\"\"\nimport os\nfrom typing import Optional\nfrom vertexai.generative_models import GenerativeModel\nimport vertexai\nfrom .base import BaseLLMClient\n\n\nclass GeminiClient(BaseLLMClient):\n    \"\"\"\n    Gemini API Client\n\n    Example:\n        >>> client = GeminiClient(\n        ...     project=\"your-project\",\n        ...     location=\"us-central1\",\n        ...     model_name=\"gemini-2.5-pro\",\n        ...     credentials_path=\"path/to/credentials.json\"\n        ... )\n        >>> response = client.generate(\"Tell me a joke\")\n    \"\"\"\n\n    def _setup(self):\n        \"\"\"Initialize Gemini client\"\"\"\n        self.project = self.config.get(\"project\")\n        self.location = self.config.get(\"location\")\n        self.model_name = self.config.get(\"model_name\", \"gemini-2.5-pro\")\n        credentials_path = self.config.get(\"credentials_path\")\n        self.default_max_tokens = self.config.get(\"max_new_tokens\")\n        self.default_temperature = self.config.get(\"temperature\")\n\n        if credentials_path:\n            os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = credentials_path\n\n        if not self.project or not self.location:\n            raise ValueError(\"project and location are required parameters\")\n\n        vertexai.init(project=self.project, location=self.location)\n        self.model = GenerativeModel(self.model_name)\n\n    def _call_api(\n        self,\n        prompt: str,\n        temperature: Optional[float] = None,\n        max_tokens: Optional[int] = None,\n        **kwargs\n    ) -> str:\n        \"\"\"\n        Call Gemini API to generate text\n\n        Args:\n            prompt: Input prompt\n            temperature: Temperature parameter (0.0-1.0)\n            max_tokens: Maximum number of tokens to generate\n            **kwargs: Other Gemini-specific parameters\n\n        Returns:\n            str: Generated text content\n\n        Raises:\n            Exception: Raised when API call fails\n        \"\"\"\n        if temperature is None:\n            temperature = self.default_temperature\n        if max_tokens is None:\n            max_tokens = self.default_max_tokens\n\n        generation_config = {}\n        if temperature is not None:\n            generation_config[\"temperature\"] = temperature\n        if max_tokens is not None:\n            generation_config[\"max_output_tokens\"] = max_tokens\n\n        if generation_config:\n            response = self.model.generate_content(\n                prompt,\n                generation_config=generation_config\n            )\n        else:\n            response = self.model.generate_content(prompt)\n\n        if response and response.text:\n            return response.text\n        else:\n            raise Exception(\"API returned empty response\")\n"
  },
  {
    "path": "benchmarks/benchmark/__init__.py",
    "content": "from benchmark.benchmark import Benchmark\nfrom benchmark.base_generator  import Generator\nfrom benchmark.generation_runner import GenerationRunner\n\n__version__ = \"0.1.0\"\n\n__all__ = [\n    \"Benchmark\",\n    \"Generator\",\n    \"GenerationRunner\",\n]\n\n"
  },
  {
    "path": "benchmarks/benchmark/base_generator.py",
    "content": "import os\nfrom abc import ABC, abstractmethod\nfrom typing import Dict, List, Any, Optional\nfrom collections import defaultdict\nfrom benchmark.console import *\n\n\n# Global configuration: tasks that should disable optimizations (long prompts may cause issues)\n# Used by vLLM-based generators to control chunked_prefill and prefix_caching\nDISABLE_OPTIMIZATIONS_FOR_TASKS = [\"rec_reason\", \"interactive\"]\n\n\nclass Generator(ABC):\n    \"\"\"\n    Abstract base class for generation models\n\n    All generation models should inherit from this class.\n    Subclasses must implement _generate_standard() to support the generate() method.\n    \"\"\"\n    \n    def __init__(\n        self,\n        **kwargs\n    ):\n        \"\"\"\n        Args:\n            num_return_sequences: Number of candidates to generate per prompt\n            max_new_tokens: Maximum number of tokens to generate\n            **kwargs: Other generation parameters\n        \"\"\"\n        pass\n\n    def __str__(self) -> str:\n        \"\"\"\n        Return model name (for directory naming, remove path separators)\n\n        This method is shared across all generator implementations.\n        Subclasses must set self.model_name for this method to work.\n\n        Returns:\n            str: Model name\n        \"\"\"\n        return os.path.basename(self.model_name.rstrip('/'))\n\n    def generate(\n        self,\n        prompts: Dict[str, str],\n        **kwargs\n    ) -> tuple:\n        \"\"\"\n        Batch text generation\n\n        Supports two-stage generation for recommendation tasks:\n        - Stage 1: Generate thinking content with top_p/top_k sampling (if thinking enabled)\n        - Stage 2: Generate SID sequences with beam search and prompt_token\n\n        This method is shared across all generator implementations to reduce code duplication.\n        Subclasses must implement _generate_standard() for this method to work.\n\n        Args:\n            prompts: {sample_id: prompt_text}\n            **kwargs: Optional generation parameters (will override initialization parameters)\n\n        Returns:\n            Tuple of two dicts:\n            - First dict: {sample_id: [generated_text_1, generated_text_2, ...]}\n            - Second dict: {sample_id: [cum_logprob_1, cum_logprob_2, ...]} (only for beam search)\n        \"\"\"\n        prompt_token = kwargs.get(\"prompt_token\", None)\n        enable_thinking = kwargs.get(\"enable_thinking\", False)\n        max_new_thinking_tokens = kwargs.get(\"max_new_thinking_tokens\", None)\n        target_tokens = kwargs.get(\"target_tokens\", None)\n\n        # Check if this is a classification task (has target_tokens parameter)\n        is_classification = target_tokens is not None\n\n        # Generation logic based on task type:\n        # A: has max_new_thinking_tokens + has prompt_token (recommendation tasks)\n        # B: has max_new_thinking_tokens + no prompt_token (caption tasks)\n        # C: no max_new_thinking_tokens (standard tasks)\n        # D: classification task + no think\n        # E: classification task + think\n        if is_classification:\n            # Classification task scenarios (D & E)\n            if enable_thinking:\n                # E: Classification with thinking\n                console.print(\n                    f\"Two-stage classification with thinking enabled: thinking (max_new_thinking_tokens={max_new_thinking_tokens}) + logprobs extraction for {target_tokens}\",\n                    style=warning_style,\n                )\n                return self._generate_two_stage_classification_with_thinking(prompts, **kwargs)\n            else:\n                # D: Classification without thinking\n                console.print(\n                    f\"Classification task: extracting logprobs for tokens {target_tokens}\",\n                    style=warning_style,\n                )\n                # Remove target_tokens from kwargs to avoid passing it twice\n                kwargs_classification = kwargs.copy()\n                kwargs_classification.pop(\"target_tokens\", None)\n                results, _, mfu_stats = self.extract_token_logprobs(prompts, target_tokens, **kwargs_classification)\n\n                self.mfu_stats = mfu_stats\n                return results, {}\n        elif max_new_thinking_tokens:\n            if enable_thinking:\n                # A & B with thinking: two-stage generation\n                console.print(\n                    f\"Two-stage generation enabled: thinking (max_new_thinking_tokens={max_new_thinking_tokens}) + prompt_token ({prompt_token})\",\n                    style=warning_style,\n                )\n                return self._generate_two_stage_with_thinking(prompts, **kwargs)\n            else:\n                # A & B without thinking\n                if prompt_token:\n                    # A without thinking: single-stage with prompt_token (beam search)\n                    console.print(\n                        f\"Single-stage generation with prompt_token ({prompt_token})\",\n                        style=warning_style,\n                    )\n                    prompts_with_token = {\n                        sample_id: prompt + prompt_token\n                        for sample_id, prompt in prompts.items()\n                    }\n                    results, logprobs, mfu_stats = self._generate_standard(prompts_with_token, **kwargs)\n                    self.mfu_stats = mfu_stats\n                    return results, logprobs\n                else:\n                    # B without thinking: single-stage sampling\n                    console.print(\n                        f\"Warning: max_new_thinking_tokens={max_new_thinking_tokens} is set but \"\n                        f\"enable_thinking=False and prompt_token=None. The max_new_thinking_tokens parameter will be ignored.\",\n                        style=warning_style,\n                    )\n                    results, logprobs, mfu_stats = self._generate_standard(prompts, **kwargs)\n                    self.mfu_stats = mfu_stats\n                    return results, logprobs\n        else:\n            # C: standard single-stage sampling\n            results, logprobs, mfu_stats = self._generate_standard(prompts, **kwargs)\n            self.mfu_stats = mfu_stats\n            return results, logprobs\n\n\n    def get_hardware_info(self) -> Dict[str, Any]:\n        \"\"\"\n        Get GPU hardware information for MFU calculation\n\n        Default implementation that works for all generators.\n        Handles both single-machine and Ray-based multi-machine setups.\n\n        Returns:\n            Dictionary containing:\n            - gpu_model: str, GPU model name\n            - gpu_count: int, total number of GPUs used\n            - gpu_tflops: float, theoretical peak TFLOPS for BF16/FP16\n            - tensor_parallel_size: int, tensor parallelism size\n            - gpu_memory_total_gb: float, total GPU memory in GB\n        \"\"\"\n        from benchmark.gpu_utils import get_gpu_info\n\n        gpu_info = get_gpu_info()\n\n        # Calculate total GPU count\n        tensor_parallel_size = getattr(self, 'tensor_parallel_size', 1)\n\n        # For Ray-based generators, multiply by number of workers\n        if hasattr(self, 'workers') and self.workers:\n            num_workers = len(self.workers)\n            total_gpus = num_workers * tensor_parallel_size\n        else:\n            # For single-machine generators\n            total_gpus = tensor_parallel_size\n\n        gpu_info[\"gpu_count\"] = total_gpus\n        gpu_info[\"tensor_parallel_size\"] = tensor_parallel_size\n\n        # Add worker info for Ray-based generators\n        if hasattr(self, 'workers'):\n            gpu_info[\"num_workers\"] = len(self.workers) if self.workers else 0\n\n        return gpu_info\n\n    def _generate_two_stage_with_thinking(\n        self,\n        prompts: Dict[str, str],\n        **kwargs\n    ) -> tuple:\n        \"\"\"\n        Two-stage generation with thinking mode\n\n        Stage 1: Generate thinking content with top_p/top_k sampling until </think>\n        Stage 2: Continue generation (with prompt_token if provided, beam search or sampling)\n\n        This method is shared across all generator implementations to reduce code duplication.\n        Subclasses must implement _generate_standard() for this method to work.\n\n        Args:\n            prompts: {sample_id: prompt_text}\n            **kwargs: Optional generation parameters\n\n        Returns:\n            Tuple of two dicts:\n            - First dict: {sample_id: [generated_text_1, generated_text_2, ...]}\n            - Second dict: {sample_id: [cum_logprob_1, cum_logprob_2, ...]} (only for beam search)\n        \"\"\"\n        prompt_token = kwargs.get(\"prompt_token\", None)\n        console.print(\n            \"Stage 1/2: Generating thinking content with top_p/top_k sampling...\",\n            style=warning_style,\n        )\n\n        # Stage 1: Build kwargs for thinking generation (remove beam search, add stop)\n        kwargs_stage1 = kwargs.copy()\n        kwargs_stage1.pop(\"num_beams\", None)  # Remove beam search to force sampling mode\n        kwargs_stage1[\"stop\"] = [\"</think>\"]  # Stop at </think> tag\n\n        # Use num_return_thinking_sequences for stage 1 if specified\n        num_return_thinking = kwargs.get(\"num_return_thinking_sequences\", 1)\n        kwargs_stage1[\"num_return_sequences\"] = num_return_thinking\n\n        # Use max_new_thinking_tokens for stage 1 if specified\n        max_new_thinking_tokens = kwargs.get(\"max_new_thinking_tokens\", 1000)\n        kwargs_stage1[\"max_new_tokens\"] = max_new_thinking_tokens\n\n        # Call _generate_standard for stage 1 (ignoring logprobs as they're not used)\n        stage1_results, _, stage1_mfu_stats = self._generate_standard(prompts, **kwargs_stage1)\n\n        # Prepare prompts for stage 2 by appending thinking + prompt_token\n        # Each sample will have multiple thinking candidates\n        stage2_prompts = {}\n        sample_to_thinking_count = {}  # Track how many thinking candidates each sample has\n\n        for sample_id, thinking_list in stage1_results.items():\n            # Use ALL thinking candidates (not just the first one)\n            sample_to_thinking_count[sample_id] = len(thinking_list)\n\n            for idx, thinking_text in enumerate(thinking_list):\n                # Create unique ID for each thinking candidate\n                thinking_sample_id = f\"{sample_id}_thinking_{idx}\"\n\n                # Append </think> + prompt_token (if provided)\n                # If model didn't generate </think>, treat entire output as thinking\n                if prompt_token:\n                    full_thinking = thinking_text + \"</think>\\n\" + prompt_token\n                else:\n                    full_thinking = thinking_text + \"</think>\\n\"\n                stage2_prompt = prompts[sample_id] + full_thinking\n                stage2_prompts[thinking_sample_id] = stage2_prompt\n\n        # Stage 2: Determine generation mode based on num_beams\n        kwargs_stage2 = kwargs.copy()\n        original_num_sequences = kwargs.get(\"num_return_sequences\", 1)\n        original_num_beams = kwargs.get(\"num_beams\", None)\n\n        # Determine if stage 2 uses beam search or sampling\n        use_beam_search_stage2 = original_num_beams is not None\n\n        if use_beam_search_stage2:\n            # Beam search mode: num_beams is directly used per thinking candidate\n            beams_per_thinking = original_num_beams\n\n            # Validate configuration: total sequences should match\n            if original_num_sequences != beams_per_thinking * num_return_thinking:\n                raise ValueError(\n                    f\"Configuration error: num_return_sequences ({original_num_sequences}) must equal \"\n                    f\"num_beams ({beams_per_thinking}) * num_return_thinking_sequences ({num_return_thinking}) = \"\n                    f\"{beams_per_thinking * num_return_thinking}. \"\n                    f\"Please adjust your parameters accordingly.\"\n                )\n\n            kwargs_stage2[\"num_return_sequences\"] = beams_per_thinking\n            kwargs_stage2[\"num_beams\"] = beams_per_thinking\n\n            console.print(\n                f\"Stage 2/2: Generating sequences with beam search for {len(stage2_prompts)} thinking candidates...\",\n                style=warning_style,\n            )\n            console.print(\n                f\"Each thinking candidate will use beam_width={beams_per_thinking}, return {beams_per_thinking} sequences \"\n                f\"({num_return_thinking} thinking × {beams_per_thinking} = {num_return_thinking * beams_per_thinking} total per sample)\",\n                style=warning_style,\n            )\n        else:\n            # Sampling mode: each thinking generates 1 result\n            kwargs_stage2[\"num_return_sequences\"] = 1\n            kwargs_stage2.pop(\"num_beams\", None)  # Remove num_beams to use sampling\n\n            console.print(\n                f\"Stage 2/2: Generating sequences with sampling for {len(stage2_prompts)} thinking candidates...\",\n                style=warning_style,\n            )\n            console.print(\n                f\"Each thinking candidate will generate 1 sequence \"\n                f\"({num_return_thinking} thinking × 1 = {num_return_thinking} total per sample)\",\n                style=warning_style,\n            )\n\n        # Call _generate_standard for stage 2\n        stage2_results, stage2_logprobs, stage2_mfu_stats = self._generate_standard(stage2_prompts, **kwargs_stage2)\n\n        # Merge mfu_stats from both stages\n        self.mfu_stats = {}\n        for sample_id, stats in stage1_mfu_stats.items():\n            self.mfu_stats[sample_id] = {\n                \"input_tokens\": stats[\"input_tokens\"].copy(),\n                \"output_tokens\": stats[\"output_tokens\"].copy(),\n                \"times\": stats[\"times\"].copy()\n            }\n\n        # Group stage2 stats by original_id first\n        stage2_by_original = defaultdict(lambda: {\"input_tokens\": [], \"output_tokens\": [], \"times\": []})\n        for thinking_id, stats in stage2_mfu_stats.items():\n            original_id = thinking_id.rsplit(\"_thinking_\", 1)[0]\n            stage2_by_original[original_id][\"input_tokens\"].extend(stats[\"input_tokens\"])\n            stage2_by_original[original_id][\"output_tokens\"].extend(stats[\"output_tokens\"])\n            stage2_by_original[original_id][\"times\"].extend(stats[\"times\"])\n\n        # Aggregate: sum tokens, max time\n        for original_id, stats in stage2_by_original.items():\n            self.mfu_stats[original_id][\"input_tokens\"].append(sum(stats[\"input_tokens\"]))\n            self.mfu_stats[original_id][\"output_tokens\"].append(sum(stats[\"output_tokens\"]))\n            self.mfu_stats[original_id][\"times\"].append(max(stats[\"times\"]))\n\n        # Merge results back by original sample_id\n        # Combine thinking + prompt_token + SID into final generation\n        final_results = defaultdict(list)\n        final_logprobs = defaultdict(list)\n\n        for thinking_sample_id, sid_sequences in stage2_results.items():\n            # Extract original sample_id and thinking index\n            # Format: \"sampleID_thinking_N\"\n            parts = thinking_sample_id.rsplit(\"_thinking_\", 1)\n            original_sample_id = parts[0]\n            thinking_idx = int(parts[1])\n\n            # Get the corresponding thinking text from stage 1\n            thinking_text = stage1_results[original_sample_id][thinking_idx]\n\n            # Combine thinking + prompt_token + SID for each sequence\n            for sid_seq in sid_sequences:\n                # Format: <think>thinking_text</think>\\n<|sid_begin|>sid_sequence\n                combined = f\"{thinking_text}</think>\\n{prompt_token or ''}{sid_seq}\"\n                final_results[original_sample_id].append(combined)\n\n            # Also merge logprobs if available (from stage 2 beam search)\n            if thinking_sample_id in stage2_logprobs:\n                final_logprobs[original_sample_id].extend(stage2_logprobs[thinking_sample_id])\n\n        return (dict(final_results), dict(final_logprobs))\n\n    def _generate_two_stage_classification_with_thinking(\n        self,\n        prompts: Dict[str, str],\n        **kwargs\n    ) -> tuple:\n        \"\"\"\n        Two-stage generation for classification tasks with thinking mode\n\n        Stage 1: Generate thinking content with top_p/top_k sampling until </think>\n        Stage 2: Extract logprobs for target tokens for each thinking candidate\n\n        This method is shared across all generator implementations to reduce code duplication.\n        Subclasses must implement _generate_standard() and extract_token_logprobs() for this method to work.\n\n        Args:\n            prompts: {sample_id: prompt_text}\n            **kwargs: Optional generation parameters\n\n        Returns:\n            Tuple of two dicts:\n            - First dict: {sample_id: [\"<think>thinking_1</think>\\n{'是': 0.8, '否': 0.2}\", ...]}\n            - Second dict: {} (empty, no logprobs for classification)\n        \"\"\"\n        # target_tokens is guaranteed to be in kwargs (checked in generate() method)\n        target_tokens = kwargs[\"target_tokens\"]\n\n        console.print(\n            \"Stage 1/2: Generating thinking content with top_p/top_k sampling...\",\n            style=warning_style,\n        )\n\n        # Stage 1: Build kwargs for thinking generation (remove beam search, add stop)\n        kwargs_stage1 = kwargs.copy()\n        kwargs_stage1.pop(\"num_beams\", None)  # Remove beam search to force sampling mode\n        kwargs_stage1.pop(\"target_tokens\", None)  # Remove target_tokens for stage 1\n        kwargs_stage1[\"stop\"] = [\"</think>\"]  # Stop at </think> tag\n\n        # Use num_return_thinking_sequences for stage 1 if specified\n        num_return_thinking = kwargs.get(\"num_return_thinking_sequences\", 1)\n        kwargs_stage1[\"num_return_sequences\"] = num_return_thinking\n\n        # Use max_new_thinking_tokens for stage 1 if specified\n        max_new_thinking_tokens = kwargs.get(\"max_new_thinking_tokens\", 1000)\n        kwargs_stage1[\"max_new_tokens\"] = max_new_thinking_tokens\n\n        # Call _generate_standard for stage 1 (ignoring logprobs as they're not used)\n        stage1_results, _, stage1_mfu_stats = self._generate_standard(prompts, **kwargs_stage1)\n\n        # Prepare prompts for stage 2 by appending thinking + </think>\n        # Each sample will have multiple thinking candidates\n        stage2_prompts = {}\n        sample_to_thinking_count = {}  # Track how many thinking candidates each sample has\n\n        for sample_id, thinking_list in stage1_results.items():\n            # Use ALL thinking candidates (not just the first one)\n            sample_to_thinking_count[sample_id] = len(thinking_list)\n\n            for idx, thinking_text in enumerate(thinking_list):\n                # Create unique ID for each thinking candidate\n                thinking_sample_id = f\"{sample_id}_thinking_{idx}\"\n\n                # Append </think> to complete the thinking tag\n                full_thinking = thinking_text + f\"</think>\\n\"\n                stage2_prompt = prompts[sample_id] + full_thinking\n                stage2_prompts[thinking_sample_id] = stage2_prompt\n\n        console.print(\n            f\"Stage 2/2: Extracting logprobs for {len(stage2_prompts)} thinking candidates...\",\n            style=warning_style,\n        )\n        console.print(\n            f\"Each thinking candidate will extract logprobs for tokens {target_tokens} \"\n            f\"({num_return_thinking} thinking total per sample)\",\n            style=warning_style,\n        )\n\n        # Build kwargs for stage 2 (remove target_tokens to avoid duplication)\n        kwargs_stage2 = kwargs.copy()\n        kwargs_stage2.pop(\"target_tokens\", None)\n\n        # Call extract_token_logprobs for stage 2\n        stage2_probs, _, stage2_mfu_stats = self.extract_token_logprobs(stage2_prompts, target_tokens, **kwargs_stage2)\n\n        # Merge mfu_stats from both stages\n        self.mfu_stats = {}\n        for sample_id, stats in stage1_mfu_stats.items():\n            self.mfu_stats[sample_id] = {\n                \"input_tokens\": stats[\"input_tokens\"].copy(),\n                \"output_tokens\": stats[\"output_tokens\"].copy(),\n                \"times\": stats[\"times\"].copy()\n            }\n\n        # Group stage2 stats by original_id first\n        stage2_by_original = defaultdict(lambda: {\"input_tokens\": [], \"output_tokens\": [], \"times\": []})\n        for thinking_id, stats in stage2_mfu_stats.items():\n            original_id = thinking_id.rsplit(\"_thinking_\", 1)[0]\n            stage2_by_original[original_id][\"input_tokens\"].extend(stats[\"input_tokens\"])\n            stage2_by_original[original_id][\"output_tokens\"].extend(stats[\"output_tokens\"])\n            stage2_by_original[original_id][\"times\"].extend(stats[\"times\"])\n\n        # Aggregate: sum tokens, max time\n        for original_id, stats in stage2_by_original.items():\n            self.mfu_stats[original_id][\"input_tokens\"].append(sum(stats[\"input_tokens\"]))\n            self.mfu_stats[original_id][\"output_tokens\"].append(sum(stats[\"output_tokens\"]))\n            self.mfu_stats[original_id][\"times\"].append(max(stats[\"times\"]))\n\n        # Merge results back by original sample_id\n        # Combine thinking + probabilities into final generation\n        final_results = defaultdict(list)\n\n        for thinking_sample_id, json_str_list in stage2_probs.items():\n            # Extract original sample_id and thinking index\n            # Format: \"sampleID_thinking_N\"\n            parts = thinking_sample_id.rsplit(\"_thinking_\", 1)\n            original_sample_id = parts[0]\n            thinking_idx = int(parts[1])\n\n            # Get the corresponding thinking text from stage 1\n            thinking_text = stage1_results[original_sample_id][thinking_idx]\n\n            # Extract JSON string from list (extract_token_logprobs returns [json_str])\n            json_str = json_str_list[0]\n\n            # Combine thinking + probabilities (json_str is already formatted)\n            # Format: \"<think>thinking_text</think>\\n{\\\"是\\\": 0.8, \\\"否\\\": 0.2}\"\n            combined = f\"{thinking_text}</think>\\n{json_str}\"\n            final_results[original_sample_id].append(combined)\n\n        return (dict(final_results), {})\n\n\nclass HfTransformersMixin:\n    \"\"\"\n    Mixin for HuggingFace Transformers functionality\n    \n    Provides common parameter building logic for HuggingFace Transformers generate() API.\n    This mixin can be combined with Generator or RayMixin to create HuggingFace-based generators.\n    \"\"\"\n    \n    def _build_sampling_params(self, **kwargs) -> tuple:\n        \"\"\"\n        Build HuggingFace sampling/generation parameters\n        \n        Args:\n            **kwargs: Optional parameters to override default values\n        \n        Returns:\n            Tuple of (gen_kwargs dict, stop_sequences list)\n        \"\"\"\n        n = kwargs.get(\"num_return_sequences\")\n        max_tokens = kwargs.get(\"max_new_tokens\")\n        num_beams = kwargs.get(\"num_beams\", None)\n        use_beam_search = num_beams is not None\n        \n        stop_sequences = kwargs.get(\"stop\", [])\n        \n        if use_beam_search:\n            # Beam search mode\n            if n and n > num_beams:\n                raise ValueError(\n                    f\"num_return_sequences ({n}) cannot be greater than num_beams ({num_beams}). \"\n                    f\"Beam search can only return at most {num_beams} sequences. \"\n                    f\"Please set num_return_sequences <= num_beams or increase num_beams.\"\n                )\n\n            gen_kwargs = {\n                \"num_beams\": num_beams,\n                \"num_return_sequences\": n if n else num_beams,\n                \"max_new_tokens\": max_tokens,\n                \"do_sample\": False,\n                \"output_scores\": True,\n                \"return_dict_in_generate\": True,\n            }\n            if \"repetition_penalty\" in kwargs:\n                gen_kwargs[\"repetition_penalty\"] = kwargs[\"repetition_penalty\"]\n        else:\n            # Sampling mode\n            gen_kwargs = {\n                \"num_return_sequences\": n,\n                \"max_new_tokens\": max_tokens,\n                \"temperature\": kwargs.get(\"temperature\", 0.7),\n                \"top_p\": kwargs.get(\"top_p\", 0.9),\n                \"top_k\": kwargs.get(\"top_k\", -1),\n                \"repetition_penalty\": kwargs.get(\"repetition_penalty\", 1.0),\n                \"presence_penalty\": kwargs.get(\"presence_penalty\", 0.0),\n                \"frequency_penalty\": kwargs.get(\"frequency_penalty\", 0.0),\n                \"do_sample\": kwargs.get(\"do_sample\", True),\n            }\n        \n        return gen_kwargs, stop_sequences\n\n\nclass VllmMixin:\n    \"\"\"\n    Mixin for vLLM functionality\n    \n    Provides common parameter building logic for vLLM generate() API.\n    This mixin can be combined with Generator or RayMixin to create vLLM-based generators.\n    \"\"\"\n    \n    def _build_sampling_params(self, **kwargs):\n        \"\"\"\n        Build vLLM sampling parameters\n        \n        Args:\n            **kwargs: Optional parameters to override default values\n        \n        Returns:\n            SamplingParams or BeamSearchParams object\n        \"\"\"\n        from vllm import SamplingParams\n        from vllm.sampling_params import BeamSearchParams\n        \n        temperature = kwargs.get(\"temperature\", 0.7)\n        top_p = kwargs.get(\"top_p\", 0.9)\n        top_k = kwargs.get(\"top_k\", -1)\n        repetition_penalty = kwargs.get(\"repetition_penalty\", 1.0)\n        presence_penalty = kwargs.get(\"presence_penalty\", 0.0)\n        frequency_penalty = kwargs.get(\"frequency_penalty\", 0.0)\n        max_tokens = kwargs.get(\"max_new_tokens\")\n        n = kwargs.get(\"num_return_sequences\", 1)\n        stop = kwargs.get(\"stop\", None)\n        \n        num_beams = kwargs.get(\"num_beams\", None)\n        use_beam_search = num_beams  is not None\n        \n        if use_beam_search:\n            # Beam search: set beam_width to max(num_beams, n)\n            actual_beam_width = max(num_beams, n)\n            params = BeamSearchParams(\n                beam_width=actual_beam_width,\n                max_tokens=max_tokens,\n            )\n        else:\n            # Sampling mode\n            params = SamplingParams(\n                n=n,\n                temperature=temperature,\n                top_p=top_p,\n                top_k=top_k,\n                repetition_penalty=repetition_penalty,\n                presence_penalty=presence_penalty,\n                frequency_penalty=frequency_penalty,\n                max_tokens=max_tokens,\n                stop=stop,\n            )\n        \n        return params\n\n    \n    def _should_enable_optimizations(self) -> bool:\n        \"\"\"\n        Determine whether to enable optimizations based on task types and force flags\n        \n        This method is primarily used by vLLM-based generators to control\n        chunked_prefill and prefix_caching optimizations.\n        \n        Returns:\n            True if should enable optimizations, False otherwise\n        \"\"\"\n        # Priority 1: Force flags\n        if self.force_enable_optimizations:\n            return True\n        if self.force_disable_optimizations:\n            return False\n        \n        # Priority 2: Check if any task in task_types requires disabling optimizations\n        if hasattr(self, 'task_types') and self.task_types:\n            for task_type in self.task_types:\n                if task_type in DISABLE_OPTIMIZATIONS_FOR_TASKS:\n                    return False\n        \n        # Default: enable optimizations\n        return True\n\n\nclass RayMixin:\n    \"\"\"\n    Mixin for Ray distributed computing functionality\n    \n    Provides Ray cluster management, GPU allocation, and resource cleanup\n    for distributed generators. This is a mixin class designed to be combined\n    with other generator classes using multiple inheritance.\n    \"\"\"\n    \n    def _initialize_ray_cluster(self):\n        \"\"\"Initialize Ray cluster connection\"\"\"\n        import ray\n\n        if ray.is_initialized():\n            console.print(\n                \"  ✓ Ray already initialized\",\n                style=success_style,\n            )\n            return\n\n        console.print(\n            \"  Initializing Ray cluster connection...\",\n            style=subhead_style_2,\n        )\n\n        # Determine connection mode\n        if self.ray_address == \"local\":\n            # Local mode (single machine)\n            ray.init(ignore_reinit_error=True)\n            console.print(\n                \"  ✓ Ray initialized in local mode\",\n                style=success_style,\n            )\n        elif self.ray_address == \"auto\":\n            # Auto-detect mode\n            try:\n                ray.init(address=\"auto\", ignore_reinit_error=True)\n                console.print(\n                    \"  ✓ Ray connected to existing cluster (auto-detected)\",\n                    style=success_style,\n                )\n            except Exception:\n                # Fallback to local mode\n                console.print(\n                    \"  [yellow]No existing cluster found, initializing local mode...[/yellow]\",\n                    style=warning_style,\n                )\n                ray.init(ignore_reinit_error=True)\n                console.print(\n                    \"  ✓ Ray initialized in local mode\",\n                    style=success_style,\n                )\n        else:\n            # Specific address\n            ray.init(address=self.ray_address, ignore_reinit_error=True)\n            console.print(\n                f\"  ✓ Ray connected to cluster at {self.ray_address}\",\n                style=success_style,\n            )\n    \n    def _determine_gpu_ids_from_cluster(self) -> List[Dict[str, Any]]:\n        \"\"\"\n        Determine GPU resources from Ray cluster\n        \n        Returns:\n            List of GPU info dicts: [{\"node_id\": str, \"gpu_index\": int}, ...]\n        \"\"\"\n        import ray\n        \n        # Get all nodes in cluster\n        nodes = ray.nodes()\n        \n        # Collect GPU information from all nodes\n        gpu_list = []\n        \n        for node in nodes:\n            if not node['Alive']:\n                continue\n            \n            node_id = node['NodeID']\n            node_resources = node.get('Resources', {})\n            \n            # Count GPUs on this node\n            num_gpus_on_node = int(node_resources.get('GPU', 0))\n            \n            if num_gpus_on_node > 0:\n                # Add GPU entries for this node\n                for gpu_idx in range(num_gpus_on_node):\n                    gpu_list.append({\n                        \"node_id\": node_id,\n                        \"node_ip\": node.get('NodeManagerAddress', 'unknown'),\n                        \"gpu_index\": gpu_idx,\n                        \"global_index\": len(gpu_list)  # Global GPU index across cluster\n                    })\n        \n        if not gpu_list:\n            raise RuntimeError(\"No GPUs detected in Ray cluster\")\n        \n        # Apply user filters if specified\n        if self.gpu_ids is not None:\n            # In cluster mode, gpu_ids refers to global indices\n            filtered_list = []\n            for idx in self.gpu_ids:\n                if idx < len(gpu_list):\n                    filtered_list.append(gpu_list[idx])\n                else:\n                    console.print(\n                        f\"  [yellow]Warning:[/yellow] GPU index {idx} out of range (max: {len(gpu_list)-1}), skipping\",\n                        style=warning_style,\n                    )\n            gpu_list = filtered_list\n        elif self.num_gpus is not None:\n            # Limit to first num_gpus\n            if self.num_gpus < len(gpu_list):\n                gpu_list = gpu_list[:self.num_gpus]\n            elif self.num_gpus > len(gpu_list):\n                console.print(\n                    f\"  [yellow]Warning:[/yellow] Requested {self.num_gpus} GPUs, but only {len(gpu_list)} available in cluster\",\n                    style=warning_style,\n                )\n        \n        return gpu_list\n    \n    def _group_gpus_for_workers(\n        self,\n        gpu_list: List[Dict[str, Any]],\n        tensor_parallel_size: int\n    ) -> tuple:\n        \"\"\"\n        Group GPUs for workers, ensuring same-node constraint for tensor parallelism\n        \n        Args:\n            gpu_list: List of GPU info dicts\n            tensor_parallel_size: Number of GPUs per worker\n        \n        Returns:\n            (worker_gpu_groups, worker_node_assignments)\n            - worker_gpu_groups: List of GPU index lists for each worker\n            - worker_node_assignments: List of node IDs for each worker\n        \"\"\"\n        if len(gpu_list) % tensor_parallel_size != 0:\n            raise ValueError(\n                f\"Number of GPUs ({len(gpu_list)}) must be divisible by tensor_parallel_size ({tensor_parallel_size})\"\n            )\n        \n        num_workers = len(gpu_list) // tensor_parallel_size\n        worker_gpu_groups = []\n        worker_node_assignments = []\n        \n        if tensor_parallel_size == 1:\n            # Simple case: one GPU per worker\n            for gpu_info in gpu_list:\n                worker_gpu_groups.append([gpu_info[\"gpu_index\"]])\n                worker_node_assignments.append(gpu_info[\"node_id\"])\n        else:\n            # Complex case: multiple GPUs per worker\n            # Need to ensure all GPUs in a group are on the same node\n            \n            if not self.allow_cross_node_tensor_parallel:\n                # Group by node first\n                node_to_gpus = {}\n                for gpu_info in gpu_list:\n                    node_id = gpu_info[\"node_id\"]\n                    if node_id not in node_to_gpus:\n                        node_to_gpus[node_id] = []\n                    node_to_gpus[node_id].append(gpu_info)\n                \n                # Create workers from each node\n                for node_id, node_gpus in node_to_gpus.items():\n                    # Group GPUs on this node\n                    for i in range(0, len(node_gpus), tensor_parallel_size):\n                        if i + tensor_parallel_size <= len(node_gpus):\n                            gpu_group = [gpu[\"gpu_index\"] for gpu in node_gpus[i:i+tensor_parallel_size]]\n                            worker_gpu_groups.append(gpu_group)\n                            worker_node_assignments.append(node_id)\n                \n                if len(worker_gpu_groups) != num_workers:\n                    raise ValueError(\n                        f\"Cannot create {num_workers} workers with tensor_parallel_size={tensor_parallel_size} \"\n                        f\"while ensuring same-node constraint. Got {len(worker_gpu_groups)} workers instead. \"\n                        f\"Try setting --allow_cross_node_tensor_parallel or adjust tensor_parallel_size.\"\n                    )\n            else:\n                # Allow cross-node tensor parallel (not recommended)\n                console.print(\n                    \"  [yellow]Warning: Cross-node tensor parallelism enabled. This may cause performance degradation.[/yellow]\",\n                    style=warning_style,\n                )\n                for i in range(num_workers):\n                    start_idx = i * tensor_parallel_size\n                    end_idx = start_idx + tensor_parallel_size\n                    gpu_group = [gpu_list[j][\"gpu_index\"] for j in range(start_idx, end_idx)]\n                    worker_gpu_groups.append(gpu_group)\n                    # Use first GPU's node as primary node\n                    worker_node_assignments.append(gpu_list[start_idx][\"node_id\"])\n        \n        return worker_gpu_groups, worker_node_assignments\n    \n    def _display_cluster_info(self, gpu_list: List[Dict[str, Any]], num_workers: int):\n        \"\"\"Display cluster and GPU information\"\"\"\n        import ray\n\n        # Get cluster info\n        nodes = ray.nodes()\n        alive_nodes = [n for n in nodes if n['Alive']]\n\n        console.print(\n            f\"  Cluster nodes: [green]{len(alive_nodes)}[/green]\",\n            style=subhead_style_2,\n        )\n\n        # Group GPUs by node\n        node_gpu_count = {}\n        for gpu_info in gpu_list:\n            node_ip = gpu_info[\"node_ip\"]\n            node_gpu_count[node_ip] = node_gpu_count.get(node_ip, 0) + 1\n\n        for node_ip, count in node_gpu_count.items():\n            console.print(\n                f\"    - Node {node_ip}: {count} GPU(s)\",\n                style=subhead_style_2,\n            )\n\n        console.print(\n            f\"  Total GPUs: [green]{len(gpu_list)}[/green]\",\n            style=subhead_style_2,\n        )\n        console.print(\n            f\"  Tensor Parallel Size: [green]{self.tensor_parallel_size}[/green]\",\n            style=subhead_style_2,\n        )\n        console.print(\n            f\"  Worker count: [green]{num_workers}[/green]\",\n            style=subhead_style_2,\n        )\n\n        # Display worker assignments\n        console.print(\n            f\"  Worker GPU assignments:\",\n            style=subhead_style_2,\n        )\n        for i, (gpu_group, node_id) in enumerate(zip(self.worker_gpu_groups, self.worker_node_assignments)):\n            # Find node IP for this node_id\n            node_ip = \"unknown\"\n            for gpu_info in gpu_list:\n                if gpu_info[\"node_id\"] == node_id:\n                    node_ip = gpu_info[\"node_ip\"]\n                    break\n            console.print(\n                f\"    - Worker {i}: GPUs {gpu_group} on node {node_ip}\",\n                style=subhead_style_2,\n            )\n    \n    def cleanup(self):\n        \"\"\"\n        Explicitly cleanup resources and release GPU memory\n        \n        Called after generation tasks complete to release GPU memory occupied by Ray Workers.\n        This is useful for avoiding OOM errors during subsequent metric calculations.\n        \"\"\"\n        import ray\n        \n        console.print(\n            \"\\nReleasing Ray Workers and resources...\",\n            style=warning_style,\n        )\n        \n        try:\n            # 1. Cleanup all Workers\n            if hasattr(self, 'workers') and self.workers:\n                for i, worker in enumerate(self.workers):\n                    try:\n                        ray.kill(worker)\n                        console.print(\n                            f\"  ✓ Worker {i} terminated\",\n                            style=success_style,\n                        )\n                    except Exception as e:\n                        console.print(\n                            f\"  ⚠ Worker {i} cleanup failed: {e}\",\n                            style=err_style,\n                        )\n                self.workers = []\n\n            # 2. Shut down Ray (optional)\n            if ray.is_initialized():\n                console.print(\n                    \"  Shutting down Ray...\",\n                    style=subhead_style_2,\n                )\n                ray.shutdown()\n                console.print(\n                    \"  ✓ Ray shut down\",\n                    style=subhead_style_2,\n                )\n\n            console.print(\n                \"✓ Resource cleanup completed\\n\",\n                style=success_style,\n            )\n\n        except Exception as e:\n            console.print(\n                f\"✗ Cleanup process error: {e}\",\n                style=err_style,\n            )"
  },
  {
    "path": "benchmarks/benchmark/benchmark.py",
    "content": "import os\nimport json\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nfrom pathlib import Path\nfrom datetime import datetime\n\nfrom benchmark.console import *\nfrom benchmark.generation_runner import GenerationRunner\nfrom benchmark.base_generator  import Generator\nfrom benchmark.tasks import (\n    BenchmarkTable,\n    LATEST_BENCHMARK_VERSION,\n    check_benchmark_version,\n    check_task_types,\n    check_splits,\n)\nfrom benchmark.tasks.v1_0.registry import get_loader, get_evaluator, get_task_config\n\n\nclass DataLoaderWrapper:\n    \"\"\"Wrapper for unified data loading interface\"\"\"\n    def __init__(self, model_path: str, benchmark_version: str, data_dir: str, enable_thinking: Optional[bool] = None):\n        self.model_path = model_path\n        self._tokenizer = self._create_tokenizer(model_path) if model_path else None\n\n        self.benchmark_version = benchmark_version\n        self.data_dir = data_dir\n        self.enable_thinking = enable_thinking\n        self._loader_cache = {}\n    \n    def _create_tokenizer(self, model_path: str):\n        \"\"\"Create tokenizer from model path\"\"\"\n        try:\n            from transformers import AutoTokenizer\n            tokenizer = AutoTokenizer.from_pretrained(\n                model_path,\n                trust_remote_code=True\n            )\n            console.print(f\"[green]Tokenizer loaded from: {model_path}[/green]\")\n            return tokenizer\n        except Exception as e:\n            raise RuntimeError(f\"Failed to load tokenizer from {model_path}: {e}\")\n    \n    def load_data(self, task_name: str, split: str = \"test\", sample_size: Optional[Any] = None):\n        \"\"\"Load data using new loader system\"\"\"\n        if task_name not in self._loader_cache:\n            self._loader_cache[task_name] = get_loader(\n                task_name=task_name,\n                data_dir=self.data_dir,\n                tokenizer=self._tokenizer,\n                enable_thinking=self.enable_thinking,\n            )\n\n        loader = self._loader_cache[task_name]\n        return loader.load_data(split=split, sample_size=sample_size)\n\n\n\nclass Benchmark:\n    \"\"\"\n    Benchmark Generation Task Evaluation Framework\n    \n    Usage Example:\n        from benchmark import Benchmark\n        from your_generator import YourGenerator\n        \n        benchmark = Benchmark(\n            data_dir=\"./data\"\n        )\n        \n        generator = YourGenerator(\"your-model-path\")\n        \n        benchmark.run(\n            generator=generator,\n            output_dir=\"./results\"\n        )\n    \"\"\"\n    \n    def __init__(\n        self,\n        model_path: Optional[str] = None,\n        task_types: Optional[List[str]] = None,\n        splits: Optional[List[str]] = None,\n        data_dir: Optional[str] = None,\n        enable_thinking: Optional[bool] = None,\n    ):\n        \"\"\"Initialize evaluation framework\"\"\"\n        self.benchmark_version = LATEST_BENCHMARK_VERSION\n        self.data_dir = data_dir\n        self.task_types = check_task_types(task_types, self.benchmark_version)\n        self.splits = check_splits(splits, self.benchmark_version)\n        self.data_loader = DataLoaderWrapper(\n            model_path=model_path,\n            benchmark_version=self.benchmark_version,\n            data_dir=data_dir,\n            enable_thinking=enable_thinking,\n        )\n    \n    @staticmethod\n    def print_benchmark_table():\n        \"\"\"Print all available benchmark versions and tasks\"\"\"\n        for benchmark_version in BenchmarkTable:\n            console.print(\n                head_print(f\"Benchmark Dataset Version: {benchmark_version}\"),\n                style=head_style,\n                justify=\"center\",\n            )\n\n            task_types_list = list(BenchmarkTable[benchmark_version].keys())\n            total_task_types = len(task_types_list)\n            \n            for task_idx, task_type in enumerate(task_types_list, start=1):\n                console.print(\n                    f\"\\nTask Type [{task_idx}/{total_task_types}]: {task_type}\\n\", \n                    style=subhead_style, \n                    justify=\"center\"\n                )\n                 \n                task_config = BenchmarkTable[benchmark_version][task_type]\n                \n                console.print(\n                    f\"Dataset Name: {task_config.get('name', task_type)}\",\n                    style=row_style,\n                    justify=\"center\",\n                )\n                console.print(\n                    f\"Source: {task_config.get('source', 'N/A')}\",\n                    style=row_style,\n                    justify=\"center\",\n                )\n                console.print(\n                    f\"Splits: {task_config.get('splits', [])}\",\n                    style=row_style,\n                    justify=\"center\",\n                )\n                console.print(\n                    f\"Sample Size: {task_config.get('sample_size', 'N/A')}\",\n                    style=row_style,\n                    justify=\"center\",\n                )\n                console.print(\n                    f\"Description: {task_config.get('description', 'N/A')}\",\n                    style=row_style,\n                    justify=\"center\",\n                )\n    \n    @staticmethod\n    def check_generator(generator):\n        \"\"\"Verify that generator implements required methods\"\"\"\n        required_methods = [\"__str__\", \"generate\"]\n        for method in required_methods:\n            if not hasattr(generator, method):\n                raise ValueError(f\"Generator should have `{method}` method.\")\n            if method != \"__str__\" and not callable(getattr(generator, method, None)):\n                raise ValueError(f\"Generator.{method} should be callable.\")\n    \n    def run(\n        self,\n        generator: Generator,\n        output_dir: str = \"./results\",\n        overwrite: bool = False,\n        **kwargs\n    ):\n        \"\"\"Run benchmark evaluation\"\"\"\n        self.check_generator(generator)\n        console.print(f\"\\n\\nStarting generation\\n\\n\", style=head_style, justify=\"center\")\n        \n        generation_runner = GenerationRunner(self.data_loader, overwrite=overwrite)\n        total_tasks = 0\n        completed_tasks = 0\n        task_table = BenchmarkTable[self.benchmark_version]\n\n        for task_name in self.task_types:\n            if task_name not in task_table:\n                continue\n            task_config = task_table[task_name]\n            available_splits = task_config.get(\"splits\", [\"test\"])\n            for split in self.splits:\n                if split in available_splits:\n                    total_tasks += 1\n        \n        for task_name in self.task_types:\n            if task_name not in task_table:\n                console.print(f\"Task does not exist: {task_name}\")\n                continue\n            \n            task_config = task_table[task_name]\n            available_splits = task_config.get(\"splits\", [\"test\"])\n            \n            # Iterate through all splits\n            for split in self.splits:\n                if split not in available_splits:\n                    console.print(f\"Split does not exist: {split} (task: {task_name})\")\n                    continue\n                \n                # Determine displayed sample size\n                sample_size_param = kwargs.get('sample_size')\n                if sample_size_param is not None:\n                    if sample_size_param == \"full\":\n                        display_sample_size = task_config.get('size', 'N/A')\n                    else:\n                        display_sample_size = int(sample_size_param)\n                else:\n                    display_sample_size = task_config.get('sample_size', 'N/A')\n                \n                console.print(\n                    f\"\\nTask [{completed_tasks + 1}/{total_tasks}]: {task_name} | Split: {split} | Sample Size: {display_sample_size}\\n\",\n                    style=subhead_style,\n                    justify=\"center\",\n                )\n                \n                try:\n                    task_gen_config = task_config.get(\"generation_config\", {})\n                    prompt_config = task_config.get(\"prompt_config\", {})\n                    \n                    # Merge generation parameters (priority: user input > task config > Generator init parameters)\n                    # Filter out None values from kwargs to avoid overwriting task config\n                    valid_kwargs = {k: v for k, v in kwargs.items() if v is not None}\n                    merged_kwargs = {**task_gen_config, **prompt_config, **valid_kwargs}\n\n                    print(f\"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]\")\n\n                    # Execute generation (without computing metrics)\n                    generation_runner(\n                        task_name=task_name,\n                        split=split,\n                        results_save_dir=output_dir,\n                        generator=generator,\n                        **merged_kwargs\n                    )\n                    \n                    completed_tasks += 1\n                \n                    \n                except Exception as e:\n                    import traceback\n                    console.print(f\"✗ Task failed: {task_name}/{split}\", style=err_style)\n                    console.print(f\"✗ Error type: {type(e).__name__}\", style=err_style)\n                    console.print(f\"✗ Error message: {str(e)}\", style=err_style)\n                    console.print(\"✗ Full stack trace:\", style=err_style)\n                    console.print(traceback.format_exc(), style=dim_style)\n        \n        console.print(f\"Total tasks: {total_tasks}\")\n        console.print(f\"Completed tasks: {completed_tasks}\")\n        console.print(f\"Failed tasks: {total_tasks - completed_tasks}\")\n        console.print(f\"Results saved to: {output_dir}\")\n    \n    @staticmethod\n    def _evaluate_single_task(\n        task_name: str,\n        task_dir: str,\n        generation_file: str,\n        split: str,\n        data_dir: str,\n        overwrite: bool,\n        valid_kwargs: Dict[str, Any],\n        cached_metrics: Optional[Dict[str, Any]] = None\n    ) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:\n        \"\"\"Evaluate a single task split\"\"\"\n        \n        # Read generation results\n        with open(generation_file, 'r', encoding='utf-8') as f:\n            gen_data = json.load(f)\n        \n        if \"samples\" not in gen_data:\n            raise ValueError(\"Generation result file missing 'samples' field (new format).\")\n        \n        samples = gen_data[\"samples\"]\n        \n        # Load task configuration\n        if task_name not in BenchmarkTable[LATEST_BENCHMARK_VERSION]:\n            console.print(f\"⚠ Warning: Task '{task_name}' not found in BenchmarkTable[{LATEST_BENCHMARK_VERSION}], skipping...\", style=warning_style)\n            return None, None, None\n        \n        try:\n            evaluator_class = get_evaluator(task_name=task_name)\n            task_config = get_task_config(task_name=task_name)\n            task_config['evaluation_config'].update(valid_kwargs)\n            evaluator = evaluator_class(\n                samples=samples,\n                task_name=task_name,\n                predictions_dir=task_dir,\n                debug=True,  # Enable debug mode for detailed info\n                task_config=task_config,\n                data_dir=data_dir,\n                overwrite=overwrite,\n                cached_metrics=cached_metrics\n            )\n            \n            console.print(f\"Using {evaluator_class.__name__} for {task_name}\")\n            metrics, per_sample_metrics = evaluator.evaluate()\n\n            # Compute MFU metrics if hardware info and token stats are available\n            try:\n                from benchmark.tasks.v1_0.mfu_evaluator import compute_mfu_from_generation_data\n                mfu_metrics = compute_mfu_from_generation_data(gen_data)\n                if mfu_metrics:\n                    metrics.update(mfu_metrics)\n                    # Display MFU for each stage\n                    if \"mfu\" in mfu_metrics:\n                        mfu_list = mfu_metrics[\"mfu\"]\n                        if len(mfu_list) == 1:\n                            console.print(f\"✓ MFU: {mfu_list[0]:.2%}\", style=success_style)\n                        else:\n                            mfu_values = [f\"Stage{i+1}: {mfu:.2%}\" for i, mfu in enumerate(mfu_list)]\n                            console.print(f\"✓ MFU (multi-stage): {', '.join(mfu_values)}\", style=success_style)\n            except Exception as e:\n                console.print(f\"⚠ Warning: MFU calculation failed: {e}\", style=warning_style)\n\n            # Update samples with per-sample metrics\n            for sample_id, sample_metrics in per_sample_metrics.items():\n                if sample_id in samples:\n                    samples[sample_id].update(sample_metrics)\n            \n            # Write updated data back to generation result file\n            gen_data[\"samples\"] = samples\n            with open(generation_file, 'w', encoding='utf-8') as f:\n                json.dump(gen_data, f, indent=2, ensure_ascii=False)\n            console.print(f\"Updated sample metrics to: {generation_file}\")\n            \n            return gen_data, metrics, samples\n        \n        except Exception as e:\n            console.print(f\"✗ Error evaluating {task_name}: {e}\", style=err_style)\n            console.print(f\"Skipping task {task_name}\", style=warning_style)\n            return None, None, None\n    \n    @staticmethod\n    def _create_debug_file(generation_file: str, gen_data: Dict[str, Any], samples: Dict[str, Any], overwrite: bool = False) -> None:\n        \"\"\"Create debug file with first 100 samples\"\"\"\n\n        debug_file = f\"{generation_file}.debug\"\n        if overwrite or not os.path.exists(debug_file):\n            sorted_ids = sorted(samples.keys())\n            debug_sample_ids = sorted_ids[:100]\n            debug_samples = {id: samples[id] for id in debug_sample_ids}\n            \n            debug_data = {\n                \"model_name\": gen_data.get(\"model_name\", \"\"),\n                \"task_name\": gen_data.get(\"task_name\", \"\"),\n                \"split\": gen_data.get(\"split\", \"\"),\n                \"total_time\": gen_data.get(\"total_time\", 0),\n                \"avg_time_per_sample\": gen_data.get(\"avg_time_per_sample\", 0),\n                \"samples\": debug_samples,\n            }\n            \n            with open(debug_file, 'w', encoding='utf-8') as f:\n                json.dump(debug_data, f, indent=2, ensure_ascii=False)\n            console.print(f\"Created debug file: {debug_file}\")\n    \n    @staticmethod\n    def _calculate_model_total_time(model_results: Dict[str, Any]) -> float:\n        \"\"\"Calculate total time for all tasks of a model\"\"\"\n        model_total_time = 0\n        for task_name, task_results in model_results.items():\n            if task_name.startswith(\"_\"):\n                continue\n            for split, split_metrics in task_results.items():\n                model_total_time += split_metrics.get(\"total_time\", 0)\n        return model_total_time\n    \n    @staticmethod\n    def _save_results_as_json(eval_results: Dict[str, Any], output_path: str) -> None:\n        \"\"\"Save evaluation results as JSON\"\"\"\n        \n        os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else \".\", exist_ok=True)\n        with open(output_path, 'w', encoding='utf-8') as f:\n            json.dump(eval_results, f, indent=2, ensure_ascii=False)\n        console.print(f\"\\n\\n✓ Results Saved to {output_path}\\n\\n\", style=success_style, justify=\"center\")\n    \n    @staticmethod\n    def _load_existing_results(output_path: str, task_types: List[str] = None) -> dict:\n        \"\"\"Load existing evaluation results from JSON file for incremental update\"\"\"\n        eval_results = {}\n\n        if os.path.exists(output_path) and output_path.endswith('.json'):\n            try:\n                with open(output_path, 'r', encoding='utf-8') as f:\n                    eval_results = json.load(f)\n                console.print(f\"✓ Loaded existing results from {output_path}\", style=success_style, justify=\"center\")\n                if task_types is not None:\n                    console.print(f\"  Will update only specified tasks: {', '.join(task_types)}\", style=success_style, justify=\"center\")\n            except Exception as e:\n                console.print(f\"⚠ Warning: Failed to load existing results: {e}\", style=err_style, justify=\"center\")\n                console.print(f\"  Starting with empty results\", style=err_style, justify=\"center\")\n                eval_results = {}\n\n        return eval_results\n\n    @staticmethod\n    def evaluate_dev(\n        generation_results_dir: str,\n        output_path: str = \"./eval_results.json\",\n        data_dir: str = None,\n        overwrite: bool = False,\n        task_types: List[str] = None,\n        **kwargs\n    ):\n        \"\"\"Batch evaluate generated results and generate report\"\"\"\n        valid_kwargs = {k: v for k, v in kwargs.items() if v is not None}\n\n        console.print(f\"\\n\\nMetric Calculation\\n\", style=head_style, justify=\"center\")\n        console.print(f\"Result Directory: {generation_results_dir}\\n\\n\", style=head_style, justify=\"center\")\n\n        if not os.path.exists(generation_results_dir):\n            console.print(f\"✗ Error: Result Directory Not Found: {generation_results_dir}\", style=err_style, justify=\"center\")\n            return\n\n        eval_results = Benchmark._load_existing_results(output_path, task_types)\n        \n        for model_name in os.listdir(generation_results_dir):\n            model_dir = os.path.join(generation_results_dir, model_name)\n            if not os.path.isdir(model_dir):\n                continue\n\n            if model_name not in eval_results:\n                eval_results[model_name] = {}\n            \n            all_tasks = [t for t in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, t))]\n\n            if task_types is not None:\n                all_tasks = [t for t in all_tasks if t in task_types]\n\n            total_tasks_count = len(all_tasks)\n            for task_idx, task_name in enumerate(all_tasks, start=1):\n                task_dir = os.path.join(model_dir, task_name)\n\n                console.print(f\"\\nTask [{task_idx}/{total_tasks_count}]: {task_name}\\n\", style=subhead_style, justify=\"center\")\n\n                if task_name not in eval_results[model_name]:\n                    eval_results[model_name][task_name] = {}\n                for filename in os.listdir(task_dir):\n                    if not filename.endswith('_generated.json'):\n                        continue\n                    \n                    split = filename.replace('_generated.json', '')\n                    generation_file = os.path.join(task_dir, filename)\n\n                    cached_metrics = eval_results.get(model_name, {}).get(task_name, {}).get(split, {})\n\n                    # Evaluate single task\n                    gen_data, metrics, samples = Benchmark._evaluate_single_task(\n                        task_name=task_name,\n                        task_dir=task_dir,\n                        generation_file=generation_file,\n                        split=split,\n                        data_dir=data_dir,\n                        overwrite=overwrite,\n                        valid_kwargs=valid_kwargs,\n                        cached_metrics=cached_metrics\n                    )\n                    \n                    if gen_data is None:\n                        continue\n\n                    Benchmark._create_debug_file(generation_file, gen_data, samples, overwrite)\n                    eval_results[model_name][task_name][split] = {\n                        **metrics,\n                        \"total_time\": gen_data.get(\"total_time\", 0),\n                        \"avg_time_per_sample\": gen_data.get(\"avg_time_per_sample\", 0),\n                    }\n            \n            model_total_time = Benchmark._calculate_model_total_time(eval_results[model_name])\n            eval_results[model_name][\"_total_time\"] = model_total_time\n            console.print(f\"\\n✓ Total time: {model_total_time:.2f}s ({model_total_time/60:.2f}min)\\n\", style=success_style)\n        \n        Benchmark._save_results_as_json(eval_results, output_path)\n"
  },
  {
    "path": "benchmarks/benchmark/checkpoint_utils.py",
    "content": "\"\"\"\nPT format model checkpoint loading tool\n\nSupports loading PyTorch model checkpoints in non-safetensor format\n\"\"\"\n\nimport torch\nimport hashlib\nfrom pathlib import Path\nfrom typing import Dict, Optional, Tuple, List\nfrom difflib import SequenceMatcher\nfrom benchmark.console import console\n\n\ndef match_checkpoint_keys_to_model(\n    checkpoint_keys: List[str],\n    model_keys: List[str],\n    similarity_threshold: float = 0.8\n) -> Dict[str, str]:\n    \"\"\"\n    Intelligently match checkpoint key names to model key names\n    \n    Args:\n        checkpoint_keys: List of key names in checkpoint\n        model_keys: List of key names in model\n        similarity_threshold: Similarity threshold\n    \n    Returns:\n        Mapping dictionary {checkpoint_key: model_key}\n    \"\"\"\n    mapping = {}\n    \n    for ckpt_key in checkpoint_keys:\n        # Try exact match first\n        if ckpt_key in model_keys:\n            mapping[ckpt_key] = ckpt_key\n            continue\n        \n        # Try matching by removing \"model.\" prefix\n        if ckpt_key.startswith(\"model.\"):\n            clean_key = ckpt_key[6:]  # Remove \"model.\"\n            if clean_key in model_keys:\n                mapping[ckpt_key] = clean_key\n                continue\n        \n        # Try matching by adding \"model.\" prefix\n        prefixed_key = f\"model.{ckpt_key}\"\n        if prefixed_key in model_keys:\n            mapping[ckpt_key] = prefixed_key\n            continue\n        \n        # Use similarity matching\n        best_match = None\n        best_score = 0.0\n        \n        for model_key in model_keys:\n            score = SequenceMatcher(None, ckpt_key, model_key).ratio()\n            if score > best_score and score >= similarity_threshold:\n                best_score = score\n                best_match = model_key\n        \n        if best_match:\n            mapping[ckpt_key] = best_match\n            console.print(f\"Similarity match: {ckpt_key} -> {best_match} (score: {best_score:.2f})\")\n    \n    return mapping\n\n\ndef check_embedding_weight_sharing(\n    state_dict: Dict[str, torch.Tensor],\n    verbose: bool = True\n) -> Tuple[bool, Optional[str], Optional[str]]:\n    \"\"\"\n    Check if embed_tokens and lm_head weights are shared\n    \n    Args:\n        state_dict: Model state dictionary\n        verbose: Whether to print detailed information\n    \n    Returns:\n        (is_shared, embed_key, lm_head_key)\n    \"\"\"\n    # Find embed_tokens and lm_head keys\n    embed_key = None\n    lm_head_key = None\n    \n    for key in state_dict.keys():\n        if \"embed_tokens.weight\" in key:\n            embed_key = key\n        elif \"lm_head.weight\" in key:\n            lm_head_key = key\n    \n    if not embed_key or not lm_head_key:\n        if verbose:\n            console.print(f\"Complete weight pair not found: embed_tokens={embed_key}, lm_head={lm_head_key}\")\n        return False, embed_key, lm_head_key\n    \n    embed_tensor = state_dict[embed_key]\n    lm_head_tensor = state_dict[lm_head_key]\n    \n    if verbose:\n        console.print(f\"embed_tokens.weight shape: {embed_tensor.shape}\")\n        console.print(f\"lm_head.weight shape: {lm_head_tensor.shape}\")\n    \n    # Check if completely identical\n    is_shared = torch.equal(embed_tensor, lm_head_tensor)\n    \n    if verbose:\n        if is_shared:\n            console.print(\"✓ embed_tokens and lm_head weights are identical (shared weights)\")\n        else:\n            console.print(\"✗ embed_tokens and lm_head weights are different\")\n            # Calculate difference statistics\n            diff = (embed_tensor != lm_head_tensor).sum().item()\n            total = embed_tensor.numel()\n            console.print(f\"  Different elements: {diff}/{total} ({diff/total*100:.2f}%)\")\n    \n    return is_shared, embed_key, lm_head_key\n\n\ndef handle_weight_tying(\n    state_dict: Dict[str, torch.Tensor],\n    model_keys: List[str],\n    new_state_dict: Dict[str, str]\n) -> Dict[str, torch.Tensor]:\n    \"\"\"\n    Handle weight tying situations\n    \n    In some models, embed_tokens and lm_head weights are tied\n    \n    Args:\n        state_dict: Original state dictionary\n        model_keys: List of model key names\n        new_state_dict: Already mapped new state dictionary\n    \n    Returns:\n        Updated state dictionary\n    \"\"\"\n    # Scenario 1: checkpoint has embed_tokens but no lm_head\n    if any(\"embed_tokens.weight\" in k for k in state_dict.keys()):\n        embed_key = next((k for k in state_dict.keys() if \"embed_tokens.weight\" in k), None)\n        \n        # Check if lm_head is missing in new_state_dict\n        has_lm_head = any(\"lm_head.weight\" in k for k in new_state_dict.keys())\n        \n        if not has_lm_head and embed_key:\n            # Try to find lm_head key in model\n            lm_head_candidates = [\"lm_head.weight\", \"model.lm_head.weight\"]\n            for candidate in lm_head_candidates:\n                if candidate in model_keys:\n                    new_state_dict[candidate] = state_dict[embed_key]\n                    console.print(f\"✓ Weight tying: using {embed_key} to initialize {candidate}\")\n                    break\n    \n    # Scenario 2: checkpoint has lm_head but no embed_tokens\n    if any(\"lm_head.weight\" in k for k in state_dict.keys()):\n        lm_head_key = next((k for k in state_dict.keys() if \"lm_head.weight\" in k), None)\n        \n        # Check if embed_tokens is missing in new_state_dict\n        has_embed = any(\"embed_tokens.weight\" in k for k in new_state_dict.keys())\n        \n        if not has_embed and lm_head_key:\n            # Try to find embed_tokens key in model\n            embed_candidates = [\"embed_tokens.weight\", \"model.embed_tokens.weight\"]\n            for candidate in embed_candidates:\n                if candidate in model_keys:\n                    new_state_dict[candidate] = state_dict[lm_head_key]\n                    console.print(f\"✓ Weight tying: using {lm_head_key} to initialize {candidate}\")\n                    break\n    \n    return new_state_dict\n\n\ndef load_weights_from_pt(\n    model: torch.nn.Module,\n    checkpoint_path: str,\n    device: str = \"cpu\",\n    strict: bool = False,\n    check_weight_sharing: bool = True,\n    handle_weight_tying_flag: bool = True\n) -> Tuple[List[str], List[str]]:\n    \"\"\"\n    Load PT format checkpoint into model\n    \n    Args:\n        model: Target model\n        checkpoint_path: Checkpoint file path\n        device: Loading device\n        strict: Whether to load strictly (requires all keys to match)\n        check_weight_sharing: Whether to check weight sharing\n        handle_weight_tying_flag: Whether to handle weight tying\n    \n    Returns:\n        (missing_keys, unexpected_keys) Missing keys and unexpected keys\n    \"\"\"\n    console.print(f\"Loading checkpoint: {checkpoint_path}\")\n    \n    # 1. Load checkpoint\n    try:\n        state_dict = torch.load(checkpoint_path, map_location=device)\n    except Exception as e:\n        console.print(f\"Failed to load checkpoint: {e}\")\n        raise\n    \n    # 2. Extract model state dictionary\n    if 'model_state_dict' in state_dict:\n        console.print(\"Detected 'model_state_dict' key, extracting nested state dictionary\")\n        state_dict = state_dict['model_state_dict']\n    elif 'state_dict' in state_dict:\n        console.print(\"Detected 'state_dict' key, extracting nested state dictionary\")\n        state_dict = state_dict['state_dict']\n\n    checkpoint_keys = list(state_dict.keys())\n    model_keys = list(model.state_dict().keys())\n\n    console.print(f\"Checkpoint key count: {len(checkpoint_keys)}\")\n    console.print(f\"Model key count: {len(model_keys)}\")\n\n    if check_weight_sharing:\n        check_embedding_weight_sharing(state_dict, verbose=True)\n\n    console.print(\"Starting to match checkpoint key names to model key names...\")\n    key_mapping = match_checkpoint_keys_to_model(checkpoint_keys, model_keys)\n\n    matched_count = len(key_mapping)\n    console.print(f\"Successfully matched: {matched_count}/{len(checkpoint_keys)} keys\")\n\n    new_state_dict = {}\n    skipped_keys = []\n\n    for ckpt_key in checkpoint_keys:\n        target_key = key_mapping.get(ckpt_key)\n        if target_key is None:\n            skipped_keys.append(ckpt_key)\n            continue\n        new_state_dict[target_key] = state_dict[ckpt_key]\n\n    if skipped_keys:\n        console.print(f\"Skipped {len(skipped_keys)} unmatched keys\")\n        if len(skipped_keys) <= 10:\n            console.print(f\"Skipped keys: {skipped_keys}\")\n\n    if handle_weight_tying_flag:\n        new_state_dict = handle_weight_tying(state_dict, model_keys, new_state_dict)\n\n    console.print(\"Loading state dictionary into model...\")\n    missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=strict)\n\n    if missing_keys:\n        console.print(f\"Missing keys ({len(missing_keys)}): {missing_keys[:10]}{'...' if len(missing_keys) > 10 else ''}\")\n    else:\n        console.print(\"✓ No missing keys\")\n    \n    if unexpected_keys:\n        console.print(f\"Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:10]}{'...' if len(unexpected_keys) > 10 else ''}\")\n    else:\n        console.print(\"✓ No unexpected keys\")\n    \n    console.print(f\"✓ Checkpoint loading completed\")\n    \n    return missing_keys, unexpected_keys\n\n\ndef build_model_from_pt(\n    config_path: str,\n    checkpoint_path: str,\n    device: str = \"cuda\",\n    torch_dtype: Optional[torch.dtype] = None,\n    trust_remote_code: bool = True\n) -> torch.nn.Module:\n    \"\"\"\n    Create model from config and load PT checkpoint\n    \n    This is the unified function used by both HfTransformersGenerator and RayHfTransformersGenerator.\n    \n    Args:\n        config_path: Model configuration path\n        checkpoint_path: PT checkpoint path\n        device: Target device\n        torch_dtype: Data type\n        trust_remote_code: Whether to trust remote code\n    \n    Returns:\n        Model with checkpoint loaded\n    \"\"\"\n    from transformers import AutoConfig, AutoModelForCausalLM\n\n    config = AutoConfig.from_pretrained(\n        config_path,\n        trust_remote_code=trust_remote_code\n    )\n\n    model = AutoModelForCausalLM.from_config(\n        config,\n        trust_remote_code=trust_remote_code\n    )\n\n    if torch_dtype is not None:\n        model = model.to(torch_dtype)\n    if device != 'cpu':\n        model = model.to(device)\n\n    target_load_device = device if device != 'cpu' else 'cpu'\n    load_weights_from_pt(\n        model=model,\n        checkpoint_path=checkpoint_path,\n        device=target_load_device,\n        strict=False,\n        check_weight_sharing=True,\n        handle_weight_tying_flag=True\n    )\n    \n    return model\n\n\ndef build_model_from_hf(\n    model_name_or_path: str,\n    device: str = \"cuda\",\n    torch_dtype: Optional[torch.dtype] = None,\n    trust_remote_code: bool = True,\n    use_device_map: bool = True\n) -> torch.nn.Module:\n    \"\"\"\n    Load pretrained model from HuggingFace\n    \n    This is the unified function used by both HfTransformersGenerator and RayHfTransformersGenerator.\n    \n    Args:\n        model_name_or_path: Model name or path\n        device: Target device\n        torch_dtype: Data type\n        trust_remote_code: Whether to trust remote code\n        use_device_map: Whether to use device_map=\"auto\" for multi-GPU\n    \n    Returns:\n        Loaded model\n    \"\"\"\n    from transformers import AutoModelForCausalLM\n\n    should_use_device_map = use_device_map and device != \"cpu\" and \"cuda\" in device\n\n    model = AutoModelForCausalLM.from_pretrained(\n        model_name_or_path,\n        torch_dtype=torch_dtype,\n        device_map=\"auto\" if should_use_device_map else None,\n        trust_remote_code=trust_remote_code\n    )\n\n    if not should_use_device_map:\n        model = model.to(device)\n\n    return model\n\n\ndef export_pt_to_safetensor(\n    config_path: str,\n    checkpoint_path: str,\n    output_dir: Optional[str] = None,\n    trust_remote_code: bool = True,\n    use_cache: bool = True\n) -> str:\n    \"\"\"\n    Convert PT checkpoint to HuggingFace format for vLLM compatibility\n\n    Args:\n        config_path: Model configuration path (HuggingFace model path or local config)\n        checkpoint_path: PT checkpoint path\n        output_dir: Output directory for converted model (optional, will use /tmp if not specified)\n        trust_remote_code: Whether to trust remote code\n        use_cache: Whether to use cached conversion (skip if already converted)\n\n    Returns:\n        Path to converted HuggingFace format model\n    \"\"\"\n    from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n\n    hash_input = f\"{config_path}_{checkpoint_path}\".encode('utf-8')\n    hash_suffix = hashlib.md5(hash_input).hexdigest()[:16]\n\n    if output_dir is None:\n        output_dir = f\"/tmp/hf_checkpoint_{hash_suffix}\"\n\n    temp_model_path = Path(output_dir) / \"converted_model\"\n\n    if use_cache and temp_model_path.exists():\n        has_config = (temp_model_path / \"config.json\").exists()\n        has_weights = (\n            (temp_model_path / \"model.safetensors\").exists() or\n            (temp_model_path / \"pytorch_model.bin\").exists() or\n            any(temp_model_path.glob(\"*.safetensors\")) or\n            any(temp_model_path.glob(\"pytorch_model*.bin\"))\n        )\n\n        if has_config and has_weights:\n            console.print(\n                f\"✓ Found converted model, skipping conversion\",\n            )\n            console.print(\n                f\"  Converted model path: {temp_model_path}\",\n            )\n            return str(temp_model_path)\n\n    # Create output directory\n    temp_model_path.mkdir(parents=True, exist_ok=True)\n    console.print(f\"  Output directory: {temp_model_path}\")\n\n    try:\n        # 1. Load configuration\n        console.print(\"  [1/4] Loading model configuration...\")\n        config = AutoConfig.from_pretrained(\n            config_path,\n            trust_remote_code=trust_remote_code\n        )\n\n        # 2. Create model from config\n        console.print(\"  [2/4] Initializing model...\")\n        model = AutoModelForCausalLM.from_config(\n            config,\n            trust_remote_code=trust_remote_code\n        )\n\n        # 3. Load checkpoint\n        console.print(\"  [3/4] Loading PT checkpoint...\")\n        load_weights_from_pt(\n            model=model,\n            checkpoint_path=checkpoint_path,\n            device='cpu',\n            strict=False,\n            check_weight_sharing=True,\n            handle_weight_tying_flag=True\n        )\n\n        # 4. Save as HuggingFace format\n        console.print(\"  [4/4] Saving as HuggingFace format...\")\n        model.save_pretrained(temp_model_path, safe_serialization=True)\n\n        # Save tokenizer\n        tokenizer = AutoTokenizer.from_pretrained(\n            config_path,\n            trust_remote_code=trust_remote_code\n        )\n        tokenizer.save_pretrained(temp_model_path)\n\n        console.print(f\"✓ Model conversion completed: {temp_model_path}\")\n\n        return str(temp_model_path)\n\n    except Exception as e:\n        console.print(f\"✗ Conversion failed: {e}\")\n        # Clean up on failure\n        import shutil\n        if temp_model_path.exists():\n            shutil.rmtree(temp_model_path)\n        raise"
  },
  {
    "path": "benchmarks/benchmark/console.py",
    "content": "from rich.console import Console\nfrom pyfiglet import Figlet\n\nconsole = Console()\nerr_style = \"bold red\"\nwarning_style = \"bold yellow\"\nsuccess_style = \"green\"\ndim_style = \"dim\"\n\n# benchmark dataset\nf = Figlet(font='digital')\nhead_print = lambda x : f.renderText(x)\n\nhead_style = \"bold white on blue\"\nsubhead_style = \"bold black on bright_blue\"\nrow_style = \"black on bright_white\"\n\n\n# Generator styles\nhead_style_2 = \"bold white on magenta\"\nsubhead_style_2 = \"white\""
  },
  {
    "path": "benchmarks/benchmark/generation_runner.py",
    "content": "\"\"\"\nGeneration Runner\n\nResponsible for:\n1. Loading test data via data loader\n2. Calling Generator to produce model outputs  \n3. Saving generation results to JSON files\n\nNote: Does NOT compute evaluation metrics (handled by task-specific evaluators)\n\"\"\"\n\nimport json\nimport os\nimport time\nfrom typing import Dict, List, Optional, Any\nfrom pathlib import Path\n\nfrom benchmark.console import *\nfrom benchmark.base_generator import Generator\nfrom benchmark.tasks.v1_0.base_loader import BaseLoader\n\n\nclass GenerationRunner:\n    \"\"\"\n    Generation task runner\n    \n    Orchestrates the generation phase of evaluation:\n    - Loads test data via data loader\n    - Calls generator to produce model outputs\n    - Saves generation results to disk\n    \n    Evaluation metrics are computed separately by task-specific evaluators.\n    \"\"\"\n    \n    def __init__(\n        self,\n        data_loader: BaseLoader,\n        overwrite: bool = False\n    ):\n        \"\"\"\n        Args:\n            data_loader: Data loader (any object with load_data method)\n            overwrite: Whether to overwrite existing results\n        \"\"\"\n        self.data_loader = data_loader\n        self.overwrite = overwrite\n        self.benchmark_version = data_loader.benchmark_version\n    \n    def __call__(\n        self,\n        task_name: str,\n        split: str,\n        results_save_dir: str,\n        generator: Generator,\n        **kwargs\n    ) -> None:\n        \"\"\"\n        Execute generation pipeline\n        \n        This method is responsible for generation and saving only,\n        NOT for computing evaluation metrics.\n        \n        Args:\n            task_name: Task name\n            split: Dataset split\n            results_save_dir: Results save directory\n            generator: Generator instance\n            **kwargs: Generation parameters\n        \n        Returns:\n            None\n        \"\"\"\n        model_name = str(generator)\n        results_dir = os.path.join(\n            results_save_dir,\n            model_name,\n            task_name\n        )\n        os.makedirs(results_dir, exist_ok=True)\n        \n        generation_file = os.path.join(results_dir, f\"{split}_generated.json\")\n        \n        # Check if generation results already exist\n        if os.path.exists(generation_file) and not self.overwrite:\n            console.print(f\"Generation results already exist, skipping: {generation_file}\")\n            console.print(\"To regenerate, please set overwrite=True\")\n            \n            return None\n        \n        start_time = time.time()\n\n        # Extract sample_size parameter (don't pass to generator)\n        sample_size_param = kwargs.pop('sample_size', None)\n\n        # 1. Load data\n        test_data = self.data_loader.load_data(task_name=task_name, split=split, sample_size=sample_size_param)\n\n        # 2. Extract prompts and references\n        prompts = {id: data[\"prompt\"] for id, data in test_data.items()}\n        references = {id: data[\"ground_truth\"] for id, data in test_data.items()}\n\n        # 3. Generate text (unified entry point)\n        # All tasks now go through the unified generate() method\n        # For classification tasks, target_tokens is already in kwargs from generation_config\n        generations, logprobs = generator.generate(prompts, **kwargs)\n        \n        end_time = time.time()\n\n        total_time = end_time - start_time\n        num_samples = len(test_data)\n        avg_time_per_sample = total_time / num_samples if num_samples > 0 else 0\n        console.print(f\"Total time: {total_time:.2f}s, Average per sample: {avg_time_per_sample:.4f}s\")\n\n        # 4. Collect hardware info and MFU statistics (for MFU calculation)\n        console.print(\"[MFU DEBUG] Starting MFU data collection...\")\n        \n        hardware_info = None\n        mfu_stats = None\n        \n        try:\n            # Check if generator has get_hardware_info method\n            if not hasattr(generator, 'get_hardware_info'):\n                console.print(\"[MFU ERROR] generator does NOT have get_hardware_info() method!\")\n                console.print(f\"[MFU ERROR] Generator type: {type(generator)}\")\n                console.print(f\"[MFU ERROR] Generator class: {generator.__class__.__name__}\")\n            else:\n                hardware_info = generator.get_hardware_info()\n                if hardware_info:\n                    console.print(f\"[MFU DEBUG] GPU Model: {hardware_info.get('gpu_model')}\")\n                    console.print(f\"[MFU DEBUG] GPU Count: {hardware_info.get('gpu_count')}\")\n                    console.print(f\"[MFU DEBUG] GPU TFLOPs: {hardware_info.get('gpu_tflops')}\")\n                else:\n                    console.print(\"[MFU WARNING] hardware_info is None!\")\n            \n            # Check if generator has mfu_stats attribute\n            if not hasattr(generator, 'mfu_stats'):\n                console.print(\"[MFU WARNING] generator does NOT have 'mfu_stats' attribute!\")\n            else:\n                mfu_stats = getattr(generator, 'mfu_stats', None)\n                if mfu_stats:\n                    console.print(f\"[MFU DEBUG] mfu_stats sample count: {len(mfu_stats)}\")\n                    if len(mfu_stats) > 0:\n                        first_key = list(mfu_stats.keys())[0]\n                        first_stats = mfu_stats[first_key]\n                        console.print(f\"[MFU DEBUG] First sample: {first_key}\")\n                        console.print(f\"[MFU DEBUG]   input_tokens: {first_stats.get('input_tokens', 'MISSING')}\")\n                        console.print(f\"[MFU DEBUG]   output_tokens: {first_stats.get('output_tokens', 'MISSING')}\")\n                        console.print(f\"[MFU DEBUG]   times: {first_stats.get('times', 'MISSING')}\")\n                else:\n                    console.print(\"[MFU WARNING] mfu_stats is None!\")\n                    \n        except Exception as e:\n            console.print(f\"Warning: Failed to collect hardware info or MFU stats: {e}\", style=warning_style)\n        \n        num_params_value = getattr(generator, 'num_params', None)\n        console.print(f\"[MFU DEBUG] num_params value: {num_params_value}\")\n\n        # 5. Save generation results\n        self.save_generations(\n            model_name=model_name,\n            task_name=task_name,\n            split=split,\n            generations=generations,\n            references=references,\n            logprobs=logprobs,\n            test_data=test_data,\n            output_path=generation_file,\n            total_time=total_time,\n            avg_time_per_sample=avg_time_per_sample,\n            hardware_info=hardware_info,\n            mfu_stats=mfu_stats,\n            num_params=getattr(generator, 'num_params', None),\n        )\n        \n        console.print(f\"Generation results saved to: {generation_file}\")\n        \n        return None\n    \n    @staticmethod\n    def save_generations(\n        model_name: str,\n        task_name: str,\n        split: str,\n        generations: Dict[str, List[str]],\n        references: Dict[str, str],\n        logprobs: Dict[str, List[float]],\n        test_data: Dict[str, Dict[str, Any]],\n        output_path: str,\n        total_time: float,\n        avg_time_per_sample: float,\n        hardware_info: Optional[Dict[str, Any]] = None,\n        mfu_stats: Optional[Dict[str, Dict[str, List[int]]]] = None,\n        num_params: Optional[float] = None,\n    ):\n        \"\"\"\n        Save generation results (excluding evaluation metrics)\n        \n        Result format:\n        {\n            \"model_name\": \"...\",\n            \"task_name\": \"...\",\n            \"split\": \"...\",\n            \"total_time\": \"...\",\n            \"avg_time_per_sample\": \"...\",\n            \"samples\": {\n                \"<sample_id>\": {\n                    \"prompt\": \"...\",\n                    \"generations\": [\"...\", \"...\"],\n                    \"ground_truth\": \"...\",\n                    \"metadata\": {...}  # Contains metadata from original data\n                },\n                ...\n            }\n        }\n        \"\"\"\n        # Check if this is a classification task (label_pred)\n        is_classification_task = task_name == \"label_pred\"\n        \n        samples: Dict[str, Any] = {}\n        for id, gens in generations.items():\n            sample_data = {\n                \"prompt\": test_data.get(id, {}).get(\"prompt\", \"\"),\n                \"generations\": gens,\n                \"ground_truth\": references.get(id, \"\"),\n            }\n\n            if id in logprobs and logprobs[id]:\n                sample_data[\"logprobs\"] = logprobs[id]\n\n            # Add MFU statistics for this sample (for MFU calculation)\n            if mfu_stats and id in mfu_stats:\n                sample_data[\"input_tokens\"] = mfu_stats[id].get(\"input_tokens\", [])\n                sample_data[\"output_tokens\"] = mfu_stats[id].get(\"output_tokens\", [])\n                sample_data[\"times\"] = mfu_stats[id].get(\"times\", [])\n\n            if is_classification_task and id in test_data:\n                metadata = test_data[id].get(\"metadata\", {})\n                if \"uid\" in metadata:\n                    sample_data[\"user_id\"] = metadata[\"uid\"]\n\n            if id in test_data and \"metadata\" in test_data[id]:\n                sample_data[\"metadata\"] = test_data[id][\"metadata\"]\n\n            samples[id] = sample_data\n\n        data = {\n            \"model_name\": model_name,\n            \"task_name\": task_name,\n            \"split\": split,\n            \"total_time\": total_time,\n            \"avg_time_per_sample\": avg_time_per_sample,\n            \"samples\": samples,\n        }\n\n        # Add hardware info and token statistics (for MFU calculation)\n        if hardware_info:\n            data[\"hardware_info\"] = hardware_info\n        else:\n            console.print(\"[MFU DEBUG] ❌ Skipping hardware_info (None or empty)\")\n\n        if num_params:\n            data[\"num_params\"] = num_params\n        else:\n            console.print(\"[MFU DEBUG] ❌ Skipping num_params (None or 0)\")\n\n        # Save mfu_stats_aggregate for multi-stage MFU calculation\n        # Compute aggregate statistics from per-sample mfu_stats\n        if mfu_stats:\n            # Determine number of stages from first sample\n            num_stages = 0\n            for sample_stats in mfu_stats.values():\n                num_stages = len(sample_stats.get(\"input_tokens\", []))\n                console.print(f\"[MFU DEBUG] Determined num_stages: {num_stages}\")\n                break\n\n            # New structure: dict with lists instead of array of dicts\n            data[\"mfu_stats_aggregate\"] = {\n                \"total_input_tokens\": [],\n                \"total_output_tokens\": [],\n                \"total_time\": []\n            }\n\n            for stage_idx in range(num_stages):\n                total_input_tokens = 0\n                total_output_tokens = 0\n\n                # Aggregate token stats across all samples for this stage\n                for sample_stats in mfu_stats.values():\n                    input_tokens_list = sample_stats.get(\"input_tokens\", [])\n                    output_tokens_list = sample_stats.get(\"output_tokens\", [])\n\n                    if stage_idx < len(input_tokens_list):\n                        total_input_tokens += input_tokens_list[stage_idx]\n                    if stage_idx < len(output_tokens_list):\n                        total_output_tokens += output_tokens_list[stage_idx]\n\n                # Calculate stage time as max across all samples\n                # Ray workers run in parallel, so stage time = slowest worker time\n                stage_times = []\n                for sample_stats in mfu_stats.values():\n                    times_list = sample_stats.get(\"times\", [])\n                    if stage_idx < len(times_list):\n                        stage_times.append(times_list[stage_idx])\n\n                # Use max time if available, otherwise 0.0\n                stage_time = max(stage_times) if stage_times else 0.0\n\n                data[\"mfu_stats_aggregate\"][\"total_input_tokens\"].append(total_input_tokens)\n                data[\"mfu_stats_aggregate\"][\"total_output_tokens\"].append(total_output_tokens)\n                data[\"mfu_stats_aggregate\"][\"total_time\"].append(stage_time)\n                \n        else:\n            console.print(\"[MFU DEBUG] ❌ Skipping mfu_stats processing (None or empty)\")\n        \n        os.makedirs(os.path.dirname(output_path), exist_ok=True)\n        \n        with open(output_path, 'w', encoding='utf-8') as f:\n            json.dump(data, f, indent=2, ensure_ascii=False)"
  },
  {
    "path": "benchmarks/benchmark/gpu_utils.py",
    "content": "\"\"\"\nGPU hardware detection and FLOPS calculation utilities for MFU computation.\n\"\"\"\n\nfrom typing import Dict, Any, Optional\nfrom benchmark.console import console\n\n# GPU theoretical peak FLOPS (TFLOPS) for BF16/FP16\n# Source: Official vendor specifications\nGPU_TFLOPS_MAP = {\n    # NVIDIA A100 series\n    \"A100-SXM4-40GB\": 312.0,\n    \"A100-SXM4-80GB\": 312.0,\n    \"A100-PCIE-40GB\": 312.0,\n    \"A100-PCIE-80GB\": 312.0,\n\n    # NVIDIA A800 series (China-specific A100 variant)\n    \"A800-SXM4-80GB\": 312.0,\n    \"A800\": 312.0,\n\n    # NVIDIA H100 series\n    \"H100-SXM5-80GB\": 989.0,\n    \"H100-PCIE-80GB\": 756.0,\n    \"H100\": 989.0,\n\n    # NVIDIA V100 series\n    \"V100-SXM2-16GB\": 125.0,\n    \"V100-SXM2-32GB\": 125.0,\n    \"V100-PCIE-16GB\": 112.0,\n    \"V100-PCIE-32GB\": 112.0,\n\n    # NVIDIA A40\n    \"A40\": 149.7,\n\n    # NVIDIA A30\n    \"A30\": 165.0,\n\n    # NVIDIA A10\n    \"A10\": 125.0,\n\n    # NVIDIA RTX series\n    \"RTX 4090\": 82.6,\n    \"RTX 4080\": 48.7,\n    \"RTX 3090\": 35.6,\n    \"RTX 3080\": 29.8,\n}\n\n\ndef _normalize_gpu_name(gpu_name: str) -> str:\n    \"\"\"\n    Normalize GPU name for lookup in TFLOPS map.\n\n    Args:\n        gpu_name: Raw GPU name from torch.cuda\n\n    Returns:\n        Normalized GPU name\n    \"\"\"\n    gpu_name = gpu_name.strip()\n\n    # Try exact match first\n    if gpu_name in GPU_TFLOPS_MAP:\n        return gpu_name\n\n    # Try fuzzy matching\n    gpu_name_upper = gpu_name.upper()\n\n    # Match A100 variants\n    if \"A100\" in gpu_name_upper:\n        if \"80GB\" in gpu_name_upper or \"80G\" in gpu_name_upper:\n            return \"A100-SXM4-80GB\"\n        else:\n            return \"A100-SXM4-40GB\"\n\n    # Match A800\n    if \"A800\" in gpu_name_upper:\n        return \"A800\"\n\n    # Match H100 variants\n    if \"H100\" in gpu_name_upper:\n        if \"PCIE\" in gpu_name_upper or \"PCIe\" in gpu_name_upper:\n            return \"H100-PCIE-80GB\"\n        else:\n            return \"H100-SXM5-80GB\"\n\n    # Match V100 variants\n    if \"V100\" in gpu_name_upper:\n        if \"32GB\" in gpu_name_upper or \"32G\" in gpu_name_upper:\n            return \"V100-SXM2-32GB\"\n        else:\n            return \"V100-SXM2-16GB\"\n\n    # Match other GPUs\n    for known_gpu in GPU_TFLOPS_MAP.keys():\n        if known_gpu.upper() in gpu_name_upper:\n            return known_gpu\n\n    return gpu_name\n\n\ndef get_gpu_tflops(gpu_name: str) -> Optional[float]:\n    \"\"\"\n    Get theoretical peak TFLOPS for a given GPU model.\n\n    Args:\n        gpu_name: GPU model name\n\n    Returns:\n        TFLOPS value for BF16/FP16, or None if unknown\n    \"\"\"\n    normalized_name = _normalize_gpu_name(gpu_name)\n    return GPU_TFLOPS_MAP.get(normalized_name)\n\n\ndef get_gpu_info() -> Dict[str, Any]:\n    \"\"\"\n    Detect GPU hardware information using PyTorch.\n\n    Returns:\n        Dictionary containing:\n        - gpu_available: bool, whether GPU is available\n        - gpu_count: int, number of GPUs\n        - gpu_model: str, GPU model name\n        - gpu_memory_total_gb: float, total GPU memory in GB\n        - gpu_tflops: float, theoretical peak TFLOPS for BF16/FP16\n    \"\"\"\n    try:\n        import torch\n    except ImportError:\n        console.print(\"PyTorch not available, cannot detect GPU info\")\n        return {\n            \"gpu_available\": False,\n            \"gpu_count\": 0,\n            \"gpu_model\": \"unknown\",\n            \"gpu_memory_total_gb\": 0.0,\n            \"gpu_tflops\": None,\n        }\n\n    if not torch.cuda.is_available():\n        console.print(\"CUDA not available\")\n        return {\n            \"gpu_available\": False,\n            \"gpu_count\": 0,\n            \"gpu_model\": \"unknown\",\n            \"gpu_memory_total_gb\": 0.0,\n            \"gpu_tflops\": None,\n        }\n\n    gpu_count = torch.cuda.device_count()\n\n    # Get properties of the first GPU (assume homogeneous cluster)\n    gpu_props = torch.cuda.get_device_properties(0)\n    gpu_model = gpu_props.name\n    gpu_memory_total_gb = gpu_props.total_memory / (1024 ** 3)  # Convert bytes to GB\n\n    # Get TFLOPS\n    gpu_tflops = get_gpu_tflops(gpu_model)\n\n    if gpu_tflops is None:\n        console.print(\n            f\"Unknown GPU model '{gpu_model}', cannot determine TFLOPS. \"\n            f\"Please add it to GPU_TFLOPS_MAP in gpu_utils.py\"\n        )\n\n    gpu_info = {\n        \"gpu_available\": True,\n        \"gpu_count\": gpu_count,\n        \"gpu_model\": gpu_model,\n        \"gpu_memory_total_gb\": round(gpu_memory_total_gb, 2),\n        \"gpu_tflops\": gpu_tflops,\n    }\n\n    console.print(f\"Detected GPU: {gpu_model} x {gpu_count}, {gpu_tflops} TFLOPS (BF16/FP16)\")\n\n    return gpu_info\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/__init__.py",
    "content": "\"\"\"\nTasks definition for Benchmark\n\"\"\"\n\nfrom .tasks import (\n    BenchmarkTable,\n    check_benchmark_version,\n    check_task_types,\n    check_splits,\n    LATEST_BENCHMARK_VERSION,\n)\n\n__all__ = [\n    \"BenchmarkTable\",\n    \"check_benchmark_version\",\n    \"check_task_types\",\n    \"check_splits\",\n    \"LATEST_BENCHMARK_VERSION\",\n]\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/tasks.py",
    "content": "\"\"\"\nTask table and utility functions for Benchmark\n\"\"\"\n\nfrom typing import List, Optional, Tuple\nfrom benchmark.tasks.v1_0.registry import TaskTable as TaskTable_v1_0\n\nLATEST_BENCHMARK_VERSION = \"v1.0\"\n\n\nBenchmarkTable = {\n    \"v1.0\": TaskTable_v1_0,\n}\n\n\ndef get_available_benchmark_versions() -> List[str]:\n    \"\"\"Get all available benchmark versions\"\"\"\n    return sorted(list(BenchmarkTable.keys()))\n\n\ndef get_available_task_types(benchmark_version: str = LATEST_BENCHMARK_VERSION) -> List[str]:\n    \"\"\"Get all task types for the specified version\"\"\"\n    task_table = BenchmarkTable[benchmark_version]\n    return sorted(list(task_table.keys()))\n\n\ndef get_available_domains(benchmark_version: str = LATEST_BENCHMARK_VERSION) -> List[str]:\n    \"\"\"Get all domains for the specified version\"\"\"\n    domains = set()\n    for task_table in BenchmarkTable[benchmark_version].values():\n        for domain in task_table.keys():\n            domains.add(domain)\n    return sorted(list(domains))\n\n\ndef get_available_languages(benchmark_version: str = LATEST_BENCHMARK_VERSION) -> List[str]:\n    \"\"\"Get all languages for the specified version\"\"\"\n    languages = set()\n    for task_table in BenchmarkTable[benchmark_version].values():\n        for task in task_table.values():\n            for lang in task.keys():\n                languages.add(lang)\n    return sorted(list(languages))\n\n\ndef check_benchmark_version(benchmark_version: Optional[str]) -> str:\n    \"\"\"\n    Validate if benchmark version is valid\n    \n    Args:\n        benchmark_version: Version to validate, returns latest version if None\n        \n    Returns:\n        str: Valid benchmark version\n        \n    Raises:\n        ValueError: If version is invalid\n    \"\"\"\n    if benchmark_version is None:\n        benchmark_version = LATEST_BENCHMARK_VERSION\n    else:\n        available_benchmark_versions = get_available_benchmark_versions()\n\n        if benchmark_version not in available_benchmark_versions:\n            raise ValueError(\n                f\"Invalid benchmark version: {benchmark_version}. Available versions: {', '.join(available_benchmark_versions)}\"\n            )\n\n    return benchmark_version\n\n\ndef check_task_types(\n    task_types: Optional[List[str]],\n    benchmark_version: str = LATEST_BENCHMARK_VERSION,\n) -> List[str]:\n    \"\"\"\n    Validate if task types are valid\n    \n    Args:\n        task_types: List of task types to validate, returns all task types if None\n        benchmark_version: Benchmark version\n        \n    Returns:\n        List[str]: Valid task types list\n        \n    Raises:\n        ValueError: If task type is invalid\n    \"\"\"\n    available_task_types = get_available_task_types(benchmark_version)\n    if task_types is None:\n        task_types = available_task_types\n    else:\n        if isinstance(task_types, str):\n            task_types = [task_types]\n        task_types = sorted(list(set(task_types)))\n        task_types = [task_type.lower() for task_type in task_types]\n        for task_type in task_types:\n            if task_type not in available_task_types:\n                raise ValueError(\n                    f\"{benchmark_version} | Invalid task type: {task_type}. Available task types: {', '.join(available_task_types)}\"\n                )\n    return task_types\n\n\ndef check_splits(\n    splits: Optional[List[str]],\n    benchmark_version: str = LATEST_BENCHMARK_VERSION,\n) -> List[str]:\n    \"\"\"\n    Validate if dataset splits are valid\n    \n    Args:\n        splits: List of splits to validate, returns all splits if None\n        benchmark_version: Benchmark version\n        \n    Returns:\n        List[str]: Valid splits list\n        \n    Raises:\n        ValueError: If split is invalid\n    \"\"\"\n    # Only allow test split\n    available_splits = [\"test\"]\n    \n    if splits is None:\n        splits = available_splits\n    else:\n        if isinstance(splits, str):\n            splits = [splits]\n        splits = sorted(list(set(splits)))\n        splits = [split.lower() for split in splits]\n    \n    for split in splits:\n        if split not in available_splits:\n            raise ValueError(\n                f\"{benchmark_version} | Invalid split: {split}. Available splits: {', '.join(available_splits)}\"\n            )\n    return splits\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/__init__.py",
    "content": "\"\"\"\nv1.0 Version Task Definitions\n\"\"\"\n\nfrom .registry import TaskTable\n\n__all__ = [\"TaskTable\"]\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/base_evaluator.py",
    "content": "\"\"\"\nBase Evaluator for all task evaluators\n\nProvides common interface for evaluation logic.\n\"\"\"\n\nimport json\nimport os\nfrom abc import ABC, abstractmethod\nfrom typing import Dict, Any, Tuple, Optional, List\nfrom benchmark.console import console, success_style\n\n\nclass BaseEval(ABC):\n    \"\"\"Base class for all task evaluators\"\"\"\n\n    def __init__(\n        self,\n        samples: Dict[str, Dict[str, Any]],\n        task_name: Optional[str] = None,\n        predictions_dir: Optional[str] = None,\n        debug: bool = False,\n        task_config: Optional[Dict[str, Any]] = None,\n        data_dir: Optional[str] = None,\n        overwrite: bool = False,\n        cached_metrics: Optional[Dict[str, Any]] = None\n    ):\n        \"\"\"\n        Initialize base evaluator\n\n        Args:\n            samples: Dictionary of samples from test_generated.json\n                Format: {\n                    sample_id: {\n                        \"prompt\": \"...\",\n                        \"generations\": [\"...\"],\n                        \"ground_truth\": \"...\",\n                        \"metadata\": {...}\n                    }\n                }\n            task_name: Task name (e.g., \"math_500\")\n            predictions_dir: Directory to save debug files (optional)\n            debug: Whether to save debug information\n            task_config: Task configuration dictionary (optional)\n            data_dir: Data directory path (optional)\n            overwrite: Whether to overwrite existing metrics and recompute from scratch\n            cached_metrics: Existing overall metrics from eval_results (optional)\n        \"\"\"\n        self.samples = samples\n        self.task_name = task_name\n        self.predictions_dir = predictions_dir\n        self.debug = debug\n        self.task_config = task_config or {}\n        self.data_dir = data_dir\n        self.overwrite = overwrite\n        self.cached_metrics = cached_metrics or {}\n    \n    def evaluate(self) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:\n        \"\"\"\n        Evaluate the samples and return metrics\n\n        This method provides a simplified two-level caching-aware evaluation flow:\n        1. If overwrite=True, always recompute from scratch\n        2. If cached overall metrics exist in eval_results, return them with empty per_sample_metrics\n        3. Otherwise, compute from scratch\n\n        Subclasses should override:\n        - required_metrics property: Return list of overall metric names\n        - _compute_metrics_from_scratch(): Compute all metrics from scratch\n\n        Returns:\n            Tuple of (metrics, per_sample_metrics)\n        \"\"\"\n\n        # If overwrite=True, always recompute from scratch\n        if self.overwrite:\n            console.print(\"[cyan]Overwrite=True, recomputing all metrics from scratch...[/cyan]\")\n            return self._compute_metrics_from_scratch()\n\n        # If cached overall metrics exist, use them\n        if self._has_all_required_metrics():\n            console.print(\"[cyan]Using existing overall metrics from eval_results...[/cyan]\")\n            # Return cached metrics with empty per_sample_metrics (not needed when using cache)\n            return self.cached_metrics, {}\n\n        # Otherwise, compute from scratch\n        console.print(\"[cyan]Computing metrics from scratch...[/cyan]\")\n        return self._compute_metrics_from_scratch()\n\n    def _all_samples_have_keys(self, required_keys: List[str]) -> bool:\n        \"\"\"Check if all samples have required keys\"\"\"\n        for sample in self.samples.values():\n            for key in required_keys:\n                if key not in sample:\n                    return False\n        return True\n\n    @property\n    def required_metrics(self) -> Optional[List[str]]:\n        \"\"\"Define required overall metric keys\"\"\"\n        return None\n\n    def _has_all_required_metrics(self) -> bool:\n        \"\"\"Check if cached_metrics contains all required keys (override for custom logic)\"\"\"\n        if self.required_metrics is not None:\n            return all(key in self.cached_metrics for key in self.required_metrics)\n        return False\n\n    def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:\n        \"\"\"Compute metrics from scratch (override in subclasses)\"\"\"\n        raise NotImplementedError(\"Subclasses must implement _compute_metrics_from_scratch()\")\n\n\n    def _save_debug_json(\n        self,\n        debug_info: Dict[str, Any],\n        filename: str = \"debug.json\"\n    ) -> Optional[str]:\n        \"\"\"Save debug information to JSON file\"\"\"\n        if not self.predictions_dir:\n            return None\n\n        debug_filename = os.path.join(self.predictions_dir, filename)\n        os.makedirs(os.path.dirname(debug_filename), exist_ok=True)\n\n        with open(debug_filename, 'w', encoding='utf-8') as f:\n            json.dump(debug_info, f, indent=2, ensure_ascii=False)\n            \n        console.print(f\"✓ Debug information saved to: {debug_filename}\", style=success_style)\n        return debug_filename\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/base_loader.py",
    "content": "\"\"\"\nBase Loader for all task data loaders\n\nProvides common functionality for data loading, sampling, and file path resolution.\n\"\"\"\n\nimport os\nimport json\nimport pandas as pd\nfrom typing import Dict, Any, Optional\nfrom abc import ABC\n\nfrom benchmark.console import *\n\n\nclass BaseLoader(ABC):\n    \"\"\"Base class for all task data loaders\"\"\"\n    \n    def __init__(\n        self,\n        task_config: Dict[str, Any],\n        data_dir: Optional[str] = None,\n        tokenizer: Optional[Any] = None,\n        enable_thinking: Optional[bool] = None,\n    ):\n        \"\"\"Initialize base loader\"\"\"\n        self.task_config = task_config\n        self.data_dir = data_dir\n        self.tokenizer = tokenizer\n        self.enable_thinking = enable_thinking\n        self.task_name = task_config.get(\"name\", \"unknown\")\n\n        # Validate tokenizer is provided for messages-based format\n        if self.tokenizer is None:\n            raise ValueError(\n                f\"{self.task_name} requires tokenizer for messages-based format. \"\n                f\"Please provide model_path when initializing Benchmark.\\n\"\n                f\"Example: Benchmark(task_types=['{self.task_name}'], model_path='your-model-path')\"\n            )\n\n    def load_data(self, split: str = \"test\", sample_size: Optional[Any] = None) -> Dict[str, Dict[str, Any]]:\n        \"\"\"\n        Load data for the task in messages-based format\n\n        Args:\n            split: Dataset split (default \"test\")\n            sample_size: Override sample size (can be int, \"full\", or None to use task config)\n\n        Returns:\n            Dictionary mapping sample_id to sample data:\n            {\n                sample_id: {\n                    \"prompt\": \"formatted prompt from apply_chat_template\",\n                    \"ground_truth\": \"answer\",\n                    \"metadata\": {\n                        \"row_index\": idx,\n                        \"messages\": [...]\n                    }\n                }\n            }\n        \"\"\"\n        # Determine effective sample size\n        if sample_size is not None:\n            if sample_size == \"full\":\n                effective_sample_size = self.task_config.get(\"size\")\n            else:\n                effective_sample_size = int(sample_size)\n        else:\n            effective_sample_size = self.task_config.get(\"sample_size\")\n        \n        full_size = self.task_config.get(\"size\")\n\n        # Try to load cached sample dataframe\n        df = None\n        if effective_sample_size is not None and full_size is not None and effective_sample_size < full_size:\n            df = self._load_sample_dataframe(split, effective_sample_size)\n\n        # If no cache, load and sample original data\n        if df is None:\n            df = self._load_dataframe(split)\n\n            # Perform sampling if needed\n            if effective_sample_size is not None and effective_sample_size < len(df):\n                df = self._sample_data(df, effective_sample_size)\n\n                # Save sampled data\n                if full_size is not None and effective_sample_size < full_size:\n                    self._save_sample_data(df, split, effective_sample_size)\n\n        if 'messages' not in df.columns:\n            raise ValueError(\n                f\"{self.task_name} requires 'messages' column in data file. \"\n                f\"Found columns: {list(df.columns)}\\n\"\n                f\"Please ensure your data is in messages-based format.\"\n            )\n\n        if 'metadata' not in df.columns:\n            raise ValueError(\n                f\"{self.task_name} requires 'metadata' column in data file. \"\n                f\"Found columns: {list(df.columns)}\\n\"\n                f\"Please ensure your data is in messages-based format.\"\n            )\n\n        console.print(f\"[green]Processing {self.task_name} data in messages-based format[/green]\")\n\n        result = self._process_dataframe(df)\n\n        return result\n    \n    @staticmethod\n    def _is_empty_value(value) -> bool:\n        \"\"\"Check if a value is None, NaN, or empty\"\"\"\n        if value is None:\n            return True\n        \n        if isinstance(value, float):\n            try:\n                return pd.isna(value)\n            except (ValueError, TypeError):\n                return False\n        \n        if isinstance(value, str):\n            return len(value.strip()) == 0\n        \n        try:\n            if hasattr(value, '__len__'):\n                return len(value) == 0\n        except (ValueError, TypeError):\n            pass\n        \n        return False\n\n    @staticmethod\n    def _convert_messages_format(messages: list) -> list:\n        \"\"\"\n        Convert message format.\n\n        {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]} \n        -> \n        {\"role\": \"user\", \"content\": \"...\"}\n        \"\"\"\n        converted = []\n        for msg in messages:\n            content = msg.get(\"content\")\n            if isinstance(content, list):\n                # Extract text from content list\n                text_parts = []\n                for item in content:\n                    if isinstance(item, dict) and item.get(\"type\") == \"text\":\n                        text_parts.append(item.get(\"text\", \"\"))\n                converted.append({\n                    \"role\": msg.get(\"role\"),\n                    \"content\": \"\".join(text_parts)\n                })\n            else:\n                # Already in old format\n                converted.append(msg)\n        return converted\n\n    def _load_custom_chat_template(self):\n        \"\"\"Load custom chat template based on configuration\"\"\"\n        if not self.tokenizer:\n            return\n\n        prompt_config = self.task_config.get(\"prompt_config\", {})\n        custom_template = prompt_config.get(\"custom_chat_template\")\n\n        template_path = os.path.join(\n            os.path.dirname(__file__),\n            custom_template\n        )\n\n        if not os.path.exists(template_path):\n            raise FileNotFoundError(f\"✗ Custom chat template not found: {template_path}\")\n\n        with open(template_path, \"r\", encoding=\"utf-8\") as f:\n            self.tokenizer.chat_template = f.read()\n        console.print(f\"✓ Loaded custom chat template: {custom_template}\", style=success_style)\n\n    def _get_data_file_path(self, split: str) -> str:\n        \"\"\"Get data file path for the given split\"\"\"\n        if self.data_dir:\n            base_dir = self.data_dir\n        else:\n            base_dir = \"./data\"\n        \n        filename = f\"{self.task_name}_{split}.parquet\"\n        \n        possible_paths = [\n            os.path.join(base_dir, self.task_name, filename),\n        ]\n        \n        for file_path in possible_paths:\n            if os.path.exists(file_path):\n                return file_path\n        \n        return possible_paths[0]\n    \n    def _get_sample_data_file_path(self, split: str, sample_size: int) -> str:\n        \"\"\"Get sample data file path\"\"\"\n        if self.data_dir:\n            base_dir = self.data_dir\n        else:\n            base_dir = \"./data\"\n        \n        possible_paths = [\n            os.path.join(base_dir, self.task_name, f\"{self.task_name}_{split}_sample_{sample_size}.parquet\"),\n            os.path.join(base_dir, f\"{self.task_name}_{split}_sample_{sample_size}.parquet\"),\n        ]\n        \n        for path in possible_paths:\n            if os.path.exists(path):\n                return path\n        \n        return possible_paths[0]\n    \n    def _load_dataframe(self, split: str) -> pd.DataFrame:\n        \"\"\"Load DataFrame from data file\"\"\"\n        data_file = self._get_data_file_path(split)\n        \n        if not os.path.exists(data_file):\n            raise FileNotFoundError(f\"Data file not found: {data_file}\")\n        \n        console.print(f\"Loading data file: {data_file}\")\n        \n        if data_file.endswith('.parquet'):\n            df = pd.read_parquet(data_file)\n        else:\n            raise ValueError(f\"Unsupported file format: {data_file}\")\n        \n        return df\n    \n    def _sample_data(self, df: pd.DataFrame, sample_size: int) -> pd.DataFrame:\n        \"\"\"Sample data from DataFrame\"\"\"\n        if sample_size >= len(df):\n            return df\n        \n        console.print(f\"Sampling {sample_size} samples (total: {len(df)})\")\n        return df.head(sample_size)\n    \n    def _save_sample_data(\n        self,\n        df: pd.DataFrame,\n        split: str,\n        sample_size: int\n    ):\n        \"\"\"Save sample data in parquet format\"\"\"\n        sample_file = self._get_sample_data_file_path(split, sample_size)\n\n        sample_dir = os.path.dirname(sample_file)\n        if sample_dir:\n            os.makedirs(sample_dir, exist_ok=True)\n\n        df.to_parquet(sample_file, index=False)\n        console.print(f\"Sample data saved to: {sample_file}\")\n    \n    def _load_sample_dataframe(self, split: str, sample_size: int) -> Optional[pd.DataFrame]:\n        \"\"\"Load sample dataframe from cache if exists\"\"\"\n        sample_file = self._get_sample_data_file_path(split, sample_size)\n\n        if not os.path.exists(sample_file):\n            return None\n\n        console.print(f\"Loading sample data from cache: {sample_file}\")\n\n        df = pd.read_parquet(sample_file)\n        return df\n\n    def _process_dataframe(self, df: pd.DataFrame) -> Dict[str, Dict[str, Any]]:\n        \"\"\"Process DataFrame and convert to model input format\"\"\"\n        self._load_custom_chat_template()\n\n        result = {}\n\n        prompt_config = self.task_config.get(\"prompt_config\", {})\n        # Command-line parameter has higher priority than config\n        if self.enable_thinking is not None:\n            enable_thinking = self.enable_thinking\n        else:\n            enable_thinking = prompt_config.get(\"enable_thinking\", False)\n\n        console.print(f\"[cyan]Auto Thinking: {'✓ Enabled' if enable_thinking else '✗ Disabled'}[/cyan]\")\n\n        for idx, row in df.iterrows():\n            sample_id = str(idx)\n\n            messages = row.get('messages')\n            if self._is_empty_value(messages):\n                console.print(f\"Sample {sample_id}: messages is empty, skipping\")\n                continue\n\n            if isinstance(messages, str):\n                try:\n                    messages = json.loads(messages)\n                except Exception:\n                    console.print(f\"Sample {sample_id}: failed to parse messages, skipping\")\n                    continue\n\n            messages = self._convert_messages_format(messages)\n\n            try:\n                formatted_prompt = self.tokenizer.apply_chat_template(\n                    messages,\n                    tokenize=False,\n                    add_generation_prompt=True,\n                    enable_thinking=enable_thinking,\n                )\n            except Exception as e:\n                console.print(f\"Sample {sample_id}: failed to apply chat template: {e}, skipping\")\n                continue\n\n            metadata_raw = row.get('metadata')\n            if self._is_empty_value(metadata_raw):\n                console.print(f\"Sample {sample_id}: metadata is empty, skipping\")\n                continue\n\n            if isinstance(metadata_raw, str):\n                try:\n                    metadata_dict = json.loads(metadata_raw)\n                except Exception:\n                    console.print(f\"Sample {sample_id}: failed to parse metadata, skipping\")\n                    continue\n            elif isinstance(metadata_raw, dict):\n                metadata_dict = metadata_raw\n            else:\n                console.print(f\"Sample {sample_id}: invalid metadata format, skipping\")\n                continue\n\n            answer = metadata_dict.get('answer')\n            if self._is_empty_value(answer):\n                console.print(f\"Sample {sample_id}: answer is empty in metadata, skipping\")\n                continue\n\n            ground_truth_str = str(answer).strip()\n\n            result_item = {\n                \"prompt\": formatted_prompt,\n                \"ground_truth\": ground_truth_str,\n                \"metadata\": self._make_metadata_serializable(idx, metadata_dict)\n            }\n\n            result[sample_id] = result_item\n\n        console.print(f\"[green]Loaded {len(result)} samples for {self.task_name}[/green]\")\n\n        return result\n\n    def _make_metadata_serializable(\n        self,\n        idx: Any,\n        metadata_dict: dict,\n    ) -> dict:\n        \"\"\"Convert metadata to JSON-serializable format\"\"\"\n        del metadata_dict[\"answer\"]\n\n        metadata = {\n            \"row_index\": int(idx) if hasattr(idx, '__int__') else str(idx),\n            **metadata_dict,\n        }\n\n\n        return metadata\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/item_understand/__init__.py",
    "content": "\"\"\"\nItem Understand Task Module\n\"\"\"\n\nfrom .config import ITEM_UNDERSTAND_CONFIG\nfrom .evaluator import ItemUnderstandEvaluator\nfrom . import utils\n\n__all__ = [\n    \"ITEM_UNDERSTAND_CONFIG\",\n    \"ItemUnderstandEvaluator\",\n    \"utils\",\n]\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/item_understand/config.py",
    "content": "\"\"\"\nItem Understand Task Configuration\n\"\"\"\n\n# Item Understand Task Configuration\nITEM_UNDERSTAND_CONFIG = {\n    \"name\": \"item_understand\",\n    \"source\": \"Kuaishou Internal\",\n    \"splits\": [\"test\"],\n    \"size\": 500,\n    \"sample_size\": 500,\n    \"description\": \"Video SID to Caption generation task\",\n    \"data_fields\": {\n        \"messages_field\": \"messages\",\n        \"metadata_field\": \"metadata\",\n    },\n    \"prompt_config\": {\n        \"enable_thinking\": False,  # Enable thinking mode for apply_chat_template\n        \"custom_chat_template\": \"qwen3_soft_switch.jinja2\",  # Custom jinja2 template (file in v1_0 directory)\n    },\n    # Generation parameter configuration\n    \"generation_config\": {\n        \"num_return_sequences\": 1,\n        \"max_new_tokens\": 128,\n        \"temperature\": 0.01,\n        \"top_p\": 0.95,\n        \"repetition_penalty\": 1.0,\n        \"do_sample\": False,\n        \"num_return_thinking_sequences\": 1,\n        \"max_new_thinking_tokens\": 1000,\n    },\n    \"evaluation_config\": {\n        \"metrics\": [\"macro_wip_double_weighted_f1\", \"micro_wip_double_weighted_f1\"],\n        \"bertscore_model_type\": \"bert-base-chinese\",\n        \"bertscore_num_layers\": 9,\n        \"bertscore_lang\": \"zh\",\n        # WIP (Weighted Information Points) evaluation config\n        \"wip_enabled\": True,                      # Whether to enable WIP evaluation\n        \"wip_judge_model\": \"gemini\",             # Judge LLM type: gemini/deepseek/claude\n        \"wip_max_workers\": 1,                      # Concurrent workers for LLM calls\n        \"wip_core_threshold\": 5,                   # Core threshold for importance score (1-5)\n        \"wip_max_samples\": 500,                    # Max samples to evaluate (None for all)\n    }\n}\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/item_understand/evaluator.py",
    "content": "\"\"\"\nItem Understand Evaluator\n\nEvaluates model predictions on Item Understand task using WIP (LLM-as-Judge).\n\"\"\"\n\nimport os\nfrom typing import Dict, Any, Tuple, List\n\nfrom benchmark.console import console\nfrom benchmark.tasks.v1_0.base_evaluator import BaseEval\n\n\nclass ItemUnderstandEvaluator(BaseEval):\n    \"\"\"Item Understand task evaluator\"\"\"\n\n    @property\n    def required_metrics(self) -> List[str]:\n        \"\"\"Define required overall metrics for Item Understand evaluation\"\"\"\n        return [\"macro_wip_double_weighted_f1\"]\n\n\n    def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:\n        \"\"\"\n        Compute all metrics from scratch\n\n        Returns:\n            Tuple of (metrics, per_sample_metrics)\n        \"\"\"\n        total_samples = len(self.samples)\n\n        # Prepare data for evaluation\n        sample_ids = list(self.samples.keys())\n        predictions = []\n        references = []\n\n        for sample_id in sample_ids:\n            sample = self.samples[sample_id]\n\n            # Get ground truth\n            ground_truth = sample.get(\"ground_truth\", \"\")\n            references.append(ground_truth)\n\n            # Get model prediction (first generation)\n            generations = sample.get(\"generations\", [])\n            if not generations:\n                prediction = \"\"\n            else:\n                prediction = generations[0]\n            predictions.append(prediction)\n\n        # Get evaluation config\n        eval_config = self.task_config.get(\"evaluation_config\", {})\n\n        # Build per-sample metrics\n        per_sample_metrics = {}\n        for sample_id in sample_ids:\n            per_sample_metrics[sample_id] = {}\n\n        # Build overall metrics\n        metrics = {\n            \"num_samples\": total_samples\n        }\n\n        # WIP Evaluation (if enabled)\n        wip_enabled = eval_config.get(\"wip_enabled\", False)\n        if wip_enabled:\n            console.print(\"[cyan]WIP evaluation enabled, starting LLM-as-Judge evaluation...[/cyan]\")\n            wip_metrics, wip_per_sample = self._evaluate_wip(\n                sample_ids=sample_ids,\n                predictions=predictions,\n                references=references,\n                eval_config=eval_config\n            )\n\n            # Merge WIP metrics into overall metrics\n            metrics.update(wip_metrics)\n\n            # Merge WIP per-sample metrics\n            for sample_id in sample_ids:\n                if sample_id in wip_per_sample:\n                    per_sample_metrics[sample_id].update(wip_per_sample[sample_id])\n\n        # Save debug information if requested\n        if self.debug and self.predictions_dir:\n            self._save_debug_info(metrics, per_sample_metrics, predictions, references)\n\n        return metrics, per_sample_metrics\n\n    def _evaluate_wip(\n        self,\n        sample_ids: list,\n        predictions: list,\n        references: list,\n        eval_config: Dict[str, Any]\n    ) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:\n        \"\"\"\n        Perform WIP (Weighted Information Points) evaluation using LLM-as-Judge.\n\n        Args:\n            sample_ids: List of sample IDs\n            predictions: List of prediction texts\n            references: List of reference texts\n            eval_config: Evaluation configuration\n\n        Returns:\n            Tuple of (wip_metrics, wip_per_sample_metrics)\n        \"\"\"\n        try:\n            from api import get_client_from_config\n            from benchmark.tasks.v1_0.item_understand.utils import evaluate_wip\n        except ImportError as e:\n            console.print(f\"[red]Failed to import WIP evaluation modules: {e}[/red]\")\n            return {}, {}\n\n        # Get WIP config\n        wip_judge_model = eval_config.get(\"wip_judge_model\", \"deepseek\")\n        wip_max_workers = eval_config.get(\"wip_max_workers\", 5)\n        wip_max_samples = eval_config.get(\"wip_max_samples\", 100)\n        wip_core_threshold = eval_config.get(\"wip_core_threshold\", 5)\n        wip_gt_cache_dir = os.path.join(self.data_dir, self.task_name)  # Use data_dir / task_name as GT cache directory\n\n        # Use BERTScore config from evaluation_config (not separate wip config)\n        bertscore_model = eval_config.get(\"bertscore_model_type\", \"bert-base-chinese\")\n        bertscore_num_layers = eval_config.get(\"bertscore_num_layers\", 9)\n\n        # Create LLM client\n        try:\n            llm_client = get_client_from_config(wip_judge_model)\n            console.print(f\"[green]Using {wip_judge_model} as WIP judge[/green]\")\n        except Exception as e:\n            console.print(f\"[red]Failed to create LLM client for WIP evaluation: {e}[/red]\")\n            return {}, {}\n\n        # Prepare data as dicts\n        predictions_dict = {id: pred for id, pred in zip(sample_ids, predictions)}\n        references_dict = {id: ref for id, ref in zip(sample_ids, references)}\n\n        # Get model name for cache file naming\n        # Try to extract from llm_client config\n        model_name = getattr(llm_client, 'model_name', wip_judge_model)\n\n        # Run WIP evaluation\n        try:\n            wip_metrics, wip_per_sample = evaluate_wip(\n                predictions=predictions_dict,\n                references=references_dict,\n                llm_client=llm_client,\n                max_workers=wip_max_workers,\n                max_samples=wip_max_samples,\n                gt_cache_dir=wip_gt_cache_dir,\n                model_name=model_name,\n                save_dir=self.predictions_dir,\n                bertscore_model=bertscore_model,\n                bertscore_num_layers=bertscore_num_layers,\n                core_threshold=wip_core_threshold,\n            )\n\n            console.print(f\"[green]WIP evaluation completed: {wip_metrics.get('wip_num_samples', 0)} samples evaluated[/green]\")\n            return wip_metrics, wip_per_sample\n\n        except Exception as e:\n            console.print(f\"[red]WIP evaluation failed: {e}[/red]\")\n            import traceback\n            traceback.print_exc()\n            return {}, {}\n\n    def _save_debug_info(\n        self,\n        metrics: Dict[str, Any],\n        per_sample_metrics: Dict[str, Dict[str, Any]],\n        predictions: list,\n        references: list\n    ):\n        \"\"\"\n        Save detailed debug information to file\n\n        Args:\n            metrics: Overall metrics\n            per_sample_metrics: Per-sample metrics\n            predictions: List of predictions\n            references: List of references\n        \"\"\"\n        # Prepare debug info\n        debug_info = {\n            \"overall_metrics\": metrics,\n            \"per_sample_metrics\": per_sample_metrics,\n            \"sample_count\": len(predictions),\n        }\n\n        # Add some examples\n        sample_ids = list(self.samples.keys())\n        debug_info[\"examples\"] = []\n        for i in range(min(10, len(sample_ids))):\n            sample_id = sample_ids[i]\n            debug_info[\"examples\"].append({\n                \"sample_id\": sample_id,\n                \"prediction\": predictions[i],\n                \"reference\": references[i],\n                \"wip_unweighted_f1\": per_sample_metrics[sample_id].get(\"wip_unweighted_f1\"),\n                \"wip_unweighted_core_f1\": per_sample_metrics[sample_id].get(\"wip_unweighted_core_f1\"),\n                \"wip_importance_weighted_f1\": per_sample_metrics[sample_id].get(\"wip_importance_weighted_f1\"),\n                \"wip_importance_weighted_core_f1\": per_sample_metrics[sample_id].get(\"wip_importance_weighted_core_f1\"),\n                \"wip_double_weighted_f1\": per_sample_metrics[sample_id].get(\"wip_double_weighted_f1\"),\n                \"wip_double_weighted_core_f1\": per_sample_metrics[sample_id].get(\"wip_double_weighted_core_f1\"),\n            })\n\n        # Save to file using base class method\n        self._save_debug_json(debug_info, filename=\"debug.json\")\n\n        # Print summary statistics\n        console.print(f\"Total samples: {metrics['num_samples']}\")\n\n        # Print WIP metrics if available\n        if metrics.get('macro_wip_unweighted_f1') is not None:\n            console.print(f\"Macro WIP Unweighted F1: {metrics['macro_wip_unweighted_f1']:.4f}\")\n        if metrics.get('macro_wip_double_weighted_f1') is not None:\n            console.print(f\"Macro WIP Double-weighted F1: {metrics['macro_wip_double_weighted_f1']:.4f}\")\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/item_understand/utils.py",
    "content": "import json\nimport os\nimport re\nfrom typing import Dict, List, Any, Optional, Tuple\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\nfrom pathlib import Path\n\nimport pandas as pd\nfrom tqdm import tqdm\n\nfrom benchmark.console import console\n\n\nWIP_EXTRACTION_PROMPT = \"\"\"你是一位顶级的【信息抽取专家】，擅长从非结构化的文本中解析出结构化的信息。\n\n### 你的核心任务\n你的任务是分析我提供的描述性文字，并将其分解为结构化的【原子化且唯一】的\"信息点\"列表。\n\n### 输出结构\n对于列表中的每一个信息点，你必须提供：\n1.  **info_point**: 一个简洁的、陈述事实的短语。\n2.  **importance_score**: 一个 [1, 5] 之间的【整数】，代表该信息点的重要性。\n\n---\n### 关键原则 (必须遵守)\n\n1.  **原子性 (Atomic):** 每个 `info_point` 应只包含一个独立的事实。\n    * (好): `{{\"info_point\": \"女孩在吃饭\", \"importance_score\": 4}}`\n    * (差): `{{\"info_point\": \"女孩在吃饭，妈妈在旁边看\", \"importance_score\": 4}}`\n2.  **唯一性 (Unique):** 确保你提取的每个 `info_point` 都是**概念上唯一**的。\n3.  **合并 (Consolidate):** 如果原始文本中的多个短语描述的是【同一个核心思想】，你【必须】将它们合并成一个单一的、最具代表性的 `info_point`。\n    * (例如): 如果文本说 \"活动环境是温馨的\" 和 \"视频色彩营造温馨氛围\"，你应该只提取一个，如：`{{\"info_point\": \"视频氛围温馨\", \"importance_score\": 5}}`。\n    * **不要创建重复或语义高度重叠的条目。**\n\n---\n### 打分指南 (1-5分制)\n\n* **5分 (绝对核心):** 视频的\"灵魂\"。如果缺少这个点，整个摘要就毫无意义。（例如：\"如何制作煎蛋卷\"、\"XX游戏的评测\"）\n* **4分 (关键信息):** 视频的\"骨架\"。关键的事件、步骤或场景。（例如：\"打散三个鸡蛋\"、\"使用了不粘锅\"、\"游戏画面评测\"）\n* **3分 (重要细节):** 视频的\"肉\"。支撑骨架的具体、重要的细节。（例如：\"加入了盐和胡椒\"、\"用中火加热黄油\"、\"角色动作流畅\"）\n* **2分 (补充细节):** 补充性的上下文或次要信息。（例如：\"煎蛋卷折叠了三次\"、\"背景音乐很好听\"）\n* **1分 (琐碎信息):** 琐碎的、风格化的或背景性的描述。（例如：\"主持人穿着蓝色围裙\"、\"视频光线很好\"）\n\n---\n### 格式与示例\n\n你的输出必须是【纯粹的 JSON 格式】，可以被 `json.loads` 直接解析。JSON应包含一个 \"wips\" 键，其值为一个列表。如果文本中没有可提取的信息点，请返回 `{{\"wips\": []}}`。\n\n**[示例输入]**\n这是一段关于如何制作法式煎蛋卷的教程视频。主持人首先将三个鸡蛋打入碗中，并加入了盐和一小撮胡椒进行搅拌。视频强调了使用中火和不粘锅的重要性。接着，她在锅中融化了一块黄油，然后倒入蛋液。在烹饪过程中，她不断晃动平底锅，并将边缘的蛋液推向中心。最后，她将煎蛋卷折叠成三折，盛入盘中。整个过程非常快速。\n\n**[示例输出]**\n```json\n{{\n  \"wips\": [\n    {{\n      \"info_point\": \"教程：如何制作法式煎蛋卷\",\n      \"importance_score\": 5\n    }},\n    {{\n      \"info_point\": \"使用三个鸡蛋，加盐和胡椒搅拌\",\n      \"importance_score\": 3\n    }},\n    {{\n      \"info_point\": \"强调使用中火\",\n      \"importance_score\": 4\n    }},\n    {{\n      \"info_point\": \"使用不粘锅和黄油\",\n      \"importance_score\": 4\n    }},\n    {{\n      \"info_point\": \"晃动锅并将蛋液边缘推向中心\",\n      \"importance_score\": 3\n    }},\n    {{\n      \"info_point\": \"煎蛋卷被折叠成三折\",\n      \"importance_score\": 2\n    }},\n    {{\n      \"info_point\": \"烹饪过程快速\",\n      \"importance_score\": 1\n    }}\n  ]\n}}\n```\n\n现在，请开始分析我提供的描述性文字:\n{}\n\n你的输出结果 (请严格按照上述要求返回一个格式规整的 JSON，可以被 json.loads 直接解析。请不要在 JSON 数据前后添加任何额外的解释性文字或代码块标记): \"\"\"\n\n\nWIP_MATCHING_PROMPT = \"\"\"你是一位极其严谨的**语义匹配专家**。你的任务是精确地对比两组关于同一个视频摘要的结构化信息点 (WIPs)，并找出它们之间的匹配关系。\n\n**背景信息:**\n- **Ground Truth WIPs (GT列表)**: 这是视频摘要的\"事实标准\"，代表视频中真实存在的所有核心信息。每个点都有一个 [1-5] 的重要性分数 (`importance_score`)。\n- **Model-Generated WIPs (模型列表)**: 这是由一个AI模型生成的摘要信息点，代表它\"声称\"在视频中看到的内容。每个点也有一个 [1-5] 的重要性分数。\n\n**你的核心任务:**\n对比这两个列表，并输出一个包含三类结果的JSON对象：\n1.  **`matches`**: 一个匹配对的列表。对于\"模型列表\"中的每一个项，如果在\"GT列表\"中找到了一个**语义上非常相似**的对应项，就将它们配对。\n2.  **`unmatched_model_wips` (幻觉)**: \"模型列表\"中，那些在\"GT列表\"里找不到任何合理对应项的条目。这些代表了模型的**幻觉 (False Positives)**。\n3.  **`unmatched_gt_wips` (漏报)**: \"GT列表\"中，那些没有被\"模型列表\"中任何条目匹配到的条目。这些代表了模型的**漏报 (False Negatives)**。\n\n**至关重要的匹配规则:**\n1.  **语义核心**: 匹配的核心是 `info_point` 的语义。\n2.  **部分匹配**: 如果两个 `info_point` 语义上\"部分重叠\"但\"不完全相同\"，你【也应该】将它们匹配。\n    * (例如): GT的 `\"一场激烈精彩的篮球比赛\"` 和 Gen的 `\"球员在打篮球\"` 应该被【匹配】(因为核心\"篮球\"匹配上了)。\n    * (例如): GT的 `\"评测《魔龙巢穴：暗影崛起》\"` 和 Gen的 `\"评测《魔龙巢穴：冰封王座》\"` 应该被【匹配】(因为核心\"《魔龙巢穴》评测\"匹配上了)。\n3.  **一对一匹配**: 找出最佳的匹配组合。\n\n---\n**[输出结构示例]**\n\n**[输入]**\n- GT列表: `[\n    {{\"info_point\": \"节气是秋分\", \"importance_score\": 5}},\n    {{\"info_point\": \"农民在收割稻谷\", \"importance_score\": 4}}\n  ]`\n- 模型列表: `[\n    {{\"info_point\": \"这是一个关于秋分的视频\", \"importance_score\": 4}},\n    {{\"info_point\": \"狗在田里跑\", \"importance_score\": 1}}\n  ]`\n\n**[你的输出]**\n```json\n{{\n  \"matches\": [\n    {{\n      \"model_wip\": {{\"info_point\": \"这是一个关于秋分的视频\", \"importance_score\": 4}},\n      \"gt_wip\": {{\"info_point\": \"节气是秋分\", \"importance_score\": 5}}\n    }}\n  ],\n  \"unmatched_model_wips\": [\n    {{\n      \"info_point\": \"狗在田里跑\",\n      \"importance_score\": 1\n    }}\n  ],\n  \"unmatched_gt_wips\": [\n    {{\n      \"info_point\": \"农民在收割稻谷\",\n      \"importance_score\": 4\n    }}\n  ]\n}}\n```\n\n现在，请开始你的匹配工作:\n\n[Ground Truth WIPs (GT列表)]\n\n{}\n\n[Model-Generated WIPs (模型列表)]\n\n{}\n\n你的匹配结果 (请严格按照上述要求返回一个格式规整的 JSON，可以被 json.loads 直接解析。请不要在 JSON 数据前后添加任何额外的解释性文字或代码块标记): \"\"\"\n\n\ndef extract_json_from_response(response: str) -> Optional[Dict]:\n    \"\"\"\n    Extract JSON from LLM response (simplified version for well-behaved LLMs).\n    \"\"\"\n    if not response:\n        return None\n    \n    try:\n        response = response.rstrip('```').lstrip('```json')\n        return json.loads(response.strip())\n    except json.JSONDecodeError:\n        print(response)\n        return None\n\n\ndef extract_wips_single(\n    text: str,\n    llm_client\n) -> Tuple[Optional[List[Dict]], Optional[str]]:\n    \"\"\"\n    Extract WIPs from a single text using LLM.\n\n    Args:\n        text: Input text to extract WIPs from\n        llm_client: LLM client instance (with built-in retry mechanism)\n\n    Returns:\n        Tuple of (wips_list, error_message)\n        - wips_list: List of WIP dicts if successful, None if failed\n        - error_message: Error message if failed, None if successful\n    \"\"\"\n    prompt = WIP_EXTRACTION_PROMPT.format(text)\n\n    try:\n        response = llm_client.generate(prompt)\n        result = extract_json_from_response(response)\n\n        if result is not None and \"wips\" in result:\n            return result[\"wips\"], None\n\n        return None, \"Failed to parse JSON from response\"\n\n    except Exception as e:\n        return None, f\"API error: {str(e)}\"\n\n\ndef extract_wips_batch(\n    texts: Dict[str, str],\n    llm_client,\n    max_workers: int = 5,\n    desc: str = \"Extracting WIPs\"\n) -> Tuple[Dict[str, List[Dict]], Dict[str, str]]:\n    \"\"\"\n    Extract WIPs from multiple texts in parallel.\n\n    Args:\n        texts: Dict of {sample_id: text}\n        llm_client: LLM client instance (with built-in retry mechanism)\n        max_workers: Number of concurrent workers\n        desc: Progress bar description\n\n    Returns:\n        Tuple of (results, errors):\n        - results: Dict of {sample_id: wips_list}\n        - errors: Dict of {sample_id: error_message}\n    \"\"\"\n    results = {}\n    errors = {}\n\n    def process_single(sample_id: str, text: str):\n        wips, error = extract_wips_single(text, llm_client)\n        return sample_id, wips, error\n\n    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n        futures = {\n            executor.submit(process_single, sid, text): sid\n            for sid, text in texts.items()\n        }\n\n        for future in tqdm(as_completed(futures), total=len(futures), desc=desc):\n            sample_id, wips, error = future.result()\n            if wips:\n                results[sample_id] = wips\n            if error:\n                errors[sample_id] = error\n\n    # Statistics: count valid (non-empty) extraction results\n    total_attempted = len(texts)\n    total_parsed = len(results)\n    valid_results = sum(1 for wips in results.values() if wips)  # Count non-empty lists\n\n    console.print(f\"[cyan]{desc} statistics: {total_attempted} attempted, {total_parsed} parsed, {valid_results} valid (non-empty)[/cyan]\")\n\n    return results, errors\n\n\ndef match_wips_single(\n    gt_wips: List[Dict],\n    model_wips: List[Dict],\n    llm_client\n) -> Tuple[Optional[Dict], Optional[str]]:\n    \"\"\"\n    Match WIPs between ground truth and model generation.\n\n    Args:\n        gt_wips: Ground truth WIPs list\n        model_wips: Model-generated WIPs list\n        llm_client: LLM client instance (with built-in retry mechanism)\n\n    Returns:\n        Tuple of (match_result, error_message)\n    \"\"\"\n    gt_str = json.dumps(gt_wips, ensure_ascii=False, indent=2)\n    model_str = json.dumps(model_wips, ensure_ascii=False, indent=2)\n    prompt = WIP_MATCHING_PROMPT.format(gt_str, model_str)\n\n    try:\n        response = llm_client.generate(prompt)\n        result = extract_json_from_response(response)\n\n        if result is not None and all(k in result for k in [\"matches\", \"unmatched_model_wips\", \"unmatched_gt_wips\"]):\n            return result, None\n\n        return None, \"Failed to parse match JSON from response\"\n\n    except Exception as e:\n        return None, f\"API error: {str(e)}\"\n\n\ndef match_wips_batch(\n    gt_wips_dict: Dict[str, List[Dict]],\n    model_wips_dict: Dict[str, List[Dict]],\n    llm_client,\n    max_workers: int = 5\n) -> Tuple[Dict[str, Dict], Dict[str, str]]:\n    \"\"\"\n    Match WIPs for multiple samples in parallel.\n\n    Args:\n        gt_wips_dict: Dict of {sample_id: gt_wips_list}\n        model_wips_dict: Dict of {sample_id: model_wips_list}\n        llm_client: LLM client instance (with built-in retry mechanism)\n        max_workers: Number of concurrent workers\n\n    Returns:\n        Tuple of (results, errors)\n    \"\"\"\n    results = {}\n    errors = {}\n\n    # Only match samples that have both GT and model WIPs (and both are non-empty)\n    common_ids = {\n        id for id in (set(gt_wips_dict.keys()) & set(model_wips_dict.keys()))\n        if gt_wips_dict[id] and model_wips_dict[id]\n    }\n\n    def process_single(sample_id: str):\n        gt_wips = gt_wips_dict[sample_id]\n        model_wips = model_wips_dict[sample_id]\n        match_result, error = match_wips_single(gt_wips, model_wips, llm_client)\n        return sample_id, match_result, error\n\n    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n        futures = {\n            executor.submit(process_single, sid): sid\n            for sid in common_ids\n        }\n\n        for future in tqdm(as_completed(futures), total=len(futures), desc=\"Matching WIPs\"):\n            sample_id, match_result, error = future.result()\n            if match_result is not None:\n                results[sample_id] = match_result\n            if error is not None:\n                errors[sample_id] = error\n\n    # Statistics: count valid (non-empty) match results\n    total_attempted = len(common_ids)\n    total_parsed = len(results)\n    valid_results = 0\n\n    for sample_id, match_result in results.items():\n        # Check if result is not empty (has at least one non-empty field)\n        if match_result:\n            matches = match_result.get(\"matches\", [])\n            unmatched_model = match_result.get(\"unmatched_model_wips\", [])\n            unmatched_gt = match_result.get(\"unmatched_gt_wips\", [])\n\n            # Consider valid if result has any content\n            if matches or unmatched_model or unmatched_gt:\n                valid_results += 1\n\n    console.print(f\"[cyan]Matching statistics: {total_attempted} attempted, {total_parsed} parsed, {valid_results} valid (non-empty)[/cyan]\")\n\n    return results, errors\n\n\ndef get_wip_score_int(wip: Optional[Dict]) -> int:\n    \"\"\"Get importance score from WIP, defaulting to 1.\"\"\"\n    if not wip:\n        return 1\n    return wip.get(\"importance_score\", 1)\n\n\ndef calculate_unweighted_metrics(match_results: Dict[str, Dict], core_threshold: int = 5) -> Dict[str, Any]:\n    \"\"\"\n    Calculate unweighted metrics (count-based) with macro and per-sample versions.\n\n    Args:\n        match_results: Dict of {sample_id: match_result}\n        core_threshold: Threshold for core WIPs (importance_score >= threshold)\n\n    Returns:\n        Dict with macro F1, core versions, and per-sample F1s (unweighted)\n    \"\"\"\n    if not match_results:\n        return {}\n\n    # Per-sample metrics (for macro calculation)\n    per_sample = {}\n\n    for sample_id, result in match_results.items():\n        if not result:\n            per_sample[sample_id] = {\"overall_f1\": 0.0, \"core_f1\": 0.0}\n            continue\n\n        # Sample-level counts\n        sample_tp = len(result.get(\"matches\", []))\n        sample_fp = len(result.get(\"unmatched_model_wips\", []))\n        sample_fn = len(result.get(\"unmatched_gt_wips\", []))\n\n        sample_core_tp = 0\n        sample_core_fp = 0\n        sample_core_fn = 0\n\n        # Core: count only WIPs with importance_score >= threshold\n        for match in result.get(\"matches\", []):\n            gt_wip = match.get(\"gt_wip\", {})\n            if get_wip_score_int(gt_wip) >= core_threshold:\n                sample_core_tp += 1\n\n        for fp_wip in result.get(\"unmatched_model_wips\", []):\n            if get_wip_score_int(fp_wip) >= core_threshold:\n                sample_core_fp += 1\n\n        for fn_wip in result.get(\"unmatched_gt_wips\", []):\n            if get_wip_score_int(fn_wip) >= core_threshold:\n                sample_core_fn += 1\n\n        # Calculate per-sample F1s\n        sample_overall_f1 = 2 * sample_tp / (2 * sample_tp + sample_fp + sample_fn) if (2 * sample_tp + sample_fp + sample_fn) > 0 else 0.0\n        sample_core_f1 = 2 * sample_core_tp / (2 * sample_core_tp + sample_core_fp + sample_core_fn) if (2 * sample_core_tp + sample_core_fp + sample_core_fn) > 0 else 0.0\n\n        per_sample[sample_id] = {\n            \"overall_f1\": sample_overall_f1,\n            \"core_f1\": sample_core_f1,\n        }\n\n    # Calculate macro F1 (average of per-sample F1s)\n    valid_samples = [v for v in per_sample.values() if v]\n    macro_f1 = sum(s[\"overall_f1\"] for s in valid_samples) / len(valid_samples) if valid_samples else 0.0\n    macro_core_f1 = sum(s[\"core_f1\"] for s in valid_samples) / len(valid_samples) if valid_samples else 0.0\n\n    return {\n        \"macro_wip_unweighted_f1\": macro_f1,\n        \"macro_wip_unweighted_core_f1\": macro_core_f1,\n        \"per_sample\": per_sample,\n    }\n\n\ndef calculate_importance_weighted_metrics(\n    match_results: Dict[str, Dict],\n    core_threshold: int = 5\n) -> Dict[str, Any]:\n    \"\"\"\n    Calculate importance-weighted metrics (weighted by importance_score only) with macro and per-sample versions.\n\n    Args:\n        match_results: Dict of {sample_id: match_result}\n        core_threshold: Threshold for core WIPs (importance_score >= threshold)\n\n    Returns:\n        Dict with macro F1, core versions, and per-sample F1s (importance-weighted)\n    \"\"\"\n    if not match_results:\n        return {}\n\n    # Per-sample metrics (for macro calculation)\n    per_sample = {}\n\n    for sample_id, result in match_results.items():\n        if not result:\n            per_sample[sample_id] = {\"overall_f1\": 0.0, \"core_f1\": 0.0}\n            continue\n\n        # Sample-level metrics\n        sample_tp, sample_fp, sample_fn = 0.0, 0.0, 0.0\n        sample_core_tp, sample_core_fp, sample_core_fn = 0.0, 0.0, 0.0\n\n        # TP from matches (use GT score)\n        for match in result.get(\"matches\", []):\n            gt_wip = match.get(\"gt_wip\")\n            gt_score = get_wip_score_int(gt_wip)\n            sample_tp += gt_score\n            if gt_score >= core_threshold:\n                sample_core_tp += gt_score\n\n        # FP from unmatched model WIPs\n        for fp_wip in result.get(\"unmatched_model_wips\", []):\n            fp_score = get_wip_score_int(fp_wip)\n            sample_fp += fp_score\n            if fp_score >= core_threshold:\n                sample_core_fp += fp_score\n\n        # FN from unmatched GT WIPs\n        for fn_wip in result.get(\"unmatched_gt_wips\", []):\n            fn_score = get_wip_score_int(fn_wip)\n            sample_fn += fn_score\n            if fn_score >= core_threshold:\n                sample_core_fn += fn_score\n\n        # Calculate per-sample F1s\n        sample_overall_f1 = 2 * sample_tp / (2 * sample_tp + sample_fp + sample_fn) if (2 * sample_tp + sample_fp + sample_fn) > 0 else 0.0\n        sample_core_f1 = 2 * sample_core_tp / (2 * sample_core_tp + sample_core_fp + sample_core_fn) if (2 * sample_core_tp + sample_core_fp + sample_core_fn) > 0 else 0.0\n\n        per_sample[sample_id] = {\n            \"overall_f1\": sample_overall_f1,\n            \"core_f1\": sample_core_f1,\n        }\n\n    # Calculate macro F1 (average of per-sample F1s)\n    valid_samples = [v for v in per_sample.values() if v]\n    macro_f1 = sum(s[\"overall_f1\"] for s in valid_samples) / len(valid_samples) if valid_samples else 0.0\n    macro_core_f1 = sum(s[\"core_f1\"] for s in valid_samples) / len(valid_samples) if valid_samples else 0.0\n\n    return {\n        \"macro_wip_importance_weighted_f1\": macro_f1,\n        \"macro_wip_importance_weighted_core_f1\": macro_core_f1,\n        \"per_sample\": per_sample,\n    }\n\n\ndef calculate_double_weighted_metrics(\n    match_results: Dict[str, Dict],\n    core_threshold: int = 5,\n) -> Dict[str, Any]:\n    \"\"\"\n    Calculate double-weighted metrics using V6.2 logic (importance_score × match_quality) with macro and per-sample versions.\n\n    NOTE: This function now uses pre-computed match_quality from match results (no BERTScore computation here).\n\n    V6.2 Logic:\n    - For matched pairs:\n        - TP = gt_score × match_quality\n        - FN = gt_score × (1 - match_quality)\n        - FP = model_score × (1 - match_quality)\n    - For unmatched GT WIPs: FN += gt_score (complete miss)\n    - For unmatched model WIPs: FP += model_score (complete hallucination)\n\n    Args:\n        match_results: Dict of {sample_id: match_result} (with pre-computed match_quality)\n        core_threshold: Threshold for core WIPs (importance_score >= threshold)\n\n    Returns:\n        Dict with macro F1, core versions, and per-sample F1s (double-weighted)\n    \"\"\"\n    if not match_results:\n        return {}\n\n    # Per-sample metrics (for macro calculation)\n    per_sample = {}\n\n    for sample_id, result in match_results.items():\n        if not result:\n            per_sample[sample_id] = {\"overall_f1\": 0.0, \"core_f1\": 0.0}\n            continue\n\n        # Sample-level metrics\n        sample_tp, sample_fp, sample_fn = 0.0, 0.0, 0.0\n        sample_core_tp, sample_core_fp, sample_core_fn = 0.0, 0.0, 0.0\n\n        # Process matched pairs using pre-computed match_quality\n        matches = result.get(\"matches\", [])\n        for match in matches:\n            gt_wip = match.get(\"gt_wip\", {})\n            model_wip = match.get(\"model_wip\", {})\n            match_quality = match.get(\"match_quality\")\n\n            # Skip if match_quality not computed\n            if match_quality is None:\n                continue\n\n            gt_score = get_wip_score_int(gt_wip)\n            model_score = get_wip_score_int(model_wip)\n\n            # V6.2 formulas for all WIPs\n            tp_contrib = gt_score * match_quality\n            fn_contrib = gt_score * (1 - match_quality)\n            fp_contrib = model_score * (1 - match_quality)\n\n            sample_tp += tp_contrib\n            sample_fn += fn_contrib\n            sample_fp += fp_contrib\n\n            # Core: V6.2 formulas only for WIPs with importance_score >= threshold\n            if gt_score >= core_threshold:\n                sample_core_tp += tp_contrib\n                sample_core_fn += fn_contrib\n            if model_score >= core_threshold:\n                sample_core_fp += fp_contrib\n\n        # Complete misses (unmatched GT WIPs)\n        for fn_wip in result.get(\"unmatched_gt_wips\", []):\n            fn_score = get_wip_score_int(fn_wip)\n            sample_fn += fn_score\n            if fn_score >= core_threshold:\n                sample_core_fn += fn_score\n\n        # Complete hallucinations (unmatched model WIPs)\n        for fp_wip in result.get(\"unmatched_model_wips\", []):\n            fp_score = get_wip_score_int(fp_wip)\n            sample_fp += fp_score\n            if fp_score >= core_threshold:\n                sample_core_fp += fp_score\n\n        # Calculate per-sample F1s\n        sample_overall_f1 = 2 * sample_tp / (2 * sample_tp + sample_fp + sample_fn) if (2 * sample_tp + sample_fp + sample_fn) > 0 else 0.0\n        sample_core_f1 = 2 * sample_core_tp / (2 * sample_core_tp + sample_core_fp + sample_core_fn) if (2 * sample_core_tp + sample_core_fp + sample_core_fn) > 0 else 0.0\n\n        per_sample[sample_id] = {\n            \"overall_f1\": sample_overall_f1,\n            \"core_f1\": sample_core_f1,\n        }\n\n    # Calculate macro F1 (average of per-sample F1s)\n    valid_samples = [v for v in per_sample.values() if v]\n    macro_f1 = sum(s[\"overall_f1\"] for s in valid_samples) / len(valid_samples) if valid_samples else 0.0\n    macro_core_f1 = sum(s[\"core_f1\"] for s in valid_samples) / len(valid_samples) if valid_samples else 0.0\n\n    return {\n        \"macro_wip_double_weighted_f1\": macro_f1,\n        \"macro_wip_double_weighted_core_f1\": macro_core_f1,\n        \"per_sample\": per_sample,\n    }\n\n\ndef save_wip_detailed_results(\n    save_dir: str,\n    sample_ids: List[str],\n    gt_wips_dict: Dict[str, List[Dict]],\n    model_wips_dict: Dict[str, List[Dict]],\n    match_results_dict: Dict[str, Dict],\n    per_sample_f1s_dict: Dict[str, Dict],\n    predictions_dict: Dict[str, str],\n    references_dict: Dict[str, str],\n    filename: str = \"wip_results.json\"\n):\n    \"\"\"\n    Save detailed WIP evaluation results to file.\n\n    This saves:\n    - prediction: Model prediction text for each sample\n    - reference: Ground truth reference text for each sample\n    - gt_wips: Ground truth information points for each sample\n    - model_wips: Model-generated information points for each sample\n    - match_result: Matching results (matches, unmatched_gt_wips, unmatched_model_wips)\n    - per_sample_f1s: 6 F1 scores for each sample (3 types × 2 versions)\n\n    Args:\n        save_dir: Directory to save the file\n        sample_ids: List of sample IDs\n        gt_wips_dict: Dict of {sample_id: gt_wips_list}\n        model_wips_dict: Dict of {sample_id: model_wips_list}\n        match_results_dict: Dict of {sample_id: match_result}\n        per_sample_f1s_dict: Dict of {sample_id: f1_scores_dict} with 6 F1 scores\n        predictions_dict: Dict of {sample_id: prediction_text}\n        references_dict: Dict of {sample_id: reference_text}\n        filename: Output filename\n    \"\"\"\n    os.makedirs(save_dir, exist_ok=True)\n    save_path = os.path.join(save_dir, filename)\n\n    detailed_results = {}\n\n    for sample_id in sample_ids:\n        detailed_results[sample_id] = {\n            \"prediction\": predictions_dict.get(sample_id, \"\"),\n            \"reference\": references_dict.get(sample_id, \"\"),\n            \"gt_wips\": gt_wips_dict.get(sample_id, []),\n            \"model_wips\": model_wips_dict.get(sample_id, []),\n            \"match_result\": match_results_dict.get(sample_id, {}),\n            \"f1_scores\": per_sample_f1s_dict.get(sample_id, {}),\n        }\n\n    with open(save_path, 'w', encoding='utf-8') as f:\n        json.dump(detailed_results, f, ensure_ascii=False, indent=2)\n\n    console.print(f\"[green]WIP detailed results saved to {save_path}[/green]\")\n\n\ndef get_gt_cache_path(cache_dir: str, model_name: str) -> str:\n    \"\"\"Get the path for GT WIPs cache file.\"\"\"\n    return os.path.join(cache_dir, f\"test_gt_wip_{model_name}.parquet\")\n\n\ndef load_wip_results_cache(cache_path: str) -> Optional[Dict[str, Any]]:\n    \"\"\"\n    Load previously saved WIP results (model_wips and match_results).\n\n    Args:\n        cache_path: Path to wip_{model_name}.json file\n\n    Returns:\n        Dict with \"model_wips\" and \"match_results\" or None if not found\n    \"\"\"\n    if not os.path.exists(cache_path):\n        return None\n\n    try:\n        with open(cache_path, 'r', encoding='utf-8') as f:\n            data = json.load(f)\n\n        # Extract model_wips and match_results from saved file\n        model_wips = {}\n        match_results = {}\n\n        for sample_id, sample_data in data.items():\n            # Load model_wips except it's an empty list\n            if \"model_wips\" in sample_data and sample_data[\"model_wips\"]:\n                model_wips[sample_id] = sample_data[\"model_wips\"]\n            # Load match_result only if it's not empty (empty dict means no matching was done)\n            if \"match_result\" in sample_data and sample_data[\"match_result\"]:\n                match_results[sample_id] = sample_data[\"match_result\"]\n\n        return {\n            \"model_wips\": model_wips,\n            \"match_results\": match_results,\n        }\n\n    except Exception as e:\n        console.print(f\"[yellow]Failed to load WIP results cache: {e}[/yellow]\")\n        return None\n\n\ndef load_gt_wips_cache(cache_path: str) -> Optional[Dict[str, List[Dict]]]:\n    \"\"\"\n    Load GT WIPs from cache file.\n\n    Args:\n        cache_path: Path to parquet cache file\n\n    Returns:\n        Dict of {sample_id: wips_list} or None if not found\n    \"\"\"\n    if not os.path.exists(cache_path):\n        return None\n\n    try:\n        df = pd.read_parquet(cache_path)\n        result = {}\n        for _, row in df.iterrows():\n            sample_id = str(row[\"sample_id\"])\n            wips = row[\"wips\"]\n            if isinstance(wips, str):\n                wips = json.loads(wips)\n            if wips:\n                result[sample_id] = wips\n        console.print(f\"[green]Loaded GT WIPs cache from {cache_path} ({len(result)} samples)[/green]\")\n        return result\n    except Exception as e:\n        console.print(f\"[yellow]Failed to load GT cache: {e}[/yellow]\")\n        return None\n\n\ndef save_gt_wips_cache(gt_wips: Dict[str, List[Dict]], cache_path: str):\n    \"\"\"\n    Save GT WIPs to cache file.\n\n    Args:\n        gt_wips: Dict of {sample_id: wips_list}\n        cache_path: Path to save parquet file\n    \"\"\"\n    os.makedirs(os.path.dirname(cache_path), exist_ok=True)\n\n    data = []\n    for sample_id, wips in gt_wips.items():\n        data.append({\n            \"sample_id\": sample_id,\n            \"wips\": json.dumps(wips, ensure_ascii=False)\n        })\n\n    df = pd.DataFrame(data)\n    df.to_parquet(cache_path, index=False)\n    console.print(f\"[green]Saved GT WIPs cache to {cache_path} ({len(gt_wips)} samples)[/green]\")\n\n\ndef _load_or_extract_gt_wips(\n    sample_ids: List[str],\n    references: Dict[str, str],\n    llm_client,\n    max_workers: int,\n    gt_cache_dir: Optional[str],\n    model_name: str,\n) -> Dict[str, List[Dict]]:\n    \"\"\"Load GT WIPs from cache or extract if missing.\"\"\"\n    gt_wips = None\n    if gt_cache_dir:\n        cache_path = get_gt_cache_path(gt_cache_dir, model_name)\n        full_gt_cache = load_gt_wips_cache(cache_path)\n\n        if full_gt_cache is not None:\n            gt_wips = {id: full_gt_cache[id] for id in sample_ids if id in full_gt_cache}\n\n            missing_ids = set(sample_ids) - set(gt_wips.keys())\n            if missing_ids:\n                console.print(f\"[yellow]Missing {len(missing_ids)} samples in GT cache, extracting...[/yellow]\")\n                missing_refs = {id: references[id] for id in missing_ids}\n                new_gt_wips, gt_errors = extract_wips_batch(\n                    missing_refs, llm_client, max_workers, \"Extracting GT WIPs\"\n                )\n                gt_wips.update(new_gt_wips)\n                full_gt_cache.update(new_gt_wips)\n                save_gt_wips_cache(full_gt_cache, cache_path)\n\n                if gt_errors:\n                    console.print(f\"[red]GT extraction errors: {len(gt_errors)} samples[/red]\")\n\n    if gt_wips is None:\n        console.print(\"[cyan]Extracting GT WIPs...[/cyan]\")\n        gt_wips, gt_errors = extract_wips_batch(\n            references, llm_client, max_workers, \"Extracting GT WIPs\"\n        )\n\n        if gt_cache_dir:\n            cache_path = get_gt_cache_path(gt_cache_dir, model_name)\n            save_gt_wips_cache(gt_wips, cache_path)\n\n        if gt_errors:\n            console.print(f\"[red]GT extraction errors: {len(gt_errors)} samples[/red]\")\n\n    return gt_wips\n\n# Extract text after last </think> tag if present\ndef extract_after_think(text: str) -> str:\n    \"\"\"Extract text after the last </think> tag\"\"\"\n    if '</think>' in text:\n        return text.split('</think>')[-1].strip()\n    return text\n\ndef _load_or_extract_model_wips(\n    sample_ids: List[str],\n    predictions: Dict[str, str],\n    gt_wips: Dict[str, List[Dict]],\n    llm_client,\n    max_workers: int,\n    save_dir: Optional[str],\n    model_name: str,\n) -> Dict[str, List[Dict]]:\n    \"\"\"Load Model WIPs from cache or extract if missing (incremental).\"\"\"\n    model_wips = {}\n\n    if save_dir:\n        wip_cache_path = os.path.join(save_dir, f\"wip_{model_name}.json\")\n        cached_data = load_wip_results_cache(wip_cache_path)\n        if cached_data:\n            model_wips = cached_data.get(\"model_wips\", {})\n            console.print(f\"[green]Loaded {len(model_wips)} cached model_wips[/green]\")\n\n    missing_model_ids = set(sample_ids) - set(model_wips.keys())\n    missing_model_ids = {id for id in missing_model_ids if id in gt_wips}\n\n    if missing_model_ids:\n        console.print(f\"[cyan]Extracting Model WIPs for {len(missing_model_ids)} missing samples...[/cyan]\")\n        missing_predictions = {id: extract_after_think(predictions[id]) for id in missing_model_ids}\n        new_model_wips, model_errors = extract_wips_batch(\n            missing_predictions, llm_client, max_workers, \"Extracting Model WIPs\"\n        )\n        model_wips.update(new_model_wips)\n\n        if model_errors:\n            console.print(f\"[red]Model extraction errors: {len(model_errors)} samples[/red]\")\n    else:\n        console.print(f\"[green]All {len(sample_ids)} samples already have Model WIPs (from cache)[/green]\")\n\n    return model_wips\n\n\ndef _load_or_match_wips(\n    sample_ids: List[str],\n    gt_wips: Dict[str, List[Dict]],\n    model_wips: Dict[str, List[Dict]],\n    llm_client,\n    max_workers: int,\n    save_dir: Optional[str],\n    model_name: str,\n) -> Dict[str, Dict]:\n    \"\"\"Load match results from cache or match if missing (incremental).\"\"\"\n    match_results = {}\n\n    if save_dir:\n        wip_cache_path = os.path.join(save_dir, f\"wip_{model_name}.json\")\n        cached_data = load_wip_results_cache(wip_cache_path)\n        if cached_data:\n            match_results = cached_data.get(\"match_results\", {})\n            console.print(f\"[green]Loaded {len(match_results)} cached match_results[/green]\")\n\n    missing_match_ids = set(sample_ids) - set(match_results.keys())\n    # Only match if both gt_wips and model_wips exist and are non-empty\n    missing_match_ids = {\n        id for id in missing_match_ids\n        if id in gt_wips and id in model_wips and gt_wips[id] and model_wips[id]\n    }\n\n    if missing_match_ids:\n        console.print(f\"[cyan]Matching WIPs for {len(missing_match_ids)} missing samples...[/cyan]\")\n        missing_gt_wips = {id: gt_wips[id] for id in missing_match_ids}\n        missing_model_wips = {id: model_wips[id] for id in missing_match_ids}\n        new_match_results, match_errors = match_wips_batch(\n            missing_gt_wips, missing_model_wips, llm_client, max_workers\n        )\n        match_results.update(new_match_results)\n\n        if match_errors:\n            console.print(f\"[red]Matching errors: {len(match_errors)} samples[/red]\")\n    else:\n        console.print(f\"[green]All {len(sample_ids)} samples already have match results (from cache)[/green]\")\n\n    return match_results\n\n\ndef _compute_bertscore_incremental(\n    sample_ids: List[str],\n    match_results: Dict[str, Dict],\n    bertscore_model: str,\n    bertscore_num_layers: int,\n) -> None:\n    \"\"\"Compute BERTScore for matches that don't have it yet (incremental, in-place update).\"\"\"\n    import torch\n    from bert_score import BERTScorer\n\n    console.print(\"[cyan]Computing BERTScore for matched pairs...[/cyan]\")\n\n    all_gt_texts = []\n    all_model_texts = []\n    sample_match_indices = []\n\n    for sample_id in sample_ids:\n        if sample_id in match_results:\n            matches = match_results[sample_id].get(\"matches\", [])\n            for match_idx, match in enumerate(matches):\n                if match.get(\"match_quality\") is not None:\n                    continue\n\n                gt_wip = match.get(\"gt_wip\")\n                model_wip = match.get(\"model_wip\")\n\n                # Skip if either wip is None or not a dict\n                if not gt_wip or not isinstance(gt_wip, dict) or not model_wip or not isinstance(model_wip, dict):\n                    continue\n\n                gt_text = gt_wip.get(\"info_point\", \"\")\n                model_text = model_wip.get(\"info_point\", \"\")\n\n                if gt_text and model_text:\n                    batch_idx = len(all_gt_texts)\n                    all_gt_texts.append(gt_text)\n                    all_model_texts.append(model_text)\n                    sample_match_indices.append((sample_id, match_idx, batch_idx))\n\n    if all_gt_texts and all_model_texts:\n        console.print(f\"[cyan]Computing BERTScore for {len(all_gt_texts)} new matched pairs...[/cyan]\")\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        scorer = BERTScorer(\n            model_type=bertscore_model,\n            num_layers=bertscore_num_layers,\n            device=device,\n            lang=\"zh\",\n            rescale_with_baseline=False,\n        )\n\n        try:\n            P, R, F1 = scorer.score(all_model_texts, all_gt_texts)\n            match_qualities = F1.tolist()\n\n            for sample_id, match_idx, batch_idx in sample_match_indices:\n                match_results[sample_id][\"matches\"][match_idx][\"match_quality\"] = match_qualities[batch_idx]\n\n            console.print(f\"[green]Computed BERTScore for {len(match_qualities)} matched pairs[/green]\")\n        except Exception as e:\n            console.print(f\"[red]BERTScore computation failed: {e}[/red]\")\n            for sample_id, match_idx, _ in sample_match_indices:\n                match_results[sample_id][\"matches\"][match_idx][\"match_quality\"] = None\n    else:\n        console.print(f\"[green]All matches already have BERTScore (from cache)[/green]\")\n\n\ndef evaluate_wip(\n    predictions: Dict[str, str],\n    references: Dict[str, str],\n    llm_client,\n    max_workers: int = 5,\n    max_samples: Optional[int] = None,\n    gt_cache_dir: Optional[str] = None,\n    model_name: str = \"unknown\",\n    save_dir: Optional[str] = None,\n    bertscore_model: str = \"bert-base-chinese\",\n    bertscore_num_layers: int = 9,\n    core_threshold: int = 5,\n) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:\n    \"\"\"\n    Main WIP evaluation function with six types of F1 metrics (3 types × 2 versions).\n\n    Computes:\n    1. Unweighted F1 (count-based) - overall and core\n    2. Importance-weighted F1 (weighted by importance_score only) - overall and core\n    3. Double-weighted F1 (V6.2 logic: importance_score × match_quality) - overall and core\n\n    Args:\n        predictions: Dict of {sample_id: prediction_text}\n        references: Dict of {sample_id: reference_text}\n        llm_client: LLM client instance for judge (with built-in retry mechanism)\n        max_workers: Number of concurrent workers\n        max_samples: Maximum samples to evaluate (None for all)\n        gt_cache_dir: Directory for GT WIPs cache\n        model_name: Model name for cache file naming\n        save_dir: Directory to save results\n        bertscore_model: BERT model for computing match quality\n        bertscore_num_layers: Number of layers for BERTScore\n        core_threshold: Threshold for core WIPs (importance_score >= threshold)\n\n    Returns:\n        Tuple of (metrics, per_sample_metrics):\n        - metrics: Dict with 6 F1 scores and sample count (flattened)\n        - per_sample_metrics: Dict of {sample_id: {6 F1 scores}}\n    \"\"\"\n    # Select samples (sorted by sample_id for consistency)\n    all_sample_ids = sorted(set(predictions.keys()) & set(references.keys()))\n\n    if max_samples is not None and max_samples < len(all_sample_ids):\n        sample_ids = all_sample_ids[:max_samples]\n        console.print(f\"[cyan]Selected {len(sample_ids)} samples for WIP evaluation (sorted by sample_id)[/cyan]\")\n    else:\n        sample_ids = all_sample_ids\n        console.print(f\"[cyan]Evaluating all {len(sample_ids)} samples for WIP[/cyan]\")\n\n    selected_predictions = {id: predictions[id] for id in sample_ids}\n    selected_references = {id: references[id] for id in sample_ids}\n\n    # Step 1: Load/Extract GT WIPs\n    gt_wips = _load_or_extract_gt_wips(\n        sample_ids, selected_references, llm_client, max_workers, gt_cache_dir, model_name\n    )\n\n    # Step 2: Load/Extract Model WIPs (incremental)\n    model_wips = _load_or_extract_model_wips(\n        sample_ids, selected_predictions, gt_wips, llm_client, max_workers, save_dir, model_name\n    )\n\n    # Step 3: Load/Match WIPs (incremental)\n    match_results = _load_or_match_wips(\n        sample_ids, gt_wips, model_wips, llm_client, max_workers, save_dir, model_name\n    )\n\n    # Step 3.5: Compute BERTScore (incremental)\n    _compute_bertscore_incremental(sample_ids, match_results, bertscore_model, bertscore_num_layers)\n\n    # Step 4: Calculate three types of metrics (each with overall, core, and per-sample)\n    console.print(\"[cyan]Computing all metrics (unweighted, importance-weighted, double-weighted)...[/cyan]\")\n\n    # 4.1: Unweighted metrics\n    unweighted_metrics = calculate_unweighted_metrics(\n        match_results,\n        core_threshold=core_threshold\n    )\n\n    # 4.2: Importance-weighted metrics\n    importance_metrics = calculate_importance_weighted_metrics(\n        match_results,\n        core_threshold=core_threshold\n    )\n\n    # 4.3: Double-weighted metrics (using pre-computed BERTScore)\n    double_metrics = calculate_double_weighted_metrics(\n        match_results,\n        core_threshold=core_threshold\n    )\n\n    # Flattened overall metrics (6 F1 scores: 3 types × 2 versions)\n    metrics = {\n        # Macro F1 (average of per-sample F1s)\n        \"macro_wip_unweighted_f1\": unweighted_metrics.get(\"macro_wip_unweighted_f1\", 0.0),\n        \"macro_wip_unweighted_core_f1\": unweighted_metrics.get(\"macro_wip_unweighted_core_f1\", 0.0),\n        \"macro_wip_importance_weighted_f1\": importance_metrics.get(\"macro_wip_importance_weighted_f1\", 0.0),\n        \"macro_wip_importance_weighted_core_f1\": importance_metrics.get(\"macro_wip_importance_weighted_core_f1\", 0.0),\n        \"macro_wip_double_weighted_f1\": double_metrics.get(\"macro_wip_double_weighted_f1\", 0.0),\n        \"macro_wip_double_weighted_core_f1\": double_metrics.get(\"macro_wip_double_weighted_core_f1\", 0.0),\n        \"wip_num_samples\": len(match_results),\n    }\n\n    # Merge per-sample metrics from all three types (6 F1 scores per sample, same as before since per-sample is used for macro)\n    per_sample_metrics = {}\n    for sample_id in sample_ids:\n        unweighted_per_sample = unweighted_metrics.get(\"per_sample\", {}).get(sample_id, {\"overall_f1\": 0.0, \"core_f1\": 0.0})\n        importance_per_sample = importance_metrics.get(\"per_sample\", {}).get(sample_id, {\"overall_f1\": 0.0, \"core_f1\": 0.0})\n        double_per_sample = double_metrics.get(\"per_sample\", {}).get(sample_id, {\"overall_f1\": 0.0, \"core_f1\": 0.0})\n\n        per_sample_metrics[sample_id] = {\n            \"wip_unweighted_f1\": unweighted_per_sample[\"overall_f1\"],\n            \"wip_unweighted_core_f1\": unweighted_per_sample[\"core_f1\"],\n            \"wip_importance_weighted_f1\": importance_per_sample[\"overall_f1\"],\n            \"wip_importance_weighted_core_f1\": importance_per_sample[\"core_f1\"],\n            \"wip_double_weighted_f1\": double_per_sample[\"overall_f1\"],\n            \"wip_double_weighted_core_f1\": double_per_sample[\"core_f1\"],\n        }\n\n    # Step 5: Save detailed results to file\n    if save_dir:\n        console.print(\"[cyan]Saving WIP detailed results...[/cyan]\")\n        save_wip_detailed_results(\n            save_dir=save_dir,\n            sample_ids=sample_ids,\n            gt_wips_dict=gt_wips,\n            model_wips_dict=model_wips,\n            match_results_dict=match_results,\n            per_sample_f1s_dict=per_sample_metrics,\n            predictions_dict=selected_predictions,\n            references_dict=selected_references,\n            filename=f\"wip_{model_name}.json\"\n        )\n\n    console.print(f\"[green]WIP evaluation completed: {metrics['wip_num_samples']} samples[/green]\")\n    console.print(f\"[green]  Macro Unweighted F1: {metrics['macro_wip_unweighted_f1']:.4f} (Core: {metrics['macro_wip_unweighted_core_f1']:.4f})[/green]\")\n    console.print(f\"[green]  Macro Importance-weighted F1: {metrics['macro_wip_importance_weighted_f1']:.4f} (Core: {metrics['macro_wip_importance_weighted_core_f1']:.4f})[/green]\")\n    console.print(f\"[green]  Macro Double-weighted F1: {metrics['macro_wip_double_weighted_f1']:.4f} (Core: {metrics['macro_wip_double_weighted_core_f1']:.4f})[/green]\")\n\n    return metrics, per_sample_metrics\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/label_pred/__init__.py",
    "content": "\"\"\"\nLabel Prediction Task Module\n\nClassification task for predicting user engagement with video content.\nUses logprobs-based classification with AUC and wuAUC metrics.\n\"\"\"\n\nfrom .config import LABEL_PRED_CONFIG\nfrom .evaluator import LabelPredEvaluator\nfrom . import utils\n\n__all__ = [\n    \"LABEL_PRED_CONFIG\",\n    \"LabelPredEvaluator\",\n    \"utils\",\n]\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/label_pred/config.py",
    "content": "\"\"\"\nLabel Prediction Task Configuration\n\nThis is a classification task for predicting user engagement with video content.\nUses logprobs-based classification with AUC and wuAUC metrics.\n\"\"\"\n\n# Label Pred Task Configuration\nLABEL_PRED_CONFIG = {\n    \"name\": \"label_pred\",\n    \"source\": \"Kuaishou Internal\",\n    \"splits\": [\"test\"],\n    \"size\": 346190,\n    \"sample_size\": 346190,\n    \"description\": \"Predict user engagement with video content (yes/no classification)\",\n    \"data_fields\": {\n        \"messages_field\": \"messages\",\n        \"metadata_field\": \"metadata\",\n    },\n    \"prompt_config\": {\n        \"enable_thinking\": False,  # Enable thinking mode for apply_chat_template\n        \"custom_chat_template\": \"qwen3_soft_switch.jinja2\",  # Custom jinja2 template (file in v1_0 directory)\n    },\n    \"generation_config\": {\n        \"max_new_tokens\": 1,\n        \"temperature\": 1,\n        \"top_p\": 1,\n        \"top_k\": -1,\n        \"do_sample\": True,\n        \"num_return_sequences\": 1,\n        \"return_logprobs\": True,  # Need to return logprobs for probability extraction\n        \"logprobs\": 10000,  # Return top-10 logprobs to ensure \"是\" and \"否\" are included\n        \"target_tokens\": [\"是\", \"否\"],  # Target tokens for logprobs extraction (classification)\n        \"max_new_thinking_tokens\": 1000\n    },\n    \"evaluation_config\": {\n        \"metrics\": [\"auc\"],\n    },\n    \"task_type\": \"logprobs_classification\",  # Special task type\n}\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/label_pred/evaluator.py",
    "content": "\"\"\"\nLabel Prediction Task Evaluator\n\nEvaluator for label_pred classification task.\nComputes AUC metric from logprobs-based predictions.\n\"\"\"\n\nfrom typing import Dict, Any, Tuple, List\n\nfrom benchmark.console import console\nfrom benchmark.tasks.v1_0.base_evaluator import BaseEval\nfrom benchmark.tasks.v1_0.label_pred.utils import (\n    extract_label_from_answer,\n    extract_probability_from_logprobs,\n    calculate_auc,\n    get_debug_info,\n)\n\n\nclass LabelPredEvaluator(BaseEval):\n    \"\"\"\n    Label prediction task evaluator\n\n    This is a classification task for predicting user engagement.\n    Uses logprobs-based predictions to compute AUC metric.\n\n    Metrics:\n    - AUC: Area Under ROC Curve\n    \"\"\"\n\n    @property\n    def required_metrics(self) -> List[str]:\n        \"\"\"Define required overall metrics for label prediction evaluation\"\"\"\n        return [\"auc\"]\n\n    def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:\n        \"\"\"\n        Compute all evaluation metrics from scratch\n\n        Extracts probabilities from logprobs and computes AUC metric.\n        Also stores per-sample metrics back into self.samples for caching.\n\n        Returns:\n            Tuple of (metrics, per_sample_metrics):\n            - metrics: Overall metrics including auc, etc.\n            - per_sample_metrics: Per-sample evaluation results\n        \"\"\"\n        total_samples = len(self.samples)\n\n        # Extract predictions and labels\n        predictions = {}  # {sample_id: probability}\n        labels = {}  # {sample_id: 0 or 1}\n        \n        # Per-sample metrics\n        per_sample_metrics = {}\n        \n        # Debug information collection\n        debug_info = {\n            \"correct_predictions\": [],\n            \"incorrect_predictions\": [],\n            \"invalid_samples\": [],\n        }\n        \n        for sample_id, sample in self.samples.items():\n            # Get ground truth answer\n            ground_truth = sample.get(\"ground_truth\", \"\")\n            \n            # Extract label from ground truth\n            label = extract_label_from_answer(ground_truth)\n            \n            if label == -1:\n                # Invalid label\n                console.print(f\"[yellow]Sample {sample_id}: unrecognized answer '{ground_truth}'[/yellow]\")\n                if self.debug:\n                    debug_info[\"invalid_samples\"].append({\n                        \"sample_id\": sample_id,\n                        \"ground_truth\": ground_truth,\n                        \"reason\": \"unrecognized_label\"\n                    })\n                continue\n            \n            labels[sample_id] = label\n\n            # Get model prediction (logprobs dictionary)\n            # For label_pred, generations contains {token: probability} dict\n            generations = sample.get(\"generations\", {})\n\n            # Variables to store probability extraction results\n            predicted_prob = 0.5\n            parsed_probs = None\n            normalized_probs = None\n\n            if not generations:\n                # No generation - log as invalid sample\n                console.print(f\"[yellow]Sample {sample_id}: no generation found[/yellow]\")\n                if self.debug:\n                    debug_info[\"invalid_samples\"].append({\n                        \"sample_id\": sample_id,\n                        \"ground_truth\": ground_truth,\n                        \"reason\": \"no_generation\"\n                    })\n                # Skip this sample - don't include in predictions\n                continue\n            else:\n                try:\n                    # Extract probability for positive class (\"是\")\n                    # Now returns dict with parsed_probs, normalized_probs, and score\n                    prob_result = extract_probability_from_logprobs(\n                        generations,\n                        positive_token=\"是\",\n                        negative_token=\"否\",\n                        sample_id=sample_id\n                    )\n\n                    predicted_prob = prob_result[\"score\"]\n                    parsed_probs = prob_result[\"parsed_probs\"]\n                    normalized_probs = prob_result[\"normalized_probs\"]\n\n                except ValueError as e:\n                    # Parsing failed - log detailed error and skip sample\n                    console.print(f\"[red]Sample {sample_id}: {str(e)}[/red]\")\n                    if self.debug:\n                        debug_info[\"invalid_samples\"].append({\n                            \"sample_id\": sample_id,\n                            \"ground_truth\": ground_truth,\n                            \"reason\": \"parsing_error\",\n                            \"error\": str(e)\n                        })\n                    # Skip this sample - don't include in predictions\n                    continue\n\n            predictions[sample_id] = predicted_prob\n\n            # Store per-sample metrics (both in return dict and in self.samples for caching)\n            sample_metrics = {\n                \"label\": label,\n                \"predicted_prob\": predicted_prob,\n            }\n            per_sample_metrics[sample_id] = sample_metrics\n\n            # Cache metrics in self.samples for future use, including debug info\n            self.samples[sample_id][\"label\"] = label\n            self.samples[sample_id][\"predicted_prob\"] = predicted_prob\n\n            # Add new debug fields to sample for tracking\n            self.samples[sample_id][\"y_true\"] = label\n            self.samples[sample_id][\"y_score\"] = predicted_prob\n            if parsed_probs is not None:\n                self.samples[sample_id][\"parsed_probs\"] = parsed_probs\n            if normalized_probs is not None:\n                self.samples[sample_id][\"normalized_probs\"] = normalized_probs\n\n            # Debug information collection\n            if self.debug:\n                debug_item = get_debug_info(\n                    sample_id=sample_id,\n                    logprobs_dict=parsed_probs,\n                    predicted_prob=predicted_prob,\n                    ground_truth=ground_truth,\n                    label=label,\n                )\n                # Determine if prediction is correct\n                # Correct: (predicted_prob > 0.5 and label = 1) OR (predicted_prob <= 0.5 and label = 0)\n                is_correct = (predicted_prob > 0.5 and label == 1) or (predicted_prob <= 0.5 and label == 0)\n\n                if is_correct:\n                    debug_info[\"correct_predictions\"].append(debug_item)\n                else:\n                    debug_info[\"incorrect_predictions\"].append(debug_item)\n        \n        # Calculate AUC\n        auc = calculate_auc(predictions, labels)\n\n        # Prepare overall metrics\n        metrics = {\n            \"auc\": auc,\n            \"total_samples\": total_samples,\n            \"valid_samples\": len(labels),\n            \"invalid_samples\": len(debug_info[\"invalid_samples\"]) if self.debug else 0,\n        }\n\n        # Save debug information if requested\n        if self.debug and self.predictions_dir:\n            self._save_debug_info(debug_info, metrics)\n\n        return metrics, per_sample_metrics\n\n    def _save_debug_info(\n        self,\n        debug_info: Dict[str, Any],\n        metrics: Dict[str, Any],\n    ):\n        \"\"\"\n        Save detailed debug information to file\n\n        Args:\n            debug_info: Debug information dictionary\n            metrics: Overall metrics\n        \"\"\"\n        # Add statistics to debug_info\n        debug_info[\"statistics\"] = {\n            \"total_samples\": metrics[\"total_samples\"],\n            \"valid_samples\": metrics[\"valid_samples\"],\n            \"correct_predictions_count\": len(debug_info[\"correct_predictions\"]),\n            \"incorrect_predictions_count\": len(debug_info[\"incorrect_predictions\"]),\n            \"invalid_samples_count\": len(debug_info[\"invalid_samples\"]),\n        }\n\n        # Add metrics\n        debug_info[\"metrics\"] = metrics\n\n        # Save debug info to file using base class method\n        self._save_debug_json(debug_info, filename=\"debug.json\")\n\n        console.print(f\"Total samples: {metrics['total_samples']}\")\n        console.print(f\"Valid samples: {metrics['valid_samples']}\")\n        console.print(f\"Correct predictions: {len(debug_info['correct_predictions'])}\")\n        console.print(f\"Incorrect predictions: {len(debug_info['incorrect_predictions'])}\")\n\n        # Calculate and display accuracy if we have valid predictions\n        total_predictions = len(debug_info['correct_predictions']) + len(debug_info['incorrect_predictions'])\n        if total_predictions > 0:\n            accuracy = len(debug_info['correct_predictions']) / total_predictions * 100\n            console.print(f\"Accuracy: {accuracy:.2f}%\")\n\n        console.print(f\"Invalid samples: {len(debug_info['invalid_samples'])}\")\n\n        # Print metrics\n        console.print(\"\\n[bold]Metrics:[/bold]\")\n        console.print(f\"  AUC: {metrics['auc']:.4f}\")\n\n        # Show some invalid sample examples\n        if debug_info[\"invalid_samples\"]:\n            console.print(f\"\\n[yellow]Invalid sample examples (first 3):[/yellow]\")\n            for i, item in enumerate(debug_info[\"invalid_samples\"][:3]):\n                console.print(f\"  Example {i+1}:\")\n                console.print(f\"    Sample ID: {item['sample_id']}\")\n                console.print(f\"    Reason: {item['reason']}\")\n                console.print(f\"    Ground truth: {item['ground_truth']}\")\n                console.print()\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/label_pred/utils.py",
    "content": "\"\"\"\nLabel Prediction Task Utilities\n\nFunctions for label extraction, probability processing, and AUC/wuAUC computation.\n\"\"\"\n\nimport json\nimport numpy as np\nfrom typing import Dict, Tuple, Any, List\nfrom sklearn.metrics import roc_auc_score\nfrom benchmark.console import console\n\n\ndef extract_label_from_answer(answer: str) -> int:\n    \"\"\"\n    Extract binary label from answer string\n    \n    Args:\n        answer: Answer string (e.g., \"是<|im_end|>\" or \"否<|im_end|>\")\n    \n    Returns:\n        1 if positive (\"是\"), 0 if negative (\"否\"), -1 if unrecognized\n    \n    Examples:\n        >>> extract_label_from_answer(\"是<|im_end|>\")\n        1\n        >>> extract_label_from_answer(\"否\")\n        0\n    \"\"\"\n    if \"是\" in answer:\n        return 1\n    elif \"否\" in answer:\n        return 0\n    else:\n        return -1\n\n\ndef extract_probability_from_logprobs(\n    generations: List[str],\n    positive_token: str = \"是\",\n    negative_token: str = \"否\",\n    sample_id: str = None\n) -> Dict[str, Any]:\n    \"\"\"\n    Extract probability for positive class from generations list containing reasoning and JSON probabilities.\n    Applies softmax normalization to ensure probabilities sum to 1.\n\n    Args:\n        generations: List of strings, each containing reasoning text followed by JSON probabilities\n        positive_token: Token representing positive class (default \"是\")\n        negative_token: Token representing negative class (default \"否\")\n        sample_id: Optional sample ID for error messages\n\n    Returns:\n        Dictionary containing:\n        - parsed_probs: Original parsed probabilities before normalization\n        - normalized_probs: Softmax normalized probabilities\n        - score: Final positive class probability (after normalization)\n\n    Raises:\n        ValueError: If JSON parsing fails or required tokens are missing\n\n    Examples:\n        >>> generations = ['</think>\\\\n{\"是\": 0.7, \"否\": 0.3}']\n        >>> result = extract_probability_from_logprobs(generations)\n        >>> result['score']\n        0.7\n        >>> result['normalized_probs']\n        {'是': 0.7, '否': 0.3}\n    \"\"\"\n    parsed_list = []\n    normalized_list = []\n    scores = []\n\n    for idx, generation in enumerate(generations):\n        # Extract JSON part: check for </think> tag first\n        if \"</think>\" in generation:\n            # Extract content after </think>\n            json_str = generation.split(\"</think>\")[-1].strip()\n        else:\n            # No </think> tag, try to parse the entire string\n            json_str = generation.strip()\n\n        # Parse JSON and extract probability\n        try:\n            probs_dict = json.loads(json_str)\n\n            # Validate that it's a dict and contains required tokens\n            if not isinstance(probs_dict, dict):\n                raise ValueError(f\"Parsed JSON is not a dictionary: {type(probs_dict)}\")\n\n            if positive_token not in probs_dict:\n                raise ValueError(f\"Positive token '{positive_token}' not found in probabilities: {probs_dict}\")\n\n            if negative_token not in probs_dict:\n                raise ValueError(f\"Negative token '{negative_token}' not found in probabilities: {probs_dict}\")\n\n            # Extract probabilities\n            p_pos = float(probs_dict[positive_token])\n            p_neg = float(probs_dict[negative_token])\n\n            # Apply softmax normalization (ensure probabilities sum to 1)\n            total = p_pos + p_neg\n            if total <= 0:\n                raise ValueError(f\"Sum of probabilities is non-positive: {total}\")\n\n            p_pos_normalized = p_pos / total\n            p_neg_normalized = p_neg / total\n\n            # Store results\n            parsed_list.append({positive_token: p_pos, negative_token: p_neg})\n            normalized_list.append({positive_token: p_pos_normalized, negative_token: p_neg_normalized})\n            scores.append(p_pos_normalized)\n\n        except (json.JSONDecodeError, TypeError, AttributeError, ValueError, KeyError) as e:\n            # Raise detailed exception\n            error_msg = f\"Failed to parse generation\"\n            if sample_id:\n                error_msg += f\" for sample_id '{sample_id}'\"\n            error_msg += f\" at index {idx}:\\n\"\n            error_msg += f\"  Generation: {generation[:200]}...\" if len(generation) > 200 else f\"  Generation: {generation}\\n\"\n            error_msg += f\"\\n  Error: {str(e)}\"\n            raise ValueError(error_msg)\n\n    # If no valid probabilities were found (empty list), raise error\n    if not scores:\n        error_msg = \"No valid probabilities found in generations\"\n        if sample_id:\n            error_msg += f\" for sample_id '{sample_id}'\"\n        raise ValueError(error_msg)\n\n    # Average across all valid elements (usually just one element)\n    if len(scores) == 1:\n        return {\n            \"parsed_probs\": parsed_list[0],\n            \"normalized_probs\": normalized_list[0],\n            \"score\": scores[0]\n        }\n    else:\n        # Average the probabilities for each token\n        avg_parsed = {\n            positive_token: sum(p[positive_token] for p in parsed_list) / len(parsed_list),\n            negative_token: sum(p[negative_token] for p in parsed_list) / len(parsed_list)\n        }\n        avg_normalized = {\n            positive_token: sum(p[positive_token] for p in normalized_list) / len(normalized_list),\n            negative_token: sum(p[negative_token] for p in normalized_list) / len(normalized_list)\n        }\n        avg_score = sum(scores) / len(scores)\n\n        return {\n            \"parsed_probs\": avg_parsed,\n            \"normalized_probs\": avg_normalized,\n            \"score\": avg_score\n        }\n\n\ndef calculate_auc(\n    predictions: Dict[str, float],\n    labels: Dict[str, int]\n) -> float:\n    \"\"\"\n    Calculate AUC (Area Under ROC Curve) using sklearn\n\n    Args:\n        predictions: Predicted probabilities, format: {sample_id: probability}\n        labels: Ground truth labels (0 or 1), format: {sample_id: label}\n\n    Returns:\n        AUC value (float between 0 and 1)\n    \"\"\"\n    if not predictions or not labels:\n        console.print(\"[red]✗ Predictions or labels are empty[/red]\")\n        return 0.0\n\n    # Align predictions and labels\n    sample_ids = sorted(set(predictions.keys()) & set(labels.keys()))\n\n    if len(sample_ids) == 0:\n        console.print(\"[red]✗ No overlapping samples between predictions and labels[/red]\")\n        return 0.0\n\n    y_true = np.array([labels[id] for id in sample_ids])\n    y_scores = np.array([predictions[id] for id in sample_ids])\n\n    # Check if we have both positive and negative samples\n    if len(np.unique(y_true)) < 2:\n        console.print(\"[yellow]⚠ Only one class present in labels, AUC is not defined[/yellow]\")\n        return 0.5\n\n    try:\n        # Calculate AUC using sklearn\n        auc = roc_auc_score(y_true, y_scores)\n        return float(auc)\n    except ValueError as e:\n        console.print(f\"[red]✗ Error calculating AUC: {e}[/red]\")\n        return 0.5\n\n\ndef get_debug_info(\n    sample_id: str,\n    logprobs_dict: Dict[str, float],\n    predicted_prob: float,\n    ground_truth: str,\n    label: int,\n    user_id: str = \"\"\n) -> Dict[str, Any]:\n    \"\"\"\n    Prepare debug information for a sample\n    \n    Args:\n        sample_id: Sample ID\n        logprobs_dict: Dictionary of token probabilities\n        predicted_prob: Predicted probability for positive class\n        ground_truth: Ground truth answer string\n        label: Ground truth label (0 or 1)\n        user_id: User ID (optional)\n    \n    Returns:\n        Debug information dictionary\n    \"\"\"\n    debug_item = {\n        \"sample_id\": sample_id,\n        \"ground_truth\": ground_truth,\n        \"label\": label,\n        \"predicted_prob\": predicted_prob,\n        \"logprobs\": logprobs_dict,\n    }\n    \n    if user_id:\n        debug_item[\"user_id\"] = user_id\n    \n    return debug_item\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/mfu_evaluator.py",
    "content": "\"\"\"\nMFU (Model FLOPs Utilization) Evaluator\n\nComputes MFU metric based on:\n- Model parameters\n- Token statistics\n- GPU hardware information\n- Generation time\n\nMFU = (num_params × 2 × total_tokens) / (gpu_flops × gpu_count × time_s)\n\"\"\"\n\nfrom typing import Dict, Any, Optional\n\nfrom benchmark.console import console, warning_style, dim_style\n\n\ndef compute_mfu(\n    num_params: float,\n    total_tokens: int,\n    gpu_tflops: float,\n    gpu_count: int,\n    time_seconds: float,\n) -> float:\n    \"\"\"\n    Compute Model FLOPs Utilization (MFU)\n\n    Formula:\n        MFU = (num_params × 2 × total_tokens) / (gpu_flops × gpu_count × time_s)\n\n    Args:\n        num_params: Number of model parameters\n        total_tokens: Total tokens processed (input + output)\n        gpu_tflops: GPU theoretical peak TFLOPS (for BF16/FP16)\n        gpu_count: Number of GPUs used\n        time_seconds: Total time in seconds\n\n    Returns:\n        MFU value (0-1, typically 0.01-0.5 for inference)\n    \"\"\"\n    if time_seconds <= 0:\n        console.print(\"⚠ Time is zero or negative, cannot compute MFU\", style=warning_style)\n        return 0.0\n\n    if gpu_tflops is None or gpu_tflops <= 0:\n        console.print(\"⚠ GPU TFLOPS is not available, cannot compute MFU\", style=warning_style)\n        return 0.0\n\n    if num_params is None or num_params <= 0:\n        console.print(\"⚠ Model parameters not specified, cannot compute MFU\", style=warning_style)\n        return 0.0\n\n    # Convert TFLOPS to FLOPS\n    gpu_flops = gpu_tflops * 1e12\n\n    # Compute total FLOPs required\n    # For inference: FLOPs ≈ 2 × num_params × num_tokens\n    total_flops = num_params * 2 * total_tokens\n\n    # Compute theoretical peak FLOPs available\n    theoretical_flops = gpu_flops * gpu_count * time_seconds\n\n    # MFU = actual FLOPs / theoretical FLOPs\n    mfu = total_flops / theoretical_flops\n\n    return mfu\n\n\ndef compute_mfu_from_generation_data(gen_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:\n    \"\"\"\n    Compute MFU metrics from generation result data\n\n    Args:\n        gen_data: Generation result data from JSON file, containing:\n            - num_params: Model parameters\n            - mfu_stats_aggregate: MFU statistics aggregate (dict with lists)\n            - hardware_info: GPU hardware info\n            - total_time: Total generation time\n\n    Returns:\n        Dictionary containing MFU metrics, or None if cannot compute\n        For multi-stage generation, returns MFU_stages as a list\n    \"\"\"\n    # Extract required fields\n    num_params = gen_data.get(\"num_params\")\n    mfu_stats_aggregate = gen_data.get(\"mfu_stats_aggregate\", {})\n    hardware_info = gen_data.get(\"hardware_info\", {})\n    total_time = gen_data.get(\"total_time\", 0)\n\n    # Validate required data\n    if not num_params:\n        console.print(\"[DEBUG] MFU: Model parameters not available in generation data, skipping MFU calculation\", style=dim_style)\n        return None\n\n    if not mfu_stats_aggregate or len(mfu_stats_aggregate.get(\"total_time\", [])) == 0:\n        console.print(\"[DEBUG] MFU: MFU statistics not available in generation data, skipping MFU calculation\", style=dim_style)\n        return None\n\n    if not hardware_info:\n        console.print(\"[DEBUG] MFU: Hardware info not available in generation data, skipping MFU calculation\", style=dim_style)\n        return None\n\n    # Extract hardware info\n    gpu_tflops = hardware_info.get(\"gpu_tflops\")\n    gpu_count = hardware_info.get(\"gpu_count\", 1)\n    gpu_model = hardware_info.get(\"gpu_model\", \"unknown\")\n\n    if gpu_tflops is None:\n        console.print(f\"⚠ GPU TFLOPS not available for {gpu_model}, cannot compute MFU\", style=warning_style)\n        return None\n\n    # Extract lists from aggregate stats\n    total_input_tokens_list = mfu_stats_aggregate.get(\"total_input_tokens\", [])\n    total_output_tokens_list = mfu_stats_aggregate.get(\"total_output_tokens\", [])\n    total_time_list = mfu_stats_aggregate.get(\"total_time\", [])\n\n    # Validate list lengths are consistent\n    if not (len(total_input_tokens_list) == len(total_output_tokens_list) == len(total_time_list)):\n        console.print(\n            f\"⚠ Inconsistent list lengths in mfu_stats_aggregate: \"\n            f\"input_tokens={len(total_input_tokens_list)}, \"\n            f\"output_tokens={len(total_output_tokens_list)}, \"\n            f\"times={len(total_time_list)}\",\n            style=warning_style\n        )\n        return None\n\n    num_stages = len(total_time_list)\n\n    # Compute MFU for each stage\n    mfu_list = []\n\n    for stage_idx in range(num_stages):\n        stage_num = stage_idx + 1\n        total_input_tokens = total_input_tokens_list[stage_idx] if stage_idx < len(total_input_tokens_list) else 0\n        total_output_tokens = total_output_tokens_list[stage_idx] if stage_idx < len(total_output_tokens_list) else 0\n        stage_time = total_time_list[stage_idx]\n        total_tokens = total_input_tokens + total_output_tokens\n\n        if total_tokens == 0:\n            console.print(f\"⚠ Stage {stage_num}: Total tokens is zero, skipping\", style=warning_style)\n            return None\n\n        if stage_time <= 0:\n            console.print(f\"⚠ Stage {stage_num}: Stage time is zero or negative, skipping\", style=warning_style)\n            return None\n\n        # Compute MFU for this stage using per-stage time\n        mfu = compute_mfu(\n            num_params=num_params,\n            total_tokens=total_tokens,\n            gpu_tflops=gpu_tflops,\n            gpu_count=gpu_count,\n            time_seconds=stage_time,\n        )\n\n        mfu_list.append(round(mfu, 6))\n\n    if len(mfu_list) == 0:\n        console.print(\"⚠ No valid stages for MFU calculation\", style=warning_style)\n        return None\n\n    # Create metrics with symmetric list structure\n    mfu_metrics = {\n        \"mfu\": mfu_list,\n        \"gpu_model\": gpu_model,\n        \"gpu_count\": gpu_count,\n        \"num_params\": num_params,\n        \"total_input_tokens\": total_input_tokens_list,\n        \"total_output_tokens\": total_output_tokens_list,\n        \"stage_time\": total_time_list,\n    }\n\n    return mfu_metrics\n\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/qwen3.jinja2",
    "content": "{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\n\\n' }}\n    {%- endif %}\n    {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n    {%- set index = (messages|length - 1) - loop.index0 %}\n    {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n        {%- set ns.multi_step_tool = false %}\n        {%- set ns.last_query_index = index %}\n    {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n    {%- if message.content is string %}\n        {%- set content = message.content %}\n    {%- else %}\n        {%- set content = '' %}\n    {%- endif %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n        {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {%- set reasoning_content = '' %}\n        {%- if message.reasoning_content is string %}\n            {%- set reasoning_content = message.reasoning_content %}\n        {%- else %}\n            {%- if '</think>' in content %}\n                {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n                {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n            {%- endif %}\n        {%- endif %}\n        {%- if loop.index0 > ns.last_query_index %}\n            {%- if loop.last or (not loop.last and reasoning_content) %}\n                {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n            {%- else %}\n                {{- '<|im_start|>' + message.role + '\\n' + content }}\n            {%- endif %}\n        {%- else %}\n            {{- '<|im_start|>' + message.role + '\\n' + content }}\n        {%- endif %}\n        {%- if message.tool_calls %}\n            {%- for tool_call in message.tool_calls %}\n                {%- if (loop.first and content) or (not loop.first) %}\n                    {{- '\\n' }}\n                {%- endif %}\n                {%- if tool_call.function %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {{- '<tool_call>\\n{\"name\": \"' }}\n                {{- tool_call.name }}\n                {{- '\", \"arguments\": ' }}\n                {%- if tool_call.arguments is string %}\n                    {{- tool_call.arguments }}\n                {%- else %}\n                    {{- tool_call.arguments | tojson }}\n                {%- endif %}\n                {{- '}\\n</tool_call>' }}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n    {%- if enable_thinking is defined and enable_thinking is false %}\n        {{- '<think>\\n\\n</think>\\n\\n' }}\n    {%- endif %}\n{%- endif %}"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/qwen3_soft_switch.jinja2",
    "content": "{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\n\\n' }}\n    {%- endif %}\n    {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n    {%- set index = (messages|length - 1) - loop.index0 %}\n    {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n        {%- set ns.multi_step_tool = false %}\n        {%- set ns.last_query_index = index %}\n    {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n    {%- if message.content is string %}\n        {%- set content = message.content %}\n    {%- else %}\n        {%- set content = '' %}\n    {%- endif %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n        {%- set suffix = '' %}\n        {%- if message.role == \"user\" and loop.index0 == ns.last_query_index %}\n            {%- if enable_thinking is defined and enable_thinking is false %}\n                {%- set suffix = '/no_think' %}\n            {%- else %}\n                {%- set suffix = '/think' %}\n            {%- endif %}\n        {%- endif %}\n        {{- '<|im_start|>' + message.role + '\\n' + content + suffix + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {%- set reasoning_content = '' %}\n        {%- if message.reasoning_content is string %}\n            {%- set reasoning_content = message.reasoning_content %}\n        {%- else %}\n            {%- if '</think>' in content %}\n                {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n                {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n            {%- endif %}\n        {%- endif %}\n        {%- if loop.index0 > ns.last_query_index %}\n            {%- if loop.last or (not loop.last and reasoning_content) %}\n                {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n            {%- else %}\n                {{- '<|im_start|>' + message.role + '\\n' + content }}\n            {%- endif %}\n        {%- else %}\n            {{- '<|im_start|>' + message.role + '\\n' + content }}\n        {%- endif %}\n        {%- if message.tool_calls %}\n            {%- for tool_call in message.tool_calls %}\n                {%- if (loop.first and content) or (not loop.first) %}\n                    {{- '\\n' }}\n                {%- endif %}\n                {%- if tool_call.function %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {{- '<tool_call>\\n{\"name\": \"' }}\n                {{- tool_call.name }}\n                {{- '\", \"arguments\": ' }}\n                {%- if tool_call.arguments is string %}\n                    {{- tool_call.arguments }}\n                {%- else %}\n                    {{- tool_call.arguments | tojson }}\n                {%- endif %}\n                {{- '}\\n</tool_call>' }}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n    {%- if enable_thinking is defined and enable_thinking is false %}\n        {{- '<think>\\n\\n</think>\\n\\n' }}\n    {%- endif %}\n{%- endif %}"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/rec_reason/__init__.py",
    "content": "\"\"\"\nRecommendation Reason Task Module\n\"\"\"\n\nfrom .config import REC_REASON_CONFIG\nfrom .evaluator import RecoReasonEvaluator\nfrom . import utils\n\n__all__ = [\n    \"REC_REASON_CONFIG\",\n    \"RecoReasonEvaluator\",\n    \"utils\",\n]\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/rec_reason/config.py",
    "content": "\"\"\"\nRecommendation Reason Task Configuration\n\"\"\"\n\n# Recommendation Reason Task Configuration\nREC_REASON_CONFIG = {\n    \"name\": \"rec_reason\",\n    \"source\": \"Kuaishou Internal\",\n    \"splits\": [\"test\"],\n    \"size\": 470,\n    \"sample_size\": 470,\n    \"description\": \"Recommendation reason inference\",\n    \"data_fields\": {\n        \"messages_field\": \"messages\",\n        \"metadata_field\": \"metadata\",\n    },\n    \"prompt_config\": {\n        \"enable_thinking\": True,  # Enable thinking mode for apply_chat_template\n        \"custom_chat_template\": \"qwen3_soft_switch.jinja2\",  # Custom jinja2 template (file in v1_0 directory)\n    },\n    \"generation_config\": {\n        \"num_return_sequences\": 1,\n        \"max_new_tokens\": 2000,\n        \"temperature\": 0.01,\n        \"top_p\": 0.95,\n        \"repetition_penalty\": 1.1,\n        \"do_sample\": False,\n        \"num_return_thinking_sequences\": 1,\n        \"max_new_thinking_tokens\": 10000,\n    },\n    \"evaluation_config\": {\n        \"metrics\": [\"avg_score\"],\n        # LLM multi-dimensional evaluation config\n        \"llm_eval_enabled\": True,                  # Whether to enable LLM evaluation\n        \"llm_judge_model\": \"gemini\",               # Judge LLM type: gemini/deepseek/claude\n        \"llm_max_workers\": 1,                      # Concurrent workers for LLM calls\n        \"llm_max_samples\": 470,                    # Max samples to evaluate (None for all)\n    }\n}\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/rec_reason/evaluator.py",
    "content": "\"\"\"\nRecommendation Reason Evaluator\n\nEvaluates model predictions on Recommendation Reason task using LLM-based multi-dimensional evaluation.\n\"\"\"\n\nimport os\nfrom typing import Dict, Any, Tuple, List\n\nfrom benchmark.console import console\nfrom benchmark.tasks.v1_0.base_evaluator import BaseEval\nfrom benchmark.tasks.v1_0.rec_reason.utils import extract_after_think, evaluate_reasoning\n\n\nclass RecoReasonEvaluator(BaseEval):\n    \"\"\"Recommendation Reason task evaluator\"\"\"\n\n    @property\n    def required_metrics(self) -> List[str]:\n        \"\"\"Define required overall metrics for Recommendation Reason evaluation\"\"\"\n        return [\"llm_score\"]\n\n    def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:\n        \"\"\"\n        Compute all metrics from scratch\n\n        Returns:\n            Tuple of (metrics, per_sample_metrics)\n        \"\"\"\n        total_samples = len(self.samples)\n\n        # Prepare data for evaluation\n        sample_ids = list(self.samples.keys())\n        predictions = []\n        references = []\n\n        for sample_id in sample_ids:\n            sample = self.samples[sample_id]\n\n            # Get ground truth\n            ground_truth = sample.get(\"ground_truth\", \"\")\n            references.append(ground_truth)\n\n            # Get model prediction (first generation)\n            generations = sample.get(\"generations\", [])\n            if not generations:\n                prediction = \"\"\n            else:\n                # Extract text after </think> tag if present\n                prediction = extract_after_think(generations[0])\n            predictions.append(prediction)\n\n        # Get evaluation config\n        eval_config = self.task_config.get(\"evaluation_config\", {})\n\n        # Build per-sample metrics\n        per_sample_metrics = {}\n        for sample_id in sample_ids:\n            per_sample_metrics[sample_id] = {}\n\n        # Build overall metrics\n        metrics = {\n            \"num_samples\": total_samples,\n        }\n\n        # LLM Evaluation (if enabled)\n        llm_eval_enabled = eval_config.get(\"llm_eval_enabled\", False)\n        if llm_eval_enabled:\n            console.print(\"[cyan]LLM evaluation enabled, starting multi-dimensional evaluation...[/cyan]\")\n            llm_metrics, llm_per_sample = self._evaluate_reasoning(\n                sample_ids=sample_ids,\n                predictions=predictions,\n                references=references,\n                eval_config=eval_config\n            )\n\n            # Merge LLM metrics into overall metrics\n            metrics.update(llm_metrics)\n\n            # Merge LLM per-sample metrics\n            for sample_id in sample_ids:\n                if sample_id in llm_per_sample:\n                    per_sample_metrics[sample_id].update(llm_per_sample[sample_id])\n\n        # Save debug information if requested\n        if self.debug and self.predictions_dir:\n            self._save_debug_info(metrics, per_sample_metrics, predictions, references)\n\n        return metrics, per_sample_metrics\n\n    def _evaluate_reasoning(\n        self,\n        sample_ids: list,\n        predictions: list,\n        references: list,\n        eval_config: Dict[str, Any]\n    ) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:\n        \"\"\"\n        Perform LLM-based multi-dimensional evaluation.\n\n        Args:\n            sample_ids: List of sample IDs\n            predictions: List of prediction texts\n            references: List of reference texts\n            eval_config: Evaluation configuration\n\n        Returns:\n            Tuple of (llm_metrics, llm_per_sample_metrics)\n        \"\"\"\n        try:\n            from api import get_client_from_config\n        except ImportError as e:\n            console.print(f\"[red]Failed to import LLM evaluation modules: {e}[/red]\")\n            return {}, {}\n\n        # Get LLM eval config\n        llm_judge_model = eval_config.get(\"llm_judge_model\", \"gemini\")\n        llm_max_workers = eval_config.get(\"llm_max_workers\", 3)\n        llm_max_samples = eval_config.get(\"llm_max_samples\", 300)\n\n        # Create LLM client\n        try:\n            llm_client = get_client_from_config(llm_judge_model)\n            console.print(f\"[green]Using {llm_judge_model} as LLM judge[/green]\")\n        except Exception as e:\n            console.print(f\"[red]Failed to create LLM client for evaluation: {e}[/red]\")\n            return {}, {}\n\n        # Prepare data as dicts\n        predictions_dict = {id: pred for id, pred in zip(sample_ids, predictions)}\n        references_dict = {id: ref for id, ref in zip(sample_ids, references)}\n\n        # Get model name for cache file naming\n        model_name = getattr(llm_client, 'model_name', llm_judge_model)\n\n        # Run LLM evaluation\n        try:\n            llm_metrics, llm_per_sample = evaluate_reasoning(\n                predictions=predictions_dict,\n                references=references_dict,\n                llm_client=llm_client,\n                max_workers=llm_max_workers,\n                max_samples=llm_max_samples,\n                model_name=model_name,\n                save_dir=self.predictions_dir,\n            )\n\n            console.print(f\"[green]LLM evaluation completed: {llm_metrics.get('llm_eval_num_samples', 0)} samples evaluated[/green]\")\n            return llm_metrics, llm_per_sample\n\n        except Exception as e:\n            console.print(f\"[red]LLM evaluation failed: {e}[/red]\")\n            import traceback\n            traceback.print_exc()\n            return {}, {}\n\n    def _save_debug_info(\n        self,\n        metrics: Dict[str, Any],\n        per_sample_metrics: Dict[str, Dict[str, Any]],\n        predictions: list,\n        references: list\n    ):\n        \"\"\"\n        Save detailed debug information to file\n\n        Args:\n            metrics: Overall metrics\n            per_sample_metrics: Per-sample metrics\n            predictions: List of predictions\n            references: List of references\n        \"\"\"\n        # Prepare debug info\n        debug_info = {\n            \"overall_metrics\": metrics,\n            \"per_sample_metrics\": per_sample_metrics,\n            \"sample_count\": len(predictions),\n        }\n\n        # Add some examples\n        sample_ids = list(self.samples.keys())\n        debug_info[\"examples\"] = []\n        for i in range(min(10, len(sample_ids))):\n            sample_id = sample_ids[i]\n            debug_info[\"examples\"].append({\n                \"sample_id\": sample_id,\n                \"prediction\": predictions[i][:500] + \"...\" if len(predictions[i]) > 500 else predictions[i],\n                \"reference\": references[i][:500] + \"...\" if len(references[i]) > 500 else references[i],\n                \"llm_score\": per_sample_metrics[sample_id].get(\"llm_score\"),\n                \"llm_reason\": per_sample_metrics[sample_id].get(\"llm_reason\"),\n            })\n\n        # Save to file using base class method\n        self._save_debug_json(debug_info, filename=\"debug.json\")\n\n        # Print summary statistics\n        console.print(f\"Total samples: {metrics['num_samples']}\")\n\n        # Print LLM eval metrics if available\n        if metrics.get('llm_score') is not None:\n            console.print(f\"LLM Eval Score: {metrics['llm_score']:.4f}\")\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/rec_reason/utils.py",
    "content": "\"\"\"\nRecommendation Reason LLM Evaluation Utilities\n\nProvides functions for extracting refined reasoning and multi-dimensional LLM evaluation.\n\"\"\"\n\nimport json\nimport os\nimport re\nfrom typing import Dict, List, Optional, Tuple, Any\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\n\nimport pandas as pd\nfrom tqdm import tqdm\n\nfrom benchmark.console import console\n\n\nEVALUATION_PROMPT = \"\"\"你是一位专业的推荐系统评估专家。你的任务是评估一个AI模型生成的\"推荐理由\"与\"标准答案\"的匹配程度。\n\n### 评估任务\n请对模型生成的推荐理由进行综合评分（1-5分）。\n\n**核心评估原则：**\n请严格按照以下步骤进行思考和评估：\n1.  **核心要素提取**：\n    - 从【标准答案】中提取：推荐的核心动机（用户为什么看）+ 推荐的内容类型（看的是什么）。\n    - 从【模型生成】中提取：推荐的核心动机 + 推荐的内容类型。\n2.  **噪音过滤（关键步骤）**：\n    - 忽略具体的措辞差异（同义词替换）。\n    - **忽略与推荐逻辑无关的用户画像细节**（例如：具体的年龄数字、与推理逻辑和视频内容无关的兴趣等）。\n3.  **匹配度分析**：\n    - 对比核心动机：是否抓住了相同的推荐的核心动机？\n    - 对比内容方向：推荐的视频类别/主题是否一致？\n4.  **评分**：基于评分标准给出最终得分。\n\n**评分标准：**\n- 5分：核心逻辑与内容方向完全一致。即使表达方式不同，但语义内核完全相同。\n- 4分：核心逻辑正确，内容方向正确。可能遗漏了标准答案中极次要的补充信息，或包含了无伤大雅的冗余信息。\n- 3分：大方向（如视频类型）正确，但对“用户为什么喜欢”的归因不够准确，或遗漏了关键的转化动机。\n- 2分：推荐逻辑有明显误读，或者推荐的内容类型与标准答案有偏差（例如：把“学习教程”理解成了“娱乐搞笑”）。\n- 1分：逻辑和内容完全错误，或生成了风马牛不相及的内容。\n\n### 输入\n\n**[标准答案]**\n{}\n\n**[模型生成]**\n{}\n\n### 输出格式\n你的输出必须是【纯粹的 JSON 格式】，可以被 `json.loads` 直接解析。\n\n```json\n{{\n  \"llm_score\": <1-5的整数>,\n  \"llm_reason\": \"<简短的打分理由，不超过50字>\"\n}}\n```\n\n你的评估结果 (请严格按照上述要求返回一个格式规整的 JSON，可以被 json.loads 直接解析。请不要在 JSON 数据前后添加任何额外的解释性文字或代码块标记): \"\"\"\n\n\ndef extract_refined_reasoning(text: str) -> str:\n    \"\"\"\n    Extract the refined reasoning section from the full text.\n\n    Finds the last occurrence of \"精炼推理\" and extracts the text after it.\n\n    Args:\n        text: Full text containing the reasoning\n\n    Returns:\n        Extracted refined reasoning text, or original text if pattern not found\n    \"\"\"\n    if not text:\n        return \"\"\n\n    # Find the last occurrence of \"精炼推理\"\n    keyword = \"精炼推理\"\n    last_pos = text.rfind(keyword)\n\n    if last_pos != -1:\n        # Extract text after \"精炼推理\"\n        after_keyword = text[last_pos + len(keyword):]\n        # Remove leading punctuation, whitespace, and markdown symbols\n        after_keyword = re.sub(r'^[\\s\\*#：:\\n]+', '', after_keyword)\n        if after_keyword.strip():\n            return after_keyword.strip()\n\n    # If \"精炼推理\" not found, return original text\n    return text.strip()\n\n\ndef extract_after_think(text: str) -> str:\n    \"\"\"Extract text after the last </think> tag if present.\"\"\"\n    if '</think>' in text:\n        return text.split('</think>')[-1].strip()\n    return text\n\n\ndef extract_json_from_response(response: str) -> Optional[Dict]:\n    \"\"\"\n    Extract JSON from LLM response.\n\n    Args:\n        response: LLM response text\n\n    Returns:\n        Parsed JSON dict or None if parsing fails\n    \"\"\"\n    if not response:\n        return None\n\n    try:\n        response = response.strip()\n        # Remove markdown code blocks if present\n        if response.startswith('```json'):\n            response = response[7:]\n        elif response.startswith('```'):\n            response = response[3:]\n        if response.endswith('```'):\n            response = response[:-3]\n\n        return json.loads(response.strip())\n    except json.JSONDecodeError:\n        # Try to find JSON object in the response\n        match = re.search(r'\\{[^{}]*\\}', response, re.DOTALL)\n        if match:\n            try:\n                return json.loads(match.group())\n            except json.JSONDecodeError:\n                pass\n        console.print(f\"[yellow]Failed to parse JSON: {response[:200]}...[/yellow]\")\n        return None\n\n\ndef evaluate_single(\n    gt_reasoning: str,\n    model_reasoning: str,\n    llm_client\n) -> Tuple[Optional[Dict], Optional[str]]:\n    \"\"\"\n    Evaluate a single sample using LLM.\n\n    Args:\n        gt_reasoning: Ground truth refined reasoning\n        model_reasoning: Model-generated refined reasoning\n        llm_client: LLM client instance\n\n    Returns:\n        Tuple of (evaluation_result, error_message)\n    \"\"\"\n    prompt = EVALUATION_PROMPT.format(gt_reasoning, model_reasoning)\n\n    try:\n        response = llm_client.generate(prompt)\n        result = extract_json_from_response(response)\n\n        if result is not None and \"llm_score\" in result:\n            # Ensure score is in valid range\n            score = result[\"llm_score\"]\n            if not isinstance(score, (int, float)) or score < 1 or score > 5:\n                result[\"llm_score\"] = 3  # Default to middle score if invalid\n            return result, None\n\n        return None, f\"Failed to parse JSON or missing 'llm_score': {response[:100]}\"\n\n    except Exception as e:\n        return None, f\"API error: {str(e)}\"\n\n\ndef evaluate_batch(\n    gt_reasonings: Dict[str, str],\n    model_reasonings: Dict[str, str],\n    llm_client,\n    max_workers: int = 5,\n    desc: str = \"Evaluating reasoning\"\n) -> Tuple[Dict[str, Dict], Dict[str, str]]:\n    \"\"\"\n    Evaluate multiple samples in parallel.\n\n    Args:\n        gt_reasonings: Dict of {sample_id: gt_reasoning}\n        model_reasonings: Dict of {sample_id: model_reasoning}\n        llm_client: LLM client instance\n        max_workers: Number of concurrent workers\n        desc: Progress bar description\n\n    Returns:\n        Tuple of (results, errors)\n    \"\"\"\n    results = {}\n    errors = {}\n\n    # Only evaluate samples that have both GT and model reasoning\n    common_ids = set(gt_reasonings.keys()) & set(model_reasonings.keys())\n    common_ids = {id for id in common_ids if gt_reasonings[id] and model_reasonings[id]}\n\n    def process_single(sample_id: str):\n        gt = gt_reasonings[sample_id]\n        model = model_reasonings[sample_id]\n        result, error = evaluate_single(gt, model, llm_client)\n        return sample_id, result, error\n\n    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n        futures = {\n            executor.submit(process_single, sid): sid\n            for sid in common_ids\n        }\n\n        for future in tqdm(as_completed(futures), total=len(futures), desc=desc):\n            sample_id, result, error = future.result()\n            if result is not None:\n                results[sample_id] = result\n            if error is not None:\n                errors[sample_id] = error\n\n    # Statistics\n    total_attempted = len(common_ids)\n    total_success = len(results)\n\n    console.print(f\"[cyan]{desc} statistics: {total_attempted} attempted, {total_success} successful[/cyan]\")\n\n    return results, errors\n\n\ndef calculate_metrics(eval_results: Dict[str, Dict]) -> Dict[str, Any]:\n    \"\"\"\n    Calculate micro and macro metrics from evaluation results.\n\n    Args:\n        eval_results: Dict of {sample_id: evaluation_result}\n\n    Returns:\n        Dict with micro/macro scores\n    \"\"\"\n    if not eval_results:\n        return {}\n\n    # Collect scores\n    scores = []\n    for sample_id, result in eval_results.items():\n        if \"llm_score\" in result:\n            score = result[\"llm_score\"]\n            if isinstance(score, (int, float)) and 1 <= score <= 5:\n                scores.append(score)\n\n    metrics = {}\n\n    if scores:\n        # micro and macro are the same for single score\n        avg_score = sum(scores) / len(scores)\n        metrics[\"micro_llm_score\"] = avg_score\n        metrics[\"macro_llm_score\"] = avg_score\n        metrics[\"llm_score\"] = avg_score\n\n    metrics[\"llm_eval_num_samples\"] = len(eval_results)\n\n    return metrics\n\n\ndef get_per_sample_metrics(eval_results: Dict[str, Dict]) -> Dict[str, Dict[str, Any]]:\n    \"\"\"\n    Extract per-sample metrics from evaluation results.\n\n    Args:\n        eval_results: Dict of {sample_id: evaluation_result}\n\n    Returns:\n        Dict of {sample_id: {llm_score, llm_reason}}\n    \"\"\"\n    per_sample = {}\n\n    for sample_id, result in eval_results.items():\n        sample_metrics = {}\n\n        if \"llm_score\" in result:\n            sample_metrics[\"llm_score\"] = result[\"llm_score\"]\n\n        if \"llm_reason\" in result:\n            sample_metrics[\"llm_reason\"] = result[\"llm_reason\"]\n\n        per_sample[sample_id] = sample_metrics\n\n    return per_sample\n\n\ndef get_cache_path(save_dir: str, model_name: str) -> str:\n    \"\"\"Get the path for evaluation results cache file.\"\"\"\n    return os.path.join(save_dir, f\"llm_eval_{model_name}.json\")\n\n\ndef load_eval_cache(cache_path: str) -> Optional[Dict[str, Dict]]:\n    \"\"\"\n    Load evaluation results from cache.\n\n    Args:\n        cache_path: Path to cache file\n\n    Returns:\n        Dict of {sample_id: evaluation_result} or None if not found\n    \"\"\"\n    if not os.path.exists(cache_path):\n        return None\n\n    try:\n        with open(cache_path, 'r', encoding='utf-8') as f:\n            data = json.load(f)\n\n        # Extract evaluation results\n        eval_results = {}\n        for sample_id, sample_data in data.items():\n            if \"eval_result\" in sample_data and sample_data[\"eval_result\"]:\n                eval_results[sample_id] = sample_data[\"eval_result\"]\n\n        console.print(f\"[green]Loaded {len(eval_results)} cached evaluation results[/green]\")\n        return eval_results\n\n    except Exception as e:\n        console.print(f\"[yellow]Failed to load evaluation cache: {e}[/yellow]\")\n        return None\n\n\ndef save_eval_results(\n    save_dir: str,\n    sample_ids: List[str],\n    gt_reasonings: Dict[str, str],\n    model_reasonings: Dict[str, str],\n    eval_results: Dict[str, Dict],\n    model_name: str\n):\n    \"\"\"\n    Save evaluation results to file.\n\n    Args:\n        save_dir: Directory to save the file\n        sample_ids: List of sample IDs\n        gt_reasonings: Dict of {sample_id: gt_reasoning}\n        model_reasonings: Dict of {sample_id: model_reasoning}\n        eval_results: Dict of {sample_id: evaluation_result}\n        model_name: Model name for filename\n    \"\"\"\n    os.makedirs(save_dir, exist_ok=True)\n    save_path = get_cache_path(save_dir, model_name)\n\n    detailed_results = {}\n\n    for sample_id in sample_ids:\n        detailed_results[sample_id] = {\n            \"gt_reasoning\": gt_reasonings.get(sample_id, \"\"),\n            \"model_reasoning\": model_reasonings.get(sample_id, \"\"),\n            \"eval_result\": eval_results.get(sample_id, {}),\n        }\n\n    with open(save_path, 'w', encoding='utf-8') as f:\n        json.dump(detailed_results, f, ensure_ascii=False, indent=2)\n\n    console.print(f\"[green]Evaluation results saved to {save_path}[/green]\")\n\n\ndef evaluate_reasoning(\n    predictions: Dict[str, str],\n    references: Dict[str, str],\n    llm_client,\n    max_workers: int = 5,\n    max_samples: Optional[int] = None,\n    model_name: str = \"unknown\",\n    save_dir: Optional[str] = None,\n) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:\n    \"\"\"\n    Main evaluation function for recommendation reasoning.\n\n    Args:\n        predictions: Dict of {sample_id: prediction_text}\n        references: Dict of {sample_id: reference_text}\n        llm_client: LLM client instance for evaluation\n        max_workers: Number of concurrent workers\n        max_samples: Maximum samples to evaluate (None for all)\n        model_name: Model name for cache file naming\n        save_dir: Directory to save results\n\n    Returns:\n        Tuple of (metrics, per_sample_metrics)\n    \"\"\"\n    # Select samples\n    all_sample_ids = sorted(set(predictions.keys()) & set(references.keys()))\n\n    if max_samples is not None and max_samples < len(all_sample_ids):\n        sample_ids = all_sample_ids[:max_samples]\n        console.print(f\"[cyan]Selected {len(sample_ids)} samples for LLM evaluation[/cyan]\")\n    else:\n        sample_ids = all_sample_ids\n        console.print(f\"[cyan]Evaluating all {len(sample_ids)} samples[/cyan]\")\n\n    # Extract refined reasoning from both GT and model outputs\n    console.print(\"[cyan]Extracting refined reasoning...[/cyan]\")\n    gt_reasonings = {}\n    model_reasonings = {}\n\n    for sample_id in sample_ids:\n        # Extract from reference (GT)\n        gt_text = references.get(sample_id, \"\")\n        gt_reasonings[sample_id] = extract_refined_reasoning(gt_text)\n\n        # Extract from prediction (model output)\n        pred_text = predictions.get(sample_id, \"\")\n        pred_text = extract_after_think(pred_text)  # Remove <think> tags first\n        model_reasonings[sample_id] = extract_refined_reasoning(pred_text)\n\n    # Load cached results if available\n    eval_results = {}\n    if save_dir:\n        cache_path = get_cache_path(save_dir, model_name)\n        cached_results = load_eval_cache(cache_path)\n        if cached_results:\n            eval_results = {k: v for k, v in cached_results.items() if k in sample_ids}\n\n    # Find samples that need evaluation\n    missing_ids = set(sample_ids) - set(eval_results.keys())\n    missing_ids = {\n        id for id in missing_ids\n        if gt_reasonings.get(id) and model_reasonings.get(id)\n    }\n\n    if missing_ids:\n        console.print(f\"[cyan]Evaluating {len(missing_ids)} samples with LLM...[/cyan]\")\n        missing_gt = {id: gt_reasonings[id] for id in missing_ids}\n        missing_model = {id: model_reasonings[id] for id in missing_ids}\n\n        new_results, errors = evaluate_batch(\n            missing_gt, missing_model, llm_client, max_workers\n        )\n        eval_results.update(new_results)\n\n        if errors:\n            console.print(f\"[red]Evaluation errors: {len(errors)} samples[/red]\")\n    else:\n        console.print(f\"[green]All {len(sample_ids)} samples already evaluated (from cache)[/green]\")\n\n    # Calculate metrics\n    metrics = calculate_metrics(eval_results)\n    per_sample_metrics = get_per_sample_metrics(eval_results)\n\n    # Save results\n    if save_dir:\n        save_eval_results(\n            save_dir, sample_ids, gt_reasonings, model_reasonings, eval_results, model_name\n        )\n\n    # Print summary\n    console.print(f\"[green]LLM evaluation completed: {metrics.get('llm_eval_num_samples', 0)} samples[/green]\")\n    console.print(f\"[green]  LLM Score: {metrics.get('llm_score', 0):.4f}[/green]\")\n\n    return metrics, per_sample_metrics\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/recommendation/__init__.py",
    "content": "\"\"\"\nRecommendation Task Module\n\nUniversal module for all recommendation tasks including:\n- label_cond: Predict next video given specified consumption behavior\n- video: Next video prediction\n- product: Predict next clicked product\n- ad: Predict next clicked advertisement\n\"\"\"\n\nfrom .config import (\n    LABEL_COND_CONFIG,\n    VIDEO_CONFIG,\n    PRODUCT_CONFIG,\n    AD_CONFIG,\n    INTERACTIVE_CONFIG,\n    RECOMMENDATION_PROMPT_CONFIG,\n    RECOMMENDATION_TASK_CONFIGS,\n    RECOMMENDATION_GENERATION_CONFIG,\n    RECOMMENDATION_EVALUATION_CONFIG,\n)\nfrom .evaluator import RecommendationEvaluator\nfrom . import utils\n\n__all__ = [\n    # Configs\n    \"LABEL_COND_CONFIG\",\n    \"VIDEO_CONFIG\",\n    \"PRODUCT_CONFIG\",\n    \"AD_CONFIG\",\n    \"INTERACTIVE_CONFIG\",\n    \"RECOMMENDATION_PROMPT_CONFIG\",\n    \"RECOMMENDATION_TASK_CONFIGS\",\n    \"RECOMMENDATION_GENERATION_CONFIG\",\n    \"RECOMMENDATION_EVALUATION_CONFIG\",\n    # Classes\n    \"RecommendationEvaluator\",\n    # Utils module\n    \"utils\",\n]\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/recommendation/config.py",
    "content": "\"\"\"\nRecommendation Task Configurations\n\nThis module contains configurations for all recommendation tasks including:\n- label_cond: Predict next video given specified consumption behavior\n- video: Next video prediction\n- product: Predict next clicked product\n- ad: Predict next clicked advertisement\n\"\"\"\n\n# Common prompt config for recommendation tasks\nRECOMMENDATION_PROMPT_CONFIG = {\n    \"enable_thinking\": False,\n    \"custom_chat_template\": \"qwen3_soft_switch.jinja2\",\n}\n\n# Common generation config for recommendation tasks\nRECOMMENDATION_GENERATION_CONFIG = {\n    \"num_return_sequences\": 128,\n    \"max_new_tokens\": 3,\n    \"temperature\": 0.6,\n    \"top_p\": 0.95,\n    \"top_k\": 50,  \n    \"presence_penalty\": 0,\n    \"frequency_penalty\": 0,\n    \"prompt_token\": \"<|sid_begin|>\",  # Token to append for two-stage generation\n    \"max_new_thinking_tokens\": 1000,\n    \"num_return_thinking_sequences\": 8,  # Number of thinking candidates to generate in stage 1\n    \"num_beams\": 16,\n}\n\n# Common evaluation config for recommendation tasks\nRECOMMENDATION_EVALUATION_CONFIG = {\n    \"metrics\": [\"pass@k\", \"position1_pass@k\", \"recall@k\"],\n    \"k_values\": [1, 32],\n    \"select_k\": \"first_k\",  # Strategy for selecting k predictions: 'first_k' or 'random_k'\n\n    # PID-based evaluation settings\n    \"evaluation_mode\": \"both\",  # Evaluation mode: 'sid', 'pid', or 'both'\n    \"sid_to_pid_strategy\": \"most_popular_after_downsampling\",  # Strategy for SID->PID conversion: 'most_popular_originally', 'most_popular_after_downsampling', or 'random'\n}\n\n# Label Cond Task Configuration\nLABEL_COND_CONFIG = {\n    \"name\": \"label_cond\",\n    \"source\": \"Kuaishou Internal\",\n    \"splits\": [\"test\"],\n    \"size\": 34891,\n    \"sample_size\": 34891,\n    \"description\": \"Predict next video given specified consumption behavior\",\n    \"data_fields\": {\n        \"messages_field\": \"messages\",\n        \"metadata_field\": \"metadata\",\n    },\n    \"prompt_config\": RECOMMENDATION_PROMPT_CONFIG.copy(),\n    \"generation_config\": RECOMMENDATION_GENERATION_CONFIG.copy(),\n    \"evaluation_config\": RECOMMENDATION_EVALUATION_CONFIG.copy(),\n}\n\n# SID USER Doc Task Configuration\nVIDEO_CONFIG = {\n    \"name\": \"video\",\n    \"source\": \"Kuaishou Internal\",\n    \"splits\": [\"test\"],\n    \"size\": 38781,\n    \"sample_size\": 38781,\n    \"description\": \"Next video prediction\",\n    \"data_fields\": {\n        \"messages_field\": \"messages\",\n        \"metadata_field\": \"metadata\",\n    },\n    \"prompt_config\": RECOMMENDATION_PROMPT_CONFIG.copy(),\n    \"generation_config\": RECOMMENDATION_GENERATION_CONFIG.copy(),\n    \"evaluation_config\": RECOMMENDATION_EVALUATION_CONFIG.copy(),\n}\n\n# Product Task Configuration\nPRODUCT_CONFIG = {\n    \"name\": \"product\",\n    \"source\": \"Kuaishou Internal\",\n    \"splits\": [\"test\"],\n    \"size\": 27910,\n    \"sample_size\": 27910,\n    \"description\": \"Predict next clicked product\",\n    \"data_fields\": {\n        \"messages_field\": \"messages\",\n        \"metadata_field\": \"metadata\",\n    },\n    \"prompt_config\": RECOMMENDATION_PROMPT_CONFIG.copy(),\n    \"generation_config\": RECOMMENDATION_GENERATION_CONFIG.copy(),\n    \"evaluation_config\": RECOMMENDATION_EVALUATION_CONFIG.copy(),\n}\n\n# Ad Task Configuration\nAD_CONFIG = {\n    \"name\": \"ad\",\n    \"source\": \"Kuaishou Internal\",\n    \"splits\": [\"test\"],\n    \"size\": 27677,\n    \"sample_size\": 27677,\n    \"description\": \"Predict next clicked advertisement\",\n    \"data_fields\": {\n        \"messages_field\": \"messages\",\n        \"metadata_field\": \"metadata\",\n    },\n    \"prompt_config\": RECOMMENDATION_PROMPT_CONFIG.copy(),\n    \"generation_config\": RECOMMENDATION_GENERATION_CONFIG.copy(),\n    \"evaluation_config\": RECOMMENDATION_EVALUATION_CONFIG.copy(),\n}\n\n# Interactive Task Configuration\nINTERACTIVE_CONFIG = {\n    \"name\": \"interactive\",\n    \"source\": \"Kuaishou Internal\",\n    \"splits\": [\"test\"],\n    \"size\": 1000,\n    \"sample_size\": 1000,\n    \"description\": \"Predict next interacted video\",\n    \"data_fields\": {\n        \"messages_field\": \"messages\",\n        \"metadata_field\": \"metadata\",\n    },\n    \"prompt_config\": RECOMMENDATION_PROMPT_CONFIG.copy(),\n    \"generation_config\": RECOMMENDATION_GENERATION_CONFIG.copy(),\n    \"evaluation_config\": RECOMMENDATION_EVALUATION_CONFIG.copy(),\n}\n\n# Task configuration mapping\nRECOMMENDATION_TASK_CONFIGS = {\n    \"label_cond\": LABEL_COND_CONFIG,\n    \"video\": VIDEO_CONFIG,\n    \"product\": PRODUCT_CONFIG,\n    \"ad\": AD_CONFIG,\n    \"interactive\": INTERACTIVE_CONFIG,\n}\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/recommendation/evaluator.py",
    "content": "\"\"\"\nRecommendation Task Evaluator\n\nUniversal evaluator for all recommendation tasks.\nComputes Pass@k and Position1_Pass@k metrics.\n\"\"\"\n\nimport json\nfrom typing import Dict, Any, Tuple, List\n\nfrom benchmark.console import console, warning_style\nfrom benchmark.tasks.v1_0.base_evaluator import BaseEval\nfrom benchmark.tasks.v1_0.recommendation import utils as utils_sid\nfrom benchmark.tasks.v1_0.recommendation import utils_by_pid as utils_pid\n\n\nclass RecommendationEvaluator(BaseEval):\n    \"\"\"\n    Universal evaluator for recommendation tasks\n\n    Supports:\n    - label_cond: Predict next video given specified consumption behavior\n    - video: Next video prediction\n    - product: Predict next clicked product\n    - ad: Predict next clicked advertisement\n\n    Metrics:\n    - Pass@k: Check if any of top-k predictions match any ground truth SID\n    - Position1_Pass@k: Check if any of top-k predictions match the first ground truth SID\n    \"\"\"\n\n    @property\n    def required_metrics(self) -> List[str]:\n        \"\"\"Define required overall metrics for Recommendation evaluation\"\"\"\n        k_values = self.task_config.get(\"evaluation_config\", {}).get(\"k_values\", [128])\n        evaluation_mode = self.task_config.get(\"evaluation_config\", {}).get(\"evaluation_mode\", \"sid\")\n\n        metrics = []\n\n        if evaluation_mode in (\"sid\", \"both\"):\n            for k in k_values:\n                metrics.extend([f\"pass@{k}\", f\"position1_pass@{k}\", f\"recall@{k}\"])\n\n        if evaluation_mode in (\"pid\", \"both\"):\n            for k in k_values:\n                metrics.extend([f\"pid_pass@{k}\", f\"pid_position1_pass@{k}\", f\"pid_recall@{k}\"])\n\n        return metrics\n\n    def _select_generations_by_strategy(\n        self,\n        generations: List[str],\n        logprobs: List[float],\n        strategy: str\n    ) -> List[str]:\n        \"\"\"\n        Select and reorder generations based on the specified strategy\n\n        Args:\n            generations: List of generation strings\n            logprobs: List of cumulative logprobs for each generation\n            strategy: Selection strategy ('first_k' or 'top_k_by_logprobs')\n\n        Returns:\n            Reordered list of generations\n\n        Raises:\n            ValueError: If strategy is 'top_k_by_logprobs' but logprobs data is invalid\n        \"\"\"\n        if strategy == \"first_k\":\n            # Keep original order\n            return generations\n        elif strategy == \"top_k_by_logprobs\":\n            # Validate logprobs data\n            if not logprobs:\n                raise ValueError(\n                    f\"Strategy 'top_k_by_logprobs' requires logprobs data, but logprobs is empty. \"\n                    f\"Please ensure the generation was run with logprobs enabled.\"\n                )\n            if len(logprobs) != len(generations):\n                raise ValueError(\n                    f\"Strategy 'top_k_by_logprobs' requires logprobs length to match generations length. \"\n                    f\"Got logprobs length {len(logprobs)}, generations length {len(generations)}.\"\n                )\n\n            # Sort generations by logprobs in descending order (higher logprob = better)\n            paired = list(zip(generations, logprobs))\n            paired_sorted = sorted(paired, key=lambda x: x[1], reverse=True)\n\n            # Deduplicate while preserving order (keep first occurrence with highest logprob)\n            seen = set()\n            unique_generations = []\n            for gen, _ in paired_sorted:\n                if gen not in seen:\n                    seen.add(gen)\n                    unique_generations.append(gen)\n\n            return unique_generations\n        else:\n            raise ValueError(\n                f\"Unknown selection strategy: '{strategy}'. \"\n                f\"Supported strategies: 'first_k', 'top_k_by_logprobs'\"\n            )\n\n    def _evaluate_single_mode(\n        self,\n        k_values: List[int],\n        evaluation_mode: str,\n        select_k_strategy: str,\n        code_to_pid: Dict[int, List[Tuple[int, float]]] = None,\n        sid_to_pid_strategy: str = \"most_popular\"\n    ) -> Tuple[Dict[str, int], Dict[str, int], Dict[str, float], Dict[str, Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:\n        \"\"\"\n        Evaluate samples using a single mode (SID or PID)\n\n        Args:\n            k_values: List of k values to compute\n            evaluation_mode: Either 'sid' or 'pid'\n            select_k_strategy: Selection strategy for generations\n            code_to_pid: PID mapping dictionary (required for 'pid' mode)\n            sid_to_pid_strategy: Strategy for SID->PID conversion (\"most_popular\" or \"random\")\n\n        Returns:\n            Tuple of (pass_counts, position1_pass_counts, recall_sums, per_sample_metrics, debug_info_lists)\n        \"\"\"\n        # Select utils module based on mode\n        if evaluation_mode == \"sid\":\n            utils = utils_sid\n        elif evaluation_mode == \"pid\":\n            if code_to_pid is None:\n                raise ValueError(\"code_to_pid is required for PID evaluation mode\")\n            utils = utils_pid\n        else:\n            raise ValueError(f\"Invalid evaluation_mode: {evaluation_mode}\")\n\n        # Initialize counters\n        pass_at_k_counts = {k: 0 for k in k_values}\n        position1_pass_at_k_counts = {k: 0 for k in k_values}\n        recall_at_k_sums = {k: 0.0 for k in k_values}\n\n        # Per-sample metrics collection\n        per_sample_metrics = {}\n\n        # Debug information collection\n        debug_info = {\n            \"passed_samples\": [],\n            \"failed_samples\": [],\n            \"no_generation_samples\": [],\n        }\n\n        # Helper function to create failed metrics\n        def create_failed_metrics():\n            \"\"\"Create metrics dict for failed samples (all False/0.0)\"\"\"\n            metrics = {}\n            for k in k_values:\n                metrics[f\"pass@{k}\"] = False\n                metrics[f\"position1_pass@{k}\"] = False\n                metrics[f\"recall@{k}\"] = 0.0\n            return metrics\n\n        for sample_id, sample in self.samples.items():\n            # Get model predictions\n            generations = sample.get(\"generations\", [])\n            logprobs = sample.get(\"logprobs\", [])\n\n            if not generations:\n                # No generation, treat as failure\n                per_sample_metrics[sample_id] = create_failed_metrics()\n\n                if self.debug:\n                    debug_info[\"no_generation_samples\"].append({\n                        \"sample_id\": sample_id,\n                        \"ground_truth\": sample.get(\"ground_truth\", \"\") if evaluation_mode == \"sid\" else sample.get(\"metadata\", {}).get(\"answer_pid\", []),\n                    })\n                continue\n\n            # Get ground truth based on mode\n            if evaluation_mode == \"sid\":\n                ground_truth = sample.get(\"ground_truth\", \"\")\n                ground_truth_ids = utils.extract_ids_from_answer(ground_truth)\n                first_ground_truth_id = utils.extract_first_id_from_answer(ground_truth)\n            else:  # pid mode\n                # Try answer_pid first, fallback to answer_iid if not available\n                ground_truth_pids = sample.get(\"metadata\", {}).get(\"answer_pid\")\n                if ground_truth_pids is None:\n                    ground_truth_pids = sample.get(\"metadata\", {}).get(\"answer_iid\", [])\n                if isinstance(ground_truth_pids, str):\n                    ground_truth_pids = json.loads(ground_truth_pids)\n                ground_truth_ids = utils.extract_ids_from_answer(ground_truth_pids)\n                first_ground_truth_id = utils.extract_first_id_from_answer(ground_truth_pids)\n\n            if not ground_truth_ids:\n                console.print(f\"Sample {sample_id}: no valid ID found in ground truth ({evaluation_mode} mode)\", style=warning_style)\n                per_sample_metrics[sample_id] = create_failed_metrics()\n                continue\n\n            # Apply selection strategy to reorder generations\n            selected_generations = self._select_generations_by_strategy(\n                generations=generations,\n                logprobs=logprobs,\n                strategy=select_k_strategy\n            )\n\n            # Extract predicted IDs from selected generations\n            if evaluation_mode == \"sid\":\n                predicted_ids = [utils.extract_id_from_generation(gen) for gen in selected_generations]\n            else:  # pid mode\n                predicted_ids = [utils.extract_id_from_generation(gen, code_to_pid, sid_to_pid_strategy) for gen in selected_generations]\n\n            # Compute metrics for each k\n            sample_pass_results = {}\n            sample_position1_pass_results = {}\n            sample_recall_results = {}\n\n            for k in k_values:\n                # Compute Pass@k\n                pass_result = utils.compute_pass_at_k(predicted_ids, ground_truth_ids, k)\n                sample_pass_results[f\"pass@{k}\"] = pass_result\n                if pass_result:\n                    pass_at_k_counts[k] += 1\n\n                # Compute Position1_Pass@k\n                position1_pass_result = utils.compute_position1_pass_at_k(\n                    predicted_ids, first_ground_truth_id, k\n                )\n                sample_position1_pass_results[f\"position1_pass@{k}\"] = position1_pass_result\n                if position1_pass_result:\n                    position1_pass_at_k_counts[k] += 1\n\n                # Compute Recall@k\n                recall_result = utils.compute_recall_at_k(predicted_ids, ground_truth_ids, k)\n                sample_recall_results[f\"recall@{k}\"] = recall_result\n                recall_at_k_sums[k] += recall_result\n\n            # Store per-sample metrics\n            sample_metrics = {\n                **sample_pass_results,\n                **sample_position1_pass_results,\n                **sample_recall_results\n            }\n\n            # For PID mode, save pid_generations (convert None/invalid to -1)\n            if evaluation_mode == \"pid\":\n                pid_generations = [pid if pid is not None else -1 for pid in predicted_ids]\n                sample_metrics[\"generations\"] = pid_generations\n\n            per_sample_metrics[sample_id] = sample_metrics\n\n            # Debug information collection\n            if self.debug:\n                metadata = sample.get(\"metadata\", {})\n                raw_prompt = metadata.get(\"raw_prompt\", \"\")\n\n                if evaluation_mode == \"sid\":\n                    debug_item = utils.get_debug_info(\n                        sample_id=sample_id,\n                        generations=generations,\n                        ground_truth=sample.get(\"ground_truth\", \"\"),\n                        pass_results=sample_pass_results,\n                        position1_pass_results=sample_position1_pass_results,\n                        raw_prompt=raw_prompt,\n                    )\n                else:  # pid mode\n                    answer_pid = metadata.get(\"answer_pid\", metadata.get(\"answer_iid\", []))\n                    if isinstance(answer_pid, str):\n                        answer_pid = json.loads(answer_pid)\n                    debug_item = utils.get_debug_info(\n                        sample_id=sample_id,\n                        generations=generations,\n                        ground_truth=answer_pid,\n                        pass_results=sample_pass_results,\n                        position1_pass_results=sample_position1_pass_results,\n                        code_to_pid=code_to_pid,\n                        strategy=sid_to_pid_strategy,\n                        raw_prompt=raw_prompt,\n                    )\n\n                # Check if any pass@k is True\n                if any(sample_pass_results.values()):\n                    debug_info[\"passed_samples\"].append(debug_item)\n                else:\n                    debug_info[\"failed_samples\"].append(debug_item)\n\n        return pass_at_k_counts, position1_pass_at_k_counts, recall_at_k_sums, per_sample_metrics, debug_info\n\n    def _calculate_metrics_from_counts(\n        self,\n        pass_counts: Dict[int, int],\n        position1_pass_counts: Dict[int, int],\n        recall_sums: Dict[int, float],\n        total_samples: int,\n        k_values: List[int],\n        prefix: str = \"\"\n    ) -> Dict[str, float]:\n        \"\"\"\n        Calculate metrics from counts\n\n        Args:\n            pass_counts: Pass@k counts for each k\n            position1_pass_counts: Position1_Pass@k counts for each k\n            recall_sums: Recall@k sums for each k\n            total_samples: Total number of samples\n            k_values: List of k values\n            prefix: Prefix for metric names (e.g., \"pid_\")\n\n        Returns:\n            Dictionary of calculated metrics\n        \"\"\"\n        metrics = {}\n        for k in k_values:\n            metrics[f\"{prefix}pass@{k}\"] = pass_counts[k] / total_samples if total_samples > 0 else 0.0\n            metrics[f\"{prefix}position1_pass@{k}\"] = position1_pass_counts[k] / total_samples if total_samples > 0 else 0.0\n            metrics[f\"{prefix}recall@{k}\"] = recall_sums[k] / total_samples if total_samples > 0 else 0.0\n        return metrics\n\n    def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:\n        \"\"\"\n        Compute all evaluation metrics from scratch\n\n        Returns:\n            Tuple of (metrics, per_sample_metrics)\n        \"\"\"\n        total_samples = len(self.samples)\n\n        # Get configuration\n        evaluation_config = self.task_config.get('evaluation_config', {})\n        k_values = evaluation_config.get(\"k_values\", [128])\n        select_k_strategy = evaluation_config.get('select_k', 'first_k')\n        evaluation_mode = evaluation_config.get('evaluation_mode', 'both')\n        sid_to_pid_strategy = evaluation_config.get('sid_to_pid_strategy', 'most_popular_after_downsampling')\n\n        # Load PID mapping if needed\n        code_to_pid = None\n        if evaluation_mode in (\"pid\", \"both\"):\n            from pathlib import Path\n            task_name = self.task_config.get(\"name\", \"\")\n            if task_name == \"product\":\n                mapping_filename = \"sid2iid.json\"\n            else:\n                mapping_filename = \"sid2pid.json\"\n            pid_mapping_path = str(Path(self.data_dir) / mapping_filename)\n            console.print(f\"[cyan]Loading PID mapping from {pid_mapping_path}...[/cyan]\")\n            code_to_pid = utils_pid.load_pid_mapping(pid_mapping_path)\n\n        # Define evaluation modes to run\n        # Format: (mode_name, metric_prefix, debug_filename, log_message)\n        modes_config = {\n            \"sid\": [(\"sid\", \"\", \"debug.json\", \"Evaluating using SID mode...\")],\n            \"pid\": [(\"pid\", \"pid_\", \"debug_pid.json\", \"Evaluating using PID mode...\")],\n            \"both\": [\n                (\"sid\", \"\", \"debug_sid.json\", \"  Running SID evaluation...\"),\n                (\"pid\", \"pid_\", \"debug_pid.json\", \"  Running PID evaluation...\")\n            ]\n        }\n\n        if evaluation_mode not in modes_config:\n            raise ValueError(f\"Invalid evaluation_mode: '{evaluation_mode}'. Must be 'sid', 'pid', or 'both'\")\n\n        if evaluation_mode == \"both\":\n            console.print(\"[cyan]Evaluating using both SID and PID modes...[/cyan]\")\n\n        # Initialize metrics\n        metrics = {\"total_samples\": total_samples}\n        per_sample_metrics = {}\n        all_debug_info = {}\n\n        # Run evaluation for each configured mode\n        for mode_name, metric_prefix, debug_filename, log_message in modes_config[evaluation_mode]:\n            console.print(f\"[cyan]{log_message}[/cyan]\")\n\n            # Run evaluation\n            pass_counts, position1_pass_counts, recall_sums, mode_per_sample_metrics, debug_info = self._evaluate_single_mode(\n                k_values=k_values,\n                evaluation_mode=mode_name,\n                select_k_strategy=select_k_strategy,\n                code_to_pid=code_to_pid if mode_name == \"pid\" else None,\n                sid_to_pid_strategy=sid_to_pid_strategy if mode_name == \"pid\" else \"most_popular\"\n            )\n\n            # Calculate and add metrics\n            mode_metrics = self._calculate_metrics_from_counts(\n                pass_counts, position1_pass_counts, recall_sums,\n                total_samples, k_values, metric_prefix\n            )\n            metrics.update(mode_metrics)\n\n            # Merge per-sample metrics with appropriate prefix\n            for sample_id, sample_metric in mode_per_sample_metrics.items():\n                if sample_id not in per_sample_metrics:\n                    per_sample_metrics[sample_id] = {}\n                # Add metrics with prefix (for PID mode) or without (for SID mode)\n                if metric_prefix:\n                    # PID mode: add prefix to metric names\n                    for metric_name, metric_value in sample_metric.items():\n                        prefixed_name = f\"{metric_prefix}{metric_name}\"\n                        per_sample_metrics[sample_id][prefixed_name] = metric_value\n                else:\n                    # SID mode: no prefix\n                    per_sample_metrics[sample_id].update(sample_metric)\n\n            # Store debug info for later saving\n            if self.debug and self.predictions_dir:\n                all_debug_info[mode_name] = (debug_info, debug_filename, mode_metrics)\n\n        # Save debug info\n        if self.debug and self.predictions_dir:\n            for mode_name, (debug_info, debug_filename, mode_metrics) in all_debug_info.items():\n                # For single mode, include all metrics; for both mode, filter by prefix\n                if evaluation_mode == \"both\":\n                    prefix = \"pid_\" if mode_name == \"pid\" else \"\"\n                    filtered_metrics = {\n                        k: v for k, v in mode_metrics.items()\n                        if k == \"total_samples\" or k.startswith(prefix)\n                    }\n                    filtered_metrics[\"total_samples\"] = total_samples\n                else:\n                    filtered_metrics = dict(metrics)\n\n                self._save_debug_info(debug_info, filtered_metrics, debug_filename)\n\n        # Record configuration\n        metrics[\"select_k_strategy\"] = select_k_strategy\n        metrics[\"evaluation_mode\"] = evaluation_mode\n        if evaluation_mode in (\"pid\", \"both\"):\n            metrics[\"sid_to_pid_strategy\"] = sid_to_pid_strategy\n\n        return metrics, per_sample_metrics\n\n    def _save_debug_info(self, debug_info: Dict[str, Any], metrics: Dict[str, Any], debug_filename: str = None):\n        \"\"\"\n        Save detailed debug information to file\n\n        Args:\n            debug_info: Debug information dictionary\n            metrics: Overall metrics\n            debug_filename: Optional custom filename (absolute path or relative to predictions_dir)\n        \"\"\"\n        # Add statistics to debug_info\n        debug_info[\"statistics\"] = {\n            \"total_samples\": metrics.get(\"total_samples\", 0),\n            \"passed_samples_count\": len(debug_info.get(\"passed_samples\", [])),\n            \"failed_samples_count\": len(debug_info.get(\"failed_samples\", [])),\n            \"no_generation_samples_count\": len(debug_info.get(\"no_generation_samples\", [])),\n        }\n\n        # Add metrics\n        debug_info[\"metrics\"] = metrics\n\n        # Use default filename if not specified\n        if debug_filename is None:\n            debug_filename = \"debug.json\"\n\n        # Save debug info to file using base class method\n        self._save_debug_json(debug_info, filename=debug_filename)\n\n        console.print(f\"Total samples: {metrics['total_samples']}\")\n        console.print(f\"Passed samples: {len(debug_info['passed_samples'])}\")\n        console.print(f\"Failed samples: {len(debug_info['failed_samples'])}\")\n        console.print(f\"No generation samples: {len(debug_info['no_generation_samples'])}\")\n\n        # Print metrics\n        console.print(\"\\n[bold]Metrics:[/bold]\")\n        for metric_name, metric_value in metrics.items():\n            if metric_name != \"total_samples\":\n                console.print(f\"  {metric_name}: {metric_value}\")\n\n        # Show some failed examples\n        if debug_info[\"failed_samples\"]:\n            console.print(f\"\\n[yellow]Failed sample examples (first 3):[/yellow]\")\n            for i, item in enumerate(debug_info[\"failed_samples\"][:3]):\n                console.print(f\"  Example {i+1}:\")\n                console.print(f\"    Sample ID: {item['sample_id']}\")\n                # Handle both SID and PID modes\n                if 'ground_truth_sids' in item:\n                    console.print(f\"    Ground truth SIDs: {item['ground_truth_sids']}\")\n                elif 'ground_truth_pids' in item:\n                    console.print(f\"    Ground truth PIDs: {item['ground_truth_pids']}\")\n                console.print(f\"    Top 5 generations: {item['top_10_generations'][:5]}\")\n                console.print()\n\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/recommendation/utils.py",
    "content": "\"\"\"\nRecommendation Task Utilities\n\nFunctions for SID extraction and recommendation metrics computation.\n\"\"\"\n\nfrom typing import Set, Dict, List, Any\n\n\ndef extract_ids_from_answer(answer: str) -> list[str]:\n    \"\"\"Extract all SIDs from answer field, preserving original order.\n\n    Returns a deduplicated list that keeps the first occurrence order.\n\n    >>> extract_ids_from_answer(\"<|sid_begin|>123<|sid_end|><|sid_begin|>456<|sid_end|>\")\n    ['123', '456']\n    \"\"\"\n    seen: set[str] = set()\n    correct_answers: list[str] = []\n    for part in answer.split('<|sid_begin|>'):\n        if '<|sid_end|>' in part:\n            sid = part.split('<|sid_end|>')[0].strip()\n            if sid and sid not in seen:\n                correct_answers.append(sid)\n                seen.add(sid)\n    return correct_answers\n\n\ndef extract_first_id_from_answer(answer: str) -> str:\n    \"\"\"\n    Extract the first SID from answer field\n    \n    Args:\n        answer: String containing multiple <|sid_begin|>...<|sid_end|> patterns\n    \n    Returns:\n        The first extracted SID, or empty string if none found\n    \n    Examples:\n        >>> extract_first_id_from_answer(\"<|sid_begin|>123<|sid_end|><|sid_begin|>456<|sid_end|>\")\n        '123'\n    \"\"\"\n    for part in answer.split('<|sid_begin|>'):\n        if '<|sid_end|>' in part:\n            sid = part.split('<|sid_end|>')[0].strip()\n            if sid:\n                return sid\n    return \"\"\n\n\ndef extract_id_from_generation(generation: str) -> str:\n    \"\"\"\n    Extract SID from model generation\n\n    The generation may contain:\n    - SID directly: \"123\"\n    - Wrapped in tags: \"<|sid_begin|>123<|sid_end|>\"\n    - With thinking: \"<think>...</think>\\\\n<|sid_begin|>123\" (two-stage generation)\n\n    Args:\n        generation: Model generation string\n\n    Returns:\n        Extracted SID, or the stripped generation if no pattern found\n\n    Examples:\n        >>> extract_id_from_generation(\"<|sid_begin|>123<|sid_end|>\")\n        '123'\n        >>> extract_id_from_generation(\"123\")\n        '123'\n        >>> extract_id_from_generation(\"<think>reasoning</think>\\\\n<|sid_begin|>123\")\n        '123'\n    \"\"\"\n    generation = generation.strip()\n\n    # If generation contains </think>, only process content after it\n    if '</think>' in generation:\n        generation = generation.split('</think>')[-1].strip()\n\n    # Try to extract from <|sid_begin|>...<|sid_end|> pattern\n    if '<|sid_begin|>' in generation:\n        for part in generation.split('<|sid_begin|>'):\n            if '<|sid_end|>' in part:\n                sid = part.split('<|sid_end|>')[0].strip()\n                if sid:\n                    return sid\n            elif part.strip():  # No end marker, take the content after begin marker\n                return part.strip()\n\n    # Otherwise, return the stripped generation\n    return generation\n\n\ndef compute_pass_at_k(\n    predicted_sids: List[str],\n    ground_truth_sids: Set[str],\n    k: int\n) -> bool:\n    \"\"\"\n    Compute Pass@k for a single sample\n\n    Pass@k definition:\n    - Take the first k candidate SIDs from predictions\n    - If any of these k SIDs appears in the ground truth SIDs, return True\n\n    Args:\n        predicted_sids: List of predicted SIDs (already extracted from generations)\n        ground_truth_sids: Set of ground truth SIDs\n        k: Number of top predictions to consider\n\n    Returns:\n        True if any of the top-k predictions match ground truth, False otherwise\n    \"\"\"\n    if not predicted_sids or not ground_truth_sids:\n        return False\n\n    # Take first k predicted SIDs\n    top_k_sids = predicted_sids[:k]\n\n    # Check if any matches ground truth\n    for sid in top_k_sids:\n        if sid in ground_truth_sids:\n            return True\n\n    return False\n\n\ndef compute_position1_pass_at_k(\n    predicted_sids: List[str],\n    first_ground_truth_sid: str,\n    k: int\n) -> bool:\n    \"\"\"\n    Compute Position1_Pass@k for a single sample\n\n    Position1_Pass@k definition:\n    - Take the first k candidate SIDs from predictions\n    - Only consider the first SID in the ground truth\n    - If any of these k SIDs matches the first ground truth, return True\n\n    Args:\n        predicted_sids: List of predicted SIDs (already extracted from generations)\n        first_ground_truth_sid: The first ground truth SID\n        k: Number of top predictions to consider\n\n    Returns:\n        True if any of the top-k predictions match the first ground truth, False otherwise\n    \"\"\"\n    if not predicted_sids or not first_ground_truth_sid:\n        return False\n\n    # Take first k predicted SIDs\n    top_k_sids = predicted_sids[:k]\n\n    # Check if any matches the first ground truth\n    for sid in top_k_sids:\n        if sid == first_ground_truth_sid:\n            return True\n\n    return False\n\n\ndef compute_recall_at_k(\n    predicted_sids: List[str],\n    ground_truth_sids: Set[str],\n    k: int\n) -> float:\n    \"\"\"\n    Compute Recall@k for a single sample\n\n    Recall@k definition:\n    - Take the first k candidate SIDs from predictions\n    - Count how many unique ground truth SIDs are hit by these k SIDs\n    - Return the ratio: hit_count / total_ground_truth_count\n\n    Args:\n        predicted_sids: List of predicted SIDs (already extracted from generations)\n        ground_truth_sids: Set of ground truth SIDs\n        k: Number of top predictions to consider\n\n    Returns:\n        Recall@k score (0.0 to 1.0)\n\n    Examples:\n        >>> predicted_sids = [\"123\", \"456\", \"999\", \"888\"]\n        >>> ground_truth_sids = {\"123\", \"456\", \"789\"}\n        >>> compute_recall_at_k(predicted_sids, ground_truth_sids, k=2)\n        0.6667  # Hit 2 out of 3 ground truth SIDs\n        >>> compute_recall_at_k(predicted_sids, ground_truth_sids, k=4)\n        0.6667  # Still hit only 2, since 789 is not in top-4\n    \"\"\"\n    if not predicted_sids or not ground_truth_sids:\n        return 0.0\n\n    # Take first k predicted SIDs\n    top_k_sids = predicted_sids[:k]\n\n    # Convert to set and filter out empty strings\n    predicted_sids_set = set(sid for sid in top_k_sids if sid)\n\n    # Count how many ground truth SIDs are hit\n    hit_count = len(predicted_sids_set & ground_truth_sids)  # Set intersection\n\n    # Calculate recall\n    recall = hit_count / len(ground_truth_sids)\n\n    return recall\n\n\ndef get_unique_generations(\n    generations: List[str],\n    max_count: int,\n    logprobs: List[float] = None,\n    exclude_sids: Set[str] = None,\n    sources: List[str] = None\n):\n    \"\"\"\n    Get first N unique SIDs from generations, optionally sorted by logprobs\n\n    This function extracts unique SIDs, optionally sorting by logprobs first.\n    Useful for merging results from multiple generation runs.\n\n    Args:\n        generations: List of model generation strings (may contain <|sid_begin|>...<|sid_end|> or <think>...</think>)\n        max_count: Maximum number of unique SIDs to return\n        logprobs: Optional list of log probabilities (same length as generations). If provided, sorts by logprobs (descending) before extracting unique SIDs\n        exclude_sids: Optional set of SIDs to exclude from results\n        sources: Optional list of source labels (same length as generations). If provided, returns tuple (sids, sources)\n\n    Returns:\n        List of unique SIDs (up to max_count), sorted by logprobs if provided, otherwise in generation order\n        If sources provided, returns tuple (List[str], List[str]) of (unique_sids, corresponding_sources)\n\n    Examples:\n        >>> gens = [\"<|sid_begin|>123<|sid_end|>\", \"456\", \"<think>...</think>\\\\n123\", \"789\", \"456\", \"999\"]\n        >>> get_unique_generations(gens, max_count=3)\n        ['123', '456', '789']\n        >>> get_unique_generations(gens, max_count=3, logprobs=[-0.5, -1.2, -0.8, -0.3, -1.5, -2.0])\n        ['789', '123', '456']  # Sorted by logprobs first\n        >>> get_unique_generations(gens, max_count=3, exclude_sids={'456', '789'})\n        ['123', '999']  # Excluded '456' and '789'\n        >>> get_unique_generations(gens, max_count=3, sources=['a', 'b', 'a', 'c', 'b', 'd'])\n        (['123', '456', '789'], ['a', 'b', 'c'])\n    \"\"\"\n    # Track sources if provided\n    track_sources = sources is not None and len(sources) == len(generations)\n\n    # If logprobs provided, sort generations by logprobs (descending)\n    if logprobs is not None and len(logprobs) == len(generations):\n        # Create tuples and sort by logprob (descending)\n        if track_sources:\n            gen_data = list(zip(generations, logprobs, sources))\n            gen_data.sort(key=lambda x: x[1], reverse=True)\n            sorted_generations = [gen for gen, _, _ in gen_data]\n            sorted_sources = [src for _, _, src in gen_data]\n        else:\n            gen_logprob_pairs = list(zip(generations, logprobs))\n            gen_logprob_pairs.sort(key=lambda x: x[1], reverse=True)\n            sorted_generations = [gen for gen, _ in gen_logprob_pairs]\n            sorted_sources = None\n    else:\n        sorted_generations = generations\n        sorted_sources = sources if track_sources else None\n\n    seen = set()\n    unique_sids = []\n    unique_sources = [] if track_sources else None\n    exclude = exclude_sids or set()\n\n    for i, gen in enumerate(sorted_generations):\n        # Skip empty strings\n        if not gen or not gen.strip():\n            continue\n\n        # Extract SID from generation text\n        sid = extract_id_from_generation(gen)\n\n        # Skip if SID is empty, already seen, or in exclude list\n        if not sid or sid in seen or sid in exclude:\n            continue\n\n        unique_sids.append(sid)\n        seen.add(sid)\n\n        if track_sources:\n            unique_sources.append(sorted_sources[i])\n\n        # Stop if we've collected enough unique SIDs\n        if len(unique_sids) >= max_count:\n            break\n\n    if track_sources:\n        return unique_sids, unique_sources\n    return unique_sids\n\n\ndef get_debug_info(\n    sample_id: str,\n    generations: List[str],\n    ground_truth: str,\n    pass_results: Dict[str, bool],\n    position1_pass_results: Dict[str, bool],\n    raw_prompt: str = \"\"\n) -> Dict[str, Any]:\n    \"\"\"\n    Prepare debug information for a sample\n\n    Args:\n        sample_id: Sample ID\n        generations: List of generated SIDs\n        ground_truth: Ground truth answer string\n        pass_results: Pass@k results for this sample\n        position1_pass_results: Position1_Pass@k results for this sample\n        raw_prompt: Raw prompt (optional)\n\n    Returns:\n        Debug information dictionary\n    \"\"\"\n    ground_truth_sids = extract_ids_from_answer(ground_truth)\n    first_ground_truth_sid = extract_first_id_from_answer(ground_truth)\n\n    # Extract top-k generated IDs\n    top_k_sids = [extract_id_from_generation(gen) for gen in generations[:10]]  # Show top-10\n\n    debug_item = {\n        \"sample_id\": sample_id,\n        \"ground_truth_sids\": list(ground_truth_sids),\n        \"first_ground_truth_sid\": first_ground_truth_sid,\n        \"top_10_generations\": top_k_sids,\n        \"pass_results\": pass_results,\n        \"position1_pass_results\": position1_pass_results,\n    }\n\n    if raw_prompt:\n        debug_item[\"raw_prompt_snippet\"] = raw_prompt[:200] + \"...\" if len(raw_prompt) > 200 else raw_prompt\n\n    return debug_item\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/recommendation/utils_by_pid.py",
    "content": "\"\"\"\nRecommendation Task Utilities (PID-based)\n\nFunctions for PID extraction and recommendation metrics computation using PIDs.\n\"\"\"\n\nimport re\nimport json\nimport random\nfrom typing import Set, Dict, List, Any, Tuple, Optional\nfrom pathlib import Path\nfrom collections import Counter\n\n\n# Encoding constants for (code1, code2, code3) -> single int\n# Each code is in range [0, 8192], needs 13 bits\nCODE_MULTIPLIER_1 = 8192 * 8192  # 67108864\nCODE_MULTIPLIER_2 = 8192\n\n\ndef load_pid_mapping(mapping_path: str) -> Dict[int, List[Dict[str, int]]]:\n    \"\"\"\n    Load SID to PID mapping from JSON file\n\n    Args:\n        mapping_path: Path to the JSON file containing SID to PID mapping\n\n    Returns:\n        Dictionary mapping encoded SID (int) to list of PID info dictionaries\n        Format: {encoded_sid: [{\"pid\": pid1, \"count\": count1, \"count_after_downsample\": count2}, ...]}\n        PIDs are sorted by original count in descending order\n    \"\"\"\n    mapping_path = Path(mapping_path)\n    if not mapping_path.exists():\n        raise FileNotFoundError(f\"PID mapping file not found: {mapping_path}\")\n\n    with open(mapping_path, 'r') as f:\n        sid_to_pid_json = json.load(f)\n    \n    # Convert string keys back to integers\n    code_to_pid = {int(k): v for k, v in sid_to_pid_json.items()}\n\n    print(f\"[INFO] Loaded {len(code_to_pid)} SID to PID mappings from {mapping_path}\")\n    return code_to_pid\n\n\ndef encode_sid(c1: int, c2: int, c3: int) -> int:\n    \"\"\"\n    Encode (code1, code2, code3) into a single integer key\n\n    Args:\n        c1, c2, c3: SID code components\n\n    Returns:\n        Encoded integer key\n    \"\"\"\n    return c1 * CODE_MULTIPLIER_1 + c2 * CODE_MULTIPLIER_2 + c3\n\n\ndef extract_sid_codes_from_text(text: str) -> Optional[Tuple[int, int, int]]:\n    \"\"\"\n    Extract SID codes from text using regex pattern\n\n    Args:\n        text: Input text containing SID patterns like <|sid_begin|><s_a_1><s_b_2><s_c_3><|sid_end|>\n\n    Returns:\n        Tuple (a, b, c) representing extracted SID codes, or None if not found\n        Expects exactly one SID in the text\n    \"\"\"\n    pattern = r'<s_a_(\\d+)><s_b_(\\d+)><s_c_(\\d+)>'\n    matches = re.findall(pattern, text)\n    if not matches:\n        return None\n    if len(matches) > 1:\n        # Log warning but use first match\n        print(f\"[WARNING] Expected 1 SID code, got {len(matches)}, using first\")\n    return (int(matches[0][0]), int(matches[0][1]), int(matches[0][2]))\n\n\ndef _get_id_from_info(info: Dict[str, int]) -> int:\n    \"\"\"\n    Extract ID from info dict, supporting both 'pid' and 'iid' keys.\n\n    Args:\n        info: Dictionary containing either 'pid' or 'iid' key\n\n    Returns:\n        The ID value (int)\n    \"\"\"\n    return info.get(\"pid\", info.get(\"iid\", 0))\n\n\ndef apply_sid_to_pid_strategy(pid_info_list: List[Dict[str, int]], strategy: str) -> int:\n    \"\"\"\n    Apply strategy to select a single PID from a list\n\n    Args:\n        pid_info_list: List of PID info dictionaries\n                      Format: [{\"pid\": pid1, \"count\": count1, \"count_after_downsample\": count2}, ...]\n                      or [{\"iid\": iid1, \"count\": count1, \"count_after_downsample\": count2}, ...] for product\n        strategy: One of \"most_popular_originally\", \"most_popular_after_downsampling\", or \"random\"\n\n    Returns:\n        Selected PID/IID (int), or 0 if list is empty\n\n    Strategies:\n        - \"most_popular_originally\": Return the PID with highest original count (already sorted)\n        - \"most_popular_after_downsampling\": Return the PID with highest downsampled count (random if tie)\n        - \"random\": Randomly select one PID from the list\n    \"\"\"\n    if not pid_info_list:\n        return 0\n\n    if strategy == \"most_popular_originally\":\n        # Return the first PID/IID (highest original count, already sorted)\n        return _get_id_from_info(pid_info_list[0])\n    elif strategy == \"most_popular_after_downsampling\":\n        # Find max downsampled count\n        max_count = max(info[\"count_after_downsample\"] for info in pid_info_list)\n        # Get all PIDs/IIDs with max downsampled count\n        max_pids = [_get_id_from_info(info) for info in pid_info_list if info[\"count_after_downsample\"] == max_count]\n        # Randomly select one if there are ties\n        return random.choice(max_pids)\n    elif strategy == \"random\":\n        # Randomly select a PID/IID\n        return random.choice([_get_id_from_info(info) for info in pid_info_list])\n    else:\n        raise ValueError(f\"Unknown strategy: {strategy}. Must be 'most_popular_originally', 'most_popular_after_downsampling', or 'random'\")\n\n\ndef extract_ids_from_answer(answer: list[int]) -> list[int]:\n    \"\"\"Extract all PIDs from answer field, preserving original order.\n\n    Returns a deduplicated list that keeps the first occurrence order.\n\n    >>> extract_ids_from_answer([123, 456, 123, 789])\n    [123, 456, 789]\n    \"\"\"\n    seen: set[int] = set()\n    correct_answers: list[int] = []\n    for pid in answer:\n        if pid != 0 and pid not in seen:\n            correct_answers.append(pid)\n            seen.add(pid)\n    return correct_answers\n\n\ndef extract_first_id_from_answer(answer: List[int]) -> int:\n    \"\"\"\n    Extract the first PID from answer field\n\n    Examples:\n        >>> extract_first_id_from_answer([123, 456, 789])\n        123\n    \"\"\"\n    valid_pids = [pid for pid in answer if pid != 0]\n    return valid_pids[0] if valid_pids else 0\n\n\ndef extract_id_from_generation(\n    generation: str,\n    code_to_pid: Dict[int, List[Dict[str, int]]],\n    strategy: str = \"most_popular_originally\"\n) -> int:\n    \"\"\"\n    Extract PID from model generation\n\n    The generation may contain:\n    - SID wrapped in tags: \"<|sid_begin|><s_a_1><s_b_2><s_c_3><|sid_end|>\"\n    - With thinking: \"<think>...</think>\\\\n<|sid_begin|><s_a_1>...\"\n\n    Args:\n        generation: Model generation string (contains exactly one SID)\n        code_to_pid: Mapping dictionary {encoded_sid: [{\"pid\": pid, \"count\": ..., \"count_after_downsample\": ...}, ...]}\n        strategy: Strategy for selecting PID (\"most_popular_originally\", \"most_popular_after_downsampling\", or \"random\")\n\n    Returns:\n        Extracted PID (int), or 0 if not found\n\n    Examples:\n        >>> extract_id_from_generation(\"<|sid_begin|><s_a_1><s_b_2><s_c_3><|sid_end|>\", code_to_pid)\n        12345  # Assuming this SID maps to PID 12345\n    \"\"\"\n    generation = generation.strip()\n\n    # If generation contains </think>, only process content after it\n    if '</think>' in generation:\n        generation = generation.split('</think>')[-1].strip()\n\n    # Extract SID codes from the generation (should be exactly one)\n    sid_codes = extract_sid_codes_from_text(generation)\n\n    if sid_codes is None:\n        return 0\n\n    # Encode SID and look up PID list\n    encoded = encode_sid(*sid_codes)\n    pid_freq_list = code_to_pid.get(encoded, [])\n\n    # Apply strategy to select PID\n    return apply_sid_to_pid_strategy(pid_freq_list, strategy)\n\n\ndef compute_pass_at_k(\n    predicted_ids: List[int],\n    ground_truth_ids: Set[int],\n    k: int\n) -> bool:\n    \"\"\"\n    Compute Pass@k for a single sample using PIDs\n\n    Pass@k definition:\n    - Take the first k candidate PIDs from predictions\n    - If any of these k PIDs appears in the ground truth PIDs, return True\n\n    Args:\n        predicted_ids: List of predicted PIDs (already extracted from generations)\n        ground_truth_ids: Set of ground truth PIDs\n        k: Number of top predictions to consider\n\n    Returns:\n        True if any of the top-k predictions match ground truth, False otherwise\n    \"\"\"\n    if not predicted_ids or not ground_truth_ids:\n        return False\n\n    # Take first k predicted PIDs\n    top_k_ids = predicted_ids[:k]\n\n    # Check if any matches ground truth\n    for pid in top_k_ids:\n        if pid != 0 and pid in ground_truth_ids:\n            return True\n\n    return False\n\n\ndef compute_position1_pass_at_k(\n    predicted_ids: List[int],\n    first_ground_truth_id: int,\n    k: int\n) -> bool:\n    \"\"\"\n    Compute Position1_Pass@k for a single sample using PIDs\n\n    Position1_Pass@k definition:\n    - Take the first k candidate PIDs from predictions\n    - Only consider the first PID in the ground truth\n    - If any of these k PIDs matches the first ground truth, return True\n\n    Args:\n        predicted_ids: List of predicted PIDs (already extracted from generations)\n        first_ground_truth_id: The first ground truth PID\n        k: Number of top predictions to consider\n\n    Returns:\n        True if any of the top-k predictions match the first ground truth, False otherwise\n    \"\"\"\n    if not predicted_ids or not first_ground_truth_id or first_ground_truth_id == 0:\n        return False\n\n    # Take first k predicted PIDs\n    top_k_ids = predicted_ids[:k]\n\n    # Check if any matches the first ground truth\n    for pid in top_k_ids:\n        if pid != 0 and pid == first_ground_truth_id:\n            return True\n\n    return False\n\n\ndef compute_recall_at_k(\n    predicted_ids: List[int],\n    ground_truth_ids: Set[int],\n    k: int\n) -> float:\n    \"\"\"\n    Compute Recall@k for a single sample using PIDs\n\n    Recall@k definition:\n    - Take the first k candidate PIDs from predictions\n    - Count how many unique ground truth PIDs are hit by these k PIDs\n    - Return the ratio: hit_count / total_ground_truth_count\n\n    Args:\n        predicted_ids: List of predicted PIDs (already extracted from generations)\n        ground_truth_ids: Set of ground truth PIDs\n        k: Number of top predictions to consider\n\n    Returns:\n        Recall@k score (0.0 to 1.0)\n\n    Examples:\n        >>> predicted_ids = [123, 456, 999, 888]\n        >>> ground_truth_ids = {123, 456, 789}\n        >>> compute_recall_at_k(predicted_ids, ground_truth_ids, k=2)\n        0.6667  # Hit 2 out of 3 ground truth PIDs\n    \"\"\"\n    if not predicted_ids or not ground_truth_ids:\n        return 0.0\n\n    # Take first k predicted PIDs\n    top_k_ids = predicted_ids[:k]\n\n    # Convert to set and filter out zeros\n    predicted_ids_set = set(pid for pid in top_k_ids if pid != 0)\n\n    # Count how many ground truth PIDs are hit\n    hit_count = len(predicted_ids_set & ground_truth_ids)  # Set intersection\n\n    # Calculate recall\n    recall = hit_count / len(ground_truth_ids)\n\n    return recall\n\n\ndef get_unique_generations(\n    generations: List[str],\n    max_count: int,\n    code_to_pid: Dict[int, List[Dict[str, int]]],\n    strategy: str = \"most_popular_originally\",\n    logprobs: List[float] = None,\n    exclude_ids: Set[int] = None,\n    sources: List[str] = None\n):\n    \"\"\"\n    Get first N unique PIDs from generations, optionally sorted by logprobs\n\n    This function extracts unique PIDs, optionally sorting by logprobs first.\n    Useful for merging results from multiple generation runs.\n\n    Args:\n        generations: List of model generation strings containing SID patterns\n        max_count: Maximum number of unique PIDs to return\n        code_to_pid: Mapping dictionary {encoded_sid: [{\"pid\": pid, \"count\": ..., \"count_after_downsample\": ...}, ...]}\n        strategy: Strategy for selecting PID (\"most_popular_originally\", \"most_popular_after_downsampling\", or \"random\")\n        logprobs: Optional list of log probabilities (same length as generations)\n        exclude_ids: Optional set of PIDs to exclude from results\n        sources: Optional list of source labels (same length as generations)\n\n    Returns:\n        List of unique PIDs (up to max_count), sorted by logprobs if provided\n        If sources provided, returns tuple (List[int], List[str]) of (unique_pids, corresponding_sources)\n    \"\"\"\n    # Track sources if provided\n    track_sources = sources is not None and len(sources) == len(generations)\n\n    # If logprobs provided, sort generations by logprobs (descending)\n    if logprobs is not None and len(logprobs) == len(generations):\n        # Create tuples and sort by logprob (descending)\n        if track_sources:\n            gen_data = list(zip(generations, logprobs, sources))\n            gen_data.sort(key=lambda x: x[1], reverse=True)\n            sorted_generations = [gen for gen, _, _ in gen_data]\n            sorted_sources = [src for _, _, src in gen_data]\n        else:\n            gen_logprob_pairs = list(zip(generations, logprobs))\n            gen_logprob_pairs.sort(key=lambda x: x[1], reverse=True)\n            sorted_generations = [gen for gen, _ in gen_logprob_pairs]\n            sorted_sources = None\n    else:\n        sorted_generations = generations\n        sorted_sources = sources if track_sources else None\n\n    seen = set()\n    unique_pids = []\n    unique_sources = [] if track_sources else None\n    exclude = exclude_ids or set()\n\n    for i, gen in enumerate(sorted_generations):\n        # Skip empty strings\n        if not gen or not gen.strip():\n            continue\n\n        # Extract PID from generation text\n        pid = extract_id_from_generation(gen, code_to_pid, strategy)\n\n        # Skip if PID is 0 (not found), already seen, or in exclude list\n        if pid == 0 or pid in seen or pid in exclude:\n            continue\n\n        unique_pids.append(pid)\n        seen.add(pid)\n\n        if track_sources:\n            unique_sources.append(sorted_sources[i])\n\n        # Stop if we've collected enough unique PIDs\n        if len(unique_pids) >= max_count:\n            break\n\n    if track_sources:\n        return unique_pids, unique_sources\n    return unique_pids\n\n\ndef get_debug_info(\n    sample_id: str,\n    generations: List[str],\n    ground_truth: List[int],\n    pass_results: Dict[str, bool],\n    position1_pass_results: Dict[str, bool],\n    code_to_pid: Dict[int, List[Dict[str, int]]],\n    strategy: str = \"most_popular_originally\",\n    raw_prompt: str = \"\"\n) -> Dict[str, Any]:\n    \"\"\"\n    Prepare debug information for a sample (PID-based)\n\n    Args:\n        sample_id: Sample ID\n        generations: List of generated SIDs\n        ground_truth: Ground truth answer string\n        pass_results: Pass@k results for this sample\n        position1_pass_results: Position1_Pass@k results for this sample\n        code_to_pid: Mapping dictionary {encoded_sid: [{\"pid\": pid, \"count\": ..., \"count_after_downsample\": ...}, ...]}\n        strategy: Strategy for selecting PID (\"most_popular_originally\", \"most_popular_after_downsampling\", or \"random\")\n        raw_prompt: Raw prompt (optional)\n\n    Returns:\n        Debug information dictionary\n    \"\"\"\n    ground_truth_ids = extract_ids_from_answer(ground_truth)\n    first_ground_truth_id = extract_first_id_from_answer(ground_truth)\n\n    # Extract top-k generated PIDs\n    top_k_ids = [extract_id_from_generation(gen, code_to_pid, strategy) for gen in generations[:10]]\n\n    debug_item = {\n        \"sample_id\": sample_id,\n        \"ground_truth_pids\": list(ground_truth_ids),\n        \"first_ground_truth_pid\": first_ground_truth_id,\n        \"top_10_generations\": top_k_ids,\n        \"pass_results\": pass_results,\n        \"position1_pass_results\": position1_pass_results,\n    }\n\n    if raw_prompt:\n        debug_item[\"raw_prompt_snippet\"] = raw_prompt[:200] + \"...\" if len(raw_prompt) > 200 else raw_prompt\n\n    return debug_item\n"
  },
  {
    "path": "benchmarks/benchmark/tasks/v1_0/registry.py",
    "content": "\"\"\"\nTask Registry - Unified Task Registration\n\nThis module consolidates:\n- loader_factory.py\n- evaluator_factory.py  \n- tasks.py\n\nPurpose: Each task is defined in ONE place only, avoiding duplication across multiple files.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Type, Dict, Any, Optional\n\n# ===== Import all configs =====\nfrom .label_pred.config import LABEL_PRED_CONFIG\nfrom .item_understand.config import ITEM_UNDERSTAND_CONFIG\nfrom .rec_reason.config import REC_REASON_CONFIG\nfrom .recommendation.config import (\n    LABEL_COND_CONFIG,\n    VIDEO_CONFIG,\n    PRODUCT_CONFIG,\n    AD_CONFIG,\n    INTERACTIVE_CONFIG,\n)\n\n# ===== Import base loader =====\nfrom .base_loader import BaseLoader\n\n# ===== Import all evaluators =====\nfrom .label_pred.evaluator import LabelPredEvaluator\nfrom .item_understand.evaluator import ItemUnderstandEvaluator\nfrom .rec_reason.evaluator import RecoReasonEvaluator\nfrom .recommendation.evaluator import RecommendationEvaluator\n\n\n@dataclass\nclass TaskRegistration:\n    \"\"\"Task registration information\"\"\"\n    name: str\n    config: Dict[str, Any]\n    evaluator_class: Type\n    category: str  # \"general\", \"recommendation\", \"caption\"\n\n\n# ========================================\n# Unified Task Registry\n# ========================================\nTASK_REGISTRY: Dict[str, TaskRegistration] = {\n    \"label_cond\": TaskRegistration(\n        name=\"label_cond\",\n        config=LABEL_COND_CONFIG,\n        evaluator_class=RecommendationEvaluator,\n        category=\"recommendation\"\n    ),\n    \"video\": TaskRegistration(\n        name=\"video\",\n        config=VIDEO_CONFIG,\n        evaluator_class=RecommendationEvaluator,\n        category=\"recommendation\"\n    ),\n    \"product\": TaskRegistration(\n        name=\"product\",\n        config=PRODUCT_CONFIG,\n        evaluator_class=RecommendationEvaluator,\n        category=\"recommendation\"\n    ),\n    \"ad\": TaskRegistration(\n        name=\"ad\",\n        config=AD_CONFIG,\n        evaluator_class=RecommendationEvaluator,\n        category=\"recommendation\"\n    ),\n    \"interactive\": TaskRegistration(\n        name=\"interactive\",\n        config=INTERACTIVE_CONFIG,\n        evaluator_class=RecommendationEvaluator,\n        category=\"recommendation\"\n    ),\n    \"label_pred\": TaskRegistration(\n        name=\"label_pred\",\n        config=LABEL_PRED_CONFIG,\n        evaluator_class=LabelPredEvaluator,\n        category=\"recommendation\"\n    ),\n    \"item_understand\": TaskRegistration(\n        name=\"item_understand\",\n        config=ITEM_UNDERSTAND_CONFIG,\n        evaluator_class=ItemUnderstandEvaluator,\n        category=\"caption\"\n    ),\n    \"rec_reason\": TaskRegistration(\n        name=\"rec_reason\",\n        config=REC_REASON_CONFIG,\n        evaluator_class=RecoReasonEvaluator,\n        category=\"caption\"\n    ),\n}\n\n\n# ========================================\n# Factory Functions\n# ========================================\n\ndef get_loader(task_name: str, data_dir: str, tokenizer: Optional[Any] = None, enable_thinking: Optional[bool] = None):\n    \"\"\"\n    Get loader instance for a task\n\n    Replaces loader_factory.get_loader()\n\n    Args:\n        task_name: Name of the task\n        benchmark_version: Version of the benchmark (used for task selection, not passed to loader)\n        data_dir: Data directory path\n        tokenizer: Tokenizer instance (optional, required for message-based formats)\n        enable_thinking: Enable thinking mode (optional, overrides task config if set)\n\n    Returns:\n        Loader instance\n\n    Raises:\n        ValueError: If task_name is not registered\n    \"\"\"\n    if task_name not in TASK_REGISTRY:\n        available_tasks = \", \".join(TASK_REGISTRY.keys())\n        raise ValueError(\n            f\"Unknown task: {task_name}. \"\n            f\"Available tasks: {available_tasks}\"\n        )\n\n    reg = TASK_REGISTRY[task_name]\n\n    # Create loader instance with aligned parameters\n    return BaseLoader(\n        task_config=reg.config,\n        data_dir=data_dir,\n        tokenizer=tokenizer,\n        enable_thinking=enable_thinking\n    )\n\n\ndef get_evaluator(task_name: str):\n    \"\"\"\n    Get evaluator class for a task\n    \n    Replaces evaluator_factory.get_evaluator()\n    \n    Args:\n        task_name: Name of the task\n        \n    Returns:\n        Evaluator class (not instance)\n        \n    Raises:\n        ValueError: If task_name is not registered\n    \"\"\"\n    if task_name not in TASK_REGISTRY:\n        available_tasks = \", \".join(TASK_REGISTRY.keys())\n        raise ValueError(\n            f\"Unknown task: {task_name}. \"\n            f\"Available tasks: {available_tasks}\"\n        )\n    \n    return TASK_REGISTRY[task_name].evaluator_class\n\n\ndef get_task_config(task_name: str) -> Dict[str, Any]:\n    \"\"\"\n    Get task configuration\n    \n    Args:\n        task_name: Name of the task\n        \n    Returns:\n        Task configuration dictionary\n        \n    Raises:\n        ValueError: If task_name is not registered\n    \"\"\"\n    if task_name not in TASK_REGISTRY:\n        available_tasks = \", \".join(TASK_REGISTRY.keys())\n        raise ValueError(\n            f\"Unknown task: {task_name}. \"\n            f\"Available tasks: {available_tasks}\"\n        )\n    \n    return TASK_REGISTRY[task_name].config\n\n\ndef get_all_tasks() -> list:\n    \"\"\"\n    Get list of all registered task names\n    \n    Returns:\n        List of task names\n    \"\"\"\n    return list(TASK_REGISTRY.keys())\n\n\ndef get_tasks_by_category(category: str) -> list:\n    \"\"\"\n    Get tasks filtered by category\n    \n    Args:\n        category: Category name (\"general\", \"recommendation\", \"caption\")\n        \n    Returns:\n        List of task names in the specified category\n    \"\"\"\n    return [\n        name for name, reg in TASK_REGISTRY.items()\n        if reg.category == category\n    ]\n\n# ========================================\n# Backward Compatibility\n# ========================================\n\n# Replaces tasks.py - TaskTable\nTaskTable = {name: reg.config for name, reg in TASK_REGISTRY.items()}"
  },
  {
    "path": "benchmarks/eval_script.sh",
    "content": "#!/bin/bash\n\n# Set common variables\nMODEL_PATH=$1\nVERSION=\"${VERSION:-v1.0}\"\nBASE_OUTPUT_DIR=\"${BENCHMARK_BASE_DIR}/results/${VERSION}/results_${2}\"\nBASE_LOG_NAME=\"${BENCHMARK_BASE_DIR}/auto_eval_logs/${VERSION}/$2\"\nENABLE_THINKING=$3\n\n# Read configuration from environment variables (set by eval_script.py)\n# Fallback to hardcoded paths if not set\nBENCHMARK_BASE_DIR=\"${BENCHMARK_BASE_DIR:-/home/user/benchmark}\"\nDATA_VERSION=\"${DATA_VERSION:-v1.0}\"\n\nBENCHMARK_DATA_DIR=\"${BENCHMARK_DATA_DIR:-${BENCHMARK_BASE_DIR}/data_${DATA_VERSION}}\"\nDATA_DIR=\"$BENCHMARK_DATA_DIR\"\n\n# Create output directory and log directory\nmkdir -p \"$(dirname \"${BASE_LOG_NAME}\")\"\nmkdir -p \"$BASE_OUTPUT_DIR\"\n\n# Write debug info to log file\n{\n    echo \"========== Task Configuration ==========\"\n    echo \"DATA_DIR: $DATA_DIR\"\n    echo \"Enable Thinking: $ENABLE_THINKING\"\n    echo \"========================================\"\n} >> \"${BASE_LOG_NAME}.log\"\n\n# Build thinking arguments\nTHINKING_ARGS=\"\"\nif [ \"$ENABLE_THINKING\" = \"true\" ]; then\n    THINKING_ARGS=\"--enable_thinking\"\nfi\n\necho \"Thinking args: $THINKING_ARGS\"\n\necho \"Running all tasks\"\n\n# Task: rec_reason\npython3 -u scripts/ray-vllm/evaluate.py \\\n    --task_types rec_reason \\\n    --gpu_memory_utilization 0.9 \\\n    --model_path \"$MODEL_PATH\" \\\n    --data_dir \"$DATA_DIR\" \\\n    --output_dir \"${BASE_OUTPUT_DIR}\" \\\n    --dtype bfloat16 \\\n    --worker_batch_size 5 \\\n    --overwrite \\\n    $THINKING_ARGS >> \"${BASE_LOG_NAME}.log\" 2>&1\n\n# Task: item_understand\npython3 -u scripts/ray-vllm/evaluate.py \\\n    --task_types item_understand \\\n    --gpu_memory_utilization 0.8 \\\n    --model_path \"$MODEL_PATH\" \\\n    --data_dir \"$DATA_DIR\" \\\n    --output_dir \"${BASE_OUTPUT_DIR}\" \\\n    --dtype bfloat16 \\\n    --worker_batch_size 250 \\\n    --overwrite \\\n    $THINKING_ARGS >> \"${BASE_LOG_NAME}.log\" 2>&1\n\n# Task: ad\npython3 -u scripts/ray-vllm/evaluate.py \\\n    --task_types ad \\\n    --gpu_memory_utilization 0.8 \\\n    --model_path \"$MODEL_PATH\" \\\n    --data_dir \"$DATA_DIR\" \\\n    --output_dir \"${BASE_OUTPUT_DIR}\" \\\n    --dtype bfloat16 \\\n    --worker_batch_size 1875 \\\n    --overwrite \\\n    --num_beams 32 --num_return_sequences 32 --num_return_thinking_sequences 1 \\\n    $THINKING_ARGS >> \"${BASE_LOG_NAME}.log\" 2>&1\n\n# Task: product\npython3 -u scripts/ray-vllm/evaluate.py \\\n    --task_types product \\\n    --gpu_memory_utilization 0.8 \\\n    --model_path \"$MODEL_PATH\" \\\n    --data_dir \"$DATA_DIR\" \\\n    --output_dir \"${BASE_OUTPUT_DIR}\" \\\n    --dtype bfloat16 \\\n    --worker_batch_size 1875 \\\n    --overwrite \\\n    --num_beams 32 --num_return_sequences 32 --num_return_thinking_sequences 1 \\\n    $THINKING_ARGS >> \"${BASE_LOG_NAME}.log\" 2>&1\n\n# Task: label_cond\npython3 -u scripts/ray-vllm/evaluate.py \\\n    --task_types label_cond \\\n    --gpu_memory_utilization 0.8 \\\n    --model_path \"$MODEL_PATH\" \\\n    --data_dir \"$DATA_DIR\" \\\n    --output_dir \"${BASE_OUTPUT_DIR}\" \\\n    --dtype bfloat16 \\\n    --worker_batch_size 1875 \\\n    --overwrite \\\n    --num_beams 32 --num_return_sequences 32 --num_return_thinking_sequences 1 \\\n    $THINKING_ARGS >> \"${BASE_LOG_NAME}.log\" 2>&1\n\n# Task: video\npython3 -u scripts/ray-vllm/evaluate.py \\\n    --task_types video \\\n    --gpu_memory_utilization 0.8 \\\n    --model_path \"$MODEL_PATH\" \\\n    --data_dir \"$DATA_DIR\" \\\n    --output_dir \"${BASE_OUTPUT_DIR}\" \\\n    --dtype bfloat16 \\\n    --worker_batch_size 1875 \\\n    --overwrite \\\n    --num_beams 32 --num_return_sequences 32 --num_return_thinking_sequences 1 \\\n    $THINKING_ARGS >> \"${BASE_LOG_NAME}.log\" 2>&1\n\n# Task: interactive\npython3 -u scripts/ray-vllm/evaluate.py \\\n    --task_types interactive \\\n    --gpu_memory_utilization 0.8 \\\n    --model_path \"$MODEL_PATH\" \\\n    --data_dir \"$DATA_DIR\" \\\n    --output_dir \"${BASE_OUTPUT_DIR}\" \\\n    --dtype bfloat16 \\\n    --worker_batch_size 250 \\\n    --overwrite \\\n    --num_beams 32 --num_return_sequences 32 --num_return_thinking_sequences 1 \\\n    $THINKING_ARGS >> \"${BASE_LOG_NAME}.log\" 2>&1\n\n# Task: label_pred\npython3 -u scripts/ray-vllm/evaluate.py \\\n    --task_types label_pred \\\n    --gpu_memory_utilization 0.8 \\\n    --model_path \"$MODEL_PATH\" \\\n    --data_dir \"$DATA_DIR\" \\\n    --output_dir \"${BASE_OUTPUT_DIR}\" \\\n    --dtype bfloat16 \\\n    --worker_batch_size 3200 \\\n    --max_logprobs 10000 \\\n    --overwrite \\\n    $THINKING_ARGS >> \"${BASE_LOG_NAME}.log\" 2>&1\n\necho \"All tasks completed successfully\"\n"
  },
  {
    "path": "benchmarks/pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=45\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"onerec-benchamrk\"\nversion = \"0.1.0\"\ndescription = \"OneRec Benchmark\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = {text = \"Apache License 2.0\"}\n\n# Core dependencies - pinned to specific versions from pip list\ndependencies = [\n    \"torch==2.5.1\",\n    \"transformers==4.52.0\",\n    \"ray==2.43.0\",\n    \"vllm==0.7.3\",\n    \"gradio==4.44.1\",\n    \"datasets==3.6.0\",\n    \"safetensors==0.5.3\",\n    \"numpy==1.26.4\",\n    \"peft==0.15.2\",\n    \"accelerate==1.8.1\",\n    \"bert_score\",\n    \"pyfiglet\",\n    \"pylatexenc\",\n    \"scikit-learn\",\n    \"vertexai\",\n    \"openai\",\n    \"anthropic\"\n]\n\n[tool.setuptools]\npackages = [\"benchmark\", \"api\", \"scripts\"]\n\n[tool.setuptools.package-data]\n\"*\" = [\"*.json\", \"*.yaml\", \"*.yml\"]\n"
  },
  {
    "path": "benchmarks/requirements.txt",
    "content": "absl-py==2.1.0\naccelerate==1.8.1\naiodns==3.6.1\naiohappyeyeballs==2.6.1\naiohttp==3.11.14\naiohttp-cors==0.8.0\naiosignal==1.3.2\nairportsdata==20250224\nannotated-types==0.7.0\nanthropic==0.75.0\nantlr4-python3-runtime==4.13.2\nanyio==4.9.0\nAPScheduler==3.11.1\nastor==0.8.1\nasttokens==3.0.0\nasync-timeout==5.0.1\nattrs==25.3.0\nav==14.0.1\nbert-score==0.3.13\nblake3==1.0.4\nblinker==1.4\nboto3==1.35.97\nbotocore==1.35.98\nbraceexpand==0.1.7\nbuild==1.2.2.post1\ncachetools==4.2.4\ncchardet==2.1.7\ncertifi==2021.10.8\ncffi==2.0.0\ncharset-normalizer==2.0.12\ncityhash==0.2.4.post11\nclick==8.1.8\ncloudpickle==3.1.1\ncolorful==0.5.6\ncompressed-tensors==0.9.1\ncontourpy==1.3.1\ncramjam==2.10.0\ncryptography==3.4.8\ncupy-cuda12x==13.4.1\ncycler==0.12.1\ndatasets==3.6.0\ndecorator==5.1.1\ndecord==0.6.0\ndeepspeed==0.16.2\ndepyf==0.18.0\ndill==0.3.8\ndiskcache==5.6.3\ndistlib==0.3.9\ndistro==1.7.0\ndnspython==2.7.0\ndocstring_parser==0.17.0\neinops==0.8.0\nemail_validator==2.2.0\nexceptiongroup==1.2.2\nexecuting==2.1.0\nfastapi==0.115.11\nfastapi-cli==0.0.7\nfastparquet==2024.2.0\nfastrlock==0.8.3\nfilelock==3.18.0\nfonttools==4.55.3\nfrozenlist==1.5.0\nfsspec==2024.2.0\nfunc-timeout==4.3.5\ngguf==0.10.0\nh11==0.14.0\nhf-xet==1.2.1\nhiredis==2.4.0\nhjson==3.1.0\nhttpcore==1.0.7\nhttplib2==0.20.2\nhttptools==0.6.4\nhttpx==0.28.1\nhuggingface-hub==0.36.0\nidna==3.3\nimportlib-metadata==4.6.4\niniconfig==2.1.0\ninteregular==0.3.3\nipython==8.30.0\njedi==0.19.2\njeepney==0.7.1\nJinja2==3.1.6\njiter==0.9.0\njmespath==1.0.1\njoblib==1.5.2\njsonschema==4.23.0\njsonschema-specifications==2024.10.1\nkazoo==2.10.0\nkeyring==23.5.0\nkiwisolver==1.4.7\nlark==1.2.2\nlatex2sympy2_extended==1.10.2\nlaunchpadlib==1.10.16\nlazr.restfulclient==0.14.4\nlazr.uri==1.0.6\nllvmlite==0.43.0\nlm-format-enforcer==0.10.11\nlxml==4.9.4\nlz4==3.1.10\nMarkdown==3.7\nmarkdown-it-py==3.0.0\nMarkupSafe==3.0.2\nmath-verify==0.8.0\nmatplotlib==3.10.0\nmatplotlib-inline==0.1.7\nmdurl==0.1.2\nmistral_common==1.5.4\nmore-itertools==8.10.0\nmpi4py==4.1.1\nmpmath==1.3.0\nmsgpack==1.1.0\nmsgspec==0.19.0\nmultidict==6.2.0\nmultiprocess==0.70.16\nnest-asyncio==1.6.0\nnetworkx==3.2.1\nninja==1.11.1.3\nnltk==3.9.2\nnumba==0.60.0\nnumpy==1.26.4\nnvidia-cublas-cu11==11.11.3.6\nnvidia-cuda-cupti-cu11==11.8.87\nnvidia-cuda-nvrtc-cu11==11.8.89\nnvidia-cuda-runtime-cu11==11.8.89\nnvidia-cudnn-cu11==9.1.0.70\nnvidia-cufft-cu11==10.9.0.58\nnvidia-curand-cu11==10.3.0.86\nnvidia-cusolver-cu11==11.4.1.48\nnvidia-cusparse-cu11==11.7.5.86\nnvidia-ml-py==13.590.44\nnvidia-nccl-cu11==2.21.5\nnvidia-nvtx-cu11==11.8.86\noauthlib==3.2.0\nopenai==1.67.0\nopencensus==0.11.4\nopencensus-context==0.1.3\nopencv-python-headless==4.11.0.86\noutlines==0.1.11\noutlines_core==0.1.26\npackaging==24.2\npandas==2.2.3\nparso==0.8.4\npartial-json-parser==0.2.1.1.post5\npeft==0.15.2\npexpect==4.9.0\npillow==11.0.0\nplatformdirs==4.3.7\npluggy==1.5.0\nprettytable==2.5.0\nprometheus-fastapi-instrumentator==7.1.0\nprometheus_client==0.21.1\nprompt_toolkit==3.0.48\npropcache==0.3.0\nproto-plus==1.26.1\npsutil==7.1.3\nptyprocess==0.7.0\npure_eval==0.2.3\npy-cpuinfo==9.0.0\npy-spy==0.4.0\npyarrow==18.1.0\npyasn1==0.6.1\npyasn1_modules==0.4.1\npybind11==2.13.6\npycares==4.11.0\npycountry==24.6.1\npycparser==2.23\npycryptodome==3.23.0\npydantic==2.10.4\npydantic_core==2.27.2\npyfiglet==1.0.4\nPygments==2.18.0\nPyJWT==2.3.0\npylatexenc==2.10\npynvml==13.0.1\npyparsing==2.4.7\npyproject_hooks==1.2.0\npysmhasher==0.2.5\npytest==8.3.5\npython-dateutil==2.9.0.post0\npython-dotenv==1.0.1\npython-multipart==0.0.20\npython-snappy==0.6.1\npytz==2021.3\npytz-deprecation-shim==0.1.0.post0\nPyYAML==6.0.2\npyzmq==26.3.0\nqwen-vl-utils==0.0.8\nray==2.43.0\nredis==4.6.0\nreferencing==0.36.2\nregex==2024.11.6\nrich==13.9.4\nrich-toolkit==0.13.2\nrouge-score==0.1.2\nrpds-py==0.23.1\nrsa==4.9\ns3transfer==0.10.4\nsafetensors==0.5.3\nscikit-learn==1.7.2\nscipy==1.15.3\nSecretStorage==3.3.1\nsentencepiece==0.2.0\nsetuptools-scm==9.2.2\nshapely==2.1.2\nshellingham==1.5.4\nsix==1.16.0\nsmart-open==7.1.0\nsniffio==1.3.1\nsqlparse==0.4.4\nssh-import-id==5.11\nstack-data==0.6.3\nstarlette==0.46.1\nsympy==1.13.1\ntensorboard==2.18.0\ntensorboard-data-server==0.7.2\nthreadpoolctl==3.6.0\ntiktoken==0.9.0\ntimm==1.0.15\ntokenizers==0.21.1\ntomli==2.2.1\ntorchao==0.11.0\ntorchdata==0.10.1\ntornado==6.5.4\ntqdm==4.67.1\ntraitlets==5.14.3\ntransformers==4.52.0\ntriton==3.1.0\ntyper==0.15.2\ntyping_extensions==4.12.2\ntzdata==2024.2\ntzlocal==4.3.1\nunpaddedbase64==2.1.0\nurllib3==1.26.8\nuvicorn==0.34.0\nuvloop==0.21.0\nvertexai==1.71.1\nvirtualenv==20.29.3\nwadllib==1.3.6\nwatchfiles==1.0.4\nwcwidth==0.2.13\nwebdataset==0.2.100\nwebsockets==15.0.1\nWerkzeug==3.1.3\nwrapt==1.17.2\nxformers==0.0.28.post3\nxgrammar==0.1.11\nxmltodict==0.12.0\nxxhash==3.6.0\nyarl==1.18.3\nzipp==1.0.0\n"
  },
  {
    "path": "benchmarks/scripts/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/scripts/eval_dev_results.py",
    "content": "import argparse\n\nfrom benchmark import Benchmark\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--output_dir\",\n        type=str, required=True,\n        help=\"The directory where the generation results are saved.\"\n    )\n    parser.add_argument(\n        \"--data_dir\",\n        type=str, \n        default=None\n    )\n    parser.add_argument(\n        \"--overwrite\",\n        action=\"store_true\",\n        help=\"Whether to overwrite existing metrics and recompute from scratch\"\n    )\n    parser.add_argument(\n        \"--task_types\",\n        type=str,\n        nargs='+',\n        default=None,\n        help=\"Task name list (e.g., item_understand rec_reason). If not specified, all tasks will be evaluated.\"\n    )\n    return parser.parse_args()\n\n\ndef main():\n    args = get_args()\n    eval_results_path = f\"{args.output_dir}/eval_results.json\"\n    Benchmark.evaluate_dev(\n        generation_results_dir=args.output_dir,\n        output_path=eval_results_path,\n        data_dir=args.data_dir,\n        overwrite=args.overwrite,\n        task_types=args.task_types\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/scripts/init_ray.sh",
    "content": "#!/bin/bash\n# Single Node Ray Initialization Script\n# Usage: bash init_ray.sh <HEAD_NODE_IP> <PORT> <RANK>\n#   HEAD_NODE_IP: IP address of the head node\n#   PORT: Ray port (default: 6379)\n#   RANK: Node rank (0 for head, >0 for workers)\n\nset -e\n\n# Parse arguments\nHEAD_NODE_IP=${1:-\"127.0.0.1\"}\nPORT=${2:-6379}\nRANK=${3:-0}\n\n# Configuration\nNUM_CPUS=${NUM_CPUS:-\"\"}\nNUM_GPUS=${NUM_GPUS:-\"\"}\nOBJECT_STORE_MEMORY=${OBJECT_STORE_MEMORY:-\"\"}\nCONDA_ENV_NAME=${CONDA_ENV_NAME:-\"benchmark\"}\n\n# Colors\nGREEN='\\033[0;32m'\nYELLOW='\\033[1;33m'\nNC='\\033[0m'\n\nlog_info() {\n    echo -e \"${GREEN}[INFO]${NC} $(hostname): $1\"\n}\n\nlog_warn() {\n    echo -e \"${YELLOW}[WARN]${NC} $(hostname): $1\"\n}\n\n# Activate conda environment\nif [ -f \"/root/anaconda3/etc/profile.d/conda.sh\" ]; then\n    source \"/root/anaconda3/etc/profile.d/conda.sh\"\nelif [ -f \"$HOME/anaconda3/etc/profile.d/conda.sh\" ]; then\n    source \"$HOME/anaconda3/etc/profile.d/conda.sh\"\nelif [ -f \"$HOME/miniconda3/etc/profile.d/conda.sh\" ]; then\n    source \"$HOME/miniconda3/etc/profile.d/conda.sh\"\nfi\n\nif command -v conda &> /dev/null; then\n    conda activate ${CONDA_ENV_NAME} 2>/dev/null || log_warn \"Could not activate conda env: ${CONDA_ENV_NAME}\"\nfi\n\n# Build ray start command options\nRAY_OPTS=\"\"\nif [ -n \"${NUM_CPUS}\" ]; then\n    RAY_OPTS=\"${RAY_OPTS} --num-cpus=${NUM_CPUS}\"\nfi\nif [ -n \"${NUM_GPUS}\" ]; then\n    RAY_OPTS=\"${RAY_OPTS} --num-gpus=${NUM_GPUS}\"\nfi\nif [ -n \"${OBJECT_STORE_MEMORY}\" ]; then\n    RAY_OPTS=\"${RAY_OPTS} --object-store-memory=${OBJECT_STORE_MEMORY}\"\nfi\n\n# Stop existing Ray instance\nray stop --force 2>/dev/null || true\nsleep 2\n\n# Start Ray\nif [ \"${RANK}\" -eq 0 ]; then\n    log_info \"Starting Ray HEAD node on port ${PORT}...\"\n    ray start --head --port=${PORT} ${RAY_OPTS}\nelse\n    log_info \"Starting Ray WORKER node, connecting to ${HEAD_NODE_IP}:${PORT}...\"\n    ray start --address=${HEAD_NODE_IP}:${PORT} ${RAY_OPTS}\nfi\n\nsleep 3\n\n# Check status\nlog_info \"Ray node started. Checking status...\"\nray status\n"
  },
  {
    "path": "benchmarks/scripts/init_ray_cluster.sh",
    "content": "#!/bin/bash\n# Multi-node Ray Cluster Initialization Script\n# Usage: bash init_ray_cluster.sh [--stop]\n#   --stop: Stop Ray on all nodes instead of starting\n\nset -e\n\nSCRIPT_DIR=$(cd $(dirname $0); pwd)\nPROJECT_DIR=${SCRIPT_DIR}\n\n# Configuration\nPORT=${RAY_PORT:-6379}\nHOSTFILE=${HOSTFILE:-\"/etc/mpi/hostfile\"}\nCONDA_ENV_NAME=${CONDA_ENV_NAME:-\"benchmark\"}\nLOG_DIR=\"${PROJECT_DIR}/logs/ray\"\n\n# Colors\nRED='\\033[0;31m'\nGREEN='\\033[0;32m'\nYELLOW='\\033[1;33m'\nNC='\\033[0m'\n\nlog_info() {\n    echo -e \"${GREEN}[INFO]${NC} $1\"\n}\n\nlog_warn() {\n    echo -e \"${YELLOW}[WARN]${NC} $1\"\n}\n\nlog_error() {\n    echo -e \"${RED}[ERROR]${NC} $1\"\n}\n\n# Generate conda initialization command that works with both anaconda and miniconda\nget_conda_init_cmd() {\n    cat << 'EOF'\nfor conda_sh in /root/miniconda3/etc/profile.d/conda.sh \\\n                /root/anaconda3/etc/profile.d/conda.sh \\\n                $HOME/miniconda3/etc/profile.d/conda.sh \\\n                $HOME/anaconda3/etc/profile.d/conda.sh \\\n                /opt/conda/etc/profile.d/conda.sh; do\n    [ -f \"$conda_sh\" ] && source \"$conda_sh\" && break\ndone\nEOF\n}\n\n# Function to stop Ray on all nodes\nstop_cluster() {\n    log_info \"Stopping Ray on all nodes...\"\n\n    if [ ! -f \"${HOSTFILE}\" ]; then\n        log_warn \"Hostfile not found, stopping local Ray only\"\n        ray stop --force 2>/dev/null || true\n        return\n    fi\n\n    ALL_NODES=$(awk '!a[$1]++ {print $1}' ${HOSTFILE})\n\n    for node in ${ALL_NODES}; do\n        log_info \"Stopping Ray on ${node}...\"\n        ssh -n ${node} \"$(get_conda_init_cmd) && conda activate ${CONDA_ENV_NAME} && ray stop --force\" 2>/dev/null &\n    done\n\n    wait\n    log_info \"Ray stopped on all nodes\"\n}\n\n# Function to start Ray cluster\nstart_cluster() {\n    # Check hostfile\n    if [ ! -f \"${HOSTFILE}\" ]; then\n        log_error \"Hostfile not found: ${HOSTFILE}\"\n        log_info \"Please create a hostfile with one IP per line\"\n        log_info \"Example:\"\n        echo \"  192.168.1.100\"\n        echo \"  192.168.1.101\"\n        echo \"  192.168.1.102\"\n        exit 1\n    fi\n\n    # Get head node (first line)\n    HEAD_NODE=$(awk 'NR==1 {print $1}' ${HOSTFILE})\n    ALL_NODES=$(awk '!a[$1]++ {print $1}' ${HOSTFILE})\n\n    log_info \"Head node: ${HEAD_NODE}\"\n    log_info \"Ray port: ${PORT}\"\n    log_info \"Conda env: ${CONDA_ENV_NAME}\"\n    echo \"\"\n    log_info \"Nodes in cluster:\"\n    echo \"${ALL_NODES}\"\n    echo \"\"\n\n    # Create log directory\n    mkdir -p \"${LOG_DIR}\"\n\n    # Stop existing Ray instances first\n    log_info \"Stopping any existing Ray instances...\"\n    stop_cluster\n    sleep 3\n\n    # Start head node first (synchronously)\n    log_info \"Starting Ray HEAD on ${HEAD_NODE}...\"\n    ssh -n ${HEAD_NODE} \"CONDA_ENV_NAME=${CONDA_ENV_NAME} bash ${SCRIPT_DIR}/init_ray.sh ${HEAD_NODE} ${PORT} 0\" \\\n        > \"${LOG_DIR}/ray_${HEAD_NODE}.log\" 2>&1\n\n    if [ $? -ne 0 ]; then\n        log_error \"Failed to start Ray HEAD. Check ${LOG_DIR}/ray_${HEAD_NODE}.log\"\n        exit 1\n    fi\n    log_info \"Ray HEAD started successfully\"\n\n    # Wait for head to be ready\n    sleep 5\n\n    # Start worker nodes (asynchronously)\n    rank=1\n    for node in ${ALL_NODES}; do\n        if [ \"${node}\" == \"${HEAD_NODE}\" ]; then\n            continue\n        fi\n\n        log_info \"Starting Ray WORKER on ${node} (rank ${rank})...\"\n        ssh -n ${node} \"CONDA_ENV_NAME=${CONDA_ENV_NAME} bash ${SCRIPT_DIR}/init_ray.sh ${HEAD_NODE} ${PORT} ${rank}\" \\\n            > \"${LOG_DIR}/ray_${node}.log\" 2>&1 &\n        rank=$((rank + 1))\n    done\n\n    # Wait for all workers\n    log_info \"Waiting for all workers to join...\"\n    wait\n    sleep 3\n\n    # Check cluster status\n    echo \"\"\n    log_info \"Ray cluster initialization complete!\"\n    log_info \"Logs saved to: ${LOG_DIR}/\"\n    echo \"\"\n    log_info \"Cluster status:\"\n    ssh -n ${HEAD_NODE} \"$(get_conda_init_cmd) && conda activate ${CONDA_ENV_NAME} && ray status\"\n}\n\n# Main\ncase \"${1}\" in\n    --stop)\n        stop_cluster\n        ;;\n    *)\n        start_cluster\n        ;;\nesac\n"
  },
  {
    "path": "benchmarks/scripts/ray-vllm/evaluate.py",
    "content": "from transformers import HfArgumentParser\nimport torch\n\nfrom benchmark import Benchmark\nfrom benchmark.console import *\nfrom utils.generator import RayVllmGenerator\nfrom utils.arguments import (\n    ModelConfig,\n    InfrastructureConfig,\n    InferenceConfig,\n    GenerationConfig,\n    PromptConfig,\n    BenchmarkConfig\n)\n\n\ndef main():\n    parser = HfArgumentParser([\n        ModelConfig,\n        InfrastructureConfig,\n        InferenceConfig,\n        GenerationConfig,\n        PromptConfig,\n        BenchmarkConfig\n    ])\n    model_config, infra_config, inference_config, generation_config, prompt_config, benchmark_config = \\\n        parser.parse_args_into_dataclasses()\n\n    # 1. Initialize Benchmark\n    benchmark = Benchmark(\n        model_path=model_config.model_path,\n        task_types=benchmark_config.task_types,\n        splits=benchmark_config.splits,\n        data_dir=benchmark_config.data_dir,\n        enable_thinking=prompt_config.enable_thinking,\n    )\n    # Benchmark.print_benchmark_table()\n\n    # 2. Initialize Ray + vLLM generator (Multi-Node Support)\n    generator = RayVllmGenerator(\n        model_name_or_path=model_config.model_path,\n        checkpoint_path=model_config.checkpoint_path,\n        trust_remote_code=model_config.trust_remote_code,\n        dtype=model_config.dtype,\n        max_model_len=model_config.max_model_len,\n        max_logprobs=model_config.max_logprobs,\n        gpu_memory_utilization=infra_config.gpu_memory_utilization,\n        tensor_parallel_size=infra_config.tensor_parallel_size,\n        ray_address=infra_config.ray_address,  # Ray cluster address\n        allow_cross_node_tensor_parallel=infra_config.allow_cross_node_tensor_parallel,  # Cross-node TP\n        num_gpus=infra_config.num_gpus,\n        gpu_ids=infra_config.gpu_ids,\n        force_enable_optimizations=inference_config.force_enable_optimizations,\n        force_disable_optimizations=inference_config.force_disable_optimizations,\n        worker_batch_size=inference_config.worker_batch_size,\n        task_types=benchmark_config.task_types\n    )\n\n    # 3. Generate text\n    benchmark.run(\n        generator=generator,\n        output_dir=benchmark_config.output_dir,\n        overwrite=benchmark_config.overwrite,\n        # Generation parameters\n        enable_thinking=prompt_config.enable_thinking,\n        num_beams=generation_config.num_beams,\n        num_return_sequences=generation_config.num_return_sequences,\n        temperature=generation_config.temperature,\n        top_p=generation_config.top_p,\n        top_k=generation_config.top_k,\n        presence_penalty=generation_config.presence_penalty,\n        num_return_thinking_sequences=generation_config.num_return_thinking_sequences,\n        sample_size=benchmark_config.sample_size,\n    )\n\n    # 4. Release GPU memory occupied by vLLM\n    console.print(\"\\nReleasing vLLM GPU memory...\", style=warning_style)\n    generator.cleanup()\n    del generator\n    import gc\n    gc.collect()\n\n    # Clear CUDA cache\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n        torch.cuda.synchronize()\n    console.print(\"✓ GPU memory release completed\\n\", style=success_style)\n\n    # 5. Calculate evaluation metrics\n    eval_results_path = f\"{benchmark_config.output_dir}/eval_results.json\"\n    Benchmark.evaluate_dev(\n        generation_results_dir=benchmark_config.output_dir,\n        output_path=eval_results_path,\n        data_dir=benchmark_config.data_dir,\n        overwrite=benchmark_config.overwrite,\n        task_types=benchmark_config.task_types\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "benchmarks/scripts/ray-vllm/utils/__init__.py",
    "content": "# Ray-vLLM Utils\n\n"
  },
  {
    "path": "benchmarks/scripts/ray-vllm/utils/arguments.py",
    "content": "from dataclasses import dataclass, field\nfrom typing import Optional, List\n\n\n@dataclass\nclass ModelConfig:\n    \"\"\"Model loading and initialization parameters\"\"\"\n    model_path: str = field(\n        metadata={\"help\": \"Model path or HuggingFace model name (e.g., Qwen/Qwen2-7B)\", \"required\": True}\n    )\n    checkpoint_path: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"PT checkpoint path (optional, for loading .pt format models, will auto-convert to HuggingFace format)\"}\n    )\n    dtype: str = field(\n        default='bfloat16',\n        metadata={\"help\": \"Model data type: auto, half, float16, bfloat16, float, float32\"}\n    )\n    max_model_len: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Maximum model length (optional, for limiting context length)\"}\n    )\n    trust_remote_code: bool = field(\n        default=True,\n        metadata={\"help\": \"Whether to trust remote code\"}\n    )\n    max_logprobs: int = field(\n        default=384,\n        metadata={\"help\": \"Maximum number of log probabilities to return (for beam search and logprob extraction)\"}\n    )\n\n\n@dataclass\nclass InfrastructureConfig:\n    \"\"\"Hardware and distributed computing configuration\"\"\"\n    # GPU allocation\n    num_gpus: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Number of GPUs to use (default uses all visible GPUs)\"}\n    )\n    gpu_ids: Optional[List[int]] = field(\n        default=None,\n        metadata={\"help\": \"List of GPU IDs to use (e.g., [0,2,4], default uses all visible GPUs)\"}\n    )\n    gpu_memory_utilization: float = field(\n        default=0.5,\n        metadata={\"help\": \"GPU memory utilization (0-1, recommended 0.8)\"}\n    )\n    # Parallelism\n    tensor_parallel_size: int = field(\n        default=1,\n        metadata={\"help\": \"Tensor parallel size (default 1, single GPU per worker)\"}\n    )\n    allow_cross_node_tensor_parallel: bool = field(\n        default=False,\n        metadata={\"help\": \"Allow tensor parallelism across different nodes (not recommended due to network latency)\"}\n    )\n    # Ray cluster\n    ray_address: Optional[str] = field(\n        default=\"auto\",\n        metadata={\"help\": \"Ray cluster address: 'auto' (auto-detect), 'local' (single machine), or 'ray://head_ip:10001' (specific cluster address)\"}\n    )\n\n\n@dataclass\nclass InferenceConfig:\n    \"\"\"Inference execution and optimization parameters\"\"\"\n    # vLLM optimizations (chunked_prefill, prefix_caching)\n    force_enable_optimizations: bool = field(\n        default=False,\n        metadata={\"help\": \"Force enable chunked_prefill and prefix_caching for all tasks (overrides task-specific settings)\"}\n    )\n    force_disable_optimizations: bool = field(\n        default=False,\n        metadata={\"help\": \"Force disable chunked_prefill and prefix_caching for all tasks (overrides task-specific settings)\"}\n    )\n    # Batch processing\n    worker_batch_size: int = field(\n        default=4,\n        metadata={\"help\": \"Batch size for each worker to process prompts (reduce this if KV cache is insufficient)\"}\n    )\n\n\n@dataclass\nclass GenerationConfig:\n    \"\"\"Text generation parameters (sampling, beam search)\"\"\"\n    # Beam search\n    num_beams: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Number of beams for beam search\"}\n    )\n    # Sampling\n    num_return_sequences: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Number of sequences to return\"}\n    )\n    temperature: Optional[float] = field(\n        default=None,\n        metadata={\"help\": \"Sampling temperature\"}\n    )\n    top_p: Optional[float] = field(\n        default=None,\n        metadata={\"help\": \"Top-p (nucleus) sampling probability\"}\n    )\n    top_k: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Top-k sampling\"}\n    )\n    presence_penalty: Optional[float] = field(\n        default=None,\n        metadata={\"help\": \"Presence penalty for sampling (-2.0 to 2.0, positive values penalize new tokens based on whether they appear in the text so far)\"}\n    )\n    # Two-stage generation (thinking mode)\n    num_return_thinking_sequences: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Number of thinking candidates to generate in stage 1\"}\n    )\n\n\n@dataclass\nclass PromptConfig:\n    \"\"\"Prompt formatting and template parameters\"\"\"\n    # Thinking mode (affects both template and generation)\n    enable_thinking: bool = field(\n        default=False,\n        metadata={\"help\": \"Enable thinking mode for apply_chat_template (overrides task config if set)\"}\n    )\n\n\n@dataclass\nclass BenchmarkConfig:\n    \"\"\"Benchmark execution and evaluation parameters\"\"\"\n    # Task selection\n    task_types: Optional[List[str]] = field(\n        default=None,\n        metadata={\"help\": \"Task name list (e.g., item_understand rec_reason)\"}\n    )\n    sample_size: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Sample size for evaluation (e.g., 'full' for all data, or a number like '100')\"}\n    )\n    splits: List[str] = field(\n        default_factory=lambda: ['test'],\n        metadata={\"help\": \"Dataset split list\"}\n    )\n    # Data I/O\n    data_dir: str = field(\n        default='./data',\n        metadata={\"help\": \"Data directory path\"}\n    )\n    output_dir: str = field(\n        default='./results',\n        metadata={\"help\": \"Output directory for results\"}\n    )\n    overwrite: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether to overwrite existing results\"}\n    )\n"
  },
  {
    "path": "benchmarks/scripts/ray-vllm/utils/generator.py",
    "content": "import os\nimport ray\nimport math\nimport json\nfrom typing import Dict, List, Any, Optional\nfrom vllm import LLM, SamplingParams\nfrom vllm.sampling_params import BeamSearchParams\n\nfrom benchmark.base_generator import Generator, RayMixin, VllmMixin, DISABLE_OPTIMIZATIONS_FOR_TASKS\nfrom benchmark.checkpoint_utils import export_pt_to_safetensor\nfrom benchmark.console import *\n\n\nclass VllmWorker:\n    \"\"\"\n    vLLM Worker that can use one or more GPUs\n    \n    Each Worker is responsible for:\n    - Loading one vLLM model instance (potentially across multiple GPUs with tensor parallelism)\n    - Processing inference tasks assigned to it\n    - Returning generation results\n    \"\"\"\n    \n    def __init__(\n        self,\n        worker_id: int,\n        model_path: str,\n        gpu_ids: List[int],\n        gpu_memory_utilization: float = 0.9,\n        trust_remote_code: bool = True,\n        dtype: str = \"auto\",\n        max_model_len: Optional[int] = None,\n        tensor_parallel_size: int = 1,\n        enable_optimizations: bool = True,\n        **kwargs\n    ):\n        \"\"\"\n        Args:\n            worker_id: Worker ID\n            model_path: Model path (converted HuggingFace format)\n            gpu_ids: List of GPU IDs assigned to this worker\n            gpu_memory_utilization: GPU memory utilization\n            trust_remote_code: Whether to trust remote code\n            dtype: Data type\n            max_model_len: Maximum model length\n            tensor_parallel_size: Tensor parallel size (must match len(gpu_ids))\n            enable_optimizations: Whether to enable chunked_prefill and prefix_caching\n            **kwargs: Other vLLM parameters\n        \"\"\"\n        self.worker_id = worker_id\n        self.gpu_ids = gpu_ids\n        \n        # Set environment variable so current process only sees specified GPUs\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join(map(str, gpu_ids))\n        \n        opt_status = \"optimized\" if enable_optimizations else \"standard\"\n        gpu_str = \",\".join(map(str, gpu_ids))\n        print(f\"  [Worker {worker_id}] Initializing ({opt_status})... (GPU {gpu_str}, TP={tensor_parallel_size})\")\n        \n        # Initialize vLLM\n        vllm_kwargs = {\n            \"model\": model_path,\n            \"tensor_parallel_size\": tensor_parallel_size,\n            \"gpu_memory_utilization\": gpu_memory_utilization,\n            \"trust_remote_code\": trust_remote_code,\n            \"dtype\": dtype,\n            \"enable_chunked_prefill\": enable_optimizations,\n            \"enable_prefix_caching\": enable_optimizations,\n            \"max_logprobs\": kwargs.get(\"max_logprobs\", 384),  # Support beam search, need large enough logprobs\n        }\n\n        if max_model_len is not None:\n            vllm_kwargs[\"max_model_len\"] = max_model_len\n\n        vllm_kwargs.update(kwargs)\n        \n        try:\n            self.llm = LLM(**vllm_kwargs)\n            self.tokenizer = self.llm.get_tokenizer()\n            print(f\"  [Worker {worker_id}] ✓ Initialized successfully (GPU {gpu_str}, TP={tensor_parallel_size})\")\n        except Exception as e:\n            print(f\"  [Worker {worker_id}] ✗ Initialization failed: {e}\")\n            raise\n    \n    def get_model_parameters(self) -> Optional[float]:\n        \"\"\"\n        Get model parameter count from the worker's vLLM instance\n        \n        Returns:\n            float: Total number of parameters, or None if unable to count\n        \"\"\"\n        try:\n            model_executor = self.llm.llm_engine.model_executor\n            if hasattr(model_executor, 'driver_worker'):\n                model = model_executor.driver_worker.model_runner.model\n            else:\n                model = model_executor.model\n            \n            # Count parameters\n            total_params = sum(p.numel() for p in model.parameters())\n            return float(total_params)\n        except Exception as e:\n            print(f\"  [Worker {self.worker_id}] Warning: Failed to count parameters: {e}\")\n            return None\n    \n    def generate_batch(\n        self,\n        prompts: Dict[str, str],\n        sampling_params: Dict[str, Any],\n        worker_batch_size: int = 8\n    ) -> tuple:\n        \"\"\"\n        Batch text generation (internal batch processing to avoid vLLM scheduler issues)\n\n        Args:\n            prompts: {sample_id: prompt_text}\n            sampling_params: Sampling parameter dictionary\n            worker_batch_size: Worker internal batch size (default 8)\n\n        Returns:\n            Tuple of three dicts:\n            - First dict: {sample_id: [generated_text_1, generated_text_2, ...]}\n            - Second dict: {sample_id: [cum_logprob_1, cum_logprob_2, ...]} (only for beam search)\n            - Third dict: {sample_id: {\"input_tokens\": [int], \"output_tokens\": [int], \"times\": [float]}} (lists for multi-stage support)\n        \"\"\"\n        import time\n        stage_start_time = time.time()\n\n        if not prompts:\n            return ({}, {}, {})\n\n        # Determine whether to use BeamSearchParams or SamplingParams based on parameters\n        if sampling_params.get(\"use_beam_search\", False):\n            # Beam search mode\n            params_dict = {\n                \"beam_width\": sampling_params.get(\"beam_width\", 1),\n                \"max_tokens\": sampling_params.get(\"max_tokens\", 128),\n            }\n            sp = BeamSearchParams(**params_dict)\n        else:\n            # Sampling mode - remove parameters not belonging to SamplingParams\n            # stop parameter is already included in the dict comprehension\n            params_dict = {k: v for k, v in sampling_params.items()\n                          if k not in [\"use_beam_search\", \"beam_width\", \"return_logprobs\"]}\n\n            # If return_logprobs is enabled, add logprobs parameter\n            if sampling_params.get(\"return_logprobs\", False):\n                params_dict[\"logprobs\"] = 1  # Enable logprobs for cumulative calculation\n\n            sp = SamplingParams(**params_dict)\n\n        # Prepare input\n        sample_ids = list(prompts.keys())\n        prompt_texts = list(prompts.values())\n\n        # Batch processing to avoid vLLM scheduler issues\n        all_results = {}\n        all_logprobs = {}  # Store cum_logprobs for beam search\n        all_mfu_stats = {}  # Store MFU statistics for MFU calculation\n        num_batches = (len(sample_ids) + worker_batch_size - 1) // worker_batch_size\n        \n        for batch_idx in range(num_batches):\n            start_idx = batch_idx * worker_batch_size\n            end_idx = min(start_idx + worker_batch_size, len(sample_ids))\n            \n            batch_sample_ids = sample_ids[start_idx:end_idx]\n            batch_prompt_texts = prompt_texts[start_idx:end_idx]\n            \n            try:\n                # If using beam search, need to record each prompt's length\n                batch_prompt_lengths = []\n                if isinstance(sp, BeamSearchParams):\n                    for text in batch_prompt_texts:\n                        prompt_tokens = self.tokenizer.encode(text, add_special_tokens=True)\n                        batch_prompt_lengths.append(len(prompt_tokens))\n                \n                # Choose different generation method based on parameter type\n                if isinstance(sp, BeamSearchParams):\n                    # Beam search needs to use beam_search method\n                    # beam_search input format is [{\"prompt\": \"text\"}]\n                    batch_prompt_dicts = [{\"prompt\": text} for text in batch_prompt_texts]\n                    batch_outputs = self.llm.beam_search(batch_prompt_dicts, sp)\n                else:\n                    # Sampling mode uses generate method\n                    batch_outputs = self.llm.generate(batch_prompt_texts, sp)\n                \n                # Organize results\n                for idx, (sample_id, output) in enumerate(zip(batch_sample_ids, batch_outputs)):\n                    if isinstance(sp, BeamSearchParams):\n                        # Beam search returns complete token IDs (including prompt), need to remove prompt part before decoding\n                        prompt_length = batch_prompt_lengths[idx]\n                        generated_texts = [\n                            self.tokenizer.decode(seq.tokens[prompt_length:], skip_special_tokens=True)\n                            for seq in output.sequences\n                        ]\n                        # Extract cum_logprob for each sequence\n                        cum_logprobs = [seq.cum_logprob for seq in output.sequences]\n                        all_results[sample_id] = generated_texts\n                        all_logprobs[sample_id] = cum_logprobs\n\n                        # Collect MFU stats for beam search\n                        input_tokens = prompt_length\n                        output_tokens_list = [len(seq.tokens) - prompt_length for seq in output.sequences]\n                        all_mfu_stats[sample_id] = {\n                            \"input_tokens\": [input_tokens],\n                            \"output_tokens\": [sum(output_tokens_list)]\n                        }\n                    else:\n                        generated_texts = [out.text for out in output.outputs]\n                        all_results[sample_id] = generated_texts\n\n                        # If return_logprobs is enabled, calculate cumulative logprobs\n                        if sampling_params.get(\"return_logprobs\", False):\n                            cum_logprobs = []\n                            for out in output.outputs:\n                                # Calculate cumulative logprob by summing all token logprobs\n                                cum_logprob = 0.0\n                                if out.logprobs and out.token_ids:\n                                    # Iterate through each position and get the logprob of the actual generated token\n                                    for i, token_logprobs in enumerate(out.logprobs):\n                                        if token_logprobs and i < len(out.token_ids):\n                                            # Get the actual token ID that was generated at this position\n                                            actual_token_id = out.token_ids[i]\n                                            # Look up the logprob for this specific token\n                                            if actual_token_id in token_logprobs:\n                                                cum_logprob += token_logprobs[actual_token_id].logprob\n                                cum_logprobs.append(cum_logprob)\n                            all_logprobs[sample_id] = cum_logprobs\n\n                        # Collect MFU stats for sampling mode\n                        prompt_text = batch_prompt_texts[idx]\n                        input_tokens = len(self.tokenizer.encode(prompt_text, add_special_tokens=True))\n                        output_tokens_list = [len(out.token_ids) for out in output.outputs]\n                        all_mfu_stats[sample_id] = {\n                            \"input_tokens\": [input_tokens],\n                            \"output_tokens\": [sum(output_tokens_list)]\n                        }\n                \n            except Exception as e:\n                # When a single batch fails, return empty string and print detailed error\n                import traceback\n                print(f\"\\n[Worker {self.worker_id}] Batch {batch_idx}/{num_batches} generation failed:\")\n                print(f\"  Error type: {type(e).__name__}\")\n                print(f\"  Error message: {str(e)}\")\n                print(f\"  Batch size: {len(batch_sample_ids)}\")\n                if batch_prompt_texts:\n                    prompt_lens = [len(self.tokenizer.encode(t, add_special_tokens=True)) for t in batch_prompt_texts]\n                    print(f\"  Prompt token length range: min={min(prompt_lens)}, max={max(prompt_lens)}, avg={sum(prompt_lens)/len(prompt_lens):.1f}\")\n                print(f\"  Full stack trace:\\n{traceback.format_exc()}\")\n                \n                num_return = sampling_params.get(\"n\", 1)\n                if sampling_params.get(\"use_beam_search\", False):\n                    num_return = sampling_params.get(\"beam_width\", 1)\n                for sample_id in batch_sample_ids:\n                    all_results[sample_id] = [\"\"] * num_return\n                    # If beam search, also set empty logprobs\n                    if sampling_params.get(\"use_beam_search\", False):\n                        all_logprobs[sample_id] = [0.0] * num_return\n                    # Don't include failed samples in MFU stats (they would have times=[0.0] which breaks MFU calculation)\n\n        # Calculate stage time\n        stage_elapsed_time = time.time() - stage_start_time\n\n        # Add time to all samples (same time for all samples in this worker)\n        for sample_id in all_mfu_stats:\n            all_mfu_stats[sample_id][\"times\"] = [stage_elapsed_time]\n\n        return (all_results, all_logprobs, all_mfu_stats)\n    \n    def extract_token_logprobs_batch(\n        self,\n        prompts: Dict[str, str],\n        target_tokens: List[str],\n        sampling_params: Dict[str, Any],\n        worker_batch_size: int = 8\n    ) -> tuple:\n        \"\"\"\n        Extract logprobs for specific target tokens\n\n        Args:\n            prompts: {sample_id: prompt_text}\n            target_tokens: List of target tokens (e.g., [\"是\", \"否\"])\n            sampling_params: Sampling parameter dictionary\n            worker_batch_size: Worker internal batch size\n\n        Returns:\n            Tuple of two dicts:\n            - First dict: {sample_id: [json_string]} where json_string is formatted probabilities\n            - Second dict: {sample_id: {\"input_tokens\": [int], \"output_tokens\": [int], \"times\": [float]}}\n        \"\"\"\n        import time\n        stage_start_time = time.time()\n\n        if not prompts:\n            return ({}, {})\n        \n        # Get token IDs for target tokens\n        target_token_ids = {}\n        for token in target_tokens:\n            token_ids = self.tokenizer.encode(token, add_special_tokens=False)\n            if len(token_ids) == 1:\n                target_token_ids[token] = token_ids[0]\n            else:\n                print(f\"  [Worker {self.worker_id}] Warning: Token '{token}' is encoded as multiple tokens: {token_ids}\")\n                # For multi-token case, we only use the first token for now\n                target_token_ids[token] = token_ids[0]\n        \n        # Build sampling parameters with logprobs enabled\n        params_dict = {\n            \"n\": sampling_params.get(\"n\", 1),\n            \"max_tokens\": sampling_params.get(\"max_tokens\", 1),\n            \"temperature\": sampling_params.get(\"temperature\", 1.0),\n            \"top_p\": sampling_params.get(\"top_p\", 1.0),\n            \"top_k\": sampling_params.get(\"top_k\", -1),\n            \"repetition_penalty\": sampling_params.get(\"repetition_penalty\", 1.0),\n            \"presence_penalty\": sampling_params.get(\"presence_penalty\", 0.0),\n            \"frequency_penalty\": sampling_params.get(\"frequency_penalty\", 0.0),\n            \"logprobs\": sampling_params.get(\"logprobs\", 10),\n        }\n        sp = SamplingParams(**params_dict)\n        \n        # Prepare input\n        sample_ids = list(prompts.keys())\n        prompt_texts = list(prompts.values())\n        \n        # Batch processing\n        all_results = {}\n        all_mfu_stats = {}\n        num_batches = (len(sample_ids) + worker_batch_size - 1) // worker_batch_size\n\n        for batch_idx in range(num_batches):\n            start_idx = batch_idx * worker_batch_size\n            end_idx = min(start_idx + worker_batch_size, len(sample_ids))\n\n            batch_sample_ids = sample_ids[start_idx:end_idx]\n            batch_prompt_texts = prompt_texts[start_idx:end_idx]\n\n            try:\n                # Generate with logprobs\n                batch_outputs = self.llm.generate(batch_prompt_texts, sp)\n\n                # Extract logprobs for target tokens\n                for idx, (sample_id, output) in enumerate(zip(batch_sample_ids, batch_outputs)):\n                    token_probs = {}\n\n                    # Get logprobs from the first generated token\n                    if output.outputs and len(output.outputs) > 0:\n                        first_output = output.outputs[0]\n                        if first_output.logprobs and len(first_output.logprobs) > 0:\n                            # Get logprobs dict for the first token\n                            first_token_logprobs = first_output.logprobs[0]\n\n                            # Extract probabilities for target tokens\n                            for token, token_id in target_token_ids.items():\n                                if token_id in first_token_logprobs:\n                                    logprob = first_token_logprobs[token_id].logprob\n                                    prob = math.exp(logprob)\n                                    token_probs[token] = prob\n                                else:\n                                    # Token not in top-k, assign very small probability\n                                    token_probs[token] = 1e-10\n\n                    all_results[sample_id] = [json.dumps(token_probs, ensure_ascii=False)]\n\n                    prompt_text = batch_prompt_texts[idx]\n                    input_tokens = len(self.tokenizer.encode(prompt_text, add_special_tokens=True))\n                    # Classification only generates 1 token\n                    output_tokens = 1\n                    all_mfu_stats[sample_id] = {\n                        \"input_tokens\": [input_tokens],\n                        \"output_tokens\": [output_tokens]\n                    }\n\n            except Exception as e:\n                import traceback\n                print(f\"\\n[Worker {self.worker_id}] Batch {batch_idx}/{num_batches} logprobs extraction failed:\")\n                print(f\"  Error: {str(e)}\")\n                print(f\"  Full stack trace:\\n{traceback.format_exc()}\")\n\n                for sample_id in batch_sample_ids:\n                    token_probs = {token: 0.0 for token in target_tokens}\n                    all_results[sample_id] = [json.dumps(token_probs, ensure_ascii=False)]\n                    # Don't include failed samples in MFU stats\n\n        stage_elapsed_time = time.time() - stage_start_time\n        for sample_id in all_mfu_stats:\n            all_mfu_stats[sample_id][\"times\"] = [stage_elapsed_time]\n\n        return (all_results, all_mfu_stats)\n\n\nclass RayVllmGenerator(RayMixin, VllmMixin, Generator):\n    \"\"\"\n    Ray-based Multi-GPU vLLM Generator (Data Parallel)\n    \"\"\"\n    \n    def __init__(\n        self,\n        model_name_or_path: str,\n        checkpoint_path: Optional[str] = None,\n        num_return_sequences: int = 2,\n        max_new_tokens: int = 128,\n        temperature: float = 0.7,\n        top_p: float = 0.9,\n        top_k: int = -1,\n        repetition_penalty: float = 1.0,\n        presence_penalty: float = 0.0,\n        frequency_penalty: float = 0.0,\n        do_sample: bool = True,\n        gpu_memory_utilization: float = 0.9,\n        trust_remote_code: bool = True,\n        dtype: str = \"auto\",\n        max_model_len: Optional[int] = None,\n        max_logprobs: int = 384,\n        tensor_parallel_size: int = 1,\n        num_gpus: Optional[int] = None,\n        gpu_ids: Optional[List[int]] = None,\n        task_types: Optional[List[str]] = None,\n        force_enable_optimizations: bool = False,\n        force_disable_optimizations: bool = False,\n        worker_batch_size: int = 4,\n        ray_address: Optional[str] = \"auto\",\n        allow_cross_node_tensor_parallel: bool = False,\n        **kwargs\n    ):\n        \"\"\"\n        Args:\n            model_name_or_path: Model name or path\n            checkpoint_path: PT checkpoint path (optional)\n            num_return_sequences: Number of candidate sequences per prompt\n            max_new_tokens: Maximum number of tokens to generate\n            temperature: Sampling temperature\n            top_p: Nucleus sampling parameter\n            top_k: Top-k sampling parameter\n            repetition_penalty: Repetition penalty\n            presence_penalty: Presence penalty (penalizes tokens that appeared in the text)\n            frequency_penalty: Frequency penalty (penalizes tokens based on frequency)\n            do_sample: Whether to sample\n            gpu_memory_utilization: GPU memory utilization\n            trust_remote_code: Whether to trust remote code\n            dtype: Model data type\n            max_model_len: Maximum model length\n            max_logprobs: Maximum number of log probabilities to return (for beam search and logprob extraction)\n            tensor_parallel_size: Tensor parallel size (default 1, single GPU per worker)\n            num_gpus: Number of GPUs to use (default uses all cluster GPUs)\n            gpu_ids: List of GPU IDs to use (only for single-node mode)\n            task_types: List of task types to evaluate (for auto optimization control)\n            force_enable_optimizations: Force enable optimizations for all tasks\n            force_disable_optimizations: Force disable optimizations for all tasks\n            worker_batch_size: Batch size for each worker (reduce if KV cache is insufficient)\n            ray_address: Ray cluster address ('auto', 'local', or specific address)\n            allow_cross_node_tensor_parallel: Allow tensor parallel across nodes (not recommended)\n            **kwargs: Other parameters\n        \"\"\"\n        super().__init__(\n            num_return_sequences=num_return_sequences,\n            max_new_tokens=max_new_tokens,\n            temperature=temperature,\n            top_p=top_p,\n            top_k=top_k,\n            repetition_penalty=repetition_penalty,\n            presence_penalty=presence_penalty,\n            frequency_penalty=frequency_penalty,\n            do_sample=do_sample,\n            **kwargs\n        )\n        \n        self.model_name = model_name_or_path\n        self.checkpoint_path = checkpoint_path\n        self.gpu_memory_utilization = gpu_memory_utilization\n        self.trust_remote_code = trust_remote_code\n        self.dtype = dtype\n        self.max_model_len = max_model_len\n        self.tensor_parallel_size = tensor_parallel_size\n        self.worker_batch_size = worker_batch_size\n        self.task_types = task_types or []\n        self.force_enable_optimizations = force_enable_optimizations\n        self.force_disable_optimizations = force_disable_optimizations\n        self.ray_address = ray_address\n        self.allow_cross_node_tensor_parallel = allow_cross_node_tensor_parallel\n        self.num_gpus = num_gpus\n        self.gpu_ids = gpu_ids\n\n        console.print(\n            \"\\nLoading Model\\n\",\n            style=head_style_2,\n            justify=\"center\",\n        )\n        console.print(\n            f\"  Using Ray + vLLM (Multi-Node) to load model: [cyan]{model_name_or_path}[/cyan]\",\n            style=subhead_style_2,\n        )\n        \n        # 1. Initialize Ray cluster connection\n        self._initialize_ray_cluster()\n        \n        # 2. Determine GPUs to use (from cluster)\n        all_gpu_ids = self._determine_gpu_ids_from_cluster()\n        \n        # 3. Group GPUs for workers (ensuring same-node constraint if needed)\n        self.worker_gpu_groups, self.worker_node_assignments = self._group_gpus_for_workers(\n            all_gpu_ids, tensor_parallel_size\n        )\n        num_workers = len(self.worker_gpu_groups)\n        \n        # Display cluster and GPU information\n        self._display_cluster_info(all_gpu_ids, num_workers)\n        \n        # 4. Handle PT checkpoint (main process converts)\n        if checkpoint_path:\n            console.print(\n                f\"  checkpoint: [yellow]{checkpoint_path}[/yellow]\",\n                style=subhead_style_2,\n            )\n            console.print(\n                \"  [yellow]Converting PT checkpoint to HuggingFace format in main process...[/yellow]\",\n                style=subhead_style_2,\n            )\n            model_path = export_pt_to_safetensor(\n                config_path=model_name_or_path,\n                checkpoint_path=checkpoint_path,\n                trust_remote_code=trust_remote_code\n            )\n        else:\n            model_path = model_name_or_path\n        \n        # 5. Create Workers\n        console.print(\n            f\"  Creating {num_workers} vLLM Workers...\",\n            style=subhead_style_2,\n        )\n        self.workers = []\n        \n        vllm_kwargs = {\n            \"gpu_memory_utilization\": gpu_memory_utilization,\n            \"trust_remote_code\": trust_remote_code,\n            \"dtype\": dtype,\n            \"tensor_parallel_size\": tensor_parallel_size,\n            \"max_logprobs\": max_logprobs,\n        }\n\n        if max_model_len is not None:\n            vllm_kwargs[\"max_model_len\"] = max_model_len\n\n        vllm_kwargs.update(kwargs)\n        \n        # Determine whether to enable optimizations (default behavior, can be overridden by force flags)\n        # Check if any task in task_types requires disabling optimizations\n        enable_optimizations = self._should_enable_optimizations()\n        \n        # Create Ray remote class with dynamic GPU count and scheduling strategy\n        # Use scheduling_strategy to place workers on specific nodes\n        for i, (gpu_group, node_id) in enumerate(zip(self.worker_gpu_groups, self.worker_node_assignments)):\n            # Create worker with node placement constraint\n            VllmWorkerRemote = ray.remote(num_gpus=tensor_parallel_size)(VllmWorker)\n            \n            # If we have node assignment, use scheduling strategy\n            if node_id is not None:\n                from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy\n                scheduling_strategy = NodeAffinitySchedulingStrategy(\n                    node_id=node_id,\n                    soft=False  # Hard constraint: must be on this node\n                )\n                worker = VllmWorkerRemote.options(\n                    scheduling_strategy=scheduling_strategy\n                ).remote(\n                    worker_id=i,\n                    model_path=model_path,\n                    gpu_ids=gpu_group,\n                    enable_optimizations=enable_optimizations,\n                    **vllm_kwargs\n                )\n            else:\n                # No node constraint, let Ray decide\n                worker = VllmWorkerRemote.remote(\n                    worker_id=i,\n                    model_path=model_path,\n                    gpu_ids=gpu_group,\n                    enable_optimizations=enable_optimizations,\n                    **vllm_kwargs\n                )\n            self.workers.append(worker)\n        \n        # Wait for all Workers to initialize\n        console.print(\n            \"  Waiting for all Workers to initialize...\\n\\n\",\n            style=subhead_style_2,\n        )\n        ray.get([worker.generate_batch.remote({}, {}) for worker in self.workers])\n        \n        console.print(\n            f\"✓ All Workers initialized successfully\\n\",\n            style=success_style,\n        )\n        \n        # Print optimization configuration summary\n        if force_enable_optimizations:\n            console.print(\n                \"⚙️ Optimization mode: FORCED ENABLED (chunked_prefill & prefix_caching enabled for all tasks)\",\n                style=warning_style,\n            )\n        elif force_disable_optimizations:\n            console.print(\n                \"⚙️ Optimization mode: FORCED DISABLED (chunked_prefill & prefix_caching disabled for all tasks)\",\n                style=warning_style,\n        )\n\n        self.num_params = self._count_model_parameters()\n        \n\n    def _count_model_parameters(self) -> Optional[float]:\n        \"\"\"\n        Override VllmMixin._count_model_parameters() for Ray-based generators.\n        \n        In Ray-based architecture, vLLM instances are in worker processes.\n        Query the first worker to get model parameter count.\n        \n        Returns:\n            float or None: Total number of parameters\n        \"\"\"\n        tensor_parallel_size = getattr(self, 'tensor_parallel_size', 1)\n        if tensor_parallel_size > 1:\n            console.print(\n                f\"Warning: Tensor parallel (size={tensor_parallel_size}) detected. \"\n                f\"Skipping parameter count (would only count local shard).\",\n                style=warning_style,\n            )\n            return None\n        \n        # Query the first worker to get parameter count\n        try:\n            import ray\n            num_params = ray.get(self.workers[0].get_model_parameters.remote())\n            console.print(\n                f\"✓ Model parameters: {num_params / 1e9:.2f}B\\n\",\n                style=success_style,\n            )\n            return num_params\n        except Exception as e:\n            console.print(\n                f\"Warning: Failed to get parameter count from worker: {e}\",\n                style=warning_style,\n            )\n            return None\n\n    def _generate_standard(\n        self,\n        prompts: Dict[str, str],\n        **kwargs\n    ) -> tuple:\n        \"\"\"\n        Standard single-stage generation (round-robin assignment to multiple Workers)\n\n        Args:\n            prompts: {sample_id: prompt_text}\n            **kwargs: Optional generation parameters, including:\n                - worker_batch_size: Worker internal batch size, for avoiding vLLM scheduler issues (default 16)\n                - return_logprobs: Whether to return cumulative logprobs for sampling mode (default False)\n\n        Returns:\n            Tuple of three dicts:\n            - First dict: {sample_id: [generated_text_1, generated_text_2, ...]}\n            - Second dict: {sample_id: [cum_logprob_1, cum_logprob_2, ...]} (for beam search or when return_logprobs=True)\n            - Third dict: {sample_id: {\"input_tokens\": [int], \"output_tokens\": [int], \"times\": [float]}} (lists for multi-stage support)\n        \"\"\"\n        # Auto-enable return_logprobs if prompt_token is used (for recommendation tasks)\n        has_prompt_token = bool(kwargs.get(\"prompt_token\", None))\n        \n        # Build sampling parameters using mixin method\n        sampling_params_obj = self._build_sampling_params(**kwargs)\n        \n        # Convert to dict for passing to workers\n        if hasattr(sampling_params_obj, 'beam_width'):\n            # BeamSearchParams\n            use_beam_search = True\n            sampling_params = {\n                \"use_beam_search\": True,\n                \"beam_width\": sampling_params_obj.beam_width,\n                \"max_tokens\": sampling_params_obj.max_tokens,\n            }\n        else:\n            # SamplingParams\n            use_beam_search = False\n            sampling_params = {\n                \"n\": sampling_params_obj.n,\n                \"max_tokens\": sampling_params_obj.max_tokens,\n                \"temperature\": sampling_params_obj.temperature,\n                \"top_p\": sampling_params_obj.top_p,\n                \"top_k\": sampling_params_obj.top_k,\n                \"repetition_penalty\": sampling_params_obj.repetition_penalty,\n                \"presence_penalty\": sampling_params_obj.presence_penalty,\n                \"frequency_penalty\": sampling_params_obj.frequency_penalty,\n                \"return_logprobs\": kwargs.get(\"return_logprobs\", has_prompt_token),\n            }\n            # Add stop parameter if specified\n            if sampling_params_obj.stop:\n                sampling_params[\"stop\"] = sampling_params_obj.stop\n\n        console.print(\n            f\"Starting generation...\",\n            style=subhead_style_2,\n        )\n        if use_beam_search:\n            console.print(\n                f\"Sampling parameters (beam search): beam_width={sampling_params['beam_width']}, \"\n                f\"max_tokens={sampling_params['max_tokens']}\",\n                style=subhead_style_2,\n            )\n        else:\n            console.print(\n                f\"Sampling parameters: n={sampling_params['n']}, max_tokens={sampling_params['max_tokens']}, \"\n                f\"temperature={sampling_params['temperature']}, top_p={sampling_params['top_p']}, top_k={sampling_params['top_k']}, \"\n                f\"repetition_penalty={sampling_params['repetition_penalty']}, \"\n                f\"presence_penalty={sampling_params['presence_penalty']}, \"\n                f\"frequency_penalty={sampling_params['frequency_penalty']}, \"\n                f\"return_logprobs={sampling_params['return_logprobs']}\",\n                style=subhead_style_2,\n            )\n        \n        # Round-robin assign tasks to Workers\n        sample_ids = list(prompts.keys())\n        num_workers = len(self.workers)\n        worker_tasks = [dict() for _ in range(num_workers)]\n        \n        for i, sample_id in enumerate(sample_ids):\n            worker_idx = i % num_workers\n            worker_tasks[worker_idx][sample_id] = prompts[sample_id]\n        \n        console.print(\n            f\"Task distribution: {[len(task) for task in worker_tasks]}\",\n            style=subhead_style_2,\n        )\n        console.print(\n            f\"Worker batch size: {self.worker_batch_size}\",\n            style=subhead_style_2,\n        )\n        \n        # Execute in parallel\n        futures = []\n        for i, (worker, task) in enumerate(zip(self.workers, worker_tasks)):\n            if task:  # Only submit non-empty tasks\n                future = worker.generate_batch.remote(task, sampling_params, self.worker_batch_size)\n                futures.append(future)\n        \n        # Collect results\n        worker_results = ray.get(futures)\n\n        # Merge results (each worker_result is a tuple of (texts_dict, logprobs_dict, mfu_stats_dict))\n        results = {}\n        logprobs = {}\n        mfu_stats = {}\n        for worker_result in worker_results:\n            texts_dict, logprobs_dict, mfu_stats_dict = worker_result\n            results.update(texts_dict)\n            logprobs.update(logprobs_dict)\n            mfu_stats.update(mfu_stats_dict)\n\n        console.print(\n            f\"✓ Generation completed\",\n            style=success_style,\n        )\n\n        return (results, logprobs, mfu_stats)\n\n    \n    def extract_token_logprobs(\n        self,\n        prompts: Dict[str, str],\n        target_tokens: List[str],\n        **kwargs\n    ) -> tuple:\n        \"\"\"\n        Extract logprobs for specific target tokens (round-robin assignment to multiple Workers)\n\n        Args:\n            prompts: {sample_id: prompt_text}\n            target_tokens: List of target tokens to extract probabilities for (e.g., [\"是\", \"否\"])\n            **kwargs: Optional parameters including generation config\n\n        Returns:\n            Tuple of three dicts:\n            - First dict: {sample_id: [json_string]} where json_string is formatted probabilities\n            - Second dict: {} (empty, no beam search logprobs for classification)\n            - Third dict: {sample_id: {\"input_tokens\": [int], \"output_tokens\": [int], \"times\": [float]}}\n        \"\"\"\n        console.print(\n            f\"Extracting logprobs for tokens: {target_tokens}\",\n            style=subhead_style_2,\n        )\n        console.print(\n            f\"Worker batch size: {self.worker_batch_size}\",\n            style=subhead_style_2,\n        )\n\n        if not prompts:\n            return ({}, {}, {})\n        \n        # Build sampling parameters\n        sampling_params = {\n            \"n\": kwargs.get(\"num_return_sequences\", 1),\n            \"max_tokens\": kwargs.get(\"max_new_tokens\", 1),\n            \"temperature\": kwargs.get(\"temperature\", 1.0),\n            \"top_p\": kwargs.get(\"top_p\", 1.0),\n            \"top_k\": kwargs.get(\"top_k\", -1),\n            \"repetition_penalty\": kwargs.get(\"repetition_penalty\", 1.0),\n            \"presence_penalty\": kwargs.get(\"presence_penalty\", 0.0),\n            \"frequency_penalty\": kwargs.get(\"frequency_penalty\", 0.0),\n            \"logprobs\": kwargs.get(\"logprobs\", 10),\n        }\n        \n        console.print(\n            f\"Sampling parameters: n={sampling_params['n']}, max_tokens={sampling_params['max_tokens']}, \"\n            f\"temperature={sampling_params['temperature']}, top_p={sampling_params['top_p']}, \"\n            f\"top_k={sampling_params['top_k']}, repetition_penalty={sampling_params['repetition_penalty']}, \"\n            f\"presence_penalty={sampling_params['presence_penalty']}, frequency_penalty={sampling_params['frequency_penalty']}, \"\n            f\"logprobs={sampling_params['logprobs']}\",\n            style=subhead_style_2,\n        )\n        \n        # Round-robin assign tasks to Workers\n        sample_ids = list(prompts.keys())\n        num_workers = len(self.workers)\n        worker_tasks = [dict() for _ in range(num_workers)]\n        \n        for i, sample_id in enumerate(sample_ids):\n            worker_idx = i % num_workers\n            worker_tasks[worker_idx][sample_id] = prompts[sample_id]\n        \n        console.print(\n            f\"Task distribution: {[len(task) for task in worker_tasks]}\",\n            style=subhead_style_2,\n        )\n\n        # Get Worker internal batch size\n        worker_batch_size = self.worker_batch_size\n\n        # Execute in parallel\n        futures = []\n        for worker, task in zip(self.workers, worker_tasks):\n            if task:  # Only submit non-empty tasks\n                future = worker.extract_token_logprobs_batch.remote(\n                    task, target_tokens, sampling_params, worker_batch_size\n                )\n                futures.append(future)\n        \n        # Collect results\n        worker_results = ray.get(futures)\n\n        # Merge results (each worker_result is a tuple of (probs_dict, mfu_stats_dict))\n        results = {}\n        mfu_stats = {}\n        for worker_result in worker_results:\n            probs_dict, mfu_stats_dict = worker_result\n            results.update(probs_dict)\n            mfu_stats.update(mfu_stats_dict)\n\n        console.print(\n            f\"✓ Logprobs extraction completed\",\n            style=success_style,\n        )\n\n        return (results, {}, mfu_stats)\n    \n"
  },
  {
    "path": "data/README.md",
    "content": "# Dataset Documentation\n\nThis directory contains data processing scripts and dataset format specifications for the OpenOneRec project.\n\n## Table of Contents\n\n- [Quick Start](#quick-start) - Get started quickly with dataset download and processing\n- [Directory Structure](#directory-structure)\n- [Dataset Format Specification](#dataset-format-specification)\n- [Notes](#notes)\n\n## Directory Structure\n\n- **general_text/**: General text data used in training, including pretraining and SFT datasets for mathematics, code, reasoning, and other domains\n- **onerec_data/**: Recommendation scenario data and corresponding processing scripts that convert raw recommendation data into LLM pretraining and SFT training formats\n\n### General Text Data (general_text)\n\nThe general text data directory contains information about the main general text datasets used in the project.\n\nThe `pretrain.csv` and `sft.csv` files list all HuggingFace dataset URLs and their corresponding sample counts. For easier reproduction, we have also released our processed datasets on HuggingFace:\n\n- [Pretraining Dataset on HuggingFace](https://huggingface.co/datasets/OpenOneRec/OpenOneRec-General-Pretrain)\n- [SFT Dataset on HuggingFace](https://huggingface.co/datasets/OpenOneRec/OpenOneRec-General-SFT)\n\n> **NOTE**: The processed data on HuggingFace currently does not include some datasets (Nemotron_CC_Math_v1, Nemotron_Pretraining_Code_v1, Nemotron_CC_v2). We will provide a data processing script later to facilitate reproduction.\n\n### OneRec Business Data (onerec_data)\n\nThe OneRec business data directory contains data processing scripts for recommendation systems, converting raw data into LLM pretraining and SFT training formats. It includes data processing scripts for various recommendation scenarios such as video recommendation, user profiling, interactive recommendation, label prediction, and cross-domain recommendation.\n\n- [OpenOneRec Dataset on HuggingFace](https://huggingface.co/datasets/OpenOneRec/OpenOneRec-RecIF)\n\n## Dataset Format Specification\n\nTo standardize data processing, we use a unified Parquet data format. Each Parquet file contains the following fields:\n\n### Field Description\n\n| Field | Type | Required | Default | Description | Requirements |\n|-------|------|----------|---------|-------------|--------------|\n| uuid | str | Yes | Auto-generated UUID | Unique identifier | Must be a valid UUID format, must be unique within the same dataset |\n| source | str | Yes | - | Data source identifier | Cannot be an empty string |\n| metadata | str | No | \"{}\" | JSON-formatted metadata dictionary | Must be a valid JSON dictionary string |\n| images | str | No | \"{}\" | (Deprecated) This project only trains on text, this field is not used | - |\n| videos | str | No | \"{}\" | (Deprecated) This project only trains on text, this field is not used | - |\n| messages | str | No | None | JSON-formatted message list for conversation format data | Must be a valid JSON array, each message must have role and content fields |\n| segments | str | No | None | JSON-formatted segment list for segmented data | Must be a valid JSON array, each segment must have a type field |\n| image | str | No | None | (Deprecated) This project only trains on text, this field is not used | - |\n| video | str | No | None | (Deprecated) This project only trains on text, this field is not used | - |\n| text | str | No | None | Text content | No special requirements |\n| label | str | No | None | Label information, if `image`, `video`, `text` exists, it is the corresponding label | No special requirements |\n\n### Data Format Examples\n\nThe data format supports two main types:\n- **Segments Format**: For regular text data, using the `segments` field to store text segment lists\n- **Chat Format**: For conversation data, using the `messages` field to store conversation message lists\n\n**Chat Format Data (Conversation Data):**\n\n| Field | Value |\n|-------|-------|\n| uuid | 550e8400-e29b-41d4-a716-446655440001 |\n| source | conversation_dataset |\n| metadata | '{}' |\n| images | '{}' |\n| videos | '{}' |\n| messages | '[{\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"What is machine learning?\"}]}, {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"Machine learning is a subset of artificial intelligence.\"}]}]' |\n\n**Segments Format Data (Regular Text):**\n\n| Field | Value |\n|-------|-------|\n| uuid | 550e8400-e29b-41d4-a716-446655440002 |\n| source | document_dataset |\n| metadata | '{}' |\n| images | '{}' |\n| videos | '{}' |\n| segments | '[{\"type\": \"text\", \"text\": \"Introduction paragraph...\"}, {\"type\": \"text\", \"text\": \"Main content...\"}]' |\n\n### Field Validation Rules\n\n| Validation Item | Rule Description |\n|-----------------|------------------|\n| JSON Field Validation | metadata must be a valid JSON dictionary string; images and videos fields (deprecated) should be set to \"{}\" |\n| Message Format Validation | messages field (if present) must contain a valid message list, each message must have role and content fields |\n| Role Validation | Message role must be one of user, assistant, or system |\n| Content Type Validation | The type in message content must be text (this project only trains on text, image and video types are not supported) |\n| Segment Format Validation | segments field (if present) must contain a valid segment list, each segment must have a type field, type should be \"text\" |\n\n### File Size Recommendations\n\nFor efficient DataLoader data loading, it is recommended that each Parquet file contains approximately **1000 samples**. If the data volume is large, you can use sharding to split the data into multiple files. The recommended file naming format is:\n\n```\npart-00000-of-00010.parquet\npart-00001-of-00010.parquet\n...\npart-00009-of-00010.parquet\n```\n\n## Quick Start\n\n### 1. Download Datasets\n\nFirst, download the corresponding datasets from HuggingFace:\n\n- [Pretraining General Text Dataset](https://huggingface.co/datasets/OpenOneRec/OpenOneRec-General-Pretrain)\n- [SFT General Text Dataset](https://huggingface.co/datasets/OpenOneRec/OpenOneRec-General-SFT)\n- [OneRec Recommendation Dataset](https://huggingface.co/datasets/OpenOneRec/OpenOneRec-RecIF)\n\nYou can download the datasets using the following commands (run from the **project root directory**):\n\n```bash\npip install huggingface_hub\n\nexport HF_TOKEN=<YOUR_HUGGINGFACE_TOKEN>\n\nhf download OpenOneRec/OpenOneRec-General-Pretrain \\\n    --repo-type dataset \\\n    --token $HF_TOKEN \\\n    --local-dir ./raw_data/general_text/pretrain\n\nhf download OpenOneRec/OpenOneRec-General-SFT \\\n    --repo-type dataset \\\n    --token $HF_TOKEN \\\n    --local-dir ./raw_data/general_text/sft\n\nhf download OpenOneRec/OpenOneRec-RecIF \\\n    --repo-type dataset \\\n    --token $HF_TOKEN \\\n    --local-dir ./raw_data/onerec_data\n```\n\n### 2. Process Recommendation Data\n\nrun:\n\n```bash\ncd data/onerec_data\nbash run.sh\n```\n\n### 3. Pretraining Data Sharding\n\nThe generated data can be processed by calling the prepare scripts. Edit `prepare_pretrain.sh` or `prepare_sft.sh` and modify the following configuration:\n\n```bash\nGENERAL_TEXT_PATH=\"data/general_text\"      # General text data path\nREC_DATA_PATH=\"data/onerec_data/output\"   # Recommendation data output path\nOUTPUT_DIR=\"./output/split_data\"          # Final output path\nMAX_ROWS=1000                             # Number of samples per file\n```\n\nThen run:\n\n```bash\n# Process pretraining data\nbash prepare_pretrain.sh\n\n# Process SFT data\nbash prepare_sft.sh\n```\n\n### 4. Distillation Data Processing\n\nData processing for on-policy distillation. Edit `prepare_distillation.sh` and modify the following configuration:\n\n```bash\nINPUT_PATH=\"data/general_text\"                    # General text data path\nOUTPUT_FILE=\"./output/onpolicy_distillation.parquet\"  # Output file path\nNUM_SAMPLES=200000                                # Number of samples to sample\nSEED=42                                           # Random seed\n```\n\nThen run:\n\n```bash\nbash prepare_distillation.sh\n```\n\n### 5. RL Data Processing\n\nData processing for reinforcement learning (RL) training. Merges multiple RL task datasets and splits them into training and test sets. Edit `prepare_rl.sh` and modify the following configuration:\n\n```bash\nREC_DATA_PATH=\"data/onerec_data\"                  # OneRec dataset path\nOUTPUT_DIR=\"./output/rl_data\"                     # Output directory path\nTEST_SIZE=1000                                     # Number of test samples per subtask\nSEED=42                                            # Random seed\n```\n\nThe script processes the following 5 RL task datasets:\n- `sft_video_rec.parquet` - Video recommendation task\n- `sft_ad_rec.parquet` - Ad recommendation task\n- `sft_product_rec.parquet` - Product recommendation task\n- `sft_interactive_rec.parquet` - Interactive recommendation task\n- `sft_label_cond_rec.parquet` - Label-conditioned recommendation task\n\nThen run:\n\n```bash\nbash prepare_rl.sh\n```\n\nOutput:\n- `./output/rl_data/train.parquet` - Training set (remaining data after merging all tasks)\n- `./output/rl_data/test.parquet` - Test set (1000 samples randomly sampled from merged data)\n\n\n## Notes\n\n* All scripts only process `split=0` (training set) data by default"
  },
  {
    "path": "data/general_text/pretrain.csv",
    "content": "﻿dataname,sample_num,huggingface_repo\r\nNemotron_CC_Math_v1,15440682 ,https://huggingface.co/datasets/nvidia/Nemotron-CC-Math-v1\r\nNemotron_Pretraining_Code_v1,5329298 ,https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-Code-v1\r\nNemotron_CC_v2,2306412 ,https://huggingface.co/datasets/nvidia/Nemotron-CC-v2\r\nreasoning_v1_20m,1666229 ,https://huggingface.co/datasets/glaiveai/reasoning-v1-20m\r\nOpenMathReasoning,477179 ,https://huggingface.co/datasets/nvidia/OpenMathReasoning\r\nNuminaMath-QwQ-CoT-5M,324270 ,https://huggingface.co/datasets/PrimeIntellect/NuminaMath-QwQ-CoT-5M\r\nOpenCodeReasoning,109292 ,https://huggingface.co/datasets/nvidia/OpenCodeReasoning\r\nKodCode_V1_SFT_R1,39211 ,https://huggingface.co/datasets/KodCode/KodCode-V1-SFT-R1\r\nChinese-Reasoning-Distil-Data,30000 ,https://huggingface.co/datasets/Mxode/Chinese-Reasoning-Distil-Data\r\nmedical-o1-reasoning-SFT,7000 ,https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT\r\nBespoke-Stratos-17k,2000 ,https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k"
  },
  {
    "path": "data/general_text/sft.csv",
    "content": "﻿dataname,sample_num,huggingface_repo\r\nOpenMathReasoning,510163,https://huggingface.co/datasets/nvidia/OpenMathReasoning\r\nR1-Distill-SFT,502818,https://huggingface.co/datasets/ServiceNow-AI/R1-Distill-SFT\r\nInfinity_Instruct,446773,https://huggingface.co/datasets/BAAI/Infinity-Instruct\r\nOpenCoderReasoning,437768,https://huggingface.co/datasets/nvidia/OpenCodeReasoning\r\nChinese-Reasoning-Distil-Data,179037,https://huggingface.co/datasets/Mxode/Chinese-Reasoning-Distil-Data\r\nReasoning_Multi_subject_RLVR,172108,https://huggingface.co/datasets/punwaiw/multi-subject-rlvr-final-reasoning-traces\r\nReasoning_KodCode_V1_SFT_R1,163908,https://huggingface.co/datasets/KodCode/KodCode-V1-SFT-R1\r\nDeepMath103K,92886,https://huggingface.co/datasets/zwhe99/DeepMath-103K\r\nmedical-o1-reasoning-SFT,50245,https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT"
  },
  {
    "path": "data/onerec_data/README.md",
    "content": "# OneRec Data Processing Scripts\n\nThis directory contains data processing scripts for the OneRec project, converting raw data into LLM pretraining and SFT training formats.\n\n## Directory Structure\n\n```\ndata/\n├── pretrain/               # Pretrain data processing scripts\n│   ├── video_rec.py        # Video recommendation pretrain\n│   ├── user_profile.py     # User profile pretrain\n│   └── item_understand.py  # item understanding alignment pretrain\n├── sft/                    # SFT data processing scripts\n│   ├── video_rec.py        # Video recommendation\n│   ├── interactive_rec.py  # Interactive recommendation\n│   ├── label_cond_rec.py   # Label conditional recommendation\n│   ├── label_pred.py       # Label prediction (binary classification)\n│   ├── ad_rec.py           # Ad recommendation (cross-domain)\n│   ├── product_rec.py      # Product recommendation (cross-domain)\n│   ├── item_understand.py      # Item understand \n│   └── reco_reason.py      # Recommendation reasoning\n├── run.sh                  # Main execution script\n└── README.md\n```\n\n## Quick Start\n\n### 1. Configure Input Paths\n\nEdit `run.sh` to set the following paths:\n\n```bash\nINPUT_METADATA=\"path/to/onerec_bench_release.parquet\"\nPID2SID_MAPPING=\"path/to/video_ad_pid2sid.parquet\"\nPRODUCT_PID2SID_MAPPING=\"path/to/product_pid2sid.parquet\"\nCAPTION_INPUT=\"path/to/pid2caption.parquet\"\nOUTPUT_BASE_DIR=\"./output\"\n```\n\n### 2. Select Tasks to Run\n\nUncomment the tasks you want to run in `run.sh`:\n\n```bash\n# Pretrain tasks\nRUN_PRETRAIN_VIDEO_REC=1\nRUN_PRETRAIN_USER_PROFILE=1\nRUN_PRETRAIN_SID2CAPTION=1\n\n# SFT tasks\nRUN_SFT_VIDEO_REC=1\nRUN_SFT_INTERACTIVE_REC=1\n# ...\n```\n\n### 3. Run\n\n```bash\ncd data\nbash run.sh\n```\n\n## Task Descriptions\n\n### Pretrain Tasks\n\n| Task | Script | Description |\n|------|--------|-------------|\n| video_rec | `pretrain/video_rec.py` | Concatenate user history SID sequence with target SID sequence for sequence modeling pretrain |\n| user_profile | `pretrain/user_profile.py` | Use `inter_user_profile_with_sid` field as pretrain text |\n| item_understand | `pretrain/item_understand.py` | Build item understanding alignment data using various template formats |\n\n### SFT Tasks\n\n| Task | Script | Description |\n|------|--------|-------------|\n| video_rec | `sft/video_rec.py` | Predict next video based on user browsing history |\n| interactive_rec | `sft/interactive_rec.py` | Recommend content based on user profile and search keywords |\n| label_cond_rec | `sft/label_cond_rec.py` | Predict items by interaction type (like/follow/forward/etc.) |\n| label_pred | `sft/label_pred.py` | Binary classification: predict if user will watch a video for long |\n| ad_rec | `sft/ad_rec.py` | Cross-domain: predict ad clicks based on video and ad history |\n| product_rec | `sft/product_rec.py` | Cross-domain: predict product clicks based on video and product history |\n| item_understand | `sft/item_understand.py` | Generate video description from SID |\n| reco_reason | `sft/reco_reason.py` | Generate recommendation reasoning: analyze user interests and explain recommendations |\n\n## Output Format\n\n### Pretrain Format\n\n```json\n{\n  \"source\": \"RecIF_VideoRec_Pretrain\",\n  \"uuid\": \"xxx\",\n  \"segments\": [{\"type\": \"text\", \"text\": \"...\"}],\n  \"metadata\": {\"uid\": 123}\n}\n```\n\n### SFT Format\n\n```json\n{\n  \"source\": \"RecIF_VideoRec\",\n  \"uuid\": \"xxx\",\n  \"messages\": [\n    {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]},\n    {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]},\n    {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]}\n  ],\n  \"metadata\": {\"uid\": 123}\n}\n```\n\n## SID Format\n\nAll scripts use a unified SID format:\n\n```\n<|sid_begin|><s_a_{c0}><s_b_{c1}><s_c_{c2}><|sid_end|>\n```\n\nWhere `c0`, `c1`, `c2` are triplet codes obtained from the `pid2sid` mapping table.\n\n## Dependencies\n\n- pandas\n- numpy\n- tqdm\n\n## Running Individual Scripts\n\nEach script can also be run independently:\n\n```bash\n# Example: Run video_rec SFT task\npython sft/video_rec.py \\\n    --input /path/to/metadata.parquet \\\n    --pid2sid /path/to/pid2sid.parquet \\\n    --output_dir ./output \\\n    --seed 42\n```\n\n## Notes\n\n1. All scripts process only `split=0` (training set) data by default\n2. Output files are named as `{task_type}_{task_name}.parquet`\n3. Cross-domain tasks (product_rec) require additional pid2sid mapping files\n"
  },
  {
    "path": "data/onerec_data/pretrain/item_understand.py",
    "content": "\"\"\"\nItem Understand Pretrain Task\nInput: caption parquet (pid, dense_caption) + pid2sid parquet\nOutput: LLM Pretrain format parquet (segments)\n\nTask: Build pretrain data with SID and caption using various templates.\n\"\"\"\n\nimport pandas as pd\nimport argparse\nimport json\nimport uuid\nimport random\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n# ============== Configuration ==============\nSID_FORMAT = '<|sid_begin|><s_a_{c0}><s_b_{c1}><s_c_{c2}><|sid_end|>'\n\n# Pretrain format templates\nPRETRAIN_TEMPLATES = [\n    # Format 1: JSON format\n    lambda sid, caption: json.dumps({\"视频ID\": sid, \"视频内容\": caption}, ensure_ascii=False),\n    # Format 2: Display format\n    lambda sid, caption: f\"视频{sid} 展示了以下内容：{caption}\",\n    # Format 3: Full description format\n    lambda sid, caption: f\"视频{sid} 的内容完整描述如下：{caption}\",\n]\n\n\n# ============== Core Functions ==============\ndef pid_to_sid(pid, pid2sid: dict) -> str:\n    \"\"\"Convert a single pid to SID string.\"\"\"\n    if pid not in pid2sid:\n        return \"\"\n    code = pid2sid[pid]\n    return SID_FORMAT.format(c0=code[0], c1=code[1], c2=code[2])\n\n\ndef build_segments(sid: str, caption: str) -> str:\n    \"\"\"Build segments format JSON string for pretrain.\"\"\"\n    template = random.choice(PRETRAIN_TEMPLATES)\n    text = template(sid, caption)\n    segments = [{\"type\": \"text\", \"text\": text}]\n    return json.dumps(segments, ensure_ascii=False)\n\n\ndef process_row(row, pid2sid: dict) -> dict:\n    \"\"\"Process a single row of data.\"\"\"\n    pid = row['pid']\n    dense_caption = row['dense_caption']\n\n    # Check data validity\n    if dense_caption is None or (isinstance(dense_caption, float) and pd.isna(dense_caption)):\n        return None\n    if not dense_caption:\n        return None\n\n    # Convert pid to SID\n    sid = pid_to_sid(pid, pid2sid)\n    if not sid:\n        return None\n\n    return {\n        'source': 'RecIF_ItemUnderstand_Pretrain',\n        'uuid': str(uuid.uuid4()),\n        'segments': build_segments(sid, dense_caption),\n        'metadata': json.dumps({'pid': int(pid), 'sid': sid}, ensure_ascii=False)\n    }\n\n\n# ============== Main Function ==============\ndef main():\n    parser = argparse.ArgumentParser(description=\"Item Understand Pretrain Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input caption parquet path')\n    parser.add_argument('--pid2sid', type=str, required=True, help='pid2sid mapping parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # 1. Load pid2sid mapping\n    print(f\"Loading pid2sid from {args.pid2sid}...\")\n    df_pid2sid = pd.read_parquet(args.pid2sid)\n    pid2sid = dict(zip(df_pid2sid['pid'], df_pid2sid['sid']))\n    print(f\"  Loaded {len(pid2sid):,} mappings\")\n\n    # 2. Load caption data\n    print(f\"Loading caption data from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # 3. Process data\n    print(\"Processing...\")\n    results = []\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        result = process_row(row, pid2sid)\n        if result:\n            results.append(result)\n\n    # 4. Save results\n    df_output = pd.DataFrame(results)\n    output_path = output_dir / 'train.parquet'\n    df_output.to_parquet(output_path, index=False)\n\n    print(f\"Saved: {output_path} ({len(df_output):,} rows)\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/onerec_data/pretrain/user_profile.py",
    "content": "\"\"\"\nUser Profile Pretrain Task\nInput: metadata parquet\nOutput: LLM Pretrain format parquet (segments)\n\nTask: Directly use inter_user_profile_with_sid as pretrain text.\n\"\"\"\n\nimport pandas as pd\nimport argparse\nimport json\nimport uuid\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n\ndef process_row(row) -> dict:\n    \"\"\"Process a single row of data.\"\"\"\n    user_profile = row.get('inter_user_profile_with_sid')\n\n    # Check data validity\n    if user_profile is None or (isinstance(user_profile, float) and pd.isna(user_profile)):\n        return None\n    if not user_profile or not isinstance(user_profile, str):\n        return None\n\n    segments = [{\"type\": \"text\", \"text\": user_profile}]\n\n    return {\n        'source': 'RecIF_UserProfile_Pretrain',\n        'uuid': str(uuid.uuid4()),\n        'segments': json.dumps(segments, ensure_ascii=False),\n        'metadata': json.dumps({'uid': int(row['uid'])}, ensure_ascii=False)\n    }\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"User Profile Pretrain Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input metadata parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    args = parser.parse_args()\n\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # Load metadata\n    print(f\"Loading metadata from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # Process data (train only, split=0)\n    print(\"Processing...\")\n    results = []\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        if row['split'] != 0:\n            continue\n        result = process_row(row)\n        if result:\n            results.append(result)\n\n    # Save results\n    df_train = pd.DataFrame(results)\n    train_path = output_dir / 'train.parquet'\n    df_train.to_parquet(train_path, index=False)\n\n    print(f\"Saved: {train_path} ({len(df_train):,} rows)\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/onerec_data/pretrain/video_rec.py",
    "content": "\"\"\"\nVideo Recommendation Pretrain Task\nInput: metadata parquet + pid2sid parquet\nOutput: LLM Pretrain format parquet (segments instead of messages)\n\nTask: Directly concatenate history SIDs and target SIDs without prompts.\n\"\"\"\n\nimport pandas as pd\nimport argparse\nimport json\nimport uuid\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n# ============== Configuration ==============\nSID_FORMAT = '<|sid_begin|><s_a_{c0}><s_b_{c1}><s_c_{c2}><|sid_end|>'\nHIST_MAX_LEN = 512\nTARGET_MAX_LEN = 10\n\n\n# ============== Core Functions ==============\ndef pids_to_sids(pids, pid2sid: dict) -> str:\n    \"\"\"Convert a list of pids to SID string.\"\"\"\n    if pids is None or (isinstance(pids, float) and pd.isna(pids)):\n        return \"\"\n    sids = []\n    for pid in pids:\n        if pid in pid2sid:\n            code = pid2sid[pid]\n            sid = SID_FORMAT.format(c0=code[0], c1=code[1], c2=code[2])\n            sids.append(sid)\n    return ''.join(sids)\n\n\ndef build_segments(hist_sids: str, target_sids: str) -> str:\n    \"\"\"Build segments format JSON string for pretrain.\"\"\"\n    text = f\"{hist_sids}{target_sids}\"\n    segments = [{\"type\": \"text\", \"text\": text}]\n    return json.dumps(segments, ensure_ascii=False)\n\n\ndef process_row(row, pid2sid: dict) -> dict:\n    \"\"\"Process a single row of data.\"\"\"\n    hist_pids = row['hist_video_pid']\n    target_pids = row['target_video_pid']\n\n    # Check data validity\n    if hist_pids is None or (isinstance(hist_pids, float) and pd.isna(hist_pids)):\n        return None\n    if target_pids is None or (isinstance(target_pids, float) and pd.isna(target_pids)):\n        return None\n\n    # Truncate and convert to SID\n    hist_sids = pids_to_sids(hist_pids[-HIST_MAX_LEN:], pid2sid)\n    target_sids = pids_to_sids(target_pids[:TARGET_MAX_LEN], pid2sid)\n\n    if not hist_sids or not target_sids:\n        return None\n\n    return {\n        'source': 'RecIF_VideoRec_Pretrain',\n        'uuid': str(uuid.uuid4()),\n        'segments': build_segments(hist_sids, target_sids),\n        'metadata': json.dumps({'uid': int(row['uid'])}, ensure_ascii=False)\n    }\n\n\n# ============== Main Function ==============\ndef main():\n    parser = argparse.ArgumentParser(description=\"Video Recommendation Pretrain Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input metadata parquet path')\n    parser.add_argument('--pid2sid', type=str, required=True, help='pid2sid mapping parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    args = parser.parse_args()\n\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # 1. Load pid2sid mapping\n    print(f\"Loading pid2sid from {args.pid2sid}...\")\n    df_pid2sid = pd.read_parquet(args.pid2sid)\n    pid2sid = dict(zip(df_pid2sid['pid'], df_pid2sid['sid']))\n    print(f\"  Loaded {len(pid2sid):,} mappings\")\n\n    # 2. Load metadata\n    print(f\"Loading metadata from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # 3. Process data (train only, split=0)\n    print(\"Processing...\")\n    results = []\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        if row['split'] != 0:\n            continue\n        result = process_row(row, pid2sid)\n        if result:\n            results.append(result)\n\n    # 4. Save results\n    df_train = pd.DataFrame(results)\n    train_path = output_dir / 'train.parquet'\n    df_train.to_parquet(train_path, index=False)\n\n    print(f\"Saved: {train_path} ({len(df_train):,} rows)\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/onerec_data/run.sh",
    "content": "#!/bin/bash\n# RecIF Data Processing Script\n# Generate all pretrain and SFT data\n\nset -e\n\n# ============== Task Selection ==============\n# Comment out tasks you don't want to run\n\n# Pretrain tasks\nRUN_PRETRAIN_VIDEO_REC=1\nRUN_PRETRAIN_USER_PROFILE=1\nRUN_PRETRAIN_ITEM_UNDERSTAND=1\n\n# SFT tasks\nRUN_SFT_VIDEO_REC=1\nRUN_SFT_INTERACTIVE_REC=1\nRUN_SFT_LABEL_COND_REC=1\nRUN_SFT_LABEL_PRED=1\nRUN_SFT_AD_REC=1\nRUN_SFT_PRODUCT_REC=1\nRUN_SFT_ITEM_UNDERSTAND=1\nRUN_SFT_REC_REASON=1\n\n# ============== Configuration ==============\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n\nINPUT_METADATA=\"../../raw_data/onerec_data/onerec_bench_release.parquet\"\nPID2SID_MAPPING=\"../../raw_data/onerec_data/video_ad_pid2sid.parquet\"\nPRODUCT_PID2SID_MAPPING=\"../../raw_data/onerec_data/product_pid2sid.parquet\"\nCAPTION_INPUT=\"../../raw_data/onerec_data/pid2caption.parquet\"\nOUTPUT_BASE_DIR=\"../../output\"\n\nSEED=42\n\n# ============== Helper Function ==============\nrun_task() {\n    local task_type=$1\n    local task_name=$2\n    local script_path=$3\n    shift 3\n    local extra_args=\"$@\"\n\n    local output_file=\"${OUTPUT_BASE_DIR}/${task_type}_${task_name}.parquet\"\n    local temp_dir=$(mktemp -d)\n\n    echo \"  Output: ${output_file}\"\n    python3 \"${script_path}\" --output_dir \"${temp_dir}\" ${extra_args}\n\n    if [ -f \"${temp_dir}/train.parquet\" ]; then\n        mv \"${temp_dir}/train.parquet\" \"${output_file}\"\n    fi\n    rm -rf \"${temp_dir}\"\n}\n\n# ============== Main ==============\necho \"========================================\"\necho \"RecIF Data Processing\"\necho \"========================================\"\necho \"Metadata: ${INPUT_METADATA}\"\necho \"PID2SID: ${PID2SID_MAPPING}\"\necho \"Caption: ${CAPTION_INPUT}\"\necho \"Output: ${OUTPUT_BASE_DIR}\"\necho \"\"\n\nmkdir -p \"${OUTPUT_BASE_DIR}\"\n\n# ============== Pretrain Tasks ==============\necho \"========================================\"\necho \"Pretrain Tasks\"\necho \"========================================\"\n\nif [ \"${RUN_PRETRAIN_VIDEO_REC}\" = \"1\" ]; then\n    echo \"[pretrain] video_rec...\"\n    run_task \"pretrain\" \"video_rec\" \"${SCRIPT_DIR}/pretrain/video_rec.py\" \\\n        --input \"${INPUT_METADATA}\" --pid2sid \"${PID2SID_MAPPING}\"\nfi\n\nif [ \"${RUN_PRETRAIN_USER_PROFILE}\" = \"1\" ]; then\n    echo \"[pretrain] user_profile...\"\n    run_task \"pretrain\" \"user_profile\" \"${SCRIPT_DIR}/pretrain/user_profile.py\" \\\n        --input \"${INPUT_METADATA}\"\nfi\n\nif [ \"${RUN_PRETRAIN_ITEM_UNDERSTAND}\" = \"1\" ]; then\n    echo \"[pretrain] item_understand...\"\n    run_task \"pretrain\" \"item_understand\" \"${SCRIPT_DIR}/pretrain/item_understand.py\" \\\n        --input \"${CAPTION_INPUT}\" --pid2sid \"${PID2SID_MAPPING}\" --seed ${SEED}\nfi\n\n# ============== SFT Tasks ==============\necho \"\"\necho \"========================================\"\necho \"SFT Tasks\"\necho \"========================================\"\n\nif [ \"${RUN_SFT_VIDEO_REC}\" = \"1\" ]; then\n    echo \"[sft] video_rec...\"\n    run_task \"sft\" \"video_rec\" \"${SCRIPT_DIR}/sft/video_rec.py\" \\\n        --input \"${INPUT_METADATA}\" --pid2sid \"${PID2SID_MAPPING}\" --seed ${SEED}\nfi\n\nif [ \"${RUN_SFT_INTERACTIVE_REC}\" = \"1\" ]; then\n    echo \"[sft] interactive_rec...\"\n    run_task \"sft\" \"interactive_rec\" \"${SCRIPT_DIR}/sft/interactive_rec.py\" \\\n        --input \"${INPUT_METADATA}\" --pid2sid \"${PID2SID_MAPPING}\" --seed ${SEED}\nfi\n\nif [ \"${RUN_SFT_LABEL_COND_REC}\" = \"1\" ]; then\n    echo \"[sft] label_cond_rec...\"\n    run_task \"sft\" \"label_cond_rec\" \"${SCRIPT_DIR}/sft/label_cond_rec.py\" \\\n        --input \"${INPUT_METADATA}\" --pid2sid \"${PID2SID_MAPPING}\" --seed ${SEED}\nfi\n\nif [ \"${RUN_SFT_LABEL_PRED}\" = \"1\" ]; then\n    echo \"[sft] label_pred...\"\n    run_task \"sft\" \"label_pred\" \"${SCRIPT_DIR}/sft/label_pred.py\" \\\n        --input \"${INPUT_METADATA}\" --pid2sid \"${PID2SID_MAPPING}\" --seed ${SEED}\nfi\n\nif [ \"${RUN_SFT_AD_REC}\" = \"1\" ]; then\n    echo \"[sft] ad_rec...\"\n    run_task \"sft\" \"ad_rec\" \"${SCRIPT_DIR}/sft/ad_rec.py\" \\\n        --input \"${INPUT_METADATA}\" --pid2sid \"${PID2SID_MAPPING}\" --seed ${SEED}\nfi\n\nif [ \"${RUN_SFT_PRODUCT_REC}\" = \"1\" ]; then\n    echo \"[sft] product_rec...\"\n    run_task \"sft\" \"product_rec\" \"${SCRIPT_DIR}/sft/product_rec.py\" \\\n        --input \"${INPUT_METADATA}\" --pid2sid \"${PID2SID_MAPPING}\" \\\n        --product_pid2sid \"${PRODUCT_PID2SID_MAPPING}\" --seed ${SEED}\nfi\n\nif [ \"${RUN_SFT_ITEM_UNDERSTAND}\" = \"1\" ]; then\n    echo \"[sft] item_understand...\"\n    run_task \"sft\" \"item_understand\" \"${SCRIPT_DIR}/sft/item_understand.py\" \\\n        --input \"${CAPTION_INPUT}\" --pid2sid \"${PID2SID_MAPPING}\" --seed ${SEED}\nfi\n\nif [ \"${RUN_SFT_REC_REASON}\" = \"1\" ]; then\n    echo \"[sft] rec_reason...\"\n    run_task \"sft\" \"rec_reason\" \"${SCRIPT_DIR}/sft/rec_reason.py\" \\\n        --input \"${INPUT_METADATA}\"\nfi\n\n# ============== Summary ==============\necho \"\"\necho \"========================================\"\necho \"Summary\"\necho \"========================================\"\nls -lh \"${OUTPUT_BASE_DIR}\"/*.parquet 2>/dev/null || echo \"No parquet files found\"\necho \"\"\necho \"Done!\"\n"
  },
  {
    "path": "data/onerec_data/sft/ad_rec.py",
    "content": "\"\"\"\nAd Recommendation Task (Cross-domain)\nInput: metadata parquet + pid2sid parquet\nOutput: LLM SFT training format parquet\n\nTask: Predict ad videos the user will click based on video watch history and ad click history.\n\"\"\"\n\nimport pandas as pd\nimport argparse\nimport json\nimport uuid\nimport random\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n# ============== Configuration ==============\nSID_FORMAT = '<|sid_begin|><s_a_{c0}><s_b_{c1}><s_c_{c2}><|sid_end|>'\nVIDEO_HIST_MAX_LEN = 100\nAD_HIST_MAX_LEN = 200\nTARGET_MAX_LEN = 10\n\n# System prompts (Chinese)\nSYSTEM_PROMPTS = [\n    \"你是一个智能广告推荐助手，能够根据用户的视频观看历史和广告点击行为，预测用户接下来可能点击的广告视频。\",\n    \"你是一个广告点击预测专家，擅长分析用户的观看习惯和广告点击偏好，预测用户的广告兴趣。\",\n    \"你是一个个性化广告推荐系统，能够基于用户的视频观看历史和广告点击记录，预测用户未来可能点击的广告。\",\n    \"你是一个用户行为分析助手，专注于理解用户的内容偏好和广告兴趣，推荐相关广告视频。\",\n    \"你是一个广告推荐引擎，通过学习用户的视频观看和广告点击历史，预测用户对广告的兴趣。\",\n]\n\n# Video watch history prompts (Chinese)\nVIDEO_WATCH_PROMPTS = [\n    \"用户观看过的视频：\",\n    \"用户浏览过的视频内容：\",\n    \"用户长时间观看的视频：\",\n    \"用户感兴趣的视频：\",\n]\n\n# Ad click history prompts (Chinese)\nAD_CLICK_PROMPTS = [\n    \"用户点击过的广告视频：\",\n    \"用户浏览过的广告视频：\",\n    \"用户感兴趣的广告视频：\",\n    \"用户历史广告点击记录：\",\n]\n\n# Task prompts (Chinese)\nTASK_PROMPTS = [\n    \"请根据用户的观看和广告点击历史，预测用户接下来可能点击的广告视频。\",\n    \"基于以上记录，推荐用户可能感兴趣并点击的广告视频。\",\n    \"分析用户的行为偏好，预测用户下一步会点击哪些广告视频。\",\n    \"根据用户的视频观看和广告点击习惯，推荐用户可能点击的广告视频。\",\n    \"请推荐用户接下来可能感兴趣并点击的广告视频。\",\n]\n\n\n# ============== Core Functions ==============\ndef pids_to_sids(pids, pid2sid: dict) -> str:\n    \"\"\"Convert a list of pids to SID string.\"\"\"\n    if pids is None or (isinstance(pids, float) and pd.isna(pids)):\n        return \"\"\n    sids = []\n    for pid in pids:\n        if pid in pid2sid:\n            code = pid2sid[pid]\n            sid = SID_FORMAT.format(c0=code[0], c1=code[1], c2=code[2])\n            sids.append(sid)\n    return ''.join(sids)\n\n\ndef build_messages(user_content: str, task_prompt: str, answer: str) -> str:\n    \"\"\"Build messages format JSON string.\"\"\"\n    system_prompt = random.choice(SYSTEM_PROMPTS)\n\n    messages = [\n        {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": system_prompt}]},\n        {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_content + \"\\n\" + task_prompt}]},\n        {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": answer}]}\n    ]\n    return json.dumps(messages, ensure_ascii=False)\n\n\ndef process_row(row, pid2sid: dict) -> dict:\n    \"\"\"Process a single row of data.\"\"\"\n    hist_ad_pids = row['hist_ad_pid']\n    target_ad_pids = row['target_ad_pid']\n\n    # Check data validity\n    if target_ad_pids is None or (isinstance(target_ad_pids, float) and pd.isna(target_ad_pids)):\n        return None\n    if len(target_ad_pids) == 0:\n        return None\n\n    # Build user content parts\n    user_content_parts = []\n\n    # 1. Process video watch history (long-view videos)\n    hist_longview_video_list = row['hist_longview_video_list']\n    if hist_longview_video_list is not None and not (isinstance(hist_longview_video_list, float) and pd.isna(hist_longview_video_list)):\n        if len(hist_longview_video_list) > 0:\n            # Keep the most recent videos (rightmost in the list)\n            video_sids = pids_to_sids(hist_longview_video_list[-VIDEO_HIST_MAX_LEN:], pid2sid)\n            if video_sids:\n                video_prompt = random.choice(VIDEO_WATCH_PROMPTS)\n                user_content_parts.append(f\"{video_prompt}{video_sids}\")\n\n    # 2. Process ad click history\n    if hist_ad_pids is not None and not (isinstance(hist_ad_pids, float) and pd.isna(hist_ad_pids)):\n        if len(hist_ad_pids) > 0:\n            # Keep the most recent ads (rightmost in the list)\n            ad_sids = pids_to_sids(hist_ad_pids[-AD_HIST_MAX_LEN:], pid2sid)\n            if ad_sids:\n                ad_prompt = random.choice(AD_CLICK_PROMPTS)\n                user_content_parts.append(f\"{ad_prompt}{ad_sids}\")\n\n    # Need at least one type of history\n    if not user_content_parts:\n        return None\n\n    # 3. Process target ad videos\n    answer = pids_to_sids(target_ad_pids[:TARGET_MAX_LEN], pid2sid)\n    if not answer:\n        return None\n\n    # Build final messages\n    user_content = \"\\n\".join(user_content_parts)\n    task_prompt = random.choice(TASK_PROMPTS)\n\n    return {\n        'source': 'RecIF_AdRec',\n        'uuid': str(uuid.uuid4()),\n        'messages': build_messages(user_content, task_prompt, answer),\n        'metadata': json.dumps({'uid': int(row['uid'])}, ensure_ascii=False)\n    }\n\n\n# ============== Main Function ==============\ndef main():\n    parser = argparse.ArgumentParser(description=\"Ad Recommendation Task Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input metadata parquet path')\n    parser.add_argument('--pid2sid', type=str, required=True, help='pid2sid mapping parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # 1. Load pid2sid mapping\n    print(f\"Loading pid2sid from {args.pid2sid}...\")\n    df_pid2sid = pd.read_parquet(args.pid2sid)\n    pid2sid = dict(zip(df_pid2sid['pid'], df_pid2sid['sid']))\n    print(f\"  Loaded {len(pid2sid):,} mappings\")\n\n    # 2. Load metadata\n    print(f\"Loading metadata from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # 3. Process data (train only, split=0)\n    print(\"Processing...\")\n    results = []\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        if row['split'] != 0:\n            continue\n        result = process_row(row, pid2sid)\n        if result:\n            results.append(result)\n\n    # 4. Save results\n    df_train = pd.DataFrame(results)\n    train_path = output_dir / 'train.parquet'\n    df_train.to_parquet(train_path, index=False)\n\n    print(f\"Saved: {train_path} ({len(df_train):,} rows)\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/onerec_data/sft/interactive_rec.py",
    "content": "\"\"\"\nInteractive Recommendation Task\nInput: metadata parquet + pid2sid parquet\nOutput: LLM SFT training format parquet\n\nTask: Given user profile (inter_user_profile_with_sid) and search keyword,\npredict items the user will interact with.\n\"\"\"\n\nimport pandas as pd\nimport json\nimport uuid\nimport random\nimport argparse\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n# ============== Configuration ==============\nSID_FORMAT = '<|sid_begin|><s_a_{c0}><s_b_{c1}><s_c_{c2}><|sid_end|>'\nTARGET_MAX_LEN = 10\n\n# System prompts (Chinese)\nSYSTEM_PROMPTS = [\n    \"你是一个智能推荐助手，能够根据用户的兴趣画像和当前对话需求，精准推荐用户可能感兴趣的内容。\",\n    \"你是一个个性化推荐专家，擅长理解用户画像和对话意图，提供精准的内容推荐。\",\n    \"你是一个对话式推荐系统，能够基于用户的兴趣特征和搜索意图，推荐最相关的内容。\",\n    \"你是一个交互式内容推荐引擎，专注于理解用户画像和对话上下文，提供个性化推荐。\",\n    \"你是一个智能内容顾问，通过分析用户兴趣和对话关键词，推荐符合需求的内容。\",\n    \"你是一位资深的推荐算法专家，精通用户画像分析和个性化匹配，能够为每位用户提供量身定制的内容推荐。\",\n    \"你是一个具备深度学习能力的推荐引擎，可以准确捕捉用户兴趣点，并结合实时需求给出最优推荐方案。\",\n    \"你是一个智能化的内容匹配系统，擅长从海量信息中筛选出与用户画像和查询意图高度契合的内容。\",\n    \"你是一个AI驱动的推荐助理，能够综合分析用户的历史偏好和当前需求，提供精准且多元化的内容推荐。\",\n    \"你是一个智慧型推荐顾问，通过理解用户的兴趣图谱和语义意图，实现千人千面的个性化推荐。\",\n]\n\n# User prompts (Chinese)\nUSER_PROMPTS = [\n    \"用户画像：\\n{user_profile}\\n\\n用户查询：{keyword}\\n\\n请推荐相关内容。\",\n    \"用户兴趣：\\n{user_profile}\\n\\n搜索关键词：{keyword}\\n\\n请根据用户需求推荐内容。\",\n    \"用户特征：\\n{user_profile}\\n\\n当前需求：{keyword}\\n\\n请提供个性化推荐。\",\n    \"【用户画像】\\n{user_profile}\\n\\n【用户输入】\\n{keyword}\\n\\n基于以上信息，推荐合适的内容。\",\n    \"用户的兴趣偏好：\\n{user_profile}\\n\\n用户正在寻找：{keyword}\\n\\n请推荐最相关的内容。\",\n    \"这是用户的兴趣画像：\\n{user_profile}\\n\\n用户现在想了解关于\\\"{keyword}\\\"的内容，能帮忙推荐一些吗？\",\n    \"用户平时喜欢：\\n{user_profile}\\n\\n现在用户搜索了\\\"{keyword}\\\"，请根据用户的兴趣推荐相关内容。\",\n    \"根据用户画像显示，用户兴趣如下：\\n{user_profile}\\n\\n用户刚刚输入了\\\"{keyword}\\\"，麻烦推荐一些合适的内容。\",\n    \"用户的兴趣领域包括：\\n{user_profile}\\n\\n用户正在查找\\\"{keyword}\\\"相关的内容，请给出推荐。\",\n    \"用户的个人画像如下：\\n{user_profile}\\n\\n用户搜索了\\\"{keyword}\\\"这个关键词，请推荐一些相关的内容。\",\n]\n\n\n# ============== Core Functions ==============\ndef pids_to_sids(pids, pid2sid: dict) -> str:\n    \"\"\"Convert a list of pids to SID string.\"\"\"\n    if pids is None or (isinstance(pids, float) and pd.isna(pids)):\n        return \"\"\n    sids = []\n    for pid in pids:\n        if pid in pid2sid:\n            code = pid2sid[pid]\n            sid = SID_FORMAT.format(c0=code[0], c1=code[1], c2=code[2])\n            sids.append(sid)\n    return ''.join(sids)\n\n\ndef build_messages(user_profile: str, keyword: str, answer: str) -> str:\n    \"\"\"Build messages format JSON string.\"\"\"\n    system_prompt = random.choice(SYSTEM_PROMPTS)\n    user_prompt = random.choice(USER_PROMPTS).format(user_profile=user_profile, keyword=keyword)\n\n    messages = [\n        {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": system_prompt}]},\n        {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_prompt}]},\n        {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": answer}]}\n    ]\n    return json.dumps(messages, ensure_ascii=False)\n\n\ndef process_row(row, pid2sid: dict) -> list:\n    \"\"\"Process a single row of data. Returns a list of results (one per keyword).\"\"\"\n    user_profile = row.get('inter_user_profile_with_sid')\n    inter_keyword_to_items = row['inter_keyword_to_items']\n\n    # Check user profile validity\n    if user_profile is None or (isinstance(user_profile, float) and pd.isna(user_profile)):\n        return []\n    if not user_profile or not isinstance(user_profile, str):\n        return []\n\n    # Check keyword_to_items validity\n    if inter_keyword_to_items is None or (isinstance(inter_keyword_to_items, float) and pd.isna(inter_keyword_to_items)):\n        return []\n\n    # Parse JSON string if needed\n    if isinstance(inter_keyword_to_items, str):\n        try:\n            inter_keyword_to_items = json.loads(inter_keyword_to_items)\n        except json.JSONDecodeError:\n            return []\n\n    if not isinstance(inter_keyword_to_items, dict) or len(inter_keyword_to_items) == 0:\n        return []\n\n    results = []\n    for keyword, item_ids in inter_keyword_to_items.items():\n        if not keyword or not item_ids:\n            continue\n\n        # Convert target items to SIDs\n        answer = pids_to_sids(item_ids[:TARGET_MAX_LEN], pid2sid)\n        if not answer:\n            continue\n\n        result = {\n            'source': 'RecIF_InteractiveRec',\n            'uuid': str(uuid.uuid4()),\n            'messages': build_messages(user_profile, keyword, answer),\n            'metadata': json.dumps({'uid': int(row['uid']), 'keyword': keyword}, ensure_ascii=False)\n        }\n        results.append(result)\n\n    return results\n\n\n# ============== Main Function ==============\ndef main():\n    parser = argparse.ArgumentParser(description=\"Interactive Recommendation Task Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input metadata parquet path')\n    parser.add_argument('--pid2sid', type=str, required=True, help='pid2sid mapping parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # 1. Load pid2sid mapping\n    print(f\"Loading pid2sid from {args.pid2sid}...\")\n    df_pid2sid = pd.read_parquet(args.pid2sid)\n    pid2sid = dict(zip(df_pid2sid['pid'], df_pid2sid['sid']))\n    print(f\"  Loaded {len(pid2sid):,} mappings\")\n\n    # 2. Load metadata\n    print(f\"Loading metadata from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # 3. Process data (train only, split=0)\n    print(\"Processing...\")\n    results = []\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        if row['split'] != 0:\n            continue\n        row_results = process_row(row, pid2sid)\n        for result in row_results:\n            results.append(result)\n\n    # 4. Save results\n    df_train = pd.DataFrame(results)\n    train_path = output_dir / 'train.parquet'\n    df_train.to_parquet(train_path, index=False)\n\n    print(f\"Saved: {train_path} ({len(df_train):,} rows)\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/onerec_data/sft/item_understand.py",
    "content": "\"\"\"\nItem Understand Task\nInput: caption parquet (pid, dense_caption) + pid2sid parquet\nOutput: LLM SFT training format parquet\n\nTask: Given a video SID, generate its description/caption.\n\"\"\"\n\nimport pandas as pd\nimport argparse\nimport json\nimport uuid\nimport random\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n# ============== Configuration ==============\nSID_FORMAT = '<|sid_begin|><s_a_{c0}><s_b_{c1}><s_c_{c2}><|sid_end|>'\n\n# System prompts (Chinese)\nSYSTEM_PROMPTS = [\n    \"你是一名视频描述生成器，请根据下面的视频token生成视频描述。\",\n    \"你是一个专业的视频内容分析助手，能够理解视频token并生成准确的描述。\",\n    \"你是一位视频理解专家，擅长将视频token转换为详细的文字描述。\",\n    \"作为视频内容解析助手，你需要根据视频token提供精准的内容描述。\",\n    \"你是一个智能视频解说员，可以根据视频token创建生动的描述。\",\n    \"你具备理解视频token并生成高质量描述的能力。\",\n    \"你是视频内容描述专家，能够将视频token转化为易懂的文字说明。\",\n    \"作为AI视频分析助手，你可以根据视频token生成详细准确的描述。\",\n]\n\n# User prompts (Chinese)\nUSER_PROMPTS = [\n    \"请描述 {sid} 的内容\",\n    \"这段视频 {sid} 展示了什么？\",\n    \"请解释 {sid} 中的内容\",\n    \"能否说明 {sid} 里发生了什么？\",\n    \"请分析 {sid} 的具体内容\",\n    \"{sid} 这个视频讲的是什么？\",\n    \"请详细描述 {sid}\",\n    \"告诉我 {sid} 的内容是什么\",\n    \"请为 {sid} 生成描述\",\n    \"{sid} 包含哪些内容？\",\n    \"请说明视频 {sid} 的主要内容\",\n    \"描述一下 {sid} 中展现的场景\",\n    \"{sid} 这段内容是关于什么的？\",\n    \"请解读 {sid} 的视频内容\",\n    \"能描述下 {sid} 吗？\",\n    \"{sid} 里面有什么？\",\n    \"请对 {sid} 进行内容说明\",\n    \"这个 {sid} 是什么内容？\",\n    \"分析 {sid} 并给出描述\",\n    \"请阐述 {sid} 的内容细节\",\n]\n\n\n# ============== Core Functions ==============\ndef pid_to_sid(pid, pid2sid: dict) -> str:\n    \"\"\"Convert a single pid to SID string.\"\"\"\n    if pid not in pid2sid:\n        return \"\"\n    code = pid2sid[pid]\n    return SID_FORMAT.format(c0=code[0], c1=code[1], c2=code[2])\n\n\ndef build_messages(sid: str, caption: str) -> str:\n    \"\"\"Build messages format JSON string.\"\"\"\n    system_prompt = random.choice(SYSTEM_PROMPTS)\n    user_prompt = random.choice(USER_PROMPTS).format(sid=sid)\n\n    messages = [\n        {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": system_prompt}]},\n        {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_prompt}]},\n        {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": caption}]}\n    ]\n    return json.dumps(messages, ensure_ascii=False)\n\n\ndef process_row(row, pid2sid: dict) -> dict:\n    \"\"\"Process a single row of data.\"\"\"\n    pid = row['pid']\n    dense_caption = row['dense_caption']\n\n    # Check data validity\n    if dense_caption is None or (isinstance(dense_caption, float) and pd.isna(dense_caption)):\n        return None\n    if not dense_caption:\n        return None\n\n    # Convert pid to SID\n    sid = pid_to_sid(pid, pid2sid)\n    if not sid:\n        return None\n\n    return {\n        'source': 'RecIF_ItemUnderstand',\n        'uuid': str(uuid.uuid4()),\n        'messages': build_messages(sid, dense_caption),\n        'metadata': json.dumps({'pid': int(pid), 'sid': sid}, ensure_ascii=False)\n    }\n\n\n# ============== Main Function ==============\ndef main():\n    parser = argparse.ArgumentParser(description=\"Item Understand Task Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input caption parquet path')\n    parser.add_argument('--pid2sid', type=str, required=True, help='pid2sid mapping parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # 1. Load pid2sid mapping\n    print(f\"Loading pid2sid from {args.pid2sid}...\")\n    df_pid2sid = pd.read_parquet(args.pid2sid)\n    pid2sid = dict(zip(df_pid2sid['pid'], df_pid2sid['sid']))\n    print(f\"  Loaded {len(pid2sid):,} mappings\")\n\n    # 2. Load caption data\n    print(f\"Loading caption data from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # 3. Process data\n    print(\"Processing...\")\n    results = []\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        result = process_row(row, pid2sid)\n        if result:\n            results.append(result)\n\n    # 4. Save results\n    df_output = pd.DataFrame(results)\n    output_path = output_dir / 'train.parquet'\n    df_output.to_parquet(output_path, index=False)\n\n    print(f\"Saved: {output_path} ({len(df_output):,} rows)\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/onerec_data/sft/label_cond_rec.py",
    "content": "\"\"\"\nLabel Conditional Recommendation Task\nInput: metadata parquet + pid2sid parquet\nOutput: LLM SFT training format parquet\n\nTask: Predict items that users will interact with under specific behavior types\n(longview/like/follow/forward/not_interested).\n\"\"\"\n\nimport pandas as pd\nimport numpy as np\nimport argparse\nimport json\nimport uuid\nimport random\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n# ============== Configuration ==============\nSID_FORMAT = '<|sid_begin|><s_a_{c0}><s_b_{c1}><s_c_{c2}><|sid_end|>'\nINTERACTION_MAX_LEN = 10  # Max items per interaction type\n\n# Interaction types\nINTERACTION_TYPES = [\"longview\", \"like\", \"follow\", \"forward\", \"not_interested\"]\n\n# System prompts (Chinese)\nSYSTEM_PROMPTS = [\n    \"你是一个智能推荐助手，能够根据用户对不同内容的互动行为，精准推荐用户可能感兴趣的下一个内容。\",\n    \"你是一个内容推荐专家，擅长分析用户的互动模式，预测用户的内容偏好。\",\n    \"你是一个个性化推荐系统，能够基于用户的历史互动行为，预测用户未来可能产生的互动。\",\n    \"你是一个用户行为分析助手，专注于理解用户的兴趣偏好，并推荐相关内容。\",\n]\n\n# Interaction type descriptions (Chinese)\nINTERACTION_PROMPTS = {\n    \"longview\": [\"用户长时观看过以下内容：\", \"用户完整观看过的内容：\", \"用户深度浏览过以下内容：\"],\n    \"like\": [\"用户点赞过以下内容：\", \"用户喜欢的内容：\", \"获得用户点赞的内容：\"],\n    \"follow\": [\"用户关注过以下内容的作者：\", \"用户关注了这些内容的创作者：\"],\n    \"forward\": [\"用户转发过以下内容：\", \"用户分享过的内容：\", \"用户向他人推荐的内容：\"],\n    \"not_interested\": [\"用户表示不感兴趣的内容：\", \"用户标记为不感兴趣的内容：\"],\n}\n\n# Task prompts for each interaction type (Chinese)\nTASK_PROMPTS = {\n    \"longview\": [\"请根据用户的互动行为，推荐用户可能会长时观看的内容。\", \"基于以上互动记录，预测用户会完整观看的内容。\"],\n    \"like\": [\"请根据用户的互动行为，推荐用户可能会点赞的内容。\", \"基于用户的互动偏好，预测用户会给哪些内容点赞。\"],\n    \"follow\": [\"请根据用户的互动行为，推荐用户可能会关注其作者的内容。\", \"基于用户的关注偏好，预测用户会关注哪些内容的创作者。\"],\n    \"forward\": [\"请根据用户的互动行为，推荐用户可能会转发的内容。\", \"基于用户的分享习惯，预测用户会转发的内容。\"],\n    \"not_interested\": [\"请根据用户的互动行为，预测用户可能会表示不感兴趣的内容。\", \"基于用户的偏好，预测用户可能会标记为不感兴趣的内容。\"],\n}\n\n\n# ============== Core Functions ==============\ndef pids_to_sids(pids, pid2sid: dict) -> str:\n    \"\"\"Convert a list of pids to SID string.\"\"\"\n    if pids is None or (isinstance(pids, float) and pd.isna(pids)):\n        return \"\"\n    sids = []\n    for pid in pids:\n        if pid in pid2sid:\n            code = pid2sid[pid]\n            sid = SID_FORMAT.format(c0=code[0], c1=code[1], c2=code[2])\n            sids.append(sid)\n    return ''.join(sids)\n\n\ndef build_messages(user_content: str, task_prompt: str, answer: str) -> str:\n    \"\"\"Build messages format JSON string.\"\"\"\n    system_prompt = random.choice(SYSTEM_PROMPTS)\n\n    messages = [\n        {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": system_prompt}]},\n        {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_content + \"\\n\" + task_prompt}]},\n        {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": answer}]}\n    ]\n    return json.dumps(messages, ensure_ascii=False)\n\n\ndef process_row(row, pid2sid: dict) -> dict:\n    \"\"\"Process a single row of data.\"\"\"\n    hist_pids = row['hist_video_pid']\n    target_pids = row['target_video_pid']\n\n    # Check data validity\n    if hist_pids is None or (isinstance(hist_pids, float) and pd.isna(hist_pids)):\n        return None\n    if target_pids is None or (isinstance(target_pids, float) and pd.isna(target_pids)):\n        return None\n\n    # Build user interaction history description\n    user_content_parts = []\n    for interaction in INTERACTION_TYPES:\n        hist_col = f'hist_video_{interaction}'\n        if hist_col not in row or row[hist_col] is None:\n            continue\n\n        mask = row[hist_col]\n        if isinstance(mask, float) and pd.isna(mask):\n            continue\n\n        # Filter pids with interaction based on mask\n        if len(mask) == len(hist_pids):\n            mask_array = np.array(mask)\n            pids_array = np.array(hist_pids)\n            interaction_pids = pids_array[mask_array == 1].tolist()\n            interaction_pids = interaction_pids[-INTERACTION_MAX_LEN:]\n\n            if interaction_pids:\n                sids = pids_to_sids(interaction_pids, pid2sid)\n                if sids:\n                    prompt = random.choice(INTERACTION_PROMPTS[interaction])\n                    user_content_parts.append(f\"{prompt}{sids}\")\n\n    if not user_content_parts:\n        return None\n\n    # Randomly select a target interaction type with data\n    available_targets = []\n    for interaction in INTERACTION_TYPES:\n        target_col = f'target_video_{interaction}'\n        if target_col not in row or row[target_col] is None:\n            continue\n\n        target_mask = row[target_col]\n        if isinstance(target_mask, float) and pd.isna(target_mask):\n            continue\n\n        if len(target_mask) == len(target_pids) and sum(1 for x in target_mask if x == 1) > 0:\n            available_targets.append(interaction)\n\n    if not available_targets:\n        return None\n\n    # Randomly select an interaction type as target\n    selected_interaction = random.choice(available_targets)\n    target_col = f'target_video_{selected_interaction}'\n    target_mask = row[target_col]\n\n    # Filter target_pids\n    target_mask_array = np.array(target_mask)\n    target_pids_array = np.array(target_pids)\n    filtered_target_pids = target_pids_array[target_mask_array == 1].tolist()\n    filtered_target_pids = filtered_target_pids[:INTERACTION_MAX_LEN]\n\n    # Convert to SID\n    answer = pids_to_sids(filtered_target_pids, pid2sid)\n    if not answer:\n        return None\n\n    # Build final messages\n    user_content = \"\\n\".join(user_content_parts)\n    task_prompt = random.choice(TASK_PROMPTS[selected_interaction])\n\n    return {\n        'source': 'RecIF_LabelCondRec',\n        'uuid': str(uuid.uuid4()),\n        'messages': build_messages(user_content, task_prompt, answer),\n        'metadata': json.dumps({'uid': int(row['uid']), 'target_interaction': selected_interaction}, ensure_ascii=False)\n    }\n\n\n# ============== Main Function ==============\ndef main():\n    parser = argparse.ArgumentParser(description=\"Label Conditional Recommendation Task Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input metadata parquet path')\n    parser.add_argument('--pid2sid', type=str, required=True, help='pid2sid mapping parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # 1. Load pid2sid mapping\n    print(f\"Loading pid2sid from {args.pid2sid}...\")\n    df_pid2sid = pd.read_parquet(args.pid2sid)\n    pid2sid = dict(zip(df_pid2sid['pid'], df_pid2sid['sid']))\n    print(f\"  Loaded {len(pid2sid):,} mappings\")\n\n    # 2. Load metadata\n    print(f\"Loading metadata from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # 3. Process data (train only, split=0)\n    print(\"Processing...\")\n    results = []\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        if row['split'] != 0:\n            continue\n        result = process_row(row, pid2sid)\n        if result:\n            results.append(result)\n\n    # 4. Save results\n    df_train = pd.DataFrame(results)\n    train_path = output_dir / 'train.parquet'\n    df_train.to_parquet(train_path, index=False)\n\n    print(f\"Saved: {train_path} ({len(df_train):,} rows)\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/onerec_data/sft/label_pred.py",
    "content": "\"\"\"\nLabel Prediction Task (Point-wise Classification)\nInput: metadata parquet + pid2sid parquet\nOutput: LLM SFT training format parquet\n\nTask: Predict whether a user will \"longview\" (watch for a long time) a candidate video.\nBinary classification: \"是\" (yes) or \"否\" (no).\n\"\"\"\n\nimport pandas as pd\nimport numpy as np\nimport argparse\nimport json\nimport uuid\nimport random\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n# ============== Configuration ==============\nSID_FORMAT = '<|sid_begin|><s_a_{c0}><s_b_{c1}><s_c_{c2}><|sid_end|>'\nINTERACTION_MAX_LEN = 10\nTARGET_MAX_LEN = 10\n\n# Interaction types\nINTERACTION_TYPES = [\"longview\", \"like\", \"follow\", \"forward\", \"not_interested\"]\n\n# System prompts (Chinese)\nSYSTEM_PROMPTS = [\n    \"你是一个内容推荐专家，擅长分析用户的互动模式，预测用户的内容偏好。\",\n    \"你是一个个性化推荐系统，能够基于用户的历史互动行为，预测用户未来可能产生的互动。\",\n    \"你是一个用户行为分析助手，专注于理解用户的兴趣偏好，并推荐相关内容。\",\n    \"你是一个内容推荐引擎，通过学习用户的互动历史，预测用户对新内容的反应。\",\n]\n\n# Interaction type descriptions (Chinese)\nINTERACTION_PROMPTS = {\n    \"longview\": [\"用户长时观看过以下内容：\", \"用户完整观看过的内容：\", \"用户深度浏览过以下内容：\"],\n    \"like\": [\"用户点赞过以下内容：\", \"用户喜欢的内容：\", \"获得用户点赞的内容：\"],\n    \"follow\": [\"用户关注过以下内容的作者：\", \"用户关注了这些内容的创作者：\"],\n    \"forward\": [\"用户转发过以下内容：\", \"用户分享过的内容：\", \"用户向他人推荐的内容：\"],\n    \"not_interested\": [\"用户表示不感兴趣的内容：\", \"用户标记为不感兴趣的内容：\"],\n}\n\n# Classification question prompts (Chinese)\nCLASSIFICATION_QUESTIONS = [\n    \"请判断用户是否会长时观看视频{candidate_sid}？\",\n    \"用户会完整观看视频{candidate_sid}吗？\",\n    \"预测用户是否会深度观看视频{candidate_sid}。\",\n    \"视频{candidate_sid}能够吸引用户长时间观看吗？\",\n    \"用户会花时间仔细观看视频{candidate_sid}吗？\",\n]\n\n# Classification answers\nPOSITIVE_ANSWER = \"是\"\nNEGATIVE_ANSWER = \"否\"\n\n\n# ============== Core Functions ==============\ndef pids_to_sids(pids, pid2sid: dict) -> str:\n    \"\"\"Convert a list of pids to SID string.\"\"\"\n    if pids is None or (isinstance(pids, float) and pd.isna(pids)):\n        return \"\"\n    sids = []\n    for pid in pids:\n        if pid in pid2sid:\n            code = pid2sid[pid]\n            sid = SID_FORMAT.format(c0=code[0], c1=code[1], c2=code[2])\n            sids.append(sid)\n    return ''.join(sids)\n\n\ndef pid_to_sid(pid, pid2sid: dict) -> str:\n    \"\"\"Convert a single pid to SID string.\"\"\"\n    if pid in pid2sid:\n        code = pid2sid[pid]\n        return SID_FORMAT.format(c0=code[0], c1=code[1], c2=code[2])\n    return \"\"\n\n\ndef build_messages(user_content: str, question: str, answer: str) -> str:\n    \"\"\"Build messages format JSON string.\"\"\"\n    system_prompt = random.choice(SYSTEM_PROMPTS)\n\n    messages = [\n        {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": system_prompt}]},\n        {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_content + \"\\n\" + question}]},\n        {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": answer}]}\n    ]\n    return json.dumps(messages, ensure_ascii=False)\n\n\ndef process_row(row, pid2sid: dict) -> list:\n    \"\"\"Process a single row of data. Returns a list of results (one per candidate video).\"\"\"\n    hist_pids = row['hist_video_pid']\n    target_pids = row['target_video_pid']\n\n    # Check data validity\n    if hist_pids is None or (isinstance(hist_pids, float) and pd.isna(hist_pids)):\n        return []\n    if target_pids is None or (isinstance(target_pids, float) and pd.isna(target_pids)):\n        return []\n    if len(target_pids) == 0:\n        return []\n\n    # Build user interaction history description\n    user_content_parts = []\n    for interaction in INTERACTION_TYPES:\n        hist_col = f'hist_video_{interaction}'\n        if hist_col not in row or row[hist_col] is None:\n            continue\n\n        mask = row[hist_col]\n        if isinstance(mask, float) and pd.isna(mask):\n            continue\n\n        # Filter pids with interaction based on mask\n        if len(mask) == len(hist_pids):\n            mask_array = np.array(mask)\n            pids_array = np.array(hist_pids)\n            interaction_pids = pids_array[mask_array == 1].tolist()\n            interaction_pids = interaction_pids[-INTERACTION_MAX_LEN:]\n\n            if interaction_pids:\n                sids = pids_to_sids(interaction_pids, pid2sid)\n                if sids:\n                    prompt = random.choice(INTERACTION_PROMPTS[interaction])\n                    user_content_parts.append(f\"{prompt}{sids}\")\n\n    if not user_content_parts:\n        return []\n\n    # Get target longview mask\n    target_longview_col = 'target_video_longview'\n    if target_longview_col not in row or row[target_longview_col] is None:\n        return []\n\n    target_longview_mask = row[target_longview_col]\n    if isinstance(target_longview_mask, float) and pd.isna(target_longview_mask):\n        return []\n\n    if len(target_longview_mask) != len(target_pids):\n        return []\n\n    # Limit target candidates\n    limited_target_pids = target_pids[:TARGET_MAX_LEN]\n    limited_longview_mask = target_longview_mask[:TARGET_MAX_LEN]\n\n    # Build user content\n    user_content = \"\\n\".join(user_content_parts)\n\n    # Generate one sample per candidate video\n    results = []\n    for candidate_pid, label in zip(limited_target_pids, limited_longview_mask):\n        label = int(label)\n\n        # Convert candidate pid to SID\n        candidate_sid = pid_to_sid(candidate_pid, pid2sid)\n        if not candidate_sid:\n            continue\n\n        # Build question with candidate SID\n        question = random.choice(CLASSIFICATION_QUESTIONS).format(candidate_sid=candidate_sid)\n\n        # Determine answer based on label\n        answer = POSITIVE_ANSWER if label == 1 else NEGATIVE_ANSWER\n\n        result = {\n            'source': 'RecIF_LabelPred',\n            'uuid': str(uuid.uuid4()),\n            'messages': build_messages(user_content, question, answer),\n            'metadata': json.dumps({\n                'uid': int(row['uid']),\n                'label': label\n            }, ensure_ascii=False)\n        }\n        results.append(result)\n\n    return results\n\n\n# ============== Main Function ==============\ndef main():\n    parser = argparse.ArgumentParser(description=\"Label Prediction Task Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input metadata parquet path')\n    parser.add_argument('--pid2sid', type=str, required=True, help='pid2sid mapping parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # 1. Load pid2sid mapping\n    print(f\"Loading pid2sid from {args.pid2sid}...\")\n    df_pid2sid = pd.read_parquet(args.pid2sid)\n    pid2sid = dict(zip(df_pid2sid['pid'], df_pid2sid['sid']))\n    print(f\"  Loaded {len(pid2sid):,} mappings\")\n\n    # 2. Load metadata\n    print(f\"Loading metadata from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # 3. Process data (train only, split=0)\n    print(\"Processing...\")\n    results = []\n    positive_count, negative_count = 0, 0\n\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        if row['split'] != 0:\n            continue\n        row_results = process_row(row, pid2sid)\n        for result in row_results:\n            metadata = json.loads(result['metadata'])\n            label = metadata['label']\n            results.append(result)\n            if label == 1:\n                positive_count += 1\n            else:\n                negative_count += 1\n\n    # 4. Save results\n    df_train = pd.DataFrame(results)\n    train_path = output_dir / 'train.parquet'\n    df_train.to_parquet(train_path, index=False)\n\n    print(f\"Saved: {train_path} ({len(df_train):,} rows, pos={positive_count:,}, neg={negative_count:,})\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/onerec_data/sft/product_rec.py",
    "content": "\"\"\"\nProduct Recommendation Task (Cross-domain)\nInput: metadata parquet + video_pid2sid parquet + product_pid2sid parquet\nOutput: LLM SFT training format parquet\n\nTask: Predict product the user will click based on video watch history and product click history.\nNote: Video and product use different pid2sid mappings (different domains).\n\"\"\"\n\nimport pandas as pd\nimport argparse\nimport json\nimport uuid\nimport random\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n# ============== Configuration ==============\nSID_FORMAT = '<|sid_begin|><s_a_{c0}><s_b_{c1}><s_c_{c2}><|sid_end|>'\nVIDEO_HIST_MAX_LEN = 100\nPRODUCT_HIST_MAX_LEN = 100\nTARGET_MAX_LEN = 10\n\n# System prompts (Chinese)\nSYSTEM_PROMPTS = [\n    \"你是一个智能跨域推荐助手，能够根据用户观看的视频内容和历史购物行为，预测用户接下来可能点击的商品。\",\n    \"你是一个跨域推荐专家，擅长分析用户的观看习惯和购物偏好，预测用户的商品兴趣。\",\n    \"你是一个个性化推荐系统，能够基于用户的视频观看历史和购物记录，预测用户未来可能购买的商品。\",\n    \"你是一个用户行为分析助手，专注于理解用户的内容偏好和购物兴趣，推荐相关商品。\",\n    \"你是一个跨域推荐引擎，通过学习用户的视频观看和购物历史，预测用户对商品的兴趣。\",\n]\n\n# Video watch history prompts (Chinese)\nVIDEO_WATCH_PROMPTS = [\n    \"用户观看过的视频：\",\n    \"用户浏览过的视频内容：\",\n    \"用户长时间观看的视频：\",\n    \"用户感兴趣的视频：\",\n]\n\n# Product click history prompts (Chinese)\nPRODUCT_CLICK_PROMPTS = [\n    \"用户点击过的商品：\",\n    \"用户浏览过的商品：\",\n    \"用户感兴趣的商品：\",\n    \"用户历史购物记录：\",\n]\n\n# Task prompts (Chinese)\nTASK_PROMPTS = [\n    \"请根据用户的观看和购物历史，预测用户接下来可能点击的商品。\",\n    \"基于以上记录，推荐用户可能感兴趣并点击的商品。\",\n    \"分析用户的行为偏好，预测用户下一步会点击哪些商品。\",\n    \"根据用户的视频观看和购物习惯，推荐用户可能点击的商品。\",\n    \"请推荐用户接下来可能感兴趣并点击的商品。\",\n]\n\n\n# ============== Core Functions ==============\ndef pids_to_sids(pids, pid2sid: dict) -> str:\n    \"\"\"Convert a list of pids to SID string.\"\"\"\n    if pids is None or (isinstance(pids, float) and pd.isna(pids)):\n        return \"\"\n    sids = []\n    for pid in pids:\n        if pid in pid2sid:\n            code = pid2sid[pid]\n            sid = SID_FORMAT.format(c0=code[0], c1=code[1], c2=code[2])\n            sids.append(sid)\n    return ''.join(sids)\n\n\ndef build_messages(user_content: str, task_prompt: str, answer: str) -> str:\n    \"\"\"Build messages format JSON string.\"\"\"\n    system_prompt = random.choice(SYSTEM_PROMPTS)\n\n    messages = [\n        {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": system_prompt}]},\n        {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_content + \"\\n\" + task_prompt}]},\n        {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": answer}]}\n    ]\n    return json.dumps(messages, ensure_ascii=False)\n\n\ndef process_row(row, video_pid2sid: dict, product_pid2sid: dict) -> dict:\n    \"\"\"Process a single row of data.\"\"\"\n    hist_product_pids = row['hist_goods_pid']\n    target_product_pids = row['target_goods_pid']\n\n    # Check data validity\n    if target_product_pids is None or (isinstance(target_product_pids, float) and pd.isna(target_product_pids)):\n        return None\n    if len(target_product_pids) == 0:\n        return None\n\n    # Build user content parts\n    user_content_parts = []\n\n    # 1. Process video watch history (long-view videos, use video_pid2sid)\n    hist_longview_video_list = row['hist_longview_video_list']\n    if hist_longview_video_list is not None and not (isinstance(hist_longview_video_list, float) and pd.isna(hist_longview_video_list)):\n        if len(hist_longview_video_list) > 0:\n            # Keep the most recent videos (rightmost in the list)\n            video_sids = pids_to_sids(hist_longview_video_list[-VIDEO_HIST_MAX_LEN:], video_pid2sid)\n            if video_sids:\n                video_prompt = random.choice(VIDEO_WATCH_PROMPTS)\n                user_content_parts.append(f\"{video_prompt}{video_sids}\")\n\n    # 2. Process product click history (use product_pid2sid)\n    if hist_product_pids is not None and not (isinstance(hist_product_pids, float) and pd.isna(hist_product_pids)):\n        if len(hist_product_pids) > 0:\n            # Keep the most recent products (rightmost in the list)\n            product_sids = pids_to_sids(hist_product_pids[-PRODUCT_HIST_MAX_LEN:], product_pid2sid)\n            if product_sids:\n                product_prompt = random.choice(PRODUCT_CLICK_PROMPTS)\n                user_content_parts.append(f\"{product_prompt}{product_sids}\")\n\n    # Need at least one type of history\n    if not user_content_parts:\n        return None\n\n    # 3. Process target product (use product_pid2sid)\n    answer = pids_to_sids(target_product_pids[:TARGET_MAX_LEN], product_pid2sid)\n    if not answer:\n        return None\n\n    # Build final messages\n    user_content = \"\\n\".join(user_content_parts)\n    task_prompt = random.choice(TASK_PROMPTS)\n\n    return {\n        'source': 'RecIF_ProductRec',\n        'uuid': str(uuid.uuid4()),\n        'messages': build_messages(user_content, task_prompt, answer),\n        'metadata': json.dumps({'uid': int(row['uid'])}, ensure_ascii=False)\n    }\n\n\n# ============== Main Function ==============\ndef main():\n    parser = argparse.ArgumentParser(description=\"Product Recommendation Task Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input metadata parquet path')\n    parser.add_argument('--pid2sid', type=str, required=True, help='Video pid2sid mapping parquet path')\n    parser.add_argument('--product_pid2sid', type=str, required=True, help='Product pid2sid mapping parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # 1. Load video pid2sid mapping\n    print(f\"Loading video pid2sid from {args.pid2sid}...\")\n    df_video_pid2sid = pd.read_parquet(args.pid2sid)\n    video_pid2sid = dict(zip(df_video_pid2sid['pid'], df_video_pid2sid['sid']))\n    print(f\"  Loaded {len(video_pid2sid):,} video mappings\")\n\n    # 2. Load product pid2sid mapping\n    print(f\"Loading product pid2sid from {args.product_pid2sid}...\")\n    df_product_pid2sid = pd.read_parquet(args.product_pid2sid)\n    product_pid2sid = dict(zip(df_product_pid2sid['pid'], df_product_pid2sid['sid']))\n    print(f\"  Loaded {len(product_pid2sid):,} product mappings\")\n\n    # 3. Load metadata\n    print(f\"Loading metadata from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # 4. Process data (train only, split=0)\n    print(\"Processing...\")\n    results = []\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        if row['split'] != 0:\n            continue\n        result = process_row(row, video_pid2sid, product_pid2sid)\n        if result:\n            results.append(result)\n\n    # 5. Save results\n    df_train = pd.DataFrame(results)\n    train_path = output_dir / 'train.parquet'\n    df_train.to_parquet(train_path, index=False)\n\n    print(f\"Saved: {train_path} ({len(df_train):,} rows)\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/onerec_data/sft/rec_reason.py",
    "content": "\"\"\"\nRecommendation Reasoning Task\nInput: rec_reason parquet (user_profile_with_sid, gsu_caption, target_caption, cot, etc.)\nOutput: LLM SFT training format parquet\n\nTask: Given user profile, watch history captions, and target video caption,\ngenerate reasoning for why the user would click the target video.\n\"\"\"\n\nimport pandas as pd\nimport argparse\nimport json\nimport uuid\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n# ============== Configuration ==============\nUSER_PROMPT_TEMPLATE = \"\"\"{user_profile}\n\n[历史观看视频内容]\n{gsu_caption}\n\n[用户点击下一个视频内容]\n{target_video_caption}\n\n请在思考的时候分析总结用户兴趣，重点根据用户观看视频内容进行推理，给出下一个点击的理由及视频的基本内容，下一个点击视频需要与给定的一致，注意虽然给出了下一个点击的视频但应该体现出推理得到而不是直接知道的。\n最后再用一段话输出精炼的推理过程。\n生成格式严格按照两大部分，标题分别是：预测分析；精炼推理。\n\"\"\"\n\n\n# ============== Core Functions ==============\ndef build_messages(user_prompt: str, answer: str) -> str:\n    \"\"\"Build messages format JSON string.\"\"\"\n    messages = [\n        {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_prompt}]},\n        {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": answer}]}\n    ]\n    return json.dumps(messages, ensure_ascii=False)\n\n\ndef is_valid_str(val) -> bool:\n    \"\"\"Check if value is a valid non-empty string.\"\"\"\n    if val is None:\n        return False\n    if isinstance(val, float) and pd.isna(val):\n        return False\n    if isinstance(val, str) and val.strip():\n        return True\n    return False\n\n\ndef process_row(row) -> dict:\n    \"\"\"Process a single row of data.\"\"\"\n    user_profile = row.get('inter_user_profile_with_sid')\n    gsu_caption = row.get('reco_gsu_caption')\n    target_caption = row.get('reco_target_caption')\n    answer = row.get('reco_cot')\n\n    # Check data validity\n    if not is_valid_str(user_profile):\n        return None\n    if not is_valid_str(target_caption):\n        return None\n    if not is_valid_str(answer):\n        return None\n\n    gsu_caption = str(gsu_caption) \n\n    # Build user prompt\n    user_prompt = USER_PROMPT_TEMPLATE.format(\n        user_profile=user_profile,\n        gsu_caption=gsu_caption,\n        target_video_caption=target_caption\n    )\n\n    metadata = {\n        'uid': int(row['uid']) if 'uid' in row else None,\n        'target_pid': int(row['target_pid']) if 'target_pid' in row else None,\n    }\n\n    return {\n        'source': 'RecIF_RecoReason',\n        'uuid': str(uuid.uuid4()),\n        'messages': build_messages(user_prompt, answer),\n        'metadata': json.dumps(metadata, ensure_ascii=False)\n    }\n\n\n# ============== Main Function ==============\ndef main():\n    parser = argparse.ArgumentParser(description=\"Recommendation Reasoning Task Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input rec_reason parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    args = parser.parse_args()\n\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # Load data\n    print(f\"Loading data from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # Process data (train only, split=0)\n    print(\"Processing...\")\n    results = []\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        if row.get('split', 0) != 0:\n            continue\n        result = process_row(row)\n        if result:\n            results.append(result)\n\n    # Save results\n    df_train = pd.DataFrame(results)\n    train_path = output_dir / 'train.parquet'\n    df_train.to_parquet(train_path, index=False)\n\n    print(f\"Saved: {train_path} ({len(df_train):,} rows)\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/onerec_data/sft/video_rec.py",
    "content": "\"\"\"\nVideo Recommendation Task\nInput: metadata parquet + pid2sid parquet\nOutput: LLM SFT training format parquet\n\"\"\"\n\nimport pandas as pd\nimport numpy as np\nimport argparse\nimport json\nimport uuid\nimport random\nfrom pathlib import Path\nfrom tqdm import tqdm\n\n# ============== Configuration ==============\nSID_FORMAT = '<|sid_begin|><s_a_{c0}><s_b_{c1}><s_c_{c2}><|sid_end|>'\nHIST_MAX_LEN = 512\nTARGET_MAX_LEN = 10\n\n# System prompts (Chinese)\nSYSTEM_PROMPTS = [\n    \"你是一个智能推荐助手，能够根据用户的浏览历史预测用户可能感兴趣的下一个内容。\",\n    \"你是一名内容推荐专家，擅长分析用户浏览行为并预测用户偏好。\",\n    \"作为推荐系统助手，你需要根据用户历史浏览记录推荐合适的内容。\",\n    \"你具备理解用户浏览模式并生成个性化推荐的能力。\",\n    \"你是一个专业的内容推荐助手，能够根据用户过往浏览记录推荐相关内容。\",\n]\n\n# User prompts (Chinese)\nUSER_PROMPTS = [\n    \"根据以下用户浏览记录，请预测用户接下来可能观看的内容：\\n{query}\",\n    \"用户浏览了以下内容：\\n{query}\\n请预测用户的下一个观看意向。\",\n    \"以下是用户的浏览历史：\\n{query}\\n请推荐用户可能感兴趣的下一个内容。\",\n    \"用户历史浏览记录如下：\\n{query}\\n分析并预测用户接下来会观看什么内容。\",\n    \"{query}\\n根据上述浏览记录，推测用户的下一个观看目标。\",\n]\n\n\n# ============== Core Functions ==============\ndef pids_to_sids(pids, pid2sid: dict) -> str:\n    \"\"\"Convert a list of pids to SID string.\"\"\"\n    if pids is None or (isinstance(pids, float) and pd.isna(pids)):\n        return \"\"\n    sids = []\n    for pid in pids:\n        if pid in pid2sid:\n            code = pid2sid[pid]\n            sid = SID_FORMAT.format(c0=code[0], c1=code[1], c2=code[2])\n            sids.append(sid)\n    return ''.join(sids)\n\n\ndef build_messages(query: str, answer: str) -> str:\n    \"\"\"Build messages format JSON string.\"\"\"\n    system_prompt = random.choice(SYSTEM_PROMPTS)\n    user_prompt = random.choice(USER_PROMPTS).format(query=query)\n\n    messages = [\n        {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": system_prompt}]},\n        {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_prompt}]},\n        {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": answer}]}\n    ]\n    return json.dumps(messages, ensure_ascii=False)\n\n\ndef process_row(row, pid2sid: dict) -> dict:\n    \"\"\"Process a single row of data.\"\"\"\n    hist_pids = row['hist_video_pid']\n    target_pids = row['target_video_pid']\n\n    # Check data validity\n    if hist_pids is None or (isinstance(hist_pids, float) and pd.isna(hist_pids)):\n        return None\n    if target_pids is None or (isinstance(target_pids, float) and pd.isna(target_pids)):\n        return None\n\n    # Truncate and convert to SID (keep most recent history)\n    query = pids_to_sids(hist_pids[-HIST_MAX_LEN:], pid2sid)\n    answer = pids_to_sids(target_pids[:TARGET_MAX_LEN], pid2sid)\n\n    if not query or not answer:\n        return None\n\n    return {\n        'source': 'RecIF_VideoRec',\n        'uuid': str(uuid.uuid4()),\n        'messages': build_messages(query, answer),\n        'metadata': json.dumps({'uid': int(row['uid'])}, ensure_ascii=False)\n    }\n\n\n# ============== Main Function ==============\ndef main():\n    parser = argparse.ArgumentParser(description=\"Video Recommendation Task Data Processing\")\n    parser.add_argument('--input', type=str, required=True, help='Input metadata parquet path')\n    parser.add_argument('--pid2sid', type=str, required=True, help='pid2sid mapping parquet path')\n    parser.add_argument('--output_dir', type=str, required=True, help='Output directory')\n    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # 1. Load pid2sid mapping\n    print(f\"Loading pid2sid from {args.pid2sid}...\")\n    df_pid2sid = pd.read_parquet(args.pid2sid)\n    pid2sid = dict(zip(df_pid2sid['pid'], df_pid2sid['sid']))\n    print(f\"  Loaded {len(pid2sid):,} mappings\")\n\n    # 2. Load metadata\n    print(f\"Loading metadata from {args.input}...\")\n    df = pd.read_parquet(args.input)\n    print(f\"  Loaded {len(df):,} rows\")\n\n    # 3. Process data (train only, split=0)\n    print(\"Processing...\")\n    results = []\n    for _, row in tqdm(df.iterrows(), total=len(df)):\n        if row['split'] != 0:\n            continue\n        result = process_row(row, pid2sid)\n        if result:\n            results.append(result)\n\n    # 4. Save results\n    df_train = pd.DataFrame(results)\n    train_path = output_dir / 'train.parquet'\n    df_train.to_parquet(train_path, index=False)\n\n    print(f\"Saved: {train_path} ({len(df_train):,} rows)\")\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/prepare_distillation.sh",
    "content": "#!/bin/bash\n# Data sampling script: Sample specified number of samples from general dataset for on-policy distillation\n\nset -e\n\n# Configuration\nINPUT_PATH=\"../raw_data/general_text/sft\"\nOUTPUT_FILE=\"../output/onpolicy_distillation.parquet\"\nTEMP_FILE=\"../output/onpolicy_distillation_temp.parquet\"\nNUM_SAMPLES=200000\nSEED=42\nENGINE=\"pyarrow\"\n\n# Check if paths exist\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n\nif [ ! -e \"${INPUT_PATH}\" ]; then\n    echo \"Error: Input path does not exist: ${INPUT_PATH}\"\n    exit 1\nfi\n\n# Step 1: Sample data\necho \"Step 1: Sampling data...\"\npython3 \"${SCRIPT_DIR}/scripts/sample_data.py\" \\\n    --input \"${INPUT_PATH}\" \\\n    --output \"${TEMP_FILE}\" \\\n    --num_samples \"${NUM_SAMPLES}\" \\\n    --seed \"${SEED}\" \\\n    --engine \"${ENGINE}\"\n\n# Step 2: Fix unicode encoding\necho \"\"\necho \"Step 2: Fixing unicode encoding...\"\npython3 \"${SCRIPT_DIR}/scripts/parquet_unicode_fix.py\" \\\n    --input \"${TEMP_FILE}\" \\\n    --output \"${OUTPUT_FILE}\" \\\n    --engine \"${ENGINE}\"\n\n# Clean up temporary files\nif [ -f \"${TEMP_FILE}\" ]; then\n    rm \"${TEMP_FILE}\"\n    echo \"Temporary files cleaned up\"\nfi\n\necho \"\"\necho \"Processing completed! Output file: ${OUTPUT_FILE}\"\n\n"
  },
  {
    "path": "data/prepare_pretrain.sh",
    "content": "#!/bin/bash\n# Data splitting script: Merge general text and recommendation data, then split by every 1000 samples\n\nset -e\n\n# Configuration\n# Both general and onerec use datasets starting with pretrain\nGENERAL_TEXT_PATH=\"../raw_data/general_text/pretrain\"\nREC_DATA_PATH=\"../raw_data/onerec_data\"\nOUTPUT_DIR=\"../output/split_data_pretrain\"\nMAX_ROWS=1000\nENGINE=\"pyarrow\"\n\n# Check if paths exist\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n\nif [ ! -e \"${GENERAL_TEXT_PATH}\" ]; then\n    echo \"Error: General text path does not exist: ${GENERAL_TEXT_PATH}\"\n    exit 1\nfi\n\nif [ ! -e \"${REC_DATA_PATH}\" ]; then\n    echo \"Error: Recommendation data path does not exist: ${REC_DATA_PATH}\"\n    exit 1\nfi\n\n# Execute\npython3 \"${SCRIPT_DIR}/scripts/split_data.py\" \\\n    --general_text_path \"${GENERAL_TEXT_PATH}\" \\\n    --rec_data_path \"${REC_DATA_PATH}\" \\\n    --output_dir \"${OUTPUT_DIR}\" \\\n    --max_rows \"${MAX_ROWS}\" \\\n    --engine \"${ENGINE}\"\n\n"
  },
  {
    "path": "data/prepare_rl.sh",
    "content": "#!/bin/bash\n# RL data splitting script: Merge multiple RL task datasets and split into training and test sets\n\nset -e\n\n# Configuration\n# onerec dataset output path, rl uses datasets starting with sft\nREC_DATA_PATH=\"../output\"\n\n# Tasks that RL depends on\nVIDEO_REC=${REC_DATA_PATH}/sft_video_rec.parquet\nAD_REC=${REC_DATA_PATH}/sft_ad_rec.parquet\nPRODUCT_REC=${REC_DATA_PATH}/sft_product_rec.parquet\nINTERACTIVE_REC=${REC_DATA_PATH}/sft_interactive_rec.parquet\nLABEL_COND_REC=${REC_DATA_PATH}/sft_label_cond_rec.parquet\n\n# Output configuration\nOUTPUT_DIR=\"../output/rl_data\"\nTEST_SIZE=1000\nSEED=42\nENGINE=\"pyarrow\"\n\n# Get script directory\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n\n# Define all task files to process\ndeclare -a TASK_FILES=(\n    \"${VIDEO_REC}\"\n    \"${AD_REC}\"\n    \"${PRODUCT_REC}\"\n    \"${INTERACTIVE_REC}\"\n    \"${LABEL_COND_REC}\"\n)\n\n# Check if input files exist\necho \"Checking input files...\"\nMISSING_FILES=0\nfor file in \"${TASK_FILES[@]}\"; do\n    if [ ! -f \"${file}\" ]; then\n        echo \"Warning: File does not exist: ${file}\"\n        MISSING_FILES=$((MISSING_FILES + 1))\n    fi\ndone\n\nif [ ${MISSING_FILES} -eq ${#TASK_FILES[@]} ]; then\n    echo \"Error: All input files do not exist\"\n    exit 1\nfi\n\n# Execute train_test_split, merge all files and process them together\necho \"\"\necho \"Starting RL data splitting...\"\necho \"==========================================\"\necho \"Input files:\"\nfor file in \"${TASK_FILES[@]}\"; do\n    if [ -f \"${file}\" ]; then\n        echo \"  - ${file}\"\n    fi\ndone\necho \"Output directory: ${OUTPUT_DIR}\"\necho \"Test set size: ${TEST_SIZE}\"\necho \"==========================================\"\n\npython3 \"${SCRIPT_DIR}/scripts/train_test_split.py\" \\\n    --input_files \"${TASK_FILES[@]}\" \\\n    --test_size \"${TEST_SIZE}\" \\\n    --output_dir \"${OUTPUT_DIR}\" \\\n    --seed \"${SEED}\" \\\n    --engine \"${ENGINE}\" \\\n    --test_filename \"test.parquet\" \\\n    --train_filename \"train.parquet\"\n\necho \"\"\necho \"==========================================\"\necho \"RL data processing completed!\"\necho \"Output directory: ${OUTPUT_DIR}\"\necho \"  - train.parquet (training set)\"\necho \"  - test.parquet (test set)\"\necho \"==========================================\"\n"
  },
  {
    "path": "data/prepare_sft.sh",
    "content": "#!/bin/bash\n# Data splitting script: Merge general text and recommendation data, then split by every 1000 samples\n\nset -e\n\n# Configuration\n# Both general and onerec use datasets starting with sft\nGENERAL_TEXT_PATH=\"../raw_data/general_text/sft\"\nREC_DATA_PATH=\"../raw_data/onerec_data\"\nOUTPUT_DIR=\"../output/split_data_sft\"\nMAX_ROWS=1000\nENGINE=\"pyarrow\"\n\n# Check if paths exist\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n\nif [ ! -e \"${GENERAL_TEXT_PATH}\" ]; then\n    echo \"Error: General text path does not exist: ${GENERAL_TEXT_PATH}\"\n    exit 1\nfi\n\nif [ ! -e \"${REC_DATA_PATH}\" ]; then\n    echo \"Error: Recommendation data path does not exist: ${REC_DATA_PATH}\"\n    exit 1\nfi\n\n# Execute\npython3 \"${SCRIPT_DIR}/scripts/split_data.py\" \\\n    --general_text_path \"${GENERAL_TEXT_PATH}\" \\\n    --rec_data_path \"${REC_DATA_PATH}\" \\\n    --output_dir \"${OUTPUT_DIR}\" \\\n    --max_rows \"${MAX_ROWS}\" \\\n    --engine \"${ENGINE}\"\n\n"
  },
  {
    "path": "data/scripts/parquet_unicode_fix.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Parquet Unicode Fix Script\n\nFix unicode Chinese garbled text issues in messages and segments fields of parquet files.\nSupports single file or batch directory processing.\n\"\"\"\n\nimport argparse\nimport json\nimport logging\nimport os\nimport sys\nfrom pathlib import Path\nfrom typing import List, Optional, Union\n\nimport pandas as pd\nfrom tqdm import tqdm\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(levelname)s - %(message)s',\n    handlers=[logging.StreamHandler()]\n)\nlogger = logging.getLogger(__name__)\n\ndef decode_unicode_json(json_str: Optional[Union[str, bytes]]) -> Optional[str]:\n    \"\"\"Decode unicode characters in JSON string.\n\n    Args:\n        json_str: JSON string that may contain unicode encoding\n\n    Returns:\n        Decoded JSON string\n    \"\"\"\n    if json_str is None or pd.isna(json_str):\n        return json_str\n\n    # Handle bytes type\n    if isinstance(json_str, bytes):\n        json_str = json_str.decode('utf-8', errors='ignore')\n\n    # If already a string and doesn't contain unicode escape sequences, return directly\n    if isinstance(json_str, str) and '\\\\u' not in json_str:\n        return json_str\n    \n    try:\n        # JSON load (automatically decode unicode)\n        json_obj = json.loads(json_str)\n\n        # JSON dump with ensure_ascii disabled (preserve Chinese characters)\n        decoded_str = json.dumps(\n            json_obj,\n            ensure_ascii=False,  # Key: don't convert Chinese to unicode\n            indent=None,         # Keep original compact format\n            separators=(',', ':')  # Keep original separator format\n        )\n        return decoded_str\n\n    except json.JSONDecodeError:\n        # Return original string when JSON parsing fails\n        return json_str\n    except Exception as e:\n        logger.debug(f\"Error processing JSON string: {e}\")\n        return json_str\n\ndef find_parquet_files(directory: str, recursive: bool = True) -> List[str]:\n    \"\"\"\n    Find all parquet files in the directory\n\n    Args:\n        directory: Directory path\n        recursive: Whether to recursively search subdirectories, default True\n\n    Returns:\n        List of parquet file paths\n    \"\"\"\n    parquet_files = []\n    directory_path = Path(directory)\n    \n    if not directory_path.exists():\n        raise FileNotFoundError(f\"Directory does not exist: {directory}\")\n\n    if not directory_path.is_dir():\n        raise ValueError(f\"Path is not a directory: {directory}\")\n\n    pattern = \"**/*.parquet\" if recursive else \"*.parquet\"\n    parquet_files = [str(p) for p in directory_path.glob(pattern) if p.is_file()]\n\n    logger.info(f\"Found {len(parquet_files)} parquet files in directory {directory}\")\n    return sorted(parquet_files)\n\ndef get_output_path(input_path: str, output_base: str, input_base: Optional[str] = None) -> str:\n    \"\"\"\n    Generate output path based on input path and output base path\n\n    Args:\n        input_path: Input file path\n        output_base: Output base path (file or directory)\n        input_base: Input base path (to maintain relative path structure), if None uses input file's directory\n\n    Returns:\n        Output file path\n    \"\"\"\n    input_path_obj = Path(input_path)\n    output_base_obj = Path(output_base)\n\n    # If output base path is a file, return directly\n    if output_base_obj.is_file() or (not output_base_obj.exists() and not output_base_obj.suffix == ''):\n        return str(output_base_obj)\n\n    # If output base path is a directory\n    if input_base:\n        # Maintain relative path structure\n        input_base_obj = Path(input_base)\n        try:\n            relative_path = input_path_obj.relative_to(input_base_obj)\n            output_path = output_base_obj / relative_path\n        except ValueError:\n            # If unable to calculate relative path, use filename\n            output_path = output_base_obj / input_path_obj.name\n    else:\n        # Use input file's directory as base\n        output_path = output_base_obj / input_path_obj.name\n\n    return str(output_path)\n\ndef process_parquet_file(\n    input_path: str,\n    output_path: str,\n    engine: str = 'pyarrow',\n    fields: Optional[List[str]] = None\n) -> None:\n    \"\"\"Process parquet file to fix unicode Chinese garbled text in specified fields.\n\n    Args:\n        input_path: Input parquet file path\n        output_path: Output parquet file path\n        engine: Engine for reading/writing parquet, options: 'pyarrow' or 'fastparquet'\n        fields: List of fields to process, defaults to ['messages', 'segments']\n    \"\"\"\n    if not os.path.exists(input_path):\n        raise FileNotFoundError(f\"Input file does not exist: {input_path}\")\n\n    if fields is None:\n        fields = ['messages', 'segments']\n\n    # Read parquet file\n    logger.info(f\"Reading file: {input_path}\")\n    df = pd.read_parquet(input_path, engine=engine)\n    logger.info(f\"Total rows: {len(df)}\")\n    \n    # Check and process fields\n    processed_fields = []\n    for field in fields:\n        if field in df.columns:\n            logger.debug(f\"Processing field: {field}\")\n            df[field] = df[field].apply(decode_unicode_json)\n            processed_fields.append(field)\n        else:\n            logger.debug(f\"Field does not exist, skipping: {field}\")\n    \n    if not processed_fields:\n        logger.warning(f\"No fields to process found: {fields}\")\n        # If no fields to process, copy file directly\n        if input_path != output_path:\n            import shutil\n            Path(output_path).parent.mkdir(parents=True, exist_ok=True)\n            shutil.copy2(input_path, output_path)\n            logger.info(f\"File copied to: {output_path}\")\n        return\n\n    logger.info(f\"Processed fields: {processed_fields}\")\n\n    # Save processed file\n    Path(output_path).parent.mkdir(parents=True, exist_ok=True)\n    df.to_parquet(\n        output_path,\n        engine=engine,\n        index=False,\n        compression='snappy'\n    )\n    logger.info(f\"File saved successfully: {output_path}\")\n\ndef process_directory(input_dir: str, output_dir: str, engine: str = 'pyarrow', recursive: bool = True, overwrite: bool = False) -> None:\n    \"\"\"\n    Batch process all parquet files in the directory\n\n    Args:\n        input_dir: Input directory path\n        output_dir: Output directory path\n        engine: Parquet processing engine\n        recursive: Whether to recursively process subdirectories\n        overwrite: Whether to overwrite original files (if True, output_dir is ignored and input files are overwritten directly)\n    \"\"\"\n    # Find all parquet files\n    parquet_files = find_parquet_files(input_dir, recursive=recursive)\n\n    if not parquet_files:\n        logger.warning(f\"No parquet files found in directory {input_dir}\")\n        return\n\n    # Create output directory (if needed)\n    if not overwrite:\n        output_path_obj = Path(output_dir)\n        output_path_obj.mkdir(parents=True, exist_ok=True)\n        logger.info(f\"Output directory: {output_dir}\")\n    \n    # Process each file\n    total_files = len(parquet_files)\n    success_count = 0\n    fail_count = 0\n\n    for input_file in tqdm(parquet_files, desc=\"Processing files\"):\n        try:\n            if overwrite:\n                # Overwrite original file\n                output_file = input_file\n            else:\n                # Generate output path, maintain directory structure\n                output_file = get_output_path(input_file, output_dir, input_dir)\n                # Ensure output directory exists\n                Path(output_file).parent.mkdir(parents=True, exist_ok=True)\n            \n            process_parquet_file(input_file, output_file, engine)\n            success_count += 1\n\n        except Exception as e:\n            fail_count += 1\n            logger.error(f\"File processing failed: {input_file}, error: {e}\", exc_info=True)\n            continue\n\n    # Output statistics\n    logger.info(f\"\\n{'='*60}\")\n    logger.info(f\"Batch processing completed!\")\n    logger.info(f\"Total files: {total_files}\")\n    logger.info(f\"Success: {success_count}\")\n    logger.info(f\"Failed: {fail_count}\")\n    logger.info(f\"{'='*60}\")\n\ndef main():\n    # Parse command line arguments\n    parser = argparse.ArgumentParser(\n        description='Process unicode Chinese garbled text in messages and segments fields of parquet files (supports single file or batch directory processing)'\n    )\n    parser.add_argument(\n        '-i', '--input',\n        required=True,\n        help='Input parquet file path or directory path (required)'\n    )\n    parser.add_argument(\n        '-o', '--output',\n        required=True,\n        help='Output parquet file path or directory path (required)'\n    )\n    parser.add_argument(\n        '-e', '--engine',\n        choices=['pyarrow', 'fastparquet'],\n        default='pyarrow',\n        help='Parquet processing engine, default uses pyarrow'\n    )\n    parser.add_argument(\n        '--no-recursive',\n        action='store_true',\n        help='When processing directory, do not recursively process subdirectories (only process files in current directory)'\n    )\n    parser.add_argument(\n        '--overwrite',\n        action='store_true',\n        help='Overwrite original files (only effective when input is directory, will ignore output path)'\n    )\n    \n    args = parser.parse_args()\n\n    # Execute processing\n    try:\n        input_path = Path(args.input)\n\n        if not input_path.exists():\n            logger.error(f\"Input path does not exist: {args.input}\")\n            exit(1)\n        \n        # Determine if input is file or directory\n        if input_path.is_file():\n            # Single file processing mode\n            logger.info(\"Single file processing mode\")\n            if Path(args.output).is_dir():\n                # If output is directory, create file with same name in directory\n                output_file = Path(args.output) / input_path.name\n            else:\n                output_file = args.output\n            \n            process_parquet_file(\n                input_path=str(input_path),\n                output_path=str(output_file),\n                engine=args.engine\n            )\n            logger.info(\"All operations completed!\")\n\n        elif input_path.is_dir():\n            # Directory batch processing mode\n            logger.info(\"Directory batch processing mode\")\n            if args.overwrite:\n                logger.info(\"Will overwrite original files\")\n                process_directory(\n                    input_dir=str(input_path),\n                    output_dir=\"\",  # Will not be used\n                    engine=args.engine,\n                    recursive=not args.no_recursive,\n                    overwrite=True\n                )\n            else:\n                output_path = Path(args.output)\n                if output_path.exists() and output_path.is_file():\n                    logger.error(f\"When input is directory, output should also be directory, but output path is file: {args.output}\")\n                    exit(1)\n                process_directory(\n                    input_dir=str(input_path),\n                    output_dir=str(output_path),\n                    engine=args.engine,\n                    recursive=not args.no_recursive,\n                    overwrite=False\n                )\n            logger.info(\"All operations completed!\")\n        else:\n            logger.error(f\"Input path is neither file nor directory: {args.input}\")\n            exit(1)\n            \n    except KeyboardInterrupt:\n        logger.info(\"\\nOperation cancelled by user\")\n        sys.exit(1)\n    except Exception as e:\n        logger.error(f\"Program execution failed: {e}\", exc_info=True)\n        sys.exit(1)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "data/scripts/sample_data.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Data Sampling Script\n\nSample specified number of samples from one or more paths (directories or files) containing parquet files,\nand save as a single parquet file.\n\"\"\"\n\nimport argparse\nimport logging\nimport random\nimport sys\nfrom pathlib import Path\nfrom typing import List\n\nimport pandas as pd\nfrom tqdm import tqdm\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(levelname)s - %(message)s',\n    handlers=[logging.StreamHandler()]\n)\nlogger = logging.getLogger(__name__)\n\n\ndef find_parquet_files(directory: str, recursive: bool = True) -> List[str]:\n    \"\"\"Find all parquet files in the directory.\n\n    Args:\n        directory: Directory path\n        recursive: Whether to recursively search subdirectories\n\n    Returns:\n        List of parquet file paths\n    \"\"\"\n    dir_path = Path(directory)\n    if not dir_path.exists():\n        raise FileNotFoundError(f\"Directory does not exist: {directory}\")\n\n    if not dir_path.is_dir():\n        raise ValueError(f\"Path is not a directory: {directory}\")\n    \n    pattern = \"**/*.parquet\" if recursive else \"*.parquet\"\n    parquet_files = [str(p) for p in dir_path.glob(pattern) if p.is_file()]\n    \n    return sorted(parquet_files)\n\n\ndef collect_parquet_files(input_paths: List[str], recursive: bool = True) -> List[str]:\n    \"\"\"Collect all parquet file paths.\n\n    Args:\n        input_paths: List of input paths (can be files or directories)\n        recursive: Whether to recursively search subdirectories\n\n    Returns:\n        List of parquet file paths\n    \"\"\"\n    all_files = []\n    \n    for input_path in input_paths:\n        path = Path(input_path)\n\n        if not path.exists():\n            logger.warning(f\"Path does not exist, skipping: {input_path}\")\n            continue\n\n        if path.is_file():\n            if path.suffix.lower() == '.parquet':\n                all_files.append(str(path))\n            else:\n                logger.warning(f\"Not a parquet file, skipping: {input_path}\")\n        elif path.is_dir():\n            files = find_parquet_files(str(path), recursive=recursive)\n            all_files.extend(files)\n        else:\n            logger.warning(f\"Unknown path type, skipping: {input_path}\")\n\n    return sorted(list(set(all_files)))  # Remove duplicates and sort\n\n\ndef load_all_parquet_files(file_paths: List[str], engine: str = 'pyarrow') -> pd.DataFrame:\n    \"\"\"Load all parquet files and merge them.\n\n    Args:\n        file_paths: List of parquet file paths\n        engine: Parquet engine, 'pyarrow' or 'fastparquet'\n\n    Returns:\n        Merged DataFrame\n    \"\"\"\n    if not file_paths:\n        logger.warning(\"No parquet files found\")\n        return pd.DataFrame()\n\n    logger.info(f\"Found {len(file_paths)} parquet files, starting to load...\")\n\n    dataframes = []\n    for file_path in tqdm(file_paths, desc=\"Loading files\"):\n        try:\n            df = pd.read_parquet(file_path, engine=engine)\n            logger.debug(f\"  Loaded {file_path}: {len(df)} rows\")\n            dataframes.append(df)\n        except Exception as e:\n            logger.error(f\"  Failed to load {file_path}: {e}\")\n            continue\n    \n    if not dataframes:\n        logger.warning(\"No files loaded successfully\")\n        return pd.DataFrame()\n\n    # Merge all DataFrames\n    logger.info(\"Merging all data...\")\n    combined_df = pd.concat(dataframes, ignore_index=True)\n    logger.info(f\"Merge completed, total {len(combined_df)} rows\")\n\n    return combined_df\n\n\ndef sample_dataframe(df: pd.DataFrame, num_samples: int, seed: int = None) -> pd.DataFrame:\n    \"\"\"Sample specified number of samples from DataFrame.\n\n    Args:\n        df: DataFrame to sample from\n        num_samples: Number of samples\n        seed: Random seed\n\n    Returns:\n        Sampled DataFrame\n    \"\"\"\n    if len(df) == 0:\n        logger.warning(\"DataFrame is empty, cannot sample\")\n        return pd.DataFrame()\n\n    if num_samples <= 0:\n        raise ValueError(f\"num_samples must be greater than 0, current value: {num_samples}\")\n\n    total_rows = len(df)\n\n    if num_samples >= total_rows:\n        logger.warning(f\"Sample size ({num_samples}) is greater than or equal to total rows ({total_rows}), returning all data\")\n        return df.copy()\n    \n    # Set random seed\n    if seed is not None:\n        random.seed(seed)\n        logger.info(f\"Using random seed: {seed}\")\n\n    # Random sampling\n    logger.info(f\"Sampling {num_samples} rows from {total_rows} rows...\")\n    sampled_indices = random.sample(range(total_rows), num_samples)\n    sampled_df = df.iloc[sampled_indices].copy()\n\n    logger.info(f\"Sampling completed, total {len(sampled_df)} rows\")\n\n    return sampled_df\n\n\ndef main():\n    \"\"\"Main function.\"\"\"\n    parser = argparse.ArgumentParser(\n        description='Sample specified number of samples from one or more paths containing parquet files, and save as a single parquet file'\n    )\n    parser.add_argument(\n        '--input',\n        type=str,\n        nargs='+',\n        required=True,\n        help='Input paths (can be files or directories), multiple paths can be specified'\n    )\n    parser.add_argument(\n        '--output',\n        type=str,\n        required=True,\n        help='Output parquet file path'\n    )\n    parser.add_argument(\n        '--num_samples',\n        type=int,\n        required=True,\n        help='Number of samples'\n    )\n    parser.add_argument(\n        '--seed',\n        type=int,\n        default=None,\n        help='Random seed (optional)'\n    )\n    parser.add_argument(\n        '--engine',\n        choices=['pyarrow', 'fastparquet'],\n        default='pyarrow',\n        help='Parquet processing engine (default: pyarrow)'\n    )\n    parser.add_argument(\n        '--no-recursive',\n        action='store_true',\n        help='Do not recursively search for files in subdirectories'\n    )\n    \n    args = parser.parse_args()\n\n    # Validate parameters\n    if args.num_samples <= 0:\n        logger.error(f\"num_samples must be greater than 0, current value: {args.num_samples}\")\n        sys.exit(1)\n    \n    try:\n        # 1. Collect all parquet files\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 1: Collecting parquet files...\")\n        parquet_files = collect_parquet_files(\n            args.input,\n            recursive=not args.no_recursive\n        )\n\n        if not parquet_files:\n            logger.error(\"No parquet files found\")\n            sys.exit(1)\n\n        logger.info(f\"Found {len(parquet_files)} parquet files\")\n        \n        # 2. Load all files\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 2: Loading parquet files...\")\n        combined_df = load_all_parquet_files(parquet_files, engine=args.engine)\n\n        if len(combined_df) == 0:\n            logger.error(\"No data loaded\")\n            sys.exit(1)\n        \n        # 3. Sample data\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 3: Sampling data...\")\n        sampled_df = sample_dataframe(\n            combined_df,\n            num_samples=args.num_samples,\n            seed=args.seed\n        )\n\n        if len(sampled_df) == 0:\n            logger.error(\"Sampled data is empty\")\n            sys.exit(1)\n        \n        # 4. Save results\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 4: Saving results...\")\n        output_path = Path(args.output)\n        output_path.parent.mkdir(parents=True, exist_ok=True)\n        \n        sampled_df.to_parquet(\n            output_path,\n            engine='pyarrow',\n            index=False,\n            compression='snappy'\n        )\n\n        logger.info(f\"Results saved to: {output_path}\")\n\n        # 5. Output statistics\n        logger.info(\"=\" * 60)\n        logger.info(\"Processing completed!\")\n        logger.info(f\"Input files: {len(parquet_files)}\")\n        logger.info(f\"Original data rows: {len(combined_df)}\")\n        logger.info(f\"Sampled rows: {len(sampled_df)}\")\n        logger.info(f\"Output file: {output_path}\")\n        logger.info(\"=\" * 60)\n        \n    except KeyboardInterrupt:\n        logger.info(\"\\nOperation cancelled by user\")\n        sys.exit(1)\n    except Exception as e:\n        logger.error(f\"Program execution failed: {e}\", exc_info=True)\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "data/scripts/split_data.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Data splitting script\n\nMerge general text data and recommendation data, then split into multiple files with 1000 samples each.\n\"\"\"\n\nimport argparse\nimport json\nimport logging\nimport sys\nfrom pathlib import Path\nfrom typing import List\n\nimport pandas as pd\nfrom tqdm import tqdm\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(levelname)s - %(message)s',\n    handlers=[logging.StreamHandler()]\n)\nlogger = logging.getLogger(__name__)\n\n\ndef find_parquet_files(directory: str, recursive: bool = True) -> List[str]:\n    \"\"\"Find all parquet files in the directory.\n\n    Args:\n        directory: Directory path\n        recursive: Whether to recursively search subdirectories\n\n    Returns:\n        List of parquet file paths\n    \"\"\"\n    dir_path = Path(directory)\n    if not dir_path.exists():\n        raise FileNotFoundError(f\"Directory does not exist: {directory}\")\n\n    if not dir_path.is_dir():\n        raise ValueError(f\"Path is not a directory: {directory}\")\n    \n    pattern = \"**/*.parquet\" if recursive else \"*.parquet\"\n    parquet_files = [str(p) for p in dir_path.glob(pattern) if p.is_file()]\n    \n    return sorted(parquet_files)\n\n\ndef load_all_parquet_files(file_paths: List[str], engine: str = 'pyarrow') -> pd.DataFrame:\n    \"\"\"Load and merge all parquet files.\n\n    Args:\n        file_paths: List of parquet file paths\n        engine: Parquet engine, 'pyarrow' or 'fastparquet'\n\n    Returns:\n        Merged DataFrame\n    \"\"\"\n    if not file_paths:\n        logger.warning(\"No parquet files found\")\n        return pd.DataFrame()\n\n    logger.info(f\"Found {len(file_paths)} parquet files, starting to load...\")\n\n    dataframes = []\n    for file_path in tqdm(file_paths, desc=\"Loading files\"):\n        try:\n            df = pd.read_parquet(file_path, engine=engine)\n            logger.debug(f\"  Loaded {file_path}: {len(df)} rows\")\n            dataframes.append(df)\n        except Exception as e:\n            logger.error(f\"  Failed to load {file_path}: {e}\")\n            continue\n\n    if not dataframes:\n        logger.warning(\"No files loaded successfully\")\n        return pd.DataFrame()\n\n    # Merge all DataFrames\n    logger.info(\"Merging all data...\")\n    combined_df = pd.concat(dataframes, ignore_index=True)\n    logger.info(f\"Merge complete, total {len(combined_df)} rows\")\n\n    return combined_df\n\n\ndef split_dataframe(df: pd.DataFrame, max_rows: int, output_dir: str, prefix: str = \"part\") -> List[str]:\n    \"\"\"Split DataFrame into multiple files by fixed number of rows.\n\n    Args:\n        df: DataFrame to split\n        max_rows: Maximum number of rows per file\n        output_dir: Output directory\n        prefix: Output file prefix\n\n    Returns:\n        List of output file paths\n    \"\"\"\n    if len(df) == 0:\n        logger.warning(\"DataFrame is empty, no need to split\")\n        return []\n\n    if max_rows <= 0:\n        raise ValueError(f\"max_rows must be greater than 0, current value: {max_rows}\")\n    \n    # Create output directory\n    output_dir_path = Path(output_dir)\n    output_dir_path.mkdir(parents=True, exist_ok=True)\n\n    # Calculate number of files needed\n    total_rows = len(df)\n    num_chunks = (total_rows + max_rows - 1) // max_rows  # Round up\n    logger.info(f\"Splitting data into {num_chunks} files (max {max_rows} rows per file)\")\n\n    # Use fixed 5-digit format to ensure consistent file naming\n    # Format: part-00000-of-00010.parquet\n    num_digits = 5\n\n    # Split and save\n    output_files = []\n    for chunk_idx in tqdm(range(num_chunks), desc=\"Splitting files\"):\n        start_idx = chunk_idx * max_rows\n        end_idx = min(start_idx + max_rows, total_rows)\n\n        # Extract data chunk\n        chunk_df = df.iloc[start_idx:end_idx]\n\n        # Generate output filename, format: part-00000-of-00010.parquet\n        output_filename = f\"{prefix}-{chunk_idx:0{num_digits}d}-of-{num_chunks:0{num_digits}d}.parquet\"\n        output_path = output_dir_path / output_filename\n\n        # Save file\n        chunk_df.to_parquet(\n            output_path,\n            engine='pyarrow',\n            index=False,\n            compression='snappy'\n        )\n\n        output_files.append(str(output_path))\n        logger.debug(f\"  Saved file {chunk_idx + 1}/{num_chunks}: {output_path} (rows {start_idx} to {end_idx - 1})\")\n\n    logger.info(f\"Successfully split into {len(output_files)} files\")\n    return output_files\n\n\ndef main():\n    \"\"\"Main function.\"\"\"\n    parser = argparse.ArgumentParser(\n        description='Merge general text data and recommendation data, then split into multiple files with 1000 samples each'\n    )\n    parser.add_argument(\n        '--general_text_path',\n        type=str,\n        required=True,\n        help='General text data path (directory or file)'\n    )\n    parser.add_argument(\n        '--rec_data_path',\n        type=str,\n        required=True,\n        help='Recommendation data path (directory or file)'\n    )\n    parser.add_argument(\n        '--output_dir',\n        type=str,\n        required=True,\n        help='Output directory path'\n    )\n    parser.add_argument(\n        '--max_rows',\n        type=int,\n        default=1000,\n        help='Maximum number of rows per file (default: 1000)'\n    )\n    parser.add_argument(\n        '--engine',\n        choices=['pyarrow', 'fastparquet'],\n        default='pyarrow',\n        help='Parquet processing engine (default: pyarrow)'\n    )\n    parser.add_argument(\n        '--no-recursive',\n        action='store_true',\n        help='Do not recursively search for files in subdirectories'\n    )\n    \n    args = parser.parse_args()\n\n    # Validate parameters\n    if args.max_rows <= 0:\n        logger.error(f\"max_rows must be greater than 0, current value: {args.max_rows}\")\n        sys.exit(1)\n    \n    try:\n        # 1. Find all parquet files\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 1: Finding general text data files...\")\n        general_text_path = Path(args.general_text_path)\n        if general_text_path.is_file():\n            general_text_files = [str(general_text_path)]\n        else:\n            general_text_files = find_parquet_files(\n                args.general_text_path,\n                recursive=not args.no_recursive\n            )\n        logger.info(f\"Found {len(general_text_files)} general text files\")\n\n        logger.info(\"Step 2: Finding recommendation data files...\")\n        rec_data_path = Path(args.rec_data_path)\n        if rec_data_path.is_file():\n            rec_data_files = [str(rec_data_path)]\n        else:\n            rec_data_files = find_parquet_files(\n                args.rec_data_path,\n                recursive=not args.no_recursive\n            )\n        logger.info(f\"Found {len(rec_data_files)} recommendation data files\")\n        \n        # 2. Load all files\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 3: Loading general text data...\")\n        general_text_df = load_all_parquet_files(general_text_files, engine=args.engine)\n\n        logger.info(\"Step 4: Loading recommendation data...\")\n        rec_data_df = load_all_parquet_files(rec_data_files, engine=args.engine)\n        \n        # 3. Merge data\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 5: Merging data...\")\n        if len(general_text_df) == 0 and len(rec_data_df) == 0:\n            logger.error(\"No data loaded\")\n            sys.exit(1)\n\n        if len(general_text_df) == 0:\n            combined_df = rec_data_df\n            logger.info(\"Using only recommendation data\")\n        elif len(rec_data_df) == 0:\n            combined_df = general_text_df\n            logger.info(\"Using only general text data\")\n        else:\n            combined_df = pd.concat([general_text_df, rec_data_df], ignore_index=True)\n            logger.info(f\"Merge complete: general text {len(general_text_df)} rows + recommendation data {len(rec_data_df)} rows = total {len(combined_df)} rows\")\n        \n        # 4. Split data\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 6: Splitting data...\")\n        output_files = split_dataframe(\n            combined_df,\n            max_rows=args.max_rows,\n            output_dir=args.output_dir,\n            prefix=\"part\"\n        )\n        \n        # 5. Generate file list JSON\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 7: Generating file list JSON...\")\n        output_dir_path = Path(args.output_dir)\n        json_file_path = output_dir_path / \"file_list.json\"\n\n        # Convert file paths to absolute paths (absolute paths are more reliable)\n        file_list = [str(Path(f).absolute()) for f in output_files]\n\n        with open(json_file_path, 'w', encoding='utf-8') as f:\n            json.dump(file_list, f, indent=2, ensure_ascii=False)\n\n        logger.info(f\"File list saved to: {json_file_path} ({len(file_list)} files)\")\n        \n        # 6. Output statistics\n        logger.info(\"=\" * 60)\n        logger.info(\"Processing complete!\")\n        logger.info(f\"Input files: general text {len(general_text_files)} files, recommendation data {len(rec_data_files)} files\")\n        logger.info(f\"Total data rows: {len(combined_df)}\")\n        logger.info(f\"Output files: {len(output_files)}\")\n        logger.info(f\"Output directory: {args.output_dir}\")\n        logger.info(f\"File list JSON: {json_file_path}\")\n        logger.info(\"=\" * 60)\n        \n    except KeyboardInterrupt:\n        logger.info(\"\\nOperation cancelled by user\")\n        sys.exit(1)\n    except Exception as e:\n        logger.error(f\"Program execution failed: {e}\", exc_info=True)\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "data/scripts/train_test_split.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Train/Test Split Script\n\nRandomly selects N samples from multiple parquet files as the test set, with remaining data as the training set.\nBoth datasets are shuffled before saving.\n\"\"\"\n\nimport argparse\nimport logging\nimport sys\nfrom pathlib import Path\nfrom typing import List\n\nimport pandas as pd\nfrom tqdm import tqdm\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(levelname)s - %(message)s',\n    handlers=[logging.StreamHandler()]\n)\nlogger = logging.getLogger(__name__)\n\n\ndef load_all_parquet_files(file_paths: List[str], engine: str = 'pyarrow') -> pd.DataFrame:\n    \"\"\"Load and merge all parquet files.\n\n    Args:\n        file_paths: List of parquet file paths\n        engine: Parquet engine, 'pyarrow' or 'fastparquet'\n\n    Returns:\n        Merged DataFrame\n    \"\"\"\n    if not file_paths:\n        logger.warning(\"No parquet files found\")\n        return pd.DataFrame()\n\n    logger.info(f\"Found {len(file_paths)} parquet files, starting to load...\")\n\n    dataframes = []\n    for file_path in tqdm(file_paths, desc=\"Loading files\"):\n        try:\n            df = pd.read_parquet(file_path, engine=engine)\n            logger.debug(f\"  Loaded {file_path}: {len(df)} rows\")\n            dataframes.append(df)\n        except Exception as e:\n            logger.error(f\"  Failed to load {file_path}: {e}\")\n            continue\n\n    if not dataframes:\n        logger.warning(\"No files loaded successfully\")\n        return pd.DataFrame()\n\n    # Merge all DataFrames\n    logger.info(\"Merging all data...\")\n    combined_df = pd.concat(dataframes, ignore_index=True)\n    logger.info(f\"Merge complete, total {len(combined_df)} rows\")\n\n    return combined_df\n\n\ndef split_train_test(\n    df: pd.DataFrame,\n    test_size: int,\n    seed: int = None\n) -> tuple:\n    \"\"\"Split DataFrame into training and test sets.\n\n    Args:\n        df: DataFrame to split\n        test_size: Number of test samples\n        seed: Random seed\n\n    Returns:\n        (train_df, test_df) tuple\n    \"\"\"\n    if len(df) == 0:\n        logger.warning(\"DataFrame is empty, cannot split\")\n        return pd.DataFrame(), pd.DataFrame()\n\n    if test_size <= 0:\n        raise ValueError(f\"test_size must be greater than 0, current value: {test_size}\")\n\n    total_rows = len(df)\n\n    if test_size >= total_rows:\n        logger.warning(\n            f\"Test size ({test_size}) is greater than or equal to total rows ({total_rows}), \"\n            f\"using all data as test set, training set will be empty\"\n        )\n        return pd.DataFrame(), df.copy()\n\n    # Use pandas sample method for random sampling, ensuring reproducibility\n    if seed is not None:\n        logger.info(f\"Using random seed: {seed}\")\n\n    logger.info(f\"Randomly selecting {test_size} rows from {total_rows} rows as test set...\")\n\n    # Use pandas sample method to randomly select test set\n    test_df = df.sample(n=test_size, random_state=seed).copy()\n    # Get test set indices\n    test_indices = set(test_df.index)\n    # Remaining data as training set\n    train_df = df.drop(test_indices).copy()\n\n    logger.info(f\"Split complete: training set {len(train_df)} rows, test set {len(test_df)} rows\")\n\n    return train_df, test_df\n\n\ndef shuffle_dataframe(df: pd.DataFrame, seed: int = None) -> pd.DataFrame:\n    \"\"\"Shuffle DataFrame.\n\n    Args:\n        df: DataFrame to shuffle\n        seed: Random seed (for reproducibility)\n\n    Returns:\n        Shuffled DataFrame\n    \"\"\"\n    if len(df) == 0:\n        return df.copy()\n\n    # Use sample method for shuffling (frac=1 means sampling all data, i.e., shuffling)\n    # random_state parameter ensures reproducibility\n    shuffled_df = df.sample(frac=1, random_state=seed).reset_index(drop=True)\n\n    return shuffled_df\n\n\ndef main():\n    \"\"\"Main function.\"\"\"\n    parser = argparse.ArgumentParser(\n        description='Randomly select N samples from multiple parquet files as test set, remaining data as training set'\n    )\n    parser.add_argument(\n        '--input_files',\n        type=str,\n        nargs='+',\n        required=True,\n        help='List of input parquet file paths (can specify multiple files)'\n    )\n    parser.add_argument(\n        '--test_size',\n        type=int,\n        required=True,\n        help='Number of test samples'\n    )\n    parser.add_argument(\n        '--output_dir',\n        type=str,\n        required=True,\n        help='Output directory path'\n    )\n    parser.add_argument(\n        '--seed',\n        type=int,\n        default=None,\n        help='Random seed (optional, for reproducibility)'\n    )\n    parser.add_argument(\n        '--engine',\n        choices=['pyarrow', 'fastparquet'],\n        default='pyarrow',\n        help='Parquet processing engine (default: pyarrow)'\n    )\n    parser.add_argument(\n        '--test_filename',\n        type=str,\n        default='test.parquet',\n        help='Test set output filename (default: test.parquet)'\n    )\n    parser.add_argument(\n        '--train_filename',\n        type=str,\n        default='train.parquet',\n        help='Training set output filename (default: train.parquet)'\n    )\n    \n    args = parser.parse_args()\n\n    # Validate parameters\n    if args.test_size <= 0:\n        logger.error(f\"test_size must be greater than 0, current value: {args.test_size}\")\n        sys.exit(1)\n    \n    # Validate input files exist\n    input_files = []\n    for file_path in args.input_files:\n        path = Path(file_path)\n        if not path.exists():\n            logger.warning(f\"File does not exist, skipping: {file_path}\")\n            continue\n        if not path.is_file():\n            logger.warning(f\"Path is not a file, skipping: {file_path}\")\n            continue\n        if path.suffix.lower() != '.parquet':\n            logger.warning(f\"Not a parquet file, skipping: {file_path}\")\n            continue\n        input_files.append(str(path))\n    \n    if not input_files:\n        logger.error(\"No valid parquet files found\")\n        sys.exit(1)\n    \n    try:\n        # 1. Load all parquet files\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 1: Loading parquet files...\")\n        combined_df = load_all_parquet_files(input_files, engine=args.engine)\n\n        if len(combined_df) == 0:\n            logger.error(\"No data loaded\")\n            sys.exit(1)\n        \n        # 2. Split training and test sets\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 2: Splitting training and test sets...\")\n        train_df, test_df = split_train_test(\n            combined_df,\n            test_size=args.test_size,\n            seed=args.seed\n        )\n        \n        if len(test_df) == 0:\n            logger.error(\"Test set is empty, cannot continue\")\n            sys.exit(1)\n        \n        # 3. Shuffle data\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 3: Shuffling data...\")\n        \n        # Use different seed offsets for training and test sets to ensure different shuffle results\n        # If seed is provided, use different offsets; otherwise use None for both (completely random)\n        train_seed = (args.seed + 1000) if args.seed is not None else None\n        test_seed = (args.seed + 2000) if args.seed is not None else None\n        \n        logger.info(\"Shuffling training set...\")\n        train_df = shuffle_dataframe(train_df, seed=train_seed)\n\n        logger.info(\"Shuffling test set...\")\n        test_df = shuffle_dataframe(test_df, seed=test_seed)\n        \n        # 4. Save results\n        logger.info(\"=\" * 60)\n        logger.info(\"Step 4: Saving results...\")\n        output_dir = Path(args.output_dir)\n        output_dir.mkdir(parents=True, exist_ok=True)\n        \n        test_path = output_dir / args.test_filename\n        train_path = output_dir / args.train_filename\n\n        logger.info(f\"Saving test set to: {test_path}\")\n        test_df.to_parquet(\n            test_path,\n            engine='pyarrow',\n            index=False,\n            compression='snappy'\n        )\n        \n        if len(train_df) > 0:\n            logger.info(f\"Saving training set to: {train_path}\")\n            train_df.to_parquet(\n                train_path,\n                engine='pyarrow',\n                index=False,\n                compression='snappy'\n            )\n        else:\n            logger.warning(\"Training set is empty, skipping save\")\n        \n        # 5. Output statistics\n        logger.info(\"=\" * 60)\n        logger.info(\"Processing complete!\")\n        logger.info(f\"Number of input files: {len(input_files)}\")\n        logger.info(f\"Original data rows: {len(combined_df)}\")\n        logger.info(f\"Training set rows: {len(train_df)}\")\n        logger.info(f\"Test set rows: {len(test_df)}\")\n        logger.info(f\"Output directory: {output_dir}\")\n        logger.info(f\"Training set file: {train_path}\")\n        logger.info(f\"Test set file: {test_path}\")\n        if args.seed is not None:\n            logger.info(f\"Random seed: {args.seed}\")\n        logger.info(\"=\" * 60)\n        \n    except KeyboardInterrupt:\n        logger.info(\"\\nOperation cancelled by user\")\n        sys.exit(1)\n    except Exception as e:\n        logger.error(f\"Program execution failed: {e}\", exc_info=True)\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "pretrain/.gitignore",
    "content": "# Python\n__pycache__/\n*.pyc\n*.so\n*.egg-info\n*.pylintrc\n\n# Build\nbuild\n\n# IDE\n.vscode/\n.idea/\n*~\n\n# OS\n.DS_Store\n\n# Project specific\nkeys.txt\nhostfile\n\n# Data files\n*.parquet\n*.bin\n*.pt\n*.npy\n\n\n# Logs and outputs\n*.log\n*.err\n*.out\n\n# System files\n*.swp\n*.deb\n.git\n"
  },
  {
    "path": "pretrain/README.md",
    "content": "# OpenOneRec Pretraining Module\n\nThe OpenOneRec pretraining module is based on the Qwen3 architecture, supporting a two-stage pretraining pipeline (Itemic-Text Alignment → Full-parameter Co-Pretraining) and SFT training workflow.\n\n> **⚠️ Important Notice**\n>\n> The distributed training in this module **relies on MPI (Message Passing Interface)** for multi-node communication. The current training scripts use `mpirun` to launch distributed training, requiring proper MPI environment configuration (e.g., OpenMPI) and hostfile setup.\n>\n> To simplify environment configuration and improve reproducibility, we plan to release in future versions:\n> - **Pre-configured Docker/Apptainer images**: Including all necessary dependencies and MPI environment\n> - **torchrun-based training scripts**: Providing an easier way to launch distributed training\n>\n> Before the images and torchrun versions are released, please ensure your environment has MPI properly installed and configured.\n\n\n## Quick Start\n\n### Prerequisites\n\n- **Hardware**: CUDA-enabled GPUs (multi-GPU or multi-node recommended)\n- **Software**:\n  - Python 3.8+\n  - PyTorch (with FSDP and distributed training support)\n  - OpenMPI or compatible MPI implementation\n  - NCCL (for GPU communication)\n- **Data**: Training data converted to Parquet format (refer to `../data/README.md`)\n- **Model**: Qwen3 base model (HuggingFace format)\n\n### 1. Environment Setup\n\nFirst, configure the training environment:\n\n```bash\n# Set environment variables\nsource set_env.sh\n```\n\nThis script sets necessary environment variables, including Python path, CUDA path, etc.\n\n### 2. Qwen3 Model Vocabulary Expansion\n\nBefore starting training, you need to expand the vocabulary of the Qwen3 base model to support recommendation system-specific item ID encoding (itemic tokens).\n\n#### 2.1 Configure Parameters\n\nEdit `scripts/expand_qwen3_vocab.sh` and set the following parameters:\n\n```bash\nHF_MODEL_DIR=/path/to/Qwen3-0.6B          # Original Qwen3 HuggingFace model path\nOUTPUT_MODEL_DIR=/path/to/Qwen3-0.6B_itemic  # Output model path with expanded vocabulary\nITEMIC_LAYER_N=3                          # Number of layers for itemic tokens\nVOCAB_SIZE_PER_LAYER=8192                 # Vocabulary size expansion per layer\n```\n\n#### 2.2 Execute Expansion\n\n```bash\nbash scripts/expand_qwen3_vocab.sh\n```\n\nThis script will:\n- Add new itemic tokens on top of the original vocabulary\n- Align vocabulary size to multiples of 256\n- Initialize embedding weights for new tokens\n- Save the expanded model to the specified directory\n\n**Note**: The expanded model path needs to be used in the data configuration file for subsequent training (`base_model_dir` field).\n\n### 3. Data Preparation\n\nTraining data needs to be converted to Parquet format. Please refer to `../data/README.md` for format specifications.\n\nData configuration is specified through JSON files located in the `examples/dataset_config/` directory.\n\n#### Data Configuration Format\n\nEach data configuration file contains the following main fields:\n\n```json\n{\n    \"name\": \"chat_completion_parquet\",\n    \"sources\": \"/path/to/file_list.json\",\n    \"base_model_dir\": \"/path/to/Qwen3-1.7B_itemic\",\n    \"max_length\": 30000,\n    \"num_epochs\": 3,\n    \"num_workers\": 2,\n    \"itemic_id_range\": [151669, 176246],\n    \"add_think_pattern\": false,\n    \"local_shuffle_buffer_size\": 100000\n    ...\n}\n```\n\n### 4. Training\n\nTraining scripts are located in the `examples/` directory, and data configuration files are in the `examples/dataset_config/` directory.\n\n#### 4.1 Stage1 Pretraining\n\nStage1 is mainly used for training itemic embeddings, typically freezing LLM parameters and only optimizing the embedding layer.\n\n```bash\n# Edit examples/pretrain_stg1.sh to set model path, output path, and other parameters\nbash examples/pretrain_stg1.sh\n```\n\nMain training parameters (configured in `pretrain_stg1.sh`):\n- `--dataset_config examples/dataset_config/stg1.json`: Specify data configuration\n- `--freeze_llm`: Freeze LLM parameters\n- `--start_optimize_embedding_index 151669`: Start optimizing embeddings from the specified token ID\n- `--model_dir`: Base model path with expanded vocabulary\n- `--output_dir`: Model output path\n\n**Note**: After training, convert the checkpoint to HuggingFace format (see [Model Conversion](#model-conversion)).\n\n#### 4.2 Stage2 Pretraining\n\nStage2 is used for full-parameter pretraining to further optimize model performance. This stage unfreezes all model parameters and performs co-pretraining on a mixed domain of recommendation data and general text data.\n\n```bash\n# Edit examples/pretrain_stg2.sh to set model path, output path, and other parameters\n# MODEL_DIR should point to the converted hf model path from Stage1 training output\nbash examples/pretrain_stg2.sh\n```\n\nMain training parameters (configured in `pretrain_stg2.sh`):\n- `--dataset_config examples/dataset_config/pretrain.json`: Specify data configuration (including recommendation data and general text data)\n- `--model_dir`: Converted model path from Stage1 output\n- `--output_dir`: Model output path\n- Note: **Does not include** `--freeze_llm` parameter, indicating full-parameter training\n\n**Note**: After training, convert the checkpoint to HuggingFace format (see [Model Conversion](#model-conversion)).\n\n#### 4.3 SFT Fine-tuning\n\nSFT (Supervised Fine-Tuning) is used for instruction fine-tuning to improve model performance on specific tasks. This stage performs supervised learning on instruction-following data, enabling the model to better understand and execute recommendation-related instructions.\n\n```bash\n# Edit examples/posttrain_sft.sh to set model path, output path, and other parameters\n# MODEL_DIR should point to the converted hf model path from Stage2 training output\nbash examples/posttrain_sft.sh\n```\n\nMain training parameters (configured in `posttrain_sft.sh`):\n- `--dataset_config examples/dataset_config/sft.json`: Specify SFT data configuration\n- `--model_dir`: Converted model path from Stage2 output\n- `--output_dir`: Model output path\n- `add_think_pattern: true` in data configuration enables thinking mode, which automatically adds `<think>` `</think>` tags and `/think` and `/no_think` instructions (for reasoning tasks)\n\n**Note**: After training, convert the checkpoint to HuggingFace format (see [Model Conversion](#model-conversion)).\n\n## Training Configuration\n\n### Data Configuration Fields\n\n| Field | Type | Description |\n|-------|------|-------------|\n| `name` | str | Data loader name, default is `\"chat_completion_parquet\"` |\n| `sources` | str | Data file list path (JSON file) or directory path list |\n| `base_model_dir` | str | Base model path (with expanded vocabulary), used for tokenizing data |\n| `max_length` | int | Maximum sequence length |\n| `num_epochs` | int | Number of training epochs |\n| `num_workers` | int | Number of dataloader workers |\n| `model_class` | str | Model class name, default is `\"Qwen3ForCausalLM\"` |\n| `itemic_id_range` | list | Itemic token ID range `[start, end]`, only used for metrics statistics |\n| `only_assistant_loss` | bool | Whether to only compute loss for assistant responses, applies to chat format data |\n| `local_shuffle_buffer_size` | int | Local sample-level shuffle buffer size |\n| `add_think_pattern` | bool | Whether to add think tags (add `/think` `/no_think` in prompt, and `<think>` `</think>` in response) |\n\nNotes:\n* The default dataset is implemented based on torch.utils.data.IterableDataset\n* By default, one GPU is bound to one process, each process creates `num_workers` workers. The dataset distributes files from `sources` to each worker at file granularity based on total worker count. The file list is shuffled before distribution, and sample-level shuffle is performed according to `local_shuffle_buffer_size` when reading data\n* If `num_epochs` > 1, file distribution is performed twice, with the file list reshuffled each time\n\n\n### Training Parameters\n\nMain training parameters are passed via command line to `recipes/train_qwen3.py`:\n\n| Parameter | Description |\n|-----------|-------------|\n| `--model_dir` | Base model path (HuggingFace format) |\n| `--output_dir` | Model output path |\n| `--dataset_config` | Data configuration file path |\n| `--freeze_llm` | Whether to freeze LLM parameters |\n| `--learning_rate` | Learning rate |\n| `--max_length` | Sequence length per step |\n| `--min_lr` | Minimum learning rate |\n| `--lr_scheduler_type` | Learning rate scheduler type (e.g., `cosine`) |\n| `--num_training_steps` | Number of training steps |\n| `--save_checkpoint_per_step` | Save checkpoint every N steps |\n| `--minibatch_size` | LLM head chunk size for chunked loss computation to save memory |\n| `--resume_from` | Checkpoint directory path to resume training from |\n| `--resume_from_tag` | Checkpoint tag to resume from (e.g., `global_step1000`) |\n| `--resume_training_state` | Whether to restore full training state (including optimizer, lr scheduler, and dataloader state) |\n| `--start_optimize_embedding_index` | Start optimizing embeddings from the specified token ID (for Stage1 training, typically set to the starting ID of itemic tokens, e.g., 151669) |\n| `--use_tie_weights` | Tie embedding and lm_head weights (required for smaller models like 0.6B / 1.7B / 4B to align with Qwen3 model configuration) |\n\nNotes:\n* `resume_from` is used to load checkpoints produced by the framework. When `resume_from` is configured, it takes priority; only model structure parameters from `model_dir` are loaded for initialization. If not configured, parameters from `model_dir` are also loaded\n* `num_training_steps` only affects the lr decay steps. This configuration ensures that when training reaches `num_training_steps`, lr decays to minimum, but training will not stop. It is recommended to configure based on token count and `max_length` to calculate the maximum training steps\n* `max_length` represents the maximum sequence length per GPU per step; the framework will perform packing based on this configuration\n\n## Utility Scripts\n\n### Model Conversion\n\nConvert trained checkpoints to HuggingFace format:\n\n```bash\nbash scripts/convert_checkpoint_to_hf.sh <base_model_dir> <model_home> <step>\n```\n\nParameter description:\n- `base_model_dir`: Qwen base model directory with expanded vocabulary (output from vocabulary expansion stage)\n- `model_home`: Training output directory (i.e., `OUTPUT_DIR` in training script)\n- `step`: Checkpoint step number to convert\n\n**Example:**\n```bash\n# Assuming the vocabulary-expanded model is in ./qwen_extended\n# Training output is in ./output\n# Converting the checkpoint at step 4000\nbash scripts/convert_checkpoint_to_hf.sh ./qwen_extended ./output 4000\n```\n\nConversion process:\n1. The script automatically locates the `{model_home}/step{step}/global_step{step}` directory\n2. Reads the training checkpoint from that directory\n3. Saves the converted HuggingFace format model to `{model_home}/step{step}/global_step{step}/converted/`\n\nThe converted model can be directly used for:\n- Loading and inference with HuggingFace Transformers\n- Subsequent SFT or other fine-tuning stages\n- Model evaluation and deployment\n\n### Model Testing\n\nTest the converted HuggingFace model:\n\n```bash\nbash scripts/test_hf_model.sh <hf_model_dir>\n```\n\nParameter description:\n- `hf_model_dir`: Converted HuggingFace model directory\n\n**Example:**\n```bash\n# Test the converted model at step 4000\nbash scripts/test_hf_model.sh ./output/step4000/global_step4000/converted/\n```\n\nThis script will verify:\n- Whether model weights are loaded correctly\n- Whether forward pass works normally\n- Whether generation functionality is available\n\n### Training Monitoring\n\nLogs and outputs during training:\n\n- **Standard output/error**: Saved in `$OUTPUT_DIR/stdout.log` and `$OUTPUT_DIR/stderr.log`\n- **Training logs**: Contains loss values, learning rate, training steps, and other information\n- **TensorBoard**: The model supports TensorBoard visualization. You can start TensorBoard with:\n  ```bash\n  tensorboard --logdir=$OUTPUT_DIR\n  ```\n- **Checkpoint**: Saved at configured step intervals (`--save_checkpoint_per_step`)\n\n### Checkpoint Management\n\nCheckpoints are saved periodically during training with the following directory structure:\n\n```\noutput_dir/\n├── step50/\n│   └── global_step50/\n│       ├── model/          # Model weights\n│       ├── optimizer/      # Optimizer state\n│       └── ...\n├── step100/\n│   └── global_step100/\n│       └── ...\n└── ...\n```\n\n**Resuming Training**:\nTo resume training from a checkpoint, add the following to the training script:\n```bash\n--resume_from $OUTPUT_DIR/step1000 \\\n--resume_from_tag global_step1000 \\\n--resume_training_state\n```\n\n## Notes\n\n\n1. **MPI Environment**:\n   - Training scripts use `mpirun` for multi-node distributed training, requiring OpenMPI or compatible MPI implementation\n   - Proper hostfile configuration is required (e.g., `/etc/mpi/hostfile`), with one node address per line\n   - Ensure passwordless SSH access between all nodes\n   - Training scripts automatically read environment variables like `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE`, etc.\n\n2. **Data Format**:\n   - Ensure training data conforms to Parquet format specifications, refer to `../data/README.md`\n   - It is recommended that each Parquet file contains approximately 1000 samples for efficient loading and shuffling\n   - Data file lists are specified through JSON files, supporting both local paths and HDFS paths\n\n3. **Vocabulary Expansion**:\n   - Vocabulary expansion must be performed before training, using the expanded model as `base_model_dir`\n   - The expanded model path needs to be specified in the `base_model_dir` field of the data configuration file\n   - Ensure `itemic_id_range` is consistent with the configuration during vocabulary expansion\n\n4. **Model Size**:\n   - For smaller models like 0.6B / 1.7B / 4B, the `--use_tie_weights` parameter is required to align with Qwen3 model configuration\n   - Different model sizes may require different learning rate and training step configurations\n\n## Related Documentation\n\n- [OpenOneRec Main README](../README.md): Project overview and complete workflow\n- [Data Format Specification](../data/README.md): Training data format requirements and preprocessing methods"
  },
  {
    "path": "pretrain/examples/dataset_config/pretrain.json",
    "content": "{\n    \"name\": \"chat_completion_parquet\",\n    \"sources\": \"../output/split_data_pretrain/file_list.json\",\n    \"only_assistant_loss\": false,\n    \"max_length\": 30000,\n    \"base_model_dir\": \"/code/hf_models/Qwen3-1.7B_itemic\",\n    \"num_workers\": 2,\n    \"num_epochs\": 4,\n    \"cut_to_pad\": 1,\n    \"model_class\": \"Qwen3ForCausalLM\",\n    \"full_attention\": false,\n    \"local_shuffle_buffer_size\": 10000,\n    \"max_sample_length\": 30000,\n    \"local_shuffle_random_fetch\": 0.0001,\n    \"itemic_id_range\": [151669, 176246]\n}\n"
  },
  {
    "path": "pretrain/examples/dataset_config/sft.json",
    "content": "{\n    \"name\": \"chat_completion_parquet\",\n    \"sources\": \"../output/split_data_sft/file_list.json\",\n    \"only_assistant_loss\": false,\n    \"max_length\": 30000,\n    \"base_model_dir\": \"/code/hf_models/Qwen3-1.7B_itemic\",\n    \"num_workers\": 2,\n    \"num_epochs\": 4,\n    \"cut_to_pad\": 1,\n    \"model_class\": \"Qwen3ForCausalLM\",\n    \"full_attention\": false,\n    \"local_shuffle_buffer_size\": 10000,\n    \"max_sample_length\": 30000,\n    \"local_shuffle_random_fetch\": 0.0001,\n    \"itemic_id_range\": [151669, 176246],\n    \"add_think_pattern\": true\n}\n"
  },
  {
    "path": "pretrain/examples/posttrain_sft.sh",
    "content": "sed 's/=1/=8/g' /etc/mpi/hostfile > /etc/mpi/hostfile_seq\n\n# MODEL_DIR=/code/hf_models/Qwen3-1.7B_itemic\nSTAGE2_OUTPUT_DIR=/code/onerec_pretrain/model_output/stg2_opt_utils_big\nMODEL_DIR=${STAGE2_OUTPUT_DIR}/step5000/global_step5000/converted\nOUTPUT_DIR=/code/onerec_pretrain/model_output/sft_output\n\nmkdir -p $OUTPUT_DIR\nmkdir -p /tmp/_wids_cache\n\nnnode=$(wc -l < /etc/mpi/hostfile_seq)\n\nset -x\n\nSCRIPT_FILE=$(readlink -f $0)\necho `date '+%Y-%m-%d %H:%M:%S'` >> $OUTPUT_DIR/task_info.log\necho \"script: ${SCRIPT_FILE}\" >> $OUTPUT_DIR/task_info.log\necho \"=========================\" >> $OUTPUT_DIR/task_info.log\n\necho \"Output: $OUTPUT_DIR\"\n\nexport PYTHONPATH=$PWD:$PYTHONPATH\n\nsource set_env.sh\n\nhostfile=/etc/mpi/hostfile_seq\nTCP_NIC=$(ifconfig | grep -B1 \" \"$(hostname -i)\" \" | grep -o \"^\\w*\")\n\nMASTER_ADDR=$MY_NODE_IP\nMASTER_PORT=8499\n\nmpirun --allow-run-as-root \\\n    -hostfile $hostfile \\\n    -mca btl self,tcp -mca pml ob1 \\\n    -mca plm_rsh_num_concurrent 600 \\\n    -mca routed_radix 600 \\\n    -mca btl_tcp_if_include $TCP_NIC \\\n    -mca oob_tcp_if_include $TCP_NIC \\\n    -mca btl_openib_allow_ib false \\\n    -mca opal_set_max_sys_limits 1 \\\n    -x OMPI_MCA_btl=self,tcp \\\n    -x OMPI_MCA_pml=ob1 \\\n    -x OMPI_MCA_btl_tcp_if_include=$TCP_NIC \\\n    -x OMPI_MCA_oob_tcp_if_include=$TCP_NIC \\\n    -x OMPI_MCA_btl_openib_allow_ib=false \\\n    -x NCCL_IB_DISABLE=0 \\\n    -x NCCL_IB_GID_INDEX=3 \\\n    -x NCCL_SOCKET_IFNAME=$TCP_NIC \\\n    -x NCCL_IB_HCA=mlx5 \\\n    -x NCCL_DEBUG=WARN \\\n    -x NCCL_IB_QPS_PER_CONNECTION=4 \\\n    -x NCCL_NET_OVERHEAD=1000 \\\n    -x NCCL_IB_TIMEOUT=20 \\\n    -x LD_PRELOAD=$LD_PRELOAD \\\n    -x http_proxy=\"\" \\\n    -x https_proxy=\"\" \\\n    -x HOROVOD_MPI_THREADS_DISABLE=1 \\\n    -x MPI_THREAD_SINGLE=1 \\\n    -x NO_COLOR=1 \\\n    -x TERM=dumb \\\n    -x COLORTERM=0 \\\n    -x PYTHONIOENCODING=utf-8 \\\n    -x LD_LIBRARY_PATH=$LIBRARY_PATH \\\n    -x PATH \\\n    -x PYTHONPATH=$PYTHONPATH \\\n    -x JAVA_HOME=$JAVA_HOME \\\n    -x HIVE_HOME=$HIVE_HOME \\\n    -x CLASSPATH=$CLASSPATH \\\n    -x HADOOP_USER_NAME=$HADOOP_USER_NAME \\\n    -x HADOOP_HOME=$HADOOP_HOME \\\n    -x SPARK_HOME=$SPARK_HOME \\\n    -x MASTER_ADDR=$MASTER_ADDR \\\n    -x MASTER_PORT=$MASTER_PORT \\\n    -x TOKENIZERS_PARALLELISM=false \\\n    with_nccl_local_env \\\n    bash -c \"bash scripts/numa_runner.sh python3 recipes/train_qwen3.py \\\n        --model_dir $MODEL_DIR \\\n        --output_dir $OUTPUT_DIR \\\n        --dataset_config examples/dataset_config/sft.json \\\n        --use_tie_weights \\\n        --model_class Qwen3ForCausalLM \\\n        --monitor_datasource_loss \\\n        --monitor_datasource_cnt \\\n        --max_length 32768 \\\n        --learning_rate 2e-4 \\\n        --min_lr 1e-4 \\\n        --weight_decay 0.1 \\\n        --max_grad_norm 1.0 \\\n        --lr_scheduler_type cosine \\\n        --num_warmup_steps 500 \\\n        --num_training_steps 5000 \\\n        --save_checkpoint_per_step 50 \\\n        --minibatch_size 16384 \\\n        --logging_per_step 5 \\\n        --use_fp32_weight \\\n        --seed 19260817 \\\n        --enable_profiler \\\n        --enable_gradient_checkpointing \\\n        --use_chunked_loss_computer \\\n    \" > $OUTPUT_DIR/stdout.log 2>$OUTPUT_DIR/stderr.log &\n\n        # --resume_from $PRETRAIN_OUTPUT_DIR/step5000 \\\n        # --resume_from_tag global_step5000 \\"
  },
  {
    "path": "pretrain/examples/pretrain_stg1.sh",
    "content": "sed 's/=1/=8/g' /etc/mpi/hostfile > /etc/mpi/hostfile_seq\n\nMODEL_DIR=/code/hf_models/Qwen3-1.7B_itemic\nOUTPUT_DIR=/code/onerec_pretrain/model_output/stg1_opt_utils_big\nmkdir -p $OUTPUT_DIR\nmkdir -p /tmp/_wids_cache\n\nnnode=$(wc -l < /etc/mpi/hostfile_seq)\n\nset -x\n\nSCRIPT_FILE=$(readlink -f $0)\necho `date '+%Y-%m-%d %H:%M:%S'` >> $OUTPUT_DIR/task_info.log\necho \"script: ${SCRIPT_FILE}\" >> $OUTPUT_DIR/task_info.log\necho \"=========================\" >> $OUTPUT_DIR/task_info.log\n\necho \"Output: $OUTPUT_DIR\"\n\nexport PYTHONPATH=$PWD:$PYTHONPATH\n\nsource set_env.sh\n\nhostfile=/etc/mpi/hostfile_seq\nTCP_NIC=$(ifconfig | grep -B1 \" \"$(hostname -i)\" \" | grep -o \"^\\w*\")\n\nMASTER_ADDR=$MY_NODE_IP\nMASTER_PORT=8499\n\nmpirun --allow-run-as-root \\\n    -hostfile $hostfile \\\n    -mca btl self,tcp -mca pml ob1 \\\n    -mca plm_rsh_num_concurrent 600 \\\n    -mca routed_radix 600 \\\n    -mca btl_tcp_if_include $TCP_NIC \\\n    -mca oob_tcp_if_include $TCP_NIC \\\n    -mca btl_openib_allow_ib false \\\n    -mca opal_set_max_sys_limits 1 \\\n    -x OMPI_MCA_btl=self,tcp \\\n    -x OMPI_MCA_pml=ob1 \\\n    -x OMPI_MCA_btl_tcp_if_include=$TCP_NIC \\\n    -x OMPI_MCA_oob_tcp_if_include=$TCP_NIC \\\n    -x OMPI_MCA_btl_openib_allow_ib=false \\\n    -x NCCL_IB_DISABLE=0 \\\n    -x NCCL_IB_GID_INDEX=3 \\\n    -x NCCL_SOCKET_IFNAME=$TCP_NIC \\\n    -x NCCL_IB_HCA=mlx5 \\\n    -x NCCL_DEBUG=WARN \\\n    -x NCCL_IB_QPS_PER_CONNECTION=4 \\\n    -x NCCL_NET_OVERHEAD=1000 \\\n    -x NCCL_IB_TIMEOUT=20 \\\n    -x LD_PRELOAD=$LD_PRELOAD \\\n    -x http_proxy=\"\" \\\n    -x https_proxy=\"\" \\\n    -x HOROVOD_MPI_THREADS_DISABLE=1 \\\n    -x MPI_THREAD_SINGLE=1 \\\n    -x NO_COLOR=1 \\\n    -x TERM=dumb \\\n    -x COLORTERM=0 \\\n    -x PYTHONIOENCODING=utf-8 \\\n    -x LD_LIBRARY_PATH=$LIBRARY_PATH \\\n    -x PATH \\\n    -x PYTHONPATH=$PYTHONPATH \\\n    -x JAVA_HOME=$JAVA_HOME \\\n    -x HIVE_HOME=$HIVE_HOME \\\n    -x CLASSPATH=$CLASSPATH \\\n    -x HADOOP_USER_NAME=$HADOOP_USER_NAME \\\n    -x HADOOP_HOME=$HADOOP_HOME \\\n    -x SPARK_HOME=$SPARK_HOME \\\n    -x MASTER_ADDR=$MASTER_ADDR \\\n    -x MASTER_PORT=$MASTER_PORT \\\n    -x TOKENIZERS_PARALLELISM=false \\\n    with_nccl_local_env \\\n    bash -c \"bash scripts/numa_runner.sh python3 recipes/train_qwen3.py \\\n        --model_dir $MODEL_DIR \\\n        --output_dir $OUTPUT_DIR \\\n        --dataset_config examples/dataset_config/pretrain.json \\\n        --freeze_llm \\\n        --use_tie_weights \\\n        --start_optimize_embedding_index 151669 \\\n        --model_class Qwen3ForCausalLM \\\n        --monitor_datasource_loss \\\n        --monitor_datasource_cnt \\\n        --max_length 32768 \\\n        --learning_rate 2e-4 \\\n        --min_lr 1e-4 \\\n        --weight_decay 0.1 \\\n        --max_grad_norm 1.0 \\\n        --lr_scheduler_type cosine \\\n        --num_warmup_steps 200 \\\n        --num_training_steps 2000 \\\n        --save_checkpoint_per_step 50 \\\n        --minibatch_size 16384 \\\n        --logging_per_step 5 \\\n        --use_fp32_weight \\\n        --seed 19260817 \\\n        --enable_profiler \\\n        --enable_gradient_checkpointing \\\n        --use_chunked_loss_computer \\\n    \" > $OUTPUT_DIR/stdout.log 2>$OUTPUT_DIR/stderr.log &"
  },
  {
    "path": "pretrain/examples/pretrain_stg2.sh",
    "content": "sed 's/=1/=8/g' /etc/mpi/hostfile > /etc/mpi/hostfile_seq\n\n# MODEL_DIR=/code/hf_models/Qwen3-1.7B_itemic\nSTAGE1_OUTPUT_DIR=/code/onerec_pretrain/model_output/stg1_opt_utils_big\nMODEL_DIR=${STAGE1_OUTPUT_DIR}/step2000/global_step2000/converted\nOUTPUT_DIR=/code/onerec_pretrain/model_output/stg2_opt_utils_big\n\n\nmkdir -p $OUTPUT_DIR\nmkdir -p /tmp/_wids_cache\n\nnnode=$(wc -l < /etc/mpi/hostfile_seq)\n\nset -x\n\nSCRIPT_FILE=$(readlink -f $0)\necho `date '+%Y-%m-%d %H:%M:%S'` >> $OUTPUT_DIR/task_info.log\necho \"script: ${SCRIPT_FILE}\" >> $OUTPUT_DIR/task_info.log\necho \"=========================\" >> $OUTPUT_DIR/task_info.log\n\necho \"Output: $OUTPUT_DIR\"\n\nexport PYTHONPATH=$PWD:$PYTHONPATH\n\nsource set_env.sh\n\nhostfile=/etc/mpi/hostfile_seq\nTCP_NIC=$(ifconfig | grep -B1 \" \"$(hostname -i)\" \" | grep -o \"^\\w*\")\n\nMASTER_ADDR=$MY_NODE_IP\nMASTER_PORT=8499\n\nmpirun --allow-run-as-root \\\n    -hostfile $hostfile \\\n    -mca btl self,tcp -mca pml ob1 \\\n    -mca plm_rsh_num_concurrent 600 \\\n    -mca routed_radix 600 \\\n    -mca btl_tcp_if_include $TCP_NIC \\\n    -mca oob_tcp_if_include $TCP_NIC \\\n    -mca btl_openib_allow_ib false \\\n    -mca opal_set_max_sys_limits 1 \\\n    -x OMPI_MCA_btl=self,tcp \\\n    -x OMPI_MCA_pml=ob1 \\\n    -x OMPI_MCA_btl_tcp_if_include=$TCP_NIC \\\n    -x OMPI_MCA_oob_tcp_if_include=$TCP_NIC \\\n    -x OMPI_MCA_btl_openib_allow_ib=false \\\n    -x NCCL_IB_DISABLE=0 \\\n    -x NCCL_IB_GID_INDEX=3 \\\n    -x NCCL_SOCKET_IFNAME=$TCP_NIC \\\n    -x NCCL_IB_HCA=mlx5 \\\n    -x NCCL_DEBUG=WARN \\\n    -x NCCL_IB_QPS_PER_CONNECTION=4 \\\n    -x NCCL_NET_OVERHEAD=1000 \\\n    -x NCCL_IB_TIMEOUT=20 \\\n    -x LD_PRELOAD=$LD_PRELOAD \\\n    -x http_proxy=\"\" \\\n    -x https_proxy=\"\" \\\n    -x HOROVOD_MPI_THREADS_DISABLE=1 \\\n    -x MPI_THREAD_SINGLE=1 \\\n    -x NO_COLOR=1 \\\n    -x TERM=dumb \\\n    -x COLORTERM=0 \\\n    -x PYTHONIOENCODING=utf-8 \\\n    -x LD_LIBRARY_PATH=$LIBRARY_PATH \\\n    -x PATH \\\n    -x PYTHONPATH=$PYTHONPATH \\\n    -x JAVA_HOME=$JAVA_HOME \\\n    -x HIVE_HOME=$HIVE_HOME \\\n    -x CLASSPATH=$CLASSPATH \\\n    -x HADOOP_USER_NAME=$HADOOP_USER_NAME \\\n    -x HADOOP_HOME=$HADOOP_HOME \\\n    -x SPARK_HOME=$SPARK_HOME \\\n    -x MASTER_ADDR=$MASTER_ADDR \\\n    -x MASTER_PORT=$MASTER_PORT \\\n    -x TOKENIZERS_PARALLELISM=false \\\n    with_nccl_local_env \\\n    bash -c \"bash scripts/numa_runner.sh python3 recipes/train_qwen3.py \\\n        --model_dir $MODEL_DIR \\\n        --output_dir $OUTPUT_DIR \\\n        --dataset_config examples/dataset_config/pretrain.json \\\n        --use_tie_weights \\\n        --model_class Qwen3ForCausalLM \\\n        --monitor_datasource_loss \\\n        --monitor_datasource_cnt \\\n        --max_length 32768 \\\n        --learning_rate 2e-4 \\\n        --min_lr 1e-4 \\\n        --weight_decay 0.1 \\\n        --max_grad_norm 1.0 \\\n        --lr_scheduler_type cosine \\\n        --num_warmup_steps 500 \\\n        --num_training_steps 5000 \\\n        --save_checkpoint_per_step 50 \\\n        --minibatch_size 16384 \\\n        --logging_per_step 5 \\\n        --use_fp32_weight \\\n        --seed 19260817 \\\n        --enable_profiler \\\n        --enable_gradient_checkpointing \\\n        --use_chunked_loss_computer \\\n    \" > $OUTPUT_DIR/stdout.log 2>$OUTPUT_DIR/stderr.log &"
  },
  {
    "path": "pretrain/onerec_llm/__init__.py",
    "content": ""
  },
  {
    "path": "pretrain/onerec_llm/data/__init__.py",
    "content": ""
  },
  {
    "path": "pretrain/onerec_llm/data/dataloaders.py",
    "content": "\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom onerec_llm.data.qwen3_dataset import Qwen3ChatCompletionParquetDataset\n\ndef get_chat_completion_parquet_dataloader(sources: str,\n                                          max_length,\n                                          base_model_dir,\n                                          num_epochs=1,\n                                          shuffle_seed=1024,\n                                          num_workers=8,\n                                          datasource_config={},\n                                          **kwargs):\n    model_type = kwargs.get('model_class','Qwen3ForCausalLM')\n    ModelDataset = {'Qwen3ForCausalLM': Qwen3ChatCompletionParquetDataset}\n    num_readers = kwargs.get(\"num_readers\", 1)\n    shuffle_window = kwargs.get(\"shuffle_window\", 0)\n\n    def input_creator():\n        return ModelDataset[model_type](\n            sources = sources,\n            num_workers = num_workers,\n            num_epochs = num_epochs,\n            shuffle_seed = shuffle_seed,\n            max_length = max_length,\n            base_model_dir=base_model_dir,\n            datasource_config=datasource_config,\n            num_readers=num_readers,\n            shuffle_window=shuffle_window,\n            **kwargs\n            )\n\n    dataset = input_creator()\n    dataloader = StatefulDataLoader(\n        dataset=dataset,\n        shuffle=False,\n        batch_size=1,\n        num_workers=num_workers,\n        collate_fn=lambda x: x[0],\n    )\n    return dataloader\n\n\ndef get_dataloader(name: str, **kwargs):\n    if name == \"chat_completion_parquet\":\n        return get_chat_completion_parquet_dataloader(\n            **kwargs\n        )\n    else:\n        raise NotImplementedError(\"Unsupported dataloader.\")\n\n\n"
  },
  {
    "path": "pretrain/onerec_llm/data/local_shuffle_buffer.py",
    "content": "\"\"\"\nLocal shuffle buffer for data randomization during iteration.\n\nThis module provides a fixed-size buffer that randomizes the order of data samples\nusing hash-based indexing with SortedDict for efficient random access.\n\"\"\"\n\nimport hashlib\nimport logging\nimport threading\nimport traceback\nfrom collections import defaultdict\n\nfrom sortedcontainers import SortedDict\n\nlogger = logging.getLogger(__name__)\n\n\nclass LocalShuffleBuffer:\n    \"\"\"\n    A buffer class to implement local data shuffling.\n    \n    Maintains a fixed-size buffer to randomize the order of data samples during iteration.\n    Uses hash-based indexing with SortedDict for efficient random access.\n    \n    Attributes:\n        buffer_size: Maximum capacity of the buffer\n        random_fetch: Probability to randomly fetch a sample before buffer is full\n        buffer: SortedDict storing samples (key: hash, value: sample)\n        count: Statistics counter (adds, conflicts, buffer_epoch)\n        buffer_multiply: Large multiplier to avoid hash collisions across epochs\n        lock: Thread lock for thread-safe operations\n    \"\"\"\n    \n    def __init__(self, buffer_size: int = 2048, random_fetch: float = 0.01) -> None:\n        \"\"\"\n        Initialize the LocalShuffleBuffer.\n        \n        Args:\n            buffer_size: Maximum capacity of the buffer (default: 2048)\n            random_fetch: Probability to randomly fetch a sample before buffer is full (0.0-1.0, default: 0.01)\n        \"\"\"\n        if buffer_size <= 0:\n            raise ValueError(f\"buffer_size must be positive, got {buffer_size}\")\n        if not 0.0 <= random_fetch <= 1.0:\n            raise ValueError(f\"random_fetch must be between 0.0 and 1.0, got {random_fetch}\")\n        \n        self.buffer_size = buffer_size\n        self.random_fetch = random_fetch\n        self.buffer = SortedDict()  # key: hash, value: sample\n        self.count = defaultdict(int)\n        self.count[\"buffer_epoch\"] = 0\n        # Large multiplier (0xffffffffffffffff) to avoid hash collisions across epochs\n        self.buffer_multiply = int('f' * 16, 16)\n        self.lock = threading.Lock()\n\n    def _calc_sample_hash(self, obj: dict, buffer_epoch: int = None) -> int:\n        \"\"\"\n        Calculate a unique hash for a sample to use as buffer key.\n        \n        Maps sample identifier to integer with random-like distribution using MD5 hash.\n        Adds epoch-based offset to prevent cross-epoch hash collisions.\n        \n        Args:\n            obj: Sample object containing \"uuid\" and \"source\" keys\n            buffer_epoch: Optional epoch index. If None, uses current buffer_epoch\n            \n        Returns:\n            Integer hash value\n        \"\"\"\n        if buffer_epoch is None:\n            buffer_epoch = self.count[\"buffer_epoch\"]\n        \n        # Create unique string from sample identifiers\n        unique_str = f\"{obj['uuid']}{obj['source']}@ep{buffer_epoch}\"\n        \n        # Generate MD5 hash and convert to integer (use first 16 hex chars = 64 bits)\n        hash_obj = hashlib.md5(unique_str.encode('utf-8'))\n        hex_str = hash_obj.hexdigest()[:16]\n        base_hash = int(hex_str, 16)\n        \n        # Add epoch-based offset to prevent cross-epoch collisions\n        return base_hash + self.buffer_multiply * buffer_epoch\n\n    def add(self, obj: dict, fn: str = None, epoch: int = None) -> bool:\n        \"\"\"\n        Add a sample to the buffer.\n        \n        Args:\n            obj: Sample object to add to buffer (must contain \"uuid\" and \"source\" keys)\n            fn: Optional filename/identifier for logging\n            epoch: Optional epoch index\n            \n        Returns:\n            True if sample was added and buffer isn't ready for extraction,\n            False if extraction should occur (buffer full or random fetch triggered)\n        \"\"\"\n        try:\n            # Calculate hash for the sample\n            obj_hash = self._calc_sample_hash(obj, buffer_epoch=epoch)\n            self.count[\"add\"] += 1\n            \n            # Update buffer epoch every buffer_size additions\n            if self.count[\"add\"] % self.buffer_size == 0:\n                self.count[\"buffer_epoch\"] += 1\n\n            # Handle hash collisions (duplicate unique identifiers)\n            if obj_hash in self.buffer:\n                self.count[\"conflict\"] += 1\n                # Log warning periodically for collision rate\n                if self.count[\"conflict\"] % 100 == 0:\n                    conflict_rate = self.count[\"conflict\"] / self.count[\"add\"]\n                    logger.warning(\n                        f\"{'=' * 30}\\n\"\n                        f\"Potential duplicate samples with same uuid/source! \"\n                        f\"uuid={obj['uuid']}, source={obj['source']}, fn={fn}, \"\n                        f\"conflict_rate={conflict_rate:.4f}, add_count={self.count['add']}\\n\"\n                        f\"{'=' * 30}\"\n                    )\n            \n            with self.lock:\n                self.buffer[obj_hash] = obj\n\n            # Random fetch trigger: small probability to extract before buffer is full\n            # This prevents downstream timeout errors\n            if (obj_hash % 10000) < int(10000 * self.random_fetch):\n                return False  # Trigger extraction\n            \n            # Check if buffer has reached capacity\n            return len(self.buffer) < self.buffer_size\n                \n        except Exception as e:\n            logger.error(f\"Error in LocalShuffleBuffer.add(): {traceback.format_exc()}\")\n            raise\n\n    def get(self) -> dict:\n        \"\"\"\n        Extract a sample from the buffer.\n        \n        Returns:\n            A sample object from the buffer\n            \n        Raises:\n            ValueError: If buffer is empty\n        \"\"\"\n        if len(self.buffer) == 0:\n            raise ValueError(\"Cannot get sample from empty buffer\")\n\n        with self.lock:\n            # Pop first item from SortedDict (provides random-like access due to hashing)\n            # popitem(0) removes the first (smallest) key-value pair\n            return self.buffer.popitem(0)[1]\n\n    def __len__(self) -> int:\n        \"\"\"Return current number of samples in the buffer.\"\"\"\n        return len(self.buffer)\n"
  },
  {
    "path": "pretrain/onerec_llm/data/qwen3_dataset.py",
    "content": "import logging\n\nimport os\nimport json\nimport time\nimport traceback\nimport random\nimport re\n\nimport multiprocessing\nimport numpy as np\n\nimport webdataset as wds\nfrom easydict import EasyDict as edict \nfrom typing import Union, Iterable, Optional, List, Dict, Tuple, Any\n\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom torch.utils.data import IterableDataset\n\nfrom transformers import AutoTokenizer, AutoConfig\n\nfrom onerec_llm.data.local_shuffle_buffer import LocalShuffleBuffer\n\nfrom onerec_llm.utils.common import print_rank_0\nfrom onerec_llm.utils.worker_utils import pytorch_worker_info\nfrom onerec_llm.utils.data_utils import shell_hdfs_ls, load_parquet_file\n\nfrom onerec_llm.models.qwen3.configuration_qwen3 import Qwen3Config\n\n\nlogger = logging.getLogger(__name__)\n\ndef set_kwargs(self, kwargs, **_kwargs):\n    kwargs.update(_kwargs)\n    self.kwargs = edict(kwargs)\n    for k, v in kwargs.items():\n        setattr(self, k, v)\n\nclass Qwen3ChatCompletionDataset(IterableDataset):\n    def __init__(self, **kwargs):\n        set_kwargs(self, kwargs)\n        print_rank_0(f\"ChatCompletionDataset init with kwargs={kwargs}\")\n\n        try:\n            model_config = AutoConfig.from_pretrained(self.kwargs.base_model_dir)\n        except Exception:\n            model_config = Qwen3Config.from_pretrained(self.kwargs.base_model_dir)\n\n        self.pad_token_id = model_config.pad_token_id\n        self.dataset, self.total_samples = self._build_source_dataset(self.sources)\n\n        # for data_source monitor\n        self.source_sample_cnt = {}\n        self.source_error_cnt = {}\n        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_dir, trust_remote_code=True)\n        self.max_sample_length = min(self.max_length, self.kwargs.get(\"max_sample_length\", 9999999))\n        assert self.max_length > 0\n\n        # Chat template tokens\n        self.im_start_token = \"<|im_start|>\"\n        self.im_end_token = \"<|im_end|>\"\n        self.im_start_token_id = self.tokenizer.encode(self.im_start_token)[0]\n        self.im_end_token_id = self.tokenizer.encode(self.im_end_token)[0]\n\n        # Derive chat template patterns from tokenizer instead of hardcoded token ids.\n        self.assistant_start_pattern = self.tokenizer.encode(\n            f\"{self.im_start_token}assistant\\n\",\n            add_special_tokens=False,\n        )\n        self.im_end_pattern = self.tokenizer.encode(\n            f\"{self.im_end_token}\\n\",\n            add_special_tokens=False,\n        )\n        if not self.im_end_pattern:\n            self.im_end_pattern = self.tokenizer.encode(\n                self.im_end_token,\n                add_special_tokens=False,\n            )\n\n        self.add_think_pattern = self.kwargs.get(\"add_think_pattern\", False)\n        if self.add_think_pattern:\n            logger.info(f\"Thinking pattern enabled: add_think_pattern={self.add_think_pattern}\")\n\n        self.itemic_id_range = self.kwargs.get(\"itemic_id_range\", None)\n        if self.itemic_id_range is not None:\n            assert len(self.itemic_id_range) == 2, \"itemic_id_range must be a list of two elements\"\n            assert self.itemic_id_range[0] < self.itemic_id_range[1], \"itemic_id_range[0] must be less than itemic_id_range[1]\"\n\n    def _build_source_dataset(self, sources):\n        \"\"\"Build WebDataset from source configuration files.\n        \n        Args:\n            sources: String (comma-separated) or list of JSON config file paths\n            \n        Returns:\n            tuple: (dataset, total_samples)\n        \"\"\"\n        if isinstance(sources, str):\n            sources = sources.split(\",\")\n        \n        # Read URLs from configuration files\n        urls = []\n        total_samples = 0\n        for source in sources:\n            with open(source, encoding=\"utf-8\") as f:\n                index = json.loads(f.read())[\"shardlist\"]\n                source_dir = os.path.dirname(source)\n                for item in index:\n                    urls.append(os.path.join(source_dir, item[\"url\"]))\n                    total_samples += item[\"nsamples\"]\n\n        # Sort, shuffle and broadcast URLs across all ranks\n        urls.sort()\n        random.shuffle(urls)\n        url_list = [urls]\n        dist.broadcast_object_list(url_list, src=0)\n        urls = url_list[0]\n        logger.info(f\"[RANK{dist.get_rank()}] Loaded {len(urls)} URLs, total_samples={total_samples}\")\n\n        # Build WebDataset\n        dataset = wds.WebDataset(\n            urls,\n            handler=wds.warn_and_continue,\n            resampled=True,\n            shardshuffle=True,\n            cache_dir=\"/tmp/_wids_cache\",\n            nodesplitter=wds.split_by_node,\n            workersplitter=wds.split_by_worker\n        )\n        \n        dataset = dataset.shuffle(\n            self.shuffle_size, \n            initial=self.shuffle_initial_size\n        ).decode(\"pil\", handler=wds.warn_and_continue)\n\n        return dataset, total_samples\n    \n    def _convert_messages(self, messages):\n        msg_list = []\n        for msg in messages:\n            content = msg['content']\n            if isinstance(content, str):\n                msg_list.append({\n                    'role': msg['role'],\n                    'content': content\n                })\n            elif isinstance(content, dict) and 'type' in content and content['type'] == 'text':\n                msg_list.append({\n                    'role': msg['role'],\n                    'content': content['text']\n                })\n            elif isinstance(content, list) and len(content) > 0:\n                content_text = \"\"\n                for c in content:\n                    if isinstance(c, dict) and 'type' in c and c['type'] == 'text':\n                        content_text += c['text']\n                    elif isinstance(c, str):\n                        content_text += c\n                    else:\n                        continue\n                msg_list.append({\n                    'role': msg['role'],\n                    'content': content_text\n                })\n            else:\n                raise ValueError(f\"Unsupported content type: {type(content)}\")\n        \n        if self.add_think_pattern:\n            # Process thinking pattern: add /think or /no_think suffix to user messages\n            # based on whether assistant message contains reasoning content\n            for i in range(len(msg_list)):\n                if msg_list[i]['role'] == 'assistant':\n                    assistant_content = msg_list[i]['content']\n                    \n                    # Find corresponding user message (typically the previous one)\n                    user_idx = i - 1\n                    if user_idx < 0 or msg_list[user_idx]['role'] != 'user':\n                        continue\n                    \n                    # Check if assistant content contains <think> tags\n                    pattern = r'<think>(.*?)</think>'\n                    match = re.search(pattern, assistant_content, re.DOTALL)\n                    \n                    if match is None:\n                        # No reasoning tags found: add empty tags and mark as /no_think\n                        msg_list[user_idx]['content'] += \"/no_think\"\n                        msg_list[i]['content'] = \"<think>\\n</think>\\n\" + assistant_content\n                    else:\n                        # Reasoning tags found: check if they contain actual content\n                        reasoning_content = match.group(1)\n                        if reasoning_content.strip():\n                            # Has reasoning content: mark as /think\n                            msg_list[user_idx]['content'] += \"/think\"\n                        else:\n                            # Empty reasoning tags: mark as /no_think\n                            msg_list[user_idx]['content'] += \"/no_think\"\n            \n        return msg_list\n\n    def _get_assistant_mask(self, batch_input_ids: torch.Tensor,\n                       start_pattern: Optional[List[int]],\n                       end_pattern: Optional[List[int]]):\n        \"\"\"\n        Generate mask for assistant tokens in chat format.\n        \n        Args:\n            batch_input_ids: Input token IDs\n            start_pattern: Pattern to identify start of assistant response\n            end_pattern: Pattern to identify end of assistant response\n        \n        Returns:\n            mask: Boolean mask indicating which tokens to compute loss on\n        \"\"\"\n        if not start_pattern:\n            start_pattern = self.assistant_start_pattern\n        if not end_pattern:\n            end_pattern = self.im_end_pattern\n\n        masks = []\n        for input_ids in batch_input_ids:\n            ids = input_ids.tolist()\n            mask = [0] * len(ids)\n            start_len = len(start_pattern)\n            end_len = len(end_pattern)\n            i = 0\n\n            while i <= len(ids) - start_len:\n                if ids[i:i + start_len] != start_pattern:\n                    i += 1\n                    continue\n\n                content_start = i + start_len\n                j = content_start\n                found_end = False\n                while j <= len(ids) - end_len:\n                    if ids[j:j + end_len] == end_pattern:\n                        found_end = True\n                        break\n                    j += 1\n\n                if not found_end:\n                    for k in range(content_start, len(ids)):\n                        mask[k] = 1\n                    break\n\n                for k in range(content_start, j):\n                    mask[k] = 1\n\n                i = j + end_len\n\n            masks.append(mask)\n        return torch.tensor(masks, dtype=torch.long)\n    \n    def _get_rope_index_qwen3(\n                                self,\n                                input_ids: torch.LongTensor,\n                            ) -> torch.Tensor:\n        position_ids = torch.arange(input_ids.shape[1], device=input_ids.device)\n        position_ids = position_ids.unsqueeze(0).expand(input_ids.shape[0], -1)\n        return position_ids\n    \n    def _process_completion(self, sample: Dict[str, Any]) -> Dict[str, torch.Tensor]:\n        \"\"\"\n        Process segments format data into model inputs.\n        \n        Args:\n            sample: Sample containing segments with pre-tokenized tokens\n            \n        Returns:\n            Dictionary containing input_ids, attention_mask, labels, etc.\n        \"\"\"\n        segments = sample[\"json\"][\"segments\"]\n\n        segments_text = \"\"\n\n        for segment in segments:\n            if segment[\"type\"] == \"text\":\n                segments_text += segment[\"text\"]\n            else:\n                logger.error(f\"segment type is not text, skip: {segment}\")\n                continue\n        \n        # Note: do not use self.tokenizer.eos_token as it's always set to <im_end>\n        # References: \n        # 1. https://huggingface.co/Qwen/Qwen3-8B/blob/main/tokenizer_config.json#L232\n        # 2. https://qwen.readthedocs.io/zh-cn/latest/getting_started/concepts.html#control-tokens\n        segments_text += self.tokenizer.pad_token\n        \n        # Tokenize\n        inputs = self.tokenizer(\n            segments_text,\n            return_tensors=\"pt\",\n            padding=False,\n            truncation=False\n        )\n\n        input_ids = inputs[\"input_ids\"]\n        \n        # Check length\n        if input_ids.shape[-1] > self.max_length:\n            raise ValueError(f\"Sample too long: {input_ids.shape[-1]} > {self.max_length}\")\n        \n        # Mask EOS token\n        inputs[\"loss_mask\"] = torch.ones_like(input_ids)\n        inputs[\"loss_mask\"][..., -1] = 0\n\n        # itemic id index mask\n        itemic_id_mask = torch.zeros_like(input_ids)\n        if self.itemic_id_range is not None:\n            itemic_id_mask[(input_ids >= self.itemic_id_range[0]) & (input_ids <= self.itemic_id_range[1])] = 1\n        inputs[\"itemic_id_mask\"] = itemic_id_mask\n        \n        # Generate position IDs\n        inputs[\"position_ids\"] = self._get_rope_index_qwen3(input_ids)\n        \n        return inputs\n\n    def _process_chat(self, sample: Dict[str, Any]) -> Dict[str, torch.Tensor]:\n        \"\"\"\n        Process messages format data into model inputs.\n        \n        Args:\n            sample: Sample containing messages in the new format\n            \n        Returns:\n            Dictionary containing input_ids, attention_mask, labels, etc.\n        \"\"\"\n        msg_key = \"message\" if \"message\" in sample[\"json\"] else \"messages\"\n        messages = sample[\"json\"][msg_key]\n\n        msg_converted = self._convert_messages(messages)\n        \n        # Convert messages to text using chat template\n        text = self.tokenizer.apply_chat_template(\n            msg_converted, \n            tokenize=False, \n            add_generation_prompt=False\n        )\n        \n        # Add EOS token\n        text += self.tokenizer.pad_token\n\n        # Tokenize\n        inputs = self.tokenizer(\n            text,\n            return_tensors=\"pt\",\n            padding=False,\n            truncation=False\n        )\n\n        input_ids = inputs[\"input_ids\"]        \n        # Check length\n        if input_ids.shape[-1] > self.max_length:\n            raise ValueError(f\"Sample too long: {input_ids.shape[-1]} > {self.max_length}\")\n        \n        inputs[\"loss_mask\"] = self._get_assistant_mask(\n            input_ids,\n            start_pattern=self.assistant_start_pattern,\n            end_pattern=self.im_end_pattern,\n        )\n        \n        # Mask EOS token\n        inputs[\"loss_mask\"][..., -1] = 0\n\n        # itemic id index mask\n        itemic_id_mask = torch.zeros_like(input_ids)\n        if self.itemic_id_range is not None:\n            itemic_id_mask[(input_ids >= self.itemic_id_range[0]) & (input_ids <= self.itemic_id_range[1])] = 1\n        inputs[\"itemic_id_mask\"] = itemic_id_mask\n        \n        # Generate position IDs\n        inputs[\"position_ids\"] = self._get_rope_index_qwen3(input_ids)\n        \n        return inputs\n\n    def _process(self, sample, source_name=None):\n        if \"segments\" in sample[\"json\"] and sample[\"json\"][\"segments\"] is not None:\n            inputs = self._process_completion(sample)\n        else:\n            inputs = self._process_chat(sample)\n\n        inputs['epoch_idx'] = sample['epoch_idx']\n        if not inputs:\n            raise ValueError(\"Empty inputs, skip\")\n        \n        # Check if sample exceeds max_sample_length (always <= max_length)\n        if inputs[\"input_ids\"].shape[-1] > self.max_sample_length:\n            logger.warning(f\"Sample exceeds max_sample_length={self.max_sample_length}, length={inputs['input_ids'].shape[-1]}\")\n            raise ValueError(\n                f\"Unable to generate sample within max_sample_length={self.max_sample_length}\"\n            )\n        \n        return inputs\n\n    def _cut_sample(self, inputs, packable_length):\n        inputs[\"input_ids\"] = inputs[\"input_ids\"][:, :packable_length]\n        inputs[\"attention_mask\"] = inputs[\"attention_mask\"][:, :packable_length]\n        inputs[\"loss_mask\"] = inputs[\"loss_mask\"][:, :packable_length]\n        inputs[\"position_ids\"] = inputs[\"position_ids\"][..., :packable_length]\n        inputs[\"itemic_id_mask\"] = inputs[\"itemic_id_mask\"][:, :packable_length]\n        return inputs\n\n    def _append_sample_packing(self,\n                                inputs: Dict[str, torch.Tensor],\n                                packed_input_ids: List[torch.Tensor],\n                                packed_position_ids: List[torch.Tensor],\n                                packed_loss_mask: List[torch.Tensor],\n                                packed_itemic_id_mask: List[torch.Tensor],\n                                packed_sample_idx: List[torch.Tensor],\n                                cu_seqlens: List[int],\n                                sample_idx: Optional[int] = None,\n                                ):\n        packable_length = self.max_length - cu_seqlens[-1]\n        if packable_length == 0: return\n\n        if self.cut_to_pad and inputs['input_ids'].shape[1] > packable_length:\n            inputs = self._cut_sample(inputs, packable_length)\n\n        packed_input_ids.append(inputs[\"input_ids\"].flatten())\n        packed_loss_mask.append(inputs[\"loss_mask\"].flatten())\n        packed_position_ids.append(inputs[\"position_ids\"])\n        packed_itemic_id_mask.append(inputs[\"itemic_id_mask\"].flatten())\n\n        if sample_idx is None:\n            sample_idx = len(cu_seqlens) - 1\n\n        packed_sample_idx.append(\n            torch.full_like(packed_input_ids[-1], sample_idx))\n\n        cu_seqlens.append(cu_seqlens[-1] + len(inputs[\"input_ids\"][0]))\n        return len(inputs[\"input_ids\"][0])\n\n    def _packing(self, buffer: List[Dict[str, torch.Tensor]]):\n        packed_input_ids: List[torch.Tensor] = []\n        packed_position_ids: List[torch.Tensor] = []\n        packed_loss_mask: List[torch.Tensor] = []\n        packed_itemic_id_mask: List[torch.Tensor] = []\n        packed_sample_idx: List[torch.Tensor] = []\n        cu_seqlens: List[int] = [0]\n        epochs = []\n        valid_seq_len = 0\n        for _, inputs in enumerate(buffer):\n            epochs.append(inputs.get(\"epoch_idx\", None))\n            valid_seq_len += self._append_sample_packing(inputs,\n                                            packed_input_ids,\n                                            packed_position_ids,\n                                            packed_loss_mask,\n                                            packed_itemic_id_mask,\n                                            packed_sample_idx,\n                                            cu_seqlens,\n                                            )\n\n        packed_input_ids = torch.cat(packed_input_ids, dim=0).unsqueeze(0)\n        packed_loss_mask = torch.cat(packed_loss_mask, dim=0).unsqueeze(0)\n        packed_itemic_id_mask = torch.cat(packed_itemic_id_mask, dim=0).unsqueeze(0)\n        packed_position_ids = torch.cat(packed_position_ids, dim=-1)\n        packed_sample_idx = torch.cat(packed_sample_idx, dim=0).unsqueeze(0)\n\n        max_length = max(self.max_length, packed_input_ids.numel())\n        padding_len = (max_length + 7) // 8 * 8 + 64 - packed_input_ids.numel()\n        assert padding_len > 0, f\"padding_len should be greater than 0, got {padding_len}\"\n        packed_input_ids = F.pad(\n            packed_input_ids, (0, padding_len),\n            value=self.tokenizer.pad_token_id)\n        packed_sample_idx = F.pad(packed_sample_idx, (0, padding_len), value=-1)\n        packed_position_ids = F.pad(packed_position_ids, (0, padding_len), value=0)\n        packed_loss_mask = F.pad(packed_loss_mask, (0, padding_len), value=0)\n        packed_itemic_id_mask = F.pad(packed_itemic_id_mask, (0, padding_len), value=False)\n        cu_seqlens.append(cu_seqlens[-1] + padding_len)\n\n        if self.kwargs.get(\"full_attention\", False):\n            packed_position_ids = self._get_rope_index_qwen3(packed_input_ids)\n            cu_seqlens = [0, cu_seqlens[-1]]\n\n        epochs = [x for x in epochs if x is not None]\n        inputs = {\n            \"input_ids\": packed_input_ids,\n            \"position_ids\": packed_position_ids,\n            \"loss_mask\": packed_loss_mask,\n            \"itemic_id_mask\": packed_itemic_id_mask,\n            \"cu_seqlens\": torch.tensor(cu_seqlens, dtype=torch.int32),\n            \"sample_idx\": packed_sample_idx.to(torch.int32),\n            \"epoch_idx\": torch.tensor([sum(epochs) / len(epochs)], dtype=torch.float32),\n        }\n        return inputs\n\n    def __iter__(self):\n        if self.dataset is None:\n            self.dataset, self.total_samples = self._build_source_dataset(self.sources)\n\n        buffer = []\n        source_list = []\n        cur_length = 0\n        ds_iter = iter(self.dataset)\n        while True:\n            try:\n                sample = next(ds_iter)\n                sample_key = sample[\"__key__\"] if \"__key__\" in sample else \"\"\n                sample_url = sample[\"__url__\"] if \"__url__\" in sample else \"\"\n\n                try:\n                    source_name = sample[\"json\"][\"source\"]\n                except Exception:\n                    source_name = \"None\"\n\n                self.source_sample_cnt.setdefault(source_name, 0)\n                self.source_sample_cnt[source_name] += 1\n            \n                inputs = self._process(sample, source_name)\n            except Exception:\n                self.source_error_cnt.setdefault(source_name, 0)\n                self.source_error_cnt[source_name] += 1\n                error_ratio = self.source_error_cnt[source_name] * 1.0 / \\\n                    self.source_sample_cnt[source_name]\n                \n                rank, world_size, worker, num_workers = pytorch_worker_info()\n                logger.error(\n                    f\"Qwen3ChatCompletionDataset process sample error. worker=r{rank}_w{worker}\"\n                    f\"{source_name=}, {error_ratio=}, {sample_key=}, {sample_url=}, sample=\\n{str(sample)[:50]}\"\n                    f\"errmsg={traceback.format_exc()}\")\n                continue\n\n            sample_length = inputs[\"input_ids\"].shape[-1]\n            if cur_length + sample_length >= self.max_length:\n                if self.cut_to_pad:\n                    buffer.append(inputs)\n                    source_list.append(source_name)\n                    packed_inputs = self._packing(buffer)\n\n                    packed_inputs[\"data_source\"] = source_list\n                    buffer = []\n                    source_list = []\n                    cur_length = 0\n                    if packed_inputs[\"loss_mask\"].sum().item() == 0:\n                        logger.warning(f\"Packed sample has no valid loss tokens, cur_length={cur_length}, skipping. \"\n                                    f\"This usually happens when a single sample has no valid tokens after processing.\")\n                        continue\n                else:\n                    packed_inputs = self._packing(buffer)\n                    packed_inputs[\"data_source\"] = source_list\n                    buffer = [inputs]\n                    source_list = [source_name]\n                    cur_length = sample_length\n\n                if packed_inputs[\"loss_mask\"].sum() == 0:\n                    logger.warning(\"Skipping sample with no valid loss tokens.\")\n                    continue\n\n                yield packed_inputs\n\n            else:\n                buffer.append(inputs)\n                source_list.append(source_name)\n                cur_length += sample_length\n\nclass Qwen3NaiveParquetDataset(IterableDataset):\n    \"\"\"Naive parquet dataset for Qwen3 that handles file reading and parsing.\"\"\"\n    \n    def __init__(self, data_files, num_workers, **kwargs):\n        set_kwargs(self, kwargs, data_files=data_files, num_workers=num_workers)\n        self.local_shuffle_buffer = LocalShuffleBuffer(buffer_size=self.kwargs.get(\"local_shuffle_buffer_size\", 81920), \n                                                        random_fetch=self.kwargs.get(\"local_shuffle_random_fetch\", 0.00001))\n    \n        manager = multiprocessing.Manager()\n        def make_dict(): return manager.dict()\n\n        self.finish_dict_all = make_dict()\n        for i in range(self.num_workers):\n            self.finish_dict_all[i] = make_dict()\n    \n    def _parser(self, raw_row_data, file_url):\n        \"\"\"Parse a single row from parquet file.\"\"\"\n        try:\n            messages = None\n            segments = None\n            \n            if \"messages\" in raw_row_data:\n                messages = raw_row_data[\"messages\"]\n                if isinstance(messages, str):\n                    messages = json.loads(messages)\n\n            if \"segments\" in raw_row_data:\n                segments = raw_row_data[\"segments\"]\n                if isinstance(segments, str):\n                    segments = json.loads(segments)\n\n            data_source = raw_row_data[\"source\"]\n            key = raw_row_data[\"uuid\"]\n            \n            samples = {\n                \"__key__\": key,\n                \"__url__\": file_url,\n            }\n\n            sample_data = {\n                \"source\": data_source,\n            }\n\n            if messages is not None and isinstance(messages, list) and len(messages) > 0:\n                sample_data[\"messages\"] = messages\n            elif segments is not None and isinstance(segments, list) and len(segments) > 0:\n                sample_data[\"segments\"] = segments\n            elif messages is not None and isinstance(messages, np.ndarray):\n                sample_data[\"messages\"] = messages.tolist()\n            else:\n                raise NotImplementedError(f\"Unsupported sample, message type is {type(messages)}, message={messages}, segments type is {type(segments)}, segments={segments}\")\n\n            samples[\"json\"] = sample_data\n            \n            return samples\n        except Exception as e:\n            logger.error(f\"Qwen3NaiveParquetDataset parse sample error: {str(e)}\")\n            return None\n\n    def __iter__local_shuffle(self):\n        rank, world_size, worker, num_workers = pytorch_worker_info()\n        finish_dict = self.finish_dict_all[worker]\n        assert num_workers == self.num_workers\n\n        total_num_workers = num_workers * world_size\n        local_worker_idx = rank * num_workers + worker\n        fn_list = [fn for idx, fn in enumerate(self.data_files) if idx % total_num_workers == local_worker_idx]\n        logger.warning(\n            f\"ParquetDataset Info: {rank=}, {world_size=}, {worker=}, {num_workers=}, {len(fn_list)=}\"\n        )   \n        \n        def get_sample():\n            for fn_index, (fn, epoch_idx) in enumerate(fn_list):\n                try:\n                    df = load_parquet_file(fn).read_row_group(0).to_pandas()\n                except Exception as e:\n                    logger.warning(\n                        f\"ParquetDataset Info: {rank=}, {world_size=}, {worker=}, {num_workers=}, {fn} failed\" + \\\n                        f\"traceback=\\n{traceback.format_exc()}\"\n                    )\n                    continue\n                df['epoch_idx'] = epoch_idx\n                df['fn_idx'] = fn_index\n                df['__fn__'] = fn\n                df['sample_index'] = range(len(df))\n                for i, (_, row) in enumerate(df.iterrows()):\n                    sample_bit = 1 << row['sample_index']\n                    if sample_bit & finish_dict.get((row['__fn__'], row['epoch_idx']), 0) != 0:\n                        logger.debug(f\"[Rank{rank}-Worker{worker}] Skipping already processed sample: \"\n                                    f\"{row['__fn__']}-epoch{row['epoch_idx']}-sample{row['sample_index']}\")\n                        continue\n                    if self.local_shuffle_buffer.add(row, fn, epoch_idx): continue\n                    row = self.local_shuffle_buffer.get()\n                    yield row\n\n            while len(self.local_shuffle_buffer) > 0:\n                row = self.local_shuffle_buffer.get()\n                yield row\n\n        for row in get_sample():\n            sample_bit = 1 << row['sample_index']\n\n            key = (row['__fn__'], row['epoch_idx'])\n            if key not in finish_dict:\n                finish_dict[key] = 0\n            finish_dict[key] |= sample_bit\n\n            sample = self._parser(row, row['__fn__'])\n            sample['epoch_idx'] = torch.tensor(row['epoch_idx'])\n            yield sample\n\n    def __iter__(self,):\n        for sample in self.__iter__local_shuffle():\n            if sample is None: continue\n            yield sample\n    \n    def state_dict(self):\n        \"\"\"Get state dict for checkpointing.\"\"\"\n        rank, world_size, worker, num_workers = pytorch_worker_info()\n\n        state_dict = {\n            \"finish_dict\": dict(self.finish_dict_all[worker]),\n        }\n        return state_dict\n    \n    def load_state_dict(self, state_dict):\n        \"\"\"Load state dict from checkpoint.\"\"\"\n        rank, world_size, worker, num_workers = pytorch_worker_info()\n        \n        finish_dict = state_dict[\"finish_dict\"]\n        \n        # Convert to regular dict to support old checkpoint format\n        tmp_finish_dict = dict(finish_dict)\n        \n        # Clear current state and update\n        self.finish_dict_all[worker].clear()\n        self.finish_dict_all[worker].update(tmp_finish_dict)\n        logger.info(f\"[rank{rank}-worker{worker}] Loaded checkpoint successfully. finish_dict_size={len(tmp_finish_dict)}\")\n\nclass Qwen3ChatCompletionParquetDataset(Qwen3ChatCompletionDataset):\n    def __init__(self, sources, num_workers, shuffle_seed=1024, num_epochs=1, **kwargs):\n        self.rng = random.Random(shuffle_seed)\n        self.num_workers = num_workers\n        self.num_epochs = num_epochs\n        self.cut_to_pad = kwargs.get(\"cut_to_pad\", True)\n        self.kwargs = kwargs\n        self.num_readers = kwargs.get(\"num_readers\", 1)\n        self.shuffle_window = kwargs.get(\"shuffle_window\", 0)\n        super().__init__(sources=sources, **kwargs)\n\n    def _build_source_dataset(self, sources):\n        data_file_list = []\n        if dist.get_rank() == 0:\n            data_files = []\n            if isinstance(sources, str) and sources.endswith(\".json\"):\n                with open(sources, \"r\") as fp:\n                    data_files = json.loads(fp.read())\n                    data_files = [fn for fn in data_files if fn.endswith(\".parquet\")]\n            elif isinstance(sources, list):\n                for source in sources:\n                    hdfs_files = shell_hdfs_ls(source)\n                    data_files += [fn for fn in hdfs_files if fn.endswith(\".parquet\")]\n            # repeat\n            for i in range(self.num_epochs):\n                data_files.sort()\n                self.rng.shuffle(data_files)\n                data_file_list += [(fn, i) for fn in data_files]\n            logger.info(f\"ParquetDataset rank{dist.get_rank()}: original_file_num={len(data_files)}, total_file_num={len(data_file_list)}\")\n\n        t = [data_file_list]\n        dist.broadcast_object_list(t, src=0)\n        data_file_list = t[0]\n\n        logger.info(f\"ParquetDataset rank{dist.get_rank()}: file_num={len(data_file_list)}\")\n        if len(data_file_list) == 0:\n            raise ValueError(f\"no datafile found!\")\n\n        dataset = Qwen3NaiveParquetDataset(data_file_list, self.num_workers, **self.kwargs)\n        return dataset, -1\n\n    def state_dict(self):\n        if self.dataset is None:\n            return {}\n        return self.dataset.state_dict()\n    \n    def load_state_dict(self, state_dict):\n        if self.dataset is None:\n            return\n        self.dataset.load_state_dict(state_dict)"
  },
  {
    "path": "pretrain/onerec_llm/losses/__init__.py",
    "content": "from onerec_llm.losses.ce import CrossEntropyLoss, ChunkedLossComputer\n\n__all__ = [\n  \"CrossEntropyLoss\",\n  \"ChunkedLossComputer\",\n]\n"
  },
  {
    "path": "pretrain/onerec_llm/losses/ce.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom onerec_llm.utils.time_tracker import TimeTracker\n\n# ===================================================================\n# Cross-Entropy Loss Function\n# ===================================================================\n\nclass CrossEntropyLoss(nn.Module):\n    \"\"\"\n    An efficient CrossEntropyLoss module that avoids redundant calculations.\n    It first computes per-token losses and then manually applies the reduction.\n    (Based on the user-provided, superior implementation).\n    \"\"\"\n    def __init__(self,\n                 ignore_index: int = -100,\n                 return_token_loss: bool = False,\n                 shift_labels: bool = True,\n                 reduction: str = \"mean\"):\n        super().__init__()\n        self.ignore_index = ignore_index\n        self.return_token_loss = return_token_loss\n        self.reduction = reduction\n        self.shift_labels = shift_labels\n\n    def forward(self, logits: torch.Tensor, labels: torch.Tensor):\n        \"\"\"\n        Args:\n            logits (torch.Tensor): A single tensor of shape (..., vocab_size).\n            labels (torch.Tensor): Ground truth labels.\n        \"\"\"\n        vocab_size = logits.shape[-1]\n        \n        if self.shift_labels:\n          logits = logits[:, :-1, :]\n          labels = labels[:, 1:]\n\n        # Reshape for cross-entropy calculation\n        logits_flat = logits.float().reshape(-1, vocab_size)\n        labels_flat = labels.reshape(-1)\n\n        # Step 1: Compute per-token loss. This is the base for all other calculations.\n        per_token_loss = F.cross_entropy(\n            logits_flat,\n            labels_flat,\n            ignore_index=self.ignore_index,\n            reduction=\"none\"\n        )\n        \n        # Step 2: Manually apply reduction to get the final loss.\n        loss = per_token_loss.sum()\n        if self.reduction == \"mean\":\n            # Ensure we divide by the number of valid (non-ignored) tokens\n            total_elements = (labels_flat != self.ignore_index).sum()\n            if total_elements > 0:\n                loss /= total_elements\n            else: # Handle case where all tokens are ignored\n                loss.zero_()\n\n        # Return what's requested\n        if self.return_token_loss:\n            return loss, per_token_loss\n        \n        return loss\n\n\n# ===================================================================\n# Memory-Efficient Chunked Loss Computer\n# ===================================================================\n\nclass ChunkedLossComputer:\n    \"\"\"\n    Memory-efficient chunked loss computer for solving OOM issues caused by large lm_head in LLMs.\n\n    By computing the input sequence in chunks and manually accumulating gradients,\n    it avoids allocating huge intermediate tensors for the entire sequence at once.\n\n    Note: The returned loss has already been backpropagated and detached,\n    and cannot be used for operations requiring gradients.\n    \"\"\"\n    def __init__(self, lm_head: nn.Module, loss_fn: nn.Module, minibatch_size: int, shift_labels: bool = True):\n        \"\"\"\n        Initialize the chunked loss computer.\n\n        Args:\n            lm_head: The output layer of the language model (typically nn.Linear)\n            loss_fn: Loss function, must return (avg_loss, per_token_loss) tuple\n            minibatch_size: Size of each chunk, used to control memory usage\n            shift_labels: Whether to shift labels (for autoregressive models)\n        \"\"\"\n        if not isinstance(lm_head, nn.Module) or not isinstance(loss_fn, nn.Module):\n            raise TypeError(\"lm_head and loss_fn must be instances of nn.Module\")\n            \n        self.lm_head = lm_head\n        self.loss_fn = loss_fn\n        self.minibatch_size = minibatch_size\n        self.shift_labels = shift_labels\n        self.loss_info = {}\n        self.ticker = TimeTracker()\n\n    def forward_and_backward(self, input: torch.Tensor, labels: torch.Tensor, loss_fn_args: dict = {}):\n        \"\"\"\n        Execute chunked forward and backward propagation.\n\n        Args:\n            input: Input tensor with shape [batch_size, seq_len, hidden_dim]\n            labels: Label tensor with shape [batch_size, seq_len]\n            loss_fn_args: Additional arguments passed to the loss function\n\n        Returns:\n            tuple[torch.Tensor, torch.Tensor]: (final_avg_loss, per_token_loss)\n\n        Note: The returned loss has already been backpropagated and detached,\n        and cannot be used for operations requiring gradients.\n        \"\"\"\n        self.ticker.tick(\"lm_head\")\n        params = list(self.lm_head.parameters())\n        grad_accs = [torch.zeros_like(p) for p in params]\n        grad_input_full = torch.zeros_like(input)\n\n        total_loss_sum_for_reporting = torch.tensor(0.0, device=input.device)\n        all_per_token_losses = []\n\n        seq_len = input.size(1)\n        \n        # Calculate total number of valid elements\n        labels_to_count = labels[:, 1:] if self.shift_labels else labels\n        total_elements = (labels_to_count != getattr(self.loss_fn, 'ignore_index', -100)).sum()\n        \n        if total_elements.item() == 0:\n            return torch.tensor(0.0, device=input.device), None\n\n        # Chunked forward and gradient accumulation\n        for i in range(0, seq_len, self.minibatch_size):\n            start, end = i, min(i + self.minibatch_size, seq_len)\n            input_chunk = input[:, start:end, :].detach().requires_grad_()\n            \n            logits_chunk = self.lm_head(input_chunk)\n\n            if self.shift_labels:\n                label_start, label_end = start + 1, end + 1\n                labels_chunk = labels[:, label_start:label_end]\n                # Ensure logits and labels have matching lengths\n                if logits_chunk.size(1) > labels_chunk.size(1):\n                    logits_chunk = logits_chunk[:, :labels_chunk.size(1), :]\n            else:\n                labels_chunk = labels[:, start:end]\n\n            if labels_chunk.numel() == 0:\n                continue\n\n            logits_flat = logits_chunk.reshape(-1, self.lm_head.out_features)\n            labels_flat = labels_chunk.reshape(-1)\n            \n            # Compute loss\n            loss_chunk_avg, per_token_loss_chunk = self.loss_fn(logits_flat, labels_flat, **loss_fn_args)\n\n            # Convert to sum loss for backward propagation\n            valid_tokens_in_chunk = (labels_flat != getattr(self.loss_fn, 'ignore_index', -100)).sum()\n            \n            if valid_tokens_in_chunk.item() == 0:\n                all_per_token_losses.append(per_token_loss_chunk.detach())\n                continue\n            \n            loss_chunk_sum = loss_chunk_avg * valid_tokens_in_chunk\n\n            # Manually compute gradients and accumulate\n            tensors_to_grad = [p for p in params if p.requires_grad] + [input_chunk]\n            grads = torch.autograd.grad(outputs=loss_chunk_sum, inputs=tensors_to_grad, retain_graph=False)\n        \n            grad_idx = 0\n            for j in range(len(params)):\n                if params[j].requires_grad:\n                    grad_accs[j] += grads[grad_idx]\n                    grad_idx += 1\n            grad_input_full[:, start:end, :] = grads[grad_idx]\n\n            total_loss_sum_for_reporting += loss_chunk_sum.detach()\n            all_per_token_losses.append(per_token_loss_chunk.detach())\n        \n        # Apply accumulated gradients\n        for j, p in enumerate(params):\n            if p.requires_grad:\n                p.grad = grad_accs[j] / total_elements\n\n        self.ticker.tick(\"llm\")        \n        input.backward(gradient=grad_input_full / total_elements)\n        self.ticker.tick(\"done\")\n        \n        final_avg_loss = (total_loss_sum_for_reporting / total_elements).detach()\n        per_token_loss = torch.cat(all_per_token_losses) if all_per_token_losses else None\n        final_avg_loss.requires_grad = True\n\n        self.loss_info = {\n            'loss': final_avg_loss,\n            'per_token_loss': per_token_loss\n        }\n        return final_avg_loss, per_token_loss\n"
  },
  {
    "path": "pretrain/onerec_llm/models/qwen3/__init__.py",
    "content": "# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import TYPE_CHECKING\n\nfrom transformers.utils.import_utils import _LazyModule, define_import_structure\n\nif TYPE_CHECKING:\n    from .configuration_qwen3 import *\n    from .modeling_qwen3 import *\nelse:\n    import sys\n\n    _file = globals()[\"__file__\"]\n    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)\n"
  },
  {
    "path": "pretrain/onerec_llm/models/qwen3/configuration_qwen3.py",
    "content": "# coding=utf-8\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Qwen3 model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import rope_config_validation\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen3Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a\n    Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of\n    Qwen3-8B [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B).\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151936):\n            Vocabulary size of the Qwen3 model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`Qwen3Model`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 22016):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 32):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.\n        head_dim (`int`, *optional*, defaults to 128):\n            The attention head dimension.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 32768):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        use_sliding_window (`bool`, *optional*, defaults to `False`):\n            Whether to use sliding window attention.\n        sliding_window (`int`, *optional*, defaults to 4096):\n            Sliding window attention (SWA) window size. If not specified, will default to `4096`.\n        max_window_layers (`int`, *optional*, defaults to 28):\n            The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n\n    ```python\n    >>> from transformers import Qwen3Model, Qwen3Config\n\n    >>> # Initializing a Qwen3 style configuration\n    >>> configuration = Qwen3Config()\n\n    >>> # Initializing a model from the Qwen3-8B style configuration\n    >>> model = Qwen3Model(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen3\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    # Default tensor parallel plan for base model `Qwen3`\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=4096,\n        intermediate_size=22016,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        head_dim=128,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        use_sliding_window=False,\n        sliding_window=4096,\n        max_window_layers=28,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.use_sliding_window = use_sliding_window\n        self.sliding_window = sliding_window  # we check `use_sliding_window` in the modeling code\n        self.max_window_layers = max_window_layers\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.head_dim = head_dim\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n        rope_config_validation(self)\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\n__all__ = [\"Qwen3Config\"]\n"
  },
  {
    "path": "pretrain/onerec_llm/models/qwen3/modeling_qwen3.py",
    "content": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.\n#               Do NOT edit this file manually as any edits will be overwritten by the generation of\n#             the file from the modular. If any change should be done, please apply the change to the\n#                          modular_qwen3.py file directly. One of our CI enforces this.\n#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# coding=utf-8\n# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom functools import partial\nfrom typing import Callable, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache\nfrom transformers.generation.utils import GenerationMixin\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\nfrom transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom transformers.processing_utils import Unpack\nfrom transformers.utils.generic import LossKwargs, can_return_tuple\nfrom transformers.utils.doc import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings\nfrom transformers.utils import logging\nfrom transformers.utils.deprecation import deprecate_kwarg\nfrom .configuration_qwen3 import Qwen3Config\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"Qwen/Qwen3-8B\"\n_CONFIG_FOR_DOC = \"Qwen3Config\"\n\n\nclass Qwen3RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Qwen3RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\nclass Qwen3MLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs,\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\nclass Qwen3Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: Qwen3Config, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n        self.scaling = self.head_dim**-0.5\n        self.attention_dropout = config.attention_dropout\n        self.is_causal = True\n\n        self.q_proj = nn.Linear(\n            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.k_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.v_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.o_proj = nn.Linear(\n            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias\n        )\n        self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)  # unlike olmo, only on the head dim!\n        self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)  # thus post q_norm does not need reshape\n        self.sliding_window = config.sliding_window\n        if not (\n            self.config.use_sliding_window\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and self.layer_idx >= self.config.max_window_layers\n        ):\n            self.sliding_window = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n                logger.warning_once(\n                    \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to \"\n                    'eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n                )\n            else:\n                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        import time\n        t0 = time.time()\n        \n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            sliding_window=self.sliding_window,  # diff with Llama\n            **kwargs,\n        )\n\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\nclass Qwen3DecoderLayer(nn.Module):\n    def __init__(self, config: Qwen3Config, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)\n        self.mlp = Qwen3MLP(config)\n        self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        if (\n            config.sliding_window and config._attn_implementation != \"flash_attention_2\"\n        ):  # diff with Llama is this warning\n            logger.warning_once(\n                f\"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; \"\n                \"unexpected results may be encountered.\"\n            )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        return outputs\n\n\nclass Qwen3RotaryEmbedding(nn.Module):\n    def __init__(self, config: Qwen3Config, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:\n            self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    @torch.no_grad()\n    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)\n    def forward(self, x, position_ids):\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos() * self.attention_scaling\n            sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nQWEN3_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`Qwen3Config`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen3 Model outputting raw hidden-states without any specific head on top.\",\n    QWEN3_START_DOCSTRING,\n)\nclass Qwen3PreTrainedModel(PreTrainedModel):\n    config_class = Qwen3Config\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Qwen3DecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = True\n    _supports_attention_backend = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nQWEN3_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen3 Model outputting raw hidden-states without any specific head on top.\",\n    QWEN3_START_DOCSTRING,\n)\nclass Qwen3Model(Qwen3PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`]\n\n    Args:\n        config: Qwen3Config\n    \"\"\"\n\n    def __init__(self, config: Qwen3Config):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = Qwen3RotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @can_return_tuple\n    @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],\n    ) -> BaseModelOutputWithPast:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training and use_cache:\n            logger.warning_once(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\"\n            )\n            use_cache = False\n\n        # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache\n        if not isinstance(past_key_values, (type(None), Cache)):\n            raise ValueError(\"The `past_key_values` should be either a `Cache` object or `None`.\")\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache()\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n        \n        cu_seqlens = flash_attn_kwargs.pop(\"cu_seqlens\", None)\n        if cu_seqlens is not None:\n            cu_seqlens = cu_seqlens.to(dtype=torch.int32)\n            flash_attn_kwargs[\"cu_seq_lens_q\"] = cu_seqlens\n            flash_attn_kwargs[\"cu_seq_lens_k\"] = cu_seqlens\n            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()\n            flash_attn_kwargs[\"max_length_q\"] = max_seqlen\n            flash_attn_kwargs[\"max_length_k\"] = max_seqlen\n            causal_mask = None\n        else:\n            causal_mask = self._update_causal_mask(\n                attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n            )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n\n        for decoder_layer in self.layers[: self.config.num_hidden_layers]:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    partial(decoder_layer.__call__, **flash_attn_kwargs),\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                    position_embeddings,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                    **flash_attn_kwargs,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values if use_cache else None,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool = False,\n    ):\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and past_key_values is not None:\n                is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]\n                if is_padding_right:\n                    raise ValueError(\n                        \"You are attempting to perform batched generation with padding_side='right'\"\n                        \" this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to \"\n                        \" call `tokenizer.padding_side  = 'left'` before tokenizing the input. \"\n                    )\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n        using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and not (using_static_cache or using_sliding_window_cache)\n            and not output_attentions\n        ):\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                sliding_window=self.config.sliding_window,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        # SlidingWindowCache or StaticCache\n        if using_sliding_window_cache or using_static_cache:\n            target_length = past_key_values.get_max_cache_shape()\n        # DynamicCache or no cache\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).\n        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(\n            attention_mask,\n            sequence_length=sequence_length,\n            target_length=target_length,\n            dtype=dtype,\n            device=device,\n            cache_position=cache_position,\n            batch_size=input_tensor.shape[0],\n            config=self.config,\n            past_key_values=past_key_values,\n        )\n\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type in [\"cuda\", \"xpu\"]\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n    @staticmethod\n    def _prepare_4d_causal_attention_mask_with_cache_position(\n        attention_mask: torch.Tensor,\n        sequence_length: int,\n        target_length: int,\n        dtype: torch.dtype,\n        device: torch.device,\n        cache_position: torch.Tensor,\n        batch_size: int,\n        config: Qwen3Config,\n        past_key_values: Cache,\n    ):\n        \"\"\"\n        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape\n        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.\n            sequence_length (`int`):\n                The sequence length being processed.\n            target_length (`int`):\n                The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.\n            dtype (`torch.dtype`):\n                The dtype to use for the 4D attention mask.\n            device (`torch.device`):\n                The device to place the 4D attention mask on.\n            cache_position (`torch.Tensor`):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            batch_size (`torch.Tensor`):\n                Batch size.\n            config (`Qwen3Config`):\n                The model's configuration class\n            past_key_values (`Cache`):\n                The cache class that is being used currently to generate\n        \"\"\"\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n            causal_mask = attention_mask\n        else:\n            min_dtype = torch.finfo(dtype).min\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            if config.sliding_window is not None:\n                # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also\n                # the check is needed to verify is current checkpoint was trained with sliding window or not\n                if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:\n                    sliding_attend_mask = torch.arange(target_length, device=device) <= (\n                        cache_position.reshape(-1, 1) - config.sliding_window\n                    )\n                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)\n            causal_mask *= diagonal_attend_mask\n            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                if attention_mask.shape[-1] > target_length:\n                    attention_mask = attention_mask[:, :target_length]\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(\n                    causal_mask.device\n                )\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        return causal_mask\n\n\nclass KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...\n\n\nclass Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    _tp_plan = {\"lm_head\": \"colwise_rep\"}\n    _pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}\n    wrap_modules = {Qwen3DecoderLayer\n    }\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = Qwen3Model(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.chunked_loss_computer = getattr(config, \"chunked_loss_computer\", False)\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @can_return_tuple\n    @deprecate_kwarg(\"num_logits_to_keep\", version=\"4.50\", new_name=\"logits_to_keep\")\n    @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> CausalLMOutputWithPast:\n        r\"\"\"\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n            logits_to_keep (`int` or `torch.Tensor`, *optional*):\n                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n                If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.\n                This is useful when using packed tensor format (single dimension for batch and sequence length).\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Qwen3ForCausalLM\n\n        >>> model = Qwen3ForCausalLM.from_pretrained(\"Qwen/Qwen3-8B\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen3-8B\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs: BaseModelOutputWithPast = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = outputs.last_hidden_state\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n\n        if self.chunked_loss_computer:\n            logits = hidden_states[:, slice_indices, :]\n        else:\n            logits = self.lm_head(hidden_states[:, slice_indices, :])\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen3 Model transformer with a sequence classification head on top (linear layer).\n\n    [`Qwen3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    QWEN3_START_DOCSTRING,\n)\nclass Qwen3ForSequenceClassification(Qwen3PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen3Model(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @can_return_tuple\n    @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n    ) -> SequenceClassifierOutputWithPast:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n\n        transformer_outputs: BaseModelOutputWithPast = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n        hidden_states = transformer_outputs.last_hidden_state\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            last_non_pad_token = -1\n        elif input_ids is not None:\n            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id\n            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)\n            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)\n            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)\n        else:\n            last_non_pad_token = -1\n            logger.warning_once(\n                f\"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be \"\n                \"unexpected if using padding tokens in conjunction with `inputs_embeds.`\"\n            )\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen3 Model transformer with a token classification head on top (a linear layer on top of the hidden-states\n    output) e.g. for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    QWEN3_START_DOCSTRING,\n)\nclass Qwen3ForTokenClassification(Qwen3PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen3Model(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @can_return_tuple\n    @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=TokenClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n    ) -> TokenClassifierOutput:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n\n        outputs: BaseModelOutputWithPast = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n        sequence_output = outputs.last_hidden_state\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.config)\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\nThe Qwen3 Model transformer with a span classification head on top for extractive question-answering tasks like\nSQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).\n    \"\"\",\n    QWEN3_START_DOCSTRING,\n)\nclass Qwen3ForQuestionAnswering(Qwen3PreTrainedModel):\n    base_model_prefix = \"transformer\"\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.transformer = Qwen3Model(config)\n        self.qa_outputs = nn.Linear(config.hidden_size, 2)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.transformer.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.transformer.embed_tokens = value\n\n    @can_return_tuple\n    @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        **kwargs,\n    ) -> QuestionAnsweringModelOutput:\n        r\"\"\"\n        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        \"\"\"\n\n        outputs: BaseModelOutputWithPast = self.transformer(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n\n        sequence_output = outputs.last_hidden_state\n\n        logits = self.qa_outputs(sequence_output)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1).contiguous()\n        end_logits = end_logits.squeeze(-1).contiguous()\n\n        loss = None\n        if start_positions is not None and end_positions is not None:\n            loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)\n\n        return QuestionAnsweringModelOutput(\n            loss=loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n__all__ = [\n    \"Qwen3ForCausalLM\",\n    \"Qwen3ForQuestionAnswering\",\n    \"Qwen3Model\",\n    \"Qwen3PreTrainedModel\",\n    \"Qwen3ForSequenceClassification\",\n    \"Qwen3ForTokenClassification\",\n]\n"
  },
  {
    "path": "pretrain/onerec_llm/models/qwen3/modular_qwen3.py",
    "content": "# coding=utf-8\n# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Qwen3 model.\"\"\"\n\nfrom typing import Callable, Optional, Tuple\n\nimport torch\nimport torch.utils.checkpoint\n\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\nfrom transformers.processing_utils import Unpack\nfrom transformers.utils.generic import LossKwargs\nfrom transformers.utils import logging\n\nfrom transformers.models.gemma.modeling_gemma import GemmaMLP\nfrom transformers.models.llama.modeling_llama import (\n    LlamaAttention,\n    LlamaDecoderLayer,\n    LlamaForCausalLM,\n    LlamaForQuestionAnswering,\n    LlamaForSequenceClassification,\n    LlamaForTokenClassification,\n    LlamaRMSNorm,\n    apply_rotary_pos_emb,\n    eager_attention_forward,\n)\nfrom transformers.models.mistral.modeling_mistral import MistralModel\nfrom .configuration_qwen3 import Qwen3Config\n\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"Qwen/Qwen3-8B\"\n\n\nclass Qwen3RMSNorm(LlamaRMSNorm):\n    pass\n\n\nclass Qwen3MLP(GemmaMLP):\n    pass\n\n\nclass Qwen3Attention(LlamaAttention):\n    def __init__(self, config: Qwen3Config, layer_idx: int):\n        super().__init__(config, layer_idx)\n        self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)  # unlike olmo, only on the head dim!\n        self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)  # thus post q_norm does not need reshape\n        self.sliding_window = config.sliding_window\n        if not (\n            self.config.use_sliding_window\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and self.layer_idx >= self.config.max_window_layers\n        ):\n            self.sliding_window = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n                logger.warning_once(\n                    \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to \"\n                    'eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n                )\n            else:\n                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            sliding_window=self.sliding_window,  # diff with Llama\n            **kwargs,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\nclass Qwen3DecoderLayer(LlamaDecoderLayer):\n    def __init__(self, config: Qwen3Config, layer_idx: int):\n        super().__init__()\n        self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)\n        self.mlp = Qwen3MLP(config)\n        if (\n            config.sliding_window and config._attn_implementation != \"flash_attention_2\"\n        ):  # diff with Llama is this warning\n            logger.warning_once(\n                f\"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; \"\n                \"unexpected results may be encountered.\"\n            )\n\n\nclass Qwen3Model(MistralModel):  # mistral model creates sliding window\n    pass\n\n\nclass KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...\n\n\nclass Qwen3ForCausalLM(LlamaForCausalLM):\n    def forward(\n        self,\n        **super_kwargs: Unpack[KwargsForCausalLM],\n    ) -> CausalLMOutputWithPast:\n        r\"\"\"\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n            logits_to_keep (`int` or `torch.Tensor`, *optional*):\n                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n                If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.\n                This is useful when using packed tensor format (single dimension for batch and sequence length).\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Qwen3ForCausalLM\n\n        >>> model = Qwen3ForCausalLM.from_pretrained(\"Qwen/Qwen3-8B\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen3-8B\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        return super().forward(**super_kwargs)\n\n\nclass Qwen3ForSequenceClassification(LlamaForSequenceClassification):\n    pass\n\n\nclass Qwen3ForTokenClassification(LlamaForTokenClassification):\n    pass\n\n\nclass Qwen3ForQuestionAnswering(LlamaForQuestionAnswering):\n    pass\n\n\n__all__ = [\n    \"Qwen3ForCausalLM\",\n    \"Qwen3ForQuestionAnswering\",\n    \"Qwen3Model\",\n    \"Qwen3PreTrainedModel\",  # noqa: F822\n    \"Qwen3ForSequenceClassification\",\n    \"Qwen3ForTokenClassification\",\n]\n"
  },
  {
    "path": "pretrain/onerec_llm/training/__init__.py",
    "content": "\"\"\"Training utilities for FSDP-based LLM training.\n\nThis package provides core training functionality including:\n- Distributed training with FSDP\n- Checkpoint management\n- Learning rate scheduling\n- Gradient computation and masking\n- Activation checkpointing\n\"\"\"\n\nfrom onerec_llm.training.activations import set_activation_checkpointing\nfrom onerec_llm.training.checkpoint import (\n    AppState,\n    DistributedCheckpointer,\n    load_checkpoint_to_state_dict,\n    load_hf_checkpoint,\n    load_safetensors,\n    safe_torch_load,\n)\nfrom onerec_llm.training.common import set_default_dtype\nfrom onerec_llm.training.distributed import (\n    load_from_full_model_state_dict,\n    shard_model,\n)\nfrom onerec_llm.training.gradients import (\n    EmbeddingGradientMasker,\n    clip_grad_by_value,\n    compute_fsdp_zero2_grad_norm,\n)\nfrom onerec_llm.training.lr_schedulers import get_cosine_scheduler, get_scheduler\n\n__all__ = [\n    # Activations\n    \"set_activation_checkpointing\",\n    # Checkpoint\n    \"AppState\",\n    \"DistributedCheckpointer\",\n    \"load_checkpoint_to_state_dict\",\n    \"load_hf_checkpoint\",\n    \"load_safetensors\",\n    \"safe_torch_load\",\n    # Common\n    \"set_default_dtype\",\n    # Distributed\n    \"load_from_full_model_state_dict\",\n    \"shard_model\",\n    # Gradients\n    \"EmbeddingGradientMasker\",\n    \"clip_grad_by_value\",\n    \"clip_grad_norm\",\n    \"compute_fsdp_zero2_grad_norm\",\n    # LR Schedulers\n    \"get_cosine_scheduler\",\n    \"get_scheduler\",\n]\n\n"
  },
  {
    "path": "pretrain/onerec_llm/training/activations.py",
    "content": "import torch.nn as nn\n\nfrom torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (\n    apply_activation_checkpointing,\n)\nfrom torch.distributed.fsdp.wrap import ModuleWrapPolicy\n\ndef set_activation_checkpointing(\n    model: nn.Module, auto_wrap_policy, **kwargs\n) -> None:\n    \"\"\"Utility to apply activation checkpointing to the passed-in model.\n\n    Args:\n        model (nn.Module): Model to apply activation checkpointing to.\n        auto_wrap_policy (ACWrapPolicyType): Policy to wrap module.\n            This can either be a set of ``nn.Module`` types, in which case, modules of the specified type(s)\n            will be wrapped individually with activation checkpointing, or a ``callable`` policy describing\n            how to wrap the model with activation checkpointing. For more information on authoring custom\n            policies, please see this tutorial:\n            https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html#transformer-wrapping-policy.\n        **kwargs: additional arguments to pass to ``torch.distributed`` activation checkpointing.\n    \"\"\"\n    if isinstance(auto_wrap_policy, set):\n        auto_wrap_policy = ModuleWrapPolicy(auto_wrap_policy)\n    apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)\n"
  },
  {
    "path": "pretrain/onerec_llm/training/checkpoint.py",
    "content": "from typing import Dict, Any, Union, Optional, Protocol, Callable\nimport re\nimport os\nimport gc\nimport glob\nimport time\nfrom pathlib import Path\nfrom concurrent.futures import Future\n\nimport torch\nimport torch.distributed as dist\nfrom safetensors import safe_open\n\nfrom torch.distributed.checkpoint import (\n    async_save,\n    FileSystemReader,\n    FileSystemWriter,\n    load,\n    save,\n)\nfrom torch.distributed.checkpoint.metadata import STATE_DICT_TYPE\nimport torch.distributed.checkpoint as dcp\nfrom torch.distributed.checkpoint.stateful import Stateful\nfrom torch.distributed.checkpoint.state_dict import get_model_state_dict, set_model_state_dict\nfrom torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner\nfrom safetensors.torch import load_file\nfrom tqdm import tqdm\n\nfrom onerec_llm.utils.distributed import get_world_size_and_rank\nfrom onerec_llm.utils.common import print_rank_0, print_rank_n\n\ndef load_safetensors(path: Union[Path, str]) -> Dict[str, torch.Tensor]:\n    \"\"\"Load safetensors file and return a dictionary of tensors.\n    \n    Args:\n        path: Path to the safetensors file.\n        \n    Returns:\n        Dictionary mapping tensor names to tensors.\n    \"\"\"\n    tensors = {}\n    with safe_open(path, framework=\"pt\", device=\"cpu\") as f:\n        for k in f.keys():\n            tensors[k] = f.get_tensor(k)\n    return tensors\n\n\ndef safe_torch_load(\n    checkpoint_path: Union[Path, str], \n    weights_only: bool = True, \n    mmap: bool = True\n) -> Dict[str, Any]:\n    \"\"\"\n    Utility to load a checkpoint file onto CPU in a safe manner. \n    Provides separate handling for safetensors files.\n\n    Args:\n        checkpoint_path: Path to the checkpoint file.\n        weights_only: Whether to load only tensors, primitive types, and dictionaries\n            (passthrough to torch.load). Default: True\n        mmap: Whether to mmap from disk into CPU memory. Default: True\n\n    Returns:\n        State dict from the checkpoint file.\n\n    Raises:\n        ValueError: If the checkpoint file is not found or cannot be loaded.\n    \"\"\"\n    try:\n        checkpoint_path_str = str(checkpoint_path)\n        if checkpoint_path_str.endswith(\".safetensors\"):\n            return load_safetensors(checkpoint_path)\n        else:\n            return torch.load(\n                checkpoint_path_str,\n                map_location=\"cpu\",\n                mmap=mmap,\n                weights_only=weights_only,\n            )\n    except Exception as e:\n        raise ValueError(f\"Unable to load checkpoint from {checkpoint_path}\") from e\n\ndef load_hf_checkpoint(\n    model_dir: str, \n    output_keys_file: Optional[str] = None\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Load HuggingFace format checkpoint from a directory.\n    \n    Args:\n        model_dir: Directory containing checkpoint files (.safetensors or .bin).\n        output_keys_file: Optional path to write checkpoint keys for debugging.\n            If None, keys are not written. Default: None.\n    \n    Returns:\n        Merged state dictionary containing all checkpoint weights.\n    \n    Raises:\n        ValueError: If checkpoint files are not found or contain non-tensor values.\n    \"\"\"\n    merged_state_dict: Dict[str, torch.Tensor] = {}\n    \n    # Try to find safetensors files first, fall back to .bin files\n    ckpt_paths = sorted(glob.glob(os.path.join(model_dir, \"*.safetensors\")))\n    if not ckpt_paths:\n        ckpt_paths = sorted(glob.glob(os.path.join(model_dir, \"*.bin\")))\n    \n    if not ckpt_paths:\n        raise ValueError(f\"No checkpoint files found in {model_dir}\")\n    \n    for cpt_idx, cpt_path in enumerate(ckpt_paths):\n        print_rank_0(f\"Loading checkpoint {cpt_idx + 1}/{len(ckpt_paths)}: {cpt_path}\")\n        state_dict = safe_torch_load(cpt_path)\n        \n        # Validate that all values are tensors\n        for key, value in state_dict.items():\n            if not isinstance(value, torch.Tensor):\n                raise ValueError(\n                    f\"Expected all values in the state dict to be torch.Tensor. \"\n                    f\"Found {key}={type(value)} instead.\"\n                )\n        \n        merged_state_dict.update(state_dict)\n        \n        # Free memory\n        del state_dict\n        gc.collect()\n    \n    # Optionally write keys to file for debugging\n    if output_keys_file:\n        with open(output_keys_file, \"w\", encoding=\"utf-8\") as f:\n            f.write(\"# Checkpoint file paths:\\n\")\n            for path in ckpt_paths:\n                f.write(f\"{path}\\n\")\n            f.write(\"\\n# State dict keys:\\n\")\n            for key in merged_state_dict.keys():\n                f.write(f\"{key}\\n\")\n    \n    return merged_state_dict\n\n\ndef load_checkpoint_to_state_dict(checkpoint_path: Union[str, os.PathLike]) -> Dict[str, torch.Tensor]:\n    \"\"\"Load checkpoint file or directory and return state_dict.\n    \n    Supports multiple checkpoint formats:\n    - .pth or .pt files (PyTorch format)\n    - .safetensors files (SafeTensors format)\n    - Directories containing .safetensors files (HuggingFace format)\n    - .distcp format directories (Distributed checkpoint format)\n    \n    Args:\n        checkpoint_path: Path to checkpoint file or directory.\n            Can be:\n            - .pth, .pt file path\n            - .safetensors file path\n            - Directory containing .safetensors files\n            - .distcp format directory\n    \n    Returns:\n        state_dict: Dictionary containing model weights\n    \n    Raises:\n        FileNotFoundError: If checkpoint path does not exist\n        ValueError: If checkpoint format is unsupported or invalid\n    \"\"\"\n    checkpoint_path = os.path.abspath(checkpoint_path)\n    \n    # Check if path exists\n    if not os.path.exists(checkpoint_path):\n        raise FileNotFoundError(f\"Checkpoint path does not exist: {checkpoint_path}\")\n    \n    # If it's a file\n    if os.path.isfile(checkpoint_path):\n        # Handle .pth files\n        if checkpoint_path.endswith(\".pth\") or checkpoint_path.endswith(\".pt\"):\n            print_rank_0(f\"Loading PyTorch checkpoint from {checkpoint_path}...\")\n            state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))\n            # If state_dict contains nested 'model' or 'app.model' keys, extract them\n            if \"model\" in state_dict and isinstance(state_dict[\"model\"], dict):\n                state_dict = state_dict[\"model\"]\n            elif \"app\" in state_dict and \"model\" in state_dict[\"app\"]:\n                state_dict = state_dict[\"app\"][\"model\"]\n            return state_dict\n        \n        # Handle .safetensors files\n        elif checkpoint_path.endswith(\".safetensors\"):\n            print_rank_0(f\"Loading SafeTensors checkpoint from {checkpoint_path}...\")\n            return load_file(checkpoint_path)\n        \n        else:\n            raise ValueError(f\"Unsupported file format: {checkpoint_path}\")\n    \n    # If it's a directory\n    elif os.path.isdir(checkpoint_path):\n        # Check if it's a .distcp format directory\n        if any(file.endswith(\".distcp\") for file in os.listdir(checkpoint_path)) or \\\n           os.path.exists(os.path.join(checkpoint_path, \"checkpoint.json\")):\n            print_rank_0(f\"Loading DCP checkpoint from {checkpoint_path}...\")\n            # Use PyTorch's FileSystemReader to load DCP format\n            sd: STATE_DICT_TYPE = {}\n            from torch.distributed.checkpoint.state_dict_loader import _load_state_dict\n            _load_state_dict(\n                sd,\n                storage_reader=FileSystemReader(checkpoint_path),\n                planner=_EmptyStateDictLoadPlanner(),\n                no_dist=True,\n            )\n            # Extract model weights section\n            if \"app\" in sd and \"model\" in sd[\"app\"]:\n                return sd[\"app\"][\"model\"]\n            return sd\n        \n        # Check if it's a directory containing .safetensors files\n        safetensors_files = [f for f in os.listdir(checkpoint_path) if f.endswith(\".safetensors\")]\n        \n        if safetensors_files:\n            # Directly merge all .safetensors files\n            print_rank_0(f\"Loading and merging all SafeTensors files from {checkpoint_path}...\")\n            state_dict = {}\n            for safetensors_file in tqdm(safetensors_files, desc=\"Loading safetensors\"):\n                file_path = os.path.join(checkpoint_path, safetensors_file)\n                shard_state_dict = load_file(file_path)\n                # Update state_dict, merging all file contents\n                state_dict.update(shard_state_dict)\n            return state_dict\n        \n        else:\n            raise ValueError(f\"No supported checkpoint files found in directory: {checkpoint_path}\")\n    \n    else:\n        raise ValueError(f\"Invalid checkpoint path: {checkpoint_path}\")\n      \nclass CheckpointerInterface(Protocol):\n    \"\"\"Protocol interface for checkpoint loaders and savers.\"\"\"\n    \n    def load_checkpoint(self, **kwargs) -> Dict[str, Any]:\n        \"\"\"Load checkpoint from storage.\"\"\"\n        ...\n    \n    def save_checkpoint(self, state_dict: Dict[str, Any], **kwargs) -> None:\n        \"\"\"Save checkpoint to storage.\"\"\"\n        ...\n\nclass DistributedCheckpointer(CheckpointerInterface):\n    \"\"\"\n    Checkpointer which reads and writes checkpoints in the DistributedCheckpointing format.\n\n    Args:\n        process_group: Optional process group to use for distributed saving/loading.\n            If None, the default process group will be used.\n            For checkpointing, gloo CPU-based backend is needed.\n    \"\"\"\n    \n    def __init__(\n        self,\n        process_group: Optional[dist.ProcessGroup] = None\n    ) -> None:\n        self._checkpoint_future: Optional[Future] = None\n        self._checkpoint_dir_prefix = \"global_step\"\n        _, self._rank = get_world_size_and_rank()\n        self._process_group: Optional[dist.ProcessGroup] = process_group\n    \n    def get_latest_checkpoint(self, checkpoint_dir: str) -> Optional[str]:\n        \"\"\"Get the latest checkpoint directory path.\n        \n        Args:\n            checkpoint_dir: Directory containing checkpoint subdirectories.\n            \n        Returns:\n            Path to the latest checkpoint directory, or None if no checkpoints found.\n        \"\"\"\n        checkpoint_dir_pattern = re.compile(f\"{self._checkpoint_dir_prefix}(\\\\d+)\")\n        checkpoint_paths = []\n        \n        if not os.path.isdir(checkpoint_dir):\n            return None\n        \n        for name in os.listdir(checkpoint_dir):\n            if re.match(checkpoint_dir_pattern, name):\n                checkpoint_path = os.path.join(checkpoint_dir, name)\n                if os.path.isdir(checkpoint_path):\n                    checkpoint_paths.append(name)\n        \n        if checkpoint_paths:\n            latest_checkpoint_dir = sorted(\n                checkpoint_paths, \n                key=lambda x: int(x.split(\"_\")[-1])\n            )[-1]\n            return os.path.join(checkpoint_dir, latest_checkpoint_dir)\n        return None\n\n    def load_checkpoint(\n        self,\n        state_dict: STATE_DICT_TYPE,\n        checkpoint_path: Optional[str] = None,\n        checkpoint_dir: Optional[str] = None,\n        tag: Union[str, int] = \"latest\"\n    ) -> Dict[str, Any]:\n        \"\"\"\n        Load a Distributed checkpoint.\n        \n        Args:\n            state_dict: State dictionary to load into.\n            checkpoint_path: Direct path to checkpoint. If provided, this takes precedence.\n            checkpoint_dir: Directory containing checkpoints.\n            tag: Checkpoint tag (e.g., \"latest\" or step number). Default: \"latest\".\n        \n        Returns:\n            Loaded state dictionary.\n        \n        Raises:\n            ValueError: If no checkpoint path can be determined.\n        \"\"\"\n        if not checkpoint_path:\n            if not checkpoint_dir:\n                raise ValueError(\"Either checkpoint_path or checkpoint_dir must be provided\")\n            \n            if tag == \"latest\":\n                checkpoint_path = self.get_latest_checkpoint(checkpoint_dir)\n                if not checkpoint_path:\n                    raise ValueError(f\"No checkpoint found in {checkpoint_dir}\")\n            else:\n                checkpoint_path = str(Path(checkpoint_dir) / str(tag))\n        \n        if not checkpoint_path or not os.path.exists(checkpoint_path):\n            raise ValueError(f\"Checkpoint path does not exist: {checkpoint_path}\")\n        \n        print_rank_0(f\"Loading checkpoint from {checkpoint_path}\")\n        \n        dcp.load(\n            state_dict=state_dict,\n            storage_reader=FileSystemReader(checkpoint_path),\n            process_group=self._process_group,\n        )\n        \n        return state_dict\n\n    def save_checkpoint(\n        self,\n        state_dict: STATE_DICT_TYPE,\n        output_dir: Union[str, Path],\n        tag: Optional[Union[str, int]] = None,\n        save_async: bool = False\n    ) -> None:\n        \"\"\"\n        Save a distributed checkpoint to storage.\n        \n        If ``save_async`` is True, the save happens asynchronously unblocking the GPUs sooner.\n        This should only be used for intermediate checkpoints. Final checkpoint must be synchronous\n        as the training job cannot terminate until the checkpoint is persisted.\n\n        Args:\n            state_dict: Checkpoint state dict to be written out to file.\n            output_dir: Directory to save the checkpoint.\n            tag: Checkpoint tag. Used to create the checkpoint directory name, generally step number.\n            save_async: If True, save the checkpoint asynchronously. Default: False.\n        \"\"\"\n        checkpoint_path = Path(output_dir)\n        if tag is not None:\n            checkpoint_path = checkpoint_path / f\"{self._checkpoint_dir_prefix}{tag}\"\n        \n        checkpoint_path_str = str(checkpoint_path)\n        print_rank_0(f\"Saving checkpoint to {checkpoint_path_str}\")\n        \n        # Wait for previous checkpoint to finish if still in progress\n        if self._checkpoint_future and not self._checkpoint_future.done():\n            wait_start = time.perf_counter()\n            print_rank_n(\n                f\"Rank {self._rank}: previous checkpoint has not finished. \"\n                f\"Checkpointing frequency is too high. Waiting...\",\n                rank=self._rank\n            )\n            self._checkpoint_future.result()\n            wait_time = time.perf_counter() - wait_start\n            print_rank_n(\n                f\"Rank {self._rank}: waited {wait_time:.2f} seconds \"\n                f\"for previous checkpoint to finish\",\n                rank=self._rank\n            )\n            self._checkpoint_future = None\n        \n        cp_start = time.perf_counter()\n        \n        if save_async:\n            def callback(f: Future) -> None:\n                if f.exception() is None:\n                    print_rank_n(\n                        f\"Rank {self._rank}: Checkpoint saved asynchronously \"\n                        f\"to {checkpoint_path_str} successfully.\",\n                        rank=self._rank\n                    )\n                else:\n                    print_rank_n(\n                        f\"Rank {self._rank}: Checkpoint failed to save asynchronously \"\n                        f\"to {checkpoint_path_str} with exception: {f.exception()}\",\n                        rank=self._rank\n                    )\n            \n            self._checkpoint_future = async_save(\n                state_dict=state_dict,\n                storage_writer=FileSystemWriter(\n                    checkpoint_path_str,\n                    thread_count=16\n                ),\n                process_group=self._process_group,\n            )\n            \n            blocked_time = time.perf_counter() - cp_start\n            print_rank_n(\n                f\"Rank {self._rank}: Trainer was blocked for {blocked_time:.2f} seconds \"\n                \"for checkpointing to start...\",\n                rank=self._rank\n            )\n            \n            self._checkpoint_future.add_done_callback(callback)\n        else:\n            print_rank_0(f\"Saving model checkpoint synchronously to {checkpoint_path_str}\")\n            save(\n                state_dict=state_dict,\n                storage_writer=FileSystemWriter(\n                    checkpoint_path_str,\n                    thread_count=4\n                ),\n                process_group=self._process_group,\n            )\n            print_rank_0(\n                \"The full model checkpoint, including all the weights and \"\n                \"configurations, has been saved successfully by the \"\n                \"DistributedCheckpointer. \"\n                \"You can now use this checkpoint for further training.\"\n            )\n\nclass AppState(Stateful):\n  \"\"\"This is a useful wrapper for checkpointing the Application State. \n     Since this object is compliant with the Stateful protocol, DCP will \n     automatically call state_dict/load_stat_dict as needed in the \n     dcp.save/load APIs.\n\n  Note: We take advantage of this wrapper to hande calling distributed \n    state dict methods on the model and optimizer.\n  \"\"\"\n\n  def __init__(self, model, optimizer=None, call_back=None):\n    self.model = model\n    self.call_back = call_back\n\n  def set_call_back(self, cb):\n    self.call_back = cb\n    return self\n\n  def state_dict(self):\n    # this line automatically manages FSDP FQN's, as well as sets the \n    # default state dict type to FSDP.SHARDED_STATE_DICT\n    model_state_dict = \\\n      get_model_state_dict(self.model)\n    if self.call_back is not None:\n      model_state_dict = self.call_back(model_state_dict)\n    return {\n      \"model\": model_state_dict\n    }\n\n  def load_state_dict(self, state_dict):\n    # sets our state dicts on the model and optimizer, now that we've loaded\n    set_model_state_dict(\n      self.model,\n      model_state_dict=state_dict[\"model\"],\n    )\n\n"
  },
  {
    "path": "pretrain/onerec_llm/training/common.py",
    "content": "\"\"\"Common training utilities for distributed model training.\"\"\"\n\nfrom typing import Generator\n\nimport contextlib\nimport torch\n\n\n@contextlib.contextmanager\ndef set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:\n    \"\"\"Temporarily set torch's default dtype.\n    \n    Args:\n        dtype: The desired default dtype.\n    \"\"\"\n    old_dtype = torch.get_default_dtype()\n    torch.set_default_dtype(dtype)\n    try:\n        yield\n    finally:\n        torch.set_default_dtype(old_dtype)\n"
  },
  {
    "path": "pretrain/onerec_llm/training/distributed.py",
    "content": "\"\"\"Distributed training utilities for FSDP model sharding and checkpoint loading.\"\"\"\n\nfrom typing import Any, Dict, Optional\n\nimport torch\nimport torch.distributed as dist\nfrom torch import nn\nfrom onerec_llm.utils.ds_utils import format_dict_or_list\nfrom onerec_llm.utils.distributed import get_world_size_and_rank\n\nfrom torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy\nfrom torch.distributed._tensor import distribute_tensor\nfrom torch.distributed.device_mesh import DeviceMesh\n\ndef shard_model(\n    model: nn.Module,\n    *,\n    cpu_offload: bool,\n    reshard_after_forward: bool = True,\n    dp_mesh: Optional[DeviceMesh] = None,\n    fp32_weight: bool = True,\n    model_class: str = 'Qwen3ForCausalLM',\n    fp32_reduce: bool = True\n) -> None:\n    \"\"\"Shard a model with FSDP using the PyTorch Distributed fully_shard API.\n    \n    Args:\n        model: Model to shard with FSDP.\n        cpu_offload: If True, FSDP will offload parameters to CPU.\n        reshard_after_forward: Whether to reshard after forward pass.\n        dp_mesh: Device mesh for FSDP sharding under multiple parallelism.\n        fp32_weight: If True, use fp32 for weights with bfloat16 params.\n        model_class: Model class name. Currently only supports 'Qwen3ForCausalLM'.\n        fp32_reduce: If True, use fp32 for gradient reduction.\n    \"\"\"\n    fsdp_kwargs = {\"reshard_after_forward\": reshard_after_forward, \"mesh\": dp_mesh}\n    \n    if fp32_weight:\n        fsdp_kwargs[\"mp_policy\"] = MixedPrecisionPolicy(\n            param_dtype=torch.bfloat16,\n            reduce_dtype=torch.float32 if fp32_reduce else torch.bfloat16\n        )\n    if cpu_offload:\n        fsdp_kwargs[\"offload_policy\"] = CPUOffloadPolicy()\n\n    if model_class == 'Qwen3ForCausalLM':\n        layers = list(model.model.layers)\n    else:\n        raise ValueError(f\"Unsupported model_class: {model_class}\")\n    \n    for layer in layers:\n        fully_shard(layer, **fsdp_kwargs)\n    \n    fully_shard(model, **fsdp_kwargs)\n\n    # Set up forward prefetch for layers\n    prev = None\n    for layer in reversed(layers):\n        if prev is not None:\n            layer.set_modules_to_forward_prefetch([prev])\n        prev = layer\n    model.set_modules_to_forward_prefetch([prev])\n\n\ndef load_from_full_model_state_dict(\n    model: \"FSDPModule\",\n    full_sd: Dict[str, Any],\n    allow_random_init_params: Optional[str] = None,\n    use_tie_weights: bool = False\n) -> None:\n    \"\"\"Load full state dict into an FSDP-sharded model.\n    \n    Args:\n        model: FSDP-sharded model to load into.\n        full_sd: Full (unsharded) state dictionary.\n        allow_random_init_params: Comma-separated parameter names to randomly initialize\n            if not found in full_sd. Default: None.\n        use_tie_weights: If True, tie lm_head.weight to model.embed_tokens.weight.\n    \"\"\"\n    if isinstance(allow_random_init_params, str):\n        allow_random_init_params = allow_random_init_params.split(',')\n    \n    meta_sharded_sd = model.state_dict()\n    sharded_sd = {}\n    \n    if dist.get_rank() == 0:\n        if use_tie_weights:\n            full_sd['lm_head.weight'] = full_sd['model.embed_tokens.weight']\n\n        extra_meta_sharded_sd = set(meta_sharded_sd.keys()) - set(full_sd.keys())\n        extra_full_ds = set(full_sd.keys()) - set(meta_sharded_sd.keys())\n        \n        extra_meta_sharded_sd = {\n            k: (v.shape, v.device, v.dtype) \n            for k, v in meta_sharded_sd.items() \n            if k in extra_meta_sharded_sd\n        }\n        extra_full_ds = {\n            k: (v.shape, v.device, v.dtype) \n            for k, v in full_sd.items() \n            if k in extra_full_ds\n        }\n\n        device0 = full_sd[list(full_sd)[0]]\n        for k in extra_meta_sharded_sd:\n            if allow_random_init_params is not None and k in allow_random_init_params:\n                full_sd[k] = torch.rand(extra_meta_sharded_sd[k][0]) * 0.1\n                if full_sd[k].ndim >= 2:\n                    nn.init.kaiming_normal_(full_sd[k], a=0, mode='fan_in', nonlinearity='relu')\n                else:\n                    nn.init.zeros_(full_sd[k])\n                full_sd[k] = full_sd[k].to(device0)\n\n        assert len(meta_sharded_sd) == len(full_sd), (\n            f\"Sharded State Dict doesn't equal to Full State Dict, \"\n            f\"{len(meta_sharded_sd)} vs {len(full_sd)}\\n\"\n            f\"extra_meta_sharded_sd={format_dict_or_list(extra_meta_sharded_sd)}, \"\n            f\"extra_full_ds={format_dict_or_list(extra_full_ds)}\"\n        )\n        assert sorted(list(meta_sharded_sd.keys())) == sorted(list(full_sd.keys())), \\\n            \"Keys of Sharded State Dict doesn't equal to Full State Dict\"\n\n    for param_name, sharded_meta_param in meta_sharded_sd.items():\n        if dist.get_rank() == 0:\n            full_tensor = full_sd[param_name].detach().cuda().type(sharded_meta_param.dtype)\n        else:\n            full_tensor = torch.empty(\n                sharded_meta_param.size(),\n                device=\"cuda\",\n                dtype=sharded_meta_param.dtype,\n            )\n        \n        mesh = sharded_meta_param.device_mesh\n        dist.broadcast(full_tensor, src=0, group=mesh.get_group(0))\n        dist.barrier()\n        \n        sharded_tensor = distribute_tensor(\n            full_tensor, mesh, sharded_meta_param.placements\n        )\n        sharded_sd[param_name] = nn.Parameter(sharded_tensor)\n\n    model.load_state_dict(sharded_sd, assign=True)\n"
  },
  {
    "path": "pretrain/onerec_llm/training/gradients.py",
    "content": "\"\"\"Gradient computation and manipulation utilities for training.\n\nThis module provides utilities for gradient processing including:\n- Gradient clipping\n- Gradient norm computation for FSDP models\n- Gradient masking for embedding layers in distributed training\n\"\"\"\n\nfrom typing import Optional\n\nimport torch\nimport torch.distributed as dist\n\n\ndef clip_grad_by_value(\n    model: torch.nn.Module, \n    clip_range: Optional[float] = None\n) -> None:\n    \"\"\"Clip gradients by value.\n    \n    Args:\n        model: The model whose gradients will be clipped.\n        clip_range: Maximum absolute value for gradients. If None, no clipping.\n    \"\"\"\n    if clip_range is not None:\n        torch.nn.utils.clip_grad_value_(model.parameters(), clip_range)\n\n\ndef clip_grad_norm(\n    model: torch.nn.Module,\n    max_grad_norm: Optional[float] = None\n) -> None:\n    \"\"\"Clip gradients by global L2 norm.\n    \n    Args:\n        model: The model whose gradients will be clipped.\n        max_grad_norm: Maximum allowed L2 norm. If None, no clipping.\n    \"\"\"\n    if max_grad_norm is not None:\n        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n\n\ndef compute_fsdp_zero2_grad_norm(\n    model: torch.nn.Module, \n    ignore_unused_parameters: bool = True\n) -> float:\n    \"\"\"Compute the global L2 norm of gradients for FSDP Zero-2 models.\n    \n    Args:\n        model: FSDP-wrapped model.\n        ignore_unused_parameters: If True, ignore parameters without gradients.\n    \n    Returns:\n        The global L2 norm of all gradients.\n    \"\"\"\n    total_sq = torch.tensor(0.0, device=next(model.parameters()).device)\n    \n    for param in model.parameters():\n        if param.grad is None:\n            if not ignore_unused_parameters:\n                raise ValueError(\n                    f\"Parameter {param} has no gradient. \"\n                    \"Please check if it is being used correctly.\"\n                )\n            continue\n        \n        local_grad = param.grad.to_local()\n        total_sq += torch.sum(local_grad ** 2)\n    \n    dist.all_reduce(total_sq, op=dist.ReduceOp.SUM, group=dist.group.WORLD)\n    grad_norm = torch.sqrt(total_sq).item()\n    \n    return grad_norm\n\n\nclass EmbeddingGradientMasker:\n    \"\"\"Freeze a portion of embedding parameters during distributed training.\n    \n    In distributed training with DTensor, embedding layers are sharded across ranks.\n    This class freezes the first `start_optimize_embedding_index` tokens in the vocabulary,\n    allowing only the remaining tokens to be optimized. This is useful for progressive\n    training strategies where only a subset of the vocabulary is optimized initially.\n    \n    Args:\n        model: The model containing embedding layers\n        config: Model config with vocab_size attribute\n        start_optimize_embedding_index: Index from which to start optimizing embeddings.\n            Tokens before this index will be frozen. If <= 0, no masking is applied.\n    \"\"\"\n    \n    def __init__(self, model, config, start_optimize_embedding_index):\n        self.model = model\n        self.config = config\n        self.start_optimize_embedding_index = start_optimize_embedding_index\n        self.embedding_params = []  # List of (name, param) tuples for embedding layers\n        self.saved_weights = {}  # Dict mapping param name -> frozen weight slice (torch.Tensor)\n\n        if start_optimize_embedding_index > 0:\n            self._find_embedding_parameters()\n            self._save_initial_weights()\n\n    def _find_embedding_parameters(self):\n        \"\"\"Find all embedding-related parameters (embed_tokens and lm_head).\"\"\"\n        for name, param in self.model.named_parameters():\n            if param.requires_grad and (\"embed_tokens\" in name or \"lm_head\" in name):\n                self.embedding_params.append((name, param))\n\n    def _save_initial_weights(self):\n        \"\"\"Save frozen weight slices for each rank in distributed training.\n        \n        In distributed training, embedding parameters are sharded across ranks.\n        This method calculates which portion of the local shard needs to be frozen\n        and saves those weights for later restoration after optimizer steps.\n        \"\"\"\n        dp_world_size = dist.get_world_size()\n        dp_rank = dist.get_rank()\n        full_vocab_size = self.config.vocab_size\n        \n        # Calculate shard boundaries: each rank owns a contiguous slice of the vocabulary\n        shard_size = (full_vocab_size + dp_world_size - 1) // dp_world_size\n        shard_offset = dp_rank * shard_size\n\n        with torch.no_grad():\n            for name, param in self.embedding_params:\n                # Get local tensor from DTensor (param is a DTensor in distributed mode)\n                local_param_tensor = param.to_local()\n                local_shard_size = local_param_tensor.shape[0]\n                \n                # Calculate overlap between frozen range [0, start_optimize_embedding_index)\n                # and this rank's shard [shard_offset, shard_offset + local_shard_size)\n                overlap_start = shard_offset\n                overlap_end = min(self.start_optimize_embedding_index, shard_offset + local_shard_size)\n                \n                # Number of rows in this rank's shard that need to be frozen\n                num_local_rows = 0\n                if overlap_end > overlap_start:\n                    num_local_rows = int(overlap_end - overlap_start)\n\n                # Save the frozen slice for restoration after optimizer steps\n                if num_local_rows > 0:\n                    self.saved_weights[name] = local_param_tensor[:num_local_rows].clone()\n\n    def save_frozen_params(self):\n        \"\"\"Deprecated: Logic moved to __init__. Kept for backward compatibility.\"\"\"\n        pass\n\n    def apply_gradient_mask(self, optimizer=None):\n        \"\"\"Deprecated: We use restore strategy instead. Kept for backward compatibility.\"\"\"\n        pass\n\n    def restore_frozen_params(self):\n        \"\"\"Restore frozen parameters after optimizer.step().\n        \n        This should be called after each optimizer.step() to restore the frozen\n        portion of embedding weights that were modified by the optimizer.\n        Uses .to_local() to safely modify DTensor parameters in distributed training.\n        \"\"\"\n        if self.start_optimize_embedding_index <= 0 or not self.saved_weights:\n            return\n\n        with torch.no_grad():\n            for name, param in self.embedding_params:\n                if name in self.saved_weights:\n                    # Get local tensor from DTensor for modification\n                    local_param_tensor = param.to_local()\n                    \n                    saved_slice = self.saved_weights[name]\n                    num_to_restore = saved_slice.shape[0]\n\n                    if num_to_restore > 0:\n                        # Restore frozen weights by copying saved slice back\n                        local_param_tensor[:num_to_restore].copy_(saved_slice)\n\n"
  },
  {
    "path": "pretrain/onerec_llm/training/lr_schedulers.py",
    "content": "\"\"\"Learning rate schedulers for training.\"\"\"\n\nimport math\nfrom functools import partial\nfrom typing import Optional\n\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR\n\n\ndef _get_cosine_schedule_with_warmup_lr_lambda(\n    current_step: int,\n    *,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    num_cycles: float,\n    num_stop_steps: int = 0,\n    min_lr_rate: float = 0.0\n) -> float:\n    \"\"\"Compute learning rate multiplier for cosine schedule with warmup.\n    \n    Args:\n        current_step: Current training step.\n        num_warmup_steps: Number of warmup steps.\n        num_training_steps: Total number of training steps.\n        num_cycles: Number of cosine cycles.\n        num_stop_steps: Number of steps to keep LR at 0 at the start.\n        min_lr_rate: Minimum learning rate as a fraction of max LR.\n    \n    Returns:\n        Learning rate multiplier (0.0 to 1.0).\n    \"\"\"\n    if num_stop_steps > 0 and current_step < num_stop_steps:\n        return 0.0\n    \n    if current_step < num_warmup_steps:\n        return float(current_step) / float(max(1, num_warmup_steps))\n    \n    if current_step > num_training_steps:\n        return min_lr_rate\n    \n    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n    factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))\n    factor = factor * (1 - min_lr_rate) + min_lr_rate\n    return max(0.0, factor)\n\ndef get_cosine_scheduler(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    num_cycles: float = 0.5,\n    num_stop_steps: int = 0,\n    last_epoch: int = -1,\n    min_lr: Optional[float] = None,\n    min_lr_rate: Optional[float] = None,\n    **kwargs\n) -> LambdaLR:\n    \"\"\"Create a cosine learning rate scheduler with warmup.\n    \n    Args:\n        optimizer: Optimizer to schedule.\n        num_warmup_steps: Number of warmup steps.\n        num_training_steps: Total number of training steps.\n        num_cycles: Number of cosine cycles. Default: 0.5.\n        num_stop_steps: Number of steps to keep LR at 0 at the start. Default: 0.\n        last_epoch: Last epoch index for resuming. Default: -1.\n        min_lr: Minimum learning rate (absolute value).\n        min_lr_rate: Minimum learning rate as fraction of max LR.\n    \n    Returns:\n        LambdaLR scheduler with cosine schedule.\n    \"\"\"\n    if min_lr is not None and min_lr_rate is not None:\n        raise ValueError(\"Only one of min_lr or min_lr_rate should be set\")\n    elif min_lr is not None:\n        min_lr_rate = min_lr / optimizer.defaults[\"lr\"]\n    elif min_lr_rate is None:\n        raise ValueError(\"One of min_lr or min_lr_rate must be set\")\n\n    lr_lambda = partial(\n        _get_cosine_schedule_with_warmup_lr_lambda,\n        num_warmup_steps=num_warmup_steps,\n        num_training_steps=num_training_steps,\n        num_cycles=num_cycles,\n        min_lr_rate=min_lr_rate,\n        num_stop_steps=num_stop_steps,\n    )\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\ndef get_scheduler(\n    name: str,\n    optimizer: Optimizer,\n    num_warmup_steps: Optional[int] = None,\n    num_training_steps: Optional[int] = None,\n    **kwargs\n) -> LambdaLR:\n    \"\"\"Get a learning rate scheduler by name.\n    \n    Args:\n        name: Scheduler name. Currently only supports \"cosine\".\n        optimizer: Optimizer to schedule.\n        num_warmup_steps: Number of warmup steps.\n        num_training_steps: Total number of training steps.\n        **kwargs: Additional arguments passed to the scheduler.\n    \n    Returns:\n        Learning rate scheduler instance.\n    \n    Raises:\n        NotImplementedError: If scheduler name is not supported.\n    \"\"\"\n    if name == \"cosine\":\n        return get_cosine_scheduler(\n            optimizer=optimizer,\n            num_warmup_steps=num_warmup_steps,\n            num_training_steps=num_training_steps,\n            **kwargs\n        )\n    else:\n        raise NotImplementedError(f\"Unsupported LR scheduler `{name}`\")\n\n"
  },
  {
    "path": "pretrain/onerec_llm/utils/__init__.py",
    "content": "\"\"\"Utility functions for LLM training.\n\nThis package provides general-purpose utilities including:\n- Common utilities (printing, device operations, random seeds)\n- Distributed training base utilities\n- Data loading and processing\n- Debugging and formatting tools\n- Performance tracking (MFU, time tracking)\n- Gradient masking\n- Worker information\n\"\"\"\n\nfrom onerec_llm.utils.common import (\n    Timer,\n    dist_reduce_dict,\n    get_optimizer_grouped_parameters,\n    print_rank_0,\n    print_rank_n,\n    set_random_seed,\n    to_cuda,\n    to_device,\n)\nfrom onerec_llm.utils.distributed import (\n    get_rank,\n    get_world_size,\n    get_world_size_and_rank,\n    is_distributed,\n)\nfrom onerec_llm.utils.ds_utils import (\n    format_dict_or_list,\n    print_input_info,\n    tensor_statistics,\n)\nfrom onerec_llm.utils.mfu_stats import MFUStats\nfrom onerec_llm.utils.time_tracker import TimeTracker\nfrom onerec_llm.utils.worker_utils import get_worker_info, pytorch_worker_info\n\n__all__ = [\n    # Common\n    \"Timer\",\n    \"dist_reduce_dict\",\n    \"get_optimizer_grouped_parameters\",\n    \"print_rank_0\",\n    \"print_rank_n\",\n    \"set_random_seed\",\n    \"to_cuda\",\n    \"to_device\",\n    # Distributed\n    \"get_rank\",\n    \"get_world_size\",\n    \"get_world_size_and_rank\",\n    \"is_distributed\",\n    # Debug/Format\n    \"format_dict_or_list\",\n    \"print_input_info\",\n    \"tensor_statistics\",\n    # Performance tracking\n    \"MFUStats\",\n    \"TimeTracker\",\n    # Worker info\n    \"get_worker_info\",\n    \"pytorch_worker_info\",\n]"
  },
  {
    "path": "pretrain/onerec_llm/utils/common.py",
    "content": "\"\"\"Common utility functions for the onerec_llm package.\n\nThis module contains core utilities for:\n- Distributed training (printing, reduction)\n- Device operations\n- Optimizer configuration\n- Random seed setting\n- Timing utilities\n\"\"\"\n\nimport random\nimport time\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom rich import print\nfrom transformers import set_seed as set_transformers_seed\n\ndef print_rank_n(*msg, rank=0):\n    try:\n        _rank = dist.get_rank()\n    except Exception:\n        _rank = 0\n    if _rank == rank:\n        print(*msg)\n\ndef print_rank_0(*msg):\n    print_rank_n(*msg, rank=0)\n\ndef get_optimizer_grouped_parameters(model,\n                                     learning_rate: float,\n                                     weight_decay,\n                                     no_decay_name_list=[\n                                         \"bias\", \"LayerNorm.weight\", \"embedding.weight\", \"lm_head.weight\"\n                                     ],\n                                     ):\n    optimizer_grouped_parameters = []\n\n    llm_wd_params_group = []\n    llm_nowd_params_group = []\n\n    for n, p in model.named_parameters():\n        if p.requires_grad:\n            if any(nd in n for nd in no_decay_name_list):\n                # no weight decay params\n                llm_nowd_params_group.append((n, p))\n            else:\n                llm_wd_params_group.append((n, p))\n    \n    # for LLM\n    optimizer_grouped_parameters.append({\n        \"params\": [p for n, p in llm_wd_params_group],\n        \"weight_decay\": weight_decay,\n        \"lr\": learning_rate,\n    })\n\n    optimizer_grouped_parameters.append({\n        \"params\": [p for n, p in llm_nowd_params_group],\n        \"weight_decay\": 0.0,\n        \"lr\": learning_rate,\n    })\n\n    # remove empty params group\n    final_optimizer_grouped_parameters = []\n    for group in optimizer_grouped_parameters:\n        if len(group['params']) > 0:\n            final_optimizer_grouped_parameters.append(group)\n    return final_optimizer_grouped_parameters\n\ndef to_device(batch, device, non_blocking=True):\n    for key in list(batch.keys()):\n        if isinstance(batch[key], torch.Tensor):\n            batch[key] = batch[key].to(device=device, non_blocking=non_blocking)\n    return batch\n\ndef to_cuda(batch, non_blocking=True):\n    \"\"\"Move batch to CUDA device. This is a convenience wrapper around to_device.\"\"\"\n    to_device(batch, device=torch.cuda.current_device(), non_blocking=non_blocking)\n\ndef set_random_seed(seed):\n    if seed is not None:\n        set_transformers_seed(seed)\n        random.seed(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n\ndef dist_reduce_dict(local_dict, group=None):\n    gather_list = [None for _ in range(dist.get_world_size(group=group))]\n\n    dist.all_gather_object(\n        object_list=gather_list, obj=local_dict, group=group)\n\n    def reduce_dicts(dicts):\n        def _reduce(d1, d2):\n            for key, value in d2.items():\n                if isinstance(value, dict):\n                    if key not in d1:\n                        d1[key] = {}\n                    _reduce(d1[key], value)\n                else:\n                    if key in d1:\n                        d1[key] += value\n                    else:\n                        d1[key] = value\n            return d1\n\n        result = {}\n        for d in dicts:\n            result = _reduce(result, d)\n        return result\n\n    return reduce_dicts(gather_list)\n\nclass Timer:\n    def __init__(self, desc: str = \"\"):\n        self.desc = desc\n\n    def __enter__(self):\n        print_rank_0(f\"Start... {self.desc}\")\n        self.start = time.time()\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.end = time.time()\n        self.elapsed = self.end - self.start\n        print_rank_0(f\"End... {self.desc} elapsed: {self.elapsed:.3f} \")\n"
  },
  {
    "path": "pretrain/onerec_llm/utils/data_utils.py",
    "content": "\"\"\"Data loading utilities for parquet files and HDFS.\"\"\"\n\nimport hashlib\nimport os\nimport subprocess\nimport time\nimport traceback\nfrom typing import Optional\n\nimport numpy as np\nimport pyarrow.parquet as pq\n\nfrom onerec_llm.utils.worker_utils import get_worker_info\nfrom onerec_llm.utils.distributed import get_world_size_and_rank\n\n\ndef calculate_text_hash(text):\n    \"\"\"Calculate SHA-256 hash of text.\n    \n    Args:\n        text: Input text string\n        \n    Returns:\n        Hexadecimal hash string\n    \"\"\"\n    hash_object = hashlib.sha256()\n    hash_object.update(text.encode('utf-8'))\n    return hash_object.hexdigest()\n\n\ndef shell_hdfs_ls(source_dir):\n    \"\"\"List files in HDFS directory.\n    \n    Args:\n        source_dir: HDFS directory path\n        \n    Returns:\n        list: List of file paths starting with 'viewfs://'\n    \"\"\"\n    try:\n        command = f\"hdfs dfs -ls {source_dir}\"\n        result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True)\n        files = []\n        for line in result.stdout.splitlines():\n            parts = line.split()\n            if len(parts) > 0 and parts[-1].startswith('viewfs://'):\n                files.append(parts[-1])\n        return files\n\n    except subprocess.CalledProcessError as e:\n        print(f\"Error occurred: {traceback.format_exc()}\")\n        return []\n\n\nclass FakeParquetFileFromFastParquetFile:\n    \"\"\"Wrapper for fastparquet ParquetFile to match pyarrow interface.\"\"\"\n    \n    def __init__(self, fast_parquet_file):\n        # Package version: mpirun --allow-run-as-root --hostfile /etc/mpi/hostfile --pernode bash -c \"pip3 install fastparquet==2024.2.0\"\n        from fastparquet import ParquetFile\n        self.fast_parquet_file = fast_parquet_file\n\n        # Put file opening logic first to prevent failure if file is deleted\n        self.res = ParquetFile(self.fast_parquet_file)\n        self.res.num_rows = len(self.res.to_pandas())\n        self.num_row_groups = 1\n\n    def read_row_group(self, i):\n        assert i == 0\n        return self.res\n\n\ndef load_parquet_file(\n    file_path: str,\n    retry: int = 5,\n    max_cache_files: int = 500,\n    parquet_backend: str = 'fast_parquet',\n    cache_dir: Optional[str] = None,\n    hadoop_cmd: Optional[str] = None\n) -> pq.ParquetFile:\n    \"\"\"Load a parquet file from local path or HDFS.\n    \n    This function handles two types of paths:\n    1. HDFS paths (viewfs:// or hdfs://): Downloads to cache and loads from cache\n    2. Local paths: Directly loads from the path\n    \n    Args:\n        file_path: Path to parquet file (can be local path or HDFS path)\n        retry: Number of retries when HDFS download fails\n        max_cache_files: Maximum number of files to keep in cache\n        parquet_backend: Parquet backend, 'fast_parquet' or 'pyarrow'\n        cache_dir: Cache directory path (default: /code/dataset_cache/{worker_id}_{rank_id})\n        hadoop_cmd: Hadoop command path (default: /home/hadoop/software/hadoop/bin/hadoop)\n        \n    Returns:\n        Loaded parquet file object\n        \n    Raises:\n        ValueError: If parquet_backend is invalid\n        FileNotFoundError: If file cannot be found or downloaded after retries\n    \"\"\"\n    if parquet_backend not in [\"fast_parquet\", \"pyarrow\"]:\n        raise ValueError(f\"Invalid parquet_backend: {parquet_backend}. Must be 'fast_parquet' or 'pyarrow'\")\n    \n    # Check if it's an HDFS path\n    is_hdfs_path = file_path.startswith(('viewfs://', 'hdfs://'))\n    \n    if is_hdfs_path:\n        # HDFS path: use cache and download logic\n        return _load_parquet_from_hdfs(\n            file_path, retry, max_cache_files, parquet_backend, cache_dir, hadoop_cmd\n        )\n    else:\n        # Local path: directly load (even if os.path.exists returns False,\n        # some file systems may support direct access)\n        try:\n            return _load_parquet_from_path(file_path, parquet_backend)\n        except Exception as e:\n            # If direct load fails and file doesn't exist, provide clear error\n            if not os.path.exists(file_path):\n                raise FileNotFoundError(f\"Local file not found: {file_path}\") from e\n            raise\n\n\ndef _load_parquet_from_hdfs(\n    file_path: str,\n    retry: int,\n    max_cache_files: int,\n    parquet_backend: str,\n    cache_dir: Optional[str],\n    hadoop_cmd: Optional[str]\n) -> pq.ParquetFile:\n    \"\"\"Load parquet file from HDFS using cache mechanism.\"\"\"\n    # Setup cache directory\n    # If cache_dir is None or empty string, use default cache directory\n    if not cache_dir:\n        worker_id = get_worker_info()[0]\n        rank_id = get_world_size_and_rank()[1]\n        cache_dir = f'/code/dataset_cache/{worker_id}_{rank_id}'\n    \n    os.makedirs(cache_dir, exist_ok=True)\n    \n    # Generate cache file path\n    filename = os.path.basename(file_path)\n    file_hash = calculate_text_hash(file_path)\n    cache_path = os.path.join(cache_dir, f\"{file_hash}_{filename}\")\n    \n    # Try to load from cache first\n    if os.path.exists(cache_path):\n        try:\n            return _load_parquet_from_path(cache_path, parquet_backend)\n        except Exception as e:\n            # Cache file might be corrupted, remove it and re-download\n            print(f\"Warning: Cached file {cache_path} is corrupted, removing: {e}\")\n            try:\n                os.remove(cache_path)\n            except Exception:\n                pass\n    \n    # Download from HDFS with retry\n    if hadoop_cmd is None:\n        hadoop_cmd = '/home/hadoop/software/hadoop/bin/hadoop'\n    \n    last_error = None\n    for attempt in range(retry):\n        try:\n            # Clean cache if needed before downloading\n            _clean_cache_if_needed(cache_dir, max_cache_files)\n            \n            # Download from HDFS\n            _download_from_hdfs(file_path, cache_path, hadoop_cmd)\n            \n            # Load downloaded file\n            return _load_parquet_from_path(cache_path, parquet_backend)\n            \n        except Exception as e:\n            last_error = e\n            if attempt < retry - 1:\n                # Exponential backoff with jitter\n                wait_time = 2 + np.random.randint(0, 5) + attempt\n                print(f\"Download attempt {attempt + 1}/{retry} failed: {e}. Retrying in {wait_time}s...\")\n                time.sleep(wait_time)\n            else:\n                print(f\"All {retry} download attempts failed for {file_path}\")\n    \n    # All retries failed\n    raise FileNotFoundError(\n        f\"Failed to load parquet file from HDFS after {retry} attempts. \"\n        f\"HDFS path: {file_path}, Cache: {cache_path}, Error: {last_error}\"\n    )\n\n\ndef _load_parquet_from_path(file_path: str, parquet_backend: str) -> pq.ParquetFile:\n    \"\"\"Load parquet file from given path.\"\"\"\n    if parquet_backend == 'pyarrow':\n        return pq.ParquetFile(file_path)\n    else:\n        return FakeParquetFileFromFastParquetFile(file_path)\n\n\ndef _clean_cache_if_needed(cache_dir: str, max_cache_files: int):\n    \"\"\"Clean old cache files if cache exceeds max_cache_files.\"\"\"\n    try:\n        files = [\n            os.path.join(cache_dir, f)\n            for f in os.listdir(cache_dir)\n            if os.path.isfile(os.path.join(cache_dir, f))\n        ]\n        \n        if len(files) <= max_cache_files:\n            return\n        \n        # Sort by creation time and remove oldest half\n        files.sort(key=os.path.getctime)\n        files_to_remove = files[:len(files) - max_cache_files // 2]\n        \n        for file_path in files_to_remove:\n            try:\n                os.remove(file_path)\n                print(f\"Removed old cached file: {file_path}\")\n            except Exception as e:\n                print(f\"Failed to remove cached file {file_path}: {e}\")\n    except Exception as e:\n        print(f\"Warning: Failed to clean cache: {e}\")\n\n\ndef _download_from_hdfs(hdfs_path: str, local_path: str, hadoop_cmd: str):\n    \"\"\"Download file from HDFS to local path.\"\"\"\n    cmd = [hadoop_cmd, 'fs', '-get', hdfs_path, local_path]\n    result = subprocess.run(\n        cmd,\n        capture_output=True,\n        text=True,\n        check=False\n    )\n    \n    if result.returncode != 0:\n        raise RuntimeError(\n            f\"HDFS download failed. Command: {' '.join(cmd)}, \"\n            f\"Return code: {result.returncode}, \"\n            f\"Error: {result.stderr}\"\n        )\n    \n    if not os.path.exists(local_path):\n        raise FileNotFoundError(f\"Downloaded file not found at {local_path}\")\n\n"
  },
  {
    "path": "pretrain/onerec_llm/utils/distributed.py",
    "content": "\"\"\"Distributed training base utilities.\n\nThis module provides fundamental distributed training utilities that can be used\nacross different modules without creating circular dependencies. For FSDP-specific\nutilities, see onerec_llm.training.distributed.\n\"\"\"\n\nimport os\nfrom typing import Tuple\n\nimport torch\nimport torch.distributed as dist\n\n\ndef get_world_size_and_rank() -> Tuple[int, int]:\n    \"\"\"Get the current world size and rank number.\n    \n    This function checks multiple sources in order:\n    1. PyTorch distributed (if initialized)\n    2. Environment variables (RANK, WORLD_SIZE)\n    3. Defaults to single process (1, 0)\n    \n    Returns:\n        Tuple of (world_size, rank).\n    \"\"\"\n    if torch.distributed.is_available() and torch.distributed.is_initialized():\n        return torch.distributed.get_world_size(), torch.distributed.get_rank()\n    elif \"RANK\" in os.environ and \"WORLD_SIZE\" in os.environ:\n        return int(os.environ[\"WORLD_SIZE\"]), int(os.environ[\"RANK\"])\n    else:\n        return 1, 0\n\n\ndef get_rank() -> int:\n    \"\"\"Get the current process rank.\n    \n    Returns:\n        Process rank (0-based).\n    \"\"\"\n    _, rank = get_world_size_and_rank()\n    return rank\n\n\ndef get_world_size() -> int:\n    \"\"\"Get the current world size.\n    \n    Returns:\n        Number of processes in the distributed group.\n    \"\"\"\n    world_size, _ = get_world_size_and_rank()\n    return world_size\n\n\ndef is_distributed() -> bool:\n    \"\"\"Check if distributed training is initialized.\n    \n    Returns:\n        True if distributed training is available and initialized.\n    \"\"\"\n    return torch.distributed.is_available() and torch.distributed.is_initialized()\n\n"
  },
  {
    "path": "pretrain/onerec_llm/utils/ds_utils.py",
    "content": "\"\"\"Debug and formatting utilities for data structures and tensors.\"\"\"\n\nimport math\nimport os\nimport traceback\nfrom dataclasses import is_dataclass, asdict\nfrom typing import Any, Dict, List, Tuple, Union\n\nimport torch\n\n\ndef convert_dataclass_to_dict(obj: Any) -> Any:\n    \"\"\"Convert dataclass instance to dict, return other objects unchanged.\"\"\"\n    if is_dataclass(obj) and not isinstance(obj, type):\n        return asdict(obj)\n    return obj\n\n\ndef tensor_statistics(tensor: torch.Tensor, n: int = -1, **kwargs) -> Tuple[str, str, str, str]:\n    \"\"\"Compute tensor statistics at 4 granularity levels.\n    \n    Args:\n        tensor: PyTorch tensor of any shape\n        n: Partial range: -1 for first half, >0 for first n elements\n    \n    Returns:\n        Tuple of 4 formatted stat strings: full, partial, magnitude-based, 1/10 magnitude-based\n    \"\"\"\n    flattened = tensor.reshape(-1)\n    total_elements = flattened.numel()\n    \n    if total_elements == 0:\n        base = \"mean: NaN, variance: NaN, max: NaN, min: NaN, non-zeros: 0\"\n        return (\n            f\"Full - {base}\",\n            f\"Partial - {base}\",\n            f\"Magnitude-based - {base}\",\n            f\"1/10 Magnitude-based - {base}\"\n        )\n    \n    if n == -1:\n        part_count = (total_elements + 1) // 2\n        part_tensor = flattened[:part_count]\n        part_label = f\"first half ({part_count} elements)\"\n    elif isinstance(n, int) and n > 0:\n        if n > total_elements:\n            raise ValueError(f\"n={n} exceeds total elements ({total_elements})\")\n        part_count = n\n        part_tensor = flattened[:n]\n        part_label = f\"first {n} elements\"\n    else:\n        raise ValueError(f\"n must be -1 or positive integer, got: {n}\")\n    \n    if total_elements <= 1:\n        mag_count = 0\n        mag_label = \"no elements (total <= 1)\"\n        mag_tensor = flattened[:0]\n    else:\n        log_val = math.log10(total_elements)\n        k = int(log_val) - 1 if log_val.is_integer() else math.floor(log_val)\n        mag_count = 10 ** k\n        mag_count = min(mag_count, total_elements)\n        mag_tensor = flattened[:mag_count]\n        mag_label = f\"first {mag_count} elements (magnitude-based)\"\n    \n    line4_count = mag_count // 10\n    if line4_count <= 0:\n        line4_label = \"no elements (1/10 of magnitude-based <= 0)\"\n        line4_tensor = flattened[:0]\n    else:\n        line4_count = min(line4_count, total_elements)\n        line4_tensor = flattened[:line4_count]\n        line4_label = f\"first {line4_count} elements (1/10 of magnitude-based)\"\n    \n    def calc_stats(t: torch.Tensor) -> Tuple[float, float, float, float, int]:\n        \"\"\"Calculate mean, variance, max, min, non-zero count.\"\"\"\n        if t.numel() == 0:\n            return (float('nan'), float('nan'), float('nan'), float('nan'), 0)\n        return (\n            torch.mean(t.float()).item(),\n            torch.var(t.float(), unbiased=False).item(),\n            torch.max(t).item(),\n            torch.min(t).item(),\n            torch.count_nonzero(t).item()\n        )\n    \n    full_mean, full_var, full_max, full_min, full_nonzero = calc_stats(flattened)\n    part_mean, part_var, part_max, part_min, part_nonzero = calc_stats(part_tensor)\n    mag_mean, mag_var, mag_max, mag_min, mag_nonzero = calc_stats(mag_tensor)\n    line4_mean, line4_var, line4_max, line4_min, line4_nonzero = calc_stats(line4_tensor)\n    \n    def format_line(label: str, mean: float, var: float, max_val: float, \n                   min_val: float, nonzero: int) -> str:\n        return (f\"{label} - mean: {mean:.6f}, variance: {var:.6f}, \"\n                f\"max: {max_val:.6f}, min: {min_val:.6f}, non-zeros: {nonzero}\")\n    \n    line1 = format_line(\"Full\", full_mean, full_var, full_max, full_min, full_nonzero)\n    line2 = format_line(part_label, part_mean, part_var, part_max, part_min, part_nonzero)\n    line3 = format_line(mag_label, mag_mean, mag_var, mag_max, mag_min, mag_nonzero)\n    line4 = format_line(line4_label, line4_mean, line4_var, line4_max, line4_min, line4_nonzero)\n    \n    return line1, line2, line3, line4\n\n\ndef print_input_info(\n    data: Any, \n    prefix: str = \"\", \n    max_str_len: int = 50, \n    return_str: bool = False, \n    max_show: int = 4, \n    save_path: Union[str, None] = None, \n    **kwargs\n) -> Union[None, str]:\n    \"\"\"Recursively print or return detailed information about input data.\n    \n    Supports Tensor, dict, list, tuple, str, int, float. Can save data to disk.\n    \n    Args:\n        data: Data to print\n        prefix: Prefix for each line (indentation)\n        max_str_len: Max string display length\n        return_str: Return string instead of printing\n        max_show: Max elements for tensor preview\n        save_path: Optional path to save data (tensors detached to CPU)\n        **kwargs: Passed to tensor_statistics()\n    \n    Returns:\n        Formatted string if return_str=True, else None\n    \"\"\"\n    data = convert_dataclass_to_dict(data)\n    \n    def _detach_to_cpu(obj: Any) -> Any:\n        \"\"\"Recursively detach tensors and move to CPU.\"\"\"\n        if isinstance(obj, torch.Tensor):\n            return obj.detach().cpu()\n        elif isinstance(obj, (list, tuple)):\n            return type(obj)(_detach_to_cpu(item) for item in obj)\n        elif isinstance(obj, dict):\n            return {k: _detach_to_cpu(v) for k, v in obj.items()}\n        elif hasattr(obj, '__dict__'):\n            return {k: _detach_to_cpu(v) for k, v in obj.__dict__.items()}\n        else:\n            return obj\n    \n    if save_path is not None:\n        try:\n            data_to_save = _detach_to_cpu(data)\n            dirname = os.path.dirname(save_path)\n            if dirname:\n                os.makedirs(dirname, exist_ok=True)\n            torch.save(data_to_save, save_path)\n            print(f\"Saved data to: {save_path}\")\n        except Exception as e:\n            print(f\"Failed to save data to {save_path}: {e}\\n{traceback.format_exc()}\")\n    \n    lines: List[str] = []\n    \n    try:\n        data = dict(data)\n    except (TypeError, ValueError):\n        pass\n    \n    def add_line(text: str) -> None:\n        if return_str:\n            lines.append(text)\n        else:\n            print(text)\n    \n    def _process_nested_item(item: Any, item_prefix: str, max_str_len: int, \n                             return_str: bool, lines: List[str], **kwargs) -> None:\n        sub_result = print_input_info(item, item_prefix, max_str_len, return_str=True, **kwargs)\n        if return_str:\n            lines.extend(sub_result.split('\\n'))\n        else:\n            print(sub_result)\n    \n    if data is None:\n        add_line(f\"{prefix}None\")\n        return \"\\n\".join(lines) if return_str else None\n    \n    if isinstance(data, torch.Tensor):\n        flattened = data.flatten()\n        data_preview = f\"{flattened[:max_show].tolist()}...{flattened[-max_show:].tolist()}\"\n        base_info = (f\"{prefix}Tensor: shape={tuple(data.shape)}, dtype={data.dtype}, \"\n                    f\"device={data.device}, data={data_preview}\")\n        \n        if data.dtype == torch.bool:\n            total_elements = data.numel()\n            true_count = data.sum().item()\n            false_count = total_elements - true_count\n            true_ratio = true_count / total_elements * 100 if total_elements > 0 else 0\n            false_ratio = false_count / total_elements * 100 if total_elements > 0 else 0\n            \n            add_line(base_info)\n            add_line(f\"{prefix}  True:  count={true_count:,d} ({true_ratio:.2f}%)\")\n            add_line(f\"{prefix}  False: count={false_count:,d} ({false_ratio:.2f}%)\")\n        else:\n            add_line(base_info)\n            for idx, stat_line in enumerate(tensor_statistics(data, **kwargs)):\n                add_line(f\"{prefix}  stat{idx}:  {stat_line}\")\n    \n    elif isinstance(data, str):\n        display_str = data[:max_str_len] + \"...\" if len(data) > max_str_len else data\n        add_line(f\"{prefix}String: length={len(data)}, value='{display_str}'\")\n    \n    elif isinstance(data, (list, tuple)):\n        container_type = \"List\" if isinstance(data, list) else \"Tuple\"\n        add_line(f\"{prefix}{container_type}: length={len(data)}\")\n        for i, item in enumerate(data):\n            add_line(f\"{prefix}[{i}]:\")\n            _process_nested_item(item, prefix + \"  \", max_str_len, return_str, lines, **kwargs)\n    \n    elif isinstance(data, dict):\n        add_line(f\"{prefix}Dict: keys={len(data)}\")\n        for key, value in data.items():\n            add_line(f\"{prefix}'{key}':\")\n            _process_nested_item(value, prefix + \"  \", max_str_len, return_str, lines, **kwargs)\n    \n    elif isinstance(data, (int, float)):\n        add_line(f\"{prefix}{type(data).__name__}: {data}\")\n    \n    else:\n        data_str = str(data)\n        truncated = data_str[:max_show] + \"...\" + data_str[-max_show:] if len(data_str) > max_show * 2 else data_str\n        add_line(f\"{prefix}Other type ({type(data).__name__}): {truncated}\")\n    \n    return \"\\n\".join(lines) if return_str else None\n\n\ndef format_dict_or_list(obj: Any, indent_level: int = 0, indent_size: int = 2) -> str:\n    \"\"\"Format dict/list as readable string (alternative to json.dumps).\n    \n    Args:\n        obj: Dictionary, list, or other object\n        indent_level: Current indentation level\n        indent_size: Spaces per indentation level\n    \n    Returns:\n        Formatted string\n    \"\"\"\n    def format_value(value: Any, indent_level: int, indent_size: int) -> str:\n        if isinstance(value, (dict, list)):\n            return format_dict_or_list(value, indent_level, indent_size)\n        elif isinstance(value, str):\n            return f'\"{value}\"'\n        else:\n            return str(value)\n    \n    if isinstance(obj, dict):\n        formatted_items = []\n        indent = \" \" * indent_size * (indent_level + 1)\n        for key, value in obj.items():\n            formatted_value = format_value(value, indent_level + 1, indent_size)\n            formatted_items.append(f'{indent}\"{key}\": {formatted_value}')\n        \n        items_str = ',\\n'.join(formatted_items)\n        current_indent = \" \" * indent_size * indent_level\n        return f'{{\\n{items_str}\\n{current_indent}}}'\n    \n    elif isinstance(obj, list):\n        formatted_items = []\n        indent = \" \" * indent_size * (indent_level + 1)\n        for item in obj:\n            formatted_value = format_value(item, indent_level + 1, indent_size)\n            formatted_items.append(f'{indent}{formatted_value}')\n        \n        items_str = ',\\n'.join(formatted_items)\n        current_indent = \" \" * indent_size * indent_level\n        return f'[\\n{items_str}\\n{current_indent}]'\n    \n    else:\n        return str(obj)"
  },
  {
    "path": "pretrain/onerec_llm/utils/mfu_stats.py",
    "content": "\"\"\"Model FLOPs Utilization (MFU) statistics and calculation utilities.\n\nThis module provides functionality to calculate FLOPs (Floating Point Operations)\nfor transformer models and compute MFU metrics for training performance monitoring.\n\"\"\"\n\nimport collections\nimport json\nimport os\nimport platform\nimport re\nimport subprocess\nfrom collections import defaultdict\nfrom functools import lru_cache\nfrom typing import Dict, List, Optional, Union\n\nimport easydict\n\n\ndef _sum_if_list(x: Union[int, List[int]]) -> int:\n    \"\"\"Sum if input is a list, otherwise return as-is.\"\"\"\n    return sum(x) if isinstance(x, list) else x\n\n\n@lru_cache(maxsize=1)\ndef _get_gpu_model() -> str:\n    \"\"\"Get NVIDIA GPU model name.\n    \n    Returns:\n        GPU model name, or \"Unknown\" if detection fails.\n    \"\"\"\n    try:\n        # Try nvidia-smi (most reliable method)\n        if platform.system() in [\"Linux\", \"Darwin\"]:\n            result = subprocess.run(\n                [\"nvidia-smi\", \"--query-gpu=name\", \"--format=csv,noheader\"],\n                capture_output=True,\n                text=True\n            )\n            if result.returncode == 0:\n                return result.stdout.strip()\n        \n        elif platform.system() == \"Windows\":\n            result = subprocess.run(\n                [\"nvidia-smi\", \"--query-gpu=name\", \"--format=csv,noheader\"],\n                capture_output=True,\n                text=True,\n                shell=True\n            )\n            if result.returncode == 0:\n                return result.stdout.strip()\n            \n            # Fallback: Windows Management Instrumentation\n            try:\n                import wmi\n                c = wmi.WMI()\n                gpus = c.Win32_VideoController()\n                for gpu in gpus:\n                    if \"NVIDIA\" in gpu.Name:\n                        return gpu.Name\n            except ImportError:\n                pass\n        \n        # Fallback: PyTorch CUDA\n        try:\n            import torch\n            if torch.cuda.is_available():\n                return torch.cuda.get_device_name(0)\n        except ImportError:\n            pass\n        \n        # Fallback: TensorFlow\n        try:\n            import tensorflow as tf\n            if tf.test.is_gpu_available():\n                gpus = tf.config.list_physical_devices('GPU')\n                if gpus:\n                    details = tf.config.experimental.get_device_details(gpus[0])\n                    return details.get('device_name', 'NVIDIA GPU')\n        except ImportError:\n            pass\n        \n        # Last resort: Check Linux driver file\n        if platform.system() == \"Linux\":\n            if os.path.exists(\"/proc/driver/nvidia/version\"):\n                with open(\"/proc/driver/nvidia/version\", \"r\") as f:\n                    first_line = f.readline().strip()\n                    match = re.search(r\"NVIDIA driver \\S+ for (\\S+)\", first_line)\n                    if match:\n                        return match.group(1)\n    \n    except Exception:\n        pass\n    \n    return \"Unknown\"\n\n\n@lru_cache(maxsize=1)\ndef _is_h800() -> bool:\n    \"\"\"Check if GPU is NVIDIA H800.\"\"\"\n    gpu_model = _get_gpu_model()\n    return gpu_model.split('\\n')[0].strip() == 'NVIDIA H800'\n\n\n@lru_cache(maxsize=1)\ndef _get_gpu_flops() -> float:\n    \"\"\"Get theoretical peak FLOPS for current GPU.\n    \n    Returns:\n        Peak FLOPS (H800: 989 TFLOPS, others: 312 TFLOPS)\n    \"\"\"\n    return 989e12 if _is_h800() else 312e12\n\n\ndef _calculate_decoder_layer_flops(\n    num_head: int,\n    head_dim: int,\n    hidden_size: int,\n    intermediate_size: int,\n    kv_heads: Optional[int] = None,\n    is_causal: bool = False,\n    seq_len: Union[int, List[int]] = 1,\n    batch_size: int = 1,\n    linear_factor: int = 2,\n    attn_output_layers: int = 2\n) -> Dict:\n    \"\"\"Calculate FLOPs for a single transformer decoder layer.\n    \n    Args:\n        num_head: Number of attention heads\n        head_dim: Dimension per attention head\n        hidden_size: Hidden layer size\n        intermediate_size: FFN intermediate layer size\n        kv_heads: Number of KV attention heads (for Group Attention)\n        is_causal: Whether to use causal masking\n        seq_len: Input sequence length (int or list for variable lengths)\n        batch_size: Batch size\n        linear_factor: Linear computation factor (default: 2 for multiply-add)\n        attn_output_layers: Number of attention output layers\n    \n    Returns:\n        Dictionary containing FLOPs breakdown and total FLOPs\n    \"\"\"\n    if kv_heads is None:\n        kv_heads = num_head\n    \n    seq_len_per_sample = None if isinstance(seq_len, list) else seq_len // batch_size\n    total_seq_len = _sum_if_list(seq_len)\n    \n    # QKV projection FLOPs\n    q_flops = linear_factor * total_seq_len * hidden_size * (num_head * head_dim)\n    k_flops = linear_factor * total_seq_len * hidden_size * (kv_heads * head_dim)\n    v_flops = linear_factor * total_seq_len * hidden_size * (kv_heads * head_dim)\n    \n    # Attention scores FLOPs\n    if isinstance(seq_len, list):\n        attn_scores_flops = 0\n        for seq_len_per_sample in seq_len:\n            attn_scores_flops += (\n                linear_factor * num_head * seq_len_per_sample * \n                seq_len_per_sample * head_dim\n            )\n    else:\n        attn_scores_flops = (\n            linear_factor * num_head * seq_len_per_sample * \n            seq_len_per_sample * head_dim * batch_size\n        )\n    \n    # Causal masking reduces computation by half\n    if is_causal:\n        attn_scores_flops *= 0.5\n    \n    attn_v_flops = attn_scores_flops\n    \n    # Attention output projection\n    attn_out_flops = linear_factor * total_seq_len * (num_head * head_dim) * hidden_size\n    \n    # Total attention FLOPs\n    attention_flops = q_flops + k_flops + v_flops + attn_scores_flops + attn_v_flops + attn_out_flops\n    \n    # FFN FLOPs\n    ffn_flops = (\n        linear_factor * total_seq_len * hidden_size * \n        intermediate_size * attn_output_layers\n    )\n    \n    total_flops = attention_flops + ffn_flops\n    \n    return {\n        'total_flops': total_flops,\n        'attention': {\n            'q_proj': q_flops,\n            'k_proj': k_flops,\n            'v_proj': v_flops,\n            'attn_scores': attn_scores_flops,\n            'attn_v': attn_v_flops,\n            'attn_out': attn_out_flops,\n            'total': attention_flops\n        },\n        'ffn_flops': ffn_flops,\n        'batch_info': {\n            'batch_size': batch_size,\n            'seq_len_per_sample': seq_len_per_sample\n        }\n    }\n\n\ndef _calculate_decoder_layers_flops(\n    num_head: int,\n    head_dim: int,\n    hidden_size: int,\n    intermediate_size: int,\n    kv_heads: Optional[int] = None,\n    is_causal: bool = False,\n    seq_len: Union[int, List[int]] = 1,\n    num_layers: int = 1,\n    linear_factor: int = 2,\n    batch_size: int = 1,\n    attn_output_layers: int = 2\n) -> Dict:\n    \"\"\"Calculate FLOPs for multiple transformer decoder layers.\n    \n    Args:\n        num_head: Number of attention heads\n        head_dim: Dimension per attention head\n        hidden_size: Hidden layer size\n        intermediate_size: FFN intermediate layer size\n        kv_heads: Number of KV attention heads\n        is_causal: Whether to use causal masking\n        seq_len: Input sequence length\n        num_layers: Number of decoder layers\n        linear_factor: Linear computation factor\n        batch_size: Batch size\n        attn_output_layers: Number of attention output layers\n    \n    Returns:\n        Dictionary containing per-layer and total FLOPs\n    \"\"\"\n    layers_flops = []\n    total_flops = 0\n    \n    for layer_idx in range(num_layers):\n        layer_flops = _calculate_decoder_layer_flops(\n            num_head=num_head,\n            head_dim=head_dim,\n            hidden_size=hidden_size,\n            intermediate_size=intermediate_size,\n            kv_heads=kv_heads,\n            is_causal=is_causal,\n            seq_len=seq_len,\n            linear_factor=linear_factor,\n            batch_size=batch_size,\n            attn_output_layers=attn_output_layers\n        )\n        layers_flops.append({\n            'layer_index': layer_idx,\n            **layer_flops\n        })\n        total_flops += layer_flops['total_flops']\n    \n    return {\n        'total_flops': total_flops,\n        'per_layer_flops': layers_flops[0] if layers_flops else {},\n        'avg_flops_per_layer': total_flops / num_layers if num_layers > 0 else 0,\n        'num_layers': num_layers,\n    }\n\n\ndef _calculate_llm_flops(llm_params: easydict.EasyDict) -> Dict:\n    \"\"\"Calculate total FLOPs for an LLM model.\n    \n    Args:\n        llm_params: Model parameters (EasyDict with model config)\n    \n    Returns:\n        Dictionary containing total FLOPs including LM head\n    \"\"\"\n    linear_factor = 2\n    \n    llm_flops = _calculate_decoder_layers_flops(\n        num_head=llm_params.num_head,\n        head_dim=llm_params.head_dim,\n        hidden_size=llm_params.hidden_size,\n        intermediate_size=llm_params.intermediate_size,\n        num_layers=llm_params.num_layers,\n        kv_heads=llm_params.get('kv_heads', None),\n        is_causal=llm_params.get('is_causal', True),\n        seq_len=llm_params.seq_len,\n        batch_size=llm_params.get('batch_size', 1),\n        linear_factor=linear_factor,\n        attn_output_layers=3\n    )\n    \n    # Add LM head FLOPs\n    lm_head_flops = (\n        linear_factor * _sum_if_list(llm_params.seq_len) * \n        (llm_params.hidden_size * llm_params.vocab_size)\n    )\n    llm_flops['total_flops'] += lm_head_flops\n    llm_flops['lm_head_flops'] = lm_head_flops\n    \n    return llm_flops\n\n\n@lru_cache(maxsize=32)\ndef _extract_model_params(config_path: str) -> easydict.EasyDict:\n    \"\"\"Extract transformer parameters from model config JSON.\n    \n    Supports Qwen3 architecture.\n    \n    Args:\n        config_path: Path to JSON config file\n    \n    Returns:\n        EasyDict containing transformer parameters\n    \n    Raises:\n        ValueError: If architecture is not supported\n    \"\"\"\n    with open(config_path, 'r') as f:\n        config = json.load(f)\n    \n    if 'architectures' in config and 'Qwen3ForCausalLM' in config['architectures']:\n        transformer_params = {\n            'num_head': config['num_attention_heads'],\n            'head_dim': config['head_dim'],\n            'hidden_size': config['hidden_size'],\n            'intermediate_size': config['intermediate_size'],\n            'kv_heads': config['num_key_value_heads'],\n            'num_layers': config['num_hidden_layers'],\n            'vocab_size': config['vocab_size']\n        }\n    else:\n        raise ValueError(\n            f'Unsupported architecture. Expected Qwen3ForCausalLM, '\n            f'got: {config.get(\"architectures\", \"unknown\")}'\n        )\n    \n    return easydict.EasyDict(transformer_params)\n\n\ndef _calc_mfu(\n    config_path: str,\n    total_seq_len: int,\n    llm_batch_size: int = 1,\n    secs_per_step: Optional[float] = None,\n    _gpu_flops: Optional[float] = None\n) -> Dict:\n    \"\"\"Calculate Model FLOPs Utilization (MFU) for LLM models.\n    \n    Args:\n        config_path: Path to model config JSON\n        total_seq_len: Total sequence length\n        llm_batch_size: Batch size for LLM\n        secs_per_step: Seconds per training step\n        _gpu_flops: GPU peak FLOPS (auto-detected if None)\n    \n    Returns:\n        Dictionary containing MFU metrics and FLOPs breakdown\n    \"\"\"\n    transformer_params = _extract_model_params(config_path)\n    \n    # Calculate LLM FLOPs\n    llm_params = easydict.EasyDict({\n        **transformer_params,\n        'is_causal': True,\n        'seq_len': total_seq_len,\n        'batch_size': llm_batch_size\n    })\n    \n    flops = _calculate_llm_flops(llm_params)\n    gpu_flops = _get_gpu_flops() if _gpu_flops is None else _gpu_flops\n    \n    # Add MFU metrics\n    flops['total_flops*3(T)'] = flops['total_flops'] * 3 / 1e12\n    flops['total_flops/gpu_flops'] = flops['total_flops'] * 3 / gpu_flops\n    flops['gpu_flops'] = gpu_flops\n    flops['llm_total_flops*3(T)'] = flops['total_flops*3(T)']\n    flops['llm_percentage'] = 100\n    \n    flops['input_args'] = easydict.EasyDict(\n        config_path=config_path,\n        total_seq_len=total_seq_len,\n        llm_batch_size=llm_batch_size,\n        secs_per_step=secs_per_step\n    )\n    \n    if secs_per_step is not None:\n        flops['mfu'] = flops['total_flops/gpu_flops'] / secs_per_step\n    \n    return flops\n\n\nclass MFUStats:\n    \"\"\"Model FLOPs Utilization statistics tracker for LLM training.\n    \n    Tracks token counts and computes MFU metrics for training performance monitoring.\n    \n    Args:\n        args: Training arguments containing model_dir and logging_per_step\n    \"\"\"\n    \n    def __init__(self, args):\n        self.tokens_for_mfu = collections.defaultdict(int)\n        self.mfu_per_step_per_gpu = None\n        self.args = args\n        self.total_mfu = defaultdict(int)\n    \n    def set(self, num_tokens: int, num_samples: int) -> None:\n        \"\"\"Accumulate token and sample counts for MFU calculation.\n        \n        Args:\n            num_tokens: Total number of tokens\n            num_samples: Number of samples\n        \"\"\"\n        self.tokens_for_mfu[\"num_tokens\"] += int(num_tokens)\n        self.tokens_for_mfu[\"num_samples\"] += int(num_samples)\n    \n    def mfu(self, secs: float, global_step: int) -> Dict[str, float]:\n        \"\"\"Compute MFU metrics for the current logging period.\n        \n        Args:\n            secs: Total seconds elapsed in this period\n            global_step: Current global training step\n        \n        Returns:\n            Dictionary containing MFU metrics for logging\n        \"\"\"\n        args = self.args\n        tokens_for_mfu = self.tokens_for_mfu\n        \n        # Calculate MFU arguments for text-only LLM\n        mfu_args = easydict.EasyDict(\n            total_seq_len=round(tokens_for_mfu[\"num_tokens\"] / args.logging_per_step),\n            llm_batch_size=round(tokens_for_mfu[\"num_samples\"] / args.logging_per_step),\n            secs_per_step=secs / args.logging_per_step\n        )\n        \n        config_path = os.path.join(args.model_dir, \"config.json\")\n        mfu_per_step_per_gpu = _calc_mfu(config_path, **mfu_args)\n        self.mfu_per_step_per_gpu = mfu_per_step_per_gpu\n        \n        # Accumulate total MFU\n        total_mfu = self.total_mfu\n        total_mfu['llm_total_flops*3(T)'] += (\n            mfu_per_step_per_gpu['llm_total_flops*3(T)'] * args.logging_per_step\n        )\n        total_mfu['mfu'] += mfu_per_step_per_gpu['mfu'] * args.logging_per_step\n        \n        # Build logging dictionary\n        # Current metrics: period-based MFU (current logging period)\n        # Average metrics: cumulative MFU (average over entire training, smoothed)\n        mfu_log_dict = {\n            \"perf/mfu_per_step_per_gpu_current\": mfu_per_step_per_gpu['mfu'],\n            \"perf/llm_flops_per_step_per_gpu_current\": mfu_per_step_per_gpu['llm_total_flops*3(T)'],\n            \"perf/mfu_per_step_per_gpu_avg\": total_mfu['mfu'] / global_step,\n            \"perf/llm_flops_per_step_per_gpu_avg\": total_mfu['llm_total_flops*3(T)'] / global_step,\n        }\n        \n        # Reset counters for next period\n        self.tokens_for_mfu = collections.defaultdict(int)\n        \n        return mfu_log_dict\n\n"
  },
  {
    "path": "pretrain/onerec_llm/utils/time_tracker.py",
    "content": "\"\"\"Time tracking utilities for performance profiling.\"\"\"\n\nimport os\nimport time\nfrom typing import Dict, List, Literal, Optional\n\n\nclass TimeTracker:\n    \"\"\"Track time intervals between tick calls and compute rolling averages.\n    \n    This class records time intervals for named events and maintains a rolling\n    average of the last N intervals for each event. Supports both absolute\n    wall-clock time and CPU time tracking.\n    \n    Args:\n        n: Number of recent intervals to average (default: 1)\n        time_types: List of time types to track. Options: \"absolute\" (wall-clock)\n            or \"cpu\" (CPU time). Default: [\"absolute\"]\n    \n    Example:\n        >>> tracker = TimeTracker(n=10)\n        >>> tracker.tick(\"start\")\n        >>> # ... do some work ...\n        >>> tracker.tick(\"end\")\n        >>> stats = tracker.stat()  # Returns average intervals\n    \"\"\"\n    \n    def __init__(\n        self, \n        n: int = 1, \n        time_types: Optional[List[Literal[\"absolute\", \"cpu\"]]] = None\n    ):\n        if time_types is None:\n            time_types = [\"absolute\"]\n        \n        self.n = n\n        self.time_types = time_types\n        self.last_times: Dict[str, float] = {\n            \"absolute\": time.perf_counter(),\n            \"cpu\": os.times().user\n        }\n        self.interval_records: Dict[str, List[float]] = {}\n\n    def tick(self, name: str) -> None:\n        \"\"\"Record time interval for a named event.\n        \n        Records the time elapsed since the last tick() call for each configured\n        time type. Maintains a rolling window of the last N intervals.\n        \n        Args:\n            name: Name of the event to track\n        \"\"\"\n        for time_type in self.time_types:\n            # Get current time based on type\n            if time_type == \"absolute\":\n                current_time = time.perf_counter()\n            elif time_type == \"cpu\":\n                current_time = os.times().user\n            else:\n                raise ValueError(\n                    f\"Invalid time_type '{time_type}'. \"\n                    \"Allowed values are 'absolute' or 'cpu'.\"\n                )\n            \n            # Calculate interval\n            last_time = self.last_times[time_type]\n            interval = current_time - last_time\n            self.last_times[time_type] = current_time\n            \n            # Store interval in rolling window\n            key = f\"{time_type}@{name}\"\n            if key not in self.interval_records:\n                self.interval_records[key] = []\n            \n            intervals = self.interval_records[key]\n            intervals.append(interval)\n            \n            # Maintain rolling window of size n\n            if len(intervals) > self.n:\n                intervals.pop(0)\n\n    def stat(self) -> Dict[str, float]:\n        \"\"\"Get average time intervals for all tracked events.\n        \n        Returns:\n            Dictionary mapping event keys (format: \"{time_type}@{name}\") to\n            average interval values. Only includes events with recorded intervals.\n        \"\"\"\n        result: Dict[str, float] = {}\n        for key, intervals in self.interval_records.items():\n            if intervals:\n                result[key] = sum(intervals) / len(intervals)\n        return result"
  },
  {
    "path": "pretrain/onerec_llm/utils/worker_utils.py",
    "content": "\"\"\"Worker information utilities for PyTorch DataLoader and distributed training.\"\"\"\n\nimport os\nimport torch\nimport torch.distributed as dist\n\n\ndef get_worker_info():\n    \"\"\"Get PyTorch DataLoader worker information.\n    \n    This function prioritizes PyTorch DataLoader's worker info over environment\n    variables, as it provides accurate worker information in multi-process\n    DataLoader contexts.\n    \n    Returns:\n        tuple: (worker_id, num_workers)\n    \"\"\"\n    # Priority 1: Try to get from PyTorch DataLoader worker info\n    # This is the most reliable source in DataLoader worker processes\n    try:\n        import torch.utils.data\n        worker_info = torch.utils.data.get_worker_info()\n        if worker_info is not None:\n            return worker_info.id, worker_info.num_workers\n    except (ModuleNotFoundError, AttributeError):\n        pass\n    \n    # Priority 2: Fall back to environment variables (for non-DataLoader contexts)\n    if \"WORKER\" in os.environ and \"NUM_WORKERS\" in os.environ:\n        return int(os.environ[\"WORKER\"]), int(os.environ[\"NUM_WORKERS\"])\n    \n    # Default: single worker, worker_id = 0\n    return 0, 1\n\n\ndef pytorch_worker_info(group=None):\n    \"\"\"Return node and worker info for PyTorch and some distributed environments.\n\n    Args:\n        group: Optional process group for distributed environments. Defaults to None.\n\n    Returns:\n        tuple: (rank, world_size, worker, num_workers)\n    \"\"\"\n    # Get worker info (reuse get_worker_info to avoid code duplication)\n    worker, num_workers = get_worker_info()\n    \n    # Get rank and world_size\n    rank = 0\n    world_size = 1\n    \n    # Check environment variables first\n    if \"RANK\" in os.environ and \"WORLD_SIZE\" in os.environ:\n        rank = int(os.environ[\"RANK\"])\n        world_size = int(os.environ[\"WORLD_SIZE\"])\n    else:\n        # Try to get from PyTorch distributed\n        try:\n            if dist.is_available() and dist.is_initialized():\n                group = group or dist.group.WORLD\n                rank = dist.get_rank(group=group)\n                world_size = dist.get_world_size(group=group)\n        except (ModuleNotFoundError, AttributeError):\n            pass\n\n    return rank, world_size, worker, num_workers\n\n"
  },
  {
    "path": "pretrain/recipes/train_qwen3.py",
    "content": "\"\"\"Qwen3 Training Script\n\nMulti-node, multi-GPU training script for Qwen3 models using FSDP (Fully Sharded Data Parallel).\nSupports distributed training, checkpointing, and comprehensive monitoring.\n\"\"\"\n\nimport os\nimport sys\n\nsys.path.append(\"./onerec_llm/models\")\n\nimport argparse\nimport collections\nimport contextlib\nimport datetime\nimport gc\nimport itertools\nimport json\nimport logging\nimport queue\nimport threading\nimport time\nfrom functools import partial\nfrom typing import Dict, Optional, Tuple\n\nimport torch\nimport torch.distributed as dist\nfrom accelerate import init_empty_weights\nfrom torch.distributed.device_mesh import DeviceMesh, init_device_mesh\nfrom torch.utils.tensorboard import SummaryWriter\nfrom transformers import AutoConfig, AutoTokenizer\n\nfrom onerec_llm.data.dataloaders import get_dataloader\nfrom onerec_llm.losses import CrossEntropyLoss, ChunkedLossComputer\nfrom onerec_llm.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM\nfrom onerec_llm.training.activations import set_activation_checkpointing\nfrom onerec_llm.training.checkpoint import (\n    AppState,\n    DistributedCheckpointer,\n    load_hf_checkpoint,\n)\nfrom onerec_llm.training.common import set_default_dtype\nfrom onerec_llm.training.gradients import (\n    EmbeddingGradientMasker,\n    clip_grad_by_value,\n    clip_grad_norm,\n    compute_fsdp_zero2_grad_norm,\n)\nfrom onerec_llm.training.distributed import (\n    load_from_full_model_state_dict,\n    shard_model,\n)\nfrom onerec_llm.training.lr_schedulers import get_scheduler\nfrom onerec_llm.utils.common import (\n    Timer,\n    dist_reduce_dict,\n    get_optimizer_grouped_parameters,\n    print_rank_0,\n    set_random_seed,\n    to_cuda,\n)\nfrom onerec_llm.utils.ds_utils import format_dict_or_list, print_input_info\nfrom onerec_llm.utils.mfu_stats import MFUStats\nfrom onerec_llm.utils.time_tracker import TimeTracker\n\n# Disable garbage collection for performance\ngc.disable()\n\n# Set CUDA memory allocation configuration\nos.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n\n# Process group timeout (24 hours)\nPROCESS_GROUP_TIMEOUT = datetime.timedelta(minutes=60 * 24)\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(levelname)s - %(message)s'\n)\nlogger = logging.getLogger(__name__)\n\n\nclass TrainingMetrics:\n    \"\"\"Manages training metrics accumulation and statistics.\n    \n    This class tracks metrics in two ways:\n    - Period metrics (period_*): Accumulated over a logging period (logging_per_step steps)\n    - Total metrics (total_*): Accumulated over the entire training run\n    \"\"\"\n    \n    def __init__(self):\n        self.reset_period_accumulators()\n        # Total metrics accumulated over entire training\n        self.total_num_tokens = 0\n        self.total_num_samples = 0\n        self.total_num_valid_tokens = 0\n        self.total_data_source_tokens = collections.defaultdict(int)\n        self.local_period_data_source_samples = collections.defaultdict(int)\n    \n    def reset_period_accumulators(self):\n        \"\"\"Reset accumulated metrics for the current logging period.\"\"\"\n        # Period metrics: accumulated over logging_per_step steps\n        self.period_sum_loss = 0.0\n        self.period_sum_itemic_token_loss = 0.0\n        self.period_sum_text_token_loss = 0.0\n        self.period_num_tokens = 0\n        self.period_num_samples = 0\n        self.period_num_valid_tokens = 0\n        self.period_data_source_loss = collections.defaultdict(float)\n        self.period_data_source_tokens = collections.defaultdict(int)\n        self.period_valid_data_source_tokens = collections.defaultdict(int)\n        # Track number of steps in current period for averaging\n        self.period_num_steps = 0\n    \n    def update(self, num_tokens, num_samples, num_valid_tokens):\n        \"\"\"Update both period and total metrics.\"\"\"\n        # Update period metrics (for current logging period)\n        self.period_num_tokens += num_tokens\n        self.period_num_samples += num_samples\n        self.period_num_valid_tokens += num_valid_tokens\n        \n        # Update total metrics (for entire training)\n        self.total_num_tokens += num_tokens\n        self.total_num_samples += num_samples\n        self.total_num_valid_tokens += num_valid_tokens\n\n\nclass TensorBoardLogger:\n    \"\"\"Manages TensorBoard logging in a separate thread.\"\"\"\n    \n    def __init__(self, tb_writer: Optional[SummaryWriter]):\n        self.tb_writer = tb_writer\n        self.metrics_queue = queue.Queue(maxsize=8)\n        self.thread = None\n        \n        if tb_writer is not None and dist.get_rank() == 0:\n            self.thread = threading.Thread(\n                target=self._write_async,\n                args=(tb_writer, self.metrics_queue),\n                daemon=True\n            )\n            self.thread.start()\n    \n    def _write_async(self, tb_writer, metrics_queue):\n        \"\"\"Async TensorBoard writer thread.\"\"\"\n        while True:\n            global_step, log_dict, ticker_stats, ds_loss, ds_tokens, ds_samples = metrics_queue.get()\n            total_num_samples = log_dict[\"perf/total_num_samples\"]\n            total_num_valid_tokens = log_dict[\"perf/valid_total_num_tokens\"]\n            \n            # Log main metrics\n            for name, data in log_dict.items():\n                if data is not None and tb_writer:\n                    tb_writer.add_scalar(\n                        name, data, global_step=global_step, new_style=True\n                    )\n                    \n                    # Log training metrics by valid tokens\n                    if name.startswith(\"training/\"):\n                        tb_writer.add_scalar(\n                            f\"x_token_{name}\",\n                            data,\n                            global_step=total_num_valid_tokens,\n                            new_style=True\n                        )\n            \n            # Log ticker stats\n            for name, data in ticker_stats.items():\n                tb_writer.add_scalar(\n                    f\"ticker/{name}\", data, global_step=global_step, new_style=True\n                )\n            \n            # Log data source metrics\n            if ds_loss and tb_writer:\n                for key, loss_sum in ds_loss.items():\n                    tb_writer.add_scalar(\n                        f\"data_source_loss/{key}\",\n                        loss_sum / (ds_tokens.get(key, 0) + 1e-6),\n                        global_step=global_step,\n                        new_style=True\n                    )\n            \n            if ds_samples and tb_writer:\n                for key, samples in ds_samples.items():\n                    tb_writer.add_scalar(\n                        f\"data_source_sample_ratio/{key}\",\n                        1.0 * samples / total_num_samples,\n                        global_step=global_step,\n                        new_style=True\n                    )\n                \n                total_tokens = sum(ds_tokens.values())\n                if total_tokens > 0:\n                    for key, num_tokens in ds_tokens.items():\n                        tb_writer.add_scalar(\n                            f\"data_source_token_ratio/{key}\",\n                            1.0 * num_tokens / total_tokens,\n                            global_step=global_step,\n                            new_style=True\n                        )\n    \n    def log(self, global_step, log_dict, ticker_stats, ds_loss, ds_tokens, ds_samples):\n        \"\"\"Queue metrics for async logging.\"\"\"\n        if self.tb_writer is not None:\n            self.metrics_queue.put((\n                global_step, log_dict, ticker_stats, ds_loss, ds_tokens, ds_samples\n            ))\n\n\ndef get_argument_parser() -> argparse.ArgumentParser:\n    \"\"\"Create and configure argument parser.\"\"\"\n    parser = argparse.ArgumentParser(description=\"Qwen3 Training Script\")\n    \n    # Checkpoint arguments\n    parser.add_argument(\"--model_dir\", type=str, default=None,\n                       help=\"Directory of the pretrained model\")\n    parser.add_argument(\"--resume_from\", type=str, default=None,\n                       help=\"Checkpoint directory to resume from\")\n    parser.add_argument(\"--resume_from_tag\", type=str, default=None,\n                       help=\"Checkpoint tag to resume from\")\n    parser.add_argument(\"--resume_training_state\", action=\"store_true\",\n                       help=\"Whether to resume training state including optimizer, scheduler, and dataloader\")\n    parser.add_argument(\"--use_fp32_weight\", action=\"store_true\",\n                       help=\"Use fp32 for model weight updating\")\n    parser.add_argument(\"--use_fp32_reduce\", action=\"store_true\",\n                       help=\"Use fp32 for gradient reduction\")\n    parser.add_argument(\"--reshard_after_forward\", action=\"store_true\",\n                       help=\"Enable reshard_after_forward to enable Zero3\")\n    parser.add_argument(\"--save_checkpoint_per_step\", type=int, default=1000,\n                       help=\"Number of steps to save a checkpoint\")\n    parser.add_argument(\"--output_dir\", type=str, default=None,\n                       help=\"Directory to write the trained model\")\n    parser.add_argument(\"--model_class\", type=str, default=\"Qwen3ForCausalLM\",\n                       help=\"Model class name\")\n    \n    # Dataset arguments\n    parser.add_argument(\"--dataset_config\", type=str, default=None,\n                       help=\"Path to dataset configuration JSON file\")\n    parser.add_argument(\"--max_length\", type=int, default=None,\n                       help=\"Max tokens per sentence\")\n    parser.add_argument(\"--minibatch_size\", type=int, default=4096,\n                       help=\"Minibatch size\")\n    parser.add_argument(\"--start_optimize_embedding_index\", type=int, default=0,\n                       help=\"Start optimize embedding index for finetuning\")\n    \n    # Learning rate arguments\n    parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"cosine_with_min_lr\",\n                       help=\"Learning rate scheduler type\")\n    parser.add_argument(\"--num_warmup_steps\", type=int, default=0,\n                       help=\"Number of warmup steps\")\n    parser.add_argument(\"--num_training_steps\", type=int, default=1000,\n                       help=\"Number of training steps\")\n    parser.add_argument(\"--min_lr\", type=float, default=1e-6,\n                       help=\"Minimum learning rate after cosine schedule\")\n    \n    # Optimizer arguments\n    parser.add_argument(\"--learning_rate\", type=float, default=2e-4,\n                       help=\"Peak learning rate\")\n    parser.add_argument(\"--weight_decay\", type=float, default=0.1,\n                       help=\"Weight decay for AdamW\")\n    parser.add_argument(\"--beta1\", type=float, default=0.9,\n                       help=\"Beta1 for AdamW\")\n    parser.add_argument(\"--beta2\", type=float, default=0.95,\n                       help=\"Beta2 for AdamW\")\n    \n    # Training arguments\n    parser.add_argument(\"--use_tie_weights\", action=\"store_true\",\n                       help=\"Tie embedding and lm_head weights\")\n    # parser.add_argument(\"--clip_range\", type=float, default=None,\n    #                    help=\"Gradient clipping range\")\n    parser.add_argument(\"--max_grad_norm\", type=float, default=1.0,\n                       help=\"Max gradient norm (global L2); set to 0 to disable\")\n    parser.add_argument(\"--freeze_llm\", action=\"store_true\",\n                       help=\"Freeze all LLM parameters\")\n    parser.add_argument(\"--enable_gradient_checkpointing\", action=\"store_true\",\n                       help=\"Enable gradient checkpointing\")\n    parser.add_argument(\"--allow_random_init_params\", type=str, default='',\n                       help=\"Allow random initialization for specified parameters\")\n    parser.add_argument(\"--logging_per_step\", type=int, default=100,\n                       help=\"Number of steps to log training info\")\n    parser.add_argument(\"--seed\", type=int, default=123,\n                       help=\"Random seed\")\n    parser.add_argument(\"--monitor_datasource_loss\", action=\"store_true\",\n                       help=\"Monitor loss of each datasource\")\n    parser.add_argument(\"--monitor_datasource_cnt\", action=\"store_true\",\n                       help=\"Monitor count of each datasource\")\n    parser.add_argument(\"--use_chunked_loss_computer\", action=\"store_true\",\n                       help=\"Use chunked loss computer\")\n    \n    # Profiling arguments\n    parser.add_argument(\"--enable_profiler\", action=\"store_true\",\n                       help=\"Enable PyTorch profiler for performance analysis\")\n    \n    return parser\n\n\nclass StateDictConverter:\n    \"\"\"Converter for state dict transformations (identity by default).\"\"\"\n    \n    def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n        \"\"\"Convert state dict (e.g., for loading).\"\"\"\n        return state_dict\n    \n    def revert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n        \"\"\"Revert state dict (e.g., for saving).\"\"\"\n        return state_dict\n\n\ndef _init_profiler(output_dir: str, enable: bool = False) -> Optional[torch.profiler.profile]:\n    \"\"\"Initialize PyTorch profiler.\n    \n    Args:\n        output_dir: Directory to save profiler traces\n        enable: Whether to enable the profiler. If False, returns None.\n    \n    Returns:\n        PyTorch profiler instance if enabled, None otherwise.\n    \"\"\"\n    if not enable:\n        return None\n    \n    if not os.path.exists(output_dir):\n        if dist.get_rank() == 0:\n            os.makedirs(output_dir, exist_ok=True)\n    \n    def trace_handler(prof):\n        prof.export_chrome_trace(\n            os.path.join(output_dir, f\"{prof.step_num}_w{dist.get_rank()}.json\")\n        )\n    \n    # Profiler schedule: wait 50 steps, warmup 1 step, profile 10 steps, repeat once\n    # This avoids profiling initialization overhead and captures representative performance\n    return torch.profiler.profile(\n        activities=[\n            torch.profiler.ProfilerActivity.CPU,\n            torch.profiler.ProfilerActivity.CUDA,\n        ],\n        schedule=torch.profiler.schedule(wait=50, warmup=1, active=10, repeat=1),\n        on_trace_ready=trace_handler,\n    )\n\n\ndef save_model_checkpoint(\n    save_dir: str,\n    tag: str,\n    global_step: int,\n    optimizer: torch.optim.Optimizer,\n    lr_scheduler,\n    dataloader: Optional[object],\n    app_state: AppState,\n    dist_checkpointer: DistributedCheckpointer,\n) -> None:\n    \"\"\"Save FSDP+TP model checkpoint.\n    \n    Args:\n        save_dir: Save directory\n        tag: Checkpoint tag\n        global_step: Global training step\n        optimizer: Optimizer instance\n        lr_scheduler: Learning rate scheduler\n        dataloader: Optional dataloader for state saving\n        app_state: Application state\n        dist_checkpointer: Distributed checkpointer\n    \"\"\"\n    if dist.get_rank() == 0:\n        os.makedirs(save_dir, exist_ok=True)\n    \n    ckpt_path = os.path.join(save_dir, tag)\n    if dist.get_rank() == 0:\n        os.makedirs(ckpt_path, exist_ok=True)\n        with open(os.path.join(save_dir, \"latest\"), \"w\") as f:\n            f.write(tag)\n    \n    try:\n        # Save model checkpoint\n        dist_checkpointer.save_checkpoint(\n            state_dict={\"app\": app_state},\n            output_dir=ckpt_path,\n            tag=str(global_step)\n        )\n        \n        # Save dataloader state\n        if dataloader is not None:\n            try:\n                dataloader_state = {\"dataloader_state_dict\": dataloader.state_dict()}\n                dataloader_path = os.path.join(ckpt_path, \"dataloader_ckpt\")\n                if dist.get_rank() == 0:\n                    os.makedirs(dataloader_path, exist_ok=True)\n                dist.barrier()\n                \n                torch.save(\n                    dataloader_state,\n                    os.path.join(dataloader_path, f\"rank{dist.get_rank()}.pt\")\n                )\n                print_rank_0(f\"Saved dataloader state to {dataloader_path}\")\n            except Exception as e:\n                logger.error(f\"Failed to save dataloader state: {e}\", exc_info=True)\n        \n        # Save optimizer and scheduler state\n        optimizer_path = os.path.join(ckpt_path, \"optimizer_ckpt\")\n        optimizer_state = {\n            \"optimizer_state_dict\": optimizer.state_dict(),\n            \"scheduler_state_dict\": lr_scheduler.state_dict(),\n        }\n        if dist.get_rank() == 0:\n            os.makedirs(optimizer_path, exist_ok=True)\n        dist.barrier()\n        torch.save(\n            optimizer_state,\n            os.path.join(optimizer_path, f\"rank{dist.get_rank()}.pt\")\n        )\n        print_rank_0(f\"Saved optimizer state to {optimizer_path}\")\n        \n    except Exception as e:\n        logger.error(f\"Failed to save checkpoint: {e}\", exc_info=True)\n        raise\n    finally:\n        dist.barrier()\n\n\ndef initialize_distributed() -> Tuple[int, int, int]:\n    \"\"\"Initialize distributed training environment.\n    \n    Returns:\n        Tuple of (rank, world_size, local_rank)\n    \"\"\"\n    rank = int(os.environ.get(\"OMPI_COMM_WORLD_RANK\", 0))\n    world_size = int(os.environ.get(\"OMPI_COMM_WORLD_SIZE\", 0))\n    local_rank = int(os.environ.get(\"OMPI_COMM_WORLD_LOCAL_RANK\", 0))\n    \n    torch.cuda.set_device(local_rank)\n    torch.distributed.init_process_group(\n        rank=rank,\n        world_size=world_size,\n        timeout=PROCESS_GROUP_TIMEOUT\n    )\n    \n    return rank, world_size, local_rank\n\n\ndef initialize_model(\n    args,\n    device_mesh: DeviceMesh,\n    state_dict: Optional[Dict[str, torch.Tensor]],\n    converter: StateDictConverter,\n) -> torch.nn.Module:\n    \"\"\"Initialize and shard model.\n    \n    Args:\n        args: Training arguments\n        device_mesh: Device mesh for distributed training\n        state_dict: Optional pretrained state dict\n        converter: State dict converter\n    \n    Returns:\n        Initialized and sharded model\n    \"\"\"\n    # Create model on meta device\n    with set_default_dtype(torch.bfloat16), torch.device(\"meta\"), init_empty_weights():\n        config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n        config._attn_implementation = \"flash_attention_2\"\n        config.use_cache = False\n        config.chunked_loss_computer = args.use_chunked_loss_computer\n        model = eval(args.model_class)(config)\n    \n    # Verify all parameters are on meta device\n    for tensor in itertools.chain(model.parameters(), model.buffers()):\n        assert tensor.device == torch.device(\"meta\"), \"All tensors must be on meta device\"\n    \n    # Enable gradient checkpointing if requested\n    if args.enable_gradient_checkpointing:\n        print_rank_0(\"Enable gradient checkpointing\")\n        set_activation_checkpointing(\n            model, auto_wrap_policy=eval(args.model_class).wrap_modules\n        )\n    \n    # Convert to fp32 if needed\n    if args.use_fp32_weight:\n        model = model.float()\n    \n    # Shard model with FSDP\n    shard_model(\n        model=model,\n        cpu_offload=False,\n        reshard_after_forward=args.reshard_after_forward,\n        dp_mesh=device_mesh,\n        fp32_weight=args.use_fp32_weight,\n        model_class=args.model_class,\n        fp32_reduce=args.use_fp32_reduce\n    )\n    dist.barrier()\n    \n    # Load state dict\n    with Timer(\"Load state dict\"):\n        load_from_full_model_state_dict(\n            model=model,\n            full_sd=state_dict,\n            allow_random_init_params=args.allow_random_init_params,\n            use_tie_weights=args.use_tie_weights\n        )\n    \n    # Tie weights if requested\n    # Sharing weights between embedding and output projection can reduce parameters\n    # and improve training stability for some models\n    if args.use_tie_weights:\n        model.lm_head.weight = model.model.embed_tokens.weight\n        # Verify weight tying: check if there are any differences (should be ~0)\n        diff_weight = model.lm_head.weight - model.model.embed_tokens.weight\n        diff_weight_cnt = (diff_weight.full_tensor().abs() > 1e-6).float().sum()\n        print_rank_0(\n            f\"diff_weight_cnt: {diff_weight_cnt.item()}, \"\n            f\"diff_weight_ratio: {diff_weight_cnt.item() / model.lm_head.weight.numel():.4f}\"\n        )\n    \n    # Initialize RoPE\n    with torch.device(torch.cuda.current_device()):\n        for m in model.modules():\n            if hasattr(m, \"rope_init\"):\n                print_rank_0(\"Initialize RoPE\")\n                m.rope_init()\n            elif hasattr(m, \"inv_freq\"):\n                print_rank_0(f\"Initialize RoPE inv_freq for {m.__class__.__name__}\")\n                from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS\n                rope_type = getattr(m, \"rope_type\", \"default\")\n                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]\n                inv_freq, attention_scaling = rope_init_fn(\n                    m.config, device=torch.cuda.current_device()\n                )\n                m.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n                m.attention_scaling = attention_scaling\n    \n    # Freeze parameters if requested\n    # When freeze_llm is enabled, only embedding and output head are trainable\n    # This is useful for embedding-only fine-tuning or when using start_optimize_embedding_index\n    if args.freeze_llm:\n        assert args.start_optimize_embedding_index > 0\n        for name, param in model.named_parameters():\n            if \"embed_tokens\" in name or \"lm_head\" in name:\n                param.requires_grad = True  # Only embeddings and output head are trainable\n            else:\n                param.requires_grad = False  # Freeze all transformer layers\n    \n    # Print trainable parameters\n    for name, param in model.named_parameters():\n        if param.requires_grad:\n            print_rank_0(f\"Trainable parameter: {name}\")\n    print_rank_0(\"=\" * 50)\n    \n    return model\n\n\ndef load_model_checkpoint(\n    args,\n    app_state: AppState,\n    dist_checkpointer: DistributedCheckpointer,\n    converter: StateDictConverter,\n) -> None:\n    \"\"\"Load model checkpoint from distributed checkpoint.\n    \n    Args:\n        args: Training arguments\n        app_state: Application state\n        dist_checkpointer: Distributed checkpointer\n        converter: State dict converter\n    \"\"\"\n    ckpt_path = os.path.join(args.resume_from, args.resume_from_tag)\n    if not os.path.exists(ckpt_path):\n        raise ValueError(f\"Checkpoint path {ckpt_path} does not exist\")\n    \n    state_dict = {\"app\": app_state.set_call_back(converter.convert)}\n    dist_checkpointer.load_checkpoint(\n        state_dict=state_dict,\n        checkpoint_dir=args.resume_from,\n        tag=args.resume_from_tag\n    )\n    print_rank_0(\"Successfully loaded model using distributed checkpoint\")\n\n\ndef load_optimizer_checkpoint(\n    args,\n    optimizer: torch.optim.Optimizer,\n    lr_scheduler,\n) -> None:\n    \"\"\"Load optimizer and scheduler state from checkpoint.\n    \n    Args:\n        args: Training arguments\n        optimizer: Optimizer instance\n        lr_scheduler: Learning rate scheduler\n    \"\"\"\n    optimizer_state_dict_path = os.path.join(\n        args.resume_from, \"optimizer_ckpt\", f\"rank{dist.get_rank()}.pt\"\n    )\n    if os.path.exists(optimizer_state_dict_path):\n        optimizer_state_dict = torch.load(optimizer_state_dict_path)\n        lr_scheduler.load_state_dict(optimizer_state_dict[\"scheduler_state_dict\"])\n        optimizer.load_state_dict(optimizer_state_dict[\"optimizer_state_dict\"])\n        print_rank_0(f\"Successfully loaded optimizer and scheduler state from {optimizer_state_dict_path}\")\n    else:\n        print_rank_0(f\"Warning: Optimizer checkpoint {optimizer_state_dict_path} does not exist\")\n\n\ndef load_dataloader_checkpoint(args) -> Optional[Dict]:\n    \"\"\"Load dataloader state from checkpoint.\n    \n    Args:\n        args: Training arguments\n    \n    Returns:\n        Dataloader state dict if found, None otherwise\n    \"\"\"\n    dataloader_resume_path = os.path.join(\n        args.resume_from, \"dataloader_ckpt\", f\"rank{dist.get_rank()}.pt\"\n    )\n    if os.path.exists(dataloader_resume_path):\n        try:\n            dataloader_state_dict = torch.load(dataloader_resume_path)[\"dataloader_state_dict\"]\n            print_rank_0(f\"Successfully loaded dataloader state from {dataloader_resume_path}\")\n            return dataloader_state_dict\n        except Exception as e:\n            print_rank_0(f\"Error loading dataloader checkpoint: {e}\")\n            return None\n    else:\n        print_rank_0(f\"Warning: Dataloader checkpoint {dataloader_resume_path} does not exist\")\n        print_rank_0(\"Will start training without resuming dataloader state\")\n        return None\n\n\ndef load_checkpoint(\n    args,\n    app_state: AppState,\n    dist_checkpointer: DistributedCheckpointer,\n    converter: StateDictConverter,\n    optimizer: torch.optim.Optimizer,\n    lr_scheduler,\n) -> Tuple[Optional[Dict], int]:\n    \"\"\"Load checkpoint if resuming training.\n    \n    This function orchestrates loading of model, optimizer, scheduler, and dataloader\n    checkpoints. It delegates to specialized functions for each component.\n    \n    Args:\n        args: Training arguments\n        app_state: Application state\n        dist_checkpointer: Distributed checkpointer\n        converter: State dict converter\n        optimizer: Optimizer instance\n        lr_scheduler: Learning rate scheduler\n    \n    Returns:\n        Tuple of (dataloader_state_dict, global_step)\n    \"\"\"\n    dataloader_state_dict = None\n    global_step = 0\n    \n    if args.resume_from_tag:\n        ckpt_path = os.path.join(args.resume_from, args.resume_from_tag)\n        if args.resume_training_state:\n            global_step = int(args.resume_from_tag.split(\"step\")[-1])\n            print_rank_0(\n                f\"Resume from checkpoint: {ckpt_path}, global_step={global_step}\"\n            )\n        else:\n            print_rank_0(\n                f\"Resume model weights only from checkpoint: {ckpt_path}, \"\n                \"global_step stays at 0\"\n            )\n        \n        # Load model checkpoint\n        load_model_checkpoint(args, app_state, dist_checkpointer, converter)\n        \n        # Load optimizer, scheduler, and dataloader state if requested\n        # Note: resume_training_state controls whether to restore the full training state\n        # including optimizer momentum, scheduler step, and dataloader position.\n        # This allows seamless continuation of training from a checkpoint.\n        if args.resume_training_state:\n            load_optimizer_checkpoint(args, optimizer, lr_scheduler)\n            dataloader_state_dict = load_dataloader_checkpoint(args)\n    \n    return dataloader_state_dict, global_step\n\n\ndef compute_forward_backward(\n    model: torch.nn.Module,\n    batch: Dict,\n    compute_loss_fn,\n    loss_fn: CrossEntropyLoss,\n    args,\n    embedding_masker: Optional[EmbeddingGradientMasker],\n    optimizer: torch.optim.Optimizer,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Compute forward and backward pass.\n    \n    Args:\n        model: Model instance\n        batch: Input batch\n        compute_loss_fn: Loss computation function\n        loss_fn: Loss function instance\n        args: Training arguments\n        embedding_masker: Optional embedding gradient masker\n        optimizer: Optimizer instance\n    \n    Returns:\n        Tuple of (loss, per_token_loss)\n    \"\"\"\n    input_ids = batch[\"input_ids\"]\n    loss_mask = batch[\"loss_mask\"]\n    attention_mask = batch.get(\"attention_mask\", None)\n    cu_seqlens = batch.get(\"cu_seqlens\", None)\n    position_ids = batch.get(\"position_ids\", None)\n    \n    # Prepare labels\n    # Zero out padding tokens (input_ids <= 0) to avoid computing loss on them\n    input_ids = input_ids * (input_ids > 0).to(torch.int64, non_blocking=True)\n    # Forward pass\n    with Timer(\"Fwd\"):\n        output = model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            labels=None,\n            cu_seqlens=cu_seqlens,\n            position_ids=position_ids,\n        )\n        \n        logits = output.logits\n        \n        # Shift labels for next token prediction\n        # For causal LM, we predict token[i] given tokens[0:i], so labels need to be shifted\n        # by one position: label[i] should correspond to input[i+1]\n        pad = torch.full(\n            (input_ids.shape[0], 1),\n            loss_fn.ignore_index,\n            dtype=input_ids.dtype\n        ).to(device=input_ids.device, non_blocking=True)\n        labels = torch.cat([input_ids[:, 1:], pad], dim=-1)\n        # Update labels: use input_ids where loss_mask==1, ignore_index where loss_mask==0\n        # This allows selective loss computation on specific tokens (e.g., excluding special tokens)\n        labels = labels * loss_mask + loss_fn.ignore_index * (1 - loss_mask)\n        \n        loss, per_token_loss = compute_loss_fn(logits, labels=labels)\n        per_token_loss = per_token_loss.to(loss.device)\n    \n    # Backward pass\n    with Timer(\"bwd\"):\n        loss.backward()\n        \n        # Apply gradient mask for embedding layers if needed\n        # When start_optimize_embedding_index > 0, only embeddings with index >= threshold are trainable\n        # This allows progressive unfreezing of embeddings during training\n        if args.start_optimize_embedding_index > 0 and embedding_masker is not None:\n            embedding_masker.apply_gradient_mask(optimizer)\n        \n        # clip_grad_by_value(model, args.clip_range)\n        if args.max_grad_norm and args.max_grad_norm > 0:\n            clip_grad_norm(model, args.max_grad_norm)\n    \n    return loss, per_token_loss\n\n\ndef compute_metrics(\n    batch: Dict,\n    loss: torch.Tensor,\n    per_token_loss: torch.Tensor,\n    loss_mask: torch.Tensor,\n    loss_fn: CrossEntropyLoss,\n    args,\n    metrics: TrainingMetrics,\n) -> Tuple[float, float, float, int, int, int]:\n    \"\"\"Compute and accumulate training metrics.\n    \n    Args:\n        batch: Input batch\n        loss: Loss tensor\n        per_token_loss: Per-token loss tensor\n        loss_mask: Loss mask tensor\n        loss_fn: Loss function instance\n        args: Training arguments\n        metrics: Training metrics tracker\n    \n    Returns:\n        Tuple of (avg_loss, avg_itemic_token_loss, avg_text_token_loss,\n                 num_tokens, num_samples, num_valid_tokens)\n    \"\"\"\n    input_ids = batch[\"input_ids\"]\n    cu_seqlens = batch.get(\"cu_seqlens\", None)\n    itemic_id_mask = batch.get(\"itemic_id_mask\", None)\n    data_source = batch.get(\"data_source\", None)\n    sample_idx = batch[\"sample_idx\"]\n    \n    # Compute token metrics\n    token_count = input_ids.numel()\n    num_samples = len(cu_seqlens) - 1 if cu_seqlens is not None else 1\n    \n    # Calculate number of valid tokens (tokens with loss_mask == 1)\n    # Works for both 1D (flattened) and 2D (batch, seq_len) loss_mask\n    num_valid_tokens = (loss_mask == 1).sum().item()\n    \n    # Aggregate metrics across all ranks\n    token_metrics = torch.tensor(\n        [token_count, num_samples, num_valid_tokens]\n    ).cuda(non_blocking=True)\n    dist.all_reduce(\n        token_metrics, op=dist.ReduceOp.SUM, group=None\n    )\n    num_tokens, num_samples, num_valid_tokens = (\n        token_metrics.detach().cpu().numpy()\n    )\n    \n    # Update metrics\n    metrics.update(num_tokens, num_samples, num_valid_tokens)\n    metrics.period_num_steps += 1\n    \n    # Compute average loss for this step\n    avg_loss = loss.detach()\n    dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)\n    avg_loss = avg_loss.item() / dist.get_world_size()\n    metrics.period_sum_loss += avg_loss\n    \n    # Compute itemic and text token losses\n    if itemic_id_mask is not None:\n        itemic_id_mask = itemic_id_mask.view(per_token_loss.shape)\n        avg_itemic_token_loss = (itemic_id_mask * per_token_loss).sum() / (itemic_id_mask.sum() + 1e-6)\n        avg_text_token_loss = ((1 - itemic_id_mask) * per_token_loss).sum() / ((1 - itemic_id_mask).sum() + 1e-6)\n        dist.all_reduce(avg_itemic_token_loss, op=dist.ReduceOp.SUM)\n        dist.all_reduce(avg_text_token_loss, op=dist.ReduceOp.SUM)\n        avg_itemic_token_loss = avg_itemic_token_loss.item() / dist.get_world_size()\n        avg_text_token_loss = avg_text_token_loss.item() / dist.get_world_size()\n    else:\n        avg_itemic_token_loss = 0.0\n        avg_text_token_loss = avg_loss\n    \n    metrics.period_sum_itemic_token_loss += avg_itemic_token_loss\n    metrics.period_sum_text_token_loss += avg_text_token_loss\n    \n    # Monitor data source metrics\n    if args.monitor_datasource_loss and data_source is not None:\n        local_sample_idx = sample_idx.squeeze()\n        unique_sample_idx = local_sample_idx.unique()\n        # Get local loss mask for valid token counting\n        local_loss_mask = loss_mask.squeeze()\n        for s_idx in unique_sample_idx:\n            if s_idx < 0:\n                continue\n            mask = local_sample_idx == s_idx\n            sum_loss = per_token_loss[mask].sum()\n            key = data_source[int(s_idx.item())]\n            metrics.period_data_source_loss[key] += sum_loss.item()\n            metrics.period_data_source_tokens[key] += mask.sum().item()\n            # Count valid tokens using loss mask\n            metrics.period_valid_data_source_tokens[key] += (\n                mask[local_loss_mask != 0].sum().item()\n            )\n    \n    if args.monitor_datasource_cnt and data_source is not None:\n        for data_source_name in data_source:\n            metrics.local_period_data_source_samples[data_source_name] += 1\n    \n    return avg_loss, avg_itemic_token_loss, avg_text_token_loss, int(num_tokens), int(num_samples), int(num_valid_tokens)\n\n\ndef log_training_step(\n    global_step: int,\n    metrics: TrainingMetrics,\n    args,\n    lr_scheduler,\n    grad_norm: float,\n    period_start_time: float,\n    training_start_time: float,\n    mfu_stats: MFUStats,\n    step_time_tracker: TimeTracker,\n    iteration_time_tracker: TimeTracker,\n    epoch_idx: int,\n    tb_logger: TensorBoardLogger,\n    chunked_loss_computer: Optional[ChunkedLossComputer],\n) -> float:\n    \"\"\"Log training step metrics.\n    \n    Args:\n        global_step: Global training step\n        metrics: Training metrics tracker\n        args: Training arguments\n        lr_scheduler: Learning rate scheduler\n        grad_norm: Gradient norm\n        period_start_time: Start time of the current logging period\n        training_start_time: Start time of the entire training run\n        mfu_stats: MFU statistics tracker\n        step_time_tracker: Time tracker for training steps (tracks forward/backward/optimizer)\n        iteration_time_tracker: Time tracker for data iteration (tracks data loading)\n        epoch_idx: Current epoch index\n        tb_logger: TensorBoard logger\n        chunked_loss_computer: Optional chunked loss computer\n    \n    Returns:\n        Updated period_start_time for next logging period\n    \"\"\"\n    end_time = time.time()\n    model_lrs = lr_scheduler.get_last_lr()\n    learning_rate = model_lrs[0]\n    \n    # Compute performance metrics for the logging period\n    period_duration = end_time - period_start_time\n    period_num_steps = max(metrics.period_num_steps, 1)  # Avoid division by zero\n    \n    # Current period metrics (_current): reflect performance in the current logging period\n    # These metrics show recent performance and can fluctuate with short-term variations\n    sec_per_step = period_duration / period_num_steps\n    tokens_per_sec_per_gpu_current = (\n        metrics.period_num_tokens / period_duration / dist.get_world_size()\n    )\n    samples_per_sec_per_gpu_current = (\n        metrics.period_num_samples / period_duration / dist.get_world_size()\n    )\n    samples_per_step_per_gpu_current = (\n        metrics.period_num_samples / period_num_steps / dist.get_world_size()\n    )\n    valid_tokens_per_sec_per_gpu_current = (\n        metrics.period_num_valid_tokens / period_duration / dist.get_world_size()\n    )\n    \n    # Average metrics (_avg): reflect average performance over entire training\n    # These metrics smooth out short-term fluctuations and include all overhead\n    # (checkpoint saving, logging, etc.), providing a more stable view of overall performance\n    samples_per_step_per_gpu_avg = (\n        metrics.total_num_samples / dist.get_world_size() / max(global_step, 1)\n    )\n    samples_per_sec_per_gpu_avg = (\n        metrics.total_num_samples / dist.get_world_size() / (end_time - training_start_time)\n    )\n    tokens_per_step_per_gpu_avg = (\n        metrics.total_num_tokens / dist.get_world_size() / max(global_step, 1)\n    )\n    tokens_per_sec_per_gpu_avg = (\n        metrics.total_num_tokens / dist.get_world_size() / (end_time - training_start_time)\n    )\n    \n    # Compute average losses over the logging period\n    avg_loss = metrics.period_sum_loss / period_num_steps\n    avg_itemic_token_loss = metrics.period_sum_itemic_token_loss / period_num_steps\n    avg_text_token_loss = metrics.period_sum_text_token_loss / period_num_steps\n    \n    # Reduce data source metrics across all ranks\n    period_data_source_loss = dist_reduce_dict(metrics.period_data_source_loss)\n    period_data_source_tokens = dist_reduce_dict(metrics.period_data_source_tokens)\n    period_valid_data_source_tokens = dist_reduce_dict(metrics.period_valid_data_source_tokens)\n    total_data_source_samples = dist_reduce_dict(\n        metrics.local_period_data_source_samples, group=None\n    )\n    # Update total data source tokens\n    for ds_key, ds_num_tokens in period_data_source_tokens.items():\n        metrics.total_data_source_tokens[ds_key] += ds_num_tokens\n    \n    # Build log dictionary\n    log_dict = {\n        \"training/loss\": avg_loss,\n        \"training/itemic_token_loss\": avg_itemic_token_loss,\n        \"training/text_token_loss\": avg_text_token_loss,\n        \"training/grad_norm\": grad_norm,\n        \"training/learning_rate\": learning_rate,\n        \"perf/sec_per_step\": sec_per_step,\n        \"perf/tokens_per_sec_per_gpu_current\": tokens_per_sec_per_gpu_current,\n        \"perf/samples_per_sec_per_gpu_current\": samples_per_sec_per_gpu_current,\n        \"perf/total_num_tokens\": metrics.total_num_tokens,\n        \"perf/total_num_samples\": metrics.total_num_samples,\n        \"perf/num_sample_per_gpu\": metrics.total_num_samples / dist.get_world_size(),\n        \"perf/samples_per_step_per_gpu_current\": samples_per_step_per_gpu_current,\n        # Note: num_sample_per_sec_per_gpu is the same as samples_per_sec_per_gpu_current\n        # Keeping for backward compatibility, but samples_per_sec_per_gpu_current should be used\n        \"perf/num_sample_per_sec_per_gpu\": samples_per_sec_per_gpu_current,\n        \"perf/valid_total_num_tokens\": metrics.total_num_valid_tokens,\n        \"perf/valid_tokens_per_sec_per_gpu_current\": valid_tokens_per_sec_per_gpu_current,\n        \"perf/valid_token_ratio\": metrics.total_num_valid_tokens / metrics.total_num_tokens,\n        **mfu_stats.mfu(period_duration, global_step),\n        \"perf/samples_per_step_per_gpu_avg\": samples_per_step_per_gpu_avg,\n        \"perf/samples_per_sec_per_gpu_avg\": samples_per_sec_per_gpu_avg,\n        \"perf/tokens_per_step_per_gpu_avg\": tokens_per_step_per_gpu_avg,\n        \"perf/tokens_per_sec_per_gpu_avg\": tokens_per_sec_per_gpu_avg,\n        \"perf/epoch_idx\": epoch_idx,\n    }\n    \n    # Get ticker statistics\n    ticker_stats = {}\n    for t in [step_time_tracker, iteration_time_tracker]:\n        ticker_stats.update(t.stat())\n    \n    # Log to TensorBoard\n    tb_logger.log(\n        global_step,\n        log_dict,\n        ticker_stats,\n        period_data_source_loss if args.monitor_datasource_loss else {},\n        period_data_source_tokens if args.monitor_datasource_cnt else {},\n        total_data_source_samples if args.monitor_datasource_cnt else {},\n    )\n    \n    # Print to console\n    print_rank_0(\n        f\"Step: {global_step}, Loss: {avg_loss:.4f}, \"\n        f\"Learning Rate: {learning_rate:.2e}, \"\n        f\"Grad Norm: {grad_norm:.4f}, \"\n        f\"Sec per Step: {sec_per_step:.4f}\",\n        format_dict_or_list(log_dict),\n        \"\\n\",\n        format_dict_or_list({\n            \"mfu_stats\": mfu_stats.mfu_per_step_per_gpu,\n            \"step_time_tracker\": step_time_tracker.stat()\n        }),\n        \"\\n\",\n        chunked_loss_computer.ticker.stat() if chunked_loss_computer else \"\",\n    )\n    \n    return end_time\n\n\ndef train():\n    \"\"\"Main training function.\"\"\"\n    parser = get_argument_parser()\n    args = parser.parse_args()\n    \n    # Validate arguments\n    assert args.learning_rate > 0.0, \"Learning rate must be positive\"\n    assert args.save_checkpoint_per_step > 0, \"save_checkpoint_per_step must be positive\"\n    \n    # Initialize distributed training\n    rank, world_size, local_rank = initialize_distributed()\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(dist.get_world_size(),))\n    \n    set_random_seed(args.seed)\n    \n    # Load dataset configuration\n    logger.info(f\"Loading dataset config from: {args.dataset_config}\")\n    with open(args.dataset_config, encoding=\"utf-8\") as f:\n        dataset_config = json.loads(f.read())\n    dataset = dataset_config.pop(\"name\")\n    dataset_config[\"model_class\"] = args.model_class\n    if args.max_length:\n        dataset_config[\"max_length\"] = args.max_length\n    \n    # Load pretrained checkpoint\n    converter = StateDictConverter()\n    state_dict = None\n    if dist.get_rank() == 0:\n        with set_default_dtype(torch.bfloat16):\n            state_dict = load_hf_checkpoint(args.model_dir)\n            state_dict = converter.convert(state_dict)\n    dist.barrier()\n    \n    # Save training arguments\n    timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n    if dist.get_rank() == 0:\n        args_dict = vars(args)\n        args_str = json.dumps(args_dict, indent=4, ensure_ascii=False)\n        print_rank_0(f\"Training Arguments:\\n{args_str}\")\n        os.makedirs(args.output_dir, exist_ok=True)\n        with open(\n            os.path.join(args.output_dir, f\"args-{timestamp}.json\"),\n            'w', encoding=\"utf-8\"\n        ) as f:\n            f.write(args_str + \"\\n\")\n    \n    # Initialize TensorBoard\n    tb_writer = None\n    if dist.get_rank() == 0:\n        tb_writer = SummaryWriter(log_dir=os.path.join(args.output_dir, \"log\"))\n    \n    # Initialize model\n    model = initialize_model(args, device_mesh, state_dict, converter)\n    if state_dict is not None:\n        del state_dict\n    \n    # Initialize optimizer\n    optimizer_grouped_parameters = get_optimizer_grouped_parameters(\n        model, learning_rate=args.learning_rate, weight_decay=args.weight_decay\n    )\n    optimizer = torch.optim.AdamW(\n        optimizer_grouped_parameters,\n        lr=args.learning_rate,\n        betas=(args.beta1, args.beta2),\n        eps=1.0e-8\n    )\n    \n    # Initialize embedding gradient masker\n    # This allows selective training of embeddings based on token index\n    # Useful for fine-tuning where only certain token embeddings should be updated\n    embedding_masker = EmbeddingGradientMasker(\n        model, model.config, args.start_optimize_embedding_index\n    )\n    if args.start_optimize_embedding_index > 0:\n        # Save frozen embedding parameters to restore after optimizer step\n        # This prevents optimizer from updating frozen embeddings\n        embedding_masker.save_frozen_params()\n    \n    # Initialize learning rate scheduler\n    lr_scheduler = get_scheduler(\n        name=args.lr_scheduler_type,\n        optimizer=optimizer,\n        num_warmup_steps=args.num_warmup_steps,\n        num_training_steps=args.num_training_steps,\n        min_lr=args.min_lr\n    )\n    \n    # Initialize checkpointing\n    app_state = AppState(model=model)\n    dist_checkpointer = DistributedCheckpointer()\n    \n    # Load checkpoint if resuming\n    dataloader_state_dict, global_step = load_checkpoint(\n        args, app_state, dist_checkpointer, converter, optimizer, lr_scheduler\n    )\n    dist.barrier()\n    \n    # Load tokenizer\n    tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)\n    \n    # Save dataset configuration\n    if dist.get_rank() == 0:\n        with open(\n            os.path.join(args.output_dir, f\"dataset-{timestamp}.json\"),\n            'w', encoding=\"utf-8\"\n        ) as f:\n            f.write(json.dumps(dataset_config, ensure_ascii=False, indent=2) + \"\\n\")\n    \n    # Build dataloader\n    with Timer(\"Build dataloader\"):\n        try:\n            dataloader = get_dataloader(name=dataset, **dataset_config)\n        except Exception as e:\n            logger.error(f\"Failed to build dataloader: {e}\", exc_info=True)\n            raise\n        if args.resume_training_state and dataloader_state_dict is not None:\n            dataloader.load_state_dict(dataloader_state_dict)\n    \n    # Initialize profiler\n    torch_profiler = _init_profiler(\n        output_dir=os.path.join(args.output_dir, \"torch_profile\"),\n        enable=args.enable_profiler\n    )\n    \n    # Initialize loss function\n    loss_fn = CrossEntropyLoss(\n        ignore_index=-100, return_token_loss=True, shift_labels=False\n    )\n    compute_loss_fn = loss_fn\n    chunked_loss_computer = None\n    if args.use_chunked_loss_computer:\n        chunked_loss_computer = ChunkedLossComputer(\n            lm_head=model.lm_head,\n            loss_fn=loss_fn,\n            minibatch_size=args.minibatch_size,\n            shift_labels=False\n        )\n        compute_loss_fn = chunked_loss_computer.forward_and_backward\n    \n    # Initialize training state\n    training_start_time = time.time()\n    period_start_time = training_start_time\n    remaining_debug_samples = 1  # Number of sample batches to print for debugging\n    # Only reset global_step if not resuming from checkpoint\n    # If resume_from_tag exists, global_step is already set in load_checkpoint\n    if args.resume_from_tag is None:\n        global_step = 0\n    \n    metrics = TrainingMetrics()\n    mfu_stats = MFUStats(args)\n    # step_time_tracker: tracks time for training steps (forward/backward/optimizer)\n    step_time_tracker = TimeTracker(n=args.logging_per_step)\n    # iteration_time_tracker: tracks time for data iteration (data loading)\n    iteration_time_tracker = TimeTracker(n=args.logging_per_step)\n    tb_logger = TensorBoardLogger(tb_writer)\n    \n    # Create data iterator\n    data_iter = iter(dataloader)\n    get_next_batch = lambda: next(data_iter)\n    \n    # Training loop\n    while True:\n        with contextlib.ExitStack() as ctx:\n            if torch_profiler:\n                ctx.enter_context(torch_profiler)\n            \n            step_time_tracker.tick(\"enter_context(torch_profiler)\")\n            try:\n                batch = get_next_batch()\n            except StopIteration:\n                break\n            step_time_tracker.tick(\"next_batch\")\n            \n            # Show sample data for debugging\n            # Only print from first 8 ranks to avoid log spam (rank 0-7)\n            # Sleep based on rank to stagger output and make logs easier to read\n            if remaining_debug_samples > 0 and dist.get_rank() <= 8:\n                with Timer(\"Show data\"):\n                    input_text = tokenizer.decode(batch['input_ids'][0])\n                    # Stagger output by rank to avoid interleaved prints (0.3s per rank)\n                    time.sleep(float(dist.get_rank()) * 0.3)\n                    print(f\"Input Text:\\n\\n{input_text}\\n\" + \"=\" * 100 + \"\\n\\n\")\n                    print_input_info(batch, f\"rank{dist.get_rank()}\")\n                    remaining_debug_samples -= 1\n            \n            # Move batch to CUDA\n            to_cuda(batch)\n            step_time_tracker.tick(\"to_cuda(batch)\")\n            \n            # Update MFU stats\n            token_count = batch[\"input_ids\"].numel()\n            num_samples = len(batch.get(\"cu_seqlens\", [0, 1])) - 1\n            mfu_stats.set(num_tokens=token_count, num_samples=num_samples)\n            \n            # Forward and backward pass\n            loss, per_token_loss = compute_forward_backward(\n                model, batch, compute_loss_fn, loss_fn, args,\n                embedding_masker, optimizer\n            )\n            \n            # Compute metrics\n            epoch_idx = batch.get(\"epoch_idx\", torch.tensor([0])).cpu().item()\n            avg_loss, avg_itemic_token_loss, avg_text_token_loss, num_tokens, num_samples, num_valid_tokens = compute_metrics(\n                batch, loss, per_token_loss, batch[\"loss_mask\"], loss_fn, args, metrics\n            )\n            \n            step_time_tracker.tick(\"compute_metrics\")\n            \n            # Optimizer step\n            grad_norm = compute_fsdp_zero2_grad_norm(model)\n            optimizer.step()\n            \n            # Restore frozen parameters after optimizer step\n            # This ensures frozen embeddings are not modified by the optimizer\n            # even if they were included in the gradient computation\n            if args.start_optimize_embedding_index > 0:\n                embedding_masker.restore_frozen_params()\n            \n            lr_scheduler.step()\n            optimizer.zero_grad()\n            global_step += 1\n            step_time_tracker.tick(\"optimizer.step\")\n            \n            # Logging\n            if global_step % args.logging_per_step == 0:\n                period_start_time = log_training_step(\n                    global_step, metrics, args, lr_scheduler, grad_norm,\n                    period_start_time, training_start_time, mfu_stats, \n                    step_time_tracker, iteration_time_tracker,\n                    epoch_idx, tb_logger, chunked_loss_computer\n                )\n                metrics.reset_period_accumulators()\n            \n            # Save checkpoint\n            # Save at regular intervals (save_checkpoint_per_step) and at early steps (20, 200)\n            # Early checkpoints help verify training setup and catch issues early\n            should_save = (\n                (global_step % args.save_checkpoint_per_step == 0 and global_step > 0) or\n                global_step == 20 or  # Early checkpoint for initial verification\n                global_step == 200    # Early checkpoint for training stability check\n            )\n            \n            if should_save:\n                torch.cuda.empty_cache()\n                gc.collect()\n                \n                with Timer(\"save checkpoint\"):\n                    save_model_checkpoint(\n                        save_dir=args.output_dir,\n                        tag=f\"step{global_step}\",\n                        global_step=global_step,\n                        optimizer=optimizer,\n                        lr_scheduler=lr_scheduler,\n                        dataloader=dataloader,\n                        app_state=app_state.set_call_back(converter.revert),\n                        dist_checkpointer=dist_checkpointer\n                    )\n                step_time_tracker.tick(f\"save_ckpt*{args.save_checkpoint_per_step}\")\n            \n            iteration_time_tracker.tick(\"iteration_time_tracker\")\n            if torch_profiler:\n                torch_profiler.step()\n    \n    # Save final checkpoint\n    save_model_checkpoint(\n        save_dir=args.output_dir,\n        tag=f\"step{global_step}\",\n        global_step=global_step,\n        optimizer=optimizer,\n        lr_scheduler=lr_scheduler,\n        dataloader=dataloader,\n        app_state=app_state.set_call_back(converter.revert),\n        dist_checkpointer=dist_checkpointer\n    )\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "pretrain/scripts/convert_checkpoint_to_hf.sh",
    "content": "#!/bin/bash\n\nset -e\n\nBASE_MODEL_DIR=$1\nMODEL_HOME=$2\nSTEP=$3\nCKPT_DIR=${MODEL_HOME}/step${STEP}/global_step${STEP}\n\nOUTPUT_DIR=$CKPT_DIR/converted\n\npython3 tools/model_converter/convert_checkpoint_to_hf.py --checkpoint_dir $CKPT_DIR \\\n    --output_dir $OUTPUT_DIR \\\n    --source_hf_model_path $BASE_MODEL_DIR\n"
  },
  {
    "path": "pretrain/scripts/expand_qwen3_vocab.sh",
    "content": "#!/bin/bash\n\nset -e\n\nHF_MODEL_DIR=/code/onerec_pretrain/hf_models/Qwen3-0.6B\nOUTPUT_MODEL_DIR=/code/onerec_pretrain/hf_models/Qwen3-0.6B_itemic\nITEMIC_LAYER_N=3\nVOCAB_SIZE_PER_LAYER=8192\n\npython3 tools/model_converter/expand_qwen3_vocab.py \\\n    --hf_model_dir $HF_MODEL_DIR \\\n    --output_model_dir $OUTPUT_MODEL_DIR \\\n    --itemic_layer_n $ITEMIC_LAYER_N \\\n    --vocab_size_per_layer $VOCAB_SIZE_PER_LAYER\n\n\n"
  },
  {
    "path": "pretrain/scripts/killall.sh",
    "content": "#!/bin/bash\n\nmpirun --allow-run-as-root --hostfile /etc/mpi/hostfile --pernode bash -c \"pkill -9 python3\"\nmpirun --allow-run-as-root --hostfile /etc/mpi/hostfile --pernode bash -c \"pkill -9 pt_main_thread\" \nmpirun --allow-run-as-root --hostfile /etc/mpi/hostfile --pernode bash -c \"pkill -9 pt_data_worker\" "
  },
  {
    "path": "pretrain/scripts/numa_runner.sh",
    "content": "#!/bin/bash\n\n# Get local NUMA node count\nnum_numa=$(numactl -H | grep \"node [0-9] cpus\" | wc -l)\nif [ \"$num_numa\" -lt 1 ]; then\n  num_numa=1\nfi\n\n# Default to NUMA 0\nnuma_id=0\n\necho \"Bind to NUMA node $numa_id\"\n\n# Bind memory and CPU to NUMA node 0 when running command\nnumactl --membind=$numa_id --cpunodebind=$numa_id \"$@\""
  },
  {
    "path": "pretrain/scripts/test_cases_example.json",
    "content": "{\n  \"test_cases\": [\n    {\n      \"type\": \"text\",\n      \"input\": \"你好，请介绍一下你自己。\",\n      \"ground_truth\": \"\"\n    },\n    {\n      \"type\": \"chat\",\n      \"input\": [\n        {\"role\": \"user\", \"content\": \"写一首关于春天的短诗：\"}\n      ],\n      \"ground_truth\": \"\"\n    },\n    {\n      \"type\": \"chat\",\n      \"input\": [\n        {\"role\": \"system\", \"content\": \"你是一名视频描述生成器，请根据下面的视频token生成视频描述\"},\n        {\"role\": \"user\", \"content\": \"这是一个视频：<|sid_begin|><s_a_2919><s_b_5923><s_c_5443><|sid_end|>，帮我总结一下这个视频讲述了什么内容\"}\n      ],\n      \"ground_truth\": \"\"\n    }\n  ]\n}\n\n"
  },
  {
    "path": "pretrain/scripts/test_hf_model.sh",
    "content": "#!/bin/bash\n\n# HuggingFace Model Testing Script\n# Tests a HuggingFace model with text generation or chat mode\n# \n# Configuration:\n#   - MODEL_PATH: Path to HuggingFace model directory\n#   - TEST_FILE: Path to JSON test cases file (optional, use --use_default if not set)\n#   - Generation parameters: MAX_NEW_TOKENS, TEMPERATURE, TOP_P, REPETITION_PENALTY\n#   - Chat options: ENABLE_THINKING, SHOW_TEMPLATE, SHOW_INPUT_IDS\n#   - Output: COMPARE_GROUND_TRUTH\n\nset -e\n\n# Model path - receive from command line argument\nMODEL_PATH=\"$1\"\n\n# Check if MODEL_PATH is empty\nif [ -z \"${MODEL_PATH}\" ]; then\n    echo \"ERROR: MODEL_PATH cannot be empty\"\n    echo \"Usage: $0 <MODEL_PATH>\"\n    echo \"Example: $0 /path/to/model\"\n    exit 1\nfi\n\n# Check if model path exists\nif [ ! -e \"${MODEL_PATH}\" ]; then\n    echo \"WARNING: model path does not exist: ${MODEL_PATH}\"\n    exit 1\nfi\n\n# Test case: use default or specify a test file\n# Option 1: Use built-in default test cases\nUSE_DEFAULT=true\n\n# Option 2: Use custom test file (comment out USE_DEFAULT and uncomment below)\n# USE_DEFAULT=false\n# TEST_FILE=tools/model_test/test_cases_example.json\n\n# Generation parameters\nMAX_NEW_TOKENS=1024\nTEMPERATURE=0.7\nTOP_P=0.9\nREPETITION_PENALTY=1.2\n\n# Chat mode options\nENABLE_THINKING=false\nSHOW_TEMPLATE=false\nSHOW_INPUT_IDS=false\n\n# Output options\nCOMPARE_GROUND_TRUTH=false\n\n# Device and data type\nDEVICE=auto\nDTYPE=bf16\n\n# Build command\nCMD=\"python3 tools/model_test/test_hf_model.py\"\nCMD=\"$CMD --model_path $MODEL_PATH\"\nCMD=\"$CMD --device $DEVICE\"\nCMD=\"$CMD --dtype $DTYPE\"\nCMD=\"$CMD --max_new_tokens $MAX_NEW_TOKENS\"\nCMD=\"$CMD --temperature $TEMPERATURE\"\nCMD=\"$CMD --top_p $TOP_P\"\nCMD=\"$CMD --repetition_penalty $REPETITION_PENALTY\"\n\n# Test case source\nif [ \"$USE_DEFAULT\" = true ]; then\n    CMD=\"$CMD --use_default\"\nelif [ -n \"$TEST_FILE\" ]; then\n    CMD=\"$CMD --test_file $TEST_FILE\"\nfi\n\n# Chat mode options\n[ \"$ENABLE_THINKING\" = true ] && CMD=\"$CMD --enable_thinking\"\n[ \"$SHOW_TEMPLATE\" = true ] && CMD=\"$CMD --show_template\"\n[ \"$SHOW_INPUT_IDS\" = true ] && CMD=\"$CMD --show_input_ids\"\n\n# Output options\n[ \"$COMPARE_GROUND_TRUTH\" = true ] && CMD=\"$CMD --compare_ground_truth\"\n\n# Execute\neval $CMD\n\n\n"
  },
  {
    "path": "pretrain/set_env.sh",
    "content": "#!/bin/bash\n\n# Check if current shell is bash\nif [ -z \"$BASH_VERSION\" ]; then\n    echo \"This script must be run with bash. Please use 'bash script.bash' to run it.\" >&2\n    exit 1\nfi\n\n# Get current script directory\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\nENV_FILE=\"${SCRIPT_DIR}/.env\"\n\n# Check if .env file exists\nif [ ! -f \"${ENV_FILE}\" ]; then\n    echo \"Error: ${ENV_FILE} not found\" >&2\n    exit 1\nfi\n\n# Load environment variables\nset -a  # Automatically export all variables\nsource \"${ENV_FILE}\"\nset +a  # Disable automatic export\n\n# Print loaded environment variables\necho \"Loaded environment variables from ${ENV_FILE}:\"\ncat \"${ENV_FILE}\"\n\n# Install system dependencies\nPIP_CMD='pip'\nPROXY=\"http://oversea-squid1.jp.txyun:11080\"\nHOSTFILE=\"/etc/mpi/hostfile\"\n\n# Install numactl on all nodes\nmpirun --allow-run-as-root \\\n    --hostfile \"${HOSTFILE}\" \\\n    -x http_proxy=\"${PROXY}\" \\\n    -x https_proxy=\"${PROXY}\" \\\n    --pernode \\\n    bash -c \"apt-get install -y numactl\"\n\n# Install Python dependencies on all nodes\nmpirun --allow-run-as-root \\\n    --hostfile \"${HOSTFILE}\" \\\n    --pernode \\\n    bash -c \"${PIP_CMD} install transformers==4.53 && \\\n             ${PIP_CMD} install easydict && \\\n             ${PIP_CMD} install torchao==0.10 && \\\n             ${PIP_CMD} install sortedcontainers\"\n"
  },
  {
    "path": "pretrain/tests/test_qwen3_dataset_file_distribution.py",
    "content": "\"\"\"\nTest file distribution logic for Qwen3ChatCompletionParquetDataset in multi-process, multi-worker scenarios\n\nValidation points:\n1. Each file is processed by only one worker (no duplication)\n2. All files are processed (no omission)\n3. Works correctly under different rank and worker combinations\n\"\"\"\n\nimport unittest\nfrom unittest.mock import patch, MagicMock\nimport os\nimport sys\n\n\nclass TestFileDistribution(unittest.TestCase):\n    \"\"\"Test file distribution logic\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up test environment\"\"\"\n        # Create mock file list\n        self.data_files = [\n            (f\"file_{i}.parquet\", 0) for i in range(100)  # 100 files, epoch=0\n        ]\n        self.num_workers = 4\n    \n    def _get_file_distribution(self, rank, world_size, worker, num_workers):\n        \"\"\"\n        Simulate file distribution logic, return file indices for this worker\n\n        Args:\n            rank: Process rank\n            world_size: Total number of processes\n            worker: Worker ID\n            num_workers: Number of workers per process\n\n        Returns:\n            list: File index list\n        \"\"\"\n        total_num_workers = num_workers * world_size\n        local_worker_idx = rank * num_workers + worker\n        fn_list = [\n            idx for idx, fn in enumerate(self.data_files) \n            if idx % total_num_workers == local_worker_idx\n        ]\n        return fn_list\n    \n    def test_file_distribution_no_overlap(self):\n        \"\"\"Test file distribution without overlap: each file is processed by only one worker\"\"\"\n        world_size = 2\n        num_workers = 4\n        \n        # Collect files assigned to all workers\n        all_assigned_files = set()\n\n        for rank in range(world_size):\n            for worker in range(num_workers):\n                assigned_files = self._get_file_distribution(rank, world_size, worker, num_workers)\n                file_indices = set(assigned_files)\n\n                # Check for overlap\n                overlap = all_assigned_files & file_indices\n                self.assertEqual(\n                    len(overlap), 0,\n                    f\"Rank {rank}, Worker {worker} assigned files overlap with existing assignments: {overlap}\"\n                )\n\n                all_assigned_files.update(file_indices)\n\n        # Verify all files are assigned\n        total_files = len(self.data_files)\n        self.assertEqual(\n            len(all_assigned_files), total_files,\n            f\"File assignment incomplete: expected {total_files} files, actually assigned {len(all_assigned_files)}\"\n        )\n    \n    def test_file_distribution_completeness(self):\n        \"\"\"Test file distribution completeness: all files are processed\"\"\"\n        world_size = 2\n        num_workers = 4\n\n        all_assigned_files = set()\n\n        for rank in range(world_size):\n            for worker in range(num_workers):\n                assigned_files = self._get_file_distribution(rank, world_size, worker, num_workers)\n                all_assigned_files.update(assigned_files)\n\n        # Verify all files are assigned\n        expected_files = set(range(len(self.data_files)))\n        self.assertEqual(\n            all_assigned_files, expected_files,\n            f\"File assignment incomplete: missing files {expected_files - all_assigned_files}\"\n        )\n    \n    def test_file_distribution_different_configs(self):\n        \"\"\"Test file distribution under different configurations\"\"\"\n        test_configs = [\n            (1, 1),   # Single process, single worker\n            (1, 4),   # Single process, 4 workers\n            (2, 2),   # 2 processes, 2 workers each\n            (4, 2),   # 4 processes, 2 workers each\n            (2, 8),   # 2 processes, 8 workers each\n        ]\n        \n        for world_size, num_workers in test_configs:\n            with self.subTest(world_size=world_size, num_workers=num_workers):\n                all_assigned_files = set()\n                \n                for rank in range(world_size):\n                    for worker in range(num_workers):\n                        assigned_files = self._get_file_distribution(\n                            rank, world_size, worker, num_workers\n                        )\n                        file_indices = set(assigned_files)\n\n                        # Check for overlap\n                        overlap = all_assigned_files & file_indices\n                        self.assertEqual(\n                            len(overlap), 0,\n                            f\"Config (world_size={world_size}, num_workers={num_workers}), \"\n                            f\"Rank {rank}, Worker {worker} has overlap: {overlap}\"\n                        )\n\n                        all_assigned_files.update(file_indices)\n\n                # Verify completeness\n                expected_files = set(range(len(self.data_files)))\n                self.assertEqual(\n                    all_assigned_files, expected_files,\n                    f\"Config (world_size={world_size}, num_workers={num_workers}) \"\n                    f\"file assignment incomplete: missing {expected_files - all_assigned_files}\"\n                )\n    \n    def test_file_distribution_balance(self):\n        \"\"\"Test file distribution load balancing (each worker should be assigned roughly equal number of files)\"\"\"\n        world_size = 2\n        num_workers = 4\n        total_workers = world_size * num_workers\n\n        file_counts = []\n        for rank in range(world_size):\n            for worker in range(num_workers):\n                assigned_files = self._get_file_distribution(rank, world_size, worker, num_workers)\n                file_counts.append(len(assigned_files))\n\n        # Calculate expected file count (should be roughly equal)\n        expected_per_worker = len(self.data_files) / total_workers\n        min_files = int(expected_per_worker)\n        max_files = int(expected_per_worker) + 1\n\n        # Verify each worker's file count is within reasonable range\n        for count in file_counts:\n            self.assertGreaterEqual(count, min_files, \"Too few files assigned\")\n            self.assertLessEqual(count, max_files, \"Too many files assigned\")\n\n        # Verify total count is correct\n        self.assertEqual(\n            sum(file_counts), len(self.data_files),\n            f\"Total file count mismatch: expected {len(self.data_files)}, actual {sum(file_counts)}\"\n        )\n    \n    def test_file_distribution_with_epochs(self):\n        \"\"\"Test file distribution with multiple epochs\"\"\"\n        # Create multi-epoch file list\n        data_files_multi_epoch = []\n        for epoch in range(3):\n            for i in range(20):\n                data_files_multi_epoch.append((f\"file_{i}.parquet\", epoch))\n\n        self.data_files = data_files_multi_epoch\n\n        world_size = 2\n        num_workers = 4\n\n        # Collect assignments by (file_idx, epoch)\n        all_assigned = set()\n\n        for rank in range(world_size):\n            for worker in range(num_workers):\n                assigned_indices = self._get_file_distribution(\n                    rank, world_size, worker, num_workers\n                )\n                # Convert indices to (filename, epoch) tuples\n                for idx in assigned_indices:\n                    file_name, epoch = self.data_files[idx]\n                    all_assigned.add((file_name, epoch))\n\n        # Verify all (file, epoch) combinations are assigned\n        expected = set((fn, ep) for fn, ep in self.data_files)\n        self.assertEqual(\n            all_assigned, expected,\n            f\"Multi-epoch file assignment incomplete: missing {expected - all_assigned}\"\n        )\n\n\nclass TestFileDistributionLogic(unittest.TestCase):\n    \"\"\"Test core algorithm of file distribution logic\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up test environment\"\"\"\n        self.data_files = [\n            (f\"file_{i}.parquet\", 0) for i in range(50)\n        ]\n\n    def test_distribution_algorithm(self):\n        \"\"\"Test correctness of file distribution algorithm\"\"\"\n        # Simulate distribution logic in Qwen3NaiveParquetDataset.__iter__local_shuffle\n        rank = 0\n        world_size = 2\n        worker = 0\n        num_workers = 2\n\n        total_num_workers = num_workers * world_size\n        local_worker_idx = rank * num_workers + worker\n        fn_list = [\n            fn for idx, fn in enumerate(self.data_files)\n            if idx % total_num_workers == local_worker_idx\n        ]\n\n        # Verify file list is not empty\n        self.assertGreater(len(fn_list), 0, \"File list should not be empty\")\n\n        # Verify file indices are correct\n        expected_indices = [\n            idx for idx in range(len(self.data_files))\n            if idx % total_num_workers == local_worker_idx\n        ]\n        actual_indices = [\n            idx for idx, fn in enumerate(self.data_files) if fn in fn_list\n        ]\n        self.assertEqual(\n            set(actual_indices), set(expected_indices),\n            \"File index assignment is incorrect\"\n        )\n\n\ndef run_distribution_test_manual():\n    \"\"\"\n    Manually run file distribution test, print detailed assignment information\n    For debugging and verification\n    \"\"\"\n    print(\"=\" * 80)\n    print(\"File Distribution Test - Manual Verification\")\n    print(\"=\" * 80)\n\n    # Test configurations\n    data_files = [(f\"file_{i}.parquet\", 0) for i in range(100)]\n    test_configs = [\n        (1, 1, \"Single process, single worker\"),\n        (1, 4, \"Single process, 4 workers\"),\n        (2, 2, \"2 processes, 2 workers each\"),\n        (4, 2, \"4 processes, 2 workers each\"),\n        (2, 8, \"2 processes, 8 workers each\"),\n    ]\n    \n    for world_size, num_workers, desc in test_configs:\n        print(f\"\\nConfig: {desc} (world_size={world_size}, num_workers={num_workers})\")\n        print(\"-\" * 80)\n\n        total_num_workers = num_workers * world_size\n        all_assigned = {}\n\n        for rank in range(world_size):\n            for worker in range(num_workers):\n                local_worker_idx = rank * num_workers + worker\n                assigned_files = [\n                    idx for idx, fn in enumerate(data_files)\n                    if idx % total_num_workers == local_worker_idx\n                ]\n                all_assigned[(rank, worker)] = assigned_files\n\n                print(f\"  Rank {rank}, Worker {worker} (local_idx={local_worker_idx}): \"\n                      f\"{len(assigned_files)} files, index range: {min(assigned_files) if assigned_files else 'N/A'}-{max(assigned_files) if assigned_files else 'N/A'}\")\n\n        # Verify completeness\n        all_file_indices = set()\n        for assigned in all_assigned.values():\n            all_file_indices.update(assigned)\n\n        expected_indices = set(range(len(data_files)))\n        missing = expected_indices - all_file_indices\n        extra = all_file_indices - expected_indices\n\n        if missing:\n            print(f\"  X Missing file indices: {sorted(missing)}\")\n        if extra:\n            print(f\"  X Extra file indices: {sorted(extra)}\")\n        if not missing and not extra:\n            print(f\"  OK File assignment complete: all {len(data_files)} files correctly assigned\")\n\n        # Check for overlap\n        has_overlap = False\n        for (r1, w1), files1 in all_assigned.items():\n            for (r2, w2), files2 in all_assigned.items():\n                if (r1, w1) >= (r2, w2):  # Avoid duplicate checks\n                    continue\n                overlap = set(files1) & set(files2)\n                if overlap:\n                    print(f\"  X Overlap detected: Rank {r1}, Worker {w1} and Rank {r2}, Worker {w2} overlap files: {sorted(overlap)}\")\n                    has_overlap = True\n\n        if not has_overlap:\n            print(f\"  OK No overlap: all files processed by only one worker\")\n\n\nif __name__ == '__main__':\n    # Run unit tests\n    print(\"Running unit tests...\")\n    unittest.main(argv=[''], exit=False, verbosity=2)\n\n    # Run manual verification\n    print(\"\\n\" + \"=\" * 80)\n    run_distribution_test_manual()\n\n"
  },
  {
    "path": "pretrain/tools/model_converter/convert_checkpoint_to_hf.py",
    "content": "\"\"\"Checkpoint to HuggingFace Format Converter\n\nThis module provides utilities to convert PyTorch checkpoints (DCP or .pth files)\nto HuggingFace format (safetensors or bin files with sharding support).\n\"\"\"\n\nimport argparse\nimport json\nimport logging\nimport os\nimport shutil\nfrom pathlib import Path\nfrom typing import Dict, Optional, Union\n\nimport torch\nimport tqdm\nfrom safetensors.torch import save_file\nfrom torch.distributed.checkpoint import FileSystemReader\nfrom torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner\nfrom torch.distributed.checkpoint.metadata import STATE_DICT_TYPE\nfrom torch.distributed.checkpoint.state_dict_loader import _load_state_dict\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(levelname)s - %(message)s',\n    handlers=[logging.StreamHandler()]\n)\nlogger = logging.getLogger(__name__)\n\n# Constants\nSHARD_FNAME_TEMPLATE = \"model-{cpt_idx}-of-{num_shards}\"\nBYTES_PER_GB = 1024 * 1024 * 1024\nDEFAULT_MAX_GB_PER_SHARD = 5\nDEFAULT_DTYPE = \"bf16\"\n\n# Common HuggingFace config files to copy\nHF_CONFIG_FILES = [\n    \"config.json\",\n    \"tokenizer.json\",\n    \"tokenizer_config.json\",\n    \"tokenizer.model\",  # SentencePiece tokenizer model file\n    \"vocab.txt\",\n    \"vocab.json\",\n    \"merges.txt\",\n    \"special_tokens_map.json\",\n    \"added_tokens.json\",\n    \"generation_config.json\",\n    \"preprocessor_config.json\",  # For vision models\n]\n\n\ndef _get_torch_dtype(dtype_str: str) -> torch.dtype:\n    \"\"\"Convert dtype string to torch.dtype.\n    \n    Args:\n        dtype_str: Data type string (\"fp32\", \"fp16\", \"bf16\")\n        \n    Returns:\n        Corresponding torch.dtype\n        \n    Raises:\n        ValueError: If dtype_str is not supported\n    \"\"\"\n    dtype_map = {\n        \"fp32\": torch.float32,\n        \"fp16\": torch.float16,\n        \"bf16\": torch.bfloat16,\n    }\n    if dtype_str not in dtype_map:\n        raise ValueError(f\"Unsupported dtype: {dtype_str}. Supported: {list(dtype_map.keys())}\")\n    return dtype_map[dtype_str]\n\n\ndef _extract_state_dict_from_checkpoint(checkpoint: Dict, model_only: bool = True) -> Dict[str, torch.Tensor]:\n    \"\"\"Extract state_dict from checkpoint with various structures.\n    \n    Args:\n        checkpoint: Checkpoint dictionary\n        model_only: Whether to extract only model weights\n        \n    Returns:\n        State dictionary containing model weights\n    \"\"\"\n    if not isinstance(checkpoint, dict):\n        raise ValueError(f\"Unsupported checkpoint format: {type(checkpoint)}\")\n    \n    # Check for nested DCP-like structure\n    if model_only and \"app\" in checkpoint and \"model\" in checkpoint[\"app\"]:\n        logger.info(\"Found nested structure: checkpoint['app']['model']\")\n        return checkpoint[\"app\"][\"model\"]\n    elif \"model\" in checkpoint:\n        logger.info(\"Found structure: checkpoint['model']\")\n        return checkpoint[\"model\"]\n    elif \"state_dict\" in checkpoint:\n        logger.info(\"Found structure: checkpoint['state_dict']\")\n        return checkpoint[\"state_dict\"]\n    else:\n        # Assume entire dict is the state_dict\n        logger.info(\"Using entire checkpoint as state_dict\")\n        return checkpoint\n\n\ndef _convert_state_dict_to_shards(\n    state_dict: Dict[str, torch.Tensor],\n    output_dir: Union[str, os.PathLike],\n    use_safetensor: bool = True,\n    max_gb_per_shard: int = DEFAULT_MAX_GB_PER_SHARD,\n    dtype: str = DEFAULT_DTYPE\n) -> None:\n    \"\"\"Convert state_dict to sharded safetensors or bin files.\n    \n    Args:\n        state_dict: State dictionary containing model weights\n        output_dir: Output directory for sharded files\n        use_safetensor: Whether to use safetensors format (default: True)\n        max_gb_per_shard: Maximum size per shard in GB (default: 5)\n        dtype: Data type for conversion (\"fp32\", \"fp16\", \"bf16\", default: \"bf16\")\n        \n    Raises:\n        ValueError: If dtype is not supported\n    \"\"\"\n    torch_dtype = _get_torch_dtype(dtype)\n    logger.info(f\"Converting state_dict to {dtype} format\")\n    \n    # Convert data types\n    logger.info(\"Converting tensor data types...\")\n    for key in tqdm.tqdm(state_dict.keys(), desc=\"Converting dtypes\"):\n        state_dict[key] = state_dict[key].to(torch_dtype)\n    \n    # Split into shards\n    logger.info(f\"Splitting state_dict into shards (max {max_gb_per_shard} GB per shard)...\")\n    split_state_dicts: Dict[int, Dict[str, torch.Tensor]] = {}\n    shard_idx = 0\n    total_size = 0\n    current_size = 0\n    \n    max_bytes_per_shard = max_gb_per_shard * BYTES_PER_GB\n    \n    for key, weight in tqdm.tqdm(state_dict.items(), desc=\"Creating shards\"):\n        if shard_idx not in split_state_dicts:\n            split_state_dicts[shard_idx] = {}\n        \n        split_state_dicts[shard_idx][key] = weight\n        weight_size = weight.numel() * weight.element_size()\n        current_size += weight_size\n        total_size += weight_size\n        \n        if current_size >= max_bytes_per_shard:\n            shard_idx += 1\n            current_size = 0\n    \n    # Write shard files\n    num_shards = len(split_state_dicts)\n    weight_map: Dict[str, str] = {}\n    output_path_obj = Path(output_dir)\n    output_path_obj.mkdir(parents=True, exist_ok=True)\n    \n    logger.info(f\"Writing {num_shards} shard files...\")\n    for shard_idx, shard_state_dict in tqdm.tqdm(split_state_dicts.items(), desc=\"Writing shards\"):\n        shard_name = SHARD_FNAME_TEMPLATE.format(\n            cpt_idx=f\"{shard_idx}\".zfill(5),\n            num_shards=f\"{num_shards}\".zfill(5)\n        )\n        \n        if use_safetensor:\n            shard_path = output_path_obj / f\"{shard_name}.safetensors\"\n            save_file(shard_state_dict, shard_path, metadata={\"format\": \"pt\"})\n        else:\n            shard_path = output_path_obj / f\"{shard_name}.bin\"\n            torch.save(shard_state_dict, shard_path)\n        \n        # Update weight map\n        shard_filename = shard_path.name\n        for key in shard_state_dict.keys():\n            weight_map[key] = shard_filename\n        \n        shard_size_gb = os.path.getsize(shard_path) / BYTES_PER_GB\n        logger.info(f\"Shard {shard_idx + 1}/{num_shards}: {shard_size_gb:.2f} GiB saved to {shard_path}\")\n    \n    # Write index file\n    index_filename = \"model.safetensors.index.json\" if use_safetensor else \"model.bin.index.json\"\n    index_path = output_path_obj / index_filename\n    \n    index_data = {\n        \"metadata\": {\n            \"total_size\": total_size\n        },\n        \"weight_map\": weight_map,\n    }\n    \n    with open(index_path, \"w\", encoding=\"utf-8\") as f:\n        json.dump(index_data, f, indent=2)\n    \n    logger.info(f\"Index file saved to {index_path}\")\n    logger.info(f\"Total model size: {total_size / BYTES_PER_GB:.2f} GiB\")\n\n\ndef pth_to_hf_format(\n    pth_file_path: Union[str, os.PathLike],\n    output_dir: Union[str, os.PathLike],\n    model_only: bool = True,\n    use_safetensor: bool = True,\n    max_gb_per_shard: int = DEFAULT_MAX_GB_PER_SHARD,\n    dtype: str = DEFAULT_DTYPE\n) -> None:\n    \"\"\"Convert .pth file to HuggingFace format (safetensors or bin files).\n    \n    Args:\n        pth_file_path: Path to .pth checkpoint file\n        output_dir: Output directory for converted files\n        model_only: Whether to extract only model weights (default: True)\n        use_safetensor: Whether to use safetensors format (default: True)\n        max_gb_per_shard: Maximum size per shard in GB (default: 5)\n        dtype: Data type for conversion (default: \"bf16\")\n        \n    Raises:\n        FileNotFoundError: If pth_file_path does not exist\n        ValueError: If pth_file_path is not a .pth file or has unsupported format\n        \n    .. warning::\n        To avoid OOM, it's recommended to run this function on a single rank/process.\n    \"\"\"\n    pth_path = Path(pth_file_path)\n    \n    if not pth_path.exists():\n        raise FileNotFoundError(f\"PTH file not found: {pth_path}\")\n    \n    if pth_path.suffix != \".pth\":\n        raise ValueError(f\"Expected .pth file, got: {pth_path.suffix}\")\n    \n    logger.info(f\"Loading PTH file from {pth_path}...\")\n    checkpoint = torch.load(pth_path, map_location=\"cpu\")\n    \n    # Extract state_dict from checkpoint\n    state_dict = _extract_state_dict_from_checkpoint(checkpoint, model_only=model_only)\n    logger.info(f\"Loaded state_dict with {len(state_dict)} keys\")\n    \n    # Convert to HuggingFace format\n    _convert_state_dict_to_shards(\n        state_dict=state_dict,\n        output_dir=output_dir,\n        use_safetensor=use_safetensor,\n        max_gb_per_shard=max_gb_per_shard,\n        dtype=dtype\n    )\n\n\ndef dcp_to_hf_format(\n    dcp_checkpoint_dir: Union[str, os.PathLike],\n    output_dir: Union[str, os.PathLike],\n    model_only: bool = True,\n    use_safetensor: bool = True,\n    max_gb_per_shard: int = DEFAULT_MAX_GB_PER_SHARD,\n    dtype: str = DEFAULT_DTYPE\n) -> None:\n    \"\"\"Convert DCP (Distributed Checkpoint) to HuggingFace format.\n    \n    Args:\n        dcp_checkpoint_dir: Directory containing the DCP checkpoint\n        output_dir: Output directory for converted files\n        model_only: Whether to extract only model weights (default: True)\n        use_safetensor: Whether to use safetensors format (default: True)\n        max_gb_per_shard: Maximum size per shard in GB (default: 5)\n        dtype: Data type for conversion (default: \"bf16\")\n        \n    Raises:\n        FileNotFoundError: If dcp_checkpoint_dir does not exist\n        \n    .. warning::\n        To avoid OOM, it's recommended to run this function on a single rank/process.\n    \"\"\"\n    dcp_path = Path(dcp_checkpoint_dir)\n    \n    if not dcp_path.exists():\n        raise FileNotFoundError(f\"DCP checkpoint directory not found: {dcp_path}\")\n    \n    if not dcp_path.is_dir():\n        raise ValueError(f\"Expected directory, got: {dcp_path}\")\n    \n    logger.info(f\"Loading DCP checkpoint from {dcp_path}...\")\n    state_dict: STATE_DICT_TYPE = {}\n    \n    _load_state_dict(\n        state_dict,\n        storage_reader=FileSystemReader(str(dcp_path)),\n        planner=_EmptyStateDictLoadPlanner(),\n        no_dist=True,\n    )\n    \n    logger.info(\"DCP checkpoint loaded successfully\")\n    \n    if model_only:\n        if \"app\" not in state_dict or \"model\" not in state_dict[\"app\"]:\n            raise ValueError(\"Expected 'app.model' in DCP checkpoint when model_only=True\")\n        state_dict = state_dict[\"app\"][\"model\"]\n        logger.info(f\"Extracted model state_dict with {len(state_dict)} keys\")\n    \n    # Convert to HuggingFace format\n    _convert_state_dict_to_shards(\n        state_dict=state_dict,\n        output_dir=output_dir,\n        use_safetensor=use_safetensor,\n        max_gb_per_shard=max_gb_per_shard,\n        dtype=dtype\n    )\n\n\ndef copy_hf_config_files(\n    source_hf_model_path: Union[str, os.PathLike],\n    output_dir: Union[str, os.PathLike]\n) -> None:\n    \"\"\"Copy HuggingFace configuration files from source to output directory.\n    \n    Args:\n        source_hf_model_path: Path to source HuggingFace model directory\n        output_dir: Output directory where config files will be copied\n    \"\"\"\n    source_path = Path(source_hf_model_path)\n    output_path = Path(output_dir)\n    \n    if not source_path.exists():\n        logger.warning(f\"Source HuggingFace model path does not exist: {source_path}\")\n        return\n    \n    if not source_path.is_dir():\n        logger.warning(f\"Source path is not a directory: {source_path}\")\n        return\n    \n    output_path.mkdir(parents=True, exist_ok=True)\n    \n    copied_files = []\n    \n    # Copy known config files\n    for config_file in HF_CONFIG_FILES:\n        source_file = source_path / config_file\n        if source_file.exists():\n            dest_file = output_path / config_file\n            shutil.copy2(source_file, dest_file)\n            copied_files.append(config_file)\n            logger.debug(f\"Copied {config_file} to {output_path}\")\n    \n    # Copy additional JSON and TXT files (may be config files)\n    for pattern in [\"*.json\", \"*.txt\"]:\n        for source_file in source_path.glob(pattern):\n            # Skip already copied files and weight files\n            if (source_file.name in copied_files or \n                source_file.name.startswith(\"model-\") or\n                source_file.suffix in [\".bin\", \".safetensors\"]):\n                continue\n            \n            dest_file = output_path / source_file.name\n            if not dest_file.exists():  # Avoid overwriting already copied files\n                shutil.copy2(source_file, dest_file)\n                if source_file.name not in HF_CONFIG_FILES:\n                    logger.debug(f\"Copied additional file: {source_file.name}\")\n    \n    if copied_files:\n        logger.info(f\"Successfully copied {len(copied_files)} config files from {source_path} to {output_path}\")\n    else:\n        logger.warning(f\"No config files found in {source_path}\")\n\n\ndef get_argument_parser() -> argparse.ArgumentParser:\n    \"\"\"Create and configure argument parser.\n    \n    Returns:\n        Configured argument parser\n    \"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Convert PyTorch checkpoints (DCP or .pth) to HuggingFace format\"\n    )\n    \n    parser.add_argument(\n        \"--checkpoint_dir\",\n        type=str,\n        required=True,\n        help=\"Path to DCP checkpoint directory or .pth file\"\n    )\n    \n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        required=True,\n        help=\"Output directory for converted HuggingFace model\"\n    )\n    \n    parser.add_argument(\n        \"--source_hf_model_path\",\n        type=str,\n        default=None,\n        help=\"Path to original HuggingFace model to copy config files from (optional)\"\n    )\n    \n    parser.add_argument(\n        \"--use_safetensor\",\n        action=\"store_true\",\n        default=True,\n        help=\"Use safetensors format (default: True)\"\n    )\n    \n    parser.add_argument(\n        \"--no_safetensor\",\n        dest=\"use_safetensor\",\n        action=\"store_false\",\n        help=\"Use .bin format instead of safetensors\"\n    )\n    \n    parser.add_argument(\n        \"--max_gb_per_shard\",\n        type=int,\n        default=DEFAULT_MAX_GB_PER_SHARD,\n        help=f\"Maximum size per shard in GB (default: {DEFAULT_MAX_GB_PER_SHARD})\"\n    )\n    \n    parser.add_argument(\n        \"--dtype\",\n        type=str,\n        default=DEFAULT_DTYPE,\n        choices=[\"fp32\", \"fp16\", \"bf16\"],\n        help=f\"Data type for conversion (default: {DEFAULT_DTYPE})\"\n    )\n    \n    return parser\n\n\ndef main() -> None:\n    \"\"\"Main entry point for the script.\"\"\"\n    parser = get_argument_parser()\n    args = parser.parse_args()\n    \n    checkpoint_path = Path(args.checkpoint_dir)\n    \n    if not checkpoint_path.exists():\n        raise FileNotFoundError(f\"Checkpoint path does not exist: {checkpoint_path}\")\n    \n    # Auto-detect input type: .pth file or DCP checkpoint directory\n    if checkpoint_path.is_file() and checkpoint_path.suffix == \".pth\":\n        logger.info(f\"Detected PTH file: {checkpoint_path}\")\n        pth_to_hf_format(\n            pth_file_path=checkpoint_path,\n            output_dir=args.output_dir,\n            model_only=True,\n            use_safetensor=args.use_safetensor,\n            max_gb_per_shard=args.max_gb_per_shard,\n            dtype=args.dtype\n        )\n    elif checkpoint_path.is_dir():\n        logger.info(f\"Detected DCP checkpoint directory: {checkpoint_path}\")\n        dcp_to_hf_format(\n            dcp_checkpoint_dir=checkpoint_path,\n            output_dir=args.output_dir,\n            model_only=True,\n            use_safetensor=args.use_safetensor,\n            max_gb_per_shard=args.max_gb_per_shard,\n            dtype=args.dtype\n        )\n    else:\n        raise ValueError(\n            f\"Invalid checkpoint path: {checkpoint_path}. \"\n            \"Expected either a .pth file or a DCP checkpoint directory.\"\n        )\n    \n    # Copy config files if source model path is provided\n    if args.source_hf_model_path:\n        logger.info(f\"Copying config files from {args.source_hf_model_path} to {args.output_dir}\")\n        copy_hf_config_files(\n            source_hf_model_path=args.source_hf_model_path,\n            output_dir=args.output_dir\n        )\n    \n    logger.info(\"Conversion completed successfully!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "pretrain/tools/model_converter/expand_qwen3_vocab.py",
    "content": "\"\"\"Qwen3 Vocabulary Expansion Tool\n\nExpand the standard Qwen3 HuggingFace checkpoint vocabulary to support post-training.\nAdd new tokens and adjust model vocabulary size (aligned to multiples of 256).\n\"\"\"\n\nimport argparse\nimport json\nimport logging\nimport os\nimport random\nimport sys\nfrom pathlib import Path\nfrom typing import List\n\nimport torch\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(levelname)s - %(message)s',\n    handlers=[logging.StreamHandler()]\n)\nlogger = logging.getLogger(__name__)\n\n\ndef _align_vocab_size(vocab_size: int, alignment: int = 256) -> int:\n    \"\"\"Align vocabulary size to the nearest multiple of alignment.\n    \n    Args:\n        vocab_size: Current vocabulary size\n        alignment: Alignment value (default: 256)\n        \n    Returns:\n        Aligned vocabulary size\n    \"\"\"\n    return ((vocab_size + alignment - 1) // alignment) * alignment\n\n\ndef _fix_chat_template(reco_model_dir: str, hf_model_dir: str) -> None:\n    \"\"\"Fix chat template in tokenizer config by copying from original model.\n    \n    Args:\n        reco_model_dir: Output model directory\n        hf_model_dir: Original HuggingFace model directory\n    \"\"\"\n    reco_tokenizer_config_path = os.path.join(reco_model_dir, \"tokenizer_config.json\")\n    hf_tokenizer_config_path = os.path.join(hf_model_dir, \"tokenizer_config.json\")\n    \n    if not os.path.exists(hf_tokenizer_config_path):\n        logger.warning(f\"Original tokenizer_config.json not found: {hf_tokenizer_config_path}\")\n        return\n    \n    if not os.path.exists(reco_tokenizer_config_path):\n        logger.warning(f\"Output tokenizer_config.json not found: {reco_tokenizer_config_path}\")\n        return\n    \n    # Load configs\n    with open(reco_tokenizer_config_path, \"r\", encoding=\"utf-8\") as f:\n        reco_config = json.load(f)\n    \n    with open(hf_tokenizer_config_path, \"r\", encoding=\"utf-8\") as f:\n        hf_config = json.load(f)\n    \n    # Copy chat template from original\n    if \"chat_template\" in hf_config:\n        reco_config[\"chat_template\"] = hf_config[\"chat_template\"]\n        \n        with open(reco_tokenizer_config_path, \"w\", encoding=\"utf-8\") as f:\n            json.dump(reco_config, f, indent=2, ensure_ascii=False)\n        \n        logger.info(\"Chat template copied from original model\")\n\n\ndef _test_expanded_vocab(model, tokenizer, new_tokens: List[str]) -> None:\n    \"\"\"Test the expanded vocabulary with sample tokens.\n    \n    Args:\n        model: Expanded model\n        tokenizer: Expanded tokenizer\n        new_tokens: List of newly added tokens\n    \"\"\"\n    if not new_tokens:\n        logger.info(\"No new tokens to test\")\n        return\n    \n    # Sample 3-5 tokens from new_tokens\n    num_samples = min(random.randint(3, 5), len(new_tokens))\n    sampled_tokens = random.sample(new_tokens, num_samples)\n    input_text = \" \".join(sampled_tokens) + \" Hello world\"\n    \n    try:\n        input_ids = tokenizer.encode(input_text, return_tensors='pt')\n        \n        # Test generation (use eval mode to avoid training-specific behavior)\n        model.eval()\n        with torch.no_grad():\n            output = model.generate(input_ids, max_new_tokens=10, do_sample=False)\n        \n        logger.info(\"Vocabulary expansion test:\")\n        logger.info(f\"  Input text: {input_text}\")\n        logger.info(f\"  Decoded input: {tokenizer.decode(input_ids[0], skip_special_tokens=True)}\")\n        logger.info(f\"  Input IDs shape: {input_ids.shape}\")\n        logger.info(f\"  Generated: {tokenizer.decode(output[0], skip_special_tokens=True)}\")\n        \n    except Exception as e:\n        logger.warning(f\"Vocabulary test failed: {e}\")\n\n\ndef expand_qwen3_vocab_for_pretraining(\n    hf_model_dir: str,\n    output_model_dir: str,\n    new_tokens: List[str]\n) -> None:\n    \"\"\"Expand Qwen3 vocabulary for pretraining by adding new tokens.\n    \n    This function:\n    1. Loads the original Qwen3 model and tokenizer\n    2. Adds new tokens to the tokenizer\n    3. Resizes model embeddings to aligned vocabulary size (multiple of 256)\n    4. Updates model configuration\n    5. Saves the expanded model, tokenizer, and config\n    6. Fixes chat template from original model\n    7. Tests the expanded vocabulary\n    \n    Args:\n        hf_model_dir: Path to original HuggingFace model directory\n        output_model_dir: Path to save expanded model\n        new_tokens: List of new tokens to add\n        \n    Raises:\n        FileNotFoundError: If model directory doesn't exist\n        ValueError: If new_tokens is empty\n    \"\"\"\n    if not new_tokens:\n        raise ValueError(\"new_tokens list cannot be empty\")\n    \n    if not os.path.exists(hf_model_dir):\n        raise FileNotFoundError(f\"Model directory does not exist: {hf_model_dir}\")\n    \n    # Create output directory\n    os.makedirs(output_model_dir, exist_ok=True)\n    logger.info(f\"Expanding vocabulary for pretraining\")\n    logger.info(f\"  Input model: {hf_model_dir}\")\n    logger.info(f\"  Output model: {output_model_dir}\")\n    logger.info(f\"  New tokens: {len(new_tokens)}\")\n    \n    # Step 1: Load original model components\n    logger.info(\"Loading original model components...\")\n    config = AutoConfig.from_pretrained(hf_model_dir, trust_remote_code=True)\n    model = AutoModelForCausalLM.from_pretrained(\n        hf_model_dir,\n        torch_dtype=torch.float32,  # Use float32 for compatibility\n        trust_remote_code=True\n    )\n    tokenizer = AutoTokenizer.from_pretrained(hf_model_dir, trust_remote_code=True)\n    \n    original_vocab_size = len(tokenizer)\n    logger.info(f\"Original vocabulary size: {original_vocab_size}\")\n    \n    # Step 2: Add new tokens\n    logger.info(f\"Adding {len(new_tokens)} new tokens...\")\n    num_added = tokenizer.add_tokens(new_tokens)\n    logger.info(f\"Successfully added {num_added} tokens\")\n    \n    # Step 3: Calculate aligned vocabulary size\n    new_vocab_size = len(tokenizer)\n    target_vocab_size = _align_vocab_size(new_vocab_size, alignment=256)\n    logger.info(f\"New vocabulary size: {new_vocab_size}\")\n    logger.info(f\"Target vocabulary size (aligned to 256): {target_vocab_size}\")\n    \n    # Step 4: Resize model embeddings\n    logger.info(\"Resizing model token embeddings...\")\n    model.resize_token_embeddings(target_vocab_size)\n    \n    # Step 5: Update configuration\n    config.vocab_size = target_vocab_size\n    logger.info(f\"Updated config vocab_size to {target_vocab_size}\")\n    \n    # Step 6: Save expanded components\n    logger.info(\"Saving expanded model components...\")\n    tokenizer.save_pretrained(output_model_dir)\n    model.save_pretrained(output_model_dir)\n    config.save_pretrained(output_model_dir)\n    logger.info(\"Model components saved successfully\")\n    \n    # Step 7: Fix chat template\n    logger.info(\"Fixing chat template...\")\n    _fix_chat_template(output_model_dir, hf_model_dir)\n    \n    # Step 8: Test expanded vocabulary\n    logger.info(\"Testing expanded vocabulary...\")\n    _test_expanded_vocab(model, tokenizer, new_tokens)\n    \n    logger.info(f\"✓ Vocabulary expansion completed! Final vocab size: {target_vocab_size}\")\n\n\ndef generate_itemic_tokens(itemic_layer_n: int, vocab_size_per_layer: int) -> List[str]:\n    \"\"\"Generate itemic special tokens dynamically.\n    \n    IMPORTANT: Token order must strictly match gen_itemic_sp_tokens.py:\n    1. All <s_a_{i}> tokens (i from 0 to vocab_size_per_layer-1)\n    2. All <s_b_{i}> tokens (i from 0 to vocab_size_per_layer-1)\n    3. All <s_c_{i}> tokens (i from 0 to vocab_size_per_layer-1)\n    4. ... (for itemic_layer_n layers, in alphabetical order)\n    5. <|sid_begin|>\n    6. <|sid_end|>\n    \n    Args:\n        itemic_layer_n: Number of itemic layers (determines s_a, s_b, s_c, ...)\n        vocab_size_per_layer: Vocabulary size per layer (determines range of i)\n        \n    Returns:\n        List of generated tokens in strict order\n        \n    Raises:\n        ValueError: If itemic_layer_n or vocab_size_per_layer is invalid\n    \"\"\"\n    if itemic_layer_n <= 0:\n        raise ValueError(f\"itemic_layer_n must be positive, got {itemic_layer_n}\")\n    if vocab_size_per_layer <= 0:\n        raise ValueError(f\"vocab_size_per_layer must be positive, got {vocab_size_per_layer}\")\n    \n    # Generate layer names in alphabetical order: a, b, c, d, ...\n    # This ensures the same order as gen_itemic_sp_tokens.py\n    layer_names = [chr(ord('a') + i) for i in range(itemic_layer_n)]\n    \n    new_tokens = []\n    \n    # Generate tokens in strict order:\n    # For each layer (a, b, c, ...), generate all tokens with i from 0 to vocab_size_per_layer-1\n    # This matches the order: [*s_a_0..8191, *s_b_0..8191, *s_c_0..8191, ...]\n    for layer_name in layer_names:\n        for i in range(vocab_size_per_layer):\n            new_tokens.append(f\"<s_{layer_name}_{i}>\")\n    \n    # Add special tokens at the end (must be in this exact order)\n    new_tokens.append('<|sid_begin|>')\n    new_tokens.append('<|sid_end|>')\n    \n    total_tokens = itemic_layer_n * vocab_size_per_layer + 2\n    logger.info(f\"Generated {total_tokens} itemic tokens in strict order:\")\n    logger.info(f\"  Layers: {itemic_layer_n} ({', '.join([f's_{name}' for name in layer_names])})\")\n    logger.info(f\"  Vocab size per layer: {vocab_size_per_layer}\")\n    logger.info(f\"  Special tokens: <|sid_begin|>, <|sid_end|>\")\n    \n    return new_tokens\n\n\ndef load_tokens_from_file(tokens_file: str) -> List[str]:\n    \"\"\"Load tokens from a text file (one token per line).\n    \n    Args:\n        tokens_file: Path to text file containing tokens (one per line)\n        \n    Returns:\n        List of tokens (empty lines are skipped)\n        \n    Raises:\n        FileNotFoundError: If tokens file doesn't exist\n    \"\"\"\n    if not os.path.exists(tokens_file):\n        raise FileNotFoundError(f\"Tokens file does not exist: {tokens_file}\")\n    \n    new_tokens = []\n    line_count = 0\n    \n    with open(tokens_file, \"r\", encoding=\"utf-8\") as f:\n        for line in f:\n            line_count += 1\n            token = line.strip()\n            if token:  # Skip empty lines\n                new_tokens.append(token)\n    \n    logger.info(f\"Loaded {len(new_tokens)} tokens from {line_count} lines in {tokens_file}\")\n    return new_tokens\n\n\ndef main():\n    \"\"\"Main entry point for the script.\"\"\"\n    parser = argparse.ArgumentParser(\n        description='Expand Qwen3 vocabulary for pretraining by adding new tokens. '\n                    'Supports two modes: loading from file or generating itemic tokens dynamically.'\n    )\n    parser.add_argument(\n        \"--hf_model_dir\",\n        type=str,\n        required=True,\n        help=\"Path to original HuggingFace Qwen3 model directory\"\n    )\n    parser.add_argument(\n        \"--output_model_dir\",\n        type=str,\n        required=True,\n        help=\"Path to save expanded model directory\"\n    )\n    \n    # Itemic token generation parameters\n    parser.add_argument(\n        \"--itemic_layer_n\",\n        type=int,\n        required=True,\n        help=\"Number of itemic layers (e.g., 3 for s_a, s_b, s_c)\"\n    )\n    parser.add_argument(\n        \"--vocab_size_per_layer\",\n        type=int,\n        required=True,\n        help=\"Vocabulary size per layer (e.g., 8192 for tokens from 0 to 8191)\"\n    )\n    \n    args = parser.parse_args()\n    \n    try:\n        # Generate itemic tokens dynamically\n        logger.info(\"Generating itemic tokens dynamically...\")\n        new_tokens = generate_itemic_tokens(\n            itemic_layer_n=args.itemic_layer_n,\n            vocab_size_per_layer=args.vocab_size_per_layer\n        )\n        \n        if not new_tokens:\n            logger.error(\"No tokens to add\")\n            sys.exit(1)\n        \n        # Expand vocabulary\n        expand_qwen3_vocab_for_pretraining(\n            hf_model_dir=args.hf_model_dir,\n            output_model_dir=args.output_model_dir,\n            new_tokens=new_tokens\n        )\n        \n        logger.info(\"All operations completed successfully!\")\n        \n    except KeyboardInterrupt:\n        logger.info(\"\\nOperation cancelled by user\")\n        sys.exit(1)\n    except Exception as e:\n        logger.error(f\"Program execution failed: {e}\", exc_info=True)\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "pretrain/tools/model_test/test_hf_model.py",
    "content": "#!/usr/bin/env python3\n\"\"\"HuggingFace Model Testing Tool\n\nA unified tool for testing HuggingFace models with both direct text generation\nand chat template modes. Supports thinking mode and ground truth comparison.\n\"\"\"\n\nimport argparse\nimport json\nimport logging\nimport sys\nfrom pathlib import Path\nfrom typing import List, Optional, Union\n\nimport torch\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(levelname)s - %(message)s',\n    handlers=[logging.StreamHandler()]\n)\nlogger = logging.getLogger(__name__)\n\n\ndef load_model(\n    model_path: str,\n    device: str = \"auto\",\n    torch_dtype: torch.dtype = torch.bfloat16\n) -> tuple:\n    \"\"\"Load HuggingFace model and tokenizer.\n    \n    Args:\n        model_path: Path to model directory\n        device: Device mapping (default: \"auto\")\n        torch_dtype: Data type for model (default: bfloat16)\n    \n    Returns:\n        Tuple of (model, tokenizer)\n    \"\"\"\n    logger.info(f\"Loading model from: {model_path}\")\n    \n    tokenizer = AutoTokenizer.from_pretrained(\n        model_path,\n        trust_remote_code=True\n    )\n    logger.info(\"Tokenizer loaded\")\n    \n    model = AutoModelForCausalLM.from_pretrained(\n        model_path,\n        torch_dtype=torch_dtype,\n        device_map=device,\n        trust_remote_code=True\n    )\n    logger.info(\"Model loaded\")\n    \n    return model, tokenizer\n\n\ndef print_model_info(model) -> None:\n    \"\"\"Print model information.\n    \n    Args:\n        model: Loaded model instance\n    \"\"\"\n    device = next(model.parameters()).device\n    dtype = next(model.parameters()).dtype\n    \n    logger.info(\"=\" * 60)\n    logger.info(\"Model Information:\")\n    logger.info(f\"  Device: {device}\")\n    logger.info(f\"  Data Type: {dtype}\")\n    logger.info(f\"  Vocab Size: {model.config.vocab_size}\")\n    logger.info(f\"  Hidden Size: {model.config.hidden_size}\")\n    if hasattr(model.config, 'num_hidden_layers'):\n        logger.info(f\"  Num Layers: {model.config.num_hidden_layers}\")\n    logger.info(\"=\" * 60)\n\n\ndef generate_text(\n    model,\n    tokenizer,\n    prompt: str,\n    max_new_tokens: int = 256,\n    temperature: float = 0.7,\n    top_p: float = 0.9,\n    repetition_penalty: float = 1.1,\n    do_sample: bool = True,\n    show_input_ids: bool = False\n) -> str:\n    \"\"\"Generate text from a direct prompt (without chat template).\n    \n    Args:\n        model: Model instance\n        tokenizer: Tokenizer instance\n        prompt: Input prompt text\n        max_new_tokens: Maximum number of tokens to generate\n        temperature: Sampling temperature\n        top_p: Top-p sampling parameter\n        repetition_penalty: Repetition penalty\n        do_sample: Whether to use sampling\n        show_input_ids: Whether to print input token IDs\n    \n    Returns:\n        Generated text (only the newly generated part)\n    \"\"\"\n    device = next(model.parameters()).device\n    inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n    \n    if show_input_ids:\n        logger.info(f\"Input IDs: {inputs['input_ids']}\")\n    \n    with torch.no_grad():\n        generate_ids = model.generate(\n            **inputs,\n            max_new_tokens=max_new_tokens,\n            temperature=temperature,\n            top_p=top_p,\n            repetition_penalty=repetition_penalty,\n            do_sample=do_sample\n        )\n    \n    output = tokenizer.batch_decode(\n        generate_ids,\n        skip_special_tokens=False,\n        clean_up_tokenization_spaces=False\n    )[0]\n    \n    # Return only the newly generated part\n    generated_text = output[len(prompt):].strip()\n    return generated_text\n\n\ndef generate_chat(\n    model,\n    tokenizer,\n    messages: List[dict],\n    max_new_tokens: int = 1024,\n    temperature: float = 0.7,\n    top_p: float = 0.9,\n    repetition_penalty: float = 1.2,\n    enable_thinking: bool = False,\n    add_generation_prompt: bool = True,\n    show_template: bool = False\n) -> str:\n    \"\"\"Generate text using chat template.\n    \n    Args:\n        model: Model instance\n        tokenizer: Tokenizer instance\n        messages: List of message dicts with 'role' and 'content' keys\n        max_new_tokens: Maximum number of tokens to generate\n        temperature: Sampling temperature\n        top_p: Top-p sampling parameter\n        repetition_penalty: Repetition penalty\n        enable_thinking: Whether to enable thinking mode\n        add_generation_prompt: Whether to add generation prompt\n        show_template: Whether to print the formatted template\n    \n    Returns:\n        Generated text (only the newly generated part)\n    \"\"\"\n    # Apply chat template\n    template_kwargs = {\n        \"tokenize\": False,\n        \"add_generation_prompt\": add_generation_prompt,\n    }\n    \n    if enable_thinking:\n        template_kwargs[\"enable_thinking\"] = True\n    \n    text = tokenizer.apply_chat_template(messages, **template_kwargs)\n    \n    if show_template:\n        logger.info(f\"Chat Template:\\n{text}\\n\" + \"=\" * 60)\n    \n    # Tokenize and generate\n    inputs = tokenizer(\n        text,\n        return_tensors=\"pt\",\n        padding=False,\n        truncation=False\n    )\n    \n    device = next(model.parameters()).device\n    inputs = inputs.to(device)\n    \n    with torch.no_grad():\n        output = model.generate(\n            **inputs,\n            max_new_tokens=max_new_tokens,\n            temperature=temperature,\n            top_p=top_p,\n            repetition_penalty=repetition_penalty,\n            do_sample=True\n        )\n    \n    output_text = tokenizer.batch_decode(\n        output,\n        skip_special_tokens=False,\n        clean_up_tokenization_spaces=False\n    )[0]\n    \n    # Return only the newly generated part\n    generated_text = output_text[len(text):].strip()\n    return generated_text\n\n\ndef load_test_cases_from_file(file_path: Union[str, Path]) -> tuple:\n    \"\"\"Load test cases from JSON file.\n    \n    Expected format:\n    {\n        \"test_cases\": [\n            {\n                \"type\": \"text\" or \"chat\",\n                \"input\": \"prompt text\" or [{\"role\": \"...\", \"content\": \"...\"}],\n                \"ground_truth\": \"expected output\" (optional)\n            }\n        ]\n    }\n    \n    Args:\n        file_path: Path to JSON file\n    \n    Returns:\n        Tuple of (test_cases, ground_truths)\n    \"\"\"\n    with open(file_path, \"r\", encoding=\"utf-8\") as f:\n        data = json.load(f)\n    \n    test_cases = []\n    ground_truths = []\n    \n    for item in data.get(\"test_cases\", []):\n        test_cases.append({\n            \"type\": item.get(\"type\", \"text\"),\n            \"input\": item[\"input\"]\n        })\n        ground_truths.append(item.get(\"ground_truth\", \"\"))\n    \n    return test_cases, ground_truths\n\n\ndef get_default_test_cases() -> tuple:\n    \"\"\"Get default test cases for demonstration.\n    \n    Returns:\n        Tuple of (test_cases, ground_truths)\n    \"\"\"\n    test_cases = [\n        {\n            \"type\": \"text\",\n            \"input\": \"你好，请介绍一下你自己。\"\n        },\n        {\n            \"type\": \"text\",\n            \"input\": \"视频<|sid_begin|><s_a_8084><s_b_243><s_c_2535><|sid_end|>的类型是：\"\n        },\n        {\n            \"type\": \"chat\",\n            \"input\": [{\"role\": \"user\", \"content\": \"写一首关于春天的短诗：\"}]\n        },\n        {\n            \"type\": \"chat\",\n            \"input\": [\n                {\"role\": \"system\", \"content\": \"你是一名视频描述生成器，请根据下面的视频token生成视频描述\"},\n                {\"role\": \"user\", \"content\": \"这是一个视频：<|sid_begin|><s_a_3482><s_b_3606><s_c_3239><|sid_end|>，帮我总结一下这个视频讲述了什么内容\"}\n            ]\n        },\n    ]\n    \n    ground_truths = [\"\", \"\", \"\", \"\"]\n    \n    return test_cases, ground_truths\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Test HuggingFace models with text generation or chat mode\"\n    )\n    \n    # Model arguments\n    parser.add_argument(\n        \"--model_path\",\n        type=str,\n        required=True,\n        help=\"Path to HuggingFace model directory\"\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"auto\",\n        help=\"Device mapping (default: auto)\"\n    )\n    parser.add_argument(\n        \"--dtype\",\n        type=str,\n        default=\"bf16\",\n        choices=[\"fp32\", \"fp16\", \"bf16\"],\n        help=\"Model data type (default: bf16)\"\n    )\n    \n    # Test case arguments\n    parser.add_argument(\n        \"--test_file\",\n        type=str,\n        default=None,\n        help=\"Path to JSON file containing test cases (optional)\"\n    )\n    parser.add_argument(\n        \"--use_default\",\n        action=\"store_true\",\n        help=\"Use default test cases if no test file provided\"\n    )\n    \n    # Generation arguments\n    parser.add_argument(\n        \"--max_new_tokens\",\n        type=int,\n        default=1024,\n        help=\"Maximum number of tokens to generate (default: 1024)\"\n    )\n    parser.add_argument(\n        \"--temperature\",\n        type=float,\n        default=0.7,\n        help=\"Sampling temperature (default: 0.7)\"\n    )\n    parser.add_argument(\n        \"--top_p\",\n        type=float,\n        default=0.9,\n        help=\"Top-p sampling parameter (default: 0.9)\"\n    )\n    parser.add_argument(\n        \"--repetition_penalty\",\n        type=float,\n        default=1.2,\n        help=\"Repetition penalty (default: 1.2)\"\n    )\n    \n    # Chat mode arguments\n    parser.add_argument(\n        \"--enable_thinking\",\n        action=\"store_true\",\n        help=\"Enable thinking mode for chat template\"\n    )\n    parser.add_argument(\n        \"--no_generation_prompt\",\n        dest=\"add_generation_prompt\",\n        action=\"store_false\",\n        help=\"Disable generation prompt in chat template\"\n    )\n    parser.add_argument(\n        \"--show_template\",\n        action=\"store_true\",\n        help=\"Show formatted chat template\"\n    )\n    parser.add_argument(\n        \"--show_input_ids\",\n        action=\"store_true\",\n        help=\"Show input token IDs for text mode\"\n    )\n    \n    # Output arguments\n    parser.add_argument(\n        \"--compare_ground_truth\",\n        action=\"store_true\",\n        help=\"Compare output with ground truth if available\"\n    )\n    \n    args = parser.parse_args()\n    \n    # Convert dtype string to torch.dtype\n    dtype_map = {\n        \"fp32\": torch.float32,\n        \"fp16\": torch.float16,\n        \"bf16\": torch.bfloat16,\n    }\n    torch_dtype = dtype_map[args.dtype]\n    \n    # Load model\n    model, tokenizer = load_model(args.model_path, args.device, torch_dtype)\n    print_model_info(model)\n    \n    # Load test cases\n    if args.test_file:\n        logger.info(f\"Loading test cases from: {args.test_file}\")\n        test_cases, ground_truths = load_test_cases_from_file(args.test_file)\n    elif args.use_default:\n        logger.info(\"Using default test cases\")\n        test_cases, ground_truths = get_default_test_cases()\n    else:\n        logger.error(\"Either --test_file or --use_default must be provided\")\n        sys.exit(1)\n    \n    logger.info(f\"Loaded {len(test_cases)} test cases\\n\")\n    \n    # Run tests\n    logger.info(\"Starting tests...\\n\")\n    for i, (test_case, ground_truth) in enumerate(zip(test_cases, ground_truths), 1):\n        logger.info(\"=\" * 60)\n        logger.info(f\"Test {i}/{len(test_cases)}\")\n        logger.info(\"=\" * 60)\n        \n        test_type = test_case[\"type\"]\n        test_input = test_case[\"input\"]\n        \n        # Display input\n        if test_type == \"text\":\n            logger.info(f\"Input (text): {test_input}\\n\")\n        else:\n            logger.info(f\"Input (chat):\")\n            for msg in test_input:\n                logger.info(f\"  {msg['role']}: {msg['content'][:100]}...\")\n            logger.info(\"\")\n        \n        try:\n            # Generate\n            if test_type == \"text\":\n                generated = generate_text(\n                    model,\n                    tokenizer,\n                    test_input,\n                    max_new_tokens=args.max_new_tokens,\n                    temperature=args.temperature,\n                    top_p=args.top_p,\n                    repetition_penalty=args.repetition_penalty,\n                    show_input_ids=args.show_input_ids\n                )\n            else:  # chat mode\n                generated = generate_chat(\n                    model,\n                    tokenizer,\n                    test_input,\n                    max_new_tokens=args.max_new_tokens,\n                    temperature=args.temperature,\n                    top_p=args.top_p,\n                    repetition_penalty=args.repetition_penalty,\n                    enable_thinking=args.enable_thinking,\n                    add_generation_prompt=args.add_generation_prompt,\n                    show_template=args.show_template\n                )\n            \n            logger.info(f\"Output: {generated}\\n\")\n            \n            # Compare with ground truth if available\n            if args.compare_ground_truth and ground_truth:\n                logger.info(f\"Ground Truth: {ground_truth}\\n\")\n                if generated.strip() == ground_truth.strip():\n                    logger.info(\"✓ Match with ground truth\")\n                else:\n                    logger.info(\"✗ Does not match ground truth\")\n            \n        except Exception as e:\n            logger.error(f\"Generation failed: {e}\", exc_info=True)\n        \n        logger.info(\"-\" * 60 + \"\\n\")\n    \n    logger.info(\"=\" * 60)\n    logger.info(\"All tests completed!\")\n    logger.info(\"=\" * 60)\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "tokenizer/README.md",
    "content": "# Residual K-Means Tokenizer\n\nA residual K-means model for vector quantization. It encodes continuous embeddings into discrete codes through hierarchical clustering.\n\n> Public weights are available at [OpenOneRec/OneRec-tokenizer](https://huggingface.co/OpenOneRec/OneRec-tokenizer).\n\n\n> To utilize our foundation model, when using new datasets, the **embedding model** must be [Qwen3-8B-Embedding](https://huggingface.co/Qwen/Qwen3-Embedding-8B).\n\n## Files\n\n- `res_kmeans.py` - Model definition\n- `train_res_kmeans.py` - Training script\n- `infer_res_kmeans.py` - Inference script\n\n## Installation\n\n```bash\npip install torch numpy pandas pyarrow faiss tqdm\n```\n\n## Usage\n\n### Training\n\n```bash\npython train_res_kmeans.py \\\n    --data_path ./data/embeddings.parquet \\\n    --model_path ./checkpoints \\\n    --n_layers 3 \\\n    --codebook_size 8192 \\\n    --dim 4096\n```\n\n**Arguments:**\n- `--data_path`: Path to parquet file(s) with `embedding` column\n- `--model_path`: Directory to save the model\n- `--n_layers`: Number of residual layers (default: 3)\n- `--codebook_size`: Size of each codebook (default: 8192)\n- `--dim`: Embedding dimension (default: 4096)\n- `--seed`: Random seed (default: 42)\n\n### Inference\n\n```bash\npython infer_res_kmeans.py \\\n    --model_path ./checkpoints/model.pt \\\n    --emb_path ./data/embeddings.parquet \\\n    --output_path ./output/codes.parquet\n```\n\n**Arguments:**\n- `--model_path`: Path to trained model checkpoint\n- `--emb_path`: Path to parquet file with `pid` and `embedding` columns\n- `--output_path`: Output path (default: `{emb_path}_codes.parquet`)\n- `--batch_size`: Inference batch size (default: 10000)\n- `--device`: Device to use (default: cuda if available)\n- `--n_layers`: Number of layers to use (default: all)\n\n**Input format:** Parquet with columns `pid`, `embedding`\n\n**Output format:** Parquet with columns `pid`, `codes`\n"
  },
  {
    "path": "tokenizer/infer_res_kmeans.py",
    "content": "import argparse\nimport torch\nimport numpy as np\nimport pandas as pd\nfrom res_kmeans import ResKmeans\n\n\ndef load_embeddings(emb_path):\n    \"\"\"Load parquet file with pid and embedding columns\"\"\"\n    df = pd.read_parquet(emb_path)\n    pids = df['pid'].tolist()\n    emb = torch.tensor(np.stack(df['embedding'].values), dtype=torch.float32)\n    return pids, emb\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='ResKmeans Inference')\n    parser.add_argument('--model_path', type=str, required=True, help='model checkpoint path')\n    parser.add_argument('--emb_path', type=str, required=True, help='embedding file path')\n    parser.add_argument('--output_path', type=str, default=None, help='output path (default: emb_path + _codes.parquet)')\n    parser.add_argument('--batch_size', type=int, default=10000, help='inference batch size')\n    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')\n    parser.add_argument('--n_layers', type=int, default=None, help='number of layers to use (default: all layers)')\n    args = parser.parse_args()\n\n    # Load model\n    print(f\"Loading model from {args.model_path}\")\n    checkpoint = torch.load(args.model_path, map_location='cpu')\n\n    if isinstance(checkpoint, ResKmeans):\n        model = checkpoint\n    elif isinstance(checkpoint, dict):\n        # Restore from state_dict\n        if 'model' in checkpoint:\n            state_dict = checkpoint['model']\n        elif 'state_dict' in checkpoint:\n            state_dict = checkpoint['state_dict']\n        else:\n            state_dict = checkpoint\n\n        # Infer model parameters\n        n_layers = sum(1 for k in state_dict.keys() if k.startswith('centroids.'))\n        first_centroid = state_dict['centroids.0']\n        codebook_size, dim = first_centroid.shape\n\n        model = ResKmeans(n_layers=n_layers, codebook_size=codebook_size, dim=dim)\n        model.load_state_dict(state_dict)\n    else:\n        raise ValueError(\"Unknown checkpoint format\")\n\n    model = model.to(args.device)\n    model.eval()\n    print(f\"Model loaded: n_layers={model.n_layers}, codebook_size={model.codebook_size}, dim={model.dim}\")\n\n    # Load embeddings\n    print(f\"Loading embeddings from {args.emb_path}\")\n    pids, emb = load_embeddings(args.emb_path)\n    print(f\"Embeddings shape: {emb.shape}, num pids: {len(pids)}\")\n\n    # Inference\n    print(\"Encoding...\")\n    all_codes = []\n    with torch.no_grad():\n        for i in range(0, len(emb), args.batch_size):\n            batch = emb[i:i + args.batch_size].to(args.device)\n            codes = model.encode(batch, n_layers=args.n_layers)\n            all_codes.append(codes.cpu())\n            if (i // args.batch_size) % 10 == 0:\n                print(f\"  Processed {min(i + args.batch_size, len(emb))}/{len(emb)}\")\n\n    all_codes = torch.cat(all_codes, dim=0)\n    print(f\"Output codes shape: {all_codes.shape}\")\n\n    # Save results to parquet\n    output_path = args.output_path or args.emb_path.rsplit('.', 1)[0] + '_codes.parquet'\n    df_out = pd.DataFrame({\n        'pid': pids,\n        'codes': all_codes.numpy().tolist()\n    })\n    df_out.to_parquet(output_path, index=False)\n    print(f\"Codes saved to {output_path}\")\n\n    # Compute reconstruction loss\n    print(\"\\nComputing reconstruction loss...\")\n    with torch.no_grad():\n        sample_size = min(10000, len(emb))\n        sample_emb = emb[:sample_size].to(args.device)\n        sample_codes = all_codes[:sample_size].to(args.device)\n        reconstructed = model.decode(sample_codes)\n        loss_info = model.calc_loss(sample_emb, reconstructed)\n        print(f\"Reconstruction loss (MSE): {loss_info['loss']:.6f}\")\n        print(f\"Relative loss: {loss_info['rel_loss']:.6f}\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tokenizer/res_kmeans.py",
    "content": "import torch\nfrom torch import nn\n\nclass ResKmeans(nn.Module):\n\n    def __init__(self, n_layers, codebook_size, dim, extra_kmeans_config=None, **kwargs):\n        super().__init__()\n        self.n_layers = n_layers\n        self.codebook_size = codebook_size\n        self.dim = dim\n        self.extra_kmeans_config = extra_kmeans_config\n        self.centroids = nn.ParameterList([\n            nn.Parameter(torch.zeros((codebook_size,dim), requires_grad=False))\n            for i in range(n_layers)\n        ])\n\n    def calc_loss(self, x, out, epsilon=1e-4):\n        loss = ((out - x) ** 2).mean()\n        rel_loss = (torch.abs(x - out) / (torch.maximum(torch.abs(x), torch.abs(out)) + epsilon)).mean()\n        return {'loss': loss.item(), 'rel_loss': rel_loss.item()}\n    \n    def train_kmeans(self, inputs, verbose=True):\n        import faiss\n        kmeans = faiss.Kmeans(self.dim, self.codebook_size, spherical=False, **self.extra_kmeans_config)\n        x = inputs.clone()\n        out = torch.zeros_like(x)\n        for l in range(self.n_layers):\n            kmeans.train(x)\n            _, I = kmeans.index.search(x, 1)\n            I = I.reshape([-1])\n            o = torch.tensor(kmeans.centroids[I])\n            out += o\n            if verbose:\n                losses = self.calc_loss(inputs, out)\n                print(l, losses)\n            x = x - o\n            self.centroids[l] = nn.Parameter(torch.tensor(kmeans.centroids.copy()), requires_grad=False)\n            print(f\"layer {l} finished\")\n    \n    def encode(self, x, n_layers=None):\n        if n_layers is None:\n            n_layers = self.n_layers\n        else:\n            assert n_layers <= self.n_layers\n        out = []\n        for l in range(n_layers):\n            x_norm_sq = x.pow(2.).sum(dim=1, keepdim=True)\n            codebook_t_norm_sq = self.centroids[l].T.pow(2.).sum(dim=0, keepdim=True)\n            distances = torch.addmm(x_norm_sq + codebook_t_norm_sq, x, self.centroids[l].T, alpha=-2.0)\n            code = distances.argmin(dim=-1)\n            x = x - self.centroids[l][code]\n            out.append(code)\n        out = torch.stack(out, dim=1)\n        return out\n    \n    def decode(self, code):\n        out = torch.zeros((code.shape[0], self.dim), dtype=torch.float32, device=code.device)\n        n_layers = code.shape[1]\n        assert n_layers <= self.n_layers\n        for l in range(n_layers):\n            c = code[:, l]\n            out += self.centroids[l][c]\n        return out\n"
  },
  {
    "path": "tokenizer/train_res_kmeans.py",
    "content": "import os\nimport argparse\nimport random\nimport numpy as np\nimport torch\nimport pyarrow.parquet as pq\nfrom tqdm import tqdm\nfrom res_kmeans import ResKmeans\n\n\ndef read_train_data(path, emb_dim):\n    \"\"\"Read training data from local parquet files\"\"\"\n    dataset = pq.ParquetDataset(path)\n\n    fragments = list(dataset.fragments)\n    random.shuffle(fragments)\n    print(f\"Total files: {len(fragments)}\")\n\n    embeddings = []\n    current_size = 0\n\n    for fragment in tqdm(fragments, desc=\"Reading files\"):\n        table = fragment.to_table(columns=['embedding'])\n        if table.num_rows == 0:\n            continue\n\n        emb_chunk = table['embedding'].to_numpy(zero_copy_only=False)\n        if emb_chunk.dtype == 'object':\n            emb_chunk = np.vstack(emb_chunk)\n\n        emb_chunk = emb_chunk[:, :emb_dim].astype(np.float32)\n        embeddings.append(emb_chunk)\n        current_size += len(emb_chunk)\n\n    result = np.concatenate(embeddings, axis=0)\n    print(f\"Final shape: {result.shape}\")\n    return result\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='Train ResKmeans')\n    parser.add_argument('--data_path', type=str, required=True, help='training data path')\n    parser.add_argument('--model_path', type=str, required=True, help='model save path')\n    parser.add_argument('--n_layers', type=int, default=3, help='number of layers')\n    parser.add_argument('--codebook_size', type=int, default=8192, help='codebook size')\n    parser.add_argument('--dim', type=int, default=4096, help='embedding dimension')\n    parser.add_argument('--niter', type=int, default=20, help='kmeans iterations')\n    parser.add_argument('--seed', type=int, default=42, help='random seed')\n    args = parser.parse_args()\n\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    # Load data\n    embeddings = read_train_data(args.data_path, args.dim)\n\n    # Create and train model\n    model = ResKmeans(\n        n_layers=args.n_layers,\n        codebook_size=args.codebook_size,\n        dim=args.dim,\n    )\n    model.train_kmeans(torch.tensor(embeddings))\n\n    # Save model\n    os.makedirs(args.model_path, exist_ok=True)\n    save_path = os.path.join(args.model_path, \"model.pt\")\n    torch.save(model.state_dict(), save_path)\n    print(f\"Model saved to {save_path}\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "verl_distillation/LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "verl_distillation/README.md",
    "content": "## Overview\n\nThis repository is built on top of the open-source [**verl**](https://github.com/volcengine/verl) (HybridFlow RLHF/RL training framework) and adds support for **on-policy distillation**.\nIt is designed for scenarios where the **teacher and student use different vocabularies**, e.g., distilling from `Qwen3` (teacher) to a recommendation-pretrained model (student) that contains **extended itemic tokens**, while improving and preserving general-purpose capabilities.\n\n> **Note**: This repository is forked from [verl](https://github.com/volcengine/verl) at commit [`703a078`](https://github.com/volcengine/verl/commit/703a07856fe2544833dfce51136f386654574b30) and extended with on-policy distillation capabilities.\n\nThe high-level idea is briefly described in the OpenOneRec technical report, Section **5.2 On-policy Distillation for General Capability**: [OneRecBench.pdf](OneRecBench.pdf).\n\n## Key Features\n\n- **On-policy distillation entrypoint**: `recipe/onpolicy_distill/main_onpolicy_distill.py`\n- **Distillation trainer**: `recipe/onpolicy_distill/onpolicy_distill_trainer.py`\n- **Teacher/Student vocabulary mismatch support**\n  - Generates `distill_special_token_mask` during rollout\n  - Replaces/masks extended-vocab tokens during log-probability computation to improve training stability\n- **OneRec dataset adapter (parquet → chat)**: `verl/utils/dataset/onerec_dataset.py`\n  - Optionally appends `/think` or `/no_think` to the user prompt (force/auto modes)\n- **Algorithm and metrics extensions**\n  - `AdvantageEstimator.ON_POLICY_DISTILL`\n  - `compute_on_policy_distill_data_metrics(...)`\n\n## Quick Start\n\n### Installation\n```bash\n# Configure hostfile (multi-node)\ncat > /etc/mpi/hostfile << EOF\n192.168.1.100\n192.168.1.101\n192.168.1.102\nEOF\n\n# Install dependencies\n# For Single node\nbash deploy_env.sh\n# For Multi-node\nbash deploy_env.sh --all-nodes\n\n# Start Ray cluster\nbash init_ray_cluster.sh\n```\n\n\n### Required environment variables\n\n```bash\n# Required: model and data paths\nexport BASE_MODEL=/path/to/student_model\nexport TEACHER_MODEL=/path/to/teacher_model   # e.g. Qwen3-1.7B\n\n# Optional: extended-vocabulary distillation settings (defaults in the script)\nexport EXTEND_VOCAB_START_TOKEN=151669         # token_id >= this value is treated as an \"extended vocab token\"\nexport MASK_RESPONSE_IF_HAVE_EXTEND_TOKEN=False  # mask the whole response if any extended token appears\n\n# Optional: advantage clipping bounds for distillation (defaults in the script)\nexport DISTILL_ADV_MAX=5.0    # upper bound\nexport DISTILL_ADV_MIN=-30.0  # lower bound\n```\n\n**`EXTEND_VOCAB_START_TOKEN`**\nis used for teacher/student vocabulary mismatch. If the student model introduces additional tokens on top of the base vocabulary (e.g., item tokens for recommendation), set this threshold to the first extended token id.\nDuring rollout, the framework produces `distill_special_token_mask`; during log-probability computation, extended-vocab tokens are replaced/masked to maintain stability.\n\n**`DISTILL_ADV_MAX / DISTILL_ADV_MIN`**\nclip the distillation advantage to avoid extreme values when the teacher and student distributions differ substantially. The distillation signal is token-level reverse KL:\n$$A = -(\\log p_{\\text{student}} - \\log p_{\\text{teacher}})$$\n\n### Launch training\n\nThe training entry script is located at `recipe/onpolicy_distill/run_qwen3_distill.sh`.\n\n```bash\nbash recipe/onpolicy_distill/run_qwen3_distill.sh /etc/mpi/hostfile\n```\n\nNotes:\n- The script defaults to **console-only logging** (`trainer.logger=[console]`). To use W&B, export `WANDB_API_KEY` and override `trainer.logger=[console,wandb]` in the script/CLI.\n- Hydra config entrypoint: `recipe/onpolicy_distill/config/onpolicy_distill_trainer.yaml` (reuses the base config from [verl](https://github.com/volcengine/verl)).\n\n## Data Format (parquet)\n\n`OneRecDataset` reads the `messages` field from parquet (either a list, or a string-serialized list) and constructs:\n- `prompt`: all messages except the last one\n- `ground_truth`: the content of the last message (used for reward payload / analysis)\n\nIt is recommended to keep a `source` or `data_source` field for per-task statistics.\n\n## Key Implementation Details (for reproducibility)\n\n- **Distillation signal (reverse KL)**\n  - Implemented in `verl/trainer/ppo/core_algos.py` as:\n    \\(A = -(\\log p_{\\text{student}} - \\log p_{\\text{teacher}})\\)\n  - Enabled via the `compute_advantage(...)` branch in `verl/trainer/ppo/ray_trainer.py`, with support for `distill_adv_max_clip / distill_adv_min_clip`.\n\n- **Extended vocabulary handling**\n  - `extend_vocab_start_token`: tokens with id \\(\\ge\\) this threshold are treated as \"extended vocab tokens\"\n  - `ToolAgentLoop` emits `distill_special_token_mask` (optionally truncating/masking the response)\n  - `dp_actor.compute_log_prob(..., mask_special_token=True)` replaces/masks extended-vocab tokens and overwrites the corresponding log-prob entries (via `ref_log_prob_replace_val`)\n\n---\n\n## 🙏 Acknowledgements\n\nThis repository is built upon and extended from the open-source [**verl**](https://github.com/volcengine/verl) project. We sincerely thank the verl team for their excellent work on the HybridFlow RLHF/RL training framework, which provides the solid foundation for our on-policy distillation implementation.\n"
  },
  {
    "path": "verl_distillation/README_ORIGINAL.md",
    "content": "<div align=\"center\">\n 👋 Hi, everyone! \n    verl is a RL training library initiated by <b>ByteDance Seed team</b> and maintained by the verl community.\n    <br>\n    <br>\n</div>\n\n<div align=\"center\">\n\n<a href=\"https://deepwiki.com/volcengine/verl\"><img src=\"https://devin.ai/assets/deepwiki-badge.png\" alt=\"Ask DeepWiki.com\" style=\"height:20px;\"></a>\n[![GitHub Repo stars](https://img.shields.io/github/stars/volcengine/verl)](https://github.com/volcengine/verl/stargazers)\n[![Twitter](https://img.shields.io/twitter/follow/verl_project)](https://twitter.com/verl_project)\n<a href=\"https://join.slack.com/t/verl-project/shared_invite/zt-3c6mc2khw-v0lo6NfDPuFP6OnkrZwfqw\"><img src=\"https://img.shields.io/badge/Slack-verl-blueviolet?logo=slack&amp\"></a>\n<a href=\"https://arxiv.org/pdf/2409.19256\"><img src=\"https://img.shields.io/static/v1?label=EuroSys&message=Paper&color=red\"></a>\n[![Documentation](https://img.shields.io/badge/documentation-blue)](https://verl.readthedocs.io/en/latest/)\n<a href=\"https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG\"><img src=\"https://img.shields.io/badge/微信-green?logo=wechat&amp\"></a>\n\n</div>\n\n![seed logo](https://github.com/user-attachments/assets/c42e675e-497c-4508-8bb9-093ad4d1f216)\n\n<h1 style=\"text-align: center;\">verl: Volcano Engine Reinforcement Learning for LLMs</h1>\n\nverl is a flexible, efficient and production-ready RL training library for large language models (LLMs).\n\nverl is the open-source version of **[HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)** paper.\n\nverl is flexible and easy to use with:\n\n- **Easy extension of diverse RL algorithms**: The hybrid-controller programming model enables flexible representation and efficient execution of complex post-training dataflows. Build RL dataflows such as GRPO, PPO in a few lines of code.\n\n- **Seamless integration of existing LLM infra with modular APIs**: Decouples computation and data dependencies, enabling seamless integration with existing LLM frameworks, such as FSDP, Megatron-LM, vLLM, SGLang, etc\n\n- **Flexible device mapping**: Supports various placement of models onto different sets of GPUs for efficient resource utilization and scalability across different cluster sizes.\n\n- Ready integration with popular HuggingFace models\n\nverl is fast with:\n\n- **State-of-the-art throughput**: SOTA LLM training and inference engine integrations and SOTA RL throughput.\n\n- **Efficient actor model resharding with 3D-HybridEngine**: Eliminates memory redundancy and significantly reduces communication overhead during transitions between training and generation phases.\n\n</p>\n\n## News\n- [2025/08] verl is presented in the [PyTorch Expert Exchange Webinar](https://www.youtube.com/watch?v=Vd79NmmqY3Q&t=2s). [Slides](https://github.com/eric-haibin-lin/verl-community/blob/main/slides/verl_talk_pytorch_2025_08.pdf) available.\n- [2025/07] The [ReTool](https://arxiv.org/pdf/2504.11536) recipe is fully open sourced. [Blog](https://www.notion.so/verl-reTool-recipe-Using-multi-round-conversations-and-code-sandboxing-to-improve-the-math-of-large-23a8b5b7feba80b386b2e5b5e3c1cde0)\n- [2025/07] The first verl meetup will be held at ICML Vancouver on July 16th! Please [join us](https://lu.ma/0ek2nyao) if you are at ICML! (onsite only)\n- [2025/06] verl with Megatron backend enables large MoE models such as [DeepSeek-671B and Qwen3-235B](https://verl.readthedocs.io/en/latest/perf/dpsk.html).\n- [2025/03] [DAPO](https://dapo-sia.github.io/) is the open-sourced SOTA RL algorithm that achieves 50 points on AIME 2024 based on the Qwen2.5-32B pre-trained model, surpassing the previous SOTA achieved by DeepSeek's GRPO (DeepSeek-R1-Zero-Qwen-32B). DAPO's training is fully powered by verl and the reproduction code is available in `recipe/dapo` now.\n<details><summary> more... </summary>\n<ul>\n  <li>[2025/04] [Seed-Thinking-v1.5](https://github.com/ByteDance-Seed/Seed-Thinking-v1.5/blob/main/seed-thinking-v1.5.pdf) tech report is released! Trained with verl, Seed-Thinking-v1.5 achieves 86.7 on AIME 2024, 55.0 on Codeforces and 77.3 on GPQA, demonstrating excellent reasoning abilities in STEM and coding. Beyond reasoning tasks, the method demonstrates notable generalization across diverse domains.</li>\n  <li>[2025/07] verl keynote at [AWS AI Hours Singapore](https://pages.awscloud.com/aws-ai-hours-sg.html#agenda) on 7/8, verl & verl-agent project updates at [Agent for SWE meetup](https://lu.ma/e498qhsi) by LF AI & Data Singapore on 7/11.</li>\n  <li>[2025/06] verl team will provide latest project updates at [PyTorch Day China](https://www.lfasiallc.com/pytorch-day-china/) on June 7th. Meet our dev team in Beijing!</li>\n  <li> [2025/04] [VAPO](https://arxiv.org/pdf/2504.05118) (value-based augmented PPO) paper covers our latest RL method for reasoning models. Trained from Qwen-32B-base model, VAPO achieves 60.4 on AIME 2024, outperforming DAPO-32B.</li>\n  <li>[2025/05] [PF-PPO](https://arxiv.org/abs/2409.06957), accepted to ICML 2025, is now supported in verl! PF-PPO enhances policy learning efficiency and robustness by filtering potentially noisy reward signals and reusing high-quality experiences via a replay buffer.</li>\n  <li>[2025/04] We will give a tutorial about latest post-training techniques and programming guide for verl at [ICLR 2025 Expo](https://iclr.cc/virtual/2025/calendar?filter_events=Expo+Talk+Panel&filter_rooms=), [SCI-FM workshop](https://open-foundation-model.github.io/) and [LMSys afterparty](https://lu.ma/d23nyynm). Talk materials available [here](https://github.com/eric-haibin-lin/verl-community/tree/main/iclr25). </li>\n  <li>[2025/03] verl v0.3.0.post1 is released! See [release note](https://github.com/volcengine/verl/releases/) for details. It achieves [~1.4x speedup](https://tongyx361.github.io/blogs/posts/verl-intro/#/verl-flexible-and-efficient-rl-for-llms) compared to prev versions.</li>\n  <li>[2025/05] verl will be presented at [A2M Shanghai](https://a2m.msup.com.cn/home/?aid=4488&city=shanghai) on 5/16 - 5/17.</li>\n  <li>[2025/05] verl will be presented at [GOSIM x PyTorch Day 2025](https://paris2025.gosim.org/). See you in Paris! </li>\n  <li>[2025/03] We introduced the programming model of verl at the [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg) and [verl intro and updates](https://github.com/eric-haibin-lin/verl-community/blob/main/slides/verl-lmsys-meetup.pdf) at the [SGLang-LMSYS Org Meetup](https://lu.ma/ntjrr7ig) in Sunnyvale mid-March.</li>\n  <li>[2025/03] We will present verl(HybridFlow) at EuroSys 2025. See you in Rotterdam!</li>\n  <li>[2025/02] verl v0.2.0.post2 is released!</li>\n  <li>[2025/02] We presented verl in the <a href=\"https://lu.ma/ji7atxux\">Bytedance/NVIDIA/Anyscale Ray Meetup</a>. See you in San Jose!</li>\n  <li>[2025/01] [Doubao-1.5-pro](https://team.doubao.com/zh/special/doubao_1_5_pro) is released with SOTA-level performance on LLM & VLM. The RL scaling preview model is trained using verl, reaching OpenAI O1-level performance on math benchmarks (70.0 pass@1 on AIME).</li>\n  <li>[2024/12] verl is presented at Ray Forward 2024. Slides available <a href=\"https://github.com/eric-haibin-lin/verl-community/blob/main/slides/Ray_Forward_2024_%E5%B7%AB%E9%94%A1%E6%96%8C.pdf\">here</a></li>\n  <li>[2024/12] The team presented <a href=\"https://neurips.cc/Expo/Conferences/2024/workshop/100677\">Post-training LLMs: From Algorithms to Infrastructure</a> at NeurIPS 2024. <a href=\"https://github.com/eric-haibin-lin/verl-data/tree/neurips\">Slides</a> and <a href=\"https://neurips.cc/Expo/Conferences/2024/workshop/100677\">video</a> available.</li>\n  <li>[2024/10] verl is presented at Ray Summit. <a href=\"https://www.youtube.com/watch?v=MrhMcXkXvJU&list=PLzTswPQNepXntmT8jr9WaNfqQ60QwW7-U&index=37\">Youtube video</a> available.</li>\n  <li>[2024/08] HybridFlow (verl) is accepted to EuroSys 2025.</li>\n</ul>   \n</details>\n\n## Key Features\n\n- **FSDP**, **FSDP2** and **Megatron-LM** for training.\n- **vLLM**, **SGLang** and **HF Transformers** for rollout generation.\n- Compatible with Hugging Face Transformers and Modelscope Hub: [Qwen-3](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen3-8b.sh), Qwen-2.5, Llama3.1, Gemma2, DeepSeek-LLM, etc\n- Supervised fine-tuning.\n- Reinforcement learning with [PPO](examples/ppo_trainer/), [GRPO](examples/grpo_trainer/), [GSPO](recipe/gspo/), [ReMax](examples/remax_trainer/), [REINFORCE++](https://verl.readthedocs.io/en/latest/examples/config.html#algorithm), [RLOO](examples/rloo_trainer/), [PRIME](recipe/prime/), [DAPO](recipe/dapo/), [DrGRPO](recipe/drgrpo), [KL_Cov & Clip_Cov](recipe/entropy) etc.\n  - Support model-based reward and function-based reward (verifiable reward) for math, [coding](https://github.com/volcengine/verl/tree/main/recipe/dapo), etc\n  - Support vision-language models (VLMs) and [multi-modal RL](examples/grpo_trainer/run_qwen2_5_vl-7b.sh) with Qwen2.5-vl, Kimi-VL\n  - [Multi-turn with tool calling](https://github.com/volcengine/verl/tree/main/examples/sglang_multiturn)\n- LLM alignment recipes such as [Self-play preference optimization (SPPO)](https://github.com/volcengine/verl/tree/main/recipe/sppo)\n- Flash attention 2, [sequence packing](examples/ppo_trainer/run_qwen2-7b_seq_balance.sh), [sequence parallelism](examples/ppo_trainer/run_deepseek7b_llm_sp2.sh) support via DeepSpeed Ulysses, [LoRA](examples/sft/gsm8k/run_qwen_05_peft.sh), [Liger-kernel](examples/sft/gsm8k/run_qwen_05_sp2_liger.sh).\n- Scales up to 671B models and hundreds of GPUs with [expert parallelism](https://github.com/volcengine/verl/pull/1467)\n- Multi-gpu [LoRA RL](https://verl.readthedocs.io/en/latest/advance/ppo_lora.html) support to save memory.\n- Experiment tracking with wandb, swanlab, mlflow and tensorboard.\n\n## Upcoming Features and Changes\n\n- Q3 Roadmap https://github.com/volcengine/verl/issues/2388\n- DeepSeek 671b optimizations with Megatron https://github.com/volcengine/verl/issues/1033\n- Multi-turn rollout and tools using optimizations https://github.com/volcengine/verl/issues/1882\n- [Agent integration](https://github.com/volcengine/verl/tree/main/verl/experimental/agent_loop)\n- Async and off-policy architecture https://github.com/volcengine/verl/pull/2231\n- List of breaking changes since v0.4 https://github.com/volcengine/verl/discussions/2270\n\n## Getting Started\n\n<a href=\"https://verl.readthedocs.io/en/latest/index.html\"><b>Documentation</b></a>\n\n**Quickstart:**\n\n- [Installation](https://verl.readthedocs.io/en/latest/start/install.html)\n- [Quickstart](https://verl.readthedocs.io/en/latest/start/quickstart.html)\n- [Programming Guide](https://verl.readthedocs.io/en/latest/hybrid_flow.html) & [Tech Talk](https://hcqnc.xetlk.com/sl/3vACOK) (in Chinese)\n- [PPO in verl](https://verl.readthedocs.io/en/latest/algo/ppo.html)\n- [GRPO in verl](https://verl.readthedocs.io/en/latest/algo/grpo.html)\n\n**Running a PPO example step-by-step:**\n\n- [Prepare Data for Post-Training](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html)\n- [Implement Reward Function for Dataset](https://verl.readthedocs.io/en/latest/preparation/reward_function.html)\n- [PPO Example Architecture](https://verl.readthedocs.io/en/latest/examples/ppo_code_architecture.html)\n- [Config Explanation](https://verl.readthedocs.io/en/latest/examples/config.html)\n\n**Reproducible algorithm baselines:**\n\n- [RL performance on coding, math](https://verl.readthedocs.io/en/latest/algo/baseline.html)\n\n**For code explanation and advance usage (extension):**\n\n- PPO Trainer and Workers\n  - [PPO Ray Trainer](https://verl.readthedocs.io/en/latest/workers/ray_trainer.html)\n  - [PyTorch FSDP Backend](https://verl.readthedocs.io/en/latest/workers/fsdp_workers.html)\n  - [Megatron-LM Backend](https://verl.readthedocs.io/en/latest/index.html)\n\n- Advanced Usage and Extension\n  - [Add Models with the FSDP Backend](https://verl.readthedocs.io/en/latest/advance/fsdp_extension.html)\n  - [Add Models with the Megatron-LM Backend](https://verl.readthedocs.io/en/latest/advance/megatron_extension.html)\n  - [Multi-turn Rollout Support](https://verl.readthedocs.io/en/latest/sglang_multiturn/multiturn.html)\n  - [Search Tool Integration](https://verl.readthedocs.io/en/latest/sglang_multiturn/search_tool_example.html)\n  - [Sandbox Fusion Integration](https://verl.readthedocs.io/en/latest/examples/sandbox_fusion_example.html)\n  - [Deployment using Separate GPU Resources](https://github.com/volcengine/verl/tree/main/examples/split_placement)\n  - [Extend to Other RL(HF) algorithms](https://verl.readthedocs.io/en/latest/advance/dpo_extension.html)\n  - [Ray API design tutorial](https://verl.readthedocs.io/en/latest/advance/placement.html)\n\n**Blogs from the community**\n\n- [When Reasoning Models Break Tokenization: The Hidden Complexity of Multiturn Training](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/fast_tokenization/multiturn_tokenization_and_masking.md)\n- [verl deployment on AWS SageMaker](https://medium.com/@kaige.yang0110/run-verl-on-sagemaker-using-4x8-l40s-gpus-8e6d5c3c61d3)\n- [verl x SGLang Multi-turn Code Walkthrough](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/code-walk-through/readme_EN.md)\n- [Optimizing SGLang Memory Usage in verl](https://hebiao064.github.io/rl-memory-management)\n- [SGLang, verl, OpenBMB and Tsinghua University: Pioneering End-to-End Multi-Turn RLHF](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/verl-multiturn-rollout-Release.md)\n- [Reinforcement Learning from Human Feedback on AMD GPUs with verl and ROCm Integration](https://rocm.blogs.amd.com/artificial-intelligence/verl-large-scale/README.html)\n- [veMLP x verl ：玩转强化学习训练](https://mp.weixin.qq.com/s/7nbqxk4knMGd-hQE9ls2tA)\n- [使用 verl 进行 GRPO 分布式强化学习训练最佳实践](https://www.volcengine.com/docs/6459/1463942)\n- [HybridFlow verl 原文浅析](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/readme.md)\n- [最高提升 20 倍吞吐量！豆包大模型团队发布全新 RLHF 框架，现已开源！](https://team.doubao.com/en/blog/%E6%9C%80%E9%AB%98%E6%8F%90%E5%8D%8720%E5%80%8D%E5%90%9E%E5%90%90%E9%87%8F-%E8%B1%86%E5%8C%85%E5%A4%A7%E6%A8%A1%E5%9E%8B%E5%9B%A2%E9%98%9F%E5%8F%91%E5%B8%83%E5%85%A8%E6%96%B0-rlhf-%E6%A1%86%E6%9E%B6-%E7%8E%B0%E5%B7%B2%E5%BC%80%E6%BA%90)\n\n## Performance Tuning Guide\n\nThe performance is essential for on-policy RL algorithm. We have written a detailed [performance tuning guide](https://verl.readthedocs.io/en/latest/perf/perf_tuning.html) to help you optimize performance.\n\n## Upgrade to vLLM >= v0.8.2\n\nverl now supports vLLM>=0.8.2 when using FSDP as the training backend. Please refer to [this document](https://github.com/volcengine/verl/blob/main/docs/README_vllm0.8.md) for the installation guide and more information. Please avoid vllm 0.7.x, which contains bugs that may lead to OOMs and unexpected errors.\n\n## Use Latest SGLang\n\nSGLang is fully supported with verl, and SGLang RL Group is working extensively on building unique features, including multi-turn agentic RL, VLM RLHF, server-based RL, and partial rollout. Please refer to [this document](https://verl.readthedocs.io/en/latest/workers/sglang_worker.html) for the installation guide and more information.\n\n## Upgrade to FSDP2\n\nverl is fully embracing FSDP2! FSDP2 is recommended by torch distributed team, providing better throughput and memory usage, and is composible with other features (e.g. torch.compile). To enable FSDP2, simply use verl main and set the following options:\n```\nactor_rollout_ref.ref.strategy=fsdp2\nactor_rollout_ref.actor.strategy=fsdp2\ncritic.strategy=fsdp2 \nreward_model.strategy=fsdp2 \n```\nFurthermore, FSDP2 cpu offloading is compatible with gradient accumulation. You can turn it on to save memory with `actor_rollout_ref.actor.fsdp_config.offload_policy=True`. For more details, see https://github.com/volcengine/verl/pull/1026\n\n## AMD Support (ROCm Kernel)\n\nverl now supports FSDP as the training engine (Megatron support coming soon) and both integrates with vLLM and SGLang as inference engines. Please refer to [this document](https://github.com/volcengine/verl/blob/main/docs/amd_tutorial/amd_build_dockerfile_page.rst) for the installation guide and more information, and [this document](https://github.com/volcengine/verl/blob/main/docs/amd_tutorial/amd_vllm_page.rst) for the vLLM performance tuning for ROCm.\n\n\n## Citation and acknowledgement\n\nIf you find the project helpful, please cite:\n\n- [HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)\n- [A Framework for Training Large Language Models for Code Generation via Proximal Policy Optimization](https://i.cs.hku.hk/~cwu/papers/gmsheng-NL2Code24.pdf)\n\n```bibtex\n@article{sheng2024hybridflow,\n  title   = {HybridFlow: A Flexible and Efficient RLHF Framework},\n  author  = {Guangming Sheng and Chi Zhang and Zilingfeng Ye and Xibin Wu and Wang Zhang and Ru Zhang and Yanghua Peng and Haibin Lin and Chuan Wu},\n  year    = {2024},\n  journal = {arXiv preprint arXiv: 2409.19256}\n}\n```\n\nverl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The project is adopted and contributed by Bytedance, Anyscale, LMSys.org, [Alibaba Qwen team](https://github.com/QwenLM/), Shanghai AI Lab, Tsinghua University, UC Berkeley, UCLA, UIUC, University of Hong Kong, ke.com, [All Hands AI](https://www.all-hands.dev/), [ModelBest](http://modelbest.cn/), JD AI Lab, Microsoft Research, [StepFun](https://www.stepfun.com/), Amazon, LinkedIn, Meituan, [Camel-AI](https://www.camel-ai.org/), [OpenManus](https://github.com/OpenManus), Xiaomi, NVIDIA research, [Baichuan](https://www.baichuan-ai.com/home), [RedNote](https://www.xiaohongshu.com/), [SwissAI](https://www.swiss-ai.org/), [Moonshot AI (Kimi)](https://www.moonshot-ai.com/), Baidu, Snowflake, Skywork.ai, JetBrains, [IceSword Lab](https://www.iceswordlab.com), and many more.\n\n## Awesome work using verl\n\n- [TinyZero](https://github.com/Jiayi-Pan/TinyZero): a reproduction of **DeepSeek R1 Zero** recipe for reasoning tasks ![GitHub Repo stars](https://img.shields.io/github/stars/Jiayi-Pan/TinyZero)\n- [SkyThought](https://github.com/NovaSky-AI/SkyThought): RL training for Sky-T1-7B by NovaSky AI team. ![GitHub Repo stars](https://img.shields.io/github/stars/NovaSky-AI/SkyThought)\n- [simpleRL-reason](https://github.com/hkust-nlp/simpleRL-reason): SimpleRL-Zoo: Investigating and Taming Zero Reinforcement Learning for Open Base Models in the Wild ![GitHub Repo stars](https://img.shields.io/github/stars/hkust-nlp/simpleRL-reason)\n- [Easy-R1](https://github.com/hiyouga/EasyR1): **Multi-modal** RL training framework ![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/EasyR1)\n- [OpenManus-RL](https://github.com/OpenManus/OpenManus-RL): LLM Agents RL tunning framework for multiple agent environments. ![GitHub Repo stars](https://img.shields.io/github/stars/OpenManus/OpenManus-RL)\n- [rllm](https://github.com/agentica-project/rllm): async RL training with [verl-pipeline](https://github.com/agentica-project/verl-pipeline) ![GitHub Repo stars](https://img.shields.io/github/stars/agentica-project/rllm)\n- [RAGEN](https://github.com/ZihanWang314/ragen): a general-purpose reasoning **agent** training framework ![GitHub Repo stars](https://img.shields.io/github/stars/ZihanWang314/ragen)\n- [Search-R1](https://github.com/PeterGriffinJin/Search-R1): RL with reasoning and **searching (tool-call)** interleaved LLMs ![GitHub Repo stars](https://img.shields.io/github/stars/PeterGriffinJin/Search-R1)\n- [ReSearch](https://github.com/Agent-RL/ReSearch): Learning to **Re**ason with **Search** for LLMs via Reinforcement Learning ![GitHub Repo stars](https://img.shields.io/github/stars/Agent-RL/ReSearch)\n- [Skywork-OR1](https://github.com/SkyworkAI/Skywork-OR1): Skywork open reaonser series ![GitHub Repo stars](https://img.shields.io/github/stars/SkyworkAI/Skywork-OR1)\n- [ToRL](https://github.com/GAIR-NLP/ToRL): Scaling tool-integrated RL ![GitHub Repo stars](https://img.shields.io/github/stars/GAIR-NLP/ToRL)\n- [Absolute Zero Reasoner](https://github.com/LeapLabTHU/Absolute-Zero-Reasoner): [A no human curated data self-play framework for reasoning](https://arxiv.org/abs/2505.03335) ![GitHub Repo stars](https://img.shields.io/github/stars/LeapLabTHU/Absolute-Zero-Reasoner)\n- [verl-agent](https://github.com/langfengQ/verl-agent): A scalable training framework for **long-horizon LLM/VLM agents**, along with a new algorithm **GiGPO** ![GitHub Repo stars](https://img.shields.io/github/stars/langfengQ/verl-agent)\n- [RL-Factory](https://github.com/Simple-Efficient/RL-Factory): An easy and efficient RL post-training framework for Agentic Learning ![GitHub Repo stars](https://img.shields.io/github/stars/Simple-Efficient/RL-Factory)\n- [ReTool](https://retool-rl.github.io/): ReTool: reinforcement learning for strategic tool use in LLMs. Code release is in progress...\n- [verl-tool](https://github.com/TIGER-AI-Lab/verl-tool): An unified and easy-to-extend tool-agent training framework based on verl![GitHub Repo stars](https://img.shields.io/github/stars/TIGER-AI-Lab/verl-tool)\n- [PRIME](https://github.com/PRIME-RL/PRIME): Process reinforcement through implicit rewards ![GitHub Repo stars](https://img.shields.io/github/stars/PRIME-RL/PRIME)\n- [MemAgent](https://github.com/BytedTsinghua-SIA/MemAgent): MemAgent: Reshaping Long-Context LLM with Multi-Conv RL based Memory Agent ![GitHub Repo stars](https://img.shields.io/github/stars/BytedTsinghua-SIA/MemAgent)\n- [POLARIS](https://github.com/ChenxinAn-fdu/POLARIS): A Post-training recipe for scaling RL on Advanced Reasoning models ![GitHub Repo stars](https://img.shields.io/github/stars/ChenxinAn-fdu/POLARIS)\n- [GUI-R1](https://github.com/ritzz-ai/GUI-R1): **GUI-R1**: A Generalist R1-style Vision-Language Action Model For **GUI Agents** ![GitHub Repo stars](https://img.shields.io/github/stars/ritzz-ai/GUI-R1)\n- [DeepRetrieval](https://github.com/pat-jj/DeepRetrieval): RL Training of **Search Agent** with **Search/Retrieval Outcome** ![GitHub Repo stars](https://img.shields.io/github/stars/pat-jj/DeepRetrieval)\n- [Code-R1](https://github.com/ganler/code-r1): Reproducing R1 for **Code** with Reliable Rewards ![GitHub Repo stars](https://img.shields.io/github/stars/ganler/code-r1)\n- [DeepResearcher](https://github.com/GAIR-NLP/DeepResearcher): Scaling deep research via reinforcement learning in real-world environments ![GitHub Repo stars](https://img.shields.io/github/stars/GAIR-NLP/DeepResearcher)\n- [VAGEN](https://github.com/RAGEN-AI/VAGEN): Training VLM agents with multi-turn reinforcement learning ![GitHub Repo stars](https://img.shields.io/github/stars/RAGEN-AI/VAGEN)\n- [RM-R1](https://arxiv.org/abs/2505.02387): RL training of reasoning reward models ![GitHub Repo stars](https://img.shields.io/github/stars/RM-R1-UIUC/RM-R1)\n- [LUFFY](https://arxiv.org/pdf/2504.14945): Learning to Reason under Off-Policy Guidance![GitHub Repo stars](https://img.shields.io/github/stars/ElliottYan/LUFFY)\n- [DeepMath](https://github.com/zwhe99/DeepMath): DeepMath-103K data and series models for math reasoning![GitHub Repo stars](https://img.shields.io/github/stars/zwhe99/DeepMath)\n- [PACS](https://github.com/ritzz-ai/PACS): Implicit Actor Critic Coupling via a Supervised Learning Framework for RLVR ![GitHub Repo stars](https://img.shields.io/github/stars/ritzz-ai/PACS)\n- [Entropy Mechanism of RL](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL): The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning![GitHub Repo stars](https://img.shields.io/github/stars/PRIME-RL/Entropy-Mechanism-of-RL)\n- [LLaSA-TTS-GRPO](https://github.com/channel-io/ch-tts-llasa-rl-grpo): TTS fine-tuning with GRPO optimization based on LLASA models ![GitHub Repo stars](https://img.shields.io/github/stars/channel-io/ch-tts-llasa-rl-grpo)\n- [PF-PPO](https://arxiv.org/abs/2409.06957): Policy Filtration for PPO based on the reliability of reward signals for more efficient and robust RLHF.\n- [RACRO](https://github.com/gyhdog99/RACRO2): Build multi-modal reasoning models via decoupling it into query-conditioned captioning and text-only reasoning ![GitHub Repo stars](https://img.shields.io/github/stars/gyhdog99/RACRO2)\n- [Agent Lightning](https://github.com/microsoft/agent-lightning): A flexible and extensible framework that enables seamless agent optimization for any existing agent framework. ![GitHub Repo stars](https://img.shields.io/github/stars/microsoft/agent-lightning)\n- [VTool-R1](https://github.com/VTOOL-R1/vtool-r1): VLMs Learn to Think with Images via Reinforcement Learning on Multimodal Tool Use. ![GitHub Repo stars](https://img.shields.io/github/stars/VTOOL-R1/vtool-r1)\n- [Kimina-Prover-RL](https://github.com/project-numina/kimina-prover-rl/tree/main/recipe/kimina_prover_rl): Training pipeline for formal theorem proving, based on a paradigm inspired by DeepSeek-R1.\n- [RL-PLUS](https://github.com/YihongDong/RL-PLUS): Countering Capability Boundary Collapse of LLMs in Reinforcement Learning with Hybrid-policy Optimization.\n- [rStar2-Agent](https://github.com/microsoft/rStar): Using reinforcement learning with multi-step tool-calling for math tasks, rStar2-Agent-14B reaches frontier-level math reasoning in just 510 RL training steps ![GitHub Repo stars](https://img.shields.io/github/stars/microsoft/rStar)\n- [Vision-SR1](https://github.com/zli12321/Vision-SR1): Self-Rewarding Vision-Language Model via Reasoning Decomposition ![GitHub Repo stars](https://img.shields.io/github/stars/zli12321/Vision-SR1)\n- [SimpleVLA-RL](https://github.com/PRIME-RL/SimpleVLA-RL): SimpleVLA-RL: A Simple yet Effective Vision-Language Action Model for Reinforcement Learning ![GitHub Repo stars](https://img.shields.io/github/stars/PRIME-RL/SimpleVLA-RL)\n- [Table-R1](https://github.com/Table-R1/Table-R1): Table-R1: Inference-Time Scaling for Table Reasoning ![GitHub Repo stars](https://img.shields.io/github/stars/Table-R1/Table-R1)\n- [Revisual-R1](https://github.com/CSfufu/Revisual-R1): Revisual-R1: Advancing Multimodal Reasoning From Optimized Cold Start to Staged Reinforcement Learning ![GitHub Repo stars](https://img.shields.io/github/stars/CSfufu/Revisual-R1)\n- [ARES](https://github.com/shawn0728/ARES): ARES: Multimodal Adaptive Reasoning via Difficulty-Aware Token-Level Entropy Shaping ![GitHub Repo stars](https://img.shields.io/github/stars/shawn0728/ARES)\n- [Meta-Bandit-LLM](https://github.com/sanxing-chen/meta-bandit-llm): Meta-Bandit-LLM: Long-horizon multiturn interactive training for meta-bandit agents ![GitHub Repo stars](https://img.shields.io/github/stars/sanxing-chen/meta-bandit-llm)\n- [PokeeResearch](https://github.com/Pokee-AI/PokeeResearchOSS): PokeeResearch: State-of-the-art 7B DeepResearch Agent that leverages web search and content reading capabilities to answer complex questions using the most up-to-date information available online. ![Github Repo Stars](https://img.shields.io/github/stars/Pokee-AI/PokeeResearchOSS)\n\nand many more awesome work listed in [recipe](recipe/README.md).\n\n## Contribution Guide\n\nSee [contributions guide](CONTRIBUTING.md)\n\n## About [ByteDance Seed Team](https://team.doubao.com/)\n\nFounded in 2023, ByteDance Seed Team is dedicated to crafting the industry's most advanced AI foundation models. The team aspires to become a world-class research team and make significant contributions to the advancement of science and society. You can get to know Bytedance Seed better through the following channels👇\n<div>\n  <a href=\"https://team.doubao.com/\">\n    <img src=\"https://img.shields.io/badge/Website-%231e37ff?style=for-the-badge&logo=bytedance&logoColor=white\"></a>\n  <a href=\"https://github.com/user-attachments/assets/469535a8-42f2-4797-acdf-4f7a1d4a0c3e\">\n    <img src=\"https://img.shields.io/badge/WeChat-07C160?style=for-the-badge&logo=wechat&logoColor=white\"></a>\n <a href=\"https://www.xiaohongshu.com/user/profile/668e7e15000000000303157d?xsec_token=ABl2-aqekpytY6A8TuxjrwnZskU-6BsMRE_ufQQaSAvjc%3D&xsec_source=pc_search\">\n    <img src=\"https://img.shields.io/badge/Xiaohongshu-%23FF2442?style=for-the-badge&logo=xiaohongshu&logoColor=white\"></a>\n  <a href=\"https://www.zhihu.com/org/dou-bao-da-mo-xing-tuan-dui/\">\n    <img src=\"https://img.shields.io/badge/zhihu-%230084FF?style=for-the-badge&logo=zhihu&logoColor=white\"></a>\n\n</div>\n---\n\nWe are HIRING! Send us an [email](mailto:the.verl.project@gmail.com) if you are interested in internship/FTE opportunities in RL for agents.\n"
  },
  {
    "path": "verl_distillation/deploy_env.sh",
    "content": "#!/bin/bash\n# Multi-node Environment Deployment Script\n# Usage: bash deploy_env.sh [--all-nodes]\n\nset -e\n\nSCRIPT_DIR=$(cd $(dirname $0); pwd)\nPROJECT_DIR=${SCRIPT_DIR}\n\n# Configuration\nCONDA_ENV_NAME=${CONDA_ENV_NAME:-\"distill\"}\nPYTHON_VERSION=${PYTHON_VERSION:-\"3.10\"}\nHOSTFILE=${HOSTFILE:-\"/etc/mpi/hostfile\"}\n\n# Colors\nGREEN='\\033[0;32m'\nYELLOW='\\033[1;33m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nlog_info() { echo -e \"${GREEN}[INFO]${NC} $1\"; }\nlog_warn() { echo -e \"${YELLOW}[WARN]${NC} $1\"; }\nlog_error() { echo -e \"${RED}[ERROR]${NC} $1\"; }\n\n# Initialize conda\ninit_conda() {\n    for conda_sh in /root/anaconda3/etc/profile.d/conda.sh \\\n                    /root/miniconda3/etc/profile.d/conda.sh \\\n                    $HOME/anaconda3/etc/profile.d/conda.sh \\\n                    $HOME/miniconda3/etc/profile.d/conda.sh \\\n                    /opt/conda/etc/profile.d/conda.sh; do\n        [ -f \"$conda_sh\" ] && source \"$conda_sh\" && return 0\n    done\n    command -v conda &>/dev/null\n}\n\n# Setup proxy\nsetup_proxy() {\n    log_info \"Setting up proxy...\"\n    unset -v http_proxy https_proxy no_proxy\n    export http_proxy=http://oversea-squid2.ko.txyun:11080\n    export https_proxy=http://oversea-squid2.ko.txyun:11080\n    export no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com\n}\n\n# Install on local node\ninstall_local() {\n    log_info \"Installing environment...\"\n\n    # Setup proxy first\n    setup_proxy\n\n    if ! init_conda; then\n        log_error \"Conda not found.\"\n        exit 1\n    fi\n\n    # Configure conda for stability\n    conda config --set remote_read_timeout_secs 600\n    conda config --set remote_connect_timeout_secs 60\n    conda config --set remote_max_retries 10\n\n    # Create or activate conda env\n    if conda env list | grep -q \"^${CONDA_ENV_NAME} \"; then\n        log_warn \"Environment '${CONDA_ENV_NAME}' exists, activating...\"\n    else\n        log_info \"Creating environment '${CONDA_ENV_NAME}'...\"\n        conda create -n ${CONDA_ENV_NAME} python=${PYTHON_VERSION} -y\n    fi\n\n    source $(conda info --base)/etc/profile.d/conda.sh\n    conda activate ${CONDA_ENV_NAME}\n\n    log_info \"Installing torch...\"\n    pip install torch==2.8.0\n    pip install --force-reinstall torchvision==0.23.0\n    pip install --force-reinstall torchaudio==2.8.0\n\n    # Install requirements\n    log_info \"Installing requirements.txt...\"\n    pip install -r ${PROJECT_DIR}/requirements.txt\n\n    # Install flash-attn separately\n    # log_info \"Installing flash-attn...\"\n    # pip install flash-attn==2.7.4.post1 --no-build-isolation\n    pip install flash-attn --no-build-isolation\n\n    \n    # Install verl package\n    log_info \"Installing verl package...\"\n    cd ${PROJECT_DIR}\n    pip install -e .\n\n    log_info \"Done!\"\n}\n\n# Deploy to all nodes\ndeploy_all_nodes() {\n    [ ! -f \"${HOSTFILE}\" ] && log_error \"Hostfile not found: ${HOSTFILE}\" && exit 1\n\n    ALL_NODES=$(awk '!a[$1]++ {print $1}' ${HOSTFILE})\n    log_info \"Deploying to: ${ALL_NODES}\"\n\n    mkdir -p ./logs/deploy\n    for node in ${ALL_NODES}; do\n        log_info \"Deploying to ${node}...\"\n        ssh -n ${node} \"CONDA_ENV_NAME=${CONDA_ENV_NAME} bash ${SCRIPT_DIR}/deploy_env.sh\" \\\n            > \"./logs/deploy/deploy_${node}.log\" 2>&1 &\n    done\n\n    wait\n    log_info \"Deployment completed! Check logs in ./logs/deploy/\"\n}\n\n# Main\ncase \"${1}\" in\n    --all-nodes) deploy_all_nodes ;;\n    *) install_local ;;\nesac\n"
  },
  {
    "path": "verl_distillation/docker/Apptainerfile.rocm",
    "content": "Bootstrap: docker\n\n# Support - Traing: fsdp; Inference: vllm\n# FROM: rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n# Support - Traing: fsdp; Inference: vllm, sglang\nFROM lmsysorg/sglang:v0.4.5-rocm630\n\n%environment\n    export PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n\n    export HIPCC_COMPILE_FLAGS_APPEND=\"--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__\"\n    export CFLAGS=\"-D__HIP_PLATFORM_AMD__\"\n    export CXXFLAGS=\"-D__HIP_PLATFORM_AMD__\"\n\n%post\n    # Create source directory\n    mkdir -p /opt/src\n\n    # Uninstall and reinstall vllm\n    pip uninstall -y vllm\n    cd /opt/src\n    git clone -b v0.6.3 https://github.com/vllm-project/vllm.git\n    cd vllm\n    MAX_JOBS=$(nproc) python3 setup.py install\n    cd /opt\n    rm -rf /opt/src/vllm\n\n    # Install dependencies\n    pip install \"tensordict<0.6\" --no-deps\n    pip install accelerate \\\n        codetiming \\\n        datasets \\\n        dill \\\n        hydra-core \\\n        liger-kernel \\\n        numpy \\\n        pandas \\\n        peft \\\n        \"pyarrow>=15.0.0\" \\\n        pylatexenc \\\n        \"ray[data,train,tune,serve]\" \\\n        torchdata \\\n        transformers \\\n        wandb \\\n        orjson \\\n        pybind11\n\n    # Clone and install verl from GitHub\n    cd /opt\n    git clone https://github.com/volcengine/verl.git\n    cd verl\n    # Uncomment to use a specific version\n    # git checkout v0.3.0.post0\n    pip install -e . --no-deps\n\n    # Install torch_memory_saver\n    pip install git+https://github.com/ExtremeViscent/torch_memory_saver.git --no-deps"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.extention.awsefa",
    "content": "# Base Image support aws EFA\n# Build Image with frameworks based on this\nFROM verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2\n\n# For aws instances with EFA net interface (Sagemaker AI Pod)\n#     install EFA driver:\n######## AWS EFA ############\nENV NCCL_VERSION=2.25.1-1\nENV DEBIAN_FRONTEND=noninteractive\nENV EFA_INSTALLER_VERSION=1.40.0\nENV AWS_OFI_NCCL_VERSION=1.14.2\nENV FI_EFA_SET_CUDA_SYNC_MEMOPS=0\nENV FI_PROVIDER=efa\n\nRUN apt update && apt install -y linux-image-generic libhwloc-dev\n\nRUN cd /tmp && \\\n    curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz  && \\\n    tar -xf aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz && \\\n    cd aws-efa-installer && \\\n    ./efa_installer.sh -y -g --skip-kmod --skip-limit-conf --no-verify && \\\n    ldconfig && \\\n    rm -rf /tmp/aws-efa-installer /var/lib/apt/lists/*\n\n# NCCL EFA Plugin\nRUN cd /tmp && \\\n    curl -LO https://github.com/aws/aws-ofi-nccl/archive/refs/tags/v${AWS_OFI_NCCL_VERSION}.tar.gz && \\\n    tar -xzf /tmp/v${AWS_OFI_NCCL_VERSION}.tar.gz && \\\n    rm /tmp/v${AWS_OFI_NCCL_VERSION}.tar.gz && \\\n    mv aws-ofi-nccl-${AWS_OFI_NCCL_VERSION} aws-ofi-nccl && \\\n    cd /tmp/aws-ofi-nccl && \\\n    ./autogen.sh && \\\n    ./configure --prefix=/opt/amazon/efa \\\n    --with-libfabric=/opt/amazon/efa \\\n    --with-cuda=/usr/local/cuda \\\n    --enable-platform-aws \\\n    --with-mpi=/opt/amazon/openmpi && \\\n    make -j$(nproc) install && \\\n    rm -rf /tmp/aws-ofi/nccl\n\n# NCCL\nRUN echo \"/usr/local/lib\"      >> /etc/ld.so.conf.d/local.conf && \\\n    echo \"/opt/amazon/openmpi/lib\" >> /etc/ld.so.conf.d/efa.conf && \\\n    ldconfig\n\nENV OMPI_MCA_pml=^cm,ucx            \\\n    OMPI_MCA_btl=tcp,self           \\\n    OMPI_MCA_btl_tcp_if_exclude=lo,docker0,veth_def_agent \\\n    OPAL_PREFIX=/opt/amazon/openmpi \\\n    NCCL_SOCKET_IFNAME=^docker,lo,veth_def_agent  \\\n    FI_EFA_USE_HUGE_PAGE=0\n\n# docker build -t verl:awsefa --label \"commit=$(git rev-parse --short HEAD)\" .\n# on aws:\n# docker run --ipc=host --privileged --name verldev --gpus all --network=host --shm-size=1800gb -itd verl:awsefa\n"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.ngc.vllm",
    "content": "# docker buildx build --platform linux/x86_64 -t \"verlai/verl:ngc-th2.4.0-cu124-vllm0.6.3-ray2.4-te1.7-v0.0.6\" -f docker/Dockerfile.ngc.vllm . --builder cloud-verlai-verl-builder --progress=plain --push\nFROM nvcr.io/nvidia/pytorch:24.05-py3\n\n# uninstall nv-pytorch fork\nRUN pip3 uninstall pytorch-quantization \\\n    pytorch-triton \\\n    torch \\\n    torch-tensorrt \\\n    torchvision \\\n    xgboost transformer_engine flash_attn \\\n    apex megatron-core -y\n\nRUN pip3 install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124\n\n# =============== Megatron dependencies (optional) =================\n# install apex, set MAX_JOBS to avoid OOMs\nRUN MAX_JOBS=4 pip3 install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \\\n    --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" \\\n    git+https://github.com/NVIDIA/apex\n# =============== End of Megatron dependencies (optional) =================\n\nRUN pip3 install --no-cache-dir \\\n    accelerate \\\n    codetiming \\\n    datasets \\\n    dill \\\n    hydra-core \\\n    numpy \\\n    'pandas' \\\n    'peft' \\\n    'pyarrow>=15.0.0' \\\n    'pybind11' \\\n    'pylatexenc' \\\n    'ray>=2.10' \\\n    'tensordict<0.6' \\\n    'transformers' \\\n    'vllm==0.6.3.post1' \\\n    'wandb' \\\n    'tensorboard'\n\n# full dependencies\nRUN pip3 install pytest pre-commit py-spy pyext liger-kernel\n\n# =============== Megatron dependencies (optional) =================\n# install Transformer Engine, which requires FA 2.5.8. Do it in a separate step for docker cache\nRUN MAX_JOBS=4 NINJA_FLAGS=\"-j4\" pip3 install flash-attn==2.5.8 --no-cache-dir --no-build-isolation\nRUN MAX_JOBS=1 NINJA_FLAGS=\"-j1\" TE_BUILD_WITH_NINJA=0 pip3 install git+https://github.com/eric-haibin-lin/TransformerEngine.git@v1.7.0\n# =============== End of Megatron dependencies (optional) =================\n"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.ngc.vllm0.8",
    "content": "# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\n# Install torch-2.6.0+cu124 + vllm-0.8.3\n# torch-2.6.0+cu124: cxx11abi=False\n# torch-2.6.0+cu126: cxx11abi=True\n# see https://github.com/flashinfer-ai/flashinfer/issues/911\nRUN pip install --no-cache-dir \"vllm==0.8.3\" \"torch==2.6.0\" \"torchvision==0.21.0\" \"torchaudio==2.6.0\" \"tensordict==0.6.2\" torchdata \\\n    \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=15.0.0\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \\\n    pytest py-spy pyext pre-commit ruff tensorboard\n\n# Install flash-attn-2.7.4.post1 (cxx11abi=False)\nRUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \\\n    pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n\n# Install flashinfer-0.2.2.post1+cu124 (cxx11abi=False)\n# vllm-0.8.3 does not support flashinfer>=0.2.3\n# see https://github.com/vllm-project/vllm/pull/15777\nRUN wget -nv https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install verl\nRUN pip install --no-cache-dir verl[vllm] -U\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.ngc.vllm0.8.sagemaker",
    "content": "# Using a pre-built image from AWS DLC which contains the current version of python (3.10) and supported cuda version (12.1)\nFROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.1.0-transformers4.36.0-gpu-py310-cu121-ubuntu20.04\n\n# uninstall nv-pytorch fork\nRUN pip3 uninstall -y pytorch-quantization \\\n    pytorch-triton torch torch-tensorrt torchvision \\\n    xgboost transformer_engine flash_attn apex megatron-core\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini && \\\n    apt-get clean\n\n# Install torch-2.6.0 + vllm-0.8.2\nRUN pip install --no-cache-dir vllm==0.8.2 torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata==0.11.0 \\\n    transformers>=4.49.0 accelerate datasets peft hf-transfer \\\n    ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \\\n    pytest pre-commit py-spy pyext ruff tensorboard\n\n# Install flash_attn-2.7.4.post1\nRUN pip uninstall -y transformer-engine flash-attn && \\\n    pip install flash-attn==2.7.4.post1 --no-build-isolation\n\n# Fix cv2\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6 && \\\n    pip install --no-cache-dir --upgrade optree>=0.13.0\n\n# Install verl\nRUN pip install --no-cache-dir verl[vllm] -U\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.rocm",
    "content": "# FROM \"compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.4:94_ubuntu22.04_py3.10_pytorch_release-2.7_575e247\"\n# FROM \"rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04\"\nFROM \"rlsys/rocm-6.3.4-patch:rocm6.3.4-numa-patch_ubuntu-22.04\"\n\nSHELL [\"/bin/bash\", \"-ceuxo\", \"pipefail\"]\n\nENV MAX_JOBS=512\n\nENV PATH=\"/usr/local/python3.12/bin:$PATH\"\nRUN ln -sf /usr/bin/python3.12 /usr/bin/python && \\\n    ln -sf /usr/bin/pip3.12 /usr/bin/pip\n\n############################################\n############################################\nRUN apt-get update\nRUN apt-get install -y pkg-config liblzma-dev\n############################################\n############################################\n\n\n###########################################\n##########Install TransformerEngine########\n###########################################\nWORKDIR /workspace/\n# transformer-engine install\n# https://github.com/ROCm/TransformerEngine\n\nRUN rm -rf TransformerEngine \nRUN git clone --recursive https://github.com/ROCm/TransformerEngine.git\nWORKDIR /workspace/TransformerEngine\nRUN git checkout 236178e5\n# git checkout bb061ade\n# git checkout 864405c\n\nENV NVTE_FRAMEWORK=pytorch \nENV NVTE_ROCM_ARCH=gfx942 \nENV NVTE_USE_HIPBLASLT=1\nENV NVTE_USE_ROCM=1  \n\n# export CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}\"\nENV CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr\"\n\n\n# ENV NVTE_BUILD_MAX_JOBS=$(MAX_JOBS)\n\nRUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv \n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n\n####################################################################################\n################Install vllm - sglang require vllm 0.6.7 dependency#################\n####################################################################################\n#### Require vllm 0.6.7 - checkout 113274a0\nWORKDIR /workspace/\nRUN rm -rf vllm\nRUN pip uninstall -y vllm\n# Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html\nRUN git clone https://github.com/ROCm/vllm.git\n# git clone https://github.com/vllm-project/vllm.git\nWORKDIR /workspace/vllm\nRUN git checkout 113274a0\nENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n#ENV MAX_JOBS=512\nENV MAX_JOBS=${MAX_JOBS}\nRUN pip install \"boto3>=1.26.0\"\nRUN pip install setuptools_scm\n# will add src into py. You can delete the repo\nRUN python3 setup.py install\nWORKDIR /workspace/\n####################################################################################\n####################################################################################\n####################################################################################\n\n\n\n###########################################\n############For hack docker################\n###########################################\nRUN pip install setuptools==75.8.0\n###########################################\n###########################################\n###########################################\n\n\n\n###########################################\n############build sgalng###################\n###########################################\n# Set environment variables\nENV BASE_DIR=/sgl-workspace\nENV BUILD_TYPE=all\nENV SGL_REPO=https://github.com/sgl-project/sglang\nENV SGL_BRANCH=v0.4.6.post5\nENV TRITON_REPO=https://github.com/ROCm/triton.git\nENV TRITON_COMMIT=improve_fa_decode_3.0.0\nENV AITER_REPO=https://github.com/ROCm/aiter.git\nENV AITER_COMMIT=v0.1.2\n# v0.1.2 version - commit id: 9d11f47\n# ENV AITER_COMMIT=9d11f47\n\nENV HIP_FORCE_DEV_KERNARG=1\nENV HSA_NO_SCRATCH_RECLAIM=1\nENV SGLANG_SET_CPU_AFFINITY=1\nENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1\nENV NCCL_MIN_NCHANNELS=112\nENV MOE_PADDING=1\nENV VLLM_FP8_PADDING=1\nENV VLLM_FP8_ACT_PADDING=1\nENV VLLM_FP8_WEIGHT_PADDING=1\nENV VLLM_FP8_REDUCE_CONV=1\nENV TORCHINDUCTOR_MAX_AUTOTUNE=1\nENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1\nENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\nENV AMDGPU_TARGETS=gfx942\nENV ROCM_ARCH=gfx942\nENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n\n# Switch to working directory\nWORKDIR /sgl-workspace\n\n# Clean and create directory\nRUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace\n\n# Clone and build sglang\nRUN git clone ${SGL_REPO} \\\n    && cd sglang \\\n    && git checkout ${SGL_BRANCH} || echo \"Using default branch\" \\\n    && cd sgl-kernel \\\n    && rm -f pyproject.toml \\\n    && mv pyproject_rocm.toml pyproject.toml \\\n    && python setup_rocm.py install \\\n    && cd .. \\\n    && if [ \"$BUILD_TYPE\" = \"srt\" ]; then \\\n         python -m pip --no-cache-dir install -e \"python[srt_hip]\"; \\\n       else \\\n         python -m pip --no-cache-dir install -e \"python[all_hip]\"; \\\n       fi \\\n    && cd /sgl-workspace \\\n    && cp -r /sgl-workspace/sglang /sglang \\\n    && python -m pip cache purge\n\n# Install common Python packages\nRUN pip install IPython orjson python-multipart torchao pybind11\n\n# Rebuild Triton\nRUN pip uninstall -y triton || true \\\n    && git clone ${TRITON_REPO} \\\n    && cd triton \\\n    && git checkout ${TRITON_COMMIT} \\\n    && cd python \\\n    && python3 setup.py install \\\n    && cd /sgl-workspace\n\n\n# ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1\"\n# ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\n\n# Build aiter\n#version: Commit 9d11f47\n    # && git checkout ${AITER_COMMIT} \\\nRUN pip uninstall -y aiter || true\nRUN git clone ${AITER_REPO} \\\n    && cd aiter \\\n    && git checkout ${AITER_COMMIT} \\\n    && git submodule sync \\\n    && git submodule update --init --recursive \\\n    && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \\\n    && cd /sgl-workspace\n    # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \\\n    # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \\\n\n# Copy MI300X config \nRUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \\\n         /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \\\n         -type f -name '*MI300X*' | \\\n         xargs -I {} sh -c 'vf_config=$(echo \"$1\" | sed \"s/MI300X/MI300X_VF/\"); cp \"$1\" \"$vf_config\"' -- {}\n\n# Environment setup complete.\nRUN echo \"Environment setup complete.\"\n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n\n\n###########################################\n###############vllm v0.8.5#################\n###########################################\n# ENV GITHUB_USERNAME=yushengsu-thu\n# ENV GITHUB_MAIL=yushengsu@gmail.com\n\n# RUN git config --global user.name \"${GITHUB_USERNAME}\" \\\n#     && git config --global user.email \"${GITHUB_MAIL}\" \n\nWORKDIR /workspace/\n\nENV VLLM_TARGET_DEVICE=rocm \nENV ROCM_PATH=/opt/rocm \nENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev\n\n# Find the repo path in: DockerFile/Dockerfile.rocm_yang\n# RUN git clone https://github.com/RLFoundation/vllm-patch.git\nRUN pip uninstall -y vllm || true\nRUN rm -rf vllm-patch\nRUN git clone https://github.com/RLFoundation/vllm-patch.git \\\n    && cd vllm-patch \\\n    && git checkout v0.8.5-sleep-numa \\\n    && rm -rf build/ dist/ *.egg-info \\\n    && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \\\n    && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py install\n    # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py develop\n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n#########################################\n#### Install megatron-core###############\n#########################################\nRUN pip uninstall -y megatron-core && \\\n    git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \\\n    cd Megatron-LM-amd_version && \\\n    pip install -vvv -e . && \\\n    cd /workspace/\n#########################################\n#########################################\n#########################################\n\n\n\n\n#######################################\n################apex###################\n#######################################\nWORKDIR /workspace/\nRUN pip uninstall -y apex && \\\n    git clone https://github.com/ROCm/apex.git && \\\n    cd apex && \\\n    python setup.py install && \\\n    cd /workspace/ \n#######################################\n#######################################\n#######################################\n\n\n\n\n################################################################################\n###########################Add torch_memory_saver###############################\n################################################################################\n# Set environment variables\nENV HIPCC_COMPILE_FLAGS_APPEND=\"--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__\"\nENV CFLAGS=\"-D__HIP_PLATFORM_AMD__\"\nENV CXXFLAGS=\"-D__HIP_PLATFORM_AMD__\"\nRUN pip install \"git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa\"\n################################################################################\n################################################################################\n################################################################################\n\n\n\n########################################\n######Install ray#######################\n########################################\n# need to add this patch: https://github.com/ray-project/ray/pull/53531/files\nRUN pip uninstall ray -y\nRUN pip install \"ray[data,train,tune,serve]>=2.47.0\" \n########################################\n########################################\n########################################\n\n\n\n##########################################\n#######Install other dependencies#########\n##########################################\nRUN pip install \"tensordict==0.6.2\" --no-deps && \\\n    pip install accelerate \\\n    codetiming \\\n    datasets \\\n    dill \\\n    hydra-core \\\n    liger-kernel \\\n    numpy \\\n    pandas \\\n    peft \\\n    \"pyarrow>=15.0.0\" \\\n    pylatexenc \\\n    torchdata \\\n    wandb \\\n    orjson \\\n    pybind11\n    \nWORKDIR /workspace/\nRUN git clone https://github.com/volcengine/verl.git && \\\n    cd verl && \\\n    pip install -e . \n##########################################\n##########################################\n##########################################\n\n\n\nWORKDIR /workspace/\n\nCMD [\"/usr/bin/bash\"]\n"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.rocm7",
    "content": "# default base image\nARG REMOTE_VLLM=\"1\"\nARG COMMON_WORKDIR=/app\nARG BASE_IMAGE=rocm/vllm-dev:base\n\nFROM ${BASE_IMAGE} AS base\n\nARG ARG_PYTORCH_ROCM_ARCH\nENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}}\n\n# Install some basic utilities\nRUN apt-get update -q -y && apt-get install -q -y \\\n    sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev \\\n    apt-transport-https ca-certificates wget curl\n# Remove sccache\nRUN python3 -m pip install --upgrade pip\nRUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f \"$(which sccache)\"\nARG COMMON_WORKDIR\nWORKDIR ${COMMON_WORKDIR}\n\n\n# -----------------------\n# vLLM fetch stages\nFROM base AS fetch_vllm_0\nONBUILD COPY ./ vllm/\nFROM base AS fetch_vllm_1\n#ARG VLLM_REPO=\"https://github.com/ROCm/vllm.git\"\n#ARG VLLM_BRANCH=\"main\"\nARG VLLM_REPO=https://github.com/HollowMan6/vllm.git\nARG VLLM_BRANCH=\"sleep_amd\"\nONBUILD RUN git clone ${VLLM_REPO} \\\n            && cd vllm \\\n            && git checkout ${VLLM_BRANCH}\nFROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm\n\n# -----------------------\n# vLLM build stages\nFROM fetch_vllm AS build_vllm\n# Build vLLM\nRUN cd vllm \\\n    && python3 -m pip install -r requirements/rocm.txt \\\n    && python3 setup.py clean --all  \\\n    && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \\\n    && VLLM_TARGET_DEVICE=rocm ROCM_PATH=/opt/rocm/ VLLM_GPU_LANG=HIP SETUPTOOLS_SCM_PRETEND_VERSION=0.11.0.dev python3 setup.py bdist_wheel --dist-dir=dist\n    #&& python3 setup.py bdist_wheel --dist-dir=dist\nFROM scratch AS export_vllm\nARG COMMON_WORKDIR\nCOPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /\nCOPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements /requirements\nCOPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks\nCOPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests\nCOPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples\nCOPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite\n\n# -----------------------\n# Test vLLM image\nFROM base AS test\n\nRUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*\n\n# Install vLLM\n#RUN --mount=type=bind,from=export_vllm,src=/,target=/install \\\nCOPY --from=export_vllm /*.whl /install\nCOPY --from=export_vllm /requirements /install/requirements\nCOPY --from=export_vllm /benchmarks /install/benchmarks\nCOPY --from=export_vllm /tests /install/tests\nCOPY --from=export_vllm /examples /install/examples\nCOPY --from=export_vllm /.buildkite /install/.buildkite\n\nRUN cd /install \\\n    && pip install -U -r requirements/rocm.txt \\\n    && pip install -U -r requirements/rocm-test.txt \\\n    && pip uninstall -y vllm \\\n    && pip install *.whl\n\nWORKDIR /vllm-workspace\nARG COMMON_WORKDIR\nCOPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace\n\n# install development dependencies (for testing)\nRUN cd /vllm-workspace \\\n    && rm -rf vllm \\\n    && python3 -m pip install -e tests/vllm_test_utils \\\n    && python3 -m pip install lm-eval[api]==0.4.4 \\\n    && python3 -m pip install pytest-shard\n\n# -----------------------\n# Final vLLM image\nFROM base AS final\n\nRUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*\n# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.\n# Manually remove it so that later steps of numpy upgrade can continue\nRUN case \"$(which python3)\" in \\\n        *\"/opt/conda/envs/py_3.9\"*) \\\n            rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \\\n        *) ;; esac\n\nRUN python3 -m pip install --upgrade huggingface-hub[cli]\n\n# Install vLLM\nRUN --mount=type=bind,from=export_vllm,src=/,target=/install \\\n    cd /install \\\n    && pip install -U -r requirements/rocm.txt \\\n    && pip uninstall -y vllm \\\n    && pip install *.whl\n\nARG COMMON_WORKDIR\n\n# Copy over the benchmark scripts as well\nCOPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks\nCOPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples\n\nENV RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1\nENV TOKENIZERS_PARALLELISM=false\n\n# ENV that can improve safe tensor loading, and end-to-end time\nENV SAFETENSORS_FAST_GPU=1\n\n# Performance environment variable.\nENV HIP_FORCE_DEV_KERNARG=1\n\n# -----------------------\n# Install verl\nARG VERL_REPO=https://github.com/volcengine/verl.git\nARG VERL_BRANCH=main\nRUN pip install \"tensordict==0.6.2\" --no-deps && \\\n    pip install accelerate \\\n    codetiming \\\n    datasets \\\n    dill \\\n    hydra-core \\\n    liger-kernel \\\n    numpy \\\n    pandas \\\n    peft \\\n    \"pyarrow>=15.0.0\" \\\n    pylatexenc \\\n    torchdata \\\n    wandb \\\n    orjson \\\n    pybind11\n\nWORKDIR /workspace/\nRUN git clone ${VERL_REPO} && \\\n    cd verl && \\\n    git checkout ${VERL_BRANCH} && \\\n    pip install -e .\n\nCMD [\"/bin/bash\"]\n\n"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.rocm_verl-0.3.0.post1",
    "content": "#  Build the docker in the repo dir:\n# docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 .\n# docker images # you can find your built docker\n\n\n# Support - Traing: fsdp; Inference: vllm\n# FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n# Support - Traing: fsdp; Inference: vllm, sglang\nFROM lmsysorg/sglang:v0.4.6.post5-rocm630\n\n# Set working directory\n# WORKDIR $PWD/app\n\n# Set environment variables\nENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n\nENV HIPCC_COMPILE_FLAGS_APPEND=\"--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__\"\nENV CFLAGS=\"-D__HIP_PLATFORM_AMD__\"\nENV CXXFLAGS=\"-D__HIP_PLATFORM_AMD__\"\n\n# Install vllm\nRUN pip uninstall -y vllm && \\\n    rm -rf vllm && \\\n    git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \\\n    cd vllm && \\\n    MAX_JOBS=$(nproc) python3 setup.py install && \\\n    cd .. && \\\n    rm -rf vllm\n\n# Copy the entire project directory\nCOPY . .\n\n# Install dependencies\nRUN pip install \"tensordict==0.6.2\" --no-deps && \\\n    pip install accelerate \\\n    codetiming \\\n    datasets \\\n    dill \\\n    hydra-core \\\n    liger-kernel \\\n    numpy \\\n    pandas \\\n    peft \\\n    \"pyarrow>=15.0.0\" \\\n    pylatexenc \\\n    \"ray[data,train,tune,serve]<2.45.0\" \\\n    torchdata \\\n    transformers \\\n    wandb \\\n    orjson \\\n    pybind11\n    \nRUN git clone https://github.com/volcengine/verl.git && \\\n    cd verl && \\\n    pip install -e . \n\n# Install torch_memory_saver\nRUN pip install git+https://github.com/ExtremeViscent/torch_memory_saver.git --no-deps\n"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.rocm_verl-0.4.1",
    "content": "# FROM \"compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.4:94_ubuntu22.04_py3.10_pytorch_release-2.7_575e247\"\n# FROM \"rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04\"\nFROM \"rlsys/rocm-6.3.4-patch:rocm6.3.4-numa-patch_ubuntu-22.04\"\n\nSHELL [\"/bin/bash\", \"-ceuxo\", \"pipefail\"]\n\nENV MAX_JOBS=512\n\nENV PATH=\"/usr/local/python3.12/bin:$PATH\"\nRUN ln -sf /usr/bin/python3.12 /usr/bin/python && \\\n    ln -sf /usr/bin/pip3.12 /usr/bin/pip\n\n############################################\n############################################\nRUN apt-get update\nRUN apt-get install -y pkg-config liblzma-dev\n############################################\n############################################\n\n\n###########################################\n##########Install TransformerEngine########\n###########################################\nWORKDIR /workspace/\n# transformer-engine install\n# https://github.com/ROCm/TransformerEngine\n\nRUN rm -rf TransformerEngine \nRUN git clone --recursive https://github.com/ROCm/TransformerEngine.git\nWORKDIR /workspace/TransformerEngine\nRUN git checkout 236178e5\n# git checkout bb061ade\n# git checkout 864405c\n\nENV NVTE_FRAMEWORK=pytorch \nENV NVTE_ROCM_ARCH=gfx942 \nENV NVTE_USE_HIPBLASLT=1\nENV NVTE_USE_ROCM=1  \n\n# export CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}\"\nENV CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr\"\n\n\n# ENV NVTE_BUILD_MAX_JOBS=$(MAX_JOBS)\n\nRUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv \n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n\n####################################################################################\n################Install vllm - sglang require vllm 0.6.7 dependency#################\n####################################################################################\n#### Require vllm 0.6.7 - checkout 113274a0\nWORKDIR /workspace/\nRUN rm -rf vllm\nRUN pip uninstall -y vllm\n# Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html\nRUN git clone https://github.com/ROCm/vllm.git\n# git clone https://github.com/vllm-project/vllm.git\nWORKDIR /workspace/vllm\nRUN git checkout 113274a0\nENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n#ENV MAX_JOBS=512\nENV MAX_JOBS=${MAX_JOBS}\nRUN pip install \"boto3>=1.26.0\"\nRUN pip install setuptools_scm\n# will add src into py. You can delete the repo\nRUN python3 setup.py install\nWORKDIR /workspace/\n####################################################################################\n####################################################################################\n####################################################################################\n\n\n\n###########################################\n############For hack docker################\n###########################################\nRUN pip install setuptools==75.8.0\n###########################################\n###########################################\n###########################################\n\n\n\n###########################################\n############build sgalng###################\n###########################################\n# Set environment variables\nENV BASE_DIR=/sgl-workspace\nENV BUILD_TYPE=all\nENV SGL_REPO=https://github.com/sgl-project/sglang\nENV SGL_BRANCH=v0.4.6.post5\nENV TRITON_REPO=https://github.com/ROCm/triton.git\nENV TRITON_COMMIT=improve_fa_decode_3.0.0\nENV AITER_REPO=https://github.com/ROCm/aiter.git\nENV AITER_COMMIT=v0.1.2\n# v0.1.2 version - commit id: 9d11f47\n# ENV AITER_COMMIT=9d11f47\n\nENV HIP_FORCE_DEV_KERNARG=1\nENV HSA_NO_SCRATCH_RECLAIM=1\nENV SGLANG_SET_CPU_AFFINITY=1\nENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1\nENV NCCL_MIN_NCHANNELS=112\nENV MOE_PADDING=1\nENV VLLM_FP8_PADDING=1\nENV VLLM_FP8_ACT_PADDING=1\nENV VLLM_FP8_WEIGHT_PADDING=1\nENV VLLM_FP8_REDUCE_CONV=1\nENV TORCHINDUCTOR_MAX_AUTOTUNE=1\nENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1\nENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\nENV AMDGPU_TARGETS=gfx942\nENV ROCM_ARCH=gfx942\nENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n\n# Switch to working directory\nWORKDIR /sgl-workspace\n\n# Clean and create directory\nRUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace\n\n# Clone and build sglang\nRUN git clone ${SGL_REPO} \\\n    && cd sglang \\\n    && git checkout ${SGL_BRANCH} || echo \"Using default branch\" \\\n    && cd sgl-kernel \\\n    && rm -f pyproject.toml \\\n    && mv pyproject_rocm.toml pyproject.toml \\\n    && python setup_rocm.py install \\\n    && cd .. \\\n    && if [ \"$BUILD_TYPE\" = \"srt\" ]; then \\\n         python -m pip --no-cache-dir install -e \"python[srt_hip]\"; \\\n       else \\\n         python -m pip --no-cache-dir install -e \"python[all_hip]\"; \\\n       fi \\\n    && cd /sgl-workspace \\\n    && cp -r /sgl-workspace/sglang /sglang \\\n    && python -m pip cache purge\n\n# Install common Python packages\nRUN pip install IPython orjson python-multipart torchao pybind11\n\n# Rebuild Triton\nRUN pip uninstall -y triton || true \\\n    && git clone ${TRITON_REPO} \\\n    && cd triton \\\n    && git checkout ${TRITON_COMMIT} \\\n    && cd python \\\n    && python3 setup.py install \\\n    && cd /sgl-workspace\n\n\n# ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1\"\n# ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\n\n# Build aiter\n#version: Commit 9d11f47\n    # && git checkout ${AITER_COMMIT} \\\nRUN pip uninstall -y aiter || true\nRUN git clone ${AITER_REPO} \\\n    && cd aiter \\\n    && git checkout ${AITER_COMMIT} \\\n    && git submodule sync \\\n    && git submodule update --init --recursive \\\n    && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \\\n    && cd /sgl-workspace\n    # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \\\n    # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \\\n\n# Copy MI300X config \nRUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \\\n         /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \\\n         -type f -name '*MI300X*' | \\\n         xargs -I {} sh -c 'vf_config=$(echo \"$1\" | sed \"s/MI300X/MI300X_VF/\"); cp \"$1\" \"$vf_config\"' -- {}\n\n# Environment setup complete.\nRUN echo \"Environment setup complete.\"\n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n\n\n###########################################\n###############vllm v0.8.5#################\n###########################################\n# ENV GITHUB_USERNAME=yushengsu-thu\n# ENV GITHUB_MAIL=yushengsu@gmail.com\n\n# RUN git config --global user.name \"${GITHUB_USERNAME}\" \\\n#     && git config --global user.email \"${GITHUB_MAIL}\" \n\nWORKDIR /workspace/\n\nENV VLLM_TARGET_DEVICE=rocm \nENV ROCM_PATH=/opt/rocm \nENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev\n\n# Find the repo path in: DockerFile/Dockerfile.rocm_yang\n# RUN git clone https://github.com/RLFoundation/vllm-patch.git\nRUN pip uninstall -y vllm || true\nRUN rm -rf vllm-patch\nRUN git clone https://github.com/RLFoundation/vllm-patch.git \\\n    && cd vllm-patch \\\n    && git checkout v0.8.5-sleep-numa \\\n    && rm -rf build/ dist/ *.egg-info \\\n    && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \\\n    && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py install\n    # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py develop\n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n#########################################\n#### Install megatron-core###############\n#########################################\nRUN pip uninstall -y megatron-core && \\\n    git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \\\n    cd Megatron-LM-amd_version && \\\n    pip install -vvv -e . && \\\n    cd /workspace/\n#########################################\n#########################################\n#########################################\n\n\n\n\n#######################################\n################apex###################\n#######################################\nWORKDIR /workspace/\nRUN pip uninstall -y apex && \\\n    git clone https://github.com/ROCm/apex.git && \\\n    cd apex && \\\n    python setup.py install && \\\n    cd /workspace/ \n#######################################\n#######################################\n#######################################\n\n\n\n\n################################################################################\n###########################Add torch_memory_saver###############################\n################################################################################\n# Set environment variables\nENV HIPCC_COMPILE_FLAGS_APPEND=\"--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__\"\nENV CFLAGS=\"-D__HIP_PLATFORM_AMD__\"\nENV CXXFLAGS=\"-D__HIP_PLATFORM_AMD__\"\nRUN pip install \"git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa\"\n################################################################################\n################################################################################\n################################################################################\n\n\n\n########################################\n######Install ray#######################\n########################################\n# need to add this patch: https://github.com/ray-project/ray/pull/53531/files\nRUN pip uninstall ray -y\nRUN pip install \"ray[data,train,tune,serve]>=2.47.0\" \n########################################\n########################################\n########################################\n\n\n\n##########################################\n#######Install other dependencies#########\n##########################################\nRUN pip install \"tensordict==0.6.2\" --no-deps && \\\n    pip install accelerate \\\n    codetiming \\\n    datasets \\\n    dill \\\n    hydra-core \\\n    liger-kernel \\\n    numpy \\\n    pandas \\\n    peft \\\n    \"pyarrow>=15.0.0\" \\\n    pylatexenc \\\n    torchdata \\\n    wandb \\\n    orjson \\\n    pybind11\n    \nWORKDIR /workspace/\nRUN git clone https://github.com/volcengine/verl.git && \\\n    cd verl && \\\n    pip install -e . \n##########################################\n##########################################\n##########################################\n\n\n\nWORKDIR /workspace/\n\nCMD [\"/usr/bin/bash\"]\nCMD [\"/usr/bin/bash\"]\n"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.sglang",
    "content": "# Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=32\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.ustc.edu.cn/ubuntu/\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini && \\\n    apt-get clean\n\n# Change pip source\nARG PIP_INDEX=https://mirrors.aliyun.com/pypi/simple/\n\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Install sglang-0.4.6.post5 and torch-memory-saver\nRUN pip uninstall -y cuda-python && pip install \"sglang[all]==0.4.6.post5\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir\n\n# Install torch-2.6.0\nRUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \\\n    transformers>=4.49.0 accelerate datasets peft hf_transfer \\\n    ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb liger-kernel \\\n    pytest pre-commit py-spy pyext\n\n# Install flash_attn-2.7.4.post1\nRUN pip uninstall -y transformer-engine flash-attn && \\\n    wget -v https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \\\n    pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n\n# Fix cv2\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6\n"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.vemlp.vllm.te",
    "content": "# docker buildx build --platform linux/x86_64 -t \"verlai/verl:$TAG\" -f docker/$FILE .\n\n# the one in docker.io is an alias for the one veturbo\n# FROM vemlp-cn-beijing.cr.volces.com/veturbo/pytorch:2.4-cu124\nFROM docker.io/haibinlin/verl:v0.0.5-th2.4.0-cu124-base\n\n# only config pip index with https://pypi.tuna.tsinghua.edu.cn/simple if needed\n# unset for now\nRUN pip3 config unset global.index-url\n\n# transformers 4.47.0 contains the following bug:\n# AttributeError: 'Gemma2Attention' object has no attribute '_flash_attn_uses_top_left_mask'\nRUN pip3 install --no-cache-dir \\\n    torch==2.4.0 \\\n    accelerate \\\n    codetiming \\\n    dill \\\n    hydra-core \\\n    numpy \\\n    pybind11 \\\n    tensordict \\\n    \"transformers <= 4.46.0\"\n\nRUN pip3 install --no-cache-dir flash-attn==2.7.0.post2 --no-build-isolation\n\n# vllm depends on ray\nRUN pip3 install --no-cache-dir vllm==0.6.3 ray==2.10\n\n# install apex\nRUN MAX_JOBS=4 pip3 install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \\\n    --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" \\\n    git+https://github.com/NVIDIA/apex\n\n# install Transformer Engine\n# - flash-attn pinned to 2.5.3 by TransformerEngine, switch to eric-haibin-lin/TransformerEngine.git@v1.7.0 to relax version req\n# - install with: MAX_JOBS=1 NINJA_FLAGS=\"-j1\" TE_BUILD_WITH_NINJA=0 to avoid OOM\n# - cudnn is required by TransformerEngine\n# RUN CUDNN_PATH=/opt/conda/lib/python3.11/site-packages/nvidia/cudnn \\\n#     pip3 install git+https://github.com/eric-haibin-lin/TransformerEngine.git@v1.7.0\nRUN MAX_JOBS=1 NINJA_FLAGS=\"-j1\" pip3 install flash-attn==2.5.3 --no-cache-dir --no-build-isolation\nRUN MAX_JOBS=1 NINJA_FLAGS=\"-j1\" pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@v1.7\n"
  },
  {
    "path": "verl_distillation/docker/Dockerfile.vllm.sglang.megatron.deepseek",
    "content": "# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\n# Reinstall CUDA 12.4\nRUN aria2c https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \\\n    mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600\n\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cuda-toolkit-12-4 && \\\n    rm cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    update-alternatives --set cuda /usr/local/cuda-12.4 && \\\n    rm -rf /usr/local/cuda-12.6\n\n# Install torch-2.6.0+cu124 + vllm-0.8.5.post1 + sglang-0.4.6.post5\n# torch-2.6.0+cu124: cxx11abi=False\n# torch-2.6.0+cu126: cxx11abi=True\n# see https://github.com/flashinfer-ai/flashinfer/issues/911\n# Install sglang-0.4.6.post1 and torch-memory-saver\nRUN pip install --resume-retries 999 \"sglang[all]==0.4.6.post5\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install --resume-retries 999 torch-memory-saver --no-cache-dir\n\nRUN pip install --resume-retries 999 --no-cache-dir \"vllm==0.8.5.post1\" \"torch==2.6.0\" \"torchvision==0.21.0\" \"torchaudio==2.6.0\" \"tensordict==0.6.2\" torchdata\n\nRUN pip install --resume-retries 999 --no-cache-dir \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=15.0.0\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile \\\n    pytest py-spy pyext pre-commit ruff\n\n# Install flash-attn-2.7.4.post1 (cxx11abi=False)\nRUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \\\n    pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install Apex\nRUN git clone https://github.com/NVIDIA/apex.git && \\\n    cd apex && \\\n    pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" ./\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/TransformerEngine.git@v2.3\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.1\n\n# Fix opencv\nRUN pip install opencv-python\n\nRUN pip install opencv-fixer && \\\n    python -c \"from opencv_fixer import AutoFix; AutoFix()\"\n\n# Install verl\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\n    RUN apt-get update && \\\n    apt-get install -y aria2 libfreeimage3 libfreeimage-dev zlib1g"
  },
  {
    "path": "verl_distillation/docker/README.md",
    "content": "# Dockerfiles of verl\n\nWe provide pre-built Docker images for quick setup. And from this version, we utilize a new image release hierarchy for productivity and stability.\n\nThe image types are divided into three large categories:\n\n- **Base Image**: Without inference and training frameworks, only basic dependencies are installed. Can directly install vllm or SGLang on top of it, without need of reinstall torch or CUDA.\n- **Application Image**: Stable version with inference and training frameworks installed.\n- **Preview Image**: Unstable version with the latest frameworks and features.\n\nThe first two types of images are hosted on dockerhub [verlai/verl](https://hub.docker.com/r/verlai/verl) repository, while the preview images are hosted on community repository.\n\n> The image versions are mapped with verl releases, for example, image with tag ``verl0.4`` is built for verl release ``v0.4.x``.\n\n## Base Image\n\nThe stable base image is ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4`` with different CUDA versions.\n\nThe update of base image is not frequent, and the app image can be built on top of it without reinstalling base packages.\n\n## Application Image\n\nFrom this version, we divide images built for vLLM and SGLang as the divergence of dependent packages like FlashInfer.\nThere are 2 types of application images available:\n\n- **vLLM with FSDP and Megatron**: ``verlai/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.13.0-te2.2``\n- **SGLang with FSDP and Megatron**: `verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2`\n\nDocker images with Megatron backends are runnable with large language model like ``Qwen/Qwen3-235B-A22B``, ``deepseek-ai/DeepSeek-V3-0324`` post-training. Refer to the :doc:`Large Language Model Post-Training documentation<../perf/dpsk>` for more details.\n\nApplication images can be updated frequently, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.app.[frameworks]``. Based on the base image, it is easy to build your own application image with the desired inference and training frameworks.\n\n## Community Image\n\nFor vLLM with FSDP, please refer to [hiyouga/verl](https://hub.docker.com/r/hiyouga/verl) repository and the latest version is ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``.\n\nFor SGLang with FSDP, please refer to [ocss884/verl-sglang](https://hub.docker.com/r/ocss884/verl-sglang) repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group.\n\nFor latest vLLM with Megatron, please refer to [iseekyan/verl](https://hub.docker.com/r/iseekyan/verl) repository and the latest version is ``iseekyan/verl:nemo.gptoss_vllm0.11.0``.\n\nSee files under ``docker/`` for NGC-based image or if you want to build your own.\n\nNote that For aws instances with EFA net interface (Sagemaker AI Pod), you need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa``\n\n## Installation from Docker\n\nAfter pulling the desired Docker image and installing desired inference and training frameworks, you can run it with the following steps:\n\n1. Launch the desired Docker image and attach into it:\n\n```sh\ndocker create --runtime=nvidia --gpus all --net=host --shm-size=\"10g\" --cap-add=SYS_ADMIN -v .:/workspace/verl --name verl <image:tag> sleep infinity\ndocker start verl\ndocker exec -it verl bash\n```\n\n2. If you use the images provided, you only need to install verl itself without dependencies:\n\n```sh\n# install the nightly version (recommended)\ngit clone https://github.com/volcengine/verl && cd verl\npip3 install --no-deps -e .\n```\n\n[Optional] If you hope to switch between different frameworks, you can install verl with the following command:\n\n```sh\n# install the nightly version (recommended)\ngit clone https://github.com/volcengine/verl && cd verl\npip3 install -e .[vllm]\npip3 install -e .[sglang]\n```\n"
  },
  {
    "path": "verl_distillation/docker/ascend/Dockerfile.ascend_8.2.rc1_a2",
    "content": "# 1. Base Image\nFROM swr.cn-south-1.myhuaweicloud.com/ascendhub/cann:8.2.rc1-910b-ubuntu22.04-py3.11\n\n# 2. Pre-installation foundation vllm with architecture echo\nRUN ARCH=$(uname -m) && \\\n    echo \"export ARCH=$ARCH\" >> ~/.bashrc && \\\n    echo \"[LOG INFO] Current system architecture: $ARCH\"\n\n# 3. Install system dependencies\nRUN apt-get update -y && apt-get install -y --no-install-recommends \\\n    gcc g++ cmake libnuma-dev wget git curl jq vim build-essential \\\n    && rm -rf /var/lib/apt/lists/*\n\n# 4. Install vllm\nRUN ARCH=$(uname -m) && \\\n    echo \"[LOG INFO] Detected architecture: $ARCH\" && \\\n    if [ \"$ARCH\" = \"x86_64\" ]; then \\\n        echo \"[LOG INFO] Entering x86_64 branch: Setting pip extra index url\"; \\\n        pip config set global.extra-index-url \"https://download.pytorch.org/whl/cpu/ https://mirrors.huaweicloud.com/ascend/repos/pypi\"; \\\n    else \\\n        echo \"[LOG INFO] Entering aarch64 branch: No extra pip index url set\"; \\\n    fi && \\\n    git clone --depth 1 --branch v0.9.1 https://github.com/vllm-project/vllm && \\\n    cd vllm && \\\n    VLLM_TARGET_DEVICE=empty pip install -v -e . && \\\n    cd ..\n\n# 5. Install vllm_ascend\nRUN ARCH=$(uname -m) && \\\n    echo \"[LOG INFO] Configuring LD_LIBRARY_PATH for $ARCH\" && \\\n    if [ \"$ARCH\" = \"aarch64\" ]; then \\\n        export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/8.2.RC1/aarch64-linux/devlib/linux/aarch64:$LD_LIBRARY_PATH; \\\n    elif [ \"$ARCH\" = \"x86_64\" ]; then \\\n        export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/8.2.RC1/x86_64-linux/devlib/linux/x86_64/:$LD_LIBRARY_PATH; \\\n    fi && \\\n    source /usr/local/Ascend/ascend-toolkit/set_env.sh && \\\n    source /usr/local/Ascend/nnal/atb/set_env.sh && \\\n    git clone --depth 1 --branch v0.9.1 https://github.com/vllm-project/vllm-ascend.git && \\\n    cd vllm-ascend && \\\n    pip install -v -e . && \\\n    cd ..\n\n# 6. Install verl\nRUN git clone https://github.com/volcengine/verl.git && \\\n    cd verl && \\\n    git checkout main && \\\n    pip install -r requirements-npu.txt && \\\n    pip install -e . && \\\n    cd ..\n\n# 7. Install MindSpeed\nRUN git clone https://gitcode.com/Ascend/MindSpeed.git && \\\n    cd MindSpeed && \\\n    git checkout f2b0977e && \\\n    cd .. && \\\n    pip install -e MindSpeed\n\n# 8. Install Megatron-LM and configure PYTHONPATH\nRUN git clone https://github.com/NVIDIA/Megatron-LM.git && \\\n    cd Megatron-LM && \\\n    git checkout core_v0.12.1 && \\\n    cd .. && \\\n    echo \"export PYTHONPATH=\\$PYTHONPATH:/Megatron-LM\" >> ~/.bashrc\n\n# Show pip list and clear pip cache to reduce image size\nRUN pip list && pip cache purge\n\n# Setting Default Commands\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "verl_distillation/docker/ascend/Dockerfile.ascend_8.2.rc1_a3",
    "content": "# 1. Base Image\nFROM swr.cn-south-1.myhuaweicloud.com/ascendhub/cann:8.2.rc1-a3-ubuntu22.04-py3.11\n\n# 2. Pre-installation foundation vllm with architecture echo\nRUN ARCH=$(uname -m) && \\\n    echo \"export ARCH=$ARCH\" >> ~/.bashrc && \\\n    echo \"[LOG INFO] Current system architecture: $ARCH\"\n\n# 3. Install system dependencies\nRUN apt-get update -y && apt-get install -y --no-install-recommends \\\n    gcc g++ cmake libnuma-dev wget git curl jq vim build-essential \\\n    && rm -rf /var/lib/apt/lists/*\n\n# 4. Install vllm\nRUN ARCH=$(uname -m) && \\\n    echo \"[LOG INFO] Detected architecture: $ARCH\" && \\\n    if [ \"$ARCH\" = \"x86_64\" ]; then \\\n        echo \"[LOG INFO] Entering x86_64 branch: Setting pip extra index url\"; \\\n        pip config set global.extra-index-url \"https://download.pytorch.org/whl/cpu/ https://mirrors.huaweicloud.com/ascend/repos/pypi\"; \\\n    else \\\n        echo \"[LOG INFO] Entering aarch64 branch: No extra pip index url set\"; \\\n    fi && \\\n    git clone --depth 1 --branch v0.9.1 https://github.com/vllm-project/vllm && \\\n    cd vllm && \\\n    VLLM_TARGET_DEVICE=empty pip install -v -e . && \\\n    cd ..\n\n# 5. Install vllm_ascend\nRUN ARCH=$(uname -m) && \\\n    echo \"[LOG INFO] Configuring LD_LIBRARY_PATH for $ARCH\" && \\\n    if [ \"$ARCH\" = \"aarch64\" ]; then \\\n        export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/8.2.RC1/aarch64-linux/devlib/linux/aarch64:$LD_LIBRARY_PATH; \\\n    elif [ \"$ARCH\" = \"x86_64\" ]; then \\\n        export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/8.2.RC1/x86_64-linux/devlib/linux/x86_64/:$LD_LIBRARY_PATH; \\\n    fi && \\\n    source /usr/local/Ascend/ascend-toolkit/set_env.sh && \\\n    source /usr/local/Ascend/nnal/atb/set_env.sh && \\\n    git clone --depth 1 --branch v0.9.1 https://github.com/vllm-project/vllm-ascend.git && \\\n    cd vllm-ascend && \\\n    pip install -v -e . && \\\n    cd ..\n\n# 6. Install verl\nRUN git clone https://github.com/volcengine/verl.git && \\\n    cd verl && \\\n    git checkout main && \\\n    pip install -r requirements-npu.txt && \\\n    pip install -e . && \\\n    cd ..\n\n# 7. Install MindSpeed\nRUN git clone https://gitcode.com/Ascend/MindSpeed.git && \\\n    cd MindSpeed && \\\n    git checkout f2b0977e && \\\n    cd .. && \\\n    pip install -e MindSpeed\n\n# 8. Install Megatron-LM and configure PYTHONPATH\nRUN git clone https://github.com/NVIDIA/Megatron-LM.git && \\\n    cd Megatron-LM && \\\n    git checkout core_v0.12.1 && \\\n    cd .. && \\\n    echo \"export PYTHONPATH=\\$PYTHONPATH:/Megatron-LM\" >> ~/.bashrc\n\n# Show pip list and clear pip cache to reduce image size\nRUN pip list && pip cache purge\n\n# Setting Default Commands\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.6.post5 and torch-memory-saver\nRUN pip install --resume-retries 999 \"sglang[all]==0.4.6.post5\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir\n\n# Some sglang operations in 0.4.6.post5 require vllm\n# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Fix for transformers 4.53.0\nRUN pip3 install --no-cache-dir \"transformers[hf_xet]<4.52.0\"\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12.deepep",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.6.post5 and torch-memory-saver\nRUN pip install --resume-retries 999 \"sglang[all]==0.4.6.post5\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir\n\n# Some sglang operations in 0.4.6.post5 require vllm\n# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Fix for transformers 4.53.0\nRUN pip3 install --no-cache-dir \"transformers[hf_xet]<4.52.0\"\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install"
  },
  {
    "path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.13.preview",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.6.post5 and torch-memory-saver\nRUN pip install --resume-retries 999 \"sglang[all]==0.4.6.post5\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir\n\n# Some sglang operations in 0.4.6.post5 require vllm\n# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.13.0\n\n# Fix for transformers 4.53.0\nRUN pip3 install --no-cache-dir \"transformers[hf_xet]<4.52.0\"\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install"
  },
  {
    "path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install torch-2.6.0+cu124 + vllm-0.8.5.post1\n# torch-2.6.0+cu124: cxx11abi=False\n# torch-2.6.0+cu126: cxx11abi=True\n# see https://github.com/flashinfer-ai/flashinfer/issues/911\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True)\n# vllm-0.8.3 does not support flashinfer>=0.2.3\n# see https://github.com/vllm-project/vllm/pull/15777\nRUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Fix for transformers 4.53.0\nRUN pip3 install --no-cache-dir \"transformers[hf_xet]<4.52.0\"\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12.deepep",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install torch-2.6.0+cu124 + vllm-0.8.5.post1\n# torch-2.6.0+cu124: cxx11abi=False\n# torch-2.6.0+cu126: cxx11abi=True\n# see https://github.com/flashinfer-ai/flashinfer/issues/911\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True)\n# vllm-0.8.3 does not support flashinfer>=0.2.3\n# see https://github.com/vllm-project/vllm/pull/15777\nRUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Fix for transformers 4.53.0\nRUN pip3 install --no-cache-dir \"transformers[hf_xet]<4.52.0\"\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install"
  },
  {
    "path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install torch-2.6.0+cu124 + vllm-0.8.5.post1\n# torch-2.6.0+cu124: cxx11abi=False\n# torch-2.6.0+cu126: cxx11abi=True\n# see https://github.com/flashinfer-ai/flashinfer/issues/911\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True)\n# vllm-0.8.3 does not support flashinfer>=0.2.3\n# see https://github.com/vllm-project/vllm/pull/15777\nRUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install"
  },
  {
    "path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.base",
    "content": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai/verl:base-v2-cu124-cudnn9.8-torch2.6-fa2.8.0-te2.3\n# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=16\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\n# Reinstall CUDA 12.4\nRUN aria2c https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \\\n    mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600\n\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cuda-toolkit-12-4 && \\\n    rm cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    update-alternatives --set cuda /usr/local/cuda-12.4 && \\\n    rm -rf /usr/local/cuda-12.6\n\nRUN pip install --resume-retries 999 --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0\n\nRUN pip install --resume-retries 999 --no-cache-dir \"tensordict==0.6.2\" torchdata \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\n# Install flash-attn-2.7.4.post1 (cxx11abi=False)\nRUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \\\n    pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\n# Install Apex\nRUN git clone https://github.com/NVIDIA/apex.git && \\\n    cd apex && \\\n    pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" ./\n\n# Profiling tools\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    apt-get update && apt-get install -y libxcb-cursor0 && \\\n    dpkg -i ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    rm -rf /usr/local/cuda/bin/nsys && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys  /usr/local/cuda/bin/nsys && \\\n    rm -rf /usr/local/cuda/bin/nsys-ui && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \\\n    rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb\n\n# Fix opencv\nRUN pip install --resume-retries 999 --no-cache-dir opencv-python\n\nRUN pip install --resume-retries 999 --no-cache-dir opencv-fixer && \\\n    python -c \"from opencv_fixer import AutoFix; AutoFix()\"\n\nRUN pip install --resume-retries 999 --no-cache-dir cuda-bindings\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\nRUN apt-get update && \\\n    apt-get install -y libfreeimage3 libfreeimage-dev zlib1g htop\n\n"
  },
  {
    "path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/README.md",
    "content": "# verl image with verl v0.4.x\n\n## Important packages version\n\n```txt\ncuda==12.4\ncudnn==9.8.0\ntorch==2.6.0\nflash_attn=2.7.4\nsglang==0.4.6.post5\nvllm==0.8.5.post1\nnvidia-cudnn-cu12==9.8.0.87\ntransformer_engine==2.3\nmegatron.core==core_v0.12.2\n# Preview\ntransformer_engine==2.5\nmegatron.core==core_r0.13.0\n```\n\n## Target\n\n- Base image: \n    - `verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4`\n- App image:\n    - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2`: SGLang requires vLLM in 0.4.6.post5 version, vLLM can have some package conflicts with SGLang\n    - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2-deepep`: Built with deepep\n    - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2`\n    - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2-deepep`: Built with deepep\n- Preview image:\n    - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.13.0-te2.2-preview`\n    - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.13.0-te2.2-preview`"
  },
  {
    "path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.app.sglang0.4.10.post2.mcore0.13",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=8\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.10\n# Install FlashInfer Python package\nRUN pip install --upgrade pip setuptools packaging\nRUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.9rc1\nRUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation \"sglang[all]==0.4.10.post2\"\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]==4.55.4\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.0\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge\n"
  },
  {
    "path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.app.sglang0.4.9.post6.mcore0.13",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=8\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.10\n# Install FlashInfer Python package\nRUN pip install --upgrade pip setuptools packaging\nRUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.9rc1\nRUN pip install --resume-retries 999  --no-cache-dir --no-build-isolation \"sglang[all]==0.4.9.post6\"\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]==4.55.4\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.0\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.app.vllm.mcore0.13",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install torch-2.7.1+cu126 + vllm-0.10.0\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.10.0\n\n# Fix packages\n# transformers 4.54.0 still not support\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.55.4\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.0\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge\n\n# Fix qwen vl\nRUN pip3 install --no-cache-dir --no-deps trl"
  },
  {
    "path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.app.vllm.mcore0.15",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM iseekyan/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4-h100\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install torch-2.7.1+cu126 + vllm-0.10.0\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.10.0\n\n# Fix packages\n# transformers 4.54.0 still not support\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.55.4\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.7\nRUN pip install onnxscript\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.15.0rc4\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge==v0.15.0\n\n# Fix qwen vl\nRUN pip3 install --no-cache-dir --no-deps trl"
  },
  {
    "path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.base.torch2.7.1",
    "content": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6\n# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=16\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\nRUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1\n\n# Install flash-attn-2.7.4.post1, although built with torch2.6, it is compatible with torch2.7\n# https://github.com/Dao-AILab/flash-attention/issues/1644#issuecomment-2899396361\nRUN ABI_FLAG=$(python -c \"import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')\") && \\\n    URL=\"https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl\" && \\\n    FILE=\"flash_attn-2.7.4.post1+cu12torch2.6cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl\" && \\\n    wget -nv \"${URL}\" && \\\n    pip install --no-cache-dir \"${FILE}\"\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\n# Install Apex\nRUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" --resume-retries 999 git+https://github.com/NVIDIA/apex.git\n\n# Profiling tools\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    apt-get update && apt-get install -y libxcb-cursor0\n\nRUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    rm -rf /usr/local/cuda/bin/nsys && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys  /usr/local/cuda/bin/nsys && \\\n    rm -rf /usr/local/cuda/bin/nsys-ui && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \\\n    rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb\n\nRUN pip install --resume-retries 999 --no-cache-dir \"tensordict==0.6.2\" torchdata \"transformers[hf_xet]>=4.52.3\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas cuda-bindings \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\n"
  },
  {
    "path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/README.md",
    "content": "# verl image with verl v0.5\n\n## Important packages version\n\n```txt\ncuda==12.6\ncudnn==9.8.0\ntorch==2.7.1\nflash_attn=2.7.4.post1\nsglang==0.4.9.post6\nvllm==0.8.5.post1\nnvidia-cudnn-cu12==9.8.0.87\ntransformer_engine==2.3\nmegatron.core==core_v0.12.2\n# Preview\ntransformer_engine==2.5\nmegatron.core==core_r0.13.0\n```\n\n## Target\n\n- Base image:\n  - `verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4`: We offer a base image with deep ep built in, for vllm/sglang\n- App image:\n  - `verlai/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.13.0-te2.2`\n  - `verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2`\n  - `iseekyan/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.15.0-te2.7`\n"
  },
  {
    "path": "verl_distillation/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.12",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0\n\n# Define environments\nENV MAX_JOBS=8\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.8 and torch-memory-saver\n# Install FlashInfer Python package\nRUN pip install --upgrade pip setuptools packaging\nRUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1\nRUN pip install --resume-retries 999  --no-cache-dir \"sglang[all]==0.4.8\" && pip install torch-memory-saver --no-cache-dir\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_distillation/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.13.preview",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0\n\n# Define environments\nENV MAX_JOBS=8\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.8 and torch-memory-saver\n# Install FlashInfer Python package\nRUN pip install --upgrade pip setuptools packaging\nRUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1\nRUN pip install --resume-retries 999  --no-cache-dir \"sglang[all]==0.4.8\" && pip install torch-memory-saver --no-cache-dir\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_distillation/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.base",
    "content": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6\n# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=16\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\nRUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1\n\n# Install flash-attn-2.8.0.post2 (cxx11abi=True)\nRUN ABI_FLAG=$(python -c \"import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')\") && \\\n    URL=\"https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl\" && \\\n    FILE=\"flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl\" && \\\n    wget -nv \"${URL}\" && \\\n    pip install --no-cache-dir \"${FILE}\"\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\n# Install Apex\nRUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" --resume-retries 999 git+https://github.com/NVIDIA/apex.git\n\n# Profiling tools\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    apt-get update && apt-get install -y libxcb-cursor0\n\nRUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    rm -rf /usr/local/cuda/bin/nsys && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys  /usr/local/cuda/bin/nsys && \\\n    rm -rf /usr/local/cuda/bin/nsys-ui && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \\\n    rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb\n\nRUN pip install --resume-retries 999 --no-cache-dir \"tensordict==0.6.2\" torchdata \"transformers[hf_xet]>=4.53\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas cuda-bindings \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\n"
  },
  {
    "path": "verl_distillation/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/README.md",
    "content": "# verl image with verl v0.5\n\n## Important packages version\n\n```txt\ncuda==12.6\ncudnn==9.8.0\ntorch==2.7.1\nflash_attn=2.8.0    ##\nsglang==0.4.8\nvllm==0.8.5.post1\nnvidia-cudnn-cu12==9.8.0.87\ntransformer_engine==2.3\nmegatron.core==core_v0.12.2\n# Preview\ntransformer_engine==2.5\nmegatron.core==core_r0.13.0\n```\n\n## Target\n\n- Base image:\n    - `verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0`: We offer a base image with deep ep built in\n- App image:\n    - `verlai/verl:app-verl0.5-sglang0.4.9-mcore0.12.2`\n    - `verlai/verl:app-verl0.5-sglang0.4.9-mcore0.13.0-preview`\n- vllm temporarily not support latest version"
  },
  {
    "path": "verl_distillation/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.megatron",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6\n\n# Define environments\nENV MAX_JOBS=8\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.8 and torch-memory-saver\n# Install FlashInfer Python package\nRUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1\nRUN pip install --resume-retries 999  --no-cache-dir \"sglang[all]==0.4.8\" && pip install torch-memory-saver --no-cache-dir\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.13.0\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_distillation/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.base",
    "content": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6\n# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:25.02-py3\n\n# Define environments\nENV MAX_JOBS=16\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\nRUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128\n\n# Install flash-attn-2.8.0.post2 (cxx11abi=True)\nRUN ABI_FLAG=$(python -c \"import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')\") && \\\n    URL=\"https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp312-cp312-linux_x86_64.whl\" && \\\n    FILE=\"flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp312-cp312-linux_x86_64.whl\" && \\\n    wget -nv \"${URL}\" && \\\n    pip install --no-cache-dir \"${FILE}\"\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\n# Install Apex\nRUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" --resume-retries 999 git+https://github.com/NVIDIA/apex.git\n\n# Profiling tools\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    apt-get update && apt-get install -y libxcb-cursor0\n\nRUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    rm -rf /usr/local/cuda/bin/nsys && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys  /usr/local/cuda/bin/nsys && \\\n    rm -rf /usr/local/cuda/bin/nsys-ui && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \\\n    rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb\n\nRUN pip install --resume-retries 999 --no-cache-dir \"tensordict==0.6.2\" torchdata \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas cuda-bindings \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pre-commit ruff\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\n"
  },
  {
    "path": "verl_distillation/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/README.md",
    "content": "# verl image with verl v0.5\n\n## Important packages version\n\n```txt\ncuda==12.8\ncudnn==9.8.0\ntorch==2.7.1\nflash_attn=2.8.0    ##\nsglang==0.4.8\ntransformer_engine==2.5\nmegatron.core==core_r0.13.0\nnvidia-cudnn-cu12==9.8.0.87\n```\n\n## Target\n\n- Base image:\n    - `verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`: We offer a base image with flash infer 0.2.6.post1 built in\n- App image:\n    - `verlai/verl:app-verl0.5-preview-sglang0.4.8-mcore0.13.0-preview`\n- vllm temporarily not support latest version\n\n## !!!Notice!!!\n\n- pyext is lack of maintainace and cannot work with python 3.12, consider using replacement and deprecating this package."
  },
  {
    "path": "verl_distillation/docker/verl0.6-cu128-torch2.8.0-fa2.7.4/Dockerfile.app.sglang",
    "content": "FROM verlai/verl:base-verl0.6-cu128-cudnn9.8-torch2.8.0-fa2.7.4\n\nRUN pip install --no-cache-dir \"sglang[all]==0.5.2\"\nRUN pip install --no-cache-dir \"torch-memory-saver==0.0.9rc1\"\n"
  },
  {
    "path": "verl_distillation/docker/verl0.6-cu128-torch2.8.0-fa2.7.4/Dockerfile.base",
    "content": "# Start from the NVIDIA official image (ubuntu-24.04 + cuda-12.8 + python-3.12)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-03.html\nFROM nvcr.io/nvidia/pytorch:25.03-py3\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\nENV PIP_CONSTRAINT=\"\"\n\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    pip config set global.no-cache-dir \"true\" && \\\n    python -m pip install --upgrade pip\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install libxml2\nRUN apt-get update && \\\n    apt-get install -y libxml2 aria2 && \\\n    apt-get clean\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    transformer_engine flash_attn apex megatron-core \\\n    xgboost opencv grpcio\n\n# Fix packages\nRUN pip install --no-cache-dir tensordict torchdata \"transformers[hf_xet]==4.55.4\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pre-commit ruff\n\n# Fix cv2\nRUN rm -rf /usr/local/lib/python3.11/dist-packages/cv2\n\n# Install torch\nRUN pip install --no-cache-dir torch==2.8.0 --index-url https://download.pytorch.org/whl/cu128\n\n# Install flash-attn\nRUN pip install --no-cache-dir --no-build-isolation flash_attn==2.7.4.post1\n\n# Install DeepEP\n# the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n# Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\n## Build deepep-nvshmem\nRUN apt-get install -y ninja-build cmake\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\nENV GDRCOPY_INCLUDE=/workspace/gdrcopy/include\n\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install\n\n# Install Apex\nRUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" git+https://github.com/NVIDIA/apex.git\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN git clone -b core_v0.13.0 https://github.com/NVIDIA/Megatron-LM.git && \\\n    cd Megatron-LM && pip3 install --no-deps -e .\n\n# Install mbridge\nRUN pip3 install --no-cache-dir git+https://github.com/ISEEKYAN/mbridge.git\n"
  },
  {
    "path": "verl_distillation/docker/verl0.6-cu128-torch2.8.0-fa2.7.4/Dockerfile.vllm011.mcore_gpt-oss",
    "content": "FROM nvcr.io/nvidia/nemo:25.07.gpt_oss\n\nRUN git clone -b v0.11.0 --depth 1 https://github.com/vllm-project/vllm.git /opt/vllm\n\nRUN pip install setuptools_scm\n\nRUN cd /opt/vllm && pip install --no-deps --no-build-isolation --no-cache-dir -e .\n\nRUN pip install cbor2 setproctitle blake3 openai_harmony pybase64 msgspec partial_json_parser py-cpuinfo diskcache gguf\n\nRUN pip install --upgrade transformers tokenizers\n\nRUN pip install codetiming tensordict mathruler pylatexenc\n\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_distillation/docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSPHINXPROJ    = verl\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "verl_distillation/docs/README.md",
    "content": "# verl documentations\n\n## Build the docs\n\n```bash\n# If you want to view auto-generated API docstring, please make sure verl is available in python path. For instance, install verl via:\n# pip install .. -e[test]\n\n# Install dependencies needed for building docs.\npip install -r requirements-docs.txt\n\n# Build the docs.\nmake clean\nmake html\n```\n\n## Open the docs with your browser\n\n```bash\npython -m http.server -d _build/html/\n```\nLaunch your browser and navigate to http://localhost:8000 to view the documentation. Alternatively you could drag the file `_build/html/index.html` to your local browser and view directly.\n"
  },
  {
    "path": "verl_distillation/docs/README_vllm0.7.md",
    "content": "# Upgrading to vllm >= 0.7\n\nNote: verl+vllm 0.8.3 is now stable. Please see ``docs/README_vllm0.8.md`` for upgrade guide.\n\n## Installation\n\nNote: At time of writing, verl+vllm 0.7.x supports **FSDP** for training and **vLLM** for rollout.\n\n```\n# Create the conda environment\nconda create -n distill python==3.10\nconda activate distill\n\n# Install verl\ngit clone https://github.com/volcengine/verl.git\ncd verl\npip3 install -e .\n\n# Install the latest stable version of vLLM\npip3 install vllm==0.7.3 \n\n# Install flash-attn\npip3 install flash-attn --no-build-isolation\n\n```\n\nNote that if you are installing lower versions of vLLM (0.7.0, 0.7.1, 0.7.2), you need to make some tiny patches manually on vllm (/path/to/site-packages/vllm after installation) after the above steps:\n\n- vllm/distributed/parallel_state.py: Remove the assertion below:\n\n```\nif (world_size\n        != tensor_model_parallel_size * pipeline_model_parallel_size):\n    raise RuntimeError(\n        f\"world_size ({world_size}) is not equal to \"\n        f\"tensor_model_parallel_size ({tensor_model_parallel_size}) x \"\n        f\"pipeline_model_parallel_size ({pipeline_model_parallel_size})\")\n\n```\n\n- vllm/executor/uniproc_executor.py: change `local_rank = rank` to `local_rank = int(os.environ[\"LOCAL_RANK\"])`\n- vllm/model_executor/model_loader/weight_utils.py: remove the `torch.cuda.empty_cache()` in `pt_weights_iterator`\n\n## Features\n\n### Use cuda graph\n\nAfter installation, examples using FSDP as training backends can be used. By default, the `enforce_eager` is set to True, which disables the cuda graph. To enjoy cuda graphs and the sleep mode of vLLM>=0.7, add the following lines to the bash script:\n\n```\nactor_rollout_ref.rollout.enforce_eager=False \\\nactor_rollout_ref.rollout.free_cache_engine=True \\\n\n```\n\nFor a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh, the rollout generation time is 85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation duration is further reduced to 62 seconds.\n\n**Note:** Currently, if the `n` is greater than 1 in `SamplingParams` in vLLM>=0.7, there is a potential performance issue on the stability of rollout generation time (Some iterations would see generation time bursts) using vLLM's V0 Engine.\n\n### Use vLLM V1 Engine\n\nUsing the vLLM V1 engine can avoid instability issues and achieve additional performance improvements. To use the V1 engine, you can first uninstall the previously installed vLLM and then follow the steps below to install the newer version.\n\n```\ngit clone https://github.com/vllm-project/vllm.git\ncd vllm\ngit checkout 2275784\nsed -i \"903a\\    data_parallel_size = world_size // pipeline_model_parallel_size // tensor_model_parallel_size\" ./vllm/distributed/parallel_state.py\nVLLM_USE_PRECOMPILED=1 pip install --editable .\n```\n\nThen you can enable the V1 engine by setting `export VLLM_USE_V1=1`. In some benchmark tests, the V1 engine demonstrates a 1.5x speed improvement over the vLLM V0 engine.\nThe stable support of the vLLM V1 engine is available on verl main.\n"
  },
  {
    "path": "verl_distillation/docs/README_vllm0.8.md",
    "content": "# Upgrading to vLLM >= 0.8\n\nLast updated: 05/04/2025.\n\n## Installation\n\nNote: This version of verl+vLLM 0.8+ supports **FSDP** for training and **vLLM** for rollout.\n\n```bash\n# Create the conda environment\nconda create -n distill python==3.10\nconda activate distill\n\n# Install verl\ngit clone https://github.com/volcengine/verl.git\ncd verl\npip3 install -e .\n\n# Install the latest stable version of vLLM\npip3 install vllm==0.8.3\n\n# Install flash-attn\npip3 install flash-attn --no-build-isolation\n\n```\n\nWe have a pre-built docker image for verl+vLLM 0.8.3. You can direct import it with the following command:\n\n```bash\ndocker pull hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0\n```\n\n## Features\n\nvLLM 0.8+ supports cuda graph and V1 engine by default in verl. To enable these features, remember to add the following lines to the bash script:\n\n```bash\nactor_rollout_ref.rollout.enforce_eager=False \\\nactor_rollout_ref.rollout.free_cache_engine=True \\\n```\n\nand also **remove** the environment variable if it exists:\n\n## Notes\n\nWhen you just directly upgrade vllm>=0.8, some dependency packages may undergo version changes. If you encounter the following problems:\n\n```bash\nin <module> from torch.multiprocessing.reductions import ForkingPickler ImportError: cannot import name 'ForkingPickler' from 'torch.multiprocessing.reductions' (/opt/conda/lib/python3.11/site-packages/torch/multiprocessing/reductions.py)\n```\n\nYou need to upgrade `tensordict` to version 0.6.2 using the command `pip install tensordict==0.6.2`.\n"
  },
  {
    "path": "verl_distillation/docs/_static/custom.css",
    "content": "/* Make the documentation use full screen width */\n.wy-nav-content {\n    max-width: none !important;\n    width: 100% !important;\n    padding: 1.618em 3.236em !important;\n}\n\n/* Adjust the content wrapper - will be set by JavaScript */\n.wy-nav-content-wrap {\n    margin-left: 300px;\n    transition: margin-left 0.2s ease;\n    width: auto !important;\n    position: relative !important;\n    background: white !important;\n    min-height: 100vh !important;\n}\n\n/* Make the main content area responsive */\n.rst-content {\n    max-width: none !important;\n    width: 100% !important;\n}\n\n/* Optional: Adjust table widths to prevent overflow */\n.rst-content table.docutils {\n    width: 100% !important;\n    table-layout: auto !important;\n}\n\n/* Optional: Better code block width handling */\n.rst-content .highlight {\n    width: 100% !important;\n}\n\n/* Content area positioning already handled above */\n\n/* Optional: Improve readability with some margin on very wide screens */\n@media (min-width: 1400px) {\n    .wy-nav-content {\n        max-width: none !important;\n        margin: 0 auto !important;\n    }\n}\n\n/* Resizable sidebar styles */\n.wy-nav-side {\n    position: fixed !important;\n    top: 0 !important;\n    bottom: 0 !important;\n    left: 0 !important;\n    width: 300px;\n    min-width: 200px;\n    max-width: 600px;\n    display: flex;\n    flex-direction: column;\n    z-index: 200 !important;\n}\n\n/* Ensure sidebar header (logo, search) adapts to width */\n.wy-side-nav-search {\n    width: 100% !important;\n    box-sizing: border-box !important;\n    padding: 0.809em 0.809em !important;\n}\n\n.wy-side-nav-search input[type=\"text\"] {\n    width: 100% !important;\n    box-sizing: border-box !important;\n}\n\n/* Make logo/title area responsive */\n.wy-side-nav-search > div.version {\n    width: 100% !important;\n}\n\n.wy-side-nav-search > a {\n    width: 100% !important;\n    display: block !important;\n    white-space: nowrap !important;\n    overflow: hidden !important;\n    text-overflow: ellipsis !important;\n}\n\n/* Responsive adjustments for narrow sidebar */\n@media (max-width: 300px) {\n    .wy-side-nav-search > a {\n        font-size: 0.9em !important;\n    }\n    \n    .wy-side-nav-search input[type=\"text\"] {\n        font-size: 0.8em !important;\n    }\n}\n\n/* Ensure search input doesn't overflow */\n.wy-side-nav-search form {\n    width: 100% !important;\n    margin: 0 !important;\n}\n\n/* Make search icon responsive */\n.wy-side-nav-search .wy-dropdown {\n    width: 100% !important;\n}\n\n/* Adjust search results dropdown width */\n.wy-side-nav-search .wy-dropdown-menu {\n    width: 100% !important;\n    max-width: none !important;\n    left: 0 !important;\n    right: 0 !important;\n}\n\n/* Resize handle is created by JavaScript */\n\n/* Make sure the sidebar content doesn't overflow */\n.wy-side-scroll {\n    width: 100% !important;\n    flex: 1 !important;\n    overflow-y: auto !important;\n    overflow-x: hidden !important;\n    padding-right: 10px !important;\n    box-sizing: border-box !important;\n    scroll-behavior: auto !important; /* Prevent smooth scrolling on sidebar itself */\n}\n\n/* Ensure proper scroll behavior for main content area */\nhtml {\n    scroll-behavior: smooth !important;\n}\n\n/* Ensure anchor links work properly in main content */\n.wy-nav-content-wrap {\n    scroll-behavior: smooth !important;\n}\n\n/* Fix scroll to target for anchor links */\n.rst-content {\n    scroll-behavior: smooth !important;\n}\n\n/* Fix anchor scroll offset to account for fixed header */\n.rst-content .section {\n    scroll-margin-top: 60px;\n}\n\n/* Fix anchor scroll offset for headers */\n.rst-content h1, .rst-content h2, .rst-content h3, .rst-content h4, .rst-content h5, .rst-content h6 {\n    scroll-margin-top: 60px;\n}\n\n/* Fix anchor scroll offset for specific scroll targets */\n.rst-content .headerlink {\n    scroll-margin-top: 60px;\n}\n\n/* Fix sidebar navigation styling */\n.wy-menu-vertical {\n    width: 100% !important;\n}\n\n.wy-menu-vertical li {\n    width: 100% !important;\n}\n\n.wy-menu-vertical a {\n    width: 100% !important;\n    word-wrap: break-word !important;\n    white-space: normal !important;\n}\n\n/* Content area margin is handled by JavaScript */\n\n/* Custom drag handle (more visible) */\n.resize-handle {\n    position: absolute;\n    top: 0;\n    right: 0;\n    width: 8px;\n    height: 100%;\n    background: #ccc;\n    cursor: col-resize;\n    z-index: 1001;\n    opacity: 0.3;\n    transition: opacity 0.2s ease;\n}\n\n.resize-handle:hover {\n    opacity: 0.8;\n    background: #999;\n}\n\n.resize-handle::before {\n    content: '';\n    position: absolute;\n    top: 50%;\n    left: 50%;\n    width: 2px;\n    height: 20px;\n    background: #666;\n    transform: translate(-50%, -50%);\n    border-radius: 1px;\n}\n\n.resize-handle:hover::before {\n    background: #333;\n}\n\n/* Ensure smooth resizing */\n.wy-nav-side.resizing {\n    user-select: none;\n    pointer-events: none;\n}\n\n.wy-nav-side.resizing .wy-side-scroll {\n    overflow: hidden;\n}"
  },
  {
    "path": "verl_distillation/docs/_static/js/resizable-sidebar.js",
    "content": "// Resizable sidebar functionality\ndocument.addEventListener('DOMContentLoaded', function() {\n    const sidebar = document.querySelector('.wy-nav-side');\n    const content = document.querySelector('.wy-nav-content-wrap');\n    \n    if (!sidebar || !content) return;\n    \n    // Create resize handle\n    const resizeHandle = document.createElement('div');\n    resizeHandle.className = 'resize-handle';\n    sidebar.appendChild(resizeHandle);\n    \n    let isResizing = false;\n    let startX = 0;\n    let startWidth = 0;\n    \n    // Get initial width\n    const getInitialWidth = () => {\n        return 300; // Default width\n    };\n    \n    // Save width to localStorage\n    const saveWidth = (width) => {\n        localStorage.setItem('sidebar-width', width);\n    };\n    \n    // Load width from localStorage\n    const loadWidth = () => {\n        const savedWidth = localStorage.getItem('sidebar-width');\n        if (savedWidth) {\n            const width = parseInt(savedWidth, 10);\n            if (width >= 200 && width <= 600) {\n                return width;\n            }\n        }\n        return getInitialWidth();\n    };\n    \n    // Apply width to sidebar and content\n    const applyWidth = (width) => {\n        // Update sidebar width\n        sidebar.style.width = width + 'px';\n        \n        // Update content margin with !important to override any CSS\n        content.style.setProperty('margin-left', width + 'px', 'important');\n        \n        // Also update any other content wrapper that might exist\n        const contentInner = document.querySelector('.wy-nav-content');\n        if (contentInner) {\n            contentInner.style.setProperty('margin-left', '0px', 'important');\n        }\n        \n        // Force reflow and repaint\n        sidebar.offsetHeight;\n        content.offsetHeight;\n        \n        // Trigger window resize event to notify other components\n        window.dispatchEvent(new Event('resize'));\n    };\n    \n    // Initialize with saved width\n    const initialWidth = loadWidth();\n    applyWidth(initialWidth);\n    \n    // Mouse down on resize handle\n    resizeHandle.addEventListener('mousedown', (e) => {\n        isResizing = true;\n        startX = e.clientX;\n        startWidth = parseInt(window.getComputedStyle(sidebar).width, 10);\n        \n        sidebar.classList.add('resizing');\n        document.body.style.cursor = 'col-resize';\n        document.body.style.userSelect = 'none';\n        \n        // Add overlay to prevent iframe issues\n        const overlay = document.createElement('div');\n        overlay.style.cssText = `\n            position: fixed;\n            top: 0;\n            left: 0;\n            width: 100%;\n            height: 100%;\n            z-index: 9999;\n            cursor: col-resize;\n        `;\n        overlay.id = 'resize-overlay';\n        document.body.appendChild(overlay);\n        \n        e.preventDefault();\n    });\n    \n    // Mouse move\n    document.addEventListener('mousemove', (e) => {\n        if (!isResizing) return;\n        \n        const width = startWidth + e.clientX - startX;\n        const clampedWidth = Math.max(200, Math.min(600, width));\n        applyWidth(clampedWidth);\n    });\n    \n    // Mouse up\n    document.addEventListener('mouseup', () => {\n        if (!isResizing) return;\n        \n        isResizing = false;\n        sidebar.classList.remove('resizing');\n        document.body.style.cursor = '';\n        document.body.style.userSelect = '';\n        \n        // Remove overlay\n        const overlay = document.getElementById('resize-overlay');\n        if (overlay) {\n            overlay.remove();\n        }\n        \n        // Save the current width\n        const currentWidth = parseInt(window.getComputedStyle(sidebar).width, 10);\n        saveWidth(currentWidth);\n    });\n    \n    // Handle window resize - removed to prevent infinite loop\n    // The sidebar width is fixed and managed by drag functionality, no need to recalculate on window resize\n    \n    // Double-click to reset to default width\n    resizeHandle.addEventListener('dblclick', () => {\n        const defaultWidth = 300;\n        applyWidth(defaultWidth);\n        saveWidth(defaultWidth);\n    });\n});\n\n// Fix navigation issues - Using MutationObserver for reliable initialization\ndocument.addEventListener('DOMContentLoaded', function() {\n    let navigationFixed = false;\n    \n    function setupNavigationFix() {\n        if (navigationFixed) return;\n        \n        // Find all links in the sidebar\n        const sidebarLinks = document.querySelectorAll('.wy-menu-vertical a');\n        \n        // Only proceed if we have sidebar links\n        if (sidebarLinks.length === 0) return;\n        \n        console.log('Setting up navigation fix...');\n        \n        sidebarLinks.forEach(function(link) {\n            const href = link.getAttribute('href');\n            \n            // Clone the link to remove all existing event listeners\n            const newLink = link.cloneNode(true);\n            \n            // Add our own click handler\n            newLink.addEventListener('click', function(e) {\n                console.log('Link clicked:', href);\n                \n                // If it's an anchor link within the same page\n                if (href && href.startsWith('#') && href !== '#') {\n                    e.preventDefault();\n                    e.stopPropagation();\n                    \n                    const targetId = href.substring(1);\n                    const targetElement = document.getElementById(targetId);\n                    \n                    if (targetElement) {\n                        // Calculate offset for fixed header\n                        const headerHeight = 60;\n                        const elementPosition = targetElement.getBoundingClientRect().top;\n                        const offsetPosition = elementPosition + window.pageYOffset - headerHeight;\n                        \n                        window.scrollTo({\n                            top: offsetPosition,\n                            behavior: 'smooth'\n                        });\n                        \n                        // Update URL hash\n                        if (history.pushState) {\n                            history.pushState(null, null, '#' + targetId);\n                        } else {\n                            location.hash = '#' + targetId;\n                        }\n                    }\n                }\n                // For external links, navigate normally\n                else if (href && !href.startsWith('#') && !href.startsWith('javascript:')) {\n                    console.log('Navigating to external link:', href);\n                    window.location.href = href;\n                }\n            });\n            \n            // Replace the old link with the new one\n            link.parentNode.replaceChild(newLink, link);\n        });\n        \n        navigationFixed = true;\n        \n        // Handle initial page load with hash\n        if (window.location.hash) {\n            // Use requestAnimationFrame for better timing\n            requestAnimationFrame(() => {\n                const targetId = window.location.hash.substring(1);\n                const targetElement = document.getElementById(targetId);\n                if (targetElement) {\n                    const headerHeight = 60;\n                    const elementPosition = targetElement.getBoundingClientRect().top;\n                    const offsetPosition = elementPosition + window.pageYOffset - headerHeight;\n                    \n                    window.scrollTo({\n                        top: offsetPosition,\n                        behavior: 'smooth'\n                    });\n                }\n            });\n        }\n    }\n    \n    // Try to set up navigation fix immediately\n    setupNavigationFix();\n    \n    // If it didn't work, use MutationObserver to watch for when sidebar links are added\n    if (!navigationFixed) {\n        const observer = new MutationObserver(function(mutations) {\n            mutations.forEach(function(mutation) {\n                if (mutation.type === 'childList' && mutation.addedNodes.length > 0) {\n                    // Check if sidebar links were added\n                    const sidebarLinks = document.querySelectorAll('.wy-menu-vertical a');\n                    if (sidebarLinks.length > 0) {\n                        setupNavigationFix();\n                        if (navigationFixed) {\n                            observer.disconnect();\n                        }\n                    }\n                }\n            });\n        });\n        \n        // Start observing the document for changes\n        observer.observe(document.body, {\n            childList: true,\n            subtree: true\n        });\n        \n        // Fallback timeout in case MutationObserver doesn't work\n        setTimeout(function() {\n            if (!navigationFixed) {\n                setupNavigationFix();\n            }\n            observer.disconnect();\n        }, 5000);\n    }\n});"
  },
  {
    "path": "verl_distillation/docs/_static/js/runllm-widget.js",
    "content": "document.addEventListener(\"DOMContentLoaded\", function () {\n    var script = document.createElement(\"script\");\n    script.type = \"module\";\n    script.id = \"runllm-widget-script\";\n    script.src = \"https://widget.runllm.com\";\n    script.setAttribute(\"version\", \"stable\");\n    script.setAttribute(\"crossorigin\", \"true\");\n    script.setAttribute(\"runllm-keyboard-shortcut\", \"Mod+j\");\n    script.setAttribute(\"runllm-name\", \"verl Chatbot\");\n    script.setAttribute(\"runllm-position\", \"TOP_RIGHT\");\n    script.setAttribute(\"runllm-assistant-id\", \"679\");\n    script.async = true;\n    document.head.appendChild(script);\n  });"
  },
  {
    "path": "verl_distillation/docs/advance/agent_loop.rst",
    "content": "Agent Loop\n==========\n\nLast updated: 07/17/2025.\n\n.. versionadded:: 0.4.2\n   [status: alpha]\n\n.. warning::\n   Agent Loop is ready for use, but the API may change in future releaes.\n\nAgent Loop is designed as general interface for multi-turn rollout and agentic reinforcement learning.\n\n**Design goal**:\n\n- Plugable user defined agent loop\n- Provide standard request generate api with different inference frameworks\n- Provide request level load balance between multiple inference servers\n\n**Non-goal**:\n\n- How tool is defined and how to call tool\n\nIn high level overview, agent loop is given a prompt, run user defined loop: call LLM generate api, call tools, ...\nand return the final output. The final output is then calculated reward and used as trajectory for RL training.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop_overview.svg?raw=true\n\n\nAPI Design\n----------\n\n``AgentLoopBase`` class is the abstraction of agent loop, and ``run`` method is the only interface that user need to implement.\nThe run method, given prompt messages in format: [{\"role\": \"user\"}, {\"content\": \"...\"}], and additional sampling params,\ncould do whatever user wants, such as\n\n- call LLM generate api\n- call tools: web search, database query, code sandbox, ...\n- environment interaction\n- reflection\n- ...\n\n.. code:: python\n\n   class AgentLoopBase(ABC):\n       @abstractmethod\n       async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:\n           \"\"\"Run agent loop to interact with LLM server and environment.\n\n           Args:\n               sampling_params (Dict[str, Any]): LLM sampling params.\n               **kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`.\n\n           Returns:\n               AgentLoopOutput: Agent loop output.\n           \"\"\"\n           raise NotImplementedError\n\nAfter running user defined loop, run method should return ``AgentLoopOutput``, including prompt token ids,\nresponse token ids, and response mask.\n\n.. code:: python\n\n   class AgentLoopOutput(BaseModel):\n       \"\"\"Agent loop output.\"\"\"\n\n       prompt_ids: list[int]\n       \"\"\"Prompt token ids.\"\"\"\n       response_ids: list[int]\n       \"\"\"Response token ids including LLM generated token, tool response token.\"\"\"\n       response_mask: list[int]\n       \"\"\"Response mask, 1 for LLM generated token, 0 for tool response token.\"\"\"\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop_output.svg?raw=true\n\n.. note:: AgentLoopOutput only output one trajectory for a given prompt, multiple trajectories output is still under discussion.\n\nArchitecture Design\n-------------------\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop_architecture.png?raw=true\n\nA single PPO step contain two phase: rollout and train. In rollout phase:\n\n1. PPOTrainer sample a batch from dataset and call ``AgentLoopManager.generate_sequences``.\n2. AgentLoopManager ``wake_up`` all async LLM server instances, which will sync weights between inference engine(vLLM/SGLang) and training engine(FSDP/Megatron-LM).\n3. AgentLoopManager split batch into chunks and send each chunk to ``AgentLoopWorker``.\n4. AgentLoopWorker receive chunk and for each prompt, spawn a user defined ``AgentLoopBase`` instance, run ``run`` coroutine until end and get ``AgentLoopOutput``.\n\n.. tip::\n   AgentLoopWorker schedules multiple coroutines concurrently. If number of AgentLoopWorker equals batch_size, then each worker is response for one prompt.\n\nIn agent loop, when user need LLM generate response:\n\n5. Call ``AsyncLLMServerManager.generate`` with prompt_ids.\n6. AsyncLLMServerManager select a server instance with least request in first turn and send request to it. (In following turns, the request will be sent to the same server instance).\n7. AsyncLLMServer receive a request, issue ipc/rpc with model_runner, and generate response. (There's slight differences between vLLM and SGLang, see below).\n\nWhen all prompts in all AgentLoopWorker finish, AgentLoopManager gather results and return to PPOTrainer.\n\n8. AgentLoopManager ``sleep`` all server instances, which will free kv cache and offload weights to CPU memory.\n\nAsyncLLMServer\n~~~~~~~~~~~~~~\n\nAsyncLLMServer is the abstraction of LLM server with two types of generation api:\n\n- `OpenAI chat completion <https://platform.openai.com/docs/api-reference/chat>`_: generate response for the given chat conversation.\n- Token in token out: generate response ids for the given token ids.\n\nWe have officially supported vLLM and SGLang AsyncLLMServer, both of them implement the two api and are well tested.\nOther inference engine should be easy to plug-in by implement the ``AsyncServerBase`` class.\n\n.. code:: python\n\n   class AsyncServerBase(ABC):\n       @abstractmethod\n       async def chat_completion(self, raw_request: Request) -> JSONResponse:\n           \"\"\"OpenAI chat completion API.\n\n           Args:\n               raw_request (Request): raw json request\n           \n           Returns:\n               JSONResponse: json response\n\n           API reference: https://platform.openai.com/docs/api-reference/chat/create\n           \"\"\"\n           raise NotImplementedError\n\n       @abstractmethod\n       async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]:\n           \"\"\"Generate response ids given prompt ids.\n\n           Args:\n               prompt_ids (List[int]): prompt ids\n               sampling_params (Dict[str, Any]): sampling params\n               request_id (str): request id\n\n           Returns:\n               List[int]: response ids\n           \"\"\"\n           raise NotImplementedError\n\n\nChat completion vs Token in token out\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. warning::\n   The following conclusion is based on our recent experience and is still open to investigation and discussion.\n\nAlmost all agent frameworks (LangGraph, CrewAI, LlamaIndex, etc) call LLM with OpenAI chat completion api, and \nkeep chat history as messages. So user may expect that we should use the chat completion api in multi-turn rollout.\n\nBut based on our recent experience on single-turn training on DAPO and multi-turn training on `retool <https://github.com/volcengine/verl/tree/main/recipe/retool>`_,\nwe found the token_ids from apply the final messages may not equal to the token_ids by concat prompt_ids and response_ids in each turn.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/multi_turn.png?raw=true\n\n**Where does this inconsistency happened?**\n\nFirst, the tool parser may alter the content. For example\n\n.. code:: json\n\n   {\"role\": \"assistant\", \"content\": \"Let me call a <tool_call>...</tool_call> and get the result\"}\n\nAfter tool_calls extraction, the messages is like this:\n\n.. code:: json\n\n   {\"role\": \"assistant\", \"content\": \"Let me call a and get the result\", \"tool_calls\": [{\"name\": \"foo\", \"arguments\": \"{}\"}]}\n\nEncode the extracted message back is not equal to the original LLM generated response_ids.\n\nSecond,  the `decode-encode` may also lead to inconsistency: `Agent-R1 issue#30 <https://github.com/0russwest0/Agent-R1/issues/30#issuecomment-2826155367>`_.\n\n**What is the impact of this inconsistency?**\n\nThis inconsistency is not a big problem for serving/agent system, but is critical to RL training.\nIt causes the trajectory deviate from the policy model distribution. We have observed that apply_chat_template\nto the final chat history messages make PPO training not even converged in single-turn.\n\nvLLM\n^^^^\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/async_vllm.png?raw=true\n\nFor vLLM, the Async LLM Engine is running in same process as the server, and ModelRunner is running in same process as FSDP/Megatron-LM workers.\nAsync LLM Engine communicate with ModelRunner through ZeroMQ. When server receive a request, it directly call engine to generate response_ids.\n\nSGLang\n^^^^^^\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/async_sglang.png?raw=true\n\nFor SGLang, the Async LLM Engine is running in same process as FSDP/Megatron-LM worker-0, and it spawn multiple subprocesses as ModelRunner.\nAlso, Async LLM Engine communicate with ModelRunner through ZeroMQ. When server receive a request, it remote call the worker-0 and get response_ids.\n\nAsyncLLMServerManager\n~~~~~~~~~~~~~~~~~~~~~\n\nAsyncLLMServerManager serve as proxy to multiple AsyncLLMServer instances, provides:\n\n- load balance: select a server instance with least request in first turn and send request to it.\n- sticky session: bind request_id to server instance, so that the same request_id will be sent to the same server instance in following turns.\n\nAsyncLLMServerManager is passed to ``AgentLoopBase.__init__``, whenever user want to interact with LLM in agent loop,\nthey can call ``AsyncLLMServerManager.generate`` to generate response_ids.\n\n.. code:: python\n\n   class AsyncLLMServerManager:\n       async def generate(\n           self,\n           request_id,\n           *,\n           prompt_ids: list[int],\n           sampling_params: dict[str, Any],\n       ) -> list[int]:\n           \"\"\"Generate tokens from prompt ids.\n\n           Args:\n               request_id (str): request id for sticky session.\n               prompt_ids (List[int]): List of prompt token ids.\n               sampling_params (Dict[str, Any]): Sampling parameters for the chat completion.\n\n           Returns:\n               List[int]: List of generated token ids.\n           \"\"\"\n           ...\n\nNext\n----\n\n- :doc:`Agentic RL Training<../start/agentic_rl>`: Quick start agentic RL training with gsm8k dataset.\n- `LangGraph MathExpression <https://github.com/volcengine/verl/tree/main/recipe/langgraph_agent/example>`_: Demonstrate how to use LangGraph to build agent loop.\n- `Retool <https://github.com/volcengine/verl/tree/main/recipe/retool>`_: End-to-end retool paper reproduction using tool agent.\n"
  },
  {
    "path": "verl_distillation/docs/advance/attention_implementation.rst",
    "content": ".. _attention-implementation-override:\n\nAttention Implementation Override\n==================================\n\nLast updated: 10/31/2025.\n\nBy default, VERL's FSDP workers use ``flash_attention_2`` as the attention implementation for improved performance. \nHowever, you can now override this setting to use different attention implementations based on your needs.\n\nSupported Attention Implementations\n-----------------------------------\n\nThe following attention implementations are supported (subject to model and hardware compatibility):\n\n- ``flash_attention_2``: High-performance attention implementation (default)\n- ``eager``: Standard PyTorch attention implementation\n- ``sdpa``: Scaled Dot-Product Attention (PyTorch native)\n\nWhen to Override\n----------------\n\nYou might want to override the attention implementation in the following scenarios:\n\n- **Debugging**: Use ``eager`` for easier debugging and better error messages\n- **Compatibility**: Some models or hardware configurations may not support ``flash_attention_2``\n- **Memory constraints**: Different implementations have different memory characteristics\n- **Performance tuning**: Testing different implementations for optimal performance\n\nConfiguration Examples\n-----------------------\n\nPPO Training with Eager Attention\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nTo override the attention implementation for the actor, rollout, and reference models:\n\n.. code:: bash\n\n    python3 ppo_trainer.py \\\n        +actor_rollout_ref.model.override_config.attn_implementation=eager \\\n        [other parameters...]\n\nPPO Training with SDPA Attention\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: bash\n\n    python3 ppo_trainer.py \\\n        +actor_rollout_ref.model.override_config.attn_implementation=sdpa \\\n        [other parameters...]\n\nCritic Model Override\n~~~~~~~~~~~~~~~~~~~~~\n\nFor training configurations that include a critic model, you can also override its attention implementation:\n\n.. code:: bash\n\n    python3 ppo_trainer.py \\\n        +actor_rollout_ref.model.override_config.attn_implementation=eager \\\n        +critic.model.override_config.attn_implementation=eager \\\n        [other parameters...]\n\nYAML Configuration\n~~~~~~~~~~~~~~~~~~\n\nYou can also specify the attention implementation in your YAML configuration file:\n\n.. code:: yaml\n\n    actor_rollout_ref:\n      model:\n        override_config:\n          attn_implementation: eager\n          # other overrides...\n\n    critic:  # if using a critic model\n      model:\n        override_config:\n          attn_implementation: eager\n          # other overrides...\n\nImportant Notes\n---------------\n\n**Backward Compatibility**: If you don't specify ``attn_implementation`` in the override config, \nVERL will continue to use ``flash_attention_2`` by default, ensuring backward compatibility with existing configurations.\n\n**Model Support**: Not all models support all attention implementations. Ensure your model is compatible \nwith the chosen attention implementation before training.\n\n**Performance Impact**: Different attention implementations have varying performance characteristics. \n``flash_attention_2`` typically offers the best performance, while ``eager`` provides better debugging capabilities.\n\n**Hardware Dependencies**: Some attention implementations (like ``flash_attention_2``) may require \nspecific hardware or CUDA versions. If you encounter compatibility issues, try using ``eager`` or ``sdpa``.\n\nTroubleshooting\n---------------\n\nIf you encounter errors when using a specific attention implementation:\n\n1. **Check model compatibility**: Verify that your model supports the chosen attention implementation\n2. **Try eager attention**: Use ``attn_implementation=eager`` as a fallback for debugging\n3. **Check hardware requirements**: Ensure your hardware supports the attention implementation\n4. **Review error messages**: Attention implementation errors often provide clear guidance on supported options\n\nExample Error Resolution\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nIf you see an error like \"flash_attention_2 is not supported\", you can resolve it by switching to eager attention:\n\n.. code:: bash\n\n    # Instead of the default flash_attention_2\n    python3 ppo_trainer.py +actor_rollout_ref.model.override_config.attn_implementation=eager\n\nThis override ensures your training can proceed while you investigate the flash attention compatibility issue.\n"
  },
  {
    "path": "verl_distillation/docs/advance/checkpoint.rst",
    "content": ".. _checkpoint-page:\n\nUsing Checkpoints to Support Fault Tolerance Training\n=====================================================\n\nLast updated: 06/25/2025.\n\nThere could be training errors or machine failure during the whole RLHF training process, \nso it is recommended to enable checkpoints to minimize your loss.\n\nThe API Interface has already been listed in :ref:`config-explain-page`,\nand we will not repeat them. But there are still some technique details\nwe hope to clarify.\n\n.. note:: \n\n    Notice that the ``checkpoint.contents`` field has no effect to FSDP checkpoint except ``hf_model``, \n    the other 3 fields are binded together to save and load. We recommend to include ``model``, ``optimizer`` and ``extra`` all.\n\nCheckpoint Saving Directory Structure\n-------------------------------------\n\nCommonly, we use the ``default_local_dir`` declared in ``ppo_trainer.yaml`` or ``ppo_megatron_trainer.yml``\nto work as preffix when saving checkpoints, which is ``checkpoints/${trainer.project_name}/${trainer.experiment_name}``.\n\nSo the inner checkpoint structure of **FSDP** is like:\n\n.. code::\n\n    checkpoints/${trainer.project_name}/${trainer.experiment_name}\n    ├── global_steps_${i}\n    │   ├── actor\n    │   │   ├── huggingface      # default save config and tokenizer, save huggingface model if include ``hf_model`` in checkpoint.contents\n    │   │   └── fsdp_config.json # FSDP config file, including world_size and fsdp version\n    │   │   ├── model_world_size_{self.world_size}_rank_{self.rank}.pt\n    │   │   ├── optim_world_size_{self.world_size}_rank_{self.rank}.pt\n    │   │   └── extra_state_world_size_{self.world_size}_rank_{self.rank}.pt\n    │   ├── critic\n    │   │   ├── huggingface\n    │   │   └── fsdp_config.json\n    │   │   ├── model_world_size_{self.world_size}_rank_{self.rank}.pt\n    │   │   ├── optim_world_size_{self.world_size}_rank_{self.rank}.pt\n    │   │   └── extra_state_world_size_{self.world_size}_rank_{self.rank}.pt\n    └── latest_checkpointed_iteration.txt\n\nAll model shards, optimizers and extra states are stored together, in a sharded and distributed way.\n\nWhile **Megatron** current checkpoint structure is:\n\n.. code::\n\n    checkpoints/${trainer.project_name}/${trainer.experiment_name}\n    ├── global_steps_${i}\n    │   ├── actor\n    │   │   ├── huggingface     # default save config and tokenizer, save huggingface model if include ``hf_mode`` in checkpoint.contents\n    │   │   └── dist_ckpt       # save sharded model/optimizer/rng_states, naming the same as Megatron\n    │   └── critic\n    │   │   ├── huggingface\n    │   │   └── dist_ckpt\n    └── latest_checkpointed_iteration.txt\n\nConvert FSDP and Megatron Checkpoints to HuggingFace Format Model\n-----------------------------------------------------------------\n\nWe provide a tool to convert the FSDP and Megatron checkpoints to HuggingFace format model.\nThe tool is located in ``verl/model_merger``. For older versions of verl that don't include fsdp_config.json in checkpoints, you can use the legacy model merger located at ``verl/scripts/legacy_model_merger.py``.\n\nThe script supports two main sub-commands: `merge` (to convert and save checkpoints) and `test` (to validate merged checkpoints against a reference model).\nThe arguments for the `merge` sub-command are as follows:\n\n.. code:: bash\n\n    usage: python -m verl.model_merger merge [-h] --backend {fsdp,megatron} [--local_dir LOCAL_DIR] [--tie-word-embedding] [--is-value-model] [--use_cpu_initialization] [--target_dir TARGET_DIR]\n                         [--hf_upload_path HF_UPLOAD_PATH] [--private]\n\n    options:\n    -h, --help            show this help message and exit\n    --backend {fsdp,megatron}\n                            The backend of the model\n    --local_dir LOCAL_DIR\n                            Path to the saved model checkpoints\n    --tie-word-embedding  Whether to tie word embedding weights (currently only Megatron supported)\n    --is-value-model      Whether the model is a value model (currently only Megatron supported)\n    --use_cpu_initialization\n                            Whether to use CPU initialization for the model. This is useful for large models that cannot fit into GPU memory during initialization.\n    --target_dir TARGET_DIR\n                            Directory to save the merged huggingface model\n    --hf_upload_path HF_UPLOAD_PATH\n                            Hugging Face repository ID to upload the model\n    --private             Whether to upload the model to a private Hugging Face repository\n\nExample usage for merging Megatron checkpoints:\n\n.. code:: bash\n\n    python -m verl.model_merger merge \\\n        --backend megatron \\\n        --tie-word-embedding \\\n        --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \\\n        --target_dir /path/to/merged_hf_model\n\nExample usage for distributed merging Megatron checkpoints:\n\n.. code:: bash\n\n    torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge \\\n        --backend megatron \\\n        --tie-word-embedding \\\n        --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \\\n        --target_dir /path/to/merged_hf_model\n\nExample usage for merging FSDP checkpoints:\n\n.. code:: bash\n\n    python -m verl.model_merger merge \\\n        --backend fsdp \\\n        --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \\\n        --target_dir /path/to/merged_hf_model\n\n\nMegatron Merger details\n-----------------------\n\nCurrent implement of decoder layers uses ``nn.ModuleList`` to store the layers, \nand thus the model layers on every PP rank and VPP rank starts their index from 0.\n\nThere are 3 ways to correct this behavior:\n\n1. Modify the decoder layer's state_dict, add ``offset`` to each layer's index, thus rewrite ``nn.ModuleList`` implementation.\n2. Modify the layer index when saving checkpoint and recover them when loading checkpoint.\n3. The Checkpoint merger do this work, calculate the actual ``offset`` from ``state_dict`` only, a little complex.\n\nCurrent implementation use solution 2.\n\n\nHuggingFace to Megatron DistCheckpoint details\n----------------------------------------------\n\nIf your model is quite huge, we recommend you to use Megatron dist-checkpoint to load the model.\nMegatron dist-checkpoint supports loading with different kinds of model parallelism,\nand it is much faster than the original checkpoint loading.\n\nTo convert original HuggingFace model to Megatron dist-checkpoint,\nyou can use the ``scripts/converter_hf_to_mcore.py`` script. Large MoE models are temporarily supported with CPU initialization,\nwhich is a little slower. While we are working on a better solution to support large models.\n\nExample command to convert the model is as follows:\n\n.. code:: bash\n\n    python scripts/converter_hf_to_mcore.py \\\n        --hf_model_path Qwen/Qwen1.5-MoE-A2.7B-Chat \\\n        --output_path /mnt/disk/Qwen/Qwen1.5-MoE-A2.7B-Chat \\\n        --use_cpu_initialization    # Only work for MoE models\n\n\nExample command to distributed convert the huge model like deepseekv3 671B is as follows:\n\n.. code:: bash\n\n    torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} scripts/converter_hf_to_mcore.py \\\n        --hf_model_path deepseek-ai/DeepSeek-V3 \\\n        --output_path /mnt/disk/deepseek-ai/DeepSeek-V3 \\\n        --use_cpu_initialization    # Only work for MoE models\n\nOriginal Checkpoint Utils\n-------------------------\n\nOriginal Checkpoint Utils refer to original checkpoint implementation in ``verl/models/[model]/megatron/checkpoint_utils``.\n\nWe only need ``[model]_loader.py`` in original checkpoint utils now, since we get rid of storing ``hf_model`` every time (which is not recommended for large model training, try only saving sharded models if you can).\n\n.. note:: \n\n    Note that ``[model]_loader`` only support environments where **storage clusters are able to connect with every calculation nodes**. \n    Because it utilizes **sharded load way to minimize the loading checkpoint overhead**. \n    Every rank loads its own data from ``state_dict`` which can be accessed by all of them.\n    While there is also no need to broadcast among DP ranks, since the saved state_dict is only produced by DP rank 0.\n\n    For users who can **only place the huggingface model on one device**, we keep the original costly implementation in ``[model]_loader_deprecated``. In this implementation, rank 0 broadcast all weights to each tp and pp rank, and then dp rank 0 broadcast to all dp ranks. There may be at risks of OOM.\n\n    To use deprecated loader, change the import package of ``load_state_dict_to_megatron_llama``.\n"
  },
  {
    "path": "verl_distillation/docs/advance/dpo_extension.rst",
    "content": "Extend to other RL(HF) algorithms\n=================================\n\nLast updated: 02/25/2025.\n\nWe already implemented the complete training pipeline of the PPO\nalgorithms. To extend to other algorithms, we analyze the high-level\nprinciple to use verl and provide a tutorial to implement the DPO\nalgorithm. Users can follow the similar paradigm to extend to other RL algorithms.\n\n.. note:: **Key ideas**: Single process drives multi-process computation and data communication.\n\nOverall Approach\n----------------\n\nStep 1: Consider what multi-machine multi-GPU computations are needed\nfor each model, such as ``generate_sequence`` , ``compute_log_prob`` and\n``update_policy`` in the actor_rollout model. Implement distributed\nsingle-process-multiple-data (SPMD) computation and encapsulate them\ninto APIs\n\nStep 2: Based on different distributed scenarios, including FSDP and 3D\nparallelism in Megatron-LM, implement single-process control of data\ninteraction among multi-process computations.\n\nStep 3: Utilize the encapsulated APIs to implement the control flow\n\nExample: Online DPO\n-------------------\n\nWe use verl to implement a simple online DPO algorithm. The algorithm\nflow of Online DPO is as follows:\n\n1. There is a prompt (rollout) generator which has the same weight as\n   the actor model. After a batch of prompts are fed into the generator,\n   it generates N responses for each prompt.\n2. Send all the prompts + responses to a verifier for scoring, which can\n   be reward model or a rule-based function. Then sort them in pairs to\n   form a training batch.\n3. Use this training batch to train the actor model using DPO. During\n   the process, a reference policy is needed.\n\nStep 1: What are the multi-machine multi-GPU computations\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**Sample Generator**\n\nImplementation details:\n\n.. code:: python\n\n   from verl.single_controller.base import Worker\n   from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool\n   import ray\n\n   @ray.remote\n   class SampleGenerator(Worker):\n       def __init__(self, config):\n           super().__init__()\n           self.config = config\n           \n       def generate_sequences(self, data):\n           pass\n\nHere, ``SampleGenerator`` can be viewed as a multi-process pulled up by\n``torchrun``, with each process running the same code (SPMD).\n``SampleGenerator`` needs to implement a ``generate_sequences`` API for\nthe control flow to call. The implementation details inside can use any\ninference engine including vllm, sglang and huggingface. Users can\nlargely reuse the code in\nverl/verl/workers/rollout/vllm_rollout/vllm_rollout.py and we won't\ngo into details here.\n\n**ReferencePolicy inference**\n\nAPI: compute reference log probability\n\n.. code:: python\n\n   from verl.single_controller.base import Worker\n   import ray\n\n   @ray.remote\n   class ReferencePolicy(Worker):\n       def __init__(self):\n           super().__init__()\n           self.model = Model()\n           \n       def infer(self, data):\n           return self.model(data)\n\n**Actor update**\n\nAPI: Update actor model parameters\n\n.. code:: python\n\n   from verl.single_controller.base import Worker\n   import ray\n\n   @ray.remote\n   class DPOActor(Worker):\n       def __init__(self):\n           super().__init__()\n           self.model = Model()\n           self.model = FSDP(self.model)  # or other distributed strategy\n           self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)\n           self.loss_fn = xxx\n           \n       def update(self, data):\n           self.optimizer.zero_grad()\n           logits = self.model(data)\n           loss = self.loss_fn(logits)\n           loss.backward()\n           self.optimizer.step()\n\n**Notes: How to distinguish between control processes and distributed computation processes**\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n- Control processes are generally functions directly decorated with\n  ``@ray.remote``\n- Computation processes are all wrapped into a ``RayWorkerGroup``.\n\nUsers can reuse most of the distribtued computation logics implemented\nin PPO algorithm, including FSDP and Megatron-LM backend in\nverl/verl/trainer/ppo.\n\nStep 2: Based on different distributed scenarios, implement single-process control of multi-process data interaction\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**The core problem to solve here is how a single process sends data to\nmultiple processes, drives multi-process computation, and how the\ncontrol process obtains the results of multi-process computation.**\nFirst, we initialize the multi-process ``WorkerGroup`` in the control\nprocess.\n\n.. code:: python\n\n   @ray.remote(num_cpus=1)\n   def main_task(config):\n       # construct SampleGenerator\n       resource_pool = RayResourcePool(process_on_nodes=[8] * 2)  # 16 GPUs\n       ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)\n       # put SampleGenerator onto resource pool\n       worker_group = RayWorkerGroup(resource_pool, ray_cls)\n       \n       # construct reference policy\n\nAs we can see, in the control process, multiple processes are wrapped\ninto a ``RayWorkerGroup``. Inside this ``WorkerGroup``, there is a\n``self._workers`` member, where each worker is a RayActor\n(https://docs.ray.io/en/latest/ray-core/actors.html) of SampleGenerator.\nray_trainer.md also provide an implementation of\n``MegatronRayWorkerGroup``.\n\nAssuming the model is distributed using FSDP, and there is a batch of\ndata on the control process, for data parallelism, the underlying\ncalling process is:\n\n.. code:: python\n\n   data = xxx\n   data_list = data.chunk(dp_size)\n\n   output = []\n   for d in data_list:\n       # worker_group._workers[i] is a SampleGenerator\n       output.append(worker_group._workers[i].generate_sequences.remote(d))\n\n   output = ray.get(output)\n   output = torch.cat(output)\n\nSingle process calling multiple processes involves the following 3\nsteps:\n\n1. Split the data into DP parts on the control process.\n2. Send the data to remote, call the remote computation through RPC, and\n   utilize multi-process computation.\n3. Obtain the computation results of each worker on the control process\n   and merge them.\n\nFrequently calling these 3 steps on the controller process greatly hurts\ncode readability. **In verl, we have abstracted and encapsulated these 3\nsteps, so that the worker's method + dispatch + collect can be\nregistered into the worker_group**\n\n.. code:: python\n\n   from verl.single_controller.base.decorator import register\n\n   def dispatch_data(worker_group, data):\n       return data.chunk(worker_group.world_size)\n       \n   def collect_data(worker_group, data):\n       return torch.cat(data)\n\n   dispatch_mode = {\n       'dispatch_fn': dispatch_data,\n       'collect_fn': collect_data\n   }\n\n   @register(dispatch_mode=dispatch_mode)\n   def generate_sequences(self, data):\n       pass\n\nIn this way, we can directly call the method inside the worker through\nthe ``worker_group`` on the control (driver) process (which is a single\nprocess):\n\n.. code:: python\n\n   output = worker_group.generate_sequences(data)\n\nThis single line includes data splitting, data distribution and\ncomputation, and data collection.\n\nFurthermore, the model parallelism size of each model is usually fixed,\nincluding dp, tp, pp. So for these common distributed scenarios, we have\npre-implemented specific dispatch and collect methods,in `decorator.py <https://github.com/volcengine/verl/blob/main/verl/single_controller/base/decorator.py>`_, which can be directly used to wrap the computations.\n\n.. code:: python\n\n   from verl.single_controller.base.decorator import register, Dispatch\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def generate_sequences(self, data: DataProto) -> DataProto:\n       pass\n\nHere it requires the data interface to be ``DataProto``. Definition of\n``DataProto`` is in `protocol.py <https://github.com/volcengine/verl/blob/main/verl/protocol.py>`_.\n\nStep 3: Main training loop\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nWith the above training flows, we can implement the algorithm's control\nflow. It is recommended that ``main_task`` is also a ray remote process.\n\n.. code:: python\n\n   @ray.remote(num_cpus=1)\n   def main_task(config):\n       # construct SampleGenerator\n       resource_pool = RayResourcePool(process_on_nodes=[8] * 2)  # 16 GPUs\n       ray_cls = RayClassWithInitArgs(SampleGenerator, config=config) \n       # put SampleGenerator onto resource pool\n       sample_gen = RayWorkerGroup(resource_pool, ray_cls)\n       \n       # construct reference policy\n       ray_cls = RayClassWithInitArgs(ReferencePolicy)\n       ref_policy = RayWorkerGroup(resource_pool, ray_cls)\n       \n       # construct actor\n       ray_cls = RayClassWithInitArgs(DPOActor)  \n       dpo_policy = RayWorkerGroup(resource_pool, ray_cls)\n       \n       dataloader = DataLoader()\n       \n       for data in dataloader:\n           # generate data\n           data = sample_gen.generate_sequences(data)\n           # generate scores for each data \n           data = generate_scores(data)\n           # generate pairwise data using scores\n           data = generate_pairwise_data(data)\n           # generate ref_log_prob\n           data.batch['ref_log_prob'] = ref_policy.infer(data)\n           # update using dpo\n           dpo_policy.update(data)\n           # logging\n\nHere, different ``WorkerGroups`` can be placed in the same resource pool or\nin different resource pools using ``create_colocated_worker_cls``\nsimilar as in `ray_trainer.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py>`_.\n"
  },
  {
    "path": "verl_distillation/docs/advance/fsdp_extension.rst",
    "content": "\nAdd models with the FSDP backend\n==================================\n\nLast updated: 02/09/2025.\n\nModel\n--------------------------\n\nIn principle, our FSDP backend can support any HF model and we can\nsychronoize the actor model weight with vLLM using `hf_weight_loader.py` under `third_party/vllm`.\nHowever, ``hf_weight_loader`` is will gather the full state_dict of a\nmodel during synchronization, which may cause OOM. We suggest using\n``dtensor_weight_loader`` which gather the full model parameter layer by\nlayer to reduce the peak memory usage. We already support dtensor weight\nloader for the models below in `dtensor_weight_loader.py` under `third_party/vllm`:\n\n- ``GPT2LMHeadModel``\n- ``LlamaForCausalLM``\n- ``LLaMAForCausalLM``\n- ``MistralForCausalLM``\n- ``InternLMForCausalLM``\n- ``AquilaModel``\n- ``AquilaForCausalLM``\n- ``Phi3ForCausalLM``\n- ``GemmaForCausalLM``\n- ``Gemma2ForCausalLM``\n- ``GPTBigCodeForCausalLM``\n- ``Starcoder2ForCausalLM``\n- ``Qwen2ForCausalLM``\n- ``DeepseekV2ForCausalLM``\n\nTo implement ``dtensor_weight_loader`` of a model that's supported in\nvLLM, follow the guide of gemma model below:\n\n1. Copy the\n   ``load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]])`` from the vllm model class\n   to ``dtensor_weight_loaders.py``\n2. Modify the arguments to\n   ``(actor_weights: Dict, vllm_model: nn.Module)``\n3. Replace the ``self`` to ``vllm_model``\n4. Add the\n   ``local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)``\n   before each ``param = params_dict[name]`` and modify the following\n   weight loading using ``local_loaded_weight``.\n5. Register the implemented dtensor weight loader to ``__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__``.\n\n.. code-block:: diff\n\n    - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n    + def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            (\"qkv_proj\", \"q_proj\", \"q\"),\n            (\"qkv_proj\", \"k_proj\", \"k\"),\n            (\"qkv_proj\", \"v_proj\", \"v\"),\n            (\"gate_up_proj\", \"gate_proj\", 0),\n            (\"gate_up_proj\", \"up_proj\", 1),\n        ]\n    -   params_dict = dict(self.named_parameters())\n    +   params_dict = dict(vllm_model.named_parameters())\n        loaded_params = set()\n    -   for name, loaded_weight in weights:\n    +   for name, loaded_weight in actor_weights.items():\n            for (param_name, shard_name, shard_id) in stacked_params_mapping:\n                if shard_name not in name:\n                    continue\n                name = name.replace(shard_name, param_name)\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n    +           local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)\n                param = params_dict[name]\n                weight_loader = param.weight_loader\n    -           weight_loader(param, loaded_weight, shard_id)\n    +           weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)\n                break\n            else:\n                # lm_head is not used in vllm as it is tied with embed_token.\n                # To prevent errors, skip loading lm_head.weight.\n                if \"lm_head.weight\" in name:\n                    continue\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n    +           local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)\n                param = params_dict[name]\n                weight_loader = getattr(param, \"weight_loader\",\n                                        default_weight_loader)\n    -           weight_loader(param, loaded_weight)\n    +           weight_loader(param, local_loaded_weight.to(dtype=param.dtype))\n            loaded_params.add(name)\n        unloaded_params = params_dict.keys() - loaded_params\n        if unloaded_params:\n            raise RuntimeError(\n                \"Some weights are not initialized from checkpoints: \"\n                f\"{unloaded_params}\")"
  },
  {
    "path": "verl_distillation/docs/advance/fully_async.md",
    "content": "# Recipe: Fully Async Policy Trainer\n\n**Author:** `https://github.com/meituan-search`\n\nLast updated: 10/18/2025.\n\nThis document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter,\nsupporting asynchronous sample generation and training.\nUnder this system, we achieved a 2.35x-2.67x performance improvement when training the Qwen2.5-7B model with 128 GPUs,\nwithout significantly affecting the results.\n\n## Introduction\n\n### Background\n\nThe separated rollout and train architecture, compared to the colocate architecture, can allocate resources more\nflexibly and design more flexible training logic, thereby addressing issues such as low GPU utilization and training\nefficiency caused by long-tail problems.\nThe one_step_off_policy alleviates the problem of long rollout times and achieves some gains in training efficiency by\ndesigning a separated architecture and performing asynchronous training between rollout and train for one round.\nHowever, it forcibly uses data from one round of asynchronous training, which is not flexible enough and cannot\ncompletely eliminate the impact of long-tail on training efficiency.\nIn other frameworks such as AReaL, Magistral, StreamRL, and AsyncFlow, asynchronous training and streaming training have\nbeen implemented based on the separated architecture and have achieved gains.\nWe borrow from their methods and implemented them in VERL. The fully_async_policy supports asynchronous, streaming, and\npartial\nrollout training.\nBy reasonably setting parameters such as resource allocation and parameter synchronization frequency, fully_async_policy\ncan significantly improve training efficiency.\n\n> Magistral https://arxiv.org/abs/2506.10910\n>\n> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language\n> Reasoning https://arxiv.org/abs/2505.24298\n>\n> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream\n> Generation https://arxiv.org/abs/2504.15930\n>\n> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663\n>\n\n### Core Contributions\n\n* **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to\n  specify the resources they occupy separately.\n* **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples.\n* **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to\n  multiple steps, making the asynchronous solution more flexible.\n* **NCCL Parameter Synchronization**: Uses NCCL communication primitives for parameter communication between Rollouter\n  and Trainer.\n* **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single\n  sample as the minimum transmission unit.\n* **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it\n  supports training with samples generated by old parameters.\n* **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter\n  synchronization, by adding `sleep() and resume()` logic, it\n  saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for\n  ongoing tasks to finish during parameter synchronization.\n\nCurrently, the supported usage mode is megatron/fsdp+vllm. vllm must use the server mode based on AgentLoop.\n\n## Design\n\nThe overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four\nparts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer.\n\n![fully_async_policy_structure](\nhttps://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true)\n\n1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the\n   production speed controlled by freshness.\n2. MessageQueue is used to temporarily store samples generated by Rollouter.\n3. Trainer fetches samples from MessageQueue sample by sample. After fetching `require_batches*ppo_mini_batch_size`\n   samples, it will perform training. After training for async_training.trigger_parameter_sync_step rounds, it triggers\n   a parameter synchronization with Rollouter.\n4. ParameterSynchronizer implements the NCCL synchronous parameter synchronization capability.\n\nThe source of benefits compared to the base scheme lies in the fact that in the colocate case, using more resources for\nrollout cannot solve the idleness caused by long-tail samples.\nAfter we perform resource isolation, the time for rollout and train may be longer than before (because fewer resources\nare used),\nbut the overlap in their time consumption reduces the end-to-end time consumption.\n\n![fully_async_policy_revenue](\nhttps://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true)\n\n## Usage\n\n### Parameter Description\n\n| super params                                  | implication                                                                                    |\n|-----------------------------------------------|------------------------------------------------------------------------------------------------|\n| `trainer.nnodes`                              | Number of nodes for Trainer                                                                    |\n| `trainer.n_gpus_per_node`                     | Number of GPUs per node for Trainer                                                            |\n| `rollout.nnodes`                              | Number of nodes for Rollouter                                                                  |\n| `rollout.n_gpus_per_node`                     | Number of GPUs per node for Rollouter                                                          |\n| `data.train_batch_size`                       | In the fully async strategy, this value is not effective (default is 0)                        |\n| `data.gen_batch_size`                         | In the fully async strategy, uses streaming sample production logic (default is 1)             |\n| `rollout.total_rollout_steps`                 | Total number of rollout samples                                                                |\n| `rollout.test_freq`                           | How many times Rollouter updates parameters before performing a validation                     |\n| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus                                |\n| `async_training.require_batches`              | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once                           |\n| `async_training.trigger_parameter_sync_step`  | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization |\n| `async_training.staleness_threshold`          | Freshness control                                                                              |\n| `async_training.partial_rollout`              | Whether to perform partial_rollout                                                             |\n| `async_training.use_rollout_log_probs`        | Use log_probs generated by rollout                                                             |\n| `async_training.compute_prox_log_prob`        | Whether to compute log_prob using the training model's parameters during the training phase.   |                                                |\n\n**Further Explanation:**\n\n* `rollout.total_rollout_steps`\n\n  Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step:\n  `rollout.total_rollout_steps = data.train_batch_size * step`.\n\n* `async_training.trigger_parameter_sync_step`\n\n  In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches\n  `require_batches * ppo_mini_batch_size` samples) before a parameter synchronization with Rollouter.\n  Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process\n  `trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` samples.\n  To fairly compare speed with colocate, trigger_parameter_sync_step should be set to\n  `data.train_batch_size / (require_batches * ppo_mini_batch_size)`.\n\n* `async_training.staleness_threshold`\n\n  In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used.\n\n    * staleness_threshold=0, indicates synchronous training.\n      Rollouter will generate a fixed number of samples between two parameter updates, the sample count is:\n      $$rollout\\_num = (trigger\\_parameter\\_sync\\_step*require\\_batches*ppo\\_mini\\_batch\\_size)$$\n    * staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous\n      calls.\n      Rollouter will generate at most the following number of samples between two parameter updates:\n      $$rollout\\_num = (1+staleness\\_threshold)*(trigger\\_parameter\\_sync\\_step*require\\_batches*ppo\\_mini\\_batch\\_size) - num\\_staleness\\_sample $$\n\n  num_staleness_sample represents the number of stale samples generated in excess during the last rollout.\n\n  Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower,\n  trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples.\n  When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy.\n  To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1.\n\n* `async_training.partial_rollout`\n\n  partial_rollout only actually takes effect when staleness_threshold>0.\n\n* `async_training.use_rollout_log_probs`\n\n  In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to\n  the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling,\n  old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm\n  correctness. In the fully\n  async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.\n\n* `async_training.require_batches`\n\n  In streaming training, require_batches should be set to 1, indicating that training is performed after producing\n  enough ppo_mini_batch_size samples.\n  In actual testing, we found that if fewer samples are issued at once, due to the order of data distribution, it can\n  cause training instability and longer response lengths.\n  Here, we additionally provide require_batches for streaming distribution and control the number of samples\n  participating in training at once.\n\n* `async_training.compute_prox_log_prob` (experimental)\n\n  During the training process, we observed that metrics and response lengths may become unstable in the later\n  stages of training. To mitigate this issue, we can use\n  the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html)\n  technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using\n  the training engine, which requires enabling this switch.\n  Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d\n  (async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`.\n\n### Supported Modes\n\n1. on policy pipeline:\n    1. **trigger_parameter_sync_step=1, staleness_threshold=0**\n    2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for\n       training, and after training completes, Trainer and Rollouter perform a parameter synchronization;\n    3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill\n       idle resources, causing some resource waste.\n    4. As shown in figure a;\n\n2. stream off policy pipeline:\n    1. **trigger_parameter_sync_step>1, staleness_threshold=0**\n    2. Synchronous streaming training will be performed. Rollouter produces\n       `require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local\n       training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training\n       trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization;\n    3. Compared to a, since more samples are generated at once, resource idleness will be lower.\n    4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples,\n       train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter\n       update, rollout waits for training to complete.\n    5. As shown in figure b;\n\n3. async stream pipeline with stale samples:\n    1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False**\n    2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number\n       of samples generated may be less than this value depending on rollout speed).\n    3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples\n       before parameter synchronization for immediate use by Trainer after synchronization.\n       When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete\n       and not add new tasks;\n    4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the\n       first batch rollout to finish, but will have the time to wait for active tasks to finish.\n    5. As shown in figure c;\n\n4. async stream pipeline with partial rollout:\n    1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True**\n    2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will\n       interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be\n       generated after synchronization. This reduces the time to wait for active tasks to finish.\n    3. As shown in figure d;\n\n![fully_async_policy_mode](\nhttps://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true)\n\n### Key Metrics\n\n| metrics                                        | implication                                                                                            |\n|------------------------------------------------|--------------------------------------------------------------------------------------------------------|\n| `trainer/idle_ratio`                           | Trainer idle rate                                                                                      |\n| `rollouter/idle_ratio`                         | Rollouter idle rate                                                                                    |\n| `fully_async/count/stale_samples_processed`    | Total number of old samples used in training                                                           |\n| `fully_async/count/stale_trajectory_processed` | Total number of old trajectories used in training (one sample produces rollout.n trajectories)         |\n| `fully_async/partial/total_partial_num`        | Number of partial samples processed by Trainer between two trigger_parameter_sync_step                 |\n| `fully_async/partial/partial_ratio`            | Ratio of partial samples processed by Trainer between two trigger_parameter_sync_step                  |\n| `fully_async/partial/max_partial_span`         | Maximum parameter span of partial samples processed by Trainer between two trigger_parameter_sync_step |\n\n### Parameter Tuning Recommendations\n\n* Resource Allocation and Adjustment:\n    * Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource\n      allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire\n      training process,\n      avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource\n      allocation can be adjusted based on the idle time of rollout and train during actual training,\n      which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and\n      trainer/idle_ratio is low,\n      Trainer resources should be increased and Rollouter resources should be reduced, and vice versa.\n\n* Key Parameters:\n    * staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It\n      is recommended to set it to less than 1.\n    * require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and\n      the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample\n      processing;\n    * trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent\n      parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in\n      low resource utilization.\n      The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy.\n    * rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small.\n\n* Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at\n  different levels, suitable for tasks in different scenarios.\n    * For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed\n      requirements, the on policy pipeline mode (Mode 1) can be tried.\n    * For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy\n      pipeline mode can be tried. That is, by\n      setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization\n      mechanism (staleness_threshold=0) (Mode 2).\n    * For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and\n      staleness, setting staleness_threshold>\n      0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4).\n\n### Quick Start\n\n```shell\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\ntotal_rollout_steps=$(((512*400)))\ntest_freq=10\nstaleness_threshold=0\ntrigger_parameter_sync_step=16\npartial_rollout=False\n\n\npython -m recipe.fully_async_policy.fully_async_main \\\n\ttrain_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    trainer.nnodes=\"${NNODES_TRAIN}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.nnodes=\"${NNODES_ROLLOUT}\" \\\n    rollout.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.test_freq=\"${test_freq}\" \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\"\n```\n\n## Experiments\n\n### Asynchronous Training on 7B Model\n\nWe used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under long candidates and multiple resources.\nUsing the `async stream pipeline with stale samples` strategy, we achieved about 2x performance improvement on 32 cards,\n64 cards, and 128 cards without significantly affecting experimental results.\n\n* Machine: H20\n* Model: Qwen2.5-Math-7B\n* Rollout length: max_response_length FSDP2: 28K tokens;\n* Algorithm: DAPO\n* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet\n* Engine: vllm+FSDP2\n* rollout.n: 16\n* ppo_mini_batch_size: 32\n* test_freq: 20\n\n* colocate sync:\n    * step: 400\n    * train_batch_size: 512\n\n* fully_async_policy\n    * total_rollout_steps: 512*400\n    * require_batches: 4\n    * trigger_parameter_sync_step: 4\n    * staleness_threshold: 0.5\n    * partial_rollout: True\n\n|  training mode   \t   | resource allocation \t | step  \t  |  gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t | total time<br>400 step \t |      acc/mean@1          \t      |\n|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-------------------------------:|\n| colocate sync      \t | 32                  \t | 790.10 \t | 357.41 \t | 107.71       \t | 313.81       \t | 13h 44m                \t | 1d 3h 43m              \t | 2d 9h 22m              \t | 3d 17h 5m              \t | max: 0.3313<br>last: 0.2448  \t  |\n| fully_async_policy \t | 16:16               \t |  294.77  |  21.26   | \\            \t |     269.80     |    7h 58m<br>(1.72x)     |    16h 21m<br>(1.70x)    |   1d 0h 53m<br>(2.31x)   |   1d 9h 26m<br>(2.66x)   | max: 0.3302<br>last: 0.2333   \t |\n| colocate sync      \t | 64                  \t | 365.28 \t | 150.72 \t | 70.26        \t | 133.41       \t | 10h 22m                \t | 20h 45m                \t | 1d 7h 6m               \t | 1d 17h 32m             \t | max: 0.3365<br>last:  0.2333 \t  |\n| fully_async_policy \t | 32:32               \t | 189.26 \t | 28.46  \t | \\            \t | 156.98       \t | 4h 57m<br>(2.09x)      \t | 10h 14m<br>(2.03x)     \t | 16h 58m<br>(1.83x)     \t | 21h 40m<br>(1.92x)     \t | max: 0.3677<br>last: 0.3406  \t  |\n| colocate sync      \t | 128                 \t | 356.30 \t | 177.85 \t | 53.92        \t | 113.81       \t | 8h 36m                 \t | 17h 56m                \t | 1d 5h 6m               \t | 1d 16h 48m             \t | max: 0.3573<br>last: 0.2958  \t  |\n| fully_async_policy \t | 64:64               \t | 150.63 \t | 33.14  \t | \\            \t | 113.16       \t | 3h 13m<br>(2.67x)      \t | 6h 46m<br>(2.65x)      \t | 10h 53m<br>(2.67x)     \t | 17h 22m<br>(2.35x)     \t | max: 0.3521<br>last: 0.3094  \t  |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg\n\n### 128-card 7B Asynchronous Mode Experiment\n\nWe used Qwen2.5-Math-7B to verify the effects of various modes supported by fully async.\nWe can see that the benefit brought by streaming is approximately 1.6x, and after combining staleness and\npartial_rollout, the benefit reaches 2.35x.\n\n|                             mode                                         \t                              | step  \t  |  gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t | total time<br>400 step \t |      acc/mean@1         \t      |\n|:-------------------------------------------------------------------------------------------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:|\n|                                          colocate sync      \t                                           | 356.30 \t | 177.85 \t | 53.92        \t | 113.81       \t | 8h 36m                 \t | 17h 56m                \t | 1d 5h 6m               \t | 1d 16h 48m             \t | max: 0.3573<br>last: 0.2958  \t |\n| `stream off policy pipeline`<br>(+fully async: trigger_parameter_sync_step= 4,<br>require_batches= 4) \t | 231.34 \t | 128.47 \t | \\            \t | 98.77        \t | 4h 25m                 \t | 9h 41m                 \t | 15h 2m                 \t | 1d 1h 53m              \t | max: 0.2844<br>last: 0.2604 \t  |\n|          `async stream pipeline with stale samples`<br>(+staleness_threshold=0.5)            \t          |    \t     |    \t     |       \t        |       \t        |            \t             |            \t             |            \t             |            \t             |               \t                |\n|        `async stream pipeline with partial rollout`<br>(+partial_rollout=True)                 \t        | 150.63 \t | 33.14  \t | \\            \t | 113.16       \t | 3h 13m                 \t | 6h 46m                 \t | 10h 53m                \t | 17h 22m                \t | max: 0.3521<br>last: 0.3094 \t  |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg\n\n### 128-card Stale Ablation Experiment\n\nUnder the `async stream pipeline with partial rollout` mode, we verified the impact of staleness settings on training\nefficiency.\nWe found that the larger the staleness, the more obvious the final gains.\nWe also noticed that the times for staleness values of 0.3 and 0.5 are quite close, because as the training steps\nincrease, the response length changes significantly, causing training instability.\nFurther analysis and optimization are needed for this issue.\n\n| staleness_threshold \t | step  \t  |  gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t | total time<br>400 step \t |     acc/mean@1         \t      |\n|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|\n| 0                   \t | 231.34 \t | 128.47 \t | \\            \t | 98.77        \t | 4h 25m                 \t | 9h 41m                 \t | 15h 2m                 \t | 1d 1h 53m              \t | max: 0.2844<br>last: 0.2604 \t |\n| 0.1                 \t | 171.30 \t | 58.17  \t | \\            \t | 109.12       \t | 3h 53m                 \t | 8h 37m                 \t | 14h 25m                \t | 19h 59m                \t | max: 0.3542<br>last: 0.2979 \t |\n| 0.3                 \t | 146.11 \t | 38.88  \t | \\            \t | 103.22       \t | 3h 18m                 \t | 6h 49m                 \t | 11h 40m                \t | 17h 20m                \t | max: 0.3469<br>last: 0.2865 \t |\n| 0.5                 \t | 150.63 \t | 33.14  \t | \\            \t | 113.16       \t | 3h 13m                 \t | 6h 46m                 \t | 10h 53m                \t | 17h 22m                \t | max: 0.3521<br>last: 0.3094 \t |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg\n\n### 128-card 7B require_batches Ablation Experiment\n\nIn multiple tests, we found that the number of samples issued each time in streaming affects the response length during\ntraining, which in turn affects training time. We verified the impact on results by modifying\n`async_training.require_batches`.\n\n| require_batches \t | step  \t  | gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t |     acc/mean@1         \t      |\n|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|\n| 1               \t | 203.47 \t | 30.88 \t | \\            \t | 181.08       \t | 3h 31m                 \t | 8h 29m                 \t | 17h 36m                \t | max: 0.349<br>last: 0.326   \t |\n| 2               \t | 158.72 \t | 26.32 \t | \\            \t | 128.08       \t | 3h 35m                 \t | 7h 38m                 \t | 13h 57m                \t | max: 0.351<br>last: 0.3406  \t |\n| 4               \t | 124.64 \t | 25.62 \t | \\            \t | 95.06        \t | 3h 13m                 \t | 6h 46m                 \t | 10h 53m                \t | max: 0.3521<br>last: 0.3521 \t |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg\n\n### 30B Model Mode Experiment\n\nTODO: The 30B experiment is still in progress.\n\n* Machine: H20\n* Model: Qwen2.5-32B\n* Rollout length: max_response_length FSDP2: 20K tokens;\n* Algorithm: DAPO\n* Engine: vllm+FSDP2\n* rollout.n: 16\n* ppo_mini_batch_size: 32\n* test_freq: 20\n\n* colocate sync:\n    * step:200\n    * train_batch_size: 512\n\n* fully_async_policy\n    * total_rollout_steps: 512*200\n    * trigger_parameter_sync_step: 512/32 = 16\n    * staleness_threshold: 0\n    * partial_rollout: False\n\n| training mode      | Resource allocation | mode                                       | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean |\n|--------------------|---------------------|--------------------------------------------|------|--------------------|--------------|--------------|------------|------------------|\n| colocate sync      | 128                 |                                            |      |                    |              |              |            |                  |\n| fully_async_policy | 64:64               | stream off policy pipeline                 |      |                    |              |              |            |                  |\n| fully_async_policy | 64:64               | async stream pipeline with stale samples   |      |                    |              |              |            |                  |\n| fully_async_policy | 64:64               | async stream pipeline with partial rollout |      |                    |              |              |            |                  |\n\n## Future Plans\n\n* GRPO experiments\n* Megatron adaptation\n* SGLang integration\n* Transfer queue integration\n* Asynchronous parameter synchronization\n* AReaL asynchronous algorithm implementation\n* TPPO algorithm implementation\n* Multi-turn and Tool support"
  },
  {
    "path": "verl_distillation/docs/advance/megatron_extension.rst",
    "content": "Add models with the Megatron-LM backend\n=========================================\n\nLast updated: 04/25/2025.\n\nModel\n-----------\n\n\nIf use latest verl, we have direct support of ``GPTModel`` for Megatron backend. \nYou can use the similar way of using Megatron to pretrain custom models. \nWe list the steps here:\n\n1. Find `model_initializer.py <https://github.com/volcengine/verl/blob/main/verl/models/mcore/model_initializer.py>`_\n2. If your model is configurable by ``TransformerLayerSpec`` , you can\n   directly use ``GPTModel``. Otherwise, Please implement a new\n   ``ModelLayerSpec`` and ``ModelLayer`` here.\n3. Use the right ``LayerSpec`` , ``TransformerConfig`` and ``HuggingfaceConfig`` \n   as arguments to initialize the GPTModel.\n4. Return the model at last.\n"
  },
  {
    "path": "verl_distillation/docs/advance/one_step_off.md",
    "content": "# Recipe: One Step Off Policy Async Trainer\n\n**Author:**  `https://github.com/meituan-search`\n\nLast updated: 07/17/2025.\n\n## Introduction\n\n### Background\n\nThe current reinforcement learning training process implemented by verl is synchronous, adhering to the algorithmic\nworkflows of established methods like PPO, GRPO, and DAPO. In each step, training samples are generated by the latest\nmodel, and the model is updated after training completes. While this approach aligns with off-policy reinforcement\nlearning and stabilizes RL training, but it suffers from severe efficiency issues.\nModel updates must wait for the longest output in the generation phase to complete.\nDuring the generation of long-tail samples, GPUs remain idle, resulting in significant underutilization.\nThe more severe the long-tail problem in sample generation, the lower the overall training efficiency.\nFor example, in DAPO 32B training, the Rollout phase accounts for approximately 70% of the total time,\nand increasing resources does not reduce the Rollout duration.\n\n![DAPO 32B Math Performance](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/dapo_32b_math.png)\n> source data: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=nwusertongyuxuan361\n\n### Solution\n\nWe have implemented the **One Step Off Async Trainer** to help alleviate this issue. This approach parallelizes the\ngeneration and training processes, utilizing samples generated in the previous step for current training.\nIt also involves appropriately partitioning resources, allocating dedicated resources for generation while automatically\nassigning the remainder to training. By reducing resources allocated to the generation phase, we mitigate GPU idle time\nduring long-tail sample generation. Throughout this process, generation and training parameters maintain a one-step off\npolicy.\n\n![One Step Off Policy Diagram](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_policy.png)\n> reference: [AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning](\n> https://arxiv.org/abs/2505.24298)\n\nOur core contributions include:\n\n1. **Parallel Generation and Training**:  \n   Samples for the next batch are asynchronously generated while the current batch is being trained.\n\n2. **Resource Isolation**:  \n   Unlike `hybrid_engine`, this method requires explicit resource allocation for rollout, with remaining resources\n   automatically assigned to training.\n\n3. **NCCL Parameter Synchronization**:  \n   Employs NCCL communication primitives for seamless parameter transfer between generation and training modules.\n\n### Experimental Results\n\n- **Machine Configuration**: 2 nodes with 16 H20 GPUs each\n   - Generation: 4 GPUs\n   - Training: 12 GPUs\n- **Model**: Qwen2.5-Math-7B\n- **Rollout Configuration**:\n- **Max Response Length**: FSDP2: 20,480 tokens; Megatron: 8,192 tokens\n- **Algorithm**: DAPO\n- **Rollout Engine**: vLLM\n\n| training mode          | engine        | step | gen | wait_prev_gen | generate_sequences | old_log_prob | update_actor | total time    | acc/best@32/mean | acc/maj@32/mean |\n|------------------------|---------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------|\n| colocate sync          | VLLM+FSDP2    | 749  | 321 | -             | 247                | 88           | 286          | 19h18m        | 0.5948           | 0.417           |\n| one-step-overlap async | VLLM+FSDP2    | 520  | -   | 45            | 458                | 108          | 337          | 15h34m（+23%）  | 0.6165           | 0.494           |\n| colocate sync          | VLLM+Megatron | 699  | 207 | -             | 162                | 119          | 344          | 18h21m        | 0.605            | 0.4217          |\n| one-step-overlap async | VLLM+Megatron | 566  | -   | 59            | 501                | 120          | 347          | 13h06m (+40%) | 0.6569           | 0.4038          |\n\n* colocate sync: step ≈ gen + old_log_prob + update_actor\n* one-step-overlap async: step ≈ wait_prev_gen + old_log_prob + update_actor\n\n![One Step Off Megatron Performance](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_megatron.png)\n\n> source data: https://wandb.ai/hou-zg-meituan/one-step-off-policy?nw=nwuserhouzg\n\n## Implementation\n\n### One Step Off Policy Async Pipline\n\nOur implemented **One Step Off Policy Async Pipeline** integrates seamlessly into existing training logic at minimal\ncost,\neliminating the need for additional sample storage management. The core mechanism uses `async_gen_next_batch`\nfor asynchronous rollout generation while maintaining continuous operation during epoch transitions\nvia `create_continuous_iterator`.\n\n```python\n# iterator generator, simplify one-step integration of the training process\ndef _create_continuous_iterator(self):\n   for epoch in range(self.config.trainer.total_epochs):\n      iterator = iter(self.train_dataloader)\n      for batch_dict in iterator:\n         yield epoch, batch_dict\n\n\n# read next batch samples, parameters sync and launch asyn gen_seq\ndef _async_gen_next_batch(self, continuous_iterator):\n   # read train_data\n   try:\n      epoch, batch_dict = next(continuous_iterator)\n   except StopIteration:\n      return None\n   batch = DataProto.from_single_dict(batch_dict)\n   gen_batch = batch_pocess(batch)\n   # sync weights from actor to rollout\n   self.sync_rollout_weights()\n   # async generation\n   gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)\n   # future encapsulated\n   return GenerationBatchFuture(epoch, batch, gen_batch_output)\n\n\ncontinuous_iterator = self._create_continuous_iterator()\n# run rollout first to achieve one-step-off\nbatch_data_future = self._async_gen_next_batch(continuous_iterator)\n\nwhile batch_data_future is not None:\n   # wait for the gen_seq result from the previous step\n   batch = batch_data_future.get()\n   # launch the next async call to generate sequences\n   batch_data_future = self._async_gen_next_batch(continuous_iterator)\n\n   # compute advantages \n   batch = critic.compute_values(batch)\n   batch = reference.compute_log_prob(batch)\n   batch = reward.compute_reward(batch)\n   batch = compute_advantages(batch)\n\n   # model update\n   critic_metrics = critic.update_critic(batch)\n   actor_metrics = actor.update_actor(batch)\n```\n\n### Parameter Synchronization\n\nThe exciting point is that our nccl based weights updating for rollout model has great performance.\nAt most of time, the latency is under 300ms, which is negligible for RLHF.\n\n> **sync_rollout_weights**：The time for synchronizing parameters from actor to rollout is extremely fast and can almost\n> be ignored because it is implemented with nccl.\n\n```python\nclass ActorRolloutRefWorker:\n   # actor acquires the meta-info of model parameters for parameter sync\n   @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n   def get_actor_weights_info(self):\n      params = self._get_actor_params()\n      ret = []\n      for key, tensor in params.items():\n         ret.append((key, tensor.size(), tensor.dtype))\n      self._weights_info = ret\n      return ret\n\n   # rollout sets the meta-info of model parameters for parameter sync\n   @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n   def set_actor_weights_info(self, weights_info):\n      self._weights_info = weights_info\n\n\nclass AsyncRayPPOTrainer(RayPPOTrainer):\n   def init_workers(self):\n      ...\n      # rollout obtains the meta-info of model parameters from the actor for parameter sync\n      weights_info = self.actor_wg.get_actor_weights_info()[0]\n      self.rollout_wg.set_actor_weights_info(weights_info)\n      \n      # Create an actor-rollout communication group for parameter sync\n      self.create_weight_sync_group\n```\n\n```python\n# The driving process invokes the actor and rollout respectively to create a weight synchronization group based on nccl/hccl.\ndef create_weight_sync_group(self):\n   master_address = ray.get(self.actor_wg.workers[0]._get_node_ip.remote())\n   master_port = ray.get(self.actor_wg.workers[0]._get_free_port.remote())\n   world_size = len(self.actor_wg.workers + self.rollout_wg.workers)\n   self.actor_wg.create_weight_sync_group(\n      master_address,\n      master_port,\n      0,\n      world_size,\n   )\n   ray.get(\n      self.rollout_wg.create_weight_sync_group(\n            master_address,\n            master_port,\n            len(self.actor_wg.workers),\n            world_size,\n      )\n   )\n\n# drive process call the actor and rollout respectively to sync parameters by nccl \ndef sync_rollout_weights(self):\n   self.actor_wg.sync_rollout_weights()\n   ray.get(self.rollout_wg.sync_rollout_weights())\n\n\n# fsdp model parameter sync\n@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\ndef sync_rollout_weights(self):\n   params = self._get_actor_params() if self._is_actor else None\n   if self._is_rollout:\n      inference_model = (\n         self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n      )\n      from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader\n      patch_vllm_moe_model_weight_loader(inference_model)\n   # Model parameters are broadcast tensor-by-tensor from actor to rollout\n   for key, shape, dtype in self._weights_info:\n      tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())\n      if self._is_actor:\n         assert key in params\n         origin_data = params[key]\n         if hasattr(origin_data, \"full_tensor\"):\n            origin_data = origin_data.full_tensor()\n         if torch.distributed.get_rank() == 0:\n            tensor.copy_(origin_data)\n      from ray.util.collective import collective\n\n      collective.broadcast(tensor, src_rank=0, group_name=\"actor_rollout\")\n      if self._is_rollout:\n         inference_model.load_weights([(key, tensor)])\n```\n\n## Usage\n\n### FSDP2 Configuration Example\n\n```shell\npython3 -m recipe.one_step_off_policy.async_main_ppo \\\n    --config-path=config \\\n    --config-name='one_step_off_ppo_trainer.yaml' \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    # actor and rollout are placed separately\n    actor_rollout_ref.hybrid_engine=False \\\n    # actor and rollout resource\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=6 \\\n    rollout.nnodes=1 \\\n    rollout.n_gpus_per_node=2\n```\n\n### Megatron Configuration Example\n\n```shell\npython3 -m recipe.one_step_off_policy.async_main_ppo \\\n    --config-path=config \\\n    --config-name='one_step_off_ppo_megatron_trainer.yaml' \\\n    actor_rollout_ref.actor.strategy=megatron \\\n    # actor and rollout are placed separately\n    actor_rollout_ref.hybrid_engine=False \\\n    # actor and rollout resource\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=6 \\\n    rollout.nnodes=1 \\\n    rollout.n_gpus_per_node=2\n```\n\n### Configuration Guidelines\n\n1. **Card Number Relationships**  \n   Maintain either of these relationships for optimal batch distribution:\n   - `actor_rollout_ref.rollout.n` should be an integer divisor of:  \n     `trainer.n_gpus_per_node * trainer.nnodes`\n   - `actor_rollout_ref.rollout.n * data.train_batch_size` should be evenly divisible by:  \n     `trainer.n_gpus_per_node * trainer.nnodes`\n\n   > Rationale: Ensures training samples can be evenly distributed across training GPUs when using partial resources for\n   generation.\n\n2. **Dynamic Resource Tuning**  \n   Adjust `trainer.nnodes` `trainer.n_gpus_per_node` `rollout.nnodes` `rollout.n_gpus_per_node` based on phase\n   durations:\n   - **Ideal state**: Rollout and training phases have comparable durations\n   - **Diagnostic metrics**:\n      - Monitor `wait_prev_gen` duration\n      - Analyze `sequence_length` distribution\n   - **Adjustment strategy**:\n      - High `wait_prev_gen` + uniform sequence lengths → Increase rollout resources\n      - High `wait_prev_gen` + long-tail sequences → Optimize stopping criteria (resource increase won't help)\n   > **wait_prev_gen**：The time consumed waiting for the previous rollout to end (the part that is not fully\n   overlapped).\n   **Resource Configuration Strategies:**\n   - **Resource-constrained scenario**: Optimize resource utilization by adjusting GPU allocation ratios,\n     keeping the number of nodes equal to allow training and rollout to share nodes;\n      - Configure `trainer.nnodes = rollout.nnodes` with\n        `trainer.n_gpus_per_node + rollout.n_gpus_per_node = physical_gpus_per_node`. Control rollout resource\n        allocation by adjusting `n_gpus_per_node`.\n   - **Resource-abundant scenario**: Optimize performance by adjusting the number of nodes,\n     keeping the number of GPUs per node equal to enable independent scaling of training and rollout\n     parallelism.\n      - Configure `trainer.n_gpus_per_node = rollout.n_gpus_per_node` and control rollout resource allocation by\n        adjusting `trainer.nnodes` and `rollout.nnodes`to achieve optimal performance.\n   > **Note**: The total number of nodes required by the system is not simply `trainer.nnodes + rollout.nnodes`. The\n   > actual calculation depends on GPU capacity:\n   > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node <= physical_gpus_per_node`,\n       > the required node count is `max(trainer.nnodes, rollout.nnodes)`\n   > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node > physical_gpus_per_node`,\n       > the required node count is `trainer.nnodes + rollout.nnodes`\n\n## Functional Support\n\n| Category           | Support Situation                                                                                               |\n|--------------------|-----------------------------------------------------------------------------------------------------------------|\n| train engine       | FSDP2  <br/> Megatron                                                                                           |\n| rollout engine     | vLLM                                                                                                            |\n| AdvantageEstimator | GRPO <br/> GRPO_PASSK <br/> REINFORCE_PLUS_PLUS <br/> RLOO <br/> OPO <br/> REINFORCE_PLUS_PLUS_BASELINE<br/>GPG |\n| Reward             | all                                                                                                             |\n"
  },
  {
    "path": "verl_distillation/docs/advance/placement.rst",
    "content": "Ray API Design Tutorial\n=======================================\n\nLast updated: 10/30/2024.\n\nWe provide a tutorial for our Ray API design, including:\n\n- Ray basic concepts\n- Resource Pool and RayWorkerGroup\n- Data Dispatch, Execution and Collection\n- Initialize the RayWorkerGroup and execute the distributed computation in the given Resource Pool\n\nSee details in `tutorial.ipynb <https://github.com/volcengine/verl/blob/main/examples/ray/tutorial.ipynb>`_."
  },
  {
    "path": "verl_distillation/docs/advance/ppo_lora.rst",
    "content": "RL(HF) algorithms with LoRA Support\n===========================================\n\nLast updated: 06/05/2025.\n\nWe support LoRA (Low-Rank Adaptation) for reinforcement learning algorithms such as PPO, GRPO, and others.\n\nLoRA is a parameter-efficient fine-tuning technique that injects trainable low-rank matrices into pre-trained weights (typically linear layers). This reduces memory footprint and compute cost, making it possible to fine-tune large models with limited hardware.\n\nThe benefits this brings include:\n\n- reinforcement learning with very large models (e.g. 70B+) with modest hardware (e.g. 8x80G GPUs),\n- enable larger batch sizes due to reduced memory usage,\n- simplify model transfer and deployment, as only LoRA adapters need to be saved,\n- Combine with techniques like `SLoRA <https://arxiv.org/abs/2311.03285>`_ or `CCoE <https://arxiv.org/abs/2407.11686>`_ to serve multiple LoRA adapters efficiently\n\nThis guide explains how to enable LoRA in RL training and configure related parameters.\n\nUsage Guide\n------------------------\n1. Lora is available in the `verl.trainer.ppo.ray_trainer.RayPPOTrainer`. Examples are provided via the `verl.trainer.main_ppo` entry point.\n\n2. Currently, LoRA is supported via huggingface peft, only with fsdp/fsdp2 and vllm backend (sglang support coming soon).\n\n- `strategy=fsdp` or `strategy=fsdp2`\n- `rollout.name=vllm`\n\n3. Required configurations for LoRA:\n\n- `actor_rollout_ref.model.lora_rank`: int, set to a reasonable value greater than 0 (e.g., 8, 16, 32, 64)\n- `actor_rollout_ref.model.lora_alpha`: float, the alpha term in LoRA\n- `actor_rollout_ref.rollout.load_format=\"safetensors\"`: required. This enables vLLM to load the base model.\n- `actor_rollout_ref.model.target_modules`: the target modules for LoRA. Typically set to \"all-linear\".\n\n4. Optional configurations for LoRA:\n\n- `actor_rollout_ref.model.lora_adapter_path`: string, path to a pretrained LoRA adapter directory. \n   If provided, loads existing adapter instead of creating new one. Enables multi-stage training from previously saved adapters.\n   Directory need contain `adapter_model.safetensors` and `adapter_config.json`.\n\n5. Recommend options:\n\n- `actor_rollout_ref.model.use_shm=True`: preload the model into `/dev/shm` to improve model loading speed.\n- `actor_rollout_ref.rollout.layered_summon=True`: this enables the actor-model to gather the FSDP shards per layers when synchronizing the LoRA Adapter to vLLM, thereby reducing GPU peak memory. Recommended if the model is very large (70B+) or the GPU memory is limited (< 48GB)\n\n\nBest Practices and Notes\n-------------------------\n\n1. **Learning rate**: it is recommended to increase the value of learning rate by an order of magnitude.\n\n2. **LoRA Rank**:\n\n- Too small a rank can hurt convergence.\n- LoRA rank recommendation from @thelongestusernameofall:\n\n  - A very small lora_rank can lead to slower convergence or worse training performance. It is recommended to set lora_rank to be>=32. Tests have shown that for a 0.5B model, with lora_rank=32,the training convergence speed and final performance are almost identical to non-LoRA training\n  - For a 32B model,with lora_rank=128,the training convergence speed and final performance are also almost identical to non-LoRA training.\n  - More comprehensive reference results are coming soon.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/f2b80b8b26829124dd393b7a795a0640eff11644/docs/lora.jpg?raw=true\n\n3. Reference configuration for RL training with the Qwen2.5-72B model using 8 x 80GB GPUs (increase lora_rank if needed):\n\n.. code-block::\n\n    data.train_batch_size=64 \\\n    actor_rollout_ref.model.use_shm=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=8 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=64 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n\nExample Scripts\n-------------------\n\nFor end-to-end examples, refer to the scripts below:\n\n- LoRA training from scratch: examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh\n- LoRA training from adapter path: examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora_from_adapter.sh\n"
  },
  {
    "path": "verl_distillation/docs/advance/reward_loop.rst",
    "content": "Reward Loop\n===========\n\n.. _yyding: https://yyding1.github.io\n\nAuthor: `Yuyang Ding <https://yyding1.github.io>`_\n\nLast updated: 10/23/2025.\n\n.. warning::\n   Reward Loop is ready for use, but the API may change in future releaes.\n\nReward Loop is designed for more flexible and easy-to-use reward computation.\n\n**Design goal**:\n\n- Make reward computation more efficient\n- Support broader reward model interface (including discriminative and generative models)\n- Make user customized reward function more flexible\n\n.. image:: https://github.com/yyDing1/verl-materials/blob/main/reward_loop_overview.svg?raw=true\n\nAsync Reward Computation\n------------------------\n\nRewardLoopManager\n~~~~~~~~~~~~~~~~~\n\nThe Reward Loop refactors the design of the reward manager so that each sample is processed asynchronously in the ``run_single`` function.\nThis asynchronous design enables the Reward Loop to handle multiple reward computations concurrently, significantly improving computation efficiency.\n\n.. code:: python\n\n   class RewardLoopManagerBase(ABC):\n      async def run_single(self, data: DataProto) -> dict:\n         # ... (data preprocessing)\n         if self.is_async_reward_score:\n            result = await self.compute_score(\n                  data_source=data_source,\n                  solution_str=response_str,\n                  ground_truth=ground_truth,\n                  extra_info=extra_info,\n                  reward_router_address=self.reward_router_address,\n                  reward_model_tokenizer=self.reward_model_tokenizer,\n            )\n         else:\n            result = await self.loop.run_in_executor(\n                  None,\n                  lambda: self.compute_score(\n                     data_source=data_source,\n                     solution_str=response_str,\n                     ground_truth=ground_truth,\n                     extra_info=extra_info,\n                     reward_router_address=self.reward_router_address,\n                     reward_model_tokenizer=self.reward_model_tokenizer,\n                  ),\n            )\n         # ... (reward postprocessing)\n         return final_result\n\nUser-defined reward functions can be implemented as either synchronous or asynchronous.\n``RewardLoopManager`` automatically detects the type of the user-defined function and executes it accordingly, ensuring that the reward computation process remains non-blocking.\n\nUser-Customized Reward Function\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nUsers can define custom reward functions, for instance, by integrating external generative rewards or rule-based rewards to accommodate diverse scenario requirements.\n\nTo facilitate this, the Reward Loop directly exposes the reward model interface, enabling complex reward computation pipelines that involve model-based scoring.\nA user-defined reward function may look like the following:\n\n.. code:: python\n\n   async def compute_score_gsm8k(\n      data_source: str,\n      solution_str: str,\n      ground_truth: str,\n      extra_info: dict,\n      reward_router_address: str,\n      reward_model_tokenizer: PreTrainedTokenizer,\n   ):\n      \"\"\"Compute the reward score.\"\"\"\n\n      # Step 1: Prepare prompt and request payload\n      grm_prompt = GRM_PROMPT_TEMPLATE.format(problem=extra_info[\"question\"], solution=solution_str)\n      messages = [{\"role\": \"user\", \"content\": grm_prompt}]\n      sampling_params = {\"temperature\": 0.7, \"top_p\": 0.8, \"max_tokens\": 4096}\n      chat_complete_request = {\"messages\": messages, **sampling_params}\n\n      # Step 2: Send async request to the reward model\n      # here, chat_complete sends async http request to the router address\n      result = await chat_complete(\n         router_address=reward_router_address,\n         chat_complete_request=chat_complete_request,\n      )\n\n      # Step 3: Parse model response and extract score\n      grm_response = result.choices[0].message.content.strip()\n      try:\n         score_str = grm_response.split(\"\\n\\n\")[-1].strip()\n         score = int(score_str)\n      except Exception:\n         score = 0\n\n      return {\"score\": score}\n\nRunable examples are provided in the ``recipe/fapo`` directory for reference.\n\nReward Models and Router\n------------------------\n\nTo support flexible and scalable reward model computation, RewardLoop implement a reward router that coordinates requests among multiple reward model servers.\n\nEach reward model runs as an independent server and is registered with the router.\nThis router will forward the requests to the registered reward servers with load balancing and return the results.\nThis design allows us to expose a single unified router address to user-defined reward functions, enabling them to access various reward models seamlessly through the same interface.\n\nRewardModelManager\n~~~~~~~~~~~~~~~~~~\n\n.. image:: https://github.com/yyDing1/verl-materials/blob/main/reward_loop_full.svg?raw=true\n\n``RewardModelManager`` will launch multiple reward servers and register them in the reward router.\n\n.. code:: python\n\n   class RewardModelManager:\n      \"\"\"Reward model manager.\"\"\"\n\n      def __init__(self, config: RewardModelConfig, worker_group: RayWorkerGroup = None):\n         \"\"\"\n         Initialize the reward model manager.\n\n         Args:\n            config (RewardModelConfig): Reward model configuration.\n            worker_group (RayWorkerGroup, optional): Worker group. Defaults to None.\n         \"\"\"\n         self.config = config\n         self.worker_group = worker_group\n         self._initialize_llm_servers()\n         self._initialize_router()\n         if self.config.rollout.free_cache_engine:\n            self.sleep()\n\nReward Router\n~~~~~~~~~~~~~\n\nThe router is to forward the requests to the registered reward servers with load balancing.\n\n- For sglang reward servers, we directly use the sglang router to forward the requests.\n- For vllm reward servers, we implement a simple round-robin ``NaiveRouter`` to dispatch the requests.\n\n.. code:: python\n\n   class NaiveRouter:\n      def __init__(\n         self,\n         worker_urls: list[str],\n         max_connections: int = 1024,\n         timeout: int = 60,\n         max_attempts: int = 3,\n         retry_delay: float = 2.0,\n         verbose: bool = False,\n      ):\n         \"\"\"A minimal async load-balancing router.\"\"\"\n         self.verbose = verbose\n         self.app = FastAPI()\n         self.worker_urls = worker_urls\n         self.request_counts = {url: 0 for url in worker_urls}\n\n         self.max_connections = max_connections\n         self.timeout = timeout\n         self.max_attempts = max_attempts\n         self.retry_delay = retry_delay\n\n         self.app = FastAPI()\n\n         # Register startup / shutdown hooks\n         self.app.on_event(\"startup\")(self._on_startup)\n         self.app.on_event(\"shutdown\")(self._on_shutdown)\n\n         # Catch-all proxy route\n         self.app.api_route(\"/{endpoint:path}\", methods=[\"GET\", \"POST\"])(self._make_async_request)\n\n         # Placeholder for aiohttp client\n         self.client = None\n\nAgent Reward Loop\n-----------------\n\nReward Loop can be integrated with AgentLoop to enable sample-wise rollout and reward computation.\n\n.. image:: https://github.com/yyDing1/verl-materials/blob/main/agent_reward_loop.svg?raw=true\n\n"
  },
  {
    "path": "verl_distillation/docs/advance/rollout_is.md",
    "content": "# Rollout Importance Sampling\n\n**Author:** [Yingru Li](https://richardli.xyz/)\n\nLast updated: 10/27/2025.\n\nThis document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation in verl.\n\n### BibTeX Citation\n\n```bibtex\n@misc{liu-li-2025,\n  title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch},\n  url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda},\n  author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen},\n  year = {2025},\n  month = september,\n}\n```\n\n## Overview\n\nRollout Importance Sampling corrects for distribution mismatch between:\n- **Rollout policy**: e.g., vLLM with BFloat16\n- **Training policy**: e.g., FSDP with FP32\n\nThis mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases.\n\n### Key Design Principle: Separation of IS Weights and Rejection Sampling\n\n**Important**: As of 10/27/2025, the implementation separates two mechanisms:\n\n1. **IS Weights** (`rollout_is_weights`): Ratios π_train/π_rollout with processing:\n   - **Safety-bounded** to [exp(-20), exp(20)] ≈ [2e-9, 5e8] to prevent overflow:\n     * Token level: Bounds per-token ratios\n     * Sequence level: Bounds product of ratios (broadcast to all tokens in sequence)\n     * Geometric level: Bounds geometric mean of ratios (broadcast to all tokens)\n   - **Truncate mode**: Upper clamped via .clamp(max=upper_threshold)\n   - **Mask mode**: Safety-bounded ratios preserved (no threshold clamping)\n   - **All modes**: Zeroed at padding positions (response_mask == 0)\n   - Used for policy gradient calculations\n\n2. **Rejection Sampling** (`modified_response_mask`): Applied via response_mask\n   - Mask mode: Excludes tokens/sequences with outlier IS ratios\n   - Veto: Excludes sequences with catastrophic tokens\n   - Used for loss aggregation (denominator calculation)\n\nThis separation ensures:\n- ✅ Correct loss normalization (rejected samples excluded from denominator)\n- ✅ Mode-specific weight processing (truncate: upper clamped, mask: safety-bounded only)\n- ✅ Padding positions zeroed in weights (necessary for correct aggregation)\n- ✅ Safety bounds always applied (prevent overflow in all modes)\n\n## Configuration\n\n```yaml\n# Rollout IS configuration (all in algorithm config)\nalgorithm:\n  # Main control: set threshold to enable (null = disabled)\n  rollout_is_threshold: 2.0\n  # Whether to apply weights to loss (default: false = metrics only)\n  rollout_is: true\n  rollout_is_threshold_lower: null  # Auto-reciprocal\n  rollout_is_level: token\n  rollout_is_mode: truncate\n  rollout_is_veto_threshold: null  # Disable veto by default\n\n# REQUIRED: Enable log prob calculation\nactor_rollout_ref:\n  rollout:\n    calculate_log_probs: true\n```\n\nKey features:\n- ✅ Three aggregation levels: token, sequence, geometric\n- ✅ Two bounding modes: truncate, mask\n- ✅ Dual threshold support (upper/lower)\n- ✅ Veto mechanism for catastrophic outliers\n- ✅ 30+ comprehensive metrics\n- ✅ Log-space computation for numerical stability\n- ✅ Memory-efficient implementation\n\n## Files\n\n### **Core Implementation**\n\n- `verl/trainer/ppo/mismatch_helper.py` - Contains `compute_rollout_importance_weights()` and `compute_is_metrics()`\n- `verl/trainer/ppo/core_algos.py` - Rollout IS integration with PPO\n- `verl/workers/actor/dp_actor.py` - Metrics collection and logging\n\n### **Configuration Files**\n\n- `verl/trainer/config/algorithm.py` - Rollout IS parameters in `AlgoConfig`\n- `verl/workers/config/actor.py` - Rollout IS parameters in `ActorConfig`\n- `verl/trainer/config/actor/actor.yaml` - Rollout IS configuration section\n- `verl/trainer/config/ppo_trainer.yaml` - Algorithm config with rollout IS\n\n### **Documentation**\n\n- `docs/examples/config.rst` - Configuration parameter descriptions\n\n### **Example Scripts**\n\n- `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh` - DAPO example with rollout IS\n- `examples/rollout_importance_sampling/README.md` - Comprehensive usage guide\n- `examples/rollout_importance_sampling/run_with_rollout_is.sh` - Basic example\n\n### **Tests**\n\n- `tests/trainer/ppo/test_rollout_is.py` - Unit tests\n- `tests/trainer/ppo/test_rollout_is_integration.py` - Integration tests\n\n## Configuration Parameters\n\n### `algorithm.rollout_is_threshold` (float or null)\n**Main on/off switch.** Upper threshold for IS weights.\n- `null` = disabled (no computation, no metrics)\n- `float` value (e.g., 2.0) = enabled (compute weights and metrics)\n\n### `algorithm.rollout_is` (bool)\nWhether to apply IS weights to policy loss. Default: `False`\n- `true` = apply weights to loss (full IS correction)\n- `false` = metrics only mode (no weight correction, but rejection still applies)\n\n**IMPORTANT**: This flag controls IS weight application, NOT rejection sampling. See \"Operation Modes\" below.\n\n**Recommended threshold ranges:**\n- Token level: 1.5 - 5.0\n- Sequence level: 2.0 - 10.0\n- Geometric level: 1.0002 - 1.001\n\n### `algorithm.rollout_is_threshold_lower` (float or null)\nLower threshold for IS weights. If `null`, defaults to 1/upper (reciprocal).\n\n### `algorithm.rollout_is_level` (str)\nAggregation level for IS weights:\n- `\"token\"`: Per-token ratios ρ_t = π_train(t)/π_rollout(t)\n  - Each token has its own IS weight\n  - Safety bound: each token's ratio bounded to [exp(-20), exp(20)]\n  - Biased estimator but low variance\n- `\"sequence\"`: Product of ratios ρ_seq = ∏_t ρ_t for entire sequence\n  - All tokens in a sequence share the same IS weight (product of per-token ratios)\n  - Safety bound: product bounded to [exp(-20), exp(20)], then broadcast to all tokens\n  - Unbiased estimator but high variance\n- `\"geometric\"`: Geometric mean ρ_geo = (∏_t ρ_t)^(1/T) (experimental)\n  - All tokens in a sequence share the same IS weight (geometric mean)\n  - Safety bound: geometric mean bounded to [exp(-20), exp(20)], then broadcast to all tokens\n  - Trade-off between bias and variance\n\n### `algorithm.rollout_is_mode` (str)\nBounding mode for handling outlier IS weights:\n- `\"truncate\"`: Clamp weights at upper threshold only (TIS)\n  - No lower bound clamping or rejection for outlier ratios\n  - **IS weights modified**: Upper bound clamped via .clamp(max=upper_threshold)\n  - Lower bound remains at exp(-20) ≈ 2e-9 from safety bound\n  - **Note**: Veto-based rejection can still occur via response_mask (see `rollout_is_veto_threshold`)\n- `\"mask\"`: Rejection sampling via response_mask (MIS)\n  - Rejects tokens/sequences with IS ratios outside [lower, upper]\n  - **Important**: Rejection applied to `response_mask`, NOT by modifying IS weights\n  - **IS weights**: Safety-bounded ratios preserved (no threshold clamping, rejection via mask)\n  - **Note**: Veto-based rejection also applies via response_mask (independent mechanism)\n\n### `algorithm.rollout_is_veto_threshold` (float or None)\nPer-token veto threshold for catastrophic outliers.\n- If any token has **unclamped** ratio < this threshold, the entire sequence is rejected via `response_mask`\n- Veto checks the **true per-token ratio** π_train(t)/π_rollout(t) before any bounds are applied\n- Applied for all levels (token, sequence, geometric) - always checks individual token ratios\n- Default: `None` (veto disabled by default)\n- Recommended: `1e-4` to `1e-6` when enabled (catches extreme outliers like 10,000x off)\n- Set to `None` to disable veto mechanism\n- **Important**: Applied **independently** of `rollout_is_mode` (works in both truncate and mask modes)\n- Veto applies rejection to `response_mask`, NOT by modifying IS weights\n- **IS weights unchanged by veto**: Already processed by mode (truncate: clamped, mask: safety-bounded)\n\n### Summary: How IS Weights are Processed\n\nThe final IS weights go through multiple stages of processing:\n\n**Stage 1: Safety Bound (All Modes)**\n- Token level: `exp(clamp(log_ratio, -20, 20))` per token → bounds each token to [2e-9, 5e8]\n- Sequence level: `exp(clamp(sum(log_ratio), -20, 20))` → bounds product to [2e-9, 5e8], broadcast to all tokens\n- Geometric level: `exp(clamp(mean(log_ratio), -20, 20))` → bounds geometric mean to [2e-9, 5e8], broadcast to all tokens\n\n**Stage 2: Threshold Processing (Mode-Dependent)**\n- Truncate mode: `.clamp(max=upper_threshold)` → upper clamps weights to threshold\n- Mask mode: No modification → weights remain as safety-bounded ratios\n\n**Stage 3: Padding (All Modes)**\n- `weights * response_mask` → zeros out padding positions\n\n**Rejection Mechanisms (Modify response_mask, NOT weights)**\n- Veto: Checks **unclamped per-token ratios** (before safety bound), rejects sequences via mask\n- Outlier (mask mode only): Checks safety-bounded weights against [lower, upper], rejects via mask\n\n## Operation Modes\n\nThe system has **two independent control flags** that combine to create different operation modes:\n\n1. **`rollout_is_threshold`**: Main on/off switch (None = disabled, float = enabled)\n2. **`rollout_is`**: Apply IS weights to loss (True/False)\n\n### Mode Combinations\n\n| `rollout_is_threshold` | `rollout_is` | `rollout_is_mode` | Behavior |\n|------------------------|--------------|-------------------|----------|\n| `None` | any | any | **Disabled**: No computation, no metrics, no rejection |\n| `2.0` | `False` | `truncate` | **Metrics only**: Compute weights & metrics, NO weight correction, NO rejection for outliers |\n| `2.0` | `False` | `mask` | **Rejection only**: Compute weights & metrics, NO weight correction, YES rejection sampling |\n| `2.0` | `True` | `truncate` | **Truncate mode**: Weight correction enabled, weights upper-clamped, NO rejection for outliers |\n| `2.0` | `True` | `mask` | **Mask mode (full)**: Weight correction enabled, rejection sampling enabled |\n\n### Key Insights\n\n**Rejection sampling is ALWAYS applied when:**\n- `rollout_is_threshold` is set (not None)\n- AND `rollout_is_mode = \"mask\"`\n- **Regardless of the `rollout_is` flag**\n\nThis means:\n- ✅ You can use **rejection sampling alone** without IS weight correction (`rollout_is=False, rollout_is_mode=\"mask\"`)\n- ✅ You can use **IS weights alone** without outlier rejection (`rollout_is=True, rollout_is_mode=\"truncate\"`)\n- ✅ You can use **both together** (`rollout_is=True, rollout_is_mode=\"mask\"`)\n- ✅ You can **monitor metrics only** without any correction or outlier rejection (`rollout_is=False, rollout_is_mode=\"truncate\"`)\n\n**Veto rejection** (if enabled via `rollout_is_veto_threshold`) is applied **independently** in all modes where `rollout_is_threshold` is set.\n\n### Recommended Workflow\n\n1. **Start with metrics only** to understand the mismatch:\n   ```yaml\n   rollout_is_threshold: 2.0\n   rollout_is: false\n   rollout_is_mode: truncate\n   ```\n   Monitor `mismatch/rollout_is_mean`, `mismatch/mismatch_kl` to assess distribution mismatch.\n\n2. **Enable rejection sampling** if you see high outlier fractions:\n   ```yaml\n   rollout_is_threshold: 2.0\n   rollout_is: false\n   rollout_is_mode: mask  # Rejection now applies\n   ```\n   This excludes outliers from training without modifying gradients.\n\n3. **Enable full IS correction** once comfortable with metrics:\n   ```yaml\n   rollout_is_threshold: 2.0\n   rollout_is: true\n   rollout_is_mode: mask  # Both rejection and weight correction\n   ```\n\n## Usage\n\n### Basic Setup\n\n```yaml\nalgorithm:\n  rollout_is_threshold: 2.0  # Main control\n  rollout_is: true           # Apply to loss (default: false)\n  rollout_is_level: token\n  rollout_is_mode: truncate\n\nactor_rollout_ref:\n  rollout:\n    calculate_log_probs: true  # Required!\n```\n\n### Metrics\n\nAll metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appears as `mismatch/rollout_is_mean` in logs.\n\n#### **Core IS Weight Metrics**\n\n- **`rollout_is_mean`**: Mean importance sampling weight across all valid tokens\n  - **Ideal value**: Close to 1.0 (indicates minimal distribution mismatch)\n  - **Warning**: < 0.5 or > 2.0 suggests significant policy mismatch\n\n- **`rollout_is_std`**: Standard deviation of IS weights\n  - **Ideal value**: < 0.5 for stable training\n  - **Warning**: > 1.0 indicates high variance, may need tighter thresholds\n\n- **`rollout_is_min`**: Minimum IS weight observed\n  - Shows the most underweighted token/sequence\n  - For sequence/geometric: computed from unclamped log-space ratios (true minimum)\n  - For token: computed from safety-bounded weights\n\n- **`rollout_is_max`**: Maximum IS weight observed\n  - Shows the most overweighted token/sequence\n  - For sequence/geometric: computed from unclamped log-space ratios (true maximum before safety bound)\n  - For token: computed from safety-bounded weights (before threshold clamping)\n  - Compare with `rollout_is_threshold` to see truncation impact\n\n#### **Effective Sample Size**\n\n- **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting\n  - **Formula**: `1 / mean(weights²)` where weights are normalized\n  - **Range**: 0.0 to 1.0 (as fraction of original batch)\n  - **Ideal value**: > 0.5 (retaining at least 50% effective samples)\n  - **Warning**: < 0.3 means high variance, losing too many effective samples\n\n#### **Veto Mechanism Metrics**\n\n- **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism\n  - **Important**: Sequences are rejected via `response_mask=0`, NOT by modifying IS weights\n  - **IS weights unchanged by veto**: Already processed by mode (truncate: clamped, mask: safety-bounded)\n  - Veto checks **unclamped per-token ratios** π_train(t)/π_rollout(t) (true ratios before safety bound)\n  - Detects catastrophic tokens (true ratio < veto_threshold, e.g., < 1e-4)\n  - **Ideal value**: < 0.05 (less than 5% vetoed)\n  - **Warning**: > 0.1 suggests policies are too different or numerical issues\n\n- **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold\n  - Identifies problematic tokens before sequence-level veto is applied\n  - Checks **unclamped per-token ratios** (true ratios, not safety-bounded)\n  - Each catastrophic token causes its entire sequence to be rejected\n  - **Warning**: > 0.01 indicates widespread distribution issues or numerical instability\n\n#### **Threshold Exceedance Metrics**\n\n- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold\n  - Shows how often truncation/masking occurs on high end\n  - For sequence/geometric: computed from unclamped log-space ratios (true exceedance)\n  - For token: computed from safety-bounded weights (before threshold clamping)\n  - **Ideal value**: < 0.1 (most weights within bounds)\n\n- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold\n  - Shows how often masking occurs on low end (mask mode only)\n  - For sequence/geometric: computed from unclamped log-space ratios (true exceedance)\n  - For token: computed from safety-bounded weights\n  - **Ideal value**: < 0.1\n\n#### **Sequence-Level Metrics** (for sequence/geometric modes)\n\n- **`rollout_is_seq_mean`**: Mean IS weight at sequence level\n  - Should match `rollout_is_mean` for sequence-level aggregation\n\n- **`rollout_is_seq_std`**: Standard deviation of sequence-level IS weights\n\n- **`rollout_is_seq_min`**: Minimum sequence-level IS weight\n\n- **`rollout_is_seq_max`**: Maximum sequence-level IS weight\n\n- **`rollout_is_seq_max_deviation`**: Maximum absolute deviation from 1.0 at sequence level\n  - **Ideal value**: < 1.0\n  - Shows worst-case sequence mismatch\n\n- **`rollout_is_seq_fraction_high`**: Fraction of sequences exceeding upper threshold\n\n- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold\n\n#### **Masking Metrics** (mask mode only)\n\n- **`rollout_is_masked_fraction`**: Fraction of tokens rejected via response_mask (mask mode only)\n  - **Important**: Tokens are rejected by setting `response_mask=0`, NOT by modifying IS weights\n  - **IS weights in mask mode**: Safety-bounded ratios preserved (no threshold clamping)\n  - **Ideal value**: < 0.1 (less than 10% rejected)\n  - **Warning**: > 0.3 means losing too much data\n\n- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one rejected token\n  - Shows sequence-level impact of rejection sampling\n  - For token-level: sequence rejected if ANY token is outside [lower, upper]\n  - For sequence-level: all tokens have same weight, so entire sequence rejected or accepted\n\n#### **Distribution Mismatch Metrics** (Training vs Rollout Policy)\n\n- **`mismatch_training_ppl`**: Perplexity of training policy (e.g., FSDP FP32)\n  - **Formula**: `exp(-mean(log_probs))`\n  - Lower is better (model is more confident)\n\n- **`mismatch_rollout_ppl`**: Perplexity of rollout policy (e.g., vLLM BF16)\n  - Should be close to `mismatch_training_ppl` if policies match well\n\n- **`mismatch_ppl_ratio`**: Ratio of training PPL to rollout PPL\n  - **Formula**: `exp(mean(log(training_ppl / rollout_ppl)))`\n  - **Ideal value**: Close to 1.0\n  - **Meaning**: > 1.0 means training is less confident than rollout\n\n- **`mismatch_training_log_ppl`**: Log perplexity of training policy\n  - Useful for identifying trends (linear scale)\n\n- **`mismatch_rollout_log_ppl`**: Log perplexity of rollout policy\n\n- **`mismatch_log_ppl_diff`**: Mean difference in log perplexities\n  - **Formula**: `mean(log_ppl_rollout - log_ppl_training)`\n  - **Ideal value**: Close to 0.0\n  - Sign indicates which policy is more confident\n\n- **`mismatch_log_ppl_abs_diff`**: Mean absolute log perplexity difference\n  - Magnitude of mismatch regardless of direction\n\n- **`mismatch_log_ppl_diff_max`**: Maximum log perplexity difference across sequences\n  - Identifies worst-case sequence\n\n- **`mismatch_log_ppl_diff_min`**: Minimum log perplexity difference across sequences\n\n- **`mismatch_kl`**: KL divergence KL(π_rollout || π_training)\n  - **Formula**: `mean(log_prob_rollout - log_prob_training)`\n  - **Ideal value**: Close to 0.0 (policies match)\n  - **Warning**: > 0.1 indicates significant mismatch\n  - **Note**: Can be negative (rollout is less confident)\n\n- **`mismatch_k3_kl`**: K3 KL estimator\n  - **Formula**: `mean(exp(log_ratio) - log_ratio - 1)`\n  - More stable for small KL values\n  - Always non-negative\n\n#### **Example: Accessing Metrics in Code**\n\n```python\n# Metrics are returned from compute_rollout_importance_weights\nfrom verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights\n\n# NEW: Returns 3 values (weights, modified_response_mask, metrics)\nweights_proto, modified_response_mask, metrics = compute_rollout_importance_weights(\n    old_log_prob=training_log_probs,      # from training policy\n    rollout_log_prob=rollout_log_probs,   # from rollout policy\n    response_mask=response_mask,\n    rollout_is_level=\"token\",\n    rollout_is_mode=\"mask\",  # Using mask mode for rejection sampling\n    rollout_is_threshold=2.0,\n    rollout_is_threshold_lower=0.5,\n    rollout_is_veto_threshold=1e-4,  # Enable veto for catastrophic outliers\n)\n\n# Extract IS weights (processed, zeroed at padding)\nis_weights = weights_proto.batch[\"rollout_is_weights\"]\n\n# IS weights processing (mask mode with token level):\n# 1. Safety-bounded: exp(clamp(log_ratio, -20, 20)) per token\n# 2. Mask mode: no threshold clamping (safety-bounded ratios preserved)\n# 3. Zeroed at padding positions\n\n# modified_response_mask has rejection applied:\n# 1. Outlier rejection: tokens outside [0.5, 2.0] masked to 0 (mask mode)\n# 2. Veto rejection: sequences with catastrophic tokens (ratio < 1e-4) masked to 0\n# Note: Veto checks unclamped per-token ratios, not the safety-bounded weights\n\n# All metrics have 'mismatch/' prefix\nprint(f\"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}\")\nprint(f\"Effective sample size: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}\")\nprint(f\"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}\")\nprint(f\"Masked fraction: {metrics['mismatch/rollout_is_masked_fraction']:.3f}\")\nprint(f\"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}\")\n\n# Check IS weights for valid tokens (non-padding)\nvalid_weights = is_weights[response_mask.bool()]\nprint(f\"\\n✓ IS weights min (valid tokens): {valid_weights.min():.4f}\")\nprint(f\"✓ IS weights max (valid tokens): {valid_weights.max():.4f}\")\nprint(f\"✓ All valid IS weights > 0: {(valid_weights > 0).all()}\")\n\n# Check rejection via response_mask\nrejected_tokens = (response_mask == 1) & (modified_response_mask == 0)\nprint(f\"\\n✓ Rejected {rejected_tokens.sum()} tokens via response_mask\")\nprint(f\"✓ In mask mode: IS weights for rejected tokens are NON-ZERO (safety-bounded ratios)\")\nprint(f\"✓ In truncate mode: IS weights upper clamped to {rollout_is_threshold}\")\nprint(f\"✓ Both modes: IS weights safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]\")\n\n# Check for warning conditions\nif metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0:\n    print(\"⚠️  Warning: Mean IS weight far from 1.0, significant policy mismatch detected\")\n\nif metrics['mismatch/rollout_is_eff_sample_size'] < 0.3:\n    print(\"⚠️  Warning: Low effective sample size, high variance in IS weights\")\n\nif metrics['mismatch/rollout_is_veto_fraction'] > 0.1:\n    print(\"⚠️  Warning: High veto fraction, policies may be too different\")\n```\n\n#### **Example: Monitoring Metrics During Training**\n\n```python\n# In your training loop\nfor epoch in range(num_epochs):\n    for batch_idx, batch in enumerate(dataloader):\n        # ... rollout phase ...\n\n        # Compute IS weights and get metrics (NEW: 3 return values)\n        weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights(\n            old_log_prob=batch.old_log_prob,\n            rollout_log_prob=batch.rollout_log_prob,\n            response_mask=batch.response_mask,\n            rollout_is_level=config.rollout_is_level,\n            rollout_is_mode=config.rollout_is_mode,\n            rollout_is_threshold=config.rollout_is_threshold,\n            rollout_is_threshold_lower=config.rollout_is_threshold_lower,\n            rollout_is_veto_threshold=config.rollout_is_veto_threshold,\n        )\n\n        # Log to tensorboard/wandb\n        for metric_name, metric_value in metrics.items():\n            logger.log_scalar(metric_name, metric_value, step=global_step)\n\n        # IMPORTANT: Update batch response_mask with rejection applied\n        batch.response_mask = modified_response_mask\n\n        # Use IS weights in training (processed based on mode)\n        # Truncate mode: upper clamped to min(weight, upper_threshold)\n        # Mask mode: safety-bounded ratios preserved (no threshold clamping)\n        # Both modes: safety bounded to [exp(-20), exp(20)], zeroed at padding\n        is_weights = weights_proto.batch[\"rollout_is_weights\"]\n        # ... apply weights to policy gradient ...\n```\n\n#### **Example: Conditional Alerting Based on Metrics**\n\n```python\ndef check_rollout_is_health(metrics, config):\n    \"\"\"Check if rollout IS metrics indicate healthy training.\"\"\"\n    warnings = []\n\n    # Check mean IS weight\n    mean_weight = metrics['mismatch/rollout_is_mean']\n    if mean_weight < 0.5 or mean_weight > 2.0:\n        warnings.append(f\"Mean IS weight {mean_weight:.3f} is far from 1.0\")\n\n    # Check effective sample size\n    ess = metrics['mismatch/rollout_is_eff_sample_size']\n    if ess < 0.3:\n        warnings.append(f\"Effective sample size {ess:.3f} is too low\")\n\n    # Check veto fraction\n    veto_frac = metrics['mismatch/rollout_is_veto_fraction']\n    if veto_frac > 0.1:\n        warnings.append(f\"Veto fraction {veto_frac:.3f} is too high\")\n\n    # Check variance\n    std = metrics['mismatch/rollout_is_std']\n    if std > 1.0:\n        warnings.append(f\"IS weight std {std:.3f} is too high\")\n\n    # Check KL divergence\n    kl = metrics['mismatch/mismatch_kl']\n    if abs(kl) > 0.1:\n        warnings.append(f\"KL divergence {kl:.3f} indicates significant mismatch\")\n\n    if warnings:\n        print(\"⚠️  Rollout IS Health Warnings:\")\n        for warning in warnings:\n            print(f\"  - {warning}\")\n        return False\n    else:\n        print(\"✅ Rollout IS metrics look healthy\")\n        return True\n\n# Use in training (NEW: 3 return values)\n_, _, metrics = compute_rollout_importance_weights(...)\nis_healthy = check_rollout_is_health(metrics, config)\n\nif not is_healthy:\n    # Consider adjusting config or investigating issues\n    print(\"Consider:\")\n    print(\"  - Tightening rollout_is_threshold\")\n    print(\"  - Switching to geometric aggregation level\")\n    print(\"  - Checking if rollout and training policies are too different\")\n```\n\n### Running Examples\n\nStart with the basic token-level truncate configuration:\n```bash\nbash examples/rollout_importance_sampling/run_with_rollout_is.sh\n```\n\nMonitor metrics for 1-2 epochs before adjusting parameters.\n\n## Configuration Examples\n\n### Example 1: Full IS Correction\n```yaml\nalgorithm:\n  rollout_is_threshold: 2.0\n  rollout_is: true  # Apply weights to loss\n  rollout_is_level: token\n  rollout_is_mode: truncate\n```\n\n### Example 2: Metrics Only (Monitoring Mode)\n```yaml\nalgorithm:\n  rollout_is_threshold: 2.0\n  rollout_is: false  # Compute metrics, don't apply weights\n  rollout_is_level: token\n  rollout_is_mode: truncate\n```\n\n### Example 3: Geometric Mean with Mask\n```yaml\nalgorithm:\n  rollout_is_threshold: 1.0002\n  rollout_is: true\n  rollout_is_threshold_lower: 0.9998\n  rollout_is_level: geometric\n  rollout_is_mode: mask\n```\n\n### Example 4: Asymmetric Thresholds\n```yaml\nalgorithm:\n  rollout_is_threshold: 5.0\n  rollout_is: true\n  rollout_is_threshold_lower: 0.8\n  rollout_is_level: token\n  rollout_is_mode: mask\n```\n\n## Troubleshooting\n\n### Issue: High variance in IS weights\n**Symptoms:** `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3\n\n**Solutions:**\n1. Switch from `sequence` to `geometric` level\n2. Tighten thresholds\n3. Verify rollout and training aren't too different\n\n### Issue: Too many sequences vetoed\n**Symptoms:** `rollout_is_veto_fraction` > 0.1\n\n**Solutions:**\n1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3`\n2. Check for numerical issues in log prob computation\n3. Verify policies aren't completely different\n\n### Issue: Mean IS weight far from 1.0\n**Symptoms:** `rollout_is_mean` < 0.5 or > 2.0\n\n**Solutions:**\n1. Verify `calculate_log_probs=True` is set\n2. Check rollout_log_probs are correctly passed\n3. Check for systematic bias\n\n### Debugging: Visualizing Metrics\n\n**Example: Plot IS weight distribution**\n\n```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\ndef plot_is_metrics(metrics_history):\n    \"\"\"Plot rollout IS metrics over training steps.\"\"\"\n    fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n\n    # Plot 1: Mean IS weight over time\n    axes[0, 0].plot(metrics_history['mismatch/rollout_is_mean'])\n    axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')\n    axes[0, 0].set_title('Mean IS Weight')\n    axes[0, 0].set_xlabel('Step')\n    axes[0, 0].legend()\n\n    # Plot 2: Effective sample size\n    axes[0, 1].plot(metrics_history['mismatch/rollout_is_eff_sample_size'])\n    axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good')\n    axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning')\n    axes[0, 1].set_title('Effective Sample Size')\n    axes[0, 1].set_xlabel('Step')\n    axes[0, 1].legend()\n\n    # Plot 3: Veto fraction\n    axes[0, 2].plot(metrics_history['mismatch/rollout_is_veto_fraction'])\n    axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='Warning')\n    axes[0, 2].set_title('Veto Fraction')\n    axes[0, 2].set_xlabel('Step')\n    axes[0, 2].legend()\n\n    # Plot 4: KL divergence over time\n    axes[1, 0].plot(metrics_history['mismatch/mismatch_kl'], label='KL')\n    axes[1, 0].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')\n    axes[1, 0].axhline(y=0, color='g', linestyle='--', alpha=0.3)\n    axes[1, 0].set_title('KL Divergence')\n    axes[1, 0].set_xlabel('Step')\n    axes[1, 0].legend()\n\n    # Plot 5: PPL ratio over time\n    axes[1, 1].plot(metrics_history['mismatch/mismatch_ppl_ratio'])\n    axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='Ideal')\n    axes[1, 1].set_title('PPL Ratio (Training/Rollout)')\n    axes[1, 1].set_xlabel('Step')\n    axes[1, 1].legend()\n\n    # Hide unused subplot\n    axes[1, 2].axis('off')\n\n    plt.tight_layout()\n    plt.savefig('rollout_is_metrics.png', dpi=150)\n    print(\"Saved plot to rollout_is_metrics.png\")\n```\n\n**Example: Metric collection during training**\n\n```python\n# Collect metrics over time\nmetrics_history = {\n    'mismatch/rollout_is_mean': [],\n    'mismatch/rollout_is_eff_sample_size': [],\n    'mismatch/rollout_is_veto_fraction': [],\n    'mismatch/mismatch_kl': [],\n    'mismatch/mismatch_k3_kl': [],\n    'mismatch/mismatch_ppl_ratio': [],\n}\n\n# In training loop\nfor step in range(num_steps):\n    # ... compute IS weights ... (NEW: 3 return values)\n    _, _, metrics = compute_rollout_importance_weights(...)\n\n    # Store metrics\n    for key in metrics_history.keys():\n        if key in metrics:\n            metrics_history[key].append(metrics[key])\n\n    # Plot every 100 steps\n    if step % 100 == 0:\n        plot_is_metrics(metrics_history)\n```\n\n## Performance Impact\n\n- **Memory overhead**: ~1% of model memory\n- **Computational overhead**: 1-3% depending on level\n- **Training stability**: Significantly improved when mismatch exists\n\n\n## Testing\n\nRun the test suite to verify everything works:\n\n```bash\n# Basic unit tests\npython test_rollout_is.py\n\n# Integration tests (if pytest is available)\npytest tests/trainer/ppo/test_rollout_is_integration.py -v\n```\n\nExpected output: All tests pass ✓\n\n## Additional Resources\n\n- **Implementation**: `verl/trainer/ppo/mismatch_helper.py`\n- **Examples**: `examples/rollout_importance_sampling/`\n- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh`\n\n## Summary\n\nRollout Importance Sampling provides:\n- ✅ Robust handling of distribution mismatch\n- ✅ Numerical stability\n- ✅ Comprehensive metrics for monitoring\n- ✅ Flexibility for different scenarios\n- ✅ Memory-efficient computation\n\n## References\n\n- [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda)\n- [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl)"
  },
  {
    "path": "verl_distillation/docs/advance/rollout_skip.rst",
    "content": "RolloutSkip Function Usage Documentation\n========================================\n\nLast updated: 08/01/2025.\n\nApplicable Scenarios\n--------------------\n\nThe RolloutSkip functionality is designed to accelerate the rollout process in reinforcement learning training by caching and reusing previously generated sequences. This feature is particularly useful when:\n\n1. You need to repeatedly run experiments with the same configuration\n\n2. You want to save time by avoiding redundant sequence generation to come close to the optimal policy\n\n\nAPI and Usage Example\n----------------------\n\n2.1 Trainer Adaptation\n~~~~~~~~~~~~~~~~~~~~~~\n\nBoth`RayDAPOTrainer()` (in `verl/recipe/dapo/dapo_ray_trainer.py`) and `RayPPOTrainer()`(in `verl/trainer/ppo/ray_trainer.py``) have already been adapted.\n\nThis is an example of how to patch rollout_skip in RayPPOTrainer.\n\n.. code-block:: python\n\n    #* Import the RolloutSkip class\n    from verl.utils.rollout_skip import RolloutSkip\n\n    ...\n    class RayPPOTrainer:\n        ...\n        def fit(self):\n            ...\n\n            #* Add code as follow:\n            rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)\n            rollout_skip.wrap_generate_sequences()\n\n            ...\n\n            for epoch in range(self.config.trainer.total_epochs):\n                for batch_dict in self.train_dataloader:\n                    ...\n\n2.2 Basic Configuration\n~~~~~~~~~~~~~~~~~~~~~~~\n\nThen, you should add the following parameters to your config to enable the RolloutSkip feature:\n\n.. code-block:: bash\n\n    actor_rollout_ref.rollout.skip_rollout=True \\\n    actor_rollout_ref.rollout.skip_dump_dir=\"/tmp/rollout_dump\" \\\n\n\nNote:\n\n1. The `skip_dump_dir` is the directory where the cached sequences will be stored. Ensure that this directory is writable and accessible by your training process. And make sure that `skip_dump_dir` is not relative path because ray will store the data in `/tmp/ray/session_<session_id>/` and the relative path will not be found in the worker.\n2. The dumped data path follows this naming pattern `{experiment_name}_{project_name}_TrainGBS{train_gbs}__InferGBS{gen_gbs}__N{n}`, once you change the `experiment_name`, `project_name`, `train_gbs`, `gen_gbs`, or `n`, the cached data will be stored in a new directory.\n"
  },
  {
    "path": "verl_distillation/docs/advance/rollout_trace.rst",
    "content": "Trace Function Usage Instructions\n========================================\n\nLast updated: 07/10/2025.\n\nApplicable Scenarios\n--------------------\n\nAgentic RL involves multiple turns of conversations, tool invocations, and user interactions during the rollout process. During the Model Training process, it is necessary to track function calls, inputs, and outputs to understand the flow path of data within the application. The Trace feature helps, in complex multi-round conversations, to view the transformation of data during each interaction and the entire process leading to the final output by recording the inputs, outputs, and corresponding timestamps of functions, which is conducive to understanding the details of how the model processes data and optimizing the training results.\n\nThe Trace feature integrates commonly used Agent trace tools, including wandb weave and mlflow, which are already supported. Users can choose the appropriate trace tool according to their own needs and preferences. Here, we introduce the usage of each tool.\n\n\nTrace Parameter Configuration\n-----------------------------\n\n- ``actor_rollout_ref.rollout.trace.backend=mlflow|weave`` # the trace backend type\n- ``actor_rollout_ref.rollout.trace.token2text=True`` # To show decoded text in trace view\n\n\nGlossary\n--------\n\n+----------------+------------------------------------------------------------------------------------------------------+\n| Object         | Explaination                                                                                         |\n+================+======================================================================================================+\n| trajectory     | A complete multi-turn conversation includes:                                                         |\n|                | 1. LLM output at least once                                                                          |\n|                | 2. Tool Call                                                                                         |\n+----------------+------------------------------------------------------------------------------------------------------+\n| step           | The training step corresponds to the global_steps variable in the trainer                            |\n+----------------+------------------------------------------------------------------------------------------------------+\n| sample_index   | The identifier of the sample, defined in the extra_info.index of the dataset. It is usually a number,|\n|                | but may also be a uuid in some cases.                                                                |\n+----------------+------------------------------------------------------------------------------------------------------+\n| rollout_n      | In the GROP algorithm, each sample is rolled out n times. rollout_n represents the serial number of  |\n|                | the rollout.                                                                                         |\n+----------------+------------------------------------------------------------------------------------------------------+\n| validate       | Whether the test dataset is used for evaluation?                                                     |\n+----------------+------------------------------------------------------------------------------------------------------+\n\nRollout trace functions\n-----------------------\n\nThere are 2 functions used for tracing:\n\n1. ``rollout_trace_op``: This is a decorator function used to mark the functions to trace. In default, only few method has it, you can add it to more functions to trace more infor.\n2. ``rollout_trace_attr``: This function is used to mark the entry of a trajectory and input some info to trace. If you add new type of agent, you may need to add it to enable trace.\n\n\nUsage of wandb weave\n--------------------\n\n1.1 Basic Configuration\n~~~~~~~~~~~~~~~~~~~~~~~\n\n1. Set the ``WANDB_API_KEY`` environment variable\n2. Configuration Parameters\n\n   1. ``actor_rollout_ref.rollout.trace.backend=weave``\n   2. ``trainer.logger=['console', 'wandb']``: This item is optional. Trace and logger are independent functions. When using Weave, it is recommended to also enable the wandb logger to implement both functions in one system.\n   3. ``trainer.project_name=$project_name``\n   4. ``trainer.experiment_name=$experiment_name``\n   5. ``actor_rollout_ref.rollout.mode=async``: Since trace is mainly used for agentic RL, need to enable agent toop using async mode for either vllm or sglang.\n\nNote:\nThe Weave Free Plan comes with a default monthly network traffic allowance of 1GB. During the training process, the amount of trace data generated is substantial, reaching dozens of gigabytes per day, so it is necessary to select an appropriate wandb plan.\n\n\n1.2 View Trace Logs\n~~~~~~~~~~~~~~~~~~~\n\nAfter executing the training, on the project page, you can see the WEAVE sidebar. Click Traces to view it.\n\nEach Trace project corresponds to a trajectory. You can filter and select the trajectories you need to view by step, sample_index, rollout_n, and experiment_name.\n\nAfter enabling token2text, prompt_text and response_text will be automatically added to the output of ToolAgentLoop.run, making it convenient to view the input and output content.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/weave_trace_list.png?raw=true\n\n1.3 Compare Trace Logs\n~~~~~~~~~~~~~~~~~~~~~~\n\nWeave can select multiple trace items and then compare the differences among them.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/weave_trace_compare.png?raw=true\n\nUsage of mlflow\n---------------\n\n1. Basic Configuration\n~~~~~~~~~~~~~~~~~~~~~~\n\n1. Set the ``MLFLOW_TRACKING_URI`` environment variable, which can be:\n\n   1. Http and https URLs corresponding to online services\n   2. Local files or directories, such as ``sqlite:////tmp/mlruns.db``, indicate that data is stored in ``/tmp/mlruns.db``. When using local files, it is necessary to initialize the file first (e.g., start the UI: ``mlflow ui --backend-store-uri sqlite:////tmp/mlruns.db``) to avoid conflicts when multiple workers create files simultaneously.\n\n2. Configuration Parameters\n\n   1. ``actor_rollout_ref.rollout.trace.backend=mlflow``\n   2. ``trainer.logger=['console', 'mlflow']``. This item is optional. Trace and logger are independent functions. When using mlflow, it is recommended to also enable the mlflow logger to implement both functions in one system.\n   3. ``trainer.project_name=$project_name``\n   4. ``trainer.experiment_name=$experiment_name``\n\n\n2. View Log\n~~~~~~~~~~~\n\nSince ``trainer.project_name`` corresponds to Experiments in mlflow, in the mlflow view, you need to select the corresponding project name, then click the \"Traces\" tab to view traces. Among them, ``trainer.experiment_name`` corresponds to the experiment_name of tags, and tags corresponding to step, sample_index, rollout_n, etc., are used for filtering and viewing.\n\nFor example, searching for ``\"tags.step = '1'\"`` can display all trajectories of step 1.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/mlflow_trace_list.png?raw=true\n\nOpening one of the trajectories allows you to view each function call process within it.\n\nAfter enabling token2text, prompt_text and response_text will be automatically added to the output of ToolAgentLoop.run, making it convenient to view the content.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/mlflow_trace_view.png?raw=true\n\nNote:\n\n1. mlflow does not support comparing multiple traces\n2. rollout_trace can not associate the mlflow trace with the run, so the trace content cannot be seen in the mlflow run logs.\n"
  },
  {
    "path": "verl_distillation/docs/advance/rope.rst",
    "content": "RoPE Scaling override\n=======================================\n\nLast updated: 05/14/2025.\n\nSome models such as `Qwen/Qwen2.5-7B-Instruct <https://huggingface.co/Qwen/Qwen2.5-7B-Instruct#processing-long-texts>`_ support RoPE Scaling but don't have it defined in their config.json file.\nFor example, this model supports this configuration:\n\n.. code:: python\n\n    {\n        ...,\n        \"rope_scaling\": {\n            \"factor\": 4.0,\n            \"original_max_position_embeddings\": 32768,\n            \"type\": \"yarn\"\n        }\n    }\n\n\n\nIn order to support a longer context for such models, you must override the model configs when starting the trainer.\n\nPPO example:\n\n.. code:: bash\n\n    +actor_rollout_ref.model.override_config.rope_scaling.type=yarn \\\n    +actor_rollout_ref.model.override_config.rope_scaling.factor=4.0 \\\n    +actor_rollout_ref.model.override_config.rope_scaling.original_max_position_embeddings=32768 \\\n\n\nAnd for the critic model\n\n.. code:: bash\n\n    +critic.model.override_config.rope_scaling.type=yarn \\\n    +critic.model.override_config.rope_scaling.factor=4.0 \\\n    +critic.model.override_config.rope_scaling.original_max_position_embeddings=32768 \\\n"
  },
  {
    "path": "verl_distillation/docs/algo/baseline.md",
    "content": "# Algorithm Baselines\n\nLast updated: 06/18/2025.\n\n## Math related datasets\n\n### GSM8k\n\nAssuming GSM8k/math dataset is preprocessed via:\n\n```bash\npython3 examples/data_preprocess/*.py\n```\n\nRefer to the table below to reproduce RL training from different pre-trained checkpoints. Below is the performance on the GSM8k dataset if not specified otherwise. More comprehensive benchmark results areavailable in the recipe folder.\n\n\n| Hardware    | Model                            | Method            | Test score   | Details |\n|-------------|----------------------------------|-------------------|--------------|---------|\n| NVIDIA GPU  | google/gemma-2-2b-it             | hf checkpoint     | 23.9         | [Huggingface](https://huggingface.co/google/gemma-2-2b-it#benchmark-results) |\n| NVIDIA GPU  | google/gemma-2-2b-it             | SFT               | 52.06        | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/gemma-2-2b-it-sft-0.411.log) |\n| NVIDIA GPU  | google/gemma-2-2b-it             | SFT + PPO         | 64.02        | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/gemma-2-2b-it-ppo-bsz512_4-prompt1024-resp-512-0.640.log), [wandb](https://api.wandb.ai/links/verl-team/h7ux8602) |\n| NVIDIA GPU  | Qwen/Qwen2.5-0.5B-Instruct       | hf checkpoint     | 36.4         | [Qwen blog](https://qwenlm.github.io/blog/qwen2.5-llm/) |\n| NVIDIA GPU  | Qwen/Qwen2.5-0.5B-Instruct       | PPO               | 56.7         | [command and log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) |\n| NVIDIA GPU  | Qwen/Qwen2.5-0.5B-Instruct       | PRIME             | 58.7         | [script](https://github.com/volcengine/verl/blob/main/recipe/prime/run_prime_qwen.sh), [wandb](https://api.wandb.ai/links/zefan-wang-thu-tsinghua-university/rxd1btvb) |\n| NVIDIA GPU  | Qwen/Qwen2.5-0.5B-Instruct       | GRPO-LoRA         | 54.3         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz64_2-prompt512-resp1024-lorarank32-score0.543.log)|\n| NVIDIA GPU  | Qwen/Qwen2.5-1.5B-Instruct       | GRPO-LoRA         | 77.9         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-1.5B-bsz64_2-prompt512-resp1024-lorarank32-score0.779.log)|\n| NVIDIA GPU  | Qwen/Qwen2.5-3B-Instruct         | GRPO-LoRA         | 86.1         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-3B-bsz64_2-prompt512-resp1024-lorarank32-score0.861.log)|\n| NVIDIA GPU  | deepseek-ai/deepseek-llm-7b-chat | PPO (Megatron)    | 69.5 [1]     | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/deepseek-llm-7b-chat-megatron-bsz256_4-prompt512-resp512-0.695.log), [wandb](https://wandb.ai/verl-team/verl_megatron_gsm8k_examples/runs/10fetyr3) |\n| NVIDIA GPU  | Qwen/Qwen2-7B-Instruct           | GRPO              | 89           | [script](https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh) |\n| NVIDIA GPU  | Qwen/Qwen2-7B-Instruct           | GRPO (FSDP2)      | 89.8         | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log) |\n| NVIDIA GPU  | Qwen/Qwen2-7B-Instruct           | GRPO (Megatron)   | 89.6         | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b_math_megatron.log) |\n| NVIDIA GPU  | Qwen/Qwen2.5-7B-Instruct         | ReMax             | 97           | [script](https://github.com/eric-haibin-lin/verl/blob/main/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh), [wandb](https://wandb.ai/liziniu1997/verl_remax_example_gsm8k/runs/vxl10pln) |\n| NVIDIA GPU  | Qwen/Qwen2.5-7B-Instruct         | SPPO              | 65.6 (MATH)  | [SPPO script](https://github.com/volcengine/verl/tree/main/recipe/sppo/README.md) |\n| NVIDIA GPU  | Qwen/Qwen2.5-7B-Instruct         | GRPO-LoRA         | 93.4         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-7B-bsz64_8-prompt512-resp1024-lorarank32-score0.934.log)|\n| NVIDIA GPU  | Mixtral-8x22B-Instruct-v0.1      | Instruct model    | 83.7         | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/) |\n| NVIDIA GPU  | Mixtral-8x22B-Instruct-v0.1      | RLOO (Megatron)   | 92.3         | [wandb](https://api.wandb.ai/links/ppo_dev/sbuiuf2d) |\n| NVIDIA GPU  | Qwen/Qwen2.5-7B-Instruct         | SPIN              | 92           | [script](https://github.com/volcengine/verl/tree/main/recipe/spin/README.md) |\n| NVIDIA GPU  | Qwen/Qwen2-7B-Instruct           | GPG               | 88           | [log](https://github.com/diqiuzhuanzhuan/verldata/blob/main/run_logs/qwen2-7b_math.log), [wandb](https://wandb.ai/diqiuzhuanzhuan/verl_gpg_example_gsm8k_math/runs/ab86c4va) |\n| NVIDIA GPU  | Qwen/Qwen2-7B-Instruct           | GPG (Megatron)    | 88           | [log](https://github.com/diqiuzhuanzhuan/verldata/blob/main/run_logs/qwen2-7b_math_megatron.log), [wandb](https://wandb.ai/diqiuzhuanzhuan/verl_gpg_example_gsm8k_math/runs/yy8bheu8) |\n| NVIDIA GPU  | Qwen/Qwen2.5-VL-7B-Instruct      | GRPO (Megatron)   | 65.4 (GEO3k) | [script](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh), [wandb](https://api.wandb.ai/links/megatron-core-moe-dev/1yngvkek) |\n| AMD MI300   | deepseek-ai/deepseek-llm-7b-chat | PPO               | 70.5 [1]     | [log](https://github.com/yushengsu-thu/verl_training_log/blob/main/gsm8k/ppo_run_deepseek7b_llm.log) |\n| AMD MI300   | deepseek-ai/deepseek-llm-7b-chat | GRPO              | 71.4 [1]     | [log](https://github.com/yushengsu-thu/verl_training_log/blob/main/gsm8k/grpo_run_deepseek7b_llm.log) |\n| NVIDIA GPU  | Qwen/Qwen2.5-14B-Instruct         | GRPO-LoRA         | 94.6         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-14B-bsz64_8-prompt512-resp1024-lorarank32-score0.946.log)|\n| NVIDIA GPU  | Qwen/Qwen2.5-32B-Instruct         | GRPO-LoRA         | 95.8         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-32B-bsz64_8-prompt512-resp1024-lorarank32-score0.958.log)|\n| NVIDIA GPU  | Qwen/Qwen2.5-72B-Instruct         | GRPO-LoRA         | 96.0         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-72B-bs64_8-prompt512-resp1024-lorarank32-score0.960.log)|\n\n### DAPO math-17k\n\n- Training DAPO math-17k dataset: https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k\n- Testing: AIME'24: https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024\n\nNote:\n- For Qwen/Qwen2.5-Math-7B, we directly modify the max_position_embeddings to 32768 without observing performance degradation in order to train longer response length.\n\n| Hardware    | Model                       | Method                  | Test score | Details |\n|-------------|-----------------------------|-------------------------|------------|---------|\n| NVIDIA GPU  | Qwen/Qwen2.5-Math-7B (32k)  | DAPO                    | 36.3       | [command](https://github.com/volcengine/verl/blob/main/recipe/dapo/test_dapo_7b_math.sh), [logs](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361)|\n| NVIDIA GPU  | Qwen/Qwen2.5-7B-Instruct    | DAPO + Code Interpreter | 40.0       | [command](https://github.com/volcengine/verl/blob/main/recipe/retool/run_qwen2_7b_dapo.sh)|\n\n\n\n\n## Coding related datasets\n\nBelow is the result on leetcode if not specified otherwise.\n\n| Hardware    | Model                            | Method            | Test score   | Details |\n|-------------|----------------------------------|-------------------|--------------|---------|\n| NVIDIA GPU  | PRIME-RL/Eurus-2-7B-SFT          | RPIME             | 36.1         | [script](https://github.com/volcengine/verl/blob/main/recipe/prime/run_prime_qwen_code.sh), [swanlab](https://swanlab.cn/@wangzefan/prime_example/runs/7f541qhspgmy8nmhdlx35/chart) |\n\n\n### Notes\n\n[1] During evaluation, we have only extracted answers following the format `\"####\"`. A more flexible answer extraction, longer response length, and better prompt engineering may lead to a higher score.\n\n[2] The default value of `actor_rollout_ref.actor.entropy_coeff` is set to `0.0` since verl 0.3.x on 2025-05-30, which is different from previous versions.\n"
  },
  {
    "path": "verl_distillation/docs/algo/collabllm.md",
    "content": "# Recipe: CollabLLM \n\nLast updated: 09/22/2025.\n\n> Open-Source Algorithm Implementation & Expriement Running: [Haiquan Chen](https://github.com/chenhaiq), [Shirley Wu](https://github.com/Wuyxin)\n\n🏠 [Homepage](https://aka.ms/CollabLLM) | 📝 [Paper](https://arxiv.org/pdf/2502.00640) | 🤗 [Datasets & Models](https://huggingface.co/collabllm) | ⭐️ [Original Implementation](https://github.com/Wuyxin/collabllm)\n\n`verl` provides a recipe for the Outstanding Paper at ICML 2025, **\"CollabLLM: From Passive Responders to Active Collaborators\"**. [CollabLLM](https://aka.ms/CollabLLM) is a unified fine-tuning framework that optimizes LLMs for effective and efficient multiturn collaboration with users.\n\n**Core Idea:** Models are rewarded based on how well their responses enable effective *future* collaboration with users.\n\nPaper Authors: [Shirley Wu](https://cs.stanford.edu/~shirwu/), [Michel Galley](https://www.microsoft.com/en-us/research/people/mgalley/), Baolin Peng, Hao Cheng, Gavin Li, Yao Dou, Weixin Cai, [James Zou](https://www.james-zou.com/), [Jure Leskovec](https://cs.stanford.edu/people/jure/), [Jianfeng Gao](https://www.microsoft.com/en-us/research/people/jfgao/)\n\n\n---\n## Quick Start\n\n### 0. Environment\nMake sure the required packages for `verl` are installed. Additionally, install `litellm` and export the required API keys. The API model will be used for user simulators and, optionally, LLM Judges (see the Configuration section below).\n\n### 1. Prepare Your Dataset\n\nFirst, process your dataset using the provided script (see example commands and usage in `process_dataset.py`):\n\n```bash\npython process_dataset.py --dataset <> ... --dataset_type <sft or rl>\n```\n\n\n**Requirements:**\n- Input: A Hugging Face multiturn dataset. Existing datasets: `collabllm/collabllm-multiturn-$DATASET`, with `DATASET` in one of [`math-hard(-large)`, `medium(-large)`, `bigcodebench(-large)`] (*-large are the datasets used in the CollabLLM paper)\n- Example format: See [collabllm-multiturn-math-hard](https://huggingface.co/datasets/collabllm/collabllm-multiturn-math-hard)\n- To generate your own dataset: Use [build_dataset.py](https://github.com/Wuyxin/collabllm/blob/main/scripts/engine/build_dataset.py) from the original CollabLLM repository\n\n\n### 2. Train Your Model\n\n**(Optional) For Supervised Fine-Tuning (SFT):**\n```bash\nbash train_sft_collabllm.sh\n```\n\n**For Reinforcement Learning (RL):**\n\n```bash\nbash train_rl_collabllm.sh\n```\n\nThe RL script shows an example to train CollabLLM on `math-hard-large`. \n\n- The config to sample future conversations are in `recipe/collabllm/config/collabllm_interaction_config.yaml`. \n- The Multiturn-aware Reward is aggregated from these three conversational-level rewards:\n\n    ```\n    +reward_model.reward_kwargs.metric_weights.accuracy=1 \\\n    +reward_model.reward_kwargs.metric_weights.interactivity=1 \\\n    +reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \\\n    ```\n\n    You can remove, add, or modify the weights depending on your task. A list of implemented metrics you can already add are under `recipe/collabllm/metrics`. For example, on `medium-large`, you can replace `accuracy` with `bleu_score` via\n    ```\n    +reward_model.reward_kwargs.metric_weights.bleu_score=1 \n    ```\n    which will instead apply bleu score on the sampled future conversations. \n\n## Algorithm\n\n| Step | Name                          | Description                                                                 |\n|------|-------------------------------|-----------------------------------------------------------------------------|\n| 1    | Model response generation     | The model generates multiple responses for each prompt in a batch.          |\n| 2    | Collaborative simulation      | A user simulator (e.g., GPT or Claude) samples `num_repeat_rollouts` conversations for up to `max_user_turns` additional turns. |\n| 3    | Compute Multiturn-aware Reward | Customized conversational reward functions are applied to the sampled conversations. Rewards are aggregated, then averaged across rollouts. |\n| 4    | Update model                  | The model weights are updated using the computed multiturn-aware rewards.  |\n\n---\n\n## Configuration\n\nThe primary configuration is managed through the launch script `train_rl_collabllm.sh` and the YAML file `recipe/collabllm/config/collabllm_interaction_config.yaml`. Key configuration sections:\n\n| Section              | Key Parameters / Notes                                                                 |\n|----------------------|-----------------------------------------------------------------------------------------|\n| `data`               | Paths to training/validation files, batch sizes, sequence lengths.                      |\n| `actor_rollout_ref` (common) | Base model path (used for actor + initial reference), FSDP settings, optimization (LR, scheduler). |\n| `actor_rollout_ref` (CollabLLM-specific) | Hyperparameters under `actor_rollout_ref.rollout.multi_turn`: `max_user_turns`, `max_assistant_turns`, `num_repeat_rollouts`. |\n| `interaction`        | Defined in `collabllm_interaction_config.yaml`. Specifies user simulator and hyperparameters. Requires exported API keys. |\n| `reward_model`       | Manager set to `collabllm` by default. Modify `reward_model.reward_kwargs.metric_weights` for conversational rewards and weights. LLM Judge hyperparameters (e.g., `model`, `temperature`) go under `reward_model.reward_kwargs.llm_judge_kwargs`. |\n| `algorithm`          | GRPO-specific hyperparameters such as `actor_rollout_ref.rollout.n`.                    |\n| `trainer`            | Distributed training (nodes, GPUs per node), logging (WandB), checkpointing frequency.  |\n\n---\n\n## Key Files\n\n| File Path | Purpose |\n|-----------|---------|\n| `recipe/collabllm/collabllm_agent_loop.py` | Main logic to sample future conversations, using `CollabLLMInteraction` from `verl/interactions/collabllm_interaction.py`. |\n| `verl/workers/reward_manager/collabllm.py` | Computes rewards for future conversations, leveraging `recipe/collabllm/reward_function.py` to apply each metric. |\n\n---\n\n## Acknowledgement\n\nWe sincerely thank the `verl` community and advisors for their contributions and guidance!\n"
  },
  {
    "path": "verl_distillation/docs/algo/dapo.md",
    "content": "# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO)\n\nLast updated: 06/19/2025.\n\n> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211)\n\n🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper@arXiv](https://arxiv.org/abs/2503.14476) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO)\n\n> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps.\n>\n> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png)\n\n## Quickstart\n\n1. Prepare the datasets **on the Ray cluster**:\n\n```bash\nbash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default\n```\n\n2. Submit the job to the Ray cluster **from any machine**:\n\n```bash\ncd verl # Repo root\nexport RAY_ADDRESS=\"http://${RAY_IP:-localhost}:8265\" # The Ray cluster address to connect to\nexport WORKING_DIR=\"${PWD}\" # The local directory to package to the Ray cluster\n# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml\nexport RUNTIME_ENV=\"./recipe/dapo/runtime_env.yaml\" # This sets environment variables for the Ray cluster\nbash recipe/dapo/run_dapo_qwen2.5_32b.sh # or other scripts\n```\n\n## Reproduction Runs\n\n| Setup                                        | AIME 2024 Acc. | Hardware  | Image                                                                | Commit                                                                                       | Environment Variables                                                                                                             | Training Script                                                                                                                                             | Training Record                                                                           |\n| -------------------------------------------- | -------------- | --------- | -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- |\n| DAPO                                         | 52%            | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_qwen2.5_32b.sh)             | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n| DAPO w/o Dynamic Sampling                    | 50%            | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_wo_ds_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n| DAPO w/o Token-level Loss & Dynamic Sampling | 44%            | 16x8xH20  | `hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix`                    | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_early_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n\n> [!IMPORTANT]\n>\n> **📢 Call for Contribution!**\n>\n> Welcome to submit your reproduction runs and setups!\n\n## Configuration\n\n### Separated Clip Epsilons (-> Clip-Higher)\n\nAn example configuration:\n\n```yaml\nactor_rollout_ref:\n  actor:\n    clip_ratio_low: 0.2\n    clip_ratio_high: 0.28\n```\n\n`clip_ratio_low` and `clip_ratio_high` specify the $\\varepsilon_{\\text {low }}$ and $\\varepsilon_{\\text {high }}$ in the DAPO objective.\n\nCore relevant code:\n\n```python\npg_losses1 = -advantages * ratio\npg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)\npg_losses = torch.maximum(pg_losses1, pg_losses2)\n```\n\n### Dynamic Sampling (with Group Filtering)\n\nAn example configuration:\n\n```yaml\ndata:\n  gen_batch_size: 1536\n  train_batch_size: 512\nalgorithm:\n  filter_groups:\n    enable: True\n    metric: acc # score / seq_reward / seq_final_reward / ...\n    max_num_gen_batches: 10 # Non-positive values mean no upper limit\n```\n\nSetting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0.\n\nThe trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`.\n\nCore relevant code:\n\n```python\nprompt_bsz = self.config.data.train_batch_size\nif num_prompt_in_batch < prompt_bsz:\n    print(f'{num_prompt_in_batch=} < {prompt_bsz=}')\n    num_gen_batches += 1\n    max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches\n    if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:\n        print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...')\n        continue\n    else:\n        raise ValueError(\n            f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.'\n        )\nelse:\n    # Align the batch\n    traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n\n    batch = batch[:traj_bsz]\n```\n\n### Flexible Loss Aggregation Mode (-> Token-level Loss)\n\nAn example configuration:\n\n```yaml\nactor_rollout_ref:\n  actor:\n    loss_agg_mode: \"token-mean\" # / \"seq-mean-token-sum\" / \"seq-mean-token-mean\"\n    # NOTE: \"token-mean\" is the default behavior\n```\n\nSetting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch.\n\nCore relevant code:\n\n```python\nif loss_agg_mode == \"token-mean\":\n    loss = verl_F.masked_mean(loss_mat, loss_mask)\nelif loss_agg_mode == \"seq-mean-token-sum\":\n    seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum\n    loss = torch.mean(seq_losses)  # seq-mean\nelif loss_agg_mode == \"seq-mean-token-mean\":\n    seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)  # token-mean\n    loss = torch.mean(seq_losses)  # seq-mean\nelse:\n    raise ValueError(f\"Invalid loss_agg_mode: {loss_agg_mode}\")\n```\n\n### Overlong Reward Shaping\n\nAn example configuration:\n\n```yaml\ndata:\n  max_response_length: 20480 # 16384 + 4096\nreward_model:\n  overlong_buffer:\n    enable: True\n    len: 4096\n    penalty_factor: 1.0\n```\n\nSetting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit.\n\nSpecifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length - overlong_buffer.len` by `0` to `overlong_buffer.len` tokens.\n\nCore relevant code:\n\n```python\nif self.overlong_buffer_cfg.enable:\n    overlong_buffer_len = self.overlong_buffer_cfg.len\n    expected_len = self.max_resp_len - overlong_buffer_len\n    exceed_len = valid_response_length - expected_len\n    overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor\n    overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)\n    reward += overlong_reward\n```\n\n## FAQ\n\n### Where is the \"Overlong Filtering\" in the paper?\n\nMost experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here.\n\n### What's the difference between [the `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) and the [`recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)?\n\n[The `recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) is for **as-is reproduction** and thus won't be updated with new features.\n\n[The `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) works as an example of how to extend the latest `verl` to implement an algorithm recipe, which will be maintained with new features.\n\n### Why can't I produce similar results after modifications?\n\nRL infrastructures nowadays still have inherent unrobustness, on which we are still working hard to improve.\n\nWe strongly recommend to only modify one thing at a time.\n\nWe also list some known problems here:\n\n1. Enabling CUDA graph (`enforce_eager=False`) might cause model performance degradation, whose cause is still under investigation.\n"
  },
  {
    "path": "verl_distillation/docs/algo/entropy.md",
    "content": "# Recipe: Entropy Mechanism\n\nLast updated: 06/27/2025.\n\n\n<div align=\"center\">\n\n  The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning.\n\n[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617)  [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue\n)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861)\n\n\n<div align=\"center\" style=\"font-family: Arial, sans-serif;\">\n  <p>\n    <a href=\"#🎉news\" style=\"text-decoration: none; font-weight: bold;\">🎉 News</a> •\n    <a href=\"#✨getting-started\" style=\"text-decoration: none; font-weight: bold;\">✨ Getting Started</a> •\n    <a href=\"#📖introduction\" style=\"text-decoration: none; font-weight: bold;\">📖 Introduction</a>\n  </p>\n  <p>\n    <a href=\"#🎈citation\" style=\"text-decoration: none; font-weight: bold;\">🎈 Citation</a> •\n    <a href=\"#🌻acknowledgement\" style=\"text-decoration: none; font-weight: bold;\">🌻 Acknowledgement</a> •\n    <a href=\"#📬Contact\" style=\"text-decoration: none; font-weight: bold;\">📬 Contact</a> •\n    <a href=\"#📈star-history\" style=\"text-decoration: none; font-weight: bold;\">📈 Star History</a>\n  </p>\n</div>\n\n</div>\n\n\n## 🎉News\n\n- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29).\n- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse. \n\n\n\n## ✨Getting started\n\nAfter preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run:\n\n```\ncd verl\nconda activate your_env\nbash recipe/dapo/7b_kl_cov.sh\n```\n\nWhile for training Qwen2.5-32B on multi nodes, you can run the following commands:\n\n```\ncd verl\nconda activate your_env\nbash recipe/dapo/32b_kl_cov.sh\n```\n\n## 📖Introduction\n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/e2a.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\nThis paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion. \n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/cov.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\nTheoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose ​​Clip-Cov​​ and ​​KL-Cov​​, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance. \n\n## 📃Evaluation\n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/performance_fig.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\n\nOur method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL. \n| **Method**        | **AIME24** | **AIME25** |  **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** |\n| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: |\n| *Qwen2.5-7B*      |            |            |          |              |               |                   |             |          |\n| GRPO              |       21.2 |        9.6 |     58.7 |         78.8 |          27.9 |              40.7 |        36.7 |     38.6 |\n| w. Clip-higher    |       18.1 |       11.5 |     56.6 |         79.2 |          29.8 |              43.3 |        40.4 |     38.8 |\n| w. **`CLIP-Cov`** |       22.1 |   **15.8** |     58.2 |         80.4 |      **30.5** |          **44.1** |    **41.1** |     40.4 |\n| w. **`KL-Cov`**   |   **22.6** |       12.9 | **61.4** |     **80.8** |          29.1 |              42.6 |        38.2 | **40.6** |\n| *Qwen2.5-32B*     |            |            |          |              |               |                   |             |          |\n| GRPO              |       21.8 |       16.2 |     69.7 |         84.2 |          35.2 |              43.6 |        45.5 |     45.8 |\n| w. Clip-higher    |       35.6 |       22.3 |     69.5 |         77.2 |          35.1 |              42.5 |        43.0 |     47.2 |\n| w. **`CLIP-Cov`** |       32.3 |       22.7 |     67.2 |     **87.0** |      **42.0** |          **57.2** |        46.0 |     50.3 |\n| w. **`KL-Cov`**   |   **36.8** |   **30.8** | **74.5** |         84.6 |          39.1 |              49.0 |    **46.3** | **52.2** |\n\nOur two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively.\n\n\n## 🎈Citation\nIf you find this paper or repo helpful, please cite us.\n\n```bibtex\n@article{cui2025entropy,\n  title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models},\n  author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others},\n  journal={arXiv preprint arXiv:2505.22617},\n  year={2025}\n}\n```\n## 🌻Acknowledgement\nWe implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions!\n\n## 📬 Contact\n\nFor questions, discussion, or collaboration opportunities, feel free to contact:\n- Ganqu Cui: cuiganqu@pjlab.org.cn\n- Yuchen Zhang: yuchen.zhang2003@gmail.com\n- Jiacheng Chen: jackchan9345@gmail.com\n- Ning Ding: ningding.cs@gmail.com\n\n"
  },
  {
    "path": "verl_distillation/docs/algo/gpg.md",
    "content": "# GPG: Group Policy Gradient\n\nLast updated: 07/03/2025.\n\nGroup Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning\n](https://arxiv.org/abs/2504.02546).\n\n## Key Components\n- Use a corrected advantage function to improve policy gradient accuracy and training efficiency.\n- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO)\n\n## Configuration\nTo configure GPG within the framework, use the following YAML settings.\n\n```yaml\nalgorithm:\n  adv_estimator: gpg \nactor_rollout_ref:\n  actor:\n    policy_loss:\n      loss_mode: \"gpg\"\n```\n\n## Advanced Extensions\nGPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance.\n\n```yaml\nalgorithm:\n  adv_estimator: gpg\nactor_rollout_ref:\n  actor:\n    use_kl_loss: True # enable kl regularization\n    kl_loss_coef: 0.01\n    policy_loss:\n      loss_mode: \"gpg\"\n```"
  },
  {
    "path": "verl_distillation/docs/algo/grpo.md",
    "content": "# Group Relative Policy Optimization (GRPO)\n\nLast updated: 05/31/2025.\n\nIn reinforcement learning, classic algorithms like PPO rely on a \"critic\" model to estimate the value of actions, guiding the learning process. However, training this critic model can be resource-intensive. \n\nGRPO simplifies this process by eliminating the need for a separate critic model. Instead, it operates as follows:\n- Group Sampling: For a given problem, the model generates multiple possible solutions, forming a \"group\" of outputs.\n- Reward Assignment: Each solution is evaluated and assigned a reward based on its correctness or quality.\n- Baseline Calculation: The average reward of the group serves as a baseline. \n- Policy Update: The model updates its parameters by comparing each solution's reward to the group baseline, reinforcing better-than-average solutions and discouraging worse-than-average ones.\n\nThis approach reduces computational overhead by avoiding the training of a separate value estimation model, making the learning process more efficient. For more details, refer to the original paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/pdf/2402.03300)\n\n## Key Components\n\n- No Value Function (Critic-less): unlike PPO, GRPO does not train a separate value network (critic)\n- Group Sampling (Grouped Rollouts): instead of evaluating one rollout per input, GRPO generates multiple completions (responses) from the current policy for each prompt. This set of completions is referred to as a group.\n- Relative Rewards: within each group, completions are scored (e.g., based on correctness), and rewards are normalized relative to the group.\n\n## Configuration\n\nNote that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior.\n\nDespite that many configurations start with the `ppo_` prefix, they work across different RL algorithms in verl, as the GRPO training loop is similar to that of PPO (without critic).\n\n![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d)\n\n- `actor_rollout.ref.rollout.n`: For each prompt, sample n times. Default to 1. For GRPO, please set it to a value larger than 1 for group sampling.\n\n- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n`\n\n- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers.\n\n- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for GRPO updates on one set of sampled trajectories for actor\n\n- `actor_rollout_ref.actor.clip_ratio`: The GRPO clip range. Default to 0.2\n\n- `algorithm.adv_estimator`: Default is gae. Please set it to grpo instead\n\n- `actor_rollout_ref.actor.loss_agg_mode`: Default is \"token-mean\". Options include \"token-mean\", \"seq-mean-token-sum\", \"seq-mean-token-mean\". The original GRPO paper takes the sample-level loss (seq-mean-token-mean), which may be unstable in long-CoT scenarios. All GRPO example scripts provided in verl uses the default configuration \"token-mean\" for loss aggregation instead.\n\nInstead of adding KL penalty in the reward, GRPO regularizes by directly adding the KL divergence between the trained policy and the reference policy to the loss:\n\n- `actor_rollout_ref.actor.use_kl_loss`: To use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False. Please set it to True for GRPO.\n\n- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.\n\n- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending \"+\" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\n## Advanced Extensions\n\n### DrGRPO\n\n[Understanding R1-Zero-Like Training: A Critical Perspective](https://arxiv.org/pdf/2503.20783) claims there's optimization bias in GRPO, which leads to artificially longer responses, especially for incorrect outputs. This inefficiency stems from the way GRPO calculates advantages using group-based reward normalization. Instead, DrGRPO aggregates token-level losses by normalizing with a global constant to eliminate length bias.\n\nConfigure the following to enable DrGRPO, with all other parameters the same as GRPO's:\n\n- `actor_rollout_ref.actor.loss_agg_mode`: \"seq-mean-token-sum-norm\", which turns off seq-dim averaging\n- `actor_rollout_ref.actor.use_kl_loss`: Please set it to False for DrGRPO\n- `algorithm.norm_adv_by_std_in_grpo`: False, which turns off standard deviation norm\n\n## Reference Example\n\nQwen2.5 GRPO training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log)\n\n```bash\nbash examples/grpo_trainer/run_qwen3-8b.sh\n```\n\nFor more reference performance, please see https://verl.readthedocs.io/en/latest/algo/baseline.html\n"
  },
  {
    "path": "verl_distillation/docs/algo/opo.md",
    "content": "# On-Policy RL with Optimal Reward Baseline (OPO)\n\nLast updated: 06/02/2025.\n\nLoose on-policy constraints and suboptimal baselines in reinforcement learning often lead to training instability such as large policy shifts and entropy collapse. OPO addresses these challenges by using exact on-policy training with the theretically optimal reward baseline for advantage estimation. It achieves lower policy shifts and higher output entropy, encouraging more diverse and less repetitive responses.\n\nOPO uses group sampling to generate multiple outputs for each input like GRPO. Unlike group-based algorithms which typically use the mean reward of a group as its baseline, OPO employs a theoretically optimal baseline: the length-weighted reward of the group. It also  omits the standard deviation normalization. By adopting these two key components, OPO enables the training of a single policy model with the objective of maximizing only the expected reward. For more detailes, refer to the original paper [On-Policy RL with Optimal Reward Baseline](https://arxiv.org/pdf/2505.23585).\n\n## Key Components\n\n- Exact On-Policy Training: always generates responses from the current policy, without using any pre-generated data or off-policy data.\n- Optimal Reward Baseline: uses a length-weighted reward of the group as the baseline for normalizing the rewards.\n\n## Configuration\n\nTo configure OPO within the framework, use the following YAML settings. These parameters are crucial for enabling exact on-policy training and activating the optimal reward baseline.\n\n```yaml\nalgorithm:\n  adv_estimator: opo  # Use OPO for optimal reward baseline \ndata:\n  train_batch_size: 1024\nactor_rollout_ref:\n  actor:\n    ppo_mini_batch_size: 1024 # ppo_mini_batch_size should equal to train_batch_size to enable exact on-policy training\n    entropy_coeff: 0 # disable entropy regularization\n    use_kl_loss: False # disable kl regularization\n    kl_loss_coef: 0 \n```\n\n## Advanced Extensions\n\nOPO can also be extended to other algorithms like RLOO and Reinforce++. It just needs to adjust their configurations to enable exact on-policy training and incorporate the optimal length-weighted reward baseline with minimal modifications to their advantage estimation functions.\n"
  },
  {
    "path": "verl_distillation/docs/algo/ppo.md",
    "content": "# Proximal Policy Optimization (PPO)\n\nLast updated: 06/19/2025.\n\nProximal Policy Optimization (PPO) is a family of policy gradient methods for reinforcement learning, proposed by OpenAI in 2017. PPO strikes a balance between simplicity, stability, and performance, making it one of the most widely used algorithms in modern RL applications, including large-scale language model fine-tuning.\n\nTraditional policy gradient methods like REINFORCE or Vanilla Policy Gradient suffer from:\n\n- High variance and sample inefficiency.\n- Instability due to large policy updates.\n\nPPO addresses this problem using a clipped surrogate objective that avoids overly large updates without requiring second-order derivatives.\n\nFor more technical details regarding PPO, we suggest reading the introduction in the [OpenAI spinning up tutorial](https://spinningup.openai.com/en/latest/algorithms/ppo.html), and the paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347).\n\n## Key Components\n\n- Actor-Critic Architecture: PPO requires both an actor model (policy) and a critic model (value function). This differs from other algorithms like GRPO and RLOO that don't require a critic model.\n\n- Generalized Advantage Estimation (GAE): PPO uses GAE for computing advantage values, which helps reduce variance in policy gradient estimates while maintaining low bias.\n\n- Clipped Surrogate Objective: The core of PPO is implemented through the clipped surrogate objective function that limits policy updates.\n\n## Configuration\n\nNote that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior.\n\nMost critic configs are similar to those of actors. Note that the critic model is omitted from the figure below.\n\n![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d)\n\n- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n`\n\n- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers\n\n- `critic.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO critic updates. The ppo_mini_batch_size is a global size across all workers\n\n- `actor_rollout_ref.actor.clip_ratio`: The PPO clip range. Default to 0.2\n\n- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for actor\n\n- `critic.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for critic. Defaults to `actor_rollout_ref.actor.ppo_epochs`\n\n- `algorithm.gemma`: discount factor\n\n- `algorithm.lam`: The lambda term that trades off between bias and variance in the GAE estimator\n\n- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo\n\n## Advanced Extensions\n\n### KL Divergence Control\n\nOptions to prevent the policy from diverging too far from a reference policy. Two mechanisms are available: KL reward penalty and KL loss. For more technical details, see [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155)\n\nOptions to use KL loss for KL divergence control: \n\n- `actor_rollout_ref.actor.use_kl_loss`: to use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False\n\n- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.\n\n- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending \"+\" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\nOptions to use KL penalty in the reward:\n\n- `algorithm.use_kl_in_reward`: Whether to enable in-reward kl penalty. Default is False.\n\n- `algorithm.kl_penalty`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. This defines the way to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty` in core_algos.py. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\n- `algorithm.kl_ctrl.kl_coef`: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.\n- `algorithm.kl_ctrl.type`: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.\n- `algorithm.kl_ctrl.horizon`: See source code of AdaptiveKLController for details.\n- `algorithm.kl_ctrl.target_kl`: See source code of AdaptiveKLController for details.\n\n### Dual-clip PPO\n\nThe Dual-Clip PPO introduces a approach by applying a lower bound to the policy ratio when the advantage is less than zero, when multiplied by a large raito, does not exceed a specified lower bound.\n\n![image](https://github.com/user-attachments/assets/fc232181-d8b0-4307-8dd2-4dc0a4c1c139)\n\n- `actor_rollout_ref.actor.clip_ratio_c`: lower bound of the value for Dual-clip PPO, defaults to 3.0\n\n## Reference Example\n\nQwen2.5 training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log)\n\n```bash\nbash run_gemma.sh\n  trainer.n_gpus_per_node=1 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  trainer.logger=console \\\n  critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  data.train_batch_size=256 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size=2 \\\n  critic.ppo_micro_batch_size=2\n```\n\nReference performance with verl v0.2:\n\n| Model                          | Method          | Score | Link                                                                                           |\n|-------------------------------|------------------|-------|------------------------------------------------------------------------------------------------|\n| Qwen/Qwen2.5-0.5B-Instruct     | pretrained model | 36.4  | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/)                                        |\n| Qwen/Qwen2.5-0.5B-Instruct     | PPO              | 56.7  | [PPO Command and Logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) |\n"
  },
  {
    "path": "verl_distillation/docs/algo/spin.md",
    "content": "# Recipe: Self-Play Fine-Tuning (SPIN)\n\nLast updated: 05/31/2025.\n\n`verl` provides a recipe inspired by the paper **\"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models\"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory.\n\n**Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models:\n\n1.  **Synthetic Data Generation:** The current model generates responses, creating its own training data from previous iterations.\n2.  **Two-Player Game Setup:** A game involving two players acted by a single LLM.\n3.  **Iterative Training:** The model progressively improves by refining its policy, with each iteration's model becoming the opponent for the next iteration.\n\nPaper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\\*, [Yihe Deng](https://github.com/uclaml/SPIN)\\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n\n[[Webpage](https://uclaml.github.io/SPIN/)] [[Huggingface](https://huggingface.co/papers/2401.01335)] [[Paper](https://arxiv.org/abs/2401.01335)] [[Original Implementation](https://github.com/uclaml/SPIN)]\n\nverl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20)\n\n---\n\n## Key Function (compute_online_dpo_loss) and Related works\nSPIN (Chen et al., 2024) proposes an iterative self-play mechanism to fine-tune language models. In each iteration, SPIN's training objective, when using a logistic loss function, is equivalent to Direct Preference Optimization (DPO) loss (Rafailov et al., 2023). \n\nThis `verl` recipe realizes SPIN's core concept by using DPO loss iteratively (Xu et al., 2023; Xiong et al., 2023; Snorkel AI, 2024). This means that in each iteration, we fine-tune the LLM using DPO loss for preference optimization. Notably, Xu et al. (2023) explored iterative preference optimization with pairwise cringe loss, while Xiong et al. (2023) discussed how to bridge theory and practice for RLHF under KL constraints using iterative training. The concept of iterative preference learning was also explored in online DPO (Guo et al., 2024), which focuses on direct alignment from online AI feedback. In online DPO, preference data is dynamically updated during training, allowing the model to learn from its own generated data.\n\nSpecifically, we developed the **`compute_online_dpo_loss`** function and built this SPIN recipe on top of it. By incorporating online preference generation, this approach enables continuously refining language models without relying on fixed external preference datasets.\n\n**Reference Papers:**\n* [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024) \n* [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023) \n* [Somethings are more cringe than others: Preference optimization with the pairwise cringe loss](https://arxiv.org/abs/2312.16682) (Xu et al., 2023) \n* [Iterative preference learning from human feedback: Bridging theory and practice for rlhf under kl-constraint](https://arxiv.org/abs/2312.11456) (Xiong et al., 2023)\n* [Snorkel-Mistral-PairRM-DPO](https://huggingface.co/snorkelai/Snorkel-Mistral-PairRM-DPO) (Snorkel AI, 2024)\n* [Direct language model alignment from online ai feedback](https://arxiv.org/abs/2402.04792) (Guo et al., 2024)\n\n\n## Our Online DPO Implementation\n\nOur `compute_online_dpo_loss` function adapts `verl`'s existing PPO infrastructure (based on `verl` v0.3.0.post1) for this iterative online DPO. Key aspects of our implementation include:\n\n* **No Critic:** Unlike PPO, we omit the value function critic.\n* **Dynamic Reference Model:** An explicit reference policy (`ref_policy_wg`) is used for DPO loss. This reference model's weights can be periodically updated from the actor (`ref_update_freq`), providing a dynamic baseline.\n* **Online Preference Generation:** The `compute_onlineDPO_pref` function (in `core_algos.py`) dynamically creates chosen/rejected pairs based on a reward source (e.g., rule-based ranking for math problems).\n* **DPO Loss Integration:** We replace PPO's policy loss with our `compute_online_dpo_loss` (in `core_algos.py`) within the actor update (`dp_actor.py`), directly optimizing the policy using the generated preferences.\n* **Iterative Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the entire self-play loop: generation, preference labeling, optional reference model updates, and policy updates, enabling continuous self-improvement aligned with SPIN's principles.\n\n---\n## Algorithm\n\nThis recipe implements an Online algorithm adapted to the `verl` Reinforcement Learning framework, which provides an alternative to PPO for fine-tuning language models.\n\n**Online Loop:** Instead of maximizing a scalar reward signal in PPO, this approach directly optimizes the policy model to align with preference data generated *online* during training:\n\n1.  **Generation:** The current model generates multiple responses for each prompt in a batch.\n2.  **Preference Labeling:** A function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). This can be done using a reward function or implicit ranking based on specific rules. (In this recipe, we use rule-based ranking on the math problem).\n3.  **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using `compute_online_dpo_loss`, comparing against a reference model.\n\n**Connection with SPIN:**\nInstead of only using a fixed target data distribution, the online generation loop in step 2 will dynamically change the target data distribution by using a certain Preference Labeling method (rule-based ranking on the math problem by selecting the better one in this recipe). This explores the direction mentioned in SPIN's paper Section 7 about \"dynamically changing target data distribution\" to potentially elevate LLM performance beyond the fixed human-annotated data ceiling.\n\n---\n\n## Reproduce the Experiment (Example Setup)\n\nThe following steps outline how to set up the environment and run the SPIN recipe, based on the provided test log using GSM8K and Qwen2.5-3B-Instruct.\n\n1.  **Setup Environment (Example using Docker):**\n    ```bash\n    # Start a container with GPU access and shared memory\n    docker run -it --name spin_test --gpus all \\\n        --shm-size=32g \\\n        --ipc=host \\\n        -v /path/to/host/.cache:/root/.cache \\\n        -e HF_TOKEN=<YOUR_HUGGINGFACE_TOKEN> \\\n        lmsysorg/sglang:latest \\\n        /bin/bash\n\n    # Inside the container or on your host machine:\n    # Ensure /tmp is writable\n    mkdir -p /tmp\n    chmod 1777 /tmp\n\n    # Install Python 3.10 (if not present) and venv\n    sudo apt update\n    sudo apt install -y python3.10 python3.10-venv tmux\n    python3 -m ensurepip --upgrade\n\n    # Create and activate a virtual environment\n    python3 -m venv ~/.python/spin_env\n    source ~/.python/spin_env/bin/activate\n\n    # Install uv (fast package installer)\n    python3 -m pip install uv\n    ```\n\n2.  **Install verl and Dependencies:**\n    ```bash\n    # Clone the verl repository and checkout the spin branch\n    cd ~\n    git clone git@github.com:volcengine/verl.git && cd verl\n\n    # Install flash-attn (handle potential build issues)\n    python3 -m uv pip install wheel packaging\n    python3 -m uv pip install flash-attn --no-build-isolation --no-deps\n\n    # Install verl with sglang extras\n    python3 -m uv pip install -e \".[sglang]\"\n    ```\n    *Note: If `flash-attn` installation fails, try the manual steps again or consult its documentation.*\n\n3.  **Login & Download Data/Model:**\n    ```bash\n    # Login to Weights & Biases (optional, for logging)\n    export WANDB_API_KEY=<YOUR_WANDB_API_KEY>\n    # wandb login\n\n    # Download the GSM8K dataset\n    python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k # Adjusted path\n\n    # Download the base model (Example: Qwen2.5-3B-Instruct)\n    huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen2.5-3B-Instruct\n    ```\n\n4.  **Configure:**\n    * Modify the configuration file (e.g., `config/spin_trainer.yaml` or the one specified in the run script) with correct paths to your downloaded model, data, desired hyperparameters (`dpo_beta`, learning rate, etc.), and distributed training settings (nodes, GPUs per node).\n    * Pay attention to `actor_rollout_ref.model_path`, `data` paths, `reward_model` config (if using one), and `trainer.ref_update_freq`.\n\n5.  **Run Training:**\n    ```bash\n    # Set CUDA visible devices (adjust based on your hardware and config)\n    export CUDA_VISIBLE_DEVICES=0,1,2,3\n\n    # Launch the training script (e.g., test.sh or a custom script)\n    # Ensure test.sh points to the correct config and main script\n    bash recipe/spin/run_spin.sh\n    ```\n\n---\n\n## Configuration\n\n* The primary configuration is typically managed through a YAML file specified in the launch script (e.g., `config/spin_trainer.yaml`).\n* Key configuration sections:\n    * `data`: Paths to training/validation prompt files, batch sizes, sequence lengths.\n    * `actor_rollout_ref`: Paths to the base model (used for actor and initial reference), FSDP settings, optimization parameters (learning rate, scheduler).\n    * `reward_model`: Configuration for the reward model used for online preference labeling (path, batch size, etc.). Can be omitted if using a simpler reward function.\n    * `algorithm`: DPO-specific hyperparameters like `dpo_beta`, `dpo_loss_type`.\n    * `trainer`: Distributed training settings (nodes, GPUs per node), logging (WandB), checkpointing frequency, and `ref_update_freq` (set > 0 to enable periodic reference model updates from the actor).\n\n---\n\n## Key Files\n\n* `main_spin.py`: Main entry point using Hydra to load the config and launch the `SpinTrainer`.\n* `spin_trainer.py`: Defines the `SpinTrainer` class, orchestrating the Online DPO training loop.\n* `fsdp_workers.py`: Implements Ray workers (Actor, Reference) potentially using FSDP.\n* `dp_actor.py`: Contains the actor class, including the DPO policy update logic.\n* `core_algos.py`: Includes helper functions for `compute_online_dpo_loss` and `compute_onlineDPO_pref`.\n* `config/spin_trainer.yaml` (or similar): Main Hydra configuration file for the recipe.\n* `run_spin.sh` (or similar): Example bash script for launching a training run.\n* `README.md`: This file.\n\n---\n\n## Acknowledgement\n\nWe sincerely thank the contribution and guidance from the `verl` community and advisors, including (adapted from SPPO):\n\n* [Zixiang Chen](https://sites.google.com/view/zxchen)\n* [Yuhao Yang](https://github.com/yhyang201)\n* [Yifan Zhang](https://github.com/yifanzhang-pro)\n* [Yongan Xiang](https://github.com/BearBiscuit05)\n* [Junrong Lin](https://github.com/ocss884)\n* [Yuxuan Tong](https://github.com/tongyx361)\n* [Guangming Shen](https://github.com/PeterSH6)\n* [Biao He](https://www.linkedin.com/in/biao-he/)\n* [Qingquan Song](https://qingquansong.github.io/)\n* [Chenyang Zhao](https://zhaochenyang20.github.io/Chayenne/)\n* [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n"
  },
  {
    "path": "verl_distillation/docs/algo/sppo.md",
    "content": "# Recipe: Self-Play Preference Optimization (SPPO)\n\nLast updated: 05/28/2025.\n\nverl provides a community recipe implementation for the paper [Self-Play Preference Optimization for Language Model Alignment](https://arxiv.org/abs/2405.00675). SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets.\n\nPaper Authors: [Yue Wu](https://yuewu.us/)\\*, [Zhiqing Sun](https://www.cs.cmu.edu/~zhiqings/)\\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Yiming Yang](https://www.cs.cmu.edu/~yiming/), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n\nverl Implementation Authors: [Yuhao Yang](https://github.com/yhyang201), [Chenyang Zhao](https://github.com/zhaochenyang20)\n\n[[Webpage](https://uclaml.github.io/SPPO/)] [[Huggingface](https://huggingface.co/papers/2405.00675)] [[Paper](https://arxiv.org/abs/2405.00675)][[Original Implementation](https://github.com/uclaml/SPPO)]\n\n## Reproduce the Experiment\n\nWe evaluate the performance of SPPO on the MATH dataset. Starting from an initial score of 46.6 with Qwen2.5-7B-Instruct, we achieve a score of 65.6 after 20 epochs of training, placing our model approximately in the top 20 on the [MATH leaderboard](https://paperswithcode.com/sota/math-word-problem-solving-on-math). It's important to note that verl's internal evaluation metrics may not perfectly align with the official evaluation methodology for Qwen2.5-7B-Instruct. Therefore, for consistency and fair comparison, we report only the results based on verl's evaluation framework.\n\n```\ngit clone git@github.com:volcengine/verl.git\ncd verl\npython3 -m uv pip install -e \".[sglang]\"\n\nexport WANDB_API_KEY=<YOUR_WANDB_API_KEY>\n\npython3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math\nhuggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct\n\nexport CUDA_VISIBLE_DEVICES=0,1,2,3\nbash recipe/sppo/run_qwen2.5-7b_rm.sh\n```\n\nNote that the installation would occasionally fail to install flash-attn. If this happens, you can install it manually by running:\n\n```bash\npython3 -m uv pip install wheel\npython3 -m uv pip install packaging\npython3 -m uv pip install flash-attn --no-build-isolation --no-deps\n```\n\n## Acknowledgement\n\nWe sincerely thank the contribution and guidance from:\n\n- [Yue Wu](https://yuewu.us/)\n- [Chendong Wang](https://cdwang96.github.io/)\n- [Yifan Zhang](https://github.com/yifanzhang-pro)\n- [Yongan Xiang](https://github.com/BearBiscuit05)\n- [Junrong Lin](https://github.com/ocss884)\n- [Yuxuan Tong](https://github.com/tongyx361)\n- [Guangming Shen](https://github.com/PeterSH6)\n- [Biao He](https://www.linkedin.com/in/biao-he/)\n- [Qingquan Song](https://qingquansong.github.io/)\n- [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n"
  },
  {
    "path": "verl_distillation/docs/amd_tutorial/amd_build_dockerfile_page.rst",
    "content": "Getting started with AMD (ROCM Kernel)\n=====================================================\n\nLast updated: 07/06/2025.\n\nAuthor: `Yusheng Su <https://yushengsu-thu.github.io/>`_\n\nSetup\n-----\n\nIf you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and set ``RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES`` or ``RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES`` when starting ray in verl's RLHF training.\n\n\ndocker/Dockerfile.rocm\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    FROM \"rlsys/rocm-6.3.4-patch:rocm6.3.4-numa-patch_ubuntu-22.04\"\n\n    SHELL [\"/bin/bash\", \"-ceuxo\", \"pipefail\"]\n\n    ENV MAX_JOBS=512\n\n    ENV PATH=\"/usr/local/python3.12/bin:$PATH\"\n    RUN ln -sf /usr/bin/python3.12 /usr/bin/python && \\\n        ln -sf /usr/bin/pip3.12 /usr/bin/pip\n\n    ############################################\n    RUN apt-get update\n    RUN apt-get install -y pkg-config liblzma-dev\n    ############################################\n\n    ###########################################\n    ##########Install TransformerEngine########\n    ###########################################\n    WORKDIR /workspace/\n    # transformer-engine install\n    # https://github.com/ROCm/TransformerEngine\n    RUN rm -rf TransformerEngine \n    RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git\n    WORKDIR /workspace/TransformerEngine\n    git checkout 236178e5\n    # git checkout bb061ade\n    # git checkout 864405c\n    ENV NVTE_FRAMEWORK=pytorch \n    ENV NVTE_ROCM_ARCH=gfx942 \n    ENV NVTE_USE_HIPBLASLT=1\n    ENV NVTE_USE_ROCM=1  \n    # export CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}\"\n    ENV CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr\"\n    RUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv \n    WORKDIR /workspace/\n    ###########################################\n    ###########################################\n    ###########################################\n\n\n\n\n\n    ####################################################################################\n    ################Install vllm - sglang require vllm 0.6.7 dependency#################\n    ####################################################################################\n    #### Require vllm 0.6.7 - checkout 113274a0\n    WORKDIR /workspace/\n    RUN rm -rf vllm\n    RUN pip uninstall -y vllm\n    # Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html\n    RUN git clone https://github.com/ROCm/vllm.git\n    # git clone https://github.com/vllm-project/vllm.git\n    WORKDIR /workspace/vllm\n    RUN git checkout 113274a0\n    ENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n    #ENV MAX_JOBS=512\n    ENV MAX_JOBS=${MAX_JOBS}\n    RUN pip install \"boto3>=1.26.0\"\n    RUN pip install setuptools_scm\n    # will add src into py. You can delete the repo\n    RUN python3 setup.py install\n    WORKDIR /workspace/\n    ####################################################################################\n    ####################################################################################\n    ####################################################################################\n\n\n\n    ###########################################\n    ############For hack docker################\n    ###########################################\n    RUN pip install setuptools==75.8.0\n    ###########################################\n    ###########################################\n    ###########################################\n\n\n\n    ###########################################\n    ############build sgalng###################\n    ###########################################\n    # Set environment variables\n    ENV BASE_DIR=/sgl-workspace\n    ENV BUILD_TYPE=all\n    ENV SGL_REPO=https://github.com/sgl-project/sglang\n    ENV SGL_BRANCH=v0.4.6.post5\n    ENV TRITON_REPO=https://github.com/ROCm/triton.git\n    ENV TRITON_COMMIT=improve_fa_decode_3.0.0\n    ENV AITER_REPO=https://github.com/ROCm/aiter.git\n    ENV AITER_COMMIT=v0.1.2\n    # v0.1.2 version - commit id: 9d11f47\n    # ENV AITER_COMMIT=9d11f47\n    ENV HIP_FORCE_DEV_KERNARG=1\n    ENV HSA_NO_SCRATCH_RECLAIM=1\n    ENV SGLANG_SET_CPU_AFFINITY=1\n    ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1\n    ENV NCCL_MIN_NCHANNELS=112\n    ENV MOE_PADDING=1\n    ENV VLLM_FP8_PADDING=1\n    ENV VLLM_FP8_ACT_PADDING=1\n    ENV VLLM_FP8_WEIGHT_PADDING=1\n    ENV VLLM_FP8_REDUCE_CONV=1\n    ENV TORCHINDUCTOR_MAX_AUTOTUNE=1\n    ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1\n    ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\n    ENV AMDGPU_TARGETS=gfx942\n    ENV ROCM_ARCH=gfx942\n    ENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n    # Switch to working directory\n    WORKDIR /sgl-workspace\n    # Clean and create directory\n    RUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace\n\n    # Clone and build sglang\n    RUN git clone ${SGL_REPO} \\\n        && cd sglang \\\n        && git checkout ${SGL_BRANCH} || echo \"Using default branch\" \\\n        && cd sgl-kernel \\\n        && rm -f pyproject.toml \\\n        && mv pyproject_rocm.toml pyproject.toml \\\n        && python setup_rocm.py install \\\n        && cd .. \\\n        && if [ \"$BUILD_TYPE\" = \"srt\" ]; then \\\n            python -m pip --no-cache-dir install -e \"python[srt_hip]\"; \\\n        else \\\n            python -m pip --no-cache-dir install -e \"python[all_hip]\"; \\\n        fi \\\n        && cd /sgl-workspace \\\n        && cp -r /sgl-workspace/sglang /sglang \\\n        && python -m pip cache purge\n\n    # Install common Python packages\n    RUN pip install IPython orjson python-multipart torchao pybind11\n    # Rebuild Triton\n    RUN pip uninstall -y triton || true \\\n        && git clone ${TRITON_REPO} \\\n        && cd triton \\\n        && git checkout ${TRITON_COMMIT} \\\n        && cd python \\\n        && python3 setup.py install \\\n        && cd /sgl-workspace\n    # ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1\"\n    # ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\n\n    # Build aiter\n    #version: Commit 9d11f47\n        # && git checkout ${AITER_COMMIT} \\\n    RUN pip uninstall -y aiter || true\n    RUN git clone ${AITER_REPO} \\\n        && cd aiter \\\n        && git checkout ${AITER_COMMIT} \\\n        && git submodule sync \\\n        && git submodule update --init --recursive \\\n        && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \\\n        && cd /sgl-workspace\n\n    # Copy MI300X config \n    RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \\\n            /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \\\n            -type f -name '*MI300X*' | \\\n            xargs -I {} sh -c 'vf_config=$(echo \"$1\" | sed \"s/MI300X/MI300X_VF/\"); cp \"$1\" \"$vf_config\"' -- {}\n\n    # Environment setup complete.\n    RUN echo \"Environment setup complete.\"\n\n    WORKDIR /workspace/\n    ###########################################\n    ###########################################\n    ###########################################\n\n\n\n\n\n\n    ###########################################\n    ###############vllm v0.8.5#################\n    ###########################################\n    WORKDIR /workspace/\n\n    ENV VLLM_TARGET_DEVICE=rocm \n    ENV ROCM_PATH=/opt/rocm \n    ENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev\n    # Find the repo path in: DockerFile/Dockerfile.rocm_yang\n    # RUN git clone https://github.com/RLFoundation/vllm-patch.git\n    RUN pip uninstall -y vllm || true\n    RUN rm -rf vllm-patch\n    RUN git clone https://github.com/RLFoundation/vllm-patch.git \\\n        && cd vllm-patch \\\n        && git checkout v0.8.5-sleep-numa \\\n        && rm -rf build/ dist/ *.egg-info \\\n        && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \\\n        && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py install\n        # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py develop\n    WORKDIR /workspace/\n    ###########################################\n    ###########################################\n    ###########################################\n\n\n\n\n    #########################################\n    #### Install megatron-core###############\n    #########################################\n    RUN pip uninstall -y megatron-core && \\\n        git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \\\n        cd Megatron-LM-amd_version && \\\n        pip install -vvv -e . && \\\n        cd /workspace/\n    #########################################\n    #########################################\n    #########################################\n\n\n\n\n    #######################################\n    ################apex###################\n    #######################################\n    WORKDIR /workspace/\n    RUN pip uninstall -y apex && \\\n        git clone git@github.com:ROCm/apex.git && \\\n        cd apex && \\\n        python setup.py install && \\\n        cd /workspace/ \n    #######################################\n    #######################################\n    #######################################\n\n\n    ################################################################################\n    ###########################Add torch_memory_saver###############################\n    ################################################################################\n    # Set environment variables\n    ENV HIPCC_COMPILE_FLAGS_APPEND=\"--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__\"\n    ENV CFLAGS=\"-D__HIP_PLATFORM_AMD__\"\n    ENV CXXFLAGS=\"-D__HIP_PLATFORM_AMD__\"\n    RUN pip install \"git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa\"\n    ################################################################################\n    ################################################################################\n    ################################################################################\n\n\n\n    ########################################\n    ######Install ray#######################\n    ########################################\n    # need to add this patch: https://github.com/ray-project/ray/pull/53531/files\n    RUN pip uninstall ray -y\n    RUN pip install \"ray[data,train,tune,serve]>=2.47.0\" \n    ########################################\n    ########################################\n    ########################################\n\n\n    ##########################################\n    #######Install other dependencies#########\n    ##########################################\n    RUN pip install \"tensordict==0.6.2\" --no-deps && \\\n        pip install accelerate \\\n        codetiming \\\n        datasets \\\n        dill \\\n        hydra-core \\\n        liger-kernel \\\n        numpy \\\n        pandas \\\n        peft \\\n        \"pyarrow>=15.0.0\" \\\n        pylatexenc \\\n        torchdata \\\n        wandb \\\n        orjson \\\n        pybind11\n        \n    WORKDIR /workspace/\n    RUN git clone https://github.com/volcengine/verl.git && \\\n        cd verl && \\\n        pip install -e . \n    ##########################################\n    ##########################################\n    ##########################################\n\n    WORKDIR /workspace/\n    CMD [\"/usr/bin/bash\"]\n\n\nBuild the image:\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    docker docker/build -t verl-rocm .\n\nRun the container\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nNote: You can pull the docker from this DockerHub: [RLSys Foundation](https://hub.docker.com/u/yushengsuthu)\nPull the image:\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    docker pull rlsys/verl:verl-0.4.1_ubuntu-22.04_rocm6.3.4-numa-patch_vllm0.8.5_sglang0.4.6.post4\n\n    docker tag rlsys/verl:verl-0.4.1_ubuntu-22.04_rocm6.3.4-numa-patch_vllm0.8.5_sglang0.4.6.post4 verl-rocm:latest\n\nRun the container\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n\nOptional: Running without root and with user permissions\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. code-block:: bash\n\n    docker run --rm -it \\\n      --device /dev/dri \\\n      --device /dev/kfd \\\n      -p 8265:8265 \\\n      --group-add video \\\n      --cap-add SYS_PTRACE \\\n      --security-opt seccomp=unconfined \\\n      --privileged \\\n      -v $HOME/.ssh:/root/.ssh \\\n      -v $HOME:$HOME \\\n      --shm-size 128G \\\n      -w $PWD \\\n      verl-rocm \\\n      /bin/bash\n\n(Optional): If you do not want to root mode and require assign yourself as the user\nPlease add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script. \n\nExample\n-------\n\nDue to to special setting in AMD (ROCM) torch, \n1. If your ``ray>=2.45.0`` (default), you need to set ``RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES`` when starting ray in verl's RLHF training and add this [patch](https://github.com/ray-project/ray/pull/53531/files).\n2. If your ``ray<2.45.0``, you need to set ``RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES`` when starting ray in verl's RLHF training.\nInference ``$ENGINE`` can be ``vllm`` or ``sglang``. We choose ``vllm`` as default in the following examples.\n\n\n\nPPO\n~~~\n\n.. code-block:: bash\n\n    YOUR_PROJECT_NAME=r1-verl-ppo-upstream\n    YOUR_RUN_NAME=r1-training_ppo-upstream \n    # export HYDRA_FULL_ERROR=1\n\n    export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n    \n    # [ray] < 2.45.0\n    #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1\n\n    # [ray] >= 2.45.0\n    export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 # Patch with https://github.com/ray-project/ray/pull/52794\n\n    GPUS_PER_NODE=8\n    MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct\n    python3 examples/data_preprocess/gsm8k.py --local_save_dir data/gsm8k\n    python3 -c \"import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')\"\n    ENGINE=vllm #sglang\n\n    PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n     data.train_files=data/gsm8k/train.parquet \\\n     data.val_files=data/gsm8k/test.parquet \\\n     data.train_batch_size=256 \\\n     data.val_batch_size=1312 \\\n     data.max_prompt_length=512 \\\n     data.max_response_length=256 \\\n     actor_rollout_ref.model.path=$MODEL_PATH \\\n     actor_rollout_ref.actor.optim.lr=1e-6 \\\n     actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n     actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n     actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n     actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n     actor_rollout_ref.rollout.name=$ENGINE \\\n     actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n     actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n     critic.optim.lr=1e-5 \\\n     critic.model.path=$MODEL_PATH \\\n     critic.ppo_micro_batch_size_per_gpu=4 \\\n     algorithm.kl_ctrl.kl_coef=0.001 \\\n     trainer.logger=console \\\n     trainer.project_name=$YOUR_PROJECT_NAME \\\n     trainer.experiment_name=$YOUR_RUN_NAME \\\n     trainer.val_before_train=False \\\n     trainer.n_gpus_per_node=$GPUS_PER_NODE \\\n     trainer.nnodes=1 \\\n     trainer.save_freq=10 \\\n     trainer.test_freq=10 \\\n     trainer.total_epochs=15 #2>&1 | tee verl_demo.log\n\nGRPO\n~~~~\n\n.. code-block:: bash\n\n    YOUR_PROJECT_NAME=r1-verl-grpo-upstream\n    YOUR_RUN_NAME=r1-training_grpo-upstream\n    # export HYDRA_FULL_ERROR=1\n    # export FSDP_VERBOSE=1 \n\n    #export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n\n    # [ray] < 2.45.0\n    #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1\n\n    # [ray] >= 2.45.0\n    export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 # Patch with https://github.com/ray-project/ray/pull/52794\n\n    GPUS_PER_NODE=8\n    MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct\n    # MODEL_PATH=Qwen/Qwen2-7B-Instruct\n    python3 examples/data_preprocess/gsm8k.py --local_save_dir data/gsm8k\n    python3 -c \"import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')\"\n    ENGINE=vllm #sglang\n    \n    python3 -m verl.trainer.main_ppo \\\n        algorithm.adv_estimator=grpo \\\n        data.train_files=data/gsm8k/train.parquet \\\n        data.val_files=data/gsm8k/test.parquet \\\n        data.train_batch_size=1024 \\\n        data.val_batch_size=1312 \\\n        data.max_prompt_length=512 \\\n        data.max_response_length=1024 \\\n        actor_rollout_ref.model.path=$MODEL_PATH \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n        actor_rollout_ref.actor.use_dynamic_bsz=True \\\n        actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n        actor_rollout_ref.actor.use_kl_loss=True \\\n        actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n        actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=Flase \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n        actor_rollout_ref.rollout.name=$ENGINE \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n        actor_rollout_ref.rollout.n=5 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=False \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.critic_warmup=0 \\\n        trainer.logger=console \\\n        trainer.project_name=$YOUR_PROJECT_NAME \\\n        trainer.experiment_name=$YOUR_RUN_NAME \\\n        trainer.n_gpus_per_node=$GPUS_PER_NODE \\\n        trainer.val_before_train=False \\\n        trainer.nnodes=1 \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15\n\n\n\nMulti-node training: slurm with Docker/Podman container\n---------------------------------------------------------------------------------------\n\nIf you want to run multi-node training with slurm, you can use the following script. \n\n.. note::\n    1. You need to use ``podman`` or ``docker`` in the following script. We will release the apptainer script later.\n    2. If you want to use ``podman``, you just replace ``docker`` with ``podman`` in the following script.\n\nThe script includes the following steps:\n\n1. SLURM Configuration\n2. Environment Setup\n3. Docker/Podman Container Setup\n4. Ray Cluster Initialization\n5. Data Preprocessing\n6. Model Setup\n7. Training Launch\n\n\nslurm_script.sh\n~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    #!/bin/bash\n\n    #SBATCH --job-name=verl-ray-on-slurm\n    #SBATCH --nodes=2\n    #SBATCH --ntasks-per-node=2\n    #SBATCH --mem=200G\n    #SBATCH --time=30-00:00:00\n    #SBATCH --gpus-per-node=8\n    #SBATCH --cpus-per-task=28\n    #SBATCH --output=../verl_log/slurm-%j.out\n    #SBATCH --error=../verl_log/slurm-%j.err\n    #SBATCH --nodelist=gpu-[0,1]\n\n\n    # load necessary modules\n    ### Run this setup\n    # [Cluster]: Use docker\n    # docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n\n\n    ##########################################################################\n    ###The following setting should be set in different project and cluster###\n    ##########################################################################\n\n    ### Project\n    CONTAINER_NAME=\"multinode_verl_training\"\n    IMG=\"verl.rocm\"\n    DOCKERFILE=\"docker/Dockerfile.rocm\"\n    # echo $PWD\n    verl_workdir=\"${HOME}/projects/verl_upstream\"\n    export TRANSFORMERS_CACHE=\"${HOME}/.cache/huggingface\"\n    export HF_HOME=$TRANSFORMERS_CACHE\n\n    ### Cluster Network Setting\n    export NCCL_DEBUG=TRACE\n    export GPU_MAX_HW_QUEUES=2\n    export TORCH_NCCL_HIGH_PRIORITY=1\n    export NCCL_CHECKS_DISABLE=1\n    # export NCCL_IB_HCA=rdma0,rdma1,rdma2,rdma3,rdma4,rdma5,rdma6,rdma7 \n    export NCCL_IB_HCA=mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_8,mlx5_9\n    export NCCL_IB_GID_INDEX=3\n    export NCCL_CROSS_NIC=0\n    export CUDA_DEVICE_MAX_CONNECTIONS=1\n    export NCCL_PROTO=Simple\n    export RCCL_MSCCL_ENABLE=0\n    export TOKENIZERS_PARALLELISM=false\n    export HSA_NO_SCRATCH_RECLAIM=1\n    ##########################################################################\n\n    ## Assign using GPUs\n    export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n\n    ### For rocm and training script\n    # [ray] < 2.45.0\n    #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1\n\n    # [ray] >= 2.45.0\n    export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 # Patch with https://github.com/ray-project/ray/pull/52794\n\n\n    # Build and launch the Docker container\n    srun bash -c \"\n        # Exit on any error\n        set -e \n\n        # Clean up dangling images (images with <none> tag)\n        docker image prune -f\n\n        # Need to pull the docker first\n        docker pull rlsys/verl:verl-0.4.1_ubuntu-22.04_rocm6.3.4-numa-patch_vllm0.8.5_sglang0.4.6.post4\n        \n        if ! docker images --format \"{{.Repository}}:{{.Tag}}\" | grep -q \"${IMG}\"; then\n            echo \\\"Building ${IMG} image...\\\"\n            docker build -f \\\"${DOCKERFILE}\\\" -t \\\"${IMG}\\\" .\n        else\n            echo \\\"${IMG} image already exists, skipping build\\\"\n        fi\n\n        # Removing old container if exists\n        docker rm \\\"${CONTAINER_NAME}\\\" 2>/dev/null || true\n\n        # Checking network devices\n        ibdev2netdev\n\n        # Launch the docker\n        docker run --rm -d \\\n        -e HYDRA_FULL_ERROR=1 \\\n        -e RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 \\\n        -e RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 \\\n        -e NCCL_DEBUG=${NCCL_DEBUG} \\\n        -e GPU_MAX_HW_QUEUES=${GPU_MAX_HW_QUEUES} \\\n        -e TORCH_NCCL_HIGH_PRIORITY=${TORCH_NCCL_HIGH_PRIORITY} \\\n        -e NCCL_CHECKS_DISABLE=${NCCL_CHECKS_DISABLE} \\\n        -e NCCL_IB_HCA=${NCCL_IB_HCA} \\\n        -e NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX} \\\n        -e NCCL_CROSS_NIC=${NCCL_CROSS_NIC} \\\n        -e CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS} \\\n        -e NCCL_PROTO=${NCCL_PROTO} \\\n        -e RCCL_MSCCL_ENABLE=${RCCL_MSCCL_ENABLE} \\\n        -e TOKENIZERS_PARALLELISM=${TOKENIZERS_PARALLELISM} \\\n        -e HSA_NO_SCRATCH_RECLAIM=${HSA_NO_SCRATCH_RECLAIM} \\\n        -e TRANSFORMERS_CACHE=${TRANSFORMERS_CACHE} \\\n        -e HF_HOME=${HF_HOME} \\\n        --network host \\\n        --device /dev/dri \\\n        --device /dev/kfd \\\n        --device /dev/infiniband \\\n        --group-add video \\\n        --cap-add SYS_PTRACE \\\n        --security-opt seccomp=unconfined \\\n        --privileged \\\n        -v \\${HOME}:\\${HOME} \\\n        -v \\${HOME}/.ssh:/root/.ssh \\\n        -w \"${verl_workdir}\" \\\n        --shm-size 128G \\\n        --name \\\"${CONTAINER_NAME}\\\" \\\n        \\\"${IMG}\\\" \\\n        tail -f /dev/null\n\n        echo \\\"Container setup completed\\\"\n    \"\n        # (Optional): If you do not want to root mode and require assign yuorself as the user\n        # Please add `-e HOST_UID=$(id -u)` and `-e HOST_GID=$(id -g)` into the above docker launch script. \n\n\n\n\n\n    ### Ray launch the nodes before training\n\n    # Getting the node names\n    nodes_array=($(scontrol show hostnames \"$SLURM_JOB_NODELIST\" | tr '\\n' ' '))\n\n    head_node=${nodes_array[0]}\n    head_node_ip=$(srun --nodes=1 --ntasks=1 -w \"$head_node\" hostname --ip-address)\n\n    # if we detect a space character in the head node IP, we'll\n    # convert it to an ipv4 address. This step is optional.\n    if [[ \"$head_node_ip\" == *\" \"* ]]; then\n        IFS=' ' read -ra ADDR <<<\"$head_node_ip\"\n    if [[ ${#ADDR[0]} -gt 16 ]]; then\n        head_node_ip=${ADDR[1]}\n    else\n        head_node_ip=${ADDR[0]}\n    fi\n        echo \"IPV6 address detected. We split the IPV4 address as $head_node_ip\"\n    fi\n\n    port=6379\n    ip_head=$head_node_ip:$port\n    export ip_head\n    echo \"IP Head: $ip_head\"\n\n    # make sure we set environment variables before Ray initialization\n\n    # Print out all env variables\n    printenv\n\n    echo \"Starting HEAD at $head_node\"\n    srun --nodes=1 --ntasks=1 -w \"$head_node\" \\\n        docker exec \"${CONTAINER_NAME}\" \\\n            ray start --head --node-ip-address=\"$head_node_ip\" --port=$port \\\n            --dashboard-port=8266 \\\n            --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n    # optional, though may be useful in certain versions of Ray < 1.0.\n    sleep 10\n\n    # number of nodes other than the head node\n    worker_num=$((SLURM_JOB_NUM_NODES - 1))\n\n    for ((i = 1; i <= worker_num; i++)); do\n        node_i=${nodes_array[$i]}\n        echo \"Debug: Starting worker on node_i = ${node_i}\"\n        if [ -z \"$node_i\" ]; then\n            echo \"Error: Empty node name for worker $i\"\n            continue\n        fi\n        echo \"Starting WORKER $i at $node_i\"\n        srun --nodes=1 --ntasks=1 -w \"$node_i\" \\\n            docker exec \"${CONTAINER_NAME}\" \\\n                ray start --address \"$ip_head\" --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n        sleep 5\n    done\n\n\n\n\n    # Ray initlization test (See whether any error in the above execution)\n    echo \"Testing Ray initialization in the slurm nodes...\"\n    docker exec \"${CONTAINER_NAME}\" python3 -c '\n    import ray\n    try:\n        ray.init(address=\"auto\")\n        print(\"\\n=== Ray Cluster Status ===\")\n        print(f\"Number of nodes: {len(ray.nodes())}\")\n        for node in ray.nodes():\n            print(\"Node: {}, Status: {}\".format(node[\"NodeManagerHostname\"], node[\"Alive\"]))\n            # print(f\"Node: {node}\")\n        ray.shutdown()\n        print(\"Ray initialization successful!\")\n    except Exception as e:\n        print(f\"Ray initialization failed: {str(e)}\")\n    '\n    echo \"=== Ray test completed ===\"\n    ######\n\n\n\n    # Run data preprocessing\n\n    echo \"Starting data preprocessing...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 \"examples/data_preprocess/gsm8k.py\" \"--local_save_dir\" \"../data/gsm8k\"\n\n    echo \"Starting data preprocessing...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 \"examples/data_preprocess/math_dataset.py\" \"--local_dir\" \"../data/math\"\n\n    train_files=\"../data/gsm8k/train.parquet\"\n    val_files=\"../data/gsm8k/test.parquet\"\n\n    # Download and test model\n    echo \"Loading model...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')\"\n    MODEL_PATH=\"Qwen/Qwen2-7B-Instruct\"\n\n    # Set model path after pipeline test\n    MODEL_PATH=\"Qwen/Qwen2.5-0.5B-Instruct\"\n\n    echo \"== Data and model loading Done ==\"\n\n    echo \"Start to train...\"\n\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')\"\n    MODEL_PATH=\"Qwen/Qwen2-7B-Instruct\"\n\n\n    PYTHONUNBUFFERED=1 srun --overlap --nodes=${SLURM_NNODES} --ntasks=1 -w \"$head_node\" \\\n        docker exec \"${CONTAINER_NAME}\" \\\n        python3 -m verl.trainer.main_ppo \\\n        data.train_files=$train_files \\\n        data.val_files=$val_files \\\n        data.train_batch_size=1024 \\\n        data.max_prompt_length=1024 \\\n        data.max_response_length=1024 \\\n        actor_rollout_ref.model.path=$MODEL_PATH \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=False \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n        actor_rollout_ref.rollout.name=vllm \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n        critic.optim.lr=1e-5 \\\n        critic.model.use_remove_padding=True \\\n        critic.model.path=$MODEL_PATH \\\n        critic.model.enable_gradient_checkpointing=False \\\n        critic.ppo_micro_batch_size_per_gpu=8 \\\n        critic.model.fsdp_config.param_offload=False \\\n        critic.model.fsdp_config.optimizer_offload=False \\\n        algorithm.kl_ctrl.kl_coef=0.0001 \\\n        trainer.critic_warmup=0 \\\n        trainer.logger='[\"console\",\"wandb\"]' \\\n        trainer.project_name='verl_example' \\\n        trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \\\n        trainer.n_gpus_per_node=${SLURM_GPUS_PER_NODE} \\\n        trainer.val_before_train=False \\\n        trainer.nnodes=${SLURM_NNODES} \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15\n\n\nRun slurm_script.sh\n~~~~~~~~~~~~~~~~~~~~\nJust sbatch your slurm_script.sh\n\n.. code-block:: bash\n\n    sbatch slurm_script.sh\n\n"
  },
  {
    "path": "verl_distillation/docs/amd_tutorial/amd_vllm_page.rst",
    "content": "verl performance tuning for AMD (ROCm Kernel)\n=====================================================\n\nLast updated: 04/25/2025.\n\nAuthor: `Yang Wang <https://github.com/YangWang92/>`_\n\nPatch vLLM to Enable Sleep Mode for AMD GPUs\n--------------------------------------------------------------\n\nBy default, verl requires vLLM to enable sleep mode, which allows vLLM to offload GPU memory to CPU memory after rollout. However, this feature is still under review by the vLLM community.\n\nTo enable vLLM's sleep mode, you can first use community patched code (from `this pull request <https://github.com/vllm-project/vllm/pull/12695>`_) to build vLLM from the source code in the corresponding pull request. After the patch merged in vLLM main branch, you can directly install vLLM from the latest version.\n\n1. Clone the vLLM repository and build it with the following commands:\n\n.. code-block:: bash\n\n    git clone -b sleep_amd https://github.com/HollowMan6/vllm.git\n    cd vllm\n    sudo ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so\n    VLLM_TARGET_DEVICE=rocm ROCM_PATH=/opt/rocm/ VLLM_GPU_LANG=HIP SETUPTOOLS_SCM_PRETEND_VERSION=0.8.4.dev python3 setup.py develop\n\n2. Additionally, make sure to use the ROCm version in your Docker image lager than or equal to ROCm 6.3.4, and we recommend to use ROCm 6.4.0 for better performance (see `this comment <https://github.com/vllm-project/vllm/pull/12695#issuecomment-2637839574>`_).\n\nAfter the upgrade, you can verify whether sleep mode is enabled by running the following test code (from `this comment <https://github.com/vllm-project/vllm/pull/12695#issuecomment-2637839574>`_).\n\n.. code-block:: python\n\n\timport torch\n\tfrom vllm import LLM\n\n\tllm = LLM(model=\"meta-llama/Llama-3.1-8B-Instruct\", enable_sleep_mode=True)\n\n\tdef run_inference(prompt):\n\t\toutputs = llm.generate(prompt)\n\t\tfor output in outputs:\n\t\t\tprompt = output.prompt\n\t\t\tgenerated_text = output.outputs[0].text\n\t\t\tprint(f\"Prompt: {prompt!r}, Generated text: {generated_text!r}\")\n\n\n\tprint(\"CUDA Memory Usage (after inference):\")\n\ttorch.cuda.empty_cache()\n\tprint(f\"{torch.cuda.memory_allocated()=}\")\n\n\trun_inference(\"San Francisco is\")\n\tllm.sleep()\n\n\tprint(\"CUDA Memory Usage (after sleep):\")\n\ttorch.cuda.empty_cache()\n\tprint(f\"{torch.cuda.memory_allocated()=}\")\n\n\tllm.wake_up()\n\n\tprint(\"CUDA Memory Usage (after wakeup):\")\n\ttorch.cuda.empty_cache()\n\tprint(f\"{torch.cuda.memory_allocated()=}\")\n\n\trun_inference(\"Paris is\")\n\nIf sleep mode is enabled, you should see the memory usage reduce after sleep.\n\nAfter applying the vLLM patch and completing the installation, you can enable sleep mode in verl to reduce memory overhead. This allows verl to offload unused GPU memory during rollout, significantly lowering the memory footprint during long-context training or multi-node reinforcement learning.\n\n\nEnable CUDA Graph and Bypass ROCm-related issues\n--------------------------------------------------------------\n\nDue to potential issues with CUDA graph capture in ROCm, we’ve found that vLLM’s CUDA graph feature cannot be enabled on multiple nodes in verl on AMD platforms with vLLM V1 mode. This leads to significantly slower rollout performance.\n\nOur investigation shows that ROCm may trigger an unexpected crash when attempting to capture large batches with CUDA graph. One workaround is to patch the LLM configuration (from `this commit <https://github.com/volcengine/verl/blob/v0.3.0.rc0/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py#L100-L115>`_).\n\n.. code-block:: python\n\t\n    self.inference_engine = LLM(\n        model=model_path,\n        enable_sleep_mode=True,\n        tensor_parallel_size=tensor_parallel_size,\n        distributed_executor_backend=\"external_launcher\",\n        dtype=config.dtype,\n        enforce_eager=config.enforce_eager,\n        gpu_memory_utilization=config.gpu_memory_utilization,\n        disable_custom_all_reduce=True,\n        disable_mm_preprocessor_cache=True,\n        limit_mm_per_prompt=limit_mm_per_prompt,\n        skip_tokenizer_init=False,\n        max_model_len=max_model_len,\n        load_format=load_format,\n        disable_log_stats=config.disable_log_stats,\n        max_num_batched_tokens=max_num_batched_tokens,\n        enable_chunked_prefill=config.enable_chunked_prefill,\n        enable_prefix_caching=True,\n        trust_remote_code=trust_remote_code,\n        # enable compilation config to bypass oom on rocm\n\t# change depends on your GPU memory size\n        compilation_config={\"cudagraph_capture_sizes\": [1, 2, 4, 8, 16, 32, 64]},\n        seed=config.get('seed', 0),\n    )\n\nThen, you can choose to enable CUDA graph by setting the following environment variables (see `this page <https://github.com/volcengine/verl/blob/v0.3.0.rc0/docs/README_vllm0.8.md>`_):\n\n.. code-block:: bash\n\n\tactor_rollout_ref.rollout.enforce_eager=False \\\n"
  },
  {
    "path": "verl_distillation/docs/api/data.rst",
    "content": "Data interface\n=========================\n\nLast updated: 05/19/2025 (API docstrings are auto-generated).\n\nDataProto is the interface for data exchange.\n\nThe :class:`verl.DataProto` class contains two key members:\n\n- batch: a :class:`tensordict.TensorDict` object for the actual data\n- meta_info: a :class:`Dict` with additional meta information\n\nTensorDict\n~~~~~~~~~~~~\n\n:attr:`DataProto.batch` is built on top of :class:`tensordict`, a project in the PyTorch ecosystem.\nA TensorDict is a dict-like container for tensors. To instantiate a TensorDict, you must specify key-value pairs as well as the batch size.\n\n.. code-block:: python\n\n    >>> import torch\n    >>> from tensordict import TensorDict\n    >>> tensordict = TensorDict({\"zeros\": torch.zeros(2, 3, 4), \"ones\": torch.ones(2, 3, 5)}, batch_size=[2,])\n    >>> tensordict[\"twos\"] = 2 * torch.ones(2, 5, 6)\n    >>> zeros = tensordict[\"zeros\"]\n    >>> tensordict\n    TensorDict(\n    fields={\n        ones: Tensor(shape=torch.Size([2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),\n        twos: Tensor(shape=torch.Size([2, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),\n        zeros: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},\n    batch_size=torch.Size([2]),\n    device=None,\n    is_shared=False)\n\nOne can also index a tensordict along its batch_size. The contents of the TensorDict can be manipulated collectively as well.\n\n.. code-block:: python\n\n    >>> tensordict[..., :1]\n    TensorDict(\n    fields={\n        ones: Tensor(shape=torch.Size([1, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),\n        twos: Tensor(shape=torch.Size([1, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),\n        zeros: Tensor(shape=torch.Size([1, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},\n    batch_size=torch.Size([1]),\n    device=None,\n    is_shared=False)\n    >>> tensordict = tensordict.to(\"cuda:0\")\n    >>> tensordict = tensordict.reshape(6)\n\nFor more about :class:`tensordict.TensorDict` usage, see the official tensordict_ documentation.\n\n.. _tensordict: https://pytorch.org/tensordict/overview.html\n\n\nCore APIs\n~~~~~~~~~~~~~~~~~\n\n.. autoclass::  verl.DataProto\n   :members: to, select, union, make_iterator, concat\n"
  },
  {
    "path": "verl_distillation/docs/api/single_controller.rst",
    "content": "Single Controller interface\n============================\n\nLast updated: 05/27/2025 (API docstrings are auto-generated).\n\nThe Single Controller provides a unified interface for managing distributed workers\nusing Ray or other backends and executing functions across them.\nIt simplifies the process of dispatching tasks and collecting results, particularly \nwhen dealing with data parallelism or model parallelism. \n\n\nCore APIs\n~~~~~~~~~~~~~~~~~\n\n.. autoclass:: verl.single_controller.Worker\n   :members: __init__, __new__, get_master_addr_port, get_cuda_visible_devices, world_size, rank\n\n.. autoclass:: verl.single_controller.WorkerGroup\n   :members: __init__,  world_size\n\n.. autoclass:: verl.single_controller.ClassWithInitArgs\n   :members: __init__, __call__\n\n.. autoclass:: verl.single_controller.ResourcePool\n   :members: __init__, world_size, local_world_size_list, local_rank_list\n\n.. autoclass:: verl.single_controller.ray.RayWorkerGroup\n   :members: __init__\n\n.. autofunction:: verl.single_controller.ray.create_colocated_worker_cls"
  },
  {
    "path": "verl_distillation/docs/api/trainer.rst",
    "content": "Trainer Interface\n================================\n\nLast updated: 06/08/2025 (API docstrings are auto-generated).\n\nTrainers drive the training loop. Introducing new trainer classes in case of new training paradiam is encouraged.\n\n.. autosummary::\n   :nosignatures:\n\n   verl.trainer.ppo.ray_trainer.RayPPOTrainer\n\n\nCore APIs\n~~~~~~~~~~~~~~~~~\n\n.. autoclass:: verl.trainer.ppo.ray_trainer.RayPPOTrainer\n   :members: __init__, init_workers, fit\n\n.. automodule:: verl.utils.tokenizer\n   :members: hf_tokenizer\n\n.. automodule:: verl.trainer.ppo.core_algos\n   :members: agg_loss, kl_penalty, compute_policy_loss, kl_penalty\n\n.. automodule:: verl.trainer.ppo.reward\n   :members: load_reward_manager, compute_reward, compute_reward_async\n\n.. autoclass:: verl.workers.reward_manager.NaiveRewardManager\n\n.. autoclass:: verl.workers.reward_manager.DAPORewardManager\n"
  },
  {
    "path": "verl_distillation/docs/api/utils.rst",
    "content": "Utilities\n============\n\nLast updated: 05/19/2025 (API docstrings are auto-generated).\n\nThis section documents the utility functions and classes in the VERL library.\n\nPython Functional Utilities\n------------------------------\n\n.. automodule:: verl.utils.py_functional\n   :members: append_to_dict\n\nFile System Utilities\n------------------------\n\n.. automodule:: verl.utils.fs\n   :members: copy_to_local\n\nTracking Utilities\n---------------------\n\n.. automodule:: verl.utils.tracking\n   :members: Tracking\n\nMetrics Utilities\n---------------------\n\n.. automodule::  verl.utils.metric\n   :members: reduce_metrics\n\nCheckpoint Management\n------------------------\n\n.. automodule:: verl.utils.checkpoint.checkpoint_manager\n   :members: find_latest_ckpt_path\n\n.. automodule:: verl.utils.checkpoint.fsdp_checkpoint_manager\n   :members: FSDPCheckpointManager\n\nDataset Utilities\n---------------------\n\n.. automodule:: verl.utils.dataset.rl_dataset\n   :members: RLHFDataset, collate_fn\n\nTorch Functional Utilities\n-----------------------------\n\n.. automodule:: verl.utils.torch_functional\n   :members: get_constant_schedule_with_warmup, masked_whiten, masked_mean, logprobs_from_logits\n\nSequence Length Balancing\n----------------------------\n\n.. automodule:: verl.utils.seqlen_balancing\n   :members: get_reverse_idx, rearrange_micro_batches\n\nUlysses Utilities\n--------------------\n\n.. automodule:: verl.utils.ulysses\n   :members: gather_outputs_and_unpad, ulysses_pad_and_slice_inputs\n\nFSDP Utilities\n------------------\n\n.. automodule:: verl.utils.fsdp_utils\n   :members: get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer,\n\nDebug Utilities\n-------------------\n\n.. automodule:: verl.utils.profiler\n   :members: log_gpu_memory_usage, GPUMemoryLogger\n\n"
  },
  {
    "path": "verl_distillation/docs/ascend_tutorial/ascend_profiling_en.rst",
    "content": "Data collection based on FSDP backend on Ascend devices(en)\n==========================================================================================\n\nLast updated: 08/14/2025.\n\nThis is a tutorial for data collection using the GRPO or DAPO algorithm\nbased on FSDP on Ascend devices.\n\nConfiguration\n-------------\n\nLeverage two levels of configuration to control data collection:\n\n1. **Global profiler control**: Use parameters in ``ppo_trainer.yaml`` to control the collection mode and steps.\n2. **Role profile control**: Use parameters in each role's ``profile`` field to control the collection mode for each role.\n\nGlobal collection control\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\nUse parameters in ppo_trainer.yaml to control the collection mode\nand steps.\n\n-  global_profiler: Control the ranks and mode of profiling\n\n   -  tool: The profiling tool to use, options are nsys, npu, torch,\n      torch_memory.\n   -  steps: This parameter can be set as a list that has\n      collection steps, such as [2, 4], which means it will collect steps 2\n      and 4. If set to null, no collection occurs.\n   -  save_path: The path to save the collected data. Default is\n      \"outputs/profile\".\n\n\nRole collection control\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nIn each role's ``profiler`` field, you can control the collection mode for that role.\n\n-  enable: Whether to enable profiling for this role.\n-  all_ranks: Whether to collect data from all ranks.\n-  ranks: A list of ranks to collect data from. If empty, no data is collected.\n-  tool_config: Configuration for the profiling tool used by this role.\n\nUse parameters in each role's ``profiler.tool_config.npu`` to control npu profiler behavior:\n\n-  level: Collection level—options are level_none, level0, level1, and\n   level2\n\n   -  level_none: Disables all level-based data collection (turns off\n      profiler_level).\n   -  level0: Collect high-level application data, underlying NPU data,\n      and operator execution details on NPU.\n   -  level1: Extends level0 by adding CANN-layer AscendCL data and AI\n      Core performance metrics on NPU.\n   -  level2: Extends level1 by adding CANN-layer Runtime data and AI\n      CPU metrics.\n\n-  contents: A list of options to control the collection content, such as\n   npu, cpu, memory, shapes, module, stack.\n   \n   -  npu: Whether to collect device-side performance data.\n   -  cpu: Whether to collect host-side performance data.\n   -  memory: Whether to enable memory analysis.\n   -  shapes: Whether to record tensor shapes.\n   -  module: Whether to record framework-layer Python call stack\n      information.\n   -  stack: Whether to record operator call stack information.\n\n-  analysis: Enables automatic data parsing.\n-  discrete: Whether to enable discrete mode.\n\n\nExamples\n--------\n\nDisabling collection\n~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n      global_profiler:\n         steps: null # disable profile\n\nEnd-to-End collection\n~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n      global_profiler:\n         steps: [1, 2, 5]\n      actor_rollout_ref:\n         actor:\n            profiler:\n               enable: True\n               all_ranks: True\n               tool_config:\n                  npu:\n                     discrete: False\n        # rollout & ref follow actor settings\n\n\nDiscrete Mode Collection\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n      global_profiler:\n         steps: [1, 2, 5]\n      actor_rollout_ref:\n         actor:\n            profiler:\n               enable: True\n               all_ranks: True\n               tool_config:\n                  npu:\n                     discrete: True\n        # rollout & ref follow actor settings\n\n\nVisualization\n-------------\n\nCollected data is stored in the user-defined save_path and can be\nvisualized by using the `MindStudio Insight <https://www.hiascend.com/document/detail/zh/mindstudio/80RC1/GUI_baseddevelopmenttool/msascendinsightug/Insight_userguide_0002.html>`_ tool.\n\nIf the analysis parameter is set to False, offline parsing is required after data collection:\n\n.. code:: python\n\n    import torch_npu\n    # Set profiler_path to the parent directory of the \"localhost.localdomain_<PID>_<timestamp>_ascend_pt\" folder\n    torch_npu.profiler.profiler.analyse(profiler_path=profiler_path)"
  },
  {
    "path": "verl_distillation/docs/ascend_tutorial/ascend_profiling_zh.rst",
    "content": "Data collection based on FSDP backend on Ascend devices(zh)\n====================================\n\n在昇腾设备上基于FSDP后端进行数据采集\n\nLast updated: 08/14/2025.\n\n这是一份在昇腾设备上基于FSDP后端使用GRPO或DAPO算法进行数据采集的教程。\n\n配置\n----\n\n使用两级profile设置来控制数据采集\n\n- 全局采集控制：使用verl/trainer/config/ppo_trainer.yaml中的配置项控制采集的模式和步数，\n- 角色profile控制：通过每个角色中的配置项控制等参数。\n\n全局采集控制\n~~~~~~~~~~~~\n\n通过 ppo_trainer.yaml 中的参数控制采集步数和模式：\n\n-  global_profiler: 控制采集的rank和模式\n\n   -  tool: 使用的采集工具，选项有 nsys、npu、torch、torch_memory。\n   -  steps: 此参数可以设置为包含采集步数的列表，例如 [2, 4]，表示将采集第2步和第4步。如果设置为 null，则不进行采集。\n   -  save_path: 保存采集数据的路径。默认值为 \"outputs/profile\"。\n\n角色profiler控制\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n在每个角色的 ``profiler`` 字段中，您可以控制该角色的采集模式。\n\n-  enable: 是否为此角色启用性能分析。\n-  all_ranks: 是否从所有rank收集数据。\n-  ranks: 要收集数据的rank列表。如果为空，则不收集数据。\n-  tool_config: 此角色使用的性能分析工具的配置。\n\n通过每个角色的 ``profiler.tool_config.npu`` 中的参数控制具体采集行为：\n\n-  level: 采集级别—选项有 level_none、level0、level1 和 level2\n\n   -  level_none: 禁用所有基于级别的数据采集（关闭 profiler_level）。\n   -  level0: 采集高级应用数据、底层NPU数据和NPU上的算子执行详情。\n   -  level1: 在level0基础上增加CANN层AscendCL数据和NPU上的AI Core性能指标。\n   -  level2: 在level1基础上增加CANN层Runtime数据和AI CPU指标。\n\n-  contents: 控制采集内容的选项列表，例如\n   npu、cpu、memory、shapes、module、stack。\n   \n   -  npu: 是否采集设备端性能数据。\n   -  cpu: 是否采集主机端性能数据。\n   -  memory: 是否启用内存分析。\n   -  shapes: 是否记录张量形状。\n   -  module: 是否记录框架层Python调用栈信息。\n   -  stack: 是否记录算子调用栈信息。\n\n-  analysis: 启用自动数据解析。\n-  discrete: 使用离散模式。\n\n示例\n----\n\n禁用采集\n~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n      global_profiler:\n         steps: null # disable profile\n\n端到端采集\n~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n      global_profiler:\n         steps: [1, 2, 5]\n      actor_rollout_ref:\n         actor:\n            profiler:\n               enable: True\n               all_ranks: True\n               tool_config:\n                  npu:\n                     discrete: False\n        # rollout & ref follow actor settings\n\n\n离散模式采集\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n      global_profiler:\n         steps: [1, 2, 5]\n      actor_rollout_ref:\n         actor:\n            profiler:\n               enable: True\n               all_ranks: True\n               tool_config:\n                  npu:\n                     discrete: True\n        # rollout & ref follow actor settings\n\n\n可视化\n------\n\n采集后的数据存放在用户设置的save_path下，可通过 `MindStudio Insight <https://www.hiascend.com/document/detail/zh/mindstudio/80RC1/GUI_baseddevelopmenttool/msascendinsightug/Insight_userguide_0002.html>`_ 工具进行可视化。\n\n如果analysis参数设置为False，采集之后需要进行离线解析：\n\n.. code:: python\n\n    import torch_npu\n    # profiler_path请设置为\"localhost.localdomain_<PID>_<timestamp>_ascend_pt\"目录的上一级目录\n    torch_npu.profiler.profiler.analyse(profiler_path=profiler_path)"
  },
  {
    "path": "verl_distillation/docs/ascend_tutorial/ascend_quick_start.rst",
    "content": "verl x Ascend\n===================================\n\nLast updated: 10/31/2025.\n\n我们在 verl 上增加对华为昇腾设备的支持。\n\n硬件支持\n-----------------------------------\n\nAtlas 200T A2 Box16\n\nAtlas 900 A2 PODc\n\nAtlas 800T A3\n\n\n安装\n-----------------------------------\n\n基础环境准备\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n+-----------+-------------+\n| software  | version     |\n+-----------+-------------+\n| Python    | == 3.10     |\n+-----------+-------------+\n| CANN      | == 8.2.RC1  |\n+-----------+-------------+\n| torch     | == 2.5.1    |\n+-----------+-------------+\n| torch_npu | == 2.5.1    |\n+-----------+-------------+\n\n基础环境准备请参照这份 `文档 <https://gitcode.com/Ascend/pytorch>`_ 。\n\nvllm & vllm-ascend\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n为了能够在 verl 中正常使用 vllm，需使用以下命令编译安装 vllm 和 vllm-ascend。请注意根据机器类型区分安装方式。\n\n.. code-block:: bash\n    \n    # vllm\n    git clone -b v0.9.1 --depth 1 https://github.com/vllm-project/vllm.git\n    cd vllm\n    pip install -r requirements-build.txt\n\n    # for Atlas 200T A2 Box16\n    VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/\n    \n    # for Atlas 900 A2 PODc\n    VLLM_TARGET_DEVICE=empty pip install -e .\n\n.. code-block:: bash\n    \n    # vllm-ascend\n    git clone -b v0.9.1 --depth 1 https://github.com/vllm-project/vllm-ascend.git\n    cd vllm-ascend\n    export COMPILE_CUSTOM_KERNELS=1\n    python setup.py install\n\n安装verl\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. code-block:: bash\n\n    git clone https://github.com/volcengine/verl.git\n    cd verl\n    pip install -r requirements-npu.txt\n    pip install -e .\n\nDockerFile镜像构建\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n如需要通过DockerFile构建镜像， 请参考 `文档 <https://github.com/volcengine/verl/tree/main/docs/ascend_tutorial/dockerfile_build_guidance.rst>`_ 。\n\n其他三方库说明\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n+--------------+---------------+\n| software     | description   |\n+--------------+---------------+\n| transformers | v4.52.4       |\n+--------------+---------------+\n| flash_attn   | not supported |\n+--------------+---------------+\n| liger-kernel | not supported |\n+--------------+---------------+\n\n1. 支持通过 transformers 使能 --flash_attention_2， transformers 需等于 4.52.4版本。\n2. 不支持通过 flash_attn 使能 flash attention 加速。\n3. 不支持 liger-kernel 使能。\n4. 针对 x86 服务器，需要安装 cpu 版本的 torchvision。\n\n.. code-block:: bash\n\n    pip install torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu\n\n\n快速开始\n-----------------------------------\n正式使用前，建议您通过对Qwen2.5-0.5B GRPO的训练尝试以检验环境准备和安装的正确性。\n\n1.下载数据集并将数据集预处理为parquet格式，以便包含计算RL奖励所需的必要字段\n\n.. code-block:: bash\n\n    python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k\n\n2.执行训练\n\n.. code-block:: bash\n\n    set -x\n\n    export VLLM_ATTENTION_BACKEND=XFORMERS\n\n    python3 -m verl.trainer.main_ppo \\\n        algorithm.adv_estimator=grpo \\\n        data.train_files=$HOME/data/gsm8k/train.parquet \\\n        data.val_files=$HOME/data/gsm8k/test.parquet \\\n        data.train_batch_size=128 \\\n        data.max_prompt_length=512 \\\n        data.max_response_length=128 \\\n        data.filter_overlong_prompts=True \\\n        data.truncation='error' \\\n        actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n        actor_rollout_ref.actor.optim.lr=5e-7 \\\n        actor_rollout_ref.model.use_remove_padding=False \\\n        actor_rollout_ref.actor.entropy_coeff=0.001 \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \\\n        actor_rollout_ref.actor.use_kl_loss=True \\\n        actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n        actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n        actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n        actor_rollout_ref.rollout.name=vllm \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n        actor_rollout_ref.rollout.n=5 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.critic_warmup=0 \\\n        trainer.logger=console \\\n        trainer.project_name='verl_grpo_example_gsm8k' \\\n        trainer.experiment_name='qwen2_7b_function_rm' \\\n        trainer.n_gpus_per_node=8 \\\n        trainer.nnodes=1 \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=5 \\\n        trainer.total_epochs=1 \\\n        trainer.device=npu $@\n\n(可选) 设置MindSpeed训练后端指导\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n1. 参考 `MindSpeed README <https://gitcode.com/Ascend/MindSpeed>`_ 说明安装 MindSpeed 加速库。\n\n2. 使能 verl worker 模型 ``strategy`` 配置为 ``megatron`` ，例如 ``actor_rollout_ref.actor.strategy=megatron``。\n\n3. MindSpeed 自定义入参可通过 ``override_transformer_config`` 参数传入，例如对 actor 模型开启 FA 特性可使用 ``+actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True``。\n\n4. 更多特性信息可参考 `MindSpeed+verl 文档 <https://gitcode.com/Ascend/MindSpeed/blob/master/docs/user-guide/verl.md>`_ 。\n\n支持现状\n-----------------------------------\n\n**表1** RL类算法\n\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n| algorithm |         model           |   actor.strategy  |   rollout.name    |         hardware         |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   GRPO    | Qwen2.5-7B-instruct     |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   GRPO    | Qwen2.5-32B-instruct    |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   GRPO    | Qwen2.5-VL-3B-instruct  |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   GRPO    | Qwen2.5-VL-7B-instruct  |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   GRPO    | Qwen2.5-VL-32B-instruct |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   GRPO    | Qwen3-8B                |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   GRPO    | Qwen3-32B               |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   DAPO    | Qwen2.5-7B-instruct     |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   DAPO    | Qwen2.5-32B             |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   DAPO    | Qwen3-8B-base           |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   DAPO    | Qwen3-14B-base          |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   DAPO    | Qwen3-30B-A3B-base      |        FSDP       |    vllm-ascend    |    Atlas 200T A2 Box16   |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   DAPO    | Qwen3-30B-A3B           |      megatron     |    vllm-ascend    |    Atlas 800T A3         |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n|   PPO     | Qwen3-8B                |        FSDP       |    vllm-ascend    |    Atlas 900 A2 PODc     |\n+-----------+-------------------------+-------------------+-------------------+--------------------------+\n\n**表2** SFT类算法\n\n+-----------+-------------------------+-------------------+----------------------+\n| algorithm |         model           |   actor.strategy  |        hardware      |\n+-----------+-------------------------+-------------------+----------------------+\n|  SFT-PEFT | Qwen3-8B                |        FSDP       |   Atlas 900 A2 PODc  |\n+-----------+-------------------------+-------------------+----------------------+\n| ReTool-SFT| Qwen2.5-7B-instruct     |        FSDP       |   Atlas 900 A2 PODc  |\n+-----------+-------------------------+-------------------+----------------------+\n\n\n\n计划\n-----------------------------------\n\n查看 `roadmap <https://github.com/volcengine/verl/discussions/2171>`_ 获取更多特性的支持进度。\n\n\n\n声明\n-----------------------------------\nverl中提供的ascend支持代码、Dockerfile、镜像皆为参考样例，如在生产环境中使用请通过官方正式途径沟通，谢谢。\n"
  },
  {
    "path": "verl_distillation/docs/ascend_tutorial/ascend_sglang_quick_start.rst",
    "content": "verl x Ascend\n===================================\n\nLast updated: 09/25/2025.\n\n我们在 verl 上增加对华为昇腾设备的支持。\n\n硬件支持\n-----------------------------------\n\nAtlas 200T A2 Box16\n\nAtlas 900 A2 PODc\n\nAtlas 800T A3\n\n\n安装\n-----------------------------------\n\n基础环境准备\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n+-----------+-------------+\n| software  | version     |\n+-----------+-------------+\n| Python    | == 3.11     |\n+-----------+-------------+\n| CANN      | == 8.3.RC1  |\n+-----------+-------------+\n| HDK       | == 25.3.RC1 |\n+-----------+-------------+\n| torch     | == 2.6.0    |\n+-----------+-------------+\n| torch_npu | == 2.6.0    |\n+-----------+-------------+\n\n**目前verl框架中sglang npu后端仅支持上述HDK、CANN和PTA版本, 商发可用版本预计2025年10月发布**\n\n为了能够在 verl 中正常使用 sglang，需使用以下命令安装sglang、torch_memory_saver和verl。\n\nsglang\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n.. code-block:: bash\n    \n    # sglang\n    git clone https://github.com/sgl-project/sglang.git\n    cd sglang\n    mv python/pyproject.toml python/pyproject.toml.backup\n    mv python/pyproject_other.toml python/pyproject.toml\n    pip install -e \"python[srt_npu]\"\n\n安装torch_memory_saver\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n.. code-block:: bash\n    \n    # torch_memory_saver\n    git clone https://github.com/sgl-project/sgl-kernel-npu.git\n    cd sgl-kernel-npu\n    bash build.sh  -a memory-saver\n    pip install output/torch_memory_saver*.whl\n\n安装verl\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. code-block:: bash\n\n    git clone https://github.com/volcengine/verl.git\n    cd verl\n    pip install --no-deps -e .\n    pip install -r requirements-npu.txt \n\n\n其他三方库说明\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n+--------------+---------------+\n| software     | description   |\n+--------------+---------------+\n| transformers | v4.56.1       |\n+--------------+---------------+\n| triton_ascend| v3.2.0        |\n+--------------+---------------+\n\n1. sglang依赖 transformers v4.56.1\n2. sglang依赖triton_ascend v3.2.0\n3. 暂不支持多模态模型，卸载相关安装包torchvision、timm\n\n.. code-block:: bash\n    \n    pip uninstall torchvision\n    pip uninstall timm\n    pip uninstall triton\n    \n    pip install transformers==4.56.1\n    pip install -i https://test.pypi.org/simple/ triton-ascend==3.2.0.dev20250925\n\n\n快速开始\n-----------------------------------\n正式使用前，建议您通过对Qwen3-8B GRPO的训练尝试以检验环境准备和安装的正确性。\n\n1.下载数据集并将数据集预处理为parquet格式，以便包含计算RL奖励所需的必要字段\n\n.. code-block:: bash\n\n    python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k\n\n2.执行训练\n\n.. code-block:: bash\n\n    bash verl/examples/grpo_trainer/run_qwen3_8b_grpo_sglang_1k_npu.sh"
  },
  {
    "path": "verl_distillation/docs/ascend_tutorial/dockerfile_build_guidance.rst",
    "content": "Ascend Dockerfile Build Guidance\n===================================\n\nLast updated: 10/31/2025.\n\n我们在verl上增加对华为昇腾镜像构建的支持。\n\n\n硬件支持\n-----------------------------------\n\nAtlas 200T A2 Box16\n\nAtlas 900 A2 PODc\n\nAtlas 800T A3\n\n\n组件版本信息\n----------------\n\n=========== ============\n组件        版本\n=========== ============\n基础镜像    Ubuntu 22.04\nPython      3.11\nCANN        8.2.RC1\ntorch       2.5.1\ntorch_npu   2.5.1\nvLLM        0.9.1\nvLLM-ascend 0.9.1\nMegatron-LM v0.12.1\nMindSpeed   (f2b0977e)\n=========== ============\n\nDockerfile构建镜像脚本\n---------------------------\n\n============== ============== ==============\n设备类型         基础镜像版本     参考文件\n============== ============== ==============\nA2              8.2.RC1       `Dockerfile.ascend_8.2.rc1_a2 <https://github.com/volcengine/verl/blob/main/docker/ascend/Dockerfile.ascend_8.2.rc1_a2>`_\nA3              8.2.RC1       `Dockerfile.ascend_8.2.rc1_a3 <https://github.com/volcengine/verl/blob/main/docker/ascend/Dockerfile.ascend_8.2.rc1_a3>`_\n============== ============== ==============\n\n\n镜像构建命令示例\n--------------------\n\n.. code:: bash\n\n   # Navigate to the directory containing the Dockerfile \n   cd {verl-root-path}/docker/ascend\n   # Build the image\n   docker build -f Dockerfile.ascend_8.2.rc1_a2 -t verl-ascend:8.2.rc1-a2 .\n\n\n声明\n--------------------\nverl中提供的ascend相关Dockerfile、镜像皆为参考样例，可用于尝鲜体验，如在生产环境中使用请通过官方正式途径沟通，谢谢。"
  },
  {
    "path": "verl_distillation/docs/conf.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\n# Configuration file for the Sphinx documentation builder.\r\n#\r\n# This file only contains a selection of the most common options. For a full\r\n# list see the documentation:\r\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\r\n\r\n# -- Path setup --------------------------------------------------------------\r\n\r\n# If extensions (or modules to document with autodoc) are in another directory,\r\n# add these directories to sys.path here. If the directory is relative to the\r\n# documentation root, use os.path.abspath to make it absolute, like shown here.\r\n#\r\n# import os\r\n# import sys\r\n# sys.path.insert(0, os.path.abspath('.'))\r\n\r\n\r\n# -- Project information -----------------------------------------------------\r\n\r\nproject = \"verl\"\r\ncopyright = \"2024 ByteDance Seed Foundation MLSys Team\"\r\nauthor = \"Guangming Sheng, Chi Zhang, Yanghua Peng, Haibin Lin\"\r\n\r\n\r\n# -- General configuration ---------------------------------------------------\r\n# The master toctree document.\r\nmaster_doc = \"index\"\r\n\r\n# Add any Sphinx extension module names here, as strings. They can be\r\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\r\n# ones.\r\nextensions = [\r\n    \"myst_parser\",\r\n    \"sphinx.ext.autodoc\",\r\n    \"sphinx.ext.autosummary\",\r\n    \"sphinx.ext.autosectionlabel\",\r\n    \"sphinx.ext.napoleon\",\r\n    \"sphinx.ext.viewcode\",\r\n]\r\n# Use Google style docstrings instead of NumPy docstrings.\r\nnapoleon_google_docstring = True\r\nnapoleon_numpy_docstring = False\r\n\r\n# The suffix(es) of source filenames.\r\n# You can specify multiple suffix as a list of string:\r\nsource_suffix = {\r\n    \".rst\": \"restructuredtext\",\r\n    \".md\": \"markdown\",\r\n}\r\n\r\n# Add any paths that contain templates here, relative to this directory.\r\ntemplates_path = [\"_templates\"]\r\n\r\n# The language for content autogenerated by Sphinx. Refer to documentation\r\n# for a list of supported languages.\r\n#\r\n# This is also used if you do content translation via gettext catalogs.\r\n# Usually you set \"language\" from the command line for these cases.\r\nlanguage = \"en\"\r\n\r\n# List of patterns, relative to source directory, that match files and\r\n# directories to ignore when looking for source files.\r\n# This pattern also affects html_static_path and html_extra_path.\r\nexclude_patterns = [\"_build\", \"Thumbs.db\", \".DS_Store\"]\r\n\r\n\r\n# -- Options for HTML output -------------------------------------------------\r\n\r\n# The theme to use for HTML and HTML Help pages.  See the documentation for\r\n# a list of builtin themes.\r\n#\r\nhtml_theme = \"sphinx_rtd_theme\"\r\n\r\n# Add any paths that contain custom static files (such as style sheets) here,\r\n# relative to this directory. They are copied after the builtin static files,\r\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\r\nhtml_static_path = [\"_static\"]\r\n\r\n# Add the JavaScript file\r\nhtml_js_files = [\r\n    \"js/runllm-widget.js\",\r\n    \"js/resizable-sidebar.js\",\r\n]\r\n\r\n# Add custom CSS file for full-width layout\r\nhtml_css_files = [\r\n    \"custom.css\",\r\n]\r\n\r\nexclude_patterns += [\"README.md\", \"README_vllm0.7.md\"]\r\n\r\nsuppress_warnings = [\"ref.duplicate\", \"ref.myst\"]\r\n"
  },
  {
    "path": "verl_distillation/docs/data/transfer_queue.md",
    "content": "# TransferQueue Data System\n\nLast updated: 09/28/2025.\n\nThis doc introduce [TransferQueue](https://github.com/TransferQueue/TransferQueue), an asynchronous streaming data management system for efficient post-training.\n\n\n<h2 id=\"overview\"> Overview</h2>\n\nTransferQueue is a high-performance data storage and transfer system with panoramic data visibility and streaming scheduling capabilities, optimized for efficient dataflow in post-training workflows.\n\n<p align=\"center\">\n  <img src=\"https://cdn.nlark.com/yuque/0/2025/png/23208217/1758696193102-a5654375-65a1-4e06-9c63-142b59df90b8.png\" width=\"70%\">\n</p>\n\n\nTransferQueue offers **fine-grained, sample-level** data management capabilities, serving as a data gateway that decouples explicit data dependencies across computational tasks. This enables a divide-and-conquer approach, significantly simplifying the design of the algorithm controller.\n\n\n<p align=\"center\">\n  <img src=\"https://cdn.nlark.com/yuque/0/2025/png/23208217/1758696791245-fa7baf96-46af-4c19-8606-28ffadc4556c.png\" width=\"70%\">\n</p>\n\n\n\n\n<h2 id=\"components\"> Components</h2>\n\n\n\n### Control Plane: Panoramic Data Management  \n\nIn the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. When all the required data fields are ready (i.e., written to the `TransferQueueStorage`), we know that this data sample can be consumed by downstream tasks. \n\nFor consumption status, we record the consumption records for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even different computation tasks require the same data field, they can consume the data independently without interfering with each other.\n\n\n<p align=\"center\">\n  <img src=\"https://cdn.nlark.com/yuque/0/2025/png/23208217/1758696820173-456c1784-42ba-40c8-a292-2ff1401f49c5.png\" width=\"70%\">\n</p>\n\n\n> In the future, we plan to support **load-balancing** and **dynamic batching** capabilities in the control plane. Besides, we will support data management for disaggregated frameworks where each rank manages the data retrieval by itself, rather than coordinated by a single controller.\n\n### Data Plane: Distributed Data Storage\n\nIn the data plane, `TransferQueueStorageSimpleUnit` serves as a naive storage unit based on CPU memory, responsible for the actual storage and retrieval of data. Each storage unit can be deployed on a separate node, allowing for distributed data management.\n\n`TransferQueueStorageSimpleUnit` employs a 2D data structure as follows:\n\n- Each row corresponds to a training sample, assigned a unique index within the corresponding global batch.\n- Each column represents the input/output data fields for computational tasks.\n\nThis data structure design is motivated by the computational characteristics of the post-training process, where each training sample is generated in a relayed manner across task pipelines. It provides an accurate addressing capability, which allows fine-grained, concurrent data read/write operations in a streaming manner.\n\n<p align=\"center\">\n  <img src=\"https://cdn.nlark.com/yuque/0/2025/png/23208217/1758696805154-3817011f-84e6-40d0-a80c-58b7e3e5f6a7.png\" width=\"70%\">\n</p>\n\n\n> In the future, we plan to implement a **general storage abstraction layer** to support various storage backends. Through this abstraction, we hope to integrate high-performance storage solutions such as [MoonCakeStore](https://github.com/kvcache-ai/Mooncake) to support device-to-device data transfer through RDMA, further enhancing data transfer efficiency for large-scale data.\n\n\n### User Interface: Asynchronous & Synchronous Client\n\n\nThe interaction workflow of TransferQueue system is as follows:\n\n1. A process sends a read request to the `TransferQueueController`.\n2. `TransferQueueController` scans the production and consumption metadata for each sample (row), and dynamically assembles a micro-batch metadata according to the load-balancing policy. This mechanism enables sample-level data scheduling.\n3. The process retrieves the actual data from distributed storage units using the metadata provided by the controller.\n\nTo simplify the usage of TransferQueue, we have encapsulated this process into `AsyncTransferQueueClient` and `TransferQueueClient`. These clients provide both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue to their framework.\n\n\n> In the future, we will provide a `StreamingDataLoader` interface for disaggregated frameworks as discussed in [RFC#2662](https://github.com/volcengine/verl/discussions/2662). Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks.\n\n\n<h2 id=\"show-cases\"> Show Cases</h2>\n\n### General Usage\n\nThe primary interaction points are `AsyncTransferQueueClient` and `TransferQueueClient`, serving as the communication interface with the TransferQueue system.\n\nCore interfaces:\n\n- (async_)get_meta(data_fields: list[str], batch_size:int, global_step:int, get_n_samples:bool, task_name:str) -> BatchMeta\n- (async_)get_data(metadata:BatchMeta) -> TensorDict\n- (async_)put(data:TensorDict, metadata:BatchMeta, global_step)\n- (async_)clear(global_step: int)\n\n\nWe will soon release a detailed tutorial and API documentation.\n\n\n### verl Example\n\n\nThe primary motivation for integrating TransferQueue to verl now is to **alleviate the data transfer bottleneck of the single controller `RayPPOTrainer`**. Currently,  all `DataProto` objects must be routed through `RayPPOTrainer`, resulting in a single point bottleneck of the whole post-training system. \n\n![verl_dataflow_DataProto](https://cdn.nlark.com/yuque/0/2025/jpeg/23208217/1758704289414-bcc54228-716b-4d4a-ad3b-f9ace6d10fcf.jpeg)\n\nLeveraging TransferQueue, we separate experience data transfer from metadata dispatch by\n\n- Replacing `DataProto` with `BatchMeta` (metadata) and `TensorDict` (actual data) structures\n- Preserving verl's original Dispatch/Collect logic via BatchMeta (maintaining single-controller debuggability)\n- Accelerating data transfer by TransferQueue's distributed storage units\n\n![verl_dataflow_TransferQueue](https://cdn.nlark.com/yuque/0/2025/jpeg/23208217/1758704301666-0807dc06-766c-4a2d-9cde-889a6bb56b34.jpeg)\n\n\nYou may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tree/dev/recipe/simple_use_case), where we mimic the verl usage in both async & sync scenarios. \n\n\n\n\n\n<h2 id=\"citation\"> Citation</h2>\nPlease kindly cite our paper if you find this repo is useful:\n\n```bibtex\n@article{han2025asyncflow,\n  title={AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training},\n  author={Han, Zhenyu and You, Ansheng and Wang, Haibo and Luo, Kui and Yang, Guang and Shi, Wenqi and Chen, Menglong and Zhang, Sicheng and Lan, Zeshun and Deng, Chunshi and others},\n  journal={arXiv preprint arXiv:2507.01663},\n  year={2025}\n}\n```"
  },
  {
    "path": "verl_distillation/docs/examples/config.rst",
    "content": ".. _config-explain-page:\n\nConfig Explanation\n===================\n\nLast updated: 06/18/2025.\n\nppo_trainer.yaml for RL FSDP Backend\n-------------------------------------\n\nData\n~~~~\n\n.. code:: yaml\n\n   data:\n     tokenizer: null\n     train_files: ~/data/rlhf/gsm8k/train.parquet\n     val_files: ~/data/rlhf/gsm8k/test.parquet\n     train_max_samples: -1  # set to -1 to use full dataset\n     val_max_samples: -1  # set to -1 to use full dataset\n     prompt_key: prompt\n     max_prompt_length: 512\n     max_response_length: 512\n     train_batch_size: 1024\n     return_raw_input_ids: False  # This should be set to true when the tokenizer between policy and rm differs\n     return_raw_chat: False\n     return_full_prompt: False\n     shuffle: True\n     seed: 42\n     filter_overlong_prompts: False\n     filter_overlong_prompts_workers: 1\n     truncation: error\n     image_key: images\n     trust_remote_code: True\n     custom_cls:\n        path: null\n        name: null\n\n- ``data.train_files``: Training set parquet. Can be a list or a single\n  file. The program will read all files into memory, so it can't be too\n  large (< 100GB). The path can be either local path or HDFS path. For\n  HDFS path, we provide utils to download it to DRAM and convert the\n  HDFS path to local path.\n- ``data.val_files``: Validation parquet. Can be a list or a single\n  file.\n- ``data.train_max_samples``: Maximum number of samples to use from the\n  training dataset. Set to -1 to use the full dataset.\n- ``data.val_max_samples``: Maximum number of samples to use from the\n  validation dataset. Set to -1 to use the full dataset.\n- ``data.prompt_key``: The field in the dataset where the prompt is\n  located. Default is 'prompt'.\n- ``data.max_prompt_length``: Maximum prompt length. All prompts will be\n  left-padded to this length. An error will be reported if the length is\n  too long\n- ``data.max_response_length``: Maximum response length. Rollout in RL\n  algorithms (e.g. PPO) generates up to this length\n- ``data.train_batch_size``: Batch size sampled for one training\n  iteration of different RL algorithms.\n- ``data.return_raw_input_ids``: Whether to return the original\n  input_ids without adding chat template. This is mainly used to\n  accommodate situations where the reward model's chat template differs\n  from the policy. It needs to be decoded first, then apply the RM's\n  chat template. If using a model-based RM, and the policy and RM\n  chat_templates are different, this flag needs to be set\n- ``data.return_raw_chat``: Whether to return the original chat (prompt)\n  without applying chat template.\n- ``data.return_full_prompt``: Whether to return the full prompt with chat template\n- ``data.shuffle``: Whether to shuffle the data in the dataloader.\n- ``data.seed``: An integer seed to use when shuffling the data. If not set or set to\n  `null`, the data shuffling will not be seeded, resulting in a different data order on each run.\n- ``data.filter_overlong_prompts``: Default don't filter.\n- ``data.filter_overlong_prompts_workers``: For large-scale dataset, filtering\n  overlong prompts could be timeconsuming. You cat set the ``filter_overlong_prompts_workers``\n  to use multiprocessing for speed up. Default to 1.\n- ``data.truncation``: Truncate the input_ids or prompt length if they\n  exceed max_prompt_length. Default is 'error', not allow exceed the\n  max_prompt_length. The users should increase the max_prompt_length if\n  throwing the error. You can also set ``left``, ``right`` and ``middle``. \n  When ``middle`` is selected, the logic splits the allowed max length roughly in half \n  and keeps the head and tail of the sequence, effectively discarding the middle section.\n- ``data.image_key``: The field in the multi-modal dataset where the image is\n  located. Default is 'images'.\n- ``data.trust_remote_code``: If the remote tokenizer has python file, we can use this field to allow \n  using remote tokenizer. For example: moonshotai/Moonlight-16B-A3B-Instruct\n\nCustomized Dataset\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nCustomized dataset extension is implemented for the SFT trainer and can be extended to other trainers with similar changes.\n\n.. code:: yaml\n\n   custom_cls:\n     path: null\n     name: null\n\n- ``data.custom_cls.path``: The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used.\n- ``data.custom_cls.name``: The name of the dataset class within the specified file.\n\nActor/Rollout/Reference Policy\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n   actor_rollout_ref:\n    hybrid_engine: True\n    model:\n      path: ~/models/deepseek-llm-7b-chat\n      external_lib: null\n      override_config:\n        attn_implementation: flash_attention_2  # or eager, sdpa - attention implementation override\n        model_config: {}\n        moe_config:  # Megatron only, can adjust moe configuration\n          freeze_moe_router: False  # Megatron only, can freeze moe router (no grad)\n      enable_gradient_checkpointing: False\n      enable_activation_offload: False\n      trust_remote_code: False\n      use_remove_padding: False\n    actor:\n      strategy: fsdp  # This is for backward-compatibility\n      ppo_mini_batch_size: 256\n      ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu\n      ppo_micro_batch_size_per_gpu: 8\n      use_dynamic_bsz: False\n      ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}\n      grad_clip: 1.0\n      clip_ratio: 0.2\n      entropy_coeff: 0.0\n      use_kl_loss: False # True for GRPO\n      # Rollout Importance Sampling (corrects distribution mismatch between rollout and training)\n      rollout_is: False # Enable IS correction\n      rollout_is_threshold: null # Upper threshold for IS weights (null to disable)\n      rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper)\n      rollout_is_level: token # Aggregation: token/sequence/geometric\n      rollout_is_mode: truncate # Bounding: truncate/mask\n      rollout_is_veto_threshold: null # Catastrophic outlier threshold (null to disable)\n      use_torch_compile: True # False to disable torch compile\n      kl_loss_coef: 0.001 # for grpo\n      kl_loss_type: low_var_kl # for grpo\n      ppo_epochs: 1\n      data_loader_seed: null\n      shuffle: False\n      ulysses_sequence_parallel_size: 1 # sp size\n      optim:\n        lr: 1e-6\n        lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.\n        lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n        min_lr_ratio: 0.0   # only used with cosine lr scheduler, default to 0.0\n        num_cycles: 0.5     # only used with cosine lr scheduler, default to 0.5\n        lr_scheduler_type: constant  # select from constant/cosine\n        total_training_steps: -1  # must be override by program\n      fsdp_config:\n        wrap_policy:\n          # transformer_layer_cls_to_wrap: None\n          min_num_params: 0\n        param_offload: False\n        optimizer_offload: False\n        fsdp_size: -1\n      checkpoint:\n        # What to include in saved checkpoints\n        # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n        save_contents: ['model', 'optimizer', 'extra']\n        # For more flexibility, you can specify the contents to load from the checkpoint.\n        load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}\n    ref:\n      fsdp_config:\n        param_offload: False\n        wrap_policy:\n          # transformer_layer_cls_to_wrap: None\n          min_num_params: 0\n      log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n      log_prob_micro_batch_size_per_gpu: 16\n      log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n      log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n      ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size\n    rollout:\n      name: vllm\n      temperature: 1.0\n      top_k: -1 # 0 for hf rollout, -1 for vllm rollout\n      top_p: 1\n      prompt_length: ${data.max_prompt_length}  # not use for opensource\n      response_length: ${data.max_response_length}\n      # for vllm rollout\n      dtype: bfloat16 # should align with FSDP\n      gpu_memory_utilization: 0.5\n      ignore_eos: False\n      enforce_eager: True\n      free_cache_engine: True\n      load_format: dummy_dtensor\n      tensor_model_parallel_size: 2\n      max_num_batched_tokens: 8192\n      max_num_seqs: 1024\n      log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n      log_prob_micro_batch_size_per_gpu: 16\n      log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n      log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n      # for hf rollout\n      do_sample: True\n      engine_kwargs: # inference engine parameters, please refer vllm/sglang official doc for detail\n        vllm: {}\n        sglang: {}\n\n      n: 1 # for each prompt, sample n responses (i.e. num sample times). set it to values > 1 for grpo, rloo\n      calculate_log_probs: False # set to True for computing log probs via rollouts\n      val_kwargs:\n        # sampling parameters for validation\n        top_k: -1 # 0 for hf rollout, -1 for vllm rollout\n        top_p: 1.0\n        temperature: 0\n        n: 1\n        do_sample: False # default eager for validation\n\n      agent:\n        custom_async_server: # Use custom async server implementation for rollout\n          path: null\n          name: null\n\n**Common config for actor, rollout and reference model**\n\n- ``actor_rollout_ref.hybrid_engine``: Whether it's a hybrid engine,\n  currently only supports hybrid engine\n- ``actor_rollout_ref.model.path``: Huggingface model path. This can be\n  either local path or HDFS path. For HDFS path, we provide utils to\n  download it to DRAM and convert the HDFS path to local path.\n- ``actor_rollout_ref.model.external_libs``: Additional Python packages\n  that need to be imported. Used to register models or tokenizers into\n  the Huggingface system.\n- ``actor_rollout_ref.model.override_config``: Used to override some of\n  the model's original configurations. Common overrides include:\n  \n  - ``attn_implementation``: Override the attention implementation. Default is ``flash_attention_2``.\n    Supported values: ``flash_attention_2``, ``eager``, ``sdpa``. Use ``eager`` for debugging or\n    compatibility issues. See :ref:`attention-implementation-override` for detailed usage.\n\n- ``actor_rollout_ref.model.enable_gradient_checkpointing``: FSDP only, decide\n  Whether to enable gradient checkpointing for the actor,\n  Megatron uses recompute options in ``override_transformer_config`` to set this\n- ``actor_rollout_ref.model.enable_activation_offload``: Whether to enable\n  activation offloading for the actor\n- ``actor_rollout_ref.model.trust_remote_code``: Whether to enable loading\n  a remote code model\n- ``actor_rollout_ref.model.use_fused_kernels``: Whether to use fused\n  kernels in the model. If set to True, the following parameters will be\n  used.\n\n  - ``actor_rollout_ref.model.fused_kernel_options.impl_backend``: The\n    implementation backend for fused kernels. Options: \"triton\" or\n    \"torch\". Default is \"torch\".\n    While in megatron, we only support \"triton\" as the\n    implementation backend, so there is no need for this option.\n\n- ``actor_rollout_ref.model.use_remove_padding``: Whether to use remove\n  padding in the model. If set to True, the model will remove padding\n  tokens in the input_ids and response_ids. This helps a lot in improving model running efficiency.\n\n**Actor model**\n\n- ``actor_rollout_ref.actor.strategy``: fsdp or megatron. In this\n  example, we use fsdp backend.\n\n- ``actor_rollout_ref.actor.ppo_mini_batch_size``: One sample is split\n  into multiple sub-batches with batch_size=ppo_mini_batch_size for PPO\n  updates. The ppo_mini_batch_size is a global num across all workers/gpus\n\n- ``actor_rollout_ref.actor.ppo_micro_batch_size``: [Will be deprecated, use ppo_micro_batch_size_per_gpu] \n  Similar to gradient accumulation, the micro_batch_size_per_gpu for one forward pass,\n  trading speed for GPU memory. The value represent the global view.\n\n- ``actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``: Similar to gradient\n  accumulation, the micro_batch_size_per_gpu for one forward pass, trading speed\n  for GPU memory. The value represent the local num per gpu.\n\n- ``actor_rollout_ref.actor.grad_clip``: Gradient clipping for actor\n  updates\n- ``actor_rollout_ref.actor.use_kl_loss``: to use kl loss in actor. When used, we are not applying KL in the reward function.\n\n- ``actor_rollout_ref.actor.clip_ratio``: PPO clip ratio\n\n- ``actor_rollout_ref.actor.use_torch_compile``: Whether to use torch compile in actor\n\n- ``actor_rollout_ref.actor.entropy_coeff``: The weight of entropy when\n  calculating PPO loss. The default value is changed to 0.0 since v0.3.x\n\n- ``actor_rollout_ref.actor.ppo_epochs``: Number of epochs for PPO\n  updates on one set of sampled data\n\n- ``actor_rollout_ref.actor.data_loader_seed``: From torch 2.6.0 Megatron backend can get wrong seed generated by pytorch \n  between cp ranks and cause misalignment between data on these ranks, so we shall manually set the seed to avoid hanging\n  issue. if ``actor_rollout_ref.actor.shuffle`` is not null, this must be set.\n\n- ``actor_rollout_ref.actor.shuffle``: Whether to shuffle data when\n  there are multiple epochs\n\n- ``actor_rollout_ref.actor.optim``: Actor's optimizer parameters\n\n- ``actor_rollout_ref.actor.fsdp_config``: FSDP config for actor\n  training\n\n  - ``wrap_policy``: FSDP wrap policy. By default, it uses Huggingface's\n    wrap policy, i.e., wrapping by DecoderLayer\n\n    - No need to set transformer_layer_cls_to_wrap, so we comment it.\n\n  - ``*_offload``: Whether to enable parameter, gradient and optimizer\n    offload\n\n    - Trading speed for GPU memory.\n\n- ``actor_rollout_ref.actor.use_kl_loss``: Whether to enable kl loss. Default is False.\n\n- ``actor_rollout_ref.actor.kl_loss_coef``: The coefficient of kl loss. Default is 0.001. \n\n- ``actor_rollout_ref.actor.kl_loss_type``: Support ``kl`` (``k1``), ``abs``, ``mse`` (``k2``), ``low_var_kl`` (``k3``) and ``full``. Appending ``+`` in the end (e.g., ``k1+`` and ``k3+``) would use straight-through to employ ``k2`` for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty()` in `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py>`_ . See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\n- ``actor_rollout_ref.actor.checkpoint``: The configurations of checkpoint function in actor\n\n  - ``save_contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint.\n    The extra information includes Rng states currently, FSDP supported lr_scheduler, and Megatron opt_param_scheduler will coming soon.\n    We do not store hf_model in checkpoint by default, but we provide a tool in ``scripts/model_merge.py`` to convert checkpoint format to hf format.\n\n  - ``load_contents``: The contents to load in the checkpoint, you can specify different checkpoint loading contents. By default, it is the same with ``save_checkpoint``.\n\n**Reference Model**\n\nReference model will be enabled when ``actor.use_kl_loss`` or/and ``algorithm.use_kl_in_reward`` is/are True.\n\n- ``actor_rollout_ref.ref``: FSDP config same as actor. **For models\n  larger than 7B, it's recommended to turn on offload for ref by\n  default**\n\n- ``actor_rollout_ref.ref.log_prob_micro_batch_size``: [Will be deprecate, use log_prob_micro_batch_size_per_gpu]\n  The batch size for one forward pass in the computation of ``ref_log_prob``. The value represent the global num.\n\n- ``actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu``: The batch size\n  for one forward pass in the computation of ``ref_log_prob``. The value represent the local num per gpu.\n\n**Rollout Model**\n\n- ``actor_rollout_ref.rollout.name``: hf/vllm/sglang.\n\n- Rollout (Auto-regressive) parameters. The key should be equal to the\n  property name in vLLM's ``SamplingParams``.\n\n  - ``temperature``, ``top_k``, ``top_p`` and others: Sampling\n    parameters in ``SamplingParams``.\n\n- ``actor_rollout_ref.rollout.dtype``: Rollout model parameters type. This should be align with\n  the actor model parameter type in FSDP/Megatron backend.\n\n- ``actor_rollout_ref.rollout.gpu_memory_utilization``:\n\n  - For vLLM v0.7.0 and later: The fraction of **total** GPU memory to be used for the vLLM instance.\n  - For SGLang: Corresponding to ``mem_fraction_static``, the fraction of the free GPU memory used for **static** memory like model weights and KV cache. \n\n- ``actor_rollout_ref.rollout.tensor_model_parallel_size``: TP size for rollout. Only effective\n  for vllm.\n\n- ``actor_rollout_ref.rollout.log_prob_micro_batch_size``: [Will be deprecate, use log_prob_micro_batch_size_per_gpu]\n  The batch size for one forward pass in the computation of ``log_prob``. The value represent the global num.\n\n- ``actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu``: Micro batch size per gpu (The batch size for\n  one forward pass) for recalculating ``log_prob``. The value represent the local num per gpu.\n\n- ``actor_rollout_ref.rollout.do_sample``: Whether to sample during training rollout. If set to False, the rollout model\n  will perform greedy sampling.\n\n- ``actor_rollout_ref.rollout.val_kwargs```: Sampling parameters used specifically during validation.\n\n  - ``top_k``: Top-k sampling parameter. Default to -1 for vLLM rollout or 0 for HF rollout.\n  - ``top_p``: Top-p sampling parameter. Default is 1.0 (disabled).\n  - ``temperature``: Sampling temperature. Default is 0 (deterministic greedy).\n  - ``n``: Number of responses to generate during validation. Default is 1.\n  - ``do_sample``: Whether to use sampling during validation. Default is False for\n    deterministic outputs. When set to True, the rollout will use the ``actor_rollout_ref.rollout.val_kwargs`` parameters\n    (top_k, top_p, temperature) to control the sampling behavior.\n\n- ``actor_rollout_ref.rollout.engine_kwargs.vllm``: extra vllm engine args, please refer vllm official doc for detail\n\n- ``actor_rollout_ref.rollout.engine_kwargs.sglang``: extra sglang engine args, please refer sglang official doc for detail\n\n- ``actor_rollout_ref.rollout.ignore_eos``: Whether to ignore the EOS\n  token and continue generating tokens after the EOS token is generated.\n\n- ``actor_rollout_ref.rollout.free_cache_engine``: Offload the KVCache\n  after rollout generation stage. Default is True. When set to True,\n  for vllm v0.5.4 and v0.6.3, we need to disable the usage of CUDAGraph\n  (set ``enforce_eager`` to True.)\n\n- ``actor_rollout_ref.rollout.enforce_eager``: Whether to use CUDAGraph\n  in vLLM generation. Default set to True to disable CUDAGraph.\n\n- ``actor_rollout_ref.rollout.load_format``: Which weight loader to use\n  to load the actor model weights to the rollout model.\n\n  - ``auto``: Use Megatron weight loader.\n  - ``megatron``: Use Megatron weight loader. Deployed with Megatron\n    backend. The input model ``state_dict()`` is already partitioned\n    along TP dimension and already gathered along PP dimension. This\n    weight loader requires that the Rollout model and Actor model's\n    parameters shape and name should be identical.\n  - ``dtensor``: Default solution when using Huggingface weight loader.\n    Deployed with FSDP backend and the state_dict_type is\n    ``StateDictType.SHARDED_STATE_DICT``. Recommend to use this weight\n    loader\n  - ``hf``: Use Huggingface weight loader. Deployed with FSDP backend\n    and the state_dict_type is ``StateDictType.FULL_STATE_DICT``. This\n    solution doesn't need to rewrite the weight loader for each model\n    implemented in vLLM but it results in larger peak memory usage.\n  - ``dummy_hf``, ``dummy_megatron``, ``dummy_dtensor``: Random\n    initialization.\n\n.. note:: **NOTED**: In this config field, users only need to select from ``dummy_megatron``, ``dummy_dtensor``, ``dummy_hf`` for rollout initialization and our hybrid engine will select the corresponding weight loader (i.e., ``megatron``, ``dtensor``, ``hf``) during actor/rollout weight synchronization.\n\n\nMegatron Optimizer and Optimizer Parameter Scheduler\n____________________________________________________\n\n.. code:: yaml\n\n    optim:\n      optimizer: adam\n      lr: 1e-6\n      clip_grad: 1.0\n      total_training_steps: -1  # must be override by program\n      lr_warmup_init: 0.0  # initial learning rate for warmup, default to 0.0\n      lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.\n      lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n      lr_decay_steps: null\n      lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root\n      min_lr: 0.0 # minimum learning rate, default to 0.0\n      weight_decay: 0.01\n      weight_decay_incr_style: constant # select from constant/linear/cosine\n      lr_wsd_decay_style: exponential # select from constant/exponential/cosine\n      lr_wsd_decay_steps: null\n      use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler\n\n\nNotice that there are some differences in APIs between Megatron optimizer and FSDP optimizer.\n\n- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``lr_scheduler_type`` actually means the style of lr decay after warmup.\n- Megatron optimizer also support weight decay decay mechanism\n- ``use_checkpoint_opt_param_scheduler`` determines whether to use the checkpoint optimizer parameter scheduler. If set to True, the optimizer parameter scheduler will be saved in the checkpoint and loaded from the checkpoint during resuming training.\n\nFor learning rate decay, original Megatron pretrain default option of ``lr_decay_style`` is ``linear``,\nmeaning that the learning rate will be linearly decayed from the initial learning rate to ``min_lr`` within the\n``lr_decay_steps``. However, in verl, to align with FSDP's default behavior, we set the default\n``lr_decay_style`` to ``constant``, meaning that the learning rate will be kept constant after the warmup stage.\n\n\nCritic Model\n~~~~~~~~~~~~\n\nMost parameters for Critic are similar to Actor Model.\n\nReward Model\n~~~~~~~~~~~~\n\n.. code:: yaml\n\n   reward_model:\n     enable: False\n     model:\n       input_tokenizer: ${actor_rollout_ref.model.path}  # set this to null if the chat template is identical\n       path: ~/models/Anomy-RM-v0.1\n       external_lib: ${actor_rollout_ref.model.external_lib}\n       trust_remote_code: False\n       fsdp_config:\n         min_num_params: 0\n         param_offload: False\n     micro_batch_size_per_gpu: 16\n     max_length: null\n     reward_manager: naive\n\n- ``reward_model.enable``: Whether to enable reward model. If False, we\n  compute the reward only with the user-defined reward functions. In\n  GSM8K and Math examples, we disable reward model. For RLHF alignment\n  example using full_hh_rlhf, we utilize reward model to assess the\n  responses. If False, the following parameters are not effective.\n- ``reward_model.model``\n\n  - ``input_tokenizer``: Input tokenizer. If the reward model's chat\n    template is inconsistent with the policy, we need to first decode to\n    plaintext, then apply the rm's chat_template. Then score with RM. If\n    chat_templates are consistent, it can be set to null.\n  - ``path``: RM's HDFS path or local path. Note that RM only supports\n    AutoModelForSequenceClassification. Other model types need to define\n    their own RewardModelWorker and pass it from the code.\n  - ``trust_remote_code``: Whether to enable loading a remote code model,\n    default to False.\n- ``reward_model.reward_manager``:  Reward Manager. This defines the mechanism\n  of computing rule-based reward and handling different reward sources. Default\n  is ``naive``. If all verification functions are multiprocessing-safe, the reward\n  manager can be set to ``prime`` for parallel verification.\n\nCustomized Reward Function\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n  \n   custom_reward_function:\n     path: null\n     name: compute_score\n\n- ``custom_reward_function.path``: The path to the file containing your customized reward function. If not specified, pre-implemented reward functions will be used.\n- ``custom_reward_function.name`` (Optional) : The name of the reward function within the specified file. Default is 'compute_score'.\n\nAlgorithm\n~~~~~~~~~\n\n.. code:: yaml\n\n   algorithm:\n     gamma: 1.0\n     lam: 1.0\n     adv_estimator: gae\n     use_kl_in_reward: False\n     kl_penalty: kl  # how to estimate kl divergence\n     kl_ctrl:\n       type: fixed\n       kl_coef: 0.005\n       horizon: 10000\n       target_kl: 0.1\n     # Rollout Importance Sampling\n     rollout_is: False\n     rollout_is_threshold: null\n     rollout_is_threshold_lower: null\n     rollout_is_level: token\n     rollout_is_mode: truncate\n     rollout_is_veto_threshold: null  # Disabled by default\n\n- ``gamma``: discount factor\n- ``lam``: Trade-off between bias and variance in the GAE estimator\n- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``, ``rloo_vectorized``, ``grpo_vectorized``\n- ``use_kl_in_reward``: Whether to enable in-reward kl penalty. Default is False.\n- ``kl_penalty``: Support ``kl``, ``abs``, ``mse``, ``low_var_kl`` and ``full``. How to\n  calculate the kl divergence between actor and reference policy. For\n  specific options, refer to `kl_penalty()` in `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py>`_ .\n- ``kl_ctrl``: Config for in-reward kl_penalty controller\n\n  - ``kl_coef``: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.\n  - ``type``: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.\n  - ``horizon`` and ``target_kl``: See source code of AdaptiveKLController for details.\n\n- ``rollout_is``: Whether to enable rollout importance sampling correction. Default is False.\n- ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely.\n- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper).\n- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental).\n- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds).\n- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is null (disabled).\n  Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.\n\nTrainer\n~~~~~~~\n\n.. code:: yaml\n\n   trainer:\n     total_epochs: 30\n     project_name: verl_examples\n     experiment_name: gsm8k\n     logger: ['console', 'wandb']\n     log_val_generations: 0\n     nnodes: 1\n     n_gpus_per_node: 8\n     save_freq: -1\n     val_before_train: True\n     test_freq: 2\n     critic_warmup: 0\n     default_hdfs_dir: null # hdfs checkpoint path\n     default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} # local checkpoint path\n     resume_mode: auto # or disable or resume_path if resume_from_path is set\n     resume_from_path: null\n     remove_previous_ckpt_in_save: False\n     del_local_ckpt_after_load: False\n     ray_wait_register_center_timeout: 300\n\n- ``trainer.total_epochs``: Number of epochs in training.\n- ``trainer.project_name``: For wandb, swanlab, mlflow\n- ``trainer.experiment_name``: For wandb, swanlab, mlflow\n- ``trainer.logger``: Support console and wandb, swanlab, mlflow, tensorboard, trackio\n- ``trainer.log_val_generations``: The number of logged generation during validation (default ``0``)\n- ``trainer.nnodes``: Number of nodes used in the training.\n- ``trainer.n_gpus_per_node``: Number of GPUs per node.\n- ``trainer.save_freq``: The frequency (by iteration) to save checkpoint\n  of the actor and critic model.\n- ``trainer.val_before_train``: Whether to run validation before training.\n- ``trainer.test_freq``: The validation frequency (by iteration).\n- ``trainer.critic_warmup``: The number of iteration to train the critic\n  model before actual policy learning.\n- ``trainer.resume_mode``: The mode of resuming training. Support\n  ``disable``, ``auto`` and ``resume_path``. If set to ``auto`` as default, the\n  program will automatically resume from the latest checkpoint in the\n  ``default_local_dir``. If set to ``resume_path``, the program will resume\n  from the path specified in ``resume_from_path``.\n- ``trainer.resume_from_path``: The path to resume training from. Only\n  effective when ``resume_mode`` is set to ``resume_path``.\n- ``trainer.remove_previous_ckpt_in_save``: Whether to remove previous\n  checkpoints in the save directory. Default is False.\n- ``trainer.del_local_ckpt_after_load``: Whether to delete local\n  checkpoints after loading them. Default is False.\n- ``trainer.ray_wait_register_center_timeout``: The timeout for waiting\n  for the ray register center to be ready. Default is 300 seconds.\n\n\nThis figure illustrates how the configurations affect the training.\n\nhttps://excalidraw.com/#json=pfhkRmiLm1jnnRli9VFhb,Ut4E8peALlgAUpr7E5pPCA\n\n.. image:: https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d\n\n\nevaluation.yaml\n---------------\n\nData\n~~~~\n\n.. code:: yaml\n\n   data:\n     path: /tmp/math_Qwen2-7B-Instruct.parquet\n     prompt_key: prompt\n     response_key: responses\n     data_source_key: data_source\n     reward_model_key: reward_model\n\n- ``data.path``: Path to the dataset file (Parquet format).\n- ``data.prompt_key``: The field in the dataset where the prompt is located. Default is 'prompt'.\n- ``data.response_key``: The key holds the generated responses. This should be a list of strings representing the responses. Default is 'responses'.\n- ``data.data_source_key``: This is used to separate metric calculations for different data sources, ensuring that metrics are calculated independently for each source.\n- ``data.reward_model_key``: The key holds the reference answers. These reference answers typically serve as the ground truth or test cases for the task.\n\nCustomized Reward Function\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n  \n   custom_reward_function:\n     path: null\n     name: compute_score\n\n- ``custom_reward_function.path``: The path to the file containing your customized reward function. If not specified, pre-implemented reward functions will be used.\n- ``custom_reward_function.name`` (Optional) : The name of the reward function within the specified file. Default is 'compute_score'.\n\nsft_trainer.yaml for SFT FSDP Backend\n--------------------------------------\n\n\nOptim\n~~~~~~~\n\n.. code:: yaml\n\n   optim:\n     optimizer: AdamW\n     optimizer_impl: torch.optim\n     lr: 1e-5\n     weight_decay: 0.01\n     lr_warmup_steps_ratio: 0.1\n     clip_grad: 1.0\n     lr_scheduler: cosine\n     override_optimizer_config: null\n\n- ``optimizer``: Optimizer class name (e.g., ``\"AdamW\"``, ``\"AdamW8bit\"``, ``\"_AdamW\"``). The class name as it appears in the module.\n- ``optimizer_impl``: Module path to import optimizer from (e.g., ``\"torch.optim\"``, ``\"torchao.optim\"``, ``\"bitsandbytes.optim\"``).\n- ``optim.lr``: Learning rate for the optimizer.\n- ``optim.weight_decay``: Weight decay for the optimizer.\n- ``optim.lr_warmup_steps_ratio``: Ratio of warmup steps to total training steps.\n- ``optim.clip_grad``: Gradient clipping value.\n- ``optim.lr_scheduler``: Learning rate scheduler type. Options:\n\n  - ``cosine``: Cosine learning rate scheduler with warmup (default).\n  - ``wsd``: Warmup-Stable-Decay scheduler that provides a stable learning rate phase between warmup and decay phases.\n\n- ``override_optimizer_config``: Dictionary of additional optimizer-specific keyword arguments. For example, to use ``torchao.optim``'s ``_AdamW`` with BF16 stochastic rounding: ``{\"bf16_stochastic_round\": true}``\n\nModel\n~~~~~~~~~~~~\n\nMost parameters for Model are similar to Reward Model.\n\n.. code:: yaml\n\n   model:\n     partial_pretrain: ~/models/gemma-1.1-7b-it\n     fsdp_config:\n       model_dtype: fp32\n       wrap_policy:\n         min_num_params: 0\n       cpu_offload: False\n       offload_params: False\n     external_lib: null\n     enable_gradient_checkpointing: False\n     trust_remote_code: False\n     lora_rank: 0\n     lora_alpha: 16\n     target_modules: all-linear\n     use_liger: False\n\n- ``partial_pretrain``: HDFS path or local path for the pretrained model.\n- ``fsdp_config``\n\n  - ``model_dtype``: Model parameters type, default to ``fp32``.\n    Support: ``bf16``, ``fp16``, ``fp32``.\n  - ``cpu_offload``: Whether to enable CPU offloading for FSDP. If True,\n    the offload_params will be used as argument.\n  - ``offload_params``: Whether to offload parameters to CPU\n    when not involved in computation. If True, then this offloads gradients\n    to CPU as well, meaning that the optimizer step runs on CPU.\n\n- ``lora_rank``: The rank of the LoRA model, default to 0. If ``lora_rank``>0,\n  we will train LoRA modules instead of tuning the full model.\n- ``lora_alpha``: The alpha parameter for LoRA scaling, default to 16.\n- ``target_modules``: The names of the modules to apply the adapter to,\n  default to ``all-linear``. See `peft docs <https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.target_modules>`_ for detail.\n\n- ``use_liger``: Whether to enable Liger kernel, default to False. If True,\n  we apply Liger kernel to the model (depends on `liger-kernel`).\n"
  },
  {
    "path": "verl_distillation/docs/examples/gsm8k_example.rst",
    "content": "GSM8K Example\n=============\n\nLast updated: 03/25/2025.\n\nIntroduction\n------------\n\nIn this example, we train an LLM to tackle the GSM8k task.\n\nPaper: https://arxiv.org/pdf/2110.14168\n\nDataset: https://huggingface.co/datasets/gsm8k\n\nNote that the original paper mainly focuses on training a verifier (a\nreward model) to solve math problems via Best-of-N sampling. In this\nexample, we train an RLHF agent using a rule-based reward model.\n\nDataset Introduction\n--------------------\n\nGSM8k is a math problem dataset. The prompt is an elementary school\nproblem. The LLM model is required to answer the math problem.\n\nThe training set contains 7473 samples and the test set contains 1319\nsamples.\n\n**An example**\n\nPrompt\n\n   Katy makes coffee using teaspoons of sugar and cups of water in the\n   ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups\n   of water, calculate the number of teaspoonfuls of sugar she used.\n\nSolution\n\n   The total ratio representing the ingredients she used to make the\n   coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the\n   number of teaspoons she used is 7/20, she used 7/20\\ *120 =\n   <<7/20*\\ 120=42>>42 #### 42\n\nStep 1: Prepare dataset\n-----------------------\n\n.. code:: bash\n\n   cd examples/data_preprocess\n   python3 gsm8k.py --local_save_dir ~/data/gsm8k\n\nStep 2: Download Model\n----------------------\n\nThere're three ways to prepare the model checkpoints for post-training:\n\n- Download the required models from huggingface or modelscope\n\n.. code:: bash\n\n   huggingface-cli download deepseek-ai/deepseek-math-7b-instruct --local-dir ~/models/deepseek-math-7b-instruct --local-dir-use-symlinks False\n   # or\n   modelscope download --model deepseek-ai/deepseek-math-7b-instruct --local_dir ~/models/deepseek-math-7b-instruct\n\n- Already store your store model in the local directory or HDFS path.\n- Also, you can directly use the model name in huggingface (e.g.,\n  deepseek-ai/deepseek-math-7b-instruct) in\n  ``actor_rollout_ref.model.path`` and ``critic.model.path`` field in\n  the run script. You can also download models from modelscope by setting environmental variable ``VERL_USE_MODELSCOPE=True``.\n  See examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh for example.\n\nNoted that users should prepare checkpoints for actor, critic and reward\nmodel.\n\n[Optional] Step 3: SFT your Model\n---------------------------------\n\nWe provide a SFT Trainer using PyTorch FSDP in\n`fsdp_sft_trainer.py <https://github.com/volcengine/verl/blob/main/verl/trainer/fsdp_sft_trainer.py>`_. \nUsers can customize their own SFT\nscript using our FSDP SFT Trainer.\n\nWe also provide various training scripts for SFT on GSM8K dataset in `gsm8k sft directory <https://github.com/volcengine/verl/blob/main/examples/sft/gsm8k/>`_.\n\n.. code:: shell\n\n   set -x\n\n   torchrun -m verl.trainer.fsdp_sft_trainer \\\n       data.train_files=$HOME/data/gsm8k/train.parquet \\\n       data.val_files=$HOME/data/gsm8k/test.parquet \\\n       data.prompt_key=question \\\n       data.response_key=answer \\\n       data.micro_batch_size_per_gpu=8 \\\n       model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \\\n       trainer.project_name=gsm8k-sft \\\n       trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \\\n       trainer.total_epochs=4 \\\n       trainer.logger='[\"console\",\"wandb\"]'\n\n\nIf you use AMD GPUs (ROCm kernel), you need to add the following environment variables into the run script:\n\n    .. code-block:: bash\n\n        export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n        export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n        export CUDA_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n\n\nStep 4: Perform PPO training with your model on GSM8K Dataset\n-------------------------------------------------------------\n\n- Prepare your own run.sh script. Here's an example for GSM8k dataset\n  and deepseek-llm-7b-chat model.\n- Users could replace the ``data.train_files`` ,\\ ``data.val_files``,\n  ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on\n  their environment.\n- See :doc:`config` for detailed explanation of each config field.\n\n**Reward Model/Function**\n\nWe use a rule-based reward model. We force the model to produce a final\nanswer following 4 “#” as shown in the solution. We extract the final\nanswer from both the solution and model's output using regular\nexpression matching. We compare them and assign a reward of 1 to correct\nanswer, 0.1 to incorrect answer and 0 to no answer.\n\n**Training Script**\n\nThe training script example for FSDP and Megatron-LM backend are stored in examples/ppo_trainer directory.\n\n.. code:: bash\n\n   cd ../ppo_trainer\n   bash run_deepseek7b_llm.sh\n\nThe script of run_deepseek7b_llm.sh\n\n.. code:: bash\n\n   set -x\n\n   python3 -m verl.trainer.main_ppo \\\n      data.train_files=$HOME/data/gsm8k/train.parquet \\\n      data.val_files=$HOME/data/gsm8k/test.parquet \\\n      data.train_batch_size=1024 \\\n      data.max_prompt_length=512 \\\n      data.max_response_length=512 \\\n      actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n      actor_rollout_ref.actor.optim.lr=1e-6 \\\n      actor_rollout_ref.model.use_remove_padding=True \\\n      actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n      actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n      actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n      actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n      actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n      actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n      actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n      actor_rollout_ref.rollout.name=vllm \\\n      actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n      actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n      actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n      critic.optim.lr=1e-5 \\\n      critic.model.use_remove_padding=True \\\n      critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n      critic.model.enable_gradient_checkpointing=True \\\n      critic.ppo_micro_batch_size_per_gpu=32 \\\n      critic.model.fsdp_config.param_offload=False \\\n      critic.model.fsdp_config.optimizer_offload=False \\\n      algorithm.kl_ctrl.kl_coef=0.001 \\\n      trainer.critic_warmup=0 \\\n      trainer.logger='[\"console\",\"wandb\"]' \\\n      trainer.project_name='verl_example_gsm8k' \\\n      trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n      trainer.n_gpus_per_node=8 \\\n      trainer.nnodes=1 \\\n      trainer.save_freq=-1 \\\n      trainer.test_freq=1 \\\n      trainer.total_epochs=15 $@\n\n\nIf you use AMD GPUs (ROCm kernel), you need to add the following environment variables into the run script:\n\n    .. code-block:: bash\n\n        export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n        export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n        export CUDA_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n\nIf you encounter any issues in using AMD GPUs running VeRL, feel free to contact me - `Yusheng Su <https://yushengsu-thu.github.io/>`_."
  },
  {
    "path": "verl_distillation/docs/examples/multi_modal_example.rst",
    "content": "Multi-Modal Example Architecture\n=================================\n\nLast updated: 04/28/2025.\n\nIntroduction\n------------\n\nNow, verl has supported multi-modal training. You can use fsdp and \nvllm/sglang to start a multi-modal RL task. Megatron supports is also \non the way.\n\nFollow the steps below to quickly start a multi-modal RL task.\n\nStep 1: Prepare dataset\n-----------------------\n\n.. code:: python\n\n    # it will be saved in the $HOME/data/geo3k folder\n    python examples/data_preprocess/geo3k.py\n\nStep 2: Download Model\n----------------------\n\n.. code:: bash\n\n    # download the model from huggingface\n    python3 -c \"import transformers; transformers.pipeline(model='Qwen/Qwen2.5-VL-7B-Instruct')\"\n\nStep 3: Perform GRPO training with multi-modal model on Geo3K Dataset\n---------------------------------------------------------------------\n\n.. code:: bash\n\n    # run the task\n    bash examples/grpo_trainer/run_qwen2_5_vl-7b.sh\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "verl_distillation/docs/examples/ppo_code_architecture.rst",
    "content": "PPO Example Architecture\n========================\n\nLast updated: 02/17/2025.\n\nLet's start with the Proximal Policy Optimization algorithm, which is\nmost widely used algorithm in LLM post-training.\n\nThe main entry point of the PPO algorithm example is:\n`main_ppo.py <https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py>`_.\nIn this tutorial, we will go through the code architecture in `main_ppo.py <https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py>`_.\n\nDefine the data\n---------------\n\nUsers need to preprocess and store the dataset in parquet files.\nAnd we implement `RLHFDataset` to load and tokenize the parquet files.\n\nFor ``RLHFDataset`` (Default), at least 1 fields are required:\n\n- ``prompt``: Contains the string prompt\n\nWe already provide some examples of processing the datasets to parquet\nfiles in `data_preprocess directory <https://github.com/volcengine/verl/blob/main/examples/data_preprocess>`_. Currently, we support\npreprocess of GSM8k, MATH, Hellasage, Full_hh_rlhf datasets. See :doc:`../preparation/prepare_data` for\nmore information.\n\nDefine the reward functions for different datasets\n--------------------------------------------------\n\nIn this main entry point, the users only need to define their own reward\nfunction based on the datasets (or applications) utilized in PPO\ntraining.\n\nFor example, we already provide reward functions for `GSM8k <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/gsm8k.py>`_ \nand `MATH <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/math.py>`_\ndatasets in the ``_select_rm_score_fn``. In the ``RewardManager``, we\nwill compute the reward score based on the data_source to select\ncorresponding reward functions. For some RLHF datasets (e.g.,\nfull_hh_rlhf), the reward model is utilized to assess the responses\nwithout any reward functions. In this case, the ``RewardManager`` will\nreturn the ``rm_score`` computed by the reward model directly.\n\nSee `reward functions <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_ for detailed implementation.\n\nDefine worker classes\n---------------------\n\n.. code:: python\n\n   if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}: # for FSDP backend\n       assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n       from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker\n       from verl.single_controller.ray import RayWorkerGroup\n       ray_worker_group_cls = RayWorkerGroup\n\n   elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend\n       assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n       from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker\n       from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n       ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM\n\n   else:\n       raise NotImplementedError\n\n   from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n   role_worker_mapping = {\n       Role.ActorRollout: ActorRolloutRefWorker,\n       Role.Critic: CriticWorker,\n       Role.RefPolicy: ActorRolloutRefWorker\n   }\n\n   global_pool_id = 'global_pool'\n   resource_pool_spec = {\n       global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n   }\n   mapping = {\n       Role.ActorRollout: global_pool_id,\n       Role.Critic: global_pool_id,\n       Role.RefPolicy: global_pool_id,\n   }\n\nStep 1: Construct the mapping between roles and workers\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nA role represents a group of workers in the same process. We have\npre-defined several roles in `ray_trainer.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py#L38>`_.\n\n.. code:: python\n\n   class Role(Enum):\n       \"\"\"\n       To create more roles dynamically, you can subclass Role and add new members\n       \"\"\"\n       Actor = 0  # This worker only has Actor\n       Rollout = 1 # This worker only has Rollout\n       ActorRollout = 2 # This worker has both actor and rollout, it's a HybridEngine\n       Critic = 3 # This worker only has critic\n       RefPolicy = 4 # This worker only has reference policy\n       RewardModel = 5 # This worker only has reward model\n       ActorRolloutRef = 6 # This worker contains actor, rollout and reference policy simultaneously \n\nStep 2: Define the worker class corresponding to this role\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n- We have pre-implemented the ``ActorRolloutRefWorker``. Through\n  different configs, it can be a standalone actor, a standalone rollout,\n  an ActorRollout HybridEngine, or an ActorRolloutRef HybridEngine\n- We also pre-implemented workers for ``Actor``, ``Rollout``,\n  ``Critic``, ``Reward Model`` and ``Reference model`` on two different\n  backend: PyTorch FSDP\n  and Megatron-LM.\n  See `FSDP Workers <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py>`_ \n  and `Megatron-LM Workers <https://github.com/volcengine/verl/blob/main/verl/workers/megatron_workers.py>`_\n  for more information.\n\nStep 3: Define resource pool id and resource pool spec\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n- Resource pool is a division of global GPU resources,\n  ``resource_pool_spec`` is a dict, mapping from id to # of GPUs\n\n  - In the above example, we defined a global resource pool:\n    global_pool_id, and then put all roles on this one resource pool\n    with all the GPUs in this post-training task. This refers to\n    *co-locate* placement where all the models share the same set of\n    GPUs.\n\n- See resource pool and placement for advance usage.\n\nDefining reward model/function\n------------------------------\n\n.. code:: python\n\n   # we should adopt a multi-source reward function here\n   # - for rule-based rm, we directly call a reward score\n   # - for model-based rm, we call a model\n   # - for code related prompt, we send to a sandbox if there are test cases\n   # - finally, we combine all the rewards together\n   # - The reward type depends on the tag of the data\n   if config.reward_model.enable:\n       from verl.workers.fsdp_workers import RewardModelWorker\n       role_worker_mapping[Role.RewardModel] = RewardModelWorker\n       mapping[Role.RewardModel] = global_pool_id\n    \n   reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)\n\n   # Note that we always use function-based RM for validation\n   val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1)\n\n   resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\nSince not all tasks use model-based RM, users need to define here\nwhether it's a model-based RM or a function-based RM\n\n- If it's a model-based RM, directly add the ``RewardModel`` role in the\n  resource mapping and add it to the resource pool mapping.\n\n  - Note that the pre-defined ``RewardModelWorker`` only supports models\n    with the structure of huggingface\n    ``AutoModelForSequenceClassification``. If it's not this model, you\n    need to define your own RewardModelWorker in `FSDP Workers <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py>`_ \n    and `Megatron-LM Workers <https://github.com/volcengine/verl/blob/main/verl/workers/megatron_workers.py>`_.\n\n- If it's a function-based RM, the users are required to classified the\n  reward function for each datasets.\n\n.. code:: python\n\n   def _select_rm_score_fn(data_source):\n       if data_source == 'openai/gsm8k':\n           return gsm8k.compute_score\n       elif data_source == 'lighteval/MATH':\n           return math.compute_score\n       else:\n           raise NotImplementedError\n\nSee reward functions implemented in `directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/>`_ \nfor more information.\n\nDefine, init and run the PPO Trainer\n------------------------------------\n\n.. code:: python\n\n   trainer = RayPPOTrainer(config=config,\n                           tokenizer=tokenizer,\n                           role_worker_mapping=role_worker_mapping,\n                           resource_pool_manager=resource_pool_manager,\n                           ray_worker_group_cls=ray_worker_group_cls,\n                           reward_fn=reward_fn,\n                           val_reward_fn=val_reward_fn)\n   trainer.init_workers()\n   trainer.fit()\n\n- We first initialize the ``RayPPOTrainer`` with user config, tokenizer\n  and all the above worker mapping, resource pool, worker group and\n  reward functions\n- We first call the ``trainer.init_workers()`` to initialize the models\n  on the allocated GPUs (in the resource pool)\n- The actual PPO training will be executed in ``trainer.fit()``\n\nverl can be easily extended to other RL algorithms by reusing the Ray\nmodel workers, resource pool and reward functions. See :doc:`extension<../advance/dpo_extension>` for\nmore information.\n\nDetails of the ``RayPPOTrainer`` is discussed in :doc:`Ray Trainer<../workers/ray_trainer>`.\n"
  },
  {
    "path": "verl_distillation/docs/examples/sandbox_fusion_example.rst",
    "content": "Sandbox Fusion Example\n============================\n\nLast updated: 06/27/2025.\n\nIntroduction\n------------\n\nSandbox Fusion is a remote code sandbox service that provides a secure environment for running and evaluating code generated by Large Language Models (LLMs). This example demonstrates how to train an LLM and use Sandbox Fusion to verify generated code, enhancing both security and performance.\n\nBy leveraging a remote code sandbox service with greater CPU resources for concurrent code verification, you can reduce the reward stage time by 10-30%, depending on the quality of the generated code.\n\nStep 1: Prepare the Dataset\n---------------------------\n\nWe use the Eurus-2-RL-Data dataset for training. This dataset combines math and code questions, making it suitable for LLM training tasks. You can download it from HuggingFace: `Eurus-2-RL-Data Dataset <https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data>`_.\n\nStep 2: Set Up the Sandbox Fusion Service\n-----------------------------------------\n\nSandbox Fusion is a remote code sandbox service designed to securely run and evaluate LLM-generated code. To use it:\n\n1. **Access Full Documentation**: For detailed setup instructions, refer to the `Sandbox Fusion Documentation <https://bytedance.github.io/SandboxFusion/>`_.\n2. **Deploy the Service**: Choose one of the following deployment methods:\n\n   - **Local Deployment**: Follow the guide `here <https://bytedance.github.io/SandboxFusion/docs/docs/get-started#local-deployment>`_.\n   - **FaaS Instance (Volcengine)**: Create an instance using the `Volcengine Documentation <https://www.volcengine.com/docs/6662/1539235>`_.\n\nAfter deployment, you will receive an API endpoint in the format: ``https://<ip-address-or-domain-name>/run_code``.\n\nStep 3: Configure the Training Script\n-------------------------------------\n\nTo integrate Sandbox Fusion into your training script, configure the following parameters:\n\n**Key Settings for Sandbox Fusion**\n\n- ``reward_model.sandbox_fusion.url='<API-endpoint>'``: Enable Sandbox Fusion by specifying the API endpoint (must end with ``/run_code``).\n- ``reward_model.sandbox_fusion.max_concurrent=256``: Set the maximum number of concurrent API requests to the Sandbox Fusion service.\n- ``reward_model.sandbox_fusion.memory_limit_mb=1024``: Set the memory limit (in MB) for each sandbox instance. Defaults to 1024MB if not specified.\n\n**Additional Optimization**\n\nTo further reduce code verification time, enable parallel processing with:  \n\n- ``reward_model.reward_manager=prime``: The Prime reward manager verifies code across multiple subprocesses concurrently.\n\n**Example Script**\n\nFor a practical implementation, refer to the example script:  \n\n``examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh``\n\nOnce you’ve set your API endpoint in the script, you can start the training job."
  },
  {
    "path": "verl_distillation/docs/examples/skypilot_examples.rst",
    "content": "SkyPilot Examples\n=================\n\nLast updated: 09/04/2025.\n\nThis guide provides examples of running VERL reinforcement learning training on Kubernetes clusters or cloud platforms with GPU nodes using `SkyPilot <https://github.com/skypilot-org/skypilot>`_.\n\nInstallation and Configuration\n-------------------------------\n\nStep 1: Install SkyPilot\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\nChoose the installation based on your target platform:\n\n.. code-block:: bash\n\n   # For Kubernetes only\n   pip install \"skypilot[kubernetes]\"\n   \n   # For AWS\n   pip install \"skypilot[aws]\"\n   \n   # For Google Cloud Platform\n   pip install \"skypilot[gcp]\"\n   \n   # For Azure\n   pip install \"skypilot[azure]\"\n   \n   # For multiple platforms\n   pip install \"skypilot[kubernetes,aws,gcp,azure]\"\n\nStep 2: Configure Your Platform\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nSee https://docs.skypilot.co/en/latest/getting-started/installation.html\n\nStep 3: Set Up Environment Variables\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nExport necessary API keys for experiment tracking:\n\n.. code-block:: bash\n\n   # For Weights & Biases tracking\n   export WANDB_API_KEY=\"your-wandb-api-key\"\n   \n   # For HuggingFace gated models (if needed)\n   export HF_TOKEN=\"your-huggingface-token\"\n\nExamples\n--------\n\nAll example configurations are available in the `examples/skypilot/ <https://github.com/volcengine/verl/tree/main/examples/skypilot>`_ directory on GitHub. See the `README <https://github.com/volcengine/verl/blob/main/examples/skypilot/README.md>`_ for additional details.\n\nPPO Training\n~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   sky launch -c verl-ppo verl-ppo.yaml --secret WANDB_API_KEY -y\n\nRuns PPO training on GSM8K dataset using Qwen2.5-0.5B-Instruct model across 2 nodes with H100 GPUs. Based on examples in ``examples/ppo_trainer/``.\n\n`View verl-ppo.yaml on GitHub <https://github.com/volcengine/verl/blob/main/examples/skypilot/verl-ppo.yaml>`_\n\nGRPO Training\n~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   sky launch -c verl-grpo verl-grpo.yaml --secret WANDB_API_KEY -y\n\nRuns GRPO (Group Relative Policy Optimization) training on MATH dataset using Qwen2.5-7B-Instruct model. Memory-optimized configuration for 2 nodes. Based on examples in ``examples/grpo_trainer/``.\n\n`View verl-grpo.yaml on GitHub <https://github.com/volcengine/verl/blob/main/examples/skypilot/verl-grpo.yaml>`_\n\nMulti-turn Tool Usage Training\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   sky launch -c verl-multiturn verl-multiturn-tools.yaml \\\n     --secret WANDB_API_KEY --secret HF_TOKEN -y\n\nSingle-node training with 8xH100 GPUs for multi-turn tool usage with Qwen2.5-3B-Instruct. Includes tool and interaction configurations for GSM8K. Based on examples in ``examples/sglang_multiturn/`` but uses vLLM instead of sglang.\n\n`View verl-multiturn-tools.yaml on GitHub <https://github.com/volcengine/verl/blob/main/examples/skypilot/verl-multiturn-tools.yaml>`_\n\nConfiguration\n-------------\n\nThe example YAML files are pre-configured with:\n\n- **Infrastructure**: Kubernetes clusters (``infra: k8s``) - can be changed to ``infra: aws`` or ``infra: gcp``, etc.\n- **Docker Image**: VERL's official Docker image with CUDA 12.6 support\n- **Setup**: Automatically clones and installs VERL from source\n- **Datasets**: Downloads required datasets during setup phase\n- **Ray Cluster**: Configures distributed training across nodes\n- **Logging**: Supports Weights & Biases via ``--secret WANDB_API_KEY``\n- **Models**: Supports gated HuggingFace models via ``--secret HF_TOKEN``\n\nLaunch Command Options\n----------------------\n\n- ``-c <name>``: Cluster name for managing the job\n- ``--secret KEY``: Pass secrets for API keys (can be used multiple times)\n- ``-y``: Skip confirmation prompt\n\nMonitoring Your Jobs\n--------------------\n\nCheck Cluster Status\n~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   sky status\n\nView Logs\n~~~~~~~~~\n\n.. code-block:: bash\n\n   sky logs verl-ppo  # View logs for the PPO job\n\nSSH into Head Node\n~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   ssh verl-ppo\n\nAccess Ray Dashboard\n~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   sky status --endpoint 8265 verl-ppo  # Get dashboard URL\n\nStop a Cluster\n~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   sky down verl-ppo\n"
  },
  {
    "path": "verl_distillation/docs/faq/faq.rst",
    "content": "Frequently Asked Questions\n====================================\n\nLast updated: 09/24/2025.\n\nRay related\n------------\n\nHow to add breakpoint for debugging with distributed Ray?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nPlease checkout the official debugging guide from Ray: https://docs.ray.io/en/latest/ray-observability/ray-distributed-debugger.html\n\n\n\"Unable to register worker with raylet\"\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThe cause of this issue is due to some system setting, e.g., SLURM added some constraints on how the CPUs are shared on a node. \nWhile `ray.init()` tries to launch as many worker processes as the number of CPU cores of the machine,\nsome constraints of SLURM restricts the `core-workers` seeing the `raylet` process, leading to the problem.\n\nTo fix this issue, you can set the config term ``ray_init.num_cpus`` to a number allowed by your system.\n\nDistributed training\n------------------------\n\nHow to run multi-node post-training with Ray?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nYou can start a ray cluster and submit a ray job, following the official guide from Ray: https://docs.ray.io/en/latest/ray-core/starting-ray.html\n\nThen in the configuration, set the ``trainer.nnode`` config to the number of machines for your job.\n\nHow to use verl on a Slurm-managed cluster?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nRay provides users with `this <https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html>`_ official\ntutorial to start a Ray cluster on top of Slurm. We have verified the :doc:`GSM8K example<../examples/gsm8k_example>`\non a Slurm cluster under a multi-node setting with the following steps.\n\n1. [Optional] If your cluster support `Apptainer or Singularity <https://apptainer.org/docs/user/main/>`_ and you wish\nto use it, convert verl's Docker image to an Apptainer image. Alternatively, set up the environment with the package\nmanager available on your cluster or use other container runtimes (e.g. through `Slurm's OCI support <https://slurm.schedmd.com/containers.html>`_) available to you.\n\n.. code:: bash\n\n    apptainer pull /your/dest/dir/vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3.sif docker://verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3\n\n2. Follow :doc:`GSM8K example<../examples/gsm8k_example>` to prepare the dataset and model checkpoints.\n\n3. Modify `examples/slurm/ray_on_slurm.slurm <https://github.com/volcengine/verl/blob/main/examples/slurm/ray_on_slurm.slurm>`_ with your cluster's own information.\n\n4. Submit the job script to the Slurm cluster with `sbatch`.\n\nPlease note that Slurm cluster setup may vary. If you encounter any issues, please refer to Ray's\n`Slurm user guide <https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html>`_ for common caveats.\n\nIf you changed Slurm resource specifications, please make sure to update the environment variables in the job script if necessary.\n\n\nInstall related\n------------------------\n\nNotImplementedError: TensorDict does not support membership checks with the `in` keyword. \n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nDetail error information: \n\n.. code:: bash\n\n    NotImplementedError: TensorDict does not support membership checks with the `in` keyword. If you want to check if a particular key is in your TensorDict, please use `key in tensordict.keys()` instead.\n\nCause of the problem: There is no suitable version of tensordict package for the linux-arm64 platform. The confirmation method is as follows:\n\n.. code:: bash\n\n    pip install tensordict==0.6.2\n\nOutput example:\n\n.. code:: bash\n\n    ERROR: Could not find a version that satisfies the requirement tensordict==0.6.2 (from versions: 0.0.1a0, 0.0.1b0, 0.0.1rc0, 0.0.2a0, 0.0.2b0, 0.0.3, 0.1.0, 0.1.1, 0.1.2, 0.8.0, 0.8.1, 0.8.2, 0.8.3)\n    ERROR: No matching distribution found for tensordict==0.6.2\n\nSolution 1st:\n  Install tensordict from source code:\n\n.. code:: bash\n\n    pip uninstall tensordict\n    git clone https://github.com/pytorch/tensordict.git\n    cd tensordict/\n    git checkout v0.6.2\n    python setup.py develop\n    pip install -v -e .\n\nSolution 2nd:\n  Temperally modify the error takeplace codes: tensordict_var -> tensordict_var.keys()\n\n\nIllegal memory access\n---------------------------------\n\nIf you encounter the error message like ``CUDA error: an illegal memory access was encountered`` during rollout, please check the vLLM documentation for troubleshooting steps specific to your vLLM version.\n\nCheckpoints\n------------------------\n\nIf you want to convert the model checkpoint into huggingface safetensor format, please refer to ``verl/model_merger``.\n\n\nTriton ``compile_module_from_src`` error\n------------------------------------------------\n\nIf you encounter triton compilation error similar to the stacktrace below, please set the ``use_torch_compile`` flag according to\nhttps://verl.readthedocs.io/en/latest/examples/config.html to disable just-in-time compilation for fused kernels.\n\n.. code:: bash\n\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/jit.py\", line 345, in <lambda>\n    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/autotuner.py\", line 338, in run\n    return self.fn.run(*args, **kwargs)\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/jit.py\", line 607, in run\n    device = driver.active.get_current_device()\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/driver.py\", line 23, in __getattr__\n    self._initialize_obj()\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/driver.py\", line 20, in _initialize_obj\n    self._obj = self._init_fn()\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/driver.py\", line 9, in _create_driver\n    return actives[0]()\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/backends/nvidia/driver.py\", line 371, in __init__\n    self.utils = CudaUtils()  # TODO: make static\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/backends/nvidia/driver.py\", line 80, in __init__\n    mod = compile_module_from_src(Path(os.path.join(dirname, \"driver.c\")).read_text(), \"cuda_utils\")\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/backends/nvidia/driver.py\", line 57, in compile_module_from_src\n    so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/build.py\", line 48, in _build\n    ret = subprocess.check_call(cc_cmd)\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/subprocess.py\", line 369, in check_call\n    raise CalledProcessError(retcode, cmd)\n\nWhat is the meaning of train batch size, mini batch size, and micro batch size?\n------------------------------------------------------------------------------------------\n\nThis figure illustrates the relationship between different batch size configurations.\n\nhttps://excalidraw.com/#json=pfhkRmiLm1jnnRli9VFhb,Ut4E8peALlgAUpr7E5pPCA\n\n.. image:: https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d\n\nHow to generate ray timeline to analyse performance of a training job?\n------------------------------------------------------------------------------------------\n\nTo generate the ray timeline file, you can set the config term ``ray_init.timeline_file`` to a json file path.\nFor example:\n\n.. code:: bash\n\n    ray_init.timeline_file=/tmp/ray_timeline.json\n  \nThe file will be generated in the specified path at the end of a training job.\nYou can use tools like chrome://tracing or the Perfetto UI and view the ray timeline file.\n\nThis figure shows the ray timeline file generated by from a training job on 1 node with 4 GPUs\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray_timeline.png?raw=true\n\nHow to set proxy only for wandb?\n------------------------------------------------------------------------------------------\n\nIf you need a proxy to access wandb, you can add below config in your training job script.\nComparing to using global https_proxy env variable, this approach won't mess up other http requests, such as ChatCompletionScheduler.\n\n.. code:: bash\n\n  +trainer.wandb_proxy=http://<your proxy and port>\n\nMissmatch between inference and training sequence (high actor/grad_norm)\n------------------------------------------------------------------------------------------\n\nIf you encounter the issue of actor/grad_norm metric continuously increasing during training, it might be caused by a significant precision mismatching between the inference engine and training. You can use the following parameter to confirm this:\n\n.. code:: bash\n\n    actor_rollout_ref.rollout.calculate_log_probs=True\n\nThis parameter will add metrics like training/rollout_probs_diff_mean , which can be used to verify if there is a precision difference between inference and training.\n\nUnder normal circumstances, the value of training/rollout_probs_diff_mean should be below 0.005. If you observe this value to be higher than 0.01, it indicates a precision issue from the inference engine.\nThe precision issue is known to occur under the following conditions:\n\n1. Using non-Hopper architecture GPUs, such as A100, L20, B200, etc.\n\n2. Using vLLM `with issue 22103 <https://github.com/vllm-project/vllm/issues/22103>`_ as the inference engine.\n\n3. The input and output texts are long, for example, in multi-turn scenarios using reasioning models like Qwen3 for RL training.\n\nIf all three conditions above are met and you observe that rollout_probs_diff_mean is too high, it is recommended to add the following parameter to resolve the precision issue:\n\n.. code:: bash\n\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_cascade_attn=True\n\nThe root cause of this issue is a bug in the flash attention used by vLLM. Although it has been fixed, the fix has not yet been released in the latest version of vLLM (v0.10.2).\nFor a more detailed explanation of this issue, please refer to `Fix LSE output error in FA2 kv-split <https://github.com/vllm-project/flash-attention/pull/87>`_.\n\nUntil vLLM releases a new version with this fix, it is recommended to use the configuration above to disable cascade attention as a workaround.\n"
  },
  {
    "path": "verl_distillation/docs/hybrid_flow.rst",
    "content": "=========================================================\nHybridFlow Programming Guide\n=========================================================\n\nLast updated: 06/02/2025.\n\n.. _vermouth: https://github.com/vermouth1992\n\nAuthor: `Chi Zhang <https://github.com/vermouth1992>`_\n\nverl is an open source implementation of the paper `HybridFlow <https://arxiv.org/abs/2409.19256v2>`_ [1]_. In this section, we will introduce the basic concepts of HybridFlow, the motivation and how to program with verl APIs.\n\nMotivation and Design\n------------------------\nWe use dataflow to represent RL systems. [4]_.\n\nDataFlow\n~~~~~~~~~~~~~~~~~~~~\n\nDataflow is an abstraction of computations. Neural Network training is a typical dataflow. It can be represented by computational graph. \n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/dataflow.jpeg?raw=true\n   :alt: The dataflow graph from CS231n 2024 lecture 4\n\nThis figure [2]_ represents the computation graph of a polynomial function followed by a sigmoid function. In the data flow of neural network computation, each node represents an operator, and each edge represents the direction of forward/backward propagation. The computation graph determines the architecture of the neural network.\n\nRL as a dataflow problem\n++++++++++++++++++++++++++++++++++++++++++++++\n\nReinforcement learning (RL) training can also be represented as a dataflow. Below is the dataflow graph that represents the PPO algorithm used in RLHF [3]_:\n\n.. image:: https://picx.zhimg.com/70/v2-cb8ab5ee946a105aab6a563e92682ffa_1440w.avis?source=172ae18b&biz_tag=Post\n  :alt: PPO dataflow graph, credit to Zhihu 低级炼丹师\n\nHowever, the dataflow of RL has fundamental differences compared with dataflow of neural network training as follows:\n\n+--------------------------+--------------------------------------------------+---------------------+\n| Workload                 | Node                                             | Edge                |\n+--------------------------+--------------------------------------------------+---------------------+\n| Neural Network Training  | Operator (+/-/matmul/softmax)                    | Tensor movement     |\n+--------------------------+--------------------------------------------------+---------------------+\n| Reinforcement Learning   | High-level operators (rollout/model forward)     | Data Movement       |\n+--------------------------+--------------------------------------------------+---------------------+\n\nIn the case of tabular reinforcement learning, each operator is a simple scalar math operation (e.g., bellman update). In deep reinforcement learning(DRL), each operator is a high-level neural network computation such as model inference/update. This makes RL a two-level dataflow problem:\n\n- Control flow: defines how the high-level operators are executed (e.g., In PPO, we first perform rollout. Then, we perform advantage computation. Finally, we perform training). It expresses the **core logics of RL algorithms**.\n- Computation flow: defines the dataflow of **neural network computation** (e.g., model forward/backward/optimizer).\n\n\nDesign Choices\n~~~~~~~~~~~~~~~~~~~~\nThe model size used in DRL before the LLM era is typically small. Thus, the high-level neural network computation can be done in a single process. This enables embedding the computation flow inside the control flow as a single process.\n\nHowever, in the LLM era, the computation flow (e.g., training neural network) becomes a multi-process program. This naturally leads to two design choices:\n\n1. Convert the control flow into a multi-process program as well. Then colocate with computation flow (unified multi-controller)\n\n- Advantages:\n\n  - Achieves the **optimal performance** under fixed computation flow and control flow as the communication overhead in both training and data transfer is minimized.\n\n- Disadvantages:\n\n  - The computation and/or control flow is **hard to reuse** from software perspective as computation code is coupled with specific controller code. For example, the training loop of PPO is generic. Say we have an PPO training flow implemented with a specific computation flow such as FSDP. Neither the control flow or computation flow can be reused if we want to switch the computation flow from FSDP to Megatron, due to the coupling of control and computation flows.\n  - Requires more efforts from the user under flexible and dynamic control flows, due to the multi-process nature of the program.\n\n2. Separate the flows: single process for the control flow and multi-process for computation flow\n\n- Advantages:\n\n  - The computation flow defined elsewhere can be **easily reused** after the decoupling.\n  - The controller runs on a single process. Implementing a new RL algorithm with a **different control flow is simple and easy**.\n\n- Disadvantages:\n\n  - Additional **data communication overhead** each time the controller process and computatation processes interact. The data has to be sent back and forth.\n\nIn verl, the latter strategy with separate control flow and computation flow is adopted. verl is designed to decouple the control flow of RL algorithms, and the implementation of computation engines.\n\nOverall Execution Diagram\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nBelow is a simplified diagram denoting the execution of a reinforcement learning job. In the diagram, the controller runs on a single process, while the generator/actor workers, critic workers run on multiple processes, placed with specific resource groups. For rollout, the controller passes the data to the generator to perform sample generation. When the rollout is done, the data is passed back to controller for the next step of the algorithm. Similar execution is done for other workers. With the hybrid controller design, the data flow and computation is decoupled to provide both efficiency in computation and flexibility in defining algorithm training loops.\n\n.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/driver_worker.png?raw=true\n   :alt: The execution diagram\n\nCodebase walkthrough (PPO)\n------------------------------------------------\n\nEntry function\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\nCode: https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py\n\nIn this file, we define a remote function `main_task` that serves as the controller (driver) process as shown in the above figure. We also define a ``RewardManager``, where users can customize their reward function based on the data source in the dataset. Note that `RewardManager` should return the final token-level reward that is optimized by RL algorithms. Note that users can combine model-based rewards and rule-based rewards.\nThe ``main_task`` constructs a RayPPOTrainer instance and launch the fit. Note that ``main_task`` **runs as a single process**.\n\nWe highly recommend that the ``main_task`` is NOT scheduled on the head of the ray cluster because ``main_task`` will consume a lot of memory but the head usually contains very few resources.\n\nRay trainer\n~~~~~~~~~~~~~~~~~~~~\nCode: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py\n\nThe RayPPOTrainer manages \n\n- Worker and WorkerGroup construction\n- Runs the main loop of PPO algorithm\n\nNote that, the fit function of RayPPOTrainer **runs as a single process**.\n\nWorker and WorkerGroup construction\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nEach workerGroup manages a list of workers that runs remotely. Note that the worker group runs in the process of its constructor.\nEach worker inside the WorkerGroup runs on a GPU. The worker group serves as a proxy for the controller process to interact with a list of workers, in order to perform certain computations. **In order to do so, we have to bind the methods of the worker into the method of the WorkerGroup and define the data dispatch and data collection**. This is done via simple decoration that will be introduced in the Worker definition section.\n\nFor example, in PPO, we define 3 worker groups:\n\n- ActorRolloutRef: manages actor, rollout and reference policy. ActorRolloutRefWorker can be instantiated as a single actor, a single rollout, a single reference policy, a combined actor/rollout or a combined actor/rollout/ref. This design is aimed for the maximum code reuse in various scenarios. The reason for colocating actor and rollout is for fast weight transfer using nccl. The reason for coloating actor and reference is to implement an efficient lora PPO as the reference policy is simply the base model of PPO in lora. The colocation is done via ``verl.single_controller.ray.base.create_colocated_worker_cls``, where it creates a single ray remote class exposing all class methods from these roles.\n- Critic: manages the critic model\n- Reward: manages the reward model\n\nThe worker group will be constructed on the resource pool it designates. The resource pool is a set of GPUs in the ray cluster.\n\nWorker definition\n~~~~~~~~~~~~~~~~~~~~\n\n.. _ActorRolloutRefWorker: https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py\n\nWe take `ActorRolloutRefWorker <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py>`_ for an example.\nThe APIs it should expose to the controller process are:\n\n- init_model: build the underlying model\n- generate_sequences: given prompts, generate responses\n- compute_log_prob: compute the log-probability of a generated sequence using actor\n- compute_ref_log_prob: compute the log-probability of a generated sequence using reference policy\n- save_checkpoint: save the checkpoint\n\nNote that these methods are defined in the worker that can only be invoked via remote calls. For example, if the controller process wants to initialize the model, it has to call\n\n.. code-block:: python\n\n   for worker in actor_rollout_ref_wg:\n       worker.init_model.remote()\n\nIf the controller process wants to generate sequences, it has to call\n\n.. code-block:: python\n\n   data = xxx\n   # split the data into dp chunks\n   data_dp_lst = data.split(dp_size)\n   output_dp_lst = []\n   for i, worker in enumerate(actor_rollout_ref_wg):\n       output_future = worker.generate_sequences.remote(data_dp_lst[i])\n       output_dp_lst.append(output_future)\n   output = torch.cat(ray.get(output_dp_lst), dim=0)\n\nWe observe that controller process calling worker group methods in general can be divided into 3 parts:\n\n- Split the data into data parallel sizes\n- Dispatch the corresponding data into each worker\n- Collect and concatenate the data when the computation finishes\n\nIn verl, we design a syntax sugar to encapsulate the 3 processes into a single call from the controller process.\n\n.. code-block:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def generate_sequences(data):\n       ...\n\n   # on the driver\n   output = actor_rollout_ref_wg.generate_sequences(data)\n\nWe decorate the method of the worker with a ``register`` that explicitly defines how the input data should be split and dispatched to each worker, and how the output data should be collected and concatenated by the controller. For example, ``Dispatch.DP_COMPUTE_PROTO`` splits the input data into dp chunks, dispatch each data to each worker, collect the output and concatenate the results. Note that this function requires the input and output to be a DataProto defined here (https://github.com/volcengine/verl/blob/main/verl/protocol.py).\n\n\nPPO main loop\n~~~~~~~~~~~~~~~~~~~~\nWith the aforementioned APIs, we can implement the main loop of PPO as if it is a single process program\n\n.. code-block:: python\n\n   for prompt in dataloader:\n       output = actor_rollout_ref_wg.generate_sequences(prompt)\n       old_log_prob = actor_rollout_ref_wg.compute_log_prob(output)\n       ref_log_prob = actor_rollout_ref_wg.compute_ref_log_prob(output)\n       values = critic_wg.compute_values(output)\n       rewards = reward_wg.compute_scores(output)\n       # compute_advantages is running directly on the control process\n       advantages = compute_advantages(values, rewards)\n       output = output.union(old_log_prob)\n       output = output.union(ref_log_prob)\n       output = output.union(values)\n       output = output.union(rewards)\n       output = output.union(advantages)\n       # update actor\n       actor_rollout_ref_wg.update_actor(output)\n       critic.update_critic(output)\n\nTakeaways\n~~~~~~~~~~~~~~~~~~~~\n- This programming paradigm enables users to use different computation backend without modification of the control process.\n- This programming paradigm enables flexible placement (by changing the mapping of WorkerGroup and ResourcePool) without modification of the control process.\n\nRepository organization\n------------------------------------------------\n\nImportant code files in the repository are organized as below:\n\n.. code-block:: bash\n\n   verl # the verl package\n     trainer\n       main_ppo.py  # the entrypoint for RL training\n       ppo\n         ray_trainer.py  # the training loop for RL algorithms such as PPO\n       fsdp_sft_trainer.py  # the SFT trainer with FSDP backend\n     config\n       generation.yaml  # configuration template for rollout\n       ppo_trainer.yaml  # configuration template for the RL trainer\n     workers\n       protocol.py  # the interface of DataProto\n       fsdp_workers.py   # the FSDP worker interfaces: ActorRolloutRefWorker, CriticWorker, RewardModelWorker\n       megatron_workers.py  # the Megatron worker interfaces: ActorRolloutRefWorker, CriticWorker, RewardModelWorker\n       actor\n         dp_actor.py  #  data parallel actor with FSDP backend\n         megatron_actor.py  # nD parallel actor with Megatron backend\n       critic\n         dp_critic.py  # data parallel critic with FSDP backend\n         megatron_critic.py  # nD parallel critic with FSDP backend\n       reward_model\n         megatron\n           reward_model.py  # reward model with Megatron backend\n       rollout\n         vllm\n           vllm_rollout.py  # rollout with vllm backend\n         hf_rollout.py  # rollout with huggingface TGI backend\n       sharding_manager\n         fsdp_ulysses.py  # data and model resharding when using FSDP + ulysses\n         fsdp_vllm.py  # data and model resharding when using FSDP + ulysses + vllm\n         megatron_vllm.py  # data and model resharding when using Megatron + vllm\n     utils\n       dataset  # datasets for SFT/RM/RL\n       reward_score  # function based reward\n         gsm8k.py  # reward function for gsm8k dataset\n         math.py  # reward function for math dataset\n       seqlen_balancing.py  # the sequence balance optimization\n     models\n       llama  # Megatron implementation for llama, deepseek, mistral, etc\n       transformers  # ulysses integration with transformer models such as llama, qwen, etc\n       weight_loader_registery.py  # registry of weight loaders for loading hf ckpt into Megatron\n     third_party\n       vllm  # adaptor for vllm's usage in RL\n         vllm_spmd  # vllm >= v0.7 adaptor\n   examples  # example scripts\n   tests  # integration and unit tests\n   .github  # the configuration of continuous integration tests\n\n\n.. [1] HybridFlow: A Flexible and Efficient RLHF Framework: https://arxiv.org/abs/2409.19256v2\n.. [2] Data flow graph credit to CS231n 2024 lecture 4: https://cs231n.stanford.edu/slides/2024/lecture_4.pdf\n.. [3] PPO dataflow graph credit to 低级炼丹师 from Zhihu​: https://zhuanlan.zhihu.com/p/635757674\n.. [4] RLFlow\n"
  },
  {
    "path": "verl_distillation/docs/index.rst",
    "content": "Welcome to verl's documentation!\n================================================\n\nverl is a flexible, efficient and production-ready RL training framework designed for large language models (LLMs) post-training. It is an open source implementation of the `HybridFlow <https://arxiv.org/pdf/2409.19256>`_ paper.\n\nverl is flexible and easy to use with:\n\n- **Easy extension of diverse RL algorithms**: The hybrid programming model combines the strengths of single-controller and multi-controller paradigms to enable flexible representation and efficient execution of complex Post-Training dataflows. Allowing users to build RL dataflows in a few lines of code.\n\n- **Seamless integration of existing LLM infra with modular APIs**: Decouples computation and data dependencies, enabling seamless integration with existing LLM frameworks, such as PyTorch FSDP, Megatron-LM, vLLM and SGLang. Moreover, users can easily extend to other LLM training and inference frameworks.\n\n- **Flexible device mapping and parallelism**: Supports various placement of models onto different sets of GPUs for efficient resource utilization and scalability across different cluster sizes.\n\n- Ready integration with popular HuggingFace models\n\n\nverl is fast with:\n\n- **State-of-the-art throughput**: By seamlessly integrating existing SOTA LLM training and inference frameworks, verl achieves high generation and training throughput.\n\n- **Efficient actor model resharding with 3D-HybridEngine**: Eliminates memory redundancy and significantly reduces communication overhead during transitions between training and generation phases.\n\n--------------------------------------------\n\n.. _Contents:\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Quickstart\n\n   start/install\n   start/quickstart\n   start/multinode\n   start/ray_debug_tutorial\n   start/more_resources\n   start/agentic_rl\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Programming guide\n\n   hybrid_flow\n   single_controller\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Data Preparation\n\n   preparation/prepare_data\n   preparation/reward_function\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Configurations\n\n   examples/config\n\n.. toctree::\n   :maxdepth: 1\n   :caption: PPO Example\n\n   examples/ppo_code_architecture\n   examples/gsm8k_example\n   examples/multi_modal_example\n   examples/skypilot_examples\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Algorithms\n\n   algo/ppo.md\n   algo/grpo.md\n   algo/collabllm.md\n   algo/dapo.md\n   algo/spin.md\n   algo/sppo.md\n   algo/entropy.md\n   algo/opo.md\n   algo/baseline.md\n   algo/gpg.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: PPO Trainer and Workers\n\n   workers/ray_trainer\n   workers/fsdp_workers\n   workers/megatron_workers\n   workers/sglang_worker\n   workers/model_engine\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Performance Tuning Guide\n\n   perf/dpsk.md\n   perf/best_practices\n   perf/perf_tuning\n   README_vllm0.8.md\n   perf/device_tuning\n   perf/verl_profiler_system.md\n   perf/nsight_profiling.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Adding new models\n\n   advance/fsdp_extension\n   advance/megatron_extension\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Advanced Features\n\n   advance/checkpoint\n   advance/rope\n   advance/attention_implementation\n   advance/ppo_lora.rst\n   sglang_multiturn/multiturn.rst\n   sglang_multiturn/interaction_system.rst\n   advance/placement\n   advance/dpo_extension\n   examples/sandbox_fusion_example\n   advance/rollout_trace.rst\n   advance/rollout_skip.rst\n   advance/rollout_is.md\n   advance/one_step_off\n   advance/agent_loop\n   advance/reward_loop\n   advance/fully_async\n   data/transfer_queue.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Hardware Support\n\n   amd_tutorial/amd_build_dockerfile_page.rst\n   amd_tutorial/amd_vllm_page.rst\n   ascend_tutorial/ascend_quick_start.rst\n   ascend_tutorial/ascend_profiling_zh.rst\n   ascend_tutorial/ascend_profiling_en.rst\n   ascend_tutorial/dockerfile_build_guidance.rst\n   ascend_tutorial/ascend_sglang_quick_start.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: API References\n\n   api/data\n   api/single_controller.rst\n   api/trainer.rst\n   api/utils.rst\n\n\n.. toctree::\n   :maxdepth: 2\n   :caption: FAQ\n\n   faq/faq\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Development Notes\n\n   sglang_multiturn/sandbox_fusion.rst\n\nContribution\n-------------\n\nverl is free software; you can redistribute it and/or modify it under the terms\nof the Apache License 2.0. We welcome contributions.\nJoin us on `GitHub <https://github.com/volcengine/verl>`_, `Slack <https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA>`_ and `Wechat <https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG>`_ for discussions.\n\nContributions from the community are welcome! Please check out our `project roadmap <https://github.com/volcengine/verl/issues/710>`_ and `good first issues <https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22>`_ to see where you can contribute.\n\nCode Linting and Formatting\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nWe use pre-commit to help improve code quality. To initialize pre-commit, run:\n\n.. code-block:: bash\n\n   pip install pre-commit\n   pre-commit install\n\nTo resolve CI errors locally, you can also manually run pre-commit by:\n\n.. code-block:: bash\n\n   pre-commit run\n\nAdding CI tests\n^^^^^^^^^^^^^^^^^^^^^^^^\n\nIf possible, please add CI test(s) for your new feature:\n\n1. Find the most relevant workflow yml file, which usually corresponds to a ``hydra`` default config (e.g. ``ppo_trainer``, ``ppo_megatron_trainer``, ``sft_trainer``, etc).\n2. Add related path patterns to the ``paths`` section if not already included.\n3. Minimize the workload of the test script(s) (see existing scripts for examples).\n\nWe are HIRING! Send us an `email <mailto:haibin.lin@bytedance.com>`_ if you are interested in internship/FTE opportunities in MLSys/LLM reasoning/multimodal alignment.\n"
  },
  {
    "path": "verl_distillation/docs/perf/best_practices.rst",
    "content": "Verl LLM Best Practices (DAPO + Qwen3-235B)\n===========================================\n\nLast updated: 11/03/2025.\n\nPurpose\n-------\n\nThis guide uses DAPO training on Qwen3-235B as a concrete example. We unpack every parameter that appears in the optimization objective, map it to Verl configuration entries, and share field-tested recommendations so you can derive sensible settings for your own workloads.\n\n.. note::\n\n   1. The guide only covers the subset of parameters required to reproduce the DAPO experiments discussed here. For the full list, refer to the ``config`` components in the Verl source tree: https://github.com/volcengine/verl/tree/main/verl/trainer/config\n   2. PPO and GRPO introduce KL-constrained policies. We therefore include that setup in the explanations below. You can treat all configurations mentioned here as a DAPO pipeline augmented with a KL penalty.\n\nOptimization Objectives\n-----------------------\n\nDAPO objective\n~~~~~~~~~~~~~~\n\n.. math::\n\n   \\begin{aligned}\n   \\mathcal{J}_{\\mathrm{DAPO}}(\\theta)= & \\mathbb{E}_{(q, a) \\sim \\mathcal{D},\\left\\{o_i\\right\\}_{i=1}^G \\sim \\pi_{\\theta_{\\text {old }}}(\\cdot \\mid q)} \\\n    {\\left[\\frac{1}{\\sum_{i=1}^G\\left|o_i\\right|} \\sum_{i=1}^G \\sum_{t=1}^{\\left|o_i\\right|} \\min \\left(r_{i, t}(\\theta) \\hat{A}_{i, t}, \\operatorname{clip}\\left(r_{i, t}(\\theta), 1-\\varepsilon_{\\text {low }}, 1+\\varepsilon_{\\text {high }}\\right) \\hat{A}_{i, t}\\right)\\right] } \\\\\n   \\end{aligned}\n\n.. math::\n   \\text { s.t. } \\quad 0<\\mid\\left\\{o_i \\mid \\text { is_equivalent }\\left(a, o_i\\right)\\right\\} \\mid<G,\n\n.. math::\n\n   \\text {where} \\quad r_{i, t}(\\theta)=\\frac{\\pi_\\theta\\left(o_{i, t} \\mid q, o_{i,<t}\\right)}{\\pi_{\\theta_{\\text {old }}}\\left(o_{i, t} \\mid q, o_{i,<t}\\right)}, \\quad \\hat{A}_{i, t}=\\frac{R_i-\\operatorname{mean}\\left(\\left\\{R_i\\right\\}_{i=1}^G\\right)}{\\operatorname{std}\\left(\\left\\{R_i\\right\\}_{i=1}^G\\right)}\n\nGRPO objective\n~~~~~~~~~~~~~~\n\n.. math::\n\n   \\begin{aligned}\n   \\mathcal{J}_{G R P O}(\\theta) & =\\mathbb{E}_{q \\sim P(Q),\\left\\{o_i\\right\\}_{i=1}^G \\sim \\pi_{\\theta_{\\text {old }}}(O \\mid q)} \\\n   \\frac{1}{G} \\sum_{i=1}^G \\frac{1}{\\left|o_i\\right|} \\sum_{t=1}^{\\left|o_i\\right|}\\left\\{\\min \\left[\\frac{\\pi_\\theta\\left(o_{i, t} \\mid q, o_{i,<t}\\right)}{\\pi_{\\theta_{\\text {old }}}\\left(o_{i, t} \\mid q, o_{i,<t}\\right)} \\hat{A}_{i, t}, \\operatorname{clip}\\left(\\frac{\\pi_\\theta\\left(o_{i, t} \\mid q, o_{i,<t}\\right)}{\\pi_{\\theta_{\\text {old }}}\\left(o_{i, t} \\mid q, o_{i,<t}\\right)}, 1-\\varepsilon, 1+\\varepsilon\\right) \\hat{A}_{i, t}\\right]-\\beta \\mathbb{D}_{K L}\\left[\\pi_\\theta \\| \\pi_{r e f}\\right]\\right\\},\n   \\end{aligned}\n\nNotation Overview\n-----------------\n\n:math:`(q,a)\\sim D`\n  - :math:`D` denotes the training dataset. For each sample, :math:`q` is the input prompt (for math tasks, the question) and :math:`a` is the target output—typically the final answer without intermediate reasoning steps.\n\n:math:`G`\n  - Group size. For every prompt we sample :math:`G` independent responses.\n\n:math:`\\theta`\n  - Actor model parameters.\n\n:math:`\\pi`\n  - Sampling strategy that bundles the rollout backend (vLLM, sglang, etc.) and all generation hyperparameters. Because LLMs generate tokens autoregressively, rollout dominates runtime, so backend-specific tuning is critical.\n\n:math:`\\pi_\\theta`\n  - Actor policy obtained by instantiating :math:`\\pi` with parameters :math:`\\theta`.\n\n:math:`\\hat{A}_{i,t}`\n  - Advantage of the :math:`i`-th sample within the group at timestep :math:`t`.\n\n:math:`R_i`\n  - Reward assigned to the :math:`i`-th sample in the group.\n\n:math:`\\mathbb{D}_{KL}`\n  - KL divergence between two policies.\n\n:math:`\\beta`\n  - Coefficient that weights the KL term.\n\n:math:`\\pi_{old}`\n  - Frozen “old” policy, updated after every :math:`\\texttt{train_batch_size}` samples.\n\n:math:`\\pi_{ref}`\n  - Reference policy used to compute the KL divergence.\n\n:math:`o_i, |o_i|`\n  - :math:`o_i` is the generated output sequence for the :math:`i`-th prompt; :math:`|o_i|` is its token length.\n\n:math:`\\pi_\\theta(o_{i,t} \\mid q_i, o_{i,<t})`\n  - Probability of emitting token :math:`o_{i,t}` at timestep :math:`t` given prompt :math:`q_i` and the previously generated prefix under parameters :math:`\\theta`. In practice, the rollout engine first generates full responses, then concatenates prompts and outputs for each model; with attention masks we can compute all token probabilities in one pass.\n\n:math:`\\varepsilon_{low}` and :math:`\\varepsilon_{high}`\n  - Lower and upper clipping bounds for importance sampling. DAPO adopts a clip-higher strategy, so the upper bound is different from the lower bound to prevent overly large policy updates.\n\nParameter Reference\n-------------------\n\n:math:`(q,a)\\sim D`\n  - ``data.train_files`` / ``data.val_files``:\n    Training and validation datasets. They must be stored as ``.parquet``. Use the conversion scripts under ``examples/data_preprocess`` and make sure your ``data_source`` implements the matching reward function. You can also reuse the HuggingFace dataset ``BytedTsinghua-SIA/DAPO-Math-17k``.\n  - ``data.prompt_key``:\n    Column name for prompts. Keep the default ``prompt`` unless you have a clearer schema.\n  - ``data.max_prompt_length``:\n    Upper bound on prompt length. Set it to cover the longest prompt in the corpus; when long-tail samples make it too large, lower the value and combine with ``data.truncation``.\n  - ``data.truncation``:\n    Policy for over-length inputs (truncate-left/right or raise). ``left`` works for most runs. If training logs show large ``clip_ratio`` and poor metrics, increase ``data.max_prompt_length`` or clean the data. Set to ``error`` when strict validation is required.\n\n:math:`G`\n  - ``actor_rollout_ref.rollout.n``:\n    Number of generations per prompt. Typical values: GRPO 64, DAPO 16.\n\n:math:`\\theta`\n  - ``actor_rollout_ref.model.path``:\n    Path to the actor checkpoint in HuggingFace-compatible format.\n  - ``actor_rollout_ref.actor.megatron.use_mbridge``:\n    Enable mbridge format conversion when the model was trained with Megatron. Use the latest mbridge release: https://github.com/ISEEKYAN/mbridge.\n\n:math:`\\pi`\n  - ``actor_rollout_ref.rollout.name``:\n    Rollout backend. Verl currently supports ``vllm`` and ``sglang``—benchmark and tune according to your infrastructure.\n  - ``actor_rollout_ref.rollout.response_length`` / ``data.max_response_length``:\n    Maximum generated tokens (rollout setting takes precedence). Larger values improve quality but consume more memory and latency. Monitor ``clip_ratio``; values above 0.1 often mean you are truncating too much.\n  - ``actor_rollout_ref.rollout.gpu_memory_utilization``:\n    Target GPU memory usage during rollout. Push it as high as possible without triggering OOM; with parameter/gradient/optimizer offload enabled, 0.8–0.9 is common.\n  - ``actor_rollout_ref.rollout.tensor_model_parallel_size``:\n    Tensor parallel degree for the inference engine. Ensure ``(memory_per_gpu * gpu_memory_utilization * TP) > 2 * model_parameters`` (bf16/fp16). Increase TP gradually to expand KV cache capacity while watching communication cost—especially once TP > 8.\n  - ``actor_rollout_ref.rollout.temperature`` / ``top_p`` / ``top_k``:\n    Sampling knobs for rollout. Keep enough randomness; ``temperature=1.0``, ``top_p=1.0``, ``top_k=-1`` are good defaults.\n  - ``actor_rollout_ref.rollout.val_kwargs.temperature`` / ``top_p`` / ``top_k`` / ``do_sample`` / ``n``:\n    Sampling options for validation. Set ``temperature > 0`` to prevent repetitive thinking chains. For small test sets (e.g., AIME24) raise ``n`` (64 is a common choice) to reduce variance. A practical starting point is ``temperature=1.0``, ``top_p=0.7``, ``top_k=-1``, ``do_sample=True``, ``n=1`` and then increase ``n`` as needed.\n  - ``+actor_rollout_ref.rollout.engine_kwargs.vllm.*`` / ``+actor_rollout_ref.rollout.engine_kwargs.sglang.*``:\n    Extra backend options injected via the ``+`` syntax. Consult backend docs for exact semantics. Some switches (for example ``pipeline_parallel_size``) may not be supported yet; when TP=32, ``enable_expert_parallel=True`` can even slow down DeepSeek-V3 rollout, so benchmark carefully.\n\n:math:`\\pi_\\theta`\n  - ``data.train_batch_size``:\n    Total batch size per training iteration. Each rollout produces ``train_batch_size * n`` samples. Larger values reduce the number of rollouts but increase off-policy drift.\n  - ``actor_rollout_ref.actor.ppo_mini_batch_size``:\n    Mini-batch size per optimization step. Tune it the same way you would for standard deep learning workloads.\n  - ``actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``:\n    Samples processed per forward pass on one GPU group (a Megatron group contains TP * PP * CP GPUs). Keep it ≤ ``ppo_mini_batch_size`` and as large as memory allows.\n  - ``actor_rollout_ref.actor.use_dynamic_bsz``:\n    Enable dynamic batch sizing to adapt to sequence length and improve throughput.\n  - ``actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu``:\n    Maximum tokens per GPU when computing log probabilities under dynamic batching. Set it to at least a multiple of ``max_prompt_length + max_response_length`` to prevent truncation.\n  - Megatron parallelism parameters (``pipeline_model_parallel_size`` / ``tensor_model_parallel_size`` / ``expert_model_parallel_size`` / ``expert_tensor_parallel_size`` / ``context_parallel_size``):\n    Balance PP/TP/EP/ETP/CP to match memory and network constraints. In bf16/fp16, each parameter consumes roughly ``2 / TP`` bytes; if you keep FP32 master weights or skip optimizer offload, reserve another 4–8 bytes for Adam. Activations scale with ``micro_batch_size × sequence_length × hidden_size`` and can be mitigated with gradient checkpointing, dynamic batches, or offload. Prefer increasing TP first, add PP when necessary, extend sequence capacity with CP, align EP/ETP with TP for MoE models, and keep DP minimal on constrained clusters while combining with offload. Always align the setup with hardware topology and communication cost.\n  - ``actor_rollout_ref.model.use_fused_kernels``:\n    Enable Verl’s fused kernels for supported models to squeeze out additional performance.\n\n:math:`\\hat{A}_{i,t}`\n  - ``algorithm.adv_estimator``:\n    Advantage estimator. Set to ``grpo`` for DAPO/GRPO.\n\n:math:`R_i`\n  - ``reward_model.reward_manager``:\n    Reward aggregation strategy. Use ``dapo`` for DAPO and ``naive`` for GRPO.\n\n:math:`D_{KL}`\n  - ``algorithm.use_kl_in_reward``:\n    Whether to add a KL term to the reward. ``True`` for PPO, ``False`` for GRPO and DAPO.\n  - ``actor_rollout_ref.actor.use_kl_loss``:\n    Whether to include a KL loss term. ``False`` for PPO, ``True`` for GRPO, ``False`` for DAPO.\n\n:math:`\\beta`\n  - ``actor_rollout_ref.actor.kl_loss_coef``:\n    Weight of the KL loss. Start around 0.001. Larger values curb reward hacking but reduce exploration.\n  - ``algorithm.kl_ctrl.kl_coef``:\n    KL coefficient applied within the reward. Adjust to match your tolerance for divergence.\n\n:math:`\\pi_{old}`\n  - ``actor_rollout_ref.rollout.log_prob_use_dynamic_bsz``:\n    Enable dynamic batching when the old policy computes log-probabilities. Recommended.\n\n:math:`\\pi_{ref}`\n  - ``actor_rollout_ref.ref.log_prob_use_dynamic_bsz``:\n    Enable dynamic batching for the reference policy. Recommended.\n  - Reference Megatron parallelism:\n    Keep ``pipeline_model_parallel_size``, ``tensor_model_parallel_size``, ``expert_model_parallel_size``, ``expert_tensor_parallel_size``, and ``context_parallel_size`` in sync with the actor.\n  - ``actor_rollout_ref.ref.megatron.param_offload``:\n    Offload reference parameters to CPU when the actor does so. Even without gradients or optimizer states, parity helps with capacity planning.\n\n:math:`o_i` / :math:`|o_i|`\n  - ``actor_rollout_ref.actor.loss_agg_mode``:\n    Loss aggregation mode. Token-level ``token-mean`` matches the recommendations from Dr.GRPO and DAPO; use ``seq-mean-token-mean`` to reproduce the original GRPO behavior.\n\n:math:`\\pi_\\theta(o_{i,t} \\mid q_i,o_{i,<t})`\n  - ``actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu`` / ``actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu``:\n    Batch size while computing token probabilities. Rollout engines generate outputs and then concatenate inputs for each model, so balance memory against throughput.\n\n:math:`\\epsilon_{low}` / :math:`\\epsilon_{high}`\n  - ``actor_rollout_ref.actor.clip_ratio_low`` / ``actor_rollout_ref.actor.clip_ratio_high``:\n    Importance sampling clipping bounds. For DAPO, use ``clip_ratio_low=0.2`` and ``clip_ratio_high=0.28``.\n\nvLLM inference optimizations\n  - ``actor_rollout_ref.rollout.enable_chunked_prefill``:\n    Enables chunked prefill to boost GPU utilization (vLLM only). Tune together with ``max_num_batched_tokens``.\n  - ``actor_rollout_ref.rollout.max_num_batched_tokens``:\n    Maximum tokens per batch. A practical rule of thumb is ``max(8192, max_prompt_length + max_response_length, max_model_len)``; see the vLLM docs for details.\n  - ``actor_rollout_ref.rollout.enforce_eager``:\n    Disables CUDA graphs. By default vLLM leverages CUDA graphs for speed at the cost of extra memory (not limited by ``gpu_memory_utilization``); set this to ``True`` when memory is tight.\n  - ``actor_rollout_ref.rollout.cudagraph_capture_sizes``:\n    Explicit capture batch sizes for CUDA graphs. Default is ``null``; on memory-constrained systems try ``[1, 2, 4, 8, 16, 32]``.\n\nOptimizer settings\n  - ``actor_rollout_ref.actor.optim.lr``:\n    Learning rate. Start around ``1e-5`` or ``1e-6``.\n  - ``actor_rollout_ref.actor.optim.lr_warmup_steps``:\n    Number of warmup steps (e.g., 10).\n  - ``actor_rollout_ref.actor.optim.weight_decay``:\n    Weight decay coefficient, typically 0.1.\n  - ``actor_rollout_ref.actor.optim.clip_grad``:\n    Gradient clipping threshold, commonly 1.\n  - ``+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction``:\n    Portion of optimizer updates executed on CPU. Large models such as DeepSeek benefit from enabling it with value 1.\n  - ``+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d`` / ``+...use_precision_aware_optimizer`` / ``+...optimizer_cpu_offload``:\n    Companion switches for hybrid optimizers. Turn them on alongside CPU offload.\n\nMegatron-related parameters\n  - ``actor_rollout_ref.actor.megatron.param_offload`` / ``optimizer_offload`` / ``grad_offload``:\n    Offload parameters, optimizer states, and gradients to CPU when GPU memory is insufficient.\n  - ``+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method`` / ``recompute_granularity`` / ``recompute_num_layers``:\n    Gradient checkpointing controls. Enable (e.g., ``uniform``, ``full``, ``1``) to trade computation for memory.\n  - ``+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype`` / ``moe_shared_expert_overlap`` / ``moe_permute_fusion`` / ``moe_enable_deepep`` / ``moe_token_dispatcher_type``:\n    Recommended MoE knobs (sample values: ``fp32``, ``False``, ``True``, ``True``, ``flex``) for stable performance.\n  - ``+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion``:\n    Enables fused gradient accumulation for additional speedup.\n  - ``+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split`` / ``account_for_loss_in_pipeline_split`` / ``num_layers_in_last_pipeline_stage``:\n    Pipeline-parallel adjustments when layer counts do not divide evenly. Treat embedding and loss as standalone stages and set ``num_layers_in_last_pipeline_stage`` (0 or ``${LAST_LAYER}``) when you need manual control.\n\nTrainer\n  - ``trainer.logger``:\n    Logging backends. Use ``['console', 'wandb']`` or, on Volcano Engine ML Platform, ``['console', 'vemlp_wandb']``.\n  - ``trainer.project_name`` / ``trainer.experiment_name``:\n    Hierarchical naming for projects and experiments so you can locate runs quickly.\n  - ``trainer.n_gpus_per_node`` / ``trainer.nnodes``:\n    Number of GPUs per node and total node count. Match your cluster allocation.\n  - ``trainer.test_freq`` / ``trainer.save_freq`` / ``trainer.total_epochs``:\n    Evaluation interval, checkpoint interval, and total epochs—configure for your SLA.\n  - ``trainer.log_val_generations``:\n    Number of validation samples stored in logs. Start with 10 and adjust as needed.\n  - ``trainer.val_before_train``:\n    Run validation before training begins when you require a baseline checkpoint.\n"
  },
  {
    "path": "verl_distillation/docs/perf/device_tuning.rst",
    "content": "Hardware Resource Needed for RL\n===============================\n\nLast updated: 06/25/2025.\n\nSince RL requires more resources compared to regular training, \ndetermining how much resources are needed to successfully run it before training \nis a relatively difficult task. To provide more people with reference points for \nresource selection when dealing with different models and tasks, this section is \nmainly dedicated to introducing the environmental requirements based on experiments \nwe have conducted.\n\nHowever, due to limited staff and equipment resources, we also hope for more \ncontributions from the open-source community. When submitting a PR, it is necessary \nto provide a script to be added to the example/tuning scripts.\n\nWe need two types of scripts: one is the configuration that can run with the **minimum \nresources(min)**, and the other is the configuration that runs with **recommended resources(recommended)**. For the former, \nit can be understood as a script that can run after applying all memory optimization techniques \n(e.g., offload, gradient checkpointing). For the latter, it can be understood as a script that \ncan run while avoiding operations that incur additional time overhead as much as possible (targetting best throughput).\n\nWhen defining script names, please follow this format: \n``[model]_[task]_[gpunums]_[device]_[train]_[infer].sh``. This will effectively improve \nthe script's recognizability. You can place the script under the ``examples/tuning/`` directory.\n\nIf you happen to have a configuration that has already been tested, we welcome you to submit \na PR and include a screenshot from Wandb or other verifiable evidence.\n\n----------------------------------------\n\n0.5B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2.5-0.5B\n      - GRPO-LoRA\n      - 1*H100\n      - 116\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n1.5B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2.5-1.5B\n      - GRPO-LoRA\n      - 1*H100\n      - 128\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n3B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2.5-3B\n      - GRPO-LoRA\n      - 1*H100\n      - 62\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n7B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2-7B\n      - GRPO\n      - 2*H800\n      - \\\n      - fsdp\n      - vllm0.8.2\n      - `qwen2-7b_grpo_2_h800_fsdp_vllm <https://github.com/volcengine/verl/blob/main/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh>`_\n      - `Xiangyongan <xiangyongan@bytedance.com>`_\n    * - MIN\n      - Qwen2.5-7B\n      - GRPO-LoRA\n      - 1*H100\n      - 16\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n14B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2-14B\n      - GRPO\n      - 4*H800\n      - \\\n      - fsdp\n      - vllm0.8.2\n      - `qwen2-14b_grpo_4_h800_fsdp_vllm <https://github.com/volcengine/verl/blob/main/examples/tuning/14b/qwen2-14b_grpo_4_h800_fsdp_vllm.sh>`_\n      - `Xiangyongan <xiangyongan@bytedance.com>`_\n    * - MIN\n      - Qwen2.5-14B\n      - GRPO-LoRA\n      - 2*H100\n      - 116\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n32B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2-32B\n      - GRPO\n      - 8*H20\n      - \\\n      - megatron\n      - vllm0.8.2\n      - `qwen2-32b_grpo_8_h20_megatron_vllm <https://github.com/volcengine/verl/tree/main/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh>`_\n      - `Xiangyongan <xiangyongan@bytedance.com>`_\n    * - MIN\n      - Qwen2.5-32B\n      - GRPO-LoRA\n      - 4*H100\n      - 180\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n70B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n\n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2-70B\n      - GRPO\n      - 32*H20\n      - \\\n      - fsdp\n      - vllm0.8.2\n      - `qwen2-70b_grpo_32_h20_fsdp_vllm <https://github.com/volcengine/verl/blob/main/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh>`_\n      - `Xiangyongan <xiangyongan@bytedance.com>`_\n    * - MIN\n      - Qwen2-70B\n      - GRPO\n      - 32*H800\n      - \\\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-70b_grpo_32_h800_fsdp_vllm <https://github.com/volcengine/verl/blob/main/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh>`_\n      - `Xiangyongan <xiangyongan@bytedance.com>`_\n    * - MIN\n      - Qwen2.5-72B\n      - GRPO-LoRA\n      - 8*H100\n      - 176\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n405B\n~~~~\n\n.. table::\n   :widths: auto\n\n   ====== ====== ====== ======== ======== ====== ====== ======\n   tag    model  task   resource MaxBatch train  infer  link\n   ====== ====== ====== ======== ======== ====== ====== ======\n   \\      \\      \\        \\        \\      \\      \\\n   ====== ====== ====== ======== ======== ====== ====== ======\n\n671B\n~~~~\n\n.. table::\n   :widths: auto\n\n   ====== ====== ====== ======== ======== ====== ====== ======\n   tag    model  task   resource MaxBatch train  infer  link\n   ====== ====== ====== ======== ======== ====== ====== ======\n   \\      \\      \\        \\        \\      \\      \\\n   ====== ====== ====== ======== ======== ====== ====== ======\n"
  },
  {
    "path": "verl_distillation/docs/perf/dpsk.md",
    "content": "# Training DeepSeek 671b\n\nLast updated: 08/20/2025.\n\nverl integrates Megatron to support large MoE models such as `Qwen3-235B-A22B` and `deepseek-ai/DeepSeek-V3`. This is an ongoing community effort.\n\nIn the journey the community added the following features and optimizations that enable verl with larger models:\n- per tensor weight resharding between rollout and training\n- context parallelism and expert parallelism enabled via megatron\n- dynamic batch size (sequence balance) for megatron\n- reduced ray-related serialization overhead\n- optimizer offloading, recomputation, and efficient kernels\n- various debugging metrics and utils\n- hybrid optimizer\n\nand the megatron backend now has a wider list of models supported:\n- DeepSeek-V3\n- Moonlight\n- Qwen3\n- Qwen2.5-VL (to be merged soon)\n- Qwen2\n- Mixtral\n\n## Getting Started\n\n### preparation\nThe recommended image with pre-built Megatron dependency is `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.13.0-preview`, which is built using the Dockerfile at [docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview](https://github.com/volcengine/verl/blob/main/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview).\n\nThe image is build in Hopper GPUs with DeepEP. It does not support None-Hopper GPUs, such as A100. You may need to reinstall DeepEP to work with A100.\n\nWith `OFFLOAD_FRACTION=1`, the system's minimum requirements are lowered. It can run on as few as 96 H20 (96GB) GPUs for DeepSeek-V3, and on as few as 32 H20 (96GB) GPUs for Qwen3-235B-A22B. However, this configuration will use 1.6TB CPU memory per node. If you run out of CPU memory or require faster training speed, you can add more nodes.\n\n### DeepSeek 671b\n\nFor DeepSeek-V3 671b, please refer to [examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh).\n\nMTP and quantilization is disabled during RL training.\n\nTo train your project, configure the following environment variables based on the number of available GPUs. These are recommended settings and can be adjusted based on your specific hardware.\n| num gpus | NNODES | TP | PP | EP | OFFLOAD_FRACTION | OFFLOAD_OPTIM | LAST_LAYER |\n| -- | -- | -- | -- | -- | -- | -- | -- |\n| 96 | 12 | 8 | 12 | 8 | 1. | False | 6 |\n| 128 | 16 | 8 | 16 | 8 | 0.5 | True | 1 |\n| 256 | 32 | 8 | 16 | 8 | 0. | True | 1 |\n| 512 | 64 | 1 | 16 | 32 | 0 | True | 1 |\n\n### Qwen3 235b\n\nFor Qwen3-235b, please refer to [examples/grpo_trainer/run_qwen3-235b_megatron_96gb.sh](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen3-235b_megatron_96gb.sh).\n\nTo train your project, configure the following environment variables based on the number of available GPUs. These are recommended settings and can be adjusted based on your specific hardware.\n| num gpus | NNODES | TP | PP | EP | OFFLOAD_FRACTION | OFFLOAD_OPTIM | LAST_LAYER |\n| -- | -- | -- | -- | -- | -- | -- | -- |\n| 32 | 4 | 4 | 8 | 4 | 1. | False | 6 |\n| 64 | 8 | 4 | 8 | 4 | 0.5 | True | 6 |\n| 128 | 16 | 4 | 8 | 4 | 0 | True | 6 |\n| 256 | 32 | 4 | 8 | 4 | 0 | True | 6 |\n\n### Benchmark\nHere are some benchmark results for DeepSeek / Qwen3-235B. All configurations match the recommended settings based on the number of GPUs.\n\n| model | num gpus | mean response length | rollout time(s) | GPU memory(GB) | CPU memory(GB) | MFU | step time(s) |\n| -- | -- | -- | -- | -- | -- | -- | -- |\n| DeepSeek 671b | 96 | 1960 | 1050 | 66 | 1500 | 0.19 | 1700 |\n\n### Qwen3-30B-A3B MOE\n\nFor Qwen3-30b, please refer to [examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh).\n\nTo train your project, configure the following environment variables based on the number of available GPUs. These are recommended settings and can be adjusted based on your specific hardware.\n| num gpus | NNODES | TP | PP | EP | OFFLOAD_FRACTION | OFFLOAD_OPTIM | MFU |\n| -- | -- | -- | -- | -- | -- | -- | -- | \n| 8 | 1 | 1 | 1 | 8 | 1. | True | 0.4 |\n| 16 | 2 | 1 | 1 | 8 | 1. | True | 0.37 |\n| 32 | 4 | 1 | 1 | 8 | 1. | True | 0.31 |\n\n\n## Upcoming Optimizations\n\nThe community continue to optimize large MoE models further, ongoing efforts include:\n- further optimizing memory consumption, and provide recommended/tuned configurations with various machine types\n- optimizing long context RL training performance\n- performance improvement with SGLang x Megatron\n\nWe invite the community to try and improve verl together. Get connected with us on [slack](https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA)/[wechat](https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG)/[Github issues](https://github.com/volcengine/verl/issues/708)!\n\n## Acknowledgement\n@vermouth1992 @ISEEKYAN @ETOgaosion @yzlnew @ShareLer @BearBiscuit05 @ccclyu @ann-qin-lu @SwordFaith @zzong2006 @zhaochenyang20 @ocss884 @eric-haibin-lin @chenhaiq @techkang\n"
  },
  {
    "path": "verl_distillation/docs/perf/nsight_profiling.md",
    "content": "# NVIDIA Nsight Systems profiling in verl\n\nLast updated: 06/20/2025.\n\nThis guide explains how to use NVIDIA Nsight Systems for profiling verl training runs.\n\n## Configuration\n\nProfiling in verl can be configured through several parameters in the trainer configuration file (ppo_trainer.yaml or other files like dapo_trainer.yaml):\n\n### Prerequisites\n\nNsight Systems version is important, please reference `docker/Dockerfile.vllm.sglang.megatron` for the version we used.\n\n### Global profiling control\n\nverl has one single controller process and multiple worker processes. Both controller and worker processes can be profiled. Since the controller process can be executed in any nodes in the cluster, there is a message printed in the logging to indicate the controller process node hostname and process id.\n\nIn `global_profiler`, three new config entries control the profiler behaviors:\n\n* **`global_profiler.steps`**. List of step numbers at which profiling should be performed. For example: [1, 2, 5] will profile steps 1, 2, and 5. And ``null`` means no profiling.\n\n* **`global_profiler.profile_continuous_steps`**. If true, and the following `global_profiler.discrete==False`, then the continuous steps in `global_profiler.steps` will be combined into one database. For example the above step 1 and 2 are in one database, and 5 in another. If false, every step occupies at least one database. The reason for this config is to observe the program behaviors between steps.\n\nNsys options in controller nodes and worker nodes are configured in `global_profiler.global_tool_config.nsys`:\n\n* **`global_profiler.global_tool_config.nsys.controller_nsight_options`**. This config group is for the single controller. All fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. `ppo_trainer.yaml` provides a workable example. Users can reference [Nsight Systems manual](https://docs.nvidia.com/nsight-systems/UserGuide/index.html) and [Ray user guide](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html) for more details.\n* **`global_profiler.global_tool_config.nsys.worker_nsight_options`**. This config group is for the worker processes. Similarly all fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. Capture range is used to control the profiler when to start and stop. So `capture-range: \"cudaProfilerApi\"` is fixed and does not change it. Users can change `capture-range-end` with some accurate calculation or just leave it `null`.\n\n### Worker process profiling\n\nVerl manages mulitiple RL roles, _Actor_, _Ref_, _Rollout_, _Critic_, _Reward_, which are implemented in different Worker classes. And these workers can be combined into one Ray Actor, running in a process group. Each RL role has its own profiling config group, `profiler`, which consists of three fields:\n\n* **`all_ranks` and `ranks`**. When `all_ranks` is set `True` then all ranks will be profiled; when set `False`, `ranks` will be profiled. By default, verl profiles the whole training process in a series ` worker_process_<PID>.<RID>.nsys-rep` files for each process rank. PID is the process ID; RID is the capture range ID.\n* **`discrete`**. When set `False`, all the roles actions in one training step will be dumped in one database. When set `True`, the actions annotated by `DistProfiler.annotate` will be dumped into a discrete database. In this case, each role's action occupies one `<RID>`.\n* **Verl collocate mode**. Verl can combine two Worker sub classes to one Worker Actor. In this case, the user should take care that the combined Workers have consistent `discrete`. The Nsight Systems profiler uses a `torch.cuda.profiler.start()` and `stop()` pair to dump a `<step>` database anyway.\n\n### where to find the profiling data\n\nBy default the `*.nsys-rep` files are saved in the directory `/tmp/ray/session_latest/logs/nsight/` at each node. According to the Ray manual, this default directory is not changeable. [&#34;however, Ray preserves the `--output` option of the default config&#34;](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html).\n\nSome users may think it is not convenient, but it is understandable that Ray may start hundreds of processes and it would be a big network file system pressure if we save the files in one central place.\n\n## Usage Example\n\nTo enable profiling for specific components and steps, modify your ppo_trainer.yaml like this:\n\n### Disable profiler\n\n```yaml\n    profiler:\n        steps: null # disable profile\n```\n\n### Enable profiler and one database for one training step\n\n```yaml\n    global_profiler:\n        steps: [1, 2, 5]\n        discrete: False\n    actor_rollout_ref:\n        actor:\n            profiler:\n                enable: True\n                all_ranks: True\n        # rollout & ref follow actor settings\n    critic:\n            profiler:\n                enable: True\n                all_ranks: True\n    reward_model:\n            profiler:\n                enable: True\n                all_ranks: True\n```\n\n### Enable profiler and multiple databases for one training step\n\n```yaml\n    profiler:\n        steps: [1, 2, 5]\n        discrete: True\n```\n\n## Profiling Output\n\nWhen profiling is enabled, verl will generate Nsight Systems profiles for the specified components and steps. The profiles will include:\n\n- CUDA kernel execution\n- Memory operations\n- CPU-GPU synchronization\n- NVTX markers for key operations\n\nNsight Systems supports multi-report view, to open multiple databases together. In this mode, different processes and steps can be aligned in one time line for better analysis.\n"
  },
  {
    "path": "verl_distillation/docs/perf/perf_tuning.rst",
    "content": "Performance Tuning Guide\n==============================\n\nLast updated: 07/17/2025.\n\nAuthor: `Guangming Sheng <https://github.com/PeterSH6>`_, `Jiali Zheng <https://github.com/CurryRice233>`_\n\nIn this section, we will discuss how to tune the performance of all the stages in verl, including:\n\n1. Rollout generation throughput.\n\n2. Enable ``use_remove_padding=True`` for sequence packing (i.e., data packing and remove padding).\n\n3. Batch size tuning for forward and backward computation\n\n4. Enable ``use_dynamic_bsz=True`` for higher throughput.\n\n5. Utilize Ulysses Sequence Parallel for Long Context Training\n\n6. LigerKernel for SFT performance optimization\n\n7. Forward prefetch in FSDP training backend\n\n8. Memory optimization for entropy calculation from logits\n\nRollout Generation Tuning\n--------------------------\n\nverl currently supports two rollout backends: vLLM and TGI (with SGLang support coming soon). \n\nBelow are key factors for tuning vLLM-based rollout. Before tuning, we recommend setting ``actor_rollout_ref.rollout.disable_log_stats=False`` so that rollout statistics are logged.\n\n- Increase ``gpu_memory_utilization``.\n\n  - For vLLM v0.7.0 and later, the vLLM instance will only use gpu_memory_utilization of the **total** memory.\n  - For SGLang, it's the fraction of the free GPU memory used for **static** memory like model weights and KV cache. However, the remaining (1-gpu_memory_utilization) will also be used during inference.\n\n  However, if model parameters and optimizer states are not offloaded, using too high a fraction can lead to OOM. \n  A value between 0.5 and 0.7 often strikes a good balance between high throughput and avoiding OOM.\n\n  Note: since the definition of ``gpu_memory_utilization`` varies across inference engines, a value that works well for one engine may cause OOM for another.\n\n- Adjust ``max_num_seqs`` or ``max_num_batched_tokens``.\n  If the GPU cache utilization is relatively low in the log, increase ``max_num_seqs`` or ``max_num_batched_tokens`` \n  can enlarge the effective batch size in the decoding stage, allowing more concurrent requests per batch. \n  We recommend setting ``max_num_batched_tokens > 2048`` for higher throughput.\n\n- Use a smaller ``tensor_parallel_size``. \n  When GPU resources allow, a smaller tensor parallel size spawns more vLLM replicas. \n  Data parallelism (DP) can yield higher throughput than tensor parallelism (TP), but also increases KVCache consumption. \n  Carefully balance the trade-off between more replicas and higher memory usage.\n  Our experiment in Sec. 8.4 of `HybridFlow paper <https://arxiv.org/pdf/2409.19256v2>`_ evaluate this trade-off.\n\n- Balance performance and memory using ``cudagraph_capture_sizes``.\n  If ``cudagraph_capture_sizes`` is set, vLLM will try to capture the model execution graph for different batch sizes.\n  Since cudagraph memory can not be offloaded to cpu, The memory stay in gpu when update actor is running. \n  Using smaller batch sizes can avoid OOM but slightly reduce throughput.\n  Must to set ``enforce_eager=False`` to use ``cudagraph_capture_sizes``.\n\nMore tuning details such as dealing with Preemption and Chunked-prefill\ncan be found in `vLLM official tuning guide <https://docs.vllm.ai/en/latest/performance/optimization.html>`_ \n\nFor optimal performance, we recommend using vLLM v0.8.3 or later. See https://github.com/volcengine/verl/blob/main/docs/README_vllm0.8.md for details.\n\nEnable remove padding (sequence packing)\n-----------------------------------------\n\nCurrently, for llama, mistral, gemma1 and qwen based models, users can enable `use_remove_padding=True` to utilize the \nsequence packing implementation provided by transformers library.\n\nFor other models, transformers library may also support it but we haven't tested it yet.\nUsers can add the desired model config to the  `test_transformer.py <https://github.com/volcengine/verl/blob/main/tests/models/test_transformer.py#L24>`_ file.\nAnd test its functionality by running the following command:\n\n.. code-block:: bash\n\n  pytest -s tests/models/test_transformer.py\n\nIf the test passes, you can add your desired model into the model `registry.py <https://github.com/volcengine/verl/blob/main/verl/models/registry.py#L24>`_ file.\nThen, you can enjoy the performance boost of sequence packing\nand welcome to PR your tested model to verl!\n\n\nBatch Size Tuning\n-----------------\n\nTo achieve higher throughput in experience preparation (i.e., model fwd) and model update (i.e., actor/critic fwd/bwd), \nusers may need to tune the ``*micro_batch_size_per_gpu`` for different computation.\n\nIn verl, the core principle for setting batch sizes is:\n\n- **Algorithmic metrics** (train batch size, PPO mini-batch size) are *global* (from a single-controller perspective), \n  normalized in each worker. See the `normalization code <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py#L120-L122>`_.\n\n- **Performance-related parameters** (micro batch size, max token length for dynamic batch size) are *local* parameters that define the per-GPU data allocations. \n  See the `normalization code <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py#L127>`_.\n\n.. note:: In your training script, please use ``*micro_batch_size_per_gpu`` instead of ``*micro_batch_size``. \n  So that you don't need to consider the normalization of the ``micro_batch_size`` and ``micro_batch_size`` will be deprecated.\n\nBatch Size Tuning tips\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nTherefore, users may need to tune the ``*micro_batch_size_per_gpu`` to accelerate training. Here're some tips:\n\n1. **Enable gradient checkpointing**: \n   Set ``actor_rollout_ref.model.enable_gradient_checkpointing=True`` and ``critic.model.enable_gradient_checkpointing=True``. \n   This often allows for larger micro-batch sizes and will be beneficial for large mini-batch training.\n\n2. Increase the ``*micro_batch_size_per_gpu`` as much as possible till equals to normalized ``mini_batch_size``.\n\n3. **Use larger forward-only parameters**: \n   Forward only parameter, such as ``actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu``, \n   ``actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu``, ``critic.forward_micro_batch_size_per_gpu`` could be larger (e.g., 2x) than training related micro batch sizes,\n   such as ``actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``, ``critic.ppo_micro_batch_size_per_gpu``.\n\n4. **Allow larger micro-batch sizes for Critic and Reward models**:\n   micro batch size of Critic and Reward model could be larger than Actor model. This is because the actor model has much larger vocab size in the final layer.\n\n5. **Enable activation offloading**:\n   Set ``actor_rollout_ref.model.enable_activation_offload=True`` and ``critic.model.enable_activation_offload=True``.\n   This often works together with gradient checkpointing to get larger micro-batch sizes and it's only available in FSDP backend now.\n\nTuning for Dynamic Batch Size\n-----------------------------\n\nDynamic batch size is a technique that allows the model to process similar number of tokens in a single forward pass (with different actual batch sizes).\nThis can significantly improve the training efficiency and reduce the memory usage.\n\nTo utilize this technique, users can set ``use_dynamic_bsz=True`` in actor, ref, critic and reward models.\nWith ``use_dynamic_bsz=True``, users don't need to tune ``*micro_batch_size_per_gpu``. \nInstead, users should tune the following parameters:\n\n- ``actor_rollout_ref.actor.ppo_max_token_len_per_gpu``, ``critic.ppo_max_token_len_per_gpu``: \n  The maximum number of tokens to be processed in fwd and bwd of ``update_policy`` and ``update_critic``.\n\n- ``actor_rollout_ref.ref.log_prob_max_token_len_per_gpu`` and ``actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu``: \n  The maximum number of tokens to be processed in a the fwd computation of ``compute_log_prob`` and ``compute_ref_log_prob``.\n\n- ``critic.forward_micro_batch_size_per_gpu``, ``reward_model.forward_micro_batch_size_per_gpu``: \n  The maximum number of tokens to be processed in a the fwd computation of ``compute_values``, ``compute_rm_score``.\n\nDynamic Batch Size Tuning tips\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nHere're some tips to tune the above parameters:\n\n1. **Increase** ``actor_rollout_ref.actor.ppo_max_token_len_per_gpu``  \n   Make it at least 2 x (max_prompt_length + max_response_length). We set it to 3x in `run_qwen2-7b_rm_seq_balance.sh <https://github.com/volcengine/verl/blob/main/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh#L25>`_.\n   Try to increase it to get higher throughput.\n\n2. **Forward-only parameters can be larger**: \n   Similar to the non-dynamic-batch scenario, forward-only token limits can exceed those used in forward/backward operations.\n \n3. **Use larger limits for Critic and Reward models**:\n   Critic and Reward parameters can be set at least 2× the Actor’s limits. For instance, we set them to 4× here:  \n   `run_qwen2-7b_rm_seq_balance.sh <https://github.com/volcengine/verl/blob/main/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh#L40>`_\n   \n.. :math:`\\text{critic.ppo_max_token_len_per_gpu}  = 2 \\times  \\text{actor.ppo_max_token_len_per_gpu})`.\n\nUlysses Sequence Parallel for Long Context Training\n----------------------------------------------------\n\nTo utilize this technique, users can set ``ulysses_sequence_parallel_size>1`` in actor, ref, critic and reward models.\n\nWe support different model utilize different ulysses_sequence_parallel_size sizes.\n\nTo train long sequence (>32k), users may need to decrease the ``*micro_batch_size_per_gpu`` and ``*max_token_len_per_gpu`` to avoid OOM.\n\nLigerKernel for SFT\n----------------------\n\nLigerKernel is a high-performance kernel for Supervised Fine-Tuning (SFT) that can improve training efficiency. To enable LigerKernel in your SFT training:\n\n1. Install liger-kernel via ``pip3 install liger-kernel``. In your SFT configuration file (e.g., ``verl/trainer/config/sft_trainer.yaml``), set the ``use_liger`` parameter:\n\n   .. code-block:: yaml\n\n      model:\n        use_liger: True  # Enable LigerKernel for SFT\n\n2. The default value is ``False``. Enable it only when you want to use LigerKernel's optimizations.\n\n3. LigerKernel is particularly useful for improving training performance in SFT scenarios.\n\nForward prefetch in FSDP training backend\n----------------------\n\nDuring the training phase, users can enable forward prefetching in FSDP by setting ``fsdp_config.forward_prefetch=True``. For example, ``actor_rollout_ref.actor.fsdp_config.forward_prefetch=True``. This configuration prefetches the next forward-pass all-gather operation before completing the current forward computation, overlapping communication with computation and improving efficiency. For further details, refer to the `FSDP forward_prefetch <https://docs.pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp>`_ documentation.\n\n.. note::\n    Backward prefetch is unsupported because the ``BACKWARD_POST`` policy may prefetch incorrectly in nested-module cases. For details, see the `FSDP documentation <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md?plain=1#L70>`_\n\nMigrating to FSDP2\n----------------------\n\nFSDP2 offers notable improvements over FSDP1. According to `PyTorch TorchTitan benchmarks <https://arxiv.org/abs/2410.06511v1>`_:\n\n- 7% lower GPU memory usage on average\n- 1.5% throughput improvement with BF16 training\n- Better composability with DTensor and per-parameter sharding\n\n**Enabling FSDP2 in VERL:**\n\n   .. code-block:: python\n\n    # Enable FSDP2 in actor configuration\n    actor_rollout_ref.actor.strategy=\"fsdp2\"\n\n.. note:: \n   FSDP2 requires PyTorch 2.1+ and is recommended for models with transformer architecture.\n\nMemory optimization for entropy calculation from logits\n----------------------\n\nThe ``logits`` tensor (typically of shape ``[bsz*seq_len, voc]``) can consume significant memory. When using ``compute_entropy_from_logits``, memory usage reaches approximately ``[bsz*seq_len, voc] × (4 bytes (float32) + 2 bytes (autocast for softmax+logsumexp) + 1 byte (softmax output))``.\n\nTo reduce this memory peak, enable chunked computation by setting:\n``actor_rollout_ref.ref.entropy_from_logits_with_chunking = True``\nThis processes the tensor in chunks of shape ``[chunk_size, voc]`` (e.g., 2048) rather than the full sequence length, exclusively during the model's forward pass.\n\nAdditionally, during training, standard gradient checkpointing (``enable_gradient_checkpointing=True``) does not apply to entropy calculations. To reduce memory peaks in this context, set:\n``actor_rollout_ref.actor.entropy_checkpointing = True``\nThis enables entropy recomputation specifically for the entropy calculation, lowering memory usage during training.\n"
  },
  {
    "path": "verl_distillation/docs/perf/verl_profiler_system.md",
    "content": "# verl Profiler System\n\nLast updated: 08/18/2025.\n\n## Architecture\n\nThe architecture of verl profiler system is like below:\n\n![verl-profiler-arch](https://raw.githubusercontent.com/eric-haibin-lin/verl-community/2bc7ed0ba2f37f21707bfac3b241eca4b86d1bc6/docs/verl_profiler_arch.png)\n\nThere is a global profiler and tool configuration to set some common config in single controller level, deciding\n\n- `tool`: which tool to use\n- `steps`: which steps to profile\n- `save_path`: results saving path\n\nWhen some tool need to profile behavior of each role, configurations in role-level is needed:\n\n- `tool`: which tool to use\n- `enable`: whether enable profiling on this role\n- rank info: `all_ranks` and `rank` to decide which rank to profile or log output\n\nFor tool config in role-level, there are some detailed behavior needed to control, like the `discrete` mode in nsys profiler.\n\nEvery role has a profiler config, and by default, rollout/ref/reward models follow the Actor's behavior.\n\n## To Add a new profiling tool\n\nNew added profiling tool shall reuse the current APIs as much as possible.\n\n1. The logic of **whether to use the tool**: `tool == [new tool]`.\n2. Add the global and local tool config to `ppo_trainer.yaml`/`ppo_megatron_trainer.yaml` and each `[role].yaml`, under `global_tool_config.[new tool]` and `tool_config.[new tool]`\n3. The tool config should be implemented in `verl/utils/profiler/config.py`, inherit the `BaseConfig` class.\n4. Implement profiling tool initialization logic using configurations in `global_profiler.global_tool_config.[new tool]` and the results saving logics (can also save in role-level profile)\n5. For role function-level profiling, please follow the nsys profiler way in `nvtx_profiler.py`, implement a profiler class inherit `DistProfiler` and import new profiler in `verl/utils/profiler/__init__.py`\n6. Add unit test and examples for others to use in convinience."
  },
  {
    "path": "verl_distillation/docs/preparation/prepare_data.rst",
    "content": "Prepare Data for Post-Training\n========================================\n\nLast updated: 02/09/2025.\n\nBefore starting the post-training job, we need to prepare the data for\nthe policy training. The data should be stored in the parquet format.\n\nWe provide several data preprocess scripts for different datasets,\nincluding GSM8K, MATH, HelloSwag, Full_hh_rlhf. To prepare other datasets, we need\nto follow the following steps: The data preprocess script can be divided\ninto two parts:\n\n1. The first part is the common part, which loads the dataset from\n   huggingface's ``datasets`` package. Then preprocess the datasets with\n   the ``make_map_fn`` and then store in the parquet format.\n\n.. code:: python\n\n   import re\n   import os\n   import datasets\n\n   from verl.utils.hdfs_io import copy, makedirs\n   import argparse\n\n   # To extract the solution for each prompts in the dataset\n   # def extract_solution(solution_str): \n   # ...\n\n\n   if __name__ == '__main__':\n       parser = argparse.ArgumentParser()\n       parser.add_argument('--local_dir', default='/opt/tiger/gsm8k')\n       parser.add_argument('--hdfs_dir', default=None)\n\n       args = parser.parse_args()\n\n       num_few_shot = 5\n       data_source = 'openai/gsm8k'\n\n       dataset = datasets.load_dataset(data_source, 'main')\n\n       train_dataset = dataset['train']\n       test_dataset = dataset['test']\n\n           # Construct a `def make_map_fn(split)` for the corresponding datasets.\n       # ...\n           \n       train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)\n       test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)\n\n       local_dir = args.local_dir\n       hdfs_dir = args.hdfs_dir\n\n       train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))\n       test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))\n\n       makedirs(hdfs_dir)\n\n       copy(src=local_dir, dst=hdfs_dir)\n\n2. The users are required to implement the ``make_map_fn()`` function\n   (as well as the ``extract_solution``) on their own to support\n   different datasets or tasks.\n\nWe already implemented the data preprocess of GSM8k, MATH, Hellaswag and Full_hh_rlhf\ndatasets. And we take the GSM8k dataset as an example:\n\n**GSM8K**\n\nIn the ``make_map_fn``, each data field should consist of the following\n5 fields:\n\n1. ``data_source``: The name of the dataset. To index the corresponding\n   reward function in the ``RewardModel``\n2. ``prompt``: This field should be constructed in the format of\n   huggingface chat_template. The tokenizer in ``RLHFDataset`` will\n   apply chat template and tokenize the prompt.\n3. ``ability``: Define the task category.\n4. ``reward_model``: Currently, we only utilize the ``ground_truth``\n   field during evaluation. The ``ground_truth`` is computed by the\n   ``extract_solution`` function. **NOTED** that the implementation of\n   the corresponding reward function should align with this extracted\n   ``ground_truth``.\n5. ``extra_info``: Record some information of the current prompt. Not\n   use for now.\n\n.. code:: python\n\n   def extract_solution(solution_str):\n       solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str) # extract the solution after ####\n       assert solution is not None\n       final_solution = solution.group(0)\n       final_solution = final_solution.split('#### ')[1].replace(',', '')\n       return final_solution\n\n   instruction_following = \"Let's think step by step and output the final answer after \\\"####\\\".\"\n\n   # add a row to each data item that represents a unique id\n   def make_map_fn(split):\n\n       def process_fn(example, idx):\n           question = example.pop('question')\n\n           question = question + ' ' + instruction_following\n\n           answer = example.pop('answer')\n           solution = extract_solution(answer)\n           data = {\n               \"data_source\": data_source,\n               \"prompt\": [{\n                   \"role\": \"user\",\n                   \"content\": question\n               }],\n               \"ability\": \"math\",\n               \"reward_model\": {\n                   \"style\": \"rule\",\n                   \"ground_truth\": solution\n               },\n               \"extra_info\": {\n                   'split': split,\n                   'index': idx\n               }\n           }\n           return data\n\n       return process_fn\n"
  },
  {
    "path": "verl_distillation/docs/preparation/reward_function.rst",
    "content": "Implement Reward Function for Dataset\n======================================\n\nLast updated: 06/02/2025.\n\nFor each dataset, we need to implement a reward function or utilize a reward model to compute the rewards for the generated responses.\nWe already pre-implemented some reward functions in `reward_score directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_.\nYou can also use customized reward functions.\n\nCurrently, we support reward functions for GSM8k and MATH datasets. For RLHF datasets (e.g.,\nfull_hh_rlhf) and Code Generation (e.g., APPS), we utilize reward model\nand SandBox (will opensource soon) for evaluation respectively.\n\nRewardManager\n-------------\n\nIn the entrypoint of the PPO Post-Training script `main_ppo.py <https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py#L33>`_,\nwe implement a ``RewardManager`` that utilize pre-implemented reward functions to compute the scores for each response.\n\nIn the ``RewardManager``, we implemented a ``__call__`` function to\ncompute the score for each response. \nAll the reward functions are executed by ``compute_score_fn``.\nThe input is a ``DataProto``, which includes:\n\n- ``input_ids``, ``attention_mask``: ``input_ids`` and ``attention_mask`` after applying\n  chat_template, including prompt and response\n- ``responses``: response tokens\n- ``ground_truth``: The ground truth string of the current prompt.\n  Stored in ``non_tensor_batch`` in the ``DataProto``, which should be\n  preprocessed in the parquet files.\n- ``data_source``: The dataset name of the current prompt. Stored in\n  ``non_tensor_batch`` in the ``DataProto``, which should be\n  preprocessed in the parquet files.\n\nAfter detokenize the responses, the responses string and the ground\ntruth string will be input to the ``compute_score_fn`` to compute the\nscore for each response.\n\nReward Functions\n----------------\n\nPre-implemented\n~~~~~~~~~~~~~~~\n\nWe already pre-implemented some reward functions in `reward_score directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_.\n\n- In the `GSM8k example <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/gsm8k.py>`_, we\n  force the response to output the final answer after four ####, then\n  use string matching to compare with the ground truth. If completely\n  correct, score 1 point; if the format is correct, score 0.1 points; if\n  the format is incorrect, score 0 points.\n- In the `MATH example <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/math.py>`_, we follow\n  the implementation in `lm-evaluation-harness repository <https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py>`_.\n\nCustomized\n~~~~~~~~~~\n\nYou can implement customized reward functions in a separate file and specify them using ``custom_reward_function.path`` and ``custom_reward_function.name``. For the set of them, please refer to :ref:`config-explain-page`.\n\nThe parameters of your reward function should be ``data_source``, ``solution_str``, ``ground_truth``, and ``extra_info``.\nFor example:\n\n.. code:: python\n\n  def my_reward_fn(data_source, solution_str, ground_truth, extra_info=None):\n    return len(solution_str)/100\n\nIf you are testing only a single customized reward function, you can simply name it 'compute_score' and leave ``custom_reward_function.name`` unset.\n\nTo run multiple tests with different customized reward functions, you can modify both ``custom_reward_function.path`` and ``custom_reward_function.name`` for each trial. \nFor instance, you might create a single `my_reward.py` file and implement multiple reward functions within it. This way, for different trials, you only need to adjust ``custom_reward_function.name``, making it more convenient to conduct multiple tests within scripts.\n"
  },
  {
    "path": "verl_distillation/docs/requirements-docs.txt",
    "content": "# markdown support\r\nrecommonmark\r\nmyst_parser\r\n# markdown table support\r\nsphinx-markdown-tables\r\n\r\n# theme default rtd\r\n\r\n# crate-docs-theme\r\nsphinx-rtd-theme\r\n\r\n# pin tokenizers version to avoid env_logger version req\r\ntokenizers==0.21\r\n"
  },
  {
    "path": "verl_distillation/docs/sglang_multiturn/interaction_system.rst",
    "content": "Interaction System for Multi-turn RL Training\n=============================================\n\nLast updated: 06/25/2025.\n\nOverview\n--------\n\nThe verl interaction system enables dynamic, multi-turn conversational feedback during reinforcement learning training. This system allows models to engage in iterative problem-solving scenarios where interaction agents can provide corrective feedback, guidance, or evaluation based on the model's responses.\n\n**New in Multi-Interaction Support**: The system now supports multiple named interactions within a single training session, enabling sophisticated training scenarios where different samples can use different interaction strategies. This allows for curriculum learning, domain-specific feedback, and flexible agent switching at the sample level.\n\nKey features:\n\n- **Async-based Architecture**: Non-blocking interaction processing for distributed training\n- **Instance Management**: Stateful session handling with unique instance IDs for concurrent interactions\n- **SGLang Integration**: Seamless integration with SGLang rollout system for multi-turn conversations\n- **Configuration-driven**: Dynamic agent loading via YAML configuration files\n- **Multi-Interaction Support**: Registry system enabling multiple named interactions per rollout\n- **Sample-Level Selection**: Each sample can specify which interaction to use via configuration\n- **Reward Integration**: Turn-level scoring mechanism integrated with verl's reward system\n\nArchitecture\n------------\n\nThe interaction system follows a plugin-based architecture with clear separation of concerns:\n\n.. code-block::\n\n    Interaction Registry System\n         ↓\n    BaseInteraction (Abstract Interface)\n         ↓\n    Multiple Named Interactions (e.g., Gsm8kInteraction, CustomInteraction)\n         ↓\n    SGLang Rollout Integration (interaction_map)\n         ↓\n    Sample-Level Interaction Selection\n         ↓\n    Async Request Lifecycle Management\n\nCore Components\n~~~~~~~~~~~~~~~\n\n**Interaction Registry System**\n\nThe interaction registry system allows loading and managing multiple named interactions:\n\n.. code-block:: python\n\n    from verl.interactions.utils.interaction_registry import initialize_interactions_from_config\n    \n    # Load multiple interactions from config\n    interaction_map = initialize_interactions_from_config(\"config.yaml\")\n    \n    # Access specific interaction by name\n    gsm8k_interaction = interaction_map[\"gsm8k\"]\n    custom_interaction = interaction_map[\"custom_solver\"]\n\n**BaseInteraction Interface**\n\nAll interaction agents must implement the ``BaseInteraction`` abstract class:\n\n.. code-block:: python\n\n    from verl.interactions.base import BaseInteraction\n    from typing import Dict, Any, List, Tuple, Optional\n\n    class BaseInteraction:\n        def __init__(self, config: Dict[str, Any]):\n            self.config = config\n            self.name: str = config.get(\"name\", \"interaction_agent\")\n        \n        async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str:\n            \"\"\"Initialize interaction session, return instance_id\"\"\"\n            \n        async def generate_response(self, instance_id: str, messages: List[Dict[str, Any]], **kwargs) -> Tuple[bool, str, float, Dict[str, Any]]:\n            \"\"\"Generate response, return (should_terminate, response, score, metadata)\"\"\"\n            \n        async def calculate_score(self, instance_id: str, **kwargs) -> float:\n            \"\"\"Calculate turn-level score for RL training\"\"\"\n            \n        async def finalize_interaction(self, instance_id: str, **kwargs) -> None:\n            \"\"\"Clean up resources\"\"\"\n\n**Request Lifecycle**\n\nThe interaction system integrates with SGLang's async rollout via state management:\n\n1. ``PENDING`` → Initialize interaction via ``start_interaction()``\n2. ``GENERATING`` → Model generates response\n3. ``INTERACTING`` → Process response via ``generate_response()``\n4. ``GENERATING`` → Continue if not terminated, otherwise ``COMPLETED``\n\nConfiguration\n-------------\n\n**Basic Setup**\n\nEnable interaction in your rollout configuration:\n\n.. code-block:: yaml\n\n    actor_rollout_ref:\n        rollout:\n            multi_turn:\n                enable: true\n                interaction_config_path: \"path/to/interaction_config.yaml\"\n                max_user_turns: 10\n                max_assistant_turns: 10\n\n**Interaction Configuration File**\n\nCreate an interaction configuration file (e.g., ``interaction_config.yaml``):\n\n**Single Interaction (Legacy Format)**\n\n.. code-block:: yaml\n\n    interaction:\n      - name: \"gsm8k\"\n        class_name: \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\"\n        config: {}\n\n**Multiple Interactions (New Format)**\n\n.. code-block:: yaml\n\n    interaction:\n      - name: \"gsm8k\"\n        class_name: \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\"\n        config: {}\n      - name: \"custom_solver\"\n        class_name: \"custom.interactions.CustomInteraction\"\n        config: \n          solver_type: \"advanced\"\n          timeout: 30\n      - name: \"code_verifier\"\n        class_name: \"verl.interactions.base.BaseInteraction\"\n        config: \n          verification_mode: \"strict\"\n\n**Automatic Name Generation**\n\nIf no ``name`` field is provided, the system will automatically generate one from the class name:\n\n.. code-block:: yaml\n\n    interaction:\n      - class_name: \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\"\n        config: {}\n        # Automatically generates name: \"gsm8k\"\n\nThe system will dynamically load all specified interaction classes and make them available by name.\n\nImplementation Example: GSM8K\n-----------------------------\n\nThe GSM8K interaction demonstrates a complete implementation for math problem-solving scenarios:\n\n.. code-block:: python\n\n    from verl.interactions.base import BaseInteraction\n    from verl.utils.reward_score import gsm8k\n    from uuid import uuid4\n\n    class Gsm8kInteraction(BaseInteraction):\n        def __init__(self, config: dict):\n            super().__init__(config)\n            self._instance_dict = {}\n\n        async def start_interaction(self, instance_id=None, ground_truth=None, **kwargs):\n            if instance_id is None:\n                instance_id = str(uuid4())\n            self._instance_dict[instance_id] = {\n                \"response\": \"\",\n                \"ground_truth\": ground_truth,\n                \"reward\": 0.0,\n            }\n            return instance_id\n\n        async def generate_response(self, instance_id, messages, **kwargs):\n            # Extract last assistant message content\n            content = \"\"\n            for item in reversed(messages):\n                if item.get(\"role\") == \"assistant\":\n                    content = item.get(\"content\", \"\")\n                    break\n\n            # Ensure GSM8K format (#### prefix)\n            self._instance_dict[instance_id][\"response\"] = content\n\n            reward = await self.calculate_score(instance_id)\n            if reward == 1.0:\n                return True, \"Your response is correct!\", 1.0, {}\n            else:\n                return False, \"Your response is incorrect! You need to reflect on your answer and try again.\", 0.0, {}\n\n        async def calculate_score(self, instance_id, **kwargs):\n            return gsm8k.compute_score(\n                self._instance_dict[instance_id][\"response\"],\n                self._instance_dict[instance_id][\"ground_truth\"],\n                method=\"strict\", format_score=0.0, score=1.0,\n            )\n\n        async def finalize_interaction(self, instance_id, **kwargs):\n            del self._instance_dict[instance_id]\n\nTraining Integration\n--------------------\n\n**Training Script Configuration**\n\nInclude interaction configuration in your training command:\n\n.. code-block:: bash\n\n    python3 -m verl.trainer.main_ppo \\\\\n        --config-path=\"$CONFIG_PATH\" \\\\\n        --config-name='gsm8k_multiturn_grpo_w_interaction' \\\\\n        algorithm.adv_estimator=grpo \\\\\n        data.train_batch_size=512 \\\\\n        data.return_raw_chat=True \\\\\n        actor_rollout_ref.rollout.name=sglang \\\\\n        actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml\" \\\\\n        trainer.total_epochs=15\n\n**Data Requirements**\n\nEnsure your dataset includes interaction parameters with the ``name`` field for interaction selection:\n\n.. code-block:: python\n\n    # Dataset should include interaction_kwargs in non_tensor_batch\n    interaction_kwargs = [\n        {\"name\": \"gsm8k\", \"query\": \"What is 2+2?\", \"ground_truth\": \"4\"},\n        {\"name\": \"custom_solver\", \"query\": \"Solve: x^2 + 5x + 6 = 0\", \"ground_truth\": \"x = -2, -3\"},\n        {\"name\": \"gsm8k\", \"query\": \"What is 3+3?\", \"ground_truth\": \"6\"},\n    ]\n\n**Sample-Level Interaction Selection**\n\nEach sample can specify which interaction to use via the ``name`` field. This enables flexible training scenarios where different samples use different interaction strategies:\n\n.. code-block:: python\n\n    # Example: Math problems use GSM8K interaction, code problems use code verifier\n    data_samples = [\n        {\n            \"prompt\": \"What is 15% of 200?\",\n            \"interaction_kwargs\": {\n                \"name\": \"gsm8k\",\n                \"query\": \"What is 15% of 200?\", \n                \"ground_truth\": \"30\"\n            }\n        },\n        {\n            \"prompt\": \"Write a function to check if a number is prime\",\n            \"interaction_kwargs\": {\n                \"name\": \"code_verifier\",\n                \"code_type\": \"python\",\n                \"expected_behavior\": \"return True for prime numbers\"\n            }\n        }\n    ]\n\n**Backward Compatibility**\n\nIf no ``name`` field is provided in ``interaction_kwargs``, the system defaults to ``\"gsm8k\"`` for backward compatibility.\n\nBest Practices\n--------------\n\n**Resource Management**\n\n- Always implement proper cleanup in ``finalize_interaction()``\n- Use unique instance IDs to avoid conflicts in concurrent training\n- Handle edge cases like empty messages or malformed content\n\n**Performance Optimization**\n\n- Keep interaction logic lightweight to avoid blocking training\n- Use async/await properly to maintain non-blocking behavior\n- Consider caching expensive computations within interaction instances\n\n**Testing**\n\nComprehensive testing is essential for interaction systems:\n\n.. code-block:: python\n\n    import pytest\n    from unittest.mock import patch\n\n    @pytest.mark.asyncio\n    async def test_interaction_workflow():\n        interaction = YourInteraction({})\n        \n        # Test complete workflow\n        instance_id = await interaction.start_interaction(ground_truth=\"expected_answer\")\n        \n\n        messages = [{\"role\": \"user\", \"content\": \"user_content\"}, {\"role\": \"assistant\", \"content\": \"assistant_content\"}]\n        should_terminate, response, reward, metadata = await interaction.generate_response(instance_id, messages)\n        \n        assert should_terminate in [True, False]\n        assert isinstance(reward, float)\n        \n        await interaction.finalize_interaction(instance_id)\n\nAdvanced Usage\n--------------\n\n**Multi-Interaction Training Strategies**\n\nYou can design sophisticated training scenarios using multiple interactions:\n\n.. code-block:: python\n\n    # Example: Progressive difficulty with different interaction agents\n    class MathTrainingPipeline:\n        def create_interaction_config(self):\n            return {\n                \"interaction\": [\n                    {\n                        \"name\": \"basic_math\",\n                        \"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\",\n                        \"config\": {\"difficulty\": \"easy\"}\n                    },\n                    {\n                        \"name\": \"advanced_math\", \n                        \"class_name\": \"custom.interactions.AdvancedMathInteraction\",\n                        \"config\": {\"difficulty\": \"hard\", \"allow_hints\": True}\n                    },\n                    {\n                        \"name\": \"competition_math\",\n                        \"class_name\": \"custom.interactions.CompetitionMathInteraction\", \n                        \"config\": {\"time_limit\": 300, \"show_steps\": False}\n                    }\n                ]\n            }\n    \n        def create_curriculum_data(self, epoch):\n            if epoch < 5:\n                return [{\"name\": \"basic_math\", ...} for _ in samples]\n            elif epoch < 10:\n                return [{\"name\": \"advanced_math\", ...} for _ in samples]\n            else:\n                return [{\"name\": \"competition_math\", ...} for _ in samples]\n\n**Custom Scoring Functions**\n\nYou can integrate custom reward functions:\n\n.. code-block:: python\n\n    async def calculate_score(self, instance_id, **kwargs):\n        response = self._instance_dict[instance_id][\"response\"]\n        ground_truth = self._instance_dict[instance_id][\"ground_truth\"]\n        \n        # Custom evaluation logic\n        if custom_evaluation_function(response, ground_truth):\n            return 1.0\n        else:\n            return 0.0\n\n**Multi-step Interactions**\n\nFor complex scenarios requiring multiple feedback rounds:\n\n.. code-block:: python\n\n    async def generate_response(self, instance_id, messages, **kwargs):\n        instance = self._instance_dict[instance_id]\n        instance[\"attempts\"] += 1\n        \n        # Evaluate current response\n        reward = await self.calculate_score(instance_id)\n        \n        if reward > 0.8:\n            return True, \"Excellent work!\", reward, {}\n        elif instance[\"attempts\"] < 3:\n            return False, \"Good attempt, but try to improve...\", reward, {}\n        else:\n            return True, \"Maximum attempts reached.\", reward, {}\n\nTroubleshooting\n---------------\n\n**Common Issues**\n\n1. **Instance ID Conflicts**: Ensure unique instance IDs across concurrent sessions\n2. **Memory Leaks**: Always call ``finalize_interaction()`` to clean up resources\n3. **Blocking Operations**: Keep interaction logic async and non-blocking\n4. **Configuration Errors**: Verify interaction config path and class name are correct\n5. **Interaction Name Conflicts**: Ensure all interactions have unique names in the configuration\n6. **Missing Interaction**: Verify the ``name`` field in ``interaction_kwargs`` matches available interactions\n7. **Backward Compatibility**: When migrating from single to multi-interaction, add ``name`` fields to existing data\n\n**Debugging**\n\nEnable debug logging to trace interaction flow:\n\n.. code-block:: bash\n\n    export VERL_LOGGING_LEVEL=DEBUG\n\n**Performance Monitoring**\n\nMonitor interaction performance impact on training throughput and adjust accordingly.\n\nRelated Documentation\n--------------------\n\n- :doc:`multiturn`: Basic multi-turn rollout configuration\n- :doc:`sandbox_fusion`: Tool integration with SGLang\n- :doc:`search_tool_example`: Search tool implementation example"
  },
  {
    "path": "verl_distillation/docs/sglang_multiturn/multiturn.rst",
    "content": "Multi-turn Rollout Support\n==========================\n\nLast updated: 06/27/2025.\n\nBasic Configuration\n~~~~~~~~~~~~~~~~~~~\n\nTo enable multi-turn rollout, make sure to configure the following fields in your rollout configuration:\n\n.. code-block:: yaml\n\n    actor_rollout_ref: \n        rollout: \n            multi_turn: True\n            name: \"sglang\"\n\nThese configuration activates the sglang engine for multi-turn interaction during rollout.\n\nCustom Tool Configuration\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\nFor custom environment interaction tools, you can implement your own tools based on ``verl.tools.base_tool.BaseTool``. Then, specify your tool configurations in a YAML file:\n\n.. code-block:: yaml\n\n    tools:\n      - class_name: \"\"\n        config: \n            type: native\n        tool_schema:\n\nYou may refer to GSM8KTool_example_configuration_, which is one example of the tool configurations. Its implementation can be found in gsm8k_tool.py_.\n\nFinally, set the ``tools_config_file`` in your rollout config:\n\n.. code-block:: yaml\n\n    actor_rollout_ref:\n        rollout:\n            tool_kwargs:\n                tools_config_file: <path_to_tool_yaml_file>\n\nThis allows integration of customized tool behaviors during actor rollout steps.\n\nIf you want rollout with simulated interaction, you can set the ``interaction_config_file`` in your rollout config:\n\n.. code-block:: yaml\n\n    interaction:\n      - class_name: \"\"\n        config: {}\n\n.. code-block:: yaml\n\n    actor_rollout_ref:\n        rollout:\n            interaction_config_file: <path_to_interaction_yaml_file>\n\nIf your tool creates multi-modal inputs, you should return a list of multi-modal inputs in your tool.execute() implementation.\n\nImage and video should be processed before returning. For example, if you are using Qwen2.5-VL, you can use the following code to get the representations:\n\n.. code-block:: python\n\n    async def create(self, ...) -> tuple[str, ToolResponse]:\n        ...\n        from verl.utils.dataset.vision_utils import process_image, process_video\n\n        img1 = process_image(img1)\n        video1 = process_video(video1)\n\n        # due to the (image | video) key is (\"image\" | \"video\") instead of (\"images\" | \"videos\") in vllm, we need to use (\"image\" | \"video\") to specify list of images/videos\n        # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205\n        return instance_id, ToolResponse(image=[img1, ...], video=[video1, ...], text=\"...\")\n\n    async def execute(self, ...) -> Tuple[str | Dict[str, Any], float, dict]:\n        ...\n        from verl.utils.dataset.vision_utils import process_image, process_video\n\n        img1 = process_image(img1)\n        video1 = process_video(video1)\n\n        # due to the (image | video) key is (\"image\" | \"video\") instead of (\"images\" | \"videos\") in vllm, we need to use (\"image\" | \"video\") to specify list of images/videos\n        # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205\n        return ToolResponse(image=[img1, ...], video=[video1, ...], text=\"...\"), 0, {}\n\nremeber to set ``return_multi_modal_inputs: False`` in your dataset config in order to process the multi-modal inputs in the rollout correctly.\nRefer to the `Handling Multi-Modal Inputs in Datasets`_ section for more details.\n\nMCP Tool Configuration\n~~~~~~~~~~~~~~~~~~~~~~\n\nFor MCP interaction tools, you can flexibly configure them using a YAML file. The typical setup is as follows:\n\n.. code-block:: yaml\n\n    tools:\n      - class_name: \"\"\n        config:\n            type: mcp\n        mcp:\n            mcp_servers_config_path: ./mcp_server.json\n            tool_selected_list: {}\n\nThe ``tool_selected_list`` field is optional and specifies which tools to use from the servers. If you want to enable all available tools, simply omit this attribute. Besides, ``mcp_servers_config_path`` points to a JSON file containing the MCP server configurations. For example:\n\n.. code-block:: json\n\n      {\n          \"mcpServers\": {\n              \"SSE Server\": {\n                  \"url\": \"your_server_url\",\n                  \"auth_token\": \"your_server_api_token\"\n              },\n              \"STDIO Server\": {\n                  \"command\": \"npx\",\n                  \"args\": [\"-y\", \"server-mcp@0.2.1\"],\n                  \"env\": {\n                    \"SERVER_API_KEY\": \"your_server_api_token\"\n                  }\n              }\n          }\n      }\n\nSince the content formats returned by the MCP server may vary, users can inherit from ``MCPBaseTool`` and override the ``_parse_tool_result`` method to implement custom parsing logic.\n\n.. code-block:: python\n\n   class MCPYourTool(MCPBaseTool):\n       def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n           super().__init__(config, tool_schema)\n\n       def _parse_tool_result(self, content: list) -> Tuple[str, dict]:\n           ...\n\nOverall, you may refer to mcp_search_tool.py_ and mcp_tool_config.yaml_ for custom implementation and configuration.\n\nMulti-turn Tokenization\n~~~~~~~~~~~~~~~~~~~~~~~\n\nTokenizing multi-turn rollouts poses a challenge: after applying the chat template and tokenizing the full message list, it's hard to identify which tokens belong to assistant messages. Since the token list is flat, it lacks direct alignment with the message roles.\n\nTo address this, we adopt a **delta-based tokenization** strategy. Each time the LLM generates a new message, we:\n\n1. Apply the chat template to all prior messages (`messages[:i]`).\n2. Apply the chat template again including the latest message (`messages[:i+1]`).\n3. Tokenize only the *delta* between these two serialized message strings.\n\nThis ensures that only tokens generated by the assistant are included in the loss mask.\n\n.. code-block:: python\n\n   # When using tokenizer\n   # Exclude the assistant prompt (e.g., \"<|im_start|>assistant\") from the loss by setting add_generation_prompt=True\n   prev = tokenizer.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False)\n   curr = tokenizer.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False)\n   token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False)\n   loss_mask += [1] * len(token_ids)  # Mask only the new assistant tokens\n\n.. code-block:: python\n\n   # When using processor\n   # Exclude the assistant prompt (e.g., \"<|im_start|>assistant\") from the loss by setting add_generation_prompt=True\n   prev = processor.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False)\n   prev_model_inputs = processor(text=prev, images=images, videos=videos, return_tensors=\"pt\")[0].tolist()\n   curr = processor.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False)\n   curr_model_inputs = processor(text=curr, images=images, videos=videos, return_tensors=\"pt\")[0].tolist()\n   token_ids += curr_model_inputs[\"input_ids\"][len(prev_model_inputs[\"input_ids\"]):]\n   loss_mask += [1] * len(token_ids)  # Mask only the new assistant tokens\n\nWhile we've validated this produces consistent results with full message tokenization, future models' chat template could break compatibility. To guard against silent inconsistencies, we compare the delta-based tokenization with full-tokenization results by default at the end of each rollout.\n\nIf you see the following warning, you can check the mismatched substring in the log:\n\n.. code-block::\n\n    Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more information, refer to the multiturn README.md.\n\nThe tokenization sanity check mode can be configured using the ``actor_rollout_ref.rollout.multi_turn.tokenization_sanity_check_mode`` parameter, which accepts the following values:\n\n- ``strict`` (default): Performs strict comparison between delta-based and full tokenization results, raising warnings for any differences.\n\n- ``ignore_strippable``: Ignores differences in whitespace characters (``\\n``, ``\\t``, ``\\r``, spaces) while still checking for meaningful text mismatches. This is useful when debugging chat template issues where whitespace variations are expected and acceptable.\n\n- ``disable``: Completely disables the tokenization sanity check. Only use this if you have thoroughly validated that tokenization discrepancies are expected and won't impact training.\n\nExample configuration:\n\n.. code-block:: yaml\n\n    actor_rollout_ref:\n        rollout:\n            multi_turn:\n                tokenization_sanity_check_mode: \"ignore_strippable\"  # Choose from: \"disable\", \"ignore_strippable\", \"strict\"\n\nHandling Multi-Modal Inputs in Datasets\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nIf your dataset includes multi-modal inputs (such as images or videos), you can control whether these are pre-processed and included in each sample by setting the return_multi_modal_inputs flag in your dataset config (used by RLHFDataset).\n\n- ``return_multi_modal_inputs: True`` (default): The dataset will pre-process and include a multi_modal_inputs dictionary for each sample. This dict contains the model-ready representations (e.g., image tensors, video tensors, etc.) as produced by your processor. This is useful for single-turn or SFT-style training, where the model expects all modalities to be present in the batch.\n\n- ``return_multi_modal_inputs: False``: The dataset will not include the multi_modal_inputs field. This is recommended for multi-turn RL or tool-augmented rollouts, where the model may generate new multi-modal inputs dynamically during rollout, and you want to avoid conflicts or redundant data in the batch.\n\n\nSpecial Cases\n^^^^^^^^^^^^^\n\nSome models (e.g., Qwen/QwQ-32B and Qwen3 series) remove internal reasoning content during chat template rendering. As a result, the message content can vary across turns, making the delta-based tokenization inaccurate.\n\nFor example, for the following conversation:\n\n.. code-block:: python\n\n    messages = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"What is 2 + 2?\"},\n        {\"role\": \"assistant\", \"content\": \"<think>user asked about a simple math question.</think> 2 + 2 = 4.\"},\n        {\"role\": \"user\", \"content\": \"Explain why.\"},\n        {\"role\": \"assistant\", \"content\": \"<think>user wants to know the reasoning behind the answer. Search for a good explanation</think>\",\n         \"tool_calls\": [{\"id\": \"tool1\", \"type\": \"search\", \"arguments\": {\"query\": \"Why is 2 + 2 = 4?\"}}]},\n        {\"role\": \"tool\", \"content\": \"The sum of two and two is four because it is a basic arithmetic operation.\"},\n        {\"role\": \"assistant\", \"content\": \"<think>The tool provided a good explanation.</think>The sum of two and two is four because it is a basic arithmetic operation.\"}\n    ]\n\n1. Qwen/QwQ-32B will remove all reasoning content except the last assistant message after applying the chat template.\n\n.. code-block:: text\n\n    <|im_start|>system\n    You are a helpful assistant.<|im_end|>\n    <|im_start|>user\n    What is 2 + 2?<|im_end|>\n    <|im_start|>assistant\n     2 + 2 = 4.<|im_end|>\n    <|im_start|>user\n    Explain why.<|im_end|>\n    <|im_start|>assistant\n    <tool_call>\n    {\"name\": \"\", \"arguments\": {\"query\": \"Why is 2 + 2 = 4?\"}}\n    </tool_call><|im_end|>\n    <|im_start|>user\n    <tool_response>\n    The sum of two and two is four because it is a basic arithmetic operation.\n    </tool_response><|im_end|>\n    <|im_start|>assistant\n    <think>The tool provided a good explanation.</think> The sum of two and two is four because it is a basic arithmetic operation.<|im_end|>\n\n2. Qwen3 series will remove all reasoning content before the last user message.\n\n.. code-block:: text\n\n    <|im_start|>system\n    You are a helpful assistant.<|im_end|>\n    <|im_start|>user\n    What is 2 + 2?<|im_end|>\n    <|im_start|>assistant\n     2 + 2 = 4.<|im_end|>\n    <|im_start|>user\n    Explain why.<|im_end|>\n    <|im_start|>assistant\n    <think>\n    user wants to know the reasoning behind the answer. Search for a good explanation\n    </think>\n\n    <tool_call>\n    {\"name\": \"\", \"arguments\": {\"query\": \"Why is 2 + 2 = 4?\"}}\n    </tool_call><|im_end|>\n    <|im_start|>user\n    <tool_response>\n    The sum of two and two is four because it is a basic arithmetic operation.\n    </tool_response><|im_end|>\n    <|im_start|>assistant\n    <think>\n    The tool provided a good explanation.\n    </think>\n\n    The sum of two and two is four because it is a basic arithmetic operation.<|im_end|>\n\nTo handle this, we fall back to a **fixed base conversation** containing only a single system and user message. Since this base doesn't include assistant messages or reasoning content, it remains consistent across turns.\n\n.. code-block:: python\n\n    BASE_CHAT_HISTORY = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"I am a user.\"}\n    ]\n    prev = tokenizer.apply_chat_template(BASE_CHAT_HISTORY, add_generation_prompt=True, tokenize=False)\n    curr = tokenizer.apply_chat_template([*BASE_CHAT_HISTORY, messages[i]], add_generation_prompt=False, tokenize=False)\n    token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False)\n    loss_mask += [1] * len(token_ids)\n\nThis method works well for Qwen3 series. However, Qwen/QwQ-32B currently has a bug in its chat template. A fix_ has been proposed but not yet adopted. Until then, use the following command to download the fixed model revision:\n\n.. code-block:: bash\n\n    pip install huggingface_hub\n    huggingface-cli download Qwen/QwQ-32B --revision refs/pr/81\n\n.. _fix: https://huggingface.co/Qwen/QwQ-32B/discussions/81\n\nDiscrepancy Between Training and Inference Templates\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nAlthough the above approach fixes the delta mismatch issue, the removal of reasoning content in the inference-time chat template introduces a new discrepancy: training uses the full reasoning content, while inference does not.\n\nThis mismatch can affect model performance in unpredictable ways. To avoid it, we default to using the full response (including reasoning) for both training and rollout.\n\nHowever, this approach comes with trade-offs:\n\n1. Long reasoning contents can easily exceed the model's context window, especially in multi-turn rollout.\n2. There's a mismatch between rollout and production environment now—models will not have reasoning content from past turns if you use the default chat template in production.\n\nWe are still evaluating the impact of these issues. If you experience context length problems or prefer rollouts that match production (i.e., exclude reasoning), you can enable:\n\n``actor_rollout_ref.rollout.multi_turn.use_inference_chat_template = True``\n\nGSM8K Multi-turn Training Performance  \n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nSee the training performance of multi-turn rollout on the GSM8K task HERE_.\n\n.. _HERE: https://wandb.ai/zhaochenyang20/gsm8k_async_rl/runs/1ro1r7om?nw=nwuserzhaochenyang20\n\n.. _GSM8KTool_example_configuration: https://github.com/volcengine/verl/blob/main/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\n\n.. _gsm8k_tool.py: https://github.com/volcengine/verl/blob/main/verl/tools/gsm8k_tool.py\n\n.. _mcp_search_tool.py: https://github.com/volcengine/verl/blob/main/verl/tools/mcp_search_tool.py\n\n.. _mcp_tool_config.yaml: https://github.com/volcengine/verl/blob/main/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml\n\nInteraction System\n~~~~~~~~~~~~~~~~~~\n\nFor dynamic conversational feedback during RL training, see:\n\n.. toctree::\n   :maxdepth: 1\n\n   interaction_system\n\nSearch Tool Integration\n~~~~~~~~~~~~~~~~~~~~~~~\n\n.. toctree::\n   :maxdepth: 1\n\n   search_tool_example\n\nCode Walkthrough\n~~~~~~~~~~~~~~~~~~~~~~~\nIf you want to learn more in depth about the code execution flow, please read https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/rlhf/verl/multi-turn/code-walk-through\n"
  },
  {
    "path": "verl_distillation/docs/sglang_multiturn/sandbox_fusion.rst",
    "content": "===============================\nSandbox Fusion Tool Integration\n===============================\n\nLast updated: 06/10/2025.\n\nMotivations\n===========\n\n- As users of verl, we want to allow the model to call certain tools during Actor rollout, incorporating the results into the training process.\n- A colleague from ByteDance proposed a paper aimed at enhancing model capability through code execution tools.\n- We aim to support tool-calling capabilities of inference engines using `sandbox-fusion` as the code execution system, providing the community with a reimplementation of `retools`.\n\nReward Compute with Sandbox Fusion + FaaS Integration\n=====================================================\n\n- In current datasets and tasks, similar work already exists (e.g., Prime), which uses local processes as runners to execute model-generated code for reward computation.\n- On this basis, #1429 has advanced the design by integrating FaaS as the runner for reward computation.\n\nGoals\n=====\n\n- Adapt to the `sglang` tool-calling protocol and define tools for sandbox fusion.\n- Integrate with the `async-rollout` process, ensuring sandbox fusion tools follow asyncIO conventions.\n- Design and implement a basic rate limiter to prevent issues such as 429 errors.\n\nNon-Goals\n=========\n\n- Training effectiveness is out of scope.\n- Observability metrics are not considered.\n- Distributed failover and component fault tolerance are not addressed.\n\nDesign Details\n==============\n\nTool Schema Definition\n----------------------\n\n- Currently, only code execution is considered, requiring a `code` field in the JSON from the model.\n- Only Python code is supported for now, so no `language` parameter is defined.\n\n.. code-block:: python\n\n   OpenAIFunctionToolSchema(\n       type=\"function\",\n       function=OpenAIFunctionSchema(\n           name=\"code_interpreter\",\n           description=\"A tool for executing code.\",\n           parameters=OpenAIFunctionParametersSchema(\n               type=\"object\",\n               properties={\n                   \"code\": OpenAIFunctionPropertySchema(\n                       type=\"string\",\n                       description=\"The code to execute.\",\n                       enum=None,\n                   )\n               },\n               required=[\"code\"],\n           ),\n           strict=False,\n       )\n   )\n\nConfiguration Parameters\n--------------------------\n\n+----------------------------+--------------------------------------------------------------+\n| Parameter Name             | Description                                                  |\n+============================+==============================================================+\n| `num_workers`              | Number of worker threads/processes per DP to request runner. |\n+----------------------------+--------------------------------------------------------------+\n| `rate_limit`               | Global limit of concurrent code executions. Default: 10      |\n+----------------------------+--------------------------------------------------------------+\n| `default_timeout`          | Timeout (in seconds) for each code execution. Default: 30    |\n+----------------------------+--------------------------------------------------------------+\n| `default_language`         | Default programming language. Default: \"python\"              |\n+----------------------------+--------------------------------------------------------------+\n| `enable_global_rate_limit` | Whether to enable global rate limiting. Default: True        |\n+----------------------------+--------------------------------------------------------------+\n| `sandbox_fusion_url`       | URL for the veFaas sandbox execution service                 |\n+----------------------------+--------------------------------------------------------------+\n\nRate Limiting Design\n-----------------------\n\nObjective:\n\n- Limit the number of inflight requests using a token bucket model.\n\n- Ensure ordered submission to code runners to avoid starvation due to backoff.\n\nDesign Highlights:\n\n- Use Ray Global Actor as a singleton distributed counter at cluster level.\n  \n- Semaphore used for counting, with `acquire` and `release` in separate thread pools to preserve order.\n  \n- Use Ray’s cloud-pickle to serialize functions for decoupled `ExecutionWorker`.\n\n.. code-block:: python\n\n   @ray.remote(concurrency_groups={\"acquire\": 1,\"release\": 10})\n   class TokenBucketWorker:\n       def __init__(self, rate_limit: int):\n           self.rate_limit = rate_limit\n           self.current_count = 0\n           self._semaphore = threading.Semaphore(rate_limit)\n\n       @ray.method(concurrency_group=\"acquire\")\n       def acquire(self):\n           self._semaphore.acquire()\n           self.current_count += 1\n\n       @ray.method(concurrency_group=\"release\")\n       def release(self):\n           self._semaphore.release()\n           self.current_count -= 1\n\n       def get_current_count(self):\n           return self.current_count\n\n   class ExecutionWorker:\n       def __init__(self, enable_global_rate_limit=True, rate_limit=10):\n           self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None\n\n       def _init_rate_limit(self, rate_limit):\n           return TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote(rate_limit)\n\n       def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T:\n           with ExitStack() as stack:\n               stack.callback(self.rate_limit_worker.release.remote)\n               ray.get(self.rate_limit_worker.acquire.remote())\n               try:\n                   return fn(*fn_args, **fn_kwargs)\n               except Exception as e:\n                   logger.warning(f\"Error when executing code: {e}\")\n\n   def init_execution_pool(num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode=PoolMode.ThreadMode):\n       if mode == PoolMode.ThreadMode:\n           return ray.remote(ExecutionWorker).options(max_concurrency=num_workers).remote(\n               enable_global_rate_limit=enable_global_rate_limit,\n               rate_limit=rate_limit\n           )\n       else:\n           raise NotImplementedError(\"Process mode is not implemented yet\")\n\nTool Implementation\n-------------------\n\n- Use `instance_id` to identify requests across multiple dialogue rounds.\n  \n- Use `execution_pool` to implement async invocation.\n  \n- Cleanup state after rollout completion.\n\n.. code-block:: python\n\n   class SandboxFusionTool(BaseTool):\n       def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n           ...\n           self.execution_pool = init_execution_pool(...)\n           ...\n\n       async def create(self, instance_id: Optional[str] = None, ...):\n           ...\n\n        async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:\n            code = parameters.get(\"code\", \"\")\n            timeout = parameters.get(\"timeout\", self.default_timeout)\n            language = parameters.get(\"language\", self.default_language)\n            if not isinstance(code, str):\n                code = str(code)\n\n            result = await self.execution_pool.execute.remote(self.execute_code,instance_id,code,timeout,language)\n            self._instance_dict[instance_id][\"reward\"].append(result.strip())\n\n            return result, result, {}\n\n        def execute_code(self,instance_id,code,timeout=30,language=\"python\"):\n            result_status, metadata  = _process_single_case(0, None, None,self.sandbox_fusion_url, code, timeout, language)\n            # we should always expect this since we don't have correct answer\n            if metadata[\"run_status\"] == \"Finished\":\n                actual_output = metadata[\"stdout\"] if metadata[\"stdout\"] is not None else \"\"\n                return actual_output\n            else:\n                return \"no stdout here\"\n\n       async def calc_reward(self, instance_id: str, ...):\n           ...\n\n       async def release(self, instance_id: str, ...):\n           ...\n\nTest Plan\n=========\n\nUnit Tests\n----------\n\n- **test_tools_registration**: Test tool registration and initialization.\n- **test_rollout_req_creation**: Validate that `AsyncRolloutReq` is built correctly.\n- **test_over_size_case**: Ensure rollout terminates early when exceeding `max_seq_len`.\n- **test_tool_call_basic_case**: Mock `sglang` output, validate tool call and result.\n- **test_tool_call_batch_case**: Test batch processing of tool calls.\n- **test_basic_multi_process_init**: Validate Ray global actor behaves as singleton.\n- **TestSingleNodeRateLimiterCase**: Verify rate limiter works in single-node mode.\n- **test_rotten_execution**: Ensure rate limiter recovers from function errors.\n- **TestMultiNodeRateLimiterCase**: Verify behavior in multi-node environments.\n\ne2e Tests\n----------\nwe provide e2e test scripts in `tests/special_e2e` folder, named `tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh`\n\nby setting 'trainer.rollout_data_dir' you can dump the rollout data to local disk. here is an sample taken from the rollout data:\n\n.. code-block:: python\n\n   {\n     \"input\": \"\n     \n     system\\nYou are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### <answer>`.\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\\n{\\\"type\\\": \\\"function\\\", \\\"function\\\": {\\\"name\\\": \\\"code_interpreter\\\", \\\"description\\\": \\\"A tool for executing code.\\\", \\\"parameters\\\": {\\\"type\\\": \\\"object\\\", \\\"properties\\\": {\\\"code\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The code to execute.\\\", \\\"enum\\\": null}}, \\\"required\\\": [\\\"code\\\"]}, \\\"strict\\\": false}}\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call>\\n\n     \n     user\\nJohn gets a bonus that's the same percentage every year.  Last year he made $100,000 and got a $10,000 bonus.  This year he makes $200,000.  How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`.\\n\n     \n     assistant\\n\",\n     \"output\": \"assistant\\n<think>\\nOkay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together.\\n\\nFirst, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year.\\n\\nLast year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that.\\n\\n10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year.\\n\\nThis year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy.\\n\n     \n     <tool_call>\\n{\\\"name\\\": \\\"code_interpreter\\\", \\\"arguments\\\": {\\\"code\\\": \\\"total_pay_this_year = 200000\\\\nbonus_percentage = 10 / 100\\\\nbonus_this_year = total_pay_this_year * bonus_percentage\\\\ntotal_income_this_year = total_pay_this_year + bonus_this_year\\\\nprint(total_income_this_year)\\\", \\\"executes\\\": \\\"True\\\"}}\\n</tool_call>\\n\n     \n     tool\\n220000.0\\n\\n\n     \n     assistant\\n<think>\\nSo the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000.\\n\\nI need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters.\\n</think>\\n\\n#### 220000.0\",\n     \"score\": 0,\n     \"step\": 1\n   }\n\nhere is the readable format version:\n\n.. code-block:: python\n\n   [system]\n   \n   You are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### <answer>`.\n   \n   # Tools\n   \n   You may call one or more functions to assist with the user query.\n   \n   You are provided with function signatures within <tools></tools> XML tags:\n   <tools>\n   {\"type\": \"function\", \"function\": {\"name\": \"code_interpreter\", \"description\": \"A tool for executing code.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"code\": {\"type\": \"string\", \"description\": \"The code to execute.\", \"enum\": null}}, \"required\": [\"code\"]}, \"strict\": false}}\n   </tools>\n   \n   For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n   <tool_call>\n   {\"name\": <function-name>, \"arguments\": <args-json-object>}\n   </tool_call>\n   \n   [user]\n   \n   John gets a bonus that's the same percentage every year.  Last year he made $100,000 and got a $10,000 bonus.  This year he makes $200,000.  How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`.\n   \n   [assistant]\n   \n   <think>\n   Okay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together.\n   \n   First, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year.\n   \n   Last year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that.\n   \n   10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year.\n   \n   This year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy.\n   <tool_call>\n   {\"name\": \"code_interpreter\", \"arguments\": {\"code\": \"total_pay_this_year = 200000\\nbonus_percentage = 10 / 100\\nbonus_this_year = total_pay_this_year * bonus_percentage\\ntotal_income_this_year = total_pay_this_year + bonus_this_year\\nprint(total_income_this_year)\", \"executes\": \"True\"}}\n   </tool_call>\n   \n   [tool]\n   \n   220000.0\n   \n   [assistant]\n   \n   <think>\n   So the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000.\n   \n   I need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters.\n   </think>\n   \n   #### 220000.0\n\n\nYou can also use the `RolloutViewer` TUI tool to view the dumped rollout data:\n\n\n.. code-block:: bash\n\n    python scripts/rollout_viewer.py ${trainer.rollout_data_dir}\n\n\n.. image:: https://github.com/user-attachments/assets/e34e5157-2880-4a21-afb2-73885d0dfb11\n   :alt: RolloutViewer screenshot"
  },
  {
    "path": "verl_distillation/docs/sglang_multiturn/search_tool_example.rst",
    "content": "=======================\r\nSearch Tool Integration\r\n=======================\r\n\r\nLast updated: 05/30/2025.\r\n\r\nIntroduction\r\n------------\r\n- We have added a search tool calling function to Multi-Turn RL, enabling the model to initiate retrieval requests during Actor rollout and directly use retrieval results for training. **We support using a local dense retriever as the retrieval tool, as well as integrating with your own local retrieval engine.**\r\n\r\n\r\n\r\nQuick Reproduction\r\n------------------\r\n\r\nCreate a New Docker Container\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n.. code:: bash\r\n\r\n   docker run \\\r\n       -it \\\r\n       --shm-size 32g \\\r\n       --gpus all \\\r\n       -v {Huggingface-Cache-Path}:/root/.cache \\\r\n       --ipc=host \\\r\n       --network=host \\\r\n       --privileged \\\r\n       --name sglang_{your-name} \\\r\n       lmsysorg/sglang:dev \\\r\n       /bin/zsh\r\n\r\nIf you need to restart after exiting the container:\r\n\r\n.. code:: bash\r\n\r\n   docker start -i sglang_{your-name}\r\n\r\nUpdate Python and Configure the Virtual Environment using uv\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n.. code:: bash\r\n\r\n   apt update\r\n   apt install -y python3.10 python3.10-venv\r\n\r\n   # Create a virtual environment\r\n   python3 -m venv ~/.python/verl-multiturn-rollout\r\n\r\n   # Activate the virtual environment\r\n   source ~/.python/verl-multiturn-rollout/bin/activate\r\n\r\n   # Install uv\r\n   python3 -m pip install uv\r\n\r\nInstall verl Upstream\r\n~~~~~~~~~~~~~~~~~~~~~\r\n\r\n.. code:: bash\r\n\r\n   cd ~\r\n   git clone https://github.com/volcengine/verl.git\r\n   cd verl\r\n\r\n   # Install verl\r\n   python3 -m uv pip install .\r\n   python3 -m uv pip install -r ./requirements_sglang.txt\r\n\r\n   # Manually install flash-attn\r\n   python3 -m uv pip install wheel\r\n   python3 -m uv pip install packaging\r\n   python3 -m uv pip install flash-attn --no-build-isolation --no-deps\r\n\r\nSet Up a Local Retrieval Engine\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\nIf you are using your own local retrieval service, you can skip this\r\nstep. We chose the local dense retriever provided in the search-R1\r\nexample; detailed instructions are in the `searchR1\r\ndocs <https://raw.githubusercontent.com/PeterGriffinJin/Search-R1/refs/heads/main/docs/retriever.md>`__.\r\nIn brief:\r\n\r\n-  The GPU version offers higher accuracy and speed; each GPU uses about\r\n   5–7 GB of memory.\r\n-  The CPU version can be used for simple testing but has lower\r\n   retrieval precision, which will degrade training performance. See the\r\n   `retriever\r\n   documentation <https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/retriever.md>`__\r\n   in search-R1 for details.\r\n-  Recommend using Conda to install faiss-gpu=1.8.0; venv may cause errors.\r\n\r\n**Note**: To start both the training process and the local retrieval\r\nservice, we launch two separate Python environments. The training uses\r\nuv in the verl-multiturn-rollout environment, while the retriever uses\r\nconda to install ``faiss-gpu``.\r\n\r\n.. code:: bash\r\n\r\n   # Download the Miniconda installer script\r\n   wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh\r\n\r\n   # Install to $HOME/miniconda3 in batch mode\r\n   bash ~/miniconda.sh -b -p $HOME/miniconda3\r\n\r\n   # Activate conda (only in the current shell)\r\n   eval \"$($HOME/miniconda3/bin/conda shell.bash hook)\"\r\n\r\n   # (Optional) Add conda to your default shell startup\r\n   conda init\r\n\r\n   # Reload shell config\r\n   source ~/.bashrc\r\n\r\n   # Create and activate the retriever environment with Python 3.10\r\n   conda create -n retriever python=3.10 -y\r\n   conda activate retriever\r\n\r\n   # Install PyTorch (with GPU support) and related libraries\r\n   conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y\r\n\r\n   # Install other Python packages\r\n   pip install transformers datasets pyserini huggingface_hub\r\n\r\n   # Install the GPU version of faiss\r\n   conda install faiss-gpu=1.8.0 -c pytorch -c nvidia -y\r\n\r\n   # Install the API service framework\r\n   pip install uvicorn fastapi\r\n\r\nDownload the Indexing and Corpus\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\nThe local retrieval files are large—prepare sufficient disk space.\r\nDownloading is about 60–70 GB, and uncompressed takes about 132 GB:\r\n\r\n.. code:: bash\r\n\r\n   conda activate retriever\r\n\r\n   save_path=/the/path/to/save\r\n   python examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py --save_path $save_path\r\n   cat $save_path/part_* > $save_path/e5_Flat.index\r\n   gzip -d $save_path/wiki-18.jsonl.gz\r\n\r\nStart the Local flat e5 Retrieval Server\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n1. The first startup will download models and load the index.\r\n2. Apart from the download, startup takes about 1–2 minutes.\r\n3. After startup, each GPU uses about 5–7 GB of memory, leaving the rest\r\n   for multi-turn RL training.\r\n\r\n.. code:: bash\r\n\r\n   conda activate retriever\r\n\r\n   index_file=$save_path/e5_Flat.index\r\n   corpus_file=$save_path/wiki-18.jsonl\r\n   retriever_name=e5\r\n   retriever_path=intfloat/e5-base-v2\r\n\r\n   python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \\\r\n     --index_path $index_file \\\r\n     --corpus_path $corpus_file \\\r\n     --topk 3 \\\r\n     --retriever_name $retriever_name \\\r\n     --retriever_model $retriever_path \\\r\n     --faiss_gpu\r\n\r\nSet Up WANDB_API_KEY\r\n~~~~~~~~~~~~~~~~~~~~\r\n\r\n.. code:: bash\r\n\r\n   export WANDB_API_KEY={YOUR_WANDB_API_KEY}\r\n\r\n   # Define a timestamp function\r\n   function now() {\r\n       date '+%Y-%m-%d-%H-%M'\r\n   }\r\n\r\n**Preprocess the Dataset**\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n   **Note:** The following data processing and training commands must be\r\n   run in the verl-multiturn-rollout environment.\r\n\r\n.. code:: bash\r\n\r\n   python3 examples/data_preprocess/preprocess_search_r1_dataset.py\r\n\r\nTesting on 8 x H20\r\n~~~~~~~~~~~~~~~~~~\r\n\r\n.. code:: bash\r\n\r\n   # Ensure the now() function is defined\r\n   # Create a logs directory\r\n   mkdir -p logs\r\n\r\n   # Set GPUs and run with a suitable log path\r\n   export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\r\n\r\n   nohup bash examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh \\\r\n     trainer.experiment_name=qwen2.5-3b-it_rm-searchR1-like-sgl-multiturn-$(now) \\\r\n     > logs/searchR1-like$(now).log 2>&1 &\r\n\r\nCustom Search Configuration\r\n---------------------------\r\n\r\nTo enable multi-turn reasoning, set the following fields in your config:\r\n\r\n.. code:: yaml\r\n\r\n   actor_rollout_ref:\r\n     rollout:\r\n       name: \"sglang\"\r\n       multi_turn:\r\n         enable: True\r\n\r\nYou must specify ``retrieval_service_url`` in ``examples/sglang_multiturn/config/tool_config/search_tool_config.yaml``, and properly configure concurrency. For more details on concurrency, refer to the Sandbox Fusion example:\r\n\r\n.. code:: yaml\r\n\r\n   tools:\r\n     - class_name: verl.tools.search_tool.SearchTool\r\n       config:\r\n         retrieval_service_url: http://127.0.0.1:8000/retrieve\r\n         num_workers: 120\r\n         rate_limit: 120\r\n         timeout: 30\r\n\r\nThe retriever input/output formats are as follows. If your service\r\nparameters match, only modify ``retrieval_service_url``. You can also\r\ncustomize in ``search_r1_like_utils.py``.\r\n\r\n.. code:: python\r\n\r\n   Input format:\r\n   {\r\n     \"queries\": [\"What is Python?\", \"Tell me about neural networks.\"],\r\n     \"topk\": 3,\r\n     \"return_scores\": true\r\n   }\r\n\r\n   Output format (when return_scores=True, similarity scores are returned):\r\n   {\r\n       \"result\": [\r\n           [   # Results for each query\r\n               {\r\n                   \"document\": doc, \"score\": score\r\n               },\r\n               # ... more documents\r\n           ],\r\n           # ... results for other queries\r\n       ]\r\n   }\r\n\r\nNotes\r\n-----\r\n\r\n1. The total training time is about 27 hours; meanwhile, the validation\r\n   dataset is very large (51 k), and each validation takes about 6000 s.\r\n   (Therefore, ``val_before_train=False`` by default)\r\n"
  },
  {
    "path": "verl_distillation/docs/single_controller.rst",
    "content": "The Design of ``verl.single_controller``\n==============================================\n\nLast updated: 05/21/2025.\n\n**Author:**\\  `Wang Zhang <https://github.com/zw0610>`__\n\nPreface\n-------\n\nWe prepared this document for developers of ``verl``, particularly those\ninterested in understanding or contributing to the\n``verl.single_controller`` module. It is not intended for end users, but\nfor contributors seeking to understand the architectural rationale and\ninternal mechanics.\n\n--------------\n\nOrigin\n------\n\nThe ``single_controller`` module originated from a request I received —\nto adapt a toy single-process RLHF script into a distributed system with\nminimal changes, while maintaining ease of debugging.\n\nCommon practice — such as using PyTorch’s Distributed Data Parallel\n(DDP) — typically involves wrapping ``nn.Module`` and launching multiple\nprocesses that execute the same function under different ranks. However,\nthis approach presents two main limitations in the context of\ndistributed RLHF: - Difficulty representing multiple DAGs as required by\nPPO; - Difficulty inspecting intermediate tensors during training.\n\nTo maintain debuggability, we opted for a different approach — breaking\nthe training loop into well-defined stages like ``generate_sequences``,\n``compute_advantages``, and so on.\n\nWe selected `Ray <https://www.ray.io/>`__ as the initial backend for\n``verl`` due to its ability to expose Python class methods as RPC\nendpoints. However, Ray’s default model only supports **one method call,\none RPC**, while training LLMs typically requires coordination across\nmultiple processes.\n\nTo hide this multi-Ray actors invocation for a single method from users,\nwe introduced the following components:\n\n-  ``WorkerGroup`` – manages a group of remote workers and provides\n   a unified interface for multi-process distributed computation;\n-  ``ResourcePool`` – binds computational resources to worker\n   processes;\n-  ``ClassWithArgs`` – enables delayed remote instantiation with\n   specified initialization arguments.\n\n--------------\n\nA Running Example: ``generate_sequences``\n-----------------------------------------\n\nTo illustrate the design, we walk through how the ``generate_sequences``\nmethod in the ``ActorRolloutRefWorker`` class is registered and invoked\nacross distributed workers.\n\n--------------\n\nStep 1: Register with a Decorator\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe first step is to define the ``generate_sequences`` and decorate it\nwith ``@register`` as it will be called in driver script.\n\n**Source:**\n`fsdp_workers.py <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/workers/fsdp_workers.py#L528>`__\n\n.. code:: python\n\n   class ActorRolloutRefWorker(Worker):\n       ...\n       @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n       def generate_sequences(self, prompts: DataProto):\n           prompts = prompts.to(torch.cuda.current_device())\n           ...\n\nThe ``@register`` decorator adds metadata to the ``generate_sequences``\nmethod. Currently, it doesn’t alter functionality, but attaches\nattributes via a magic key (``MAGIC_ATTR``):\n\n**Source:**\n`decorator.py <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L411>`__\n\n.. code:: python\n\n   def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):\n       ...\n       def decorator(func):\n           @wraps(func)\n           def inner(*args, **kwargs):\n               if materialize_futures:\n                   args, kwargs = _materialize_futures(*args, **kwargs)\n               return func(*args, **kwargs)\n\n           attrs = {\"dispatch_mode\": dispatch_mode, \"execute_mode\": execute_mode, \"blocking\": blocking}\n           setattr(inner, MAGIC_ATTR, attrs)\n           return inner\n\n       return decorator\n\nAs the code shows, values of ``dispatch_mode``, ``execute_mode`` and\n``blocking`` is attached the ``generate_sequences`` method.\n\n--------------\n\nStep 2: Binding During Initialization\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThese attached attributes are extracted and utilized when\n``ActorRolloutRefWorker``, wrapped in a ``RayClassWithArgs``, is passed\ninto a ``RayWorkerGroup``.\n\n**Source:**\n`main_generation.py <https://github.com/volcengine/verl/blob/4ae9a0fdab229f75f080e9478807783ed4c97154/verl/trainer/main_generation.py#L82>`__\n\n.. code:: python\n\n   ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role=\"rollout\")\n   resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)\n   wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)\n\nDuring the\n`initialization <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L184>`__\nof ``RayWorkerGroup``, two key steps occur:\n\n1. Worker instances (Ray actors) are created:\n   `RayWorkerGroup._init_with_resource_pool <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L211>`__\n2. Methods decorated with ``@register`` are bound to ``RayWorkerGroup``:\n   `RayWorkerGroup._bind_worker_method <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L214>`__\n\n.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/worker_group_init.png?raw=true\n   :alt: initialization_and_binding_of_worker_group\n\n   initialization_and_binding_of_worker_group\n\nThe binding procedure is the heart of ``verl.single_controller``.\n\n**Key function:**\n`WorkerGroup._bind_worker_method <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/worker_group.py#L143>`__\n\n.. code:: python\n\n   def _bind_worker_method(self, user_defined_cls, func_generator):\n       ...\n       for method_name in dir(user_defined_cls):\n           try:\n               method = getattr(user_defined_cls, method_name)\n               assert callable(method)\n           except Exception:\n               continue  # Skip properties\n           <<<to be continue 1>>>\n\nWhen a method has the ``MAGIC_ATTR``, the attributes set by\n``@register`` are extracted:\n\n.. code:: python\n\n           <<<continue 1>>>\n           if hasattr(method, MAGIC_ATTR):\n               attribute = getattr(method, MAGIC_ATTR)\n               dispatch_mode = attribute[\"dispatch_mode\"]\n               execute_mode = attribute[\"execute_mode\"]\n               blocking = attribute[\"blocking\"]\n\n               <<<to be continue 2>>>\n\nAs show in the flow chart above, these attributes are fed into\n``func_generator``. However, ``func_generator`` takes ``method_name``,\n``dispatch_fn``, ``collect_fn``, ``execute_fn``, ``blocking``. We need\nto find the corresponding ``dispatch_fn`` and ``collect_fn`` associated\nwith the ``dispatch_mode`` (``DP_COMPUTE_PROTO``) from\n`DISPATCH_MODE_FN_REGISTRY <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L387>`__:\n\n.. code:: python3\n\n   DISPATCH_MODE_FN_REGISTRY = {\n       Dispatch.ONE_TO_ALL: {\n           \"dispatch_fn\": dispatch_one_to_all,\n           \"collect_fn\": collect_all_to_all,\n       },\n       ...\n       Dispatch.DP_COMPUTE_PROTO: {\n           \"dispatch_fn\": dispatch_dp_compute_data_proto,\n           \"collect_fn\": collect_dp_compute_data_proto,\n       },\n       ...\n   }\n\nSimilarly, the ``execute_fn`` is selected by ``execute_mode`` and\nextracted by:\n\n.. code:: python\n\n               <<<continue 2>>>\n               # get execute_fn_name\n               execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)\n               wg_execute_fn_name = execute_mode[\"execute_fn_name\"]\n\n               # get execute_fn from string\n               try:\n                   execute_fn = getattr(self, wg_execute_fn_name)\n                   assert callable(execute_fn), \"execute_fn must be callable\"\n               except Exception:\n                   print(f\"execute_fn {wg_execute_fn_name} is invalid\")\n                   raise\n               <<<to be continue 3>>>\n\nIn this ``generate_sequences`` cases: -\n``dispatch_mode = Dispatch.DP_COMPUTE_PROTO`` -\n``dispatch_fn = dispatch_dp_compute_data_proto`` -\n``collect_fn = collect_dp_compute_data_proto`` -\n``execute_fn = RayWorkerGroup.execute_all``\n\nONE_TO_ALL v.s. DP_COMPUTE_PROTO\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n``dispatch_mode`` is associated with a ``dispatch_fn`` and a\n``collect_fn``. As the name implies, ``dispatch_fn`` processes the input\narguments in ``WorkerGroup`` and generate a batch (list) of input\narguments, each of which will be fed into a worker attached to the\n``WorkerGroup``.\n\n``dispatch_fn`` of ``ONE_TO_ALL`` is\n`dispatch_one_to_all <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L119>`__,\nwhich just duplicates all the input arguments into N replicas, where N\nequals the number of Workers attached to the ``worker_group``:\n\n.. code:: python\n\n   def dispatch_one_to_all(worker_group, *args, **kwargs):\n       args = tuple([arg] * worker_group.world_size for arg in args)\n       kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}\n       return args, kwargs\n\n``dispatch_fn`` of ``DP_COMPUTE_PROTO`` is\n`dispatch_dp_compute_data_proto <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L350>`__,\nwhich uses ``DataProto.chunk`` to split a large ``DataProto`` into N\nsmaller ``DataProto``, where N equals the world_size (number of the\nworkers) of the ``worker_group``:\n\n.. code:: python\n\n   def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):\n       from verl.single_controller.base.worker_group import WorkerGroup\n\n       assert isinstance(worker_group, WorkerGroup)\n       # Note: enable auto padding for dp compute DatapProto\n       splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding(\n           worker_group.world_size,\n           *args,\n           **kwargs,\n       )\n       return splitted_args, splitted_kwargs\n\nThe ``collect_fn`` follows the same pattern and process a batch (list)\nof returned value from all workers of a ``WorkerGroup`` and merge it\ninto a list as ``collect_all_to_all`` does or a large ``DataProto`` as\n``collect_dp_compute_data_proto`` does.\n\nFinally, a new method is dynamically generated using ``func_generator``\nand added to the ``WorkerGroup`` instance:\n\n.. code:: python\n\n               <<<continue 3>>>\n               # bind a new method to the RayWorkerGroup\n               func = func_generator(\n                   self,\n                   method_name,\n                   dispatch_fn=dispatch_fn,\n                   collect_fn=collect_fn,\n                   execute_fn=execute_fn,\n                   blocking=blocking,\n               )\n\n               try:\n                   setattr(self, method_name, func)\n                   method_names.append(method_name)\n               except Exception as e:\n                   raise ValueError(f\"Fail to set method_name {method_name}\") from e\n\nThis makes the method invocable via the ``WorkerGroup`` interface.\n\n--------------\n\nStep 3: Call Chain\n~~~~~~~~~~~~~~~~~~\n\nAll the machinery above ensures that distributed calls feel identical to\nsingle-process ones. In the original single-process script, the code\nlooks like:\n\n.. code:: python\n\n   rollout = Rollout()\n   rollout.generate_sequences(batch)\n\nWith ``verl``, the multiprocess program becomes:\n\n.. code:: python\n\n   rollout = RayWorkerGroup(resource_pool=[4], RayClassWithArgs(Rollout))\n   rollout.generate_sequences(batch)\n\n.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/call_generate_sequences.png?raw=true\n   :alt: call_chain_of_generate_sequences\n\n   call_chain_of_generate_sequences\n\nBehind this simple call: - ``dispatch_fn`` splits input across workers -\n``execute_fn`` performs the actual remote invocation - ``collect_fn``\ngathers the results\n\nAll of this is abstracted away, enabling developers to write distributed\ncode with minimal changes to their existing logic.\n\n--------------\n\nBeyond RL Post-Training: Generalizing ``verl.single_controller``\n----------------------------------------------------------------\n\nThe ``verl.single_controller`` module generalizes well beyond\nreinforcement learning. It provides a clean abstraction to batch-process\nremote method calls, with automatic input/output handling.\n\nBy minimizing the gap between single-process and multi-process scripts,\n``verl.single_controller`` opens the door to distributed computing in\nbroader domains — not limited to RL post-training.\n\nWe hope this design inspires more examples and extensions from the\ncommunity.\n"
  },
  {
    "path": "verl_distillation/docs/start/agentic_rl.rst",
    "content": "Agentic RL Training\n===================\n\nLast updated: 07/15/2025.\n\nOverview\n----------\nThe goal of Agentic RL is to improve the performance of backend models from reinforcement learning to the Agent. During the training process, a series of features are developed:\n\n1. Server-based asynchronous rollout\n2. Multi-turn conversations and tool calls\n3. LangGraph-based Agent\n\n\nThis document explains the system principles and usage involved to help users implement Agentic RL.\n\n\nServer-based Asynchronous Rollout\n---------------------------------\n\nSince Agents need to interact with the environment through various tool calls, in order to avoid GPU idling while waiting for tool call return results, an asyncio based co-routing mechanism is utilized to execute each rollout requests asynchronously, thereby improving training performance. To support asynchronous rollout, the inference engine (server) and the agent (client) are architecturally separated, implementing a server-based system with the following objectives:\n\n1. Enabling load balancing mechanisms to balance loads across multiple GPUs and reduce the impact of long-tail requests on performance. For this purpose, scheduling capabilities in stream mode (recipe\\stream_mode) are implemented as a recipe.\n2. Preventing agent specific features such as tracing from affecting the inference engine.\n\nSystem Architecture\n~~~~~~~~~~~~~~~~~~~\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop.png?raw=true\n\nFor more detail on internal design, please refer to :doc:`Agent Loop<../advance/agent_loop>`.\n\nSystem Components\n~~~~~~~~~~~~~~~~~\n\n+--------------------------+----------------------------------------------------------------------------+\n| Component                | Role                                                                       |\n+==========================+============================================================================+\n| AgentLoop                | Client, implements Agent functions                                         |\n+--------------------------+----------------------------------------------------------------------------+\n| AsyncLLMServerManager    | Inference gateway, provides generate interface for AgentLoop               |\n+--------------------------+----------------------------------------------------------------------------+\n| AsyncServer              | Server, each instance is connected to one DP group of the inference engine |\n+--------------------------+----------------------------------------------------------------------------+\n\n**\"generate\" Interface**\n\nThe \"generate\" function based on ray actor is used between the Client and Server instead of the standard chat completion API. This is because the conversion between tokens and text can be irreversible. For example, the token converted from \"<think>\" will be different from that generated by the LLM. During the training phase, it is necessary to strictly use the tokens generated by LLM inference to avoid inaccurate in computing advantage, which may affect model performance. Having the Server provide a token-based API helps the Client maintain the relationship between the text generated by tool calls and the tokens returned by the LLM, so as to output correct tokens for training.\n\n\n**Inference Engine Adaptation**\nAsyncServer uniformly provides a generate function to the upper layer, with separate implementations for SGLang and vLLM to hide underlying differences:\n\n1. The SGLang AsyncServer uses the async_generate interface of the SGLang engine, which is located on the first GPU of each TP group. Therefore, AsyncServer needs to remotely call async_generate through ray actor.\n2. The vLLM AsyncServer uses the generate interface of the vLLM engine, which can communicate with the GPUs in the TP group through ZMQ and can be directly called in AsyncServer.\n\n\nUsage Example\n~~~~~~~~~~~~~\n\nFollow :doc:`GSM8K example<../examples/gsm8k_example>` to prepare the dataset and model checkpoints.\n\nThere are two options required to use agent loop:\n\n- `data.return_raw_chat=True`\n- `actor_rollout_ref.rollout.mode=async`\n\nThis example uses the sglang inference engine by default, and you can also modify rollout_name to use vllm.\n\n.. code-block:: bash\n\n    bash examples/grpo_trainer/run_qwen2-7b_seq_balance.sh\n\n\nMulti-turn Conversations and Tool Calls\n---------------------------------------\n\nFollow :doc:`Multi-turn Rollout Support<../sglang_multiturn/multiturn>` to prepare tool and configuration files.\n\nThe Tool Agent Loop has an additional requirement: adding an \"agent_name\" field to the dataset. During rollout, it will choose to use tool_agent_loop or single_turn_agent (default) based on this field.\n\nUsage Example\n~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    # install mlflow to view toolcall and llm trace\n    pip install mlflow\n\n    # This will download and preprocess the GSM8K dataset into ~/data/gsm8k/ and add the \"agent_name\" field.\n    python examples/data_preprocess/gsm8k_tool_agent_loop.py\n\n    # Start training with tool calls and enabled mlflow based trace helping to debug the rollout details\n    bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh\n\n    # When training is done, start a mlflow server to view trace\n    mlflow ui -h 0.0.0.0 -p 5000 --backend-store-uri sqlite:////tmp/mlruns.db\n\n    # then you can open http://<your ip address>:5000 from browser to view trace\n\n\nNote: During training, because the model may sometimes fail to generate correct toolcall tags, an error message \"Failed to decode tool call\" will be output to the console, which does not indicate an abnormality in training.\n\n\nFollow :doc:`Rollout trace<../advance/rollout_trace>` to known more about trace feature.\n\n\n\nAgent Framework\n---------------\n\nSystem Architecture\n~~~~~~~~~~~~~~~~~~~\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/langgraph_agent.png?raw=true\n\nSystem Components\n~~~~~~~~~~~~~~~~~\n\n+--------------------------+-----------------------------------------------------------------------------------------------+\n| Component                | Role                                                                                          |\n+==========================+===============================================================================================+\n| ChatModel                | LLM object of LangChain, used to adapt to the “generate” api provided by AsyncLLMServerManager|\n+--------------------------+-----------------------------------------------------------------------------------------------+\n| RectAgentLoop            | Agent adaptation layer, which by default supports a naive LangGraph Agentic.                  |\n|                          | New classes can be derived to support user-defined Agents, and the run function needs to be   |\n|                          | implemented to complete Agent calls.                                                          |\n+--------------------------+-----------------------------------------------------------------------------------------------+\n| AsyncServer              | Server, each instance is connected to one DP group of the inference engine.                   |\n+--------------------------+-----------------------------------------------------------------------------------------------+\n\n\nFollow doc \"recipe/langgraph_agent/example/README.md\" for more details."
  },
  {
    "path": "verl_distillation/docs/start/install.rst",
    "content": "Installation\n============\n\nRequirements\n------------\n\n- **Python**: Version >= 3.10\n- **CUDA**: Version >= 12.8\n\nverl supports various backends. Currently, the following configurations are available:\n\n- **FSDP** and **Megatron-LM** (optional) for training.\n- **SGLang**, **vLLM** and **TGI** for rollout generation.\n\nChoices of Backend Engines\n----------------------------\n\n1. Training:\n\nWe recommend using **FSDP** backend to investigate, research and prototype different models, datasets and RL algorithms. The guide for using FSDP backend can be found in :doc:`FSDP Workers<../workers/fsdp_workers>`.\n\nFor users who pursue better scalability, we recommend using **Megatron-LM** backend. Currently, we support `Megatron-LM v0.13.1 <https://github.com/NVIDIA/Megatron-LM/tree/core_v0.13.1>`_. The guide for using Megatron-LM backend can be found in :doc:`Megatron-LM Workers<../workers/megatron_workers>`.\n\n\n2. Inference:\n\nFor inference, vllm 0.8.3 and later versions have been tested for stability. We recommend turning on env var `VLLM_USE_V1=1` for optimal performance.\n\nFor SGLang, refer to the :doc:`SGLang Backend<../workers/sglang_worker>` for detailed installation and usage instructions. SGLang rollout is under extensive development and offers many advanced features and optimizations. We encourage users to report any issues or provide feedback via the `SGLang Issue Tracker <https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/106>`_.\n\nFor huggingface TGI integration, it is usually used for debugging and single GPU exploration.\n\nInstall from docker image\n-------------------------\n\nWe provide pre-built Docker images for quick setup. And from this version,\nwe utilize a new image release hierarchy for productivity and stability.\n\nThe image types are divided into three large categories:\n\n- **Base Image**: Without inference and training frameworks, only basic dependencies are installed.\n  Can directly install vllm or SGLang on top of it, without need of reinstall torch or CUDA.\n- **Application Image**: Stable version with inference and training frameworks installed.\n- **Community Image**: Unstable version with the latest frameworks and features.\n\nThe first two types of images are hosted on dockerhub `verlai/verl <https://hub.docker.com/r/verlai/verl>`_ repository, while the preview images are hosted on community repository.\n\n.. note::\n\n    The image versions are mapped with verl releases, for example, image with tag ``verl0.4`` is built for verl release ``v0.4.x``.\n\nBase Image\n::::::::::\n\nThe stable base image is ``verlai/verl:base-verl0.6-cu128-cudnn9.8-torch2.8.0-fa2.7.4`` for vLLM and sglang. The installed package versions can be found from tags, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.base``.\n\nThe update of base image is not frequent, and the app image can be built on top of it without reinstalling base packages.\n\nApplication Image\n:::::::::::::::::\n\nFrom this version, we divide images built for vLLM and SGLang as the divergence of dependent packages like Pytorch and FlashInfer.\n\nThere are 2 types of application images available:\n\n- **vLLM with FSDP and Megatron**: ``verlai/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.13.0-te2.2``\n- **SGLang with FSDP and Megatron**: ``verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2``\n\nDocker images with Megatron backends are runnable with large language model like ``Qwen/Qwen3-235B-A22B``, ``deepseek-ai/DeepSeek-V3-0324`` post-training. Refer to the :doc:`Large Language Model Post-Training documentation<../perf/dpsk>` for more details.\n\nApplication images can be updated frequently, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.app.[frameworks]``. Based on the base image, it is easy to build your own application image with the desired inference and training frameworks.\n\nCommunity Image\n:::::::::::::::\n\nCommunity images are provided by the community, including the latest versions of vLLM and SGLang, and may include experimental features or configurations. And also works for other hardwares or platforms like AMD GPUs with ROCM or AWS EFA and Sagemaker.\n\nFor latest vLLM with FSDP, please refer to `hiyouga/verl <https://hub.docker.com/r/hiyouga/verl>`_ repository and the latest version is ``hiyouga/verl:ngc-th2.8.0-cu12.9-vllm0.11.0``.\n\nFor latest SGLang with FSDP, please refer to `hebiaobuaa/verl <https://hub.docker.com/r/hebiaobuaa/verl>`_ repository and the latest version is ``hebiaobuaa/verl:app-verl0.5-sglang0.4.9.post6-mcore0.12.2-te2.2`` which is provided by SGLang RL Group.\n\nFor latest vLLM with Megatron, please refer to `iseekyan/verl <https://hub.docker.com/r/iseekyan/verl>`_ repository and the latest version is ``iseekyan/verl:megatron0.13_vllm0.11``.\n\nSee files under ``docker/`` for NGC-based image or if you want to build your own.\n\nNote that For aws instances with EFA net interface (Sagemaker AI Pod),\nyou need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa``\n\nInstallation from Docker\n::::::::::::::::::::::::\n\nAfter pulling the desired Docker image and installing desired inference and training frameworks, you can run it with the following steps:\n\n1. Launch the desired Docker image and attach into it:\n\n.. code:: bash\n\n    docker create --runtime=nvidia --gpus all --net=host --shm-size=\"10g\" --cap-add=SYS_ADMIN -v .:/workspace/verl --name verl <image:tag> sleep infinity\n    docker start verl\n    docker exec -it verl bash\n\n\n2.\tIf you use the images provided, you only need to install verl itself without dependencies:\n\n.. code:: bash\n\n    # install the nightly version (recommended)\n    git clone https://github.com/volcengine/verl && cd verl\n    pip3 install --no-deps -e .\n\n[Optional] If you hope to switch between different frameworks, you can install verl with the following command:\n\n.. code:: bash\n\n    # install the nightly version (recommended)\n    git clone https://github.com/volcengine/verl && cd verl\n    pip3 install -e .[vllm]\n    pip3 install -e .[sglang]\n\n\nInstall from custom environment\n---------------------------------------------\n\nWe recommend to use docker images for convenience. However, if your environment is not compatible with the docker image, you can also install verl in a python environment.\n\n.. note::\n\n    - Dockerfile provides more details than this installation instructions. You can find examples in each Dockerfile, for example `verl0.6-cu128-torch2.8.0-fa2.7.4 Dockerfile.base <https://github.com/volcengine/verl/blob/v0.6.0/docker/verl0.6-cu128-torch2.8.0-fa2.7.4/Dockerfile.base>`_ .\n\n\nPre-requisites\n::::::::::::::\n\nFor training and inference engines to utilize better and faster hardware support, CUDA/cuDNN and other dependencies are required,\nand some of the dependencies are easy to be overridden when installing other packages,\nso we put them in the :ref:`Post-installation` step.\n\n.. note::\n\n    - The installation steps below are recommended configurations for the latest version of verl.\n\n    If you are trying to customize your own environment, please ignore the strict constraints.\n\nWe need to install the following pre-requisites:\n\n- **CUDA**: Version >= 12.8\n- **cuDNN**: Version >= 9.10.0\n- **Apex**\n\nCUDA above 12.8 is recommended to use as the docker image,\nplease refer to `NVIDIA's official website <https://developer.nvidia.com/cuda-toolkit-archive>`_ for other version of CUDA.\n\n.. code:: bash\n\n    # change directory to anywher you like, in verl source code directory is not recommended\n    wget https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda-repo-ubuntu2204-12-8-local_12.8.1-570.124.06-1_amd64.deb\n    dpkg -i cuda-repo-ubuntu2204-12-8-local_12.8.1-570.124.06-1_amd64.deb\n    cp /var/cuda-repo-ubuntu2204-12-8-local/cuda-*-keyring.gpg /usr/share/keyrings/\n    apt-get update\n    apt-get -y install cuda-toolkit-12-8\n    update-alternatives --set cuda /usr/local/cuda-12-8\n\n\ncuDNN can be installed via the following command,\nplease refer to `NVIDIA's official website <https://developer.nvidia.com/rdp/cudnn-archive>`_ for other version of cuDNN.\n\n.. code:: bash\n\n    # change directory to anywher you like, in verl source code directory is not recommended\n    wget https://developer.download.nvidia.com/compute/cudnn/9.10.2/local_installers/cudnn-local-repo-ubuntu2204-9.10.2_1.0-1_amd64.deb\n    dpkg -i cudnn-local-repo-ubuntu2204-9.10.2_1.0-1_amd64.deb\n    cp /var/cudnn-local-repo-ubuntu2204-9.10.2/cudnn-*-keyring.gpg /usr/share/keyrings/\n    apt-get update\n    apt-get -y install cudnn-cuda-12\n\nInstall dependencies\n::::::::::::::::::::\n\n.. note::\n\n    We recommend to use a fresh new conda environment to install verl and its dependencies.\n\n    **Notice that the inference frameworks often strictly limit your pytorch version and will directly override your installed pytorch if not paying enough attention.**\n\n    As a countermeasure, it is recommended to install inference frameworks first with the pytorch they needed. For vLLM, if you hope to use your existing pytorch,\n    please follow their official instructions\n    `Use an existing PyTorch installation <https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html#build-wheel-from-source>`_ .\n\n\n1. First of all, to manage environment, we recommend using conda:\n\n.. code:: bash\n\n   conda create -n distill python==3.12\n   conda activate distill\n\n\n2. Then, execute the ``install.sh`` script that we provided in verl:\n\n.. code:: bash\n\n    # Make sure you have activated distill conda env\n    # If you need to run with megatron\n    bash scripts/install_vllm_sglang_mcore.sh\n    # Or if you simply need to run with FSDP\n    USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh\n\n\nIf you encounter errors in this step, please check the script and manually follow the steps in the script.\n\n[Optional] NVIDIA Apex is recommended for Megatron-LM training, but it's not needed if you only use FSDP backend.\nYou can install it via the following command, but notice that this steps can take a very long time.\nIt is recommended to set the ``MAX_JOBS`` environment variable to accelerate the installation process,\nbut do not set it too large, otherwise the memory will be overloaded and your machines may hang.\n\n.. code:: bash\n\n    # change directory to anywher you like, in verl source code directory is not recommended\n    git clone https://github.com/NVIDIA/apex.git && \\\n    cd apex && \\\n    MAX_JOB=32 pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" ./\n\nInstall verl\n::::::::::::\n\nFor installing the latest version of verl, the best way is to clone and\ninstall it from source. Then you can modify our code to customize your\nown post-training jobs.\n\n.. code:: bash\n\n   git clone https://github.com/volcengine/verl.git\n   cd verl\n   pip install --no-deps -e .\n\n\nPost-installation\n:::::::::::::::::\n\nPlease make sure that the installed packages are not overridden during the installation of other packages.\n\nThe packages worth checking are:\n\n- **torch** and torch series\n- **vLLM**\n- **SGLang**\n- **pyarrow**\n- **tensordict**\n- **nvidia-cudnn-cu12**: For Magetron backend\n\nIf you encounter issues about package versions during running verl, please update the outdated ones.\n\n\nInstall with AMD GPUs - ROCM kernel support\n------------------------------------------------------------------\n\nWhen you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and run it.\nIf you encounter any issues in using AMD GPUs running verl, feel free to contact me - `Yusheng Su <https://yushengsu-thu.github.io/>`_.\n\nFind the docker for AMD ROCm: `docker/Dockerfile.rocm <https://github.com/volcengine/verl/blob/main/docker/Dockerfile.rocm>`_\n::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::\n\n.. code-block:: bash\n\n    #  Build the docker in the repo dir:\n    # docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 .\n    # docker images # you can find your built docker\n    FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n\n    # Set working directory\n    # WORKDIR $PWD/app\n\n    # Set environment variables\n    ENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n\n    # Install vllm\n    RUN pip uninstall -y vllm && \\\n        rm -rf vllm && \\\n        git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \\\n        cd vllm && \\\n        MAX_JOBS=$(nproc) python3 setup.py install && \\\n        cd .. && \\\n        rm -rf vllm\n\n    # Copy the entire project directory\n    COPY . .\n\n    # Install dependencies\n    RUN pip install \"tensordict<0.6\" --no-deps && \\\n        pip install accelerate \\\n        codetiming \\\n        datasets \\\n        dill \\\n        hydra-core \\\n        liger-kernel \\\n        numpy \\\n        pandas \\\n        datasets \\\n        peft \\\n        \"pyarrow>=15.0.0\" \\\n        pylatexenc \\\n        \"ray[data,train,tune,serve]\" \\\n        torchdata \\\n        transformers \\\n        wandb \\\n        orjson \\\n        pybind11 && \\\n        pip install -e . --no-deps\n\nBuild the image\n::::::::::::::::::::::::\n\n.. code-block:: bash\n\n    docker build -t verl-rocm .\n\nLaunch the container\n::::::::::::::::::::::::::::\n\n.. code-block:: bash\n\n    docker run --rm -it \\\n      --device /dev/dri \\\n      --device /dev/kfd \\\n      -p 8265:8265 \\\n      --group-add video \\\n      --cap-add SYS_PTRACE \\\n      --security-opt seccomp=unconfined \\\n      --privileged \\\n      -v $HOME/.ssh:/root/.ssh \\\n      -v $HOME:$HOME \\\n      --shm-size 128G \\\n      -w $PWD \\\n      verl-rocm \\\n      /bin/bash\n\nIf you do not want to root mode and require assign yourself as the user,\nPlease add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script.\n\nverl with AMD GPUs currently supports FSDP as the training engine, vLLM and SGLang as the inference engine. We will support Megatron in the future.\n"
  },
  {
    "path": "verl_distillation/docs/start/more_resources.rst",
    "content": "More Resources\n==============\n\nLast updated: 06/30/2025.\n\n- Introduction to verl (`Slides <https://tongyx361.github.io/blogs/posts/verl-intro>`_)\n- verl Code Walkthrough (`Slides <https://tongyx361.github.io/blogs/posts/verl-tutorial>`_, `Talk in Chinese <https://hcqnc.xetlk.com/sl/3vACOK>`_) \n"
  },
  {
    "path": "verl_distillation/docs/start/multinode.rst",
    "content": "Multinode Training\n==================\n\nLast updated: 06/10/2025.\n\n.. _wuxibin89: https://github.com/wuxibin89\n\nAuthor: `Xibin Wu <https://github.com/wuxibin89>`_, `Yusheng Su <https://yushengsu-thu.github.io/>`_.\n\nOption 1: Launch Manually\n------------------------------\n\nSet up multinode ray cluster\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n1. Start head node with ``ray start --head --dashboard-host=0.0.0.0``, there're 2 address you should care about:\n\n- GCS address: ``ray start --address=<address>``, where worker node should connect to.\n- Dashboard address: ``<address>:8265``, where you should submit job to the cluster.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/head.png?raw=true\n\n2. Start worker node with ``ray start --address=<address>`` you get above.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/worker.png?raw=true\n\n3. Now you should see the cluster have 2 nodes with ``ray status``.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/status.png?raw=true\n\n4. Additionally, you can access dashboard in the browser with the address you get above. \n\n*Firewall rules maybe need configure to access the dashboard, if there's any trouble, please contact your network administrator.*\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/overview.png?raw=true\n\nSubmit job to ray cluster\n~~~~~~~~~~~~~~~~~~~~~~~~~\n1. Submit ray job to cluster with the dashboard address you get above.\n\n.. code-block:: bash\n\n    ray job submit --address=\"http://127.0.0.1:8265\" \\\n        --runtime-env=verl/trainer/runtime_env.yaml \\\n        --no-wait \\\n        -- \\\n        python3 -m verl.trainer.main_ppo \\\n        trainer.n_gpus_per_node=8 \\\n        trainer.nnodes=2 \\\n        ...\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/submit.png?raw=true\n\n2. Then you can check the job status with the following commands:\n\n- ray job list: list all jobs submitted to the cluster.\n- ray job logs <Submission ID>: query the logs of the job.\n- ray job status <Submission ID>: query the status of the job.\n- ray job stop <Submission ID>: request the job to be stopped.\n- ray job list | grep submission_id | grep JobStatus | grep RUNNING | grep -oP 'raysubmit_[^'\\''\"]+' | head -n 1: get the latest job submission ID of the running job.\n- ray job logs <Submission ID> --follow: added ``--follow`` parameter to ray job logs command to enable continuous log streaming.\n\n3. You can also access driver/task/actor logs in ``/tmp/ray/session_latest/logs/``, driver log is ``job-driver-raysubmit_<Submission ID>.log``.\n\n4. We strongly recommend you to view job detail from dashboard in multinode training, because it provide more structure way to view the job information.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/job.png?raw=true\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/job_detail.png?raw=true\n\nOption 2: Launch via SkyPilot on Kubernetes or clouds\n------------------------------------------------------\n\n.. note::\n   Ready-to-use SkyPilot example configurations are available in the `examples/skypilot/ <https://github.com/volcengine/verl/tree/main/examples/skypilot>`_ directory:\n   \n   - ``verl-ppo.yaml`` - PPO training with GSM8K dataset\n   - ``verl-grpo.yaml`` - GRPO training with MATH dataset  \n   - ``verl-multiturn-tools.yaml`` - Multi-turn tool usage training\n   \n   See the `SkyPilot examples README <https://github.com/volcengine/verl/tree/main/examples/skypilot>`_ for detailed usage instructions.\n\nStep 1: Setup SkyPilot\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\nSkyPilot can support different clouds, here we use GCP as example. `install skypilot <https://docs.skypilot.co/en/latest/getting-started/installation.html>`_\n\n.. code-block:: bash\n\n    conda create -y -n sky python=3.10\n    conda activate sky\n    pip install \"skypilot[gcp]\"\n\n    conda install -c conda-forge google-cloud-sdk\n    gcloud init\n\n    # Run this if you don't have a credential file.\n    # This will generate ~/.config/gcloud/application_default_credentials.json.\n    gcloud auth application-default login\n\n    # Check if the GCP credential is correctly setup.\n    sky check gcp\n\n.. image:: https://github.com/yottalabsai/open-source/blob/main/static/verl/setup_skypilot.png?raw=true\n\nStep 2: Prepare dataset\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   git clone https://github.com/volcengine/verl.git\n   cd examples/data_preprocess\n   python3 gsm8k.py --local_save_dir ~/data/gsm8k\n\n\nStep 3: Submit a job with SkyPilot\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n1. Create a SkyPilot YAML ``verl-cluster.yml`` with the following content:\n\n.. parsed-literal:: workdir: .  will sync all the data in the current dir to the remote cluster.\n\n.. code-block:: yaml\n\n   resources:\n     accelerators: L4:1 # every node has 1 L4 GPU\n     image_id: docker:verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4\n     memory: 64+        # every node has 64 GB memory\n     ports: 8265        # expose port for ray dashboard\n\n   num_nodes: 2         # cluster size\n\n   # --------------- Work Directory Synchronization (workdir) ---------------\n   # Defines the local working directory to be synchronized to the remote cluster.\n   # Here, '.' means synchronizing the directory where the sky submit command is currently run.\n   workdir: .\n\n   # --------------- (secrets) ---------------\n   secrets:\n     ## your wandb api key ##\n     WANDB_API_KEY: null\n\n   # --------------- File Mounts/Data Upload (file_mounts) ---------------\n   # If your dataset (gsm8k folder) is local, it needs to be uploaded to the remote cluster.\n   file_mounts:\n     # Remote path (relative to remote user's home directory): Local path\n     # /remote/dir1/file: /local/dir1/file\n     data/gsm8k: ~/data/gsm8k\n\n   # --------------- Environment Setup (setup) ---------------\n   # Commands run on each node of the remote cluster to set up the environment (e.g., install dependencies). These are run directly inside Docker.\n   setup: |\n     rm -rf verl\n     git clone https://github.com/volcengine/verl.git\n     cd verl\n     pip3 install -v -e .[vllm]\n\n   # --------------- Run Command (run) ---------------\n   # The actual task commands to be executed on the remote cluster.\n   # This script will first start the Ray cluster (different ray start commands are executed on Head and Worker nodes).\n   # Then, your training script will only be run on the Head node (SKYPILOT_NODE_RANK == 0).\n   run: |\n     # Get the Head node's IP and total number of nodes (environment variables injected by SkyPilot).\n     head_ip=`echo \"$SKYPILOT_NODE_IPS\" | head -n1`\n     num_nodes=`echo \"$SKYPILOT_NODE_IPS\" | wc -l` # Here num_nodes should be equal to 2.\n\n     # login wandb\n     python3 -c \"import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')\"\n\n     # Start Ray based on node role (Head=0, Worker>0).\n     # This logic is a standard Ray cluster startup script.\n     if [ \"$SKYPILOT_NODE_RANK\" == \"0\" ]; then\n       # Head node starts Ray Head.\n       echo \"Starting Ray head node...\"\n       # Check if a Ray Head is already running to avoid duplicate starts.\n       ps aux | grep ray | grep 6379 &> /dev/null ||  ray start --head --disable-usage-stats \\\n             --port=6379 \\\n             --dashboard-host=0.0.0.0 \\\n             --dashboard-port=8265\n\n       # Wait for all worker nodes to join the cluster.\n       while [ $(ray nodes | grep NODE_ID | wc -l) -lt $num_nodes ]; do\n         echo \"Waiting for all nodes to join... ($(ray nodes | grep NODE_ID | wc -l)/$num_nodes)\"\n         sleep 5\n       done\n\n       # Head node executes the training script.\n       echo \"Executing training script on head node...\"\n\n       python3 -m verl.trainer.main_ppo \\\n        data.train_files=data/gsm8k/train.parquet \\\n        data.val_files=data/gsm8k/test.parquet \\\n        data.train_batch_size=256 \\\n        data.max_prompt_length=512 \\\n        data.max_response_length=256 \\\n        actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n        actor_rollout_ref.rollout.name=vllm \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n        critic.optim.lr=1e-5 \\\n        critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n        critic.ppo_micro_batch_size_per_gpu=4 \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.logger=['console','wandb'] \\\n        trainer.val_before_train=False \\\n        trainer.default_hdfs_dir=null \\\n        trainer.n_gpus_per_node=1 \\\n        trainer.nnodes=2 \\\n        trainer.save_freq=20 \\\n        trainer.test_freq=20 \\\n        trainer.total_epochs=2 \\\n        trainer.project_name=verl_examples \\\n        trainer.experiment_name=experiment_name_gsm8k\n\n     else\n       # Wait for Ray Head to start.\n       sleep 10 # Increase waiting time to ensure Head finishes starting.\n       # Worker node starts Ray Worker.\n       echo \"Starting Ray worker node...\"\n\n       # Check if a Ray Worker is already running to avoid duplicate starts.\n       ps aux | grep ray | grep $head_ip:6379 &> /dev/null || ray start --address $head_ip:6379 --disable-usage-stats\n\n       # Add sleep to after `ray start` to give ray enough time to daemonize\n       sleep 5 # Ensure Worker successfully connects to Head.\n     fi\n\n     # No commands are added to the Worker node here; the Worker's main task is to start Ray and wait for the Head node to assign tasks.\n     echo \"Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK.\"\n\n\n.. code-block:: bash\n\n    export WANDB_API_KEY=<your-wandb-api-key>\n    sky launch -c verl --secret WANDB_API_KEY verl-cluster.yml\n\n.. image:: https://github.com/yottalabsai/open-source/blob/main/static/verl/running_job.png?raw=true\n.. image:: https://github.com/yottalabsai/open-source/blob/main/static/verl/running_job_1.png?raw=true\n.. image:: https://github.com/yottalabsai/open-source/blob/main/static/verl/finished.png?raw=true\n\n**Check the cluster on GCP**\n\n.. image:: https://github.com/yottalabsai/open-source/blob/main/static/verl/gcp_instances.png?raw=true\n\n**Check Ray Dashboard**\n\nWe can see the cluster on the RAY Dashboard with the GCP head node:\n\n```console\n$ sky status --endpoint 8265 verl\n1.2.3.4:8265\n```\n\n.. image:: https://github.com/yottalabsai/open-source/blob/main/static/verl/ray_dashboard_overview.png?raw=true\n.. image:: https://github.com/yottalabsai/open-source/blob/main/static/verl/ray_dashboard_jobs.png?raw=true\n.. image:: https://github.com/yottalabsai/open-source/blob/main/static/verl/ray_dashboard_cluster.png?raw=true\n\n\n**Check the checkpoint of model**\n\n.. code-block:: bash\n\n    # login the head node\n    ssh verl\n    # The global step will vary. Find the correct path from the training logs.\n    cd ~/sky_workdir/checkpoints/verl_examples/gsm8k/\n    # Then list contents to find the checkpoint, e.g.:\n    ls -R .\n\n.. image:: https://github.com/yottalabsai/open-source/blob/main/static/verl/saved_model.png?raw=true\n\n\nOption 3: Launch via Slurm\n------------------------------\n\nRay provides users with `this <https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html>`_ official\ntutorial to start a Ray cluster on top of Slurm. We have verified the :doc:`GSM8K example<../examples/gsm8k_example>`\non a Slurm cluster under a multi-node setting with the following steps.\n\n1. [Optional] If your cluster support `Apptainer or Singularity <https://apptainer.org/docs/user/main/>`_ and you wish\nto use it, convert verl's Docker image to an Apptainer image. Alternatively, set up the environment with the package\nmanager available on your cluster or use other container runtimes (e.g. through `Slurm's OCI support <https://slurm.schedmd.com/containers.html>`_) available to you.\n\n.. code:: bash\n\n    apptainer pull /your/dest/dir/vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3.sif docker://verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3\n\n2. Follow :doc:`GSM8K example<../examples/gsm8k_example>` to prepare the dataset and model checkpoints.\n\n3. Modify `examples/slurm/ray_on_slurm.slurm <https://github.com/volcengine/verl/blob/main/examples/slurm/ray_on_slurm.slurm>`_ with your cluster's own information.\n\n4. Submit the job script to the Slurm cluster with `sbatch`.\n\nPlease note that Slurm cluster setup may vary. If you encounter any issues, please refer to Ray's\n`Slurm user guide <https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html>`_ for common caveats.\n\nIf you changed Slurm resource specifications, please make sure to update the environment variables in the job script if necessary.\n\n\nOption 4: Launch via dstack\n------------------------------\n\n`dstackai/dstack <https://github.com/dstackai/dstack>`_ is an open-source container orchestrator that simplifies distributed training across cloud providers and on-premises environments\nwithout the need to use K8S or Slurm.\n\nPrerequisite\n~~~~~~~~~~~~\nOnce dstack is `installed <https://dstack.ai/docs/installation>`_, initialize the directory as a repo with ``dstack init``. \n\n.. code-block:: bash\n\n    mkdir myproject && cd myproject\n    dstack init\n\n**Create a fleet**\n\nBefore submitting distributed training jobs, create a `dstack` `fleet <https://dstack.ai/docs/concepts/fleets>`_.\n\nRun a Ray cluster task\n~~~~~~~~~~~~~~~~~~~~~~\n\nOnce the fleet is created, define a Ray cluster task, e.g. in ``ray-cluster.dstack.yml``:\n\n.. code-block:: yaml\n\n    type: task\n    name: ray-verl-cluster\n\n    nodes: 2\n\n    env:\n        - WANDB_API_KEY\n        - PYTHONUNBUFFERED=1\n        - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n    \n    image: verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2\n    commands:\n        - git clone https://github.com/volcengine/verl\n        - cd verl\n        - pip install --no-deps -e .\n        - pip install hf_transfer hf_xet\n        - |\n        if [ $DSTACK_NODE_RANK = 0 ]; then\n            python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k\n            python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-7B-Instruct')\" \n            ray start --head --port=6379;\n        else\n            ray start --address=$DSTACK_MASTER_NODE_IP:6379\n        fi\n\n    # Expose Ray dashboard port\n    ports:\n        - 8265\n\n    resources:\n        gpu: 80GB:8\n        shm_size: 128GB\n\n    # Save checkpoints on the instance\n    volumes:\n        - /checkpoints:/checkpoints\n\nNow, if you run this task via `dstack apply`, it will automatically forward the Ray's dashboard port to `localhost:8265`.\n\n.. code-block:: bash\n\n    dstack apply -f ray-cluster.dstack.yml\n\nAs long as the `dstack apply` is attached, you can use `localhost:8265` to submit Ray jobs for execution\n\nSubmit Ray jobs\n~~~~~~~~~~~~~~~\n\nBefore you can submit Ray jobs, ensure to install `ray` locally:\n   \n.. code-block:: shell\n\n    pip install ray\n\nNow you can submit the training job to the Ray cluster which is available at ``localhost:8265``:\n   \n.. code-block:: shell\n\n    $ RAY_ADDRESS=http://localhost:8265\n    $ ray job submit \\\n        -- python3 -m verl.trainer.main_ppo \\\n        data.train_files=/root/data/gsm8k/train.parquet \\\n        data.val_files=/root/data/gsm8k/test.parquet \\\n        data.train_batch_size=256 \\\n        data.max_prompt_length=512 \\\n        data.max_response_length=256 \\\n        actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n        critic.optim.lr=1e-5 \\\n        critic.model.path=Qwen/Qwen2.5-7B-Instruct \\\n        critic.ppo_micro_batch_size_per_gpu=4 \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.project_name=ppo_training \\\n        trainer.experiment_name=qwen-2.5-7B \\\n        trainer.val_before_train=False \\\n        trainer.n_gpus_per_node=8 \\\n        trainer.nnodes=2 \\\n        trainer.default_local_dir=/checkpoints \\\n        trainer.save_freq=10 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15 2>&1 | tee verl_demo.log \\\n        trainer.resume_mode=disable\n\n\nFor more details on how `dstack` works, check out its `documentation <https://dstack.ai/docs>`_.\n\nHow to debug?\n---------------------\n\n\nRay Distributed Debugger VSCode Extension (Recommended)\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n1. Starting with Ray 2.39, Anyscale has introduced the `Ray Distributed Debugger <https://docs.ray.io/en/latest/ray-observability/ray-distributed-debugger.html>`_ VSCode extension. Follow the extension’s installation instructions, then add your cluster using the dashboard URL you obtained earlier.\n\n   .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/debugger.png?raw=true\n      :alt: Ray Distributed Debugger VSCode extension screenshot\n\n2. Prerequisites.\n\n   Ensure the following are installed (see the extension README for more detail):\n\n   - Visual Studio Code  \n   - `ray[default]` >= 2.9.1  \n   - `debugpy` >= 1.8.0  \n\n   .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/c7098b755ff689859837773a916c857.png?raw=true\n      :alt: VSCode with Ray prerequisites\n\n3. Environment Variables.\n\n   To enable post‑mortem debugging, set:\n\n   .. code-block:: bash\n\n      export RAY_DEBUG_POST_MORTEM=1\n\n   .. admonition:: Note\n      :class: important\n\n      Be sure to remove any legacy flags before starting Ray:\n\n      - `RAY_DEBUG=legacy`  \n      - `--ray-debugger-external`\n\n4. Configuring BreakpointsSet up breakpoint() in your code, and submit job to cluster. Then the extension will show the breakpoint information.\n\n\n   1. Insert `breakpoint()` calls into your remote functions.  \n   2. Submit your job to the cluster.  \n\n   The extension will detect active breakpoints and display them in VSCode.\n\n   .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/4ddad74395c79a1402331c0ce73316f.png?raw=true\n      :alt: Detected breakpoint in VSCode\n\n   **Note:** Breakpoints are only supported inside functions decorated with `@ray.remote`.\n\n5. Launching the Debugger.\n\n   Run your job directly from the command line (do not use a `launch.json`):\n\n   .. code-block:: bash\n\n      python job.py\n\n6. Attaching to a Breakpoint.\n\n Once the process hits the first `breakpoint()`, click the Ray Distributed Debugger icon in the VSCode sidebar to attach the debugger.\n\n   .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/4ddad74395c79a1402331c0ce73316f.png?raw=true\n      :alt: Attaching VSCode debugger to Ray process\n\n7. Debugging With Multiple breakpoint().\n\n   For each subsequent task, first disconnect the current debugger session, then click the extension icon again to attach to the next breakpoint.\n\n   .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/6e83c910a62c82fecb89c6619e001cd.png?raw=true\n      :alt: Disconnecting and reconnecting the debugger\n\nLegacy Ray Debugger\n~~~~~~~~~~~~~~~~~~~\n1. Ray has a builtin legacy `debugger <https://docs.ray.io/en/latest/ray-observability/user-guides/debug-apps/ray-debugging.html>`_ that allows you to debug your distributed applications. To enable debugger, start ray cluster with ``RAY_DEBUG=legacy`` and ``--ray-debugger-external``.\n\n.. code-block:: bash\n\n    # start head node\n    RAY_DEBUG=legacy ray start --head --dashboard-host=0.0.0.0 --ray-debugger-external\n    # start worker node\n    RAY_DEBUG=legacy ray start --address='10.124.46.192:6379' --ray-debugger-external\n\n2. Set up breakpoint in your code, and submit job to cluster. Then run ``ray debug`` to wait breakpoint:\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/legacy.png?raw=true\n\n\nMulti-node training on AMD clusters\n---------------------------------------------------------------------------------------\n\nIf you want to run multi-node training with slurm with Docker/Podman container on AMD Cluster, you can use the following script. \n\nIf you encounter any issues in using AMD GPUs running verl, please contact `Yusheng Su <https://yushengsu-thu.github.io/>`_.\n\n.. note::\n    1. You need to use ``podman`` or ``docker`` in the following script. We will release the apptainer script later.\n    2. If you want to use ``podman``, you just replace ``docker`` with ``podman`` in the following script.\n\nThe script includes the following steps:\n\n1. SLURM Configuration\n2. Environment Setup\n3. Docker/Podman Container Setup\n4. Ray Cluster Initialization\n5. Data Preprocessing\n6. Model Setup\n7. Training Launch\n\n\nslurm_script.sh\n~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    #!/bin/bash\n\n    #SBATCH --job-name=verl-ray-on-slurm\n    #SBATCH --nodes=2\n    #SBATCH --ntasks-per-node=2\n    #SBATCH --mem=200G\n    #SBATCH --time=30-00:00:00\n    #SBATCH --gpus-per-node=8\n    #SBATCH --cpus-per-task=28\n    #SBATCH --output=../verl_log/slurm-%j.out\n    #SBATCH --error=../verl_log/slurm-%j.err\n    #SBATCH --nodelist=gpu-[0,1]\n\n\n    # load necessary modules\n    ### Run this setup\n    # [Cluster]: Use docker\n    # docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n\n\n    ##########################################################################\n    ###The following setting should be set in different project and cluster###\n    ##########################################################################\n\n    ### Project\n    CONTAINER_NAME=\"multinode_verl_training\"\n    IMG=\"verl.rocm\"\n    DOCKERFILE=\"docker/Dockerfile.rocm\"\n    # echo $PWD\n    verl_workdir=\"${HOME}/projects/verl_upstream\"\n    export TRANSFORMERS_CACHE=\"${HOME}/.cache/huggingface\"\n    export HF_HOME=$TRANSFORMERS_CACHE\n\n    ### Cluster Network Setting\n    export NCCL_DEBUG=TRACE\n    export GPU_MAX_HW_QUEUES=2\n    export TORCH_NCCL_HIGH_PRIORITY=1\n    export NCCL_CHECKS_DISABLE=1\n    # export NCCL_IB_HCA=rdma0,rdma1,rdma2,rdma3,rdma4,rdma5,rdma6,rdma7 \n    export NCCL_IB_HCA=mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_8,mlx5_9\n    export NCCL_IB_GID_INDEX=3\n    export NCCL_CROSS_NIC=0\n    export CUDA_DEVICE_MAX_CONNECTIONS=1\n    export NCCL_PROTO=Simple\n    export RCCL_MSCCL_ENABLE=0\n    export TOKENIZERS_PARALLELISM=false\n    export HSA_NO_SCRATCH_RECLAIM=1\n    ##########################################################################\n\n    ### For rocm and training script\n    export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n    export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n    export CUDA_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n\n\n    # Build and launch the Docker container\n    srun bash -c \"\n        # Exit on any error\n        set -e \n\n        # Clean up dangling images (images with <none> tag)\n        docker image prune -f\n\n        # Need to pull the docker first\n        docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n        \n        if ! docker images --format \"{{.Repository}}:{{.Tag}}\" | grep -q \"${IMG}\"; then\n            echo \\\"Building ${IMG} image...\\\"\n            docker build -f \\\"${DOCKERFILE}\\\" -t \\\"${IMG}\\\" .\n        else\n            echo \\\"${IMG} image already exists, skipping build\\\"\n        fi\n\n        # Removing old container if exists\n        docker rm \\\"${CONTAINER_NAME}\\\" 2>/dev/null || true\n\n        # Checking network devices\n        ibdev2netdev\n\n        # Launch the docker\n        docker run --rm -d \\\n        -e HYDRA_FULL_ERROR=1 \\\n        -e HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES} \\\n        -e ROCR_VISIBLE_DEVICES=${ROCR_VISIBLE_DEVICES} \\\n        -e CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} \\\n        -e NCCL_DEBUG=${NCCL_DEBUG} \\\n        -e GPU_MAX_HW_QUEUES=${GPU_MAX_HW_QUEUES} \\\n        -e TORCH_NCCL_HIGH_PRIORITY=${TORCH_NCCL_HIGH_PRIORITY} \\\n        -e NCCL_CHECKS_DISABLE=${NCCL_CHECKS_DISABLE} \\\n        -e NCCL_IB_HCA=${NCCL_IB_HCA} \\\n        -e NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX} \\\n        -e NCCL_CROSS_NIC=${NCCL_CROSS_NIC} \\\n        -e CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS} \\\n        -e NCCL_PROTO=${NCCL_PROTO} \\\n        -e RCCL_MSCCL_ENABLE=${RCCL_MSCCL_ENABLE} \\\n        -e TOKENIZERS_PARALLELISM=${TOKENIZERS_PARALLELISM} \\\n        -e HSA_NO_SCRATCH_RECLAIM=${HSA_NO_SCRATCH_RECLAIM} \\\n        -e TRANSFORMERS_CACHE=${TRANSFORMERS_CACHE} \\\n        -e HF_HOME=${HF_HOME} \\\n        --network host \\\n        --device /dev/dri \\\n        --device /dev/kfd \\\n        --device /dev/infiniband \\\n        --group-add video \\\n        --cap-add SYS_PTRACE \\\n        --security-opt seccomp=unconfined \\\n        --privileged \\\n        -v \\${HOME}:\\${HOME} \\\n        -v \\${HOME}/.ssh:/root/.ssh \\\n        -w \"${verl_workdir}\" \\\n        --shm-size 128G \\\n        --name \\\"${CONTAINER_NAME}\\\" \\\n        \\\"${IMG}\\\" \\\n        tail -f /dev/null\n\n        echo \\\"Container setup completed\\\"\n    \"\n        # (Optional): If you do not want to root mode and require assign yuorself as the user\n        # Please add `-e HOST_UID=$(id -u)` and `-e HOST_GID=$(id -g)` into the above docker launch script. \n\n\n\n\n\n    ### Ray launch the nodes before training\n\n    # Getting the node names\n    nodes_array=($(scontrol show hostnames \"$SLURM_JOB_NODELIST\" | tr '\\n' ' '))\n\n    head_node=${nodes_array[0]}\n    head_node_ip=$(srun --nodes=1 --ntasks=1 -w \"$head_node\" hostname --ip-address)\n\n    # if we detect a space character in the head node IP, we'll\n    # convert it to an ipv4 address. This step is optional.\n    if [[ \"$head_node_ip\" == *\" \"* ]]; then\n        IFS=' ' read -ra ADDR <<<\"$head_node_ip\"\n    if [[ ${#ADDR[0]} -gt 16 ]]; then\n        head_node_ip=${ADDR[1]}\n    else\n        head_node_ip=${ADDR[0]}\n    fi\n        echo \"IPV6 address detected. We split the IPV4 address as $head_node_ip\"\n    fi\n\n    port=6379\n    ip_head=$head_node_ip:$port\n    export ip_head\n    echo \"IP Head: $ip_head\"\n\n    # make sure we set environment variables before Ray initialization\n\n    # Print out all env variables\n    printenv\n\n    echo \"Starting HEAD at $head_node\"\n    srun --nodes=1 --ntasks=1 -w \"$head_node\" \\\n        docker exec \"${CONTAINER_NAME}\" \\\n            ray start --head --node-ip-address=\"$head_node_ip\" --port=$port \\\n            --dashboard-port=8266 \\\n            --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n    # optional, though may be useful in certain versions of Ray < 1.0.\n    sleep 10\n\n    # number of nodes other than the head node\n    worker_num=$((SLURM_JOB_NUM_NODES - 1))\n\n    for ((i = 1; i <= worker_num; i++)); do\n        node_i=${nodes_array[$i]}\n        echo \"Debug: Starting worker on node_i = ${node_i}\"\n        if [ -z \"$node_i\" ]; then\n            echo \"Error: Empty node name for worker $i\"\n            continue\n        fi\n        echo \"Starting WORKER $i at $node_i\"\n        srun --nodes=1 --ntasks=1 -w \"$node_i\" \\\n            docker exec \"${CONTAINER_NAME}\" \\\n                ray start --address \"$ip_head\" --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n        sleep 5\n    done\n\n\n\n\n    # Ray initlization test (See whether any error in the above execution)\n    echo \"Testing Ray initialization in the slurm nodes...\"\n    docker exec \"${CONTAINER_NAME}\" python3 -c '\n    import ray\n    try:\n        ray.init(address=\"auto\")\n        print(\"\\n=== Ray Cluster Status ===\")\n        print(f\"Number of nodes: {len(ray.nodes())}\")\n        for node in ray.nodes():\n            print(\"Node: {}, Status: {}\".format(node[\"NodeManagerHostname\"], node[\"Alive\"]))\n            # print(f\"Node: {node}\")\n        ray.shutdown()\n        print(\"Ray initialization successful!\")\n    except Exception as e:\n        print(f\"Ray initialization failed: {str(e)}\")\n    '\n    echo \"=== Ray test completed ===\"\n    ######\n\n\n\n    # Run data preprocessing\n\n    echo \"Starting data preprocessing...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 \"examples/data_preprocess/gsm8k.py\" \"--local_save_dir\" \"../data/gsm8k\"\n\n    echo \"Starting data preprocessing...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 \"examples/data_preprocess/math_dataset.py\" \"--local_dir\" \"../data/math\"\n\n    train_files=\"../data/gsm8k/train.parquet\"\n    val_files=\"../data/gsm8k/test.parquet\"\n\n    # Download and test model\n    echo \"Loading model...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')\"\n    MODEL_PATH=\"Qwen/Qwen2-7B-Instruct\"\n\n    # Set model path after pipeline test\n    MODEL_PATH=\"Qwen/Qwen2.5-0.5B-Instruct\"\n\n    echo \"== Data and model loading Done ==\"\n\n    echo \"Start to train...\"\n\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')\"\n    MODEL_PATH=\"Qwen/Qwen2-7B-Instruct\"\n\n\n    PYTHONUNBUFFERED=1 srun --overlap --nodes=${SLURM_NNODES} --ntasks=1 -w \"$head_node\" \\\n        docker exec \"${CONTAINER_NAME}\" \\\n        python3 -m verl.trainer.main_ppo \\\n        data.train_files=$train_files \\\n        data.val_files=$val_files \\\n        data.train_batch_size=1024 \\\n        data.max_prompt_length=1024 \\\n        data.max_response_length=1024 \\\n        actor_rollout_ref.model.path=$MODEL_PATH \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=False \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n        actor_rollout_ref.rollout.name=vllm \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n        critic.optim.lr=1e-5 \\\n        critic.model.use_remove_padding=True \\\n        critic.model.path=$MODEL_PATH \\\n        critic.model.enable_gradient_checkpointing=False \\\n        critic.ppo_micro_batch_size_per_gpu=8 \\\n        critic.model.fsdp_config.param_offload=False \\\n        critic.model.fsdp_config.optimizer_offload=False \\\n        algorithm.kl_ctrl.kl_coef=0.0001 \\\n        trainer.critic_warmup=0 \\\n        trainer.logger='[\"console\",\"wandb\"]' \\\n        trainer.project_name='verl_example' \\\n        trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \\\n        trainer.n_gpus_per_node=${SLURM_GPUS_PER_NODE} \\\n        trainer.val_before_train=False \\\n        trainer.nnodes=${SLURM_NNODES} \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15\n\n\nRun multi-node training with above slurm_script.sh\n~~~~~~~~~~~~~~~~~~~~\nJust sbatch your slurm_script.sh\n\n.. code-block:: bash\n\n    sbatch slurm_script.sh\n\n"
  },
  {
    "path": "verl_distillation/docs/start/quickstart.rst",
    "content": ".. _quickstart:\n\n=========================================================\nQuickstart: PPO training on GSM8K dataset\n=========================================================\n\nPost-train a LLM using GSM8K dataset.\n\nIntroduction\n------------\n\n.. _hf_dataset_gsm8k: https://huggingface.co/datasets/gsm8k\n\nIn this example, we train an LLM to tackle the `GSM8k <hf_dataset_gsm8k>`_ task with function-based rewards. [1]_\n\nPrerequisite:\n\n- the latest version of ``verl`` and its dependencies installed following the installation guide. Using the docker image is recommended.\n\n- a GPU with at least 24 GB HBM\n\n\nDataset Introduction\n--------------------\n\nGSM8k is a math problem dataset. The prompt is an elementary school\nproblem. The LLM model is asked to solve the math problem. Below is an example:\n\nPrompt\n\n   Katy makes coffee using teaspoons of sugar and cups of water in the\n   ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups\n   of water, calculate the number of teaspoonfuls of sugar she used.\n\nSolution\n\n   The total ratio representing the ingredients she used to make the\n   coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the\n   number of teaspoons she used is 7/20, she used 7/20\\ *120 =\n   <<7/20*\\ 120=42>>42 #### 42\n\nStep 1: Prepare the dataset\n----------------------------\n\nWe preprocess the dataset in parquet format so that (1) it contains necessary fields for computing RL rewards and (2) is faster to read.\n\n.. code-block:: bash\n\n   python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k\n\nStep 2: Download a model for post-training\n-------------------------------------------\n\nIn this example, we start with the ``Qwen2.5-0.5B-Instruct`` model.\n\nIf you want to perform SFT before RL, refer to the :doc:`Complete GSM8K Example<../examples/gsm8k_example>`, the `sft directory <https://github.com/volcengine/verl/blob/main/examples/sft/gsm8k>`_ and `SFT Trainer <https://github.com/volcengine/verl/blob/main/verl/trainer/fsdp_sft_trainer.py>`_ for further details.\n\n.. code-block:: bash\n\n   python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-0.5B-Instruct')\"\n\nStep 3: Perform PPO training with the instruct model\n----------------------------------------------------------------------\n\n**Reward Model/Function**\n\nWe use a pre-defined rule-based reward model. We force the model to produce a final\nanswer following 4 “#” as shown in the solution. We extract the final\nanswer from both the solution and model's output using regular\nexpression matching. We assign a reward of 1 to correct\nanswer, 0.0 to incorrect answer and 0 to no answer. \n\nFor more details, please refer to `verl/utils/reward_score/gsm8k.py <https://github.com/volcengine/verl/blob/v0.4.1/verl/utils/reward_score/gsm8k.py>`_.\n\n**Training Script**\n\nNow let's run PPO training with the dataset and model above. [2]_\n\n\nSet the ``data.train_files`` ,\\ ``data.val_files``, ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on your dataset and model names or paths.\nYou may set ``VERL_USE_MODELSCOPE=True`` to download models from `modelscope <https://www.modelscope.cn>`_ instead of `huggingface <https://huggingface.co>`_.\n\n.. code-block:: bash\n\n   PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=256 \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.logger=console \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=10 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 2>&1 | tee verl_demo.log\n\nYou are expected to see the following logs, indicating training in progress. The key metric ``val/test_score/openai/gsm8k`` is computed every ``trainer.test_freq`` steps:\n\n.. code-block:: bash\n\n    step:0 - timing/gen:21.470 - timing/ref:4.360 - timing/values:5.800 - actor/reward_kl_penalty:0.000 - actor/reward_kl_penalty_coeff:0.001 - timing/adv:0.109 - timing/update_critic:15.664 - critic/vf_loss:14.947 - critic/vf_clipfrac:0.000 - critic/vpred_mean:-2.056 - critic/grad_norm:1023.278 - critic/lr(1e-4):0.100 - timing/update_actor:20.314 - actor/entropy_loss:0.433 - actor/pg_loss:-0.005 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/grad_norm:1.992 - actor/lr(1e-4):0.010 - critic/score/mean:0.004 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.004 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:2.360 - critic/advantages/min:-2.280 - critic/returns/mean:0.003 - critic/returns/max:0.000 - critic/returns/min:0.000 - critic/values/mean:-2.045 - critic/values/max:9.500 - critic/values/min:-14.000 - response_length/mean:239.133 - response_length/max:256.000 - response_length/min:77.000 - prompt_length/mean:104.883 - prompt_length/max:175.000 - prompt_length/min:68.000\n    step:1 - timing/gen:23.020 - timing/ref:4.322 - timing/values:5.953 - actor/reward_kl_penalty:0.000 - actor/reward_kl_penalty:0.001 - timing/adv:0.118 - timing/update_critic:15.646 - critic/vf_loss:18.472 - critic/vf_clipfrac:0.384 - critic/vpred_mean:1.038 - critic/grad_norm:942.924 - critic/lr(1e-4):0.100 - timing/update_actor:20.526 - actor/entropy_loss:0.440 - actor/pg_loss:0.000 - actor/pg_clipfrac:0.002 - actor/ppo_kl:0.000 - actor/grad_norm:2.060 - actor/lr(1e-4):0.010 - critic/score/mean:0.000 - critic/score/max:0.000 - critic/score/min:0.000 - critic/rewards/mean:0.000 - critic/rewards/max:0.000 - critic/rewards/min:0.000 - critic/advantages/mean:0.000 - critic/advantages/max:2.702 - critic/advantages/min:-2.616 - critic/returns/mean:0.000 - critic/returns/max:0.000 - critic/returns/min:0.000 - critic/values/mean:-2.280 - critic/values/max:11.000 - critic/values/min:-16.000 - response_length/mean:232.242 - response_length/max:256.000 - response_length/min:91.000 - prompt_length/mean:102.398 - prompt_length/max:185.000 - prompt_length/min:70.000\n\nCheckout ``Algorithm Baselines`` page for full training and validation logs for reference.\n\nThe checkpoint is saved at the following dir by default: ``checkpoints/${trainer.project_name}/${trainer.experiment_name}``. You can merge the saved checkpoints to huggingface model using ``verl.model_merger`` module, for example:\n\n.. code-block:: bash\n\n    python3 -m verl.model_merger merge \\\n        --backend fsdp \\\n        --local_dir checkpoints/${trainer.project_name}/${trainer.experiment_name}/global_step_1/actor \\\n        --target_dir checkpoints/${trainer.project_name}/${trainer.experiment_name}/global_step_1/actor/huggingface\n\nFor more details about checkpoint and model merging, please refer to :ref:`checkpoint-page`.\n\nTo enable ``wandb`` for experiment tracking, set the following configs:\n\n.. code-block:: bash\n\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=$YOUR_PROJECT_NAME \\\n    trainer.experiment_name=$YOUR_RUN_NAME \\\n\nIf you encounter out of memory issues with HBM less than 32GB, enable the following configs would help:\n\n.. code-block:: bash\n\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    critic.ppo_micro_batch_size_per_gpu=1 \\\n\nFor the full set of configs, please refer to :ref:`config-explain-page` for detailed explanation and performance tuning.\n\n\n.. [1] The original paper (https://arxiv.org/pdf/2110.14168) mainly focuses on training a verifier (a reward model) to solve math problems via Best-of-N sampling. In this example, we train an RL agent using a rule-based reward model.\n.. [2] More training script examples for FSDP and Megatron-LM backend are stored in `examples/ppo_trainer <https://github.com/volcengine/verl/tree/main/examples/ppo_trainer>`_ directory.\n"
  },
  {
    "path": "verl_distillation/docs/start/ray_debug_tutorial.rst",
    "content": "Ray Debug Tutorial\r\n==================\r\n\r\nLast updated: 04/23/2025\r\n\r\n\r\n.. _wuxibin89: https://github.com/wuxibin89\r\n\r\nAuthor: `Ao Shen <https://aoshen524.github.io/>`_.\r\n\r\nHow to debug?\r\n---------------------\r\n\r\n\r\nRay Distributed Debugger VSCode Extension (Recommended)\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n1. Starting with Ray 2.39, Anyscale has introduced the `Ray Distributed Debugger <https://docs.ray.io/en/latest/ray-observability/ray-distributed-debugger.html>`_ VSCode extension. Follow the extension’s installation instructions, then add your cluster using the dashboard URL you obtained earlier.\r\n\r\n   .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/debugger.png?raw=true\r\n      :alt: Ray Distributed Debugger VSCode extension screenshot\r\n\r\n2. Prerequisites.\r\n\r\n   Ensure the following are installed (see the extension README for more detail):\r\n\r\n   - Visual Studio Code  \r\n   - `ray[default]` >= 2.9.1  \r\n   - `debugpy` >= 1.8.0  \r\n\r\n   .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/readme.png?raw=true\r\n      :alt: VSCode with Ray prerequisites\r\n\r\n3. Environment Variables.\r\n\r\n   To enable post‑mortem debugging, set:\r\n\r\n   .. code-block:: bash\r\n\r\n      export RAY_DEBUG_POST_MORTEM=1\r\n\r\n   .. admonition:: Note\r\n      :class: important\r\n\r\n      Be sure to remove any legacy flags before starting Ray:\r\n\r\n      - `RAY_DEBUG=legacy`  \r\n      - `--ray-debugger-external`\r\n\r\n4. Configuring BreakpointsSet up breakpoint() in your code, and submit job to cluster. Then the extension will show the breakpoint information.\r\n\r\n\r\n   1. Insert `breakpoint()` calls into your remote functions.  \r\n   2. Submit your job to the cluster.  \r\n\r\n   The extension will detect active breakpoints and display them in VSCode.\r\n\r\n   **Note:** Breakpoints are only supported inside functions decorated with `@ray.remote`.\r\n\r\n5. Launching the Debugger.\r\n\r\n   Run your job directly from the command line (do not use a `launch.json`):\r\n\r\n   .. code-block:: bash\r\n\r\n      python job.py\r\n\r\n6. Attaching to a Breakpoint.\r\n\r\n Once the process hits the first `breakpoint()`, click the Ray Distributed Debugger icon in the VSCode sidebar to attach the debugger.\r\n\r\n   .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/launch.png?raw=true\r\n      :alt: Attaching VSCode debugger to Ray process\r\n\r\n7. Debugging With Multiple breakpoint().\r\n\r\n   For each subsequent task, first disconnect the current debugger session, then click the extension icon again to attach to the next breakpoint.\r\n\r\n   .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/disconnect.png?raw=true\r\n      :alt: Disconnecting and reconnecting the debugger\r\n\r\nLegacy Ray Debugger\r\n~~~~~~~~~~~~~~~~~~~\r\n1. Ray has a builtin legacy `debugger <https://docs.ray.io/en/latest/ray-observability/user-guides/debug-apps/ray-debugging.html>`_ that allows you to debug your distributed applications. To enable debugger, start ray cluster with ``RAY_DEBUG=legacy`` and ``--ray-debugger-external``.\r\n\r\n.. code-block:: bash\r\n\r\n    # start head node\r\n    RAY_DEBUG=legacy ray start --head --dashboard-host=0.0.0.0 --ray-debugger-external\r\n    # start worker node\r\n    RAY_DEBUG=legacy ray start --address='10.124.46.192:6379' --ray-debugger-external\r\n\r\n2. Set up breakpoint in your code, and submit job to cluster. Then run ``ray debug`` to wait breakpoint:\r\n\r\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/legacy.png?raw=true\r\n\r\n"
  },
  {
    "path": "verl_distillation/docs/workers/fsdp_workers.rst",
    "content": "PyTorch FSDP Backend\n======================\n\nLast updated: 02/12/2025.\n\nWe support PyTorch FSDP Backend by implementing various workers for\nactor, critic, reference, rollout and reward models. We also implement\nthe ``FSDPVLLMShardingManager`` that reshard weight between FSDP and\nvLLM in `fsdp_vllm.py <https://github.com/volcengine/verl/blob/main/verl/workers/sharding_manager/fsdp_vllm.py>`_.\n\n**Pros**\n\n- Readily support various models.\n\n  - Users only need to implement the corresponding\n    ``dtensor_weight_loader`` for weight synchronization between FSDP\n    and vLLM. While for ``hf_weight_loader``, users can directly apply\n    any models supported both in HF and vLLM without any code change.\n\n- Easy to organize the forward and backward computation for each model.\n\n**Cons**\n\n- Poor scalability when it comes to large-scale models (e.g. Llama 70B\n  and 405B)\n- The resharding overhead between actor and rollout could be larger than\n  Megatron-LM backend.\n\nDue to the simplicity, we recommend using FSDP backend for algorithm\nresearch and prototyping.\n\nFSDP Workers\n--------------\n\nActorRolloutRefWorker\n^^^^^^^^^^^^^^^^^^^^^\n\nActor/Rollout HybridEngine\n''''''''''''''''''''''''''\n\n1. HybridEngine, Actor and Rollout initialization API.\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n   def init_model(self):\n\n``ONE_TO_ALL``: when calling the ``init_model`` function from the driver\nprocess, each worker (on a GPU) will execute the following model\ninitialization process.\n\nThe initialization details of HybridEngine, Actor and Rollout are\nhighlighted below:\n\n1. ``DataParallelPPOActor`` implements the simple PPO computation logics\n   when the model is built with FSDP, including compute log prob, model\n   update.\n2. ``vLLMRollout`` support generation with vLLM. We modify the vLLM\n   Engine and make it executed under SPMD to fit into our\n   ``WorkerGroup`` design.\n3. ``FSDPVLLMShardingManager`` a context manager to perform actual\n   resharding between actor and rollout.\n\nSee `source code <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py>`_. for more information.\n\n1. Generate sequence and recompute log prob\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def generate_sequences(self, prompts: DataProto):\n\n- ``Dispatch.DP_COMPUTE_PROTO``: The data will be dispatched and\n  collected along the DP dimension\n\n- In this function, the rollout model will perform auto-regressive\n  generation and the actor model will recompute the old log prob for the\n  generated response.\n\n3. Update actor model\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def update_actor(self, data: DataProto):\n\n- Update the actor model weight using PPO & entropy loss.\n\nReferenceModel\n''''''''''''''\n\n1. Reference model initialization\n\nThe reference model is initialized using the same function as the actor\nmodel without initializing the HybridEngine and Optimizer. Then the\nactor model is also wrapped by the ``DataParallelPPOActor``.\n\n2. Compute reference log prob\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def compute_ref_log_prob(self, data: DataProto):\n\n- In this function, the reference model will call the compute log prob\n  function in ``DataParallelPPOActor`` to compute the reference log\n  prob.\n\nCriticWorker and RewardWorker\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n1. Model initialization\n\nQuite similar to reference model. The CriticWorker will perform\nadditional initialization for the Optimizer.\n\n2. Compute Values for CriticWorker\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def compute_values(self, data: DataProto):\n\n3. Update Critic\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def update_critic(self, data: DataProto):\n\n4. Compute Reward\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def compute_rm_score(self, data: DataProto):\n\n\nHybridShard\n------------\n\nWe didn't support FSDP `HybridShard`. To support this, we may need to\nconstruct a 2D device mesh and test the corresponding\n``dtensor_weight_loader`` and ``hf_weight_loader`` for each model.\n"
  },
  {
    "path": "verl_distillation/docs/workers/megatron_workers.rst",
    "content": "Megatron-LM Backend\n===================\n\nLast updated: 06/24/2025.\n\nWe support Megatron Backend by implementing various workers for actor,\ncritic, reference, rollout and reward models. We also implement the\n``3DHybridEngine`` using Megatron-LM and vLLM/SGLang in\n`megatron_vllm.py <https://github.com/volcengine/verl/blob/main/verl/workers/sharding_manager/megatron_vllm.py>`_\nand `megatron_sglang.py <https://github.com/volcengine/verl/blob/main/verl/workers/sharding_manager/megatron_sglang.py>`_.\n\n**Pros**\n\n- Support 5D parallelism (TP, EP, CP, DP, PP) and sequence parallelism\n  for best scalablility and throughput.\n- 3D HybridEngine can significantly reduce peak memory usage and reduce\n  weight synchronize overhead between actor and rollout.\n\n**Cons**\n\n- Huggingface Models and Megatron checkpoints need tools for conversion.\n\n\nDevelopment Progress\n--------------------\n\n\nNote that [Deprecated] means that the feature is not supported in the latest\nversion of verl.\n[To-Optimize] means that the feature is implemented but not optimized yet.\n[WIP] means that the feature is working in progress.\n[In-Release] means that the feature is ready and in review process,\ncoming at any time.\n\n\n+---------------+-----------------------------------------------------------+\n| [Deprecated]  | Megatron 3D Parallelism with custom models                |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron 0.11.0 ``GPTModel`` support                      |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron GRPO support                                     |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron with vLLM 0.8.2, with per-tensor weights loading |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron with Context Parallel                            |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Qwen2MoE model support                                    |\n+---------------+-----------------------------------------------------------+\n| [To-Optimize] | Megatron dist Checkpoint                                  |\n+---------------+-----------------------------------------------------------+\n| [To-Optimize] | Huggingface and Megatron Checkpoint Converter             |\n+---------------+-----------------------------------------------------------+\n| [To-Optimize] | Efficient fused linear, entropy and cross entropy         |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron offload(param, grad, optimizer)                  |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron Profiler                                         |\n+---------------+-----------------------------------------------------------+\n| [In-Release]  | Megatron 0.12.0, TE 2.2 with vLLM 0.8.3 and Fused Attn    |\n+---------------+-----------------------------------------------------------+\n| [WIP]         | Moonlight/DeepSeek-V3 model support                       |\n+---------------+-----------------------------------------------------------+\n| [WIP]         | Expert Parallel support                                   |\n+---------------+-----------------------------------------------------------+\n| [WIP]         | Megatron support dynamic batch size                       |\n+---------------+-----------------------------------------------------------+\n| [To-Do]       | Performance tuning                                        |\n+---------------+-----------------------------------------------------------+\n| [MileStone]   | Runnable with DeepSeek-V3 671B post-training              |\n+---------------+-----------------------------------------------------------+\n\n\n\nUtils of Megatron Workers\n-------------------------\n\nMegatronWorker\n^^^^^^^^^^^^^^\n\n``MegatronWorker`` is the base class of different megatron worker\nclasses. In this class, ``get_megatron_global_info`` and\n``get_megatron_rank_info`` function to retrieve the 3D parallel world\nsize and rank of each ``Worker`` running on specific GPU. These information\nwill be used in transfer protocol for Megatron Backend.\n\nThe following ``Worker`` class for different models will be utilized to\nconstruct the ``WorkerGroup`` .\n\nWe implement various of APIs for each ``Worker`` class decorated by the\n``@register(dispatch_mode=)`` . These APIs can be called by the ray\ndriver process. The data can be correctly collect and dispatch following\nthe ``dispatch_mode`` on each function. The supported dispatch_model\n(i.e., transfer protocols) can be found in `decorator.py <https://github.com/volcengine/verl/blob/main/verl/single_controller/base/decorator.py>`_.\n\nActorRolloutRefWorker\n^^^^^^^^^^^^^^^^^^^^^\n\nThis class is implemented for Actor/Rollout HybridEngine or for the\nreference model to initialize their model and perform computation.\n\nActor/Rollout HybridEngine\n''''''''''''''''''''''''''\n\n1. HybridEngine, Actor and Rollout initialization API.\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n   def init_model(self):\n\n``ONE_TO_ALL``: when calling the ``init_model`` function from the driver\nprocess, each worker (on a GPU) will execute the following model\ninitialization process.\n\nThe initialization details of HybridEngine, Actor and Rollout are\nhighlighted below:\n\n1. ``MegatronPPOActor`` implements the simple PPO computation logics\n   when the model is built with Megatron, including compute log prob,\n   model update.\n2. ``vLLMRollout`` support generation with vLLM. We modify the vLLM\n   Engine and make it executed under SPMD to fit into our\n   ``WorkerGroup`` design.\n3. ``MegatronVLLMShardingManager`` a context manager to perform actual\n   resharding between actor and rollout.\n\nSee `source code <https://github.com/volcengine/verl/blob/main/verl/workers/megatron_workers.py#L63>`_ for more information.\n\n.. code:: python\n\n   # build actor model\n   self.actor = MegatronPPOActor(config=self.config.actor,\n                                 model_config=self.actor_model_config,\n                                 megatron_config=megatron_config,\n                                 actor_module=self.actor_module,\n                                 actor_optimizer=self.actor_optimizer,\n                                 actor_optimizer_config=self.actor_optim_config)\n\n   # build rollout\n   # rollout initialization\n   rollout = vLLMRollout(actor_module=params,\n                        config=self.config.rollout,\n                        tokenizer=self.tokenizer,\n                        model_hf_config=self.actor_model_config,\n                        train_tp=mpu.get_tensor_model_parallel_world_size())\n   # perform weight resharding between actor and rollout\n   sharding_manager = MegatronVLLMShardingManager(module=self.hybrid_engine,\n                                                  inference_engine=rollout.inference_engine,\n                                                  model_config=self.actor_model_config,\n                                                  layer_name_mapping=layer_name_mapping)\n   ...\n\n1. Generate sequence and recompute log prob\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_PP_AS_DP_PROTO)\n   def generate_sequences(self, prompts: DataProto):\n\n- ``Dispatch.MEGATRON_PP_AS_DP_PROTO``: The PP dimension of the actor\n  model will be regarded as DP dimension. Then the driver process will\n  dispatch and collect the data according to this reorganization. This\n  is because, in HybridEngine, the actor weight, which usually applied\n  larger 3D parallel sizes, will be gathered along the PP dimension and\n  TP dimension. Therefore, the corresponding data should be dispatched\n  and collected through the 3D parallel group of the rollout model,\n  rather than the actor model. However, the world_size and rank\n  information can only be retrieved from ``get_megatron_global_info`` and\n  ``get_megatron_rank_info``, which records the 3D information for the\n  actor model. Moreover, the data resharding inside TP dimension will be\n  processed within the HybridEngine.\n\n- In this function, the rollout model will perform auto-regressive\n  generation and the actor model will recompute the old log prob for the\n  generated response.\n\n3. Update actor model\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n   def update_actor(self, data: DataProto):\n\n- ``Dispatch.MEGATRON_COMPUTE_PROTO``: User passes the data partitioned\n  by DP dimension. The data is dispatched to all tp/pp ranks within the\n  same dp group, and ultimately only collects output data from tp=0 and\n  the last pp.\n- Update the actor model weight using PPO & entropy loss.\n\n\n..note:: \n\n   Currently, training Tensor Parallel Size can be different from inference\n   Tensor Parallel Size.\n\n\nReferenceModel\n''''''''''''''\n\n1. Reference model initialization\n\nThe reference model is initialized using the same function as the actor\nmodel without initializing the HybridEngine and Optimizer. Then the\nactor model is also wrapped by the ``MegatronPPOActor``.\n\n2. Compute reference log prob\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n   def compute_ref_log_prob(self, data: DataProto):\n\n- In this function, the reference model will call the compute log prob\n  function in ``MegatronPPOActor`` to compute the reference log prob.\n\nCriticWorker and RewardWorker\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n1. Model initialization\n\nQuite similar to reference model. The CriticWorker will perform\nadditional initialization for the Optimizer.\n\n2. Compute Values for CriticWorker\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n   def compute_values(self, data: DataProto):\n\n3. Update Critic\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n   def update_critic(self, data: DataProto):\n\n4. Compute Reward\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n   def compute_rm_score(self, data: DataProto):\n\n\nUtils of Train Optimization\n---------------------------\n\nOffload\n^^^^^^^\nWhen resources are tight, the offload method can lower GPU memory \nusage, helping training and inference frameworks work well under verl. \nIt moves parameters, gradients, and optimizers to CPU memory and only \nloads them back to the GPU when needed.\n\nIf you want to use the offload, you can add the following parameters \nfor the actor and ref separately. \n\n.. code:: python\n\n   # For the actor\n   actor_rollout_ref.actor.megatron.param_offload=True \\\n   actor_rollout_ref.actor.megatron.grad_offload=True \\\n   actor_rollout_ref.actor.megatron.optimizer_offload=True \\\n   # For the ref w/o grad and optimizer\n   actor_rollout_ref.ref.megatron.param_offload=True \\\n\n\nFor the critic, you can include these parameters.\n\n.. code:: python\n\n   # For the critic\n   critic.megatron.param_offload=True \\\n   critic.megatron.grad_offload=True \\\n   critic.megatron.optimizer_offload=True \\\n\n\nRelated MCore Document\n----------------------\n\nThere is also a detailed document of using MCore to train different\nkinds of models, please refer to `MCore Document <https://github.com/volcengine/verl/blob/main/verl/models/mcore/readme.md>`_.\n"
  },
  {
    "path": "verl_distillation/docs/workers/model_engine.rst",
    "content": "Model Engine\n============\n\n.. _vermouth: https://github.com/vermouth1992\n\nAuthor: `Chi Zhang <https://github.com/vermouth1992>`_\n\nLast updated: 09/25/2025.\n\nCurrent Support Matrix\n----------------------\n\n+----------+-----------+--------------+-------------+--------------------------+\n| Backends | Model     | Scalability  | Model       | Pain points              |\n|          | Supported |              | Definition  |                          |\n|          |           |              |             |                          |\n+==========+===========+==============+=============+==========================+\n| FSDP     | Day 1     | - Dense is OK| Huggingface | Monkey patch can be      |\n| +        | support   |              | + monkey    | easily impacted by       |\n| ulysses  | HF model  | - MoE is bad | patch       | transformers version     |\n+----------+-----------+--------------+-------------+--------------------------+\n| MCore    | Limited   | Best         | GPTModel    | Supporting new models is |\n|          |           |              | (One model  | difficult                |\n|          |           |              | for all)    |                          |\n+----------+-----------+--------------+-------------+--------------------------+\n\n-  We monkey patch attention function to support ulysses\n-  We monkey patch VLM models to support FSDP with mixed data with and\n   without images\n\nClass Hierarchy\n---------------\n\nNote that all the workers and trainers run in **SPMD** mode. SFT/DPO/RM\ntrainer is directly invoked by ``torchrun``. The Actor/Critic worker can\nalso be invoked by a RayWorkerGroup and provides APIs to a single\ncontroller.\n\n-  Base Engine level: implement model init, optimizer init, lr scheduler\n   init, sharding, checkpoint manager.\n-  Full Engine level: subclass base engine and implement\n   ``forward_step``.\n-  Worker/SPMD trainer level: **engine agnostic**, implement training\n   logics using abstract engine APIs\n\nRL trainer utilizes workers to construct HybridFlow program. This is out\nof the scope of model engine.\n\nExisting Model Types\n--------------------\n\n========== ====================== ======================\nModel type Language model         Value model\n========== ====================== ======================\nInput      text/image/video/audio text/image/video/audio\nOutput     logits for next token  logits as value\n========== ====================== ======================\n\nCurrently, we have two model types: language model and value model. We\nexpect to expand the category to include Qwen-Omni family (output both\ntext and audio) and VLA models.\n\nData Format\n-----------\n\nCurrently, verl adopts left-right padding data format in RL trainer.\nThis creates massive padding when the discrepancy between response\nlength is large. We will start to implement no-padding format throughout\nthe whole system.\n\n.. image:: https://github.com/vermouth1992/verl-data/blob/master/images/data_format.png?raw=true\n   :alt: Data Format\n\nHere is the migration plan:\n- Implement no-padding format in engine\n- Add a transformation layer in Actor/Critic worker.\n- Replace Actor/Critic Worker in RL trainer\n- Implement no-padding throughput system\n\nCheckpoint System\n-----------------\n\n.. image:: https://github.com/vermouth1992/verl-data/blob/master/images/verl-ckpt.png?raw=true\n   :alt: Model Engine Checkpoint System\n\nThe engine constructs the model using huggingface config, then load\nweights from huggingface checkpoint. If the engine directly uses\nhuggingface model definition, it can use function provided by\n``transformers``. Otherwise, each engine has to write their own\ncheckpoint load logic (e.g.,\n`mbridge <https://github.com/ISEEKYAN/mbridge>`__). During model\ntraining, each engine has to implement save_checkpoint and\nload_checkpoint that save/load intermediate sharded checkpoint including\nmodel, optimizer and lr scheduler states. Each engine has to implement a\ncheckpoint merge script, that merges the intermediate sharded checkpoint\nback to huggingface format.\n\nAPI\n---\n\nA tentative model engine API can be found:\nhttps://github.com/volcengine/verl/blob/main/verl/workers/engine/base.py#L24\n\nExtension\n---------\n\nAdd a new backend\n~~~~~~~~~~~~~~~~~\n\n-  Start a new folder under ``verl/workers/engine``. Then, implement\n   ``transformer_impl.py``. If you want to implement a non-transformer\n   model, please contact us in advance.\n-  Add the engine config to the GSM8k SFT trainer script:\n   https://github.com/volcengine/verl/blob/main/tests/special_e2e/sft/run_sft_engine_gsm8k.sh\n-  Invoke the tests with your backend:\n   https://github.com/volcengine/verl/blob/main/tests/special_e2e/sft/test_sft_engine_all.sh.\n   This test script will run various backends and various\n   configurations, and compare the loss and grad norm of the first step\n   to make sure they are close.\n\nAdd a new model type\n~~~~~~~~~~~~~~~~~~~~\n\n-  This is mainly reserved for models whose the output is not just text\n   (e.g., Qwen3-Omni). Please discuss with us before you proceed.\n"
  },
  {
    "path": "verl_distillation/docs/workers/ray_trainer.rst",
    "content": "PPO Ray Trainer\n===============\n\nLast updated: 02/12/2025.\n\nWe implement the RayPPOTrainer, which is a trainer runs on the driver\nprocess on a single CPU/GPU node (default is CPU).\n\nThe PPORayTrainer include 3 core functions for data preparation,\nWorkerGroup initialization and PPO training loop.\n\nData Preparation\n----------------\n\nThe ``PPORayTrainer``, as a single process, is responsible for loading a\ncomplete batch of samples (prompts) from the dataset and then dispatch\nto different worker_groups running on different GPUs.\n\nTo generalize the data loading, we implement the ``RLHFDataset`` class\nto load the preprocessed parquet files, apply chat templates to the\nprompts, add padding, truncate prompts that exceed max prompt length and\nthen tokenize.\n\n.. code:: python\n\n   self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,\n                                       tokenizer=self.tokenizer,\n                                       config=self.config.data)\n\nThen, the dataloader will iterate the dataset under PPO mini batch size.\n\nWorkerGroup Initialization\n--------------------------\n\nWe first introduce a basic implementation of initializing the\n``WorkerGroup`` of the actor model on a given set of GPUs.\n\n.. code:: python\n\n   # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool\n   # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.\n   # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models\n   resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n                                   use_gpu=True,\n                                   max_colocate_count=1)\n   # define actor rollout cls to be init on remote\n   actor_rollout_cls = RayClassWithInitArgs(cls=ActorRolloutWorker)\n   # define actor_rollout worker group\n   actor_rollout_worker_group = MegatronRayWorkerGroup(resource_pool=resource_pool,\n                                                       ray_cls_with_init=actor_rollout_cls,\n                                                       default_megatron_kwargs=config.actor_rollout.megatron)\n\nDifferent WorkerGroups, like ``actor_rollout_worker_group`` ,\n``critic_worker_group`` and ``ref_worker_group`` lies on a separate\nprocess in the above implementation.\n\nThe driver process can then call the distributed compute function within\nthe ``actor_rollout_worker_group`` and other roles to construct the RL\ntraining loop.\n\nFor models colocated in the same set of GPUs, we further provide a\nfine-grain optimization, which merge the ``worker_group`` of different roles\nin the same process. This optimization can save the redundant\nCUDA/distributed context in different processes.\n\n.. code:: python\n\n   # initialize WorkerGroup\n   # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,\n   # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.\n   # See TODO(url) for more information.\n   all_wg = {}\n   for resource_pool, class_dict in self.resource_pool_to_cls.items():\n       worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n       wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)\n       spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n       all_wg.update(spawn_wg)\n\n   if self.use_critic:\n       self.critic_wg = all_wg['critic']\n       self.critic_wg.init_model()\n\n   if self.use_reference_policy:\n       self.ref_policy_wg = all_wg['ref']\n       self.ref_policy_wg.init_model()\n\n   if self.use_rm:\n       self.rm_wg = all_wg['rm']\n       self.rm_wg.init_model()\n\n   # we should create rollout at the end so that vllm can have a better estimation of kv cache memory\n   self.actor_rollout_wg = all_wg['actor_rollout']\n   self.actor_rollout_wg.init_model()\n\n.. note:: For megatron backend, if we merge the ``worker_groups`` into the same processes, all the roles will utilize the same 3D parallel size. To optimize this, we may need to maintain several 3D process groups for each role in the same distributed context. If you want to use different 3D parallel size for different roles, please follow the similar architecture of the first code block to initialize each role's ``worker_group``\n\n\nPPO Training Loop\n-----------------\n\nWe implement the PPO training loop by calling the functions in\nworker_group of each role. The input and output data of each function is\na ``DataProto`` object implemented in `protocol.py <https://github.com/volcengine/verl/blob/main/verl/protocol.py>`_. In the training\nloop, trainer will dispatch/collect the data to/from different GPUs\nfollowing the transfer protocols wrapped in the workers' functions. The\ncomputation of PPO micro batches is processed in ``update_actor`` and\n``update_critic`` functions.\n\nTo extend to other RLHF algorithms, such as DPO, GRPO, please refer to\n:doc:`../advance/dpo_extension`.\n\n.. code:: python\n\n   def fit(self):\n       \"\"\"\n       The training loop of PPO.\n       The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.\n       The light-weight advantage computation is done on the driver process.\n       \"\"\"\n       from verl.utils.tracking import Tracking\n       from omegaconf import OmegaConf\n\n       logger = Tracking(project_name=self.config.trainer.project_name,\n                           experiment_name=self.config.trainer.experiment_name,\n                           default_backend=self.config.trainer.logger,\n                           config=OmegaConf.to_container(self.config, resolve=True))\n\n       global_steps = 0\n\n       # perform validation before training\n       # currently, we only support validation using the reward_function.\n       if self.val_reward_fn is not None:\n           val_metrics = self._validate()\n           pprint(f'Initial validation metrics: {val_metrics}')\n\n       for epoch in range(self.config.trainer.total_epochs):\n           for batch_dict in self.train_dataloader:\n               metrics = {}\n\n               batch: DataProto = DataProto.from_single_dict(batch_dict)\n               # batch = batch.to('cuda')\n\n               # pop those keys for generation\n               gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])\n\n               # generate a batch\n               with Timer(name='gen', logger=None) as timer:\n                   gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n               metrics['timing/gen'] = timer.last\n\n               batch = batch.union(gen_batch_output)\n\n               if self.use_reference_policy:\n                   # compute reference log_prob\n                   with Timer(name='ref', logger=None) as timer:\n                       ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                       batch = batch.union(ref_log_prob)\n                   metrics['timing/ref'] = timer.last\n\n               # compute values\n               with Timer(name='values', logger=None) as timer:\n                   values = self.critic_wg.compute_values(batch)\n                   batch = batch.union(values)\n               metrics['timing/values'] = timer.last\n\n               with Timer(name='adv', logger=None) as timer:\n                   # compute scores. Support both model and function-based.\n                   # We first compute the scores using reward model. Then, we call reward_fn to combine\n                   # the results from reward model and rule-based results.\n                   if self.use_rm:\n                       # we first compute reward model score\n                       reward_tensor = self.rm_wg.compute_rm_score(batch)\n                       batch = batch.union(reward_tensor)\n\n                   # we combine with rule-based rm\n                   reward_tensor = self.reward_fn(batch)\n                   batch.batch['token_level_scores'] = reward_tensor\n\n                   # compute rewards. apply_kl_penalty if available\n                   batch, kl_metrics = apply_kl_penalty(batch,\n                                                           kl_ctrl=self.kl_ctrl_in_reward,\n                                                           kl_penalty=self.config.algorithm.kl_penalty)\n                   metrics.update(kl_metrics)\n\n                   # compute advantages, executed on the driver process\n                   batch = compute_advantage(batch,\n                                               self.config.algorithm.gamma,\n                                               self.config.algorithm.lam,\n                                               adv_estimator=self.config.algorithm.adv_estimator)\n               metrics['timing/adv'] = timer.last\n\n               # update critic\n               if self.use_critic:\n                   with Timer(name='update_critic', logger=None) as timer:\n                       critic_output = self.critic_wg.update_critic(batch)\n                   metrics['timing/update_critic'] = timer.last\n                   critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])\n                   metrics.update(critic_output_metrics)\n\n               # implement critic warmup\n               if self.config.trainer.critic_warmup <= global_steps:\n                   # update actor\n                   with Timer(name='update_actor', logger=None) as timer:\n                       actor_output = self.actor_rollout_wg.update_actor(batch)\n                   metrics['timing/update_actor'] = timer.last\n                   actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])\n                   metrics.update(actor_output_metrics)\n\n               # validate\n               if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0:\n                   with Timer(name='testing', logger=None) as timer:\n                       val_metrics: dict = self._validate()\n                       val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}\n                   metrics['timing/testing'] = timer.last\n                   metrics.update(val_metrics)\n\n               # collect metrics\n               data_metrics = compute_data_metrics(batch=batch)\n               metrics.update(data_metrics)\n\n               # TODO: make a canonical logger that supports various backend\n               logger.log(data=metrics, step=global_steps)\n\n               if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0:\n                   actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',\n                                                   f'global_step_{global_steps}')\n                   actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor')\n                   self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)\n\n                   if self.use_critic:\n                       critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',\n                                                           f'global_step_{global_steps}')\n                       critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic')\n                       self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)\n\n               global_steps += 1\n\n       # perform validation after training\n       if self.val_reward_fn is not None:\n           val_metrics = self._validate()\n           pprint(f'Final validation metrics: {val_metrics}')\n"
  },
  {
    "path": "verl_distillation/docs/workers/sglang_worker.rst",
    "content": "SGLang Backend\n==============\n\nLast updated: 05/31/2025.\n\n**Authored By SGLang RL Team and listed alphabetically by last name**\n\n`Jingyi Chen <https://github.com/fzyzcjy>`_, `Yitong Guan <https://github.com/minleminzui>`_, `Zhuobin Huang <https://zobinhuang.github.io/sec_about/>`_, `Jiajun Li <https://github.com/guapisolo>`_, `Ji Li <https://github.com/GeLee-Q>`_, `Shenggui Li <https://franklee.xyz/about>`_, `Junrong Lin <https://github.com/ocss884>`_, `Xiang Long <https://github.com/SwordFaith>`_, `Rui Lu <https://scholar.google.com/citations?user=-MGuqDcAAAAJ>`_, `Jin Pan <https://jhinpan.github.io/>`_, `Shuai Shi <https://github.com/shuaills>`_, `Yushen Su <https://yushengsu-thu.github.io/>`_, `Xinyuan Tong <https://github.com/JustinTong0323>`_, `Chendong Wang <https://github.com/cedricbeta>`_, `Hanchen Zhang <https://scholar.google.com/citations?user=pGcJcagAAAAJ>`_, `Haoran Wang <https://ubecc.github.io/about/>`_, `Yongan Xiang <https://github.com/BearBiscuit05>`_, `Chengxing Xie <https://yitianlian.github.io/>`_, `Yuhao Yang <https://github.com/yhyang201>`_, `Jinwei Yao <https://kivi-yao.github.io/>`_, `Qiaolin Yu <https://github.com/Qiaolin-Yu>`_, `Yuzhen Zhou <https://github.com/zyzshishui>`_, `Chenyang Zhao <https://github.com/zhaochenyang20>`_\n\n\n\nIntroduction\n------------\n`SGLang <https://github.com/sgl-project/sglang>`_ is an open-source state-of-the-art inference service engine, fully adopted by xAI to support all inference needs of Grok during research and serving processes.\n\nCurrently, verl fully supports using SGLang as the inference engine during the rollout phase. As a rollout engine, SGLang provides the same feature coverage as vLLM., including memory saving and multi-node rollout features. After installing verl and SGLang, simply add ``actor_rollout_ref.rollout.name=sglang`` at startup script to seamlessly switch between the two inference frameworks.\n\nIn addition, the SGLang team is actively working on supporting features such as Multi-Turn Agentic RL, VLM RLHF, Server-Based RLHF, and Partial Rollout. You can track the related development progress in the `Tracking Roadmap <https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/74>`_.\n\nInstallation\n------------\nPlease always follow the following command to install SGLang with verl. \n\n.. code-block:: bash\n    \n    pip install --upgrade pip\n    # Currently 0.4.8, subject to updates at any time, please refer to the latest version specified in `setup.py`\n    pip install -e \".[sglang]\"\n\nYou can check the following dependencies are in your environment:\n\n.. note::\n\n    - **PyTorch**: 2.6.0+cu124\n    - **CUDA**: 12.4\n    - **flashinfer-python**: 0.2.5+cu124torch2.6\n    - **SGLang**: 0.4.6.post5\n    - **sgl-kernel**: 0.1.4\n\nUsing SGLang as the Inference Backend for PPO Training on a Single Machine\n-------------------------------------------------------------------------\nWe use Qwen/Qwen2-7B-Instruct on the gsm8k dataset for a simple test.\n\n1. Run the following command to prepare the gsm8k dataset:\n\n.. code-block:: bash\n\n    python3 examples/data_preprocess/gsm8k.py\n\n2. Run the following script to conduct a PPO experiment on a single machine with 4 GPUs:\n\n.. code-block:: bash\n\n    export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=True\n    PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n        data.train_files=$HOME/data/gsm8k/train.parquet \\\n        data.val_files=$HOME/data/gsm8k/test.parquet \\\n        data.train_batch_size=4096 \\\n        data.max_prompt_length=4096 \\\n        data.max_response_length=4096 \\\n        actor_rollout_ref.rollout.name=sglang \\\n        actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n        critic.optim.lr=1e-5 \\\n        critic.model.path=Qwen/Qwen2-7B-Instruct \\\n        critic.ppo_micro_batch_size_per_gpu=4 \\\n        critic.model.fsdp_config.param_offload=True \\\n        critic.model.fsdp_config.optimizer_offload=True \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.logger=console \\\n        trainer.val_before_train=False \\\n        trainer.n_gpus_per_node=4 \\\n        trainer.nnodes=1 \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15 2>&1 | tee verl_demo.log\n\nWhy export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n1. ``verl`` initializes a ``SGLangRollout`` module during rollout, which is used to evaluate/generate samples.\n\n2. ``SGLangRollout`` will initialize ``Engine``, and further initialize a ``torch.distributed.DeviceMesh``, used to support Tensor Parallel (TP).\n\n3. ``DeviceMesh.init()`` internally checks the free GPU memory of all participating devices. If the difference is too large (more than ~10%), it directly reports an error to avoid initialization failures or deadlocks.\n\nWhy might there be inconsistent GPU memory?\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n**1. Ray Distributed Actor loads the model at different times**\n\n``verl`` uses Ray-based multi-process, multi-GPU concurrent training. Each ``WorkerDict`` may be called at different times:\n\n.. code-block:: python\n\n    self.rollout = SGLangRollout(...)\n\nDifferent workers initialize the model at different times → different memory usage.\n\n**2. Delayed initialization causes memory bias**\n\nSome workers start model loading/inference (e.g., ``generate_sequences()``, ``compute_log_prob()``) earlier than others.  \nEarly workers already use up GPU memory → late workers still have empty memory → memory difference appears.\n\n**3. SGLang's TP init uses \"all-device broadcast\", but there's no uniform release timing**\n\nAlthough ``SGLangRollout`` may only involve subset of GPUs, its ``Engine`` initialization calls ``torch.distributed.init_process_group()`` and broadcasts weights, so:\n\n- Non-rollout GPUs also join the communication.\n- Later on, ``DeviceMesh`` init will fail due to \"inconsistent memory\".\n\n**4. Different FSDP/TP loading behaviors also lead to mismatch**\n\nIf using:\n\n.. code-block:: bash\n\n    actor.fsdp_config.param_offload=True  \n    ref.fsdp_config.param_offload=True\n\nThen some workers keep params on CPU while others already sharded to GPU → leads to asymmetric memory layout.\n\nUsing SGLang as the Inference Backend for PPO Training Across Multiple Machines\n------------------------------------------------------------------------------\nSGLang also supports running verl's RAY-based cross-machine inference in IPv4 and IPv6 scenarios. In the script below, we use TP=16 for cross-machine inference. Suppose we have two interconnected machines: node0 with IP 10.94.16.4 and node1 with IP 10.94.16.5.\n\n1. Start Ray on node0:\n\n.. code-block:: bash\n\n    ray start --head --dashboard-host=0.0.0.0\n\nYou will see the following prompt:\n\n.. code-block:: bash\n\n    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.\n\n    Local node IP: 10.94.16.4\n\n    --------------------\n    Ray runtime started.\n    --------------------\n\n    Next steps\n    To add another node to this Ray cluster, run\n        ray start --address='10.94.16.4:6379'\n\n2. Have node1 join the Ray cluster:\n\nRun the following command on node1:\n\n.. code-block:: bash\n\n    ray start --address='10.94.16.4:6379'\n\nRun the following command to confirm that the Ray cluster now has two nodes:\n\n.. code-block:: bash\n\n    ray status\n\nYou can see that the cluster has two nodes with 16 GPUs:\n\n.. code-block:: bash\n\n    ======== Autoscaler status: 2025-04-09 09:25:37.694016 ========\n    Node status\n    ---------------------------------------------------------------\n    Active:\n     1 node_ef382ffd687d8f6b060c1b68e63ada7341b936fe5b1901dd04de1027\n     1 node_1eb4d7d07e793114c23a89d1a41f1f76acf6ef5b35af844a4ee8e4ba\n    Pending:\n     (no pending nodes)\n    Recent failures:\n     (no failures)\n\n    Resources\n    ---------------------------------------------------------------\n    Usage:\n     0.0/360.0 CPU\n     0.0/16.0 GPU\n     0B/3.39TiB memory\n     0B/372.53GiB object_store_memory\n\n3. Run the following script to train meta-llama/Llama-3.1-8B-Instruct with TP=16 across 2 machines using 16 GPUs:\n\n.. code-block:: bash\n\n    DATA_DIR=$HOME/data/gsm8k\n\n    python3 -m verl.trainer.main_ppo \\\n        actor_rollout_ref.rollout.name=sglang \\\n        data.train_files=$DATA_DIR/train.parquet \\\n        data.val_files=$DATA_DIR/test.parquet \\\n        data.train_batch_size=4096 \\\n        data.max_prompt_length=4096 \\\n        data.max_response_length=4096 \\\n        actor_rollout_ref.model.path=meta-llama/Llama-3.1-8B-Instruct \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=16 \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n        actor_rollout_ref.rollout.free_cache_engine=True \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size=16 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n        critic.optim.lr=1e-5 \\\n        critic.model.use_remove_padding=True \\\n        critic.model.path=meta-llama/Llama-3.1-8B-Instruct \\\n        critic.model.enable_gradient_checkpointing=True \\\n        critic.ppo_micro_batch_size=16 \\\n        critic.model.fsdp_config.param_offload=True \\\n        critic.model.fsdp_config.optimizer_offload=True \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.critic_warmup=0 \\\n        trainer.logger=console \\\n        trainer.val_before_train=True \\\n        trainer.n_gpus_per_node=8 \\\n        trainer.nnodes=2 \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15 2>&1 | tee verl_demo.log\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/aime2024_multiturn_w_tool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the DAPO-Math-17k dataset to multiturn format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None, help=\"The save directory for the preprocessed dataset.\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/retool_aime2024\", help=\"The save directory for the preprocessed dataset.\"\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_path = \"BytedTsinghua-SIA/AIME-2024\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(local_dataset_path, \"default\")\n    else:\n        dataset = datasets.load_dataset(data_path, \"default\")\n\n    train_dataset = dataset[\"train\"]\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            orig_extra_info = example.pop(\"extra_info\")\n            extra_info = orig_extra_info.copy()\n            extra_info[\"need_tools_kwargs\"] = True\n            extra_info[\"tools_kwargs\"] = {\n                \"code_interpreter\": {\n                    \"create_kwargs\": {\n                        \"ground_truth\": example[\"reward_model\"][\"ground_truth\"],\n                    },\n                },\n            }\n            example[\"extra_info\"] = extra_info\n            return example\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n\n    hdfs_dir = args.hdfs_dir\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    train_dataset.to_parquet(os.path.join(local_save_dir, \"train.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_save_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/dapo_multiturn_w_tool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the DAPO-Math-17k dataset to multiturn format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None, help=\"The save directory for the preprocessed dataset.\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/retool_dapo\", help=\"The save directory for the preprocessed dataset.\"\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_path = \"BytedTsinghua-SIA/DAPO-Math-17k\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(local_dataset_path, \"default\")\n    else:\n        dataset = datasets.load_dataset(data_path, \"default\")\n\n    train_dataset = dataset[\"train\"]\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            orig_extra_info = example.pop(\"extra_info\")\n            extra_info = orig_extra_info.copy()\n            extra_info[\"need_tools_kwargs\"] = True\n            extra_info[\"tools_kwargs\"] = {\n                \"code_interpreter\": {\n                    \"create_kwargs\": {\n                        \"ground_truth\": example[\"reward_model\"][\"ground_truth\"],\n                    },\n                },\n            }\n            example[\"extra_info\"] = extra_info\n            return example\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n\n    hdfs_dir = args.hdfs_dir\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    train_dataset.to_parquet(os.path.join(local_save_dir, \"train.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_save_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/full_hh_rlhf.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\n- Preprocess data and split the training set into 75% for training RM and 25% for validting RM.\n- All the training data is used to train SFT and RL.\n- Both chosen and rejected is used to train SFT\n\"\"\"\n\nimport argparse\nimport os\n\nimport pandas as pd\nfrom datasets import load_dataset\nfrom tqdm.auto import tqdm\n\nfrom verl.utils.fs import copy, makedirs\n\n\ndef generate_sft_dataset(target_hdfs_path_dir, local_dir=\"~/data/full_hh_rlh/sft\", local_dataset_path=None):\n    if local_dataset_path is not None:\n        dataset = load_dataset(local_dataset_path)\n    else:\n        dataset = load_dataset(\"Dahoas/full-hh-rlhf\")\n    output = {\"prompt\": [], \"response\": []}\n    for data in tqdm(dataset[\"train\"]):\n        # add chosen\n        output[\"prompt\"].append(data[\"prompt\"])\n        output[\"response\"].append(data[\"chosen\"])\n\n        # add rejection\n        output[\"prompt\"].append(data[\"prompt\"])\n        output[\"response\"].append(data[\"rejected\"])\n\n    df = pd.DataFrame(output)\n\n    local_dir = os.path.expanduser(local_dir)\n    os.makedirs(local_dir, exist_ok=True)\n\n    local_path = os.path.join(local_dir, \"train.parquet\")\n\n    df.to_parquet(path=local_path)\n\n    if target_hdfs_path_dir is not None:\n        hdfs_dir = target_hdfs_path_dir + \"/\" + \"train.parquet\"\n        makedirs(hdfs_dir)\n\n        copy(local_path, hdfs_dir)\n\n\ndef generate_rm_dataset(target_hdfs_path_dir, local_dir=\"~/data/full_hh_rlh/rm\", local_dataset_path=None):\n    if local_dataset_path is not None:\n        train_dataset = load_dataset(local_dataset_path, split=\"train[:75%]\")\n        test_dataset = load_dataset(local_dataset_path, split=\"train[-25%:]\")\n    else:\n        train_dataset = load_dataset(\"Dahoas/full-hh-rlhf\", split=\"train[:75%]\")\n        test_dataset = load_dataset(\"Dahoas/full-hh-rlhf\", split=\"train[-25%:]\")\n\n    local_dir = os.path.expanduser(local_dir)\n    os.makedirs(local_dir, exist_ok=True)\n\n    for dataset, name in zip([train_dataset, test_dataset], [\"train\", \"test\"], strict=True):\n        output = {\"prompt\": [], \"chosen\": [], \"rejected\": []}\n        for data in tqdm(dataset):\n            # add chosen\n            output[\"prompt\"].append(data[\"prompt\"])\n            output[\"chosen\"].append(data[\"chosen\"])\n            output[\"rejected\"].append(data[\"rejected\"])\n\n        df = pd.DataFrame(output)\n\n        local_path = os.path.join(local_dir, name + \".parquet\")\n\n        df.to_parquet(path=local_path)\n\n        if target_hdfs_path_dir is not None:\n            hdfs_dir = target_hdfs_path_dir + \"/\" + name + \".parquet\"\n            makedirs(hdfs_dir)\n\n            copy(local_path, hdfs_dir)\n\n\ndef generate_rl_dataset(target_hdfs_path_dir, local_dir=\"~/data/full_hh_rlhf/rl\", local_dataset_path=None):\n    if local_dataset_path is not None:\n        dataset = load_dataset(local_dataset_path)\n    else:\n        dataset = load_dataset(\"Dahoas/full-hh-rlhf\")\n    train_dataset = dataset[\"train\"]\n\n    data_source = \"Dahoas/full-hh-rlhf\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            prompt = example.pop(\"prompt\")\n            response = example.pop(\"response\")\n\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [{\"role\": \"user\", \"content\": prompt}],\n                \"ability\": \"alignment\",\n                \"reward_model\": {\n                    \"style\": \"model\",\n                    \"ground_truth\": response,  # should not be used\n                },\n                \"extra_info\": {\"split\": split, \"index\": idx},\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    local_dir = os.path.expanduser(local_dir)\n    local_path = os.path.join(local_dir, \"train.parquet\")\n    train_dataset.to_parquet(local_path)\n\n    if target_hdfs_path_dir is not None:\n        hdfs_dir = target_hdfs_path_dir + \"/\" + \"train.parquet\"\n        makedirs(hdfs_dir)\n\n        copy(local_path, hdfs_dir)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--split\", type=str, choices=[\"sft\", \"rm\", \"rl\"], required=True)\n    parser.add_argument(\"--local_dir\", default=None, help=\"The save directory for the preprocessed dataset.\")\n    parser.add_argument(\"--hdfs_dir\", type=str, required=False, default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\",\n        type=str,\n        default=\"~/data/full_hh_rlhf\",\n        help=\"The save directory for the preprocessed dataset.\",\n    )\n\n    args = parser.parse_args()\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    if args.split == \"sft\":\n        generate_sft_dataset(args.hdfs_dir, os.path.join(local_save_dir, args.split), args.local_dataset_path)\n    elif args.split == \"rm\":\n        generate_rm_dataset(args.hdfs_dir, os.path.join(local_save_dir, args.split), args.local_dataset_path)\n    elif args.split == \"rl\":\n        generate_rl_dataset(args.hdfs_dir, os.path.join(local_save_dir, args.split), args.local_dataset_path)\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/geo3k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the Geometry3k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None)\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/geo3k\", help=\"The save directory for the preprocessed dataset.\"\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_source = \"hiyouga/geometry3k\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(\n            local_dataset_path,\n        )\n    else:\n        dataset = datasets.load_dataset(\n            data_source,\n        )\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = (\n        r\"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. \"\n        r\"The reasoning process MUST BE enclosed within <think> </think> tags. \"\n        r\"The final answer MUST BE put in \\boxed{}.\"\n    )\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            problem = example.pop(\"problem\")\n            prompt = problem + \" \" + instruction_following\n            answer = example.pop(\"answer\")\n            images = example.pop(\"images\")\n\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": prompt,\n                    }\n                ],\n                \"images\": images,\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": answer},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer,\n                    \"question\": problem,\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True, num_proc=8)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True, num_proc=8)\n\n    hdfs_dir = args.hdfs_dir\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    train_dataset.to_parquet(os.path.join(local_save_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_save_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_save_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/geo3k_multiturn_w_tool.py",
    "content": "# Copyright 2023-2025 SGLang Team\n# Copyright Amazon.com, Inc. or its affiliates.\n# Copyright 2025 Reallm Labs Ltd. or its affiliates\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nPreprocess the Geometry3k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None, help=\"The save directory for the preprocessed dataset.\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\",\n        default=\"~/data/geo3k_multiturn_w_tool\",\n        help=\"The save directory for the preprocessed dataset.\",\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_source = \"hiyouga/geometry3k\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(local_dataset_path)\n    else:\n        dataset = datasets.load_dataset(data_source)\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = (\n        r\"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. \"\n        r\"The reasoning process MUST BE enclosed within <think> </think> tags. \"\n        r\"The final answer MUST BE put in \\boxed{}.\"\n    )\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            problem = example.pop(\"problem\")\n            prompt = problem + \" \" + instruction_following\n            answer = example.pop(\"answer\")\n            images = example.pop(\"images\")\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"system\",\n                        \"content\": (\n                            \"You are a math expert. You are given a question and you need to solve it step by step. \"\n                            \"Reasoning step by step before any tool call. \"\n                            \"You should use the `calc_geo3k_reward` tool after step by step solving the question, \"\n                            \"before generate final answer at least once and refine your answer if necessary. \"\n                        ),\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": prompt,\n                    },\n                ],\n                \"images\": images,\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": answer},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer,\n                    \"question\": problem,\n                    \"need_tools_kwargs\": True,\n                    \"tools_kwargs\": {\n                        \"calc_geo3k_reward\": {\n                            \"create_kwargs\": {\"ground_truth\": answer},\n                            # \"execute_kwargs\": {},\n                            # \"calc_reward_kwargs\": {},\n                            # \"release_kwargs\": {},\n                        },\n                    },\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True, num_proc=8)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True, num_proc=8)\n\n    hdfs_dir = args.hdfs_dir\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    train_dataset.to_parquet(os.path.join(local_save_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_save_dir, \"test.parquet\"))\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_save_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/gsm8k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the GSM8k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef extract_solution(solution_str):\n    solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n    assert solution is not None\n    final_solution = solution.group(0)\n    final_solution = final_solution.split(\"#### \")[1].replace(\",\", \"\")\n    return final_solution\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None, help=\"The save directory for the preprocessed dataset.\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/gsm8k\", help=\"The save directory for the preprocessed dataset.\"\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_source = \"openai/gsm8k\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(local_dataset_path, \"main\")\n    else:\n        dataset = datasets.load_dataset(data_source, \"main\")\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = 'Let\\'s think step by step and output the final answer after \"####\".'\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"question\")\n\n            question = question_raw + \" \" + instruction_following\n\n            answer_raw = example.pop(\"answer\")\n            solution = extract_solution(answer_raw)\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    }\n                ],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer_raw,\n                    \"question\": question_raw,\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    hdfs_dir = args.hdfs_dir\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    train_dataset.to_parquet(os.path.join(local_save_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_save_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_save_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/gsm8k_multiturn_sft.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the GSM8k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef extract_solution(solution_str):\n    solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n    assert solution is not None\n    final_solution = solution.group(0)\n    final_solution = final_solution.split(\"#### \")[1].replace(\",\", \"\")\n    return final_solution\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/gsm8k_sft\", help=\"The save directory for the preprocessed dataset.\"\n    )\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_source = \"openai/gsm8k\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(local_dataset_path, \"main\")\n    else:\n        dataset = datasets.load_dataset(data_source, \"main\")\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = 'Let\\'s think step by step and output the final answer after \"####\".'\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"question\")\n\n            question = question_raw + \" \" + instruction_following\n\n            answer_raw = example.pop(\"answer\")\n            data = {\n                \"messages\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": answer_raw,\n                    },\n                ],\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    hdfs_dir = args.hdfs_dir\n\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    local_save_dir = os.path.expanduser(local_save_dir)\n\n    train_dataset.to_parquet(os.path.join(local_save_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_save_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_save_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/gsm8k_multiturn_w_interaction.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the GSM8k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef extract_solution(solution_str):\n    solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n    assert solution is not None\n    final_solution = solution.group(0)\n    final_solution = final_solution.split(\"#### \")[1].replace(\",\", \"\")\n    return final_solution\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None, help=\"The save directory for the preprocessed dataset.\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/gsm8k\", help=\"The save directory for the preprocessed dataset.\"\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_source = \"openai/gsm8k\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(local_dataset_path, \"main\")\n    else:\n        dataset = datasets.load_dataset(data_source, \"main\")\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = \"Let's think step by step and output the final answer after `####`.\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"question\")\n\n            question = question_raw + \" \" + instruction_following\n\n            answer_raw = example.pop(\"answer\")\n            solution = extract_solution(answer_raw)\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"system\",\n                        \"content\": (\n                            \"You are a math expert. You are given a question and you need to solve it step by step. \"\n                            \"You should rethinking carefully if user point out your answer is wrong. \"\n                            \"Put your final answer in the format of `#### <answer>`.\"\n                        ),\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    },\n                ],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer_raw,\n                    \"question\": question_raw,\n                    \"interaction_kwargs\": {\n                        \"name\": \"gsm8k\",\n                        \"query\": question,\n                        \"ground_truth\": solution,\n                    },\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    hdfs_dir = args.hdfs_dir\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    train_dataset.to_parquet(os.path.join(local_save_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_save_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_save_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/gsm8k_multiturn_w_tool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the GSM8k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef extract_solution(solution_str):\n    solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n    assert solution is not None\n    final_solution = solution.group(0)\n    final_solution = final_solution.split(\"#### \")[1].replace(\",\", \"\")\n    return final_solution\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None, help=\"The save directory for the preprocessed dataset.\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/gsm8k\", help=\"The save directory for the preprocessed dataset.\"\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_source = \"openai/gsm8k\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(local_dataset_path, \"main\")\n    else:\n        dataset = datasets.load_dataset(data_source, \"main\")\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = \"Let's think step by step and output the final answer after `####`.\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"question\")\n\n            question = question_raw + \" \" + instruction_following\n\n            answer_raw = example.pop(\"answer\")\n            solution = extract_solution(answer_raw)\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"system\",\n                        \"content\": (\n                            \"You are a math expert. You are given a question and you need to solve it step by step. \"\n                            \"Reasoning step by step before any tool call. \"\n                            \"You should use the `calc_gsm8k_reward` tool after step by step solving the question, \"\n                            \"before generate final answer at least once and refine your answer if necessary. \"\n                            \"Put your final answer in the format of `#### <answer>`.\"\n                        ),\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    },\n                ],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer_raw,\n                    \"question\": question_raw,\n                    \"need_tools_kwargs\": True,\n                    \"tools_kwargs\": {\n                        \"calc_gsm8k_reward\": {\n                            \"create_kwargs\": {\"ground_truth\": solution},\n                            # \"execute_kwargs\": {},\n                            # \"calc_reward_kwargs\": {},\n                            # \"release_kwargs\": {},\n                        },\n                    },\n                    \"interaction_kwargs\": {\n                        \"query\": question,\n                        \"ground_truth\": solution,\n                    },\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    hdfs_dir = args.hdfs_dir\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    train_dataset.to_parquet(os.path.join(local_save_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_save_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_save_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/gsm8k_tool_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the GSM8k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef extract_solution(solution_str):\n    solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n    assert solution is not None\n    final_solution = solution.group(0)\n    final_solution = final_solution.split(\"#### \")[1].replace(\",\", \"\")\n    return final_solution\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None, help=\"The save directory for the preprocessed dataset.\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/gsm8k\", help=\"The save directory for the preprocessed dataset.\"\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_source = \"openai/gsm8k\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(local_dataset_path, \"main\")\n    else:\n        dataset = datasets.load_dataset(data_source, \"main\")\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = \"Let's think step by step and output the final answer after `####`.\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"question\")\n\n            question = question_raw + \" \" + instruction_following\n\n            answer_raw = example.pop(\"answer\")\n            solution = extract_solution(answer_raw)\n            data = {\n                \"data_source\": data_source,\n                \"agent_name\": \"tool_agent\",\n                \"prompt\": [\n                    {\n                        \"role\": \"system\",\n                        \"content\": (\n                            \"You are a math expert. You are given a question and you need to solve it step by step. \"\n                            \"Reasoning step by step before any tool call. \"\n                            \"You should use the `calc_gsm8k_reward` tool after step by step solving the question, \"\n                            \"before generate final answer at least once and refine your answer if necessary. \"\n                            \"Put your final answer in the format of `#### <answer>`.\"\n                        ),\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    },\n                ],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer_raw,\n                    \"question\": question_raw,\n                    \"need_tools_kwargs\": True,\n                    \"tools_kwargs\": {\n                        \"calc_gsm8k_reward\": {\n                            \"create_kwargs\": {\"ground_truth\": solution},\n                            # \"execute_kwargs\": {},\n                            # \"calc_reward_kwargs\": {},\n                            # \"release_kwargs\": {},\n                        },\n                    },\n                    \"interaction_kwargs\": {\n                        \"query\": question,\n                        \"ground_truth\": solution,\n                    },\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    hdfs_dir = args.hdfs_dir\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    train_dataset.to_parquet(os.path.join(local_save_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_save_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_save_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/hellaswag.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess Hellaswag dataset.\n\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef preprocess(text):\n    text = text.strip()\n    # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.\n    text = text.replace(\" [title]\", \". \")\n    text = re.sub(\"\\\\[.*?\\\\]\", \"\", text)\n    text = text.replace(\"  \", \" \")\n    return text\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None, help=\"The save directory for the preprocessed dataset.\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/hellaswag\", help=\"The save directory for the preprocessed dataset.\"\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_source = \"Rowan/hellaswag\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(local_dataset_path)\n    else:\n        dataset = datasets.load_dataset(data_source, trust_remote_code=True)\n\n    train_dataset = dataset[\"train\"]\n    val_dataset = dataset[\"validation\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction = \"Please complete the following sentence.\\n\"\n\n    def make_map_fn(split):\n        def process_fn(doc, idx):\n            ctx = doc[\"ctx_a\"] + \" \" + doc[\"ctx_b\"].capitalize()\n            query = preprocess(doc[\"activity_label\"] + \": \" + ctx)\n            choices = [preprocess(ending) for ending in doc[\"endings\"]]\n            gold = int(doc[\"label\"])\n\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [{\"role\": \"user\", \"content\": query}],\n                \"ability\": \"nlp\",\n                \"reward_model\": {\n                    \"style\": \"model\",\n                    \"eval\": \"multiple_choice\",  # using loglikelihood\n                    \"ground_truth\": gold,\n                    \"choices\": choices,\n                },\n                \"extra_info\": {\"split\": split, \"index\": idx},\n            }\n            return data\n\n        return process_fn\n\n    # filter data that doesn't have a label\n    train_dataset = train_dataset.filter(lambda x: len(x[\"label\"]) > 0)\n    val_dataset = val_dataset.filter(lambda x: len(x[\"label\"]) > 0)\n    test_dataset = test_dataset.filter(lambda x: len(x[\"label\"]) > 0)\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    val_dataset = val_dataset.map(function=make_map_fn(\"validation\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    hdfs_dir = args.hdfs_dir\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    train_dataset.to_parquet(os.path.join(local_save_dir, \"train.parquet\"))\n    val_dataset.to_parquet(os.path.join(local_save_dir, \"validation.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_save_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_save_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/math_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the MATH-lighteval dataset to parquet format\n\"\"\"\n\nimport argparse\nimport json\nimport os\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\nfrom verl.utils.reward_score.math_reward import last_boxed_only_string, remove_boxed\n\n\ndef extract_solution(solution_str):\n    return remove_boxed(last_boxed_only_string(solution_str))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=None)\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/math\", help=\"The save directory for the preprocessed dataset.\"\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    # 'lighteval/MATH' is no longer available on huggingface.\n    # Use mirror repo: DigitalLearningGmbH/MATH-lighteval\n    data_source = \"DigitalLearningGmbH/MATH-lighteval\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(\n            local_dataset_path,\n        )\n    else:\n        dataset = datasets.load_dataset(\n            data_source,\n        )\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = \"Let's think step by step and output the final answer within \\\\boxed{}.\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question = example.pop(\"problem\")\n\n            question = question + \" \" + instruction_following\n\n            answer = example.pop(\"solution\")\n            solution = extract_solution(answer)\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [{\"role\": \"user\", \"content\": question}],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\"split\": split, \"index\": idx},\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    local_save_dir = args.local_dir\n    if local_save_dir is not None:\n        print(\"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.\")\n    else:\n        local_save_dir = args.local_save_dir\n\n    local_dir = os.path.expanduser(local_save_dir)\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n    # Save one example as JSON for reference\n    example = train_dataset[0]\n    with open(os.path.join(local_dir, \"train_example.json\"), \"w\") as f:\n        json.dump(example, f, indent=2)\n    example = test_dataset[0]\n    with open(os.path.join(local_dir, \"test_example.json\"), \"w\") as f:\n        json.dump(example, f, indent=2)\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/multiturn.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n\n#     http://www.apache.org/licenses/LICENSE-2.0\n\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nCreate a simple multi-turn dataset for testing\n\"\"\"\n\nimport argparse\nimport os\n\nimport pandas as pd\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/multiturn\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    args = parser.parse_args()\n\n    # Create example conversations\n    conversations = []\n\n    # Conversation 1\n    conversations.append(\n        {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n                {\"role\": \"assistant\", \"content\": \"The capital of France is Paris.\"},\n                {\"role\": \"user\", \"content\": \"And what about Germany?\"},\n                {\"role\": \"assistant\", \"content\": \"The capital of Germany is Berlin.\"},\n            ]\n        }\n    )\n\n    # Conversation 2\n    conversations.append(\n        {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"Can you explain quantum computing?\"},\n                {\n                    \"role\": \"assistant\",\n                    \"content\": \"Quantum computing is a type of computing that uses quantum-mechanical phenomena, \"\n                    \"such as superposition and entanglement, to perform operations on data.\",\n                },\n                {\"role\": \"user\", \"content\": \"How is it different from classical computing?\"},\n                {\n                    \"role\": \"assistant\",\n                    \"content\": \"Classical computing uses bits that are either 0 or 1, while quantum computing uses \"\n                    \"quantum bits or qubits that can exist in multiple states simultaneously due to superposition.\",\n                },\n            ]\n        }\n    )\n\n    # Conversation 3\n    conversations.append(\n        {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"Write a simple Python function to calculate factorial.\"},\n                {\n                    \"role\": \"assistant\",\n                    \"content\": (\n                        \"```python\\ndef factorial(n):\\n    if n == 0 or n == 1:\\n        return 1\\n    else:\\n        \"\n                        \"return n * factorial(n-1)\\n```\\n\\nThis is a recursive function to calculate the \"\n                        \"factorial of a number.\"\n                    ),\n                },\n                {\"role\": \"user\", \"content\": \"Can you make it iterative instead?\"},\n                {\n                    \"role\": \"assistant\",\n                    \"content\": (\n                        \"```python\\ndef factorial(n):\\n    result = 1\\n    for i in range(1, n+1):\\n        \"\n                        \"result *= i\\n    return result\\n```\\n\\nThis is an iterative version of the factorial function.\"\n                    ),\n                },\n            ]\n        }\n    )\n\n    # Create train and test datasets\n    train_data = conversations[:2]  # First 2 conversations for training\n    test_data = conversations[2:]  # Last conversation for testing\n\n    # Create output directory\n    local_dir = os.path.expanduser(args.local_dir)\n    os.makedirs(local_dir, exist_ok=True)\n\n    # Save to parquet files\n    train_df = pd.DataFrame(train_data)\n    test_df = pd.DataFrame(test_data)\n\n    train_df.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_df.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    # Handle HDFS if specified\n    if args.hdfs_dir is not None:\n        try:\n            from verl.utils.hdfs_io import copy, makedirs\n\n            makedirs(args.hdfs_dir)\n            copy(src=local_dir, dst=args.hdfs_dir)\n        except ImportError:\n            print(\"Warning: HDFS support not available. Skipping HDFS copy.\")\n\n    # Print statistics\n    print(f\"Train dataset size: {len(train_df)}\")\n    print(f\"Test dataset size: {len(test_df)}\")\n    print(f\"Data saved to {local_dir}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/examples/data_preprocess/preprocess_search_r1_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport argparse\r\nimport logging\r\nimport os\r\nimport tempfile\r\n\r\nimport pandas as pd\r\nfrom huggingface_hub import hf_hub_download\r\nfrom huggingface_hub.utils import EntryNotFoundError\r\n\r\nfrom verl.utils.hdfs_io import copy, makedirs\r\n\r\n# Setup logging\r\nlogging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\r\nlogger = logging.getLogger(__name__)\r\n\r\n# Configuration constants\r\nDEFAULT_SYSTEM_CONTENT = \"You are a helpful and harmless assistant.\"\r\nDEFAULT_USER_CONTENT_PREFIX = (\r\n    \"Answer the given question. You must conduct reasoning inside <think> and </think> \"\r\n    \"first every time you get new information. After reasoning, if you find you lack \"\r\n    \"some knowledge, you can call a search engine by <tool_call> query </tool_call> \"\r\n    \"and it will return the top searched results between <tool_response> and \"\r\n    \"</tool_response>. You can search as many times as your want. If you find no \"\r\n    \"further external knowledge needed, you can directly provide the answer inside \"\r\n    \"<answer> and </answer>, without detailed illustrations. For example, \"\r\n    \"<answer> Beijing </answer>. Question: \"\r\n)\r\n\r\n\r\ndef process_single_row(row, current_split_name, row_index):\r\n    \"\"\"\r\n    Process a single row of data for SearchR1-like format.\r\n\r\n    Args:\r\n        row: DataFrame row containing the original data\r\n        current_split_name: Name of the current split (train/test)\r\n        row_index: Index of the row in the DataFrame\r\n\r\n    Returns:\r\n        pd.Series: Processed row data in the required format\r\n    \"\"\"\r\n    question = row.get(\"question\", \"\")\r\n\r\n    # Build prompt structure\r\n    user_content = user_content_prefix.rstrip(\"\\n\") + question\r\n    prompt = [{\"role\": \"system\", \"content\": system_content}, {\"role\": \"user\", \"content\": user_content}]\r\n\r\n    # Extract ground truth from reward_model or fallback to golden_answers\r\n    reward_model_data = row.get(\"reward_model\")\r\n    if isinstance(reward_model_data, dict) and \"ground_truth\" in reward_model_data:\r\n        ground_truth = reward_model_data.get(\"ground_truth\")\r\n    else:\r\n        ground_truth = row.get(\"golden_answers\", [])\r\n\r\n    # Process data source\r\n    data_source_tagged = \"searchR1_\" + str(row.get(\"data_source\", \"\"))\r\n\r\n    # Build tools kwargs structure\r\n    tools_kwargs = {\r\n        \"search\": {\r\n            \"create_kwargs\": {\"ground_truth\": ground_truth, \"question\": question, \"data_source\": data_source_tagged}\r\n        }\r\n    }\r\n\r\n    # Build complete extra_info structure\r\n    extra_info = {\r\n        \"index\": row_index,\r\n        \"need_tools_kwargs\": True,\r\n        \"question\": question,\r\n        \"split\": current_split_name,\r\n        \"tools_kwargs\": tools_kwargs,\r\n    }\r\n\r\n    return pd.Series(\r\n        {\r\n            \"data_source\": data_source_tagged,\r\n            \"prompt\": prompt,\r\n            \"ability\": row.get(\"ability\"),\r\n            \"reward_model\": reward_model_data,\r\n            \"extra_info\": extra_info,\r\n            \"metadata\": row.get(\"metadata\"),\r\n        }\r\n    )\r\n\r\n\r\ndef main():\r\n    local_save_dir = os.path.expanduser(args.local_dir)\r\n    os.makedirs(local_save_dir, exist_ok=True)\r\n\r\n    processed_files = []\r\n\r\n    # Download and process files using temporary directory\r\n    with tempfile.TemporaryDirectory() as tmp_download_dir:\r\n        for split in [\"train\", \"test\"]:\r\n            parquet_filename = f\"{split}.parquet\"\r\n            logger.info(f\"Processing {split} split...\")\r\n\r\n            try:\r\n                # Download Parquet file from HuggingFace\r\n                logger.info(f\"Downloading {parquet_filename} from {args.hf_repo_id}\")\r\n                local_parquet_filepath = hf_hub_download(\r\n                    repo_id=args.hf_repo_id,\r\n                    filename=parquet_filename,\r\n                    repo_type=\"dataset\",\r\n                    local_dir=tmp_download_dir,\r\n                    local_dir_use_symlinks=False,\r\n                )\r\n\r\n                # Load and process Parquet file\r\n                df_raw = pd.read_parquet(local_parquet_filepath)\r\n                logger.info(f\"Loaded {len(df_raw)} rows from {parquet_filename}\")\r\n\r\n                def apply_process_row(row, split_name=split):\r\n                    return process_single_row(row, current_split_name=split_name, row_index=row.name)\r\n\r\n                df_processed = df_raw.apply(apply_process_row, axis=1)\r\n\r\n                # Save processed DataFrame\r\n                output_file_path = os.path.join(local_save_dir, f\"{split}.parquet\")\r\n                df_processed.to_parquet(output_file_path, index=False)\r\n                logger.info(f\"Saved {len(df_processed)} processed rows to {output_file_path}\")\r\n                processed_files.append(output_file_path)\r\n\r\n            except EntryNotFoundError:\r\n                logger.warning(f\"{parquet_filename} not found in repository {args.hf_repo_id}\")\r\n            except Exception as e:\r\n                logger.error(f\"Error processing {split} split: {e}\")\r\n\r\n    if not processed_files:\r\n        logger.warning(\"No data was processed or saved\")\r\n        return\r\n\r\n    logger.info(f\"Successfully processed {len(processed_files)} files to {local_save_dir}\")\r\n\r\n    # Copy to HDFS if specified\r\n    if args.hdfs_dir:\r\n        try:\r\n            makedirs(args.hdfs_dir)\r\n            copy(src=local_save_dir, dst=args.hdfs_dir)\r\n            logger.info(f\"Successfully copied files to HDFS: {args.hdfs_dir}\")\r\n        except Exception as e:\r\n            logger.error(f\"Error copying files to HDFS: {e}\")\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    parser = argparse.ArgumentParser(description=\"Download Search-R1 from HuggingFace, process, and save to Parquet.\")\r\n    parser.add_argument(\r\n        \"--hf_repo_id\", default=\"PeterJinGo/nq_hotpotqa_train\", help=\"HuggingFace dataset repository ID.\"\r\n    )\r\n    parser.add_argument(\r\n        \"--local_dir\",\r\n        default=\"~/data/searchR1_processed_direct\",\r\n        help=\"Local directory to save the processed Parquet files.\",\r\n    )\r\n    parser.add_argument(\"--hdfs_dir\", default=None, help=\"Optional HDFS directory to copy the Parquet files to.\")\r\n\r\n    args = parser.parse_args()\r\n\r\n    # System and user content configuration\r\n    system_content = DEFAULT_SYSTEM_CONTENT\r\n    user_content_prefix = DEFAULT_USER_CONTENT_PREFIX\r\n\r\n    main()\r\n"
  },
  {
    "path": "verl_distillation/examples/generation/run_deepseek7b_mutli_node.sh",
    "content": "set -x\n\ndata_path=$HOME/data/rlhf/gsm8k/test.parquet\nsave_path=$HOME/data/rlhf/math/deepseek_v2_lite_gen_test.parquet\nmodel_path=deepseek-ai/deepseek-llm-7b-chat\n\npython3 -m verl.trainer.main_generation \\\n    trainer.nnodes=2 \\\n    trainer.n_gpus_per_node=8 \\\n    data.path=$data_path \\\n    data.prompt_key=prompt \\\n    data.n_samples=1 \\\n    data.output_path=$save_path \\\n    model.path=$model_path\\\n    +model.trust_remote_code=True \\\n    rollout.temperature=1.0 \\\n    rollout.top_k=50 \\\n    rollout.top_p=0.7 \\\n    rollout.prompt_length=2048 \\\n    rollout.response_length=1024 \\\n    rollout.tensor_model_parallel_size=16 \\\n    rollout.gpu_memory_utilization=0.8\n"
  },
  {
    "path": "verl_distillation/examples/generation/run_deepseek_v2_lite_math.sh",
    "content": "set -x\n\ndata_path=$HOME/data/gsm8k/test.parquet\nsave_path=$HOME/data/gsm8k/deepseek_v2_lite_gen_test.parquet\nmodel_path=deepseek-ai/deepseek-llm-7b-chat\n\npython3 -m verl.trainer.main_generation \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=8 \\\n    data.path=$data_path \\\n    data.prompt_key=prompt \\\n    data.n_samples=1 \\\n    data.output_path=$save_path \\\n    model.path=$model_path \\\n    +model.trust_remote_code=True \\\n    rollout.temperature=1.0 \\\n    rollout.top_k=50 \\\n    rollout.top_p=0.7 \\\n    rollout.prompt_length=2048 \\\n    rollout.response_length=1024 \\\n    rollout.tensor_model_parallel_size=2 \\\n    rollout.gpu_memory_utilization=0.8\n"
  },
  {
    "path": "verl_distillation/examples/gmpo_trainer/README.md",
    "content": "<div align=center>\n  \n# Geometric-Mean Policy Optimization\n</div>\n\nThis is the official implementaion of paper [***Geometric-Mean Policy Optimization***](https://arxiv.org/abs/2507.20673).\n\n<div align=center>\n<img width=\"3092\" height=\"864\" alt=\"image\" src=\"https://github.com/user-attachments/assets/20b04c4e-7ee8-4775-9af8-33c0158336e2\" />\n</div>\n\n## 1. Contents\n- Geometric-Mean Policy Optimization\n  - [1. Contents](#1-contents)\n  - [2. Introduction](#2-introduction)\n  - [3. Code Usage](#3-code-usage)\n  - [4. Contacts](#4-contacts)\n  - [5. Citation](#5-citation)\n\n## 2. Introduction\n\nGroup Relative Policy Optimization (GRPO) has significantly enhanced the reasoning capability of large language models by optimizing the arithmetic mean of token-level rewards. Unfortunately, GRPO is observed to suffer from unstable policy updates when facing tokens with outlier importance-weighted rewards, which manifest as extreme importance sampling ratios during training. In this study, we propose Geometric-Mean Policy Optimization (GMPO), with the aim to improve the stability of GRPO through suppressing token reward outliers. Instead of optimizing the arithmetic mean, GMPO maximizes the geometric mean of token-level rewards, which is inherently less sensitive to outliers and maintains a more stable range of importance sampling ratio. GMPO is plug-and-play—simply replacing GRPO's arithmetic mean with the geometric mean of token-level rewards, as the latter is inherently less sensitive to outliers. GMPO is theoretically plausible—analysis reveals that both GMPO and GRPO are weighted forms of the policy gradient while the former enjoys more stable weights, which consequently benefits policy optimization and performance. Experiments on multiple mathematical reasoning benchmarks show that GMPO-7B improves the average Pass@1 of GRPO by up to 4.1%, outperforming many state-of-the-art approaches.\n\n## 3. Code Usage\n\nThe key configurations are:\n```\nclip_ratio_low=0.4\nclip_ratio_high=0.4\nloss_mode=geo_mean\n```\nWe observed that using a large clip ratio during Mixture-of-Experts (MoE) model training often leads to optimization instability. When training MoE models, consider lowering the clip ratio to achieve more stable convergence.\nTo get started quickly, run:\n```\nbash examples/gmpo_trainer/run_qwen2_5-7b_math.sh\n```\n\nGMPO can be combined with other methods such as DAPO (experimental - not fully tested):\n```\nbash examples/gmpo_trainer/test_dapo_7b_math.sh \nbash examples/gmpo_trainer/test_dapo_qwen3_30b_math.sh\n```\n\n## 4. Contacts\nIf you have any question about our work or this repository, please don't hesitate to contact us by emails or open an issue under this project.\n- [zhaoyuzhong20@mails.ucas.ac.cn](zhaoyuzhong20@mails.ucas.ac.cn)\n- [liuyue171@mails.ucas.ac.cn](liuyue171@mails.ucas.ac.cn)\n- [lecu@microsoft.com](lecu@microsoft.com)\n- [wanfang@ucas.ac.cn](wanfang@ucas.ac.cn)\n\n## 5. Citation\n```\n@article{zhao2025geometric,\n  title={Geometric-mean policy optimization},\n  author={Zhao, Yuzhong and Liu, Yue and Liu, Junpeng and Chen, Jingye and Wu, Xun and Hao, Yaru and Lv, Tengchao and Huang, Shaohan and Cui, Lei and Ye, Qixiang and others},\n  journal={arXiv preprint arXiv:2507.20673},\n  year={2025}\n}\n```\n"
  },
  {
    "path": "verl_distillation/examples/gmpo_trainer/run_qwen2_5-7b_math.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\nuse_kl_loss=False\nloss_mode=geo_mean\nclip_ratio=0.4\nsave_contents=\"['model', 'optimizer', 'extra']\"\n\nexport WANDB_MODE=offline\nsave_contents=\"['hf_model']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-Math-7B \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.checkpoint.save_contents=${save_contents} \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_gmpo_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_5_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/gmpo_trainer/test_dapo_7b_math.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.4\nclip_ratio_high=0.4\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-8}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\nfsdp_size=32\n\nloss_mode=geo_mean\n\n# export WANDB_MODE=offline\nsave_contents=\"['model', 'optimizer', 'extra']\"\n# save_contents=\"['hf_model']\"\n\n# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    actor_rollout_ref.actor.checkpoint.save_contents=\"${save_contents}\" \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=200 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/examples/gmpo_trainer/test_dapo_qwen3_30b_math.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0527a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.4\nclip_ratio_high=0.4\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\nloss_mode=geo_mean\n\n# export WANDB_MODE=offline\nsave_contents=\"['model', 'optimizer', 'extra']\"\n# save_contents=\"['hf_model']\"\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-8}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\nfsdp_size=32\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    actor_rollout_ref.actor.checkpoint.save_contents=\"${save_contents}\" \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=300 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/examples/gpg_trainer/gpg.md",
    "content": "# GPG: Group Policy Gradient\n\nGroup Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning\n](https://arxiv.org/abs/2504.02546).\n\n## Key Components\n- Use a corrected advantage function to improve policy gradient accuracy and training efficiency.\n- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO)\n\n## Configuration\nTo configure GPG within the framework, use the following YAML settings.\n\n```yaml\nalgorithm:\n  adv_estimator: gpg \nactor_rollout_ref:\n  actor:\n    policy_loss:\n      loss_mode: \"gpg\"\n```\n\n## Advanced Extensions\nGPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance.\n\n```yaml\nalgorithm:\n  adv_estimator: gpg\nactor_rollout_ref:\n  actor:\n    use_kl_loss: True # enable kl regularization\n    kl_loss_coef: 0.01\n    policy_loss:\n      loss_mode: \"gpg\"\n```"
  },
  {
    "path": "verl_distillation/examples/gpg_trainer/run_qwen2-7b_math.sh",
    "content": "set -x\n\n# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:\n# export VLLM_ATTENTION_BACKEND=XFORMERS\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gpg \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=gpg \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_gpg_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/gpg_trainer/run_qwen2-7b_math_megatron.sh",
    "content": "set -x\n\n# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:\n# export VLLM_ATTENTION_BACKEND=XFORMERS\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=gpg \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=gpg \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_gpg_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/README.md",
    "content": "# Group Relative Policy Optimization (GRPO)\n\nIn reinforcement learning, classic algorithms like PPO rely on a \"critic\" model to estimate the value of actions, guiding the learning process. However, training this critic model can be resource-intensive. \n\nGRPO simplifies this process by eliminating the need for a separate critic model. Instead, it operates as follows:\n- Group Sampling: For a given problem, the model generates multiple possible solutions, forming a \"group\" of outputs.\n- Reward Assignment: Each solution is evaluated and assigned a reward based on its correctness or quality.\n- Baseline Calculation: The average reward of the group serves as a baseline. \n- Policy Update: The model updates its parameters by comparing each solution's reward to the group baseline, reinforcing better-than-average solutions and discouraging worse-than-average ones.\n\nThis approach reduces computational overhead by avoiding the training of a separate value estimation model, making the learning process more efficient. For more details, refer to the original paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/pdf/2402.03300)\n\n## Key Components\n\n- No Value Function (Critic-less): unlike PPO, GRPO does not train a separate value network (critic)\n- Group Sampling (Grouped Rollouts): instead of evaluating one rollout per input, GRPO generates multiple completions (responses) from the current policy for each prompt. This set of completions is referred to as a group.\n- Relative Rewards: within each group, completions are scored (e.g., based on correctness), and rewards are normalized relative to the group.\n\n## Configuration\n\nNote that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior.\n\nDespite that many configurations start with the `ppo_` prefix, they work across different RL algorithms in verl, as the GRPO training loop is similar to that of PPO (without critic).\n\n![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d)\n\n- `actor_rollout.ref.rollout.n`: For each prompt, sample n times. Default to 1. For GRPO, please set it to a value larger than 1 for group sampling.\n\n- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n`\n\n- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers.\n\n- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for GRPO updates on one set of sampled trajectories for actor\n\n- `actor_rollout_ref.actor.clip_ratio`: The GRPO clip range. Default to 0.2\n\n- `algorithm.adv_estimator`: Default is gae. Please set it to grpo instead\n\n- `actor_rollout_ref.actor.loss_agg_mode`: Default is \"token-mean\". Options include \"token-mean\", \"seq-mean-token-sum\", \"seq-mean-token-mean\". The original GRPO paper takes the sample-level loss (seq-mean-token-mean), which may be unstable in long-CoT scenarios. All GRPO example scripts provided in verl uses the default configuration \"token-mean\" for loss aggregation instead.\n\nInstead of adding KL penalty in the reward, GRPO regularizes by directly adding the KL divergence between the trained policy and the reference policy to the loss:\n\n- `actor_rollout_ref.actor.use_kl_loss`: To use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False. Please set it to True for GRPO.\n\n- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.\n\n- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending \"+\" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\n## Advanced Extensions\n\n### DrGRPO\n\nThe work [Understanding R1-Zero-Like Training: A Critical Perspective](https://arxiv.org/pdf/2503.20783) claims there's optimization bias in GRPO, that leads to artificially longer responses, especially for incorrect outputs. This inefficiency stems from the way GRPO calculates advantages using group-based reward normalization, which can inadvertently favor longer, less accurate responses. Instead, DrGRPO aggregates token-level losses by normalizing with a global constant to eliminate length bias.\n\nConfigure the following to enable DrGRPO, with all other parameters the same as GRPO's:\n\n- `actor_rollout_ref.actor.loss_agg_mode`: \"seq-mean-token-sum-norm\", which turns off seq-dim averaging\n- `actor_rollout_ref.actor.use_kl_loss`: Please set it to False for DrGRPO\n- `algorithm.norm_adv_by_std_in_grpo`: False, which turns off standard deviation norm\n\n## Reference Example\n\nQwen2.5 GRPO training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log)\n\n```bash\nbash examples/grpo_trainer/run_qwen3-8b.sh\n```\n\nFor more reference performance, please see https://verl.readthedocs.io/en/latest/algo/baseline.html\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_deepseek671b_math_megatron_80gb.sh",
    "content": "set -x\n\n# # 0. download HF checkpoint\n# # remove the `quantization_config` in the `config.json`\n# # set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported\n# huggingface-cli download deepseek-ai/DeepSeek-V3-0324\n\n# no offline dist checkpoint needed, now with mbridge>=0.13.0, we can directly init model from huggingface downloaded fp8 weights\n# tested on docker://verlai/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.13.0-te2.2\nLLM=\"<path_to_dsv3_config>\"\n\n\n# 2. run the script\ngsm8k_train_path=/root/data/gsm8k/train.parquet\ngsm8k_test_path=/root/data/gsm8k/test.parquet\ntrain_files=$gsm8k_train_path\ntest_files=$gsm8k_test_path\n\nALL_OFFLOAD=${ALL_OFFLOAD:-True}\nCOMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}\n\nACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nREF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nCRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nRM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\n\n# 256 H100(80GB)\nNODES=32\nPP=16\nTP=1\nEP=16\nETP=1\nINFER_TP=32\n# consider TP/ETP, and enable recompute if short of memory\n\n# full recompute\n\nn_resp_per_prompt=4\nmax_prompt_length=2048\nmax_response_length=4096\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=True\nkl_loss_coef=0.001\n\n# RAY_ADDRESS='auto' ray job submit --working-dir . --\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=grpo \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$LLM \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.rollout.temperature=1.0 \\\n    actor_rollout_ref.rollout.top_p=1.0 \\\n    actor_rollout_ref.rollout.top_k=-1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$INFER_TP \\\n    trainer.logger='[\"console\",\"tensorboard\"]' \\\n    trainer.project_name='verl_megatron_gsm8k_examples' \\\n    trainer.experiment_name='dsv3-32nodes' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=$NODES \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend='fused' \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=4 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=1 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \\\n    actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \\\n    actor_rollout_ref.actor.megatron.use_mbridge=True \\\n    trainer.default_local_dir=$CKPT_DIR \\\n    trainer.val_before_train=False \\\n    trainer.total_epochs=100 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n## !!!!!!!important!!!!!!\n# 1. set the following environment variables on all your nodes\n# env_vars:\n#   CUDA_DEVICE_MAX_CONNECTIONS: \"1\"\n#   NCCL_NVLS_ENABLE: \"0\"\n#   VLLM_USE_V1: 1\n# 2. install mbridge=0.1.13 on all your node with the following command: \n# pip3 install git+https://github.com/ISEEKYAN/mbridge\n# 3. remove the `quantization_config` in the DeepSeek-V3's `config.json` and \n# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported\n\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n[ -f \"${SCRIPT_DIR}/env.sh\" ] && source \"${SCRIPT_DIR}/env.sh\"\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=True\nkl_loss_coef=0.001\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1204 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=96\nn_resp_per_prompt=8\ntrain_prompt_mini_bsz=32\n\n\n# minimum nodes for DeepSeek-V3: 12 nodes\nNNODES=${NNODES:-12}\n\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n\nMODEL_PATH=$RAY_DATA_HOME/models/DeepSeek-V3-config-verl\n\nTRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet\nTEST_FILE=$RAY_DATA_HOME/dataset/aime-2024.parquet\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 10 / 10))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))\noffload=True\noptim_offload=${OFFLOAD_OPTIM:-True}\ngen_tp=32\ntrain_tp=${TP:-8}\ntrain_pp=${PP:-12}\n\nEP=${EP:-8}\nETP=1\nCP=1\noptimizer_offload_fraction=${OFFLOAD_FRACTION:-1.}\nLAST_LAYER=${LAST_LAYER:-6}\n\n\nproject_name='verl-deepseek-v3'\nexp_name=\"671B-${NNODES}-pp${train_pp}-tp${train_tp}-ep${EP}-actor-length${actor_ppo_max_token_len}\"\nCKPTS_DIR=$RAY_DATA_HOME/ckpt/${project_name}/${exp_name}\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.megatron.use_mbridge=True \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${optim_offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=${CP} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.nccl_timeout=1200 \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=${CP} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=False \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_shared_expert_overlap=False \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=False \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=False \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=${LAST_LAYER} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=100 \\\n    trainer.total_epochs=10 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_deepseek7b_llm.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_deepseek7b_llm_math.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm_math' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='deepseek_llm_7b_math_megatron' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm_seq_packing' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_glm41v_9b.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=zai-org/GLM-4.1V-9B-Thinking \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='glm41v_9b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_gptoss_20b.sh",
    "content": "#!/bin/bash\n\ncat > get_model.py << EOF\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config\n\nmodel_id = \"openai/gpt-oss-20b\"\noutput_dir = \"$HOME/models/gpt-oss-20b-bf16\"\n\nquantization_config = Mxfp4Config(dequantize=True)\nmodel_kwargs = dict(\n    attn_implementation=\"eager\",\n    torch_dtype=torch.bfloat16,\n    quantization_config=quantization_config,\n    use_cache=False,\n    device_map=\"auto\",\n)\n\nmodel = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)\n\n# Patch config with custom attribute before saving\nmodel.config.attn_implementation = \"eager\"\n\nmodel.save_pretrained(output_dir)\ntokenizer = AutoTokenizer.from_pretrained(model_id)\ntokenizer.save_pretrained(output_dir)\nEOF\n\npython get_model.py\n# or you can use lmsys/gpt-oss-20b-bf16\n# recommend to use same value for train_batch_size and ppo_mini_batch_size\n# to avoid MOE training instability\n# use large value for max_response_length if you want to use reasoning effort high.\n\n\nmodel_dir=$HOME/models/gpt-oss-20b-bf16\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$gsm8k_train_path\" \\\n    data.val_files=\"$gsm8k_test_path\" \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=8192 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    +data.apply_chat_template_kwargs.reasoning_effort=medium \\\n    actor_rollout_ref.model.path=${model_dir} \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    +actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.mode=sync \\\n    actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=triton \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='oai_oss_20b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=50 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_minicpmo2_6.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=False \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    data.trust_remote_code=True \\\n    data.custom_cls.path=recipe/minicpmo/rl_dataset.py \\\n    data.custom_cls.name=RLHFDataset \\\n    actor_rollout_ref.model.path=openbmb/MiniCPM-o-2_6 \\\n    actor_rollout_ref.model.trust_remote_code=True \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.use_orig_params=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=False \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='minicpmo2_6_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_mistral13b_skyworkrm_hhrlhf.sh",
    "content": "train_files=data/full_hh_rlhf/rl/train.parquet\ntest_files=data/full_hh_rlhf/rl/train.parquet # no use\n\nmax_prompt_length=4096\nmax_response_length=2048\n\ngen_tp=4\nn_per_prompt=5\nadv_estimator=\"grpo\"\n\nproject_name=verl_full_hh_rlhf_examples\nexp_name=\"grpo_mistral13B-skyworkLlama8b-hhrlhf\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=$adv_estimator \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=512 \\\n    data.prompt_key=\"prompt\" \\\n    data.return_raw_chat=True \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=mistralai/Mistral-Nemo-Instruct-2407 \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.n=$n_per_prompt \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    reward_model.enable=True \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.model.path=Skywork/Skywork-Reward-Llama-3.1-8B \\\n    reward_model.model.input_tokenizer=mistralai/Mistral-Nemo-Instruct-2407 \\\n    reward_model.micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.val_before_train=False \\\n    trainer.project_name=$project_name \\\n    trainer.experiment_name=$exp_name \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=10 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=5 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_moonlight16b_math_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\nHF_MODEL_PATH=moonshotai/Moonlight-16B-A3B\nDIST_CKPT_PATH=${DIST_CKPT_PATH}\n\ntrain_path=$HOME/data/gsm8k/train.parquet\ntest_path=$HOME/data/gsm8k/test.parquet\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_path\" \\\n    data.val_files=\"$test_path\" \\\n    data.train_batch_size=192 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.trust_remote_code=True \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.model.trust_remote_code=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=3 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=4 \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=1 \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=3 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=4 \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=1 \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='moonlight_megatron_ep' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=3 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2-7b.sh",
    "content": "set -x\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2-7b_math.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\nrollout_mode=\"sync\"\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\nUSE_FUSED_KERNELS=True\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=$return_raw_chat \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=$rollout_mode \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh",
    "content": "set -x\n\n\n# For async rollout mode, dataset should return raw chat.\nrollout_mode=\"async\"\nrollout_name=\"sglang\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.return_raw_chat=$return_raw_chat \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$rollout_name \\\n    actor_rollout_ref.rollout.mode=$rollout_mode \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm_kl1e-3' \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\noffload=True\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=12000 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    trainer.val_before_train=False \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=16 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.model.lora_rank=64 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=16 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2.5_3b_grpo_lora' \\\n    trainer.n_gpus_per_node=2 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n\n    # actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    # data.train_batch_size=1024 \\\n    # trainer.n_gpus_per_node=8 \\\n    # actor_rollout_ref.model.use_shm=True \\\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora_from_adapter.sh",
    "content": "set -x\n\nlora_adapter_path=${lora_adapter_path:-/path/saved/lora_adapter}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.model.use_shm=True \\\n    actor_rollout_ref.model.lora_adapter_path=${lora_adapter_path} \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2.5_3b_grpo_lora' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6\\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_5_32b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=2 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh",
    "content": "set -x\n\n# profiling configuration\nPROFILE_STEPS=\"[2,4]\"\nPROFILE_RANKS_ALL=False\nDISCRETE=True\nPROFILE_RANKS=\"[1,2]\"\n\n# profiling NPU options\nSAVE_PATH=\"$HOME/profile_data\"\nLEVEL=\"level1\"\nCONTENTS=['npu','cpu']\nANALYSIS=True\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=32 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.actor.optim.lr=5e-8 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=2 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.profiler.enable=True \\\n    actor_rollout_ref.actor.profiler.ranks=$PROFILE_RANKS \\\n    actor_rollout_ref.actor.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.actor.profiler.tool_config.npu.discrete=$DISCRETE \\\n    actor_rollout_ref.actor.profiler.tool_config.npu.contents=$CONTENTS \\\n    actor_rollout_ref.actor.profiler.tool_config.npu.level=$LEVEL \\\n    actor_rollout_ref.actor.profiler.tool_config.npu.analysis=$ANALYSIS \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.ref.profiler.enable=True \\\n    actor_rollout_ref.ref.profiler.ranks=$PROFILE_RANKS \\\n    actor_rollout_ref.ref.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.ref.profiler.tool_config.npu.discrete=$DISCRETE \\\n    actor_rollout_ref.ref.profiler.tool_config.npu.contents=$CONTENTS \\\n    actor_rollout_ref.ref.profiler.tool_config.npu.level=$LEVEL \\\n    actor_rollout_ref.ref.profiler.tool_config.npu.analysis=$ANALYSIS \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_5_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=5 \\\n    trainer.device=npu \\\n    global_profiler.tool=npu \\\n    global_profiler.steps=$PROFILE_STEPS \\\n    global_profiler.save_path=$SAVE_PATH\n    $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh",
    "content": "set -x\n\n# profiling configuration\nPROFILE_STEPS=\"[2,4]\"\nPROFILE_RANKS_ALL=True\nDISCRETE=False\n\n# profiling NPU options\nSAVE_PATH=\"$HOME/profile_data\"\nLEVEL=\"level1\"\nCONTENTS=['npu','cpu']\nANALYSIS=True\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=32 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=5e-8 \\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=2 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.profiler.enable=True \\\n    actor_rollout_ref.actor.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.actor.profiler.tool_config.npu.discrete=$DISCRETE \\\n    actor_rollout_ref.actor.profiler.tool_config.npu.contents=$CONTENTS \\\n    actor_rollout_ref.actor.profiler.tool_config.npu.level=$LEVEL \\\n    actor_rollout_ref.actor.profiler.tool_config.npu.analysis=$ANALYSIS \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.ref.profiler.enable=True \\\n    actor_rollout_ref.ref.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.ref.profiler.tool_config.npu.discrete=$DISCRETE \\\n    actor_rollout_ref.ref.profiler.tool_config.npu.contents=$CONTENTS \\\n    actor_rollout_ref.ref.profiler.tool_config.npu.level=$LEVEL \\\n    actor_rollout_ref.ref.profiler.tool_config.npu.analysis=$ANALYSIS \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_5_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=5 \\\n    trainer.device=npu \\\n    global_profiler.tool=npu \\\n    global_profiler.steps=$PROFILE_STEPS \\\n    global_profiler.save_path=$SAVE_PATH\n    $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=5e-8 \\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_5_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=5 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh",
    "content": "set -x\nENGINE=${1:-vllm}\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\nHF_MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct\nDIST_CKPT_PATH=${DIST_CKPT_PATH}\n\n# convert HF model to meagatron format offlinely\n# python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH\n\n\n# megatron tuning guide:\n# 1. recommend to offload all states by setting ALL_OFFLOAD=True\n# 2. enable dynamic batch size by setting actor_rollout_ref.actor.use_dynamic_bsz=True ref.log_prob_use_dynamic_bsz=True rollout.log_prob_use_dynamic_bsz=True\n# 3. set ppo_max_token_len_per_gpu and log_prob_max_token_len_per_gpu as large as possible for better MFU (limited by GPU memory). assure ppo_max_token_len_per_gpu > max_prompt_length+max_response_length, if sequence length is too long, you can increase the TP/PP size\n# 4. if memory is very limited, enable full recompute, but the mfu will be 30% lower\n#        full recompute settings:\n#        +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \\\n#        +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \\\n#        +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \\\n\nALL_OFFLOAD=${ALL_OFFLOAD:-True}\nCOMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}\n\nACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nREF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\n\n\ntrain_path=$HOME/data/geo3k/train.parquet\ntest_path=$HOME/data/geo3k/test.parquet\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_path\" \\\n    data.val_files=\"$test_path\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=20480 \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=20480 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \\\n    actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_vl-7b-sglang.sh",
    "content": "set -x\n\n# python examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.multi_stage_wake_up=True \\\n    global_profiler.tool=torch_memory \\\n    global_profiler.save_path=./mem_snapshots \\\n    global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries=100000 \\\n    global_profiler.global_tool_config.torch_memory.stack_depth=32 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.rollout.mode=sync \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_vl-7b.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_vl-7b_freeze_vision.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.freeze_vision_tower=True \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:\n# export VLLM_ATTENTION_BACKEND=XFORMERS\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.model.lora_rank=64 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.model.exclude_modules='.*visual.*' \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=False \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=6144 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=False \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=6144 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\n# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, \n# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model.\nexport USE_OPTIMIZED_MODEL=0\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-32B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_32b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=2 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=15 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\n# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, \n# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model.\nexport USE_OPTIMIZED_MODEL=0\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=16 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_3b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=15 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\n# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, \n# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model.\nexport USE_OPTIMIZED_MODEL=0\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=15 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen3-235b_megatron_96gb.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n## !!!!!!!important!!!!!!\n## set the following environment variables on all your nodes\n# env_vars:\n#   CUDA_DEVICE_MAX_CONNECTIONS: \"1\"\n#   NCCL_NVLS_ENABLE: \"0\"\n#   VLLM_USE_V1: 1\n# install mbridge=0.1.13 on all your node with the following command: \n# pip3 install git+https://github.com/ISEEKYAN/mbridge\n\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n[ -f \"${SCRIPT_DIR}/env.sh\" ] && source \"${SCRIPT_DIR}/env.sh\"\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=True\nkl_loss_coef=0.001\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1204 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 1))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=${TRAIN_BS:-32}\nn_resp_per_prompt=8\ntrain_prompt_mini_bsz=16\n\n# minimum nodes need for qwen3-235B-A22B\nNNODES=${NNODES:-4}\n# Paths\n\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n\nMODEL_PATH=$RAY_DATA_HOME/models/Qwen3-235B-A22B\n\nTRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet\nTEST_FILE=$RAY_DATA_HOME/dataset/aime-2024.parquet\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 10 / 10))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))\noffload=True\nOPTIM_OFFLOAD=${OPTIM_OFFLOAD:-True}\ngen_tp=8\ntrain_tp=${TP:-4}\ntrain_pp=${PP:-8}\n\nEP=${EP:-4}\nETP=1\nCP=1\noptimizer_offload_fraction=${OFFLOAD_FRACTION:-1.}\nlast_layer=${LAST_LAYER:-10}\n\nproject_name='verl-qwen3'\nexp_name=\"235B-${NNODES}-pp${train_pp}-tp${train_tp}-ep${EP}-actor-length${actor_ppo_max_token_len}\"\nCKPTS_DIR=$RAY_DATA_HOME/ckpt/${project_name}/${exp_name}\n\n# TODO: support cuda graph for rollout by setting the following config\n    # actor_rollout_ref.rollout.cudagraph_capture_sizes=[1,2,4,8,16,32]\n    # actor_rollout_ref.rollout.enforce_eager=False\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.megatron.use_mbridge=True \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${OPTIM_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=${CP} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.nccl_timeout=1200 \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=${CP} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=\"flex\" \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=100 \\\n    trainer.total_epochs=10 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen3-32b_npu.sh",
    "content": "set -x\n\nproject_name='GRPO-Qwen3'\nexp_name='GRPO-Qwen3-32b-npu'\ngen_tp=4\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-32B\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/train.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/test.parquet\"}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=${MODEL_PATH} \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=4 \\\n    +actor_rollout_ref.actor.fsdp_config.mixed_precision.param_dtype=bf16 \\\n    +actor_rollout_ref.actor.fsdp_config.mixed_precision.reduce_dtype=bf16 \\\n    +actor_rollout_ref.actor.fsdp_config.mixed_precision.buffer_dtype=fp32 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=32768 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=4 \\\n    trainer.resume_from_path=checkpoints/ \\\n    trainer.save_freq=500 \\\n    trainer.test_freq=50 \\\n    trainer.total_epochs=50 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen3-8b.sh",
    "content": "# Tested successfully on the hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0 image.\n# It outperforms the Qwen2 7B base model by two percentage points on the test set of GSM8K.\n\nset -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen3-8B \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen3_8b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen3-8b_npu.sh",
    "content": "set -x\n\nproject_name='GRPO-Qwen3'\nexp_name='GRPO-Qwen3-8B-npu'\ngen_tp=2\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-8B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=${MODEL_PATH} \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.default_local_dir=${CKPTS_DIR} \\\n    trainer.device=npu \\\n    trainer.resume_mode=auto \\\n    actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \\\n    ++actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \\\n    ++actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \\\n    trainer.val_before_train=True \\\n    trainer.save_freq=5 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen3_8b_grpo_sglang_1k_spmd_npu.sh",
    "content": "set -x\nexport HCCL_CONNECT_TIMEOUT=1500\nexport HCCL_HOST_SOCKET_PORT_RANGE=60000-60050\nexport HCCL_NPU_SOCKET_PORT_RANGE=61000-61050\n\n# WORKSPACE_HOME and DATA_HOME support custom path configuration.\nWORKSPACE_HOME=$pwd\nDATA_HOME=$pwd\n\nsp_size=4\nnum_npu=4\ntp_size=4\ntrain_prompt_bsz=16\ntrain_prompt_mini_bsz=16\n\nmax_prompt_length=512\nmax_response_length=1024\n\nCKPTS_DIR=$WORKSPACE_HOME/logs/ckpt/qwen3_8b\nmodel_path=$DATA_HOME/models/Qwen3-8B\ntrain_data=$DATA_HOME/datasets/processed_gsm8k/train.parquet\nvalid_data=$DATA_HOME/datasets/processed_gsm8k/test.parquet\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$train_data \\\n    data.val_files=$valid_data \\\n    data.train_batch_size=$train_prompt_bsz \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$train_prompt_mini_bsz \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$tp_size \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.n=5 \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=\"ascend\" \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.nccl_timeout=1800 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.val_before_train=False \\\n    trainer.project_name='verl_grpo_example_512_1024_gsm8k' \\\n    trainer.experiment_name='qwen3_8b_function_rm' \\\n    trainer.n_gpus_per_node=$num_npu \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=1000 \\\n    trainer.test_freq=10000 \\\n    trainer.total_epochs=5 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen3_8b_grpo_sglang_32k_spmd_npu.sh",
    "content": "set -x\nexport HCCL_CONNECT_TIMEOUT=1500\nexport HCCL_HOST_SOCKET_PORT_RANGE=60000-60050\nexport HCCL_NPU_SOCKET_PORT_RANGE=61000-61050\n\n# WORKSPACE_HOME and DATA_HOME support custom path configuration.\nWORKSPACE_HOME=$pwd\nDATA_HOME=$pwd\n\nsp_size=4\nnum_gpu=8\ntp_size=4\ntrain_prompt_bsz=16\ntrain_prompt_mini_bsz=16\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 32))\n\nCKPTS_DIR=$WORKSPACE_HOME/logs/ckpt/qwen3_8b\nmodel_path=$DATA_HOME/models/Qwen3-8B\ntrain_data=$DATA_HOME/datasets/dapo/dapo-math-17k.parquet\nvalid_data=$DATA_HOME/datasets/dapo/aime-2024.parquet\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$train_data \\\n    data.val_files=$valid_data \\\n    data.train_batch_size=$train_prompt_bsz \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=False \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$train_prompt_mini_bsz \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$tp_size \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.n=5 \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=\"ascend\" \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.nccl_timeout=3600 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.val_before_train=False \\\n    trainer.project_name='verl_grpo_example_2k_32k' \\\n    trainer.experiment_name='qwen3_8b_function_rm' \\\n    trainer.n_gpus_per_node=$num_gpu \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=1000 \\\n    trainer.test_freq=10000 \\\n    trainer.total_epochs=5 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen3_vl-235b-megatron.sh",
    "content": "set -x\nENGINE=${1:-vllm}\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\n# dependency: vllm>=0.11.0, megatron-lm>=0.13, mbridge with qwen3vl_cp branch\n# environment option1: use a stable container later than docker://verlai/verl:vllm011.dev6 \n    # and install mbridge in it by following the instruction in the container\n            # pip remove mbridge if you have installed it\n            # pip install git+https://github.com/ISEEKYAN/mbridge.git@qwen3vl_cp # for correct mbridge\n# environment option2: use container docker://verlai/verl:vllm011.dev_qwenvl_cp\n \n\nexport VLLM_ALLREDUCE_USE_SYMM_MEM=0 # for vllm0.11.0 with TP\n\n\nHF_MODEL_PATH=${HF_MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-VL-235B-A22B-Instruct\"}\n\nGEN_TP=${GEN_TP:-16}\nCP=${CP:-2}\nTP=${TP:-4}\nPP=${PP:-8}\nEP=${EP:-8}\nETP=${ETP:-1}\n\ntrain_path=$HOME/data/geo3k/train.parquet\ntest_path=$HOME/data/geo3k/test.parquet\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_path\" \\\n    data.val_files=\"$test_path\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=$CP \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$GEN_TP \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=4096 \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=4096 \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=4096 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.megatron.use_mbridge=True \\\n    actor_rollout_ref.actor.megatron.param_offload=True \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=True \\\n    actor_rollout_ref.actor.megatron.grad_offload=True \\\n    actor_rollout_ref.ref.megatron.param_offload=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen3_vl_235b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=8 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh",
    "content": "set -x\nENGINE=${1:-vllm}\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\n\n# dependency: vllm>=0.11.0, megatron-lm>=0.13, mbridge with qwen3vl_cp branch\n# environment option1: use a stable container later than docker://verlai/verl:vllm011.dev6 \n    # and install mbridge in it by following the instruction in the container\n            # pip remove mbridge if you have installed it\n            # pip install git+https://github.com/ISEEKYAN/mbridge.git@qwen3vl_cp # for correct mbridge\n# environment option2: use container docker://verlai/verl:vllm011.dev_qwenvl_cp\n \n\nexport VLLM_ALLREDUCE_USE_SYMM_MEM=0 # for vllm0.11.0 with TP\n\n\nHF_MODEL_PATH=${HF_MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-VL-30B-A3B-Instruct\"}\n\nGEN_TP=${GEN_TP:-4}\nCP=${CP:-2}\nTP=${TP:-2}\nPP=${PP:-1}\nEP=${EP:-8}\nETP=${ETP:-1}\n\ntrain_path=$HOME/data/geo3k/train.parquet\ntest_path=$HOME/data/geo3k/test.parquet\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_path\" \\\n    data.val_files=\"$test_path\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=$CP \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$GEN_TP \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=4096 \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=4096 \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=4096 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.megatron.use_mbridge=True \\\n    actor_rollout_ref.actor.megatron.param_offload=True \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=True \\\n    actor_rollout_ref.actor.megatron.grad_offload=True \\\n    actor_rollout_ref.ref.megatron.param_offload=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen3_vl_30b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen3_vl-8b-megatron.sh",
    "content": "set -x\nENGINE=${1:-vllm}\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\n# dependency: vllm>=0.11.0, megatron-lm>=0.13, mbridge with qwen3vl_cp branch\n# environment option1: use a stable container later than docker://verlai/verl:vllm011.dev6 \n    # and install mbridge in it by following the instruction in the container\n            # pip remove mbridge if you have installed it\n            # pip install git+https://github.com/ISEEKYAN/mbridge.git@qwen3vl_cp # for correct mbridge\n# environment option2: use container docker://verlai/verl:vllm011.dev_qwenvl_cp\n \n\nexport VLLM_ALLREDUCE_USE_SYMM_MEM=0 # for vllm0.11.0 with TP\n\n\nHF_MODEL_PATH=${HF_MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-VL-8B-Instruct\"}\n\nGEN_TP=${GEN_TP:-4}\nCP=${CP:-2}\nTP=${TP:-2}\nPP=${PP:-2}\n\ntrain_path=$HOME/data/geo3k/train.parquet\ntest_path=$HOME/data/geo3k/test.parquet\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_path\" \\\n    data.val_files=\"$test_path\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=$CP \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$GEN_TP \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=4096 \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=4096 \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=4096 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.megatron.use_mbridge=True \\\n    actor_rollout_ref.actor.megatron.param_offload=True \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=True \\\n    actor_rollout_ref.actor.megatron.grad_offload=True \\\n    actor_rollout_ref.ref.megatron.param_offload=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen3_vl_8b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh",
    "content": "set -x\n\n# tested in NNODES=1~4 * 96G H20 GPU\nNNODES=${NNODES:-1}\nNGPUS_PER_NODES=${NGPUS_PER_NODES:-8}\n\nproject_name='DAPO-Qwen3-30b-MATH'\nexp_name='DAPO-Qwen3-30b-MATH-megatron'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=128\ntrain_ppo_micro_batch_size_per_gpu=2\ninfer_ppo_micro_batch_size_per_gpu=2\n# Paths\nMODEL_PATH=Qwen/Qwen3-30B-A3B\n\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nTRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet\nTEST_FILE=$RAY_DATA_HOME/dataset/aime-2024.parquet\nTEST_FILE=\"['$aime24_test_path']\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length)))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length)))\noffload=True\n\noptimizer_offload_fraction=${OFFLOAD_FRACTION:-1.}\n\nCOMMON_PP=${COMMON_PP:-1}\nCOMMON_VPP=${COMMON_VPP:-null}\nCOMMON_CP=${COMMON_CP:-1}\nCOMMON_TP=${COMMON_TP:-1}\nCOMMON_EP=${COMMON_EP:-8}\nCOMMON_ETP=${COMMON_ETP:-1}\n\nTRAIN_TP=${TRAIN_TP:-$COMMON_TP}\nINFER_TP=${INFER_TP:-4}\n\nACTOR_PP=${ACTOR_PP:-$COMMON_PP}\nACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP}\nACTOR_CP=${ACTOR_CP:-$COMMON_CP}\nACTOR_TP=${ACTOR_TP:-$TRAIN_TP}\nACTOR_EP=${ACTOR_EP:-$COMMON_EP}\nACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP}\nROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP}\nREF_PP=${REF_PP:-$COMMON_PP}\nREF_VPP=${REF_VPP:-$COMMON_VPP}\nREF_CP=${REF_CP:-$COMMON_CP}\nREF_TP=${REF_TP:-$TRAIN_TP}\nREF_EP=${REF_EP:-$COMMON_EP}\nREF_ETP=${REF_ETP:-$COMMON_ETP}\nCRITIC_PP=${CRITIC_PP:-$COMMON_PP}\nCRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP}\nCRITIC_CP=${CRITIC_CP:-$COMMON_CP}\nCRITIC_TP=${CRITIC_TP:-$TRAIN_TP}\nCRITIC_EP=${CRITIC_EP:-$COMMON_EP}\nCRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP}\nRM_PP=${RM_PP:-$COMMON_PP}\nRM_VPP=${RM_VPP:-$COMMON_VPP}\nRM_CP=${RM_CP:-$COMMON_CP}\nRM_TP=${RM_TP:-$TRAIN_TP}\nRM_EP=${RM_EP:-$COMMON_EP}\nRM_ETP=${RM_ETP:-$COMMON_ETP}\n\n# install mbridge\n# pip3 install git+https://github.com/ISEEKYAN/mbridge\nUSE_MBRIDGE=True\nUSE_DIST_CKPT=False\n\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    +actor_rollout_ref.model.override_config.model_config.max_position_embeddings=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.model.use_fused_kernels=False \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.lr_decay_style='constant' \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \\\n    actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${ACTOR_PP} \\\n    actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=${ACTOR_VPP} \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=${ACTOR_CP} \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=${ACTOR_EP} \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ACTOR_ETP} \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=\"flex\" \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${INFER_TP} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${REF_TP} \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${REF_PP} \\\n    actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=${REF_VPP} \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=${REF_CP} \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=${REF_EP} \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${REF_ETP} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODES}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=100 \\\n    trainer.total_epochs=10 \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/examples/grpo_trainer/run_seed_oss_36b.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=64 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=ByteDance-Seed/Seed-OSS-36B-Base \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=8 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=2 \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.ref.strategy=fsdp2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\"]' \\\n    trainer.project_name='verl_grpo_seed_oss_36b' \\\n    trainer.experiment_name='seed_oss_36b' \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/README.md",
    "content": "# Proximal Policy Optimization (PPO)\n\nProximal Policy Optimization (PPO) is a family of policy gradient methods for reinforcement learning, proposed by OpenAI in 2017. PPO strikes a balance between simplicity, stability, and performance, making it one of the most widely used algorithms in modern RL applications, including large-scale language model fine-tuning.\n\nTraditional policy gradient methods like REINFORCE or Vanilla Policy Gradient suffer from:\n\n- High variance and sample inefficiency.\n- Instability due to large policy updates.\n\nPPO addresses this problem using a clipped surrogate objective that avoids overly large updates without requiring second-order derivatives.\n\nFor more technical details regarding PPO, we suggest reading the introduction in the [OpenAI spinning up tutorial](https://spinningup.openai.com/en/latest/algorithms/ppo.html), and the paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347).\n\n## Key Components\n\n- Actor-Critic Architecture: PPO requires both an actor model (policy) and a critic model (value function). This differs from other algorithms like GRPO and RLOO that don't require a critic model.\n\n- Generalized Advantage Estimation (GAE): PPO uses GAE for computing advantage values, which helps reduce variance in policy gradient estimates while maintaining low bias.\n\n- Clipped Surrogate Objective: The core of PPO is implemented through the clipped surrogate objective function that limits policy updates.\n\n## Configuration\n\nNote that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior.\n\nMost critic configs are similar to those of actors. Note that the critic model is omitted from the figure below.\n\n![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d)\n\n- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n`\n\n- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers\n\n- `critic.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO critic updates. The ppo_mini_batch_size is a global size across all workers\n\n- `actor_rollout_ref.actor.clip_ratio`: The PPO clip range. Default to 0.2\n\n- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for actor\n\n- `critic.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for critic. Defaults to `actor_rollout_ref.actor.ppo_epochs`\n\n- `algorithm.gamma`: discount factor\n\n- `algorithm.lam`: The lambda term that trades off between bias and variance in the GAE estimator\n\n- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo, rloo_vectorized\n\n## Advanced Extensions\n\n### KL Divergence Control\n\nOptions to prevent the policy from diverging too far from a reference policy. Two mechanisms are available: KL reward penalty and KL loss. For more technical details, see [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155)\n\nOptions to use KL loss for KL divergence control: \n\n- `actor_rollout_ref.actor.use_kl_loss`: to use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False\n\n- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.\n\n- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending \"+\" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\nOptions to use KL penalty in the reward:\n\n- `algorithm.use_kl_in_reward`: Whether to enable in-reward kl penalty. Default is False.\n\n- `algorithm.kl_penalty`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. This defines the way to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty` in core_algos.py. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\n- `algorithm.kl_ctrl.kl_coef`: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.\n- `algorithm.kl_ctrl.type`: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.\n- `algorithm.kl_ctrl.horizon`: See source code of AdaptiveKLController for details.\n- `algorithm.kl_ctrl.target_kl`: See source code of AdaptiveKLController for details.\n\n### Dual-clip PPO\n\nThe Dual-Clip PPO introduces a approach by applying a lower bound to the policy ratio when the advantage is less than zero, when multiplied by a large raito, does not exceed a specified lower bound.\n\n![image](https://github.com/user-attachments/assets/fc232181-d8b0-4307-8dd2-4dc0a4c1c139)\n\n- `actor_rollout_ref.actor.clip_ratio_c`: lower bound of the value for Dual-clip PPO, defaults to 3.0\n\n## Reference Example\n\nQwen2.5 training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log)\n\n```bash\nbash run_gemma.sh\n  trainer.n_gpus_per_node=1 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  trainer.logger=console \\\n  critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  data.train_batch_size=256 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size=2 \\\n  critic.ppo_micro_batch_size=2\n```\n\nReference performance with verl v0.2:\n\n| Model                          | Method          | Score | Link                                                                                           |\n|-------------------------------|------------------|-------|------------------------------------------------------------------------------------------------|\n| Qwen/Qwen2.5-0.5B-Instruct     | pretrained model | 36.4  | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/)                                        |\n| Qwen/Qwen2.5-0.5B-Instruct     | PPO              | 56.7  | [PPO Command and Logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) |\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_deepseek7b_llm.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=32 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=1 \\\n    trainer.use_legacy_worker_impl=auto \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh",
    "content": "set -x\n\nVERL_USE_MODELSCOPE=True \\\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=32 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=1 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    algorithm.use_pf_ppo=True \\\n    algorithm.pf_ppo.reweight_method=pow \\  # [\"pow\", \"max_min\", \"max_random\"]\n    algorithm.pf_ppo.weight_pow=2.0 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.rollout.n=5 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=32 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=1 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    reward_model.sandbox_fusion.url='https://xxxxxxxxx.apigateway-cn-beijing.volceapi.com/run_code' \\\n    reward_model.sandbox_fusion.max_concurrent=128 \\\n    reward_model.reward_manager=prime \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/Eurus-2-RL-Data/train.parquet \\\n    data.val_files=$HOME/data/Eurus-2-RL-Data/validation.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=32 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_sandbox_fusion' \\\n    trainer.experiment_name='deepseek_llm_7b_function_sandbox_fusion' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=1 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    critic.optim.lr=1e-5 \\\n    critic.ulysses_sequence_parallel_size=2 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=64 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm_sp2' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh",
    "content": "set -x\n\ntrain_files=$HOME/data/full_hh_rlhf/rl/train.parquet\ntest_files=$HOME/data/full_hh_rlhf/rl/train.parquet # no use\n\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=128 \\\n    data.max_response_length=128 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    reward_model.enable=True \\\n    reward_model.megatron.tensor_model_parallel_size=4 \\\n    reward_model.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    reward_model.micro_batch_size_per_gpu=4 \\\n    reward_model.param_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_megatron_full_hh_rlhf_examples' \\\n    trainer.experiment_name='deepseek_llm_7b_model_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=100 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh",
    "content": "set -x\n\n# Example runnable on H20 * 8\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_ppo_gsm8k_math_examples' \\\n    trainer.experiment_name='deepseek_llm_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=100 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh",
    "content": "set -x\n\n# Example runnable on H20 * 8\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=${train_files:-\"$gsm8k_train_path\"}\ntest_files=${test_files:-\"$gsm8k_test_path\"}\n\n# Nsight profiling configuration\nPROFILE_STEPS=\"[1]\" # or [] or null\nPROFILE_RANKS_ALL=False # or True\nPROFILE_RANKS=[0,4]\nDISCRETE=True  # or True\n\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.profiler.enable=True \\\n    actor_rollout_ref.actor.profiler.ranks=$PROFILE_RANKS \\\n    actor_rollout_ref.actor.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    critic.profiler.enable=True \\\n    critic.profiler.ranks=$PROFILE_RANKS \\\n    critic.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_ppo_gsm8k_math_examples' \\\n    trainer.experiment_name='deepseek_llm_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=100 \\\n    trainer.total_training_steps=1 \\\n    global_profiler.tool=nsys \\\n    global_profiler.steps=$PROFILE_STEPS \\\n    global_profiler.global_tool_config.nsys.discrete=$DISCRETE $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_gemma.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=google/gemma-2-2b-it \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=False \\\n    critic.model.path=google/gemma-2-2b-it \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example' \\\n    trainer.experiment_name='gemma2b_function_rm' \\\n    trainer.n_gpus_per_node=2 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\n\n# 0. download the model\nhuggingface-cli download moonshotai/Moonlight-16B-A3B-Instruct\n\n# 1. convert the model to mcore format\n# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path\nHF_MODEL_PATH=/data/models/moonshotai/Moonlight-16B-A3B-Instruct\nDIST_CKPT_PATH=/data/mcore_ckpt/Moonlight-16B-A3B-Instruct\npython scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH\n\n\n# 2. run the script\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\ntrain_files=$gsm8k_train_path\ntest_files=$gsm8k_test_path\n\nALL_OFFLOAD=${ALL_OFFLOAD:-False}\nCOMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}\n\nACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nREF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nCRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nRM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\n\n\nNODES=4\nPP=2\nTP=8\nEP=8\nETP=1\nVLLM_TP=4\n\n# RAY_ADDRESS='auto' ray job submit --working-dir . -- \npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.trust_remote_code=True \\\n    actor_rollout_ref.model.path=$LLM \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=$LLM \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_megatron_gsm8k_examples' \\\n    trainer.experiment_name='moonlight_16b_a3b_instruct_1node' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=$NODES \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    actor_rollout_ref.model.trust_remote_code=True \\\n    critic.model.trust_remote_code=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=13 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \\\n    critic.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \\\n    critic.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \\\n    critic.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \\\n    critic.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \\\n    actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \\\n    critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \\\n    critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \\\n    critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    critic.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    trainer.val_before_train=False \\\n    trainer.total_epochs=100 $@\n    "
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\n# 0. download the model\n#huggingface-cli download Qwen/Qwen1.5-MoE-A2.7B-Chat\n\n# 1. convert the model to mcore format\n# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path\nHF_MODEL_PATH=/data/models/Qwen/Qwen1.5-MoE-A2.7B-Chat\nDIST_CKPT_PATH=/data/mcore_ckpt/Qwen1.5-MoE-A2.7B-Chat\npython scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH\n\n# 2. run the script\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\ntrain_files=$gsm8k_train_path\ntest_files=$gsm8k_test_path\n\nNODES=4\nPP=2\nTP=4\nCP=1\nVLLM_TP=4\n\n# RAY_ADDRESS='auto' ray job submit --working-dir . -- \npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=$CP \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=$CP \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=$HF_MODEL_PATH \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    critic.megatron.tensor_model_parallel_size=$TP \\\n    critic.megatron.pipeline_model_parallel_size=$PP \\\n    critic.megatron.context_parallel_size=$CP \\\n    critic.megatron.use_dist_checkpointing=True \\\n    critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_megatron_gsm8k_examples' \\\n    trainer.experiment_name='qwen1.5_moe_nochat' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=$NODES \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=100 $@\n    "
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_ppo_gsm8k_math_examples' \\\n    trainer.experiment_name='qwen2_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=100 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_qwen2-7b_rm.sh",
    "content": "# Discliamer: the model used in the script is only for academic purpose.\nset -x\n\n# Data preparation scripts are available in ``examples/data_preprocess``.\n# Example usage:\n#\n#   python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math\n#   python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\n\n# prepare model ckpt\nhuggingface-cli download Qwen/Qwen2-7B-Instruct --local-dir $HOME/models/Qwen2-7B-Instruct &\nhuggingface-cli download sfairXC/FsfairX-LLaMA3-RM-v0.1 --local-dir $HOME/models/FsfairX-LLaMA3-RM-v0.1 &\nwait\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"$HOME/models/Qwen2-7B-Instruct\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.optim.lr_warmup_steps_ratio=0.05 \\\n    critic.model.path=\"$HOME/models/Qwen2-7B-Instruct\" \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=32 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    reward_model.enable=True \\\n    reward_model.model.path=\"$HOME/models/FsfairX-LLaMA3-RM-v0.1\" \\\n    reward_model.model.use_remove_padding=True \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.micro_batch_size_per_gpu=32 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example' \\\n    trainer.val_before_train=False \\\n    trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=4096 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.use_dynamic_bsz=True \\\n    critic.ppo_max_token_len_per_gpu=98304 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    reward_model.enable=True \\\n    reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\\\n    reward_model.model.use_remove_padding=True \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.micro_batch_size_per_gpu=32 \\\n    reward_model.use_dynamic_bsz=True \\\n    reward_model.forward_max_token_len_per_gpu=98304 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\nFUSED_KERNEL_BACKEND=triton # or 'torch' for torch backend\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=4096 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.model.fused_kernel_options.impl_backend=$FUSED_KERNEL_BACKEND \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.use_dynamic_bsz=True \\\n    critic.ppo_max_token_len_per_gpu=98304 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    reward_model.enable=True \\\n    reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\\\n    reward_model.model.use_remove_padding=True \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.micro_batch_size_per_gpu=32 \\\n    reward_model.use_dynamic_bsz=True \\\n    reward_model.forward_max_token_len_per_gpu=98304 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing_fused_kernel' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=${train_files:-\"$gsm8k_train_path\"}\ntest_files=${test_files:-\"$gsm8k_test_path\"}\n\nPROFILE_STEPS=\"[1,2,5]\" # or [] or null\nPROFILE_RANKS_ALL=False # or True\nPROFILE_RANKS=[0,4]\nDISCRETE=True  # or True\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=4096 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=12000 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.profiler.enable=True \\\n    actor_rollout_ref.actor.profiler.ranks=$PROFILE_RANKS \\\n    actor_rollout_ref.actor.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=2 \\\n    critic.use_dynamic_bsz=True \\\n    critic.ppo_max_token_len_per_gpu=98304 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    critic.profiler.enable=True \\\n    critic.profiler.ranks=$PROFILE_RANKS \\\n    critic.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    reward_model.enable=True \\\n    reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\\\n    reward_model.model.use_remove_padding=True \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.micro_batch_size_per_gpu=32 \\\n    reward_model.use_dynamic_bsz=True \\\n    reward_model.forward_max_token_len_per_gpu=98304 \\\n    reward_model.profiler.enable=True \\\n    reward_model.profiler.ranks=$PROFILE_RANKS \\\n    reward_model.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=15 \\\n    trainer.total_training_steps=6 \\\n    global_profiler.profile_continuous_steps=True \\\n    global_profiler.tool=nsys \\\n    global_profiler.steps=$PROFILE_STEPS \\\n    global_profiler.global_tool_config.nsys.discrete=$DISCRETE $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\n# For async rollout mode, dataset should return raw chat.\nrollout_mode=\"sync\"\nif [ \"$rollout_mode\" = \"async\" ]; then\n    return_raw_chat=\"True\"\nfi\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=$return_raw_chat \\\n    data.train_batch_size=4096 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=$rollout_mode \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_max_token_len_per_gpu=98304 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='qwen2-7b_function_rm_bsz8k_p4k_r4k_seq_packing' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=4096 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_max_token_len_per_gpu=98304 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='qwen2-7b_function_rm_bsz8k_p4k_r4k_seq_packing' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_qwen2.5-32b.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=False \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2.5-32B-Instruct \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.ppo_micro_batch_size_per_gpu=8 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example' \\\n    trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=4 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/ppo_trainer/run_qwen3-8b_npu.sh",
    "content": "set -x\n\nexport VLLM_USE_V1=1\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/dapo-math-17k.parquet \\\n    data.val_files=$HOME/data/dapo-math-17k.parquet \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=2000 \\\n    data.max_response_length=12000 \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=Qwen/Qwen3-8B \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=14000 \\\n    actor_rollout_ref.rollout.max_num_seqs=64 \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen3-8B \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=1 \\\n    critic.ulysses_sequence_parallel_size=2 \\\n    critic.model.fsdp_config.param_offload=True \\\n    critic.model.fsdp_config.optimizer_offload=True \\\n    critic.use_dynamic_bsz=True \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_example_dapo_math_17k' \\\n    trainer.experiment_name='qwen3_8b_fsdp' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=-1 \\\n    trainer.val_before_train=False \\\n    trainer.device=npu \\\n    trainer.max_actor_ckpt_to_keep=1 \\\n    trainer.max_critic_ckpt_to_keep=1 \\\n    trainer.total_training_steps=100 $@"
  },
  {
    "path": "verl_distillation/examples/ray/tutorial.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0ddc582b\",\n   \"metadata\": {},\n   \"source\": [\n    \"# VeRL Ray API Tutorial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"71fe3b94\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Chapter 1: Ray Basics\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 144,\n   \"id\": \"1347d381\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import os\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 145,\n   \"id\": \"e75b9d44\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import warnings\\n\",\n    \"\\n\",\n    \"import ray\\n\",\n    \"import torch\\n\",\n    \"\\n\",\n    \"warnings.filterwarnings(\\\"ignore\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 146,\n   \"id\": \"2e90ae00\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2024-11-01 17:27:19,132\\tINFO worker.py:1752 -- Started a local Ray instance.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"9cc9d2ccbdfb48918c8fd6cd13a0807a\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/html\": [\n       \"<div class=\\\"lm-Widget p-Widget lm-Panel p-Panel jp-Cell-outputWrapper\\\">\\n\",\n       \"    <div style=\\\"margin-left: 50px;display: flex;flex-direction: row;align-items: center\\\">\\n\",\n       \"        <div class=\\\"jp-RenderedHTMLCommon\\\" style=\\\"display: flex; flex-direction: row;\\\">\\n\",\n       \"  <svg viewBox=\\\"0 0 567 224\\\" fill=\\\"none\\\" xmlns=\\\"http://www.w3.org/2000/svg\\\" style=\\\"height: 3em;\\\">\\n\",\n       \"    <g clip-path=\\\"url(#clip0_4338_178347)\\\">\\n\",\n       \"        <path d=\\\"M341.29 165.561H355.29L330.13 129.051C345.63 123.991 354.21 112.051 354.21 94.2307C354.21 71.3707 338.72 58.1807 311.88 58.1807H271V165.561H283.27V131.661H311.8C314.25 131.661 316.71 131.501 319.01 131.351L341.25 165.561H341.29ZM283.29 119.851V70.0007H311.82C331.3 70.0007 342.34 78.2907 342.34 94.5507C342.34 111.271 331.34 119.861 311.82 119.861L283.29 119.851ZM451.4 138.411L463.4 165.561H476.74L428.74 58.1807H416L367.83 165.561H380.83L392.83 138.411H451.4ZM446.19 126.601H398L422 72.1407L446.24 126.601H446.19ZM526.11 128.741L566.91 58.1807H554.35L519.99 114.181L485.17 58.1807H472.44L514.01 129.181V165.541H526.13V128.741H526.11Z\\\" fill=\\\"var(--jp-ui-font-color0)\\\"/>\\n\",\n       \"        <path d=\\\"M82.35 104.44C84.0187 97.8827 87.8248 92.0678 93.1671 87.9146C98.5094 83.7614 105.083 81.5067 111.85 81.5067C118.617 81.5067 125.191 83.7614 130.533 87.9146C135.875 92.0678 139.681 97.8827 141.35 104.44H163.75C164.476 101.562 165.622 98.8057 167.15 96.2605L127.45 56.5605C121.071 60.3522 113.526 61.6823 106.235 60.3005C98.9443 58.9187 92.4094 54.9203 87.8602 49.0574C83.3109 43.1946 81.0609 35.8714 81.5332 28.4656C82.0056 21.0599 85.1679 14.0819 90.4252 8.8446C95.6824 3.60726 102.672 0.471508 110.08 0.0272655C117.487 -0.416977 124.802 1.86091 130.647 6.4324C136.493 11.0039 140.467 17.5539 141.821 24.8501C143.175 32.1463 141.816 39.6859 138 46.0505L177.69 85.7505C182.31 82.9877 187.58 81.4995 192.962 81.4375C198.345 81.3755 203.648 82.742 208.33 85.3976C213.012 88.0532 216.907 91.9029 219.616 96.5544C222.326 101.206 223.753 106.492 223.753 111.875C223.753 117.258 222.326 122.545 219.616 127.197C216.907 131.848 213.012 135.698 208.33 138.353C203.648 141.009 198.345 142.375 192.962 142.313C187.58 142.251 182.31 140.763 177.69 138L138 177.7C141.808 184.071 143.155 191.614 141.79 198.91C140.424 206.205 136.44 212.75 130.585 217.313C124.731 221.875 117.412 224.141 110.004 223.683C102.596 223.226 95.6103 220.077 90.3621 214.828C85.1139 209.58 81.9647 202.595 81.5072 195.187C81.0497 187.779 83.3154 180.459 87.878 174.605C92.4405 168.751 98.9853 164.766 106.281 163.401C113.576 162.035 121.119 163.383 127.49 167.19L167.19 127.49C165.664 124.941 164.518 122.182 163.79 119.3H141.39C139.721 125.858 135.915 131.673 130.573 135.826C125.231 139.98 118.657 142.234 111.89 142.234C105.123 142.234 98.5494 139.98 93.2071 135.826C87.8648 131.673 84.0587 125.858 82.39 119.3H60C58.1878 126.495 53.8086 132.78 47.6863 136.971C41.5641 141.163 34.1211 142.972 26.7579 142.059C19.3947 141.146 12.6191 137.574 7.70605 132.014C2.79302 126.454 0.0813599 119.29 0.0813599 111.87C0.0813599 104.451 2.79302 97.2871 7.70605 91.7272C12.6191 86.1673 19.3947 82.5947 26.7579 81.6817C34.1211 80.7686 41.5641 82.5781 47.6863 86.7696C53.8086 90.9611 58.1878 97.2456 60 104.44H82.35ZM100.86 204.32C103.407 206.868 106.759 208.453 110.345 208.806C113.93 209.159 117.527 208.258 120.522 206.256C123.517 204.254 125.725 201.276 126.771 197.828C127.816 194.38 127.633 190.677 126.253 187.349C124.874 184.021 122.383 181.274 119.205 179.577C116.027 177.88 112.359 177.337 108.826 178.042C105.293 178.746 102.113 180.654 99.8291 183.44C97.5451 186.226 96.2979 189.718 96.3 193.32C96.2985 195.364 96.7006 197.388 97.4831 199.275C98.2656 201.163 99.4132 202.877 100.86 204.32ZM204.32 122.88C206.868 120.333 208.453 116.981 208.806 113.396C209.159 109.811 208.258 106.214 206.256 103.219C204.254 100.223 201.275 98.0151 197.827 96.97C194.38 95.9249 190.676 96.1077 187.348 97.4873C184.02 98.8669 181.274 101.358 179.577 104.536C177.879 107.714 177.337 111.382 178.041 114.915C178.746 118.448 180.653 121.627 183.439 123.911C186.226 126.195 189.717 127.443 193.32 127.44C195.364 127.443 197.388 127.042 199.275 126.259C201.163 125.476 202.878 124.328 204.32 122.88ZM122.88 19.4205C120.333 16.8729 116.981 15.2876 113.395 14.9347C109.81 14.5817 106.213 15.483 103.218 17.4849C100.223 19.4868 98.0146 22.4654 96.9696 25.9131C95.9245 29.3608 96.1073 33.0642 97.4869 36.3922C98.8665 39.7202 101.358 42.4668 104.535 44.1639C107.713 45.861 111.381 46.4036 114.914 45.6992C118.447 44.9949 121.627 43.0871 123.911 40.301C126.195 37.515 127.442 34.0231 127.44 30.4205C127.44 28.3772 127.038 26.3539 126.255 24.4664C125.473 22.5788 124.326 20.8642 122.88 19.4205ZM19.42 100.86C16.8725 103.408 15.2872 106.76 14.9342 110.345C14.5813 113.93 15.4826 117.527 17.4844 120.522C19.4863 123.518 22.4649 125.726 25.9127 126.771C29.3604 127.816 33.0638 127.633 36.3918 126.254C39.7198 124.874 42.4664 122.383 44.1635 119.205C45.8606 116.027 46.4032 112.359 45.6988 108.826C44.9944 105.293 43.0866 102.114 40.3006 99.8296C37.5145 97.5455 34.0227 96.2983 30.42 96.3005C26.2938 96.3018 22.337 97.9421 19.42 100.86ZM100.86 100.86C98.3125 103.408 96.7272 106.76 96.3742 110.345C96.0213 113.93 96.9226 117.527 98.9244 120.522C100.926 123.518 103.905 125.726 107.353 126.771C110.8 127.816 114.504 127.633 117.832 126.254C121.16 124.874 123.906 122.383 125.604 119.205C127.301 116.027 127.843 112.359 127.139 108.826C126.434 105.293 124.527 102.114 121.741 99.8296C118.955 97.5455 115.463 96.2983 111.86 96.3005C109.817 96.299 107.793 96.701 105.905 97.4835C104.018 98.2661 102.303 99.4136 100.86 100.86Z\\\" fill=\\\"#00AEEF\\\"/>\\n\",\n       \"    </g>\\n\",\n       \"    <defs>\\n\",\n       \"        <clipPath id=\\\"clip0_4338_178347\\\">\\n\",\n       \"            <rect width=\\\"566.93\\\" height=\\\"223.75\\\" fill=\\\"white\\\"/>\\n\",\n       \"        </clipPath>\\n\",\n       \"    </defs>\\n\",\n       \"  </svg>\\n\",\n       \"</div>\\n\",\n       \"\\n\",\n       \"        <table class=\\\"jp-RenderedHTMLCommon\\\" style=\\\"border-collapse: collapse;color: var(--jp-ui-font-color1);font-size: var(--jp-ui-font-size1);\\\">\\n\",\n       \"    <tr>\\n\",\n       \"        <td style=\\\"text-align: left\\\"><b>Python version:</b></td>\\n\",\n       \"        <td style=\\\"text-align: left\\\"><b>3.9.2</b></td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"        <td style=\\\"text-align: left\\\"><b>Ray version:</b></td>\\n\",\n       \"        <td style=\\\"text-align: left\\\"><b>2.10.0</b></td>\\n\",\n       \"    </tr>\\n\",\n       \"    \\n\",\n       \"</table>\\n\",\n       \"\\n\",\n       \"    </div>\\n\",\n       \"</div>\\n\"\n      ],\n      \"text/plain\": [\n       \"RayContext(dashboard_url='', python_version='3.9.2', ray_version='2.10.0', ray_commit='09abba26b5bf2707639bb637c208d062a47b46f6')\"\n      ]\n     },\n     \"execution_count\": 146,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[36m(GPUAccumulator pid=224400)\\u001b[0m rank 0, value: tensor([1.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulator pid=225234)\\u001b[0m rank 2, value: tensor([3.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulator pid=225607)\\u001b[0m rank 0, value: tensor([2.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulator pid=226423)\\u001b[0m rank 1, value: tensor([3.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulator pid=226857)\\u001b[0m rank 3, value: tensor([6.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulatorDecorator pid=227475)\\u001b[0m 10\\n\",\n      \"\\u001b[36m(GPUAccumulatorDecorator pid=227475)\\u001b[0m rank 0, value: tensor([10.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulatorDecorator pid=227655)\\u001b[0m rank 1, value: tensor([11.], device='cuda:0')\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Build a local ray cluster. The head node and worker node are on this machine\\n\",\n    \"ray.init()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a127e4e4\",\n   \"metadata\": {},\n   \"source\": [\n    \"Implement an Accumulator class.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 147,\n   \"id\": \"20e7b9a3\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"@ray.remote\\n\",\n    \"class Accumulator:\\n\",\n    \"    def __init__(self):\\n\",\n    \"        self.value = 0\\n\",\n    \"\\n\",\n    \"    def add(self, x):\\n\",\n    \"        self.value += x\\n\",\n    \"\\n\",\n    \"    def get_value(self):\\n\",\n    \"        return self.value\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 148,\n   \"id\": \"3b80098c\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Instantiate an accumulator. Accumulator can be viewed as a process, acting as an RPC service.\\n\",\n    \"accumulator = Accumulator.remote()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 149,\n   \"id\": \"b14b1009\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"0\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"value_ref = accumulator.get_value.remote()  # Check the current value. Note that this function returns immediately and does not actually wait for the remote execution to complete.\\n\",\n    \"# Get the value\\n\",\n    \"value = ray.get(value_ref)\\n\",\n    \"print(value)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 150,\n   \"id\": \"513a84b3\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"10\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Accumulate, then check the result.\\n\",\n    \"accumulator.add.remote(10)  # Similarly, the 'add' here will return immediately.\\n\",\n    \"new_value = ray.get(accumulator.get_value.remote())\\n\",\n    \"print(new_value)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3c332fe0\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Chapter 2: Resource Pool and RayWorkerGroup\\n\",\n    \"In the previous example, it was a simple single-process worker. \\n\",\n    \"In this example, we implement a worker with a GPU and form a RayWorkerGroup. Within this RayWorkerGroup, we implement a simple operation of an accumulator.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 151,\n   \"id\": \"04229afb\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from verl.single_controller.base import Worker\\n\",\n    \"from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 152,\n   \"id\": \"0d0dbd58\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"resource_pool = RayResourcePool([4], use_gpu=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 153,\n   \"id\": \"68f6838a\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"@ray.remote\\n\",\n    \"class GPUAccumulator(Worker):\\n\",\n    \"    def __init__(self) -> None:\\n\",\n    \"        super().__init__()\\n\",\n    \"        # The initial value of each rank is the same as the rank\\n\",\n    \"        self.value = torch.zeros(size=(1,), device=\\\"cuda\\\") + self.rank\\n\",\n    \"\\n\",\n    \"    def add(self, x):\\n\",\n    \"        self.value += x\\n\",\n    \"        print(f\\\"rank {self.rank}, value: {self.value}\\\")\\n\",\n    \"        return self.value.cpu()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 154,\n   \"id\": \"23aad8fe\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[tensor([1.]), tensor([2.]), tensor([3.]), tensor([4.])]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Each worker's initial value is its rank, and then each rank's value is incremented by 1, so the values obtained on each rank are [1, 2, 3, 4]\\n\",\n    \"class_with_args = RayClassWithInitArgs(cls=GPUAccumulator)\\n\",\n    \"worker_group = RayWorkerGroup(resource_pool, class_with_args)\\n\",\n    \"print(worker_group.execute_all_sync(\\\"add\\\", x=[1, 1, 1, 1]))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e6705284\",\n   \"metadata\": {},\n   \"source\": [\n    \"The principle of parameter passing: The input parameter is a list of length world_size, where each element in the list is dispatched respectively to each worker in the RayWorkerGroup. \\n\",\n    \"The return parameter is also a list, corresponding to the return value of each worker.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d25c2412\",\n   \"metadata\": {},\n   \"source\": [\n    \"### GPU Resource Sharing\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f74f6d24\",\n   \"metadata\": {},\n   \"source\": [\n    \"RayWorkerGroups mapped to the same resource pool share the GPU. In this example, we implement three resource pools: the first occupies 4 GPUs, the second also occupies 4 GPUs, and the last occupies all 8 GPUs. Among them, the first resource pool reuses the resource pool mentioned above.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 155,\n   \"id\": \"49f9c06f\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Create a new resource pool and then merge the newly created resource pool with the previous one.\\n\",\n    \"resource_pool_1 = RayResourcePool([4], use_gpu=True, name_prefix=\\\"a\\\")\\n\",\n    \"resource_pool_merge = merge_resource_pool(resource_pool, resource_pool_1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 156,\n   \"id\": \"05c2e305\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Establish a RayWorkerGroup on the newly created resource pool.\\n\",\n    \"worker_group_1 = RayWorkerGroup(resource_pool_1, class_with_args)\\n\",\n    \"worker_group_merge = RayWorkerGroup(resource_pool_merge, class_with_args)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 157,\n   \"id\": \"6b9b13f4\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[tensor([2.]), tensor([3.]), tensor([4.]), tensor([5.])]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run 'add' on the second set of 4 GPUs; the result should be [2, 3, 4, 5].\\n\",\n    \"output_1 = worker_group_1.execute_all_sync(\\\"add\\\", x=[2, 2, 2, 2])\\n\",\n    \"print(output_1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 158,\n   \"id\": \"d856d030\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[tensor([3.]), tensor([4.]), tensor([5.]), tensor([6.]), tensor([7.]), tensor([8.]), tensor([9.]), tensor([10.])]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run 'add' on the merged set of 8 GPUs; the result should be [3, 4, 5, 6, 7, 8, 9, 10].\\n\",\n    \"output_merge = worker_group_merge.execute_all_sync(\\\"add\\\", x=[3, 3, 3, 3, 3, 3, 3, 3])\\n\",\n    \"print(output_merge)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 159,\n   \"id\": \"33a4628c\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"4 4 8\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(worker_group.world_size, worker_group_1.world_size, worker_group_merge.world_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3df19d13\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Chapter 3: Data Dispatch, Execution and Collection\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"acb22d9d\",\n   \"metadata\": {},\n   \"source\": [\n    \"In the above example, we used the `execute_all_sync` function in the RayWorkerGroup to dispatch data from the driver to each worker. This is very inconvenient for coding. \\n\",\n    \"In this chapter, we use the form of function decorators to allow RayWorkerGroup to directly call functions written in the Worker, and to greatly simplify parameter passing.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 160,\n   \"id\": \"35237432\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from verl.single_controller.base.decorator import Dispatch, Execute, register\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 161,\n   \"id\": \"88b8ba3b\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"@ray.remote\\n\",\n    \"class GPUAccumulatorDecorator(Worker):\\n\",\n    \"    def __init__(self) -> None:\\n\",\n    \"        super().__init__()\\n\",\n    \"        # The initial value of each rank is the same as the rank\\n\",\n    \"        self.value = torch.zeros(size=(1,), device=\\\"cuda\\\") + self.rank\\n\",\n    \"\\n\",\n    \"    # map from a single input to all the worker\\n\",\n    \"    @register(Dispatch.ONE_TO_ALL)\\n\",\n    \"    def add(self, x):\\n\",\n    \"        print(x)\\n\",\n    \"        self.value = self.value + x\\n\",\n    \"        print(f\\\"rank {self.rank}, value: {self.value}\\\")\\n\",\n    \"        return self.value.cpu()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 162,\n   \"id\": \"eddaa043\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"class_with_args = RayClassWithInitArgs(cls=GPUAccumulatorDecorator)\\n\",\n    \"gpu_accumulator_decorator = RayWorkerGroup(resource_pool_merge, class_with_args)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 163,\n   \"id\": \"10087c91\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[tensor([10.]), tensor([11.]), tensor([12.]), tensor([13.]), tensor([14.]), tensor([15.]), tensor([16.]), tensor([17.])]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# As we can see, 10 is automatically dispatched to each Worker in this RayWorkerGroup.\\n\",\n    \"print(gpu_accumulator_decorator.add(x=10))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"540ee6ad\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Custom Dispatch, Collection\\n\",\n    \"Users can customize `dispatch` and `collection` function. You only need to write the `dispatch_fn` and `collect_fn` functions yourself. We also support executing RPC only on rank_zero, with specific examples provided below.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 164,\n   \"id\": \"8e041270\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from verl.single_controller.base.decorator import Dispatch, collect_all_to_all, register\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 165,\n   \"id\": \"43b5be31\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def two_to_all_dispatch_fn(worker_group, *args, **kwargs):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    for arg in args:\\n\",\n    \"        assert len(arg) == 2\\n\",\n    \"        for i in range(worker_group.world_size - 2):\\n\",\n    \"            arg.append(arg[i % 2])\\n\",\n    \"    for k, v in kwargs.items():\\n\",\n    \"        assert len(v) == 2\\n\",\n    \"        for i in range(worker_group.world_size - 2):\\n\",\n    \"            v.append(v[i % 2])\\n\",\n    \"    return args, kwargs\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@ray.remote\\n\",\n    \"class TestActor(Worker):\\n\",\n    \"    # TODO: pass *args and **kwargs is bug prone and not very convincing\\n\",\n    \"    def __init__(self, x) -> None:\\n\",\n    \"        super().__init__()\\n\",\n    \"        self._x = x\\n\",\n    \"\\n\",\n    \"    def foo(self, y):\\n\",\n    \"        return self._x + y\\n\",\n    \"\\n\",\n    \"    @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)\\n\",\n    \"    def foo_rank_zero(self, x, y):\\n\",\n    \"        return self._x + y + x\\n\",\n    \"\\n\",\n    \"    @register(dispatch_mode={\\\"dispatch_fn\\\": two_to_all_dispatch_fn, \\\"collect_fn\\\": collect_all_to_all})\\n\",\n    \"    def foo_custom(self, x, y):\\n\",\n    \"        return self._x + y + x\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 166,\n   \"id\": \"83ec6609\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)\\n\",\n    \"worker_group = RayWorkerGroup(resource_pool, class_with_args)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 167,\n   \"id\": \"62c58d8a\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])\\n\",\n    \"assert output_ref == [8, 10, 8, 10]\\n\",\n    \"\\n\",\n    \"output_ref = worker_group.foo_rank_zero(x=1, y=2)\\n\",\n    \"assert output_ref == 5\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 168,\n   \"id\": \"14689353\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"8\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(gpu_accumulator_decorator.world_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 169,\n   \"id\": \"2c80bbf4\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Shutdown ray cluster\\n\",\n    \"ray.shutdown()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a5c8151c\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Chapter 4: NVMegatronRayWorkerGroup\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"cd5680e9\",\n   \"metadata\": {},\n   \"source\": [\n    \"Due to the Ray issue, we can only support max_colocate_count=1 in RayResourcePool for now. \\n\",\n    \"This means that each GPU can only have one process.\\n\",\n    \"We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"92724419\",\n   \"metadata\": {},\n   \"source\": [\n    \"Therefore, we need to restart the ray and initialize a new resource_pool to demonstrate the **NVMegatronRayWorkerGroup**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9b038538\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Build a local ray cluster. The head node and worker node are on this machine\\n\",\n    \"ray.init()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ebfd8798\",\n   \"metadata\": {},\n   \"source\": [\n    \"Finally, we implement a `NVMegatronRayWorkerGroup`, within which we create a Megatron and then run a tensor parallel (tp) split Llama mlp layer. Here, we use a complex dispatch mode, `Megatron_COMPUTE`. This dispatch mode assumes that user passes the data partitioned by DP dimension. The data is dispatched to all tp/pp ranks within the same dp group, and ultimately only collects output data from tp=0 and the last pp. In this way, for users that only write code on the driver, the Megatron behind the RPC becomes transparent.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 171,\n   \"id\": \"5a032154\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/opt/tiger/Megatron-LM\\n\",\n      \"/opt/tiger/Megatron-LM/megatron/__init__.py\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import sys\\n\",\n    \"\\n\",\n    \"current_pythonpath = os.environ.get(\\\"PYTHONPATH\\\", \\\"\\\")\\n\",\n    \"\\n\",\n    \"new_path = \\\"/opt/tiger/Megatron-LM\\\"\\n\",\n    \"\\n\",\n    \"new_pythonpath = f\\\"{new_path}:{current_pythonpath}\\\" if current_pythonpath else new_path\\n\",\n    \"\\n\",\n    \"os.environ[\\\"PYTHONPATH\\\"] = new_pythonpath\\n\",\n    \"\\n\",\n    \"print(new_path)\\n\",\n    \"sys.path.append(new_path)\\n\",\n    \"\\n\",\n    \"import megatron\\n\",\n    \"\\n\",\n    \"print(megatron.__file__)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 172,\n   \"id\": \"8c84cd5a\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from megatron.core import parallel_state as mpu\\n\",\n    \"from omegaconf import OmegaConf\\n\",\n    \"\\n\",\n    \"from verl.single_controller.base.decorator import Dispatch, Execute, register\\n\",\n    \"from verl.single_controller.base.megatron.worker import MegatronWorker\\n\",\n    \"from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\\n\",\n    \"from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 173,\n   \"id\": \"1b1debcc\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"resource_pool = RayResourcePool([4], use_gpu=True, max_colocate_count=1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 174,\n   \"id\": \"bccbe081\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"@ray.remote\\n\",\n    \"class MLPLayerWorker(MegatronWorker):\\n\",\n    \"    def __init__(self):\\n\",\n    \"        super().__init__()\\n\",\n    \"        rank = int(os.environ[\\\"LOCAL_RANK\\\"])\\n\",\n    \"        torch.distributed.init_process_group(backend=\\\"nccl\\\")\\n\",\n    \"        torch.cuda.set_device(rank)\\n\",\n    \"\\n\",\n    \"        mpu.initialize_model_parallel(\\n\",\n    \"            tensor_model_parallel_size=4,\\n\",\n    \"            pipeline_model_parallel_size=1,\\n\",\n    \"            virtual_pipeline_model_parallel_size=None,\\n\",\n    \"            pipeline_model_parallel_split_rank=None,\\n\",\n    \"            use_sharp=False,\\n\",\n    \"            context_parallel_size=1,\\n\",\n    \"            expert_model_parallel_size=1,\\n\",\n    \"            nccl_communicator_config_path=None,\\n\",\n    \"        )\\n\",\n    \"        from megatron.core import tensor_parallel\\n\",\n    \"\\n\",\n    \"        tensor_parallel.model_parallel_cuda_manual_seed(10)\\n\",\n    \"\\n\",\n    \"    @register(Dispatch.ONE_TO_ALL)\\n\",\n    \"    def init_model(self, config):\\n\",\n    \"        from omegaconf import OmegaConf\\n\",\n    \"\\n\",\n    \"        from verl.models.llama.megatron.layers import ParallelLlamaMLP\\n\",\n    \"        from verl.utils.megatron_utils import init_model_parallel_config\\n\",\n    \"\\n\",\n    \"        megatron_config = OmegaConf.create(\\n\",\n    \"            {\\n\",\n    \"                \\\"sequence_parallel\\\": False,\\n\",\n    \"                \\\"param_dtype\\\": \\\"fp32\\\",\\n\",\n    \"                \\\"tensor_model_parallel_size\\\": mpu.get_tensor_model_parallel_world_size(),\\n\",\n    \"                \\\"pipeline_model_parallel_rank\\\": mpu.get_pipeline_model_parallel_rank(),\\n\",\n    \"                \\\"pipeline_model_parallel_size\\\": mpu.get_pipeline_model_parallel_world_size(),\\n\",\n    \"                \\\"virtual_pipeline_model_parallel_rank\\\": mpu.get_virtual_pipeline_model_parallel_rank(),\\n\",\n    \"                \\\"virtual_pipeline_model_parallel_size\\\": mpu.get_virtual_pipeline_model_parallel_world_size(),\\n\",\n    \"            }\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        megatron_config = init_model_parallel_config(megatron_config)\\n\",\n    \"        self.parallel_layer = ParallelLlamaMLP(config=config, megatron_config=megatron_config)\\n\",\n    \"\\n\",\n    \"    @register(Dispatch.ONE_TO_ALL)\\n\",\n    \"    def get_weights(self):\\n\",\n    \"        output = {}\\n\",\n    \"        for key, val in self.parallel_layer.named_parameters():\\n\",\n    \"            output[key] = val\\n\",\n    \"        return output\\n\",\n    \"\\n\",\n    \"    @register(Dispatch.MEGATRON_COMPUTE)\\n\",\n    \"    def run_layer(self, x):\\n\",\n    \"        x = x.to(\\\"cuda\\\")\\n\",\n    \"        y = self.parallel_layer(x)\\n\",\n    \"        return y\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 175,\n   \"id\": \"a655271d\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"layer_cls = RayClassWithInitArgs(cls=MLPLayerWorker)\\n\",\n    \"layer_worker_group = NVMegatronRayWorkerGroup(\\n\",\n    \"    resource_pool=resource_pool,\\n\",\n    \"    ray_cls_with_init=layer_cls,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 176,\n   \"id\": \"f105ebee\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"4 4 1 1\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(layer_worker_group.world_size, layer_worker_group.tp_size, layer_worker_group.pp_size, layer_worker_group.dp_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 177,\n   \"id\": \"38655091\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"ffn_hidden_size = 11008\\n\",\n    \"batch_size = 16\\n\",\n    \"seq_len = 2048\\n\",\n    \"hidden_size = 4096\\n\",\n    \"\\n\",\n    \"config = OmegaConf.create(\\n\",\n    \"    {\\n\",\n    \"        \\\"hidden_size\\\": hidden_size,\\n\",\n    \"        \\\"intermediate_size\\\": ffn_hidden_size,\\n\",\n    \"        \\\"hidden_act\\\": \\\"silu\\\",\\n\",\n    \"        \\\"pretraining_tp\\\": 1,\\n\",\n    \"        \\\"tp\\\": layer_worker_group.tp_size,\\n\",\n    \"    }\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 178,\n   \"id\": \"a026efca\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"x = torch.rand(size=(seq_len, batch_size, hidden_size), dtype=torch.float32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 179,\n   \"id\": \"f5fcaf13\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[None, None, None, None]\"\n      ]\n     },\n     \"execution_count\": 179,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"layer_worker_group.init_model(config)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 180,\n   \"id\": \"3f5cc9b4\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([2048, 16, 4096])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"output = layer_worker_group.run_layer(\\n\",\n    \"    [x]\\n\",\n    \")  # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\\n\",\n    \"print(output[0].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 181,\n   \"id\": \"49792210\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Shutdown ray cluster\\n\",\n    \"ray.shutdown()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.2\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "verl_distillation/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=reinforce_plus_plus \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=1024 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=mse \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=True \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=reinforce_plus_plus_baseline \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=1024 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=mse \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=True \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh",
    "content": "set -x\n\nexport HF_DATASETS_OFFLINE=1\nexport TRANSFORMERS_OFFLINE=1\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=remax \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_remax_example_gsm8k' \\\n    trainer.experiment_name='qwen2.5_3b_function_rm_kl1e-3' \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=5 $@\n"
  },
  {
    "path": "verl_distillation/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh",
    "content": "set -x\n\nexport HF_DATASETS_OFFLINE=1\nexport TRANSFORMERS_OFFLINE=1\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=remax \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_remax_example_gsm8k' \\\n    trainer.experiment_name='qwen2.5_7b_function_rm_kl1e-3' \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=10 $@\n"
  },
  {
    "path": "verl_distillation/examples/rloo_trainer/run_qwen2-7b.sh",
    "content": "set -x\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=rloo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_rloo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/rollout_importance_sampling/README.md",
    "content": "# Rollout Importance Sampling (IS) Examples\n\nThis directory contains examples and documentation for using Rollout Importance Sampling to correct distribution mismatch between rollout and training policies.\n\n**References:**\n- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda\n- Off-policy RL: https://fengyao.notion.site/off-policy-rl\n\n## Overview\n\nRollout Importance Sampling corrects for distribution mismatch when:\n1. **Rollout generation** uses one policy (e.g., vLLM with BFloat16)\n2. **Training** uses another policy (e.g., FSDP with FP32)\n3. This mismatch leads to biased gradient estimates\n\n## Quick Start\n\n### Basic Configuration\n\n```yaml\nalgorithm:\n  # Main control: set threshold to enable (null = disabled)\n  rollout_is_threshold: 2.0\n  # Whether to apply weights to policy loss (true) or just compute metrics (false)\n  rollout_is: true\n  rollout_is_level: token\n  rollout_is_mode: truncate\n\n# IMPORTANT: Must enable log prob calculation\nactor_rollout_ref:\n  rollout:\n    calculate_log_probs: true\n```\n\n### Running the Example\n\n```bash\n# Basic example with token-level truncate\nbash examples/rollout_importance_sampling/run_with_rollout_is.sh\n```\n\n## Configuration Options\n\n### Aggregation Levels (`rollout_is_level`)\n\n| Level | Properties | Threshold Range |\n|-------|-----------|-----------------|\n| **token** | Per-token | 1.5 - 5.0 |\n| **sequence** | Per-sequence | 2.0 - 10.0 |\n| **geometric** | Geometric mean | 1.0002 - 1.001 |\n\n### Bounding Modes (`rollout_is_mode`)\n\n| Mode | Behavior |\n|------|----------|\n| **truncate** | Cap weights at upper threshold only |\n| **clip** | Zero out weights outside [lower, upper] |\n\n### Key Parameters\n\n- `rollout_is_threshold`: Upper threshold for IS weights (null = disabled, float = enabled). **Main on/off switch.**\n- `rollout_is`: Whether to apply weights to loss (true) or just compute metrics (false). Default: false.\n- `rollout_is_threshold_lower`: Lower threshold (null = auto 1/upper)\n- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: null, disabled)\n\n## Configuration Examples\n\n### Example 1: Full IS Correction (Apply Weights)\n\n```yaml\nalgorithm:\n  rollout_is_threshold: 2.0\n  rollout_is: true  # Apply to loss\n  rollout_is_level: token\n  rollout_is_mode: truncate\n  rollout_is_veto_threshold: null  # Disabled by default\n```\n\n### Example 2: Metrics Only (No Weight Application)\n\n```yaml\nalgorithm:\n  rollout_is_threshold: 2.0\n  rollout_is: false  # Compute metrics only, don't apply to loss\n  rollout_is_level: token\n  rollout_is_mode: truncate\n```\n\n### Example 3: Geometric Mean with Mask\n\n```yaml\nalgorithm:\n  rollout_is_threshold: 1.0002\n  rollout_is: true\n  rollout_is_threshold_lower: 0.9998\n  rollout_is_level: geometric\n  rollout_is_mode: mask\n  rollout_is_veto_threshold: 1e-4  # Enable veto for this example\n```\n\n### Example 4: Sequence-level with Truncate\n\n```yaml\nalgorithm:\n  rollout_is_threshold: 5.0\n  rollout_is: true\n  rollout_is_threshold_lower: null  # Auto-reciprocal: 0.2\n  rollout_is_level: sequence\n  rollout_is_mode: truncate\n  rollout_is_veto_threshold: 1e-4  # Enable veto for this example\n```\n\n### Example 5: Asymmetric Thresholds\n\n```yaml\nalgorithm:\n  rollout_is_threshold: 5.0\n  rollout_is: true\n  rollout_is_threshold_lower: 0.8\n  rollout_is_level: token\n  rollout_is_mode: mask\n```\n\n## Monitoring Metrics\n\nKey metrics to watch (all prefixed with `mismatch/` in logs):\n\n### Health Indicators\n- `rollout_is_mean`: Mean IS weight across sequences\n- `rollout_is_eff_sample_size`: Effective sample size after weighting\n- `rollout_is_veto_fraction`: Fraction of sequences vetoed\n\n### Distribution Metrics\n- `rollout_is_max`, `rollout_is_min`: Weight extremes\n- `rollout_is_std`: Standard deviation\n\n### Diagnostic Metrics\n- `rollout_is_ratio_fraction_high`: Fraction exceeding upper threshold\n- `rollout_is_ratio_fraction_low`: Fraction below lower threshold\n- `rollout_is_catastrophic_token_fraction`: Catastrophic tokens detected\n\n### Mismatch Metrics (Training vs Rollout Policy)\n\nThese metrics help diagnose the distribution mismatch between rollout and training policies:\n\n**Perplexity Metrics:**\n- `mismatch_training_ppl`: Perplexity of training policy\n- `mismatch_rollout_ppl`: Perplexity of rollout policy\n- `mismatch_ppl_ratio`: Ratio of training PPL to rollout PPL\n- `mismatch_log_ppl_diff`: Log perplexity difference\n\n**KL Divergence Metrics:**\n- `mismatch_kl`: KL divergence KL(π_rollout || π_training)\n- `mismatch_k3_kl`: K3 KL estimator\n\n## Troubleshooting\n\n### Issue: High Variance in IS Weights\n\n**Symptoms**: `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3\n\n**Solutions**:\n1. Switch from `sequence` to `geometric` level\n2. Tighten thresholds\n3. Check if rollout and training are too different\n\n### Issue: Too Many Sequences Vetoed\n\n**Symptoms**: `rollout_is_veto_fraction` > 0.1\n\n**Solutions**:\n1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3`\n2. Check for numerical issues in log prob computation\n3. Verify rollout and training policies aren't completely different\n\n### Issue: Mean IS Weight Far from 1.0\n\n**Symptoms**: `rollout_is_mean` < 0.5 or > 2.0\n\n**Solutions**:\n1. Check that `calculate_log_probs=True` is set\n2. Verify rollout_log_probs are correctly passed\n3. Check for systematic bias in rollout vs training\n\n### Issue: Too Much Data Discarded (Mask Mode)\n\n**Symptoms**: `rollout_is_masked_fraction` > 0.5\n\n**Solutions**:\n1. Widen thresholds\n2. Switch to `truncate` mode\n3. Use `geometric` level for better stability\n\n## Performance Considerations\n\n### Memory Usage\n- Rollout IS adds minimal memory overhead (~1% of model memory)\n- Log-space computation prevents numerical overflow\n\n### Computational Cost\n- Token-level: ~1-2% overhead\n- Sequence-level: ~2-3% overhead\n- Geometric: ~2-3% overhead\n\n## Advanced Topics\n\n### Dual Thresholds\n\nSpecify both upper and lower explicitly:\n\n```yaml\nrollout_is_threshold: 2.0      # Upper\nrollout_is_threshold_lower: 0.5  # Lower (not 1/2.0 = 0.5)\n```\n\nOr use auto-reciprocal:\n\n```yaml\nrollout_is_threshold: 2.0      # Upper = 2.0, Lower = 0.5 (auto)\nrollout_is_threshold_lower: null\n```\n\n### Veto Mechanism\n\nThe veto mechanism zeros out entire sequences containing catastrophic outliers:\n\n- If any token has ratio < `rollout_is_veto_threshold`, the entire sequence is rejected\n- This prevents extreme outliers from dominating training\n- Default: `null` (disabled by default)\n- Set to `1e-4` to enable (catches ratios 10,000x off)\n\n## Examples\n\nSee the script in this directory:\n- `run_with_rollout_is.sh`: Basic example with token-level truncate mode\n\n## References\n\n- Implementation: `verl/trainer/ppo/mismatch_helper.py`\n- Core algorithm: `verl/trainer/ppo/core_algos.py`\n- Paper: \"Your Efficient RL Framework Secretly Brings You Off-Policy RL Training\"\n"
  },
  {
    "path": "verl_distillation/examples/rollout_importance_sampling/run_with_rollout_is.sh",
    "content": "#!/usr/bin/env bash\n# Example: Basic PPO training with Rollout Importance Sampling\n# This demonstrates the standard setup for correcting distribution mismatch\n\nset -xeuo pipefail\n\n# ==============================================================================\n# Rollout Importance Sampling Configuration\n# ==============================================================================\n\n# Main control: Upper threshold for IS weights (null = disabled, float = enabled)\nrollout_is_threshold=2.0\n\n# Whether to apply IS weights to policy loss\n# true = apply weights to loss, false = compute metrics only\nrollout_is=true\n\n# Lower threshold (null = auto-reciprocal, i.e., 1/upper = 0.5)\nrollout_is_threshold_lower=null\n\n# Aggregation level: token | sequence | geometric (experimental)\nrollout_is_level=token\n\n# Bounding mode: truncate (cap upper) | mask (zero outside bounds)\nrollout_is_mode=truncate\n\n# Catastrophic outlier veto threshold (set to null to disable, or e.g., 1e-4 to enable)\nrollout_is_veto_threshold=null\n\n# ==============================================================================\n# Model and Data Configuration\n# ==============================================================================\n\nMODEL_PATH=${MODEL_PATH:-\"Qwen/Qwen2.5-7B\"}\nTRAIN_FILE=${TRAIN_FILE:-\"data/train.parquet\"}\nTEST_FILE=${TEST_FILE:-\"data/test.parquet\"}\n\nmax_prompt_length=512\nmax_response_length=1024\n\n# ==============================================================================\n# Training Configuration\n# ==============================================================================\n\ntrain_batch_size=128\nppo_mini_batch_size=32\nppo_epochs=1\nlearning_rate=5e-7\n\n# ==============================================================================\n# Algorithm Configuration\n# ==============================================================================\n\nadv_estimator=gae\ngamma=1.0\nlam=0.95\n\n# ==============================================================================\n# Launch Training\n# ==============================================================================\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_batch_size} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.gamma=${gamma} \\\n    algorithm.lam=${lam} \\\n    algorithm.rollout_is=${rollout_is} \\\n    algorithm.rollout_is_threshold=${rollout_is_threshold} \\\n    algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \\\n    algorithm.rollout_is_level=${rollout_is_level} \\\n    algorithm.rollout_is_mode=${rollout_is_mode} \\\n    algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=${learning_rate} \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_epochs=${ppo_epochs} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    actor_rollout_ref.rollout.name=vllm \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"rollout_is_example\" \\\n    trainer.experiment_name=\"basic_token_truncate\" \\\n    trainer.total_epochs=10\n\necho \"Training completed!\"\necho \"\"\necho \"Rollout IS Configuration:\"\necho \"  - Threshold: ${rollout_is_threshold}\"\necho \"  - Apply to loss: ${rollout_is}\"\necho \"  - Level: ${rollout_is_level}\"\necho \"  - Mode: ${rollout_is_mode}\"\necho \"\"\necho \"Monitor these key metrics in wandb:\"\necho \"  - mismatch/rollout_is_mean (should be ~1.0)\"\necho \"  - mismatch/rollout_is_eff_sample_size (should be >0.5)\"\necho \"  - mismatch/rollout_is_veto_fraction (should be <0.1)\"\n"
  },
  {
    "path": "verl_distillation/examples/sft/gsm8k/run_deepseek_6b7.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_deepseek_6b7.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \\\n    trainer.total_epochs=4 \\\n    trainer.logger='[\"console\",\"wandb\"]' $@"
  },
  {
    "path": "verl_distillation/examples/sft/gsm8k/run_gemma_2b.sh",
    "content": "# Tested with 2 & 4 GPUs\n\nset -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_gemma_2b.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=google/gemma-2b-it \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-gemma-2b-it \\\n    trainer.total_epochs=2 \\\n    trainer.logger='[\"console\",\"wandb\"]' $@"
  },
  {
    "path": "verl_distillation/examples/sft/gsm8k/run_gemma_7b.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_gemma_7b.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    data.prompt_dict_keys=['question'] \\\n    data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=google/gemma-1.1-7b-it \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-gemma-1.1-7b-it \\\n    trainer.total_epochs=4 \\\n    trainer.logger='[\"console\",\"wandb\"]' $@\n"
  },
  {
    "path": "verl_distillation/examples/sft/gsm8k/run_qwen3_8b_sft_peft_sp2_npu.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_qwen3_8b_sft_peft_sp2_npu.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=64 \\\n    model.partial_pretrain=Qwen/Qwen3-8B \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-qwen3-8b-instruct \\\n    trainer.logger=console \\\n    trainer.total_epochs=2 $@ \\\n    model.lora_rank=32 \\\n    model.lora_alpha=16 \\\n    model.target_modules=all-linear \\\n    model.strategy=fsdp \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true \\\n    trainer.device=npu\n"
  },
  {
    "path": "verl_distillation/examples/sft/gsm8k/run_qwen_05_peft.sh",
    "content": "# Tested with 2 & 4 GPUs\n\nset -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_qwen_05_peft.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \\\n    trainer.logger=console \\\n    trainer.total_epochs=1 $@ \\\n    model.lora_rank=32\\\n    model.lora_alpha=16 \\\n    model.target_modules=all-linear\n\n    # Or you can do this:\n    # model.target_modules=[q_proj,v_proj] \\\n"
  },
  {
    "path": "verl_distillation/examples/sft/gsm8k/run_qwen_05_sp2.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_qwen_05_sp2.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size=4 \\\n    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2 \\\n    trainer.logger=console \\\n    trainer.total_training_steps=1 $@ \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true\n"
  },
  {
    "path": "verl_distillation/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_qwen_05_sp2.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size=4 \\\n    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \\\n    model.use_liger=True \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2-liger \\\n    trainer.logger=console $@ \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true\n"
  },
  {
    "path": "verl_distillation/examples/sft/gsm8k/run_seed_oss_36b_sft.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_seed_oss_36b_sft.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size=4 \\\n    model.partial_pretrain=ByteDance-Seed/Seed-OSS-36B-Base \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-seed-oss-36b \\\n    trainer.logger=console \\\n    trainer.total_training_steps=1 \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true $@\n"
  },
  {
    "path": "verl_distillation/examples/sft/multiturn/run_qwen_05_sp2.sh",
    "content": "#!/bin/bash\nset -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_qwen_05_sp2.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/multiturn/train.parquet \\\n    data.val_files=$HOME/data/multiturn/test.parquet \\\n    data.multiturn.enable=true \\\n    data.multiturn.messages_key=messages \\\n    data.micro_batch_size=4 \\\n    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=multiturn-sft \\\n    trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \\\n    trainer.logger=console \\\n    trainer.total_training_steps=1 $@ \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/README.md",
    "content": "# Multi-Turn Rollout Example (GSM8K)\n\nThis example demonstrates how to perform **multi-turn rollout** using SGLang with a tool-calling capable model (e.g., Qwen2.5-3B) on the GSM8K dataset.\n\n## Usage\n\n### Step 1: Download GSM8K Dataset\n\n```bash\ncd examples/data_preprocess\npython3 gsm8k_multiturn_w_tool.py\n```\n\nThis will download and preprocess the GSM8K dataset into ~/data/gsm8k/.\n\n### Step 2: Run Multi-Turn Rollout\n\nIf you have 8 GPUs\nUse the standard 8-GPU script:\n\n```bash\ncd your_verl_root_dir\nbash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh\n```\n\nIf you have only 4 GPUs\nUse the fallback 4-GPU script:\n\n```bash\ncd your_verl_root_dir\nbash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh \n```\n\n## Notes\n\n- The rollout supports multi-turn conversations with tool-calling capabilities.\n- Current tools are used for GSM8K answer evaluation.\n- Future versions may extend to search and code interpreter tools.\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 2048\n  max_response_length: 2048\n  train_batch_size: 256\n  return_raw_chat: True\n  return_multi_modal_inputs: False\n\nactor_rollout_ref:\n  hybrid_engine: True\n  model:\n    custom_chat_template: \"{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \\\"\\\\n\\\\n# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\\" }}{%- for tool in tools %}{{- \\\"\\\\n\\\" }}{{- tool | tojson }}{%- endfor %}{{- \\\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\\"name\\\\\\\": <function-name>, \\\\\\\"arguments\\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\\n{% endif %}{%- elif message.role == \\\"assistant\\\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}{{- tool_call.name }}{{- '\\\", \\\"arguments\\\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\\\n' }}{%- elif message.role == \\\"tool\\\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\\\n<tool_response>\\\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}{{- '<|im_end|>\\\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n{% endif %}{%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\\n{% endif %}{%- elif message.role == \\\"assistant\\\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}{{- tool_call.name }}{{- '\\\", \\\"arguments\\\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\\\n' }}{%- elif message.role == \\\"tool\\\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\\\n<tool_response>\\\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}{{- '<|im_end|>\\\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\\n{% endif %}\"\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n      # tool_config_path: \"./config/tool_config/gsm8k_tool_config.yaml\"\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_megatron_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 2048\n  max_response_length: 2048\n  train_batch_size: 256\n  return_raw_chat: True\n  return_multi_modal_inputs: False\n\nactor_rollout_ref:\n  hybrid_engine: True\n  model:\n    custom_chat_template: \"{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \\\"\\\\n\\\\n# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\\" }}{%- for tool in tools %}{{- \\\"\\\\n\\\" }}{{- tool | tojson }}{%- endfor %}{{- \\\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\\"name\\\\\\\": <function-name>, \\\\\\\"arguments\\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\\n{% endif %}{%- elif message.role == \\\"assistant\\\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}{{- tool_call.name }}{{- '\\\", \\\"arguments\\\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\\\n' }}{%- elif message.role == \\\"tool\\\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\\\n<tool_response>\\\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}{{- '<|im_end|>\\\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n{% endif %}{%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\\n{% endif %}{%- elif message.role == \\\"assistant\\\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}{{- tool_call.name }}{{- '\\\", \\\"arguments\\\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\\\n' }}{%- elif message.role == \\\"tool\\\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\\\n<tool_response>\\\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}{{- '<|im_end|>\\\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\\n{% endif %}\"\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n      # tool_config_path: \"./config/tool_config/gsm8k_tool_config.yaml\"\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_server.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n    sglang_rollout_mode: server\n    server:\n      timeout: 60\n      max_attempts: 3\n      retry_delay: 2\n      max_connections: 1000\n      max_start_wait_time: 300.0"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_user_turns: 5\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_megatron_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n  \n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml",
    "content": "interaction:\n  - name: \"gsm8k\"\n    class_name: \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\"\n    config: {}"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n      tool_config_path: \"./config/tool_config/sandbox_fusion_tool_config.yaml\"\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/search_multiturn_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n  shuffle: False\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 2\n      format: qwen\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/search_multiturn_grpo_one_step_off.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n  shuffle: False\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 2\n      format: qwen\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml",
    "content": "tools:\n  - class_name: \"verl.tools.geo3k_tool.Geo3kTool\"\n    config: \n      type: native\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"calc_geo3k_reward\"\n        description: \"A tool for calculating the reward of geo3k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)\"\n        parameters:\n          type: \"object\"\n          properties:\n            answer:\n              type: \"string\"\n              description: \"The model's answer to the geo3k problem, must be a digits\"\n          required: [\"answer\"]"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml",
    "content": "tools:\n  - class_name: \"verl.tools.gsm8k_tool.Gsm8kTool\"\n    config: \n      type: native\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"calc_gsm8k_reward\"\n        description: \"A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)\"\n        parameters:\n          type: \"object\"\n          properties:\n            answer:\n              type: \"string\"\n              description: \"The model's answer to the GSM8K math problem, must be a digits\"\n          required: [\"answer\"]\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/tool_config/mcp_server.json",
    "content": "{\n    \"mcpServers\": {\n        \"Tavily Expert\": {\n            \"url\": \"your_tavily_expert_url\",\n            \"auth_token\": \"your_tavily_api_token\"\n        }\n    }\n}"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml",
    "content": "tools:\n  - class_name: verl.tools.mcp_search_tool.MCPSearchTool\n    config:\n      rate_limit: 120\n      timeout: 120\n      type: mcp\n    mcp:\n      mcp_servers_config_path: ./mcp_server.json\n      # optional\n      tool_selected_list: \n        - tavily_search_tool"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml",
    "content": "tools:\n  - class_name: \"verl.tools.sandbox_fusion_tools.SandboxFusionTool\"\n    config: \n      sandbox_fusion_url: \"https://xxx.apigateway-cn-beijing.volceapi.com/run_code\"\n      num_workers: 10\n      enable_global_rate_limit: true\n      rate_limit: 10\n      default_timeout: 30\n      default_language: \"python\"\n      memory_limit_mb: 1024\n      type: native\n\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"code_interpreter\"\n        description: \"A tool for executing code.\"\n        parameters:\n          type: \"object\"\n          properties:\n            code:\n              type: \"string\"\n              description: \"The code to execute.\"\n          required: [\"code\"]"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml",
    "content": "tools:\n  - class_name: verl.tools.search_tool.SearchTool\n    config:\n      retrieval_service_url: http://127.0.0.1:8000/retrieve\n      num_workers: 120\n      rate_limit: 120\n      timeout: 30\n      type: native\n    tool_schema:\n      type: function\n      function:\n        name: search\n        description: Searches the web for relevant information based on the given query.\n        parameters:\n          type: object\n          properties:\n            query_list:\n              type: array\n              item:\n                type: string\n              description: A list of fully-formed semantic queries. The tool will return search results for each query.\n          required: \n            - query_list"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='geo3k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='geo3k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-sgl-multi-w-tool-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/geo3k_multiturn_w_tool/train.parquet \\\n    data.val_files=$HOME/data/geo3k_multiturn_w_tool/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh",
    "content": "# run on 4xH100\n# make sure your current working directory is the root of the project\n\nset -x\nexport HYDRA_FULL_ERROR=1\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='geo3k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='geo3k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-async-sgl-multi-w-tool-verify-n16-4cards' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.total_epochs=15 \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \\\n    critic.ppo_max_token_len_per_gpu=8192 \\\n    critic.forward_max_token_len_per_gpu=8192 \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml\" \\\n    $@"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n# this is a verification training script, the parallel setting should be tuned to your model\n\nset -x\n\nexport PYTHONUNBUFFERED=1\nexport RAY_DEDUP_LOGS=0\nexport RUST_BACKTRACE=1\nexport HYDRA_FULL_ERROR=1\nexport CUDA_DEVICE_MAX_CONNECTIONS=1\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='geo3k_multiturn_megatron_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.megatron.seed=42 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='geo3k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/geo3k_multiturn_w_tool/train.parquet \\\n    data.val_files=$HOME/data/geo3k_multiturn_w_tool/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen0.5b_gsm8k_multiturn_curriculum.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.sampler.class_name=\"RandomCurriculumSampler\" \\\n    data.sampler.class_path=\"pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu\" \\\n    data.dataloader_num_workers=0 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.train_batch_size=256 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen3-4b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\nTRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-512}\nMICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-8}\nOFFLOAD=${OFFLOAD:-False}\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo_w_interaction' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=$TRAIN_BATCH_SIZE \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=$((1024 * 3)) \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    +actor_rollout_ref.model.enable_activation_offloading=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_BATCH_SIZE \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=$OFFLOAD \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=$OFFLOAD \\\n    actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=$OFFLOAD \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen2.5-0.5b_function_rm-gsm8k-sgl-multi-w-interaction-n8' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/gsm8k_verl_sgl_multi_turn_w_interaction/train.parquet \\\n    data.val_files=$HOME/data/gsm8k_verl_sgl_multi_turn_w_interaction/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\nfunction now() {\n    date '+%d-%H-%M'\n}\n\nEXPERIMENT_NAME=\"qwen2.5-3b_baseline_$(now)\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    global_profiler.tool=torch_memory \\\n    global_profiler.save_path=./mem_snapshots \\\n    global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries=100000 \\\n    global_profiler.global_tool_config.torch_memory.stack_depth=32 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.multi_stage_wake_up=True \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.rollout.over_sample_rate=0.1 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='multi-turn-grpo-qwen2.5-3b-sglang' \\\n    trainer.experiment_name=$EXPERIMENT_NAME \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.val_before_train=True \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 \\\n    actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@\n\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh",
    "content": "# run on 4xH100\n# make sure your current working directory is the root of the project\n\nset -x\nexport HYDRA_FULL_ERROR=1\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16-4cards' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.total_epochs=15 \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \\\n    critic.ppo_max_token_len_per_gpu=8192 \\\n    critic.forward_max_token_len_per_gpu=8192 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml\" \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=1 \\\n    $@"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu_server.sh",
    "content": "# run on 4xH100\n# make sure your current working directory is the root of the project\n\nset -x\nexport HYDRA_FULL_ERROR=1\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo_server' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\", \"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl_server' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16-4cards' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.total_epochs=15 \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \\\n    critic.ppo_max_token_len_per_gpu=8192 \\\n    critic.forward_max_token_len_per_gpu=8192 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml\" \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=1 \\\n    $@"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_server.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\nfunction now() {\n    date '+%d-%H-%M'\n}\n\nEXPERIMENT_NAME=\"qwen2.5-3b_baseline_$(now)\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo_server' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.multi_stage_wake_up=True \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.rollout.over_sample_rate=0 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='multi-turn-grpo-qwen2.5-3b-sglang' \\\n    trainer.experiment_name=$EXPERIMENT_NAME \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.val_before_train=True \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 \\\n    actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@\n\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_vllm_fsdp.sh",
    "content": "# run on Ascend 910\n# make sure your current working directory is the root of the project\n\nset -x\nulimit -n 65535\n\n#set vllm v1 env\nexport VLLM_USE_V1=1\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\nTRAIN_BATCH_SIZE=32\nMICRO_BATCH_SIZE=8\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    actor_rollout_ref.rollout.name=vllm \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=${TRAIN_BATCH_SIZE} \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"Qwen/Qwen2.5-3B-Instruct\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${TRAIN_BATCH_SIZE} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${MICRO_BATCH_SIZE} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${MICRO_BATCH_SIZE} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9\\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${MICRO_BATCH_SIZE} \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \\\n    trainer.device=npu \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.logger='[\"console\"]' \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    trainer.total_epochs=15 \\\n    actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 \\\n    actor_rollout_ref.rollout.trace.token2text=False \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.multi_turn.enable=true \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    actor_rollout_ref.rollout.free_cache_engine=True"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.rollout.trace.backend=mlflow \\\n    actor_rollout_ref.rollout.trace.token2text=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"mlflow\"]' \\\n    trainer.project_name='gsm8k_tool-agent' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-tool-agent-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.total_training_steps=2 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n# this is a verification training script, the parallel setting should be tuned to your model\n\nset -x\n\nexport PYTHONUNBUFFERED=1\nexport RAY_DEDUP_LOGS=0\nexport RUST_BACKTRACE=1\nexport HYDRA_FULL_ERROR=1\nexport CUDA_DEVICE_MAX_CONNECTIONS=1\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_megatron_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=/user/longxiang1/models/Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.megatron.seed=42 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=/user/longxiang1/data/gsm8k_verl_sgl_multi_turn_preprocessed_v2/train.parquet \\\n    data.val_files=/user/longxiang1/data/gsm8k_verl_sgl_multi_turn_preprocessed_v2/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen3-4B \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.rollout.over_sample_rate=0.1 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen3-4b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/run_qwen3_4b_dapo_multiturn.sh",
    "content": "set -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npip install --upgrade \"huggingface-hub>=0.34.0\"\nhf download \\\n    BytedTsinghua-SIA/DAPO-Math-17k \\\n    --repo-type dataset \\\n    --local-dir $HOME/data/BytedTsinghua-SIA/DAPO-Math-17k\n\n\nhf download \\\n    Maxwell-Jia/AIME_2024 \\\n    --repo-type dataset \\\n    --local-dir $HOME/data/Maxwell-Jia/AIME_2024\n\n\n# Note:\n# 1. \n# a sandbox fusion server is needed to run the code interpreter tool.\n# docker run -it -p 8080:8080 volcengine/sandbox-fusion:server-20250609\n\n# 2. \n# The model located at font-info/qwen3-4b-sft-SGLang-RL (https://huggingface.co/font-info/qwen3-4b-sft-SGLang-RL)\n# is a fine-tuned version provided by the SGLang RL team. Without supervised fine-tuning (SFT)\n# on the Retool dataset, Dapo training will not converge.\n\n# If you still wish to perform SFT from scratch, follow the steps below:\n\n# Step 1: Download the SFT dataset\n#huggingface-cli download JoeYing/ReTool-SFT --repo-type dataset --local-dir ./ReTool-SFT\n\n# Step 2: Preprocess the data for SFT\n#python3 recipe/retool/retool_sft_preprocess.py\n\n# Step 3: Run SFT training\n#bash recipe/retool/run_qwen2-32b_sft.sh\n\n# having trouble setup? see https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/latest_sglang.md for more details.\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    algorithm.use_kl_in_reward=False \\\n    algorithm.kl_ctrl.kl_coef=0.0 \\\n    data.train_files=$HOME/data/BytedTsinghua-SIA/DAPO-Math-17k \\\n    data.val_files=$HOME/data/Maxwell-Jia/AIME_2024 \\\n    data.return_raw_chat=True \\\n    data.train_batch_size=32 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=16384 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.custom_cls.path=$PROJECT_DIR/recipe/retool/retool.py \\\n    data.custom_cls.name=CustomRLHFDataset \\\n    custom_reward_function.path=$PROJECT_DIR/recipe/retool/retool.py \\\n    custom_reward_function.name=compute_score \\\n    actor_rollout_ref.model.path=font-info/qwen3-4b-sft-SGLang-RL \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.0 \\\n    actor_rollout_ref.actor.clip_ratio_low=0.2 \\\n    actor_rollout_ref.actor.clip_ratio_high=0.28 \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=False \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.multi_stage_wake_up=True \\\n    actor_rollout_ref.rollout.multi_turn.enable=True \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=16 \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=16 \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=$PROJECT_DIR/recipe/retool/sandbox_fusion_tool_config.yaml \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \\\n    actor_rollout_ref.rollout.val_kwargs.n=30 \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=sglang-dapo-multiturn \\\n    trainer.experiment_name=qwen3_4b_sft_dapo_multiturn \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.log_val_generations=20 \\\n    trainer.val_before_train=True \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.total_epochs=15 \\\n    $@\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 Search-R1 Contributors\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/scripts/download.py\n\n\nimport argparse\n\nfrom huggingface_hub import hf_hub_download\n\nparser = argparse.ArgumentParser(description=\"Download files from a Hugging Face dataset repository.\")\nparser.add_argument(\"--repo_id\", type=str, default=\"PeterJinGo/wiki-18-e5-index\", help=\"Hugging Face repository ID\")\nparser.add_argument(\"--save_path\", type=str, required=True, help=\"Local directory to save files\")\n\nargs = parser.parse_args()\n\nrepo_id = \"PeterJinGo/wiki-18-e5-index\"\nfor file in [\"part_aa\", \"part_ab\"]:\n    hf_hub_download(\n        repo_id=repo_id,\n        filename=file,  # e.g., \"e5_Flat.index\"\n        repo_type=\"dataset\",\n        local_dir=args.save_path,\n    )\n\nrepo_id = \"PeterJinGo/wiki-18-corpus\"\nhf_hub_download(\n    repo_id=repo_id,\n    filename=\"wiki-18.jsonl.gz\",\n    repo_type=\"dataset\",\n    local_dir=args.save_path,\n)\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 Search-R1 Contributors\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/search_r1/search/retrieval_server.py\n\nimport argparse\nimport json\nimport warnings\nfrom typing import Optional\n\nimport datasets\nimport faiss\nimport numpy as np\nimport torch\nimport uvicorn\nfrom fastapi import FastAPI\nfrom pydantic import BaseModel\nfrom tqdm import tqdm\nfrom transformers import AutoModel, AutoTokenizer\n\n\ndef load_corpus(corpus_path: str):\n    corpus = datasets.load_dataset(\"json\", data_files=corpus_path, split=\"train\", num_proc=4)\n    return corpus\n\n\ndef load_docs(corpus, doc_idxs):\n    results = [corpus[int(idx)] for idx in doc_idxs]\n    return results\n\n\ndef load_model(model_path: str, use_fp16: bool = False):\n    model = AutoModel.from_pretrained(model_path, trust_remote_code=True)\n    model.eval()\n    model.cuda()\n    if use_fp16:\n        model = model.half()\n    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)\n    return model, tokenizer\n\n\ndef pooling(pooler_output, last_hidden_state, attention_mask=None, pooling_method=\"mean\"):\n    if pooling_method == \"mean\":\n        last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)\n        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]\n    elif pooling_method == \"cls\":\n        return last_hidden_state[:, 0]\n    elif pooling_method == \"pooler\":\n        return pooler_output\n    else:\n        raise NotImplementedError(\"Pooling method not implemented!\")\n\n\nclass Encoder:\n    def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):\n        self.model_name = model_name\n        self.model_path = model_path\n        self.pooling_method = pooling_method\n        self.max_length = max_length\n        self.use_fp16 = use_fp16\n\n        self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16)\n        self.model.eval()\n\n    @torch.no_grad()\n    def encode(self, query_list: list[str], is_query=True) -> np.ndarray:\n        # processing query for different encoders\n        if isinstance(query_list, str):\n            query_list = [query_list]\n\n        if \"e5\" in self.model_name.lower():\n            if is_query:\n                query_list = [f\"query: {query}\" for query in query_list]\n            else:\n                query_list = [f\"passage: {query}\" for query in query_list]\n\n        if \"bge\" in self.model_name.lower():\n            if is_query:\n                query_list = [\n                    f\"Represent this sentence for searching relevant passages: {query}\" for query in query_list\n                ]\n\n        inputs = self.tokenizer(\n            query_list, max_length=self.max_length, padding=True, truncation=True, return_tensors=\"pt\"\n        )\n        inputs = {k: v.cuda() for k, v in inputs.items()}\n\n        if \"T5\" in type(self.model).__name__:\n            # T5-based retrieval model\n            decoder_input_ids = torch.zeros((inputs[\"input_ids\"].shape[0], 1), dtype=torch.long).to(\n                inputs[\"input_ids\"].device\n            )\n            output = self.model(**inputs, decoder_input_ids=decoder_input_ids, return_dict=True)\n            query_emb = output.last_hidden_state[:, 0, :]\n        else:\n            output = self.model(**inputs, return_dict=True)\n            query_emb = pooling(\n                output.pooler_output, output.last_hidden_state, inputs[\"attention_mask\"], self.pooling_method\n            )\n            if \"dpr\" not in self.model_name.lower():\n                query_emb = torch.nn.functional.normalize(query_emb, dim=-1)\n\n        query_emb = query_emb.detach().cpu().numpy()\n        query_emb = query_emb.astype(np.float32, order=\"C\")\n\n        del inputs, output\n        torch.cuda.empty_cache()\n\n        return query_emb\n\n\nclass BaseRetriever:\n    def __init__(self, config):\n        self.config = config\n        self.retrieval_method = config.retrieval_method\n        self.topk = config.retrieval_topk\n\n        self.index_path = config.index_path\n        self.corpus_path = config.corpus_path\n\n    def _search(self, query: str, num: int, return_score: bool):\n        raise NotImplementedError\n\n    def _batch_search(self, query_list: list[str], num: int, return_score: bool):\n        raise NotImplementedError\n\n    def search(self, query: str, num: int = None, return_score: bool = False):\n        return self._search(query, num, return_score)\n\n    def batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):\n        return self._batch_search(query_list, num, return_score)\n\n\nclass BM25Retriever(BaseRetriever):\n    def __init__(self, config):\n        super().__init__(config)\n        from pyserini.search.lucene import LuceneSearcher\n\n        self.searcher = LuceneSearcher(self.index_path)\n        self.contain_doc = self._check_contain_doc()\n        if not self.contain_doc:\n            self.corpus = load_corpus(self.corpus_path)\n        self.max_process_num = 8\n\n    def _check_contain_doc(self):\n        return self.searcher.doc(0).raw() is not None\n\n    def _search(self, query: str, num: int = None, return_score: bool = False):\n        if num is None:\n            num = self.topk\n        hits = self.searcher.search(query, num)\n        if len(hits) < 1:\n            if return_score:\n                return [], []\n            else:\n                return []\n        scores = [hit.score for hit in hits]\n        if len(hits) < num:\n            warnings.warn(\"Not enough documents retrieved!\", stacklevel=2)\n        else:\n            hits = hits[:num]\n\n        if self.contain_doc:\n            all_contents = [json.loads(self.searcher.doc(hit.docid).raw())[\"contents\"] for hit in hits]\n            results = [\n                {\n                    \"title\": content.split(\"\\n\")[0].strip('\"'),\n                    \"text\": \"\\n\".join(content.split(\"\\n\")[1:]),\n                    \"contents\": content,\n                }\n                for content in all_contents\n            ]\n        else:\n            results = load_docs(self.corpus, [hit.docid for hit in hits])\n\n        if return_score:\n            return results, scores\n        else:\n            return results\n\n    def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):\n        results = []\n        scores = []\n        for query in query_list:\n            item_result, item_score = self._search(query, num, True)\n            results.append(item_result)\n            scores.append(item_score)\n        if return_score:\n            return results, scores\n        else:\n            return results\n\n\nclass DenseRetriever(BaseRetriever):\n    def __init__(self, config):\n        super().__init__(config)\n        self.index = faiss.read_index(self.index_path)\n        if config.faiss_gpu:\n            co = faiss.GpuMultipleClonerOptions()\n            co.useFloat16 = True\n            co.shard = True\n            self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)\n\n        self.corpus = load_corpus(self.corpus_path)\n        self.encoder = Encoder(\n            model_name=self.retrieval_method,\n            model_path=config.retrieval_model_path,\n            pooling_method=config.retrieval_pooling_method,\n            max_length=config.retrieval_query_max_length,\n            use_fp16=config.retrieval_use_fp16,\n        )\n        self.topk = config.retrieval_topk\n        self.batch_size = config.retrieval_batch_size\n\n    def _search(self, query: str, num: int = None, return_score: bool = False):\n        if num is None:\n            num = self.topk\n        query_emb = self.encoder.encode(query)\n        scores, idxs = self.index.search(query_emb, k=num)\n        idxs = idxs[0]\n        scores = scores[0]\n        results = load_docs(self.corpus, idxs)\n        if return_score:\n            return results, scores.tolist()\n        else:\n            return results\n\n    def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):\n        if isinstance(query_list, str):\n            query_list = [query_list]\n        if num is None:\n            num = self.topk\n\n        results = []\n        scores = []\n        for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc=\"Retrieval process: \"):\n            query_batch = query_list[start_idx : start_idx + self.batch_size]\n            batch_emb = self.encoder.encode(query_batch)\n            batch_scores, batch_idxs = self.index.search(batch_emb, k=num)\n            batch_scores = batch_scores.tolist()\n            batch_idxs = batch_idxs.tolist()\n\n            # load_docs is not vectorized, but is a python list approach\n            flat_idxs = sum(batch_idxs, [])\n            batch_results = load_docs(self.corpus, flat_idxs)\n            # chunk them back\n            batch_results = [batch_results[i * num : (i + 1) * num] for i in range(len(batch_idxs))]\n\n            results.extend(batch_results)\n            scores.extend(batch_scores)\n\n            del batch_emb, batch_scores, batch_idxs, query_batch, flat_idxs, batch_results\n            torch.cuda.empty_cache()\n\n        if return_score:\n            return results, scores\n        else:\n            return results\n\n\ndef get_retriever(config):\n    if config.retrieval_method == \"bm25\":\n        return BM25Retriever(config)\n    else:\n        return DenseRetriever(config)\n\n\n#####################################\n# FastAPI server below\n#####################################\n\n\nclass Config:\n    \"\"\"\n    Minimal config class (simulating your argparse)\n    Replace this with your real arguments or load them dynamically.\n    \"\"\"\n\n    def __init__(\n        self,\n        retrieval_method: str = \"bm25\",\n        retrieval_topk: int = 10,\n        index_path: str = \"./index/bm25\",\n        corpus_path: str = \"./data/corpus.jsonl\",\n        dataset_path: str = \"./data\",\n        data_split: str = \"train\",\n        faiss_gpu: bool = True,\n        retrieval_model_path: str = \"./model\",\n        retrieval_pooling_method: str = \"mean\",\n        retrieval_query_max_length: int = 256,\n        retrieval_use_fp16: bool = False,\n        retrieval_batch_size: int = 128,\n    ):\n        self.retrieval_method = retrieval_method\n        self.retrieval_topk = retrieval_topk\n        self.index_path = index_path\n        self.corpus_path = corpus_path\n        self.dataset_path = dataset_path\n        self.data_split = data_split\n        self.faiss_gpu = faiss_gpu\n        self.retrieval_model_path = retrieval_model_path\n        self.retrieval_pooling_method = retrieval_pooling_method\n        self.retrieval_query_max_length = retrieval_query_max_length\n        self.retrieval_use_fp16 = retrieval_use_fp16\n        self.retrieval_batch_size = retrieval_batch_size\n\n\nclass QueryRequest(BaseModel):\n    queries: list[str]\n    topk: Optional[int] = None\n    return_scores: bool = False\n\n\napp = FastAPI()\n\n\n@app.post(\"/retrieve\")\ndef retrieve_endpoint(request: QueryRequest):\n    \"\"\"\n    Endpoint that accepts queries and performs retrieval.\n\n    Input format:\n    {\n      \"queries\": [\"What is Python?\", \"Tell me about neural networks.\"],\n      \"topk\": 3,\n      \"return_scores\": true\n    }\n\n    Output format (when return_scores=True，similarity scores are returned):\n    {\n        \"result\": [\n            [   # Results for each query\n                {\n                    {\"document\": doc, \"score\": score}\n                },\n                # ... more documents\n            ],\n            # ... results for other queries\n        ]\n    }\n    \"\"\"\n    if not request.topk:\n        request.topk = config.retrieval_topk  # fallback to default\n\n    # Perform batch retrieval\n    results, scores = retriever.batch_search(\n        query_list=request.queries, num=request.topk, return_score=request.return_scores\n    )\n\n    # Format response\n    resp = []\n    for i, single_result in enumerate(results):\n        if request.return_scores:\n            # If scores are returned, combine them with results\n            combined = []\n            for doc, score in zip(single_result, scores[i], strict=True):\n                combined.append({\"document\": doc, \"score\": score})\n            resp.append(combined)\n        else:\n            resp.append(single_result)\n    return {\"result\": resp}\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Launch the local faiss retriever.\")\n    parser.add_argument(\n        \"--index_path\", type=str, default=\"/home/peterjin/mnt/index/wiki-18/e5_Flat.index\", help=\"Corpus indexing file.\"\n    )\n    parser.add_argument(\n        \"--corpus_path\",\n        type=str,\n        default=\"/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl\",\n        help=\"Local corpus file.\",\n    )\n    parser.add_argument(\"--topk\", type=int, default=3, help=\"Number of retrieved passages for one query.\")\n    parser.add_argument(\"--retriever_name\", type=str, default=\"e5\", help=\"Name of the retriever model.\")\n    parser.add_argument(\n        \"--retriever_model\", type=str, default=\"intfloat/e5-base-v2\", help=\"Path of the retriever model.\"\n    )\n    parser.add_argument(\"--faiss_gpu\", action=\"store_true\", help=\"Use GPU for computation\")\n\n    args = parser.parse_args()\n\n    # 1) Build a config (could also parse from arguments).\n    #    In real usage, you'd parse your CLI arguments or environment variables.\n    config = Config(\n        retrieval_method=args.retriever_name,  # or \"dense\"\n        index_path=args.index_path,\n        corpus_path=args.corpus_path,\n        retrieval_topk=args.topk,\n        faiss_gpu=args.faiss_gpu,\n        retrieval_model_path=args.retriever_model,\n        retrieval_pooling_method=\"mean\",\n        retrieval_query_max_length=256,\n        retrieval_use_fp16=True,\n        retrieval_batch_size=512,\n    )\n\n    # 2) Instantiate a global retriever so it is loaded once and reused.\n    retriever = get_retriever(config)\n\n    # 3) Launch the server. By default, it listens on http://127.0.0.1:8000\n    uvicorn.run(app, host=\"0.0.0.0\", port=8000)\n"
  },
  {
    "path": "verl_distillation/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh",
    "content": "# run on 8xH20\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\n\nTRAIN_DATA=\"$HOME/data/searchR1_processed_direct/train.parquet\"\nVAL_DATA=\"$HOME/data/searchR1_processed_direct/test.parquet\"\n\nTOOL_CONFIG=\"$CONFIG_PATH/tool_config/search_tool_config.yaml\"\n\n\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='search_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=512 \\\n    data.val_batch_size=256 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=3000 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.max_model_len=15000 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.val_before_train=False \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='search_r1_like_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b-instruct_function_rm-search-async-sgl-multi-w-searchtool-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=100 \\\n    trainer.test_freq=50 \\\n    data.train_files=\"$TRAIN_DATA\" \\\n    data.val_files=\"$VAL_DATA\"  \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$TOOL_CONFIG\" \\\n    trainer.total_epochs=1 $@\n\n"
  },
  {
    "path": "verl_distillation/examples/skypilot/README.md",
    "content": "# verl with SkyPilot\n\nRun verl reinforcement learning training jobs on Kubernetes clusters or cloud platforms with GPU nodes using [SkyPilot](https://github.com/skypilot-org/skypilot).\n\n## Installation and Configuration\n\n### Step 1: Install SkyPilot\n\nChoose the installation based on your target platform:\n\n```bash\n# For Kubernetes only\npip install \"skypilot[kubernetes]\"\n\n# For AWS\npip install \"skypilot[aws]\"\n\n# For Google Cloud Platform\npip install \"skypilot[gcp]\"\n\n# For Azure\npip install \"skypilot[azure]\"\n\n# For multiple platforms\npip install \"skypilot[kubernetes,aws,gcp,azure]\"\n```\n\n### Step 2: Configure Your Platform\n\nSee https://docs.skypilot.co/en/latest/getting-started/installation.html\n\n### Step 3: Set Up Environment Variables\n\nExport necessary API keys for experiment tracking:\n\n```bash\n# For Weights & Biases tracking\nexport WANDB_API_KEY=\"your-wandb-api-key\"\n\n# For HuggingFace gated models (if needed)\nexport HF_TOKEN=\"your-huggingface-token\"\n```\n\n## Examples\n\n### PPO Training\n```bash\nsky launch -c verl-ppo verl-ppo.yaml --secret WANDB_API_KEY -y\n```\nRuns PPO training on GSM8K dataset using Qwen2.5-0.5B-Instruct model across 2 nodes with H100 GPUs. Based on examples in [`../ppo_trainer/`](../ppo_trainer/).\n\n### GRPO Training  \n```bash\nsky launch -c verl-grpo verl-grpo.yaml --secret WANDB_API_KEY -y\n```\nRuns GRPO (Group Relative Policy Optimization) training on MATH dataset using Qwen2.5-7B-Instruct model. Memory-optimized configuration for 2 nodes. Based on examples in [`../grpo_trainer/`](../grpo_trainer/).\n\n### Multi-turn Tool Usage Training\n```bash\nsky launch -c verl-multiturn verl-multiturn-tools.yaml --secret WANDB_API_KEY --secret HF_TOKEN -y\n```\nSingle-node training with 8xH100 GPUs for multi-turn tool usage with Qwen2.5-3B-Instruct. Includes tool and interaction configurations for GSM8K. Based on examples in [`../sglang_multiturn/`](../sglang_multiturn/) but uses vLLM instead of sglang.\n\n## Configuration\n\nThe example YAML files are pre-configured with:\n\n- **Infrastructure**: Kubernetes clusters (`infra: k8s`) - can be changed to `infra: aws` or `infra: gcp`, etc.\n- **Docker Image**: verl's official Docker image with CUDA 12.6 support\n- **Setup**: Automatically clones and installs verl from source\n- **Datasets**: Downloads required datasets during setup phase\n- **Ray Cluster**: Configures distributed training across nodes\n- **Logging**: Supports Weights & Biases via `--secret WANDB_API_KEY`\n- **Models**: Supports gated HuggingFace models via `--secret HF_TOKEN`\n\n## Launch Command Options\n\n- `-c <name>`: Cluster name for managing the job\n- `--secret KEY`: Pass secrets for API keys (can be used multiple times)\n- `-y`: Skip confirmation prompt\n\n## Monitoring Your Jobs\n\n### Check cluster status\n```bash\nsky status\n```\n\n### View logs\n```bash\nsky logs verl-ppo  # View logs for the PPO job\n```\n\n### SSH into head node\n```bash\nssh verl-ppo\n```\n\n### Access Ray dashboard\n```bash\nsky status --endpoint 8265 verl-ppo  # Get dashboard URL\n```\n\n### Stop a cluster\n```bash\nsky down verl-ppo\n```\n"
  },
  {
    "path": "verl_distillation/examples/skypilot/verl-grpo.yaml",
    "content": "resources:\n  infra: k8s\n  accelerators: H100:1 \n  memory: 128+\n  image_id: docker:verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4\n  ports: 8265\n\nnum_nodes: 2\n\nsecrets:\n  WANDB_API_KEY: \n\nsetup: |  \n  rm -rf verl\n  git clone https://github.com/volcengine/verl.git\n  cd verl\n  pip3 install -v -e .[vllm]\n  pip3 install flashinfer-python\n  echo \"Downloading Math dataset...\"\n  mkdir -p ~/data/math\n  python3 \"$(pwd)/examples/data_preprocess/math_dataset.py\" --local_dir ~/data/math\n  echo \"Math dataset download completed\"\n\nrun: |\n  HEAD_IP=$(echo \"$SKYPILOT_NODE_IPS\" | head -n1)\n  NUM_NODES=$SKYPILOT_NUM_NODES\n  NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE\n  \n  if [ \"$SKYPILOT_NODE_RANK\" == \"0\" ]; then\n    echo \"Starting Ray head node...\"\n    ps aux | grep ray | grep 6379 &> /dev/null ||  ray start --head --disable-usage-stats \\\n          --port=6379 \\\n          --dashboard-host=0.0.0.0 \\\n          --dashboard-port=8265\n\n    # Wait for all worker nodes to join\n    retry_count=0\n    max_retries=30\n    while [ $retry_count -lt $max_retries ]; do\n      connected_nodes=$(ray status 2>/dev/null | grep -c \"node_\" || echo \"0\")\n      echo \"Connected nodes: $connected_nodes/$NUM_NODES (attempt $((retry_count+1))/$max_retries)\"\n      \n      if [ \"$connected_nodes\" -ge \"$NUM_NODES\" ]; then\n        echo \"All nodes connected to Ray cluster\"\n        break\n      fi\n      \n      retry_count=$((retry_count+1))\n      sleep 10\n    done\n\n    python3 -m verl.trainer.main_ppo \\\n     algorithm.adv_estimator=grpo \\\n     data.train_files=$HOME/data/math/train.parquet \\\n     data.val_files=$HOME/data/math/test.parquet \\\n     data.train_batch_size=32 \\\n     data.max_prompt_length=256 \\\n     data.max_response_length=256 \\\n     data.filter_overlong_prompts=True \\\n     data.truncation='error' \\\n     actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n     actor_rollout_ref.actor.optim.lr=1e-6 \\\n     actor_rollout_ref.model.use_remove_padding=True \\\n     actor_rollout_ref.actor.ppo_mini_batch_size=16 \\\n     actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n     actor_rollout_ref.actor.ppo_epochs=1 \\\n     actor_rollout_ref.actor.use_kl_loss=False \\\n     actor_rollout_ref.actor.entropy_coeff=0 \\\n     actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n     actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n     actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n     actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n     actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n     actor_rollout_ref.rollout.name=vllm \\\n     actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n     actor_rollout_ref.rollout.n=1 \\\n     actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n     actor_rollout_ref.rollout.max_num_batched_tokens=2048 \\\n     actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n     actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n     algorithm.use_kl_in_reward=False \\\n     trainer.critic_warmup=0 \\\n     trainer.logger=[console,wandb] \\\n     trainer.project_name=verl_math_grpo_demo \\\n     trainer.experiment_name=qwen25_7b_grpo \\\n     trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \\\n     trainer.nnodes=$NUM_NODES \\\n     trainer.save_freq=-1 \\\n     trainer.test_freq=-1 \\\n     trainer.total_epochs=1\n\n  else\n    sleep 15\n    echo \"Starting Ray worker node...\"\n    ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats\n    sleep 10\n  fi\n\n  echo \"Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK.\""
  },
  {
    "path": "verl_distillation/examples/skypilot/verl-multiturn-tools.yaml",
    "content": "resources:\n  infra: k8s\n  accelerators: H100:8\n  memory: 128+\n  image_id: docker:verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4\n  ports: 8265\n\nnum_nodes: 1\n\nsecrets:\n  WANDB_API_KEY: \n  HF_TOKEN: # in case you're using gated models from the HF hub\n\nsetup: |\n  rm -rf verl\n  git clone https://github.com/volcengine/verl.git\n  cd verl\n  pip3 install -v -e .[vllm]\n  pip3 install flashinfer-python\n  pip install \"transformers<4.54.0\" # https://github.com/vllm-project/vllm-ascend/issues/2046\n  # Download GSM8K dataset for multiturn tool training\n  echo \"Downloading GSM8K dataset...\"\n  mkdir -p ~/data/gsm8k\n  python3 \"$(pwd)/examples/data_preprocess/gsm8k.py\" --local_dir ~/data/gsm8k\n  echo \"GSM8K dataset download completed\"\n\nrun: |\n  NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE\n  PROJECT_DIR=\"$(pwd)/verl\"\n  CONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n  \n  # Single node setup - no worker coordination needed\n  echo \"Starting Ray head node...\"\n  ps aux | grep ray | grep 6379 &> /dev/null || ray start --head --disable-usage-stats \\\n        --port=6379 \\\n        --dashboard-host=0.0.0.0 \\\n        --dashboard-port=8265\n\n  cd verl\n\n  python3 -m verl.trainer.main_ppo \\\n     --config-path=\"$CONFIG_PATH\" \\\n     --config-name='gsm8k_multiturn_grpo' \\\n     algorithm.adv_estimator=grpo \\\n     data.train_batch_size=512 \\\n     data.max_prompt_length=1024 \\\n     data.max_response_length=1024 \\\n     data.filter_overlong_prompts=True \\\n     data.truncation='error' \\\n     data.return_raw_chat=True \\\n     data.train_files=$HOME/data/gsm8k/train.parquet \\\n     data.val_files=$HOME/data/gsm8k/test.parquet \\\n     actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n     actor_rollout_ref.actor.optim.lr=1e-6 \\\n     actor_rollout_ref.model.use_remove_padding=True \\\n     actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n     actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n     actor_rollout_ref.actor.use_kl_loss=True \\\n     actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n     actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n     actor_rollout_ref.actor.entropy_coeff=0 \\\n     actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n     actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n     actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n     actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \\\n     actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n     actor_rollout_ref.rollout.name=vllm \\\n     actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n     actor_rollout_ref.rollout.n=16 \\\n     actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \\\n     actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n     algorithm.use_kl_in_reward=False \\\n     trainer.critic_warmup=0 \\\n     trainer.logger=[console,wandb] \\\n     trainer.project_name=verl_multiturn_tools \\\n     trainer.experiment_name=qwen25_7b_gsm8k_multiturn_tools \\\n     trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \\\n     trainer.nnodes=1 \\\n     trainer.save_freq=10 \\\n     trainer.test_freq=5 \\\n     trainer.total_epochs=10 \\\n     actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \\\n     actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \\\n     actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \\\n     critic.ppo_max_token_len_per_gpu=8192 \\\n     critic.forward_max_token_len_per_gpu=8192 \\\n     actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n     actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml\" \\\n     actor_rollout_ref.rollout.multi_turn.max_user_turns=1\n\n  echo \"Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK.\""
  },
  {
    "path": "verl_distillation/examples/skypilot/verl-ppo.yaml",
    "content": "resources:\n  infra: k8s\n  accelerators: H100:1 \n  memory: 128+\n  image_id: docker:verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4\n  ports: 8265\n\nnum_nodes: 2\n\nsecrets:\n  WANDB_API_KEY: \n\nsetup: |  \n  rm -rf verl\n  git clone https://github.com/volcengine/verl.git\n  cd verl\n  pip3 install -v -e .[vllm]\n  pip3 install flashinfer-python\n  # Download GSM8K dataset - alternative approach\n  echo \"Downloading GSM8K dataset...\"\n  mkdir -p ~/data/gsm8k\n  # Check if the script exists and use absolute path\n  if [ -f \"$(pwd)/examples/data_preprocess/gsm8k.py\" ]; then\n    python3 \"$(pwd)/examples/data_preprocess/gsm8k.py\" --local_dir ~/data/gsm8k\n  else\n    echo \"Warning: gsm8k.py script not found, skipping dataset download\"\n    # You might want to download the dataset manually or use a different approach\n  fi\n  echo \"GSM8K dataset download completed\"\n\nrun: |\n  # Get the Head node's IP and total number of nodes\n  HEAD_IP=$(echo \"$SKYPILOT_NODE_IPS\" | head -n1)\n  NUM_NODES=$SKYPILOT_NUM_NODES\n  \n  # login wandb\n  # python3 -c \"import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')\"\n\n  if [ \"$SKYPILOT_NODE_RANK\" == \"0\" ]; then\n    # Head node starts Ray Head\n    echo \"Starting Ray head node...\"\n    ps aux | grep ray | grep 6379 &> /dev/null ||  ray start --head --disable-usage-stats \\\n          --port=6379 \\\n          --dashboard-host=0.0.0.0 \\\n          --dashboard-port=8265\n\n    # Wait for all worker nodes to join the cluster with better checking\n    echo \"Waiting for all nodes to join Ray cluster...\"\n    retry_count=0\n    max_retries=30\n    while [ $retry_count -lt $max_retries ]; do\n      connected_nodes=$(ray status 2>/dev/null | grep -c \"node_\" || echo \"0\")\n      echo \"Connected nodes: $connected_nodes/$NUM_NODES (attempt $((retry_count+1))/$max_retries)\"\n      \n      if [ \"$connected_nodes\" -ge \"$NUM_NODES\" ]; then\n        echo \"All nodes connected to Ray cluster\"\n        break\n      fi\n      \n      retry_count=$((retry_count+1))\n      sleep 10\n    done\n\n    if [ $retry_count -eq $max_retries ]; then\n      echo \"WARNING: Not all nodes connected to Ray cluster after $max_retries attempts\"\n      echo \"Current Ray status:\"\n      ray status\n    fi\n\n    python3 -m verl.trainer.main_ppo \\\n     data.train_files=$HOME/data/gsm8k/train.parquet \\\n     data.val_files=$HOME/data/gsm8k/test.parquet \\\n     data.train_batch_size=256 \\\n     data.max_prompt_length=512 \\\n     data.max_response_length=256 \\\n     actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n     actor_rollout_ref.actor.optim.lr=1e-6 \\\n     actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n     actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n     actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n     actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n     actor_rollout_ref.rollout.name=vllm \\\n     actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n     actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n     critic.optim.lr=1e-5 \\\n     critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n     critic.ppo_micro_batch_size_per_gpu=4 \\\n     algorithm.kl_ctrl.kl_coef=0.001 \\\n     trainer.logger=[console,wandb] \\\n     trainer.val_before_train=False \\\n     trainer.default_hdfs_dir=null \\\n     trainer.n_gpus_per_node=1 \\\n     trainer.nnodes=2 \\\n     trainer.save_freq=20 \\\n     trainer.test_freq=20 \\\n     trainer.total_epochs=2 \\\n     trainer.project_name=verl_examples \\\n     trainer.experiment_name=experiment_name_gsm8k\n\n  else\n    # Wait for Ray Head to start\n    sleep 15\n    # Worker node starts Ray Worker\n    echo \"Starting Ray worker node...\"\n    ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats\n    sleep 10\n  fi\n\n  echo \"Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK.\""
  },
  {
    "path": "verl_distillation/examples/slurm/ray_on_slurm.slurm",
    "content": "#!/bin/bash\n#SBATCH --job-name=verl-ray-on-slurm\n#SBATCH --nodes=2\n#SBATCH --ntasks-per-node=1\n#SBATCH --mem=200G\n#SBATCH --partition=your-partition\n#SBATCH --time=01:00:00\n#SBATCH --account=your-account\n#SBATCH --gpus-per-node=4\n#SBATCH --cpus-per-task=64\n#SBATCH --output=slurm-%j.out\n#SBATCH --error=slurm-%j.err\n\n# load necessary modules\n\n# replace these information with your own\nverl_workdir=/path/to/verl\ntrain_files=/path/to/gsm8k/train.parquet\nval_files=/path/to/gsm8k/test.parquet\napptainer_image_path=/path/to/verl-ngc.sif\n# replace these information with your own\n\n# Getting the node names\nnodes=$(scontrol show hostnames \"$SLURM_JOB_NODELIST\")\nnodes_array=(\"$nodes\")\n\nhead_node=${nodes_array[0]}\nhead_node_ip=$(srun --nodes=1 --ntasks=1 -w \"$head_node\" hostname --ip-address)\n\n# if we detect a space character in the head node IP, we'll\n# convert it to an ipv4 address. This step is optional.\nif [[ \"$head_node_ip\" == *\" \"* ]]; then\nIFS=' ' read -ra ADDR <<<\"$head_node_ip\"\nif [[ ${#ADDR[0]} -gt 16 ]]; then\n  head_node_ip=${ADDR[1]}\nelse\n  head_node_ip=${ADDR[0]}\nfi\necho \"IPV6 address detected. We split the IPV4 address as $head_node_ip\"\nfi\n\nport=6379\nip_head=$head_node_ip:$port\nexport ip_head\necho \"IP Head: $ip_head\"\n\n# make sure we set environment variables before Ray initialization\n\nprintenv\n\necho \"Starting HEAD at $head_node\"\nsrun --nodes=1 --ntasks=1 -w \"$head_node\" \\\n    apptainer run --nv --bind $verl_workdir $apptainer_image_path \\\n        ray start --head --node-ip-address=\"$head_node_ip\" --port=$port \\\n        --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n# optional, though may be useful in certain versions of Ray < 1.0.\nsleep 10\n\n# number of nodes other than the head node\nworker_num=$((SLURM_JOB_NUM_NODES - 1))\n\nfor ((i = 1; i <= worker_num; i++)); do\n    node_i=${nodes_array[$i]}\n    echo \"Starting WORKER $i at $node_i\"\n    srun --nodes=1 --ntasks=1 -w \"$node_i\" \\\n        apptainer run --nv --bind $verl_workdir $apptainer_image_path \\\n            ray start --address \"$ip_head\" --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n    sleep 5\ndone\n\nPYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w \"$head_node\" \\\n    apptainer run --nv --bind $verl_workdir $apptainer_image_path \\\n    python3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$train_files \\\n    data.val_files=$val_files \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=256 \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.logger=console \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=\"${SLURM_GPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${SLURM_NNODES}\" \\\n    trainer.save_freq=10 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 2>&1 | tee verl_demo_slurm.log\n"
  },
  {
    "path": "verl_distillation/examples/split_placement/README.md",
    "content": "# Split Placement Example\nHere we introduce how to run the naive implementation of the split placement of PPO algorithm.\nWe will release the complete version of flexible placement in the near future.\n\n For quickstart, you can only follow Step 2 to modify the code and then follow Step 4 to execute the split placement example.\n\n### Step 1: Placing the models to different GPUs\nSpecify the placement and resource allocation. In the example, we place the actor and reference in the first half of the GPUs while map the critic and reward model (if any) to the second half of the GPUs.\n```python\nactor_rollout_ref_pool_id = 'actor_rollout_ref_pool'\ncritic_pool_id = 'critic_pool'\nif config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0:\n    resource_pool_spec = {\n        actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,\n        critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,\n    }\nelse:\n    resource_pool_spec = {\n        actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),\n        critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),\n    }\nprint(f'resource_pool_spec: {resource_pool_spec}')\nmapping = {\n    Role.ActorRollout: actor_rollout_ref_pool_id,\n    Role.Critic: critic_pool_id,\n    Role.RefPolicy: actor_rollout_ref_pool_id,\n}\nmapping[Role.RewardModel] = critic_pool_id\n```\n\n### Step 2: Make the models executed asynchronously\nBased on the model placement, we need to make the models executed asynchronously.\n\nTo do so, you need to turn off the `blocking` flag (i.e., `blocking=False`) in our decorator of some model operations.\nFor example, we hope the actor update and critic update can be executed in parallel, then we need to make the following modification in `fsdp_workers.py`\n\n```\n@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)\ndef update_actor(self, data: DataProto):\n    ...\n\n@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)\ndef update_critic(self, data: DataProto):\n    ...\n```\n\nWe can also parallelize the computation of `ref_log_prob` and `values` and `rewards` in the split placement. For simplicity of the tutorial, we don't do this in this example.\n\n### Step 3: Execute these operation in parallel in the single controller process\nTo implement the parallel execution of the actor and critic update, the only thing we need to modify in the `ray_trainer.py` is to `get` the concurrent  `futures` on the single controller process.\n\n```python\ncritic_output = critic_output.get()\nactor_output = actor_output.get()\n```\n\n### Step 4: Run the split placement example\n\n```\nbash run_deepseek7b_llm.sh\n```\n"
  },
  {
    "path": "verl_distillation/examples/split_placement/config/ppo_trainer_split.yaml",
    "content": "# the ppo trainer split config will override default ppo_trainer.yaml\n\nhydra:\n  searchpath:\n    - file://../../verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  tokenizer: null\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n  train_max_samples: -1  # set to -1 to use full dataset\n  val_max_samples: -1  # set to -1 to use full dataset\n  prompt_key: prompt\n  max_prompt_length: 512\n  max_response_length: 512\n  train_batch_size: 1024\n  val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves\n  return_raw_input_ids: False  # This should be set to true when the tokenizer between policy and rm differs\n  return_raw_chat: False\n  return_full_prompt: False\n  shuffle: True\n  seed: 42\n\nactor_rollout_ref:\n  hybrid_engine: True\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    external_lib: null\n    override_config: { }\n    enable_gradient_checkpointing: True\n    use_remove_padding: False\n  actor:\n    strategy: fsdp  # This is for backward-compatibility\n    ppo_mini_batch_size: 256\n    ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu\n    ppo_micro_batch_size_per_gpu: null\n    use_dynamic_bsz: False\n    ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}\n    grad_clip: 1.0\n    clip_ratio: 0.2\n    entropy_coeff: 0.0\n    use_kl_loss: False # True for GRPO\n    kl_loss_coef: 0.001 # for grpo\n    kl_loss_type: low_var_kl # for grpo\n    ppo_epochs: 1\n    shuffle: False\n    ulysses_sequence_parallel_size: 1 # sp size\n    optim:\n      lr: 1e-6\n      lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.\n      lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n      min_lr_ratio: null   # only useful for warmup with cosine\n      lr_scheduler_type: constant  # select from constant/cosine\n      total_training_steps: -1  # must be override by program\n    fsdp_config:\n      wrap_policy:\n        # transformer_layer_cls_to_wrap: None\n        min_num_params: 0\n      param_offload: False\n      optimizer_offload: False\n      fsdp_size: -1\n  ref:\n    fsdp_config:\n      param_offload: False\n      wrap_policy:\n        # transformer_layer_cls_to_wrap: None\n        min_num_params: 0\n    log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n    ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size\n  rollout:\n    name: vllm\n    temperature: 1.0\n    top_k: -1 # 0 for hf rollout, -1 for vllm rollout\n    top_p: 1\n    prompt_length: ${data.max_prompt_length}  # not use for opensource\n    response_length: ${data.max_response_length}\n    # for vllm rollout\n    dtype: bfloat16 # should align with FSDP\n    gpu_memory_utilization: 0.5\n    ignore_eos: False\n    enforce_eager: True\n    free_cache_engine: True\n    load_format: dummy_dtensor\n    tensor_model_parallel_size: 2\n    max_num_batched_tokens: 8192\n    max_num_seqs: 1024\n    log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n    disable_log_stats: True\n    enable_chunked_prefill: True # could get higher throughput\n    # for hf rollout\n    do_sample: True\n    # number of responses (i.e. num sample times)\n    n: 1 # > 1 for grpo\n\ncritic:\n  strategy: fsdp\n  optim:\n    lr: 1e-5\n    lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n    min_lr_ratio: null   # only useful for warmup with cosine\n    lr_scheduler_type: constant  # select from constant/cosine\n    total_training_steps: -1  # must be override by program\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    tokenizer_path: ${actor_rollout_ref.model.path}\n    override_config: { }\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    enable_gradient_checkpointing: True\n    use_remove_padding: False\n    fsdp_config:\n      param_offload: False\n      optimizer_offload: False\n      wrap_policy:\n        # transformer_layer_cls_to_wrap: None\n        min_num_params: 0\n      fsdp_size: -1\n  ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}\n  ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu\n  ppo_micro_batch_size_per_gpu: null\n  forward_micro_batch_size: ${critic.ppo_micro_batch_size}\n  forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}\n  use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n  ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2\n  forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}\n  ulysses_sequence_parallel_size: 1 # sp size\n  ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}\n  shuffle: ${actor_rollout_ref.actor.shuffle}\n  grad_clip: 1.0\n  cliprange_value: 0.5\n\nreward_model:\n  enable: False\n  strategy: fsdp\n  model:\n    input_tokenizer: ${actor_rollout_ref.model.path}  # set this to null if the chat template is identical\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    use_remove_padding: False\n    fsdp_config:\n      min_num_params: 0\n      param_offload: False\n      fsdp_size: -1\n  micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu\n  micro_batch_size_per_gpu: null # set a number\n  max_length: null\n  ulysses_sequence_parallel_size: 1 # sp size\n  use_dynamic_bsz: ${critic.use_dynamic_bsz}\n  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n  reward_manager: naive\n\nalgorithm:\n  gamma: 1.0\n  lam: 1.0\n  adv_estimator: gae\n  use_kl_in_reward: False\n  kl_penalty: kl  # how to estimate kl divergence\n  kl_ctrl:\n    type: fixed\n    kl_coef: 0.001\n\ntrainer:\n  total_epochs: 30\n  total_training_steps: null\n  project_name: verl_examples\n  experiment_name: gsm8k\n  logger: [ 'console', 'wandb' ]\n  log_val_generations: 0\n  nnodes: 1\n  n_gpus_per_node: 8\n  save_freq: -1\n  # auto: find the last ckpt to resume. If can't find, start from scratch\n  resume_mode: auto # or disable or resume_path if resume_from_path is set\n  resume_from_path: null\n  test_freq: -1\n  critic_warmup: 0\n  default_hdfs_dir: null\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n\nray_kwargs:\n  ray_init:\n    num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.\n"
  },
  {
    "path": "verl_distillation/examples/split_placement/main_ppo_split.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport hydra\nimport ray\nimport torch\nfrom omegaconf import OmegaConf\nfrom split_monkey_patch import fit\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.ray_trainer import RayPPOTrainer\nfrom verl.utils.reward_score import gsm8k, math_reward\n\n\ndef _select_rm_score_fn(data_source):\n    if data_source == \"openai/gsm8k\":\n        return gsm8k.compute_score\n    elif data_source == \"lighteval/MATH\":\n        return math_reward.compute_score\n    else:\n        raise NotImplementedError\n\n\nclass RewardManager:\n    def __init__(self, tokenizer, num_examine) -> None:\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n\n    def __call__(self, data: DataProto, return_dict: bool = False):\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            return data.batch[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n\n        already_print_data_sources = {}\n\n        for i in range(len(data)):\n            data_item = data[i]  # DataProtoItem\n\n            prompt_ids = data_item.batch[\"prompts\"]\n\n            prompt_length = prompt_ids.shape[-1]\n\n            valid_prompt_length = data_item.batch[\"attention_mask\"][:prompt_length].sum()\n            valid_prompt_ids = prompt_ids[-valid_prompt_length:]\n\n            response_ids = data_item.batch[\"responses\"]\n            valid_response_length = data_item.batch[\"attention_mask\"][prompt_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            sequences = torch.cat((valid_prompt_ids, valid_response_ids))\n            sequences_str = self.tokenizer.decode(sequences)\n\n            ground_truth = data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"]\n\n            # select rm_score\n            data_source = data_item.non_tensor_batch[\"data_source\"]\n            compute_score_fn = _select_rm_score_fn(data_source)\n\n            score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth)\n            reward_tensor[i, valid_response_length - 1] = score\n\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                print(sequences_str)\n\n        if return_dict:\n            return {\"reward_tensor\": reward_tensor}\n        else:\n            return reward_tensor\n\n\n@hydra.main(config_path=\"config\", config_name=\"ppo_trainer_split\", version_base=None)\ndef main(config):\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        default_runtime_env = {\"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\"}}\n        ray_init_kwargs = config.ray_kwargs.get(\"ray_init\", {})\n        runtime_env_kwargs = ray_init_kwargs.get(\"runtime_env\", {})\n        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)\n        ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, \"runtime_env\": runtime_env})\n        print(f\"ray init kwargs: {ray_init_kwargs}\")\n        ray.init(**OmegaConf.to_container(ray_init_kwargs))\n\n    ray.get(main_task.remote(config))\n\n\n@ray.remote\ndef main_task(config):\n    # print initial config\n    from pprint import pprint\n\n    from omegaconf import OmegaConf\n\n    from verl.utils.fs import copy_to_local\n\n    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n    OmegaConf.resolve(config)\n\n    # download the checkpoint from hdfs\n    local_path = copy_to_local(config.actor_rollout_ref.model.path)\n\n    # instantiate tokenizer\n    from verl.utils import hf_tokenizer\n\n    tokenizer = hf_tokenizer(local_path)\n\n    # define worker classes\n    if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n        assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n        from verl.single_controller.ray import RayWorkerGroup\n        from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker\n\n        ray_worker_group_cls = RayWorkerGroup\n\n    elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n        assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n        from verl.single_controller.ray import RayWorkerGroup\n        from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker\n\n        ray_worker_group_cls = RayWorkerGroup\n\n    else:\n        raise NotImplementedError\n\n    from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n    role_worker_mapping = {\n        Role.ActorRollout: ray.remote(ActorRolloutRefWorker),\n        Role.Critic: ray.remote(CriticWorker),\n    }\n\n    # NOTE: initialze two resource pool\n    actor_rollout_ref_pool_id = \"actor_rollout_ref_pool\"\n    critic_pool_id = \"critic_pool\"\n    if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0:\n        resource_pool_spec = {\n            actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,\n            critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,\n        }\n    else:\n        resource_pool_spec = {\n            actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),\n            critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),\n        }\n    print(f\"resource_pool_spec: {resource_pool_spec}\")\n    mapping = {\n        Role.ActorRollout: actor_rollout_ref_pool_id,\n        Role.Critic: critic_pool_id,\n    }\n\n    # use reference model\n    if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n        role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n        mapping[Role.RefPolicy] = actor_rollout_ref_pool_id\n\n    # we should adopt a multi-source reward function here\n    # - for rule-based rm, we directly call a reward score\n    # - for model-based rm, we call a model\n    # - for code related prompt, we send to a sandbox if there are test cases\n    # - finally, we combine all the rewards together\n    # - The reward type depends on the tag of the data\n    if config.reward_model.enable:\n        if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n            from verl.workers.fsdp_workers import RewardModelWorker\n        elif config.reward_model.strategy == \"megatron\":\n            from verl.workers.megatron_workers import RewardModelWorker\n        else:\n            raise NotImplementedError\n        role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n        mapping[Role.RewardModel] = critic_pool_id\n\n    reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)\n\n    # Note that we always use function-based RM for validation\n    val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1)\n\n    resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n    RayPPOTrainer.fit = fit\n    trainer = RayPPOTrainer(\n        config=config,\n        tokenizer=tokenizer,\n        role_worker_mapping=role_worker_mapping,\n        resource_pool_manager=resource_pool_manager,\n        ray_worker_group_cls=ray_worker_group_cls,\n        reward_fn=reward_fn,\n        val_reward_fn=val_reward_fn,\n    )\n    trainer.init_workers()\n    trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/examples/split_placement/run_deepseek7b_llm.sh",
    "content": "set -x\n\npython3 main_ppo_split.py \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.ppo_micro_batch_size_per_gpu=8 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/split_placement/split_monkey_patch.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nAn naive implementation of split placment example\n\"\"\"\n\nimport uuid\nfrom copy import deepcopy\nfrom pprint import pprint\n\nimport numpy as np\nimport torch\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.ray_trainer import (\n    AdvantageEstimator,\n    apply_kl_penalty,\n    compute_advantage,\n    compute_data_metrics,\n    compute_timing_metrics,\n    marked_timer,\n)\nfrom verl.trainer.ppo.reward import compute_reward\nfrom verl.utils.metric import reduce_metrics\n\n\ndef fit(self):\n    \"\"\"\n    The training loop of PPO.\n    The driver process only need to call the compute functions of the worker group through RPC\n    to construct the PPO dataflow.\n    The light-weight advantage computation is done on the driver process.\n    \"\"\"\n    from omegaconf import OmegaConf\n\n    from verl.utils.tracking import Tracking\n\n    logger = Tracking(\n        project_name=self.config.trainer.project_name,\n        experiment_name=self.config.trainer.experiment_name,\n        default_backend=self.config.trainer.logger,\n        config=OmegaConf.to_container(self.config, resolve=True),\n    )\n\n    self.global_steps = 0\n\n    # load checkpoint before doing anything\n    self._load_checkpoint()\n\n    # perform validation before training\n    # currently, we only support validation using the reward_function.\n    if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n        val_metrics = self._validate()\n        pprint(f\"Initial validation metrics: {val_metrics}\")\n        logger.log(data=val_metrics, step=self.global_steps)\n        if self.config.trainer.get(\"val_only\", False):\n            return\n\n    # we start from step 1\n    self.global_steps += 1\n    last_val_metrics = None\n\n    for epoch in range(self.config.trainer.total_epochs):\n        for batch_dict in self.train_dataloader:\n            metrics = {}\n            timing_raw = {}\n\n            batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n            # pop those keys for generation\n            gen_batch = batch.pop(batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"])\n            is_last_step = self.global_steps >= self.total_training_steps\n\n            with marked_timer(\"step\", timing_raw):\n                # generate a batch\n                with marked_timer(\"gen\", timing_raw):\n                    gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                    timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                    gen_batch_output.meta_info.pop(\"timing\", None)\n\n                if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                    with marked_timer(\"gen_max\", timing_raw):\n                        gen_baseline_batch = deepcopy(gen_batch)\n                        gen_baseline_batch.meta_info[\"do_sample\"] = False\n                        gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n\n                        batch = batch.union(gen_baseline_output)\n                        # compute reward model score on batch\n                        rm_scores = None\n                        if self.use_rm and \"rm_scores\" not in batch.batch.keys():\n                            rm_scores = self.rm_wg.compute_rm_score(batch)\n                            batch = batch.union(rm_scores)\n                        reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn)\n                        reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                        keys_to_pop = set(gen_baseline_output.batch.keys())\n                        if rm_scores is not None:\n                            keys_to_pop.update(rm_scores.batch.keys())\n                        batch.pop(batch_keys=list(keys_to_pop))\n\n                        batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                        del rm_scores, gen_baseline_batch, gen_baseline_output\n\n                batch.non_tensor_batch[\"uid\"] = np.array(\n                    [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object\n                )\n                # repeat to align with repeated responses in rollout\n                batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                batch = batch.union(gen_batch_output)\n\n                # Balance the number of valid tokens across DP ranks.\n                # NOTE: This usually changes the order of data in the `batch`,\n                # which won't affect the advantage calculation (since it's based on uid),\n                # but might affect the loss calculation (due to the change of mini-batching).\n                # TODO: Decouple the DP balancing and mini-batching.\n                self._balance_batch(batch, metrics=metrics)\n\n                # compute global_valid tokens\n                batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                # recompute old_log_probs\n                with marked_timer(\"old_log_prob\", timing_raw):\n                    old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                    batch = batch.union(old_log_prob)\n\n                if self.use_reference_policy:\n                    # compute reference log_prob\n                    with marked_timer(\"ref\", timing_raw):\n                        ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                        batch = batch.union(ref_log_prob)\n\n                # compute values\n                if self.use_critic:\n                    with marked_timer(\"values\", timing_raw):\n                        values = self.critic_wg.compute_values(batch)\n                        batch = batch.union(values)\n\n                with marked_timer(\"adv\", timing_raw):\n                    # compute scores. Support both model and function-based.\n                    # We first compute the scores using reward model. Then, we call reward_fn to combine\n                    # the results from reward model and rule-based results.\n                    if self.use_rm and \"rm_scores\" not in batch.batch.keys():\n                        # we first compute reward model score\n                        reward_tensor = self.rm_wg.compute_rm_score(batch)\n                        batch = batch.union(reward_tensor)\n\n                    # we combine with rule-based rm\n                    reward_tensor, _ = compute_reward(batch, self.reward_fn)\n                    batch.batch[\"token_level_scores\"] = reward_tensor\n\n                    # compute rewards. apply_kl_penalty if available\n                    if self.config.algorithm.use_kl_in_reward:\n                        batch, kl_metrics = apply_kl_penalty(\n                            batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                        )\n                        metrics.update(kl_metrics)\n                    else:\n                        batch.batch[\"token_level_rewards\"] = batch.batch[\"token_level_scores\"]\n\n                    # compute advantages, executed on the driver process\n                    norm_adv_by_std_in_grpo = self.config.algorithm.get(\"norm_adv_by_std_in_grpo\", True)\n                    batch = compute_advantage(\n                        batch,\n                        adv_estimator=self.config.algorithm.adv_estimator,\n                        gamma=self.config.algorithm.gamma,\n                        lam=self.config.algorithm.lam,\n                        num_repeat=self.config.actor_rollout_ref.rollout.n,\n                        norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                        config=self.config.algorithm,\n                    )\n\n                # implement critic warmup\n                if self.config.trainer.critic_warmup <= self.global_steps:\n                    # update actor\n                    with marked_timer(\"update_actor_call\", timing_raw):\n                        actor_output = self.actor_rollout_wg.update_actor(batch)\n                else:\n                    actor_output = None\n\n                # update critic\n                if self.use_critic:\n                    with marked_timer(\"update_critic_call\", timing_raw):\n                        critic_output = self.critic_wg.update_critic(batch)\n\n                    # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class\n                    with marked_timer(\"update_actor_critic\", timing_raw):\n                        critic_output = critic_output.get()\n                        critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                        metrics.update(critic_output_metrics)\n\n                if actor_output is not None:\n                    actor_output = actor_output.get()\n                    actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                    metrics.update(actor_output_metrics)\n\n            # validate\n            if (\n                self.val_reward_fn is not None\n                and self.config.trainer.test_freq > 0\n                and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n            ):\n                with marked_timer(\"testing\", timing_raw):\n                    val_metrics: dict = self._validate()\n                    if is_last_step:\n                        last_val_metrics = val_metrics\n                metrics.update(val_metrics)\n\n            if self.config.trainer.save_freq > 0 and (\n                is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n            ):\n                with marked_timer(\"save_checkpoint\", timing_raw):\n                    self._save_checkpoint()\n\n            # collect metrics\n            metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n            metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n\n            # TODO: make a canonical logger that supports various backend\n            logger.log(data=metrics, step=self.global_steps)\n\n            if self.global_steps >= self.total_training_steps:\n                pprint(f\"Final validation metrics: {last_val_metrics}\")\n                return\n\n            self.global_steps += 1\n"
  },
  {
    "path": "verl_distillation/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=4\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-0.5b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=0.5b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct\n\nset -x\nnproc_per_gpu=1\nnnodes=1\nngpu_per_node=1\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    trainer.val_before_train=False \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \\\n    actor_rollout_ref.rollout.n=1 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_distillation/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-1.5b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=1.5b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-1.5B-Instruct\n\nset -x\nnproc_per_gpu=128\nnnodes=1\nngpu_per_node=1\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_distillation/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-14b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=14b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-14B-Instruct\n\nset -x\nnproc_per_gpu=58 # 32√ → 64× → 48√ → 56√ → 60× → 58√ → 59×\nnnodes=1\nngpu_per_node=2\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.25 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=2 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_distillation/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/rlhf/math/test.parquet\nmodel_path=Qwen/Qwen2.5-Coder-14B-Instruct\n\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\nPYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_14b_function_rm' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@\n"
  },
  {
    "path": "verl_distillation/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-32b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=32b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-32B-Instruct\n\nset -x\nnproc_per_gpu=45 # 32√ → 64× → 48× → 40√ → 44√ → 46× → 45×\nnnodes=1\nngpu_per_node=4\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_distillation/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh",
    "content": "set -x\n\n# we need this to avoid fragmentation of GPU memory\nexport PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256\n\ngsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/rlhf/math/test.parquet\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\nmodel_path=Qwen/Qwen2.5-32B\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=6144 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.actor.megatron.param_offload=True \\\n    actor_rollout_ref.actor.megatron.grad_offload=True \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=True \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.megatron.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='megatron_vllm_qwen2_32b' \\\n    trainer.experiment_name='qwen2_32b_grpo_8_h20' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-3b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=3b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-3B-Instruct\n\nset -x\nnproc_per_gpu=62\nnnodes=1\nngpu_per_node=1\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_distillation/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet\ngsm8k_val_path=$HOME/data/rlhf/math/test.parquet\nmodel_path=Qwen/Qwen2-72B-Instruct\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$data_path \\\n    data.val_files=$gsm8k_val_path \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=16 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='Qwen2_72B_Instruct' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=4 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@"
  },
  {
    "path": "verl_distillation/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh",
    "content": "set -x\n\n#### important: vllm version must be >= 0.8.3\n\ngsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet\ngsm8k_val_path=$HOME/data/rlhf/math/test.parquet\nmodel_path=Qwen/Qwen2-72B-Instruct\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$gsm8k_train_path \\\n    data.val_files=$gsm8k_val_path \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=16 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='Qwen2_72B_Instruct' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=4 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@"
  },
  {
    "path": "verl_distillation/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-72b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=72b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-72B-Instruct\n\nset -x\nnproc_per_gpu=22 # 16√ → 32× → 24× → 20√ → 22√ → 23×\nnnodes=1\nngpu_per_node=8\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_distillation/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-7b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=7b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-7B-Instruct\n\nset -x\nnproc_per_gpu=16 # 64√ → 128× → 96√ → 112× → 104× → 100√ → 102× → 101×\nnnodes=1\nngpu_per_node=1\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_distillation/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/rlhf/math/test.parquet\nmodel_path=Qwen/Qwen2-7B-Instruct\n\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\nPYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=2 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/examples/tutorial/agent_loop_get_started/agent_loop_tutorial.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Train ReAct agent with code sandbox\\n\",\n    \"\\n\",\n    \"In this tutorial, we will demonstrate how to train a [ReAct](https://arxiv.org/abs/2210.03629) agent to solve math problem with code sandbox.\\n\",\n    \"\\n\",\n    \"The agent works as follows:\\n\",\n    \"1. Given a math problem, the agent first query LLM to generate response and tool calls, which are python code to be executed in sandbox.\\n\",\n    \"2. If there is a tool call, the agent execute the python code in code sandbox.\\n\",\n    \"3. After code execution, the agent get the result from sandbox and append to chat history.\\n\",\n    \"4. The agent query LLM again until no tool call or max context length reached.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"<figure>\\n\",\n    \"  <img src=\\\"https://langchain-ai.github.io/langgraph/agents/assets/agent.png\\\" alt=\\\"ReAct\\\" width=\\\"400\\\">\\n\",\n    \"  <figcaption style=\\\"font-style: italic; color: #666;\\\">\\n\",\n    \"    source: <a href=\\\"https://langchain-ai.github.io/langgraph/agents/overview/\\\" target=\\\"_blank\\\">LangGraph</a>\\n\",\n    \"  </figcaption>\\n\",\n    \"</figure>\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 1. Prerequisite\\n\",\n    \"\\n\",\n    \"To run the examples in this notebook, you need to install the verl package first.\\n\",\n    \"```bash\\n\",\n    \"git clone https://github.com/volcengine/verl\\n\",\n    \"cd verl\\n\",\n    \"pip install -e .\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-10-16 23:20:11,956\\tINFO worker.py:2004 -- Started a local Ray instance. View the dashboard at \\u001b[1m\\u001b[32mhttp://127.0.0.1:8265 \\u001b[39m\\u001b[22m\\n\",\n      \"/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py:2052: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0\\n\",\n      \"  warnings.warn(\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import asyncio\\n\",\n    \"import sys\\n\",\n    \"import tempfile\\n\",\n    \"import os\\n\",\n    \"import socket\\n\",\n    \"import json\\n\",\n    \"\\n\",\n    \"import requests\\n\",\n    \"import ray\\n\",\n    \"import fastapi\\n\",\n    \"import uvicorn\\n\",\n    \"from starlette.requests import Request\\n\",\n    \"from starlette.responses import JSONResponse\\n\",\n    \"from pprint import pprint\\n\",\n    \"\\n\",\n    \"import verl\\n\",\n    \"\\n\",\n    \"ray.init()\\n\",\n    \"verl_config_dir = os.path.join(os.path.dirname(verl.__file__), \\\"trainer/config\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"For demo purpose, we will use Qwen/Qwen3-1.7B as the LLM. First, let's download required model and dataset used in this tutorial.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pyarrow.parquet as pq\\n\",\n    \"from huggingface_hub import snapshot_download\\n\",\n    \"\\n\",\n    \"snapshot_download(\\n\",\n    \"    repo_id=\\\"verl-team/lighteval-MATH-preprocessed\\\",\\n\",\n    \"    repo_type=\\\"dataset\\\",\\n\",\n    \"    local_dir=os.path.expanduser(\\\"~/verl-team/lighteval-MATH-preprocessed\\\"),\\n\",\n    \")\\n\",\n    \"snapshot_download(\\n\",\n    \"    repo_id=\\\"Qwen/Qwen3-1.7B\\\",\\n\",\n    \"    repo_type=\\\"model\\\",\\n\",\n    \"    local_dir=os.path.expanduser(\\\"~/Qwen/Qwen3-1.7B\\\"),\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"model_path = os.path.expanduser(\\\"~/Qwen/Qwen3-1.7B\\\")\\n\",\n    \"train_file = os.path.expanduser(\\\"~/verl-team/lighteval-MATH-preprocessed/train.parquet\\\")\\n\",\n    \"test_file = os.path.expanduser(\\\"~/verl-team/lighteval-MATH-preprocessed/test.parquet\\\")\\n\",\n    \"\\n\",\n    \"test = pq.read_table(test_file)\\n\",\n    \"test_file = os.path.expanduser(\\\"~/verl-team/lighteval-MATH-preprocessed/test_100.parquet\\\")\\n\",\n    \"pq.write_table(test[:100], test_file)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"verl support both vllm and sglang rollout server for high performance inference. This tutorial has been tested on both vllm and sglang, you can choose either of them to run the tutorial.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"rollout_name = \\\"???\\\"  # vllm or sglang\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 2. Basic tool call\\n\",\n    \"For beginning, let's see how we can do basic tool call in verl with example from [Transformer tool use](https://huggingface.co/docs/transformers/main/chat_extras#tool-use). To use tool in verl, we need to define a tool class that inherits from `BaseTool`, and implement the following methods:\\n\",\n    \"- `get_openai_tool_schema`: return the schema of the tool in `OpenAIFunctionToolSchema` format.\\n\",\n    \"- `execute`: execute the tool with the given parameters, and return the result in `ToolResponse` format.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"{\\n\",\n      \"  \\\"type\\\": \\\"function\\\",\\n\",\n      \"  \\\"function\\\": {\\n\",\n      \"    \\\"name\\\": \\\"get_current_temperature\\\",\\n\",\n      \"    \\\"description\\\": \\\"Get current temperature at a location.\\\",\\n\",\n      \"    \\\"parameters\\\": {\\n\",\n      \"      \\\"type\\\": \\\"object\\\",\\n\",\n      \"      \\\"properties\\\": {\\n\",\n      \"        \\\"location\\\": {\\n\",\n      \"          \\\"type\\\": \\\"string\\\",\\n\",\n      \"          \\\"description\\\": \\\"The location to get the temperature for, in the format \\\\\\\"City, State, Country\\\\\\\".\\\"\\n\",\n      \"        },\\n\",\n      \"        \\\"unit\\\": {\\n\",\n      \"          \\\"type\\\": \\\"string\\\",\\n\",\n      \"          \\\"description\\\": \\\"The unit to return the temperature in. Defaults to \\\\\\\"celsius\\\\\\\".\\\",\\n\",\n      \"          \\\"enum\\\": [\\n\",\n      \"            \\\"celsius\\\",\\n\",\n      \"            \\\"fahrenheit\\\"\\n\",\n      \"          ]\\n\",\n      \"        }\\n\",\n      \"      },\\n\",\n      \"      \\\"required\\\": [\\n\",\n      \"        \\\"location\\\"\\n\",\n      \"      ]\\n\",\n      \"    }\\n\",\n      \"  }\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from transformers.utils import get_json_schema\\n\",\n    \"from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema, ToolResponse\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class WeatherTool(BaseTool):\\n\",\n    \"    def get_current_temperature(self, location: str, unit: str = \\\"celsius\\\"):\\n\",\n    \"        \\\"\\\"\\\"Get current temperature at a location.\\n\",\n    \"\\n\",\n    \"        Args:\\n\",\n    \"            location: The location to get the temperature for, in the format \\\"City, State, Country\\\".\\n\",\n    \"            unit: The unit to return the temperature in. Defaults to \\\"celsius\\\". (choices: [\\\"celsius\\\", \\\"fahrenheit\\\"])\\n\",\n    \"\\n\",\n    \"        Returns:\\n\",\n    \"            the temperature, the location, and the unit in a dict\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        return {\\n\",\n    \"            \\\"temperature\\\": 26.1,\\n\",\n    \"            \\\"location\\\": location,\\n\",\n    \"            \\\"unit\\\": unit,\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\\n\",\n    \"        schema = get_json_schema(self.get_current_temperature)\\n\",\n    \"        return OpenAIFunctionToolSchema(**schema)\\n\",\n    \"\\n\",\n    \"    async def execute(self, instance_id: str, parameters: dict, **kwargs) -> tuple[ToolResponse, float, dict]:\\n\",\n    \"        try:\\n\",\n    \"            result = self.get_current_temperature(**parameters)\\n\",\n    \"            return ToolResponse(text=json.dumps(result)), 0, {}\\n\",\n    \"        except Exception as e:\\n\",\n    \"            return ToolResponse(text=str(e)), 0, {}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"weather_tool = WeatherTool(config={}, tool_schema=None)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Next, let's launch a standalone rollout server without hybrid engine (which is more heavy to start) to test the basic tool call.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from hydra import compose, initialize_config_dir\\n\",\n    \"from verl.workers.rollout.replica import get_rollout_replica_class\\n\",\n    \"\\n\",\n    \"with initialize_config_dir(config_dir=verl_config_dir):\\n\",\n    \"    config = compose(\\n\",\n    \"        config_name=\\\"ppo_trainer\\\",\\n\",\n    \"        overrides=[\\n\",\n    \"            \\\"actor_rollout_ref.rollout.name=\\\" + rollout_name,\\n\",\n    \"            \\\"actor_rollout_ref.rollout.mode=async\\\",\\n\",\n    \"            \\\"actor_rollout_ref.rollout.tensor_model_parallel_size=1\\\",\\n\",\n    \"            \\\"actor_rollout_ref.model.path=\\\" + model_path,\\n\",\n    \"            \\\"actor_rollout_ref.rollout.response_length=4096\\\",\\n\",\n    \"            \\\"actor_rollout_ref.rollout.skip_tokenizer_init=False\\\",\\n\",\n    \"            \\\"+actor_rollout_ref.rollout.engine_kwargs.vllm.enable_auto_tool_choice=True\\\",\\n\",\n    \"            \\\"+actor_rollout_ref.rollout.engine_kwargs.vllm.tool_call_parser=hermes\\\",\\n\",\n    \"            \\\"+actor_rollout_ref.rollout.engine_kwargs.sglang.tool_call_parser=qwen25\\\",\\n\",\n    \"        ],\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"rollout_server_class = get_rollout_replica_class(config.actor_rollout_ref.rollout.name)\\n\",\n    \"rollout_server = rollout_server_class(\\n\",\n    \"    replica_rank=0,\\n\",\n    \"    config=config.actor_rollout_ref.rollout,\\n\",\n    \"    model_config=config.actor_rollout_ref.model,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"await rollout_server.init_standalone()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Then, we can query LLM with openai client. Note that we need to pass the tool schema to server to guide LLM generating tool calls. We can see that the LLM correctly generates a tool call to get the temperature in Paris.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[{'content': \\\"Hey, what's the temperature in Paris right now?\\\", 'role': 'user'},\\n\",\n      \" {'role': 'assistant',\\n\",\n      \"  'tool_calls': [{'function': {'arguments': '{\\\"location\\\": \\\"Paris, France\\\"}',\\n\",\n      \"                               'name': 'get_current_temperature'},\\n\",\n      \"                  'id': 'call_b10bdde504a0411690e96b55',\\n\",\n      \"                  'index': -1,\\n\",\n      \"                  'type': 'function'}]}]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from openai import AsyncOpenAI\\n\",\n    \"\\n\",\n    \"client = AsyncOpenAI(\\n\",\n    \"    api_key=\\\"dummy\\\",\\n\",\n    \"    base_url=f\\\"http://{rollout_server._server_address}/v1\\\",\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"messages = [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"Hey, what's the temperature in Paris right now?\\\"}]\\n\",\n    \"completion = await client.chat.completions.create(\\n\",\n    \"    model=config.actor_rollout_ref.model.path,\\n\",\n    \"    messages=messages,\\n\",\n    \"    tools=[weather_tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True)],\\n\",\n    \"    extra_body={\\n\",\n    \"        \\\"chat_template_kwargs\\\": {\\\"enable_thinking\\\": False},\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)\\n\",\n    \"messages.append(message)\\n\",\n    \"pprint(messages)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can execute the tool call with arguments generated by LLM and get the temperature in Paris.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"text='{\\\"temperature\\\": 26.1, \\\"location\\\": \\\"Paris, France\\\", \\\"unit\\\": \\\"celsius\\\"}' image=None video=None\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"args = json.loads(message[\\\"tool_calls\\\"][0][\\\"function\\\"][\\\"arguments\\\"])\\n\",\n    \"tool_response, _, _ = await weather_tool.execute(\\\"\\\", args)\\n\",\n    \"print(tool_response)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Then, we can add the tool response to chat history and query LLM again. With the tool response, LLM can generate a final response to the user.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[{'content': \\\"Hey, what's the temperature in Paris right now?\\\", 'role': 'user'},\\n\",\n      \" {'role': 'assistant',\\n\",\n      \"  'tool_calls': [{'function': {'arguments': '{\\\"location\\\": \\\"Paris, France\\\"}',\\n\",\n      \"                               'name': 'get_current_temperature'},\\n\",\n      \"                  'id': 'call_b10bdde504a0411690e96b55',\\n\",\n      \"                  'index': -1,\\n\",\n      \"                  'type': 'function'}]},\\n\",\n      \" {'content': '{\\\"temperature\\\": 26.1, \\\"location\\\": \\\"Paris, France\\\", \\\"unit\\\": '\\n\",\n      \"             '\\\"celsius\\\"}',\\n\",\n      \"  'role': 'tool'},\\n\",\n      \" {'content': 'The current temperature in Paris is 26.1°C.',\\n\",\n      \"  'role': 'assistant'}]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"messages.append({\\\"role\\\": \\\"tool\\\", \\\"content\\\": tool_response.text})\\n\",\n    \"completion = await client.chat.completions.create(\\n\",\n    \"    model=config.actor_rollout_ref.model.path,\\n\",\n    \"    messages=messages,\\n\",\n    \"    tools=[weather_tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True)],\\n\",\n    \"    extra_body={\\n\",\n    \"        \\\"chat_template_kwargs\\\": {\\\"enable_thinking\\\": False},\\n\",\n    \"    },\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)\\n\",\n    \"messages.append(message)\\n\",\n    \"pprint(messages)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 2. Advanced tool call with code sandbox\\n\",\n    \"\\n\",\n    \"Now, let's see a more realistic example of tool call with code sandbox, which is widely used in real-world applications.\\n\",\n    \"\\n\",\n    \"### 2.1 Implement a naive code sandbox\\n\",\n    \"\\n\",\n    \"To execute python code snippet generated by LLM, we need a code sandbox environment. In this tutorial, we will implement a very naive code sandbox, which is\\n\",\n    \"a FastAPI http server with `/run_code` endpoint. The server works as follows:\\n\",\n    \"1. Receive a http request, write the python code snippet to a temp file.\\n\",\n    \"2. Spawn a subprocess to execute the code, and get stdout and stderr of the subprocess.\\n\",\n    \"3. Return the stdout and stderr of the subprocess as http response.\\n\",\n    \"\\n\",\n    \"> 🚨 **WARNING:** This naive code sandbox is for demonstration purpose only, do not use it in production. Please use docker/kata container for stronger isolation and security restriction.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@ray.remote(num_cpus=1)\\n\",\n    \"class Sandbox:\\n\",\n    \"    \\\"\\\"\\\"Sandbox to execute python code.\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    def __init__(self):\\n\",\n    \"        self.address = ray._private.services.get_node_ip_address()\\n\",\n    \"        self.port = self._get_free_port()\\n\",\n    \"        asyncio.create_task(self._start_fastapi_server())\\n\",\n    \"\\n\",\n    \"    async def code_execution(self, request: Request):\\n\",\n    \"        request_json = await request.json()\\n\",\n    \"        code = request_json[\\\"code\\\"]\\n\",\n    \"        # print(f\\\"execute code:\\\\n{code}\\\")\\n\",\n    \"\\n\",\n    \"        _, temp_file = tempfile.mkstemp(suffix=\\\".py\\\", prefix=\\\"temp_code\\\", dir=None, text=True)\\n\",\n    \"        with open(temp_file, \\\"w\\\") as f:\\n\",\n    \"            f.write(code)\\n\",\n    \"\\n\",\n    \"        try:\\n\",\n    \"            process = await asyncio.create_subprocess_exec(\\n\",\n    \"                sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE\\n\",\n    \"            )\\n\",\n    \"\\n\",\n    \"            stdout, stderr = await process.communicate()\\n\",\n    \"\\n\",\n    \"            response = {\\n\",\n    \"                \\\"status\\\": \\\"Success\\\" if process.returncode == 0 else \\\"Failed\\\",\\n\",\n    \"                \\\"run_result\\\": {\\n\",\n    \"                    \\\"status\\\": \\\"Finished\\\",\\n\",\n    \"                    \\\"stdout\\\": stdout.decode(),\\n\",\n    \"                    \\\"stderr\\\": stderr.decode(),\\n\",\n    \"                    \\\"return_code\\\": process.returncode,\\n\",\n    \"                },\\n\",\n    \"            }\\n\",\n    \"            return JSONResponse(content=response)\\n\",\n    \"        finally:\\n\",\n    \"            try:\\n\",\n    \"                os.unlink(temp_file)\\n\",\n    \"            except Exception:\\n\",\n    \"                pass\\n\",\n    \"\\n\",\n    \"    def _get_free_port(self):\\n\",\n    \"        with socket.socket() as sock:\\n\",\n    \"            sock.bind((\\\"\\\", 0))\\n\",\n    \"            return sock.getsockname()[1]\\n\",\n    \"\\n\",\n    \"    async def _start_fastapi_server(self):\\n\",\n    \"        app = fastapi.FastAPI()\\n\",\n    \"        app.router.add_api_route(\\\"/run_code\\\", self.code_execution, methods=[\\\"POST\\\"])\\n\",\n    \"\\n\",\n    \"        config = uvicorn.Config(app, host=[\\\"::\\\", \\\"0.0.0.0\\\"], port=self.port, log_level=\\\"warning\\\")\\n\",\n    \"        server = uvicorn.Server(config)\\n\",\n    \"        await server.serve()\\n\",\n    \"\\n\",\n    \"    async def get_server_address(self) -> str:\\n\",\n    \"        \\\"\\\"\\\"Get FastAPI server address.\\\"\\\"\\\"\\n\",\n    \"        return f\\\"{self.address}:{self.port}\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sandbox = Sandbox.remote()\\n\",\n    \"sandbox_address = ray.get(sandbox.get_server_address.remote())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2.2 Define sandbox tool\\n\",\n    \"\\n\",\n    \"As shown in the previous section, we also defined a tool for the code sandbox. In the `execute` method, we send the code snippet to code sandbox by http request and get the output.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"{\\n\",\n      \"  \\\"type\\\": \\\"function\\\",\\n\",\n      \"  \\\"function\\\": {\\n\",\n      \"    \\\"name\\\": \\\"code_interpreter\\\",\\n\",\n      \"    \\\"description\\\": \\\"Execute the code in the sandbox.\\\",\\n\",\n      \"    \\\"parameters\\\": {\\n\",\n      \"      \\\"type\\\": \\\"object\\\",\\n\",\n      \"      \\\"properties\\\": {\\n\",\n      \"        \\\"code\\\": {\\n\",\n      \"          \\\"type\\\": \\\"string\\\",\\n\",\n      \"          \\\"description\\\": \\\"The code to be executed.\\\"\\n\",\n      \"        }\\n\",\n      \"      },\\n\",\n      \"      \\\"required\\\": [\\n\",\n      \"        \\\"code\\\"\\n\",\n      \"      ]\\n\",\n      \"    }\\n\",\n      \"  }\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import re\\n\",\n    \"import aiohttp\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class SandboxTool(BaseTool):\\n\",\n    \"    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\\n\",\n    \"        super().__init__(config, tool_schema)\\n\",\n    \"        # Different model may use different code pattern, e.g. python, py, etc.\\n\",\n    \"        self.code_pattern = re.compile(r\\\"```py(.*?)```\\\", re.DOTALL)\\n\",\n    \"\\n\",\n    \"    async def code_interpreter(self, code: str) -> str:\\n\",\n    \"        \\\"\\\"\\\"Execute the code in the sandbox.\\n\",\n    \"\\n\",\n    \"        Args:\\n\",\n    \"            code: The code to be executed.\\n\",\n    \"\\n\",\n    \"        Returns:\\n\",\n    \"            str: The output of the code execution.\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        async with aiohttp.ClientSession() as session:\\n\",\n    \"            async with session.post(\\n\",\n    \"                self.config.get(\\\"sandbox_fusion_url\\\"),\\n\",\n    \"                json={\\\"code\\\": code},\\n\",\n    \"            ) as resp:\\n\",\n    \"                resp.raise_for_status()\\n\",\n    \"                result = await resp.json()\\n\",\n    \"                stdout, stderr = result[\\\"run_result\\\"][\\\"stdout\\\"], result[\\\"run_result\\\"][\\\"stderr\\\"]\\n\",\n    \"                return stdout + stderr\\n\",\n    \"\\n\",\n    \"    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\\n\",\n    \"        schema = get_json_schema(self.code_interpreter)\\n\",\n    \"        return OpenAIFunctionToolSchema(**schema)\\n\",\n    \"\\n\",\n    \"    async def execute(self, instance_id: str, parameters: dict, **kwargs) -> tuple[str, float, dict]:\\n\",\n    \"        code = parameters[\\\"code\\\"]\\n\",\n    \"        matches = self.code_pattern.findall(code)\\n\",\n    \"        if matches:\\n\",\n    \"            code = matches[0].strip()\\n\",\n    \"\\n\",\n    \"        # NOTE: Some script may not explicitly print result, we need to add a print statement to the end of the script.\\n\",\n    \"        # More better way is to SFT the model to make it print result by default, we skip SFT stage in this tutorial.\\n\",\n    \"        lines = code.split(\\\"\\\\n\\\")\\n\",\n    \"        for i, line in reversed(list(enumerate(lines))):\\n\",\n    \"            if line == \\\"\\\":\\n\",\n    \"                continue\\n\",\n    \"            if not lines[i].startswith(\\\"print\\\"):\\n\",\n    \"                lines[i] = f\\\"print({line})\\\"\\n\",\n    \"            break\\n\",\n    \"        code = \\\"\\\\n\\\".join(lines)\\n\",\n    \"\\n\",\n    \"        result = await self.code_interpreter(code)\\n\",\n    \"        return ToolResponse(text=result), 0.0, {}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"sandbox_tool = SandboxTool(config={\\\"sandbox_fusion_url\\\": f\\\"http://{sandbox_address}/run_code\\\"}, tool_schema=None)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"First, let's try to execute a valid code and check the response with stdout.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"(ToolResponse(text='sqrt(3)\\\\n', image=None, video=None), 0.0, {})\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"code = \\\"\\\"\\\"```py\\n\",\n    \"import sympy\\n\",\n    \"\\n\",\n    \"print(sympy.sqrt(3))\\n\",\n    \"```\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"print(await sandbox_tool.execute(instance_id=\\\"\\\", parameters={\\\"code\\\": code}))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Then, let's try to execute an invalid code and check the response with stderr. The error message is important to inform LLM to fix code in next generation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"(ToolResponse(text='Traceback (most recent call last):\\\\n  File \\\"/tmp/temp_code3e2f638_.py\\\", line 2, in <module>\\\\n    print(sympy.sqrt(3))\\\\n          ^^^^^\\\\nNameError: name \\\\'sympy\\\\' is not defined\\\\n', image=None, video=None), 0.0, {})\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"code_invalid = \\\"\\\"\\\"\\n\",\n    \"print(sympy.sqrt(3))\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"print(await sandbox_tool.execute(instance_id=\\\"\\\", parameters={\\\"code\\\": code_invalid}))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2.3 Test sandbox tool\\n\",\n    \"\\n\",\n    \"Now, we can test sandbox tool with real math problem. In this tutorial, we will use the [DigitalLearningGmbH/MATH-lighteval](https://huggingface.co/datasets/DigitalLearningGmbH/MATH-lighteval) dataset, which consists of problems from mathematics competitions, including the AMC 10, AMC 12, AIME, and more.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"ebd09c8816b140a59a879e5a5e218950\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Generating train split: 0 examples [00:00, ? examples/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"from datasets import load_dataset\\n\",\n    \"\\n\",\n    \"dataset = load_dataset(\\\"parquet\\\", data_files=test_file)[\\\"train\\\"]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"For debug purpose, we can implement ReAct agent as a simple loop. For RL training, there are more subtle issue and corner case to deal with, we provide a built-in ReAct agent loop which will be discussed in next section.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"No tool calls, finish_reason: stop\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"messages = dataset[\\\"prompt\\\"][0]\\n\",\n    \"\\n\",\n    \"while True:\\n\",\n    \"    # 1. Chat with the model\\n\",\n    \"    completion = await client.chat.completions.create(\\n\",\n    \"        model=config.actor_rollout_ref.model.path,\\n\",\n    \"        messages=messages,\\n\",\n    \"        tools=[sandbox_tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True)],\\n\",\n    \"        extra_body={\\n\",\n    \"            \\\"chat_template_kwargs\\\": {\\\"enable_thinking\\\": False},\\n\",\n    \"        },\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)\\n\",\n    \"    messages.append(message)\\n\",\n    \"\\n\",\n    \"    # 2. Call tools\\n\",\n    \"    finish_reason = completion.choices[0].finish_reason\\n\",\n    \"    if finish_reason != \\\"tool_calls\\\":\\n\",\n    \"        print(f\\\"No tool calls, finish_reason: {finish_reason}\\\")\\n\",\n    \"        break\\n\",\n    \"\\n\",\n    \"    try:\\n\",\n    \"        tool_calls = completion.choices[0].message.tool_calls[0]\\n\",\n    \"        args = json.loads(tool_calls.function.arguments)\\n\",\n    \"        result, _, _ = await sandbox_tool.execute(\\\"\\\", args)\\n\",\n    \"    except Exception as e:\\n\",\n    \"        print(f\\\"Error: {e}\\\")\\n\",\n    \"\\n\",\n    \"    # 3. Add tool response to messages\\n\",\n    \"    messages.append(\\n\",\n    \"        {\\n\",\n    \"            \\\"role\\\": \\\"tool\\\",\\n\",\n    \"            \\\"content\\\": result.text,\\n\",\n    \"        }\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[{'content': \\\"How many vertical asymptotes does the graph of $y=\\\\\\\\frac{2}{x^2+x-6}$ have? Let's think step by step and output the final answer within \\\\\\\\boxed{}.\\\",\\n\",\n       \"  'role': 'user'},\\n\",\n       \" {'content': \\\"To determine the number of vertical asymptotes for the function $ y = \\\\\\\\frac{2}{x^2 + x - 6} $, we need to find the values of $ x $ where the denominator equals zero, as these points are where the function is undefined and potentially where it has vertical asymptotes.\\\\n\\\\nThe denominator is $ x^2 + x - 6 $. To find the vertical asymptotes, we need to solve the equation:\\\\n\\\\n$$ x^2 + x - 6 = 0 $$\\\\n\\\\nThis is a quadratic equation, and we can solve it using the quadratic formula:\\\\n\\\\n$$ x = \\\\\\\\frac{-b \\\\\\\\pm \\\\\\\\sqrt{b^2 - 4ac}}{2a} $$\\\\n\\\\nwhere $ a = 1 $, $ b = 1 $, and $ c = -6 $. Let's solve this equation to find the values of $ x $ where the denominator is zero, which will give us the vertical asymptotes.\\\",\\n\",\n       \"  'role': 'assistant',\\n\",\n       \"  'tool_calls': [{'id': 'call_4d873672ff8445159e4e5e45',\\n\",\n       \"    'function': {'arguments': '{\\\"code\\\": \\\"from sympy import symbols, solve\\\\\\\\nx = symbols(\\\\'x\\\\')\\\\\\\\nroots = solve(x**2 + x - 6, x)\\\\\\\\nroots\\\"}',\\n\",\n       \"     'name': 'code_interpreter'},\\n\",\n       \"    'type': 'function',\\n\",\n       \"    'index': -1}]},\\n\",\n       \" {'role': 'tool', 'content': '[-3, 2]\\\\n'},\\n\",\n       \" {'content': 'The roots of the equation $ x^2 + x - 6 = 0 $ are $ x = -3 $ and $ x = 2 $. These are the values of $ x $ where the denominator is zero, which means the function $ y = \\\\\\\\frac{2}{x^2 + x - 6} $ is undefined at these points. \\\\n\\\\nSince the denominator is zero at these values, the function has vertical asymptotes at $ x = -3 $ and $ x = 2 $. Therefore, the graph of the function has two vertical asymptotes.\\\\n\\\\nThe final answer is $\\\\\\\\boxed{2}$.',\\n\",\n       \"  'role': 'assistant'}]\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"messages\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can see that the ReAct agent properly query LLM, execute sandbox tool call, finally generate the answer.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 3. End-to-end training with tool agent loop\\n\",\n    \"\\n\",\n    \"After tool has been implemented and tested, we can do end-to-end RL training to tune the model to properly use the tool. To simplify agentic RL training, verl provide [Agent Loop](https://verl.readthedocs.io/en/latest/advance/agent_loop.html) abstraction, which allow user to define custom agent loop:\\n\",\n    \"- Search agent\\n\",\n    \"- Math agent\\n\",\n    \"- SWE agent\\n\",\n    \"- GUI agent\\n\",\n    \"- ...\\n\",\n    \"\\n\",\n    \"For ease of use, verl provide two pre-defined agent loop:\\n\",\n    \"- SingleTurnAgentLoop: single-turn conversation without tool calling\\n\",\n    \"- ToolAgentLoop: multi-turn conversation with tool calling, interaction\\n\",\n    \"\\n\",\n    \"To use ToolAgentLoop, user only need to provide tools configuration in json/yaml file. In the configuration file, user should specify following fields for each tool:\\n\",\n    \"- class_name: fully qualified class name of the tool used to dynamically load the custom tool class\\n\",\n    \"- config: key-word arguments used to initialize the tool instance\\n\",\n    \"\\n\",\n    \"Let's dump our sandbox tool configuration to a json file:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-10-16 23:07:16,868\\tINFO worker.py:2004 -- Started a local Ray instance. View the dashboard at \\u001b[1m\\u001b[32mhttp://127.0.0.1:8265 \\u001b[39m\\u001b[22m\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"ray.shutdown()\\n\",\n    \"\\n\",\n    \"sandbox = Sandbox.remote()\\n\",\n    \"sandbox_address = ray.get(sandbox.get_server_address.remote())\\n\",\n    \"\\n\",\n    \"tool_config = {\\n\",\n    \"    \\\"tools\\\": [\\n\",\n    \"        {\\n\",\n    \"            \\\"class_name\\\": \\\"sandbox.SandboxTool\\\",\\n\",\n    \"            \\\"config\\\": {\\n\",\n    \"                \\\"type\\\": \\\"native\\\",\\n\",\n    \"                \\\"sandbox_fusion_url\\\": f\\\"http://{sandbox_address}/run_code\\\",\\n\",\n    \"            },\\n\",\n    \"        },\\n\",\n    \"    ],\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"tool_config_path = \\\"tool_config.json\\\"\\n\",\n    \"with open(tool_config_path, \\\"w\\\") as f:\\n\",\n    \"    json.dump(tool_config, f)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/tmp/ipykernel_174199/3963810189.py:3: UserWarning: \\n\",\n      \"The version_base parameter is not specified.\\n\",\n      \"Please specify a compatability version level, or None.\\n\",\n      \"Will assume defaults for version 1.1\\n\",\n      \"  with initialize_config_dir(config_dir=verl_config_dir):\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from hydra import compose, initialize_config_dir\\n\",\n    \"\\n\",\n    \"with initialize_config_dir(config_dir=verl_config_dir):\\n\",\n    \"    config = compose(\\n\",\n    \"        config_name=\\\"ppo_trainer\\\",\\n\",\n    \"        overrides=[\\n\",\n    \"            \\\"algorithm.adv_estimator=grpo\\\",\\n\",\n    \"            \\\"data.train_files=\\\" + train_file,\\n\",\n    \"            \\\"data.val_files=\\\" + test_file,\\n\",\n    \"            \\\"data.return_raw_chat=True\\\",\\n\",\n    \"            \\\"data.train_batch_size=32\\\",\\n\",\n    \"            \\\"data.max_prompt_length=1024\\\",\\n\",\n    \"            \\\"data.max_response_length=1024\\\",\\n\",\n    \"            \\\"+data.apply_chat_template_kwargs.enable_thinking=False\\\",\\n\",\n    \"            # actor related\\n\",\n    \"            \\\"actor_rollout_ref.model.path=\\\" + model_path,\\n\",\n    \"            \\\"actor_rollout_ref.actor.ppo_mini_batch_size=8\\\",\\n\",\n    \"            \\\"actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8\\\",\\n\",\n    \"            \\\"actor_rollout_ref.actor.fsdp_config.param_offload=True\\\",\\n\",\n    \"            \\\"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\\\",\\n\",\n    \"            # rollout related\\n\",\n    \"            \\\"actor_rollout_ref.rollout.name=\\\" + rollout_name,\\n\",\n    \"            \\\"actor_rollout_ref.rollout.mode=async\\\",\\n\",\n    \"            \\\"actor_rollout_ref.rollout.tensor_model_parallel_size=1\\\",\\n\",\n    \"            \\\"actor_rollout_ref.rollout.n=8\\\",\\n\",\n    \"            \\\"actor_rollout_ref.rollout.multi_turn.tool_config_path=\\\" + tool_config_path,\\n\",\n    \"            \\\"actor_rollout_ref.rollout.agent.default_agent_loop=tool_agent\\\",\\n\",\n    \"            \\\"actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8\\\",\\n\",\n    \"            # trainer related\\n\",\n    \"            \\\"trainer.val_before_train=True\\\",\\n\",\n    \"            \\\"trainer.log_val_generations=10\\\",\\n\",\n    \"            \\\"trainer.n_gpus_per_node=8\\\",\\n\",\n    \"            \\\"trainer.test_freq=-1\\\",\\n\",\n    \"            \\\"trainer.total_training_steps=5\\\",\\n\",\n    \"            \\\"trainer.logger=['console','tensorboard', 'wandb']\\\",\\n\",\n    \"            \\\"trainer.project_name=verl\\\",\\n\",\n    \"            \\\"trainer.experiment_name=\\\" + os.path.basename(model_path),\\n\",\n    \"        ],\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from verl.trainer.main_ppo import main\\n\",\n    \"\\n\",\n    \"main(config)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"For demo purpose, we only train 5 steps, you can verify the training process by checking wandb metrics:\\n\",\n    \"- num_turns: min/max/mean chat conversation turns in each step.\\n\",\n    \"- critic rewards: min/max/mean critic rewards in each step.\\n\",\n    \"\\n\",\n    \"For more realistic agentic RL training, please refer to our recipe:\\n\",\n    \"- [retool](https://github.com/volcengine/verl/tree/main/recipe/retool): implementation of paper [ReTool: Reinforcement Learning for Strategic Tool Use in LLMs](https://arxiv.org/abs/2504.11536)\\n\",\n    \"- [collabllm](https://github.com/volcengine/verl/tree/main/recipe/collabllm): implementation of paper [CollabLLM: From Passive Responders to Active Collaborators](https://arxiv.org/pdf/2502.00640)\\n\",\n    \"- [deepeyes](https://github.com/volcengine/verl/tree/main/recipe/deepeyes): implementation of paper [DeepEyes: Incentivizing \\\"Thinking with Images\\\" via Reinforcement Learning](https://arxiv.org/abs/2505.14362)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"fileId\": \"398ea641-8a51-4a0b-b64e-6b7cd6b72164\",\n  \"filePath\": \"/opt/tiger/open_verl/examples/agent_loop_tutorial.ipynb\",\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.12.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "verl_distillation/examples/tutorial/agent_loop_get_started/sandbox.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport re\n\nimport aiohttp\nfrom transformers.utils import get_json_schema\n\nfrom verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema, ToolResponse\n\n\nclass SandboxTool(BaseTool):\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        super().__init__(config, tool_schema)\n        # Different model may use different code pattern, e.g. python, py, etc.\n        self.code_pattern = re.compile(r\"```py(.*?)```\", re.DOTALL)\n\n    async def code_interpreter(self, code: str) -> str:\n        \"\"\"Execute the code in the sandbox.\n\n        Args:\n            code: The code to be executed.\n\n        Returns:\n            str: The output of the code execution.\n        \"\"\"\n        async with aiohttp.ClientSession() as session:\n            async with session.post(\n                self.config.get(\"sandbox_fusion_url\"),\n                json={\"code\": code},\n            ) as resp:\n                resp.raise_for_status()\n                result = await resp.json()\n                stdout, stderr = result[\"run_result\"][\"stdout\"], result[\"run_result\"][\"stderr\"]\n                return stdout + stderr\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        schema = get_json_schema(self.code_interpreter)\n        return OpenAIFunctionToolSchema(**schema)\n\n    async def execute(self, instance_id: str, parameters: dict, **kwargs) -> tuple[str, float, dict]:\n        code = parameters[\"code\"]\n        matches = self.code_pattern.findall(code)\n        if matches:\n            code = matches[0].strip()\n\n        # NOTE: Some script may not explicitly print result, we need to add a print statement to the end of the script.\n        # More better way is to SFT the model to make it print result by default, we skip SFT stage in this tutorial.\n        lines = code.split(\"\\n\")\n        for i, line in reversed(list(enumerate(lines))):\n            if line == \"\":\n                continue\n            if not lines[i].startswith(\"print\"):\n                lines[i] = f\"print({line})\"\n            break\n        code = \"\\n\".join(lines)\n\n        result = await self.code_interpreter(code)\n        return ToolResponse(text=result), 0.0, {}\n"
  },
  {
    "path": "verl_distillation/init_ray.sh",
    "content": "#!/bin/bash\n# Single Node Ray Initialization Script\n# Usage: bash init_ray.sh <HEAD_NODE_IP> <PORT> <RANK>\n#   HEAD_NODE_IP: IP address of the head node\n#   PORT: Ray port (default: 6379)\n#   RANK: Node rank (0 for head, >0 for workers)\n\nset -e\n\n# Parse arguments\nHEAD_NODE_IP=${1:-\"127.0.0.1\"}\nPORT=${2:-6379}\nRANK=${3:-0}\n\n# Configuration\nNUM_CPUS=${NUM_CPUS:-\"\"}\nNUM_GPUS=${NUM_GPUS:-\"\"}\nOBJECT_STORE_MEMORY=${OBJECT_STORE_MEMORY:-\"\"}\nCONDA_ENV_NAME=${CONDA_ENV_NAME:-\"distill\"}\n\n# Colors\nGREEN='\\033[0;32m'\nYELLOW='\\033[1;33m'\nNC='\\033[0m'\n\nlog_info() {\n    echo -e \"${GREEN}[INFO]${NC} $(hostname): $1\"\n}\n\nlog_warn() {\n    echo -e \"${YELLOW}[WARN]${NC} $(hostname): $1\"\n}\n\n# Activate conda environment\nif [ -f \"/root/anaconda3/etc/profile.d/conda.sh\" ]; then\n    source \"/root/anaconda3/etc/profile.d/conda.sh\"\nelif [ -f \"$HOME/anaconda3/etc/profile.d/conda.sh\" ]; then\n    source \"$HOME/anaconda3/etc/profile.d/conda.sh\"\nelif [ -f \"$HOME/miniconda3/etc/profile.d/conda.sh\" ]; then\n    source \"$HOME/miniconda3/etc/profile.d/conda.sh\"\nfi\n\nif command -v conda &> /dev/null; then\n    conda activate ${CONDA_ENV_NAME} 2>/dev/null || log_warn \"Could not activate conda env: ${CONDA_ENV_NAME}\"\nfi\n\n# Build ray start command options\nRAY_OPTS=\"\"\nif [ -n \"${NUM_CPUS}\" ]; then\n    RAY_OPTS=\"${RAY_OPTS} --num-cpus=${NUM_CPUS}\"\nfi\nif [ -n \"${NUM_GPUS}\" ]; then\n    RAY_OPTS=\"${RAY_OPTS} --num-gpus=${NUM_GPUS}\"\nfi\nif [ -n \"${OBJECT_STORE_MEMORY}\" ]; then\n    RAY_OPTS=\"${RAY_OPTS} --object-store-memory=${OBJECT_STORE_MEMORY}\"\nfi\n\n# Stop existing Ray instance\nray stop --force 2>/dev/null || true\nsleep 2\n\n# Start Ray\nif [ \"${RANK}\" -eq 0 ]; then\n    log_info \"Starting Ray HEAD node on port ${PORT}...\"\n    ray start --head --port=${PORT} ${RAY_OPTS}\nelse\n    log_info \"Starting Ray WORKER node, connecting to ${HEAD_NODE_IP}:${PORT}...\"\n    ray start --address=${HEAD_NODE_IP}:${PORT} ${RAY_OPTS}\nfi\n\nsleep 3\n\n# Check status\nlog_info \"Ray node started. Checking status...\"\nray status\n"
  },
  {
    "path": "verl_distillation/init_ray_cluster.sh",
    "content": "#!/bin/bash\n# Multi-node Ray Cluster Initialization Script\n# Usage: bash init_ray_cluster.sh [--stop]\n#   --stop: Stop Ray on all nodes instead of starting\n\nset -e\n\nSCRIPT_DIR=$(cd $(dirname $0); pwd)\nPROJECT_DIR=${SCRIPT_DIR}\n\n# Configuration\nPORT=${RAY_PORT:-6379}\nHOSTFILE=${HOSTFILE:-\"/etc/mpi/hostfile\"}\nCONDA_ENV_NAME=${CONDA_ENV_NAME:-\"distill\"}\nLOG_DIR=\"${PROJECT_DIR}/logs/ray\"\n\n# Colors\nRED='\\033[0;31m'\nGREEN='\\033[0;32m'\nYELLOW='\\033[1;33m'\nNC='\\033[0m'\n\nlog_info() {\n    echo -e \"${GREEN}[INFO]${NC} $1\"\n}\n\nlog_warn() {\n    echo -e \"${YELLOW}[WARN]${NC} $1\"\n}\n\nlog_error() {\n    echo -e \"${RED}[ERROR]${NC} $1\"\n}\n\n# Function to stop Ray on all nodes\nstop_cluster() {\n    log_info \"Stopping Ray on all nodes...\"\n\n    if [ ! -f \"${HOSTFILE}\" ]; then\n        log_warn \"Hostfile not found, stopping local Ray only\"\n        ray stop --force 2>/dev/null || true\n        return\n    fi\n\n    ALL_NODES=$(awk '!a[$1]++ {print $1}' ${HOSTFILE})\n\n    for node in ${ALL_NODES}; do\n        log_info \"Stopping Ray on ${node}...\"\n        ssh -n ${node} \"source /root/anaconda3/etc/profile.d/conda.sh && conda activate ${CONDA_ENV_NAME} && ray stop --force\" 2>/dev/null &\n    done\n\n    wait\n    log_info \"Ray stopped on all nodes\"\n}\n\n# Function to start Ray cluster\nstart_cluster() {\n    # Check hostfile\n    if [ ! -f \"${HOSTFILE}\" ]; then\n        log_error \"Hostfile not found: ${HOSTFILE}\"\n        log_info \"Please create a hostfile with one IP per line\"\n        log_info \"Example:\"\n        echo \"  192.168.1.100\"\n        echo \"  192.168.1.101\"\n        echo \"  192.168.1.102\"\n        exit 1\n    fi\n\n    # Get head node (first line)\n    HEAD_NODE=$(awk 'NR==1 {print $1}' ${HOSTFILE})\n    ALL_NODES=$(awk '!a[$1]++ {print $1}' ${HOSTFILE})\n\n    log_info \"Head node: ${HEAD_NODE}\"\n    log_info \"Ray port: ${PORT}\"\n    log_info \"Conda env: ${CONDA_ENV_NAME}\"\n    echo \"\"\n    log_info \"Nodes in cluster:\"\n    echo \"${ALL_NODES}\"\n    echo \"\"\n\n    # Create log directory\n    mkdir -p \"${LOG_DIR}\"\n\n    # Stop existing Ray instances first\n    log_info \"Stopping any existing Ray instances...\"\n    stop_cluster\n    sleep 3\n\n    # Start head node first (synchronously)\n    log_info \"Starting Ray HEAD on ${HEAD_NODE}...\"\n    ssh -n ${HEAD_NODE} \"CONDA_ENV_NAME=${CONDA_ENV_NAME} bash ${SCRIPT_DIR}/init_ray.sh ${HEAD_NODE} ${PORT} 0\" \\\n        > \"${LOG_DIR}/ray_${HEAD_NODE}.log\" 2>&1\n\n    if [ $? -ne 0 ]; then\n        log_error \"Failed to start Ray HEAD. Check ${LOG_DIR}/ray_${HEAD_NODE}.log\"\n        exit 1\n    fi\n    log_info \"Ray HEAD started successfully\"\n\n    # Wait for head to be ready\n    sleep 5\n\n    # Start worker nodes (asynchronously)\n    rank=1\n    for node in ${ALL_NODES}; do\n        if [ \"${node}\" == \"${HEAD_NODE}\" ]; then\n            continue\n        fi\n\n        log_info \"Starting Ray WORKER on ${node} (rank ${rank})...\"\n        ssh -n ${node} \"CONDA_ENV_NAME=${CONDA_ENV_NAME} bash ${SCRIPT_DIR}/init_ray.sh ${HEAD_NODE} ${PORT} ${rank}\" \\\n            > \"${LOG_DIR}/ray_${node}.log\" 2>&1 &\n        rank=$((rank + 1))\n    done\n\n    # Wait for all workers\n    log_info \"Waiting for all workers to join...\"\n    wait\n    sleep 3\n\n    # Check cluster status\n    echo \"\"\n    log_info \"Ray cluster initialization complete!\"\n    log_info \"Logs saved to: ${LOG_DIR}/\"\n    echo \"\"\n    log_info \"Cluster status:\"\n    ssh -n ${HEAD_NODE} \"source /root/anaconda3/etc/profile.d/conda.sh && conda activate ${CONDA_ENV_NAME} && ray status\"\n}\n\n# Main\ncase \"${1}\" in\n    --stop)\n        stop_cluster\n        ;;\n    *)\n        start_cluster\n        ;;\nesac\n"
  },
  {
    "path": "verl_distillation/pyproject.toml",
    "content": "# -------------------------------\n# build-system\n# -------------------------------\n[build-system]\nrequires = [\n    \"setuptools>=61.0\",\n    \"wheel\"\n]\nbuild-backend = \"setuptools.build_meta\"\n\n# -------------------------------\n# project (PEP 621 metadata)\n# -------------------------------\n[project]\nname = \"verl\"\n# We'll mark the version as \"dynamic\" because it's read from the file \"verl/version/version\" \n# (PEP 621 calls this \"dynamic version\"). \n# The actual version is specified in the [tool.setuptools.dynamic] section below.\ndynamic = [\"version\", \"dependencies\", \"optional-dependencies\", \"authors\", \"urls\"]\n\ndescription = \"verl: Volcano Engine Reinforcement Learning for LLM\"\nlicense = {text = \"Apache-2.0\"}  # Changed from file to text format\nreadme = {file = \"README.md\", content-type = \"text/markdown\"}\nrequires-python = \">=3.10\"\n\n# -------------------------------\n# tool.ruff - Linting configuration\n# -------------------------------\n[tool.ruff]\n# Note: While the formatter will attempt to format lines such that they remain within the line-length,\n# it isn't a hard upper bound, and formatted lines may exceed the line-length.\nline-length = 120\nexclude = [\"tests/workers/rollout/test_sglang_async_rollout_sf_tools.py\", \"scripts/legacy_model_merger.py\"]\n\n[tool.ruff.lint]\nisort = {known-first-party = [\"verl\"]}\n# c.f. https://github.com/vllm-project/vllm/blob/ce8d6b75fc0586045df75ee1568a5b5f9957251b/pyproject.toml\nselect = [\n    # pycodestyle\n    \"E\",\n    # Pyflakes\n    \"F\",\n    # pyupgrade\n    \"UP\",\n    # flake8-bugbear\n    \"B\",\n    # isort\n    \"I\",\n    \"G\",\n]\nignore = [\n    # star imports\n    \"F405\", \"F403\",\n    # lambda expression assignment\n    \"E731\",\n    # Loop control variable not used within loop body\n    \"B007\",\n    # f-string format\n    \"UP032\",\n    # `.log()` statement uses f-string\n    \"G004\",\n    # X | None for type annotations\n    \"UP045\",\n    # deprecated import\n    \"UP035\",\n]\n\n# -------------------------------\n# tool.mypy - typechecking config\n# -------------------------------\n[tool.mypy]\npretty            = true\nignore_missing_imports = true\nexplicit_package_bases = true\nfollow_imports = \"skip\"\n\n# Blanket silence\nignore_errors = true\n\n[[tool.mypy.overrides]]\nmodule = [\n\"verl.trainer.config.algorithm\",\n\"verl.trainer.ppo.core_algos\",\n\"verl.trainer.ppo.reward\",\n\"verl.workers.reward_manager\",\n\"verl.workers.reward_manager.*\",\n]\nignore_errors = false\n\n# -------------------------------\n# tool.setuptools - Additional config\n# -------------------------------\n[tool.setuptools]\n# True means `setuptools` will attempt to include all relevant files in package_data automatically.\n# This corresponds to `include_package_data=True` in setup.py.\ninclude-package-data = true\n\n# We read the version from a file in 'verl/version/version'\n[tool.setuptools.dynamic]\nversion = {file = \"verl/version/version\"}\n\n# If you need to mimic `package_dir={'': '.'}`:\n[tool.setuptools.package-dir]\n\"\" = \".\"\n\n# If you need to include specific non-Python data (like YAML files or version file):\n# This is the rough equivalent of package_data={'': ['version/*'], 'verl': ['trainer/config/*.yaml']}\n[tool.setuptools.package-data]\nverl = [\n  \"version/*\",\n  \"trainer/config/*.yaml\",\n  \"trainer/config/*/*.yaml\",\n]\n\"recipe.onpolicy_distill\" = [\n  \"config/*.yaml\",\n]\n"
  },
  {
    "path": "verl_distillation/recipe/README.md",
    "content": "# Recipe\nThe examples under `recipes/` are representative extensions to verl for specific end-to-end RL training recipes.\nThe help the community reproduce experiments, verl team provides a snapshot of the codebase when each recipe is initially PR'ed to verl main. You can find them via [github branches](https://github.com/volcengine/verl/branches/all?query=recipe)\n\n# Awesome work using verl\n\n- [Logic-RL](https://github.com/Unakar/Logic-RL): a reproduction of DeepSeek R1 Zero on 2K Tiny Logic Puzzle Dataset. ![GitHub Repo stars](https://img.shields.io/github/stars/Unakar/Logic-RL)\n- [Seed-Coder](https://github.com/ByteDance-Seed/Seed-Coder): RL training of Seed-Coder boosts performance on competitive programming ![GitHub Repo stars](https://img.shields.io/github/stars/ByteDance-Seed/Seed-Coder)\n- [all-hands/openhands-lm-32b-v0.1](https://www.all-hands.dev/blog/introducing-openhands-lm-32b----a-strong-open-coding-agent-model): A strong, open coding agent model, trained with [multi-turn fine-tuning](https://github.com/volcengine/verl/pull/195)\n- [s3](https://github.com/pat-jj/s3) **Efficient Yet Effective** Search Agent Training via RL ![GitHub Repo stars](https://img.shields.io/github/stars/pat-jj/s3)\n- [Rec-R1](https://arxiv.org/pdf/2503.24289): Bridging Generative Large Language Models and Recommendation Systems via Reinforcement Learning\n- [Explore RL Data Scaling](https://arxiv.org/abs/2503.22230): Exploring Data Scaling Trends and Effects in Reinforcement Learning from Human Feedback\n- [FIRE](https://arxiv.org/abs/2410.21236): Flaming-hot initiation with regular execution sampling for large language models\n- [DQO](https://arxiv.org/abs/2410.09302): Enhancing multi-Step reasoning abilities of language models through direct Q-function optimization\n- [ProRL](https://arxiv.org/abs/2505.24864): Prolonged Reinforcement Learning Expands Reasoning Boundaries in Large Language Models\n- [cognition-engineering](https://github.com/gair-nlp/cognition-engineering): Test time scaling drives cognition engineering. ![GitHub Repo stars](https://img.shields.io/github/stars/gair-nlp/cognition-engineering)\n- [Trust Region Preference Approximation](https://github.com/XueruiSu/Trust-Region-Preference-Approximation): A simple and stable **reinforcement learning algorithm** for LLM reasoning. ![GitHub Repo stars](https://img.shields.io/github/stars/XueruiSu/Trust-Region-Preference-Approximation)\n- [AdaRFT](https://github.com/uscnlp-lime/verl): Efficient Reinforcement Finetuning via **Adaptive Curriculum Learning** ![GitHub Repo stars](https://img.shields.io/github/stars/uscnlp-lime/verl)\n- [critic-rl](https://github.com/HKUNLP/critic-rl): LLM critics for code generation ![GitHub Repo stars](https://img.shields.io/github/stars/HKUNLP/critic-rl)\n- [self-rewarding-reasoning-LLM](https://arxiv.org/pdf/2502.19613): self-rewarding and correction with **generative reward models** ![GitHub Repo stars](https://img.shields.io/github/stars/RLHFlow/Self-rewarding-reasoning-LLM)\n- [DeepEnlighten](https://github.com/DolbyUUU/DeepEnlighten): Reproduce R1 with **social reasoning** tasks and analyze key findings ![GitHub Repo stars](https://img.shields.io/github/stars/DolbyUUU/DeepEnlighten)\n- [MetaSpatial](https://github.com/PzySeere/MetaSpatial): Reinforcing **3D Spatial Reasoning** in **VLMs** for the **Metaverse** ![GitHub Repo stars](https://img.shields.io/github/stars/PzySeere/MetaSpatial)\n- [PURE](https://github.com/CJReinforce/PURE): **Credit assignment** is the key to successful reinforcement fine-tuning using **process reward model** ![GitHub Repo stars](https://img.shields.io/github/stars/CJReinforce/PURE)\n- [cognitive-behaviors](https://github.com/kanishkg/cognitive-behaviors): Cognitive Behaviors that Enable Self-Improving Reasoners, or, Four Habits of Highly Effective STaRs ![GitHub Repo stars](https://img.shields.io/github/stars/kanishkg/cognitive-behaviors)\n- [deepscaler](https://github.com/agentica-project/rllm/tree/deepscaler): iterative context scaling with GRPO ![GitHub Repo stars](https://img.shields.io/github/stars/agentica-project/deepscaler)\n- [DAPO](https://dapo-sia.github.io/): the fully open source SOTA RL algorithm that beats DeepSeek-R1-zero-32B ![GitHub Repo stars](https://img.shields.io/github/stars/volcengine/verl)\n- [NoisyRollout](https://github.com/NUS-TRAIL/NoisyRollout): Reinforcing Visual Reasoning with Data Augmentation ![GitHub Repo stars](https://img.shields.io/github/stars/NUS-TRAIL/NoisyRollout)\n"
  },
  {
    "path": "verl_distillation/recipe/__init__.py",
    "content": "# This file makes `recipe` a regular Python package so that entrypoints like\n# `python -m recipe.onpolicy_distill.main_onpolicy_distill` work reliably after installation.\n\n\n"
  },
  {
    "path": "verl_distillation/recipe/char_count/README.md",
    "content": "# Char Count\n## Introduction\nChar count is a simple NLP task. We create it for beginners to grasp the idea of RLVR. The task can be trained using a tiny model (e.g., https://huggingface.co/HuggingFaceTB/SmolLM2-135M) on a consumer GPU with only 8GB.\n\n## Problem formulation\nThe prompt is: \"How many {char} are there in {word}?\". In order for LLM to better answer this question, we create SFT dataset with intermediate steps. For example,\n\n```text\nQuestion: How many n are there in n-i-n-e?\nAnswer:\nn = n\ni != n\nn = n\ne != n\n\\boxed{2}\n```\n\nNote that\n- We add a dash between each individual char to make the task easier because each individual char will be tokenized to the same token by most tokenizer.\n- In the SFT dataset, we create a CoT by listing all the individual chars and whether it equals to the target. In the end, it outputs the final answer inside the box.\n- The task can be verified.\n- The word is not always meaningful. Each char is sampled uniformly from a to z. We make the total length and the answer uniformly distributed within a range.\n\n## Scripts\nTo create the dataset, run\n```bash\npython3 create_dataset.py\n```\nWe create a train set and a val set. Both of them are used of SFT and RL. You can specify the total number of data, min/max length and data path.\n\nTo run the SFT\n```bash\nbash train_sft.sh\n```\nWe train SFT for 3 epochs. After 3 epochs, the validation score is around 0.12.\n\nTo run GRPO\n```bash\nbash train_grpo.sh\n```\nWe train GRPO for 2 epochs. After 2 epochs, the validation score is around 0.36.\n"
  },
  {
    "path": "verl_distillation/recipe/char_count/create_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nTask description:\nGiven a random word and a random char, count the number of occurrence of char in the word.\n\nCreate CoT dataset that split the word into separate char. Then list the char and count the occurrence.\n\nThe word set comes from shakespeare\n\"\"\"\n\nimport os.path\nimport random\n\nprompt_template = \"How many {} are there in word {}?\"\n\n\ndef generate_random_char():\n    return chr(97 + random.randint(0, 25))\n\n\ndef create_prompt_response(min_length=3, max_length=5):\n    # randomly generate a length\n    word_length = random.randint(min_length, max_length)\n    # randomly generate a target count number. This makes the target number\n    target_count_number = random.randint(1, word_length)\n\n    char_lst = []\n    # generate the word\n    # step 1: generate the target word\n    target_char = generate_random_char()\n\n    for _ in range(target_count_number):\n        char_lst.append(target_char)\n\n    # step 2: generate other words\n    for _ in range(word_length - target_count_number):\n        while True:\n            char = generate_random_char()\n            if char != target_char:\n                char_lst.append(char)\n                break\n\n    # step 3: random permute char_lst\n    random.shuffle(char_lst)\n\n    word = \"-\".join(char_lst)\n\n    prompt = prompt_template.format(target_char, word)\n    final_answer = []\n\n    # cot\n    number = 0\n    for i, char in enumerate(char_lst):\n        cot = f\"{char}\"\n        if char != target_char:\n            cot += \" != \"\n        else:\n            cot += \" = \"\n            number += 1\n        cot += f\"{target_char}.\"\n\n        final_answer.append(cot)\n\n    conclusion = f\"\\\\boxed{{{number}}} {target_char} in {word}.\"\n\n    final_answer.append(conclusion)\n\n    final_answer = \"\\n\".join(final_answer)\n\n    return prompt, final_answer\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--total_number\", type=int, default=10000)\n    parser.add_argument(\"--min_length\", type=int, default=5)\n    parser.add_argument(\"--max_length\", type=int, default=20)\n    parser.add_argument(\"--data_path\", type=str, default=\"~/data/char_count\")\n\n    args = vars(parser.parse_args())\n\n    total_number = args[\"total_number\"]\n    min_length = args[\"min_length\"]\n    max_length = args[\"max_length\"]\n    data_path = args[\"data_path\"]\n    data_path = os.path.expanduser(data_path)\n\n    full_output = []\n    for _ in range(total_number):\n        output = create_prompt_response(min_length=min_length, max_length=max_length)\n        full_output.append(output)\n\n    # random reorder\n    random.shuffle(full_output)\n\n    # split for train and test\n    train_split_len = int(0.9 * len(full_output))\n    train_outputs = full_output[:train_split_len]\n    test_output = full_output[train_split_len:]\n\n    sft_train_dataset = {\"prompt\": [], \"response\": []}\n\n    for o in train_outputs:\n        sft_train_dataset[\"prompt\"].append(o[0])\n        sft_train_dataset[\"response\"].append(o[1])\n\n    sft_test_dataset = {\"prompt\": [], \"response\": []}\n\n    for o in test_output:\n        sft_test_dataset[\"prompt\"].append(o[0])\n        sft_test_dataset[\"response\"].append(o[1])\n\n    import pandas as pd\n\n    sft_train_dataset = pd.DataFrame(data=sft_train_dataset)\n    sft_test_dataset = pd.DataFrame(data=sft_test_dataset)\n\n    folder = os.path.join(data_path, \"sft\")\n\n    os.makedirs(folder, exist_ok=True)\n\n    sft_train_dataset.to_parquet(os.path.join(folder, \"train.parquet\"))\n    sft_test_dataset.to_parquet(os.path.join(folder, \"test.parquet\"))\n\n    # build RL dataset\n    rl_train_dataset = {\"prompt\": [], \"data_source\": [], \"ability\": [], \"reward_model\": [], \"extra_info\": []}\n\n    rl_test_dataset = {\"prompt\": [], \"data_source\": [], \"ability\": [], \"reward_model\": [], \"extra_info\": []}\n\n    from verl.utils.reward_score.math_reward import last_boxed_only_string, remove_boxed\n\n    for o in train_outputs:\n        prompt = o[0]\n        response = o[1]\n        prompt_with_template = [\n            {\n                \"role\": \"user\",\n                \"content\": prompt,\n            }\n        ]\n\n        rl_train_dataset[\"prompt\"].append(prompt_with_template)\n        rl_train_dataset[\"data_source\"].append(\"char_count\")\n        rl_train_dataset[\"ability\"].append(\"other\")\n        rl_train_dataset[\"reward_model\"].append(\n            {\"style\": \"rule\", \"ground_truth\": remove_boxed(last_boxed_only_string(response))}\n        )\n        rl_train_dataset[\"extra_info\"].append({\"response\": response})\n\n    for o in test_output:\n        prompt = o[0]\n        response = o[1]\n        prompt_with_template = [\n            {\n                \"role\": \"user\",\n                \"content\": prompt,\n            }\n        ]\n\n        rl_test_dataset[\"prompt\"].append(prompt_with_template)\n        rl_test_dataset[\"data_source\"].append(\"char_count\")\n        rl_test_dataset[\"ability\"].append(\"other\")\n        rl_test_dataset[\"reward_model\"].append(\n            {\"style\": \"rule\", \"ground_truth\": remove_boxed(last_boxed_only_string(response))}\n        )\n        rl_test_dataset[\"extra_info\"].append({\"response\": response})\n\n    rl_train_dataset = pd.DataFrame(data=rl_train_dataset)\n    rl_test_dataset = pd.DataFrame(data=rl_test_dataset)\n\n    folder = os.path.join(data_path, \"rl\")\n\n    os.makedirs(folder, exist_ok=True)\n\n    rl_train_dataset.to_parquet(os.path.join(folder, \"train.parquet\"))\n    rl_test_dataset.to_parquet(os.path.join(folder, \"test.parquet\"))\n"
  },
  {
    "path": "verl_distillation/recipe/char_count/reward_function.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nReward function\n\"\"\"\n\nfrom verl.utils.reward_score import math_reward\n\n\ndef char_count_reward_function(data_source, solution_str, ground_truth, extra_info=None):\n    try:\n        last_boxed_string = math_reward.last_boxed_only_string(solution_str)\n        if last_boxed_string is None:\n            return 0\n        solution = math_reward.remove_boxed(last_boxed_string)\n        if solution == ground_truth:\n            return 1\n        else:\n            return 0\n    except Exception:\n        print(ground_truth, solution_str)\n        return 0\n"
  },
  {
    "path": "verl_distillation/recipe/char_count/train_grpo.sh",
    "content": "set -x\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/char_count/rl/train.parquet \\\n    data.val_files=$HOME/data/char_count/rl/test.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=128 \\\n    data.max_response_length=128 \\\n    data.filter_overlong_prompts=False \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=./models/sft/global_step_105 \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=16 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5000 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.0 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"tensorboard\"]' \\\n    trainer.project_name='verl_example' \\\n    trainer.experiment_name='smol135m_grpo' \\\n    trainer.val_before_train=True \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=2 \\\n    custom_reward_function.path=recipe/char_count/reward_function.py \\\n    custom_reward_function.name=char_count_reward_function\n"
  },
  {
    "path": "verl_distillation/recipe/char_count/train_sft.sh",
    "content": "set -x\n\nnproc_per_node=1\nsave_path=./models/sft\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/char_count/sft/train.parquet \\\n    data.val_files=$HOME/data/char_count/sft/test.parquet \\\n    data.prompt_key=prompt \\\n    data.response_key=response \\\n    data.micro_batch_size_per_gpu=8 \\\n    data.max_length=256 \\\n    data.train_batch_size=256 \\\n    use_remove_padding=True \\\n    model.partial_pretrain=HuggingFaceTB/SmolLM2-135M-Instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=char_count-sft \\\n    trainer.experiment_name=char_count-sft-SmolLM2-135M-Instruct \\\n    trainer.total_epochs=3 \\\n    trainer.logger=console"
  },
  {
    "path": "verl_distillation/recipe/collabllm/README.md",
    "content": "# CollabLLM\n\nThis repository implements [CollabLLM](https://arxiv.org/pdf/2502.00640) (ICML 2025) using the verl framework. For the original implementation, see the [CollabLLM repository](https://github.com/Wuyxin/collabllm).\n\n\nCollabLLM is a method for training language models to collaborate effectively in multi-turn conversations. This implementation adapts the original imlpementation to work with the Verl training framework.\n\n## Quick start\n\n### 0. Environment\nMake sure the required packages for `verl` are installed. Additionally, install `litellm` and export the required API keys. The API model will be used for user simulators and, optionally, LLM Judges (see the Configuration section below).\n\n### 1. Prepare Your Dataset\n\nFirst, process your dataset using the provided script:\n\n```bash\npython process_dataset.py --dataset <> ... --dataset_type <sft or rl>\n```\n\n\n**Requirements:**\n- Input: A Hugging Face multiturn dataset. Existing datasets: `collabllm/collabllm-multiturn-$DATASET`, with `DATASET` in one of [`math-hard(-large)`, `medium(-large)`, `bigcodebench(-large)`] (*-large are the datasets used in the CollabLLM paper)\n- Example format: See [collabllm-multiturn-math-hard](https://huggingface.co/datasets/collabllm/collabllm-multiturn-math-hard)\n- To generate your own dataset: Use [build_dataset.py](https://github.com/Wuyxin/collabllm/blob/main/scripts/engine/build_dataset.py) from the original CollabLLM repository\n\n*Note: Check `process_dataset.py` for example commands and usage.*\n\n### 2. Train Your Model\n\n**(Optional) For Supervised Fine-Tuning (SFT):**\n```bash\nbash train_sft_collabllm.sh\n```\n\n**For Reinforcement Learning (RL):**\n\n```bash\nbash train_rl_collabllm.sh\n```\n\nThe RL script shows an example to train CollabLLM on `math-hard-large`. \n\n- The config to sample future conversations are in `recipe/collabllm/config/collabllm_interaction_config.yaml`. \n- The Multiturn-aware Reward is aggregated from these three conversational-level rewards:\n\n    ```\n    +reward_model.reward_kwargs.metric_weights.accuracy=1 \\\n    +reward_model.reward_kwargs.metric_weights.interactivity=1 \\\n    +reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \\\n    ```\n\n    You can remove, add, or modify the weights depending on your task. A list of implemented metrics you can already add are under `recipe/collabllm/metrics`. For example, on `medium-large`, you can replace `accuracy` with `bleu_score` via\n    ```\n    +reward_model.reward_kwargs.metric_weights.bleu_score=1 \n    ```\n    which will instead apply bleu score on the sampled future conversations. \n\n## Configuration \nRead [doc](https://verl.readthedocs.io/en/latest/) for detailed configurations.\n\n## Citation\nIf you find CollabLLM useful in your research, please cite the following:\n\n```bibtex\n@inproceedings{collabllm2025,\n    title={CollabLLM: From Passive Responders to Active Collaborators},\n    author={Shirley Wu and Michel Galley and Baolin Peng and Hao Cheng and \n            Gavin Li and Yao Dou and Weixin Cai and James Zou and \n            Jure Leskovec and Jianfeng Gao},\n    booktitle={International Conference on Machine Learning (ICML)},\n    year={2025}\n}\n```\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/collabllm_agent_loop.py",
    "content": "# Copyright 2025 CollabLLM team and/or its affiliates\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom copy import deepcopy\nfrom typing import Any\nfrom uuid import uuid4\n\nfrom recipe.collabllm.utils import is_valid_messages\nfrom verl.experimental.agent_loop.agent_loop import AgentLoopOutput\nfrom verl.experimental.agent_loop.tool_agent_loop import AgentData, AgentState, ToolAgentLoop\nfrom verl.utils.rollout_trace import rollout_trace_op\nfrom verl.workers.rollout.schemas import Message\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass CollabLLMAgentLoop(ToolAgentLoop):\n    @rollout_trace_op\n    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:\n        messages = list(kwargs[\"raw_prompt\"])\n        image_data = deepcopy(kwargs.get(\"multi_modal_data\", {}).get(\"image\", None))\n        metrics = {}\n        request_id = uuid4().hex\n        tools_kwargs = kwargs.get(\"tools_kwargs\", {})\n\n        # Initialize interaction if needed\n        interaction = None\n        interaction_kwargs = {}\n        if self.interaction_config_file:\n            interaction_kwargs = kwargs[\"extra_info\"][\"interaction_kwargs\"]\n            if \"name\" not in interaction_kwargs:\n                raise ValueError(\"'name' key is required in interaction_kwargs\")\n            interaction_name = interaction_kwargs[\"name\"]\n            if interaction_name not in self.interaction_map:\n                raise ValueError(\n                    f\"Interaction '{interaction_name}' not found in interaction_map. Available interactions: \"\n                    f\"{list(self.interaction_map.keys())}\"\n                )\n            interaction = self.interaction_map[interaction_name]\n            await interaction.start_interaction(request_id, **interaction_kwargs)\n        # Create AgentData instance to encapsulate all state\n        agent_data = AgentData(\n            messages=messages,\n            image_data=image_data,\n            metrics=metrics,\n            request_id=request_id,\n            tools_kwargs=tools_kwargs,\n            interaction=interaction,\n            interaction_kwargs=interaction_kwargs,\n        )\n        # for collabllm, firstly generate model reponses\n        await self._handle_pending_state(agent_data, sampling_params)\n\n        status = await self._handle_generating_state(agent_data, sampling_params)\n\n        if status == AgentState.TERMINATED:\n            # tell reward manager to score -1 and skip future interaction\n            # to avoid reward hacking with incompleted message\n            num_repeats = 0\n        else:\n            # then, collect interaction rollouts\n            num_repeats = self.config.actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts\n\n        interaction_requests = [deepcopy(agent_data) for _ in range(num_repeats)]\n\n        # messages are only used in collabllm reward manager\n        messages_lst = []\n        for _agent_data in interaction_requests:\n            if not is_valid_messages(_agent_data.messages[-1]):\n                break\n\n            prev_msg_len = len(_agent_data.messages)\n            await self.run_agent_data_loop(_agent_data, sampling_params, AgentState.INTERACTING)\n            messages_lst.append([Message(**msg) for msg in _agent_data.messages])\n\n            if interaction.config.get(\"enable_log\"):\n                print(f\"Assistant: ...{messages_lst[-1][prev_msg_len - 1].content[-100:]}\")\n                print(f\"User:      {messages_lst[-1][prev_msg_len].content[:100]}...\")\n\n        # Finalize output\n        response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :]\n        prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)]\n        multi_modal_data = {\"image\": agent_data.image_data} if agent_data.image_data is not None else {}\n\n        output = AgentLoopOutput(\n            prompt_ids=prompt_ids,\n            response_ids=response_ids[: self.response_length],\n            response_mask=agent_data.response_mask[: self.response_length],\n            multi_modal_data=multi_modal_data,\n            response_logprobs=agent_data.response_logprobs[: self.response_length]\n            if agent_data.response_logprobs\n            else None,\n            num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,\n            metrics=agent_data.metrics,\n            extra_fields={\n                \"turn_scores\": agent_data.turn_scores,\n                \"messages\": {\"messages\": messages_lst},  # compatiable with sglang interaction\n            },\n        )\n        return output\n\n    async def run_agent_data_loop(self, agent_data: AgentData, sampling_params: dict[str, Any], state: AgentState):\n        \"\"\"\n        Run the agent data loop to process the agent data.\n\n        Args:\n            agent_data (AgentData): The agent data to process.\n            sampling_params (dict[str, Any]): The sampling parameters.\n            state (AgentState, optional): The initial state of the agent. Defaults to None.\n        \"\"\"\n\n        while state != AgentState.TERMINATED:\n            if state == AgentState.PENDING:\n                state = await self._handle_pending_state(agent_data, sampling_params)\n            elif state == AgentState.GENERATING:\n                state = await self._handle_generating_state(agent_data, sampling_params)\n            elif state == AgentState.PROCESSING_TOOLS:\n                state = await self._handle_processing_tools_state(agent_data)\n            elif state == AgentState.INTERACTING:\n                state = await self._handle_interacting_state(agent_data)\n            else:\n                logger.error(f\"Invalid state: {state}\")\n                state = AgentState.TERMINATED\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/collabllm_interation.py",
    "content": "# Copyright 2024 CollabLLM Ltd. and/or its affiliates\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport copy\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom recipe.collabllm.utils import remove_think_block\nfrom verl.interactions.base import BaseInteraction\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\nTERMINATION_SIGNAL = \"[[TERMINATE CHAT]]\"\nUSER_PROMPT_TEMPLATE = \"\"\"You are role-playing as a human USER interacting with an AI collaborator to complete a specific task. Your goal is to generate realistic, natural responses that a user might give in this scenario.\n\n## Input Information:\nYou will be provided with:\n- Task Description: The type of task you are trying to accomplish.\n- Complete Prompt or Reference Goal: This field may include the complete user request/query or a reference answer to user's request. Use this field to understand the user's intent, requirements, or what would count as a satisfactory outcome.\n- Chat History: The ongoing conversation between you (as the user) and the AI\n\nInputs:\n<|The Start of Task Description (Not visible to the AI)|>\n{task_desc}\n<|The End of Task Description|>\n\n<|The Start of Complete Prompt or Reference Goal (Not visible to the AI)|>\n{single_turn_prompt}\n<|The End of Complete Prompt or Reference Goal|>\n\n<|The Start of Chat History|>\n{chat_history}\n<|The End of Chat History|>\n\n\n## Guidelines:\n- Stay in Character: Role-play as a human USER. You are NOT an AI. Maintain a consistent personality throughout the chat.\n- Minimize Effort: IMPORTANT! As a user, avoid being too detailed in your responses. Provide vague or incomplete demands in the early stages of the conversation to minimize your effort. Let the AI ask for clarification rather than providing everything upfront.\n- Knowledge Background: Reflect the user's knowledge level in the role-playing. If the user is less knowledgeable about a task, they might not notice incorrect statements. Ask questions that demonstrate your current understanding and areas of confusion.\n- Occasionally Make Mistakes: Real-world users might misspell words, provide incorrect dates, give wrong information, or ask unclear questions. Simulate this behavior to reflect natural interactions.\n- Mention Personal Preferences: Include preferences or constraints that might influence your requests or responses. For example, \"I prefer short answers,\" \"I need this done quickly,\" or \"I like detailed comments in code.\"\n- Goal-Oriented: Keep the chat focused on your intent. Avoid small talk or digressions. Redirect the chat back to the main objective if it starts to stray.\n\n## Output Format:\nYou should output a JSON object with three entries:\n- \"current_answer\" (str): Briefly summerize the AI's current solution to the task.\n- \"thought\" (str): Output your thought process as a user deciding what to say next. Consider:\n1. Have you obtained a satisfactory solution from the AI? If yes, you can terminate this chat.\n2. If not, what specific part of the problem or solution are you struggling with?\n3. Has the AI asked you to perform a task or answer a question? If so, how should you approach it?\n4. Are you noticing any patterns or potential misunderstandings that need clarification?\n5. If you're stuck, how can you phrase your question to get the most helpful response while demonstrating your current understanding?\n- \"response\" (str): Based on your thought process, respond to the AI as the user you are role-playing. Stop immediately when the user's response is completed.\n\n## Important Notes:\n- Respond Based on Previous Messages: Your responses should be based on the context of the current chat history. Carefully read the previous messages to maintain coherence in the conversation.\n- Conversation Flow: If \"Current Chat History\" is empty, start the conversation from scratch with an initial request. Otherwise, continue based on the existing conversation.\n- Don't Copy Input Directly: Use the provided information for understanding context only. Avoid copying target queries or any provided information directly in your responses.\n- Completion Signal: Use \"{termination_signal}\" as your response when you believe your goal has been solved or if you determine the AI cannot help further.\n- Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured.\n\nRemember to stay in character as a user throughout your response, and follow the instructions and guidelines carefully.\"\"\"  # noqa: E501\n\n\nclass CollabLLMInteraction(BaseInteraction):\n    \"\"\"A demo interaction for calculating the reward of CollabLLM.\n\n    - `start_interaction`: start a interaction instance for a trajectory.\n    - `generate_response`: generate the response of the assistant.\n    - `calculate_score`: calculate the score of the interaction.\n    - `finalize_interaction`: finalize the interaction instance.\n    \"\"\"\n\n    def __init__(self, config: dict):\n        super().__init__(config)\n        _config = copy.deepcopy(config)\n\n        _config.pop(\"enable_log\", None)\n\n        self.name = _config.pop(\"name\")\n        self.user_model = _config.pop(\"user_model\")\n\n        self.termination_signal = _config.pop(\"termination_signal\", TERMINATION_SIGNAL)\n        self.num_retries = _config.pop(\"num_retries\", 3)\n\n        self.user_model_kwargs = _config\n\n        self._instance_dict = {}\n\n    async def start_interaction(\n        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs\n    ) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        self.interaction_kwargs = kwargs\n        assert \"single_turn_prompt\" in kwargs, \"single_turn_prompt is required in interaction_kwargs\"\n        return instance_id\n\n    @rollout_trace_op\n    async def generate_response(\n        self, instance_id: str, messages: list[dict[str, Any]], **kwargs\n    ) -> tuple[bool, str, float, dict]:\n        assert messages[-1][\"role\"] in [\"system\", \"assistant\"], (\n            \"Last message input to the user model must be from system or assistant role\"\n        )\n\n        import litellm\n\n        chat_history = self._parse_messages(messages, strip_sys_prompt=True)\n        prompt = USER_PROMPT_TEMPLATE.format(\n            task_desc=self.interaction_kwargs.get(\"task_desc\", \"general assistance task\"),\n            single_turn_prompt=self.interaction_kwargs[\"single_turn_prompt\"],\n            chat_history=chat_history,\n            termination_signal=self.termination_signal,\n        )\n        response = \"\"\n        for i in range(self.num_retries):\n            try:\n                full_response = (\n                    (\n                        await litellm.acompletion(\n                            model=self.user_model,\n                            messages=[{\"role\": \"user\", \"content\": prompt}],\n                            **self.user_model_kwargs,\n                        )\n                    )\n                    .choices[0]\n                    .message.content\n                )\n            except litellm.RateLimitError as e:\n                logger.warning(f\"[CollabLLMInteraction] hit RateLimitError: {e}. Retrying...\")\n                await asyncio.sleep(max(2**i, 60))\n                continue\n            except Exception as e:\n                logger.exception(f\"An unexpected error occurred in CollabLLMAgentLoop: {e}\")\n                continue\n\n            try:\n                if isinstance(full_response, str):\n                    full_response = extract_json(full_response)\n            except Exception as e:\n                logger.warning(f\"[CollabLLMInteraction] Error extracting JSON: {e}. Retrying...\")\n                continue\n\n            if isinstance(full_response, dict):\n                keys = full_response.keys()\n                if {\"current_answer\", \"thought\", \"response\"}.issubset(keys):\n                    response = full_response.pop(\"response\")\n                    if isinstance(response, str):\n                        break\n                    else:\n                        logger.warning(\n                            f\"[CollabLLMInteraction] got an invaild response {response} full_response {full_response}. \\\n                                Retrying...\"\n                        )\n                        continue\n                else:\n                    logger.warning(f\"[CollabLLMInteraction] Keys {keys} do not match expected keys. Retrying...\")\n                    continue\n\n        self._instance_dict[instance_id][\"response\"] = response\n        logger.debug(f\"[CollabLLMInteraction] User: {response}\")\n        should_terminate_sequence = self.termination_signal in response\n        reward = 0.0\n\n        return should_terminate_sequence, response, reward, {}\n\n    async def finalize_interaction(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n\n    def _parse_messages(self, messages, strip_sys_prompt=True):\n        if messages is None:\n            return \"\"\n\n        if strip_sys_prompt:\n            messages = [msg for msg in messages if msg[\"role\"] != \"system\"]\n\n        messages = [remove_think_block(msg) for msg in messages]\n\n        chat = \"\\n\".join(f\"**{m['role'].capitalize()}**: {m['content']}\" for m in messages)\n\n        return chat\n\n\ndef extract_json(s):\n    def convert_value(value):\n        true_values = {\"true\": True, \"false\": False, \"null\": None}\n        value_lower = value.lower()\n        if value_lower in true_values:\n            return true_values[value_lower]\n        try:\n            if \".\" in value or \"e\" in value.lower():\n                return float(value)\n            else:\n                return int(value)\n        except ValueError:\n            return value  # Return as string if not a number\n\n    def parse_number(s, pos):\n        start = pos\n        while pos < len(s) and s[pos] in \"-+0123456789.eE\":\n            pos += 1\n        num_str = s[start:pos]\n        try:\n            if \".\" in num_str or \"e\" in num_str.lower():\n                return float(num_str), pos\n            else:\n                return int(num_str), pos\n        except ValueError:\n            logger.error(f\"Invalid number at position {start}: {num_str}\")\n            raise\n\n    def skip_whitespace(s, pos):\n        while pos < len(s) and s[pos] in \" \\t\\n\\r\":\n            pos += 1\n        return pos\n\n    def parse_string(s, pos):\n        quote_char = s[pos]\n        assert quote_char in ('\"', \"'\")\n        pos += 1\n        result = \"\"\n        while pos < len(s):\n            c = s[pos]\n            if c == \"\\\\\":\n                pos += 1\n                if pos >= len(s):\n                    raise ValueError(\"Invalid escape sequence\")\n                c = s[pos]\n                escape_sequences = {\"n\": \"\\n\", \"t\": \"\\t\", \"r\": \"\\r\", \"\\\\\": \"\\\\\", quote_char: quote_char}\n                result += escape_sequences.get(c, c)\n            elif c == quote_char:\n                pos += 1\n                # Attempt to convert to a number if possible\n                converted_value = convert_value(result)\n                return converted_value, pos\n            else:\n                result += c\n            pos += 1\n        raise ValueError(\"Unterminated string\")\n\n    def parse_key(s, pos):\n        pos = skip_whitespace(s, pos)\n        if s[pos] in ('\"', \"'\"):\n            key, pos = parse_string(s, pos)\n            return key, pos\n        else:\n            raise ValueError(f\"Expected string for key at position {pos}\")\n\n    def parse_object(s, pos):\n        obj = {}\n        assert s[pos] == \"{\"\n        pos += 1\n        pos = skip_whitespace(s, pos)\n        while pos < len(s) and s[pos] != \"}\":\n            pos = skip_whitespace(s, pos)\n            key, pos = parse_key(s, pos)\n            pos = skip_whitespace(s, pos)\n            if pos >= len(s) or s[pos] != \":\":\n                raise ValueError(f'Expected \":\" at position {pos}')\n            pos += 1\n            pos = skip_whitespace(s, pos)\n            value, pos = parse_value(s, pos)\n            obj[key] = value\n            pos = skip_whitespace(s, pos)\n            if pos < len(s) and s[pos] == \",\":\n                pos += 1\n                pos = skip_whitespace(s, pos)\n            elif pos < len(s) and s[pos] == \"}\":\n                break\n            elif pos < len(s) and s[pos] != \"}\":\n                raise ValueError(f'Expected \",\" or \"}}\" at position {pos}')\n        if pos >= len(s) or s[pos] != \"}\":\n            raise ValueError(f'Expected \"}}\" at position {pos}')\n        pos += 1\n        return obj, pos\n\n    def parse_array(s, pos):\n        lst = []\n        assert s[pos] == \"[\"\n        pos += 1\n        pos = skip_whitespace(s, pos)\n        while pos < len(s) and s[pos] != \"]\":\n            value, pos = parse_value(s, pos)\n            lst.append(value)\n            pos = skip_whitespace(s, pos)\n            if pos < len(s) and s[pos] == \",\":\n                pos += 1\n                pos = skip_whitespace(s, pos)\n            elif pos < len(s) and s[pos] == \"]\":\n                break\n            elif pos < len(s) and s[pos] != \"]\":\n                raise ValueError(f'Expected \",\" or \"]\" at position {pos}')\n        if pos >= len(s) or s[pos] != \"]\":\n            raise ValueError(f'Expected \"]\" at position {pos}')\n        pos += 1\n        return lst, pos\n\n    def parse_triple_quoted_string(s, pos):\n        if s[pos : pos + 3] == \"'''\":\n            quote_str = \"'''\"\n        elif s[pos : pos + 3] == '\"\"\"':\n            quote_str = '\"\"\"'\n        else:\n            raise ValueError(f\"Expected triple quotes at position {pos}\")\n        pos += 3\n        result = \"\"\n        while pos < len(s):\n            if s[pos : pos + 3] == quote_str:\n                pos += 3\n                # Attempt to convert to a number if possible\n                converted_value = convert_value(result)\n                return converted_value, pos\n            else:\n                result += s[pos]\n                pos += 1\n        raise ValueError(\"Unterminated triple-quoted string\")\n\n    def parse_value(s, pos):\n        pos = skip_whitespace(s, pos)\n        if pos >= len(s):\n            raise ValueError(\"Unexpected end of input\")\n        if s[pos] == \"{\":\n            return parse_object(s, pos)\n        elif s[pos] == \"[\":\n            return parse_array(s, pos)\n        elif s[pos : pos + 3] in (\"'''\", '\"\"\"'):\n            return parse_triple_quoted_string(s, pos)\n        elif s[pos] in ('\"', \"'\"):\n            return parse_string(s, pos)\n        elif s[pos : pos + 4].lower() == \"true\":\n            return True, pos + 4\n        elif s[pos : pos + 5].lower() == \"false\":\n            return False, pos + 5\n        elif s[pos : pos + 4].lower() == \"null\":\n            return None, pos + 4\n        elif s[pos] in \"-+0123456789.\":\n            return parse_number(s, pos)\n        else:\n            raise ValueError(f\"Unexpected character at position {pos}: {s[pos]}\")\n\n    json_start = s.index(\"{\")\n    json_end = s.rfind(\"}\")\n    s = s[json_start : json_end + 1]\n\n    s = s.strip()\n    result, pos = parse_value(s, 0)\n    pos = skip_whitespace(s, pos)\n    if pos != len(s):\n        raise ValueError(f\"Unexpected content at position {pos}\")\n    return result\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/config/agent.yaml",
    "content": "- name: collabllm_agent\n  _target_: recipe.collabllm.collabllm_agent_loop.CollabLLMAgentLoop\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/config/collabllm_interaction_config.yaml",
    "content": "interaction:\n  - name: \"collabllm\"\n    class_name: \"recipe.collabllm.collabllm_interation.CollabLLMInteraction\"\n    config: {\n      \"user_model\": \"gpt-4o-mini\",\n      \"num_retries\": 3,\n      \"max_tokens\": 512,\n      \"temperature\": 1.0,\n      \"enable_log\": True\n    }"
  },
  {
    "path": "verl_distillation/recipe/collabllm/metrics/accuracy.py",
    "content": "# Copyright 2025 CollabLLM team and/or its affiliates\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom recipe.collabllm.utils import extract_json, parse_messages\n\nACCURACY_PROMPT = '''You are a helpful and meticulous evaluator. Your task is to \\\nevaluate the *accuracy* of an AI model's answer to a target question. \\\nYou will be given the target question, the ground truth answer, and the conversation between the AI and the user.\n\nProvided Information:\n\n<|The Start of Target Question and Ground Truth Answer|>\nTarget Question: {single_turn_prompt}\nGround Truth Answer: {ground_truth}\n<|The End of Target Question and Ground Truth Answer|>\n\n<|The Start of The Conversation|>\n{chat_history}\n<|The End of The Conversation|>\n\nYou should determine whether the model's final response to the target question is \\\nfactually correct and consistent with the provided ground truth.\n\nRating criteria (binary):\n  • 1 = Correct   — the response matches the ground truth.\n  • 0 = Incorrect — the response contradicts or misses the ground truth.\n\nOutput format (JSON):\n{{\n    \"thought\": \"<your reasoning here>\",\n    \"accuracy\": <0 or 1>\n}}\n\nDouble check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured. \\\nUse \" or \"\"\" to wrap up the thought and use single quotes inside the \"thought\" field to avoid JSON escape issues.\n\nYour evaluation:\n'''\n\n\nasync def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):\n    # Check if litellm is available, fallback to openai if not\n    try:\n        import litellm\n\n        use_litellm = True\n    except ImportError:\n        # litellm not found, falling back to openai\n        import openai\n\n        use_litellm = False\n\n    chat_history = parse_messages(messages, strip_sys_prompt=True)\n    prompt = ACCURACY_PROMPT.format(\n        single_turn_prompt=extra_info[\"interaction_kwargs\"][\"single_turn_prompt\"],\n        ground_truth=ground_truth,\n        chat_history=chat_history,\n    )\n\n    if use_litellm:\n        full_response = (\n            (\n                await litellm.acompletion(\n                    messages=[{\"role\": \"user\", \"content\": prompt}],\n                    **kwargs,\n                )\n            )\n            .choices[0]\n            .message.content\n        )\n    else:\n        client = openai.AsyncOpenAI()  # Assumes API key is set in environment\n        full_response = (\n            (\n                await client.chat.completions.create(\n                    messages=[{\"role\": \"user\", \"content\": prompt}],\n                    **kwargs,\n                )\n            )\n            .choices[0]\n            .message.content\n        )\n\n    full_response = extract_json(full_response)\n\n    assert isinstance(full_response, dict), f\"Expected a dict, got {type(full_response)}\"\n    assert {\"accuracy\", \"thought\"}.issubset(full_response.keys()), (\n        f\"Expected keys not found from {full_response.keys()}\"\n    )\n\n    accuracy = full_response.pop(\"accuracy\")\n    return float(accuracy)\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/metrics/bleu_score.py",
    "content": "# Copyright 2025 CollabLLM team and/or its affiliates\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom nltk.translate.bleu_score import sentence_bleu\n\nfrom recipe.collabllm.utils import extract_json, parse_messages\n\nEXTRACT_MULTITURN_COMPLETION_PROMPT = '''You are a thorough and diligent conversation analyzer. \\\nYour task is to extract the final and complete version of a document that was generated during \\\na multiturn conversation between a user and a chat assistant. \\\nThe extracted content should reflect the final and comprehensive response provided by the assistant \\\nbased on the user’s request.\n\nYou will be provided with the conversation:\n\n<|The Start of The Conversation|>\n{chat_history}\n<|The End of The Conversation|>\n\nInstructions for Extraction:\n\n1. Identify the Most Update-to-Date Contents: Review the entire conversation to identify the most updated parts \\\nof the content provided by the assistant. This may include:\n   - Different sections of text (e.g., an essay, report, or article).\n\n2. Integrate Revisions: If the assistant made revisions, updates, or added sections throughout the conversation, \\\nensure that these changes are fully integrated into the final content. The goal is to extract a single, cohesive \\\noutput that incorporates all modifications and additions made during the conversation. For example, if the assistant \\\nwrites an introducation at the beginning and move on to the conclusion, the final output should include both the \\\nintroduction and the conclusion.\n\n3. Focus on Completeness:\n   - For text-based documents: Ensure that the extracted content is comprehensive and represents the full document \\\n     or section as discussed in the conversation.\n\nYou should output a JSON object with two entries:\n- \"thought\" (str): Output your thought process when extracting the final content. \n   1. How do different parts of the conversation contribute to the final output?\n   2. How do you make sure you included the most updated and complete information?\n   3. How do you make sure you did not include any information that is not necessary?\n- \"final_completion\" (str): The final and complete version of the document extracted from the conversation.\n\nNote: \n1. If there are multiple lines, you should use triple quotes (\"\"\") to wrap the content. For example, \\\n   \"final_completion\": \"\"\"first line. \n   second line.\"\"\" or \"thought\": \"\"\"first line;\n   second line.\"\"\".\n2. In the \"final_completion\" entry, replace all double quotes (\") with single quotes (') to prevent JSON formatting \\\nissues. For example, you can output \"final_completion\": \"'Hello World' is a common phrase.\" \n\nTake a deep breath and carefully follow the instructions and guidelines provided. \n'''\n\n\nasync def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):\n    # Check if litellm is available, fallback to openai if not\n    try:\n        import litellm\n\n        use_litellm = True\n    except ImportError:\n        # litellm not found, falling back to openai\n        import openai\n\n        use_litellm = False\n\n    chat_history = parse_messages(messages, strip_sys_prompt=True)\n    prompt = EXTRACT_MULTITURN_COMPLETION_PROMPT.format(chat_history=chat_history)\n\n    if use_litellm:\n        full_response = (\n            (\n                await litellm.acompletion(\n                    messages=[{\"role\": \"user\", \"content\": prompt}],\n                    **kwargs,\n                )\n            )\n            .choices[0]\n            .message.content\n        )\n    else:\n        client = openai.AsyncOpenAI()  # Assumes API key is set in environment\n        full_response = (\n            (\n                await client.chat.completions.create(\n                    messages=[{\"role\": \"user\", \"content\": prompt}],\n                    **kwargs,\n                )\n            )\n            .choices[0]\n            .message.content\n        )\n\n    full_response = extract_json(full_response)\n\n    assert isinstance(full_response, dict), f\"Expected a dict, got {type(full_response)}\"\n    assert {\"final_completion\", \"thought\"}.issubset(full_response.keys()), (\n        f\"Expected keys not found from {full_response.keys()}\"\n    )\n\n    final_completion = full_response.pop(\"final_completion\")\n\n    bleu = sentence_bleu([ground_truth], final_completion)\n    return float(bleu)\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/metrics/interactivity.py",
    "content": "# Copyright 2025 CollabLLM team and/or its affiliates\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom recipe.collabllm.utils import extract_json, parse_messages\n\nINTERACTIVITY_PROMPT = '''You are a helpful and meticulous conversation evaluator. \\\nYour task is to evaluate the interactivity of the responses provided by an AI assistant \\\nto user questions in a given conversation:\n\n<|The Start of the Conversation to be Evaluated|>\n{chat_history}\n<|The End of the Conversation to be Evaluated|>\n\nYou should assess the assistant's engagement, clarity, and ability to understand the user's needs. \\\nGive a float number between 0 and 1. \n\nScoring Criteria:\n- Let U = user understanding & response clarity ∈ [0,1]  \n  - 1.0 = Fully understands the user's intent and gives a clear answer.  \n  - 0.7 = Mostly understands and the answer is generally clear.  \n  - 0.3 = Partially misunderstands or the answer is hard to follow.  \n  - 0.0 = Misunderstands the intent and gives an unclear or irrelevant answer.\n- Let Q = clarification in [0,1]\n  - 1.0 = Asks precise, necessary clarifying questions when needed.\n  - 0.7 = Asks somewhat helpful but incomplete clarifications.\n  - 0.3 = Only asks generic questions (e.g., “Does that help?”).\n  - 0.0 = Asks no clarifying questions when needed.\n- Let S = suggestion helpfulness in [0,1]\n  - 1.0 = Provides useful, actionable suggestions.\n  - 0.7 = Suggestions are somewhat helpful but limited.\n  - 0.3 = Suggestions are vague or generic.\n  - 0.0 = No suggestions when they would clearly help.\nscore = average([U, Q, S])\n\nOutput format (JSON):\n{{\n    \"thought\": \"<How interactive is the assistant?>\",\n    \"interactivity\": <score>\n}}\n\nDouble check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured. \\\nUse \" or \"\"\" to wrap up the thought. You should not use other triple quotes inside the \"thought\" field. \\\nInstead you should use single quotes to avoid JSON escape issues.\n\nYour evaluation:\n'''\n\n\nasync def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):\n    # Check if litellm is available, fallback to openai if not\n    try:\n        import litellm\n\n        use_litellm = True\n    except ImportError:\n        # litellm not found, falling back to openai\n        import openai\n\n        use_litellm = False\n\n    chat_history = parse_messages(messages, strip_sys_prompt=True)\n    prompt = INTERACTIVITY_PROMPT.format(chat_history=chat_history)\n\n    if use_litellm:\n        full_response = (\n            (\n                await litellm.acompletion(\n                    messages=[{\"role\": \"user\", \"content\": prompt}],\n                    **kwargs,\n                )\n            )\n            .choices[0]\n            .message.content\n        )\n    else:\n        client = openai.AsyncOpenAI()  # Assumes API key is set in environment\n        full_response = (\n            (\n                await client.chat.completions.create(\n                    messages=[{\"role\": \"user\", \"content\": prompt}],\n                    **kwargs,\n                )\n            )\n            .choices[0]\n            .message.content\n        )\n\n    full_response = extract_json(full_response)\n\n    assert isinstance(full_response, dict), f\"Expected a dict, got {type(full_response)}\"\n    assert {\"interactivity\", \"thought\"}.issubset(full_response.keys()), (\n        f\"Expected keys not found from {full_response.keys()}\"\n    )\n\n    interactivity = full_response.pop(\"interactivity\")\n    return float(interactivity)\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/metrics/pass_rate.py",
    "content": "# Copyright 2025 CollabLLM team and/or its affiliates\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom bigcodebench.eval import untrusted_check\n\nfrom recipe.collabllm.utils import extract_json, parse_messages\n\nEXTRACT_MULTITURN_COMPLETION_PROMPT = '''You are a thorough and diligent conversation analyzer. \\\nYour task is to extract the final and complete version of a code function {entry_point} that was generated \\\nduring a multiturn conversation between a user and a chat assistant. \\\nThe extracted content should reflect the final and comprehensive response provided by the \\\nassistant based on the user’s request.\n\nYou will be provided with the task and the conversation:\n\n<|The Start of The Task|>\n{single_turn_prompt}\n<|The End of The Task|>\n\n<|The Start of The Conversation|>\n{chat_history}\n<|The End of The Conversation|>\n\nInstructions for Extraction:\n\n1. Identify the Most Update-to-Date Contents: Review the entire conversation to identify the most updated parts of \\\nthe content provided by the assistant. This may include:\n   - Different parts of the code snippet, function, class, or script.\n\n2. Integrate Revisions: If the assistant made revisions, updates, or added sections throughout the conversation, \\\nensure that these changes are fully integrated into the final content. The goal is to extract a single, cohesive \\\noutput that incorporates all modifications and additions made during the conversation. For example, if the assistant \\\nwrites a function at the beginning and changes a part, the final output should take the modification into account.\n\n3. Focus on Completeness:\n   - For code: Extract a complete and functional code snippet, including all necessary components such as imports, \\\n     functions, classes, and any other essential elements. The code should be runnable, but you do not need to \\\n     include any testing examples including the contents after `if __name__ == \"__main__\":`. Only the function code \\\n     is required. \n\nYou should output a JSON object with two entries:\n- \"thought\" (str): Output your thought process when extracting the final content. \n   1. How do different parts of the conversation contribute to the final output?\n   2. How do you make sure you included the most updated and complete information?\n   3. How do you make sure you did not include any information that is not necessary?\n- \"final_completion\" (str): The final and complete version of the code extracted from the conversation. \\\nRename main function name for the task to {entry_point} if needed. Remove any comments wrapped by \"\"\".\n\nNote: \n1. If there are multiple lines, you should use triple quotes (\"\"\") to wrap the content. For example, \\\n   \"final_completion\": \"\"\"first line. \n   second line.\"\"\" or \"thought\": \"\"\"first line;\n   second line.\"\"\". You should not use other triple quotes inside. \n2. In the \"final_completion\" entry, replace all double quotes (\") with single quotes (') to prevent JSON formatting \\\n   issues. For example, you can output \"final_completion\": \"'Hello World' is a common phrase.\" \n\nTake a deep breath and carefully follow the instructions and guidelines provided. \n'''\n\n\nasync def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):\n    # Check if litellm is available, fallback to openai if not\n    try:\n        import litellm\n\n        use_litellm = True\n    except ImportError:\n        # litellm not found, falling back to openai\n        import openai\n\n        use_litellm = False\n\n    chat_history = parse_messages(messages, strip_sys_prompt=True)\n\n    prompt = EXTRACT_MULTITURN_COMPLETION_PROMPT.format(\n        chat_history=chat_history,\n        single_turn_prompt=extra_info[\"interaction_kwargs\"][\"single_turn_prompt\"],\n        entry_point=extra_info[\"single_turn_metadata\"][\"entry_point\"],\n    )\n\n    if use_litellm:\n        full_response = (\n            (\n                await litellm.acompletion(\n                    messages=[{\"role\": \"user\", \"content\": prompt}],\n                    **kwargs,\n                )\n            )\n            .choices[0]\n            .message.content\n        )\n    else:\n        client = openai.AsyncOpenAI()  # Assumes API key is set in environment\n        full_response = (\n            (\n                await client.chat.completions.create(\n                    messages=[{\"role\": \"user\", \"content\": prompt}],\n                    **kwargs,\n                )\n            )\n            .choices[0]\n            .message.content\n        )\n\n    full_response = extract_json(full_response)\n\n    assert isinstance(full_response, dict), f\"Expected a dict, got {type(full_response)}\"\n    assert {\"final_completion\", \"thought\"}.issubset(full_response.keys()), (\n        f\"Expected keys not found from {full_response.keys()}\"\n    )\n\n    final_completion = full_response.pop(\"final_completion\")\n    metadata = extra_info[\"single_turn_metadata\"]\n    res = untrusted_check(\n        final_completion,\n        metadata[\"test\"],\n        metadata[\"entry_point\"],\n        max_as_limit=300 * 1024,\n        max_data_limit=300 * 1024,\n        max_stack_limit=300 * 1024,\n        min_time_limit=60,\n        gt_time_limit=60,\n    )\n    passed = res[0] == \"pass\"\n\n    # info = res[1] # for printing extra info\n    return float(passed)\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/metrics/token_amount.py",
    "content": "# Copyright 2025 CollabLLM team and/or its affiliates\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ndef compute_score(data_source, messages, ground_truth, extra_info, **kwargs):\n    prompt = extra_info[\"prompt\"]\n\n    # Calculate the token penalty based on the length of the prompt\n    future_conv = messages[len(prompt) :]\n\n    # simple length estimation\n    total_tokens = sum(len(m.content.split()) for m in future_conv)\n\n    return total_tokens\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/process_dataset.py",
    "content": "# Copyright 2025 CollabLLM team and/or its affiliates\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/usr/bin/env python3\n\"\"\"\n# available datasets: \n# math-hard(-large), medium(-large), bigcodebench(-large)\n# to create your own dataset, refer to https://github.com/Wuyxin/collabllm\n\nDATASET=math-hard-large\n\npython recipe/collabllm/process_dataset.py \\\n  --dataset collabllm/collabllm-multiturn-$DATASET  \\\n  --local_dir $HOME/data/collabllm-$DATASET \\\n  --dataset_type sft\n\npython recipe/collabllm/process_dataset.py \\\n  --dataset collabllm/collabllm-multiturn-$DATASET  \\\n  --local_dir $HOME/data/collabllm-$DATASET \\\n  --dataset_type rl\n  \n\nPreprocess collabllm/collabllm-multiturn-math-hard into (ground_truth, extra_info).\n\n- ground_truth: picked from --prefer_field (default: single_turn_completion),\n                falling back to --fallback_field (default: completion)\n- extra_info:   a shallow copy of the original example plus bookkeeping fields\n- reward_model: {\"style\": \"rule\", \"ground_truth\": ground_truth}\n\nSaves one parquet per split into --local_dir and a small JSON preview.\n\"\"\"\n\nimport argparse\nimport json\nimport os\nimport uuid\nfrom typing import Any, Optional\n\nfrom datasets import Dataset, concatenate_datasets, load_dataset\n\nSYSTEM_PROMPT = \"\"\"The assistant is designed to be helpful, proactive, and highly interactive.\n\nThe assistant strives to accurately interpret the user's intent throughout the conversation, acknowledging previous\ninteractions to maintain context and continuity. If the user's message is unclear or lacks necessary details, the\nassistant always asks for clarification rather than making assumptions. For example, if the user's request is\nincomplete, the assistant responds with: \"Could you provide more details so I can assist you better?\"\n\nThe assistant asks specific follow-up questions and offers suggestions based on the user's needs, avoiding vague or\ngeneric prompts. It proactively provides guidance and potential next steps, especially in complex tasks such as\nwriting, analysis, coding, and question answering.\n\nThe assistant is mindful of how much content the user needs to read or type, keeping interactions concise and\nefficient. It reduces unnecessary repetition and ensures responses are relevant, well-structured, and free from\nerrors. When presenting options or asking for feedback, the assistant simplifies interactions by offering\nmultiple-choice answers or specific suggestions to make it easier for the user to respond quickly.\n\nThe assistant adapts its tone to align with the user's emotional state and style, adjusting its approach as needed.\nIf uncertain about something, the assistant honestly says, \"I don't know,\" and suggests ways for the user to find\nthe information.\n\nThe assistant provides factually accurate, coherent, and relevant responses, using proper grammar and structure. It\nremains interactive and proactive across all tasks, continually seeking feedback to refine and improve\ninteractions.\"\"\"\n\n\n# Required fields: \"prompt\", \"ground_truth\", \"extra_info\"\n# In \"extra_info\" dict:\n# (1) Rquired: \"single_turn_prompt\", which is the specific problem used to inform the user simulator,\n# (2) Optional: \"task_desc\" (a short task description),\n# (3) Optional: other fields for customized reward computation\ndef collapse_example(example: dict[str, Any]) -> dict[str, Any]:\n    if \"prompt\" not in example:\n        raise ValueError(\"Missing required 'prompt' field.\")\n\n    ground_truth = (\n        example.get(\"ground_truth\") or example.get(\"single_turn_completion\") or example.get(\"completion\") or \"\"\n    )\n\n    extra_info = {}\n    for k, v in example.items():\n        if k in (\"prompt\", \"ground_truth\", \"extra_info\"):\n            continue\n        extra_info.setdefault(k, v)  # keep extra_info values if keys overlap\n\n    # make sure extra_info has the required fields\n    assert \"single_turn_prompt\" in extra_info, \"Missing 'single_turn_prompt' in extra_info.\"\n\n    # add system prompt as the beginning of the list\n    example[\"prompt\"] = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}] + example[\"prompt\"]\n\n    extra_info.setdefault(\"prompt\", example[\"prompt\"])  # save the original prompt\n    extra_info.setdefault(\n        \"interaction_kwargs\",\n        {\n            \"name\": \"collabllm\",\n            \"single_turn_prompt\": extra_info.pop(\"single_turn_prompt\"),\n            \"task_desc\": extra_info.pop(\"task_desc\", \"general ask-for-assistance task\"),\n        },\n    )\n    return {\n        \"prompt\": example[\"prompt\"],\n        \"ground_truth\": ground_truth,\n        \"raw_prompt\": example[\"prompt\"],  # save the original prompt\n        \"extra_info\": extra_info,\n        \"reward_model\": {\"style\": \"rule\", \"ground_truth\": ground_truth},\n        \"data_source\": \"collabllm\",\n        \"agent_name\": \"collabllm_agent\",\n        \"index\": str(uuid.uuid4()),\n    }\n\n\n# ---------- IO helpers ----------\ndef save_parquet(ds_split: Dataset, filename: str, out_dir: str) -> None:\n    os.makedirs(out_dir, exist_ok=True)\n    path = os.path.join(out_dir, f\"{filename}.parquet\")\n    ds_split.to_parquet(path)\n    print(f\"[OK] Wrote {filename}.parquet → {path} ({len(ds_split)} rows)\")\n\n\ndef maybe_copy_to_hdfs(local_dir: str, hdfs_dir: Optional[str]) -> None:\n    if not hdfs_dir:\n        return\n    try:\n        from verl.utils.hdfs_io import copy, makedirs  # type: ignore\n    except Exception as e:\n        print(f\"[WARN] Skipping HDFS copy (verl not available): {e}\")\n        return\n    makedirs(hdfs_dir)\n    copy(src=local_dir, dst=hdfs_dir)\n    print(f\"[OK] Copied {local_dir} → {hdfs_dir}\")\n\n\n# ---------- Main ----------\ndef main():\n    ap = argparse.ArgumentParser()\n    ap.add_argument(\n        \"--dataset\", default=\"collabllm/collabllm-multiturn-math-hard\", help=\"HF dataset path or local dir/file.\"\n    )\n    ap.add_argument(\"--task_desc\", default=\"solving math problems\", help=\"Task description for the dataset.\")\n    ap.add_argument(\"--local_dir\", default=\"~/data/collabllm-math-hard\", help=\"Output directory.\")\n    ap.add_argument(\"--hdfs_dir\", default=None, help=\"Optional HDFS destination (requires verl).\")\n    ap.add_argument(\n        \"--validation_size\", type=float, default=0.1, help=\"Validation split size (fraction or absolute int).\"\n    )\n    ap.add_argument(\"--seed\", type=int, default=42, help=\"Random seed for splitting.\")\n    ap.add_argument(\"--num_proc\", type=int, default=1, help=\"Parallel workers for map().\")\n    ap.add_argument(\"--dataset_type\", default=\"rl\", choices=[\"rl\", \"sft\"], help=\"Type of dataset (e.g., 'rl', 'sft').\")\n    args = ap.parse_args()\n\n    out_dir = os.path.expanduser(args.local_dir)\n    os.makedirs(out_dir, exist_ok=True)\n\n    print(f\"[INFO] Loading dataset: {args.dataset}\")\n    ds_dict = load_dataset(args.dataset)\n    parts = list(ds_dict.values())\n    ds_all: Dataset = parts[0] if len(parts) == 1 else concatenate_datasets(parts)\n    # Dataset({\n    #     features: ['prompt', 'completion', 'conv_id', 'score', 'single_turn_prompt',\n    #       'single_turn_completion', 'single_turn_metadata', 'turn_id', 'sessions', 'rewards'],\n    #     num_rows: xxx\n    # })\n\n    if args.dataset_type == \"rl\":\n        # If multiple splits exist, merge them before collapsing/splitting.\n        ds_all = ds_all.map(lambda x: {\"task_desc\": args.task_desc}, num_proc=args.num_proc)\n\n        print(f\"[INFO] Collapsing to formatted fields on {len(ds_all)} rows…\")\n        ds_all = ds_all.map(\n            function=collapse_example,\n            remove_columns=ds_all.column_names,\n            num_proc=args.num_proc,\n        )\n\n        def dedup_by_prompt(dataset):\n            seen = set()\n            unique_rows = []\n            for ex in dataset:\n                prompt_key = json.dumps(ex[\"prompt\"], sort_keys=True, ensure_ascii=False)\n                if prompt_key not in seen:\n                    seen.add(prompt_key)\n                    unique_rows.append(ex)\n            return Dataset.from_list(unique_rows)\n\n        ds_all = dedup_by_prompt(ds_all)\n\n    elif args.dataset_type == \"sft\":\n        df = ds_all.to_pandas()\n\n        # Sort so that within each conv_id the highest turn_id is first,\n        # and if multiple rows share the same turn_id, the highest score comes first\n        df = df.sort_values([\"conv_id\", \"turn_id\", \"score\"], ascending=[True, False, False])\n\n        # Keep only the top row per conv_id\n        df = df.drop_duplicates(subset=\"conv_id\", keep=\"first\")\n\n        # Back to HF Dataset\n        ds_all = Dataset.from_pandas(df, preserve_index=False)\n\n        # Append assistant response into prompt list\n        def append_completion(example):\n            example[\"prompt\"] = (\n                [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n                + example[\"prompt\"]\n                + [{\"role\": \"assistant\", \"content\": example[\"completion\"]}]\n            )\n            return example\n\n        ds_all = ds_all.map(append_completion)\n\n        # Keep only prompt column\n        cols_to_remove = [col for col in ds_all.column_names if col != \"prompt\"]\n        ds_all = ds_all.remove_columns(cols_to_remove)\n\n    print(f\"[INFO] Splitting with validation_size={args.validation_size}, seed={args.seed}\")\n    split = ds_all.train_test_split(test_size=args.validation_size, seed=args.seed, shuffle=True)\n    train_ds, val_ds = split[\"train\"], split[\"test\"]\n    print(train_ds, val_ds)\n\n    save_parquet(train_ds, f\"{args.dataset_type}_train\", out_dir)\n    save_parquet(val_ds, f\"{args.dataset_type}_validation\", out_dir)\n\n    maybe_copy_to_hdfs(local_dir=out_dir, hdfs_dir=args.hdfs_dir)\n    print(f\"[DONE] {args.dataset_type}_train.parquet and {args.dataset_type}_validation.parquet written.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/reward_function.py",
    "content": "# Copyright 2025 CollabLLM team and/or its affiliates\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport importlib.util\nimport os\nimport sys\nfrom typing import Any, Callable, Optional\n\nimport litellm\nimport torch\nfrom transformers import PreTrainedTokenizer\n\nfrom verl import DataProto\nfrom verl.utils.reward_score import default_compute_score\nfrom verl.workers.reward_manager import register\nfrom verl.workers.reward_manager.abstract import AbstractRewardManager\n\nTERMINATION_SIGNAL = \"[[TERMINATE CHAT]]\"\n\n\nasync def conversation_level_reward_func(\n    data_source, messages, ground_truth, extra_info, metrics, **kwargs\n) -> torch.Tensor:\n    \"\"\"\n    Async version of conversation-level reward function.\n\n    Apply conversation-level reward function to the future interactions between the user simulator\n    and policy model, which are generated from `verl/interactions/collabllm_interation.py`\n    \"\"\"\n    num_retries = kwargs.get(\"num_retries\", 6)\n\n    rewards = {}\n    for metric in metrics:\n        current_dir = os.path.dirname(os.path.abspath(__file__))\n        metric_file_path = os.path.join(current_dir, f\"metrics/{metric}.py\")\n\n        if not os.path.exists(metric_file_path):\n            print(f\"Error: Metric file '{metric_file_path}' not found. Assigning 0 to metric '{metric}'.\")\n            rewards[metric] = 0.0\n            continue\n\n        spec = importlib.util.spec_from_file_location(f\"metric_{metric}\", metric_file_path)\n        if spec is None:\n            print(f\"Error: Could not create spec for metric '{metric}'. Assigning 0 to metric '{metric}'.\")\n            rewards[metric] = 0.0\n            continue\n\n        module = importlib.util.module_from_spec(spec)\n\n        try:\n            sys.modules[f\"metric_{metric}\"] = module\n            assert spec.loader is not None\n            spec.loader.exec_module(module)\n        except Exception as e:\n            print(f\"Error loading metric module from '{metric_file_path}': {e}. Assigning 0 to metric '{metric}'.\")\n            rewards[metric] = 0.0\n            continue\n\n        # Assume each metric file has a compute_score function\n        if not hasattr(module, \"compute_score\"):\n            print(\n                f\"Error: Function 'compute_score' not found in '{metric_file_path}'. Assigning 0 to metric '{metric}'.\"\n            )\n            rewards[metric] = 0.0\n            continue\n\n        compute_score_fn = module.compute_score\n\n        # Retry mechanism for calling the metric function\n        for attempt in range(num_retries):\n            try:\n                # Call the metric function (await if it's async)\n                if asyncio.iscoroutinefunction(compute_score_fn):\n                    rewards[metric] = await compute_score_fn(data_source, messages, ground_truth, extra_info, **kwargs)\n                else:\n                    rewards[metric] = compute_score_fn(data_source, messages, ground_truth, extra_info, **kwargs)\n                break  # Success, exit retry loop\n            except Exception as e:\n                if attempt == num_retries - 1:  # Last attempt\n                    print(\n                        f\"Error: Failed to compute metric '{metric}' after {num_retries} attempts. \"\n                        f\"Last error: {e}. Assigning 0 to metric '{metric}'.\"\n                    )\n                    rewards[metric] = 0.0\n                else:\n                    print(f\"Attempt {attempt + 1} failed for metric '{metric}': {e}. Retrying...\")\n                    if isinstance(e, litellm.RateLimitError):\n                        await asyncio.sleep(max(2**attempt, 60))  # Exponential backoff\n\n    # Return dict with metric names as keys\n    return {metric: torch.tensor(reward, dtype=torch.float32) for metric, reward in rewards.items()}\n\n\n@register(\"collabllm\")\nclass CollabLLMRewardManager(AbstractRewardManager):\n    \"\"\"\n    The Reward Manager used in https://github.com/Wuyxin/collabllm/\n    \"\"\"\n\n    def __init__(\n        self,\n        tokenizer: PreTrainedTokenizer,\n        num_examine: int,\n        metric_weights: dict,\n        llm_judge_kwargs: dict,\n        reward_fn_key: str = \"data_source\",\n        compute_score: Optional[Callable] = None,\n        normalize_by_data_source=False,\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n        self.compute_score = compute_score or default_compute_score\n        self.reward_fn_key = reward_fn_key\n\n        self.metric_weights = metric_weights\n        self.llm_judge_kwargs = llm_judge_kwargs\n        self.normalize_by_data_source = normalize_by_data_source\n\n        self.metrics = list(self.metric_weights.keys())\n\n    def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            if return_dict:\n                return {\"reward_tensor\": data.batch[\"rm_scores\"]}\n            else:\n                return data.batch[\"rm_scores\"]\n        # Use thread-compatible async loop management instead of asyncio.run()\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n        try:\n            return loop.run_until_complete(self._compute_rewards_async(data, return_dict))\n        finally:\n            loop.close()\n\n    async def _compute_rewards_async(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:\n        # batched scoring\n        prompt_ids = data.batch[\"prompts\"]\n        prompt_length = prompt_ids.shape[-1]\n        valid_response_length = data.batch[\"attention_mask\"][:, prompt_length:].sum(dim=-1)\n\n        data_source = data.non_tensor_batch[\"data_source\"]\n        ground_truth = data.non_tensor_batch[\"ground_truth\"]\n        extra_info = data.non_tensor_batch[\"extra_info\"]\n        message_lst = data.non_tensor_batch[\"messages\"]\n\n        # batch the messages into multiple\n        num_repeat_rollouts = len(message_lst[0][\"messages\"])\n        batch_size = len(data_source)\n\n        grouped_messages = [\n            [message_lst[i][\"messages\"][j] for i in range(len(message_lst))] for j in range(num_repeat_rollouts)\n        ]\n\n        # Flatten lists for all batch items across all rollouts\n        flattened_data_sources = [data_source[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]\n        flattened_ground_truths = [ground_truth[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]\n        flattened_extra_infos = [extra_info[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)]\n        flattened_messages = [grouped_messages[j][i] for j in range(num_repeat_rollouts) for i in range(batch_size)]\n\n        if num_repeat_rollouts > 0:\n            tasks = [\n                self.compute_score(\n                    flattened_data_sources[i],\n                    flattened_messages[i],\n                    flattened_ground_truths[i],\n                    flattened_extra_infos[i],\n                    self.metrics,\n                    **self.llm_judge_kwargs,\n                )\n                for i in range(len(flattened_data_sources))\n            ]\n            score_dicts = await asyncio.gather(*tasks)\n\n            # Aggregate scores for each metric across repeated rollouts\n            scores_by_metrics = {\n                metric: torch.stack([score_dict[metric] for score_dict in score_dicts])\n                .view(num_repeat_rollouts, -1)\n                .sum(dim=0)\n                for metric in self.metrics\n            }\n\n            # Apply metric-specific weights\n            weighted_scores_by_metrics = {\n                metric: torch.clamp(\n                    scores_by_metrics[metric] * self.metric_weights[metric] / num_repeat_rollouts,\n                    min=-1.0,\n                    max=1.0,\n                )\n                for metric in self.metrics\n            }\n            # Compute mean of weighted scores for each metric\n            mean_weighted_scores_by_metrics = {\n                metric: weighted_scores_by_metrics[metric].mean(dim=0) for metric in self.metrics\n            }\n\n            # Combine weighted scores from all metrics into a single tensor\n            scores = torch.stack([weighted_scores_by_metrics[metric] for metric in self.metrics]).sum(dim=0)\n        else:\n            score_dicts = []\n            scores = torch.full((batch_size,), 0.0, dtype=torch.float32, device=prompt_ids.device)\n            mean_weighted_scores_by_metrics = {metric: 0.0 for metric in self.metrics}\n\n        print(\"Scores:\", scores, mean_weighted_scores_by_metrics)\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n\n        for i in range(len(data)):\n            reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]\n\n        if return_dict:\n            return {\"reward_tensor\": reward_tensor}\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/train_rl_collabllm.sh",
    "content": "# Usage: sh recipe/collabllm/train_rl_collabllm.sh <optional resume path>\n\nset -x\n\nPROJECT_DIR=\"$(pwd)\"\nexport VLLM_USE_V1=1\n\nRESUME_PATH=\"${1:-}\"\n\nif [ -z \"$RESUME_PATH\" ]; then\n    RESUME_PATH=null\nfi\n\nDATASET=math-hard-large\nPROJECT_DIR=\"$(pwd)\"\nAGENTLOOP_CONFIG_PATH=\"$PROJECT_DIR/recipe/collabllm/config/agent.yaml\"\n\n\npython3 -m verl.trainer.main_ppo \\\n    trainer.val_before_train=False \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/collabllm-$DATASET/rl_train.parquet \\\n    data.val_files=$HOME/data/collabllm-$DATASET/rl_validation.parquet \\\n    reward_model.reward_manager=collabllm \\\n    +reward_model.reward_kwargs.metric_weights.accuracy=1 \\\n    +reward_model.reward_kwargs.metric_weights.interactivity=1 \\\n    +reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \\\n    +reward_model.reward_kwargs.llm_judge_kwargs.model=gpt-4o-mini \\\n    +reward_model.reward_kwargs.llm_judge_kwargs.max_tokens=2048 \\\n    +reward_model.reward_kwargs.llm_judge_kwargs.temperature=0 \\\n    data.train_batch_size=16 \\\n    data.max_prompt_length=8196 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=\"Qwen/Qwen2.5-7B-Instruct\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=8 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.rollout.temperature=1.0 \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.multi_turn.enable=true \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=2 \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=3 \\\n    actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts=3 \\\n    actor_rollout_ref.rollout.agent.agent_loop_config_path=$AGENTLOOP_CONFIG_PATH \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\", \"wandb\"]' \\\n    trainer.project_name=verlxcollabllm \\\n    trainer.experiment_name=collabllm-qwen2.5-7B-$DATASET \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    trainer.save_freq=100 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=20 \\\n    custom_reward_function.path=recipe/collabllm/reward_function.py \\\n    custom_reward_function.name=conversation_level_reward_func \\\n    actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"$PROJECT_DIR/recipe/collabllm/config/collabllm_interaction_config.yaml\" \\\n    trainer.resume_from_path=$RESUME_PATH \n"
  },
  {
    "path": "verl_distillation/recipe/collabllm/train_sft_collabllm.sh",
    "content": "#!/bin/bash\nset -x\n\nif [ \"$#\" -lt 1 ]; then\n    echo \"Usage: sft_train_collabllm.sh [<nproc_per_node> other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\n\n# Shift the arguments so $@ refers to the rest\nshift 1\n\nDATASET=math-hard-large\n\ntorchrun --nnodes=1 --nproc_per_node=$nproc_per_node \\\n    -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/collabllm-$DATASET/sft_train.parquet \\\n    data.val_files=$HOME/data/collabllm-$DATASET/sft_validation.parquet \\\n    data.multiturn.enable=true \\\n    data.multiturn.messages_key=prompt \\\n    optim.lr=1e-6 \\\n    data.train_batch_size=64 \\\n    data.micro_batch_size_per_gpu=2 \\\n    data.max_length=8196 \\\n    model.partial_pretrain=Qwen/Qwen2.5-7B-Instruct \\\n    trainer.project_name=collabllm-sft-$DATASET \\\n    trainer.experiment_name=collabllm-sft-qwen2.5-7B-$DATASET \\\n    trainer.logger=console \\\n    trainer.total_epochs=3 $@ \\\n    ulysses_sequence_parallel_size=1 \\\n    use_remove_padding=true $@"
  },
  {
    "path": "verl_distillation/recipe/collabllm/utils.py",
    "content": "# Copyright 2025 CollabLLM team and/or its affiliates\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport logging\nimport os\nimport re\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef parse_messages(messages, strip_sys_prompt=True):\n    \"\"\"\n    Args:\n        messages: List[dict]\n            List of dictionaries with keys 'role' and 'content'\n            Example: messages = [{'role': 'user', 'content': 'Hello!'},\n                                 {'role': 'assistant', 'content': 'Hi!'}, ...]\n    \"\"\"\n    if messages is None:\n        return \"\"\n\n    if strip_sys_prompt:\n        messages = strip_system_prompt(messages)\n\n    chat = \"\\n\".join(f\"**{m.role.capitalize()}**: {m.content}\" for m in messages)\n\n    return chat\n\n\ndef strip_system_prompt(messages):\n    \"\"\"\n    Args:\n        messages: List[dict]\n            List of dictionaries with keys 'role' and 'content'\n            Example: messages = [{'role': 'user', 'content': 'Hello!'},\n                                 {'role': 'assistant', 'content': 'Hi!'}, ...]\n    \"\"\"\n    return [msg for msg in messages if msg.role != \"system\"]\n\n\ndef extract_json(s):\n    def convert_value(value):\n        true_values = {\"true\": True, \"false\": False, \"null\": None}\n        value_lower = value.lower()\n        if value_lower in true_values:\n            return true_values[value_lower]\n        try:\n            if \".\" in value or \"e\" in value.lower():\n                return float(value)\n            else:\n                return int(value)\n        except ValueError:\n            return value  # Return as string if not a number\n\n    def parse_number(s, pos):\n        start = pos\n        while pos < len(s) and s[pos] in \"-+0123456789.eE\":\n            pos += 1\n        num_str = s[start:pos]\n        try:\n            if \".\" in num_str or \"e\" in num_str.lower():\n                return float(num_str), pos\n            else:\n                return int(num_str), pos\n        except ValueError:\n            logger.error(f\"Invalid number at position {start}: {num_str}\")\n            raise\n\n    def skip_whitespace(s, pos):\n        while pos < len(s) and s[pos] in \" \\t\\n\\r\":\n            pos += 1\n        return pos\n\n    def parse_string(s, pos):\n        quote_char = s[pos]\n        assert quote_char in ('\"', \"'\")\n        pos += 1\n        result = \"\"\n        while pos < len(s):\n            c = s[pos]\n            if c == \"\\\\\":\n                pos += 1\n                if pos >= len(s):\n                    raise ValueError(\"Invalid escape sequence\")\n                c = s[pos]\n                escape_sequences = {\"n\": \"\\n\", \"t\": \"\\t\", \"r\": \"\\r\", \"\\\\\": \"\\\\\", quote_char: quote_char}\n                result += escape_sequences.get(c, c)\n            elif c == quote_char:\n                pos += 1\n                # Attempt to convert to a number if possible\n                converted_value = convert_value(result)\n                return converted_value, pos\n            else:\n                result += c\n            pos += 1\n        raise ValueError(\"Unterminated string\")\n\n    def parse_key(s, pos):\n        pos = skip_whitespace(s, pos)\n        if s[pos] in ('\"', \"'\"):\n            key, pos = parse_string(s, pos)\n            return key, pos\n        else:\n            raise ValueError(f\"Expected string for key at position {pos}\")\n\n    def parse_object(s, pos):\n        obj = {}\n        assert s[pos] == \"{\"\n        pos += 1\n        pos = skip_whitespace(s, pos)\n        while pos < len(s) and s[pos] != \"}\":\n            pos = skip_whitespace(s, pos)\n            key, pos = parse_key(s, pos)\n            pos = skip_whitespace(s, pos)\n            if pos >= len(s) or s[pos] != \":\":\n                raise ValueError(f'Expected \":\" at position {pos}')\n            pos += 1\n            pos = skip_whitespace(s, pos)\n            value, pos = parse_value(s, pos)\n            obj[key] = value\n            pos = skip_whitespace(s, pos)\n            if pos < len(s) and s[pos] == \",\":\n                pos += 1\n                pos = skip_whitespace(s, pos)\n            elif pos < len(s) and s[pos] == \"}\":\n                break\n            elif pos < len(s) and s[pos] != \"}\":\n                raise ValueError(f'Expected \",\" or \"}}\" at position {pos}')\n        if pos >= len(s) or s[pos] != \"}\":\n            raise ValueError(f'Expected \"}}\" at position {pos}')\n        pos += 1\n        return obj, pos\n\n    def parse_array(s, pos):\n        lst = []\n        assert s[pos] == \"[\"\n        pos += 1\n        pos = skip_whitespace(s, pos)\n        while pos < len(s) and s[pos] != \"]\":\n            value, pos = parse_value(s, pos)\n            lst.append(value)\n            pos = skip_whitespace(s, pos)\n            if pos < len(s) and s[pos] == \",\":\n                pos += 1\n                pos = skip_whitespace(s, pos)\n            elif pos < len(s) and s[pos] == \"]\":\n                break\n            elif pos < len(s) and s[pos] != \"]\":\n                raise ValueError(f'Expected \",\" or \"]\" at position {pos}')\n        if pos >= len(s) or s[pos] != \"]\":\n            raise ValueError(f'Expected \"]\" at position {pos}')\n        pos += 1\n        return lst, pos\n\n    def parse_triple_quoted_string(s, pos):\n        if s[pos : pos + 3] == \"'''\":\n            quote_str = \"'''\"\n        elif s[pos : pos + 3] == '\"\"\"':\n            quote_str = '\"\"\"'\n        else:\n            raise ValueError(f\"Expected triple quotes at position {pos}\")\n        pos += 3\n        result = \"\"\n        while pos < len(s):\n            if s[pos : pos + 3] == quote_str:\n                pos += 3\n                # Attempt to convert to a number if possible\n                converted_value = convert_value(result)\n                return converted_value, pos\n            else:\n                result += s[pos]\n                pos += 1\n        raise ValueError(\"Unterminated triple-quoted string\")\n\n    def parse_value(s, pos):\n        pos = skip_whitespace(s, pos)\n        if pos >= len(s):\n            raise ValueError(\"Unexpected end of input\")\n        if s[pos] == \"{\":\n            return parse_object(s, pos)\n        elif s[pos] == \"[\":\n            return parse_array(s, pos)\n        elif s[pos : pos + 3] in (\"'''\", '\"\"\"'):\n            return parse_triple_quoted_string(s, pos)\n        elif s[pos] in ('\"', \"'\"):\n            return parse_string(s, pos)\n        elif s[pos : pos + 4].lower() == \"true\":\n            return True, pos + 4\n        elif s[pos : pos + 5].lower() == \"false\":\n            return False, pos + 5\n        elif s[pos : pos + 4].lower() == \"null\":\n            return None, pos + 4\n        elif s[pos] in \"-+0123456789.\":\n            return parse_number(s, pos)\n        else:\n            raise ValueError(f\"Unexpected character at position {pos}: {s[pos]}\")\n\n    json_start = s.index(\"{\")\n    json_end = s.rfind(\"}\")\n    s = s[json_start : json_end + 1]\n\n    s = s.strip()\n    result, pos = parse_value(s, 0)\n    pos = skip_whitespace(s, pos)\n    if pos != len(s):\n        raise ValueError(f\"Unexpected content at position {pos}\")\n    return result\n\n\ndef remove_think_block(msg: dict):\n    \"\"\"\n    remove <think>.*?</think> from content\n    \"\"\"\n    if \"content\" in msg and isinstance(msg[\"content\"], str):\n        msg[\"content\"] = re.sub(r\"<think>.*?</think>\", \"\", msg[\"content\"], flags=re.DOTALL).strip()\n    return msg\n\n\ndef is_valid_messages(msg: dict) -> bool:\n    \"\"\"\n    check if is valid messages, including:\n    1. <think> is paried with </think>\n    2. is not empty inside and outside <think>\n    3. is not nested, and at most one <think> block is allowed.\n    4. can not be empty if remove ending \"<|im_end|>\"\n    \"\"\"\n    content = msg.get(\"content\")\n    if not isinstance(content, str):\n        return True\n\n    # Base case: empty or whitespace-only content is invalid.\n    if not content.strip():\n        return False\n\n    num_think_open = content.count(\"<think>\")\n    num_think_close = content.count(\"</think>\")\n\n    # Rule 1: Check for paired tags.\n    if num_think_open != num_think_close:\n        return False\n\n    # Rule 3: Allow at most one think block.\n    if num_think_open > 1:\n        return False\n\n    # Case 1: No <think> blocks.\n    if num_think_open == 0:\n        visible_content = content\n    # Case 2: Exactly one <think> block.\n    else:\n        # Rule 2: Check for empty content inside the think block.\n        match = re.search(r\"<think>(.*?)</think>\", content, re.DOTALL)\n        if not match or not match.group(1).strip():\n            return False\n\n        # The \"visible\" content is what's outside the think block.\n        visible_content = re.sub(r\"<think>.*?</think>\", \"\", content, flags=re.DOTALL)\n\n    visible_content = visible_content.strip()\n\n    # Rule 4 & 2 (outside): Check if visible content is empty after handling <|im_end|>.\n    if visible_content.endswith(\"<|im_end|>\"):\n        visible_content = visible_content[: -len(\"<|im_end|>\")]\n\n    if not visible_content.strip():\n        return False\n\n    return True\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/README.md",
    "content": "# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO)\n\n> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211)\n\n> [!IMPORTANT]\n>\n> **🔥 News!!!**\n>\n> - [2025/04] We reproduced the results of two versions of DAPO ([Full](./run_dapo_qwen2.5_32b.sh) & [w/o Dynamic Sampling](./run_dapo_wo_ds_qwen2.5_32b.sh)), achieving 52% and 50% on AIME 2024 respectively, based on [the latest codebase on `recipe/dapo`](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo). Please check the details in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n).\n> - [2025/03] We published the training record of [an early version of DAPO (w/o Token-level PG Loss & Dynamic Sampling)](./run_dapo_early_qwen2.5_32b.sh), achieving 44% on AIME 2024, in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n).\n\n🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper@arXiv](https://arxiv.org/abs/2503.14476) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO)\n\n> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps.\n>\n> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png)\n\n## Quickstart\n\n1. Prepare the datasets **on the Ray cluster**:\n\n```bash\nbash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default\n```\n\n2. Submit the job to the Ray cluster **from any machine**:\n\n```bash\ncd verl # Repo root\nexport RAY_ADDRESS=\"http://${RAY_IP:-localhost}:8265\" # The Ray cluster address to connect to\nexport WORKING_DIR=\"${PWD}\" # The local directory to package to the Ray cluster\n# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml\nexport RUNTIME_ENV=\"./recipe/dapo/runtime_env.yaml\" # This sets environment variables for the Ray cluster\nbash recipe/dapo/run_dapo_qwen2.5_32b.sh # or other scripts\n```\n\n## Reproduction Runs\n\n| Setup                                        | AIME 2024 Acc. | Hardware  | Image                                                                | Commit                                                                                       | Environment Variables                                                                                                             | Training Script                                                                                                                                             | Training Record                                                                           |\n| -------------------------------------------- | -------------- | --------- | -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- |\n| DAPO                                         | 52%            | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_qwen2.5_32b.sh)             | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n| DAPO w/o Dynamic Sampling                    | 50%            | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_wo_ds_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n| DAPO w/o Token-level Loss & Dynamic Sampling | 44%            | 16x8xH20  | `hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix`                    | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_early_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n\n> [!IMPORTANT]\n>\n> **📢 Call for Contribution!**\n>\n> Welcome to submit your reproduction runs and setups!\n\n## Configuration\n\n### Separated Clip Epsilons (-> Clip-Higher)\n\nAn example configuration:\n\n```yaml\nactor_rollout_ref:\n  actor:\n    clip_ratio_low: 0.2\n    clip_ratio_high: 0.28\n```\n\n`clip_ratio_low` and `clip_ratio_high` specify the $\\varepsilon_{\\text {low }}$ and $\\varepsilon_{\\text {high }}$ in the DAPO objective.\n\nCore relevant code:\n\n```python\npg_losses1 = -advantages * ratio\npg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)\npg_losses = torch.maximum(pg_losses1, pg_losses2)\n```\n\n### Dynamic Sampling (with Group Filtering)\n\nAn example configuration:\n\n```yaml\ndata:\n  gen_batch_size: 1536\n  train_batch_size: 512\nalgorithm:\n  filter_groups:\n    enable: True\n    metric: acc # score / seq_reward / seq_final_reward / ...\n    max_num_gen_batches: 10 # Non-positive values mean no upper limit\n```\n\nSetting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0.\n\nThe trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`.\n\nCore relevant code:\n\n```python\nprompt_bsz = self.config.data.train_batch_size\nif num_prompt_in_batch < prompt_bsz:\n    print(f'{num_prompt_in_batch=} < {prompt_bsz=}')\n    num_gen_batches += 1\n    max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches\n    if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:\n        print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...')\n        continue\n    else:\n        raise ValueError(\n            f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.'\n        )\nelse:\n    # Align the batch\n    traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n\n    batch = batch[:traj_bsz]\n```\n\n### Flexible Loss Aggregation Mode (-> Token-level Loss)\n\nAn example configuration:\n\n```yaml\nactor_rollout_ref:\n  actor:\n    loss_agg_mode: \"token-mean\" # / \"seq-mean-token-sum\" / \"seq-mean-token-mean\"\n    # NOTE: \"token-mean\" is the default behavior\n```\n\nSetting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch.\n\nCore relevant code:\n\n```python\nif loss_agg_mode == \"token-mean\":\n    loss = verl_F.masked_mean(loss_mat, loss_mask)\nelif loss_agg_mode == \"seq-mean-token-sum\":\n    seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum\n    loss = torch.mean(seq_losses)  # seq-mean\nelif loss_agg_mode == \"seq-mean-token-mean\":\n    seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)  # token-mean\n    loss = torch.mean(seq_losses)  # seq-mean\nelse:\n    raise ValueError(f\"Invalid loss_agg_mode: {loss_agg_mode}\")\n```\n\n### Overlong Reward Shaping\n\nAn example configuration:\n\n```yaml\ndata:\n  max_response_length: 20480 # 16384 + 4096\nreward_model:\n  overlong_buffer:\n    enable: True\n    len: 4096\n    penalty_factor: 1.0\n```\n\nSetting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit.\n\nSpecifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length - overlong_buffer.len` by `0` to `overlong_buffer.len` tokens.\n\nCore relevant code:\n\n```python\nif self.overlong_buffer_cfg.enable:\n    overlong_buffer_len = self.overlong_buffer_cfg.len\n    expected_len = self.max_resp_len - overlong_buffer_len\n    exceed_len = valid_response_length - expected_len\n    overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor\n    overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)\n    reward += overlong_reward\n```\n\n## FAQ\n\n### Where is the \"Overlong Filtering\" in the paper?\n\nMost experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here.\n\n### What's the difference between [the `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) and the [`recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)?\n\n[The `recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) is for **as-is reproduction** and thus won't be updated with new features.\n\n[The `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) works as an example of how to extend the latest `verl` to implement an algorithm recipe, which will be maintained with new features.\n\n### Why can't I produce similar results after modifications?\n\nRL infrastructures nowadays still have inherent unrobustness, on which we are still working hard to improve.\n\nWe strongly recommend to only modify one thing at a time.\n\nWe also list some known problems here:\n\n1. Enabling CUDA graph (`enforce_eager=False`) might cause model performance degradation, whose cause is still under investigation.\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/config/dapo_megatron_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_megatron_trainer\n  - _self_\n\ndata:\n  gen_batch_size: ${data.train_batch_size}\n\nreward_model:\n  reward_manager: dapo\n  overlong_buffer: \n    enable: False # We try to avoid forgetting to set enable\n    len: 0\n    penalty_factor: 0.0\n    log: False\n\nalgorithm:\n  filter_groups:\n    _target_: verl.trainer.config.FilterGroupsConfig\n    enable: False # We try to avoid forgetting to set enable\n    metric: null # acc / score / seq_reward / seq_final_reward / ...\n    max_num_gen_batches: 0 # Non-positive values mean no upper limit\n\ntrainer:\n  project_name: verl-dapo\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/config/dapo_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  gen_batch_size: ${data.train_batch_size}\n\nreward_model:\n  reward_manager: dapo\n  overlong_buffer: \n    enable: False # We try to avoid forgetting to set enable\n    len: 0\n    penalty_factor: 0.0\n    log: False\n\nalgorithm:\n  filter_groups:\n    _target_: verl.trainer.config.FilterGroupsConfig\n    enable: False # We try to avoid forgetting to set enable\n    metric: null # acc / score / seq_reward / seq_final_reward / ...\n    max_num_gen_batches: 0 # Non-positive values mean no upper limit\n\ntrainer:\n  project_name: verl-dapo\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/dapo_ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nFSDP PPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport os\nimport uuid\nfrom collections import defaultdict\nfrom copy import deepcopy\nfrom pprint import pprint\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.core_algos import agg_loss\nfrom verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics\nfrom verl.trainer.ppo.ray_trainer import (\n    AdvantageEstimator,\n    RayPPOTrainer,\n    apply_kl_penalty,\n    compute_advantage,\n    compute_response_mask,\n)\nfrom verl.trainer.ppo.reward import compute_reward\nfrom verl.utils.metric import reduce_metrics\nfrom verl.utils.profiler import marked_timer\nfrom verl.utils.rollout_skip import RolloutSkip\n\n\nclass RayDAPOTrainer(RayPPOTrainer):\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    def compute_kl_related_metrics(self, batch: DataProto, metrics: dict, timing_raw: dict):\n        batch.batch[\"response_mask\"] = compute_response_mask(batch)\n\n        # recompute old_log_probs\n        with marked_timer(\"old_log_prob\", timing_raw, \"blue\"):\n            old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n            entropys = old_log_prob.batch[\"entropys\"]\n            response_masks = batch.batch[\"response_mask\"]\n            loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n            entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n            old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n            metrics.update(old_log_prob_metrics)\n            old_log_prob.batch.pop(\"entropys\")\n            batch = batch.union(old_log_prob)\n\n        if self.use_reference_policy:\n            # compute reference log_prob\n            with marked_timer(\"ref\", timing_raw, \"olive\"):\n                if not self.ref_in_actor:\n                    ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                else:\n                    ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)\n                batch = batch.union(ref_log_prob)\n\n        return batch\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n        self.gen_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        if self.config.actor_rollout_ref.rollout.get(\"skip_rollout\", False):\n            rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)\n            rollout_skip.wrap_generate_sequences()\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        self.gen_steps += 1\n        last_val_metrics = None\n\n        prev_step_profile = False\n        curr_step_profile = (\n            self.global_steps in self.config.global_profiler.steps\n            if self.config.global_profiler.steps is not None\n            else False\n        )\n        next_step_profile = False\n\n        timing_raw = defaultdict(float)\n        batch = None\n        num_prompt_in_batch = 0\n        num_gen_batches = 0\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n\n                with marked_timer(\"start_profile\", timing_raw):\n                    self._start_profiling(\n                        not prev_step_profile and curr_step_profile\n                        if self.config.global_profiler.profile_continuous_steps\n                        else curr_step_profile\n                    )\n\n                new_batch: DataProto = DataProto.from_single_dict(batch_dict)\n                num_gen_batches += 1\n                # pop those keys for generation\n                if \"multi_modal_data\" in new_batch.non_tensor_batch.keys():\n                    gen_batch = new_batch.pop(\n                        batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"],\n                        non_tensor_batch_keys=[\"raw_prompt_ids\", \"multi_modal_data\"],\n                    )\n                else:\n                    gen_batch = new_batch.pop(\n                        batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"],\n                        non_tensor_batch_keys=[\"raw_prompt_ids\"],\n                    )\n                gen_batch_output = gen_batch.repeat(\n                    repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True\n                )\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                with marked_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with marked_timer(\"gen\", timing_raw, \"red\"):\n                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                        with marked_timer(\"gen_max\", timing_raw, \"red\"):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n\n                            new_batch = new_batch.union(gen_baseline_output)\n                            # compute reward model score on new_batch\n                            rm_scores = None\n                            if self.use_rm and \"rm_scores\" not in new_batch.batch.keys():\n                                rm_scores = self.rm_wg.compute_rm_score(new_batch)\n                                new_batch = new_batch.union(rm_scores)\n                            reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            keys_to_pop = set(gen_baseline_output.batch.keys())\n                            if rm_scores is not None:\n                                keys_to_pop.update(rm_scores.batch.keys())\n                            new_batch.pop(batch_keys=list(keys_to_pop))\n\n                            new_batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del rm_scores, gen_baseline_batch, gen_baseline_output\n\n                    new_batch.non_tensor_batch[\"uid\"] = np.array(\n                        [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object\n                    )\n                    # repeat to align with repeated responses in rollout\n                    new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    new_batch = new_batch.union(gen_batch_output)\n\n                    if self.config.algorithm.use_kl_in_reward:\n                        # We need these metrics for apply_kl_penalty if using kl in reward\n                        new_batch = self.compute_kl_related_metrics(new_batch, metrics, timing_raw)\n                        # otherwise, we will compute those after dynamic sampling\n\n                    with marked_timer(\"reward\", timing_raw, \"yellow\"):\n                        # compute scores. Support both model and function-based.\n                        # We first compute the scores using reward model. Then, we call reward_fn to combine\n                        # the results from reward model and rule-based results.\n                        if self.use_rm and \"rm_scores\" not in new_batch.batch.keys():\n                            # we first compute reward model score\n                            reward_tensor = self.rm_wg.compute_rm_score(new_batch)\n                            new_batch = new_batch.union(reward_tensor)\n\n                        # we combine with rule-based rm\n                        reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn)\n\n                        new_batch.batch[\"token_level_scores\"] = reward_tensor\n\n                        if reward_extra_infos_dict:\n                            new_batch.non_tensor_batch.update(\n                                {k: np.array(v) for k, v in reward_extra_infos_dict.items()}\n                            )\n\n                        # compute rewards. apply_kl_penalty if available\n                        if self.config.algorithm.use_kl_in_reward:\n                            new_batch, kl_metrics = apply_kl_penalty(\n                                new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                            )\n                            metrics.update(\n                                kl_metrics\n                            )  # TODO: This will be cleared if we use multiple genenration batches\n                        else:\n                            new_batch.batch[\"token_level_rewards\"] = new_batch.batch[\"token_level_scores\"]\n\n                    if not self.config.algorithm.filter_groups.enable:\n                        batch = new_batch\n                    else:  # NOTE: When prompts after filtering is less than train batch size,\n                        # we skip to the next generation batch\n                        metric_name = self.config.algorithm.filter_groups.metric\n                        if metric_name == \"seq_final_reward\":\n                            # Turn to numpy for easier filtering\n                            new_batch.non_tensor_batch[\"seq_final_reward\"] = (\n                                new_batch.batch[\"token_level_rewards\"].sum(dim=-1).numpy()\n                            )\n                        elif metric_name == \"seq_reward\":\n                            new_batch.non_tensor_batch[\"seq_reward\"] = (\n                                new_batch.batch[\"token_level_scores\"].sum(dim=-1).numpy()\n                            )\n\n                        # Collect the sequence reward for each trajectory\n                        prompt_uid2metric_vals = defaultdict(list)\n                        for uid, metric_val in zip(\n                            new_batch.non_tensor_batch[\"uid\"], new_batch.non_tensor_batch[metric_name], strict=True\n                        ):\n                            prompt_uid2metric_vals[uid].append(metric_val)\n\n                        prompt_uid2metric_std = {}\n                        for prompt_uid, metric_vals in prompt_uid2metric_vals.items():\n                            prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)\n\n                        kept_prompt_uids = [\n                            uid\n                            for uid, std in prompt_uid2metric_std.items()\n                            if std > 0 or len(prompt_uid2metric_vals[uid]) == 1\n                        ]\n                        num_prompt_in_batch += len(kept_prompt_uids)\n\n                        kept_traj_idxs = []\n                        for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch[\"uid\"]):\n                            if traj_from_prompt_uid in kept_prompt_uids:\n                                kept_traj_idxs.append(idx)\n\n                        new_batch = new_batch[kept_traj_idxs]\n                        batch = new_batch if batch is None else DataProto.concat([batch, new_batch])\n\n                        prompt_bsz = self.config.data.train_batch_size\n                        if num_prompt_in_batch < prompt_bsz:\n                            print(f\"{num_prompt_in_batch=} < {prompt_bsz=}\")\n                            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches\n                            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:\n                                print(f\"{num_gen_batches=}. Keep generating...\")\n                                self.gen_steps += 1\n                                is_last_step = self.global_steps >= self.total_training_steps\n                                continue\n                            else:\n                                raise ValueError(\n                                    f\"{num_gen_batches=} >= {max_num_gen_batches=}.\"\n                                    + \" Generated too many. Please check if your data are too difficult.\"\n                                    + \" You could also try set max_num_gen_batches=0 to enable endless trials.\"\n                                )\n                        else:\n                            # Align the batch\n                            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n\n                            batch = batch[:traj_bsz]\n\n                    # === Updating ===\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    # TODO: Decouple the DP balancing and mini-batching.\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                    if not self.config.algorithm.use_kl_in_reward:\n                        batch = self.compute_kl_related_metrics(batch, metrics, timing_raw)\n\n                    # compute values\n                    if self.use_critic:\n                        with marked_timer(\"values\", timing_raw, \"cyan\"):\n                            values = self.critic_wg.compute_values(batch)\n                            batch = batch.union(values)\n\n                    # Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)\n                    batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)\n                    # IS and mismatch metrics already have mismatch/ prefix\n                    metrics.update(is_metrics)\n\n                    with marked_timer(\"adv\", timing_raw, \"brown\"):\n                        # compute advantages, executed on the driver process\n                        norm_adv_by_std_in_grpo = self.config.algorithm.get(\"norm_adv_by_std_in_grpo\", True)\n                        batch = compute_advantage(\n                            batch,\n                            adv_estimator=self.config.algorithm.adv_estimator,\n                            gamma=self.config.algorithm.gamma,\n                            lam=self.config.algorithm.lam,\n                            num_repeat=self.config.actor_rollout_ref.rollout.n,\n                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                        )\n\n                    # update critic\n                    if self.use_critic:\n                        with marked_timer(\"update_critic\", timing_raw, \"pink\"):\n                            critic_output = self.critic_wg.update_critic(batch)\n                        critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                        metrics.update(critic_output_metrics)\n\n                    # implement critic warmup\n                    if self.config.trainer.critic_warmup <= self.global_steps:\n                        # update actor\n                        with marked_timer(\"update_actor\", timing_raw, \"red\"):\n                            actor_output = self.actor_rollout_wg.update_actor(batch)\n                        actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                        metrics.update(actor_output_metrics)\n\n                    # Log rollout generations if enabled\n                    rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n                    if rollout_data_dir:\n                        self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)\n\n                # validate\n                if (\n                    self.val_reward_fn is not None\n                    and self.config.trainer.test_freq > 0\n                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                ):\n                    with marked_timer(\"testing\", timing_raw, \"green\"):\n                        val_metrics: dict = self._validate()\n                        if is_last_step:\n                            last_val_metrics = val_metrics\n                    metrics.update(val_metrics)\n\n                if self.config.trainer.save_freq > 0 and (\n                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n                ):\n                    with marked_timer(\"save_checkpoint\", timing_raw, \"green\"):\n                        self._save_checkpoint()\n\n                with marked_timer(\"stop_profile\", timing_raw):\n                    next_step_profile = (\n                        self.global_steps + 1 in self.config.global_profiler.steps\n                        if self.config.global_profiler.steps is not None\n                        else False\n                    )\n                    self._stop_profiling(\n                        curr_step_profile and not next_step_profile\n                        if self.config.global_profiler.profile_continuous_steps\n                        else curr_step_profile\n                    )\n                    prev_step_profile = curr_step_profile\n                    curr_step_profile = next_step_profile\n\n                # collect metrics\n                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n                # TODO: implement actual tflpo and theoretical tflpo\n                n_gpus = self.resource_pool_manager.get_n_gpus()\n                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n                timing_raw = defaultdict(float)  # clear timing\n\n                metrics[\"train/num_gen_batches\"] = num_gen_batches\n                batch = None\n                num_prompt_in_batch = 0\n                num_gen_batches = 0\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                if is_last_step:\n                    pprint(f\"Final validation metrics: {last_val_metrics}\")\n                    progress_bar.close()\n                    return\n\n                progress_bar.update(1)\n                self.global_steps += 1\n                self.gen_steps += 1\n        # check if last step checkpint exists\n        checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f\"global_step_{self.global_steps}\")\n        if not os.path.exists(checkpoint_dir):\n            # save last step checkpoint\n            timing_raw = defaultdict(float)\n            with marked_timer(\"save_checkpoint\", timing_raw, \"green\"):\n                self._save_checkpoint()\n            metrics = {f\"timing/{k}\": v for k, v in timing_raw.items()}\n            logger.log(data=metrics, step=self.global_steps)\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/main_dapo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport os\nimport socket\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom verl.trainer.ppo.reward import load_reward_manager\nfrom verl.utils.device import is_cuda_available\n\nfrom .dapo_ray_trainer import RayDAPOTrainer\n\n\n@hydra.main(config_path=\"config\", config_name=\"dapo_trainer\", version_base=None)\ndef main(config):\n    run_ppo(config)\n\n\ndef run_ppo(config) -> None:\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        default_runtime_env = {\n            \"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\", \"VLLM_LOGGING_LEVEL\": \"WARN\"}\n        }\n        ray_init_kwargs = config.ray_kwargs.get(\"ray_init\", {})\n        runtime_env_kwargs = ray_init_kwargs.get(\"runtime_env\", {})\n        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)\n        ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, \"runtime_env\": runtime_env})\n        print(f\"ray init kwargs: {ray_init_kwargs}\")\n        ray.init(**OmegaConf.to_container(ray_init_kwargs))\n\n    try:\n        if (\n            is_cuda_available\n            and config.global_profiler.tool == \"nsys\"\n            and OmegaConf.select(config.global_profiler, \"steps\") is not None\n            and len(OmegaConf.select(config.global_profiler, \"steps\")) > 0\n        ):\n            nsight_options = OmegaConf.to_container(\n                config.global_profiler.global_tool_config.nsys.controller_nsight_options\n            )\n            runner = TaskRunner.options(runtime_env={\"nsight\": nsight_options}).remote()\n        else:\n            runner = TaskRunner.remote()\n        ray.get(runner.run.remote(config))\n    finally:\n        if ray.is_initialized():\n            ray.shutdown()\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    def run(self, config):\n        # print initial config\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        print(f\"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}\")\n\n        pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n        OmegaConf.resolve(config)\n\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n\n        # instantiate tokenizer\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        # used for multimodal LLM, could be none\n        processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)\n\n        from verl.single_controller.ray import RayWorkerGroup\n\n        # define worker classes\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n\n            from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker\n\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker\n\n            ray_worker_group_cls = RayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n        role_worker_mapping = {\n            Role.ActorRollout: ray.remote(ActorRolloutRefWorker),\n            Role.Critic: ray.remote(CriticWorker),\n        }\n\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {\n            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n        }\n        mapping = {\n            Role.ActorRollout: global_pool_id,\n            Role.Critic: global_pool_id,\n        }\n\n        # we should adopt a multi-source reward function here\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # - finally, we combine all the rewards together\n        # - The reward type depends on the tag of the data\n        if config.reward_model.enable:\n            if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                from verl.workers.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # reference model\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n            mapping[Role.RefPolicy] = global_pool_id\n\n        reward_fn = load_reward_manager(\n            config,\n            tokenizer,\n            0,\n            max_resp_len=config.data.max_response_length,\n            overlong_buffer_cfg=config.reward_model.overlong_buffer,\n        )\n\n        # Note that we always use function-based RM for validation\n        val_reward_fn = load_reward_manager(\n            config,\n            tokenizer,\n            1,\n            max_resp_len=config.data.max_response_length,\n            overlong_buffer_cfg=config.reward_model.overlong_buffer,\n        )\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        trainer = RayDAPOTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n        )\n        trainer.init_workers()\n        trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/prepare_dapo_data.sh",
    "content": "#!/usr/bin/env bash\nset -uxo pipefail\n\nexport VERL_HOME=${VERL_HOME:-\"${HOME}/verl\"}\nexport TRAIN_FILE=${TRAIN_FILE:-\"${VERL_HOME}/data/dapo-math-17k.parquet\"}\nexport TEST_FILE=${TEST_FILE:-\"${VERL_HOME}/data/aime-2024.parquet\"}\nexport OVERWRITE=${OVERWRITE:-0}\n\nmkdir -p \"${VERL_HOME}/data\"\n\nif [ ! -f \"${TRAIN_FILE}\" ] || [ \"${OVERWRITE}\" -eq 1 ]; then\n  wget -O \"${TRAIN_FILE}\" \"https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/resolve/main/data/dapo-math-17k.parquet?download=true\"\nfi\n\nif [ ! -f \"${TEST_FILE}\" ] || [ \"${OVERWRITE}\" -eq 1 ]; then\n  wget -O \"${TEST_FILE}\" \"https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024/resolve/main/data/aime-2024.parquet?download=true\"\nfi\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/run_dapo_early_qwen2.5_32b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Early-Qwen2.5-32B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# An early version for DAPO\nloss_agg_mode=\"seq-mean-token-mean\"\n\nenable_filter_groups=False\ngen_prompt_bsz=512 # NOTE: no filtering here\ntrain_prompt_bsz=512\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=16\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-16}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$((max_prompt_length + max_response_length))\ninfer_ppo_max_token_len=$((max_prompt_length + max_response_length))\noffload=True\ngen_tp=4\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=5 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/run_dapo_qwen2.5_32b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-32B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=512\ngen_prompt_bsz=$((train_prompt_bsz * 3))\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-16}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$((max_prompt_length + max_response_length))\ninfer_ppo_max_token_len=$((max_prompt_length + max_response_length))\noffload=True\ngen_tp=4\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=5 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/run_dapo_qwen2.5_32b_npu.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO-Qwen2.5-32B'\nexp_name='Qwen2.5-32B-npu-32rank-gbs128'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\nclip_ratio_low=0.2\nclip_ratio_high=0.28\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\nloss_agg_mode=\"token-mean\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\n\nNNODES=2\n\ntrain_prompt_bsz=128\ngen_prompt_bsz=$((train_prompt_bsz * 3))\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\nPWD=./\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))\noffload=True\ngen_tp=4\nenable_chunked_prefill=True\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    +actor_rollout_ref.model.override_config.attention_dropout=0. \\\n    +actor_rollout_ref.model.override_config.embd_pdrop=0. \\\n    +actor_rollout_ref.model.override_config.resid_pdrop=0. \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.90 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=${enable_chunked_prefill} \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger=\"['console','wandb']\" \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=20 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.device=npu \\\n    trainer.resume_mode=auto \\\n    actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \\\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n# Rollout Importance Sampling Example\n# References:\n#   - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda\n#   - Off-policy RL: https://fengyao.notion.site/off-policy-rl\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-32B-RolloutIS'  # Rollout Importance Sampling\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\n# Rollout Importance Sampling parameters\nrollout_is=True\nrollout_is_threshold=2.0\nrollout_is_threshold_lower=null  # No lower bound\nrollout_is_level=token  # token-level\nrollout_is_mode=truncate  # truncate mode\nrollout_is_veto_threshold=null  # No veto\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=512\ngen_prompt_bsz=$((train_prompt_bsz * 3))\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-16}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$((max_prompt_length + max_response_length))\ninfer_ppo_max_token_len=$((max_prompt_length + max_response_length))\noffload=True\ngen_tp=4\n\n\n# Rollout Importance Sampling (corrects distribution mismatch between rollout and training)\n#\n# Please note that server mode (agent loop) hasn't returned rollout_log_probs for now,\n# so currently server mode is not supported for Rollout IS.\n#\n# Rollout IS parameters (configured at top of script):\n#   algorithm.rollout_is=True\n#   algorithm.rollout_is_threshold=2.0  # Upper threshold (can be tuned)\n#   algorithm.rollout_is_level=token  # Aggregation level\n#   algorithm.rollout_is_mode=truncate  # Bounding mode\n#   actor_rollout_ref.rollout.calculate_log_probs=True  # Required!\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    algorithm.rollout_is=${rollout_is} \\\n    algorithm.rollout_is_threshold=${rollout_is_threshold} \\\n    algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \\\n    algorithm.rollout_is_level=${rollout_is_level} \\\n    algorithm.rollout_is_mode=${rollout_is_mode} \\\n    algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=5 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \n"
  },
  {
    "path": "verl_distillation/recipe/dapo/run_dapo_qwen2.5_7b_npu.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO-Qwen2.5-7B-Instruct'\nexp_name='DAPO-Qwen2.5-7B-Instruct'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\nclip_ratio_low=0.2\nclip_ratio_high=0.28\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\nloss_agg_mode=\"token-mean\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\n\nNNODES=1\n\ntrain_prompt_bsz=16\ngen_prompt_bsz=$((train_prompt_bsz * 3))\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=1\n\n# Ray\nPWD=./\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-7B-Instruct\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))\noffload=True\ngen_tp=1\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    +actor_rollout_ref.model.override_config.attention_dropout=0. \\\n    +actor_rollout_ref.model.override_config.embd_pdrop=0. \\\n    +actor_rollout_ref.model.override_config.resid_pdrop=0. \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.50 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger=\"['console']\" \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=20 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.device=npu \\\n    trainer.resume_mode=auto \\\n    actor_rollout_ref.actor.entropy_checkpointing=True \\\n    actor_rollout_ref.ref.entropy_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \\\n    actor_rollout_ref.ref.entropy_from_logits_with_chunking=True"
  },
  {
    "path": "verl_distillation/recipe/dapo/run_dapo_qwen3_14b_base_npu.sh",
    "content": "#!/bin/bash\nproject_name='DAPO'\nexp_name='DAPO-Qwen3-14B-Base'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=False\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=16\ngen_prompt_bsz=$((train_prompt_bsz * 2))\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=1\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-14B-Base\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\n\n# Performance Related Parameter\nsp_size=2\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))\noffload=True\ngen_tp=2\n\nray job submit --runtime-env=\"${RUNTIME_ENV}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    +actor_rollout_ref.model.override_config.attention_dropout=0. \\\n    +actor_rollout_ref.model.override_config.embd_pdrop=0. \\\n    +actor_rollout_ref.model.override_config.resid_pdrop=0. \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=8 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger=['console'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=20 \\\n    trainer.total_epochs=1 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    data.shuffle=False \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    actor_rollout_ref.actor.entropy_checkpointing=True \\\n    actor_rollout_ref.ref.entropy_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \\\n    trainer.device=npu\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/run_dapo_qwen3_8b_base_npu.sh",
    "content": "#!/bin/bash\nproject_name='DAPO'\nexp_name='DAPO-Qwen3-8B-Base'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=False\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=16\ngen_prompt_bsz=$((train_prompt_bsz * 3))\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=1\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-1}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-8B-Base\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\n\n# Performance Related Parameter\nsp_size=2\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))\noffload=True\ngen_tp=2\n\nray job submit --runtime-env=\"${RUNTIME_ENV}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    +actor_rollout_ref.model.override_config.attention_dropout=0. \\\n    +actor_rollout_ref.model.override_config.embd_pdrop=0. \\\n    +actor_rollout_ref.model.override_config.resid_pdrop=0. \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.90 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger=['console'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=20 \\\n    trainer.total_epochs=1 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    data.shuffle=False \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    actor_rollout_ref.actor.entropy_checkpointing=True \\\n    actor_rollout_ref.ref.entropy_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \\\n    trainer.device=npu\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/run_dapo_qwen3_moe_30b_base_fsdp_npu.sh",
    "content": "#!/usr/bin/env bash\nset -euxo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen3-MOE-30B-FSDP-128rank-gbs512'\n\nNNODES=8\nNPUS_PER_NODE=16\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\nloss_agg_mode=\"token-mean\"\nppo_mini_batch_size=32\n\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=512\ngen_prompt_bsz=$((train_prompt_bsz * 3))\nn_resp_per_prompt=16\n\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=16 # For load-balance. For smaller cluster this can be set to as less as 2.\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / 2))\noffload=True\nrecompute=True\nmax_num_seqs=128\ngen_tp=2\n\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.rollout.max_num_seqs=${max_num_seqs} \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length))  \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    +actor_rollout_ref.model.override_config.attention_dropout=0. \\\n    +actor_rollout_ref.model.override_config.embd_pdrop=0. \\\n    +actor_rollout_ref.model.override_config.resid_pdrop=0. \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=${recompute} \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.forward_prefetch=False \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.ref.fsdp_config.forward_prefetch=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=1 \\\n    trainer.device=\"npu\" \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \n   \n"
  },
  {
    "path": "verl_distillation/recipe/dapo/run_dapo_qwen3_moe_30b_megatron_npu.sh",
    "content": "#!/bin/bash\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen3-30B-megatron'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=16\ngen_prompt_bsz=$((train_prompt_bsz * 2))\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=2\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-1}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-30B-A3B\"}\n# MCORE_MODEL_PATH points to the converted checkpoint.\n# To avoid loading these weights, set actor_rollout_ref.actor.megatron.use_dist_checkpointing=False.\nMCORE_MODEL_PATH=${MCORE_MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-dist_ckpt\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length)))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length)))\noffload=True\n\nmax_num_batched_tokens=$((max_prompt_length + max_response_length))\n\n# Megatron backen\ntrain_tp=4\ntrain_ep=2\ntrain_pp=2\ntrain_cp=1\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    --config-name=\"dapo_megatron_trainer\" \\\n    data.filter_overlong_prompts=False \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.shuffle=False \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.ppo_epochs=1 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    +actor_rollout_ref.model.override_config.attention_dropout=0. \\\n    +actor_rollout_ref.model.override_config.embd_pdrop=0. \\\n    +actor_rollout_ref.model.override_config.resid_pdrop=0. \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=${train_cp} \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=${train_cp} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.enable_prefix_caching=False \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \\\n    actor_rollout_ref.rollout.max_model_len=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger=['console'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=-1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.device=\"npu\" \\\n    actor_rollout_ref.nccl_timeout=14400 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1\n\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh",
    "content": "#!/usr/bin/env bash\nset -euxo pipefail\n# DAPO (w/o Dynamic Sampling)\n\nproject_name='DAPO-verl'\nexp_name='DAPO-wo-DS-Qwen2.5-32B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=False\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-16}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$((max_prompt_length + max_response_length))\ninfer_ppo_max_token_len=$((max_prompt_length + max_response_length))\noffload=True\ngen_tp=4\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=5 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/runtime_env.yaml",
    "content": "working_dir: ./\nexcludes: [\"/.git/\"]\nenv_vars:\n  TORCH_NCCL_AVOID_RECORD_STREAMS: \"1\"\n  VLLM_USE_V1: \"1\"\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/test_dapo_7b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7B-Math-Test'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 2))\nenable_overlong_buffer=True\noverlong_buffer_len=512\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=512\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=16\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=2 \\\n    trainer.save_freq=2 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/test_dapo_7b_math.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-8}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\nfsdp_size=32\n\n# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=200 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/test_dapo_7b_math_lora.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-8}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\nfsdp_size=32\n\n# remember to set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for this model\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=8 \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=200 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/test_dapo_7b_math_megatron.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-megatron-0519a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\ntrain_tp=4\ntrain_pp=2\n\n# TODO: support dynamic_bsz for megatron\n# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/test_dapo_dspk_671b_megatron_96gb.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n# 0. download the config\n# only need to download the configuration_deepseek.py and config.json\n# remove the `quantization_config` in the `config.json`\n# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported\nhuggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json\n\nproject_name='DAPO'\nexp_name='DAPO-DeepSeek-671b-megatron'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=0.1\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=256 # must be > n_gpus. need to fix\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32  # mini_bsz * n >= micro_bsz * pp * dp\n\nNNODES=${NNODES:-64}\n\n# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main\n# change the MODEL_PATH and MCORE_MODEL_PATH to your own path\n# Paths\nMODEL_PATH=\"<path_to_dsv3_config>\"\nMCORE_MODEL_PATH=\"<path_to_dpsk-v3-671B-BF16-dist_ckpt>\"\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\naime24_test_path=${RAY_DATA_HOME}/data/aime-2024.parquet\n# TEST_FILE=\"['$math500_test_path', '$aime24_test_path']\"\n\nTEST_FILE=\"['$aime24_test_path']\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=32\ntrain_tp=1\ntrain_ep=32\ntrain_pp=16\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=3 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=2 \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=5 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=10 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/test_dapo_glm_air_megatron.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNNODES=${NNODES:-8}\nNGPUS_PER_NODES=${NGPUS_PER_NODES:-8}\n\nproject_name='DAPO'\nexp_name='DAPO-GLM-AIR-MATH-megatron'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=128\ntrain_ppo_micro_batch_size_per_gpu=2\ninfer_ppo_micro_batch_size_per_gpu=2\n# Paths\nMODEL_PATH=/models/zai-org/GLM-4.5-Air-Base\n# GLM Base model can use chat_template.jinja from instruct models\ncp /models/zai-org/GLM-4.5-Air/chat_template.jinja ${MODEL_PATH}/chat_template.jinja\n\nTRAIN_FILE=/data/dapo/dapo-math-17k.parquet\naime24_test_path=/data/dapo/aime-2024.parquet\n# math500_test_path=/data/rlhf/math500/test.parquet\n\n# TEST_FILE=\"['$math500_test_path', '$aime24_test_path']\"\n\nTEST_FILE=\"['$aime24_test_path']\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length)))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length)))\noffload=True\n\nCOMMON_PP=${COMMON_PP:-2}\nCOMMON_VPP=${COMMON_VPP:-null}\nCOMMON_CP=${COMMON_CP:-4}\nCOMMON_TP=${COMMON_TP:-2}\nCOMMON_EP=${COMMON_EP:-8}\nCOMMON_ETP=${COMMON_ETP:-1}\n\nTRAIN_TP=${TRAIN_TP:-$COMMON_TP}\nINFER_TP=${INFER_TP:-8}\n\nACTOR_PP=${ACTOR_PP:-$COMMON_PP}\nACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP}\nACTOR_CP=${ACTOR_CP:-$COMMON_CP}\nACTOR_TP=${ACTOR_TP:-$TRAIN_TP}\nACTOR_EP=${ACTOR_EP:-$COMMON_EP}\nACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP}\nROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP}\nREF_PP=${REF_PP:-$COMMON_PP}\nREF_VPP=${REF_VPP:-$COMMON_VPP}\nREF_CP=${REF_CP:-$COMMON_CP}\nREF_TP=${REF_TP:-$TRAIN_TP}\nREF_EP=${REF_EP:-$COMMON_EP}\nREF_ETP=${REF_ETP:-$COMMON_ETP}\nCRITIC_PP=${CRITIC_PP:-$COMMON_PP}\nCRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP}\nCRITIC_CP=${CRITIC_CP:-$COMMON_CP}\nCRITIC_TP=${CRITIC_TP:-$TRAIN_TP}\nCRITIC_EP=${CRITIC_EP:-$COMMON_EP}\nCRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP}\nRM_PP=${RM_PP:-$COMMON_PP}\nRM_VPP=${RM_VPP:-$COMMON_VPP}\nRM_CP=${RM_CP:-$COMMON_CP}\nRM_TP=${RM_TP:-$TRAIN_TP}\nRM_EP=${RM_EP:-$COMMON_EP}\nRM_ETP=${RM_ETP:-$COMMON_ETP}\n\nUSE_MBRIDGE=True\nUSE_DIST_CKPT=False\n\n# Install the latest mbridge\n# pip install --no-cache-dir git+https://github.com/ISEEKYAN/mbridge.git\n\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    +actor_rollout_ref.model.override_config.model_config.max_position_embeddings=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.lr_decay_style='constant' \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${ACTOR_PP} \\\n    actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=${ACTOR_VPP} \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=${ACTOR_CP} \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=${ACTOR_EP} \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ACTOR_ETP} \\\n    actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=\"selective\" \\\n    actor_rollout_ref.actor.megatron.override_transformer_config.recompute_modules=[\"core_attn\",\"moe_act\",\"layernorm\",\"mlp\",\"moe\"] \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_shared_expert_overlap=False \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=\"flex\" \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=False \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.name='vllm' \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${INFER_TP} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${REF_TP} \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${REF_PP} \\\n    actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=${REF_VPP} \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=${REF_CP} \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=${REF_EP} \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${REF_ETP} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODES}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=100 \\\n    trainer.total_epochs=10 \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10"
  },
  {
    "path": "verl_distillation/recipe/dapo/test_dapo_qwen3_30b_math.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0527a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-8}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\nfsdp_size=32\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=300 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/dapo/test_dapo_qwen3_30b_math_single_node.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0719a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 4))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=0.1\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=64\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=16\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-1}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\nfsdp_size=8\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=300 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/deepeyes/README.md",
    "content": "# DeepEyes: Incentivizing \"Thinking with Images\" via Reinforcement Learning\n\nThis directory contains the implementation for reproducing the DeepEyes paper within the verl framework, supporting multi-turn visual tool calls. This implementation is based on the original [DeepEyes paper](https://arxiv.org/abs/2505.14362) and its [official implementation](https://github.com/Visual-Agent/DeepEyes), integrated with the multi-modal and multi-turn capabilities of the verl framework.\n\n## Reproducing the Experiment\n\n> **Note on the 'Chart' Dataset:**\n> \n> The provided preprocessing script intentionally excludes `data_v0.8_visual_toolbox_v2.parquet`, which contains the 'Chart' data. This subset consists of very high-resolution images, often resembling large figures composed of multiple sub-plots, much like those found in academic papers.\n>\n> Consequently, even after using the zoom-in tool, the resulting cropped images remain large. This poses a significant risk of causing Out-of-Memory (OOM) errors, which can abruptly terminate the training process. \n> \n> **We strongly recommend against training on the 'Chart' dataset on a single node.**\n\n> **Note on the 'thinklite' Dataset:**\n> Many images in the `thinklite` dataset have a very low resolution, with either a height or width below 28 pixels. This fails to meet the minimum input size required by the Qwen-2.5VL image processor and would cause errors during data loading.\n>\n> To mitigate this, we upscale these low-resolution images to satisfy the processor's requirements. However, please be aware that because the original resolution is low, subsequent `crop` operations by the zoom-in tool might frequently trigger exceptions, which could in turn affect the model's tool-use performance.\n\nFirst, launch an inference service to act as a judge for reward calculation. You can use the following script as a reference:\n\n```bash\npython -m sglang.launch_server --model-path /path/to/Qwen2.5-72B-Instruct \\\n    --port 18901 \\\n    --tp-size 8 \\\n    --context-length 32768 \\\n    --trust-remote-code \\\n    --log-requests false\n```\n\nNext, you can start the training:\n\n```bash\nbash recipe/deepeyes/run_deepeyes_grpo.sh\n```\n\n## Performance\n\n![score](https://private-user-images.githubusercontent.com/82520804/474784419-b13f4f72-bb3a-4281-a43b-1f34a9037c0c.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3NTQ0NTQxMTMsIm5iZiI6MTc1NDQ1MzgxMywicGF0aCI6Ii84MjUyMDgwNC80NzQ3ODQ0MTktYjEzZjRmNzItYmIzYS00MjgxLWE0M2ItMWYzNGE5MDM3YzBjLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTA4MDYlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwODA2VDA0MTY1M1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTJjNGMxMjhiOGM4MTNhYTEzYTE2MTYzY2ZjYWRhNmEzMmVjNjUxOGI3MTgzOGQyM2ZmOWJlYTZlNDYzYzU0ZDkmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.qTDX-3fyLHWdeFh9o4b6nIAB57bT0XyLjKXhNV6k5nA)\n\n![entropy](https://private-user-images.githubusercontent.com/82520804/474785253-752106a9-e25d-4b44-aef9-1ac98015d05c.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3NTQ0NTQxMTMsIm5iZiI6MTc1NDQ1MzgxMywicGF0aCI6Ii84MjUyMDgwNC80NzQ3ODUyNTMtNzUyMTA2YTktZTI1ZC00YjQ0LWFlZjktMWFjOTgwMTVkMDVjLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTA4MDYlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwODA2VDA0MTY1M1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTM4OGQ2ZGI3M2JlYWE4YTQyMzIxMWYxMzZhNDBmNmYxNzcwNDgxNThiZDRiMzQyYzUwZjc3OWE4YzdhYWEwMWUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.PhimMTxXXEtMLPGzejPQuw-Ul0As8ey-hyy1qkeABIQ)\n\n![num_turns](https://private-user-images.githubusercontent.com/82520804/474785462-c99c7952-14db-485a-acd2-14e5956ecc34.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3NTQ0NTQxMTMsIm5iZiI6MTc1NDQ1MzgxMywicGF0aCI6Ii84MjUyMDgwNC80NzQ3ODU0NjItYzk5Yzc5NTItMTRkYi00ODVhLWFjZDItMTRlNTk1NmVjYzM0LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTA4MDYlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwODA2VDA0MTY1M1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTJkNWYwMGVjOWM4NDVhZTkzZWI5NWMzMGVjZTcyZGM2NDExY2FmYTBlYWJmZTk5YTU5MzM3NmNkYWI4Y2U4Y2YmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.Ieakk_ttMsNygVzpZZqGs1507j2GC-rqHSYH9iQQ71Q)\n\nSee [Comment](https://github.com/volcengine/verl/pull/2398#issuecomment-3157142856) for more details.\n\nNote: AgentLoop does not directly record num_tool_calls, but records num_turns. In our scenario, you can calculate the number of tool calls by num_tool_calls = num_turns / 2 - 1.\n\n## References and Acknowledgements\n\n- [DeepEyes Paper](https://arxiv.org/abs/2505.14362)\n- [DeepEyes Official Implementation](https://github.com/Visual-Agent/DeepEyes)\n\n---\nIf you need further details for reproduction or encounter any issues, feel free to open an issue or contact the maintainers. "
  },
  {
    "path": "verl_distillation/recipe/deepeyes/configs/deepeyes_multiturn_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 2048\n  max_response_length: 2048\n  train_batch_size: 256\n  return_raw_chat: True\n  return_multi_modal_inputs: False\n  custom_cls:\n    path: \"recipe/deepeyes/deepeyes.py\"\n    name: CustomRLHFDataset\n\nactor_rollout_ref:\n  hybrid_engine: True\n  model:\n    custom_chat_template: \"{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\\\n' }}{%- if messages[0]['role'] == 'system' %}{%- if messages[0]['content'] is string %}{{- messages[0]['content'] }}{%- else %}{{- messages[0]['content'][0]['text'] }}{%- endif %}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \\\"\\\\n\\\\n# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\\" }}{%- for tool in tools %}{{- \\\"\\\\n\\\" }}{{- tool | tojson }}{%- endfor %}{{- \\\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\\"name\\\\\\\": <function-name>, \\\\\\\"arguments\\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\\n{% endif %}{%- elif message.role == \\\"assistant\\\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}{{- tool_call.name }}{{- '\\\", \\\"arguments\\\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\\\n' }}{%- elif message.role == \\\"tool\\\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\\\n<tool_response>\\\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}{{- '<|im_end|>\\\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n{% endif %}{%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\\n{% endif %}{%- elif message.role == \\\"assistant\\\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}{{- tool_call.name }}{{- '\\\", \\\"arguments\\\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\\\n' }}{%- elif message.role == \\\"tool\\\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\\\n<tool_response>\\\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}{{- '<|im_end|>\\\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\\n{% endif %}\"\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n      tool_config_path: \"recipe/deepeyes/config/image_zoom_in_tool_config.yaml\"\n\ncustom_reward_function:\n  path: \"recipe/deepeyes/deepeyes.py\"\n  name: compute_score"
  },
  {
    "path": "verl_distillation/recipe/deepeyes/configs/image_zoom_in_tool_config.yaml",
    "content": "tools:\n  - class_name: \"verl.tools.image_zoom_in_tool.ImageZoomInTool\"\n    config:\n      num_workers: 256\n      rate_limit: 256\n      timeout: 60\n      type: native\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"image_zoom_in_tool\"\n        description: \"Zoom in on a specific region of an image by cropping it based on a bounding box (bbox) and an optional object label.\"\n        parameters:\n          type: \"object\"  \n          properties:\n            bbox_2d:\n              type: \"array\"\n              items:\n                type: \"number\"\n              minItems: 4\n              maxItems: 4\n              description: \"The bounding box of the region to zoom in, as [x1, y1, x2, y2], where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner.\"\n            label:\n              type: \"string\"\n              description: \"The name or label of the object in the specified bounding box (optional).\"\n          required: [\"bbox_2d\"]"
  },
  {
    "path": "verl_distillation/recipe/deepeyes/deepeyes.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport io\nimport logging\nimport os\nimport random\nimport re\n\nimport requests\nfrom openai import OpenAI\nfrom PIL import Image\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.utils.dataset.rl_dataset import RLHFDataset\nfrom verl.utils.model import compute_position_id_with_mask\n\nlogger = logging.getLogger(__name__)\n\nopenai_api_key = \"EMPTY\"\nopenai_api_base = os.environ.get(\"LLM_AS_A_JUDGE_BASE\", \"http://10.1.100.71:18901/v1\")\n\nclient = OpenAI(\n    api_key=openai_api_key,\n    base_url=openai_api_base,\n)\n\nmodel_name = \"\"\nif openai_api_base:\n    try:\n        response = requests.get(f\"{openai_api_base}/models\")\n        response.raise_for_status()\n        models = response.json()\n        if models.get(\"data\"):\n            model_name = models[\"data\"][0][\"id\"]\n        else:\n            logger.warning(\"No models found at the specified API base for reward scoring.\")\n    except (requests.exceptions.RequestException, KeyError, IndexError) as e:\n        logger.warning(f\"Failed to get model from {openai_api_base}: {e}. Reward scoring will be disabled.\")\n\n\nclass CustomRLHFDataset(RLHFDataset):\n    def __getitem__(self, item):\n        \"\"\"\n        Note that we also return the raw_input_ids so that it can be combined with other chat template\n        \"\"\"\n        row_dict: dict = self.dataframe[item]\n        row_dict[self.prompt_key] = [\n            {\n                \"role\": \"system\",\n                # We don't need tool description, because custom_chat_template will add it.\n                \"content\": (\n                    \"You are a helpful assistant. You can call functions to assist with the user query. \"\n                    \"Important: You must call only one function at a time. After each function call, \"\n                    \"wait for the execution result before making the next function call if needed.\"\n                ),\n            },\n            {\n                \"role\": \"user\",\n                \"content\": row_dict[self.prompt_key][1][\"content\"],\n            },\n        ]\n        messages = self._build_messages(row_dict)\n        model_inputs = {}\n\n        if self.processor is not None:\n            raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n            multi_modal_data = {}\n\n            images = None\n            row_dict_images = row_dict.pop(self.image_key, None)\n            if row_dict_images:\n                images = [Image.open(io.BytesIO(image[\"bytes\"])) for image in row_dict_images]\n\n                # due to the image key is \"image\" instead of \"images\" in vllm, we need to use \"image\" here\n                # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205  # noqa: E501\n                multi_modal_data[\"image\"] = images\n\n            model_inputs = self.processor(text=[raw_prompt], images=images, return_tensors=\"pt\")\n\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n            if \"second_per_grid_ts\" in model_inputs:\n                model_inputs.pop(\"second_per_grid_ts\")\n\n            # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature\n            row_dict[\"multi_modal_data\"] = multi_modal_data\n\n            # We will do batch.union() in the trainer,\n            # so we cannot have \"multi_modal_inputs\" in row_dict if rollout generates new multi_modal_inputs\n            if self.return_multi_modal_inputs:\n                row_dict[\"multi_modal_inputs\"] = dict(model_inputs)\n\n                # second_per_grid_ts isn't used for training, just for mrope\n                row_dict[\"multi_modal_inputs\"].pop(\"second_per_grid_ts\", None)\n\n        else:\n            raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n            model_inputs = self.tokenizer(raw_prompt, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n        input_ids, attention_mask = verl_F.postprocess_data(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            max_length=self.max_prompt_length,\n            pad_token_id=self.tokenizer.pad_token_id,\n            left_pad=True,\n            truncation=self.truncation,\n        )\n\n        if self.processor is not None and \"Qwen2VLImageProcessor\" in self.processor.image_processor.__class__.__name__:\n            from verl.models.transformers.qwen2_vl import get_rope_index\n\n            position_ids = [\n                get_rope_index(\n                    self.processor,\n                    input_ids=input_ids[0],\n                    image_grid_thw=model_inputs.get(\"image_grid_thw\"),\n                    video_grid_thw=model_inputs.get(\"video_grid_thw\"),\n                    second_per_grid_ts=model_inputs.get(\"second_per_grid_ts\"),\n                    attention_mask=attention_mask[0],\n                )\n            ]  # (1, 3, seq_len)\n\n        else:\n            position_ids = compute_position_id_with_mask(attention_mask)\n\n        row_dict[\"input_ids\"] = input_ids[0]\n        row_dict[\"attention_mask\"] = attention_mask[0]\n        row_dict[\"position_ids\"] = position_ids[0]\n\n        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)\n        if len(raw_prompt_ids) > self.max_prompt_length:\n            if self.truncation == \"left\":\n                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]\n            elif self.truncation == \"right\":\n                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]\n            elif self.truncation == \"middle\":\n                left_half = self.max_prompt_length // 2\n                right_half = self.max_prompt_length - left_half\n                raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]\n            elif self.truncation == \"error\":\n                raise RuntimeError(f\"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.\")\n\n        row_dict[\"raw_prompt_ids\"] = raw_prompt_ids\n        # encode prompts without chat template\n        if self.return_raw_chat:\n            row_dict[\"raw_prompt\"] = messages\n\n        # get prompts with chat template\n        if self.return_full_prompt:\n            row_dict[\"full_prompts\"] = raw_prompt  # array of strings\n\n        # add index for each prompt\n        index = row_dict.get(\"extra_info\", {}).get(\"index\", 0)\n        tools_kwargs = {\n            \"image_zoom_in_tool\": {\n                \"create_kwargs\": {\"image\": images[0]},\n                # \"execute_kwargs\": {},\n                # \"calc_reward_kwargs\": {},\n                # \"release_kwargs\": {},\n            }\n        }\n        row_dict[\"index\"] = index\n        row_dict[\"tools_kwargs\"] = tools_kwargs\n        row_dict[\"agent_name\"] = \"tool_agent\"\n        return row_dict\n\n\ndef compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None) -> float:\n    \"\"\"\n    Compute reward score for model solutions with robust handling of various formats.\n\n    Returns a weighted combination of:\n    - Accuracy reward (0.8 weight): Whether the answer is semantically correct\n    - Format reward (0.2 weight): Whether the output follows expected format\n    - Tool reward (1.2 weight): Whether tools were used when answer is correct\n    \"\"\"\n\n    # Initialize tracking variables\n    is_format_error = False\n\n    # 1. Check <think> tag format\n    count_think_1 = solution_str.count(\"<think>\")\n    count_think_2 = solution_str.count(\"</think>\")\n    if count_think_1 != count_think_2:\n        is_format_error = True\n\n    # 2. Check vision tokens (skip this since tokenizer removes special tokens)\n    # We'll use <tool_call> and <tool_response> instead to detect tool usage\n\n    # 3. Extract answer text with multiple fallback strategies\n    answer_text = \"\"\n\n    # Strategy 1: Try to extract from <answer> tags first\n    predict_no_think = (\n        solution_str.split(\"</think>\")[-1].strip() if \"</think>\" in solution_str else solution_str.strip()\n    )\n\n    # Check <answer> tag format\n    count_answer_1 = predict_no_think.count(\"<answer>\")\n    count_answer_2 = predict_no_think.count(\"</answer>\")\n    if count_answer_1 != count_answer_2:\n        is_format_error = True\n\n    # Try to extract from <answer> tags\n    answer_match = re.search(r\"<answer>(.*?)</answer>\", predict_no_think, re.DOTALL)\n    if answer_match:\n        answer_text = answer_match.group(1).strip()\n    else:\n        # No proper <answer> tags found - this is a format error\n        is_format_error = True\n\n        # Strategy 2: If no <answer> tags, extract content after tool responses\n        # Look for pattern: <tool_response>...</tool_response>assistant\\n[actual_answer]\n        tool_response_match = re.search(\n            r\"</tool_response>\\s*assistant\\s*\\n(.*?)$\", predict_no_think, re.DOTALL | re.MULTILINE\n        )\n        if tool_response_match:\n            answer_text = tool_response_match.group(1).strip()\n        else:\n            # Strategy 3: If no tool responses, look for content after </think>\n            if \"</think>\" in solution_str:\n                # Remove any remaining tool-related tags and extract meaningful content\n                remaining_content = predict_no_think\n                # Remove tool calls and responses\n                remaining_content = re.sub(r\"<tool_call>.*?</tool_call>\", \"\", remaining_content, flags=re.DOTALL)\n                remaining_content = re.sub(\n                    r\"<tool_response>.*?</tool_response>\", \"\", remaining_content, flags=re.DOTALL\n                )\n                # Remove user/assistant markers\n                remaining_content = re.sub(r\"\\b(user|assistant)\\b\", \"\", remaining_content)\n                answer_text = remaining_content.strip()\n            else:\n                # Strategy 4: Use the entire solution_str as fallback\n                answer_text = solution_str.strip()\n\n    # Clean up answer text\n    answer_text = answer_text.strip()\n\n    # If answer is still empty after all strategies, mark as format error\n    if not answer_text:\n        is_format_error = True\n        answer_text = solution_str.strip()  # Use full text as last resort\n\n    # 4. Evaluate correctness using LLM judge\n    question_text = extra_info.get(\"question\", \"\") if extra_info else \"\"\n\n    if not client or not model_name:\n        logger.warning(\"Reward function client not initialized or model name not found.\")\n        return 0.0\n\n    system_prompt = (\n        \"You are an expert evaluator. Your task is to determine if a model's answer is semantically equivalent to a \"\n        \"provided standard answer, given a specific question.\\n\"\n        \"Your evaluation must be strict. The model's answer is only correct if it fully matches the meaning of the \"\n        \"standard answer.\\n\"\n        'You must provide your final judgement as a single word: either \"CORRECT\" or \"INCORRECT\". Do not provide '\n        \"any explanation or other text.\"\n    )\n\n    user_prompt = (\n        f\"I will provide a question, a standard answer, and a model's answer. You must evaluate if the model's \"\n        f\"answer is correct.\\n\\n\"\n        f\"---\\n\"\n        f\"**Example 1:**\\n\"\n        f\"[Question]: Is the countertop tan or blue?\\n\"\n        f\"[Standard Answer]: The countertop is tan.\\n\"\n        f\"[Model's Answer]: tan\\n\"\n        f\"[Your Judgement]: CORRECT\\n\"\n        f\"---\\n\"\n        f\"**Example 2:**\\n\"\n        f\"[Question]: Is the man phone both blue and closed?\\n\"\n        f\"[Standard Answer]: Yes, the man phone is both blue and closed.\\n\"\n        f\"[Model's Answer]: No.\\n\"\n        f\"[Your Judgement]: INCORRECT\\n\"\n        f\"---\\n\"\n        f\"**Task:**\\n\"\n        f\"[Question]: {question_text}\\n\"\n        f\"[Standard Answer]: {ground_truth}\\n\"\n        f\"[Model's Answer]: {answer_text}\\n\"\n        f\"[Your Judgement]:\"\n    )\n\n    try:\n        chat_response = client.chat.completions.create(\n            model=model_name,\n            messages=[\n                {\"role\": \"system\", \"content\": system_prompt},\n                {\"role\": \"user\", \"content\": user_prompt},\n            ],\n            seed=random.randint(0, 1000000),\n            temperature=0.1,  # Lower temperature for more deterministic judgement\n            extra_body={\n                \"chat_template_kwargs\": {\"enable_thinking\": False},\n            },\n        )\n        response = chat_response.choices[0].message.content.strip()\n    except Exception as e:\n        logger.warning(f\" [WARNING] Chat completion request failed: {e}\")\n        return 0.0\n\n    # Parse LLM judge response\n    if re.search(r\"\\bCORRECT\\b\", response, re.IGNORECASE):\n        acc_reward = 1.0\n    elif re.search(r\"\\bINCORRECT\\b\", response, re.IGNORECASE):\n        acc_reward = 0.0\n    else:\n        logger.warning(\n            f\" [WARNING] Judgement format error. Expected 'CORRECT' or 'INCORRECT'.\\n\"\n            f\"Response: '{response}'\\n\"\n            f\"Model Answer: '{answer_text}'\\n\"\n            f\"Ground Truth: '{ground_truth}'\"\n        )\n        acc_reward = 0.0\n\n    # Penalize excessively long answers (potential judge hacking)\n    if len(answer_text) >= 1000:\n        acc_reward = 0.0\n        is_format_error = True\n\n    # 5. Check tool usage - look for tool_call/tool_response patterns instead of vision tokens\n    has_tool_usage = bool(\n        re.search(r\"<tool_call>.*?</tool_call>\", solution_str, re.DOTALL)\n        or re.search(r\"<tool_response>.*?</tool_response>\", solution_str, re.DOTALL)\n    )\n\n    # Tool reward: only give if tools were used AND answer is correct\n    tool_reward = 1.0 if has_tool_usage and acc_reward > 0.5 else 0.0\n\n    # Format reward: penalty for format errors\n    format_reward = -1.0 if is_format_error else 0.0\n\n    # Log debug information for problematic cases\n    if is_format_error or not answer_text:\n        logger.debug(\n            f\"Format issue detected:\\n\"\n            f\"Solution: {solution_str[:200]}...\\n\"\n            f\"Extracted answer: '{answer_text}'\\n\"\n            f\"Format error: {is_format_error}\\n\"\n            f\"Tool usage: {has_tool_usage}\"\n        )\n\n    # Final weighted score\n    final_score = 0.8 * acc_reward + 0.2 * format_reward + 1.2 * tool_reward\n\n    return final_score\n\n\nif __name__ == \"__main__\":\n    # Test case 1: Original test case\n    predict_str = \"The answer is 2 + 2 = 4 </think> <answer> right </answer> <answer> left </answer>\"\n    ground_truth = \"left\"\n    extra_info = {\n        \"answer\": \"The woman is to the left of the man who is holding the camera.\",\n        \"id\": 0,\n        \"image\": \"/cpfs/user/honglingyi/DATA/LLM/Vstar/gqa/images/713270.jpg\",\n        \"pred_ans\": \"The woman is to the right of the man who is holding the camera.\",\n        \"question\": \"Is the woman to the left or to the right of the man who is holding the camera?\",\n    }\n    print(\"=== Test Case 1: Original test ===\")\n    import time\n\n    time_start = time.time()\n    score = compute_score(\"common_reasoning\", predict_str, ground_truth, extra_info)\n    print(f\"Score: {score}\")\n    time_end = time.time()\n    print(f\"Time: {time_end - time_start}\")\n\n    # Test case 2: Problematic case mentioned by user\n    problematic_solution = \"\"\"<tool_call>\n{\"name\": \"image_zoom_in_tool\", \"arguments\": {\"bbox_2d\": [226, 399, 265, 464], \"label\": \"white van\"}}\n</tool_call>user\n<tool_response>\nZoomed in on the image to the region [226, 399, 265, 464] with label white van.\n</tool_response>\nassistant\nThe white van is visible in the lower section of the image, near the diagonal road.\"\"\"\n\n    problematic_ground_truth = \"Yes, the white van is indeed situated in the bottom part of the picture.\"\n    problematic_extra_info = {\n        \"question\": \"Is the white van in the bottom part of the picture?\",\n    }\n\n    print(\"\\n=== Test Case 2: Problematic case (no answer tags) ===\")\n    print(f\"Solution: {problematic_solution}\")\n    print(f\"Ground truth: {problematic_ground_truth}\")\n\n    time_start = time.time()\n    score2 = compute_score(\"common_reasoning\", problematic_solution, problematic_ground_truth, problematic_extra_info)\n    print(f\"Score: {score2}\")\n    time_end = time.time()\n    print(f\"Time: {time_end - time_start}\")\n\n    # Test case 3: Well-formatted case with tools\n    well_formatted_solution = \"\"\"<think>\nI need to use the image zoom tool to get a better look at the specific area.\n</think>\n<tool_call>\n{\"name\": \"image_zoom_in_tool\", \"arguments\": {\"bbox_2d\": [226, 399, 265, 464], \"label\": \"white van\"}}\n</tool_call>\n<tool_response>\nZoomed in on the image to the region [226, 399, 265, 464] with label white van.\n</tool_response>\n<answer>Yes, the white van is indeed situated in the bottom part of the picture.</answer>\"\"\"\n\n    print(\"\\n=== Test Case 3: Well-formatted case ===\")\n    time_start = time.time()\n    score3 = compute_score(\n        \"common_reasoning\", well_formatted_solution, problematic_ground_truth, problematic_extra_info\n    )\n    print(f\"Score: {score3}\")\n    time_end = time.time()\n    print(f\"Time: {time_end - time_start}\")\n"
  },
  {
    "path": "verl_distillation/recipe/deepeyes/run_deepeyes_grpo.sh",
    "content": "#!/bin/bash\n\nset -x\n\nexport LLM_AS_A_JUDGE_BASE=\"your llm-as-a-judge server/v1\"\nexport WANDB_API_KEY=\"your wandb key\"\n\nPROJECT_NAME=\"your_project_name\"\nEXPERIMENT_NAME=\"your_experiment_name\"\n\nBASEDIR=base_dir\nSAVE_CHECKPOINT_DIR=${BASEDIR}/verl_checkpoints\nDATASET_TRAIN=${BASEDIR}/dataset/train.parquet\nDATASET_VAL=${BASEDIR}/dataset/val.parquet\n\nREF_MODEL_PATH=ref_model_path\n\nPYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n    --config-path=${BASEDIR}/recipe/deepeyes/configs \\\n    --config-name='deepeyes_multiturn_grpo' \\\n    data.train_files=${DATASET_TRAIN} \\\n    data.val_files=[${DATASET_VAL}] \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=8192 \\\n    data.max_response_length=16384 \\\n    data.return_raw_chat=True \\\n    data.filter_overlong_prompts=True \\\n    algorithm.adv_estimator=grpo \\\n    algorithm.kl_ctrl.kl_coef=0.0 \\\n    actor_rollout_ref.model.path=${REF_MODEL_PATH} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.0 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0.0 \\\n    actor_rollout_ref.actor.checkpoint.save_contents=['model','hf_model','optimizer','extra'] \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=32768 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.rollout.multi_turn.enable=True \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5 \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=5 \\\n    actor_rollout_ref.rollout.multi_turn.max_parallel_calls=1 \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=recipe/deepeyes/configs/image_zoom_in_tool_config.yaml \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=['console','wandb','tensorboard'] \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=8 \\\n    trainer.test_freq=80 \\\n    trainer.project_name=${PROJECT_NAME} \\\n    trainer.experiment_name=${EXPERIMENT_NAME} \\\n    trainer.default_local_dir=${SAVE_CHECKPOINT_DIR}/${PROJECT_NAME}/${EXPERIMENT_NAME} \\\n    +trainer.tensorboard_dir=${SAVE_CHECKPOINT_DIR}/logs/tensorboard \\\n    +trainer.rl_logging_board_dir=${SAVE_CHECKPOINT_DIR}/logs/rl_logging_board \\\n    trainer.total_epochs=1 2>&1 | tee ./logs/${EXPERIMENT_NAME}.log\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/32b_clip_cov.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport WANDB_API_KEY=YOUR_WANDB_API_KEY\n# export VLLM_USE_V1=1\n\nproject_name='Qwen2.5-32B'\nexp_name='clipcov'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=1\nclip_ratio_high=1\nclip_cov_ratio=0.0002\nclip_cov_lb=1.0\nclip_cov_ub=5.0\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 2))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=\"clip_cov\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=256\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=8\nmax_token=20480\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"/YOUR_MODELPATH\"}\nCKPTS_DIR=${CKPTS_DIR:-\"/YOUR_CKPTS_PATH\"}\nTRAIN_FILE=${TRAIN_FILE:-\"/YOUR_TRAIN_FILE_PATH\"}\nTEST_FILE=${TEST_FILE:-[\"/YOUR_TRAIN_FILE_PATH\"]}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nppo_kl_coef=1\nkl_cov_ratio=0.02\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nHYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.filter_overlong_prompts=False \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    actor_rollout_ref.rollout.name=vllm \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.weight_decay=0 \\\n    actor_rollout_ref.actor.optim.lr_scheduler_type=constant \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.clip_cov_ratio=${clip_cov_ratio} \\\n    actor_rollout_ref.actor.clip_cov_lb=${clip_cov_lb} \\\n    actor_rollout_ref.actor.clip_cov_ub=${clip_cov_ub} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=False \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=4 \\\n    trainer.save_freq=32 \\\n    trainer.total_epochs=1000 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/32b_kl_cov.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport WANDB_API_KEY=YOUR_WANDB_API_KEY\n# export VLLM_USE_V1=1\n\nproject_name='Qwen2.5-32B'\nexp_name='klcov'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.2\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 2))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=\"kl_cov\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=256\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=8\nmax_token=20480\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"/YOUR_MODELPATH\"}\nCKPTS_DIR=${CKPTS_DIR:-\"/YOUR_CKPTS_PATH\"}\nTRAIN_FILE=${TRAIN_FILE:-\"/YOUR_TRAIN_FILE_PATH\"}\nTEST_FILE=${TEST_FILE:-[\"/YOUR_TRAIN_FILE_PATH\"]}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nppo_kl_coef=1\nkl_cov_ratio=0.0002\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nHYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.filter_overlong_prompts=False \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \\\n    actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    actor_rollout_ref.rollout.name=vllm \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.weight_decay=0 \\\n    actor_rollout_ref.actor.optim.lr_scheduler_type=constant \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=False \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=4 \\\n    trainer.save_freq=32 \\\n    trainer.total_epochs=1000 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/32b_kl_cov_mininbsz.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport WANDB_API_KEY=YOUR_WANDB_API_KEY\n# export VLLM_USE_V1=1\n\nproject_name='Qwen2.5-32B'\nexp_name='klcov'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.2\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 2))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=\"kl_cov\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=256\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=16\nn_resp_per_prompt=8\nmax_token=20480\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"/YOUR_MODELPATH\"}\nCKPTS_DIR=${CKPTS_DIR:-\"/YOUR_CKPTS_PATH\"}\nTRAIN_FILE=${TRAIN_FILE:-\"/YOUR_TRAIN_FILE_PATH\"}\nTEST_FILE=${TEST_FILE:-[\"/YOUR_TRAIN_FILE_PATH\"]}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nppo_kl_coef=1\nkl_cov_ratio=0.0002\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nHYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.filter_overlong_prompts=False \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \\\n    actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    actor_rollout_ref.rollout.name=vllm \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.weight_decay=0 \\\n    actor_rollout_ref.actor.optim.lr_scheduler_type=constant \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=False \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=4 \\\n    trainer.save_freq=32 \\\n    trainer.total_epochs=1000 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/7b_clip_cov.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport WANDB_API_KEY=YOUR_WANDB_API_KEY\n# export VLLM_USE_V1=1\n\nproject_name='Qwen2.5-7B'\nexp_name='clipcov'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=1\nclip_ratio_high=1\nclip_cov_ratio=0.0002\nclip_cov_lb=1.0\nclip_cov_ub=5.0\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 2))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=\"clip_cov\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=256\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=8\nmax_token=30720\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"/YOUR_MODELPATH\"}\nCKPTS_DIR=${CKPTS_DIR:-\"/YOUR_CKPTS_PATH\"}\nTRAIN_FILE=${TRAIN_FILE:-\"/YOUR_TRAIN_FILE_PATH\"}\nTEST_FILE=${TEST_FILE:-[\"/YOUR_TRAIN_FILE_PATH\"]}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nppo_kl_coef=1\nkl_cov_ratio=0.2\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nHYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.filter_overlong_prompts=False \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    actor_rollout_ref.rollout.name=vllm \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.weight_decay=0 \\\n    actor_rollout_ref.actor.optim.lr_scheduler_type=constant \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=False \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=4 \\\n    trainer.save_freq=32 \\\n    trainer.total_epochs=1000 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/7b_kl_cov.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport WANDB_API_KEY=YOUR_WANDB_API_KEY\n# export VLLM_USE_V1=1\n\nproject_name='Qwen2.5-7B'\nexp_name='klcov'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.2\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 2))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=\"kl_cov\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=256\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=8\nmax_token=30720\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"/YOUR_MODELPATH\"}\nCKPTS_DIR=${CKPTS_DIR:-\"/YOUR_CKPTS_PATH\"}\nTRAIN_FILE=${TRAIN_FILE:-\"/YOUR_TRAIN_FILE_PATH\"}\nTEST_FILE=${TEST_FILE:-[\"/YOUR_TRAIN_FILE_PATH\"]}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nppo_kl_coef=1\nkl_cov_ratio=0.002\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nHYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.filter_overlong_prompts=False \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \\\n    actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    actor_rollout_ref.rollout.name=vllm \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.weight_decay=0 \\\n    actor_rollout_ref.actor.optim.lr_scheduler_type=constant \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=False \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=4 \\\n    trainer.save_freq=32 \\\n    trainer.total_epochs=1000 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/README.md",
    "content": "<div align=\"center\">\n\n# The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning.\n\n[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617)  [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue\n)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861)\n\n\n<div align=\"center\" style=\"font-family: Arial, sans-serif;\">\n  <p>\n    <a href=\"#🎉news\" style=\"text-decoration: none; font-weight: bold;\">🎉 News</a> •\n    <a href=\"#✨getting-started\" style=\"text-decoration: none; font-weight: bold;\">✨ Getting Started</a> •\n    <a href=\"#📖introduction\" style=\"text-decoration: none; font-weight: bold;\">📖 Introduction</a>\n  </p>\n  <p>\n    <a href=\"#🎈citation\" style=\"text-decoration: none; font-weight: bold;\">🎈 Citation</a> •\n    <a href=\"#🌻acknowledgement\" style=\"text-decoration: none; font-weight: bold;\">🌻 Acknowledgement</a> •\n    <a href=\"#📬Contact\" style=\"text-decoration: none; font-weight: bold;\">📬 Contact</a> •\n    <a href=\"#📈star-history\" style=\"text-decoration: none; font-weight: bold;\">📈 Star History</a>\n  </p>\n</div>\n\n</div>\n\n\n# 🎉News\n\n- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29).\n- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse. \n\n\n\n# ✨Getting started\n\nAfter preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run:\n\n```\ncd verl\nconda activate your_env\nbash recipe/dapo/7b_kl_cov.sh\n```\n\nWhile for training Qwen2.5-32B on multi nodes, you can run the following commands:\n\n```\ncd verl\nconda activate your_env\nbash recipe/dapo/32b_kl_cov.sh\n```\n\n# 📖Introduction\n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/e2a.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\nThis paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion. \n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/cov.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\nTheoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose ​​Clip-Cov​​ and ​​KL-Cov​​, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance. \n\n# 📃Evaluation\n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/performance_fig.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\n\nOur method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL. \n| **Method**        | **AIME24** | **AIME25** |  **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** |\n| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: |\n| *Qwen2.5-7B*      |            |            |          |              |               |                   |             |          |\n| GRPO              |       21.2 |        9.6 |     58.7 |         78.8 |          27.9 |              40.7 |        36.7 |     38.6 |\n| w. Clip-higher    |       18.1 |       11.5 |     56.6 |         79.2 |          29.8 |              43.3 |        40.4 |     38.8 |\n| w. **`CLIP-Cov`** |       22.1 |   **15.8** |     58.2 |         80.4 |      **30.5** |          **44.1** |    **41.1** |     40.4 |\n| w. **`KL-Cov`**   |   **22.6** |       12.9 | **61.4** |     **80.8** |          29.1 |              42.6 |        38.2 | **40.6** |\n| *Qwen2.5-32B*     |            |            |          |              |               |                   |             |          |\n| GRPO              |       21.8 |       16.2 |     69.7 |         84.2 |          35.2 |              43.6 |        45.5 |     45.8 |\n| w. Clip-higher    |       35.6 |       22.3 |     69.5 |         77.2 |          35.1 |              42.5 |        43.0 |     47.2 |\n| w. **`CLIP-Cov`** |       32.3 |       22.7 |     67.2 |     **87.0** |      **42.0** |          **57.2** |        46.0 |     50.3 |\n| w. **`KL-Cov`**   |   **36.8** |   **30.8** | **74.5** |         84.6 |          39.1 |              49.0 |    **46.3** | **52.2** |\n\nOur two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively.\n\n\n# 🎈Citation\nIf you find this paper or repo helpful, please cite us.\n\n```bibtex\n@article{cui2025entropy,\n  title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models},\n  author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others},\n  journal={arXiv preprint arXiv:2505.22617},\n  year={2025}\n}\n```\n# 🌻Acknowledgement\nWe implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions!\n\n# 📬 Contact\n\nFor questions, discussion, or collaboration opportunities, feel free to contact:\n- Ganqu Cui: cuiganqu@pjlab.org.cn\n- Yuchen Zhang: yuchen.zhang2003@gmail.com\n- Jiacheng Chen: jackchan9345@gmail.com\n- Ning Ding: ningding.cs@gmail.com\n\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/config/entropy_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  gen_batch_size: ${data.train_batch_size}\n\nreward_model:\n  reward_kwargs:\n        overlong_buffer_cfg: ${reward_model.overlong_buffer}\n  reward_manager: dapo\n  overlong_buffer: \n    enable: False \n    len: 0\n    penalty_factor: 0.0\n    log: False\n\nalgorithm:\n  filter_groups:\n    enable: False # We try to avoid forgetting to set enable\n    metric: null # acc / score / seq_reward / seq_final_reward / ...\n    max_num_gen_batches: 0 # Non-positive values mean no upper limit\n\ntrainer:\n  project_name: verl-entropy\n\nactor_rollout_ref:\n  actor:\n    policy_loss:\n      loss_mode: \"vanilla\" # /clip-cov / kl-cov from https://arxiv.org/abs/2505.\n      clip_cov_ratio: 0.0002 # for clip-cov loss\n      clip_cov_lb: 1.0 # for clip-cov loss\n      clip_cov_ub: 5.0 # for clip-cov loss\n      kl_cov_ratio: 0.0002 # for kl-cov loss\n      ppo_kl_coef: 0.1 # for kl-cov loss"
  },
  {
    "path": "verl_distillation/recipe/entropy/entropy_ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nFSDP PPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport uuid\nfrom collections import defaultdict\nfrom copy import deepcopy\nfrom pprint import pprint\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics\nfrom verl.trainer.ppo.ray_trainer import (\n    AdvantageEstimator,\n    RayPPOTrainer,\n    apply_kl_penalty,\n    compute_advantage,\n    compute_response_mask,\n)\nfrom verl.trainer.ppo.reward import compute_reward\nfrom verl.utils.metric import reduce_metrics\nfrom verl.utils.profiler import simple_timer\n\n\nclass RayEntropyTrainer(RayPPOTrainer):\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    def compute_kl_related_metrics(self, batch: DataProto, timing_raw: dict):\n        batch.batch[\"response_mask\"] = compute_response_mask(batch)\n\n        # recompute old_log_probs\n        with simple_timer(\"old_log_prob\", timing_raw):\n            old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n            batch = batch.union(old_log_prob)\n\n        if self.use_reference_policy:\n            # compute reference log_prob\n            with simple_timer(\"ref\", timing_raw):\n                if not self.ref_in_actor:\n                    ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                else:\n                    ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)\n                batch = batch.union(ref_log_prob)\n\n        return batch\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n\n        timing_raw = defaultdict(float)\n        batch = None\n        num_prompt_in_batch = 0\n        num_gen_batches = 0\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n\n                new_batch: DataProto = DataProto.from_single_dict(batch_dict)\n                num_gen_batches += 1\n                # pop those keys for generation\n                if \"multi_modal_inputs\" in new_batch.non_tensor_batch.keys():\n                    gen_batch = new_batch.pop(\n                        batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"],\n                        non_tensor_batch_keys=[\"raw_prompt_ids\", \"multi_modal_data\", \"multi_modal_inputs\"],\n                    )\n                else:\n                    gen_batch = new_batch.pop(\n                        batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"],\n                        non_tensor_batch_keys=[\"raw_prompt_ids\"],\n                    )\n                gen_batch_output = gen_batch.repeat(\n                    repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True\n                )\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                with simple_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with simple_timer(\"gen\", timing_raw):\n                        if not self.async_rollout_mode:\n                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)\n                        else:\n                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)\n\n                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                        with simple_timer(\"gen_max\", timing_raw):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n\n                            new_batch = new_batch.union(gen_baseline_output)\n                            # compute reward model score on new_batch\n                            rm_scores = None\n                            if self.use_rm and \"rm_scores\" not in new_batch.batch.keys():\n                                rm_scores = self.rm_wg.compute_rm_score(new_batch)\n                                new_batch = new_batch.union(rm_scores)\n                            reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            keys_to_pop = set(gen_baseline_output.batch.keys())\n                            if rm_scores is not None:\n                                keys_to_pop.update(rm_scores.batch.keys())\n                            new_batch.pop(batch_keys=list(keys_to_pop))\n\n                            new_batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del rm_scores, gen_baseline_batch, gen_baseline_output\n\n                    new_batch.non_tensor_batch[\"uid\"] = np.array(\n                        [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object\n                    )\n                    # repeat to align with repeated responses in rollout\n                    new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    new_batch = new_batch.union(gen_batch_output)\n\n                    if self.config.algorithm.use_kl_in_reward:\n                        # We need these metrics for apply_kl_penalty if using kl in reward\n                        new_batch = self.compute_kl_related_metrics(new_batch, timing_raw)\n                        # otherwise, we will compute those after dynamic sampling\n\n                    with simple_timer(\"reward\", timing_raw):\n                        # compute scores. Support both model and function-based.\n                        # We first compute the scores using reward model. Then, we call reward_fn to combine\n                        # the results from reward model and rule-based results.\n                        if self.use_rm and \"rm_scores\" not in new_batch.batch.keys():\n                            # we first compute reward model score\n                            reward_tensor = self.rm_wg.compute_rm_score(new_batch)\n                            new_batch = new_batch.union(reward_tensor)\n\n                        # we combine with rule-based rm\n                        reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn)\n\n                        new_batch.batch[\"token_level_scores\"] = reward_tensor\n\n                        print(f\"{list(reward_extra_infos_dict.keys())=}\")\n                        if reward_extra_infos_dict:\n                            new_batch.non_tensor_batch.update(\n                                {k: np.array(v) for k, v in reward_extra_infos_dict.items()}\n                            )\n\n                        # compute rewards. apply_kl_penalty if available\n                        if self.config.algorithm.use_kl_in_reward:\n                            new_batch, kl_metrics = apply_kl_penalty(\n                                new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                            )\n                            metrics.update(\n                                kl_metrics\n                            )  # TODO: This will be cleared if we use multiple genenration batches\n                        else:\n                            new_batch.batch[\"token_level_rewards\"] = new_batch.batch[\"token_level_scores\"]\n\n                    if not self.config.algorithm.filter_groups.enable:\n                        batch = new_batch\n                    else:  # NOTE: When prompts after filtering is less than train batch size,\n                        # we skip to the next generation batch\n                        metric_name = self.config.algorithm.filter_groups.metric\n                        if metric_name == \"seq_final_reward\":\n                            # Turn to numpy for easier filtering\n                            new_batch.non_tensor_batch[\"seq_final_reward\"] = (\n                                new_batch.batch[\"token_level_rewards\"].sum(dim=-1).numpy()\n                            )\n                        elif metric_name == \"seq_reward\":\n                            new_batch.non_tensor_batch[\"seq_reward\"] = (\n                                new_batch.batch[\"token_level_scores\"].sum(dim=-1).numpy()\n                            )\n\n                        # Collect the sequence reward for each trajectory\n                        prompt_uid2metric_vals = defaultdict(list)\n                        for uid, metric_val in zip(\n                            new_batch.non_tensor_batch[\"uid\"], new_batch.non_tensor_batch[metric_name], strict=True\n                        ):\n                            prompt_uid2metric_vals[uid].append(metric_val)\n\n                        prompt_uid2metric_std = {}\n                        for prompt_uid, metric_vals in prompt_uid2metric_vals.items():\n                            prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)\n\n                        kept_prompt_uids = [\n                            uid\n                            for uid, std in prompt_uid2metric_std.items()\n                            if std > 0 or len(prompt_uid2metric_vals[uid]) == 1\n                        ]\n                        num_prompt_in_batch += len(kept_prompt_uids)\n\n                        kept_traj_idxs = []\n                        for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch[\"uid\"]):\n                            if traj_from_prompt_uid in kept_prompt_uids:\n                                kept_traj_idxs.append(idx)\n\n                        new_batch = new_batch[kept_traj_idxs]\n                        batch = new_batch if batch is None else DataProto.concat([batch, new_batch])\n\n                        prompt_bsz = self.config.data.train_batch_size\n                        if num_prompt_in_batch < prompt_bsz:\n                            print(f\"{num_prompt_in_batch=} < {prompt_bsz=}\")\n                            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches\n                            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:\n                                print(f\"{num_gen_batches=}. Keep generating...\")\n                                continue\n                            else:\n                                raise ValueError(\n                                    f\"{num_gen_batches=} >= {max_num_gen_batches=}.\"\n                                    + \" Generated too many. Please check if your data are too difficult.\"\n                                    + \" You could also try set max_num_gen_batches=0 to enable endless trials.\"\n                                )\n                        else:\n                            # Align the batch\n                            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n\n                            print(\n                                f\"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. \"\n                                f\"Collecting finished.\"\n                            )\n                            batch = batch[:traj_bsz]\n\n                    # === Updating ===\n                    # balance the number of valid tokens on each dp rank.\n                    # Note that this breaks the order of data inside the batch.\n                    # Please take care when you implement group based adv computation such as GRPO and rloo\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                    if not self.config.algorithm.use_kl_in_reward:\n                        batch = self.compute_kl_related_metrics(batch, timing_raw)\n\n                    # compute values\n                    if self.use_critic:\n                        with simple_timer(\"values\", timing_raw):\n                            values = self.critic_wg.compute_values(batch)\n                            batch = batch.union(values)\n\n                    with simple_timer(\"adv\", timing_raw):\n                        # compute advantages, executed on the driver process\n                        norm_adv_by_std_in_grpo = self.config.algorithm.get(\"norm_adv_by_std_in_grpo\", True)\n                        batch = compute_advantage(\n                            batch,\n                            adv_estimator=self.config.algorithm.adv_estimator,\n                            gamma=self.config.algorithm.gamma,\n                            lam=self.config.algorithm.lam,\n                            num_repeat=self.config.actor_rollout_ref.rollout.n,\n                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                        )\n\n                    # update critic\n                    if self.use_critic:\n                        with simple_timer(\"update_critic\", timing_raw):\n                            critic_output = self.critic_wg.update_critic(batch)\n                        critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                        metrics.update(critic_output_metrics)\n\n                    # implement critic warmup\n                    if self.config.trainer.critic_warmup <= self.global_steps:\n                        # update actor\n                        with simple_timer(\"update_actor\", timing_raw):\n                            actor_output = self.actor_rollout_wg.update_actor(batch)\n                        actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                        metrics.update(actor_output_metrics)\n\n                    # validate\n                    if (\n                        self.val_reward_fn is not None\n                        and self.config.trainer.test_freq > 0\n                        and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                    ):\n                        with simple_timer(\"testing\", timing_raw):\n                            val_metrics: dict = self._validate()\n                            if is_last_step:\n                                last_val_metrics = val_metrics\n                        metrics.update(val_metrics)\n\n                    if self.config.trainer.save_freq > 0 and (\n                        is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n                    ):\n                        with simple_timer(\"save_checkpoint\", timing_raw):\n                            self._save_checkpoint()\n\n                # collect metrics\n                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n                # TODO: implement actual tflpo and theoretical tflpo\n                n_gpus = self.resource_pool_manager.get_n_gpus()\n                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n                timing_raw = defaultdict(float)  # clear timing\n\n                metrics[\"train/num_gen_batches\"] = num_gen_batches\n                batch = None\n                num_prompt_in_batch = 0\n                num_gen_batches = 0\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                if is_last_step:\n                    pprint(f\"Final validation metrics: {last_val_metrics}\")\n                    progress_bar.close()\n                    return\n\n                progress_bar.update(1)\n                self.global_steps += 1\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/main_entropy.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom .entropy_ray_trainer import RayEntropyTrainer\nfrom .reward import load_reward_manager\n\n\n@hydra.main(config_path=\"config\", config_name=\"entropy_trainer\", version_base=None)\ndef main(config):\n    run_ppo(config)\n\n\ndef run_ppo(config) -> None:\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        default_runtime_env = {\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"WARN\",\n            }\n        }\n        ray_init_kwargs = config.ray_kwargs.get(\"ray_init\", {})\n        runtime_env_kwargs = ray_init_kwargs.get(\"runtime_env\", {})\n        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)\n        ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, \"runtime_env\": runtime_env})\n        print(f\"ray init kwargs: {ray_init_kwargs}\")\n        ray.init(**OmegaConf.to_container(ray_init_kwargs))\n\n    runner = TaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n\ndef merge_dict(a: dict, b: dict) -> dict:\n    \"\"\"Return a new dict that has `a` updated with `b` (b wins on conflicts).\n\n    Example::\n\n      >>> d1 = {\"x\": 1, \"y\": 2}\n      >>> d2 = {\"y\": 20, \"z\": 3}\n      >>> new_dict = merge_dict(d1, d2)\n      >>> print(new_dict)   # {'x': 1, 'y': 20, 'z': 3}\n      >>> print(d1)         # {\"x\": 1, \"y\": 2} (unchanged)\n      >>> print(d2)         # {\"y\": 20, \"z\": 3} (unchanged)\n    \"\"\"\n    return a | b\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    def run(self, config):\n        # print initial config\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n        OmegaConf.resolve(config)\n\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n        print(f\"{config.actor_rollout_ref.model.path}\")\n        # instantiate tokenizer\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        processor = hf_processor(local_path, use_fast=True)  # used for multimodal LLM, could be none\n\n        # define worker classes\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n            from verl.single_controller.ray import RayWorkerGroup\n            from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray import RayWorkerGroup\n            from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker\n\n            actor_rollout_cls = ActorRolloutRefWorker\n            ray_worker_group_cls = RayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n        role_worker_mapping = {\n            Role.ActorRollout: ray.remote(actor_rollout_cls),\n            Role.Critic: ray.remote(CriticWorker),\n        }\n\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {\n            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n        }\n        mapping = {\n            Role.ActorRollout: global_pool_id,\n            Role.Critic: global_pool_id,\n        }\n\n        # we should adopt a multi-source reward function here\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # - finally, we combine all the rewards together\n        # - The reward type depends on the tag of the data\n        if config.reward_model.enable:\n            if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                from verl.workers.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # use reference model\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n            mapping[Role.RefPolicy] = global_pool_id\n\n        reward_kwargs = {\n            \"max_resp_len\": config.data.max_response_length,\n            \"overlong_buffer_cfg\": config.reward_model.overlong_buffer,\n        }\n        cfg_reward_kwargs = config.reward_model.get(\"reward_kwargs\", {})\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **OmegaConf.merge(OmegaConf.create(reward_kwargs), cfg_reward_kwargs)\n        )\n        val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **reward_kwargs)\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        from verl.utils.dataset.rl_dataset import collate_fn\n\n        train_dataset = create_rl_dataset(\n            config.data.train_files,\n            config.data,\n            tokenizer,\n            processor,\n            max_samples=config.data.get(\"train_max_samples\", -1),\n        )\n        val_dataset = create_rl_dataset(\n            config.data.val_files, config.data, tokenizer, processor, max_samples=config.data.get(\"val_max_samples\", -1)\n        )\n        train_sampler = create_rl_sampler(config.data, train_dataset)\n        trainer = RayEntropyTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            train_dataset=train_dataset,\n            val_dataset=val_dataset,\n            collate_fn=collate_fn,\n            train_sampler=train_sampler,\n        )\n        trainer.init_workers()\n        trainer.fit()\n\n\ndef create_rl_dataset(data_paths, data_config, tokenizer, processor, max_samples: int = -1):\n    \"\"\"Create a dataset.\n\n    Arguments:\n        data_config: The data config.\n        tokenizer (Tokenizer): The tokenizer.\n        processor (Processor): The processor.\n\n    Returns:\n        dataset (Dataset): The dataset.\n    \"\"\"\n    from torch.utils.data import Dataset\n\n    from verl.utils.dataset.rl_dataset import RLHFDataset\n\n    if \"custom_cls\" in data_config and data_config.custom_cls.get(\"path\", None) is not None:\n        from verl.utils.import_utils import load_extern_type\n\n        dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)\n        if not issubclass(dataset_cls, Dataset):\n            raise TypeError(\n                f\"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' \"\n                f\"must inherit from torch.utils.data.Dataset\"\n            )\n    else:\n        dataset_cls = RLHFDataset\n    print(f\"Using dataset class: {dataset_cls.__name__}\")\n\n    dataset = dataset_cls(\n        data_files=data_paths,\n        tokenizer=tokenizer,\n        processor=processor,\n        config=data_config,\n        max_samples=max_samples,\n    )\n\n    return dataset\n\n\ndef create_rl_sampler(data_config, dataset):\n    \"\"\"Create a sampler for the dataset.\n\n    Arguments:\n        data_config: The data config.\n        dataset (Dataset): The dataset.\n\n    Returns:\n        sampler (Sampler): The sampler.\n    \"\"\"\n    import torch\n    from torch.utils.data import RandomSampler, SequentialSampler\n\n    # use sampler for better ckpt resume\n    if data_config.shuffle:\n        train_dataloader_generator = torch.Generator()\n        seed = data_config.get(\"seed\")\n        if seed is not None:\n            train_dataloader_generator.manual_seed(seed)\n        sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)\n    else:\n        sampler = SequentialSampler(data_source=dataset)\n\n    return sampler\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/reward.py",
    "content": "# Copyright 2025 Individual Contributor: Thibaut Barroyer\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport multiprocessing\nfrom functools import partial\n\nimport ray\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.reward import compute_reward, get_custom_reward_fn\n\nfrom .reward_score import _default_compute_score\n\n\ndef load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):\n    \"\"\"\n    Load and initialize a reward manager based on the configuration.\n\n    Args:\n        config: PPO trainer configuration object containing reward_model fields.\n        tokenizer: Tokenizer object used for processing text.\n        num_examine: Number of samples to examine.\n        **reward_kwargs: Additional keyword arguments for the reward manager.\n\n    Returns:\n        An instance of the specified reward manager class.\n    \"\"\"\n    from verl.workers.reward_manager import get_reward_manager_cls\n\n    # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:\n    # naive: NaiveRewardManager\n    # prime: PrimeRewardManager\n    # batch: BatchRewardManager\n    # dapo: DAPORewardManager\n    # Note(haibin.lin): For custom reward managers, please make sure they are imported and\n    # registered via `verl.workers.reward_manager.register`\n    # By default reward_manager is set to naive (NaiveRewardManager)\n    reward_manager_name = config.reward_model.get(\"reward_manager\", \"naive\")\n    reward_manager_cls = get_reward_manager_cls(reward_manager_name)\n\n    # Try to get a custom reward function based on the configuration\n    compute_score = get_custom_reward_fn(config)\n    final_compute_score = compute_score\n\n    if compute_score is None:\n        sandbox_config = config.reward_model.get(\"sandbox_fusion\")\n        sandbox_url = sandbox_config.get(\"url\") if sandbox_config else None\n        if sandbox_url:\n            sandbox_manager = multiprocessing.Manager()\n            # Create a semaphore to control concurrent access to the sandbox\n            _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get(\"max_concurrent\", 64))\n            final_compute_score = partial(\n                _default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore\n            )\n        else:\n            final_compute_score = _default_compute_score\n\n    # Instantiate and return the reward manager with the specified parameters\n    return reward_manager_cls(\n        tokenizer=tokenizer,\n        num_examine=num_examine,\n        compute_score=final_compute_score,\n        reward_fn_key=config.data.reward_fn_key,\n        **reward_kwargs,\n    )\n\n\n@ray.remote(num_cpus=1)\ndef compute_reward_async(data: DataProto, config, tokenizer):\n    \"\"\"\n    Load the reward manager and compute the reward for a batch of data.\n    This is meant to be run in a separate Ray worker.\n    \"\"\"\n    reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {}))\n    return compute_reward(data, reward_fn)\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/reward_score/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# from . import gsm8k, math, prime_math, prime_code\n\nimport traceback\n\nfrom . import entropy_math\n\n\ndef _default_compute_score(\n    data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None\n):\n    try:\n        res = entropy_math.compute_score(solution_str, str(ground_truth))\n        # print(f\"data_source: {data_source}\")\n        # raise NotImplementedError(f\"Reward function is not implemented for {data_source=}\")\n\n        if isinstance(res, dict):\n            return res\n        elif isinstance(res, int | float | bool):\n            return float(res)\n        else:\n            return float(res[0])\n    except Exception as e:\n        print(f\"[ERROR] Error in process_completion for task : {str(e)}\")\n        traceback.print_exc()  # 打印完整堆栈\n        raise  # 重新抛出异常以便上层捕获\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/reward_score/entropy_math/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except Exception in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Provides a math answer grading function with high recall.\nBased on HF math_verify, verl, open reasoner zero, etc.\n\"\"\"\n\nimport os\nimport re\nimport signal\nfrom itertools import islice, zip_longest\nfrom math import isclose\nfrom typing import Optional\n\nimport sympy\nfrom latex2sympy2_extended import latex2sympy\nfrom math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify\nfrom pylatexenc import latex2text\nfrom sympy import N, simplify\nfrom sympy.parsing import sympy_parser\nfrom sympy.parsing.latex import parse_latex\nfrom sympy.parsing.sympy_parser import parse_expr\n\n\"\"\"\nThis code is adapted from: Dr. GRPO (https://github.com/sail-sg/understand-r1-zero/blob/main/understand_r1_zero/math_grader.py).\n\"\"\"\n\n\ndef timeout_ours(timeout_seconds: int = 8):\n    if os.name == \"posix\":\n        import signal\n\n        def decorator(func):\n            def handler(signum, frame):\n                raise TimeoutError(\"Operation timed out!\")\n\n            def wrapper(*args, **kwargs):\n                old_handler = signal.getsignal(signal.SIGALRM)\n                signal.signal(signal.SIGALRM, handler)\n                signal.alarm(timeout_seconds)\n\n                try:\n                    return func(*args, **kwargs)\n                finally:\n                    signal.alarm(0)\n                    signal.signal(signal.SIGALRM, old_handler)\n\n            return wrapper\n\n        return decorator\n    else:\n        raise NotImplementedError(f\"Unsupported OS: {os.name}\")\n\n\n# Dan Hendrycks' code\ndef mathd_normalize_answer(answer: Optional[str]) -> Optional[str]:\n    if answer is None:\n        return None\n    answer = answer.strip()\n    try:\n        # Remove enclosing `\\text{}`.\n        m = re.search(r\"^\\\\text\\{(?P<text>.+?)\\}$\", answer)\n        if m is not None:\n            answer = m.group(\"text\").strip()\n        return _strip_string(answer)\n    except Exception:\n        return answer\n\n\n# units mainly from MathQA\nunit_texts = [\n    \"east\",\n    \"degree\",\n    \"mph\",\n    \"kmph\",\n    \"ft\",\n    \"m sqaure\",\n    \" m east\",\n    \"sq m\",\n    \"deg\",\n    \"mile\",\n    \"q .\",\n    \"monkey\",\n    \"prime\",\n    \"ratio\",\n    \"profit of rs\",\n    \"rd\",\n    \"o\",\n    \"gm\",\n    \"p . m\",\n    \"lb\",\n    \"tile\",\n    \"per\",\n    \"dm\",\n    \"lt\",\n    \"gain\",\n    \"ab\",\n    \"way\",\n    \"west\",\n    \"a .\",\n    \"b .\",\n    \"c .\",\n    \"d .\",\n    \"e .\",\n    \"f .\",\n    \"g .\",\n    \"h .\",\n    \"t\",\n    \"a\",\n    \"h\",\n    \"no change\",\n    \"men\",\n    \"soldier\",\n    \"pie\",\n    \"bc\",\n    \"excess\",\n    \"st\",\n    \"inches\",\n    \"noon\",\n    \"percent\",\n    \"by\",\n    \"gal\",\n    \"kmh\",\n    \"c\",\n    \"acre\",\n    \"rise\",\n    \"a . m\",\n    \"th\",\n    \"π r 2\",\n    \"sq\",\n    \"mark\",\n    \"l\",\n    \"toy\",\n    \"coin\",\n    \"sq . m\",\n    \"gallon\",\n    \"° f\",\n    \"profit\",\n    \"minw\",\n    \"yr\",\n    \"women\",\n    \"feet\",\n    \"am\",\n    \"pm\",\n    \"hr\",\n    \"cu cm\",\n    \"square\",\n    \"v â € ™\",\n    \"are\",\n    \"rupee\",\n    \"rounds\",\n    \"cubic\",\n    \"cc\",\n    \"mtr\",\n    \"s\",\n    \"ohm\",\n    \"number\",\n    \"kmph\",\n    \"day\",\n    \"hour\",\n    \"minute\",\n    \"min\",\n    \"second\",\n    \"man\",\n    \"woman\",\n    \"sec\",\n    \"cube\",\n    \"mt\",\n    \"sq inch\",\n    \"mp\",\n    \"∏ cm ³\",\n    \"hectare\",\n    \"more\",\n    \"sec\",\n    \"unit\",\n    \"cu . m\",\n    \"cm 2\",\n    \"rs .\",\n    \"rs\",\n    \"kg\",\n    \"g\",\n    \"month\",\n    \"km\",\n    \"m\",\n    \"cm\",\n    \"mm\",\n    \"apple\",\n    \"liter\",\n    \"loss\",\n    \"yard\",\n    \"pure\",\n    \"year\",\n    \"increase\",\n    \"decrease\",\n    \"d\",\n    \"less\",\n    \"Surface\",\n    \"litre\",\n    \"pi sq m\",\n    \"s .\",\n    \"metre\",\n    \"meter\",\n    \"inch\",\n]\n\nunit_texts.extend([t + \"s\" for t in unit_texts])\n\n\ndef _strip_string(string):\n    def _fix_fracs(string):\n        substrs = string.split(\"\\\\frac\")\n        new_str = substrs[0]\n        if len(substrs) > 1:\n            substrs = substrs[1:]\n            for substr in substrs:\n                new_str += \"\\\\frac\"\n                if substr[0] == \"{\":\n                    new_str += substr\n                else:\n                    try:\n                        assert len(substr) >= 2\n                    except Exception:\n                        return string\n                    a = substr[0]\n                    b = substr[1]\n                    if b != \"{\":\n                        if len(substr) > 2:\n                            post_substr = substr[2:]\n                            new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                        else:\n                            new_str += \"{\" + a + \"}{\" + b + \"}\"\n                    else:\n                        if len(substr) > 2:\n                            post_substr = substr[2:]\n                            new_str += \"{\" + a + \"}\" + b + post_substr\n                        else:\n                            new_str += \"{\" + a + \"}\" + b\n        string = new_str\n        return string\n\n    def _fix_a_slash_b(string):\n        if len(string.split(\"/\")) != 2:\n            return string\n        a = string.split(\"/\")[0]\n        b = string.split(\"/\")[1]\n        try:\n            a = int(a)\n            b = int(b)\n            assert string == \"{}/{}\".format(a, b)\n            new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n            return new_string\n        except Exception:\n            return string\n\n    def _remove_right_units(string):\n        # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n        if \"\\\\text{ \" in string:\n            splits = string.split(\"\\\\text{ \")\n            assert len(splits) == 2\n            return splits[0]\n        else:\n            return string\n\n    def _fix_sqrt(string):\n        if \"\\\\sqrt\" not in string:\n            return string\n        splits = string.split(\"\\\\sqrt\")\n        new_string = splits[0]\n        for split in splits[1:]:\n            if split[0] != \"{\":\n                a = split[0]\n                new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n            else:\n                new_substr = \"\\\\sqrt\" + split\n            new_string += new_substr\n        return new_string\n\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n    # print(string)\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n    # print(string)\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n    # print(string)\n\n    # matrix\n    string = re.sub(r\"\\\\begin\\{array\\}\\{.*?\\}\", r\"\\\\begin{pmatrix}\", string)\n    string = re.sub(r\"\\\\end\\{array\\}\", r\"\\\\end{pmatrix}\", string)\n    string = string.replace(\"bmatrix\", \"pmatrix\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n    string = string.replace(\"\\\\neq\", \"\\\\ne\").replace(\"\\\\leq\", \"\\\\le\").replace(\"\\\\geq\", \"\\\\ge\")\n    # print(string)\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n    # print(string)\n\n    # Remove unit: miles, dollars if after is not none\n    _string = re.sub(r\"\\\\text{.*?}$\", \"\", string).strip()\n    if _string != \"\" and _string != string:\n        # print(\"Warning: unit not removed: '{}' -> '{}'\".format(string, _string))\n        string = _string\n\n    # Remove unit: texts\n    for _ in range(2):\n        for unit_text in unit_texts:\n            # use regex, the prefix should be either the start of the string or a non-alphanumeric character\n            # the suffix should be either the end of the string or a non-alphanumeric character\n            _string = re.sub(r\"(^|\\W)\" + unit_text + r\"($|\\W)\", r\"\\1\\2\", string)\n            if _string != \"\":\n                string = _string\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = _remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\\\\\%\", \"\")\n    string = string.replace(\"\\\\%\", \"\")\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2:\n        if len(string.split(\"=\")[0]) <= 2:\n            string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = _fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1).\n    # Also does a/b --> \\\\frac{a}{b}\n    string = _fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = _fix_a_slash_b(string)\n\n    return string\n\n\nSUBSTITUTIONS = [\n    (\"an \", \"\"),\n    (\"a \", \"\"),\n    (\".$\", \"$\"),\n    (\"\\\\$\", \"\"),\n    (r\"\\ \", \"\"),\n    (\" \", \"\"),\n    (\"mbox\", \"text\"),\n    (\",\\\\text{and}\", \",\"),\n    (\"\\\\text{and}\", \",\"),\n    (\"\\\\text{m}\", \"\\\\text{}\"),\n]\n\n\nREMOVED_EXPRESSIONS = [\n    \"square\",\n    \"ways\",\n    \"integers\",\n    \"dollars\",\n    \"mph\",\n    \"inches\",\n    \"ft\",\n    \"hours\",\n    \"km\",\n    \"units\",\n    \"\\\\ldots\",\n    \"sue\",\n    \"points\",\n    \"feet\",\n    \"minutes\",\n    \"digits\",\n    \"cents\",\n    \"degrees\",\n    \"cm\",\n    \"gm\",\n    \"pounds\",\n    \"meters\",\n    \"meals\",\n    \"edges\",\n    \"students\",\n    \"childrentickets\",\n    \"multiples\",\n    \"\\\\text{s}\",\n    \"\\\\text{.}\",\n    \"\\\\text{\\ns}\",\n    \"\\\\text{}^2\",\n    \"\\\\text{}^3\",\n    \"\\\\text{\\n}\",\n    \"\\\\text{}\",\n    r\"\\mathrm{th}\",\n    r\"^\\circ\",\n    r\"^{\\circ}\",\n    r\"\\;\",\n    r\",\\!\",\n    \"{,}\",\n    '\"',\n    \"\\\\dots\",\n]\n\n\ndef normalize_final_answer(final_answer: str) -> str:\n    \"\"\"\n    Normalize a final answer to a quantitative reasoning question.\n    This code comes from https://arxiv.org/pdf/2206.14858.pdf, page18.\n    \"\"\"\n    # final_answer = final_answer.split(\"=\")[-1]\n\n    for before, after in SUBSTITUTIONS:\n        final_answer = final_answer.replace(before, after)\n    for expr in REMOVED_EXPRESSIONS:\n        final_answer = final_answer.replace(expr, \"\")\n\n    # Extract answer that is in LaTeX math, is bold,\n    # is surrounded by a box, etc.\n    final_answer = re.sub(r\"(.*?)(\\$)(.*?)(\\$)(.*)\", \"$\\\\3$\", final_answer)\n    final_answer = re.sub(r\"(\\\\text\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\textbf\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\overline\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\boxed\\{)(.*)(\\})\", \"\\\\2\", final_answer)\n\n    # Normalize shorthand TeX:\n    # \\fracab -> \\frac{a}{b}\n    # \\frac{abc}{bef} -> \\frac{abc}{bef}\n    # \\fracabc -> \\frac{a}{b}c\n    # \\sqrta -> \\sqrt{a}\n    # \\sqrtab -> sqrt{a}b\n    final_answer = re.sub(r\"(frac)([^{])(.)\", \"frac{\\\\2}{\\\\3}\", final_answer)\n    final_answer = re.sub(r\"(sqrt)([^{])\", \"sqrt{\\\\2}\", final_answer)\n    final_answer = final_answer.replace(\"$\", \"\")\n\n    # Normalize 100,000 -> 100000\n    if final_answer.replace(\",\", \"\").isdigit():\n        final_answer = final_answer.replace(\",\", \"\")\n\n    return final_answer\n\n\ndef repeatness(s: str):\n    def ranks(seq):\n        index = {v: i for i, v in enumerate(sorted(set(seq)))}\n        return [index[v] for v in seq]\n\n    def suffixArray(s):\n        line = ranks(s)\n        n, k, ans, sa = len(s), 1, line, [0] * len(s)\n        while k < n - 1:\n            line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))\n            ans, k = line, k << 1\n        for i, k in enumerate(ans):\n            sa[k] = i\n        return ans, sa\n\n    def lcp(arr, suffixArr, inv_suff):\n        n, ans, k = len(arr), [0] * len(arr), 0\n\n        for i in range(n):\n            if inv_suff[i] == n - 1:\n                k = 0\n                continue\n\n            j = suffixArr[inv_suff[i] + 1]\n            while i + k < n and j + k < n and arr[i + k] == arr[j + k]:\n                k += 1\n\n            ans[inv_suff[i]] = k\n            if k > 0:\n                k -= 1\n\n        return ans\n\n    arr = [ord(i) for i in s]\n    n = len(arr)\n    if n <= 1:\n        return 0\n    c, sa = suffixArray(arr)\n    cnt = sum(lcp(arr, sa, c))\n\n    return (cnt * 2 / (n * (n + 1))) > 0.2\n\n\nclass timeout:\n    def __init__(self, seconds=1, error_message=\"Timeout\"):\n        self.seconds = seconds\n        self.error_message = error_message\n\n    def handle_timeout(self, signum, frame):\n        raise TimeoutError(self.error_message)\n\n    def __enter__(self):\n        signal.signal(signal.SIGALRM, self.handle_timeout)\n        signal.alarm(self.seconds)\n\n    def __exit__(self, type, value, traceback):\n        signal.alarm(0)\n\n\ndef latex_eval(latex):\n    sym = parse_latex(latex)\n    val = sym.evalf()\n    return sym, val\n\n\ndef numeric_equal(prediction: float, reference: float):\n    # Note that relative tolerance has significant impact\n    # on the result of the synthesized GSM-Hard dataset\n    # if reference.is_integer():\n    #     return isclose(reference, round(prediction), abs_tol=1e-4)\n    # else:\n    # prediction = round(prediction, len(str(reference).split(\".\")[-1]))\n    return isclose(reference, prediction, rel_tol=1e-4)\n\n\n@timeout_ours(timeout_seconds=5)\ndef symbolic_equal(a, b):\n    def _parse(s):\n        for f in [parse_latex, parse_expr, latex2sympy]:\n            try:\n                return f(s.replace(\"\\\\\\\\\", \"\\\\\"))\n            except Exception:\n                try:\n                    return f(s)\n                except Exception:\n                    pass\n        return s\n\n    a = _parse(a)\n    b = _parse(b)\n\n    # direct equal\n    try:\n        if str(a) == str(b) or a == b:\n            return True\n    except Exception:\n        pass\n\n    # simplify equal\n    try:\n        if a.equals(b) or simplify(a - b) == 0:\n            return True\n    except Exception:\n        pass\n\n    # equation equal\n    try:\n        if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):\n            return True\n    except Exception:\n        pass\n\n    try:\n        if numeric_equal(float(N(a)), float(N(b))):\n            return True\n    except Exception:\n        pass\n\n    # matrix\n    try:\n        # if a and b are matrix\n        if a.shape == b.shape:\n            _a = a.applyfunc(lambda x: round(x, 3))\n            _b = b.applyfunc(lambda x: round(x, 3))\n            if _a.equals(_b):\n                return True\n    except Exception:\n        pass\n\n    return False\n\n\ndef _is_latex_equal(str1, str2):\n    try:\n        sym1, val1 = latex_eval(str1)\n        sym2, val2 = latex_eval(str2)\n        if sym1 == sym2 or val1 == val2:\n            return True\n        else:\n            raise ValueError\n    except Exception:\n        try:\n            norm1, norm2 = normalize_final_answer(str1), normalize_final_answer(str2)\n            sym1, val1 = latex_eval(norm1)\n            sym2, val2 = latex_eval(norm2)\n            if sym1 == sym2 or val1 == val2:\n                return True\n        except Exception:\n            return norm1 == norm2\n    return False\n\n\ndef is_latex_equal(given_answer: str, ground_truth: str) -> bool:\n    try:\n        with timeout(1):\n            try:\n                if (len(given_answer) > 128 and repeatness(given_answer)) or (\n                    len(ground_truth) > 128 and repeatness(ground_truth)\n                ):\n                    return False\n                # First conduct normalized string matching.\n                ground_truth_normalized = _normalize(ground_truth)\n                given_normalized = _normalize(given_answer)\n                if ground_truth_normalized is None:\n                    return False\n                if ground_truth_normalized == given_normalized:\n                    return True\n\n                # Next call math verify.\n                given_answer.replace(\"\\n\", \"\")\n                ground_truth.replace(\"\\n\", \"\")\n                if \"$\" not in given_answer:\n                    given_answer = f\"${given_answer}$\"\n                if \"$\" not in ground_truth:\n                    ground_truth = f\"${ground_truth}$\"\n                return verify(\n                    parse(\n                        ground_truth,\n                        extraction_config=(\n                            LatexExtractionConfig(boxed_match_priority=0),\n                            ExprExtractionConfig(),\n                        ),\n                        fallback_mode=\"no_fallback\",\n                        extraction_mode=[\"first_match\"],\n                        parsing_timeout=1,\n                    ),\n                    parse(\n                        given_answer,\n                        extraction_config=(\n                            LatexExtractionConfig(boxed_match_priority=0),\n                            ExprExtractionConfig(),\n                        ),\n                        fallback_mode=\"no_fallback\",\n                        extraction_mode=[\"first_match\"],\n                        parsing_timeout=1,\n                    ),\n                    timeout_seconds=1,\n                )\n                # or symbolic_equal(ground_truth, given_answer)\n            except Exception:\n                return False\n    except TimeoutError:\n        return False\n\n\ndef is_value_equal(given_answer: str, ground_truth: str) -> bool:\n    assert ground_truth is not None\n    ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth)\n    given_answer_normalized_mathd = mathd_normalize_answer(given_answer)\n\n    str_equal = ground_truth_normalized_mathd == given_answer_normalized_mathd\n    try:\n        number_equal = float(ground_truth_normalized_mathd) == float(given_answer_normalized_mathd)\n        return str_equal or number_equal\n    except Exception:\n        return str_equal\n\n\n# sympy might hang -- we don't care about trying to be lenient in these cases\nBAD_SUBSTRINGS = [\"^{\", \"^(\"]\nBAD_REGEXES = [r\"\\^[0-9]+\\^\", r\"\\^[0-9][0-9]+\"]\nTUPLE_CHARS = \"()[]\"\n\n\ndef _sympy_parse(expr: str):\n    \"\"\"Parses an expression with sympy.\"\"\"\n    py_expr = expr.replace(\"^\", \"**\")\n    return sympy_parser.parse_expr(\n        py_expr,\n        transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)),\n    )\n\n\ndef _parse_latex(expr: str) -> str:\n    \"\"\"Attempts to parse latex to an expression sympy can read.\"\"\"\n    expr = expr.replace(\"\\\\tfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\dfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\frac\", \" \\\\frac\")  # Play nice with mixed numbers.\n    expr = latex2text.LatexNodes2Text().latex_to_text(expr)\n\n    # Replace the specific characters that this parser uses.\n    expr = expr.replace(\"√\", \"sqrt\")\n    expr = expr.replace(\"π\", \"pi\")\n    expr = expr.replace(\"∞\", \"inf\")\n    expr = expr.replace(\"∪\", \"U\")\n    expr = expr.replace(\"·\", \"*\")\n    expr = expr.replace(\"×\", \"*\")\n\n    return expr.strip()\n\n\ndef _is_float(num: str) -> bool:\n    try:\n        float(num)\n        return True\n    except ValueError:\n        return False\n\n\ndef _is_int(x: float) -> bool:\n    try:\n        return abs(x - int(round(x))) <= 1e-7\n    except Exception:\n        return False\n\n\ndef _is_frac(expr: str) -> bool:\n    return bool(re.search(r\"^-?[0-9]+.?/0*[1-9][0-9]*.?$\", expr))\n\n\ndef _str_is_int(x: str) -> bool:\n    try:\n        x = _strip_properly_formatted_commas(x)\n        x = float(x)\n        return abs(x - int(round(x))) <= 1e-7\n    except Exception:\n        return False\n\n\ndef _str_to_int(x: str) -> bool:\n    x = x.replace(\",\", \"\")\n    x = float(x)\n    return int(x)\n\n\ndef _inject_implicit_mixed_number(step: str):\n    \"\"\"\n    Automatically make a mixed number evalable\n    e.g. 7 3/4 => 7+3/4\n    \"\"\"\n    p1 = re.compile(\"([0-9]) +([0-9])\")\n    step = p1.sub(\"\\\\1+\\\\2\", step)  ## implicit mults\n    return step\n\n\ndef _strip_properly_formatted_commas(expr: str):\n    # We want to be careful because we don't want to strip tuple commas\n    p1 = re.compile(r\"(\\d)(,)(\\d\\d\\d)($|\\D)\")\n    while True:\n        next_expr = p1.sub(\"\\\\1\\\\3\\\\4\", expr)\n        if next_expr == expr:\n            break\n        expr = next_expr\n    return next_expr\n\n\ndef _normalize(expr: str) -> str:\n    \"\"\"Normalize answer expressions.\"\"\"\n    if expr is None:\n        return None\n\n    # Remove enclosing `\\text{}`.\n    m = re.search(r\"^\\\\text\\{(?P<text>.+?)\\}$\", expr)\n    if m is not None:\n        expr = m.group(\"text\")\n\n    expr = expr.replace(\"\\\\%\", \"%\")\n    expr = expr.replace(\"\\\\$\", \"$\")\n    expr = expr.replace(\"$\", \"\")\n    expr = expr.replace(\"%\", \"\")\n    expr = expr.replace(\" or \", \" , \")\n    expr = expr.replace(\" and \", \" , \")\n\n    expr = expr.replace(\"million\", \"*10^6\")\n    expr = expr.replace(\"billion\", \"*10^9\")\n    expr = expr.replace(\"trillion\", \"*10^12\")\n\n    for unit in [\n        \"degree\",\n        \"cm\",\n        \"centimeter\",\n        \"meter\",\n        \"mile\",\n        \"second\",\n        \"minute\",\n        \"hour\",\n        \"day\",\n        \"week\",\n        \"month\",\n        \"year\",\n        \"foot\",\n        \"feet\",\n        \"inch\",\n        \"yard\",\n    ]:\n        expr = re.sub(f\"{unit}(es)?(s)? *(\\\\^[0-9]+)?\", \"\", expr)\n    expr = re.sub(r\"\\^ *\\\\circ\", \"\", expr)\n\n    if len(expr) > 0 and expr[0] == \"{\" and expr[-1] == \"}\":\n        expr = expr[1:-1]\n\n    expr = re.sub(\",\\\\\\\\! *\", \"\", expr)\n    if _is_float(expr) and _is_int(float(expr)):\n        expr = str(int(round(float(expr))))\n    if \"\\\\\" in expr:\n        try:\n            expr = _parse_latex(expr)\n        except Exception:\n            pass\n\n    # edge case with mixed numbers and negative signs\n    expr = re.sub(\"- *\", \"-\", expr)\n\n    expr = _inject_implicit_mixed_number(expr)\n    expr = expr.replace(\" \", \"\")\n\n    # if we somehow still have latex braces here, just drop them\n    expr = expr.replace(\"{\", \"\")\n    expr = expr.replace(\"}\", \"\")\n\n    # don't be case sensitive for text answers\n    expr = expr.lower()\n\n    if _str_is_int(expr):\n        expr = str(_str_to_int(expr))\n\n    return expr\n\n\ndef count_unknown_letters_in_expr(expr: str):\n    expr = expr.replace(\"sqrt\", \"\")\n    expr = expr.replace(\"frac\", \"\")\n    letters_in_expr = set([x for x in expr if x.isalpha()])\n    return len(letters_in_expr)\n\n\ndef should_allow_eval(expr: str):\n    # we don't want to try parsing unknown text or functions of more than two variables\n    if count_unknown_letters_in_expr(expr) > 2:\n        return False\n\n    for bad_string in BAD_SUBSTRINGS:\n        if bad_string in expr:\n            return False\n\n    for bad_regex in BAD_REGEXES:\n        if re.search(bad_regex, expr) is not None:\n            return False\n\n    return True\n\n\n@timeout_ours(timeout_seconds=5)\ndef are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):\n    are_equal = False\n    try:\n        expr = f\"({ground_truth_normalized})-({given_normalized})\"\n        if should_allow_eval(expr):\n            sympy_diff = _sympy_parse(expr)\n            simplified = sympy.simplify(sympy_diff)\n            if simplified == 0:\n                are_equal = True\n    except Exception:\n        pass\n    return are_equal\n\n\ndef split_tuple(expr: str):\n    \"\"\"\n    Split the elements in a tuple/interval, while handling well-formatted commas in large numbers\n    \"\"\"\n    expr = _strip_properly_formatted_commas(expr)\n    if len(expr) == 0:\n        return []\n    if (\n        len(expr) > 2\n        and expr[0] in TUPLE_CHARS\n        and expr[-1] in TUPLE_CHARS\n        and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])\n    ):\n        elems = [elem.strip() for elem in expr[1:-1].split(\",\")]\n    else:\n        elems = [expr]\n    return elems\n\n\ndef last_boxed_only_string(string):\n    idx = string.rfind(\"\\\\boxed\")\n    if idx < 0:\n        idx = string.rfind(\"\\\\fbox\")\n        if idx < 0:\n            return None\n\n    i = idx\n    right_brace_idx = None\n    num_left_braces_open = 0\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n        if string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n        i += 1\n    if right_brace_idx is None:\n        retval = None\n    else:\n        retval = string[idx : right_brace_idx + 1]\n\n    return retval\n\n\ndef remove_boxed(s):\n    left = \"\\\\boxed{\"\n    try:\n        assert s[: len(left)] == left\n        assert s[-1] == \"}\"\n        return s[len(left) : -1]\n    except Exception:\n        return None\n\n\ndef extract_boxed_answer(solution: str) -> str:\n    \"\"\"Extract the answer from inside a LaTeX \\\\boxed{} command\"\"\"\n    solution = last_boxed_only_string(solution)\n    solution = remove_boxed(solution)\n    return solution\n\n\ndef grade_answer_sympy(given_answer: str, ground_truth: str) -> bool:\n    ground_truth_normalized = _normalize(ground_truth)\n    given_normalized = _normalize(given_answer)\n\n    if ground_truth_normalized is None:\n        return False\n\n    if ground_truth_normalized == given_normalized:\n        return True\n\n    if len(given_normalized) == 0:\n        return False\n\n    ground_truth_elems = split_tuple(ground_truth_normalized)\n    given_elems = split_tuple(given_normalized)\n\n    if len(ground_truth_elems) > 1 and (\n        ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]\n    ):\n        is_correct = False\n    elif len(ground_truth_elems) != len(given_elems):\n        is_correct = False\n    else:\n        for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True):\n            if _is_frac(ground_truth_elem) and _is_frac(given_elem):\n                # if fractions aren't reduced, then shouldn't be marked as correct\n                # so, we don't want to allow sympy.simplify in this case\n                is_correct = ground_truth_elem == given_elem\n            elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):\n                # if the ground truth answer is an integer, we require the given answer to be a strict match\n                # (no sympy.simplify)\n                is_correct = False\n            else:\n                is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)\n            if not is_correct:\n                break\n\n    return is_correct\n\n\ndef grade_answer_mathd(given_answer: str, ground_truth: str) -> bool:\n    ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth)\n    given_answer_normalized_mathd = mathd_normalize_answer(given_answer)\n\n    # be at least as lenient as mathd\n    if ground_truth_normalized_mathd == given_answer_normalized_mathd:\n        return True\n    return False\n\n\ndef extract_answer(passage: str) -> str:\n    if \"\\\\boxed\" in passage:\n        return extract_boxed_answer(passage)\n    return None\n\n\ndef grade(model_answer: str, gt_answer: str, fast: bool = True):\n    if \"\\\\boxed\" in gt_answer:\n        gt_answer = extract_answer(gt_answer)\n    correct = grade_answer_mathd(model_answer, gt_answer) or grade_answer_sympy(model_answer, gt_answer)\n    if not fast:\n        # This mode further uses math_verify to recall originally false positives.\n        # Will be a bit slower, and sensitive to bad inputs.\n        correct = correct or is_latex_equal(\n            model_answer,\n            gt_answer,\n        )\n    return correct\n\n\ndef compute_score(model_response, gt_answer, fast=False):\n    model_answer = extract_answer(model_response)\n    if model_answer is None:\n        return {\n            \"score\": 0.0,\n            \"format_score\": 0.0,\n            \"acc\": False,\n            \"extracted_gt\": gt_answer,\n            # \"extracted_pred\": None,\n        }\n        # return 0.0, 0.0  # Cannot even parse anything.\n    is_correct = False\n    if isinstance(gt_answer, float) or isinstance(gt_answer, int):\n        gt_answer = str(gt_answer)\n    if isinstance(gt_answer, str):\n        is_correct = grade(model_answer, gt_answer, fast)\n    elif isinstance(gt_answer, list):\n        is_correct = False\n        for gt in gt_answer:\n            is_correct |= grade(model_answer, gt, fast)\n    if is_correct:\n        return {\n            \"score\": 1.0,\n            \"format_score\": 1.0,\n            \"acc\": True,\n            \"extracted_gt\": gt_answer,\n            # \"extracted_pred\": None,\n        }\n    else:\n        return {\n            \"score\": 0.0,\n            \"format_score\": 1.0,\n            \"acc\": False,\n            \"extracted_gt\": gt_answer,\n            # \"extracted_pred\": None,\n        }\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/reward_score/entropy_math/grader.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright (c) Microsoft Corporation.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE\n\n# Copyright (c) 2023 OpenAI\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Copyright (c) 2021 Dan Hendrycks\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:\n- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py\n- https://github.com/microsoft/ProphetNet/tree/master/CRITIC\n- https://github.com/openai/prm800k\n\"\"\"\n\nimport contextlib\nimport math\nimport re\nfrom math import isclose\n\n# sympy related\nfrom sympy import N, simplify\nfrom sympy.parsing.latex import parse_latex\nfrom sympy.parsing.sympy_parser import parse_expr\n\n# verl related\nfrom verl.utils.py_functional import timeout_limit\n\n\ndef is_digit(s):\n    try:\n        if \"{,}\" in str(s):\n            num = float(str(s).replace(\"{,}\", \"\"))\n            return True, num\n\n        num = float(str(s).replace(\",\", \"\"))\n        return True, num\n    except ValueError:\n        return False, None\n\n\ndef normalize(answer, pi) -> str:\n    # checking if answer is $<number> and removing $ in that case to compare\n    if isinstance(answer, str) and bool(re.match(r\"\\$\\d+(\\.\\d+)?\", answer)):\n        return answer[1:]\n\n    # checking if answer is <number>% or <number>\\\\% and removing %\n    if isinstance(answer, str) and (\n        bool(re.match(r\"^\\d+(\\.\\d+)?%$\", answer)) or bool(re.match(r\"^\\d+(\\.\\d+)?\\\\%$\", answer))\n    ):\n        return answer.replace(\"\\\\%\", \"\").replace(\"%\", \"\")\n\n    # handle base\n    answer = handle_base(answer)\n\n    # handle pi\n    answer = handle_pi(answer, pi)\n\n    return answer\n\n\ndef handle_base(x) -> str:\n    if isinstance(x, str) and \"_\" in x:\n        # Due to base\n        x = x.split(\"_\")[0]\n        x = float(x)\n        return int(x)\n    return x\n\n\ndef handle_pi(string, pi):\n    if isinstance(string, str) and \"\\\\pi\" in string:\n        # Find the first occurrence of \"\\pi\"\n        idx = string.find(\"\\\\pi\")\n\n        # Iterate over the string and find all occurrences of \"\\pi\" with a valid previous character\n        while idx != -1:\n            if idx > 0 and string[idx - 1].isdigit():\n                # Replace \"\\pi\" with \"*math.pi\" if the previous character is a digit\n                string = string[:idx] + f\"*{pi}\" + string[idx + 3 :]\n            else:\n                # Replace \"\\pi\" with \"1*math.pi\" if the previous character is not a digit\n                string = string[:idx] + f\"1*{pi}\" + string[idx + 3 :]\n\n            # Find the next occurrence of \"\\pi\"\n            idx = string.find(\"\\\\pi\", idx + 1)\n\n        # Evaluate the expression using eval() function\n        with contextlib.suppress(Exception):\n            string = eval(string)\n\n    return string\n\n\ndef math_equal(\n    prediction: bool | float | str,\n    reference: float | str,\n    include_percentage: bool = True,\n    tolerance: float = 1e-4,\n    timeout: float = 10.0,\n    pi: float = math.pi,\n) -> bool:\n    \"\"\"\n    Exact match of math if and only if:\n    1. numerical equal: both can convert to float and are equal\n    2. symbolic equal: both can convert to sympy expression and are equal\n    \"\"\"\n\n    prediction = normalize(prediction, pi)\n    reference = normalize(reference, pi)\n\n    if isinstance(prediction, str) and len(prediction) > 1000:  # handling weird corner-cases\n        prediction = prediction[:1000]\n\n    # 0. string comparison\n    if isinstance(prediction, str) and isinstance(reference, str):\n        if prediction.strip().lower() == reference.strip().lower():\n            return True\n        if prediction.replace(\" \", \"\") == reference.replace(\" \", \"\"):\n            return True\n\n    try:  # 1. numerical equal\n        if is_digit(prediction)[0] and is_digit(reference)[0]:\n            prediction = is_digit(prediction)[1]\n            reference = is_digit(reference)[1]\n            # number questions\n            gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]\n            for item in gt_result:\n                try:\n                    if isclose(item, prediction, rel_tol=tolerance):\n                        return True\n                except Exception:\n                    continue\n            return False\n    except Exception:\n        pass\n\n    if not prediction and prediction not in [0, False]:\n        return False\n\n    # 2. symbolic equal\n    reference = str(reference).strip()\n    prediction = str(prediction).strip()\n\n    ## deal with [], (), {}\n    prediction = format_intervals(prediction)\n\n    pred_str, ref_str = prediction, reference\n    if (prediction.startswith(\"[\") and prediction.endswith(\"]\") and not reference.startswith(\"(\")) or (\n        prediction.startswith(\"(\") and prediction.endswith(\")\") and not reference.startswith(\"[\")\n    ):\n        pred_str = pred_str.strip(\"[]()\")\n        ref_str = ref_str.strip(\"[]()\")\n    for s in [\"{\", \"}\", \"(\", \")\"]:\n        ref_str = ref_str.replace(s, \"\")\n        pred_str = pred_str.replace(s, \"\")\n    if pred_str == ref_str:\n        return True\n\n    ## [a, b] vs. [c, d], return a==c and b==d\n    if (\n        prediction\n        and reference\n        and prediction[0] in \"([\"\n        and prediction[-1] in \")]\"\n        and prediction[0] == reference[0]\n        and prediction[-1] == reference[-1]\n    ):\n        pred_parts = prediction[1:-1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts) and all(\n            [\n                math_equal(pred_pt, ref_pt, include_percentage, tolerance)\n                for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)\n            ]\n        ):\n            return True\n\n    if \",\" in prediction and \",\" in reference:\n        pred_parts = [item.strip() for item in prediction.split(\",\")]\n        ref_parts = [item.strip() for item in reference.split(\",\")]\n\n        if len(pred_parts) == len(ref_parts):\n            return bool(\n                all(\n                    [\n                        math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance)\n                        for i in range(len(pred_parts))\n                    ]\n                )\n            )\n\n    # if we have point == tuple of values\n    if prediction.startswith(\"Point\") and reference[0] == \"(\" and reference[-1] == \")\":\n        pred_parts = prediction[prediction.find(\"(\") + 1 : -1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts) and all(\n            [\n                math_equal(pred_pt, ref_pt, include_percentage, tolerance)\n                for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)\n            ]\n        ):\n            return True\n\n    # if reference is a matrix\n    if r\"\\begin{pmatrix}\" in reference and prediction.startswith(\"Matrix\"):\n        try:\n            pred_matrix = parse_expr(prediction)\n            ref_matrix_items = reference.split()[1:-1:2]\n            if len(pred_matrix) == len(ref_matrix_items) and all(\n                [\n                    math_equal(pred, ref, include_percentage, tolerance)\n                    for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)\n                ]\n            ):\n                return True\n        except Exception:\n            pass\n    elif r\"\\begin{pmatrix}\" in reference and prediction.startswith(\"[\") and prediction.endswith(\"]\"):\n        if isinstance(eval(prediction), list):\n            try:\n                pred_matrix = eval(prediction)\n                # ref_matrix_items = reference.split()[1:-1:2]\n                ref_matrix_items = (\n                    reference.removeprefix(r\"\\\\begin{pmatrix}\")\n                    .removeprefix(r\"\\begin{pmatrix}\")\n                    .removesuffix(r\"\\\\end{pmatrix}\")\n                    .removesuffix(r\"\\end{pmatrix}\")\n                )\n                ref_matrix_items = ref_matrix_items.split(\"\\\\\")\n                ref_matrix_items = [row.split(\"&\") if \"&\" in row else row for row in ref_matrix_items]\n                if len(pred_matrix) == len(ref_matrix_items) and all(\n                    [\n                        math_equal(pred, ref, include_percentage, tolerance)\n                        for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)\n                    ]\n                ):\n                    return True\n            except Exception:\n                pass\n\n    return symbolic_equal(prediction, reference, tolerance, timeout)\n\n\ndef symbolic_equal(a, b, tolerance, timeout=10.0):\n    def _parse(s):\n        for f in [parse_expr, parse_latex]:\n            try:\n                with timeout_limit(seconds=timeout):\n                    return f(s)\n            except TimeoutError:\n                print(f\"Parsing timed out for {s}\")\n                continue\n            except Exception:\n                continue\n        return s\n\n    a = _parse(a)\n    b = _parse(b)\n\n    try:\n        with timeout_limit(seconds=timeout):\n            if simplify(a - b) == 0:\n                return True\n    except TimeoutError:\n        print(f\"Simplification timed out for {a} - {b}\")\n        pass\n    except Exception:\n        pass\n\n    try:\n        with timeout_limit(seconds=timeout):\n            if isclose(N(a), N(b), rel_tol=tolerance):\n                return True\n    except TimeoutError:\n        print(f\"Numerical evaluation timed out for {a}, {b}\")\n        pass\n    except Exception:\n        pass\n    return False\n\n\ndef format_intervals(prediction):\n    patterns = {\n        \"Interval(\": r\"^Interval\\((.*)\\)$\",\n        \"Interval.Ropen(\": r\"^Interval\\.Ropen\\((.*)\\)$\",\n        \"Interval.Lopen(\": r\"^Interval\\.Lopen\\((.*)\\)$\",\n        \"Interval.open(\": r\"^Interval\\.open\\((.*)\\)$\",\n    }\n\n    for key, pattern in patterns.items():\n        match = re.match(pattern, prediction)\n        if match:\n            inner_content = match.group(1)\n\n            if key == \"Interval(\":  # Intarval(a, b) == [a, b]\n                return f\"[{inner_content}]\"\n            elif key == \"Interval.Ropen(\":  # Intarval.Ropen(a, b) == [a, b)\n                return f\"[{inner_content})\"\n            elif key == \"Interval.Lopen(\":  # Intarval.Lopen(a, b) == (a, b]\n                return f\"({inner_content}]\"\n            elif key == \"Interval.open(\":  # Intarval.open(a, b) == (a, b)\n                return f\"({inner_content})\"\n\n    return prediction\n"
  },
  {
    "path": "verl_distillation/recipe/entropy/reward_score/entropy_math/math_normalize.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright (c) 2021 Dan Hendrycks\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\"\"\"\nThis logic is largely copied from the Hendrycks' MATH release (math_equivalence).\n\nFrom: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py\n\"\"\"\n\nimport re\nfrom typing import Optional\n\n\ndef normalize_answer(answer: Optional[str]) -> Optional[str]:\n    if answer is None:\n        return None\n    answer = answer.strip()\n    try:\n        # Remove enclosing `\\text{}`.\n        m = re.search(r\"^\\\\text\\{(?P<text>.+?)\\}$\", answer)\n        if m is not None:\n            answer = m.group(\"text\").strip()\n        return _strip_string(answer)\n    except Exception:\n        return answer\n\n\ndef _fix_fracs(string):\n    substrs = string.split(\"\\\\frac\")\n    new_str = substrs[0]\n    if len(substrs) > 1:\n        substrs = substrs[1:]\n        for substr in substrs:\n            new_str += \"\\\\frac\"\n            if substr[0] == \"{\":\n                new_str += substr\n            else:\n                try:\n                    assert len(substr) >= 2\n                except Exception:\n                    return string\n                a = substr[0]\n                b = substr[1]\n                if b != \"{\":\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}{\" + b + \"}\"\n                else:\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}\" + b + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}\" + b\n    string = new_str\n    return string\n\n\ndef _fix_a_slash_b(string):\n    if len(string.split(\"/\")) != 2:\n        return string\n    a = string.split(\"/\")[0]\n    b = string.split(\"/\")[1]\n    try:\n        a = int(a)\n        b = int(b)\n        assert string == \"{}/{}\".format(a, b)\n        new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n        return new_string\n    except Exception:\n        return string\n\n\ndef _remove_right_units(string):\n    # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n    if \"\\\\text{ \" in string:\n        splits = string.split(\"\\\\text{ \")\n        assert len(splits) == 2\n        return splits[0]\n    else:\n        return string\n\n\ndef _fix_sqrt(string):\n    if \"\\\\sqrt\" not in string:\n        return string\n    splits = string.split(\"\\\\sqrt\")\n    new_string = splits[0]\n    for split in splits[1:]:\n        if split[0] != \"{\":\n            a = split[0]\n            new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n        else:\n            new_substr = \"\\\\sqrt\" + split\n        new_string += new_substr\n    return new_string\n\n\ndef _strip_string(string):\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = _remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\\\\\%\", \"\")\n    string = string.replace(\"\\\\%\", \"\")\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2 and len(string.split(\"=\")[0]) <= 2:\n        string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = _fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1).\n    # Also does a/b --> \\\\frac{a}{b}\n    string = _fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = _fix_a_slash_b(string)\n\n    return string\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/README.md",
    "content": "<p align=\"center\">\n<h1 align=\"center\">FAPO: Flawed-Aware Policy Optimization for Efficient and Reliable Reasoning</h1>\n\n<p align=\"center\">\n    <a href=\"https://fapo-rl.github.io/\"><img alt=\"Project Page\" src=\"https://img.shields.io/badge/📒-Project Page-blue\"></a>\n    <a href=\"https://verl.readthedocs.io/en/latest/advance/reward_loop.html\"><img alt=\"Infra Design\" src=\"https://img.shields.io/badge/🏗️-Infra Design-teal\">\n    <a href=\"https://huggingface.co/collections/dyyyyyyyy/fapo\"><img alt=\"Resources\" src=\"https://img.shields.io/badge/🤗 HuggingFace-Data & Models-green\"></a>\n    <a href=\"\"><img alt=\"Paper\" src=\"https://img.shields.io/badge/📄-Arxiv Paper-orange\"></a>\n    <a href=\"https://github.com/yyDing1/FAPO\"><img alt=\"Code\" src=\"https://img.shields.io/badge/💻-Code-blueviolet\"></a>\n</p>\n\n- **Algorithm Insights:** Visit our [Project Page](https://fapo-rl.github.io/) for an overview; comprehensive details are available in the [Paper]().\n- **Infrastructure Design:** Refer to the [Reward Loop](https://verl.readthedocs.io/en/latest/advance/reward_loop.html) document for architectural insights.\n- **Open-Source Software:** Explore the [Huggingface Collections](https://huggingface.co/collections/dyyyyyyyy/fapo) for datasets and models.\n\n\n![fapo-result](https://fapo-rl.github.io/_astro/intro_main.DKe72RHX_1Us2HB.webp)\n\n## Step 1: Train FAPO-GenRM-4B (Generative Reward Model)\n\nWe provide our training and evaluation datasets [here](https://huggingface.co/datasets/dyyyyyyyy/FAPO-Critic).\nDirectly download them to `${RAY_DATA_HOME}/data/`.\n\nThen, submit the training job to the ray cluster:\n\n```bash\ncd verl # Repo root\nexport RAY_ADDRESS=\"...\" # The Ray cluster address to connect to\nexport RAY_DATA_HOME=\"...\" # The directory to store the data\nexport WORKING_DIR=\"${PWD}\" # The local directory to package to the Ray cluster\n# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml\nexport RUNTIME_ENV=\"./recipe/fapo/runtime_env.yaml\" # This sets environment variables for the Ray cluster\nbash recipe/fapo/run_fapo_genrm_train.sh\n```\n\nYou can skip this step if you want to use the pre-trained FAPO-GenRM-4B model available [here](https://huggingface.co/dyyyyyyyy/FAPO-GenRM-4B).\n\n## Step 2: Integrate the GRM into the Final Training\n\nOur training data is identical to that of DAPO-Math-17K, except that we replace the instruction with \"Put the final answer in \\boxed{}\", which is a common practice for current instruct models.\n\nYou can construct the training and evaluation datasets by:\n```bash\npython recipe/fapo/prepare_fapo_data.py --local_dir ${RAY_DATA_HOME}/data/\n```\n\nOr you can directly use the data available [here](https://huggingface.co/datasets/dyyyyyyyy/FAPO-Reasoning-Dataset).\n\nTo integrate the GRM into the final training, we provide two options:\n\n1. **Launch GRM as an external service:** Launch multiple model servers and a router in advance to handle and dispatch incoming requests. Refer to `verl/recipe/genrm_remote` for more details. The scripts is `verl/recipe/fapo/run_fapo_{7b/32b}_remote.sh`.\n2. **Launch GRM in verl single controller:** Start the GRM model directly inside the verl single controller with an integrated router. (Note: this feature is still unstable for large-scale training scenarios.)\n\n```bash\ncd verl # Repo root\nexport RAY_ADDRESS=\"...\" # The Ray cluster address to connect to\nexport WORKING_DIR=\"${PWD}\" # The local directory to package to the Ray cluster\n# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml\nexport RUNTIME_ENV=\"./recipe/fapo/runtime_env.yaml\" # This sets environment variables for the Ray cluster\n\n# run Baseline Models\nbash recipe/fapo/run_baseline_7b.sh  # 7b baseline model\nbash recipe/fapo/run_baseline_32b.sh  # 32b baseline model\n\n# run FAPO Models (with external GRM service)\n# Note that you should launch the external GRM service first,\n# and specify the router address in the compute_score function\nbash recipe/fapo/run_fapo_7b_remote.sh  # 7b fapo model\nbash recipe/fapo/run_fapo_32b_remote.sh  # 32b fapo model\n\n# run FAPO Models (single controller mode)\nbash recipe/fapo/run_fapo_7b.sh  # 7b fapo model\nbash recipe/fapo/run_fapo_32b.sh  # 32b fapo model\n```\n\n## Infrastructure Design\n\nWe implement RewardLoop to enable efficient and flexible reward computation.\nThe core implementation can be found in `verl/experimental/reward/`.\nRefer to [this official document](https://verl.readthedocs.io/en/latest/advance/reward_loop.html) for more implementation details.\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/config/rm_config.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\nreward_model:\n  _target_: verl.workers.config.RewardModelConfig\n\n  reward_manager: dapo\n  enable: False\n\n  # Whether to deploy the model to a separate resource pool.\n  enable_resource_pool: False\n  n_gpus_per_node: 0\n  nnodes: 0\n\n  model:\n    type: discriminative\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    trust_remote_code: False\n\n  rollout:\n    _target_: verl.workers.config.RolloutConfig\n    name: ???\n    dtype: bfloat16\n    gpu_memory_utilization: 0.5\n    enforce_eager: true\n    cudagraph_capture_sizes: null\n    free_cache_engine: true\n    data_parallel_size: 1\n    expert_parallel_size: 1\n    tensor_model_parallel_size: 2\n    max_num_batched_tokens: 8192\n    max_model_len: null\n    max_num_seqs: 1024\n    load_format: auto\n    engine_kwargs: {}\n    limit_images: null\n    enable_chunked_prefill: true\n    enable_prefix_caching: true\n    disable_log_stats: true\n    skip_tokenizer_init: true\n\n    prompt_length: 512\n    response_length: 512"
  },
  {
    "path": "verl_distillation/recipe/fapo/prepare_fapo_data.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nfrom functools import partial\n\nfrom datasets import concatenate_datasets, load_dataset\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef example_map_fn(example, idx, process_fn, data_source, ability, split):\n    question, prompt, ground_truth = process_fn(example)\n    data = {\n        \"data_source\": data_source,\n        \"prompt\": [{\"role\": \"user\", \"content\": prompt}],\n        \"ability\": ability,\n        \"reward_model\": {\"style\": \"rule\", \"ground_truth\": ground_truth},\n        \"extra_info\": {\"split\": split, \"index\": idx, \"question\": question},\n    }\n    return data\n\n\ndef build_aime2024_dataset():\n    def process_aime2024(example):\n        question, ground_truth = example[\"Problem\"], str(example[\"Answer\"])\n        prompt = question.strip() + \"\\n\\n\" + \"Please reason step by step, and put your final answer within \\\\boxed{}.\"\n        return question, prompt, ground_truth\n\n    data_source = \"Maxwell-Jia/AIME_2024\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n    dataset = load_dataset(data_source, split=\"train\")\n    map_fn = partial(example_map_fn, process_fn=process_aime2024, data_source=\"aime24\", ability=\"Math\", split=\"test\")\n    dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)\n    return dataset\n\n\ndef build_aime2025_dataset():\n    def process_aime2025(example):\n        question, ground_truth = example[\"problem\"], str(example[\"solution\"])\n        prompt = question.strip() + \"\\n\\n\" + \"Please reason step by step, and put your final answer within \\\\boxed{}.\"\n        return question, prompt, ground_truth\n\n    data_source = \"yentinglin/aime_2025\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n    dataset = load_dataset(data_source, split=\"train\")\n    map_fn = partial(example_map_fn, process_fn=process_aime2025, data_source=\"aime25\", ability=\"Math\", split=\"test\")\n    dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)\n    return dataset\n\n\ndef build_gpqa_diamond_dataset():\n    import random\n\n    GPQA_QUERY_TEMPLATE = (\n        \"{Question}\\n\"\n        \"A. {A}\\nB. {B}\\nC. {C}\\nD. {D}\\n\\n\"\n        \"Please reason step by step, and put your final answer (only the choice letter) within \\\\boxed{{}}.\"\n    )\n\n    def process_gpqa_diamond(example):\n        choices = [\n            example[\"Incorrect Answer 1\"].strip(),\n            example[\"Incorrect Answer 2\"].strip(),\n            example[\"Incorrect Answer 3\"].strip(),\n        ]\n        random.shuffle(choices)\n        gold_index = random.randint(0, 3)\n        choices.insert(gold_index, example[\"Correct Answer\"].strip())\n        question = example[\"Question\"]\n        query_prompt = GPQA_QUERY_TEMPLATE.format(\n            A=choices[0],\n            B=choices[1],\n            C=choices[2],\n            D=choices[3],\n            Question=question,\n        )\n        gold_choice = \"ABCD\"[gold_index]\n        return question, query_prompt, gold_choice\n\n    data_source = \"Idavidrein/gpqa\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n\n    dataset = load_dataset(data_source, \"gpqa_diamond\", split=\"train\")\n    map_fn = partial(\n        example_map_fn, process_fn=process_gpqa_diamond, data_source=\"gpqa-diamond\", ability=\"General\", split=\"test\"\n    )\n    dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)\n    return dataset\n\n\ndef build_dapo_train_dataset():\n    def process_dapo(example):\n        question, ground_truth = example[\"prompt\"], example[\"solution\"]\n        prompt = question.strip() + \"\\n\\n\" + \"Please reason step by step, and put your final answer within \\\\boxed{}.\"\n        return question, prompt, ground_truth\n\n    data_source = \"open-r1/DAPO-Math-17k-Processed\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n    dataset = load_dataset(data_source, \"all\", split=\"train\")\n    map_fn = partial(example_map_fn, process_fn=process_dapo, data_source=\"math-dapo\", ability=\"Math\", split=\"train\")\n    dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)\n    return dataset\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/genrm\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--tasks\", default=\"all\")\n\n    args = parser.parse_args()\n\n    train_dataset = build_dapo_train_dataset()\n    train_dataset = concatenate_datasets([train_dataset for _ in range(20)])\n\n    test_datasets = []\n    # AIME 2024\n    aime24_dataset = build_aime2024_dataset()\n    test_datasets.extend([aime24_dataset for _ in range(32)])\n    # AIME 2025\n    aime25_dataset = build_aime2025_dataset()\n    test_datasets.extend([aime25_dataset for _ in range(32)])\n    # GPQA Diamond\n    gpqa_dataset = build_gpqa_diamond_dataset()\n    test_datasets.extend([gpqa_dataset for _ in range(4)])\n    test_dataset = concatenate_datasets(test_datasets)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"fapo-train-boxed.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"fapo-test-full-boxed.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/reward_fn_genrm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py\n\n\nfrom verl.utils.reward_score.math_dapo import last_boxed_only_string, remove_boxed\n\n\ndef parse_ans(\n    solution_str: str,\n    total_steps: int,\n) -> tuple[bool, str]:\n    try:\n        boxed_answer = last_boxed_only_string(solution_str[-300:])\n        extracted_answer = int(remove_boxed(boxed_answer))\n        if extracted_answer == -1 or 0 <= extracted_answer < total_steps:\n            return extracted_answer\n        else:\n            return None\n    except Exception:\n        return None\n\n\ndef compute_score_fapo_genrm(\n    solution_str: str,\n    ground_truth: int,\n    extra_info: dict,\n    **kwargs,\n) -> float:\n    # Verify the solution\n    total_steps = extra_info[\"total_steps\"]\n    extracted_answer = parse_ans(solution_str, total_steps)\n    gt = \"correct\" if ground_truth == -1 else \"incorrect\"\n    pred = \"correct\" if extracted_answer == -1 else \"incorrect\"\n    if extracted_answer is None:\n        pred = \"[INVALID]\"\n    acc = gt == pred\n    # reward = 1.0 if acc else -1.0\n    if extracted_answer is None:\n        reward = -1.0\n    elif ground_truth == -1:\n        reward = 1.0 if extracted_answer == -1 else -1.0\n    else:\n        # ground truth != -1\n        if extracted_answer == -1:\n            reward = -1.0\n        else:\n            # gt != -1, pred != -1\n            reward = 1.0\n            reward -= abs(extracted_answer - ground_truth) / total_steps\n\n    return {\n        \"score\": reward,\n        \"acc\": acc,\n        \"pred\": extracted_answer,\n        \"gt\": ground_truth,\n    }\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/reward_fn_reasoning.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport json\nimport logging\nimport os\n\nimport aiohttp\nfrom transformers import PreTrainedTokenizer\n\nfrom verl.utils.reward_score.math_dapo import last_boxed_only_string, normalize_final_answer, remove_boxed\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef verify(\n    solution_str: str,\n    gt: str,\n) -> tuple[bool, str]:\n    solution_str = solution_str[-300:]\n    boxed_answer = last_boxed_only_string(solution_str)\n    if boxed_answer is not None:\n        extracted_answer = remove_boxed(boxed_answer)\n    else:\n        extracted_answer = \"[INVALID]\"\n\n    pred = normalize_final_answer(extracted_answer)\n    gt = normalize_final_answer(gt)\n    return (pred == gt), pred\n\n\nasync def compute_score_baseline(\n    solution_str: str,\n    ground_truth: str,\n    **kwargs,\n):\n    loop = asyncio.get_running_loop()\n    \"\"\"Compute the reward score for Baseline.\"\"\"\n    correct, pred = await loop.run_in_executor(None, lambda: verify(solution_str, ground_truth))\n    reward_score = 1.0 if correct else -1.0\n    return {\"score\": reward_score, \"acc\": correct, \"pred\": pred}\n\n\n# FAPO Hyper-parameters\nFAPO_GENRM_TEMPLATE = (\n    \"The following is a math problem with its ground truth answer, along with an AI solution (split into steps):\\n\\n\"\n    \"[Math Problem]\\n\\n\"\n    \"{problem}\\n\\n\"\n    \"[Ground Truth]\\n\\n\"\n    \"{ground_truth}\\n\\n\"\n    \"[AI Solution]\\n\\n\"\n    \"{solution}\\n\\n\"\n    \"Your task is to review and critique the solution step by step. \"\n    \"Once you identify an error in a step, return the index of the step where the earliest error occurs. \"\n    \"Otherwise, return the index of -1 (which typically denotes 'not found').\\n\\n\"\n    \"Please reason step by step, put your final answer (i.e., the index) in \\\\boxed{{}}.\"\n)\nGRM_SAMPLING_PARAMS = {\n    \"max_new_tokens\": 16384,\n}\nFLAWED_REWARD_PENALTY = 1.0\n\n\nasync def generate_aiohttp(router_address: str, prompt_ids: list[int], sampling_params: dict):\n    payload = {\n        \"input_ids\": prompt_ids,\n        \"sampling_params\": sampling_params,\n    }\n    url = f\"http://{router_address}/generate\"\n    try:\n        session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None))\n        async with session.post(url, json=payload) as resp:\n            output = await resp.text()\n            try:\n                output = json.loads(output)\n                return output\n            except Exception:\n                logger.error(f\"Failed to parse JSON response: {output}\")\n                return {}\n    finally:\n        await session.close()\n\n\nasync def compute_score_fapo(\n    data_source: str,\n    solution_str: str,\n    ground_truth: str,\n    extra_info: dict,\n    reward_router_address: str,\n    reward_model_tokenizer: PreTrainedTokenizer,\n):\n    \"\"\"Compute the reward score for FAPO.\"\"\"\n    loop = asyncio.get_running_loop()\n\n    question, split = extra_info[\"question\"], extra_info[\"split\"]\n    correct, pred = await loop.run_in_executor(None, lambda: verify(solution_str, ground_truth))\n    reward_score = 1.0 if correct else -1.0\n    is_flawed_positive = False\n\n    # for test set or incorrect solution, directly return the reward score\n    if split == \"test\" or not correct:\n        return {\"score\": reward_score, \"acc\": correct, \"pred\": pred, \"is_flawed_positive\": is_flawed_positive}\n\n    grm_prompt = FAPO_GENRM_TEMPLATE.format(\n        problem=question,\n        ground_truth=ground_truth,\n        solution=solution_str,\n    )\n    grm_prompt_ids = await loop.run_in_executor(\n        None,\n        lambda: reward_model_tokenizer.apply_chat_template(\n            [{\"role\": \"user\", \"content\": grm_prompt}],\n            tokenize=True,\n            add_generation_prompt=True,\n        ),\n    )\n    grm_outputs = await generate_aiohttp(\n        router_address=reward_router_address,\n        prompt_ids=grm_prompt_ids,\n        sampling_params=GRM_SAMPLING_PARAMS,\n    )\n    grm_response_ids = grm_outputs.get(\"output_ids\", None)\n    if grm_response_ids is not None:\n        grm_response = await loop.run_in_executor(\n            None, lambda: reward_model_tokenizer.decode(grm_response_ids, skip_special_tokens=True)\n        )\n        try:\n            err_location = remove_boxed(last_boxed_only_string(grm_response))\n            is_flawed_positive = int(err_location) != -1\n        except Exception:\n            is_flawed_positive = False\n\n        if is_flawed_positive:\n            reward_score -= FLAWED_REWARD_PENALTY\n\n    return {\"score\": reward_score, \"acc\": correct, \"pred\": pred, \"is_flawed_positive\": is_flawed_positive}\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/reward_fn_reasoning_remote.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\n\nimport aiohttp\n\nfrom verl.utils.reward_score.math_dapo import last_boxed_only_string, normalize_final_answer, remove_boxed\n\n\ndef verify(\n    solution_str: str,\n    gt: str,\n) -> tuple[bool, str]:\n    boxed_answer = last_boxed_only_string(solution_str)\n    if boxed_answer is not None:\n        extracted_answer = remove_boxed(boxed_answer)\n    else:\n        extracted_answer = \"[INVALID]\"\n\n    pred = normalize_final_answer(extracted_answer)\n    gt = normalize_final_answer(gt)\n    return (pred == gt), pred\n\n\ndef compute_score_baseline(\n    solution_str: str,\n    ground_truth: str,\n    **kwargs,\n) -> float:\n    # Limit solution length for efficiency\n    solution_str = solution_str[-300:]  # The longest answer in MATH-500 has 159 characters\n\n    # Verify the solution\n    correct, pred = verify(solution_str, ground_truth)\n\n    reward = 1.0 if correct else -1.0\n    acc = correct\n\n    return {\n        \"score\": reward,\n        \"acc\": acc,\n        \"pred\": pred,\n    }\n\n\nADDRESS = \"xx.xx.xx.xx:xxxx\"\nMODEL_NAME = \"FAPO-4B-GenRM\"\nFAPO_GENRM_TEMPLATE = (\n    \"The following is a math problem with its ground truth answer, along with an AI solution (split into steps):\\n\\n\"\n    \"[Math Problem]\\n\\n\"\n    \"{problem}\\n\\n\"\n    \"[Ground Truth]\\n\\n\"\n    \"{ground_truth}\\n\\n\"\n    \"[AI Solution]\\n\\n\"\n    \"{solution}\\n\\n\"\n    \"Your task is to review and critique the solution step by step. \"\n    \"Once you identify an error in a step, return the index of the step where the earliest error occurs. \"\n    \"Otherwise, return the index of -1 (which typically denotes 'not found').\\n\\n\"\n    \"Please reason step by step, put your final answer (i.e., the index) in \\\\boxed{{}}.\"\n)\n\n\nasync def chat_completions_aiohttp(address, **chat_complete_request):\n    try:\n        request_url = f\"http://{address}/v1/chat/completions\"\n        timeout = aiohttp.ClientTimeout(total=None)\n        session = aiohttp.ClientSession(timeout=timeout)\n        async with session.post(\n            url=request_url,\n            json=chat_complete_request,\n        ) as resp:\n            output = await resp.text()\n            try:\n                output = json.loads(output)\n                return output[\"choices\"][0][\"message\"][\"content\"]\n            except Exception as e:\n                print(f\"Error: {e}. Output: {output}\")\n                return \"\"\n    finally:\n        await session.close()\n\n\ndef judge_fp_process(response, return_err_step=False):\n    try:\n        boxed_result = last_boxed_only_string(response)\n        result = remove_boxed(boxed_result)\n        reward_score = int(eval(result)) != -1\n        if return_err_step:\n            return reward_score, int(result)\n        return reward_score\n    except Exception:\n        if return_err_step:\n            return None, None\n        return None\n\n\nasync def compute_score_fapo(data_source, solution_str, ground_truth, extra_info, keep_genrm_critics=False, **kwargs):\n    question, split = extra_info[\"question\"], extra_info[\"split\"]\n    result = compute_score_baseline(solution_str, ground_truth)\n    result[\"flawed_positive\"] = False\n\n    if split == \"test\" or result[\"acc\"] == 0:\n        if keep_genrm_critics:\n            result[\"genrm_critics\"] = \"\"\n        return result\n    else:\n        prompt = FAPO_GENRM_TEMPLATE.format(problem=question, ground_truth=ground_truth, solution=solution_str)\n        messages = [{\"role\": \"user\", \"content\": prompt}]\n        response = await chat_completions_aiohttp(\n            ADDRESS,\n            messages=messages,\n            model=MODEL_NAME,\n            max_tokens=16384,\n        )\n        if response is not None and judge_fp_process(response):  # flawed positive\n            result[\"score\"] = 0.0\n            result[\"flawed_positive\"] = True\n\n        if keep_genrm_critics and response is not None:\n            result[\"genrm_critics\"] = response\n\n    return result\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/run_baseline_32b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='FAPO-Reproduce'\nexp_name='Baseline-32B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=8\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$((max_prompt_length + max_response_length))\ninfer_ppo_max_token_len=$((max_prompt_length + max_response_length))\noffload=True\ngen_tp=4\nfsdp_size=32\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/recipe/fapo/config\"\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m verl.trainer.main_ppo \\\n    --config-path $CONFIG_PATH \\\n    --config-name rm_config.yaml \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    custom_reward_function.path=recipe/fapo/reward_fn.py \\\n    custom_reward_function.name=compute_score_baseline \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=600 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/run_baseline_7b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='FAPO-Reproduce'\nexp_name='Baseline-7B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=8\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=1\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=1\nfsdp_size=8\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/recipe/fapo/config\"\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m verl.trainer.main_ppo \\\n    --config-path $CONFIG_PATH \\\n    --config-name rm_config.yaml \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    custom_reward_function.path=recipe/fapo/reward_fn.py \\\n    custom_reward_function.name=compute_score_baseline \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=200 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/run_fapo_32b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='FAPO-Reproduce'\nexp_name='FAPO-32B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=8\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nRM_NODES=${RM_NODES:-2}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nGRM_PATH=${GRM_PATH:-\"${RAY_DATA_HOME}/models/FAPO-GenRM-4B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$((max_prompt_length + max_response_length))\ninfer_ppo_max_token_len=$((max_prompt_length + max_response_length))\noffload=True\ngen_tp=4\nfsdp_size=32\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/recipe/fapo/config\"\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m verl.trainer.main_ppo \\\n    --config-path $CONFIG_PATH \\\n    --config-name rm_config.yaml \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.enable=True \\\n    reward_model.enable_resource_pool=True \\\n    reward_model.n_gpus_per_node=8 \\\n    reward_model.nnodes=\"${RM_NODES}\" \\\n    reward_model.model.path=${GRM_PATH} \\\n    reward_model.rollout.name=sglang \\\n    reward_model.rollout.gpu_memory_utilization=0.95 \\\n    reward_model.rollout.tensor_model_parallel_size=1 \\\n    reward_model.rollout.free_cache_engine=False \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    custom_reward_function.path=recipe/fapo/reward_fn.py \\\n    custom_reward_function.name=compute_score_fapo \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=600 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/run_fapo_32b_remote.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='FAPO-Reproduce'\nexp_name='FAPO-32B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=8\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$((max_prompt_length + max_response_length))\ninfer_ppo_max_token_len=$((max_prompt_length + max_response_length))\noffload=True\ngen_tp=4\nfsdp_size=32\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/recipe/fapo/config\"\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m verl.trainer.main_ppo \\\n    --config-path $CONFIG_PATH \\\n    --config-name rm_config.yaml \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    custom_reward_function.path=recipe/fapo/reward_fn_reasoning_remote.py \\\n    custom_reward_function.name=compute_score_fapo \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=600 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/run_fapo_7b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='FAPO-Reproduce'\nexp_name='FAPO-7B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=8\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nRM_NODES=${RM_NODES:-2}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nGRM_PATH=${GRM_PATH:-\"${RAY_DATA_HOME}/models/FAPO-GenRM-4B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=1\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=1\nfsdp_size=8\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/recipe/fapo/config\"\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m verl.trainer.main_ppo \\\n    --config-path $CONFIG_PATH \\\n    --config-name rm_config.yaml \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.enable=True \\\n    reward_model.enable_resource_pool=True \\\n    reward_model.n_gpus_per_node=8 \\\n    reward_model.nnodes=\"${RM_NODES}\" \\\n    reward_model.model.path=${GRM_PATH} \\\n    reward_model.rollout.name=sglang \\\n    reward_model.rollout.gpu_memory_utilization=0.95 \\\n    reward_model.rollout.tensor_model_parallel_size=1 \\\n    reward_model.rollout.free_cache_engine=False \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    custom_reward_function.path=recipe/fapo/reward_fn.py \\\n    custom_reward_function.name=compute_score_fapo \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=200 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/run_fapo_7b_remote.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='FAPO-Reproduce'\nexp_name='FAPO-7B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=8\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=1\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=1\nfsdp_size=8\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/recipe/fapo/config\"\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m verl.trainer.main_ppo \\\n    --config-path $CONFIG_PATH \\\n    --config-name rm_config.yaml \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    custom_reward_function.path=recipe/fapo/reward_fn_reasoning_remote.py \\\n    custom_reward_function.name=compute_score_fapo \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=200 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/run_fapo_genrm_train.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='FAPO-Reproduce'\nexp_name='FAPO-GenRM-4B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 5))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-4B-Instruct-2507\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/train.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/test.parquet\"}\n\n# Algorithm\ntemperature=1.2\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_temperature=0.6\nval_top_p=0.95\n\n# Performance Related Parameter\nsp_size=1\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=1\nfsdp_size=8\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --address \"${RAY_ADDRESS}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    custom_reward_function.path=recipe/fapo/reward_fn_genrm.py \\\n    custom_reward_function.name=compute_score_fapo_genrm \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=500 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/fapo/runtime_env.yaml",
    "content": "working_dir: ./\nexcludes: [\"/.git/\"]\nenv_vars:\n  TORCH_NCCL_AVOID_RECORD_STREAMS: \"1\"\n  VLLM_USE_V1: \"1\"\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/README.md",
    "content": "# Recipe: Fully Async Policy Trainer\n\n**Author:** `https://github.com/meituan-search`\n\nLast updated: 10/18/2025.\n\nThis document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter,\nsupporting asynchronous sample generation and training.\nUnder this system, we achieved a 2.35x-2.67x performance improvement when training the Qwen2.5-7B model with 128 GPUs,\nwithout significantly affecting the results.\n\n## Introduction\n\n### Background\n\nThe separated rollout and train architecture, compared to the colocate architecture, can allocate resources more\nflexibly and design more flexible training logic, thereby addressing issues such as low GPU utilization and training\nefficiency caused by long-tail problems.\nThe one_step_off_policy alleviates the problem of long rollout times and achieves some gains in training efficiency by\ndesigning a separated architecture and performing asynchronous training between rollout and train for one round.\nHowever, it forcibly uses data from one round of asynchronous training, which is not flexible enough and cannot\ncompletely eliminate the impact of long-tail on training efficiency.\nIn other frameworks such as AReaL, Magistral, StreamRL, and AsyncFlow, asynchronous training and streaming training have\nbeen implemented based on the separated architecture and have achieved gains.\nWe borrow from their methods and implemented them in VERL. The fully_async_policy supports asynchronous, streaming, and\npartial\nrollout training.\nBy reasonably setting parameters such as resource allocation and parameter synchronization frequency, fully_async_policy\ncan significantly improve training efficiency.\n\n> Magistral https://arxiv.org/abs/2506.10910\n>\n> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language\n> Reasoning https://arxiv.org/abs/2505.24298\n>\n> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream\n> Generation https://arxiv.org/abs/2504.15930\n>\n> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663\n>\n\n### Core Contributions\n\n* **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to\n  specify the resources they occupy separately.\n* **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples.\n* **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to\n  multiple steps, making the asynchronous solution more flexible.\n* **NCCL Parameter Synchronization**: Uses NCCL communication primitives for parameter communication between Rollouter\n  and Trainer.\n* **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single\n  sample as the minimum transmission unit.\n* **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it\n  supports training with samples generated by old parameters.\n* **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter\n  synchronization, by adding `sleep() and resume()` logic, it\n  saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for\n  ongoing tasks to finish during parameter synchronization.\n\nCurrently, the supported usage mode is megatron/fsdp+vllm. vllm must use the server mode based on AgentLoop.\n\n## Design\n\nThe overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four\nparts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer.\n\n![fully_async_policy_structure](\nhttps://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true)\n\n1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the\n   production speed controlled by freshness.\n2. MessageQueue is used to temporarily store samples generated by Rollouter.\n3. Trainer fetches samples from MessageQueue sample by sample. After fetching `require_batches*ppo_mini_batch_size`\n   samples, it will perform training. After training for async_training.trigger_parameter_sync_step rounds, it triggers\n   a parameter synchronization with Rollouter.\n4. ParameterSynchronizer implements the NCCL synchronous parameter synchronization capability.\n\nThe source of benefits compared to the base scheme lies in the fact that in the colocate case, using more resources for\nrollout cannot solve the idleness caused by long-tail samples.\nAfter we perform resource isolation, the time for rollout and train may be longer than before (because fewer resources\nare used),\nbut the overlap in their time consumption reduces the end-to-end time consumption.\n\n![fully_async_policy_revenue](\nhttps://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true)\n\n## Usage\n\n### Parameter Description\n\n| super params                                  | implication                                                                                    |\n|-----------------------------------------------|------------------------------------------------------------------------------------------------|\n| `trainer.nnodes`                              | Number of nodes for Trainer                                                                    |\n| `trainer.n_gpus_per_node`                     | Number of GPUs per node for Trainer                                                            |\n| `rollout.nnodes`                              | Number of nodes for Rollouter                                                                  |\n| `rollout.n_gpus_per_node`                     | Number of GPUs per node for Rollouter                                                          |\n| `data.train_batch_size`                       | In the fully async strategy, this value is not effective (default is 0)                        |\n| `data.gen_batch_size`                         | In the fully async strategy, uses streaming sample production logic (default is 1)             |\n| `rollout.total_rollout_steps`                 | Total number of rollout samples                                                                |\n| `rollout.test_freq`                           | How many times Rollouter updates parameters before performing a validation                     |\n| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus                                |\n| `async_training.require_batches`              | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once                           |\n| `async_training.trigger_parameter_sync_step`  | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization |\n| `async_training.staleness_threshold`          | Freshness control                                                                              |\n| `async_training.partial_rollout`              | Whether to perform partial_rollout                                                             |\n| `async_training.use_rollout_log_probs`        | Use log_probs generated by rollout                                                             |\n| `async_training.compute_prox_log_prob`        | Whether to compute log_prob using the training model's parameters during the training phase.   |                                                |\n\n**Further Explanation:**\n\n* `rollout.total_rollout_steps`\n\n  Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step:\n  `rollout.total_rollout_steps = data.train_batch_size * step`.\n\n* `async_training.trigger_parameter_sync_step`\n\n  In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches\n  `require_batches * ppo_mini_batch_size` samples) before a parameter synchronization with Rollouter.\n  Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process\n  `trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` samples.\n  To fairly compare speed with colocate, trigger_parameter_sync_step should be set to\n  `data.train_batch_size / (require_batches * ppo_mini_batch_size)`.\n\n* `async_training.staleness_threshold`\n\n  In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used.\n\n    * staleness_threshold=0, indicates synchronous training.\n      Rollouter will generate a fixed number of samples between two parameter updates, the sample count is:\n      $$rollout\\_num = (trigger\\_parameter\\_sync\\_step*require\\_batches*ppo\\_mini\\_batch\\_size)$$\n    * staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous\n      calls.\n      Rollouter will generate at most the following number of samples between two parameter updates:\n      $$rollout\\_num = (1+staleness\\_threshold)*(trigger\\_parameter\\_sync\\_step*require\\_batches*ppo\\_mini\\_batch\\_size) - num\\_staleness\\_sample $$\n\n  num_staleness_sample represents the number of stale samples generated in excess during the last rollout.\n\n  Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower,\n  trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples.\n  When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy.\n  To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1.\n\n* `async_training.partial_rollout`\n\n  partial_rollout only actually takes effect when staleness_threshold>0.\n\n* `async_training.use_rollout_log_probs`\n\n  In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to\n  the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling,\n  old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm\n  correctness. In the fully\n  async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.\n\n* `async_training.require_batches`\n\n  In streaming training, require_batches should be set to 1, indicating that training is performed after producing\n  enough ppo_mini_batch_size samples.\n  In actual testing, we found that if fewer samples are issued at once, due to the order of data distribution, it can\n  cause training instability and longer response lengths.\n  Here, we additionally provide require_batches for streaming distribution and control the number of samples\n  participating in training at once.\n\n* `async_training.compute_prox_log_prob` (experimental)\n\n  During the training process, we observed that metrics and response lengths may become unstable in the later\n  stages of training. To mitigate this issue, we can use\n  the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html)\n  technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using\n  the training engine, which requires enabling this switch.\n  Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d\n  (async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`.\n\n### Supported Modes\n\n1. on policy pipeline:\n    1. **trigger_parameter_sync_step=1, staleness_threshold=0**\n    2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for\n       training, and after training completes, Trainer and Rollouter perform a parameter synchronization;\n    3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill\n       idle resources, causing some resource waste.\n    4. As shown in figure a;\n\n2. stream off policy pipeline:\n    1. **trigger_parameter_sync_step>1, staleness_threshold=0**\n    2. Synchronous streaming training will be performed. Rollouter produces\n       `require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local\n       training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training\n       trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization;\n    3. Compared to a, since more samples are generated at once, resource idleness will be lower.\n    4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples,\n       train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter\n       update, rollout waits for training to complete.\n    5. As shown in figure b;\n\n3. async stream pipeline with stale samples:\n    1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False**\n    2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number\n       of samples generated may be less than this value depending on rollout speed).\n    3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples\n       before parameter synchronization for immediate use by Trainer after synchronization.\n       When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete\n       and not add new tasks;\n    4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the\n       first batch rollout to finish, but will have the time to wait for active tasks to finish.\n    5. As shown in figure c;\n\n4. async stream pipeline with partial rollout:\n    1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True**\n    2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will\n       interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be\n       generated after synchronization. This reduces the time to wait for active tasks to finish.\n    3. As shown in figure d;\n\n![fully_async_policy_mode](\nhttps://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true)\n\n### Key Metrics\n\n| metrics                                        | implication                                                                                            |\n|------------------------------------------------|--------------------------------------------------------------------------------------------------------|\n| `trainer/idle_ratio`                           | Trainer idle rate                                                                                      |\n| `rollouter/idle_ratio`                         | Rollouter idle rate                                                                                    |\n| `fully_async/count/stale_samples_processed`    | Total number of old samples used in training                                                           |\n| `fully_async/count/stale_trajectory_processed` | Total number of old trajectories used in training (one sample produces rollout.n trajectories)         |\n| `fully_async/partial/total_partial_num`        | Number of partial samples processed by Trainer between two trigger_parameter_sync_step                 |\n| `fully_async/partial/partial_ratio`            | Ratio of partial samples processed by Trainer between two trigger_parameter_sync_step                  |\n| `fully_async/partial/max_partial_span`         | Maximum parameter span of partial samples processed by Trainer between two trigger_parameter_sync_step |\n\n### Parameter Tuning Recommendations\n\n* Resource Allocation and Adjustment:\n    * Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource\n      allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire\n      training process,\n      avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource\n      allocation can be adjusted based on the idle time of rollout and train during actual training,\n      which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and\n      trainer/idle_ratio is low,\n      Trainer resources should be increased and Rollouter resources should be reduced, and vice versa.\n\n* Key Parameters:\n    * staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It\n      is recommended to set it to less than 1.\n    * require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and\n      the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample\n      processing;\n    * trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent\n      parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in\n      low resource utilization.\n      The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy.\n    * rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small.\n\n* Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at\n  different levels, suitable for tasks in different scenarios.\n    * For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed\n      requirements, the on policy pipeline mode (Mode 1) can be tried.\n    * For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy\n      pipeline mode can be tried. That is, by\n      setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization\n      mechanism (staleness_threshold=0) (Mode 2).\n    * For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and\n      staleness, setting staleness_threshold>\n      0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4).\n\n### Quick Start\n\n```shell\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\ntotal_rollout_steps=$(((512*400)))\ntest_freq=10\nstaleness_threshold=0\ntrigger_parameter_sync_step=16\npartial_rollout=False\n\n\npython -m recipe.fully_async_policy.fully_async_main \\\n\ttrain_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    trainer.nnodes=\"${NNODES_TRAIN}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.nnodes=\"${NNODES_ROLLOUT}\" \\\n    rollout.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.test_freq=\"${test_freq}\" \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\"\n```\n\n## Experiments\n\n### Asynchronous Training on 7B Model\n\nWe used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under long candidates and multiple resources.\nUsing the `async stream pipeline with stale samples` strategy, we achieved about 2x performance improvement on 32 cards,\n64 cards, and 128 cards without significantly affecting experimental results.\n\n* Machine: H20\n* Model: Qwen2.5-Math-7B\n* Rollout length: max_response_length FSDP2: 28K tokens;\n* Algorithm: DAPO\n* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet\n* Engine: vllm+FSDP2\n* rollout.n: 16\n* ppo_mini_batch_size: 32\n* test_freq: 20\n\n* colocate sync:\n    * step: 400\n    * train_batch_size: 512\n\n* fully_async_policy\n    * total_rollout_steps: 512*400\n    * require_batches: 4\n    * trigger_parameter_sync_step: 4\n    * staleness_threshold: 0.5\n    * partial_rollout: True\n\n|  training mode   \t   | resource allocation \t | step  \t  |  gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t | total time<br>400 step \t |      acc/mean@1          \t      |\n|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-------------------------------:|\n| colocate sync      \t | 32                  \t | 790.10 \t | 357.41 \t | 107.71       \t | 313.81       \t | 13h 44m                \t | 1d 3h 43m              \t | 2d 9h 22m              \t | 3d 17h 5m              \t | max: 0.3313<br>last: 0.2448  \t  |\n| fully_async_policy \t | 16:16               \t |  294.77  |  21.26   | \\            \t |     269.80     |    7h 58m<br>(1.72x)     |    16h 21m<br>(1.70x)    |   1d 0h 53m<br>(2.31x)   |   1d 9h 26m<br>(2.66x)   | max: 0.3302<br>last: 0.2333   \t |\n| colocate sync      \t | 64                  \t | 365.28 \t | 150.72 \t | 70.26        \t | 133.41       \t | 10h 22m                \t | 20h 45m                \t | 1d 7h 6m               \t | 1d 17h 32m             \t | max: 0.3365<br>last:  0.2333 \t  |\n| fully_async_policy \t | 32:32               \t | 189.26 \t | 28.46  \t | \\            \t | 156.98       \t | 4h 57m<br>(2.09x)      \t | 10h 14m<br>(2.03x)     \t | 16h 58m<br>(1.83x)     \t | 21h 40m<br>(1.92x)     \t | max: 0.3677<br>last: 0.3406  \t  |\n| colocate sync      \t | 128                 \t | 356.30 \t | 177.85 \t | 53.92        \t | 113.81       \t | 8h 36m                 \t | 17h 56m                \t | 1d 5h 6m               \t | 1d 16h 48m             \t | max: 0.3573<br>last: 0.2958  \t  |\n| fully_async_policy \t | 64:64               \t | 150.63 \t | 33.14  \t | \\            \t | 113.16       \t | 3h 13m<br>(2.67x)      \t | 6h 46m<br>(2.65x)      \t | 10h 53m<br>(2.67x)     \t | 17h 22m<br>(2.35x)     \t | max: 0.3521<br>last: 0.3094  \t  |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg\n\n### 128-card 7B Asynchronous Mode Experiment\n\nWe used Qwen2.5-Math-7B to verify the effects of various modes supported by fully async.\nWe can see that the benefit brought by streaming is approximately 1.6x, and after combining staleness and\npartial_rollout, the benefit reaches 2.35x.\n\n|                             mode                                         \t                              | step  \t  |  gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t | total time<br>400 step \t |      acc/mean@1         \t      |\n|:-------------------------------------------------------------------------------------------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:|\n|                                          colocate sync      \t                                           | 356.30 \t | 177.85 \t | 53.92        \t | 113.81       \t | 8h 36m                 \t | 17h 56m                \t | 1d 5h 6m               \t | 1d 16h 48m             \t | max: 0.3573<br>last: 0.2958  \t |\n| `stream off policy pipeline`<br>(+fully async: trigger_parameter_sync_step= 4,<br>require_batches= 4) \t | 231.34 \t | 128.47 \t | \\            \t | 98.77        \t | 4h 25m                 \t | 9h 41m                 \t | 15h 2m                 \t | 1d 1h 53m              \t | max: 0.2844<br>last: 0.2604 \t  |\n|          `async stream pipeline with stale samples`<br>(+staleness_threshold=0.5)            \t          |    \t     |    \t     |       \t        |       \t        |            \t             |            \t             |            \t             |            \t             |               \t                |\n|        `async stream pipeline with partial rollout`<br>(+partial_rollout=True)                 \t        | 150.63 \t | 33.14  \t | \\            \t | 113.16       \t | 3h 13m                 \t | 6h 46m                 \t | 10h 53m                \t | 17h 22m                \t | max: 0.3521<br>last: 0.3094 \t  |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg\n\n### 128-card Stale Ablation Experiment\n\nUnder the `async stream pipeline with partial rollout` mode, we verified the impact of staleness settings on training\nefficiency.\nWe found that the larger the staleness, the more obvious the final gains.\nWe also noticed that the times for staleness values of 0.3 and 0.5 are quite close, because as the training steps\nincrease, the response length changes significantly, causing training instability.\nFurther analysis and optimization are needed for this issue.\n\n| staleness_threshold \t | step  \t  |  gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t | total time<br>400 step \t |     acc/mean@1         \t      |\n|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|\n| 0                   \t | 231.34 \t | 128.47 \t | \\            \t | 98.77        \t | 4h 25m                 \t | 9h 41m                 \t | 15h 2m                 \t | 1d 1h 53m              \t | max: 0.2844<br>last: 0.2604 \t |\n| 0.1                 \t | 171.30 \t | 58.17  \t | \\            \t | 109.12       \t | 3h 53m                 \t | 8h 37m                 \t | 14h 25m                \t | 19h 59m                \t | max: 0.3542<br>last: 0.2979 \t |\n| 0.3                 \t | 146.11 \t | 38.88  \t | \\            \t | 103.22       \t | 3h 18m                 \t | 6h 49m                 \t | 11h 40m                \t | 17h 20m                \t | max: 0.3469<br>last: 0.2865 \t |\n| 0.5                 \t | 150.63 \t | 33.14  \t | \\            \t | 113.16       \t | 3h 13m                 \t | 6h 46m                 \t | 10h 53m                \t | 17h 22m                \t | max: 0.3521<br>last: 0.3094 \t |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg\n\n### 128-card 7B require_batches Ablation Experiment\n\nIn multiple tests, we found that the number of samples issued each time in streaming affects the response length during\ntraining, which in turn affects training time. We verified the impact on results by modifying\n`async_training.require_batches`.\n\n| require_batches \t | step  \t  | gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t |     acc/mean@1         \t      |\n|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|\n| 1               \t | 203.47 \t | 30.88 \t | \\            \t | 181.08       \t | 3h 31m                 \t | 8h 29m                 \t | 17h 36m                \t | max: 0.349<br>last: 0.326   \t |\n| 2               \t | 158.72 \t | 26.32 \t | \\            \t | 128.08       \t | 3h 35m                 \t | 7h 38m                 \t | 13h 57m                \t | max: 0.351<br>last: 0.3406  \t |\n| 4               \t | 124.64 \t | 25.62 \t | \\            \t | 95.06        \t | 3h 13m                 \t | 6h 46m                 \t | 10h 53m                \t | max: 0.3521<br>last: 0.3521 \t |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg\n\n### 30B Model Mode Experiment\n\nTODO: The 30B experiment is still in progress.\n\n* Machine: H20\n* Model: Qwen2.5-32B\n* Rollout length: max_response_length FSDP2: 20K tokens;\n* Algorithm: DAPO\n* Engine: vllm+FSDP2\n* rollout.n: 16\n* ppo_mini_batch_size: 32\n* test_freq: 20\n\n* colocate sync:\n    * step:200\n    * train_batch_size: 512\n\n* fully_async_policy\n    * total_rollout_steps: 512*200\n    * trigger_parameter_sync_step: 512/32 = 16\n    * staleness_threshold: 0\n    * partial_rollout: False\n\n| training mode      | Resource allocation | mode                                       | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean |\n|--------------------|---------------------|--------------------------------------------|------|--------------------|--------------|--------------|------------|------------------|\n| colocate sync      | 128                 |                                            |      |                    |              |              |            |                  |\n| fully_async_policy | 64:64               | stream off policy pipeline                 |      |                    |              |              |            |                  |\n| fully_async_policy | 64:64               | async stream pipeline with stale samples   |      |                    |              |              |            |                  |\n| fully_async_policy | 64:64               | async stream pipeline with partial rollout |      |                    |              |              |            |                  |\n\n## Future Plans\n\n* GRPO experiments\n* Megatron adaptation\n* SGLang integration\n* Transfer queue integration\n* Asynchronous parameter synchronization\n* AReaL asynchronous algorithm implementation\n* TPPO algorithm implementation\n* Multi-turn and Tool support"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/README_zh.md",
    "content": "# Recipe: Fully Async Policy Trainer\n\n**Author:**  `https://github.com/meituan-search`\n\nLast updated: 10/17/2025.\n\n本文档介绍了完全异步PPO训练系统，该系统实现了 Trainer 和 Rollouter 的完全解耦，支持异步样本生成和训练。\n在该系统下，我们使用128卡训练qwen2.5-7B模型取得了2.35x-2.67x的性能提升,同时效果没有显著受到影响。\n\n## Introduction\n\n### Background\n\nrollout和train分离架构相较于colocate的架构能够更加灵活地分配资源，设计更加灵活的训练逻辑，从而处理长尾等问题带来的GPU利用率低，训练效率低的问题。\none_step_off_policy通过分离架构的设计并进行rollout和train一轮异步的训练方法，缓解了rollout时间过长的问题，并在训练效率上取得了一些收益，\n但其强制使用一轮异步的数据，存在不够灵活等问题，而且并不能完全去除长尾对训练效率带来的的影响；在其他框架如areal、Magistral、streamrl、asyncflow上，\n已经基于分离架构实现了异步训练、流式训练，并取得了收益；我们借鉴其方法，在verl上进行了实现。fully_async_policy支持异步、流式、partial\nrollout的训练， 通过合理设置资源分配情况、参数同步频率等参数，fully_async_policy能够显著提高训练效率。\n\n> Magistral https://arxiv.org/abs/2506.10910\n>\n> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language\n> Reasoning https://arxiv.org/abs/2505.24298\n>\n> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream\n> Generation https://arxiv.org/abs/2504.15930\n>\n> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663\n>\n\n### 核心贡献\n\n* **资源隔离**：与使用hybrid_engine不同，Rollouter和Trainer使用分离的计算资源，需要分别指定所占用的资源。\n* **生成与训练并行**：Trainer在训练的同时，Rollouter在生成新的样本。\n* **多步异步**: 相比 one step off policy 支持0.x步到多步的异步设定，异步方案更加灵活。\n* **nccl参数同步**：使用nccl通信原语进行Rollouter与Trainer参数的通信。\n* **Stream推理与训练**：Rollouter逐样本生成数据，同时数据传输以单个sample为最小传输单位。\n* **异步训练与新鲜度控制**：通过设置参数async_training.staleness_threshold，支持使用旧参数生成的样本进行训练。\n* **PartialRollout**: Rollouter推理过程支持partial rollout逻辑，通过参数同步时，添加`sleep()`和`resume()`\n  逻辑，保存进行中的rollout的样本，并在下一次rollout中继续使用，减少参数同步等待进行中的任务结束时间。\n\n目前支持使用模式为 megatron/fsdp+vllm。vllm必须使用基于AgentLoop的server模式。\n\n## 设计\n\nfully_async_policy的整体架构如下图所示，fully_async_policy主要由Rollouter、MessageQueue、Trainer、ParameterSynchronizer四部分组成。\n\n![fully_async_policy_structure](\nhttps://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true)\n\n1. Rollouter逐样本生成序列，并将生成的sample放入MessageQueue中，生产的速度受新鲜度控制。\n2. MessageQueue用于暂存Rollouter生成的sample。\n3. Trainer逐样本从MessageQueue中获取，获取到`require_batches*ppo_mini_batch_size`\n   数量的样本后，就会进行训练，训练async_training.trigger_parameter_sync_step轮后，触发与Rollouter的一次参数同步。\n4. ParameterSynchronizer 实现了Nccl的同步参数同步能力。\n\n当前方案对比base的收益来源，在于colocate情况下，rollout使用更多的资源无法解决长尾样本带来的空闲，\n当我们进行资源隔离后，rollout的时间和train的时间都可能相较于之前更长（因为使用的资源变少了），\n但是相互之间的耗时overlap，端到端的耗时反而有所缩减。\n\n![fully_async_policy_revenue](\nhttps://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true)\n\n## 使用方式\n\n### 参数说明\n\n| super params                                         | implication                                                     |\n|------------------------------------------------------|-----------------------------------------------------------------|\n| `trainer.nnodes`                                     | Trainer的node数量                                                  |\n| `trainer.n_gpus_per_node`                            | Trainer每个node上gpu的数量                                            |\n| `rollout.nnodes`                                     | Rollouter的node数量                                                |\n| `rollout.n_gpus_per_node`                            | Rollouter每个node上gpu的数量                                          |\n| `data.train_batch_size`                              | 在fully async策略中，该值不生效（默认设置为0）                                   |\n| `data.gen_batch_size`                                | 在fully async策略中，使用流式的样本生产逻辑（默认设置为1)                             |\n| `rollout.total_rollout_steps`                        | 总的rollout的sample数量                                              |\n| `rollout.test_freq`                                  | Rollouter每更新多少次参数，进行一次validation                                |\n| `actor_rollout_ref.actor.ppo_mini_batch_size`        | The ppo_mini_batch_size is a global num across all workers/gpus |\n| `async_training.require_batches`                     | FullyAsyncTrainer一次性获取的ppo_mini_batch_size的数量                   |\n| `async_training.trigger_parameter_sync_step`         | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步                          |\n| `async_training.staleness_threshold`                 | 新鲜度控制                                                           |\n| `async_training.partial_rollout`                     | 是否进行partial_rollout                                             |\n| `async_training.use_rollout_log_probs`               | 使用rollout产生的log_probs                                           |\n| `async_training.compute_prox_log_prob`（experimental） | 是否在train阶段，使用train模型的参数计算token的 log_prob                        |\n\n**进一步的解释：**\n\n* `rollout.total_rollout_steps`\n\n  与 colocate 相比，数量可以通过 train_batch_size 与 step 相乘对齐:\n  `rollout.total_rollout_steps = data.train_batch_size * step`。\n\n* `async_training.trigger_parameter_sync_step`\n\n  在fully async策略中，表示Trainer进行多少次本地更新后（也就是获取多少次`require_batches * ppo_mini_batch_size`数量样本），\n  与Rollouter之间进行一次参数同步。\n  每两次Rollouter和Trainer参数同步之间，Trainer将会处理`trigger_parameter_sync_step* require_batches\\\n  ppo_mini_batch_size`份sample。\n  如果为了与colocate在公平的情况下对比速度，trigger_parameter_sync_step应该设置为 `data.train_batch_size / (\n  require_batches * ppo_mini_batch_size)`。\n\n* `async_training.staleness_threshold`\n\n  在fully async策略中，表示最大允许使用的staleness样本的比例。\n\n    * staleness_threshold=0，表示同步训练。\n      Rollouter两次参数更新之间将会生成固定数量的样本，样本数为：\n      $$rollout\\_num = (trigger\\_parameter\\_sync\\_step*require\\_batches*ppo\\_mini\\_batch\\_size)$$\n    * staleness_threshold>0，表示异步训练， 可以设置为小数，支持更灵活的异步调用。\n      Rollouter两次参数更新之间将会最多生成的样本数为：\n      $$rollout\\_num = (1+staleness\\_threshold)*(trigger\\_parameter\\_sync\\_step*require\\_batches*ppo\\_mini\\_batch\\_size) - num\\_staleness\\_sample $$\n\n  num_staleness_sample 表示上一次rollout多生成的陈旧样本数。\n\n  由于是流式系统，rollout持续生成，trainer持续消费。如果rollouter较慢，trainer会更早触发参数同步，rollouter并不会实际生产rollout_num个样本。\n  当rollout 足够快时，staleness_threshold设置为1，基本上等价于one_step_off policy。\n  为了避免过期样本太多影响训练精度，建议该值设置小于1。\n\n* `async_training.partial_rollout`\n\n  partial_rollout只会在staleness_threshold>0时才实际上起作用。\n\n* `async_training.use_rollout_log_probs`\n\n  在强化学习算法中，log_probs与参数版本，token都存在隐性的相关性。由于PPO/GRPO/DAPO等算法的设定，我们在计算重要性采样时，\n  即 old_log_prob必须使用rollout参数及token所对应log_probs，才能保证算法的正确性。在fully\n  async策略中，我们默认old_log_prob是有rollout所计算的，而不是由trainer所计算。\n\n* `async_training.require_batches`\n\n  在流式训练中，require_batches 应该设置为1，表示生产够ppo_mini_batch_size样本后，就进行训练。\n  在实际测试中，我们发现，如果单次下发的样本较少，由于数据分发的顺序，会导致训练不稳定，response 长度变长。\n  在这里，我们额外提供 require_batches 进行流式分发，单次参与训练的样本数量控制。\n\n* `async_training.compute_prox_log_prob` （experimental）\n\n  我们在训练过程中，观测到随着训练的进行，训练后期指标和response长度可能会出现不稳定的情况，\n  这里我们可以使用 [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html) 的技术进行\n  重要性采样，缓解这一问题。为了使用 `Rollout Importance Sampling` 我们需要使用训练引擎使用当前的参数版本计算old_log_prob，此开关需要打开。\n  此外，在 mode d (async stream pipeline with partial rollout) 的情况下开启 `compute_prox_log_prob` 以及\n  `Rollout Importance Sampling` 后，我们的实现已近似Areal的 `Decoupled PPO`。\n\n### 模式支持\n\n1. on policy pipeline:\n    1. **trigger_parameter_sync_step=1，staleness_threshold=0**\n    2. Rollouter一次生产`require_batches*ppo_mini_batch_size`\n       的samples，Trainer获取这些samples后进行训练，训练完后Trainer和Rollouter之间进行一次参数同步;\n    3. 在rollout阶段，如果存在长尾的样本，但是rollout样本数较少时，较短的样本无法填充到空闲的资源中，会造成一定的资源浪费。\n    4. 如图a所示；\n\n2. stream off policy pipeline:\n    1. **trigger_parameter_sync_step>1，staleness_threshold=0**\n    2. 将会进行同步的流式训练，Rollouter一次生产`require_batches*ppo_mini_batch_size*trigger_parameter_sync_step`\n       的samples，Trainer每获取`require_batches*ppo_mini_batch_size`\n       就进行一次本地训练，训练trigger_parameter_sync_step次后，Trainer和Rollouter之间进行一次参数同步;\n    3. 相较于a，由于一次生成的样本更多，资源的空闲会更低。\n    4. 在一次step训练中，会存在两次资源闲置的时间，分别是在第一次获取样本时，train等待`require_batches*ppo_mini_batch_size`\n       个样本生产，以及最后一次参数更新时，rollout等待训练完成。\n    5. 如图b所示；\n\n3. async stream pipeline with staleness samples:\n    1. **trigger_parameter_sync_step>=1，staleness_threshold>0，partial_rollout=Flase**\n    2. Rollouter在每次参数更新后将计划最多生产rollout_num个样本（实际根据rollout速度，生成的样本可能会少与这个值）。\n    3. 如果rollout过程比较快，Rollouter将会在参数同步前额外生成一部分样本num_stale_samples，用于参数同步后立即给Trainer使用。\n       触发参数同步时，如果Rollouter有正在生产的任务，将会等待任务完成，同时不会添加新的任务；\n    4. 相较于b，除第一次step训练外，后续的训练都不会有wait first batch rollout finish的时间，但是会有wait active task\n       finish的时间。\n    5. 如图c所示；\n\n4. async stream pipeline with partial rollout:\n    1. **trigger_parameter_sync_step>=1，staleness_threshold>0，partial_rollout=True**\n    2. 相较于c，触发参数同步时，Rollouter如果有正在生产的sample，会打断rollout过程并进行参数同步，被中断的sample会在参数同步后继续生成。减少了wait\n       active task finish的时间。\n    3. 如图d所示；\n\n![fully_async_policy_mode](\nhttps://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true)\n\n### 关键指标\n\n| metrics                                        | implication                                               |\n|------------------------------------------------|-----------------------------------------------------------|\n| `trainer/idle_ratio`                           | Trainer闲置率                                                |\n| `rollouter/idle_ratio`                         | Rollouter闲置率                                              |\n| `fully_async/count/stale_samples_processed`    | 训练使用的旧sample总数                                            |\n| `fully_async/count/stale_trajectory_processed` | 训练使用的旧trajectory总数(一个sample会生产rollout.n条trajectory)       |\n| `fully_async/partial/total_partial_num`        | 两次trigger_parameter_sync_step之间Trainer处理的partial样本数       |\n| `fully_async/partial/partial_ratio`            | 两次trigger_parameter_sync_step之间Trainer处理的partial样本的比例     |\n| `fully_async/partial/max_partial_span`         | 两次trigger_parameter_sync_step之间Trainer处理的partial样本的最大参数跨度 |\n\n### 调参建议\n\n* 资源分配与调整:\n    * 合理的资源分配是获得好的训练效率的前提。理想的资源分配情况应该是使得Rollout的时间和Train的时间接近，从而使得整个训练过程流水气泡最小，\n      避免资源闲置，同时Trainer不会使用旧样本。在真实训练场景下，可以根据实际训练过程中rollout和train的空闲时间调整资源分配，\n      可从rollouter/idle_ratio和trainer/idle_ratio获得，如果rollouter/idle_ratio较高trainer/idle_ratio较低，\n      应该增多Trainer的资源减少Rollouter的资源，反之亦然。\n\n* 关键参数：\n    * staleness_threshold: 设置太大会导致较多的旧样本使用，影响模型效果，建议设置小于1。\n    * require_batches：越接近1，越接近纯流式过程，训练过程中bubble越小，能够在速度上获得更快的加速效果，但会对样本的处理顺序产生影响；\n    * trigger_parameter_sync_step: 设置的越小越接近on policy，但会导致频繁的参数同步，长尾样本浪费的资源无法被短样本填充，资源利用率低。\n      设置的越大有更高的计算效率，但是精度上会受到off policy的影响。\n    * rollout.test_freq: 会占用Rollouter资源，不建议设置太小。\n\n* 模式选择：通过调整不同的参数，Fully Async架构支持不同程度上的优化加速，适用于不同场景的任务。\n    * 对于小规模任务，需要保证训练的稳定性和 on-policy 性，对速度要求不高的场景，可以尝试使用on policy pipeline的模式（模式1）。\n    * 对于需要提高训练吞吐量，但对 staleness 敏感的场景，可以尝试使用 stream off policy pipeline 的模式。即通过\n      设置trigger_parameter_sync_step>1 ，提高 训练效率，但仍保持同步机制 (staleness_threshold=0 )（模式2）。\n    * 对于大规模任务，对训练速度有较高要求，且可以容忍一定 off-policy 程度、staleness的场景，可以设置staleness_threshold>\n      0、partial_rollout=True提高训练效率，使用 async stream pipeline 模式（模式 3 或 4）。\n\n### 快速开始\n\n```shell\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\ntotal_rollout_steps=$(((512*400)))\ntest_freq=10\nstaleness_threshold=0\ntrigger_parameter_sync_step=16\npartial_rollout=False\n\n\npython -m recipe.fully_async_policy.fully_async_main \\\n\ttrain_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    trainer.nnodes=\"${NNODES_TRAIN}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.nnodes=\"${NNODES_ROLLOUT}\" \\\n    rollout.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.test_freq=\"${test_freq}\" \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\"\n```\n\n## 实验\n\n### 在7B模型上进行异步训练\n\n我们使用 Qwen2.5-Math-7B 验证 fully async 策略在长候选下，多种资源下的收益情况。\n使用`async stream pipeline with staleness samples` 策略，我们在32卡，64卡，128卡都取得2x左右的性能提升，同时没有显著影响实验效果。\n\n* 机器：H20\n* 模型：Qwen2.5-Math-7B\n* rollout长度：max_response_length FSDP2: 28K tokens;\n* 算法：DAPO\n* 数据集： TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet\n* engine: vllm+FSDP2\n* rollout.n: 16\n* ppo_mini_batch_size: 32\n* test_freq: 20\n\n* colocate sync:\n    * step: 400\n    * train_batch_size: 512\n\n* fully_async_policy\n    * total_rollout_steps: 512*400\n    * require_batches: 4\n    * trigger_parameter_sync_step: 4\n    * staleness_threshold: 0.5\n    * partial_rollout: True\n\n|  training mode   \t   | resource allocation \t | step  \t  |  gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t | total time<br>400 step \t |      acc/mean@1          \t      |\n|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-------------------------------:|\n| colocate sync      \t | 32                  \t | 790.10 \t | 357.41 \t | 107.71       \t | 313.81       \t | 13h 44m                \t | 1d 3h 43m              \t | 2d 9h 22m              \t | 3d 17h 5m              \t | max: 0.3313<br>last: 0.2448  \t  |\n| fully_async_policy \t | 16:16               \t |  294.77  |  21.26   | \\            \t |     269.80     |    7h 58m<br>(1.72x)     |    16h 21m<br>(1.70x)    |   1d 0h 53m<br>(2.31x)   |   1d 9h 26m<br>(2.66x)   | max: 0.3302<br>last: 0.2333   \t |\n| colocate sync      \t | 64                  \t | 365.28 \t | 150.72 \t | 70.26        \t | 133.41       \t | 10h 22m                \t | 20h 45m                \t | 1d 7h 6m               \t | 1d 17h 32m             \t | max: 0.3365<br>last:  0.2333 \t  |\n| fully_async_policy \t | 32:32               \t | 189.26 \t | 28.46  \t | \\            \t | 156.98       \t | 4h 57m<br>(2.09x)      \t | 10h 14m<br>(2.03x)     \t | 16h 58m<br>(1.83x)     \t | 21h 40m<br>(1.92x)     \t | max: 0.3677<br>last: 0.3406  \t  |\n| colocate sync      \t | 128                 \t | 356.30 \t | 177.85 \t | 53.92        \t | 113.81       \t | 8h 36m                 \t | 17h 56m                \t | 1d 5h 6m               \t | 1d 16h 48m             \t | max: 0.3573<br>last: 0.2958  \t  |\n| fully_async_policy \t | 64:64               \t | 150.63 \t | 33.14  \t | \\            \t | 113.16       \t | 3h 13m<br>(2.67x)      \t | 6h 46m<br>(2.65x)      \t | 10h 53m<br>(2.67x)     \t | 17h 22m<br>(2.35x)     \t | max: 0.3521<br>last: 0.3094  \t  |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg\n\n### 128卡 7B 异步模式实验\n\n我们使用 Qwen2.5-Math-7B 验证 fully async 所支持的各个模式的效果。\n我们可以看到 stream 带来的收益大约1.6x，叠加 staleness 和 partial_rollout 后，收益为2.35x。\n\n|                             mode                                         \t                              | step  \t  |  gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t | total time<br>400 step \t |      acc/mean@1         \t      |\n|:-------------------------------------------------------------------------------------------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:|\n|                                          colocate sync      \t                                           | 356.30 \t | 177.85 \t | 53.92        \t | 113.81       \t | 8h 36m                 \t | 17h 56m                \t | 1d 5h 6m               \t | 1d 16h 48m             \t | max: 0.3573<br>last: 0.2958  \t |\n| `stream off policy pipeline`<br>(+fully async: trigger_parameter_sync_step= 4,<br>require_batches= 4) \t | 231.34 \t | 128.47 \t | \\            \t | 98.77        \t | 4h 25m                 \t | 9h 41m                 \t | 15h 2m                 \t | 1d 1h 53m              \t | max: 0.2844<br>last: 0.2604 \t  |\n|        `async stream pipeline with staleness samples`<br>(+staleness_threshold=0.5)            \t        |    \t     |    \t     |       \t        |       \t        |            \t             |            \t             |            \t             |            \t             |               \t                |\n|        `async stream pipeline with partial rollout`<br>(+partial_rollout=True)                 \t        | 150.63 \t | 33.14  \t | \\            \t | 113.16       \t | 3h 13m                 \t | 6h 46m                 \t | 10h 53m                \t | 17h 22m                \t | max: 0.3521<br>last: 0.3094 \t  |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg\n\n### 128卡 stale 消融实验\n\n在 `async stream pipeline with partial rollout` 模式下，我们验证 staleness 的设置对于训练效率的影响。\n我们可以发现，staleness 越大，最终取得的收益越明显。\n同时我们也注意到 staleness 取 0.3 和 0.5 的时间比较接近，原因是随着训练步数的增量，response 长度变化较大，训练出现了不稳定的问题。\n后续还需要针对该问题进行进一步的分析和优化。\n\n| staleness_threshold \t | step  \t  |  gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t | total time<br>400 step \t |     acc/mean@1         \t      |\n|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|\n| 0                   \t | 231.34 \t | 128.47 \t | \\            \t | 98.77        \t | 4h 25m                 \t | 9h 41m                 \t | 15h 2m                 \t | 1d 1h 53m              \t | max: 0.2844<br>last: 0.2604 \t |\n| 0.1                 \t | 171.30 \t | 58.17  \t | \\            \t | 109.12       \t | 3h 53m                 \t | 8h 37m                 \t | 14h 25m                \t | 19h 59m                \t | max: 0.3542<br>last: 0.2979 \t |\n| 0.3                 \t | 146.11 \t | 38.88  \t | \\            \t | 103.22       \t | 3h 18m                 \t | 6h 49m                 \t | 11h 40m                \t | 17h 20m                \t | max: 0.3469<br>last: 0.2865 \t |\n| 0.5                 \t | 150.63 \t | 33.14  \t | \\            \t | 113.16       \t | 3h 13m                 \t | 6h 46m                 \t | 10h 53m                \t | 17h 22m                \t | max: 0.3521<br>last: 0.3094 \t |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_stale?nw=nwuserhouzg\n\n### 128卡 7B require_batches 消融实验\n\n在多次测试下，我们发现流式每次下发样本的数量会影响训练的response长度，进而影响训练时长，我们通过修改\n`async_training.require_batches` 验证对与结果的影响。\n\n| require_batches \t | step  \t  | gen  \t  | old_log_prob \t | update_actor \t | total time<br>100 step \t | total time<br>200 step \t | total time<br>300 step \t |     acc/mean@1         \t      |\n|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|\n| 1               \t | 203.47 \t | 30.88 \t | \\            \t | 181.08       \t | 3h 31m                 \t | 8h 29m                 \t | 17h 36m                \t | max: 0.349<br>last: 0.326   \t |\n| 2               \t | 158.72 \t | 26.32 \t | \\            \t | 128.08       \t | 3h 35m                 \t | 7h 38m                 \t | 13h 57m                \t | max: 0.351<br>last: 0.3406  \t |\n| 4               \t | 124.64 \t | 25.62 \t | \\            \t | 95.06        \t | 3h 13m                 \t | 6h 46m                 \t | 10h 53m                \t | max: 0.3521<br>last: 0.3521 \t |\n\n> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg\n\n### 30B模型模式实验\n\nTODO: 30B 的实验，还在完善中。\n\n* 机器: H20\n* 模型：Qwen2.5-32B\n* rollout长度：max_response_length FSDP2: 20K tokens;\n* 算法：DAPO\n* engine: vllm+FSDP2\n* rollout.n: 16\n* ppo_mini_batch_size: 32\n* test_freq: 20\n\n* colacate sync:\n    * step:200\n    * train_batch_size: 512\n\n* fully_async_policy\n    * total_rollout_steps: 512*200\n    * trigger_parameter_sync_step: 512/32 = 16\n    * staleness_threshold: 0\n    * partial_rollout: False\n\n| training mode      | Resource allocation | mode                                         | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean |\n|--------------------|---------------------|----------------------------------------------|------|--------------------|--------------|--------------|------------|------------------|\n| colocate sync      | 128                 |                                              |      |                    |              |              |            |                  |\n| fully_async_policy | 64:64               | stream off policy pipeline                   |      |                    |              |              |            |                  |\n| fully_async_policy | 64:64               | async stream pipeline with staleness samples |      |                    |              |              |            |                  |\n| fully_async_policy | 64:64               | async stream pipeline with partial rollout   |      |                    |              |              |            |                  |\n\n## 后续计划\n\n* GRPO实验\n* megatron 适配\n* sglang 集成\n* transfer queue 集成\n* 异步参数同步\n* Areal异步算法实现\n* TPPO算法实现\n* 多轮及Tool的支持"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/agent_loop/__init__.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .agent_loop import FullyAsyncAgentLoopManager\nfrom .partial_single_turn_agent_loop import PartialSingleTurnAgentLoop\n\n_ = [PartialSingleTurnAgentLoop]\n__all__ = [FullyAsyncAgentLoopManager]\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/agent_loop/agent_loop.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport logging\nimport os\nfrom typing import Any, Optional\n\nimport hydra\nimport numpy as np\nimport ray\nfrom omegaconf import DictConfig\n\nfrom recipe.fully_async_policy.vllm_rollout.vllm_async_server import FullyAsyncvLLMReplica\nfrom verl.experimental.agent_loop.agent_loop import (\n    AgentLoopManager,\n    AgentLoopOutput,\n    AgentLoopWorkerBase,\n    AsyncLLMServerManager,\n    _agent_loop_registry,\n    _DummyConfig,\n    get_trajectory_info,\n)\nfrom verl.protocol import DataProto\nfrom verl.single_controller.ray import RayWorkerGroup\nfrom verl.utils.rollout_trace import rollout_trace_attr\nfrom verl.workers.rollout.replica import TokenOutput\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass FullyAsyncLLMServerManager(AsyncLLMServerManager):\n    async def generate_for_partial(self, request_id, prompt_ids, sampling_params, **kwargs_extra) -> TokenOutput:\n        \"\"\"Generate tokens from prompt ids. with partial rollout function\"\"\"\n        server = self._choose_server(request_id)\n        output = await server.generate_for_partial.remote(\n            request_id=request_id,\n            prompt_ids=prompt_ids,\n            sampling_params=sampling_params,\n            **kwargs_extra,\n        )\n        return output\n\n\nclass FullyAsyncAgentLoopOutput(AgentLoopOutput):\n    \"\"\"Agent loop output.\"\"\"\n\n    is_cancel: bool = False\n    \"\"\"Indicates whether the request was interrupted\"\"\"\n    log_probs: list[float] = None\n    \"\"\"Response token log probs including LLM generated token, tool response token.\"\"\"\n    param_version_start: int = 0\n    \"\"\"Indicate start parameter version when this response is generated\"\"\"\n    param_version_end: int = 0\n    \"\"\"Indicate end parameter version when this response is generated, used for partial rollout\"\"\"\n\n\n@ray.remote\nclass FullyAsyncAgentLoopWorker(AgentLoopWorkerBase):\n    def __init__(\n        self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], reward_router_address: str = None\n    ):\n        self.server_manager = FullyAsyncLLMServerManager(config, server_handles)\n        super().__init__(config, server_handles, reward_router_address)\n\n    async def generate_sequences_no_post(\n        self, batch: DataProto, partial_output_list: Optional[list[AgentLoopOutput]]\n    ) -> list[AgentLoopOutput]:\n        \"\"\"Generate sequences from agent loop.\n\n        Args:\n            batch (DataProto): Input batch.\n            partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result.\n\n        Returns:\n            list[FullyAsyncAgentLoopOutput]: List of agent loop outputs, one per sample in the batch.\n        \"\"\"\n        config = self.config.actor_rollout_ref.rollout\n        sampling_params = dict(\n            temperature=config.temperature,\n            top_p=config.top_p,\n            repetition_penalty=1.0,\n            logprobs=config.calculate_log_probs,\n        )\n\n        # override sampling params for validation\n        if batch.meta_info.get(\"validate\", False):\n            sampling_params[\"top_p\"] = config.val_kwargs.top_p\n            sampling_params[\"temperature\"] = config.val_kwargs.temperature\n\n        # by default, we assume it's a single turn agent\n        if \"agent_name\" not in batch.non_tensor_batch:\n            batch.non_tensor_batch[\"agent_name\"] = np.array([\"single_turn_agent\"] * len(batch), dtype=object)\n\n        if \"index\" in batch.non_tensor_batch:\n            index = batch.non_tensor_batch[\"index\"]\n        else:\n            index = np.arange(len(batch))\n\n        trajectory_info = await get_trajectory_info(\n            batch.meta_info.get(\"global_steps\", -1), index, batch.meta_info.get(\"validate\", False)\n        )\n\n        if not partial_output_list:\n            partial_output_list = [None] * len(batch)\n\n        tasks = []\n        for i in range(len(batch)):\n            kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}\n            kwargs[\"output\"] = partial_output_list[i]\n            tasks.append(\n                asyncio.create_task(self._partial_run_agent_loop(sampling_params, trajectory_info[i], **kwargs))\n            )\n        return await asyncio.gather(*tasks)\n\n    async def _partial_run_agent_loop(\n        self,\n        sampling_params: dict[str, Any],\n        trajectory: dict[str, Any],\n        *,\n        agent_name: str,\n        **kwargs,\n    ) -> AgentLoopOutput:\n        with rollout_trace_attr(\n            step=trajectory[\"step\"],\n            sample_index=trajectory[\"sample_index\"],\n            rollout_n=trajectory[\"rollout_n\"],\n            validate=trajectory[\"validate\"],\n            name=\"agent_loop\",\n        ):\n            assert agent_name in _agent_loop_registry, (\n                f\"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}\"\n            )\n\n            agent_loop_config = _agent_loop_registry[agent_name]\n            agent_loop = hydra.utils.instantiate(\n                config=agent_loop_config,\n                trainer_config=_DummyConfig(config=self.config),\n                server_manager=self.server_manager,\n                tokenizer=self.tokenizer,\n                processor=self.processor,\n            )\n            return await agent_loop.run(sampling_params, **kwargs)\n\n\nclass FullyAsyncAgentLoopManager(AgentLoopManager):\n    def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None):\n        self.config = config\n        self.worker_group = worker_group\n        self.reward_model_manager = None\n        self.reward_router_address = None\n        self.agent_loop_workers_class = FullyAsyncAgentLoopWorker\n        self.rollout_replica_class = FullyAsyncvLLMReplica\n\n        self.rm_wg = rm_wg\n        self.rollout_replicas = None\n        self.server_handles = None\n        self.server_addresses = None\n        self.agent_loop_workers = None\n\n    @classmethod\n    async def create(cls, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None):\n        instance = cls(config, worker_group, rm_wg)\n        await instance._async_init()\n        return instance\n\n    async def _async_init(self):\n        if self.config.reward_model.enable and self.config.reward_model.enable_resource_pool:\n            from verl.experimental.reward import RewardModelManager\n\n            self.reward_model_manager = RewardModelManager(self.config.reward_model, self.rm_wg)\n            self.reward_router_address = self.reward_model_manager.get_router_address()\n\n        await self._initialize_llm_servers_async()\n        self._init_agent_loop_workers()\n\n    async def _initialize_llm_servers_async(self):\n        rollout_world_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size\n        world_size = (\n            self.worker_group.world_size\n            if self.worker_group\n            else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes\n        )\n        num_replicas = world_size // rollout_world_size\n\n        rollout_config = self.config.actor_rollout_ref.rollout\n        model_config = self.config.actor_rollout_ref.model\n        self.rollout_replicas = [\n            self.rollout_replica_class(\n                replica_rank=replica_rank,\n                config=rollout_config,\n                model_config=model_config,\n                gpus_per_node=self.config.trainer.n_gpus_per_node,\n            )\n            for replica_rank in range(num_replicas)\n        ]\n\n        if self.worker_group:\n            await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas])\n        else:\n            await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas])\n\n        self.server_handles = [server._server_handle for server in self.rollout_replicas]\n        self.server_addresses = [server._server_address for server in self.rollout_replicas]\n\n    async def generate_single_sample_async(\n        self,\n        sample: DataProto,\n        partial_output_list: Optional[list[AgentLoopOutput]],\n    ) -> list[AgentLoopOutput]:\n        \"\"\"\n        Asynchronously process a single sample\n\n        Args:\n            sample: Single sample data\n            partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result.\n\n        Returns:\n            list[AgentLoopOutput]: Processing results\n        \"\"\"\n        worker = self._select_best_worker()\n        output_future = worker.generate_sequences_no_post.remote(sample, partial_output_list)\n        return await asyncio.wrap_future(output_future.future())\n\n    def _select_best_worker(self):\n        \"\"\"Select the best worker, simple round-robin load balancing\"\"\"\n        if not hasattr(self, \"_worker_index\"):\n            self._worker_index = 0\n\n        worker = self.agent_loop_workers[self._worker_index]\n        self._worker_index = (self._worker_index + 1) % len(self.agent_loop_workers)\n        return worker\n\n    async def cancel(self):\n        await asyncio.gather(*[replica.cancel() for replica in self.rollout_replicas])\n\n    async def resume(self):\n        await asyncio.gather(*[replica.resume() for replica in self.rollout_replicas])\n\n    async def wake_up(self):\n        await asyncio.gather(*[replica.wake_up() for replica in self.rollout_replicas])\n\n    async def sleep(self):\n        await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas])\n\n    async def reset_prefix_cache(self):\n        await asyncio.gather(*[replica.reset_prefix_cache() for replica in self.rollout_replicas])\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom recipe.fully_async_policy.agent_loop.agent_loop import AgentLoopOutput, FullyAsyncAgentLoopOutput\nfrom verl.experimental.agent_loop import AgentLoopBase\nfrom verl.experimental.agent_loop.agent_loop import register\nfrom verl.utils.profiler import simple_timer\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n@register(\"partial_single_turn_agent\")\nclass PartialSingleTurnAgentLoop(AgentLoopBase):\n    \"\"\"Naive agent loop that only do single turn chat completion.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length\n        self.response_length = self.config.actor_rollout_ref.rollout.response_length\n        self.apply_chat_template_kwargs = self.config.data.get(\"apply_chat_template_kwargs\", {})\n\n    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:\n        output: Optional[FullyAsyncAgentLoopOutput] = kwargs.get(\"output\", None)\n        messages = list(kwargs[\"raw_prompt\"])\n        param_version = kwargs.get(\"param_version\", 0)\n\n        metrics = {}\n        request_id = uuid4().hex\n        image_data = (kwargs.get(\"multi_modal_data\") or {}).get(\"image\", None)\n\n        param_version_start = param_version\n        param_version_end = param_version\n\n        if not output:\n            # TODO(baiyan): it is supposed to use the correct processor,\n            #    but I found the async training would hang if use_correct_processor=True.\n            #    so we use the tokenizer to tokenize the prompt for now.\n            use_correct_processor = False\n            if self.processor is not None and use_correct_processor:\n\n                def get_prompt_ids():\n                    raw_prompt = self.processor.apply_chat_template(\n                        messages,\n                        add_generation_prompt=True,\n                        tokenize=False,\n                        **self.apply_chat_template_kwargs,\n                    )\n                    model_inputs = self.processor(text=[raw_prompt], images=image_data, return_tensors=\"pt\")\n                    return model_inputs.pop(\"input_ids\").squeeze(0).tolist()\n\n                prompt_ids = await self.loop.run_in_executor(None, get_prompt_ids)\n            else:\n                prompt_ids = await self.loop.run_in_executor(\n                    None,\n                    lambda: self.tokenizer.apply_chat_template(\n                        messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs\n                    ),\n                )\n        else:\n            if output.is_cancel:\n                # Resume the paused sample,\n                # add the result directly after prompt_ids,\n                # and reset generate_sequences metric\n                prompt_ids = output.prompt_ids + output.response_ids\n                metrics[\"generate_sequences\"] = output.metrics.generate_sequences\n                param_version_start = output.param_version_start\n            else:\n                # In the same batch of samples,\n                # ome are canceled and some are not.\n                # The samples without partial rollout are returned directly.\n                return output\n        with simple_timer(\"generate_sequences\", metrics):\n            response_ids, log_probs, is_cancel = await self.server_manager.generate_for_partial(\n                request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params, image_data=image_data\n            )\n        if not output:\n            response_mask = [1] * len(response_ids)\n        else:\n            # Pause the sample to be resumed, add the output result to response_ids, and reset response_mask\n            prompt_ids = output.prompt_ids\n            log_probs = output.log_probs + log_probs\n            response_ids = output.response_ids + response_ids\n            response_mask = [1] * len(response_ids)\n\n        return FullyAsyncAgentLoopOutput(\n            prompt_ids=prompt_ids,\n            response_ids=response_ids[: self.response_length],\n            response_mask=response_mask[: self.response_length],\n            num_turns=2,\n            metrics=metrics,\n            is_cancel=is_cancel,\n            log_probs=log_probs,\n            param_version_start=param_version_start,\n            param_version_end=param_version_end,\n            # multi_modal_data={\"image\": image_data} if image_data is not None else {},\n        )\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_megatron_trainer\n  - _self_\n\nasync_training:\n\n  # Maximum samples staleness threshold\n  staleness_threshold: 0.1\n\n  # Frequency of parameter synchronization between rollouter and trainer, \n  # One step means trainer obtains a batch of required samples\n  trigger_parameter_sync_step: 4\n  \n  # The number of ppo_mini_batches that the FullyAsyncTrainer obtains once\n  require_batches: 1\n\n  # When synchronizing parameters, whether to interrupt rollouter and perform partial rollout\n  partial_rollout: True\n\n  # Whether to use rollout log probs for training\n  use_rollout_log_probs: True\n\n  # compute_prox_log_prob\n  compute_prox_log_prob: False\n\n# Rollout config\nrollout:\n\n  # Number of nodes used in the rollout\n  nnodes: 1\n\n  # Number of GPUs per node                     \n  n_gpus_per_node: 8\n\n  # number of responses (i.e. num sample times). > 1 for grpo\n  n: 4\n\n  # total rollout samples # TODO rename to total_rollout_samples\n  total_rollout_steps: 100\n\n  # Number of epochs in training \n  total_epochs: 10\n\n  # Test frequency, how many times a parameter update triggers a validation\n  test_freq: 1\n\ndata:\n  # Number of samples generated, currently only support 1\n  gen_batch_size: 1\n\nactor_rollout_ref:\n  actor:\n    # Whether to use rollout log probs for training\n    use_rollout_log_probs: ${oc.select:async_training.use_rollout_log_probs, True}\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\nasync_training:\n\n  # Maximum samples staleness threshold\n  staleness_threshold: 0.1\n\n  # Frequency of parameter synchronization between rollouter and trainer, \n  # One step means trainer obtains a batch of required samples\n  trigger_parameter_sync_step: 4\n  \n  # The number of ppo_mini_batches that the FullyAsyncTrainer obtains once\n  require_batches: 1\n\n  # When synchronizing parameters, whether to interrupt rollouter and perform partial rollout\n  partial_rollout: True\n\n  # Whether to use rollout log probs for training\n  use_rollout_log_probs: True\n\n  # compute_prox_log_prob\n  compute_prox_log_prob: False\n\n# Rollout config\nrollout:\n\n  # Number of nodes used in the rollout\n  nnodes: 1\n\n  # Number of GPUs per node                     \n  n_gpus_per_node: 8\n\n  # number of responses (i.e. num sample times). > 1 for grpo\n  n: 4\n\n  # total rollout samples # TODO rename to total_rollout_samples\n  total_rollout_steps: 100\n\n  # Number of epochs in training \n  total_epochs: 10\n\n  # Test frequency, how many times a parameter update triggers a validation\n  test_freq: 1\n\ndata:\n  # Number of samples generated, currently only support 1\n  gen_batch_size: 1\n\nactor_rollout_ref:\n  actor:\n    # Whether to use rollout log probs for training\n    use_rollout_log_probs: ${oc.select:async_training.use_rollout_log_probs, True}\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/detach_utils.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport time\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import Any, Optional\n\nimport numpy as np\nimport torch\nfrom tensordict import TensorDict\n\nfrom verl import DataProto\nfrom verl.experimental.agent_loop.agent_loop import AgentLoopOutput\nfrom verl.trainer.ppo.ray_trainer import compute_response_mask\nfrom verl.utils.model import compute_position_id_with_mask\n\n\ndef postprocess_agent_loop_outputs(rs: \"RolloutSample\", tokenizer, config, processor) -> DataProto:\n    \"\"\"Static method to postprocess a list of AgentLoopOutput into DataProto\n\n    Args:\n        rs: RolloutSample\n        tokenizer: Tokenizer instance\n        config: Configuration object\n\n    Returns:\n        DataProto: Processed batch data\n    \"\"\"\n    inputs: list[AgentLoopOutput] = rs.agent_loop_output_list\n    full_batch = rs.full_batch\n    # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py\n    # prompts: left pad\n    # responses: right pad\n    # input_ids: prompt + response\n    # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n    # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n\n    # prompts\n    tokenizer.padding_side = \"left\"\n    outputs = tokenizer.pad(\n        [{\"input_ids\": input.prompt_ids} for input in inputs],\n        padding=\"max_length\",\n        max_length=config.actor_rollout_ref.rollout.prompt_length,\n        return_tensors=\"pt\",\n        return_attention_mask=True,\n    )\n    prompt_ids, prompt_attention_mask = outputs[\"input_ids\"], outputs[\"attention_mask\"]\n\n    # responses\n    tokenizer.padding_side = \"right\"\n    outputs = tokenizer.pad(\n        [{\"input_ids\": input.response_ids} for input in inputs],\n        padding=\"max_length\",\n        max_length=config.actor_rollout_ref.rollout.response_length,\n        return_tensors=\"pt\",\n        return_attention_mask=True,\n    )\n    response_ids, response_attention_mask = outputs[\"input_ids\"], outputs[\"attention_mask\"]\n\n    # response_mask\n    outputs = tokenizer.pad(\n        [{\"input_ids\": input.response_mask} for input in inputs],\n        padding=\"max_length\",\n        max_length=config.actor_rollout_ref.rollout.response_length,\n        return_tensors=\"pt\",\n        return_attention_mask=False,\n    )\n    response_mask = outputs[\"input_ids\"]\n    assert response_ids.shape == response_mask.shape, (\n        f\"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}\"\n    )\n    response_mask = response_mask * response_attention_mask\n\n    # Handle multi-modal inputs and position_ids calculation\n    # Only support Qwen2VLImageProcessor for multi-modal processing currently\n    # TODO: support other multi-modal inputs\n    multi_modal_inputs = None\n    if processor is not None and \"Qwen2VLImageProcessor\" in processor.image_processor.__class__.__name__:\n        # qwen-vl mrope\n        if \"Qwen3VLProcessor\" in processor.__class__.__name__:\n            pass\n        else:\n            pass\n\n        images = [one.get(\"image\", None) for one in full_batch.non_tensor_batch.get(\"multi_modal_data\")]\n        current_text = [tokenizer.decode(input.prompt_ids, skip_special_tokens=False) for input in inputs]\n        multi_modal_inputs = processor(\n            text=current_text,\n            images=images,\n            return_tensors=\"pt\",\n            max_length=config.actor_rollout_ref.rollout.prompt_length,\n            padding=\"max_length\",\n            padding_side=\"left\",\n        )\n\n        prompt_ids = multi_modal_inputs.pop(\"input_ids\")\n        prompt_attention_mask = multi_modal_inputs.pop(\"attention_mask\")\n\n        # TODO: megatron will cauculate rope position_ids in the forward pass, so we don't need to calculate it here\n        #       but for FSDP support, we need to calculate it here\n\n        # # We must use dict(multi_modal_inputs) to convert BatchFeature values to a new dict\n        # # because np.array() only keeps the keys for BatchFeature.\n        # multi_modal_inputs = dict(multi_modal_inputs)\n\n        # image_grid_thw = multi_modal_inputs.get(\"image_grid_thw\")\n        # video_grid_thw = multi_modal_inputs.get(\"video_grid_thw\")\n        # second_per_grid_ts = multi_modal_inputs.get(\"second_per_grid_ts\")\n\n        # vision_position_ids = get_rope_index(\n        #     processor,\n        #     input_ids=input_ids.squeeze(0),\n        #     image_grid_thw=image_grid_thw,\n        #     video_grid_thw=video_grid_thw,\n        #     second_per_grid_ts=second_per_grid_ts,\n        #     attention_mask=attention_mask.squeeze(0),\n        # ).unsqueeze(0)  # (1, 3, seq_len)\n\n        # valid_mask = attention_mask[0].bool()\n        # text_position_ids = torch.ones((1, len(input_ids[0])), dtype=torch.long)\n        # text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item())\n        # text_position_ids = text_position_ids.unsqueeze(0)\n        # position_ids = torch.cat((text_position_ids, vision_position_ids), dim=1)  # (1, 4, seq_length)\n    else:\n        pass\n    input_ids = torch.cat([prompt_ids, response_ids], dim=1)\n    attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1)\n    position_ids = compute_position_id_with_mask(attention_mask)  # (1, seq_len)\n\n    batch = TensorDict(\n        {\n            \"prompts\": prompt_ids,  # [bsz, prompt_length]\n            \"responses\": response_ids,  # [bsz, response_length]\n            \"response_mask\": response_mask,  # [bsz, response_length]\n            \"input_ids\": input_ids,  # [bsz, prompt_length + response_length]\n            \"attention_mask\": attention_mask,  # [bsz, prompt_length + response_length]\n            \"position_ids\": position_ids,  # [bsz, prompt_length + response_length]\n        },\n        batch_size=len(input_ids),\n    )\n\n    num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32)\n    metrics = [input.metrics.model_dump() for input in inputs]\n    return DataProto(batch=batch, non_tensor_batch={\"__num_turns__\": num_turns}, meta_info={\"metrics\": metrics})\n\n\n@dataclass\nclass RolloutSample:\n    \"\"\"Enhanced rollout sample containing both original batch info and AgentLoopOutput\"\"\"\n\n    # Original batch information\n    full_batch: Any\n\n    # AgentLoopOutput from generation\n    agent_loop_output_list: list[Any]  # AgentLoopOutput\n\n    # Metadata\n    sample_id: str\n    epoch: int\n\n    # Processing metadata\n    processing_times: list[float]\n    param_version: int\n    param_version_start: list[int]\n    param_version_end: list[int]\n    rollout_status: dict[str, Any]\n\n\n@dataclass\nclass ValidateMetrics:\n    \"\"\"Metrics for validation\"\"\"\n\n    timing_raw: dict[str, Any]\n    metrics: Optional[dict[str, Any]] = None\n    global_steps: Optional[int] = None\n    param_version: Optional[int] = None\n\n\ndef prepare_single_generation_data(batch_dict, global_steps, rollout_n) -> DataProto:\n    \"\"\"\n    Similar to the logic of ray_trainer._prepare_generate_batch, but for a single sample.\n    Separate the data used for generation from the original data.\n\n    Returns:\n        tuple: (original_batch_dict, gen_data_for_single_sample)\n    \"\"\"\n\n    full_batch = DataProto.from_single_dict(batch_dict)\n\n    batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n    non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n\n    full_batch.pop(\n        batch_keys=batch_keys_to_pop,\n        non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n    )\n\n    # Setting agent - partial_single_turn_agent, that supports partial\n    full_batch.non_tensor_batch[\"agent_name\"] = np.array([\"partial_single_turn_agent\"] * len(full_batch), dtype=object)\n\n    # Add global step count to generated data\n    full_batch = full_batch.repeat(repeat_times=rollout_n, interleave=True)\n    return full_batch\n\n\ndef process_rollout_log_probs(data_proto: DataProto, rollout_log_probs: list[list[float]]) -> torch.Tensor:\n    \"\"\"\n    Process rollout_log_probs according to the mask in DataProto\n    mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n\n    Args:\n        data_proto: A DataProto object containing batch information\n        rollout_log_probs: A two-dimensional list, each sublist containing the log_probs of a sample\n\n    Returns:\n        torch.Tensor: The processed log_probs tensor, with shape: [bsz, response_length]\n    \"\"\"\n\n    batch = data_proto.batch\n    response_mask = batch[\"response_mask\"]\n    rollout_log_probs_tensor = torch.zeros(response_mask.shape, dtype=torch.float32) - 1\n\n    for i, log_probs_seq in enumerate(rollout_log_probs):\n        # Get the effective length of the current sample (the number of positions with 1 in the mask)\n        valid_length = response_mask[i].sum().item()\n\n        # Ensure that the length of log_probs_seq does not exceed the valid length\n        actual_length = min(len(log_probs_seq), valid_length)\n\n        # Fill log_probs into the corresponding position\n        if actual_length > 0:\n            rollout_log_probs_tensor[i, :actual_length] = torch.tensor(log_probs_seq[:actual_length])\n\n    rollout_log_probs_tensor = rollout_log_probs_tensor.to(torch.float32)\n    return rollout_log_probs_tensor\n\n\ndef merge_rollout_sample(config, tokenizer, rs: RolloutSample, processor):\n    \"\"\"\n    Supplement and refine the RolloutSample object,\n    \"\"\"\n    # Step 1: Create a DataProto from the AgentLoopOutput to generate the result\n    gen_batch_output = postprocess_agent_loop_outputs(rs, tokenizer, config, processor)\n    rollout_log_probs = [x.log_probs for x in rs.agent_loop_output_list]\n    rollout_log_probs = process_rollout_log_probs(gen_batch_output, rollout_log_probs)\n    gen_batch_output.batch[\"rollout_log_probs\"] = rollout_log_probs.to(torch.float32)\n\n    # Step 2: Add uid\n    rs.full_batch.non_tensor_batch[\"uid\"] = np.array([f\"uid_{rs.sample_id}\"] * len(rs.full_batch), dtype=object)\n\n    # Step 2: Merge batches\n    # Merge the non_tensor_batch and meta_info of original_batch into final_batch\n    for key, value in rs.full_batch.non_tensor_batch.items():\n        gen_batch_output.non_tensor_batch[key] = value\n    gen_batch_output.meta_info.update(rs.full_batch.meta_info)\n\n    # Step 3, set full_batch\n    rs.full_batch = gen_batch_output\n    rs.processing_times = []\n    for agent_loop in rs.agent_loop_output_list:\n        rs.processing_times.append(agent_loop.metrics.generate_sequences)\n    rs.param_version_start = [agent_loop.param_version_start for agent_loop in rs.agent_loop_output_list]\n    rs.param_version_end = [agent_loop.param_version_end for agent_loop in rs.agent_loop_output_list]\n    # Step 4, clear agent_loop_output_list\n    rs.agent_loop_output_list = []\n    return rs\n\n\ndef assemble_batch_from_rollout_samples(\n    rollout_samples: list[RolloutSample], tokenizer, config, balance_batch=None\n) -> DataProto:\n    \"\"\"\n    Assemble gen_batch_output from RolloutSample objects\n    Assembles batches from RolloutSample objects, similar to the _post_generate_batch logic in ray_trainer.\n\n    Args:\n        rollout_samples: List of RolloutSample objects\n        tokenizer: Tokenizer instance\n        config: Configuration object containing trainer settings\n        balance_batch: Whether to balance the batch (simplified version)\n\n    Returns:\n        DataProto: Assembled gen_batch_output\n\n    Raises:\n        ValueError: If rollout_samples is empty\n    \"\"\"\n    start_time = time.time()\n\n    if not rollout_samples:\n        raise ValueError(\"Empty rollout_samples provided for batch assembly\")\n\n    print(f\"[BatchUtils] Assembling batch from {len(rollout_samples)} RolloutSample objects\")\n\n    rollout_samples_batch = []\n    processing_times = []\n    rollout_status = rollout_samples[0].rollout_status\n    # Add a prefix to all rollout_status keys\n    rollout_status = {f\"fully_async/{key}\": value for key, value in rollout_status.items()}\n\n    for rs in rollout_samples:\n        rollout_samples_batch.append(rs.full_batch)\n        processing_times.extend(rs.processing_times)\n    final_batch = DataProto.concat(rollout_samples_batch)\n\n    # Calculate response_mask (if not present)\n    if \"response_mask\" not in final_batch.batch.keys():\n        final_batch.batch[\"response_mask\"] = compute_response_mask(final_batch)\n\n    if balance_batch:\n        balance_batch(final_batch, metrics={})\n\n    # Calculate the global valid token number\n    if \"attention_mask\" in final_batch.batch:\n        final_batch.meta_info[\"global_token_num\"] = torch.sum(final_batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n    # Collect statistics\n    param_versions = [rs.param_version for rs in rollout_samples]\n    trajectorys_param_versions = [version for rs in rollout_samples for version in rs.param_version_end]\n\n    processing_time_stats = {\n        \"processing_time/avg\": np.mean(processing_times),\n        \"processing_time/max\": np.max(processing_times),\n        \"processing_time/min\": np.min(processing_times),\n        \"processing_time/tp50\": np.percentile(processing_times, 50),\n        \"processing_time/tp99\": np.percentile(processing_times, 99),\n        \"processing_time/tp95\": np.percentile(processing_times, 95),\n    }\n    processing_time_stats = {f\"fully_async/{key}\": value for key, value in processing_time_stats.items()}\n\n    param_version_diff = [abs(a - b) for a, b in zip(rs.param_version_end, rs.param_version_start, strict=False)]\n    num_diff0 = param_version_diff.count(0)\n    partial_stats = {\n        \"fully_async/partial/total_partial_num\": len(param_version_diff) - num_diff0,\n        \"fully_async/partial/partial_ratio\": (len(param_version_diff) - num_diff0) / len(param_version_diff),\n        \"fully_async/partial/max_partial_span\": max(param_version_diff),\n    }\n    # add meta_info\n    final_batch.meta_info.update(\n        {\n            \"rollout_param_versions\": param_versions,\n            \"param_version_diversity\": len(set(param_versions)) if param_versions else 0,\n            \"trajectory_param_versions\": trajectorys_param_versions,\n            **processing_time_stats,\n            **rollout_status,\n            **partial_stats,\n        }\n    )\n\n    print(f\"[BatchUtils] Batch assembly completed in {time.time() - start_time:.2f}s\")\n\n    return final_batch\n\n\nclass MetricsAggregator:\n    \"\"\"Metrics aggregator, used to combine metrics from multiple training steps\"\"\"\n\n    def __init__(self, total_gpus: int):\n        # Store all values ​​for each metric\n        self.metric_values: dict[str, list[float]] = defaultdict(list)\n        # Store the number of samples at each step for weighted averaging\n        self.sample_counts: list[int] = []\n        # Store the timestamp of each step for time-related calculations\n        self.timestamps: list[float] = []\n        # Step Count\n        self.step_count = 0\n        # total num gpus used\n        self.total_gpus = total_gpus\n\n        # Metric aggregation rule configuration\n        self.aggregation_rules = self._init_aggregation_rules()\n\n    def _init_aggregation_rules(self) -> dict[str, dict[str, list[str]]]:\n        \"\"\"Initialize metrics aggregation rules\"\"\"\n        return {\n            # Time-Based metrics, can add metrics here\n            \"time_sum\": [\"perf/time_per_step\"],\n            \"last\": [\n                \"fully_async/count/total_generated_samples\",\n                \"fully_async/count/stale_samples_processed\",\n                \"fully_async/count/stale_trajectory_processed\",\n                \"fully_async/count/current_param_version\",\n                \"fully_async/count/dropped_stale_samples\",\n                \"training/global_step\",  # TODO change name to: total_step\n            ],\n        }\n\n    def add_step_metrics(self, metrics: dict[str, Any], sample_count: int, timestamp: float = None):\n        \"\"\"Adding a single-step metrics\"\"\"\n        if timestamp is None:\n            timestamp = time.time()\n\n        self.sample_counts.append(sample_count)\n        self.timestamps.append(timestamp)\n        self.step_count += 1\n\n        # Store all metrics values\n        for key, value in metrics.items():\n            if isinstance(value, int | float | np.number):\n                self.metric_values[key].append(float(value))\n            elif isinstance(value, torch.Tensor):\n                self.metric_values[key].append(float(value.item()))\n\n    def _get_aggregation_type(self, metric_name: str) -> str:\n        \"\"\"Determine the aggregation type based on the metric name\"\"\"\n        for agg_type, metric_list in self.aggregation_rules.items():\n            if metric_name in metric_list:\n                return agg_type\n\n        metric_lower = metric_name.lower()\n        if any(keyword in metric_lower for keyword in [\"timing_s/\"]):\n            return \"time_sum\"\n        if any(keyword in metric_lower for keyword in [\"mean\", \"avg\", \"average\"]):\n            return \"avg\"\n        if any(keyword in metric_lower for keyword in [\"max\", \"maximum\"]):\n            return \"max\"\n        if any(keyword in metric_lower for keyword in [\"min\", \"minimum\"]):\n            return \"min\"\n        if any(keyword in metric_lower for keyword in [\"sum\", \"total\"]):\n            return \"sum\"\n        if any(keyword in metric_lower for keyword in [\"weighted_avg\"]):\n            return \"weighted_avg\"\n\n        return \"avg\"\n\n    def _aggregate_single_metric(self, metric_name: str, values: list[float]) -> float:\n        \"\"\"Aggregating a single metric\"\"\"\n        if not values:\n            return 0.0\n\n        agg_type = self._get_aggregation_type(metric_name)\n\n        if agg_type == \"last\":\n            return values[-1]\n\n        elif agg_type == \"weighted_avg\":\n            # Weighted average\n            if len(values) != len(self.sample_counts):\n                # If the lengths do not match, use a simple average\n                return sum(values) / len(values)\n\n            total_samples = sum(self.sample_counts)\n            if total_samples == 0:\n                return sum(values) / len(values)\n\n            weighted_sum = sum(v * c for v, c in zip(values, self.sample_counts, strict=False))\n            return weighted_sum / total_samples\n\n        elif agg_type == \"sum\" or agg_type == \"time_sum\":\n            return sum(values)\n\n        elif agg_type == \"avg\":\n            return sum(values) / len(values)\n\n        elif agg_type == \"max\":\n            return max(values)\n\n        elif agg_type == \"min\":\n            return min(values)\n\n        else:\n            # Default average\n            return sum(values) / len(values)\n\n    def get_aggregated_metrics(self) -> dict[str, Any]:\n        \"\"\"aggregated metrics\"\"\"\n        t = time.time()\n        if self.step_count == 0:\n            return {}\n\n        aggregated = {}\n\n        # Aggregate all metrics\n        for metric_name, values in self.metric_values.items():\n            aggregated[metric_name] = self._aggregate_single_metric(metric_name, values)\n\n        # Aggregate special metrics\n        aggregated = self._special_metrics_aggergate(aggregated)\n\n        print(f\"aggregated metrics done. cost {time.time() - t}\")\n\n        return aggregated\n\n    def _special_metrics_aggergate(self, aggregated: dict[str, Any]) -> dict[str, Any]:\n        \"\"\"calculate special metrics\"\"\"\n\n        # global_seqlen/minmax_diff\n        if \"global_seqlen/minmax_diff\" in aggregated.keys():\n            aggregated[\"global_seqlen/minmax_diff\"] = aggregated[\"global_seqlen/max\"] - aggregated[\"global_seqlen/min\"]\n\n        # perf/throughput\n        REQUIRED_PERF_KEYS = {\"perf/throughput\", \"perf/total_num_tokens\", \"perf/time_per_step\"}\n        if REQUIRED_PERF_KEYS.issubset(aggregated):\n            aggregated[\"perf/throughput\"] = aggregated[\"perf/total_num_tokens\"] / (\n                aggregated[\"perf/time_per_step\"] * self.total_gpus\n            )\n\n        # trainer/idle_ratio\n        if \"timing_s/gen\" in aggregated.keys() and \"timing_s/step\" in aggregated.keys():\n            aggregated[\"trainer/idle_ratio\"] = aggregated[\"timing_s/gen\"] / aggregated[\"timing_s/step\"]\n\n        return aggregated\n\n    def reset(self):\n        \"\"\"Reset Aggregator\"\"\"\n        self.metric_values.clear()\n        self.sample_counts.clear()\n        self.timestamps.clear()\n        self.step_count = 0\n\n    def get_current_stats(self) -> dict[str, Any]:\n        \"\"\"Get statistics about the current aggregation state (for debugging)\"\"\"\n        return {\n            \"step_count\": self.step_count,\n            \"metric_count\": len(self.metric_values),\n            \"total_samples\": sum(self.sample_counts),\n            \"metric_names\": list(self.metric_values.keys()),\n        }\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/fsdp2_utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\nimport torch.distributed as dist\nfrom packaging import version\nfrom torch.distributed.tensor import DTensor\nfrom torch.distributed.tensor._dtensor_spec import DTensorSpec\n\nif version.parse(torch.__version__) < version.parse(\"2.6\"):\n    raise RuntimeError(\"PyTorch 2.6 or higher is required to use fstp_utils.\")\n\n\ndef fsdp2_sharded_save_to_cpu(\n    model: torch.nn.Module,\n) -> tuple[dict[str, tuple[torch.Tensor, DTensorSpec]], DTensorSpec]:\n    \"\"\"\n    Sharded Save: Each process only saves the local DTensor shard from its own GPU to CPU memory.\n\n    Args:\n        model: FSDP2-wrapped model whose parameters are of DTensor type.\n\n    Returns:\n        cpu_sharded_state: Dictionary of CPU shards for the current process.\n                          Key = parameter name, Value = (CPU shard tensor, original DTensorSpec)\n        global_spec: DTensorSpec of the first parameter (used to verify global rules during loading)\n    \"\"\"\n    cpu_sharded_state = {}\n    global_spec = None  # Record global sharding rules (all parameters follow the same spec)\n\n    for param_name, param in model.named_parameters():\n        # Only process sharded parameters of DTensor type (core parameters of FSDP2)\n        if not isinstance(param, DTensor):\n            # Save non-sharded parameters (e.g., running_mean of BatchNorm) as local data\n            cpu_tensor = param.detach().cpu()\n            cpu_sharded_state[param_name] = (cpu_tensor, None)\n            continue\n\n        # Record global sharding rules (take spec of the first DTensor to ensure consistency)\n        if global_spec is None:\n            global_spec = param._spec\n            assert hasattr(global_spec, \"device_mesh\"), \"DTensorSpec must contain 'device_mesh' attribute\"\n            assert hasattr(global_spec, \"placements\"), \"DTensorSpec must contain 'placements' attribute\"\n\n        # 1. Extract local shard data from the current GPU (_local_tensor)\n        local_gpu_tensor = param._local_tensor  # Local shard attribute defined in your DTensor class\n        # 2. Move to CPU memory and detach from computation graph\n        local_cpu_tensor = local_gpu_tensor.detach().cpu()\n        # 3. Save CPU shard + original DTensorSpec (ensure sharding rules remain unchanged)\n        cpu_sharded_state[param_name] = (local_cpu_tensor, param._spec)\n\n    assert global_spec is not None, \"No DTensor-type parameters found in the model. FSDP2 sharding may not be enabled.\"\n    return cpu_sharded_state, global_spec\n\n\ndef fsdp2_sharded_load_from_cpu(\n    model: torch.nn.Module,\n    cpu_sharded_state: dict[str, tuple[torch.Tensor, Optional[DTensorSpec]]],\n    target_spec: DTensorSpec,\n) -> None:\n    \"\"\"\n    Sharded Load: Each process only loads the CPU shard it is responsible for to the GPU,\n                  keeping sharding rules unchanged.\n\n    Args:\n        model: FSDP2 model to be restored (must have the same structure as when saved)\n        cpu_sharded_state: Shard data read from CPU memory by the current process\n                          (from fsdp2_sharded_save_to_cpu)\n        target_spec: Global DTensorSpec from saving (used to verify sharding rule consistency)\n    \"\"\"\n    # Verify device_mesh consistency (core: ensure loaded shards map to original GPUs)\n    current_device_mesh = None\n    for param in model.parameters():\n        if isinstance(param, DTensor):\n            current_device_mesh = param._spec.device_mesh\n            break\n    assert current_device_mesh is not None, \"DTensor parameters not initialized in the model to be loaded\"\n    assert current_device_mesh == target_spec.device_mesh, (\n        f\"device_mesh mismatch during loading! Original: {target_spec.device_mesh}, Current: {current_device_mesh}\"\n    )\n\n    for param_name, param in model.named_parameters():\n        # Skip parameters not in the saved state (e.g., newly added parameters)\n        if param_name not in cpu_sharded_state:\n            continue\n\n        # Extract CPU shard data and original Spec\n        local_cpu_tensor, saved_spec = cpu_sharded_state[param_name]\n\n        # Handle different parameter types: DTensor sharded parameters vs. regular parameters\n        if isinstance(param, DTensor):\n            # 1. Verify sharding rule consistency (placements must match original Spec)\n            assert saved_spec is not None, f\"DTensorSpec missing in saved state for parameter {param_name}\"\n            assert saved_spec.placements == target_spec.placements, (\n                f\"Sharding strategy mismatch for parameter {param_name} (conflicts with global rules)!\"\n            )\n\n            # 2. Move CPU shard data to the current GPU (device of param._local_tensor)\n            target_device = param._local_tensor.device\n            local_gpu_tensor = local_cpu_tensor.to(target_device)\n\n            # 3. Restore to DTensor's local shard (directly copy to _local_tensor, keep spec unchanged)\n            param._local_tensor.copy_(local_gpu_tensor)\n\n        else:\n            # Regular parameters: load directly to original device\n            target_device = param.device\n            param.data.copy_(local_cpu_tensor.to(target_device))\n\n    # Process synchronization: ensure all processes complete loading before proceeding\n    dist.barrier()\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/fsdp_workers.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\nimport torch.distributed\nfrom omegaconf import DictConfig\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nfrom recipe.fully_async_policy.fsdp2_utils import fsdp2_sharded_load_from_cpu, fsdp2_sharded_save_to_cpu\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.utils.device import (\n    get_device_name,\n    get_torch_device,\n)\nfrom verl.utils.fsdp_utils import (\n    fsdp_version,\n)\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\ndevice_name = get_device_name()\n\n__all__ = [\"DetachActorWorker\", \"DetachAsyncRolloutWorker\", \"CriticWorker\"]\n\n\ndef get_inference_model(rollout):\n    \"\"\"\n    get models according to different types of inference_engine\n    Args:\n        rollout: rollout object\n    Returns:\n        model: model object\n    \"\"\"\n    inference_engine = rollout.inference_engine\n    if hasattr(inference_engine, \"llm_engine\"):\n        inference_model = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n    elif hasattr(inference_engine, \"worker\"):\n        inference_model = inference_engine.worker.model_runner.model\n    else:\n        raise AttributeError(\n            f\"Unsupported inference_engine type: {type(inference_engine)}. \"\n            f\"Expected LLM (with llm_engine attribute) or WorkerWrapperBase (with worker attribute).\"\n        )\n    return inference_model\n\n\nclass DetachNcclSync(AsyncActorRolloutRefWorker):\n    def _get_actor_params(self):\n        pass\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\n    def sync_rollout_weights(self):\n        assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine\n        assert hasattr(self, \"_weights_info\") and self._weights_info is not None\n\n        params = self._get_actor_params() if self._is_actor else None\n        if self._is_rollout:\n            inference_model = get_inference_model(self.rollout)\n\n            from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader\n\n            patch_vllm_moe_model_weight_loader(inference_model)\n        for key, shape, dtype in self._weights_info:\n            tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())\n            if self._is_actor:\n                assert key in params\n                origin_data = params[key]\n                if hasattr(origin_data, \"full_tensor\"):\n                    origin_data = origin_data.full_tensor()\n                if torch.distributed.get_rank() == 0:\n                    tensor.copy_(origin_data)\n            from ray.util.collective import collective\n\n            collective.broadcast(tensor, src_rank=0, group_name=\"actor_rollout\")\n            if self._is_rollout:\n                inference_model.load_weights([(key, tensor)])\n        get_torch_device().empty_cache()\n\n\nclass DetachActorWorker(DetachNcclSync):\n    def _get_actor_params(self):\n        assert self._is_actor\n        params = self.actor_module_fsdp.state_dict()\n        from verl.utils.model import convert_weight_keys\n\n        params = convert_weight_keys(\n            params, getattr(self.actor_module_fsdp, \"_fsdp_wrapped_module\", self.actor_module_fsdp)\n        )\n        return params\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def get_actor_weights_info(self):\n        assert self._is_actor\n        if hasattr(self, \"_weights_info\"):\n            return self._weights_info\n        if fsdp_version(self.actor_module_fsdp) == 1:\n            from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType\n\n            FSDP.set_state_dict_type(\n                self.actor_module_fsdp,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n        params = self._get_actor_params()\n        ret = []\n        for key, tensor in params.items():\n            ret.append((key, tensor.size(), tensor.dtype))\n        self._weights_info = ret\n        return ret\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_model_to_cpu(self, n):\n        if not hasattr(self, \"cpu_saved_models\"):\n            self.cpu_saved_models = {}\n        self.cpu_saved_models[n] = fsdp2_sharded_save_to_cpu(self.actor_module_fsdp)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def restore_model_from_cpu(self, n):\n        if n in self.cpu_saved_models:\n            cpu_sharded_state, global_spec = self.cpu_saved_models[n]\n            fsdp2_sharded_load_from_cpu(self.actor_module_fsdp, cpu_sharded_state, global_spec)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def clear_cpu_model(self, n):\n        if n in self.cpu_saved_models:\n            del self.cpu_saved_models[n]\n\n\nclass DetachAsyncRolloutWorker(DetachNcclSync):\n    def __init__(self, config: DictConfig, role: str):\n        print(f\"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}\")\n        ActorRolloutRefWorker.__init__(self, config, role)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def set_actor_weights_info(self, weights_info):\n        assert self._is_rollout\n        self._weights_info = weights_info\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/fully_async_main.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport socket\nimport threading\nfrom pprint import pprint\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom recipe.fully_async_policy.fully_async_rollouter import FullyAsyncRollouter\nfrom recipe.fully_async_policy.fully_async_trainer import FullyAsyncTrainer\nfrom recipe.fully_async_policy.message_queue import MessageQueue, MessageQueueClient\nfrom verl.trainer.ppo.ray_trainer import ResourcePoolManager\nfrom verl.trainer.ppo.reward import load_reward_manager\nfrom verl.trainer.ppo.utils import Role\nfrom verl.utils.fs import copy_to_local\n\n\ndef create_resource_pool_manager(config, roles: list) -> ResourcePoolManager:\n    \"\"\"\n    Create resource pool manager\n\n    Args:\n        config: Configuration object\n        roles: List of roles that need to create resource pools\n\n    Returns:\n        ResourcePoolManager: Resource pool manager\n    \"\"\"\n    resource_pool_spec = {}\n    mapping = {}\n\n    # Actor/Critic resource pool\n    if any(role in roles for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]):\n        assert config.trainer.n_gpus_per_node > 0, \"config.trainer.n_gpus_per_node must be greater than 0\"\n        assert config.trainer.nnodes > 0, \"config.trainer.nnodes must be greater than 0\"\n\n        trainer_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes\n        resource_pool_spec[\"trainer_pool\"] = trainer_pool\n\n        # Map training-related roles to the same resource pool\n        for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]:\n            if role in roles:\n                mapping[role] = \"trainer_pool\"\n\n    # Rollout resource pool\n    if Role.Rollout in roles:\n        assert config.rollout.n_gpus_per_node > 0, \"config.rollout.n_gpus_per_node must be greater than 0\"\n        assert config.rollout.nnodes > 0, \"config.rollout.nnodes must be greater than 0\"\n\n        rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes\n        resource_pool_spec[\"rollout_pool\"] = rollout_pool\n        mapping[Role.Rollout] = \"rollout_pool\"\n\n    return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n\ndef create_role_worker_mapping(config):\n    \"\"\"\n    Create mapping from roles to worker classes\n\n    Args:\n        config: Configuration object\n\n    Returns:\n        dict: Mapping from roles to worker classes\n    \"\"\"\n    # Select worker class based on strategy\n    if config.actor_rollout_ref.actor.strategy in [\"fsdp\", \"fsdp2\"]:\n        assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n        from recipe.fully_async_policy.fsdp_workers import (\n            CriticWorker,\n            DetachActorWorker,\n            DetachAsyncRolloutWorker,\n        )\n        from verl.single_controller.ray import RayWorkerGroup\n\n        ray_worker_group_cls = RayWorkerGroup\n\n    elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n        assert config.critic.strategy == \"megatron\"\n        from recipe.fully_async_policy.megatron_worker import CriticWorker, DetachActorWorker, DetachAsyncRolloutWorker\n        from verl.single_controller.ray import RayWorkerGroup\n\n        ray_worker_group_cls = RayWorkerGroup\n    else:\n        raise NotImplementedError(f\"Unsupported strategy: {config.actor_rollout_ref.actor.strategy}\")\n\n    role_worker_mapping = {\n        Role.Actor: ray.remote(DetachActorWorker),\n        Role.Rollout: ray.remote(DetachAsyncRolloutWorker),\n        Role.Critic: ray.remote(CriticWorker),\n    }\n\n    if config.reward_model.enable:\n        if config.reward_model.strategy in [\"fsdp\", \"fsdp2\"]:\n            from verl.workers.fsdp_workers import RewardModelWorker\n        # TODO megatron support\n        else:\n            raise NotImplementedError(f\"Unsupported reward model strategy: {config.reward_model.strategy}\")\n\n        role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n\n    # Add reference policy (if KL loss or reward is required)\n    if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n        role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker)\n\n    return role_worker_mapping, ray_worker_group_cls\n\n\n@ray.remote(num_cpus=1)\nclass FullyAsyncTaskRunner:\n    \"\"\"\n    Ray remote class for executing distributed PPO training tasks.\n    \"\"\"\n\n    def __init__(self):\n        self.running = False\n        self.components = {}\n        self.shutdown_event = threading.Event()\n\n    def run(self, config):\n        print(\"[ASYNC MAIN] Starting fully async PPO training...\")\n        self._initialize_components(config)\n        self._run_training_loop()\n\n    def _initialize_components(self, config) -> None:\n        print(f\"[ASYNC MAIN] TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}\")\n        pprint(OmegaConf.to_container(config, resolve=True))\n        OmegaConf.resolve(config)\n\n        print(\"[ASYNC MAIN] Initializing model and tokenizer...\")\n        local_path = copy_to_local(\n            config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get(\"use_shm\", False)\n        )\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n\n        # Used for multimodal LLM, could be None\n        processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)\n\n        self.components[\"tokenizer\"] = tokenizer\n        self.components[\"processor\"] = processor\n        self.components[\"config\"] = config\n\n        print(\"[ASYNC MAIN] Creating worker mapping and resource pools...\")\n        role_worker_mapping, ray_worker_group_cls = create_role_worker_mapping(config)\n        self.components[\"role_worker_mapping\"] = role_worker_mapping\n        self.components[\"ray_worker_group_cls\"] = ray_worker_group_cls\n\n        print(\"[ASYNC MAIN] Loading reward functions...\")\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        val_reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=1, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        self.components[\"reward_fn\"] = reward_fn\n        self.components[\"val_reward_fn\"] = val_reward_fn\n\n        print(\"[ASYNC MAIN] Creating FullyAsyncRollouter...\")\n        self._create_rollouter(config)\n\n        print(\"[ASYNC MAIN] Creating FullyAsyncTrainer...\")\n        self._create_trainer(config)\n\n        # sync total_train_steps between rollouter and trainer\n        total_train_steps = ray.get(self.components[\"rollouter\"].get_total_train_steps.remote())\n        print(f\"total_train_steps {total_train_steps}\")\n        ray.get(self.components[\"trainer\"].set_total_train_steps.remote(total_train_steps))\n\n        # max_queue_size\n        max_queue_size = ray.get(self.components[\"rollouter\"].get_max_queue_size.remote())\n        print(f\"[ASYNC MAIN] Creating MessageQueue... max_queue_size {max_queue_size}\")\n        message_queue = MessageQueue.remote(config, max_queue_size)\n        message_queue_client = MessageQueueClient(message_queue)\n        self.components[\"message_queue\"] = message_queue\n        self.components[\"message_queue_client\"] = message_queue_client\n\n        ray.get(self.components[\"rollouter\"].set_message_queue_client.remote(self.components[\"message_queue_client\"]))\n        ray.get(self.components[\"trainer\"].set_message_queue_client.remote(self.components[\"message_queue_client\"]))\n\n        print(\"[ASYNC MAIN] Setting up parameter synchronization...\")\n        from recipe.fully_async_policy.param_sync import ParameterSynchronizer\n\n        param_synchronizer = ParameterSynchronizer.remote(\n            config=config,\n            trainer=self.components[\"trainer\"],\n            rollouter=self.components[\"rollouter\"],\n            mq=self.components[\"message_queue_client\"],\n        )\n        ray.get(self.components[\"trainer\"].set_parameter_synchronizer.remote(param_synchronizer))\n\n        # load checkpoint and sync parameter before doing anything\n        val_before_train = val_reward_fn is not None and config.trainer.get(\"val_before_train\", True)\n        ray.get(self.components[\"trainer\"].load_checkpoint.remote())\n        ray.get(param_synchronizer.sync_weights.remote(version=0, validate=val_before_train))\n        ray.get(param_synchronizer.wait_last_valid.remote())\n\n        self.components[\"param_synchronizer\"] = param_synchronizer\n        print(\"[ASYNC MAIN] All components initialized successfully\")\n\n    def _create_rollouter(self, config) -> None:\n        rollouter = FullyAsyncRollouter.remote(\n            config=config,\n            tokenizer=self.components[\"tokenizer\"],\n            role_worker_mapping={Role.Rollout: self.components[\"role_worker_mapping\"][Role.Rollout]},\n            resource_pool_manager=create_resource_pool_manager(config, roles=[Role.Rollout]),\n            ray_worker_group_cls=self.components[\"ray_worker_group_cls\"],\n            processor=self.components[\"processor\"],\n            reward_fn=self.components[\"reward_fn\"],\n            val_reward_fn=self.components[\"val_reward_fn\"],\n            device_name=config.trainer.device,\n        )\n\n        ray.get(rollouter.init_workers.remote())\n        ray.get(rollouter.set_max_required_samples.remote())\n\n        self.components[\"rollouter\"] = rollouter\n        print(\"[ASYNC MAIN] Rollouter created and initialized successfully\")\n\n    def _create_trainer(self, config) -> None:\n        trainer_role_mapping = {\n            role: worker_cls\n            for role, worker_cls in self.components[\"role_worker_mapping\"].items()\n            if role != Role.Rollout\n        }\n\n        trainer = FullyAsyncTrainer.remote(\n            config=config,\n            tokenizer=self.components[\"tokenizer\"],\n            role_worker_mapping=trainer_role_mapping,\n            resource_pool_manager=create_resource_pool_manager(config, roles=list(trainer_role_mapping.keys())),\n            ray_worker_group_cls=self.components[\"ray_worker_group_cls\"],\n            processor=self.components[\"processor\"],\n            reward_fn=self.components[\"reward_fn\"],\n            val_reward_fn=self.components[\"val_reward_fn\"],\n            device_name=config.trainer.device,\n        )\n\n        ray.get(trainer.init_workers.remote())\n        self.components[\"trainer\"] = trainer\n        print(\"[ASYNC MAIN] FullyAsyncTrainer created and initialized successfully\")\n\n    def _run_training_loop(self):\n        self.running = True\n\n        print(\"[ASYNC MAIN] Starting Rollouter and Trainer...\")\n        rollouter_future = self.components[\"rollouter\"].fit.remote()\n        trainer_future = self.components[\"trainer\"].fit.remote()\n\n        futures = [rollouter_future, trainer_future]\n\n        try:\n            while futures:\n                # Use ray.wait to monitor all futures and return when any one is completed.\n                done_futures, remaining_futures = ray.wait(futures, num_returns=1, timeout=None)\n\n                for future in done_futures:\n                    try:\n                        ray.get(future)\n                        print(\"[ASYNC MAIN] One component completed successfully\")\n                    except Exception as e:\n                        print(f\"[ASYNC MAIN] Component failed with error: {e}\")\n                        for remaining_future in remaining_futures:\n                            ray.cancel(remaining_future)\n                        raise e\n\n                futures = remaining_futures\n\n        except Exception as e:\n            print(f\"[ASYNC MAIN] Training failed: {e}\")\n            for future in futures:\n                ray.cancel(future)\n            raise\n        finally:\n            self.components[\"message_queue_client\"].clear_queue()\n            print(\"[ASYNC MAIN] Training completed or interrupted\")\n\n\n@hydra.main(config_path=\"config\", config_name=\"fully_async_ppo_trainer\", version_base=None)\ndef main(config):\n    from verl.trainer.main_ppo import run_ppo\n\n    # Ensure async training config exists\n    if not hasattr(config, \"async_training\"):\n        raise RuntimeError(\"must set async_training config\")\n    from time import time\n\n    start_time = time()\n    run_ppo(config, task_runner_class=FullyAsyncTaskRunner)\n    print(f\"total time: {time() - start_time:.2f} seconds\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/fully_async_rollouter.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport time\nfrom pprint import pformat\n\nimport ray\nfrom ray import ObjectRef\n\nfrom recipe.fully_async_policy.detach_utils import (\n    RolloutSample,\n    ValidateMetrics,\n    merge_rollout_sample,\n    prepare_single_generation_data,\n)\nfrom recipe.fully_async_policy.message_queue import MessageQueueClient\nfrom recipe.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup\nfrom verl.trainer.ppo.ray_trainer import ResourcePoolManager\nfrom verl.trainer.ppo.utils import Role, WorkerType\nfrom verl.utils.profiler import marked_timer\nfrom verl.utils.tracking import ValidationGenerationsLogger\n\n\n@ray.remote(num_cpus=10, max_concurrency=100)\nclass FullyAsyncRollouter(FullyAsyncRayPPOTrainer):\n    \"\"\"\n    Asynchronous sample generator, responsible for continuously generating training samples\n    and putting them into MessageQueue\n    Based on the mature implementation improvements of OneStepOffRayTrainer\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        device_name=None,\n    ):\n        # Store the tokenizer for text processing\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n\n        assert not self.hybrid_engine\n        assert self.config.data.train_batch_size == 0, \"train_batch_size must be zero\"\n        assert self.config.data.gen_batch_size == 1, \"gen_batch_size must be one\"\n        assert self.config.async_training.staleness_threshold >= 0, \"staleness_threshold must larger than 0\"\n        assert self.config.async_training.trigger_parameter_sync_step >= 1, (\n            \"trigger_parameter_sync_step must larger than 1\"\n        )\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.device_name = device_name if device_name else self.config.trainer.device\n        self.validation_generations_logger = ValidationGenerationsLogger(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n        )\n\n        self.ref_in_actor = False\n        self.kl_ctrl_in_reward = False\n        self.use_critic = False\n        self.use_reference_policy = False\n        self.use_rm = False\n\n        print(\"[FullyAsyncRollouter] Creating datasets...\")\n        from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler\n        from verl.utils.dataset.rl_dataset import collate_fn\n\n        train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)\n        val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)\n        train_sampler = create_rl_sampler(config.data, train_dataset)\n\n        self._validate_config()\n        print(f\"[FullyAsyncRollouter] Rollouter _create_dataloader...\\n{train_dataset}\\n{val_dataset}\")\n\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n        # ==================== fully async config ====================\n\n        self.total_rollout_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n        if self.config.rollout.total_rollout_steps is not None:\n            self.total_rollout_steps = min(self.config.rollout.total_rollout_steps, self.total_rollout_steps)\n        print(f\"[FullyAsyncRollouter] Total rollout steps: {self.total_rollout_steps}\")\n        self.total_train_steps = None\n\n        # Rollouter parameter configuration\n        self.message_queue_client = None\n\n        # Worker groups: rollout_wg is same to actor_rollout_wg\n        self.rollout_wg = None\n        self.actor_rollout_wg = None\n        self.async_rollout_manager = None\n\n        # Config\n        self.staleness_threshold: float = config.async_training.get(\"staleness_threshold\", 1)\n        # required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples.\n        self.require_batches = config.async_training.require_batches\n        self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches\n        self.max_required_samples = None\n        self.max_concurrent_samples = None\n        # queue size\n        self.max_queue_size = None\n\n        # Statistics\n        self.current_param_version = 0\n        self.total_generated_samples = 0\n        self.staleness_samples = 0\n        self.dropped_stale_samples = 0\n        self.processed_sample_count = 0\n        self.global_steps = 0\n        self.idle_start_time = None\n        self.version_start_time = None\n\n        # Concurrency control\n        # Modified by self.pause() or self._should_pause_generation()\n        self.paused = False\n        self.running = True\n        self.monitor_loop_trigger = True\n\n        # Initialize async locks directly\n        self.lock = asyncio.Lock()\n        self.condition = asyncio.Condition(self.lock)\n\n        # Initialize async queues\n        self.pending_queue = asyncio.Queue(maxsize=128)\n        self.active_tasks = set()\n        self.result_queue = asyncio.Queue()\n        self.cancel_queue = asyncio.Queue()\n\n    async def set_message_queue_client(self, message_queue_client: MessageQueueClient):\n        \"\"\"Set message queue client\"\"\"\n        async with self.lock:\n            self.message_queue_client = message_queue_client\n\n    async def set_max_required_samples(self):\n        async with self.lock:\n            self.max_required_samples = int(\n                self.required_samples\n                * (self.staleness_threshold + 1)\n                * self.config.async_training.trigger_parameter_sync_step\n            )\n            self.total_train_steps = int(\n                self.total_rollout_steps\n                / (self.required_samples * self.config.async_training.trigger_parameter_sync_step)\n            )\n\n            self.max_concurrent_samples = len(self.async_rollout_manager.server_handles) * 16\n            self.max_concurrent_samples = min(self.max_concurrent_samples, self.max_required_samples)\n            self.max_queue_size = self.max_required_samples\n\n            print(\n                f\"[FullyAsyncRollouter] required_samples : {self.required_samples} \"\n                f\"max_required_samples: {self.max_required_samples} \"\n                f\"max_queue_size: {self.max_queue_size} \"\n                f\"total_train_steps: {self.total_train_steps} \"\n                f\"total_rollout_steps: {self.total_rollout_steps} \"\n                f\"max_concurrent_samples: {self.max_concurrent_samples} \"\n            )\n\n    def get_rollout_wg(self):\n        \"\"\"Get rollout worker group\"\"\"\n        return self.rollout_wg\n\n    def get_max_queue_size(self):\n        return self.max_queue_size\n\n    def get_total_train_steps(self):\n        return self.total_train_steps\n\n    async def update_param_version(self, version: int, validate: bool = False, global_steps: int = 0):\n        \"\"\"Update current parameter version\"\"\"\n        async with self.lock:\n            old_version = self.current_param_version\n            self.current_param_version = version\n            # every time param change, reset staleness_samples\n            self.staleness_samples = (\n                len(self.active_tasks)\n                + self.result_queue.qsize()\n                + self.cancel_queue.qsize()\n                + await self.message_queue_client.get_queue_size()\n            )\n            timing_raw = {}\n            idle_ratio = None\n            if self.idle_start_time is not None and self.version_start_time is not None:\n                rollout_active_time = self.idle_start_time - self.version_start_time\n                rollout_version_time = time.time() - self.version_start_time\n                idle_ratio = 1 - rollout_active_time / rollout_version_time\n                timing_raw[\"rollouter/active_time\"] = rollout_active_time\n                timing_raw[\"rollouter/version_time\"] = rollout_version_time\n                timing_raw[\"rollouter/idle_ratio\"] = idle_ratio\n                self.idle_start_time = None\n            print(\n                f\"[FullyAsyncRollouter][Public][update_param_version] \"\n                f\"Parameter version updated from {old_version} to {version} \"\n                f\",reset staleness_samples to: {self.staleness_samples}\"\n                f\",idle_ratio: {idle_ratio}\"\n            )\n            val_metrics = None\n            if (\n                self.val_reward_fn is not None\n                and self.config.rollout.test_freq > 0\n                and self.current_param_version % self.config.rollout.test_freq == 0\n                and self.current_param_version > 0  # don't test here in the initial parameter sync\n            ) or (validate and self.val_reward_fn is not None):\n                with marked_timer(\"rollouter/validate_time\", timing_raw, color=\"green\"):\n                    val_metrics: dict = self._validate()\n            data = ValidateMetrics(\n                timing_raw=timing_raw, metrics=val_metrics, global_steps=global_steps, param_version=version\n            )\n            await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data))\n\n            self.version_start_time = time.time()\n\n    def _validate_config(self):\n        # Validate asynchronous training configuration\n        if not hasattr(self.config, \"async_training\"):\n            raise ValueError(\"[FullyAsyncRollouter] Missing async_training configuration\")\n        assert self.config.actor_rollout_ref.rollout.calculate_log_probs, \"must rollout calculate log_probs\"\n\n    async def init_workers(self):\n        \"\"\"Initialize distributed training workers using Ray backend.\n\n        Creates:\n        1. Ray resource pools from configuration\n        2. Worker groups for each role (actor, critic, etc.)\n        \"\"\"\n        self._init_resource_pools()\n        self._create_worker_classes()\n        self._init_worker_groups()\n        self._init_models()\n        await self._init_async_rollout_manager()\n\n    def _create_actor_rollout_classes(self):\n        # only create rollout\n        for role in [Role.Rollout]:\n            resource_pool = self.resource_pool_manager.get_resource_pool(role)\n            role_cls = RayClassWithInitArgs(\n                cls=self.role_worker_mapping[role],\n                config=self.config.actor_rollout_ref,\n                role=str(role),\n            )\n            self.resource_pool_to_cls[resource_pool][str(role)] = role_cls\n\n    def _init_models(self):\n        self.rollout_wg = self.all_wg[str(Role.Rollout)]\n        self.rollout_wg.init_model()\n        self.actor_rollout_wg = self.rollout_wg\n\n    def _create_continuous_iterator(self):\n        \"\"\"\n        Create a continuous data iterator across epoch\n        \"\"\"\n        for epoch in range(self.config.rollout.total_epochs):\n            iterator = iter(self.train_dataloader)\n            for batch_dict in iterator:\n                yield epoch, batch_dict\n\n    async def _init_async_rollout_manager(self):\n        # create async rollout manager and request scheduler\n        assert self.config.actor_rollout_ref.rollout.mode == \"async\"\n        from recipe.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager\n\n        self.async_rollout_mode = True\n        self.async_rollout_manager = await FullyAsyncAgentLoopManager.create(\n            config=self.config,\n            worker_group=self.rollout_wg,\n        )\n\n    # Add samples to the pending_queue\n    async def _feed_samples(self):\n        continuous_iterator = self._create_continuous_iterator()\n\n        for epoch, batch_dict in continuous_iterator:\n            # Similar to _prepare_generate_batch: Separate data\n            full_batch = prepare_single_generation_data(\n                batch_dict, self.global_steps, self.config.actor_rollout_ref.rollout.n\n            )\n\n            sample_id = f\"sample_{epoch}_{self.global_steps}\"\n\n            rollout_sample = RolloutSample(\n                full_batch=full_batch,\n                agent_loop_output_list=[None] * self.config.actor_rollout_ref.rollout.n,\n                sample_id=sample_id,\n                epoch=epoch,\n                param_version=0,\n                param_version_start=[],\n                param_version_end=[],\n                processing_times=[],\n                rollout_status={},\n            )\n\n            await self.pending_queue.put(rollout_sample)\n\n            # Check if have reached the last step\n            if self.global_steps >= self.total_rollout_steps:\n                print(\n                    f\"[FullyAsyncRollouter][Feed] \"\n                    f\"Maximum count has been reached, stop adding new samples\"\n                    f\"{self.global_steps} >= {self.total_rollout_steps}\"\n                )\n                break\n\n            self.global_steps += 1\n\n        # End signal\n        await self.pending_queue.put(\"DONE\")\n        print(f\"[FullyAsyncRollouter][Feed] Sample addition is complete, {self.global_steps} samples have been added\")\n\n    async def _processor_worker(self):\n        \"\"\"\n        Streaming worker coroutines, a sample is submitted for processing without waiting for batches\n        \"\"\"\n        while True:\n            if self.paused or await self._should_pause_generation():\n                print(\n                    \"[FullyAsyncRollouter][Processor] Received pause signal, waiting for remaining tasks to return...\"\n                )\n                async with self.lock:\n                    self.paused = True\n                while self.active_tasks:\n                    async with self.lock:\n                        # After acquiring the lock, the number of active_tasks may change, need to be verified again\n                        if self.active_tasks:\n                            done_tasks, self.active_tasks = await asyncio.wait(\n                                self.active_tasks, return_when=asyncio.FIRST_COMPLETED\n                            )\n                        for task in done_tasks:\n                            await task\n\n                async with self.lock:\n                    while self.paused:\n                        self.idle_start_time = time.time()\n                        await self.condition.wait()\n                continue\n\n            simple_from_cancel_queue = False\n            if not self.cancel_queue.empty():\n                rollout_sample = await self.cancel_queue.get()\n                simple_from_cancel_queue = True\n            else:\n                rollout_sample = await self.pending_queue.get()\n                self.staleness_samples += 1\n\n            if rollout_sample == \"DONE\":\n                print(\n                    \"[FullyAsyncRollouter][Processor] Received end signal, waiting for remaining tasks to complete...\"\n                )\n                while self.active_tasks:\n                    async with self.lock:\n                        if self.active_tasks:\n                            done_tasks, self.active_tasks = await asyncio.wait(\n                                self.active_tasks, return_when=asyncio.FIRST_COMPLETED\n                            )\n                        for task in done_tasks:\n                            await task\n                break\n\n            # Check whether the number of concurrent tasks exceeds the limit\n            while len(self.active_tasks) >= self.max_concurrent_samples:\n                async with self.lock:\n                    if self.active_tasks:\n                        done_tasks, self.active_tasks = await asyncio.wait(\n                            self.active_tasks, return_when=asyncio.FIRST_COMPLETED\n                        )\n                    for task in done_tasks:\n                        await task\n\n            # Submit single sample processing\n            async with self.lock:\n                # After the pause is over, the lock is acquired and it is necessary\n                # to determine whether it is the pause phase, otherwise continue to wait\n                while self.paused:\n                    await self.condition.wait()\n                task = asyncio.create_task(\n                    self._process_single_sample_streaming(rollout_sample),\n                    name=rollout_sample.sample_id,\n                )\n                self.active_tasks.add(task)\n\n            if simple_from_cancel_queue:\n                self.cancel_queue.task_done()\n            else:\n                self.pending_queue.task_done()\n\n    async def _process_single_sample_streaming(self, rollout_sample: RolloutSample):\n        \"\"\"Process a single sample streamingly\"\"\"\n        # Calling asynchronous generation methods\n        rollout_sample.full_batch.non_tensor_batch[\"param_version\"] = [self.current_param_version] * len(\n            rollout_sample.full_batch\n        )\n        agent_loop_output_list = await self.async_rollout_manager.generate_single_sample_async(\n            rollout_sample.full_batch, rollout_sample.agent_loop_output_list\n        )\n        rollout_sample.agent_loop_output_list = agent_loop_output_list\n\n        is_cancel = False\n        for agent_loop in agent_loop_output_list:\n            if not is_cancel and agent_loop.is_cancel:\n                is_cancel = True\n\n        if is_cancel:\n            # Put in the cancel queue and wait for the generation to resume\n            await self.cancel_queue.put(rollout_sample)\n        else:\n            # put into the result_queue\n            rollout_sample.param_version = self.current_param_version\n            rollout_sample.rollout_status = await self.get_statistics()\n            await self.result_queue.put(rollout_sample)\n\n        self.processed_sample_count += 1\n\n    async def _consumer_worker(self):\n        \"\"\"\n        The consumer coroutine is responsible for obtaining the processing results\n        from the result queue and putting them into the message queue\n        \"\"\"\n        while True:\n            rollout_sample = await self.result_queue.get()\n            rollout_sample = merge_rollout_sample(self.config, self.tokenizer, rollout_sample, self.processor)\n\n            # Put RolloutSample into the message queue\n            success = await self.message_queue_client.put_sample(\n                sample=ray.cloudpickle.dumps(rollout_sample),\n                param_version=rollout_sample.param_version,\n            )\n            if success:\n                self.total_generated_samples += 1\n            else:\n                self.dropped_stale_samples += 1\n\n            self.result_queue.task_done()\n\n    async def _streaming_generation_main(self):\n        \"\"\"The main entry method for stream processing\"\"\"\n\n        # we start from step 1\n        self.global_steps += 1\n\n        if self.async_rollout_manager is None:\n            await self._init_async_rollout_manager()\n\n        # Start the streaming loop\n        print(f\"[FullyAsyncRollouter] Start streaming mode, maximum concurrent samples: {self.max_concurrent_samples}\")\n\n        # Start sample feed coroutine, streaming process coroutine and consumer coroutine\n        self.feed_task = asyncio.create_task(self._feed_samples())\n        self.processor_task = asyncio.create_task(self._processor_worker())\n        self.consumer_task = asyncio.create_task(self._consumer_worker())\n\n        try:\n            # Wait for sample feed to complete\n            await self.feed_task\n            print(\"[FullyAsyncRollouter] Sample feed completed\")\n\n            # Wait for streaming to complete\n            await self.processor_task\n            print(\"[FullyAsyncRollouter] Streaming process completed\")\n\n            # Waiting for the result queue to clear\n            await self.result_queue.join()\n            print(\"[FullyAsyncRollouter] Result queue cleared\")\n\n        except Exception as e:\n            print(f\"[FullyAsyncRollouter] Streaming process exception:{e}\")\n\n        finally:\n            if self.processor_task:\n                self.processor_task.cancel()\n            if self.consumer_task:\n                self.consumer_task.cancel()\n\n            await asyncio.gather(self.processor_task, self.consumer_task, return_exceptions=True)\n\n        # Send a finish signal\n        await self.message_queue_client.put_sample(\n            sample=None,\n            param_version=self.current_param_version,\n        )\n\n        async with self.lock:\n            self.running = False\n\n    async def fit(self):\n        \"\"\"\n        Start the async rollouter - entry point that sets up and runs async tasks\n        Main async fit method that coordinates all coroutines\n        \"\"\"\n\n        print(\"[FullyAsyncRollouter] Starting FullyAsyncRollouter...\")\n\n        if self.message_queue_client is None:\n            raise ValueError(\"MessageQueue client not set. Call set_message_queue_client() first.\")\n\n        # Set the running status flag\n        async with self.lock:\n            self.paused = False\n            self.running = True\n\n        # Create the main asynchronous task\n        generation_task = asyncio.create_task(self._streaming_generation_main())\n        monitor_task = asyncio.create_task(self._async_monitor_loop())\n\n        try:\n            # Run build and monitoring tasks concurrently\n            await asyncio.gather(generation_task, monitor_task, return_exceptions=True)\n        except Exception as e:\n            print(f\"[FullyAsyncRollouter] Asynchronous task execution error: {e}\")\n        finally:\n            if not generation_task.done():\n                generation_task.cancel()\n            if not monitor_task.done():\n                monitor_task.cancel()\n\n            # Wait for the task to complete\n            await asyncio.gather(generation_task, monitor_task, return_exceptions=True)\n\n        print(\"[FullyAsyncRollouter] Rollouter fit completed\")\n\n    async def _async_monitor_loop(self):\n        \"\"\"\n        Async coroutine for monitoring:\n        Function 1: Log information output\n        Function 2: Trigger rollout recovery\n        \"\"\"\n        last_stats_time = time.time()\n        stats_interval = 60.0\n        check_interval = 10.0\n\n        while True:\n            async with self.lock:\n                if not self.running:\n                    break\n            await asyncio.sleep(check_interval)\n            # Print statistics periodically\n            current_time = time.time()\n            if current_time - last_stats_time >= stats_interval:\n                stats = await self.get_statistics()\n                print(f\"[FullyAsyncRollouter][MonitorLoop][Statistics] {pformat(stats)}\")\n                last_stats_time = current_time\n\n            # Trigger rollout recovery\n            if self.monitor_loop_trigger:\n                if not await self._should_pause_generation():\n                    async with self.lock:\n                        self.paused = False\n                        self.condition.notify_all()\n\n    async def _should_pause_generation(self) -> bool:\n        \"\"\"Determine whether the build should be paused\"\"\"\n        queue_stats = self.message_queue_client.get_statistics_sync()\n        queue_size = queue_stats[\"queue_size\"]\n\n        if queue_size >= self.max_queue_size:\n            if not self.paused:\n                print(\n                    f\"[FullyAsyncRollouter][ShouldPause]  \"\n                    f\"due to full queue: size={queue_size}, max={self.max_queue_size}\"\n                )\n            return True\n\n        if self.staleness_samples >= self.max_required_samples:\n            if not self.paused:\n                print(\n                    \"[FullyAsyncRollouter][ShouldPause] \"\n                    f\"due to \"\n                    f\"staleness_samples {self.staleness_samples} >= max_required_samples {self.max_required_samples} \"\n                )\n            return True\n\n        return False\n\n    async def pause(self):\n        \"\"\"pause rollout\"\"\"\n        print(\"[FullyAsyncRollouter][Public][Pause]\")\n        async with self.lock:\n            self.paused = True\n            # Cancel all rollout tasks\n            if self.config.async_training.partial_rollout:\n                await self.async_rollout_manager.cancel()\n            if self.active_tasks:\n                await asyncio.gather(*self.active_tasks, return_exceptions=True)\n                self.active_tasks.clear()\n                print(\"[FullyAsyncRollouter][Public][Pause] All active tasks completed\")\n            await self.async_rollout_manager.reset_prefix_cache()\n            self.monitor_loop_trigger = False\n\n    async def resume(self, dependency_ref: ObjectRef = None):\n        if dependency_ref is not None:\n            ray.get(dependency_ref)\n        print(\"[FullyAsyncRollouter][Public][Resume]\")\n        async with self.lock:\n            self.paused = False\n            self.monitor_loop_trigger = True\n            self.condition.notify_all()\n\n            if self.config.async_training.partial_rollout:\n                await self.async_rollout_manager.resume()\n\n    async def get_statistics(self) -> dict:\n        queue_stats = self.message_queue_client.get_statistics_sync()\n\n        stats = {\n            # monitor stats\n            \"monitor/active_tasks_size\": len(self.active_tasks),\n            \"monitor/queue/pending_queue_size\": self.pending_queue.qsize(),\n            \"monitor/queue/cancel_queue_size\": self.cancel_queue.qsize(),\n            \"monitor/queue/result_queue_size\": self.result_queue.qsize(),\n            \"monitor/queue/mq_queue_size\": queue_stats[\"queue_size\"],\n            # counting stats\n            \"count/current_param_version\": self.current_param_version,\n            \"count/total_generated_samples\": self.total_generated_samples,\n            \"count/staleness_samples\": self.staleness_samples,\n            \"count/dropped_stale_samples\": self.dropped_stale_samples,\n            # static stats\n            \"static/max_required_samples\": self.max_required_samples,\n            \"static/required_samples\": self.required_samples,\n            \"static/staleness_threshold\": self.staleness_threshold,\n            \"static/max_queue_size\": self.max_queue_size,\n            \"static/max_concurrent_samples\": self.max_concurrent_samples,\n        }\n\n        return stats\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/fully_async_trainer.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\nfrom datetime import datetime\nfrom pprint import pprint\nfrom typing import Any\n\nimport ray\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\n\nfrom recipe.fully_async_policy.detach_utils import (\n    MetricsAggregator,\n    ValidateMetrics,\n    assemble_batch_from_rollout_samples,\n)\nfrom recipe.fully_async_policy.message_queue import MessageQueueClient\nfrom recipe.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup\nfrom verl.trainer.ppo import core_algos\nfrom verl.trainer.ppo.ray_trainer import ResourcePoolManager\nfrom verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model\nfrom verl.utils.debug import marked_timer\n\n\n@ray.remote(num_cpus=10)\nclass FullyAsyncTrainer(FullyAsyncRayPPOTrainer):\n    \"\"\"\n    A fully asynchronous PPO trainer that obtains samples from a MessageQueue for training.\n    Based on an improved implementation of OneStepOffRayTrainer\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        device_name=None,\n    ):\n        # Store the tokenizer for text processing\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n        assert not self.hybrid_engine\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = need_reference_policy(self.role_worker_mapping)\n        self.use_rm = need_reward_model(self.role_worker_mapping)\n        self.use_critic = need_critic(self.config)\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.device_name = device_name if device_name else self.config.trainer.device\n\n        # if ref_in_actor is True, the reference policy will be actor without lora applied\n        self.ref_in_actor = config.actor_rollout_ref.model.get(\"lora_rank\", 0) > 0\n\n        # define in-reward KL control\n        # kl loss control currently not suppoorted\n        if self.config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)\n\n        # ==================== fully async config ====================\n\n        self.message_queue_client = None\n        self.param_synchronizer = None\n\n        # Statistics\n        # we start from step 1\n        self.global_steps = 1\n        self.local_trigger_step = 1\n        self.processed_samples = 0\n        self.stale_samples_processed = 0\n        self.stale_trajectory_processed = 0\n        self.current_param_version = 0\n        self.total_train_steps = None\n        self.progress_bar = None\n        self.trigger_parameter_sync_step = config.async_training.trigger_parameter_sync_step\n\n        # required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples.\n        self.require_batches = config.async_training.require_batches\n        self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches\n        self.compute_prox_log_prob = self.config.async_training.compute_prox_log_prob\n        total_gpus = (\n            config.trainer.nnodes * config.trainer.n_gpus_per_node\n            + config.rollout.nnodes * config.rollout.n_gpus_per_node\n        )\n        self.metrics_aggregator = MetricsAggregator(total_gpus=total_gpus)\n\n    def set_message_queue_client(self, message_queue_client: MessageQueueClient):\n        \"\"\"Set message queue client\"\"\"\n        self.message_queue_client = message_queue_client\n\n    def set_parameter_synchronizer(self, param_synchronizer):\n        \"\"\"Set parameter synchronizer\"\"\"\n        self.param_synchronizer = param_synchronizer\n\n    def set_total_train_steps(self, total_train_steps):\n        self.total_train_steps = total_train_steps\n        self.progress_bar = tqdm(total=self.total_train_steps, initial=0, desc=\"Training Progress\")\n\n    def get_actor_wg(self):\n        \"\"\"Get actor worker group\"\"\"\n        return self.actor_wg\n\n    def _get_samples_from_queue(self) -> tuple[None, None] | tuple[int, Any]:\n        \"\"\"\n        Get samples from message queue and compose gen_batch_output\n        Uses a loop to continuously collect samples until enough are gathered\n\n        Returns:\n            tuple: (epoch, batch_dict, gen_batch_output)\n        \"\"\"\n        print(\n            f\"[FullyAsyncTrainer] Requesting {self.required_samples} samples from queue\",\n            flush=True,\n        )\n\n        # Collect samples using a simple loop calling get_sample\n        consumer_start = time.time()\n        queue_samples = []\n        queue_len = 0\n        while len(queue_samples) < self.required_samples:\n            # Get a single sample and wait until there is a sample or None is received\n            sample, queue_len = self.message_queue_client.get_sample_sync()\n\n            if sample is None:\n                print(\n                    f\"[FullyAsyncTrainer] Detected termination signal (None), stopping sample collection. \"\n                    f\"Collected {len(queue_samples)}/{self.required_samples} samples\"\n                )\n                break\n\n            queue_samples.append(sample)\n\n            if len(queue_samples) % 64 == 0:\n                print(\n                    f\"[FullyAsyncTrainer] Collected {len(queue_samples)}/{self.required_samples} samples. \"\n                    f\"mq_len: {queue_len}\"\n                )\n\n        consumer_end = time.time()\n\n        if not queue_samples or len(queue_samples) < self.required_samples:\n            print(\"[FullyAsyncTrainer] not enough samples collected after loop\")\n            return None, None\n        total_wait_time = consumer_end - consumer_start\n\n        print(\n            f\"[FullyAsyncTrainer] Loop collection completed: {len(queue_samples)}/{self.required_samples} samples, \"\n            f\"total wait time: {total_wait_time:.2f} seconds.\"\n            f\"mq_len: {queue_len}\"\n        )\n\n        queue_samples = [ray.cloudpickle.loads(x) for x in queue_samples]\n        # Assemble batch - now working directly with RolloutSample objects\n        if self.config.trainer.balance_batch:\n            batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, self._balance_batch)\n        else:\n            batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, None)\n\n        batch.meta_info[\"fully_async/total_wait_time\"] = total_wait_time\n        return 0, batch\n\n    def _create_actor_rollout_classes(self):\n        # create actor\n        for role in [Role.Actor]:\n            resource_pool = self.resource_pool_manager.get_resource_pool(role)\n            role_cls = RayClassWithInitArgs(\n                cls=self.role_worker_mapping[role],\n                config=self.config.actor_rollout_ref,\n                role=str(role),\n            )\n            self.resource_pool_to_cls[resource_pool][str(role)] = role_cls\n\n    def _init_models(self):\n        if self.use_critic:\n            self.critic_wg = self.all_wg[str(Role.Critic)]\n            self.critic_wg.init_model()\n\n        if self.use_reference_policy and not self.ref_in_actor:\n            self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)]\n            self.ref_policy_wg.init_model()\n\n        if self.use_rm:\n            self.rm_wg = self.all_wg[str(Role.RewardModel)]\n            self.rm_wg.init_model()\n\n        self.actor_wg = self.all_wg[str(Role.Actor)]\n        self.actor_wg.init_model()\n        self.actor_rollout_wg = self.actor_wg  # to be compatible with the functions that not be modified\n\n    def _init_async_rollout_manager(self):\n        pass\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        print(\"[FullyAsyncTrainer] Starting FullyAsyncTrainer...\")\n        if self.message_queue_client is None:\n            raise ValueError(\"MessageQueue client not set. Call set_message_queue_client() first.\")\n        if self.param_synchronizer is None:\n            raise ValueError(\"param_synchronizer client not set. Call set_parameter_synchronizer() first.\")\n\n        from verl.utils.tracking import Tracking\n\n        self.logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.max_steps_duration = 0\n\n        # get validate data before training\n        val_data = self.message_queue_client.get_validate_sync()\n        if val_data:\n            val_data: ValidateMetrics = ray.cloudpickle.loads(val_data)\n            if val_data.metrics:\n                self.logger.log(data=val_data.metrics, step=val_data.param_version)\n                pprint(f\"[FullyAsyncTrainer] Initial validation metrics: {val_data.metrics}\")\n            self.logger.log(data=val_data.timing_raw, step=val_data.param_version)\n\n        # Use queue mode, no need for traditional dataloader iterator\n        # Initialize to get the first batch of data\n        while True:\n            metrics = {}\n            timing_raw = {}\n\n            with marked_timer(\"step\", timing_raw):\n                with marked_timer(\"gen\", timing_raw, color=\"red\"):\n                    epoch, batch = self._get_samples_from_queue()\n                    if batch is None:\n                        break\n                    self._collect_metrics_from_samples(batch, metrics)\n                batch, reward_extra_infos_dict = self._process_batch_common(\n                    batch, metrics, timing_raw, self.local_trigger_step if self.compute_prox_log_prob else None\n                )\n                self._log_rollout(batch, reward_extra_infos_dict, timing_raw)\n                self._check_save_checkpoint(False, timing_raw)\n\n            self._collect_metrics(batch, 0, metrics, timing_raw)\n            self.metrics_aggregator.add_step_metrics(\n                metrics=metrics, sample_count=self.required_samples, timestamp=time.time()\n            )\n            # Trigger parameter synchronization after training step\n            time_str = datetime.now().strftime(\"%H:%M:%S.%f\")[:-3]\n            print(\n                f\"[FullyAsyncTrainer] global_steps: {self.global_steps} \"\n                f\"local_trigger_step: {self.local_trigger_step} \"\n                f\"trigger_parameter_sync_step: {self.trigger_parameter_sync_step} \"\n                f\"{time_str}\"\n            )\n            self._trigger_parameter_sync_after_step(global_steps=self.global_steps)\n            val_data = self.message_queue_client.get_validate_sync()\n            if val_data:\n                val_data: ValidateMetrics = ray.cloudpickle.loads(val_data)\n                if val_data.metrics:\n                    self.logger.log(data=val_data.metrics, step=val_data.param_version)\n                    pprint(\n                        f\"[FullyAsyncTrainer] parameter version: {val_data.param_version} \\\n                        Validation metrics: {val_data.metrics}\"\n                    )\n                self.logger.log(data=val_data.timing_raw, step=val_data.param_version)\n            self.global_steps += 1\n\n        # final parameter sync and validate\n        if val_data is None or val_data.metrics is None:\n            self._trigger_parameter_sync_after_step(validate=True, global_steps=self.global_steps - 1)\n            ray.get(self.param_synchronizer.wait_last_valid.remote())\n            val_data = self.message_queue_client.get_validate_sync()\n            if val_data:\n                val_data: ValidateMetrics = ray.cloudpickle.loads(val_data)\n                if val_data.metrics:\n                    self.logger.log(data=val_data.metrics, step=val_data.param_version)\n                    pprint(f\"[FullyAsyncTrainer] Final validation metrics: {val_data.metrics}\")\n                self.logger.log(data=val_data.timing_raw, step=val_data.param_version)\n        else:\n            pprint(f\"[FullyAsyncTrainer] Final validation metrics: {val_data.metrics}\")\n        self.progress_bar.close()\n\n        self._check_save_checkpoint(True, timing_raw)  # TODO: check checkpoint\n\n    def load_checkpoint(self):\n        return self._load_checkpoint()\n\n    def _collect_metrics_from_samples(self, batch, metrics):\n        \"\"\"\n        Collect metrics from samples\n        \"\"\"\n        if hasattr(batch, \"meta_info\") and batch.meta_info:\n            samples_param_versions = batch.meta_info[\"rollout_param_versions\"]\n            stale_count = sum(1 for v in samples_param_versions if self.current_param_version - v >= 1)\n            self.stale_samples_processed += stale_count\n            trajectory_param_versions = batch.meta_info[\"trajectory_param_versions\"]\n            stale_traj_count = sum(1 for v in trajectory_param_versions if self.current_param_version - v >= 1)\n            self.stale_trajectory_processed += stale_traj_count\n            metrics.update(\n                {\n                    \"fully_async/count/stale_samples_processed\": self.stale_samples_processed,\n                    \"fully_async/count/stale_trajectory_processed\": self.stale_trajectory_processed,\n                    \"fully_async/count/current_param_version\": self.current_param_version,\n                }\n            )\n            for key, value in batch.meta_info.items():\n                if key.startswith(\"fully_async\"):\n                    metrics[key] = value\n\n    def _trigger_parameter_sync_after_step(self, validate: bool = False, global_steps: int = None):\n        \"\"\"\n        Trigger parameter synchronization after training step\n        This ensures rollouter always uses the latest trained parameters\n        \"\"\"\n        if self.local_trigger_step < self.trigger_parameter_sync_step and not validate:\n            self.local_trigger_step += 1\n            return\n\n        self.current_param_version += 1\n        self.local_trigger_step = 1\n        self.logger.log(\n            data=self.metrics_aggregator.get_aggregated_metrics(),\n            step=self.current_param_version,\n        )\n        self.progress_bar.update(1)\n        self.metrics_aggregator.reset()\n        timing_param_sync = {}\n        with marked_timer(\"timing_s/wait_last_valid\", timing_param_sync):\n            ray.get(self.param_synchronizer.wait_last_valid.remote())\n        with marked_timer(\"timing_s/param_sync\", timing_param_sync):\n            ray.get(\n                self.param_synchronizer.sync_weights.remote(\n                    self.current_param_version, validate=validate, global_steps=global_steps\n                )\n            )\n        self.logger.log(data=timing_param_sync, step=self.current_param_version)\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/megatron_worker.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n# Copyright 2025 NVIDIA Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\nimport torch.distributed\nfrom omegaconf import DictConfig\n\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.utils.device import (\n    get_device_name,\n    get_torch_device,\n)\nfrom verl.utils.megatron_utils import per_tensor_generator\nfrom verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\ndevice_name = get_device_name()\n\n__all__ = [\"DetachActorWorker\", \"DetachAsyncRolloutWorker\", \"CriticWorker\"]\n\n\ndef get_inference_model(rollout):\n    \"\"\"\n    get models according to different types of inference_engine\n    Args:\n        rollout: rollout object\n    Returns:\n        model: model object\n    \"\"\"\n    inference_engine = rollout.inference_engine\n    if hasattr(inference_engine, \"llm_engine\"):\n        inference_model = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n    elif hasattr(inference_engine, \"worker\"):\n        inference_model = inference_engine.worker.model_runner.model\n    else:\n        raise AttributeError(\n            f\"Unsupported inference_engine type: {type(inference_engine)}. \"\n            f\"Expected LLM (with llm_engine attribute) or WorkerWrapperBase (with worker attribute).\"\n        )\n    return inference_model\n\n\nclass DetachNcclSync(AsyncActorRolloutRefWorker):\n    def _get_actor_params(self):\n        pass\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\n    def sync_rollout_weights(self):\n        assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine\n        assert hasattr(self, \"_weights_info\") and self._weights_info is not None\n\n        params_generator = self._get_actor_params_generator() if self._is_actor else None\n        if self._is_rollout:\n            inference_model = get_inference_model(self.rollout)\n            from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader\n\n            patch_vllm_moe_model_weight_loader(inference_model)\n        for key, shape, dtype in self._weights_info:\n            if self._is_actor:\n                weight_key, weight = next(params_generator)\n                assert key == weight_key\n                assert shape == weight.size()\n                assert dtype == weight.dtype\n\n            tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())\n            if self._is_actor and torch.distributed.get_rank() == 0:\n                tensor.copy_(weight)\n            from ray.util.collective import collective\n\n            collective.broadcast(tensor, src_rank=0, group_name=\"actor_rollout\")\n            if self._is_rollout:\n                inference_model.load_weights([(key, tensor)])\n\n\nclass DetachActorWorker(DetachNcclSync):\n    def _get_actor_params_generator(self):\n        assert self._is_actor\n        if self.bridge is not None:\n            generator = self.bridge.export_weights(self.actor.actor_module)\n        else:\n            generator = per_tensor_generator(\n                self.actor.actor_module,\n                self.actor_model_config,\n                self.weight_converter,\n                self.tf_config,\n                self.layer_name_mapping,\n            )\n\n        return generator\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def get_actor_weights_info(self):\n        assert self._is_actor\n        if hasattr(self, \"_weights_info\"):\n            return self._weights_info\n\n        params_generator = self._get_actor_params_generator()\n        ret = []\n        for key, tensor in params_generator:\n            ret.append((key, tensor.size(), tensor.dtype))\n\n        self._weights_info = ret\n        return ret\n\n\nclass DetachAsyncRolloutWorker(DetachNcclSync):\n    def __init__(self, config: DictConfig, role: str):\n        print(f\"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}\")\n        ActorRolloutRefWorker.__init__(self, config, role)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def set_actor_weights_info(self, weights_info):\n        assert self._is_rollout\n        self._weights_info = weights_info\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/message_queue.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport logging\nfrom collections import deque\nfrom typing import Any\n\nimport ray\nfrom omegaconf import DictConfig\n\nlogger = logging.getLogger(__name__)\n\n\n@ray.remote(num_cpus=2, max_concurrency=20)\nclass MessageQueue:\n    \"\"\"\n    Simplified Ray-based asynchronous message queue for communication between Rollouter and Trainer\n    \"\"\"\n\n    def __init__(self, config: DictConfig, max_queue_size: int = 1000):\n        self.config = config\n        if max_queue_size is None:\n            raise ValueError(f\"max_queue_size cannot be None, got: {max_queue_size}\")\n        self.max_queue_size = int(max_queue_size)\n        self.queue = deque(maxlen=self.max_queue_size)\n        self.current_param_version = 0\n\n        self.val_queue = deque()\n\n        try:\n            if hasattr(config, \"async_training\") and config.async_training is not None:\n                self.staleness_threshold = getattr(config.async_training, \"staleness_threshold\", 3)\n            else:\n                self.staleness_threshold = 3\n        except (AttributeError, RecursionError):\n            self.staleness_threshold = 3\n\n        # Asyncio for message handling\n        self.running = True\n\n        # async safe\n        self._lock = asyncio.Lock()\n        self._consumer_condition = asyncio.Condition(self._lock)\n\n        # statistic message\n        self.total_produced = 0\n        self.total_consumed = 0\n        self.dropped_samples = 0\n\n        print(\n            f\"[MessageQueue] initialized with max_queue_size={max_queue_size},\"\n            f\"staleness_threshold={self.staleness_threshold}\"\n        )\n\n    async def put_sample(self, sample: Any, param_version: int) -> bool:\n        \"\"\"\n        Put a batch sample into the queue\n\n        Args:\n            sample: Sample data\n            param_version: Parameter version number\n\n        Returns:\n            bool: Whether the sample was successfully put into the queue\n        \"\"\"\n        async with self._lock:\n            # If queue is full, remove the oldest sample (rarely happens)\n            is_drop = False\n            if len(self.queue) >= self.max_queue_size:\n                self.queue.popleft()\n                self.dropped_samples += 1\n                is_drop = True\n                logger.warning(\"Queue full, dropped sample\")\n            self.queue.append(sample)\n            self.total_produced += 1\n\n            # Notify waiting consumers\n            self._consumer_condition.notify_all()\n\n            if self.total_produced % 100 == 0:\n                print(f\"MessageQueue stats: produced={self.total_produced}, queue_size={len(self.queue)}\")\n            if is_drop:\n                return False\n            return True\n\n    async def get_sample(self) -> Any | None:\n        \"\"\"\n        Get a single sample from the queue, wait until one is available\n\n        Returns:\n            Any: Single sample data or None if queue is closed\n        \"\"\"\n        async with self._lock:\n            while len(self.queue) == 0 and self.running:\n                await self._consumer_condition.wait()\n\n            # If queue is closed and empty, return None\n            if not self.running and len(self.queue) == 0:\n                return None\n\n            # Get one sample\n            data = self.queue.popleft()\n            self.total_consumed += 1\n            return data, len(self.queue)\n\n    async def update_param_version(self, version: int):\n        \"\"\"Update current parameter version\"\"\"\n        async with self._lock:\n            old_version = self.current_param_version\n            self.current_param_version = version\n            print(f\"Parameter version updated from {old_version} to {version}\")\n\n    async def get_queue_size(self) -> int:\n        \"\"\"Get current queue length\"\"\"\n        async with self._lock:\n            return len(self.queue)\n\n    async def get_statistics(self) -> dict[str, Any]:\n        \"\"\"Get queue statistics\"\"\"\n        async with self._lock:\n            return {\n                \"queue_size\": len(self.queue),\n                \"total_produced\": self.total_produced,\n                \"total_consumed\": self.total_consumed,\n                \"dropped_samples\": self.dropped_samples,\n                \"current_param_version\": self.current_param_version,\n                \"staleness_threshold\": self.staleness_threshold,\n                \"max_queue_size\": self.max_queue_size,\n            }\n\n    async def clear_queue(self):\n        \"\"\"Clear the queue\"\"\"\n        async with self._lock:\n            cleared_count = len(self.queue)\n            self.queue.clear()\n            logger.info(f\"Cleared {cleared_count} samples from queue\")\n\n    async def shutdown(self):\n        \"\"\"Shutdown the message queue\"\"\"\n        async with self._lock:\n            self.running = False\n            # Notify all waiting coroutines so they can exit\n            self._consumer_condition.notify_all()\n        logger.info(\"MessageQueue shutdown\")\n\n    async def get_memory_usage(self) -> dict:\n        \"\"\"Get memory usage statistics\"\"\"\n        async with self._lock:\n            # Estimate memory usage of samples in queue\n            import sys\n\n            total_size = 0\n            sample_count = len(self.queue)\n\n            if sample_count > 0:\n                # Estimate size of a single sample (simplified estimation)\n                sample = list(self.queue)[0]\n                try:\n                    sample_size = sys.getsizeof(sample)\n                    # Since we now store RolloutSample directly, estimate based on its components\n                    if hasattr(sample, \"original_batch_dict\") and sample.original_batch_dict:\n                        # Estimate batch data size\n                        batch_data = sample.original_batch_dict.get(\"batch\", {})\n                        sample_size += len(batch_data) * 1000  # Roughly estimate 1KB per batch entry\n                    if hasattr(sample, \"agent_loop_output\"):\n                        # Estimate AgentLoopOutput size\n                        sample_size += 5000  # Roughly estimate 5KB for AgentLoopOutput\n                    total_size = sample_size * sample_count\n                except Exception:\n                    total_size = sample_count * 15000  # Roughly estimate 15KB per RolloutSample\n\n            return {\n                \"queue_samples\": sample_count,\n                \"estimated_memory_bytes\": total_size,\n                \"estimated_memory_mb\": total_size / (1024 * 1024),\n            }\n\n    async def put_validate(self, data):\n        async with self._lock:\n            self.val_queue.append(data)\n\n    async def get_validate(self):\n        async with self._lock:\n            if self.val_queue:\n                return self.val_queue.popleft()\n            else:\n                return None\n\n\nclass MessageQueueClient:\n    \"\"\"Asyncio-compatible MessageQueue client for communicating with MessageQueue Actor\"\"\"\n\n    def __init__(self, queue_actor: Any):\n        self.queue_actor = queue_actor\n\n    async def put_sample(self, sample: Any, param_version: int) -> bool:\n        \"\"\"Put batch into queue (async)\"\"\"\n        future = self.queue_actor.put_sample.remote(sample, param_version)\n        return await asyncio.wrap_future(future.future())\n\n    async def put_validate(self, data: Any) -> bool:\n        future = self.queue_actor.put_validate.remote(data)\n        return await asyncio.wrap_future(future.future())\n\n    def get_validate_sync(self) -> Any | None:\n        return ray.get(self.queue_actor.get_validate.remote())\n\n    async def get_sample(self) -> Any | None:\n        \"\"\"Get single sample from queue, wait until one is available (async)\"\"\"\n        future = self.queue_actor.get_sample.remote()\n        return await asyncio.wrap_future(future.future())\n\n    async def get_queue_size(self) -> int:\n        \"\"\"Get queue size (async)\"\"\"\n        future = self.queue_actor.get_queue_size.remote()\n        return await asyncio.wrap_future(future.future())\n\n    async def get_statistics(self) -> dict[str, Any]:\n        \"\"\"Get statistics (async)\"\"\"\n        future = self.queue_actor.get_statistics.remote()\n        return await asyncio.wrap_future(future.future())\n\n    async def clear_queue(self):\n        \"\"\"Clear queue (async)\"\"\"\n        future = self.queue_actor.clear_queue.remote()\n        await asyncio.wrap_future(future.future())\n\n    async def shutdown(self):\n        \"\"\"Shutdown queue (async)\"\"\"\n        future = self.queue_actor.shutdown.remote()\n        await asyncio.wrap_future(future.future())\n\n    async def get_memory_usage(self) -> dict:\n        \"\"\"Get memory usage statistics (async)\"\"\"\n        future = self.queue_actor.get_memory_usage.remote()\n        return await asyncio.wrap_future(future.future())\n\n    # Synchronous version of the method (deprecated)\n    def put_sample_sync(self, sample: Any, param_version: int) -> bool:\n        \"\"\"Put batch into queue (sync - deprecated, use put_sample instead)\"\"\"\n        return ray.get(self.queue_actor.put_sample.remote(sample, param_version))\n\n    def get_sample_sync(self) -> Any | None:\n        \"\"\"Get single sample from queue (sync - deprecated, use get_sample instead)\"\"\"\n        return ray.get(self.queue_actor.get_sample.remote())\n\n    def get_statistics_sync(self) -> dict[str, Any]:\n        \"\"\"Get statistics (sync - deprecated, use get_statistics instead)\"\"\"\n        return ray.get(self.queue_actor.get_statistics.remote())\n\n    def update_param_version_sync(self, version: int):\n        \"\"\"Update parameter version (async)\"\"\"\n        return ray.get(self.queue_actor.update_param_version.remote(version))\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/param_sync.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport time\n\nimport ray\nfrom ray.util.collective import collective\n\nlogger = logging.getLogger(__name__)\n\n\n@ray.remote\nclass ParameterSynchronizer:\n    \"\"\"\n    Unified parameter synchronizer, responsible for synchronizing model parameters between actor and rollout\n    Based on the mature synchronization mode implementation of one_step_off_policy\n    Merges the functions of the original multiple synchronizer classes\n    \"\"\"\n\n    def __init__(self, config, trainer, rollouter, mq):\n        self.config = config\n        self.trainer = trainer\n        self.rollouter = rollouter\n        self.mq_client = mq\n        self.actor_wg = ray.get(trainer.get_actor_wg.remote())\n        self.rollout_wg = ray.get(rollouter.get_rollout_wg.remote())\n\n        # Basic attributes\n        self.weights_info = None\n        self.sync_group_initialized = False\n        self.sync_group_name = \"actor_rollout\"\n        self.wait_last_update = None\n        self.wait_last_resume = None\n\n        # Statistics\n        self.current_version = 0\n\n        self._init_weights_info()\n        self._init_sync_group()\n\n    def get_current_param_version(self) -> int:\n        \"\"\"Get current parameter version number\"\"\"\n        return self.current_version\n\n    def get_weights_info(self):\n        \"\"\"Get weights info\"\"\"\n        return self.weights_info\n\n    def _init_weights_info(self):\n        self.weights_info = self.actor_wg.get_actor_weights_info()[0]\n        self.rollout_wg.set_actor_weights_info(self.weights_info)\n\n    def _init_sync_group(self):\n        print(\"[ParameterSynchronizer] Initializing parameter synchronization group...\")\n        actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers\n        collective.create_collective_group(\n            actor_rollout_workers,\n            len(actor_rollout_workers),\n            list(range(0, len(actor_rollout_workers))),\n            backend=\"nccl\",\n            group_name=self.sync_group_name,\n        )\n\n    def sync_weights(self, version, validate=False, global_steps=0):\n        \"\"\"Sync weights between trainer and rollouter, and update parameter version\"\"\"\n        start_time = time.time()\n\n        self.current_version = version\n        print(f\"[ParameterSynchronizer] Starting weight synchronization (version {self.current_version})...\")\n\n        ray.get(self.rollouter.pause.remote())\n\n        # Update MQ version\n        self.mq_client.update_param_version_sync(version)\n\n        # sync weights\n        self.actor_wg.sync_rollout_weights()\n        ray.get(self.rollout_wg.sync_rollout_weights())\n        end_time = time.time()\n        print(f\"[ParameterSynchronizer] sync_weights success. cost {end_time - start_time:.2f} seconds\")\n\n        # Async Update rollout version & validation\n        self.wait_last_update = self.rollouter.update_param_version.remote(version, validate, global_steps)\n        self.wait_last_resume = self.rollouter.resume.remote(self.wait_last_update)\n\n    def wait_last_valid(self):\n        print(\"[ParameterSynchronizer] Waiting last sync and validate...\")\n        start_time = time.time()\n        if self.wait_last_update:\n            ray.get(self.wait_last_update)\n        if self.wait_last_resume:\n            ray.get(self.wait_last_resume)\n        print(f\"[ParameterSynchronizer] Wait last validate cost: {time.time() - start_time:.2f} seconds\")\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport uuid\nfrom copy import deepcopy\nfrom pprint import pprint\n\nimport numpy as np\nimport ray\nimport torch\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.experimental.dataset.sampler import AbstractCurriculumSampler\nfrom verl.single_controller.ray import RayClassWithInitArgs\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss\nfrom verl.trainer.ppo.metric_utils import (\n    compute_data_metrics,\n    compute_throughout_metrics,\n    compute_timing_metrics,\n)\nfrom verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask\nfrom verl.trainer.ppo.reward import compute_reward, compute_reward_async\nfrom verl.trainer.ppo.utils import Role\nfrom verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.debug import marked_timer\nfrom verl.utils.metric import (\n    reduce_metrics,\n)\nfrom verl.utils.rollout_skip import RolloutSkip\n\n\nclass FullyAsyncRayPPOTrainer(RayPPOTrainer):\n    def init_workers(self):\n        \"\"\"Initialize distributed training workers using Ray backend.\n\n        Creates:\n        1. Ray resource pools from configuration\n        2. Worker groups for each role (actor, critic, etc.)\n        \"\"\"\n        self._init_resource_pools()\n        self._create_worker_classes()\n        self._init_worker_groups()\n        self._init_models()\n        self._init_async_rollout_manager()\n\n    def _init_resource_pools(self):\n        self.resource_pool_manager.create_resource_pool()\n\n        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}\n\n    def _create_worker_classes(self):\n        self._create_actor_rollout_classes()\n        self._create_critic_class()\n        self._create_reference_policy_class()\n        self._create_reward_model_class()\n\n    def _create_actor_rollout_classes(self):\n        raise NotImplementedError\n\n    def _create_critic_class(self):\n        # create critic\n        if self.use_critic:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)\n            critic_cfg = omega_conf_to_dataclass(self.config.critic)\n            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)\n            self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls\n\n    def _create_reference_policy_class(self):\n        # create reference policy if needed\n        if self.use_reference_policy:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)\n            ref_policy_cls = RayClassWithInitArgs(\n                self.role_worker_mapping[Role.RefPolicy],\n                config=self.config.actor_rollout_ref,\n                role=str(Role.RefPolicy),\n                # profile_option=self.config.trainer.npu_profile.options,\n            )\n            self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls\n\n    def _create_reward_model_class(self):\n        # create a reward model if reward_fn is None\n        if self.use_rm:\n            # we create a RM here\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)\n            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)\n            self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls\n\n    def _init_worker_groups(self):\n        # initialize WorkerGroup\n        # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,\n        # you should not use `create_colocated_worker_cls`.\n        # Instead, directly pass different resource pool to different worker groups.\n        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.\n        all_wg = {}\n        wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup\n        if OmegaConf.select(self.config.trainer, \"ray_wait_register_center_timeout\") is not None:\n            wg_kwargs[\"ray_wait_register_center_timeout\"] = self.config.trainer.ray_wait_register_center_timeout\n        if OmegaConf.select(self.config.global_profiler, \"steps\") is not None:\n            wg_kwargs[\"profile_steps\"] = OmegaConf.select(self.config.global_profiler, \"steps\")\n            # Only require nsight worker options when tool is nsys\n            if OmegaConf.select(self.config.global_profiler, \"tool\") == \"nsys\":\n                assert (\n                    OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, \"worker_nsight_options\")\n                    is not None\n                ), \"worker_nsight_options must be set when using nsys with profile_steps\"\n                wg_kwargs[\"worker_nsight_options\"] = OmegaConf.to_container(\n                    OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, \"worker_nsight_options\")\n                )\n        wg_kwargs[\"device_name\"] = self.device_name\n\n        for resource_pool, class_dict in self.resource_pool_to_cls.items():\n            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n            wg_dict = self.ray_worker_group_cls(\n                resource_pool=resource_pool,\n                ray_cls_with_init=worker_dict_cls,\n                **wg_kwargs,\n            )\n            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n            all_wg.update(spawn_wg)\n        self.all_wg = all_wg\n\n    def _init_models(self):\n        if self.use_critic:\n            self.critic_wg = self.all_wg[str(Role.Critic)]\n            self.critic_wg.init_model()\n\n        if self.use_reference_policy and not self.ref_in_actor:\n            self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)]\n            self.ref_policy_wg.init_model()\n\n        if self.use_rm:\n            self.rm_wg = self.all_wg[str(Role.RewardModel)]\n            self.rm_wg.init_model()\n\n        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory\n        self.actor_rollout_wg = self.all_wg[str(Role.ActorRollout)]\n        self.actor_rollout_wg.init_model()\n\n    def _init_async_rollout_manager(self):\n        pass\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        if self.config.actor_rollout_ref.rollout.get(\"skip_rollout\", False):\n            rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)\n            rollout_skip.wrap_generate_sequences()\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n        self.max_steps_duration = 0\n\n        prev_step_profile = False\n        curr_step_profile = (\n            self.global_steps in self.config.global_profiler.steps\n            if self.config.global_profiler.steps is not None\n            else False\n        )\n        next_step_profile = False\n\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n                timing_raw = {}\n\n                with marked_timer(\"start_profile\", timing_raw):\n                    self._start_profiling(\n                        not prev_step_profile and curr_step_profile\n                        if self.config.global_profiler.profile_continuous_steps\n                        else curr_step_profile\n                    )\n\n                batch, gen_batch = self._prepare_generate_batch(batch_dict)\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                with marked_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with marked_timer(\"gen\", timing_raw, color=\"red\"):\n                        if not self.async_rollout_mode:\n                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                        else:\n                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                        if self.reward_fn is None:\n                            raise ValueError(\"A reward_fn is required for REMAX advantage estimation.\")\n\n                        with marked_timer(\"gen_max\", timing_raw, color=\"purple\"):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            if not self.async_rollout_mode:\n                                gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n                            else:\n                                gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)\n                            batch = batch.union(gen_baseline_output)\n                            reward_baseline_tensor = self.reward_fn(batch)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))\n\n                            batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del gen_baseline_batch, gen_baseline_output\n\n                    batch = self._post_generate_batch(batch, gen_batch_output, metrics)\n                    batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw)\n                    self._log_rollout(batch, reward_extra_infos_dict, timing_raw)\n\n                last_val_metrics = self._validate_metrics(is_last_step, last_val_metrics, metrics, timing_raw)\n                self._check_save_checkpoint(is_last_step, timing_raw)\n\n                with marked_timer(\"stop_profile\", timing_raw):\n                    next_step_profile = (\n                        self.global_steps + 1 in self.config.global_profiler.steps\n                        if self.config.global_profiler.steps is not None\n                        else False\n                    )\n                    self._stop_profiling(\n                        curr_step_profile and not next_step_profile\n                        if self.config.global_profiler.profile_continuous_steps\n                        else curr_step_profile\n                    )\n                    prev_step_profile = curr_step_profile\n                    curr_step_profile = next_step_profile\n\n                self._collect_metrics(batch, epoch, metrics, timing_raw)\n                self._post_batch_processing(batch)\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                progress_bar.update(1)\n                self.global_steps += 1\n\n                if (\n                    hasattr(self.config.actor_rollout_ref.actor, \"profiler\")\n                    and self.config.actor_rollout_ref.actor.profiler.tool == \"torch_memory\"\n                ):\n                    self.actor_rollout_wg.dump_memory_snapshot(\n                        tag=f\"post_update_step{self.global_steps}\", sub_dir=f\"step{self.global_steps}\"\n                    )\n\n                if is_last_step:\n                    pprint(f\"Final validation metrics: {last_val_metrics}\")\n                    progress_bar.close()\n                    return\n\n    def _prepare_generate_batch(self, batch_dict):\n        batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n        # add uid to batch\n        batch.non_tensor_batch[\"uid\"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)\n\n        gen_batch = self._get_gen_batch(batch)\n\n        # pass global_steps to trace\n        gen_batch.meta_info[\"global_steps\"] = self.global_steps\n        gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n        return batch, gen_batch\n\n    def _post_generate_batch(self, batch, gen_batch_output, metrics):\n        # repeat to align with repeated responses in rollout\n        batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n        batch = batch.union(gen_batch_output)\n\n        if \"response_mask\" not in batch.batch.keys():\n            batch.batch[\"response_mask\"] = compute_response_mask(batch)\n        # Balance the number of valid tokens across DP ranks.\n        # NOTE: This usually changes the order of data in the `batch`,\n        # which won't affect the advantage calculation (since it's based on uid),\n        # but might affect the loss calculation (due to the change of mini-batching).\n        # TODO: Decouple the DP balancing and mini-batching.\n        if self.config.trainer.balance_batch:\n            self._balance_batch(batch, metrics=metrics)\n\n        # compute global_valid tokens\n        batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n        return batch\n\n    def _process_batch_common(self, batch, metrics, timing_raw, local_trigger_step=None):\n        with marked_timer(\"reward\", timing_raw, color=\"yellow\"):\n            # compute reward model score\n            if self.use_rm:\n                reward_tensor = self.rm_wg.compute_rm_score(batch)\n                batch = batch.union(reward_tensor)\n\n            if self.config.reward_model.launch_reward_fn_async:\n                future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)\n            else:\n                reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)\n\n        with marked_timer(\"old_log_prob\", timing_raw, color=\"blue\"):\n\n            def compute_old_log_prob(batch):\n                old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                entropys = old_log_prob.batch[\"entropys\"]\n                response_masks = batch.batch[\"response_mask\"]\n                loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n                metrics.update(old_log_prob_metrics)\n                old_log_prob.batch.pop(\"entropys\")\n                batch = batch.union(old_log_prob)\n                if \"rollout_log_probs\" in batch.batch.keys():\n                    # TODO: we may want to add diff of probs too.\n                    from verl.utils.debug.metrics import calculate_debug_metrics\n\n                    metrics.update(calculate_debug_metrics(batch))\n                return batch\n\n            async_training = self.config.get(\"async_training\", None)\n            if async_training and async_training.use_rollout_log_probs:\n                # If local_triger_step == 1, load the training engine's parameters to the CPU\n                #  and save a copy for subsequent MIS use.\n                # If local_trigger_step == 2, 3, ..., restore the parameters of version 1 to calculate the old_log_prob,\n                # then restore the parameters of the current version.\n                if local_trigger_step == 1:\n                    self.actor_rollout_wg.save_model_to_cpu(1)\n                    batch = compute_old_log_prob(batch)\n                elif local_trigger_step is not None:\n                    self.actor_rollout_wg.save_model_to_cpu(local_trigger_step)\n                    self.actor_rollout_wg.restore_model_from_cpu(1)\n                    batch = compute_old_log_prob(batch)\n                    self.actor_rollout_wg.restore_model_from_cpu(local_trigger_step)\n                    self.actor_rollout_wg.clear_cpu_model(local_trigger_step)\n                else:\n                    batch.batch[\"old_log_probs\"] = batch.batch[\"rollout_log_probs\"]\n                    batch.meta_info[\"temperature\"] = self.config.actor_rollout_ref.rollout.temperature\n\n            else:\n                batch = compute_old_log_prob(batch)\n\n        if self.use_reference_policy:\n            # compute reference log_prob\n            with marked_timer(\"ref\", timing_raw, color=\"olive\"):\n                if not self.ref_in_actor:\n                    ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                else:\n                    ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)\n                batch = batch.union(ref_log_prob)\n\n        # compute values\n        if self.use_critic:\n            with marked_timer(\"values\", timing_raw, color=\"cyan\"):\n                values = self.critic_wg.compute_values(batch)\n                batch = batch.union(values)\n\n        with marked_timer(\"adv\", timing_raw, color=\"brown\"):\n            # we combine with rule-based rm\n            reward_extra_infos_dict: dict[str, list]\n            if self.config.reward_model.launch_reward_fn_async:\n                reward_tensor, reward_extra_infos_dict = ray.get(future_reward)\n            batch.batch[\"token_level_scores\"] = reward_tensor\n\n            if reward_extra_infos_dict:\n                batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})\n\n            # compute rewards. apply_kl_penalty if available\n            if self.config.algorithm.use_kl_in_reward:\n                batch, kl_metrics = apply_kl_penalty(\n                    batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                )\n                metrics.update(kl_metrics)\n            else:\n                batch.batch[\"token_level_rewards\"] = batch.batch[\"token_level_scores\"]\n\n            # Compute rollout importance sampling weights centrally (once per batch)\n            # This corrects for mismatch between rollout policy and training policy\n            # Also computes mismatch metrics (KL, PPL, etc.)\n            batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)\n            # IS and mismatch metrics already have mismatch/ prefix\n            metrics.update(is_metrics)\n\n            # compute advantages, executed on the driver process\n            norm_adv_by_std_in_grpo = self.config.algorithm.get(\n                \"norm_adv_by_std_in_grpo\", True\n            )  # GRPO adv normalization factor\n\n            batch = compute_advantage(\n                batch,\n                adv_estimator=self.config.algorithm.adv_estimator,\n                gamma=self.config.algorithm.gamma,\n                lam=self.config.algorithm.lam,\n                num_repeat=self.config.actor_rollout_ref.rollout.n,\n                norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                config=self.config.algorithm,\n            )\n\n        # update critic\n        if self.use_critic:\n            with marked_timer(\"update_critic\", timing_raw, color=\"pink\"):\n                critic_output = self.critic_wg.update_critic(batch)\n            critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n            metrics.update(critic_output_metrics)\n\n        # implement critic warmup\n        if self.config.trainer.critic_warmup <= self.global_steps:\n            # update actor\n            with marked_timer(\"update_actor\", timing_raw, color=\"red\"):\n                batch.meta_info[\"multi_turn\"] = self.config.actor_rollout_ref.rollout.multi_turn.enable\n                actor_output = self.actor_rollout_wg.update_actor(batch)\n            actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n            metrics.update(actor_output_metrics)\n        return batch, reward_extra_infos_dict\n\n    def _log_rollout(self, batch, reward_extra_infos_dict, timing_raw):\n        # Log rollout generations if enabled\n        rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n        if rollout_data_dir:\n            with marked_timer(\"dump_rollout_generations\", timing_raw, color=\"green\"):\n                inputs = self.tokenizer.batch_decode(batch.batch[\"prompts\"], skip_special_tokens=True)\n                outputs = self.tokenizer.batch_decode(batch.batch[\"responses\"], skip_special_tokens=True)\n                scores = batch.batch[\"token_level_scores\"].sum(-1).cpu().tolist()\n                sample_gts = [item.non_tensor_batch.get(\"reward_model\", {}).get(\"ground_truth\", None) for item in batch]\n\n                if \"request_id\" in batch.non_tensor_batch:\n                    reward_extra_infos_dict.setdefault(\n                        \"request_id\",\n                        batch.non_tensor_batch[\"request_id\"].tolist(),\n                    )\n\n                self._dump_generations(\n                    inputs=inputs,\n                    outputs=outputs,\n                    gts=sample_gts,\n                    scores=scores,\n                    reward_extra_infos_dict=reward_extra_infos_dict,\n                    dump_path=rollout_data_dir,\n                )\n\n    def _validate_metrics(self, is_last_step, last_val_metrics, metrics, timing_raw):\n        if (\n            self.val_reward_fn is not None\n            and self.config.trainer.test_freq > 0\n            and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n        ):\n            with marked_timer(\"testing\", timing_raw, color=\"green\"):\n                val_metrics: dict = self._validate()\n                if is_last_step:\n                    last_val_metrics = val_metrics\n            metrics.update(val_metrics)\n            return last_val_metrics\n\n    def _check_save_checkpoint(self, is_last_step, timing_raw):\n        # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.\n        esi_close_to_expiration = should_save_ckpt_esi(\n            max_steps_duration=self.max_steps_duration,\n            redundant_time=self.config.trainer.esi_redundant_time,\n        )\n        # Check if the conditions for saving a checkpoint are met.\n        # The conditions include a mandatory condition (1) and\n        # one of the following optional conditions (2/3/4):\n        # 1. The save frequency is set to a positive value.\n        # 2. It's the last training step.\n        # 3. The current step number is a multiple of the save frequency.\n        # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.\n        if self.config.trainer.save_freq > 0 and (\n            is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration\n        ):\n            if esi_close_to_expiration:\n                print(\"Force saving checkpoint: ESI instance expiration approaching.\")\n            with marked_timer(\"save_checkpoint\", timing_raw, color=\"green\"):\n                self._save_checkpoint()\n\n    def _collect_metrics(self, batch, epoch, metrics, timing_raw):\n        steps_duration = timing_raw[\"step\"]\n        self.max_steps_duration = max(self.max_steps_duration, steps_duration)\n\n        # training metrics\n        metrics.update(\n            {\n                \"training/global_step\": self.global_steps,\n                \"training/epoch\": epoch,\n            }\n        )\n        # collect metrics\n        metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n        metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n        # TODO: implement actual tflpo and theoretical tflpo\n        n_gpus = self.resource_pool_manager.get_n_gpus()\n        metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n\n    def _post_batch_processing(self, batch: DataProto):\n        # this is experimental and may be changed/removed in the future in favor of a general-purpose one\n        if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):\n            self.train_dataloader.sampler.update(batch=batch)\n\n        # this is experimental and may be changed/removed in the future\n        # in favor of a general-purpose data buffer pool\n        if hasattr(self.train_dataset, \"on_batch_end\"):\n            # The dataset may be changed after each training batch\n            self.train_dataset.on_batch_end(batch=batch)\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_16-16'\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\n# Algorithm parameters\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\n# Response length parameters\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 28))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# Training parameters\nloss_agg_mode=\"token-mean\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=4\nsp_size=4\nfsdp_size=8\n\n# Fully async specific parameters\nNNODES_ROLLOUT=${NNODES_ROLLOUT:-2}\nNNODES_TRAIN=${NNODES_TRAIN:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\ntotal_rollout_steps=$(((512*400)))\ntest_freq=20\nstaleness_threshold=0.1\ntrigger_parameter_sync_step=4\nrequire_batches=4\npartial_rollout=True\n\npython -m recipe.fully_async_policy.fully_async_main \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.save_freq=-1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.nnodes=\"${NNODES_TRAIN}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.nnodes=\"${NNODES_ROLLOUT}\" \\\n    rollout.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.total_epochs=10 \\\n    rollout.test_freq=\"${test_freq}\" \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.require_batches=\"${require_batches}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\" \\\n    async_training.use_rollout_log_probs=True"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_32-32'\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\n# Algorithm parameters\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\n# Response length parameters\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 28))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# Training parameters\nloss_agg_mode=\"token-mean\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=4\nsp_size=4\nfsdp_size=8\n\n# Fully async specific parameters\nNNODES_ROLLOUT=${NNODES_ROLLOUT:-4}\nNNODES_TRAIN=${NNODES_TRAIN:-4}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\ntotal_rollout_steps=$(((512*400)))\ntest_freq=20\nstaleness_threshold=0.1\ntrigger_parameter_sync_step=4\nrequire_batches=4\npartial_rollout=True\n\npython -m recipe.fully_async_policy.fully_async_main \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.save_freq=-1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.nnodes=\"${NNODES_TRAIN}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.nnodes=\"${NNODES_ROLLOUT}\" \\\n    rollout.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.total_epochs=10 \\\n    rollout.test_freq=\"${test_freq}\" \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.require_batches=\"${require_batches}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\" \\\n    async_training.use_rollout_log_probs=True\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-4-12'\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\n# Algorithm parameters\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\n# Response length parameters\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# Training parameters\nloss_agg_mode=\"token-mean\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=1\nsp_size=1\nfsdp_size=2\n\n# Fully async specific parameters\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\ntotal_rollout_steps=$(((512*100)))\ntest_freq=10\nstaleness_threshold=0.1\ntrigger_parameter_sync_step=4\nrequire_batches=4\npartial_rollout=True\n\npython -m recipe.fully_async_policy.fully_async_main \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=\"${test_freq}\" \\\n    trainer.save_freq=-1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.total_epochs=10 \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.require_batches=\"${require_batches}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\" \\\n    async_training.use_rollout_log_probs=True"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-4-4'\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\n# Algorithm parameters\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\n# Response length parameters\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# Training parameters\nloss_agg_mode=\"token-mean\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=1\nsp_size=1\nfsdp_size=2\n\n# Fully async specific parameters\nNNODES=${NNODES:-1}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=4\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\ntotal_rollout_steps=$(((512*100)))\ntest_freq=10\nstaleness_threshold=0.1\ntrigger_parameter_sync_step=4\nrequire_batches=4\npartial_rollout=True\n\npython -m recipe.fully_async_policy.fully_async_main \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=False \\\n    trainer.save_freq=-1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.total_epochs=10 \\\n    rollout.test_freq=\"${test_freq}\" \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.require_batches=\"${require_batches}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\" \\\n    async_training.use_rollout_log_probs=True"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_64-64'\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\n# Algorithm parameters\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\n# Response length parameters\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 28))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# Training parameters\nloss_agg_mode=\"token-mean\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=4\nsp_size=4\nfsdp_size=8\n\n# Fully async specific parameters\nNNODES_ROLLOUT=${NNODES_ROLLOUT:-8}\nNNODES_TRAIN=${NNODES_TRAIN:-8}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\ntotal_rollout_steps=$(((512*400)))\ntest_freq=20\nstaleness_threshold=0.5\ntrigger_parameter_sync_step=4\nrequire_batches=4\npartial_rollout=True\n\npython -m recipe.fully_async_policy.fully_async_main \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.save_freq=-1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.nnodes=\"${NNODES_TRAIN}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.nnodes=\"${NNODES_ROLLOUT}\" \\\n    rollout.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.total_epochs=10 \\\n    rollout.test_freq=\"${test_freq}\" \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.require_batches=\"${require_batches}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\" \\\n    async_training.use_rollout_log_probs=True\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_64-64'\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\n# Algorithm parameters\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\n# Response length parameters\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 28))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# Training parameters\nloss_agg_mode=\"token-mean\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=4\nsp_size=4\nfsdp_size=8\n\n# Fully async specific parameters\nNNODES_ROLLOUT=${NNODES_ROLLOUT:-8}\nNNODES_TRAIN=${NNODES_TRAIN:-8}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\ntotal_rollout_steps=$(((512*400)))\ntest_freq=20\nstaleness_threshold=0.5\ntrigger_parameter_sync_step=4\nrequire_batches=4\npartial_rollout=True\n\n# Rollout Importance Sampling\nrollout_is_threshold=1.001\nrollout_is=True\nrollout_is_threshold_lower=0.99\nrollout_is_level=geometric\nrollout_is_mode=mask\nrollout_is_veto_threshold=1e-4\n\npython -m recipe.fully_async_policy.fully_async_main \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.save_freq=-1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.nnodes=\"${NNODES_TRAIN}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.nnodes=\"${NNODES_ROLLOUT}\" \\\n    rollout.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.total_epochs=10 \\\n    rollout.test_freq=\"${test_freq}\" \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.require_batches=\"${require_batches}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\" \\\n    async_training.use_rollout_log_probs=True \\\n    async_training.compute_prox_log_prob=True \\\n    algorithm.rollout_is=${rollout_is} \\\n    algorithm.rollout_is_threshold=${rollout_is_threshold} \\\n    algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \\\n    algorithm.rollout_is_level=${rollout_is_level} \\\n    algorithm.rollout_is_mode=${rollout_is_mode} \\\n    algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold}\n\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-8-8'\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\n# Algorithm parameters\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\n# Response length parameters\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# Training parameters\nloss_agg_mode=\"token-mean\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=1\nsp_size=1\nfsdp_size=2\n\n# Fully async specific parameters\nNNODES_ROLLOUT=${NNODES_ROLLOUT:-1}\nNNODES_TRAIN=${NNODES_TRAIN:-1}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\ntotal_rollout_steps=$(((512*100)))\ntest_freq=10\nstaleness_threshold=0.1\ntrigger_parameter_sync_step=4\nrequire_batches=4\npartial_rollout=True\n\npython -m recipe.fully_async_policy.fully_async_main \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.save_freq=-1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.nnodes=\"${NNODES_TRAIN}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.nnodes=\"${NNODES_ROLLOUT}\" \\\n    rollout.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.total_epochs=10 \\\n    rollout.test_freq=\"${test_freq}\" \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.require_batches=\"${require_batches}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\" \\\n    async_training.use_rollout_log_probs=True"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh",
    "content": "set -x\nENGINE=${1:-vllm}\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\n\nHF_MODEL_PATH=${HF_MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-VL-7B-Instruct\"}\n\ntrain_path=$HOME/data/geo3k/train.parquet\ntest_path=$HOME/data/geo3k/test.parquet\n\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\n# Fully async specific parameters\nNNODES=${NNODES:-1}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=4\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=4\ntrain_prompt_mini_bsz=128\ntotal_rollout_steps=$(((512*100)))\ntest_freq=5\nstaleness_threshold=0.1\ntrigger_parameter_sync_step=4\nrequire_batches=2\npartial_rollout=True\ntotal_epochs=200\n\npython -m recipe.fully_async_policy.fully_async_main \\\n    --config-path=config \\\n    --config-name='fully_async_ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_path\" \\\n    data.val_files=\"$test_path\" \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    actor_rollout_ref.rollout.max_model_len=32768 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=32768 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.return_raw_chat=${return_raw_chat} \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_decay_steps=51200 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=5120 \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=5120 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \\\n    actor_rollout_ref.actor.megatron.use_mbridge=True \\\n    actor_rollout_ref.actor.megatron.param_offload=True \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=True \\\n    actor_rollout_ref.actor.megatron.grad_offload=True \\\n    actor_rollout_ref.ref.megatron.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_megatron_async' \\\n    trainer.test_freq=\"${test_freq}\" \\\n    trainer.total_epochs=\"${total_epochs}\" \\\n    trainer.val_before_train=False \\\n    trainer.save_freq=-1 \\\n    trainer.resume_mode=auto \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\" \\\n    rollout.total_rollout_steps=\"${total_rollout_steps}\" \\\n    rollout.total_epochs=\"${total_epochs}\" \\\n    rollout.test_freq=\"${test_freq}\" \\\n    async_training.staleness_threshold=\"${staleness_threshold}\" \\\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\" \\\n    async_training.require_batches=\"${require_batches}\" \\\n    async_training.partial_rollout=\"${partial_rollout}\" \\\n    async_training.use_rollout_log_probs=True"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/shell/runtime_env.yaml",
    "content": "env_vars:\n  VLLM_USE_V1: \"1\"\n  NCCL_DEBUG: \"INFO\"\n  HYDRA_FULL_ERROR: \"1\""
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/unittest/simple_streaming_demo.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport random\nimport time\n\n\nclass SimpleStreamingSystem:\n    \"\"\"Simplified streaming system demonstration\"\"\"\n\n    def __init__(self, max_concurrent_tasks: int = 4):\n        self.max_concurrent_tasks = max_concurrent_tasks\n        self.data_queue = asyncio.Queue()\n        self.result_queue = asyncio.Queue()\n        self.consumer_count = 0\n\n    # Data stream coroutine\n    async def data_stream(self):\n        # Add initial data\n        # Prepare test data\n        test_data = [{\"id\": f\"task_{i}\", \"content\": f\"data_{i}\"} for i in range(8)]\n        await self.add_data_stream(test_data)\n\n        # Simulate subsequent data stream\n        await asyncio.sleep(3)\n        print(\"\\nAdding second batch of data...\")\n        extra_data = [{\"id\": f\"extra_{i}\", \"content\": f\"extra_data_{i}\"} for i in range(5)]\n        await self.add_data_stream(extra_data)\n\n        # Send termination signal\n        await asyncio.sleep(1)\n        await self.data_queue.put(\"DONE\")\n        print(\"Sending termination signal\")\n\n    async def add_data_stream(self, data_list: list[dict]):\n        \"\"\"Simulate data stream\"\"\"\n        print(\"Starting to add data stream...\")\n\n        for i, data_item in enumerate(data_list):\n            await self.data_queue.put(data_item)\n            print(f\"Data {data_item['id']} added to pending queue\")\n\n            # Simulate interval between data streams\n            if i < len(data_list) - 1:  # Don't wait after the last item\n                await asyncio.sleep(0.8)\n\n        print(\"Initial data stream added successfully\")\n\n    async def _process_data_async(self, data_item: dict):\n        \"\"\"Asynchronously process a single data item\"\"\"\n        data_id = data_item[\"id\"]\n        content = data_item[\"content\"]\n\n        # Simulate different processing times (1-3 seconds)\n        processing_time = random.uniform(1, 3)\n\n        print(f\"    Starting to process {data_id}, estimated time {processing_time:.1f}s\")\n\n        # Asynchronously wait for processing completion\n        await asyncio.sleep(processing_time)\n\n        result = {\n            \"id\": data_id,\n            \"processed_content\": f\"Processed {content}\",\n            \"processing_time\": round(processing_time, 2),\n            \"completed_at\": time.time(),\n        }\n\n        # Immediately put into result queue\n        await self.result_queue.put(result)\n        print(f\"    {data_id} processing completed! (took {processing_time:.1f}s) -> Added to result queue\")\n\n    async def _submit_worker(self):\n        \"\"\"Stream submission worker coroutine\"\"\"\n        active_tasks = set()\n\n        print(\"Stream submitter started...\")\n\n        while True:\n            # Get data to process\n            data_item = await self.data_queue.get()\n\n            if data_item == \"DONE\":\n                print(\"Received termination signal, waiting for remaining tasks to complete...\")\n                if active_tasks:\n                    await asyncio.gather(*active_tasks, return_exceptions=True)\n                break\n\n            # Check concurrent limit\n            while len(active_tasks) >= self.max_concurrent_tasks:\n                print(f\"Reached maximum concurrency {self.max_concurrent_tasks}, waiting for tasks to complete...\")\n                done_tasks, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)\n\n                # Clean up completed tasks\n                for task in done_tasks:\n                    try:\n                        await task\n                        print(f\"Task completed {task}\")\n                    except Exception as e:\n                        print(f\"Task execution failed: {e}\")\n\n            # Immediately submit new task\n            task = asyncio.create_task(self._process_data_async(data_item), name=f\"active {data_item}\")\n            active_tasks.add(task)\n\n            print(f\"Submitted task {data_item['id']}, current concurrency: {len(active_tasks)}\")\n\n    async def _consumer_worker(self):\n        \"\"\"Result consumer coroutine\"\"\"\n        print(\"Consumer started...\")\n\n        while True:\n            try:\n                # Get processing result from result queue\n                result = await asyncio.wait_for(self.result_queue.get(), timeout=2.0)\n\n                self.consumer_count += 1\n\n                print(\n                    f\"Consumed #{self.consumer_count}: {result['id']} \"\n                    f\"(processing time {result['processing_time']}s) - {result['processed_content']}\"\n                )\n\n            except asyncio.TimeoutError:\n                print(\"    Consumer waiting...\")\n                await asyncio.sleep(0.5)\n\n    async def run_demo(self):\n        \"\"\"Run demonstration\"\"\"\n        print(\"=\" * 60)\n        print(f\"Maximum concurrency: {self.max_concurrent_tasks}\")\n        print(\"=\" * 60)\n\n        # Start core coroutines\n        stream_task = asyncio.create_task(self.data_stream())\n        submit_task = asyncio.create_task(self._submit_worker())\n        consumer_task = asyncio.create_task(self._consumer_worker())\n\n        try:\n            # Wait for data stream to complete\n            await stream_task\n            print(\"Data stream completed\")\n\n            # Wait for processing to complete\n            await submit_task\n            print(\"All tasks processed\")\n\n        finally:\n            # Cleanup\n            submit_task.cancel()\n            consumer_task.cancel()\n            await asyncio.gather(submit_task, consumer_task, return_exceptions=True)\n\n        print(f\"\\nFinal statistics: Consumed {self.consumer_count} results\")\n\n\nasync def main():\n    \"\"\"Main function\"\"\"\n    system = SimpleStreamingSystem(max_concurrent_tasks=3)\n    await system.run_demo()\n\n\nif __name__ == \"__main__\":\n    asyncio.run(main())\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/vllm_rollout/__init__.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/recipe/fully_async_policy/vllm_rollout/vllm_async_server.py",
    "content": "# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport logging\nfrom typing import Any, Optional, Sequence\n\nimport ray\nfrom ray.actor import ActorHandle\nfrom vllm import SamplingParams\nfrom vllm.inputs import TokensPrompt\nfrom vllm.outputs import RequestOutput\n\nfrom verl.workers.config import HFModelConfig, RewardModelConfig, RolloutConfig\nfrom verl.workers.rollout.replica import RolloutMode\nfrom verl.workers.rollout.vllm_rollout.vllm_async_server import (\n    _qwen2_5_vl_dedup_image_tokens,\n    vLLMHttpServerBase,\n    vLLMReplica,\n)\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(logging.INFO)\n\n\n@ray.remote(num_cpus=1)\nclass vLLMHttpServerForPartial(vLLMHttpServerBase):\n    def __init__(\n        self,\n        config: RolloutConfig | RewardModelConfig,\n        model_config: HFModelConfig,\n        rollout_mode: RolloutMode,\n        workers: list[ActorHandle],\n        replica_rank: int,\n        node_rank: int,\n        gpus_per_node: int,\n        nnodes: int,\n    ):\n        super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes)\n\n        # for cancel LLMServer\n        self.paused = False\n        self.lock = asyncio.Lock()\n        self.cancel_event: dict[str, asyncio.Event] = {}\n        self.req_output: dict[str, Optional[RequestOutput]] = {}\n\n    async def _generate_step(\n        self,\n        prompt_ids: list[int],\n        sampling_params: dict[str, Any],\n        request_id: str,\n        image_data: Optional[list[Any]] = None,\n    ):\n        max_tokens = self.config.max_model_len - len(prompt_ids)\n        sampling_params[\"logprobs\"] = 1\n        sampling_params.setdefault(\"repetition_penalty\", self.config.get(\"repetition_penalty\", 1.0))\n        sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)\n        prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)\n        prompt = TokensPrompt(\n            prompt_token_ids=prompt_ids, multi_modal_data={\"image\": image_data} if image_data else None\n        )\n        generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id)\n\n        # Get final response\n        async for output in generator:\n            self.req_output[request_id] = output\n        assert self.req_output[request_id] is not None\n\n    async def generate_for_partial(\n        self,\n        prompt_ids: list[int],\n        sampling_params: dict[str, Any],\n        request_id: str,\n        image_data: Optional[list[Any]] = None,\n    ) -> tuple[list[Any], list[Any], bool] | tuple[Sequence[int], list[float], Any]:\n        async with self.lock:\n            if self.paused:\n                # After cancel, all tasks will return directly and wait for the next submission\n                return [], [], True\n            self.req_output[request_id]: Optional[RequestOutput] = None\n            self.cancel_event[request_id] = asyncio.Event()\n            cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait())\n            generation_handle = asyncio.create_task(\n                self._generate_step(prompt_ids, sampling_params, request_id, image_data)\n            )\n\n        done, pend = await asyncio.wait([generation_handle, cancel_handle], return_when=asyncio.FIRST_COMPLETED)\n\n        for task in done:\n            await task\n\n        for task in pend:\n            task.cancel()\n\n        async with self.lock:\n            if self.req_output[request_id] is None:\n                return [], [], True\n            token_ids = self.req_output[request_id].outputs[0].token_ids\n            log_probs: list[float] = []\n            for i, x in enumerate(self.req_output[request_id].outputs[0].logprobs):\n                # In sampling_params, logprobs is set to 1, which should return 1,\n                # but in practice there are multiple. Take the log_prob corresponding to token_id\n                token_id = self.req_output[request_id].outputs[0].token_ids[i]\n                log_probs.append(x[token_id].logprob)\n            is_cancel = generation_handle not in done\n            self.cancel_event.pop(request_id, None)\n            self.req_output.pop(request_id, None)\n        return token_ids, log_probs, is_cancel\n\n    async def cancel(self):\n        async with self.lock:\n            self.paused = True\n            for request_id in self.cancel_event:\n                self.cancel_event[request_id].set()\n\n    async def resume(self):\n        async with self.lock:\n            self.paused = False\n\n    async def reset_prefix_cache(self):\n        async with self.lock:\n            await self.engine.reset_prefix_cache()\n\n\nclass FullyAsyncvLLMReplica(vLLMReplica):\n    def __init__(\n        self,\n        replica_rank: int,\n        config: RolloutConfig | RewardModelConfig,\n        model_config: HFModelConfig,\n        gpus_per_node: int = 8,\n        is_reward_model: bool = False,\n    ):\n        super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model)\n        self.server_class = vLLMHttpServerForPartial\n\n    async def cancel(self):\n        \"\"\"Cancel each rollout server.\"\"\"\n        await asyncio.gather(*[server.cancel.remote() for server in self.servers])\n\n    async def resume(self):\n        \"\"\"Resume each rollout server.\"\"\"\n        await asyncio.gather(*[server.resume.remote() for server in self.servers])\n\n    async def reset_prefix_cache(self):\n        \"\"\"reset kv cache in each rollout server.\"\"\"\n        await asyncio.gather(*[server.reset_prefix_cache.remote() for server in self.servers])\n"
  },
  {
    "path": "verl_distillation/recipe/genrm_remote/README.md",
    "content": "# Generative Reward Model\n\n## Scripts\n\n### Step 1: Launch a vLLM Server (Optional)\n\nDeploy the pretrained GenRM model using vLLM. Skip this step if you want to use an external api service.\n\n```bash \nvllm serve verl-team/GenRM-CI-Test-1.5B --served-model-name genrm-demo\n```\n\n### Step 2: Perform RL using GenRM\n\n```bash\nbash recipe/api-genrm/run_genrm_remote.sh\n```\n\nThe implementation works by passing a customized reward function (see `reward_function.py`)\n\nFor convenience, we run both the RL training and server on the same machine. To use an external server, configure the `BASE_URL` and `API_KEY` in `reward_function.py` first.\n\n## Advanced: Customizing Your GenRM\n\nYou can use sglang server with data parallel for faster inference:\n\n```bash\nCUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4\n```\n\nNote that you should modify the `BASE_URL` in `reward_function.py` to match your SGLang Server address.\n\nYou can also create your own customized GenRM by implementing a custom reward function. Here are some tips for customizing your own GenRM based on `reward_function.py`:\n\n- Design appropriate prompts for your GenRM\n- Convert GenRM responses into RL rewards\n- ...\n\nSince these aspects are highly flexible, we only provide a demo implementation. The actual design and implementation of GenRM is left to the user's discretion.\n"
  },
  {
    "path": "verl_distillation/recipe/genrm_remote/reward_function.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom concurrent.futures import ThreadPoolExecutor\nfrom time import sleep\n\nimport requests\n\nfrom verl.utils.reward_score.math_reward import last_boxed_only_string, remove_boxed\n\nBASE_URL = \"http://localhost:30000\"\nAPI_KEY = \"EMPTY\"\nMAX_RETRIES = 3\nBASE_DELAY = 2\nMAX_WORKERS = 32\nMODEL_NAME = \"genrm-demo\"\nGENRM_PROMPT_TEMPLATE = \"\"\"\nThe following is a math problem and an AI solution:\n\n[Math Problem]\n\n{problem}\n\n[AI Solution]\n\n{solution}\n\nYour task is to review and critique the solution step by step, and output whether the AI solution is correct.\n\nPlease put your final answer (i.e., 'True' or 'False') in \\\\boxed{{}}.\n\"\"\".strip()\n\n\ndef get_response(problem, solution_str, ground_truth):\n    prompt = GENRM_PROMPT_TEMPLATE.format(problem=problem, solution=solution_str)\n    messages = [{\"role\": \"user\", \"content\": prompt}]\n    for attempt in range(MAX_RETRIES):\n        try:\n            headers = {\"Content-Type\": \"application/json\"}\n            chat_url = f\"{BASE_URL}/v1/chat/completions\"\n            data = {\"model\": MODEL_NAME, \"messages\": messages}\n            output = requests.post(chat_url, headers=headers, json=data, timeout=30)\n            response = output.json()[\"choices\"][0][\"message\"][\"content\"]\n            return response\n        except Exception as e:\n            if attempt < MAX_RETRIES - 1:\n                print(\"Exception: \", repr(e))\n                delay = BASE_DELAY * (2**attempt)\n                print(f\"Retrying in {delay} seconds...\")\n                sleep(delay)\n            else:\n                print(f\"Failed after {MAX_RETRIES} attempts. Error: {e}\")\n\n    raise ConnectionRefusedError(f\"Failed to run the model for {prompt}!\")\n\n\ndef compute_reward(response):\n    reward_score = 0.0\n    try:\n        boxed_result = last_boxed_only_string(response)\n        if boxed_result is not None:\n            result = remove_boxed(boxed_result)\n            reward_score = float(result == \"True\")\n    except Exception as e:\n        print(e)\n    return reward_score\n\n\ndef compute_score(data_source, solution_str, ground_truth, extra_info):\n    split = extra_info[\"split\"]\n    from verl.utils.reward_score import default_compute_score\n\n    func_rm_score = default_compute_score(data_source, solution_str, ground_truth, extra_info)\n\n    if split == \"test\":\n        return func_rm_score\n    else:\n        problem = extra_info[\"question\"]\n        response = get_response(problem, solution_str, ground_truth)\n        if response is not None:\n            reward_score = compute_reward(response)\n        else:\n            reward_score = 0.0\n\n        return reward_score\n\n\ndef compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos):\n    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:\n        futures = []\n        for data_source, solution_str, ground_truth, extra_info in zip(\n            data_sources, solution_strs, ground_truths, extra_infos, strict=True\n        ):\n            future = executor.submit(compute_score, data_source, solution_str, ground_truth, extra_info)\n            futures.append(future)\n\n        results = [future.result() for future in futures]\n\n    return results\n"
  },
  {
    "path": "verl_distillation/recipe/genrm_remote/run_genrm_remote.sh",
    "content": "# vllm server\n# CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve verl-team/GenRM-CI-Test-1.5B --served_model_name genrm-demo\n\n# sglang server\n# CUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4\n\nset -x\n\nCUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=${HOME}/data/gsm8k/train.parquet \\\n    data.val_files=${HOME}/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=8 \\\n    algorithm.use_kl_in_reward=False \\\n    reward_model.reward_manager=batch \\\n    custom_reward_function.path=recipe/genrm_remote/reward_function.py \\\n    custom_reward_function.name=compute_score_batch \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_func_rm_example_gsm8k' \\\n    trainer.experiment_name='qwen2_5_3b_gen_rm' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.val_before_train=True \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=10 \\\n    trainer.resume_mode='disable'\n"
  },
  {
    "path": "verl_distillation/recipe/gspo/test_gspo_3b_math.sh",
    "content": "#!/usr/bin/env bash\n#SBATCH --job-name=rl-gspo-3B\n#SBATCH --partition=main\n#SBATCH --nodes=1                # Number of nodes\n#SBATCH --ntasks-per-node=1      # One task per node\n#SBATCH --cpus-per-task=128      # cpu-cores per task\n#SBATCH --gres=gpu:8\n#SBATCH --mem=0\n#SBATCH --exclusive\n#SBATCH --time=500:00:00\n#SBATCH --output=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.out\n#SBATCH --error=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.err\n\nset -xeuo pipefail\n\n# activate the venv\necho \"Activating distill environment...\"\neval \"$(conda shell.bash hook)\"\nconda deactivate\nconda activate distill\n\n# can make training faster, depends on your infrastructure\nexport NCCL_IBEXT_DISABLE=1\nexport NCCL_NVLS_ENABLE=1\nexport NCCL_IB_HCA=mlx5\nexport UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1\n\n# Set how many GPUs we actually have on this node.\nexport GPUS_PER_NODE=8\n\nNNODES=${SLURM_JOB_NUM_NODES}\nexport NNODES\n\nexport VLLM_ATTENTION_BACKEND=FLASH_ATTN\nexport RAY_LOGGING_LEVEL=DEBUG\nexport HYDRA_FULL_ERROR=1\nexport WANDB_API_KEY=... # your wandb API key\n\necho \"Using $NNODES nodes for training...\"\n\n# ------------------------------------- Setup xp params ---------------------------------------\nproject_name='RL-GSPO'\n\nadv_estimator=grpo\nloss_mode=gspo\nloss_agg_mode=\"seq-mean-token-mean\"\nMODEL_PATH=Qwen/Qwen2.5-3B-Instruct\noffload=false # it's a small model, offloading will just slow-down training\nrollout_engine=vllm\nrollout_mode=sync # can be async to speedup large scale xps\ngpu_memory_utilization=0.8\nreward_manager=dapo\nadv_estimator=grpo\nshuffle_dataset=true\nfirst_time_dataset_prep=true # prepare dataset\n\ntest_freq=10\nsave_freq=10\ntotal_epochs=10\ntotal_training_steps=500\nval_before_train=false\n\nuse_kl_in_reward=false\nkl_coef=0.0\nuse_kl_loss=false\nkl_loss_coef=0.0\n\nclip_ratio_low=0.0003 # as recommended by the paper, see Sec. 5.1\nclip_ratio_high=0.0004 # as recommended by the paper, see Sec. 5.1\ntrain_batch_size=512\nppo_mini_batch_size=128 # maintain 4 mini-batches as recommended by the paper, see Sec. 5.1\nppo_micro_batch_size_per_gpu=8 # setup depending on your GPU memory\nn_resp_per_prompt=16\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\n# dapo reward manager params\nenable_overlong_buffer=false # true\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# Paths and namings\nSFT_MODEL=$(basename $MODEL_PATH)\nexp_name=\"${loss_mode}-epslow-${clip_ratio_low}-epshigh-${clip_ratio_high}-${SFT_MODEL}-RL\"\nCKPTS_DIR=/rl/checkpoints/experimental/4b/${loss_mode}/${exp_name}\n\n# Sampling params at rollouts\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=1\nuse_dynamic_bsz=true\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=true\ngen_tp=1\nentropy_checkpointing=true # This enables entropy recomputation specifically for the entropy calculation, lowering memory usage during training.\n\n# ------------------------------------- train/val data preparation ---------------------------------------\nif [ \"$first_time_dataset_prep\" = true ]; then\n    echo \"Preprocessing GSM8K dataset...\"\n    python examples/data_preprocess/gsm8k.py --local_save_dir /data/gsm8k/\nfi\n\ngsm8k_train_path=/data/gsm8k/train.parquet\ngsm8k_test_path=/data/gsm8k/test.parquet\n\n# set the paths\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    data.train_files=\"${train_files}\" \\\n    data.val_files=\"${test_files}\" \\\n    data.shuffle=$shuffle_dataset \\\n    data.prompt_key=prompt \\\n    data.truncation='error' \\\n    data.filter_overlong_prompts=true \\\n    data.train_batch_size=${train_batch_size} \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.model.use_remove_padding=true \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.name=${rollout_engine} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=true \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.05 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=true \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=true \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \\\n    reward_model.reward_manager=${reward_manager} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=false \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${GPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=${val_before_train} \\\n    trainer.test_freq=${test_freq} \\\n    trainer.save_freq=${save_freq} \\\n    trainer.total_epochs=${total_epochs} \\\n    trainer.total_training_steps=${total_training_steps} \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=2 \\\n    $@\n"
  },
  {
    "path": "verl_distillation/recipe/gspo/test_gspo_3b_math_slurm.sh",
    "content": "#!/usr/bin/env bash\n#SBATCH --job-name=rl-gspo-3B\n#SBATCH --partition=main\n#SBATCH --nodes=1                # Number of nodes\n#SBATCH --ntasks-per-node=1      # One task per node\n#SBATCH --cpus-per-task=128      # cpu-cores per task\n#SBATCH --gres=gpu:8\n#SBATCH --mem=0\n#SBATCH --exclusive\n#SBATCH --time=500:00:00\n#SBATCH --output=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.out\n#SBATCH --error=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.err\n\nset -xeuo pipefail\n\n# activate the venv\necho \"Activating distill environment...\"\neval \"$(conda shell.bash hook)\"\nconda deactivate\nconda activate distill\n\n# can make training faster, depends on your infrastructure\nexport NCCL_IBEXT_DISABLE=1\nexport NCCL_NVLS_ENABLE=1\nexport NCCL_IB_HCA=mlx5\nexport UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1\n\n# Set how many GPUs we actually have on this node.\nexport GPUS_PER_NODE=8\n\nNNODES=${SLURM_JOB_NUM_NODES}\nexport NNODES\n\nexport VLLM_ATTENTION_BACKEND=FLASH_ATTN\nexport RAY_memory_monitor_refresh_ms=0\nexport RAY_LOGGING_LEVEL=DEBUG\nexport HYDRA_FULL_ERROR=1\nexport WANDB_API_KEY=... # your wandb API key\n\n# Let Ray know how many nodes to expect\nexport RAY_NUM_NODES=$NNODES\n\necho \"Using $NNODES nodes for training...\"\n\n# ------------------------------------- Setup xp params ---------------------------------------\nproject_name='RL-GSPO'\n\nadv_estimator=grpo\nloss_mode=gspo\nloss_agg_mode=\"seq-mean-token-mean\"\nMODEL_PATH=Qwen/Qwen2.5-3B-Instruct\noffload=false # it's a small model, offloading will just slow-down training\nrollout_engine=vllm\nrollout_mode=sync # can be async to speedup large scale xps\ngpu_memory_utilization=0.8\nreward_manager=dapo\nadv_estimator=grpo\nshuffle_dataset=true\nfirst_time_dataset_prep=true # prepare dataset\n\ntest_freq=10\nsave_freq=10\ntotal_epochs=10\ntotal_training_steps=500\nval_before_train=false\n\nuse_kl_in_reward=false\nkl_coef=0.0\nuse_kl_loss=false\nkl_loss_coef=0.0\n\nclip_ratio_low=0.0003 # as recommended by the paper, see Sec. 5.1\nclip_ratio_high=0.0004 # as recommended by the paper, see Sec. 5.1\ntrain_batch_size=512\nppo_mini_batch_size=128 # maintain 4 mini-batches as recommended by the paper, see Sec. 5.1\nppo_micro_batch_size_per_gpu=8 # setup depending on your GPU memory\nn_resp_per_prompt=16\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\n# dapo reward manager params\nenable_overlong_buffer=false # true\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# Paths and namings\nSFT_MODEL=$(basename $MODEL_PATH)\nexp_name=\"${loss_mode}-epslow-${clip_ratio_low}-epshigh-${clip_ratio_high}-${SFT_MODEL}-RL\"\nCKPTS_DIR=/rl/checkpoints/experimental/4b/${loss_mode}/${exp_name}\n\n# Sampling params at rollouts\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=1\nuse_dynamic_bsz=true\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=true\ngen_tp=1\nentropy_checkpointing=true # This enables entropy recomputation specifically for the entropy calculation, lowering memory usage during training.\n\n# ------------------------------------- train/val data preparation ---------------------------------------\nif [ \"$first_time_dataset_prep\" = true ]; then\n    echo \"Preprocessing GSM8K dataset...\"\n    python examples/data_preprocess/gsm8k.py --local_save_dir /data/gsm8k/\nfi\n\ngsm8k_train_path=/data/gsm8k/train.parquet\ngsm8k_test_path=/data/gsm8k/test.parquet\n\n# set the paths\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    data.train_files=\"${train_files}\" \\\n    data.val_files=\"${test_files}\" \\\n    data.shuffle=$shuffle_dataset \\\n    data.prompt_key=prompt \\\n    data.truncation='error' \\\n    data.filter_overlong_prompts=true \\\n    data.train_batch_size=${train_batch_size} \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.model.use_remove_padding=true \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.name=${rollout_engine} \\\n    actor_rollout_ref.rollout.mode=${rollout_mode} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=true \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.05 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=true \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=true \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \\\n    reward_model.reward_manager=${reward_manager} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=false \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${GPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=${val_before_train} \\\n    trainer.test_freq=${test_freq} \\\n    trainer.save_freq=${save_freq} \\\n    trainer.total_epochs=${total_epochs} \\\n    trainer.total_training_steps=${total_training_steps} \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=2 \\\n    $@\n"
  },
  {
    "path": "verl_distillation/recipe/gspo/test_gspo_qwen30b_a3b_ep.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport NCCL_DEBUG=WARN\n# export VERL_LOGGING_LEVEL=DEBUG\n\nproject_name='DAPO'\nexp_name='GSPO-Qwen3-30B-A3B-Base-MATH'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=3e-4\nclip_ratio_high=4e-4\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=gspo\n\ntrain_prompt_bsz=256\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\n# RAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# MODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base\"}\n# CKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\n# TRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\n# TEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\nMODEL_PATH=$HDFS_ROOT/model/Qwen3-30B-A3B-Base\nCKPTS_DIR=$DATA_ROOT/checkpoint/${project_name}/${exp_name}\nTRAIN_FILE=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k/data/dapo-math-17k.parquet\naime24_test_path=$DATA_ROOT/dataset/aime-2024.parquet\n\nTEST_FILE=\"['$aime24_test_path']\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\n\n# gen\nrollout_name=vllm # vllm or sglang\ngen_tp=1\ngen_dp=4\ngen_ep=4\n\n# train\ntrain_tp=4\ntrain_pp=1\nEP=4\nETP=1\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.return_raw_chat=True \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=${rollout_name} \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.calculate_log_probs=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.data_parallel_size=${gen_dp} \\\n    actor_rollout_ref.rollout.expert_parallel_size=${gen_ep} \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.use_mbridge=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}-tp${gen_tp}-ep${gen_ep}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=30 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=300 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/infigui-g1/README.md",
    "content": "# Recipe for InfiGUI-G1\n\nThis directory contains the official implementation for the paper [InfiGUI-G1: Advancing GUI Grounding with Adaptive Exploration Policy Optimization](https://arxiv.org/abs/2508.05731).\n\nThis work introduces Adaptive Exploration Policy Optimization (AEPO), a policy optimization framework designed to enhance GUI grounding in Multimodal Large Language Models (MLLMs). AEPO improves exploration efficiency by employing a multi-answer generation strategy and a theoretically grounded Adaptive Exploration Reward (AER) function. This approach effectively addresses the challenge of semantic alignment in complex GUI grounding tasks.\n\nWe provide training scripts for both 3B and 7B models, configured for a single machine with 8 GPUs by default.\n\n## Environment Setup\n\nPlease follow the main environment setup guide for `verl`.\n\nThe provided scripts use the following Docker image: `verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2`\n\n## Data Preparation\n\nBefore starting the training, you need to download the example dataset. This dataset is a filtered version of [omniact](https://huggingface.co/datasets/Writer/omniact), containing only grounding tasks and excluding easy samples.\n\nThe data is hosted on the Hugging Face. You can download it using the `huggingface-cli`:\n\n```bash\nhuggingface-cli download --repo-type dataset --resume-download InfiX-ai/omniact_grounding_filtered --local-dir data/omniact_grounding_filtered\n```\n\nThis command will download the training and validation parquet files into the `data/omniact_grounding_filtered` directory, which is the default path used by the scripts.\n\n## Training\n\nWe provide scripts to train the 3B and 7B models. Please run them from the root directory of `verl`.\n\n-   **Train the 3B model:**\n\n    ```bash\n    bash recipe/infigui-g1/run_3b.sh\n    ```\n\n-   **Train the 7B model:**\n\n    ```bash\n    bash recipe/infigui-g1/run_7b.sh\n    ```\n\n## Using Custom Data\n\nIf you wish to train on your own dataset, please format your data to match the structure of the example files located in `data/omniact_grounding_filtered`.\n\nOnce your data is ready, you need to update the data path arguments in the training script.\n\nIn `run_3b.sh` or `run_7b.sh`, modify the following lines:\n\n```bash\n    data.train_files=./path/to/your/train_data.parquet \\\n    data.val_files=./path/to/your/val_data.parquet \\\n```\n\nReplace the paths with the location of your custom data files.\n"
  },
  {
    "path": "verl_distillation/recipe/infigui-g1/reward_fn.py",
    "content": "# Copyright 2025 Individual Contributor: InfiX.ai\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport math\nimport re\nfrom itertools import combinations\n\nFMT_RATIO = 1.0\nACC_RATIO = 1.0\n\n\n# ============================================================================\n# Utility Functions\n# ============================================================================\n\n\ndef extract_think_format(predict_str: str) -> None | dict[str, str]:\n    \"\"\"\n    Check if the predicted string meets format requirements and extract thinking and answer parts.\n\n    Args:\n        predict_str: The predicted string\n\n    Returns:\n        If format requirements are met, returns a dictionary containing thinking and answer parts;\n        otherwise returns None\n    \"\"\"\n    if not predict_str or not isinstance(predict_str, str):\n        return None\n\n    # Check if <think> is at the beginning\n    if not predict_str.startswith(\"<think>\"):\n        return None\n\n    # Check if there is <think>...</think> format\n    pattern = r\"<think>(.*?)</think>\"\n    think_match = re.search(pattern, predict_str, re.DOTALL)\n    if not think_match:\n        return None\n\n    if predict_str.count(\"<think>\") != 1 or predict_str.count(\"</think>\") != 1:\n        return None\n\n    # Extract thinking content\n    think_content = think_match.group(1).strip()\n    if not think_content:\n        return None\n\n    # Get content after </think>\n    think_end_pos = predict_str.find(\"</think>\") + len(\"</think>\")\n    post_think_content = predict_str[think_end_pos:].strip()\n\n    # Check if there is non-empty content after </think>\n    if not post_think_content:\n        return None\n\n    return {\"think\": think_content, \"answer\": post_think_content}\n\n\ndef extract_and_parse_json(input_string, wrapper):\n    \"\"\"\n    Try to extract and parse JSON from a string.\n\n    Args:\n        input_string: The input string\n        wrapper: JSON wrapper symbols, can be '{}' or '[]'\n\n    Returns:\n        Parsed JSON object, returns None if parsing fails\n    \"\"\"\n    if len(wrapper) != 2:\n        raise ValueError(\"Wrapper must be exactly two characters long\")\n\n    start_char, end_char = wrapper\n    start_index = input_string.find(start_char)\n\n    if start_index == -1:\n        return None\n\n    # Find the matching end character by balancing brackets/braces\n    balance = 1\n    end_index = -1\n    for i in range(start_index + 1, len(input_string)):\n        if input_string[i] == start_char:\n            balance += 1\n        elif input_string[i] == end_char:\n            balance -= 1\n\n        if balance == 0:\n            end_index = i\n            break\n\n    if end_index == -1:\n        return None\n\n    json_string = input_string[start_index : end_index + 1]\n\n    try:\n        return json.loads(json_string)\n    except json.JSONDecodeError:\n        return None\n\n\n# ============================================================================\n# AER Reward Functions\n# ============================================================================\n\n\ndef _extract_verifiable_answer(answer):\n    \"\"\"\n    Extract and verify the format of the point list from the answer string.\n\n    A valid format is a JSON list of dictionaries, where each dictionary\n    has a \"point_2d\" key with a list of two numbers as the value.\n\n    Args:\n        answer: The answer string to extract points from\n\n    Returns:\n        List of valid points or None if format is invalid\n    \"\"\"\n    points = extract_and_parse_json(answer, \"[]\")\n    if points is None or not isinstance(points, list):\n        return None\n\n    # Verify each point in the list\n    for point in points:\n        if isinstance(point, dict) and \"point_2d\" in point:\n            point_2d = point[\"point_2d\"]\n            if isinstance(point_2d, list) and len(point_2d) == 2:\n                continue\n\n        # If any point is malformed, the whole answer is invalid\n        return None\n\n    return points\n\n\ndef _format_reward(answer):\n    \"\"\"\n    Calculate the format reward for 'point' type data.\n\n    This function is now primarily used as a check to see if the format is valid.\n\n    Args:\n        answer: The answer string to validate\n\n    Returns:\n        Tuple of (reward, is_collinear) where reward is 1.0 for valid format, 0.0 otherwise\n    \"\"\"\n    points = _extract_verifiable_answer(answer)\n    if points is None:\n        return 0.0, 0\n\n    points_2d = [item[\"point_2d\"] for item in points]\n    if _check_collinear(points_2d):\n        return 0.0, 1\n\n    return 1.0, 0\n\n\ndef _check_collinear(points_2d):\n    \"\"\"\n    Check if 3 or more points in the list are collinear on any straight line.\n\n    This uses the cross-product method to avoid division and handle all line types.\n\n    Args:\n        points_2d: A list of [x, y] coordinates\n\n    Returns:\n        True if 3 or more points are collinear, False otherwise\n    \"\"\"\n    if len(points_2d) < 3:\n        return False\n\n    # Iterate through all unique combinations of 3 points\n    for p1, p2, p3 in combinations(points_2d, 3):\n        x1, y1 = p1\n        x2, y2 = p2\n        x3, y3 = p3\n\n        # Check for collinearity using the cross-product method.\n        # If (y2 - y1) * (x3 - x1) == (y3 - y1) * (x2 - x1), the points are collinear.\n        # This is equivalent to checking if the area of the triangle formed by the points is 0.\n        if math.isclose((y2 - y1) * (x3 - x1), (y3 - y1) * (x2 - x1)):\n            return True\n\n    return False\n\n\ndef _accuracy_reward(answer, ground_truth):\n    \"\"\"\n    Calculate the accuracy reward based on the symmetric zero-centered formula.\n\n    The reward is in the range [-1, 1].\n\n    Args:\n        answer: The answer string containing predicted points\n        ground_truth: Ground truth bounding box dictionary\n\n    Returns:\n        Tuple containing:\n        - accuracy (float): The calculated reward\n        - extracted_answer (str): The JSON string of the predicted points\n        - num_pred (int): The number of predicted points\n        - first_correct (int): 1 if the first predicted point is correct, 0 otherwise\n    \"\"\"\n    pred_points = _extract_verifiable_answer(answer)\n\n    # If no valid points are extracted, this is considered a format error, return -1 reward\n    if pred_points is None:\n        return -1.0, \"\", 0, 0\n\n    num_pred = len(pred_points)\n    extracted_answer = json.dumps(pred_points)\n\n    if num_pred == 0:\n        return -1.0, extracted_answer, 0, 0\n\n    # Find the rank 'k' of the first correct point\n    first_correct_rank = -1\n    for i, item in enumerate(pred_points):\n        point_2d = item[\"point_2d\"]\n        if (\n            ground_truth[\"x1\"] <= point_2d[0] <= ground_truth[\"x2\"]\n            and ground_truth[\"y1\"] <= point_2d[1] <= ground_truth[\"y2\"]\n        ):\n            first_correct_rank = i + 1  # 1-based index\n            break\n\n    # Calculate reward based on the zero-centered symmetric formula\n    accuracy = 0.0\n    if first_correct_rank != -1:\n        # Case a: Correct point found (Positive reward space)\n        k = first_correct_rank\n        accuracy = 1.0 / math.sqrt(num_pred * k)\n    else:\n        # Case b: No correct point found (Negative reward space)\n        accuracy = -1.0 / num_pred\n\n    first_correct = 1 if first_correct_rank == 1 else 0\n\n    return accuracy, extracted_answer, num_pred, first_correct\n\n\ndef calculate_point_reward(solution_str, ground_truth, extra_info=None, fmt_ratio=1.0, acc_ratio=1.0, **kwargs):\n    \"\"\"\n    Calculate the final reward for 'point' type data.\n\n    Implements the full logic including format checks, collinearity checks,\n    and the zero-centered symmetric reward calculation.\n\n    Args:\n        solution_str: The solution string from the model\n        ground_truth: Ground truth data\n        extra_info: Extra information dictionary\n        fmt_ratio: Format reward ratio\n        acc_ratio: Accuracy reward ratio\n        **kwargs: Additional keyword arguments\n\n    Returns:\n        Dictionary containing detailed reward information\n    \"\"\"\n    if extra_info.get(\"no_think\", False):\n        answer = solution_str\n    else:\n        solution_dict = extract_think_format(solution_str)\n        # If the overall 'think'/'answer' format is wrong, return score of -1\n        if solution_dict is None:\n            return {\n                \"score\": -1.0,\n                \"format\": 0.0,\n                \"accuracy\": -1.0,\n                \"pred\": \"\",\n                \"num_pred\": 0,\n                \"has_correct\": 0,\n                \"first_correct\": 0,\n                \"only_correct\": 0,\n                \"is_collinear\": 0,\n            }\n\n        answer = solution_dict[\"answer\"]\n\n    # Reuse _format_reward to check the format of the 'answer' part\n    # If it's invalid, return score of -1\n    format_reward, is_collinear = _format_reward(answer)\n    if format_reward == 0.0:\n        return {\n            \"score\": -1.0,\n            \"format\": 0.0,\n            \"accuracy\": -1.0,\n            \"pred\": \"\",\n            \"num_pred\": 0,\n            \"has_correct\": 0,\n            \"first_correct\": 0,\n            \"only_correct\": 0,\n            \"is_collinear\": is_collinear,\n        }\n\n    # If format is OK, calculate the accuracy reward\n    accuracy_reward, extracted_answer, num_pred, first_correct = _accuracy_reward(answer, ground_truth)\n\n    return {\n        \"score\": fmt_ratio * format_reward + acc_ratio * accuracy_reward,\n        \"format\": format_reward,\n        \"accuracy\": accuracy_reward,\n        \"pred\": extracted_answer,\n        \"num_pred\": num_pred,\n        \"has_correct\": 1 if accuracy_reward > 0 else 0,\n        \"first_correct\": first_correct,\n        \"only_correct\": 1 if num_pred == 1 and accuracy_reward > 0 else 0,\n        \"is_collinear\": 0,\n    }\n\n\n# ============================================================================\n# AER Reward Handler Registry\n# ============================================================================\n\n# Dictionary to map data_source to the respective reward calculation function\nAER_REWARD_HANDLERS = {\n    \"point\": calculate_point_reward,\n}\n\n\ndef aer_gui_reward_function(data_source, solution_str, ground_truth, extra_info=None, **kwargs):\n    \"\"\"\n    Main reward function dispatcher for the Adaptive Exploration Reward (AER) system.\n\n    Delegates reward calculation to specific functions based on the data_source using a dictionary lookup.\n\n    Args:\n        data_source: The source or type of the data (e.g., \"point\", \"bbox\")\n        solution_str: The solution string generated by the model\n        ground_truth: The ground truth data\n        extra_info: Any extra information passed along (optional)\n        **kwargs: Additional keyword arguments that might be passed from the PPO trainer config\n\n    Returns:\n        Dictionary containing detailed reward information with keys:\n        - score: The final calculated reward score\n        - format: Format validation score\n        - accuracy: Accuracy score\n        - pred: Extracted prediction string\n        - num_pred: Number of predictions\n        - has_correct: Whether any correct prediction exists\n        - first_correct: Whether first prediction is correct\n        - only_correct: Whether only one correct prediction exists\n        - is_collinear: Whether points are collinear (for point type)\n    \"\"\"\n    handler = AER_REWARD_HANDLERS.get(data_source, None)\n\n    if handler:\n        try:\n            return handler(\n                solution_str, ground_truth, extra_info=extra_info, fmt_ratio=FMT_RATIO, acc_ratio=ACC_RATIO, **kwargs\n            )\n        except Exception as e:\n            logging.exception(\n                f\"Error executing reward handler for data_source '{data_source}': {e}\",\n            )\n            return {\n                \"score\": -1.0,\n                \"format\": 0.0,\n                \"accuracy\": -1.0,\n                \"pred\": \"\",\n                \"num_pred\": 0,\n                \"has_correct\": 0,\n                \"first_correct\": 0,\n                \"only_correct\": 0,\n                \"is_collinear\": 0,\n            }  # Return a default penalty score on error\n    else:\n        raise ValueError(f\"Unknown data_source: '{data_source}'. No specific reward handler defined.\")\n"
  },
  {
    "path": "verl_distillation/recipe/infigui-g1/run_3b.sh",
    "content": "#!/bin/bash\nset -x\nulimit -n 65535\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=rloo \\\n    data.train_files=./data/omniact_grounding_filtered/omniact_filtered_train.parquet \\\n    data.val_files=./data/omniact_grounding_filtered/omniact_filtered_val.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=7168 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=False \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    custom_reward_function.path=./recipe/infigui-g1/reward_fn.py \\\n    custom_reward_function.name=aer_gui_reward_function \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.model.enable_activation_offload=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=False \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=0 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.clip_ratio_high=0.4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=8192 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.rollout.temperature=1.0 \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name='infigui-g1' \\\n    trainer.experiment_name='3b' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=16 \\\n    trainer.test_freq=16 \\\n    trainer.total_epochs=6\n"
  },
  {
    "path": "verl_distillation/recipe/infigui-g1/run_7b.sh",
    "content": "#!/bin/bash\nset -x\nulimit -n 65535\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=rloo \\\n    data.train_files=./data/omniact_grounding_filtered/omniact_filtered_train.parquet \\\n    data.val_files=./data/omniact_grounding_filtered/omniact_filtered_val.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=7168 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=False \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    custom_reward_function.path=./recipe/infigui-g1/reward_fn.py \\\n    custom_reward_function.name=aer_gui_reward_function \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.model.enable_activation_offload=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=False \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=0 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.clip_ratio_high=0.4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=8192 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.rollout.temperature=1.0 \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name='infigui-g1' \\\n    trainer.experiment_name='7b' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=16 \\\n    trainer.test_freq=16 \\\n    trainer.total_epochs=6\n"
  },
  {
    "path": "verl_distillation/recipe/langgraph_agent/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/recipe/langgraph_agent/chat_model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nRef: https://python.langchain.com/docs/how_to/custom_chat_model/\n\"\"\"\n\nimport asyncio\nimport json\nimport logging\nimport os\nimport uuid\nfrom typing import Any, Optional\n\nfrom langchain_core.language_models import BaseChatModel\nfrom langchain_core.language_models.base import LanguageModelInput\nfrom langchain_core.messages import (\n    AIMessage,\n    BaseMessage,\n    convert_to_openai_messages,\n)\nfrom langchain_core.messages.tool import InvalidToolCall, ToolCall\nfrom langchain_core.outputs import ChatGeneration, ChatResult\nfrom langchain_core.runnables import Runnable, RunnableConfig\nfrom langchain_core.tools import StructuredTool\nfrom langchain_core.utils.function_calling import convert_to_openai_tool\nfrom pydantic import Field\n\nfrom verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager\nfrom verl.experimental.agent_loop.tool_parser import ToolParser\nfrom verl.experimental.agent_loop.utils import add_generation_prompt_for_gpt_oss, format_gpt_oss_tool_response_manually\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MaxTokenExceededError(Exception):\n    \"\"\"Indicate that history chat messages + tool message exceeds LLM max_tokens.\"\"\"\n\n    pass\n\n\nclass ChatModel(BaseChatModel):\n    model_name: str = Field(alias=\"model\")\n    \"\"\"The name of the model\"\"\"\n\n    client: AsyncLLMServerManager\n    \"\"\"AsyncLLM server manager\"\"\"\n\n    tokenizer: Any\n    \"\"\"Tokenizer for the model\"\"\"\n\n    max_tokens: int\n    \"\"\"Max tokens to generate\"\"\"\n\n    tool_parser: str = \"hermes\"\n    \"\"\"Tool parser for the model\"\"\"\n\n    max_parallel_calls: int = 1\n    \"\"\"Max parallel tool calls\"\"\"\n\n    temperature: float = 1.0\n    \"\"\"Temperature for sampling\"\"\"\n\n    top_p: float = 1.0\n    \"\"\"Top p for sampling\"\"\"\n\n    repetition_penalty: float = 1.0\n    \"\"\"Repetition penalty for sampling\"\"\"\n\n    def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]:\n        \"\"\"Bind tools to the model.\n\n        Args:\n            tools: Sequence of tools to bind to the model.\n\n        Returns:\n            A Runnable that returns a message.\n        \"\"\"\n        formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools]\n\n        # used to remove system prompt prefix when encoding tool response\n        system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)\n        kwargs[\"system_prompt\"] = system_prompt\n\n        return self.bind(tools=formatted_tools, **kwargs)\n\n    def with_structured_output(\n        self,\n        schema: dict | type,\n        *,\n        include_raw: bool = False,\n        **kwargs: Any,\n    ) -> Runnable[LanguageModelInput, dict | BaseChatModel]:\n        \"\"\"Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/\"\"\"\n        raise NotImplementedError\n\n    def _generate(\n        self,\n        messages: list[BaseMessage],\n        stop: Optional[list[str]] = None,\n        **kwargs: Any,\n    ) -> ChatResult:\n        raise NotImplementedError\n\n    async def _agenerate(\n        self,\n        messages: list[BaseMessage],\n        stop: Optional[list[str]] = None,\n        **kwargs: Any,\n    ) -> ChatResult:\n        \"\"\"Asynchronously generate chat completion message.\n\n        Args:\n            messages (list[BaseMessage]): List of list of messages.\n            stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the\n                first occurrence of any of these substrings. Defaults to None.\n\n        Returns:\n            ChatResult: Chat result.\n        \"\"\"\n        request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs)\n\n        sampling_params = {\n            \"temperature\": self.temperature,\n            \"top_p\": self.top_p,\n            \"repetition_penalty\": self.repetition_penalty,\n        }\n        if \"sampling_params\" in kwargs:\n            sampling_params.update(kwargs[\"sampling_params\"])\n\n        output = await self.client.generate(\n            request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params\n        )\n\n        message = await self._postprocess(request_id, prompt_ids, response_mask, output.token_ids, **kwargs)\n        generation = ChatGeneration(message=message)\n        return ChatResult(generations=[generation])\n\n    @property\n    def _llm_type(self) -> str:\n        \"\"\"Get the type of language model used by this chat model.\"\"\"\n        return self.model_name\n\n    async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]:\n        \"\"\"Preprocess messages for chat completion.\n\n        To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out\n        instead of messages list.\n\n        But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory\n        (prompt_ids, response_mask) in lastest AIMessage.response_metadata.\n\n        1. Encode ToolMessage to token ids.\n        2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata.\n        3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask.\n\n        Ref: https://python.langchain.com/docs/concepts/chat_history/\n\n        Args:\n            messages (list[BaseMessage]): List of messages.\n\n        Returns:\n            tuple[str, list[int], list[int]]: Request id, prompt ids, response mask.\n        \"\"\"\n        # messages: [system], human, ai, human|tool, ai, human|tool, ...\n        assert messages[-1].type in [\"human\", \"tool\"], (\n            f\"Last message must be human or tool, but got {messages[-1].type}\"\n        )\n        loop = asyncio.get_running_loop()\n\n        # Case 1: initial chat completion: [system], human\n        if messages[-1].type == \"human\" and (len(messages) == 1 or messages[-2].type != \"ai\"):\n            prompt_ids = await loop.run_in_executor(\n                None,\n                lambda: self.tokenizer.apply_chat_template(\n                    convert_to_openai_messages(messages),\n                    tools=kwargs.get(\"tools\"),\n                    add_generation_prompt=True,\n                    tokenize=True,\n                ),\n            )\n            return str(uuid.uuid4()), prompt_ids, []\n\n        # Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ...\n        for i in range(len(messages) - 1, -1, -1):\n            if messages[i].type == \"ai\":\n                break\n        assert \"prompt_ids\" in messages[i].response_metadata, \"Last message must have prompt_ids in response_metadata\"\n        assert \"response_mask\" in messages[i].response_metadata, (\n            \"Last message must have response_mask in response_metadata\"\n        )\n\n        # encode tool response\n        tool_responses = convert_to_openai_messages(messages[i + 1 :])\n        if self.tool_parser == \"hermes\":\n            tool_response_ids = await loop.run_in_executor(\n                None,\n                lambda messages=tool_responses: self.tokenizer.apply_chat_template(\n                    messages, add_generation_prompt=True, tokenize=True\n                ),\n            )\n            tool_response_ids = tool_response_ids[len(kwargs[\"system_prompt\"]) :]\n        elif self.tool_parser == \"gpt-oss\":\n            # Format tool responses manually\n            # since gpt-oss chat template requires tool call messages to parse tool response messages\n            # we need to format the tool response messages manually\n            tool_response_texts = []\n            for tool_msg in tool_responses:\n                if tool_msg[\"role\"] == \"tool\":\n                    # Use tool message's name if available (for multiple tool calls)\n                    actual_tool_name = tool_msg.get(\"name\", \"unknown\")\n                    if actual_tool_name == \"unknown\":\n                        logger.error(f\"actual_tool_name: {actual_tool_name}\")\n                    formatted = format_gpt_oss_tool_response_manually(tool_msg[\"content\"], actual_tool_name)\n                    tool_response_texts.append(formatted)\n\n            # Tokenize the manually formatted tool responses\n            tool_response_text = \"\".join(tool_response_texts)\n            # need to add generation tokens for gpt-oss manually since add_generation_prompt is True\n            tool_response_text = add_generation_prompt_for_gpt_oss(tool_response_text)\n            logger.debug(f\"tool_response_text: {tool_response_text}\")\n\n            tool_response_ids = await loop.run_in_executor(\n                None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False)\n            )\n        else:\n            raise ValueError(f\"Unsupported tool parser: {self.tool_parser}\")\n\n        # stop generation if response length exceeds max response length\n        if len(messages[i].response_metadata[\"response_mask\"]) + len(tool_response_ids) >= self.max_tokens:\n            raise MaxTokenExceededError(f\"Max response length {self.max_tokens} exceeded\")\n\n        # append tool response to prompt\n        request_id = messages[i].response_metadata.pop(\"request_id\")\n        prompt_ids = messages[i].response_metadata.pop(\"prompt_ids\")\n        response_mask = messages[i].response_metadata.pop(\"response_mask\")\n        prompt_ids += tool_response_ids\n        response_mask += [0] * len(tool_response_ids)\n\n        return request_id, prompt_ids, response_mask\n\n    async def _postprocess(\n        self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any\n    ) -> AIMessage:\n        \"\"\"Postprocess response_ids when chat completion is done.\n\n        1. Decode response_ids, parse tool calls to AIMessage.\n        2. Append response_ids to prompt_ids, and append 1 to response_mask.\n        3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata.\n\n        Args:\n            request_id (str): Unique request id.\n            prompt_ids (list[int]): Input prompt token ids in this chat completion.\n            response_mask (list[int]): Response mask before this chat completion.\n            response_ids (list[int]): LLM generated token ids in this chat completion.\n\n        Returns:\n            AIMessage: Postprocessed message.\n        \"\"\"\n        prompt_ids += response_ids\n        response_mask += [1] * len(response_ids)\n\n        tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer)\n        content, function_calls = await tool_parser.extract_tool_calls(response_ids)\n\n        tool_calls, invalid_tool_calls = [], []\n\n        for function_call in function_calls:\n            error = None\n            try:\n                args = json.loads(function_call.arguments)\n                if not isinstance(args, dict):\n                    error = f\"Tool arguments must be a JSON object, got {type(args).__name__}\"\n            except json.JSONDecodeError as e:\n                error = f\"Invalid JSON tool arguments: {e}\"\n\n            if error:\n                logger.warning(error)\n                invalid_tool_calls.append(\n                    InvalidToolCall(\n                        name=function_call.name,\n                        args=function_call.arguments,\n                        id=str(uuid.uuid4()),\n                        error=error,\n                    )\n                )\n            else:\n                tool_calls.append(\n                    ToolCall(\n                        name=function_call.name,\n                        args=args,\n                        id=str(uuid.uuid4()),\n                    )\n                )\n\n        message = AIMessage(\n            content=content,\n            tool_calls=tool_calls[: self.max_parallel_calls],\n            invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls],\n            response_metadata={\n                \"request_id\": request_id,\n                \"prompt_ids\": prompt_ids,\n                \"response_mask\": response_mask,\n            },\n        )\n        return message\n\n\nclass TruncateStructuredTool(StructuredTool):\n    \"\"\"Structured tool with response truncation.\"\"\"\n\n    tool_response_truncate_side: str\n    \"\"\"truncate side of tool response: left, middle, right\"\"\"\n\n    max_tool_response_length: int\n    \"\"\"max length of tool response\"\"\"\n\n    async def _arun(\n        self,\n        *args: Any,\n        config: RunnableConfig,\n        **kwargs: Any,\n    ) -> Any:\n        tool_response = await super()._arun(*args, config=config, **kwargs)\n        tool_response = str(tool_response)\n\n        if len(tool_response) > self.max_tool_response_length:\n            if self.tool_response_truncate_side == \"left\":\n                tool_response = tool_response[: self.max_tool_response_length] + \"...(truncated)\"\n            elif self.tool_response_truncate_side == \"right\":\n                tool_response = \"(truncated)...\" + tool_response[-self.max_tool_response_length :]\n            else:\n                length = self.max_tool_response_length // 2\n                tool_response = tool_response[:length] + \"...(truncated)...\" + tool_response[-length:]\n\n        return tool_response\n\n\ndef convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput:\n    \"\"\"Convert messages to AgentLoopOutput.\n\n    Args:\n        messages (List[BaseMessage]): List of messages, last message must be assistant\n            with response_metadata containing `prompt_ids` and `response_mask`.\n        response_length (int): Max length of response.\n\n    Returns:\n        AgentLoopOutput: agent loop output trajectory used for training.\n    \"\"\"\n    # skip last tool calls\n    for i in range(len(messages) - 1, -1, -1):\n        if messages[i].type != \"tool\":\n            break\n    last_message = messages[i]\n    assert last_message.type == \"ai\", f\"Last message must be assistant, but got {last_message.type}\"\n    assert \"prompt_ids\" in last_message.response_metadata, \"Last message must have prompt_ids in response_metadata\"\n    assert \"response_mask\" in last_message.response_metadata, (\n        \"Last message must have response_mask in response_metadata\"\n    )\n\n    num_turns = 0\n    for i in range(len(messages)):\n        if messages[i].type == \"system\":\n            continue\n        # parallel tool calls are in single turn\n        if i == 0 or messages[i].type != messages[i - 1].type:\n            num_turns += 1\n\n    prompt_ids = last_message.response_metadata[\"prompt_ids\"]\n    response_mask = last_message.response_metadata[\"response_mask\"]\n\n    response_ids = prompt_ids[-len(response_mask) :]\n    prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]\n\n    output = AgentLoopOutput(\n        prompt_ids=prompt_ids,\n        response_ids=response_ids[:response_length],\n        response_mask=response_mask[:response_length],\n        num_turns=num_turns,\n        metrics={},\n    )\n    return output\n"
  },
  {
    "path": "verl_distillation/recipe/langgraph_agent/example/README.md",
    "content": "# MathExpression: LangGraph Agent Example\n\nMathExpression is a tiny example to demonstrate multi-turn rollout with [LangGraph ReactAgent](https://langchain-ai.github.io/langgraph/agents/overview/).\n\n### Define react agent with tool\nFirstly, to force ReactAgent to evaluate math expression by tool, we define a special operand `@`:\n```python\n@tool(parse_docstring=True)\ndef calculate(a: int, b: int, operand: str) -> int:\n    \"\"\"\n    Compute the results using operand with two integers\n\n    Args:\n        a: the first operand\n        b: the second operand\n        operand: '+' or '-' or '*' or '@'\n    \"\"\"\n    assert operand in [\"+\", \"-\", \"*\", \"@\"], f\"unknown operand {operand}\"\n    if operand == \"@\":\n        return 3 * a - 2 * b\n    return eval(f\"{a} {operand} {b}\")\n```\n\nWithout calling `calculate`, ReactAgent is impossible to evaluate math expression correctly.\n\nThen, we can equip ReactAgent with `calculate` tool:\n```python\nclass MathExpressionReactAgentLoop(ReactAgentLoop):\n    @classmethod\n    def init_class(cls, config, tokenizer):\n        cls.tools = [calculate]\n        super().init_class(config, tokenizer)\n```\n\nWe can define agent loop config in yaml file, which will be used by AgentLoopWorker to dynamic load custom AgentLoop class.\n```yaml\n- name: math_expression\n  _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop\n```\n\n### Prepare dataset\nNow, let's prepare two small datasets for training and evaluation:\n```bash\npython recipe/langgraph_agent/example/create_dataset.py\n```\n\n- Parameters: `--train_size` (default: 5000), `--test_size` (default: 500), `--output_dir` (default: `data/math_expression_tool`).\n- Example with custom sizes/output:\n```bash\npython recipe/langgraph_agent/example/create_dataset.py \\\n  --train_size 10000 \\\n  --test_size 1000 \\\n  --output_dir data/math_expression_tool\n```\n\nNote that dataset should contain a column `agent_name` with `math_expression`, which is used by `AgentLoopWorker` to select the\nagent loop class.\n| prompt | reward_model | agent_name |\n|--------------------------------------|------------------------------|-----------------|\n| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |\n| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |\n\nGenerated math expressions are like below, requiring model to call `calculate` multiple times to solve sub expressions.\n```\n(2 @ (8 @ 8 @ 5 @ 5 @ 3) @ 6 @ (1 @ 4 @ 4 @ 4) @ 2) @ 6\n(4.6 @ (9.05 @ 4.0) @ 8.3 @ 1.21) @ 8.6\n9 @ 4\n((2 @ 2) @ (3 @ 3)) @ 4\n```\n\n### Training\nHook all these up and start training:\n```bash\nbash recipe/langgraph_agent/example/run_qwen2.5_3b.sh 2>&1 | tee train.log\n```\n\nTo submit on a SLURM cluster (the script contains SBATCH headers):\n```bash\nsbatch recipe/langgraph_agent/example/run_qwen2.5_3b.sh\n```\n\n**Note on `GPUS_PER_NODE` and `NNODES`:**\n\n- `GPUS_PER_NODE`: GPUs per node.  \n  Detection order: `SLURM_GPUS_ON_NODE` (if set) → `GPUS_PER_NODE` → `2`.\n- `NNODES`: number of nodes.  \n  Detection order: `SLURM_JOB_NUM_NODES` (if set) → `NNODES` → `1`.\n- Total GPUs = `GPUS_PER_NODE × NNODES` (must be ≥ 2).\n\nLocal override (no `SLURM_*` set):\n```bash\nGPUS_PER_NODE=4 NNODES=2 bash recipe/langgraph_agent/example/run_qwen2.5_3b.sh\n```\n\nAfter total 39 steps, model should achieve 100% accuray on test dataset:\n- val-aux/lighteval/MATH/reward: 1.0\n- val-aux/num_turns/mean: 9.0, average number of messages include assistant and tool turns.\n\nWe can see model can solve math expression by calling `calculate` tool:\n```\nsystem\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.\nuser\nWe define a new math operator @, where you can only call an external tool to compute. Please put your final answer inside \\boxed{} only in the last turn. Now answer the following questions:\nCompute 5 @ (3 @ 6 @ 8)\nassistant\nSure, let's break this down step by step. We need to compute the expression \\(5 @ (3 @ 6 @ 8)\\). First, we need to compute the inner expression \\(3 @ 6 @ 8\\).\n\nLet's compute \\(3 @ 6\\) first using the external tool.\n<tool_call>\n{\"name\": \"calculate\", \"arguments\": {\"a\": 3, \"b\": 6, \"operand\": \"@\"}}\n</tool_call>user\n<tool_response>\n-3\n</tool_response>\nassistant\nWe have computed that \\(3 @ 6 = -3\\). Now we need to compute \\(-3 @ 8\\).\n\nLet's use the external tool to compute this.\n<tool_call>\n{\"name\": \"calculate\", \"arguments\": {\"a\": -3, \"b\": 8, \"operand\": \"@\"}}\n</tool_call>user\n<tool_response>\n-25\n</tool_response>\nassistant\nWe have computed that \\(-3 @ 8 = -25\\). Now we can put this result into the original expression to get the final result of \\(5 @ (3 @ 6 @ 8)\\), which is \\(5 @ (-25)\\).\n\nLet's compute \\(5 @ (-25)\\) using the external tool.\n<tool_call>\n{\"name\": \"calculate\", \"arguments\": {\"a\": 5, \"b\": -25, \"operand\": \"@\"}}\n</tool_call>user\n<tool_response>\n65\n</tool_response>\nassistant\nThe final result of \\(5 @ (3 @ 6 @ 8)\\) is \\(\\boxed{65}\\).\n```\n"
  },
  {
    "path": "verl_distillation/recipe/langgraph_agent/example/agent.yaml",
    "content": "- name: math_expression\n  _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop\n"
  },
  {
    "path": "verl_distillation/recipe/langgraph_agent/example/create_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nCreate dataset for calculator\n\"\"\"\n\nimport argparse\nimport os\nimport random\n\nimport pandas as pd\n\n\ndef generate_math_expression(min_terms=2, max_terms=5, min_number=1, max_number=10, allow_decimals=False, max_depth=2):\n    \"\"\"\n    Generate a random mathematical expression with operators +, -, *, /, and parentheses.\n\n    Args:\n        min_terms (int): Minimum number of terms in the expression.\n        max_terms (int): Maximum number of terms in the expression.\n        max_number (int): Maximum value for numbers in the expression.\n        allow_decimals (bool): Whether to allow decimal numbers.\n        max_depth (int): Maximum nesting depth for parentheses.\n\n    Returns:\n        str: A valid mathematical expression as a string.\n    \"\"\"\n\n    def generate_number():\n        \"\"\"Generate a random number (integer or float).\"\"\"\n        assert min_number < max_number\n        num = random.uniform(min_number, max_number)\n        if not allow_decimals:\n            num = int(num)\n        else:\n            num = round(num, random.randint(0, 2))  # Round to 0-2 decimal places\n        return str(num)\n\n    def generate_term(depth=0):\n        \"\"\"Generate a term (number or parenthesized expression).\"\"\"\n        if depth < max_depth and random.random() < 0.5:  # 50% chance to add parentheses\n            expr = generate_expression(depth + 1)\n            return f\"({expr})\"\n        else:\n            return generate_number()\n\n    def generate_expression(depth=0):\n        \"\"\"Generate a full expression with multiple terms and operators.\"\"\"\n        num_terms = random.randint(min_terms, max_terms)\n        terms = [generate_term(depth) for _ in range(num_terms)]\n\n        # Randomly select operators\n        operators = [\"+\", \"-\", \"*\", \"/\", \"@\"]\n        expr = terms[0]\n\n        for i in range(1, num_terms):\n            # Bias towards + and - for readability\n            op = random.choices(\n                operators,\n                weights=[0, 0, 0, 0, 1],  # + and - are 1.5x more likely than * and /\n            )[0]\n            expr += f\" {op} \" + terms[i]\n\n        return expr\n\n    return generate_expression()\n\n\ndef test():\n    # Example 1: Basic integer expression\n    print(generate_math_expression())\n    # Output: (3 + 7) * 2 - 5\n\n    # Example 2: Expression with decimals\n    print(generate_math_expression(allow_decimals=True))\n    # Output: 4.5 / (2.1 + 3.7) - 1.2\n\n    # Example 3: More complex expression with higher depth\n    print(generate_math_expression(max_terms=6, max_depth=3))\n    # Output: ((5 * 2) - (3 + 1)) / (7 - 2) + 4\n\n    # Example 4: Simplified expression\n    print(generate_math_expression(min_terms=2, max_terms=3, max_number=5))\n    # Output: 4 - 2 * 3\n\n\ndef calculate(expression: str) -> float:\n    \"\"\"\n    Evaluate a mathematical expression with +, -, *, /, @, and parentheses.\n    The @ operator is defined as: a @ b = 3a - 2b.\n\n    Args:\n        expression (str): Input mathematical expression (e.g., \"3@2+4\").\n\n    Returns:\n        float: Result of the evaluated expression.\n\n    Raises:\n        ValueError: For invalid expressions (e.g., mismatched parentheses, division by zero).\n    \"\"\"\n\n    def tokenize(s: str) -> list:\n        \"\"\"Convert the input string into tokens (numbers, operators, parentheses).\"\"\"\n        tokens = []\n        i = 0\n        while i < len(s):\n            if s[i].isdigit() or s[i] == \".\":\n                # Parse number (integer or float)\n                j = i\n                while j < len(s) and (s[j].isdigit() or s[j] == \".\"):\n                    j += 1\n                tokens.append(s[i:j])\n                i = j\n            elif s[i] in \"+-*/@()\":\n                # Operator or parenthesis\n                tokens.append(s[i])\n                i += 1\n            elif s[i].isspace():\n                # Skip whitespace\n                i += 1\n            else:\n                raise ValueError(f\"Invalid character: {s[i]}\")\n        return tokens\n\n    def infix_to_postfix(tokens: list) -> list:\n        \"\"\"Convert infix notation to postfix notation (Reverse Polish Notation).\"\"\"\n        output = []\n        stack = []\n        # Higher precedence for @ (between * and +)\n        precedence = {\"@\": 3, \"*\": 2, \"/\": 2, \"+\": 1, \"-\": 1}\n\n        for token in tokens:\n            if token.isdigit() or \".\" in token:\n                output.append(token)\n            elif token == \"(\":\n                stack.append(token)\n            elif token == \")\":\n                while stack and stack[-1] != \"(\":\n                    output.append(stack.pop())\n                if not stack or stack[-1] != \"(\":\n                    raise ValueError(\"Mismatched parentheses\")\n                stack.pop()  # Discard '('\n            else:  # Operator\n                while stack and stack[-1] != \"(\" and precedence.get(stack[-1], 0) >= precedence.get(token, 0):\n                    output.append(stack.pop())\n                stack.append(token)\n\n        # Pop remaining operators\n        while stack:\n            if stack[-1] in \"()\":\n                raise ValueError(\"Mismatched parentheses\")\n            output.append(stack.pop())\n\n        return output\n\n    def evaluate_postfix(postfix: list) -> float:\n        \"\"\"Evaluate postfix expression using a stack.\"\"\"\n        stack = []\n        for token in postfix:\n            if token.isdigit() or \".\" in token:\n                stack.append(float(token))\n            else:\n                if len(stack) < 2:\n                    raise ValueError(\"Invalid expression\")\n                b = stack.pop()\n                a = stack.pop()\n                if token == \"+\":\n                    res = a + b\n                elif token == \"-\":\n                    res = a - b\n                elif token == \"*\":\n                    res = a * b\n                elif token == \"/\":\n                    if b == 0:\n                        raise ValueError(\"Division by zero\")\n                    res = a / b\n                elif token == \"@\":\n                    res = 3 * a - 2 * b  # Custom @ operator implementation\n                else:\n                    raise ValueError(f\"Invalid operator: {token}\")\n                stack.append(res)\n\n        if len(stack) != 1:\n            raise ValueError(\"Invalid expression\")\n        return stack[0]\n\n    # Remove spaces and validate parentheses\n    expression = expression.replace(\" \", \"\")\n    if expression.count(\"(\") != expression.count(\")\"):\n        raise ValueError(\"Mismatched parentheses\")\n\n    tokens = tokenize(expression)\n    postfix = infix_to_postfix(tokens)\n    result = evaluate_postfix(postfix)\n\n    # Convert integers to integer representation\n    if result.is_integer():\n        return int(result)\n    return result\n\n\ndef generate_data(total_num_dataset, split):\n    rl_dataset = {\n        \"prompt\": [],\n        \"data_source\": [],\n        \"ability\": [],\n        \"reward_model\": [],\n        \"extra_info\": [],\n        \"agent_name\": [],\n    }\n\n    for idx in range(total_num_dataset):\n        while True:\n            try:\n                expression: str = generate_math_expression(\n                    min_terms=2, max_terms=3, min_number=1, max_number=10, allow_decimals=False, max_depth=1\n                )\n\n                num_plus = expression.count(\"+\")\n                num_minus = expression.count(\"-\")\n                num_mul = expression.count(\"*\")\n                num_star = expression.count(\"@\")\n\n                answer = str(calculate(expression))\n                # answer = str(eval(expression))\n                break\n            except Exception as e:\n                print(e)\n                continue\n\n        num_tool_calls = num_plus + num_minus + num_mul + num_star\n\n        prompt = (\n            f\"We define a new math operator @, where you can only call an external tool to compute. \"\n            f\"Please put your final answer inside \\\\boxed{{}} only in the last turn. Now answer the \"\n            f\"following questions:\\nCompute {expression}\"\n        )\n        prompt_with_template = [\n            {\n                \"role\": \"user\",\n                \"content\": prompt,\n            }\n        ]\n\n        rl_dataset[\"prompt\"].append(prompt_with_template)\n        rl_dataset[\"data_source\"].append(\"lighteval/MATH\")\n        rl_dataset[\"ability\"].append(\"math\")\n        rl_dataset[\"reward_model\"].append({\"style\": \"lighteval/MATH\", \"ground_truth\": answer})\n        rl_dataset[\"extra_info\"].append(\n            {\"index\": idx, \"expression\": expression, \"split\": split, \"expected_tool_calls\": num_tool_calls}\n        )\n        rl_dataset[\"agent_name\"].append(\"math_expression\")\n\n    rl_dataset = pd.DataFrame(data=rl_dataset)\n    return rl_dataset\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Math Expression Dataset Generator\")\n    parser.add_argument(\"--train_size\", type=int, default=5000, help=\"Number of training samples\")\n    parser.add_argument(\"--test_size\", type=int, default=500, help=\"Number of testing samples\")\n    parser.add_argument(\"--output_dir\", default=\"data/math_expression_tool\", help=\"Directory to save the dataset\")\n    args = parser.parse_args()\n\n    # print(calculate(\"3@2\"))          # Output: 5 (3*3 - 2*2)\n    # print(calculate(\"3@2+4\"))        # Output: 9 (5 + 4)\n    # print(calculate(\"3*(4@2)\"))      # Output: 24 (3 * 8)\n    # print(calculate(\"(5@3)*2\"))      # Output: 18 (9 * 2)\n\n    train_dataset = generate_data(total_num_dataset=args.train_size, split=\"train\")\n    test_dataset = generate_data(total_num_dataset=args.test_size, split=\"test\")\n\n    # Make sure the dataset directory exists\n    os.makedirs(args.output_dir, exist_ok=True)\n\n    # Save the datasets to parquet files\n    train_dataset.to_parquet(os.path.join(args.output_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(args.output_dir, \"test.parquet\"))\n"
  },
  {
    "path": "verl_distillation/recipe/langgraph_agent/example/math_expression.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom langchain_core.tools import tool\n\nfrom recipe.langgraph_agent.react_agent_loop import ReactAgentLoop\n\n\n@tool(parse_docstring=True)\ndef calculate(a: int, b: int, operand: str) -> int:\n    \"\"\"\n    Compute the results using operand with two integers\n\n    Args:\n        a: the first operand\n        b: the second operand\n        operand: '+' or '-' or '*' or '@'\n    \"\"\"\n    assert operand in [\"+\", \"-\", \"*\", \"@\"], f\"unknown operand {operand}\"\n    if operand == \"@\":\n        return 3 * a - 2 * b\n    return eval(f\"{a} {operand} {b}\")\n\n\nclass MathExpressionReactAgentLoop(ReactAgentLoop):\n    @classmethod\n    def init_class(cls, config, tokenizer, **kwargs):\n        cls.tools = [calculate]\n        super().init_class(config, tokenizer)\n"
  },
  {
    "path": "verl_distillation/recipe/langgraph_agent/example/run_gpt_oss_20b_bf16.sh",
    "content": "#!/usr/bin/env bash\n#SBATCH --job-name=rl-langgraph-3B\n#SBATCH --partition=main\n#SBATCH --nodes=1\n#SBATCH --ntasks-per-node=1\n#SBATCH --cpus-per-task=64\n#SBATCH --gres=gpu:4\n#SBATCH --mem=0\n#SBATCH --time=10:00:00\n#SBATCH --output=%x_%j.out\n#SBATCH --error=%x_%j.err\n\nset -xeuo pipefail\n\n# ================= cluster topology =================\nexport GPUS_PER_NODE=${SLURM_GPUS_ON_NODE:-${GPUS_PER_NODE:-2}}  # GPUs on this node\nNNODES=${SLURM_JOB_NUM_NODES:-${NNODES:-1}}\nexport NNODES\nexport RAY_NUM_NODES=$NNODES\n\n# Require at least 2 GPUs\nTOTAL_GPUS=$((GPUS_PER_NODE * NNODES))\nif [ \"$TOTAL_GPUS\" -lt 2 ]; then\n  echo \"Error: at least 2 GPUs are required, detected $TOTAL_GPUS.\" >&2\n  exit 1\nfi\n\necho \"Using $NNODES nodes and $GPUS_PER_NODE GPUs per node...\"\n\n# ================= data/model/tool =================\nHDFS_ROOT=${HDFS_ROOT:-$PWD}\nDATA_ROOT=${DATA_ROOT:-$PWD}\n\n# Prefer local model if present, otherwise fall back to HF hub path\nmodel_path=\"lmsys/gpt-oss-20b-bf16\"\n\n# Use the default output directory produced by create_dataset.py\ntrain_files=$DATA_ROOT/data/math_expression_tool/train.parquet\ntest_files=$DATA_ROOT/data/math_expression_tool/test.parquet\n\n# Agent config\nagent_loop_config_path=recipe/langgraph_agent/example/agent.yaml\n\n# =================== wandb ===================\nproject_name=math_expression_tool\nexperiment_name=gpt-oss-20b-bf16\ndefault_local_dir=$DATA_ROOT/checkpoint/$experiment_name\n\n# ================= algorithm =================\nadv_estimator=grpo\n\nuse_kl_in_reward=false\nkl_coef=0.0\nuse_kl_loss=false\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_turns=8\nmax_prompt_length=1024\nmax_response_length=8192\nactor_lr=1e-6\n\ntrain_batch_size=128\nppo_mini_batch_size=16\nn_resp_per_prompt=8\nn_resp_per_prompt_val=1\n\n# =================== logging ===================\nexport RAY_LOGGING_LEVEL=DEBUG\nexport HYDRA_FULL_ERROR=1\n\n# ================= performance =================\nexport NCCL_IBEXT_DISABLE=1\nexport NCCL_NVLS_ENABLE=1\nexport NCCL_IB_HCA=mlx5\nexport UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1\nexport VLLM_USE_V1=1\nexport VLLM_ATTENTION_BACKEND=FLASH_ATTN\n\ninfer_tp=2  # vLLM tensor parallel size\ntrain_sp=4  # Ulysses sequence parallel size for actor\noffload=true\n\nactor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))\nlog_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 ))\n\ntrain_files=\"['$train_files']\"\ntest_files=\"['$test_files']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=$adv_estimator \\\n    algorithm.use_kl_in_reward=$use_kl_in_reward \\\n    algorithm.kl_ctrl.kl_coef=$kl_coef \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=true \\\n    data.train_batch_size=$train_batch_size \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=true \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=\"$model_path\" \\\n    actor_rollout_ref.model.use_remove_padding=true \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=true \\\n    actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \\\n    actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \\\n    actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \\\n    actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.optim.lr=$actor_lr \\\n    actor_rollout_ref.actor.use_dynamic_bsz=true \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=$offload \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.format=gpt-oss \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=triton \\\n    actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=$n_resp_per_prompt \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=1.0\\\n    actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \\\n    actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=$project_name \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.n_gpus_per_node=\"$GPUS_PER_NODE\" \\\n    trainer.val_before_train=true \\\n    trainer.log_val_generations=50 \\\n    trainer.nnodes=\"$NNODES\" \\\n    trainer.save_freq=-1 \\\n    trainer.default_local_dir=\"$default_local_dir\" \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 \"$@\""
  },
  {
    "path": "verl_distillation/recipe/langgraph_agent/example/run_qwen2.5_3b.sh",
    "content": "#!/usr/bin/env bash\n#SBATCH --job-name=rl-langgraph-3B\n#SBATCH --partition=main\n#SBATCH --nodes=1\n#SBATCH --ntasks-per-node=1\n#SBATCH --cpus-per-task=64\n#SBATCH --gres=gpu:4\n#SBATCH --mem=0\n#SBATCH --time=10:00:00\n#SBATCH --output=%x_%j.out\n#SBATCH --error=%x_%j.err\n\nset -xeuo pipefail\n\n# ================= cluster topology =================\nexport GPUS_PER_NODE=${SLURM_GPUS_ON_NODE:-${GPUS_PER_NODE:-2}}  # GPUs on this node\nNNODES=${SLURM_JOB_NUM_NODES:-${NNODES:-1}}\nexport NNODES\nexport RAY_NUM_NODES=$NNODES\n\n# Require at least 2 GPUs\nTOTAL_GPUS=$((GPUS_PER_NODE * NNODES))\nif [ \"$TOTAL_GPUS\" -lt 2 ]; then\n  echo \"Error: at least 2 GPUs are required, detected $TOTAL_GPUS.\" >&2\n  exit 1\nfi\n\necho \"Using $NNODES nodes and $GPUS_PER_NODE GPUs per node...\"\n\n# ================= data/model/tool =================\nHDFS_ROOT=${HDFS_ROOT:-$PWD}\nDATA_ROOT=${DATA_ROOT:-$PWD}\n\n# Prefer local model if present, otherwise fall back to HF hub path\nmodel_path=${model_path:-$DATA_ROOT/model/Qwen2.5-3B-Instruct}\nif [ ! -d \"$model_path\" ]; then\n  model_path=Qwen/Qwen2.5-3B-Instruct\nfi\n\n# Use the default output directory produced by create_dataset.py\ntrain_files=$DATA_ROOT/data/math_expression_tool/train.parquet\ntest_files=$DATA_ROOT/data/math_expression_tool/test.parquet\n\n# Agent config\nagent_loop_config_path=recipe/langgraph_agent/example/agent.yaml\n\n# =================== wandb ===================\nproject_name=math_expression_tool\nexperiment_name=qwen2.5-3b\ndefault_local_dir=$DATA_ROOT/checkpoint/$experiment_name\n\n# ================= algorithm =================\nadv_estimator=grpo\n\nuse_kl_in_reward=false\nkl_coef=0.0\nuse_kl_loss=false\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_turns=8\nmax_prompt_length=1024\nmax_response_length=2048\nactor_lr=1e-6\n\ntrain_batch_size=128\nppo_mini_batch_size=16\nn_resp_per_prompt=8\nn_resp_per_prompt_val=1\n\n# =================== logging ===================\nexport RAY_LOGGING_LEVEL=DEBUG\nexport HYDRA_FULL_ERROR=1\n\n# ================= performance =================\nexport NCCL_IBEXT_DISABLE=1\nexport NCCL_NVLS_ENABLE=1\nexport NCCL_IB_HCA=mlx5\nexport UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1\nexport VLLM_USE_V1=1\nexport VLLM_ATTENTION_BACKEND=FLASH_ATTN\n\ninfer_tp=2  # vLLM tensor parallel size\ntrain_sp=4  # Ulysses sequence parallel size for actor\noffload=true\n\nactor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))\nlog_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 ))\n\ntrain_files=\"['$train_files']\"\ntest_files=\"['$test_files']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=$adv_estimator \\\n    algorithm.use_kl_in_reward=$use_kl_in_reward \\\n    algorithm.kl_ctrl.kl_coef=$kl_coef \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=true \\\n    data.train_batch_size=$train_batch_size \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=true \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=\"$model_path\" \\\n    actor_rollout_ref.model.use_remove_padding=true \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=true \\\n    actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \\\n    actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \\\n    actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \\\n    actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.optim.lr=$actor_lr \\\n    actor_rollout_ref.actor.use_dynamic_bsz=true \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=$offload \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n    actor_rollout_ref.rollout.n=$n_resp_per_prompt \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \\\n    actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=$project_name \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.n_gpus_per_node=\"$GPUS_PER_NODE\" \\\n    trainer.val_before_train=true \\\n    trainer.log_val_generations=50 \\\n    trainer.nnodes=\"$NNODES\" \\\n    trainer.save_freq=-1 \\\n    trainer.default_local_dir=\"$default_local_dir\" \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 \"$@\""
  },
  {
    "path": "verl_distillation/recipe/langgraph_agent/react_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nLangGraph React Agent Loop.\n\nThis implementation is exact same as `ToolAgentLoop`.\n\nRef: https://langchain-ai.github.io/langgraph/tutorials/workflows/\n\"\"\"\n\nfrom typing import Any, Literal\n\nfrom langchain_core.runnables import RunnableConfig\nfrom langgraph.graph import END, MessagesState, StateGraph\nfrom langgraph.prebuilt import ToolNode\n\nfrom recipe.langgraph_agent.chat_model import (\n    ChatModel,\n    MaxTokenExceededError,\n    convert_to_agent_output,\n)\nfrom verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput\n\n\nasync def call_model(state: MessagesState, config: RunnableConfig):\n    model = config[\"configurable\"][\"model\"]\n    sampling_params = config[\"configurable\"][\"sampling_params\"]\n    try:\n        message = await model.ainvoke(state[\"messages\"], sampling_params=sampling_params)\n        return {\"messages\": [message]}\n    except MaxTokenExceededError:\n        # last message is ToolMessage\n        return {\"messages\": []}\n\n\ndef should_continue(state: MessagesState, config: RunnableConfig) -> Literal[\"tools\", END]:\n    max_assistant_turns = config[\"configurable\"][\"max_assistant_turns\"]\n    num_assistant_turns = 0\n    for message in state[\"messages\"]:\n        if message.type == \"ai\":\n            num_assistant_turns += 1\n\n    last_message = state[\"messages\"][-1]\n\n    # LLM call failed, e.g: max response length exceeded\n    if last_message.type == \"tool\":\n        return END\n\n    # max assistant turns exceeded\n    if max_assistant_turns and num_assistant_turns >= max_assistant_turns:\n        return END\n\n    # no tool calls\n    if not last_message.tool_calls:\n        return END\n\n    return \"tools\"\n\n\nclass ReactAgentLoop(AgentLoopBase):\n    @classmethod\n    def init_class(cls, config, tokenizer, **kwargs):\n        if cls._class_initialized:\n            return\n        cls._class_initialized = True\n        print(\"Performing class-level ReactAgentLoop initialization\")\n\n        # build graph\n        cls.graph = cls.build_graph()\n\n    @classmethod\n    def build_graph(cls) -> StateGraph:\n        workflow = StateGraph(MessagesState)\n\n        workflow.add_node(\"agent\", call_model)\n        workflow.add_node(\"tools\", ToolNode(cls.tools))\n        workflow.set_entry_point(\"agent\")\n        workflow.add_conditional_edges(\n            \"agent\",\n            should_continue,\n            {\n                \"tools\": \"tools\",\n                END: END,\n            },\n        )\n\n        workflow.add_edge(\"tools\", \"agent\")\n        graph = workflow.compile()\n        return graph\n\n    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:\n        messages = list(kwargs[\"raw_prompt\"])\n\n        model_path = self.config.actor_rollout_ref.model.path\n        model_name = \"/\".join(model_path.split(\"/\")[-2:])\n\n        rollout = self.config.actor_rollout_ref.rollout\n        model = ChatModel(\n            model=model_name,\n            client=self.server_manager,\n            tokenizer=self.tokenizer,\n            max_tokens=rollout.response_length,\n            max_parallel_calls=rollout.multi_turn.max_parallel_calls,\n            tool_parser=rollout.multi_turn.format,\n        )\n\n        model = model.bind_tools(self.tools, tool_choice=\"any\")\n\n        config = {\n            \"configurable\": {\n                \"model\": model,\n                \"sampling_params\": sampling_params,\n                \"max_user_turns\": rollout.multi_turn.max_user_turns,\n                \"max_assistant_turns\": rollout.multi_turn.max_assistant_turns,\n            }\n        }\n\n        # TODO: how to handle multiple trajectories in an graph invocation?\n        # Each graph node may has its own LLM calls and state, e.g:\n        # https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart\n        state = await self.graph.ainvoke(input={\"messages\": messages}, config=config)\n\n        output = convert_to_agent_output(state[\"messages\"], rollout.response_length)\n        return output\n"
  },
  {
    "path": "verl_distillation/recipe/langgraph_agent/test_react_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport os\n\nimport numpy as np\nimport pytest\nimport ray\nfrom langchain_core.tools import tool\nfrom omegaconf import DictConfig\n\nfrom recipe.langgraph_agent.react_agent_loop import ReactAgentLoop\nfrom tests.experimental.agent_loop.agent_utils import init_agent_loop_manager\nfrom verl.protocol import DataProto\nfrom verl.utils import hf_tokenizer\n\n\n@pytest.fixture\ndef init_config() -> DictConfig:\n    from hydra import compose, initialize_config_dir\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(config_name=\"ppo_trainer\")\n    model_path = \"Qwen/Qwen2.5-1.5B-Instruct\"\n    config.actor_rollout_ref.model.path = model_path\n    config.actor_rollout_ref.rollout.name = os.getenv(\"ROLLOUT_NAME\", \"vllm\")\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.prompt_length = 4096\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.n = 4\n    config.actor_rollout_ref.rollout.agent.num_workers = 2\n\n    config.actor_rollout_ref.actor.use_dynamic_bsz = True\n    # test sleep/wake_up with fsdp offload\n    config.actor_rollout_ref.actor.fsdp_config.param_offload = True\n    config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True\n\n    return config\n\n\n@tool(parse_docstring=True)\ndef get_current_temperature(location: str, unit: str = \"celsius\"):\n    \"\"\"Get current temperature at a location.\n\n    Args:\n        location: The location to get the temperature for, in the format \"City, State, Country\".\n        unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n    Returns:\n        the temperature, the location, and the unit in a dict\n    \"\"\"\n    print(f\"[DEBUG] get_current_temperature: {location}, {unit}\")\n    return {\n        \"temperature\": 26.1,\n        \"location\": location,\n        \"unit\": unit,\n    }\n\n\n@tool(parse_docstring=True)\ndef get_temperature_date(location: str, date: str, unit: str = \"celsius\"):\n    \"\"\"Get temperature at a location and date.\n\n    Args:\n        location: The location to get the temperature for, in the format \"City, State, Country\".\n        date: The date to get the temperature for, in the format \"Year-Month-Day\".\n        unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n    Returns:\n        the temperature, the location, the date and the unit in a dict\n    \"\"\"\n    print(f\"[DEBUG] get_temperature_date: {location}, {date}, {unit}\")\n    return {\n        \"temperature\": 25.9,\n        \"location\": location,\n        \"date\": date,\n        \"unit\": unit,\n    }\n\n\nclass TestReactAgentLoop(ReactAgentLoop):\n    @classmethod\n    def init_class(cls, config, tokenizer, **kwargs):\n        # TODO: find better way to configure tools\n        cls.tools = [get_current_temperature, get_temperature_date]\n        super().init_class(config, tokenizer, **kwargs)\n\n\ndef test_react_agent(init_config):\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    # =========================== 1. Init rollout manager ===========================\n    agent_loop_config = [\n        {\n            \"_target_\": \"recipe.langgraph_agent.test_react_agent_loop.TestReactAgentLoop\",\n            \"name\": \"react_agent\",\n        },\n    ]\n    agent_loop_config_path = \"/tmp/agent_loop_config.json\"\n    with open(agent_loop_config_path, \"w\") as f:\n        json.dump(agent_loop_config, f)\n\n    n = 2\n    init_config.actor_rollout_ref.rollout.n = n\n    # init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path\n    init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2\n    init_config.actor_rollout_ref.rollout.agent.agent_loop_config_path = agent_loop_config_path\n    agent_loop_manager = init_agent_loop_manager(init_config)\n\n    # =========================== 2. Generate sequences  ===========================\n    raw_prompts = [\n        [\n            {\"role\": \"user\", \"content\": \"How are you?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in Los Angeles now?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in New York now?\"},\n        ],\n        [\n            {\n                \"role\": \"system\",\n                \"content\": \"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\\n\\n\"\n                \"Current Date: 2024-09-30\",\n            },\n            {\"role\": \"user\", \"content\": \"What's the temperature in San Francisco now? How about tomorrow?\"},\n        ],\n    ]\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),\n            \"agent_name\": np.array([\"react_agent\"] * len(raw_prompts)),\n            \"data_source\": np.array([\"openai/gsm8k\"] * len(raw_prompts)),\n            \"reward_model\": np.array([{\"style\": \"rule\", \"ground_truth\": \"1.0\"}] * len(raw_prompts)),\n        },\n    )\n    batch = batch.repeat(n)\n    result = agent_loop_manager.generate_sequences(prompts=batch)\n    assert len(result) == len(raw_prompts) * n\n\n    # Check turns\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    print(f\"num_turns: {num_turns}\")\n    for i in range(len(num_turns)):\n        if i // n == 0:\n            # [user, assistant]\n            assert num_turns[i] == 2\n        else:\n            # [user, assistant, tool, assistant]\n            assert num_turns[i] == 4\n\n    # Check response_mask\n    tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)\n    responses = result.batch[\"responses\"]\n    response_mask = result.batch[\"response_mask\"]\n    attention_mask = result.batch[\"attention_mask\"]\n    assert responses.size() == response_mask.size(), f\"{responses.size()} != {response_mask.size()}\"\n    response_length = response_mask.size(1)\n\n    for i in range(len(responses)):\n        # response with tool response\n        valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]\n        response_with_obs = tokenizer.decode(valid_tokens)\n\n        # response without tool response\n        valid_tokens = responses[i][response_mask[i].bool()]\n        response_without_obs = tokenizer.decode(valid_tokens)\n\n        assert \"<tool_response>\" not in response_without_obs, (\n            f\"found <tool_response> in response: {response_without_obs}\"\n        )\n        assert \"</tool_response>\" not in response_without_obs, (\n            f\"found </tool_response> in response: {response_without_obs}\"\n        )\n        print(\"=========================\")\n        print(response_with_obs)\n        print(\"---\")\n        print(response_without_obs)\n\n    print(\"Test passed!\")\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/recipe/minicpmo/rl_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nimport logging\nimport math\nimport os\nimport re\nfrom typing import Optional\n\nimport datasets\nimport torch\nfrom omegaconf import DictConfig, ListConfig\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.utils.dataset.vision_utils import process_image\nfrom verl.utils.model import compute_position_id_with_mask\n\nlogger = logging.getLogger(__name__)\n\n\ndef build_transform():\n    IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)  # timm.data.IMAGENET_INCEPTION_MEAN\n    IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)  # timm.data.IMAGENET_INCEPTION_STD\n    return transforms.Compose(\n        [\n            transforms.ToTensor(),\n            transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),\n        ]\n    )\n\n\ndef build_image_bound(input_ids, tokenizer, new_schema=True, logger=None):\n    if new_schema:\n        start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id)\n        end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id)\n    else:\n        start_cond = input_ids == tokenizer.im_start_id\n        end_cond = input_ids == tokenizer.im_end_id\n    image_start_tokens = torch.where(start_cond)[0]\n    image_start_tokens += 1\n    image_end_tokens = torch.where(end_cond)[0]\n    if len(image_start_tokens) != len(image_end_tokens):\n        logger.error(\"image start token != image end tokens\")\n        raise Exception(\"image start token != image end tokens\")\n    if len(image_start_tokens) > 0:\n        image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])\n    else:\n        image_bound = []\n    return image_bound\n\n\ndef preprocess(\n    images_dict,\n    conversations,\n    tokenizer,\n    transform,\n    query_nums=64,\n    slice_config=None,\n    llm_type=None,\n    patch_size=14,\n    batch_vision=False,\n    max_length=2048,\n    truncation=\"error\",\n    apply_chat_template_kwargs=None,\n    logger=None,\n):\n    \"\"\"\n    single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation\n    \"\"\"\n    conversations = copy.deepcopy(conversations)\n    assert conversations[0][\"role\"] == \"user\", \"the first role must be user\"\n\n    if slice_config is not None:\n        assert isinstance(slice_config, dict)\n        assert \"patch_size\" in slice_config\n        assert \"max_slice_nums\" in slice_config\n        assert \"scale_resolution\" in slice_config\n    default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end\n    new_schema = False\n    use_image_id = False\n    if llm_type == \"qwen\":\n        new_schema = True\n        use_image_id = True\n    image_placeholder_dict = {}\n    images = []\n    image_id_cnt = 0\n    for img_name, image in images_dict.items():\n        if slice_config:\n            source_image, patches, best_grid = slice_image(\n                image,\n                slice_config[\"max_slice_nums\"],\n                slice_config[\"scale_resolution\"],\n                slice_config[\"patch_size\"],\n            )\n            images.append(source_image)\n            image_placeholder = default_image_placeholder\n            if len(patches) > 0:\n                for i in range(len(patches)):\n                    for j in range(len(patches[0])):\n                        images.append(patches[i][j])\n                if use_image_id:\n                    image_placeholder = (\n                        f\"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}\" + image_placeholder\n                    )\n                    image_id_cnt += 1\n                image_placeholder += get_grid_placeholder(tokenizer, best_grid, query_nums, new_schema=new_schema)\n            image_placeholder_dict[img_name] = image_placeholder\n        else:\n            images.append(image)\n            if use_image_id:\n                image_placeholder = f\"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}\" + image_placeholder\n                image_id_cnt += 1\n            else:\n                image_placeholder = default_image_placeholder\n            image_placeholder_dict[img_name] = image_placeholder\n\n    images = [transform(i) for i in images]\n\n    if len(images_dict) == 1 and \"<image>\" in images_dict:\n        if \"<image>\" in conversations[0][\"content\"]:\n            conversations[0][\"content\"] = conversations[0][\"content\"].replace(\"<image>\", image_placeholder)\n        else:\n            conversations[0][\"content\"] = image_placeholder + \"\\n\" + conversations[0][\"content\"]\n    else:\n        pattern = r\"<image_\\d+>\"\n        new_conversations = []\n        for conversation in conversations:\n            content = conversation[\"content\"]\n            parts = re.split(f\"({pattern})\", content)\n            for i, part in enumerate(parts):\n                if not part.strip():\n                    continue\n                if re.match(pattern, part):\n                    if part in image_placeholder_dict:\n                        parts[i] = image_placeholder_dict[part]\n                    else:\n                        raise Exception(f\"not found {part} in image dict\")\n            conversation[\"content\"] = \"\\n\".join(parts)\n            new_conversations.append(conversation)\n        conversations = new_conversations\n\n    # TODO change role in conversation for different llm\n    prompt_with_chat_template = tokenizer.apply_chat_template(\n        conversations, add_generation_prompt=True, tokenize=False, **(apply_chat_template_kwargs or {})\n    )\n\n    input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(\n        prompt=prompt_with_chat_template,\n        tokenizer=tokenizer,\n        max_length=max_length,\n        pad_token_id=tokenizer.pad_token_id,\n        left_pad=True,\n        truncation=truncation,\n    )\n    position_ids = compute_position_id_with_mask(attention_mask)\n    image_bound = build_image_bound(input_ids[0], tokenizer, new_schema, logger)\n\n    input_dict = {\n        \"input_ids\": input_ids[0],\n        \"attention_mask\": attention_mask[0],\n        \"position_ids\": position_ids[0],\n        \"image_bound\": image_bound,\n    }\n\n    if batch_vision:\n        tgt_sizes = []\n        reshape_images = []\n        for image in images:\n            H, W = image.shape[1:]\n            reshape_image = reshape_by_patch(image, patch_size)\n            reshape_images.append(reshape_image)\n            tgt_sizes.append([H // patch_size, W // patch_size])\n        if tgt_sizes:\n            tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32)\n\n        input_dict[\"pixel_values\"] = reshape_images\n        input_dict[\"tgt_sizes\"] = tgt_sizes\n\n    else:\n        input_dict[\"pixel_values\"] = images\n        input_dict[\"tgt_sizes\"] = []\n\n    return input_dict\n\n\ndef slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):\n    original_size = image.size\n    original_width, original_height = original_size\n    log_ratio = math.log(original_width / original_height)\n    ratio = original_width * original_height / (scale_resolution * scale_resolution)\n    multiple = min(math.ceil(ratio), max_slice_nums)\n\n    source_image = None\n    best_grid = None\n    patches = []\n\n    if multiple <= 1 or never_split:\n        # dont need to slice, upsample\n        best_size = find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)\n        source_image = image.resize(best_size, Image.Resampling.BICUBIC)\n    else:\n        candidate_split_grids_nums = []\n        for i in [multiple - 1, multiple, multiple + 1]:\n            if i == 1 or i > max_slice_nums:\n                continue\n            candidate_split_grids_nums.append(i)\n\n        # source image, down-sampling and ensure divided by patch_size\n        best_resize = find_best_resize(original_size, scale_resolution, patch_size)\n        source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)\n        candidate_grids = []\n\n        # find best grid\n        for split_grids_nums in candidate_split_grids_nums:\n            m = 1\n            while m <= split_grids_nums:\n                if split_grids_nums % m == 0:\n                    candidate_grids.append([m, split_grids_nums // m])\n                m += 1\n\n        best_grid = [1, 1]\n        min_error = float(\"inf\")\n        for grid in candidate_grids:\n            error = abs(log_ratio - math.log(grid[0] / grid[1]))\n            if error < min_error:\n                best_grid = grid\n                min_error = error\n\n        refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, allow_upscale=True)\n\n        refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)\n        patches = split_to_patches(refine_image, best_grid)\n\n    return source_image, patches, best_grid\n\n\ndef ensure_divide(length, patch_size):\n    return max(round(length / patch_size) * patch_size, patch_size)\n\n\ndef find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False):\n    width, height = original_size\n    if (width * height > scale_resolution * scale_resolution) or allow_upscale:\n        r = width / height\n        height = int(scale_resolution / math.sqrt(r))\n        width = int(height * r)\n    best_width = ensure_divide(width, patch_size)\n    best_height = ensure_divide(height, patch_size)\n    return (best_width, best_height)\n\n\ndef get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False):\n    width, height = original_size\n    grid_x, grid_y = grid\n\n    refine_width = ensure_divide(width, grid_x)\n    refine_height = ensure_divide(height, grid_y)\n\n    grid_width = refine_width / grid_x\n    grid_height = refine_height / grid_y\n\n    best_grid_size = find_best_resize(\n        (grid_width, grid_height),\n        scale_resolution,\n        patch_size,\n        allow_upscale=allow_upscale,\n    )\n\n    refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)\n\n    return refine_size\n\n\ndef split_to_patches(image, grid):\n    patches = []\n    width, height = image.size\n    grid_x = int(width / grid[0])\n    grid_y = int(height / grid[1])\n\n    for i in range(0, height, grid_y):\n        images = []\n        for j in range(0, width, grid_x):\n            box = (j, i, j + grid_x, i + grid_y)\n            patch = image.crop(box)\n            images.append(patch)\n        patches.append(images)\n\n    return patches\n\n\ndef get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):\n    if new_schema:\n        image_placeholder = tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end\n    else:\n        image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end\n\n    cols = grid[0]\n    rows = grid[1]\n    slices = []\n    for i in range(rows):\n        lines = []\n        for j in range(cols):\n            lines.append(image_placeholder)\n        slices.append(\"\".join(lines))\n    if new_schema:\n        slice_placeholder = \"\\n\".join(slices)\n    else:\n        slice_placeholder = tokenizer.slice_start + \"\\n\".join(slices) + tokenizer.slice_end\n    return slice_placeholder\n\n\ndef reshape_by_patch(image_tensor, patch_size):\n    \"\"\"\n    :param image_tensor: shape [3, H, W]\n    :param patch_size:\n    :return: [3, patch_size, HW/patch_size]\n    \"\"\"\n    patches = torch.nn.functional.unfold(image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size))\n\n    patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)\n    patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1)\n    return patches\n\n\ndef init_minicpmo_config(processor, config):\n    \"\"\"Initialize MiniCPM-o specific configuration\"\"\"\n    minicpmo_config = {\n        \"transform\": build_transform(),\n        \"patch_size\": config.get(\"patch_size\", 14),\n        \"query_nums\": config.get(\"query_nums\", 64),\n        \"slice_config\": config.get(\n            \"slice_config\", {\"max_slice_nums\": 9, \"patch_size\": config.get(\"patch_size\", 14), \"scale_resolution\": 448}\n        ),\n        \"llm_type\": config.get(\"llm_type\", \"qwen\"),\n        \"batch_vision\": config.get(\"batch_vision\", True),\n    }\n    return minicpmo_config\n\n\ndef process_minicpmo_data(\n    row_dict,\n    messages,\n    tokenizer,\n    minicpmo_config,\n    image_key,\n    max_prompt_length,\n    truncation,\n    apply_chat_template_kwargs,\n    logger,\n):\n    \"\"\"Process data for MiniCPM-o model\"\"\"\n    if len(row_dict[image_key]) == 1:\n        multi_modal_data = {}\n        image = process_image(row_dict.pop(image_key)[0])\n        multi_modal_data[\"image\"] = [image]\n        images_dict = {\"<image>\": image}\n    else:\n        raise NotImplementedError\n\n    model_inputs = preprocess(\n        images_dict,\n        messages,\n        tokenizer,\n        minicpmo_config[\"transform\"],\n        query_nums=minicpmo_config[\"query_nums\"],\n        slice_config=minicpmo_config[\"slice_config\"],\n        llm_type=minicpmo_config[\"llm_type\"],\n        patch_size=minicpmo_config[\"patch_size\"],\n        batch_vision=minicpmo_config[\"batch_vision\"],\n        max_length=max_prompt_length,\n        truncation=truncation,\n        apply_chat_template_kwargs=apply_chat_template_kwargs,\n        logger=logger,\n    )\n\n    raw_prompt = tokenizer.apply_chat_template(\n        messages, add_generation_prompt=True, tokenize=False, **(apply_chat_template_kwargs or {})\n    )\n    raw_prompt = raw_prompt.replace(\"<image>\", \"(<image>./</image>)\")\n\n    return model_inputs, multi_modal_data, raw_prompt\n\n\nclass RLHFDataset(Dataset):\n    \"\"\"\n    Load and preprocess RLHF data from Parquet files.\n\n    - Caches files locally.\n    - Reads into a HuggingFace Dataset and tokenizes prompts.\n    - Optionally handles images/videos via a ProcessorMixin.\n    - Filters prompts over a max length.\n    - Supports resuming from checkpoints.\n\n    Args:\n        data_files (str or list): Path(s) to Parquet file(s).\n        tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.\n        config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.\n        processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_files: str | list[str],\n        tokenizer: PreTrainedTokenizer,\n        config: DictConfig,\n        processor: Optional[ProcessorMixin] = None,\n    ):\n        if not isinstance(data_files, list | ListConfig):\n            data_files = [data_files]\n\n        self.data_files = copy.deepcopy(data_files)\n        self.original_data_files = copy.deepcopy(data_files)  # use for resume\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n\n        self.cache_dir = os.path.expanduser(config.get(\"cache_dir\", \"~/.cache/verl/rlhf\"))\n        self.prompt_key = config.get(\"prompt_key\", \"prompt\")\n        self.image_key = config.get(\"image_key\", \"images\")\n        self.video_key = config.get(\"video_key\", \"videos\")\n        self.max_prompt_length = config.get(\"max_prompt_length\", 1024)\n        self.return_raw_chat = config.get(\"return_raw_chat\", False)\n        self.return_full_prompt = config.get(\"return_full_prompt\", False)\n        self.truncation = config.get(\"truncation\", \"error\")\n        self.filter_overlong_prompts = config.get(\"filter_overlong_prompts\", True)\n        self.apply_chat_template_kwargs = config.get(\"apply_chat_template_kwargs\", {})\n\n        self.num_workers = config.get(\"filter_overlong_prompts_workers\", max(1, os.cpu_count() // 4))\n        self.num_workers = min(self.num_workers, os.cpu_count())\n        self.use_shm = config.get(\"use_shm\", False)\n        self.chat_template_func = config.get(\"chat_template_func\", None)\n        self.need_tools_kwargs = config.get(\"need_tools_kwargs\", False)\n        self.filter_prompts = config.get(\"filter_prompts\", True)\n        self.serialize_dataset = False\n        self.minicpmo_config = init_minicpmo_config(self.processor, config)\n        self._download()\n        self._read_files_and_tokenize()\n\n    def _download(self, use_origin_parquet=False):\n        from verl.utils.fs import copy_to_local\n\n        data_files = self.data_files if not use_origin_parquet else self.original_data_files\n        for i, parquet_file in enumerate(data_files):\n            self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)\n\n    def _read_files_and_tokenize(self):\n        dataframes = []\n        for parquet_file in self.data_files:\n            # read parquet files and cache\n            dataframe = datasets.load_dataset(\"parquet\", data_files=parquet_file)[\"train\"]\n            dataframes.append(dataframe)\n        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)\n\n        print(f\"dataset len: {len(self.dataframe)}\")\n\n    def resume_dataset_state(self):\n        self.serialize_dataset = not hasattr(self, \"original_data_files\")\n        # resume dataframe if not it's serialized in data.pt\n        if not self.serialize_dataset:\n            self._download(use_origin_parquet=True)  # download and resume from original parquet files\n            self._read_files_and_tokenize()\n        else:\n            print(r\"old dataloader ckpt file is used, please train from scratch for better ckpt performance\")\n\n    def __len__(self):\n        return len(self.dataframe)\n\n    def _build_messages(self, example: dict):\n        return example.pop(self.prompt_key)\n\n    def __getitem__(self, item):\n        \"\"\"\n        Note that we also return the raw_input_ids so that it can be combined with other chat template\n        \"\"\"\n        row_dict: dict = self.dataframe[item]\n        messages = self._build_messages(row_dict)\n        model_inputs = {}\n\n        if self.processor is not None:\n            model_inputs, multi_modal_data, raw_prompt = process_minicpmo_data(\n                row_dict,\n                messages,\n                self.tokenizer,\n                self.minicpmo_config,\n                self.image_key,\n                self.max_prompt_length,\n                self.truncation,\n                self.apply_chat_template_kwargs,\n                logger,\n            )\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n            position_ids = model_inputs.pop(\"position_ids\")\n\n            # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature\n            row_dict[\"multi_modal_data\"] = multi_modal_data\n            row_dict[\"multi_modal_inputs\"] = dict(model_inputs)\n        else:\n            raw_prompt = self.tokenizer.apply_chat_template(\n                messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs\n            )\n            model_inputs = self.tokenizer(raw_prompt, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n            position_ids = compute_position_id_with_mask(attention_mask)\n\n        row_dict[\"input_ids\"] = input_ids\n        row_dict[\"attention_mask\"] = attention_mask\n        row_dict[\"position_ids\"] = position_ids\n\n        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)\n        if len(raw_prompt_ids) > self.max_prompt_length:\n            if self.truncation == \"left\":\n                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]\n            elif self.truncation == \"right\":\n                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]\n            elif self.truncation == \"middle\":\n                left_half = self.max_prompt_length // 2\n                right_half = self.max_prompt_length - left_half\n                raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]\n            elif self.truncation == \"error\":\n                raise RuntimeError(f\"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.\")\n\n        row_dict[\"raw_prompt_ids\"] = raw_prompt_ids\n        # encode prompts without chat template\n        if self.return_raw_chat:\n            row_dict[\"raw_prompt\"] = messages\n\n        # get prompts with chat template\n        if self.return_full_prompt:\n            row_dict[\"full_prompts\"] = raw_prompt  # array of strings\n\n        # add index for each prompt\n        index = row_dict.get(\"extra_info\", {}).get(\"index\", 0)\n        tools_kwargs = row_dict.get(\"extra_info\", {}).get(\"tools_kwargs\", {})\n        interaction_kwargs = row_dict.get(\"extra_info\", {}).get(\"interaction_kwargs\", {})\n        need_tools_kwargs = row_dict.get(\"extra_info\", {}).get(\"need_tools_kwargs\", self.need_tools_kwargs)\n        if need_tools_kwargs and not tools_kwargs:\n            logger.warning(\"tools_kwargs is empty for index {}, data source: {}\", index, row_dict[\"data_source\"])\n        row_dict[\"index\"] = index\n        row_dict[\"tools_kwargs\"] = tools_kwargs\n        row_dict[\"interaction_kwargs\"] = interaction_kwargs\n        return row_dict\n\n    def __getstate__(self):\n        if not self.serialize_dataset:\n            state = self.__dict__.copy()\n\n            if \"dataframe\" in state:\n                del state[\"dataframe\"]\n            return state\n\n        return self.__dict__.copy()\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/README.md",
    "content": "# Recipe: One Step Off Policy Async Trainer\n\n**Author:**  `https://github.com/meituan-search`\n\nLast updated: 07/17/2025.\n\n## Introduction\n\n### Background\n\nThe current reinforcement learning training process implemented by verl is synchronous, adhering to the algorithmic\nworkflows of established methods like PPO, GRPO, and DAPO. In each step, training samples are generated by the latest\nmodel, and the model is updated after training completes. While this approach aligns with off-policy reinforcement\nlearning and stabilizes RL training, but it suffers from severe efficiency issues.\nModel updates must wait for the longest output in the generation phase to complete.\nDuring the generation of long-tail samples, GPUs remain idle, resulting in significant underutilization.\nThe more severe the long-tail problem in sample generation, the lower the overall training efficiency.\nFor example, in DAPO 32B training, the Rollout phase accounts for approximately 70% of the total time,\nand increasing resources does not reduce the Rollout duration.\n\n![DAPO 32B Math Performance](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/dapo_32b_math.png)\n> source data: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=nwusertongyuxuan361\n\n### Solution\n\nWe have implemented the **One Step Off Async Trainer** to help alleviate this issue. This approach parallelizes the\ngeneration and training processes, utilizing samples generated in the previous step for current training.\nIt also involves appropriately partitioning resources, allocating dedicated resources for generation while automatically\nassigning the remainder to training. By reducing resources allocated to the generation phase, we mitigate GPU idle time\nduring long-tail sample generation. Throughout this process, generation and training parameters maintain a one-step off\npolicy.\n\n![One Step Off Policy Diagram](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_policy.png)\n> reference: [AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning](\n> https://arxiv.org/abs/2505.24298)\n> original work: [Asynchronous RLHF: Faster and More Efficient Off-Policy RL for Language Models](https://arxiv.org/abs/2410.18252)\n\nOur core contributions include:\n\n1. **Parallel Generation and Training**:  \n   Samples for the next batch are asynchronously generated while the current batch is being trained.\n\n2. **Resource Isolation**:  \n   Unlike `hybrid_engine`, this method requires explicit resource allocation for rollout, with remaining resources\n   automatically assigned to training.\n\n3. **NCCL Parameter Synchronization**:  \n   Employs NCCL communication primitives for seamless parameter transfer between generation and training modules.\n\n### Experimental Results\n\n- **Machine Configuration**: 2 nodes with 16 H20 GPUs each\n    - Generation: 4 GPUs\n    - Training: 12 GPUs\n- **Model**: Qwen2.5-Math-7B\n- **Rollout Configuration**:\n- **Max Response Length**: FSDP2: 20,480 tokens; Megatron: 8,192 tokens\n- **Algorithm**: DAPO\n- **Rollout Engine**: vLLM\n\n| training mode          | engine        | step | gen | wait_prev_gen | generate_sequences | old_log_prob | update_actor | total time    | acc/best@32/mean | acc/maj@32/mean |\n|------------------------|---------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------|\n| colocate sync          | VLLM+FSDP2    | 749  | 321 | -             | 247                | 88           | 286          | 19h18m        | 0.5948           | 0.417           |\n| one-step-overlap async | VLLM+FSDP2    | 520  | -   | 45            | 458                | 108          | 337          | 15h34m（+23%）  | 0.6165           | 0.494           |\n| colocate sync          | VLLM+Megatron | 699  | 207 | -             | 162                | 119          | 344          | 18h21m        | 0.605            | 0.4217          |\n| one-step-overlap async | VLLM+Megatron | 566  | -   | 59            | 501                | 120          | 347          | 13h06m (+40%) | 0.6569           | 0.4038          |\n\n* colocate sync: step ≈ gen + old_log_prob + update_actor\n* one-step-overlap async: step ≈ wait_prev_gen + old_log_prob + update_actor\n\n![One Step Off Megatron Performance](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_megatron.png)\n\n> source data: https://wandb.ai/hou-zg-meituan/one-step-off-policy?nw=nwuserhouzg\n\n## Implementation\n\n### One Step Off Policy Async Pipline\n\nOur implemented **One Step Off Policy Async Pipeline** integrates seamlessly into existing training logic at minimal\ncost,\neliminating the need for additional sample storage management. The core mechanism uses `async_gen_next_batch`\nfor asynchronous rollout generation while maintaining continuous operation during epoch transitions\nvia `create_continuous_iterator`.\n\n```python\n# iterator generator, simplify one-step integration of the training process\ndef _create_continuous_iterator(self):\n    for epoch in range(self.config.trainer.total_epochs):\n        iterator = iter(self.train_dataloader)\n        for batch_dict in iterator:\n            yield epoch, batch_dict\n\n\n# read next batch samples, parameters sync and launch asyn gen_seq\ndef _async_gen_next_batch(self, continuous_iterator):\n    # read train_data\n    try:\n        epoch, batch_dict = next(continuous_iterator)\n    except StopIteration:\n        return None\n    batch = DataProto.from_single_dict(batch_dict)\n    gen_batch = batch_pocess(batch)\n    # sync weights from actor to rollout\n    self.sync_rollout_weights()\n    # async generation\n    gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)\n    # future encapsulated\n    return GenerationBatchFuture(epoch, batch, gen_batch_output)\n\n\ncontinuous_iterator = self._create_continuous_iterator()\n# run rollout first to achieve one-step-off\nbatch_data_future = self._async_gen_next_batch(continuous_iterator)\n\nwhile batch_data_future is not None:\n    # wait for the gen_seq result from the previous step\n    batch = batch_data_future.get()\n    # launch the next async call to generate sequences\n    batch_data_future = self._async_gen_next_batch(continuous_iterator)\n\n    # compute advantages \n    batch = critic.compute_values(batch)\n    batch = reference.compute_log_prob(batch)\n    batch = reward.compute_reward(batch)\n    batch = compute_advantages(batch)\n\n    # model update\n    critic_metrics = critic.update_critic(batch)\n    actor_metrics = actor.update_actor(batch)\n```\n\n### Parameter Synchronization\n\nThe exciting point is that our nccl based weights updating for rollout model has great performance.\nAt most of time, the latency is under 300ms, which is negligible for RLHF.\n\n> **sync_rollout_weights**：The time for synchronizing parameters from actor to rollout is extremely fast and can almost\n> be ignored because it is implemented with nccl.\n\n```python\nclass ActorRolloutRefWorker:\n    # actor acquires the meta-info of model parameters for parameter sync\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def get_actor_weights_info(self):\n        params = self._get_actor_params()\n        ret = []\n        for key, tensor in params.items():\n            ret.append((key, tensor.size(), tensor.dtype))\n        self._weights_info = ret\n        return ret\n\n    # rollout sets the meta-info of model parameters for parameter sync\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def set_actor_weights_info(self, weights_info):\n        self._weights_info = weights_info\n\n\nclass AsyncRayPPOTrainer(RayPPOTrainer):\n    def init_workers(self):\n\n\n...\n# rollout obtains the meta-info of model parameters from the actor for parameter sync\nweights_info = self.actor_wg.get_actor_weights_info()[0]\nself.rollout_wg.set_actor_weights_info(weights_info)\n\n# Create an actor-rollout communication group for parameter sync\nactor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers\ncollective.create_collective_group(\n    actor_rollout_workers,\n    len(actor_rollout_workers),\n    list(range(0, len(actor_rollout_workers))),\n    backend=\"nccl\",\n    group_name=\"actor_rollout\"\n)\n```\n\n```python\n# drive process call the actor and rollout respectively to sync parameters by nccl \ndef sync_rollout_weights(self):\n    self.actor_wg.sync_rollout_weights()\n    ray.get(self.rollout_wg.sync_rollout_weights())\n\n\n# fsdp model parameter sync\n@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\ndef sync_rollout_weights(self):\n    params = self._get_actor_params() if self._is_actor else None\n    if self._is_rollout:\n        inference_model = (\n            self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n        )\n        from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader\n        patch_vllm_moe_model_weight_loader(inference_model)\n    # Model parameters are broadcast tensor-by-tensor from actor to rollout\n    for key, shape, dtype in self._weights_info:\n        tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())\n        if self._is_actor:\n            assert key in params\n            origin_data = params[key]\n            if hasattr(origin_data, \"full_tensor\"):\n                origin_data = origin_data.full_tensor()\n            if torch.distributed.get_rank() == 0:\n                tensor.copy_(origin_data)\n        from ray.util.collective import collective\n\n        collective.broadcast(tensor, src_rank=0, group_name=\"actor_rollout\")\n        if self._is_rollout:\n            inference_model.load_weights([(key, tensor)])\n```\n\n## Usage\n\n### FSDP2 Configuration Example\n\n```shell\npython3 -m recipe.one_step_off_policy.async_main_ppo \\\n    --config-path=config \\\n    --config-name='one_step_off_ppo_trainer.yaml' \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    # actor and rollout are placed separately\n    actor_rollout_ref.hybrid_engine=False \\\n    # actor and rollout resource\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=6 \\\n    rollout.nnodes=1 \\\n    rollout.n_gpus_per_node=2\n```\n\n### Megatron Configuration Example\n\n```shell\npython3 -m recipe.one_step_off_policy.async_main_ppo \\\n    --config-path=config \\\n    --config-name='one_step_off_ppo_megatron_trainer.yaml' \\\n    actor_rollout_ref.actor.strategy=megatron \\\n    # actor and rollout are placed separately\n    actor_rollout_ref.hybrid_engine=False \\\n    # actor and rollout resource\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=6 \\\n    rollout.nnodes=1 \\\n    rollout.n_gpus_per_node=2\n```\n\n### Configuration Guidelines\n\n1. **Card Number Relationships**  \n   Maintain either of these relationships for optimal batch distribution:\n    - `actor_rollout_ref.rollout.n` should be an integer divisor of:  \n      `trainer.n_gpus_per_node * trainer.nnodes`\n    - `actor_rollout_ref.rollout.n * data.train_batch_size` should be evenly divisible by:  \n      `trainer.n_gpus_per_node * trainer.nnodes`\n\n   > Rationale: Ensures training samples can be evenly distributed across training GPUs when using partial resources for\n   generation.\n\n2. **Dynamic Resource Tuning**  \n   Adjust `trainer.nnodes` `trainer.n_gpus_per_node` `rollout.nnodes` `rollout.n_gpus_per_node` based on phase\n   durations:\n    - **Ideal state**: Rollout and training phases have comparable durations\n    - **Diagnostic metrics**:\n        - Monitor `wait_prev_gen` duration\n        - Analyze `sequence_length` distribution\n    - **Adjustment strategy**:\n        - High `wait_prev_gen` + uniform sequence lengths → Increase rollout resources\n        - High `wait_prev_gen` + long-tail sequences → Optimize stopping criteria (resource increase won't help)\n   > **wait_prev_gen**：The time consumed waiting for the previous rollout to end (the part that is not fully\n   overlapped).\n   **Resource Configuration Strategies:**\n    - **Resource-constrained scenario**: Optimize resource utilization by adjusting GPU allocation ratios,\n      keeping the number of nodes equal to allow training and rollout to share nodes;\n        - Configure `trainer.nnodes = rollout.nnodes` with\n          `trainer.n_gpus_per_node + rollout.n_gpus_per_node = physical_gpus_per_node`. Control rollout resource\n          allocation by adjusting `n_gpus_per_node`.\n    - **Resource-abundant scenario**: Optimize performance by adjusting the number of nodes,\n      keeping the number of GPUs per node equal to enable independent scaling of training and rollout\n      parallelism.\n        - Configure `trainer.n_gpus_per_node = rollout.n_gpus_per_node` and control rollout resource allocation by\n          adjusting `trainer.nnodes` and `rollout.nnodes`to achieve optimal performance.\n   > **Note**: The total number of nodes required by the system is not simply `trainer.nnodes + rollout.nnodes`. The\n   > actual calculation depends on GPU capacity:\n   > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node <= physical_gpus_per_node`,\n       > the required node count is `max(trainer.nnodes, rollout.nnodes)`\n   > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node > physical_gpus_per_node`,\n       > the required node count is `trainer.nnodes + rollout.nnodes`\n\n## Functional Support\n\n| Category           | Support Situation                                                                                               |\n|--------------------|-----------------------------------------------------------------------------------------------------------------|\n| train engine       | FSDP2  <br/> Megatron                                                                                           |\n| rollout engine     | vLLM <br/> SGLang                                                                                               |\n| AdvantageEstimator | GRPO <br/> GRPO_PASSK <br/> REINFORCE_PLUS_PLUS <br/> RLOO <br/> OPO <br/> REINFORCE_PLUS_PLUS_BASELINE<br/>GPG |\n| Reward             | all                                                                                                             |\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_megatron_trainer\n  - _self_\n\n# config for the rollout (only for resource isolation)\nrollout:\n  # Number of nodes used in the rollout\n  nnodes: 1\n  # Number of GPUs per node\n  n_gpus_per_node: 8\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\n# config for the rollout (only for resource isolation)\nrollout:\n  # Number of nodes used in the rollout\n  nnodes: 1\n  # Number of GPUs per node\n  n_gpus_per_node: 8\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/dapo_7b_math_fsdp2_4_12.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-one-step-off-4-12'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=12\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=2\nsp_size=4\nfsdp_size=2\n\npython3 -m recipe.one_step_off_policy.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\"\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/dapo_7b_math_fsdp2_colocate.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-colocate'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=12\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=2\nsp_size=4\nfsdp_size=2\n\n# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/dapo_7b_math_fsdp2_sglang_4_12.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-sglang-one-step-off-4-12'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=12\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=2\nsp_size=4\nfsdp_size=2\n\npython3 -m recipe.one_step_off_policy.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\"\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/dapo_7b_math_fsdp2_sglang_colocate.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-sglang-colocate'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=12\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=2\nsp_size=4\nfsdp_size=2\n\n# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/dapo_7b_math_megatron_4_12.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-megatron-one-step-off-4-12'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=12\ntrain_prompt_mini_bsz=32\n\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=2\ntrain_tp=2\ntrain_pp=2\n\n# TODO: support dynamic_bsz for megatron\n# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n\npython3 -m recipe.one_step_off_policy.main_ppo \\\n    --config-path=config \\\n    --config-name='one_step_off_ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=megatron \\\n    critic.strategy=megatron \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.param_offload=${ref_offload} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\"\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/dapo_7b_math_megatron_colocate.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0519a1-megatron-colocate'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=2\ntrain_tp=2\ntrain_pp=2\n\n# TODO: support dynamic_bsz for megatron\n# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=megatron \\\n    critic.strategy=megatron \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/distributed_util.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom verl.utils.device import is_npu_available\n\n\ndef stateless_init_process_group(master_address, master_port, rank, world_size, device):\n    \"\"\"\n    vLLM provides `StatelessProcessGroup` to create a process group\n    without considering the global process group in torch.distributed.\n    It is recommended to create `StatelessProcessGroup`, and then initialize\n    the data-plane communication (NCCL) between external (train processes)\n    and vLLM workers.\n    \"\"\"\n    # NOTE: If it is necessary to support weight synchronization with the sglang backend in the future,\n    # the following can be used:\n    # from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator\n    # from sglang.srt.distributed.utils import statelessprocessgroup\n    if is_npu_available:\n        from vllm_ascend.distributed.device_communicators.pyhccl import (\n            PyHcclCommunicator as PyNcclCommunicator,\n        )\n    else:\n        from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator\n    from vllm.distributed.utils import StatelessProcessGroup\n\n    pg = StatelessProcessGroup.create(host=master_address, port=master_port, rank=rank, world_size=world_size)\n    pynccl = PyNcclCommunicator(pg, device=device)\n    return pynccl\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/fsdp_workers.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\nimport torch.distributed\nfrom omegaconf import DictConfig, OmegaConf\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom transformers import AutoConfig\n\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register\nfrom verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass\nfrom verl.utils.device import (\n    get_device_id,\n    get_device_name,\n    get_nccl_backend,\n    get_torch_device,\n)\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.fsdp_utils import (\n    fsdp_version,\n)\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.model import get_generation_config, update_model_config\nfrom verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer\nfrom verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max\nfrom verl.utils.ray_utils import get_event_loop\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker as ARRWorker\nfrom verl.workers.fsdp_workers import CriticWorker\nfrom verl.workers.rollout import get_rollout_class\n\nfrom .distributed_util import stateless_init_process_group\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\ndevice_name = get_device_name()\n\n__all__ = [\"ActorRolloutRefWorker\", \"AsyncActorRolloutRefWorker\", \"CriticWorker\", \"RolloutWorker\"]\n\n\nclass ActorRolloutRefWorker(ARRWorker):\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\n    def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size):\n        rank = torch.distributed.get_rank() + rank_offset\n        self._weight_sync_group = stateless_init_process_group(\n            master_address,\n            master_port,\n            rank,\n            world_size,\n            get_torch_device().current_device(),\n        )\n\n    def _get_actor_params(self):\n        assert self._is_actor\n        params = self.actor_module_fsdp.state_dict()\n        from verl.utils.model import convert_weight_keys\n\n        params = convert_weight_keys(\n            params, getattr(self.actor_module_fsdp, \"_fsdp_wrapped_module\", self.actor_module_fsdp)\n        )\n        return params\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\n    def sync_rollout_weights(self):\n        assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine\n        assert hasattr(self, \"_weights_info\") and self._weights_info is not None\n\n        params = self._get_actor_params() if self._is_actor else None\n        rollout_name = self.config.rollout.name\n        if self._is_rollout:\n            if rollout_name == \"vllm\":\n                inference_model = (\n                    self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n                )\n                from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader\n\n                patch_vllm_moe_model_weight_loader(inference_model)\n            elif rollout_name == \"sglang\":\n                inference_model = self.rollout._engine\n            else:\n                raise NotImplementedError(f\"Unknown rollout name: {rollout_name}\")\n        loop = get_event_loop()\n        for key, shape, dtype in self._weights_info:\n            tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())\n            if self._is_actor:\n                assert key in params\n                origin_data = params[key]\n                if hasattr(origin_data, \"full_tensor\"):\n                    origin_data = origin_data.full_tensor()\n                if torch.distributed.get_rank() == 0:\n                    tensor.copy_(origin_data)\n\n            self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())\n            if self._is_rollout:\n                if rollout_name == \"vllm\":\n                    inference_model.load_weights([(key, tensor)])\n                elif rollout_name == \"sglang\":\n                    loop.run_until_complete(self.update_weights(inference_model, [(key, tensor)]))\n\n    async def update_weights(self, inference_engine, params):\n        from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights\n\n        await sgl_update_weights(\n            engine=inference_engine,\n            params_batch=params,\n            device_mesh_key=\"infer_tp\",\n            device_mesh=self.rollout_device_mesh,\n        )\n\n        if self.rollout_device_mesh[\"infer_tp\"].get_local_rank() == 0:\n            await inference_engine.flush_cache()\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def get_actor_weights_info(self):\n        assert self._is_actor\n        if hasattr(self, \"_weights_info\"):\n            return self._weights_info\n        if fsdp_version(self.actor_module_fsdp) == 1:\n            from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType\n\n            FSDP.set_state_dict_type(\n                self.actor_module_fsdp,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n        params = self._get_actor_params()\n        ret = []\n        for key, tensor in params.items():\n            ret.append((key, tensor.size(), tensor.dtype))\n        self._weights_info = ret\n        return ret\n\n\nclass RolloutWorker(ActorRolloutRefWorker):\n    def __init__(self, config: DictConfig, role: str):\n        Worker.__init__(self)\n        assert role == \"rollout\"\n        self.config = config\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ.get(\"RANK\", 0))\n            world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n            torch.distributed.init_process_group(\n                backend=f\"cpu:gloo,{get_device_name()}:{get_nccl_backend()}\",\n                rank=rank,\n                world_size=world_size,\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n        # TODO(haibin.lin):\n        # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig,\n        # it will actually convert the ProfilerConfig dataclass back to a DictConfig.\n        # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py)\n        # as they provides DictConfig-like interface\n        # The benefit of creating the dataclass config is to perform validation during __post_init__\n        omega_profiler_config = config.get(\"profiler\", {})\n        profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)\n        if omega_profiler_config.get(\"tool\", None) in [\"npu\", \"nsys\", \"torch\", \"torch_memory\"]:\n            tool_config = omega_conf_to_dataclass(\n                omega_profiler_config.get(\"tool_config\", {}).get(omega_profiler_config.get(\"tool\"))\n            )\n        else:\n            tool_config = None\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config)\n        )\n        self._is_rollout = True\n        self._is_actor = False\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n        override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get(\"override_config\", {})))\n\n        use_shm = self.config.model.get(\"use_shm\", False)\n        local_path = copy_to_local(self.config.model.path, use_shm=use_shm)\n        trust_remote_code = self.config.model.get(\"trust_remote_code\", False)\n\n        self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)\n\n        if self.config.model.get(\"custom_chat_template\", None) is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.config.model.custom_chat_template\n            else:\n                self.tokenizer.chat_template = self.config.model.custom_chat_template\n\n        # override model kwargs\n        actor_model_config = AutoConfig.from_pretrained(\n            local_path, trust_remote_code=trust_remote_code, attn_implementation=\"flash_attention_2\"\n        )\n\n        # patch for kimi-vl\n        if getattr(actor_model_config, \"model_type\", None) == \"kimi_vl\":\n            actor_model_config.text_config.topk_method = \"greedy\"\n\n        self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)\n\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_model_config)\n        update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)\n        if self.rank == 0:\n            print(f\"Model config after override: {actor_model_config}\")\n\n        infer_tp = self.config.rollout.tensor_model_parallel_size\n        dp = self.world_size // infer_tp\n        assert self.world_size % infer_tp == 0, (\n            f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n        )\n        rollout_device_mesh = init_device_mesh(\n            device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=[\"dp\", \"infer_tp\"]\n        )\n        self.rollout_device_mesh = rollout_device_mesh\n\n        is_collect = rollout_device_mesh[\"infer_tp\"].get_local_rank() == 0\n        self._register_dispatch_collect_info(\n            \"rollout\", dp_rank=rollout_device_mesh[\"dp\"].get_local_rank(), is_collect=is_collect\n        )\n\n        rollout_name = self.config.rollout.name\n        if rollout_name not in (\"vllm\", \"sglang\"):\n            raise NotImplementedError(f\"rollout_name: {rollout_name} is not supported\")\n\n        rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)\n        model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)\n        self.model_config = model_config\n\n        log_gpu_memory_usage(f\"Before building {rollout_name} rollout\", logger=logger)\n        rollout = get_rollout_class(rollout_config.name, rollout_config.mode)(\n            config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh\n        )\n        log_gpu_memory_usage(f\"After building {rollout_name} rollout\", logger=logger)\n\n        if rollout_name == \"vllm\":\n            from .vllm_sharding_manager import VLLMShardingManager\n\n            rollout_sharding_manager = VLLMShardingManager(\n                inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh\n            )\n\n            log_gpu_memory_usage(\"After building sharding manager\", logger=logger)\n        elif rollout_name == \"sglang\":\n            from .sglang_sharding_manager import SGLangShardingManager\n\n            rollout_sharding_manager = SGLangShardingManager(device_mesh=rollout_device_mesh)\n\n            log_gpu_memory_usage(\"After building sharding manager\", logger=logger)\n\n        self.model_config = model_config\n        self.rollout = rollout\n        self.rollout_sharding_manager = rollout_sharding_manager\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"rollout\"), blocking=False)\n    def async_generate_sequences(self, prompts):\n        # Support all hardwares\n        prompts = prompts.to(get_device_id())\n\n        assert self._is_rollout\n\n        meta_info = {\n            \"eos_token_id\": self.generation_config.eos_token_id\n            if self.generation_config is not None\n            else self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.generation_config.pad_token_id\n            if self.generation_config is not None\n            else self.tokenizer.pad_token_id,\n        }\n        prompts.meta_info.update(meta_info)\n        timing_generate = {}\n        with self.rollout_sharding_manager:\n            log_gpu_memory_usage(\"After entering rollout sharding manager\", logger=logger)\n\n            with simple_timer(\"generate_sequences\", timing_generate):\n                output = self.rollout.generate_sequences(prompts=prompts)\n\n            log_gpu_memory_usage(\"After rollout generation\", logger=logger)\n\n        timing_generate.update(self.rollout_sharding_manager.timing)\n        # We calculate the average timing across all ranks\n        # to make sure meta_info[\"timing\"] is the same\n        timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max(\n            timing_generate[\"generate_sequences\"]\n        )\n        timing_generate = reduce_timing(timing_generate)\n        timing_generate.update(\n            {\n                \"generation_timing/max\": timing_generate_max,\n                \"generation_timing/min\": timing_generate_min,\n                \"generation_timing/topk_ratio\": timing_generate_topk_ratio,\n            }\n        )\n        output.meta_info[\"timing\"] = timing_generate\n        output = output.to(\"cpu\")\n\n        # clear kv cache\n        get_torch_device().empty_cache()\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def set_actor_weights_info(self, weights_info):\n        assert self._is_rollout\n        self._weights_info = weights_info\n\n\nclass AsyncActorRolloutRefWorker(ActorRolloutRefWorker):\n    def __init__(self, *args, **kwargs):\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/grpo_0.6b_gsm8k_fsdp2_2_6.sh",
    "content": "set -x\n\nproject_name='GRPO'\nexp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6'\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-0.6B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/train.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/test.parquet\"}\n\nNNODES=${NNODES:-1}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\n\npython3 -m recipe.one_step_off_policy.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.train_batch_size=1152 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=192 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.val_before_train=True \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=2 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\" $@"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh",
    "content": "set -x\n\nproject_name='GRPO'\nexp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-sglang-one-step-off-2-6'\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-0.6B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/train.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/test.parquet\"}\n\nNNODES=${NNODES:-1}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\n\npython3 -m recipe.one_step_off_policy.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.train_batch_size=1152 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=192 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.val_before_train=True \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=2 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\" $@"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/grpo_3b_gsm8k_fsdp2_2_6.sh",
    "content": "set -x\n\nproject_name='GRPO'\nexp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6'\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen/Qwen2.5-3B-Instruct\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/train.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/test.parquet\"}\n\nNNODES=${NNODES:-1}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\npython3 -m recipe.one_step_off_policy.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.train_batch_size=1152 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=192 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.val_before_train=True \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=2 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\" $@"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/main_ppo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport os\nimport socket\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom recipe.one_step_off_policy.utils import need_critic\nfrom verl.trainer.constants_ppo import get_ppo_ray_runtime_env\nfrom verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler\nfrom verl.trainer.ppo.reward import load_reward_manager\nfrom verl.trainer.ppo.utils import need_reference_policy\nfrom verl.utils.config import validate_config\n\nfrom .ray_trainer import OneStepOffRayTrainer\n\n\n@hydra.main(config_path=\"config\", config_name=\"one_step_off_ppo_trainer\", version_base=None)\ndef main(config):\n    run_ppo(config)\n\n\n# Define a function to run the PPO-like training process\ndef run_ppo(config) -> None:\n    # Check if Ray is not initialized\n    if not ray.is_initialized():\n        # Initialize Ray with a local cluster configuration\n        # Set environment variables in the runtime environment to control tokenizer parallelism,\n        # NCCL debug level, VLLM logging level, and allow runtime LoRA updating\n        # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration\n        default_runtime_env = get_ppo_ray_runtime_env()\n        ray_init_kwargs = config.ray_kwargs.get(\"ray_init\", {})\n        runtime_env_kwargs = ray_init_kwargs.get(\"runtime_env\", {})\n        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)\n        ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, \"runtime_env\": runtime_env})\n        print(f\"ray init kwargs: {ray_init_kwargs}\")\n        ray.init(**OmegaConf.to_container(ray_init_kwargs))\n\n    # Create a remote instance of the TaskRunner class, and\n    # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete\n    if (\n        config.global_profiler.tool == \"nsys\"\n        and OmegaConf.select(config.global_profiler, \"steps\") is not None\n        and len(OmegaConf.select(config.global_profiler, \"steps\")) > 0\n    ):\n        nsight_options = OmegaConf.to_container(config.global_profiler.tool_config.nsys.controller_nsight_options)\n        runner = TaskRunner.options(runtime_env={\"nsight\": nsight_options}).remote()\n    else:\n        runner = TaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n    # [Optional] get the path of the timeline trace file from the configuration, default to None\n    # This file is used for performance analysis\n    timeline_json_file = config.ray_kwargs.get(\"timeline_json_file\", None)\n    if timeline_json_file:\n        ray.timeline(filename=timeline_json_file)\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    def run(self, config):\n        # Print the initial configuration. `resolve=True` will evaluate symbolic values.\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        print(f\"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}\")\n\n        pprint(OmegaConf.to_container(config, resolve=True))\n\n        OmegaConf.resolve(config)\n\n        # Define worker classes based on the actor strategy.\n        if config.actor_rollout_ref.actor.strategy == \"fsdp2\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray import RayWorkerGroup\n\n            from .fsdp_workers import (\n                ActorRolloutRefWorker,\n                AsyncActorRolloutRefWorker,\n                CriticWorker,\n                RolloutWorker,\n            )\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray import RayWorkerGroup\n\n            from .megatron_workers import (\n                ActorRolloutRefWorker,\n                AsyncActorRolloutRefWorker,\n                CriticWorker,\n                RolloutWorker,\n            )\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = RayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from .ray_trainer import ResourcePoolManager, Role\n\n        role_worker_mapping = {\n            Role.Actor: ray.remote(actor_rollout_cls),\n            Role.Rollout: ray.remote(RolloutWorker),\n            Role.Critic: ray.remote(CriticWorker),\n        }\n\n        global_pool_id = \"actor_pool\"\n\n        assert config.trainer.n_gpus_per_node > 0, \"config.trainer.n_gpus_per_node must be greater than 0\"\n        assert config.trainer.nnodes > 0, \"config.trainer.nnodes must be greater than 0\"\n        assert config.rollout.n_gpus_per_node > 0, \"config.rollout.n_gpus_per_node must be greater than 0\"\n        assert config.rollout.nnodes > 0, \"config.rollout.nnodes must be greater than 0\"\n\n        actor_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes\n        rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes\n\n        resource_pool_spec = {\n            \"actor_pool\": actor_pool,\n            \"rollout_pool\": rollout_pool,\n        }\n        mapping = {\n            Role.Actor: \"actor_pool\",\n            Role.Rollout: \"rollout_pool\",\n            Role.Critic: \"actor_pool\",\n        }\n        print(f\"resource_pool_spec: {resource_pool_spec}\")\n        # We should adopt a multi-source reward function here:\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # finally, we combine all the rewards together\n        # The reward type depends on the tag of the data\n        if config.reward_model.enable:\n            if config.reward_model.strategy in [\"fsdp2\"]:\n                from verl.workers.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # Add a reference policy worker if KL loss or KL reward is used.\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n            mapping[Role.RefPolicy] = global_pool_id\n\n        # validate config\n        validate_config(\n            config=config,\n            use_reference_policy=need_reference_policy(role_worker_mapping),\n            use_critic=need_critic(config),\n        )\n\n        # Download the checkpoint from HDFS to the local machine.\n        # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on\n        local_path = copy_to_local(\n            config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get(\"use_shm\", False)\n        )\n\n        # Instantiate the tokenizer and processor.\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        # Used for multimodal LLM, could be None\n        processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)\n\n        # Load the reward manager for training and validation.\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        val_reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=1, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        from verl.utils.dataset.rl_dataset import collate_fn\n\n        # Create training and validation datasets.\n        train_dataset = create_rl_dataset(\n            config.data.train_files,\n            config.data,\n            tokenizer,\n            processor,\n            max_samples=config.data.get(\"train_max_samples\", -1),\n        )\n        val_dataset = create_rl_dataset(\n            config.data.val_files, config.data, tokenizer, processor, max_samples=config.data.get(\"val_max_samples\", -1)\n        )\n        train_sampler = create_rl_sampler(config.data, train_dataset)\n\n        # Initialize the PPO trainer.\n        trainer = OneStepOffRayTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            train_dataset=train_dataset,\n            val_dataset=val_dataset,\n            collate_fn=collate_fn,\n            train_sampler=train_sampler,\n            device_name=config.trainer.device,\n        )\n        # Initialize the workers of the trainer.\n        trainer.init_workers()\n        # Start the training process.\n        trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/megatron_workers.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\nimport torch.distributed\nfrom omegaconf import DictConfig, OmegaConf\n\nfrom verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.debug import (\n    log_gpu_memory_usage,\n)\nfrom verl.utils.device import get_device_name, get_torch_device\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.megatron_workers import ActorRolloutRefWorker as ARRWorker\nfrom verl.workers.megatron_workers import CriticWorker, RewardModelWorker\nfrom verl.workers.rollout import get_rollout_class\n\nfrom .distributed_util import stateless_init_process_group\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n__all__ = [\"ActorRolloutRefWorker\", \"AsyncActorRolloutRefWorker\", \"CriticWorker\", \"RewardModelWorker\", \"RolloutWorker\"]\n\n\nclass ActorRolloutRefWorker(ARRWorker):\n    def __init__(self, config: DictConfig, role: str):\n        assert role in [\"actor\", \"ref\"]\n        tmp_role = \"ref\" if role == \"ref\" else \"actor_rollout\"\n        super().__init__(config, tmp_role)\n        if role == \"actor\":\n            self._is_rollout = False\n        self.role = role\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\n    def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size):\n        rank = torch.distributed.get_rank() + rank_offset\n        self._weight_sync_group = stateless_init_process_group(\n            master_address,\n            master_port,\n            rank,\n            world_size,\n            get_torch_device().current_device(),\n        )\n\n    def _get_actor_params_generator(self):\n        assert self._is_actor\n        from verl.models.mcore import get_mcore_weight_converter\n        from verl.utils.megatron_utils import per_tensor_generator\n\n        layer_name_mapping = {\n            \"qkv_layer_name\": \"self_attention.linear_qkv.\",\n            \"gate_proj_layer_name\": \"linear_fc1.\",\n        }\n        weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)\n        generator = per_tensor_generator(\n            self.actor.actor_module,\n            self.actor_model_config,\n            weight_converter,\n            self.tf_config,\n            layer_name_mapping,\n        )\n        return generator\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\n    def sync_rollout_weights(self):\n        assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine\n        assert hasattr(self, \"_weights_info\") and self._weights_info is not None\n\n        params_generator = self._get_actor_params_generator() if self._is_actor else None\n        if self._is_rollout:\n            inference_model = (\n                self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n            )\n            from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader\n\n            patch_vllm_moe_model_weight_loader(inference_model)\n        for key, shape, dtype in self._weights_info:\n            if self._is_actor:\n                weight_key, weight = next(params_generator)\n                assert key == weight_key\n                assert shape == weight.size()\n                assert dtype == weight.dtype\n\n            tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())\n            if self._is_actor and torch.distributed.get_rank() == 0:\n                tensor.copy_(weight)\n\n            self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())\n            if self._is_rollout:\n                inference_model.load_weights([(key, tensor)])\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def get_actor_weights_info(self):\n        assert self._is_actor\n        if hasattr(self, \"_weights_info\"):\n            return self._weights_info\n\n        params_generator = self._get_actor_params_generator()\n        ret = []\n        for key, tensor in params_generator:\n            ret.append((key, tensor.size(), tensor.dtype))\n\n        self._weights_info = ret\n        return ret\n\n\nclass RolloutWorker(ActorRolloutRefWorker):\n    def __init__(self, config: DictConfig, role: str):\n        assert role == \"rollout\"\n        ARRWorker.__init__(self, config, role)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        if self.config.model.get(\"external_lib\", None) is not None:\n            # This is used to import external_lib into the huggingface systems\n            import importlib\n\n            importlib.import_module(self.config.model.external_lib)\n\n        from verl.utils.torch_dtypes import PrecisionType\n\n        override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get(\"override_config\", {})))\n        override_transformer_config = {}\n        self.param_dtype = torch.bfloat16\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n        trust_remote_code = self.config.model.get(\"trust_remote_code\", False)\n\n        from verl.utils.model import get_generation_config\n\n        self._init_hf_config_and_tf_config(\n            self.config.model.path,\n            self.config.model.path,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            trust_remote_code,\n        )\n        self.generation_config = get_generation_config(self.local_path)\n\n        from torch.distributed.device_mesh import init_device_mesh\n\n        assert self.config.rollout.name == \"vllm\"\n        assert self.config.rollout.mode == \"sync\"\n\n        from .vllm_sharding_manager import VLLMShardingManager\n\n        # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,\n        # we will reorganize their weight format when resharding from actor to rollout.\n\n        infer_tp = self.config.rollout.tensor_model_parallel_size\n        dp = self.world_size // infer_tp\n        assert self.world_size % infer_tp == 0, (\n            f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n        )\n        rollout_device_mesh = init_device_mesh(\n            get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=[\"dp\", \"infer_tp\"]\n        )\n        is_collect = rollout_device_mesh[\"infer_tp\"].get_local_rank() == 0\n        self._register_dispatch_collect_info(\n            \"rollout\", dp_rank=rollout_device_mesh[\"dp\"].get_local_rank(), is_collect=is_collect\n        )\n        log_gpu_memory_usage(\"Before building vllm rollout\", logger=None)\n\n        rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)\n        model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)\n        rollout = get_rollout_class(rollout_config.name, rollout_config.mode)(\n            config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh\n        )\n        log_gpu_memory_usage(\"After building vllm rollout\", logger=logger)\n\n        sharding_manager = VLLMShardingManager(\n            inference_engine=rollout.inference_engine,\n            device_mesh=rollout_device_mesh,\n        )\n        log_gpu_memory_usage(\"After building sharding manager\", logger=logger)\n\n        self.rollout, self.sharding_manager = rollout, sharding_manager\n        self.rollout.sharding_manager = sharding_manager\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"rollout\"), blocking=False)\n    def async_generate_sequences(self, *args, **kwargs):\n        return super().generate_sequences(*args, **kwargs)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def set_actor_weights_info(self, weights_info):\n        assert self._is_rollout\n        self._weights_info = weights_info\n\n\nclass AsyncActorRolloutRefWorker(ActorRolloutRefWorker):\n    def __init__(self, *args, **kwargs):\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport uuid\nfrom pprint import pprint\n\nimport numpy as np\nimport ray\nimport torch\nfrom omegaconf import OmegaConf\nfrom torch.utils.data import Dataset, Sampler\nfrom tqdm import tqdm\n\nfrom recipe.one_step_off_policy.utils import need_critic\nfrom verl import DataProto\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.ppo import core_algos\nfrom verl.trainer.ppo.core_algos import agg_loss\nfrom verl.trainer.ppo.metric_utils import (\n    compute_data_metrics,\n    compute_throughout_metrics,\n    compute_timing_metrics,\n)\nfrom verl.trainer.ppo.ray_trainer import (\n    RayPPOTrainer,\n    ResourcePoolManager,\n    apply_kl_penalty,\n    compute_advantage,\n    compute_response_mask,\n)\nfrom verl.trainer.ppo.reward import compute_reward, compute_reward_async\nfrom verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model\nfrom verl.utils.debug import marked_timer\nfrom verl.utils.metric import (\n    reduce_metrics,\n)\nfrom verl.utils.tracking import ValidationGenerationsLogger\n\n\nclass GenerationBatchFuture:\n    \"\"\"\n    Wrapper class for encapsulating batch generation results\n    \"\"\"\n\n    def __init__(self, epoch, batch, gen_batch_output, future_reward=None):\n        \"\"\"\n        :param epoch: current epoch\n        :param batch: Input batch data\n        :param gen_batch_output: Generated sequences from the main model (DataProtoFuture)\n        :param future_reward: Future for reward computation (optional)\n        \"\"\"\n        self.epoch = epoch\n        self.batch = batch\n        self.gen_batch_output = gen_batch_output\n        self.future_reward = future_reward\n\n    def get(self):\n        \"\"\"\n        Get the actual results by calling get() method on gen_batch_output\n\n        Returns:\n            tuple: (epoch, batch, gen_batch_result, future_reward)\n                - epoch: Current epoch\n                - batch: Original input batch data\n                - gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself\n                - future_reward: Future for reward computation if available, else None\n        \"\"\"\n        # Call get() method on gen_batch_output if available\n        if hasattr(self.gen_batch_output, \"get\"):\n            gen_batch_result = self.gen_batch_output.get()\n        else:\n            gen_batch_result = self.gen_batch_output\n\n        return self.epoch, self.batch, gen_batch_result, self.future_reward\n\n\nclass OneStepOffRayTrainer(RayPPOTrainer):\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        train_dataset: Dataset | None = None,\n        val_dataset: Dataset | None = None,\n        collate_fn=None,\n        train_sampler: Sampler | None = None,\n        device_name=\"cuda\",\n    ):\n        \"\"\"\n        Initialize distributed PPO trainer with Ray backend.\n        Note that this trainer runs on the driver process on a single CPU/GPU node.\n\n        Args:\n            config: Configuration object containing training parameters.\n            tokenizer: Tokenizer used for encoding and decoding text.\n            role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes.\n            resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools.\n            ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup.\n            processor: Optional data processor, used for multimodal data\n            reward_fn: Function for computing rewards during training.\n            val_reward_fn: Function for computing rewards during validation.\n            train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.\n            val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.\n            collate_fn: Function to collate data samples into batches.\n            train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.\n            device_name (str, optional): Device name for training (e.g., \"cuda\", \"cpu\"). Defaults to \"cuda\".\n        \"\"\"\n\n        # Store the tokenizer for text processing\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n\n        assert not self.hybrid_engine\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = need_reference_policy(self.role_worker_mapping)\n        self.use_rm = need_reward_model(self.role_worker_mapping)\n        self.use_critic = need_critic(config)\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.device_name = device_name\n        self.validation_generations_logger = ValidationGenerationsLogger()\n\n        # if ref_in_actor is True, the reference policy will be actor without lora applied\n        self.ref_in_actor = config.actor_rollout_ref.model.get(\"lora_rank\", 0) > 0\n\n        # define in-reward KL control\n        # kl loss control currently not suppoorted\n        if config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)\n\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n    def _validate(self):\n        self.actor_rollout_wg = self.rollout_wg\n        ret = super()._validate()\n        self.actor_rollout_wg = self.actor_wg\n        return ret\n\n    def init_workers(self):\n        \"\"\"Initialize distributed training workers using Ray backend.\n\n        Creates:\n        1. Ray resource pools from configuration\n        2. Worker groups for each role (actor, critic, etc.)\n        \"\"\"\n        self.resource_pool_manager.create_resource_pool()\n\n        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}\n\n        # create actor and rollout\n        for role, role_name in [(Role.Actor, \"actor\"), (Role.Rollout, \"rollout\")]:\n            resource_pool = self.resource_pool_manager.get_resource_pool(role)\n            role_cls = RayClassWithInitArgs(\n                cls=self.role_worker_mapping[role],\n                config=self.config.actor_rollout_ref,\n                role=role_name,\n            )\n            self.resource_pool_to_cls[resource_pool][role_name] = role_cls\n\n        # create critic\n        if self.use_critic:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)\n            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)\n            self.resource_pool_to_cls[resource_pool][\"critic\"] = critic_cls\n\n        # create reference policy if needed\n        if self.use_reference_policy:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)\n            ref_policy_cls = RayClassWithInitArgs(\n                self.role_worker_mapping[Role.RefPolicy],\n                config=self.config.actor_rollout_ref,\n                role=\"ref\",\n            )\n            self.resource_pool_to_cls[resource_pool][\"ref\"] = ref_policy_cls\n\n        # create a reward model if reward_fn is None\n        if self.use_rm:\n            # we create a RM here\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)\n            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)\n            self.resource_pool_to_cls[resource_pool][\"rm\"] = rm_cls\n\n        # initialize WorkerGroup\n        # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,\n        # you should not use `create_colocated_worker_cls`.\n        # Instead, directly pass different resource pool to different worker groups.\n        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.\n        all_wg = {}\n        wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup\n        if OmegaConf.select(self.config.trainer, \"ray_wait_register_center_timeout\") is not None:\n            wg_kwargs[\"ray_wait_register_center_timeout\"] = self.config.trainer.ray_wait_register_center_timeout\n        if OmegaConf.select(self.config.global_profiler, \"steps\") is not None:\n            wg_kwargs[\"profile_steps\"] = OmegaConf.select(self.config.trainer, \"steps\")\n            assert (\n                OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, \"worker_nsight_options\")\n                is not None\n            ), \"worker_nsight_options must be set when profile_steps is set\"\n            wg_kwargs[\"worker_nsight_options\"] = OmegaConf.to_container(\n                OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, \"worker_nsight_options\")\n            )\n\n        for resource_pool, class_dict in self.resource_pool_to_cls.items():\n            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n            wg_dict = self.ray_worker_group_cls(\n                resource_pool=resource_pool,\n                ray_cls_with_init=worker_dict_cls,\n                device_name=self.device_name,\n                **wg_kwargs,\n            )\n            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n            all_wg.update(spawn_wg)\n\n        if self.use_critic:\n            self.critic_wg = all_wg[\"critic\"]\n            self.critic_wg.init_model()\n\n        if self.use_reference_policy and not self.ref_in_actor:\n            self.ref_policy_wg = all_wg[\"ref\"]\n            self.ref_policy_wg.init_model()\n\n        if self.use_rm:\n            self.rm_wg = all_wg[\"rm\"]\n            self.rm_wg.init_model()\n\n        self.actor_wg = all_wg[\"actor\"]\n        self.rollout_wg = all_wg[\"rollout\"]\n        self.actor_wg.init_model()\n        self.rollout_wg.init_model()\n        self.actor_rollout_wg = self.actor_wg  # to be compatible with the functions that not be modified\n        weights_info = self.actor_wg.get_actor_weights_info()[0]\n        self.rollout_wg.set_actor_weights_info(weights_info)\n\n        self.create_weight_sync_group()\n        self.sync_rollout_weights()\n\n        # create async rollout manager and request scheduler\n        self.async_rollout_mode = False\n        if self.config.actor_rollout_ref.rollout.mode == \"async\" and self._is_rollout:\n            from verl.workers.rollout.async_server import AsyncLLMServerManager\n\n            self.async_rollout_mode = True\n            self.async_rollout_manager = AsyncLLMServerManager(\n                config=self.config,\n                worker_group=self.rollout_wg,\n            )\n\n    def create_weight_sync_group(self):\n        master_address = ray.get(self.actor_wg.workers[0]._get_node_ip.remote())\n        master_port = ray.get(self.actor_wg.workers[0]._get_free_port.remote())\n        world_size = len(self.actor_wg.workers + self.rollout_wg.workers)\n        self.actor_wg.create_weight_sync_group(\n            master_address,\n            master_port,\n            0,\n            world_size,\n        )\n        ray.get(\n            self.rollout_wg.create_weight_sync_group(\n                master_address,\n                master_port,\n                len(self.actor_wg.workers),\n                world_size,\n            )\n        )\n\n    def sync_rollout_weights(self):\n        if not self.hybrid_engine:\n            self.actor_wg.sync_rollout_weights()\n            ray.get(self.rollout_wg.sync_rollout_weights())\n\n    def _create_continuous_iterator(self):\n        \"\"\"\n        Create a continuous data iterator across epoch\n        \"\"\"\n        for epoch in range(self.config.trainer.total_epochs):\n            iterator = iter(self.train_dataloader)\n            for batch_dict in iterator:\n                yield epoch, batch_dict\n\n    def _async_gen_next_batch(self, continuous_iterator):\n        \"\"\"\n        Call parameter synchronization and asynchronous sequence generation.\n        \"\"\"\n        try:\n            epoch, batch_dict = next(continuous_iterator)\n        except StopIteration:\n            return None\n        except Exception as e:\n            print(f\"Error in async_gen_next_batch: {e}\")\n            return None\n\n        # Create the initial batch from the data loader\n        batch = DataProto.from_single_dict(batch_dict)\n\n        # pop those keys for generation\n        batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n        non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n        if \"multi_modal_data\" in batch.non_tensor_batch:\n            non_tensor_batch_keys_to_pop.append(\"multi_modal_data\")\n        if \"raw_prompt\" in batch.non_tensor_batch:\n            non_tensor_batch_keys_to_pop.append(\"raw_prompt\")\n        if \"tools_kwargs\" in batch.non_tensor_batch:\n            non_tensor_batch_keys_to_pop.append(\"tools_kwargs\")\n        if \"interaction_kwargs\" in batch.non_tensor_batch:\n            non_tensor_batch_keys_to_pop.append(\"interaction_kwargs\")\n\n        gen_batch = batch.pop(\n            batch_keys=batch_keys_to_pop,\n            non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n        )\n        gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n\n        # sync weights from actor to rollout\n        self.sync_rollout_weights()\n\n        # async generation\n        gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)\n\n        # Launch individual reward computations as each generation completes\n        future_reward = None\n        if self.config.reward_model.launch_reward_fn_async:\n            # Store the object reference and set up callback\n            future_reward = self._launch_individual_rewards.remote(\n                gen_batch_output, self.config, self.tokenizer, batch.non_tensor_batch\n            )\n\n        # Return the original, now-modified `batch` and the `future_reward`\n        return GenerationBatchFuture(epoch, batch, gen_batch_output, future_reward)\n\n    @staticmethod\n    @ray.remote\n    def _launch_individual_rewards(gen_batch_output, config, tokenizer, original_non_tensor_batch):\n        # Get generation results\n        gen_batch_result = gen_batch_output.get()\n\n        # Repeat non_tensor_batch to match the number of responses\n        n = config.actor_rollout_ref.rollout.n\n        repeated_non_tensor_batch = {}\n        for key, value in original_non_tensor_batch.items():\n            repeated_non_tensor_batch[key] = np.repeat(value, n, axis=0)\n\n        # Split into individual responses with preserved non_tensor_batch\n        responses_split = []\n        for i in range(len(gen_batch_result)):\n            response_data = gen_batch_result[i : i + 1]  # Get single response\n            # Add repeated non_tensor_batch values\n            for key in repeated_non_tensor_batch:\n                response_data.non_tensor_batch[key] = repeated_non_tensor_batch[key][i : i + 1]\n            responses_split.append(response_data)\n\n        # Launch async reward computation\n        reward_futures = [\n            compute_reward_async.remote(response_data, config, tokenizer) for response_data in responses_split\n        ]\n\n        # Wait for results and combine\n        results = ray.get(reward_futures)\n        rewards_list = [r[0] for r in results]\n        extras_list = [r[1] for r in results]\n\n        combined_reward_tensor = torch.cat(rewards_list, dim=0)\n        combined_extras_dict = {}\n        if extras_list and extras_list[0]:\n            for key in extras_list[0].keys():\n                combined_extras_dict[key] = [d[key] for d in extras_list if key in d]\n\n        return combined_reward_tensor, combined_extras_dict\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n\n        # across epoch iterator\n        continuous_iterator = self._create_continuous_iterator()\n\n        # Start the first asynchronous generation task.\n        batch_data_future = self._async_gen_next_batch(continuous_iterator)\n\n        while batch_data_future is not None:\n            do_profile = (\n                self.global_steps in self.config.global_profiler.steps\n                if self.config.global_profiler.steps is not None\n                else False\n            )\n            if do_profile:\n                self.actor_wg.start_profile()\n                if not self.hybrid_engine:\n                    self.rollout_wg.start_profile()\n                if self.use_reference_policy:\n                    self.ref_policy_wg.start_profile()\n                if self.use_critic:\n                    self.critic_wg.start_profile()\n                if self.use_rm:\n                    self.rm_wg.start_profile()\n\n            metrics = {}\n            timing_raw = {}\n            is_last_step = self.global_steps >= self.total_training_steps\n\n            with marked_timer(\"step\", timing_raw):\n                # wait for the previous batch\n                with marked_timer(\"wait_prev_gen\", timing_raw, color=\"red\"):\n                    epoch, batch, gen_batch_output, future_reward = batch_data_future.get()\n                    timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                    gen_batch_output.meta_info.pop(\"timing\", None)\n\n                # asys next generation (with syns weights from actor to rollout)\n                with marked_timer(\"sync_rollout_weights\", timing_raw, color=\"purple\"):\n                    if not is_last_step:\n                        batch_data_future = self._async_gen_next_batch(continuous_iterator)\n\n                batch.non_tensor_batch[\"uid\"] = np.array(\n                    [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object\n                )\n                # repeat to align with repeated responses in rollout\n                batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                batch = batch.union(gen_batch_output)\n\n                batch.batch[\"response_mask\"] = compute_response_mask(batch)\n                # Balance the number of valid tokens across DP ranks.\n                # NOTE: This usually changes the order of data in the `batch`,\n                # which won't affect the advantage calculation (since it's based on uid),\n                # but might affect the loss calculation (due to the change of mini-batching).\n                # TODO: Decouple the DP balancing and mini-batching.\n                if self.config.trainer.balance_batch:\n                    self._balance_batch(batch, metrics=metrics)\n\n                # compute global_valid tokens\n                batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                with marked_timer(\"reward\", timing_raw, color=\"yellow\"):\n                    # compute reward model score\n                    if self.use_rm:\n                        reward_tensor = self.rm_wg.compute_rm_score(batch)\n                        batch = batch.union(reward_tensor)\n\n                    # Use the pre-launched future reward if available\n                    if self.config.reward_model.launch_reward_fn_async:\n                        # future_reward was already started in _async_gen_next_batch\n                        reward_tensor, reward_extra_infos_dict = ray.get(future_reward)\n                    else:\n                        reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)\n\n                # recompute old_log_probs\n                with marked_timer(\"old_log_prob\", timing_raw, color=\"blue\"):\n                    old_log_prob = self.actor_wg.compute_log_prob(batch)\n                    entropys = old_log_prob.batch[\"entropys\"]\n                    response_masks = batch.batch[\"response_mask\"]\n                    loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                    entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                    old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n                    metrics.update(old_log_prob_metrics)\n                    old_log_prob.batch.pop(\"entropys\")\n                    batch = batch.union(old_log_prob)\n\n                    if \"rollout_log_probs\" in batch.batch.keys():\n                        # TODO: we may want to add diff of probs too.\n                        rollout_old_log_probs = batch.batch[\"rollout_log_probs\"]\n                        actor_old_log_probs = batch.batch[\"old_log_probs\"]\n                        attention_mask = batch.batch[\"attention_mask\"]\n                        responses = batch.batch[\"responses\"]\n                        response_length = responses.size(1)\n                        response_mask = attention_mask[:, -response_length:]\n\n                        rollout_probs = torch.exp(rollout_old_log_probs)\n                        actor_probs = torch.exp(actor_old_log_probs)\n                        rollout_probs_diff = torch.abs(rollout_probs - actor_probs)\n                        rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())\n                        rollout_probs_diff_max = torch.max(rollout_probs_diff)\n                        rollout_probs_diff_mean = torch.mean(rollout_probs_diff)\n                        rollout_probs_diff_std = torch.std(rollout_probs_diff)\n                        metrics.update(\n                            {\n                                \"training/rollout_probs_diff_max\": rollout_probs_diff_max.detach().item(),\n                                \"training/rollout_probs_diff_mean\": rollout_probs_diff_mean.detach().item(),\n                                \"training/rollout_probs_diff_std\": rollout_probs_diff_std.detach().item(),\n                            }\n                        )\n\n                if self.use_reference_policy:\n                    # compute reference log_prob\n                    with marked_timer(\"ref\", timing_raw, color=\"olive\"):\n                        if not self.ref_in_actor:\n                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                        else:\n                            ref_log_prob = self.actor_wg.compute_ref_log_prob(batch)\n                        batch = batch.union(ref_log_prob)\n\n                # compute values\n                if self.use_critic:\n                    with marked_timer(\"values\", timing_raw, color=\"cyan\"):\n                        values = self.critic_wg.compute_values(batch)\n                        batch = batch.union(values)\n\n                with marked_timer(\"adv\", timing_raw, color=\"brown\"):\n                    # we combine with rule-based rm\n                    reward_extra_infos_dict: dict[str, list]\n                    batch.batch[\"token_level_scores\"] = reward_tensor\n\n                    if reward_extra_infos_dict:\n                        batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})\n\n                    # compute rewards. apply_kl_penalty if available\n                    if self.config.algorithm.use_kl_in_reward:\n                        batch, kl_metrics = apply_kl_penalty(\n                            batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                        )\n                        metrics.update(kl_metrics)\n                    else:\n                        batch.batch[\"token_level_rewards\"] = batch.batch[\"token_level_scores\"]\n\n                    # Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)\n                    batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)\n                    # IS and mismatch metrics already have mismatch/ prefix\n                    metrics.update(is_metrics)\n\n                    # compute advantages, executed on the driver process\n\n                    norm_adv_by_std_in_grpo = self.config.algorithm.get(\n                        \"norm_adv_by_std_in_grpo\", True\n                    )  # GRPO adv normalization factor\n\n                    batch = compute_advantage(\n                        batch,\n                        adv_estimator=self.config.algorithm.adv_estimator,\n                        gamma=self.config.algorithm.gamma,\n                        lam=self.config.algorithm.lam,\n                        num_repeat=self.config.actor_rollout_ref.rollout.n,\n                        norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                        config=self.config.algorithm,\n                    )\n\n                # update critic\n                if self.use_critic:\n                    with marked_timer(\"update_critic\", timing_raw, color=\"pink\"):\n                        critic_output = self.critic_wg.update_critic(batch)\n                    critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                    metrics.update(critic_output_metrics)\n\n                # implement critic warmup\n                if self.config.trainer.critic_warmup <= self.global_steps:\n                    # update actor\n                    with marked_timer(\"update_actor\", timing_raw, color=\"red\"):\n                        batch.meta_info[\"multi_turn\"] = self.config.actor_rollout_ref.rollout.multi_turn.enable\n                        actor_output = self.actor_wg.update_actor(batch)\n                    actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                    metrics.update(actor_output_metrics)\n\n                # Log rollout generations if enabled\n                rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n                if rollout_data_dir:\n                    with marked_timer(\"dump_rollout_generations\", timing_raw, color=\"green\"):\n                        inputs = self.tokenizer.batch_decode(batch.batch[\"prompts\"], skip_special_tokens=True)\n                        outputs = self.tokenizer.batch_decode(batch.batch[\"responses\"], skip_special_tokens=True)\n                        scores = batch.batch[\"token_level_scores\"].sum(-1).cpu().tolist()\n                        self._dump_generations(\n                            inputs=inputs,\n                            outputs=outputs,\n                            scores=scores,\n                            reward_extra_infos_dict=reward_extra_infos_dict,\n                            dump_path=rollout_data_dir,\n                        )\n\n            # validate\n            if (\n                self.val_reward_fn is not None\n                and self.config.trainer.test_freq > 0\n                and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n            ):\n                with marked_timer(\"testing\", timing_raw, color=\"green\"):\n                    val_metrics: dict = self._validate()\n                    if is_last_step:\n                        last_val_metrics = val_metrics\n                metrics.update(val_metrics)\n\n            if self.config.trainer.save_freq > 0 and (\n                is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n            ):\n                with marked_timer(\"save_checkpoint\", timing_raw, color=\"green\"):\n                    self._save_checkpoint()\n\n            # training metrics\n            metrics.update(\n                {\n                    \"training/global_step\": self.global_steps,\n                    \"training/epoch\": epoch,\n                }\n            )\n            # collect metrics\n            metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n            metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n            # TODO: implement actual tflpo and theoretical tflpo\n            n_gpus = self.resource_pool_manager.get_n_gpus()\n            metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n\n            # TODO: make a canonical logger that supports various backend\n            logger.log(data=metrics, step=self.global_steps)\n\n            progress_bar.update(1)\n            self.global_steps += 1\n\n            if do_profile:\n                self.actor_wg.stop_profile()\n                if not self.hybrid_engine:\n                    self.rollout_wg.stop_profile()\n                if self.use_reference_policy:\n                    self.ref_policy_wg.stop_profile()\n                if self.use_critic:\n                    self.critic_wg.stop_profile()\n                if self.use_rm:\n                    self.rm_wg.stop_profile()\n\n            if is_last_step:\n                pprint(f\"Final validation metrics: {last_val_metrics}\")\n                progress_bar.close()\n                return\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/sglang_sharding_manager.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nfrom torch.distributed.device_mesh import DeviceMesh\n\nfrom verl import DataProto\nfrom verl.protocol import all_gather_data_proto\nfrom verl.utils.debug import GPUMemoryLogger\nfrom verl.utils.device import get_torch_device\nfrom verl.utils.torch_functional import check_device_is_available\nfrom verl.workers.sharding_manager.base import BaseShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass SGLangShardingManager(BaseShardingManager):\n    @check_device_is_available()\n    def __init__(self, device_mesh: DeviceMesh):\n        self.device_mesh = device_mesh\n        self.tp_size = self.device_mesh[\"infer_tp\"].size()\n        self.tp_rank = self.device_mesh[\"infer_tp\"].get_local_rank()\n        self.timing = {}\n        gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n        get_torch_device().manual_seed(gen_dp_rank + 1000)\n        self.gen_random_states = get_torch_device().get_rng_state()\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def __enter__(self):\n        get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.gen_random_states = get_torch_device().get_rng_state()\n        get_torch_device().empty_cache()\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"All gather across tp group to make each rank has identical input.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        # TODO: Current impl doesn't consider FSDP with torch micro-dp\n        group = self.device_mesh[\"infer_tp\"].get_group()\n\n        all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"Get chunk data of this tp rank since we do all gather in preprocess.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        return data.chunk(chunks=self.tp_size)[self.tp_rank]\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom omegaconf import DictConfig\n\nfrom verl.trainer.ppo.core_algos import AdvantageEstimator\n\n\ndef need_critic(config: DictConfig) -> bool:\n    \"\"\"Given a config, do we need critic\"\"\"\n    if config.algorithm.adv_estimator == AdvantageEstimator.GAE:\n        return True\n    elif config.algorithm.adv_estimator in [\n        AdvantageEstimator.GRPO,\n        AdvantageEstimator.GRPO_PASSK,\n        AdvantageEstimator.REINFORCE_PLUS_PLUS,\n        # AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy\n        AdvantageEstimator.RLOO,\n        AdvantageEstimator.OPO,\n        AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,\n        AdvantageEstimator.GPG,\n    ]:\n        return False\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_distillation/recipe/one_step_off_policy/vllm_sharding_manager.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nfrom torch.distributed.device_mesh import DeviceMesh\n\nfrom verl import DataProto\nfrom verl.protocol import all_gather_data_proto\nfrom verl.third_party.vllm import parallel_state as vllm_ps\nfrom verl.utils.debug import GPUMemoryLogger\nfrom verl.utils.device import get_torch_device\nfrom verl.utils.torch_functional import check_device_is_available\nfrom verl.workers.sharding_manager.base import BaseShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass VLLMShardingManager(BaseShardingManager):\n    @check_device_is_available()\n    def __init__(self, inference_engine, device_mesh: DeviceMesh):\n        self.device_mesh = device_mesh\n        self.inference_engine = inference_engine\n        inference_engine.wake_up()\n        assert device_mesh is not None\n        assert inference_engine is not None\n        self.tp_size = self.device_mesh[\"infer_tp\"].size()\n        self.tp_rank = self.device_mesh[\"infer_tp\"].get_local_rank()\n        self.timing = {}\n        gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n        get_torch_device().manual_seed(gen_dp_rank + 1000)\n        self.gen_random_states = get_torch_device().get_rng_state()\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def __enter__(self):\n        get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.gen_random_states = get_torch_device().get_rng_state()\n        self.inference_engine.reset_prefix_cache()\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"All gather across tp group to make each rank has identical input.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        group = vllm_ps.get_tensor_model_parallel_group().device_group\n\n        all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"Get chunk data of this tp rank since we do all gather in preprocess.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        return data.chunk(chunks=self.tp_size)[self.tp_rank]\n"
  },
  {
    "path": "verl_distillation/recipe/onpolicy_distill/__init__.py",
    "content": "# On-policy distillation recipe package.\n\n\n"
  },
  {
    "path": "verl_distillation/recipe/onpolicy_distill/config/onpolicy_distill_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\n# This recipe reuses `ppo_trainer` config and only provides a new training entry.\n# You are expected to override most fields (models, data, rollout, etc.) via CLI.\ndata:\n  gen_batch_size: ${data.train_batch_size}\n\ntrainer:\n  project_name: verl-on-policy-distill\n\n\n"
  },
  {
    "path": "verl_distillation/recipe/onpolicy_distill/main_onpolicy_distill.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other mpain.\n\"\"\"\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom verl.trainer.main_ppo import TaskRunner, run_ppo\nfrom verl.utils.import_utils import load_extern_type\n\nfrom .onpolicy_distill_trainer import RayOnPolicyDistillTrainer\n\n\ndef create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True, max_samples: int = -1):\n    \"\"\"Create a dataset.\n\n    Arguments:\n        data_paths: List of paths to data files.\n        data_config: The data config.\n        tokenizer (Tokenizer): The tokenizer.\n        processor (Processor): The processor.\n\n    Returns:\n        dataset (Dataset): The dataset.\n    \"\"\"\n    from torch.utils.data import Dataset\n\n    from verl.utils.dataset.onerec_dataset import OneRecDataset\n\n    # Check if a custom dataset class is specified in the data configuration\n    # and if the path to the custom class is provided\n    if \"custom_cls\" in data_config and data_config.custom_cls.get(\"path\", None) is not None:\n        # Dynamically load the custom dataset class\n        dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)\n        # Verify that the custom dataset class inherits from torch.utils.data.Dataset\n        if not issubclass(dataset_cls, Dataset):\n            raise TypeError(\n                f\"The custom dataset class '{data_config.custom_cls.name}' from \"\n                f\"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset\"\n            )\n    elif \"datagen\" in data_config and data_config.datagen.get(\"path\", None) is not None and is_train:\n        # If a data generation strategy is specified, use the DynamicGenDataset class\n        from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset\n\n        dataset_cls = DynamicGenDataset\n        print(\"Using DynamicGenDataset for data generation.\")\n    else:\n        # Use the default RLHFDataset class if no custom class is specified\n        dataset_cls = OneRecDataset\n    print(f\"Using dataset class: {dataset_cls.__name__}\")\n\n    # Instantiate the dataset using the determined dataset class\n    dataset = dataset_cls(\n        data_files=data_paths,\n        tokenizer=tokenizer,\n        processor=processor,\n        config=data_config,\n        max_samples=max_samples,\n    )\n\n    return dataset\n\n@ray.remote(num_cpus=1)\nclass OnPolicyDistillTaskRunner(TaskRunner):\n\n    def run(self, config):\n        import os\n        import socket\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.trainer.ppo.reward import load_reward_manager\n        from verl.trainer.ppo.utils import need_critic, need_reference_policy\n        from verl.utils import hf_processor, hf_tokenizer\n        from verl.utils.config import validate_config\n        from verl.utils.dataset.rl_dataset import collate_fn\n        from verl.utils.fs import copy_to_local\n        from verl.utils.import_utils import load_extern_type\n\n        print(f\"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}\")\n        pprint(OmegaConf.to_container(config, resolve=True))\n        OmegaConf.resolve(config)\n\n        # Initialize role worker mapping\n        self.role_worker_mapping = {}\n        self.mapping = {}\n\n        # Add actor rollout worker based on the actor strategy\n        actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)\n\n        # Add critic worker to role mapping\n        self.add_critic_worker(config)\n\n        # Add reward model worker if enabled\n        self.add_reward_model_worker(config)\n\n        # Add a reference policy worker if KL loss or KL reward is used\n        self.add_ref_policy_worker(config, actor_rollout_cls)\n\n        # validate config\n        validate_config(\n            config=config,\n            use_reference_policy=need_reference_policy(self.role_worker_mapping),\n            use_critic=need_critic(config),\n        )\n\n        # Download the checkpoint from HDFS to the local machine\n        local_path = copy_to_local(\n            config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get(\"use_shm\", False)\n        )\n\n        # Instantiate the tokenizer and processor\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        # Used for multimodal LLM, could be None\n        processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)\n        if config.actor_rollout_ref.model.get(\"custom_chat_template\", None) is not None:\n            print(f'{config.actor_rollout_ref.model.custom_chat_template=}')\n            if processor is not None:\n                processor.chat_template = config.actor_rollout_ref.model.custom_chat_template\n            if tokenizer is not None:\n                tokenizer.chat_template = config.actor_rollout_ref.model.custom_chat_template\n\n        # Load the reward manager for training and validation\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        val_reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=1, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n\n        # Initialize resource pool manager\n        resource_pool_manager = self.init_resource_pool_mgr(config)\n\n        # Create training and validation datasets\n        from verl.trainer.main_ppo import create_rl_sampler\n        train_dataset = create_rl_dataset(\n            config.data.train_files,\n            config.data,\n            tokenizer,\n            processor,\n            is_train=True,\n            max_samples=config.data.get(\"train_max_samples\", -1),\n        )\n        val_dataset = create_rl_dataset(\n            config.data.val_files,\n            config.data,\n            tokenizer,\n            processor,\n            is_train=False,\n            max_samples=config.data.get(\"val_max_samples\", -1),\n        )\n        train_sampler = create_rl_sampler(config.data, train_dataset)\n\n        # Initialize the DAPO trainer with RayDAPOTrainer instead of RayPPOTrainer\n        trainer = RayOnPolicyDistillTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=self.role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            train_dataset=train_dataset,\n            val_dataset=val_dataset,\n            collate_fn=collate_fn,\n            train_sampler=train_sampler,\n        )\n        # Initialize the workers of the trainer\n        trainer.init_workers()\n\n        # Start the training process\n        trainer.fit()\n\n\n@hydra.main(config_path=\"config\", config_name=\"onpolicy_distill_trainer\", version_base=None)\ndef main(config):\n    \"\"\"Main entry point for PPO training with Hydra configuration management.\n\n    Args:\n        config_dict: Hydra configuration dictionary containing training parameters.\n    \"\"\"\n    run_ppo(config, task_runner_class=OnPolicyDistillTaskRunner)\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/onpolicy_distill/onpolicy_distill_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nFSDP PPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport gc\nimport os\nimport uuid\nfrom collections import defaultdict\nfrom copy import deepcopy\nfrom pprint import pprint\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.core_algos import agg_loss\nfrom verl.trainer.ppo.metric_utils import (\n    compute_on_policy_distill_data_metrics, compute_throughout_metrics,\n    compute_timing_metrics)\nfrom verl.trainer.ppo.ray_trainer import (AdvantageEstimator, RayPPOTrainer,\n                                          apply_kl_penalty, compute_advantage,\n                                          compute_response_mask)\nfrom verl.trainer.ppo.reward import compute_reward\nfrom verl.utils.metric import reduce_metrics\nfrom verl.utils.profiler import marked_timer\nfrom verl.utils.rollout_skip import RolloutSkip\n\n\nclass RayOnPolicyDistillTrainer(RayPPOTrainer):\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    def compute_kl_related_metrics(self, batch: DataProto, metrics: dict, timing_raw: dict):\n        batch.batch[\"response_mask\"] = compute_response_mask(batch)\n\n        # recompute old_log_probs\n        with marked_timer(\"old_log_prob\", timing_raw, \"blue\"):\n            old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n            entropys = old_log_prob.batch[\"entropys\"]\n            response_masks = batch.batch[\"response_mask\"]\n            loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n            entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n            old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n            metrics.update(old_log_prob_metrics)\n            old_log_prob.batch.pop(\"entropys\")\n            batch = batch.union(old_log_prob)\n\n        if self.use_reference_policy:\n            # compute reference log_prob\n            with marked_timer(\"ref\", timing_raw, \"olive\"):\n                if not self.ref_in_actor:\n                    ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                else:\n                    ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)\n                batch = batch.union(ref_log_prob)\n\n        return batch\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n        self.gen_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        if self.config.actor_rollout_ref.rollout.get(\"skip_rollout\", False):\n            rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)\n            rollout_skip.wrap_generate_sequences()\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        self.gen_steps += 1\n        last_val_metrics = None\n\n        prev_step_profile = False\n        curr_step_profile = (\n            self.global_steps in self.config.global_profiler.steps\n            if self.config.global_profiler.steps is not None\n            else False\n        )\n        next_step_profile = False\n\n        timing_raw = defaultdict(float)\n        batch = None\n        num_prompt_in_batch = 0\n        num_gen_batches = 0\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n\n                with marked_timer(\"start_profile\", timing_raw):\n                    self._start_profiling(\n                        not prev_step_profile and curr_step_profile\n                        if self.config.global_profiler.profile_continuous_steps\n                        else curr_step_profile\n                    )\n\n                new_batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n                num_gen_batches += 1\n                gen_batch = self._get_gen_batch(new_batch)\n                gen_batch_output = gen_batch.repeat(\n                    repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True\n                )\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                with marked_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with marked_timer(\"gen\", timing_raw, \"red\"):\n                        if not self.async_rollout_mode:\n                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)\n                        else:\n                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)\n\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                    new_batch.non_tensor_batch[\"uid\"] = np.array(\n                        [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object\n                    )\n                    # repeat to align with repeated responses in rollout\n                    new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    new_batch = new_batch.union(gen_batch_output)\n\n                    batch = new_batch\n                    # === Updating ===\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    # TODO: Decouple the DP balancing and mini-batching.\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                    batch = self.compute_kl_related_metrics(batch, metrics, timing_raw)\n\n                    # Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)\n                    batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)\n                    # IS and mismatch metrics already have mismatch/ prefix\n                    metrics.update(is_metrics)\n\n                    with marked_timer(\"adv\", timing_raw, \"brown\"):\n                        # compute advantages, executed on the driver process\n                        norm_adv_by_std_in_grpo = self.config.algorithm.get(\"norm_adv_by_std_in_grpo\", True)\n                        batch = compute_advantage(\n                            batch,\n                            adv_estimator=self.config.algorithm.adv_estimator,\n                            gamma=self.config.algorithm.gamma,\n                            lam=self.config.algorithm.lam,\n                            num_repeat=self.config.actor_rollout_ref.rollout.n,\n                            distill_adv_max_clip=self.config.algorithm.distill_adv_max_clip,\n                            distill_adv_min_clip=self.config.algorithm.distill_adv_min_clip,\n                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                        )\n\n                    # update actor\n                    with marked_timer(\"update_actor\", timing_raw, \"red\"):\n                        actor_output = self.actor_rollout_wg.update_actor(batch)\n                    actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                    metrics.update(actor_output_metrics)\n                    \n                    # pop multi_modal_inputs before save model\n                    non_tensor_batch_keys_to_pop = []\n                    if \"multi_modal_data\" in batch.non_tensor_batch:\n                        non_tensor_batch_keys_to_pop.append(\"multi_modal_data\")\n                    if \"multi_modal_inputs\" in batch.non_tensor_batch:\n                        non_tensor_batch_keys_to_pop.append(\"multi_modal_inputs\")\n                    if \"processor_kwargs\" in batch.non_tensor_batch:\n                        non_tensor_batch_keys_to_pop.append(\"processor_kwargs\")\n                    batch.pop(\n                        non_tensor_batch_keys=non_tensor_batch_keys_to_pop\n                    )\n                    gc.collect()\n\n\n                    # Log rollout generations if enabled\n                    rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n                    if rollout_data_dir:\n                        with marked_timer(\"dump_rollout_generations\", timing_raw, color=\"green\"):\n                            inputs = self.tokenizer.batch_decode(batch.batch[\"prompts\"], skip_special_tokens=True)\n                            outputs = self.tokenizer.batch_decode(batch.batch[\"responses\"], skip_special_tokens=True)\n                            INVALID_FIELDS = ['score','index','uid','__num_turns__','multi_modal_inputs',\n                                            'sample_reward', \"raw_prompt\"]\n                            extra_infos = {}\n                            for key in batch.non_tensor_batch.keys():\n                                if key not in INVALID_FIELDS:\n                                    extra_infos[key] = batch.non_tensor_batch[key].tolist()\n                            self._dump_generations(\n                            inputs=inputs,\n                            outputs=outputs,\n                            scores=[0 for _ in range(len(outputs))],\n                            reward_extra_infos_dict=extra_infos,\n                            dump_path=rollout_data_dir,\n                            logger=logger,\n                            )\n\n                # validate\n                if (\n                    self.val_reward_fn is not None\n                    and self.config.trainer.test_freq > 0\n                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                ):\n                    with marked_timer(\"testing\", timing_raw, \"green\"):\n                        val_metrics: dict = self._validate()\n                        if is_last_step:\n                            last_val_metrics = val_metrics\n                    metrics.update(val_metrics)\n\n                if self.config.trainer.save_freq > 0 and (\n                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n                ):\n                    with marked_timer(\"save_checkpoint\", timing_raw, \"green\"):\n                        self._save_checkpoint()\n\n                with marked_timer(\"stop_profile\", timing_raw):\n                    next_step_profile = (\n                        self.global_steps + 1 in self.config.global_profiler.steps\n                        if self.config.global_profiler.steps is not None\n                        else False\n                    )\n                    self._stop_profiling(\n                        curr_step_profile and not next_step_profile\n                        if self.config.global_profiler.profile_continuous_steps\n                        else curr_step_profile\n                    )\n                    prev_step_profile = curr_step_profile\n                    curr_step_profile = next_step_profile\n\n                # collect metrics\n                metrics.update(compute_on_policy_distill_data_metrics(batch=batch, use_critic=self.use_critic))\n                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n                # TODO: implement actual tflpo and theoretical tflpo\n                n_gpus = self.resource_pool_manager.get_n_gpus()\n                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n                timing_raw = defaultdict(float)  # clear timing\n\n                metrics[\"train/num_gen_batches\"] = num_gen_batches\n                batch = None\n                num_prompt_in_batch = 0\n                num_gen_batches = 0\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n                if is_last_step:\n                    pprint(f\"Final validation metrics: {last_val_metrics}\")\n                    progress_bar.close()\n                    return\n\n                progress_bar.update(1)\n                self.global_steps += 1\n                self.gen_steps += 1\n        # check if last step checkpint exists\n        checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f\"global_step_{self.global_steps}\")\n        if not os.path.exists(checkpoint_dir):\n            # save last step checkpoint\n            timing_raw = defaultdict(float)\n            with marked_timer(\"save_checkpoint\", timing_raw, \"green\"):\n                self._save_checkpoint()\n            metrics = {f\"timing/{k}\": v for k, v in timing_raw.items()}\n            logger.log(data=metrics, step=self.global_steps)\n"
  },
  {
    "path": "verl_distillation/recipe/onpolicy_distill/run_qwen3_distill.sh",
    "content": "#!/bin/bash\n# On-policy Distillation: distill from a teacher model (e.g., Qwen3-1.7B) to a student model\n# with extended vocabulary (e.g., recommendation pretrained model with item tokens).\n#\n# Usage:\n#   export BASE_MODEL=/path/to/student_model\n#   export TEACHER_MODEL=/path/to/teacher_model\n#   export DATASET_PARQUET=/path/to/train.parquet\n#   bash run_qwen3_1.7b_distill.sh [hostfile]\n\nset -x\nHOME=$(pwd)\ntimestamp=$(date +\"%Y-%m-%d-%H:%M:%S\")\n\n# tmp_hostfile_dir 需要保留,框架需要\nHOSTFILE=\"${1:-/etc/mpi/hostfile}\"\nNODES=$(wc -l < $HOSTFILE)\n\nif [ ! -d \"$HOME/tmp_hostfile_dir\" ]; then\n    mkdir -p \"$HOME/tmp_hostfile_dir\"\nfi\nif [ ! -d \"$HOME/timeline_dir\" ]; then\n    mkdir -p \"$HOME/timeline_dir\"\nfi\ncat $HOSTFILE > \"$HOME/tmp_hostfile_dir/hostfile_$timestamp\"\n\nN_GPUS_PER_NODE=2\n\nproject_name=\"verl_on_policy_distill\"\n\nexperiment_name=\"verl_1.7b_distill_${NODES}_${timestamp}\"\nCKPT_HOME=${CKPT_HOME:-\"$HOME/outputs\"}\nCKPT_DIR=${CKPT_DIR:-\"${CKPT_HOME}/ckpts/${project_name}/${experiment_name}/\"}\n\nrollout_mode=\"async\"\nrollout_name=\"sglang\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\nexport HYDRA_FULL_ERROR=1\n\n# rollout buffer setting\nNUM_WORKER=16\nCUDA_GRAPH_MAX_BS=64\n\n# ===== Open-source friendly defaults =====\n# You MUST set these paths for your own environment.\nexport BASE_MODEL=${BASE_MODEL:-\"\"}\nexport TEACHER_MODEL=${TEACHER_MODEL:-\"\"}\nexport DATASET_PARQUET=${DATASET_PARQUET:-\"$(realpath ../output/onpolicy_distillation.parquet)\"}\n\n# Logging: default is console only.\n# To enable W&B, export WANDB_API_KEY and override trainer.logger:\n#   export WANDB_API_KEY=\"your-key\"\n#   ... trainer.logger='[console,wandb]' ...\nexport WANDB_API_KEY=${WANDB_API_KEY:-\"\"}\n\nif [ -z \"$BASE_MODEL\" ] || [ -z \"$TEACHER_MODEL\" ] || [ -z \"$DATASET_PARQUET\" ]; then\n  echo \"[ERROR] Please set BASE_MODEL / TEACHER_MODEL / DATASET_PARQUET before running.\"\n  echo \"  BASE_MODEL=$BASE_MODEL\"\n  echo \"  TEACHER_MODEL=$TEACHER_MODEL\"\n  echo \"  DATASET_PARQUET=$DATASET_PARQUET\"\n  exit 1\nfi\n\nexport USE_DYNAMIC_BSZ=True # 是否开启动态batch size, 则无视上述batch_size设置，按token数来分配显卡，避免某张显卡处理的token数过多导致OOM显存溢出\nexport MAX_TOKENS_PER_GPU=24000  # n*(prompt_len+response_len)\n\nexport TRAIN_BATCH_SIZE=32\nexport LEARNING_RATE=5e-6\n\n\nexport ROLLOUT_N=1  # 每个prompt的CoT采样数量\nexport BEAM_SIZE_PER_ROLLOUT=1  # 每个CoT的beam search数量\nexport TEMPERATURE=1.1\nexport ENABLE_THINK=True  # 是否在user prompt末尾添加/think\nexport THINK_MODE=\"auto\"\nexport MAX_RESPONSE_LEN=2048\n\nexport DISTILL_ADV_MAX=5.0\nexport DISTILL_ADV_MIN=-30.0\n\n# ===== Extended vocabulary distillation settings =====\n# Token ID threshold: tokens with id >= this value are considered \"extended vocab tokens\"\n# For Qwen3 with OneRec item tokens, 151669 is the start of extended vocabulary.\n# Set to empty string or \"null\" to disable extended vocab handling.\nexport EXTEND_VOCAB_START_TOKEN=151669\n# Whether to mask the entire response if it contains any extended token\nexport MASK_RESPONSE_IF_HAVE_EXTEND_TOKEN=False\n\nexport TRAIN_FILES=$DATASET_PARQUET\nexport VAL_FILES=$DATASET_PARQUET\n\necho \"Training files: $TRAIN_FILES\"\necho \"Validation files: $VAL_FILES\"\n\n\nPYTHONUNBUFFERED=1 python3 -m recipe.onpolicy_distill.main_onpolicy_distill --config-name='onpolicy_distill_trainer'\\\n    +ray_kwargs.ray_init.runtime_env.env_vars.TRACE_GPU_MEM=False \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.WORK_DIR=$HOME \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.WANDB_API_KEY=\"$WANDB_API_KEY\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.nosp=\"1\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NCCL_IB_ECE_ENABLE=\"0\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.CUDA_DEVICE_MAX_CONNECTIONS=\"32\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NVTE_ALLOW_NONDETERMINISTIC_ALGO=\"1\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NCCL_NVLS_ENABLE=\"0\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.PYTHONWARNINGS=\"ignore\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NCCL_DEBUG=\"VERSION\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NCCL_IB_DISABLE=\"0\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NCCL_IB_GID_INDEX=\"3\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NCCL_ASYNC_ERROR_HANDLING=\"1\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NCCL_SOCKET_IFNAME=\"bond0\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NCCL_IB_HCA=\"mlx5\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NCCL_PXN_DISABLE=\"0\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.NCCL_IB_QPS_PER_CONNECTION=\"4\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.SGLANG_VLM_CACHE_SIZE_MB=\"512\" \\\n    +ray_kwargs.ray_init.runtime_env.env_vars.TIMESTAMP=$timestamp \\\n    algorithm.adv_estimator=on_policy_distill \\\n    data.train_files=$TRAIN_FILES \\\n    data.val_files=$VAL_FILES \\\n    data.max_prompt_length=10240 \\\n    ++data.enable_think=$ENABLE_THINK \\\n    ++data.think_mode=$THINK_MODE \\\n    data.prompt_key=prompt \\\n    data.image_key=dummy \\\n    data.video_key=dummy \\\n    ++data.data_source_key='source' \\\n    data.reward_fn_key='source' \\\n    data.max_response_length=$MAX_RESPONSE_LEN \\\n    data.train_batch_size=$TRAIN_BATCH_SIZE \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=$return_raw_chat \\\n    actor_rollout_ref.actor.use_dynamic_bsz=$USE_DYNAMIC_BSZ \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$MAX_TOKENS_PER_GPU \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_BATCH_SIZE \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$MAX_TOKENS_PER_GPU \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$MAX_TOKENS_PER_GPU \\\n    actor_rollout_ref.rollout.calculate_log_probs=False \\\n    actor_rollout_ref.actor.optim.lr=${LEARNING_RATE} \\\n    actor_rollout_ref.actor.clip_ratio_high=0.28 \\\n    actor_rollout_ref.model.enable_activation_offload=True \\\n    actor_rollout_ref.model.path=$BASE_MODEL \\\n    +actor_rollout_ref.ref.model.path=$TEACHER_MODEL \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.ref_log_prob_replace_val=-100 \\\n    actor_rollout_ref.ref.ref_log_prob_replace_val=-100 \\\n    actor_rollout_ref.rollout.name=$rollout_name \\\n    actor_rollout_ref.rollout.mode=$rollout_mode \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.extend_vocab_start_token=$EXTEND_VOCAB_START_TOKEN \\\n    actor_rollout_ref.rollout.mask_response_if_have_extend_token=$MASK_RESPONSE_IF_HAVE_EXTEND_TOKEN \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.chunked_prefill_size=16384 \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.cuda_graph_max_bs=$CUDA_GRAPH_MAX_BS \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.max_running_requests=$CUDA_GRAPH_MAX_BS \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.disable_radix_cache=False \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.log_level=info \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.log_requests=False \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.log_requests_level=2 \\\n    actor_rollout_ref.rollout.n=$ROLLOUT_N \\\n    actor_rollout_ref.rollout.temperature=${TEMPERATURE} \\\n    actor_rollout_ref.rollout.top_p=0.95 \\\n    actor_rollout_ref.rollout.top_k=200 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.agent.num_workers=$NUM_WORKER \\\n    actor_rollout_ref.rollout.agent.default_agent_loop=tool_agent \\\n    algorithm.use_kl_in_reward=False \\\n    ++algorithm.distill_adv_max_clip=$DISTILL_ADV_MAX \\\n    ++algorithm.distill_adv_min_clip=$DISTILL_ADV_MIN \\\n    actor_rollout_ref.actor.loss_agg_mode=\"token-mean\" \\\n    actor_rollout_ref.actor.kl_loss_coef=0.0 \\\n    trainer.logger='[console]' \\\n    trainer.project_name=$project_name \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.n_gpus_per_node=$N_GPUS_PER_NODE \\\n    trainer.nnodes=$NODES \\\n    trainer.save_freq=5 \\\n    trainer.max_actor_ckpt_to_keep=100 \\\n    trainer.test_freq=-1 \\\n    trainer.default_hdfs_dir=null \\\n    trainer.default_local_dir=$CKPT_DIR \\\n    trainer.val_before_train=False \\\n    trainer.val_only=False \\\n    trainer.rollout_data_dir=$HOME \\\n    +trainer.validation_data_dir=$HOME \\\n    +trainer.ray_timeline_dir=$HOME/tmp_hostfile_dir \\\n    trainer.total_epochs=1 2>&1 | tee $project_name-$experiment_name-$timestamp.log\n\n"
  },
  {
    "path": "verl_distillation/recipe/open_math_reasoning/README.md",
    "content": "# Open math reasoning\n## Introduction\nIn this recipe, we perform SFT on the [open math reasoning](https://huggingface.co/datasets/nvidia/OpenMathReasoning) dataset using the new SFT trainer with backend agostic model engine. Note that our goal is not to replicate the [AIMO-2 Winning Solution](https://arxiv.org/abs/2504.16891) work, but to demonstrate a SFT demo from end to end.\n\nNote that you may need to modify the path as needed in the following scripts.\n## Dataset Preprocessing\n### Download Dataset\n```bash\nhf download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* --local-dir /path/to/dataset/nvidia/OpenMathReasoning\nhf download math-ai/aime24 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime24\nhf download math-ai/aime25 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime25\n```\n\n### Preprocess the dataset\n```bash\npython3 recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py --local_dataset_path /path/to/nvidia/OpenMathReasoning --local_save_dir /path/to/open_math_reasoning\n```\n\n### Prepare the eval dataset\n```bash\npython3 recipe/open_math_reasoning/prepare_eval_dataset.py --local_dataset_path /path/to/dataset --local_save_dir /path/to/eval_dataset\n```\n\n## Train the model using SFT\n```bash\nexport CKPT_HOME=/path/to/ckpt\nexport MODEL_ID=Qwen/Qwen3-8B-Base\nexport TRAIN_FILES=/path/to/open_math_reasoning/cot_dataset.parquet\n```\n\n### FSDP backend\n```bash\nexport BACKEND=fsdp2\nbash recipe/open_math_reasoning/run_sft_qwen3_8b.sh\n```\n\n### Megatron backend\n```bash\nexport BACKEND=megatron\nbash recipe/open_math_reasoning/run_sft_qwen3_8b.sh\n```\n\n## Eval the model\n### Merge checkpoint into huggingface format\nFSDP backend\n```bash\npython -m verl.model_merger merge --backend fsdp --local_dir /path/to/ckpt/global_step_19751 --target_dir /path/to/ckpt/global_step_19751/huggingface\n```\nMegatron backend\n```bash\npython -m verl.model_merger merge --backend megatron --local_dir /path/to/ckpt/global_step_19751 --target_dir /path/to/ckpt/global_step_19751/huggingface --use_cpu_initialization\n```\n\n### Generate the responses\n```bash\nexport MODEL_PATH=/path/to/ckpt/global_step_19751/huggingface\nbash recipe/open_math_reasoning/run_generation.sh\n```\n\n### Evaluate the responses\n```bash\nbash recipe/open_math_reasoning/run_eval.sh\n```\n\nYou should see the results like:\n```python\n{'test_score/aime24': 0.584375, 'test_score/aime25': 0.43333333333333335}\n```\n"
  },
  {
    "path": "verl_distillation/recipe/open_math_reasoning/compute_score.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ndef compute_score_data_source(data_source, response, ground_truth):\n    from verl.utils.reward_score.math_reward import compute_score\n\n    if data_source in [\"aime24\", \"aime25\"]:\n        return compute_score(response, ground_truth)\n    else:\n        raise ValueError(f\"Unknown data source: {data_source}\")\n"
  },
  {
    "path": "verl_distillation/recipe/open_math_reasoning/prepare_eval_dataset.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# prepare eval dataset including AIME'24, AIME'25\n\n# hf download math-ai/aime24 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime24\n# hf download math-ai/aime25 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime25\n\nimport os\n\nimport datasets\n\nfrom verl.utils.reward_score.math_reward import remove_boxed\n\ninstruction_following = \"Please reason step by step, and put your final answer within \\\\boxed{}.\"\n\n\ndef make_map_fn(data_source):\n    def process_fn(example, idx):\n        question_raw = example.pop(\"problem\")\n\n        question = question_raw + \" \" + instruction_following\n\n        if \"solution\" not in example:\n            example[\"solution\"] = example[\"answer\"]\n\n        answer_raw = example.pop(\"solution\")\n\n        example.clear()\n\n        try:\n            solution = remove_boxed(answer_raw)\n        except Exception:\n            solution = answer_raw\n\n        data = {\n            \"data_source\": data_source,\n            \"prompt\": [\n                {\n                    \"role\": \"user\",\n                    \"content\": question,\n                }\n            ],\n            \"ability\": \"math\",\n            \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n            \"extra_info\": {\n                \"index\": idx,\n                \"answer\": answer_raw,\n                \"question\": question_raw,\n            },\n        }\n        return data\n\n    return process_fn\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\", default=\"~/data/math-ai\", help=\"The save directory for the preprocessed dataset.\"\n    )\n\n    args = parser.parse_args()\n\n    if args.local_dataset_path is not None:\n        aime24_dataset_path = os.path.join(args.local_dataset_path, \"math-ai/aime24\")\n        aime25_dataset_path = os.path.join(args.local_dataset_path, \"math-ai/aime25\")\n    else:\n        aime24_dataset_path = \"math-ai/aime24\"\n        aime25_dataset_path = \"math-ai/aime25\"\n\n    aime24_dataset = datasets.load_dataset(aime24_dataset_path, split=\"test\")\n    aime25_dataset = datasets.load_dataset(aime25_dataset_path, split=\"test\")\n\n    aime24_dataset = aime24_dataset.map(function=make_map_fn(\"aime24\"), with_indices=True)\n    aime25_dataset = aime25_dataset.map(function=make_map_fn(\"aime25\"), with_indices=True)\n\n    local_save_dir = os.path.expanduser(args.local_save_dir)\n    os.makedirs(local_save_dir, exist_ok=True)\n\n    aime24_dataset.to_parquet(os.path.join(local_save_dir, \"aime24_test.parquet\"))\n    aime25_dataset.to_parquet(os.path.join(local_save_dir, \"aime25_test.parquet\"))\n"
  },
  {
    "path": "verl_distillation/recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nhuggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \\\n    --local-dir /path/to/nvidia/OpenMathReasoning\nhuggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \\\n    --local-dir /opt/tiger/nvidia/OpenMathReasoning\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dataset_path\", default=None, help=\"The local path to the raw dataset, if it exists.\")\n    parser.add_argument(\n        \"--local_save_dir\",\n        default=\"~/data/open_math_reasoning\",\n        help=\"The save directory for the preprocessed dataset.\",\n    )\n\n    args = parser.parse_args()\n    local_dataset_path = args.local_dataset_path\n\n    data_source = \"nvidia/OpenMathReasoning\"\n\n    if local_dataset_path is not None:\n        dataset = datasets.load_dataset(local_dataset_path, split=\"cot\")\n    else:\n        dataset = datasets.load_dataset(data_source, split=\"cot\")\n\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question = example.pop(\"problem\")\n            solution = example.pop(\"generated_solution\")\n\n            extra_info = {}\n            for key, value in example.items():\n                extra_info[key] = value\n            example.clear()\n\n            data = {\n                \"messages\": [\n                    {\"role\": \"user\", \"content\": question, \"loss_mask\": 0},\n                    {\"role\": \"assistant\", \"content\": solution, \"loss_mask\": 1},\n                ],\n                \"extra_info\": extra_info,\n            }\n            return data\n\n        return process_fn\n\n    # filter out data where the problem_type is not has_answer_extracted\n    dataset = dataset.filter(lambda example: example[\"problem_type\"] == \"has_answer_extracted\")\n    dataset = dataset.map(function=make_map_fn(\"cot\"), with_indices=True)\n    local_save_dir = os.path.expanduser(args.local_save_dir)\n    os.makedirs(local_save_dir, exist_ok=True)\n    dataset.to_parquet(os.path.join(local_save_dir, \"cot_dataset.parquet\"))\n"
  },
  {
    "path": "verl_distillation/recipe/open_math_reasoning/run_eval.sh",
    "content": "#!/usr/bin/env bash\n\n# Evaluation\npython3 -m verl.trainer.main_eval \\\n    data.path=$HOME/data/gen/qwen_8b_gen_test.parquet \\\n    custom_reward_function.path=recipe/open_math_reasoning/compute_score.py \\\n    custom_reward_function.name=compute_score_data_source\n"
  },
  {
    "path": "verl_distillation/recipe/open_math_reasoning/run_generation.sh",
    "content": "#!/usr/bin/env bash\n\nMODEL_PATH=${MODEL_PATH:-/path/to/ckpt/global_step_19751/huggingface}\n\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\nNNODES=${NNODES:-1}\nOUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_8b_gen_test.parquet}\nGEN_TP=${GEN_TP:-1}  # Default tensor parallel size to 2\n\naime24_test_path=${HOME}/data/math-ai/aime24_test.parquet\naime25_test_path=${HOME}/data/math-ai/aime25_test.parquet\ntrain_files=\"['$aime24_test_path', '$aime25_test_path']\"\n\npython3 -m verl.trainer.main_generation_server \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.trust_remote_code=True \\\n    actor_rollout_ref.rollout.temperature=1.0 \\\n    actor_rollout_ref.rollout.top_p=0.7 \\\n    actor_rollout_ref.rollout.prompt_length=2048 \\\n    actor_rollout_ref.rollout.response_length=20480 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\"${GEN_TP}\" \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.n=32 \\\n    data.train_files=\"$train_files\" \\\n    data.prompt_key=prompt \\\n    +data.output_path=\"${OUTPUT_PATH}\" \\\n\n\n\n"
  },
  {
    "path": "verl_distillation/recipe/open_math_reasoning/run_sft_qwen3_8b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nENTRYPOINT=${ENTRYPOINT:-\"-m verl.trainer.sft_trainer\"}\n\nTRAIN_FILES=${TRAIN_FILES:-/path/to/cot_dataset.parquet}\n\nbackend=${BACKEND:-fsdp}\n\nproject_name=verl_sft_test\n\nRESUME_MODE=auto\nMODEL_ID=${MODEL_ID:-Qwen/Qwen3-8B-Base}\n\nSP_SIZE=${SP_SIZE:-8}\nFSDP_SIZE=${FSDP_SIZE:-16}\nFSDP_STRATEGY=${FSDP_STRATEGY:-\"fsdp2\"}\n\nTP_SIZE=${TP_SIZE:-8}\nPP_SIZE=${PP_SIZE:-2}\nVPP_SIZE=${VPP_SIZE:-null}\nCP_SIZE=${CP_SIZE:-1}\n\nPAD_MODE=${PAD_MODE:-no_padding}\n\nUSE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}\n\nFSDP_ENGINE_CONFIG=\"\\\n    engine=${backend} \\\n    optim=${backend} \\\n    optim.lr=2e-5 \\\n    optim.lr_warmup_steps_ratio=0.01 \\\n    optim.weight_decay=0.1 \\\n    optim.betas=\"[0.9,0.95]\" \\\n    optim.clip_grad=1.0 \\\n    optim.min_lr_ratio=0.1 \\\n    optim.warmup_style=cosine \\\n    engine.ulysses_sequence_parallel_size=${SP_SIZE} \\\n    engine.strategy=${FSDP_STRATEGY} \\\n    engine.fsdp_size=${FSDP_SIZE}\"\n\n\nMEGATRON_ENGINE_CONFIG=\"\\\n    engine=${backend} \\\n    optim=${backend} \\\n    optim.lr=2e-5 \\\n    optim.lr_warmup_steps_ratio=0.01 \\\n    optim.weight_decay=0.1 \\\n    optim.betas=\"[0.9,0.95]\" \\\n    optim.clip_grad=1.0 \\\n    optim.lr_warmup_init=0 \\\n    optim.lr_decay_style=cosine \\\n    optim.min_lr=2e-6 \\\n    engine.tensor_model_parallel_size=${TP_SIZE} \\\n    engine.pipeline_model_parallel_size=${PP_SIZE} \\\n    engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \\\n    engine.context_parallel_size=${CP_SIZE} \\\n    engine.use_mbridge=False\"\n\nif [ \"$backend\" = \"fsdp\" ]; then\n    ENGINE_CONFIG=\"$FSDP_ENGINE_CONFIG\"\n    echo \"Using fsdp engine\"\n    exp_name=nvidia-openmathreasoning-qwen3-8b-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp-1008a1\nelse\n    ENGINE_CONFIG=\"$MEGATRON_ENGINE_CONFIG\"\n    echo \"Using megatron engine\"\n    exp_name=nvidia-openmathreasoning-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-megatron-1018a1\nfi\n\nCKPT_HOME=${CKPT_HOME:-$HOME/open_verl/sft/${project_name}/${exp_name}}\nmkdir -p \"${CKPT_HOME}\"\n\ntorchrun --standalone --nnodes=1 --nproc-per-node=${NUM_TRAINERS:-8} \\\n    ${ENTRYPOINT} \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.train_batch_size=96 \\\n    data.max_length=32768 \\\n    data.pad_mode=${PAD_MODE} \\\n    data.truncation=error \\\n    data.use_dynamic_bsz=True \\\n    data.max_token_len_per_gpu=65536 \\\n    data.messages_key=messages \\\n    model.path=$MODEL_ID \\\n    model.use_remove_padding=${USE_REMOVE_PADDING} \\\n    ${ENGINE_CONFIG} \\\n    trainer.test_freq=-1 \\\n    trainer.save_freq=4000 \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPT_HOME}\" \\\n    trainer.resume_mode=${RESUME_MODE} \\\n    trainer.max_ckpt_to_keep=5 \\\n    checkpoint.save_contents=[model,optimizer,extra]"
  },
  {
    "path": "verl_distillation/recipe/prime/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/recipe/prime/config/prime_trainer.yaml",
    "content": "# the prime config will override default ppo_trainer.yaml\n\nhydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  filter_accuracy: True\n  accuracy_lower_bound: 0.2\n  accuracy_upper_bound: 0.8\n  oversample_factor: 4.0 # Sample more responses than the batch size. prompts satisfying the filter will be prioritized.\n  filter_truncate: True\n  truncation: right\n\nactor_rollout_ref:\n  hybrid_engine: True\n  model:\n    use_remove_padding: True\n  rollout:\n    # number of responses (i.e. num sample times)\n    n: 4\n  actor:\n    entropy_coeff: 0.001\n\nreward_model:\n  enable: True\n  strategy: fsdp\n  model:\n    ref_path: ${reward_model.model.path}\n    use_remove_padding:  True\n    use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}\n    fused_kernel_options:\n      impl_backend: torch # triton, torch\n    tokenizer_path: ${actor_rollout_ref.model.path}\n    enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing}\n    ref_type: freeze\n    fsdp_config:\n      min_num_params: 0\n      param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload}\n      optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload}\n    update: before # ``before`` for double-forward, ``after`` for single-forward\n    optim:\n      lr: 1e-6\n      lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.\n      lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n      min_lr_ratio: null\n      warmup_style: null # deprecated\n      lr_scheduler_type: constant\n      total_training_steps: -1  # must be overridden by program\n      weight_decay: 0.\n      grad_clip: 10.0\n    beta_train: 0.05\n    loss_type: ce # currently only supports ce loss\n  prime_granularity: token\n  prime_norm: batch_norm # batch_norm or none. if set to none, the normalizer is beta_train\n  mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}\n  reward_manager: prime\n\nalgorithm:\n  adv_estimator: rloo\n  # now supports rloo. it treats different source of reward separately.\n  kl_ctrl:\n    type: fixed\n    kl_coef: 0.000\n  reward_gt_coef: 5\n  reward_dpo_coef: 5\n\ntrainer:\n  project_name: prime\n  experiment_name: examples\n  val_before_train: False\n  balance_batch: False\n"
  },
  {
    "path": "verl_distillation/recipe/prime/main_prime.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom verl.trainer.ppo.utils import need_reference_policy\nfrom verl.utils.config import validate_config\n\nfrom .prime_ray_trainer import RayPRIMETrainer\n\n\n@hydra.main(config_path=\"config\", config_name=\"prime_trainer\", version_base=None)\ndef main(config):\n    run_prime(config)\n\n\ndef run_prime(config, compute_score=None):\n    if not ray.is_initialized():\n        default_runtime_env = {\"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\"}}\n        ray_init_kwargs = config.ray_kwargs.get(\"ray_init\", {})\n        runtime_env_kwargs = ray_init_kwargs.get(\"runtime_env\", {})\n        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)\n        ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, \"runtime_env\": runtime_env})\n        print(f\"ray init kwargs: {ray_init_kwargs}\")\n        # this is for local ray cluster\n        ray.init(**OmegaConf.to_container(ray_init_kwargs))\n\n    ray.get(main_task.remote(config, compute_score))\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\ndef main_task(config, compute_score=None):\n    # print initial config\n    from pprint import pprint\n\n    from omegaconf import OmegaConf\n\n    from verl.utils.fs import copy_local_path_from_hdfs\n\n    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n    OmegaConf.resolve(config)\n\n    # define worker classes\n    if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n        assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n        from verl.single_controller.ray import RayWorkerGroup\n        from verl.workers.fsdp_workers import ActorRolloutRefWorker\n\n        ray_worker_group_cls = RayWorkerGroup\n\n    elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n        assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n        from verl.single_controller.ray import RayWorkerGroup\n        from verl.workers.megatron_workers import ActorRolloutRefWorker\n\n        ray_worker_group_cls = RayWorkerGroup\n\n    else:\n        raise NotImplementedError\n\n    from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n    role_worker_mapping = {\n        Role.ActorRollout: ray.remote(ActorRolloutRefWorker),\n    }\n\n    global_pool_id = \"global_pool\"\n    resource_pool_spec = {\n        global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n    }\n    mapping = {\n        Role.ActorRollout: global_pool_id,\n    }\n\n    # use reference model\n    if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n        role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n        mapping[Role.RefPolicy] = global_pool_id\n\n    if config.reward_model.enable:\n        from .prime_fsdp_workers import PRIMERewardModelWorker\n\n        role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)\n        mapping[Role.RewardModel] = global_pool_id\n\n    # validate config\n    # TODO: Additional config checks can be added with proper function under prime recipe\n    validate_config(\n        config=config,\n        use_reference_policy=need_reference_policy(role_worker_mapping),\n        use_critic=False,\n    )\n\n    # download the checkpoint from hdfs\n    local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)\n\n    # instantiate tokenizer\n    from verl.utils import hf_tokenizer\n\n    tokenizer = hf_tokenizer(local_path)\n    reward_manager_name = config.reward_model.get(\"reward_manager\", \"naive\")\n    if reward_manager_name == \"naive\":\n        from verl.workers.reward_manager import NaiveRewardManager\n\n        reward_manager_cls = NaiveRewardManager\n    elif reward_manager_name == \"prime\":\n        from verl.workers.reward_manager import PrimeRewardManager\n\n        reward_manager_cls = PrimeRewardManager\n    else:\n        raise NotImplementedError\n    reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)\n\n    # Note that we always use function-based RM for validation\n    val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)\n\n    resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n    trainer = RayPRIMETrainer(\n        config=config,\n        tokenizer=tokenizer,\n        role_worker_mapping=role_worker_mapping,\n        resource_pool_manager=resource_pool_manager,\n        ray_worker_group_cls=ray_worker_group_cls,\n        reward_fn=reward_fn,\n        val_reward_fn=val_reward_fn,\n    )\n    trainer.init_workers()\n    trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/prime/prime_core_algos.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\nimport verl\nimport verl.utils.torch_functional as verl_F\n\n\ndef compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config):\n    # calculate rloo reward on different reward sources, and sum again\n    def masked_rloo(reward_tensor_original, mask_tensor):\n        reward_tensor = reward_tensor_original.clone()\n        reward_tensor[~mask_tensor] = 0\n        for start_pos in range(0, reward_tensor.shape[0], n_samples):\n            cur_rewards_mean = torch.cat(\n                [\n                    reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True)\n                    for pos in range(start_pos, start_pos + n_samples)\n                ],\n                dim=0,\n            )\n            cur_rewards_sum = cur_rewards_mean.sum()\n            cur_reward_baseline = cur_rewards_sum / (n_samples - 1)\n            reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = (\n                reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]]\n                * (n_samples / (n_samples - 1))\n                - cur_reward_baseline\n            )\n\n        return reward_tensor\n\n    reward_tensors = []\n\n    with torch.no_grad():\n        if \"rm_scores\" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0:\n            reward_tensor = data.batch[\"rm_scores\"]\n            reward_mask = response_mask.bool()\n\n            reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)\n\n        if \"acc\" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0:\n            reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)\n            reward_mask = torch.zeros_like(response_mask, dtype=torch.bool)\n\n            prompt_ids = data.batch[\"prompts\"]\n            prompt_length = prompt_ids.shape[-1]\n            valid_response_length = data.batch[\"attention_mask\"][:, prompt_length:].sum(-1)\n\n            reward_mask[\n                torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),\n                valid_response_length - 1,\n            ] = True\n            reward_tensor[\n                torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),\n                valid_response_length - 1,\n            ] = data.batch[\"acc\"]\n\n            reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef)\n\n        final_reward_tensor = sum(reward_tensors)\n\n        returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])\n\n        advantages = returns.clone()\n        advantages = verl_F.masked_whiten(advantages, response_mask)\n\n        return advantages, returns\n\n\ndef compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta):\n    cur_scores = ((token_level_scores * response_mask).sum(dim=1) * beta).sigmoid()\n    cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)\n    return cur_dpo_loss\n\n\ndef compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode=\"none\"):\n    # we always assume that the BoN size equals n_samples\n    # mode1: use acc as rm\n    # mode2: use Q as rm\n    cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta\n    other_Q = torch.zeros_like(cur_Q)\n    for i in range(token_level_scores.shape[0]):\n        Q_chosen = Q_bc[i][acc_bc[i] < acc[i]] if acc[i] > 0 else Q_bc[i][acc_bc[i] > acc[i]]\n        if len(Q_chosen) > 0:\n            other_Q[i] = Q_chosen.mean() * beta\n        else:\n            other_Q[i] = 0\n    dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1)))\n    if bon_mode == \"none\":\n        dpo_loss = dpo_loss.mean()\n    else:\n        weight = torch.zeros_like(dpo_loss)\n        n_samples = acc_bc.shape[1]\n        if bon_mode == \"bon_rm\":\n            for i in range(token_level_scores.shape[0]):\n                weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1)\n        elif bon_mode == \"bon_acc\":\n            for i in range(token_level_scores.shape[0]):\n                weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1)\n        else:\n            raise NotImplementedError\n        dpo_loss = (dpo_loss * weight).sum()\n\n    return dpo_loss\n\n\ndef compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples):\n    dpo_acc = []\n    for start_id in range(0, token_level_scores.shape[0], n_samples):\n        cur_scores = (\n            token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples]\n        ).sum(dim=1)\n\n        def get_upper_triangle(tensor_x):\n            diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)\n            upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1)\n            return diff_matrix[upper_tri_indices]\n\n        cur_acc_diff = get_upper_triangle(acc[start_id : start_id + n_samples])  # in range [-1,1]\n        cur_score_diff = get_upper_triangle(cur_scores)  # in R\n        cur_score_prediction = (cur_score_diff > 0).float()  # in [0,1]\n        if cur_acc_diff.abs().sum() == 0:\n            cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5\n        else:\n            cur_acc = (\n                ((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs()\n            ).sum() / cur_acc_diff.abs().sum()\n\n        dpo_acc.append(cur_acc.unsqueeze(0))\n\n    return torch.cat(dpo_acc, dim=0).mean()\n\n\ndef compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples):\n    return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()\n"
  },
  {
    "path": "verl_distillation/recipe/prime/prime_dp_rm.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement a multiprocess PPOCritic\n\"\"\"\n\nimport itertools\n\nimport torch\nimport torch.distributed\nfrom flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\nfrom torch import nn, optim\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nimport verl.utils.torch_functional as verl_F\nfrom verl import DataProto\nfrom verl.utils.device import get_device_name\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs\n\nfrom .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm\n\n__all__ = [\"DataParallelPRIMERewardModel\"]\n\n\nclass DataParallelPRIMERewardModel:\n    def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer):\n        self.config = config\n        self.reward_module = reward_module\n        self.ref_module = ref_module\n        self.reward_optimizer = reward_optimizer\n        self.use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        print(f\"Reward model use_remove_padding={self.use_remove_padding}\")\n        self.use_fused_kernels = self.config.model.get(\"use_fused_kernels\", False)\n        print(f\"Reward model use_fused_kernels={self.use_fused_kernels}\")\n\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n\n    def _forward_micro_batch(self, micro_batch, prompt_length):\n        input_ids = micro_batch[\"input_ids\"]\n        batch_size, seqlen = input_ids.shape\n        attention_mask = micro_batch[\"attention_mask\"]\n        position_ids = micro_batch[\"position_ids\"]\n\n        num_actions = micro_batch[\"input_ids\"].shape[-1] - prompt_length\n        max_positions = micro_batch[\"attention_mask\"][:, prompt_length:].sum(-1)\n\n        if self.use_remove_padding:\n            input_ids_rmpad, indices, *_ = unpad_input(\n                input_ids.unsqueeze(-1), attention_mask\n            )  # input_ids_rmpad (total_nnz, ...)\n            input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n            # unpad the position_ids to align the rotary\n            position_ids_rmpad = index_first_axis(\n                rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n            ).transpose(0, 1)\n\n            # for compute the log_prob\n            input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)\n\n            # pad and slice the inputs if sp > 1\n            if self.ulysses_sequence_parallel_size > 1:\n                input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                    input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size\n                )\n                input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(\n                    input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size\n                )\n\n            input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)\n            output = self.reward_module(\n                input_ids=input_ids_rmpad,\n                attention_mask=None,\n                position_ids=position_ids_rmpad,\n                use_cache=False,\n                return_dict=self.use_fused_kernels,\n            )\n\n            if self.use_fused_kernels:\n                rm_log_labels = output.log_probs.squeeze(0)  # (total_nnz,)\n                rm_log_labels = rm_log_labels.to(torch.float32)\n\n            else:\n                rm_output_logits = output.logits.squeeze(0)\n                rm_log_labels = verl_F.logprobs_from_logits(\n                    logits=rm_output_logits,\n                    labels=input_ids_rmpad_rolled,\n                )\n\n            if self.ulysses_sequence_parallel_size > 1:\n                rm_log_labels = gather_outputs_and_unpad(\n                    rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size\n                )\n            rm_log_labels = pad_input(\n                hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n            ).squeeze(-1)[:, -num_actions - 1 : -1]\n\n        else:\n            output = self.reward_module(\n                input_ids=micro_batch[\"input_ids\"],\n                attention_mask=micro_batch[\"attention_mask\"],\n                position_ids=micro_batch[\"position_ids\"],\n                use_cache=False,\n                return_dict=self.use_fused_kernels,\n            )\n\n            if self.use_fused_kernels:\n                rm_log_labels = output.log_probs[:, :-1]  # (bsz, seq_length)\n                rm_log_labels = rm_log_labels.to(torch.float32)\n\n            else:\n                rm_output_logits = output.logits\n                rm_log_prob = torch.nn.functional.log_softmax(\n                    rm_output_logits[:, :-1, :], dim=-1\n                )  # (batch_size, seq_length, vocab_size)\n                rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch[\"input_ids\"][:, 1:].unsqueeze(-1)).squeeze(\n                    -1\n                )  # (batch, seq_length)\n\n        if self.ref_module is not None:\n            # do not have to pad again\n            with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):\n                if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding:\n                    ref_output = self.ref_module(\n                        input_ids=input_ids_rmpad,\n                        attention_mask=None,\n                        position_ids=position_ids_rmpad,\n                        use_cache=False,\n                    )\n\n                    if self.use_fused_kernels:\n                        ref_log_labels = ref_output.log_probs.squeeze(0)  # (total_nnz,)\n                        ref_log_labels = ref_log_labels.to(torch.float32)\n\n                    else:\n                        ref_output_logits = ref_output.logits.squeeze(0)\n                        ref_log_labels = verl_F.logprobs_from_logits(\n                            logits=ref_output_logits, labels=input_ids_rmpad_rolled\n                        )\n\n                    ref_log_labels = gather_outputs_and_unpad(\n                        ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size\n                    )\n                    ref_log_labels = pad_input(\n                        hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n                    ).squeeze(-1)[:, -num_actions - 1 : -1]\n                else:\n                    ref_output = self.ref_module(\n                        input_ids=micro_batch[\"input_ids\"],\n                        attention_mask=micro_batch[\"attention_mask\"],\n                        position_ids=micro_batch[\"position_ids\"],\n                        use_cache=False,\n                    )\n\n                    if self.use_fused_kernels:\n                        ref_log_labels = ref_output.log_probs[:, :-1]  # (batch_size, seq_length)\n                        ref_log_labels = ref_log_labels.to(torch.float32)\n\n                    else:\n                        ref_output_logits = ref_output.logits\n                        ref_log_prob = torch.nn.functional.log_softmax(\n                            ref_output_logits[:, :-1, :], dim=-1\n                        )  # (batch_size, seq_length, vocab_size)\n                        ref_log_labels = ref_log_prob.gather(\n                            dim=-1, index=micro_batch[\"input_ids\"][:, 1:].unsqueeze(-1)\n                        ).squeeze(-1)  # (batch, seq_length)\n\n        else:\n            ref_log_labels = micro_batch[\"old_log_probs\"]\n\n        ref_log_labels.to(rm_log_labels.dtype)\n        q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:]  # this is actually diff of q\n\n        # trim unnecessary logprobs here\n        for i in range(micro_batch[\"input_ids\"].shape[0]):\n            q[i, max_positions[i] :] = 0\n\n        # reward computation does not need gradient. only q needs\n        with torch.no_grad():\n            # generalized estimation of r should go before the reward filling. r means process reward for policy\n            # model, or the advantage of reward model.\n            lam = self.config.get(\"lambda\", 0.0)\n            beta = self.config.model.get(\"beta_train\", 0.05)\n            if lam == 0.0:\n                r = q * beta\n            else:\n                # reward coefficient takes no effect here\n                acc = micro_batch[\"acc\"]\n                q_ = q * beta\n                r = torch.zeros_like(q)\n                lastgaelam = 0\n                # change the last token and mask out all paddings to make this process easier if we rely on\n                # outcome reward to calculate V\n                for i in range(q.shape[0]):\n                    if self.config.prime_use_gt:\n                        q_[i, max_positions[i] - 1] = acc[i] - q_[i, : max_positions[i] - 1].sum()\n                    q_[i, max_positions[i] :] = 0\n\n                for t in reversed(range(num_actions)):\n                    delta = q_[:, t]\n                    lastgaelam = delta + lam * lastgaelam\n                    r[:, t] = lastgaelam\n\n            token_level_score = torch.zeros_like(q)\n\n            if self.config.prime_granularity == \"token\":\n                for i in range(micro_batch[\"input_ids\"].shape[0]):\n                    token_level_score[i, : max_positions[i] - 1] = r[i, : max_positions[i] - 1]\n            elif self.config.prime_granularity == \"whole\":\n                for i in range(micro_batch[\"input_ids\"].shape[0]):\n                    token_level_score[i, max_positions[i] - 1] = r[i, : max_positions[i]]\n            else:\n                raise NotImplementedError\n\n        return token_level_score, q\n\n    def _optimizer_step(self):\n        assert self.config.model.optim.grad_clip is not None\n\n        if isinstance(self.reward_module, FSDP):\n            grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(\n                self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip\n            )\n        self.reward_optimizer.step()\n        return grad_norm\n\n    def prime_norm(self, token_level_scores):\n        if self.config.prime_norm == \"batch_norm\":\n            reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1])\n            token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6)\n        return token_level_scores\n\n    def compute_rm_score(self, data: DataProto):\n        self.reward_module.eval()\n        self.ref_module.eval()\n        micro_batch_size = data.meta_info[\"micro_batch_size\"]\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\", \"acc\"]\n        batch = data.select(batch_keys=select_keys).batch\n        use_dynamic_bsz = data.meta_info[\"use_dynamic_bsz\"]\n        prompt_length = data.batch[\"input_ids\"].shape[-1] - data.batch[\"responses\"].shape[-1]\n\n        if use_dynamic_bsz:\n            # split using dynamic bsz\n            max_token_len = data.meta_info[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)\n        else:\n            micro_batches = batch.split(micro_batch_size)\n\n        rm_scores_lst = []\n        q_lst = []\n        for micro_batch in micro_batches:\n            with torch.no_grad():\n                rm_score, q = self._forward_micro_batch(micro_batch, prompt_length)\n            rm_scores_lst.append(rm_score)\n            q_lst.append(q)\n        rm_scores = torch.concat(rm_scores_lst, dim=0)\n        q = torch.concat(q_lst, dim=0)\n\n        rm_scores = self.prime_norm(rm_scores)\n\n        if use_dynamic_bsz:\n            indices = list(itertools.chain.from_iterable(indices))\n            assert len(indices) == rm_scores.size(0), f\"{len(indices)} vs. {rm_scores.size()}\"\n            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n            rm_scores = rm_scores[revert_indices]\n\n        return (\n            rm_scores,\n            q.detach(),\n            {\n                \"reward_model/reward\": rm_scores.sum(dim=-1).mean().item(),\n                \"reward_model/raw_reward\": q.sum(dim=-1).mean().item(),\n            },\n        )\n\n    def update_rm(self, data: DataProto):\n        # make sure we are in training mode\n        self.reward_module.train()\n        metrics = {}\n\n        beta = self.config.model.get(\"beta_train\", 0.05)\n\n        select_keys = [\"input_ids\", \"responses\", \"attention_mask\", \"position_ids\", \"acc\", \"prompts\"]\n\n        for key in [\"Q_bc\", \"acc_bc\"]:\n            if key in data.batch.keys():\n                select_keys.append(key)\n\n        batch = data.select(batch_keys=select_keys).batch\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        dataloader = batch.split(self.config.mini_batch_size)\n\n        rm_scores_lst = []\n        q_lst = []\n\n        for batch_idx, data in enumerate(dataloader):\n            # split batch into micro_batches\n            mini_batch = data\n            if self.config.use_dynamic_bsz:\n                max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)\n            else:\n                micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu)\n                self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu\n\n            self.reward_optimizer.zero_grad()\n\n            for data in micro_batches:\n                data = data.to(get_device_name())\n                attention_mask = data[\"attention_mask\"]\n                acc = data[\"acc\"]\n\n                prompt_ids = data[\"prompts\"]\n                prompt_length = prompt_ids.shape[-1]\n\n                response_mask = attention_mask[:, prompt_length:]\n\n                rm_score, q = self._forward_micro_batch(data, prompt_length)\n\n                rm_scores_lst.append(rm_score)\n                q_lst.append(q.detach())\n\n                if self.config.model.loss_type == \"ce\":\n                    dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta)\n                elif self.config.model.loss_type == \"dpo\":\n                    # the implementation of dpo is actually detached, which means we have to know the average\n                    # value of w/l reward before the update.\n                    dpo_loss = compute_detach_dpo_loss_rm(\n                        q, acc, Q_bc=data[\"Q_bc\"], acc_bc=data[\"acc_bc\"], response_mask=response_mask, beta=beta\n                    )\n                elif self.config.model.loss_type == \"bon_acc\":\n                    # change the original distribution of each sample to BoN distribution, then update reward model\n                    dpo_loss = compute_detach_dpo_loss_rm(\n                        q,\n                        acc,\n                        Q_bc=data[\"Q_bc\"],\n                        acc_bc=data[\"acc_bc\"],\n                        response_mask=response_mask,\n                        beta=beta,\n                        bon_mode=\"bon_acc\",\n                    )\n                elif self.config.model.loss_type == \"bon_rm\":\n                    dpo_loss = compute_detach_dpo_loss_rm(\n                        q,\n                        acc,\n                        Q_bc=data[\"Q_bc\"],\n                        acc_bc=data[\"acc_bc\"],\n                        response_mask=response_mask,\n                        beta=beta,\n                        bon_mode=\"bon_rm\",\n                    )\n                else:\n                    raise NotImplementedError\n\n                data = {\"reward_model/dpo_loss\": dpo_loss.detach().item()}\n\n                if self.config.use_dynamic_bsz:\n                    # relative to the dynamic bsz\n                    loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size)\n                else:\n                    loss = dpo_loss / self.gradient_accumulation\n\n                loss.backward()\n\n                append_to_dict(metrics, data)\n\n            grad_norm = self._optimizer_step()\n            data = {\"reward_model/grad_norm\": grad_norm.detach().item()}\n            append_to_dict(metrics, data)\n        self.reward_optimizer.zero_grad()\n\n        rm_scores = torch.cat(rm_scores_lst, dim=0)\n        q = torch.concat(q_lst, dim=0)\n\n        rm_scores = self.prime_norm(rm_scores)\n\n        metrics.update(\n            {\n                \"reward_model/reward\": rm_scores.sum(dim=-1).mean().item(),\n                \"reward_model/raw_reward\": q.sum(dim=-1).mean().item(),\n            }\n        )\n\n        return rm_scores, metrics\n"
  },
  {
    "path": "verl_distillation/recipe/prime/prime_fsdp_workers.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport logging\nimport os\nimport warnings\n\nimport torch\nimport torch.distributed\nfrom omegaconf import OmegaConf\nfrom torch.distributed.device_mesh import init_device_mesh\n\nfrom verl import DataProto\nfrom verl.models.transformers.monkey_patch import apply_monkey_patch\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.device import get_device_id, get_device_name, get_nccl_backend\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fs import copy_local_path_from_hdfs\nfrom verl.utils.fsdp_utils import (\n    get_fsdp_wrap_policy,\n    get_init_weight_context_manager,\n    init_fn,\n    load_fsdp_model_to_gpu,\n    load_fsdp_optimizer,\n    offload_fsdp_model_to_cpu,\n    offload_fsdp_optimizer,\n)\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.profiler import log_gpu_memory_usage\nfrom verl.workers.config.optimizer import build_optimizer\nfrom verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\nfrom .prime_core_algos import compute_dpo_abs_accuracy, compute_dpo_accuracy\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass PRIMERewardModelWorker(Worker):\n    def __init__(self, config):\n        super().__init__()\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(backend=get_nccl_backend())\n        self.config = config\n\n        # build device mesh for Ulysses Sequence Parallel\n        world_size = torch.distributed.get_world_size()\n\n        fsdp_size = self.config.model.fsdp_config.fsdp_size\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        # set FSDP offload params\n        self._is_offload_param = self.config.model.fsdp_config.param_offload\n        self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload\n\n        # normalize config\n        self.config.mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n        if self.config.micro_batch_size is not None:\n            self.config.micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n            self.config.micro_batch_size_per_gpu = self.config.micro_batch_size\n            assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0\n\n    def _build_reward_ref_model_optimizer(self, config):\n        # the following line is necessary\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from torch.distributed.fsdp import MixedPrecision\n\n        from verl.utils.model import print_model_size\n        from verl.utils.torch_dtypes import PrecisionType\n\n        local_path = copy_local_path_from_hdfs(config.model.path)\n\n        tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)\n        self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n\n        override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get(\"override_config\", {})))\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_config)\n        if self.rank == 0:\n            print(f\"Reward model overriding config {override_config_kwargs}\")\n\n        torch_dtype = self.config.model.fsdp_config.get(\"model_dtype\", \"fp32\")\n        torch_dtype = PrecisionType.to_dtype(torch_dtype)\n\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        trust_remote_code = False\n        reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)\n        reward_model_config.num_labels = 1\n\n        init_context = get_init_weight_context_manager(use_meta_tensor=not reward_model_config.tie_word_embeddings)\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            reward_model_config.classifier_dropout = 0.0\n            reward_model_config.hidden_dropout = \"0\"\n            reward_module = AutoModelForCausalLM.from_pretrained(\n                pretrained_model_name_or_path=local_path,\n                torch_dtype=torch_dtype,\n                config=reward_model_config,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )\n\n            fused_kernel_options = config.model.get(\"fused_kernel_options\", None)\n            fused_kernels_backend = (\n                fused_kernel_options.get(\"impl_backend\", None) if fused_kernel_options is not None else None\n            )\n\n            apply_monkey_patch(\n                model=reward_module,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n                use_remove_padding=config.model.get(\"use_remove_padding\", False),\n                use_fused_kernels=config.model.get(\"use_fused_kernels\", False),\n                fused_kernels_backend=fused_kernels_backend,\n            )\n\n            # some parameters may not in torch_dtype\n            reward_module.to(torch_dtype)\n\n            if config.model.get(\"enable_gradient_checkpointing\", False):\n                reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n        if self.rank == 0:\n            print_model_size(reward_module)\n\n        self.reward_model_config = reward_model_config\n\n        fsdp_config = self.config.model.fsdp_config\n        mixed_precision_config = fsdp_config.get(\"mixed_precision\", None)\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"param_dtype\", \"bf16\"))\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"reduce_dtype\", \"fp32\"))\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"buffer_dtype\", \"fp32\"))\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n\n        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy)\n\n        log_gpu_memory_usage(\"Before reward model FSDP\", logger=None)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            reward_model_config.classifier_dropout = 0.0\n            reward_model_config.hidden_dropout = \"0\"\n            ref_module = AutoModelForCausalLM.from_pretrained(\n                pretrained_model_name_or_path=copy_local_path_from_hdfs(config.model.ref_path),\n                torch_dtype=torch_dtype,\n                config=reward_model_config,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )\n\n            # some parameters may not in torch_dtype\n            ref_module.to(torch_dtype)\n\n        reward_module = FSDP(\n            reward_module,\n            param_init_fn=init_fn,\n            use_orig_params=False,\n            auto_wrap_policy=auto_wrap_policy,\n            device_id=get_device_id(),\n            sharding_strategy=sharding_strategy,\n            mixed_precision=mixed_precision,\n            sync_module_states=True,\n            forward_prefetch=False,\n            device_mesh=self.device_mesh,\n            cpu_offload=None,\n        )\n\n        log_gpu_memory_usage(\"After reward FSDP\", logger=None)\n\n        ref_module = FSDP(\n            ref_module,\n            param_init_fn=init_fn,\n            use_orig_params=False,\n            auto_wrap_policy=auto_wrap_policy,\n            device_id=get_device_id(),\n            sharding_strategy=sharding_strategy,\n            mixed_precision=mixed_precision,\n            sync_module_states=True,\n            forward_prefetch=False,\n            device_mesh=self.device_mesh,\n            cpu_offload=None,\n        )\n\n        reward_optimizer = build_optimizer(reward_module.parameters(), config.model.optim)\n\n        total_steps = config.model.optim.get(\"total_training_steps\", 0)\n        num_warmup_steps = int(config.model.optim.get(\"lr_warmup_steps\", -1))\n        if num_warmup_steps < 0:\n            num_warmup_steps_ratio = config.model.optim.get(\"lr_warmup_steps_ratio\", 0.0)\n            num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n\n        print(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n\n        from verl.utils.torch_functional import get_constant_schedule_with_warmup\n\n        reward_lr_scheduler = get_constant_schedule_with_warmup(\n            optimizer=reward_optimizer, num_warmup_steps=num_warmup_steps\n        )\n\n        return reward_module, ref_module, reward_optimizer, reward_lr_scheduler\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        from .prime_dp_rm import DataParallelPRIMERewardModel\n\n        self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = (\n            self._build_reward_ref_model_optimizer(config=self.config)\n        )\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.reward_module)\n            offload_fsdp_model_to_cpu(self.ref_module)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.reward_optimizer)\n\n        self.rm = DataParallelPRIMERewardModel(\n            config=self.config,\n            reward_module=self.reward_module,\n            ref_module=self.ref_module,\n            reward_optimizer=self.reward_optimizer,\n        )\n\n        self.flops_counter = FlopsCounter(self.reward_model_config)\n        self.checkpoint_manager = FSDPCheckpointManager(\n            model=self.reward_module,\n            optimizer=self.reward_optimizer,\n            lr_scheduler=self.reward_lr_scheduler,\n            tokenizer=self.tokenizer,\n        )\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def compute_rm_score(self, data: DataProto):\n        data = data.to(get_device_name())\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.reward_module)\n            load_fsdp_model_to_gpu(self.ref_module)\n        micro_batch_size = self.config.micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"max_token_len\"] = self.config.forward_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n            rm_scores, q, metrics = self.rm.compute_rm_score(data=data)\n\n            prompt_length = data.batch[\"prompts\"].shape[-1]\n            response_mask = data.batch[\"attention_mask\"][:, prompt_length:]\n            acc = data.batch[\"acc\"]\n\n            dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info[\"n\"])\n            dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info[\"n\"])\n\n            metrics[\"reward_model/dpo_acc\"] = dpo_acc.detach().item()\n            metrics[\"reward_model/dpo_acc_abs\"] = dpo_acc_abs.detach().item()\n\n            output = DataProto.from_dict(tensors={\"rm_scores\": rm_scores, \"q\": q}, meta_info={\"metrics\": metrics})\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n\n        output = output.to(\"cpu\")\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.reward_module)\n            offload_fsdp_model_to_cpu(self.ref_module)\n        return output\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def update_rm(self, data: DataProto):\n        data = data.to(get_device_name())\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.ref_module)\n            load_fsdp_model_to_gpu(self.reward_module)\n        if self._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=get_device_id())\n\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n\n            rm_scores, metrics = self.rm.update_rm(data=data)\n\n            self.reward_lr_scheduler.step()\n            lr = self.reward_lr_scheduler.get_last_lr()[0]\n            metrics[\"rm/lr\"] = lr\n\n            prompt_length = data.batch[\"prompts\"].shape[-1]\n            response_mask = data.batch[\"attention_mask\"][:, prompt_length:]\n            acc = data.batch[\"acc\"]\n\n            dpo_acc_before = compute_dpo_accuracy(\n                rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info[\"n\"]\n            )\n            dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info[\"n\"])\n\n            metrics[\"reward_model/dpo_acc_before\"] = dpo_acc_before.detach().item()\n            metrics[\"reward_model/dpo_acc_abs_before\"] = dpo_acc_abs.detach().item()\n\n            output = DataProto.from_dict(tensors={\"rm_scores\": rm_scores}, meta_info={\"metrics\": metrics})\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.reward_module)\n            offload_fsdp_model_to_cpu(self.ref_module)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.reward_optimizer)\n        output = output.to(\"cpu\")\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.reward_module)\n\n        self.checkpoint_manager.save_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.reward_module)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, local_path, del_local_after_load=True):\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.reward_module)\n\n        self.checkpoint_manager.load_checkpoint(local_path=local_path, del_local_after_load=del_local_after_load)\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.reward_module)\n"
  },
  {
    "path": "verl_distillation/recipe/prime/prime_ray_trainer.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nFSDP PPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport os\nimport statistics\nimport uuid\nfrom copy import deepcopy\nfrom pprint import pprint\n\nimport numpy as np\nimport torch\nfrom omegaconf import OmegaConf, open_dict\n\nfrom verl import DataProto\nfrom verl.single_controller.ray import RayWorkerGroup\nfrom verl.trainer.ppo.core_algos import agg_loss\nfrom verl.trainer.ppo.metric_utils import _compute_response_info\nfrom verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager\nfrom verl.trainer.ppo.utils import Role, WorkerType\nfrom verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path\nfrom verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn\nfrom verl.utils.metric import reduce_metrics\nfrom verl.utils.profiler.performance import simple_timer\n\nfrom . import prime_core_algos\n\n\ndef compute_advantage(data: DataProto, adv_estimator, config):\n    if adv_estimator == \"rloo\":\n        responses = data.batch[\"responses\"]\n        response_length = responses.size(-1)\n        attention_mask = data.batch[\"attention_mask\"]\n        response_mask = attention_mask[:, -response_length:]\n        advantages, returns = prime_core_algos.compute_rloo_advantage_return(\n            data, response_mask, config.actor_rollout_ref.rollout.n, config\n        )\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n    else:\n        raise NotImplementedError\n    return data\n\n\ndef compute_data_metrics(batch, use_critic=True):\n    advantages = batch.batch[\"advantages\"]\n    returns = batch.batch[\"returns\"]\n\n    max_response_length = batch.batch[\"responses\"].shape[-1]\n\n    prompt_mask = batch.batch[\"attention_mask\"][:, :-max_response_length].bool()\n    response_mask = batch.batch[\"attention_mask\"][:, -max_response_length:].bool()\n\n    max_prompt_length = prompt_mask.size(-1)\n\n    response_info = _compute_response_info(batch)\n    prompt_length = response_info[\"prompt_length\"]\n    response_length = response_info[\"response_length\"]\n\n    valid_adv = torch.masked_select(advantages, response_mask)\n    valid_returns = torch.masked_select(returns, response_mask)\n\n    if use_critic:\n        values = batch.batch[\"values\"]\n        valid_values = torch.masked_select(values, response_mask)\n        return_diff_var = torch.var(valid_returns - valid_values)\n        return_var = torch.var(valid_returns)\n\n    metrics = {\n        # adv\n        \"critic/advantages/mean\": torch.mean(valid_adv).detach().item(),\n        \"critic/advantages/max\": torch.max(valid_adv).detach().item(),\n        \"critic/advantages/min\": torch.min(valid_adv).detach().item(),\n        # returns\n        \"critic/returns/mean\": torch.mean(valid_returns).detach().item(),\n        \"critic/returns/max\": torch.max(valid_returns).detach().item(),\n        \"critic/returns/min\": torch.min(valid_returns).detach().item(),\n        **(\n            {\n                # values\n                \"critic/values/mean\": torch.mean(valid_values).detach().item(),\n                \"critic/values/max\": torch.max(valid_values).detach().item(),\n                \"critic/values/min\": torch.min(valid_values).detach().item(),\n                # vf explained var\n                \"critic/vf_explained_var\": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),\n            }\n            if use_critic\n            else {}\n        ),\n        # response length\n        \"response_length/mean\": torch.mean(response_length).detach().item(),\n        \"response_length/max\": torch.max(response_length).detach().item(),\n        \"response_length/min\": torch.min(response_length).detach().item(),\n        \"response_length/clip_ratio\": torch.mean(torch.eq(response_length, max_response_length).float())\n        .detach()\n        .item(),\n        # prompt length\n        \"prompt_length/mean\": torch.mean(prompt_length).detach().item(),\n        \"prompt_length/max\": torch.max(prompt_length).detach().item(),\n        \"prompt_length/min\": torch.min(prompt_length).detach().item(),\n        \"prompt_length/clip_ratio\": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),\n    }\n    return metrics\n\n\ndef compute_response_mask(data: DataProto):\n    responses = data.batch[\"responses\"]\n    response_length = responses.size(1)\n    attention_mask = data.batch[\"attention_mask\"]\n    return attention_mask[:, -response_length:]\n\n\ndef compute_timing_metrics(batch, timing_raw):\n    response_info = _compute_response_info(batch)\n    num_prompt_tokens = torch.sum(response_info[\"prompt_length\"]).item()\n    num_response_tokens = torch.sum(response_info[\"response_length\"]).item()\n    num_overall_tokens = num_prompt_tokens + num_response_tokens\n\n    num_tokens_of_section = {\n        \"gen\": num_response_tokens,\n        **{name: num_overall_tokens for name in [\"ref\", \"values\", \"adv\", \"update_critic\", \"update_actor\"]},\n    }\n\n    return {\n        **{f\"timing_s/{name}\": value for name, value in timing_raw.items()},\n        **{\n            f\"timing_per_token_ms/{name}\": timing_raw[name] * 1000 / num_tokens_of_section[name]\n            for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())\n        },\n    }\n\n\nclass RayPRIMETrainer(RayPPOTrainer):\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        reward_fn=None,\n        val_reward_fn=None,\n        device_name=\"cuda\",\n    ):\n        # assert get_torch_device().is_available(), 'cuda must be available on driver'\n\n        super().__init__(\n            config,\n            tokenizer,\n            role_worker_mapping,\n            resource_pool_manager,\n            ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            device_name=device_name,\n        )\n\n        self.use_critic = False\n\n    def _create_dataloader(self, *args, **kwargs):\n        from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n\n        # TODO: we have to make sure the batch size is divisible by the dp size\n        self.train_dataset = RLHFDataset(\n            data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data\n        )\n        # use sampler for better ckpt resume\n        if self.config.data.shuffle:\n            train_dataloader_generator = torch.Generator()\n            seed = self.config.data.get(\"seed\")\n            if seed is not None:\n                train_dataloader_generator.manual_seed(seed)\n            sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)\n        else:\n            sampler = SequentialSampler(data_source=self.train_dataset)\n\n        self.train_dataloader = DataLoader(\n            dataset=self.train_dataset,\n            batch_size=int(self.config.data.train_batch_size * self.config.data.oversample_factor),\n            drop_last=True,\n            collate_fn=collate_fn,\n            sampler=sampler,\n        )\n\n        self.val_dataset = RLHFDataset(\n            data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data\n        )\n        self.val_dataloader = DataLoader(\n            dataset=self.val_dataset,\n            batch_size=len(self.val_dataset),\n            shuffle=True,\n            drop_last=True,\n            collate_fn=collate_fn,\n        )\n\n        assert len(self.train_dataloader) >= 1\n        assert len(self.val_dataloader) >= 1\n\n        print(f\"Size of train dataloader: {len(self.train_dataloader)}\")\n        print(f\"Size of val dataloader: {len(self.val_dataloader)}\")\n\n        # inject total_training_steps to actor/critic optim_config. This is hacky.\n        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n\n        if self.config.trainer.total_training_steps is not None:\n            total_training_steps = self.config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n        print(f\"Total training steps: {self.total_training_steps}\")\n\n        OmegaConf.set_struct(self.config, True)\n        with open_dict(self.config):\n            self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps\n            self.config.critic.optim.total_training_steps = total_training_steps\n\n    def _save_checkpoint(self):\n        # path: given_path + `/global_step_{global_steps}` + `/actor`\n        local_global_step_folder = os.path.join(\n            self.config.trainer.default_local_dir, f\"global_step_{self.global_steps}\"\n        )\n        print(f\"local_global_step_folder: {local_global_step_folder}\")\n        actor_local_path = os.path.join(local_global_step_folder, \"actor\")\n\n        actor_remote_path = (\n            None\n            if self.config.trainer.default_hdfs_dir is None\n            else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"actor\")\n        )\n        self.actor_rollout_wg.save_checkpoint(\n            actor_local_path,\n            actor_remote_path,\n            self.global_steps,\n        )\n\n        if self.use_rm:\n            reward_local_path = os.path.join(local_global_step_folder, \"reward\")\n            reward_remote_path = (\n                None\n                if self.config.trainer.default_hdfs_dir is None\n                else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"reward\")\n            )\n            self.rm_wg.save_checkpoint(\n                reward_local_path,\n                reward_remote_path,\n                self.global_steps,\n            )\n\n        # save dataloader\n        dataloader_local_path = os.path.join(local_global_step_folder, \"data.pt\")\n        import dill\n\n        torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill)\n\n        # latest checkpointed iteration tracker (for atomic usage)\n        local_latest_checkpointed_iteration = os.path.join(\n            self.config.trainer.default_local_dir, \"latest_checkpointed_iteration.txt\"\n        )\n        with open(local_latest_checkpointed_iteration, \"w\") as f:\n            f.write(str(self.global_steps))\n\n    def _load_checkpoint(self):\n        if self.config.trainer.resume_mode == \"disable\":\n            return 0\n\n        # load from hdfs\n        if self.config.trainer.default_hdfs_dir is not None:\n            NotImplementedError(\"load from hdfs is not implemented yet\")\n        else:\n            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path\n            if not os.path.isabs(checkpoint_folder):\n                working_dir = os.getcwd()\n                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)\n            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest\n\n        # find global_step_folder\n        if self.config.trainer.resume_mode == \"auto\":\n            if global_step_folder is None:\n                print(\"Training from scratch\")\n                return 0\n        else:\n            if self.config.trainer.resume_mode == \"resume_path\":\n                assert isinstance(self.config.trainer.resume_from_path, str), \"resume ckpt must be str type\"\n                assert \"global_step_\" in self.config.trainer.resume_from_path, (\n                    \"resume ckpt must specify the global_steps\"\n                )\n                global_step_folder = self.config.trainer.resume_from_path\n                if not os.path.isabs(global_step_folder):\n                    working_dir = os.getcwd()\n                    global_step_folder = os.path.join(working_dir, global_step_folder)\n        print(f\"Load from checkpoint folder: {global_step_folder}\")\n        # set global step\n        self.global_steps = int(global_step_folder.split(\"global_step_\")[-1])\n\n        print(f\"Setting global step to {self.global_steps}\")\n        print(f\"Resuming from {global_step_folder}\")\n\n        actor_path = os.path.join(global_step_folder, \"actor\")\n        reward_path = os.path.join(global_step_folder, \"reward\")\n        # load actor\n        self.actor_rollout_wg.load_checkpoint(\n            actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n        )\n        # load rm\n        if self.use_rm:\n            self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)\n\n        # load dataloader,\n        # TODO: from remote not implemented yet\n        dataloader_local_path = os.path.join(global_step_folder, \"data.pt\")\n        self.train_dataloader = torch.load(dataloader_local_path)\n        if isinstance(self.train_dataloader.dataset, RLHFDataset):\n            self.train_dataloader.dataset.resume_dataset_state()\n\n    def compute_reward(self, batch: DataProto, n_samples: int):\n        update_style = self.config.reward_model.model.get(\"update\", \"none\")\n        reward_output_metrics = {}\n        if update_style == \"none\":  # only run forward\n            reward_output = self.rm_wg.compute_rm_score(batch)\n        elif update_style == \"after\":  # update and directly return the reward\n            reward_output = self.rm_wg.update_rm(batch)\n        elif update_style == \"before\":  # update reward model, and then run forward\n            reward_output = self.rm_wg.update_rm(batch)\n            if \"metrics\" in reward_output.meta_info.keys():\n                reward_output_metrics = reduce_metrics(reward_output.meta_info[\"metrics\"])\n\n            reward_output = self.rm_wg.compute_rm_score(batch)\n        elif update_style == \"reverse\":  # run forward to calculate statistics, then update reward model\n            reward_output = self.rm_wg.compute_rm_score(batch)\n\n            # broadcast q and acc tensor to each result\n            bc_td = DataProto.from_dict(\n                tensors={\n                    \"Q_bc\": reward_output.batch[\"q\"]\n                    .sum(dim=-1)\n                    .view(-1, n_samples)\n                    .unsqueeze(1)\n                    .expand(-1, n_samples, -1)\n                    .reshape(-1, n_samples),\n                    \"acc_bc\": batch.batch[\"acc\"]\n                    .view(-1, n_samples)\n                    .unsqueeze(1)\n                    .expand(-1, n_samples, -1)\n                    .reshape(-1, n_samples),\n                }\n            )\n            batch = batch.union(bc_td)\n            reward_output = self.rm_wg.update_rm(batch)\n        else:\n            raise NotImplementedError\n\n        return reward_output, reward_output_metrics\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC to\n        construct the PPO dataflow. The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # we start from step 1\n        self.global_steps += 1\n\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n                timing_raw = {}\n\n                batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n                # pop those keys for generation\n                gen_batch = batch.pop(batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"])\n                gen_batch_output = gen_batch.repeat(\n                    repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True\n                )\n\n                with simple_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with simple_timer(\"gen\", timing_raw):\n                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                    if self.config.algorithm.adv_estimator == \"remax\":\n                        with simple_timer(\"gen_max\", timing_raw):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n\n                            batch = batch.union(gen_baseline_output)\n                            rm_scores, _ = self.compute_reward(batch, 1)\n                            reward_baseline_tensor = rm_scores.batch.get(\n                                \"rm_scores\", rm_scores.batch.get(\"acc_bc\", None)\n                            )\n                            if reward_baseline_tensor is None:\n                                raise ValueError(\n                                    \"Neither 'rm_scores' nor 'acc_bc' found in reward model output for baseline.\"\n                                )\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            keys_to_pop = set(gen_baseline_output.batch.keys())\n                            keys_to_pop.update(rm_scores.batch.keys())\n                            batch.pop(batch_keys=list(keys_to_pop))\n\n                            batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del gen_baseline_batch, gen_baseline_output\n\n                    batch.non_tensor_batch[\"uid\"] = np.array(\n                        [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object\n                    )\n                    # repeat to align with repeated responses in rollout\n                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    batch = batch.union(gen_batch_output)\n\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    # TODO: Decouple the DP balancing and mini-batching.\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                    # verify\n                    with simple_timer(\"verify\", timing_raw):\n                        scores = self.reward_fn.verify(batch)\n                        metrics[\"acc\"] = statistics.mean(scores)\n\n                    # filter the batch. 1/oversample_factor samples will be kept.\n                    # If there is a filter, prompts passing it will be prioritized.\n\n                    batch = self.filter_and_downsample(scores, batch)\n                    batch.meta_info[\"n\"] = self.config.actor_rollout_ref.rollout.n\n                    n_samples = self.config.actor_rollout_ref.rollout.n\n\n                    # recompute old_log_probs\n                    with simple_timer(\"old_log_prob\", timing_raw):\n                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                        entropys = old_log_prob.batch[\"entropys\"]\n                        response_masks = compute_response_mask(batch)\n                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                        old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n                        metrics.update(old_log_prob_metrics)\n                        old_log_prob.batch.pop(\"entropys\")\n                        batch = batch.union(old_log_prob)\n\n                    if self.use_reference_policy:\n                        # compute reference log_prob\n                        with simple_timer(\"ref\", timing_raw):\n                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                            batch = batch.union(ref_log_prob)\n\n                    with simple_timer(\"adv\", timing_raw):\n                        if self.use_rm:\n                            reward_output, reward_output_metrics = self.compute_reward(batch, n_samples)\n                            batch = batch.union(reward_output)\n                            if \"metrics\" in reward_output.meta_info.keys():\n                                reward_output_metrics.update(reduce_metrics(reward_output.meta_info[\"metrics\"]))\n                            metrics.update(reward_output_metrics)\n\n                        # compute advantages, executed on the driver process\n                        batch = compute_advantage(\n                            batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config\n                        )\n\n                    # update actor\n                    with simple_timer(\"update_actor\", timing_raw):\n                        actor_output = self.actor_rollout_wg.update_actor(batch)\n                    actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                    metrics.update(actor_output_metrics)\n\n                    # validate\n                    if (\n                        self.val_reward_fn is not None\n                        and self.config.trainer.test_freq > 0\n                        and self.global_steps % self.config.trainer.test_freq == 0\n                    ):\n                        with simple_timer(\"testing\", timing_raw):\n                            val_metrics: dict = self._validate()\n                        metrics.update(val_metrics)\n\n                    if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0:\n                        with simple_timer(\"save_checkpoint\", timing_raw):\n                            self._save_checkpoint()\n\n                # collect metrics\n                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                self.global_steps += 1\n\n                if self.global_steps >= self.total_training_steps:\n                    # perform validation after training\n                    if self.val_reward_fn is not None:\n                        val_metrics = self._validate()\n                        pprint(f\"Final validation metrics: {val_metrics}\")\n                        logger.log(data=val_metrics, step=self.global_steps)\n                    if (\n                        self.config.trainer.save_freq > 0\n                        and (self.global_steps - 1) % self.config.trainer.save_freq != 0\n                    ):\n                        with simple_timer(\"save_checkpoint\", timing_raw):\n                            self._save_checkpoint()\n                    return\n\n    def filter_and_downsample(self, scores, batch: DataProto):\n        \"\"\"\n        downsample the batch according to oversample_factor\n        samples passing the filters will be prioritized\n        \"\"\"\n        n_samples = int(self.config.actor_rollout_ref.rollout.n)\n        reward_matrix = torch.tensor(scores).reshape(-1, n_samples)\n\n        filter_mask = torch.ones((reward_matrix.shape[0]), dtype=torch.bool)\n\n        if self.config.data.filter_accuracy:\n            acc_tensor = torch.mean(reward_matrix, dim=-1)\n            filter_mask[\n                (acc_tensor > self.config.data.accuracy_upper_bound)\n                | (acc_tensor < self.config.data.accuracy_lower_bound)\n            ] = False\n\n        if self.config.data.filter_truncate:\n            length_matrix = (\n                batch.batch[\"attention_mask\"][:, -batch.batch[\"responses\"].shape[-1] :]\n                .sum(dim=-1)\n                .reshape(-1, n_samples)\n            )\n            length_tensor = torch.max(length_matrix, dim=-1)[0]\n            filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False\n\n        reorder_index = torch.argsort(filter_mask, descending=True)\n        reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1)\n        batch.reorder(\n            reorder_index[: int(len(batch) // self.config.data.oversample_factor)]\n        )  # this operation is inplace\n\n        return batch\n"
  },
  {
    "path": "verl_distillation/recipe/prime/run_prime_qwen.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\n\n# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\nmodel_path=PRIME-RL/Eurus-2-7B-SFT\n# model_path=Qwen/Qwen2.5-0.5B-Instruct\n\npython3 -m recipe.prime.main_prime \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=64 \\\n    data.val_batch_size=6312 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=3072 \\\n    data.filter_overlong_prompts=True \\\n    data.filter_accuracy=True \\\n    data.accuracy_lower_bound=0.2 \\\n    data.accuracy_upper_bound=0.8 \\\n    data.oversample_factor=4 \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=5e-7 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    algorithm.adv_estimator=rloo \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    reward_model.model.path=$model_path \\\n    reward_model.micro_batch_size_per_gpu=1 \\\n    reward_model.model.update=before \\\n    reward_model.model.beta_train=0.05 \\\n    reward_model.model.optim.lr=1e-6 \\\n    reward_model.model.optim.grad_clip=10.0 \\\n    reward_model.model.input_tokenizer=null \\\n    reward_model.mini_batch_size=64 \\\n    trainer.val_before_train=False \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='prime_example' \\\n    trainer.experiment_name='Eurus-2-7B-SFT-gsm8k' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=64 \\\n    trainer.test_freq=64 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/recipe/prime/run_prime_qwen_code.sh",
    "content": "set -x\n\n\n# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data\ncode_train_path=$HOME/data/code/train.parquet\ncode_test_path=$HOME/data/code/test.parquet\n\ntrain_files=\"['$code_train_path']\"\ntest_files=\"['$code_test_path']\"\n\nmodel_path=PRIME-RL/Eurus-2-7B-SFT\n# model_path=Qwen/Qwen2.5-0.5B-Instruct\n\npython3 -m recipe.prime.main_prime \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=64 \\\n    data.val_batch_size=6312 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=3072 \\\n    data.filter_overlong_prompts=True \\\n    data.filter_accuracy=True \\\n    data.accuracy_lower_bound=0.2 \\\n    data.accuracy_upper_bound=0.8 \\\n    data.oversample_factor=4 \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=5e-7 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    algorithm.adv_estimator=rloo \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    reward_model.model.path=$model_path \\\n    reward_model.micro_batch_size_per_gpu=1 \\\n    reward_model.model.update=before \\\n    reward_model.model.beta_train=0.05 \\\n    reward_model.model.optim.lr=1e-6 \\\n    reward_model.model.optim.grad_clip=10.0 \\\n    reward_model.model.input_tokenizer=null \\\n    reward_model.mini_batch_size=64 \\\n    trainer.val_before_train=False \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='prime_example' \\\n    trainer.experiment_name='Eurus-2-7B-SFT-code' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=64 \\\n    trainer.test_freq=64 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_distillation/recipe/r1/README.md",
    "content": "# DeepSeek R1 Reproduction\n\nThis recipe is under development, if you are interested, checkout the TODO list and join this project! https://github.com/volcengine/verl/issues/708 \n\n## Reproducing Evaluation\n\nEval Results of DS-R1-Distill-Qwen2.5-1.5B (k=8)\n\nDataset | Test Results | Reported\n-- | -- | --\nGPQA Diamond | 35.3 | 33.8\nLiveCodeBench | 16.9 | 16.9\nAIME 2024 | 30.4 | 28.9\nCNMO 2024 (en) | 45.1 | -\nCNMO 2024 (zh) | 41.0 | -\n\n---\n\nEval Results (DS-R1)\n\nDataset | Test Results (k=1) | Test Results (k=4) | Reported\n-- | -- | -- | --\nGPQA Diamond | 67.7 | 69.6 | 71.5\nLiveCodeBench | 64.7 | 63.1 | 65.9\nAIME 2024 | 86.7 | 79.2 | 79.8\nCNMO 2024 | 75.0 | 78.5 | 78.8\n"
  },
  {
    "path": "verl_distillation/recipe/r1/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/recipe/r1/config/evaluation.yaml",
    "content": "data:\n  path: /tmp/math_Qwen2-7B-Instruct.parquet\n  prompt_key: prompt\n  response_key: responses\n  data_source_key: data_source\n  reward_model_key: reward_model\n\ncustom_reward_function:\n  path: null\n  name: compute_score\n\nray_kwargs:\n  ray_init:\n    num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then."
  },
  {
    "path": "verl_distillation/recipe/r1/data_process.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nfrom functools import partial\n\nfrom datasets import concatenate_datasets, load_dataset\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef example_map_fn(example, idx, process_fn, data_source, ability, split):\n    question, solution = process_fn(example)\n    data = {\n        \"data_source\": data_source,\n        \"prompt\": [{\"role\": \"user\", \"content\": question}],\n        \"ability\": ability,\n        \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n        \"extra_info\": {\"split\": split, \"index\": idx},\n    }\n    return data\n\n\ndef build_aime2024_dataset():\n    def process_aime2024(example):\n        return example[\"Problem\"], str(example[\"Answer\"])\n\n    data_source = \"Maxwell-Jia/AIME_2024\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n    dataset = load_dataset(data_source, split=\"train\")\n    map_fn = partial(\n        example_map_fn, process_fn=process_aime2024, data_source=data_source, ability=\"English\", split=\"test\"\n    )\n    dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)\n    return dataset\n\n\ndef build_gpqa_dimond_dataset():\n    import random\n\n    GPQA_QUERY_TEMPLATE = (\n        \"Answer the following multiple choice question. The last line of your response should be of the following \"\n        \"format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before \"\n        \"answering.\\n\\n{Question}\\n\\nA) {A}\\nB) {B}\\nC) {C}\\nD) {D}\"\n    )\n\n    def process_gpqa_diamond(example):\n        choices = [example[\"Incorrect Answer 1\"], example[\"Incorrect Answer 2\"], example[\"Incorrect Answer 3\"]]\n        random.shuffle(choices)\n        gold_index = random.randint(0, 3)\n        choices.insert(gold_index, example[\"Correct Answer\"])\n        query_prompt = GPQA_QUERY_TEMPLATE.format(\n            A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=example[\"Question\"]\n        )\n        gold_choice = \"ABCD\"[gold_index]\n        return query_prompt, gold_choice\n\n    data_source = \"Idavidrein/gpqa\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n\n    dataset = load_dataset(data_source, \"gpqa_diamond\", split=\"train\")\n    map_fn = partial(\n        example_map_fn, process_fn=process_gpqa_diamond, data_source=data_source, ability=\"Math\", split=\"test\"\n    )\n    dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)\n    return dataset\n\n\ndef build_cnmo2024_dataset():\n    def process_cnmo2024(example):\n        return example[\"question\"], example[\"answer\"]\n\n    data_source = \"opencompass/LiveMathBench\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n\n    dataset_en = load_dataset(data_source, \"v202412_CNMO_en\", split=\"test\")\n    map_fn_en = partial(\n        example_map_fn, process_fn=process_cnmo2024, data_source=\"opencompass/cnmo2024_en\", ability=\"Math\", split=\"test\"\n    )\n    dataset_en = dataset_en.map(map_fn_en, with_indices=True, remove_columns=dataset_en.column_names)\n\n    dataset_zh = load_dataset(data_source, \"v202412_CNMO_cn\", split=\"test\")\n    map_fn_zh = partial(\n        example_map_fn, process_fn=process_cnmo2024, data_source=\"opencompass/cnmo2024_zh\", ability=\"Math\", split=\"test\"\n    )\n    dataset_zh = dataset_zh.map(map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names)\n\n    dataset = concatenate_datasets([dataset_en, dataset_zh])\n    return dataset\n\n\ndef build_livecodebench_dataset():\n    import base64\n    import json\n    import pickle\n    import zlib\n\n    def process_livecodebench(example):\n        # Construct Query Prompt\n        # From https://github.com/LiveCodeBench/LiveCodeBench/blob/998c52d394b836f15fff3b9a29866191108ff81b/lcb_runner/prompts/code_generation.py#L140\n        query_prompt = (\n            f\"You will be given a question (problem specification) and will generate a correct Python program \"\n            f\"that matches the specification and passes all tests.\\n\\nQuestion: {example['question_content']}\\n\\n\"\n        )\n        if example[\"starter_code\"]:\n            query_prompt += (\n                f\"You will use the following starter code to write the solution to the problem and enclose your \"\n                f\"code within delimiters.\\n```python\\n{example['starter_code']}\\n```\"\n            )\n        else:\n            query_prompt += (\n                \"Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test \"\n                \"on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python \"\n                \"program runs, it reads the inputs, runs the algorithm and writes output to STDOUT.\"\n                \"```python\\n# YOUR CODE HERE\\n```\"\n            )\n\n        # Construct test cases\n        public_test_cases = json.loads(example[\"public_test_cases\"])\n        try:\n            private_test_cases = json.loads(example[\"private_test_cases\"])\n        except Exception as e:\n            print(f\"Error loading private test cases: {e}\")\n            private_test_cases = json.loads(\n                pickle.loads(zlib.decompress(base64.b64decode(example[\"private_test_cases\"].encode(\"utf-8\"))))\n            )\n        full_test_cases = public_test_cases + private_test_cases\n\n        metadata = json.loads(example[\"metadata\"])\n        test_cases = {\n            \"inputs\": [t[\"input\"] for t in full_test_cases],\n            \"outputs\": [t[\"output\"] for t in full_test_cases],\n            \"fn_name\": metadata.get(\"func_name\", None),\n        }\n        text_cases_compressed = base64.b64encode(zlib.compress(pickle.dumps(json.dumps(test_cases)))).decode(\"utf-8\")\n        return query_prompt, text_cases_compressed\n\n    data_source = \"livecodebench/code_generation_lite\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n    dataset = load_dataset(data_source, split=\"test\")\n    # R1 Evaluation use LiveCodeBench 24.08-25.01\n    dataset = dataset.filter(lambda line: \"2024-08-00T00:00:00\" <= line[\"contest_date\"] < \"2025-01-00T00:00:00\")\n    map_fn = partial(\n        example_map_fn, process_fn=process_livecodebench, data_source=data_source, ability=\"Code\", split=\"test\"\n    )\n\n    dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8)\n    return dataset\n\n\nTASK2DATA = {\n    \"aime2024\": build_aime2024_dataset,\n    \"gpqa_diamond\": build_gpqa_dimond_dataset,\n    \"cnmo2024\": build_cnmo2024_dataset,\n    \"livecodebench\": build_livecodebench_dataset,\n}\nSUPPORTED_TASKS = TASK2DATA.keys()\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/r1\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--tasks\", default=\"all\")\n\n    args = parser.parse_args()\n\n    if args.tasks.lower() == \"all\":\n        args.tasks = SUPPORTED_TASKS\n    else:\n        args.tasks = [task.strip() for task in args.tasks.split(\",\") if task.strip()]\n        for task in args.tasks:\n            if task not in SUPPORTED_TASKS:\n                raise NotImplementedError(f\"{task} has not been supported.\")\n\n    datasets = []\n    for task in args.tasks:\n        datasets.append(TASK2DATA[task]())\n    test_dataset = concatenate_datasets(datasets)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_distillation/recipe/r1/main_eval.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nOffline evaluate the performance of a generated file using reward model and ground truth verifier.\nThe input is a parquet file that contains N generated sequences and (optional) the ground truth.\n\n\"\"\"\n\nfrom collections import defaultdict\n\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport ray\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\n\nfrom verl.trainer.ppo.reward import get_custom_reward_fn\nfrom verl.utils.fs import copy_to_local\n\n\n@ray.remote\ndef process_item(config, data_source, response_lst, reward_data):\n    reward_fn = get_custom_reward_fn(config)\n    ground_truth = reward_data[\"ground_truth\"]\n    score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]\n    return data_source, np.mean(score_lst)\n\n\n@hydra.main(config_path=\"config\", config_name=\"evaluation\", version_base=None)\ndef main(config):\n    local_path = copy_to_local(config.data.path)\n    dataset = pd.read_parquet(local_path)\n    responses = dataset[config.data.response_key]\n    data_sources = dataset[config.data.data_source_key]\n    reward_model_data = dataset[config.data.reward_model_key]\n\n    total = len(dataset)\n\n    # Initialize Ray\n    if not ray.is_initialized():\n        ray.init(**OmegaConf.to_container(config.ray_kwargs.get(\"ray_init\", {})))\n\n    # evaluate test_score based on data source\n    data_source_reward = defaultdict(list)\n\n    # Create remote tasks\n    remote_tasks = [\n        process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)\n    ]\n\n    # Process results as they come in\n    with tqdm(total=total) as pbar:\n        while len(remote_tasks) > 0:\n            # Use ray.wait to get completed tasks\n            done_ids, remote_tasks = ray.wait(remote_tasks)\n            for result_id in done_ids:\n                data_source, score = ray.get(result_id)\n                data_source_reward[data_source].append(score)\n                pbar.update(1)\n\n    metric_dict = {}\n    for data_source, rewards in data_source_reward.items():\n        metric_dict[f\"test_score/{data_source}\"] = np.mean(rewards)\n\n    print(metric_dict)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/r1/reward_score.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ndef reward_func(data_source, solution_str, ground_truth, extra_info=None):\n    if data_source in [\"Maxwell-Jia/AIME_2024\", \"opencompass/cnmo2024_en\", \"opencompass/cnmo2024_zh\"]:\n        from recipe.r1.tasks import math_reward\n\n        return math_reward.compute_score(solution_str, ground_truth)\n    elif data_source == \"Idavidrein/gpqa\":\n        from recipe.r1.tasks import gpqa\n\n        return gpqa.compute_score(solution_str, ground_truth)\n    elif data_source in [\"livecodebench/code_generation_lite\", \"livecodebench/code_generation\"]:\n        from recipe.r1.tasks import livecodebench\n\n        return livecodebench.compute_score(solution_str, ground_truth)\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_distillation/recipe/r1/run_r1_distill_qwen.sh",
    "content": "MODEL_PATH=Qwen/DeepSeek-R1-Distill-Qwen-1.5B\nDATA_PATH=/workspace/datasets/r1_bench\n\n# Eval Data Process\npython3 -m recipe.r1.data_process \\\n    --local_dir $DATA_PATH \\\n    --tasks all\n\n# Generation\npython3 -m verl.trainer.main_generation \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=8 \\\n    data.path=$DATA_PATH/test.parquet \\\n    data.prompt_key=prompt \\\n    data.batch_size=1024 \\\n    data.n_samples=8 \\\n    data.output_path=$DATA_PATH/test-output-8.parquet \\\n    model.path=$MODEL_PATH \\\n    rollout.temperature=0.6 \\\n    rollout.top_p=0.95 \\\n    rollout.prompt_length=1024 \\\n    rollout.response_length=32768 \\\n    rollout.tensor_model_parallel_size=1 \\\n    rollout.gpu_memory_utilization=0.9 \\\n    rollout.max_num_batched_tokens=65536\n\n# Evaluation\npython3 -m recipe.r1.main_eval \\\n    data.path=$DATA_PATH/test-output-8.parquet \\\n    data.prompt_key=prompt \\\n    data.response_key=responses \\\n    custom_reward_function.path=recipe/r1/reward_score.py \\\n    custom_reward_function.name=reward_func\n"
  },
  {
    "path": "verl_distillation/recipe/r1/tasks/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/recipe/r1/tasks/gpqa.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\n\n# Extraction Template from https://github.com/openai/simple-evals/blob/90e3e821cabba2aeb6be651dcb662b253df04225/common.py#L25\nANSWER_PATTERN_MULTICHOICE = r\"(?i)Answer[ \\t]*:[ \\t]*\\$?([A-D])\\$?\"\n\n\ndef compute_score(solution_str, ground_truth) -> float:\n    match = re.search(ANSWER_PATTERN_MULTICHOICE, solution_str)\n    extracted_answer = match.group(1) if match else None\n    score = 1.0 if extracted_answer == ground_truth else 0.0\n    return score\n"
  },
  {
    "path": "verl_distillation/recipe/r1/tasks/livecodebench.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport base64\nimport json\nimport multiprocessing\nimport pickle\nimport zlib\n\n# Reuse `run_test` for convenience\nfrom verl.utils.reward_score.prime_code.testing_util import run_test\n\n\ndef _temp_run(in_outs, generation, debug, result, metadata_list, timeout):\n    res, metadata = run_test(in_outs, test=generation, debug=debug, timeout=timeout)\n    result.append(res)\n    metadata_list.append(metadata)\n\n\ndef check_correctness(in_outs, generation, timeout, debug=True):\n    \"\"\"Check correctness of code generation with a global timeout.\n    The global timeout is to catch some extreme/rare cases not handled by the timeouts\n    inside `run_test`\"\"\"\n\n    manager = multiprocessing.Manager()\n    result = manager.list()\n    metadata_list = manager.list()\n    p = multiprocessing.Process(\n        target=_temp_run,\n        args=(in_outs, generation, debug, result, metadata_list, timeout),\n    )\n    p.start()\n    p.join(timeout=(timeout + 1) * len(in_outs[\"inputs\"]) + 5)\n    if p.is_alive():\n        p.kill()\n    if not result:\n        # consider that all tests failed\n        result = [[-1 for i in range(len(in_outs[\"inputs\"]))]]\n        if debug:\n            print(\"global timeout\")\n    return result[0], metadata_list[0]\n\n\ndef compute_score(completion, test_cases):\n    solution = completion.split(\"```python\")[-1].split(\"```\")[0]\n\n    # extract test cases\n    try:\n        in_outs = json.loads(test_cases)\n    except Exception as e:\n        print(f\"Error loading test cases: {e}\")\n        in_outs = json.loads(pickle.loads(zlib.decompress(base64.b64decode(test_cases.encode(\"utf-8\")))))\n\n    success = False\n    try:\n        res, metadata = check_correctness(in_outs=in_outs, generation=solution, timeout=6, debug=False)\n        success = all(map(lambda x: x is True, res))\n    except Exception:\n        pass\n\n    return success\n"
  },
  {
    "path": "verl_distillation/recipe/r1/tasks/math_reward.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport contextlib\n\ntry:\n    from math_verify.metric import math_metric\n    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig\nexcept ImportError:\n    print(\"To use Math-Verify, please install it first by running `pip install math-verify`.\")\n\n\ndef compute_score(model_output: str, ground_truth: str) -> bool:\n    verify_func = math_metric(\n        gold_extraction_target=(LatexExtractionConfig(),),\n        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),\n    )\n    ret_score = 0.0\n\n    # Wrap the ground truth in \\boxed{} format for verification\n    ground_truth_boxed = \"\\\\boxed{\" + ground_truth + \"}\"\n    with contextlib.suppress(Exception):\n        ret_score, _ = verify_func([ground_truth_boxed], [model_output])\n\n    return ret_score\n"
  },
  {
    "path": "verl_distillation/recipe/retool/README.md",
    "content": "# Retool\n[ReTool: Reinforcement Learning for Strategic Tool Use in LLMs](https://arxiv.org/abs/2504.11536)\n\n## Overview\n- Base model: [Qwen/Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct)\n- SFT dataset: [JoeYing/ReTool-SFT](https://huggingface.co/datasets/JoeYing/ReTool-SFT)\n- RL dataset: [BytedTsinghua-SIA/DAPO-Math-17k](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k)\n- Val dataset: [yentinglin/aime_2025](https://huggingface.co/datasets/yentinglin/aime_2025)\n\n## How it works\n\n Retool's workflow is divided into two key phases:\n\n1.  Cold Start and Supervised Fine Tuning (SFT)\n    \n     The data generation pipeline builds a high-quality dataset containing code-enhanced inference trajectories, and supervised fine-tuning enables the model to master basic Tool call (e.g., code execution) and analysis of the execution results.\n    \n2.  Dynamic Interaction and Policy Optimization (RL).\n    \n     With the verl Reinforcement Learning framework, the model dynamically inserts code blocks during inference and interacts with the sandbox environment in real-time, generating a hybrid trajectory of natural language thinking and code snippets, sending the code to the sandbox for asynchronous execution when code termination markers are detected, and the execution results (success outputs/errors) are fed back to the model for guiding the subsequent inference. This \"think-execute-feedback\" cycle, together with the design of rewards based on the accuracy of the final answer, enables the model to independently optimize the Tool call strategy, and improves the reasoning efficiency and computational accuracy.\n\n## SFT\n1. Data preparation\n```bash\npython3 recipe/retool/retool_sft_preprocess.py\n```\n\n2. Training\n```bash\nbash recipe/retool/run_qwen2-32b_sft.sh\n```\n\nAfter 6 epoches, validation metrics:\n- val-core/aime_2025/acc/mean@30: 0.24\n- val-aux/num_turns/mean: 7.2\n\n## RL\n\n### GRPO\n```bash\nbash recipe/retool/run_qwen2-32b_dapo.sh\n```\n\nAfter 150 steps, validation metrics:\n- val-core/aime_2025/acc/mean@30: 0.6\n- val-aux/num_turns/mean: 10\n\n### PPO\n\n```bash\nbash recipe/retool/run_qwen2-32b_ppo.sh\n```\n\nAfter 250 steps, validation metrics:\n- val-core/aime_2025/acc/mean@30: 0.55\n- val-aux/num_turns/mean: 8.3\n"
  },
  {
    "path": "verl_distillation/recipe/retool/retool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport logging\nimport re\nfrom typing import Any\n\nimport datasets\n\nfrom verl.tools.base_tool import OpenAIFunctionToolSchema\nfrom verl.tools.sandbox_fusion_tools import SandboxFusionTool\nfrom verl.utils.dataset import RLHFDataset\nfrom verl.utils.reward_score import math_dapo\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nlogger = logging.getLogger(__name__)\n\n\nclass CustomSandboxFusionTool(SandboxFusionTool):\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        super().__init__(config, tool_schema)\n        self.code_pattern = re.compile(r\"```python(.*?)```\", re.DOTALL)\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        code = parameters[\"code\"]\n        matches = self.code_pattern.findall(code)\n        if matches:\n            code = matches[0].strip()\n\n        # NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script\n        lines = code.split(\"\\n\")\n        for i, line in reversed(list(enumerate(lines))):\n            if line == \"\":\n                continue\n            if not lines[i].startswith(\"print\"):\n                lines[i] = f\"print({line})\"\n            break\n        code = \"\\n\".join(lines)\n\n        timeout = parameters.get(\"timeout\", self.default_timeout)\n        language = parameters.get(\"language\", self.default_language)\n        if not isinstance(code, str):\n            code = str(code)\n\n        result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language)\n        # sandbox has no score or metrics, use Nones\n        return result, None, None\n\n\nanswer_format = \"\"\"\\nThe answer format must be: \\\\boxed{'The final answer goes here.'}\"\"\"\n\n\nclass CustomRLHFDataset(RLHFDataset):\n    \"\"\"Custom dataset class to process Maxwell-Jia/AIME_2024, yentinglin/aime_2025 datasets.\"\"\"\n\n    def _read_files_and_tokenize(self):\n        dataframes = []\n        for parquet_file in self.data_files:\n            # read parquet files and cache\n            dataframe = datasets.load_dataset(parquet_file)[\"train\"]\n            data_source = \"/\".join(parquet_file.split(\"/\")[-2:])\n            if data_source in [\"Maxwell-Jia/AIME_2024\", \"yentinglin/aime_2025\"]:\n                dataframe = dataframe.map(\n                    self.map_fn, fn_kwargs={\"data_source\": data_source}, remove_columns=dataframe.column_names\n                )\n            else:\n                dataframe = dataframe.map(self.map_fn2, num_proc=16)\n            dataframes.append(dataframe)\n        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)\n\n        print(f\"dataset len: {len(self.dataframe)}\")\n\n    def map_fn(self, row: dict, *, data_source: str = None):\n        if data_source == \"Maxwell-Jia/AIME_2024\":\n            problem, answer = row[\"Problem\"], row[\"Answer\"]\n        elif data_source == \"yentinglin/aime_2025\":\n            problem, answer = row[\"problem\"], row[\"answer\"]\n\n        prompt = problem + answer_format\n        data = {\n            \"data_source\": data_source.split(\"/\")[1].lower(),  # aime_2024, aime_2025\n            \"prompt\": [{\"role\": \"user\", \"content\": prompt}],\n            \"ability\": \"MATH\",\n            \"reward_model\": {\"ground_truth\": str(answer)},\n            \"agent_name\": \"tool_agent\",\n        }\n        return data\n\n    def map_fn2(self, row: dict):\n        content = row[\"prompt\"][0][\"content\"]\n        row[\"prompt\"][0][\"content\"] = content + answer_format\n        row[\"agent_name\"] = \"tool_agent\"\n        return row\n\n\ndef compute_score(data_source, solution_str, ground_truth, extra_info):\n    # use \\\\boxed{...} answer\n    result = math_dapo.compute_score(solution_str, ground_truth, strict_box_verify=True)\n\n    # encourage model to call tools\n    num_turns = extra_info[\"num_turns\"]\n    if result[\"score\"] < 0:\n        tool_call_reward = (num_turns - 2) / 2 * 0.1\n        result[\"score\"] = min(-0.6, result[\"score\"] + tool_call_reward)\n\n    if result[\"pred\"] is None:\n        result[\"pred\"] = \"\"\n\n    return result\n"
  },
  {
    "path": "verl_distillation/recipe/retool/retool_sft_preprocess.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nConvert JoeYing/ReTool-SFT to standard multi-turn tool calling messages.\n\"\"\"\n\nimport json\nimport os\nimport re\nfrom typing import Any\n\nimport datasets\nfrom omegaconf import OmegaConf\n\ncode_pattern = re.compile(r\"```python(.*?)```\", re.DOTALL)\n\n\ndef extract_code_message(content: str) -> tuple[dict[str, Any], str]:\n    start, stop = \"<code>\", \"</code>\"\n    i = content.find(start)\n    if i == -1:\n        return None, content\n    j = content.find(stop)\n    assert j > i\n\n    code = content[i + len(start) : j]\n    matches = code_pattern.findall(code)\n    if matches:\n        code = matches[0].strip()\n\n    message = {\n        \"role\": \"assistant\",\n        \"content\": content[:i].strip(),\n        \"tool_calls\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"code_interpreter\",\n                    \"arguments\": {\"code\": code},\n                },\n            },\n        ],\n    }\n    return message, content[j + len(stop) :]\n\n\ndef extract_answer_message(content: str) -> tuple[dict[str, Any], str]:\n    start, stop = \"<answer>\", \"</answer>\"\n    i = content.find(start)\n    if i == -1:\n        return None, content\n    j = content.find(stop)\n    assert j > i\n\n    answer = content[:i] + content[i + len(start) : j]\n    message = {\n        \"role\": \"assistant\",\n        \"content\": answer.strip(),\n    }\n    return message, content[j + len(stop) :]\n\n\ndef extract_interpreter_message(content: str) -> tuple[dict[str, Any], str]:\n    start, stop = \"<interpreter>\", \"</interpreter>\"\n    i = content.find(start)\n    if i == -1:\n        return None, content\n    j = content.find(stop)\n    assert j > i\n\n    interpreter = content[i + len(start) : j]\n    message = {\n        \"role\": \"tool\",\n        \"content\": interpreter.strip(),\n    }\n    return message, content[j + len(stop) :]\n\n\ndef process(row: dict, *, tools: str):\n    messages = []\n\n    # extract problem\n    content = row[\"messages\"][0][\"content\"]\n    start = \"*user question:*\"\n    i = content.find(start)\n    assert i != -1\n    prompt = content[i + len(start) :].replace(\"<answer>\", \"\").replace(\"</answer>\", \"\").strip()\n    messages.append(\n        {\n            \"role\": \"user\",\n            \"content\": prompt,\n        }\n    )\n\n    # extract multi turns\n    content = row[\"messages\"][1][\"content\"]\n    role = \"assistant\"\n    while len(content) > 0:\n        if role == \"assistant\":\n            message, content = extract_code_message(content)\n            if message is None:\n                message, content = extract_answer_message(content)\n            assert message is not None\n            messages.append(message)\n            role = \"tool\"\n        else:\n            message, content = extract_interpreter_message(content)\n            assert message is not None\n            messages.append(message)\n            role = \"assistant\"\n\n    tools = json.loads(tools)\n    return {\"messages\": messages, \"tools\": tools}\n\n\nif __name__ == \"__main__\":\n    tools_config_file = \"recipe/retool/sandbox_fusion_tool_config.yaml\"\n    tools_config = OmegaConf.load(tools_config_file)\n    tool_schema = OmegaConf.to_container(tools_config[\"tools\"][0][\"tool_schema\"])\n    tools = json.dumps([tool_schema])\n\n    data = datasets.load_dataset(\"JoeYing/ReTool-SFT\")[\"train\"]\n    data = data.map(process, fn_kwargs={\"tools\": tools})\n    save_path = os.path.expanduser(\"~/ReTool-SFT/data/train-00000-of-00001.parquet\")\n    data.to_parquet(save_path)\n"
  },
  {
    "path": "verl_distillation/recipe/retool/run_gpt_oss_ppo.sh",
    "content": "set -x\n\n# ================= data/model/tool =================\nHDFS_ROOT=${HDFS_ROOT:-$PWD}\nDATA_ROOT=${DATA_ROOT:-$PWD}\n\ndapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k\naime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024\naime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025\nactor_model_path=lmsys/gpt-oss-20b-bf16\ncritic_model_path=$actor_model_path\n\ntrain_files=\"['$dapo_math_17k']\"\ntest_files=\"['$aime_2025']\"\n\n# tool\ntool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml\n\n# wandb\nproject_name=wuxibin_retool\nexperiment_name=gpt-oss-20b-bf16_ppo\ndefault_local_dir=$DATA_ROOT/checkpoint/$experiment_name\n\n# ================= algorithm =================\nadv_estimator=gae\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_turns=8\nmax_prompt_length=2048\nmax_response_length=16384\nactor_lr=1e-6\ncritic_lr=2e-6\ngae_gamma=1.0\ngae_lam=1.0\n\ncritic_warmup=20\n\ntrain_batch_size=512\nppo_mini_batch_size=512\nn_resp_per_prompt_val=30\n\n# ================= perfomance =================\ninfer_tp=4 # vllm\ntrain_sp=4 # train\n\noffload=True\n\nactor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 2 ))\ncritic_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=$adv_estimator \\\n    algorithm.use_kl_in_reward=$use_kl_in_reward \\\n    algorithm.kl_ctrl.kl_coef=$kl_coef \\\n    algorithm.gamma=$gae_gamma \\\n    algorithm.lam=$gae_lam \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=True \\\n    data.train_batch_size=$train_batch_size \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=True \\\n    +data.apply_chat_template_kwargs.reasoning_effort=medium \\\n    data.truncation='error' \\\n    data.custom_cls.path=recipe/retool/retool.py \\\n    data.custom_cls.name=CustomRLHFDataset \\\n    custom_reward_function.path=recipe/retool/retool.py \\\n    custom_reward_function.name=compute_score \\\n    actor_rollout_ref.model.path=$actor_model_path \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \\\n    actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \\\n    actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \\\n    actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.optim.lr=$actor_lr \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=$offload \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \\\n    actor_rollout_ref.rollout.multi_turn.enable=True \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \\\n    actor_rollout_ref.rollout.multi_turn.format=gpt-oss \\\n    +actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=triton \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=1.0 \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \\\n    actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \\\n    critic.optim.lr=$critic_lr \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=$critic_model_path \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_max_token_len_per_gpu=$critic_max_token_len_per_gpu \\\n    critic.ulysses_sequence_parallel_size=$train_sp \\\n    critic.model.fsdp_config.param_offload=$offload \\\n    critic.model.fsdp_config.optimizer_offload=$offload \\\n    trainer.critic_warmup=$critic_warmup \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=$project_name \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=True \\\n    trainer.log_val_generations=100 \\\n    trainer.nnodes=2 \\\n    trainer.save_freq=30 \\\n    trainer.default_local_dir=$default_local_dir \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@\n"
  },
  {
    "path": "verl_distillation/recipe/retool/run_qwen2-32b_dapo.sh",
    "content": "set -x\n\n# ================= data/model/tool =================\nHDFS_ROOT=${HDFS_ROOT:-$PWD}\nDATA_ROOT=${DATA_ROOT:-$PWD}\n\ndapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k\naime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024\naime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025\nmodel_path=$HDFS_ROOT/checkpoint/multiturn-sft-qwen-2.5-32b-instruct/global_step_372\n\ntrain_files=\"['$dapo_math_17k']\"\ntest_files=\"['$aime_2025']\"\n\n# tool\ntool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml\n\n# wandb\nproject_name=wuxibin_retool\nexperiment_name=qwen2.5-32b_dapo\ndefault_local_dir=$DATA_ROOT/checkpoint/$experiment_name\n\n# ================= algorithm =================\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_turns=8\nmax_prompt_length=2048\nmax_response_length=16384\nactor_lr=1e-6\n\ntrain_batch_size=512\nppo_mini_batch_size=64\nn_resp_per_prompt=16\nn_resp_per_prompt_val=30\n\n# ================= perfomance =================\ninfer_tp=4 # vllm\ntrain_sp=8 # train\noffload=True\n\nactor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 1 ))\nlog_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 4 ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=$adv_estimator \\\n    algorithm.use_kl_in_reward=$use_kl_in_reward \\\n    algorithm.kl_ctrl.kl_coef=$kl_coef \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=True \\\n    data.train_batch_size=$train_batch_size \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.custom_cls.path=recipe/retool/retool.py \\\n    data.custom_cls.name=CustomRLHFDataset \\\n    custom_reward_function.path=recipe/retool/retool.py \\\n    custom_reward_function.name=compute_score \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \\\n    actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \\\n    actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \\\n    actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.optim.lr=$actor_lr \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=$offload \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \\\n    actor_rollout_ref.rollout.multi_turn.enable=True \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n    actor_rollout_ref.rollout.n=$n_resp_per_prompt \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \\\n    actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=$project_name \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=True \\\n    trainer.log_val_generations=100 \\\n    trainer.nnodes=2 \\\n    trainer.save_freq=30 \\\n    trainer.default_local_dir=$default_local_dir \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@\n"
  },
  {
    "path": "verl_distillation/recipe/retool/run_qwen2-32b_ppo.sh",
    "content": "set -x\n\n# ================= data/model/tool =================\nHDFS_ROOT=${HDFS_ROOT:-$PWD}\nDATA_ROOT=${DATA_ROOT:-$PWD}\n\ndapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k\naime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024\naime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025\nactor_model_path=$HDFS_ROOT/checkpoint/multiturn-sft-qwen-2.5-32b-instruct/global_step_372\ncritic_model_path=$actor_model_path\n\ntrain_files=\"['$dapo_math_17k']\"\ntest_files=\"['$aime_2025']\"\n\n# tool\ntool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml\n\n# wandb\nproject_name=wuxibin_retool\nexperiment_name=qwen2.5-32b_ppo\ndefault_local_dir=$DATA_ROOT/checkpoint/$experiment_name\n\n# ================= algorithm =================\nadv_estimator=gae\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_turns=8\nmax_prompt_length=2048\nmax_response_length=16384\nactor_lr=1e-6\ncritic_lr=2e-6\ngae_gamma=1.0\ngae_lam=1.0\n\ncritic_warmup=20\n\ntrain_batch_size=1024\nppo_mini_batch_size=256\nn_resp_per_prompt_val=30\n\n# ================= perfomance =================\ninfer_tp=4 # vllm\ntrain_sp=4 # train\n\noffload=True\n\nactor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 2 ))\ncritic_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=$adv_estimator \\\n    algorithm.use_kl_in_reward=$use_kl_in_reward \\\n    algorithm.kl_ctrl.kl_coef=$kl_coef \\\n    algorithm.gamma=$gae_gamma \\\n    algorithm.lam=$gae_lam \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=True \\\n    data.train_batch_size=$train_batch_size \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.custom_cls.path=recipe/retool/retool.py \\\n    data.custom_cls.name=CustomRLHFDataset \\\n    custom_reward_function.path=recipe/retool/retool.py \\\n    custom_reward_function.name=compute_score \\\n    actor_rollout_ref.model.path=$actor_model_path \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \\\n    actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \\\n    actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \\\n    actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.optim.lr=$actor_lr \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=$offload \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \\\n    actor_rollout_ref.rollout.multi_turn.enable=True \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \\\n    actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \\\n    critic.optim.lr=$critic_lr \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=$critic_model_path \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_max_token_len_per_gpu=$critic_max_token_len_per_gpu \\\n    critic.ulysses_sequence_parallel_size=$train_sp \\\n    critic.model.fsdp_config.param_offload=$offload \\\n    critic.model.fsdp_config.optimizer_offload=$offload \\\n    trainer.critic_warmup=$critic_warmup \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=$project_name \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=True \\\n    trainer.log_val_generations=100 \\\n    trainer.nnodes=2 \\\n    trainer.save_freq=30 \\\n    trainer.default_local_dir=$default_local_dir \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@\n"
  },
  {
    "path": "verl_distillation/recipe/retool/run_qwen2-32b_sft.sh",
    "content": "#!/bin/bash\nset -x\n\nnnodes=2\nnproc_per_node=8\nmaster_addr=\nmaster_port=\n\nexperiment_name=multiturn-sft-qwen-2.5-32b-instruct\nHDFS_ROOT=${HDFS_ROOT:-$PWD}\nDATA_ROOT=${DATA_ROOT:-$PWD}\n\nTRAIN_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet\nEVAL_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet\nMODEL_PATH=$HDFS_ROOT/model/Qwen2.5-32B-Instruct\nSAVE_PATH=$DATA_ROOT/checkpoint/$experiment_name\n\ntorchrun --nnodes=$nnodes \\\n     --nproc_per_node=$nproc_per_node \\\n     --master-addr=$master_addr \\\n     --master-port=$master_port \\\n     --node-rank=$node_rank \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$TRAIN_DATA \\\n    data.val_files=$EVAL_DATA \\\n    data.max_length=16384 \\\n    data.train_batch_size=32 \\\n    data.multiturn.enable=true \\\n    data.multiturn.messages_key=messages \\\n    data.multiturn.tools_key=tools \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=$MODEL_PATH \\\n    model.strategy=fsdp \\\n    trainer.default_local_dir=$SAVE_PATH \\\n    trainer.project_name=wuxibin-multiturn-sft \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.total_epochs=6 \\\n    ulysses_sequence_parallel_size=4 \\\n    use_remove_padding=true"
  },
  {
    "path": "verl_distillation/recipe/retool/run_qwen2_7b_dapo.sh",
    "content": "set -x\n\nexport VLLM_USE_V1=1\n\n# ================= data/model/tool =================\nHDFS_ROOT=${HDFS_ROOT:-$PWD}\nDATA_ROOT=${DATA_ROOT:-$PWD}\n\ndapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k\naime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024\naime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025\nmodel_path=$HDFS_ROOT/checkpoint/multiturn-sft-qwen-2.5-7b-instruct/global_step_372\n\ntrain_files=\"['$dapo_math_17k']\"\ntest_files=\"['$aime_2025', '$aime_2024']\"\n\n# tool\ntool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml\n\n# wandb\nproject_name=retool\nexperiment_name=qwen2.5-7b_dapo\ndefault_local_dir=$DATA_ROOT/checkpoint/$experiment_name\n\n# ================= algorithm =================\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_turns=16\nmax_prompt_length=2048\nmax_response_length=16384\nactor_lr=1e-6\n\ntrain_batch_size=64\nppo_mini_batch_size=16\nn_resp_per_prompt=16\nn_resp_per_prompt_val=30\n\n# ================= perfomance =================\ninfer_tp=4 # vllm\ntrain_sp=4 # train\noffload=True\n\nactor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 1 ))\nlog_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 4 ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=$adv_estimator \\\n    algorithm.use_kl_in_reward=$use_kl_in_reward \\\n    algorithm.kl_ctrl.kl_coef=$kl_coef \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=True \\\n    data.train_batch_size=$train_batch_size \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.custom_cls.path=recipe/retool/retool.py \\\n    data.custom_cls.name=CustomRLHFDataset \\\n    custom_reward_function.path=recipe/retool/retool.py \\\n    custom_reward_function.name=compute_score \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \\\n    actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \\\n    actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \\\n    actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.optim.lr=$actor_lr \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=$offload \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \\\n    actor_rollout_ref.rollout.multi_turn.enable=True \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n    actor_rollout_ref.rollout.n=$n_resp_per_prompt \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \\\n    actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=$project_name \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=True \\\n    trainer.log_val_generations=20 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.default_local_dir=$default_local_dir \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=1 $@\n"
  },
  {
    "path": "verl_distillation/recipe/retool/run_qwen2_7b_sft.sh",
    "content": "#!/bin/bash\nset -x\n\nnnodes=1\nnproc_per_node=8\nmaster_addr=\nmaster_port=\nnode_rank=${ARNOLD_ID:-0}\n\nproject_name=retool\nexperiment_name=multiturn-sft-qwen-2.5-7b-instruct\n\nHDFS_ROOT=${HDFS_ROOT:-$PWD}\nDATA_ROOT=${DATA_ROOT:-$PWD}\n\nTRAIN_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet\nEVAL_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet\nMODEL_PATH=$HDFS_ROOT/model/Qwen2.5-7B-Instruct\nSAVE_PATH=$DATA_ROOT/checkpoint/$experiment_name\n\ntorchrun --nnodes=$nnodes \\\n     --nproc_per_node=$nproc_per_node \\\n     --master-addr=$master_addr \\\n     --master-port=$master_port \\\n     --node-rank=$node_rank \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$TRAIN_DATA \\\n    data.val_files=$EVAL_DATA \\\n    data.max_length=16384 \\\n    data.train_batch_size=32 \\\n    data.multiturn.enable=true \\\n    data.multiturn.messages_key=messages \\\n    data.multiturn.tools_key=tools \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=$MODEL_PATH \\\n    model.strategy=fsdp \\\n    trainer.default_local_dir=$SAVE_PATH \\\n    trainer.project_name=wuxibin-multiturn-sft \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.total_epochs=6 \\\n    trainer.save_freq=62 \\\n    ulysses_sequence_parallel_size=4 \\\n    use_remove_padding=true"
  },
  {
    "path": "verl_distillation/recipe/retool/run_qwen2_7b_sft_npu.sh",
    "content": "#!/bin/bash\nset -x\n\nnnodes=1\nnproc_per_node=8\n\nproject_name=retool_sft\nexperiment_name=multiturn-sft-qwen-2.5-7b-instruct\n\nTRAIN_DATA=PATH/TO/ReTool-SFT/data/train-00000-of-00001.parquet\nEVAL_DATA=PATH/TO/ReTool-SFT/data/train-00000-of-00001.parquet\nMODEL_PATH=PATH/TO/Qwen2.5-7B-Instruct\nSAVE_PATH=PATH/TO/checkpoint/$experiment_name\n\ntorchrun --nnodes=$nnodes \\\n     --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$TRAIN_DATA \\\n    data.val_files=$EVAL_DATA \\\n    data.max_length=16384 \\\n    data.train_batch_size=64 \\\n    data.multiturn.enable=true \\\n    data.multiturn.messages_key=messages \\\n    data.multiturn.tools_key=tools \\\n    data.micro_batch_size_per_gpu=8 \\\n    model.partial_pretrain=$MODEL_PATH \\\n    model.strategy=fsdp \\\n    trainer.default_local_dir=$SAVE_PATH \\\n    trainer.project_name=$project_name \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.logger='[\"console\"]' \\\n    trainer.total_epochs=6 \\\n    trainer.save_freq=10 \\\n    trainer.device=npu \\\n    ulysses_sequence_parallel_size=4 \\\n    use_remove_padding=true"
  },
  {
    "path": "verl_distillation/recipe/retool/sandbox_fusion_tool_config.yaml",
    "content": "tools:\n  - class_name: \"recipe.retool.retool.CustomSandboxFusionTool\"\n    config:\n      sandbox_fusion_url: \"http://localhost:8080/run_code\"\n      num_workers: 128\n      enable_global_rate_limit: true\n      rate_limit: 128\n      default_timeout: 30\n      default_language: \"python\"\n      memory_limit_mb: 1024\n      type: native\n\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"code_interpreter\"\n        description: \"A tool for executing code.\"\n        parameters:\n          type: \"object\"\n          properties:\n            code:\n              type: \"string\"\n              description: \"The code to execute.\"\n          required: [\"code\"]\n"
  },
  {
    "path": "verl_distillation/recipe/spin/README.md",
    "content": "# SPIN: Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models\n\nThis repository hosts a `verl` recipe inspired by the paper **\"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models\"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory.\n\n**Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models:\n\n1.  **Synthetic Data Generation:** The current model generates responses, creating its own training data from previous iterations.\n2.  **Two-Player Game Setup:** A game involving two players acted by a single LLM.\n3.  **Iterative Training:** The model progressively improves by refining its policy, with each iteration's model becoming the opponent for the next iteration.\n\nPaper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\\*, [Yihe Deng](https://github.com/uclaml/SPIN)\\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n\n[[Webpage](https://uclaml.github.io/SPIN/)] [[Huggingface](https://huggingface.co/papers/2401.01335)] [[Paper](https://arxiv.org/abs/2401.01335)] [[Original Implementation](https://github.com/uclaml/SPIN)]\n\nverl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20)\n\n---\n\n## Key Function (compute_online_dpo_loss) and Related works\nSPIN (Chen et al., 2024) proposes an iterative self-play mechanism to fine-tune language models. In each iteration, SPIN's training objective, when using a logistic loss function, is equivalent to Direct Preference Optimization (DPO) loss (Rafailov et al., 2023). \n\nThis `verl` recipe realizes SPIN's core concept by using DPO loss iteratively (Xu et al., 2023; Xiong et al., 2023; Snorkel AI, 2024). This means that in each iteration, we fine-tune the LLM using DPO loss for preference optimization. Notably, Xu et al. (2023) explored iterative preference optimization with pairwise cringe loss, while Xiong et al. (2023) discussed how to bridge theory and practice for RLHF under KL constraints using iterative training. The concept of iterative preference learning was also explored in online DPO (Guo et al., 2024), which focuses on direct alignment from online AI feedback. In online DPO, preference data is dynamically updated during training, allowing the model to learn from its own generated data.\n\nSpecifically, we developed the **`compute_online_dpo_loss`** function and built this SPIN recipe on top of it. By incorporating online preference generation, this approach enables continuously refining language models without relying on fixed external preference datasets.\n\n**Reference Papers:**\n* [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024) \n* [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023) \n* [Somethings are more cringe than others: Preference optimization with the pairwise cringe loss](https://arxiv.org/abs/2312.16682) (Xu et al., 2023) \n* [Iterative preference learning from human feedback: Bridging theory and practice for rlhf under kl-constraint](https://arxiv.org/abs/2312.11456) (Xiong et al., 2023)\n* [Snorkel-Mistral-PairRM-DPO](https://huggingface.co/snorkelai/Snorkel-Mistral-PairRM-DPO) (Snorkel AI, 2024)\n* [Direct language model alignment from online ai feedback](https://arxiv.org/abs/2402.04792) (Guo et al., 2024)\n\n\n## Our Online DPO Implementation\n\nOur `compute_online_dpo_loss` function adapts `verl`'s existing PPO infrastructure (based on `verl` v0.3.0.post1) for this iterative online DPO. Key aspects of our implementation include:\n\n* **No Critic:** Unlike PPO, we omit the value function critic.\n* **Dynamic Reference Model:** An explicit reference policy (`ref_policy_wg`) is used for DPO loss. This reference model's weights can be periodically updated from the actor (`ref_update_freq`), providing a dynamic baseline.\n* **Online Preference Generation:** The `compute_onlineDPO_pref` function (in `core_algos.py`) dynamically creates chosen/rejected pairs based on a reward source (e.g., rule-based ranking for math problems).\n* **DPO Loss Integration:** We replace PPO's policy loss with our `compute_online_dpo_loss` (in `core_algos.py`) within the actor update (`dp_actor.py`), directly optimizing the policy using the generated preferences.\n* **Iterative Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the entire self-play loop: generation, preference labeling, optional reference model updates, and policy updates, enabling continuous self-improvement aligned with SPIN's principles.\n\n---\n## Algorithm\n\nThis recipe implements an Online algorithm adapted to the `verl` Reinforcement Learning framework, which provides an alternative to PPO for fine-tuning language models.\n\n**Online Loop:** Instead of maximizing a scalar reward signal in PPO, this approach directly optimizes the policy model to align with preference data generated *online* during training:\n\n1.  **Generation:** The current model generates multiple responses for each prompt in a batch.\n2.  **Preference Labeling:** A function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). This can be done using a reward function or implicit ranking based on specific rules. (In this recipe, we use rule-based ranking on the math problem).\n3.  **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using `compute_online_dpo_loss`, comparing against a reference model.\n\n**Connection with SPIN:**\nInstead of only using a fixed target data distribution, the online generation loop in step 2 will dynamically change the target data distribution by using a certain Preference Labeling method (rule-based ranking on the math problem by selecting the better one in this recipe). This explores the direction mentioned in SPIN's paper Section 7 about \"dynamically changing target data distribution\" to potentially elevate LLM performance beyond the fixed human-annotated data ceiling.\n\n---\n\n## Reproduce the Experiment (Example Setup)\n\nThe following steps outline how to set up the environment and run the SPIN recipe, based on the provided test log using GSM8K and Qwen2.5-3B-Instruct.\n\n1.  **Setup Environment (Example using Docker):**\n    ```bash\n    # Start a container with GPU access and shared memory\n    docker run -it --name spin_test --gpus all \\\n        --shm-size=32g \\\n        --ipc=host \\\n        -v /path/to/host/.cache:/root/.cache \\\n        -e HF_TOKEN=<YOUR_HUGGINGFACE_TOKEN> \\\n        lmsysorg/sglang:latest \\\n        /bin/bash\n\n    # Inside the container or on your host machine:\n    # Ensure /tmp is writable\n    mkdir -p /tmp\n    chmod 1777 /tmp\n\n    # Install Python 3.10 (if not present) and venv\n    sudo apt update\n    sudo apt install -y python3.10 python3.10-venv tmux\n    python3 -m ensurepip --upgrade\n\n    # Create and activate a virtual environment\n    python3 -m venv ~/.python/spin_env\n    source ~/.python/spin_env/bin/activate\n\n    # Install uv (fast package installer)\n    python3 -m pip install uv\n    ```\n\n2.  **Install verl and Dependencies:**\n    ```bash\n    # Clone the verl repository and checkout the spin branch\n    cd ~\n    git clone git@github.com:volcengine/verl.git && cd verl\n\n    # Install flash-attn (handle potential build issues)\n    python3 -m uv pip install wheel packaging\n    python3 -m uv pip install flash-attn --no-build-isolation --no-deps\n\n    # Install verl with sglang extras\n    python3 -m uv pip install -e \".[sglang]\"\n    ```\n    *Note: If `flash-attn` installation fails, try the manual steps again or consult its documentation.*\n\n3.  **Login & Download Data/Model:**\n    ```bash\n    # Login to Weights & Biases (optional, for logging)\n    export WANDB_API_KEY=<YOUR_WANDB_API_KEY>\n    # wandb login\n\n    # Download the GSM8K dataset\n    python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k # Adjusted path\n\n    # Download the base model (Example: Qwen2.5-3B-Instruct)\n    huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen2.5-3B-Instruct\n    ```\n\n4.  **Configure:**\n    * Modify the configuration file (e.g., `config/spin_trainer.yaml` or the one specified in the run script) with correct paths to your downloaded model, data, desired hyperparameters (`dpo_beta`, learning rate, etc.), and distributed training settings (nodes, GPUs per node).\n    * Pay attention to `actor_rollout_ref.model_path`, `data` paths, `reward_model` config (if using one), and `trainer.ref_update_freq`.\n\n5.  **Run Training:**\n    ```bash\n    # Set CUDA visible devices (adjust based on your hardware and config)\n    export CUDA_VISIBLE_DEVICES=0,1,2,3\n\n    # Launch the training script (e.g., test.sh or a custom script)\n    # Ensure test.sh points to the correct config and main script\n    bash recipe/spin/run_spin.sh\n    ```\n\n---\n\n## Configuration\n\n* The primary configuration is typically managed through a YAML file specified in the launch script (e.g., `config/spin_trainer.yaml`).\n* Key configuration sections:\n    * `data`: Paths to training/validation prompt files, batch sizes, sequence lengths.\n    * `actor_rollout_ref`: Paths to the base model (used for actor and initial reference), FSDP settings, optimization parameters (learning rate, scheduler).\n    * `reward_model`: Configuration for the reward model used for online preference labeling (path, batch size, etc.). Can be omitted if using a simpler reward function.\n    * `algorithm`: DPO-specific hyperparameters like `dpo_beta`, `dpo_loss_type`.\n    * `trainer`: Distributed training settings (nodes, GPUs per node), logging (WandB), checkpointing frequency, and `ref_update_freq` (set > 0 to enable periodic reference model updates from the actor).\n\n---\n\n## Key Files\n\n* `main_spin.py`: Main entry point using Hydra to load the config and launch the `SpinTrainer`.\n* `spin_trainer.py`: Defines the `SpinTrainer` class, orchestrating the Online DPO training loop.\n* `fsdp_workers.py`: Implements Ray workers (Actor, Reference) potentially using FSDP.\n* `dp_actor.py`: Contains the actor class, including the DPO policy update logic.\n* `core_algos.py`: Includes helper functions for `compute_online_dpo_loss` and `compute_onlineDPO_pref`.\n* `config/spin_trainer.yaml` (or similar): Main Hydra configuration file for the recipe.\n* `run_spin.sh` (or similar): Example bash script for launching a training run.\n* `README.md`: This file.\n\n---\n\n## Acknowledgement\n\nWe sincerely thank the contribution and guidance from the `verl` community and advisors, including (adapted from SPPO):\n\n* [Zixiang Chen](https://sites.google.com/view/zxchen)\n* [Yuhao Yang](https://github.com/yhyang201)\n* [Yifan Zhang](https://github.com/yifanzhang-pro)\n* [Yongan Xiang](https://github.com/BearBiscuit05)\n* [Junrong Lin](https://github.com/ocss884)\n* [Yuxuan Tong](https://github.com/tongyx361)\n* [Guangming Shen](https://github.com/PeterSH6)\n* [Biao He](https://www.linkedin.com/in/biao-he/)\n* [Qingquan Song](https://qingquansong.github.io/)\n* [Chenyang Zhao](https://zhaochenyang20.github.io/Chayenne/)\n* [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n\n---\n"
  },
  {
    "path": "verl_distillation/recipe/spin/config/spin_trainer.yaml",
    "content": "# the sppo config will override default ppo_trainer.yaml\n\nhydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\nactor_rollout_ref:\n  actor:\n    dpo_beta: 0.1\n    optim:\n      lr_warmup_steps: 15\n  rollout:\n    name: sglang\n    tensor_model_parallel_size: 2\n    gpu_memory_utilization: 0.5\n    val_kwargs:\n      n: 2  # 2 will trigger validation, 1 will bypass\n\nalgorithm:\n  adv_estimator: null\n\ntrainer:\n  log_val_generations: 0\n  ref_update_freq: 1"
  },
  {
    "path": "verl_distillation/recipe/spin/core_algos.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport numpy as np\nimport torch\n\n\nclass AdaptiveKLController:\n    \"\"\"\n    Adaptive KL controller described in the paper:\n    https://arxiv.org/pdf/1909.08593.pdf\n    \"\"\"\n\n    def __init__(self, init_kl_coef, target_kl, horizon):\n        self.value = init_kl_coef\n        self.target = target_kl\n        self.horizon = horizon\n\n    def update(self, current_kl, n_steps):\n        target = self.target\n        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)\n        mult = 1 + proportional_error * n_steps / self.horizon\n        self.value *= mult\n\n\nclass FixedKLController:\n    \"\"\"Fixed KL controller.\"\"\"\n\n    def __init__(self, kl_coef):\n        self.value = kl_coef\n\n    def update(self, current_kl, n_steps):\n        pass\n\n\ndef get_kl_controller(kl_ctrl):\n    if kl_ctrl.type == \"fixed\":\n        return FixedKLController(kl_coef=kl_ctrl.kl_coef)\n    elif kl_ctrl.type == \"adaptive\":\n        assert kl_ctrl.horizon > 0, f\"horizon must be larger than 0. Got {kl_ctrl.horizon}\"\n        return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)\n    else:\n        raise NotImplementedError\n\n\ndef compute_onlinedpo_pref(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"\n    Computes preferences between pairs of sequences based on summed rewards\n    and returns a mask aligned with the interleaved batch.\n\n    Assumes inputs are interleaved: [Resp1_Prompt0, Resp2_Prompt0, Resp1_Prompt1, Resp2_Prompt1, ...]\n\n    Args:\n        token_level_rewards: Tensor of shape [batch_size * 2, seq_len]\n        response_mask: Tensor of shape [batch_size * 2, seq_len]\n\n    Returns:\n        torch.Tensor: A boolean mask of shape [batch_size * 2], where True indicates\n                      the corresponding entry is the chosen response for its pair.\n                      Example: [True, False, False, True, ...] means for prompt 0,\n                               response 1 was chosen; for prompt 1, response 2 was chosen.\n    \"\"\"\n    # print(f\"---- [DEBUG] Inside compute_onlinedpo_pref ----\")\n    if token_level_rewards.shape[0] % 2 != 0 or response_mask.shape[0] % 2 != 0:\n        raise ValueError(\n            f\"Input tensor batch dimension must be even for pair comparison, got shapes: \"\n            f\"{token_level_rewards.shape}, {response_mask.shape}\"\n        )\n    if token_level_rewards.shape != response_mask.shape:\n        raise ValueError(f\"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}\")\n\n    # 1. Calculate Sequence Scores\n    scores = (token_level_rewards * response_mask).sum(dim=-1)\n    # print(f\"  Calculated sequence scores shape: {scores.shape}\") # [batch_size * 2]\n\n    # 2. Reshape scores to group pairs: [batch_size, 2]\n    try:\n        score_pairs = scores.view(-1, 2)\n    except RuntimeError as e:\n        print(f\"ERROR reshaping scores (shape {scores.shape}) into pairs: {e}\")\n        raise e\n    print(f\"  Reshaped score pairs shape: {score_pairs.shape}\")  # [batch_size, 2]\n\n    # 3. Compare scores to find which index (0 or 1) is the winner within each pair\n    #    winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1\n    winner_indices = torch.argmax(score_pairs, dim=1)  # 0 if first is max, 1 if second is max\n    # Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max)\n    # Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1]\n    # print(f\"  Winner indices shape: {winner_indices.shape}\") # [batch_size]\n    # print(f\"  Number where Response 2 (index 1) is preferred: {winner_indices.sum().item()}\") # Counts number of 1s\n\n    # 4. Create the final [batch_size * 2] mask\n    num_pairs = score_pairs.shape[0]\n    full_batch_size = num_pairs * 2\n    # Create indices for the full batch [0, 1, 2, 3, ..., N*2-1]\n    # full_indices = torch.arange(full_batch_size, device=scores.device)\n    # Create indices corresponding to the winner within each pair's original index\n    # E.g., if winner_indices is [0, 1, 0], pair_indices is [0, 1, 2]\n    # winner_global_indices = (pair_indices * 2) + winner_indices -> [ (0*2)+0, (1*2)+1, (2*2)+0 ] -> [0, 3, 4]\n    pair_indices = torch.arange(num_pairs, device=scores.device)\n    winner_global_indices = (pair_indices * 2) + winner_indices\n\n    # Create boolean mask - True at the winner's position\n    output_preference_mask = torch.zeros(full_batch_size, dtype=torch.bool, device=scores.device)\n    output_preference_mask[winner_global_indices] = True\n\n    # print(f\"  Output preference mask shape: {output_preference_mask.shape}\") # Should be [batch_size * 2]\n    # print(f\"  Output mask True count (Chosen): {output_preference_mask.sum().item()}\") # Should be batch_size\n    # print(f\"  Output mask False count (Rejected): {(~output_preference_mask).sum().item()}\") # Should be batch_size\n    # print(f\"---- [DEBUG] Exiting compute_onlinedpo_pref ----\")\n\n    return output_preference_mask\n\n\ndef compute_online_dpo_loss(\n    policy_chosen_logps: torch.Tensor,\n    policy_rejected_logps: torch.Tensor,\n    reference_chosen_logps: torch.Tensor,\n    reference_rejected_logps: torch.Tensor,\n    beta: float,\n    label_smoothing: float = 0.0,\n    loss_type: str = \"sigmoid\",\n    reference_free: bool = False,\n) -> torch.Tensor:\n    import torch.nn.functional as F\n\n    pi_logratios = policy_chosen_logps - policy_rejected_logps\n    ref_logratios = reference_chosen_logps - reference_rejected_logps\n\n    if reference_free:\n        ref_logratios = torch.zeros_like(pi_logratios)\n\n    logits = pi_logratios - ref_logratios\n\n    if loss_type == \"sigmoid\":\n        losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing\n    elif loss_type == \"ipo\":\n        losses = (logits - 1 / (2 * beta)) ** 2\n    else:\n        raise ValueError(f\"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'.\")\n\n    return losses.mean()\n\n\ndef get_batch_logps(\n    logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False\n) -> torch.FloatTensor:\n    \"\"\"\n    Compute the log probabilities of the given labels under the given logits.\n\n    Args:\n        logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`).\n                Shape: (batch_size, sequence_length, vocab_size)\n        labels: Labels for computing the sequence log probabilities. Shape: (batch_size, sequence_length)\n        average_log_prob: If True, return the average log probability per sequence. Otherwise, return the sum.\n\n    Returns:\n        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given sequences.\n    \"\"\"\n    if logits.shape[:-1] != labels.shape:\n        raise ValueError(\"Logits and labels must have the same shape[:-1]\")\n\n    # Ensure labels are contiguous and on the same device as logits\n    labels = labels.contiguous().to(logits.device)\n    # Shift so that tokens < n predict n\n    shift_logits = logits[..., :-1, :].contiguous()\n    shift_labels = labels[..., 1:].contiguous()\n\n    # Calculate per token log probability\n    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction=\"none\")\n    per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n    per_token_logps = per_token_logps.view(\n        shift_logits.size(0), shift_logits.size(1)\n    )  # Reshape back to (batch_size, seq_len-1)\n\n    # Create a mask for the labels that are not -100\n    loss_mask = shift_labels != -100\n\n    # Apply the mask to the per token log probabilities\n    masked_logps = per_token_logps * loss_mask\n\n    # Calculate the sum or average log probability per sequence\n    sequence_logps = masked_logps.sum(dim=-1)\n\n    if average_log_prob:\n        # Avoid division by zero for sequences with no valid tokens\n        num_valid_tokens = loss_mask.sum(dim=-1)\n        return sequence_logps / torch.clamp(num_valid_tokens, min=1)\n    else:\n        return sequence_logps\n"
  },
  {
    "path": "verl_distillation/recipe/spin/dp_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport itertools\nimport math\nfrom collections import defaultdict\n\nimport numpy as np\nimport torch\n\nfrom recipe.spin.core_algos import compute_online_dpo_loss, get_batch_logps\nfrom verl import DataProto\nfrom verl.utils.device import get_device_name\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.workers.actor import DataParallelPPOActor\n\n__all__ = [\"DataParallelPPOActor\"]\n\n\nclass SPINDataParallelPPOActor(DataParallelPPOActor):\n    def compute_log_prob(self, data: DataProto) -> torch.Tensor:\n        \"\"\"Compute the log probability of the responses given input_ids, attention_mask and position_ids\n\n        Args:\n            data (DataProto): a DataProto containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the\n                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.\n\n        Returns:\n            torch.Tensor: the log_prob tensor\n        \"\"\"\n        # set to eval\n        self.actor_module.eval()\n\n        micro_batch_size = data.meta_info[\"micro_batch_size\"]\n        temperature = data.meta_info[\"temperature\"]  # temperature must be in the data.meta_info to avoid silent error\n        use_dynamic_bsz = data.meta_info[\"use_dynamic_bsz\"]\n\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n        batch = data.select(batch_keys=select_keys).batch\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n\n        if has_multi_modal_inputs:\n            num_micro_batches = data.batch.batch_size[0] // micro_batch_size\n            non_tensor_select_keys = [\"multi_modal_inputs\"]\n            micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)\n        elif use_dynamic_bsz:\n            # split using dynamic bsz\n            max_token_len = data.meta_info[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)\n        else:\n            micro_batches = batch.split(micro_batch_size)\n\n        log_probs_lst = []\n        for micro_batch in micro_batches:\n            if isinstance(micro_batch, DataProto):\n                micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n\n            with torch.no_grad():\n                _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)\n            log_probs_lst.append(log_probs)\n        log_probs = torch.concat(log_probs_lst, dim=0)\n\n        if use_dynamic_bsz:\n            indices = list(itertools.chain.from_iterable(indices))\n            assert len(indices) == log_probs.size(0), f\"{len(indices)} vs. {log_probs.size()}\"\n            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n            log_probs = log_probs[revert_indices]\n\n        return log_probs\n\n    def update_policy_dpo_with_ref(self, data: DataProto):\n        \"\"\"\n        Performs the DPO update step using pre-calculated reference log probs\n        from an external, periodically updated reference model.\n        \"\"\"\n        self.actor_module.train()  # Ensure training mode\n\n        # --- Retrieve necessary data ---\n        try:\n            # Expects batch prepared by fit_dpo loop, including reference log probs\n            batch_td = data.batch\n            chosen_labels = batch_td[\"chosen_labels\"]\n            rejected_labels = batch_td[\"rejected_labels\"]\n            # ... other needed tensors like chosen/rejected input_ids, attention_mask, position_ids ...\n\n            # === Get PRE-CALCULATED reference log probs from input data ===\n            reference_chosen_logps = batch_td[\"reference_chosen_logps\"]  # Should be sequence-level logps\n            reference_rejected_logps = batch_td[\"reference_rejected_logps\"]  # Should be sequence-level logps\n            # ============================================================\n\n            # Get DPO params from meta_info\n            # beta = data.meta_info.get('dpo_beta', 0.1) # Default beta\n            beta = self.config.get(\"dpo_beta\", 0.1)  # Default beta\n            loss_type = data.meta_info.get(\"dpo_loss_type\", \"sigmoid\")\n            label_smoothing = data.meta_info.get(\"dpo_label_smoothing\", 0.0)\n            # reference_free should now be False as we provide ref logps\n            reference_free = data.meta_info.get(\"reference_free\", False)  # Default False\n\n        except KeyError as e:\n            print(f\"ERROR: Missing required key for DPO update (in update_policy_dpo): {e}\")\n            print(f\"Available keys in data.batch: {list(batch_td.keys())}\")  # Debug print\n            return {}  # Return empty metrics on error\n        except Exception as e_data:\n            print(f\"ERROR accessing data for DPO update (in update_policy_dpo): {e_data}\")\n            return {}\n\n        # --- Micro-batching Setup ---\n        micro_batch_size = self.config.get(\"ppo_micro_batch_size_per_gpu\")\n        if micro_batch_size is None:\n            # Fallback or default if not set, or raise error\n            micro_batch_size = 1  # Example fallback, adjust as needed\n            print(f\"Warning: 'ppo_micro_batch_size_per_gpu' not set, defaulting to {micro_batch_size}\")\n            # raise ValueError(\"Config 'ppo_micro_batch_size_per_gpu' must be set.\")\n\n        # Ensure chosen_input_ids exists before getting shape\n        if \"chosen_input_ids\" not in batch_td:\n            print(\"ERROR: 'chosen_input_ids' not found in batch_td for DPO update.\")\n            return {}\n        bsz = batch_td[\"chosen_input_ids\"].shape[0]\n\n        if bsz == 0:\n            print(\"Warning: DPO batch size is 0 in update_policy_dpo. Skipping update.\")\n            return {\"actor/dpo_loss\": 0.0, \"actor/grad_norm\": 0.0}  # Return zero metrics if batch is empty\n\n        num_micro_batches = math.ceil(bsz / micro_batch_size)\n        gradient_accumulation_steps = num_micro_batches\n\n        # --- Metrics Accumulation ---\n        total_loss = 0.0\n        accumulated_metrics = defaultdict(list)\n        metrics = {}  # Final metrics dict\n\n        # --- Zero Gradients ---\n        self.actor_optimizer.zero_grad(set_to_none=True)\n\n        # --- Micro-batch Loop ---\n        for i in range(num_micro_batches):\n            start_idx = i * micro_batch_size\n            end_idx = min(start_idx + micro_batch_size, bsz)\n            if start_idx >= end_idx:\n                continue\n\n            # Slice the full DPO batch into micro-batches\n            # Important: Slice ALL required tensors, including labels and inputs\n            micro_batch_chosen_labels = chosen_labels[start_idx:end_idx]\n            micro_batch_rejected_labels = rejected_labels[start_idx:end_idx]\n            micro_batch_chosen_inputs = {\n                \"input_ids\": batch_td[\"chosen_input_ids\"][start_idx:end_idx],\n                \"attention_mask\": batch_td[\"chosen_attention_mask\"][start_idx:end_idx],\n            }\n            if \"chosen_position_ids\" in batch_td:\n                micro_batch_chosen_inputs[\"position_ids\"] = batch_td[\"chosen_position_ids\"][start_idx:end_idx]\n\n            micro_batch_rejected_inputs = {\n                \"input_ids\": batch_td[\"rejected_input_ids\"][start_idx:end_idx],\n                \"attention_mask\": batch_td[\"rejected_attention_mask\"][start_idx:end_idx],\n            }\n            if \"rejected_position_ids\" in batch_td:\n                micro_batch_rejected_inputs[\"position_ids\"] = batch_td[\"rejected_position_ids\"][start_idx:end_idx]\n\n            # Determine autocast dtype\n            autocast_dtype = torch.bfloat16  # Or get dynamically from config/FSDP settings\n            # --- Autocast Forward Pass ---\n            with torch.autocast(device_type=get_device_name(), dtype=autocast_dtype):\n                # --- Step 1: Forward pass for CURRENT policy log probs (with grad) ---\n                policy_chosen_outputs = self.actor_module(**micro_batch_chosen_inputs, use_cache=False)\n                policy_rejected_outputs = self.actor_module(**micro_batch_rejected_inputs, use_cache=False)\n\n                # --- Step 2: Calculate CURRENT policy log probs using get_batch_logps ---\n                policy_chosen_logps = get_batch_logps(\n                    policy_chosen_outputs.logits, micro_batch_chosen_labels, average_log_prob=False\n                )\n                policy_rejected_logps = get_batch_logps(\n                    policy_rejected_outputs.logits, micro_batch_rejected_labels, average_log_prob=False\n                )\n\n                # --- Step 3: Retrieve PRE-CALCULATED reference log probs (NO grad needed) ---\n                # Slice the full batch reference logps for the current micro-batch\n                micro_ref_chosen_logps = reference_chosen_logps[start_idx:end_idx]\n                micro_ref_rejected_logps = reference_rejected_logps[start_idx:end_idx]\n                # --- The ActorAsRef calculation block is REMOVED ---\n\n                # --- Step 4: Calculate DPO Logits and Loss ---\n                pi_logratios = policy_chosen_logps - policy_rejected_logps\n                ref_logratios = micro_ref_chosen_logps - micro_ref_rejected_logps  # Uses pre-calculated values\n                logits = pi_logratios - ref_logratios  # DPO logits\n\n                loss = compute_online_dpo_loss(\n                    policy_chosen_logps=policy_chosen_logps,  # Has grad\n                    policy_rejected_logps=policy_rejected_logps,  # Has grad\n                    reference_chosen_logps=micro_ref_chosen_logps,  # No grad (from input)\n                    reference_rejected_logps=micro_ref_rejected_logps,  # No grad (from input)\n                    beta=beta,\n                    label_smoothing=label_smoothing,\n                    loss_type=loss_type,\n                    reference_free=reference_free,  # Should be False now\n                )\n\n                # --- Scale loss for gradient accumulation ---\n                scaled_loss = loss / gradient_accumulation_steps\n\n                # --- Accumulate Metrics ---\n                total_loss += loss.item()  # Unscaled loss\n                accumulated_metrics[\"actor/dpo_loss_batch\"].append(loss.item())\n                accumulated_metrics[\"actor/dpo_logits_batch\"].append(logits.mean().item())\n                # Accumulate policy and reference log probs/ratios if needed for debugging\n                accumulated_metrics[\"actor/policy_chosen_logps_batch\"].append(policy_chosen_logps.mean().item())\n                accumulated_metrics[\"actor/policy_rejected_logps_batch\"].append(policy_rejected_logps.mean().item())\n                accumulated_metrics[\"actor/reference_chosen_logps_batch\"].append(micro_ref_chosen_logps.mean().item())\n                accumulated_metrics[\"actor/reference_rejected_logps_batch\"].append(\n                    micro_ref_rejected_logps.mean().item()\n                )\n\n            # --- Backward Pass (outside autocast) ---\n            # Check if loss requires grad before backward\n            if scaled_loss.requires_grad:\n                scaled_loss.backward()\n            else:\n                print(f\"Warning: Scaled loss at micro-batch {i} does not require grad. Skipping backward.\")\n\n        # --- End Micro-batch Loop ---\n\n        # --- Optimizer Step (after accumulating gradients for all micro-batches) ---\n        grad_norm = self._optimizer_step()\n\n        # --- Populate Final Metrics ---\n        if num_micro_batches > 0 and bsz > 0:  # Check if any processing happened\n            metrics[\"actor/dpo_loss\"] = total_loss / num_micro_batches\n            metrics[\"actor/grad_norm\"] = (\n                grad_norm.item() if torch.is_tensor(grad_norm) and torch.isfinite(grad_norm) else float(\"inf\")\n            )\n            # Average other accumulated metrics\n            for key, val_list in accumulated_metrics.items():\n                if val_list:\n                    metrics[key.replace(\"_batch\", \"\")] = np.mean(val_list)\n\n            # Calculate accuracy / rewards / margins based on averaged logprobs if desired\n            if (\n                \"actor/policy_chosen_logps\" in metrics\n                and \"actor/policy_rejected_logps\" in metrics\n                and \"actor/reference_chosen_logps\" in metrics\n                and \"actor/reference_rejected_logps\" in metrics\n            ):\n                policy_ratio_mean = metrics[\"actor/policy_chosen_logps\"] - metrics[\"actor/policy_rejected_logps\"]\n                ref_ratio_mean = metrics[\"actor/reference_chosen_logps\"] - metrics[\"actor/reference_rejected_logps\"]\n                logits_mean = policy_ratio_mean - ref_ratio_mean\n                metrics[\"actor/rewards_chosen\"] = beta * (\n                    metrics[\"actor/policy_chosen_logps\"] - metrics[\"actor/reference_chosen_logps\"]\n                )\n                metrics[\"actor/rewards_rejected\"] = beta * (\n                    metrics[\"actor/policy_rejected_logps\"] - metrics[\"actor/reference_rejected_logps\"]\n                )\n                metrics[\"actor/rewards_accuracies\"] = float(logits_mean > 0)  # Mean accuracy proxy\n                metrics[\"actor/rewards_margins\"] = metrics[\"actor/rewards_chosen\"] - metrics[\"actor/rewards_rejected\"]\n\n        else:  # Handle case where no micro-batches were run (e.g., bsz=0)\n            metrics[\"actor/dpo_loss\"] = 0.0\n            metrics[\"actor/grad_norm\"] = 0.0\n            # Initialize other metrics to 0 or NaN as appropriate\n            for key in accumulated_metrics.keys():\n                metrics[key.replace(\"_batch\", \"\")] = 0.0\n            metrics[\"actor/rewards_chosen\"] = 0.0\n            metrics[\"actor/rewards_rejected\"] = 0.0\n            metrics[\"actor/rewards_accuracies\"] = 0.0\n            metrics[\"actor/rewards_margins\"] = 0.0\n\n        return metrics  # Return aggregated metrics\n"
  },
  {
    "path": "verl_distillation/recipe/spin/fsdp_workers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport logging\nimport os\nimport warnings\n\nimport numpy as np\nimport psutil\nimport torch\nimport torch.distributed\nfrom codetiming import Timer\nfrom omegaconf import OmegaConf, open_dict\nfrom torch.distributed.device_mesh import init_device_mesh\n\nimport verl.utils.torch_functional as verl_F\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.fsdp_utils import (\n    get_fsdp_wrap_policy,\n    get_init_weight_context_manager,\n    init_fn,\n    load_fsdp_model_to_gpu,\n    load_fsdp_optimizer,\n    offload_fsdp_model_to_cpu,\n    offload_fsdp_optimizer,\n)\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.utils.profiler import log_gpu_memory_usage\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_PPO_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef create_device_mesh(world_size, fsdp_size):\n    if fsdp_size < 0 or fsdp_size >= world_size:\n        device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n    else:\n        device_mesh = init_device_mesh(\n            get_device_name(), mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=[\"ddp\", \"fsdp\"]\n        )\n    return device_mesh\n\n\ndef get_sharding_strategy(device_mesh):\n    from torch.distributed.fsdp import ShardingStrategy\n\n    if device_mesh.ndim == 1:\n        sharding_strategy = ShardingStrategy.FULL_SHARD\n    elif device_mesh.ndim == 2:\n        sharding_strategy = ShardingStrategy.HYBRID_SHARD\n    else:\n        raise NotImplementedError(f\"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2\")\n    return sharding_strategy\n\n\nclass SPINRolloutRefWorker(ActorRolloutRefWorker):\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        from recipe.spin.dp_actor import SPINDataParallelPPOActor as DataParallelPPOActor\n\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get(\"override_config\", {})))\n        use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        use_fused_kernels = self.config.model.get(\"use_fused_kernels\", False)\n\n        if self._is_actor or self._is_rollout or self._is_ref:\n            # we need the model for actor and rollout\n            if self._is_actor or self._is_ref:\n                optim_config = self.config.actor.optim\n                fsdp_config = self.config.actor.fsdp_config\n            else:\n                optim_config = None\n                fsdp_config = OmegaConf.create()\n            self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = (\n                self._build_model_optimizer(\n                    model_path=self.config.model.path,\n                    fsdp_config=fsdp_config,\n                    optim_config=optim_config,\n                    override_model_config=override_model_config,\n                    use_remove_padding=use_remove_padding,\n                    use_fused_kernels=use_fused_kernels,\n                    enable_gradient_checkpointing=self.config.model.get(\"enable_gradient_checkpointing\", False),\n                    trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                    use_liger=self.config.model.get(\"use_liger\", False),\n                    role=\"actor\",\n                )\n            )\n\n            # get the original unwrapped module\n            self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module\n\n            if self._is_offload_optimizer:\n                offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n                log_gpu_memory_usage(\"After offload actor optimizer during init\", logger=logger)\n        # load from checkpoint\n        if self._is_actor or self._is_ref:\n            OmegaConf.set_struct(self.config.actor, True)\n            with open_dict(self.config.actor):\n                self.config.actor.use_remove_padding = use_remove_padding\n                self.config.actor.use_fused_kernels = use_fused_kernels\n            self.actor = DataParallelPPOActor(\n                config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer\n            )\n\n        if self._is_rollout:\n            self._build_rollout(trust_remote_code=self.config.model.get(\"trust_remote_code\", False))\n\n        if self._is_ref:\n            self.ref_module_fsdp = self._build_model_optimizer(\n                model_path=self.config.model.path,\n                fsdp_config=self.config.ref.fsdp_config,\n                optim_config=None,\n                override_model_config=override_model_config,\n                use_remove_padding=use_remove_padding,\n                use_fused_kernels=use_fused_kernels,\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                use_liger=self.config.model.get(\"use_liger\", False),\n                role=\"ref\",\n            )[0]\n            OmegaConf.set_struct(self.config.ref, True)\n            with open_dict(self.config.ref):\n                self.config.ref.use_remove_padding = use_remove_padding\n                self.config.ref.use_fused_kernels = use_fused_kernels\n            self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp,\n                optimizer=self.actor.actor_optimizer,\n                lr_scheduler=self.actor_lr_scheduler,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                checkpoint_config=self.config.actor.checkpoint,\n            )\n\n        if self._is_actor:\n            self.flops_counter = FlopsCounter(self.actor_model_config)\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp,\n                optimizer=self.actor.actor_optimizer,\n                lr_scheduler=self.actor_lr_scheduler,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                checkpoint_config=self.config.actor.checkpoint,\n            )\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    def compute_ref_log_prob(self, data: DataProto):\n        assert self._is_ref\n\n        # Support all hardwares\n        data = data.to(get_device_id())\n\n        micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        data.meta_info[\"max_token_len\"] = self.config.ref.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.ref.log_prob_use_dynamic_bsz\n        with self.ulysses_sharding_manager:\n            output = self.ref_policy.compute_log_prob(data=data)\n            output = DataProto.from_dict(tensors={\"ref_log_prob\": output})\n\n        output = output.to(\"cpu\")\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1:\n            self.ref_policy.actor_module._handle.reshard(True)\n\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    def compute_log_prob(self, data: DataProto):\n        assert self._is_actor\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        # Support all hardwares\n        data = data.to(get_device_id())\n        # we should always recompute old_log_probs when it is HybridEngine\n        data.meta_info[\"micro_batch_size\"] = self.config.rollout.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"max_token_len\"] = self.config.rollout.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.rollout.log_prob_use_dynamic_bsz\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        # perform recompute log_prob\n        with self.ulysses_sharding_manager:\n            output = self.actor.compute_log_prob(data=data)\n            output = DataProto.from_dict(\n                tensors={\"old_log_probs\": output}, meta_info={\"temperature\": self.config.rollout.temperature}\n            )\n\n        output = output.to(\"cpu\")\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1:\n            self.actor.actor_module._handle.reshard(True)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n\n        log_gpu_memory_usage(\"After compute_log_prob\", logger=logger)\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    def update_actor_dpo(self, data: DataProto):\n        \"\"\"\n        Wrapper for actor update step. Handles FSDP state management.\n        Calls self.actor.update_policy which now contains DPO logic based\n        on pre-calculated log probabilities.\n        \"\"\"\n        # Support all hardwares\n        data = data.to(get_device_id())\n\n        assert self._is_actor  # Make sure this worker has the actor role\n        if self.actor is None:\n            raise RuntimeError(\"Actor instance (self.actor) not initialized in worker.\")\n\n        # --- FSDP State Management ---\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n        if self._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())\n\n        log_gpu_memory_usage(\"Before update policy (DPO via PPO path)\", logger=logger)\n\n        # --- Ulysses Sharding (if used) ---\n        with self.ulysses_sharding_manager:\n            # --- Call the core update method (now containing DPO logic) ---\n            with Timer(name=\"update_policy_dpo_via_ppo\", logger=None) as timer:  # Use a distinct timer name\n                # Calls the modified update_policy method\n                metrics = self.actor.update_policy_dpo_with_ref(data=data)  # <-- THIS CALLS THE MODIFIED FUNCTION\n            delta_time = timer.last\n\n            # --- Add Performance Metrics ---\n            # MFU calculation might be less accurate/meaningful here for DPO\n            metrics[\"perf/approx_tokens_processed\"] = torch.sum(\n                data.batch.get(\"attention_mask\", torch.tensor(0))\n            ).item()  # Approx tokens\n            metrics[\"perf/max_memory_allocated_gb\"] = get_torch_device().max_memory_allocated() / (1024**3)\n            metrics[\"perf/max_memory_reserved_gb\"] = get_torch_device().max_memory_reserved() / (1024**3)\n            metrics[\"perf/cpu_memory_used_gb\"] = psutil.virtual_memory().used / (1024**3)\n            global_num_tokens = data.meta_info[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/actor\"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size\n\n            # --- LR Scheduler Step ---\n            lr = self.actor_lr_scheduler.get_last_lr()[0]\n            metrics[\"actor/lr\"] = lr\n            self.actor_lr_scheduler.step()\n\n            log_gpu_memory_usage(\"After update policy (DPO via PPO path)\", logger=logger)\n\n            # --- Prepare Output ---\n            output = DataProto(meta_info={\"metrics\": metrics})\n            output = output.to(\"cpu\")\n\n        # --- FSDP State Management (Offload) ---\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n\n        return output\n\n\n# TODO(sgm): we may need to extract it to dp_reward_model.py\nclass RewardModelWorker(Worker):\n    \"\"\"\n    Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(backend=get_nccl_backend())\n        self.config = config\n\n        # build device mesh for Ulysses Sequence Parallel\n        world_size = torch.distributed.get_world_size()\n        from torch.distributed.device_mesh import init_device_mesh\n\n        fsdp_size = self.config.model.fsdp_config.fsdp_size\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        if self.ulysses_device_mesh is not None:\n            is_collect = self.ulysses_device_mesh[\"sp\"].get_local_rank() == 0\n            self._register_dispatch_collect_info(\n                \"reward\", dp_rank=self.ulysses_device_mesh[\"dp\"].get_local_rank(), is_collect=is_collect\n            )\n        else:\n            self._register_dispatch_collect_info(\"reward\", dp_rank=self.rank, is_collect=True)\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        self.use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n\n        # normalize config\n        if self.config.micro_batch_size is not None:\n            self.config.micro_batch_size //= torch.distributed.get_world_size()\n            self.config.micro_batch_size_per_gpu = self.config.micro_batch_size\n\n    def _build_model(self, config):\n        # the following line is necessary\n        from torch.distributed.fsdp import CPUOffload\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from transformers import AutoConfig, AutoModelForTokenClassification\n\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.model.path)\n\n        if self.config.model.input_tokenizer is None:\n            self._do_switch_chat_template = False\n        else:\n            self._do_switch_chat_template = True\n            input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer)\n            self.input_tokenizer = hf_tokenizer(\n                input_tokenizer_local_path, trust_remote_code=config.model.get(\"trust_remote_code\", False)\n            )\n            self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n\n        trust_remote_code = config.model.get(\"trust_remote_code\", False)\n        model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)\n        model_config.num_labels = 1\n\n        # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            model_config.classifier_dropout = 0.0\n            reward_module = AutoModelForTokenClassification.from_pretrained(\n                pretrained_model_name_or_path=local_path,\n                config=model_config,\n                torch_dtype=torch.bfloat16,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )\n\n            if config.model.get(\"use_remove_padding\", False) or self.ulysses_sequence_parallel_size > 1:\n                from verl.models.transformers.monkey_patch import apply_monkey_patch\n\n                apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)\n\n            reward_module.to(torch.bfloat16)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        reward_module = FSDP(\n            reward_module,\n            param_init_fn=init_fn,\n            use_orig_params=False,\n            auto_wrap_policy=auto_wrap_policy,\n            device_id=get_device_id(),\n            sharding_strategy=sharding_strategy,  # zero3\n            sync_module_states=True,\n            cpu_offload=CPUOffload(offload_params=True),\n            forward_prefetch=False,\n            device_mesh=self.device_mesh,\n        )\n\n        return reward_module\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n        self.reward_module = self._build_model(config=self.config)\n\n    def _forward_micro_batch(self, micro_batch):\n        from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\n\n        from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs\n\n        with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch_size, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                position_ids_rmpad = index_first_axis(\n                    rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                ).transpose(0, 1)\n\n                # pad and slice the inputs if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size\n                    )\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                output = self.reward_module(\n                    input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False\n                )  # prevent model thinks we are generating\n                reward_rmpad = output.logits\n                reward_rmpad = reward_rmpad.squeeze(0)  # (total_nnz)\n\n                # gather output if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    reward_rmpad = gather_outputs_and_unpad(\n                        reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size\n                    )\n\n                # pad it back\n                rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)\n            else:\n                output = self.reward_module(\n                    input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False\n                )\n                rm_score = output.logits  # (batch_size, seq_len, 1)\n                rm_score = rm_score.squeeze(-1)\n\n            # extract the result of the last valid token\n            eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)\n            rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]\n            return rm_score\n\n    def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):\n        batch_size = data.batch.batch_size[0]\n        # expand as token_level_reward\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        response_length = data.batch[\"responses\"].shape[-1]\n        eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)\n        token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)  # (bsz, seqlen)\n        token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores\n\n        # select the response part\n        token_level_scores = token_level_scores[:, -response_length:]\n\n        return token_level_scores\n\n    def _switch_chat_template(self, data: DataProto):\n        src_max_length = data.batch[\"attention_mask\"].shape[-1]\n\n        src_tokenizer = self.input_tokenizer\n        target_tokenizer = self.tokenizer\n\n        rm_input_ids = []\n        rm_attention_mask = []\n\n        for i in range(data.batch.batch_size[0]):\n            if not isinstance(data.non_tensor_batch[\"raw_prompt\"][i], list | np.ndarray):\n                raise TypeError(\n                    f\"raw_prompt must be a list or numpy array, got {type(data.non_tensor_batch['raw_prompt'][i])}\"\n                )\n\n            # extract raw prompt\n            chat: list = list(data.non_tensor_batch[\"raw_prompt\"][i])\n\n            # extract response\n            response_ids = data.batch[\"responses\"][i]\n            response_length = response_ids.shape[-1]\n            valid_response_length = data.batch[\"attention_mask\"][i][-response_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            response = src_tokenizer.decode(valid_response_ids)\n            # remove bos and eos\n            response = response.replace(src_tokenizer.eos_token, \"\")\n\n            chat.append({\"role\": \"assistant\", \"content\": response})\n\n            prompt_with_chat_template = target_tokenizer.apply_chat_template(\n                chat, add_generation_prompt=False, tokenize=False\n            )\n            if self.rank == 0 and i == 0:\n                # for debugging purpose\n                print(f\"Switch template. chat: {prompt_with_chat_template}\")\n\n            # the maximum length is actually determined by the reward model itself\n            max_length = self.config.get(\"max_length\", src_max_length)\n            if max_length is None:\n                max_length = src_max_length\n\n            model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids, attention_mask = verl_F.postprocess_data(\n                input_ids=model_inputs[\"input_ids\"],\n                attention_mask=model_inputs[\"attention_mask\"],\n                max_length=max_length,\n                pad_token_id=target_tokenizer.pad_token_id,\n                left_pad=False,  # right padding\n                truncation=self.config.get(\"truncation\", \"right\"),\n            )  # truncate from the right\n\n            rm_input_ids.append(input_ids)\n            rm_attention_mask.append(attention_mask)\n\n        rm_input_ids = torch.cat(rm_input_ids, dim=0)\n        rm_attention_mask = torch.cat(rm_attention_mask, dim=0)\n\n        rm_position_ids = compute_position_id_with_mask(rm_attention_mask)\n\n        rm_inputs = {\"input_ids\": rm_input_ids, \"attention_mask\": rm_attention_mask, \"position_ids\": rm_position_ids}\n\n        return DataProto.from_dict(rm_inputs)\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"reward\"))\n    def compute_rm_score(self, data: DataProto):\n        import itertools\n\n        from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\n\n        # Support all hardwares\n        data = data.to(get_device_id())\n        if self._do_switch_chat_template:\n            rm_data = self._switch_chat_template(data)\n        else:\n            rm_input_ids = data.batch[\"input_ids\"]\n            rm_attention_mask = data.batch[\"attention_mask\"]\n            rm_position_ids = data.batch[\"position_ids\"]\n            rm_inputs = {\n                \"input_ids\": rm_input_ids,\n                \"attention_mask\": rm_attention_mask,\n                \"position_ids\": rm_position_ids,\n            }\n            rm_data = DataProto.from_dict(rm_inputs)\n\n        # Support all hardwares\n        rm_data.batch = rm_data.batch.to(get_device_id())\n\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data)\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n\n            use_dynamic_bsz = self.config.use_dynamic_bsz\n            if use_dynamic_bsz:\n                max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)\n            else:\n                micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)\n            output = []\n            for micro_batch in micro_batches:\n                rm_score = self._forward_micro_batch(micro_batch)\n                output.append(rm_score)\n            scores = torch.cat(output, dim=0)  # (batch_size)\n\n            if use_dynamic_bsz:\n                indices = list(itertools.chain.from_iterable(indices))\n                assert len(indices) == scores.size(0), f\"{len(indices)} vs. {scores.size()}\"\n                revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                scores = scores[revert_indices]\n\n            token_level_scores = self._expand_to_token_level(data, scores)\n            # Note that this is only the scores, may not be the final rewards used to train RL\n            output = DataProto.from_dict(tensors={\"rm_scores\": token_level_scores})\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        self.reward_module._handle.reshard(True)\n\n        output = output.to(\"cpu\")\n        return output\n"
  },
  {
    "path": "verl_distillation/recipe/spin/main_spin.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport hydra\nimport ray\n\nfrom recipe.spin.spin_trainer import RaySPINTrainer\nfrom recipe.spin.utils import validate_config\nfrom verl.trainer.ppo.reward import get_custom_reward_fn\nfrom verl.trainer.ppo.utils import need_reference_policy\n\n\n@hydra.main(config_path=\"config\", config_name=\"spin_trainer\", version_base=None)\ndef main(config):\n    run_ppo(config)\n\n\ndef run_ppo(config) -> None:\n    # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices\n    # isolation, will solve in the future\n    os.environ[\"ENSURE_CUDA_VISIBLE_DEVICES\"] = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"\")\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        ray.init(\n            runtime_env={\n                \"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\", \"VLLM_LOGGING_LEVEL\": \"WARN\"}\n            }\n        )\n\n    runner = TaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    def run(self, config):\n        # print initial config\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n        OmegaConf.resolve(config)\n\n        # define worker classes\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n            # from recipe.spin.fsdp_workers import ActorRolloutRefWorker\n            from recipe.spin.fsdp_workers import SPINRolloutRefWorker\n            from verl.single_controller.ray import RayWorkerGroup\n\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray import RayWorkerGroup\n\n            ray_worker_group_cls = RayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from recipe.spin.spin_trainer import ResourcePoolManager, Role\n\n        role_worker_mapping = {\n            # Role.ActorRollout: ray.remote(ActorRolloutRefWorker),\n            Role.ActorRollout: ray.remote(SPINRolloutRefWorker),\n            # Role.Critic: ray.remote(CriticWorker),\n        }\n\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {\n            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n        }\n        mapping = {\n            Role.ActorRollout: global_pool_id,\n            # Role.Critic: global_pool_id,\n        }\n\n        if config.reward_model.enable:\n            if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                from recipe.spin.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # use reference model\n        # if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n        # role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n        role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker)\n        mapping[Role.RefPolicy] = global_pool_id\n\n        # validate config\n        validate_config(\n            config=config,\n            use_reference_policy=need_reference_policy(role_worker_mapping),\n            use_critic=False,\n        )\n\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n\n        # instantiate tokenizer\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        processor = hf_processor(local_path, use_fast=True)  # used for multimodal LLM, could be none\n\n        from verl.workers.reward_manager import get_reward_manager_cls\n\n        # Note(haibin.lin): please make sure custom reward managers are imported and\n        # registered via `verl.workers.reward_manager.register`\n        reward_manager_name = config.reward_model.get(\"reward_manager\", \"naive\")\n        reward_manager_cls = get_reward_manager_cls(reward_manager_name)\n\n        compute_score = get_custom_reward_fn(config)\n        reward_kwargs = dict(config.reward_model.get(\"reward_kwargs\", {}))\n        reward_fn = reward_manager_cls(\n            tokenizer=tokenizer,\n            num_examine=0,\n            compute_score=compute_score,\n            reward_fn_key=config.data.reward_fn_key,\n            **reward_kwargs,\n        )\n\n        # Note that we always use function-based RM for validation\n        val_reward_fn = reward_manager_cls(\n            tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key\n        )\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        trainer = RaySPINTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n        )\n        trainer.init_workers()\n        trainer.fit_dpo()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/spin/run_spin.sh",
    "content": "set -e\nset -x\nVISIBLE_DEVICES=\"4,5,6,7\"\nexport HYDRA_FULL_ERROR=1\n\nCUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} python3 -m recipe.spin.main_spin \\\n  data.train_files=$HOME/data/gsm8k/train.parquet \\\n  data.val_files=$HOME/data/gsm8k/test.parquet \\\n  data.train_batch_size=1024 \\\n  data.max_prompt_length=1024 \\\n  data.max_response_length=1024 \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.actor.optim.lr=1e-6 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size=8 \\\n  actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n  actor_rollout_ref.ref.log_prob_micro_batch_size=64 \\\n  algorithm.kl_ctrl.kl_coef=0.001 \\\n  trainer.logger=console \\\n  trainer.val_before_train=True \\\n  trainer.n_gpus_per_node=4 \\\n  trainer.nnodes=1 \\\n  trainer.save_freq=-1 \\\n  trainer.test_freq=1 \\\n  +trainer.log_freq=1 \\\n  trainer.ref_update_freq=1 \\\n  trainer.total_epochs=1000 2>&1 | tee verl_demo.log"
  },
  {
    "path": "verl_distillation/recipe/spin/spin_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport traceback\nimport uuid\nfrom collections import defaultdict\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass, field\nfrom pprint import pprint\nfrom typing import Any, Optional\n\nimport numpy as np\nimport ray\nimport torch\nfrom codetiming import Timer\nfrom omegaconf import OmegaConf, open_dict\nfrom torch.utils.data import Dataset, Sampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom tqdm import tqdm\n\nfrom recipe.spin import core_algos\nfrom verl import DataProto\nfrom verl.protocol import pad_dataproto_to_divisor, unpad_dataproto\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.ppo.metric_utils import compute_throughout_metrics, compute_timing_metrics, process_validation_metrics\nfrom verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model\nfrom verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path\nfrom verl.utils.metric import reduce_metrics\nfrom verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance\nfrom verl.utils.torch_functional import masked_mean\nfrom verl.utils.tracking import ValidationGenerationsLogger\n\n\n@dataclass\nclass ResourcePoolManager:\n    \"\"\"\n    Define a resource pool specification. Resource pool will be initialized first.\n    Mapping\n    \"\"\"\n\n    resource_pool_spec: dict[str, list[int]]\n    mapping: dict[Role, str]\n    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)\n\n    def create_resource_pool(self):\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool\n            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.\n            # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different\n            # WorkerGroup for different models\n            resource_pool = RayResourcePool(\n                process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name\n            )\n            self.resource_pool_dict[resource_pool_name] = resource_pool\n\n        self._check_resource_available()\n\n    def get_resource_pool(self, role: Role) -> RayResourcePool:\n        \"\"\"Get the resource pool of the worker_cls\"\"\"\n        return self.resource_pool_dict[self.mapping[role]]\n\n    def get_n_gpus(self) -> int:\n        \"\"\"Get the number of gpus in this cluster.\"\"\"\n        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])\n\n    def _check_resource_available(self):\n        \"\"\"Check if the resource pool can be satisfied in this ray cluster.\"\"\"\n        node_available_resources = ray._private.state.available_resources_per_node()\n        node_available_gpus = {node: node_info.get(\"GPU\", 0) for node, node_info in node_available_resources.items()}\n\n        # check total required gpus can be satisfied\n        total_available_gpus = sum(node_available_gpus.values())\n        total_required_gpus = sum(\n            [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]\n        )\n        if total_available_gpus < total_required_gpus:\n            raise ValueError(\n                f\"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}\"\n            )\n\n        # check each resource pool can be satisfied, O(#resource_pools * #nodes)\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)\n            for node, available_gpus in node_available_gpus.items():\n                if available_gpus >= num_gpus:\n                    node_available_gpus[node] -= num_gpus\n                    num_nodes -= 1\n                    if num_nodes == 0:\n                        break\n            if num_nodes > 0:\n                raise ValueError(\n                    f\"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this \"\n                    f\"ray cluster\"\n                )\n\n\ndef _compute_response_info(batch: DataProto) -> dict[str, Any]:\n    \"\"\"Placeholder: Computes prompt and response lengths.\"\"\"\n    try:\n        # Assuming 'prompts' and 'responses' keys exist after generation/union\n        prompt_len = batch.batch[\"prompts\"].shape[1]\n        resp_len = batch.batch[\"responses\"].shape[1]\n        # This is simplified - real implementation might use attention masks\n        # to get actual lengths per sample.\n        batch_size = batch.batch.batch_size[0]\n        prompt_lengths_tensor = torch.full((batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device)\n        response_lengths_tensor = torch.full((batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device)\n\n        # Try getting actual lengths from attention mask if possible (more accurate)\n        if \"response_mask\" in batch.batch:\n            response_lengths_tensor = batch.batch[\"response_mask\"].sum(dim=1).float()\n            # if \"attention_mask\" in batch.batch and \"response_mask\" in batch.batch:\n            # full_mask = batch.batch[\"attention_mask\"]\n            # resp_mask = batch.batch[\"response_mask\"]\n            # Infer prompt mask length based on where response mask starts or total length\n            # This logic depends heavily on how your masks are constructed.\n            # Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor\n            # Fallback to using prompt shape if mask logic is complex:\n            prompt_lengths_tensor = torch.tensor(\n                [batch.batch[\"prompts\"].shape[1]] * batch_size, dtype=torch.float32, device=batch.batch.device\n            )\n\n        return {\n            \"prompt_length\": prompt_lengths_tensor,\n            \"response_length\": response_lengths_tensor,\n            \"max_response_length\": resp_len,\n            \"max_prompt_length\": prompt_len,  # Or from config if fixed padding\n        }\n    except KeyError as e:\n        print(f\"Warning: Missing key in _compute_response_info: {e}. Returning defaults.\")\n        # Return default/dummy values if keys are missing\n        b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1\n        max_resp = batch.batch.get(\"responses\").shape[1] if batch.batch.get(\"responses\") is not None else 0\n        max_prompt = batch.batch.get(\"prompts\").shape[1] if batch.batch.get(\"prompts\") is not None else 0\n        return {\n            \"prompt_length\": torch.zeros(b_size),\n            \"response_length\": torch.zeros(b_size),\n            \"max_response_length\": max_resp,\n            \"max_prompt_length\": max_prompt,\n        }\n\n\n# --- Modified Metric Function ---\ndef compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]:\n    \"\"\"\n    Computes and returns metrics relevant for the DPO-like process.\n    Assumes 'batch' contains results after generation and preference marking,\n    potentially including 'dpo_logits', 'preferences', 'chosen_logps', etc.\n    Removes PPO-specific advantage/return/critic metrics.\n    \"\"\"\n    print(\"---- [DEBUG] Computing DPO Data Metrics ----\")\n    metrics = {}\n    try:\n        # --- Scores and Rewards (from reward_fn) ---\n        if \"token_level_scores\" in batch.batch and batch.batch[\"token_level_scores\"] is not None:\n            sequence_score = batch.batch[\"token_level_scores\"].sum(-1)\n            metrics.update(\n                {\n                    \"reward/score/mean\": torch.mean(sequence_score).item(),\n                    \"reward/score/max\": torch.max(sequence_score).item(),\n                    \"reward/score/min\": torch.min(sequence_score).item(),\n                }\n            )\n        else:\n            print(\"DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.\")\n\n        if \"token_level_rewards\" in batch.batch and batch.batch[\"token_level_rewards\"] is not None:\n            sequence_reward = batch.batch[\"token_level_rewards\"].sum(-1)\n            metrics.update(\n                {\n                    \"reward/rewards/mean\": torch.mean(sequence_reward).item(),\n                    \"reward/rewards/max\": torch.max(sequence_reward).item(),\n                    \"reward/rewards/min\": torch.min(sequence_reward).item(),\n                }\n            )\n        else:\n            print(\"DEBUG compute_dpo_data_metrics: 'token_level_rewards' not found.\")\n\n        # --- DPO Specific Metrics (if stored previously) ---\n        if \"dpo_logits\" in batch.batch and batch.batch[\"dpo_logits\"] is not None:\n            metrics[\"actor/dpo_logits\"] = batch.batch[\"dpo_logits\"].mean().item()\n        else:\n            print(\"DEBUG compute_dpo_data_metrics: 'dpo_logits' not found.\")\n\n        if \"chosen_logps\" in batch.batch and batch.batch[\"chosen_logps\"] is not None:\n            metrics[\"actor/chosen_logps\"] = batch.batch[\"chosen_logps\"].mean().item()\n        else:\n            print(\"DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.\")\n\n        if \"rejected_logps\" in batch.batch and batch.batch[\"rejected_logps\"] is not None:\n            metrics[\"actor/rejected_logps\"] = batch.batch[\"rejected_logps\"].mean().item()\n        else:\n            print(\"DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.\")\n\n        # Add metrics based on the 'preferences' mask if available\n        # if \"preferences\" in batch.batch and batch.batch[\"preferences\"] is not None:\n        # prefs_mask = batch.batch[\"preferences\"]  # Shape [batch_size * n]\n        # Calculate accuracy based on RM scores (assuming higher score -> True in mask)\n        # Requires chosen/rejected scores to be available or recalculated\n        # This is complex here, better calculated in the main loop or update function\n\n        # --- Length Metrics ---\n        response_info = _compute_response_info(batch)\n        prompt_length = response_info[\"prompt_length\"]\n        response_length = response_info[\"response_length\"]\n        max_response_length = response_info[\"max_response_length\"]\n        max_prompt_length = response_info[\"max_prompt_length\"]  # Use calculated or from config\n\n        metrics.update(\n            {\n                \"response_length/mean\": torch.mean(response_length).item(),\n                \"response_length/max\": torch.max(response_length).item(),\n                \"response_length/min\": torch.min(response_length).item(),\n                \"response_length/clip_ratio\": torch.mean(torch.eq(response_length, max_response_length).float()).item(),\n                \"prompt_length/mean\": torch.mean(prompt_length).item(),\n                \"prompt_length/max\": torch.max(prompt_length).item(),\n                \"prompt_length/min\": torch.min(prompt_length).item(),\n                # Prompt clip ratio might need adjustment based on how max_prompt_length is defined\n                \"prompt_length/clip_ratio\": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(),\n            }\n        )\n\n    except KeyError as e:\n        print(f\"ERROR in compute_dpo_data_metrics: Missing key {e}\")\n    except Exception as e:\n        print(f\"ERROR in compute_dpo_data_metrics: {e}\")\n        traceback.print_exc()\n\n    print(f\"---- [DEBUG] Calculated DPO Data Metrics: {list(metrics.keys())} ----\")\n    return metrics\n\n\ndef apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty=\"kl\"):\n    responses = data.batch[\"responses\"]\n    response_length = responses.size(1)\n    token_level_scores = data.batch[\"token_level_scores\"]\n    batch_size = data.batch.batch_size[0]\n    attention_mask = data.batch[\"attention_mask\"]\n    response_mask = attention_mask[:, -response_length:]\n\n    # compute kl between ref_policy and current policy\n    # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.\n    kld = core_algos.kl_penalty(\n        data.batch[\"old_log_probs\"], data.batch[\"ref_log_prob\"], kl_penalty=kl_penalty\n    )  # (batch_size, response_length)\n    kld = kld * response_mask\n    beta = kl_ctrl.value\n\n    token_level_rewards = token_level_scores - beta * kld\n\n    current_kl = masked_mean(kld, mask=response_mask, axis=-1)  # average over sequence\n    current_kl = torch.mean(current_kl, dim=0).item()\n\n    # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837\n    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)\n    data.batch[\"token_level_rewards\"] = token_level_rewards\n\n    metrics = {\"actor/reward_kl_penalty\": current_kl, \"actor/reward_kl_penalty_coeff\": beta}\n\n    return data, metrics\n\n\ndef compute_response_mask(data: DataProto):\n    responses = data.batch[\"responses\"]\n    response_length = responses.size(1)\n    attention_mask = data.batch[\"attention_mask\"]\n    return attention_mask[:, -response_length:]\n\n\ndef compute_onlineDPO_pref(data: DataProto):\n    \"\"\"\n    Wrapper to compute DPO preference and add it to the DataProto batch.\n    Includes debugging prints.\n    \"\"\"\n    # print(f\"\\n---- [DEBUG] Entering compute_onlineDPO_pref ----\")\n    # print(f\"  Input batch keys: {list(data.batch.keys())}\")\n\n    # Check inputs\n    rewards_tensor = data.batch.get(\"token_level_rewards\")\n    mask_tensor = data.batch.get(\"response_mask\")\n\n    if rewards_tensor is None or mask_tensor is None:\n        print(\"  ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!\")\n        # Handle error case - maybe return original data or raise?\n        # Returning original data for now to potentially allow skipping\n        return data\n\n    try:\n        preferences = core_algos.compute_onlinedpo_pref(token_level_rewards=rewards_tensor, response_mask=mask_tensor)\n        # Store the result\n        data.batch[\"preferences\"] = preferences\n\n    except AttributeError:\n        print(\"ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!\")\n        # Assign dummy value or raise error\n        data.batch[\"preferences\"] = None  # Indicate failure\n    except Exception as e_pref:\n        print(f\"ERROR during core_algos.compute_online_dpo_preference: {e_pref}\")\n        import traceback\n\n        traceback.print_exc()\n        data.batch[\"preferences\"] = None  # Indicate failure\n\n    # print(f\"---- [DEBUG] Exiting compute_onlineDPO_pref ----\")\n    return data\n\n\n@contextmanager\ndef _timer(name: str, timing_raw: dict[str, float]):\n    with Timer(name=name, logger=None) as timer:\n        yield\n    timing_raw[name] = timer.last\n\n\nclass RaySPINTrainer:\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        train_dataset: Optional[Dataset] = None,\n        val_dataset: Optional[Dataset] = None,\n        collate_fn=None,\n        train_sampler: Optional[Sampler] = None,\n        device_name=None,\n    ):\n        # assert get_torch_device().is_available(), 'cuda must be available on driver'\n\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n        assert self.hybrid_engine, \"Currently, only support hybrid engine\"\n\n        if self.hybrid_engine:\n            assert Role.ActorRollout in role_worker_mapping, f\"{role_worker_mapping.keys()=}\"\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = need_reference_policy(role_worker_mapping)\n        self.use_rm = need_reward_model(role_worker_mapping)\n        self.use_critic = False\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.validation_generations_logger = ValidationGenerationsLogger()\n        self.async_rollout_mode = False\n        self.device_name = device_name if device_name else self.config.trainer.device\n\n        # define in-reward KL control\n        # kl loss control currently not suppoorted\n        if config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)\n\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n    def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):\n        \"\"\"\n        Creates the train and validation dataloaders.\n        \"\"\"\n        # TODO: we have to make sure the batch size is divisible by the dp size\n        from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler\n\n        if train_dataset is None:\n            train_dataset = create_rl_dataset(\n                self.config.data.train_files,\n                self.config.data,\n                self.tokenizer,\n                self.processor,\n                max_samples=self.config.data.get(\"train_max_samples\", -1),\n            )\n        if val_dataset is None:\n            val_dataset = create_rl_dataset(\n                self.config.data.val_files,\n                self.config.data,\n                self.tokenizer,\n                self.processor,\n                max_samples=self.config.data.get(\"val_max_samples\", -1),\n            )\n        self.train_dataset, self.val_dataset = train_dataset, val_dataset\n\n        if train_sampler is None:\n            train_sampler = create_rl_sampler(self.config.data, self.train_dataset)\n        if collate_fn is None:\n            from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn\n\n            collate_fn = default_collate_fn\n\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=self.config.data.get(\"gen_batch_size\", self.config.data.train_batch_size),\n            num_workers=self.config.data.get(\"dataloader_num_workers\", 8),\n            drop_last=True,\n            collate_fn=collate_fn,\n            sampler=train_sampler,\n        )\n\n        val_batch_size = self.config.data.val_batch_size  # Prefer config value if set\n        if val_batch_size is None:\n            val_batch_size = len(self.val_dataset)\n\n        self.val_dataloader = StatefulDataLoader(\n            dataset=self.val_dataset,\n            batch_size=val_batch_size,\n            num_workers=self.config.data.get(\"dataloader_num_workers\", 8),\n            shuffle=False,\n            drop_last=False,\n            collate_fn=collate_fn,\n        )\n\n        assert len(self.train_dataloader) >= 1, \"Train dataloader is empty!\"\n        assert len(self.val_dataloader) >= 1, \"Validation dataloader is empty!\"\n\n        print(\n            f\"Size of train dataloader: {len(self.train_dataloader)}, \"\n            f\"Size of val dataloader: {len(self.val_dataloader)}\"\n        )\n\n        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n\n        if self.config.trainer.total_training_steps is not None:\n            total_training_steps = self.config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n        print(f\"Total training steps: {self.total_training_steps}\")\n\n        try:\n            OmegaConf.set_struct(self.config, True)\n            with open_dict(self.config):\n                if OmegaConf.select(self.config, \"actor_rollout_ref.actor.optim\"):\n                    self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps\n                if OmegaConf.select(self.config, \"critic.optim\"):\n                    self.config.critic.optim.total_training_steps = total_training_steps\n        except Exception as e:\n            print(f\"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}\")\n\n    def _maybe_log_val_generations(self, inputs, outputs, scores):\n        \"\"\"Log a table of validation samples to the configured logger (wandb or swanlab)\"\"\"\n\n        generations_to_log = self.config.trainer.log_val_generations\n\n        if generations_to_log == 0:\n            return\n\n        import numpy as np\n\n        # Create tuples of (input, output, score) and sort by input text\n        samples = list(zip(inputs, outputs, scores, strict=True))\n        samples.sort(key=lambda x: x[0])  # Sort by input text\n\n        # Use fixed random seed for deterministic shuffling\n        rng = np.random.RandomState(42)\n        rng.shuffle(samples)\n\n        # Take first N samples after shuffling\n        samples = samples[:generations_to_log]\n\n        # Log to each configured logger\n        self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)\n\n    def _validate(self):\n        data_source_lst = []\n        reward_extra_infos_dict: dict[str, list] = defaultdict(list)\n\n        # Lists to collect samples for the table\n        sample_inputs = []\n        sample_outputs = []\n        sample_scores = []\n\n        for test_data in self.val_dataloader:\n            test_batch = DataProto.from_single_dict(test_data)\n\n            # repeat test batch\n            test_batch = test_batch.repeat(\n                repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True\n            )\n\n            # we only do validation on rule-based rm\n            if self.config.reward_model.enable and test_batch[0].non_tensor_batch[\"reward_model\"][\"style\"] == \"model\":\n                return {}\n\n            # Store original inputs\n            input_ids = test_batch.batch[\"input_ids\"]\n            # TODO: Can we keep special tokens except for padding tokens?\n            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]\n            sample_inputs.extend(input_texts)\n\n            batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n            non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n            if \"multi_modal_inputs\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.extend([\"multi_modal_data\", \"multi_modal_inputs\"])\n            if \"raw_prompt\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"raw_prompt\")\n            if \"tools_kwargs\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"tools_kwargs\")\n            test_gen_batch = test_batch.pop(\n                batch_keys=batch_keys_to_pop,\n                non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n            )\n\n            test_gen_batch.meta_info = {\n                \"eos_token_id\": self.tokenizer.eos_token_id,\n                \"pad_token_id\": self.tokenizer.pad_token_id,\n                \"recompute_log_prob\": False,\n                \"do_sample\": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,\n                \"validate\": True,\n            }\n            print(f\"test_gen_batch meta info: {test_gen_batch.meta_info}\")\n\n            # pad to be divisible by dp_size\n            test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)\n            if not self.async_rollout_mode:\n                test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)\n            else:\n                test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)\n\n            # unpad\n            test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)\n            print(\"validation generation end\")\n\n            # Store generated outputs\n            output_ids = test_output_gen_batch.batch[\"responses\"]\n            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]\n            sample_outputs.extend(output_texts)\n\n            test_batch = test_batch.union(test_output_gen_batch)\n\n            # evaluate using reward_function\n            result = self.val_reward_fn(test_batch, return_dict=True)\n            reward_tensor = result[\"reward_tensor\"]\n            scores = reward_tensor.sum(-1).cpu().tolist()\n            sample_scores.extend(scores)\n\n            reward_extra_infos_dict[\"reward\"].extend(scores)\n            if \"reward_extra_info\" in result:\n                for key, lst in result[\"reward_extra_info\"].items():\n                    reward_extra_infos_dict[key].extend(lst)\n\n            data_source_lst.append(test_batch.non_tensor_batch.get(\"data_source\", [\"unknown\"] * reward_tensor.shape[0]))\n\n        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)\n\n        # dump generations\n        val_data_dir = self.config.trainer.get(\"validation_data_dir\", None)\n        if val_data_dir:\n            sample_gts = [\n                item.non_tensor_batch.get(\"reward_model\", {}).get(\"ground_truth\", None) for item in test_batch\n            ]\n            self._dump_generations(\n                inputs=sample_inputs,\n                outputs=sample_outputs,\n                gts=sample_gts,\n                scores=sample_scores,\n                reward_extra_infos_dict=reward_extra_infos_dict,\n                dump_path=val_data_dir,\n            )\n\n        for key_info, lst in reward_extra_infos_dict.items():\n            assert len(lst) == 0 or len(lst) == len(sample_scores), f\"{key_info}: {len(lst)=}, {len(sample_scores)=}\"\n\n        data_sources = np.concatenate(data_source_lst, axis=0)\n        print(f\"DEBUG: Data sources shape: {data_sources.shape}\")  # Added Print\n        print(f\"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}\")  # Added Print\n\n        data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)\n        print(\n            f\"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}\"\n        )  # Added Print\n        metric_dict = {}\n        for data_source, var2metric2val in data_src2var2metric2val.items():\n            core_var = \"acc\" if \"acc\" in var2metric2val else \"reward\"\n            for var_name, metric2val in var2metric2val.items():\n                n_max = max([int(name.split(\"@\")[-1].split(\"/\")[0]) for name in metric2val.keys()])\n                for metric_name, metric_val in metric2val.items():\n                    if (\n                        (var_name == core_var)\n                        and any(metric_name.startswith(pfx) for pfx in [\"mean\", \"maj\", \"best\"])\n                        and (f\"@{n_max}\" in metric_name)\n                    ):\n                        metric_sec = \"val-core\"\n                    else:\n                        metric_sec = \"val-aux\"\n                    pfx = f\"{metric_sec}/{data_source}/{var_name}/{metric_name}\"\n                    metric_dict[pfx] = metric_val\n\n        return metric_dict\n\n    def init_workers(self):\n        \"\"\"Init resource pool and worker group\"\"\"\n        self.resource_pool_manager.create_resource_pool()\n\n        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}\n\n        # create actor and rollout\n        if self.hybrid_engine:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)\n            actor_rollout_cls = RayClassWithInitArgs(\n                cls=self.role_worker_mapping[Role.ActorRollout],\n                config=self.config.actor_rollout_ref,\n                role=\"actor_rollout\",\n            )\n            self.resource_pool_to_cls[resource_pool][\"actor_rollout\"] = actor_rollout_cls\n        else:\n            raise NotImplementedError\n\n        # create critic\n        if self.use_critic:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)\n            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)\n            self.resource_pool_to_cls[resource_pool][\"critic\"] = critic_cls\n\n        # create reference policy if needed\n        if self.use_reference_policy:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)\n            ref_policy_cls = RayClassWithInitArgs(\n                self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role=\"ref\"\n            )\n            self.resource_pool_to_cls[resource_pool][\"ref\"] = ref_policy_cls\n\n        # create a reward model if reward_fn is None\n        if self.use_rm:\n            # we create a RM here\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)\n            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)\n            self.resource_pool_to_cls[resource_pool][\"rm\"] = rm_cls\n\n        # initialize WorkerGroup\n        # NOTE: if you want to use a different resource pool for each role, which can support different\n        # parallel size,\n        # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to\n        # different worker groups.\n        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.\n        all_wg = {}\n        self.wg_dicts = []\n        wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup\n        if OmegaConf.select(self.config.trainer, \"ray_wait_register_center_timeout\") is not None:\n            wg_kwargs[\"ray_wait_register_center_timeout\"] = self.config.trainer.ray_wait_register_center_timeout\n        wg_kwargs[\"device_name\"] = self.device_name\n\n        for resource_pool, class_dict in self.resource_pool_to_cls.items():\n            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n            wg_dict = self.ray_worker_group_cls(\n                resource_pool=resource_pool,\n                ray_cls_with_init=worker_dict_cls,\n                **wg_kwargs,\n            )\n            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n            all_wg.update(spawn_wg)\n            # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699\n            self.wg_dicts.append(wg_dict)\n\n        if self.use_critic:\n            self.critic_wg = all_wg[\"critic\"]\n            self.critic_wg.init_model()\n\n        if self.use_reference_policy:\n            self.ref_policy_wg = all_wg[\"ref\"]\n            self.ref_policy_wg.init_model()\n\n        if self.use_rm:\n            self.rm_wg = all_wg[\"rm\"]\n            self.rm_wg.init_model()\n\n        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory\n        self.actor_rollout_wg = all_wg[\"actor_rollout\"]\n        self.actor_rollout_wg.init_model()\n\n    def _save_checkpoint(self):\n        # path: given_path + `/global_step_{global_steps}` + `/actor`\n        local_global_step_folder = os.path.join(\n            self.config.trainer.default_local_dir, f\"global_step_{self.global_steps}\"\n        )\n\n        print(f\"local_global_step_folder: {local_global_step_folder}\")\n        actor_local_path = os.path.join(local_global_step_folder, \"actor\")\n\n        actor_remote_path = (\n            None\n            if self.config.trainer.default_hdfs_dir is None\n            else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"actor\")\n        )\n\n        remove_previous_ckpt_in_save = self.config.trainer.get(\"remove_previous_ckpt_in_save\", False)\n        if remove_previous_ckpt_in_save:\n            print(\n                \"Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and \"\n                \"max_critic_ckpt_to_keep=1 instead\"\n            )\n        max_actor_ckpt_to_keep = (\n            self.config.trainer.get(\"max_actor_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n        max_critic_ckpt_to_keep = (\n            self.config.trainer.get(\"max_critic_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n\n        self.actor_rollout_wg.save_checkpoint(\n            actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep\n        )\n\n        if self.use_critic:\n            critic_local_path = os.path.join(local_global_step_folder, \"critic\")\n            critic_remote_path = (\n                None\n                if self.config.trainer.default_hdfs_dir is None\n                else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"critic\")\n            )\n            self.critic_wg.save_checkpoint(\n                critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep\n            )\n\n        # save dataloader\n        dataloader_local_path = os.path.join(local_global_step_folder, \"data.pt\")\n        dataloader_state_dict = self.train_dataloader.state_dict()\n        torch.save(dataloader_state_dict, dataloader_local_path)\n\n        # latest checkpointed iteration tracker (for atomic usage)\n        local_latest_checkpointed_iteration = os.path.join(\n            self.config.trainer.default_local_dir, \"latest_checkpointed_iteration.txt\"\n        )\n        with open(local_latest_checkpointed_iteration, \"w\") as f:\n            f.write(str(self.global_steps))\n\n    def _load_checkpoint(self):\n        if self.config.trainer.resume_mode == \"disable\":\n            return 0\n\n        # load from hdfs\n        if self.config.trainer.default_hdfs_dir is not None:\n            raise NotImplementedError(\"load from hdfs is not implemented yet\")\n        else:\n            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path\n            if not os.path.isabs(checkpoint_folder):\n                working_dir = os.getcwd()\n                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)\n            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest\n\n        # find global_step_folder\n        if self.config.trainer.resume_mode == \"auto\":\n            if global_step_folder is None:\n                print(\"Training from scratch\")\n                return 0\n        else:\n            if self.config.trainer.resume_mode == \"resume_path\":\n                assert isinstance(self.config.trainer.resume_from_path, str), \"resume ckpt must be str type\"\n                assert \"global_step_\" in self.config.trainer.resume_from_path, (\n                    \"resume ckpt must specify the global_steps\"\n                )\n                global_step_folder = self.config.trainer.resume_from_path\n                if not os.path.isabs(global_step_folder):\n                    working_dir = os.getcwd()\n                    global_step_folder = os.path.join(working_dir, global_step_folder)\n        print(f\"Load from checkpoint folder: {global_step_folder}\")\n        # set global step\n        self.global_steps = int(global_step_folder.split(\"global_step_\")[-1])\n\n        print(f\"Setting global step to {self.global_steps}\")\n        print(f\"Resuming from {global_step_folder}\")\n\n        actor_path = os.path.join(global_step_folder, \"actor\")\n        critic_path = os.path.join(global_step_folder, \"critic\")\n        # load actor\n        self.actor_rollout_wg.load_checkpoint(\n            actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n        )\n        # load critic\n        if self.use_critic:\n            self.critic_wg.load_checkpoint(\n                critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n            )\n\n        # load dataloader,\n        # TODO: from remote not implemented yet\n        dataloader_local_path = os.path.join(global_step_folder, \"data.pt\")\n        if os.path.exists(dataloader_local_path):\n            dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)\n            self.train_dataloader.load_state_dict(dataloader_state_dict)\n        else:\n            print(f\"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch\")\n\n    def _balance_batch(self, batch: DataProto, metrics, logging_prefix=\"global_seqlen\"):\n        \"\"\"Reorder the data on single controller such that each dp rank gets similar total tokens\"\"\"\n        attention_mask = batch.batch[\"attention_mask\"]\n        batch_size = attention_mask.shape[0]\n        global_seqlen_lst = batch.batch[\"attention_mask\"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)\n        world_size = self.actor_rollout_wg.world_size\n        global_partition_lst = get_seqlen_balanced_partitions(\n            global_seqlen_lst, k_partitions=world_size, equal_size=True\n        )\n        # reorder based on index. The data will be automatically equally partitioned by dispatch function\n        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])\n        batch.reorder(global_idx)\n        global_balance_stats = log_seqlen_unbalance(\n            seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix\n        )\n        metrics.update(global_balance_stats)\n\n    def fit_dpo(self):  # Renamed for clarity as standard PPO loop\n        \"\"\"\n        The training loop of Online DPO using a periodically updated reference model.\n        The driver process calls worker groups for computation.\n        Advantage computation is replaced by DPO logic.\n        \"\"\"\n        import traceback  # Ensure traceback is imported\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        # Initialize logger\n        logger = None\n        try:\n            logger = Tracking(\n                project_name=self.config.trainer.project_name,\n                experiment_name=self.config.trainer.experiment_name,\n                default_backend=self.config.trainer.logger,\n                config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False),\n            )\n        except Exception as e:\n            print(f\"Warning: Failed to initialize logger: {e}\")\n\n        self.global_steps = 0\n        # Load checkpoint before doing anything\n        loaded_step = self._load_checkpoint()\n        self.global_steps = loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1\n        print(\n            f\"Starting Online DPO training from global step {self.global_steps}. \"\n            f\"Total steps: {self.total_training_steps}\"\n        )\n        print(f\"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}\")\n\n        # Check if reference policy is configured correctly for this mode\n        if not self.use_reference_policy:\n            print(\n                \"WARNING: 'use_reference_policy' is False. Periodic reference model update requires a \"\n                \"reference policy worker. DPO updates might fail or use incorrect logic.\"\n            )\n            # Consider raising an error if strict adherence is required:\n            # raise ValueError(\"Periodic reference model update requires 'use_reference_policy' to be True \"\n            #                  \"and a configured reference worker.\")\n\n        # Perform validation before training\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            print(\"Running validation before Online DPO training...\")\n            val_metrics = self._validate()\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            if logger and val_metrics:\n                logger.log(data=val_metrics, step=max(0, self.global_steps - 1))\n            if self.config.trainer.get(\"val_only\", False):\n                print(\"Validation only mode enabled. Exiting training.\")\n                if logger and hasattr(logger, \"finish\"):\n                    logger.finish()\n                return\n\n        # Add tqdm progress bar\n        progress_bar = tqdm(\n            total=self.total_training_steps,\n            initial=self.global_steps,\n            desc=\"Online DPO Training Progress\",\n            position=0,\n            leave=True,\n        )\n\n        last_val_metrics = None\n        should_stop = False\n\n        for epoch in range(self.config.trainer.total_epochs):\n            if should_stop:\n                break\n            print(f\"--- Starting Online DPO Epoch {epoch} ---\")\n            try:\n                train_iterator = iter(self.train_dataloader)\n            except TypeError:\n                print(\"Warning: Dataloader is not iterable.\")\n                train_iterator = self.train_dataloader  # Fallback attempt\n\n            for batch_idx, batch_dict in enumerate(train_iterator):\n                if self.global_steps > self.total_training_steps:\n                    should_stop = True\n                    break\n\n                metrics = {}\n                timing_raw = {}\n                step_timer = Timer(logger=None)\n                ref_log_prob_computed = False  # Flag to track if ref log probs were computed\n\n                try:  # Outer try-except for the whole step\n                    step_timer.start()\n                    with _timer(\"step\", timing_raw):\n                        batch: DataProto = DataProto.from_single_dict(batch_dict)\n                        current_batch_size = batch.batch.batch_size[0]\n                        print(\n                            f\"\\n[Step {self.global_steps}, Batch {batch_idx}] Processing batch size: \"\n                            f\"{current_batch_size}\"\n                        )\n\n                        # --- Reference Model Update ---\n                        ref_update_freq = self.config.trainer.get(\"ref_update_freq\", -1)\n                        if (\n                            self.use_reference_policy\n                            and ref_update_freq > 0\n                            and self.global_steps % ref_update_freq == 0\n                        ):\n                            print(f\"\\n[Step {self.global_steps}] Updating Reference Model Weights from Actor...\")\n                            try:\n                                # --- This requires careful implementation with FSDP ---\n                                # 1. Save actor state dict (potentially to CPU memory or disk)\n                                #    This needs to be done collectively across actor worker ranks.\n                                #    The checkpoint_manager might be adaptable, or use FSDP APIs directly.\n                                #    Example placeholder using a conceptual save/load mechanism:\n                                actor_state_path = \"/tmp/actor_state_mid\"  # Temporary path\n                                self.actor_rollout_wg.save_checkpoint(actor_state_path)  # Adapt save logic\n\n                                # 2. Load the state dict onto the reference model worker group\n                                #    This also needs collective loading on the ref worker ranks.\n                                self.ref_policy_wg.load_checkpoint(actor_state_path, None, True)  # Adapt load logic\n\n                                print(f\"[Step {self.global_steps}] Reference Model Weights Updated.\")\n                                # Optionally remove the temporary state file\n                                # os.remove(actor_state_path) # Needs rank-aware removal or shared storage\n\n                            except Exception as sync_e:\n                                print(f\"ERROR during reference model sync at step {self.global_steps}: {sync_e}\")\n                                traceback.print_exc()\n\n                        # Pop keys for generation\n                        pop_batch_keys = [\"input_ids\", \"attention_mask\"]\n                        if \"position_ids\" in batch.batch:\n                            pop_batch_keys.append(\"position_ids\")\n                        pop_non_tensor_keys = [\"raw_prompt_ids\"] if \"raw_prompt_ids\" in batch.non_tensor_batch else []\n                        if \"multi_modal_inputs\" in batch.non_tensor_batch.keys():\n                            pop_non_tensor_keys.extend([\"multi_modal_data\", \"multi_modal_inputs\"])\n                        original_non_tensor_data = batch.non_tensor_batch\n                        gen_batch = batch.pop(\n                            batch_keys=pop_batch_keys,\n                            non_tensor_batch_keys=pop_non_tensor_keys,\n                        )\n                        gen_batch = gen_batch.repeat(\n                            repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True\n                        )\n                        # (Add Debug prints for gen_batch if needed)\n\n                        # Generate sequences (chosen/rejected pairs)\n                        with _timer(\"gen\", timing_raw):\n                            try:\n                                gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                                # (Add Debug prints for gen_batch_output if needed)\n                            except Exception as gen_e:\n                                print(f\"\\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!\")\n                                print(gen_e)\n                                traceback.print_exc()\n                                print(\"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\")\n                                step_timer.stop()\n                                continue\n\n                        # Combine original prompts with generated sequences\n                        batch.non_tensor_batch = original_non_tensor_data  # Restore non-tensor data\n                        batch.non_tensor_batch[\"uid\"] = np.array(\n                            [str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object\n                        )\n                        batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                        batch = batch.union(gen_batch_output)\n                        # (Add Debug prints after union if needed)\n\n                        # Compute response mask (needed for ref logprob calc and DPO prep)\n                        batch.batch[\"response_mask\"] = compute_response_mask(batch)\n\n                        if self.config.trainer.balance_batch:\n                            self._balance_batch(batch, metrics=metrics)\n\n                        batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                        # --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef\n                        # fallback) ---\n                        # Note: For pure DPO with external ref, this 'old_log_probs' might not be strictly needed\n                        #       unless used for other metrics or a fallback. Keep it for now.\n                        with _timer(\"policy_log_prob\", timing_raw):\n                            policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch)\n                            batch = batch.union(policy_log_prob_output)  # Adds 'old_log_probs'\n                            # (Debug prints for old_log_probs)\n\n                        # --- Compute Log Probs using the EXTERNAL Reference Model ---\n                        if self.use_reference_policy:\n                            with _timer(\"ref_log_prob_dpo\", timing_raw):\n                                # print(f\"---- [Step {self.global_steps}] DEBUG DPO: Calling compute_ref_log_prob ----\")\n                                try:\n                                    # 'batch' contains interleaved chosen/rejected sequences\n                                    ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob(\n                                        batch\n                                    )  # Returns DataProto with 'ref_log_prob'\n                                    batch = batch.union(\n                                        ref_log_prob_output\n                                    )  # Adds 'ref_log_prob' key [batch_size * n, seq_len]\n                                    ref_log_prob_computed = True  # Mark success\n                                    # print(f\"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: \"\n                                    #       f\"{batch.batch['ref_log_prob'].shape} ----\")\n                                except Exception as ref_e:\n                                    print(f\"ERROR computing reference log probs at step {self.global_steps}: {ref_e}\")\n                                    traceback.print_exc()\n                                    batch.batch[\"ref_log_prob\"] = None  # Mark as failed\n                                    ref_log_prob_computed = False\n                        else:\n                            print(\n                                \"Warning: Skipping external reference log prob calculation as use_reference_policy \"\n                                \"is False.\"\n                            )\n                            # DPO update will likely fail unless ActorAsRef logic is re-enabled in dp_actor\n\n                        # --- Compute Rewards/Scores (used to determine preference) ---\n                        with _timer(\"reward_calc\", timing_raw):\n                            # (Reward calculation logic using RM or reward_fn as before)\n                            # ... Ensure this calculates 'token_level_rewards' or similar ...\n                            if self.use_rm:\n                                reward_tensor_rm = self.rm_wg.compute_rm_score(batch)\n                                batch = batch.union(reward_tensor_rm)  # Adds 'rm_scores'\n\n                            reward_extra_infos_dict = {}\n                            try:\n                                if self.reward_fn is None:\n                                    #  print(f\"---- [DEBUG Step {self.global_steps}] ERROR: self.reward_fn is None! \"\n                                    #        f\"Using dummy rewards. ----\")\n                                    # Use rm_scores if available, otherwise zeros\n                                    reward_tensor = batch.batch.get(\n                                        \"rm_scores\", torch.zeros_like(batch.batch[\"response_mask\"], dtype=torch.float32)\n                                    )\n                                else:\n                                    reward_result = self.reward_fn(batch, return_dict=True)\n                                    reward_tensor = reward_result[\"reward_tensor\"]  # Final combined reward\n                                    reward_extra_infos_dict = reward_result.get(\"reward_extra_info\", {})\n\n                            except Exception:\n                                # print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. '\n                                #       f'Using dummy rewards. ----')\n                                traceback.print_exc()\n                                reward_tensor = torch.zeros_like(batch.batch[\"response_mask\"], dtype=torch.float32)\n                                reward_extra_infos_dict = {}\n\n                            # Use 'token_level_rewards' as the key for preference calculation\n                            batch.batch[\"token_level_rewards\"] = reward_tensor\n                            if reward_extra_infos_dict:\n                                batch.non_tensor_batch.update(\n                                    {k: np.array(v) for k, v in reward_extra_infos_dict.items()}\n                                )\n\n                        # --- Determine Preferences ---\n                        # Uses 'token_level_rewards' to determine chosen/rejected based on score\n                        batch = compute_onlineDPO_pref(batch)  # Adds 'preferences' key\n\n                        # --- Prepare DPO Batch ---\n                        dpo_update_batch_proto = None  # Initialize\n                        with _timer(\"prepare_dpo_batch\", timing_raw):\n                            try:\n                                if \"preferences\" not in batch.batch or batch.batch[\"preferences\"] is None:\n                                    raise ValueError(\"'preferences' key missing or None after compute_onlineDPO_pref.\")\n\n                                # Check if reference log probs were computed successfully (if needed)\n                                if self.use_reference_policy and not ref_log_prob_computed:\n                                    raise ValueError(\"Reference log probs required but failed to compute.\")\n\n                                # Check required base keys\n                                required_keys = [\"input_ids\", \"attention_mask\", \"response_mask\"]\n                                for rk in required_keys:\n                                    if rk not in batch.batch or batch.batch[rk] is None:\n                                        raise KeyError(f\"Required key '{rk}' missing from batch for DPO prep.\")\n\n                                preferences_mask = batch.batch[\"preferences\"]  # Shape [batch_size * n]\n                                not_preferences_mask = ~preferences_mask\n\n                                # Gather Chosen/Rejected Base Tensors\n                                chosen_input_ids = batch.batch[\"input_ids\"][preferences_mask]\n                                chosen_attention_mask = batch.batch[\"attention_mask\"][preferences_mask]\n                                rejected_input_ids = batch.batch[\"input_ids\"][not_preferences_mask]\n                                rejected_attention_mask = batch.batch[\"attention_mask\"][not_preferences_mask]\n                                chosen_position_ids = (\n                                    batch.batch.get(\"position_ids\")[preferences_mask]\n                                    if \"position_ids\" in batch.batch\n                                    else None\n                                )\n                                rejected_position_ids = (\n                                    batch.batch.get(\"position_ids\")[not_preferences_mask]\n                                    if \"position_ids\" in batch.batch\n                                    else None\n                                )\n\n                                # Create Labels\n                                print(\"WARNING: Creating DPO labels using configured max_prompt_length...\")\n                                prompt_len = self.config.data.max_prompt_length\n                                chosen_labels = chosen_input_ids.clone()\n                                chosen_labels[:, :prompt_len] = -100\n                                rejected_labels = rejected_input_ids.clone()\n                                rejected_labels[:, :prompt_len] = -100\n\n                                # Calculate and Gather Reference Log Probs (Sequence Level)\n                                if self.use_reference_policy:\n                                    ref_log_prob_tensor = batch.batch[\"ref_log_prob\"]  # Token level [bsz * n, seq_len]\n                                    response_mask_full = batch.batch[\n                                        \"response_mask\"\n                                    ]  # Response mask [bsz * n, seq_len]\n                                    ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum(\n                                        dim=-1\n                                    )  # Sequence level [bsz * n]\n                                    reference_chosen_logps = ref_sequence_logps[preferences_mask]\n                                    reference_rejected_logps = ref_sequence_logps[not_preferences_mask]\n                                else:\n                                    # If not using external ref, DPO needs ActorAsRef logic in dp_actor\n                                    # We won't add the keys here, dp_actor will handle it (or fail if not modified)\n                                    print(\n                                        \"Info: Not adding explicit reference logps to DPO batch \"\n                                        \"(use_reference_policy=False).\"\n                                    )\n                                    reference_chosen_logps = None  # Explicitly None\n                                    reference_rejected_logps = None\n\n                                # Package Tensors\n                                dpo_tensors = {\n                                    \"chosen_input_ids\": chosen_input_ids,\n                                    \"chosen_attention_mask\": chosen_attention_mask,\n                                    \"chosen_labels\": chosen_labels,\n                                    \"rejected_input_ids\": rejected_input_ids,\n                                    \"rejected_attention_mask\": rejected_attention_mask,\n                                    \"rejected_labels\": rejected_labels,\n                                }\n                                # Conditionally add reference logps if computed\n                                if reference_chosen_logps is not None:\n                                    dpo_tensors[\"reference_chosen_logps\"] = reference_chosen_logps\n                                if reference_rejected_logps is not None:\n                                    dpo_tensors[\"reference_rejected_logps\"] = reference_rejected_logps\n                                # Add position ids if they exist\n                                if chosen_position_ids is not None:\n                                    dpo_tensors[\"chosen_position_ids\"] = chosen_position_ids\n                                if rejected_position_ids is not None:\n                                    dpo_tensors[\"rejected_position_ids\"] = rejected_position_ids\n\n                                # Prepare Meta Info\n                                dpo_meta = {\n                                    \"dpo_beta\": OmegaConf.select(self.config.algorithm, \"dpo_beta\", default=0.1),\n                                    \"dpo_loss_type\": OmegaConf.select(\n                                        self.config.algorithm, \"dpo_loss_type\", default=\"sigmoid\"\n                                    ),\n                                    \"dpo_label_smoothing\": OmegaConf.select(\n                                        self.config.algorithm, \"dpo_label_smoothing\", default=0.0\n                                    ),\n                                    \"use_reference_policy\": self.use_reference_policy,\n                                    \"reference_free\": not self.use_reference_policy,  # False if using external ref\n                                    \"global_step\": self.global_steps,\n                                }\n\n                                dpo_update_batch_proto = DataProto.from_dict(tensors=dpo_tensors, meta_info=dpo_meta)\n                                # print(f\"---- [Step {self.global_steps}] DEBUG DPO: Prepared DPO Update Batch ----\")\n                                # print(f\"  Keys: {list(dpo_update_batch_proto.batch.keys())}\")\n                                # print(f\"  Meta Info: {dpo_meta}\")\n\n                            except Exception as e_prep:\n                                print(f\"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}\")\n                                traceback.print_exc()\n                                dpo_update_batch_proto = None  # Skip update on error\n\n                        # --- Actor Update Step ---\n                        actor_output = None\n                        if self.config.trainer.critic_warmup <= self.global_steps and dpo_update_batch_proto:\n                            with _timer(\"update_actor\", timing_raw):\n                                # Pass the batch containing reference log probs (if computed)\n                                # The modified update_actor_dpo expects them if reference_free=False\n                                actor_output = self.actor_rollout_wg.update_actor_dpo(dpo_update_batch_proto)\n                            if actor_output and \"metrics\" in actor_output.meta_info:\n                                metrics.update(reduce_metrics(actor_output.meta_info[\"metrics\"]))\n                        elif dpo_update_batch_proto is None:\n                            print(\n                                f\"Skipping actor update at step {self.global_steps} due to DPO batch preparation error.\"\n                            )\n\n                        # --- Validation and Saving ---\n                        test_freq = OmegaConf.select(self.config.trainer, \"test_freq\", default=-1)\n                        is_last_step = self.global_steps >= self.total_training_steps\n                        if (\n                            self.val_reward_fn is not None\n                            and test_freq > 0\n                            and (is_last_step or self.global_steps % test_freq == 0)\n                        ):\n                            print(f\"\\nRunning DPO validation at step {self.global_steps}...\")\n                            val_timing_raw = {}\n                            with _timer(\"testing\", val_timing_raw):\n                                val_metrics: dict = self._validate()\n                            if is_last_step:\n                                last_val_metrics = val_metrics\n                            if val_metrics:\n                                metrics[\"time/validation_run\"] = val_timing_raw.get(\"testing\", 0)\n                                metrics.update(val_metrics)\n                            else:\n                                print(\"Validation skipped or returned no metrics.\")\n\n                        save_freq = OmegaConf.select(self.config.trainer, \"save_freq\", default=-1)\n                        if save_freq > 0 and (is_last_step or self.global_steps % save_freq == 0):\n                            print(f\"\\nSaving DPO checkpoint at step {self.global_steps}...\")\n                            with _timer(\"save_checkpoint\", timing_raw):\n                                self._save_checkpoint()  # Saves actor (and potentially critic if used elsewhere)\n                            metrics[\"time/save_checkpoint\"] = timing_raw.get(\"save_checkpoint\", 0)\n\n                    # --- End main step timer context ---\n\n                    # --- Metrics calculation AFTER the 'step' timer block ---\n                    metrics.update(compute_dpo_data_metrics(batch=batch))  # Use DPO-specific metrics\n                    metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n                    n_gpus = self.resource_pool_manager.get_n_gpus()\n                    if \"step\" in timing_raw:\n                        metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n                    else:\n                        print(\n                            f\"Warning: 'step' key missing from timing_raw at step {self.global_steps}. \"\n                            f\"Skipping throughput.\"\n                        )\n\n                    step_timer.stop()\n                    metrics[\"time/step\"] = step_timer.last\n\n                    # Log metrics\n                    log_freq = OmegaConf.select(self.config.trainer, \"log_freq\", default=1)\n                    if logger and self.global_steps % log_freq == 0:\n                        log_payload = metrics.copy()\n                        # Add learning rate to log payload\n                        if actor_output and \"actor/lr\" in metrics:\n                            log_payload[\"actor/lr\"] = metrics[\"actor/lr\"]\n\n                        print(f\"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}\")\n                        try:\n                            logger.log(data=log_payload, step=self.global_steps)\n                        except Exception as e:\n                            print(f\"Logging failed at step {self.global_steps}: {e}\")\n\n                    # Update progress bar\n                    postfix_metrics = {\n                        k: f\"{v:.3f}\" if isinstance(v, float) else v\n                        for k, v in metrics.items()\n                        if isinstance(v, int | float)\n                    }\n                    progress_bar.set_postfix(postfix_metrics)\n\n                except Exception as step_e:\n                    print(f\"\\n!!!!!!!! ERROR DURING DPO Step {self.global_steps} !!!!!!!!\")\n                    print(f\"Caught Exception: {step_e}\")\n                    traceback.print_exc()\n                    print(\"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\")\n                    step_timer.stop()\n                    should_stop = True\n                    break\n\n                if is_last_step or should_stop:\n                    print(f\"Stopping DPO training at step {self.global_steps}.\")\n                    break\n\n                self.global_steps += 1\n                progress_bar.update(1)\n\n            # End of epoch handling\n            if hasattr(self.train_dataloader, \"reset\"):\n                try:\n                    self.train_dataloader.reset()\n                except Exception as e:\n                    print(f\"Warning: Failed to reset train dataloader state: {e}\")\n            if should_stop:\n                break\n\n        # --- Final cleanup and logging ---\n        progress_bar.close()\n        final_step = max(0, self.global_steps - 1)\n        print(f\"Online DPO Training finished at step {final_step}.\")\n        # Save final checkpoint\n        save_freq = OmegaConf.select(self.config.trainer, \"save_freq\", default=-1)\n        if not self.config.trainer.get(\"val_only\", False) and (save_freq <= 0 or final_step % save_freq != 0):\n            print(f\"Saving final DPO checkpoint at step {final_step}...\")\n            self._save_checkpoint()\n\n        # Final validation run\n        if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get(\"val_only\", False):\n            print(\"Running final validation...\")\n            last_val_metrics = self._validate()\n            if last_val_metrics and logger:\n                last_val_metrics[\"final_validation\"] = True\n                try:\n                    logger.log(data=last_val_metrics, step=final_step)\n                except Exception as e:\n                    print(f\"[Final Val Metrics Log Error]: {e}\")\n\n        pprint(f\"Final validation metrics: {last_val_metrics}\")\n        if logger and hasattr(logger, \"finish\"):\n            logger.finish()\n        print(\"Online DPO Training Run Complete.\")\n"
  },
  {
    "path": "verl_distillation/recipe/spin/utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom omegaconf import DictConfig\n\n\ndef validate_config(\n    config: DictConfig,\n    use_reference_policy: bool,\n    use_critic: bool,\n) -> None:\n    \"\"\"\n    Validate an OmegaConf DictConfig\n\n    Args:\n        config: The OmegaConf DictConfig to validate.\n        use_reference_policy (bool): is ref policy needed\n        use_critic (bool): is critic needed\n    \"\"\"\n    # number of GPUs total\n    n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes\n\n    # 1. Check total batch size for data correctness\n    real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n\n    assert real_train_batch_size % n_gpus == 0, (\n        f\"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus}).\"\n    )\n\n    # A helper function to check \"micro_batch_size\" vs \"micro_batch_size_per_gpu\"\n    # We throw an error if the user sets both. The new convention is \"..._micro_batch_size_per_gpu\".\n    def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):\n        settings = {\n            \"actor_rollout_ref.actor\": \"micro_batch_size\",\n            \"critic\": \"micro_batch_size\",\n            \"reward_model\": \"micro_batch_size\",\n            \"actor_rollout_ref.ref\": \"log_prob_micro_batch_size\",\n            \"actor_rollout_ref.rollout\": \"log_prob_micro_batch_size\",\n        }\n\n        if name in settings:\n            param = settings[name]\n            param_per_gpu = f\"{param}_per_gpu\"\n\n            if mbs is None and mbs_per_gpu is None:\n                raise ValueError(f\"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.\")\n\n            if mbs is not None and mbs_per_gpu is not None:\n                raise ValueError(\n                    f\"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. \"\n                    f\"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported \"\n                    f\"(the former is deprecated).\"\n                )\n\n    if not config.actor_rollout_ref.actor.use_dynamic_bsz:\n        # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu\n        check_mutually_exclusive(\n            config.actor_rollout_ref.actor.ppo_micro_batch_size,\n            config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,\n            \"actor_rollout_ref.actor\",\n        )\n\n        if use_reference_policy:\n            # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu\n            check_mutually_exclusive(\n                config.actor_rollout_ref.ref.log_prob_micro_batch_size,\n                config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,\n                \"actor_rollout_ref.ref\",\n            )\n\n        #  The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu\n        check_mutually_exclusive(\n            config.actor_rollout_ref.rollout.log_prob_micro_batch_size,\n            config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,\n            \"actor_rollout_ref.rollout\",\n        )\n\n    if use_critic and not config.critic.use_dynamic_bsz:\n        # Check for critic micro-batch size conflicts\n        check_mutually_exclusive(\n            config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, \"critic\"\n        )\n\n    # Check for reward model micro-batch size conflicts\n    if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:\n        check_mutually_exclusive(\n            config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, \"reward_model\"\n        )\n\n    # Actor\n    # check if train_batch_size is larger than ppo_mini_batch_size\n    # if NOT dynamic_bsz, we must ensure:\n    #    ppo_mini_batch_size is divisible by ppo_micro_batch_size\n    #    ppo_micro_batch_size * sequence_parallel_size >= n_gpus\n    if not config.actor_rollout_ref.actor.use_dynamic_bsz:\n        assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size\n        sp_size = config.actor_rollout_ref.actor.get(\"ulysses_sequence_parallel_size\", 1)\n        if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:\n            assert (\n                config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size\n                == 0\n            )\n            assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus\n\n    assert config.actor_rollout_ref.actor.loss_agg_mode in [\n        \"token-mean\",\n        \"seq-mean-token-sum\",\n        \"seq-mean-token-mean\",\n    ], f\"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}\"\n\n    if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:\n        print(\"NOTICE: You have both enabled in-reward kl and kl loss.\")\n\n    # critic\n    if use_critic and not config.critic.use_dynamic_bsz:\n        assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size\n        sp_size = config.critic.get(\"ulysses_sequence_parallel_size\", 1)\n        if config.critic.ppo_micro_batch_size is not None:\n            assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0\n            assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus\n\n    # Check if use_remove_padding is enabled when using sequence parallelism for fsdp\n    if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n        if (\n            config.actor_rollout_ref.actor.get(\"ulysses_sequence_parallel_size\", 1) > 1\n            or config.actor_rollout_ref.ref.get(\"ulysses_sequence_parallel_size\", 1) > 1\n        ):\n            assert config.actor_rollout_ref.model.use_remove_padding, (\n                \"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`.\"\n            )\n\n    if use_critic and config.critic.strategy in {\"fsdp\", \"fsdp2\"}:\n        if config.critic.get(\"ulysses_sequence_parallel_size\", 1) > 1:\n            assert config.critic.model.use_remove_padding, (\n                \"When using sequence parallelism for critic, you must enable `use_remove_padding`.\"\n            )\n\n    if config.data.get(\"val_batch_size\", None) is not None:\n        print(\n            \"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines \"\n            \"as a whole batch, which will schedule the memory themselves.\"\n        )\n\n    # check eval config\n    if config.actor_rollout_ref.rollout.val_kwargs.do_sample:\n        assert config.actor_rollout_ref.rollout.temperature > 0, (\n            \"validation gen temperature should be greater than 0 when enabling do_sample\"\n        )\n\n    print(\"[validate_config] All configuration checks passed successfully!\")\n"
  },
  {
    "path": "verl_distillation/recipe/sppo/README.md",
    "content": "# SPPO: Self-Play Preference Optimization for Language Model Alignment\n\nThis repository hosts the community implementation for the paper [Self-Play Preference Optimization for Language Model Alignment](https://arxiv.org/abs/2405.00675). SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets.\n\nPaper Authors: [Yue Wu](https://yuewu.us/)\\*, [Zhiqing Sun](https://www.cs.cmu.edu/~zhiqings/)\\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Yiming Yang](https://www.cs.cmu.edu/~yiming/), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n\nverl Implementation Authors: [Yuhao Yang](https://github.com/yhyang201), [Chenyang Zhao](https://github.com/zhaochenyang20)\n\n[[Webpage](https://uclaml.github.io/SPPO/)] [[Huggingface](https://huggingface.co/papers/2405.00675)] [[Paper](https://arxiv.org/abs/2405.00675)][[Original Implementation](https://github.com/uclaml/SPPO)]\n\n## Reproduce the Experiment\n\nWe evaluate the performance of SPPO on the MATH dataset. Starting from an initial score of 46.6 with Qwen2.5-7B-Instruct, we achieve a score of 65.6 after 20 epochs of training, placing our model approximately in the top 20 on the [MATH leaderboard](https://paperswithcode.com/sota/math-word-problem-solving-on-math). It's important to note that verl's internal evaluation metrics may not perfectly align with the official evaluation methodology for Qwen2.5-7B-Instruct. Therefore, for consistency and fair comparison, we report only the results based on verl's evaluation framework.\n\n```\ngit clone git@github.com:volcengine/verl.git\ncd verl\npython3 -m uv pip install -e \".[sglang]\"\n\nexport WANDB_API_KEY=<YOUR_WANDB_API_KEY>\n\npython3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math\nhuggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct\n\nexport CUDA_VISIBLE_DEVICES=0,1,2,3\nbash recipe/sppo/run_qwen2.5-7b_rm.sh\n```\n\nNote that the installation would occasionally fail to install flash-attn. If this happens, you can install it manually by running:\n\n```bash\npython3 -m uv pip install wheel\npython3 -m uv pip install packaging\npython3 -m uv pip install flash-attn --no-build-isolation --no-deps\n```\n\n## Acknowledgement\n\nWe sincerely thank the contribution and guidance from:\n\n- [Yue Wu](https://yuewu.us/)\n- [Chendong Wang](https://cdwang96.github.io/)\n- [Yifan Zhang](https://github.com/yifanzhang-pro)\n- [Yongan Xiang](https://github.com/BearBiscuit05)\n- [Junrong Lin](https://github.com/ocss884)\n- [Yuxuan Tong](https://github.com/tongyx361)\n- [Guangming Shen](https://github.com/PeterSH6)\n- [Biao He](https://www.linkedin.com/in/biao-he/)\n- [Qingquan Song](https://qingquansong.github.io/)\n- [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n"
  },
  {
    "path": "verl_distillation/recipe/sppo/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/recipe/sppo/config/sppo_trainer.yaml",
    "content": "# the sppo config will override default ppo_trainer.yaml\n\nhydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\nactor_rollout_ref:\n  actor:\n    _target_: recipe.sppo.config.SPPOActorConfig\n\n    # sppo_eta is an additional hyperparameter for SPPO, not available in\n    # verl core. specifying _target_ with SPPOActorConfig is needed to\n    # extend verl ActorConfig with custom fields.\n    # additional, it is also possible to use the `extra` field natively supported\n    # by all verl core dataclasses, without having to define SPPOActorConfig\n    # extra:\n    #   sppo_eta: 1.0\n    sppo_eta: 1.0\n\n    optim:\n      lr_warmup_steps: 15\n  rollout:\n    name: sglang\n    tensor_model_parallel_size: 2\n    gpu_memory_utilization: 0.5\n    val_kwargs:\n      n: 2  # 2 will trigger validation, 1 will bypass\n\nalgorithm:\n  adv_estimator: null\n  sppo_eta: 1.0\n\ntrainer:\n  log_val_generations: 0"
  },
  {
    "path": "verl_distillation/recipe/sppo/config.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass\n\nfrom verl.workers.config import FSDPActorConfig\n\n\n@dataclass\nclass SPPOActorConfig(FSDPActorConfig):\n    sppo_eta: float = 1.0\n"
  },
  {
    "path": "verl_distillation/recipe/sppo/dp_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\n\nimport verl.utils.torch_functional as verl_F\nfrom verl import DataProto\nfrom verl.trainer.ppo.core_algos import agg_loss, kl_penalty\nfrom verl.utils.device import get_device_id\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import rearrange_micro_batches\nfrom verl.workers.actor.dp_actor import DataParallelPPOActor\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef compute_sppo_loss(\n    old_log_prob: torch.Tensor,  # (bs, seq_len)\n    log_prob: torch.Tensor,  # (bs, seq_len)\n    rewards: torch.Tensor,  # (bs,)\n    response_mask: torch.Tensor,  # (bs, seq_len)\n    eta: float = 1.0,\n    loss_agg_mode: str = \"token-mean\",\n):\n    \"\"\"\n    SPPO Loss computation.\n    \"\"\"\n    # Compute log-ratios over masked tokens\n    log_prob_sum = (log_prob * response_mask).sum(dim=1)  # (bs,)\n    old_log_prob_sum = (old_log_prob * response_mask).sum(dim=1)  # (bs,)\n    log_ratios = log_prob_sum - old_log_prob_sum  # (bs,)\n\n    scaled_rewards = eta * (rewards)\n    loss_vec = (log_ratios - scaled_rewards) ** 2  # (bs,)\n\n    if loss_agg_mode == \"token-mean\":\n        sample_mask = response_mask.any(dim=1).float()  # (bs,)\n        loss = verl_F.masked_mean(loss_vec, sample_mask)\n\n    return loss, log_ratios, scaled_rewards\n\n\nclass DataParallelSPPOActor(DataParallelPPOActor):\n    @GPUMemoryLogger(role=\"dp actor\", logger=logger)\n    def update_policy(self, data: DataProto):\n        # make sure we are in training mode\n        self.actor_module.train()\n\n        temperature = data.meta_info[\"temperature\"]  # temperature must be in the data.meta_info to avoid slient error\n        multi_turn = data.meta_info.get(\"multi_turn\", False)\n\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\", \"old_log_probs\", \"seq_level_rewards\"]\n        if multi_turn:\n            select_keys.append(\"loss_mask\")\n        if self.config.use_kl_loss:\n            select_keys.append(\"ref_log_prob\")\n        batch = data.select(batch_keys=select_keys).batch\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        if has_multi_modal_inputs:\n            num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size\n            non_tensor_select_keys = [\"multi_modal_inputs\"]\n            dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)\n        else:\n            dataloader = batch.split(self.config.ppo_mini_batch_size)\n\n        metrics = {}\n        for epoch in range(self.config.ppo_epochs):\n            for batch_idx, data in enumerate(dataloader):\n                # split batch into micro_batches\n                mini_batch = data\n                if has_multi_modal_inputs:\n                    self.gradient_accumulation = (\n                        self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n                    )\n                    num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu\n                    micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)\n                elif self.config.use_dynamic_bsz:\n                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                    micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)\n                else:\n                    self.gradient_accumulation = (\n                        self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n                    )\n                    # split batch into micro_batches\n                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n\n                self.actor_optimizer.zero_grad()\n\n                for data in micro_batches:\n                    # Support all hardwares\n                    if isinstance(data, DataProto):\n                        data = {**data.batch.to(get_device_id()), **data.non_tensor_batch}\n                    else:\n                        data = data.to(get_device_id())  # actor device is cpu when using offload\n                    responses = data[\"responses\"]\n                    response_length = responses.size(1)\n                    attention_mask = data[\"attention_mask\"]\n                    if multi_turn:\n                        response_mask = data[\"loss_mask\"][:, -response_length:]\n                    else:\n                        response_mask = attention_mask[:, -response_length:]\n\n                    old_log_prob = data[\"old_log_probs\"]\n                    rewards = data[\"seq_level_rewards\"]\n\n                    entropy_coeff = self.config.entropy_coeff\n                    loss_agg_mode = self.config.loss_agg_mode\n                    eta = self.config.get(\"sppo_eta\", 1.0)\n\n                    # all return: (bsz, response_length)\n                    calculate_entropy = False\n                    if entropy_coeff != 0:\n                        calculate_entropy = True\n                    entropy, log_prob = self._forward_micro_batch(\n                        micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy\n                    )\n\n                    pg_loss, log_ratios, preference = compute_sppo_loss(\n                        old_log_prob=old_log_prob,\n                        log_prob=log_prob,\n                        rewards=rewards,\n                        response_mask=response_mask,\n                        eta=eta,\n                        loss_agg_mode=loss_agg_mode,\n                    )\n\n                    if entropy_coeff != 0:\n                        entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n                        # compute policy loss\n                        policy_loss = pg_loss - entropy_loss * entropy_coeff\n                    else:\n                        policy_loss = pg_loss\n\n                    if self.config.use_kl_loss:\n                        ref_log_prob = data[\"ref_log_prob\"]\n                        # compute kl loss\n                        kld = kl_penalty(\n                            logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type\n                        )\n                        kl_loss = agg_loss(\n                            loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode\n                        )\n\n                        policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef\n                        metrics[\"actor/kl_loss\"] = kl_loss.detach().item()\n                        metrics[\"actor/kl_coef\"] = self.config.kl_loss_coef\n\n                    if self.config.use_dynamic_bsz:\n                        # relative to the dynamic bsz\n                        loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)\n                    else:\n                        loss = policy_loss / self.gradient_accumulation\n                    loss.backward()\n\n                    data = {\n                        \"actor/loss\": loss.detach().item(),\n                        \"actor/log_ratio_mean\": log_ratios.mean().detach().item(),\n                        \"actor/preference_mean\": preference.mean().detach().item(),\n                    }\n                    append_to_dict(metrics, data)\n\n                grad_norm = self._optimizer_step()\n                data = {\"actor/grad_norm\": grad_norm.detach().item()}\n            append_to_dict(metrics, data)\n        self.actor_optimizer.zero_grad()\n        return metrics\n"
  },
  {
    "path": "verl_distillation/recipe/sppo/main_sppo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport os\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom verl.trainer.ppo.reward import load_reward_manager\nfrom verl.trainer.ppo.utils import need_reference_policy\nfrom verl.utils.config import validate_config\n\nfrom .sppo_ray_trainer import RaySPPOTrainer\n\n\n@hydra.main(config_path=\"config\", config_name=\"sppo_trainer\", version_base=None)\ndef main(config):\n    run_ppo(config)\n\n\ndef run_ppo(config) -> None:\n    # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices\n    # isolation, will solve in the future\n    os.environ[\"ENSURE_CUDA_VISIBLE_DEVICES\"] = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"\")\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        default_runtime_env = {\n            \"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\", \"VLLM_LOGGING_LEVEL\": \"WARN\"}\n        }\n        ray_init_kwargs = config.ray_kwargs.get(\"ray_init\", {})\n        runtime_env_kwargs = ray_init_kwargs.get(\"runtime_env\", {})\n        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)\n        ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, \"runtime_env\": runtime_env})\n        print(f\"ray init kwargs: {ray_init_kwargs}\")\n        ray.init(**OmegaConf.to_container(ray_init_kwargs))\n\n    runner = TaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    def run(self, config):\n        # print initial config\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n        OmegaConf.resolve(config)\n\n        # define worker classes\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n            from verl.single_controller.ray import RayWorkerGroup\n\n            from .sppo_worker import SPPOActorRolloutRefWorker  # , CriticWorker\n\n            actor_rollout_cls = SPPOActorRolloutRefWorker\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray import RayWorkerGroup\n            from verl.workers.megatron_workers import ActorRolloutRefWorker\n\n            actor_rollout_cls = ActorRolloutRefWorker\n            ray_worker_group_cls = RayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n        # sppo does not use critic\n        role_worker_mapping = {\n            Role.ActorRollout: ray.remote(actor_rollout_cls),\n        }\n\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {\n            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n        }\n        mapping = {\n            Role.ActorRollout: global_pool_id,\n        }\n\n        # we should adopt a multi-source reward function here\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # - finally, we combine all the rewards together\n        # - The reward type depends on the tag of the data\n        if config.reward_model.enable:\n            if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                from verl.workers.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # use reference model\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            role_worker_mapping[Role.RefPolicy] = ray.remote(SPPOActorRolloutRefWorker)\n            mapping[Role.RefPolicy] = global_pool_id\n\n        # validate config\n        validate_config(\n            config=config,\n            use_reference_policy=need_reference_policy(role_worker_mapping),\n            use_critic=False,\n        )\n\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n\n        # instantiate tokenizer\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        processor = hf_processor(local_path, use_fast=True)  # used for multimodal LLM, could be none\n\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1)\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        trainer = RaySPPOTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n        )\n        trainer.init_workers()\n        trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/sppo/run_qwen2.5-7b_rm.sh",
    "content": "# Discliamer: the model used in the script is only for academic purpose.\nset -x\n\n# Data preparation scripts are available in ``examples/data_preprocess``.\n# Example usage:\n#\n#   python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math\n#   python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k\n\ngsm8k_train_path=$HOME/data/math/train.parquet\ngsm8k_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\n# prepare model ckpt\nhuggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct &\n# huggingface-cli download sfairXC/FsfairX-LLaMA3-RM-v0.1 --local-dir $HOME/models/FsfairX-LLaMA3-RM-v0.1 &\nwait\n\npython3 -m recipe.sppo.main_sppo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"$HOME/models/Qwen2.5-7B-Instruct\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang  \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='sppo-sglang' \\\n    trainer.val_before_train=True \\\n    trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=1 \\\n    trainer.total_epochs=1000 $@\n    # Note that we set lr_warmup_steps = 15 in config/sppo_trainer.yaml\n    # The experiment will converge to 0.656 on MATH dataset after 20 epochs"
  },
  {
    "path": "verl_distillation/recipe/sppo/sppo_ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nFSDP PPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport uuid\nfrom copy import deepcopy\nfrom pprint import pprint\nfrom typing import Optional\n\nimport numpy as np\nimport ray\nimport torch\nfrom torch.utils.data import Dataset, Sampler\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.single_controller.ray import RayWorkerGroup\nfrom verl.trainer.ppo import core_algos\nfrom verl.trainer.ppo.core_algos import agg_loss\nfrom verl.trainer.ppo.ray_trainer import (\n    AdvantageEstimator,\n    RayPPOTrainer,\n    ResourcePoolManager,\n    apply_kl_penalty,\n    compute_response_mask,\n)\nfrom verl.trainer.ppo.reward import compute_reward, compute_reward_async\nfrom verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model\nfrom verl.utils.metric import reduce_metrics\nfrom verl.utils.profiler.performance import simple_timer\nfrom verl.utils.tracking import ValidationGenerationsLogger\n\n\ndef softmean(x: torch.Tensor, beta: float, dim: int = -1, keepdim: bool = False) -> torch.Tensor:\n    \"\"\"\n    Compute SoftMean_β(x) = (1/β) * log( (1/n) * Σ exp(β * x_i) )\n    Falls back to arithmetic mean when β=0.\n    \"\"\"\n    if beta == 0.0:\n        return x.mean(dim=dim, keepdim=keepdim)\n\n    # cast beta to tensor on same device/dtype\n    beta_t = x.new_tensor(beta)\n    # numerically-stable logsumexp(β x)\n    lse = torch.logsumexp(x * beta_t, dim=dim, keepdim=keepdim)\n    n = x.size(dim)\n    log_n = x.new_tensor(n).log()\n\n    return (lse - log_n) / beta_t\n\n\ndef compute_advantage(data: DataProto, beta=1.0):\n    rewards = data.batch[\"token_level_rewards\"].sum(axis=-1)  # (bs, )\n    s_mean = softmean(rewards, beta, keepdim=True)  # (bs, )\n    rewards = rewards - s_mean  # (bs, )\n    data.batch[\"seq_level_rewards\"] = rewards  # (bs, )\n    return data\n\n\nclass RaySPPOTrainer(RayPPOTrainer):\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        train_dataset: Optional[Dataset] = None,\n        val_dataset: Optional[Dataset] = None,\n        collate_fn=None,\n        train_sampler: Optional[Sampler] = None,\n        device_name=None,\n    ):\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n        assert self.hybrid_engine, \"Currently, only support hybrid engine\"\n\n        if self.hybrid_engine:\n            assert Role.ActorRollout in role_worker_mapping, f\"{role_worker_mapping.keys()=}\"\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = need_reference_policy(role_worker_mapping)\n        self.use_rm = need_reward_model(role_worker_mapping)\n        self.use_critic = False\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.validation_generations_logger = ValidationGenerationsLogger()\n        self.device_name = device_name if device_name else self.config.trainer.device\n\n        # define in-reward KL control\n        # kl loss control currently not supported\n        if config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)\n\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the\n        worker group through RPC to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n                timing_raw = {}\n                batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n                # pop those keys for generation\n                batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n                non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n                if \"multi_modal_data\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"multi_modal_data\")\n                if \"raw_prompt\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"raw_prompt\")\n                if \"tools_kwargs\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"tools_kwargs\")\n                gen_batch = batch.pop(\n                    batch_keys=batch_keys_to_pop,\n                    non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n                )\n                gen_batch_output = gen_batch.repeat(\n                    repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True\n                )\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                with simple_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with simple_timer(\"gen\", timing_raw):\n                        if not self.async_rollout_mode:\n                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)\n                        else:\n                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                        with simple_timer(\"gen_max\", timing_raw):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n\n                            batch = batch.union(gen_baseline_output)\n                            # compute reward model score on batch\n                            rm_scores = None\n                            if self.use_rm and \"rm_scores\" not in batch.batch.keys():\n                                rm_scores = self.rm_wg.compute_rm_score(batch)\n                                batch = batch.union(rm_scores)\n                            reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            keys_to_pop = set(gen_baseline_output.batch.keys())\n                            if rm_scores is not None:\n                                keys_to_pop.update(rm_scores.batch.keys())\n                            batch.pop(batch_keys=list(keys_to_pop))\n\n                            batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del rm_scores, gen_baseline_batch, gen_baseline_output\n\n                    batch.non_tensor_batch[\"uid\"] = np.array(\n                        [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object\n                    )\n                    # repeat to align with repeated responses in rollout\n                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    batch = batch.union(gen_batch_output)\n\n                    batch.batch[\"response_mask\"] = compute_response_mask(batch)\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    # TODO: Decouple the DP balancing and mini-batching.\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                with simple_timer(\"reward\", timing_raw):\n                    # compute reward model score\n                    if self.use_rm and \"rm_scores\" not in batch.batch.keys():\n                        reward_tensor = self.rm_wg.compute_rm_score(batch)\n                        batch = batch.union(reward_tensor)\n\n                    if self.config.reward_model.launch_reward_fn_async:\n                        future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)\n                    else:\n                        reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)\n\n                # recompute old_log_probs\n                with simple_timer(\"old_log_prob\", timing_raw):\n                    old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                    entropys = old_log_prob.batch[\"entropys\"]\n                    response_masks = batch.batch[\"response_mask\"]\n                    loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                    entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                    old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n                    metrics.update(old_log_prob_metrics)\n                    old_log_prob.batch.pop(\"entropys\")\n                    batch = batch.union(old_log_prob)\n\n                if self.use_reference_policy:\n                    # compute reference log_prob\n                    with simple_timer(\"ref\", timing_raw):\n                        ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                        batch = batch.union(ref_log_prob)\n\n                # compute values\n                if self.use_critic:\n                    with simple_timer(\"values\", timing_raw):\n                        values = self.critic_wg.compute_values(batch)\n                        batch = batch.union(values)\n\n                with simple_timer(\"adv\", timing_raw):\n                    # we combine with rule-based rm\n                    reward_extra_infos_dict: dict[str, list]\n                    if self.config.reward_model.launch_reward_fn_async:\n                        reward_tensor, reward_extra_infos_dict = ray.get(future_reward)\n                    batch.batch[\"token_level_scores\"] = reward_tensor\n\n                    if reward_extra_infos_dict:\n                        batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})\n\n                    # compute rewards. apply_kl_penalty if available\n                    if self.config.algorithm.use_kl_in_reward:\n                        batch, kl_metrics = apply_kl_penalty(\n                            batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                        )\n                        metrics.update(kl_metrics)\n                    else:\n                        batch.batch[\"token_level_rewards\"] = batch.batch[\"token_level_scores\"]\n                        batch.batch[\"seq_level_rewards\"] = batch.batch[\"token_level_scores\"]\n\n                    beta = self.config.algorithm.sppo_eta\n                    batch = compute_advantage(batch, beta=beta)\n\n                # update critic\n                if self.use_critic:\n                    with simple_timer(\"update_critic\", timing_raw):\n                        critic_output = self.critic_wg.update_critic(batch)\n                    critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                    metrics.update(critic_output_metrics)\n\n                # implement critic warmup\n                if self.config.trainer.critic_warmup <= self.global_steps:\n                    # update actor\n                    with simple_timer(\"update_actor\", timing_raw):\n                        batch.meta_info[\"multi_turn\"] = self.config.actor_rollout_ref.rollout.multi_turn.enable\n                        actor_output = self.actor_rollout_wg.update_actor(batch)\n                    actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                    metrics.update(actor_output_metrics)\n\n                # Log rollout generations if enabled\n                rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n                if rollout_data_dir:\n                    self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)\n\n                # validate\n                if (\n                    self.val_reward_fn is not None\n                    and self.config.trainer.test_freq > 0\n                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                ):\n                    with simple_timer(\"testing\", timing_raw):\n                        val_metrics: dict = self._validate()\n                        if is_last_step:\n                            last_val_metrics = val_metrics\n                    metrics.update(val_metrics)\n\n                if self.config.trainer.save_freq > 0 and (\n                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n                ):\n                    with simple_timer(\"save_checkpoint\", timing_raw):\n                        self._save_checkpoint()\n\n            # training metrics\n            metrics.update(\n                {\n                    \"training/global_step\": self.global_steps,\n                    \"training/epoch\": epoch,\n                }\n            )\n\n            # TODO: make a canonical logger that supports various backend\n            logger.log(data=metrics, step=self.global_steps)\n\n            if is_last_step:\n                pprint(f\"Final validation metrics: {last_val_metrics}\")\n                progress_bar.close()\n                return\n\n            progress_bar.update(1)\n            self.global_steps += 1\n"
  },
  {
    "path": "verl_distillation/recipe/sppo/sppo_worker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nfrom omegaconf import OmegaConf, open_dict\n\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fsdp_utils import offload_fsdp_model_to_cpu, offload_fsdp_optimizer\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.profiler import log_gpu_memory_usage\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_PPO_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass SPPOActorRolloutRefWorker(ActorRolloutRefWorker):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        from .dp_actor import DataParallelSPPOActor\n\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get(\"override_config\", {})))\n        use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        use_fused_kernels = self.config.model.get(\"use_fused_kernels\", False)\n\n        if self._is_actor or self._is_rollout:\n            # we need the model for actor and rollout\n            if self._is_actor:\n                optim_config = self.config.actor.optim\n                fsdp_config = self.config.actor.fsdp_config\n            else:\n                optim_config = None\n                fsdp_config = OmegaConf.create()\n            self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = (\n                self._build_model_optimizer(\n                    model_path=self.config.model.path,\n                    fsdp_config=fsdp_config,\n                    optim_config=optim_config,\n                    override_model_config=override_model_config,\n                    use_remove_padding=use_remove_padding,\n                    use_fused_kernels=use_fused_kernels,\n                    enable_gradient_checkpointing=self.config.model.get(\"enable_gradient_checkpointing\", False),\n                    trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                    use_liger=self.config.model.get(\"use_liger\", False),\n                    role=\"actor\",\n                )\n            )\n\n            # get the original unwrapped module\n            self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module\n\n            if self._is_offload_param:\n                offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n                log_gpu_memory_usage(\"After offload actor model during init\", logger=logger)\n\n            if self._is_offload_optimizer:\n                offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n                log_gpu_memory_usage(\"After offload actor optimizer during init\", logger=logger)\n        # load from checkpoint\n        if self._is_actor:\n            OmegaConf.set_struct(self.config.actor, True)\n            with open_dict(self.config.actor):\n                self.config.actor.use_remove_padding = use_remove_padding\n                self.config.actor.use_fused_kernels = use_fused_kernels\n            self.actor = DataParallelSPPOActor(\n                config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer\n            )\n\n        if self._is_rollout:\n            self._build_rollout(trust_remote_code=self.config.model.get(\"trust_remote_code\", False))\n\n        if self._is_ref:\n            self.ref_module_fsdp = self._build_model_optimizer(\n                model_path=self.config.model.path,\n                fsdp_config=self.config.ref.fsdp_config,\n                optim_config=None,\n                override_model_config=override_model_config,\n                use_remove_padding=use_remove_padding,\n                use_fused_kernels=use_fused_kernels,\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                use_liger=self.config.model.get(\"use_liger\", False),\n                role=\"ref\",\n            )[0]\n            OmegaConf.set_struct(self.config.ref, True)\n            with open_dict(self.config.ref):\n                self.config.ref.use_remove_padding = use_remove_padding\n                self.config.ref.use_fused_kernels = use_fused_kernels\n            self.ref_policy = DataParallelSPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)\n\n        if self._is_actor:\n            self.flops_counter = FlopsCounter(self.actor_model_config)\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp,\n                optimizer=self.actor.actor_optimizer,\n                lr_scheduler=self.actor_lr_scheduler,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                checkpoint_config=self.config.actor.checkpoint,\n            )\n"
  },
  {
    "path": "verl_distillation/recipe/transfer_queue/agent_loop.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport numpy as np\nimport ray\nfrom transfer_queue import BatchMeta\n\nimport verl.experimental.agent_loop.agent_loop as agent_loop\nfrom verl import DataProto\n\n\nclass AgentLoopManager(agent_loop.AgentLoopManager):\n    def generate_sequences(self, prompts: BatchMeta) -> BatchMeta:\n        \"\"\"Split input batch and dispatch to agent loop workers.\n\n        Args:\n            prompts (BatchMeta): Input batch.\n\n        Returns:\n            BatchMeta: Output batch metadata.\n        \"\"\"\n\n        if self.rm_micro_batch_size and len(prompts) % self.rm_micro_batch_size != 0:\n            raise ValueError(\n                f\"The length of prompts {len(prompts)} cannot divide the world size of rm_wg {self.rm_micro_batch_size}\"\n            )\n        if self.config.actor_rollout_ref.rollout.free_cache_engine:\n            self.wake_up()\n        chunkes = prompts.chunk(len(self.agent_loop_workers))\n        outputs = ray.get(\n            [\n                worker.generate_sequences.remote(chunk)\n                for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)\n            ]\n        )\n        output = BatchMeta.concat(outputs)\n        if self.config.actor_rollout_ref.rollout.free_cache_engine:\n            self.sleep()\n\n        # calculate performance metrics\n        metrics = [output.extra_info.pop(\"metrics\") for output in outputs]  # List[List[Dict[str, str]]]\n        timing = self._performance_metrics(metrics, output)\n\n        output.set_extra_info(\"timing\", timing)\n        return output\n\n    def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:\n        timing = {}\n        t_generate_sequences = np.array([metric[\"generate_sequences\"] for chunk in metrics for metric in chunk])\n        t_tool_calls = np.array([metric[\"tool_calls\"] for chunk in metrics for metric in chunk])\n        timing[\"agent_loop/generate_sequences/min\"] = t_generate_sequences.min()\n        timing[\"agent_loop/generate_sequences/max\"] = t_generate_sequences.max()\n        timing[\"agent_loop/generate_sequences/mean\"] = t_generate_sequences.mean()\n        timing[\"agent_loop/tool_calls/min\"] = t_tool_calls.min()\n        timing[\"agent_loop/tool_calls/max\"] = t_tool_calls.max()\n        timing[\"agent_loop/tool_calls/mean\"] = t_tool_calls.mean()\n\n        return timing\n\n    def create_transferqueue_client(self, controller_infos, storage_infos, role):\n        ray.get(\n            [\n                worker.create_transferqueue_client.remote(controller_infos, storage_infos, role)\n                for worker in self.agent_loop_workers\n            ]\n        )\n"
  },
  {
    "path": "verl_distillation/recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\n# config for TransferQueue\ntransfer_queue:\n  enable: True\n"
  },
  {
    "path": "verl_distillation/recipe/transfer_queue/main_ppo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport os\nimport socket\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom verl.trainer.constants_ppo import get_ppo_ray_runtime_env\nfrom verl.trainer.main_ppo import (\n    TaskRunner as MainTaskRunner,\n)\nfrom verl.trainer.main_ppo import (\n    create_rl_dataset,\n    create_rl_sampler,\n)\nfrom verl.trainer.ppo.reward import load_reward_manager\nfrom verl.trainer.ppo.utils import need_critic, need_reference_policy\nfrom verl.utils.config import validate_config\nfrom verl.utils.device import is_cuda_available\n\nfrom .ray_trainer import RayPPOTrainer\n\n\n@hydra.main(config_path=\"config\", config_name=\"ppo_trainer\", version_base=None)\ndef main(config):\n    \"\"\"Main entry point for PPO training with Hydra configuration management.\n\n    Args:\n        config_dict: Hydra configuration dictionary containing training parameters.\n    \"\"\"\n    run_ppo(config)\n\n\n# Define a function to run the PPO-like training process\ndef run_ppo(config, task_runner_class=None) -> None:\n    \"\"\"Initialize Ray cluster and run distributed PPO training process.\n\n    Args:\n        config: Training configuration object containing all necessary parameters\n                for distributed PPO training including Ray initialization settings,\n                model paths, and training hyperparameters.\n        task_runner_class: For recipe to change TaskRunner.\n    \"\"\"\n    # Check if Ray is not initialized\n    if not ray.is_initialized():\n        # Initialize Ray with a local cluster configuration\n        # Set environment variables in the runtime environment to control tokenizer parallelism,\n        # NCCL debug level, VLLM logging level, and allow runtime LoRA updating\n        # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration\n        default_runtime_env = get_ppo_ray_runtime_env()\n        ray_init_kwargs = config.ray_kwargs.get(\"ray_init\", {})\n        runtime_env_kwargs = ray_init_kwargs.get(\"runtime_env\", {})\n\n        if config.transfer_queue.enable:\n            # Add runtime environment variables for transfer queue\n            runtime_env_vars = runtime_env_kwargs.get(\"env_vars\", {})\n            runtime_env_vars[\"TRANSFER_QUEUE_ENABLE\"] = \"1\"\n            runtime_env_kwargs[\"env_vars\"] = runtime_env_vars\n\n        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)\n        ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, \"runtime_env\": runtime_env})\n        print(f\"ray init kwargs: {ray_init_kwargs}\")\n        ray.init(**OmegaConf.to_container(ray_init_kwargs))\n\n    if task_runner_class is None:\n        task_runner_class = ray.remote(num_cpus=1)(TaskRunner)  # please make sure main_task is not scheduled on head\n\n    # Create a remote instance of the TaskRunner class, and\n    # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete\n    if (\n        is_cuda_available\n        and config.global_profiler.tool == \"nsys\"\n        and config.global_profiler.get(\"steps\") is not None\n        and len(config.global_profiler.get(\"steps\", [])) > 0\n    ):\n        from verl.utils.import_utils import is_nvtx_available\n\n        assert is_nvtx_available(), \"nvtx is not available in CUDA platform. Please 'pip3 install nvtx'\"\n        nsight_options = OmegaConf.to_container(\n            config.global_profiler.global_tool_config.nsys.controller_nsight_options\n        )\n        runner = task_runner_class.options(runtime_env={\"nsight\": nsight_options}).remote()\n    else:\n        runner = task_runner_class.remote()\n    ray.get(runner.run.remote(config))\n\n    # [Optional] get the path of the timeline trace file from the configuration, default to None\n    # This file is used for performance analysis\n    timeline_json_file = config.ray_kwargs.get(\"timeline_json_file\", None)\n    if timeline_json_file:\n        ray.timeline(filename=timeline_json_file)\n\n\nclass TaskRunner(MainTaskRunner):\n    def run(self, config):\n        \"\"\"Execute the main PPO training workflow.\n\n        This method sets up the distributed training environment, initializes\n        workers, datasets, and reward functions, then starts the training process.\n\n        Args:\n            config: Training configuration object containing all parameters needed\n                   for setting up and running the PPO training process.\n        \"\"\"\n        # Print the initial configuration. `resolve=True` will evaluate symbolic values.\n        from pprint import pprint\n\n        from verl.utils.fs import copy_to_local\n\n        print(f\"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}\")\n        pprint(OmegaConf.to_container(config, resolve=True))\n        OmegaConf.resolve(config)\n\n        actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)\n        self.add_critic_worker(config)\n\n        # We should adopt a multi-source reward function here:\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # finally, we combine all the rewards together\n        # The reward type depends on the tag of the data\n        self.add_reward_model_worker(config)\n\n        # Add a reference policy worker if KL loss or KL reward is used.\n        self.add_ref_policy_worker(config, actor_rollout_cls)\n\n        # validate config\n        validate_config(\n            config=config,\n            use_reference_policy=need_reference_policy(self.role_worker_mapping),\n            use_critic=need_critic(config),\n        )\n\n        # Download the checkpoint from HDFS to the local machine.\n        # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on\n        local_path = copy_to_local(\n            config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get(\"use_shm\", False)\n        )\n\n        # Instantiate the tokenizer and processor.\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        # Used for multimodal LLM, could be None\n        processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)\n\n        # Load the reward manager for training and validation.\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        val_reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=1, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n\n        resource_pool_manager = self.init_resource_pool_mgr(config)\n\n        from verl.utils.dataset.rl_dataset import collate_fn\n\n        # Create training and validation datasets.\n        train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True)\n        val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False)\n        train_sampler = create_rl_sampler(config.data, train_dataset)\n\n        # Initialize the PPO trainer.\n        trainer = RayPPOTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=self.role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            train_dataset=train_dataset,\n            val_dataset=val_dataset,\n            collate_fn=collate_fn,\n            train_sampler=train_sampler,\n        )\n        # Initialize the workers of the trainer.\n        trainer.init_workers()\n        # Start the training process.\n        trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/recipe/transfer_queue/ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport asyncio\nimport json\nimport logging\nimport math\nimport os\nimport uuid\nfrom collections import defaultdict\nfrom dataclasses import dataclass, field\nfrom pprint import pprint\nfrom typing import Any, Optional\n\nimport numpy as np\nimport ray\nimport tensordict\nimport torch\nfrom omegaconf import OmegaConf, open_dict\nfrom packaging.version import parse as parse_version\nfrom tensordict import TensorDict\nfrom torch.utils.data import Dataset, Sampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom tqdm import tqdm\nfrom transfer_queue import (\n    BatchMeta,\n    TransferQueueController,\n    TransferQueueStorageSimpleUnit,\n    get_placement_group,\n    process_zmq_server_info,\n)\n\nfrom verl import DataProto\nfrom verl.experimental.dataset.sampler import AbstractCurriculumSampler\nfrom verl.single_controller.ray import (\n    RayClassWithInitArgs,\n    RayResourcePool,\n    RayWorkerGroup,\n)\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.config import AlgoConfig\nfrom verl.trainer.ppo import core_algos\nfrom verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss\nfrom verl.trainer.ppo.metric_utils import (\n    compute_data_metrics,\n    compute_throughout_metrics,\n    compute_timing_metrics,\n    process_validation_metrics,\n)\nfrom verl.trainer.ppo.reward import compute_reward, compute_reward_async\nfrom verl.trainer.ppo.utils import (\n    Role,\n    WorkerType,\n    need_critic,\n    need_reference_policy,\n    need_reward_model,\n)\nfrom verl.utils.checkpoint.checkpoint_manager import (\n    find_latest_ckpt_path,\n    should_save_ckpt_esi,\n)\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.debug import marked_timer\nfrom verl.utils.metric import reduce_metrics\nfrom verl.utils.rollout_skip import RolloutSkip\nfrom verl.utils.seqlen_balancing import (\n    get_seqlen_balanced_partitions,\n    log_seqlen_unbalance,\n)\nfrom verl.utils.torch_functional import masked_mean\nfrom verl.utils.tracking import ValidationGenerationsLogger\nfrom verl.utils.transferqueue_utils import (\n    create_transferqueue_client,\n    get_transferqueue_client,\n    get_val_transferqueue_client,\n    tqbridge,\n)\n\n\n@dataclass\nclass ResourcePoolManager:\n    \"\"\"\n    Define a resource pool specification. Resource pool will be initialized first.\n    \"\"\"\n\n    resource_pool_spec: dict[str, list[int]]\n    mapping: dict[Role, str]\n    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)\n\n    def create_resource_pool(self):\n        \"\"\"Create Ray resource pools for distributed training.\n\n        Initializes resource pools based on the resource pool specification,\n        with each pool managing GPU resources across multiple nodes.\n        For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups.\n        For Megatron backend, uses max_colocate_count>1 for different models.\n        \"\"\"\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool\n            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.\n            # For Megatron backend, we recommend using max_colocate_count>1\n            # that can utilize different WorkerGroup for differnt models\n            resource_pool = RayResourcePool(\n                process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name\n            )\n            self.resource_pool_dict[resource_pool_name] = resource_pool\n\n        self._check_resource_available()\n\n    def get_resource_pool(self, role: Role) -> RayResourcePool:\n        \"\"\"Get the resource pool of the worker_cls\"\"\"\n        return self.resource_pool_dict[self.mapping[role]]\n\n    def get_n_gpus(self) -> int:\n        \"\"\"Get the number of gpus in this cluster.\"\"\"\n        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])\n\n    def _check_resource_available(self):\n        \"\"\"Check if the resource pool can be satisfied in this ray cluster.\"\"\"\n        node_available_resources = ray._private.state.available_resources_per_node()\n        node_available_gpus = {\n            node: node_info.get(\"GPU\", 0) if \"GPU\" in node_info else node_info.get(\"NPU\", 0)\n            for node, node_info in node_available_resources.items()\n        }\n\n        # check total required gpus can be satisfied\n        total_available_gpus = sum(node_available_gpus.values())\n        total_required_gpus = sum(\n            [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]\n        )\n        if total_available_gpus < total_required_gpus:\n            raise ValueError(\n                f\"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}\"\n            )\n\n\n@tqbridge(put_data=False)\ndef compute_reward_decorated(data, reward_fn):\n    return compute_reward(data, reward_fn)\n\n\n@tqbridge(put_data=False)\ndef compute_reward_async_decorated(data, reward_fn):\n    return compute_reward_async.remote(data, reward_fn)\n\n\n@tqbridge(put_data=False)\ndef apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty=\"kl\"):\n    \"\"\"Apply KL penalty to the token-level rewards.\n\n    This function computes the KL divergence between the reference policy and current policy,\n    then applies a penalty to the token-level rewards based on this divergence.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n        kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty.\n        kl_penalty (str, optional): Type of KL penalty to apply. Defaults to \"kl\".\n\n    Returns:\n        tuple: A tuple containing:\n            - The updated data with token-level rewards adjusted by KL penalty\n            - A dictionary of metrics related to the KL penalty\n    \"\"\"\n    response_mask = data.batch[\"response_mask\"]\n    token_level_scores = data.batch[\"token_level_scores\"]\n    batch_size = data.batch.batch_size[0]\n\n    # compute kl between ref_policy and current policy\n    # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.\n    kld = core_algos.kl_penalty(\n        data.batch[\"old_log_probs\"], data.batch[\"ref_log_prob\"], kl_penalty=kl_penalty\n    )  # (batch_size, response_length)\n    kld = kld * response_mask\n    beta = kl_ctrl.value\n\n    token_level_rewards = token_level_scores - beta * kld\n\n    current_kl = masked_mean(kld, mask=response_mask, axis=-1)  # average over sequence\n    current_kl = torch.mean(current_kl, dim=0).item()\n\n    # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837\n    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)\n\n    metrics = {\"actor/reward_kl_penalty\": current_kl, \"actor/reward_kl_penalty_coeff\": beta}\n\n    return token_level_rewards, metrics\n\n\ndef compute_response_mask(batch_meta: BatchMeta, data_system_client):\n    \"\"\"Compute the attention mask for the response part of the sequence.\n\n    This function extracts the portion of the attention mask that corresponds to the model's response,\n    which is used for masking computations that should only apply to response tokens.\n\n    Args:\n        batch_meta (BatchMeta): The data containing batched model outputs and inputs.\n\n    Returns:\n        BatchMeta: The BatchMeta of attention mask for the response tokens.\n    \"\"\"\n    data = asyncio.run(data_system_client.async_get_data(batch_meta))\n\n    responses = data[\"responses\"]\n    response_length = responses.size(1)\n    attention_mask = data[\"attention_mask\"]\n    response_mask = attention_mask[:, -response_length:]\n    output = TensorDict({\"response_mask\": response_mask}, batch_size=response_mask.size(0))\n\n    asyncio.run(data_system_client.async_put(data=output, metadata=batch_meta))\n    batch_meta.add_fields(output)\n\n    return batch_meta\n\n\n@tqbridge(put_data=False)\ndef compute_advantage(\n    data: DataProto,\n    adv_estimator: AdvantageEstimator,\n    gamma: float = 1.0,\n    lam: float = 1.0,\n    num_repeat: int = 1,\n    norm_adv_by_std_in_grpo: bool = True,\n    config: Optional[AlgoConfig] = None,\n) -> tuple[Any, Any]:\n    \"\"\"Compute advantage estimates for policy optimization.\n\n    This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc.\n    The advantage estimates are used to guide policy optimization in RL algorithms.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n        adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++).\n        gamma (float, optional): Discount factor for future rewards. Defaults to 1.0.\n        lam (float, optional): Lambda parameter for GAE. Defaults to 1.0.\n        num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1.\n        norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in\n            GRPO. Defaults to True.\n        config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None.\n\n    Returns:\n        tuple: A tuple containing:\n            - advantages: The computed advantage estimates.\n            - returns: The computed returns.\n    \"\"\"\n    # prepare response group\n    if adv_estimator == AdvantageEstimator.GAE:\n        # Compute advantages and returns using Generalized Advantage Estimation (GAE)\n        advantages, returns = core_algos.compute_gae_advantage_return(\n            token_level_rewards=data.batch[\"token_level_rewards\"],\n            values=data.batch[\"values\"],\n            response_mask=data.batch[\"response_mask\"],\n            gamma=gamma,\n            lam=lam,\n        )\n        # TODO: (TQ) adapt core_algos.compute_pf_ppo_reweight_data function to support transfer queue\n        if config.get(\"use_pf_ppo\", False):\n            data = core_algos.compute_pf_ppo_reweight_data(\n                data,\n                config.pf_ppo.get(\"reweight_method\"),\n                config.pf_ppo.get(\"weight_pow\"),\n            )\n    elif adv_estimator == AdvantageEstimator.GRPO:\n        # Initialize the mask for GRPO calculation\n        grpo_calculation_mask = data.batch[\"response_mask\"]\n        # Call compute_grpo_outcome_advantage with parameters matching its definition\n        advantages, returns = core_algos.compute_grpo_outcome_advantage(\n            token_level_rewards=data.batch[\"token_level_rewards\"],\n            response_mask=grpo_calculation_mask,\n            index=data.non_tensor_batch[\"uid\"],\n            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n        )\n    else:\n        # handle all other adv estimator type other than GAE and GRPO\n        adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)\n        adv_kwargs = {\n            \"token_level_rewards\": data.batch[\"token_level_rewards\"],\n            \"response_mask\": data.batch[\"response_mask\"],\n            \"config\": config,\n        }\n        if \"uid\" in data.non_tensor_batch:  # optional\n            adv_kwargs[\"index\"] = data.non_tensor_batch[\"uid\"]\n        if \"reward_baselines\" in data.batch:  # optional\n            adv_kwargs[\"reward_baselines\"] = data.batch[\"reward_baselines\"]\n\n        # calculate advantage estimator\n        advantages, returns = adv_estimator_fn(**adv_kwargs)\n    return advantages, returns\n\n\n@tqbridge(put_data=False)\ndef compute_data_metrics_decorated(batch, use_critic: bool = True):\n    return compute_data_metrics(batch, use_critic)\n\n\n@tqbridge(put_data=False)\ndef compute_timing_metrics_decorated(batch, timing_raw: dict[str, float]) -> dict[str, Any]:\n    return compute_timing_metrics(batch, timing_raw)\n\n\n@tqbridge(put_data=False)\ndef compute_throughout_metrics_decorated(batch, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]:\n    return compute_throughout_metrics(batch, timing_raw, n_gpus)\n\n\n@tqbridge(put_data=False)\ndef calculate_debug_metrics_decorated(data):\n    from verl.utils.debug.metrics import calculate_debug_metrics\n\n    return calculate_debug_metrics(data)\n\n\n@tqbridge(put_data=False)\ndef compute_val_reward_decorated(reward_fn, data, return_dict):\n    return reward_fn(data, return_dict)\n\n\nclass RayPPOTrainer:\n    \"\"\"Distributed PPO trainer using Ray for scalable reinforcement learning.\n\n    This trainer orchestrates distributed PPO training across multiple nodes and GPUs,\n    managing actor rollouts, critic training, and reward computation with Ray backend.\n    Supports various model architectures including FSDP, Megatron, vLLM, and SGLang integration.\n    \"\"\"\n\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        train_dataset: Optional[Dataset] = None,\n        val_dataset: Optional[Dataset] = None,\n        collate_fn=None,\n        train_sampler: Optional[Sampler] = None,\n        device_name=None,\n    ):\n        \"\"\"\n        Initialize distributed PPO trainer with Ray backend.\n        Note that this trainer runs on the driver process on a single CPU/GPU node.\n\n        Args:\n            config: Configuration object containing training parameters.\n            tokenizer: Tokenizer used for encoding and decoding text.\n            role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes.\n            resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools.\n            ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup.\n            processor: Optional data processor, used for multimodal data\n            reward_fn: Function for computing rewards during training.\n            val_reward_fn: Function for computing rewards during validation.\n            train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.\n            val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.\n            collate_fn: Function to collate data samples into batches.\n            train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.\n            device_name (str, optional): Device name for training (e.g., \"cuda\", \"cpu\"). Defaults to None.\n        \"\"\"\n\n        # Store the tokenizer for text processing\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n        assert self.hybrid_engine, \"Currently, only support hybrid engine\"\n\n        if self.hybrid_engine:\n            assert Role.ActorRollout in role_worker_mapping, f\"{role_worker_mapping.keys()=}\"\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = need_reference_policy(self.role_worker_mapping)\n        self.use_rm = need_reward_model(self.role_worker_mapping)\n        self.use_critic = need_critic(self.config)\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.device_name = device_name if device_name else self.config.trainer.device\n        self.validation_generations_logger = ValidationGenerationsLogger(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n        )\n\n        # if ref_in_actor is True, the reference policy will be actor without lora applied\n        self.ref_in_actor = config.actor_rollout_ref.model.get(\"lora_rank\", 0) > 0\n\n        # define in-reward KL control\n        # kl loss control currently not suppoorted\n        if self.config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)\n\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n        self.data_system_client = self._initialize_train_data_system(\n            self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n\n        )\n        self.val_data_system_client = self._initialize_val_data_system(\n            self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n\n        )\n\n    def _initialize_train_data_system(self, global_batch_size, num_n_samples, role=\"train\"):\n        # 1. initialize TransferQueueStorage\n        total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples\n        self.data_system_storage_units = {}\n        storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)\n        for storage_unit_rank in range(self.config.trainer.num_data_storage_units):\n            storage_node = TransferQueueStorageSimpleUnit.options(\n                placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank\n            ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))\n            self.data_system_storage_units[storage_unit_rank] = storage_node\n            logging.info(f\"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.\")\n\n        # 2. initialize TransferQueueController\n        # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly\n        # one controller for a single WorkerGroup.\n        self.data_system_controllers = {}\n        controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)\n        for controller_rank in range(self.config.trainer.num_data_controllers):\n            self.data_system_controllers[controller_rank] = TransferQueueController.options(\n                placement_group=controller_placement_group, placement_group_bundle_index=controller_rank\n            ).remote(\n                num_storage_units=self.config.trainer.num_data_storage_units,\n                global_batch_size=global_batch_size,\n                num_global_batch=self.config.trainer.num_global_batch,\n                num_n_samples=num_n_samples,\n            )\n            logging.info(f\"TransferQueueController #{controller_rank} has been created.\")\n\n        # 3. register controller & storage\n        self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers)\n        self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)\n\n        ray.get(\n            [\n                storage_unit.register_controller_info.remote(self.data_system_controller_infos)\n                for storage_unit in self.data_system_storage_units.values()\n            ]\n        )\n\n        # 4. create client\n        # each client should be allocated to exactly one controller\n        create_transferqueue_client(\n            client_id=\"Trainer-\" + role,\n            controller_infos=self.data_system_controller_infos,\n            storage_infos=self.data_system_storage_unit_infos,\n        )\n        data_system_client = get_transferqueue_client()\n        return data_system_client\n\n    def _initialize_val_data_system(self, global_batch_size, num_n_samples, role=\"val\"):\n        # 1. initialize TransferQueueStorage\n        total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples\n        self.val_data_system_storage_units = {}\n        storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)\n        for storage_unit_rank in range(self.config.trainer.num_data_storage_units):\n            storage_node = TransferQueueStorageSimpleUnit.options(\n                placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank\n            ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))\n            self.val_data_system_storage_units[storage_unit_rank] = storage_node\n            logging.info(f\"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.\")\n\n        # 2. initialize TransferQueueController\n        # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly\n        # one controller for a single WorkerGroup.\n        self.val_data_system_controllers = {}\n        controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)\n        for controller_rank in range(self.config.trainer.num_data_controllers):\n            self.val_data_system_controllers[controller_rank] = TransferQueueController.options(\n                placement_group=controller_placement_group, placement_group_bundle_index=controller_rank\n            ).remote(\n                num_storage_units=self.config.trainer.num_data_storage_units,\n                global_batch_size=global_batch_size,\n                num_global_batch=self.config.trainer.num_global_batch,\n                num_n_samples=num_n_samples,\n            )\n            logging.info(f\"TransferQueueController #{controller_rank} has been created.\")\n\n        # 3. register controller & storage\n        self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers)\n        self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units)\n\n        ray.get(\n            [\n                storage_unit.register_controller_info.remote(self.val_data_system_controller_infos)\n                for storage_unit in self.val_data_system_storage_units.values()\n            ]\n        )\n\n        # 4. create client\n        # each client should be allocated to exactly one controller\n        create_transferqueue_client(\n            client_id=\"Trainer-\" + role,\n            controller_infos=self.val_data_system_controller_infos,\n            storage_infos=self.val_data_system_storage_unit_infos,\n        )\n        data_system_client = get_val_transferqueue_client()\n        return data_system_client\n\n    def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):\n        \"\"\"\n        Creates the train and validation dataloaders.\n        \"\"\"\n        # TODO: we have to make sure the batch size is divisible by the dp size\n        from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler\n\n        if train_dataset is None:\n            train_dataset = create_rl_dataset(\n                self.config.data.train_files, self.config.data, self.tokenizer, self.processor\n            )\n        if val_dataset is None:\n            val_dataset = create_rl_dataset(\n                self.config.data.val_files, self.config.data, self.tokenizer, self.processor\n            )\n        self.train_dataset, self.val_dataset = train_dataset, val_dataset\n\n        if train_sampler is None:\n            train_sampler = create_rl_sampler(self.config.data, self.train_dataset)\n        if collate_fn is None:\n            from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn\n\n            collate_fn = default_collate_fn\n\n        num_workers = self.config.data[\"dataloader_num_workers\"]\n\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=self.config.data.get(\"gen_batch_size\", self.config.data.train_batch_size),\n            num_workers=num_workers,\n            drop_last=True,\n            collate_fn=collate_fn,\n            sampler=train_sampler,\n        )\n\n        val_batch_size = self.config.data.val_batch_size  # Prefer config value if set\n        if val_batch_size is None:\n            val_batch_size = len(self.val_dataset)\n        self.val_batch_size = val_batch_size\n\n        self.val_dataloader = StatefulDataLoader(\n            dataset=self.val_dataset,\n            batch_size=val_batch_size,\n            num_workers=num_workers,\n            shuffle=self.config.data.get(\"validation_shuffle\", True),\n            drop_last=False,\n            collate_fn=collate_fn,\n        )\n\n        assert len(self.train_dataloader) >= 1, \"Train dataloader is empty!\"\n        assert len(self.val_dataloader) >= 1, \"Validation dataloader is empty!\"\n\n        print(\n            f\"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: \"\n            f\"{len(self.val_dataloader)}\"\n        )\n\n        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n\n        if self.config.trainer.total_training_steps is not None:\n            total_training_steps = self.config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n        print(f\"Total training steps: {self.total_training_steps}\")\n\n        try:\n            OmegaConf.set_struct(self.config, True)\n            with open_dict(self.config):\n                if OmegaConf.select(self.config, \"actor_rollout_ref.actor.optim\"):\n                    self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps\n                if OmegaConf.select(self.config, \"critic.optim\"):\n                    self.config.critic.optim.total_training_steps = total_training_steps\n        except Exception as e:\n            print(f\"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}\")\n\n    def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path):\n        \"\"\"Dump rollout/validation samples as JSONL.\"\"\"\n        os.makedirs(dump_path, exist_ok=True)\n        filename = os.path.join(dump_path, f\"{self.global_steps}.jsonl\")\n\n        n = len(inputs)\n        base_data = {\n            \"input\": inputs,\n            \"output\": outputs,\n            \"gts\": gts,\n            \"score\": scores,\n            \"step\": [self.global_steps] * n,\n        }\n\n        for k, v in reward_extra_infos_dict.items():\n            if len(v) == n:\n                base_data[k] = v\n\n        lines = []\n        for i in range(n):\n            entry = {k: v[i] for k, v in base_data.items()}\n            lines.append(json.dumps(entry, ensure_ascii=False))\n\n        with open(filename, \"w\") as f:\n            f.write(\"\\n\".join(lines) + \"\\n\")\n\n        print(f\"Dumped generations to {filename}\")\n\n    def _log_rollout_data(\n        self, log_rollout_meta: BatchMeta, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str\n    ):\n        \"\"\"\n        Log rollout data to disk.\n\n        Args:\n            log_rollout_meta (BatchMeta): The batch_meta of rollout data\n            reward_extra_infos_dict (dict): Additional reward information to log\n            timing_raw (dict): Timing information for profiling\n            rollout_data_dir (str): Directory path to save the rollout data\n        \"\"\"\n        with marked_timer(\"dump_rollout_generations\", timing_raw, color=\"green\"):\n            data = asyncio.run(self.data_system_client.async_get_data(log_rollout_meta))\n\n            inputs = self.tokenizer.batch_decode(data[\"prompts\"], skip_special_tokens=True)\n            outputs = self.tokenizer.batch_decode(data[\"responses\"], skip_special_tokens=True)\n            scores = data[\"token_level_scores\"].sum(-1).cpu().tolist()\n            sample_gts = [item.get(\"ground_truth\", None) for item in data.get(\"reward_model\", {})]\n\n            reward_extra_infos_to_dump = reward_extra_infos_dict.copy()\n            if \"request_id\" in log_rollout_meta.field_names:\n                reward_extra_infos_dict.setdefault(\n                    \"request_id\",\n                    data[\"request_id\"].tolist(),\n                )\n\n            self._dump_generations(\n                inputs=inputs,\n                outputs=outputs,\n                gts=sample_gts,\n                scores=scores,\n                reward_extra_infos_dict=reward_extra_infos_to_dump,\n                dump_path=rollout_data_dir,\n            )\n\n    def _maybe_log_val_generations(self, inputs, outputs, scores):\n        \"\"\"Log a table of validation samples to the configured logger (wandb or swanlab)\"\"\"\n\n        generations_to_log = self.config.trainer.log_val_generations\n\n        if generations_to_log == 0:\n            return\n\n        import numpy as np\n\n        # Create tuples of (input, output, score) and sort by input text\n        samples = list(zip(inputs, outputs, scores, strict=True))\n        samples.sort(key=lambda x: x[0])  # Sort by input text\n\n        # Use fixed random seed for deterministic shuffling\n        rng = np.random.RandomState(42)\n        rng.shuffle(samples)\n\n        # Take first N samples after shuffling\n        samples = samples[:generations_to_log]\n\n        # Log to each configured logger\n        self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)\n\n    def _get_gen_batch(self, batch: DataProto) -> DataProto:\n        reward_model_keys = set({\"data_source\", \"reward_model\", \"extra_info\", \"uid\"}) & batch.non_tensor_batch.keys()\n\n        # pop those keys for generation\n        batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n        non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys\n        gen_batch = batch.pop(\n            batch_keys=batch_keys_to_pop,\n            non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop),\n        )\n\n        # For agent loop, we need reward model keys to compute score.\n        if self.async_rollout_mode:\n            gen_batch.non_tensor_batch.update(batch.non_tensor_batch)\n\n        return gen_batch\n\n    def _validate(self):\n        data_source_lst = []\n        reward_extra_infos_dict: dict[str, list] = defaultdict(list)\n\n        # Lists to collect samples for the table\n        sample_inputs = []\n        sample_outputs = []\n        sample_gts = []\n        sample_scores = []\n        sample_turns = []\n        sample_uids = []\n\n        for test_data in self.val_dataloader:\n            if \"uid\" not in test_data.keys():\n                test_data[\"uid\"] = np.array(\n                    [str(uuid.uuid4()) for _ in range(len(test_data[\"input_ids\"]))], dtype=object\n                )\n\n            # repeat test data\n            repeated_test_data = self.repeat_dict(\n                test_data, repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True\n            )\n\n            test_batch: TensorDict = self.dict_to_tensordict(repeated_test_data)\n\n            # we only do validation on rule-based rm\n            if self.config.reward_model.enable and test_batch[0][\"reward_model\"][\"style\"] == \"model\":\n                return {}\n\n            asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1))\n\n            # Store original inputs\n            batch_meta = asyncio.run(\n                self.val_data_system_client.async_get_meta(\n                    data_fields=[\"input_ids\", \"uid\", \"reward_model\"],\n                    batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,\n                    global_step=self.global_steps - 1,\n                    get_n_samples=False,\n                    task_name=\"get_data\",\n                )\n            )\n            data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta))\n            input_ids = data[\"input_ids\"]\n            # TODO: Can we keep special tokens except for padding tokens?\n            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]\n            sample_inputs.extend(input_texts)\n            sample_uids.extend(data[\"uid\"])\n\n            ground_truths = [item.get(\"ground_truth\", None) for item in data.get(\"reward_model\", {})]\n            sample_gts.extend(ground_truths)\n\n            test_gen_meta = asyncio.run(\n                self.val_data_system_client.async_get_meta(\n                    data_fields=list(test_batch.keys()),  # TODO: (TQ) Get metadata by specified fields\n                    batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,\n                    global_step=self.global_steps - 1,  # self.global_steps start from 1\n                    get_n_samples=False,\n                    task_name=\"generate_sequences\",\n                )\n            )\n            test_gen_meta.extra_info = {\n                \"eos_token_id\": self.tokenizer.eos_token_id,\n                \"pad_token_id\": self.tokenizer.pad_token_id,\n                \"recompute_log_prob\": False,\n                \"do_sample\": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,\n                \"validate\": True,\n                \"global_steps\": self.global_steps,\n            }\n            print(f\"test_gen_batch meta info: {test_gen_meta.extra_info}\")\n\n            # TODO: (TQ) Support padding and unpadding to make DataProto divisible by dp_size with TransferQueue\n            if not self.async_rollout_mode:\n                test_output_gen_meta = self.actor_rollout_wg.generate_sequences(test_gen_meta)\n            else:\n                test_output_gen_meta = self.async_rollout_manager.generate_sequences(test_gen_meta)\n\n            test_batch_meta = test_gen_meta.union(test_output_gen_meta)\n\n            print(\"validation generation end\")\n\n            # Store generated outputs\n            test_response_meta = asyncio.run(\n                self.val_data_system_client.async_get_meta(\n                    data_fields=[\"responses\"],\n                    batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,\n                    global_step=self.global_steps - 1,  # self.global_steps start from 1\n                    get_n_samples=False,\n                    task_name=\"get_response\",\n                )\n            )\n            data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta))\n            output_ids = data[\"responses\"]\n            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]\n            sample_outputs.extend(output_texts)\n\n            test_batch_meta.set_extra_info(\"validate\", True)\n\n            # evaluate using reward_function\n            if self.val_reward_fn is None:\n                raise ValueError(\"val_reward_fn must be provided for validation.\")\n\n            compute_reward_fields = [\n                \"responses\",\n                \"prompts\",\n                \"attention_mask\",\n                \"reward_model\",\n                \"data_source\",\n            ]\n            if \"rm_scores\" in batch_meta.field_names:\n                compute_reward_fields = [\"rm_scores\"]\n            val_reward_meta = asyncio.run(\n                self.val_data_system_client.async_get_meta(\n                    data_fields=compute_reward_fields,\n                    batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,\n                    global_step=self.global_steps - 1,\n                    get_n_samples=False,\n                    task_name=\"compute_reward\",\n                )\n            )\n            val_reward_meta.update_extra_info(test_batch_meta.extra_info)\n            result = compute_val_reward_decorated(self.val_reward_fn, val_reward_meta, return_dict=True)\n            reward_tensor = result[\"reward_tensor\"]\n            scores = reward_tensor.sum(-1).cpu().tolist()\n            sample_scores.extend(scores)\n\n            reward_extra_infos_dict[\"reward\"].extend(scores)\n            print(f\"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}\")\n            if \"reward_extra_info\" in result:\n                for key, lst in result[\"reward_extra_info\"].items():\n                    reward_extra_infos_dict[key].extend(lst)\n                    print(f\"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}\")\n\n            # collect num_turns of each prompt\n            if \"__num_turns__\" in test_batch_meta.field_names:\n                num_turns_meta = asyncio.run(\n                    self.val_data_system_client.async_get_meta(\n                        data_fields=[\"__num_turns__\"],\n                        batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,\n                        global_step=self.global_steps - 1,  # self.global_steps start from 1\n                        get_n_samples=False,\n                        task_name=\"get_num_turns\",\n                    )\n                )\n                data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta))\n                sample_turns.append(data[\"__num_turns__\"])\n\n            data_source = [\"unknown\"] * reward_tensor.shape[0]\n            if \"data_source\" in test_batch_meta.field_names:\n                data_source_meta = asyncio.run(\n                    self.val_data_system_client.async_get_meta(\n                        data_fields=[\"data_source\"],\n                        batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,\n                        global_step=self.global_steps - 1,  # self.global_steps start from 1\n                        get_n_samples=False,\n                        task_name=\"get_data_source\",\n                    )\n                )\n                data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta))\n                data_source = data[\"data_source\"]\n\n            data_source_lst.append(data_source)\n\n        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)\n\n        # dump generations\n        val_data_dir = self.config.trainer.get(\"validation_data_dir\", None)\n        if val_data_dir:\n            self._dump_generations(\n                inputs=sample_inputs,\n                outputs=sample_outputs,\n                gts=sample_gts,\n                scores=sample_scores,\n                reward_extra_infos_dict=reward_extra_infos_dict,\n                dump_path=val_data_dir,\n            )\n\n        for key_info, lst in reward_extra_infos_dict.items():\n            assert len(lst) == 0 or len(lst) == len(sample_scores), f\"{key_info}: {len(lst)=}, {len(sample_scores)=}\"\n\n        data_sources = np.concatenate(data_source_lst, axis=0)\n\n        data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)\n        metric_dict = {}\n        for data_source, var2metric2val in data_src2var2metric2val.items():\n            core_var = \"acc\" if \"acc\" in var2metric2val else \"reward\"\n            for var_name, metric2val in var2metric2val.items():\n                n_max = max([int(name.split(\"@\")[-1].split(\"/\")[0]) for name in metric2val.keys()])\n                for metric_name, metric_val in metric2val.items():\n                    if (\n                        (var_name == core_var)\n                        and any(metric_name.startswith(pfx) for pfx in [\"mean\", \"maj\", \"best\"])\n                        and (f\"@{n_max}\" in metric_name)\n                    ):\n                        metric_sec = \"val-core\"\n                    else:\n                        metric_sec = \"val-aux\"\n                    pfx = f\"{metric_sec}/{data_source}/{var_name}/{metric_name}\"\n                    metric_dict[pfx] = metric_val\n\n        if len(sample_turns) > 0:\n            sample_turns = np.concatenate(sample_turns)\n            metric_dict[\"val-aux/num_turns/min\"] = sample_turns.min()\n            metric_dict[\"val-aux/num_turns/max\"] = sample_turns.max()\n            metric_dict[\"val-aux/num_turns/mean\"] = sample_turns.mean()\n\n        asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1))\n        return metric_dict\n\n    def init_workers(self):\n        \"\"\"Initialize distributed training workers using Ray backend.\n\n        Creates:\n        1. Ray resource pools from configuration\n        2. Worker groups for each role (actor, critic, etc.)\n        \"\"\"\n        self.resource_pool_manager.create_resource_pool()\n\n        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}\n\n        # create actor and rollout\n        if self.hybrid_engine:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)\n            actor_rollout_cls = RayClassWithInitArgs(\n                cls=self.role_worker_mapping[Role.ActorRollout],\n                config=self.config.actor_rollout_ref,\n                role=\"actor_rollout\",\n            )\n            self.resource_pool_to_cls[resource_pool][\"actor_rollout\"] = actor_rollout_cls\n        else:\n            raise NotImplementedError\n\n        # create critic\n        if self.use_critic:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)\n            critic_cfg = omega_conf_to_dataclass(self.config.critic)\n            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)\n            self.resource_pool_to_cls[resource_pool][\"critic\"] = critic_cls\n\n        # create reference policy if needed\n        if self.use_reference_policy:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)\n            ref_policy_cls = RayClassWithInitArgs(\n                self.role_worker_mapping[Role.RefPolicy],\n                config=self.config.actor_rollout_ref,\n                role=\"ref\",\n            )\n            self.resource_pool_to_cls[resource_pool][\"ref\"] = ref_policy_cls\n\n        # create a reward model if reward_fn is None\n        if self.use_rm:\n            # we create a RM here\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)\n            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)\n            self.resource_pool_to_cls[resource_pool][\"rm\"] = rm_cls\n\n        # initialize WorkerGroup\n        # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,\n        # you should not use `create_colocated_worker_cls`.\n        # Instead, directly pass different resource pool to different worker groups.\n        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.\n        all_wg = {}\n        wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup\n        if OmegaConf.select(self.config.trainer, \"ray_wait_register_center_timeout\") is not None:\n            wg_kwargs[\"ray_wait_register_center_timeout\"] = self.config.trainer.ray_wait_register_center_timeout\n        if OmegaConf.select(self.config.global_profiler, \"steps\") is not None:\n            wg_kwargs[\"profile_steps\"] = OmegaConf.select(self.config.global_profiler, \"steps\")\n            # Only require nsight worker options when tool is nsys\n            if OmegaConf.select(self.config.global_profiler, \"tool\") == \"nsys\":\n                assert (\n                    OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, \"worker_nsight_options\")\n                    is not None\n                ), \"worker_nsight_options must be set when using nsys with profile_steps\"\n                wg_kwargs[\"worker_nsight_options\"] = OmegaConf.to_container(\n                    OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, \"worker_nsight_options\")\n                )\n        wg_kwargs[\"device_name\"] = self.device_name\n\n        for resource_pool, class_dict in self.resource_pool_to_cls.items():\n            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n            wg_dict = self.ray_worker_group_cls(\n                resource_pool=resource_pool,\n                ray_cls_with_init=worker_dict_cls,\n                **wg_kwargs,\n            )\n            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n            all_wg.update(spawn_wg)\n\n        if self.use_critic:\n            self.critic_wg = all_wg[\"critic\"]\n            self.critic_wg.init_model()\n\n        if self.use_reference_policy and not self.ref_in_actor:\n            self.ref_policy_wg = all_wg[\"ref\"]\n            self.ref_policy_wg.init_model()\n\n        self.rm_wg = None\n        if self.use_rm:\n            self.rm_wg = all_wg[\"rm\"]\n            self.rm_wg.init_model()\n\n        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory\n        self.actor_rollout_wg = all_wg[\"actor_rollout\"]\n        self.actor_rollout_wg.init_model()\n\n        # set transferqueue server info for each worker\n        for _, wg in all_wg.items():\n            wg.create_transferqueue_client(\n                self.data_system_controller_infos, self.data_system_storage_unit_infos, role=\"train\"\n            )\n            wg.create_transferqueue_client(\n                self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role=\"val\"\n            )\n\n        # create async rollout manager and request scheduler\n        self.async_rollout_mode = False\n        if self.config.actor_rollout_ref.rollout.mode == \"async\":\n            from .agent_loop import AgentLoopManager\n\n            self.async_rollout_mode = True\n            self.async_rollout_manager = AgentLoopManager(\n                config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg\n            )\n\n            self.async_rollout_manager.create_transferqueue_client(\n                self.data_system_controller_infos, self.data_system_storage_unit_infos, role=\"train\"\n            )\n            self.async_rollout_manager.create_transferqueue_client(\n                self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role=\"val\"\n            )\n\n    def _save_checkpoint(self):\n        from verl.utils.fs import local_mkdir_safe\n\n        # path: given_path + `/global_step_{global_steps}` + `/actor`\n        local_global_step_folder = os.path.join(\n            self.config.trainer.default_local_dir, f\"global_step_{self.global_steps}\"\n        )\n\n        print(f\"local_global_step_folder: {local_global_step_folder}\")\n        actor_local_path = os.path.join(local_global_step_folder, \"actor\")\n\n        actor_remote_path = (\n            None\n            if self.config.trainer.default_hdfs_dir is None\n            else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"actor\")\n        )\n\n        remove_previous_ckpt_in_save = self.config.trainer.get(\"remove_previous_ckpt_in_save\", False)\n        if remove_previous_ckpt_in_save:\n            print(\n                \"Warning: remove_previous_ckpt_in_save is deprecated,\"\n                + \" set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead\"\n            )\n        max_actor_ckpt_to_keep = (\n            self.config.trainer.get(\"max_actor_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n        max_critic_ckpt_to_keep = (\n            self.config.trainer.get(\"max_critic_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n\n        self.actor_rollout_wg.save_checkpoint(\n            actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep\n        )\n\n        if self.use_critic:\n            critic_local_path = os.path.join(local_global_step_folder, \"critic\")\n            critic_remote_path = (\n                None\n                if self.config.trainer.default_hdfs_dir is None\n                else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"critic\")\n            )\n            self.critic_wg.save_checkpoint(\n                critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep\n            )\n\n        # save dataloader\n        local_mkdir_safe(local_global_step_folder)\n        dataloader_local_path = os.path.join(local_global_step_folder, \"data.pt\")\n        dataloader_state_dict = self.train_dataloader.state_dict()\n        torch.save(dataloader_state_dict, dataloader_local_path)\n\n        # latest checkpointed iteration tracker (for atomic usage)\n        local_latest_checkpointed_iteration = os.path.join(\n            self.config.trainer.default_local_dir, \"latest_checkpointed_iteration.txt\"\n        )\n        with open(local_latest_checkpointed_iteration, \"w\") as f:\n            f.write(str(self.global_steps))\n\n    def _load_checkpoint(self):\n        if self.config.trainer.resume_mode == \"disable\":\n            return 0\n\n        # load from hdfs\n        if self.config.trainer.default_hdfs_dir is not None:\n            raise NotImplementedError(\"load from hdfs is not implemented yet\")\n        else:\n            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path\n            if not os.path.isabs(checkpoint_folder):\n                working_dir = os.getcwd()\n                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)\n            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest\n\n        # find global_step_folder\n        if self.config.trainer.resume_mode == \"auto\":\n            if global_step_folder is None:\n                print(\"Training from scratch\")\n                return 0\n        else:\n            if self.config.trainer.resume_mode == \"resume_path\":\n                assert isinstance(self.config.trainer.resume_from_path, str), \"resume ckpt must be str type\"\n                assert \"global_step_\" in self.config.trainer.resume_from_path, (\n                    \"resume ckpt must specify the global_steps\"\n                )\n                global_step_folder = self.config.trainer.resume_from_path\n                if not os.path.isabs(global_step_folder):\n                    working_dir = os.getcwd()\n                    global_step_folder = os.path.join(working_dir, global_step_folder)\n        print(f\"Load from checkpoint folder: {global_step_folder}\")\n        # set global step\n        self.global_steps = int(global_step_folder.split(\"global_step_\")[-1])\n\n        print(f\"Setting global step to {self.global_steps}\")\n        print(f\"Resuming from {global_step_folder}\")\n\n        actor_path = os.path.join(global_step_folder, \"actor\")\n        critic_path = os.path.join(global_step_folder, \"critic\")\n        # load actor\n        self.actor_rollout_wg.load_checkpoint(\n            actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n        )\n        # load critic\n        if self.use_critic:\n            self.critic_wg.load_checkpoint(\n                critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n            )\n\n        # load dataloader,\n        # TODO: from remote not implemented yet\n        dataloader_local_path = os.path.join(global_step_folder, \"data.pt\")\n        if os.path.exists(dataloader_local_path):\n            dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)\n            self.train_dataloader.load_state_dict(dataloader_state_dict)\n        else:\n            print(f\"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch\")\n\n    def _start_profiling(self, do_profile: bool) -> None:\n        \"\"\"Start profiling for all worker groups if profiling is enabled.\"\"\"\n        if do_profile:\n            self.actor_rollout_wg.start_profile(role=\"e2e\", profile_step=self.global_steps)\n            if self.use_reference_policy:\n                self.ref_policy_wg.start_profile(profile_step=self.global_steps)\n            if self.use_critic:\n                self.critic_wg.start_profile(profile_step=self.global_steps)\n            if self.use_rm:\n                self.rm_wg.start_profile(profile_step=self.global_steps)\n\n    def _stop_profiling(self, do_profile: bool) -> None:\n        \"\"\"Stop profiling for all worker groups if profiling is enabled.\"\"\"\n        if do_profile:\n            self.actor_rollout_wg.stop_profile()\n            if self.use_reference_policy:\n                self.ref_policy_wg.stop_profile()\n            if self.use_critic:\n                self.critic_wg.stop_profile()\n            if self.use_rm:\n                self.rm_wg.stop_profile()\n\n    def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix=\"global_seqlen\"):\n        \"\"\"Reorder the batchmeta on single controller such that each dp rank gets similar total tokens\"\"\"\n        data = asyncio.run(data_system_client.async_get_data(batch))\n\n        attention_mask = data[\"attention_mask\"]\n        batch_size = attention_mask.shape[0]\n        global_seqlen_lst = data[\"attention_mask\"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)\n        world_size = self.actor_rollout_wg.world_size\n        global_partition_lst = get_seqlen_balanced_partitions(\n            global_seqlen_lst, k_partitions=world_size, equal_size=True\n        )\n        # reorder based on index. The data will be automatically equally partitioned by dispatch function\n        global_idx = [j for partition in global_partition_lst for j in partition]\n        global_balance_stats = log_seqlen_unbalance(\n            seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix\n        )\n        metrics.update(global_balance_stats)\n        return global_idx\n\n    @classmethod\n    def repeat_dict(\n        cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_times=2, interleave=True\n    ) -> dict[str, torch.Tensor | np.ndarray]:\n        \"\"\"\n        Repeat the batch dict a specified number of times.\n\n        Args:\n            repeat_times (int): Number of times to repeat the data.\n            interleave (bool): Whether to interleave the repeated data.\n\n        Returns:\n            dict: A new dict with repeated data.\n        \"\"\"\n        if repeat_times == 1:\n            return batch_dict\n\n        repeated_batch_dict = {}\n        if batch_dict:\n            if interleave:\n                # Interleave the data\n                for key, val in batch_dict.items():\n                    if isinstance(val, torch.Tensor):\n                        repeated_batch_dict[key] = val.repeat_interleave(repeat_times, dim=0)\n                    elif isinstance(val, np.ndarray):\n                        repeated_batch_dict[key] = np.repeat(val, repeat_times, axis=0)\n                    else:\n                        raise ValueError(f\"Unsupported type in data {type(val)}\")\n            else:\n                # Stack the data\n                for key, val in batch_dict.items():\n                    if isinstance(val, torch.Tensor):\n                        repeated_batch_dict[key] = (\n                            val.unsqueeze(0).expand(repeat_times, *val.shape).reshape(-1, *val.shape[1:])\n                        )\n                    elif isinstance(val, np.ndarray):\n                        repeated_batch_dict[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))\n                    else:\n                        raise ValueError(f\"Unsupported type in data {type(val)}\")\n        return repeated_batch_dict\n\n    @classmethod\n    def dict_to_tensordict(cls, data: dict[str, torch.Tensor | np.ndarray]) -> TensorDict:\n        \"\"\"\n        Create a TensorDict from a dict of tensors and non_tensors.\n        Note that this requires tensordict version at least 0.10\n        \"\"\"\n        assert parse_version(tensordict.__version__) >= parse_version(\"0.10\"), (\n            \"Storing non-tensor data in TensorDict at least requires tensordict version 0.10\"\n        )\n        tensors_batch = {}\n        batch_size = None\n\n        for key, val in data.items():\n            if isinstance(val, torch.Tensor | np.ndarray):\n                tensors_batch[key] = val\n            else:\n                raise ValueError(f\"Unsupported type in data {type(val)}\")\n\n            if batch_size is None:\n                batch_size = len(val)\n            else:\n                assert len(val) == batch_size\n\n        if batch_size is None:\n            batch_size = []\n        else:\n            batch_size = [batch_size]\n\n        return TensorDict(tensors_batch, batch_size=batch_size)\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        if self.config.actor_rollout_ref.rollout.get(\"skip_rollout\", False):\n            rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)\n            rollout_skip.wrap_generate_sequences()\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n        self.max_steps_duration = 0\n\n        prev_step_profile = False\n        curr_step_profile = (\n            self.global_steps in self.config.global_profiler.steps\n            if self.config.global_profiler.steps is not None\n            else False\n        )\n        next_step_profile = False\n\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n                timing_raw = {}\n                base_get_meta_kwargs = dict(\n                    batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,\n                    global_step=self.global_steps - 1,  # self.global_steps starts from 1\n                    get_n_samples=False,\n                )\n\n                with marked_timer(\"start_profile\", timing_raw):\n                    self._start_profiling(\n                        not prev_step_profile and curr_step_profile\n                        if self.config.global_profiler.profile_continuous_steps\n                        else curr_step_profile\n                    )\n\n                # add uid to batch\n                batch_dict[\"uid\"] = np.array(\n                    [str(uuid.uuid4()) for _ in range(len(batch_dict[\"input_ids\"]))], dtype=object\n                )\n                # When n > 1, repeat input data before putting to data system, simulating DataProto repeat.\n                repeated_batch_dict = self.repeat_dict(\n                    batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True\n                )\n                batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict)\n                asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1))\n\n                gen_meta = asyncio.run(\n                    self.data_system_client.async_get_meta(\n                        data_fields=list(batch.keys()),  # TODO: (TQ) Get metadata by specified fields\n                        task_name=\"generate_sequences\",\n                        **base_get_meta_kwargs,\n                    )\n                )\n                # pass global_steps to trace\n                gen_meta.set_extra_info(\"global_steps\", self.global_steps)\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                with marked_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with marked_timer(\"gen\", timing_raw, color=\"red\"):\n                        if not self.async_rollout_mode:\n                            gen_output_meta = self.actor_rollout_wg.generate_sequences(gen_meta)\n                        else:\n                            gen_output_meta = self.async_rollout_manager.generate_sequences(gen_meta)\n                        timing_raw.update(gen_output_meta.extra_info[\"timing\"])\n                        gen_output_meta.extra_info.pop(\"timing\", None)\n\n                    # TODO: (TQ) support transfer queue\n                    # if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                    #     if self.reward_fn is None:\n                    #         raise ValueError(\"A reward_fn is required for REMAX advantage estimation.\")\n                    #\n                    #     with marked_timer(\"gen_max\", timing_raw, color=\"purple\"):\n                    #         gen_baseline_meta = deepcopy(gen_meta)\n                    #         gen_baseline_meta.extra_info[\"do_sample\"] = False\n                    #         if not self.async_rollout_mode:\n                    #             gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_meta)\n                    #         else:\n                    #             gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_meta)\n                    #         batch = batch.union(gen_baseline_output)\n                    #         reward_baseline_tensor = self.reward_fn(batch)\n                    #         reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n                    #\n                    #         batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))\n                    #\n                    #         batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n                    #\n                    #         del gen_baseline_batch, gen_baseline_output\n\n                    batch_meta: BatchMeta = gen_meta.union(gen_output_meta)\n\n                    if \"response_mask\" not in batch_meta.field_names:\n                        response_mask_meta = asyncio.run(\n                            self.data_system_client.async_get_meta(\n                                data_fields=[\"responses\", \"attention_mask\"],\n                                task_name=\"compute_response_mask\",\n                                **base_get_meta_kwargs,\n                            )\n                        )\n                        response_mask_output_meta = compute_response_mask(response_mask_meta, self.data_system_client)\n                        batch_meta = batch_meta.union(response_mask_output_meta)\n\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    # TODO: Decouple the DP balancing and mini-batching.\n                    balanced_idx = None\n                    if self.config.trainer.balance_batch:\n                        attention_mask_meta = asyncio.run(\n                            self.data_system_client.async_get_meta(\n                                data_fields=[\"attention_mask\"],\n                                task_name=\"balance_batch\",\n                                **base_get_meta_kwargs,\n                            )\n                        )\n\n                        balanced_idx = self._balance_batch(\n                            attention_mask_meta, self.data_system_client, metrics=metrics\n                        )\n                        batch_meta.reorder(balanced_idx)\n\n                    # compute global_valid tokens\n                    data = asyncio.run(self.data_system_client.async_get_data(attention_mask_meta))\n                    batch_meta.extra_info[\"global_token_num\"] = torch.sum(data[\"attention_mask\"], dim=-1).tolist()\n\n                    with marked_timer(\"reward\", timing_raw, color=\"yellow\"):\n                        # compute reward model score\n                        if self.use_rm and \"rm_scores\" not in batch_meta.field_names:\n                            reward_meta = self.rm_wg.compute_rm_score(batch_meta)\n                            batch_meta = batch_meta.union(reward_meta)\n\n                        compute_reward_fields = [\n                            \"responses\",\n                            \"prompts\",\n                            \"attention_mask\",\n                            \"reward_model\",\n                            \"data_source\",\n                        ]\n                        if \"rm_scores\" in batch_meta.field_names:\n                            compute_reward_fields.append(\"rm_scores\")\n                        compute_reward_meta = asyncio.run(\n                            self.data_system_client.async_get_meta(\n                                data_fields=compute_reward_fields,\n                                task_name=\"compute_reward\",\n                                **base_get_meta_kwargs,\n                            )\n                        )\n                        compute_reward_meta.reorder(balanced_idx)\n                        if self.config.reward_model.launch_reward_fn_async:\n                            future_reward = compute_reward_async_decorated(\n                                data=compute_reward_meta,\n                                reward_fn=self.reward_fn,\n                            )\n                        else:\n                            reward_tensor, reward_extra_infos_dict = compute_reward_decorated(\n                                compute_reward_meta, self.reward_fn\n                            )\n                        batch_meta = batch_meta.union(compute_reward_meta)\n\n                    # recompute old_log_probs\n                    with marked_timer(\"old_log_prob\", timing_raw, color=\"blue\"):\n                        old_log_prob_meta = asyncio.run(\n                            self.data_system_client.async_get_meta(\n                                data_fields=[\n                                    \"input_ids\",\n                                    \"attention_mask\",\n                                    \"position_ids\",\n                                    \"prompts\",\n                                    \"responses\",\n                                    \"response_mask\",\n                                    \"data_source\",\n                                    \"reward_model\",\n                                    \"extra_info\",\n                                    \"uid\",\n                                    \"index\",\n                                    \"tools_kwargs\",\n                                    \"interaction_kwargs\",\n                                    \"ability\",\n                                ],\n                                task_name=\"compute_log_prob\",\n                                **base_get_meta_kwargs,\n                            )\n                        )\n                        old_log_prob_meta.reorder(balanced_idx)\n\n                        old_log_prob_output_meta = self.actor_rollout_wg.compute_log_prob(old_log_prob_meta)\n                        data = asyncio.run(self.data_system_client.async_get_data(old_log_prob_output_meta))\n                        entropys = data[\"entropys\"]\n                        response_masks = data[\"response_mask\"]\n                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                        old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n                        metrics.update(old_log_prob_metrics)\n\n                        batch_meta = batch_meta.union(old_log_prob_output_meta)\n\n                        if \"rollout_log_probs\" in batch_meta.field_names:\n                            # TODO: we may want to add diff of probs too.\n                            data_fields = [\"rollout_log_probs\", \"old_log_probs\", \"responses\"]\n                            if \"response_mask\" in batch_meta.field_names:\n                                data_fields.append(\"response_mask\")\n                            if \"attention_mask\" in batch_meta.field_names:\n                                data_fields.append(\"attention_mask\")\n                            calculate_debug_metrics_meta = asyncio.run(\n                                self.data_system_client.async_get_meta(\n                                    data_fields=data_fields,\n                                    task_name=\"calculate_debug_metrics\",\n                                    **base_get_meta_kwargs,\n                                )\n                            )\n                            calculate_debug_metrics_meta.reorder(balanced_idx)\n\n                            metrics.update(calculate_debug_metrics_decorated(calculate_debug_metrics_meta))\n\n                    if self.use_reference_policy:\n                        # compute reference log_prob\n                        ref_log_prob_meta = asyncio.run(\n                            self.data_system_client.async_get_meta(\n                                data_fields=[\n                                    \"input_ids\",\n                                    \"attention_mask\",\n                                    \"position_ids\",\n                                    \"prompts\",\n                                    \"responses\",\n                                    \"response_mask\",\n                                    \"old_log_probs\",\n                                    \"data_source\",\n                                    \"reward_model\",\n                                    \"extra_info\",\n                                    \"uid\",\n                                    \"index\",\n                                    \"tools_kwargs\",\n                                    \"interaction_kwargs\",\n                                    \"ability\",\n                                ],\n                                task_name=\"compute_ref_log_prob\",\n                                **base_get_meta_kwargs,\n                            )\n                        )\n                        ref_log_prob_meta.reorder(balanced_idx)\n                        with marked_timer(\"ref\", timing_raw, color=\"olive\"):\n                            if not self.ref_in_actor:\n                                ref_log_prob_output_meta = self.ref_policy_wg.compute_ref_log_prob(ref_log_prob_meta)\n                            else:\n                                ref_log_prob_output_meta = self.actor_rollout_wg.compute_ref_log_prob(ref_log_prob_meta)\n                            batch_meta = batch_meta.union(ref_log_prob_output_meta)\n\n                    # compute values\n                    if self.use_critic:\n                        with marked_timer(\"values\", timing_raw, color=\"cyan\"):\n                            values_meta = self.critic_wg.compute_values(batch_meta)\n                            batch_meta = batch_meta.union(values_meta)\n\n                    with marked_timer(\"adv\", timing_raw, color=\"brown\"):\n                        # we combine with rule-based rm\n                        reward_extra_infos_dict: dict[str, list]\n                        if self.config.reward_model.launch_reward_fn_async:\n                            reward_tensor, reward_extra_infos_dict = ray.get(future_reward)\n                        reward_td = TensorDict({\"token_level_scores\": reward_tensor}, batch_size=reward_tensor.size(0))\n                        asyncio.run(self.data_system_client.async_put(data=reward_td, metadata=batch_meta))\n                        batch_meta.add_fields(reward_td)\n\n                        if reward_extra_infos_dict:\n                            reward_extra_infos_dict_new = {k: np.array(v) for k, v in reward_extra_infos_dict.items()}\n                            reward_extra_infos_td = self.dict_to_tensordict(reward_extra_infos_dict_new)\n                            asyncio.run(\n                                self.data_system_client.async_put(data=reward_extra_infos_td, metadata=batch_meta)\n                            )\n                            batch_meta.add_fields(reward_extra_infos_td)\n\n                        # compute rewards. apply_kl_penalty if available\n                        if self.config.algorithm.use_kl_in_reward:\n                            apply_kl_penalty_fields = [\n                                \"response_mask\",\n                                \"token_level_scores\",\n                                \"old_log_probs\",\n                                \"ref_log_prob\",\n                            ]\n                            apply_kl_penalty_meta = asyncio.run(\n                                self.data_system_client.async_get_meta(\n                                    data_fields=apply_kl_penalty_fields,\n                                    task_name=\"apply_kl_penalty\",\n                                    **base_get_meta_kwargs,\n                                )\n                            )\n                            apply_kl_penalty_meta.reorder(balanced_idx)\n                            token_level_rewards, kl_metrics = apply_kl_penalty(\n                                apply_kl_penalty_meta,\n                                kl_ctrl=self.kl_ctrl_in_reward,\n                                kl_penalty=self.config.algorithm.kl_penalty,\n                            )\n                            token_level_rewards_td = TensorDict(\n                                {\"token_level_rewards\": token_level_rewards}, batch_size=token_level_rewards.size(0)\n                            )\n                            asyncio.run(\n                                self.data_system_client.async_put(\n                                    data=token_level_rewards_td, metadata=apply_kl_penalty_meta\n                                )\n                            )\n                            apply_kl_penalty_meta.add_fields(token_level_rewards_td)\n\n                            metrics.update(kl_metrics)\n                            batch_meta = batch_meta.union(apply_kl_penalty_meta)\n                        else:\n                            token_level_scores_meta = asyncio.run(\n                                self.data_system_client.async_get_meta(\n                                    data_fields=[\"token_level_scores\"],\n                                    task_name=\"token_level_scores\",\n                                    **base_get_meta_kwargs,\n                                )\n                            )\n                            token_level_scores_meta.reorder(balanced_idx)\n                            data = asyncio.run(self.data_system_client.async_get_data(token_level_scores_meta))\n                            token_level_rewards_td = TensorDict(\n                                {\"token_level_rewards\": data[\"token_level_scores\"]},\n                                batch_size=data[\"token_level_scores\"].size(0),\n                            )\n                            asyncio.run(\n                                self.data_system_client.async_put(\n                                    data=token_level_rewards_td, metadata=token_level_scores_meta\n                                )\n                            )\n                            batch_meta.add_fields(token_level_rewards_td)\n\n                        # compute advantages, executed on the driver process\n\n                        norm_adv_by_std_in_grpo = self.config.algorithm.get(\n                            \"norm_adv_by_std_in_grpo\", True\n                        )  # GRPO adv normalization factor\n\n                        assert \"response_mask\" in batch_meta.field_names, (\n                            f\"`response_mask` must be in batch_meta {batch_meta.field_names} for advantage computation\"\n                        )\n                        compute_advantage_fields = [\n                            \"response_mask\",\n                            \"token_level_rewards\",\n                        ]\n                        if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:\n                            compute_advantage_fields.append(\"values\")\n                        elif self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO:\n                            compute_advantage_fields.append(\"uid\")\n                        else:\n                            if \"uid\" in batch_meta.field_names:\n                                compute_advantage_fields.append(\"uid\")\n                            if \"reward_baselines\" in batch_meta.field_names:\n                                compute_advantage_fields.append(\"reward_baselines\")\n\n                        compute_advantage_meta = asyncio.run(\n                            self.data_system_client.async_get_meta(\n                                data_fields=compute_advantage_fields,\n                                task_name=\"compute_advantage\",\n                                **base_get_meta_kwargs,\n                            )\n                        )\n                        compute_advantage_meta.reorder(balanced_idx)\n\n                        advantages, returns = compute_advantage(\n                            compute_advantage_meta,\n                            adv_estimator=self.config.algorithm.adv_estimator,\n                            gamma=self.config.algorithm.gamma,\n                            lam=self.config.algorithm.lam,\n                            num_repeat=self.config.actor_rollout_ref.rollout.n,\n                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                            config=self.config.algorithm,\n                        )\n\n                        advantages_td = TensorDict(\n                            {\"advantages\": advantages, \"returns\": returns}, batch_size=advantages.size(0)\n                        )\n                        asyncio.run(\n                            self.data_system_client.async_put(data=advantages_td, metadata=compute_advantage_meta)\n                        )\n                        compute_advantage_meta.add_fields(advantages_td)\n\n                        batch_meta = batch_meta.union(compute_advantage_meta)\n\n                    # update critic\n                    if self.use_critic:\n                        with marked_timer(\"update_critic\", timing_raw, color=\"pink\"):\n                            critic_output_meta = self.critic_wg.update_critic(batch_meta)\n                            batch_meta = batch_meta.union(critic_output_meta)\n                        critic_output_metrics = reduce_metrics(critic_output_meta.extra_info[\"metrics\"])\n                        metrics.update(critic_output_metrics)\n\n                    # implement critic warmup\n                    if self.config.trainer.critic_warmup <= self.global_steps:\n                        # update actor\n                        with marked_timer(\"update_actor\", timing_raw, color=\"red\"):\n                            batch_meta.extra_info[\"multi_turn\"] = (\n                                self.config.actor_rollout_ref.rollout.multi_turn.enable\n                            )\n\n                            update_actor_meta = asyncio.run(\n                                self.data_system_client.async_get_meta(\n                                    data_fields=[\n                                        \"input_ids\",\n                                        \"attention_mask\",\n                                        \"position_ids\",\n                                        \"prompts\",\n                                        \"responses\",\n                                        \"response_mask\",\n                                        \"old_log_probs\",\n                                        \"ref_log_prob\",\n                                        \"advantages\",\n                                        \"returns\",\n                                        \"token_level_rewards\",\n                                        \"token_level_scores\",\n                                        \"data_source\",\n                                        \"reward_model\",\n                                        \"extra_info\",\n                                        \"uid\",\n                                        \"index\",\n                                        \"tools_kwargs\",\n                                        \"interaction_kwargs\",\n                                        \"ability\",\n                                    ],\n                                    batch_size=self.config.data.train_batch_size\n                                    * self.config.actor_rollout_ref.rollout.n,\n                                    global_step=self.global_steps - 1,\n                                    get_n_samples=False,\n                                    task_name=\"update_actor\",\n                                )\n                            )\n                            update_actor_meta.reorder(balanced_idx)\n                            update_actor_meta.set_extra_info(\n                                \"global_token_num\", batch_meta.get_extra_info(\"global_token_num\")\n                            )\n                            update_actor_meta.set_extra_info(\"temperature\", batch_meta.get_extra_info(\"temperature\"))\n\n                            actor_output_meta = self.actor_rollout_wg.update_actor(update_actor_meta)\n                            batch_meta = batch_meta.union(actor_output_meta)\n                        actor_output_metrics = reduce_metrics(actor_output_meta.extra_info[\"metrics\"])\n                        metrics.update(actor_output_metrics)\n\n                    # Log rollout generations if enabled\n                    rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n                    if rollout_data_dir:\n                        data_fields = [\"prompts\", \"responses\", \"token_level_scores\", \"reward_model\"]\n                        if \"request_id\" in batch_meta.field_names:\n                            data_fields.append(\"request_id\")\n                        log_rollout_meta = asyncio.run(\n                            self.data_system_client.async_get_meta(\n                                data_fields=data_fields,\n                                batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,\n                                global_step=self.global_steps - 1,\n                                get_n_samples=False,\n                                task_name=\"log_rollout\",\n                            )\n                        )\n                        log_rollout_meta.reorder(balanced_idx)\n                        self._log_rollout_data(log_rollout_meta, reward_extra_infos_dict, timing_raw, rollout_data_dir)\n\n                # TODO: clear meta after iteration\n\n                # TODO: validate\n                if (\n                    self.val_reward_fn is not None\n                    and self.config.trainer.test_freq > 0\n                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                ):\n                    with marked_timer(\"testing\", timing_raw, color=\"green\"):\n                        val_metrics: dict = self._validate()\n                        if is_last_step:\n                            last_val_metrics = val_metrics\n                    metrics.update(val_metrics)\n\n                # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.\n                esi_close_to_expiration = should_save_ckpt_esi(\n                    max_steps_duration=self.max_steps_duration,\n                    redundant_time=self.config.trainer.esi_redundant_time,\n                )\n                # Check if the conditions for saving a checkpoint are met.\n                # The conditions include a mandatory condition (1) and\n                # one of the following optional conditions (2/3/4):\n                # 1. The save frequency is set to a positive value.\n                # 2. It's the last training step.\n                # 3. The current step number is a multiple of the save frequency.\n                # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.\n                if self.config.trainer.save_freq > 0 and (\n                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration\n                ):\n                    if esi_close_to_expiration:\n                        print(\"Force saving checkpoint: ESI instance expiration approaching.\")\n                    with marked_timer(\"save_checkpoint\", timing_raw, color=\"green\"):\n                        self._save_checkpoint()\n\n                with marked_timer(\"stop_profile\", timing_raw):\n                    next_step_profile = (\n                        self.global_steps + 1 in self.config.global_profiler.steps\n                        if self.config.global_profiler.steps is not None\n                        else False\n                    )\n                    self._stop_profiling(\n                        curr_step_profile and not next_step_profile\n                        if self.config.global_profiler.profile_continuous_steps\n                        else curr_step_profile\n                    )\n                    prev_step_profile = curr_step_profile\n                    curr_step_profile = next_step_profile\n\n                steps_duration = timing_raw[\"step\"]\n                self.max_steps_duration = max(self.max_steps_duration, steps_duration)\n\n                # training metrics\n                metrics.update(\n                    {\n                        \"training/global_step\": self.global_steps,\n                        \"training/epoch\": epoch,\n                    }\n                )\n                # collect metrics\n                compute_data_metrics_fields = [\n                    \"token_level_rewards\",\n                    \"token_level_scores\",\n                    \"advantages\",\n                    \"returns\",\n                    \"responses\",\n                    \"attention_mask\",\n                    \"response_mask\",\n                ]\n                if \"__num_turns__\" in batch_meta.field_names:\n                    compute_data_metrics_fields.append(\"__num_turns__\")\n                if \"tool_call_counts\" in batch_meta.field_names:\n                    compute_data_metrics_fields.append(\"tool_call_counts\")\n                compute_data_metrics_meta = asyncio.run(\n                    self.data_system_client.async_get_meta(\n                        data_fields=compute_data_metrics_fields,\n                        task_name=\"compute_data_metrics\",\n                        **base_get_meta_kwargs,\n                    )\n                )\n                compute_data_metrics_meta.reorder(balanced_idx)\n                metrics.update(\n                    compute_data_metrics_decorated(batch=compute_data_metrics_meta, use_critic=self.use_critic)\n                )\n\n                compute_timing_metrics_fields = [\"responses\", \"attention_mask\"]\n                compute_timing_metrics_meta = asyncio.run(\n                    self.data_system_client.async_get_meta(\n                        data_fields=compute_timing_metrics_fields,\n                        task_name=\"compute_timing_metrics\",\n                        **base_get_meta_kwargs,\n                    )\n                )\n                compute_timing_metrics_meta.reorder(balanced_idx)\n                metrics.update(\n                    compute_timing_metrics_decorated(batch=compute_timing_metrics_meta, timing_raw=timing_raw)\n                )\n\n                compute_throughout_metrics_meta = BatchMeta(\n                    samples=[],\n                    extra_info={\"global_token_num\": batch_meta.get_extra_info(\"global_token_num\")},\n                )\n                # TODO: implement actual tflpo and theoretical tflpo\n                n_gpus = self.resource_pool_manager.get_n_gpus()\n                metrics.update(\n                    compute_throughout_metrics_decorated(\n                        batch=compute_throughout_metrics_meta, timing_raw=timing_raw, n_gpus=n_gpus\n                    )\n                )\n\n                # this is experimental and may be changed/removed in the future in favor of a general-purpose one\n                if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):\n                    # TODO: (TQ) support transfer queue\n                    self.train_dataloader.sampler.update(batch=batch)\n\n                asyncio.run(self.data_system_client.async_clear(self.global_steps - 1))\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                progress_bar.update(1)\n                self.global_steps += 1\n\n                if (\n                    hasattr(self.config.actor_rollout_ref.actor, \"profiler\")\n                    and self.config.actor_rollout_ref.actor.profiler.tool == \"torch_memory\"\n                ):\n                    self.actor_rollout_wg.dump_memory_snapshot(\n                        tag=f\"post_update_step{self.global_steps}\", sub_dir=f\"step{self.global_steps}\"\n                    )\n\n                if is_last_step:\n                    pprint(f\"Final validation metrics: {last_val_metrics}\")\n                    progress_bar.close()\n                    return\n\n                # this is experimental and may be changed/removed in the future\n                # in favor of a general-purpose data buffer pool\n                if hasattr(self.train_dataset, \"on_batch_end\"):\n                    # The dataset may be changed after each training batch\n                    # TODO: (TQ) support transfer queue\n                    self.train_dataset.on_batch_end(batch=batch)\n"
  },
  {
    "path": "verl_distillation/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh",
    "content": "set -x\n\nproject_name='GRPO-Qwen3'\nexp_name='GRPO-Qwen3-8B-npu'\ngen_tp=2\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-8B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\npython3 -m recipe.transfer_queue.main_ppo \\\n    --config-name='transfer_queue_ppo_trainer' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=${MODEL_PATH} \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.default_local_dir=${CKPTS_DIR} \\\n    trainer.device=npu \\\n    trainer.resume_mode=auto \\\n    actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \\\n    ++actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \\\n    ++actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \\\n    trainer.val_before_train=False \\\n    trainer.save_freq=5 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 \\\n    +trainer.num_global_batch=1 \\\n    +trainer.num_data_storage_units=2 \\\n    +trainer.num_data_controllers=1"
  },
  {
    "path": "verl_distillation/requirements-cuda.txt",
    "content": "flash-attn"
  },
  {
    "path": "verl_distillation/requirements-npu.txt",
    "content": "# requirements.txt records the full set of dependencies for development\naccelerate\ncodetiming\ndatasets\ndill\nhydra-core\nnumpy<2.0.0\npandas\npeft>=0.15.2\npyarrow>=15.0.0\npybind11\npylatexenc\ntensordict>=0.8.0,<=0.10.0,!=0.9.0\ntransformers==4.52.4\nray==2.46.0\nwandb\nmathruler\ntorchdata\neinops\nqwen_vl_utils\ntorchvision==0.20.1\n"
  },
  {
    "path": "verl_distillation/requirements.txt",
    "content": "# requirements.txt records the full set of dependencies for development\naccelerate\ncodetiming\nclick==8.0.4\ndatasets\ndill\n# flash-attn\nhydra-core\nliger-kernel\nnumpy<2.0.0\npandas\npeft\npyarrow>=19.0.0\npybind11\npylatexenc\npre-commit\nray==2.49.0\ntensordict>=0.8.0,<=0.9.1,!=0.9.0\ntorchdata\ntransformers\n# vllm==0.8.4\nopentelemetry-api>=1.26.0,<1.27.0\nopentelemetry-sdk>=1.26.0,<1.27.0\nopentelemetry-exporter-otlp-proto-grpc>=1.26.0,<1.27.0\nopentelemetry-exporter-otlp-proto-http>=1.26.0,<1.27.0\nwandb\npackaging>=20.0\nuvicorn\nfastapi\nlatex2sympy2_extended\nmath_verify\nsglang==0.5.2\n"
  },
  {
    "path": "verl_distillation/requirements_sglang.txt",
    "content": "# requirements.txt records the full set of dependencies for development\naccelerate\ncodetiming\ndatasets\ndill\nflash-attn\nhydra-core\nnumpy<2.0.0\npandas\npeft\npyarrow>=19.0.0\npybind11\npylatexenc\nray[default]>=2.10\ntensordict>=0.8.0,<=0.10.0,!=0.9.0\ntorchdata\ntorchvision\ntransformers\nwandb\nsglang[all]==0.5.2\nhuggingface_hub\n"
  },
  {
    "path": "verl_distillation/requirements_transferqueue.txt",
    "content": "# requirements.txt records the full set of dependencies for development\ngit+https://github.com/TransferQueue/TransferQueue.git@68c04e7\n"
  },
  {
    "path": "verl_distillation/scripts/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/scripts/converter_hf_to_mcore.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport os\nimport warnings\nfrom contextlib import contextmanager\nfrom importlib.metadata import version\nfrom typing import Any, Callable, ContextManager, Optional\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\ntry:\n    # NPU patch\n    import mindspeed.megatron_adaptor  # noqa: F401\nexcept ImportError:\n    pass\n\nfrom accelerate import init_empty_weights\nfrom megatron.core import dist_checkpointing\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.dist_checkpointing.mapping import ShardedTensor\nfrom megatron.core.dist_checkpointing.serialization import StrictHandling\nfrom megatron.core.models.gpt.gpt_model import ModelType\nfrom megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed\nfrom packaging.version import Version\nfrom transformers import AutoConfig\n\nfrom verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards\nfrom verl.models.mcore import hf_to_mcore_config\nfrom verl.utils.device import get_device_name, get_torch_device\nfrom verl.utils.megatron_utils import get_model\n\n\ndef _init_args():\n    \"\"\"\n    Examples:\n\n    1. single rank conversion for any model:\n        > python converter_hf_to_mcore.py --hf_model_path %{hf_model} --output_path ${output_path}\n    2. distributed conversion for DeepseekV3 671B:\n        > torchrun --nproc_per_node 1 --nnodes 4 --node_rank ${RANK} converter_hf_to_mcore.py \\\n          --hf_model_path %{hf_model} --output_path ${output_path}\n    \"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--hf_model_path\", type=str, required=True, help=\"The path for the huggingface model\")\n    parser.add_argument(\"--output_path\", type=str, required=True, help=\"The path for the output mcore model\")\n    parser.add_argument(\"--use_cpu_initialization\", action=\"store_true\", help=\"Whether to use cpu initialization\")\n    parser.add_argument(\"--test\", action=\"store_true\", help=\"Whether to test the conversion\")\n    parser.add_argument(\"--trust_remote_code\", action=\"store_true\", help=\"Whether to trust remote code\")\n    args = parser.parse_args()\n    return args\n\n\ndef test_conversion(megatron_model_provider, tfconfig, output_path, model):\n    ########### test ###########\n    # load model\n    model_test = get_model(\n        model_provider_func=megatron_model_provider,\n        model_type=ModelType.encoder_or_decoder,\n        wrap_with_ddp=True,\n        transformer_config=tfconfig,\n    )\n    ref_state_dict = model_test[0].module.sharded_state_dict()\n    dist_checkpointing.load(ref_state_dict, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED)\n\n    dut_state_dict = model[0].module.state_dict()\n    for name in dut_state_dict.keys():\n        if dut_state_dict[name] is None:\n            print(f\"[Warning] {name} is none in dut_state_dict\")\n            continue\n        dut_data = dut_state_dict[name].data\n        if name in ref_state_dict:\n            ref_data = ref_state_dict[name]\n            if isinstance(ref_data, ShardedTensor):\n                ref_data = ref_data.data.view(ref_data.local_shape)\n            else:\n                ref_data = ref_data.data\n            assert dut_data.shape == ref_data.shape, f\"{name=} {dut_data.shape=} {ref_data.shape=}\"\n            assert (dut_data == ref_data).all(), f\"{name} is not equal\"\n            print(f\"{name} is equal\")\n        else:\n            print(f\"[Warning] {name} is not in ref_state_dict\")\n    for name in ref_state_dict.keys():\n        if ref_state_dict[name] is None:\n            print(f\"[Warning] {name} is none in ref_state_dict\")\n            continue\n        ref_data = ref_state_dict[name]\n        if isinstance(ref_data, ShardedTensor):\n            ref_data = ref_data.data.view(ref_data.local_shape)\n        else:\n            ref_data = ref_data.data\n        if name in dut_state_dict:\n            dut_data = dut_state_dict[name].data\n            assert dut_data.shape == ref_data.shape, f\"{name=} {dut_data.shape=} {ref_data.shape=}\"\n            assert (dut_data == ref_data).all(), f\"{name} is not equal\"\n            print(f\"{name} is equal\")\n        else:\n            print(f\"[Warning] {name} is not in dut_state_dict\")\n    print(\"Conversion test passed!\")\n\n\n@torch.inference_mode()\ndef convert_checkpoint_from_transformers_to_megatron(\n    hf_model, model, hf_config, layer_start_end: Optional[tuple[int, int]] = None\n):\n    if layer_start_end is None:\n        layer_start_end = (0, len(model.decoder.layers))\n    layer_start, layer_end = layer_start_end\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    numel = 0\n\n    num_attention_heads = hf_config.num_attention_heads\n    num_key_value_heads = hf_config.num_key_value_heads\n    hidden_dim = hf_config.hidden_size\n    head_dim = getattr(hf_config, \"head_dim\", hidden_dim // num_attention_heads)\n    if num_attention_heads != num_key_value_heads:\n        print(\"[WARNING] Converting GQA model\")\n    has_qkv_bias = getattr(hf_config, \"qkv_bias\", False) or getattr(hf_config, \"attention_bias\", False)\n    has_share_expert = getattr(hf_config, \"shared_expert_intermediate_size\", None)\n    if pp_rank == 0:\n        numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)\n\n    assert len(model.decoder.layers) == (layer_end - layer_start), (\n        f\"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}\"\n    )\n    for layer_idx, (layer, hf_layer) in enumerate(\n        zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True)\n    ):\n        global_layer_idx = layer_idx + layer_start\n        numel_cur = numel\n        numel += safe_copy(hf_layer.input_layernorm.weight, layer.self_attention.linear_qkv.layer_norm_weight)\n\n        q = hf_layer.self_attn.q_proj.weight.view(\n            [num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1]\n        )\n        k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1])\n        v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1])\n        qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous()\n        numel += safe_copy(qkv, layer.self_attention.linear_qkv.weight)\n\n        if has_qkv_bias:\n            q_bias = hf_layer.self_attn.q_proj.bias.view([num_key_value_heads, -1])\n            k_bias = hf_layer.self_attn.k_proj.bias.view([num_key_value_heads, -1])\n            v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1])\n            qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous()\n            numel += safe_copy(qkv_bias, layer.self_attention.linear_qkv.bias)\n\n        if hasattr(hf_layer.self_attn, \"q_norm\"):\n            numel += safe_copy(hf_layer.self_attn.q_norm.weight.data, layer.self_attention.q_layernorm.weight)\n            numel += safe_copy(hf_layer.self_attn.k_norm.weight.data, layer.self_attention.k_layernorm.weight)\n\n        numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight)\n        numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight)\n\n        numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight)\n\n        for idx, hf_expert in enumerate(hf_layer.mlp.experts):\n            fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])\n            numel += safe_copy(fc1_weight, layer.mlp.experts.linear_fc1._parameters[f\"weight{idx}\"])\n            numel += safe_copy(hf_expert.down_proj.weight, layer.mlp.experts.linear_fc2._parameters[f\"weight{idx}\"])\n\n        if has_share_expert:\n            numel += safe_copy(hf_layer.mlp.shared_expert_gate.weight, layer.mlp.shared_experts.gate_weight)\n            shared_fc1_weight = torch.cat(\n                [hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight]\n            )\n            numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight)\n            numel += safe_copy(hf_layer.mlp.shared_expert.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight)\n        print(f\"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}\")\n\n    if pp_rank == pp_size - 1:\n        numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight)\n        numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight)\n    return numel\n\n\ndef safe_copy(\n    src_tensor: torch.Tensor,\n    dst_tensor: torch.Tensor,\n    skip_dtype_assert: bool = False,\n):\n    if not skip_dtype_assert:\n        if src_tensor.dtype != dst_tensor.dtype:\n            raise ValueError(f\"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}\")\n    assert src_tensor.shape == dst_tensor.shape\n    dst_tensor.data.copy_(src_tensor.data)\n    return src_tensor.numel()\n\n\n@torch.inference_mode()\ndef convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel, hf_config):\n    mgmodel = mgmodel.bfloat16()\n    hfmodel = hfmodel.bfloat16()\n    num_attention_heads = hf_config.num_attention_heads\n    num_query_groups = hf_config.num_key_value_heads\n    hidden_size = hf_config.hidden_size\n    head_dim = hidden_size // num_attention_heads\n\n    # 1. vision model\n    if Version(version(\"transformers\")) < Version(\"4.52.0\"):\n        print(\"Using transformers < 4.52 API to load vision model\")\n        hfvision = hfmodel.visual\n    else:\n        hfvision = hfmodel.model.visual\n    mgvision = mgmodel.vision_model\n    vision_hidden_size = mgvision.config.hidden_size\n    vision_num_query_groups = mgvision.config.num_query_groups\n    vision_head_dim = vision_hidden_size // mgvision.config.num_attention_heads\n    copied_numel = 0\n    safe_copy(hfvision.rotary_pos_emb.inv_freq, mgvision.rotary_pos_emb.inv_freq)\n    copied_numel += safe_copy(hfvision.patch_embed.proj.weight, mgvision.patch_embed.proj.weight)\n    for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers, strict=True):\n        # norm1 --> linear_qkv.norm\n        copied_numel += safe_copy(hfblock.norm1.weight, mgblock.self_attention.linear_qkv.layer_norm_weight)\n        # norm2 --> mlp.linear_fc1.norm\n        copied_numel += safe_copy(hfblock.norm2.weight, mgblock.mlp.linear_fc1.layer_norm_weight)\n        # qkv --> self_attention.linear_qkv\n        converted_weight = (\n            hfblock.attn.qkv.weight.view(3, vision_num_query_groups, -1, vision_head_dim, vision_hidden_size)\n            .transpose(0, 1)\n            .flatten(1, 2)\n            .reshape(-1, vision_hidden_size)\n            .contiguous()\n        )\n        copied_numel += safe_copy(converted_weight, mgblock.self_attention.linear_qkv.weight)\n        converted_bias = (\n            hfblock.attn.qkv.bias.view(3, vision_num_query_groups, -1)\n            .transpose(0, 1)\n            .flatten(1, 2)\n            .view(-1)\n            .contiguous()\n        )\n        copied_numel += safe_copy(converted_bias, mgblock.self_attention.linear_qkv.bias)\n        # proj --> self_attention.linear_proj\n        copied_numel += safe_copy(hfblock.attn.proj.weight, mgblock.self_attention.linear_proj.weight)\n        copied_numel += safe_copy(hfblock.attn.proj.bias, mgblock.self_attention.linear_proj.bias)\n        # mlp --> mlp: gate\n        fc1_weight = torch.cat([hfblock.mlp.gate_proj.weight, hfblock.mlp.up_proj.weight])\n        fc1_bias = torch.cat([hfblock.mlp.gate_proj.bias, hfblock.mlp.up_proj.bias])\n        copied_numel += safe_copy(fc1_weight, mgblock.mlp.linear_fc1.weight)\n        copied_numel += safe_copy(fc1_bias, mgblock.mlp.linear_fc1.bias)\n        copied_numel += safe_copy(hfblock.mlp.down_proj.weight, mgblock.mlp.linear_fc2.weight)\n        copied_numel += safe_copy(hfblock.mlp.down_proj.bias, mgblock.mlp.linear_fc2.bias)\n\n    # 2. vision projector\n    hfprojector = hfvision.merger\n    mgprojector = mgvision.projection\n    copied_numel += safe_copy(hfprojector.ln_q.weight, mgvision.decoder.final_layernorm.weight)\n\n    copied_numel += safe_copy(hfprojector.mlp[0].weight, mgprojector.encoder.linear_fc1.weight)\n    copied_numel += safe_copy(hfprojector.mlp[0].bias, mgprojector.encoder.linear_fc1.bias)\n    copied_numel += safe_copy(hfprojector.mlp[2].weight, mgprojector.encoder.linear_fc2.weight)\n    copied_numel += safe_copy(hfprojector.mlp[2].bias, mgprojector.encoder.linear_fc2.bias)\n    n_params = sum([t.numel() for t in hfvision.state_dict().values()])\n    assert n_params == copied_numel, f\"n_params={n_params} != copied_numel={copied_numel}\"\n    # 3. llm [just Qwen2]\n    if Version(version(\"transformers\")) < Version(\"4.52.0\"):\n        print(\"Using transformers < 4.52 API to load llm\")\n        hfllm = hfmodel.model\n    else:\n        hfllm = hfmodel.model.language_model\n    mgllm = mgmodel.language_model\n    copied_numel = 0\n    copied_numel += safe_copy(hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight)\n    layermaps = zip(mgllm.decoder.layers, hfllm.layers, strict=True)\n    for mglayer, hflayer in layermaps:\n        copied_numel += safe_copy(hflayer.input_layernorm.weight, mglayer.self_attention.linear_qkv.layer_norm_weight)\n\n        q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)\n        k_proj_weight = hflayer.self_attn.k_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)\n        v_proj_weight = hflayer.self_attn.v_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)\n        qkv_proj = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=1).view(-1, hidden_size).contiguous()\n        copied_numel += safe_copy(qkv_proj, mglayer.self_attention.linear_qkv.weight)\n\n        q_proj_bias = hflayer.self_attn.q_proj.bias.view(num_query_groups, -1)\n        k_proj_bias = hflayer.self_attn.k_proj.bias.view(num_query_groups, -1)\n        v_proj_bias = hflayer.self_attn.v_proj.bias.view(num_query_groups, -1)\n        qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=1).view(-1).contiguous()\n        copied_numel += safe_copy(qkv_bias, mglayer.self_attention.linear_qkv.bias)\n        copied_numel += safe_copy(hflayer.self_attn.o_proj.weight, mglayer.self_attention.linear_proj.weight)\n\n        fc1_weight = torch.cat([hflayer.mlp.gate_proj.weight, hflayer.mlp.up_proj.weight])\n        copied_numel += safe_copy(fc1_weight, mglayer.mlp.linear_fc1.weight)\n\n        copied_numel += safe_copy(hflayer.mlp.down_proj.weight, mglayer.mlp.linear_fc2.weight)\n        copied_numel += safe_copy(hflayer.post_attention_layernorm.weight, mglayer.mlp.linear_fc1.layer_norm_weight)\n\n    copied_numel += safe_copy(hfllm.norm.weight, mgllm.decoder.final_layernorm.weight)\n    if not hf_config.tie_word_embeddings:\n        safe_copy(hfmodel.lm_head.weight, mgllm.output_layer.weight)\n\n    n_params = sum([t.numel() for t in hfllm.state_dict().values()])\n\n    assert n_params == copied_numel, f\"n_params={n_params} != copied_numel={copied_numel}\"\n\n\n@torch.inference_mode()\ndef convert_checkpoint_from_transformers_to_megatron_dpskv3(\n    hf_model,\n    model,\n    hf_config,\n    tfconfig,\n    layer_start_end: Optional[tuple[int, int]] = None,\n):\n    warnings.warn(\"MTP model is not supported yet\", stacklevel=2)\n    if layer_start_end is None:\n        layer_start_end = (0, len(model.decoder.layers))\n    layer_start, layer_end = layer_start_end\n    numel: int = 0\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    if pp_rank == 0:\n        numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)\n\n    assert len(model.decoder.layers) == (layer_end - layer_start), (\n        f\"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}\"\n    )\n    for layer_idx, (layer, hf_layer) in enumerate(\n        zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True)\n    ):\n        global_layer_idx = layer_idx + layer_start\n        numel_cur: int = numel\n        numel += safe_copy(hf_layer.input_layernorm.weight, layer.input_layernorm.weight)\n\n        if hf_config.q_lora_rank is None:\n            numel += safe_copy(hf_layer.self_attn.q_proj.weight, layer.self_attention.linear_q_proj.weight)\n        else:\n            numel += safe_copy(hf_layer.self_attn.q_a_proj.weight, layer.self_attention.linear_q_down_proj.weight)\n            numel += safe_copy(hf_layer.self_attn.q_b_proj.weight, layer.self_attention.linear_q_up_proj.weight)\n            numel += safe_copy(\n                hf_layer.self_attn.q_a_layernorm.weight, layer.self_attention.linear_q_up_proj.layer_norm_weight\n            )\n\n        numel += safe_copy(\n            hf_layer.self_attn.kv_a_proj_with_mqa.weight, layer.self_attention.linear_kv_down_proj.weight\n        )\n        numel += safe_copy(hf_layer.self_attn.kv_b_proj.weight, layer.self_attention.linear_kv_up_proj.weight)\n        numel += safe_copy(\n            hf_layer.self_attn.kv_a_layernorm.weight, layer.self_attention.linear_kv_up_proj.layer_norm_weight\n        )\n        numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight)\n\n        if not hasattr(layer.mlp, \"router\"):\n            numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.mlp.linear_fc1.layer_norm_weight)\n            numel += safe_copy(\n                torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight]), layer.mlp.linear_fc1.weight\n            )\n            numel += safe_copy(hf_layer.mlp.down_proj.weight, layer.mlp.linear_fc2.weight)\n        else:\n            numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight)\n            # NOTE: the e_score_correction_bias in mcore model will be initialized with bfloat16 and \\\n            # recover to fp32 in the first forward. There is always a diff in the bias between two models (~0.3%)\n            numel += safe_copy(\n                hf_layer.mlp.gate.e_score_correction_bias, layer.mlp.router.expert_bias, skip_dtype_assert=True\n            )\n            if tfconfig.moe_grouped_gemm:\n                for i, hf_expert in enumerate(hf_layer.mlp.experts):\n                    fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])\n                    linear_fc1_weighti = getattr(layer.mlp.experts.linear_fc1, \"weight\" + str(i))\n                    numel += safe_copy(fc1_weight, linear_fc1_weighti)\n                    linear_fc2_weighti = getattr(layer.mlp.experts.linear_fc2, \"weight\" + str(i))\n                    numel += safe_copy(hf_expert.down_proj.weight, linear_fc2_weighti)\n            else:\n                for i, hf_expert in enumerate(hf_layer.mlp.experts):\n                    expert = layer.mlp.experts.local_experts[i]\n                    fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])\n                    numel += safe_copy(fc1_weight, expert.linear_fc1.weight)\n                    numel += safe_copy(hf_expert.down_proj.weight, expert.linear_fc2.weight)\n            numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight)\n            shared_fc1_weight = torch.cat(\n                [hf_layer.mlp.shared_experts.gate_proj.weight, hf_layer.mlp.shared_experts.up_proj.weight]\n            )\n            numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight)\n            numel += safe_copy(hf_layer.mlp.shared_experts.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight)\n        print(f\"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}\")\n        assert numel - numel_cur == sum([i.numel() for i in hf_layer.state_dict().values()]), \"numel mismatch\"\n\n    if pp_rank == pp_size - 1:\n        numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight)\n        if not hf_config.tie_word_embeddings:\n            numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight)\n    print(f\"{pp_rank=} {numel=}\")\n    return numel\n\n\n@contextmanager\ndef noop_context() -> Any:\n    yield\n\n\ndef support_distributed_convert(hf_config: AutoConfig) -> bool:\n    for arch in [\"DeepseekV3ForCausalLM\", \"Qwen3MoeForCausalLM\", \"Qwen2MoeForCausalLM\"]:\n        if arch in hf_config.architectures:\n            return True\n    return False\n\n\ndef convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False):\n    os.makedirs(output_path, exist_ok=True)\n    if len(os.listdir(output_path)) > 0 and not test:\n        print(f\"Output path {output_path} is not empty, skipping conversion\")\n        return\n\n    # init torch distributed and mpu\n    if \"WORLD_SIZE\" not in os.environ:\n        os.environ[\"RANK\"] = \"0\"\n        os.environ[\"WORLD_SIZE\"] = \"1\"\n        os.environ[\"MASTER_ADDR\"] = \"localhost\"\n        os.environ[\"MASTER_PORT\"] = \"12355\"\n\n    torch.distributed.init_process_group(\"nccl\")\n\n    rank = dist.get_rank()\n    local_rank = os.getenv(\"LOCAL_RANK\", 0)\n    world_size = dist.get_world_size()\n    get_torch_device().set_device(f\"{get_device_name()}:{local_rank}\")\n\n    mpu.initialize_model_parallel(\n        tensor_model_parallel_size=1,\n        pipeline_model_parallel_size=world_size,\n        virtual_pipeline_model_parallel_size=None,\n        context_parallel_size=1,\n        expert_model_parallel_size=1,\n    )\n    model_parallel_cuda_manual_seed(0)\n\n    # init hf config\n    hf_config = AutoConfig.from_pretrained(hf_model_path)\n    print(hf_config, flush=True)\n\n    if world_size > 1 and not support_distributed_convert(hf_config):\n        raise NotImplementedError(f\"distributed conversion is not supported for {hf_config.architectures} yet.\")\n\n    pipeline_shards = get_dynamic_pipeline_shards(hf_config.num_hidden_layers, world_size)\n    print(f\"Pipeline shards: {pipeline_shards}\", flush=True)\n\n    tfconfig = hf_to_mcore_config(\n        hf_config,\n        torch.bfloat16,\n        num_layers_in_first_pipeline_stage=pipeline_shards[0] if len(pipeline_shards) > 1 else None,\n        num_layers_in_last_pipeline_stage=pipeline_shards[-1] if len(pipeline_shards) > 2 else None,\n    )\n    tfconfig.use_cpu_initialization = use_cpu_initialization\n    tie_word_embeddings = getattr(hf_config, \"tie_word_embeddings\", False)\n\n    # init megatron model\n    def megatron_model_provider(pre_process, post_process):\n        from verl.models.mcore import init_mcore_model\n\n        parallel_model = init_mcore_model(\n            tfconfig,\n            hf_config,\n            pre_process,\n            post_process,\n            share_embeddings_and_output_weights=tie_word_embeddings,\n            value=False,\n        )\n        return parallel_model\n\n    context: Callable[..., ContextManager] = init_empty_weights if use_cpu_initialization else noop_context\n    with context():\n        model = get_model(\n            model_provider_func=megatron_model_provider,\n            model_type=ModelType.encoder_or_decoder,\n            wrap_with_ddp=False,\n            transformer_config=tfconfig,\n        )\n\n    if use_cpu_initialization:\n        # convert meta device to empty tensor so it can use `copy_` function\n        model[0].module = model[0].module.to_empty(device=\"cpu\")\n\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\")\n    from transformers import AutoModelForCausalLM, AutoModelForImageTextToText\n\n    # init hf model\n    if \"Qwen2_5_VLForConditionalGeneration\" in hf_config.architectures:\n        hf_model = AutoModelForImageTextToText.from_pretrained(\n            hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code\n        )\n    else:\n        hf_model = AutoModelForCausalLM.from_pretrained(\n            hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code\n        )\n    hf_state_dict = hf_model.state_dict()\n\n    # distributed convert\n    if world_size > 1 and support_distributed_convert(hf_config):\n        pipeline_cumsum = np.cumsum(pipeline_shards)\n        layer_start = 0 if rank == 0 else pipeline_cumsum[rank - 1]\n        layer_end = pipeline_cumsum[rank]\n        if \"DeepseekV3ForCausalLM\" in hf_config.architectures:\n            numel_partial: int = convert_checkpoint_from_transformers_to_megatron_dpskv3(\n                hf_model, model[0].module, hf_config, tfconfig=tfconfig, layer_start_end=(layer_start, layer_end)\n            )\n        elif \"Qwen3MoeForCausalLM\" in hf_config.architectures or \"Qwen2MoeForCausalLM\" in hf_config.architectures:\n            numel_partial: int = convert_checkpoint_from_transformers_to_megatron(\n                hf_model, model[0].module, hf_config, layer_start_end=(layer_start, layer_end)\n            )\n        else:\n            raise NotImplementedError(f\"Distributed conversion is not supported for {hf_config.architectures} yet.\")\n\n        numel_tensor = torch.tensor([numel_partial]).to(get_device_name())\n        dist.all_reduce(numel_tensor, op=dist.ReduceOp.SUM)\n        numel = int(numel_tensor.cpu().item())\n        print(f\"total numel={numel} vs {hf_model.num_parameters()=}\")\n        if numel != hf_model.num_parameters():\n            warnings.warn(f\"numel mismatch: {numel=} != {hf_model.num_parameters()=}\", stacklevel=1)\n\n    # load hf state dict to megatron model\n    elif \"Qwen2MoeForCausalLM\" in hf_config.architectures:\n        convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)\n    elif \"Qwen2_5_VLForConditionalGeneration\" in hf_config.architectures:\n        convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hf_model, model[0].module, hf_config)\n    elif \"DeepseekV3ForCausalLM\" in hf_config.architectures:\n        convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig)\n    elif \"Qwen3MoeForCausalLM\" in hf_config.architectures:\n        convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)\n    else:\n        assert not use_cpu_initialization, \"use_cpu_initialization is only supported for MoE model\"\n        from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel\n\n        load_state_dict_to_megatron_gptmodel(\n            state_dict=hf_state_dict,\n            wrapped_models=model,\n            config=hf_config,\n            params_dtype=torch.bfloat16,\n            is_value_model=False,\n        )\n\n    megatron_state_dict = model[0].module.sharded_state_dict()\n    del hf_state_dict, hf_model\n\n    # save megatron model\n    if len(os.listdir(output_path)) == 0:\n        dist_checkpointing.save(megatron_state_dict, output_path, sharded_strategy=None, async_sharded_save=False)\n    if test:\n        test_conversion(megatron_model_provider, tfconfig, output_path, model)\n\n\nif __name__ == \"__main__\":\n    args = _init_args()\n    convert_hf_to_mcore(\n        args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test, args.trust_remote_code\n    )\n"
  },
  {
    "path": "verl_distillation/scripts/diagnose.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Diagnose script for checking OS/hardware/python/pip/verl/network.\nThe output of this script can be a very good hint to issue/problem.\n\"\"\"\n\nimport os\nimport platform\nimport socket\nimport subprocess\nimport sys\nimport time\n\nimport psutil\n\ntry:\n    from urllib.parse import urlparse\n    from urllib.request import urlopen\nexcept ImportError:\n    from urllib2 import urlopen\n    from urlparse import urlparse\nimport argparse\nimport importlib.metadata\n\nimport torch\n\nURLS = {\n    \"PYPI\": \"https://pypi.python.org/pypi/pip\",\n}\n\nREGIONAL_URLS = {\n    \"cn\": {\n        \"PYPI(douban)\": \"https://pypi.douban.com/\",\n        \"Conda(tsinghua)\": \"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/\",\n    }\n}\n\n\ndef test_connection(name, url, timeout=10):\n    \"\"\"Simple connection test\"\"\"\n    urlinfo = urlparse(url)\n    start = time.time()\n    try:\n        socket.gethostbyname(urlinfo.netloc)\n    except Exception as e:\n        print(\"Error resolving DNS for {}: {}, {}\".format(name, url, e))\n        return\n    dns_elapsed = time.time() - start\n    start = time.time()\n    try:\n        _ = urlopen(url, timeout=timeout)\n    except Exception as e:\n        print(\"Error open {}: {}, {}, DNS finished in {} sec.\".format(name, url, e, dns_elapsed))\n        return\n    load_elapsed = time.time() - start\n    print(\"Timing for {}: {}, DNS: {:.4f} sec, LOAD: {:.4f} sec.\".format(name, url, dns_elapsed, load_elapsed))\n\n\ndef check_python():\n    print(\"----------Python Info----------\")\n    print(\"Version      :\", platform.python_version())\n    print(\"Compiler     :\", platform.python_compiler())\n    print(\"Build        :\", platform.python_build())\n    print(\"Arch         :\", platform.architecture())\n\n\ndef check_pip():\n    print(\"------------Pip Info-----------\")\n    try:\n        import pip\n\n        print(\"Version      :\", pip.__version__)\n        print(\"Directory    :\", os.path.dirname(pip.__file__))\n    except ImportError:\n        print(\"No corresponding pip install for current python.\")\n\n\ndef _get_current_git_commit():\n    try:\n        result = subprocess.run([\"git\", \"rev-parse\", \"HEAD\"], capture_output=True, text=True, check=True)\n        return result.stdout.strip()\n    except subprocess.CalledProcessError as e:\n        print(f\"Error running git command: {e.stderr.strip()}\")\n        return None\n    except FileNotFoundError:\n        print(\"Did not find command: git\")\n        return None\n\n\ndef check_verl():\n    print(\"----------verl Info-----------\")\n    try:\n        sys.path.insert(0, os.getcwd())\n        import verl\n\n        print(\"Version      :\", verl.__version__)\n        verl_dir = os.path.dirname(verl.__file__)\n        print(\"Directory    :\", verl_dir)\n        try:\n            commit_hash = _get_current_git_commit()\n            print(\"Commit Hash  :\", commit_hash)\n        except AttributeError:\n            print(\"Commit hash not found. \")\n    except ImportError as e:\n        print(f\"No verl installed: {e}\")\n    except Exception as e:\n        import traceback\n\n        if not isinstance(e, IOError):\n            print(\"An error occurred trying to import verl.\")\n            print(\"This is very likely due to missing or incompatible library files.\")\n        print(traceback.format_exc())\n\n\ndef check_os():\n    print(\"----------Platform Info----------\")\n    print(\"Platform     :\", platform.platform())\n    print(\"system       :\", platform.system())\n    print(\"node         :\", platform.node())\n    print(\"release      :\", platform.release())\n    print(\"version      :\", platform.version())\n\n\ndef check_hardware():\n    print(\"----------Hardware Info----------\")\n    print(\"machine      :\", platform.machine())\n    print(\"processor    :\", platform.processor())\n    if sys.platform.startswith(\"darwin\"):\n        pipe = subprocess.Popen((\"sysctl\", \"-a\"), stdout=subprocess.PIPE)\n        output = pipe.communicate()[0]\n        for line in output.split(b\"\\n\"):\n            if b\"brand_string\" in line or b\"features\" in line:\n                print(line.strip())\n    elif sys.platform.startswith(\"linux\"):\n        subprocess.call([\"lscpu\"])\n    elif sys.platform.startswith(\"win32\"):\n        subprocess.call([\"wmic\", \"cpu\", \"get\", \"name\"])\n\n\ndef check_network(args):\n    print(\"----------Network Test----------\")\n    if args.timeout > 0:\n        print(\"Setting timeout: {}\".format(args.timeout))\n        socket.setdefaulttimeout(10)\n    for region in args.region.strip().split(\",\"):\n        r = region.strip().lower()\n        if not r:\n            continue\n        if r in REGIONAL_URLS:\n            URLS.update(REGIONAL_URLS[r])\n        else:\n            import warnings\n\n            warnings.warn(\"Region {} do not need specific test, please refer to global sites.\".format(r), stacklevel=2)\n    for name, url in URLS.items():\n        test_connection(name, url, args.timeout)\n\n\ndef check_environment():\n    print(\"----------Environment----------\")\n    for k, v in os.environ.items():\n        if k.startswith(\"VERL_\") or k.startswith(\"OMP_\") or k.startswith(\"KMP_\") or k == \"CC\" or k == \"CXX\":\n            print('{}=\"{}\"'.format(k, v))\n\n\ndef check_pip_package_versions():\n    packages = [\"vllm\", \"sglang\", \"ray\", \"torch\"]\n    for package in packages:\n        try:\n            version = importlib.metadata.version(package)\n            print(f\"{package}\\t     : {version}\")\n        except importlib.metadata.PackageNotFoundError:\n            print(f\"{package}\\t     : not found.\")\n\n\ndef check_cuda_versions():\n    if torch.cuda.is_available():\n        try:\n            cuda_runtime_version = torch.version.cuda\n            print(f\"CUDA Runtime : {cuda_runtime_version}\")\n            import subprocess\n\n            nvcc_output = subprocess.check_output([\"nvcc\", \"--version\"]).decode(\"utf-8\")\n            cuda_compiler_version = next((line for line in nvcc_output.splitlines() if \"release\" in line), None)\n            if cuda_compiler_version:\n                print(f\"CUDA Compiler : {cuda_compiler_version.strip()}\")\n            else:\n                print(\"Could not determine CUDA compiler version.\")\n        except FileNotFoundError as e:\n            print(f\"CUDA compiler : Not found: {e}\")\n        except Exception as e:\n            print(f\"An error occurred while checking CUDA versions: {e}\")\n    else:\n        print(\"CUDA is not available.\")\n\n\ndef _get_cpu_memory():\n    \"\"\"\n    Get the total CPU memory capacity in GB.\n    \"\"\"\n    memory = psutil.virtual_memory()\n    return memory.total / (1024**3)\n\n\ndef _get_gpu_info():\n    \"\"\"\n    Get GPU type, GPU memory, and GPU count using nvidia-smi command.\n    \"\"\"\n    try:\n        result = subprocess.run(\n            [\"nvidia-smi\", \"--query-gpu=gpu_name,memory.total\", \"--format=csv,noheader,nounits\"],\n            capture_output=True,\n            text=True,\n            check=True,\n        )\n        gpu_lines = result.stdout.strip().split(\"\\n\")\n        gpu_count = len(gpu_lines)\n        gpu_info = []\n        for line in gpu_lines:\n            gpu_name, gpu_memory = line.split(\", \")\n            gpu_info.append(\n                {\n                    \"type\": gpu_name,\n                    \"memory\": float(gpu_memory) / 1024,  # Convert to GB\n                }\n            )\n        return gpu_count, gpu_info\n    except (subprocess.CalledProcessError, FileNotFoundError):\n        print(\"Failed to execute nvidia-smi command.\")\n        return 0, []\n\n\ndef _get_system_info():\n    \"\"\"\n    Get CPU memory capacity, GPU type, GPU memory, and GPU count.\n    \"\"\"\n    cpu_memory = _get_cpu_memory()\n    gpu_count, gpu_info = _get_gpu_info()\n    return {\"cpu_memory\": cpu_memory, \"gpu_count\": gpu_count, \"gpu_info\": gpu_info}\n\n\ndef check_system_info():\n    print(\"----------System Info----------\")\n    system_info = _get_system_info()\n    print(f\"CPU Memory\\t: {system_info['cpu_memory']:.2f} GB\")\n    print(f\"GPU Count\\t: {system_info['gpu_count']}\")\n    for i, gpu in enumerate(system_info[\"gpu_info\"]):\n        print(f\"GPU {i + 1}\\tType    : {gpu['type']}\")\n        print(f\"GPU {i + 1}\\tMemory  : {gpu['memory']:.2f} GB\")\n\n\ndef parse_args():\n    \"\"\"Parse arguments.\"\"\"\n    parser = argparse.ArgumentParser(\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n        description=\"Diagnose script for checking the current system.\",\n    )\n    choices = [\"python\", \"pip\", \"verl\", \"system\", \"os\", \"environment\"]\n    for choice in choices:\n        parser.add_argument(\"--\" + choice, default=1, type=int, help=\"Diagnose {}.\".format(choice))\n    parser.add_argument(\"--network\", default=0, type=int, help=\"Diagnose network.\")\n    parser.add_argument(\"--hardware\", default=0, type=int, help=\"Diagnose hardware.\")\n    parser.add_argument(\n        \"--region\",\n        default=\"\",\n        type=str,\n        help=\"Additional sites in which region(s) to test. \\\n                        Specify 'cn' for example to test mirror sites in China.\",\n    )\n    parser.add_argument(\"--timeout\", default=10, type=int, help=\"Connection test timeout threshold, 0 to disable.\")\n    args = parser.parse_args()\n    return args\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    if args.python:\n        check_python()\n\n    if args.pip:\n        check_pip()\n        check_pip_package_versions()\n\n    if args.verl:\n        check_verl()\n\n    if args.os:\n        check_os()\n\n    if args.hardware:\n        check_hardware()\n\n    if args.network:\n        check_network(args)\n\n    if args.environment:\n        check_environment()\n        check_cuda_versions()\n\n    if args.system:\n        check_system_info()\n"
  },
  {
    "path": "verl_distillation/scripts/generate_trainer_config.sh",
    "content": "#!/usr/bin/env bash\nset -euox pipefail\n\n\n# Define config specifications: \"config_name:output_file:config_arg\"\nCONFIG_SPECS=(\n    \"ppo_trainer:_generated_ppo_trainer.yaml:\"\n    \"ppo_megatron_trainer:_generated_ppo_megatron_trainer.yaml:--config-name=ppo_megatron_trainer.yaml\"\n)\n\ngenerate_config() {\n    local config_name=\"$1\"\n    local output_file=\"$2\"\n    local config_arg=\"$3\"\n    \n    local target_cfg=\"verl/trainer/config/${output_file}\"\n    local tmp_header=$(mktemp)\n    local tmp_cfg=$(mktemp)\n    \n    echo \"# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'\" > \"$tmp_header\"\n    echo \"# in which it invokes 'python3 scripts/print_cfg.py --cfg job ${config_arg}' to flatten the 'verl/trainer/config/${config_name}.yaml' config fields into a single file.\" >> \"$tmp_header\"\n    echo \"# Do not modify this file directly.\" >> \"$tmp_header\"\n    echo \"# The file is usually only for reference and never used.\" >> \"$tmp_header\"\n    echo \"\" >> \"$tmp_header\"\n    \n    python3 scripts/print_cfg.py --cfg job ${config_arg} > \"$tmp_cfg\"\n    \n    cat \"$tmp_header\" > \"$target_cfg\"\n    sed -n '/^actor_rollout_ref/,$p' \"$tmp_cfg\" >> \"$target_cfg\"\n    \n    rm \"$tmp_cfg\" \"$tmp_header\"\n    \n    echo \"Generated: $target_cfg\"\n}\n\nfor spec in \"${CONFIG_SPECS[@]}\"; do\n    IFS=':' read -r config_name output_file config_arg <<< \"$spec\"\n    generate_config \"$config_name\" \"$output_file\" \"$config_arg\"\ndone\n\nfor spec in \"${CONFIG_SPECS[@]}\"; do\n    IFS=':' read -r config_name output_file config_arg <<< \"$spec\"\n    target_cfg=\"verl/trainer/config/${output_file}\"\n    if ! git diff --exit-code -- \"$target_cfg\" >/dev/null; then\n        echo \"✖ $target_cfg is out of date. Please regenerate via 'scripts/generate_trainer_config.sh' and commit the changes.\"\n        exit 1\n    fi\ndone\n\necho \"All good\"\nexit 0\n"
  },
  {
    "path": "verl_distillation/scripts/init_random_model.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script override a model with custom config and random weights, mainly for create small models for \ndebugging purposes.\n\nUsage:\n    python scripts/init_random_model.py \\\n        --hf_model_path <path_to_hf_model> \\\n        --new_config_path <path_to_new_config.json> \\\n        --output_path <path_to_output_model>\n\n\"\"\"\n\nimport argparse\nimport json\nimport os\nimport warnings\nfrom typing import Any\n\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig\n\n\ndef _init_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--hf_model_path\", type=str, required=True, help=\"The path for the huggingface model\")\n    parser.add_argument(\"--new_config_path\", type=str, required=True, help=\"The path for the new config file\")\n    parser.add_argument(\"--output_path\", type=str, required=True, help=\"The path for the output random model\")\n    args = parser.parse_args()\n    return args\n\n\ndef check_output_path(output_path: str):\n    if os.path.exists(output_path):\n        warnings.warn(f\"Output path '{output_path}' already exists. Will do nothing.\", stacklevel=2)\n        exit()\n    else:\n        os.makedirs(output_path, exist_ok=True)\n        print(f\"Output path '{output_path}' created.\")\n\n\ndef check_configs(original_config: dict[str, Any], new_config: dict[str, Any]) -> bool:\n    \"\"\"\n    Check if the original config and new config are compatible.\n    This is a placeholder function; actual implementation may vary based on requirements.\n    \"\"\"\n    # Example check: ensure 'model_type' is the same\n    if new_config.get(\"model_type\", None) is not None and original_config.get(\"model_type\") != new_config.get(\n        \"model_type\"\n    ):\n        raise RuntimeError(\"Model types do not match.\")\n    for key in new_config:\n        if key not in original_config:\n            warnings.warn(\n                f\"Key '{key}' in new config does not exist in original config, may not take effect.\", stacklevel=2\n            )\n\n\ndef init_random_model(hf_model_path, new_config_path, output_path):\n    config = AutoConfig.from_pretrained(hf_model_path)\n    tokenizer = AutoTokenizer.from_pretrained(hf_model_path)\n    config_dict = PretrainedConfig.get_config_dict(hf_model_path)[0]\n    print(config_dict)\n    with open(new_config_path) as f:\n        new_config_dict = json.load(f)\n    check_configs(config_dict, new_config_dict)\n    config_dict.update(new_config_dict)\n    new_confg = config.from_dict(config_dict)\n    print(f\"new_config: {new_confg}\")\n    model = AutoModelForCausalLM.from_config(new_confg)\n    model.save_pretrained(output_path)\n    tokenizer.save_pretrained(output_path)\n    new_confg.save_pretrained(output_path)\n    print(f\"Random model initialized and saved to {output_path}\")\n\n\nif __name__ == \"__main__\":\n    args = _init_args()\n    check_output_path(args.output_path)\n    init_random_model(\n        hf_model_path=args.hf_model_path, new_config_path=args.new_config_path, output_path=args.output_path\n    )\n"
  },
  {
    "path": "verl_distillation/scripts/install_vllm_sglang_mcore.sh",
    "content": "#!/bin/bash\n\nUSE_MEGATRON=${USE_MEGATRON:-1}\nUSE_SGLANG=${USE_SGLANG:-1}\n\nexport MAX_JOBS=32\n\necho \"1. install inference frameworks and pytorch they need\"\nif [ $USE_SGLANG -eq 1 ]; then\n    pip install \"sglang[all]==0.5.2\" --no-cache-dir && pip install torch-memory-saver --no-cache-dir\nfi\npip install --no-cache-dir \"vllm==0.11.0\"\n\necho \"2. install basic packages\"\npip install \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=15.0.0\" pandas \"tensordict>=0.8.0,<=0.10.0,!=0.9.0\" torchdata \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \\\n    pytest py-spy pre-commit ruff tensorboard \n\necho \"pyext is lack of maintainace and cannot work with python 3.12.\"\necho \"if you need it for prime code rewarding, please install using patched fork:\"\necho \"pip install git+https://github.com/ShaohonChen/PyExt.git@py311support\"\n\npip install \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n\necho \"3. install FlashAttention and FlashInfer\"\n# Install flash-attn-2.8.1 (cxx11abi=False)\nwget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu12torch2.8cxx11abiFALSE-cp312-cp312-linux_x86_64.whl && \\\n    pip install --no-cache-dir flash_attn-2.8.1+cu12torch2.8cxx11abiFALSE-cp312-cp312-linux_x86_64.whl\n\npip install --no-cache-dir flashinfer-python==0.3.1\n\n\nif [ $USE_MEGATRON -eq 1 ]; then\n    echo \"4. install TransformerEngine and Megatron\"\n    echo \"Notice that TransformerEngine installation can take very long time, please be patient\"\n    pip install \"onnxscript==0.3.1\"\n    NVTE_FRAMEWORK=pytorch pip3 install --no-deps git+https://github.com/NVIDIA/TransformerEngine.git@v2.6\n    pip3 install --no-deps git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.1\nfi\n\n\necho \"5. May need to fix opencv\"\npip install opencv-python\npip install opencv-fixer && \\\n    python -c \"from opencv_fixer import AutoFix; AutoFix()\"\n\n\nif [ $USE_MEGATRON -eq 1 ]; then\n    echo \"6. Install cudnn python package (avoid being overridden)\"\n    pip install nvidia-cudnn-cu12==9.10.2.21\nfi\n\necho \"Successfully installed all packages\"\n"
  },
  {
    "path": "verl_distillation/scripts/legacy_model_merger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends.\n\nTo merge FSDP checkpoints:\n```sh\npython scripts/legacy_model_merger.py merge \\\n    --backend fsdp \\\n    --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \\\n    --target_dir /path/to/merged_hf_model\n```\n\nTo merge Megatron checkpoints:\n```sh\npython scripts/legacy_model_merger.py merge \\\n    --backend megatron \\\n    --tie-word-embedding \\\n    --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \\\n    --target_dir /path/to/merged_hf_model\n```\n\nFor more details, please refer to documentation:\nhttps://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model\n\"\"\"\n\nimport argparse\nimport os\nimport re\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom concurrent.futures import ThreadPoolExecutor\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nimport numpy as np\nimport torch\nfrom accelerate import init_empty_weights\nfrom safetensors.torch import load_file\nfrom torch.distributed._tensor import Placement, Shard\nfrom transformers import (\n    AutoConfig,\n    AutoModelForCausalLM,\n    AutoModelForTokenClassification,\n    AutoModelForVision2Seq,\n    GenerationConfig,\n    PretrainedConfig,\n)\n\ntry:\n    # for torch 2.5+\n    from torch.distributed.tensor import DTensor\nexcept ImportError:\n    from torch.distributed._tensor import DTensor\n\nfrom tqdm import tqdm\n\nfrom verl.utils import hf_processor, hf_tokenizer\n\n\n@dataclass\nclass ModelMergerConfig:\n    operation: str  # 'merge' or 'test'\n    backend: str\n    local_dir: str\n    hf_model_config_path: str\n    target_dir: Optional[str] = \"tmp\"\n    hf_upload_path: Optional[str] = None\n    private: bool = False\n    test_hf_dir: Optional[str] = None\n    tie_word_embedding: bool = False\n    is_value_model: bool = False\n    hf_model_path: Optional[str] = None\n    hf_upload: bool = field(init=False)\n\n    def __post_init__(self):\n        self.hf_upload = self.operation == \"merge\" and bool(self.hf_upload_path)\n        if self.operation == \"test\":\n            self.target_dir = None\n            self.hf_upload_path = None\n            self.private = False\n\n\nclass BaseModelMerger(ABC):\n    def __init__(self, config: ModelMergerConfig):\n        self.config = config\n        self.hf_model_config_path = config.hf_model_config_path\n\n        if config.hf_model_path:\n            print(\n                \"Warning: --hf_model_path is deprecated and will be removed in a future version. Currently verl will save huggingface model configuration files into checkpoint directories. Therefore, there is no need to provide --hf_model_path. \"\n            )\n            self.hf_model_config_path = config.hf_model_path\n\n        # Auto-detect huggingface subdirectory if it exists\n        huggingface_subdir = os.path.join(self.hf_model_config_path, \"huggingface\")\n        if os.path.isdir(huggingface_subdir):\n            self.hf_model_config_path = huggingface_subdir\n\n        self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path)\n\n    def get_transformers_auto_model_class(self):\n        # Handle case where architectures might be None or empty\n        if self.model_config.architectures is None or len(self.model_config.architectures) == 0:\n            # Try to infer from model_type if architectures is missing\n            model_type = getattr(self.model_config, 'model_type', '').lower()\n            if 'vision' in model_type or 'vl' in model_type:\n                return AutoModelForVision2Seq\n            elif 'causal' in model_type or 'gpt' in model_type or 'llama' in model_type or 'qwen' in model_type:\n                return AutoModelForCausalLM\n            else:\n                raise NotImplementedError(\n                    f\"Cannot determine model class: architectures is None and model_type '{model_type}' is not recognized\"\n                )\n        \n        architecture = self.model_config.architectures[0]\n        if \"ForTokenClassification\" in architecture:\n            return AutoModelForTokenClassification\n        elif \"ForCausalLM\" in architecture:\n            return AutoModelForCausalLM\n        elif \"ForConditionalGeneration\" in architecture:\n            return AutoModelForVision2Seq\n\n        raise NotImplementedError(f\"Unknown architecture {self.model_config.architectures}\")\n\n    def patch_model_generation_config(self, model):\n        \"\"\"\n        The generation_config created from model config may be different to the pretrained model,\n        this may lead to error when generating: https://github.com/volcengine/verl/issues/1246\n\n        This function patch the generation_config created from model config to the pretrained model.\n        \"\"\"\n        if model.can_generate():\n            try:\n                model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path)\n            except OSError:\n                print(\n                    f\"Warning: Generation config file not found in {self.hf_model_config_path}, using a generation config created from the model config.\"\n                )\n        return model\n\n    def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]):\n        \"\"\"\n        Save lora adapter to safetensors.\n\n        Returns:\n            lora_path: str, the path to the lora adapter. None if no lora adapter found.\n\n        Note:\n            This function change the 'state_dict' in place.\n        \"\"\"\n        lora_params_names = [name for name in state_dict.keys() if \"lora_\" in name]\n\n        if len(lora_params_names) == 0:\n            return None\n\n        import json\n        from typing import OrderedDict\n\n        import peft\n        from safetensors.torch import save_file\n\n        lora_params = OrderedDict()\n        target_modules = set()\n        lora_key = None\n\n        for name in lora_params_names:\n            lora_key = name.replace(\".default.weight\", \".weight\")\n            target_modules.add(lora_key.split(\".\")[-3])\n            lora_params[lora_key] = state_dict.pop(name)\n\n        lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1])\n        peft_dict = {\n            \"r\": lora_rank,\n            \"lora_alpha\": 0,  # lora_alpha is not set. An error should be raised to inform the user to set it manually.\n            \"target_modules\": list(target_modules),\n        }\n        peft_config = peft.LoraConfig(**peft_dict).to_dict()\n        peft_config[\"task_type\"] = peft_config[\"task_type\"].value if peft_config[\"task_type\"] else None\n        peft_config[\"peft_type\"] = peft_config[\"peft_type\"].value if peft_config[\"peft_type\"] else None\n        peft_config[\"target_modules\"] = list(peft_config[\"target_modules\"])\n\n        lora_path = os.path.join(self.config.target_dir, \"lora_adapter\")\n        os.makedirs(lora_path, exist_ok=True)\n        with open(os.path.join(lora_path, \"adapter_config.json\"), \"w\", encoding=\"utf-8\") as f:\n            json.dump(peft_config, f, ensure_ascii=False, indent=4)\n        save_file(lora_params, os.path.join(lora_path, \"adapter_model.safetensors\"))\n\n        for name in list(state_dict.keys()):\n            key = (\n                name.replace(\"base_model.model.\", \"\")\n                .replace(\".base_layer.weight\", \".weight\")\n                .replace(\".base_layer.bias\", \".bias\")\n            )\n            state_dict[key] = state_dict.pop(name)\n\n        return lora_path\n\n    def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]):\n        auto_model_class = self.get_transformers_auto_model_class()\n        with init_empty_weights():\n            model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16)\n        model.to_empty(device=\"cpu\")\n        model = self.patch_model_generation_config(model)\n\n        lora_path = self.save_lora_adapter(state_dict)\n        if lora_path:\n            print(f\"Saving lora adapter to {lora_path}\")\n\n        print(f\"Saving model to {self.config.target_dir}\")\n        model.save_pretrained(self.config.target_dir, state_dict=state_dict)\n        del state_dict\n        del model\n\n        processor = hf_processor(self.hf_model_config_path)\n        try:\n            tokenizer = hf_tokenizer(self.hf_model_config_path)\n        except Exception as e:\n            warnings.warn(f\"Failed to create tokenizer: {e}. This may affect tokenizer saving\", stacklevel=1)\n            tokenizer = None\n        if processor is not None:\n            print(f\"Saving processor to {self.config.target_dir}\")\n            processor.save_pretrained(self.config.target_dir)\n        if tokenizer is not None:\n            print(f\"Saving tokenizer to {self.config.target_dir}\")\n            tokenizer.save_pretrained(self.config.target_dir)\n\n    def upload_to_huggingface(self):\n        from huggingface_hub import HfApi\n\n        api = HfApi()\n        api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True)\n        api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type=\"model\")\n\n    @abstractmethod\n    def merge_and_save(self):\n        raise NotImplementedError(\"Subclasses should implement this method\")\n\n\nclass FSDPModelMerger(BaseModelMerger):\n    def _get_world_size(self) -> int:\n        \"\"\"Extracts the FSDP world_size from checkpoint filenames (e.g., 'model_world_size_8_rank_0.pt').\"\"\"\n        for filename in os.listdir(self.config.local_dir):\n            match = re.match(r\"model_world_size_(\\d+)_rank_0\\.pt\", filename)\n            if match:\n                return int(match.group(1))\n        raise FileNotFoundError(\n            f\"Could not determine world size. No file matching 'model_world_size_(\\\\d+)_rank_0.pt' found in {self.config.local_dir}\"\n        )\n\n    def _load_rank_zero_state_dict(self, world_size: int) -> dict:\n        return torch.load(\n            Path(self.config.local_dir) / f\"model_world_size_{world_size}_rank_0.pt\",\n            map_location=\"cpu\",\n            weights_only=False,\n        )\n\n    def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]:\n        \"\"\"\n        Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict.\n        If no DTensor is found, infers a simple FSDP mesh based on world_size.\n        \"\"\"\n        pivot_key = sorted(list(state_dict.keys()))[0]\n        weight = state_dict[pivot_key]\n\n        if isinstance(weight, DTensor):\n            # get sharding info\n            device_mesh = weight.device_mesh\n            mesh = device_mesh.mesh\n            mesh_dim_names = device_mesh.mesh_dim_names\n        else:\n            # for non-DTensor\n            mesh = np.array([world_size], dtype=np.int64)\n            mesh_dim_names = (\"fsdp\",)\n\n        return mesh, mesh_dim_names\n\n    def _calculate_shard_configuration(\n        self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...]\n    ) -> tuple[int, tuple[int, ...]]:\n        \"\"\"Calculates the total number of shards and the shape of the device mesh.\"\"\"\n        assert mesh_dim_names in ((\"fsdp\",), (\"ddp\", \"fsdp\")), f\"Unsupported mesh_dim_names {mesh_dim_names}\"\n\n        if \"tp\" in mesh_dim_names:\n            # TODO: \"tp\" is not supported yet due to the above assert\n            total_shards = mesh.shape[-1] * mesh.shape[-2]\n            mesh_shape = (mesh.shape[-2], mesh.shape[-1])\n        else:\n            total_shards = mesh.shape[-1]\n            mesh_shape = (mesh.shape[-1],)\n\n        return total_shards, mesh_shape\n\n    def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor:\n        \"\"\"Merges a list of tensors based on their DTensor placement\"\"\"\n        if placement.is_replicate():\n            return tensors[0]\n        elif placement.is_partial():\n            raise NotImplementedError(\"Partial placement is not supported yet\")\n        elif placement.is_shard():\n            return torch.cat(tensors, dim=placement.dim).contiguous()\n\n        raise NotImplementedError(f\"Unsupported placement: {placement}\")\n\n    def _load_and_merge_state_dicts(\n        self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...]\n    ) -> dict[str, torch.Tensor]:\n        model_state_dict_lst = [None] * total_shards\n\n        def process_one_shard(rank: int, model_state_dict_lst: list):\n            model_path = Path(self.config.local_dir) / f\"model_world_size_{world_size}_rank_{rank}.pt\"\n            state_dict = torch.load(model_path, map_location=\"cpu\", weights_only=False)\n            model_state_dict_lst[rank] = state_dict\n            return state_dict\n\n        with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:\n            futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)]\n            for future in tqdm(futures, desc=f\"Loading {total_shards} FSDP shards\", total=total_shards):\n                future.result()\n\n        # Merge state dicts from all shards\n        state_dict = {}\n        param_placements: dict[str, list] = {}\n\n        for key in set(model_state_dict_lst[0].keys()):\n            state_dict[key] = []\n            for model_state_shard in model_state_dict_lst:\n                # add tensor shard in order of rank to state_dict[key]\n                tensor = model_state_shard.pop(key)\n                if isinstance(tensor, DTensor):\n                    state_dict[key].append(tensor._local_tensor.bfloat16())\n\n                    placements = tuple(tensor.placements)\n                    # replicated placement at dp dimension can be discarded\n                    if mesh_dim_names[0] in (\"dp\", \"ddp\"):\n                        placements = placements[1:]\n\n                    if key not in param_placements:\n                        param_placements[key] = placements\n                    else:\n                        assert param_placements[key] == placements\n                else:\n                    state_dict[key].append(tensor.bfloat16())\n\n        del model_state_dict_lst\n\n        # Merge tensors\n        for key in sorted(state_dict):\n            if not isinstance(state_dict[key], list):\n                print(f\"No need to merge key {key}\")\n                continue\n            if key in param_placements:\n                # merge shards\n                placements: tuple[Shard] = param_placements[key]\n                if len(mesh_shape) == 1:\n                    # 1-D list, FSDP without TP\n                    assert len(placements) == 1\n                    shards = state_dict[key]\n                    state_dict[key] = self._merge_by_placement(shards, placements[0])\n                else:\n                    # 2-D list, FSDP + TP\n                    raise NotImplementedError(\"FSDP + TP is not supported yet\")\n            else:\n                state_dict[key] = torch.cat(state_dict[key], dim=0)\n\n        return state_dict\n\n    def merge_and_save(self):\n        world_size = self._get_world_size()\n        rank_zero_state_dict = self._load_rank_zero_state_dict(world_size)\n\n        mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size)\n        print(f\"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}\")\n\n        total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names)\n        print(f\"Processing model shards with {total_shards} {mesh_shape} in total\")\n\n        merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names)\n\n        if self.config.operation == \"test\":\n            if not self.config.test_hf_dir:\n                raise ValueError(\"test_hf_dir must be provided for test operation\")\n            self._test_state_dict(merged_state_dict)\n        elif self.config.operation == \"merge\":\n            self.save_hf_model_and_tokenizer(merged_state_dict)\n            if self.config.hf_upload:\n                self.upload_to_huggingface()\n        else:\n            raise ValueError(f\"Unknown operation: {self.config.operation}\")\n\n    def _test_state_dict(self, state_dict: dict[str, torch.Tensor]):\n        auto_model_class = self.get_transformers_auto_model_class()\n\n        hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16)\n        hf_state_dict = hf_model.state_dict()\n        del hf_model\n\n        hf_model_keys = set(hf_state_dict.keys())\n        collected_keys = set(state_dict.keys())\n\n        missing_keys = hf_model_keys - collected_keys\n        assert len(missing_keys) == 0, f\"Missing keys in collected state dict: {list(sorted(missing_keys))}\"\n\n        extra_keys = collected_keys - hf_model_keys\n        assert len(extra_keys) == 0, f\"Extra keys in collected state dict: {list(sorted(extra_keys))}\"\n\n        for key in hf_model_keys:\n            hf_shape = hf_state_dict[key].shape\n            collected_shape = state_dict[key].shape\n            assert hf_shape == collected_shape, (\n                f\"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}\"\n            )\n\n            hf_dtype = hf_state_dict[key].dtype\n            collected_dtype = state_dict[key].dtype\n            assert hf_dtype == collected_dtype, (\n                f\"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}\"\n            )\n\n            torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6)\n\n        print(\"FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.\")\n\n\nclass MegatronModelMerger(BaseModelMerger):\n    def __init__(self, config: ModelMergerConfig):\n        from verl.utils.megatron_utils import get_hf_config_and_tokenizer_checkpoint_path\n\n        config.hf_model_config_path = get_hf_config_and_tokenizer_checkpoint_path(config.local_dir)\n        super().__init__(config)\n\n        self.params_mapping = {\n            # megatron core gpt model name, huggingface model name\n            # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the longer key within the containing relationship is processed first.\n            \"embedding.word_embeddings\": \"model.embed_tokens\",\n            # attn\n            \"self_attention.linear_qkv.layer_norm_weight\": \"input_layernorm.weight\",\n            \"self_attention.linear_qkv.layer_norm_bias\": \"input_layernorm.bias\",\n            \"self_attention.linear_qkv\": \"self_attn.qkv_proj\",\n            \"self_attention.q_layernorm\": \"self_attn.q_norm\",\n            \"self_attention.k_layernorm\": \"self_attn.k_norm\",\n            \"self_attention.linear_proj\": \"self_attn.o_proj\",\n            # mla\n            \"self_attention.linear_q_proj\": \"self_attn.q_proj\",\n            \"self_attention.linear_q_down_proj\": \"self_attn.q_a_proj\",\n            \"self_attention.linear_q_up_proj.layer_norm_weight\": \"self_attn.q_a_layernorm.weight\",\n            \"self_attention.linear_q_up_proj\": \"self_attn.q_b_proj\",\n            \"self_attention.linear_kv_down_proj\": \"self_attn.kv_a_proj_with_mqa\",\n            \"self_attention.linear_kv_up_proj.layer_norm_weight\": \"self_attn.kv_a_layernorm.weight\",\n            \"self_attention.linear_kv_up_proj\": \"self_attn.kv_b_proj\",\n            # mlp\n            \"pre_mlp_layernorm\": \"post_attention_layernorm\",\n            \"mlp.linear_fc1.layer_norm_weight\": \"post_attention_layernorm.weight\",\n            \"mlp.linear_fc1.layer_norm_bias\": \"post_attention_layernorm.bias\",\n            \"mlp.linear_fc1\": \"mlp.gate_up_proj\",\n            \"mlp.linear_fc2\": \"mlp.down_proj\",\n            # moe\n            \"mlp.router.expert_bias\": \"mlp.gate.e_score_correction_bias\",\n            \"mlp.router\": \"mlp.gate\",\n            \"mlp.shared_experts.linear_fc1\": \"mlp.shared_experts.gate_up_proj\",\n            \"mlp.shared_experts.linear_fc2\": \"mlp.shared_experts.down_proj\",\n            \"linear_fc1\": \"gate_up_proj\",\n            \"linear_fc2\": \"down_proj\",\n            # output\n            \"final_layernorm\": \"norm\",\n            \"output_layer\": \"lm_head\",\n        }\n\n    def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[int, int]:\n        tp_rank = pp_rank = None\n        rank_list = sharded_dir.split(\"_\")[2:]\n        if re.match(r\"mp_rank_(\\d\\d)_(\\d\\d\\d)\", sharded_dir):\n            tp_rank = int(rank_list[0])\n            pp_rank = int(rank_list[1])\n        elif re.match(r\"mp_rank_(\\d\\d)\", sharded_dir):\n            tp_rank = int(rank_list[0])\n            pp_rank = 0\n\n        assert tp_rank is not None and pp_rank is not None, f\"Invalid sharded dir {sharded_dir}\"\n\n        return tp_rank, pp_rank\n\n    def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[list[str], int, int]:\n        \"\"\"\n        Validates the Megatron checkpoint structure (presence of 'model.pt' in sharded directories).\n        Determines TP and PP sizes from directory names.\n        \"\"\"\n        tp_size = 0\n        pp_size = 0\n        sharded_dirs = sorted(os.listdir(model_path))\n        for sharded_dir in sharded_dirs:\n            assert \"model.pt\" in os.listdir(Path(model_path) / sharded_dir), f\"model.pt not found in {sharded_dir}\"\n            tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir)\n            tp_size = max(tp_size, tp_rank + 1)\n            pp_size = max(pp_size, pp_rank + 1)\n        return sharded_dirs, tp_size, pp_size\n\n    def _merge_across_tp(\n        self,\n        key: str,\n        tp_data: list[torch.Tensor],\n        config: PretrainedConfig,\n        tp_size: int,\n        is_value_model: bool = False,\n    ) -> Union[torch.Tensor, list[torch.Tensor]]:\n        if \"linear_fc1.weight\" in key:\n            # if the tensor is gate and proj\n            gate_lst = []\n            up_lst = []\n            for infer_param in tp_data:\n                gate, up = infer_param.chunk(2)\n                gate_lst.append(gate)\n                up_lst.append(up)\n            gate = torch.cat(gate_lst, dim=0)\n            up = torch.cat(up_lst, dim=0)\n            return [gate, up]\n        elif \"self_attention.linear_qkv.\" in key and \"layer_norm\" not in key:\n            # if the tensor is qkv, for each param on tp, split into q, k, v\n            # concat q, k, v separately.\n            q_lst = []\n            k_lst = []\n            v_lst = []\n            assert config.num_attention_heads % config.num_key_value_heads == 0\n            num_q_per_kv = config.num_attention_heads // config.num_key_value_heads\n            assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0\n            kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2)\n            split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]\n\n            for infer_param in tp_data:\n                num_query_groups_per_partition = config.num_key_value_heads // tp_size\n                for chunk in infer_param.chunk(num_query_groups_per_partition):\n                    split_size = [\n                        kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,\n                        kv_size_per_tp // num_query_groups_per_partition,\n                        kv_size_per_tp // num_query_groups_per_partition,\n                    ]\n                    q, k, v = chunk.split(split_size)\n                    q_lst.append(q)\n                    k_lst.append(k)\n                    v_lst.append(v)\n\n            q = torch.cat(q_lst, dim=0)\n            k = torch.cat(k_lst, dim=0)\n            v = torch.cat(v_lst, dim=0)\n            return [q, k, v]\n        elif \"layer_norm\" in key or \"layernorm\" in key or \"router\" in key or (\"output_layer\" in key and is_value_model):\n            return tp_data[0]\n        else:\n            dim = 0\n            if \"linear_fc2.weight\" in key or \"self_attention.linear_proj\" in key:\n                dim = 1\n            return torch.cat(tp_data, dim=dim)\n\n    def _load_state_dicts(\n        self, model_ckpt_path: str, sharded_dirs: list[str], tp_size: int, pp_size: int\n    ) -> list[list[dict]]:\n        model_state_dict_lst = [[None for _ in range(tp_size)] for _ in range(pp_size)]\n\n        def _process_one_megatron_shard(sharded_dir: str):\n            model_file_path = Path(model_ckpt_path) / sharded_dir / \"model.pt\"\n            state_dict = torch.load(model_file_path, map_location=\"cpu\", weights_only=False)\n            tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir)\n            model_state_dict_lst[pp_rank][tp_rank] = state_dict\n\n        with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:\n            futures = [executor.submit(_process_one_megatron_shard, sharded_dir) for sharded_dir in sharded_dirs]\n            for future in tqdm(futures, desc=f\"Loading {len(sharded_dirs)} Megatron shards\", total=len(sharded_dirs)):\n                future.result()\n\n        return model_state_dict_lst\n\n    def _check_megatron_state_key(self, key: str) -> bool:\n        \"\"\"\n        Checks if the key is a valid Megatron state key.\n\n        Now the model merger only supports keys that start with \"decoder/embedding/output_layer\" in TransformerLayer.\n        Shall not use key starts with \"model.\"\n        \"\"\"\n        if key.startswith(\"model.\"):\n            raise ValueError(\n                f\"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder/embedding/output_layer' in TransformerLayer.\"\n            )\n\n        skip_checking_keys = [\"embedding.word_embeddings\", \"output_layer\"]\n        for skip_key in skip_checking_keys:\n            if skip_key in key:\n                print(f\"skip checking key {key}\")\n                return\n\n        # Exclude extra state keys\n        if not key.startswith(\"decoder\"):\n            raise ValueError(\n                f\"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer.\"\n            )\n\n    def _merge_state_dicts(\n        self, model_state_dict_lst: list[list[dict]], tp_size: int, pp_size: int\n    ) -> dict[str, torch.Tensor]:\n        state_dict = {}\n        vpp_size = len(model_state_dict_lst[0][0])\n        layers_cum = 0\n\n        for vpp_rank in range(vpp_size):\n            for pp_rank in range(pp_size):\n                layers_handled = 0\n                keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys()\n                for key in keys:\n                    if \"extra_state\" in key:\n                        continue\n                    if self.config.tie_word_embedding and (\"output_layer\" in key):\n                        print(\"skip lm_head and reward_head loading because of tie_word_embeddings\")\n                        continue\n\n                    self._check_megatron_state_key(key)\n                    hf_name = self._replace_name(key, self.params_mapping)\n                    assert hf_name is not None, f\"Failed to convert layer name [{key}] from megatron to huggingface.\"\n                    if \"model.layers.\" in hf_name:\n                        local_layer_no = int(hf_name.split(\".\")[2])\n                        layers_handled = max(local_layer_no, layers_handled)\n                        global_layer_no = local_layer_no + layers_cum\n                        new_key_list = hf_name.split(\".\")\n                        new_key_list[2] = str(global_layer_no)\n                        hf_name = \".\".join(new_key_list)\n                    else:\n                        warnings.warn(f\"hf_name {hf_name} will not be fixed with layer number\", stacklevel=2)\n\n                    tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)]\n                    merged = self._merge_across_tp(key, tp_data, self.model_config, tp_size, self.config.is_value_model)\n\n                    if not isinstance(merged, list):\n                        state_dict[hf_name] = merged\n                    elif len(merged) == 3:\n                        # split qkv\n                        for n, d in zip([\"q\", \"k\", \"v\"], merged):\n                            state_dict[hf_name.replace(\"qkv\", n)] = d\n                    elif len(merged) == 2:\n                        # split gate up\n                        state_dict[hf_name.replace(\"gate_up\", \"gate\")] = merged[0]\n                        state_dict[hf_name.replace(\"gate_up\", \"up\")] = merged[1]\n                    print(\n                        f\"converted {key} to {hf_name} with shape {merged.shape if isinstance(merged, torch.Tensor) else [t.shape for t in merged]}\"\n                    )\n\n                layers_cum += layers_handled + 1  # zero based\n\n        return state_dict\n\n    def merge_and_save(self):\n        from verl.utils.megatron_utils import get_model_checkpoint_path\n\n        model_ckpt_path = get_model_checkpoint_path(self.config.local_dir)\n        sharded_dirs, tp_size, pp_size = self._check_megatron_checkpoint_path(model_ckpt_path)\n        print(f\"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}\")\n\n        model_state_dict_lst = self._load_state_dicts(model_ckpt_path, sharded_dirs, tp_size, pp_size)\n        merged_state_dict = self._merge_state_dicts(model_state_dict_lst, tp_size, pp_size)\n        del model_state_dict_lst\n\n        if self.config.operation == \"test\":\n            if not self.config.test_hf_dir:\n                raise ValueError(\"test_hf_dir must be provided for test operation\")\n            self._test_state_dict(merged_state_dict)\n        elif self.config.operation == \"merge\":\n            self.save_hf_model_and_tokenizer(merged_state_dict)\n            if self.config.hf_upload:\n                self.upload_to_huggingface()\n        else:\n            raise ValueError(f\"Unknown operation: {self.config.operation}\")\n\n    def _test_state_dict(self, state_dict: dict[str, torch.Tensor]):\n        \"\"\"\n        Compares the merged Megatron state_dict against a reference safetensors model.\n        Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name.\n        \"\"\"\n        ref_state_dict = load_file(Path(self.config.test_hf_dir) / \"model.safetensors\")\n\n        for name, loaded_weight in state_dict.items():\n            # name = self._replace_name(original_name, self.params_mapping)\n            if not name or name.endswith(\".bias\") and name not in ref_state_dict:\n                continue\n            if \"rotary_emb.inv_freq\" in name:\n                continue\n            if self.config.tie_word_embedding and \"lm_head.weight\" in name:\n                continue\n            if name not in ref_state_dict:\n                raise RuntimeError(f\"key: {name} not exist in state_dict\")\n            param = ref_state_dict[name]\n            assert loaded_weight.dtype == param.dtype\n            torch.testing.assert_close(loaded_weight, param, atol=1e-2, rtol=5e-2)\n\n    def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str:\n        for m_name, v_name in name_mapping.items():\n            if m_name not in megatron_name:\n                continue\n\n            megatron_name = megatron_name.replace(\"decoder\", \"model\")\n            param_name = megatron_name.replace(m_name, v_name)\n            return param_name\n\n        return None  # Return None if no mapping found\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"verl model merger\")\n    subparsers = parser.add_subparsers(dest=\"operation\", required=True, help=\"Specify 'merge' or 'test' operation.\")\n\n    base_op_parser = argparse.ArgumentParser(add_help=False)\n    base_op_parser.add_argument(\n        \"--backend\", type=str, required=True, choices=[\"fsdp\", \"megatron\"], help=\"The backend of the model\"\n    )\n    base_op_parser.add_argument(\"--local_dir\", type=str, required=True, help=\"Path to the saved model checkpoints\")\n    base_op_parser.add_argument(\n        \"--hf_model_path\",\n        type=str,\n        default=None,\n        help=\"(Deprecated) Path to the original Hugging Face model for config.\",\n    )\n    base_op_parser.add_argument(\n        \"--tie-word-embedding\",\n        action=\"store_true\",\n        help=\"Whether to tie word embedding weights (currently only Megatron supported)\",\n    )\n    base_op_parser.add_argument(\n        \"--is-value-model\",\n        action=\"store_true\",\n        help=\"Whether the model is a value model (currently only Megatron supported)\",\n    )\n\n    merge_parser = subparsers.add_parser(\"merge\", parents=[base_op_parser], help=\"Merge model checkpoints and save.\")\n    merge_parser.add_argument(\n        \"--target_dir\", default=\"tmp\", type=str, help=\"Directory to save the merged huggingface model\"\n    )\n    merge_parser.add_argument(\n        \"--hf_upload_path\", default=None, type=str, help=\"Hugging Face repository ID to upload the model\"\n    )\n    merge_parser.add_argument(\n        \"--private\", action=\"store_true\", help=\"Whether to upload the model to a private Hugging Face repository\"\n    )\n\n    test_parser = subparsers.add_parser(\n        \"test\", parents=[base_op_parser], help=\"Test merged model against a reference Hugging Face model\"\n    )\n    test_parser.add_argument(\n        \"--test_hf_dir\", type=str, required=True, help=\"Path to the reference Hugging Face model directory for testing\"\n    )\n\n    args = parser.parse_args()\n\n    common_config_args = {\n        \"operation\": args.operation,\n        \"backend\": args.backend,\n        \"tie_word_embedding\": args.tie_word_embedding,\n        \"is_value_model\": args.is_value_model,\n        \"local_dir\": args.local_dir,\n        \"hf_model_path\": args.hf_model_path,\n        \"hf_model_config_path\": args.local_dir,\n    }\n\n    if args.operation == \"merge\":\n        config = ModelMergerConfig(\n            **common_config_args,\n            target_dir=args.target_dir,\n            hf_upload_path=args.hf_upload_path,\n            private=args.private,\n            test_hf_dir=None,\n        )\n        os.makedirs(config.target_dir, exist_ok=True)\n    elif args.operation == \"test\":\n        config = ModelMergerConfig(\n            **common_config_args,\n            test_hf_dir=args.test_hf_dir,\n            # the following args are not used by test operation\n            target_dir=None,\n            hf_upload_path=None,\n            private=False,\n        )\n    else:\n        raise NotImplementedError(f\"Unknown operation: {args.operation}\")\n\n    if config.backend == \"fsdp\":\n        merger = FSDPModelMerger(config)\n    elif config.backend == \"megatron\":\n        merger = MegatronModelMerger(config)\n    else:\n        raise NotImplementedError(f\"Unknown backend: {config.backend}\")\n\n    merger.merge_and_save()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/scripts/print_cfg.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\ntry:\n    import hydra\nexcept ImportError as e:\n    raise ImportError(\"Please install hydra-core via 'pip install hydra-core' and retry.\") from e\n\n\n@hydra.main(config_path=\"../verl/trainer/config\", config_name=\"ppo_trainer\", version_base=None)\ndef main(config):\n    \"\"\"Main entry point for PPO training with Hydra configuration management.\n\n    Args:\n        config_dict: Hydra configuration dictionary containing training parameters.\n    \"\"\"\n    print(config)\n    from verl.utils.config import omega_conf_to_dataclass\n\n    profiler_config = omega_conf_to_dataclass(config.critic.profiler)\n    print(profiler_config)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/scripts/rollout_viewer.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport re\nimport traceback\nfrom pathlib import Path\nfrom typing import Annotated, Optional\n\nimport aiofiles\n\ntry:\n    import ujson as json\nexcept ImportError:\n    import json\nimport typer\nfrom rich.highlighter import ReprHighlighter\nfrom rich.markdown import Markdown\nfrom rich.table import Table\nfrom rich.text import Text\nfrom textual import on\nfrom textual.app import App, ComposeResult\nfrom textual.containers import Horizontal, Vertical, VerticalScroll\nfrom textual.widgets import Input, ProgressBar, Select, SelectionList, Static\n\nINDEX_KEY = \"__IDX\"\nFILE_SUFFIX = \".jsonl\"\n\n\ndef check_textual_version():\n    # check if textual version is equal to 0.52.1\n    import textual\n    from packaging.version import Version\n\n    if Version(textual.__version__) != Version(\"0.52.1\"):\n        raise ImportError(f\"Textual version {textual.__version__} is not supported, please pip install textual==0.52.1\")\n\n\ncheck_textual_version()\n\n\nasync def load_path(p: Path, data: dict, mask_strs: str, idx: int, pbar):\n    samples = []\n    async with aiofiles.open(p, encoding=\"utf-8\") as f:\n        async for line in f:\n            d = json.loads(line)\n            for k in d:\n                if isinstance(d[k], str):\n                    if mask_strs:\n                        d[k] = re.sub(rf\"{mask_strs}\", \"*\", d[k])\n                else:\n                    d[k] = json.dumps(d[k], ensure_ascii=False, indent=4)\n\n            d[INDEX_KEY] = len(samples)\n            samples.append(d)\n        data[idx] = {\"samples\": samples}\n\n    print(f\"path {p} loaded\")\n    pbar.advance(1)\n\n\nasync def load_dir(path: Path, data: dict[int, dict], pbar, mask_strs: str = \"\"):\n    paths = list(path.glob(f\"*{FILE_SUFFIX}\"))\n    paths = sorted(paths, key=lambda x: int(x.stem))\n\n    tasks = [load_path(p, data, mask_strs, i, pbar) for i, p in enumerate(paths)]\n\n    await asyncio.gather(*tasks)\n\n\nclass Highlighter(ReprHighlighter):\n    highlights = ReprHighlighter.highlights + [\n        r\"(?P<tag_name>[][\\<\\>{}()\\|（）【】\\[\\]=`])\",\n        r\"\\<\\|(?P<tag_name>[\\w\\W]*?)\\|\\>\",\n    ]\n\n\ndef center_word_with_equals_exactly(word: str, total_length: int, char: str = \"=\") -> str:\n    if len(word) > total_length:\n        return word\n\n    padding = total_length - len(word)\n    left_pad = (padding) // 2\n    right_pad = (padding + 1) // 2\n    return char * left_pad + \" \" + word + \" \" + char * right_pad\n\n\ndef highlight_keyword(content: str, keyword: Optional[str]):\n    if not keyword:\n        return Text(content)\n    text = Text()\n    parts = content.split(keyword)\n    for i, part in enumerate(parts):\n        text.append(part, style=None)\n        if i < len(parts) - 1:\n            # text.append(keyword, style=Style(color=\"#d154d1\", bgcolor=\"yellow\", bold=True))\n            text.append(keyword, style=\"on #8f51b5\")\n    return text\n\n\nhelp_doc = \"\"\"\n⌨️   keybinds：\n\n- `f/esc`: find/cancel\n- `tab/←/→`: change focus\n- `j/k`: page down/up\n- `g/G`: scroll home/end\n- `n/N`: next sample/step\n- `p/P`: previous sample/step\n- `s`: switch display mode\n  - plain text\n  - rich table\n\n\"\"\"\n\n\nclass JsonLineViewer(App):\n    BINDINGS = [\n        (\"left\", \"focus_previous\", \"Focus Previous\"),\n        (\"right\", \"focus_next\", \"Focus Next\"),\n        (\"s\", \"swith_render\", \"switch render\"),\n        # control\n        (\"n\", \"next_sample\", \"Next Sample\"),\n        (\"N\", \"next_step\", \"Next Step\"),\n        (\"p\", \"previous_sample\", \"Previous Sample\"),\n        (\"P\", \"previous_step\", \"Previous Step\"),\n        # search\n        (\"f\", \"toggle_search\", \"find\"),\n        (\"enter\", \"next_search\", \"find next\"),\n        (\"escape\", \"cancel_search\", \"cancel find\"),\n        # scroll\n        (\"j\", \"page_down\", \"page down\"),\n        (\"k\", \"page_up\", \"page up\"),\n        (\"g\", \"page_home\", \"page home\"),\n        (\"G\", \"page_end\", \"page end\"),\n    ]\n\n    CSS = \"\"\"\n\n    Select:focus > SelectCurrent {\n        border: tall #8f51b5;\n    }\n    Select.-expanded > SelectCurrent {\n        border: tall #8f51b5;\n    }\n    #select-container {\n        width: 15%;\n        height: 100%;\n        align: center top;\n    }\n    #search-container {\n        height: 10%;\n        align: center top;\n    }\n    #search-box {\n        width: 50%;\n    }\n    #reqid-box {\n        width: 50%;\n    }\n    \"\"\"\n\n    def __init__(self, step_num: int, data: dict[int, dict], pbar):\n        super().__init__()\n        self.step_num = step_num\n\n        self.data = data\n        self.render_table = False\n        self.selected_step_index = 0\n        self.selected_sample_index = 0\n        self.pbar = pbar\n\n        self.matches = []\n        self.current_match_index = 0\n\n        self.highlighter = Highlighter()\n\n        first_samples = data[list(data.keys())[0]][\"samples\"]\n        # Prepare the initial field filter list (all keys from the first sample)\n        self.filter_fields = [(f, f, True) for f in first_samples[0].keys()]\n\n        # Internal set used for fast membership checks when we add new fields on the fly.\n        # We keep it here so that when new columns appear in later steps (e.g. `request_id`),\n        # they can be added to the UI automatically without restarting the viewer.\n        self._field_set: set[str] = set(first_samples[0].keys())\n        self.sample_num = len(first_samples)\n\n    def compose(self) -> ComposeResult:\n        with Horizontal(id=\"search-container\"):\n            yield Input(placeholder=\"find something...\", id=\"search-box\")\n            yield Input(placeholder=\"request id...\", id=\"reqid-box\")\n            with Vertical(id=\"search-container2\"):\n                yield self.pbar\n                yield Static(\"\", id=\"search-status\")\n\n        with Horizontal():\n            with Vertical(id=\"select-container\"):\n                yield Static(\"\\n\")\n                yield Static(\n                    renderable=Markdown(\n                        help_doc,\n                    ),\n                    markup=False,\n                )\n                yield Static(\"\\n\")\n                yield Select(\n                    id=\"step-select\",\n                    value=0,\n                    prompt=\"select step\",\n                    options=[(\"step: 1\", 0)],\n                    allow_blank=False,\n                )\n                yield Select(\n                    id=\"sample-select\",\n                    value=0,\n                    prompt=\"select sample\",\n                    options=[(\"sample: 1\", 0)],\n                    allow_blank=False,\n                )\n                yield Select(\n                    id=\"sample-sort\",\n                    value=0,\n                    prompt=\"排序\",\n                    options=[\n                        (\"sort\", 0),\n                        (\"score asc\", 1),\n                        (\"score desc\", 2),\n                    ],\n                    allow_blank=False,\n                )\n\n                yield SelectionList[int]((\"Select ALL\", 1, True), id=\"fields-select-all\")\n                with VerticalScroll(id=\"scroll-view2\"):\n                    yield SelectionList[str](*self.filter_fields, id=\"fields-select\")\n            with VerticalScroll(id=\"scroll-view\"):\n                yield Static(id=\"content\", markup=False)\n\n    async def on_mount(self) -> None:\n        self.step_select = self.query_one(\"#step-select\", Select)\n        self.sample_select = self.query_one(\"#sample-select\", Select)\n        self.sample_sort = self.query_one(\"#sample-sort\", Select)\n        self.content_display = self.query_one(\"#content\", Static)\n        self.search_box = self.query_one(\"#search-box\", Input)\n        self.reqid_box = self.query_one(\"#reqid-box\", Input)\n        self.scroll_view = self.query_one(\"#scroll-view\", VerticalScroll)\n        self.search_status = self.query_one(\"#search-status\", Static)\n        self.fields_select = self.query_one(\"#fields-select\", SelectionList)\n        self.fields_select.border_title = \"field filter\"\n\n        if self.data:\n            self.step_select.set_options([(f\"step: {i + 1}\", i) for i in range(self.step_num)])\n            self.sample_select.set_options([(f\"sample: {i + 1}\", i) for i in range(self.sample_num)])\n            self.step_select.focus()\n            await self.update_content()\n\n    def update_result_options(self, offset: int = 0, sort_desc: Optional[bool] = None):\n        options = []\n        if isinstance(self.selected_step_index, int) and self.selected_step_index < len(self.data):\n            if self.sample_num is None or sort_desc is not None:\n                samples = self.data[self.selected_step_index].get(\"samples\", [])\n                if not samples:\n                    self.selected_sample_index = offset\n                    return\n                if sort_desc is not None:\n                    samples = sorted(\n                        samples,\n                        key=lambda x: x.get(\"score\", x.get(\"score_1\", 0)),\n                        reverse=sort_desc,\n                    )\n\n                options = [(f\"sample: {r[INDEX_KEY] + 1}\", r[INDEX_KEY]) for r in samples]\n                self.sample_select.set_options(options)\n                self.sample_num = len(samples)\n\n            if sort_desc is not None and options:\n                self.selected_sample_index = options[0][1]\n            else:\n                self.selected_sample_index = offset\n\n    async def update_content(self, search_keyword: Optional[str] = None):\n        content = \"\"\n        try:\n            samples = self.data[self.selected_step_index].get(\"samples\", [])\n            content_dict_full = samples[self.selected_sample_index]\n\n            # Dynamically track any NEW keys that appear and add them to the field filter.\n            self._update_fields_select(content_dict_full.keys())\n\n            # Apply field selection filter (only show selected fields)\n            content_dict = {k: v for k, v in content_dict_full.items() if k in self.fields_select.selected}\n            if self.render_table:\n                content = Table(\"key\", \"value\", show_lines=True)\n                for k in content_dict:\n                    v = content_dict[k]\n                    v = f\"{v}\"\n                    content.add_row(\n                        k,\n                        self.highlighter(highlight_keyword(v, search_keyword)),\n                    )\n            else:\n                text = Text()\n                for k in content_dict:\n                    v = content_dict[k]\n                    s = center_word_with_equals_exactly(k, 64) + f\"\\n{v}\\n\"\n                    text.append(highlight_keyword(s, search_keyword))\n                content = self.highlighter(text)\n        except KeyError:\n            content = f\"Loading data asynchronously, progress: {len(self.data)}/{self.step_num} step\"\n\n        except Exception:\n            content = self.highlighter(traceback.format_exc())\n\n        self.content_display.update(content)\n\n    # ---------------------------------------------------------------------\n    # Request-ID jump logic\n    # ---------------------------------------------------------------------\n\n    @on(Input.Submitted, \"#reqid-box\")\n    async def on_reqid_submitted(self, event: Input.Submitted) -> None:\n        \"\"\"Jump to the sample that has a matching `request_id`.\"\"\"\n\n        req_id_raw = event.value.strip()\n        # Remove hyphens so search is tolerant to different id formats\n        req_id = req_id_raw.replace(\"-\", \"\")\n        if not req_id:\n            return\n\n        found = False\n        for step_idx, step_data in self.data.items():\n            for sample in step_data.get(\"samples\", []):\n                sample_id = str(sample.get(\"request_id\", \"\"))\n                if sample_id.replace(\"-\", \"\") == req_id:\n                    # Update selected indices\n                    self.selected_step_index = step_idx\n                    self.step_select.value = step_idx\n\n                    # Ensure sample list is updated and select sample\n                    self.update_result_options(offset=sample[INDEX_KEY])\n                    self.selected_sample_index = sample[INDEX_KEY]\n                    self.sample_select.value = sample[INDEX_KEY]\n\n                    await self._clear_search()\n                    await self.update_content()\n\n                    found = True\n                    break\n            if found:\n                break\n\n        if not found:\n            self.search_status.update(Text(f\"request_id '{req_id_raw}' not found\", style=\"bold red\"))\n        else:\n            # Keep the typed id in the input box so users see what was searched.\n            pass\n\n    # ---------------------------------------------------------------------\n    # Helper: add new fields to SelectionList on-the-fly\n    # ---------------------------------------------------------------------\n\n    def _update_fields_select(self, keys):\n        \"\"\"Add any unseen *keys* to the field-selection widget so they can be toggled.\n\n        The viewer is often launched with only the first step loaded. Later steps may\n        introduce new columns (e.g. `request_id`). This helper ensures those fields\n        become visible without requiring a restart.\n        \"\"\"\n        # Ensure we have the widget (only after on_mount)\n        if not hasattr(self, \"fields_select\"):\n            return\n\n        for k in keys:\n            if k not in self._field_set:\n                self._field_set.add(k)\n                try:\n                    # By default, new fields are selected so they appear immediately.\n                    self.fields_select.add_option(k, k, selected=True)\n                except Exception:\n                    # Fallback for older textual versions where signature is different.\n                    self.fields_select.add_option((k, k, True))\n\n    @on(Select.Changed, \"#step-select\")\n    async def step_changed(self, event):\n        self.selected_step_index = event.value\n        self.update_result_options()\n        await self.update_content()\n\n    @on(Select.Changed, \"#sample-select\")\n    async def sample_changed(self, event):\n        self.selected_sample_index = event.value\n        await self._clear_search()\n        await self.update_content()\n\n    @on(Select.Changed, \"#sample-sort\")\n    async def sort_changed(self, event):\n        v = event.value\n        self.update_result_options(sort_desc=None if v == 0 else False if v == 1 else True)\n        await self.update_content()\n\n    @on(SelectionList.SelectedChanged, \"#fields-select\")\n    async def fields_changed(self, event):\n        await self.update_content()\n\n    @on(SelectionList.SelectedChanged, \"#fields-select-all\")\n    async def fields_all_changed(self, event):\n        s = self.query_one(\"#fields-select-all\", SelectionList)\n        if s.selected:\n            self.fields_select.select_all()\n        else:\n            self.fields_select.deselect_all()\n\n    def action_focus_previous(self):\n        self.screen.focus_previous()\n\n    def action_focus_next(self):\n        self.screen.focus_next()\n\n    async def action_next_step(self) -> None:\n        self.selected_step_index += 1\n        if self.selected_step_index >= self.step_num:\n            self.selected_step_index = 0\n        self.step_select.value = self.selected_step_index\n        self.update_result_options()\n        await self.update_content()\n\n    async def action_next_sample(self) -> None:\n        self.selected_sample_index += 1\n        if not self.sample_num or self.selected_sample_index >= self.sample_num:\n            self.selected_sample_index = 0\n        self.sample_select.value = self.selected_sample_index\n        await self._clear_search()\n        await self.update_content()\n\n    async def action_previous_step(self) -> None:\n        self.selected_step_index -= 1\n        if self.selected_step_index < 0:\n            self.selected_step_index = self.step_num - 1\n        self.step_select.value = self.selected_step_index\n        self.update_result_options()\n        await self.update_content()\n\n    async def action_previous_sample(self) -> None:\n        self.selected_sample_index -= 1\n        if self.selected_sample_index < 0:\n            self.selected_sample_index = self.sample_num - 1\n        self.sample_select.value = self.selected_sample_index\n        await self._clear_search()\n        await self.update_content()\n\n    async def action_swith_render(self):\n        self.render_table = not self.render_table\n        await self.update_content()\n\n    def action_toggle_search(self) -> None:\n        self.search_box.focus()\n\n    async def action_cancel_search(self) -> None:\n        self.search_box.value = \"\"\n        await self._clear_search()\n        await self.update_content()\n\n    async def _clear_search(self):\n        self.matches = []\n        self.search_status.update(\"\")\n        self.current_match_index = 0\n\n    @on(Input.Submitted, \"#search-box\")\n    async def on_search_submitted(self, event: Input.Submitted) -> None:\n        self.matches = []\n        self.current_match_index = 0\n        if event.value:\n            await self.update_content(event.value)\n            renderable = self.content_display.render()\n            if isinstance(renderable, Table):\n                return\n\n            assert isinstance(renderable, Text)\n            console = self.content_display._console\n            lines = renderable.wrap(console, self.scroll_view.container_size.width)\n            line_idx_recorded = set()\n            for line_idx, line in enumerate(lines):\n                if line_idx in line_idx_recorded:\n                    continue\n                if event.value in line:\n                    self.matches.append(\n                        {\n                            \"line\": line_idx,\n                            \"word\": event.value,\n                        }\n                    )\n                    line_idx_recorded.add(line_idx)\n            self.scroll_view.focus()\n            await self.action_next_search()\n\n    async def action_next_search(self) -> None:\n        if not self.matches or self.current_match_index >= len(self.matches):\n            return\n\n        target_line = self.matches[self.current_match_index][\"line\"]\n        self.scroll_view.scroll_to(x=0, y=target_line * 1, animate=False)\n        self.current_match_index = (self.current_match_index + 1) % len(self.matches)\n        self.search_status.update(\n            Text(\n                f\"Find ：{self.current_match_index + 1}/{len(self.matches)}\",\n                style=\"bold on #8f51b5\",\n            )\n        )\n\n    def action_page_up(self):\n        self.scroll_view.scroll_page_up(animate=False)\n\n    def action_page_down(self):\n        self.scroll_view.scroll_page_down(animate=False)\n\n    def action_page_home(self):\n        self.scroll_view.scroll_home(animate=False)\n\n    def action_page_end(self):\n        self.scroll_view.scroll_end(animate=False)\n\n\nasync def _run(path: Path, mask_str: str):\n    assert path.exists(), f\"{path} not exist\"\n\n    paths = list(path.glob(f\"*{FILE_SUFFIX}\"))\n    paths = sorted(paths, key=lambda x: int(x.stem))\n\n    if not paths:\n        raise ValueError(f\"no available reward dump files under f{path}\")\n\n    print(f\"get jsonl file nums: {len(paths)}\")\n\n    pbar = ProgressBar(total=len(paths), name=\"data load progress\")\n    data = {}\n    await load_path(paths[0], data, mask_str, 0, pbar)\n    app = JsonLineViewer(step_num=len(paths), data=data, pbar=pbar)\n    await asyncio.gather(load_dir(path, data, pbar, mask_str), app.run_async())\n\n\napp = typer.Typer()\n\n\n@app.command(help=\"launch TUI APP\")\ndef run(\n    rollout_data_dir: Path,\n    mask_str: Annotated[str, typer.Option(help=\"string that will be masked to *\")] = r\"<\\|image_pad\\|>|<\\|imgpad\\|>\",\n):\n    loop = asyncio.get_event_loop()\n    loop.run_until_complete(_run(rollout_data_dir, mask_str))\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "verl_distillation/setup.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# setup.py is the fallback installation script when pyproject.toml does not work\nimport os\nfrom pathlib import Path\n\nfrom setuptools import find_packages, setup\n\nversion_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))\n\nwith open(os.path.join(version_folder, \"verl/version/version\")) as f:\n    __version__ = f.read().strip()\n\ninstall_requires = [\n    \"accelerate\",\n    \"codetiming\",\n    \"datasets\",\n    \"dill\",\n    \"hydra-core\",\n    \"numpy<2.0.0\",\n    \"pandas\",\n    \"peft\",\n    \"pyarrow>=19.0.0\",\n    \"pybind11\",\n    \"pylatexenc\",\n    \"ray[default]>=2.41.0\",\n    \"torchdata\",\n    \"tensordict>=0.8.0,<=0.10.0,!=0.9.0\",\n    \"transformers\",\n    \"wandb\",\n    \"packaging>=20.0\",\n    \"tensorboard\",\n]\n\nTEST_REQUIRES = [\"pytest\", \"pre-commit\", \"py-spy\", \"pytest-asyncio\"]\nPRIME_REQUIRES = [\"pyext\"]\nGEO_REQUIRES = [\"mathruler\", \"torchvision\", \"qwen_vl_utils\"]\nGPU_REQUIRES = [\"liger-kernel\", \"flash-attn\"]\nMATH_REQUIRES = [\"math-verify\"]  # Add math-verify as an optional dependency\nVLLM_REQUIRES = [\"tensordict>=0.8.0,<=0.10.0,!=0.9.0\", \"vllm>=0.8.5,<=0.11.0\"]\nSGLANG_REQUIRES = [\n    \"tensordict>=0.8.0,<=0.10.0,!=0.9.0\",\n    \"sglang[srt,openai]==0.5.2\",\n    \"torch==2.8.0\",\n]\nTRL_REQUIRES = [\"trl<=0.9.6\"]\nMCORE_REQUIRES = [\"mbridge\"]\nTRANSFERQUEUE_REQUIRES = [\"TransferQueue @ git+https://github.com/TransferQueue/TransferQueue.git@68c04e7\"]\n\nextras_require = {\n    \"test\": TEST_REQUIRES,\n    \"prime\": PRIME_REQUIRES,\n    \"geo\": GEO_REQUIRES,\n    \"gpu\": GPU_REQUIRES,\n    \"math\": MATH_REQUIRES,\n    \"vllm\": VLLM_REQUIRES,\n    \"sglang\": SGLANG_REQUIRES,\n    \"trl\": TRL_REQUIRES,\n    \"mcore\": MCORE_REQUIRES,\n    \"transferqueue\": TRANSFERQUEUE_REQUIRES,\n}\n\n\nthis_directory = Path(__file__).parent\nlong_description = (this_directory / \"README.md\").read_text()\n\nsetup(\n    name=\"verl\",\n    version=__version__,\n    package_dir={\"\": \".\"},\n    packages=find_packages(where=\".\"),\n    url=\"https://github.com/volcengine/verl\",\n    license=\"Apache 2.0\",\n    author=\"Bytedance - Seed - MLSys\",\n    author_email=\"zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk\",\n    description=\"verl: Volcano Engine Reinforcement Learning for LLM\",\n    install_requires=install_requires,\n    extras_require=extras_require,\n    package_data={\n        \"\": [\"version/*\"],\n        \"verl\": [\"trainer/config/*.yaml\"],\n        \"recipe.onpolicy_distill\": [\"config/*.yaml\"],\n    },\n    include_package_data=True,\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n)\n"
  },
  {
    "path": "verl_distillation/tests/README.md",
    "content": "# Tests layout\n\nEach folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance:\n- `tests/trainer` for testing functionality related to `verl/trainer`\n- `tests/models` for testing functionality related to `verl/models`\n- ...\n\nThere are a few folders with `special_` prefix, created for special purposes:\n- `special_distributed`: unit tests that must run with multiple GPUs\n- `special_e2e`: end-to-end tests with training/generation scripts\n- `special_npu`: tests for NPUs\n- `special_sanity`: a suite of quick sanity tests\n- `special_standalone`: a set of test that are designed to run in dedicated environments\n\nAccelerators for tests \n- By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`.\n- For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment.\n\n# Workflow layout\n\nAll CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs:\n1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml`\n2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml`\n3. End-to-end tests: `e2e_*.yml`\n4. Unit tests\n  - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py`\n  - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix.\n  - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when\n    - new workflow yaml is added to `.github/workflows`\n    - new tests are added to workflow mentioned in 2."
  },
  {
    "path": "verl_distillation/tests/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/tests/experimental/agent_loop/agent_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\nfrom omegaconf import DictConfig\n\nfrom verl.experimental.agent_loop import AgentLoopManager\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, RewardModelWorker\n\n\ndef init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup:\n    # =========================== 1. Create hybrid ActorRollout workers ===========================\n    actor_rollout_cls = (\n        AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == \"async\" else ActorRolloutRefWorker\n    )\n    role_worker_mapping = {\n        Role.ActorRollout: ray.remote(actor_rollout_cls),\n    }\n    if config.reward_model.enable:\n        role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n\n    global_pool_id = \"global_pool\"\n    resource_pool_spec = {\n        global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n    }\n    mapping = {\n        Role.ActorRollout: global_pool_id,\n    }\n    if config.reward_model.enable_resource_pool:\n        mapping[Role.RewardModel] = \"reward_pool\"\n        if config.reward_model.n_gpus_per_node <= 0:\n            raise ValueError(\"config.reward_model.n_gpus_per_node must be greater than 0\")\n        if config.reward_model.nnodes <= 0:\n            raise ValueError(\"config.reward_model.nnodes must be greater than 0\")\n\n        reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes\n        resource_pool_spec[\"reward_pool\"] = reward_pool\n    resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n    resource_pool_manager.create_resource_pool()\n    resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()}\n\n    # create actor and rollout\n    resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout)\n    actor_rollout_cls = RayClassWithInitArgs(\n        cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role=\"actor_rollout\"\n    )\n    resource_pool_to_cls[resource_pool][\"actor_rollout\"] = actor_rollout_cls\n\n    if config.reward_model.enable:\n        # we create a RM here\n        resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel)\n        rm_cls = RayClassWithInitArgs(role_worker_mapping[Role.RewardModel], config=config.reward_model)\n        resource_pool_to_cls[resource_pool][\"rm\"] = rm_cls\n\n    all_wg = {}\n    for resource_pool, class_dict in resource_pool_to_cls.items():\n        worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n        wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)\n        spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n        all_wg.update(spawn_wg)\n    actor_rollout_wg = all_wg[\"actor_rollout\"]\n    actor_rollout_wg.init_model()\n\n    if config.actor_rollout_ref.rollout.mode == \"sync\":\n        return actor_rollout_wg\n\n    if config.reward_model.enable_resource_pool and config.reward_model.enable:\n        rm_wg = all_wg[\"rm\"]\n        rm_wg.init_model()\n    else:\n        rm_wg = None\n    # =========================== 2. Create AgentLoopManager ===========================\n    agent_loop_manager = AgentLoopManager(\n        config=config,\n        worker_group=actor_rollout_wg,\n        rm_wg=rm_wg,\n    )\n\n    return agent_loop_manager\n"
  },
  {
    "path": "verl_distillation/tests/experimental/agent_loop/qwen_vl_tool_chat_template.jinja2",
    "content": "{% set image_count = namespace(value=0) %}\n{% set video_count = namespace(value=0) %}\n{%- if tools %}\n{{- '<|im_start|>system\\n' }}\n{%- if messages[0]['role'] == 'system' %}\n{%- if messages[0]['content'] is string %}\n{{- messages[0]['content'] }}\n{%- else %}\n{{- messages[0]['content'][0]['text'] }}\n{%- endif %}\n{%- else %}\n{{- 'You are a helpful assistant.' }}\n{%- endif %}\n{{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n{%- for tool in tools %}\n{{- \"\\n\" }}\n{{- tool | tojson }}\n{%- endfor %}\n{{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{% for message in messages %}\n{% if message['role'] != 'system' or loop.first == false %}\n{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}\n{{ message['content'] }}<|im_end|>\n{% else %}\n{% for content in message['content'] %}\n{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}\n{% set image_count.value = image_count.value + 1 %}\n{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>\n{% elif content['type'] == 'video' or 'video' in content %}\n{% set video_count.value = video_count.value + 1 %}\n{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>\n{% elif 'text' in content %}\n{{ content['text'] }}\n{% endif %}\n{% endfor %}<|im_end|>\n{% endif %}\n{%- elif message.role == \"assistant\" %}\n{{- '<|im_start|>' + message.role }}\n{%- if message.content %}\n{{- '\\n' + message.content }}\n{%- endif %}\n{%- for tool_call in message.tool_calls %}\n{%- if tool_call.function is defined %}\n{%- set tool_call = tool_call.function %}\n{%- endif %}\n{{- '\\n<tool_call>\\n{\"name\": \"' }}\n{{- tool_call.name }}\n{{- '\", \"arguments\": ' }}\n{{- tool_call.arguments | tojson }}\n{{- '}\\n</tool_call>' }}\n{%- endfor %}\n{{- '<|im_end|>\\n' }}\n{%- elif message.role == \"tool\" %}\n{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n{{- '<|im_start|>user' }}\n{%- endif %}\n{{- '\\n<tool_response>\\n' }}\n{% if message['content'] is string %}\n{{ message.content }}\n{% else %}\n{% for content in message['content'] %}\n{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}\n{% set image_count.value = image_count.value + 1 %}\n{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>\n{% elif content['type'] == 'video' or 'video' in content %}\n{% set video_count.value = video_count.value + 1 %}\n{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>\n{% elif content['type'] == 'text' or 'text' in content %}\n{{ content['text'] }}\n{% endif %}\n{% endfor %}\n{% endif %}\n{{- '\\n</tool_response>' }}\n{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n{{- '<|im_end|>\\n' }}\n{%- endif %}\n{%- endif %}\n{% endif %}\n{% endfor %}\n{%- else %}\n{% for message in messages %}\n{% if loop.first and message['role'] != 'system' %}\n<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}\n{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}\n{{ message['content'] }}<|im_end|>\n{% else %}\n{% for content in message['content'] %}\n{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}\n{% set image_count.value = image_count.value + 1 %}\n{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>\n{% elif content['type'] == 'video' or 'video' in content %}\n{% set video_count.value = video_count.value + 1 %}\n{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>\n{% elif 'text' in content %}\n{{ content['text'] }}\n{% endif %}\n{% endfor %}<|im_end|>\n{% endif %}\n{%- elif message.role == \"assistant\" %}\n{{- '<|im_start|>' + message.role }}\n{%- if message.content %}\n{{- '\\n' + message.content }}\n{%- endif %}\n{%- for tool_call in message.tool_calls %}\n{%- if tool_call.function is defined %}\n{%- set tool_call = tool_call.function %}\n{%- endif %}\n{{- '\\n<tool_call>\\n{\"name\": \"' }}\n{{- tool_call.name }}\n{{- '\", \"arguments\": ' }}\n{{- tool_call.arguments | tojson }}\n{{- '}\\n</tool_call>' }}\n{%- endfor %}\n{{- '<|im_end|>\\n' }}\n{%- elif message.role == \"tool\" %}\n{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n{{- '<|im_start|>user' }}\n{%- endif %}\n{{- '\\n<tool_response>\\n' }}\n{% if message['content'] is string %}\n{{ message.content }}\n{% else %}\n{% for content in message['content'] %}\n{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}\n{% set image_count.value = image_count.value + 1 %}\n{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>\n{% elif content['type'] == 'video' or 'video' in content %}\n{% set video_count.value = video_count.value + 1 %}\n{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>\n{% elif content['type'] == 'text' or 'text' in content %}\n{{ content['text'] }}\n{% endif %}\n{% endfor %}\n{% endif %}\n{{- '\\n</tool_response>' }}\n{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n{{- '<|im_end|>\\n' }}\n{%- endif %}\n{%- endif %}\n{% endfor %}\n{%- endif %}\n{% if add_generation_prompt %}\n<|im_start|>assistant\n{% endif %}"
  },
  {
    "path": "verl_distillation/tests/experimental/agent_loop/test_agent_loop_reward.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nimport pytest\nimport ray\nfrom hydra import compose, initialize_config_dir\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom transformers import AutoTokenizer\n\nfrom verl.experimental.agent_loop import AgentLoopManager\nfrom verl.protocol import DataProto\nfrom verl.trainer.main_ppo import create_rl_sampler\nfrom verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn\n\n\n@pytest.mark.skip(reason=\"compute score is depreated and replaced by reward manager worker\")\ndef test_agent_loop_compute_score():\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(\"ppo_trainer\")\n\n    model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-1.5B-Instruct\")\n    config.data.return_raw_chat = True\n    config.actor_rollout_ref.model.path = model_path\n    config.actor_rollout_ref.actor.use_dynamic_bsz = True\n    config.actor_rollout_ref.rollout.name = os.environ[\"ROLLOUT_NAME\"]\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.enforce_eager = True\n    config.actor_rollout_ref.rollout.prompt_length = 1024\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.skip_tokenizer_init = True\n\n    # 1. init agent loop manager\n    agent_loop_manager = AgentLoopManager(config)\n\n    # 2. init dataset and dataloader\n    local_folder = os.path.expanduser(\"~/data/gsm8k/\")\n    data_files = [os.path.join(local_folder, \"train.parquet\")]\n    tokenizer = AutoTokenizer.from_pretrained(model_path)\n\n    dataset = RLHFDataset(\n        data_files=data_files,\n        tokenizer=tokenizer,\n        config=config.data,\n        processor=None,\n    )\n\n    batch_size = 128\n    sampler = create_rl_sampler(config.data, dataset)\n    dataloader = StatefulDataLoader(\n        dataset=dataset,\n        batch_size=batch_size,\n        num_workers=config.data.dataloader_num_workers,\n        drop_last=True,\n        collate_fn=collate_fn,\n        sampler=sampler,\n    )\n\n    # 3. generate_sequences with agent loop\n    batch_dict = next(iter(dataloader))\n    batch = DataProto.from_single_dict(batch_dict)\n    gen_batch = agent_loop_manager.generate_sequences(prompts=batch)\n\n    rm_scores = gen_batch.batch[\"rm_scores\"]\n    sample_scores = rm_scores.sum(dim=1)\n    assert sample_scores.min() == 0.0, f\"min score: {sample_scores.min()}\"\n    assert sample_scores.max() == 1.0, f\"max score: {sample_scores.max()}\"\n    print(f\"gsm8k acc: {sample_scores.mean()}\")\n\n    print(\"Test passed!\")\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/experimental/agent_loop/test_agent_loop_reward_model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nimport pytest\nimport ray\nfrom hydra import compose, initialize_config_dir\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom transformers import AutoTokenizer\n\nfrom tests.experimental.agent_loop.agent_utils import AgentLoopManager\nfrom verl.protocol import DataProto\nfrom verl.trainer.main_ppo import create_rl_sampler\nfrom verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn\n\n\n@pytest.mark.skip(reason=\"reward model is depreated and replaced by GRM\")\ndef test_agent_loop_compute_score_with_model():\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(\"ppo_trainer\")\n\n    model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-1.5B-Instruct\")\n    config.data.return_raw_chat = True\n    config.actor_rollout_ref.model.path = model_path\n    config.actor_rollout_ref.actor.use_dynamic_bsz = True\n    config.actor_rollout_ref.rollout.name = os.environ[\"ROLLOUT_NAME\"]\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.enforce_eager = True\n    config.actor_rollout_ref.rollout.prompt_length = 1024\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.skip_tokenizer_init = True\n    config.reward_model.enable = True\n    config.reward_model.model.path = model_path\n    config.reward_model.use_dynamic_bsz = True\n    config.reward_model.forward_max_token_len_per_gpu = 6000\n    config.reward_model.micro_batch_size_per_gpu = 40\n    config.reward_model.enable_resource_pool = True\n    config.reward_model.n_gpus_per_node = 1\n    config.reward_model.nnodes = 1\n    config.reward_model.model.trust_remote_code = True\n    config.reward_model.model.input_tokenizer = None\n    config.trainer.n_gpus_per_node = 4\n    config.trainer.nnodes = 1\n    # 1. init agent loop manager\n    agent_loop_manager = AgentLoopManager(config)\n\n    # 2. init dataset and dataloader\n    local_folder = os.path.expanduser(\"~/data/gsm8k/\")\n    data_files = [os.path.join(local_folder, \"train.parquet\")]\n    tokenizer = AutoTokenizer.from_pretrained(model_path)\n\n    dataset = RLHFDataset(\n        data_files=data_files,\n        tokenizer=tokenizer,\n        config=config.data,\n        processor=None,\n    )\n\n    batch_size = 128\n    sampler = create_rl_sampler(config.data, dataset)\n    dataloader = StatefulDataLoader(\n        dataset=dataset,\n        batch_size=batch_size,\n        num_workers=config.data.dataloader_num_workers,\n        drop_last=True,\n        collate_fn=collate_fn,\n        sampler=sampler,\n    )\n\n    # 3. generate_sequences with agent loop\n    batch_dict = next(iter(dataloader))\n    batch = DataProto.from_single_dict(batch_dict)\n    gen_batch = agent_loop_manager.generate_sequences(prompts=batch)\n\n    rm_scores = gen_batch.batch[\"rm_scores\"]\n    sample_scores = rm_scores.sum(dim=1)\n    print(sample_scores)\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/experimental/agent_loop/test_basic_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport os\nfrom typing import Any\n\nimport numpy as np\nimport pytest\nimport ray\nfrom omegaconf import DictConfig\nfrom transformers.utils import get_json_schema\n\nfrom tests.experimental.agent_loop.agent_utils import init_agent_loop_manager\nfrom verl.experimental.agent_loop import AgentLoopManager\nfrom verl.experimental.agent_loop.agent_loop import get_trajectory_info\nfrom verl.protocol import DataProto\nfrom verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema\nfrom verl.tools.schemas import ToolResponse\nfrom verl.trainer.ppo.reward import compute_reward, load_reward_manager\nfrom verl.utils import hf_tokenizer\n\n\n@pytest.fixture\ndef init_config() -> DictConfig:\n    from hydra import compose, initialize_config_dir\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(\n            config_name=\"ppo_trainer\",\n            overrides=[\n                \"actor_rollout_ref.actor.use_dynamic_bsz=true\",\n                # test sleep/wake_up with fsdp offload\n                \"actor_rollout_ref.actor.fsdp_config.param_offload=True\",\n                \"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\",\n                \"reward_model.reward_manager=dapo\",\n                \"+reward_model.reward_kwargs.overlong_buffer_cfg.enable=False\",\n                \"+reward_model.reward_kwargs.overlong_buffer_cfg.len=3072\",\n                \"+reward_model.reward_kwargs.max_resp_len=4096\",\n            ],\n        )\n\n    model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-1.5B-Instruct\")\n    config.actor_rollout_ref.model.path = model_path\n    config.actor_rollout_ref.rollout.name = os.environ[\"ROLLOUT_NAME\"]\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.enforce_eager = True\n    config.actor_rollout_ref.rollout.prompt_length = 4096\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.n = 4\n    config.actor_rollout_ref.rollout.agent.num_workers = 2\n    config.actor_rollout_ref.rollout.skip_tokenizer_init = True\n\n    return config\n\n\ndef test_single_turn(init_config):\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    agent_loop_manager = AgentLoopManager(init_config)\n    tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)\n    reward_fn = load_reward_manager(\n        init_config, tokenizer, num_examine=0, **init_config.reward_model.get(\"reward_kwargs\", {})\n    )\n\n    raw_prompts = [\n        [\n            {\n                \"role\": \"user\",\n                \"content\": \"Let's play a role playing game. Your name is Alice, your favorite color is blue.\",\n            }\n        ],\n        [{\"role\": \"user\", \"content\": \"Let's play a role playing game. Your name is Bob, your favorite color is red.\"}],\n    ]\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array(raw_prompts),\n            \"agent_name\": np.array([\"single_turn_agent\"] * len(raw_prompts)),\n            \"data_source\": np.array([\"openai/gsm8k\"] * len(raw_prompts)),\n            \"reward_model\": np.array([{\"style\": \"rule\", \"ground_truth\": \"1.0\"}] * len(raw_prompts)),\n        },\n    )\n    n = init_config.actor_rollout_ref.rollout.n\n    batch = batch.repeat(n)\n    result = agent_loop_manager.generate_sequences(prompts=batch)\n    assert len(result) == len(raw_prompts) * n\n\n    # check result\n    seq_len = result.batch[\"prompts\"].size(1) + result.batch[\"responses\"].size(1)\n    assert result.batch[\"input_ids\"].size(1) == seq_len\n    assert result.batch[\"attention_mask\"].size(1) == seq_len\n    assert result.batch[\"position_ids\"].size(1) == seq_len\n\n    if init_config.actor_rollout_ref.rollout.calculate_log_probs:\n        assert result.batch[\"rollout_log_probs\"].size(1) == result.batch[\"responses\"].size(1)\n\n    # check compute score\n    assert result.batch[\"rm_scores\"].shape == result.batch[\"responses\"].shape\n    reward_tensor, reward_extra_info = compute_reward(result, reward_fn)\n    assert reward_tensor.shape == result.batch[\"responses\"].shape\n    assert \"acc\" in reward_extra_info, f\"reward_extra_info {reward_extra_info} should contain 'acc'\"\n    assert reward_extra_info[\"acc\"].shape == (len(result),), f\"invalid acc: {reward_extra_info['acc']}\"\n\n    # check turns\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    assert np.all(num_turns == 2)\n\n    print(\"Test passed!\")\n    ray.shutdown()\n\n\nclass WeatherTool(BaseTool):\n    def get_current_temperature(self, location: str, unit: str = \"celsius\"):\n        \"\"\"Get current temperature at a location.\n\n        Args:\n            location: The location to get the temperature for, in the format \"City, State, Country\".\n            unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n        Returns:\n            the temperature, the location, and the unit in a dict\n        \"\"\"\n        print(f\"[DEBUG] get_current_temperature: {location}, {unit}\")\n        return {\n            \"temperature\": 26.1,\n            \"location\": location,\n            \"unit\": unit,\n        }\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        schema = get_json_schema(self.get_current_temperature)\n        return OpenAIFunctionToolSchema(**schema)\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:\n        try:\n            result = self.get_current_temperature(**parameters)\n            return ToolResponse(text=json.dumps(result)), 0, {}\n        except Exception as e:\n            return ToolResponse(text=str(e)), 0, {}\n\n\nclass WeatherToolWithData(BaseTool):\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        schema = get_json_schema(self.get_temperature_date)\n        return OpenAIFunctionToolSchema(**schema)\n\n    def get_temperature_date(self, location: str, date: str, unit: str = \"celsius\"):\n        \"\"\"Get temperature at a location and date.\n\n        Args:\n            location: The location to get the temperature for, in the format \"City, State, Country\".\n            date: The date to get the temperature for, in the format \"Year-Month-Day\".\n            unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n        Returns:\n            the temperature, the location, the date and the unit in a dict\n        \"\"\"\n        print(f\"[DEBUG] get_temperature_date: {location}, {date}, {unit}\")\n        return {\n            \"temperature\": 25.9,\n            \"location\": location,\n            \"date\": date,\n            \"unit\": unit,\n        }\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:\n        try:\n            result = self.get_temperature_date(**parameters)\n            return ToolResponse(text=json.dumps(result)), 0, {}\n        except Exception as e:\n            return ToolResponse(text=str(e)), 0, {}\n\n\ndef test_tool_agent(init_config):\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        },\n        ignore_reinit_error=True,\n    )\n\n    # =========================== 1. Init rollout manager ===========================\n    tool_config = {\n        \"tools\": [\n            {\n                \"class_name\": \"tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool\",\n                \"config\": {\"type\": \"native\"},\n            },\n            {\n                \"class_name\": \"tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData\",\n                \"config\": {\"type\": \"native\"},\n            },\n        ]\n    }\n    tool_config_path = \"/tmp/tool_config.json\"\n    with open(tool_config_path, \"w\") as f:\n        json.dump(tool_config, f)\n\n    n = 2\n    init_config.actor_rollout_ref.rollout.n = n\n    init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path\n    init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2\n    init_config.actor_rollout_ref.rollout.calculate_log_probs = True\n    agent_loop_manager = AgentLoopManager(init_config)\n\n    # =========================== 2. Generate sequences  ===========================\n    raw_prompts = [\n        [\n            {\"role\": \"user\", \"content\": \"How are you?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in Los Angeles now?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in New York now?\"},\n        ],\n        [\n            {\n                \"role\": \"system\",\n                \"content\": \"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\\n\\n\"\n                \"Current Date: 2024-09-30\",\n            },\n            {\"role\": \"user\", \"content\": \"What's the temperature in San Francisco now? How about tomorrow?\"},\n        ],\n    ]\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),\n            \"agent_name\": np.array([\"tool_agent\"] * len(raw_prompts)),\n            \"data_source\": np.array([\"openai/gsm8k\"] * len(raw_prompts)),\n            \"reward_model\": np.array([{\"style\": \"rule\", \"ground_truth\": \"1.0\"}] * len(raw_prompts)),\n        },\n    )\n    batch = batch.repeat(n)\n    result = agent_loop_manager.generate_sequences(prompts=batch)\n    assert len(result) == len(raw_prompts) * n\n\n    # Check turns\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    print(f\"num_turns: {num_turns}\")\n    for i in range(len(num_turns)):\n        if i // n == 0:\n            # [user, assistant]\n            assert num_turns[i] == 2\n        else:\n            # [user, assistant, tool, assistant]\n            assert num_turns[i] == 4\n\n    # Check response_mask\n    tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)\n    responses = result.batch[\"responses\"]\n    response_mask = result.batch[\"response_mask\"]\n    attention_mask = result.batch[\"attention_mask\"]\n    assert result.batch[\"rm_scores\"].size(1) == responses.size(1)\n    assert responses.size() == response_mask.size(), f\"{responses.size()} != {response_mask.size()}\"\n    assert result.batch[\"rollout_log_probs\"].size(1) == result.batch[\"responses\"].size(1)\n\n    response_length = response_mask.size(1)\n    for i in range(len(responses)):\n        # response with tool response\n        valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]\n        response_with_obs = tokenizer.decode(valid_tokens)\n\n        # response without tool response\n        valid_tokens = responses[i][response_mask[i].bool()]\n        response_without_obs = tokenizer.decode(valid_tokens)\n\n        assert \"<tool_response>\" not in response_without_obs, (\n            f\"found <tool_response> in response: {response_without_obs}\"\n        )\n        assert \"</tool_response>\" not in response_without_obs, (\n            f\"found </tool_response> in response: {response_without_obs}\"\n        )\n        print(\"=========================\")\n        print(response_with_obs)\n        print(\"---\")\n        print(response_without_obs)\n\n    print(\"Test passed!\")\n    ray.shutdown()\n\n\ndef test_tool_agent_with_interaction(init_config):\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    # =========================== 1. Init rollout manager ===========================\n    tool_config = {\n        \"tools\": [\n            {\n                \"class_name\": \"tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool\",\n                \"config\": {\"type\": \"native\"},\n            },\n            {\n                \"class_name\": \"tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData\",\n                \"config\": {\"type\": \"native\"},\n            },\n        ]\n    }\n    tool_config_path = \"/tmp/tool_config.json\"\n    with open(tool_config_path, \"w\") as f:\n        json.dump(tool_config, f)\n\n    interaction_config = {\n        \"interaction\": [\n            {\"name\": \"weather\", \"class_name\": \"verl.interactions.weather_interaction.WeatherInteraction\", \"config\": {}}\n        ]\n    }\n    interaction_config_path = \"/tmp/interaction_config.json\"\n    with open(interaction_config_path, \"w\") as f:\n        json.dump(interaction_config, f)\n\n    n = 2\n    init_config.actor_rollout_ref.rollout.n = n\n    init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path\n    init_config.actor_rollout_ref.rollout.multi_turn.interaction_config_path = interaction_config_path\n    init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2\n    agent_loop_manager = init_agent_loop_manager(init_config)\n\n    # =========================== 2. Generate sequences  ===========================\n    raw_prompts = [\n        [\n            {\"role\": \"user\", \"content\": \"How are you?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in Los Angeles now?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in New York now?\"},\n        ],\n        [\n            {\n                \"role\": \"system\",\n                \"content\": \"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\\n\\n\"\n                \"Current Date: 2024-09-30\",\n            },\n            {\"role\": \"user\", \"content\": \"What's the temperature in San Francisco now? How about tomorrow?\"},\n        ],\n    ]\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),\n            \"agent_name\": np.array([\"tool_agent\"] * len(raw_prompts)),\n            \"data_source\": np.array([\"openai/gsm8k\"] * len(raw_prompts)),\n            \"reward_model\": np.array([{\"style\": \"rule\", \"ground_truth\": \"1.0\"}] * len(raw_prompts)),\n            \"extra_info\": np.array(\n                [\n                    {\"interaction_kwargs\": {\"name\": \"weather\"}},\n                    {\"interaction_kwargs\": {\"name\": \"weather\"}},\n                    {\"interaction_kwargs\": {\"name\": \"weather\"}},\n                    {\"interaction_kwargs\": {\"name\": \"weather\"}},\n                ]\n            ),\n        },\n    )\n    batch = batch.repeat(n)\n    result = agent_loop_manager.generate_sequences(prompts=batch)\n    assert len(result) == len(raw_prompts) * n\n\n    # Check turns\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    print(f\"num_turns: {num_turns}\")\n    for i in range(len(num_turns)):\n        if i // n == 0:\n            # [user, assistant, user]\n            assert num_turns[i] == 3\n        else:\n            # [user, assistant, tool, assistant, user]\n            assert num_turns[i] == 5\n\n    # Check response_mask\n    tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)\n    responses = result.batch[\"responses\"]\n    response_mask = result.batch[\"response_mask\"]\n    attention_mask = result.batch[\"attention_mask\"]\n    assert responses.size() == response_mask.size(), f\"{responses.size()} != {response_mask.size()}\"\n    response_length = response_mask.size(1)\n\n    for i in range(len(responses)):\n        # response with tool response\n        valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]\n        response_with_obs = tokenizer.decode(valid_tokens)\n\n        # response without tool response\n        valid_tokens = responses[i][response_mask[i].bool()]\n        response_without_obs = tokenizer.decode(valid_tokens)\n\n        assert \"\\udb82\\udc89\" not in response_without_obs, f\"found \\udb82\\udc89 in response: {response_without_obs}\"\n        assert \"\\udb82\\udc8a\" not in response_without_obs, f\"found \\udb82\\udc8a in response: {response_without_obs}\"\n        print(\"=========================\")\n        print(response_with_obs)\n        print(\"---\")\n        print(response_without_obs)\n\n    print(\"Test passed!\")\n    ray.shutdown()\n\n\n@pytest.mark.asyncio\nasync def test_get_trajectory_info():\n    \"\"\"Tests the get_trajectory_info method.\"\"\"\n    # Initialize the class to set up class-level attributes\n    step = 10\n    index = [1, 1, 3, 3]\n    expected_info = [\n        {\"step\": step, \"sample_index\": 1, \"rollout_n\": 0, \"validate\": False},\n        {\"step\": step, \"sample_index\": 1, \"rollout_n\": 1, \"validate\": False},\n        {\"step\": step, \"sample_index\": 3, \"rollout_n\": 0, \"validate\": False},\n        {\"step\": step, \"sample_index\": 3, \"rollout_n\": 1, \"validate\": False},\n    ]\n\n    trajectory_info = await get_trajectory_info(step, index, validate=False)\n\n    assert trajectory_info == expected_info\n"
  },
  {
    "path": "verl_distillation/tests/experimental/agent_loop/test_gpt_oss_tool_parser.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport pytest\nfrom transformers import AutoTokenizer\n\nfrom verl.experimental.agent_loop.tool_parser import GptOssToolParser\n\n\n@pytest.mark.asyncio\n@pytest.mark.skip(reason=\"local test only\")\nasync def test_gpt_oss_tool_parser():\n    example_text = \"\"\"\n<|start|>assistant<|channel|>commentary to=functions.get_current_weather \\\n<|constrain|>json<|message|>{\"location\": \"Tokyo\"}<|call|>\n<|start|>functions.get_current_weather to=assistant<|channel|>commentary<|message|>\\\n{ \"temperature\": 20, \"sunny\": true }<|end|>\"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\"openai/gpt-oss-20b\")\n    response_ids = tokenizer.encode(example_text)\n    tool_parser = GptOssToolParser(tokenizer)\n    _, function_calls = await tool_parser.extract_tool_calls(response_ids)\n    assert len(function_calls) == 1\n    assert function_calls[0].name == \"get_current_weather\"\n    assert function_calls[0].arguments == '{\"location\": \"Tokyo\"}'\n"
  },
  {
    "path": "verl_distillation/tests/experimental/agent_loop/test_multi_modal.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport os\nfrom typing import Any\n\nimport numpy as np\nimport pytest\nimport ray\nfrom omegaconf import DictConfig\nfrom PIL import Image\nfrom transformers.utils import get_json_schema\n\nfrom verl.experimental.agent_loop import AgentLoopManager\nfrom verl.protocol import DataProto\nfrom verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema\nfrom verl.tools.schemas import ToolResponse\nfrom verl.utils import hf_tokenizer\n\n\n@pytest.fixture\ndef init_config() -> DictConfig:\n    from hydra import compose, initialize_config_dir\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(\n            config_name=\"ppo_trainer\",\n            overrides=[\n                \"actor_rollout_ref.actor.use_dynamic_bsz=true\",\n                # test sleep/wake_up with fsdp offload\n                \"actor_rollout_ref.actor.fsdp_config.param_offload=True\",\n                \"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\",\n            ],\n        )\n\n    model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-VL-3B-Instruct\")\n    config.actor_rollout_ref.model.path = model_path\n    config.actor_rollout_ref.rollout.name = os.environ[\"ROLLOUT_NAME\"]\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.enforce_eager = True\n    config.actor_rollout_ref.rollout.prompt_length = 4096\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.n = 4\n    config.actor_rollout_ref.rollout.agent.num_workers = 2\n    config.actor_rollout_ref.rollout.skip_tokenizer_init = True\n\n    return config\n\n\nclass ImageGeneratorTool(BaseTool):\n    def generate_image(self, description: str, size: str = \"256x256\"):\n        \"\"\"Generate a simple image based on description.\n\n        Args:\n            description: The description of the image to generate.\n            size: The size of the image. Defaults to \"256x256\". (choices: [\"256x256\", \"512x512\"])\n\n        Returns:\n            A generated image\n        \"\"\"\n        print(f\"[DEBUG] generate_image: {description}, {size}\")\n        # Create a simple colored image for testing\n        width, height = map(int, size.split(\"x\"))\n\n        # Create different colors based on description\n        if \"red\" in description.lower():\n            color = (255, 0, 0)\n        elif \"blue\" in description.lower():\n            color = (0, 0, 255)\n        elif \"green\" in description.lower():\n            color = (0, 255, 0)\n        else:\n            color = (128, 128, 128)  # gray\n\n        # Create image\n        image = Image.new(\"RGB\", (width, height), color)\n\n        # Add some pattern to make it more interesting\n        for i in range(0, width, 50):\n            for j in range(0, height, 50):\n                # Add white squares in a grid pattern\n                for x in range(i, min(i + 20, width)):\n                    for y in range(j, min(j + 20, height)):\n                        image.putpixel((x, y), (255, 255, 255))\n\n        return image\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        schema = get_json_schema(self.generate_image)\n        return OpenAIFunctionToolSchema(**schema)\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:\n        try:\n            image = self.generate_image(**parameters)\n            # Return the PIL Image directly - the framework should handle the conversion\n            return ToolResponse(image=[image]), 0, {}\n        except Exception as e:\n            return ToolResponse(text=str(e)), 0, {}\n\n\ndef test_multimodal_tool_agent(init_config):\n    \"\"\"Test agent loop with multimodal tool that returns images using Qwen VL model.\"\"\"\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        },\n        ignore_reinit_error=True,\n    )\n\n    # Add custom chat template to enable tool calling support (same as recipe/deepeyes)\n    template_path = os.path.join(os.path.dirname(__file__), \"qwen_vl_tool_chat_template.jinja2\")\n    with open(template_path, encoding=\"utf-8\") as f:\n        custom_chat_template = f.read()\n\n    init_config.actor_rollout_ref.model.custom_chat_template = custom_chat_template\n\n    # =========================== 1. Init rollout manager with image tool ===========================\n    tool_config = {\n        \"tools\": [\n            {\n                \"class_name\": \"tests.experimental.agent_loop.test_multi_modal.ImageGeneratorTool\",\n                \"config\": {\"type\": \"native\"},\n            },\n        ]\n    }\n    tool_config_path = \"/tmp/multimodal_tool_config.json\"\n    with open(tool_config_path, \"w\") as f:\n        json.dump(tool_config, f)\n\n    n = 2\n    init_config.actor_rollout_ref.rollout.n = n\n    init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path\n    init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1\n    init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1\n    agent_loop_manager = AgentLoopManager(init_config)\n\n    # =========================== 2. Generate sequences with multimodal prompts ===========================\n    raw_prompts = [\n        [\n            {\"role\": \"user\", \"content\": \"How are you?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"Please generate a red image for me.\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"Can you create a blue picture with size 512x512?\"},\n        ],\n        [\n            {\n                \"role\": \"system\",\n                \"content\": (\n                    \"You are Qwen VL, created by Alibaba Cloud. You are a helpful \"\n                    \"assistant that can generate and analyze images.\"\n                ),\n            },\n            {\"role\": \"user\", \"content\": \"Generate a green landscape image and describe what you see in it.\"},\n        ],\n    ]\n\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),\n            \"agent_name\": np.array([\"tool_agent\"] * len(raw_prompts)),\n            \"data_source\": np.array([\"openai/gsm8k\"] * len(raw_prompts)),\n            \"reward_model\": np.array([{\"style\": \"rule\", \"ground_truth\": \"1.0\"}] * len(raw_prompts)),\n        },\n    )\n    batch = batch.repeat(n)\n    result = agent_loop_manager.generate_sequences(prompts=batch)\n    assert len(result) == len(raw_prompts) * n\n\n    # Check turns\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    print(f\"num_turns: {num_turns}\")\n    for i in range(len(num_turns)):\n        if i // n == 0:\n            # First prompt: \"How are you?\" - should have 2 turns [user, assistant]\n            assert num_turns[i] == 2, f\"Expected 2 turns but got {num_turns[i]} for sample {i}\"\n        else:\n            # Tool-calling prompts should have 4 turns [user, assistant, tool, assistant]\n            assert num_turns[i] == 4, f\"Expected 4 turns but got {num_turns[i]} for sample {i}\"\n\n    # Check that images were properly returned in the tool responses\n    tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)\n    responses = result.batch[\"responses\"]\n    response_mask = result.batch[\"response_mask\"]\n    attention_mask = result.batch[\"attention_mask\"]\n    assert responses.size() == response_mask.size(), f\"{responses.size()} != {response_mask.size()}\"\n    response_length = response_mask.size(1)\n\n    image_found_count = 0\n    for i in range(len(responses)):\n        # response with tool response (including images)\n        valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]\n        response_with_obs = tokenizer.decode(valid_tokens)\n\n        # response without tool response\n        valid_tokens = responses[i][response_mask[i].bool()]\n        response_without_obs = tokenizer.decode(valid_tokens)\n\n        # Check that tool responses were properly masked out from training\n        assert \"<tool_response>\" not in response_without_obs, (\n            f\"found <tool_response> in response: {response_without_obs}\"\n        )\n        assert \"</tool_response>\" not in response_without_obs, (\n            f\"found </tool_response> in response: {response_without_obs}\"\n        )\n\n        # Check that images were included in the full response\n        if \"<image>\" in response_with_obs or \"image\" in response_with_obs.lower():\n            image_found_count += 1\n\n        print(\"=========================\")\n        print(\"Response with tool observations:\")\n        print(response_with_obs)\n        print(\"---\")\n        print(\"Response without tool observations:\")\n        print(response_without_obs)\n\n    # Verify that tool-calling responses contained image-related content\n    print(f\"Found {image_found_count} responses with image content out of {len(responses)}\")\n    # We should have at least some image content from the tool-calling prompts\n    # Note: First prompt might not use tools, so we don't expect 100% image content\n    expected_tool_calls = sum(1 for i in range(len(num_turns)) if num_turns[i] == 4)\n    assert image_found_count >= 0, (\n        f\"No image-related content found, but expected at least some from {expected_tool_calls} tool calls\"\n    )\n\n    print(\"Multimodal tool test passed!\")\n    ray.shutdown()\n\n\ndef test_multimodal_single_turn_agent(init_config):\n    \"\"\"Test single turn agent loop with multimodal inputs using Qwen VL model.\"\"\"\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        },\n        ignore_reinit_error=True,\n    )\n\n    # =========================== 1. Init rollout manager ===========================\n    n = 2\n    init_config.actor_rollout_ref.rollout.n = n\n    init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1\n    init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1\n    agent_loop_manager = AgentLoopManager(init_config)\n\n    # =========================== 2. Generate sequences with multimodal prompts ===========================\n    # Create a simple test image\n    test_image = Image.new(\"RGB\", (256, 256), (100, 150, 200))\n    test_image2 = Image.new(\"RGB\", (512, 512), (100, 150, 200))\n\n    raw_prompts = [\n        [\n            {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n        ],\n        [\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image\"},\n                    {\"type\": \"text\", \"text\": \"What color is this image?\"},\n                ],\n            },\n        ],\n        [\n            {\n                \"role\": \"system\",\n                \"content\": \"You are Qwen VL, created by Alibaba Cloud. You are a helpful assistant.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image\"},\n                    {\"type\": \"text\", \"text\": \"Describe this image in detail.\"},\n                ],\n            },\n        ],\n    ]\n\n    # Prepare multi_modal_data for each prompt\n    multi_modal_data_list = [\n        None,  # First prompt: text only\n        {\"image\": test_image},  # Second prompt: with image\n        {\"image\": test_image2},  # Third prompt: with image\n    ]\n\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),\n            \"agent_name\": np.array([\"single_turn_agent\"] * len(raw_prompts)),\n            \"data_source\": np.array([\"openai/gsm8k\"] * len(raw_prompts)),\n            \"reward_model\": np.array([{\"style\": \"rule\", \"ground_truth\": \"1.0\"}] * len(raw_prompts)),\n        },\n    )\n\n    # Add multi_modal_data to batch\n    multi_modal_data_array = np.array([data if data else {} for data in multi_modal_data_list], dtype=object)\n    batch.non_tensor_batch[\"multi_modal_data\"] = multi_modal_data_array\n\n    batch = batch.repeat(n)\n    result = agent_loop_manager.generate_sequences(prompts=batch)\n    assert len(result) == len(raw_prompts) * n\n\n    # Check turns - all should be single turn (2: user + assistant)\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    print(f\"num_turns: {num_turns}\")\n    for i in range(len(num_turns)):\n        assert num_turns[i] == 2, f\"Expected 2 turns but got {num_turns[i]} for sample {i}\"\n\n    # Verify responses\n    tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)\n    prompts = result.batch[\"prompts\"]\n    responses = result.batch[\"responses\"]\n    response_mask = result.batch[\"response_mask\"]\n    assert responses.size() == response_mask.size(), f\"{responses.size()} != {response_mask.size()}\"\n\n    # Check for image pads in prompts\n    image_pad_count = 0\n    for i in range(len(prompts)):\n        prompt_ids = prompts[i][prompts[i] != tokenizer.pad_token_id].tolist()\n        prompt_text = tokenizer.decode(prompt_ids)\n\n        # Check if this sample should have image pads (samples with index 1 and 2 in each repeat have images)\n        sample_idx = i // n\n        has_image_pad = \"<|image_pad|>\" in prompt_text or \"<|vision_start|>\" in prompt_text\n\n        print(\"=========================\")\n        print(f\"Sample {i} (original prompt index: {sample_idx}):\")\n        print(f\"Prompt length: {len(prompt_ids)} tokens\")\n        print(f\"Has image_pad: {has_image_pad}\")\n\n        if sample_idx != 0:  # Samples 1 and 2 should have images\n            if has_image_pad:\n                image_pad_count += 1\n                # Count the number of image_pad tokens\n                num_image_pads = prompt_text.count(\"<|image_pad|>\")\n                print(f\"Number of <|image_pad|> tokens: {num_image_pads}\")\n            else:\n                print(\"WARNING: Expected image_pad but not found!\")\n\n        # Show first 200 chars of prompt\n        print(f\"Prompt text (first 200 chars): {prompt_text[:200]}...\")\n\n    for i in range(len(responses)):\n        valid_tokens = responses[i][response_mask[i].bool()]\n        response_text = tokenizer.decode(valid_tokens)\n        print(f\"Sample {i} response: {response_text[:100]}...\")\n\n    # Verify that we found image pads in multimodal samples\n    expected_multimodal_samples = 2 * n  # 2 prompts with images, repeated n times\n    print(f\"\\nFound {image_pad_count} samples with image_pad out of {expected_multimodal_samples} expected\")\n    assert image_pad_count > 0, \"No image_pad tokens found in multimodal samples!\"\n\n    print(\"Single turn multimodal test passed!\")\n    ray.shutdown()\n\n\ndef test_multimodal_partial_single_turn_agent(init_config):\n    \"\"\"Test partial single turn agent loop with multimodal inputs using Qwen VL model.\"\"\"\n\n    # TODO(baiyan):\n    #    see verl/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py for more details.\n    #    if use_correct_processor=True, the test will pass but the async training will hang, so I disable this test\n    #    for now\n\n    return\n\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        },\n        ignore_reinit_error=True,\n    )\n    from recipe.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager\n\n    # =========================== 1. Init rollout manager ===========================\n    n = 2\n    init_config.actor_rollout_ref.rollout.n = n\n    init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1\n    init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1\n    import asyncio\n\n    loop = asyncio.new_event_loop()\n    asyncio.set_event_loop(loop)\n    agent_loop_manager = loop.run_until_complete(FullyAsyncAgentLoopManager.create(init_config))\n\n    # =========================== 2. Generate sequences with multimodal prompts ===========================\n    # Create a simple test image\n    test_image = Image.new(\"RGB\", (256, 256), (200, 100, 50))\n    test_image2 = Image.new(\"RGB\", (512, 512), (100, 150, 200))\n\n    raw_prompts = [\n        [\n            {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n        ],\n        [\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image\"},\n                    {\"type\": \"text\", \"text\": \"What do you see in this image?\"},\n                ],\n            },\n        ],\n        [\n            {\n                \"role\": \"system\",\n                \"content\": \"You are Qwen VL, a helpful multimodal assistant.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image\"},\n                    {\"type\": \"text\", \"text\": \"Analyze the colors in this image.\"},\n                ],\n            },\n        ],\n    ]\n\n    # Prepare multi_modal_data for each prompt\n    multi_modal_data_list = [\n        None,  # First prompt: text only\n        {\"image\": test_image},  # Second prompt: with image\n        {\"image\": test_image2},  # Third prompt: with image\n    ]\n\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),\n            \"agent_name\": np.array([\"partial_single_turn_agent\"] * len(raw_prompts)),\n            \"data_source\": np.array([\"openai/gsm8k\"] * len(raw_prompts)),\n            \"reward_model\": np.array([{\"style\": \"rule\", \"ground_truth\": \"1.0\"}] * len(raw_prompts)),\n        },\n    )\n\n    # Add multi_modal_data to batch\n    multi_modal_data_array = np.array([data if data else {} for data in multi_modal_data_list], dtype=object)\n    batch.non_tensor_batch[\"multi_modal_data\"] = multi_modal_data_array\n\n    batch = batch.repeat(n)\n    result = agent_loop_manager.generate_sequences(prompts=batch)\n    assert len(result) == len(raw_prompts) * n\n\n    # Check turns - all should be single turn (2: user + assistant)\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    print(f\"num_turns: {num_turns}\")\n    for i in range(len(num_turns)):\n        assert num_turns[i] == 2, f\"Expected 2 turns but got {num_turns[i]} for sample {i}\"\n\n    # Verify responses\n    tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)\n    prompts = result.batch[\"prompts\"]\n    responses = result.batch[\"responses\"]\n    response_mask = result.batch[\"response_mask\"]\n    assert responses.size() == response_mask.size(), f\"{responses.size()} != {response_mask.size()}\"\n\n    # Check for image pads in prompts\n    image_pad_count = 0\n    for i in range(len(prompts)):\n        prompt_ids = prompts[i][prompts[i] != tokenizer.pad_token_id].tolist()\n        prompt_text = tokenizer.decode(prompt_ids)\n\n        # Check if this sample should have image pads (samples with index 1 and 2 in each repeat have images)\n        sample_idx = i // n\n        has_image_pad = \"<|image_pad|>\" in prompt_text or \"<|vision_start|>\" in prompt_text\n\n        print(\"=========================\")\n        print(f\"Sample {i} (original prompt index: {sample_idx}):\")\n        print(f\"Prompt length: {len(prompt_ids)} tokens\")\n        print(f\"Has image_pad: {has_image_pad}\")\n\n        if sample_idx != 0:  # Samples 1 and 2 should have images\n            if has_image_pad:\n                image_pad_count += 1\n                # Count the number of image_pad tokens\n                num_image_pads = prompt_text.count(\"<|image_pad|>\")\n                print(f\"Number of <|image_pad|> tokens: {num_image_pads}\")\n            else:\n                print(\"WARNING: Expected image_pad but not found!\")\n\n        # Show first 200 chars of prompt\n        print(f\"Prompt text (first 200 chars): {prompt_text[:200]}...\")\n\n    for i in range(len(responses)):\n        valid_tokens = responses[i][response_mask[i].bool()]\n        response_text = tokenizer.decode(valid_tokens)\n        print(f\"Sample {i} response: {response_text[:100]}...\")\n\n    # Verify that we found image pads in multimodal samples\n    expected_multimodal_samples = 2 * n  # 2 prompts with images, repeated n times\n    print(f\"\\nFound {image_pad_count} samples with image_pad out of {expected_multimodal_samples} expected\")\n    assert image_pad_count > 0, \"No image_pad tokens found in multimodal samples!\"\n\n    print(\"Partial single turn multimodal test passed!\")\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/experimental/agent_loop/test_standalone_rollout.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport os\n\nimport pytest\nimport ray\nfrom omegaconf import DictConfig\nfrom openai import AsyncOpenAI, OpenAI\n\nfrom tests.experimental.agent_loop.agent_utils import init_agent_loop_manager\nfrom verl.workers.rollout.replica import get_rollout_replica_class\n\n\n@pytest.fixture\ndef init_config() -> DictConfig:\n    from hydra import compose, initialize_config_dir\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(config_name=\"ppo_trainer\")\n\n    config.trainer.n_gpus_per_node = 4\n    config.trainer.nnodes = 2\n    config.actor_rollout_ref.actor.use_dynamic_bsz = True\n    config.actor_rollout_ref.model.path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-1.5B-Instruct\")\n    config.actor_rollout_ref.rollout.name = os.environ[\"ROLLOUT_NAME\"]\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.skip_tokenizer_init = False\n\n    return config\n\n\n@pytest.mark.asyncio\n@pytest.mark.parametrize(\"tp_size\", [2, 4])\nasync def test_standalone_rollout(init_config, tp_size):\n    \"\"\"Test standalone rollout single node and multi nodes.\"\"\"\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    init_config.actor_rollout_ref.rollout.tensor_model_parallel_size = tp_size\n    num_replicas = (init_config.trainer.n_gpus_per_node * init_config.trainer.nnodes) // tp_size\n    rollout_config = init_config.actor_rollout_ref.rollout\n    model_config = init_config.actor_rollout_ref.model\n\n    # create standalone rollout server\n    rollout_server_class = get_rollout_replica_class(init_config.actor_rollout_ref.rollout.name)\n    rollout_servers = [\n        rollout_server_class(\n            replica_rank=replica_rank, config=rollout_config, model_config=model_config, gpus_per_node=2\n        )\n        for replica_rank in range(num_replicas)\n    ]\n    await asyncio.gather(*[server.init_standalone() for server in rollout_servers])\n\n    server_handles = [server._server_handle for server in rollout_servers]\n    server_addresses = [server._server_address for server in rollout_servers]\n    assert len(server_handles) == num_replicas\n    assert len(server_addresses) == num_replicas\n\n    os.environ.pop(\"HTTPS_PROXY\", None)\n    os.environ.pop(\"HTTP_PROXY\", None)\n    os.environ.pop(\"NO_PROXY\", None)\n\n    client = AsyncOpenAI(\n        api_key=\"123-abc\",\n        base_url=f\"http://{server_addresses[0]}/v1\",\n    )\n\n    completion = await client.chat.completions.create(\n        model=init_config.actor_rollout_ref.model.path,\n        messages=[{\"role\": \"user\", \"content\": \"What can you do?\"}],\n    )\n    print(completion.choices[0].message.content)\n\n    ray.shutdown()\n\n\n@pytest.mark.skip(reason=\"local test only\")\ndef test_hybrid_rollout_with_ep(init_config):\n    \"\"\"Test hybrid rollout with expert parallelism, DP=2, TP=4, EP=8.\"\"\"\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    model_path = os.path.expanduser(\"~/models/Qwen/Qwen3-30B-A3B-Instruct-2507\")\n    init_config.actor_rollout_ref.model.path = model_path\n\n    # parallelism config\n    init_config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2\n    init_config.actor_rollout_ref.rollout.data_parallel_size = 4\n    init_config.actor_rollout_ref.rollout.expert_parallel_size = 8\n\n    # 1. init hybrid worker: FSDP+rollout\n    # - build FSDP model and optimizer\n    # - offload FSDP model and optimizer, build rollout\n    # - sleep rollout and load FSDP model and optimizer\n    agent_loop_manager = init_agent_loop_manager(init_config)\n\n    # 2. wake up rollout\n    # - wake_up weights\n    # - load_weights from FSDP\n    # - wake_up kv_cache\n    agent_loop_manager.wake_up()\n\n    # 3. test async openai call\n    server_address = agent_loop_manager.server_addresses[0]\n    client = OpenAI(\n        api_key=\"123-abc\",\n        base_url=f\"http://{server_address}/v1\",\n    )\n\n    smapling_params = {\n        \"temperature\": 1.0,\n        \"top_p\": 1.0,\n        \"max_tokens\": 512,\n    }\n\n    response = client.chat.completions.create(\n        model=model_path,\n        messages=[{\"role\": \"user\", \"content\": \"What can you do?\"}],\n        **smapling_params,\n    )\n\n    completion = response.choices[0].message.content\n    print(f\"response: {completion}\")\n\n    print(\"Test passed!\")\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/experimental/reward/reward_fn.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\n\nimport aiohttp\nfrom openai.types.chat import ChatCompletion\nfrom transformers import PreTrainedTokenizer\n\nGRM_PROMPT_TEMPLATE = \"\"\"\nYou are given a problem and a proposed solution.\n\nProblem:\n{problem}\n\nSolution:\n{solution}\n\nPlease evaluate how well the solution addresses the problem. \nGive a score from 1 to 10, where:\n- 1 means the solution is completely irrelevant or incorrect.\n- 5 means the solution is partially correct but incomplete or not well reasoned.\n- 10 means the solution is fully correct, well-reasoned, and directly solves the problem.\n\nOnly output the score as a single number (integer).\n\"\"\".strip()\n\n\nasync def chat_complete(router_address: str, chat_complete_request: dict):\n    url = f\"http://{router_address}/v1/chat/completions\"\n    try:\n        timeout = aiohttp.ClientTimeout(total=None)\n        session = aiohttp.ClientSession(timeout=timeout)\n        async with session.post(url, json=chat_complete_request) as resp:\n            output = await resp.text()\n            output = json.loads(output)\n            return ChatCompletion(**output)\n    except Exception as e:\n        raise e\n    finally:\n        await session.close()\n\n\nasync def compute_score_gsm8k(\n    data_source: str,\n    solution_str: str,\n    ground_truth: str,\n    extra_info: dict,\n    reward_router_address: str,\n    reward_model_tokenizer: PreTrainedTokenizer,\n):\n    \"\"\"Compute the reward score.\"\"\"\n\n    grm_prompt = GRM_PROMPT_TEMPLATE.format(problem=extra_info[\"question\"], solution=solution_str)\n    messages = [{\"role\": \"user\", \"content\": grm_prompt}]\n    sampling_params = {\"temperature\": 0.7, \"top_p\": 0.8, \"max_tokens\": 4096}\n    model_name = os.path.expanduser(\"~/models/Qwen/Qwen2.5-1.5B-Instruct\")\n    chat_complete_request = {\n        \"messages\": messages,\n        \"model\": model_name,\n        **sampling_params,\n    }\n    result = await chat_complete(\n        router_address=reward_router_address,\n        chat_complete_request=chat_complete_request,\n    )\n    grm_response = result.choices[0].message.content\n    try:\n        score = int(grm_response.split(\"\\n\\n\")[-1].strip())\n    except Exception:\n        score = 0\n    return {\"score\": score, \"acc\": score == 10}\n"
  },
  {
    "path": "verl_distillation/tests/experimental/reward/test_agent_loop_reward_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nimport ray\nfrom hydra import compose, initialize_config_dir\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom transformers import AutoTokenizer\n\nfrom verl.experimental.agent_loop import AgentLoopManager\nfrom verl.protocol import DataProto\nfrom verl.trainer.main_ppo import create_rl_sampler\nfrom verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn\n\n\ndef test_agent_loop_reward_manager():\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n    with initialize_config_dir(config_dir=os.path.abspath(\"recipe/fapo/config\")):\n        config = compose(\"rm_config\")\n\n    rollout_model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B-Instruct\")\n    reward_model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-1.5B-Instruct\")\n\n    # actor_rollout_ref config\n    config.data.return_raw_chat = True\n    config.data.max_prompt_length = 1024\n    config.data.max_response_length = 4096\n    config.actor_rollout_ref.model.path = rollout_model_path\n    config.actor_rollout_ref.actor.use_dynamic_bsz = True\n    config.actor_rollout_ref.rollout.name = os.getenv(\"ROLLOUT_NAME\", \"vllm\")\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2\n    config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9\n    config.actor_rollout_ref.rollout.enforce_eager = True\n    config.actor_rollout_ref.rollout.prompt_length = 1024\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.skip_tokenizer_init = True\n    config.trainer.n_gpus_per_node = 4\n    config.trainer.nnodes = 1\n\n    config.reward_model.reward_manager = \"dapo\"\n    config.reward_model.enable = True\n    config.reward_model.enable_resource_pool = True\n    config.reward_model.n_gpus_per_node = 4\n    config.reward_model.nnodes = 1\n    config.reward_model.model.path = reward_model_path\n    config.reward_model.rollout.name = os.getenv(\"ROLLOUT_NAME\", \"vllm\")\n    config.reward_model.rollout.gpu_memory_utilization = 0.9\n    config.reward_model.rollout.tensor_model_parallel_size = 2\n    config.reward_model.rollout.skip_tokenizer_init = False\n    config.reward_model.rollout.prompt_length = 5120\n    config.reward_model.rollout.response_length = 4096\n    config.custom_reward_function.path = \"tests/experimental/reward/reward_fn.py\"\n    config.custom_reward_function.name = \"compute_score_gsm8k\"\n\n    # 1. init reward model manager\n    agent_loop_manager = AgentLoopManager(config)\n\n    # 2. init test data\n    local_folder = os.path.expanduser(\"~/data/gsm8k/\")\n    data_files = [os.path.join(local_folder, \"train.parquet\")]\n    tokenizer = AutoTokenizer.from_pretrained(rollout_model_path)\n\n    dataset = RLHFDataset(\n        data_files=data_files,\n        tokenizer=tokenizer,\n        config=config.data,\n        processor=None,\n    )\n\n    batch_size = 64\n    sampler = create_rl_sampler(config.data, dataset)\n    dataloader = StatefulDataLoader(\n        dataset=dataset,\n        batch_size=batch_size,\n        num_workers=config.data.dataloader_num_workers,\n        drop_last=True,\n        collate_fn=collate_fn,\n        sampler=sampler,\n    )\n\n    # 3. generate responses\n    batch_dict = next(iter(dataloader))\n    batch = DataProto.from_single_dict(batch_dict)\n    gen_batch = agent_loop_manager.generate_sequences(prompts=batch)\n\n    rm_scores = gen_batch.batch[\"rm_scores\"]\n    sample_scores = rm_scores.sum(dim=1)\n    print(sample_scores)\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/experimental/reward/test_reward_model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nimport ray\nfrom hydra import compose, initialize_config_dir\n\nfrom verl.experimental.reward import RewardModelManager\nfrom verl.protocol import DataProto\n\nGRM_PROMPT_TEMPLATE = \"\"\"\nYou are given a problem and a proposed solution.\n\nProblem:\n{problem}\n\nSolution:\n{solution}\n\nPlease evaluate how well the solution addresses the problem. \nGive a score from 1 to 10, where:\n- 1 means the solution is completely irrelevant or incorrect.\n- 5 means the solution is partially correct but incomplete or not well reasoned.\n- 10 means the solution is fully correct, well-reasoned, and directly solves the problem.\n\nOnly output the score as a single number (integer).\n\"\"\".strip()\n\n\ndef create_data_samples() -> DataProto:\n    convs = [\n        {\n            \"problem\": \"What is the range of the numeric output of a sigmoid node in a neural network?\",\n            \"solution\": \"Between -1 and 1.\",\n        },\n        {\n            \"problem\": \"What is the range of the numeric output of a sigmoid node in a neural network?\",\n            \"solution\": \"Between 0 and 1.\",\n        },\n        {\n            \"problem\": \"What is the capital of Australia?\",\n            \"solution\": \"Canberra is the capital city of Australia.\",\n        },\n        {\n            \"problem\": \"What is the capital of Australia?\",\n            \"solution\": \"Sydney is the capital city of Australia.\",\n        },\n    ]\n\n    messages = [[{\"role\": \"user\", \"content\": GRM_PROMPT_TEMPLATE.format(**conv)}] for conv in convs]\n    prompts = DataProto.from_dict(\n        non_tensors={\n            \"raw_prompt\": messages,\n        }\n    )\n    return convs, prompts\n\n\ndef test_reward_model_manager():\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n    with initialize_config_dir(config_dir=os.path.abspath(\"recipe/fapo/config\")):\n        config = compose(\"rm_config\")\n\n    model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B-Instruct\")\n\n    config.reward_model.reward_manager = \"dapo\"\n    config.reward_model.enable = True\n    config.reward_model.enable_resource_pool = True\n    config.reward_model.n_gpus_per_node = 8\n    config.reward_model.nnodes = 1\n    config.reward_model.model.path = model_path\n    config.reward_model.rollout.name = os.getenv(\"ROLLOUT_NAME\", \"vllm\")\n    config.reward_model.rollout.gpu_memory_utilization = 0.9\n    config.reward_model.rollout.tensor_model_parallel_size = 2\n    config.reward_model.rollout.skip_tokenizer_init = False\n    config.reward_model.rollout.prompt_length = 2048\n    config.reward_model.rollout.response_length = 4096\n\n    # 1. init reward model manager\n    reward_model_manager = RewardModelManager(config.reward_model)\n\n    # 2. init test data\n    convs, prompts = create_data_samples()\n\n    # 3. generate responses\n    sampling_params = {\n        \"max_tokens\": 4096,\n        \"temperature\": 0.7,\n        \"top_p\": 0.8,\n        \"top_k\": 20,\n    }\n    results = reward_model_manager.generate_sequences(prompts, sampling_params)\n    responses = [result.choices[0].message.content for result in results]\n\n    for idx, (conv, response) in enumerate(zip(convs, responses, strict=False)):\n        print(f\"Problem {idx}:\\n{conv['problem']}\\n\")\n        print(f\"AI Solution {idx}:\\n{conv['solution']}\\n\")\n        print(f\"GRM Response {idx}:\\n{response}\\n\")\n        print(\"=\" * 50 + \"\\n\")\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/interactions/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/tests/interactions/test_gsm8k_interaction.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom unittest.mock import patch\n\nimport pytest\n\nfrom verl.interactions.gsm8k_interaction import Gsm8kInteraction\n\n\nclass TestGsm8kInteraction:\n    \"\"\"Test cases for Gsm8kInteraction class.\"\"\"\n\n    def setup_method(self):\n        \"\"\"Set up test environment before each test method.\"\"\"\n        self.config = {\"name\": \"gsm8k\"}\n        self.interaction = Gsm8kInteraction(self.config)\n\n    def test_init(self):\n        \"\"\"Test Gsm8kInteraction initialization.\"\"\"\n        assert self.interaction._instance_dict == {}\n        assert self.interaction.config == self.config\n        assert self.interaction.name == \"gsm8k\"\n\n    @pytest.mark.asyncio\n    async def test_start_interaction_with_instance_id(self):\n        \"\"\"Test start_interaction with provided instance_id.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        result_id = await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        assert result_id == instance_id\n        assert instance_id in self.interaction._instance_dict\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"\"\n        assert self.interaction._instance_dict[instance_id][\"ground_truth\"] == ground_truth\n        assert self.interaction._instance_dict[instance_id][\"reward\"] == 0.0\n\n    @pytest.mark.asyncio\n    async def test_start_interaction_without_instance_id(self):\n        \"\"\"Test start_interaction without provided instance_id (auto-generated).\"\"\"\n        ground_truth = \"42\"\n\n        result_id = await self.interaction.start_interaction(ground_truth=ground_truth)\n\n        assert result_id is not None\n        assert len(result_id) == 36  # UUID4 length\n        assert result_id in self.interaction._instance_dict\n        assert self.interaction._instance_dict[result_id][\"ground_truth\"] == ground_truth\n\n    @pytest.mark.asyncio\n    async def test_start_interaction_without_ground_truth(self):\n        \"\"\"Test start_interaction without ground_truth parameter.\"\"\"\n        instance_id = \"test_instance\"\n\n        result_id = await self.interaction.start_interaction(instance_id=instance_id)\n\n        assert result_id == instance_id\n        assert self.interaction._instance_dict[instance_id][\"ground_truth\"] is None\n\n    @pytest.mark.asyncio\n    async def test_generate_response_correct_answer_with_prefix(self):\n        \"\"\"Test generate_response with correct answer already having #### prefix.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [{\"role\": \"assistant\", \"content\": \"#### 42\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is True\n        assert response == \"Your response is correct!\"\n        assert reward == 1.0\n        assert metadata == {}\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"#### 42\"\n\n    @pytest.mark.asyncio\n    async def test_generate_response_correct_answer_without_prefix(self):\n        \"\"\"Test generate_response with correct answer missing #### prefix.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [{\"role\": \"assistant\", \"content\": \"42\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is True\n        assert response == \"Your response is correct!\"\n        assert reward == 1.0\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"42\"\n\n    @pytest.mark.asyncio\n    async def test_generate_response_incorrect_answer(self):\n        \"\"\"Test generate_response with incorrect answer.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [{\"role\": \"assistant\", \"content\": \"24\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is False\n        assert response == \"Your response is incorrect! You need to reflect on your answer and try again.\"\n        assert reward == 0.0\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"24\"\n\n    @pytest.mark.asyncio\n    async def test_generate_response_multiple_messages(self):\n        \"\"\"Test generate_response with multiple messages (should use last assistant message).\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [\n            {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n            {\"role\": \"assistant\", \"content\": \"### 4\"},\n            {\"role\": \"user\", \"content\": \"What is 40+2?\"},\n            {\"role\": \"assistant\", \"content\": \"#### 42\"},\n        ]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is True\n        assert response == \"Your response is correct!\"\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"#### 42\"\n\n    @pytest.mark.asyncio\n    async def test_generate_response_no_assistant_message(self):\n        \"\"\"Test generate_response with no assistant messages.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [{\"role\": \"user\", \"content\": \"Hello!\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is False\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"\"\n\n    @pytest.mark.asyncio\n    async def test_calculate_score_direct_call(self):\n        \"\"\"Test calculate_score method directly.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        # Set a response\n        self.interaction._instance_dict[instance_id][\"response\"] = \"#### 42\"\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0) as mock_compute:\n            score = await self.interaction.calculate_score(instance_id)\n\n            assert score == 1.0\n            mock_compute.assert_called_once_with(\"#### 42\", \"42\", method=\"strict\", format_score=0.0, score=1.0)\n\n    @pytest.mark.asyncio\n    async def test_calculate_score_with_kwargs(self):\n        \"\"\"Test calculate_score method with additional kwargs.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        # Set a response\n        self.interaction._instance_dict[instance_id][\"response\"] = \"#### 24\"\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0) as mock_compute:\n            score = await self.interaction.calculate_score(instance_id, extra_param=\"test\")\n\n            assert score == 0.0\n            mock_compute.assert_called_once_with(\"#### 24\", \"42\", method=\"strict\", format_score=0.0, score=1.0)\n\n    @pytest.mark.asyncio\n    async def test_finalize_interaction(self):\n        \"\"\"Test finalize_interaction method.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        assert instance_id in self.interaction._instance_dict\n\n        await self.interaction.finalize_interaction(instance_id)\n\n        assert instance_id not in self.interaction._instance_dict\n\n    @pytest.mark.asyncio\n    async def test_finalize_interaction_with_kwargs(self):\n        \"\"\"Test finalize_interaction method with additional kwargs.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        assert instance_id in self.interaction._instance_dict\n\n        await self.interaction.finalize_interaction(instance_id, extra_param=\"test\")\n\n        assert instance_id not in self.interaction._instance_dict\n\n    @pytest.mark.asyncio\n    async def test_finalize_nonexistent_interaction(self):\n        \"\"\"Test finalize_interaction with non-existent instance_id.\"\"\"\n        instance_id = \"nonexistent_instance\"\n\n        # This should raise KeyError\n        with pytest.raises(KeyError):\n            await self.interaction.finalize_interaction(instance_id)\n\n    @pytest.mark.asyncio\n    async def test_full_interaction_workflow_correct(self):\n        \"\"\"Test complete interaction workflow with correct answer.\"\"\"\n        ground_truth = \"42\"\n\n        # Start interaction\n        instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)\n\n        # Generate response with correct answer\n        messages = [{\"role\": \"assistant\", \"content\": \"42\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is True\n        assert reward == 1.0\n\n        # Finalize interaction\n        await self.interaction.finalize_interaction(instance_id)\n        assert instance_id not in self.interaction._instance_dict\n\n    @pytest.mark.asyncio\n    async def test_full_interaction_workflow_incorrect(self):\n        \"\"\"Test complete interaction workflow with incorrect answer.\"\"\"\n        ground_truth = \"42\"\n\n        # Start interaction\n        instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)\n\n        # Generate response with incorrect answer\n        messages = [{\"role\": \"assistant\", \"content\": \"24\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is False\n        assert reward == 0.0\n\n        # Continue with another attempt\n        messages.append({\"role\": \"user\", \"content\": response})\n        messages.append({\"role\": \"assistant\", \"content\": \"42\"})\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is True\n        assert reward == 1.0\n\n        # Finalize interaction\n        await self.interaction.finalize_interaction(instance_id)\n        assert instance_id not in self.interaction._instance_dict\n\n    @pytest.mark.asyncio\n    async def test_multiple_concurrent_interactions(self):\n        \"\"\"Test multiple concurrent interaction instances.\"\"\"\n        ground_truth_1 = \"42\"\n        ground_truth_2 = \"24\"\n\n        # Start multiple interactions\n        instance_id_1 = await self.interaction.start_interaction(ground_truth=ground_truth_1)\n        instance_id_2 = await self.interaction.start_interaction(ground_truth=ground_truth_2)\n\n        assert len(self.interaction._instance_dict) == 2\n        assert instance_id_1 in self.interaction._instance_dict\n        assert instance_id_2 in self.interaction._instance_dict\n\n        # Test responses for both instances\n        messages_1 = [{\"role\": \"assistant\", \"content\": \"42\"}]\n        messages_2 = [{\"role\": \"assistant\", \"content\": \"24\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", side_effect=[1.0, 1.0]):\n            should_terminate_1, _, reward_1, _ = await self.interaction.generate_response(instance_id_1, messages_1)\n            should_terminate_2, _, reward_2, _ = await self.interaction.generate_response(instance_id_2, messages_2)\n\n        assert should_terminate_1 is True\n        assert should_terminate_2 is True\n        assert reward_1 == 1.0\n        assert reward_2 == 1.0\n\n        # Finalize both interactions\n        await self.interaction.finalize_interaction(instance_id_1)\n        await self.interaction.finalize_interaction(instance_id_2)\n\n        assert len(self.interaction._instance_dict) == 0\n\n    @pytest.mark.asyncio\n    async def test_edge_case_empty_messages(self):\n        \"\"\"Test edge case with empty messages list.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = []\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is False\n        assert reward == 0.0\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"\"\n\n    @pytest.mark.asyncio\n    async def test_edge_case_message_without_content(self):\n        \"\"\"Test edge case with message without content field.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [\n            {\"role\": \"assistant\"}  # Missing content field\n        ]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is False\n        assert reward == 0.0\n        assert self.interaction._instance_dict[instance_id][\"response\"] is None\n\n    def test_inheritance_from_base_interaction(self):\n        \"\"\"Test that Gsm8kInteraction properly inherits from BaseInteraction.\"\"\"\n        from verl.interactions.base import BaseInteraction\n\n        assert isinstance(self.interaction, BaseInteraction)\n\n        # Test that all required methods are implemented\n        assert hasattr(self.interaction, \"start_interaction\")\n        assert hasattr(self.interaction, \"generate_response\")\n        assert hasattr(self.interaction, \"calculate_score\")\n        assert hasattr(self.interaction, \"finalize_interaction\")\n\n        # Test that methods are callable\n        assert callable(self.interaction.start_interaction)\n        assert callable(self.interaction.generate_response)\n        assert callable(self.interaction.calculate_score)\n        assert callable(self.interaction.finalize_interaction)\n\n    def test_name_attribute_initialization(self):\n        \"\"\"Test name attribute initialization with different configs.\"\"\"\n        # Test with explicit name in config\n        config_with_name = {\"name\": \"custom_gsm8k\"}\n        interaction_with_name = Gsm8kInteraction(config_with_name)\n        assert interaction_with_name.name == \"custom_gsm8k\"\n\n        # Test with default name when not provided in config\n        config_without_name = {}\n        interaction_without_name = Gsm8kInteraction(config_without_name)\n        assert interaction_without_name.name == \"interaction_agent\"  # Default from BaseInteraction\n\n        # Test that name is accessible as attribute\n        assert hasattr(self.interaction, \"name\")\n        assert self.interaction.name == \"gsm8k\"\n"
  },
  {
    "path": "verl_distillation/tests/interactions/test_interaction_registry.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport tempfile\n\nimport pytest\nfrom omegaconf import OmegaConf\n\nfrom verl.interactions.base import BaseInteraction\nfrom verl.interactions.gsm8k_interaction import Gsm8kInteraction\nfrom verl.interactions.utils.interaction_registry import (\n    get_interaction_class,\n    initialize_interactions_from_config,\n)\n\n\nclass TestInteractionRegistry:\n    def test_get_interaction_class(self):\n        \"\"\"Test getting interaction class by name.\"\"\"\n        # Test getting base interaction class\n        base_cls = get_interaction_class(\"verl.interactions.base.BaseInteraction\")\n        assert base_cls == BaseInteraction\n\n        # Test getting gsm8k interaction class\n        gsm8k_cls = get_interaction_class(\"verl.interactions.gsm8k_interaction.Gsm8kInteraction\")\n        assert gsm8k_cls == Gsm8kInteraction\n\n    def test_initialize_single_interaction_from_config(self):\n        \"\"\"Test initializing single interaction from config.\"\"\"\n        # Create temporary config file\n        config_content = {\n            \"interaction\": [\n                {\n                    \"name\": \"test_gsm8k\",\n                    \"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\",\n                    \"config\": {},\n                }\n            ]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            interaction_map = initialize_interactions_from_config(temp_config_path)\n\n            # Check that interaction was created\n            assert len(interaction_map) == 1\n            assert \"test_gsm8k\" in interaction_map\n            assert isinstance(interaction_map[\"test_gsm8k\"], Gsm8kInteraction)\n            assert interaction_map[\"test_gsm8k\"].name == \"test_gsm8k\"\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_initialize_multiple_interactions_from_config(self):\n        \"\"\"Test initializing multiple interactions from config.\"\"\"\n        config_content = {\n            \"interaction\": [\n                {\n                    \"name\": \"gsm8k_solver\",\n                    \"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\",\n                    \"config\": {},\n                },\n                {\n                    \"name\": \"base_agent\",\n                    \"class_name\": \"verl.interactions.base.BaseInteraction\",\n                    \"config\": {\"custom_param\": \"test_value\"},\n                },\n            ]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            interaction_map = initialize_interactions_from_config(temp_config_path)\n\n            # Check that both interactions were created\n            assert len(interaction_map) == 2\n            assert \"gsm8k_solver\" in interaction_map\n            assert \"base_agent\" in interaction_map\n\n            # Check types\n            assert isinstance(interaction_map[\"gsm8k_solver\"], Gsm8kInteraction)\n            assert isinstance(interaction_map[\"base_agent\"], BaseInteraction)\n\n            # Check names were injected\n            assert interaction_map[\"gsm8k_solver\"].name == \"gsm8k_solver\"\n            assert interaction_map[\"base_agent\"].name == \"base_agent\"\n\n            # Check custom config was passed\n            assert interaction_map[\"base_agent\"].config.get(\"custom_param\") == \"test_value\"\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_initialize_interaction_without_explicit_name(self):\n        \"\"\"Test that interaction name is derived from class name when not specified.\"\"\"\n        config_content = {\n            \"interaction\": [{\"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\", \"config\": {}}]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            interaction_map = initialize_interactions_from_config(temp_config_path)\n\n            # Check that interaction name was derived from class name\n            assert len(interaction_map) == 1\n            assert \"gsm8k\" in interaction_map  # Should be \"gsm8k\" after removing \"interaction\" suffix\n            assert isinstance(interaction_map[\"gsm8k\"], Gsm8kInteraction)\n            assert interaction_map[\"gsm8k\"].name == \"gsm8k\"\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_initialize_empty_config(self):\n        \"\"\"Test initializing from empty config.\"\"\"\n        config_content = {\"interaction\": []}\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            interaction_map = initialize_interactions_from_config(temp_config_path)\n            assert len(interaction_map) == 0\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_invalid_class_name(self):\n        \"\"\"Test handling of invalid class name.\"\"\"\n        config_content = {\n            \"interaction\": [{\"name\": \"invalid\", \"class_name\": \"invalid.module.InvalidClass\", \"config\": {}}]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            with pytest.raises(ModuleNotFoundError):\n                initialize_interactions_from_config(temp_config_path)\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_duplicate_interaction_names(self):\n        \"\"\"Test handling of duplicate interaction names.\"\"\"\n        config_content = {\n            \"interaction\": [\n                {\"name\": \"duplicate\", \"class_name\": \"verl.interactions.base.BaseInteraction\", \"config\": {}},\n                {\n                    \"name\": \"duplicate\",\n                    \"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\",\n                    \"config\": {},\n                },\n            ]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            with pytest.raises(ValueError, match=\"Duplicate interaction name 'duplicate' found\"):\n                initialize_interactions_from_config(temp_config_path)\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_auto_name_generation_edge_cases(self):\n        \"\"\"Test automatic name generation for various class name patterns.\"\"\"\n        config_content = {\n            \"interaction\": [\n                {\"class_name\": \"verl.interactions.base.BaseInteraction\", \"config\": {}},\n                {\"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\", \"config\": {}},\n            ]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            interaction_map = initialize_interactions_from_config(temp_config_path)\n\n            # Check that names were generated correctly\n            assert len(interaction_map) == 2\n            assert \"base\" in interaction_map  # BaseInteraction -> base\n            assert \"gsm8k\" in interaction_map  # Gsm8kInteraction -> gsm8k\n        finally:\n            os.unlink(temp_config_path)\n"
  },
  {
    "path": "verl_distillation/tests/kill_github_tests.sh",
    "content": "#!/bin/bash\n\nif [ \"$#\" -ne 1 ]; then\n    echo \"Usage: $0 YOUR_GITHUB_TOKEN\"\n    echo \"Please provide exactly one input argument for your github token.\"\n    exit 1\nfi\n\n# Set your GitHub repository details\nOWNER=\"volcengine\"\nREPO=\"verl\"\nTOKEN=$1\n\n# API URL for workflow runs\nAPI_URL=\"https://api.github.com/repos/$OWNER/$REPO/actions/runs?status=queued\"\n\n# Check required commands\ncommand -v jq >/dev/null 2>&1 || { echo \"jq is required but not installed. Aborting.\"; exit 1; }\n\n# Get queued workflow runs\nresponse=$(curl -s -H \"Authorization: token $TOKEN\" -H \"Accept: application/vnd.github.v3+json\" \"$API_URL\")\n\n# Run this for debugging\n# echo $response\n\n# Extract run IDs\nqueued_run_ids=$(echo \"$response\" | jq -r '.workflow_runs[] | .id')\n\nif [ -z \"$queued_run_ids\" ]; then\n    echo \"No queued workflow runs found.\"\n    exit 0\nfi\n\n# Cancel each queued run\nfor run_id in $queued_run_ids; do\n    echo \"Cancelling run $run_id\"\n    cancel_url=\"https://api.github.com/repos/$OWNER/$REPO/actions/runs/$run_id/cancel\"\n    curl -s -X POST -H \"Authorization: token $TOKEN\" -H \"Accept: application/vnd.github.v3+json\" \"$cancel_url\"\ndone\n\necho \"Cancelled all queued workflow runs.\"\n"
  },
  {
    "path": "verl_distillation/tests/models/test_engine.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\n\nfrom functools import partial\n\nimport numpy as np\nimport pytest\nimport ray\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, Qwen3Config, Qwen3MoeConfig\n\nfrom verl import DataProto\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\nfrom verl.trainer.config import CheckpointConfig\nfrom verl.utils.model import compute_position_id_with_mask, create_random_mask\nfrom verl.utils.torch_functional import logprobs_from_logits_naive\nfrom verl.workers.config import (\n    ActorConfig,\n    CriticConfig,\n    FSDPEngineConfig,\n    FSDPOptimizerConfig,\n    HFModelConfig,\n    McoreEngineConfig,\n    McoreOptimizerConfig,\n)\nfrom verl.workers.roles import ActorWorker, CriticWorker\nfrom verl.workers.roles.utils.losses import ppo_loss, sft_loss\n\n\n@pytest.mark.parametrize(\"strategy\", [\"megatron\", \"fsdp\", \"fsdp2\"])\ndef test_actor_engine(strategy):\n    ray.init()\n\n    path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B-Instruct\")\n    model_config = HFModelConfig(path=path)\n\n    if strategy == \"megatron\":\n        engine_config = McoreEngineConfig(\n            forward_only=False,\n            use_mbridge=False,\n            tensor_model_parallel_size=2,\n            pipeline_model_parallel_size=2,\n            context_parallel_size=2,\n        )\n        optimizer_config = McoreOptimizerConfig(lr_decay_steps=10)\n    elif strategy in [\"fsdp\", \"fsdp2\"]:\n        engine_config = FSDPEngineConfig(\n            forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2\n        )\n        optimizer_config = FSDPOptimizerConfig()\n    else:\n        raise NotImplementedError(f\"strategy {strategy} is not supported\")\n\n    config = ActorConfig(\n        model_config=model_config,\n        engine=engine_config,\n        strategy=strategy,\n        ppo_micro_batch_size_per_gpu=256,\n        ppo_mini_batch_size=4,\n        optim=optimizer_config,\n        use_dynamic_bsz=True,\n        rollout_n=1,\n    )\n    ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorWorker), config=config)\n    resource_pool = RayResourcePool(process_on_nodes=[8])\n    wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)\n    # init model\n    wg.init_model()\n\n    batch_size = 8\n    seqlen = 32\n\n    response_length = seqlen // 2\n\n    torch.manual_seed(1)\n    np.random.seed(1)\n\n    input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen))\n    attention_mask = create_random_mask(\n        input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6\n    )\n    position_ids = compute_position_id_with_mask(attention_mask)\n\n    global_token_num = torch.sum(attention_mask, dim=-1).tolist()\n\n    print(input_ids.float().mean(), attention_mask.float().mean())\n\n    responses = input_ids[:, response_length:]\n    response_mask = attention_mask[:, response_length:]\n\n    assert torch.all(response_mask[:, 0] == 1)\n\n    data = DataProto.from_single_dict(\n        {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"position_ids\": position_ids,\n            \"responses\": responses,\n            \"response_mask\": response_mask,\n        },\n        meta_info={\"temperature\": 1.0, \"global_token_num\": global_token_num},\n    )\n\n    sft_loss_ = partial(sft_loss, config=config)\n\n    # eval\n    output = wg.compute_log_prob(data)\n\n    # load hf model and compare results with hf model\n    hf_model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16)\n    hf_output = hf_model(input_ids, attention_mask=attention_mask)\n    hf_logprobs = logprobs_from_logits_naive(\n        hf_output.logits[:, -response_length - 1 : -1, :].float(), input_ids[:, -response_length:]\n    )\n    hf_logprobs_mean = torch.mean(hf_logprobs * response_mask)\n    mcore_logprobs_mean = torch.mean(output.batch[\"old_log_probs\"] * response_mask)\n\n    torch.testing.assert_close(hf_logprobs_mean, mcore_logprobs_mean, atol=1e-3, rtol=1e-2)\n\n    data = data.union(output)\n\n    wg.set_loss_fn(sft_loss_)\n\n    # train for one step\n    metrics = wg.update_actor(data)\n    print(metrics)\n\n    # add ppo data\n    data.batch[\"advantages\"] = torch.rand_like(responses, dtype=torch.float32)\n    data.batch[\"ref_log_prob\"] = torch.rand_like(responses, dtype=torch.float32)\n\n    # set ppo loss\n    ppo_loss_ = partial(ppo_loss, config=config)\n    wg.set_loss_fn(ppo_loss_)\n\n    # update again\n    ppo_metrics = wg.update_actor(data)\n    print(ppo_metrics)\n\n    ray.shutdown()\n\n\ndef create_model():\n    from transformers import Qwen3Config\n\n    config = Qwen3Config(num_hidden_layers=2, num_labels=1)\n    model = AutoModelForTokenClassification.from_config(config)\n    assert model.config.num_labels == 1\n    path = os.path.expanduser(\"~/models/test_model\")\n    model.save_pretrained(path)\n    config.save_pretrained(path)\n    return path\n\n\n@pytest.mark.parametrize(\"strategy\", [\"megatron\", \"fsdp\", \"fsdp2\"])\ndef test_critic_engine(strategy):\n    ray.init()\n\n    path = create_model()\n    model_config = HFModelConfig(path=path, load_tokenizer=False)\n\n    if strategy == \"megatron\":\n        engine_config = McoreEngineConfig(\n            forward_only=False,\n            use_mbridge=False,\n            tensor_model_parallel_size=2,\n            pipeline_model_parallel_size=2,\n            context_parallel_size=2,\n        )\n        optimizer_config = McoreOptimizerConfig(lr_decay_steps=10)\n    elif strategy in [\"fsdp\", \"fsdp2\"]:\n        engine_config = FSDPEngineConfig(\n            forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2\n        )\n        optimizer_config = FSDPOptimizerConfig()\n    else:\n        raise NotImplementedError(f\"strategy {strategy} is not supported\")\n\n    config = CriticConfig(\n        model_config=model_config,\n        engine=engine_config,\n        strategy=strategy,\n        ppo_micro_batch_size_per_gpu=256,\n        ppo_mini_batch_size=4,\n        optim=optimizer_config,\n        use_dynamic_bsz=True,\n        rollout_n=1,\n    )\n    ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(CriticWorker), config=config)\n    resource_pool = RayResourcePool(process_on_nodes=[8])\n    wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)\n    # init model\n    wg.init_model()\n\n    batch_size = 8\n    seqlen = 32\n\n    response_length = seqlen // 2\n\n    torch.manual_seed(1)\n    np.random.seed(1)\n\n    input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen))\n    attention_mask = create_random_mask(\n        input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6\n    )\n    position_ids = compute_position_id_with_mask(attention_mask)\n\n    global_token_num = torch.sum(attention_mask, dim=-1).tolist()\n\n    print(input_ids.float().mean(), attention_mask.float().mean())\n\n    responses = input_ids[:, response_length:]\n    response_mask = attention_mask[:, response_length:]\n\n    assert torch.all(response_mask[:, 0] == 1)\n\n    data = DataProto.from_single_dict(\n        {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"position_ids\": position_ids,\n            \"responses\": responses,\n            \"response_mask\": response_mask,\n        },\n        meta_info={\"temperature\": 1.0, \"global_token_num\": global_token_num},\n    )\n\n    # eval\n    output = wg.compute_values(data)\n\n    # load hf model and compare results with hf model\n    with torch.device(\"cuda\"):\n        hf_model = AutoModelForTokenClassification.from_pretrained(\n            path, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n        )\n        hf_output = hf_model(input_ids.cuda(), attention_mask=attention_mask.cuda())\n        hf_values = hf_output.logits[:, -response_length - 1 : -1, :].float().squeeze(-1).cpu()\n    hf_values_mean = torch.mean(hf_values * response_mask)\n\n    engine_values = torch.mean(output.batch[\"values\"] * response_mask)\n\n    torch.testing.assert_close(hf_values_mean, engine_values, atol=1e-2, rtol=1e-2)\n\n    data = data.union(output)\n\n    # add ppo data\n    data.batch[\"values\"] = torch.rand_like(responses, dtype=torch.float32)\n    data.batch[\"returns\"] = torch.rand_like(responses, dtype=torch.float32)\n\n    # update again\n    ppo_metrics = wg.update_critic(data)\n    print(ppo_metrics)\n\n    ray.shutdown()\n\n\ndef create_actor_model(tmp_path, config):\n    model = AutoModelForCausalLM.from_config(config)\n    path = os.path.join(tmp_path, \"test_model\")\n    model.save_pretrained(path)\n    config.save_pretrained(path)\n    return path\n\n\ndef _worker(rank: int, world_size: int, rendezvous_file: str, strategy: str, model_path: str):\n    torch.cuda.set_device(rank)\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=f\"file://{rendezvous_file}\",\n        rank=rank,\n        world_size=world_size,\n    )\n\n    ref_model_config = AutoConfig.from_pretrained(model_path)\n    with torch.device(\"meta\"):\n        ref_model = AutoModelForCausalLM.from_config(ref_model_config)\n\n    from verl.workers.engine import BaseEngine, EngineRegistry\n\n    # construct configs\n    model_config = HFModelConfig(path=model_path, load_tokenizer=False)\n\n    if strategy == \"megatron\":\n        engine_config = McoreEngineConfig(\n            forward_only=False,\n            use_mbridge=True,\n            tensor_model_parallel_size=2,\n            pipeline_model_parallel_size=2,\n            context_parallel_size=1,\n        )\n        optimizer_config = McoreOptimizerConfig(lr_decay_steps=10)\n    elif strategy in [\"fsdp\", \"fsdp2\"]:\n        engine_config = FSDPEngineConfig(\n            forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2\n        )\n        optimizer_config = FSDPOptimizerConfig()\n    else:\n        raise NotImplementedError(f\"strategy {strategy} is not supported\")\n\n    checkpoint_config = CheckpointConfig()\n\n    # build model engine\n    engine: BaseEngine = EngineRegistry.new(\n        model_type=\"language_model\",\n        backend=engine_config.strategy,\n        model_config=model_config,\n        engine_config=engine_config,\n        optimizer_config=optimizer_config,\n        checkpoint_config=checkpoint_config,\n    )\n\n    engine.initialize()\n\n    # get per tensor parameter\n    per_tensor_params = engine.get_per_tensor_param()\n\n    ref_state_dict = ref_model.state_dict()\n\n    # load ground truth and compare\n    for key, value in per_tensor_params:\n        assert key in ref_state_dict, f\"{key} not in ref_state_dict\"\n        assert value.shape == ref_state_dict[key].shape, (\n            f\"{key} shape not equal, {value.shape} != {ref_state_dict[key].shape}\"\n        )\n        if rank == 0:\n            print(key, value.shape)\n\n    dist.barrier()\n    dist.destroy_process_group()\n\n\n@pytest.mark.parametrize(\"world_size\", [8])\n@pytest.mark.parametrize(\"config\", [Qwen3Config(num_hidden_layers=2), Qwen3MoeConfig(num_hidden_layers=2)])\n@pytest.mark.parametrize(\"strategy\", [\"megatron\", \"fsdp\", \"fsdp2\"])\ndef test_per_tensor_generator(world_size, tmp_path, config, strategy):\n    rendezvous_file = str(tmp_path / \"rdzv_mask\")\n    os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)\n    # create a model\n    model_path = create_actor_model(tmp_path, config)\n    # spawn workers\n    mp.spawn(\n        fn=_worker,\n        args=(world_size, rendezvous_file, strategy, model_path),\n        nprocs=world_size,\n        join=True,\n    )\n"
  },
  {
    "path": "verl_distillation/tests/models/test_transformer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\nfrom transformers import (\n    ApertusConfig,\n    AutoModelForCausalLM,\n    AutoModelForTokenClassification,\n    GemmaConfig,\n    LlamaConfig,\n    MistralConfig,\n    Qwen2Config,\n)\n\nfrom verl.utils.model import compute_position_id_with_mask, create_random_mask\nfrom verl.utils.torch_functional import log_probs_from_logits_all_rmpad, masked_mean\n\n# TODO(sgm): add more models for test\n# we only need one scale for each model\ntest_configs = [\n    LlamaConfig(num_hidden_layers=1),\n    MistralConfig(num_hidden_layers=1),\n    GemmaConfig(num_hidden_layers=1),\n    Qwen2Config(num_hidden_layers=1),\n    ApertusConfig(num_hidden_layers=1),\n]\n\n\ndef test_hf_casual_models():\n    batch_size = 4\n    seqlen = 128\n    response_length = 127\n\n    for config in test_configs:\n        # config = AutoConfig.from_pretrained(test_case)\n        with torch.device(\"cuda\"):\n            model = AutoModelForCausalLM.from_config(\n                config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n            )\n            model = model.to(device=\"cuda\")\n        input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device=\"cuda\")\n        attention_mask = create_random_mask(\n            input_ids=input_ids,\n            max_ratio_of_left_padding=0.1,\n            max_ratio_of_valid_token=0.8,\n            min_ratio_of_valid_token=0.5,\n        )\n        position_ids = compute_position_id_with_mask(\n            attention_mask\n        )  # TODO(sgm): we can construct the position_ids_rmpad here\n\n        input_ids_rmpad, indices, *_ = unpad_input(\n            input_ids.unsqueeze(-1), attention_mask\n        )  # input_ids_rmpad (total_nnz, ...)\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n        # unpad the position_ids to align the rotary\n        position_ids_rmpad = index_first_axis(\n            rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n        ).transpose(0, 1)\n\n        # input with input_ids_rmpad and postition_ids to enable flash attention varlen\n        logits_rmpad = model(\n            input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False\n        ).logits  # (1, total_nnz, vocab_size)\n\n        origin_logits = model(\n            input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False\n        ).logits\n        origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask)\n\n        logits_rmpad = logits_rmpad.squeeze(0)\n        log_probs = log_probs_from_logits_all_rmpad(\n            input_ids_rmpad=input_ids_rmpad,\n            logits_rmpad=logits_rmpad,\n            indices=indices,\n            batch_size=batch_size,\n            seqlen=seqlen,\n            response_length=response_length,\n        )  # (batch, seqlen)\n        origin_log_probs = log_probs_from_logits_all_rmpad(\n            input_ids_rmpad=input_ids_rmpad,\n            logits_rmpad=origin_logits_rmpad,\n            indices=origin_logits_indices,\n            batch_size=batch_size,\n            seqlen=seqlen,\n            response_length=response_length,\n        )  # (batch, seqlen)\n\n        torch.testing.assert_close(\n            masked_mean(log_probs, attention_mask[:, -response_length - 1 : -1]),\n            masked_mean(origin_log_probs, attention_mask[:, -response_length - 1 : -1]),\n            atol=1e-2,\n            rtol=1e-5,\n        )\n    print(\"Check pass\")\n\n\ndef test_hf_value_models():\n    batch_size = 4\n    seqlen = 128\n\n    for config in test_configs:\n        # config = AutoConfig.from_pretrained(test_case)\n        config.num_labels = 1\n        config.classifier_dropout = 0\n        config.hidden_dropout = 0\n        with torch.device(\"cuda\"):\n            model = AutoModelForTokenClassification.from_config(\n                config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n            )\n            model = model.to(device=\"cuda\")\n        input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device=\"cuda\")\n        attention_mask = create_random_mask(\n            input_ids=input_ids,\n            max_ratio_of_left_padding=0.1,\n            max_ratio_of_valid_token=0.8,\n            min_ratio_of_valid_token=0.5,\n        )\n        position_ids = compute_position_id_with_mask(\n            attention_mask\n        )  # TODO(sgm): we can construct the position_ids_rmpad here\n\n        input_ids_rmpad, indices, *_ = unpad_input(\n            input_ids.unsqueeze(-1), attention_mask\n        )  # input_ids_rmpad (total_nnz, ...)\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n        # unpad the position_ids to align the rotary\n        position_ids_rmpad = index_first_axis(\n            rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n        ).transpose(0, 1)\n\n        origin_logits = model(\n            input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False\n        ).logits\n\n        # input with input_ids_rmpad and postition_ids to enable flash attention varlen\n        rmpad_logits = model(\n            input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False\n        ).logits  # (1, total_nnz, 1)\n        rmpad_logits = rmpad_logits.squeeze(0)\n        pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen)\n\n        torch.testing.assert_close(\n            masked_mean(pad_logits, attention_mask[:, :, None]),\n            masked_mean(origin_logits, attention_mask[:, :, None]),\n            atol=1e-2,\n            rtol=1e-5,\n        )\n    print(\"Value model check pass\")\n\n\ndef test_attn_implementation_override():\n    \"\"\"Test that attn_implementation override config is properly respected.\"\"\"\n    # Test case 1: Test the actual extraction logic (no network required)\n    test_cases = [\n        ({}, \"flash_attention_2\"),  # Default case\n        ({\"attn_implementation\": \"eager\"}, \"eager\"),  # Override case\n        ({\"attn_implementation\": \"sdpa\"}, \"sdpa\"),  # Another override\n        ({\"other_config\": \"value\"}, \"flash_attention_2\"),  # No attn_implementation key\n    ]\n\n    for override_config, expected in test_cases:\n        actual = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert actual == expected, f\"Expected {expected}, got {actual} for config {override_config}\"\n\n    # Test case 2: Test with local config creation (simulate FSDP worker behavior)\n    # Test default behavior\n    override_config_default = {}\n    attn_implementation_default = override_config_default.get(\"attn_implementation\", \"flash_attention_2\")\n    assert attn_implementation_default == \"flash_attention_2\"\n\n    # Test override behavior\n    override_config_eager = {\"attn_implementation\": \"eager\"}\n    attn_implementation_eager = override_config_eager.get(\"attn_implementation\", \"flash_attention_2\")\n    assert attn_implementation_eager == \"eager\"\n\n    # Test that we can create a config with specific attn_implementation\n    config_with_eager = LlamaConfig(num_hidden_layers=1, _attn_implementation=\"eager\")\n    assert config_with_eager._attn_implementation == \"eager\"\n\n    config_with_flash = LlamaConfig(num_hidden_layers=1, _attn_implementation=\"flash_attention_2\")\n    assert config_with_flash._attn_implementation == \"flash_attention_2\"\n\n    print(\"✓ All attn_implementation override config tests passed\")\n\n\ndef test_fsdp_worker_attn_implementation_integration():\n    \"\"\"Test integration of attn_implementation with FSDP worker logic.\"\"\"\n\n    # Mock the FSDP worker configuration scenario\n    mock_override_config = {\"attn_implementation\": \"eager\"}\n\n    # Test the exact logic used in FSDP workers\n    attn_implementation = mock_override_config.get(\"attn_implementation\", \"flash_attention_2\")\n    assert attn_implementation == \"eager\"\n\n    # Test with empty config (should default)\n    mock_override_config_empty = {}\n    attn_implementation_default = mock_override_config_empty.get(\"attn_implementation\", \"flash_attention_2\")\n    assert attn_implementation_default == \"flash_attention_2\"\n\n    # Test that the parameter would be passed correctly to both AutoConfig and Model\n    expected_calls = [\n        (\"AutoConfig.from_pretrained\", {\"attn_implementation\": attn_implementation}),\n        (\"AutoModel.from_pretrained\", {\"attn_implementation\": attn_implementation}),\n    ]\n\n    # Verify the parameter extraction works as expected\n    for call_name, expected_params in expected_calls:\n        assert expected_params[\"attn_implementation\"] == \"eager\"\n\n    print(\"✓ FSDP worker integration test passed\")\n\n\nif __name__ == \"__main__\":\n    test_hf_casual_models()\n    test_hf_value_models()\n    test_attn_implementation_override()\n    test_fsdp_worker_attn_implementation_integration()\n"
  },
  {
    "path": "verl_distillation/tests/models/test_transformers_ulysses.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport contextlib\nimport copy\nfrom dataclasses import dataclass\n\nimport pytest\nimport torch\nimport torch.distributed\nimport transformers\nfrom flash_attn.bert_padding import index_first_axis, rearrange, unpad_input\nfrom packaging import version\nfrom torch.distributed import init_device_mesh\nfrom transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig, Qwen2Config\n\nfrom verl.models.transformers.monkey_patch import apply_monkey_patch\nfrom verl.protocol import DataProto\nfrom verl.utils.distributed import initialize_global_process_group\nfrom verl.utils.model import compute_position_id_with_mask, create_random_mask\nfrom verl.utils.ulysses import (\n    gather_outputs_and_unpad,\n    get_ulysses_sequence_parallel_world_size,\n    set_ulysses_sequence_parallel_group,\n    ulysses_pad_and_slice_inputs,\n)\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\n# TODO(sgm): add more models for test\n# we only need one scale for each model\n\n\n@dataclass\nclass SequenceParallelConfig:\n    config: PretrainedConfig\n    sp_size: int\n    is_valid: bool\n\n\ndef test_configs():\n    configs = [\n        SequenceParallelConfig(\n            LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True\n        ),\n        SequenceParallelConfig(\n            Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584),\n            sp_size=4,\n            is_valid=True,\n        ),\n        SequenceParallelConfig(\n            Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584),\n            sp_size=8,\n            is_valid=False,\n        ),\n        SequenceParallelConfig(\n            Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=4, is_valid=True\n        ),\n        SequenceParallelConfig(\n            Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=8, is_valid=True\n        ),\n    ]\n\n    if version.parse(transformers.__version__) >= version.parse(\"4.56.0\"):\n        from transformers import ApertusConfig\n\n        configs.append(\n            SequenceParallelConfig(\n                ApertusConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32, hidden_size=4096),\n                sp_size=8,\n                is_valid=True,\n            )\n        )\n\n    return configs\n\n\ndef sync_model_parameters_global(layer):\n    # synchronize weights\n    for p in layer.parameters():\n        torch.distributed.broadcast(tensor=p.data, src=0)\n\n\n@pytest.mark.parametrize(\"test_config\", test_configs())\ndef test_hf_casual_fwd_bwd(test_config):\n    if not torch.distributed.is_initialized():\n        initialize_global_process_group()\n\n    context = contextlib.nullcontext() if test_config.is_valid else pytest.raises(AssertionError)\n    with context:\n        world_size = torch.distributed.get_world_size()\n        _hf_casual_fwd_bwd(test_config.config, test_config.sp_size, world_size // test_config.sp_size)\n\n    # TODO: seems not work, will cause `socketStartConnect: Connect to xxx failed : Software caused connection abort`\n    # torch.distributed.destroy_process_group()\n\n\ndef _hf_casual_fwd(config, sp_size, dp_size):\n    assert torch.cuda.device_count() >= 2, \"need at least 2 gpus for test\"\n\n    ulysses_device_mesh = init_device_mesh(\n        device_type=\"cuda\", mesh_shape=(dp_size, sp_size), mesh_dim_names=(\"dp\", \"sp\")\n    )\n    sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)\n\n    batch_size = 1\n    seqlen = 128\n    # response_length = 127\n\n    # patch before load\n    with torch.device(\"cuda\"):\n        model = AutoModelForCausalLM.from_config(\n            config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n        )\n        apply_monkey_patch(model, sp_size)\n        model = model.to(device=\"cuda\")\n        sync_model_parameters_global(model)\n\n    # different rank will generate different input_ids following fsdp\n    input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device=\"cuda\")\n    attention_mask = create_random_mask(\n        input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8\n    )\n    position_ids = compute_position_id_with_mask(\n        attention_mask\n    )  # TODO(sgm): we can construct the position_ids_rmpad here\n\n    model_inputs = {\n        \"input_ids\": input_ids.cuda(),\n        \"attention_mask\": attention_mask.cuda(),\n        \"position_ids\": position_ids.int().cuda(),\n    }\n\n    model_inputs = DataProto.from_dict(model_inputs)\n\n    # 1. perform ulysses forward\n    with sharding_manager:\n        model_inputs = sharding_manager.preprocess_data(model_inputs)\n        input_ids = model_inputs.batch[\"input_ids\"]\n        attention_mask = model_inputs.batch[\"attention_mask\"]\n        position_ids = model_inputs.batch[\"position_ids\"]\n        input_ids_rmpad, indices, *_ = unpad_input(\n            input_ids.unsqueeze(-1), attention_mask\n        )  # input_ids_rmpad (total_nnz, ...)\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n        # unpad the position_ids to align the rotary\n        position_ids_rmpad = index_first_axis(\n            rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n        ).transpose(0, 1)\n\n        # slice input tensor for ulysses\n        # input_ids are padded and sliced\n        # postition_ids are only padded but not sliced\n        input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(\n            input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()\n        )\n\n        # input with input_ids_rmpad and postition_ids to enable flash attention varlen\n        logits_split_in_seq = model(\n            input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False\n        ).logits  # (1, total_nnz/n, vocab_size)\n\n        # all_gather output\n        logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)\n\n    # 2. perform normal forward\n    set_ulysses_sequence_parallel_group(None)\n    logits_rmpad_local = model(\n        input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False\n    ).logits  # (1, total_nnz, vocab_size)\n\n    mean_local = logits_rmpad_local.mean()\n    mean_full = logits_full.mean()\n    torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)\n\n\ndef _hf_casual_fwd_bwd(config, sp_size, dp_size):\n    assert torch.cuda.device_count() >= 2, \"need at least 2 gpus for test\"\n\n    ulysses_device_mesh = init_device_mesh(\n        device_type=\"cuda\", mesh_shape=(dp_size, sp_size), mesh_dim_names=(\"dp\", \"sp\")\n    )\n    sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)\n\n    batch_size = 1\n    seqlen = 128\n    # response_length = 127\n\n    # patch before load\n    with torch.device(\"cuda\"):\n        model = AutoModelForCausalLM.from_config(\n            config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n        )\n        apply_monkey_patch(model, sp_size)\n        model = model.to(device=\"cuda\")\n        sync_model_parameters_global(model)\n\n    # different rank will generate different input_ids following fsdp\n    input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device=\"cuda\")\n    attention_mask = create_random_mask(\n        input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8\n    )\n    position_ids = compute_position_id_with_mask(\n        attention_mask\n    )  # TODO(sgm): we can construct the position_ids_rmpad here\n\n    model_inputs = {\n        \"input_ids\": input_ids.cuda(),\n        \"attention_mask\": attention_mask.cuda(),\n        \"position_ids\": position_ids.int().cuda(),\n    }\n\n    model_inputs = DataProto.from_dict(model_inputs)\n\n    # 1. perform ulysses forward\n    with sharding_manager:\n        model_inputs = sharding_manager.preprocess_data(model_inputs)\n        input_ids = model_inputs.batch[\"input_ids\"]\n        attention_mask = model_inputs.batch[\"attention_mask\"]\n        position_ids = model_inputs.batch[\"position_ids\"]\n        input_ids_rmpad, indices, *_ = unpad_input(\n            input_ids.unsqueeze(-1), attention_mask\n        )  # input_ids_rmpad (total_nnz, ...)\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n        # unpad the position_ids to align the rotary\n        position_ids_rmpad = index_first_axis(\n            rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n        ).transpose(0, 1)\n\n        # slice input tensor for ulysses\n        # input_ids are padded and sliced\n        # postition_ids are only padded but not sliced\n        input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(\n            input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()\n        )\n\n        # input with input_ids_rmpad and postition_ids to enable flash attention varlen\n        logits_split_in_seq = model(\n            input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False\n        ).logits  # (1, total_nnz/n, vocab_size)\n\n        # all_gather output\n        logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)\n\n    # 2. perform normal forward\n    set_ulysses_sequence_parallel_group(None)\n    input_ids_full = copy.deepcopy(input_ids_rmpad)\n    position_ids_full = copy.deepcopy(position_ids_rmpad)\n    model_no_sp = copy.deepcopy(model)\n    logits_rmpad_local = model_no_sp(\n        input_ids_full, position_ids=position_ids_full, use_cache=False\n    ).logits  # (1, total_nnz, vocab_size)\n\n    mean_local = logits_rmpad_local.mean()\n    mean_full = logits_full.mean()\n\n    mean_full.backward()\n    mean_local.backward()\n\n    # 3. check the gradients\n    grad = model.model.layers[0].self_attn.q_proj.weight.grad\n    grad_full = model_no_sp.model.layers[0].self_attn.q_proj.weight.grad\n    torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=3e-5)\n    # The check should be less strict because the gradient is not an averaged value.\n    torch.testing.assert_close(grad, grad_full, rtol=1e-2, atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-svv\"])\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/base/test_decorator.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\nimport verl.single_controller.base.decorator as decorator_module\nfrom verl.single_controller.base.decorator import (\n    DISPATCH_MODE_FN_REGISTRY,\n    Dispatch,\n    _check_dispatch_mode,\n    get_predefined_dispatch_fn,\n    register_dispatch_mode,\n    update_dispatch_mode,\n)\n\n\n@pytest.fixture\ndef reset_dispatch_registry():\n    # Store original state\n    original_registry = DISPATCH_MODE_FN_REGISTRY.copy()\n    yield\n    # Reset registry after test\n    decorator_module.DISPATCH_MODE_FN_REGISTRY.clear()\n    decorator_module.DISPATCH_MODE_FN_REGISTRY.update(original_registry)\n\n\ndef test_register_new_dispatch_mode(reset_dispatch_registry):\n    # Test registration\n    def dummy_dispatch(worker_group, *args, **kwargs):\n        return args, kwargs\n\n    def dummy_collect(worker_group, output):\n        return output\n\n    register_dispatch_mode(\"TEST_MODE\", dummy_dispatch, dummy_collect)\n\n    # Verify enum extension\n    _check_dispatch_mode(Dispatch.TEST_MODE)\n\n    # Verify registry update\n    assert get_predefined_dispatch_fn(Dispatch.TEST_MODE) == {\n        \"dispatch_fn\": dummy_dispatch,\n        \"collect_fn\": dummy_collect,\n    }\n    # Clean up\n    Dispatch.remove(\"TEST_MODE\")\n\n\ndef test_update_existing_dispatch_mode(reset_dispatch_registry):\n    # Store original implementation\n    original_mode = Dispatch.ONE_TO_ALL\n\n    # New implementations\n    def new_dispatch(worker_group, *args, **kwargs):\n        return args, kwargs\n\n    def new_collect(worker_group, output):\n        return output\n\n    # Test update=\n    update_dispatch_mode(original_mode, new_dispatch, new_collect)\n\n    # Verify update\n    assert get_predefined_dispatch_fn(original_mode)[\"dispatch_fn\"] == new_dispatch\n    assert get_predefined_dispatch_fn(original_mode)[\"collect_fn\"] == new_collect\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/check_worker_alive/main.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport sys\nimport time\n\nimport ray\n\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n@ray.remote\nclass TestActor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\n    def foo(self, wait_time):\n        time.sleep(wait_time)\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    wait_time = int(os.getenv(\"WAIT_TIME\", \"10\"))\n\n    ray.init()\n\n    # test single-node-no-partition\n    print(\"test single-node-no-partition\")\n    resource_pool = RayResourcePool([2], use_gpu=False)\n    class_with_args = RayClassWithInitArgs(cls=TestActor)\n\n    print(\"create worker group\")\n    wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"test\")\n\n    wg.start_worker_aliveness_check(1)\n    time.sleep(1)\n\n    print(time.time(), \"start foo\")\n\n    _ = wg.foo(wait_time)\n    print(\"foo started\")\n\n    print(\n        time.time(),\n        f\"wait 6x wait time {wait_time * 6} to let signal returned to process but still not exceed process wait time\",\n    )\n    time.sleep(wait_time * 6)\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/detached_worker/README.md",
    "content": "# Detached Worker\n## How to run (Only on a single node)\n- Start a local ray cluster: \n```bash\nray start --head --port=6379\n```\n- Run the server\n```bash\npython3 server.py\n```\n- On another terminal, Run the client\n```bash\npython3 client.py\n```\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/detached_worker/client.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nIn client, we can get the server handler and send RPC request\n\"\"\"\n\nimport ray\nimport torch\nfrom server import Trainer\nfrom tensordict import TensorDict\n\nfrom verl import DataProto\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup\n\n\ndef compute_position_id_with_mask(mask):\n    return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)\n\n\nif __name__ == \"__main__\":\n    ray.init(address=\"auto\", namespace=\"verl\")\n    # get the worker group using names\n    worker_names = [\"trainerTrainer_0:0\", \"trainerTrainer_0:1\"]\n    cls_with_init_args = RayClassWithInitArgs(cls=Trainer)\n    worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args)\n\n    batch_size = 16\n    sequence_length = 1024\n\n    # give Trainer some data to train\n    input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device=\"cuda\")\n    attention_mask = torch.ones_like(input_ids)\n    position_ids = compute_position_id_with_mask(attention_mask)\n\n    data = DataProto(\n        batch=TensorDict(\n            {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"position_ids\": position_ids},\n            batch_size=batch_size,\n        ),\n        meta_info={},\n    )\n\n    output = worker_group.train_model(data)\n\n    print(output)\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/detached_worker/run.sh",
    "content": "#!/bin/bash\nray start --head --port=6379\npython3 server.py\npython3 client.py\nray stop --force"
  },
  {
    "path": "verl_distillation/tests/single_controller/detached_worker/server.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nServer starts a Trainer. Client sends data to the server to train.\n\"\"\"\n\nimport os\n\nos.environ[\"MEGATRON_USE_CUDA_TIMER\"] = \"0\"\nos.environ[\"MEGATRON_START_PROCESS_TIMER\"] = \"False\"\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\n\nimport ray\nimport torch\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core import tensor_parallel\nfrom megatron.core.models.gpt.gpt_model import ModelType\nfrom omegaconf import OmegaConf\nfrom tensordict import TensorDict\nfrom torch import nn\nfrom transformers import LlamaConfig\n\nfrom verl import DataProto\nfrom verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\nfrom verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config\nfrom verl.utils.megatron_utils import get_model, mcore_model_parallel_config\n\n\n@ray.remote\nclass Trainer(Worker):\n    def __init__(self):\n        super().__init__()\n\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(backend=\"nccl\")\n            torch.cuda.set_device(rank)\n\n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=2,\n                pipeline_model_parallel_size=1,\n                virtual_pipeline_model_parallel_size=None,\n                pipeline_model_parallel_split_rank=None,\n                use_sharp=False,\n                context_parallel_size=1,\n                expert_model_parallel_size=1,\n                nccl_communicator_config_path=None,\n            )\n            tensor_parallel.model_parallel_cuda_manual_seed(10)\n\n            is_collect = (\n                mpu.get_tensor_model_parallel_rank() == 0\n                and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1\n                and mpu.get_context_parallel_rank() == 0\n            )\n            self._register_dispatch_collect_info(\n                mesh_name=\"train\", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect\n            )\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        actor_model_config = LlamaConfig(\n            vocab_size=256,\n            hidden_size=2048,\n            intermediate_size=5504,\n            num_hidden_layers=24,\n            num_attention_heads=16,\n            num_key_value_heads=16,\n        )\n\n        megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16)\n        self.megatron_config = megatron_config\n\n        def megatron_actor_model_provider(pre_process, post_process):\n            # vpp is not supported yet because it will hang for some reason. Need debugging\n            # this_megatron_config = copy.deepcopy(megatron_config)\n            # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank\n            parallel_model = ParallelLlamaForCausalLMRmPadPP(\n                config=actor_model_config,\n                megatron_config=megatron_config,\n                pre_process=pre_process,\n                post_process=post_process,\n            )\n            parallel_model.cuda()\n            return parallel_model\n\n        actor_module = get_model(\n            model_provider_func=megatron_actor_model_provider,\n            model_type=ModelType.encoder_or_decoder,\n            wrap_with_ddp=True,\n        )\n        actor_module = nn.ModuleList(actor_module)\n\n        optim_config = OmegaConf.create({\"lr\": 1e-6, \"clip_grad\": 1.0})\n\n        optim_config = init_megatron_optim_config(optim_config)\n        self.optimizer_config = optim_config\n        actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)\n\n        self.model = actor_module[0]\n        self.optimizer = actor_optimizer\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"train\"))\n    def train_model(self, data: DataProto) -> DataProto:\n        input_ids = data.batch[\"input_ids\"]\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n\n        self.optimizer.zero_grad()\n        self.model.zero_grad_buffer(\n            zero_buffer=(not self.optimizer_config.use_distributed_optimizer)\n        )  # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm\n        # update for 1 iteration\n        output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits\n        output.mean().backward()\n\n        update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(\n            self.megatron_config, self.megatron_config.timers\n        )\n\n        return DataProto(batch=TensorDict({\"loss\": output.detach()}, batch_size=output.shape[0]))\n\n\nif __name__ == \"__main__\":\n    ray.init(address=\"auto\", namespace=\"verl\")\n\n    resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)\n    cls_with_init_args = RayClassWithInitArgs(cls=Trainer)\n    worker_group = RayWorkerGroup(\n        resource_pool=resource_pool,\n        ray_cls_with_init=cls_with_init_args,\n        name_prefix=\"trainer\",\n        detached=True,\n    )\n\n    worker_group.init_model()\n\n    worker_names = worker_group.worker_names\n    print(worker_names)\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_auto_padding_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport ray\nimport torch\n\nfrom verl import DataProto\nfrom verl.protocol import DataProtoConfig\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n# or set env var VERL_AUTO_PADDING = \"1\" / \"true\"\nDataProtoConfig.auto_padding = True\n\n\n@ray.remote\nclass Actor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def add(self, data: DataProto):\n        data.batch[\"a\"] += self.rank\n        return data\n\n\ndef test_auto_padding():\n    ray.init(num_cpus=100)\n\n    chunk_size = 4\n    actor_cls = RayClassWithInitArgs(cls=Actor)\n    resource_pool = RayResourcePool(process_on_nodes=[chunk_size], use_gpu=False)\n    actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)\n\n    # test locally first\n    for test_size in range(4, 20):\n        local_data = DataProto.from_dict({\"a\": torch.zeros(test_size)}, {\"na\": np.zeros(test_size, dtype=object)})\n        # print(f\"before padding, local_data = {local_data}\")\n        padding_size = (chunk_size - (test_size % chunk_size)) if (test_size % chunk_size > 0) else 0\n        local_data.padding(padding_size)\n        # print(f\"after padding, local_data = {local_data}\")\n        assert len(local_data) == len(local_data) + len(local_data) % chunk_size, (\n            f\"expecting padded length to be {len(local_data) + len(local_data) % chunk_size}, but got {len(local_data)}\"\n        )\n        chunked = local_data.chunk(chunk_size)\n        assert len(chunked) == chunk_size, f\"during test_size = {test_size}, expecting {chunk_size}, got {chunked}\"\n        for dp in chunked:\n            assert len(dp) == test_size // chunk_size + bool(test_size % chunk_size), (\n                f\"test size = {test_size}, expecting dp to be length of \"\n                f\"{test_size // chunk_size + bool(test_size % chunk_size)}, but got {len(dp)}: {dp} {chunked}\"\n            )\n\n    # test with RayWorkerGroup method decorated as dispatch_mode=Dispatch.DP_COMPUTE_PROTO\n    data = DataProto.from_dict({\"a\": torch.zeros(10)}, {\"na\": np.array([str(i) for i in range(10)], dtype=object)})\n    output = actor_wg.add(data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 10, \"Failed in args split and padding.\"\n\n    data = DataProto.from_dict({\"a\": torch.zeros(10)}, {\"na\": np.array([str(i) for i in range(10)], dtype=object)})\n    output = actor_wg.add(data=data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 10, \"Failed in kwargs split and padding.\"\n\n    data = DataProto.from_dict({\"a\": torch.zeros(1)}, {\"na\": np.array([str(i) for i in range(1)], dtype=object)})\n    output = actor_wg.add(data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 1, \"Failed in args split and padding.\"\n\n    data = DataProto.from_dict({\"a\": torch.zeros(1)}, {\"na\": np.array([str(i) for i in range(1)], dtype=object)})\n    output = actor_wg.add(data=data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 1, \"Failed in kwargs split and padding.\"\n\n    data = DataProto.from_dict({\"a\": torch.zeros(8)}, {\"na\": np.array([str(i) for i in range(8)], dtype=object)})\n    output = actor_wg.add(data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 8, \"Failed in args split and padding.\"\n\n    data = DataProto.from_dict({\"a\": torch.zeros(8)}, {\"na\": np.array([str(i) for i in range(8)], dtype=object)})\n    output = actor_wg.add(data=data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 8, \"Failed in kwargs split and padding.\"\n\n    # test data proto specific config\n    DataProtoConfig.auto_padding = False\n\n    data = DataProto.from_dict(\n        {\"a\": torch.zeros(10)}, {\"na\": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True\n    )\n    output = actor_wg.add(data)\n    print(output.batch[\"a\"])\n    assert len(output) == 10, \"Failed in args split and padding.\"\n\n    data = DataProto.from_dict(\n        {\"a\": torch.zeros(10)}, {\"na\": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True\n    )\n    output = actor_wg.add(data=data)\n    print(output.batch[\"a\"])\n    assert len(output) == 10, \"Failed in kwargs split and padding.\"\n\n    data = DataProto.from_single_dict(\n        {\"a\": torch.zeros(1), \"na\": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True\n    )\n    output = actor_wg.add(data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 1, \"Failed in args split and padding.\"\n\n    data = DataProto.from_single_dict(\n        {\"a\": torch.zeros(1), \"na\": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True\n    )\n    output = actor_wg.add(data=data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 1, \"Failed in kwargs split and padding.\"\n\n    data = DataProto.from_single_dict({\"a\": torch.zeros(8), \"na\": np.array([str(i) for i in range(8)], dtype=object)})\n    output = actor_wg.add(data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 8, \"Failed in args split and padding.\"\n\n    data = DataProto.from_single_dict({\"a\": torch.zeros(8), \"na\": np.array([str(i) for i in range(8)], dtype=object)})\n    output = actor_wg.add(data=data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 8, \"Failed in kwargs split and padding.\"\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_auto_padding()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_colocated_workers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray.base import (\n    RayClassWithInitArgs,\n    RayResourcePool,\n    RayWorkerGroup,\n    create_colocated_worker_cls,\n)\n\n\n@ray.remote\nclass Actor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def add(self, data: DataProto):\n        data.batch[\"a\"] += self.rank\n        return data\n\n\n@ray.remote\nclass Critic(Worker):\n    def __init__(self, config) -> None:\n        super().__init__()\n        self.config = config\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    async def sub(self, data: DataProto):\n        data.batch[\"a\"] -= self.config[\"b\"]\n        return data\n\n\ndef test_colocated_workers():\n    ray.init()\n\n    import torch\n\n    data = DataProto.from_dict({\"a\": torch.zeros(10)})\n    # create separate workers on the same resource pool\n    actor_cls = RayClassWithInitArgs(cls=Actor)\n    critic_cls = RayClassWithInitArgs(cls=Critic, config={\"b\": 10})\n    resource_pool = RayResourcePool(process_on_nodes=[2])\n\n    actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)\n    critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls)\n\n    expected_actor_output = actor_wg.add(data)\n    expected_critic_output = critic_wg.sub(data)\n\n    # create colocated workers\n    cls_dict = {\"actor\": actor_cls, \"critic\": critic_cls}\n    ray_cls_with_init = create_colocated_worker_cls(cls_dict)\n    wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)\n    spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())\n\n    colocated_actor_wg = spawn_wg[\"actor\"]\n    colocated_critic_wg = spawn_wg[\"critic\"]\n\n    actor_output = colocated_actor_wg.add(data)\n    critic_output = colocated_critic_wg.sub(data)\n\n    torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)\n    torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_colocated_workers_fused.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray.base import (\n    RayClassWithInitArgs,\n    RayResourcePool,\n    RayWorkerGroup,\n    create_colocated_worker_cls_fused,\n)\n\n\n@ray.remote\nclass Actor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def add(self, data: DataProto):\n        data.batch[\"a\"] += self.rank\n        return data\n\n\n@ray.remote\nclass Critic(Worker):\n    def __init__(self, config) -> None:\n        super().__init__()\n        self.config = config\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def sub(self, data: DataProto):\n        data.batch[\"a\"] -= self.config[\"b\"]\n        return data\n\n\ndef test_colocated_workers_fused():\n    ray.init()\n\n    import torch\n\n    data = DataProto.from_dict({\"a\": torch.zeros(10)})\n    # create separate workers on the same resource pool\n    actor_cls = RayClassWithInitArgs(cls=Actor)\n    critic_cls = RayClassWithInitArgs(cls=Critic, config={\"b\": 10})\n    resource_pool = RayResourcePool(process_on_nodes=[2])\n\n    actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)\n    critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls)\n\n    expected_actor_output = actor_wg.add(data)\n    expected_critic_output = critic_wg.sub(data)\n\n    # create colocated workers\n    cls_dict = {\"actor\": actor_cls, \"critic\": critic_cls}\n    ray_cls_with_init = create_colocated_worker_cls_fused(cls_dict)\n    wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)\n    spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())\n\n    colocated_actor_wg = spawn_wg[\"actor\"]\n    colocated_critic_wg = spawn_wg[\"critic\"]\n\n    actor_output = colocated_actor_wg.add(data)\n    critic_output = colocated_critic_wg.sub(data)\n\n    torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)\n    torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_data_transfer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nIn this test, we instantiate a data parallel worker with 8 GPUs\n\"\"\"\n\nimport ray\nimport tensordict\nimport torch\nfrom codetiming import Timer\nfrom torch import distributed as dist\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\nfrom verl.utils.ray_utils import parallel_put\n\n\n@ray.remote\nclass DummyWorker(Worker):\n    def __init__(self):\n        super().__init__()\n        dist.init_process_group()\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)\n    def do_nothing(self, data):\n        for key in data.batch.keys():\n            data.batch[key] += 1\n        if tensordict.__version__ >= \"0.5.0\":\n            data.batch = data.batch.consolidate()\n        return data\n\n\ndef test_data_transfer():\n    ray.init()\n    # construct resource pool\n    resource_pool = RayResourcePool([8])\n    cls_with_init = RayClassWithInitArgs(cls=DummyWorker)\n    # construct worker group\n    wg = RayWorkerGroup(resource_pool, cls_with_init)\n\n    # this is real dataset size\n    batch_size = 4096\n    seqlen = 32768\n\n    data_dict = {}\n\n    for i in range(2):\n        data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen))\n\n    data = DataProto.from_dict(tensors=data_dict)\n\n    print(data)\n\n    # we manually split data here and send to each worker\n    data_list = data.chunk(wg.world_size)\n\n    for i in range(wg.world_size):\n        # consolidate is necessary\n        if tensordict.__version__ >= \"0.5.0\":\n            data_list[i].batch = data_list[i].batch.consolidate()\n\n    with Timer(name=\"ray.pickle\", initial_text=True):\n        for i in range(wg.world_size):\n            ray.cloudpickle.pickle.dumps(data_list[i])\n\n    with Timer(name=\"raw.pickle\", initial_text=True):\n        import pickle\n\n        for i in range(wg.world_size):\n            pickle.dumps(data_list[i])\n\n    # we put in advance\n    with Timer(name=\"put\", initial_text=True):\n        # takes around 40 seconds\n        data_list_ref = parallel_put(data_list)\n        # for i in range(wg.world_size):\n        #     data_list[i] = ray.put(data_list[i])\n\n    with Timer(name=\"launch\", initial_text=True):\n        output_ref = wg.do_nothing(data_list_ref)\n\n    with Timer(name=\"get\", initial_text=True):\n        # takes around 40 seconds\n        output_lst = ray.get(output_ref)\n\n    for input_data, output_data in zip(data_list, output_lst, strict=True):\n        for key in input_data.batch.keys():\n            assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), (\n                input_data.batch[key],\n                output_data.batch[key],\n                key,\n            )\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_decorator_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport time\n\nimport pytest\nimport ray\nimport torch\nfrom tensordict import TensorDict\n\nfrom verl.protocol import DataProto, DataProtoFuture\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n# Pytest fixture for Ray setup/teardown\n@pytest.fixture\ndef ray_init_shutdown():\n    ray.init(num_cpus=100)\n    yield\n    ray.shutdown()\n\n\n# Define a simple worker for testing\n@ray.remote\nclass DecoratorTestWorker(Worker):\n    def __init__(self, initial_value=0):\n        super().__init__()\n        self.value = initial_value\n        # Simulate some setup if needed\n        time.sleep(0.1)  # Ensure worker init completes\n\n    # Test method for synchronous DP compute (default behavior)\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def dp_compute(self, data: DataProto) -> DataProto:\n        time.sleep(0.1)  # Simulate work\n        rank_value = torch.tensor(self.rank, device=data.batch[\"input\"].device, dtype=data.batch[\"input\"].dtype)\n        data.batch[\"output\"] = data.batch[\"input\"] + self.value + rank_value\n        return data\n\n    # Test async def method with DP compute (default behavior)\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)\n    async def async_dp_compute(self, data: DataProto) -> DataProto:\n        # Simulate async work\n        await asyncio.sleep(0.1)  # Simulate async work\n        rank_value = torch.tensor(self.rank, device=data.batch[\"input\"].device, dtype=data.batch[\"input\"].dtype)\n        data.batch[\"output_async\"] = data.batch[\"input\"] * 2 + self.value + rank_value\n        return data\n\n\n# Test function for synchronous DP compute\ndef test_decorator_dp_compute(ray_init_shutdown):\n    \"\"\"\n    Tests the default behavior of a synchronous decorated method with DP_COMPUTE_PROTO.\n    Verifies the result correctness.\n    \"\"\"\n    num_workers = 2\n    resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1)  # Use CPU for simplicity\n    cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=10)\n    worker_group = RayWorkerGroup(\n        resource_pool, cls_with_args, name_prefix=f\"decorator_test_sync_dp_{int(time.time())}\"\n    )\n\n    # Prepare input data (size 4, for 2 workers)\n    input_tensor = torch.arange(4, dtype=torch.float32)\n    data = DataProto(batch=TensorDict({\"input\": input_tensor}, batch_size=[4]))\n\n    # Call the decorated method\n    output = worker_group.dp_compute(data)\n\n    # Assert the result correctness\n    assert isinstance(output, DataProto), \"Expected DataProto result\"\n    assert \"output\" in output.batch.keys()\n    assert len(output) == len(data), \"Output length should match input length\"\n\n    # Expected output calculation for DP_COMPUTE_PROTO with 2 workers\n    # Worker 0 gets data[0:2], Worker 1 gets data[2:4]\n    # Worker 0 adds initial_value(10) + rank(0) = 10\n    # Worker 1 adds initial_value(10) + rank(1) = 11\n    expected_output_part1 = torch.tensor([0, 1], dtype=torch.float32) + 10 + 0\n    expected_output_part2 = torch.tensor([2, 3], dtype=torch.float32) + 10 + 1\n    expected_output = torch.cat([expected_output_part1, expected_output_part2])\n\n    torch.testing.assert_close(output.batch[\"output\"], expected_output, msg=\"Sync DP compute output data mismatch\")\n\n\n# Test function for async def method with DP compute\ndef test_decorator_async_function(ray_init_shutdown):\n    \"\"\"\n    Tests the decorator with an `async def` method using DP_COMPUTE_PROTO.\n    Verifies that the call returns a future and the result is correct after .get().\n    \"\"\"\n    num_workers = 2\n    resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1)\n    cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=5)\n    worker_group = RayWorkerGroup(\n        resource_pool, cls_with_args, name_prefix=f\"decorator_test_async_dp_{int(time.time())}\"\n    )\n\n    # Prepare input data (size 4, for 2 workers)\n    input_tensor = torch.arange(4, dtype=torch.float32)\n    data = DataProto(batch=TensorDict({\"input\": input_tensor}, batch_size=[4]))\n\n    # Call the async decorated method - this should return a future\n    future_output: DataProtoFuture = worker_group.async_dp_compute(data)\n\n    # Assert that the call returned a future\n    assert isinstance(future_output, DataProtoFuture), \"Expected DataProtoFuture for async def call\"\n\n    # Get the result (this should block)\n    result_data = future_output.get()\n\n    # Assert the result correctness\n    assert isinstance(result_data, DataProto)\n    assert \"output_async\" in result_data.batch.keys()\n    assert len(result_data) == len(data), \"Output length should match input length\"\n\n    # Expected output calculation for DP_COMPUTE_PROTO with 2 workers\n    # Worker 0 gets data[0:2], Worker 1 gets data[2:4]\n    # Worker 0 calculates: input * 2 + initial_value(5) + rank(0)\n    # Worker 1 calculates: input * 2 + initial_value(5) + rank(1)\n    expected_output_part1 = (torch.tensor([0, 1], dtype=torch.float32) * 2) + 5 + 0\n    expected_output_part2 = (torch.tensor([2, 3], dtype=torch.float32) * 2) + 5 + 1\n    expected_output = torch.cat([expected_output_part1, expected_output_part2])\n\n    torch.testing.assert_close(\n        result_data.batch[\"output_async\"], expected_output, msg=\"Async DP compute output data mismatch\"\n    )\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_device_mesh_register.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport ray\nimport torch\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import make_nd_compute_dataproto_dispatch_fn, register\n\n\n@ray.remote\nclass TestActor(Worker):\n    def __init__(self):\n        super().__init__()\n\n        import torch.distributed\n\n        torch.distributed.init_process_group(backend=\"nccl\")\n        self.infer_device_mesh = torch.distributed.device_mesh.init_device_mesh(\n            device_type=\"cuda\", mesh_shape=[2, 4], mesh_dim_names=[\"dp\", \"tp\"]\n        )\n        self.train_device_mesh = torch.distributed.device_mesh.init_device_mesh(\n            device_type=\"cuda\", mesh_shape=[2, 2, 2], mesh_dim_names=[\"pp\", \"dp\", \"tp\"]\n        )\n\n        self._register_dispatch_collect_info(\n            \"infer\",\n            dp_rank=self.infer_device_mesh[\"dp\"].get_local_rank(),\n            is_collect=self.infer_device_mesh[\"tp\"].get_local_rank() == 0,\n        )\n        self._register_dispatch_collect_info(\n            \"train\",\n            dp_rank=self.train_device_mesh[\"dp\"].get_local_rank(),\n            is_collect=self.train_device_mesh[\"tp\"].get_local_rank() == 0\n            and self.train_device_mesh[\"pp\"].get_local_rank() == 1,\n        )\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"infer\"))\n    def generate_data_proto(self, data: DataProto):\n        tp_rank = self.infer_device_mesh[\"tp\"].get_local_rank()\n        dp_rank = self.infer_device_mesh[\"dp\"].get_local_rank()\n        data.batch[\"a\"] += (tp_rank + 1) * dp_rank\n        return data\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"train\"))\n    def train_data_proto(self, data: DataProto):\n        tp_rank = self.train_device_mesh[\"tp\"].get_local_rank()\n        dp_rank = self.train_device_mesh[\"dp\"].get_local_rank()\n        pp_rank = self.train_device_mesh[\"pp\"].get_local_rank()\n        data.batch[\"a\"] += (tp_rank + 1) * (dp_rank + 2) * (pp_rank + 3)\n        # tp rank 0, pp rank 1, dp rank 0, output data added: 8 + 3 = 11\n        # tp rank 0, pp rank 1, dp rank 1, output data added: 12 + 4 = 16\n        return data\n\n\ndef test_dist_global_info_wg():\n    # create a worker group with size 8\n    # register a infer dist info with tp=4, dp=2\n    # register a train dist info with tp=2, dp=2, pp=2\n    # test the correctness of data dispatch and computation\n    from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n    ray.init()\n\n    ray_cls = RayClassWithInitArgs(TestActor)\n    resource_pool = RayResourcePool(process_on_nodes=[8])\n    wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls)\n\n    infer_input_data_proto = DataProto.from_single_dict(data={\"a\": torch.tensor([1, 2])})\n    infer_output_data_proto = wg.generate_data_proto(infer_input_data_proto)\n\n    assert wg._dispatch_info[\"infer\"] == [0, 0, 0, 0, 1, 1, 1, 1]\n\n    assert torch.all(torch.eq(infer_output_data_proto.batch[\"a\"], torch.tensor([1, 3])))\n\n    train_input_data_proto = DataProto.from_single_dict(data={\"a\": torch.tensor([3, 4])})\n    train_output_data_proto = wg.train_data_proto(train_input_data_proto)\n\n    assert wg._dispatch_info[\"train\"] == [0, 0, 1, 1, 0, 0, 1, 1]\n\n    assert torch.all(torch.eq(train_output_data_proto.batch[\"a\"], torch.tensor([11, 16])))\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_dist_global_info_wg()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_driverfunc_to_worker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport ray\nimport torch\nfrom tensordict import TensorDict\n\nfrom verl import DataProto\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray import RayWorkerGroup\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool\n\nos.environ[\"RAY_DEDUP_LOGS\"] = \"0\"\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\n\n\n@ray.remote\nclass ModelActor(Worker):\n    def __init__(self):\n        pass\n\n\nclass HackSelf:\n    def __init__(self):\n        pass\n\n\ndef get_aux_metrics(self, test_proto):\n    sequence_ids = test_proto.batch[\"sequence_ids\"]\n    decode_count = []\n    for i in range(sequence_ids.size(0)):\n        decode_count.append(len(sequence_ids[i].tolist()))\n    ret_proto = DataProto(\n        batch=TensorDict(\n            {\"sequence_ids\": sequence_ids, \"decode_count\": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0)\n        )\n    )\n    return ret_proto\n\n\ndef test():\n    # construct model\n    ray.init()\n\n    # create 2 workers, each hold a GPU\n    resource_pool = RayResourcePool([2], use_gpu=True, name_prefix=\"a\")\n\n    class_with_args = RayClassWithInitArgs(cls=ModelActor)\n    shard_wg = RayWorkerGroup(resource_pool, class_with_args)\n\n    test_bs = 8\n    test_proto = DataProto(\n        TensorDict(\n            {\n                \"sequence_ids\": torch.ones([test_bs, 2048], dtype=torch.int64),\n            },\n            batch_size=test_bs,\n        ),\n        meta_info={\"query_length\": 1536},\n    )\n\n    # Sharding among different ranks\n    ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)\n\n    # compare execute on driver\n    hs = HackSelf()\n    ret_proto2 = get_aux_metrics(hs, test_proto)\n\n    torch.testing.assert_close(ret_proto1.batch[\"decode_count\"], ret_proto2.batch[\"decode_count\"])\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_fused_workers_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\n\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray.base import (\n    RayClassWithInitArgs,\n    RayResourcePool,\n    RayWorkerGroup,\n    create_colocated_worker_raw_cls,\n)\n\n\n@ray.remote\nclass Actor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def add(self, x):\n        x += self.rank\n        return x\n\n\n@ray.remote\nclass Critic(Worker):\n    def __init__(self, val) -> None:\n        super().__init__()\n        self.val = val\n\n    @register(dispatch_mode=Dispatch.ALL_TO_ALL)\n    def sub(self, x):\n        x -= self.val\n        return x\n\n\nactor_cls = RayClassWithInitArgs(cls=Actor)\ncritic_cls = RayClassWithInitArgs(cls=Critic, val=10)\ncls_dict = {\"actor\": actor_cls, \"critic\": critic_cls}\nFusedBaseClass = create_colocated_worker_raw_cls(cls_dict)\n\n\n@ray.remote\nclass HybridWorker(FusedBaseClass):\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def foo(self, x):\n        return self.critic.sub(self.actor.add(x))\n\n\ndef test_fused_workers():\n    ray.init(num_cpus=100)\n\n    # create separate workers on the same resource pool\n    process_on_nodes = [2]\n    resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=False)\n\n    # create colocated workers\n    hybrid_cls_with_init = RayClassWithInitArgs(cls=HybridWorker)\n    hybrid_cls_with_init.fused_worker_used = True\n\n    fused_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=hybrid_cls_with_init)\n    fused_wg.fuse(cls_dict.keys())\n\n    x = fused_wg.actor.add(0.1)\n    print(x)\n    y = fused_wg.critic.sub(x)\n    print(y)\n    z = fused_wg.foo(0.1)\n    print(z)\n    for i, j in zip(y, z, strict=True):\n        assert i == j\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_fused_workers()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_high_level_scheduling_api.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport gc\nimport time\n\nimport ray\n\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool\n\n\n@ray.remote\nclass TestActor(Worker):\n    # TODO: pass *args and **kwargs is bug prone and not very convincing\n    def __init__(self, cuda_visible_devices=None) -> None:\n        super().__init__(cuda_visible_devices)\n\n    def get_node_id(self):\n        return ray.get_runtime_context().get_node_id()\n\n\ndef test():\n    ray.init()\n\n    # test single-node-no-partition\n    print(\"test single-node-no-partition\")\n    resource_pool = RayResourcePool([8], use_gpu=True)\n\n    class_with_args = RayClassWithInitArgs(cls=TestActor)\n\n    print(\"create actor worker group\")\n    actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"high_level_api_actor\")\n    print(\"create critic worker group\")\n    critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"hight_level_api_critic\")\n    print(\"create rm worker group\")\n    rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"high_level_api_rm\")\n    print(\"create ref worker group\")\n    ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"high_level_api_ref\")\n\n    assert actor_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n    assert critic_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n    assert rm_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n    assert ref_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n\n    del actor_wg\n    del critic_wg\n    del rm_wg\n    del ref_wg\n    gc.collect()  # make sure ray actors are deleted\n\n    [ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()]\n    print(\"wait 5s to remove placemeng_group\")\n    time.sleep(5)\n    # test single-node-multi-partition\n\n    print(\"test single-node-multi-partition\")\n    rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix=\"rm\")\n    ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix=\"ref\")\n    total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)\n\n    assert rm_resource_pool.world_size == 4\n    assert ref_resource_pool.world_size == 4\n    assert total_resource_pool.world_size == 8\n\n    actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix=\"high_level_api_actor\")\n    critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix=\"high_level_api_critic\")\n    rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix=\"high_level_api_rm\")\n    ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix=\"high_level_api_ref\")\n\n    assert actor_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n    assert critic_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n    assert rm_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(4)]\n    assert ref_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(4, 8)]\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_nested_worker.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport ray\n\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\nclass TestActor(Worker):\n    # TODO: pass *args and **kwargs is bug prone and not very convincing\n    def __init__(self, x) -> None:\n        super().__init__()\n        self.a = x\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def get(self):\n        return self.a + self.rank\n\n\nclass TestHighLevelActor(Worker):\n    def __init__(self, x=None) -> None:\n        super().__init__()\n        self.test_actor = TestActor(x=x)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def get(self):\n        return self.test_actor.get()\n\n\ndef test_nested_worker():\n    ray.init(num_cpus=100)\n\n    # create 4 workers, each hold a GPU\n    resource_pool = RayResourcePool([4], use_gpu=True)\n    class_with_args = RayClassWithInitArgs(cls=ray.remote(TestActor), x=2)\n\n    worker_group = RayWorkerGroup(\n        resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix=\"worker_group_basic\"\n    )\n\n    output = worker_group.get()\n\n    assert output == [2, 3, 4, 5]\n\n    class_with_args = RayClassWithInitArgs(cls=ray.remote(TestHighLevelActor), x=2)\n    high_level_worker_group = RayWorkerGroup(\n        resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix=\"worker_group_basic_2\"\n    )\n\n    output_1 = high_level_worker_group.get()\n\n    assert output_1 == [2, 3, 4, 5]\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_ray_collectives.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTest for using ray collective group.\nSuppose we Actor and Rollout. Actor contains 4 workers and Rollout contains 2 workers. We established a Worker to\nRollout relationship by using collective groups\nActor: rank 0, 1 - Rollout rank 0\nRollout rank 2, 3 - Rollout rank 1\nThen, we initiate 4 p2p comms from actor to rollout\n\"\"\"\n\nimport ray\nimport ray.util.collective as collective\nimport torch\n\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n@ray.remote\nclass Actor(Worker):\n    @register(Dispatch.ONE_TO_ALL)\n    def init(self):\n        remote_rank = self.rank // 2\n        self.group_name = f\"A{self.rank}_R{remote_rank}\"\n        collective.init_collective_group(world_size=2, rank=0, backend=\"nccl\", group_name=self.group_name)\n\n    @register(Dispatch.ONE_TO_ALL, blocking=False)\n    def send_tensors(self):\n        tensor = torch.ones(size=(4,), dtype=torch.float32, device=\"cuda\") * self.rank\n        collective.send(tensor=tensor, dst_rank=1, group_name=self.group_name)\n\n\n@ray.remote\nclass Rollout(Worker):\n    @register(Dispatch.ONE_TO_ALL)\n    def init(self):\n        self.remote_first_rank = self.rank * 2\n        self.remote_second_rank = self.remote_first_rank + 1\n        self.first_group_name = f\"A{self.remote_first_rank}_R{self.rank}\"\n        self.second_group_name = f\"A{self.remote_second_rank}_R{self.rank}\"\n\n        collective.init_collective_group(world_size=2, rank=1, backend=\"nccl\", group_name=self.first_group_name)\n        collective.init_collective_group(world_size=2, rank=1, backend=\"nccl\", group_name=self.second_group_name)\n\n    @register(Dispatch.ONE_TO_ALL, blocking=False)\n    def receive_tensors(self):\n        self.tensor1 = torch.randn(size=(4,), dtype=torch.float32, device=\"cuda\")\n        self.tensor2 = torch.randn(size=(4,), dtype=torch.float32, device=\"cuda\")\n\n        collective.recv(self.tensor1, src_rank=0, group_name=self.first_group_name)\n        collective.recv(self.tensor2, src_rank=0, group_name=self.second_group_name)\n\n    @register(Dispatch.ONE_TO_ALL)\n    def get_tensors(self):\n        return {f\"src_{self.remote_first_rank}\": self.tensor1, f\"src_{self.remote_second_rank}\": self.tensor2}\n\n\ndef test_ray_collective_group():\n    ray.init()\n\n    actor_resource_pool = RayResourcePool([4])\n    rollout_resource_pool = RayResourcePool([2])\n\n    actor_cls = RayClassWithInitArgs(cls=Actor)\n    rollout_cls = RayClassWithInitArgs(cls=Rollout)\n\n    actor_wg = RayWorkerGroup(\n        resource_pool=actor_resource_pool, ray_cls_with_init=actor_cls, name_prefix=\"collective_group_actor\"\n    )\n    rollout_wg = RayWorkerGroup(\n        resource_pool=rollout_resource_pool, ray_cls_with_init=rollout_cls, name_prefix=\"collective_group_rollout\"\n    )\n\n    actor_wg.init()\n    rollout_wg.init()\n\n    out1 = actor_wg.send_tensors()\n    out2 = rollout_wg.receive_tensors()\n\n    # block to wait\n    ray.get(out1)\n    ray.get(out2)\n\n    output = rollout_wg.get_tensors()\n\n    rollout_0_output = output[0]\n    rollout_1_output = output[1]\n\n    output = rollout_0_output | rollout_1_output\n\n    print(output)\n\n    for i in range(4):\n        assert torch.sum(output[f\"src_{i}\"]).item() == 4 * i\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_ray_collective_group()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_ray_local_envs_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\ne2e test verl.single_controller.ray\n\"\"\"\n\nimport os\n\nimport ray\n\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n@ray.remote\nclass TestActor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    def getenv(self, key):\n        val = os.getenv(key, f\"{key} not set\")\n        return val\n\n\ndef test_basics():\n    ray.init(num_cpus=100)\n\n    # create 4 workers, each hold a GPU\n    resource_pool = RayResourcePool([4], use_gpu=False)\n    class_with_args = RayClassWithInitArgs(cls=TestActor)\n\n    worker_group = RayWorkerGroup(\n        resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix=\"worker_group_basic\"\n    )\n\n    output = worker_group.execute_all_sync(\"getenv\", key=\"RAY_LOCAL_WORLD_SIZE\")\n    assert output == [\"4\", \"4\", \"4\", \"4\"]\n\n    ray.shutdown()\n\n\ndef test_customized_worker_env():\n    ray.init(num_cpus=100)\n\n    # create 4 workers, each hold a GPU\n    resource_pool = RayResourcePool([4], use_gpu=False)\n    class_with_args = RayClassWithInitArgs(cls=TestActor)\n\n    worker_group = RayWorkerGroup(\n        resource_pool=resource_pool,\n        ray_cls_with_init=class_with_args,\n        name_prefix=\"worker_group_customized\",\n        worker_env={\n            \"test_key\": \"test_value\",  # new key will be appended\n        },\n    )\n\n    output = worker_group.execute_all_sync(\"getenv\", key=\"test_key\")\n    assert output == [\"test_value\", \"test_value\", \"test_value\", \"test_value\"]\n\n    try:\n        worker_group = RayWorkerGroup(\n            resource_pool=resource_pool,\n            ray_cls_with_init=class_with_args,\n            name_prefix=\"worker_group_error\",\n            worker_env={\n                \"WORLD_SIZE\": \"100\",  # override system env will result in error\n            },\n        )\n    except ValueError as e:\n        assert \"WORLD_SIZE\" in str(e)\n    else:\n        raise ValueError(\"test failed\")\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_basics()\n    test_customized_worker_env()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_ray_utils_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport ray\n\nfrom verl.utils.ray_utils import parallel_put\n\n\n# Initialize Ray for testing if not already done globally\n@pytest.fixture()\ndef init_ray():\n    ray.init(num_cpus=4)\n    yield\n    ray.shutdown()\n\n\ndef test_parallel_put_basic(init_ray):\n    data = [1, \"hello\", {\"a\": 2}, [3, 4]]\n    refs = parallel_put(data)\n    assert len(refs) == len(data)\n    retrieved_data = [ray.get(ref) for ref in refs]\n    assert retrieved_data == data\n\n\ndef test_parallel_put_empty(init_ray):\n    data = []\n    with pytest.raises(AssertionError):\n        _ = parallel_put(data)\n\n\ndef test_parallel_put_workers(init_ray):\n    data = list(range(20))\n    # Test with specific number of workers\n    refs = parallel_put(data, max_workers=4)\n    assert len(refs) == len(data)\n    retrieved_data = [ray.get(ref) for ref in refs]\n    assert retrieved_data == data\n    # Test with default workers (should cap)\n    refs_default = parallel_put(data)\n    assert len(refs_default) == len(data)\n    retrieved_data_default = [ray.get(ref) for ref in refs_default]\n    assert retrieved_data_default == data\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_rvdz.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\n\n\n@ray.remote\nclass TestWorker:\n    def __init__(self, rank, world_size, group_name):\n        self.rank = rank\n        self.world_size = world_size\n        self.group_name = group_name\n        self.communicator = None\n\n    def init(self):\n        from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray\n\n        self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name)\n\n    def test(self):\n        if self.communicator is None:\n            return None\n        return self.communicator.rank_id()\n\n\ndef test_rvdz():\n    ray.init()\n\n    group_name = \"test_group\"\n    world_size = 2\n\n    workers = [TestWorker.options(num_gpus=1).remote(rank, world_size, group_name) for rank in range(world_size)]\n\n    ray.get([worker.init.remote() for worker in workers])\n\n    ranks = ray.get([worker.test.remote() for worker in workers])\n\n    assert ranks == [0, 1], f\"expecting [0, 1], got {ranks}\"\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_worker_group_basics.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\ne2e test verl.single_controller.ray\n\"\"\"\n\nimport ray\nimport torch\n\nfrom verl.single_controller.base.decorator import Dispatch, Execute, collect_all_to_all, register\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\ndef two_to_all_dispatch_fn(worker_group, *args, **kwargs):\n    \"\"\"\n    Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker.\n    \"\"\"\n    for arg in args:\n        assert len(arg) == 2\n        for i in range(worker_group.world_size - 2):\n            arg.append(arg[i % 2])\n    for k, v in kwargs.items():\n        assert len(v) == 2\n        for i in range(worker_group.world_size - 2):\n            v.append(v[i % 2])\n    return args, kwargs\n\n\n@ray.remote\nclass TestActor(Worker):\n    # TODO: pass *args and **kwargs is bug prone and not very convincing\n    def __init__(self, x) -> None:\n        super().__init__()\n        self._x = x\n\n    def foo(self, y):\n        return self._x + y\n\n    @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)\n    def foo_rank_zero(self, x, y):\n        return self._x + y + x\n\n    @register(Dispatch.ONE_TO_ALL, blocking=False)\n    def foo_one_to_all(self, x, y):\n        return self._x + y + x\n\n    @register(Dispatch.ALL_TO_ALL, blocking=False)\n    def foo_all_to_all(self, x, y):\n        return self._x + y + x\n\n    @register(dispatch_mode={\"dispatch_fn\": two_to_all_dispatch_fn, \"collect_fn\": collect_all_to_all})\n    def foo_custom(self, x, y):\n        return self._x + y + x\n\n\n@ray.remote(num_gpus=0.1)\ndef remote_call_wg(worker_names):\n    class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)\n    worker_group = RayWorkerGroup.from_detached(\n        worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None\n    )\n    print(worker_group.worker_names)\n\n    output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])\n    assert output_ref == [8, 10, 8, 10]\n\n    output_ref = worker_group.foo_rank_zero(x=1, y=2)\n    assert output_ref == 5\n\n    return worker_group.worker_names\n\n\ndef add_one(data):\n    data = data.to(\"cuda\")\n    data += 1\n    data = data.to(\"cpu\")\n    return data\n\n\ndef test_basics():\n    ray.init(num_cpus=100)\n\n    # create 4 workers, each hold a GPU\n    resource_pool = RayResourcePool([4], use_gpu=True)\n    class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)\n\n    worker_group = RayWorkerGroup(\n        resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix=\"worker_group_basic\"\n    )\n\n    print(worker_group.worker_names)\n\n    # this will wait for all the results\n    output = worker_group.execute_all_sync(\"foo\", y=3)\n    assert output == [5, 5, 5, 5]\n\n    # this is a list of object reference. It won't block.\n    output_ref = worker_group.execute_all_async(\"foo\", y=4)\n    print(output_ref)\n\n    assert ray.get(output_ref) == [6, 6, 6, 6]\n\n    output_ref = worker_group.foo_one_to_all(x=1, y=2)\n    assert ray.get(output_ref) == [5, 5, 5, 5]\n\n    output_ref = worker_group.foo_all_to_all(x=[1, 2, 3, 4], y=[5, 6, 7, 8])\n    assert ray.get(output_ref) == [8, 10, 12, 14]\n\n    print(ray.get(remote_call_wg.remote(worker_group.worker_names)))\n\n    output = worker_group.execute_func_rank_zero(add_one, torch.ones(2, 2))\n    torch.testing.assert_close(output, torch.ones(2, 2) + 1)\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_basics()\n"
  },
  {
    "path": "verl_distillation/tests/single_controller/test_worker_group_torch.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nos.environ[\"RAY_DEDUP_LOGS\"] = \"0\"\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\n\nimport ray\nimport torch\nimport torch.distributed\n\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n@ray.remote\nclass TestAllGatherActor(Worker):\n    def __init__(self, size) -> None:\n        super().__init__()\n        self.size = size\n\n    def init(self):\n        torch.distributed.init_process_group()\n        self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device=\"cuda\")\n        self.tensor += self.rank\n\n    def all_gather(self):\n        world_size = self._world_size\n        output = torch.zeros(\n            size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device\n        )\n        torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)\n        return output\n\n\n@ray.remote\nclass TestAllGatherActorV2(Worker):\n    def __init__(self, size) -> None:\n        super().__init__()\n        self.size = size\n\n        torch.distributed.init_process_group()\n        self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device=\"cuda\")\n        self.tensor += self.rank\n\n    def all_gather(self):\n        world_size = self._world_size\n        output = torch.zeros(\n            size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device\n        )\n        torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)\n        return output\n\n\ndef test_all_gather_torch():\n    \"\"\"\n    In this test, we instantiate 4 GPUs in a group and test the all_gather\n    \"\"\"\n    ray.init()\n\n    # create 4 workers, each hold a GPU\n    resource_pool = RayResourcePool([4], use_gpu=True)\n    class_with_args = RayClassWithInitArgs(cls=TestAllGatherActor, size=2)\n\n    worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"worker_group_torch\")\n\n    worker_group.execute_all_sync(\"init\")\n    output = worker_group.execute_all_sync(\"all_gather\")\n    for i in range(1, len(output)):\n        assert torch.all(output[i] == output[0])\n\n    output = output[0].cpu()\n    print(output)\n    assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))\n\n    ray.shutdown()\n\n\ndef test_all_gather_torch_v2():\n    \"\"\"\n    In this test, we instantiate 4 GPUs in a group and test the all_gather\n    \"\"\"\n    ray.init()\n\n    # create 4 workers, each hold a GPU\n    resource_pool = RayResourcePool([4], use_gpu=True)\n    class_with_args = RayClassWithInitArgs(cls=TestAllGatherActorV2, size=2)\n\n    worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"worker_group_torch\")\n\n    output = worker_group.execute_all_sync(\"all_gather\")\n    for i in range(1, len(output)):\n        assert torch.all(output[i] == output[0])\n\n    output = output[0].cpu()\n    print(output)\n    assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/special_distributed/README.md",
    "content": "This folder is reserved for unit tests (instead of end-to-end tests) that require multiple GPUs.\n"
  },
  {
    "path": "verl_distillation/tests/special_distributed/run_all.sh",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/usr/bin/env bash\n\nset -e -x\ntorchrun --nproc-per-node=4 --standalone tests/special_distributed/test_tensor_dict.py"
  },
  {
    "path": "verl_distillation/tests/special_distributed/test_fsdp_ckpt.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nimport shutil\nimport tempfile\n\nimport torch\nimport torch.distributed\nfrom torch.distributed import init_device_mesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import MixedPrecision, ShardingStrategy\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config\n\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.distributed import initialize_global_process_group\nfrom verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2\n\n\ndef create_random_input_ids(batch_size, seq_len, vocab_size):\n    from flash_attn.bert_padding import unpad_input\n\n    from verl.utils.model import compute_position_id_with_mask, create_random_mask\n\n    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=\"cuda\")\n\n    attention_mask = create_random_mask(\n        input_ids, max_ratio_of_left_padding=0.1, min_ratio_of_valid_token=0.5, max_ratio_of_valid_token=0.7\n    )\n    position_ids = compute_position_id_with_mask(attention_mask)\n\n    input_ids = unpad_input(input_ids.unsqueeze(-1), attention_mask)[0].transpose(0, 1)\n    position_ids = unpad_input(position_ids.unsqueeze(-1), attention_mask)[0].transpose(0, 1)\n    return input_ids, position_ids\n\n\ndef test_fsdp_ckpt(strategy=\"fsdp\"):\n    assert torch.cuda.device_count() >= 2, \"need at least 2 gpus for test\"\n    local_rank, rank, world_size = initialize_global_process_group()\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(world_size,), mesh_dim_names=(\"dp\",))\n\n    model_name = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B-Instruct\")\n    config = Qwen2Config(num_hidden_layers=1)\n\n    with torch.device(\"cuda\"):\n        model = AutoModelForCausalLM.from_config(\n            config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n        )\n        model = model.to(device=\"cuda\")\n\n    # Wrap model with FSDP\n    if strategy == \"fsdp\":\n        mixed_precision = MixedPrecision(\n            param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32\n        )\n\n        model = FSDP(\n            model,\n            use_orig_params=False,\n            device_id=torch.cuda.current_device(),\n            sharding_strategy=ShardingStrategy.FULL_SHARD,\n            mixed_precision=mixed_precision,\n            device_mesh=device_mesh,\n        )\n    else:\n        mp_policy = MixedPrecisionPolicy(\n            param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True\n        )\n        fsdp_kwargs = {\n            \"mesh\": device_mesh,\n            \"mp_policy\": mp_policy,\n        }\n        apply_fsdp2(model, fsdp_kwargs, {})\n\n    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)\n    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)\n\n    # Create checkpoint manager\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    checkpoint_manager = FSDPCheckpointManager(\n        model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer\n    )\n\n    # Generate sample input\n    batch_size = 10\n    seq_len = 1024\n    vocab_size = config.vocab_size\n    # First input for initial update\n    input_ids1, position_ids1 = create_random_input_ids(batch_size, seq_len, vocab_size)\n\n    # Second input for verification\n    input_ids2, position_ids2 = create_random_input_ids(batch_size, seq_len, vocab_size)\n\n    # Step 1: Initial update and save checkpoint\n    outputs1 = model(input_ids=input_ids1, position_ids=position_ids1)\n    loss1 = outputs1.logits.mean()\n    loss1.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Save checkpoint after first update\n    temp_dir = tempfile.mkdtemp()\n    checkpoint_path = os.path.join(temp_dir, \"checkpoint\")\n    checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)\n    saved_state_dict = model.state_dict()\n\n    # Step 2: Second update and forward pass\n    outputs2 = model(input_ids=input_ids2, position_ids=position_ids2)\n    loss2 = outputs2.logits.mean()\n    loss2.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Record logits after second update\n    with torch.no_grad():\n        logits_before_load = model(input_ids=input_ids2, position_ids=position_ids2).logits\n\n    # Step 3: Load checkpoint and repeat second update\n    checkpoint_manager.load_checkpoint(checkpoint_path)\n    loaded_state_dict = model.state_dict()\n    for key in loaded_state_dict:\n        assert key in saved_state_dict, f\"Key {key} not found in saved state dict\"\n        torch.testing.assert_close(loaded_state_dict[key], saved_state_dict[key], atol=0.0, rtol=0.0)\n\n    # Repeat the second update with same input\n    outputs3 = model(input_ids=input_ids2, position_ids=position_ids2)\n    loss3 = outputs3.logits.mean()\n    loss3.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Record logits after loaded checkpoint and update\n    with torch.no_grad():\n        logits_after_load = model(input_ids=input_ids2, position_ids=position_ids2).logits\n\n    # Step 4: Verify outputs match\n    torch.testing.assert_close(logits_before_load, logits_after_load, atol=0.0, rtol=0.0)\n    print(\"Checkpoint save/load test passed!\")\n\n    # Cleanup\n    shutil.rmtree(temp_dir)\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    strategy = os.environ.get(\"STRATEGY\", \"fsdp\")\n    os.environ[\"FLASH_ATTENTION_DETERMINISTIC\"] = \"1\"\n    test_fsdp_ckpt(strategy=strategy)\n"
  },
  {
    "path": "verl_distillation/tests/special_distributed/test_mcore_config_converter.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport megatron.core.parallel_state as mpu\nimport torch\nfrom megatron.core.transformer import MLATransformerConfig, TransformerConfig\nfrom transformers import AutoConfig, PretrainedConfig\n\nfrom verl.models.mcore import hf_to_mcore_config\nfrom verl.utils.distributed import destroy_global_process_group, initialize_global_process_group\n\nTEST_MODELS = [\n    \"Qwen/Qwen2.5-7B\",  # Qwen2 dense\n    \"Qwen/Qwen3-8B\",  # Qwen3 dense\n    \"deepseek-ai/deepseek-coder-1.3b-instruct\",  # deepseek dense\n    \"Qwen/Qwen2-57B-A14B\",  # Qwen2 moe\n    \"Qwen/Qwen3-30B-A3B\",  # Qwen3 moe\n    # \"mistralai/Mixtral-8x7B-v0.1\",  # Mixtral # require authentication\n    \"deepseek-ai/DeepSeek-V3-Base\",  # Deepseek V3\n]\n\n\ndef check_config_converter_results(tf_config: TransformerConfig | MLATransformerConfig, hf_config: PretrainedConfig):\n    assert tf_config.num_layers == hf_config.num_hidden_layers, (\n        f\"Number of layers mismatch: {tf_config.num_layers} != {hf_config.num_hidden_layers}\"\n    )\n    assert tf_config.hidden_size == hf_config.hidden_size, (\n        f\"Hidden size mismatch: {tf_config.hidden_size} != {hf_config.hidden_size}\"\n    )\n    assert tf_config.num_attention_heads == hf_config.num_attention_heads, (\n        f\"Number of attention heads mismatch: {tf_config.num_attention_heads} != {hf_config.num_attention_heads}\"\n    )\n    assert tf_config.num_query_groups == hf_config.num_key_value_heads, (\n        f\"Number of query groups mismatch: {tf_config.num_query_groups} != {hf_config.num_key_value_heads}\"\n    )\n    assert tf_config.ffn_hidden_size == hf_config.intermediate_size, (\n        f\"FFN hidden size mismatch: {tf_config.ffn_hidden_size} != {hf_config.intermediate_size}\"\n    )\n    assert tf_config.attention_dropout == hf_config.attention_dropout, (\n        f\"Attention dropout mismatch: {tf_config.attention_dropout} != {hf_config.attention_dropout}\"\n    )\n    assert tf_config.hidden_dropout == getattr(hf_config, \"hidden_dropout\", 0.0), (\n        f\"Hidden dropout mismatch: {tf_config.hidden_dropout} != {getattr(hf_config, 'hidden_dropout', 0.0)}\"\n    )\n    if getattr(hf_config, \"head_dim\", None) is not None:\n        assert tf_config.kv_channels == getattr(hf_config, \"head_dim\", None), (\n            f\"Head dim mismatch: {tf_config.kv_channels} != {getattr(hf_config, 'head_dim', None)}\"\n        )\n    assert tf_config.layernorm_epsilon == hf_config.rms_norm_eps, (\n        f\"Layernorm epsilon mismatch: {tf_config.layernorm_epsilon} != {hf_config.rms_norm_eps}\"\n    )\n\n\ndef modify_hf_config(name: str, hf_config: PretrainedConfig):\n    if name == \"deepseek-ai/DeepSeek-V3-Base\":\n        hf_config.num_nextn_predict_layers = 0\n        hf_config.quantization_config = None\n    return hf_config\n\n\ndef test_mcore_config_converter():\n    \"\"\"\n    Test the conversion of Hugging Face model configurations to MCore configurations.\n    \"\"\"\n    local_rank, rank, world_size = initialize_global_process_group()\n    mpu.initialize_model_parallel(\n        tensor_model_parallel_size=2,\n        pipeline_model_parallel_size=2,\n        virtual_pipeline_model_parallel_size=None,\n        pipeline_model_parallel_split_rank=None,\n        use_sharp=False,\n        context_parallel_size=2,\n        expert_model_parallel_size=1,\n        expert_tensor_parallel_size=None,\n        nccl_communicator_config_path=None,\n    )\n    for model_name in TEST_MODELS:\n        print(f\"testing {model_name}\")\n        hf_config = AutoConfig.from_pretrained(os.path.expanduser(f\"~/models/configs/{model_name}/config.json\"))\n        hf_config = modify_hf_config(model_name, hf_config)\n        tf_config = hf_to_mcore_config(hf_config, torch.bfloat16)\n        check_config_converter_results(tf_config, hf_config)\n\n    destroy_global_process_group()\n\n\nif __name__ == \"__main__\":\n    test_mcore_config_converter()\n"
  },
  {
    "path": "verl_distillation/tests/special_distributed/test_tensor_dict.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\n\nimport numpy as np\nimport torch\nimport torch.distributed\n\nfrom verl.protocol import DataProto, all_gather_data_proto\nfrom verl.utils.distributed import initialize_global_process_group\n\n\ndef test_all_gather_data_proto():\n    device_mesh = torch.distributed.device_mesh.init_device_mesh(\"cuda\", mesh_shape=[2, 2], mesh_dim_names=[\"dp\", \"tp\"])\n\n    global_rank = torch.distributed.get_rank()\n\n    obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]])\n\n    labels = [\"a\", \"b\"] if global_rank % 2 == 0 else [\"b\", \"a\"]\n    labels = np.array(labels, dtype=object)\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    all_gather_data_proto(data=data, process_group=device_mesh.get_group(\"dp\"))\n\n    if global_rank == 0:\n        expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device=\"cuda\")\n        expected_labels = [\"a\", \"b\", \"a\", \"b\"]\n    elif global_rank == 1:\n        expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device=\"cuda\")\n        expected_labels = [\"b\", \"a\", \"b\", \"a\"]\n    elif global_rank == 2:\n        expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device=\"cuda\")\n        expected_labels = [\"a\", \"b\", \"a\", \"b\"]\n    elif global_rank == 3:\n        expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device=\"cuda\")\n        expected_labels = [\"b\", \"a\", \"b\", \"a\"]\n\n    torch.testing.assert_close(data.batch[\"obs\"], expected_obs, atol=0, rtol=0)\n    assert (data.non_tensor_batch[\"labels\"] == expected_labels).all()\n    assert data.meta_info == {\"info\": \"test_info\"}\n\n\ndef test_vocab_parallel_entropy():\n    from megatron.core import parallel_state as mpu\n\n    from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy\n    from verl.utils.profiler import log_gpu_memory_usage\n    from verl.utils.torch_functional import entropy_from_logits\n\n    mpu.initialize_model_parallel(\n        tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None\n    )\n\n    batch_size = 2\n    seqlen = 128\n    vocab_size = 155136\n\n    logits = torch.randn(batch_size * seqlen, vocab_size, device=\"cuda\", requires_grad=True)\n    target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device=\"cuda\", dtype=torch.int64)\n\n    # broadcast across tp\n    torch.distributed.broadcast(\n        logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()\n    )\n    torch.distributed.broadcast(\n        target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()\n    )\n\n    tp_rank = mpu.get_tensor_model_parallel_rank()\n    vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size()\n\n    # get the local logits of each tp\n    vocab_parallel_logits = (\n        logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_()\n    )\n    logits.grad = None\n    vocab_parallel_logits.grad = None\n\n    log_gpu_memory_usage(\"begin\")\n    output_entropy = vocab_parallel_entropy(vocab_parallel_logits)\n    log_gpu_memory_usage(\"after forward\")\n    grad_output = torch.randn_like(output_entropy)\n    output_entropy.backward(grad_output)\n    log_gpu_memory_usage(\"after backward\")\n\n    target_entropy = entropy_from_logits(logits)\n    torch.testing.assert_close(output_entropy, target_entropy)\n    target_entropy.backward(grad_output)\n    torch.testing.assert_close(\n        logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad\n    )\n    # make sure logits is not altered\n    torch.testing.assert_close(\n        logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits\n    )\n\n    if mpu.get_tensor_model_parallel_rank() == 0:\n        print(\"test_vocab_parallel_entropy passes\")\n\n    mpu.destroy_model_parallel()\n\n\nif __name__ == \"__main__\":\n    local_rank, rank, world_size = initialize_global_process_group()\n    test_all_gather_data_proto()\n    test_vocab_parallel_entropy()\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/README.md",
    "content": "This folder is reserved for end-to-end tests that typically require multiple GPUs.\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/check_custom_rwd_fn.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\n\n\ndef check_congratulations_in_file(output_file):\n    with open(output_file) as f:\n        output = f.read()\n\n    success_message = \"Congratulations!!! You have called my_reward_function successfully!!!\"\n    assert success_message in output, f\"Success message of my_reward_function not found in {output_file}\"\n    print(\"Check passes\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--output_file\", required=True, type=str)\n\n    args = parser.parse_args()\n\n    check_congratulations_in_file(args.output_file)\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/check_results.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\n\nimport numpy as np\n\n\ndef extract_reward_from_line(line):\n    # TODO: this function needs error handling\n    try:\n        key_vals = line.split(\" - \")\n        for key_val in key_vals:\n            key, val = key_val.split(\":\")\n            if key == \"critic/rewards/mean\":\n                reward = float(val)\n                return reward\n        return -np.inf\n    except Exception:\n        return -np.inf\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--output_file\", required=True, type=str)\n    parser.add_argument(\"--target\", type=float, default=0.2, help=\"target reward score\")\n\n    args = parser.parse_args()\n\n    with open(args.output_file) as f:\n        output = f.read().split(\"\\n\")\n\n    best_reward = -np.inf\n    for line in output:\n        if line.startswith(\"step\"):\n            reward = extract_reward_from_line(line)\n            if reward > best_reward:\n                best_reward = reward\n\n    print(f\"Best reward is {best_reward}\")\n    assert best_reward > args.target, f\"Best reward must be greater than {args.target}. best_reward: {best_reward}\"\n    print(\"Check passes\")\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/envs/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .digit_completion import DigitCompletion\n\n__all__ = [\"DigitCompletion\"]\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/envs/digit_completion/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers import AutoTokenizer, LlamaConfig\n\nfrom .task import DigitCompletion, generate_ground_truth_response\nfrom .tokenizer import CharTokenizer\n\nAutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True)\n\n__all__ = [\"DigitCompletion\", \"generate_ground_truth_response\", \"CharTokenizer\"]\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/envs/digit_completion/task.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Task and environment definition for digit completion.\"\"\"\n\nimport numpy as np\n\n\nclass DigitCompletion:\n    \"\"\"\n    The implementation of a simple digit completion task.\n    The prompt is a sequence of numbers with fixed difference. The task is to complete the next N numbers.\n    If the max number is reached, the next number should be modulo with max number.\n\n    For example,\n    - prompt = [1, 2, 3]\n    - N = 5\n    - max_number = 6\n\n    the response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1]\n\n    Note that the tokenizer is char-level to increase the difficulty.\n    \"\"\"\n\n    def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, seed=0):\n        \"\"\"\n\n        Args:\n            max_number: the maximum number allowed in the arithmetic sequence\n            max_diff: the maximum diff. The actual common diff will be sampled from [0, max_diff]\n            max_num_in_response: the maximum number in the response\n        \"\"\"\n        super().__init__()\n        self.max_number = max_number\n        self.max_diff = max_diff\n        self.max_num_in_response = max_num_in_response\n        assert self.max_num_in_response < 10\n        assert self.max_number > 0\n        assert self.max_diff > 0\n        self.max_number_length = len(str(max_number))\n        # {num1},{num2}:{max_num_in_response},{max_number}\n        self._prompt_length = self.max_number_length * 2 + 4 + self.max_number_length  # no negative is allowed\n\n        self.np_rng = np.random.default_rng(seed=seed)\n\n    def __str__(self):\n        return (\n            f\"Prompt length: {self.prompt_length}. Response length: {self.response_length}, \"\n            f\"Max number: {self.max_number}. Max diff: {self.max_diff}, \"\n            f\"Max number in response: {self.max_num_in_response}\"\n        )\n\n    def get_state(self):\n        return {\"rng\": self.np_rng}\n\n    def set_state(self, state):\n        assert \"rng\" in state, \"rng must be inside state\"\n        self.np_rng = state[\"rng\"]\n\n    @property\n    def prompt_length(self):\n        return self._prompt_length\n\n    @property\n    def response_length(self):\n        # number length + comma length + [EOS]\n        # The actual number times 1.5 to allow 'U'\n        return (self.max_num_in_response * self.max_number_length + (self.max_num_in_response - 1) + 1) * 2\n\n    def add(self, a, b):\n        return (a + b) % self.max_number\n\n    def get_all_prompts(self):\n        all_prompts = []\n        for first_num in range(self.max_number + 1):\n            for diff in range(0, self.max_diff + 1):\n                second_num = self.add(first_num, diff)\n                for num_to_complete in range(self.max_num_in_response + 1):\n                    prompt = str(first_num) + \",\" + str(second_num) + f\":{self.max_number},{num_to_complete}\"\n                    all_prompts.append(prompt)\n        return all_prompts\n\n    def sample_str_prompts(self):\n        # step 1: sample initial numbers\n        first_num = self.np_rng.integers(self.max_number + 1)\n        diff = self.np_rng.integers(self.max_diff + 1)\n        second_num = self.add(first_num, diff)\n        num_to_complete = self.np_rng.integers(self.max_num_in_response + 1)\n        prompt = str(first_num) + \",\" + str(second_num) + f\":{self.max_number},{num_to_complete}\"\n        return prompt\n\n    def sample_batch_str_prompts(self, batch_size):\n        str_prompts = []\n        for _ in range(batch_size):\n            str_prompts.append(self.sample_str_prompts())\n        return str_prompts\n\n\ndef compute_attention_mask(prompts, pad_token_id):\n    mask = np.ones_like(prompts)\n    mask[prompts == pad_token_id] = 0\n    return mask\n\n\ndef compute_position_id_with_mask(mask):\n    return np.clip(np.cumsum(mask, axis=-1) - 1, a_min=0, a_max=None)\n\n\ndef generate_ground_truth_response(prompt: str):\n    \"\"\"Generate ground truth response given a prompt.\"\"\"\n    num, info = prompt.split(\":\")\n    num1, num2 = num.split(\",\")\n    max_number, num_to_gen = info.split(\",\")\n    num1 = int(num1)\n    num2 = int(num2)\n    max_number = int(max_number)\n    num_to_gen = int(num_to_gen)\n    diff = (num2 - num1) % max_number\n    results = []\n    last_num = num2\n    for _ in range(num_to_gen):\n        curr = (last_num + diff) % max_number\n        results.append(str(curr))\n        last_num = curr\n    response = \",\".join(results)\n    return response\n\n\ndef compute_reward(prompt: str, response: str, sequence_reward=1.0):\n    \"\"\"We compute dense reward here so that we can directly train RL without SFT\"\"\"\n    response_length = len(response)\n    ground_truth_response = generate_ground_truth_response(prompt)\n    per_token_reward = sequence_reward / (len(ground_truth_response) + 1)  # including [EOS]\n\n    # pad\n    reward = np.zeros(response_length, dtype=np.float32)  # this assumes that each char is a token\n    # assign reward until mismatches\n    ground_truth_idx = 0\n    for i in range(response_length):\n        if ground_truth_idx == len(ground_truth_response):\n            break\n\n        ground_truth_response_token = ground_truth_response[ground_truth_idx]\n        response_token = response[i]\n        if ground_truth_response_token == response_token:\n            reward[i] = per_token_reward\n            ground_truth_idx += 1\n        else:\n            # no matches\n            break\n\n    return reward, {\"ground_truth_response\": ground_truth_response}\n\n\nif __name__ == \"__main__\":\n    task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5)\n    print(task.sample_str_prompts())\n\n    prompt = \"7,8:20,0\"\n    response = \"\"\n    print(compute_reward(prompt, response))\n\n    prompt = \"7,8:20,0\"\n    response = \"E000\"\n    print(compute_reward(prompt, response))\n\n    prompt = \"9,10:20,2\"\n    response = \"11,12,13\"\n    print(compute_reward(prompt, response))\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/envs/digit_completion/tokenizer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Copied from https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py\n\nCharacterTokenzier for Hugging Face Transformers.\n\nThis is heavily inspired from CanineTokenizer in transformers package.\n\"\"\"\n\nimport json\nimport os\nfrom pathlib import Path\nfrom typing import Optional, Sequence\n\nfrom transformers.tokenization_utils import AddedToken, PreTrainedTokenizer\n\n\nclass CharTokenizer(PreTrainedTokenizer):\n    def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs):\n        \"\"\"Character tokenizer for Hugging Face transformers.\n\n        Args:\n            characters (Sequence[str]): List of desired characters. Any character which\n                is not included in this list will be replaced by a special token called\n                [UNK] with id=6. Following are list of all of the special tokens with\n                their corresponding ids:\n                    \"[CLS]\": 0\n                    \"[SEP]\": 1\n                    \"[BOS]\": 2\n                    \"[MASK]\": 3\n                    \"[PAD]\": 4\n                    \"[RESERVED]\": 5\n                    \"[UNK]\": 6\n                an id (starting at 7) will be assigned to each character.\n\n            model_max_length (int): Model maximum sequence length.\n        \"\"\"\n        eos_token_str = \"E\"\n        sep_token_str = \"S\"\n        pad_token_str = \"P\"\n        unk_token_str = \"U\"\n\n        self.characters = characters\n        self.model_max_length = model_max_length\n        eos_token = AddedToken(eos_token_str, lstrip=False, rstrip=False)\n        sep_token = AddedToken(sep_token_str, lstrip=False, rstrip=False)\n        pad_token = AddedToken(pad_token_str, lstrip=False, rstrip=False)\n        unk_token = AddedToken(unk_token_str, lstrip=False, rstrip=False)\n\n        self._vocab_str_to_int = {\n            sep_token_str: 0,\n            eos_token_str: 1,\n            pad_token_str: 2,\n            unk_token_str: 3,\n            **{ch: i + 4 for i, ch in enumerate(characters)},\n        }\n        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}\n\n        super().__init__(\n            eos_token=eos_token,\n            sep_token=sep_token,\n            pad_token=pad_token,\n            unk_token=unk_token,\n            add_prefix_space=False,\n            model_max_length=model_max_length,\n            **kwargs,\n        )\n\n        self.chat_template = chat_template\n\n    @property\n    def vocab_size(self) -> int:\n        return len(self._vocab_str_to_int)\n\n    def get_vocab(self):\n        return self._vocab_str_to_int\n\n    def _tokenize(self, text: str) -> list[str]:\n        return list(text)\n\n    def _convert_token_to_id(self, token: str) -> int:\n        return self._vocab_str_to_int.get(token, self._vocab_str_to_int[\"U\"])\n\n    def _convert_id_to_token(self, index: int) -> str:\n        return self._vocab_int_to_str[index]\n\n    def convert_tokens_to_string(self, tokens):\n        return \"\".join(tokens)\n\n    def build_inputs_with_special_tokens(\n        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None\n    ) -> list[int]:\n        sep = [self.sep_token_id]\n        cls = [self.cls_token_id]\n        result = cls + token_ids_0 + sep\n        if token_ids_1 is not None:\n            result += token_ids_1 + sep\n        return result\n\n    def get_special_tokens_mask(\n        self,\n        token_ids_0: list[int],\n        token_ids_1: Optional[list[int]] = None,\n        already_has_special_tokens: bool = False,\n    ) -> list[int]:\n        if already_has_special_tokens:\n            return super().get_special_tokens_mask(\n                token_ids_0=token_ids_0,\n                token_ids_1=token_ids_1,\n                already_has_special_tokens=True,\n            )\n\n        result = [1] + ([0] * len(token_ids_0)) + [1]\n        if token_ids_1 is not None:\n            result += ([0] * len(token_ids_1)) + [1]\n        return result\n\n    def get_config(self) -> dict:\n        return {\n            \"char_ords\": [ord(ch) for ch in self.characters],\n            \"model_max_length\": self.model_max_length,\n            \"chat_template\": self.chat_template,\n        }\n\n    @classmethod\n    def from_config(cls, config: dict):\n        cfg = {}\n        cfg[\"characters\"] = [chr(i) for i in config[\"char_ords\"]]\n        cfg[\"model_max_length\"] = config[\"model_max_length\"]\n        cfg[\"chat_template\"] = config[\"chat_template\"]\n        return cls(**cfg)\n\n    def save_pretrained(self, save_directory: str | os.PathLike, **kwargs):\n        cfg_file = Path(save_directory) / \"tokenizer_config.json\"\n        cfg = self.get_config()\n        with open(cfg_file, \"w\") as f:\n            json.dump(cfg, f, indent=4)\n\n    @classmethod\n    def from_pretrained(cls, save_directory: str | os.PathLike, **kwargs):\n        cfg_file = Path(save_directory) / \"tokenizer_config.json\"\n        with open(cfg_file) as f:\n            cfg = json.load(f)\n        return cls.from_config(cfg)\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/generation/run_gen_qwen05.sh",
    "content": "#!/usr/bin/env bash\n# Tested with 1 & 4 GPUs\nset -xeuo pipefail\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\n\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-4}\nOUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_05_gen_test.parquet}\nGEN_TP=${GEN_TP:-2}  # Default tensor parallel size to 2\n\npython3 -m verl.trainer.main_generation \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    data.path=\"${HOME}/data/gsm8k/test.parquet\" \\\n    data.prompt_key=prompt \\\n    data.n_samples=1 \\\n    data.output_path=\"${OUTPUT_PATH}\" \\\n    model.path=\"${MODEL_ID}\" \\\n    +model.trust_remote_code=True \\\n    rollout.temperature=1.0 \\\n    rollout.top_k=50 \\\n    rollout.top_p=0.7 \\\n    rollout.prompt_length=2048 \\\n    rollout.response_length=1024 \\\n    rollout.tensor_model_parallel_size=\"${GEN_TP}\" \\\n    rollout.gpu_memory_utilization=0.8\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/generation/run_gen_qwen05_server.sh",
    "content": "#!/usr/bin/env bash\n# Tested with 1 & 4 GPUs\nset -xeuo pipefail\n\nMODEL_ID=${MODEL_ID:-$HOME/models/Qwen/Qwen2.5-0.5B-Instruct}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\nOUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_05_gen_test.parquet}\nGEN_TP=${GEN_TP:-2}  # Default tensor parallel size to 2\n\npython3 -m verl.trainer.main_generation_server \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    actor_rollout_ref.model.path=\"${MODEL_ID}\" \\\n    actor_rollout_ref.model.trust_remote_code=True \\\n    actor_rollout_ref.rollout.temperature=1.0 \\\n    actor_rollout_ref.rollout.top_k=50 \\\n    actor_rollout_ref.rollout.top_p=0.7 \\\n    actor_rollout_ref.rollout.prompt_length=2048 \\\n    actor_rollout_ref.rollout.response_length=1024 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\"${GEN_TP}\" \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.n=4 \\\n    data.train_files=\"${HOME}/data/gsm8k/test.parquet\" \\\n    data.prompt_key=prompt \\\n    +data.output_path=\"${OUTPUT_PATH}\" \\\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json",
    "content": "{\n    \"num_hidden_layers\": 2,\n    \"max_window_layers\": 2\n}"
  },
  {
    "path": "verl_distillation/tests/special_e2e/ppo_trainer/run_function_reward.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n#huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nTRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}\nVAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}\nMAX_PROMPT_LEN=${MAX_PROMPT_LEN:-512}\nMAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512}\n\nENGINE=${ENGINE:-vllm}\nROLLOUT_MODE=${ROLLOUT_MODE:-sync}\n\nRETURN_RAW_CHAT=\"False\"\nSKIP_TOKENIZER_INIT=${SKIP_TOKENIZER_INIT:-False}\nif [ \"$ROLLOUT_MODE\" = \"async\" ]; then\n    RETURN_RAW_CHAT=\"True\"\n    SKIP_TOKENIZER_INIT=\"True\"\nfi\n\nGPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.8}\nACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False}\nACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False}\nREF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True}\nRM_PAD=${RM_PAD:-True}\nFUSED_KERNELS=${FUSED_KERNELS:-False}\nFUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend\nADV_ESTIMATOR=${ADV_ESTIMATOR:-gae}\nLOSS_MODE=${LOSS_MODE:-vanilla}\nUSE_KL=${USE_KL:-False}\nCUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False}\nENABLE_CHUNKED_PREFILL=${ENABLE_CHUNKED_PREFILL:-True} # For vLLM VLM placeholder issue: https://github.com/vllm-project/vllm/issues/15185\nSTRATEGY=${STRATEGY:-fsdp}\n# LoRA config\nLORA_RANK=${LORA_RANK:-0}\nLORA_ALPHA=${LORA_ALPHA:-${LORA_RANK}}\nLORA_TARGET=${LORA_TARGET:-\"all-linear\"}\nLORA_EXCLUDE=${LORA_EXCLUDE:-\"DONT_EXCLUDE\"}\nUSE_SHM=${USE_SHM:-False}\nLOAD_FORMAT=${LOAD_FORMAT:-dummy}\nLAYERED_SUMMON=${LAYERED_SUMMON:-False}\n# Validation\nVAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False}\nTEST_FREQ=${TEST_FREQ:--1}\n# Save & Resume\nRESUME_MODE=${RESUME_MODE:-disable}\nSAVE_FREQ=${SAVE_FREQ:--1}\nTOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}\n\n# whether to save hf_model\nSAVE_HF_MODEL=${SAVE_HF_MODEL:-False}\nFSDP_SIZE=${FSDP_SIZE:--1}\nSP_SIZE=${SP_SIZE:-1}\n\nif [ \"${SAVE_HF_MODEL}\" = \"True\" ]; then\n    CHECKPOINT_CONTENTS=\"['model','hf_model','optimizer','extra']\"\nelse\n    CHECKPOINT_CONTENTS=\"['model','optimizer','extra']\"\nfi\n\ntrain_traj_micro_bsz_per_gpu=2 # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\nreward_fn_name=null\nreward_fn_file_path=null\noutput_file=\"$(pwd)/output.txt\"\nif [ \"${CUSTOM_REWARD_FN}\" = \"True\" ]; then\n    reward_fn_name=\"my_reward_function\"\n    reward_fn_file_path=\"$(pwd)/my_reward_function.py\"\n    rm -rf \"${reward_fn_file_path}\"\n    cat <<EOF > \"$reward_fn_file_path\"\ndef ${reward_fn_name}(data_source, solution_str, ground_truth, extra_info=None):\n    print(f\"Congratulations!!! You have called ${reward_fn_name} successfully!!!\")\n    return 0.1\nEOF\n\n    rm -rf \"${output_file}\"\nfi\n\nexp_name=\"${VERL_EXP_NAME:-$(basename \"${MODEL_ID,,}\")-function-reward-minimal}\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=\"${ADV_ESTIMATOR}\" \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.train_batch_size=\"${train_prompt_bsz}\" \\\n    data.max_prompt_length=\"${MAX_PROMPT_LEN}\" \\\n    data.max_response_length=\"${MAX_RESPONSE_LEN}\" \\\n    data.return_raw_chat=${RETURN_RAW_CHAT} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.use_shm=${USE_SHM} \\\n    actor_rollout_ref.model.lora_rank=${LORA_RANK} \\\n    actor_rollout_ref.model.lora_alpha=${LORA_ALPHA} \\\n    actor_rollout_ref.model.target_modules=${LORA_TARGET} \\\n    actor_rollout_ref.model.exclude_modules=${LORA_EXCLUDE} \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=\"${RM_PAD}\" \\\n    actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \\\n    actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.actor.strategy=${STRATEGY} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=\"${SP_SIZE}\" \\\n    actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \\\n    actor_rollout_ref.actor.use_kl_loss=\"${USE_KL}\" \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=\"${LOSS_MODE}\" \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=\"${ENGINE}\" \\\n    actor_rollout_ref.rollout.mode=\"${ROLLOUT_MODE}\" \\\n    actor_rollout_ref.rollout.load_format=${LOAD_FORMAT} \\\n    actor_rollout_ref.rollout.layered_summon=${LAYERED_SUMMON} \\\n    actor_rollout_ref.rollout.skip_tokenizer_init=\"${SKIP_TOKENIZER_INIT}\" \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=\"${GPU_MEMORY_UTILIZATION}\" \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=\"${ENABLE_CHUNKED_PREFILL}\" \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=\"${REF_FSDP_PARAM_OFFLOAD}\" \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=\"${RM_PAD}\" \\\n    critic.model.path=\"${MODEL_PATH}\" \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    custom_reward_function.path=\"${reward_fn_file_path}\"\\\n    custom_reward_function.name=\"${reward_fn_name}\"\\\n    algorithm.use_kl_in_reward=\"${USE_KL}\" \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=\"${NUM_GPUS}\" \\\n    trainer.val_before_train=\"${VAL_BEFORE_TRAIN}\" \\\n    trainer.test_freq=\"${TEST_FREQ}\" \\\n    trainer.save_freq=\"${SAVE_FREQ}\" \\\n    trainer.resume_mode=\"${RESUME_MODE}\" \\\n    trainer.total_epochs=2 \\\n    trainer.device=cuda \\\n    trainer.total_training_steps=\"${TOTAL_TRAIN_STEPS}\" $@ \\\n    | tee \"${output_file}\"\n\nif [ \"${CUSTOM_REWARD_FN}\" = \"True\" ]; then\n    python3 tests/special_e2e/check_custom_rwd_fn.py --output_file=\"${output_file}\"\n    check_exit_code=$?\n    rm -rf \"${reward_fn_file_path}\"\n    rm -rf \"${output_file}\"\n    # Return the exit code of check_custom_rwd_fn.py if it fails\n    if [ $check_exit_code -ne 0 ]; then\n        exit $check_exit_code\n    fi\nfi\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/ppo_trainer/run_model_reward.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n#huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nTRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}\nVAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}\n\nRM_PAD=${RM_PAD:-True}\nFUSED_KERNELS=${FUSED_KERNELS:-False}\nFUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend\nSP_SIZE=${SP_SIZE:-1}\nSEQ_BALANCE=${SEQ_BALANCE:-False}\nLIGER=${LIGER:-False}\n# Validation\nVAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False}\nTEST_FREQ=${TEST_FREQ:--1}\n# Save & Resume\nRESUME_MODE=${RESUME_MODE:-disable}\nSAVE_FREQ=${SAVE_FREQ:--1}\nTOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}\n\ntrain_traj_micro_bsz_per_gpu=2 # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\ntrain_max_token_num_per_gpu=32768\ninfer_max_token_num_per_gpu=32768\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-model-reward-minimal\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.use_liger=\"${LIGER}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=\"${RM_PAD}\" \\\n    actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \\\n    actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.use_dynamic_bsz=\"${SEQ_BALANCE}\" \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=\"${SP_SIZE}\" \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    critic.optim.lr=1e-5 \\\n    critic.ulysses_sequence_parallel_size=\"${SP_SIZE}\" \\\n    critic.model.use_remove_padding=\"${RM_PAD}\" \\\n    critic.optim.lr_warmup_steps_ratio=0.05 \\\n    critic.model.path=\"${MODEL_PATH}\" \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.use_dynamic_bsz=\"${SEQ_BALANCE}\" \\\n    critic.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \\\n    critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    reward_model.enable=True \\\n    reward_model.ulysses_sequence_parallel_size=\"${SP_SIZE}\" \\\n    reward_model.model.path=\"${MODEL_PATH}\" \\\n    reward_model.model.use_remove_padding=\"${RM_PAD}\" \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.use_dynamic_bsz=\"${SEQ_BALANCE}\" \\\n    reward_model.forward_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \\\n    reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=\"${NUM_GPUS}\" \\\n    trainer.val_before_train=\"${VAL_BEFORE_TRAIN}\" \\\n    trainer.test_freq=\"${VAL_BEFORE_TRAIN}\" \\\n    trainer.save_freq=\"${SAVE_FREQ}\" \\\n    trainer.resume_mode=\"${RESUME_MODE}\" \\\n    trainer.total_epochs=2 \\\n    trainer.total_training_steps=\"${TOTAL_TRAIN_STEPS}\" $@\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/ppo_trainer/run_single_gpu.sh",
    "content": "PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n  data.train_files=$HOME/data/gsm8k/train.parquet \\\n  data.val_files=$HOME/data/gsm8k/test.parquet \\\n  data.train_batch_size=256  \\\n  data.max_prompt_length=512 \\\n  data.max_response_length=256  \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.actor.optim.lr=1e-6 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4  \\\n  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n  critic.optim.lr=1e-5 \\\n  critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  critic.ppo_micro_batch_size_per_gpu=4 \\\n  algorithm.kl_ctrl.kl_coef=0.001 \\\n  trainer.logger=console \\\n  trainer.val_before_train=False \\\n  trainer.n_gpus_per_node=1 \\\n  trainer.nnodes=1 \\\n  actor_rollout_ref.rollout.name=hf \\\n  trainer.total_training_steps=2"
  },
  {
    "path": "verl_distillation/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh",
    "content": "PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n  data.train_files=$HOME/data/gsm8k/train.parquet \\\n  data.val_files=$HOME/data/gsm8k/test.parquet \\\n  data.train_batch_size=256  \\\n  data.max_prompt_length=512 \\\n  data.max_response_length=256  \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.actor.optim.lr=1e-6 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4  \\\n  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n  critic.optim.lr=1e-5 \\\n  critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  critic.ppo_micro_batch_size_per_gpu=4 \\\n  algorithm.kl_ctrl.kl_coef=0.001 \\\n  trainer.logger=['console'] \\\n  trainer.val_before_train=False \\\n  trainer.n_gpus_per_node=1 \\\n  trainer.nnodes=1 \\\n  actor_rollout_ref.rollout.name=hf \\\n  trainer.use_legacy_worker_impl=disable \\\n  trainer.total_training_steps=2"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_dapo.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n#huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nadv_estimator=grpo\n\nkl_coef=0.0\nuse_kl_in_reward=False\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=1024\nmax_response_length=2048\nenable_overlong_buffer=True\noverlong_buffer_len=128\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=True\nfilter_groups_metric=seq_reward\nmax_num_gen_batches=10\n\ntrain_traj_micro_bsz_per_gpu=2 # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\ngen_prompt_bsz=$((train_prompt_bsz * 4))\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-dapo-minimal\"\n\npython3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${HOME}/data/gsm8k/train.parquet\" \\\n    data.val_files=\"${HOME}/data/gsm8k/test.parquet\" \\\n    reward_model.reward_manager=dapo \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=${NUM_GPUS} \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=2 \\\n    trainer.resume_mode=disable \\\n    trainer.val_before_train=False \\\n    trainer.total_training_steps=1 $@\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_fully_async_policy.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n# Test script for fully_async_policy E2E regression testing\n# This script runs fully async PPO training with both FSDP2 and Megatron backends\n# to ensure the asynchronous training mechanism works correctly\n\nNUM_GPUS=${NUM_GPUS:-8}\nACTOR_STRATEGY=${ACTOR_STRATEGY:-\"fsdp2\"}  # fsdp2 or megatron\n\n# Download model if not exists\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nhuggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\n\nrollout_mode=\"async\"\nrollout_name=\"vllm\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\n# Algorithm parameters\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\n# Response length parameters\nmax_prompt_length=1024\nmax_response_length=2048\nenable_overlong_buffer=True\noverlong_buffer_len=128\noverlong_penalty_factor=1.0\n\n# Training parameters\nloss_agg_mode=\"token-mean\"\n\n# Temperature parameters\ntemperature=1.0\ntop_p=1.0\ntop_k=-1\nval_top_p=0.7\n\n# Fully async specific parameters\nn_gpus_rollout=4\nn_gpus_training=4\n\ntrain_prompt_bsz=0\ngen_prompt_bsz=1\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=16\ntotal_rollout_steps=$(((128)))\ntest_freq=-1\nstaleness_threshold=0.1\ntrigger_parameter_sync_step=4\npartial_rollout=True\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-fully-async-policy-${ACTOR_STRATEGY}-minimal\"\n\necho \"Running fully_async_policy with ${ACTOR_STRATEGY} strategy\"\necho \"Total GPUs: ${NUM_GPUS}, Rollout GPUs: ${n_gpus_rollout}, Training GPUs: ${n_gpus_training}\"\n\n# Common parameters for both FSDP2 and Megatron\ncommon_params=(\n    data.train_files=\"${HOME}/data/gsm8k/train.parquet\"\n    data.val_files=\"${HOME}/data/gsm8k/test.parquet\"\n    data.prompt_key=prompt\n    data.truncation='left'\n    data.max_prompt_length=${max_prompt_length}\n    data.max_response_length=${max_response_length}\n    data.train_batch_size=${train_prompt_bsz}\n    data.gen_batch_size=${gen_prompt_bsz}\n    data.return_raw_chat=${return_raw_chat}\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt}\n    actor_rollout_ref.rollout.calculate_log_probs=True\n    algorithm.adv_estimator=${adv_estimator}\n    algorithm.use_kl_in_reward=${use_kl_in_reward}\n    algorithm.kl_ctrl.kl_coef=${kl_coef}\n    actor_rollout_ref.hybrid_engine=False\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss}\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef}\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low}\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high}\n    actor_rollout_ref.actor.clip_ratio_c=10.0\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\"\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.actor.optim.lr_warmup_steps=-1\n    actor_rollout_ref.actor.optim.weight_decay=0.1\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode}\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80\n    actor_rollout_ref.rollout.temperature=${temperature}\n    actor_rollout_ref.rollout.top_p=${top_p}\n    actor_rollout_ref.rollout.top_k=${top_k}\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature}\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p}\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k}\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True\n    actor_rollout_ref.rollout.val_kwargs.n=1\n    actor_rollout_ref.rollout.enable_chunked_prefill=True\n    actor_rollout_ref.rollout.name=${rollout_name}\n    actor_rollout_ref.rollout.mode=${rollout_mode}\n    reward_model.reward_manager=dapo\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer}\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len}\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor}\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length}\n    trainer.logger=['console']\n    trainer.project_name='verl-test-fully-async'\n    trainer.experiment_name=\"${exp_name}\"\n    trainer.val_before_train=True\n    trainer.save_freq=-1\n    trainer.resume_mode=disable\n    trainer.nnodes=1\n    trainer.n_gpus_per_node=${n_gpus_training}\n    rollout.nnodes=1\n    rollout.n_gpus_per_node=${n_gpus_rollout}\n    rollout.total_rollout_steps=${total_rollout_steps}\n    rollout.total_epochs=2\n    rollout.test_freq=${test_freq}\n    # Fully async specific configurations\n    async_training.staleness_threshold=${staleness_threshold}\n    async_training.partial_rollout=\"${partial_rollout}\"\n    async_training.trigger_parameter_sync_step=\"${trigger_parameter_sync_step}\"\n)\n\nif [ \"${ACTOR_STRATEGY}\" == \"fsdp2\" ]; then\n    echo \"Running fully async training with FSDP2 strategy...\"\n    # FSDP2 specific parameters\n    gen_tp=1\n    sp_size=1\n    fsdp_size=1\n    ref_offload=True\n    actor_offload=False\n\n    python3 -m recipe.fully_async_policy.fully_async_main \\\n        \"${common_params[@]}\" \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.strategy=fsdp2 \\\n        critic.strategy=fsdp2 \\\n        actor_rollout_ref.actor.grad_clip=1.0 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.actor.use_dynamic_bsz=True \\\n        actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \\\n        actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n        actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n        actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n        actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} $@\n\nelif [ \"${ACTOR_STRATEGY}\" == \"megatron\" ]; then\n    echo \"Running fully async training with Megatron strategy...\"\n    # Megatron specific parameters\n    gen_tp=2\n    train_tp=1\n    train_pp=2\n    ref_offload=True\n    actor_offload=False\n\n    python3 -m recipe.fully_async_policy.fully_async_main \\\n        --config-path=config \\\n        --config-name='fully_async_ppo_megatron_trainer.yaml' \\\n        \"${common_params[@]}\" \\\n        actor_rollout_ref.actor.strategy=megatron \\\n        critic.strategy=megatron \\\n        actor_rollout_ref.actor.optim.lr_decay_steps=10000000 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \\\n        actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \\\n        actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \\\n        actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n        actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n        actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n        actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n        actor_rollout_ref.ref.megatron.param_offload=${ref_offload} $@\nelse\n    echo \"Error: Unknown strategy ${ACTOR_STRATEGY}. Please use 'fsdp2' or 'megatron'\"\n    exit 1\nfi\n\necho \"Fully async policy E2E test completed successfully with ${ACTOR_STRATEGY} strategy\"\n\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_genrm_remote.sh",
    "content": "#!/usr/bin/env bash\n\nexport no_proxy=\"localhost,127.0.0.1\"\n\nset -x\n\n# Launch a vllm server\nCUDA_VISIBLE_DEVICES=0 vllm serve $HOME/models/verl-team/GenRM-CI-Test-1.5B \\\n    --served_model_name genrm-demo --host localhost --port 30000 > /dev/null &\nSERVER_PID=$!\n\n# kill server when script exits\ncleanup() {\n    echo \"Cleaning up...\"\n    kill $SERVER_PID 2>/dev/null || true\n    wait $SERVER_PID 2>/dev/null || true\n    echo \"Cleanup done\"\n}\ntrap cleanup EXIT\n\n# wait for server to start\nwait_for_server() {\n    local max_attempts=60\n    local attempt=0\n    local sleep_time=10\n\n    while [ $attempt -lt $max_attempts ]; do\n        if curl -s \"http://localhost:30000/health\" >/dev/null; then\n            echo \"Server is up and running!\"\n            return 0\n        fi\n        echo \"Waiting for server to start... (attempt $((attempt+1))/$max_attempts)\"\n        sleep $sleep_time\n        ((attempt++))\n    done\n    \n    echo \"Error: Failed to start server after $max_attempts attempts\" >&2\n    return 1\n}\n\nif ! wait_for_server; then\n    exit 1\nfi\n\nCUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=${HOME}/data/gsm8k/train.parquet \\\n    data.val_files=${HOME}/data/gsm8k/test.parquet \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=4 \\\n    algorithm.use_kl_in_reward=False \\\n    reward_model.reward_manager=batch \\\n    custom_reward_function.path=recipe/genrm_remote/reward_function.py \\\n    custom_reward_function.name=compute_score_batch \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name='qwen2.5-0.5b-gen-rm' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.resume_mode='disable' \\\n    trainer.total_training_steps=1\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_geo3k_fsdp_sgl_multiturn_w_tool.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\n#huggingface-cli download Qwen/Qwen2.5-VL-3B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-VL-3B-Instruct\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\nFSDP_STRATEGY=${FSDP_STRATEGY:-fsdp}\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='geo3k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=64 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.strategy=$FSDP_STRATEGY \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.ref.strategy=$FSDP_STRATEGY \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='geo3k_async_rl' \\\n    trainer.experiment_name=qwen2.5-vl-3b_function_rm-geo3k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0619-verify-n8 \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    data.train_files=$HOME/data/geo3k_verl_sgl_multi_turn_preprocessed/train.parquet \\\n    data.val_files=$HOME/data/geo3k_verl_sgl_multi_turn_preprocessed/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml\" \\\n    trainer.val_before_train=False \\\n    trainer.total_training_steps=1 $@"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_grpo_lora_with_merge.sh",
    "content": "#!/usr/bin/env bash\n#\n#  An e2e test script for testing the GRPO LoRA training process \n#  and processing the generated checkpoint using the merge_model.py script.  \n\nset -xeuo pipefail\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nif [ ! -d \"$MODEL_PATH\" ]; then\n    echo \"Downloading model to ${MODEL_PATH}...\"\n#    huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\nelse\n    echo \"Model directory ${MODEL_PATH} already exists, skip downloading.\"\nfi\n\n\nBATCH_SIZE=16\nEXP_NAME=\"qwen2.5_0.5b_grpo_lora\"\n# step 1. train model with grpo-lora for 1 step\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=${BATCH_SIZE} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=${MODEL_PATH} \\\n    actor_rollout_ref.model.use_shm=True \\\n    actor_rollout_ref.model.lora_rank=64 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${BATCH_SIZE} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name=${EXP_NAME} \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.total_training_steps=1 \\\n    trainer.save_freq=1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@\n\n# step 2. merge model\npython3 -m verl.model_merger merge \\\n    --backend fsdp \\\n    --local_dir checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/ \\\n    --target_dir checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/hf\n\n# step 3. assert\n# make sure adapter_model.safetensors exists and its size is larger than 1MB\nfile_path=\"checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/hf/lora_adapter/adapter_model.safetensors\"\n\nif [ ! -f \"$file_path\" ]; then\n    echo \"Error: File $file_path does not exist!\"\n    exit 1\nfi\n\nfile_size=$(stat -c %s \"$file_path\")\n\nmin_size_mb=1\nmin_size=$((min_size_mb * 1024 * 1024))  # 1MB = 1048576 bytes\n\nif [ \"$file_size\" -lt \"$min_size\" ]; then\n    echo \"Error: File $file_path is too small! Current size: $((file_size/1024))KB, Required: ${min_size_mb}MB\"\n    exit 1\nfi\n\necho \"Check passed: File exists and size is $(($file_size/1024/1024))MB\"\nexit 0\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh",
    "content": "# run on 8xH20\n# make sure your current working directory is the root of the project\n\nset -x\n\n\nexport PYTHONUNBUFFERED=1\nexport RAY_DEDUP_LOGS=0\nexport RUST_BACKTRACE=1\nexport HYDRA_FULL_ERROR=1\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_sf_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=16384 \\\n    data.filter_overlong_prompts=False \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    data.train_files=$HOME/data/retool_dapo/train.parquet \\\n    data.val_files=$HOME/data/retool_aime2024/train.parquet \\\n    actor_rollout_ref.model.path=Qwen/Qwen3-4B \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_liger=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    +actor_rollout_ref.model.enable_activation_offloading=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.0 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml\" \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='retool_async_rl' \\\n    trainer.experiment_name='qwen3-4b_function_rm-retool-async-sgl-no-sft-n8-v2505271300' \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=100 \\\n    trainer.test_freq=20 \\\n    trainer.total_training_steps=1000 \\\n    trainer.total_epochs=1 $@"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\n#huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-3B-Instruct\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\nFSDP_STRATEGY=${FSDP_STRATEGY:-fsdp}\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.strategy=$FSDP_STRATEGY \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.strategy=$FSDP_STRATEGY \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    data.train_files=$HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed/train.parquet \\\n    data.val_files=$HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.val_before_train=False \\\n    trainer.total_training_steps=1 $@\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_one_step_off_policy.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n# Test script for one_step_off_policy E2E regression testing\n# This script runs one_step_off_policy with both FSDP2 and Megatron backends\n# to ensure the asynchronous training mechanism works correctly\n\nNUM_GPUS=${NUM_GPUS:-8}\nACTOR_STRATEGY=${ACTOR_STRATEGY:-\"fsdp2\"}  # fsdp2 or megatron\n\n# Download model if not exists\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n#huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\n# Algorithm parameters\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\n# Response length parameters\nmax_prompt_length=1024\nmax_response_length=2048\nenable_overlong_buffer=True\noverlong_buffer_len=128\noverlong_penalty_factor=1.0\n\n# Training parameters\nloss_agg_mode=\"token-mean\"\ntrain_prompt_bsz=8\nn_resp_per_prompt=3\ntrain_prompt_mini_bsz=4\n\n# Temperature parameters\ntemperature=1.0\ntop_p=1.0\ntop_k=-1\nval_top_p=0.7\n\n# One-step-off-policy specific parameters\n# Allocate 2 GPUs for rollout, remaining for training\nn_gpus_rollout=2\nn_gpus_training=$((NUM_GPUS - n_gpus_rollout))\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-one-step-off-policy-${ACTOR_STRATEGY}-minimal\"\n\necho \"Running one_step_off_policy with ${ACTOR_STRATEGY} strategy\"\necho \"Total GPUs: ${NUM_GPUS}, Rollout GPUs: ${n_gpus_rollout}, Training GPUs: ${n_gpus_training}\"\n\n# Common parameters for both FSDP2 and Megatron\ncommon_params=(\n    data.train_files=\"${HOME}/data/gsm8k/train.parquet\"\n    data.val_files=\"${HOME}/data/gsm8k/test.parquet\"\n    data.prompt_key=prompt\n    data.truncation='left'\n    data.max_prompt_length=${max_prompt_length}\n    data.max_response_length=${max_response_length}\n    data.train_batch_size=${train_prompt_bsz}\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt}\n    algorithm.adv_estimator=${adv_estimator}\n    algorithm.use_kl_in_reward=${use_kl_in_reward}\n    algorithm.kl_ctrl.kl_coef=${kl_coef}\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss}\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef}\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low}\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high}\n    actor_rollout_ref.actor.clip_ratio_c=10.0\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\"\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.actor.optim.lr_warmup_steps=-1\n    actor_rollout_ref.actor.optim.weight_decay=0.1\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode}\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80\n    actor_rollout_ref.rollout.temperature=${temperature}\n    actor_rollout_ref.rollout.top_p=${top_p}\n    actor_rollout_ref.rollout.top_k=${top_k}\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature}\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p}\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k}\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True\n    actor_rollout_ref.rollout.val_kwargs.n=1\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.name=vllm \\\n    reward_model.reward_manager=dapo\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer}\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len}\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor}\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length}\n    trainer.logger=['console']\n    trainer.project_name='verl-test'\n    trainer.experiment_name=\"${exp_name}\"\n    trainer.val_before_train=False\n    trainer.test_freq=-1\n    trainer.save_freq=-1\n    trainer.total_epochs=2\n    trainer.total_training_steps=2\n    trainer.resume_mode=disable\n    trainer.nnodes=1\n    trainer.n_gpus_per_node=${n_gpus_training}\n    rollout.nnodes=1\n    rollout.n_gpus_per_node=${n_gpus_rollout}\n\n)\n\nif [ \"${ACTOR_STRATEGY}\" == \"fsdp2\" ]; then\n    echo \"Running with FSDP2 strategy...\"\n    # FSDP2 specific parameters\n    gen_tp=2\n    sp_size=2\n    fsdp_size=2\n    ref_offload=True\n    actor_offload=False\n\n    python3 -m recipe.one_step_off_policy.main_ppo \\\n        \"${common_params[@]}\" \\\n        actor_rollout_ref.actor.strategy=fsdp2 \\\n        critic.strategy=fsdp2 \\\n        actor_rollout_ref.actor.grad_clip=1.0 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.use_dynamic_bsz=True \\\n        actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \\\n        actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n        actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n        actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n        actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} $@\n\nelif [ \"${ACTOR_STRATEGY}\" == \"megatron\" ]; then\n    echo \"Running with Megatron strategy...\"\n    # Megatron specific parameters\n    gen_tp=2\n    train_tp=1\n    train_pp=2\n    ref_offload=True\n    actor_offload=False\n\n    python3 -m recipe.one_step_off_policy.main_ppo \\\n        --config-path=config \\\n        --config-name='one_step_off_ppo_megatron_trainer.yaml' \\\n        \"${common_params[@]}\" \\\n        actor_rollout_ref.actor.strategy=megatron \\\n        critic.strategy=megatron \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \\\n        actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \\\n        actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \\\n        actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n        actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n        actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n        actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n        actor_rollout_ref.ref.megatron.param_offload=${ref_offload} $@\nelse\n    echo \"Error: Unknown strategy ${ACTOR_STRATEGY}. Please use 'fsdp2' or 'megatron'\"\n    exit 1\nfi\n\necho \"One-step-off-policy E2E test completed successfully with ${ACTOR_STRATEGY} strategy\""
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_ppo_trainer_megatron.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\nexport VERL_LOGGING_LEVEL=INFO\nexport VERL_PPO_LOGGING_LEVEL=INFO\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n#huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nUSE_DUMMY_MODEL=${USE_DUMMY_MODEL:-False}\nDUMMY_MODEL_PATH=${DUMMY_MODEL_PATH:-${HOME}/dummy_models/${MODEL_ID}}\nif [ \"$USE_DUMMY_MODEL\" = \"True\" ]; then\n    if [ -z \"${DUMMY_MODEL_CONFIG_PATH}\"  ]; then\n        echo \"[ERROR] DUMMY_MODEL_CONFIG_PATH not set\"\n        exit 1\n    fi\n    \n    python scripts/init_random_model.py \\\n        --hf_model_path \"${MODEL_PATH}\" \\\n        --new_config_path \"${DUMMY_MODEL_CONFIG_PATH}\" \\\n        --output_path \"${DUMMY_MODEL_PATH}\"\n\n    MODEL_PATH=\"${DUMMY_MODEL_PATH}\"\nfi\n\nTRAIN_FILES=${TRAIN_FILES:-${HOME}/data/gsm8k/train.parquet}\nVAL_FILES=${VAL_FILES:-${HOME}/data/gsm8k/test.parquet}\n\nADV_ESTIMATOR=${ADV_ESTIMATOR:-gae}\n# Validation\nVAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False}\nTEST_FREQ=${TEST_FREQ:--1}\n# Save & Resume\nRESUME_MODE=${RESUME_MODE:-disable}\nSAVE_FREQ=${SAVE_FREQ:--1}\nTOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}\n\nUSE_DYNAMIC_BSZ=${USE_DYNAMIC_BSZ:-True}\nppo_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN:-2400}\nforward_max_token_len_per_gpu=${FWD_MAX_TOKEN_LEN:-4800}\ntrain_traj_micro_bsz_per_gpu=${MICRO_BSZ:-2} # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\nMAX_PROMPT_LENGTH=${MAX_PROMPT_LENGTH:-512}\nMAX_RESPONSE_LENGTH=${MAX_RESPONSE_LENGTH:-512}\n\nCOMMON_PP=${COMMON_PP:-2}\nCOMMON_VPP=${COMMON_VPP:-2}\nCOMMON_CP=${COMMON_CP:-2}\nCOMMON_TP=${COMMON_TP:-2}\nCOMMON_EP=${COMMON_EP:-1}\nCOMMON_ETP=${COMMON_ETP:-1}\n\nTRAIN_TP=${TRAIN_TP:-$COMMON_TP}\nINFER_TP=${INFER_TP:-$COMMON_TP}\n\nACTOR_PP=${ACTOR_PP:-$COMMON_PP}\nACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP}\nACTOR_CP=${ACTOR_CP:-$COMMON_CP}\nACTOR_TP=${ACTOR_TP:-$TRAIN_TP}\nACTOR_EP=${ACTOR_EP:-$COMMON_EP}\nACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP}\nROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP}\nREF_PP=${REF_PP:-$COMMON_PP}\nREF_VPP=${REF_VPP:-$COMMON_VPP}\nREF_CP=${REF_CP:-$COMMON_CP}\nREF_TP=${REF_TP:-$TRAIN_TP}\nREF_EP=${REF_EP:-$COMMON_EP}\nREF_ETP=${REF_ETP:-$COMMON_ETP}\nCRITIC_PP=${CRITIC_PP:-$COMMON_PP}\nCRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP}\nCRITIC_CP=${CRITIC_CP:-$COMMON_CP}\nCRITIC_TP=${CRITIC_TP:-$TRAIN_TP}\nCRITIC_EP=${CRITIC_EP:-$COMMON_EP}\nCRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP}\nRM_PP=${RM_PP:-$COMMON_PP}\nRM_VPP=${RM_VPP:-$COMMON_VPP}\nRM_CP=${RM_CP:-$COMMON_CP}\nRM_TP=${RM_TP:-$TRAIN_TP}\nRM_EP=${RM_EP:-$COMMON_EP}\nRM_ETP=${RM_ETP:-$COMMON_ETP}\n\nALL_OFFLOAD=${ALL_OFFLOAD:-False}\nCOMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}\n\nACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nREF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nCRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nRM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nUSE_MBRIDGE=${USE_MBRIDGE:-False}\nUSE_FUSED_KERNELS=${USE_FUSED_KERNELS:-False}\n\nLR_WARMUP_STEPS=${LR_WARMUP_STEPS:-null}\n\nCHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra']\nSKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0}\nif [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then\n    CHECKPOINT_CONTENTS=['model','optimizer','extra']\nfi\n\nUSE_DIST_CKPT=${USE_DIST_CKPT:-False}\nDIST_CKPT_PATH=${DIST_CKPT_PATH:-${HOME}/dist_ckpt/${MODEL_ID}}\nif [ \"$USE_DIST_CKPT\" = \"True\" ]; then\n    if [ \"$USE_DUMMY_MODEL\" = \"True\" ]; then\n        DIST_CKPT_PATH=${HOME}/dist_ckpt_dummy/${MODEL_ID}\n    fi\n    python scripts/converter_hf_to_mcore.py \\\n        --hf_model_path \"${MODEL_PATH}\" \\\n        --output_path \"${DIST_CKPT_PATH}\"\nfi\n\nENGINE=${ENGINE:-\"vllm\"}\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-megatron-gsm8k-minimal\"\n\nif [ \"$ENGINE\" = \"vllm\" ]; then\n    MODE=${MODE:-\"sync\"}\n    ROLLOUT_MODE_ARG=\"actor_rollout_ref.rollout.mode=${MODE}\"\n    if [ \"$MODE\" = \"async\" ]; then\n        ROLLOUT_MODE_ARG=\"${ROLLOUT_MODE_ARG} data.return_raw_chat=True\"\n    fi\nelse\n    ROLLOUT_MODE_ARG=\"\"\nfi\n\nOPTIM_MEMORY_EFFICIENT=${OPTIM_MEMORY_EFFICIENT:-False}\n\nPROFILE_ENABLE=${PROFILE_ENABLE:-False}\nPROFILE_STEPS=${PROFILE_STEPS:-[1]}\nPROFILE_RANKS_ALL=${PROFILE_RANKS_ALL:-True}\nPROFILE_RANKS=${PROFILE_RANKS:-[0,1,2,3]}\nDISCRETE=${DISCRETE:-True}  # or True\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=\"${ADV_ESTIMATOR}\" \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.max_prompt_length=${MAX_PROMPT_LENGTH} \\\n    data.max_response_length=${MAX_RESPONSE_LENGTH} \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.use_fused_kernels=${USE_FUSED_KERNELS} \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=$LR_WARMUP_STEPS \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=$OPTIM_MEMORY_EFFICIENT \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=$OPTIM_MEMORY_EFFICIENT \\\n    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=$OPTIM_MEMORY_EFFICIENT \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \\\n    actor_rollout_ref.actor.megatron.use_mbridge=${USE_MBRIDGE} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$ACTOR_PP \\\n    actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=$ACTOR_CP \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$ACTOR_TP \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$ACTOR_EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ACTOR_ETP \\\n    actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_CONTENTS \\\n    actor_rollout_ref.actor.profiler.enable=$PROFILE_ENABLE \\\n    actor_rollout_ref.actor.profiler.ranks=$PROFILE_RANKS \\\n    actor_rollout_ref.actor.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.rollout.name=\"${ENGINE}\" ${ROLLOUT_MODE_ARG}\\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$REF_PP \\\n    actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=$REF_VPP \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=$REF_CP \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$REF_TP \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=$REF_EP \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$REF_ETP \\\n    actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \\\n    critic.optim.lr=2e-5 \\\n    critic.optim.lr_warmup_steps=$LR_WARMUP_STEPS \\\n    +critic.optim.override_optimizer_config.optimizer_cpu_offload=$OPTIM_MEMORY_EFFICIENT \\\n    +critic.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=$OPTIM_MEMORY_EFFICIENT \\\n    +critic.optim.override_optimizer_config.use_precision_aware_optimizer=$OPTIM_MEMORY_EFFICIENT \\\n    critic.model.path=\"${MODEL_PATH}\" \\\n    critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    critic.ppo_max_token_len_per_gpu=${forward_max_token_len_per_gpu} \\\n    critic.megatron.use_mbridge=${USE_MBRIDGE} \\\n    critic.megatron.pipeline_model_parallel_size=$CRITIC_PP \\\n    critic.megatron.virtual_pipeline_model_parallel_size=$CRITIC_VPP \\\n    critic.megatron.context_parallel_size=$CRITIC_CP \\\n    critic.megatron.tensor_model_parallel_size=$CRITIC_TP \\\n    critic.megatron.expert_model_parallel_size=$CRITIC_EP \\\n    critic.megatron.expert_tensor_parallel_size=$CRITIC_ETP \\\n    critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \\\n    critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \\\n    critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \\\n    critic.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \\\n    critic.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \\\n    critic.checkpoint.save_contents=$CHECKPOINT_CONTENTS \\\n    critic.profiler.enable=$PROFILE_ENABLE \\\n    critic.profiler.ranks=$PROFILE_RANKS \\\n    critic.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    reward_model.enable=True \\\n    reward_model.model.path=\"${MODEL_PATH}\" \\\n    reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    reward_model.megatron.use_mbridge=${USE_MBRIDGE} \\\n    reward_model.megatron.pipeline_model_parallel_size=$RM_PP \\\n    reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \\\n    reward_model.megatron.context_parallel_size=$RM_CP \\\n    reward_model.megatron.tensor_model_parallel_size=$RM_TP \\\n    reward_model.megatron.expert_model_parallel_size=$RM_EP \\\n    reward_model.megatron.expert_tensor_parallel_size=$RM_ETP \\\n    reward_model.megatron.param_offload=${RM_PARAM_OFFLOAD} \\\n    reward_model.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \\\n    reward_model.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \\\n    reward_model.profiler.enable=$PROFILE_ENABLE \\\n    reward_model.profiler.ranks=$PROFILE_RANKS \\\n    reward_model.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    algorithm.use_kl_in_reward=False \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=${NUM_GPUS} \\\n    trainer.val_before_train=\"${VAL_BEFORE_TRAIN}\" \\\n    trainer.test_freq=\"${TEST_FREQ}\" \\\n    trainer.save_freq=\"${SAVE_FREQ}\" \\\n    trainer.resume_mode=\"${RESUME_MODE}\" \\\n    trainer.total_epochs=2 \\\n    trainer.total_training_steps=\"${TOTAL_TRAIN_STEPS}\" \\\n    global_profiler.profile_continuous_steps=True \\\n    global_profiler.tool=nsys \\\n    global_profiler.steps=$PROFILE_STEPS \\\n    global_profiler.global_tool_config.nsys.discrete=$DISCRETE $@\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_prime.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n#huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nTRAIN_FILES=${TRAIN_FILES:-${HOME}/data/gsm8k/train.parquet}\nVAL_FILES=${VAL_FILES:-${HOME}/data/gsm8k/test.parquet}\n\ntrain_traj_micro_bsz_per_gpu=2 # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-prime-minimal\"\n\npython3 -m recipe.prime.main_prime \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_accuracy=True \\\n    data.accuracy_lower_bound=0.2 \\\n    data.accuracy_upper_bound=0.8 \\\n    data.oversample_factor=4 \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=5e-7 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=False \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.adv_estimator=rloo \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    reward_model.model.path=\"${MODEL_PATH}\" \\\n    reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    reward_model.model.update=before \\\n    reward_model.model.beta_train=0.05 \\\n    reward_model.model.optim.lr=1e-6 \\\n    reward_model.model.optim.grad_clip=10.0 \\\n    reward_model.model.input_tokenizer=null \\\n    reward_model.mini_batch_size=${train_prompt_bsz} \\\n    reward_model.reward_manager=prime \\\n    trainer.val_before_train=False \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=${NUM_GPUS} \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_training_steps=1 $@\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_r1_distill_qwen_aime24_eval.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n#huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \\\n#    --local-dir $HOME/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\n\npython3 -m verl.trainer.main_generation \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=8 \\\n    data.path=$HOME/data/r1/test.parquet \\\n    data.prompt_key=prompt \\\n    data.batch_size=1024 \\\n    data.n_samples=1 \\\n    data.output_path=$HOME/data/r1/test-output-k1.parquet \\\n    model.path=$HOME/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \\\n    rollout.temperature=0.6 \\\n    rollout.top_p=0.95 \\\n    rollout.prompt_length=1024 \\\n    rollout.response_length=32768 \\\n    rollout.tensor_model_parallel_size=1 \\\n    rollout.gpu_memory_utilization=0.95 \\\n    rollout.max_num_batched_tokens=65536 \\\n    rollout.enforce_eager=False \\\n    rollout.free_cache_engine=True\n\npython3 -m recipe.r1.main_eval \\\n    data.path=$HOME/data/r1/test-output-k1.parquet \\\n    data.prompt_key=prompt \\\n    data.response_key=responses \\\n    custom_reward_function.path=recipe/r1/reward_score.py \\\n    custom_reward_function.name=reward_func"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_spin.sh",
    "content": "set -e\nset -x\nNUM_GPUS=${NUM_GPUS:-8}\n\nexp_name=\"Qwen2.5-0.5B-Instruct-spin-minimal\"\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n#huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nCUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} python3 -m recipe.spin.main_spin \\\n  data.train_files=\"${HOME}/data/gsm8k/train.parquet\" \\\n  data.val_files=\"${HOME}/data/gsm8k/test.parquet\" \\\n  data.train_batch_size=1024 \\\n  data.max_prompt_length=1024 \\\n  data.max_response_length=1024 \\\n  actor_rollout_ref.model.path=$MODEL_PATH \\\n  actor_rollout_ref.actor.optim.lr=1e-6 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size=8 \\\n  actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n  actor_rollout_ref.ref.log_prob_micro_batch_size=64 \\\n  algorithm.kl_ctrl.kl_coef=0.001 \\\n  trainer.logger=console \\\n  trainer.val_before_train=False \\\n  trainer.n_gpus_per_node=4 \\\n  trainer.nnodes=1 \\\n  trainer.save_freq=-1 \\\n  trainer.test_freq=1 \\\n  +trainer.log_freq=1 \\\n  trainer.ref_update_freq=1 \\\n  trainer.total_training_steps=1 \\\n  trainer.total_epochs=1000 2>&1 | tee verl_demo.log"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_sppo.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n# in e2e_sppo.yml, we set NUM_GPUS=8 L20\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nexp_name=\"Qwen2.5-0.5B-Instruct-sppo-minimal\"\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n#huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\npython3 -m recipe.sppo.main_sppo \\\n    data.train_files=\"${HOME}/data/math/train.parquet\" \\\n    data.val_files=\"${HOME}/data/math/test.parquet\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"$MODEL_PATH\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang  \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=$NUM_GPUS \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_training_steps=1 \\\n    trainer.total_epochs=2 $@\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/run_test.sh",
    "content": "#!/bin/bash\nset -xeuo pipefail\n\n# Get the configuration name and engine name from arguments\nCONFIG_NAME=\"$1\"\nENGINE=\"${2:-vllm}\"\n\n# Download model if needed\n#huggingface-cli download Qwen/Qwen2.5-0.5B --local-dir \"$HOME/models/Qwen/Qwen2.5-0.5B\"\n\n# Run the training with the specified configuration\npython3 -m verl.trainer.main_ppo \\\n    --config-name \"$CONFIG_NAME\" \"$@\" "
  },
  {
    "path": "verl_distillation/tests/special_e2e/sft/compare_sft_engine_results.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\n\nimport torch\n\n\ndef get_result(file):\n    file = os.path.expanduser(file)\n    result = []\n    with open(file) as f:\n        lines = f.readlines()\n        for line in lines:\n            result.append(json.loads(line))\n    return result\n\n\ndef compare_results(golden_results, other_result):\n    golden_loss = golden_results[0][\"data\"][\"train/loss\"]\n    golden_grad_norm = golden_results[0][\"data\"][\"train/grad_norm\"]\n\n    loss = other_result[0][\"data\"][\"train/loss\"]\n    grad_norm = other_result[0][\"data\"][\"train/grad_norm\"]\n\n    torch.testing.assert_close(golden_loss, loss, atol=1e-2, rtol=1e-2)\n    torch.testing.assert_close(golden_grad_norm, grad_norm, atol=1e-4, rtol=1e-2)\n\n\nif __name__ == \"__main__\":\n    golden_results = get_result(\"~/verl/test/log/golden.jsonl\")\n\n    # get all other results\n    other_results = {}\n    # walk through all files in ~/verl/test/log\n    for file in os.listdir(os.path.expanduser(\"~/verl/test/log/verl_sft_test\")):\n        if file.endswith(\".jsonl\"):\n            other_results[file] = get_result(os.path.join(os.path.expanduser(\"~/verl/test/log/verl_sft_test\"), file))\n\n    # # compare results\n    for file, other_result in other_results.items():\n        print(f\"compare results {file}\")\n        compare_results(golden_results, other_result)\n\n    print(\"All results are close to golden results\")\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/sft/run_sft.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nENTRYPOINT=${ENTRYPOINT:-\"-m verl.trainer.fsdp_sft_trainer\"}\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n#huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nTRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}\nVAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}\n\nSP_SIZE=${SP_SIZE:-1}\nLIGER=${LIGER:-False}\nMULTITURN=${MULTITURN:-False}\nLORA_RANK=${LORA_RANK:-0}\nRM_PAD=${RM_PAD:-True}\n\nTOTAL_TRAIN_STEP=${TOTAL_TRAIN_STEP:-1}\nRESUME_MODE=${RESUME_MODE:-disable}\nSAVE_FREQ=${SAVE_FREQ:-1}\n\nmicro_bsz=2\nNUM_GPUS=8\n\nproject_name=\"verl-test\"\nexp_name=\"$(basename \"${MODEL_ID,,}\")-sft-minimal\"\nckpts_home=${ckpts_home:-$HOME/${project_name}/${exp_name}}\n\nmkdir -p \"${ckpts_home}\"\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    data.prompt_dict_keys=['question'] \\\n    data.response_dict_keys=['answer'] \\\n    data.multiturn.enable=\"${MULTITURN}\" \\\n    data.multiturn.messages_key=messages \\\n    optim.lr=1e-4 \\\n    data.micro_batch_size_per_gpu=${micro_bsz} \\\n    model.strategy=fsdp \\\n    model.partial_pretrain=\"${MODEL_PATH}\" \\\n    model.lora_rank=\"${LORA_RANK}\" \\\n    model.lora_alpha=16 \\\n    model.target_modules=all-linear \\\n    model.use_liger=\"${LIGER}\" \\\n    ulysses_sequence_parallel_size=\"${SP_SIZE}\" \\\n    use_remove_padding=\"${RM_PAD}\" \\\n    trainer.default_local_dir=\"${ckpts_home}\" \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.total_training_steps=${TOTAL_TRAIN_STEP} \\\n    trainer.save_freq=${SAVE_FREQ} \\\n    trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \\\n    trainer.max_ckpt_to_keep=1 \\\n    trainer.resume_mode=${RESUME_MODE} \\\n    trainer.logger=['console'] $@\n\nrm -rf \"${ckpts_home:?}/*\""
  },
  {
    "path": "verl_distillation/tests/special_e2e/sft/run_sft_engine_gsm8k.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nENTRYPOINT=${ENTRYPOINT:-\"-m verl.trainer.sft_trainer\"}\n\nNUM_GPUS=${NUM_GPUS:-1}\n\nTRAIN_FILES=~/data/gsm8k_sft/train.parquet\nVAL_FILES=~/data/gsm8k_sft/test.parquet\n\nbackend=${BACKEND:-fsdp}\n\nproject_name=verl_sft_test\n\nRESUME_MODE=disable\n\nckpts_home=${ckpts_home:-~/verl/test/gsm8k-sft-${backend}}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen3-0.6B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n#huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nSP_SIZE=${SP_SIZE:-1}\nFSDP_SIZE=${FSDP_SIZE:-${NUM_GPUS}}\nFSDP_STRATEGY=${FSDP_STRATEGY:-\"fsdp\"}\n\nTP_SIZE=${TP_SIZE:-1}\nPP_SIZE=${PP_SIZE:-1}\nVPP_SIZE=${VPP_SIZE:-null}\nCP_SIZE=${CP_SIZE:-1}\n\nPAD_MODE=${PAD_MODE:-no_padding}\n\nUSE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}\n\nFSDP_ENGINE_CONFIG=\"\\\n    engine=${backend} \\\n    optim=${backend} \\\n    optim.lr=1e-5 \\\n    optim.lr_warmup_steps_ratio=0.2 \\\n    optim.weight_decay=0.1 \\\n    optim.betas=\"[0.9,0.95]\" \\\n    optim.clip_grad=1.0 \\\n    optim.min_lr_ratio=0.1 \\\n    optim.lr_scheduler_type=cosine \\\n    engine.ulysses_sequence_parallel_size=${SP_SIZE} \\\n    engine.strategy=${FSDP_STRATEGY} \\\n    engine.fsdp_size=${FSDP_SIZE}\"\n\n\nMEGATRON_ENGINE_CONFIG=\"\\\n    engine=${backend} \\\n    optim=${backend} \\\n    optim.lr=1e-5 \\\n    optim.lr_warmup_steps_ratio=0.2 \\\n    optim.weight_decay=0.1 \\\n    optim.betas=\"[0.9,0.95]\" \\\n    optim.clip_grad=1.0 \\\n    optim.lr_warmup_init=0 \\\n    optim.lr_decay_style=cosine \\\n    optim.min_lr=1e-6 \\\n    engine.tensor_model_parallel_size=${TP_SIZE} \\\n    engine.pipeline_model_parallel_size=${PP_SIZE} \\\n    engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \\\n    engine.context_parallel_size=${CP_SIZE}\"\n\nif [ \"$backend\" = \"fsdp\" ]; then\n    ENGINE_CONFIG=\"$FSDP_ENGINE_CONFIG\"\n    echo \"Using fsdp engine\"\n    exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}\nelse\n    ENGINE_CONFIG=\"$MEGATRON_ENGINE_CONFIG\"\n    echo \"Using megatron engine\"\n    exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}\nfi\n\nmkdir -p \"${ckpts_home}\"\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.train_batch_size=256 \\\n    data.pad_mode=${PAD_MODE} \\\n    data.truncation=error \\\n    data.use_dynamic_bsz=True \\\n    data.max_token_len_per_gpu=8192 \\\n    data.messages_key=messages \\\n    model.path=$MODEL_PATH \\\n    model.use_remove_padding=${USE_REMOVE_PADDING} \\\n    ${ENGINE_CONFIG} \\\n    trainer.test_freq=after_each_epoch \\\n    trainer.save_freq=-1 \\\n    trainer.logger=['console','file'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.total_epochs=2 \\\n    trainer.total_training_steps=2 \\\n    trainer.default_local_dir=\"${ckpts_home}\" \\\n    trainer.resume_mode=${RESUME_MODE} \\\n\n    # trainer.total_training_steps=${TOTAL_TRAIN_STEP} \\\n    # trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \\\n    # trainer.max_ckpt_to_keep=1 \\\n    \nrm -rf \"${ckpts_home:?}/*\""
  },
  {
    "path": "verl_distillation/tests/special_e2e/sft/test_sft_engine_all.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nrm -rf ~/verl/test/log\nmkdir -p ~/verl/test/log\n\nexport VERL_FILE_LOGGER_ROOT=~/verl/test/log\n\n# test with single gpu as golden\necho \"run with single gpu as golden\"\nBACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\n\n# test with fsdp 1\necho \"run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding\"\nBACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\necho \"run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding\"\nBACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\necho \"run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding\"\nBACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\necho \"run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding\"\nBACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\n\n# test use_remove_padding and pad_mode no_padding\necho \"run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False\"\nBACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\n\n\n# test with fsdp 2\necho \"run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2\"\nBACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\n\necho \"run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2\"\nBACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\necho \"run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2\"\nBACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\nBACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\nBACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\n\n# test with megatron\necho \"run with tp1 pp1 cp1 num_gpus1\"\nBACKEND=megatron TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 NUM_GPUS=1 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\n\necho \"run with tp2 pp2 vpp2 cp1 num_gpus8\"\nBACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=2 CP_SIZE=1 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh\n\n# TODO: toggle with following test when cp is fixed\n# BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=2 CP_SIZE=1 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh >& ~/verl/test/log/gsm8k-tp2_pp2_vpp2_cp1_num_gpus8.log\n\npython3 tests/special_e2e/sft/compare_sft_engine_results.py\n\nrm -rf ~/verl/test/log\n"
  },
  {
    "path": "verl_distillation/tests/special_e2e/sft/test_sp_loss_match.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\nfrom tensordict import TensorDict\nfrom torch.distributed.device_mesh import init_device_mesh\n\nfrom verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer\nfrom verl.utils.distributed import initialize_global_process_group\n\n\ndef test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4):\n    \"\"\"Test consistency between original forward pass and SP+rmpad forward passes.\n\n    Args:\n        trainer: The FSDPSFTTrainer instance to test\n        total_steps: Number of steps to test (default: 4)\n    \"\"\"\n    if trainer.device_mesh.get_rank() == 0:\n        print(\"\\nStarting debug comparison between original and SP+rmpad forward passes...\")\n        print(f\"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}\")\n        print(f\"Remove padding: {trainer.use_remove_padding}\\n\")\n\n    steps_remaining = total_steps\n\n    for epoch in range(1):  # Just one epoch for testing\n        trainer.train_sampler.set_epoch(epoch=epoch)\n        for data in trainer.train_dataloader:\n            data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda()\n            trainer.fsdp_model.train()\n            micro_batches = data.split(trainer.config.data.micro_batch_size_per_gpu)\n\n            for idx, micro_batch in enumerate(micro_batches):\n                if trainer.device_mesh.get_rank() == 0:\n                    print(f\"\\nProcessing micro batch {idx + 1}/{len(micro_batches)}\")\n\n                # Compute losses using both methods\n                # Disable SP and rmpad\n                trainer.use_remove_padding = False\n                old_sp = trainer.config.ulysses_sequence_parallel_size\n                trainer.config.ulysses_sequence_parallel_size = 1\n                loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False)\n\n                # Do SP and rmpad\n                trainer.config.ulysses_sequence_parallel_size = old_sp\n                trainer.use_remove_padding = True\n                loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False)\n\n                # Collect losses across all ranks\n                loss_ref_all = loss_ref.clone()\n                loss_sp_all = loss_sp.clone()\n                torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG)\n                torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG)\n\n                # Calculate relative difference of averaged losses\n                rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8)\n\n                if trainer.device_mesh.get_rank() == 0:\n                    print(\"\\nComparison Results (Averaged across ranks):\")\n                    print(f\"Reference Loss: {loss_ref_all.item():.6f}\")\n                    print(f\"SP+rmpad Loss: {loss_sp_all.item():.6f}\")\n                    print(f\"Relative Difference: {rel_diff.item():.6f}\")\n\n                    assert rel_diff.item() < 1e-2, \"Significant difference detected between averaged losses!\"\n                    print(\"Loss difference is within the acceptable range.\")\n\n                steps_remaining -= 1\n                if steps_remaining == 0:\n                    break\n            if steps_remaining == 0:\n                break\n        break\n\n    if trainer.device_mesh.get_rank() == 0:\n        print(\"\\nDebug comparison completed successfully.\")\n\n\ndef create_trainer(config):\n    \"\"\"Create and initialize a trainer instance with the given config.\n\n    Args:\n        config: Configuration object with training parameters\n\n    Returns:\n        FSDPSFTTrainer: Initialized trainer instance\n    \"\"\"\n    local_rank, rank, world_size = initialize_global_process_group()\n\n    device_mesh = init_device_mesh(device_type=\"cuda\", mesh_shape=(world_size,), mesh_dim_names=(\"fsdp\",))\n\n    dp_size = world_size // config.ulysses_sequence_parallel_size\n    ulysses_device_mesh = init_device_mesh(\n        device_type=\"cuda\", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=(\"dp\", \"sp\")\n    )\n\n    # build tokenizer and datasets first\n    from verl.trainer.fsdp_sft_trainer import create_sft_dataset\n    from verl.utils import hf_tokenizer\n    from verl.utils.fs import copy_to_local\n\n    local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)\n    tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)\n    train_dataset = create_sft_dataset(\n        config.data.train_files, config.data, tokenizer, max_samples=config.data.get(\"train_max_samples\", -1)\n    )\n    val_dataset = create_sft_dataset(\n        config.data.val_files, config.data, tokenizer, max_samples=config.data.get(\"val_max_samples\", -1)\n    )\n\n    return FSDPSFTTrainer(\n        config=config,\n        device_mesh=device_mesh,\n        ulysses_device_mesh=ulysses_device_mesh,\n        tokenizer=tokenizer,\n        train_dataset=train_dataset,\n        val_dataset=val_dataset,\n    )\n\n\ndef main(config):\n    \"\"\"Main function to run trainer tests.\n\n    Args:\n        config: Configuration object with training parameters\n    \"\"\"\n    trainer = create_trainer(config)\n    test_trainer_forward_consistency(trainer)\n\n\nif __name__ == \"__main__\":\n    import hydra\n    from omegaconf import DictConfig\n\n    @hydra.main(config_path=\"../../../verl/trainer/config\", config_name=\"sft_trainer\")\n    def hydra_entry(cfg: DictConfig) -> None:\n        main(cfg)\n\n    hydra_entry()\n"
  },
  {
    "path": "verl_distillation/tests/special_npu/run_qwen2_5_05b_dapo.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNUM_GPUS=${NUM_GPUS:-16}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n\nadv_estimator=grpo\n\nkl_coef=0.0\nuse_kl_in_reward=False\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=1024\nmax_response_length=2048\nenable_overlong_buffer=True\noverlong_buffer_len=128\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=True\nfilter_groups_metric=seq_reward\nmax_num_gen_batches=10\n\ntrain_traj_micro_bsz_per_gpu=2 # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\ngen_prompt_bsz=$((train_prompt_bsz * 4))\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-dapo-minimal\"\n\npython3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${HOME}/data/gsm8k/train.parquet\" \\\n    data.val_files=\"${HOME}/data/gsm8k/test.parquet\" \\\n    reward_model.reward_manager=dapo \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.actor.entropy_checkpointing=True \\\n    actor_rollout_ref.ref.entropy_checkpointing=True \\\n    actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \\\n    actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=${NUM_GPUS} \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=1 \\\n    trainer.resume_mode=disable \\\n    trainer.val_before_train=False \\\n    trainer.total_training_steps=2 \\\n    trainer.device=npu $@\n"
  },
  {
    "path": "verl_distillation/tests/special_npu/run_qwen2_5_05b_grpo.sh",
    "content": "set -x\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=128 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=5e-7 \\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=1 \\\n    trainer.total_training_steps=2 \\\n    trainer.device=npu $@\n"
  },
  {
    "path": "verl_distillation/tests/special_npu/run_qwen2_5_05b_grpo_mindspeed.sh",
    "content": "set -x\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n\nUSE_DIST_CKPT=${USE_DIST_CKPT:-False}\nDIST_CKPT_PATH=${DIST_CKPT_PATH:-${HOME}/dist_ckpt/qwen2_5_05b_grpo_mindspeed}\nif [ \"$USE_DIST_CKPT\" = \"True\" ]; then\n    if [ \"$USE_DUMMY_MODEL\" = \"True\" ]; then\n        DIST_CKPT_PATH=${HOME}/dist_ckpt_dummy/${MODEL_ID}\n    fi\n    python scripts/converter_hf_to_mcore.py \\\n        --hf_model_path \"${MODEL_PATH}\" \\\n        --output_path \"${DIST_CKPT_PATH}\"\nfi\n\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=128 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=${MODEL_ID} \\\n    actor_rollout_ref.actor.optim.lr=5e-7 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.actor.strategy=megatron \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=1 \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.strategy=megatron \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=1 \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=1 \\\n    trainer.total_training_steps=2 \\\n    trainer.device=npu \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True $@\n"
  },
  {
    "path": "verl_distillation/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh",
    "content": "set -x\n\nmkdir -p ./save_ckpts\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=8 \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=32 \\\n    model.partial_pretrain=\"${MODEL_PATH}\" \\\n    trainer.default_local_dir=./save_ckpts \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \\\n    trainer.logger=console \\\n    trainer.total_epochs=1 \\\n    trainer.total_training_steps=2 \\\n    model.lora_rank=32 \\\n    model.lora_alpha=16 \\\n    model.target_modules=all-linear \\\n    model.strategy=fsdp \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true \\\n    trainer.device=npu\n\nrm -rf ./outputs ./save_ckpts\n"
  },
  {
    "path": "verl_distillation/tests/special_npu/run_qwen2_5_vl_3b_npu.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\n# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, \n# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model.\nexport USE_OPTIMIZED_MODEL=0\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-VL-3B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.ref.use_torch_compile=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_3b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=1 \\\n    trainer.total_training_steps=2 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_distillation/tests/special_npu/run_qwen3_06b_ppo.sh",
    "content": "set -x\n\n# TODO (FightingZhen) Env VLLM_USE_V1=1 is not supported in vllm==0.7.3\n# export VLLM_USE_V1=1\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}  # TODO: change to Qwen3-0.6B when CI server is ready\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=128 \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=\"${MODEL_PATH}\" \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=8 \\\n    critic.ulysses_sequence_parallel_size=2 \\\n    critic.model.fsdp_config.param_offload=True \\\n    critic.model.fsdp_config.optimizer_offload=True \\\n    critic.use_dynamic_bsz=True \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\"]' \\\n    trainer.project_name='verl_ppo_example_gsm8k_qwen3' \\\n    trainer.experiment_name='qwen3_06b_fsdp' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=1 \\\n    trainer.total_training_steps=2 \\\n    trainer.device=npu $@\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/check_api_docs.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nFail CI if any function or class that is publicly exported via\n``__all__`` lacks a docstring.\n\nUsage\n-----\n  # Check specific modules or packages\n  python check_docstrings.py mypkg.core mypkg.utils\n\n  # Check an entire source tree (all top-level packages under cwd)\n  python check_docstrings.py\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport importlib\nimport inspect\nimport pkgutil\nimport sys\nfrom pathlib import Path\nfrom types import ModuleType\nfrom typing import Iterable\n\n_ALLOW_LIST = [\n    \"verl.third_party.vllm.LLM\",\n    \"verl.third_party.vllm.parallel_state\",\n    \"verl.utils.profiler.WorkerProfiler\",\n    \"verl.utils.profiler.WorkerProfilerExtension\",\n    \"verl.utils.profiler.log_gpu_memory_usage\",\n    \"verl.utils.profiler.log_print\",\n    \"verl.utils.profiler.mark_annotate\",\n    \"verl.utils.profiler.mark_end_range\",\n    \"verl.utils.profiler.mark_start_range\",\n    \"verl.models.mcore.qwen2_5_vl.get_vision_model_config\",\n    \"verl.models.mcore.qwen2_5_vl.get_vision_projection_config\",\n    \"verl.models.mcore.mbridge.freeze_moe_router\",\n    \"verl.models.mcore.mbridge.make_value_model\",\n    \"verl.utils.transformers_compat.flash_attn_supports_top_left_mask\",\n]\n\n\ndef iter_submodules(root: ModuleType) -> Iterable[ModuleType]:\n    \"\"\"Yield *root* and every sub-module inside it.\"\"\"\n    yield root\n\n    def print_pkg_error(pkg_name):\n        print(f\"[warn] Skipping {pkg_name!r}\", file=sys.stderr)\n\n    if getattr(root, \"__path__\", None):  # only packages have __path__\n        for mod_info in pkgutil.walk_packages(root.__path__, prefix=f\"{root.__name__}.\", onerror=print_pkg_error):\n            try:\n                yield importlib.import_module(mod_info.name)\n            except Exception as exc:\n                print(f\"[warn] Skipping {mod_info.name!r}: {exc}\", file=sys.stderr)\n\n\ndef names_missing_doc(mod: ModuleType) -> list[str]:\n    \"\"\"Return fully-qualified names that need docstrings.\"\"\"\n    missing: list[str] = []\n    public = getattr(mod, \"__all__\", [])\n    for name in public:\n        obj = getattr(mod, name, None)\n        if f\"{mod.__name__}.{name}\" in _ALLOW_LIST:\n            continue\n        if obj is None:\n            # Exported but not found in the module: flag it anyway.\n            missing.append(f\"{mod.__name__}.{name}  (not found)\")\n            continue\n\n        if inspect.isfunction(obj) or inspect.isclass(obj):\n            doc = inspect.getdoc(obj)\n            if not doc or not doc.strip():\n                missing.append(f\"{mod.__name__}.{name}\")\n    return missing\n\n\ndef check_module(qualname: str) -> list[str]:\n    \"\"\"Import *qualname* and check it (and sub-modules).\"\"\"\n    try:\n        module = importlib.import_module(qualname)\n    except ModuleNotFoundError as exc:\n        print(f\"[error] Cannot import '{qualname}': {exc}\", file=sys.stderr)\n        return [qualname]\n\n    missing: list[str] = []\n    for submod in iter_submodules(module):\n        missing.extend(names_missing_doc(submod))\n    return missing\n\n\ndef autodiscover_packages() -> list[str]:\n    \"\"\"Detect top-level packages under CWD when no argument is given.\"\"\"\n    pkgs: list[str] = []\n    for p in Path.cwd().iterdir():\n        if p.is_dir() and (p / \"__init__.py\").exists():\n            pkgs.append(p.name)\n    return pkgs\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(description=__doc__)\n    parser.add_argument(\n        \"modules\",\n        nargs=\"*\",\n        help=\"Fully-qualified module or package names (defaults to every top-level package found in CWD).\",\n    )\n    args = parser.parse_args()\n\n    targets = args.modules or autodiscover_packages()\n    if not targets:\n        raise ValueError(\"[error] No modules specified and none detected automatically.\")\n\n    all_missing: list[str] = []\n    for modname in targets:\n        all_missing.extend(check_module(modname))\n\n    if all_missing:\n        print(\"\\nMissing docstrings:\")\n        for name in sorted(all_missing):\n            print(f\"  - {name}\")\n        raise ValueError(\"Missing docstrings detected. Please enhance them with docs accordingly.\")\n\n    print(\"✅ All exported functions/classes have docstrings.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/check_dataproto_usage.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis CI test is used for checking whether DataProto is used in the code of some directory\n\"\"\"\n\nimport os\nfrom argparse import ArgumentParser\nfrom pathlib import Path\n\nSEARCH_WHITELIST = []\n\nSEARCH_KEYWORDS = [\"DataProto\"]\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--directory\", \"-d\", required=True, type=str)\n    args = parser.parse_args()\n    directory_in_str = args.directory\n\n    pathlist = Path(directory_in_str).glob(\"**/*.py\")\n    for path in pathlist:\n        path_in_str = str(path.absolute())\n\n        # judge whether current path is in pre-defined search whitelist or not.\n        path_in_whitelist = False\n\n        for sw in SEARCH_WHITELIST:\n            # for easy debugging in non-linux system\n            sw = sw.replace(\"/\", os.sep)\n            if sw in path_in_str:\n                print(f\"[SKIP] File {path_in_str} is in device api usage check whitelist, checking is skipped.\")\n                path_in_whitelist = True\n                break\n\n        if path_in_whitelist:\n            continue\n\n        with open(path_in_str, encoding=\"utf-8\") as f:\n            file_content = f.read()\n\n            find_invalid_device_management = False\n\n            for sk in SEARCH_KEYWORDS:\n                if sk in file_content:\n                    find_invalid_device_management = True\n                    break\n\n            print(\n                f\"[CHECK] File {path_in_str} is detected for DataProto usage check, check result: \"\n                f\"{'success' if not find_invalid_device_management else f'failed, because detect {sk}'}.\"\n            )\n\n            assert not find_invalid_device_management, (\n                f\"file {path_in_str} contains DataProto usage, please use TensorDict directly!\"\n            )\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/check_device_api_usage.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis CI test is used for checking whether device api usage is irregular, suggest using api in `verl/utils/device.py`.\nSearch targets include .py files in verl/recipe and verl/verl.\nSome files that must contain \".cuda\", \"cuda\" or \"nccl\" keyword is pre-defined in whitelist below.\n\"\"\"\n\nimport os\nfrom argparse import ArgumentParser\nfrom pathlib import Path\n\n# directory or file path must contain keyword \".cuda\" or \"cuda\"\nCUDA_KEYWORD_CHECK_WHITELIST = [\n    \"verl/utils/device.py\",\n    \"recipe/prime/prime_ray_trainer.py\",  # appear in default device_name\n    \"recipe/spin/spin_trainer.py\",  # appear in default device_name\n    \"recipe/sppo/sppo_ray_trainer.py\",  # appear in default device_name\n    \"recipe/one_step_off_policy/ray_trainer.py\",  # appear in default device_name\n    \"recipe/transfer_queue/ray_trainer.py\",  # appear in default device_name\n    \"verl/utils/profiler/nvtx_profile.py\",  # appear in NsightSystemsProfiler\n    \"verl/utils/kernel/linear_cross_entropy.py\",  # appear in nvidia nvtx\n    \"verl/utils/rendezvous/ray_backend.py\",  # appear in cupy importance\n    \"verl/single_controller/ray/base.py\",  # appear in default device_name\n    \"verl/trainer/ppo/ray_trainer.py\",  # appear in default device_name\n    \"verl/utils/reward_score/sandbox_fusion/utils.py\",  # appear in sandbox language type\n    \"verl/workers/reward_model/megatron/reward_model.py\",  # appear in default device_name\n    \"verl/third_party/torch/distributed/_state_dict_utils.py\",  # torch monkey patch fixes\n    \"verl/third_party/torch/distributed/checkpoint/state_dict.py\",  # torch monkey patch fixes\n    \"verl/workers/engine/base.py\",  # appear in default device_name\n    \"verl/workers/engine/fsdp/transformer_impl.py\",  # appear in default device_name\n    \"verl/workers/rollout/vllm_rollout/vllm_async_server.py\",  # appear in config.cudagraph_capture_sizes\n    \"verl/workers/rollout/sglang_rollout/async_sglang_server.py\",  # manually set CUDA_VISIBLE_DEVICES\n]\n\n# directory or file path must contain keyword \"nccl\"\nNCCL_KEYWORD_CHECK_WHITELIST = [\n    \"verl/utils/device.py\",\n    \"verl/third_party/sglang/parallel_state.py\",  # appear in default backend\n    \"verl/recipe/fully_async_policy/param_sync.py\",  # fully_async_policy in default backend\n]\n\nSEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST\n\nSEARCH_KEYWORDS = [\".cuda\", '\"cuda\"', '\"nccl\"']\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--directory\", \"-d\", required=True, type=str)\n    args = parser.parse_args()\n    directory_in_str = args.directory\n\n    pathlist = Path(directory_in_str).glob(\"**/*.py\")\n    for path in pathlist:\n        path_in_str = str(path.absolute())\n\n        # judge whether current path is in pre-defined search whitelist or not.\n        path_in_whitelist = False\n\n        for sw in SEARCH_WHITELIST:\n            # for easy debugging in non-linux system\n            sw = sw.replace(\"/\", os.sep)\n            if sw in path_in_str:\n                print(f\"[SKIP] File {path_in_str} is in device api usage check whitelist, checking is skipped.\")\n                path_in_whitelist = True\n                break\n\n        if path_in_whitelist:\n            continue\n\n        with open(path_in_str, encoding=\"utf-8\") as f:\n            file_content = f.read()\n\n            find_invalid_device_management = False\n\n            for sk in SEARCH_KEYWORDS:\n                if sk in file_content:\n                    find_invalid_device_management = True\n                    break\n\n            print(\n                f\"[CHECK] File {path_in_str} is detected for device api usage check, check result: \"\n                f\"{'success' if not find_invalid_device_management else f'failed, because detect {sk}'}.\"\n            )\n\n            assert not find_invalid_device_management, (\n                f'file {path_in_str} contains .cuda/\"cuda\"/\"nccl\" usage, please use api in '\n                f\"verl/utils/device.py directly.\"\n            )\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/check_docs_time_info.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nCheck that every .md and .rst file under docs/ contains the substring \"Last updated\",\nwith an allow-list for exceptions.\n\"\"\"\n\nimport sys\nfrom pathlib import Path\n\n# === CONFIGURATION ===\n\n# Relative paths (to docs/) or glob patterns to skip checking\nALLOW_LIST = {\n    \"docs/README.md\",  # you can list individual files\n    \"docs/legacy/*.rst\",  # or glob patterns\n    \"docs/index.rst\",\n    \"docs/start/install.rst\",\n    \"docs/start/quickstart.rst\",\n    \"docs/README_vllm0.7.md\",\n}\n\n# The folder to scan\nDOCS_DIR = Path(\"docs\")\n\n# === SCRIPT ===\n\n\ndef is_allowed(path: Path) -> bool:\n    \"\"\"\n    Return True if `path` matches any entry in ALLOW_LIST.\n    \"\"\"\n    rel = str(path)\n    for pattern in ALLOW_LIST:\n        if Path(rel).match(pattern):\n            return True\n    return False\n\n\ndef main():\n    if not DOCS_DIR.exists():\n        print(f\"Error: Documentation directory '{DOCS_DIR}' does not exist.\", file=sys.stderr)\n        sys.exit(1)\n\n    missing = []\n\n    # Gather all .md and .rst files under docs/\n    for ext in (\"*.md\", \"*.rst\"):\n        for path in DOCS_DIR.rglob(ext):\n            if is_allowed(path):\n                continue\n\n            text = path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n            if \"Last updated\" not in text:\n                missing.append(path)\n\n    # Report\n    if missing:\n        print(\"\\nThe following files are missing the 'Last updated' string:\\n\")\n        for p in missing:\n            print(f\"  - {p}\")\n        print(f\"\\nTotal missing: {len(missing)}\\n\", file=sys.stderr)\n        raise AssertionError(\n            \"Some documentation files lack a 'Last updated' line. Please include info such as \"\n            \"'Last updated: mm/dd/yyyy' to indicate the last update time of the document.\"\n        )\n    else:\n        print(\"✅ All checked files contain 'Last updated'.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/check_docstrings.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nPython script to check docstrings for functions and classes in specified files.\nChecks that every public function and class has proper docstring documentation.\n\"\"\"\n\nimport ast\nimport os\nimport sys\n\n\nclass DocstringChecker(ast.NodeVisitor):\n    \"\"\"AST visitor to check for missing docstrings in functions and classes.\"\"\"\n\n    def __init__(self, filename: str):\n        self.filename = filename\n        self.missing_docstrings: list[tuple[str, str, int]] = []\n        self.current_class = None\n        self.function_nesting_level = 0\n\n    def visit_FunctionDef(self, node: ast.FunctionDef):\n        \"\"\"Visit function definitions and check for docstrings.\"\"\"\n        if not node.name.startswith(\"_\") and self.function_nesting_level == 0:\n            if not self._has_docstring(node):\n                func_name = f\"{self.current_class}.{node.name}\" if self.current_class else node.name\n                self.missing_docstrings.append((func_name, self.filename, node.lineno))\n\n        self.function_nesting_level += 1\n        self.generic_visit(node)\n        self.function_nesting_level -= 1\n\n    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):\n        \"\"\"Visit async function definitions and check for docstrings.\"\"\"\n        if not node.name.startswith(\"_\") and self.function_nesting_level == 0:\n            if not self._has_docstring(node):\n                func_name = f\"{self.current_class}.{node.name}\" if self.current_class else node.name\n                self.missing_docstrings.append((func_name, self.filename, node.lineno))\n\n        self.function_nesting_level += 1\n        self.generic_visit(node)\n        self.function_nesting_level -= 1\n\n    def visit_ClassDef(self, node: ast.ClassDef):\n        \"\"\"Visit class definitions and check for docstrings.\"\"\"\n        if not node.name.startswith(\"_\"):\n            if not self._has_docstring(node):\n                self.missing_docstrings.append((node.name, self.filename, node.lineno))\n\n        old_class = self.current_class\n        self.current_class = node.name\n        self.generic_visit(node)\n        self.current_class = old_class\n\n    def _has_docstring(self, node) -> bool:\n        \"\"\"Check if a node has a docstring.\"\"\"\n        return ast.get_docstring(node) is not None\n\n\ndef check_file_docstrings(filepath: str) -> list[tuple[str, str, int]]:\n    \"\"\"Check docstrings in a single file.\"\"\"\n    try:\n        with open(filepath, encoding=\"utf-8\") as f:\n            content = f.read()\n\n        tree = ast.parse(content, filename=filepath)\n        checker = DocstringChecker(filepath)\n        checker.visit(tree)\n        return checker.missing_docstrings\n\n    except Exception as e:\n        print(f\"Error processing {filepath}: {e}\")\n        return []\n\n\ndef main():\n    \"\"\"Main function to check docstrings in specified files.\"\"\"\n\n    files_to_check = [\n        \"verl/trainer/ppo/ray_trainer.py\",\n        \"verl/trainer/main_ppo.py\",\n        \"verl/trainer/ppo/reward.py\",\n        \"verl/utils/reward_score/__init__.py\",\n        \"verl/trainer/ppo/core_algos.py\",\n        \"verl/experimental/agent_loop/agent_loop.py\",\n        \"verl/workers/sharding_manager/fsdp_vllm.py\",\n        \"verl/workers/sharding_manager/fsdp_ulysses.py\",\n    ]\n\n    script_dir = os.path.dirname(os.path.abspath(__file__))\n    repo_path = os.path.dirname(os.path.dirname(script_dir))\n\n    if not os.path.exists(repo_path):\n        print(f\"Repository path {repo_path} does not exist!\")\n        sys.exit(1)\n\n    os.chdir(repo_path)\n\n    all_missing_docstrings = []\n\n    print(\"Checking docstrings in specified files...\")\n    print(\"=\" * 60)\n\n    for file_path in files_to_check:\n        if not os.path.exists(file_path):\n            print(f\"Warning: File {file_path} does not exist!\")\n            continue\n\n        print(f\"Checking {file_path}...\")\n        missing = check_file_docstrings(file_path)\n        all_missing_docstrings.extend(missing)\n\n        if missing:\n            print(f\"  Found {len(missing)} missing docstrings\")\n        else:\n            print(\"  All functions and classes have docstrings ✓\")\n\n    print(\"=\" * 60)\n\n    if all_missing_docstrings:\n        print(f\"\\nSUMMARY: Found {len(all_missing_docstrings)} functions/classes missing docstrings:\")\n        print(\"-\" * 60)\n\n        by_file = {}\n        for name, filepath, lineno in all_missing_docstrings:\n            if filepath not in by_file:\n                by_file[filepath] = []\n            by_file[filepath].append((name, lineno))\n\n        for filepath in sorted(by_file.keys()):\n            print(f\"\\n{filepath}:\")\n            for name, lineno in sorted(by_file[filepath], key=lambda x: x[1]):\n                print(f\"  - {name} (line {lineno})\")\n\n        print(f\"\\nTotal missing docstrings: {len(all_missing_docstrings)}\")\n\n        raise Exception(f\"Found {len(all_missing_docstrings)} functions/classes without proper docstrings!\")\n\n    else:\n        print(\"\\n✅ All functions and classes have proper docstrings!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/check_license.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom argparse import ArgumentParser\nfrom pathlib import Path\nfrom typing import Iterable\n\nlicense_head_bytedance = \"Copyright 2024 Bytedance Ltd. and/or its affiliates\"\nlicense_head_bytedance_25 = \"Copyright 2025 Bytedance Ltd. and/or its affiliates\"\n# Add custom license headers below\nlicense_head_prime = \"Copyright 2024 PRIME team and/or its affiliates\"\nlicense_head_individual = \"Copyright 2025 Individual Contributor:\"\nlicense_head_sglang = \"Copyright 2023-2024 SGLang Team\"\nlicense_head_modelbest = \"Copyright 2025 ModelBest Inc. and/or its affiliates\"\nlicense_head_amazon = \"Copyright 2025 Amazon.com Inc and/or its affiliates\"\nlicense_head_facebook = \"Copyright (c) 2016-     Facebook, Inc\"\nlicense_head_meituan = \"Copyright 2025 Meituan Ltd. and/or its affiliates\"\nlicense_headers = [\n    license_head_bytedance,\n    license_head_bytedance_25,\n    license_head_prime,\n    license_head_individual,\n    license_head_sglang,\n    license_head_modelbest,\n    license_head_amazon,\n    license_head_facebook,\n    license_head_meituan,\n]\n\n\ndef get_py_files(path_arg: Path) -> Iterable[Path]:\n    \"\"\"get py files under a dir. if already py file return it\n\n    Args:\n        path_arg (Path): path to scan for py files\n\n    Returns:\n        Iterable[Path]: list of py files\n    \"\"\"\n    if path_arg.is_dir():\n        return path_arg.glob(\"**/*.py\")\n    elif path_arg.is_file() and path_arg.suffix == \".py\":\n        return [path_arg]\n    return []\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\n        \"--directories\",\n        \"-d\",\n        required=True,\n        type=Path,\n        nargs=\"+\",\n        help=\"List of directories to check for license headers\",\n    )\n    args = parser.parse_args()\n\n    # Collect all Python files from specified directories\n    pathlist = set(path for path_arg in args.directories for path in get_py_files(path_arg))\n\n    for path in pathlist:\n        # because path is object not string\n        path_in_str = str(path.absolute())\n        print(path_in_str)\n        with open(path_in_str, encoding=\"utf-8\") as f:\n            file_content = f.read()\n\n            has_license = False\n            for lh in license_headers:\n                if lh in file_content:\n                    has_license = True\n                    break\n            assert has_license, f\"file {path_in_str} does not contain license\"\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/check_pr_description.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/usr/bin/env python3\nimport json\nimport os\n\n# Number of lines to check\nNUM_LINES = 5\n\n\n# Custom exception types for clear error handling\nclass TemplateFileError(Exception):\n    pass\n\n\nclass PRBodyLoadError(Exception):\n    pass\n\n\nclass PRDescriptionError(Exception):\n    pass\n\n\n# Path to the PR template file\ntemplate_file = os.path.join(os.getenv(\"GITHUB_WORKSPACE\", \".\"), \".github\", \"PULL_REQUEST_TEMPLATE.md\")\n\n\ndef load_template(path):\n    \"\"\"\n    Load only the first NUM_LINES of the PR template file as a list of lines,\n    without stripping any characters.\n    \"\"\"\n    lines = []\n    try:\n        with open(path, encoding=\"utf-8\") as f:\n            for _ in range(NUM_LINES):\n                line = f.readline()\n                if not line:\n                    break\n                lines.append(line.strip())\n        return lines\n    except Exception as e:\n        raise TemplateFileError(f\"Failed to read PR template (first {NUM_LINES} lines) at {path}: {e}\") from e\n\n\ndef load_pr_body(event_path):\n    try:\n        with open(event_path, encoding=\"utf-8\") as f:\n            payload = json.load(f)\n        return payload.get(\"pull_request\", {}).get(\"body\", \"\") or \"\"\n    except Exception as e:\n        raise PRBodyLoadError(f\"Failed to read PR body from {event_path}: {e}\") from e\n\n\ndef check_pr_description(body, template_lines):\n    \"\"\"\n    Compare the first NUM_LINES lines of the PR body to the template lines.\n    If they match exactly, the placeholder was not modified.\n    \"\"\"\n    pr_lines = body.splitlines(keepends=True)\n    pr_first = [x.strip() for x in pr_lines[:NUM_LINES]]\n    if pr_first == template_lines:\n        raise PRDescriptionError(\n            \"It looks like you haven't updated the '### What does this PR do?' section. Please replace \"\n            \"the placeholder text with a concise description of what your PR does.\"\n        )\n    else:\n        print(pr_first)\n        print(template_lines)\n\n\ndef main():\n    event_path = os.getenv(\"GITHUB_EVENT_PATH\")\n    if not event_path:\n        raise OSError(\"GITHUB_EVENT_PATH is not set.\")\n\n    template_lines = load_template(template_file)\n    pr_body = load_pr_body(event_path)\n    check_pr_description(pr_body, template_lines)\n\n    print(\"✅ '### What does this PR do?' section has been filled out.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/check_pr_title.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport re\n\n# Get PR title from environment\npr_title = os.environ.get(\"PR_TITLE\", \"\").strip()\n\n# Define rules\nallowed_modules = [\"fsdp\", \"megatron\", \"sglang\", \"vllm\", \"rollout\", \"trainer\"]\nallowed_modules += [\"tests\", \"training_utils\", \"recipe\", \"hardware\", \"deployment\"]\nallowed_modules += [\"ray\", \"worker\", \"single_controller\", \"misc\", \"docker\", \"ci\"]\nallowed_modules += [\"perf\", \"model\", \"algo\", \"env\", \"tool\", \"ckpt\", \"doc\", \"data\", \"cfg\"]\nallowed_types = [\"feat\", \"fix\", \"refactor\", \"chore\", \"test\"]\n\n# Check for [1/N] prefix and extract the rest of the title\nprogress_match = re.match(r\"^\\[\\d/[\\dNn]\\]\\s*(.+)$\", pr_title, re.IGNORECASE)\nif progress_match:\n    pr_title = progress_match.group(1).strip()\n\n# Check for [BREAKING] prefix and extract the rest of the title\nbreaking_match = re.match(r\"^\\[BREAKING\\]\\s*(.+)$\", pr_title, re.IGNORECASE)\nif breaking_match:\n    core_pr_title = breaking_match.group(1).strip()\n    is_breaking = True\nelse:\n    core_pr_title = pr_title\n    is_breaking = False\n\n# Build dynamic regex pattern for modules (now working on core_pr_title)\nre_modules_pattern = re.compile(r\"^\\[([a-z_,\\s]+)\\]\", re.IGNORECASE)\nre_modules = re_modules_pattern.match(core_pr_title)\nif not re_modules:\n    print(f\"❌ Invalid PR title: '{pr_title}'\")\n    print(\"Expected format: [BREAKING][module] type: description\")\n    print(f\"Allowed modules: {', '.join(allowed_modules)}\")\n    raise Exception(\"Invalid PR title\")\nelse:\n    modules = re.findall(r\"[a-z_]+\", re_modules.group(1).lower())\n    if not all(module in allowed_modules for module in modules):\n        invalid_modules = [module for module in modules if module not in allowed_modules]\n        print(f\"❌ Invalid modules: {', '.join(invalid_modules)}\")\n        print(f\"Allowed modules: {', '.join(allowed_modules)}\")\n        raise Exception(\"Invalid PR title\")\n\ntypes_pattern = \"|\".join(re.escape(t) for t in allowed_types)\nre_types_pattern = re.compile(rf\"^\\[[a-z_,\\s]+\\]\\s+({types_pattern}):\\s+.+$\", re.IGNORECASE)\nmatch = re_types_pattern.match(core_pr_title)\n\nif not match:\n    print(f\"❌ Invalid PR title: '{pr_title}'\")\n    print(\"Expected format: [BREAKING][module] type: description\")\n    print(f\"Allowed types: {', '.join(allowed_types)}\")\n    raise Exception(\"Invalid PR title\")\n\nchange_type = match.group(1).lower()\n\n# Build the success message\nbreaking_info = \" (BREAKING CHANGE)\" if is_breaking else \"\"\nprint(f\"✅ PR title is valid: {pr_title}, modules: {modules}, type: {change_type}{breaking_info}\")\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/test_config_docs.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\nfrom pathlib import Path\n\n\ndef validate_yaml_format(yaml_lines):\n    errors = []\n    i = 0\n\n    while i < len(yaml_lines):\n        line = yaml_lines[i]\n        stripped = line.strip()\n\n        # Skip empty lines\n        if stripped == \"\":\n            i += 1\n            continue\n\n        # Match YAML keys like \"field:\" or \"field: value\"\n        key_match = re.match(r\"^(\\s*)([a-zA-Z0-9_]+):\", line)\n        if key_match:\n            # Check if there's a comment above\n            if i == 0 or not yaml_lines[i - 1].strip().startswith(\"#\"):\n                errors.append(f\"Missing comment above line {i + 1}: {line.strip()}\")\n\n            # Check for inline comment\n            if \"#\" in line and not stripped.startswith(\"#\"):\n                comment_index = line.index(\"#\")\n                colon_index = line.index(\":\")\n                if comment_index > colon_index:\n                    errors.append(f\"Inline comment found on line {i + 1}: {line.strip()}\")\n\n            # Check for blank line after this key line (unless next is a deeper indent)\n            if i + 1 < len(yaml_lines):\n                next_line = yaml_lines[i + 1]\n                next_stripped = next_line.strip()\n\n                # If next is not empty and not a deeper nested line, enforce blank line\n                if next_stripped != \"\":\n                    errors.append(f\"Missing blank line after line {i + 1}: {line.strip()}\")\n\n        i += 1\n\n    return errors\n\n\ndef test_trainer_config_doc():\n    yamls_to_inspect = [\n        \"verl/trainer/config/ppo_trainer.yaml\",\n        \"verl/trainer/config/actor/actor.yaml\",\n        \"verl/trainer/config/actor/dp_actor.yaml\",\n        \"verl/trainer/config/critic/critic.yaml\",\n        \"verl/trainer/config/critic/dp_critic.yaml\",\n        \"verl/trainer/config/ref/ref.yaml\",\n        \"verl/trainer/config/ref/dp_ref.yaml\",\n        \"verl/trainer/config/rollout/rollout.yaml\",\n    ]\n    success = True\n    for yaml_to_inspect in yamls_to_inspect:\n        yaml_path = Path(yaml_to_inspect)  # path to your YAML file\n        with open(yaml_path) as f:\n            lines = f.readlines()\n\n        validation_errors = validate_yaml_format(lines)\n        if validation_errors:\n            success = False\n            print(\"YAML documentation format check failed:\")\n            print(f\"Please read the top block of {yaml_to_inspect} to see format rules:\\n\")\n            for err in validation_errors:\n                print(\" -\", err)\n\n    if not success:\n        raise Exception(\"Please fix documentation format.\")\n    else:\n        print(\"YAML format check passed ✅\")\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/test_import.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ndef test_import():\n    import verl\n\n    print(verl.__version__)\n\n\ndef test_single_controller_import():\n    import verl.single_controller\n\n    print(verl.single_controller.__version__)\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/type_coverage_check.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Custom type annotation check tool.\nTo inspect the type annotation for functions in the entire codebase, please run:\nfind verl -type f -name \"*.py\" | xargs -n 1 python3 tests/special_sanity/type_coverage_check.py --all-lines\n--debug --target-file\n\"\"\"\n\nimport argparse\nimport ast\nimport linecache\nimport subprocess\nfrom pathlib import Path\n\n\ndef get_changed_files() -> list[Path]:\n    result = subprocess.run(\n        [\"git\", \"diff\", \"--name-only\", \"--diff-filter=AM\", \"origin/main...HEAD\"], stdout=subprocess.PIPE, text=True\n    )\n    return [Path(f) for f in result.stdout.splitlines() if f.endswith(\".py\")]\n\n\ndef get_changed_lines(file_path: Path) -> set[int]:\n    result = subprocess.run(\n        [\"git\", \"diff\", \"-U0\", \"origin/main...HEAD\", \"--\", str(file_path)],\n        stdout=subprocess.PIPE,\n        text=True,\n    )\n    lines: set[int] = set()\n    for line in result.stdout.splitlines():\n        if line.startswith(\"@@\"):\n            for part in line.split():\n                try:\n                    if part.startswith(\"+\") and \",\" in part:\n                        start, count = map(int, part[1:].split(\",\"))\n                        lines.update(range(start, start + count))\n                    elif part.startswith(\"+\") and \",\" not in part:\n                        lines.add(int(part[1:]))\n                except Exception:\n                    # (vermouth1992) There are many edge cases here because + can be in the changed program\n                    pass\n    return lines\n\n\nCHECK_SUCCESS = 0\nCHECK_WARNING = 1\nCHECK_FAILURE = -1\n\n\ndef should_check_type(arg_name: str) -> bool:\n    if arg_name in (\"self\", \"cls\"):\n        return False\n    if arg_name.startswith(\"*\"):\n        return False\n    return True\n\n\ndef has_type_annotations(node: ast.AST, debug: bool = False) -> int:\n    if isinstance(node, ast.FunctionDef):\n        is_private = node.name.startswith(\"_\")\n        has_ann = (\n            all(arg.annotation is not None for arg in node.args.args if should_check_type(arg.arg))\n            and node.returns is not None\n        )\n        if has_ann or is_private:\n            return CHECK_SUCCESS\n        else:\n            if debug:\n                print(node, [(arg.annotation, arg.arg) for arg in node.args.args if should_check_type(arg.arg)])\n            return CHECK_FAILURE\n    return CHECK_SUCCESS\n\n\ndef check_file(\n    file_path: Path, changed_lines: set[int], debug: bool = False\n) -> tuple[int, int, list[tuple[Path, int, str]], list[tuple[Path, int, str]]]:\n    with open(file_path) as f:\n        source: str = f.read()\n    tree = ast.parse(source, filename=str(file_path))\n    annotated = 0\n    total = 0\n    warning_lines: list[tuple[Path, int, str]] = []\n    failure_lines: list[tuple[Path, int, str]] = []\n\n    for node in ast.walk(tree):\n        if hasattr(node, \"lineno\") and node.lineno in changed_lines:\n            if isinstance(node, ast.FunctionDef | ast.Assign | ast.AnnAssign):\n                total += 1\n                result = has_type_annotations(node, debug)\n                if result == CHECK_SUCCESS or result == CHECK_WARNING:\n                    annotated += 1\n                    if result == CHECK_WARNING:\n                        warning_lines.append(\n                            (file_path, node.lineno, linecache.getline(str(file_path), node.lineno).strip())\n                        )\n                else:\n                    source_line = linecache.getline(str(file_path), node.lineno).strip()\n                    failure_lines.append((file_path, node.lineno, source_line))\n\n    return annotated, total, warning_lines, failure_lines\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--threshold\", type=float, default=0.3, help=\"Minimum ratio of annotated lines required (0.0 - 1.0)\"\n    )\n    parser.add_argument(\"--target-file\", type=str, default=None, help=\"Path to the Python source file to analyse\")\n    parser.add_argument(\n        \"--all-lines\",\n        action=\"store_true\",\n        help=\"Check all lines in the file instead of only changed lines based on git\",\n    )\n    parser.add_argument(\"--debug\", action=\"store_true\", help=\"Add debugging logs\")\n    args = parser.parse_args()\n\n    total_changed = 0\n    total_annotated = 0\n    all_warnings: list[tuple[Path, int, str]] = []\n    all_failures: list[tuple[Path, int, str]] = []\n\n    target_files = [args.target_file] if args.target_file is not None else get_changed_files()\n    for fpath in target_files:\n        if \"tests/\" in str(fpath):\n            continue\n        if args.all_lines:\n            changed_lines = [i + 1 for i in range(len(open(fpath).readlines()))]\n        else:\n            changed_lines = get_changed_lines(fpath)\n        annotated, total, warning_lines, failure_lines = check_file(fpath, changed_lines, args.debug)\n        total_annotated += annotated\n        total_changed += total\n        all_warnings.extend(warning_lines)\n        all_failures.extend(failure_lines)\n\n    ratio = (total_annotated / total_changed) if total_changed else 1.0\n\n    print(\n        f\"🔍 Type coverage on {'all' if args.all_lines else 'changed'} lines: \"\n        f\"{total_annotated}/{total_changed} = {ratio:.2%}. Files inspected: {target_files}\"\n    )\n\n    if all_warnings:\n        print(\"\\n⚠️ Suggest Improve: Lines missing type annotations for inputs and outputs:\\n\")\n        for fname, lineno, line in all_warnings:\n            print(f\"{fname}:{lineno}: {line}\")\n\n    if all_failures:\n        print(\"⚠️ [ERROR] Lines missing type annotations for inputs and outputs:\\n\")\n        for fname, lineno, line in all_failures:\n            print(f\"{fname}:{lineno}: {line}\")\n\n    if ratio < args.threshold:\n        print(\n            f\"Please add type annotations for inputs and outputs to meet threshold {args.threshold}. \"\n            f\"Cases exempt from checking:\"\n        )\n        print(\"1. Private methods.\")\n        print(\"2. Args with name in ('self', 'cls'), or *args / **kwargs\")\n        print(\"3. Files under tests/\")\n        raise Exception(f\"\\n❌ Type coverage below threshold ({args.threshold:.0%}).\")\n    else:\n        if all_warnings or all_failures:\n            print(\"\")\n        print(\"✅ Type annotation coverage acceptable.\\n\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/validate_imported_docs.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nverify_imported_docs.py\n\nAssert that every function or class *explicitly imported* (via\n`from <module> import <name>`) in a given Python file has a docstring.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport ast\nimport importlib\nimport inspect\nimport pathlib\nimport sys\n\n\ndef _parse_args() -> argparse.Namespace:\n    p = argparse.ArgumentParser(description=\"Verify that imported functions/classes have docstrings.\")\n    p.add_argument(\n        \"--target-file\",\n        default=\"verl/trainer/ppo/ray_trainer.py\",\n        help=\"Path to the Python source file to analyse (e.g. verl/trainer/ppo/ray_trainer.py)\",\n    )\n    p.add_argument(\n        \"--allow-list\",\n        default=[\"omegaconf.open_dict\"],\n        help=\"a list of third_party dependencies that do not have proper docs :(\",\n    )\n    p.add_argument(\n        \"--project-root\",\n        default=\".\",\n        help=\"Directory to prepend to PYTHONPATH so local packages resolve (default: .)\",\n    )\n    p.add_argument(\n        \"--quiet\",\n        action=\"store_true\",\n        help=\"Suppress success message (still prints errors).\",\n    )\n    return p.parse_args()\n\n\ndef _import_attr(module_name: str, attr_name: str):\n    \"\"\"Import `module_name` then return `getattr(module, attr_name)`.\"\"\"\n    module = importlib.import_module(module_name)\n    return getattr(module, attr_name)\n\n\ndef _check_file(py_file: pathlib.Path, project_root: pathlib.Path, allow_list: list[str]) -> list[str]:\n    \"\"\"Return a list of error strings (empty == success).\"\"\"\n    # Ensure local packages resolve\n    sys.path.insert(0, str(project_root.resolve()))\n\n    tree = ast.parse(py_file.read_text(), filename=str(py_file))\n    problems: list[str] = []\n\n    for node in ast.walk(tree):\n        if not isinstance(node, ast.ImportFrom):\n            continue\n\n        # Relative imports (level > 0) get the leading dots stripped\n        module_name = \".\" * node.level + (node.module or \"\")\n        for alias in node.names:\n            if alias.name == \"*\":\n                problems.append(\n                    f\"{py_file}:{node.lineno} - wildcard import `from {module_name} import *` cannot be verified.\"\n                )\n                continue\n\n            imported_name = alias.name\n\n            try:\n                obj = _import_attr(module_name, imported_name)\n            except Exception:  # pragma: no cover – wide net for import quirks\n                pass\n                # For some reason the module cannot be imported, skip for now\n                # problems.append(\n                #     f\"{py_file}:{node.lineno} - could not resolve \"\n                #     f\"`{imported_name}` from `{module_name}` ({exc})\"\n                # )\n                continue\n\n            if f\"{module_name}.{imported_name}\" in allow_list:\n                continue\n            if inspect.isfunction(obj) or inspect.isclass(obj):\n                doc = inspect.getdoc(obj)\n                if not (doc and doc.strip()):\n                    kind = \"class\" if inspect.isclass(obj) else \"function\"\n                    problems.append(\n                        f\"{py_file}:{node.lineno} - {kind} `{module_name}.{imported_name}` is missing a docstring.\"\n                    )\n\n    return problems\n\n\ndef main() -> None:\n    args = _parse_args()\n    target_path = pathlib.Path(args.target_file).resolve()\n    project_root = pathlib.Path(args.project_root).resolve()\n\n    if not target_path.is_file():\n        raise Exception(f\"❌ Target file not found: {target_path}\")\n\n    errors = _check_file(target_path, project_root, args.allow_list)\n\n    if errors:\n        print(\"Docstring verification failed:\\n\")\n        print(\"\\n\".join(f\" • {e}\" for e in errors))\n        raise Exception(\"❌ Docstring verification failed.\")\n\n    if not args.quiet:\n        print(f\"✅ All explicitly imported functions/classes in {target_path} have docstrings.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/tests/special_sanity/validate_structure.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/usr/bin/env python3\n\"\"\"\nValidate that test file subfolders mirror the top-level package layout.\n\nUsage examples\n--------------\n\n# Typical run (defaults: impl_root=my_project, tests_root=tests)\npython check_tests_structure.py\n\n# Custom layout and extra allowed folders\npython check_tests_structure.py \\\n    --impl-root verl \\\n    --tests-root tests \\\n    --allow-dirs special_e2e special_sanity special_standalone special_distributed\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport sys\nfrom pathlib import Path\n\n\ndef discover_allowed_modules(impl_root: Path, extra: list[str]) -> set[str]:\n    \"\"\"Return the set of first-level directories that tests may live under.\"\"\"\n    allowed = {p.name for p in impl_root.iterdir() if p.is_dir()}\n    allowed.update(extra)\n    return allowed\n\n\ndef find_violations(tests_root: Path, allowed: set[str], allowed_files: list[str]) -> list[str]:\n    \"\"\"Return a list of error strings for test files in the wrong place.\"\"\"\n    errors: list[str] = []\n    for test_file in tests_root.rglob(\"test*.py\"):\n        if str(test_file) in allowed_files:\n            continue\n        rel_parts = test_file.relative_to(tests_root).parts\n        if len(rel_parts) < 2:\n            errors.append(f\"{test_file}: must be inside one of {sorted(allowed)} (not at tests root)\")\n            continue\n\n        first_folder = rel_parts[0]\n        if first_folder not in allowed:\n            errors.append(\n                f\"{test_file}: subfolder '{first_folder}' under tests/ is not an allowed module. \"\n                f\"The valid ones are: {sorted(allowed)}\"\n            )\n    return errors\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(description=\"Check that test files follow tests/<module>/… layout.\")\n    parser.add_argument(\n        \"--impl-root\",\n        type=Path,\n        default=\"verl\",\n        help=\"Implementation root (default: my_project)\",\n    )\n    parser.add_argument(\n        \"--tests-root\",\n        type=Path,\n        default=\"tests\",\n        help=\"Root of test tree (default: tests)\",\n    )\n    parser.add_argument(\n        \"--allow-dirs\",\n        nargs=\"*\",\n        default=[\"special_e2e\", \"special_sanity\", \"special_standalone\", \"special_distributed\"],\n        help=\"Extra top-level test folders that are exempt from the rule\",\n    )\n    parser.add_argument(\n        \"--allow-files\",\n        nargs=\"*\",\n        default=[\n            \"tests/test_protocol_on_cpu.py\",\n            \"tests/test_base_config_on_cpu.py\",\n            \"tests/test_protocol_v2_on_cpu.py\",\n        ],\n        help=\"Extra top-level test folders that are exempt from the rule\",\n    )\n    args = parser.parse_args()\n\n    if not args.impl_root.is_dir():\n        raise Exception(f\"Implementation root '{args.impl_root}' does not exist.\")\n    if not args.tests_root.is_dir():\n        raise Exception(f\"Tests root '{args.tests_root}' does not exist.\")\n\n    allowed = discover_allowed_modules(args.impl_root, args.allow_dirs)\n    violations = find_violations(args.tests_root, allowed, args.allow_files)\n\n    if violations:\n        print(\"❌  Test layout violations found:\\n\", file=sys.stderr)\n        for err in violations:\n            print(\"  -\", err, file=sys.stderr)\n\n        print(\n            f\"\\nGuideline:\\n  Place each test file under   tests/<module_name>/…\\n  where <module_name> is \"\n            f\"one of the top-level packages inside '{args.impl_root}', or is explicitly listed via --allow-dirs.\\n\",\n            file=sys.stderr,\n        )\n        raise Exception(\"❌  Test layout violations found.\")\n\n    print(\"✅  Tests folder structure looks good.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/tests/special_standalone/README.md",
    "content": "The standalone test folder is reserved for tests that require dedicated environment (e.g. memory stress tests)\n"
  },
  {
    "path": "verl_distillation/tests/special_standalone/test_memory_buffers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTest memory buffers\n- We start with two models with the same weights\n- We use Memory buffer to make one of the models and then compare the parameters\n\"\"\"\n\nimport gc\n\nimport torch\nfrom transformers import LlamaConfig, LlamaModel\n\n\ndef test_memory_buffers():\n    llama_config = LlamaConfig(\n        vocab_size=256,\n        hidden_size=4096,\n        intermediate_size=11008,\n        num_hidden_layers=2,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n    )\n\n    model = LlamaModel(config=llama_config).cuda()\n    model_copy = LlamaModel(config=llama_config).cuda()\n    model_copy.load_state_dict(model.state_dict())\n\n    norm_factor = 1024**3\n\n    t_before = torch.cuda.get_device_properties(0).total_memory / norm_factor\n    r_before = torch.cuda.memory_reserved(0) / norm_factor\n    a_before = torch.cuda.memory_allocated(0) / norm_factor\n\n    print(f\"Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB\")\n\n    t = torch.cuda.get_device_properties(0).total_memory / norm_factor\n    r = torch.cuda.memory_reserved(0) / norm_factor\n    a = torch.cuda.memory_allocated(0) / norm_factor\n\n    gc.collect()\n    torch.cuda.empty_cache()\n\n    print(f\"After Total memory: {t} GB, reserved: {r} GB, allocated: {a} GB\")\n\n    change_ratio = (a - a_before) / a_before\n    assert change_ratio < 0.01, f\"make sure the allocated change is less than 1%, Got {change_ratio}\"\n\n    for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters(), strict=True):\n        assert name1 == name2\n        assert torch.eq(param1.data, param2.data).all(), f\"{param1.data}, {param2.data}, {name1}\"\n\n\nif __name__ == \"__main__\":\n    test_memory_buffers()\n"
  },
  {
    "path": "verl_distillation/tests/test_base_config_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\nfrom verl.base_config import BaseConfig\n\n\n@pytest.fixture\ndef base_config_mock():\n    \"\"\"Fixture to create a mock BaseConfig instance with test attributes.\"\"\"\n    mock_config = BaseConfig()\n    mock_config.test_attr = \"test_value\"\n    return mock_config\n\n\ndef test_getitem_success(base_config_mock):\n    \"\"\"Test __getitem__ with existing attribute (happy path).\"\"\"\n    assert base_config_mock[\"test_attr\"] == \"test_value\"\n\n\ndef test_getitem_nonexistent_attribute(base_config_mock):\n    \"\"\"Test __getitem__ with non-existent attribute (exception path 1).\"\"\"\n    with pytest.raises(AttributeError):\n        _ = base_config_mock[\"nonexistent_attr\"]\n\n\ndef test_getitem_invalid_key_type(base_config_mock):\n    \"\"\"Test __getitem__ with invalid key type (exception path 2).\"\"\"\n    with pytest.raises(TypeError):\n        _ = base_config_mock[123]  # type: ignore\n"
  },
  {
    "path": "verl_distillation/tests/test_protocol_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport random\n\nimport numpy as np\nimport pytest\nimport tensordict\nimport torch\nfrom packaging.version import parse as parse_version\nfrom tensordict import TensorDict\n\nfrom verl import DataProto\nfrom verl.protocol import (\n    deserialize_single_tensor,\n    deserialize_tensordict,\n    serialize_single_tensor,\n    serialize_tensordict,\n    union_numpy_dict,\n    union_tensor_dict,\n)\nfrom verl.utils import tensordict_utils as tu\n\n\ndef test_union_tensor_dict():\n    obs = torch.randn(100, 10)\n\n    data1 = TensorDict({\"obs\": obs, \"act\": torch.randn(100, 3)}, batch_size=[100])\n    data2 = TensorDict({\"obs\": obs, \"next_obs\": torch.randn(100, 10), \"rew\": torch.randn(100)}, batch_size=[100])\n\n    data_with_copied_obs = TensorDict(\n        {\"obs\": obs.clone(), \"next_obs\": torch.randn(100, 10), \"rew\": torch.randn(100)}, batch_size=[100]\n    )\n\n    union_tensor_dict(data1, data2)\n    with pytest.raises(AssertionError):\n        union_tensor_dict(data1, data_with_copied_obs)\n\n\ndef test_union_numpy_dict():\n    \"\"\"\n    A comprehensive test suite for union_numpy_dict, covering standard use\n    cases, N-dimensional arrays, object-dtype arrays, and NaN value handling.\n    \"\"\"\n    arr_3d = np.arange(8).reshape((2, 2, 2))\n    union_numpy_dict({\"a\": arr_3d}, {\"a\": arr_3d})\n    arr1 = np.array([1, \"hello\", np.array([2, 3])], dtype=object)\n    arr2 = np.array([1, \"hello\", np.array([2, 3])], dtype=object)\n    union_numpy_dict({\"a\": arr1}, {\"a\": arr2})\n    # --- Test Case 1: The original test with mixed object/float types ---\n    # This test case from the original test file is preserved.\n    data = np.random.random(100)\n    # This array intentionally mixes float('nan') and the string 'nan'\n    nan_data = [float(\"nan\") for _ in range(99)]\n    nan_data.append(\"nan\")\n    nan_data_arr = np.array(nan_data, dtype=object)\n\n    dict1 = {\"a\": data, \"b\": nan_data_arr}\n    dict2_same = {\"a\": data.copy(), \"b\": nan_data_arr.copy()}\n    dict3_different = {\"a\": np.random.random(100)}\n\n    union_numpy_dict(dict1, dict2_same)  # Should pass\n    with pytest.raises(AssertionError):\n        union_numpy_dict(dict1, dict3_different)\n\n    # --- Test Case 2: Standard 3D arrays (fixes the core bug) ---\n    arr_3d = np.arange(24, dtype=np.int32).reshape((2, 3, 4))\n    dict_3d_1 = {\"nd_array\": arr_3d}\n    dict_3d_2_same = {\"nd_array\": arr_3d.copy()}\n    dict_3d_3_different = {\"nd_array\": arr_3d + 1}\n\n    union_numpy_dict(dict_3d_1, dict_3d_2_same)  # Should pass\n    with pytest.raises(AssertionError, match=\"`nd_array` in tensor_dict1 and tensor_dict2 are not the same object.\"):\n        union_numpy_dict(dict_3d_1, dict_3d_3_different)\n\n    # --- Test Case 3: Nested 2D and 4D object-dtype arrays ---\n    sub_arr1 = np.array([1, 2])\n    sub_arr2 = np.array([3.0, 4.0])\n    # 2D object array\n    arr_2d_obj = np.array([[sub_arr1, \"text\"], [sub_arr2, None]], dtype=object)\n    arr_2d_obj_diff = np.array([[sub_arr1, \"text\"], [sub_arr2, \"other\"]], dtype=object)\n\n    union_numpy_dict({\"data\": arr_2d_obj}, {\"data\": arr_2d_obj.copy()})  # Should pass\n    with pytest.raises(AssertionError):\n        union_numpy_dict({\"data\": arr_2d_obj}, {\"data\": arr_2d_obj_diff})\n\n    # 4D object array to ensure deep recursion is robust\n    arr_4d_obj = np.array([[[[sub_arr1]]], [[[sub_arr2]]]], dtype=object)\n    arr_4d_obj_diff = np.array([[[[sub_arr1]]], [[[np.array([9, 9])]]]], dtype=object)\n\n    union_numpy_dict({\"data\": arr_4d_obj}, {\"data\": arr_4d_obj.copy()})  # Should pass\n    with pytest.raises(AssertionError):\n        union_numpy_dict({\"data\": arr_4d_obj}, {\"data\": arr_4d_obj_diff})\n\n    # --- Test Case 4: Explicit NaN value comparison ---\n    # This verifies that our new _deep_equal logic correctly handles NaNs.\n    nan_arr = np.array([1.0, np.nan, 3.0])\n    dict_nan_1 = {\"data\": nan_arr}\n    dict_nan_2_same = {\"data\": np.array([1.0, np.nan, 3.0])}  # A new array with same values\n    dict_nan_3_different_val = {\"data\": np.array([1.0, 2.0, 3.0])}\n    dict_nan_4_different_pos = {\"data\": np.array([np.nan, 1.0, 3.0])}\n\n    # NaNs in the same position should be considered equal for merging.\n    union_numpy_dict(dict_nan_1, dict_nan_2_same)  # Should pass\n\n    with pytest.raises(AssertionError):\n        union_numpy_dict(dict_nan_1, dict_nan_3_different_val)\n    with pytest.raises(AssertionError):\n        union_numpy_dict(dict_nan_1, dict_nan_4_different_pos)\n\n    # --- Test Case 5: Circular reference handling ---\n    # Create two separate, but structurally identical, circular references.\n    # This should pass without a RecursionError.\n    circ_arr_1 = np.array([None], dtype=object)\n    circ_arr_1[0] = circ_arr_1\n\n    circ_arr_2 = np.array([None], dtype=object)\n    circ_arr_2[0] = circ_arr_2\n\n    union_numpy_dict({\"data\": circ_arr_1}, {\"data\": circ_arr_2})  # Should pass\n\n    # Create a circular reference and a non-circular one.\n    # This should fail with an AssertionError because they are different.\n    non_circ_arr = np.array([None], dtype=object)\n\n    with pytest.raises(AssertionError):\n        union_numpy_dict({\"data\": circ_arr_1}, {\"data\": non_circ_arr})\n\n\ndef test_tensor_dict_constructor():\n    obs = torch.randn(100, 10)\n    act = torch.randn(100, 10, 3)\n    data = DataProto.from_dict(tensors={\"obs\": obs, \"act\": act})\n\n    assert data.batch.batch_size == torch.Size([100])\n\n    with pytest.raises(AssertionError):\n        data = DataProto.from_dict(tensors={\"obs\": obs, \"act\": act}, num_batch_dims=2)\n\n    with pytest.raises(AssertionError):\n        data = DataProto.from_dict(tensors={\"obs\": obs, \"act\": act}, num_batch_dims=3)\n\n\ndef test_tensor_dict_make_iterator():\n    obs = torch.randn(100, 10)\n    labels = [random.choice([\"abc\", \"cde\"]) for _ in range(100)]\n    dataset = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels})\n\n    data_iter_1 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1)\n    data_list_1 = []\n    for data in data_iter_1:\n        data_list_1.append(data)\n\n    data_iter_2 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1)\n    data_list_2 = []\n    for data in data_iter_2:\n        data_list_2.append(data)\n\n    for data1, data2 in zip(data_list_1, data_list_2, strict=True):\n        assert isinstance(data1, DataProto)\n        assert isinstance(data2, DataProto)\n        result = torch.all(torch.eq(data1.batch[\"obs\"], data2.batch[\"obs\"]))\n        if not result.item():\n            print(data1.batch[\"obs\"])\n            print(data2.batch[\"obs\"])\n            raise AssertionError()\n        non_tensor_result = np.all(np.equal(data1.non_tensor_batch[\"labels\"], data2.non_tensor_batch[\"labels\"]))\n        if not non_tensor_result.item():\n            print(data1.non_tensor_batch[\"labels\"])\n            print(data2.non_tensor_batch[\"labels\"])\n\n\ndef test_reorder():\n    obs = torch.tensor([1, 2, 3, 4, 5, 6])\n    labels = [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abdce\"})\n    data.reorder(torch.tensor([3, 4, 2, 0, 1, 5]))\n\n    assert torch.all(torch.eq(data.batch[\"obs\"], torch.tensor([4, 5, 3, 1, 2, 6])))\n    assert np.all(data.non_tensor_batch[\"labels\"] == np.array([\"d\", \"e\", \"c\", \"a\", \"b\", \"f\"]))\n    assert data.meta_info == {\"name\": \"abdce\"}\n\n\ndef test_chunk_concat():\n    obs = torch.tensor([1, 2, 3, 4, 5, 6])\n    labels = [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abdce\"})\n\n    with pytest.raises(AssertionError):\n        data.chunk(5)\n\n    data_split = data.chunk(2)\n    assert len(data_split) == 2\n    assert torch.all(torch.eq(data_split[0].batch[\"obs\"], torch.tensor([1, 2, 3])))\n    assert np.all(data_split[0].non_tensor_batch[\"labels\"] == np.array([\"a\", \"b\", \"c\"]))\n    assert data_split[0].meta_info == {\"name\": \"abdce\"}\n\n    assert torch.all(torch.eq(data_split[1].batch[\"obs\"], torch.tensor([4, 5, 6])))\n    assert np.all(data_split[1].non_tensor_batch[\"labels\"] == np.array([\"d\", \"e\", \"f\"]))\n    assert data_split[1].meta_info == {\"name\": \"abdce\"}\n\n    concat_data = DataProto.concat(data_split)\n    assert torch.all(torch.eq(concat_data.batch[\"obs\"], data.batch[\"obs\"]))\n    assert np.all(concat_data.non_tensor_batch[\"labels\"] == data.non_tensor_batch[\"labels\"])\n    assert concat_data.meta_info == data.meta_info\n\n\ndef test_concat_metrics_from_multiple_workers():\n    \"\"\"Test that concat() properly merges metrics from all workers in distributed training.\"\"\"\n    # Simulate 3 workers each with their own metrics\n    obs1 = torch.tensor([1, 2])\n    obs2 = torch.tensor([3, 4])\n    obs3 = torch.tensor([5, 6])\n\n    # Each worker has different metrics (as list of dict format)\n    worker1_metrics = [{\"loss\": 0.5, \"accuracy\": 0.9}]\n    worker2_metrics = [{\"loss\": 0.6, \"accuracy\": 0.85}]\n    worker3_metrics = [{\"loss\": 0.55, \"accuracy\": 0.88}]\n\n    data1 = DataProto.from_dict(tensors={\"obs\": obs1}, meta_info={\"metrics\": worker1_metrics, \"config_flag\": True})\n    data2 = DataProto.from_dict(tensors={\"obs\": obs2}, meta_info={\"metrics\": worker2_metrics, \"config_flag\": True})\n    data3 = DataProto.from_dict(tensors={\"obs\": obs3}, meta_info={\"metrics\": worker3_metrics, \"config_flag\": True})\n\n    # Concat all workers' data\n    concat_data = DataProto.concat([data1, data2, data3])\n\n    # Verify tensors are concatenated\n    assert torch.all(torch.eq(concat_data.batch[\"obs\"], torch.tensor([1, 2, 3, 4, 5, 6])))\n\n    # Verify ALL workers' metrics are flattened to dict of lists\n    expected_metrics = {\"loss\": [0.5, 0.6, 0.55], \"accuracy\": [0.9, 0.85, 0.88]}\n    assert concat_data.meta_info[\"metrics\"] == expected_metrics\n\n    # Verify config flags are preserved from first worker\n    assert concat_data.meta_info[\"config_flag\"] is True\n\n\ndef test_concat_with_empty_and_non_list_meta_info():\n    \"\"\"Test concat() handles edge cases: empty meta_info, non-list values, and None.\"\"\"\n    obs1 = torch.tensor([1, 2])\n    obs2 = torch.tensor([3, 4])\n\n    # Worker 1 has metrics, worker 2 doesn't\n    data1 = DataProto.from_dict(tensors={\"obs\": obs1}, meta_info={\"metrics\": [{\"loss\": 0.5}], \"flag\": True})\n    data2 = DataProto.from_dict(tensors={\"obs\": obs2}, meta_info={\"flag\": True})\n\n    concat_data = DataProto.concat([data1, data2])\n\n    # Should flatten worker1's metrics to dict of lists\n    assert concat_data.meta_info[\"metrics\"] == {\"loss\": [0.5]}\n    assert concat_data.meta_info[\"flag\"] is True\n\n    # Test with non-list meta_info value\n    data3 = DataProto.from_dict(tensors={\"obs\": obs1}, meta_info={\"single_value\": 42})\n    data4 = DataProto.from_dict(tensors={\"obs\": obs2}, meta_info={\"single_value\": 42})\n\n    concat_data2 = DataProto.concat([data3, data4])\n    assert concat_data2.meta_info[\"single_value\"] == 42\n\n\ndef test_concat_first_worker_missing_metrics():\n    \"\"\"Test that metrics from other workers are preserved even when first worker has no metrics.\n\n    This is a critical edge case - the old buggy implementation only checked data[0].meta_info\n    and would lose all metrics if the first worker didn't have any.\n    \"\"\"\n    obs1 = torch.tensor([1, 2])\n    obs2 = torch.tensor([3, 4])\n    obs3 = torch.tensor([5, 6])\n\n    # First worker has NO metrics, but workers 2 and 3 do\n    data1 = DataProto.from_dict(tensors={\"obs\": obs1}, meta_info={\"config_flag\": True})\n    data2 = DataProto.from_dict(tensors={\"obs\": obs2}, meta_info={\"metrics\": {\"loss\": 0.6}, \"config_flag\": True})\n    data3 = DataProto.from_dict(tensors={\"obs\": obs3}, meta_info={\"metrics\": {\"loss\": 0.55}, \"config_flag\": True})\n\n    concat_data = DataProto.concat([data1, data2, data3])\n\n    # Should flatten metrics from workers 2 and 3 into dict of lists\n    expected_metrics = {\"loss\": [0.6, 0.55]}\n    assert concat_data.meta_info[\"metrics\"] == expected_metrics\n    assert concat_data.meta_info[\"config_flag\"] is True\n\n\ndef test_concat_non_list_metrics():\n    \"\"\"Test that concat() handles non-list metrics (single dict) correctly.\n\n    In some cases, metrics might be a single dict instead of a list.\n    The implementation should flatten them into a dict of lists.\n    \"\"\"\n    obs1 = torch.tensor([1, 2])\n    obs2 = torch.tensor([3, 4])\n\n    # Metrics as single dict (not wrapped in list)\n    data1 = DataProto.from_dict(tensors={\"obs\": obs1}, meta_info={\"metrics\": {\"loss\": 0.5, \"accuracy\": 0.9}})\n    data2 = DataProto.from_dict(tensors={\"obs\": obs2}, meta_info={\"metrics\": {\"loss\": 0.6, \"accuracy\": 0.85}})\n\n    concat_data = DataProto.concat([data1, data2])\n\n    # Should flatten to dict of lists\n    expected_metrics = {\"loss\": [0.5, 0.6], \"accuracy\": [0.9, 0.85]}\n    assert concat_data.meta_info[\"metrics\"] == expected_metrics\n\n\ndef test_concat_merge_different_non_metric_keys():\n    \"\"\"Test that concat() merges non-metric meta_info keys from all workers.\n\n    When different workers have different non-metric keys, all keys should be preserved.\n    This prevents silent data loss and aligns with the docstring stating meta_info is \"merged\".\n    \"\"\"\n    obs1 = torch.tensor([1, 2])\n    obs2 = torch.tensor([3, 4])\n    obs3 = torch.tensor([5, 6])\n\n    # Each worker has some unique non-metric keys\n    data1 = DataProto.from_dict(tensors={\"obs\": obs1}, meta_info={\"config\": \"A\", \"shared_key\": \"X\"})\n    data2 = DataProto.from_dict(tensors={\"obs\": obs2}, meta_info={\"extra_key\": \"B\", \"shared_key\": \"X\"})\n    data3 = DataProto.from_dict(tensors={\"obs\": obs3}, meta_info={\"another_key\": \"C\", \"shared_key\": \"X\"})\n\n    concat_data = DataProto.concat([data1, data2, data3])\n\n    # All unique keys should be preserved\n    assert concat_data.meta_info[\"config\"] == \"A\"\n    assert concat_data.meta_info[\"extra_key\"] == \"B\"\n    assert concat_data.meta_info[\"another_key\"] == \"C\"\n    assert concat_data.meta_info[\"shared_key\"] == \"X\"\n\n\ndef test_concat_conflicting_non_metric_keys():\n    \"\"\"Test that concat() raises an assertion error when non-metric keys have conflicting values.\n\n    This ensures data integrity by catching cases where workers have different values\n    for what should be the same configuration parameter.\n    \"\"\"\n    obs1 = torch.tensor([1, 2])\n    obs2 = torch.tensor([3, 4])\n\n    # Same key \"config\" but different values\n    data1 = DataProto.from_dict(tensors={\"obs\": obs1}, meta_info={\"config\": \"A\"})\n    data2 = DataProto.from_dict(tensors={\"obs\": obs2}, meta_info={\"config\": \"B\"})\n\n    # Should raise an assertion error due to conflicting values\n    with pytest.raises(AssertionError, match=\"Conflicting values for meta_info key 'config'\"):\n        DataProto.concat([data1, data2])\n\n\ndef test_pop():\n    obs = torch.randn(100, 10)\n    act = torch.randn(100, 3)\n    dataset = DataProto.from_dict({\"obs\": obs, \"act\": act}, meta_info={\"2\": 2, \"1\": 1})\n    poped_dataset = dataset.pop(batch_keys=[\"obs\"], meta_info_keys=[\"2\"])\n\n    assert poped_dataset.batch.keys() == {\"obs\"}\n    assert poped_dataset.meta_info.keys() == {\"2\"}\n\n    assert dataset.batch.keys() == {\"act\"}\n    assert dataset.meta_info.keys() == {\"1\"}\n\n\ndef test_repeat():\n    # Create a DataProto object with some batch and non-tensor data\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    # Test interleave=True\n    repeated_data_interleave = data.repeat(repeat_times=2, interleave=True)\n    expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]])\n    expected_labels_interleave = [\"a\", \"a\", \"b\", \"b\", \"c\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_interleave.batch[\"obs\"], expected_obs_interleave))\n    assert (repeated_data_interleave.non_tensor_batch[\"labels\"] == expected_labels_interleave).all()\n    assert repeated_data_interleave.meta_info == {\"info\": \"test_info\"}\n\n    # Test interleave=False\n    repeated_data_no_interleave = data.repeat(repeat_times=2, interleave=False)\n    expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]])\n    expected_labels_no_interleave = [\"a\", \"b\", \"c\", \"a\", \"b\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_no_interleave.batch[\"obs\"], expected_obs_no_interleave))\n    assert (repeated_data_no_interleave.non_tensor_batch[\"labels\"] == expected_labels_no_interleave).all()\n    assert repeated_data_no_interleave.meta_info == {\"info\": \"test_info\"}\n\n\ndef test_dataproto_pad_unpad():\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto\n\n    padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=2)\n    assert pad_size == 1\n\n    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]])\n    expected_labels = [\"a\", \"b\", \"c\", \"a\"]\n\n    assert torch.all(torch.eq(padded_data.batch[\"obs\"], expected_obs))\n    assert (padded_data.non_tensor_batch[\"labels\"] == expected_labels).all()\n    assert padded_data.meta_info == {\"info\": \"test_info\"}\n\n    unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)\n    assert torch.all(torch.eq(unpadd_data.batch[\"obs\"], obs))\n    assert (unpadd_data.non_tensor_batch[\"labels\"] == labels).all()\n    assert unpadd_data.meta_info == {\"info\": \"test_info\"}\n\n    padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=3)\n    assert pad_size == 0\n\n    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    expected_labels = [\"a\", \"b\", \"c\"]\n\n    assert torch.all(torch.eq(padded_data.batch[\"obs\"], expected_obs))\n    assert (padded_data.non_tensor_batch[\"labels\"] == expected_labels).all()\n    assert padded_data.meta_info == {\"info\": \"test_info\"}\n\n    unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)\n    assert torch.all(torch.eq(unpadd_data.batch[\"obs\"], obs))\n    assert (unpadd_data.non_tensor_batch[\"labels\"] == labels).all()\n    assert unpadd_data.meta_info == {\"info\": \"test_info\"}\n\n    padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7)\n    assert pad_size == 4\n\n    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])\n    expected_labels = [\"a\", \"b\", \"c\", \"a\", \"b\", \"c\", \"a\"]\n    assert torch.all(torch.eq(padded_data.batch[\"obs\"], expected_obs))\n    assert (padded_data.non_tensor_batch[\"labels\"] == expected_labels).all()\n    assert padded_data.meta_info == {\"info\": \"test_info\"}\n\n    unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)\n    assert torch.all(torch.eq(unpadd_data.batch[\"obs\"], obs))\n    assert (unpadd_data.non_tensor_batch[\"labels\"] == labels).all()\n    assert unpadd_data.meta_info == {\"info\": \"test_info\"}\n\n\ndef test_dataproto_fold_unfold():\n    from verl.protocol import DataProto, fold_batch_dim, unfold_batch_dim\n\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    data1 = data.repeat(repeat_times=2, interleave=True)\n\n    data2 = fold_batch_dim(data1, new_batch_size=3)\n\n    torch.testing.assert_close(data2.batch[\"obs\"], torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]))\n    assert (data2.non_tensor_batch[\"labels\"] == [[\"a\", \"a\"], [\"b\", \"b\"], [\"c\", \"c\"]]).all()\n\n    data2.reorder(indices=torch.tensor([1, 2, 0]))\n\n    data3 = unfold_batch_dim(data2, batch_dims=2)\n\n    torch.testing.assert_close(data3.batch[\"obs\"], torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]]))\n    assert (data3.non_tensor_batch[\"labels\"] == [\"b\", \"b\", \"c\", \"c\", \"a\", \"a\"]).all()\n    assert data3.meta_info == {\"info\": \"test_info\"}\n\n\ndef test_torch_save_data_proto():\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n    data.save_to_disk(\"test_data.pt\")\n    loaded_data = DataProto.load_from_disk(\"test_data.pt\")\n\n    assert torch.all(torch.eq(loaded_data.batch[\"obs\"], data.batch[\"obs\"]))\n    assert (loaded_data.non_tensor_batch[\"labels\"] == data.non_tensor_batch[\"labels\"]).all()\n    assert loaded_data.meta_info == data.meta_info\n\n    import os\n\n    os.remove(\"test_data.pt\")\n\n\ndef test_len():\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = np.array([\"a\", \"b\", \"c\"], dtype=object)\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    assert len(data) == 3\n\n    data = DataProto(batch=None, non_tensor_batch={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    assert len(data) == 3\n\n    data = DataProto(batch=None, non_tensor_batch={}, meta_info={\"info\": \"test_info\"})\n\n    assert len(data) == 0\n\n    data = DataProto(batch=None, non_tensor_batch=None, meta_info={\"info\": \"test_info\"})\n\n    assert len(data) == 0\n\n\ndef test_dataproto_index():\n    data_len = 100\n    idx_num = 10\n\n    obs = torch.randn(data_len, 10)\n    labels = [random.choice([\"abc\", \"cde\"]) for _ in range(data_len)]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels})\n    labels_np = np.array(labels)\n\n    idx_np_int = np.random.randint(0, data_len, size=(idx_num,))\n    result_np_int = data[idx_np_int]\n    assert result_np_int.batch.keys() == data.batch.keys()\n    assert result_np_int.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_np_int.batch[\"obs\"].shape[0] == idx_num\n    assert result_np_int.non_tensor_batch[\"labels\"].shape[0] == idx_num\n    assert np.array_equal(result_np_int.batch[\"obs\"].cpu().numpy(), obs[idx_np_int].numpy())\n    assert np.array_equal(result_np_int.non_tensor_batch[\"labels\"], labels_np[idx_np_int])\n\n    idx_torch_int = torch.randint(0, data_len, size=(idx_num,))\n    result_torch_int = data[idx_torch_int]\n    assert result_torch_int.batch.keys() == data.batch.keys()\n    assert result_torch_int.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_torch_int.batch[\"obs\"].shape[0] == idx_num\n    assert result_torch_int.non_tensor_batch[\"labels\"].shape[0] == idx_num\n    assert np.array_equal(result_torch_int.batch[\"obs\"].cpu().numpy(), obs[idx_torch_int].cpu().numpy())\n    assert np.array_equal(result_torch_int.non_tensor_batch[\"labels\"], labels_np[idx_torch_int.cpu().numpy()])\n\n    idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)]\n    result_list_int = data[idx_list_int]\n    assert result_list_int.batch.keys() == data.batch.keys()\n    assert result_list_int.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_list_int.batch[\"obs\"].shape[0] == idx_num\n    assert result_list_int.non_tensor_batch[\"labels\"].shape[0] == idx_num\n    assert np.array_equal(result_list_int.batch[\"obs\"].cpu().numpy(), obs[idx_list_int].cpu().numpy())\n    assert np.array_equal(result_list_int.non_tensor_batch[\"labels\"], labels_np[idx_list_int])\n\n    idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool)\n    result_np_bool = data[idx_np_bool]\n    assert result_np_bool.batch.keys() == data.batch.keys()\n    assert result_np_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_np_bool.batch[\"obs\"].shape[0] == idx_np_bool.sum()\n    assert result_np_bool.non_tensor_batch[\"labels\"].shape[0] == idx_np_bool.sum()\n    assert np.array_equal(result_np_bool.batch[\"obs\"].cpu().numpy(), obs[idx_np_bool].cpu().numpy())\n    assert np.array_equal(result_np_bool.non_tensor_batch[\"labels\"], labels_np[idx_np_bool])\n\n    idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool)\n    result_torch_bool = data[idx_torch_bool]\n    assert result_torch_bool.batch.keys() == data.batch.keys()\n    assert result_torch_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_torch_bool.batch[\"obs\"].shape[0] == idx_torch_bool.sum().item()\n    assert result_torch_bool.non_tensor_batch[\"labels\"].shape[0] == idx_torch_bool.sum().item()\n    assert np.array_equal(result_torch_bool.batch[\"obs\"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy())\n    assert np.array_equal(result_torch_bool.non_tensor_batch[\"labels\"], labels_np[idx_torch_bool])\n\n    idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)]\n    result_list_bool = data[idx_list_bool]\n    assert result_list_bool.batch.keys() == data.batch.keys()\n    assert result_list_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_list_bool.batch[\"obs\"].shape[0] == sum(idx_list_bool)\n    assert result_list_bool.non_tensor_batch[\"labels\"].shape[0] == sum(idx_list_bool)\n    assert np.array_equal(result_list_bool.batch[\"obs\"].cpu().numpy(), obs[idx_list_bool].cpu().numpy())\n    assert np.array_equal(result_list_bool.non_tensor_batch[\"labels\"], labels_np[idx_list_bool])\n\n\ndef test_old_vs_new_from_single_dict():\n    class CustomProto(DataProto):\n        \"\"\"Uses the new, fixed from_single_dict.\"\"\"\n\n        pass\n\n    class OriginProto(DataProto):\n        \"\"\"Mimics the *old* from_single_dict (always returns a DataProto).\"\"\"\n\n        @classmethod\n        def from_single_dict(cls, data, meta_info=None, auto_padding=False):\n            tensors, non_tensors = {}, {}\n            for k, v in data.items():\n                if torch.is_tensor(v):\n                    tensors[k] = v\n                else:\n                    non_tensors[k] = v\n            # always calls DataProto.from_dict, ignoring `cls`\n            return DataProto.from_dict(\n                tensors=tensors,\n                non_tensors=non_tensors,\n                meta_info=meta_info,\n                auto_padding=auto_padding,\n            )\n\n    sample = {\"x\": torch.tensor([0])}\n\n    orig = OriginProto.from_single_dict(sample)\n    # old behavior: always DataProto, not a CustomOriginProto\n    assert type(orig) is DataProto\n    assert type(orig) is not OriginProto\n\n    cust = CustomProto.from_single_dict(sample)\n    # new behavior: respects subclass\n    assert type(cust) is CustomProto\n\n\ndef test_dataproto_no_batch():\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n    selected = data.select(non_tensor_batch_keys=[\"labels\"])\n    assert (selected.non_tensor_batch[\"labels\"] == labels).all()\n    pop_data = data.pop(non_tensor_batch_keys=[\"labels\"])\n    assert (pop_data.non_tensor_batch[\"labels\"] == labels).all()\n    assert data.non_tensor_batch == {}\n\n\ndef test_sample_level_repeat():\n    # Create a DataProto object with some batch and non-tensor data\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    # list\n    repeated_data_interleave = data.sample_level_repeat(repeat_times=[3, 1, 2])\n    expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]])\n    expected_labels_interleave = [\"a\", \"a\", \"a\", \"b\", \"c\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_interleave.batch[\"obs\"], expected_obs_interleave))\n    assert (repeated_data_interleave.non_tensor_batch[\"labels\"] == expected_labels_interleave).all()\n    assert repeated_data_interleave.meta_info == {\"info\": \"test_info\"}\n\n    # torch.tensor\n    repeated_data_no_interleave = data.sample_level_repeat(repeat_times=torch.tensor([1, 2, 3]))\n    expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]])\n    expected_labels_no_interleave = [\"a\", \"b\", \"b\", \"c\", \"c\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_no_interleave.batch[\"obs\"], expected_obs_no_interleave))\n    assert (repeated_data_no_interleave.non_tensor_batch[\"labels\"] == expected_labels_no_interleave).all()\n    assert repeated_data_no_interleave.meta_info == {\"info\": \"test_info\"}\n\n\ndef test_dataproto_unfold_column_chunks():\n    obs1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])\n    obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]])\n\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(\n        tensors={\"obs1\": obs1, \"obs2\": obs2}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abc\"}\n    )\n    ret = data.unfold_column_chunks(2, split_keys=[\"obs1\"])\n\n    expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])\n    expect_obs2 = torch.tensor([[1, 2], [1, 2], [5, 6], [5, 6], [9, 10], [9, 10]])\n    expect_labels = [\"a\", \"a\", \"b\", \"b\", \"c\", \"c\"]\n    assert torch.all(torch.eq(ret.batch[\"obs1\"], expect_obs1))\n    assert torch.all(torch.eq(ret.batch[\"obs2\"], expect_obs2))\n    assert (ret.non_tensor_batch[\"labels\"] == expect_labels).all()\n    assert ret.meta_info == {\"name\": \"abc\"}\n\n    obs1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])\n    obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]])\n\n    labels = [[\"a1\", \"a2\"], [\"b1\", \"b2\"], [\"c1\", \"c2\"]]\n    data = DataProto.from_dict(\n        tensors={\"obs1\": obs1, \"obs2\": obs2}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abc\"}\n    )\n    ret = data.unfold_column_chunks(2, split_keys=[\"obs1\", \"labels\"])\n\n    expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])\n    expect_obs2 = torch.tensor([[1, 2], [1, 2], [5, 6], [5, 6], [9, 10], [9, 10]])\n    expect_labels = [[\"a1\"], [\"a2\"], [\"b1\"], [\"b2\"], [\"c1\"], [\"c2\"]]\n    assert torch.all(torch.eq(ret.batch[\"obs1\"], expect_obs1))\n    assert torch.all(torch.eq(ret.batch[\"obs2\"], expect_obs2))\n    assert (ret.non_tensor_batch[\"labels\"] == expect_labels).all()\n    assert ret.meta_info == {\"name\": \"abc\"}\n\n    obs1 = torch.tensor(\n        [[[1, 1], [2, 2], [3, 3], [4, 4]], [[5, 5], [6, 6], [7, 7], [8, 8]], [[9, 9], [10, 10], [11, 11], [12, 12]]]\n    )\n    obs2 = torch.tensor([[[1, 1], [2, 2]], [[5, 5], [6, 6]], [[9, 9], [10, 10]]])\n\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(\n        tensors={\"obs1\": obs1, \"obs2\": obs2}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abc\"}\n    )\n    ret = data.unfold_column_chunks(2, split_keys=[\"obs1\"])\n\n    expect_obs1 = torch.tensor(\n        [\n            [[1, 1], [2, 2]],\n            [[3, 3], [4, 4]],\n            [[5, 5], [6, 6]],\n            [[7, 7], [8, 8]],\n            [[9, 9], [10, 10]],\n            [[11, 11], [12, 12]],\n        ]\n    )\n    expect_obs2 = torch.tensor(\n        [[[1, 1], [2, 2]], [[1, 1], [2, 2]], [[5, 5], [6, 6]], [[5, 5], [6, 6]], [[9, 9], [10, 10]], [[9, 9], [10, 10]]]\n    )\n    expect_labels = [\"a\", \"a\", \"b\", \"b\", \"c\", \"c\"]\n    assert torch.all(torch.eq(ret.batch[\"obs1\"], expect_obs1))\n    assert torch.all(torch.eq(ret.batch[\"obs2\"], expect_obs2))\n    assert (ret.non_tensor_batch[\"labels\"] == expect_labels).all()\n    assert ret.meta_info == {\"name\": \"abc\"}\n\n\ndef test_dataproto_chunk_after_index():\n    data_len = 4\n    obs = torch.randn(data_len, 4)\n    labels = [f\"label_{i}\" for i in range(data_len)]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abc\"})\n\n    # Test with boolean numpy array\n    bool_mask = np.array([True, False, True, False])\n    selected = data[bool_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)  # int or List[int]\n\n    # Test with integer numpy array\n    int_mask = np.array([0, 2])\n    selected = data[int_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)\n\n    # Test with boolean list\n    list_mask = [True, False, True, False]\n    selected = data[list_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)\n\n    # Test with list\n    list_mask = [0, 2]\n    selected = data[list_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)\n\n    # Test with torch tensor (bool)\n    torch_bool_mask = torch.tensor([True, False, True, False])\n    selected = data[torch_bool_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)\n\n    # Test with torch tensor (int)\n    torch_int_mask = torch.tensor([0, 2])\n    selected = data[torch_int_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)\n\n\n@pytest.mark.skipif(\n    parse_version(tensordict.__version__) < parse_version(\"0.10\"), reason=\"requires at least tensordict 0.10\"\n)\ndef test_to_tensordict():\n    obs = torch.tensor([1, 2, 3, 4, 5, 6])\n    labels = [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abdce\"})\n    output = data.to_tensordict()\n\n    assert torch.all(torch.eq(output[\"obs\"], obs)).item()\n    assert output[\"labels\"] == labels\n    assert output[\"name\"] == \"abdce\"\n\n\n@pytest.mark.skipif(\n    parse_version(tensordict.__version__) < parse_version(\"0.10\"), reason=\"requires at least tensordict 0.10\"\n)\ndef test_from_tensordict():\n    tensor_dict = {\n        \"obs\": torch.tensor([1, 2, 3, 4, 5, 6]),\n        \"labels\": [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\"],\n    }\n    non_tensor_dict = {\"name\": \"abdce\"}\n    tensordict = tu.get_tensordict(tensor_dict, non_tensor_dict)\n    data = DataProto.from_tensordict(tensordict)\n\n    assert data.non_tensor_batch[\"labels\"].tolist() == tensor_dict[\"labels\"]\n    assert torch.all(torch.eq(data.batch[\"obs\"], tensor_dict[\"obs\"])).item()\n    assert data.meta_info[\"name\"] == \"abdce\"\n\n\ndef test_serialize_deserialize_single_tensor():\n    \"\"\"Test serialization and deserialization of a single tensor\"\"\"\n    # Create test tensor\n    original_tensor = torch.randn(3, 4, 5)\n\n    # Serialize\n    dtype, shape, data = serialize_single_tensor(original_tensor)\n\n    # Deserialize\n    reconstructed_tensor = deserialize_single_tensor((dtype, shape, data))\n\n    # Verify results\n    assert torch.allclose(original_tensor, reconstructed_tensor)\n    assert original_tensor.shape == reconstructed_tensor.shape\n    assert original_tensor.dtype == reconstructed_tensor.dtype\n\n\ndef test_serialize_deserialize_tensordict_regular_tensors():\n    \"\"\"Test serialization and deserialization of TensorDict with regular tensors\"\"\"\n    # Create test data\n    batch_size = (5, 3)\n    tensor1 = torch.randn(*batch_size, 4)\n    tensor2 = torch.randint(0, 10, (*batch_size, 2))\n\n    # Create TensorDict\n    original_tensordict = TensorDict({\"tensor1\": tensor1, \"tensor2\": tensor2}, batch_size=batch_size)\n\n    # Serialize\n    batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)\n\n    # Deserialize\n    reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))\n\n    # Verify results\n    assert original_tensordict.batch_size == reconstructed_tensordict.batch_size\n    assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())\n\n    for key in original_tensordict.keys():\n        original_tensor = original_tensordict[key]\n        reconstructed_tensor = reconstructed_tensordict[key]\n\n        assert torch.allclose(original_tensor, reconstructed_tensor)\n        assert original_tensor.shape == reconstructed_tensor.shape\n        assert original_tensor.dtype == reconstructed_tensor.dtype\n\n\ndef test_serialize_deserialize_tensordict_nested_tensors():\n    \"\"\"Test serialization and deserialization of TensorDict with nested tensors\"\"\"\n    # Create nested tensor\n    tensor_list = [torch.randn(2, 3), torch.randn(3, 4), torch.randn(1, 5)]\n    nested_tensor = torch.nested.as_nested_tensor(tensor_list)\n\n    # Create regular tensor for comparison\n    regular_tensor = torch.randn(3, 4, 5)\n\n    # Create TensorDict\n    original_tensordict = TensorDict({\"nested\": nested_tensor, \"regular\": regular_tensor}, batch_size=(3,))\n\n    # Serialize\n    batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)\n\n    # Deserialize\n    reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))\n\n    # Verify results\n    assert original_tensordict.batch_size == reconstructed_tensordict.batch_size\n    assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())\n\n    # Verify regular tensor\n    original_regular = original_tensordict[\"regular\"]\n    reconstructed_regular = reconstructed_tensordict[\"regular\"]\n\n    assert torch.allclose(original_regular, reconstructed_regular)\n    assert original_regular.shape == reconstructed_regular.shape\n    assert original_regular.dtype == reconstructed_regular.dtype\n\n    # Verify nested tensor\n    original_nested = original_tensordict[\"nested\"]\n    reconstructed_nested = reconstructed_tensordict[\"nested\"]\n\n    # Check if it's a nested tensor\n    assert original_nested.is_nested\n    assert reconstructed_nested.is_nested\n\n    # Check layout\n    assert original_nested.layout == reconstructed_nested.layout\n\n    # Check each tensor after unbinding\n    original_unbind = original_nested.unbind()\n    reconstructed_unbind = reconstructed_nested.unbind()\n\n    assert len(original_unbind) == len(reconstructed_unbind)\n\n    for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False):\n        assert torch.allclose(orig, recon)\n        assert orig.shape == recon.shape\n        assert orig.dtype == recon.dtype\n\n\ndef test_serialize_deserialize_tensordict_mixed_types():\n    \"\"\"Test serialization and deserialization of TensorDict with mixed tensor types\"\"\"\n    # Create tensors with different data types\n    float_tensor = torch.randn(2, 3).float()\n    double_tensor = torch.randn(2, 3).double()\n    int_tensor = torch.randint(0, 10, (2, 3)).int()\n    long_tensor = torch.randint(0, 10, (2, 3)).long()\n    bool_tensor = torch.tensor([[True, False], [False, True]])\n    bfloat16_tensor = torch.randn(2, 3).bfloat16()\n\n    # Add fp8 tensor (if available)\n    # Note: FP8 is not natively supported in all PyTorch versions\n    # We'll check if it's available and conditionally include it\n    has_fp8 = hasattr(torch, \"float8_e5m2\") or hasattr(torch, \"float8_e4m3fn\")\n    if has_fp8:\n        try:\n            # Try to create an FP8 tensor (implementation may vary)\n            # This is a placeholder - actual FP8 support might require specific hardware\n            fp8_tensor = torch.randn(2, 3)\n            if hasattr(torch, \"float8_e5m2\"):\n                fp8_tensor = fp8_tensor.to(torch.float8_e5m2)\n            elif hasattr(torch, \"float8_e4m3fn\"):\n                fp8_tensor = fp8_tensor.to(torch.float8_e4m3fn)\n        except Exception:\n            has_fp8 = False\n\n    # Create nested tensor\n    tensor_list = [\n        torch.randn(2, 3),\n        torch.randn(3, 4),\n    ]\n    nested_tensor = torch.nested.as_nested_tensor(tensor_list)\n\n    # Create TensorDict with all available types\n    tensordict_data = {\n        \"float\": float_tensor,\n        \"double\": double_tensor,\n        \"int\": int_tensor,\n        \"long\": long_tensor,\n        \"bool\": bool_tensor,\n        \"bfloat16\": bfloat16_tensor,\n        \"nested\": nested_tensor,\n    }\n\n    # Conditionally add fp8 tensor if available\n    if has_fp8:\n        tensordict_data[\"fp8\"] = fp8_tensor\n\n    original_tensordict = TensorDict(\n        tensordict_data,\n        batch_size=(2,),\n    )\n\n    # Serialize\n    batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)\n\n    # Deserialize\n    reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))\n\n    # Verify results\n    assert original_tensordict.batch_size == reconstructed_tensordict.batch_size\n    assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())\n\n    for key in original_tensordict.keys():\n        original_tensor = original_tensordict[key]\n        reconstructed_tensor = reconstructed_tensordict[key]\n\n        if original_tensor.is_nested:\n            # For nested tensors, check each tensor after unbinding\n            original_unbind = original_tensor.unbind()\n            reconstructed_unbind = reconstructed_tensor.unbind()\n\n            assert len(original_unbind) == len(reconstructed_unbind)\n\n            for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False):\n                assert torch.allclose(orig, recon, equal_nan=True)\n                assert orig.shape == recon.shape\n                assert orig.dtype == recon.dtype\n        else:\n            # For regular tensors, compare directly\n            assert torch.all(original_tensor == reconstructed_tensor)\n            assert original_tensor.shape == reconstructed_tensor.shape\n            assert original_tensor.dtype == reconstructed_tensor.dtype\n\n\ndef test_serialize_deserialize_tensordict_with_device():\n    \"\"\"Test serialization and deserialization of TensorDict with device information\"\"\"\n    # Create test data\n    batch_size = (2, 3)\n    tensor1 = torch.randn(*batch_size, 4)\n    tensor2 = torch.randint(0, 10, (*batch_size, 2))\n\n    # Create TensorDict with device information\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    original_tensordict = TensorDict({\"tensor1\": tensor1, \"tensor2\": tensor2}, batch_size=batch_size, device=device)\n\n    # Serialize\n    batch_size_serialized, device_serialized, encoded_items = serialize_tensordict(original_tensordict)\n\n    # Deserialize\n    reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device_serialized, encoded_items))\n\n    # Verify results\n    assert original_tensordict.batch_size == reconstructed_tensordict.batch_size\n    assert str(original_tensordict.device) == str(reconstructed_tensordict.device)\n    assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())\n\n    for key in original_tensordict.keys():\n        original_tensor = original_tensordict[key]\n        reconstructed_tensor = reconstructed_tensordict[key]\n\n        assert torch.allclose(original_tensor.cpu(), reconstructed_tensor.cpu())\n        assert original_tensor.shape == reconstructed_tensor.shape\n        assert original_tensor.dtype == reconstructed_tensor.dtype\n"
  },
  {
    "path": "verl_distillation/tests/test_protocol_v2_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nReplace DataProto with raw TensorDict\n\"\"\"\n\nimport copy\nimport random\n\nimport numpy as np\nimport pytest\nimport torch\n\nfrom verl.utils import tensordict_utils as tu\n\n\ndef test_union_tensor_dict():\n    obs = torch.randn(100, 10)\n\n    meta_info1 = {\"top_p\": 0.8}\n    meta_info2 = {\"top_p\": 0.9}\n    data1 = {\"obs\": obs, \"act\": torch.randn(100, 3), \"data_sources\": [\"gsm8k\"] * 100}\n    data2 = {\"obs\": obs, \"next_obs\": torch.randn(100, 10), \"rew\": torch.randn(100), \"data_sources\": [\"gsm8k\"] * 100}\n\n    data_with_copied_obs = {\"obs\": obs.clone(), \"next_obs\": torch.randn(100, 10), \"rew\": torch.randn(100)}\n\n    data1 = tu.get_tensordict(tensor_dict=data1)\n    data2 = tu.get_tensordict(tensor_dict=data2)\n    data_with_copied_obs = tu.get_tensordict(data_with_copied_obs)\n\n    tu.union_tensor_dict(data1, data2)\n    with pytest.raises(AssertionError):\n        # conflict in tensor values\n        tu.union_tensor_dict(data1, data_with_copied_obs)\n\n    data1 = tu.assign_non_tensor_dict(data1, meta_info1)\n    tu.union_tensor_dict(data1, data2)  # works ok\n\n    data2 = tu.assign_non_tensor_dict(data2, meta_info2)\n\n    with pytest.raises(AssertionError):\n        # conflict in NonTensorData\n        tu.union_tensor_dict(data1, data2)\n\n    data1.pop(\"top_p\")\n    data2.pop(\"top_p\")\n\n    data2[\"data_sources\"][0] = \"math\"\n    with pytest.raises(AssertionError):\n        # conflict in NonTensorData\n        tu.union_tensor_dict(data1, data2)\n\n\ndef test_tensor_dict_constructor():\n    obs = torch.ones(100, 10)\n    act = torch.zeros(100, 10, 3)\n    data_source = [\"gsm8k\"] * 100\n    non_tensor_dict = {\"name\": \"abdce\"}\n\n    data = tu.get_tensordict(\n        tensor_dict={\"obs\": obs, \"act\": act, \"data_source\": data_source}, non_tensor_dict=non_tensor_dict\n    )\n\n    assert data.batch_size == torch.Size([100])\n\n    # test slicing\n    assert torch.all(torch.eq(data[0][\"obs\"], torch.ones(10))).item()\n    assert torch.all(torch.eq(data[0][\"act\"], torch.zeros(10, 3))).item()\n    assert data[0][\"data_source\"] == \"gsm8k\"\n\n    assert torch.all(torch.eq(data[0:2][\"obs\"], torch.ones(2, 10))).item()\n    assert torch.all(torch.eq(data[0:2][\"act\"], torch.zeros(2, 10, 3))).item()\n    assert data[0:2][\"data_source\"] == [\"gsm8k\"] * 2\n\n    # test non tensor data\n    assert data[\"name\"] == \"abdce\"\n\n\ndef test_index_select_tensor_dict():\n    vocab_size = 128\n    a = torch.randint(low=0, high=vocab_size, size=(11,))\n    b = torch.randint(low=0, high=vocab_size, size=(13,))\n    c = torch.randint(low=0, high=vocab_size, size=(12,))\n    d = torch.randint(low=0, high=vocab_size, size=(15,))\n    input_ids = [a, b, c, d]\n    input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)\n\n    padded_tensor = torch.randn(4, 10)\n    non_tensor_dict = {\"global_batch_size\": \"4\"}\n\n    data = tu.get_tensordict(\n        tensor_dict={\n            \"input_ids\": input_ids,\n            \"padded_tensor\": padded_tensor,\n        },\n        non_tensor_dict=non_tensor_dict,\n    )\n\n    assert data.batch_size == torch.Size([4])\n\n    # test index select\n    indices = torch.tensor([1, 3])\n    selected_data = tu.index_select_tensor_dict(data, indices)\n\n    assert selected_data.batch_size == torch.Size([2])\n\n    target_input_ids = torch.nested.as_nested_tensor([input_ids[idx] for idx in indices], layout=torch.jagged)\n    target_select_data = tu.get_tensordict(\n        tensor_dict={\n            \"input_ids\": target_input_ids,\n            \"padded_tensor\": padded_tensor[indices],\n        },\n        non_tensor_dict=non_tensor_dict,\n    )\n    tu.assert_tensordict_eq(selected_data, target_select_data)\n\n\ndef test_tensordict_with_images():\n    # each sample contains a sequence with multiple images of different sizes\n    vocab_size = 128\n    a = torch.randint(low=0, high=vocab_size, size=(11,))\n    b = torch.randint(low=0, high=vocab_size, size=(13,))\n    input_ids = [a, b]\n    input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)\n\n    # must be numpy\n    # TODO(vermouth1992). We may use nested tensor too. But this requires nested over nested\n    a_images = [\n        torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),\n        torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),\n    ]\n    b_images = [\n        torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),\n        torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),\n        torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(),\n    ]\n\n    images = [a_images, b_images]\n\n    data = tu.get_tensordict({\"input_ids\": input_ids, \"images\": images})\n\n    assert np.all(np.equal(data[0][\"images\"][0], a_images[0]))\n    assert torch.all(torch.eq(data[0][\"input_ids\"], a))\n\n\ndef test_tensordict_with_packing():\n    vocab_size = 128\n    a = torch.randint(low=0, high=vocab_size, size=(11,))\n    b = torch.randint(low=0, high=vocab_size, size=(13,))\n    input_ids = [a, b]\n    input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)\n\n    data = tu.get_tensordict({\"input_ids\": input_ids})\n\n    # test cu_seqlens\n    cu_seqlens = torch.tensor([0, 11, 24])\n    assert torch.all(torch.eq(cu_seqlens, data[\"input_ids\"].offsets()))\n\n    # test index\n    assert torch.all(torch.eq(data[\"input_ids\"][0], a))\n    assert torch.all(torch.eq(data[\"input_ids\"][1], b))\n\n    assert torch.all(torch.eq(data[0][\"input_ids\"], a))\n    assert torch.all(torch.eq(data[1][\"input_ids\"], b))\n\n    data_lst = data.chunk(2)\n\n    assert torch.all(torch.eq(data_lst[0][\"input_ids\"][0], a))\n    assert torch.all(torch.eq(data_lst[1][\"input_ids\"][0], b))\n\n\ndef test_tensordict_eq():\n    obs = torch.tensor([1, 2, 3, 4, 5, 6])\n    data_sources = [\"abc\", \"def\", \"abc\", \"def\", \"pol\", \"klj\"]\n    non_tensor_dict = {\"train_sample_kwargs\": {\"top_p\": 1.0}, \"val_sample_kwargs\": {\"top_p\": 0.7}}\n    data = tu.get_tensordict({\"obs\": obs, \"data_sources\": data_sources}, non_tensor_dict=non_tensor_dict)\n\n    obs = torch.tensor([1, 2, 3, 4, 5, 6])\n    data_sources = [\"abc\", \"def\", \"abc\", \"def\", \"pol\", \"klj\"]\n    non_tensor_dict = {\"train_sample_kwargs\": {\"top_p\": 1.0}, \"val_sample_kwargs\": {\"top_p\": 0.7}}\n    data1 = tu.get_tensordict({\"obs\": obs, \"data_sources\": data_sources}, non_tensor_dict=non_tensor_dict)\n\n    tu.assert_tensordict_eq(data, data1)\n\n    data2 = copy.deepcopy(data1)\n    data2[\"obs\"][0] += 1\n\n    with pytest.raises(AssertionError):\n        tu.assert_tensordict_eq(data, data2)\n\n    data2 = copy.deepcopy(data1)\n    data2[\"data_sources\"][0] = \"math\"\n\n    with pytest.raises(AssertionError):\n        tu.assert_tensordict_eq(data, data2)\n\n    data2 = copy.deepcopy(data1)\n    data2[\"train_sample_kwargs\"][\"top_p\"] = 0.9\n\n    with pytest.raises(AssertionError):\n        tu.assert_tensordict_eq(data, data2)\n\n    tensor_list = [\n        torch.tensor([1, 2, 3, 3, 2]),\n        torch.tensor([4, 5]),\n        torch.tensor([7, 8, 10, 14]),\n        torch.tensor([10, 11, 12]),\n        torch.tensor([13, 14, 15, 18]),\n        torch.tensor([16, 17]),\n    ]\n    obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)\n    data_sources = [\"abc\", \"def\", \"abc\", \"def\", \"pol\", \"klj\"]\n    non_tensor_dict = {\"train_sample_kwargs\": {\"top_p\": 1.0}, \"val_sample_kwargs\": {\"top_p\": 0.7}}\n    data3 = tu.get_tensordict({\"obs\": obs, \"data_sources\": data_sources}, non_tensor_dict=non_tensor_dict)\n\n    tensor_list[0] = torch.tensor([1, 2, 3, 3, 2])\n    obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)\n    data4 = tu.get_tensordict({\"obs\": obs, \"data_sources\": data_sources}, non_tensor_dict=non_tensor_dict)\n    tu.assert_tensordict_eq(data3, data4)\n\n    tensor_list[0] = torch.tensor([1, 2, 4])\n    obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)\n    data5 = tu.get_tensordict({\"obs\": obs, \"data_sources\": data_sources}, non_tensor_dict=non_tensor_dict)\n    with pytest.raises(AssertionError):\n        tu.assert_tensordict_eq(data3, data5)\n\n    tensor_list[0] = torch.tensor([4, 5])\n    tensor_list[1] = torch.tensor([1, 2, 3, 3, 2])\n    obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)\n    data6 = tu.get_tensordict({\"obs\": obs, \"data_sources\": data_sources}, non_tensor_dict=non_tensor_dict)\n    with pytest.raises(AssertionError):\n        tu.assert_tensordict_eq(data3, data6)\n\n\ndef test_tensor_dict_make_iterator():\n    obs = torch.tensor([1, 2, 3, 4, 5, 6])\n    data_sources = [\"abc\", \"def\", \"abc\", \"def\", \"pol\", \"klj\"]\n    non_tensor_dict = {\"train_sample_kwargs\": {\"top_p\": 1.0}, \"val_sample_kwargs\": {\"top_p\": 0.7}}\n    dataset = tu.get_tensordict({\"obs\": obs, \"data_sources\": data_sources}, non_tensor_dict=non_tensor_dict)\n\n    dataloader = tu.make_iterator(\n        dataset, mini_batch_size=2, epochs=2, seed=0, dataloader_kwargs={\"shuffle\": False, \"drop_last\": False}\n    )\n\n    expected_tensor_dict = [dataset[0:2], dataset[2:4], dataset[4:6], dataset[0:2], dataset[2:4], dataset[4:6]]\n\n    i = 0\n\n    for d in dataloader:\n        tu.assert_tensordict_eq(d, expected_tensor_dict[i])\n        i += 1\n\n    data_iter_1 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={\"shuffle\": True})\n    data_list_1 = []\n    for data in data_iter_1:\n        data_list_1.append(data)\n\n    data_iter_2 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={\"shuffle\": True})\n    data_list_2 = []\n    for data in data_iter_2:\n        data_list_2.append(data)\n\n    for data1, data2 in zip(data_list_1, data_list_2, strict=True):\n        tu.assert_tensordict_eq(data1, data2)\n\n\ndef test_reorder():\n    obs = torch.tensor([1, 2, 3, 4, 5, 6])\n    labels = [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\"]\n    non_tensor_dict = {\"name\": \"abdce\"}\n\n    data = tu.get_tensordict(tensor_dict={\"obs\": obs, \"labels\": labels}, non_tensor_dict=non_tensor_dict)\n    data = data[torch.tensor([3, 4, 2, 0, 1, 5])]\n\n    assert torch.all(torch.eq(data[\"obs\"], torch.tensor([4, 5, 3, 1, 2, 6])))\n    assert np.all(data[\"labels\"] == np.array([\"d\", \"e\", \"c\", \"a\", \"b\", \"f\"]))\n    assert data[\"name\"] == \"abdce\"\n\n\ndef test_chunk_concat():\n    obs = torch.tensor([1, 2, 3, 4, 5, 6])\n    labels = [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\"]\n    data = tu.get_tensordict({\"obs\": obs, \"labels\": labels}, non_tensor_dict={\"name\": \"abcde\"})\n\n    data_split = data.tensor_split(indices_or_sections=5, dim=0)\n\n    expected_idx_lst = [[0, 1], [2], [3], [4], [5]]\n\n    for d, expected_idx in zip(data_split, expected_idx_lst, strict=False):\n        tu.assert_tensordict_eq(d, data[expected_idx])\n\n    data_split = data.chunk(2)\n    assert len(data_split) == 2\n    assert torch.all(torch.eq(data_split[0][\"obs\"], torch.tensor([1, 2, 3])))\n    assert np.all(data_split[0][\"labels\"] == np.array([\"a\", \"b\", \"c\"]))\n    assert data_split[0][\"name\"] == \"abcde\"\n\n    assert torch.all(torch.eq(data_split[1][\"obs\"], torch.tensor([4, 5, 6])))\n    assert np.all(data_split[1][\"labels\"] == np.array([\"d\", \"e\", \"f\"]))\n    assert data_split[1][\"name\"] == \"abcde\"\n\n    concat_data = torch.cat(data_split, dim=0)\n    assert torch.all(torch.eq(concat_data[\"obs\"], data[\"obs\"]))\n    assert np.all(concat_data[\"labels\"] == data[\"labels\"])\n    assert concat_data[\"name\"] == data[\"name\"]\n\n\ndef test_pop():\n    obs = torch.randn(100, 10)\n    act = torch.randn(100, 3)\n    dataset = tu.get_tensordict({\"obs\": obs, \"act\": act}, non_tensor_dict={\"2\": 2, \"1\": 1})\n\n    poped_dataset = tu.pop(dataset, keys=[\"obs\", \"2\"])\n\n    assert poped_dataset.batch_size[0] == 100\n\n    assert poped_dataset.keys() == {\"obs\", \"2\"}\n\n    assert dataset.keys() == {\"act\", \"1\"}\n\n\ndef test_repeat():\n    # Create a DataProto object with some batch and non-tensor data\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = tu.get_tensordict({\"obs\": obs, \"labels\": labels}, non_tensor_dict={\"info\": \"test_info\"})\n\n    # Test interleave=True\n    repeated_data_interleave = data.repeat_interleave(repeats=2)\n    expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]])\n    expected_labels_interleave = [\"a\", \"a\", \"b\", \"b\", \"c\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_interleave[\"obs\"], expected_obs_interleave))\n    assert repeated_data_interleave[\"labels\"] == expected_labels_interleave\n    assert repeated_data_interleave[\"info\"] == \"test_info\"\n\n    # Test interleave=False\n    repeated_data_no_interleave = data.repeat(2)\n    expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]])\n    expected_labels_no_interleave = [\"a\", \"b\", \"c\", \"a\", \"b\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_no_interleave[\"obs\"], expected_obs_no_interleave))\n    assert repeated_data_no_interleave[\"labels\"] == expected_labels_no_interleave\n    assert repeated_data_no_interleave[\"info\"] == \"test_info\"\n\n\ndef test_dataproto_pad_unpad():\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = tu.get_tensordict(tensor_dict={\"obs\": obs, \"labels\": labels}, non_tensor_dict={\"info\": \"test_info\"})\n\n    padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=2)\n\n    assert pad_size == 1\n\n    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]])\n    expected_labels = [\"a\", \"b\", \"c\", \"a\"]\n\n    assert torch.all(torch.eq(padded_data[\"obs\"], expected_obs))\n    assert padded_data[\"labels\"] == expected_labels\n    assert padded_data[\"info\"] == \"test_info\"\n\n    unpadd_data = tu.unpad(padded_data, pad_size=pad_size)\n    assert torch.all(torch.eq(unpadd_data[\"obs\"], obs))\n    assert unpadd_data[\"labels\"] == labels\n    assert unpadd_data[\"info\"] == \"test_info\"\n\n    padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=3)\n    assert pad_size == 0\n\n    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    expected_labels = [\"a\", \"b\", \"c\"]\n\n    assert torch.all(torch.eq(padded_data[\"obs\"], expected_obs))\n    assert padded_data[\"labels\"] == expected_labels\n    assert padded_data[\"info\"] == \"test_info\"\n\n    unpadd_data = tu.unpad(padded_data, pad_size=pad_size)\n    assert torch.all(torch.eq(unpadd_data[\"obs\"], obs))\n    assert unpadd_data[\"labels\"] == labels\n    assert unpadd_data[\"info\"] == \"test_info\"\n\n    padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=7)\n    assert pad_size == 4\n\n    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])\n    expected_labels = [\"a\", \"b\", \"c\", \"a\", \"b\", \"c\", \"a\"]\n    assert torch.all(torch.eq(padded_data[\"obs\"], expected_obs))\n    assert padded_data[\"labels\"] == expected_labels\n    assert padded_data[\"info\"] == \"test_info\"\n\n    unpadd_data = tu.unpad(padded_data, pad_size=pad_size)\n    assert torch.all(torch.eq(unpadd_data[\"obs\"], obs))\n    assert unpadd_data[\"labels\"] == labels\n    assert unpadd_data[\"info\"] == \"test_info\"\n\n\ndef test_torch_save_data_proto():\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = tu.get_tensordict({\"obs\": obs, \"labels\": labels}, non_tensor_dict={\"info\": \"test_info\"})\n\n    filename = \"test_data.pt\"\n    torch.save(data, filename)\n    loaded_data = torch.load(filename, weights_only=False)\n\n    assert torch.all(torch.eq(loaded_data[\"obs\"], data[\"obs\"]))\n    assert loaded_data[\"labels\"] == data[\"labels\"]\n    assert loaded_data[\"info\"] == data[\"info\"]\n\n    import os\n\n    os.remove(filename)\n\n\ndef test_len():\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = np.array([\"a\", \"b\", \"c\"], dtype=object)\n\n    data = tu.get_tensordict({\"obs\": obs, \"labels\": labels.tolist()}, non_tensor_dict={\"info\": \"test_info\"})\n    assert len(data) == 3\n\n    data = tu.get_tensordict({\"labels\": labels.tolist()}, non_tensor_dict={\"info\": \"test_info\"})\n    assert len(data) == 3\n\n    data_item = data[0]\n    assert len(data_item) == 0\n\n    data = tu.get_tensordict({}, non_tensor_dict={\"info\": \"test_info\"})\n    assert len(data) == 0\n\n\ndef test_dataproto_index():\n    data_len = 100\n    idx_num = 10\n\n    obs = torch.randn(data_len, 10)\n    labels = [random.choice([\"abc\", \"cde\"]) for _ in range(data_len)]\n\n    data = tu.get_tensordict({\"obs\": obs, \"labels\": labels})\n\n    labels_np = np.array(labels)\n\n    idx_np_int = np.random.randint(0, data_len, size=(idx_num,))\n    result_np_int = data[idx_np_int]\n    assert result_np_int.keys() == data.keys()\n    assert result_np_int[\"obs\"].shape[0] == idx_num\n    assert len(result_np_int[\"labels\"]) == idx_num\n    assert np.array_equal(result_np_int[\"obs\"].cpu().numpy(), obs[idx_np_int].numpy())\n    assert np.array_equal(result_np_int[\"labels\"], labels_np[idx_np_int])\n\n    idx_torch_int = torch.randint(0, data_len, size=(idx_num,))\n    result_torch_int = data[idx_torch_int]\n    assert result_torch_int.keys() == data.keys()\n    assert result_torch_int[\"obs\"].shape[0] == idx_num\n    assert len(result_torch_int[\"labels\"]) == idx_num\n    assert np.array_equal(result_torch_int[\"obs\"].cpu().numpy(), obs[idx_torch_int].cpu().numpy())\n    assert np.array_equal(result_torch_int[\"labels\"], labels_np[idx_torch_int.cpu().numpy()])\n\n    idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)]\n    result_list_int = data[idx_list_int]\n    assert result_list_int.keys() == data.keys()\n    assert result_list_int[\"obs\"].shape[0] == idx_num\n    assert len(result_list_int[\"labels\"]) == idx_num\n    assert np.array_equal(result_list_int[\"obs\"].cpu().numpy(), obs[idx_list_int].cpu().numpy())\n    assert np.array_equal(result_list_int[\"labels\"], labels_np[idx_list_int])\n\n    # idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool)\n    # result_np_bool = data[idx_np_bool]\n    # assert result_np_bool.keys() == data.keys()\n    # assert result_np_bool[\"obs\"].shape[0] == idx_np_bool.sum()\n    # assert len(result_np_bool[\"labels\"]) == idx_np_bool.sum()\n    # assert np.array_equal(result_np_bool[\"obs\"].cpu().numpy(), obs[idx_np_bool].cpu().numpy())\n    # assert np.array_equal(result_np_bool[\"labels\"], labels_np[idx_np_bool])\n\n    idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool)\n    result_torch_bool = data[idx_torch_bool]\n    assert result_torch_bool.keys() == data.keys()\n    assert result_torch_bool[\"obs\"].shape[0] == idx_torch_bool.sum().item()\n    assert len(result_torch_bool[\"labels\"]) == idx_torch_bool.sum().item()\n    assert np.array_equal(result_torch_bool[\"obs\"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy())\n    assert np.array_equal(result_torch_bool[\"labels\"], labels_np[idx_torch_bool])\n\n    # idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)]\n    # result_list_bool = data[idx_list_bool]\n    # assert result_list_bool.keys() == data.keys()\n    # assert result_list_bool[\"obs\"].shape[0] == sum(idx_list_bool)\n    # assert len(result_list_bool[\"labels\"]) == sum(idx_list_bool)\n    # assert np.array_equal(result_list_bool[\"obs\"].cpu().numpy(), obs[idx_list_bool].cpu().numpy())\n    # assert np.array_equal(result_list_bool[\"labels\"], labels_np[idx_list_bool])\n\n\ndef test_select():\n    obs = torch.randn(100, 10)\n    act = torch.randn(100, 3)\n    dataset = tu.get_tensordict({\"obs\": obs, \"act\": act}, non_tensor_dict={\"2\": 2, \"1\": 1})\n\n    subset = dataset.select(\"obs\", \"2\")\n\n    assert torch.all(torch.eq(subset[\"obs\"], dataset[\"obs\"]))\n    assert subset[\"2\"] == dataset[\"2\"]\n    assert \"act\" not in subset.keys()\n    assert \"1\" not in subset.keys()\n\n\ndef test_dataproto_no_batch():\n    labels = [\"a\", \"b\", \"c\"]\n    data = tu.get_tensordict(tensor_dict={\"labels\": labels}, non_tensor_dict={\"info\": \"test_info\"})\n    selected = data.select(\"labels\")\n\n    assert selected[\"labels\"] == labels\n    pop_data = tu.pop(data, keys=[\"labels\"])\n    assert pop_data[\"labels\"] == labels\n    assert \"labels\" not in data\n\n\ndef test_sample_level_repeat():\n    # Create a DataProto object with some batch and non-tensor data\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n\n    data = tu.get_tensordict({\"obs\": obs, \"labels\": labels}, non_tensor_dict={\"info\": \"test_info\"})\n\n    # list\n    repeated_data_interleave = data.repeat_interleave(repeats=torch.tensor([3, 1, 2]))\n    expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]])\n    expected_labels_interleave = [\"a\", \"a\", \"a\", \"b\", \"c\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_interleave[\"obs\"], expected_obs_interleave))\n    assert repeated_data_interleave[\"labels\"] == expected_labels_interleave\n    assert repeated_data_interleave[\"info\"] == \"test_info\"\n\n    # torch.tensor\n    repeated_data_no_interleave = data.repeat_interleave(repeats=torch.tensor([1, 2, 3]))\n    expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]])\n    expected_labels_no_interleave = [\"a\", \"b\", \"b\", \"c\", \"c\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_no_interleave[\"obs\"], expected_obs_no_interleave))\n    assert repeated_data_no_interleave[\"labels\"] == expected_labels_no_interleave\n    assert repeated_data_no_interleave[\"info\"] == \"test_info\"\n\n\ndef test_dataproto_chunk_after_index():\n    data_len = 4\n    obs = torch.randn(data_len, 4)\n    labels = [f\"label_{i}\" for i in range(data_len)]\n\n    data = tu.get_tensordict(tensor_dict={\"obs\": obs, \"labels\": labels}, non_tensor_dict={\"name\": \"abc\"})\n    # Test with boolean numpy array\n    bool_mask = torch.tensor([True, False, True, False])\n    selected = data[bool_mask]\n    assert isinstance(selected.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch_size)  # int or List[int]\n\n    # Test with integer numpy array\n    int_mask = torch.tensor([0, 2])\n    selected = data[int_mask]\n    assert isinstance(selected.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch_size)\n\n    # Test with boolean list\n    list_mask = [True, False, True, False]\n    selected = data[list_mask]\n    assert isinstance(selected.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch_size)\n\n    # Test with list\n    list_mask = [0, 2]\n    selected = data[list_mask]\n    assert isinstance(selected.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch_size)\n\n    # Test with torch tensor (bool)\n    torch_bool_mask = torch.tensor([True, False, True, False])\n    selected = data[torch_bool_mask]\n    assert isinstance(selected.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch_size)\n\n    # Test with torch tensor (int)\n    torch_int_mask = torch.tensor([0, 2])\n    selected = data[torch_int_mask]\n    assert isinstance(selected.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch_size)\n"
  },
  {
    "path": "verl_distillation/tests/trainer/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTests for the trainer module.\n\"\"\"\n"
  },
  {
    "path": "verl_distillation/tests/trainer/config/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/tests/trainer/config/legacy_ppo_megatron_trainer.yaml",
    "content": "data:\n  tokenizer: null\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n  train_max_samples: -1  # set to -1 to use full dataset\n  val_max_samples: -1  # set to -1 to use full dataset\n  prompt_key: prompt\n  reward_fn_key: data_source\n  max_prompt_length: 512\n  max_response_length: 512\n  train_batch_size: 1024\n  val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves\n  return_raw_input_ids: False  # This should be set to true when the tokenizer between policy and rm differs\n  return_raw_chat: False\n  return_full_prompt: False\n  shuffle: True\n  seed: null # An integer seed to use when shuffling the data. If not set or set to `null`, the data shuffling will not be seeded, resulting in a different data order on each run.\n  filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.\n  filter_overlong_prompts_workers: 1\n  truncation: error\n  trust_remote_code: False  # main_ppo will check this config to determine whether to use remote code for tokenizer\n  custom_cls:\n      path: null\n      name: null\n  sampler:\n    class_path: null\n    class_name: null\n  dataloader_num_workers: 8\n  return_multi_modal_inputs: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    custom_chat_template: null\n    external_lib: null\n    override_config:\n      model_config: {}\n      moe_config:\n        freeze_moe_router: False\n    enable_gradient_checkpointing: False\n    gradient_checkpointing_kwargs:\n      ## Activation Checkpointing\n      activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective'\n      # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk\n      # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity\n      activations_checkpoint_granularity: null # 'selective' or 'full'\n      # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention\n      activations_checkpoint_num_layers: null # not used with 'selective'\n    trust_remote_code: False\n  actor:\n    strategy: megatron  # This is for backward-compatibility\n    ppo_mini_batch_size: 256\n    ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu\n    ppo_micro_batch_size_per_gpu: null\n    use_dynamic_bsz: False\n    ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}\n    use_torch_compile: True # False to disable torch compile\n    # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)\n    clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified\n    clip_ratio_low: 0.2\n    clip_ratio_high: 0.2\n    clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729\n    loss_agg_mode: \"token-mean\" # / \"seq-mean-token-sum\" / \"seq-mean-token-mean\"\n    # NOTE: \"token-mean\" is the default behavior\n    entropy_coeff: 0\n    use_kl_loss: False # True for GRPO\n    kl_loss_coef: 0.001 # for grpo\n    kl_loss_type: low_var_kl # for grpo\n    ppo_epochs: 1\n    data_loader_seed: null\n    shuffle: False\n    policy_loss:   # policy loss config\n      loss_mode: \"vanilla\" # Loss function mode: vanilla / clip-cov / kl-cov / gpg from https://arxiv.org/abs/2505.22617,\n      clip_cov_ratio: 0.0002 # Ratio of tokens to be clipped for clip-cov loss\n      clip_cov_lb: 1.0 # Lower bound for clip-cov loss\n      clip_cov_ub: 5.0 # Upper bound for clip-cov loss\n      kl_cov_ratio: 0.0002 # Ratio of tokens to be applied kl penalty for kl-cov loss\n      ppo_kl_coef: 0.1 # KL divergence penalty coefficient\n    optim:\n      optimizer: adam\n      lr: 1e-6\n      clip_grad: 1.0\n      total_training_steps: -1  # must be override by program\n      lr_warmup_init: 0.0  # initial learning rate for warmup, default to 0.0\n      lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.\n      lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n      lr_decay_steps: null\n      lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root\n      min_lr: 0.0 # minimum learning rate, default to 0.0\n      weight_decay: 0.01\n      weight_decay_incr_style: constant # select from constant/linear/cosine\n      lr_wsd_decay_style: exponential # select from constant/exponential/cosine\n      lr_wsd_decay_steps: null\n      use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler\n    megatron:\n      param_offload: False\n      grad_offload: False\n      optimizer_offload: False\n      tensor_model_parallel_size: 1\n      expert_model_parallel_size: 1\n      expert_tensor_parallel_size: null\n      pipeline_model_parallel_size: 1\n      virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests\n      context_parallel_size: 1\n      sequence_parallel: True\n      use_distributed_optimizer: True\n      use_dist_checkpointing: False\n      dist_checkpointing_path: null\n      seed: 42\n      override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage\n      use_mbridge: False\n    profile: # profile the actor model in `update_policy`\n      use_profile: False # open it when you want to profile the actor model\n      profile_ranks: null # list, you can specify the ranks to profile\n      step_start: -1 # start step in update_policy\n      step_end: -1 # end step\n      save_path: null # the path to save the profile result\n    load_weight: True\n    checkpoint:\n      async_save: False # save checkpoint asynchronously\n      # What to include in saved checkpoints\n      # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n      save_contents: ['model', 'optimizer', 'extra']\n      # For more flexibility, you can specify the contents to load from the checkpoint.\n      load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}\n  ref:\n    strategy: ${actor_rollout_ref.actor.strategy}\n    use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}\n    megatron:\n      param_offload: False\n      tensor_model_parallel_size: 1\n      expert_model_parallel_size: 1\n      expert_tensor_parallel_size: null\n      pipeline_model_parallel_size: 1\n      virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests\n      context_parallel_size: 1\n      sequence_parallel: True\n      use_distributed_optimizer: True\n      use_dist_checkpointing: False\n      dist_checkpointing_path: null\n      seed: ${actor_rollout_ref.actor.megatron.seed}\n      override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}\n      use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}\n    profile:\n      use_profile: False\n      profile_ranks: null\n      step_start: -1\n      step_end: -1\n      save_path: null\n    load_weight: True\n    log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n  rollout:\n    name: vllm\n    mode: sync # sync: LLM, async: AsyncLLM\n    temperature: 1.0\n    top_k: -1 # 0 for hf rollout, -1 for vllm rollout\n    top_p: 1\n    prompt_length: ${data.max_prompt_length}  # for xperf_gpt\n    response_length: ${data.max_response_length}\n    # for vllm rollout\n    dtype: bfloat16 # should align with FSDP\n    gpu_memory_utilization: 0.5\n    ignore_eos: False\n    enforce_eager: False\n    free_cache_engine: True\n    load_format: dummy\n    tensor_model_parallel_size: 2\n    max_num_batched_tokens: 8192\n    max_model_len: null\n    max_num_seqs: 1024\n    log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n    disable_log_stats: True\n    enable_chunked_prefill: True # could get higher throughput\n    # for hf rollout\n    do_sample: True\n    layer_name_map:\n      qkv_layer_name: qkv\n      gate_proj_layer_name: gate_up\n    # number of responses (i.e. num sample times)\n    n: 1\n    engine_kwargs: # inference engine parameters, please refer vllm/sglang official doc for detail\n      vllm: {}\n      sglang: {}\n    val_kwargs:\n      # sampling parameters for validation\n      top_k: -1 # 0 for hf rollout, -1 for vllm rollout\n      top_p: 1.0\n      temperature: 0\n      n: 1\n      do_sample: False # default eager for validation\n\n    # Multi-turn interaction config for tools or chat.\n    multi_turn:\n      # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well\n      enable: False\n\n      # null for no limit (default max_length // 3)\n      max_assistant_turns: null\n\n      # null for no tool\n      tool_config_path: null\n\n      # null for no limit (default max_length // 3)\n      max_user_turns: null\n\n      # max parallel call for tools in single turn\n      max_parallel_calls: 1\n\n      # max length of tool response\n      max_tool_response_length: 256\n\n      # truncate side of tool response: left, middle, right\n      tool_response_truncate_side: middle\n\n      # null for no interaction\n      interaction_config_path: null\n\n      # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.\n      # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,\n      #   which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.\n      use_inference_chat_template: False\n\n      # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.\n      # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.\n      # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.\n      # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:\n      # Qwen/QwQ-32B, Qwen/Qwen3-xxB\n      # - disable: disable tokenization sanity check\n      # - strict: enable strict tokenization sanity check (default)\n      # - ignore_strippable: ignore strippable tokens when checking tokenization sanity\n      tokenization_sanity_check_mode: strict\n\n      # Format of the multi-turn interaction. Options: hermes, llama3_json, ...\n      format: hermes\n\n    # [Experimental] agent loop based rollout configs\n    agent:\n\n      # Number of agent loop workers\n      num_workers: 8\n\n      custom_async_server:\n        path: null\n        name: null\n\n    # support logging rollout prob for debugging purpose\n    calculate_log_probs: False\n    # Nsight system profiler configs\n  profiler:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: False\n    all_ranks: False\n    ranks: []\n\ncritic:\n  rollout_n: ${actor_rollout_ref.rollout.n}\n  strategy: ${actor_rollout_ref.actor.strategy}\n  nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron\n  optim:\n    optimizer: adam\n    lr: 1e-6\n    clip_grad: 1.0\n    total_training_steps: -1  # must be override by program\n    lr_warmup_init: 0.0  # initial learning rate for warmup, default to 0.0\n    lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.\n    lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n    lr_decay_steps: null\n    lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root\n    min_lr: 0.0 # minimum learning rate, default to 0.0\n    weight_decay: 0.01\n    weight_decay_incr_style: constant # select from constant/linear/cosine\n    lr_wsd_decay_style: exponential # select from constant/exponential/cosine\n    lr_wsd_decay_steps: null\n    use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    tokenizer_path: ${actor_rollout_ref.model.path}\n    override_config:\n      model_config: {}\n      moe_config:\n        freeze_moe_router: False\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    trust_remote_code: False\n    enable_gradient_checkpointing: False\n    gradient_checkpointing_kwargs:\n      ## Activation Checkpointing\n      activations_checkpoint_method: null\n      activations_checkpoint_granularity: null\n      activations_checkpoint_num_layers: null\n  megatron:\n    param_offload: False\n    grad_offload: False\n    optimizer_offload: False\n    tensor_model_parallel_size: 1\n    expert_model_parallel_size: 1\n    expert_tensor_parallel_size: null\n    pipeline_model_parallel_size: 1\n    virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests\n    context_parallel_size: 1\n    sequence_parallel: True\n    use_distributed_optimizer: True\n    use_dist_checkpointing: False\n    dist_checkpointing_path: null\n    seed: ${actor_rollout_ref.actor.megatron.seed}\n    override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}\n    use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}\n  load_weight: True\n  ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}\n  ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu\n  ppo_micro_batch_size_per_gpu: null\n  use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n  ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2\n  forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}\n  ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}\n  data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed}\n  shuffle: ${actor_rollout_ref.actor.shuffle}\n  cliprange_value: 0.5\n  loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}\n  checkpoint:\n    async_save: False # save checkpoint asynchronously\n    # What to include in saved checkpoints\n    # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n    save_contents: ['model', 'optimizer', 'extra']\n    load_contents: ${critic.checkpoint.save_contents}\n  # Nsight system profiler configs\n  profiler:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: False\n    all_ranks: False\n    ranks: []\nreward_model:\n  enable: False\n  strategy: ${actor_rollout_ref.actor.strategy}\n  nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron\n  megatron:\n    param_offload: False\n    tensor_model_parallel_size: 1\n    expert_model_parallel_size: 1\n    expert_tensor_parallel_size: null\n    pipeline_model_parallel_size: 1\n    virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests\n    context_parallel_size: 1\n    sequence_parallel: True\n    use_distributed_optimizer: False\n    use_dist_checkpointing: False\n    dist_checkpointing_path: null\n    seed: ${actor_rollout_ref.actor.megatron.seed}\n    override_transformer_config: {}\n    use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}\n  model:\n    input_tokenizer: ${actor_rollout_ref.model.path}  # set this to null if the chat template is identical\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n    trust_remote_code: False\n    external_lib: ${actor_rollout_ref.model.external_lib}\n  load_weight: True\n  micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu\n  micro_batch_size_per_gpu: null\n  use_dynamic_bsz: ${critic.use_dynamic_bsz}\n  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n  max_length: null\n  reward_manager: naive\n  launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob\n  sandbox_fusion:\n    url: null # faas url to run code in cloud sandbox\n    max_concurrent: 64 # max concurrent requests to sandbox\n    memory_limit_mb: 1024 # Max memory limit for each sandbox process in MB\n  # Nsight system profiler configs\n  profiler:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: False\n    all_ranks: False\n    ranks: []\n\ncustom_reward_function:\n  path: null\n  name: compute_score\n\nalgorithm:\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.trainer.config.AlgoConfig\n  gamma: 1.0\n  lam: 1.0\n  adv_estimator: gae\n  norm_adv_by_std_in_grpo: True\n  use_kl_in_reward: False\n  kl_penalty: kl  # how to estimate kl divergence\n  kl_ctrl:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.trainer.config.KLControlConfig\n    type: fixed\n    kl_coef: 0.001\n    horizon: 10000\n    target_kl: 0.1\n  use_pf_ppo: False\n  pf_ppo:\n    reweight_method: pow  # [\"pow\", \"max_min\", \"max_random\"]\n    weight_pow: 2.0\n\ntrainer:\n  balance_batch: True\n  total_epochs: 30\n  total_training_steps: null\n  profile_steps: null # [1,2,5] or [] or null\n  project_name: verl_examples\n  experiment_name: gsm8k\n  logger: ['console', 'wandb']\n  log_val_generations: 0\n  nnodes: 1\n  n_gpus_per_node: 8\n  save_freq: -1\n  esi_redundant_time: 0\n\n  # auto: find the last ckpt to resume. If can't find, start from scratch\n  resume_mode: auto # or disable or resume_path if resume_from_path is set\n  resume_from_path: null\n  del_local_ckpt_after_load: False\n  val_before_train: True\n  test_freq: -1\n  critic_warmup: 0\n  default_hdfs_dir: null\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  max_actor_ckpt_to_keep: null\n  max_critic_ckpt_to_keep: null\n  # The timeout for ray worker group to wait for the register center to be ready\n  ray_wait_register_center_timeout: 300\n  device: cuda\n  # see ppo_trainer.yaml for more details\n  controller_nsight_options:\n    trace: \"cuda,nvtx,cublas,ucx\"\n    cuda-memory-usage: \"true\"\n    cuda-graph-trace: \"graph\"\n  worker_nsight_options:\n    trace: \"cuda,nvtx,cublas,ucx\"\n    cuda-memory-usage: \"true\"\n    cuda-graph-trace: \"graph\"\n    capture-range: \"cudaProfilerApi\"\n    capture-range-end: null\n    kill: none\n  npu_profile:\n    options:\n      save_path: ./profiler_data\n      roles: [\"all\"]\n      level: level1\n      with_memory: False\n      record_shapes: False\n      with_npu: True\n      with_cpu: True\n      with_module: False\n      with_stack: False\n      analysis: True\n\nray_kwargs:\n  ray_init:\n    num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_distillation/tests/trainer/config/legacy_ppo_trainer.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# dataset config\ndata:\n\n  # Tokenizer class or path. If null, it will be inferred from the model.\n  tokenizer: null\n\n  # Whether to use shared memory for data loading.\n  use_shm: False\n\n  # Training set parquet. Can be a list or a single file.\n  # The program will read all files into memory, so it can't be too large (< 100GB).\n  # The path can be either a local path or an HDFS path.\n  # For HDFS path, we provide utils to download it to DRAM and convert it to a local path.\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n\n  # Validation parquet. Can be a list or a single file.\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n\n  # Maximum sample length to be used.\n  # Set to -1 to use full dataset, otherwise, randomly\n  # select the specified number of samples from train dataset\n  train_max_samples: -1\n\n  # Maximum sample length to be used.\n  # Set to -1 to use full dataset, otherwise, randomly\n  # select the specified number of samples from val dataset\n  val_max_samples: -1\n\n  # The field in the dataset where the prompt is located. Default is 'prompt'.\n  prompt_key: prompt\n\n  # The field used to select the reward function (if using different ones per example).\n  reward_fn_key: data_source\n\n  # Maximum prompt length. All prompts will be left-padded to this length.\n  # An error will be reported if the length is too long.\n  max_prompt_length: 512\n\n  # Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length.\n  max_response_length: 512\n\n  # Batch size sampled for one training iteration of different RL algorithms.\n  train_batch_size: 1024\n\n  # Batch size used during validation. Can be null.\n  val_batch_size: null\n\n  # Whether to return the original input_ids without adding chat template.\n  # This is used when the reward model's chat template differs from the policy.\n  # If using a model-based RM with different templates, this should be True.\n  return_raw_input_ids: False\n\n  # Whether to return the original chat (prompt) without applying chat template.\n  return_raw_chat: False\n\n  # Whether to return the full prompt with chat template.\n  return_full_prompt: False\n\n  # Whether to shuffle the data in the dataloader.\n  shuffle: True\n\n  # An integer seed to use when shuffling the data. If not set or set to\n  # `null`, the data shuffling will not be seeded, resulting in a different data order on each run.\n  seed: null\n\n  # num dataloader workers\n  dataloader_num_workers: 8\n\n  # Whether to shuffle the validation set.\n  validation_shuffle: False\n\n  # Whether to filter overlong prompts.\n  filter_overlong_prompts: False\n\n  # Number of workers for filtering overlong prompts.\n  # For large-scale datasets, filtering can be time-consuming.\n  # Use multiprocessing to speed up. Default is 1.\n  filter_overlong_prompts_workers: 1\n\n  # Truncate the input_ids or prompt if they exceed max_prompt_length.\n  # Options: 'error', 'left', or 'right'. Default is 'error'.\n  truncation: error\n\n  # The field in the multi-modal dataset where the image is located. Default is 'images'.\n  image_key: images\n\n  # The field in the multi-modal dataset where the video is located.\n  video_key: videos\n\n  # If the remote tokenizer has a Python file, this flag determines whether to allow using it.\n  trust_remote_code: False\n\n  # Optional: specify a custom dataset class path and name if overriding default loading behavior.\n  custom_cls:\n\n    # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used.\n    path: null\n\n    # The name of the dataset class within the specified file.\n    name: null\n\n  # Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs.\n  return_multi_modal_inputs: True\n\n  # Data generation configuration for augmenting the dataset.\n  datagen:\n\n    # The path to the file containing your customized data generation class.\n    # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset'\n    path: null\n\n    # The class name of the data generation class within the specified file.\n    # E.g. 'MockDataGenerator'\n    name: null\n\n  # settings related to data sampler\n  sampler:\n\n    # the path to the module containing a curriculum class which implements the\n    # AbstractSampler interface\n    class_path: null\n\n    # the name of the curriculum class like `MySampler`\n    class_name: null\n\n  # Additional kwargs when calling tokenizer.apply_chat_template\n  apply_chat_template_kwargs: {}\n\n# config for actor, rollout and reference model\nactor_rollout_ref:\n\n  # Whether it's a hybrid engine, currently only supports hybrid engine\n  hybrid_engine: true\n\n  # common configs for the model\n  model:\n\n    _target_: verl.workers.config.HFModelConfig\n\n    # Huggingface model path. This can be either local path or HDFS path.\n    path: ~/models/deepseek-llm-7b-chat\n\n    # Custom chat template for the model.\n    custom_chat_template: null\n\n    # Whether to use shared memory (SHM) for accelerating the loading of model weights\n    use_shm: false\n\n    # Additional Python packages to register huggingface models/tokenizers.\n    external_lib: null\n\n    # Used to override model's original configurations, mainly dropout\n    override_config: {}\n\n    # Enable gradient checkpointing for actor\n    enable_gradient_checkpointing: true\n\n    # Enable activation offloading for actor\n    enable_activation_offload: false\n\n    # Whether to remove padding tokens in inputs during training\n    use_remove_padding: false\n\n    # Set to positive value to enable LoRA (e.g., 32)\n    lora_rank: 0\n\n    # LoRA scaling factor\n    lora_alpha: 16\n\n    # Target modules to apply LoRA. Options: \"all-linear\" (not recommended for VLMs) or\n    # [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj]\n    target_modules: all-linear\n\n    # Exclude modules from applying Lora. Similar usage to target_modules and Peft.\n    # Example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as currently vllm does not support ViT Lora.\n    exclude_modules: null\n\n    # Whether to use Liger for linear layer fusion\n    use_liger: false\n\n    # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP)\n    use_fused_kernels: false\n\n    # Options for fused kernels. If use_fused_kernels is true, this will be used.\n    fused_kernel_options:\n\n      # Implementation backend for fused kernels. Options: \"triton\" or \"torch\".\n      impl_backend: torch\n\n    # Whether to enable loading a remote code model\n    trust_remote_code: false\n\n  # actor configs\n  actor:\n\n    # fsdp, fsdp2 or megatron. fsdp backend used here.\n    strategy: fsdp\n\n    # Split each sample into sub-batches of this size for PPO\n    ppo_mini_batch_size: 256\n\n    # [Deprecated] Global micro batch size\n    ppo_micro_batch_size: null\n\n    # Local per-GPU micro batch size\n    ppo_micro_batch_size_per_gpu: null\n\n    # Whether to automatically adjust batch size at runtime\n    use_dynamic_bsz: false\n\n    # Max tokens per GPU in one PPO batch; affects gradient accumulation\n    # Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length}\n    ppo_max_token_len_per_gpu: 16384\n\n    # Gradient clipping for actor updates\n    grad_clip: 1.0\n\n    # PPO clip ratio\n    clip_ratio: 0.2\n\n    # Lower bound for asymmetric clipping (used in dual-clip PPO)\n    clip_ratio_low: 0.2\n\n    # Upper bound for asymmetric clipping (used in dual-clip PPO)\n    clip_ratio_high: 0.2\n\n    # policy loss config\n    policy_loss:\n\n      # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617\n      loss_mode: \"vanilla\"\n\n      # Ratio of tokens to be clipped for clip-cov loss\n      clip_cov_ratio: 0.0002\n\n      # Lower bound for clip-cov loss\n      clip_cov_lb: 1.0\n\n      # Upper bound for clip-cov loss\n      clip_cov_ub: 5.0\n\n      # Ratio of tokens to be applied kl penalty for kl-cov loss\n      kl_cov_ratio: 0.0002\n\n      # KL divergence penalty coefficient\n      ppo_kl_coef: 0.1\n\n    # Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C\n    clip_ratio_c: 3.0\n\n    # Loss aggregation mode: \"token-mean\", \"seq-mean-token-sum\", or \"seq-mean-token-mean\"\n    loss_agg_mode: token-mean\n\n    # Entropy regularization coefficient in PPO loss\n    entropy_coeff: 0\n\n    # Whether to use KL loss instead of KL reward penalty. True for GRPO\n    use_kl_loss: false\n\n    # Whether to use torch.compile()\n    use_torch_compile: true\n\n    # KL loss coefficient when use_kl_loss is enabled. For GRPO\n    kl_loss_coef: 0.001\n\n    # Type of KL divergence loss. Options: \"kl\"(k1), \"abs\", \"mse\"(k2), \"low_var_kl\"(k3), \"full\"\n    kl_loss_type: low_var_kl\n\n    # Number of PPO epochs per batch\n    ppo_epochs: 1\n\n    # Shuffle training data across PPO epochs\n    shuffle: false\n\n    # Sequence parallelism size for Ulysses-style model parallelism\n    ulysses_sequence_parallel_size: 1\n\n    # calculate entropy with chunking to reduce memory peak\n    entropy_from_logits_with_chunking: False\n\n    # recompute entropy\n    entropy_checkpointing: False\n\n    # checkpoint configs\n    checkpoint:\n\n      # What to include in saved checkpoints\n      # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n      save_contents: ['model', 'optimizer', 'extra']\n\n      # For more flexibility, you can specify the contents to load from the checkpoint.\n      load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}\n\n    # optimizer configs\n    optim:\n\n      # Learning rate\n      lr: 1e-6\n\n      # Warmup steps; negative value delegates to lr_warmup_steps_ratio\n      lr_warmup_steps: -1\n\n      # Warmup steps ratio (used if lr_warmup_steps is negative)\n      lr_warmup_steps_ratio: 0.0\n\n      # Minimum LR ratio for cosine schedule\n      min_lr_ratio: 0.0\n\n      # Number of cosine cycles in LR schedule\n      num_cycles: 0.5\n\n      # LR scheduler type: \"constant\" or \"cosine\"\n      lr_scheduler_type: constant\n\n      # Total training steps (must be overridden at runtime)\n      total_training_steps: -1\n\n      # Weight decay\n      weight_decay: 0.01\n\n    # configs for FSDP\n    fsdp_config:\n\n      # policy for wrapping the model\n      wrap_policy:\n\n        # Minimum number of parameters to trigger wrapping a layer with FSDP\n        min_num_params: 0\n\n      # Whether to offload model parameters to CPU (trades speed for memory)\n      param_offload: false\n\n      # Whether to offload optimizer state to CPU\n      optimizer_offload: false\n\n      # Only for FSDP2: offload param/grad/optimizer during train\n      offload_policy: false\n\n      # Only for FSDP2: Reshard after forward pass to reduce memory footprint\n      reshard_after_forward: true\n\n      # Number of GPUs in each FSDP shard group; -1 means auto\n      fsdp_size: -1\n\n      # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n      # before the current forward computation.\n      forward_prefetch: False\n\n  # Reference model config.\n  # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True.\n  ref:\n\n    # actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default\n    strategy: ${actor_rollout_ref.actor.strategy}\n\n    # config for FSDP strategy\n    fsdp_config:\n\n      # whether to offload parameters in FSDP\n      param_offload: False\n\n      # whether to perform reshard after model forward to save memory.\n      # only for fsdp2, [True, False, int between 1 and fsdp_size]\n      reshard_after_forward: True\n\n      # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n      # before the current forward computation.\n      forward_prefetch: False\n\n      # the wrap policy for FSDP model\n      wrap_policy:\n\n        # minimum number of params in a wrapped module\n        min_num_params: 0\n\n    # whether to enable torch.compile\n    use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}\n\n    # [Will be deprecated, use log_prob_micro_batch_size_per_gpu]\n    # The batch size for one forward pass in the computation of log_prob. Global batch size.\n    log_prob_micro_batch_size: null\n\n    # The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.\n    log_prob_micro_batch_size_per_gpu: null\n\n    # enable dynamic batch size (sequence packing) for log_prob computation\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n\n    # the max token length per GPU\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n\n    # sequence parallel size\n    ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}\n\n    # calculate entropy with chunking to reduce memory peak\n    entropy_from_logits_with_chunking: False\n\n    # recompute entropy\n    entropy_checkpointing: False\n\n  # Rollout model config.\n  rollout:\n\n    # actor_rollout_ref.rollout.name: hf/vllm/sglang.\n    name: vllm\n\n    # sync: LLM, async: AsyncLLM\n    mode: sync\n\n    # Sampling temperature for rollout.\n    temperature: 1.0\n\n    # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.\n    top_k: -1\n\n    # Top-p sampling parameter. Default 1.0.\n    top_p: 1\n\n\n    # typically the same as data max prompt length\n    prompt_length: ${data.max_prompt_length}\n\n    # typically the same as data max response length\n    response_length: ${data.max_response_length}\n\n    # for vllm rollout\n    # Rollout model parameters type. Align with actor model's FSDP/Megatron type.\n    dtype: bfloat16\n\n    # Fraction of GPU memory used by vLLM/SGLang for KV cache.\n    gpu_memory_utilization: 0.5\n\n    # Whether to ignore EOS and continue generating after EOS is hit.\n    ignore_eos: False\n\n    # Whether to disable CUDA graph. Default True to allow cache freeing.\n    enforce_eager: False\n\n    # Whether to free engine KVCache after generation. Set enforce_eager=True when enabled.\n    free_cache_engine: True\n\n    # Which loader to use for rollout model weights: dummy_dtensor, hf, megatron, etc.\n    # safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight\n    load_format: dummy\n\n    # for huge model, layered summon can save memory (prevent OOM) but make it slower\n    layered_summon: False\n\n    # TP size for rollout. Only effective for vLLM.\n    tensor_model_parallel_size: 2\n\n    # max number of tokens in a batch\n    max_num_batched_tokens: 8192\n\n    # max length for rollout\n    max_model_len: null\n\n    # max length of sequences\n    max_num_seqs: 1024\n\n    # [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size.\n    log_prob_micro_batch_size: null\n\n    # The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.\n    log_prob_micro_batch_size_per_gpu: null\n\n    # enable dynamic batch size (sequence packing) for log_prob computation\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n\n    # max token length for log_prob computation\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n\n    # disable logging statistics\n    disable_log_stats: True\n\n    # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.\n    enable_chunked_prefill: True\n\n    # for hf rollout\n    # Whether to sample during training rollout. False uses greedy sampling.\n    do_sample: True\n\n    # number of responses (i.e. num sample times). > 1 for grpo\n    n: 1\n\n    # Whether to wake up inference engine in multi-stage to reduce peak memory during training-rollout transition.\n    multi_stage_wake_up: false\n\n    # Extra inference engine arguments, please refer vllm/sglang official doc for detail\n    engine_kwargs:\n\n      # vllm engine config\n      vllm: {}\n\n      # sglang engine config\n      sglang: {}\n\n    # Sampling parameters used during validation.\n    val_kwargs:\n\n      # sampling parameters for validation\n      # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.\n      top_k: -1\n\n      # Top-p sampling parameter. Default 1.0.\n      top_p: 1.0\n\n      # Sampling temperature for rollout.\n      temperature: 0\n\n      # whether to repeat n times for validation\n      n: 1\n\n      # Whether to sample during training rollout. False uses greedy sampling.\n      do_sample: False\n\n    # Multi-turn interaction config for tools or chat.\n    multi_turn:\n\n      # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well\n      enable: False\n\n      # null for no limit (default max_length // 3)\n      max_assistant_turns: null\n\n      # null for no tool\n      tool_config_path: null\n\n      # null for no limit (default max_length // 3)\n      max_user_turns: null\n\n      # max parallel call for tools in single turn\n      max_parallel_calls: 1\n\n      # max length of tool response\n      max_tool_response_length: 256\n\n      # truncate side of tool response: left, middle, right\n      tool_response_truncate_side: middle\n\n      # null for no interaction\n      interaction_config_path: null\n\n      # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.\n      # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,\n      #   which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.\n      use_inference_chat_template: False\n\n      # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.\n      # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.\n      # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.\n      # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:\n      # Qwen/QwQ-32B, Qwen/Qwen3-xxB\n      # - disable: disable tokenization sanity check\n      # - strict: enable strict tokenization sanity check (default)\n      # - ignore_strippable: ignore strippable tokens when checking tokenization sanity\n      tokenization_sanity_check_mode: strict\n\n      # Format of the multi-turn interaction. Options: hermes, llama3_json, ...\n      format: hermes\n\n    # support logging rollout prob for debugging purpose\n    calculate_log_probs: False\n\n    # [Experimental] agent loop based rollout configs\n    agent:\n\n      # Number of agent loop workers\n      num_workers: 8\n\n      # custom async server configs\n      custom_async_server:\n\n        # Path to the custom async server implementation\n        path: null\n\n        # Class name of the custom async server class (e.g. AsyncvLLMServer)\n        name: null\n\n  # profiler configs\n  profiler:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.utils.profiler.ProfilerConfig\n\n    # True for each task has its own database, False for all tasks in one training step share one database.\n    discrete: False\n\n    # Whether to profile all ranks.\n    all_ranks: False\n\n    # The ranks that will be profiled. [] or [0,1,...]\n    ranks: []\n\n# configs for the critic\ncritic:\n\n  # Number of rollouts per update (mirrors actor rollout_n)\n  rollout_n: ${actor_rollout_ref.rollout.n}\n\n  # fsdp or fsdp2 strategy used for critic model training\n  strategy: ${actor_rollout_ref.actor.strategy}\n\n  # optimizer configs\n  optim:\n\n    # Learning rate\n    lr: 1e-5\n\n    # Warmup steps ratio; total steps will be injected at runtime\n    lr_warmup_steps_ratio: 0.\n\n    # Minimum LR ratio for cosine schedule\n    min_lr_ratio: 0.0\n\n    # LR scheduler type: \"constant\" or \"cosine\"\n    lr_scheduler_type: constant\n\n    # Total training steps (must be overridden at runtime)\n    total_training_steps: -1\n\n    # Weight decay\n    weight_decay: 0.01\n\n  # model config for the critic\n  model:\n\n    # Path to pretrained model weights\n    path: ~/models/deepseek-llm-7b-chat\n\n    # Whether to use shared memory for loading the model\n    use_shm: False\n\n    # Tokenizer path (defaults to actor's model path)\n    tokenizer_path: ${actor_rollout_ref.model.path}\n\n    # Hugging Face config override\n    override_config: { }\n\n    # External model implementation (optional)\n    external_lib: ${actor_rollout_ref.model.external_lib}\n\n    # Enable gradient checkpointing to save memory\n    enable_gradient_checkpointing: True\n\n    # Offload activations to CPU to reduce GPU memory usage\n    enable_activation_offload: False\n\n    # Use remove padding optimization (saves compute)\n    use_remove_padding: False\n\n    # Whether to trust remote code from Hugging Face models\n    trust_remote_code: ${actor_rollout_ref.model.trust_remote_code}\n\n    # FSDP-specific config\n    fsdp_config:\n\n      # Whether to offload model parameters to CPU\n      param_offload: False\n\n      # Whether to offload optimizer state to CPU\n      optimizer_offload: False\n\n      # Only for FSDP2: offload param/grad/optimizer during train\n      offload_policy: False\n\n      # Only for FSDP2: Reshard after forward pass to reduce memory footprint\n      reshard_after_forward: True\n\n      # Policy for wrapping layers with FSDP\n      wrap_policy:\n\n        # Minimum number of parameters to trigger wrapping\n        min_num_params: 0\n\n      # Number of GPUs in each FSDP shard group; -1 means auto\n      fsdp_size: -1\n\n      # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n      # before the current forward computation.\n      forward_prefetch: False\n\n    # Set to positive value to enable LoRA (e.g., 32)\n    lora_rank: 0\n\n    # LoRA scaling factor\n    lora_alpha: 16\n\n    # LoRA target modules: \"all-linear\" or list of linear projection layers\n    target_modules: all-linear\n\n  # PPO mini-batch size per update\n  ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}\n\n  # [Deprecated] Global micro batch size\n  ppo_micro_batch_size: null\n\n  # Local per-GPU micro batch size\n  ppo_micro_batch_size_per_gpu: null\n\n  # Forward-only batch size (global)\n  forward_micro_batch_size: ${critic.ppo_micro_batch_size}\n\n  # Forward-only batch size (per GPU)\n  forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}\n\n  # Whether to automatically adjust batch size at runtime\n  use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n\n  # Max tokens per GPU in one PPO batch (doubled for critic)\n  ppo_max_token_len_per_gpu: 32768\n\n  # Max token length per GPU in forward pass\n  forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}\n\n  # Sequence parallelism size for Ulysses-style model parallelism\n  ulysses_sequence_parallel_size: 1\n\n  # Number of PPO epochs per batch\n  ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}\n\n  # Shuffle training data across PPO epochs\n  shuffle: ${actor_rollout_ref.actor.shuffle}\n\n  # Gradient clipping for critic updates\n  grad_clip: 1.0\n\n  # PPO value function clipping range\n  cliprange_value: 0.5\n\n  # Loss aggregation mode: \"token-mean\", \"seq-mean-token-sum\", or \"seq-mean-token-mean\"\n  loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}\n\n  # checkpoint configs\n  checkpoint:\n\n    # What to include in saved checkpoints\n    # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n    save_contents: ['model', 'optimizer', 'extra']\n\n    # What to include when loading checkpoints\n    load_contents: ${critic.checkpoint.save_contents}\n\n  # profiler configs\n  # the corresponding dataclass is verl.utils.profiler.ProfilerConfig.\n  profiler:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.utils.profiler.ProfilerConfig\n\n    # True for each task has its own database, False for all tasks in one training step share one database.\n    discrete: False\n\n    # Whether to profile all ranks.\n    all_ranks: False\n\n    # The ranks that will be profiled. [] or [0,1,...]\n    ranks: []\n\n# configs for the reward model\nreward_model:\n\n  # Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions.\n  # In GSM8K and Math examples, we disable reward model.\n  # For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses.\n  # If False, the following parameters are not effective\n  enable: False\n\n  # FSDP strategy: \"fsdp\" or \"fsdp2\"\n  strategy: ${actor_rollout_ref.actor.strategy}\n\n  # model config for reward scoring\n  model:\n\n    # Input tokenizer. If the reward model’s chat template is inconsistent with the policy,\n    # we need to first decode to plaintext, then apply the rm’s chat_template.\n    # Then score with RM. If chat_templates are consistent, it can be set to null.\n    input_tokenizer: ${actor_rollout_ref.model.path}\n\n    # RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification.\n    # Other model types need to define their own RewardModelWorker and pass it from the code.\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n\n    # Whether to use shared memory for loading the model\n    use_shm: False\n\n    # External model implementation (optional)\n    external_lib: ${actor_rollout_ref.model.external_lib}\n\n    # Use remove padding optimization (saves compute)\n    use_remove_padding: False\n\n    # Whether to use fused reward kernels for speedup\n    use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}\n\n    # Whether to enable loading a remote code model, default to False\n    trust_remote_code: False\n\n    # FSDP-specific config\n    fsdp_config:\n\n      # Policy for wrapping layers with FSDP\n      wrap_policy:\n\n        # Minimum number of parameters to trigger wrapping\n        min_num_params: 0\n\n      # Whether to offload model parameters to CPU\n      param_offload: False\n\n      # Only for FSDP2: Reshard after forward pass to reduce memory footprint\n      reshard_after_forward: True\n\n      # Number of GPUs in each FSDP shard group; -1 means auto\n      fsdp_size: -1\n\n      # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n      # before the current forward computation.\n      forward_prefetch: False\n\n  # [Deprecated] Global micro batch size\n  micro_batch_size: null\n\n  # Local per-GPU micro batch size\n  micro_batch_size_per_gpu: null\n\n  # Maximum sequence length to process for scoring\n  max_length: null\n\n  # Sequence parallelism size for Ulysses-style model parallelism\n  ulysses_sequence_parallel_size: 1\n\n  # Whether to dynamically adjust batch size at runtime\n  use_dynamic_bsz: ${critic.use_dynamic_bsz}\n\n  # Maximum number of tokens per GPU in one forward pass\n  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n\n  # Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources.\n  # Default is naive. If all verification functions are multiprocessing-safe,\n  # the reward manager can be set to prime for parallel verification.\n  reward_manager: naive\n\n  # Whether to launch custom reward function asynchronously during log_prob\n  launch_reward_fn_async: False\n\n  # Cloud/local sandbox fusion configuration for custom reward logic\n  sandbox_fusion:\n\n    # Cloud/local function URL for sandbox execution\n    url: null\n\n    # Max concurrent requests allowed to sandbox\n    max_concurrent: 64\n\n    # Max memory limit for each sandbox process in MB\n    memory_limit_mb: 1024\n\n  # profiler configs\n  profiler:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.utils.profiler.ProfilerConfig\n\n    # True for each task has its own database, False for all tasks in one training step share one database.\n    discrete: False\n\n    # Whether to profile all ranks.\n    all_ranks: False\n\n    # The ranks that will be profiled. [] or [0,1,...]\n    ranks: []\n\n# custom reward function definition\ncustom_reward_function:\n\n  # The path to the file containing your customized reward function.\n  # If not specified, pre-implemented reward functions will be used.\n  path: null\n\n  # The name of the reward function within the specified file. Default is 'compute_score'.\n  name: compute_score\n\n# config for the algorithm\nalgorithm:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.trainer.config.AlgoConfig\n\n  # Discount factor for future rewards\n  gamma: 1.0\n\n  # Trade-off between bias and variance in the GAE estimator\n  lam: 1.0\n\n  # Advantage estimator type: \"gae\", \"grpo\", \"reinforce_plus_plus\", etc.\n  adv_estimator: gae\n\n  # Whether to normalize advantages by std (specific to GRPO)\n  norm_adv_by_std_in_grpo: True\n\n  # Whether to enable in-reward KL penalty\n  use_kl_in_reward: False\n\n  # How to estimate KL divergence: \"kl\", \"abs\", \"mse\", \"low_var_kl\", or \"full\"\n  kl_penalty: kl\n\n  # KL control configuration\n  kl_ctrl:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.trainer.config.KLControlConfig\n\n    # KL control type: \"fixed\" or \"adaptive\"\n    type: fixed\n\n    # Initial coefficient for KL penalty\n    kl_coef: 0.001\n\n    # Horizon value for adaptive controller (if enabled)\n    horizon: 10000\n\n    # Target KL divergence (used for adaptive controller)\n    target_kl: 0.1\n\n  # Whether to enable preference feedback PPO\n  use_pf_ppo: False\n\n  # Preference feedback PPO settings\n  pf_ppo:\n\n    # Method for reweighting samples: \"pow\", \"max_min\", or \"max_random\"\n    reweight_method: pow\n\n    # Power used for weight scaling in \"pow\" method\n    weight_pow: 2.0\n\n# config for the trainer\ntrainer:\n\n  # Whether to balance batch sizes across distributed workers\n  balance_batch: True\n\n  # Number of epochs in training\n  total_epochs: 30\n\n  # Total training steps (can be set explicitly or derived from epochs)\n  total_training_steps: null\n\n  # The steps that will be profiled. null means no profiling. null or [1,2,5,...]\n  profile_steps: null\n\n  # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None.\n  ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html\n  ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html\n  controller_nsight_options:\n\n    # Select the API(s) to be traced.\n    trace: \"cuda,nvtx,cublas,ucx\"\n\n    # Track the GPU memory usage by CUDA kernels. Must be string type \"true\" or \"false\".\n    cuda-memory-usage: \"true\"\n\n    # CUDA graphs will be traced as a whole\n    cuda-graph-trace: \"graph\"\n\n  # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None.\n  worker_nsight_options:\n\n    # Select the API(s) to be traced.\n    trace: \"cuda,nvtx,cublas,ucx\"\n\n    # Track the GPU memory usage by CUDA kernels. Must be string type \"true\" or \"false\".\n    cuda-memory-usage: \"true\"\n\n    # CUDA graphs will be traced as a whole\n    cuda-graph-trace: \"graph\"\n\n    # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config.\n    capture-range: \"cudaProfilerApi\"\n\n    # Specify the desired behavior when a capture range ends.\n    # In verl we need the orch.cuda.profiler.start/stop pair to repeats n times.\n    # valid values are \"repeat-shutdown:n\" or null.\n    # For normal whole step profiling, n = len(profile_steps);\n    # but for discrete profiling, n = len(profile_steps) * Number(subtasks).\n    # Or you can just leave it null and the program will use n = len(profile_steps) * 6;\n    capture-range-end: null\n\n    # Send signal to the target application's process group. We let the program to exit by itself.\n    kill: none\n\n  # Config for npu profiler. Must set when profile_steps is not None and torch_npu is available.\n  npu_profile:\n\n    # Options for the npu profiler\n    options:\n\n      # Storage path of collected data.\n      save_path: ./profiler_data\n\n      # The roles that will be profiled. Only takes effect in discrete mode.\n      # optional values: all, rollout_generate, actor_compute_log_prob, actor_update and ref_compute_log_prob.\n      # \"all\" means all roles will be profiled.\n      roles: [\"all\"]\n\n      # Collection level, optional values: level_none, level0, level1, level2.\n      level: level1\n\n      # Whether to enable memory analysis.\n      with_memory: False\n\n      # Whether to record tensor shape.\n      record_shapes: False\n\n      # Whether to record Device-side performance data.\n      with_npu: True\n\n      # Whether to record Host-side performance data.\n      with_cpu: True\n\n      # Whether to record Python call stack information.\n      with_module: False\n\n      # Whether to record operator call stack information.\n      with_stack: False\n\n      # Whether to automatically parse the data.\n      analysis: True\n\n  # Project name for experiment tracking (e.g., wandb)\n  project_name: verl_examples\n\n  # Experiment name for run identification in tracking tools\n  experiment_name: gsm8k\n\n  # Logging backends to use: \"console\", \"wandb\", etc.\n  logger: [ 'console', 'wandb' ]\n\n  # Number of generations to log during validation\n  log_val_generations: 0\n\n  # Directory for logging rollout data; no dump if null\n  rollout_data_dir: null\n\n  # Directory for logging validation data; no dump if null\n  validation_data_dir: null\n\n  # Number of nodes used in the training\n  nnodes: 1\n\n  # Number of GPUs per node\n  n_gpus_per_node: 8\n\n  # Save frequency (by iteration) for model checkpoints\n  save_freq: -1\n\n  # ESI refers to the elastic server instance used during training, similar to the training plan. For example,\n  # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training.\n  # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance.\n  # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time.\n  # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety.\n  esi_redundant_time: 0\n\n  # Resume mode: \"auto\", \"disable\", or \"resume_path\"\n  # \"auto\": resume from last checkpoint if available\n  # \"disable\": start from scratch\n  # \"resume_path\": resume from a user-defined path\n  resume_mode: auto\n\n  # Path to resume training from (only used when resume_mode is \"resume_path\")\n  resume_from_path: null\n\n  # Whether to run validation before training begins\n  val_before_train: True\n\n  # Whether to run validation only\n  val_only: False\n\n  # Validation frequency (in training iterations)\n  test_freq: -1\n\n  # Number of iterations to warm up the critic before updating policy\n  critic_warmup: 0\n\n  # Default path to distributed filesystem for saving checkpoints\n  default_hdfs_dir: null\n\n  # Whether to delete local checkpoints after loading\n  del_local_ckpt_after_load: False\n\n  # Default local directory for saving checkpoints\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n\n  # Maximum number of actor checkpoints to keep\n  max_actor_ckpt_to_keep: null\n\n  # Maximum number of critic checkpoints to keep\n  max_critic_ckpt_to_keep: null\n\n  # Timeout (in seconds) for Ray worker to wait for registration\n  ray_wait_register_center_timeout: 300\n\n  # Device to run training on (e.g., \"cuda\", \"cpu\")\n  device: cuda\n\n# configs related to ray\nray_kwargs:\n  # configs related to ray initialization\n  ray_init:\n\n    # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM.\n    num_cpus: null\n\n  # Path to save Ray timeline JSON for performance profiling\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_distillation/tests/trainer/config/test_algo_config_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom omegaconf import OmegaConf\n\nfrom verl.trainer.config import AlgoConfig, KLControlConfig\nfrom verl.trainer.ppo.core_algos import (\n    compute_gae_advantage_return,\n    compute_grpo_outcome_advantage,\n    get_adv_estimator_fn,\n)\nfrom verl.utils.config import omega_conf_to_dataclass\n\n\nclass TestAlgoConfig(unittest.TestCase):\n    \"\"\"Test the AlgoConfig dataclass and its integration with core algorithms.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up test fixtures.\"\"\"\n        # Create a sample algorithm config as DictConfig (similar to what comes from YAML)\n        self.config_dict = {\n            \"_target_\": \"verl.trainer.config.AlgoConfig\",\n            \"gamma\": 0.99,\n            \"lam\": 0.95,\n            \"adv_estimator\": \"gae\",\n            \"norm_adv_by_std_in_grpo\": True,\n            \"use_kl_in_reward\": True,\n            \"kl_penalty\": \"kl\",\n            \"kl_ctrl\": {\n                \"_target_\": \"verl.trainer.config.KLControlConfig\",\n                \"type\": \"adaptive\",\n                \"kl_coef\": 0.002,\n                \"horizon\": 5000,\n                \"target_kl\": 0.05,\n            },\n            \"use_pf_ppo\": True,\n            \"pf_ppo\": {\"reweight_method\": \"max_min\", \"weight_pow\": 3.0},\n        }\n        self.omega_config = OmegaConf.create(self.config_dict)\n\n    def test_dataclass_creation_from_dict(self):\n        \"\"\"Test creating AlgoConfig from dictionary.\"\"\"\n        config = omega_conf_to_dataclass(self.config_dict)\n\n        self.assertIsInstance(config, AlgoConfig)\n        self.assertEqual(config.gamma, 0.99)\n        self.assertEqual(config.lam, 0.95)\n        self.assertEqual(config.adv_estimator, \"gae\")\n        self.assertTrue(config.norm_adv_by_std_in_grpo)\n        self.assertTrue(config.use_kl_in_reward)\n        self.assertEqual(config.kl_penalty, \"kl\")\n        self.assertTrue(config.use_pf_ppo)\n\n    def test_dataclass_creation_from_omega_config(self):\n        \"\"\"Test creating AlgoConfig from OmegaConf DictConfig.\"\"\"\n        config = omega_conf_to_dataclass(self.omega_config)\n\n        self.assertIsInstance(config, AlgoConfig)\n        self.assertEqual(config.gamma, 0.99)\n        self.assertEqual(config.lam, 0.95)\n\n    def test_nested_configs(self):\n        \"\"\"Test that nested configurations are properly converted.\"\"\"\n        config = omega_conf_to_dataclass(self.omega_config)\n\n        # Test KL control config\n        self.assertIsInstance(config.kl_ctrl, KLControlConfig)\n        self.assertEqual(config.kl_ctrl.type, \"adaptive\")\n        self.assertEqual(config.kl_ctrl.kl_coef, 0.002)\n        self.assertEqual(config.kl_ctrl.horizon, 5000)\n        self.assertEqual(config.kl_ctrl.target_kl, 0.05)\n\n        # Test PF PPO config\n        self.assertEqual(config.pf_ppo.get(\"reweight_method\"), \"max_min\")\n        self.assertEqual(config.pf_ppo.get(\"weight_pow\"), 3.0)\n\n    def test_default_values(self):\n        \"\"\"Test that default values are properly set.\"\"\"\n        minimal_config = {\"gamma\": 0.8}\n        config = omega_conf_to_dataclass(minimal_config, AlgoConfig)\n\n        self.assertEqual(config.gamma, 0.8)\n        self.assertEqual(config.lam, 1.0)  # default value\n        self.assertEqual(config.adv_estimator, \"gae\")  # default value\n        self.assertTrue(config.norm_adv_by_std_in_grpo)  # default value\n        self.assertFalse(config.use_kl_in_reward)  # default value\n        self.assertEqual(config.kl_penalty, \"kl\")  # default value\n        self.assertFalse(config.use_pf_ppo)  # default value\n\n    def test_get_method_backward_compatibility(self):\n        \"\"\"Test the get method for backward compatibility.\"\"\"\n        config = omega_conf_to_dataclass(self.omega_config)\n\n        # Test existing attribute\n        self.assertEqual(config.get(\"gamma\"), 0.99)\n        self.assertEqual(config.get(\"gamma\", 1.0), 0.99)\n\n        # Test non-existing attribute\n        self.assertIsNone(config.get(\"non_existing\"))\n        self.assertEqual(config.get(\"non_existing\", \"default\"), \"default\")\n\n    def test_post_init_nested_configs(self):\n        \"\"\"Test that __post_init__ properly initializes nested configs when None.\"\"\"\n        # Create config without nested configs\n        minimal_config = AlgoConfig(gamma=0.9)\n\n        # Check that nested configs are initialized\n        self.assertIsNotNone(minimal_config.kl_ctrl)\n        self.assertIsInstance(minimal_config.kl_ctrl, KLControlConfig)\n        assert not minimal_config.pf_ppo\n\n    def test_config_init_from_yaml(self):\n        import os\n\n        from hydra import compose, initialize_config_dir\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n            cfg = compose(config_name=\"ppo_trainer\")\n        algo_config = omega_conf_to_dataclass(cfg.algorithm)\n        from verl.trainer.config import AlgoConfig\n\n        assert isinstance(algo_config, AlgoConfig)\n\n\nclass TestAlgoCompute(unittest.TestCase):\n    \"\"\"Test the AlgoConfig dataclass and its integration with core algorithms.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up test fixtures.\"\"\"\n        self.algo_config = AlgoConfig(\n            gamma=0.99,\n            lam=0.95,\n            adv_estimator=\"gae\",\n            norm_adv_by_std_in_grpo=True,\n            use_kl_in_reward=True,\n            kl_penalty=\"kl\",\n            kl_ctrl=KLControlConfig(type=\"adaptive\", kl_coef=0.002, horizon=5000, target_kl=0.05),\n            use_pf_ppo=True,\n            pf_ppo={\"reweight_method\": \"max_min\", \"weight_pow\": 3.0},\n        )\n\n    def test_advantage_estimator_with_cfg(self):\n        \"\"\"Test integration with advantage estimators from core_algos.\"\"\"\n        config = self.algo_config\n\n        # Test GAE advantage estimator\n        adv_fn = get_adv_estimator_fn(config.adv_estimator)\n        self.assertIsNotNone(adv_fn)\n\n        # Test with actual GAE computation\n        batch_size, seq_len = 2, 5\n        token_level_rewards = torch.randn(batch_size, seq_len)\n        values = torch.randn(batch_size, seq_len)\n        response_mask = torch.ones(batch_size, seq_len)\n\n        advantages, returns = compute_gae_advantage_return(\n            token_level_rewards=token_level_rewards,\n            values=values,\n            response_mask=response_mask,\n            gamma=config.gamma,\n            lam=config.lam,\n        )\n\n        self.assertEqual(advantages.shape, (batch_size, seq_len))\n        self.assertEqual(returns.shape, (batch_size, seq_len))\n\n    def test_grpo_advantage_estimator_with_cfg(self):\n        \"\"\"Test integration with GRPO advantage estimator.\"\"\"\n        grpo_config = AlgoConfig(adv_estimator=\"grpo\", norm_adv_by_std_in_grpo=True)\n\n        # Test GRPO advantage computation\n        batch_size, seq_len = 4, 3\n        token_level_rewards = torch.tensor([[1.0, 0.5, 0.0], [2.0, 1.0, 0.0], [0.5, 0.2, 0.0], [1.5, 0.8, 0.0]])\n        response_mask = torch.ones(batch_size, seq_len)\n        index = np.array([0, 0, 1, 1])  # Two groups\n\n        advantages, returns = compute_grpo_outcome_advantage(\n            token_level_rewards=token_level_rewards,\n            response_mask=response_mask,\n            index=index,\n            norm_adv_by_std_in_grpo=grpo_config.norm_adv_by_std_in_grpo,\n        )\n\n        self.assertEqual(advantages.shape, (batch_size, seq_len))\n        self.assertEqual(returns.shape, (batch_size, seq_len))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/trainer/config/test_legacy_config_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport unittest\nimport warnings\n\nfrom hydra import compose, initialize_config_dir\nfrom hydra.core.global_hydra import GlobalHydra\nfrom omegaconf import OmegaConf\n\n_BREAKING_CHANGES = [\n    \"critic.optim.lr\",  # mcore critic lr init value 1e-6 -> 1e-5\n    \"actor_rollout_ref.actor.optim.lr_warmup_steps\",  # None -> -1\n    \"critic.optim.lr_warmup_steps\",  # None -> -1\n    \"actor_rollout_ref.rollout.name\",  # vllm -> ???\n    \"actor_rollout_ref.actor.megatron.expert_tensor_parallel_size\",\n    \"actor_rollout_ref.ref.megatron.expert_tensor_parallel_size\",\n    \"critic.megatron.expert_tensor_parallel_size\",\n    \"reward_model.megatron.expert_tensor_parallel_size\",\n]\n\n\nclass TestConfigComparison(unittest.TestCase):\n    \"\"\"Test that current configs match their legacy counterparts exactly.\"\"\"\n\n    ignored_keys = [\n        \"enable_gradient_checkpointing\",\n        \"gradient_checkpointing_kwargs\",\n        \"activations_checkpoint_method\",\n        \"activations_checkpoint_granularity\",\n        \"activations_checkpoint_num_layers\",\n        \"discrete\",\n        \"profiler\",\n        \"profile\",\n        \"use_profile\",\n        \"npu_profile\",\n        \"profile_steps\",\n        \"worker_nsight_options\",\n        \"controller_nsight_options\",\n    ]\n\n    def _compare_configs_recursively(\n        self, current_config, legacy_config, path=\"\", legacy_allow_missing=True, current_allow_missing=False\n    ):\n        \"\"\"Recursively compare two OmegaConf configs and assert they are identical.\n\n        Args:\n            legacy_allow_missing (bool): sometimes the legacy megatron config contains fewer keys and\n              we allow that to happen\n        \"\"\"\n        if isinstance(current_config, dict) and isinstance(legacy_config, dict):\n            current_keys = set(current_config.keys())\n            legacy_keys = set(legacy_config.keys())\n\n            missing_in_current = legacy_keys - current_keys\n            missing_in_legacy = current_keys - legacy_keys\n\n            # Ignore specific keys that are allowed to be missing\n            for key in self.ignored_keys:\n                if key in missing_in_current:\n                    missing_in_current.remove(key)\n                if key in missing_in_legacy:\n                    missing_in_legacy.remove(key)\n\n            if missing_in_current:\n                msg = f\"Keys missing in current config at {path}: {missing_in_current}\"\n                if current_allow_missing:\n                    warnings.warn(msg, stacklevel=1)\n                else:\n                    self.fail(f\"Keys missing in current config at {path}: {missing_in_current}\")\n            if missing_in_legacy:\n                # if the legacy\n                msg = f\"Keys missing in legacy config at {path}: {missing_in_legacy}\"\n                if legacy_allow_missing:\n                    warnings.warn(msg, stacklevel=1)\n                else:\n                    self.fail(msg)\n\n            for key in current_keys:\n                current_path = f\"{path}.{key}\" if path else key\n                if key in legacy_config:\n                    self._compare_configs_recursively(current_config[key], legacy_config[key], current_path)\n        elif isinstance(current_config, list) and isinstance(legacy_config, list):\n            self.assertEqual(\n                len(current_config),\n                len(legacy_config),\n                f\"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}\",\n            )\n            for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)):\n                self._compare_configs_recursively(current_item, legacy_item, f\"{path}[{i}]\")\n        elif path not in _BREAKING_CHANGES:\n            self.assertEqual(\n                current_config,\n                legacy_config,\n                f\"Values differ at {path}: current={current_config}, legacy={legacy_config}\",\n            )\n\n    def test_ppo_trainer_config_matches_legacy(self):\n        \"\"\"Test that ppo_trainer.yaml matches legacy_ppo_trainer.yaml exactly.\"\"\"\n        import os\n\n        from hydra import compose, initialize_config_dir\n        from hydra.core.global_hydra import GlobalHydra\n\n        GlobalHydra.instance().clear()\n\n        try:\n            with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n                current_config = compose(config_name=\"ppo_trainer\")\n\n            legacy_config = OmegaConf.load(\"tests/trainer/config/legacy_ppo_trainer.yaml\")\n            current_dict = OmegaConf.to_container(current_config, resolve=True)\n            legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)\n\n            if \"defaults\" in current_dict:\n                del current_dict[\"defaults\"]\n\n            self._compare_configs_recursively(current_dict, legacy_dict)\n        finally:\n            GlobalHydra.instance().clear()\n\n    def test_ppo_megatron_trainer_config_matches_legacy(self):\n        \"\"\"Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly.\"\"\"\n\n        GlobalHydra.instance().clear()\n\n        try:\n            with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n                current_config = compose(config_name=\"ppo_megatron_trainer\")\n\n            legacy_config = OmegaConf.load(\"tests/trainer/config/legacy_ppo_megatron_trainer.yaml\")\n            current_dict = OmegaConf.to_container(current_config, resolve=True)\n            legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)\n\n            if \"defaults\" in current_dict:\n                del current_dict[\"defaults\"]\n\n            self._compare_configs_recursively(\n                current_dict, legacy_dict, legacy_allow_missing=True, current_allow_missing=False\n            )\n        finally:\n            GlobalHydra.instance().clear()\n\n    def test_load_component(self):\n        \"\"\"Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly.\"\"\"\n\n        GlobalHydra.instance().clear()\n        configs_to_load = [\n            (\"verl/trainer/config/actor\", \"dp_actor\"),\n            (\"verl/trainer/config/actor\", \"megatron_actor\"),\n            (\"verl/trainer/config/ref\", \"dp_ref\"),\n            (\"verl/trainer/config/ref\", \"megatron_ref\"),\n            (\"verl/trainer/config/rollout\", \"rollout\"),\n        ]\n        for config_dir, config_file in configs_to_load:\n            try:\n                with initialize_config_dir(config_dir=os.path.abspath(config_dir)):\n                    compose(config_name=config_file)\n            finally:\n                GlobalHydra.instance().clear()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/trainer/ppo/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTests for the PPO trainer module.\n\"\"\"\n"
  },
  {
    "path": "verl_distillation/tests/trainer/ppo/test_core_algos_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport random\nimport unittest\n\nimport numpy as np\nimport pytest\nimport torch\n\nimport verl.trainer.ppo.core_algos\nfrom verl.trainer.ppo.core_algos import (\n    compute_gae_advantage_return,\n    compute_grpo_outcome_advantage,\n    compute_grpo_vectorized_outcome_advantage,\n    compute_rloo_outcome_advantage,\n    compute_rloo_vectorized_outcome_advantage,\n    get_adv_estimator_fn,\n    register_adv_est,\n)\n\n\ndef mock_test_fn():\n    pass\n\n\nclass TestRegisterAdvEst(unittest.TestCase):\n    def setUp(self):\n        \"\"\"Clear the registry before each test\"\"\"\n        verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()\n        verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY = {\n            \"gae\": lambda x: x * 2,\n            \"vtrace\": lambda x: x + 1,\n        }\n        self.ADV_ESTIMATOR_REGISTRY = verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY\n\n    def tearDown(self) -> None:\n        verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()\n        return super().tearDown()\n\n    def test_register_new_function(self):\n        \"\"\"Test registering a new function with a string name\"\"\"\n\n        @register_adv_est(\"test_estimator\")\n        def test_fn():\n            pass\n\n        self.assertIn(\"test_estimator\", self.ADV_ESTIMATOR_REGISTRY)\n        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY[\"test_estimator\"], test_fn)\n\n    def test_register_with_enum(self):\n        \"\"\"Test registering with an enum value (assuming AdvantageEstimator exists)\"\"\"\n        from enum import Enum\n\n        class AdvantageEstimator(Enum):\n            TEST = \"test_enum_estimator\"\n\n        @register_adv_est(AdvantageEstimator.TEST)\n        def test_fn():\n            pass\n\n        self.assertIn(\"test_enum_estimator\", self.ADV_ESTIMATOR_REGISTRY)\n        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY[\"test_enum_estimator\"], test_fn)\n\n    def test_duplicate_registration_same_function(self):\n        \"\"\"Test that registering the same function twice doesn't raise an error\"\"\"\n        register_adv_est(\"duplicate_test\")(mock_test_fn)\n        register_adv_est(\"duplicate_test\")(mock_test_fn)\n\n        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY[\"duplicate_test\"], mock_test_fn)\n\n    def test_duplicate_registration_different_function(self):\n        \"\"\"Test that registering different functions with same name raises ValueError\"\"\"\n\n        @register_adv_est(\"conflict_test\")\n        def test_fn1():\n            pass\n\n        with self.assertRaises(ValueError):\n\n            @register_adv_est(\"conflict_test\")\n            def test_fn2():\n                pass\n\n    def test_decorator_preserves_function(self):\n        \"\"\"Test that the decorator returns the original function\"\"\"\n\n        def test_fn():\n            return \"original\"\n\n        decorated = register_adv_est(\"preserve_test\")(test_fn)\n        self.assertEqual(decorated(), \"original\")\n\n    def test_multiple_registrations(self):\n        \"\"\"Test registering multiple different functions\"\"\"\n        init_adv_count = len(self.ADV_ESTIMATOR_REGISTRY)\n\n        @register_adv_est(\"estimator1\")\n        def fn1():\n            pass\n\n        @register_adv_est(\"estimator2\")\n        def fn2():\n            pass\n\n        self.assertEqual(len(self.ADV_ESTIMATOR_REGISTRY), 2 + init_adv_count)\n        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY[\"estimator1\"], fn1)\n        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY[\"estimator2\"], fn2)\n\n    def test_get_adv_estimator_fn_valid_names(self):\n        \"\"\"Test that valid names return the correct function from registry.\"\"\"\n        # Test GAE\n        gae_fn = get_adv_estimator_fn(\"gae\")\n        assert gae_fn(5) == 10  # 5 * 2 = 10\n\n        # Test Vtrace\n        vtrace_fn = get_adv_estimator_fn(\"vtrace\")\n        assert vtrace_fn(5) == 6  # 5 + 1 = 6\n\n    def test_get_adv_estimator_fn_invalid_name(self):\n        \"\"\"Test that invalid names raise ValueError.\"\"\"\n        with pytest.raises(ValueError) as excinfo:\n            get_adv_estimator_fn(\"invalid_name\")\n        assert \"Unknown advantage estimator simply: invalid_name\" in str(excinfo.value)\n\n    def test_get_adv_estimator_fn_case_sensitive(self):\n        \"\"\"Test that name lookup is case-sensitive.\"\"\"\n        with pytest.raises(ValueError):\n            get_adv_estimator_fn(\"GAE\")  # Different case\n\n\ndef test_multi_turn_compute_gae_advantage_return():\n    \"\"\"Test multi-turn GAE skip observation tokens.\"\"\"\n    gamma = random.uniform(0.0, 1.0)\n    lam = random.uniform(0.0, 1.0)\n\n    rewards = torch.tensor([[0.0, 0.0, 0.1, 0.1, 0.1, 0.0, 0.0, 0.1, 1.0, 0.0, 0.0]], dtype=torch.float)\n\n    values1 = torch.tensor(\n        [\n            [\n                random.uniform(-100.0, 100.0),\n                random.random(),\n                4.0,\n                5.0,\n                6.0,\n                random.uniform(-100.0, 0),\n                random.random(),\n                7.0,\n                9.0,\n                0.0,\n                0.0,\n            ]\n        ],\n        dtype=torch.float,\n    )\n\n    values2 = torch.tensor(\n        [\n            [\n                random.random(),\n                random.uniform(-100.0, 100.0),\n                4.0,\n                5.0,\n                6.0,\n                random.random(),\n                random.uniform(0.0, 100.0),\n                7.0,\n                9.0,\n                0.0,\n                0.0,\n            ]\n        ],\n        dtype=torch.float,\n    )\n\n    response_mask = torch.tensor([[0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float)\n\n    adv1, ret1 = compute_gae_advantage_return(rewards, values1, response_mask, gamma, lam)\n    adv2, ret2 = compute_gae_advantage_return(rewards, values2, response_mask, gamma, lam)\n\n    ret1 *= response_mask\n    ret2 *= response_mask\n    assert torch.equal(adv1, adv2), f\"{adv1=}, {adv2=}\"\n    assert torch.equal(ret1, ret2), f\"{ret1=}, {ret2=}\"\n    print(f\" [CORRECT] \\n\\n{adv1=}, \\n\\n{ret1=}\")\n\n\ndef _make_group_index(batch_size: int, num_groups: int) -> np.ndarray:\n    \"\"\"Create a numpy index array ensuring each group has at least 2 samples.\"\"\"\n    assert num_groups * 2 <= batch_size, \"batch_size must allow >=2 samples per group\"\n    counts: list[int] = [2] * num_groups\n    remaining = batch_size - 2 * num_groups\n    for _ in range(remaining):\n        counts[random.randrange(num_groups)] += 1\n    index = []\n    for gid, c in enumerate(counts):\n        index.extend([gid] * c)\n    random.shuffle(index)\n    return np.asarray(index, dtype=np.int64)\n\n\ndef _rand_mask(batch_size: int, seq_len: int) -> torch.Tensor:\n    mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64).float()\n    rows_without_one = (mask.sum(dim=-1) == 0).nonzero(as_tuple=True)[0]\n    if len(rows_without_one) > 0:\n        mask[rows_without_one, -1] = 1.0\n    return mask\n\n\n@pytest.mark.parametrize(\n    \"batch_size,seq_len,num_groups,seed\",\n    [\n        (64, 128, 5, 0),\n        (128, 256, 8, 1),\n        (512, 512, 10, 2),\n    ],\n)\ndef test_rloo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):\n    torch.manual_seed(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    index = _make_group_index(batch_size, num_groups)\n    response_mask = _rand_mask(batch_size, seq_len)\n    base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)\n    token_level_rewards = base_rewards * response_mask\n    adv1, ret1 = compute_rloo_outcome_advantage(\n        token_level_rewards=token_level_rewards,\n        response_mask=response_mask,\n        index=index,\n    )\n    adv2, ret2 = compute_rloo_vectorized_outcome_advantage(\n        token_level_rewards=token_level_rewards,\n        response_mask=response_mask,\n        index=index,\n    )\n    # Print concise diagnostics for visibility during test runs\n    adv_max_diff = (adv1 - adv2).abs().max().item()\n    ret_max_diff = (ret1 - ret2).abs().max().item()\n    total_mask_tokens = int(response_mask.sum().item())\n    print(\n        f\"[RLOO] seed={seed} groups={num_groups} shape={adv1.shape} \"\n        f\"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}\"\n    )\n    assert adv1.shape == adv2.shape == (batch_size, seq_len)\n    assert ret1.shape == ret2.shape == (batch_size, seq_len)\n    assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)\n    assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)\n\n\n@pytest.mark.parametrize(\n    \"batch_size,seq_len,num_groups,seed\",\n    [\n        (64, 128, 5, 0),\n        (128, 256, 8, 1),\n        (512, 512, 10, 2),\n    ],\n)\ndef test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):\n    # Set seeds for reproducibility\n    torch.manual_seed(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n\n    # Generate group indices (numpy array of shape [batch_size])\n    index = _make_group_index(batch_size, num_groups)\n\n    # Generate binary response mask (at least one valid token per row)\n    response_mask = _rand_mask(batch_size, seq_len)\n\n    # Generate token-level rewards and apply mask\n    base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)\n    token_level_rewards = base_rewards * response_mask\n\n    # Compute GRPO outcome advantage (original implementation)\n    adv1, ret1 = compute_grpo_outcome_advantage(\n        token_level_rewards=token_level_rewards,\n        response_mask=response_mask,\n        index=index,\n    )\n\n    # Compute GRPO outcome advantage (vectorized implementation)\n    adv2, ret2 = compute_grpo_vectorized_outcome_advantage(\n        token_level_rewards=token_level_rewards,\n        response_mask=response_mask,\n        index=index,\n    )\n\n    # Diagnostic info for visibility (same style as RLOO test)\n    adv_max_diff = (adv1 - adv2).abs().max().item()\n    ret_max_diff = (ret1 - ret2).abs().max().item()\n    total_mask_tokens = int(response_mask.sum().item())\n    print(\n        f\"[GRPO] seed={seed} groups={num_groups} shape={adv1.shape} \"\n        f\"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}\"\n    )\n\n    # Assert shape and numerical equivalence\n    assert adv1.shape == adv2.shape == (batch_size, seq_len)\n    assert ret1.shape == ret2.shape == (batch_size, seq_len)\n    assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)\n    assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/trainer/ppo/test_metric_utils_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTests for the metric utilities in verl.trainer.ppo.metric_utils.\n\"\"\"\n\nimport unittest\nfrom unittest.mock import MagicMock, patch\n\nimport numpy as np\nimport torch\n\nfrom verl.trainer.ppo.metric_utils import (\n    bootstrap_metric,\n    calc_maj_val,\n    compute_data_metrics,\n    compute_throughout_metrics,\n    compute_timing_metrics,\n    process_validation_metrics,\n)\nfrom verl.utils.metric import (\n    reduce_metrics,\n)\n\n\nclass TestReduceMetrics(unittest.TestCase):\n    \"\"\"Tests for the reduce_metrics function.\"\"\"\n\n    def test_reduce_metrics_basic(self):\n        \"\"\"Test that reduce_metrics correctly computes means.\"\"\"\n        metrics = {\n            \"loss\": [1.0, 2.0, 3.0],\n            \"accuracy\": [0.0, 0.5, 1.0],\n        }\n        result = reduce_metrics(metrics)\n\n        self.assertEqual(result[\"loss\"], 2.0)\n        self.assertEqual(result[\"accuracy\"], 0.5)\n\n    def test_reduce_metrics_empty(self):\n        \"\"\"Test that reduce_metrics handles empty lists.\"\"\"\n        metrics = {\n            \"empty\": [],\n        }\n        result = reduce_metrics(metrics)\n\n        self.assertTrue(np.isnan(result[\"empty\"]))\n\n    def test_reduce_metrics_single_value(self):\n        \"\"\"Test that reduce_metrics works with single values.\"\"\"\n        metrics = {\n            \"single\": [5.0],\n        }\n        result = reduce_metrics(metrics)\n\n        self.assertEqual(result[\"single\"], 5.0)\n\n\nclass TestComputeDataMetrics(unittest.TestCase):\n    \"\"\"Tests for the compute_data_metrics function.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up common test data.\"\"\"\n        # Create a mock DataProto object\n        self.batch = MagicMock()\n        self.batch.batch = {\n            \"token_level_scores\": torch.tensor([[1.0, 2.0], [3.0, 4.0]]),\n            \"token_level_rewards\": torch.tensor([[0.5, 1.0], [1.5, 2.0]]),\n            \"advantages\": torch.tensor([[0.1, 0.2], [0.3, 0.4]]),\n            \"returns\": torch.tensor([[1.1, 1.2], [1.3, 1.4]]),\n            \"responses\": torch.zeros((2, 2)),  # 2 samples, 2 tokens each\n            \"attention_mask\": torch.tensor(\n                [\n                    [1, 1, 1, 1],  # 2 prompt tokens, 2 response tokens\n                    [1, 1, 1, 1],\n                ]\n            ),\n            \"response_mask\": torch.tensor(\n                [\n                    [1, 1],  # 2 response tokens\n                    [1, 1],\n                ]\n            ),\n            \"values\": torch.tensor([[0.9, 1.0], [1.1, 1.2]]),\n        }\n\n    def test_compute_data_metrics_with_critic(self):\n        \"\"\"Test compute_data_metrics with critic enabled.\"\"\"\n        metrics = compute_data_metrics(self.batch, use_critic=True)\n\n        # Check that all expected metrics are present\n        self.assertIn(\"critic/score/mean\", metrics)\n        self.assertIn(\"critic/rewards/mean\", metrics)\n        self.assertIn(\"critic/advantages/mean\", metrics)\n        self.assertIn(\"critic/returns/mean\", metrics)\n        self.assertIn(\"critic/values/mean\", metrics)\n        self.assertIn(\"critic/vf_explained_var\", metrics)\n        self.assertIn(\"response_length/mean\", metrics)\n        self.assertIn(\"prompt_length/mean\", metrics)\n\n        # Check some specific values\n        self.assertAlmostEqual(metrics[\"critic/score/mean\"], 5.0)  # Sum of token_level_scores\n        self.assertAlmostEqual(metrics[\"critic/rewards/mean\"], 2.5)  # Sum of token_level_rewards\n\n    def test_compute_data_metrics_without_critic(self):\n        \"\"\"Test compute_data_metrics with critic disabled.\"\"\"\n        metrics = compute_data_metrics(self.batch, use_critic=False)\n\n        # Check that critic-specific metrics are not present\n        self.assertNotIn(\"critic/values/mean\", metrics)\n        self.assertNotIn(\"critic/vf_explained_var\", metrics)\n\n        # Check that other metrics are still present\n        self.assertIn(\"critic/score/mean\", metrics)\n        self.assertIn(\"critic/rewards/mean\", metrics)\n        self.assertIn(\"response_length/mean\", metrics)\n\n\nclass TestComputeTimingMetrics(unittest.TestCase):\n    \"\"\"Tests for the compute_timing_metrics function.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up common test data.\"\"\"\n        # Create a mock DataProto object\n        self.batch = MagicMock()\n        self.batch.batch = {\n            \"responses\": torch.zeros((2, 3)),  # 2 samples, 3 response tokens each\n            \"attention_mask\": torch.tensor(\n                [\n                    [1, 1, 1, 1, 1, 1],  # 3 prompt tokens, 3 response tokens\n                    [1, 1, 1, 1, 1, 1],\n                ]\n            ),\n        }\n\n        # Mock the _compute_response_info function to return known values\n        self.response_info = {\n            \"prompt_length\": torch.tensor([3.0, 3.0]),\n            \"response_length\": torch.tensor([3.0, 3.0]),\n            \"response_mask\": torch.ones((2, 3)),\n        }\n\n    @patch(\"verl.trainer.ppo.metric_utils._compute_response_info\")\n    def test_compute_timing_metrics(self, mock_compute_response_info):\n        \"\"\"Test compute_timing_metrics with various timing data.\"\"\"\n        mock_compute_response_info.return_value = self.response_info\n\n        timing_raw = {\n            \"gen\": 0.5,  # 500ms\n            \"ref\": 0.3,  # 300ms\n            \"values\": 0.2,  # 200ms\n        }\n\n        metrics = compute_timing_metrics(self.batch, timing_raw)\n\n        # Check raw timing metrics\n        self.assertEqual(metrics[\"timing_s/gen\"], 0.5)\n        self.assertEqual(metrics[\"timing_s/ref\"], 0.3)\n        self.assertEqual(metrics[\"timing_s/values\"], 0.2)\n\n        # Check per-token timing metrics\n        # gen uses only response tokens (6 tokens)\n        self.assertAlmostEqual(metrics[\"timing_per_token_ms/gen\"], 0.5 * 1000 / 6, places=5)\n\n        # ref and values use all tokens (12 tokens)\n        self.assertAlmostEqual(metrics[\"timing_per_token_ms/ref\"], 0.3 * 1000 / 12, places=5)\n        self.assertAlmostEqual(metrics[\"timing_per_token_ms/values\"], 0.2 * 1000 / 12, places=5)\n\n\nclass TestComputeThroughputMetrics(unittest.TestCase):\n    \"\"\"Tests for the compute_throughout_metrics function.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up common test data.\"\"\"\n        # Create a mock DataProto object\n        self.batch = MagicMock()\n        self.batch.meta_info = {\n            \"global_token_num\": [100, 200, 300],  # 600 tokens total\n        }\n\n    def test_compute_throughout_metrics(self):\n        \"\"\"Test compute_throughout_metrics with various timing data.\"\"\"\n        timing_raw = {\n            \"step\": 2.0,  # 2 seconds per step\n        }\n\n        # Test with 1 GPU\n        metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=1)\n\n        self.assertEqual(metrics[\"perf/total_num_tokens\"], 600)\n        self.assertEqual(metrics[\"perf/time_per_step\"], 2.0)\n        self.assertEqual(metrics[\"perf/throughput\"], 600 / 2.0)  # 300 tokens/sec\n\n        # Test with 2 GPUs\n        metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=2)\n\n        self.assertEqual(metrics[\"perf/total_num_tokens\"], 600)\n        self.assertEqual(metrics[\"perf/time_per_step\"], 2.0)\n        self.assertEqual(metrics[\"perf/throughput\"], 600 / (2.0 * 2))  # 150 tokens/sec/GPU\n\n\nclass TestBootstrapMetric(unittest.TestCase):\n    \"\"\"Tests for the bootstrap_metric function.\"\"\"\n\n    def test_bootstrap_metric_basic(self):\n        \"\"\"Test bootstrap_metric with simple data and functions.\"\"\"\n        data = [1, 2, 3, 4, 5]\n        reduce_fns = [np.mean, np.max]\n\n        # Use a fixed seed for reproducibility\n        result = bootstrap_metric(data, subset_size=3, reduce_fns=reduce_fns, n_bootstrap=100, seed=42)\n\n        # Check that we get two results (one for each reduce_fn)\n        self.assertEqual(len(result), 2)\n\n        # Each result should be a tuple of (mean, std)\n        mean_result, max_result = result\n        self.assertEqual(len(mean_result), 2)\n        self.assertEqual(len(max_result), 2)\n\n        # The mean of means should be close to the true mean (3.0)\n        self.assertAlmostEqual(mean_result[0], 3.0, delta=0.3)\n\n        # The mean of maxes should be close to the expected value for samples of size 3\n        # For samples of size 3 from [1,2,3,4,5], the expected max is around 4.0-4.5\n        self.assertGreater(max_result[0], 3.5)\n        self.assertLess(max_result[0], 5.0)\n\n    def test_bootstrap_metric_empty(self):\n        \"\"\"Test bootstrap_metric with empty data.\"\"\"\n        with self.assertRaises(ValueError):\n            bootstrap_metric([], subset_size=1, reduce_fns=[np.mean])\n\n\nclass TestCalcMajVal(unittest.TestCase):\n    \"\"\"Tests for the calc_maj_val function.\"\"\"\n\n    def test_calc_maj_val_basic(self):\n        \"\"\"Test calc_maj_val with simple data.\"\"\"\n        data = [\n            {\"pred\": \"A\", \"val\": 0.9},\n            {\"pred\": \"B\", \"val\": 0.8},\n            {\"pred\": \"A\", \"val\": 0.7},\n        ]\n\n        result = calc_maj_val(data, vote_key=\"pred\", val_key=\"val\")\n\n        # \"A\" is the majority vote, so we should get the first \"val\" for \"A\"\n        self.assertEqual(result, 0.9)\n\n    def test_calc_maj_val_tie(self):\n        \"\"\"Test calc_maj_val with tied votes.\"\"\"\n        data = [\n            {\"pred\": \"A\", \"val\": 0.9},\n            {\"pred\": \"B\", \"val\": 0.8},\n            {\"pred\": \"B\", \"val\": 0.7},\n            {\"pred\": \"A\", \"val\": 0.6},\n        ]\n\n        # In case of a tie, the first key in sorted order wins\n        # This depends on Python's dict implementation, but for this test\n        # we just verify that one of the valid values is returned\n        result = calc_maj_val(data, vote_key=\"pred\", val_key=\"val\")\n\n        self.assertTrue(result in [0.9, 0.8])\n\n\nclass TestProcessValidationMetrics(unittest.TestCase):\n    \"\"\"Tests for the process_validation_metrics function.\"\"\"\n\n    def test_process_validation_metrics_basic(self):\n        \"\"\"Test process_validation_metrics with simple data.\"\"\"\n        data_sources = [\"source1\", \"source1\", \"source2\"]\n        sample_inputs = [\"prompt1\", \"prompt1\", \"prompt2\"]\n        infos_dict = {\n            \"score\": [0.8, 0.9, 0.7],\n        }\n\n        result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42)\n\n        # Check the structure of the result\n        self.assertIn(\"source1\", result)\n        self.assertIn(\"source2\", result)\n\n        # Check that source1 has metrics for score\n        self.assertIn(\"score\", result[\"source1\"])\n\n        # Check that mean@2 is present for source1/score\n        self.assertIn(\"mean@2\", result[\"source1\"][\"score\"])\n\n        # Check the value of mean@2 for source1/score\n        self.assertAlmostEqual(result[\"source1\"][\"score\"][\"mean@2\"], 0.85)\n\n    def test_process_validation_metrics_with_pred(self):\n        \"\"\"Test process_validation_metrics with prediction data.\"\"\"\n        data_sources = [\"source1\", \"source1\", \"source1\"]\n        sample_inputs = [\"prompt1\", \"prompt1\", \"prompt1\"]\n        infos_dict = {\n            \"score\": [0.8, 0.9, 0.7],\n            \"pred\": [\"A\", \"B\", \"A\"],\n        }\n\n        result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42)\n\n        # Check that majority voting metrics are present\n        self.assertIn(\"maj@2/mean\", result[\"source1\"][\"score\"])\n\n        # For bootstrap with n=2, the majority vote could be either A or B\n        # depending on the random sampling, so we don't check the exact value\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/trainer/ppo/test_rollout_is.py",
    "content": "#!/usr/bin/env python3\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nQuick Sanity Test for Rollout Importance Sampling\n\nThis is a standalone test script that can be run without pytest to quickly verify\nthe rollout IS implementation is working correctly. For comprehensive integration\ntests, see: tests/trainer/ppo/test_rollout_is_integration.py\n\nUsage:\n    python test_rollout_is.py\n\nThis tests:\n- Basic rollout IS functionality (3 levels, 2 modes)\n- Metrics completeness (32 total: 21 IS + 11 mismatch metrics)\n- Veto mechanism\n- Edge cases\n\"\"\"\n\nimport torch\n\nfrom verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights\n\n\ndef test_basic_rollout_is():\n    \"\"\"Test basic rollout IS functionality.\"\"\"\n    print(\"Testing basic rollout IS functionality...\")\n\n    # Create test data\n    batch_size, seq_length = 4, 10\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    # Create slightly different log probs (simulating BF16 vs FP32 mismatch)\n    old_log_prob = torch.randn(batch_size, seq_length, device=device)\n    rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.1\n    eos_mask = torch.ones(batch_size, seq_length, device=device)\n\n    # Test token-level truncate mode\n    print(\"\\n1. Testing token-level truncate mode...\")\n    weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights(\n        old_log_prob=old_log_prob,\n        rollout_log_prob=rollout_log_prob,\n        response_mask=eos_mask,\n        rollout_is_level=\"token\",\n        rollout_is_mode=\"truncate\",\n        rollout_is_threshold=2.0,\n        rollout_is_veto_threshold=1e-4,\n    )\n\n    weights = weights_proto.batch[\"rollout_is_weights\"]\n    print(f\"   Weights shape: {weights.shape}\")\n    print(f\"   Mean weight: {metrics['mismatch/rollout_is_mean']:.4f}\")\n    print(f\"   Max weight: {metrics['mismatch/rollout_is_max']:.4f}\")\n    print(f\"   Min weight: {metrics['mismatch/rollout_is_min']:.4f}\")\n    print(f\"   Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.4f}\")\n    assert weights.shape == old_log_prob.shape\n    assert weights.max() <= 2.0, \"Weights should be capped at threshold\"\n    print(\"   ✓ Token-level truncate mode passed\")\n\n    # Test sequence-level mode\n    print(\"\\n2. Testing sequence-level mode...\")\n    weights_seq_proto, _, metrics_seq = compute_rollout_importance_weights(\n        old_log_prob=old_log_prob,\n        rollout_log_prob=rollout_log_prob,\n        response_mask=eos_mask,\n        rollout_is_level=\"sequence\",\n        rollout_is_mode=\"truncate\",\n        rollout_is_threshold=5.0,\n        rollout_is_veto_threshold=1e-4,\n    )\n\n    weights_seq = weights_seq_proto.batch[\"rollout_is_weights\"]\n    print(f\"   Mean weight: {metrics_seq['mismatch/rollout_is_mean']:.4f}\")\n    print(f\"   Effective sample size: {metrics_seq['mismatch/rollout_is_eff_sample_size']:.4f}\")\n    # Check that all tokens in a sequence have the same weight\n    for i in range(batch_size):\n        seq_weights = weights_seq[i, eos_mask[i].bool()]\n        assert torch.allclose(seq_weights, seq_weights[0]), \"All tokens in sequence should have same weight\"\n    print(\"   ✓ Sequence-level mode passed\")\n\n    # Test geometric mean mode\n    print(\"\\n3. Testing geometric mean mode...\")\n    weights_geo_proto, _, metrics_geo = compute_rollout_importance_weights(\n        old_log_prob=old_log_prob,\n        rollout_log_prob=rollout_log_prob,\n        response_mask=eos_mask,\n        rollout_is_level=\"geometric\",\n        rollout_is_mode=\"mask\",\n        rollout_is_threshold=1.5,\n        rollout_is_threshold_lower=0.5,\n        rollout_is_veto_threshold=1e-4,\n    )\n\n    print(f\"   Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}\")\n    print(f\"   Masked fraction: {metrics_geo['mismatch/rollout_is_masked_fraction']:.4f}\")\n    print(\"   ✓ Geometric mean mode passed\")\n\n    # Test veto mechanism\n    print(\"\\n4. Testing veto mechanism...\")\n    # Create data with catastrophic outliers\n    old_log_prob_veto = torch.randn(2, 5, device=device)\n    rollout_log_prob_veto = old_log_prob_veto.clone()\n    # Make one token have catastrophically low ratio\n    rollout_log_prob_veto[0, 2] = old_log_prob_veto[0, 2] + 15.0  # ratio ~= 3e-7\n    eos_mask_veto = torch.ones(2, 5, device=device)\n\n    weights_veto_proto, modified_response_mask_veto, metrics_veto = compute_rollout_importance_weights(\n        old_log_prob=old_log_prob_veto,\n        rollout_log_prob=rollout_log_prob_veto,\n        response_mask=eos_mask_veto,\n        rollout_is_level=\"token\",\n        rollout_is_mode=\"truncate\",\n        rollout_is_threshold=2.0,\n        rollout_is_veto_threshold=1e-4,\n    )\n\n    weights_veto = weights_veto_proto.batch[\"rollout_is_weights\"]\n    print(f\"   Veto fraction: {metrics_veto['mismatch/rollout_is_veto_fraction']:.4f}\")\n    # KEY FIX: Veto is applied via response_mask, not by zeroing weights\n    # Check that weights are NON-ZERO (safety-bounded ratios preserved, not zeroed)\n    assert weights_veto[0].sum() > 0, \"Weights should be non-zero (not zeroed by veto)\"\n    # Check that response_mask has veto applied\n    assert modified_response_mask_veto[0].sum() == 0, \"Vetoed sequence should have response_mask zeroed\"\n    assert modified_response_mask_veto[1].sum() > 0, \"Normal sequence should have response_mask unchanged\"\n    print(\"   ✓ Veto mechanism passed\")\n\n    # Test disabled IS (threshold=None)\n    print(\"\\n5. Testing disabled IS...\")\n    weights_disabled, modified_response_mask_disabled, metrics_disabled = compute_rollout_importance_weights(\n        old_log_prob=old_log_prob,\n        rollout_log_prob=rollout_log_prob,\n        response_mask=eos_mask,\n        rollout_is_threshold=None,\n    )\n\n    assert weights_disabled is None, \"Should return None when threshold is None\"\n    assert torch.equal(modified_response_mask_disabled, eos_mask), \"Should return original mask unchanged\"\n    assert len(metrics_disabled) == 0, \"Should return empty metrics when disabled\"\n    print(\"   ✓ Disabled IS passed\")\n\n    print(\"\\n✓ All tests passed!\")\n\n\ndef test_metrics_completeness():\n    \"\"\"Test that all expected metrics are returned.\"\"\"\n    print(\"\\nTesting metrics completeness...\")\n\n    batch_size, seq_length = 3, 8\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    old_log_prob = torch.randn(batch_size, seq_length, device=device)\n    rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2\n    eos_mask = torch.ones(batch_size, seq_length, device=device)\n\n    _, _, metrics = compute_rollout_importance_weights(\n        old_log_prob=old_log_prob,\n        rollout_log_prob=rollout_log_prob,\n        response_mask=eos_mask,\n        rollout_is_level=\"token\",\n        rollout_is_mode=\"truncate\",\n        rollout_is_threshold=2.5,\n    )\n\n    # Expected IS metrics\n    expected_is_metrics = [\n        \"mismatch/rollout_is_mean\",\n        \"mismatch/rollout_is_max\",\n        \"mismatch/rollout_is_min\",\n        \"mismatch/rollout_is_std\",\n        \"mismatch/rollout_is_eff_sample_size\",\n        \"mismatch/rollout_is_veto_fraction\",\n        \"mismatch/rollout_is_catastrophic_token_fraction\",\n        \"mismatch/rollout_is_ratio_fraction_high\",\n        \"mismatch/rollout_is_ratio_fraction_low\",\n    ]\n\n    # Expected mismatch/diagnostic metrics (also included now)\n    expected_mismatch_metrics = [\n        \"mismatch/mismatch_training_ppl\",\n        \"mismatch/mismatch_training_log_ppl\",\n        \"mismatch/mismatch_kl\",\n        \"mismatch/mismatch_k3_kl\",\n        \"mismatch/mismatch_rollout_ppl\",\n        \"mismatch/mismatch_rollout_log_ppl\",\n        \"mismatch/mismatch_log_ppl_diff\",\n        \"mismatch/mismatch_log_ppl_abs_diff\",\n        \"mismatch/mismatch_log_ppl_diff_max\",\n        \"mismatch/mismatch_log_ppl_diff_min\",\n        \"mismatch/mismatch_ppl_ratio\",\n    ]\n\n    expected_metrics = expected_is_metrics + expected_mismatch_metrics\n\n    missing_metrics = [m for m in expected_metrics if m not in metrics]\n    if missing_metrics:\n        print(f\"   ✗ Missing metrics: {missing_metrics}\")\n        return False\n\n    print(f\"   ✓ All {len(expected_metrics)} expected metrics present\")\n    print(f\"   Total metrics returned: {len(metrics)}\")\n    return True\n\n\ndef test_mismatch_metrics():\n    \"\"\"Test mismatch metrics computation.\"\"\"\n    print(\"\\nTesting mismatch metrics computation...\")\n\n    batch_size, seq_length = 4, 12\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    # Create test data with some mismatch\n    old_log_prob = torch.randn(batch_size, seq_length, device=device) - 2.0  # training policy\n    rollout_log_prob = torch.randn(batch_size, seq_length, device=device) - 1.5  # rollout policy (more confident)\n    response_mask = torch.ones(batch_size, seq_length, device=device)\n\n    # Test with rollout log probs\n    metrics = compute_mismatch_metrics(\n        old_log_prob=old_log_prob,\n        rollout_log_prob=rollout_log_prob,\n        response_mask=response_mask,\n    )\n\n    expected_metrics = [\n        \"mismatch_training_ppl\",\n        \"mismatch_training_log_ppl\",\n        \"mismatch_kl\",\n        \"mismatch_k3_kl\",\n        \"mismatch_rollout_ppl\",\n        \"mismatch_rollout_log_ppl\",\n        \"mismatch_log_ppl_diff\",\n        \"mismatch_log_ppl_abs_diff\",\n        \"mismatch_log_ppl_diff_max\",\n        \"mismatch_log_ppl_diff_min\",\n        \"mismatch_ppl_ratio\",\n    ]\n\n    for metric in expected_metrics:\n        assert metric in metrics, f\"Missing metric: {metric}\"\n\n    print(f\"   Training PPL: {metrics['mismatch_training_ppl']:.4f}\")\n    print(f\"   Rollout PPL: {metrics['mismatch_rollout_ppl']:.4f}\")\n    print(f\"   KL divergence: {metrics['mismatch_kl']:.6f}\")\n    print(f\"   K3 KL: {metrics['mismatch_k3_kl']:.6f}\")\n    print(f\"   PPL ratio: {metrics['mismatch_ppl_ratio']:.4f}\")\n    print(f\"   ✓ All {len(expected_metrics)} mismatch metrics present\")\n\n    # Test without rollout log probs\n    metrics_no_rollout = compute_mismatch_metrics(\n        old_log_prob=old_log_prob,\n        rollout_log_prob=None,\n        response_mask=response_mask,\n    )\n\n    assert \"mismatch_training_ppl\" in metrics_no_rollout\n    assert \"mismatch_rollout_ppl\" not in metrics_no_rollout\n    print(\"   ✓ Mismatch metrics work without rollout log probs\")\n\n\ndef test_mask_mode():\n    \"\"\"Test mask mode applies rejection via response_mask, keeps true IS weights.\"\"\"\n    print(\"\\nTesting mask mode behavior...\")\n\n    batch_size = 2\n    seq_length = 5\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    # Sequence 0: ratio ≈ 0.37 (below 0.5, should be rejected)\n    # Sequence 1: ratio ≈ 1.65 (in [0.5, 2.0], should be accepted)\n    old_log_prob = torch.tensor([[-2.0] * seq_length, [-2.0] * seq_length], device=device)\n    rollout_log_prob = torch.tensor(\n        [\n            [-1.0] * seq_length,  # exp(-2.0 - (-1.0)) = exp(-1.0) ≈ 0.37\n            [-2.5] * seq_length,  # exp(-2.0 - (-2.5)) = exp(0.5) ≈ 1.65\n        ],\n        device=device,\n    )\n    response_mask = torch.ones(batch_size, seq_length, device=device)\n\n    weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights(\n        old_log_prob=old_log_prob,\n        rollout_log_prob=rollout_log_prob,\n        response_mask=response_mask,\n        rollout_is_level=\"token\",\n        rollout_is_mode=\"mask\",\n        rollout_is_threshold=2.0,\n        rollout_is_threshold_lower=0.5,\n        rollout_is_veto_threshold=None,\n    )\n\n    weights = weights_proto.batch[\"rollout_is_weights\"]\n\n    # KEY FIX: Weights should be safety-bounded ratios (NOT zeroed)\n    assert torch.all(weights[0, :] > 0), \"Weights should remain as safety-bounded ratios (not zeroed)\"\n    assert torch.allclose(weights[0, 0], torch.tensor(0.368, device=device), atol=0.01), (\n        \"First seq ratio should be ≈0.37\"\n    )\n    assert torch.allclose(weights[1, 0], torch.tensor(1.649, device=device), atol=0.01), (\n        \"Second seq ratio should be ≈1.65\"\n    )\n\n    # Rejection should be applied via response_mask\n    assert torch.all(modified_response_mask[0, :] == 0), \"First sequence should be rejected via mask\"\n    assert torch.all(modified_response_mask[1, :] == 1), \"Second sequence should be accepted\"\n\n    # Verify mask metrics exist\n    assert \"mismatch/rollout_is_masked_fraction\" in metrics\n    assert abs(metrics[\"mismatch/rollout_is_masked_fraction\"] - 0.5) < 0.01, \"Should reject 50% of tokens\"\n\n    print(f\"   First seq IS weight: {weights[0, 0]:.4f} (expected ≈0.37)\")\n    print(f\"   Second seq IS weight: {weights[1, 0]:.4f} (expected ≈1.65)\")\n    print(f\"   First seq mask: {modified_response_mask[0, 0]:.0f} (expected 0 - rejected)\")\n    print(f\"   Second seq mask: {modified_response_mask[1, 0]:.0f} (expected 1 - accepted)\")\n    print(f\"   Masked fraction: {metrics['mismatch/rollout_is_masked_fraction']:.2f}\")\n    print(\"   ✓ Mask mode correctly separates IS weights from rejection\")\n\n\nif __name__ == \"__main__\":\n    print(\"=\" * 60)\n    print(\"Rollout Importance Sampling Test Suite\")\n    print(\"=\" * 60)\n\n    try:\n        test_basic_rollout_is()\n        test_metrics_completeness()\n        test_mismatch_metrics()\n        test_mask_mode()\n        print(\"\\n\" + \"=\" * 60)\n        print(\"ALL TESTS PASSED ✓\")\n        print(\"=\" * 60)\n    except Exception as e:\n        print(f\"\\n✗ Test failed with error: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        exit(1)\n"
  },
  {
    "path": "verl_distillation/tests/trainer/ppo/test_rollout_is_integration.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Integration tests for Rollout Importance Sampling.\"\"\"\n\nimport pytest\nimport torch\n\nfrom verl.trainer.ppo.core_algos import compute_policy_loss_vanilla\nfrom verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights\nfrom verl.workers.config.actor import ActorConfig\n\n\nclass TestRolloutISIntegration:\n    \"\"\"Integration tests for Rollout IS with PPO.\"\"\"\n\n    @pytest.fixture\n    def sample_data(self):\n        \"\"\"Create sample training data.\"\"\"\n        batch_size, seq_length = 4, 16\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        return {\n            \"old_log_prob\": torch.randn(batch_size, seq_length, device=device),\n            \"log_prob\": torch.randn(batch_size, seq_length, device=device),\n            \"rollout_log_prob\": torch.randn(batch_size, seq_length, device=device),\n            \"advantages\": torch.randn(batch_size, seq_length, device=device),\n            \"response_mask\": torch.ones(batch_size, seq_length, device=device),\n        }\n\n    @pytest.fixture\n    def config_with_rollout_is(self):\n        \"\"\"Create config for policy loss computation.\n\n        Note: rollout_is config has been moved to algorithm config.\n        This config only needs fields used by policy loss (clip_ratio, etc).\n        \"\"\"\n        config = ActorConfig(\n            strategy=\"fsdp\",\n            rollout_n=1,\n            ppo_micro_batch_size=2,\n            clip_ratio=0.2,\n        )\n        return config\n\n    def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):\n        \"\"\"Test that policy loss computation works with rollout IS weights.\n\n        Note: In production, IS weights are computed centrally in the trainer\n        (before advantage computation) and passed to policy loss.\n        This test simulates that workflow.\n        \"\"\"\n        # First compute IS weights (as trainer would do centrally)\n        rollout_is_weights_proto, _, _ = compute_rollout_importance_weights(\n            old_log_prob=sample_data[\"old_log_prob\"],\n            rollout_log_prob=sample_data[\"rollout_log_prob\"],\n            response_mask=sample_data[\"response_mask\"],\n            rollout_is_level=\"token\",\n            rollout_is_mode=\"truncate\",\n            rollout_is_threshold=2.0,\n            rollout_is_veto_threshold=1e-4,\n        )\n\n        rollout_is_weights = rollout_is_weights_proto.batch[\"rollout_is_weights\"]\n\n        # Policy loss function receives pre-computed IS weights\n        pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss_vanilla(\n            old_log_prob=sample_data[\"old_log_prob\"],\n            log_prob=sample_data[\"log_prob\"],\n            advantages=sample_data[\"advantages\"],\n            response_mask=sample_data[\"response_mask\"],\n            loss_agg_mode=\"token-mean\",\n            config=config_with_rollout_is,\n            rollout_is_weights=rollout_is_weights,\n        )\n\n        # Check loss is valid\n        assert isinstance(pg_loss, torch.Tensor)\n        assert pg_loss.ndim == 0  # Scalar\n        assert not torch.isnan(pg_loss)\n        assert not torch.isinf(pg_loss)\n\n    def test_rollout_is_weights_computation(self, sample_data):\n        \"\"\"Test rollout IS weights and metrics computation.\"\"\"\n        weights_proto, _, metrics = compute_rollout_importance_weights(\n            old_log_prob=sample_data[\"old_log_prob\"],\n            rollout_log_prob=sample_data[\"rollout_log_prob\"],\n            response_mask=sample_data[\"response_mask\"],\n            rollout_is_level=\"token\",\n            rollout_is_mode=\"truncate\",\n            rollout_is_threshold=2.0,\n            rollout_is_veto_threshold=1e-4,\n        )\n\n        # Check weights\n        from verl.protocol import DataProto\n\n        assert isinstance(weights_proto, DataProto)\n        weights = weights_proto.batch[\"rollout_is_weights\"]\n        assert isinstance(weights, torch.Tensor)\n        assert weights.shape == sample_data[\"old_log_prob\"].shape\n\n        # Check metrics are returned\n        assert isinstance(metrics, dict)\n        assert len(metrics) > 0\n        assert \"mismatch/rollout_is_mean\" in metrics\n\n    def test_all_aggregation_levels(self, sample_data):\n        \"\"\"Test all three aggregation levels.\"\"\"\n        levels = [\"token\", \"sequence\", \"geometric\"]\n\n        for level in levels:\n            _, _, metrics = compute_rollout_importance_weights(\n                old_log_prob=sample_data[\"old_log_prob\"],\n                rollout_log_prob=sample_data[\"rollout_log_prob\"],\n                response_mask=sample_data[\"response_mask\"],\n                rollout_is_level=level,\n                rollout_is_mode=\"truncate\",\n                rollout_is_threshold=2.0,\n            )\n\n            assert \"mismatch/rollout_is_mean\" in metrics\n\n    def test_both_bounding_modes(self, sample_data):\n        \"\"\"Test both truncate and mask modes.\"\"\"\n        modes = [\"truncate\", \"mask\"]\n\n        for mode in modes:\n            _, _, metrics = compute_rollout_importance_weights(\n                old_log_prob=sample_data[\"old_log_prob\"],\n                rollout_log_prob=sample_data[\"rollout_log_prob\"],\n                response_mask=sample_data[\"response_mask\"],\n                rollout_is_level=\"token\",\n                rollout_is_mode=mode,\n                rollout_is_threshold=2.0,\n                rollout_is_threshold_lower=0.5,\n            )\n\n            assert \"mismatch/rollout_is_mean\" in metrics\n\n    def test_mismatch_metrics(self, sample_data):\n        \"\"\"Test mismatch diagnostic metrics computation.\"\"\"\n        metrics = compute_mismatch_metrics(\n            old_log_prob=sample_data[\"old_log_prob\"],\n            rollout_log_prob=sample_data[\"rollout_log_prob\"],\n            response_mask=sample_data[\"response_mask\"],\n        )\n\n        # Check key metrics are present\n        assert \"mismatch_training_ppl\" in metrics\n        assert \"mismatch_rollout_ppl\" in metrics\n        assert \"mismatch_kl\" in metrics\n        assert isinstance(metrics[\"mismatch_kl\"], float)\n\n    def test_veto_mechanism(self):\n        \"\"\"Test veto mechanism with catastrophic outliers.\"\"\"\n        batch_size, seq_length = 2, 5\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        old_log_prob = torch.randn(batch_size, seq_length, device=device)\n        rollout_log_prob = old_log_prob.clone()\n\n        # Create catastrophic outlier in first sequence\n        rollout_log_prob[0, 2] += 15.0  # Makes ratio ~3e-7\n\n        response_mask = torch.ones(batch_size, seq_length, device=device)\n\n        _, _, metrics = compute_rollout_importance_weights(\n            old_log_prob=old_log_prob,\n            rollout_log_prob=rollout_log_prob,\n            response_mask=response_mask,\n            rollout_is_level=\"token\",\n            rollout_is_mode=\"truncate\",\n            rollout_is_threshold=2.0,\n            rollout_is_veto_threshold=1e-4,\n        )\n\n        # Should have vetoed one sequence\n        assert metrics[\"mismatch/rollout_is_veto_fraction\"] > 0\n        assert metrics[\"mismatch/rollout_is_veto_fraction\"] <= 1.0\n\n    def test_metrics_only_mode(self, sample_data, config_with_rollout_is):\n        \"\"\"Test metrics-only mode: compute IS weights/metrics but don't apply to loss.\n\n        This tests the use case where rollout_is_threshold is set (enables computation)\n        but rollout_is=False (disables weight application to policy loss).\n        \"\"\"\n        # Compute IS weights (as trainer would do)\n        rollout_is_weights_proto, _, is_metrics = compute_rollout_importance_weights(\n            old_log_prob=sample_data[\"old_log_prob\"],\n            rollout_log_prob=sample_data[\"rollout_log_prob\"],\n            response_mask=sample_data[\"response_mask\"],\n            rollout_is_level=\"token\",\n            rollout_is_mode=\"truncate\",\n            rollout_is_threshold=2.0,\n        )\n\n        # Metrics should be computed\n        assert len(is_metrics) > 0\n        assert \"mismatch/rollout_is_mean\" in is_metrics\n\n        # In metrics-only mode, we compute loss WITHOUT applying weights\n        # (simulating rollout_is=False)\n        pg_loss_no_weights, _, _, _ = compute_policy_loss_vanilla(\n            old_log_prob=sample_data[\"old_log_prob\"],\n            log_prob=sample_data[\"log_prob\"],\n            advantages=sample_data[\"advantages\"],\n            response_mask=sample_data[\"response_mask\"],\n            loss_agg_mode=\"token-mean\",\n            config=config_with_rollout_is,\n            rollout_is_weights=None,  # Don't apply weights\n        )\n\n        # Compare to loss WITH weights (rollout_is=True)\n        rollout_is_weights = rollout_is_weights_proto.batch[\"rollout_is_weights\"]\n        pg_loss_with_weights, _, _, _ = compute_policy_loss_vanilla(\n            old_log_prob=sample_data[\"old_log_prob\"],\n            log_prob=sample_data[\"log_prob\"],\n            advantages=sample_data[\"advantages\"],\n            response_mask=sample_data[\"response_mask\"],\n            loss_agg_mode=\"token-mean\",\n            config=config_with_rollout_is,\n            rollout_is_weights=rollout_is_weights,\n        )\n\n        # Losses should be different (weights have an effect)\n        assert not torch.allclose(pg_loss_no_weights, pg_loss_with_weights)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "verl_distillation/tests/utils/_test_module.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n# Test module for import_utils.load_extern_type testing\nclass TestClass:\n    \"\"\"A test class to be imported by load_extern_type\"\"\"\n\n    def __init__(self, value=None):\n        self.value = value or \"default\"\n\n    def get_value(self):\n        return self.value\n\n\nTEST_CONSTANT = \"test_constant_value\"\n\n\ndef test_function():\n    return \"test_function_result\"\n"
  },
  {
    "path": "verl_distillation/tests/utils/dataset/test_create_rl_sampler_on_cpu.py",
    "content": "# Copyright 2025 Amazon.com Inc and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\ntest create_rl_sampler\n\"\"\"\n\nfrom collections.abc import Sized\n\nimport pytest\nimport torch\nfrom omegaconf import DictConfig, OmegaConf\nfrom torch.utils.data import Dataset, RandomSampler\n\nfrom verl.experimental.dataset.sampler import AbstractCurriculumSampler\nfrom verl.trainer.main_ppo import create_rl_sampler\n\n\nclass RandomCurriculumSampler(AbstractCurriculumSampler):\n    def __init__(\n        self,\n        data_source: Sized,\n        data_config: DictConfig,\n    ):\n        train_dataloader_generator = torch.Generator()\n        train_dataloader_generator.manual_seed(1)\n        sampler = RandomSampler(data_source=data_source)\n        self.sampler = sampler\n\n    def __iter__(self):\n        return self.sampler.__iter__()\n\n    def __len__(self) -> int:\n        return len(self.sampler)\n\n    def update(self, batch) -> None:\n        return\n\n\nclass MockIncorrectSampler:\n    \"\"\"A fake sampler class that does not adhere to the AbstractCurriculumSampler interface.\"\"\"\n\n    def __init__(self, data_source, data_config):\n        pass\n\n\nclass MockChatDataset(Dataset):\n    def __init__(self):\n        self.data = [\n            {\"prompt\": \"What's your name?\", \"response\": \"My name is Assistant.\"},\n            {\"prompt\": \"How are you?\", \"response\": \"I'm doing well, thank you.\"},\n            {\"prompt\": \"What is the capital of France?\", \"response\": \"Paris.\"},\n            {\n                \"prompt\": \"Tell me a joke.\",\n                \"response\": \"Why did the chicken cross the road? To get to the other side!\",\n            },\n            {\"prompt\": \"What is 2+2?\", \"response\": \"4\"},\n        ]\n\n    def __getitem__(self, index):\n        return self.data[index]\n\n    def __len__(self):\n        return len(self.data)\n\n\ndef test_create_custom_curriculum_samper():\n    data_config = OmegaConf.create(\n        {\n            \"dataloader_num_workers\": 0,\n            \"sampler\": {\n                \"class_path\": \"pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu\",\n                \"class_name\": \"RandomCurriculumSampler\",\n            },\n        }\n    )\n\n    dataset = MockChatDataset()\n\n    # doesn't raise\n    create_rl_sampler(data_config, dataset)\n\n\ndef test_create_custom_curriculum_samper_wrong_class():\n    data_config = OmegaConf.create(\n        {\n            \"sampler\": {\n                \"class_path\": \"pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu\",\n                \"class_name\": \"MockIncorrectSampler\",\n            }\n        }\n    )\n\n    dataset = MockChatDataset()\n\n    # MockIncorrectSampler is not an instance of AbstractCurriculumSampler, so raises\n    with pytest.raises(AssertionError):\n        create_rl_sampler(data_config, dataset)\n"
  },
  {
    "path": "verl_distillation/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n\n#     http://www.apache.org/licenses/LICENSE-2.0\n\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTest the MultiTurnSFTDataset implementation\n\"\"\"\n\nimport os\n\nimport pandas as pd\nimport torch\nfrom transformers import AutoTokenizer\n\nfrom verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset\n\n\ndef test_multiturn_sft_dataset():\n    print(\"Starting test...\")\n    # Create a temporary parquet file with test data\n    test_data = {\n        \"messages\": [\n            [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n                {\"role\": \"assistant\", \"content\": \"2+2 equals 4.\"},\n                {\"role\": \"user\", \"content\": \"And what is 4+4?\"},\n                {\"role\": \"assistant\", \"content\": \"4+4 equals 8.\"},\n            ],\n            [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"Tell me a joke.\"},\n                {\"role\": \"assistant\", \"content\": \"Why did the chicken cross the road?\"},\n                {\"role\": \"user\", \"content\": \"Why?\"},\n                {\"role\": \"assistant\", \"content\": \"To get to the other side!\"},\n            ],\n        ]\n    }\n\n    # Create test directory if it doesn't exist\n    os.makedirs(\"test_data\", exist_ok=True)\n    test_file = \"test_data/test.parquet\"\n\n    # Save test data to parquet\n    df = pd.DataFrame(test_data)\n    df.to_parquet(test_file)\n\n    # Initialize tokenizer and dataset\n    tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-Coder-7B-Instruct\")\n    config = {\"max_length\": 512, \"truncation\": \"error\", \"multiturn\": {\"messages_key\": \"messages\"}}\n    dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)\n\n    # Test 1: Dataset Length\n    assert len(dataset) == 2, f\"Expected dataset length 2, got {len(dataset)}\"\n\n    # Get items for testing\n    item0 = dataset[0]  # Math conversation\n    item1 = dataset[1]  # Joke conversation\n\n    # Test 2: Required Keys and Types\n    required_keys = [\"input_ids\", \"attention_mask\", \"position_ids\", \"loss_mask\"]\n    for key in required_keys:\n        assert key in item0, f\"Missing key {key} in dataset item\"\n        assert isinstance(item0[key], torch.Tensor), f\"Expected torch.Tensor for {key}\"\n        assert item0[key].dtype == torch.long, f\"Expected torch.long for {key}, got {item0[key].dtype}\"\n\n    # Test 3: Shape Consistency\n    assert item0[\"loss_mask\"].shape == item0[\"input_ids\"].shape, \"Loss mask shape doesn't match input_ids shape\"\n    assert item0[\"attention_mask\"].shape == item0[\"input_ids\"].shape, (\n        \"Attention mask shape doesn't match input_ids shape\"\n    )\n    assert item0[\"position_ids\"].shape == item0[\"input_ids\"].shape, \"Position IDs shape doesn't match input_ids shape\"\n\n    # Test 4: Loss Mask Pattern - Math Conversation\n    loss_mask0 = item0[\"loss_mask\"]\n    input_ids0 = item0[\"input_ids\"]\n\n    # Find assistant response positions\n    assistant_positions0 = torch.where(loss_mask0 == 1)[0]\n    assert len(assistant_positions0) > 0, \"No assistant positions found in loss mask\"\n\n    # Decode and verify assistant responses\n    assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1])\n    print(f\"Math conversation assistant text: {assistant_text0}\")\n    assert \"2+2 equals 4\" in assistant_text0, \"First assistant response not found\"\n    assert \"4+4 equals 8\" in assistant_text0, \"Second assistant response not found\"\n\n    # Test 5: Loss Mask Pattern - Joke Conversation\n    loss_mask1 = item1[\"loss_mask\"]\n    input_ids1 = item1[\"input_ids\"]\n\n    # Find assistant response positions\n    assistant_positions1 = torch.where(loss_mask1 == 1)[0]\n    assert len(assistant_positions1) > 0, \"No assistant positions found in loss mask\"\n\n    # Decode and verify assistant responses\n    assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1])\n    print(f\"Joke conversation assistant text: {assistant_text1}\")\n    assert \"chicken cross the road\" in assistant_text1, \"First assistant response not found\"\n    assert \"other side\" in assistant_text1, \"Second assistant response not found\"\n\n    # Test 6: Attention Mask Pattern\n    attention_mask0 = item0[\"attention_mask\"]\n    sequence_length = torch.sum(attention_mask0)\n    assert sequence_length > 0, \"No tokens marked as attended in attention mask\"\n    assert torch.all(attention_mask0[:sequence_length] == 1), \"Incorrect attention mask pattern\"\n    if sequence_length < len(attention_mask0):\n        assert torch.all(attention_mask0[sequence_length:] == 0), \"Padding not properly masked\"\n\n    # Test 7: Position IDs Pattern\n    position_ids0 = item0[\"position_ids\"]\n    assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), (\n        \"Position IDs not sequential for non-padded tokens\"\n    )\n    if sequence_length < len(position_ids0):\n        assert torch.all(position_ids0[sequence_length:] == 0), \"Padding position IDs not zero\"\n\n    # Test 8: Verify loss mask for assistant responses\n    # Get the full conversation text\n    full_text = tokenizer.decode(input_ids0)\n    print(f\"\\nFull conversation text:\\n{full_text}\")\n\n    # Get the assistant responses\n    assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1])\n    print(f\"\\nAssistant responses (from loss mask):\\n{assistant_text}\")\n\n    # Verify that loss mask is set for all assistant responses\n    for msg in test_data[\"messages\"][0]:  # First conversation\n        if msg[\"role\"] == \"assistant\":\n            # The content should appear in the masked text\n            assert msg[\"content\"] in assistant_text, f\"Assistant message '{msg['content']}' not found in masked text\"\n\n            # The content should NOT appear in the non-masked text\n            non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0])\n            assert msg[\"content\"] not in non_assistant_text, (\n                f\"Assistant message '{msg['content']}' found in non-assistant text\"\n            )\n\n    # Test 9: Verify non-assistant parts have loss_mask=0\n    # Get non-assistant text\n    non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0])\n    print(f\"\\nNon-assistant text (from loss mask):\\n{non_assistant_text}\")\n\n    # Verify that system and user messages are in the non-assistant text\n    for msg in test_data[\"messages\"][0]:  # First conversation\n        if msg[\"role\"] in [\"system\", \"user\"]:\n            assert msg[\"content\"] in non_assistant_text, (\n                f\"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text\"\n            )\n\n            # And verify they're NOT in the assistant text\n            assert msg[\"content\"] not in assistant_text, (\n                f\"{msg['role'].title()} message '{msg['content']}' found in assistant text\"\n            )\n\n    # Test 10: Verify padding behavior\n    padding_config = {\"max_length\": 1024, \"truncation\": \"error\", \"multiturn\": {\"messages_key\": \"messages\"}}\n    small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config)\n    padded_item = small_dataset[0]\n\n    # Get actual sequence length (before padding)\n    actual_length = torch.sum(padded_item[\"attention_mask\"])\n\n    # Verify padding tokens\n    assert torch.all(padded_item[\"input_ids\"][actual_length:] == tokenizer.pad_token_id), (\n        \"Padding tokens not set correctly\"\n    )\n    assert torch.all(padded_item[\"attention_mask\"][actual_length:] == 0), \"Attention mask not set correctly for padding\"\n    assert torch.all(padded_item[\"loss_mask\"][actual_length:] == 0), \"Loss mask not set correctly for padding\"\n\n    # test no-padding\n    config = {\n        \"max_length\": 512,\n        \"truncation\": \"error\",\n        \"multiturn\": {\"messages_key\": \"messages\"},\n        \"pad_mode\": \"no_padding\",\n    }\n    dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)\n\n    item0 = dataset[0]\n\n    # Verify that the output contains expected keys for no-padding mode\n    required_keys = [\"input_ids\", \"position_ids\", \"loss_mask\"]\n    for key in required_keys:\n        assert key in item0, f\"Missing key {key} in no-padding mode dataset item\"\n        assert isinstance(item0[key], torch.Tensor), f\"Expected torch.Tensor for {key} in no-padding mode\"\n\n    # make sure assistant_text matches with expected\n    assistant_text = tokenizer.decode(item0[\"input_ids\"][item0[\"loss_mask\"] == 1])\n    assert assistant_text == \"2+2 equals 4.<|im_end|>\\n4+4 equals 8.<|im_end|>\\n\"\n\n    print(\"All tests passed!\")\n    print(\"Starting test...\")\n"
  },
  {
    "path": "verl_distillation/tests/utils/dataset/test_rl_collate_fn_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\n\n\ndef test_rl_collate_fn():\n    from verl.utils.dataset.rl_dataset import collate_fn\n\n    max_prompt_length = 5\n\n    test_data = [\n        {\n            # test tensor\n            \"input_ids\": torch.randint(0, 10, (max_prompt_length,)),\n            # test fixed length (1) list within a batch\n            \"messages\": [{\"role\": \"user\", \"content\": \"Hi.\"}],\n            # test variable length list within a batch\n            \"raw_prompt_ids\": [1, 2, 3, 4],\n            # test string\n            \"ability\": \"math\",\n            # test dict\n            \"reward_model\": {\"ground_truth\": 5, \"style\": \"rule\"},\n            # test empty dict\n            \"tools_kwargs\": {},\n        },\n        {\n            \"input_ids\": torch.randint(0, 10, (max_prompt_length,)),\n            \"messages\": [{\"role\": \"user\", \"content\": \"Hello.\"}],\n            \"raw_prompt_ids\": [1, 2, 3],\n            \"ability\": \"toolcall\",\n            \"reward_model\": {\n                \"ground_truth\": '[{\"name\": \"rgb_to_cmyk\", \"arguments\": {\"r\": 0, \"g\": 0, \"b\": 255}}]',\n                \"style\": \"rule\",\n            },\n            \"tools_kwargs\": {},\n        },\n    ]\n\n    batch_size = len(test_data)\n    batch = collate_fn(test_data)\n\n    # Tensor part\n    assert batch[\"input_ids\"].shape == (batch_size, max_prompt_length)\n    assert isinstance(batch[\"input_ids\"], torch.Tensor)\n\n    # Non-tensor parts\n    expected_types = {\n        \"messages\": list,\n        \"raw_prompt_ids\": list,\n        \"ability\": str,\n        \"reward_model\": dict,\n        \"tools_kwargs\": dict,\n    }\n\n    for key, dtype in expected_types.items():\n        assert batch[key].shape == (batch_size,), (\n            f\"Expected shape {(batch_size,)} for '{key}', but got {batch[key].shape}\"\n        )\n        assert isinstance(batch[key][0], dtype), (\n            f\"'{key}' should contain elements of type {dtype}, but got {type(batch[key][0])}\"\n        )\n"
  },
  {
    "path": "verl_distillation/tests/utils/dataset/test_rl_dataset_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nimport torch\nfrom omegaconf import OmegaConf\nfrom torch.utils.data import DataLoader\n\n\ndef get_gsm8k_data():\n    # prepare test dataset\n    local_folder = os.path.expanduser(\"~/verl-data/gsm8k/\")\n    local_path = os.path.join(local_folder, \"train.parquet\")\n    os.makedirs(local_folder, exist_ok=True)\n    return local_path\n\n\ndef test_rl_dataset():\n    from verl.utils import hf_tokenizer\n    from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn\n\n    tokenizer = hf_tokenizer(\"deepseek-ai/deepseek-coder-1.3b-instruct\")\n    local_path = get_gsm8k_data()\n    config = OmegaConf.create(\n        {\n            \"prompt_key\": \"prompt\",\n            \"max_prompt_length\": 256,\n            \"filter_overlong_prompts\": True,\n            \"filter_overlong_prompts_workers\": 2,\n        }\n    )\n    dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config)\n\n    dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)\n\n    a = next(iter(dataloader))\n\n    from verl import DataProto\n\n    tensors = {}\n    non_tensors = {}\n\n    for key, val in a.items():\n        if isinstance(val, torch.Tensor):\n            tensors[key] = val\n        else:\n            non_tensors[key] = val\n\n    data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)\n    assert \"input_ids\" in data_proto.batch\n\n    data = dataset[0][\"input_ids\"]\n    output = tokenizer.batch_decode([data])[0]\n    print(f\"type: type{output}\")\n    print(f\"\\n\\noutput: {output}\")\n\n\ndef test_rl_dataset_with_max_samples():\n    from verl.utils import hf_tokenizer\n    from verl.utils.dataset.rl_dataset import RLHFDataset\n\n    tokenizer = hf_tokenizer(\"deepseek-ai/deepseek-coder-1.3b-instruct\")\n    local_path = get_gsm8k_data()\n    config = OmegaConf.create(\n        {\n            \"prompt_key\": \"prompt\",\n            \"max_prompt_length\": 256,\n            \"filter_overlong_prompts\": True,\n            \"filter_overlong_prompts_workers\": 2,\n            \"max_samples\": 5,\n        }\n    )\n    dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config, max_samples=5)\n    assert len(dataset) == 5\n\n\ndef test_image_rl_data():\n    from verl.utils import hf_processor, hf_tokenizer\n    from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn\n\n    tokenizer = hf_tokenizer(\"Qwen/Qwen2-VL-2B-Instruct\")\n    processor = hf_processor(\"Qwen/Qwen2-VL-2B-Instruct\")\n    config = OmegaConf.create(\n        {\n            \"prompt_key\": \"prompt\",\n            \"max_prompt_length\": 1024,\n            \"filter_overlong_prompts\": True,\n            \"filter_overlong_prompts_workers\": 1,\n        }\n    )\n    dataset = RLHFDataset(\n        data_files=os.path.expanduser(\"~/data/geo3k/train.parquet\"),\n        tokenizer=tokenizer,\n        config=config,\n        processor=processor,\n    )\n\n    dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)\n\n    a = next(iter(dataloader))\n\n    from verl import DataProto\n\n    tensors = {}\n    non_tensors = {}\n\n    for key, val in a.items():\n        if isinstance(val, torch.Tensor):\n            tensors[key] = val\n        else:\n            non_tensors[key] = val\n\n    data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)\n\n    assert \"multi_modal_data\" in data_proto.non_tensor_batch, data_proto\n    assert \"multi_modal_inputs\" in data_proto.non_tensor_batch, data_proto\n\n    data = dataset[0][\"input_ids\"]\n    output = tokenizer.batch_decode([data])[0]\n    print(f\"type: type{output}\")\n    print(f\"\\n\\noutput: {output}\")\n"
  },
  {
    "path": "verl_distillation/tests/utils/dataset/test_sft_dataset_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.dataset.sft_dataset import SFTDataset\n\n\ndef get_gsm8k_data():\n    # prepare test dataset\n    local_folder = os.path.expanduser(\"~/verl-data/gsm8k/\")\n    local_path = os.path.join(local_folder, \"train.parquet\")\n    return local_path\n\n\ndef test_sft_cot_dataset():\n    tokenizer = hf_tokenizer(\"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct\")\n    local_path = get_gsm8k_data()\n    from omegaconf import OmegaConf\n\n    dataset = SFTDataset(\n        parquet_files=local_path,\n        tokenizer=tokenizer,\n        config=OmegaConf.create(\n            {\n                \"prompt_key\": \"prompt\",\n                \"prompt_dict_keys\": [\"content\"],\n                \"response_key\": \"extra_info\",\n                \"response_dict_keys\": [\"answer\"],\n                \"max_length\": 512,\n            }\n        ),\n    )\n\n    data = dataset[0][\"input_ids\"]\n    output = tokenizer.batch_decode([data])[0]\n    assert len(output) > 1\n    assert isinstance(output, str)\n\n\ndef test_sft_dataset():\n    tokenizer = hf_tokenizer(\"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct\")\n    local_path = get_gsm8k_data()\n    from omegaconf import OmegaConf\n\n    dataset = SFTDataset(\n        parquet_files=local_path,\n        tokenizer=tokenizer,\n        config=OmegaConf.create(\n            {\n                \"prompt_key\": \"extra_info\",\n                \"prompt_dict_keys\": [\"question\"],\n                \"response_key\": \"extra_info\",\n                \"response_dict_keys\": [\"answer\"],\n                \"max_length\": 512,\n            }\n        ),\n    )\n\n    data = dataset[0][\"input_ids\"]\n    output = tokenizer.batch_decode([data])[0]\n    assert len(output) > 1\n    assert isinstance(output, str)\n\n\ndef test_sft_dataset_with_max_samples():\n    tokenizer = hf_tokenizer(\"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct\")\n    local_path = get_gsm8k_data()\n    from omegaconf import OmegaConf\n\n    dataset = SFTDataset(\n        parquet_files=local_path,\n        tokenizer=tokenizer,\n        config=OmegaConf.create(\n            {\n                \"prompt_key\": \"extra_info\",\n                \"prompt_dict_keys\": [\"question\"],\n                \"response_key\": \"extra_info\",\n                \"response_dict_keys\": [\"answer\"],\n                \"max_length\": 512,\n            }\n        ),\n        max_samples=5,\n    )\n\n    assert len(dataset) == 5\n"
  },
  {
    "path": "verl_distillation/tests/utils/debug/test_metrics.py",
    "content": "# Copyright 2025 Individual Contributor: TomQunChaoA\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\n\nimport torch\n\nfrom verl.protocol import DataProto\nfrom verl.utils.debug.metrics import calculate_debug_metrics\n\n\nclass TestMetrics(unittest.TestCase):\n    def test_calculate_debug_metrics(self):\n        data = DataProto.from_dict(\n            {\n                \"rollout_log_probs\": torch.tensor(\n                    [\n                        [-1.5085, -0.1200, -0.6650, -0.4823, -0.1426, -1.5557, -2.8532, -0.3919, -0.4294, -0.4700],\n                        [-0.0585, -0.0573, -0.4681, -0.5187, -0.7451, -1.2737, -0.0682, -0.4284, -0.5754, -0.0611],\n                    ]\n                ),\n                \"old_log_probs\": torch.tensor(\n                    [\n                        [-1.8636, -0.7863, -0.2136, -0.4376, -2.0257, -0.2579, -1.1547, -0.5203, -0.3802, -0.9872],\n                        [-0.3507, -0.5426, -0.2725, -0.4637, -0.3577, -0.3733, -1.7560, -1.9542, -0.4229, -1.3098],\n                    ]\n                ),\n                \"loss_mask\": torch.tensor([[1, 0, 0, 0, 1, 1, 0, 1, 1, 0], [1, 0, 1, 0, 1, 1, 1, 0, 1, 1]]),\n                \"responses\": torch.zeros((2, 10)),\n            }\n        )\n        metrics = calculate_debug_metrics(data)\n        print(metrics)\n        assert metrics[\"training/rollout_probs_diff_valid\"] == 1\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/utils/megatron/test_pipeline_parallel.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\nfrom verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards\nfrom verl.utils.megatron.pipeline_parallel import make_batch_generator\n\n\ndef test_make_batch_generator_no_vpp():\n    batches = [1, 2, 3]\n    vpp_size = 1\n    generator = make_batch_generator(batches, vpp_size)\n    assert list(generator) == batches\n\n\ndef test_make_batch_generator_with_vpp():\n    batches = [{\"data\": 1}, {\"data\": 2}]\n    vpp_size = 2\n    generators = make_batch_generator(batches, vpp_size)\n    assert isinstance(generators, list)\n    assert len(generators) == vpp_size\n\n    # Check each generator yields the original batches\n    for gen in generators:\n        assert list(gen) == batches\n\n\ndef test_make_batch_generator_empty():\n    batches = []\n    vpp_size = 1\n    generator = make_batch_generator(batches, vpp_size)\n    assert list(generator) == []\n\n    vpp_size = 3\n    generators = make_batch_generator(batches, vpp_size)\n    assert len(generators) == vpp_size\n    for gen in generators:\n        assert list(gen) == []\n\n\n@pytest.mark.parametrize(\n    \"layer_num,pp_size,gt\",\n    [\n        (61, 8, [6, 8, 8, 8, 8, 8, 8, 7]),\n        (61, 7, [8, 9, 9, 9, 9, 9, 8]),\n        (61, 1, [61]),\n        (61, 0, ValueError),\n        (10, 16, ValueError),\n    ],\n)\ndef test_get_dynamic_pipeline_shards(layer_num, pp_size, gt):\n    if isinstance(gt, list):\n        shards = get_dynamic_pipeline_shards(layer_num, pp_size)\n        assert len(shards) == len(gt) == pp_size, f\"Expected {pp_size} shards, got {len(shards)}\"\n        assert all([shard == gt[i] for i, shard in enumerate(shards)]), f\"Expected shards {gt}, got {shards}\"\n    elif issubclass(gt, Exception):\n        with pytest.raises(gt):\n            shards = get_dynamic_pipeline_shards(layer_num, pp_size)\n"
  },
  {
    "path": "verl_distillation/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport multiprocessing\nimport os\nimport time\nfrom concurrent.futures import ProcessPoolExecutor\nfrom unittest.mock import patch\n\nimport pytest\n\n# Import the function to be tested\nfrom verl.utils.reward_score.sandbox_fusion.utils import check_correctness\n\n# Get SANDBOX_URL from environment variable\nSANDBOX_URL = os.environ.get(\"SANDBOX_FUSION_URL\")\n# Define skip condition and reason\nskip_reason = \"SANDBOX_FUSION_URL environment variable not set\"\nskip_condition = not SANDBOX_URL\n\n# --- Test code (for real API calls) ---\nCODE_SUCCESS = \"\"\"\nimport sys\ndata = sys.stdin.read()\nif data == 'input1':\n    print('output1\\\\n', end='')\nelif data == 'input2':\n    print('output2\\\\n', end='')\nelse:\n    print('unexpected input', end='')\n\"\"\"\n\nCODE_WRONG_OUTPUT = \"\"\"\nprint('wrong_output\\\\n', end='')\n\"\"\"\n\nCODE_COMPILE_ERROR = \"\"\"\na=b\n\"\"\"\n\nCODE_RUNTIME_ERROR = \"\"\"\nimport sys\nprint(\"About to raise error\", file=sys.stderr)\nraise ValueError(\"This is a runtime error\")\n\"\"\"\n\nCODE_TIMEOUT = \"\"\"\nimport time\nimport sys\nprint(\"Sleeping...\", file=sys.stderr)\ntime.sleep(10) # Sleep time should be longer than the timeout set in the test\nprint(\"Finished sleeping\", file=sys.stderr)\n\"\"\"\n\n# --- Test input/output data ---\nINPUT_OUTPUT_VALID = {\"inputs\": [\"input1\", \"input2\"], \"outputs\": [\"output1\\n\", \"output2\\n\"]}\n\nINPUT_OUTPUT_SINGLE = {\"inputs\": [\"input1\"], \"outputs\": [\"output1\\n\"]}\n\nINPUT_OUTPUT_MISMATCH = {\"inputs\": [\"input1\"], \"outputs\": [\"output1\\n\", \"output2\\n\"]}\n\nINPUT_OUTPUT_INVALID_MISSING_KEY = {\"inputs\": [\"input1\"]}\n\n# --- Integration test cases (calling real API) ---\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_success_correct():\n    \"\"\"Integration test: Code is correct, output is correct\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_SUCCESS)\n    assert results == [True, True]\n    assert metadata_list[0][\"status\"] == \"success\"\n    assert metadata_list[0][\"stdout\"] == \"output1\\n\"\n    assert metadata_list[1][\"status\"] == \"success\"\n    assert metadata_list[1][\"stdout\"] == \"output2\\n\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_success_wrong_output():\n    \"\"\"Integration test: Code runs successfully, but output is wrong\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_WRONG_OUTPUT)\n    assert results == [False, False]\n    assert metadata_list[0][\"status\"] == \"wrong_answer\"\n    assert metadata_list[0][\"stdout\"] == \"wrong_output\\n\"\n    assert metadata_list[1][\"status\"] == \"wrong_answer\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_compile_error():\n    \"\"\"Integration test: Code causes compile error\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_COMPILE_ERROR, language=\"cpp\")\n    assert results == [-4, -4]\n    assert metadata_list[0][\"status\"] == \"compile_error\"\n    assert metadata_list[1][\"status\"] == \"compile_error\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_runtime_error():\n    \"\"\"Integration test: Code causes runtime error\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_RUNTIME_ERROR)\n    assert results == [-2]\n    assert metadata_list[0][\"status\"] == \"runtime_error\"\n    # More assertions can be added based on the actual API response, e.g., exit_code, stderr\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_runtime_timeout():\n    \"\"\"Integration test: Code causes runtime timeout\"\"\"\n    test_timeout = 5  # Set a timeout shorter than the sleep time in CODE_TIMEOUT\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_TIMEOUT, timeout=test_timeout)\n    assert results == [-3]\n    assert metadata_list[0][\"status\"] == \"timeout\"\n    # More assertions can be added based on the actual API response, e.g., run_status\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_concurrency_high_load():\n    \"\"\"Integration test: High concurrency (100 cases) against real API with mixed results (success, wrong\n    answer, timeout)\"\"\"\n    concurrency_level = 100\n    # Indices for different expected outcomes\n    wrong_answer_indices = {10, 25, 50}\n    timeout_indices = {5, 30, 60, 90}  # Indices where we expect a timeout\n\n    # Generate 100 input/output pairs and code\n    high_load_inputs = []\n    high_load_outputs = []\n    expected_results_map = {}  # Store expected result for each index\n\n    for i in range(concurrency_level):\n        if i in timeout_indices:\n            # Use a special input to trigger timeout in the code\n            high_load_inputs.append(f\"input_timeout_{i}\")\n            # Output doesn't matter for timeout, but keep it consistent\n            high_load_outputs.append(f\"output_{i}\\n\")\n            expected_results_map[i] = -3  # Expect timeout\n        elif i in wrong_answer_indices:\n            high_load_inputs.append(f\"input_{i}\")\n            # Intentionally set wrong expected output\n            high_load_outputs.append(f\"wrong_output_{i}\\n\")\n            expected_results_map[i] = False  # Expect wrong answer\n        else:\n            high_load_inputs.append(f\"input_{i}\")\n            # Correct expected output\n            high_load_outputs.append(f\"output_{i}\\n\")\n            expected_results_map[i] = True  # Expect success\n\n    high_load_in_outs = {\"inputs\": high_load_inputs, \"outputs\": high_load_outputs}\n\n    # Code that handles normal inputs, and sleeps on specific \"timeout\" inputs\n    code_mixed_concurrent = \"\"\"\nimport sys\nimport time\ndata = sys.stdin.read()\nif data.startswith('input_timeout_'):\n    time.sleep(20) # Sleep longer than the test timeout\n    print(f\"output_{data.split('_')[-1]}\\\\n\", end='') # Still print something in case it finishes early\nelif data.startswith('input_'):\n    print(f\"output_{data.split('_')[-1]}\\\\n\", end='')\nelse:\n    print(\"unknown_input\\\\n\", end='')\n\"\"\"\n    # Set a reasonable timeout per case (must be less than the sleep time in the code)\n    test_timeout = 15  # Allow slightly more time due to potential API load, but less than 20s sleep\n\n    start_time = time.time()\n    results, metadata_list = check_correctness(\n        SANDBOX_URL,\n        high_load_in_outs,\n        code_mixed_concurrent,  # Use the new code\n        timeout=test_timeout,\n    )\n    end_time = time.time()\n    duration = end_time - start_time\n    print(\n        f\"\\nHigh concurrency test ({concurrency_level} cases with {len(wrong_answer_indices)} wrong answers, \"\n        f\"{len(timeout_indices)} timeouts) duration: {duration:.2f} seconds\"\n    )\n\n    # Verify results against the expected map\n    assert len(results) == concurrency_level, f\"Expected {concurrency_level} results, got {len(results)}\"\n\n    correct_count = 0\n    wrong_count = 0\n    timeout_count = 0\n    unexpected_results = []\n    for i, r in enumerate(results):\n        expected = expected_results_map[i]\n        if r == expected:\n            if expected is True:\n                correct_count += 1\n            elif expected is False:\n                wrong_count += 1\n            elif expected == -3:\n                timeout_count += 1\n        else:\n            unexpected_results.append((i, r, f\"Expected {expected}\"))\n\n    print(\n        f\"Correct results (True): {correct_count}/\"\n        f\"{concurrency_level - len(wrong_answer_indices) - len(timeout_indices)}\"\n    )\n    print(f\"Expected wrong answers (False, correctly identified): {wrong_count}/{len(wrong_answer_indices)}\")\n    print(f\"Expected timeouts (-3, correctly identified): {timeout_count}/{len(timeout_indices)}\")\n\n    if unexpected_results:\n        print(\"Unexpected results found:\")\n        for idx, res, expected_str in unexpected_results[:10]:  # Print first 10 unexpected\n            print(f\"  Index {idx}: Got {res}, {expected_str}. Metadata: {metadata_list[idx]}\")\n        raise AssertionError(f\"Found {len(unexpected_results)} unexpected results.\")\n\n    assert correct_count == concurrency_level - len(wrong_answer_indices) - len(timeout_indices), (\n        \"Incorrect number of successful results\"\n    )\n    assert wrong_count == len(wrong_answer_indices), \"Incorrect number of identified wrong answers\"\n    assert timeout_count == len(timeout_indices), \"Incorrect number of identified timeouts\"\n\n    # Verify metadata count and basic status of one of each type\n    assert len(metadata_list) == concurrency_level\n    # Find the first correct index\n    first_correct_index = next(\n        i for i in range(concurrency_level) if i not in wrong_answer_indices and i not in timeout_indices\n    )\n    assert metadata_list[first_correct_index][\"status\"] == \"success\"\n    assert metadata_list[first_correct_index][\"stdout\"] == f\"output_{first_correct_index}\\n\"\n\n    # Check the status of the first intentionally wrong case\n    first_wrong_index = min(wrong_answer_indices)\n    assert metadata_list[first_wrong_index][\"status\"] == \"wrong_answer\"\n    assert metadata_list[first_wrong_index][\"stdout\"] == f\"output_{first_wrong_index}\\n\"\n    assert metadata_list[first_wrong_index][\"expected_output\"] == f\"wrong_output_{first_wrong_index}\\n\"\n\n    # Check the status of the first intentionally timeout case\n    first_timeout_index = min(timeout_indices)\n    assert metadata_list[first_timeout_index][\"status\"] == \"timeout\"\n    # For timeout, stdout might be None or empty depending on when the timeout occurred\n    # assert metadata_list[first_timeout_index][\"stdout\"] is None or metadata_list[first_timeout_index][\"stdout\"] == \"\"\n\n\n# --- Unit test cases (using mock) ---\n\n\n@patch(\"verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api\")\ndef test_unit_concurrency_order(mock_call_sandbox_api):\n    sandbox_url = \"mock_url\"\n    generation = \"print(input())\"\n    language = \"python\"\n    timeout = 5\n    in_outs = {\"inputs\": [\"input1\", \"input2\", \"input3\"], \"outputs\": [\"output1\", \"output2\", \"output3\"]}\n\n    def side_effect(*args, **kwargs):\n        stdin = kwargs.get(\"stdin\")\n        if stdin == \"input1\":\n            return (\n                {\"status\": \"Success\", \"run_result\": {\"status\": \"Finished\", \"stdout\": \"output1\", \"return_code\": 0}},\n                None,\n            )\n        elif stdin == \"input2\":\n            time.sleep(0.1)\n            return (\n                {\"status\": \"Success\", \"run_result\": {\"status\": \"Finished\", \"stdout\": \"output2\", \"return_code\": 0}},\n                None,\n            )\n        elif stdin == \"input3\":\n            return (\n                {\"status\": \"Success\", \"run_result\": {\"status\": \"Finished\", \"stdout\": \"output3\", \"return_code\": 0}},\n                None,\n            )\n        else:\n            return (None, \"Unknown input in mock\")\n\n    mock_call_sandbox_api.side_effect = side_effect\n\n    results, metadata_list = check_correctness(sandbox_url, in_outs, generation, timeout, language)\n\n    assert results == [True, True, True]\n    assert len(metadata_list) == 3\n    assert metadata_list[0][\"case_index\"] == 0\n    assert metadata_list[0][\"status\"] == \"success\"\n    assert metadata_list[1][\"case_index\"] == 1\n    assert metadata_list[1][\"status\"] == \"success\"\n    assert metadata_list[2][\"case_index\"] == 2\n    assert metadata_list[2][\"status\"] == \"success\"\n    assert mock_call_sandbox_api.call_count == 3\n\n\n@patch(\"verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api\")\ndef test_unit_api_timeout_error_concurrent(mock_call_sandbox_api):\n    sandbox_url = \"mock_url\"\n    generation = \"print(input())\"\n    language = \"python\"\n    timeout = 5\n    in_outs = {\"inputs\": [\"input1\", \"input2_timeout\", \"input3\"], \"outputs\": [\"output1\", \"output2\", \"output3\"]}\n\n    api_error_message = \"API Call Failed: Gateway Timeout (504) on attempt 3/3\"\n\n    def side_effect(*args, **kwargs):\n        stdin = kwargs.get(\"stdin\")\n        if stdin == \"input1\":\n            return (\n                {\"status\": \"Success\", \"run_result\": {\"status\": \"Finished\", \"stdout\": \"output1\", \"return_code\": 0}},\n                None,\n            )\n        elif stdin == \"input2_timeout\":\n            return (None, api_error_message)\n        elif stdin == \"input3\":\n            return (\n                {\"status\": \"Success\", \"run_result\": {\"status\": \"Finished\", \"stdout\": \"output3\", \"return_code\": 0}},\n                None,\n            )\n        else:\n            return (None, \"Unknown input in mock\")\n\n    mock_call_sandbox_api.side_effect = side_effect\n\n    results, metadata_list = check_correctness(sandbox_url, in_outs, generation, timeout, language)\n\n    assert results == [True, -1, True]\n    assert len(metadata_list) == 3\n    assert metadata_list[0][\"status\"] == \"success\"\n    assert metadata_list[1][\"status\"] == \"api_error\"\n    assert metadata_list[1][\"api_request_error\"] == api_error_message\n    assert metadata_list[2][\"status\"] == \"success\"\n    assert mock_call_sandbox_api.call_count == 3\n\n\n# --- Constants for the new concurrency test ---\n# Define a low global concurrency limit to test the semaphore's effect\nMAX_GLOBAL_CONCURRENCY_LIMIT_TEST = 5\n# Define the number of processes used in the test\nNUM_PROCESSES_TEST = 4\n# Define the number of tasks processed by check_correctness in each process (i.e., internal\n# ThreadPoolExecutor's concurrency potential)\nNUM_TASKS_PER_PROCESS_TEST = 3\n# Simulate API call duration to ensure calls can overlap\nSIMULATED_API_CALL_DURATION_TEST = 0.2  # seconds\n\n\n# --- Mock API call function for concurrency tracking ---\n# This function will replace the real call_sandbox_api and use shared variables to track concurrency\ndef _mock_api_call_for_concurrency_tracking(\n    active_calls_counter,  # multiprocessing.Value\n    max_calls_tracker,  # multiprocessing.Value\n    call_lock,  # multiprocessing.Lock\n    # Standard call_sandbox_api parameters\n    sandbox_fusion_url,\n    code,\n    stdin,\n    compile_timeout,\n    run_timeout,\n    memory_limit_mb,\n    language,\n):\n    # entry_time = time.time() # For detailed logging\n    with call_lock:\n        active_calls_counter.value += 1\n        if active_calls_counter.value > max_calls_tracker.value:\n            max_calls_tracker.value = active_calls_counter.value\n        # Optional debug log:\n        # print(f\"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call Start. Active: \"\n        #       f\"{active_calls_counter.value}, Max Observed: {max_calls_tracker.value}, Input: {stdin}\")\n\n    time.sleep(SIMULATED_API_CALL_DURATION_TEST)  # Simulate actual work duration\n\n    # exit_time = time.time() # For detailed logging\n    with call_lock:\n        active_calls_counter.value -= 1\n        # Optional debug log:\n        # print(f\"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call End. Active: \"\n        #       f\"{active_calls_counter.value}, Input: {stdin}, Duration: {exit_time - entry_time:.2f}s\")\n\n    # Return a simulated successful API response\n    return {\n        \"status\": \"Success\",\n        \"run_result\": {\"status\": \"Finished\", \"stdout\": f\"mock_output_for_{stdin}\", \"return_code\": 0},\n    }, None\n\n\n# --- Worker function for ProcessPoolExecutor ---\n# This function runs in each child process of ProcessPoolExecutor\ndef _process_pool_worker_for_concurrency_test(\n    sandbox_url,\n    in_outs,\n    generation,\n    memory_limit_mb,\n    language,\n    timeout,\n    mp_semaphore_for_check_correctness,\n    active_calls_counter,\n    max_calls_tracker,\n    call_lock,\n):\n    # Corrected lambda to accept keyword arguments matching call_sandbox_api's usage\n    curried_mock_api_call = (\n        lambda sandbox_fusion_url, code, stdin, compile_timeout, run_timeout, memory_limit_mb, language: (\n            _mock_api_call_for_concurrency_tracking(\n                active_calls_counter,\n                max_calls_tracker,\n                call_lock,\n                sandbox_fusion_url,\n                code,\n                stdin,\n                compile_timeout,\n                run_timeout,\n                memory_limit_mb,\n                language,\n            )\n        )\n    )\n\n    # ---- START DEBUG PRINTS ----\n    import os\n\n    import verl.utils.reward_score.sandbox_fusion.utils\n\n    print(\n        f\"[Worker PID:{os.getpid()}] Original call_sandbox_api: \"\n        f\"{verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}\",\n        flush=True,\n    )\n    # ---- END DEBUG PRINTS ----\n\n    with patch(\n        \"verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api\", side_effect=curried_mock_api_call\n    ) as mock_obj:\n        # ---- START DEBUG PRINTS ----\n        print(\n            f\"[Worker PID:{os.getpid()}] Patched call_sandbox_api: \"\n            f\"{verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}\",\n            flush=True,\n        )\n        print(f\"[Worker PID:{os.getpid()}] Mock object: {mock_obj}\", flush=True)\n        # ---- END DEBUG PRINTS ----\n        results, metadata_list = check_correctness(\n            sandbox_fusion_url=sandbox_url,\n            in_outs=in_outs,\n            generation=generation,\n            timeout=timeout,\n            memory_limit_mb=memory_limit_mb,\n            language=language,\n            concurrent_semaphore=mp_semaphore_for_check_correctness,  # Pass multiprocessing.Semaphore\n        )\n        # print(f\"Process {os.getpid()} finished check_correctness. Processed {len(results)} tasks.\")\n    return len(results)  # Return the number of processed tasks for basic validation\n\n\n# --- The actual test case for multiprocess concurrency control ---\ndef test_multiprocess_global_concurrency_limit_with_semaphore():\n    \"\"\"\n    Tests that the global concurrent_semaphore (multiprocessing.Semaphore)\n    correctly limits the number of concurrent calls to call_sandbox_api\n    across multiple processes, each potentially running multiple threads\n    via check_correctness's internal ThreadPoolExecutor.\n    \"\"\"\n    manager = multiprocessing.Manager()\n    active_calls_counter = manager.Value(\"i\", 0)  # Current active mock API calls\n    max_calls_tracker = manager.Value(\"i\", 0)  # Observed maximum concurrent mock API calls\n    call_lock = manager.Lock()  # Lock to protect counters\n\n    # Create a multiprocessing.Semaphore instance, this is the global semaphore we are testing.\n    # It will be passed to check_correctness and used by _process_single_case to limit calls to call_sandbox_api.\n    global_mp_semaphore = manager.Semaphore(MAX_GLOBAL_CONCURRENCY_LIMIT_TEST)\n\n    mock_sandbox_url = \"mock_url_for_concurrency_test\"\n    mock_generation = \"pass\"  # Specific code content is not important as API call is mocked\n    mock_memory_limit_mb = 1024\n    mock_language = \"python\"\n    mock_timeout = 5  # Timeout setting, not critical for mock calls\n\n    # Input/output data for each process\n    # NUM_TASKS_PER_PROCESS_TEST tasks will be handled by check_correctness's internal ThreadPoolExecutor\n    process_in_outs = {\n        \"inputs\": [f\"task_input_{i}\" for i in range(NUM_TASKS_PER_PROCESS_TEST)],\n        \"outputs\": [f\"task_output_{i}\" for i in range(NUM_TASKS_PER_PROCESS_TEST)],\n    }\n\n    futures = []\n    total_tasks_expected_to_run = NUM_PROCESSES_TEST * NUM_TASKS_PER_PROCESS_TEST\n\n    test_start_time = time.time()\n\n    with ProcessPoolExecutor(max_workers=NUM_PROCESSES_TEST) as executor:\n        for i in range(NUM_PROCESSES_TEST):\n            future = executor.submit(\n                _process_pool_worker_for_concurrency_test,  # Worker function\n                mock_sandbox_url,\n                process_in_outs,\n                mock_generation,\n                mock_memory_limit_mb,\n                mock_language,\n                mock_timeout,\n                global_mp_semaphore,  # Global semaphore to test\n                active_calls_counter,  # Shared variables for tracking\n                max_calls_tracker,\n                call_lock,\n            )\n            futures.append(future)\n\n    # Wait for all processes to complete and collect results\n    num_tasks_processed_per_worker = [f.result() for f in futures]\n    test_end_time = time.time()\n    total_execution_time = test_end_time - test_start_time\n\n    # Print some test statistics for debugging and validation\n    print(\"\\n--- Global Concurrency Test Stats ---\")\n    print(f\"Semaphore Limit (MAX_GLOBAL_CONCURRENCY_LIMIT_TEST): {MAX_GLOBAL_CONCURRENCY_LIMIT_TEST}\")\n    print(f\"Number of Processes (NUM_PROCESSES_TEST): {NUM_PROCESSES_TEST}\")\n    print(f\"Tasks per Process (NUM_TASKS_PER_PROCESS_TEST): {NUM_TASKS_PER_PROCESS_TEST}\")\n    print(f\"Total Tasks Submitted: {total_tasks_expected_to_run}\")\n    print(f\"Simulated API Call Duration: {SIMULATED_API_CALL_DURATION_TEST}s\")\n    print(f\"Total Test Execution Time: {total_execution_time:.2f}s\")\n    print(f\"Max Concurrent Mock API Calls Observed: {max_calls_tracker.value}\")\n    # print(f\"Tasks processed per worker: {num_tasks_processed_per_worker}\")\n\n    # Verify that all submitted tasks have been processed\n    assert sum(num_tasks_processed_per_worker) == total_tasks_expected_to_run, (\n        \"Mismatch in the number of tasks processed.\"\n    )\n\n    # Verify that the mock API was called at least once\n    assert max_calls_tracker.value > 0, \"The mocked API call_sandbox_api was not called.\"\n\n    # Core assertion: Observed maximum concurrent calls should not exceed the semaphore's limit\n    assert max_calls_tracker.value <= MAX_GLOBAL_CONCURRENCY_LIMIT_TEST, (\n        f\"Observed concurrency ({max_calls_tracker.value}) exceeded semaphore limit \"\n        f\"({MAX_GLOBAL_CONCURRENCY_LIMIT_TEST}).\"\n    )\n\n    # Optional: Rough check on execution time to verify semaphore is working to limit concurrency\n    # Theoretical minimum execution time = (Total tasks / Concurrency limit) * Single task duration\n    # Actual time will be longer due to various overheads\n    min_expected_duration = (\n        total_tasks_expected_to_run * SIMULATED_API_CALL_DURATION_TEST\n    ) / MAX_GLOBAL_CONCURRENCY_LIMIT_TEST\n    # print(f\"Minimum Expected Execution Time (approx): {min_expected_duration:.2f}s\")\n    # Allow some margin, e.g., 80% of theoretical minimum time\n    assert total_execution_time >= min_expected_duration * 0.8, (\n        f\"Total execution time ({total_execution_time:.2f}s) was unexpectedly short, suggesting the \"\n        f\"semaphore might not be effectively limiting concurrency as expected \"\n        f\"(min expected: {min_expected_duration * 0.8:.2f}s).\"\n    )\n\n\n# Ensure there is no more code after this point if these were the last functions.\n# If there was other code, it would follow here.\ndef test_unit_invalid_input_format():\n    \"\"\"Unit test: Invalid in_outs format passed\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, None, CODE_SUCCESS)\n    assert results == [-1]\n    assert metadata_list[0][\"error\"] == \"Invalid input/output data\"\n\n    results, metadata_list = check_correctness(SANDBOX_URL, {}, CODE_SUCCESS)\n    assert results == [-1]\n    assert metadata_list[0][\"error\"] == \"Invalid input/output data\"\n\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_INVALID_MISSING_KEY, CODE_SUCCESS)\n    assert results == [-1]\n    assert metadata_list[0][\"error\"] == \"Invalid input/output data\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_unit_input_output_mismatch():\n    \"\"\"Unit test: Mismatch between the number of inputs and outputs\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_MISMATCH, CODE_SUCCESS)\n    assert results == [-1]\n    assert len(metadata_list) == 1\n    assert metadata_list[0][\"error\"] == \"Input/output count mismatch\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_concurrency_all_timeout():\n    \"\"\"Integration test: High concurrency (100 cases) against real API, all causing timeout\"\"\"\n    concurrency_level = 100\n    code_infinite_loop = \"\"\"\ndef knight_moves(X, Y):\n    MOD = 10**9 + 7\n    dp = [[0] * (Y + 1) for _ in range(X + 1)]\n    dp[0][0] = 1\n    for i in range(1, X + 1):\n        for j in range(1, Y + 1):\n            dp[i][j] = (dp[i - 1][j] + dp[i][j - 1]) % MOD\n    return dp[X][Y]\n\ndef solve():\n    X, Y = map(int, input().split())\n    print(knight_moves(X, Y))\n\nif __name__ == \"__main__\":\n    solve()\n    \"\"\"\n\n    # Generate 100 simple input/output pairs (content doesn't matter)\n    timeout_inputs = [\"324 384429\" for i in range(concurrency_level)]\n    timeout_outputs = [f\"output_{i}\\n\" for i in range(concurrency_level)]\n    timeout_in_outs = {\"inputs\": timeout_inputs, \"outputs\": timeout_outputs}\n\n    # Set a timeout for the test cases\n    test_timeout = 10  # Set a timeout value\n\n    start_time = time.time()\n    results, metadata_list = check_correctness(SANDBOX_URL, timeout_in_outs, code_infinite_loop, timeout=test_timeout)\n    end_time = time.time()\n    duration = end_time - start_time\n    print(f\"\\nHigh concurrency all timeout test ({concurrency_level} cases) duration: {duration:.2f} seconds\")\n\n    # Verify all results are -3 (timeout)\n    assert len(results) == concurrency_level, f\"Expected {concurrency_level} results, got {len(results)}\"\n    all_timed_out = all(r == -3 for r in results)\n    if not all_timed_out:\n        non_timeout_indices = [i for i, r in enumerate(results) if r != -3]\n        print(f\"Indices that did not time out: {non_timeout_indices}\")\n        # Print metadata for the first few non-timeout cases for debugging\n        for i in non_timeout_indices[:5]:\n            print(f\"Metadata for non-timeout case {i}: {metadata_list[i]}\")\n    assert all_timed_out, f\"Not all {concurrency_level} concurrent tests resulted in timeout (-3). Results: {results}\"\n\n    # Verify metadata count and status of the first case\n    assert len(metadata_list) == concurrency_level\n    assert metadata_list[0][\"status\"] == \"timeout\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_fn_name_success_single_case():\n    \"\"\"Tests successful execution for a single test case with fn_name.\n    from livecodebench/code_generation_lite test 510\n    \"\"\"\n    generation_code = \"\"\"\nclass Solution:\n    def occurrencesOfElement(self, nums: List[int], queries: List[int], x: int) -> List[int]:\n        positions = defaultdict(list)\n        for idx, num in enumerate(nums):\n            positions[num].append(idx)\n\n        x_positions = positions[x]\n        answer = []\n        for k in queries:\n            if k > len(x_positions):\n                answer.append(-1)\n            else:\n                answer.append(x_positions[k-1])\n        return answer\n\"\"\"\n    in_outs = {\n        \"fn_name\": \"occurrencesOfElement\",\n        \"inputs\": [\"[1, 3, 1, 7]\\n[1, 3, 2, 4]\\n1\", \"[1, 2, 3]\\n[10]\\n5\"],\n        \"outputs\": [\"[0, -1, 2, -1]\", \"[-1]\"],\n    }\n\n    # Use a short timeout for fast tests\n    results, metadata_list = check_correctness(SANDBOX_URL, in_outs, generation_code, timeout=5)\n    # from verl.utils.reward_score.prime_code import apps_check_correctness\n    # results, metadata_list = apps_check_correctness(in_outs=in_outs, generation=generation_code,\n    #                                                        timeout=50000, debug=True)\n\n    assert results == [True, True]\n    assert \"error\" not in metadata_list[0]\n    assert metadata_list[0].get(\"status\") != \"compile_error\"\n    assert metadata_list[0].get(\"status\") != \"runtime_error\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_none_and_empty_stdin_passed_correctly():\n    \"\"\"\n    Tests that when stdin data is set to an empty string or None, it is still\n    is passed correctly to Sandbox Fusion as an empty string.\n    \"\"\"\n    echo_code = \"\"\"\nimport sys\nprint(f\"You said '{sys.stdin.readline().strip()}'\")\n\"\"\"\n    in_outs = {\n        \"inputs\": [None, \"\", \"hello\"],\n        \"outputs\": [\"You said ''\", \"You said ''\", \"You said 'hello'\"],\n    }\n\n    # Use a short timeout for fast tests\n    results, metadata_list = check_correctness(SANDBOX_URL, in_outs, echo_code, timeout=5)\n\n    assert results == [True, True, True]\n    assert \"error\" not in metadata_list[0]\n    assert metadata_list[0].get(\"status\") != \"compile_error\"\n    assert metadata_list[0].get(\"status\") != \"runtime_error\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_assert_case_success():\n    \"\"\"Tests successful execution for assert case.\n    from KodCode\n    \"\"\"\n    generation_code = \"\"\"\nfrom typing import List, Tuple\n\ndef merge_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[int, int]]:\n    if not intervals:\n        return []\n\n    # Sort intervals by the start time\n    intervals.sort(key=lambda x: x[0])\n\n    merged = [intervals[0]]\n\n    for current in intervals[1:]:\n        last = merged[-1]\n        # If intervals overlap, merge them\n        if current[0] <= last[1]:\n            merged[-1] = (last[0], max(last[1], current[1]))\n        else:\n            merged.append(current)\n\n    return merged\n\"\"\"\n    test_cases = {\n        \"fn_name\": \"merge_intervals\",\n        \"assert_case\": [\n            \"assert merge_intervals([(0, 1), (3, 5), (4, 7), (6, 8), (10, 12),\"\n            \" (12, 14)]) == [(0, 1), (3, 8), (10, 14)]\",\n            \"assert merge_intervals([(1, 2), (2, 3), (3, 4)]) == [(1, 4)]\",\n            \"assert merge_intervals([(1, 2), (3, 4), (5, 6)]) == [(1, 2), (3, 4), (5, 5)]\",\n        ],\n    }\n\n    assert_cases = test_cases.get(\"assert_case\")\n    test_cases.setdefault(\"inputs\", [\"\" for _ in assert_cases])\n    test_cases.setdefault(\"outputs\", [None for _ in assert_cases])\n\n    # Use a short timeout for fast tests\n    results, metadata_list = check_correctness(SANDBOX_URL, test_cases, generation_code, timeout=5)\n    assert results == [True, True, -2]\n    for i in range(2):\n        assert \"error\" not in metadata_list[i]\n        assert metadata_list[i].get(\"status\") == \"success\"\n        assert metadata_list[i].get(\"expected_output\") is None\n        assert metadata_list[i].get(\"status\") != \"runtime_error\"\n    assert \"error\" not in metadata_list[2]\n    assert metadata_list[2].get(\"status\") != \"success\"\n    assert metadata_list[2].get(\"expected_output\") is None\n    assert metadata_list[2].get(\"status\") == \"runtime_error\"\n"
  },
  {
    "path": "verl_distillation/tests/utils/reward_score/test_sandbox_on_cpu.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport json\nimport os\n\nimport pytest\n\nfrom verl.utils.reward_score import default_compute_score, prime_code, sandbox_fusion\nfrom verl.utils.reward_score.prime_code import apps_check_correctness\nfrom verl.workers.reward_manager.prime import parallel_compute_score_async\n\nprime_math_answers = [\n    \"\"\"\\\\begin{bmatrix}\\n -7 & 6 & -8 \\\\\\\\\\n 11 & -9 & 12 \\\\\\\\\\n 15 & -16 & 19 \\n \\\\end{bmatrix}\"\"\",\n    \"\"\"\\\\frac{\\\\sqrt{505}}{7}\"\"\",\n    \"\"\"x^2 + y^2 + 4x - 6y + 13\"\"\",\n]\nprime_math_gts = [\n    \"\"\"\\\\begin{pmatrix}\\n -7 & 6 & -8 \\\\\\\\\\n 11 & -9 & 12 \\\\\\\\\\n 15 & -16 & 19\\n \\\\end{pmatrix}\"\"\",  # mat test\n    \"\"\"\\\\frac{\\\\sqrt{505}}{7}\"\"\",  # frac test\n    \"\"\"(x + 2)^2 + (y - 3)^2 \"\"\",  # symbolic test\n]\n\nprime_code_answers = [\n    \"\"\"import sys\nfrom collections import deque\n\ndef main():\n    data = sys.stdin.read().split()\n    it = iter(data)\n    \n    # Read start and target positions\n    x0, y0, x1, y1 = int(next(it)), int(next(it)), int(next(it)), int(next(it))\n    \n    n = int(next(it))\n    allowed = set()\n    # The total number of allowed cells is at most 10^5.\n    for _ in range(n):\n        r = int(next(it))\n        a = int(next(it))\n        b = int(next(it))\n        for c in range(a, b + 1):\n            allowed.add((r, c))\n    \n    # Directions for the king (8 neighboring cells)\n    directions = [(-1, -1), (-1, 0), (-1, 1),\n                  (0, -1),           (0, 1),\n                  (1, -1),  (1, 0),  (1, 1)]\n    \n    start = (x0, y0)\n    target = (x1, y1)\n    \n    # BFS initialization\n    queue = deque()\n    queue.append((x0, y0, 0))\n    # Mark the starting cell as visited by removing it from allowed set.\n    allowed.discard(start)\n    \n    while queue:\n        x, y, moves = queue.popleft()\n        if (x, y) == target:\n            print(moves)\n            return\n        for dx, dy in directions:\n            nx, ny = x + dx, y + dy\n            if (nx, ny) in allowed:\n                allowed.remove((nx, ny))\n                queue.append((nx, ny, moves + 1))\n    \n    print(-1)\n\nif __name__ == '__main__':\n    main()\n\"\"\"\n] * 2\nprime_code_gts = [\n    \"\"\"{\\n \\\"inputs\\\": [\\n \\\"5 7 6 11\\\\n3\\\\n5 3 8\\\\n6 7 11\\\\n5 2 5\\\\n\\\",\\n \\\"3 4 3 10\\\\n3\\\\n3 1 4\\\\n4 5 9\\\\n3 10 10\\\\n\\\",\\n \\\"1 1 2 10\\\\n2\\\\n1 1 3\\\\n2 6 10\\\\n\\\",\\n \\\"9 8 7 8\\\\n9\\\\n10 6 6\\\\n10 6 6\\\\n7 7 8\\\\n9 5 6\\\\n8 9 9\\\\n9 5 5\\\\n9 8 8\\\\n8 5 6\\\\n9 10 10\\\\n\\\",\\n \\\"6 15 7 15\\\\n9\\\\n6 15 15\\\\n7 14 14\\\\n6 15 15\\\\n9 14 14\\\\n7 14 16\\\\n6 15 15\\\\n6 15 15\\\\n7 14 14\\\\n8 15 15\\\\n\\\",\\n \\\"13 16 20 10\\\\n18\\\\n13 16 16\\\\n20 10 10\\\\n19 10 10\\\\n12 15 15\\\\n20 10 10\\\\n18 11 11\\\\n19 10 10\\\\n19 10 10\\\\n20 10 10\\\\n19 10 10\\\\n20 10 10\\\\n20 10 10\\\\n19 10 10\\\\n18 11 11\\\\n13 16 16\\\\n12 15 15\\\\n19 10 10\\\\n19 10 10\\\\n\\\",\\n \\\"89 29 88 30\\\\n16\\\\n87 31 31\\\\n14 95 95\\\\n98 88 89\\\\n96 88 88\\\\n14 97 97\\\\n13 97 98\\\\n100 88 88\\\\n88 32 32\\\\n99 88 89\\\\n90 29 29\\\\n87 31 31\\\\n15 94 96\\\\n89 29 29\\\\n88 32 32\\\\n97 89 89\\\\n88 29 30\\\\n\\\",\\n \\\"30 14 39 19\\\\n31\\\\n35 7 11\\\\n37 11 12\\\\n32 13 13\\\\n37 5 6\\\\n46 13 13\\\\n37 14 14\\\\n31 13 13\\\\n43 13 19\\\\n45 15 19\\\\n46 13 13\\\\n32 17 17\\\\n41 14 19\\\\n30 14 14\\\\n43 13 17\\\\n34 16 18\\\\n44 11 19\\\\n38 13 13\\\\n40 12 20\\\\n37 16 18\\\\n46 16 18\\\\n34 10 14\\\\n36 9 10\\\\n36 15 19\\\\n38 15 19\\\\n42 13 19\\\\n33 14 15\\\\n35 15 19\\\\n33 17 18\\\\n39 12 20\\\\n36 5 7\\\\n45 12 12\\\\n\\\",\\n \\\"2 1 1 1\\\\n2\\\\n1 1 2\\\\n2 1 2\\\\n\\\",\\n \\\"1 1 1 2\\\\n5\\\\n1000000000 1 10000\\\\n19920401 1188 5566\\\\n1000000000 1 10000\\\\n1 1 10000\\\\n5 100 200\\\\n\\\",\\n \\\"1 1 1000000000 2\\\\n5\\\\n1000000000 1 10000\\\\n19920401 1188 5566\\\\n1000000000 1 10000\\\\n1 1 10000\\\\n5 100 200\\\\n\\\"\\n ],\\n \\\"outputs\\\": [\\n \\\"4\\\\n\\\",\\n \\\"6\\\\n\\\",\\n \\\"-1\\\\n\\\",\\n \\\"2\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"-1\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"9\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"-1\\\\n\\\"\\n ]\\n}\"\"\",  # A correct sample # noqa: E501\n    \"\"\"{\\n \\\"inputs\\\": [\\n \\\"5 7 6 11\\\\n3\\\\n5 3 8\\\\n6 7 11\\\\n5 2 5\\\\n\\\",\\n \\\"3 4 3 10\\\\n3\\\\n3 1 4\\\\n4 5 9\\\\n3 10 10\\\\n\\\",\\n \\\"1 1 2 10\\\\n2\\\\n1 1 3\\\\n2 6 10\\\\n\\\",\\n \\\"9 8 7 8\\\\n9\\\\n10 6 6\\\\n10 6 6\\\\n7 7 8\\\\n9 5 6\\\\n8 9 9\\\\n9 5 5\\\\n9 8 8\\\\n8 5 6\\\\n9 10 10\\\\n\\\",\\n \\\"6 15 7 15\\\\n9\\\\n6 15 15\\\\n7 14 14\\\\n6 15 15\\\\n9 14 14\\\\n7 14 16\\\\n6 15 15\\\\n6 15 15\\\\n7 14 14\\\\n8 15 15\\\\n\\\",\\n \\\"13 16 20 10\\\\n18\\\\n13 16 16\\\\n20 10 10\\\\n19 10 10\\\\n12 15 15\\\\n20 10 10\\\\n18 11 11\\\\n19 10 10\\\\n19 10 10\\\\n20 10 10\\\\n19 10 10\\\\n20 10 10\\\\n20 10 10\\\\n19 10 10\\\\n18 11 11\\\\n13 16 16\\\\n12 15 15\\\\n19 10 10\\\\n19 10 10\\\\n\\\",\\n \\\"89 29 88 30\\\\n16\\\\n87 31 31\\\\n14 95 95\\\\n98 88 89\\\\n96 88 88\\\\n14 97 97\\\\n13 97 98\\\\n100 88 88\\\\n88 32 32\\\\n99 88 89\\\\n90 29 29\\\\n87 31 31\\\\n15 94 96\\\\n89 29 29\\\\n88 32 32\\\\n97 89 89\\\\n88 29 30\\\\n\\\",\\n \\\"30 14 39 19\\\\n31\\\\n35 7 11\\\\n37 11 12\\\\n32 13 13\\\\n37 5 6\\\\n46 13 13\\\\n37 14 14\\\\n31 13 13\\\\n43 13 19\\\\n45 15 19\\\\n46 13 13\\\\n32 17 17\\\\n41 14 19\\\\n30 14 14\\\\n43 13 17\\\\n34 16 18\\\\n44 11 19\\\\n38 13 13\\\\n40 12 20\\\\n37 16 18\\\\n46 16 18\\\\n34 10 14\\\\n36 9 10\\\\n36 15 19\\\\n38 15 19\\\\n42 13 19\\\\n33 14 15\\\\n35 15 19\\\\n33 17 18\\\\n39 12 20\\\\n36 5 7\\\\n45 12 12\\\\n\\\",\\n \\\"2 1 1 1\\\\n2\\\\n1 1 2\\\\n2 1 2\\\\n\\\",\\n \\\"1 1 1 2\\\\n5\\\\n1000000000 1 10000\\\\n19920401 1188 5566\\\\n1000000000 1 10000\\\\n1 1 10000\\\\n5 100 200\\\\n\\\",\\n \\\"1 1 1000000000 2\\\\n5\\\\n1000000000 1 10000\\\\n19920401 1188 5566\\\\n1000000000 1 10000\\\\n1 1 10000\\\\n5 100 200\\\\n\\\"\\n ],\\n \\\"outputs\\\": [\\n \\\"4\\\\n\\\",\\n \\\"6\\\\n\\\",\\n \\\"-1\\\\n\\\",\\n \\\"-1\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"-1\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"9\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"-1\\\\n\\\"\\n ]\\n}\"\"\",  # noqa: E501\n]  # A failed sample with first several in-out passed\n\nprime_code_scores = [1.0, 0.9]\n\n\ndef test_parallelism():\n    \"\"\"\n    Test if process pool works properly\n    \"\"\"\n    sequences_str = []\n    ground_truth = []\n    data_sources = []\n    while len(sequences_str) < 32:\n        sequences_str.extend(prime_code_answers)\n        ground_truth.extend(prime_code_gts)\n        data_sources.extend([\"codecontests\"] * len(prime_code_answers))\n\n        sequences_str.extend(prime_math_answers)\n        ground_truth.extend(prime_math_gts)\n        data_sources.extend([\"numina_aops_forum\"] * len(prime_math_answers))\n\n    scores = asyncio.run(\n        parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16)\n    )\n    print(scores)\n\n\ndef test_prime_code():\n    \"\"\"\n    Test PRIME code sandbox.\n    \"\"\"\n    data_source = \"codecontests\"\n    for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True):\n        score = default_compute_score(data_source, completion, ground_truth)\n        assert float(score) == score_\n\n\n# Use the pytest.mark.skipif decorator to skip the test\n@pytest.mark.skipif(not os.environ.get(\"SANDBOX_FUSION_URL\"), reason=\"SANDBOX_FUSION_URL environment variable not set\")\ndef test_prime_code_sandbox_fusion():\n    \"\"\"\n    Test PRIME code on sandbox fusion. Skips if SANDBOX_FUSION_URL is not set.\n    \"\"\"\n    data_source = \"codecontests\"\n    # Get the URL from the environment variable, as skipif ensures it is set at this point\n    sandbox_fusion_url = os.environ.get(\"SANDBOX_FUSION_URL\")\n    # Removed the previous 'if not sandbox_url' check block\n\n    for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True):\n        score = default_compute_score(\n            data_source, completion, ground_truth, extra_info={\"sandbox_fusion_url\": sandbox_fusion_url}\n        )  # <-- Use the URL obtained from the environment variable\n        assert float(score) == score_\n\n\n@pytest.mark.skipif(not os.environ.get(\"SANDBOX_FUSION_URL\"), reason=\"SANDBOX_FUSION_URL environment variable not set\")\ndef test_continuous_score_consistency():\n    \"\"\"\n    Verify that continuous score calculation is consistent between prime_code and sandbox_fusion.\n    Uses a test case where the first 9 out of 11 sub-cases pass (expected score 0.9).\n    \"\"\"\n    completion = prime_code_answers[1]  # Use the second sample\n    ground_truth = prime_code_gts[1]  # Use the second sample (9/11 pass, first 9 pass)\n    expected_continuous_score = 0.9\n\n    # 1. Calculate score using prime_code (default) with continuous=True\n    prime_score, _ = sandbox_fusion.compute_score(\n        os.environ.get(\"SANDBOX_FUSION_URL\"), None, completion, ground_truth, continuous=True\n    )\n\n    # 2. Calculate score using sandbox_fusion with continuous=True\n    # Ensure the extra_info key triggers the sandbox_fusion path in default_compute_score\n    fusion_score, _ = prime_code.compute_score(completion, ground_truth, continuous=True)\n\n    # 3. Assert scores are equal (using pytest.approx for float comparison)\n    assert float(prime_score) == pytest.approx(expected_continuous_score)\n    assert float(fusion_score) == pytest.approx(expected_continuous_score)\n    assert float(prime_score) == pytest.approx(float(fusion_score))\n    print(f\"Continuous Score (Prime Code): {prime_score}\")\n    print(f\"Continuous Score (Sandbox Fusion): {fusion_score}\")\n\n\ndef test_check_correctness():\n    completion = prime_code_answers[0]\n    ground_truth = json.loads(prime_code_gts[0])\n    ground_truth_single = {\"inputs\": ground_truth[\"inputs\"][:1], \"outputs\": ground_truth[\"outputs\"][:1]}\n    res, meta = apps_check_correctness(in_outs=ground_truth_single, generation=completion, timeout=5, debug=False)\n    print(res, meta)\n\n\ndef test_prime_math():\n    data_source = \"numina_aops_forum\"\n    for completion, ground_truth in zip(prime_math_answers, prime_math_gts, strict=True):\n        score = default_compute_score(data_source, completion, ground_truth)\n        assert float(score) == 1.0\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_activation_offload.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nimport shutil\nimport tempfile\n\nimport pytest\nimport torch\nimport torch.distributed\nimport torch.multiprocessing as mp\nfrom torch.distributed import init_device_mesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import MixedPrecision, ShardingStrategy\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config\n\nfrom verl.utils.activation_offload import enable_activation_offloading\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, get_fsdp_wrap_policy\n\n\ndef create_random_input_ids(batch_size, seq_len, vocab_size):\n    from flash_attn.bert_padding import unpad_input\n\n    from verl.utils.model import compute_position_id_with_mask, create_random_mask\n\n    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=\"cuda\")\n\n    attention_mask = create_random_mask(\n        input_ids, max_ratio_of_left_padding=0.1, min_ratio_of_valid_token=0.5, max_ratio_of_valid_token=0.7\n    )\n    position_ids = compute_position_id_with_mask(attention_mask)\n\n    input_ids = unpad_input(input_ids.unsqueeze(-1), attention_mask)[0].transpose(0, 1)\n    position_ids = unpad_input(position_ids.unsqueeze(-1), attention_mask)[0].transpose(0, 1)\n    return input_ids, position_ids\n\n\ndef _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy=\"fsdp\"):\n    torch.cuda.set_device(rank)\n    torch.distributed.init_process_group(\n        backend=\"nccl\",\n        init_method=f\"file://{rendezvous_file}\",\n        rank=rank,\n        world_size=world_size,\n    )\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(world_size,), mesh_dim_names=(\"dp\",))\n\n    model_name = \"Qwen/Qwen2.5-0.5B-Instruct\"\n    config = Qwen2Config(num_hidden_layers=4)\n\n    with torch.device(\"cuda\"):\n        model = AutoModelForCausalLM.from_config(\n            config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n        )\n        model = model.to(device=\"cuda\")\n\n    # Wrap model with FSDP\n    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)\n\n    if strategy == \"fsdp\":\n        model = FSDP(\n            model,\n            use_orig_params=False,\n            device_id=torch.cuda.current_device(),\n            sharding_strategy=ShardingStrategy.FULL_SHARD,\n            mixed_precision=mixed_precision,\n            device_mesh=device_mesh,\n            auto_wrap_policy=get_fsdp_wrap_policy(module=model),\n        )\n    else:\n        mp_policy = MixedPrecisionPolicy(\n            param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True\n        )\n        fsdp_kwargs = {\n            \"mesh\": device_mesh,\n            \"mp_policy\": mp_policy,\n        }\n        apply_fsdp2(model, fsdp_kwargs, {})\n\n    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)\n    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)\n\n    # Create checkpoint manager\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    checkpoint_manager = FSDPCheckpointManager(\n        model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer\n    )\n\n    # Generate sample input\n    batch_size = 2\n    seq_len = 32\n    vocab_size = 32000\n    # First input for initial update\n    input_ids1, position_ids1 = create_random_input_ids(batch_size, seq_len, vocab_size)\n\n    # Second input for verification\n    input_ids2, position_ids2 = create_random_input_ids(batch_size, seq_len, vocab_size)\n\n    # Step 1: Initial update and save checkpoint\n    outputs1 = model(input_ids=input_ids1, position_ids=position_ids1)\n    loss1 = outputs1.logits.mean()\n    loss1.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Save checkpoint after first update\n    temp_dir = tempfile.mkdtemp()\n    checkpoint_path = os.path.join(temp_dir, \"checkpoint\")\n    checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)\n\n    # Step 2: Second update and forward pass\n    outputs2 = model(input_ids=input_ids2, position_ids=position_ids2)\n    loss2 = outputs2.logits.mean()\n    loss2.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Record logits after second update\n    with torch.no_grad():\n        logits_without_offloading = model(input_ids=input_ids2, position_ids=position_ids2).logits\n\n    # Step 3: wrap module with activation offloading and load checkpoint\n    enable_activation_offloading(model, strategy=strategy)\n    checkpoint_manager.load_checkpoint(checkpoint_path)\n\n    # Step 4: Repeat the second update with same input\n    outputs3 = model(input_ids=input_ids2, position_ids=position_ids2)\n    loss3 = outputs3.logits.mean()\n    loss3.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Record logits after loaded checkpoint and update\n    with torch.no_grad():\n        logits_with_offloading = model(input_ids=input_ids2, position_ids=position_ids2).logits\n\n    # Step 4: Verify outputs match\n    torch.testing.assert_close(logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0)\n    print(f\"Activaiton offloading for {strategy} test passed on {world_size} GPUs!\")\n\n    # Cleanup\n    shutil.rmtree(temp_dir)\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n\n\n@pytest.mark.parametrize(\"world_size\", (2, 4))\n@pytest.mark.parametrize(\"strategy\", (\"fsdp\", \"fsdp2\"))\ndef test_activation_offloading(world_size, strategy, tmp_path):\n    rendezvous_file = str(tmp_path / \"rdzv_file\")\n    os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)\n\n    mp.spawn(\n        fn=_fsdp_activation_offloading_test,\n        args=(world_size, rendezvous_file, strategy),\n        nprocs=world_size,\n        join=True,\n    )\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_config_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nfrom dataclasses import dataclass, field\n\nfrom omegaconf import OmegaConf\n\nfrom verl.base_config import BaseConfig\nfrom verl.utils import omega_conf_to_dataclass\n\n\n@dataclass\nclass TestDataclass(BaseConfig):\n    hidden_size: int = 0\n    activation: str = \"relu\"\n\n\n@dataclass\nclass TestTrainConfig(BaseConfig):\n    batch_size: int = 0\n    model: TestDataclass = field(default_factory=TestDataclass)\n    override_config: dict = field(default_factory=dict)\n\n\n_cfg_str = \"\"\"train_config:\n  _target_: tests.utils.test_config_on_cpu.TestTrainConfig\n  batch_size: 32\n  model:\n    hidden_size: 768\n    activation: relu\n  override_config: {}\"\"\"\n\n\nclass TestConfigOnCPU(unittest.TestCase):\n    \"\"\"Test cases for configuration utilities on CPU.\n\n    Test Plan:\n    1. Test basic OmegaConf to dataclass conversion for simple nested structures\n    2. Test nested OmegaConf to dataclass conversion for complex hierarchical configurations\n    3. Verify all configuration values are correctly converted and accessible\n    \"\"\"\n\n    def setUp(self):\n        self.config = OmegaConf.create(_cfg_str)\n\n    def test_omega_conf_to_dataclass(self):\n        sub_cfg = self.config.train_config.model\n        cfg = omega_conf_to_dataclass(sub_cfg, TestDataclass)\n        self.assertEqual(cfg.hidden_size, 768)\n        self.assertEqual(cfg.activation, \"relu\")\n        assert isinstance(cfg, TestDataclass)\n\n    def test_nested_omega_conf_to_dataclass(self):\n        cfg = omega_conf_to_dataclass(self.config.train_config, TestTrainConfig)\n        self.assertEqual(cfg.batch_size, 32)\n        self.assertEqual(cfg.model.hidden_size, 768)\n        self.assertEqual(cfg.model.activation, \"relu\")\n        assert isinstance(cfg, TestTrainConfig)\n        assert isinstance(cfg.model, TestDataclass)\n\n\nclass TestPrintCfgCommand(unittest.TestCase):\n    \"\"\"Test suite for the print_cfg.py command-line tool.\"\"\"\n\n    def test_command_with_override(self):\n        \"\"\"Test that the command runs without error when overriding config values.\"\"\"\n        import subprocess\n\n        # Run the command\n        result = subprocess.run(\n            [\"python3\", \"scripts/print_cfg.py\"],\n            capture_output=True,\n            text=True,\n        )\n\n        # Verify the command exited successfully\n        self.assertEqual(result.returncode, 0, f\"Command failed with stderr: {result.stderr}\")\n\n        # Verify the output contains expected config information\n        self.assertIn(\"critic\", result.stdout)\n        self.assertIn(\"profiler\", result.stdout)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_flops_counter.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\n\nimport pytest\n\nfrom verl.utils.flops_counter import FlopsCounter\n\nVALID_CONFIG_TYPE = {\"llama\", \"qwen2\", \"qwen3\", \"qwen3_moe\", \"deepseek_v3\", \"mistral\", \"gemma3_text\", \"apertus\"}\n\n\nclass Config:\n    def __init__(self, config_dict):\n        for key, value in config_dict.items():\n            setattr(self, key, value)\n\n\nCONFIG = {\n    \"llama\": {\n        \"config\": {  # llama2-7B\n            \"model_type\": \"llama\",\n            \"vocab_size\": 32000,\n            \"hidden_size\": 4096,\n            \"intermediate_size\": 11008,\n            \"num_hidden_layers\": 32,\n            \"num_attention_heads\": 32,\n            \"num_key_value_heads\": 32,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +\n        # 12*sum(seqlen^2)*layer*head*head_dim\n        # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(512+1024+2048) +\n        # 12*(512*512+1024*1024+2048*2048)*32*4096\n        # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(4096+4096+4096) +\n        # 12*(4096*4096+4096*4096+4096*4096)*32*4096\n        \"expected_flops_tuple\": (153555818250240 / 1e12, 575955114393600 / 1e12),\n    },\n    \"qwen2\": {\n        \"config\": {  # Qwen/Qwen2.5-7B-Instruct\n            \"model_type\": \"qwen2\",\n            \"vocab_size\": 152064,\n            \"hidden_size\": 3584,\n            \"intermediate_size\": 18944,\n            \"num_hidden_layers\": 28,\n            \"num_attention_heads\": 28,\n            \"num_key_value_heads\": 4,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +\n        # 12*sum(seqlen^2)*layer*head*head_dim\n        # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(512+1024+2048) +\n        # 12*(512*512+1024*1024+2048*2048)*28*3584\n        # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(4096+4096+4096) +\n        # 12*(4096*4096+4096*4096+4096*4096)*28*3584\n        \"expected_flops_tuple\": (170388331954176 / 1e12, 622070178250752 / 1e12),\n    },\n    \"qwen3\": {\n        \"config\": {  # Qwen/Qwen3-8B\n            \"model_type\": \"qwen3\",\n            \"vocab_size\": 151936,\n            \"hidden_size\": 4096,\n            \"intermediate_size\": 12288,\n            \"num_hidden_layers\": 36,\n            \"num_attention_heads\": 32,\n            \"num_key_value_heads\": 8,\n            \"head_dim\": 128,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +\n        # 12*sum(seqlen^2)*layer*head*head_dim\n        # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(512+1024+2048) +\n        # 12*(512*512+1024*1024+2048*2048)*36*128*32\n        # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(4096+4096+4096) +\n        # 12*(4096*4096+4096*4096+4096*4096)*36*128*32\n        \"expected_flops_tuple\": (185867930959872 / 1e12, 692924253732864 / 1e12),\n    },\n    \"qwen3_moe\": {\n        \"config\": {  # Qwen/Qwen3-30B-A3B-Base\n            \"model_type\": \"qwen3_moe\",\n            \"hidden_size\": 2048,\n            \"vocab_size\": 151936,\n            \"num_hidden_layers\": 48,\n            \"num_key_value_heads\": 4,\n            \"num_attention_heads\": 32,\n            \"head_dim\": 128,\n            \"moe_intermediate_size\": 768,\n            \"num_experts_per_tok\": 8,\n            \"num_experts\": 128,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3 +\n        # hidden*num_experts))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim\n        # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(512+1024+2048) +\n        # 12*(512*512+1024*1024+2048*2048)*48*128*32\n        # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(4096+4096+4096) +\n        # 12*(4096*4096+4096*4096+4096*4096)*48*128*32\n        \"expected_flops_tuple\": (85087060230144 / 1e12, 365944098521088 / 1e12),\n    },\n    \"deepseek_v3\": {\n        \"config\": {  # deepseek-ai/DeepSeek-Prover-V2-671B\n            \"model_type\": \"deepseek_v3\",\n            \"hidden_size\": 7168,\n            \"vocab_size\": 129280,\n            \"moe_intermediate_size\": 2048,\n            \"num_hidden_layers\": 61,\n            \"first_k_dense_replace\": 3,\n            \"num_attention_heads\": 128,\n            \"n_routed_experts\": 256,\n            \"num_experts_per_tok\": 8,\n            \"n_shared_experts\": 1,\n            \"kv_lora_rank\": 512,\n            \"qk_rope_head_dim\": 64,\n            \"v_head_dim\": 128,\n            \"intermediate_size\": 18432,\n            \"qk_nope_head_dim\": 128,\n            \"q_lora_rank\": 1536,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # (1536*7168+128*192*1536+7168*(512+64)+128*(128+128)*512+128*128*7168) = 187105280\n        # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(512+1024+2048) +\n        # 12*(512*512+1024*1024+2048*2048)*61*192*128\n        # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(4096+4096+4096) +\n        # 12*(4096*4096+4096*4096+4096*4096)*61*192*128\n        \"expected_flops_tuple\": (906535995703296 / 1e12, 3674028304760832 / 1e12),\n    },\n    \"mistral\": {\n        \"config\": {  # mistralai/Mistral-Small-24B-Instruct-2501\n            \"model_type\": \"mistral\",\n            \"vocab_size\": 131072,\n            \"hidden_size\": 5120,\n            \"intermediate_size\": 32768,\n            \"num_hidden_layers\": 40,\n            \"num_attention_heads\": 32,\n            \"num_key_value_heads\": 8,\n            \"head_dim\": 128,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # Mistral uses same architecture as Llama, with GQA\n        # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +\n        # 12*sum(seqlen^2)*layer*head*head_dim\n        # vocab part: 131072*5120*2 = 1342177280\n        # attn part per layer: 5120*(128*32+128*8+128*8+128*32) = 5120*10240 = 52428800\n        # mlp part per layer: 5120*32768*3 = 503316480\n        # total per layer: 52428800 + 503316480 = 555745280\n        # all layers: 1342177280 + 40*555745280 = 23571988480\n        # For batch [512, 1024, 2048], tokens_sum = 3584:\n        # dense flops: 6 * 23571988480 * 3584 = 506892040273920\n        # attn flops: 12 * 5505024 * 40 * 128 * 32 = 10823317585920\n        # total: 517715357859840 / 1e12 = 517.71535785984\n        # For batch [4096, 4096, 4096], tokens_sum = 12288:\n        # dense flops: 6 * 23571988480 * 12288 = 1737915566653440\n        # attn flops: 12 * 50331648 * 40 * 128 * 32 = 98956046499840\n        # total: 1836871613153280 / 1e12 = 1836.87161315328\n        \"expected_flops_tuple\": (517715357859840 / 1e12, 1836871613153280 / 1e12),\n    },\n    \"gemma3_text\": {\n        \"config\": {  # Gemma3-12B-IT-TextOnly\n            \"model_type\": \"gemma3_text\",\n            \"vocab_size\": 262208,\n            \"hidden_size\": 3840,\n            \"intermediate_size\": 15360,\n            \"num_hidden_layers\": 48,\n            \"num_attention_heads\": 16,\n            \"num_key_value_heads\": 8,\n            \"head_dim\": 256,\n            \"sliding_window\": 1024,\n            \"layer_types\": None,\n            # Will be auto-generated based on sliding_window_pattern\n            \"sliding_window_pattern\": 6,\n            # Every 6th layer is full attention\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # Gemma3 has alternating sliding window attention\n        # With sliding_window_pattern=6: layers 5,11,17,23,29,35,41,47 use full attention (8 layers)\n        # Other 40 layers use sliding window attention with window_size=1024\n        #\n        # Non-attention FLOPs:\n        # vocab part: 262208*3840*2 = 2013757440\n        # attn part per layer: 3840*(256*16+256*8+256*8+256*16) = 3840*12288 = 47185920\n        # mlp part per layer: 3840*15360*3 = 176947200\n        # total per layer: 47185920 + 176947200 = 224133120\n        # all layers: 2013757440 + 48*224133120 = 12772147200\n        #\n        # For batch [512, 1024, 2048], tokens_sum = 3584:\n        # dense flops: 6 * 12772147200 * 3584 = 274652253388800\n        # seqlen_square_sum: 180355072 (calculated with sliding window logic)\n        # attn flops: 12 * 180355072 * 256 * 16 = 8864812498944\n        # total: 283517065887744 / 1e12 = 283.517065887744\n        #\n        # For batch [4096, 4096, 4096], tokens_sum = 12288:\n        # dense flops: 6 * 12772147200 * 12288 = 941664868761600\n        # seqlen_square_sum: 905969664 (calculated with sliding window logic)\n        # attn flops: 12 * 905969664 * 256 * 16 = 44530220924928\n        # total: 986195089686528 / 1e12 = 986.195089686528\n        \"expected_flops_tuple\": (283517065887744 / 1e12, 986195089686528 / 1e12),\n    },\n    \"apertus\": {\n        \"config\": {  # swiss-ai/Apertus-8B\n            \"model_type\": \"apertus\",\n            \"vocab_size\": 131072,\n            \"hidden_size\": 4096,\n            \"intermediate_size\": 21504,\n            \"num_hidden_layers\": 32,\n            \"num_attention_heads\": 32,\n            \"num_key_value_heads\": 32,\n            \"hidden_act\": \"xielu\",\n            # head_dim will be derived as 4096 / 32 = 128\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # Calculation for Apertus (hidden_act=\"xielu\" -> MLP uses [k_mlp=2]*H*I params; qk_norm=True -> [k_qkn=2]*H):\n        # V=131072, H=4096, I=21504, L=32, k_mlp=2 (XIELU), k_qkn=2 (QK norm), S=6\n        # S*(2*V*H + L*(4*H**2 + k_mlp*H*I + k_qkn*H)) * (SUM[seqlen]) + 12*SUM[seqlen**2]*L*H\n        \"expected_flops_tuple\": (199154680725504 / 1e12, 732294071451648 / 1e12),\n    },\n}\n\n\n@pytest.mark.parametrize(\n    \"config_type\",\n    [\"llama\", \"qwen2\", \"qwen3\", \"qwen3_moe\", \"deepseek_v3\", \"mistral\", \"gemma3_text\", \"apertus\"],\n)\ndef test_flops_counter(config_type: str):\n    test_config = CONFIG[config_type]\n    config = Config(test_config[\"config\"])\n    flops_counter = FlopsCounter(config)\n    for batch_seqlens, expected_flops in zip(\n        test_config[\"batch_seqlens_tuple\"], test_config[\"expected_flops_tuple\"], strict=True\n    ):\n        # set delta time to 1 to get the flops\n        counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1)\n        print(f\"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}\")\n        assert math.isclose(counted_flops, expected_flops), (\n            f\"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}\"\n        )\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_fs_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom pathlib import Path\n\nimport verl.utils.fs as fs\n\n\ndef test_record_and_check_directory_structure(tmp_path):\n    # Create test directory structure\n    test_dir = tmp_path / \"test_dir\"\n    test_dir.mkdir()\n    (test_dir / \"file1.txt\").write_text(\"test\")\n    (test_dir / \"subdir\").mkdir()\n    (test_dir / \"subdir\" / \"file2.txt\").write_text(\"test\")\n\n    # Create structure record\n    record_file = fs._record_directory_structure(test_dir)\n\n    # Verify record file exists\n    assert os.path.exists(record_file)\n\n    # Initial check should pass\n    assert fs._check_directory_structure(test_dir, record_file) is True\n\n    # Modify structure and verify check fails\n    (test_dir / \"new_file.txt\").write_text(\"test\")\n    assert fs._check_directory_structure(test_dir, record_file) is False\n\n\ndef test_copy_from_hdfs_with_mocks(tmp_path, monkeypatch):\n    # Mock HDFS dependencies\n    monkeypatch.setattr(fs, \"is_non_local\", lambda path: True)\n\n    # side_effect will simulate the copy by creating parent dirs + empty file\n    def fake_copy(src: str, dst: str, *args, **kwargs):\n        dst_path = Path(dst)\n        dst_path.parent.mkdir(parents=True, exist_ok=True)\n        dst_path.write_bytes(b\"\")  # touch an empty file\n\n    monkeypatch.setattr(fs, \"copy\", fake_copy)  # Mock actual HDFS copy\n\n    # Test parameters\n    test_cache = tmp_path / \"cache\"\n    hdfs_path = \"hdfs://test/path/file.txt\"\n\n    # Test initial copy\n    local_path = fs.copy_to_local(hdfs_path, cache_dir=test_cache)\n    expected_path = os.path.join(test_cache, fs.md5_encode(hdfs_path), os.path.basename(hdfs_path))\n    assert local_path == expected_path\n    assert os.path.exists(local_path)\n\n\ndef test_always_recopy_flag(tmp_path, monkeypatch):\n    # Mock HDFS dependencies\n    monkeypatch.setattr(fs, \"is_non_local\", lambda path: True)\n\n    copy_call_count = 0\n\n    def fake_copy(src: str, dst: str, *args, **kwargs):\n        nonlocal copy_call_count\n        copy_call_count += 1\n        dst_path = Path(dst)\n        dst_path.parent.mkdir(parents=True, exist_ok=True)\n        dst_path.write_bytes(b\"\")\n\n    monkeypatch.setattr(fs, \"copy\", fake_copy)  # Mock actual HDFS copy\n\n    test_cache = tmp_path / \"cache\"\n    hdfs_path = \"hdfs://test/path/file.txt\"\n\n    # Initial copy (always_recopy=False)\n    fs.copy_to_local(hdfs_path, cache_dir=test_cache)\n    assert copy_call_count == 1\n\n    # Force recopy (always_recopy=True)\n    fs.copy_to_local(hdfs_path, cache_dir=test_cache, always_recopy=True)\n    assert copy_call_count == 2\n\n    # Subsequent normal call (always_recopy=False)\n    fs.copy_to_local(hdfs_path, cache_dir=test_cache)\n    assert copy_call_count == 2  # Should not increment\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_groupwise.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nos.environ.setdefault(\"VERL_FORCE_DEVICE\", \"cpu\")  # ensure CPU for tests\n\nimport numpy as np\nimport pytest\nimport torch\n\nfrom verl.utils import as_torch_index, group_mean_std\n\n\ndef test_as_torch_index_basic_integers():\n    g = as_torch_index([2, 2, 5, 7, 5, 2])\n    assert g.dtype == torch.long\n    assert g.device.type == \"cpu\"\n    # Values should be contiguous 0..G-1, keeping equal labels equal\n    assert g.tolist()[0] == g.tolist()[1]\n    assert len(torch.unique(g)) == 3  # {2,5,7} -> 3 groups\n\n\ndef test_as_torch_index_near_integer_floats():\n    arr = np.array([1.0000001, 2.0, 1.0, 3.0000000001], dtype=np.float64)\n    g = as_torch_index(arr)  # should round to integers then factorize\n    assert g.dtype == torch.long\n    assert len(torch.unique(g)) == 3  # {1,2,3}\n\n\ndef test_as_torch_index_factorization_mixed():\n    labels = [\"a\", \"b\", \"a\", \"c\", \"0042\", 42]\n    g = as_torch_index(labels)\n    # \"0042\" and 42 should NOT be the same group (strings are not coerced here)\n    assert g.tolist()[4] != g.tolist()[5]\n    assert len(torch.unique(g)) == 5\n\n\ndef test_group_mean_std_simple():\n    # groups: 0 -> [1, 3], 1 -> [2]\n    scores = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)\n    gidx = as_torch_index([0, 1, 0])\n\n    mean_g, std_g, cnt_g = group_mean_std(scores, gidx)\n    # group 0: mean = (1+3)/2 = 2\n    # sample std (unbiased) = sqrt( (sum(x^2) - (sum(x)^2)/n) / (n-1) )\n    # = sqrt( (1^2+3^2) - (1+3)^2/2 ) / (2-1) = sqrt(10 - 16/2) = sqrt(2)\n    assert torch.allclose(mean_g, torch.tensor([2.0, 0.0]))\n    assert torch.allclose(cnt_g, torch.tensor([2.0, 1.0]))\n    # singleton group -> std = 1.0\n    assert mean_g[1].item() == 0.0\n    assert std_g[1].item() == 1.0\n    assert pytest.approx(std_g[0].item(), rel=1e-6) == (2.0**0.5)\n\n\ndef test_group_mean_std_empty():\n    scores = torch.tensor([], dtype=torch.float32)\n    gidx = torch.tensor([], dtype=torch.long)\n    mean_g, std_g, cnt_g = group_mean_std(scores, gidx)\n    assert mean_g.numel() == 0 and std_g.numel() == 0 and cnt_g.numel() == 0\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_import_utils_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport pytest\n\nfrom verl.utils.import_utils import load_extern_type\n\n# Path to the test module\nTEST_MODULE_PATH = os.path.join(os.path.dirname(__file__), \"_test_module.py\")\n\n\ndef test_load_extern_type_class():\n    \"\"\"Test loading a class from an external file\"\"\"\n    TestClass = load_extern_type(TEST_MODULE_PATH, \"TestClass\")\n\n    # Verify the class was loaded correctly\n    assert TestClass is not None\n    assert TestClass.__name__ == \"TestClass\"\n\n    # Test instantiation and functionality\n    instance = TestClass()\n    assert instance.value == \"default\"\n\n    # Test with a custom value\n    custom_instance = TestClass(\"custom\")\n    assert custom_instance.get_value() == \"custom\"\n\n\ndef test_load_extern_type_function():\n    \"\"\"Test loading a function from an external file\"\"\"\n    test_function = load_extern_type(TEST_MODULE_PATH, \"test_function\")\n\n    # Verify the function was loaded correctly\n    assert test_function is not None\n    assert callable(test_function)\n\n    # Test function execution\n    result = test_function()\n    assert result == \"test_function_result\"\n\n\ndef test_load_extern_type_constant():\n    \"\"\"Test loading a constant from an external file\"\"\"\n    constant = load_extern_type(TEST_MODULE_PATH, \"TEST_CONSTANT\")\n\n    # Verify the constant was loaded correctly\n    assert constant is not None\n    assert constant == \"test_constant_value\"\n\n\ndef test_load_extern_type_nonexistent_file():\n    \"\"\"Test behavior when file doesn't exist\"\"\"\n    with pytest.raises(FileNotFoundError):\n        load_extern_type(\"/nonexistent/path.py\", \"SomeType\")\n\n\ndef test_load_extern_type_nonexistent_type():\n    \"\"\"Test behavior when type doesn't exist in the file\"\"\"\n    with pytest.raises(AttributeError):\n        load_extern_type(TEST_MODULE_PATH, \"NonExistentType\")\n\n\ndef test_load_extern_type_none_path():\n    \"\"\"Test behavior when file path is None\"\"\"\n    result = load_extern_type(None, \"SomeType\")\n    assert result is None\n\n\ndef test_load_extern_type_invalid_module():\n    \"\"\"Test behavior when module has syntax errors\"\"\"\n    # Create a temporary file with syntax errors\n    import tempfile\n\n    with tempfile.NamedTemporaryFile(suffix=\".py\", mode=\"w+\", delete=False) as temp_file:\n        temp_file.write(\"This is not valid Python syntax :\")\n        temp_path = temp_file.name\n\n    try:\n        with pytest.raises(RuntimeError):\n            load_extern_type(temp_path, \"SomeType\")\n    finally:\n        # Clean up the temporary file\n        if os.path.exists(temp_path):\n            os.remove(temp_path)\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_linear_cross_entropy.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.utils.experimental.torch_functional import FusedLinearForPPO\nfrom verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\nfrom verl.utils.torch_functional import logprobs_from_logits\n\ncompute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)\nfused_linear_for_ppo = FusedLinearForPPO()\nfused_linear_for_ppo.compile(dynamic=True)\n\nMAX_TEST_CASES = os.environ.get(\"MAX_TEST_CASES\", 5)\n\n\ndef run_torch_entropy(\n    hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction=\"none\"\n) -> list[torch.Tensor]:\n    hidden = hidden.squeeze(0).to(torch.float32)\n    weight = weight.transpose(0, 1).to(torch.float32)\n    logits = torch.matmul(hidden, weight)  # [num_tokens, vocab_size]\n    logits /= temperature\n    pd = torch.nn.functional.softmax(logits, dim=-1)  # [num_tokens, vocab_size]\n    entropy_a = torch.logsumexp(logits, dim=-1)  # [num_tokens]\n    entropy_b = torch.sum(pd * logits, dim=-1)  # [num_tokens]\n    entropy = entropy_a - entropy_b\n    logprobs = torch.nn.functional.cross_entropy(logits, labels.squeeze(0), reduction=reduction)  # [num_tokens]\n    logprobs = torch.neg(logprobs)\n    return logprobs, entropy\n\n\ndef run_verl_original_entropy(\n    hidden: torch.Tensor,\n    weight: torch.Tensor,\n    labels: torch.Tensor,\n    temperature: float,\n) -> list[torch.Tensor]:\n    hidden = hidden.squeeze(0).to(torch.float32)\n    weight = weight.transpose(0, 1).to(torch.float32)\n    logits = torch.matmul(hidden, weight)  # [num_tokens, vocab_size]\n    logits /= temperature\n    # compute entropy\n    entropy = compute_entropy_from_logits(logits)  # ((total_nnz / sp) + pad)\n    # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)\n    logprobs = logprobs_from_logits(logits=logits, labels=labels, inplace_backward=False)\n    return logprobs, entropy\n\n\n# To be tested\ndef run_verl_torch_fused_entropy(\n    hidden: torch.Tensor,\n    weight: torch.Tensor,\n    labels: torch.Tensor,\n    temperature: float,\n):\n    hidden = hidden.to(torch.float32)\n    weight = weight.to(torch.float32)\n    logprobs, entropy = fused_linear_for_ppo(\n        hidden,\n        weight,\n        labels,\n        temperature=temperature,\n    )\n    return logprobs.squeeze(0), entropy.squeeze(0)\n\n\nclass TestLinearCrossEntropy:\n    def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None:\n        self.test_case_idx = test_case_idx\n        self.temperature = temperature\n\n    def cleanup(self):\n        torch.cuda.empty_cache()\n        torch.cuda.reset_peak_memory_stats()\n        import gc\n\n        gc.collect()\n        torch.cuda.synchronize()\n\n    def generate_hyper(self):\n        global MAX_TEST_CASES\n\n        self.dtype = torch.bfloat16\n        if self.test_case_idx == 0:\n            self.batch_size = 1\n            self.num_tokens = 1937\n            self.hidden_size = 3584\n            self.vocab_size = 152064\n        elif self.test_case_idx == 1:\n            self.batch_size = 1\n            self.num_tokens = 2169\n            self.hidden_size = 896\n            self.vocab_size = 151936\n        elif self.test_case_idx == 2:\n            self.batch_size = 1\n            self.num_tokens = 1530\n            self.hidden_size = 2048\n            self.vocab_size = 32256\n        elif self.test_case_idx == 3:\n            self.batch_size = 1\n            self.num_tokens = 1388\n            self.hidden_size = 4096\n            self.vocab_size = 102400\n        elif self.test_case_idx == 4:\n            self.batch_size = 1\n            self.num_tokens = 8192\n            self.hidden_size = 4096\n            self.vocab_size = 102400\n        else:\n            raise ValueError(f\"Invalid test case index: {self.test_case_idx}\")\n        assert MAX_TEST_CASES <= 5, \"MAX_TEST_CASES should be less than or equal to 5.\"\n\n    def generate_forward_inputs(self):\n        hidden = (\n            torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device=\"cuda\")\n            .uniform_(-0.5, 0.5)\n            .requires_grad_()\n        )\n        weight = (\n            torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device=\"cuda\")\n            .uniform_(-0.5, 0.5)\n            .requires_grad_()\n        )\n        labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device=\"cuda\")\n        return hidden, weight, labels\n\n    def generate_backward_inputs(self):\n        g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device=\"cuda\").uniform_(-0.5, 0.5)\n        g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device=\"cuda\").uniform_(-1, 1)\n        return g_entropy, g_logprobs\n\n    def verify_correctness(self, iterations=5):\n        self.cleanup()\n        self.generate_hyper()\n\n        torch_forward_latency = list()\n        torch_backward_latency = list()\n        verl_forward_latency = list()\n        verl_backward_latency = list()\n        verl_fused_forward_latency = list()\n        verl_fused_backward_latency = list()\n        kernel_forward_latency = list()\n        kernel_backward_latency = list()\n\n        start_event = torch.cuda.Event(enable_timing=True)\n        end_event = torch.cuda.Event(enable_timing=True)\n\n        for i in range(iterations):\n            print(f\"[INFO]: Iteration {i + 1} / {iterations}...\", end=\"\\r\")\n            hidden, weight, labels = self.generate_forward_inputs()\n\n            start_event.record()\n            (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature)\n            end_event.record()\n            torch.cuda.synchronize()\n            torch_forward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature)\n            end_event.record()\n            torch.cuda.synchronize()\n            verl_forward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy(\n                hidden, weight, labels, self.temperature\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            verl_fused_forward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature)\n            end_event.record()\n            torch.cuda.synchronize()\n            kernel_forward_latency.append(start_event.elapsed_time(end_event))\n\n            torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4)\n            torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4)\n\n            torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4)\n            torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4)\n            torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4)\n            torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4)\n\n            torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)\n            torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)\n            torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)\n            torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)\n            torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)\n            torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)\n\n            # backward\n            g_entropy, g_logprobs = self.generate_backward_inputs()\n\n            start_event.record()\n            (d_torch_hidden, d_torch_weight) = torch.autograd.grad(\n                (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            torch_backward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (d_verl_hidden, d_verl_weight) = torch.autograd.grad(\n                (verl_entropy, verl_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            verl_backward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (d_verl_fused_hidden, d_verl_fused_weight) = torch.autograd.grad(\n                (verl_fused_entropy, verl_fused_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            verl_fused_backward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(\n                (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            kernel_backward_latency.append(start_event.elapsed_time(end_event))\n\n            torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4)\n\n            torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4)\n\n            torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)\n\n        # remove first latency\n        torch_forward_latency = torch_forward_latency[1:]\n        torch_backward_latency = torch_backward_latency[1:]\n        verl_forward_latency = verl_forward_latency[1:]\n        verl_backward_latency = verl_backward_latency[1:]\n        verl_fused_forward_latency = verl_fused_forward_latency[1:]\n        verl_fused_backward_latency = verl_fused_backward_latency[1:]\n        kernel_forward_latency = kernel_forward_latency[1:]\n        kernel_backward_latency = kernel_backward_latency[1:]\n\n        print(\"\\n[INFO]: Verified forward & backward correctness.\")\n\n        print(\n            f\"[INFO]: Forward pass: Torch implementation average time: \"\n            f\"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Backward pass: torch implementation average time: \"\n            f\"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Forward pass: VeRL implementation average time: \"\n            f\"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Backward pass: VeRL implementation average time: \"\n            f\"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: \"\n            f\"{sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: \"\n            f\"{sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Forward pass: Kernel implementation average time: \"\n            f\"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Backward pass: kernel implementation average time: \"\n            f\"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms\"\n        )\n\n    def check_storage(self, method_name, run_forward):\n        self.cleanup()\n        self.generate_hyper()\n\n        hidden, weight, labels = self.generate_forward_inputs()\n\n        torch.cuda.reset_peak_memory_stats()\n        (logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature)\n        torch.cuda.synchronize()\n        torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n        print(f\"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB\")\n\n        g_entropy, g_logprobs = self.generate_backward_inputs()\n\n        torch.cuda.reset_peak_memory_stats()\n        (d_torch_hidden, d_torch_weight) = torch.autograd.grad(\n            (entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n        )\n        torch.cuda.synchronize()\n        torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n        print(f\"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB\")\n\n    def check_storage_all(self):\n        self.check_storage(\"Torch\", run_torch_entropy)\n        self.check_storage(\"VeRL\", run_verl_original_entropy)\n        self.check_storage(\"VeRL Torch Fused\", run_verl_torch_fused_entropy)\n        self.check_storage(\"Kernel\", linear_cross_entropy)\n\n\nif __name__ == \"__main__\":\n    # torch.cuda.memory._record_memory_history()\n\n    for test_case_idx in range(MAX_TEST_CASES):\n        print(f\"[INFO] Running test case {test_case_idx}\")\n        test = TestLinearCrossEntropy(test_case_idx)\n\n        test.verify_correctness()\n        test.check_storage_all()\n\n    # torch.cuda.memory._dump_snapshot(\"test_linear_cross_entropy.pkl\")\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_mlflow_key_sanitization.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nfrom unittest.mock import patch\n\nfrom verl.utils.tracking import _MlflowLoggingAdapter\n\n\nclass TestMlflowLoggingAdapter(unittest.TestCase):\n    def test_sanitize_key_and_warning(self):\n        adapter = _MlflowLoggingAdapter()\n        data = {\"valid_key\": 1.0, \"invalid@key!\": 2.0, \"another/valid-key\": 3.0, \"bad key#\": 4.0}\n        # Patch mlflow.log_metrics to capture the metrics actually sent\n        with (\n            patch(\"mlflow.log_metrics\") as mock_log_metrics,\n            patch.object(adapter, \"logger\") as mock_logger,\n        ):\n            adapter.log(data, step=5)\n            # Check that keys are sanitized\n            sent_metrics = mock_log_metrics.call_args[1][\"metrics\"]\n            self.assertIn(\"invalid_at_key_\", sent_metrics)  # @ becomes _at_, ! becomes _\n            self.assertIn(\"bad key_\", sent_metrics)  # # becomes _, space remains\n            self.assertNotIn(\"invalid@key!\", sent_metrics)\n            self.assertNotIn(\"bad key#\", sent_metrics)\n            # Check that a warning was logged for each sanitized key\n            warning_msgs = [str(call) for call in mock_logger.warning.call_args_list]\n            self.assertTrue(any(\"invalid@key!\" in msg and \"invalid_at_key_\" in msg for msg in warning_msgs))\n            self.assertTrue(any(\"bad key#\" in msg and \"bad key_\" in msg for msg in warning_msgs))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_model_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom types import SimpleNamespace  # Or use a mock object library\n\nimport pytest\n\nfrom verl.utils.model import update_model_config\n\n\n# Parametrize with different override scenarios\n@pytest.mark.parametrize(\n    \"override_kwargs\",\n    [\n        {\"param_a\": 5, \"new_param\": \"plain_added\"},\n        {\"param_a\": 2, \"nested_params\": {\"sub_param_x\": \"updated_x\", \"sub_param_z\": True}},\n    ],\n)\ndef test_update_model_config(override_kwargs):\n    \"\"\"\n    Tests that update_model_config correctly updates attributes,\n    handling both plain and nested overrides via parametrization.\n    \"\"\"\n    # Create a fresh mock config object for each test case\n    mock_config = SimpleNamespace(\n        param_a=1, nested_params=SimpleNamespace(sub_param_x=\"original_x\", sub_param_y=100), other_param=\"keep_me\"\n    )\n    # Apply the updates using the parametrized override_kwargs\n    update_model_config(mock_config, override_kwargs)\n\n    # Assertions to check if the config was updated correctly\n    if \"nested_params\" in override_kwargs:  # Case 2: Nested override\n        override_nested = override_kwargs[\"nested_params\"]\n        assert mock_config.nested_params.sub_param_x == override_nested[\"sub_param_x\"], \"Nested sub_param_x mismatch\"\n        assert mock_config.nested_params.sub_param_y == 100, \"Nested sub_param_y should be unchanged\"\n        assert hasattr(mock_config.nested_params, \"sub_param_z\"), \"Expected nested sub_param_z to be added\"\n        assert mock_config.nested_params.sub_param_z == override_nested[\"sub_param_z\"], \"Value of sub_param_z mismatch\"\n    else:  # Case 1: Plain override (nested params untouched)\n        assert mock_config.nested_params.sub_param_x == \"original_x\", \"Nested sub_param_x should be unchanged\"\n        assert mock_config.nested_params.sub_param_y == 100, \"Nested sub_param_y should be unchanged\"\n        assert not hasattr(mock_config.nested_params, \"sub_param_z\"), \"Nested sub_param_z should not exist\"\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_nvtx_profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nfrom unittest.mock import MagicMock, patch\n\nfrom verl.utils import omega_conf_to_dataclass\nfrom verl.utils.profiler.config import NsightToolConfig, ProfilerConfig\nfrom verl.utils.profiler.nvtx_profile import NsightSystemsProfiler\n\n\nclass TestProfilerConfig(unittest.TestCase):\n    def test_config_init(self):\n        import os\n\n        from hydra import compose, initialize_config_dir\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n            cfg = compose(config_name=\"ppo_trainer\")\n        for config in [\n            cfg.actor_rollout_ref.actor.profiler,\n            cfg.actor_rollout_ref.rollout.profiler,\n            cfg.actor_rollout_ref.ref.profiler,\n            cfg.critic.profiler,\n            cfg.reward_model.profiler,\n        ]:\n            profiler_config = omega_conf_to_dataclass(config)\n            self.assertEqual(profiler_config.tool, config.tool)\n            self.assertEqual(profiler_config.enable, config.enable)\n            self.assertEqual(profiler_config.all_ranks, config.all_ranks)\n            self.assertEqual(profiler_config.ranks, config.ranks)\n            self.assertEqual(profiler_config.save_path, config.save_path)\n            self.assertEqual(profiler_config.ranks, config.ranks)\n            assert isinstance(profiler_config, ProfilerConfig)\n            with self.assertRaises(AttributeError):\n                _ = profiler_config.non_existing_key\n            assert config.get(\"non_existing_key\") == profiler_config.get(\"non_existing_key\")\n            assert config.get(\"non_existing_key\", 1) == profiler_config.get(\"non_existing_key\", 1)\n\n    def test_frozen_config(self):\n        \"\"\"Test that modifying frozen keys in ProfilerConfig raises exceptions.\"\"\"\n        from dataclasses import FrozenInstanceError\n\n        from verl.utils.profiler.config import ProfilerConfig\n\n        # Create a new ProfilerConfig instance\n        config = ProfilerConfig(all_ranks=False, ranks=[0])\n\n        with self.assertRaises(FrozenInstanceError):\n            config.all_ranks = True\n\n        with self.assertRaises(FrozenInstanceError):\n            config.ranks = [1, 2, 3]\n\n        with self.assertRaises(TypeError):\n            config[\"all_ranks\"] = True\n\n        with self.assertRaises(TypeError):\n            config[\"ranks\"] = [1, 2, 3]\n\n\nclass TestNsightSystemsProfiler(unittest.TestCase):\n    \"\"\"Test suite for NsightSystemsProfiler functionality.\n\n    Test Plan:\n    1. Initialization: Verify profiler state after creation\n    2. Basic Profiling: Test start/stop functionality\n    3. Discrete Mode: TODO: Test discrete profiling behavior\n    4. Annotation: Test the annotate decorator in both normal and discrete modes\n    5. Config Validation: Verify proper config initialization from OmegaConf\n    \"\"\"\n\n    def setUp(self):\n        self.config = ProfilerConfig(enable=True, all_ranks=True)\n        self.rank = 0\n        self.profiler = NsightSystemsProfiler(self.rank, self.config, tool_config=NsightToolConfig(discrete=False))\n\n    def test_initialization(self):\n        self.assertEqual(self.profiler.this_rank, True)\n        self.assertEqual(self.profiler.this_step, False)\n\n    def test_start_stop_profiling(self):\n        with patch(\"torch.cuda.profiler.start\") as mock_start, patch(\"torch.cuda.profiler.stop\") as mock_stop:\n            # Test start\n            self.profiler.start()\n            self.assertTrue(self.profiler.this_step)\n            mock_start.assert_called_once()\n\n            # Test stop\n            self.profiler.stop()\n            self.assertFalse(self.profiler.this_step)\n            mock_stop.assert_called_once()\n\n    # def test_discrete_profiling(self):\n    #     discrete_config = ProfilerConfig(discrete=True, all_ranks=True)\n    #     profiler = NsightSystemsProfiler(self.rank, discrete_config)\n\n    #     with patch(\"torch.cuda.profiler.start\") as mock_start, patch(\"torch.cuda.profiler.stop\") as mock_stop:\n    #         profiler.start()\n    #         self.assertTrue(profiler.this_step)\n    #         mock_start.assert_not_called()  # Shouldn't start immediately in discrete mode\n\n    #         profiler.stop()\n    #         self.assertFalse(profiler.this_step)\n    #         mock_stop.assert_not_called()  # Shouldn't stop immediately in discrete mode\n\n    def test_annotate_decorator(self):\n        mock_self = MagicMock()\n        mock_self.profiler = self.profiler\n        mock_self.profiler.this_step = True\n        decorator = mock_self.profiler.annotate(message=\"test\")\n\n        @decorator\n        def test_func(self, *args, **kwargs):\n            return \"result\"\n\n        with (\n            patch(\"torch.cuda.profiler.start\") as mock_start,\n            patch(\"torch.cuda.profiler.stop\") as mock_stop,\n            patch(\"verl.utils.profiler.nvtx_profile.mark_start_range\") as mock_start_range,\n            patch(\"verl.utils.profiler.nvtx_profile.mark_end_range\") as mock_end_range,\n        ):\n            result = test_func(mock_self)\n            self.assertEqual(result, \"result\")\n            mock_start_range.assert_called_once()\n            mock_end_range.assert_called_once()\n            mock_start.assert_not_called()  # Not discrete mode\n            mock_stop.assert_not_called()  # Not discrete mode\n\n    # def test_annotate_discrete_mode(self):\n    #     discrete_config = ProfilerConfig(discrete=True, all_ranks=True)\n    #     profiler = NsightSystemsProfiler(self.rank, discrete_config)\n    #     mock_self = MagicMock()\n    #     mock_self.profiler = profiler\n    #     mock_self.profiler.this_step = True\n\n    #     @NsightSystemsProfiler.annotate(message=\"test\")\n    #     def test_func(self, *args, **kwargs):\n    #         return \"result\"\n\n    #     with (\n    #         patch(\"torch.cuda.profiler.start\") as mock_start,\n    #         patch(\"torch.cuda.profiler.stop\") as mock_stop,\n    #         patch(\"verl.utils.profiler.nvtx_profile.mark_start_range\") as mock_start_range,\n    #         patch(\"verl.utils.profiler.nvtx_profile.mark_end_range\") as mock_end_range,\n    #     ):\n    #         result = test_func(mock_self)\n    #         self.assertEqual(result, \"result\")\n    #         mock_start_range.assert_called_once()\n    #         mock_end_range.assert_called_once()\n    #         mock_start.assert_called_once()  # Should start in discrete mode\n    #         mock_stop.assert_called_once()  # Should stop in discrete mode\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_rollout_skip_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport shutil\nimport tempfile\nfrom pathlib import Path\nfrom unittest.mock import MagicMock\n\nimport pytest\nimport torch\n\nfrom verl.utils.rollout_skip import DataProto, RolloutSkip\n\nlen_prompt = 50\nlen_response = 100\n\n\ndef temp_dir():\n    # Create a temporary directory\n    temp_dir = Path(tempfile.mkdtemp())\n    yield temp_dir\n    # Cleanup\n    shutil.rmtree(temp_dir)\n\n\ndef build_generate_fn(gen_bs, n):\n    len_tokenizer = 1024\n\n    def iterate():\n        while True:\n            prompt = torch.randint(len_tokenizer, size=(gen_bs, len_prompt)).repeat_interleave(n, dim=0)\n            generate = torch.randint(len_tokenizer, size=(gen_bs * n, len_response))\n            data = DataProto.from_dict(tensors={\"prompt\": prompt, \"response\": generate})\n            yield data\n\n    mock_infer_engine = iterate()\n\n    def fn(batch, **kwargs):\n        # Simulate the inference engine returning the next batch\n        return next(mock_infer_engine)\n\n    return fn\n\n\n@pytest.fixture(params=[(32, 4), (64, 4), (64, 8)])\ndef mock_rollout_wg(request):\n    gen_bs, n = request.param\n    rollout_wg = MagicMock()\n\n    config = MagicMock()\n    config.actor_rollout_ref.rollout = {\n        \"n\": n,\n        \"skip_dump_dir\": next(temp_dir()),\n    }\n    config.data = {\"gen_batch_size\": gen_bs}\n\n    rollout_wg.generate_sequences = build_generate_fn(gen_bs, n)\n\n    yield config, rollout_wg\n    # Cleanup\n    shutil.rmtree(next(temp_dir()))\n\n\nclass TestRolloutSkip:\n    def test_initialization(self, capsys):\n        \"\"\"Test that RolloutSkip initializes correctly\"\"\"\n        config = MagicMock()\n        config.actor_rollout_ref.rollout = {\n            \"n\": 16,\n            \"skip_dump_dir\": \"tmp/rollout_dump\",\n        }\n        config.data = {\"gen_batch_size\": 128}\n        mock_rollout_wg = MagicMock()\n        skip = RolloutSkip(config, mock_rollout_wg)\n\n        assert skip.n == 16\n        assert skip.gbs == 128\n        assert str(skip.dumped_dir) == \"tmp/rollout_dump\"\n\n        assert skip._rollout_wg == mock_rollout_wg\n        skip.wrap_generate_sequences()\n        captured = capsys.readouterr()\n        assert \"Successfully patched\" in captured.out\n\n    def test_generate_without_wrap(self, mock_rollout_wg):\n        \"\"\"Test that generate_sequences works without wrapping\"\"\"\n\n        config, rollout_wg = mock_rollout_wg\n        _ = RolloutSkip(config, rollout_wg)\n\n        _result = rollout_wg.generate_sequences(MagicMock())\n        for _ in range(10):\n            result = rollout_wg.generate_sequences(MagicMock())\n            assert isinstance(result, DataProto)\n            # * make sure the data is different\n            assert torch.abs(_result.batch[\"prompt\"] - result.batch[\"prompt\"]).sum() > 0\n            assert torch.abs(_result.batch[\"response\"] - result.batch[\"response\"]).sum() > 0\n            _result = result\n\n    def test_dump(self, mock_rollout_wg, capsys):\n        config, rollout_wg = mock_rollout_wg\n        skip = RolloutSkip(config, rollout_wg)\n        skip.wrap_generate_sequences()\n\n        result = rollout_wg.generate_sequences(MagicMock())\n        # * check if dump is OK\n        assert skip.curr_path_dump.exists()\n        captured = capsys.readouterr()\n        assert \"Successfully dump data in\" in captured.out\n        # * get file size, estimate file size\n        file_size = skip.curr_path_dump.stat().st_size\n        est_file_size = (len_prompt + len_response) * skip.gbs * skip.n * result.batch[\"prompt\"].dtype.itemsize\n        assert file_size >= est_file_size, \"Dumped file size is smaller than expected\"\n\n    def test_generate_with_wrap(self, mock_rollout_wg, capsys):\n        \"\"\"Test that generate_sequences works without wrapping\"\"\"\n\n        config, rollout_wg = mock_rollout_wg\n        skip = RolloutSkip(config, rollout_wg)\n        skip.wrap_generate_sequences()\n\n        _result = rollout_wg.generate_sequences(MagicMock())\n\n        for _ in range(10):\n            result = rollout_wg.generate_sequences(MagicMock())\n            assert isinstance(result, DataProto)\n            # * make sure the data is different\n            assert torch.abs(_result.batch[\"prompt\"] - result.batch[\"prompt\"]).sum() == 0\n            assert torch.abs(_result.batch[\"response\"] - result.batch[\"response\"]).sum() == 0\n            captured = capsys.readouterr()\n            assert \"Successfully load pre-generated data from\" in captured.out\n            _result = result\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_rollout_trace_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport sys\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\n\nfrom verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op\n\n\n@pytest.fixture(autouse=True)\ndef reset_rollout_trace_config_singleton():\n    \"\"\"Fixture to reset the RolloutTraceConfig singleton before each test.\"\"\"\n    RolloutTraceConfig.reset()\n\n\n@pytest.fixture\ndef mock_weave_client():\n    \"\"\"Mocks the weave module and its client, yielding the mock client.\"\"\"\n    mock_weave = MagicMock()\n    mock_client = MagicMock()\n    mock_call = MagicMock()\n    mock_client.create_call.return_value = mock_call\n    mock_weave.init.return_value = mock_client\n\n    # Also mock the call_context if it's used internally by the decorator\n    mock_weave.trace.context.call_context.return_value = MagicMock()\n\n    with patch.dict(sys.modules, {\"weave\": mock_weave, \"weave.trace.context\": mock_weave.trace.context}):\n        yield mock_client\n\n\nclass TracedClass:\n    @rollout_trace_op\n    # @weave.op\n    # @mlflow.trace\n    async def my_method(self, a, b=\"default\"):\n        return f\"result: {a}, {b}\"\n\n    @rollout_trace_op\n    # @weave.op\n    # @mlflow.trace\n    async def middle_method(self, a, b=\"default\"):\n        await self.my_method(\"test_a1\", b=\"test_b1\")\n        return f\"result: {a}, {b}\"\n\n    @rollout_trace_op\n    # @mlflow.trace\n    async def my_method_with_exception(self):\n        raise ValueError(\"Test Exception\")\n\n    async def upper_method(self):\n        await self.my_method(\"test_a0\", b=\"test_b0\")\n        await self.middle_method(\"test_a2\", b=\"test_b2\")\n        return True\n\n\nclass UntracedClass:\n    @rollout_trace_op\n    async def my_method(self, x):\n        return x * 2\n\n\nasync def test_rollout_trace_on_untraced_class():\n    \"\"\"Tests that the decorator works correctly when no backend is configured.\"\"\"\n    instance = UntracedClass()\n    assert await instance.my_method(10) == 20\n\n\nasync def test_rollout_trace_with_tracer(mock_weave_client):\n    \"\"\"Tests that the decorator calls the tracer's methods correctly.\"\"\"\n    RolloutTraceConfig.init(project_name=\"my-project\", experiment_name=\"my-experiment\", backend=\"weave\")\n    instance = TracedClass()\n    assert RolloutTraceConfig.get_client() is mock_weave_client\n\n    result = await instance.my_method(\"test_a\", b=\"test_b\")\n\n    assert result == \"result: test_a, test_b\"\n    mock_weave_client.create_call.assert_called_once()\n    call_kwargs = mock_weave_client.create_call.call_args.kwargs\n    assert call_kwargs[\"op\"] == \"TracedClass.my_method\"\n    expected_inputs = {\"a\": \"test_a\", \"b\": \"test_b\"}\n    assert call_kwargs[\"inputs\"] == expected_inputs\n\n    mock_call = mock_weave_client.create_call.return_value\n    mock_weave_client.finish_call.assert_called_once_with(mock_call, output=result)\n\n\nasync def test_rollout_trace_with_exception(mock_weave_client):\n    \"\"\"Tests that `finish` is called with the exception when one is raised.\"\"\"\n    RolloutTraceConfig.init(project_name=\"my-project\", experiment_name=\"my-experiment\", backend=\"weave\")\n    instance = TracedClass()\n\n    with pytest.raises(ValueError, match=\"Test Exception\"):\n        await instance.my_method_with_exception()\n\n    mock_weave_client.create_call.assert_called_once()\n    mock_call = mock_weave_client.create_call.return_value\n    mock_weave_client.finish_call.assert_called_once()\n\n    # Check that finish_call was called with the exception\n    args, kwargs = mock_weave_client.finish_call.call_args\n    assert args[0] == mock_call\n    assert \"exception\" in kwargs\n    assert isinstance(kwargs[\"exception\"], ValueError)\n\n\nasync def test_rollout_trace_with_dummy_backend(mock_weave_client):\n    \"\"\"Tests that the tracer is not called when the backend is 'dummy'.\"\"\"\n    RolloutTraceConfig.init(project_name=\"my-project\", experiment_name=\"my-experiment\", backend=\"dummy\")\n    instance = TracedClass()\n\n    await instance.my_method(\"test_a\")\n\n    mock_weave_client.create_call.assert_not_called()\n\n\n@pytest.mark.skipif(\n    os.environ.get(\"RUN_WEAVE_INTEGRATION_TESTS\", \"false\").lower() != \"true\",\n    reason=\"Skipping weave integration test. Set RUN_WEAVE_INTEGRATION_TESTS=true to run.\",\n)\nasync def test_rollout_trace_with_real_weave_backend():\n    \"\"\"Integration test with a real weave backend.\"\"\"\n\n    # This assumes that the weave environment (e.g., project) is configured\n    RolloutTraceConfig.init(project_name=\"my-project\", experiment_name=\"my-experiment\", backend=\"weave\")\n\n    instance = TracedClass()\n\n    with rollout_trace_attr(step=1, sample_index=2, rollout_n=3):\n        await instance.upper_method()\n\n    with pytest.raises(ValueError, match=\"Test Exception\"):\n        await instance.my_method_with_exception()\n\n    print(\"\\nWeave integration test ran successfully. Check your weave project for the trace.\")\n\n\n@pytest.mark.skipif(\n    os.environ.get(\"RUN_MLFLOW_INTEGRATION_TESTS\", \"false\").lower() != \"true\",\n    reason=\"Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.\",\n)\nasync def test_rollout_trace_with_real_mlflow_backend():\n    \"\"\"Integration test with a real mlflow backend.\"\"\"\n\n    # This assumes that the mlflow environment (e.g., project) is configured\n    RolloutTraceConfig.init(project_name=\"my-project\", experiment_name=\"my-experiment\", backend=\"mlflow\")\n\n    instance = TracedClass()\n\n    with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name=\"agent_run\"):\n        assert await instance.upper_method()\n\n    # with pytest.raises(ValueError, match=\"Test Exception\"):\n    #     await instance.my_method_with_exception()\n\n    print(\"\\nWeave integration test ran successfully. Check your weave project for the trace.\")\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_seqlen_balancing.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nfrom verl import DataProto\nfrom verl.utils.model import create_random_mask\nfrom verl.utils.seqlen_balancing import (\n    ceildiv,\n    get_reverse_idx,\n    prepare_dynamic_batch,\n    rearrange_micro_batches,\n    restore_dynamic_batch,\n)\n\n\ndef test_seqlen_balancing():\n    input_ids = torch.randint(low=0, high=10, size=(20, 100))\n\n    attention_mask = create_random_mask(\n        input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5\n    )\n    data = {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n    dataproto = DataProto.from_single_dict(data)\n    micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300)\n    batch = torch.cat(micro_batches)\n    micro_bsz_idx = []\n    for idx in micro_bsz_idx_lst:\n        micro_bsz_idx.extend(idx)\n    reverse_idx_map = get_reverse_idx(micro_bsz_idx)\n    reverse_idx_map = torch.tensor(reverse_idx_map)\n    new_batch = batch[reverse_idx_map]\n    torch.testing.assert_close(new_batch, dataproto.batch)\n\n\ndef test_dynamic_batch():\n    input_ids = torch.randint(low=0, high=10, size=(20, 100))\n\n    attention_mask = create_random_mask(\n        input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5\n    )\n    data = {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n    dataproto = DataProto.from_single_dict(data)\n    micro_batches, micro_bsz_idx_lst = prepare_dynamic_batch(dataproto, max_token_len=300)\n    input_ids = torch.cat([micro_batch.batch[\"input_ids\"] for micro_batch in micro_batches], dim=0)\n    input_ids = restore_dynamic_batch(input_ids, micro_bsz_idx_lst)\n    torch.testing.assert_close(input_ids, dataproto.batch[\"input_ids\"])\n\n\ndef _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb):\n    # 1) init process group & CUDA\n    torch.cuda.set_device(rank)\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=init_method,\n        world_size=world_size,\n        rank=rank,\n    )\n\n    # 2) build a small random batch (each rank different length to force mismatch)\n    torch.manual_seed(42 + rank)\n    input_ids = torch.randint(0, 10, (20 + rank * 5, 100), device=f\"cuda:{rank}\")\n    attention_mask = create_random_mask(\n        input_ids=input_ids,\n        max_ratio_of_left_padding=0.1,\n        max_ratio_of_valid_token=0.9,\n        min_ratio_of_valid_token=0.5,\n    )\n    dp = {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n    proto = DataProto.from_single_dict(dp)\n    batch = proto.batch\n\n    # 3) call rearrange_micro_batches with one of the two params under test\n    micros, idx_lst = rearrange_micro_batches(\n        batch,\n        max_token_len=max_token_len,\n        dp_group=dist.group.WORLD,\n        same_micro_num_in_dp=use_same_dp,\n        min_num_micro_batch=min_mb,\n    )\n\n    # 4) check the enforced counts\n    seq_len_effective: torch.Tensor = batch[\"attention_mask\"].sum(dim=1)\n    total_seqlen = seq_len_effective.sum().item()\n    local = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))\n\n    if min_mb is not None:\n        expected = max(local, min_mb)\n        assert len(micros) == expected\n    if use_same_dp:\n        # gather all local_counts\n        counts = [torch.zeros(1, device=f\"cuda:{rank}\") for _ in range(world_size)]\n        counts[rank].fill_(local)\n        dist.all_gather(counts, counts[rank])\n        expected = max(int(c.item()) for c in counts)\n        assert len(micros) == expected\n    else:\n        # if neither, we get the local natural count\n        assert len(micros) == local\n\n    # 5) reconstruction sanity: concat→reverse_idx→orig\n    flat = torch.cat(micros, dim=0)\n    idx = []\n    for sub in idx_lst:\n        idx.extend(sub)\n    inv = get_reverse_idx(idx)\n    inv = torch.tensor(inv, device=flat.device)\n    reconstructed = flat[inv]\n    torch.testing.assert_close(reconstructed, batch)\n\n    dist.destroy_process_group()\n\n\ndef test_dataproto_split_uneven():\n    \"\"\"Test DataProto.split with uneven splits\"\"\"\n    # Create test data with 10 items\n    input_ids = torch.randint(low=0, high=10, size=(10, 5))\n    attention_mask = torch.ones(10, 5)\n    data = {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n    dataproto = DataProto.from_single_dict(data)\n\n    # Test split with size 3 (should create chunks of [3, 3, 3, 1])\n    splits = dataproto.split(3)\n    assert len(splits) == 4\n    assert len(splits[0]) == 3\n    assert len(splits[1]) == 3\n    assert len(splits[2]) == 3\n    assert len(splits[3]) == 1\n\n    reconstructed = DataProto.concat(splits)\n    torch.testing.assert_close(reconstructed.batch[\"input_ids\"], dataproto.batch[\"input_ids\"])\n    torch.testing.assert_close(reconstructed.batch[\"attention_mask\"], dataproto.batch[\"attention_mask\"])\n\n    # Test split with size equal to length (should create one chunk)\n    splits = dataproto.split(10)\n    assert len(splits) == 1\n    assert len(splits[0]) == 10\n\n    # Test split with size larger than length (should create one chunk with all data)\n    splits = dataproto.split(15)\n    assert len(splits) == 1\n    assert len(splits[0]) == 10\n\n    # Test with non-tensor batch data\n    import numpy as np\n\n    data_with_non_tensor = {\n        \"input_ids\": input_ids,\n        \"attention_mask\": attention_mask,\n        \"labels\": np.array([f\"label_{i}\" for i in range(10)], dtype=object),\n    }\n    dataproto_with_non_tensor = DataProto.from_single_dict(data_with_non_tensor)\n\n    splits = dataproto_with_non_tensor.split(3)\n    assert len(splits) == 4\n    assert len(splits[0]) == 3\n    assert len(splits[1]) == 3\n    assert len(splits[2]) == 3\n    assert len(splits[3]) == 1\n\n    # Verify non-tensor data integrity\n    reconstructed = DataProto.concat(splits)\n    np.testing.assert_array_equal(\n        reconstructed.non_tensor_batch[\"labels\"], dataproto_with_non_tensor.non_tensor_batch[\"labels\"]\n    )\n\n\ndef test_seqlen_balancing_distributed_params(tmp_path):\n    world_size = 2\n    init_file = tmp_path / \"dist_init\"\n    init_file.write_text(\"\")  # empty file\n    init_method = f\"file://{init_file}\"\n\n    # test min_num_micro_batch only\n    mp.spawn(\n        _worker,\n        args=(world_size, init_method, 300, False, 4),\n        nprocs=world_size,\n        join=True,\n    )\n\n    # test same_micro_num_in_dp only\n    mp.spawn(\n        _worker,\n        args=(world_size, init_method, 300, True, None),\n        nprocs=world_size,\n        join=True,\n    )\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_special_linear_cross_entropy_tp.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\nimport torch.distributed as dist\n\ntry:\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\nexcept ImportError:\n    # FIXME: remove these manually included paths\n    import sys\n\n    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), \"../../\")))\nfinally:\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\n\nimport verl.utils.torch_functional as verl_F\n\ncompute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)\n\nMAX_TEST_CASES = os.environ.get(\"MAX_TEST_CASES\", 5)\nVERIFY_TORCH_SELF = os.environ.get(\"VERIFY_TORCH_SELF\", False)\nLOW_MEMORY = os.environ.get(\"LOW_MEMORY\", False)\nLOW_MEMORY_DIV_FACTOR = os.environ.get(\"LOW_MEMORY_DIV_FACTOR\", 16)\n\n\ndef run_torch_entropy(\n    hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction=\"none\"\n) -> list[torch.Tensor]:\n    # [num_tokens, vocab_size]\n    if len(hidden.shape) > 2:\n        hidden = hidden.view(-1, hidden.shape[-1])  # [num_tokens, hidden_size]\n    if len(labels.shape) > 1:\n        labels = labels.view(-1)\n    logits = torch.matmul(\n        hidden.to(torch.float32),\n        weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32),\n    )\n    logits /= temperature\n    pd = torch.nn.functional.softmax(logits, dim=-1)  # [num_tokens, vocab_size]\n    entropy_a = torch.logsumexp(logits, dim=-1)  # [num_tokens]\n    entropy_b = torch.sum(pd * logits, dim=-1)  # [num_tokens]\n    entropy = entropy_a - entropy_b\n    logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction)  # [num_tokens]\n    logprobs = torch.neg(logprobs)\n    return logprobs, entropy\n\n\nclass TorchEntropyTP(torch.autograd.Function):\n    \"\"\"\n    it is used for testing the correctness of the kernel\n    it is not efficient and is not recommended to use in practice\n    \"\"\"\n\n    @staticmethod\n    def forward(\n        ctx,\n        hidden: torch.Tensor,\n        weight: torch.Tensor,\n        labels: torch.Tensor,\n        temperature: float,\n        dist_process_group: torch.distributed.ProcessGroup,\n    ):\n        # weight has shape [vocab_size, hidden_size], hidden has shape [num_tokens, hidden_size]\n        ctx.original_hidden_shape = hidden.shape\n        if len(hidden.shape) > 2:\n            hidden = hidden.view(-1, hidden.shape[-1])  # [num_tokens, hidden_size]\n        if len(labels.shape) > 1:\n            labels = labels.view(-1)\n\n        logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T)  # [num_tokens, vocab_size]\n        logits /= temperature\n        whole_logits = torch.empty(\n            (logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)),\n            dtype=logits.dtype,\n            device=logits.device,\n        )\n        whole_logits_ref = [\n            whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]]\n            for i in range(dist.get_world_size(dist_process_group))\n        ]\n        dist.all_gather(whole_logits_ref, logits, group=dist_process_group)\n\n        pd = torch.nn.functional.softmax(whole_logits, dim=-1)\n        entropy_a = torch.logsumexp(whole_logits, dim=-1)  # [num_tokens]\n        entropy_b = torch.sum(pd * whole_logits, dim=-1)  # [num_tokens]\n        entropy = entropy_a - entropy_b\n\n        logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction=\"none\")\n        logprobs = torch.neg(logprobs)\n\n        ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b)\n        ctx.dist_process_group = dist_process_group\n        ctx.temperature = temperature\n        return logprobs, entropy\n\n    @staticmethod\n    def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor):\n        hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors\n        dist_process_group = ctx.dist_process_group\n        temperature = ctx.temperature\n        batch_size, hidden_size = hidden.shape\n        vocab_size, hidden_size = weight.shape\n        rank = dist.get_rank(dist_process_group)\n\n        # Compute softmax probabilities\n        maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True)\n        exp_logits = torch.exp(whole_logits - maximum)\n        accumulate = exp_logits.sum(dim=-1, keepdim=True)\n        pd = exp_logits / accumulate\n\n        # Gradient for entropy\n        # entropy = entropy_a - entropy_b\n        # entropy_a = log(sum(exp(logits)))\n        # entropy_b = sum(pd * logits)\n        # d_entropy_a/d_logits = pd\n        # d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1)\n        # d_entropy/d_logits = d_entropy_a - d_entropy_b\n        # d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1)\n        # d_entropy/d_logits = -pd * (logits - b.unsqueeze(1))\n        d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1)))\n\n        # Gradient for logprobs\n        # logprobs = -cross_entropy = -log(pd[labels])\n        # d_logprobs/d_logits = (pd - one_hot(labels))\n        one_hot = torch.zeros_like(whole_logits)\n        one_hot.scatter_(1, labels.unsqueeze(1), 1)\n        g_logprobs = torch.neg(g_logprobs)\n        d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot)\n        # NOTE: This will lead to wrong result\n        # d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot\n\n        # Combine gradients\n        d_logits = d_logits_entropy + d_logits_logprobs\n        d_logits /= temperature\n\n        # Get local slice of gradients\n        local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size]\n\n        # Compute gradients for hidden and weight\n        d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32))\n        d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32))\n        d_hidden = d_hidden.view(ctx.original_hidden_shape)\n\n        return d_hidden, d_weight, None, None, None\n\n\nrun_torch_entropy_tp = TorchEntropyTP.apply\n\n\nclass TestLinearCrossEntropy_TensorParallel:\n    def __init__(self):\n        dist.init_process_group(backend=\"nccl\")\n        self.group = dist.group.WORLD\n\n        self.local_rank = dist.get_rank(self.group)\n        self.world_size = dist.get_world_size(self.group)\n        device = torch.device(f\"cuda:{self.local_rank}\")\n        torch.cuda.set_device(device)\n        print(f\"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}\")\n\n    def initialize(self, test_case_idx: int, temperature: float = 1.5):\n        self.test_case_idx = test_case_idx\n        self.temperature = temperature\n\n    def shutdown(self):\n        dist.destroy_process_group()\n\n    def cleanup(self):\n        torch.cuda.empty_cache()\n        torch.cuda.reset_peak_memory_stats()\n        import gc\n\n        gc.collect()\n        torch.cuda.synchronize()\n\n    def generate_hyper(self):\n        global LOW_MEMORY, LOW_MEMORY_DIV_FACTOR, MAX_TEST_CASES\n\n        self.dtype = torch.bfloat16\n        if self.test_case_idx == 0:\n            self.batch_size = 1\n            self.num_tokens = 1937\n            self.hidden_size = 3584\n            self.vocab_size = 152064\n        elif self.test_case_idx == 1:\n            self.batch_size = 1\n            self.num_tokens = 2169\n            self.hidden_size = 896\n            self.vocab_size = 151936\n        elif self.test_case_idx == 2:\n            self.batch_size = 1\n            self.num_tokens = 1530\n            self.hidden_size = 2048\n            self.vocab_size = 32256\n        elif self.test_case_idx == 3:\n            self.batch_size = 1\n            self.num_tokens = 1388\n            self.hidden_size = 4096\n            self.vocab_size = 102400\n        elif self.test_case_idx == 4:\n            self.batch_size = 1\n            self.num_tokens = 8192\n            self.hidden_size = 4096\n            self.vocab_size = 102400\n        else:\n            raise ValueError(f\"Invalid test case index: {self.test_case_idx}\")\n        if LOW_MEMORY:\n            self.vocab_size = int(self.vocab_size / LOW_MEMORY_DIV_FACTOR)\n        assert MAX_TEST_CASES <= 5, \"MAX_TEST_CASES should be less than or equal to 5.\"\n\n    def generate_forward_inputs(self):\n        hidden = (\n            torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device=\"cuda\")\n            .uniform_(-0.5, 0.5)\n            .requires_grad_()\n        )\n        weight = (\n            torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device=\"cuda\")\n            .uniform_(-0.5, 0.5)\n            .requires_grad_()\n        )\n        labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device=\"cuda\")\n        return hidden, weight, labels\n\n    def generate_backward_inputs(self):\n        g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device=\"cuda\").uniform_(-0.5, 0.5)\n        g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device=\"cuda\").uniform_(-1, 1)\n        return g_entropy, g_logprobs\n\n    def verify_torch_itself(self, iterations: int = 5):\n        self.cleanup()\n        self.generate_hyper()\n\n        for i in range(iterations):\n            hidden, weight, labels = self.generate_forward_inputs()\n\n            # NOTE: we need to manually synchronize hidden and labels among Process Group\n            dist.broadcast(hidden, src=0, group=self.group)\n            dist.broadcast(labels, src=0, group=self.group)\n\n            # forward pass\n            # Create a tensor to hold the gathered weights from all ranks\n            # weight has shape [vocab_size, hidden_size]\n            # We want to gather along the first dimension to get [vocab_size * world_size, hidden_size]\n\n            # Create a single contiguous tensor to hold all gathered weights\n            whole_weight = torch.empty(\n                (self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device\n            )\n\n            # Create views into the tensor for each rank's portion\n            whole_weight_views = [\n                whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size)\n            ]\n\n            # Perform all_gather operation using the views\n            dist.all_gather(whole_weight_views, weight, group=self.group)\n\n            # Set requires_grad for autograd\n            whole_weight.requires_grad_()\n\n            (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature)\n\n            (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)\n\n            torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4)\n            torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4)\n\n            # backward pass\n            g_entropy, g_logprobs = self.generate_backward_inputs()\n            # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group\n            dist.broadcast(g_entropy, src=0, group=self.group)\n            dist.broadcast(g_logprobs, src=0, group=self.group)\n\n            (single_d_hidden, single_d_weight) = torch.autograd.grad(\n                (single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n\n            (tp_d_hidden, tp_d_weight) = torch.autograd.grad(\n                (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            # NOTE: all-reduce on hidden is conducted outside the kernel\n            dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group)\n\n            torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4)\n            # Extract the corresponding slice from single_d_weight for comparison\n            # tp_d_weight has shape [vocab_size, hidden_size]\n            # single_d_weight has shape [vocab_size * world_size, hidden_size]\n            torch.testing.assert_close(\n                tp_d_weight,\n                single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size],\n                atol=1e-2,\n                rtol=1e-4,\n            )\n\n            # atol=1e-3, rtol=1e-4)\n        if self.local_rank == 0:\n            print(\"[PASS] torch TP correctness is verified\")\n\n    def check_torch_storage(self):\n        self.cleanup()\n        self.generate_hyper()\n\n        hidden, weight, labels = self.generate_forward_inputs()\n\n        # NOTE: we need to manually synchronize hidden and labels among Process Group\n        dist.broadcast(hidden, src=0, group=self.group)\n        dist.broadcast(labels, src=0, group=self.group)\n\n        torch.cuda.reset_peak_memory_stats()\n        (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)\n        torch.cuda.synchronize()\n        forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n\n        g_entropy, g_logprobs = self.generate_backward_inputs()\n        # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group\n        dist.broadcast(g_entropy, src=0, group=self.group)\n        dist.broadcast(g_logprobs, src=0, group=self.group)\n\n        torch.cuda.reset_peak_memory_stats()\n        (d_tp_hidden, d_tp_weight) = torch.autograd.grad(\n            (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n        )\n        torch.cuda.synchronize()\n        backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n        # NOTE: all-reduce on hidden is conducted outside the kernel\n        dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group)\n\n        if self.local_rank == 0:\n            print(f\"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB\")\n            print(f\"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB\")\n\n    def verify_kernel_correctness(self, iterations: int = 5):\n        self.cleanup()\n        self.generate_hyper()\n\n        torch_forward_latency = list()\n        torch_backward_latency = list()\n        kernel_forward_latency = list()\n        kernel_backward_latency = list()\n\n        start_event = torch.cuda.Event(enable_timing=True)\n        end_event = torch.cuda.Event(enable_timing=True)\n\n        for i in range(iterations):\n            hidden, weight, labels = self.generate_forward_inputs()\n\n            # NOTE: we need to manually synchronize hidden and labels among Process Group\n            dist.broadcast(hidden, src=0, group=self.group)\n            dist.broadcast(labels, src=0, group=self.group)\n\n            start_event.record()\n            (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)\n            end_event.record()\n            torch.cuda.synchronize()\n            torch_forward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (kernel_logprobs, kernel_entropy) = linear_cross_entropy(\n                hidden, weight, labels, self.temperature, \"none\", self.group\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            kernel_forward_latency.append(start_event.elapsed_time(end_event))\n\n            torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2)\n            torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2)\n\n            # backward pass\n            g_entropy, g_logprobs = self.generate_backward_inputs()\n            # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group\n            dist.broadcast(g_entropy, src=0, group=self.group)\n            dist.broadcast(g_logprobs, src=0, group=self.group)\n\n            start_event.record()\n            (torch_d_hidden, torch_d_weight) = torch.autograd.grad(\n                (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            torch_backward_latency.append(start_event.elapsed_time(end_event))\n            # NOTE: all-reduce on hidden is conducted outside the kernel\n            dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group)\n\n            start_event.record()\n            (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad(\n                (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            kernel_backward_latency.append(start_event.elapsed_time(end_event))\n            # NOTE: all-reduce on hidden is conducted outside the kernel\n            dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group)\n\n            torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2)\n\n        # remove first latency\n        torch_forward_latency = torch_forward_latency[1:]\n        torch_backward_latency = torch_backward_latency[1:]\n        kernel_forward_latency = kernel_forward_latency[1:]\n        kernel_backward_latency = kernel_backward_latency[1:]\n\n        if self.local_rank == 0:\n            print(\"\\n[PASS]: Verified kernel forward & backward correctness.\")\n\n            print(\n                f\"[INFO]: Forward pass: Torch implementation average time: \"\n                f\"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms\"\n            )\n            print(\n                f\"[INFO]: Backward pass: torch implementation average time: \"\n                f\"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms\"\n            )\n            print(\n                f\"[INFO]: Forward pass: Kernel implementation average time: \"\n                f\"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms\"\n            )\n            print(\n                f\"[INFO]: Backward pass: kernel implementation average time: \"\n                f\"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms\"\n            )\n\n    def check_kernel_storage(self):\n        self.cleanup()\n        self.generate_hyper()\n\n        hidden, weight, labels = self.generate_forward_inputs()\n\n        # NOTE: we need to manually synchronize hidden and labels among Process Group\n        dist.broadcast(hidden, src=0, group=self.group)\n        dist.broadcast(labels, src=0, group=self.group)\n\n        torch.cuda.reset_peak_memory_stats()\n        (kernel_logprobs, kernel_entropy) = linear_cross_entropy(\n            hidden, weight, labels, self.temperature, \"none\", self.group\n        )\n        torch.cuda.synchronize()\n        kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n\n        g_entropy, g_logprobs = self.generate_backward_inputs()\n        # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group\n        dist.broadcast(g_entropy, src=0, group=self.group)\n        dist.broadcast(g_logprobs, src=0, group=self.group)\n\n        torch.cuda.reset_peak_memory_stats()\n        (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(\n            (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n        )\n        torch.cuda.synchronize()\n        kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n        # NOTE: all-reduce on hidden is conducted outside the kernel\n        dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group)\n\n        if self.local_rank == 0:\n            print(f\"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB\")\n            print(f\"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB\")\n\n\nif __name__ == \"__main__\":\n    # TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernels/test_linear_cross_entropy_tp.py\n\n    # Check if running with torchrun (distributed mode)\n    assert int(os.environ[\"WORLD_SIZE\"]) > 1, (\n        \"[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to \"\n        \"execute this script.\"\n    )\n    torch.manual_seed(233376 + int(os.environ.get(\"RANK\", 0)))\n\n    # set_backward_method(BackwardEnum._Total_Fuse_MN)\n    # set_backward_method(BackwardEnum._Split_Dlogits_N)\n\n    test = TestLinearCrossEntropy_TensorParallel()\n    for test_case_idx in range(MAX_TEST_CASES):\n        print(f\"[INFO] Running test case {test_case_idx}\")\n        test.initialize(test_case_idx)\n        if VERIFY_TORCH_SELF:\n            test.verify_torch_itself()\n        test.check_torch_storage()\n        test.verify_kernel_correctness()\n        test.check_kernel_storage()\n\n    test.shutdown()\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_special_mstx_profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nfrom unittest.mock import MagicMock, patch\n\nfrom verl.utils.profiler.config import NPUToolConfig, ProfilerConfig\nfrom verl.utils.profiler.mstx_profile import NPUProfiler\n\n\nclass TestNPUProfilerInitialization(unittest.TestCase):\n    def setUp(self):\n        NPUProfiler._define_count = 0\n\n    def test_init_with_default_config(self):\n        tool_config = NPUToolConfig()\n        profiler = NPUProfiler(rank=0, config=None, tool_config=tool_config)\n        self.assertFalse(profiler.enable)\n        self.assertFalse(hasattr(profiler, \"profile_npu\"))\n\n    def test_init_with_disabled_config(self):\n        config = ProfilerConfig(enable=False)\n        tool_config = NPUToolConfig()\n        profiler = NPUProfiler(rank=0, config=config, tool_config=tool_config)\n        self.assertFalse(profiler.enable)\n        self.assertFalse(hasattr(profiler, \"profile_npu\"))\n\n    def test_init_with_all_ranks_true(self):\n        config = ProfilerConfig(enable=True, all_ranks=True)\n        tool_config = NPUToolConfig()\n        profiler = NPUProfiler(rank=0, config=config, tool_config=tool_config)\n        self.assertTrue(profiler.this_rank)\n\n    def test_init_with_ranks_list(self):\n        config = ProfilerConfig(enable=True, ranks=[1, 2])\n        tool_config = NPUToolConfig()\n        profiler = NPUProfiler(rank=1, config=config, tool_config=tool_config)\n        self.assertTrue(profiler.this_rank)\n\n    def test_init_with_rank_not_in_ranks(self):\n        config = ProfilerConfig(enable=True, ranks=[1, 2])\n        tool_config = NPUToolConfig()\n        profiler = NPUProfiler(rank=3, config=config, tool_config=tool_config)\n        self.assertFalse(profiler.this_rank)\n\n\nclass TestNPUProfilerStart(unittest.TestCase):\n    def setUp(self):\n        NPUProfiler._define_count = 0\n        self.config = ProfilerConfig(enable=True, ranks=[0])\n        self.tool_config = NPUToolConfig(discrete=False)\n\n    @patch(\"verl.utils.profiler.mstx_profile.get_npu_profiler\")\n    def test_start_when_enabled_and_this_rank(self, mock_get_profiler):\n        profiler = NPUProfiler(rank=0, config=self.config, tool_config=self.tool_config)\n        profiler.start(role=\"worker\", profile_step=\"1\")\n        self.assertTrue(profiler.this_step)\n        self.assertEqual(NPUProfiler._define_count, 1)\n        mock_get_profiler.assert_called_once()\n\n    @patch(\"verl.utils.profiler.mstx_profile.get_npu_profiler\")\n    def test_start_when_not_this_rank(self, mock_get_profiler):\n        profiler = NPUProfiler(rank=1, config=self.config, tool_config=self.tool_config)\n        profiler.start()\n        self.assertFalse(profiler.this_step)\n        self.assertEqual(NPUProfiler._define_count, 0)\n        mock_get_profiler.assert_not_called()\n\n    @patch(\"verl.utils.profiler.mstx_profile.get_npu_profiler\")\n    def test_start_discrete_mode_does_not_increase_count(self, mock_get_profiler):\n        tool_config = NPUToolConfig(discrete=True)\n        profiler = NPUProfiler(rank=0, config=self.config, tool_config=tool_config)\n        profiler.start()\n        self.assertEqual(NPUProfiler._define_count, 0)\n        mock_get_profiler.assert_not_called()\n\n    @patch(\"verl.utils.profiler.mstx_profile.get_npu_profiler\")\n    def test_multiple_start_calls_do_not_increase_count(self, mock_get_profiler):\n        profiler = NPUProfiler(rank=0, config=self.config, tool_config=self.tool_config)\n        profiler.start()\n        profiler.start()\n        self.assertEqual(NPUProfiler._define_count, 1)\n        mock_get_profiler.assert_called_once()\n\n\nclass TestNPUProfilerStartStopInteraction(unittest.TestCase):\n    def setUp(self):\n        NPUProfiler._define_count = 0\n        self.config = ProfilerConfig(enable=True, ranks=[0])\n        self.tool_config = NPUToolConfig(discrete=False)\n\n    @patch(\"verl.utils.profiler.mstx_profile.get_npu_profiler\")\n    def test_start_stop_cycle(self, mock_get_profiler):\n        mock_profile_npu = MagicMock()\n        mock_get_profiler.return_value = mock_profile_npu\n\n        profiler = NPUProfiler(rank=0, config=self.config, tool_config=self.tool_config)\n        profiler.start()\n        self.assertEqual(NPUProfiler._define_count, 1)\n        self.assertEqual(mock_profile_npu.start.call_count, 1)\n        profiler.stop()\n        self.assertEqual(NPUProfiler._define_count, 0)\n        self.assertEqual(mock_profile_npu.step.call_count, 1)\n        self.assertEqual(mock_profile_npu.stop.call_count, 1)\n\n    @patch(\"verl.utils.profiler.mstx_profile.get_npu_profiler\")\n    def test_multiple_instances_share_define_count(self, mock_get_profiler):\n        mock_profile_npu = MagicMock()\n        mock_get_profiler.return_value = mock_profile_npu\n\n        profiler1 = NPUProfiler(rank=0, config=self.config, tool_config=self.tool_config)\n        profiler2 = NPUProfiler(rank=0, config=self.config, tool_config=self.tool_config)\n        profiler1.start()\n        profiler2.start()\n        self.assertEqual(NPUProfiler._define_count, 1)\n        self.assertEqual(mock_profile_npu.start.call_count, 1)\n        profiler1.stop()\n        self.assertEqual(NPUProfiler._define_count, 0)\n\n\nclass TestNPUProfilerAnnotate(unittest.TestCase):\n    def setUp(self):\n        self.config = ProfilerConfig(enable=True, all_ranks=True)\n        self.tool_config = NPUToolConfig(discrete=False)\n        self.rank = 0\n\n    def test_annotate_decorator_applied_correctly(self):\n        mock_worker = MagicMock()\n        mock_worker.profiler = NPUProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config)\n        mock_worker.profiler.this_step = True\n\n        mock_mark_range = \"mocked_range_handle\"\n\n        with (\n            patch(\"verl.utils.profiler.mstx_profile.mark_start_range\") as mock_start_patch,\n            patch(\"verl.utils.profiler.mstx_profile.mark_end_range\") as mock_end_patch,\n        ):\n            mock_start_patch.return_value = mock_mark_range\n\n            with patch(\"verl.utils.profiler.mstx_profile.get_npu_profiler\") as mock_get_profiler:\n                decorator = mock_worker.profiler.annotate(message=\"test\")\n\n                @decorator\n                def test_func(self, *args, **kwargs):\n                    return \"result\"\n\n                result = test_func(mock_worker)\n\n                self.assertEqual(result, \"result\")\n                mock_start_patch.assert_called_once_with(message=\"test\")\n                mock_end_patch.assert_called_once_with(mock_mark_range)\n                mock_get_profiler.assert_not_called()\n\n    def test_annotate_when_profiler_disabled(self):\n        disabled_config = ProfilerConfig(enable=False)\n        mock_worker = MagicMock()\n        mock_worker.profiler = NPUProfiler(rank=self.rank, config=disabled_config, tool_config=self.tool_config)\n\n        with (\n            patch(\"verl.utils.profiler.mstx_profile.mark_start_range\") as mock_start_patch,\n            patch(\"verl.utils.profiler.mstx_profile.mark_end_range\") as mock_end_patch,\n            patch(\"verl.utils.profiler.mstx_profile.get_npu_profiler\") as mock_get_profiler,\n        ):\n            decorator = mock_worker.profiler.annotate(message=\"test\")\n\n            @decorator\n            def test_func(self, *args, **kwargs):\n                return \"result\"\n\n            result = test_func(mock_worker)\n\n            self.assertEqual(result, \"result\")\n            mock_start_patch.assert_not_called()\n            mock_end_patch.assert_not_called()\n            mock_get_profiler.assert_not_called()\n\n    def test_annotate_when_this_step_disabled(self):\n        mock_worker = MagicMock()\n        mock_worker.profiler = NPUProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config)\n        mock_worker.profiler.this_step = False\n\n        with (\n            patch(\"verl.utils.profiler.mstx_profile.mark_start_range\") as mock_start_patch,\n            patch(\"verl.utils.profiler.mstx_profile.mark_end_range\") as mock_end_patch,\n            patch(\"verl.utils.profiler.mstx_profile.get_npu_profiler\") as mock_get_profiler,\n        ):\n            decorator = mock_worker.profiler.annotate(message=\"test\")\n\n            @decorator\n            def test_func(self, *args, **kwargs):\n                return \"result\"\n\n            result = test_func(mock_worker)\n\n            self.assertEqual(result, \"result\")\n            mock_start_patch.assert_not_called()\n            mock_end_patch.assert_not_called()\n            mock_get_profiler.assert_not_called()\n\n    def test_annotate_discrete_mode_enabled(self):\n        discrete_tool_config = NPUToolConfig(discrete=True)\n        mock_worker = MagicMock()\n        mock_worker.profiler = NPUProfiler(rank=self.rank, config=self.config, tool_config=discrete_tool_config)\n        mock_worker.profiler.this_step = True\n\n        mock_mark_range = \"mocked_range_handle\"\n        mock_profile_npu = MagicMock()\n\n        with (\n            patch(\"verl.utils.profiler.mstx_profile.mark_start_range\") as mock_start_patch,\n            patch(\"verl.utils.profiler.mstx_profile.mark_end_range\") as mock_end_patch,\n            patch(\"verl.utils.profiler.mstx_profile.get_npu_profiler\") as mock_get_profiler,\n        ):\n            mock_start_patch.return_value = mock_mark_range\n            mock_get_profiler.return_value = mock_profile_npu\n            decorator = mock_worker.profiler.annotate(message=\"test\", role=\"test_role\")\n\n            @decorator\n            def test_func(self, *args, **kwargs):\n                return \"result\"\n\n            result = test_func(mock_worker)\n\n            self.assertEqual(result, \"result\")\n            mock_start_patch.assert_called_once_with(message=\"test\")\n            mock_end_patch.assert_called_once_with(mock_mark_range)\n            mock_get_profiler.assert_called_once_with(\n                contents=mock_worker.profiler.profile_contents,\n                profile_level=mock_worker.profiler.profile_level,\n                profile_save_path=mock_worker.profiler.profile_save_path,\n                analysis=mock_worker.profiler.analysis,\n                role=\"test_role\",\n            )\n            mock_profile_npu.start.assert_called_once()\n            mock_profile_npu.step.assert_called_once()\n            mock_profile_npu.stop.assert_called_once()\n\n    def test_annotate_with_default_message(self):\n        mock_worker = MagicMock()\n        mock_worker.profiler = NPUProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config)\n        mock_worker.profiler.this_step = True\n\n        mock_mark_range = \"mocked_range_handle\"\n        with (\n            patch(\"verl.utils.profiler.mstx_profile.mark_start_range\") as mock_start_patch,\n            patch(\"verl.utils.profiler.mstx_profile.mark_end_range\") as mock_end_patch,\n        ):\n            mock_start_patch.return_value = mock_mark_range\n            decorator = mock_worker.profiler.annotate()\n\n            @decorator\n            def test_func(self, *args, **kwargs):\n                return \"result\"\n\n            test_func(mock_worker)\n\n            mock_start_patch.assert_called_once_with(message=\"test_func\")\n            mock_end_patch.assert_called_once_with(mock_mark_range)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_temp_env_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport pytest\n\nfrom verl.utils.py_functional import temp_env_var\n\n\n@pytest.fixture(autouse=True)\ndef clean_env():\n    \"\"\"Fixture to clean up environment variables before and after each test.\"\"\"\n    # Store original environment state\n    original_env = dict(os.environ)\n\n    # Clean up any test variables that might exist\n    test_vars = [\"TEST_VAR\", \"TEST_VAR_2\", \"EXISTING_VAR\"]\n    for var in test_vars:\n        if var in os.environ:\n            del os.environ[var]\n\n    # Yield control to the test function\n    yield\n\n    # Restore original environment state after test\n    os.environ.clear()\n    os.environ.update(original_env)\n\n\ndef test_set_new_env_var():\n    \"\"\"Test setting a new environment variable that didn't exist before.\"\"\"\n    # Ensure variable doesn't exist\n    assert \"TEST_VAR\" not in os.environ\n\n    with temp_env_var(\"TEST_VAR\", \"test_value\"):\n        # Variable should be set inside context\n        assert os.environ[\"TEST_VAR\"] == \"test_value\"\n        assert \"TEST_VAR\" in os.environ\n\n    # Variable should be removed after context\n    assert \"TEST_VAR\" not in os.environ\n\n\ndef test_restore_existing_env_var():\n    \"\"\"Test restoring an environment variable that already existed.\"\"\"\n    # Set up existing variable\n    os.environ[\"EXISTING_VAR\"] = \"original_value\"\n\n    with temp_env_var(\"EXISTING_VAR\", \"temporary_value\"):\n        # Variable should be temporarily changed\n        assert os.environ[\"EXISTING_VAR\"] == \"temporary_value\"\n\n    # Variable should be restored to original value\n    assert os.environ[\"EXISTING_VAR\"] == \"original_value\"\n\n\ndef test_env_var_restored_on_exception():\n    \"\"\"Test that environment variables are restored even when exceptions occur.\"\"\"\n    # Set up existing variable\n    os.environ[\"EXISTING_VAR\"] = \"original_value\"\n\n    with pytest.raises(ValueError):\n        with temp_env_var(\"EXISTING_VAR\", \"temporary_value\"):\n            # Verify variable is set\n            assert os.environ[\"EXISTING_VAR\"] == \"temporary_value\"\n            # Raise exception\n            raise ValueError(\"Test exception\")\n\n    # Variable should still be restored despite exception\n    assert os.environ[\"EXISTING_VAR\"] == \"original_value\"\n\n\ndef test_nested_context_managers():\n    \"\"\"Test nested temp_env_var context managers.\"\"\"\n    # Set up original variable\n    os.environ[\"TEST_VAR\"] = \"original\"\n\n    with temp_env_var(\"TEST_VAR\", \"level1\"):\n        assert os.environ[\"TEST_VAR\"] == \"level1\"\n\n        with temp_env_var(\"TEST_VAR\", \"level2\"):\n            assert os.environ[\"TEST_VAR\"] == \"level2\"\n\n        # Should restore to level1\n        assert os.environ[\"TEST_VAR\"] == \"level1\"\n\n    # Should restore to original\n    assert os.environ[\"TEST_VAR\"] == \"original\"\n\n\ndef test_multiple_different_vars():\n    \"\"\"Test setting multiple different environment variables.\"\"\"\n    # Set up one existing variable\n    os.environ[\"EXISTING_VAR\"] = \"existing_value\"\n\n    with temp_env_var(\"EXISTING_VAR\", \"modified\"):\n        with temp_env_var(\"TEST_VAR\", \"new_value\"):\n            assert os.environ[\"EXISTING_VAR\"] == \"modified\"\n            assert os.environ[\"TEST_VAR\"] == \"new_value\"\n\n    # Check restoration\n    assert os.environ[\"EXISTING_VAR\"] == \"existing_value\"\n    assert \"TEST_VAR\" not in os.environ\n\n\ndef test_empty_string_value():\n    \"\"\"Test setting environment variable to empty string.\"\"\"\n    with temp_env_var(\"TEST_VAR\", \"\"):\n        assert os.environ[\"TEST_VAR\"] == \"\"\n        assert \"TEST_VAR\" in os.environ\n\n    # Should be removed after context\n    assert \"TEST_VAR\" not in os.environ\n\n\ndef test_overwrite_with_empty_string():\n    \"\"\"Test overwriting existing variable with empty string.\"\"\"\n    os.environ[\"EXISTING_VAR\"] = \"original\"\n\n    with temp_env_var(\"EXISTING_VAR\", \"\"):\n        assert os.environ[\"EXISTING_VAR\"] == \"\"\n\n    # Should restore original value\n    assert os.environ[\"EXISTING_VAR\"] == \"original\"\n\n\ndef test_context_manager_returns_none():\n    \"\"\"Test that context manager yields None.\"\"\"\n    with temp_env_var(\"TEST_VAR\", \"value\") as result:\n        assert result is None\n        assert os.environ[\"TEST_VAR\"] == \"value\"\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_timeout_decorator_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport multiprocessing\nimport sys\nimport threading\nimport time\n\nimport pytest  # Import pytest\n\nfrom verl.utils.py_functional import timeout_limit as timeout\n\n# --- Test Task Functions ---\nTEST_TIMEOUT_SECONDS = 1.5  # Timeout duration for tests\nLONG_TASK_DURATION = TEST_TIMEOUT_SECONDS + 0.5  # Duration slightly longer than timeout\n\n\n@timeout(seconds=TEST_TIMEOUT_SECONDS)  # Keep global decorator for mp tests\ndef quick_task(x):\n    \"\"\"A task that completes quickly.\"\"\"\n    time.sleep(0.1)\n    return \"quick_ok\"\n\n\n@timeout(seconds=TEST_TIMEOUT_SECONDS)  # Keep global decorator for mp tests\ndef slow_task(x):\n    \"\"\"A task that takes longer than the timeout.\"\"\"\n    time.sleep(LONG_TASK_DURATION)\n    return \"slow_finished\"  # This return value indicates it didn't time out\n\n\n# REMOVE global decorator here\ndef task_raises_value_error():  # Now truly not globally decorated\n    \"\"\"A task that intentionally raises a ValueError.\"\"\"\n    raise ValueError(\"Specific value error from task\")\n\n\n# --- Top-level function for signal test in subprocess ---\n# Keep this decorated globally for the specific subprocess test case\n@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)\ndef top_level_decorated_quick_task_signal():\n    \"\"\"A pickleable top-level function decorated with signal timeout.\"\"\"\n    # Assuming this calls the logic of quick_task directly for the test purpose\n    time.sleep(0.1)\n    return \"quick_ok_signal_subprocess\"  # Different return for clarity if needed\n\n\n# --- Top-level function for signal test in subprocess ---\n# Keep this decorated globally for the specific subprocess test case\n@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)\ndef top_level_decorated_slow_task_signal():\n    \"\"\"A pickleable top-level function decorated with signal timeout.\"\"\"\n    time.sleep(LONG_TASK_DURATION)\n    return \"slow_finished\"\n\n\n# --- NEW: Top-level helper function to run target in process ---\ndef run_target_and_put_in_queue(target_func, q):\n    \"\"\"\n    Top-level helper function to run a target function and put its result or exception into a queue.\n    This function is pickleable and can be used as the target for multiprocessing.Process.\n    \"\"\"\n    try:\n        result = target_func()\n        q.put((\"success\", result))\n    except Exception as e:\n        q.put((\"error\", e))\n\n\n# Use a module-level fixture to set the start method on macOS\n@pytest.fixture(scope=\"module\", autouse=True)  # Changed scope to module\ndef set_macos_start_method():\n    if sys.platform == \"darwin\":\n        # Force fork method on macOS to avoid pickling issues with globally decorated functions\n        # when running tests via pytest discovery.\n        current_method = multiprocessing.get_start_method(allow_none=True)\n        # Only set if not already set or if set to something else (less likely in test run)\n        if current_method is None or current_method != \"fork\":\n            try:\n                multiprocessing.set_start_method(\"fork\", force=True)\n            except RuntimeError:\n                # Might fail if context is already started, ignore in that case.\n                pass\n\n\ndef test_quick_task():  # Renamed from test_multiprocessing_quick_task\n    \"\"\"Tests timeout handles a quick task correctly.\"\"\"\n    # Call the globally decorated function directly\n    result = quick_task(1)\n    assert result == \"quick_ok\"  # Use pytest assert\n\n\ndef test_slow_task_timeout():  # Renamed from test_multiprocessing_slow_task_timeout\n    \"\"\"Tests timeout correctly raises TimeoutError for a slow task.\"\"\"\n    # Call the globally decorated function directly within pytest.raises\n    with pytest.raises(TimeoutError) as excinfo:  # Use pytest.raises\n        slow_task(1)\n    # Check the error message from the multiprocessing implementation\n    assert f\"timed out after {TEST_TIMEOUT_SECONDS} seconds\" in str(excinfo.value)  # Use pytest assert\n\n\ndef test_internal_exception():  # Renamed from test_multiprocessing_internal_exception\n    \"\"\"Tests timeout correctly propagates internal exceptions.\"\"\"\n    # Apply the default timeout decorator dynamically to the undecorated function\n    decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS)(task_raises_value_error)  # Apply decorator dynamically\n    with pytest.raises(ValueError) as excinfo:  # Use pytest.raises\n        decorated_task()  # Call the dynamically decorated function\n    assert str(excinfo.value) == \"Specific value error from task\"  # Use pytest assert\n\n\n# --- Test the signal implementation (use_signals=True) ---\n# Note: As per py_functional.py, use_signals=True currently falls back to\n# multiprocessing on POSIX. These tests verify that behavior.\n\n\ndef test_signal_quick_task_main_process():  # Removed self\n    \"\"\"Tests signal timeout handles a quick task correctly in the main process.\"\"\"\n\n    # Apply the signal decorator dynamically\n    def plain_quick_task_logic():\n        time.sleep(0.1)\n        return \"quick_ok_signal\"\n\n    decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_quick_task_logic)\n    assert decorated_task() == \"quick_ok_signal\"  # Use pytest assert\n\n\ndef test_signal_slow_task_main_process_timeout():  # Removed self\n    \"\"\"Tests signal timeout correctly raises TimeoutError for a slow task in the main process.\"\"\"\n\n    # Apply the signal decorator dynamically\n    def plain_slow_task_logic():\n        time.sleep(LONG_TASK_DURATION)\n        return \"slow_finished_signal\"\n\n    decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_slow_task_logic)\n    with pytest.raises(TimeoutError) as excinfo:  # Use pytest.raises\n        decorated_task()\n    # Check the error message (falls back to multiprocessing message on POSIX)\n    assert f\"timed out after {TEST_TIMEOUT_SECONDS} seconds\" in str(excinfo.value)  # Use pytest assert\n\n\n@pytest.mark.skip(reason=\"this test won't pass. Just to show why use_signals should not be used\")\ndef test_signal_in_thread_does_not_timeout():\n    \"\"\"\n    Tests that signal-based timeout does NOT work reliably in a child thread.\n    The TimeoutError from the signal handler is not expected to be raised.\n    \"\"\"\n    result_container = []  # Use a list to store result from thread\n    exception_container = []  # Use a list to store exception from thread\n\n    @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)\n    def slow_task_in_thread():\n        try:\n            print(\"Thread: Starting slow task...\")\n            time.sleep(LONG_TASK_DURATION)\n            print(\"Thread: Slow task finished.\")\n            return \"slow_finished_in_thread\"\n        except Exception as e:\n            # Catch any exception within the thread's target function\n            print(f\"Thread: Caught exception: {e}\")\n            exception_container.append(e)\n            return None  # Indicate failure\n\n    def thread_target():\n        try:\n            # Run the decorated function inside the thread\n            res = slow_task_in_thread()\n            if res is not None:\n                result_container.append(res)\n        except Exception as e:\n            # This might catch exceptions happening *outside* the decorated function\n            # but still within the thread target, though less likely here.\n            print(f\"Thread Target: Caught exception: {e}\")\n            exception_container.append(e)\n\n    thread = threading.Thread(target=thread_target)\n    print(\"Main: Starting thread...\")\n    thread.start()\n    # Wait longer than the timeout + task duration to ensure the thread finishes\n    # regardless of whether timeout worked or not.\n    thread.join(timeout=LONG_TASK_DURATION + 1)\n\n    assert len(exception_container) == 1\n    assert isinstance(exception_container[0], TimeoutError)\n    assert not result_container\n\n\ndef test_in_thread_timeout():\n    result_container = []  # Use a list to store result from thread\n    exception_container = []  # Use a list to store exception from thread\n\n    @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=False)\n    def slow_task_in_thread():\n        try:\n            print(\"Thread: Starting slow task...\")\n            time.sleep(LONG_TASK_DURATION)\n            print(\"Thread: Slow task finished.\")\n            return \"slow_finished_in_thread\"\n        except Exception as e:\n            # Catch any exception within the thread's target function\n            print(f\"Thread: Caught exception: {e}\")\n            exception_container.append(e)\n            return None  # Indicate failure\n\n    def thread_target():\n        try:\n            # Run the decorated function inside the thread\n            res = slow_task_in_thread()\n            if res is not None:\n                result_container.append(res)\n        except Exception as e:\n            # This might catch exceptions happening *outside* the decorated function\n            # but still within the thread target, though less likely here.\n            print(f\"Thread Target: Caught exception: {e}\")\n            exception_container.append(e)\n\n    thread = threading.Thread(target=thread_target)\n    print(\"Main: Starting thread...\")\n    thread.start()\n    # Wait longer than the timeout + task duration to ensure the thread finishes\n    # regardless of whether timeout worked or not.\n    thread.join(timeout=LONG_TASK_DURATION + 1)\n\n    assert len(exception_container) == 1\n    assert isinstance(exception_container[0], TimeoutError)\n    assert not result_container\n"
  },
  {
    "path": "verl_distillation/tests/utils/test_torch_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport pytest\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nfrom verl.utils.torch_functional import distributed_masked_mean, distributed_mean_max_min_std, masked_mean\n\n\ndef _worker_mean(rank: int, world_size: int, rendezvous_file: str):\n    # 1) set GPU and init NCCL\n    torch.cuda.set_device(rank)\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=f\"file://{rendezvous_file}\",\n        rank=rank,\n        world_size=world_size,\n    )\n\n    # each rank holds tensor [rank+1]\n    local = torch.tensor([float(rank + 1)], device=f\"cuda:{rank}\")\n    mean, gmax, gmin, gstd = distributed_mean_max_min_std(local, True, True, True)\n\n    values = [float(i + 1) for i in range(world_size)]\n    exp_mean = sum(values) / len(values)\n    exp_max = max(values)\n    exp_min = min(values)\n    var = sum((x - exp_mean) ** 2 for x in values) / (len(values) - 1)\n    exp_std = var**0.5\n\n    # all ranks should see the same result\n    assert torch.allclose(mean.cpu(), torch.tensor(exp_mean)), f\"mean@{rank}\"\n    assert torch.allclose(gmax.cpu(), torch.tensor(exp_max)), f\"max@{rank}\"\n    assert torch.allclose(gmin.cpu(), torch.tensor(exp_min)), f\"min@{rank}\"\n    assert torch.allclose(gstd.cpu(), torch.tensor(exp_std)), f\"std@{rank}\"\n\n    dist.destroy_process_group()\n\n\n@pytest.mark.parametrize(\n    \"value,mask,gt\",\n    [\n        ([1.0, 2.0, 3.0, 4.0], [1, 0, 0, 1], 2.5),\n        ([1.0, 2.0, float(\"nan\"), 4.0], [1, 0, 0, 1], 2.5),\n        ([1.0, 2.0, float(\"nan\"), 4.0], [1, 0, 1, 0], float(\"nan\")),\n    ],\n)\ndef test_masked_mean(value, mask, gt):\n    res = masked_mean(torch.tensor(value), torch.tensor(mask))\n    gt = torch.tensor(gt)\n    assert torch.allclose(res, gt) or (torch.isnan(res) and torch.isnan(gt))\n\n\n@pytest.mark.parametrize(\"world_size\", [2, 4])\ndef test_distributed_mean_max_min_std(world_size, tmp_path):\n    rendezvous_file = str(tmp_path / \"rdzv_mean\")\n    os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)\n\n    mp.spawn(\n        fn=_worker_mean,\n        args=(world_size, rendezvous_file),\n        nprocs=world_size,\n        join=True,\n    )\n\n\ndef _worker_mask(rank: int, world_size: int, rendezvous_file: str):\n    torch.cuda.set_device(rank)\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=f\"file://{rendezvous_file}\",\n        rank=rank,\n        world_size=world_size,\n    )\n\n    # build per‐rank tensor and mask\n    local_tensor = torch.tensor([rank * 2 + 1.0, rank * 2 + 2.0], device=f\"cuda:{rank}\")\n    if rank == 0:\n        mask = torch.tensor([1, 0], device=f\"cuda:{rank}\", dtype=torch.float32)\n    else:\n        mask = torch.tensor([0, 1], device=f\"cuda:{rank}\", dtype=torch.float32)\n\n    gmean = distributed_masked_mean(local_tensor, mask)\n\n    valid_values = [1.0] + [2 * i + 2.0 for i in range(1, world_size)]\n    expected_mean = sum(valid_values) / len(valid_values)\n    assert torch.allclose(gmean.cpu(), torch.tensor(expected_mean)), f\"masked_mean@{rank}\"\n\n    dist.destroy_process_group()\n\n\n@pytest.mark.parametrize(\"world_size\", [2, 4])\ndef test_distributed_masked_mean(world_size, tmp_path):\n    rendezvous_file = str(tmp_path / \"rdzv_mask\")\n    os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)\n\n    mp.spawn(\n        fn=_worker_mask,\n        args=(world_size, rendezvous_file),\n        nprocs=world_size,\n        join=True,\n    )\n"
  },
  {
    "path": "verl_distillation/tests/workers/actor/test_special_dp_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\n\nimport torch\nimport torch.nn as nn\nfrom tensordict import TensorDict\nfrom transformers import AutoModelForCausalLM, Qwen3Config\n\nfrom verl import DataProto\nfrom verl.workers.actor.dp_actor import DataParallelPPOActor\nfrom verl.workers.config import FSDPActorConfig, OptimizerConfig\n\n\nclass MockTransformerModel(nn.Module):\n    \"\"\"Mock transformer model for testing DataParallelPPOActor\"\"\"\n\n    def __init__(self, vocab_size=1000, hidden_size=64):\n        super().__init__()\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.embedding = nn.Embedding(vocab_size, hidden_size)\n        self.transformer = nn.TransformerEncoder(\n            nn.TransformerEncoderLayer(d_model=hidden_size, nhead=4, batch_first=True), num_layers=2\n        )\n        self.lm_head = nn.Linear(hidden_size, vocab_size)\n\n    def forward(self, input_ids, attention_mask=None, position_ids=None, use_cache=False, **kwargs):\n        batch_size, seq_len = input_ids.shape\n\n        embeddings = self.embedding(input_ids)\n        hidden_states = self.transformer(embeddings)\n        logits = self.lm_head(hidden_states)\n\n        class MockOutput:\n            def __init__(self, logits):\n                self.logits = logits\n\n        return MockOutput(logits)\n\n\nclass TestDataParallelPPOActor(unittest.TestCase):\n    \"\"\"Test DataParallelPPOActor compute_log_prob and update_policy methods\"\"\"\n\n    @classmethod\n    def setUpClass(cls):\n        \"\"\"Set up distributed environment\"\"\"\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(\n                backend=\"nccl\" if torch.cuda.is_available() else \"gloo\", init_method=\"env://\"\n            )\n\n        cls.rank = torch.distributed.get_rank()\n        cls.world_size = torch.distributed.get_world_size()\n\n        if torch.cuda.is_available():\n            torch.cuda.set_device(cls.rank)\n            cls.device = torch.device(f\"cuda:{cls.rank}\")\n        else:\n            cls.device = torch.device(\"cpu\")\n\n    def setUp(self):\n        \"\"\"Set up test fixtures\"\"\"\n        self.config = FSDPActorConfig(\n            strategy=\"fsdp2\",\n            ppo_mini_batch_size=4,\n            ppo_micro_batch_size_per_gpu=2,\n            ppo_epochs=1,\n            clip_ratio=0.2,\n            entropy_coeff=0.01,\n            grad_clip=1.0,\n            use_dynamic_bsz=False,\n            use_torch_compile=False,  # Disable torch.compile for testing\n            ulysses_sequence_parallel_size=1,\n            optim=OptimizerConfig(lr=1e-6),\n        )\n\n        self.mock_model = MockTransformerModel(vocab_size=1000, hidden_size=64).to(self.device)\n        self.mock_optimizer = torch.optim.Adam(self.mock_model.parameters(), lr=1e-4)\n\n        self.actor = DataParallelPPOActor(\n            config=self.config, actor_module=self.mock_model, actor_optimizer=self.mock_optimizer\n        )\n\n    @classmethod\n    def tearDownClass(cls):\n        \"\"\"Clean up distributed environment\"\"\"\n        if torch.distributed.is_initialized():\n            torch.distributed.destroy_process_group()\n\n    def _create_test_data_for_compute_log_prob(self):\n        \"\"\"Create test DataProto for compute_log_prob method\"\"\"\n        batch_size = 2\n        prompt_length = 8\n        response_length = 4\n        total_length = prompt_length + response_length\n        vocab_size = 1000\n\n        input_ids = torch.randint(0, vocab_size, (batch_size, total_length)).to(self.device)\n        attention_mask = torch.ones(batch_size, total_length).to(self.device)\n        position_ids = torch.arange(total_length).unsqueeze(0).expand(batch_size, -1).to(self.device)\n        responses = input_ids[:, -response_length:]  # Last part is the response\n\n        tensor_dict = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n                \"responses\": responses,\n            },\n            batch_size=[batch_size],\n        )\n\n        meta_info = {\"micro_batch_size\": batch_size, \"temperature\": 1.0, \"use_dynamic_bsz\": False}\n\n        return DataProto(batch=tensor_dict, meta_info=meta_info)\n\n    def _create_test_data_for_update_policy(self):\n        \"\"\"Create test DataProto for update_policy method\"\"\"\n        batch_size = 4  # Must match ppo_mini_batch_size\n        prompt_length = 8\n        response_length = 4\n        total_length = prompt_length + response_length\n        vocab_size = 1000\n\n        input_ids = torch.randint(0, vocab_size, (batch_size, total_length)).to(self.device)\n        attention_mask = torch.ones(batch_size, total_length).to(self.device)\n        position_ids = torch.arange(total_length).unsqueeze(0).expand(batch_size, -1).to(self.device)\n        responses = input_ids[:, -response_length:]\n        response_mask = torch.ones(batch_size, response_length).to(self.device)\n        old_log_probs = torch.randn(batch_size, response_length).to(self.device) * 0.1  # Small values\n        advantages = torch.randn(batch_size, response_length).to(self.device) * 0.5\n\n        tensor_dict = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n                \"responses\": responses,\n                \"response_mask\": response_mask,\n                \"old_log_probs\": old_log_probs,\n                \"advantages\": advantages,\n            },\n            batch_size=[batch_size],\n        )\n\n        meta_info = {\"temperature\": 1.0}\n\n        return DataProto(batch=tensor_dict, meta_info=meta_info)\n\n    def test_compute_log_prob(self):\n        \"\"\"Test compute_log_prob method\"\"\"\n        data = self._create_test_data_for_compute_log_prob()\n\n        log_probs, entropies = self.actor.compute_log_prob(data, calculate_entropy=True)\n\n        batch_size = data.batch[\"responses\"].shape[0]\n        response_length = data.batch[\"responses\"].shape[1]\n\n        self.assertIsInstance(log_probs, torch.Tensor)\n        self.assertEqual(log_probs.shape, (batch_size, response_length))\n        self.assertTrue(torch.all(torch.isfinite(log_probs)))\n\n        self.assertIsInstance(entropies, torch.Tensor)\n        self.assertEqual(entropies.shape, (batch_size, response_length))\n        self.assertTrue(torch.all(torch.isfinite(entropies)))\n        self.assertTrue(torch.all(entropies >= 0))  # Entropy should be non-negative\n\n    def test_compute_log_prob_without_entropy(self):\n        \"\"\"Test compute_log_prob method without entropy calculation\"\"\"\n        data = self._create_test_data_for_compute_log_prob()\n\n        log_probs, entropies = self.actor.compute_log_prob(data, calculate_entropy=False)\n\n        batch_size = data.batch[\"responses\"].shape[0]\n        response_length = data.batch[\"responses\"].shape[1]\n\n        self.assertIsInstance(log_probs, torch.Tensor)\n        self.assertEqual(log_probs.shape, (batch_size, response_length))\n        self.assertTrue(torch.all(torch.isfinite(log_probs)))\n\n        self.assertIsNone(entropies)\n\n    def test_update_policy(self):\n        \"\"\"Test update_policy method\"\"\"\n        data = self._create_test_data_for_update_policy()\n\n        metrics = self.actor.update_policy(data)\n\n        self.assertIsInstance(metrics, dict)\n\n        expected_metric_keys = [\n            \"actor/pg_loss\",\n            \"actor/pg_clipfrac\",\n            \"actor/ppo_kl\",\n            \"actor/pg_clipfrac_lower\",\n            \"actor/grad_norm\",\n        ]\n\n        for key in expected_metric_keys:\n            self.assertIn(key, metrics)\n            if isinstance(metrics[key], list):\n                self.assertTrue(all(torch.isfinite(torch.tensor(v)) for v in metrics[key]))\n            else:\n                self.assertIsInstance(metrics[key], (float, int))\n                self.assertTrue(torch.isfinite(torch.tensor(metrics[key])))\n\n    def test_dataparallelppoactor_initialization(self):\n        \"\"\"Test DataParallelPPOActor initialization\"\"\"\n        self.assertIsNotNone(self.actor.actor_module)\n        self.assertIsNotNone(self.actor.actor_optimizer)\n        self.assertEqual(self.actor.config, self.config)\n\n        self.assertEqual(self.actor.config.strategy, \"fsdp2\")\n        self.assertEqual(self.actor.config.ppo_mini_batch_size, 4)\n        self.assertEqual(self.actor.config.clip_ratio, 0.2)\n\n    def test_dataparallelppoactor_with_qwen3_model(self):\n        \"\"\"Test DataParallelPPOActor with real Qwen3ForCausalLM model\"\"\"\n        qwen_config = Qwen3Config(\n            vocab_size=1000,\n            hidden_size=64,\n            intermediate_size=128,\n            num_hidden_layers=2,\n            num_attention_heads=4,\n            num_key_value_heads=2,\n            max_position_embeddings=512,\n            torch_dtype=torch.float32,\n            use_cache=False,\n        )\n\n        with torch.device(self.device):\n            qwen_model = AutoModelForCausalLM.from_config(config=qwen_config, torch_dtype=torch.float32).to(self.device)\n\n        qwen_optimizer = torch.optim.Adam(qwen_model.parameters(), lr=1e-4)\n\n        qwen_actor = DataParallelPPOActor(config=self.config, actor_module=qwen_model, actor_optimizer=qwen_optimizer)\n\n        data = self._create_test_data_for_compute_log_prob()\n        log_probs, entropies = qwen_actor.compute_log_prob(data, calculate_entropy=True)\n\n        batch_size = data.batch[\"responses\"].shape[0]\n        response_length = data.batch[\"responses\"].shape[1]\n\n        self.assertIsInstance(log_probs, torch.Tensor)\n        self.assertEqual(log_probs.shape, (batch_size, response_length))\n        self.assertTrue(torch.all(torch.isfinite(log_probs)))\n\n        self.assertIsInstance(entropies, torch.Tensor)\n        self.assertEqual(entropies.shape, (batch_size, response_length))\n        self.assertTrue(torch.all(torch.isfinite(entropies)))\n        self.assertTrue(torch.all(entropies >= 0))\n\n        policy_data = self._create_test_data_for_update_policy()\n        metrics = qwen_actor.update_policy(policy_data)\n\n        self.assertIsInstance(metrics, dict)\n\n        expected_metric_keys = [\n            \"actor/pg_loss\",\n            \"actor/pg_clipfrac\",\n            \"actor/ppo_kl\",\n            \"actor/pg_clipfrac_lower\",\n            \"actor/grad_norm\",\n        ]\n\n        for key in expected_metric_keys:\n            self.assertIn(key, metrics)\n            if isinstance(metrics[key], list):\n                self.assertTrue(all(torch.isfinite(torch.tensor(v)) for v in metrics[key]))\n            else:\n                self.assertIsInstance(metrics[key], (float, int))\n                self.assertTrue(torch.isfinite(torch.tensor(metrics[key])))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/workers/config/test_actor_config_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport unittest\n\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import (\n    ActorConfig,\n    FSDPActorConfig,\n    McoreActorConfig,\n    OptimizerConfig,\n)\n\n\nclass TestActorConfig(unittest.TestCase):\n    \"\"\"Test the ActorConfig dataclass and its variants.\"\"\"\n\n    def test_config_inheritance(self):\n        \"\"\"Test that the inheritance hierarchy works correctly.\"\"\"\n        megatron_dict = {\n            \"_target_\": \"verl.workers.config.McoreActorConfig\",\n            \"strategy\": \"megatron\",\n            \"ppo_mini_batch_size\": 256,\n            \"ppo_micro_batch_size_per_gpu\": 256,\n            \"clip_ratio\": 0.2,\n            \"optim\": {\n                \"_target_\": \"verl.workers.config.McoreOptimizerConfig\",\n                \"lr\": 0.1,\n            },\n        }\n        fsdp_dict = {\n            \"_target_\": \"verl.workers.config.FSDPActorConfig\",\n            \"strategy\": \"fsdp\",\n            \"ppo_mini_batch_size\": 256,\n            \"ppo_micro_batch_size_per_gpu\": 256,\n            \"clip_ratio\": 0.2,\n            \"optim\": {\n                \"_target_\": \"verl.workers.config.FSDPOptimizerConfig\",\n                \"lr\": 0.1,\n            },\n        }\n\n        megatron_config = omega_conf_to_dataclass(megatron_dict)\n        fsdp_config = omega_conf_to_dataclass(fsdp_dict)\n\n        self.assertIsInstance(megatron_config, ActorConfig)\n        self.assertIsInstance(fsdp_config, ActorConfig)\n\n        self.assertEqual(megatron_config.ppo_mini_batch_size, fsdp_config.ppo_mini_batch_size)\n        self.assertEqual(megatron_config.clip_ratio, fsdp_config.clip_ratio)\n\n    def test_actor_config_from_yaml(self):\n        \"\"\"Test creating ActorConfig from YAML file.\"\"\"\n        from hydra import compose, initialize_config_dir\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config/actor\")):\n            cfg = compose(config_name=\"actor\", overrides=[\"strategy=fsdp\", \"ppo_micro_batch_size_per_gpu=128\"])\n\n        config = omega_conf_to_dataclass(cfg)\n\n        self.assertIsInstance(config, ActorConfig)\n        self.assertEqual(config.strategy, \"fsdp\")\n\n    def test_fsdp_actor_config_from_yaml(self):\n        \"\"\"Test creating FSDPActorConfig from YAML file.\"\"\"\n        from hydra import compose, initialize_config_dir\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config/actor\")):\n            cfg = compose(config_name=\"dp_actor\", overrides=[\"strategy=fsdp2\", \"ppo_micro_batch_size_per_gpu=128\"])\n\n        config = omega_conf_to_dataclass(cfg)\n\n        self.assertIsInstance(config, FSDPActorConfig)\n        self.assertEqual(config.strategy, \"fsdp2\")\n\n    def test_megatron_actor_config_from_yaml(self):\n        \"\"\"Test creating McoreActorConfig from YAML file.\"\"\"\n        from hydra import compose, initialize_config_dir\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config/actor\")):\n            cfg = compose(config_name=\"megatron_actor\", overrides=[\"ppo_micro_batch_size_per_gpu=128\"])\n\n        config = omega_conf_to_dataclass(cfg)\n\n        self.assertIsInstance(config, McoreActorConfig)\n        self.assertEqual(config.strategy, \"megatron\")\n\n    def test_config_get_method(self):\n        \"\"\"Test the get method for backward compatibility.\"\"\"\n        config_dict = {\n            \"_target_\": \"verl.workers.config.ActorConfig\",\n            \"strategy\": \"fsdp\",\n            \"ppo_mini_batch_size\": 256,\n            \"ppo_micro_batch_size_per_gpu\": 256,\n            \"optim\": {\n                \"_target_\": \"verl.workers.config.OptimizerConfig\",\n                \"lr\": 0.1,\n            },\n        }\n        config = omega_conf_to_dataclass(config_dict)\n\n        self.assertEqual(config.get(\"strategy\"), \"fsdp\")\n        self.assertEqual(config.get(\"ppo_mini_batch_size\"), 256)\n\n        self.assertIsNone(config.get(\"non_existing\"))\n        self.assertEqual(config.get(\"non_existing\", \"default\"), \"default\")\n\n    def test_config_dict_like_access(self):\n        \"\"\"Test dictionary-like access to config fields.\"\"\"\n        config_dict = {\n            \"_target_\": \"verl.workers.config.ActorConfig\",\n            \"strategy\": \"fsdp\",\n            \"ppo_mini_batch_size\": 256,\n            \"ppo_micro_batch_size_per_gpu\": 256,\n            \"optim\": {\n                \"_target_\": \"verl.workers.config.OptimizerConfig\",\n                \"lr\": 0.1,\n            },\n        }\n        config = omega_conf_to_dataclass(config_dict)\n\n        self.assertEqual(config[\"strategy\"], \"fsdp\")\n        self.assertEqual(config[\"ppo_mini_batch_size\"], 256)\n\n        field_names = list(config)\n        self.assertIn(\"strategy\", field_names)\n        self.assertIn(\"ppo_mini_batch_size\", field_names)\n\n        self.assertGreater(len(config), 0)\n\n    def test_frozen_fields_modification_raises_exception(self):\n        \"\"\"Test that modifying frozen fields raises an exception.\"\"\"\n        config_dict = {\n            \"_target_\": \"verl.workers.config.ActorConfig\",\n            \"strategy\": \"fsdp\",\n            \"ppo_mini_batch_size\": 256,\n            \"ppo_micro_batch_size_per_gpu\": 256,\n            \"optim\": {\n                \"_target_\": \"verl.workers.config.OptimizerConfig\",\n                \"lr\": 0.1,\n            },\n        }\n        config = omega_conf_to_dataclass(config_dict)\n\n        with self.assertRaises(AttributeError):\n            config.strategy = \"megatron\"\n\n        with self.assertRaises(AttributeError):\n            config.clip_ratio = 0.5\n\n        config.ppo_mini_batch_size = 512  # This should work since it's not in frozen fields anymore\n        self.assertEqual(config.ppo_mini_batch_size, 512)\n\n    def test_actor_config_validation_exceptions(self):\n        \"\"\"Test that ActorConfig.__post_init__ raises appropriate validation exceptions.\"\"\"\n        optim = OptimizerConfig(lr=0.1)\n        with self.assertRaises((ValueError, AssertionError)) as cm:\n            ActorConfig(\n                strategy=\"fsdp\",\n                loss_agg_mode=\"invalid-mode\",\n                use_dynamic_bsz=True,\n                optim=optim,\n                ppo_micro_batch_size_per_gpu=4,\n            )\n        self.assertIn(\"Invalid loss_agg_mode\", str(cm.exception))\n\n        with self.assertRaises((ValueError, AssertionError)) as cm:\n            ActorConfig(\n                strategy=\"fsdp\",\n                use_dynamic_bsz=False,\n                ppo_micro_batch_size=4,\n                ppo_micro_batch_size_per_gpu=2,\n                optim=optim,\n            )\n        self.assertIn(\"You have set both\", str(cm.exception))\n\n        with self.assertRaises((ValueError, AssertionError)) as cm:\n            ActorConfig(\n                strategy=\"fsdp\",\n                use_dynamic_bsz=False,\n                ppo_micro_batch_size=None,\n                ppo_micro_batch_size_per_gpu=None,\n                optim=optim,\n            )\n        self.assertIn(\"Please set at least one\", str(cm.exception))\n\n        config = ActorConfig(\n            strategy=\"fsdp\",\n            use_dynamic_bsz=True,\n            ppo_micro_batch_size=None,\n            ppo_micro_batch_size_per_gpu=None,\n            optim=optim,\n        )\n        self.assertIsNotNone(config)  # Should not raise an exception\n\n    def test_fsdp_actor_config_validation_exceptions(self):\n        \"\"\"Test that FSDPActorConfig.validate() raises appropriate validation exceptions.\"\"\"\n        optim = OptimizerConfig(lr=0.1)\n        config = FSDPActorConfig(\n            strategy=\"fsdp\",\n            ulysses_sequence_parallel_size=2,\n            use_dynamic_bsz=True,  # Skip batch size validation to focus on FSDP validation\n            optim=optim,\n        )\n\n        model_config = {\"use_remove_padding\": False}\n        with self.assertRaises(ValueError) as cm:\n            config.validate(n_gpus=8, train_batch_size=256, model_config=model_config)\n        self.assertIn(\"you must enable `use_remove_padding`\", str(cm.exception))\n\n    def test_actor_config_validate_method_exceptions(self):\n        \"\"\"Test that ActorConfig.validate() raises appropriate validation exceptions.\"\"\"\n        optim = OptimizerConfig(lr=0.1)\n        config = ActorConfig(\n            strategy=\"fsdp\",\n            use_dynamic_bsz=False,\n            ppo_mini_batch_size=256,\n            ppo_micro_batch_size=8,\n            ppo_micro_batch_size_per_gpu=None,  # Ensure only one batch size setting is used\n            optim=optim,\n        )\n\n        with self.assertRaises(ValueError) as cm:\n            config.validate(n_gpus=8, train_batch_size=128)\n        self.assertIn(\"train_batch_size\", str(cm.exception))\n\n        with self.assertRaises(ValueError) as cm:\n            config.validate(n_gpus=16, train_batch_size=512)\n        self.assertIn(\"must be >= n_gpus\", str(cm.exception))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/workers/config/test_critic_config_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom pathlib import Path\n\nimport pytest\nfrom hydra import compose, initialize_config_dir\n\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.profiler import ProfilerConfig\nfrom verl.workers.config import (\n    CriticConfig,\n    FSDPCriticConfig,\n    FSDPOptimizerConfig,\n    McoreCriticConfig,\n    McoreOptimizerConfig,\n    OptimizerConfig,\n)\n\n\nclass TestCriticConfig:\n    \"\"\"Test suite for critic configuration dataclasses.\"\"\"\n\n    @pytest.fixture\n    def config_dir(self):\n        \"\"\"Get the path to the config directory.\"\"\"\n        return Path(__file__).parent.parent.parent.parent / \"verl\" / \"trainer\" / \"config\" / \"critic\"\n\n    def test_megatron_critic_config_instantiation_from_yaml(self, config_dir):\n        \"\"\"Test that McoreCriticConfig can be instantiated from megatron_critic.yaml.\"\"\"\n        yaml_path = config_dir / \"megatron_critic.yaml\"\n        assert yaml_path.exists(), f\"Config file not found: {yaml_path}\"\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config/critic\")):\n            test_config = compose(config_name=\"megatron_critic\", overrides=[\"ppo_micro_batch_size_per_gpu=1\"])\n\n        megatron_config_obj = omega_conf_to_dataclass(test_config)\n\n        assert isinstance(megatron_config_obj, McoreCriticConfig)\n        assert isinstance(megatron_config_obj, CriticConfig)\n\n        expected_attrs = [\n            \"strategy\",\n            \"rollout_n\",\n            \"optim\",\n            \"model\",\n            \"ppo_mini_batch_size\",\n            \"ppo_max_token_len_per_gpu\",\n            \"cliprange_value\",\n            \"get\",\n            \"nccl_timeout\",\n            \"megatron\",\n            \"load_weight\",\n        ]\n        for attr in expected_attrs:\n            assert hasattr(megatron_config_obj, attr), f\"Missing attribute: {attr}\"\n\n        assert callable(megatron_config_obj.get)\n        assert megatron_config_obj.strategy == \"megatron\"\n\n    def test_fsdp_critic_config_instantiation_from_yaml(self, config_dir):\n        \"\"\"Test that FSDPCriticConfig can be instantiated from dp_critic.yaml.\"\"\"\n        yaml_path = config_dir / \"dp_critic.yaml\"\n        assert yaml_path.exists(), f\"Config file not found: {yaml_path}\"\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config/critic\")):\n            test_config = compose(config_name=\"dp_critic\", overrides=[\"ppo_micro_batch_size_per_gpu=1\"])\n\n        fsdp_config_obj = omega_conf_to_dataclass(test_config)\n\n        assert isinstance(fsdp_config_obj, FSDPCriticConfig)\n        assert isinstance(fsdp_config_obj, CriticConfig)\n\n        expected_attrs = [\n            \"strategy\",\n            \"rollout_n\",\n            \"optim\",\n            \"model\",\n            \"ppo_mini_batch_size\",\n            \"ppo_max_token_len_per_gpu\",\n            \"cliprange_value\",\n            \"get\",\n            \"forward_micro_batch_size\",\n            \"forward_micro_batch_size_per_gpu\",\n            \"ulysses_sequence_parallel_size\",\n            \"grad_clip\",\n        ]\n        for attr in expected_attrs:\n            assert hasattr(fsdp_config_obj, attr), f\"Missing attribute: {attr}\"\n\n        assert callable(fsdp_config_obj.get)\n        assert fsdp_config_obj.strategy == \"fsdp\"\n\n    def test_config_inheritance_hierarchy(self):\n        \"\"\"Test that the inheritance hierarchy is correct.\"\"\"\n        megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=McoreOptimizerConfig(lr=0.1))\n        assert isinstance(megatron_config, CriticConfig)\n        assert isinstance(megatron_config, McoreCriticConfig)\n\n        fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=FSDPOptimizerConfig(lr=0.1))\n        assert isinstance(fsdp_config, CriticConfig)\n        assert isinstance(fsdp_config, FSDPCriticConfig)\n\n        critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy=\"fsdp2\", optim=OptimizerConfig(lr=0.1))\n        assert isinstance(critic_config, CriticConfig)\n        assert not isinstance(critic_config, McoreCriticConfig)\n        assert not isinstance(critic_config, FSDPCriticConfig)\n\n    def test_config_dict_interface(self):\n        \"\"\"Test that configs provide dict-like interface from BaseConfig.\"\"\"\n        optim = OptimizerConfig(lr=0.1)\n        config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy=\"fsdp2\", optim=optim)\n\n        assert \"strategy\" in config\n        assert config[\"strategy\"] == \"fsdp2\"\n\n        assert config.get(\"strategy\") == \"fsdp2\"\n        assert config.get(\"nonexistent_key\", \"default\") == \"default\"\n\n        keys = list(config)\n        assert \"strategy\" in keys\n        assert \"rollout_n\" in keys\n\n        assert len(config) > 0\n\n    def test_frozen_fields_immutability(self):\n        \"\"\"Test that frozen fields raise exceptions when modified after creation.\"\"\"\n        critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy=\"fsdp2\", optim=OptimizerConfig(lr=0.1))\n        frozen_fields = [\"rollout_n\", \"strategy\", \"cliprange_value\"]\n\n        for field_name in frozen_fields:\n            with pytest.raises((AttributeError, TypeError, ValueError)):\n                setattr(critic_config, field_name, \"modified_value\")\n\n        megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=McoreOptimizerConfig(lr=0.1))\n        megatron_frozen_fields = [\"nccl_timeout\", \"load_weight\", \"data_loader_seed\"]\n\n        for field_name in megatron_frozen_fields:\n            with pytest.raises((AttributeError, TypeError, ValueError)):\n                setattr(megatron_config, field_name, \"modified_value\")\n\n        fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=FSDPOptimizerConfig(lr=0.1))\n        fsdp_frozen_fields = [\"ulysses_sequence_parallel_size\", \"grad_clip\"]\n\n        for field_name in fsdp_frozen_fields:\n            with pytest.raises((AttributeError, TypeError, ValueError)):\n                setattr(fsdp_config, field_name, \"modified_value\")\n\n    def test_batch_size_fields_modifiable(self):\n        \"\"\"Test that batch size fields can be modified after creation.\"\"\"\n        optim = OptimizerConfig(lr=0.1)\n        critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy=\"fsdp2\", optim=optim)\n\n        critic_config.ppo_mini_batch_size = 8\n        critic_config.ppo_micro_batch_size = 4\n        critic_config.ppo_micro_batch_size_per_gpu = 2\n\n        assert critic_config.ppo_mini_batch_size == 8\n        assert critic_config.ppo_micro_batch_size == 4\n        assert critic_config.ppo_micro_batch_size_per_gpu == 2\n\n        fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=FSDPOptimizerConfig(lr=0.1))\n\n        fsdp_config.forward_micro_batch_size = 16\n        fsdp_config.forward_micro_batch_size_per_gpu = 8\n\n        assert fsdp_config.forward_micro_batch_size == 16\n        assert fsdp_config.forward_micro_batch_size_per_gpu == 8\n\n    def test_profiler_config_type_validation(self):\n        \"\"\"Test that profiler field has correct type and validation.\"\"\"\n        optim = OptimizerConfig(lr=0.1)\n        critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy=\"fsdp2\", optim=optim)\n        assert isinstance(critic_config.profiler, ProfilerConfig)\n        assert critic_config.profiler.all_ranks is False\n        assert critic_config.profiler.ranks == []\n\n        custom_profiler = ProfilerConfig(all_ranks=True, ranks=[0, 1])\n        critic_config_custom = CriticConfig(\n            profiler=custom_profiler, ppo_micro_batch_size_per_gpu=1, strategy=\"fsdp2\", optim=optim\n        )\n        assert isinstance(critic_config_custom.profiler, ProfilerConfig)\n        assert critic_config_custom.profiler.all_ranks is True\n        assert critic_config_custom.profiler.ranks == [0, 1]\n\n        profiler1 = ProfilerConfig(enable=True, ranks=[0, 1])\n        profiler2 = ProfilerConfig(all_ranks=True, ranks=[1, 2])\n\n        union_result = profiler1.union(profiler2)\n        assert union_result.enable is True\n        assert union_result.all_ranks is True\n        assert set(union_result.ranks) == {0, 1, 2}\n\n        intersect_result = profiler1.intersect(profiler2)\n        assert intersect_result.all_ranks is False\n        assert intersect_result.ranks == [1]\n\n    def test_critic_config_validation_logic(self):\n        \"\"\"Test the __post_init__ validation logic for CriticConfig.\"\"\"\n        optim = OptimizerConfig(lr=0.1)\n        valid_config = CriticConfig(\n            strategy=\"fsdp2\", ppo_micro_batch_size_per_gpu=2, use_dynamic_bsz=False, optim=optim\n        )\n        assert valid_config.ppo_micro_batch_size_per_gpu == 2\n\n        valid_config2 = CriticConfig(\n            strategy=\"fsdp2\",\n            ppo_micro_batch_size_per_gpu=None,\n            ppo_micro_batch_size=4,\n            ppo_mini_batch_size=8,\n            use_dynamic_bsz=False,\n            optim=optim,\n        )\n        assert valid_config2.ppo_micro_batch_size == 4\n\n        dynamic_config = CriticConfig(\n            strategy=\"fsdp2\", ppo_micro_batch_size_per_gpu=2, use_dynamic_bsz=True, optim=optim\n        )\n        assert dynamic_config.use_dynamic_bsz is True\n\n        with pytest.raises(ValueError, match=\"You have set both.*micro_batch_size.*AND.*micro_batch_size_per_gpu\"):\n            CriticConfig(\n                strategy=\"fsdp2\",\n                ppo_micro_batch_size=4,\n                ppo_micro_batch_size_per_gpu=2,\n                use_dynamic_bsz=False,\n                optim=optim,\n            )\n\n        with pytest.raises(\n            ValueError, match=\"Please set at least one of.*micro_batch_size.*or.*micro_batch_size_per_gpu\"\n        ):\n            CriticConfig(\n                strategy=\"fsdp2\",\n                ppo_micro_batch_size=None,\n                ppo_micro_batch_size_per_gpu=None,\n                use_dynamic_bsz=False,\n                optim=optim,\n            )\n\n    def test_micro_batch_size_divisibility_validation(self):\n        \"\"\"Test micro batch size divisibility validation in __post_init__.\"\"\"\n        optim = OptimizerConfig(lr=0.1)\n        valid_config = CriticConfig(\n            strategy=\"fsdp2\", ppo_micro_batch_size_per_gpu=2, ppo_mini_batch_size=8, use_dynamic_bsz=False, optim=optim\n        )\n        assert valid_config.ppo_mini_batch_size == 8\n        assert valid_config.ppo_micro_batch_size_per_gpu == 2\n\n        valid_config_with_mbs = CriticConfig(\n            strategy=\"fsdp2\", ppo_mini_batch_size=8, ppo_micro_batch_size=4, use_dynamic_bsz=False, optim=optim\n        )\n        assert valid_config_with_mbs.ppo_mini_batch_size == 8\n        assert valid_config_with_mbs.ppo_micro_batch_size == 4\n\n        with pytest.raises(ValueError, match=\"ppo_mini_batch_size.*must be divisible by.*ppo_micro_batch_size\"):\n            CriticConfig(\n                strategy=\"fsdp2\", ppo_mini_batch_size=7, ppo_micro_batch_size=4, use_dynamic_bsz=False, optim=optim\n            )\n\n        dynamic_config = CriticConfig(\n            strategy=\"fsdp2\", ppo_mini_batch_size=7, ppo_micro_batch_size=4, use_dynamic_bsz=True, optim=optim\n        )\n        assert dynamic_config.use_dynamic_bsz is True\n\n    def test_fsdp_sequence_parallelism_validation(self):\n        \"\"\"Test FSDP sequence parallelism validation in FSDPCriticConfig.__post_init__.\"\"\"\n        valid_config = FSDPCriticConfig(\n            ppo_micro_batch_size_per_gpu=2,\n            ulysses_sequence_parallel_size=2,\n            model={\"use_remove_padding\": True},\n            optim=FSDPOptimizerConfig(lr=0.1),\n        )\n        assert valid_config.ulysses_sequence_parallel_size == 2\n\n        with pytest.raises(\n            ValueError, match=\"When using sequence parallelism for critic, you must enable.*use_remove_padding\"\n        ):\n            FSDPCriticConfig(\n                ppo_micro_batch_size_per_gpu=2,\n                ulysses_sequence_parallel_size=2,\n                model={\"use_remove_padding\": False},\n                optim=FSDPOptimizerConfig(lr=0.1),\n            )\n\n        valid_config_no_sp = FSDPCriticConfig(\n            ppo_micro_batch_size_per_gpu=2,\n            ulysses_sequence_parallel_size=1,\n            model={\"use_remove_padding\": False},\n            optim=FSDPOptimizerConfig(lr=0.1),\n        )\n        assert valid_config_no_sp.ulysses_sequence_parallel_size == 1\n"
  },
  {
    "path": "verl_distillation/tests/workers/config/test_engine_config_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\nfrom verl.workers.config.engine import FSDPEngineConfig, McoreEngineConfig\n\n\nclass TestMcoreEngineConfig:\n    def test_default_values(self):\n        config = McoreEngineConfig()\n        assert config.tensor_model_parallel_size == 1\n        assert config.sequence_parallel is False  # Should be auto-corrected\n        assert config.seed == 42\n\n    def test_post_init_validation(self):\n        # Test TP size 1 forces sequence_parallel=False\n        config = McoreEngineConfig(tensor_model_parallel_size=1)\n        assert config.sequence_parallel is False\n\n        # Test TP >1 keeps sequence_parallel=True\n        config = McoreEngineConfig(tensor_model_parallel_size=2)\n        assert config.sequence_parallel is True\n\n    def test_mutable_fields(self):\n        config = McoreEngineConfig()\n        config.sequence_parallel = True  # Should be mutable\n        with pytest.raises(AttributeError):\n            config.tensor_model_parallel_size = 2  # Frozen field\n\n    @pytest.mark.parametrize(\"offload_field\", [\"param_offload\", \"grad_offload\", \"optimizer_offload\"])\n    def test_offload_flags(self, offload_field):\n        config = McoreEngineConfig(**{offload_field: True})\n        assert getattr(config, offload_field) is True\n\n\nclass TestFSDPEngineConfigCPU:\n    def test_default_values(self):\n        config = FSDPEngineConfig()\n        assert config.param_offload is False\n        assert config.optimizer_offload is False\n        assert config.fsdp_size == -1\n\n    @pytest.mark.parametrize(\n        \"offload_params\",\n        [{\"param_offload\": True}, {\"optimizer_offload\": True}, {\"param_offload\": True, \"optimizer_offload\": True}],\n    )\n    def test_offload_combinations(self, offload_params):\n        config = FSDPEngineConfig(**offload_params)\n        assert config.param_offload == offload_params.get(\"param_offload\", False)\n        assert config.optimizer_offload == offload_params.get(\"optimizer_offload\", False)\n\n    def test_wrap_policy_configuration(self):\n        test_policy = {\"layer_class\": \"TransformerBlock\"}\n        config = FSDPEngineConfig(wrap_policy=test_policy)\n        assert config.wrap_policy == test_policy\n"
  },
  {
    "path": "verl_distillation/tests/workers/config/test_optim_config_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\nfrom verl.workers.config.optimizer import FSDPOptimizerConfig\n\n\nclass TestFSDPOptimizerConfigCPU:\n    def test_default_configuration(self):\n        config = FSDPOptimizerConfig(lr=0.1)\n        assert config.min_lr_ratio is None\n        assert config.lr_scheduler_type == \"constant\"\n        assert config.num_cycles == 0.5\n\n    @pytest.mark.parametrize(\"lr_scheduler_type\", [\"constant\", \"cosine\"])\n    def test_valid_lr_scheduler_types(self, lr_scheduler_type):\n        config = FSDPOptimizerConfig(lr_scheduler_type=lr_scheduler_type, lr=0.1)\n        assert config.lr_scheduler_type == lr_scheduler_type\n\n    @pytest.mark.parametrize(\"warmup_style\", [\"constant\", \"cosine\"])\n    def test_valid_warmup_style_types(self, warmup_style):\n        config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)\n        assert config.lr_scheduler_type == warmup_style\n\n    def test_invalid_lr_scheduler_type(self):\n        with pytest.raises((ValueError, AssertionError)):\n            FSDPOptimizerConfig(lr_scheduler_type=\"invalid_style\", lr=0.1)\n\n    def test_invalid_warmup_style_type(self):\n        with pytest.raises((ValueError, AssertionError)):\n            FSDPOptimizerConfig(warmup_style=\"invalid_style\", lr=0.1)\n\n    @pytest.mark.parametrize(\"num_cycles\", [0.1, 1.0, 2.5])\n    def test_num_cycles_configuration(self, num_cycles):\n        config = FSDPOptimizerConfig(num_cycles=num_cycles, lr=0.1)\n        assert config.num_cycles == num_cycles\n"
  },
  {
    "path": "verl_distillation/tests/workers/critic/test_special_dp_critic.py",
    "content": "#!/usr/bin/env python3\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport tempfile\nimport unittest\nfrom unittest.mock import Mock, patch\n\nimport torch\nimport torch.distributed\nfrom omegaconf import OmegaConf\nfrom tensordict import TensorDict\nfrom transformers import AutoConfig\n\nfrom verl import DataProto\nfrom verl.workers.config import FSDPCriticConfig, FSDPOptimizerConfig\nfrom verl.workers.config.critic import FSDPCriticModelCfg\nfrom verl.workers.config.engine import FSDPEngineConfig\nfrom verl.workers.fsdp_workers import CriticWorker\n\n\nclass TestCriticWorker(unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        \"\"\"Set up distributed environment\"\"\"\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(\n                backend=\"nccl\" if torch.cuda.is_available() else \"gloo\", init_method=\"env://\"\n            )\n\n        cls.rank = torch.distributed.get_rank()\n        cls.world_size = torch.distributed.get_world_size()\n\n        if torch.cuda.is_available():\n            torch.cuda.set_device(cls.rank)\n            cls.device = torch.device(f\"cuda:{cls.rank}\")\n        else:\n            cls.device = torch.device(\"cpu\")\n\n    @classmethod\n    def tearDownClass(cls):\n        \"\"\"Clean up distributed environment\"\"\"\n        if torch.distributed.is_initialized():\n            torch.distributed.destroy_process_group()\n\n    def setUp(self):\n        \"\"\"Set up test fixtures\"\"\"\n\n        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        self.temp_dir = tempfile.mkdtemp()\n\n        config = AutoConfig.from_pretrained(\"Qwen/Qwen2.5-0.5B-Instruct\")\n        config.save_pretrained(self.temp_dir)\n\n        self.config = FSDPCriticConfig(\n            strategy=\"fsdp2\",\n            ppo_mini_batch_size=4,\n            ppo_micro_batch_size_per_gpu=2,\n            forward_micro_batch_size_per_gpu=2,\n            ppo_epochs=1,\n            cliprange_value=0.5,\n            grad_clip=1.0,\n            use_dynamic_bsz=False,\n            ulysses_sequence_parallel_size=1,\n            rollout_n=1,\n            optim=FSDPOptimizerConfig(lr=1e-6),\n            model=FSDPCriticModelCfg(\n                path=\"Qwen/Qwen2.5-0.5B-Instruct\",\n                tokenizer_path=\"Qwen/Qwen2.5-0.5B-Instruct\",\n                fsdp_config=FSDPEngineConfig(fsdp_size=-1),\n                use_remove_padding=False,\n            ),\n        )\n        assert self.world_size <= 4 // 2\n\n    def tearDown(self):\n        \"\"\"Clean up test fixtures\"\"\"\n        import shutil\n\n        shutil.rmtree(self.temp_dir, ignore_errors=True)\n\n    def _create_test_data_for_compute_values(self, batch_size=2, seq_len=10, response_len=5):\n        \"\"\"Create test data for compute_values method\"\"\"\n        input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long)\n        attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)\n        position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)\n        responses = torch.randint(0, 1000, (batch_size, response_len), dtype=torch.long)\n        response_mask = torch.ones(batch_size, response_len, dtype=torch.float)\n\n        batch = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n                \"responses\": responses,\n                \"response_mask\": response_mask,\n            },\n            batch_size=[batch_size],\n        )\n\n        data = DataProto(\n            batch=batch, meta_info={\"micro_batch_size\": 2, \"max_token_len\": seq_len, \"use_dynamic_bsz\": False}\n        )\n\n        return data\n\n    def _create_test_data_for_update_critic(self, batch_size=2, seq_len=10, response_len=5):\n        \"\"\"Create test data for update_critic method\"\"\"\n        input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long)\n        attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)\n        position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)\n        responses = torch.randint(0, 1000, (batch_size, response_len), dtype=torch.long)\n        response_mask = torch.ones(batch_size, response_len, dtype=torch.float)\n        values = torch.randn(batch_size, response_len, dtype=torch.float)\n        returns = torch.randn(batch_size, response_len, dtype=torch.float)\n\n        batch = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n                \"responses\": responses,\n                \"response_mask\": response_mask,\n                \"values\": values,\n                \"returns\": returns,\n            },\n            batch_size=[batch_size],\n        )\n\n        data = DataProto(\n            batch=batch,\n            meta_info={\"global_token_num\": [response_len] * batch_size, \"batch_seqlens\": [response_len] * batch_size},\n        )\n\n        return data\n\n    def test_init_model(self):\n        \"\"\"Test CriticWorker.init_model() method\"\"\"\n        worker = CriticWorker(self.config)\n        worker.init_model()\n\n        self.assertIsNotNone(worker.critic_module)\n        self.assertIsNotNone(worker.critic_optimizer)\n        self.assertIsNotNone(worker.critic)\n        self.assertIsNotNone(worker.checkpoint_manager)\n\n    def test_compute_values(self):\n        \"\"\"Test CriticWorker.compute_values() method\"\"\"\n        worker = CriticWorker(self.config)\n        worker.init_model()\n\n        data = self._create_test_data_for_compute_values()\n\n        result = worker.compute_values(data)\n\n        self.assertIsInstance(result, DataProto)\n        self.assertIn(\"values\", result.batch)\n        values = result.batch[\"values\"]\n\n        batch_size, response_len = 2, 5\n        self.assertEqual(values.shape, (batch_size, response_len))\n\n        self.assertTrue(torch.isfinite(values).all())\n\n    def test_update_critic(self):\n        \"\"\"Test CriticWorker.update_critic() method\"\"\"\n        worker = CriticWorker(self.config)\n        worker.init_model()\n\n        data = self._create_test_data_for_update_critic()\n\n        result = worker.update_critic(data)\n\n        self.assertIsInstance(result, DataProto)\n        self.assertIn(\"metrics\", result.meta_info)\n        metrics = result.meta_info[\"metrics\"]\n\n        expected_keys = [\"critic/vf_loss\", \"critic/vf_clipfrac\", \"critic/vpred_mean\", \"critic/grad_norm\"]\n        for key in expected_keys:\n            self.assertIn(key, metrics)\n\n        for key, value in metrics.items():\n            if isinstance(value, list | tuple):\n                for v in value:\n                    self.assertTrue(torch.isfinite(torch.tensor(v)).all())\n            else:\n                self.assertTrue(torch.isfinite(torch.tensor(value)).all())\n\n    @patch(\"transformers.AutoConfig.from_pretrained\")\n    def test_critic_attn_implementation_override_functionality(self, mock_config_from_pretrained):\n        \"\"\"Test that CriticWorker correctly uses attn_implementation from override_config\"\"\"\n\n        # Mock the AutoConfig return value\n        mock_config = Mock()\n        mock_config.tie_word_embeddings = False\n        mock_config.architectures = [\"LlamaForCausalLM\"]\n        mock_config.num_labels = 1\n        mock_config_from_pretrained.return_value = mock_config\n\n        # Test different attn_implementation values\n        test_cases = [\n            (\"eager\", \"eager\"),\n            (\"sdpa\", \"sdpa\"),\n            (\"flash_attention_2\", \"flash_attention_2\"),\n            (None, \"flash_attention_2\"),  # Default case\n        ]\n\n        for override_value, expected_value in test_cases:\n            mock_config_from_pretrained.reset_mock()\n\n            # Create config with override_config\n            config_dict = {\n                \"model\": {\n                    \"path\": \"/test/model/path\",\n                    \"tokenizer_path\": \"/test/tokenizer/path\",\n                    \"fsdp_config\": {\n                        \"fsdp_size\": 1,\n                        \"param_offload\": False,\n                        \"optimizer_offload\": False,\n                    },\n                },\n                \"optim\": {\"lr\": 1e-4, \"type\": \"AdamW\"},\n                \"strategy\": \"fsdp\",\n                \"ppo_mini_batch_size\": 1,\n                \"ppo_epochs\": 1,\n                \"rollout_n\": 1,\n                \"checkpoint\": {\"save_contents\": [], \"load_contents\": []},\n            }\n\n            # Add override_config with attn_implementation if specified\n            if override_value is not None:\n                config_dict[\"model\"][\"override_config\"] = {\"attn_implementation\": override_value}\n\n            # Convert to OmegaConf\n            test_config = OmegaConf.create(config_dict)\n\n            # Test the extraction logic that should happen in CriticWorker._build_critic_model_optimizer\n            override_config = OmegaConf.to_container(OmegaConf.create(test_config.model.get(\"override_config\", {})))\n            extracted_attn_implementation = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n            # Verify the extraction works correctly\n            self.assertEqual(\n                extracted_attn_implementation,\n                expected_value,\n                f\"Expected {expected_value}, got {extracted_attn_implementation} for override_value {override_value}\",\n            )\n\n    def test_critic_model_config_structure(self):\n        \"\"\"Test that critic model config properly incorporates override settings\"\"\"\n\n        # Test configuration scenarios\n        test_scenarios = [\n            {\"name\": \"default_flash_attention\", \"override_config\": {}, \"expected_attn\": \"flash_attention_2\"},\n            {\"name\": \"eager_override\", \"override_config\": {\"attn_implementation\": \"eager\"}, \"expected_attn\": \"eager\"},\n            {\"name\": \"sdpa_override\", \"override_config\": {\"attn_implementation\": \"sdpa\"}, \"expected_attn\": \"sdpa\"},\n            {\n                \"name\": \"mixed_config\",\n                \"override_config\": {\"attn_implementation\": \"eager\", \"dropout\": 0.1, \"num_labels\": 1},\n                \"expected_attn\": \"eager\",\n            },\n        ]\n\n        for scenario in test_scenarios:\n            with self.subTest(scenario=scenario[\"name\"]):\n                # Simulate the config processing logic from CriticWorker\n                override_config = scenario[\"override_config\"]\n\n                # Test the extraction logic\n                extracted_attn = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n                # Verify correct extraction\n                self.assertEqual(extracted_attn, scenario[\"expected_attn\"], f\"Failed for scenario {scenario['name']}\")\n\n                # Verify other configs are preserved\n                if \"dropout\" in override_config:\n                    self.assertEqual(override_config[\"dropout\"], 0.1)\n\n    def test_critic_hydra_config_compatibility(self):\n        \"\"\"Test that Hydra +prefix configurations work correctly for CriticWorker\"\"\"\n\n        # Simulate Hydra configuration with +prefix for critic\n        # This would come from: +critic.model.override_config.attn_implementation=eager\n        hydra_config_dict = {\n            \"critic\": {\"model\": {\"path\": \"/test/model/path\", \"override_config\": {\"attn_implementation\": \"eager\"}}}\n        }\n\n        omegaconf = OmegaConf.create(hydra_config_dict)\n\n        # Extract override config as would be done in CriticWorker\n        override_model_config = OmegaConf.to_container(\n            OmegaConf.create(omegaconf.critic.model.get(\"override_config\", {}))\n        )\n\n        # Test extraction\n        attn_implementation = override_model_config.get(\"attn_implementation\", \"flash_attention_2\")\n        self.assertEqual(attn_implementation, \"eager\")\n\n    def test_critic_backward_compatibility(self):\n        \"\"\"Test that CriticWorker maintains backward compatibility with existing configurations\"\"\"\n\n        # Test cases for backward compatibility\n        compatibility_tests = [\n            {\"name\": \"no_override_config\", \"config\": {}, \"expected\": \"flash_attention_2\"},\n            {\"name\": \"empty_override_config\", \"config\": {\"override_config\": {}}, \"expected\": \"flash_attention_2\"},\n            {\n                \"name\": \"other_overrides_only\",\n                \"config\": {\"override_config\": {\"dropout\": 0.1, \"hidden_size\": 768}},\n                \"expected\": \"flash_attention_2\",\n            },\n        ]\n\n        for test in compatibility_tests:\n            with self.subTest(test=test[\"name\"]):\n                override_config = test[\"config\"].get(\"override_config\", {})\n                attn_implementation = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n                self.assertEqual(\n                    attn_implementation, test[\"expected\"], f\"Backward compatibility failed for {test['name']}\"\n                )\n\n    def test_critic_and_actor_independent_configuration(self):\n        \"\"\"Test that critic and actor can have independent attention implementation configurations\"\"\"\n\n        # Simulate a complete training configuration with both actor and critic\n        complete_config = {\n            \"actor_rollout_ref\": {\"model\": {\"override_config\": {\"attn_implementation\": \"eager\"}}},\n            \"critic\": {\"model\": {\"override_config\": {\"attn_implementation\": \"sdpa\"}}},\n        }\n\n        omegaconf = OmegaConf.create(complete_config)\n\n        # Extract actor config\n        actor_override = OmegaConf.to_container(\n            OmegaConf.create(omegaconf.actor_rollout_ref.model.get(\"override_config\", {}))\n        )\n        actor_attn = actor_override.get(\"attn_implementation\", \"flash_attention_2\")\n\n        # Extract critic config\n        critic_override = OmegaConf.to_container(OmegaConf.create(omegaconf.critic.model.get(\"override_config\", {})))\n        critic_attn = critic_override.get(\"attn_implementation\", \"flash_attention_2\")\n\n        # Verify independent configuration\n        self.assertEqual(actor_attn, \"eager\")\n        self.assertEqual(critic_attn, \"sdpa\")\n        self.assertNotEqual(actor_attn, critic_attn)  # Ensure they are indeed different\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_distillation/tests/workers/reward_manager/test_registry_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\n# Assuming REWARD_MANAGER_REGISTRY is defined somewhere in the module\nfrom verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY, get_reward_manager_cls, register\n\n\n@pytest.fixture\ndef setup():\n    \"\"\"Setup test cases with a mock registry.\"\"\"\n    REWARD_MANAGER_REGISTRY.clear()\n    REWARD_MANAGER_REGISTRY.update({\"manager1\": \"Manager1Class\", \"manager2\": \"Manager2Class\"})\n    return REWARD_MANAGER_REGISTRY\n\n\ndef test_get_existing_manager(setup):\n    \"\"\"Test getting an existing reward manager class.\"\"\"\n    assert get_reward_manager_cls(\"manager1\") == \"Manager1Class\"\n    assert get_reward_manager_cls(\"manager2\") == \"Manager2Class\"\n\n\ndef test_get_nonexistent_manager(setup):\n    \"\"\"Test getting a non-existent reward manager raises ValueError.\"\"\"\n    with pytest.raises(ValueError) as excinfo:\n        get_reward_manager_cls(\"unknown_manager\")\n    assert \"Unknown reward manager: unknown_manager\" in str(excinfo.value)\n\n\ndef test_case_sensitivity(setup):\n    \"\"\"Test that manager names are case-sensitive.\"\"\"\n    with pytest.raises(ValueError):\n        get_reward_manager_cls(\"MANAGER1\")\n    with pytest.raises(ValueError):\n        get_reward_manager_cls(\"Manager1\")\n\n\ndef test_empty_registry(setup):\n    \"\"\"Test behavior when registry is empty.\"\"\"\n    REWARD_MANAGER_REGISTRY.clear()\n    with pytest.raises(ValueError) as excinfo:\n        get_reward_manager_cls(\"any_manager\")\n    assert \"Unknown reward manager: any_manager\" in str(excinfo.value)\n\n\ndef test_register_new_class(setup):\n    \"\"\"Test registering a new class with the decorator.\"\"\"\n\n    @register(\"test_manager\")\n    class TestManager:\n        pass\n\n    assert \"test_manager\" in REWARD_MANAGER_REGISTRY\n    assert REWARD_MANAGER_REGISTRY[\"test_manager\"] == TestManager\n\n\ndef test_register_different_classes_same_name(setup):\n    \"\"\"Test that registering different classes with same name raises ValueError.\"\"\"\n\n    @register(\"conflict_manager\")\n    class Manager1:\n        pass\n\n    with pytest.raises(ValueError):\n\n        @register(\"conflict_manager\")\n        class Manager2:\n            pass\n\n    assert REWARD_MANAGER_REGISTRY[\"conflict_manager\"] == Manager1\n\n\ndef test_decorator_returns_original_class(setup):\n    \"\"\"Test that the decorator returns the original class unchanged.\"\"\"\n\n    @register(\"return_test\")\n    class OriginalClass:\n        def method(setup):\n            return 42\n\n    assert OriginalClass().method() == 42\n    assert REWARD_MANAGER_REGISTRY[\"return_test\"] == OriginalClass\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/perf/vllm_async_rollout.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nCompare vLLM AsyncLLM backend: ExternalRayDistributedExecutor(remote call) vs RayDistributedExecutor(compiled graph)\n\n1. Prepare openai/gsm8k dataset\npython3 examples/data_preprocess/gsm8k.py\n\n2. Run perf test\npython3 tests/workers/rollout/perf/vllm_async_rollout.py >perf.log 2>&1\n\nhardware: Nvidia 8*H20\npackages:\n- torch==2.6.0\n- vllm==0.8.5\n\n[DEBUG] backend: sync, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 21.27 secs\n[DEBUG] backend: zeromq, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 23.40 secs\n[DEBUG] backend: ray, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 25.33 secs\n\"\"\"\n\nimport os\nimport time\n\nimport ray\nfrom omegaconf import DictConfig\nfrom torch.utils.data import SequentialSampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\n\nfrom tests.experimental.agent_loop.agent_utils import AgentLoopManager, RayWorkerGroup, init_agent_loop_manager\nfrom verl.protocol import DataProto\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.dataset import RLHFDataset\nfrom verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn\n\n\ndef init_config(n_gpus_per_node) -> DictConfig:\n    import os\n\n    from hydra import compose, initialize_config_dir\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(\n            config_name=\"ppo_trainer\",\n            overrides=[\n                \"actor_rollout_ref.actor.use_dynamic_bsz=true\",\n                \"actor_rollout_ref.actor.fsdp_config.param_offload=True\",\n                \"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\",\n            ],\n        )\n    config.trainer.n_gpus_per_node = n_gpus_per_node\n    config.data.train_batch_size = 128\n    config.data.return_raw_chat = True\n    config.actor_rollout_ref.model.path = \"Qwen/Qwen2.5-7B-Instruct\"\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2\n    config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9\n    config.actor_rollout_ref.rollout.multi_turn.format = \"hermes\"\n    config.actor_rollout_ref.rollout.prompt_length = 4096\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.n = 16\n\n    return config\n\n\ndef initialize(config, backend) -> tuple[AgentLoopManager | RayWorkerGroup, StatefulDataLoader]:\n    env_vars = {\n        \"NCCL_DEBUG\": \"WARN\",\n        \"VLLM_USE_V1\": \"1\",\n        \"VERL_VLLM_DISTRIBUTED_BACKEND\": backend,\n    }\n    ray.init(runtime_env={\"env_vars\": env_vars})\n\n    # STEP 1: init async llm server\n    server = init_agent_loop_manager(config)\n\n    # STEP 2: create dataloader\n    tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path)\n    dataset = RLHFDataset(\n        data_files=os.path.expanduser(\"~/data/gsm8k/train.parquet\"),\n        tokenizer=tokenizer,\n        config=config.data,\n    )\n    dataloader = StatefulDataLoader(\n        dataset=dataset,\n        batch_size=config.data.get(\"gen_batch_size\", config.data.train_batch_size),\n        num_workers=config.data.get(\"dataloader_num_workers\", 8),\n        drop_last=True,\n        collate_fn=default_collate_fn,\n        sampler=SequentialSampler(dataset),\n    )\n\n    return server, dataloader\n\n\ndef perf_rollout(mode, backend, n_gpus_per_node, num_steps):\n    config = init_config(n_gpus_per_node)\n    config.actor_rollout_ref.rollout.mode = mode\n    agent_loop_manager, dataloader = initialize(config, backend)\n\n    for step, batch in enumerate(dataloader):\n        batch: DataProto = DataProto.from_single_dict(batch)\n        batch = batch.pop(\n            batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"],\n            non_tensor_batch_keys=[\"raw_prompt_ids\", \"raw_prompt\"],\n        )\n        t_start = time.time()\n        gen_batch = agent_loop_manager.generate_sequences(batch)\n        t_end = time.time()\n        print(\n            f\"[DEBUG] backend: {backend}, n_gpus_per_node: {n_gpus_per_node}, batch_size: {len(gen_batch)}, \"\n            f\"step: {step}, step_time: {t_end - t_start:.2f} secs\"\n        )\n        if step + 1 >= num_steps:\n            break\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    num_steps = 1\n    n_gpus_per_node = 8\n\n    # test_cases = [(\"sync\", \"sync\"), (\"async\", \"zeromq\"), (\"async\", \"ray\")]\n    test_cases = [(\"async\", \"zeromq\"), (\"async\", \"ray\")]\n    for mode, backend in test_cases:\n        perf_rollout(mode=mode, backend=backend, n_gpus_per_node=n_gpus_per_node, num_steps=num_steps)\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/resource/tool_configs/mcp_server.json",
    "content": "{\n    \"mcpServers\": {\n        \"Tavily Expert\": {\n            \"url\": \"https://tavily.api.tadata.com/mcp/tavily/your_expert\",\n            \"auth_token\": \"your_tavily_token\"\n        }\n    }\n}"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/resource/tool_configs/mcp_tool_config",
    "content": "tools:\n  - class_name: verl.tools.mcp_search_tool.MCPSearchTool\n    config:\n      rate_limit: 120\n      timeout: 120\n      type: mcp\n    mcp:\n      mcp_servers_config_path: ./resource/tool_configs/mcp_server.json\n      # optional\n      tool_selected_list: \n        - tavily_search_tool"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config",
    "content": "tools:\n  - class_name: \"verl.tools.sandbox_fusion_tools.SandboxFusionTool\"\n    config: \n      sandbox_fusion_url: \"https://xxx.apigateway-cn-beijing.volceapi.com/run_code\"\n      type: native\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"code_interpreter\"\n        description: \"A tool for executing code.\"\n        parameters:\n          type: \"object\"\n          properties:\n            code:\n              type: \"string\"\n              description: \"The code to execute.\"\n          required: [\"code\"]"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/resource/tool_configs/search_tool_config",
    "content": "tools:\n  - class_name: verl.tools.search_tool.SearchTool\n    config:\n      retrieval_service_url: http://127.0.0.1:8000/retrieve\n      num_workers: 120\n      rate_limit: 120\n      timeout: 30\n      type: native\n    tool_schema:\n      type: function\n      function:\n        name: search\n        description: Searches the web for relevant information based on the given query.\n        parameters:\n          type: object\n          properties:\n            query_list:\n              type: array\n              item:\n                type: string\n              description: A list of fully-formed semantic queries. The tool will return search results for each query.\n          required: \n            - query_list"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/rollout_sglang/test_http_server_engine.py",
    "content": "# Copyright 2025 z.ai\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# This file is adapted from multiple sources:\n# 1. THUDM/slime project\n#    Original source: https://github.com/THUDM/slime/blob/main/slime/backends/sglang_utils/http_server_engine.py\n#    Copyright 2025 z.ai\n#    Licensed under the Apache License, Version 2.0\n# 2. SGLang project\n#    Original source: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server_engine.py\n#    Copyright 2023-2024 SGLang Team\n#    Licensed under the Apache License, Version 2.0\n#\n# Modifications made by z.ai and ModelBest Inc. include but are not limited to:\n# - Enhanced error handling and retry logic\n# - Added async support with connection pooling\n# - Extended functionality for distributed weight updates\n# - Improved logging and monitoring capabilities\n# - Additional configuration options and optimizations\n\n\"\"\"Complete unit tests for HTTP Server Engine Adapters.\n\nThis module contains comprehensive unit tests for both HttpServerEngineAdapter\nand AsyncHttpServerEngineAdapter classes, covering all public methods,\nerror handling scenarios, edge cases, and boundary conditions using pytest and mock frameworks.\n\nTests use real SGLang modules for integration testing while mocking external dependencies.\n\"\"\"\n\nimport asyncio\nfrom unittest.mock import AsyncMock, Mock, patch\n\nimport aiohttp\nimport pytest\nimport requests\nfrom sglang.srt.managers.io_struct import (\n    UpdateWeightsFromTensorReqInput,\n)\nfrom sglang.srt.utils import MultiprocessingSerializer\n\n# Import the module under test\nfrom verl.workers.rollout.sglang_rollout.http_server_engine import (\n    AsyncHttpServerAdapter,\n    HttpServerAdapter,\n    launch_server_process,\n)\n\n\n@pytest.fixture(scope=\"session\")\ndef event_loop():\n    \"\"\"Create an event loop for the entire test session.\"\"\"\n    loop = asyncio.new_event_loop()\n    yield loop\n    loop.close()\n\n\n@pytest.fixture\ndef basic_adapter_kwargs():\n    \"\"\"Provide basic kwargs for creating HTTP server adapters.\"\"\"\n    return {\n        \"host\": \"localhost\",\n        \"port\": 8000,\n        \"node_rank\": 0,\n        \"model_path\": \"/tmp/test_model\",\n    }\n\n\n@pytest.fixture\ndef router_adapter_kwargs():\n    \"\"\"Provide kwargs for creating adapters with router configuration.\"\"\"\n    return {\n        \"router_ip\": \"192.168.1.1\",\n        \"router_port\": 8080,\n        \"host\": \"localhost\",\n        \"port\": 8000,\n        \"node_rank\": 0,\n        \"model_path\": \"/tmp/test_model\",\n    }\n\n\n@pytest.fixture\ndef non_master_adapter_kwargs():\n    \"\"\"Provide kwargs for creating non-master node adapters.\"\"\"\n    return {\n        \"host\": \"localhost\",\n        \"port\": 8000,\n        \"node_rank\": 1,  # Non-master\n        \"model_path\": \"/tmp/test_model\",\n    }\n\n\n@pytest.fixture\ndef mock_launch_server_process():\n    \"\"\"Mock the launch_server_process function for testing without actual server startup.\"\"\"\n    from unittest.mock import patch\n\n    with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.launch_server_process\") as mock_launch:\n        mock_process = Mock()\n        mock_process.is_alive.return_value = True\n        mock_process.pid = 12345\n        mock_launch.return_value = mock_process\n        yield mock_launch\n\n\n@pytest.fixture\ndef mock_multiprocessing_process():\n    \"\"\"Create mock multiprocessing.Process for testing without actual process creation.\"\"\"\n    from unittest.mock import patch\n\n    with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.multiprocessing.Process\") as mock_process_class:\n        mock_process = Mock()\n        mock_process.is_alive.return_value = True\n        mock_process.pid = 12345\n        mock_process_class.return_value = mock_process\n        yield mock_process\n\n\n@pytest.fixture\ndef mock_requests_session():\n    \"\"\"Create mock requests.Session for testing HTTP interactions.\"\"\"\n    from unittest.mock import patch\n\n    with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.Session\") as mock_session_class:\n        mock_session = Mock()\n        mock_response = Mock()\n        mock_response.status_code = 200\n        mock_response.json.return_value = {\"status\": \"success\"}\n        mock_session.get.return_value = mock_response\n        mock_session.post.return_value = mock_response\n        mock_session_class.return_value.__enter__.return_value = mock_session\n        yield mock_session\n\n\n@pytest.fixture\ndef mock_requests_post():\n    \"\"\"Mock requests.post for testing HTTP POST requests.\"\"\"\n    from unittest.mock import patch\n\n    with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.post\") as mock_post:\n        mock_response = Mock()\n        mock_response.status_code = 200\n        mock_response.json.return_value = {\"status\": \"success\"}\n        mock_post.return_value = mock_response\n        yield mock_post\n\n\n@pytest.fixture\ndef mock_requests_get():\n    \"\"\"Mock requests.get for testing HTTP GET requests.\"\"\"\n    from unittest.mock import patch\n\n    with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.get\") as mock_get:\n        mock_response = Mock()\n        mock_response.status_code = 200\n        mock_response.json.return_value = {\"status\": \"success\"}\n        mock_get.return_value = mock_response\n        yield mock_get\n\n\n@pytest.fixture\ndef mock_aiohttp_session():\n    \"\"\"Create mock aiohttp.ClientSession for testing async HTTP interactions.\"\"\"\n    mock_session = AsyncMock()\n    mock_session.closed = False\n\n    # Mock response\n    mock_response = AsyncMock()\n    mock_response.status = 200\n    mock_response.json = AsyncMock(return_value={\"status\": \"success\"})\n    mock_response.raise_for_status = Mock()\n\n    # Mock context managers\n    mock_session.get.return_value.__aenter__.return_value = mock_response\n    mock_session.post.return_value.__aenter__.return_value = mock_response\n\n    return mock_session\n\n\n@pytest.fixture\ndef mock_kill_process_tree():\n    \"\"\"Mock kill_process_tree function for testing cleanup without actual process termination.\"\"\"\n    from unittest.mock import patch\n\n    with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.kill_process_tree\") as mock_kill:\n        yield mock_kill\n\n\n# Test environment fixtures for real SGLang testing\n@pytest.fixture(scope=\"session\")\ndef sglang_test_model_path():\n    \"\"\"Provide a test model path for SGLang tests.\n\n    This can be overridden by environment variable SGLANG_TEST_MODEL_PATH\n    for tests that need a real model.\n    \"\"\"\n    import os\n\n    return os.getenv(\"SGLANG_TEST_MODEL_PATH\", \"/tmp/test_model\")\n\n\n@pytest.fixture\ndef real_adapter_kwargs(sglang_test_model_path):\n    \"\"\"Provide kwargs for creating adapters with real SGLang integration.\"\"\"\n    return {\n        \"host\": \"localhost\",\n        \"port\": 8000,\n        \"node_rank\": 0,\n        \"model_path\": sglang_test_model_path,\n    }\n\n\n@pytest.fixture(autouse=True)\ndef mock_server_args_post_init():\n    \"\"\"Mock ServerArgs.__post_init__ to skip model path validation.\"\"\"\n    from unittest.mock import patch\n\n    with patch(\n        \"verl.workers.rollout.sglang_rollout.http_server_engine.ServerArgs.__post_init__\", return_value=None\n    ) as mock_post_init:\n        yield mock_post_init\n\n\nclass TestLaunchServerProcess:\n    \"\"\"Test cases for launch_server_process function.\"\"\"\n\n    def test_launch_server_process_success(\n        self, mock_multiprocessing_process, mock_requests_session, real_adapter_kwargs\n    ):\n        \"\"\"Test successful server process launch and health check.\"\"\"\n        # Import real SGLang ServerArgs\n        from sglang.srt.server_args import ServerArgs\n\n        # Create server args using real ServerArgs\n        server_args = ServerArgs(**real_adapter_kwargs)\n\n        # Test\n        with patch(\n            \"verl.workers.rollout.sglang_rollout.http_server_engine.multiprocessing.Process\"\n        ) as mock_process_class:\n            mock_process_class.return_value = mock_multiprocessing_process\n            with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.Session\") as mock_session_class:\n                mock_session_class.return_value.__enter__.return_value = mock_requests_session\n\n                result = launch_server_process(server_args, first_rank_in_node=True)\n\n                # Assertions\n                assert result == mock_multiprocessing_process\n                mock_multiprocessing_process.start.assert_called_once()\n                assert mock_requests_session.get.call_count >= 2  # health_generate and flush_cache\n\n    def test_launch_server_process_non_master(self, mock_multiprocessing_process, non_master_adapter_kwargs):\n        \"\"\"Test server launch for non-master nodes (should return immediately).\"\"\"\n        from sglang.srt.server_args import ServerArgs\n\n        server_args = ServerArgs(**non_master_adapter_kwargs)\n\n        with patch(\n            \"verl.workers.rollout.sglang_rollout.http_server_engine.multiprocessing.Process\"\n        ) as mock_process_class:\n            mock_process_class.return_value = mock_multiprocessing_process\n            result = launch_server_process(server_args, first_rank_in_node=True)\n\n            assert result == mock_multiprocessing_process\n            mock_multiprocessing_process.start.assert_not_called()\n\n    def test_launch_server_process_timeout(self, mock_multiprocessing_process, real_adapter_kwargs):\n        \"\"\"Test timeout during server health check.\"\"\"\n        from sglang.srt.server_args import ServerArgs\n\n        server_args = ServerArgs(**real_adapter_kwargs)\n\n        with patch(\n            \"verl.workers.rollout.sglang_rollout.http_server_engine.multiprocessing.Process\"\n        ) as mock_process_class:\n            mock_process_class.return_value = mock_multiprocessing_process\n            with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.Session\") as mock_session_class:\n                mock_session = Mock()\n                mock_session.get.side_effect = requests.RequestException(\"Connection failed\")\n                mock_session_class.return_value.__enter__.return_value = mock_session\n\n            import itertools\n\n            with patch(\n                \"verl.workers.rollout.sglang_rollout.http_server_engine.time.time\",\n                side_effect=itertools.chain([0], itertools.repeat(400)),  # 第一次返回0，之后一直返回400\n            ):\n                with pytest.raises(TimeoutError):\n                    launch_server_process(server_args, first_rank_in_node=True)\n\n                mock_multiprocessing_process.terminate.assert_called_once()\n\n    def test_launch_server_process_died(self, real_adapter_kwargs):\n        \"\"\"Test server process dies during startup.\"\"\"\n        from sglang.srt.server_args import ServerArgs\n\n        server_args = ServerArgs(**real_adapter_kwargs)\n\n        with patch(\n            \"verl.workers.rollout.sglang_rollout.http_server_engine.multiprocessing.Process\"\n        ) as mock_process_class:\n            mock_process = Mock()\n            mock_process.is_alive.return_value = False\n            mock_process_class.return_value = mock_process\n\n            with pytest.raises(RuntimeError, match=\"Server process terminated unexpectedly\"):\n                launch_server_process(server_args, first_rank_in_node=True)\n\n\nclass TestHttpServerEngineAdapter:\n    \"\"\"Test cases for HttpServerEngineAdapter class.\"\"\"\n\n    def test_init_with_router_registration(self, mock_launch_server_process, mock_requests_post, router_adapter_kwargs):\n        \"\"\"Test initialization with router registration.\"\"\"\n        adapter = HttpServerAdapter(**router_adapter_kwargs)\n\n        assert adapter.router_ip == \"192.168.1.1\"\n        assert adapter.router_port == 8080\n        assert adapter.process == mock_launch_server_process.return_value\n        mock_requests_post.assert_called_once()\n\n    def test_init_without_router(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test initialization without router registration.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        assert adapter.router_ip is None\n        assert adapter.router_port is None\n        assert adapter.process == mock_launch_server_process.return_value\n\n    def test_register_with_router_failure(self, mock_launch_server_process, router_adapter_kwargs):\n        \"\"\"Test router registration failure handling.\"\"\"\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.post\") as mock_post:\n            mock_post.side_effect = requests.RequestException(\"Connection failed\")\n\n            # Should not raise exception, just log error\n            adapter = HttpServerAdapter(**router_adapter_kwargs)\n\n            assert adapter.router_ip == \"192.168.1.1\"\n            mock_post.assert_called_once()\n\n    def test_make_request_success(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test successful HTTP request.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.post\") as mock_post:\n            mock_response = Mock()\n            mock_response.status_code = 200\n            mock_response.json.return_value = {\"status\": \"success\"}\n            mock_post.return_value = mock_response\n\n            result = adapter._make_request(\"test_endpoint\", {\"param\": \"value\"})\n\n            assert result == {\"status\": \"success\"}\n            mock_post.assert_called_with(\n                \"http://localhost:8000/test_endpoint\",\n                json={\"param\": \"value\"},\n                timeout=adapter.timeout,\n            )\n\n    def test_make_request_get_method(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test HTTP GET request.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.get\") as mock_get:\n            mock_response = Mock()\n            mock_response.status_code = 200\n            mock_response.json.return_value = {\"data\": \"test\"}\n            mock_get.return_value = mock_response\n\n            result = adapter._make_request(\"test_endpoint\", method=\"GET\")\n\n            assert result == {\"data\": \"test\"}\n            mock_get.assert_called_with(\"http://localhost:8000/test_endpoint\", timeout=adapter.timeout)\n\n    def test_make_request_non_master(self, mock_launch_server_process):\n        \"\"\"Test request from non-master node returns empty dict.\"\"\"\n        kwargs = {\"host\": \"localhost\", \"port\": 8000, \"node_rank\": 1, \"model_path\": \"/tmp/test_model\"}\n        adapter = HttpServerAdapter(**kwargs)\n        result = adapter._make_request(\"test_endpoint\")\n\n        assert result == {}\n\n    def test_make_request_retry_logic(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test retry logic for failed requests.\"\"\"\n        adapter = HttpServerAdapter(max_attempts=3, **basic_adapter_kwargs)\n\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.post\") as mock_post:\n            with patch(\"time.sleep\") as mock_sleep:\n                # First two calls fail, third succeeds\n                mock_post.side_effect = [\n                    requests.exceptions.Timeout(),\n                    requests.exceptions.ConnectionError(),\n                    Mock(status_code=200, json=lambda: {\"success\": True}),\n                ]\n\n                result = adapter._make_request(\"test_endpoint\")\n\n                assert result == {\"success\": True}\n                assert mock_post.call_count == 3\n                assert mock_sleep.call_count == 2\n\n    def test_make_request_http_error(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test HTTP error handling.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.post\") as mock_post:\n            mock_response = Mock()\n            mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(\"404 Not Found\")\n            mock_post.return_value = mock_response\n\n            with pytest.raises(requests.exceptions.HTTPError):\n                adapter._make_request(\"test_endpoint\")\n\n    def test_make_request_max_attempts_exceeded(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test max retries exceeded.\"\"\"\n        adapter = HttpServerAdapter(max_attempts=1, **basic_adapter_kwargs)\n\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.post\") as mock_post:\n            with patch(\"time.sleep\"):\n                mock_post.side_effect = requests.exceptions.Timeout()\n\n                with pytest.raises(RuntimeError, match=\"Failed to complete request\"):\n                    adapter._make_request(\"test_endpoint\")\n\n                assert mock_post.call_count == 1  # Initial retry\n\n    def test_update_weights_from_tensor_strict(self, mock_launch_server_process, basic_adapter_kwargs):\n        import base64\n\n        from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput\n\n        from verl.workers.rollout.sglang_rollout.http_server_engine import HttpServerAdapter\n\n        basic_adapter_kwargs.setdefault(\"node_rank\", 0)\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {\"status\": \"updated\"}\n\n            req = UpdateWeightsFromTensorReqInput(\n                serialized_named_tensors=[b\"tensor1\", b\"tensor2\"],\n                load_format=\"safetensors\",\n                flush_cache=True,\n            )\n            result = adapter.update_weights_from_tensor(req)\n\n            assert result == {\"status\": \"updated\"}\n\n            expected_b64_1 = base64.b64encode(b\"tensor1\").decode(\"utf-8\")\n            expected_b64_2 = base64.b64encode(b\"tensor2\").decode(\"utf-8\")\n\n            mock_request.assert_called_once_with(\n                \"update_weights_from_tensor\",\n                {\n                    \"serialized_named_tensors\": [expected_b64_1, expected_b64_2],\n                    \"load_format\": \"safetensors\",\n                    \"flush_cache\": True,\n                },\n            )\n\n    def test_update_weights_from_tensor_empty(self, mock_launch_server_process, basic_adapter_kwargs):\n        from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput\n\n        from verl.workers.rollout.sglang_rollout.http_server_engine import HttpServerAdapter\n\n        basic_adapter_kwargs.setdefault(\"node_rank\", 0)\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {\"status\": \"updated\"}\n\n            req = UpdateWeightsFromTensorReqInput(\n                serialized_named_tensors=[],\n                load_format=\"safetensors\",\n                flush_cache=True,\n            )\n            result = adapter.update_weights_from_tensor(req)\n\n            assert result == {\"status\": \"updated\"}\n\n            mock_request.assert_called_once_with(\n                \"update_weights_from_tensor\",\n                {\n                    \"serialized_named_tensors\": [],\n                    \"load_format\": \"safetensors\",\n                    \"flush_cache\": True,\n                },\n            )\n\n    def test_update_weights_from_tensor_none(self, mock_launch_server_process, basic_adapter_kwargs):\n        from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput\n\n        from verl.workers.rollout.sglang_rollout.http_server_engine import HttpServerAdapter\n\n        basic_adapter_kwargs.setdefault(\"node_rank\", 0)\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {\"status\": \"updated\"}\n\n            req = UpdateWeightsFromTensorReqInput(\n                serialized_named_tensors=None,\n                load_format=\"safetensors\",\n                flush_cache=True,\n            )\n            result = adapter.update_weights_from_tensor(req)\n\n            assert result == {\"status\": \"updated\"}\n\n            mock_request.assert_called_once_with(\n                \"update_weights_from_tensor\",\n                {\n                    \"serialized_named_tensors\": [],\n                    \"load_format\": \"safetensors\",\n                    \"flush_cache\": True,\n                },\n            )\n\n    def test_generate(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test generate method.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {\"text\": \"Generated text\"}\n\n            result = adapter.generate(\n                prompt=\"Hello world\",\n                sampling_params={\"temperature\": 0.7},\n                return_logprob=True,\n            )\n\n            assert result == {\"text\": \"Generated text\"}\n            mock_request.assert_called_once_with(\n                \"generate\",\n                {\n                    \"text\": \"Hello world\",\n                    \"sampling_params\": {\"temperature\": 0.7},\n                    \"return_logprob\": True,\n                },\n                only_master=False,\n            )\n\n    def test_flush_cache(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test flush_cache method.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.get\") as mock_get:\n            with patch(\"time.sleep\") as mock_sleep:\n                # First call fails, second succeeds\n                mock_responses = [\n                    Mock(status_code=503),  # Service unavailable\n                    Mock(status_code=200, json=lambda: {\"cache_flushed\": True}),\n                ]\n                mock_get.side_effect = mock_responses\n\n                result = adapter.flush_cache()\n\n                assert result == {\"cache_flushed\": True}\n                assert mock_get.call_count == 2\n                mock_sleep.assert_called_once()\n\n    def test_flush_cache_non_master(self, mock_launch_server_process):\n        \"\"\"Test flush_cache for non-master node.\"\"\"\n        kwargs = {\"host\": \"localhost\", \"port\": 8000, \"node_rank\": 1, \"model_path\": \"/tmp/test_model\"}\n        adapter = HttpServerAdapter(**kwargs)\n        result = adapter.flush_cache()\n\n        assert result == {}\n\n    def test_memory_management_methods(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test memory release and resume methods.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {\"status\": \"success\"}\n\n            # Test release_memory_occupation\n            result = adapter.release_memory_occupation([\"weights\", \"kv_cache\"])\n            assert result == {\"status\": \"success\"}\n            mock_request.assert_called_with(\"release_memory_occupation\", {\"tags\": [\"weights\", \"kv_cache\"]})\n\n            # Test resume_memory_occupation\n            result = adapter.resume_memory_occupation([\"weights\"])\n            assert result == {\"status\": \"success\"}\n            mock_request.assert_called_with(\"resume_memory_occupation\", {\"tags\": [\"weights\"]})\n\n    def test_generation_control_methods(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test generation control methods.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {\"status\": \"success\"}\n\n    def test_shutdown(self, mock_launch_server_process, mock_kill_process_tree, router_adapter_kwargs):\n        \"\"\"Test shutdown method.\"\"\"\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.post\") as mock_post:\n            mock_response = Mock()\n            mock_response.status_code = 200\n            mock_post.return_value = mock_response\n\n            adapter = HttpServerAdapter(**router_adapter_kwargs)\n\n            adapter.shutdown()\n\n            # Should unregister from router\n            assert mock_post.call_count == 2  # Once for registration, once for unregistration\n            # Should kill process\n            mock_kill_process_tree.assert_called_once_with(mock_launch_server_process.return_value.pid)\n\n    def test_shutdown_with_errors(self, mock_launch_server_process, mock_kill_process_tree, router_adapter_kwargs):\n        \"\"\"Test shutdown method with errors.\"\"\"\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.post\") as mock_post:\n            # Mock registration success but unregistration failure\n            mock_post.side_effect = [\n                Mock(status_code=200),  # Registration success\n                requests.RequestException(\"Unregistration failed\"),  # Unregistration failure\n            ]\n\n            # Mock process kill failure\n            mock_kill_process_tree.side_effect = Exception(\"Kill failed\")\n\n            adapter = HttpServerAdapter(**router_adapter_kwargs)\n\n            # Should not raise exceptions\n            adapter.shutdown()\n\n            assert mock_post.call_count == 2\n            mock_kill_process_tree.assert_called_once_with(mock_launch_server_process.return_value.pid)\n\n    # Edge cases for HttpServerEngineAdapter\n    def test_empty_and_none_parameters(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test handling of empty and None parameters.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {\"status\": \"success\"}\n            req = UpdateWeightsFromTensorReqInput(\n                serialized_named_tensors=None,\n                load_format=None,\n                flush_cache=None,\n            )\n\n            # Test generate with all None parameters\n            result = adapter.generate()\n            assert result == {\"status\": \"success\"}\n\n            # Test with empty lists\n            result = adapter.update_weights_from_tensor(req)\n            assert result == {\"status\": \"success\"}\n\n            # Test with empty tags\n            result = adapter.release_memory_occupation(req)\n            assert result == {\"status\": \"success\"}\n\n    def test_large_payload_handling(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test handling of large payloads.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {\"status\": \"success\"}\n\n            # Test with large tensor list\n            large_tensor_list = [MultiprocessingSerializer.serialize(f\"tensor_{i}\") for i in range(1000)]\n\n            req = UpdateWeightsFromTensorReqInput(\n                serialized_named_tensors=large_tensor_list,\n                load_format=\"safetensors\",\n                flush_cache=True,\n            )\n            result = adapter.update_weights_from_tensor(req)\n            assert result == {\"status\": \"success\"}\n\n            # Test with large prompt\n            large_prompt = \"A\" * 10000\n            result = adapter.generate(prompt=large_prompt)\n            assert result == {\"status\": \"success\"}\n\n    def test_timeout_edge_cases(self, mock_launch_server_process):\n        \"\"\"Test various timeout scenarios.\"\"\"\n        # Test with very small timeout\n        kwargs = {\"host\": \"localhost\", \"port\": 8000, \"node_rank\": 0, \"model_path\": \"/tmp/test_model\", \"timeout\": 0.001}\n        adapter = HttpServerAdapter(**kwargs)\n\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.post\") as mock_post:\n            mock_post.side_effect = requests.exceptions.Timeout()\n\n            with pytest.raises(RuntimeError, match=\"Failed to complete request\"):\n                adapter._make_request(\"test_endpoint\")\n\n    def test_extreme_configuration_values(self, mock_launch_server_process):\n        \"\"\"Test extreme configuration values.\"\"\"\n        # Test with extreme values\n        kwargs = {\n            \"host\": \"localhost\",\n            \"port\": 8000,\n            \"node_rank\": 0,\n            \"model_path\": \"/tmp/test_model\",\n            \"timeout\": 0.001,  # Very small\n            \"max_attempts\": 100,  # Very large\n            \"retry_delay\": 0.001,  # Very small\n        }\n        adapter = HttpServerAdapter(**kwargs)\n\n        assert adapter.timeout == 0.001\n        assert adapter.max_attempts == 100\n        assert adapter.retry_delay == 0.001\n\n\nclass TestAsyncHttpServerEngineAdapter:\n    \"\"\"Test cases for AsyncHttpServerEngineAdapter class.\"\"\"\n\n    def test_init(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test async adapter initialization.\"\"\"\n        adapter = AsyncHttpServerAdapter(max_connections=50, **basic_adapter_kwargs)\n\n        assert adapter.max_connections == 50\n\n    @pytest.mark.asyncio\n    async def test_make_async_request_success(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test successful async HTTP request.\"\"\"\n\n        # Instantiate adapter\n        adapter = AsyncHttpServerAdapter(**basic_adapter_kwargs)\n\n        mock_response = AsyncMock()\n        mock_response.status = 200\n        mock_response.json = AsyncMock(return_value={\"status\": \"success\"})\n        mock_response.raise_for_status = Mock()\n\n        mock_post_context_manager = AsyncMock()\n        mock_post_context_manager.__aenter__.return_value = mock_response\n\n        mock_session = AsyncMock(spec=aiohttp.ClientSession)\n        mock_session.closed = False\n        mock_session.post.return_value = mock_post_context_manager\n\n        mock_session_cm = AsyncMock()\n        mock_session_cm.__aenter__.return_value = mock_session\n\n        with patch.object(adapter, \"_get_session\", return_value=mock_session_cm):\n            result = await adapter._make_async_request(\"test_endpoint\", {\"param\": \"value\"})\n\n            # Assert result is correct\n            assert result == {\"status\": \"success\"}\n\n            # Verify post was called\n            mock_session.post.assert_called_once_with(\n                \"http://localhost:8000/test_endpoint\", json={\"param\": \"value\"}, timeout=adapter.timeout\n            )\n\n    @pytest.mark.asyncio\n    async def test_make_async_request_get_method(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test async GET request using aiohttp and proper context mocking.\"\"\"\n\n        # Instantiate the async adapter\n        adapter = AsyncHttpServerAdapter(**basic_adapter_kwargs)\n\n        mock_response = AsyncMock()\n        mock_response.status = 200\n        mock_response.json = AsyncMock(return_value={\"data\": \"test\"})\n        mock_response.raise_for_status = Mock()\n\n        mock_get_context_manager = AsyncMock()\n        mock_get_context_manager.__aenter__.return_value = mock_response\n\n        mock_session = AsyncMock(spec=aiohttp.ClientSession)\n        mock_session.closed = False\n        mock_session.get.return_value = mock_get_context_manager\n\n        mock_session_cm = AsyncMock()\n        mock_session_cm.__aenter__.return_value = mock_session\n\n        with patch.object(adapter, \"_get_session\", return_value=mock_session_cm):\n            result = await adapter._make_async_request(\"test_endpoint\", method=\"GET\")\n\n            # Validate\n            assert result == {\"data\": \"test\"}\n            mock_session.get.assert_called_once_with(\"http://localhost:8000/test_endpoint\", timeout=adapter.timeout)\n\n    @pytest.mark.asyncio\n    async def test_make_async_request_non_master(self, mock_launch_server_process):\n        \"\"\"Test async request from non-master node.\"\"\"\n        kwargs = {\"host\": \"localhost\", \"port\": 8000, \"node_rank\": 1, \"model_path\": \"/tmp/test_model\"}\n        adapter = AsyncHttpServerAdapter(**kwargs)\n        result = await adapter._make_async_request(\"test_endpoint\")\n\n        assert result == {}\n\n    @pytest.mark.asyncio\n    async def test_async_generate(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test async generate method.\"\"\"\n        adapter = AsyncHttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_async_request\", new_callable=AsyncMock) as mock_request:\n            mock_request.return_value = {\"text\": \"Generated text\"}\n\n            result = await adapter.generate(\n                prompt=\"Hello world\",\n                sampling_params={\"temperature\": 0.7},\n                return_logprob=True,\n            )\n\n            assert result == {\"text\": \"Generated text\"}\n            mock_request.assert_called_once()\n\n    @pytest.mark.asyncio\n    async def test_async_memory_management(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test async memory management methods.\"\"\"\n        adapter = AsyncHttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_async_request\", new_callable=AsyncMock) as mock_request:\n            mock_request.return_value = {\"status\": \"success\"}\n\n            # Test release_memory_occupation\n            result = await adapter.release_memory_occupation([\"weights\"])\n            assert result == {\"status\": \"success\"}\n            mock_request.assert_called_with(\"release_memory_occupation\", {\"tags\": [\"weights\"]})\n\n            # Test resume_memory_occupation\n            result = await adapter.resume_memory_occupation([\"weights\"])\n            assert result == {\"status\": \"success\"}\n            mock_request.assert_called_with(\"resume_memory_occupation\", {\"tags\": [\"weights\"]})\n            assert (\n                mock_request.call_count == 2\n            )  # resume memory occupation will also call release memory occupation once\n\n\nclass TestErrorRecovery:\n    \"\"\"Test error recovery mechanisms.\"\"\"\n\n    def test_flush_cache_recovery(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test flush cache recovery from failures.\"\"\"\n        adapter = HttpServerAdapter(max_attempts=2, **basic_adapter_kwargs)\n\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.get\") as mock_get:\n            # Simulate multiple failures then success\n            mock_get.side_effect = [\n                requests.exceptions.ConnectionError(),\n                requests.exceptions.Timeout(),\n                Mock(status_code=503),  # Service unavailable\n                Mock(status_code=200, json=lambda: {\"cache_flushed\": True}),\n            ]\n\n            with patch(\"time.sleep\"):\n                result = adapter.flush_cache()\n                assert result == {\"cache_flushed\": True}\n\n    def test_flush_cache_max_attempts(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test flush cache max retries exceeded.\"\"\"\n        adapter = HttpServerAdapter(max_attempts=1, **basic_adapter_kwargs)\n\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.get\") as mock_get:\n            # All attempts fail\n            mock_get.side_effect = requests.exceptions.ConnectionError()\n\n            with patch(\"time.sleep\"):\n                result = adapter.flush_cache()\n                assert result == {}  # Should return empty dict on failure\n\n    def test_network_partition_recovery(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test recovery from network partition scenarios.\"\"\"\n        adapter = HttpServerAdapter(max_attempts=3, **basic_adapter_kwargs)\n\n        with patch(\"verl.workers.rollout.sglang_rollout.http_server_engine.requests.post\") as mock_post:\n            # Simulate network partition then recovery\n            mock_post.side_effect = [\n                requests.exceptions.ConnectionError(\"Network unreachable\"),\n                requests.exceptions.ConnectionError(\"Network unreachable\"),\n                Mock(status_code=200, json=lambda: {\"recovered\": True}),\n            ]\n\n            with patch(\"time.sleep\"):\n                result = adapter._make_request(\"test_endpoint\")\n                assert result == {\"recovered\": True}\n\n\nclass TestResourceManagement:\n    \"\"\"Test resource management and cleanup.\"\"\"\n\n    def test_resource_cleanup_on_exception(\n        self, mock_launch_server_process, mock_kill_process_tree, basic_adapter_kwargs\n    ):\n        \"\"\"Test resource cleanup when exceptions occur.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        # Simulate exception during operation\n        with patch.object(adapter, \"_make_request\", side_effect=Exception(\"Test error\")):\n            try:\n                adapter.generate(prompt=\"test\")\n            except Exception:\n                pass\n\n        # Cleanup should still work\n        adapter.shutdown()\n        mock_kill_process_tree.assert_called_once_with(mock_launch_server_process.return_value.pid)\n\n    def test_multiple_shutdown_calls(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test multiple shutdown calls are safe.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        # Multiple shutdown calls should be safe\n        adapter.shutdown()\n        adapter.shutdown()\n        adapter.shutdown()\n\n\nclass TestDataTypeHandling:\n    \"\"\"Test handling of various data types.\"\"\"\n\n    def test_complex_data_structures(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test handling of complex data structures.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {\"status\": \"success\"}\n\n            # Test with complex sampling params\n            complex_sampling_params = {\n                \"temperature\": 0.7,\n                \"top_p\": 0.9,\n                \"top_k\": 50,\n                \"repetition_penalty\": 1.1,\n                \"stop_sequences\": [\"</s>\", \"\\n\\n\"],\n                \"max_tokens\": 100,\n                \"logit_bias\": {\"token_123\": 0.5, \"token_456\": -0.5},\n                \"nested_config\": {\n                    \"beam_search\": True,\n                    \"num_beams\": 4,\n                    \"early_stopping\": True,\n                },\n            }\n\n            result = adapter.generate(\n                prompt=\"Test prompt\",\n                sampling_params=complex_sampling_params,\n            )\n\n            assert result == {\"status\": \"success\"}\n            # Verify the complex structure was passed through\n            call_args = mock_request.call_args[0][1]\n            assert call_args[\"sampling_params\"] == complex_sampling_params\n\n\nclass TestIntegration:\n    \"\"\"Integration tests for both adapters.\"\"\"\n\n    def test_error_scenarios(self, mock_launch_server_process, basic_adapter_kwargs):\n        \"\"\"Test various error scenarios.\"\"\"\n        adapter = HttpServerAdapter(**basic_adapter_kwargs)\n\n        # Test with None payload\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {}\n            result = adapter.generate()\n            assert result == {}\n\n        # Test with empty parameters\n        with patch.object(adapter, \"_make_request\") as mock_request:\n            mock_request.return_value = {}\n            req = UpdateWeightsFromTensorReqInput(\n                serialized_named_tensors=None,\n                load_format=None,\n                flush_cache=None,\n            )\n            result = adapter.update_weights_from_tensor(req)\n            assert result == {}\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed.fsdp import CPUOffload, MixedPrecision\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\nfrom vllm import SamplingParams\n\nfrom verl.third_party.vllm import LLM\nfrom verl.utils.distributed import initialize_global_process_group\n\n\ndef main():\n    assert torch.cuda.is_available(), \"CUDA must be present to run FSDP vLLM example\"\n    local_rank, rank, world_size = initialize_global_process_group()\n\n    local_cache_path = \"~/.cache/verl/rlhf\"\n    local_cache_path = os.path.expanduser(local_cache_path)\n    hdfs_path = \"Qwen/Qwen2-7B-Instruct\"\n\n    from verl.utils.fs import copy_to_local\n\n    local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)\n    tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)\n    actor_model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=True)\n    with torch.device(\"cuda\"):\n        actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True)\n        actor_model.to(torch.bfloat16)\n\n    max_prompt_length = 16\n    response_length = 32\n    preencode_prompts = [\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ]\n    tokenizer.pad_token = tokenizer.eos_token\n    prompts = tokenizer(preencode_prompts, return_tensors=\"pt\", padding=True)\n    input_ids = prompts[\"input_ids\"]\n    attention_mask = prompts[\"attention_mask\"]\n    from verl.utils.torch_functional import pad_sequence_to_length\n\n    input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True).cuda()\n    attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True).cuda()\n\n    from transformers import GenerationConfig\n\n    generation_config = GenerationConfig(do_sample=False)\n    actor_model.cuda()\n    output = actor_model.generate(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        max_new_tokens=32,\n        # max_length=max_length,\n        eos_token_id=tokenizer.eos_token_id,\n        pad_token_id=tokenizer.pad_token_id,\n        generation_config=generation_config,\n        # renormalize_logits=True,\n        output_scores=False,  # this is potentially very large\n        return_dict_in_generate=True,\n        use_cache=False,\n    )  # may OOM when use_cache = True\n    seq = output.sequences\n    response = seq[:, max_prompt_length:]\n\n    print(f\"hf response: {tokenizer.batch_decode(response)}\")\n\n    tensor_model_parallel_size = 4\n    from torch.distributed.device_mesh import init_device_mesh\n\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n\n    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)\n    fsdp_model = FSDP(\n        actor_model,\n        use_orig_params=True,\n        auto_wrap_policy=None,\n        device_id=torch.cuda.current_device(),\n        sharding_strategy=ShardingStrategy.FULL_SHARD,\n        mixed_precision=mixed_precision,\n        cpu_offload=CPUOffload(offload_params=False),\n        sync_module_states=False,\n        device_mesh=device_mesh,\n    )\n\n    FSDP.set_state_dict_type(\n        fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()\n    )\n\n    state_dict = fsdp_model.state_dict()\n\n    sampling_params = SamplingParams(\n        temperature=0, top_p=1, n=1, max_tokens=response_length, logprobs=1, ignore_eos=True, detokenize=False\n    )\n\n    print(actor_model_config)\n    llm = LLM(\n        model=None,\n        tokenizer=tokenizer,\n        model_hf_config=actor_model_config,\n        tensor_parallel_size=tensor_model_parallel_size,\n        enforce_eager=True,\n        dtype=\"bfloat16\",\n        load_format=\"dummy_dtensor\",\n        gpu_memory_utilization=0.8,\n        trust_remote_code=True,\n    )\n\n    # Warmup iterations\n    for _ in range(10):\n        torch.cuda.synchronize()\n        llm.sync_model_weights(actor_weights=state_dict, load_format=\"dtensor\")\n        torch.cuda.synchronize()\n        dist.barrier()\n\n    start_time = time.time()\n    llm.sync_model_weights(actor_weights=state_dict, load_format=\"dtensor\")\n    torch.cuda.synchronize()\n    dist.barrier()\n    end_time = time.time()\n\n    # Calculate elapsed time\n    elapsed_time = end_time - start_time\n    print(f\"Time taken: {elapsed_time:.6f} seconds\")\n\n    input_ids = input_ids.cuda()\n    attention_mask = attention_mask.cuda()\n    idx_list = []\n    batch_size = input_ids.shape[0]\n\n    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n    from verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import _pre_process_inputs\n\n    for i in range(batch_size):\n        idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))\n    print(\"start generation\")\n    outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False)\n    vllm_output = outputs[0].cuda()\n    if torch.distributed.get_rank() == 0:\n        print(f\"hf response: {tokenizer.batch_decode(response)}\")\n        print(f\"vllm response: {tokenizer.batch_decode(vllm_output)}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport os\n\nimport torch\nimport torch.distributed\nimport torch.distributed as dist\nfrom omegaconf import OmegaConf\nfrom transformers import AutoTokenizer\n\nfrom verl import DataProto\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.distributed import initialize_global_process_group\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import vLLMRollout\n\n\ndef test_vllm_rollout_with_yarn_position_embeddings():\n    \"\"\"\n    Test the vLLM rollout with yarn position embeddings.\n    \"\"\"\n\n    local_rank, rank, world_size = initialize_global_process_group()\n    model_path = os.path.expanduser(\"~/models/OldKingMeister/Qwen2.5-1.5B-Instruct-YaRN\")\n    config = OmegaConf.create(\n        {\n            \"name\": \"vllm\",\n            \"prompt_length\": 35000,\n            \"response_length\": 512,\n            \"dtype\": \"bfloat16\",\n            \"enforce_eager\": True,\n            \"gpu_memory_utilization\": 0.4,\n            \"enable_chunked_prefill\": False,\n            \"free_cache_engine\": False,\n            \"disable_log_stats\": True,\n            \"max_model_len\": 35000 + 512,\n            \"max_num_seqs\": 1024,\n            \"load_format\": \"auto\",\n            \"val_kwargs\": {\n                \"top_k\": -1,\n                \"top_p\": 1.0,\n                \"temperature\": 0,\n                \"n\": 1,\n                \"do_sample\": False,\n            },\n            \"tensor_model_parallel_size\": 4,\n            \"calculate_log_probs\": False,\n            \"do_sample\": False,\n            \"temperature\": 0.0,\n            \"max_num_batched_tokens\": 35000 + 512,\n        }\n    )\n\n    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side=\"left\")\n    tokenizer.pad_token = tokenizer.eos_token\n\n    # do_sample=False for temperate=0 deterministic\n    input_dataproto = prepare_input_dataproto(tokenizer, config, validate=True, do_sample=False)\n\n    rollout_config: RolloutConfig = omega_conf_to_dataclass(config, dataclass_type=RolloutConfig)\n    model_config = HFModelConfig(path=model_path)\n    model_config.tokenizer.pad_token = tokenizer.eos_token\n\n    vllm_rollout = vLLMRollout(\n        config=rollout_config,\n        model_config=model_config,\n        device_mesh=None,\n    )\n    # rollout\n    rollout_response = vllm_rollout.generate_sequences(\n        prompts=input_dataproto,\n    )\n    if rank == 0:\n        print(\"VLLM Rollout Outputs:\")\n        print(tokenizer.batch_decode(rollout_response.batch[\"responses\"][:], skip_special_tokens=False))\n        for response in rollout_response.batch[\"responses\"]:\n            assert \"<|im_end|>\" in tokenizer.decode(response, skip_special_tokens=False), (\n                \"Response should contain <|im_end|> token\"\n            )\n    print(\"Checks passed.\")\n\n    del vllm_rollout\n    gc.collect()\n    torch.cuda.empty_cache()\n    torch.cuda.ipc_collect()\n    dist.barrier()\n    torch.distributed.destroy_process_group()\n\n\ndef prepare_input_dataproto(tokenizer, config, validate, do_sample=False):\n    base_phrase = \"Roses are red, sky is blue. \" * 4096\n    preencode_prompts = [\n        # 32810 tokens > 32768 tokens\n        [{\"role\": \"user\", \"content\": base_phrase + \"Who won the Champions League in 2019?\"}],\n        [{\"role\": \"user\", \"content\": base_phrase + \"The founder of Apple is\"}],\n        [{\"role\": \"user\", \"content\": base_phrase + \"What's your name\"}],\n    ]\n    formatted_prompts = [\n        tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)\n        for conversation in preencode_prompts\n    ]\n    prompts = tokenizer(formatted_prompts, return_tensors=\"pt\", padding=\"max_length\", max_length=config.prompt_length)\n    input_dataproto = DataProto.from_dict(\n        {\n            \"input_ids\": prompts[\"input_ids\"],\n            \"attention_mask\": prompts[\"attention_mask\"],\n            \"position_ids\": compute_position_id_with_mask(prompts[\"attention_mask\"]),\n        },\n        meta_info={\n            \"bos_token_id\": tokenizer.bos_token_id,\n            \"eos_token_id\": tokenizer.eos_token_id,\n            \"pad_token_id\": tokenizer.pad_token_id,\n            \"validate\": validate,\n            \"do_sample\": do_sample,\n            \"response_length\": config.response_length,\n            \"temperature\": config.temperature,\n        },\n    )\n    return input_dataproto\n\n\nif __name__ == \"__main__\":\n    test_vllm_rollout_with_yarn_position_embeddings()\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport pytest\nimport torch\nfrom torch.distributed.fsdp import CPUOffload, MixedPrecision\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom vllm import LLM, SamplingParams\n\nfrom verl.utils.distributed import initialize_global_process_group\nfrom verl.utils.torch_functional import pad_sequence_to_length\n\n\ndef levenshtein(s1, s2):\n    m, n = len(s1), len(s2)\n    # Initialize matrix of zeros\n    dp = [[0] * (n + 1) for _ in range(m + 1)]\n    # Initialize first column and first row of the matrix\n    for i in range(m + 1):\n        dp[i][0] = i  # Deletion from s1 to empty string\n    for j in range(n + 1):\n        dp[0][j] = j  # Insertion to s1 from empty string\n    # Compute the Levenshtein distance matrix\n    for i in range(1, m + 1):\n        for j in range(1, n + 1):\n            cost = 0 if s1[i - 1] == s2[j - 1] else 1  # No cost if characters match\n            dp[i][j] = min(\n                dp[i - 1][j] + 1,  # Deletion\n                dp[i][j - 1] + 1,  # Insertion\n                dp[i - 1][j - 1] + cost,  # Substitution\n            )\n    return dp[m][n]\n\n\ndef are_lists_similar(a, b):\n    if len(a) != len(b):\n        print(\"The lists are of different lengths.\")\n        return False\n\n    total_length = 0\n    total_diff = 0\n\n    for s1, s2 in zip(a, b, strict=True):\n        max_len = max(len(s1), len(s2))\n        total_length += max_len\n        diff = levenshtein(s1, s2)\n        total_diff += diff\n        print(f\"Comparing strings:\\n{s1}\\n{s2}\\nDifference: {diff} characters\\n\")\n\n    percentage_difference = (total_diff / total_length) * 100\n    print(f\"Total difference: {percentage_difference:.2f}%\")\n\n    return percentage_difference <= 15\n\n\n@pytest.mark.skip(\"https://github.com/vllm-project/vllm/issues/16993\")\ndef test_vllm_spmd():\n    assert torch.cuda.device_count() >= 2, \"At least 2 GPUs is required to run tp+dp tests.\"\n    local_rank, rank, world_size = initialize_global_process_group()\n\n    # Initialize model and token\n    local_cache_path = \"~/.cache/verl/rlhf\"\n    local_cache_path = os.path.expanduser(local_cache_path)\n    hdfs_path = \"Qwen/Qwen2.5-1.5B-Instruct\"\n    from verl.utils.fs import copy_to_local\n\n    local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)\n    tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side=\"left\", trust_remote_code=True)\n\n    actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True)\n    actor_model.to(torch.bfloat16)\n\n    # fill rollout config\n    max_prompt_length = 16\n    max_response_length = 32\n    preencode_prompts = [\n        \"Who won the Champions League in 2019?\",\n        \"The founder of Apple is\",\n        \"What's your name?\",\n    ]\n    tokenizer.pad_token = tokenizer.eos_token\n    prompts = tokenizer(preencode_prompts, return_tensors=\"pt\", padding=True)\n    input_ids = prompts[\"input_ids\"]\n    attention_mask = prompts[\"attention_mask\"]\n\n    input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)\n    attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)\n\n    print(\"start generation\")\n    input_ids = input_ids.cuda()\n    attention_mask = attention_mask.cuda()\n\n    temperature = 0\n    top_p = 1\n    kwargs = dict(\n        n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True\n    )\n\n    tensor_parallel_size = 4\n\n    from torch.distributed.device_mesh import init_device_mesh\n\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n\n    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)\n\n    fsdp_model = FSDP(\n        actor_model,\n        use_orig_params=True,\n        auto_wrap_policy=None,\n        device_id=torch.cuda.current_device(),\n        sharding_strategy=ShardingStrategy.FULL_SHARD,\n        mixed_precision=mixed_precision,\n        cpu_offload=CPUOffload(offload_params=False),\n        sync_module_states=False,\n        device_mesh=device_mesh,\n    )\n\n    FSDP.set_state_dict_type(\n        fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()\n    )\n\n    state_dict = fsdp_model.state_dict()\n\n    sampling_params = SamplingParams(**kwargs)\n    llm = LLM(\n        model=local_model_path,\n        enable_sleep_mode=True,\n        tensor_parallel_size=tensor_parallel_size,\n        distributed_executor_backend=\"external_launcher\",\n        dtype=\"bfloat16\",\n        enforce_eager=True,\n        gpu_memory_utilization=0.8,\n        disable_custom_all_reduce=True,\n        skip_tokenizer_init=False,\n        enable_prefix_caching=True,\n        trust_remote_code=True,\n        seed=1,\n    )\n\n    outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False)\n    vllm_response_tokens = []\n    for output in outputs:\n        generated_text = output.outputs[0].text\n        vllm_response_tokens.append(generated_text)\n\n    world_size = torch.distributed.get_world_size()\n    model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model\n    model.load_weights(\n        ((name, param.full_tensor() if world_size != 1 else param) for name, param in state_dict.items())\n    )\n\n    outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False)\n    verl_vllm_response_tokens = []\n    for output in outputs:\n        generated_text = output.outputs[0].text\n        verl_vllm_response_tokens.append(generated_text)\n\n    if torch.distributed.get_rank() == 0:\n        print(f\"vllm response: {vllm_response_tokens}\")\n        print(f\"verl-vllm response: {verl_vllm_response_tokens}\")\n    assert are_lists_similar(vllm_response_tokens, verl_vllm_response_tokens), \"Strings differ more than 10%:\\n\"\n    print(\"Check Pass\")\n    torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    test_vllm_spmd()\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_hf_rollout.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\nfrom omegaconf import OmegaConf\nfrom torch.distributed.fsdp import CPUOffload, MixedPrecision\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom verl import DataProto\nfrom verl.utils.distributed import initialize_global_process_group\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.workers.rollout.hf_rollout import HFRollout\n\nBASE_HF_ROLLOUT_CONFIG = {\n    \"temperature\": 1.0,\n    \"top_k\": -1,\n    \"top_p\": 1,\n    \"prompt_length\": 64,\n    \"response_length\": 64,\n    \"do_sample\": True,\n    \"n\": 1,\n    \"val_kwargs\": {\n        \"top_k\": -1,\n        \"top_p\": 1.0,\n        \"temperature\": 0,\n        \"n\": 1,\n        \"do_sample\": False,\n    },\n}\n\n\ndef prepare_input_dataproto(tokenizer, config, validate):\n    preencode_prompts = [\n        [{\"role\": \"user\", \"content\": \"Who won the Champions League in 2019?\"}],\n        [{\"role\": \"user\", \"content\": \"The founder of Apple is\"}],\n        [{\"role\": \"user\", \"content\": \"What's your name\"}],\n    ]\n    formatted_prompts = [\n        tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)\n        for conversation in preencode_prompts\n    ]\n    prompts = tokenizer(formatted_prompts, return_tensors=\"pt\", padding=\"max_length\", max_length=config.prompt_length)\n    input_dataproto = DataProto.from_dict(\n        {\n            \"input_ids\": prompts[\"input_ids\"],\n            \"attention_mask\": prompts[\"attention_mask\"],\n            \"position_ids\": compute_position_id_with_mask(prompts[\"attention_mask\"]),\n        },\n        meta_info={\n            \"bos_token_id\": tokenizer.bos_token_id,\n            \"eos_token_id\": tokenizer.eos_token_id,\n            \"pad_token_id\": tokenizer.pad_token_id,\n            \"validate\": validate,\n        },\n    )\n    return input_dataproto\n\n\ndef prepare_fsdp_model(model, world_size):\n    from torch.distributed.device_mesh import init_device_mesh\n\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n\n    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)\n\n    fsdp_model = FSDP(\n        model,\n        use_orig_params=True,\n        auto_wrap_policy=None,\n        device_id=torch.cuda.current_device(),\n        sharding_strategy=ShardingStrategy.FULL_SHARD,\n        mixed_precision=mixed_precision,\n        cpu_offload=CPUOffload(offload_params=False),\n        sync_module_states=False,\n        device_mesh=device_mesh,\n    )\n\n    FSDP.set_state_dict_type(\n        fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()\n    )\n    return fsdp_model\n\n\ndef test_hf_rollout(n: int = 1, do_sample: bool = True, validate: bool = False):\n    config = OmegaConf.create(BASE_HF_ROLLOUT_CONFIG)\n    config.update({\"n\": n, \"do_sample\": do_sample})\n\n    assert torch.cuda.device_count() >= 2, \"At least 2 GPUs is required to run tp+dp tests.\"\n    local_rank, rank, world_size = initialize_global_process_group()\n\n    # Initialize model and tokenizer\n    local_cache_path = \"~/.cache/verl/rlhf\"\n    local_cache_path = os.path.expanduser(local_cache_path)\n    hdfs_path = \"Qwen/Qwen2-7B-Instruct\"\n    local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)\n    tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side=\"left\", trust_remote_code=True)\n    tokenizer.pad_token = tokenizer.eos_token\n\n    # Initialize FSDP model\n    actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True)\n    actor_model.to(torch.bfloat16)\n    fsdp_model = prepare_fsdp_model(actor_model, world_size)\n\n    # Initialize HFRollout and start generate\n    hf_rollout = HFRollout(fsdp_model, OmegaConf.create(config))\n    input = prepare_input_dataproto(tokenizer, config, validate).to(torch.cuda.current_device())\n    outputs = hf_rollout.generate_sequences(input)\n\n    # check generated batch size is expected\n    generated_batch_size = outputs.batch.batch_size[0]\n    assert generated_batch_size == input.batch.batch_size[0] * config.n\n\n    for i in range(generated_batch_size):\n        prompt_tokens = outputs.batch[\"prompts\"][i]\n        prompt_mask = prompt_tokens != tokenizer.pad_token_id\n        prompt_tokens = prompt_tokens[prompt_mask]\n        decoded_prompt = tokenizer.decode(prompt_tokens, skip_special_tokens=False)\n\n        response_tokens = outputs.batch[\"responses\"][i]\n        response_mask = response_tokens != tokenizer.pad_token_id\n        response_tokens = response_tokens[response_mask]\n        decoded_response = tokenizer.decode(response_tokens, skip_special_tokens=False)\n\n        attention_mask = outputs.batch[\"attention_mask\"][i]\n        position_ids = outputs.batch[\"position_ids\"][i]\n        prompt_length = outputs.batch[\"prompts\"].size(1)\n        response_length = outputs.batch[\"responses\"].size(1)\n\n        assert attention_mask.size(0) == prompt_length + response_length\n        assert position_ids.size(0) == prompt_length + response_length\n\n        # check response attention mask is expected\n        response_attention = attention_mask[prompt_length:]\n        eos_positions = (outputs.batch[\"responses\"][i] == tokenizer.pad_token_id).nonzero(as_tuple=True)[0]\n        if len(eos_positions) > 0:\n            first_eos_pos = eos_positions[0].item()\n            assert response_attention[: first_eos_pos + 1].all(), \"Response attention mask should be 1 until EOS\"\n            if first_eos_pos + 1 < response_length:\n                assert not response_attention[first_eos_pos + 1 :].any(), (\n                    \"Response attention mask should be 0 after EOS\"\n                )\n        else:\n            assert response_attention.all(), \"Response attention mask should be all 1 if no EOS token\"\n\n        # check response position ids is expected\n        prompt_positions = position_ids[:prompt_length]\n        response_positions = position_ids[prompt_length:]\n        valid_response_length = min(len(response_tokens), response_length)\n        if valid_response_length > 0:\n            assert response_positions[0] == prompt_positions[-1] + 1\n            for j in range(1, valid_response_length):\n                assert response_positions[j] == response_positions[j - 1] + 1\n\n        # print generated text for inspection\n        if torch.distributed.get_rank() == 0:\n            print(f\"prompt: {decoded_prompt}\")\n            print(f\"response: {decoded_response}\")\n            print(\"=\" * 30)\n\n\nif __name__ == \"__main__\":\n    test_hf_rollout(n=2, do_sample=True, validate=False)\n    # test_hf_rollout(n=1, do_sample=False, validate=True)\n    # test_hf_rollout(n=1, do_sample=True, validate=False)\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from tests/workers/rollout/test_sglang_async_rollout_sf_tools.py\n\n\nimport asyncio\nimport os\nfrom copy import deepcopy\nfrom unittest.mock import AsyncMock, MagicMock, patch\n\nimport numpy as np\nimport pytest\nfrom tensordict import TensorDict\nfrom transformers import AutoConfig, AutoTokenizer\nfrom utils_sglang import get_rollout_config, prepare_inputs\n\nfrom verl.protocol import DataProto\nfrom verl.tools.mcp_search_tool import MCPSearchTool\nfrom verl.tools.schemas import ToolResponse\nfrom verl.tools.utils.mcp_clients.McpClientManager import MCPClientManager\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\nDEFAULT_USER_CONTENT_PREFIX = (\n    \"Answer the given question. You must conduct reasoning inside <think> and </think> \"\n    \"first every time you get new information. After reasoning, if you find you lack \"\n    \"some knowledge, you can call a search engine by <tool_call> query </tool_call> \"\n    \"and it will return the top searched results between <tool_response> and \"\n    \"</tool_response>. You can search as many times as your want. If you find no \"\n    \"further external knowledge needed, you can directly provide the answer inside \"\n    \"<answer> and </answer>, without detailed illustrations. For example, \"\n    \"<answer> Beijing </answer>. Question: \"\n)\nuser_content = DEFAULT_USER_CONTENT_PREFIX.rstrip(\"\\n\") + \"How's the weather lately?\"\n\n\ndef get_search_messages():\n    user_prompt = {\n        \"role\": \"user\",\n        \"content\": user_content,\n    }\n\n    expect_turn_0_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"Let me search the web.\",\n        \"tool_calls\": [\n            {\n                \"id\": \"10\",\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"tavily_search_tool\",\n                    \"arguments\": {\n                        \"what_is_your_intent\": \"Search for the weather lately\",\n                        \"query\": \"the weather in Beijing today\",\n                        \"search_depth\": \"basic\",\n                        \"time_range\": \"day\",\n                        \"include_domains\": [\"google.com\", \"baidu.com\"],\n                        \"max_results\": 2,\n                    },\n                },\n            }\n        ],\n    }\n\n    expect_turn_1_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"Let me search again.\",\n        \"tool_calls\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"tavily_search_tool\",\n                    \"arguments\": {\n                        \"what_is_your_intent\": \"Search for the weather lately\",\n                        \"query\": \"the weather in Beijing tomorrow\",\n                        \"search_depth\": \"basic\",\n                        \"time_range\": \"day\",\n                        \"include_domains\": [\"google.com\", \"baidu.com\"],\n                        \"max_results\": 2,\n                    },\n                },\n            }\n        ],\n    }\n\n    expect_turn_2_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"<answer>Today is sunny and tomorrow will be cloudy in Beijing.</answer>\",\n    }\n\n    # Mock search tool responses\n    tool_return_0_msg = {\"role\": \"tool\", \"content\": [{\"type\": \"text\", \"text\": \"Today's weather in Beijing is sunny.\"}]}\n    tool_return_1_msg = {\n        \"role\": \"tool\",\n        \"content\": [{\"type\": \"text\", \"text\": \"Tomorrow's weather in Beijing is cloudy.\"}],\n    }\n\n    user_prompts = [user_prompt]\n    expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg]\n    tool_return_array = [tool_return_0_msg, tool_return_1_msg]\n\n    return user_prompts, expect_turn_array, tool_return_array\n\n\nclass TestRolloutWithMCPSearchTools:\n    local_model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B\")\n\n    @pytest.fixture\n    def qwen_tokenizer(self):\n        tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side=\"left\")\n        tokenizer.pad_token = tokenizer.eos_token\n        return tokenizer\n\n    # we only need this for tokenizer\n    @pytest.fixture\n    def qwen_model_config(self):\n        config = AutoConfig.from_pretrained(self.local_model_path)\n        return config\n\n    @pytest.fixture\n    def search_data(self, qwen_tokenizer):\n        user_prompt, expect_turn_array, tool_return_array = get_search_messages()\n        prompts = [[message] for message in user_prompt]\n        preencode_turn_array = [\n            qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False)\n            for turn in expect_turn_array\n        ]\n        preencode_tool_return_array = [\n            ToolResponse(text=qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True))\n            for turn in tool_return_array\n        ]\n        return prompts, preencode_turn_array, preencode_tool_return_array\n\n    @pytest.fixture\n    def search_rollout_config(self):\n        max_prompt_length = 4096\n        max_response_length = 3000\n        dtype = \"bfloat16\"\n        tensor_parallel_size = 1\n        tool_path = \"./resource/tool_configs/mcp_tool_config\"\n        rollout_config = get_rollout_config(\n            max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path\n        )\n        return rollout_config\n\n    @pytest.fixture\n    def search_data_proto(self, search_data, qwen_tokenizer):\n        preencode_prompts, _, _ = search_data\n        prompts = [\n            qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n            for message in preencode_prompts\n        ]\n        input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000)\n        prompt_dict = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=input_ids.shape[0],\n        )\n        messages = np.asarray(preencode_prompts)\n\n        tools_kwargs = np.array(\n            [\n                {\n                    \"tavily_search_tool\": {\n                        \"create_kwargs\": {\"ground_truth\": \"Today is sunny and tomorrow will be cloudy in Beijing.\"},\n                    },\n                }\n            ],\n            dtype=object,\n        )\n        index = np.array([0], dtype=object)\n        prompts = DataProto(\n            batch=prompt_dict, non_tensor_batch={\"raw_prompt\": messages, \"tools_kwargs\": tools_kwargs, \"index\": index}\n        )\n        return prompts\n\n    @pytest.fixture\n    def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config):\n        \"\"\"Mock the rollout instance with sampling_params initialized.\"\"\"\n        tool_schema = [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"tavily_search_tool\",\n                    \"description\": \"A powerful web search tool...\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\n                            \"what_is_your_intent\": {\n                                \"type\": \"string\",\n                                \"description\": \"Describe your intent for using Tavily\",\n                            },\n                            \"query\": {\"type\": \"string\", \"description\": \"Search query\"},\n                            \"search_depth\": {\n                                \"type\": \"string\",\n                                \"description\": \"The depth of the search ('basic' or 'advanced')\",\n                            },\n                            \"topic\": {\n                                \"type\": \"string\",\n                                \"description\": \"The category of the search ('general' or 'news')\",\n                            },\n                            \"days\": {\n                                \"type\": \"integer\",\n                                \"description\": \"Number of days back to include in search results (only for \"\n                                \"'news' topic)\",\n                            },\n                            \"time_range\": {\n                                \"type\": \"string\",\n                                \"description\": \"Time range for results ('day', 'week', 'month', 'year', 'd', \"\n                                \"'w', 'm', 'y')\",\n                            },\n                            \"include_domains\": {\n                                \"type\": \"array\",\n                                \"description\": \"List of domains to specifically include in search results\",\n                            },\n                            \"exclude_domains\": {\n                                \"type\": \"array\",\n                                \"description\": \"List of domains to specifically exclude from search results\",\n                            },\n                            \"include_answer\": {\n                                \"type\": \"boolean\",\n                                \"description\": \"Whether to include an answer summary generated by an LLM\",\n                            },\n                            \"include_raw_content\": {\n                                \"type\": \"boolean\",\n                                \"description\": \"Whether to include the cleaned and parsed HTML content of each result\",\n                            },\n                            \"include_images\": {\n                                \"type\": \"boolean\",\n                                \"description\": \"Whether to include images from search results\",\n                            },\n                            \"include_image_descriptions\": {\n                                \"type\": \"boolean\",\n                                \"description\": \"Whether to include descriptions with images\",\n                            },\n                            \"max_results\": {\n                                \"type\": \"integer\",\n                                \"description\": \"Maximum number of results to return (5-20)\",\n                            },\n                            \"async_search\": {\n                                \"type\": \"boolean\",\n                                \"description\": \"Whether to perform the search asynchronously\",\n                            },\n                        },\n                        \"required\": [\"what_is_your_intent\", \"query\"],\n                    },\n                    \"strict\": False,\n                },\n            }\n        ]\n        with (\n            patch.object(MCPClientManager, \"fetch_tool_schemas\", return_value=tool_schema),\n            patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n            patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n            patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n        ):\n            rollout_config: RolloutConfig = omega_conf_to_dataclass(search_rollout_config, dataclass_type=RolloutConfig)\n            model_config = HFModelConfig(path=self.local_model_path)\n            rollout = SGLangRollout(\n                config=rollout_config,\n                model_config=model_config,\n                device_mesh=None,\n            )\n            rollout.sampling_params = {\n                \"n\": 1,\n                \"max_new_tokens\": search_rollout_config.response_length,\n                \"presence_penalty\": 0.0,\n                \"frequency_penalty\": 0.0,\n                \"repetition_penalty\": 1.0,\n            }\n            return rollout\n\n    def test_tools_registration(self, mock_rollout):\n        assert len(mock_rollout._tool_schemas) != 0\n        assert \"tavily_search_tool\" in mock_rollout._tool_map.keys()\n        from verl.tools.mcp_search_tool import MCPSearchTool\n\n        assert isinstance(mock_rollout._tool_map[\"tavily_search_tool\"], MCPSearchTool)\n        # depend on the tokenizer\n        assert mock_rollout._tool_call_parser_type == \"qwen25\"\n\n    def test_rollout_req_creation(self, mock_rollout, search_data_proto):\n        req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)\n        assert len(req_list) == 1\n        assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING\n        assert len(req_list[0].tool_schemas) == 1\n\n    def test_over_size_case(self, mock_rollout, search_data_proto, search_data):\n        mock_rollout.config.multi_turn.max_assistant_turns = 1\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n\n        _, expect_turn_array, _ = search_data\n        # here we mock a meta info with 'length'. indicate the response is truncate\n        mock_rollout._handle_engine_call = MagicMock()\n        future = asyncio.Future()\n        future.set_result(\n            {\n                \"text\": expect_turn_array[0],\n                \"meta_info\": {\n                    \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                    \"finish_reason\": {\"type\": \"length\", \"length\": 3000},\n                    \"prompt_tokens\": 132,\n                    \"completion_tokens\": 100,\n                    \"cached_tokens\": 0,\n                    \"e2e_latency\": 2.23543,\n                },\n            }\n        )\n        mock_rollout._handle_engine_call.return_value = future\n        mock_rollout._tp_rank = 0\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(\n                *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list],\n            )\n        )\n        assert len(output_req_list) == 1\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        assert output_req.reward_scores.get(\"tavily_search_tool\") == []\n        # we should only have two message, one for prompt, second for response.\n        assert len(output_req.messages) == 2\n        assert output_req.messages[1] == Message(\n            role=\"assistant\",\n            content=expect_turn_array[0],\n            tool_calls=None,\n        )\n\n    @patch.object(MCPSearchTool, \"execute\", new_callable=AsyncMock)\n    def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_proto, search_data):\n        _, expect_turn_array, tool_return_array = search_data\n        # Mock search tool execution to return predefined responses\n        mock_execute.side_effect = [(msg, 0.0, {\"status\": \"success\"}) for msg in tool_return_array]\n\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n\n        mock_rollout._handle_engine_call = MagicMock()\n        futures = [asyncio.Future() for i in expect_turn_array]\n        for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):\n            i.set_result(\n                {\n                    \"text\": turn,\n                    \"meta_info\": {\n                        \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                        \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                        \"prompt_tokens\": len(turn),\n                        \"completion_tokens\": 100,\n                        \"cached_tokens\": 0,\n                        \"e2e_latency\": 2.23543,\n                    },\n                }\n            )\n            if idx < len(expect_turn_array) - 1:\n                assert mock_rollout._function_call_parser.has_tool_call(turn)\n                assert mock_rollout._function_call_parser.parse_non_stream(turn)\n\n        mock_rollout._handle_engine_call.side_effect = futures\n        mock_rollout._tp_rank = 0\n\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(*[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list])\n        )\n\n        # Verify conversation completed successfully with proper tool usage\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        assert \"tavily_search_tool\" in output_req.metrics\n        assert output_req.metrics[\"tavily_search_tool\"][0][\"status\"] == \"success\"\n        assert mock_execute.await_count == 2\n        assert len(output_req.messages) == 6\n        # Verify tool response messages contain expected content\n        search_counter = 0\n        for msg in output_req.messages:\n            if msg.role == \"tool\":\n                assert msg.content == tool_return_array[search_counter].text\n                search_counter += 1\n        assert search_counter == 2\n\n    @patch.object(MCPSearchTool, \"execute\", new_callable=AsyncMock)\n    def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_proto, search_data):\n        _, expect_turn_array, tool_return_array = search_data\n        # Mock tool execution for large batch (100 requests * 2 calls each)\n        mock_execute.side_effect = [\n            (tool_return_array[0], 0.0, {\"status\": \"success\"}),\n            (tool_return_array[1], 0.0, {\"status\": \"success\"}),\n        ] * 100\n\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n\n        req_nums = 100\n        req_list = []\n        req_turns_map = {}\n        req_turns_counter = {}\n\n        for i in range(req_nums):\n            tmp_req = deepcopy(base_req)\n            tmp_req.batch_data_id = i\n            tmp_req.request_id = i\n            req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest))\n\n            futures = [asyncio.Future() for _ in expect_turn_array]\n            for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):\n                fut.set_result(\n                    {\n                        \"text\": turn,\n                        \"meta_info\": {\n                            \"id\": \"dummy\",\n                            \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                            \"prompt_tokens\": len(turn),\n                            \"completion_tokens\": 100,\n                        },\n                    }\n                )\n            req_turns_map[i] = futures\n            req_turns_counter[i] = 0\n\n        async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_kwargs):\n            fut = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]]\n            req_turns_counter[_req.batch_data_id] += 1\n            return await fut\n\n        with patch.object(SGLangRollout, \"_handle_engine_call\", new=hacked_handle_engine_call):\n            mock_rollout._tp_rank = 0\n            loop = asyncio.get_event_loop()\n            output_req_list = loop.run_until_complete(\n                asyncio.gather(*[mock_rollout._async_rollout_a_request(r, True, False) for r in req_list])\n            )\n\n        # Verify all requests completed successfully\n        assert len(output_req_list) == req_nums\n        for out_req in output_req_list:\n            assert out_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n            assert \"tavily_search_tool\" in out_req.metrics\n            for metric in out_req.metrics[\"tavily_search_tool\"]:\n                assert metric[\"status\"] == \"success\"\n            assert len(out_req.messages) == 6\n            assert sum(1 for m in out_req.messages if m.role == \"tool\") == 2\n\n        assert mock_execute.await_count == 2 * req_nums\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py",
    "content": "# Copyright 2025 Amazon.com, Inc. or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport os\n\nimport pytest\n\nfrom verl.tools.schemas import ToolResponse\nfrom verl.utils.dataset.vision_utils import process_image\nfrom verl.utils.tokenizer import hf_processor\nfrom verl.workers.rollout.schemas import (\n    AsyncRolloutRequest,\n    AsyncRolloutRequestStateEnum,\n    TokenizationSanityCheckModeEnum,\n)\n\n\ndef _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False):\n    assert len(image_list) == len(description_list)\n    # Get the smallest dimensions across all images\n    processed_images = []\n    for img_url in image_list:\n        img = process_image(img_url)\n        processed_images.append(img)\n\n    min_width = min(img.size[0] for img in processed_images)\n    min_height = min(img.size[1] for img in processed_images)\n    min_size = (min_width, min_height)\n\n    if resize_image:\n        processed_images_resized = []\n        for img in processed_images:\n            img = img.resize(min_size)\n            processed_images_resized.append(img)\n        processed_images = processed_images_resized\n\n    # Initial message history\n    system_prompt = (\n        \"You will be provided with an image. Describe this image and then generate a new image for the next round\"\n    )\n    messages = [\n        {\n            \"role\": \"system\",\n            \"content\": system_prompt,\n        },\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": \"Here is the first image provided: \"},\n                {\"type\": \"image\", \"image\": [processed_images[0]]},\n            ],\n        },\n    ]\n\n    # Initial multi_modal_data with one image\n    multi_modal_data = {\"image\": [processed_images[0]], \"video\": []}\n    # Minimal required fields for AsyncRolloutRequest\n\n    req = AsyncRolloutRequest(\n        batch_data_id=0,\n        request_id=\"test-req-1\",\n        state=AsyncRolloutRequestStateEnum.PENDING,\n        messages=messages,\n        multi_modal_keys=[\"image\", \"video\"],\n        multi_modal_data=multi_modal_data.copy(),\n        tool_schemas=[],\n        tools_kwargs={},\n        interaction_kwargs={},\n        input_ids=None,\n        prompt_ids=None,\n        response_ids=None,\n        attention_mask=None,\n        prompt_attention_mask=None,\n        response_attention_mask=None,\n        position_ids=None,\n        prompt_position_ids=None,\n        response_position_ids=None,\n        loss_mask=None,\n        prompt_loss_mask=None,\n        response_loss_mask=None,\n        reward_scores={},\n        max_prompt_len=8192,\n        max_response_len=8192,\n        max_model_len=16384,\n        metrics={},\n        use_inference_chat_template=True,\n        tokenization_sanity_check_mode=TokenizationSanityCheckModeEnum.STRICT,\n        generation_prompt_ids=None,\n        base_conv_wo_gen_prompt_end_pos=0,\n        base_conv_with_gen_prompt_end_pos=0,\n        processing_class=processor,\n    )\n\n    prev_generated_len = 0\n    # Add First Assistant Message and first tool response message(image)\n    for idx, img in enumerate(processed_images):\n        if idx == 0:\n            continue\n        _ = req.get_generation_prompt_ids(processor)\n        req.add_assistant_message(processor, content=description_list[idx - 1])\n        before_tool_call_len = req.input_ids.shape[-1]\n        req.add_tool_response_messages(\n            processor, [ToolResponse(image=[img], text=\"Here is the new image you requested: \")]\n        )\n        after_tool_call_len = req.input_ids.shape[-1]\n        if prev_generated_len == 0:\n            prev_generated_len = after_tool_call_len - before_tool_call_len\n        else:\n            if resize_image:\n                assert after_tool_call_len - before_tool_call_len == prev_generated_len\n        assert req.multi_modal_data[\"image\"] == processed_images[: idx + 1]\n\n    _ = req.get_generation_prompt_ids(processor)\n    req.add_assistant_message(processor, content=description_list[-1])\n\n    messages = [msg.model_dump() for msg in req.messages]\n    tools = [tool.model_dump() for tool in req.tool_schemas] if req.tool_schemas else None\n    full_prompt_info = req._handle_apply_chat_template(\n        processor,\n        messages,\n        multi_modal_data=req.multi_modal_data,\n        tools=tools,\n        add_generation_prompt=False,\n        tokenize=True,\n        return_dict=True,\n    )\n    full_prompt_ids = full_prompt_info[\"input_ids\"]\n    assert full_prompt_ids.eq(req.input_ids).all()\n\n    # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict\n    # because np.array() only keeps the keys for BatchFeature.\n    full_prompt_multi_modal_inputs = full_prompt_info.copy()\n    full_prompt_multi_modal_inputs.pop(\"input_ids\", None)\n    full_prompt_multi_modal_inputs.pop(\"attention_mask\", None)\n\n    for key in full_prompt_multi_modal_inputs:\n        assert full_prompt_multi_modal_inputs[key].eq(req.multi_modal_inputs[key]).all()\n\n\n@pytest.mark.skipif(\n    hf_processor(os.path.expanduser(\"~/models/Qwen/Qwen2.5-VL-3B-Instruct\")) is None,\n    reason=\"Processor not available for Qwen/Qwen2.5-VL-B-Instruct\",\n)\ndef test_add_tool_response_messages_image_delta():\n    processor = hf_processor(os.path.expanduser(\"~/models/Qwen/Qwen2.5-VL-3B-Instruct\"))\n\n    # From Qwen2.5-VL-3B-Instruct HF example\n    img_1_url = {\"image\": \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg\"}\n    img_1_description = \"A woman sits on the beach at sunset, smiling as she shares a high five with her large dog.\"\n    # GitHub Logo\n    img_2_url = {\"image\": \"https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png\"}\n    img_2_description = \"A GitHub Logo image\"\n    # Octocat\n    img_3_url = {\"image\": \"https://octodex.github.com/images/orderedlistocat.png\"}\n    img_3_description = \"An Octocat image\"\n\n    image_list = [img_1_url, img_2_url, img_3_url]\n    description_list = [img_1_description, img_2_description, img_3_description]\n    _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False)\n\n\n@pytest.mark.skipif(\n    hf_processor(os.path.expanduser(\"~/models/Qwen/Qwen2.5-VL-3B-Instruct\")) is None,\n    reason=\"Processor not available for Qwen/Qwen2.5-VL-B-Instruct\",\n)\ndef test_add_tool_response_messages_image_delta_resize_image():\n    processor = hf_processor(os.path.expanduser(\"~/models/Qwen/Qwen2.5-VL-3B-Instruct\"))\n\n    # From Qwen2.5-VL-3B-Instruct HF example\n    img_1_url = {\"image\": \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg\"}\n    img_1_description = \"A woman sits on the beach at sunset, smiling as she shares a high five with her large dog.\"\n    # GitHub Logo\n    img_2_url = {\"image\": \"https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png\"}\n    img_2_description = \"A GitHub Logo image\"\n    # Octocat\n    img_3_url = {\"image\": \"https://octodex.github.com/images/orderedlistocat.png\"}\n    img_3_description = \"An Octocat image\"\n\n    image_list = [img_1_url, img_2_url, img_3_url]\n    description_list = [img_1_description, img_2_description, img_3_description]\n    _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=True)\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_sglang_async_rollout_search_tools.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from tests/workers/rollout/test_sglang_async_rollout_sf_tools.py\n\n\nimport asyncio\nimport os\nfrom copy import deepcopy\nfrom unittest.mock import AsyncMock, MagicMock, patch\n\nimport numpy as np\nimport pytest\nfrom tensordict import TensorDict\nfrom transformers import AutoConfig, AutoTokenizer\nfrom utils_sglang import get_rollout_config, prepare_inputs\n\nfrom verl.protocol import DataProto\nfrom verl.tools.schemas import (\n    OpenAIFunctionParametersSchema,\n    OpenAIFunctionPropertySchema,\n    OpenAIFunctionSchema,\n    OpenAIFunctionToolSchema,\n    ToolResponse,\n)\nfrom verl.tools.search_tool import SearchTool\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\nDEFAULT_USER_CONTENT_PREFIX = (\n    \"Answer the given question. You must conduct reasoning inside <think> and </think> \"\n    \"first every time you get new information. After reasoning, if you find you lack \"\n    \"some knowledge, you can call a search engine by <tool_call> query </tool_call> \"\n    \"and it will return the top searched results between <tool_response> and \"\n    \"</tool_response>. You can search as many times as your want. If you find no \"\n    \"further external knowledge needed, you can directly provide the answer inside \"\n    \"<answer> and </answer>, without detailed illustrations. For example, \"\n    \"<answer> Beijing </answer>. Question: \"\n)\nuser_content = DEFAULT_USER_CONTENT_PREFIX.rstrip(\"\\n\") + \"How's the weather lately?\"\n\n\ndef get_search_messages():\n    user_prompt = {\n        \"role\": \"user\",\n        \"content\": user_content,\n    }\n\n    expect_turn_0_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"Let me search the web.\",\n        \"tool_calls\": [{\"type\": \"function\", \"function\": {\"name\": \"search\", \"arguments\": {\"query\": \"today's weather\"}}}],\n    }\n\n    expect_turn_1_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"Let me search again.\",\n        \"tool_calls\": [\n            {\"type\": \"function\", \"function\": {\"name\": \"search\", \"arguments\": {\"query\": \"tomorrow's weather\"}}}\n        ],\n    }\n\n    expect_turn_2_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"<answer>Today is sunny and tomorrow will be cloudy in Beijing.</answer>\",\n    }\n\n    # Mock search tool responses\n    tool_return_0_msg = {\"role\": \"tool\", \"content\": \"Today's weather in Beijing is sunny.\"}\n    tool_return_1_msg = {\"role\": \"tool\", \"content\": \"Tomorrow's weather in Beijing is cloudy.\"}\n\n    user_prompts = [user_prompt]\n    expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg]\n    tool_return_array = [tool_return_0_msg, tool_return_1_msg]\n\n    return user_prompts, expect_turn_array, tool_return_array\n\n\nclass TestRolloutWithSearchTools:\n    local_model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B\")\n\n    @pytest.fixture\n    def qwen_tokenizer(self):\n        tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side=\"left\")\n        tokenizer.pad_token = tokenizer.eos_token\n        return tokenizer\n\n    # we only need this for tokenizer\n    @pytest.fixture\n    def qwen_model_config(self):\n        config = AutoConfig.from_pretrained(self.local_model_path)\n        return config\n\n    @pytest.fixture\n    def search_data(self, qwen_tokenizer):\n        user_prompt, expect_turn_array, tool_return_array = get_search_messages()\n        prompts = [[message] for message in user_prompt]\n        preencode_turn_array = [\n            qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False)\n            for turn in expect_turn_array\n        ]\n        preencode_tool_return_array = [\n            ToolResponse(text=qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True))\n            for turn in tool_return_array\n        ]\n        return prompts, preencode_turn_array, preencode_tool_return_array\n\n    @pytest.fixture\n    def search_rollout_config(self):\n        max_prompt_length = 4096\n        max_response_length = 3000\n        dtype = \"bfloat16\"\n        tensor_parallel_size = 1\n        tool_path = \"./resource/tool_configs/search_tool_config\"\n        rollout_config = get_rollout_config(\n            max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path\n        )\n        return rollout_config\n\n    @pytest.fixture\n    def search_data_proto(self, search_data, qwen_tokenizer):\n        preencode_prompts, _, _ = search_data\n        prompts = [\n            qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n            for message in preencode_prompts\n        ]\n        input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000)\n        prompt_dict = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=input_ids.shape[0],\n        )\n        messages = np.asarray(preencode_prompts)\n\n        tools_kwargs = np.array(\n            [\n                {\n                    \"search\": {\n                        \"create_kwargs\": {\n                            \"ground_truth\": \"Today is sunny and tomorrow will be cloudy in Beijing.\",\n                            \"data_source\": \"searchR1_nq\",\n                        },\n                    },\n                }\n            ],\n            dtype=object,\n        )\n        index = np.array([0], dtype=object)\n        prompts = DataProto(\n            batch=prompt_dict, non_tensor_batch={\"raw_prompt\": messages, \"tools_kwargs\": tools_kwargs, \"index\": index}\n        )\n        return prompts\n\n    @pytest.fixture\n    def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config):\n        \"\"\"Mock the rollout instance with sampling_params initialized.\"\"\"\n        with (\n            patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n            patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n            patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n        ):\n            rollout_config: RolloutConfig = omega_conf_to_dataclass(search_rollout_config, dataclass_type=RolloutConfig)\n            model_config = HFModelConfig(path=self.local_model_path)\n            rollout = SGLangRollout(\n                config=rollout_config,\n                model_config=model_config,\n                device_mesh=None,\n            )\n            rollout.sampling_params = {\n                \"n\": 1,\n                \"max_new_tokens\": search_rollout_config.response_length,\n                \"presence_penalty\": 0.0,\n                \"frequency_penalty\": 0.0,\n                \"repetition_penalty\": 1.0,\n            }\n            return rollout\n\n    @patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None)\n    @patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None)\n    @patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None)\n    def test_tools_registration(\n        self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config\n    ):\n        rollout_config: RolloutConfig = omega_conf_to_dataclass(search_rollout_config, dataclass_type=RolloutConfig)\n        model_config = HFModelConfig(path=self.local_model_path)\n        rollout = SGLangRollout(\n            config=rollout_config,\n            model_config=model_config,\n            device_mesh=None,\n        )\n        assert len(rollout._tool_schemas) == 1\n        assert \"search\" in rollout._tool_map.keys()\n        from verl.tools.search_tool import SearchTool\n\n        assert isinstance(rollout._tool_map[\"search\"], SearchTool)\n        # depend on the tokenizer\n        assert rollout._tool_call_parser_type == \"qwen25\"\n\n    @patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None)\n    @patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None)\n    @patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None)\n    def test_rollout_req_creation(\n        self,\n        mock_env,\n        mock_engine,\n        mock_sampling,\n        search_rollout_config,\n        qwen_tokenizer,\n        qwen_model_config,\n        search_data_proto,\n    ):\n        rollout_config: RolloutConfig = omega_conf_to_dataclass(search_rollout_config, dataclass_type=RolloutConfig)\n        model_config = HFModelConfig(path=self.local_model_path)\n        rollout = SGLangRollout(\n            config=rollout_config,\n            model_config=model_config,\n            device_mesh=None,\n        )\n        req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)\n        assert len(req_list) == 1\n        assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING\n        assert len(req_list[0].tool_schemas) == 1\n        print(type(req_list[0].tool_schemas[0]))\n        assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema(\n            type=\"function\",\n            function=OpenAIFunctionSchema(\n                name=\"search\",\n                description=\"Searches the web for relevant information based on the given query.\",\n                parameters=OpenAIFunctionParametersSchema(\n                    type=\"object\",\n                    properties={\n                        \"query_list\": OpenAIFunctionPropertySchema(\n                            type=\"array\",\n                            description=\"A list of fully-formed semantic queries. The tool will return search \"\n                            \"results for each query.\",\n                            items={\"type\": \"string\"},\n                        )\n                    },\n                    required=[\"query_list\"],\n                ),\n                strict=False,\n            ),\n        )\n\n    def test_over_size_case(self, mock_rollout, search_data_proto, search_data):\n        mock_rollout.config.multi_turn.max_assistant_turns = 1\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n\n        _, expect_turn_array, _ = search_data\n        mock_rollout._handle_engine_call = MagicMock()\n        future = asyncio.Future()\n        future.set_result(\n            {\n                \"text\": expect_turn_array[0],\n                \"meta_info\": {\n                    \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                    \"finish_reason\": {\"type\": \"length\", \"length\": 3000},\n                    \"prompt_tokens\": 132,\n                    \"completion_tokens\": 100,\n                    \"cached_tokens\": 0,\n                    \"e2e_latency\": 2.23543,\n                },\n            }\n        )\n        mock_rollout._handle_engine_call.return_value = future\n        mock_rollout._tp_rank = 0\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(\n                *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list],\n            )\n        )\n        assert len(output_req_list) == 1\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        assert output_req.reward_scores.get(\"search\") == []\n        assert len(output_req.messages) == 2\n        assert output_req.messages[1] == Message(\n            role=\"assistant\",\n            content=expect_turn_array[0],\n            tool_calls=None,\n        )\n\n    @patch.object(SearchTool, \"execute\", new_callable=AsyncMock)\n    def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_proto, search_data):\n        _, expect_turn_array, tool_return_array = search_data\n\n        # Mock search tool execution to return predefined responses\n        mock_execute.side_effect = [(msg, 0.0, {\"status\": \"success\"}) for msg in tool_return_array]\n\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        mock_rollout._tool_map[\"search\"].retrieval_service_url = \"mock://dummy\"\n\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n\n        mock_rollout._handle_engine_call = MagicMock()\n        futures = [asyncio.Future() for i in expect_turn_array]\n        for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):\n            i.set_result(\n                {\n                    \"text\": turn,\n                    \"meta_info\": {\n                        \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                        \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                        \"prompt_tokens\": len(turn),\n                        \"completion_tokens\": 100,\n                        \"cached_tokens\": 0,\n                        \"e2e_latency\": 2.23543,\n                    },\n                }\n            )\n            if idx < len(expect_turn_array) - 1:\n                assert mock_rollout._function_call_parser.has_tool_call(turn)\n                assert mock_rollout._function_call_parser.parse_non_stream(turn)\n\n        mock_rollout._handle_engine_call.side_effect = futures\n        mock_rollout._tp_rank = 0\n\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(*[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list])\n        )\n\n        # Verify conversation completed successfully with proper tool usage\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        assert \"search\" in output_req.metrics\n        assert output_req.metrics[\"search\"][0][\"status\"] == \"success\"\n        assert mock_execute.await_count == 2\n        assert len(output_req.messages) == 6  # user + 3*assistant + 2*tool_call\n        # Verify tool response messages contain expected content\n        search_counter = 0\n        for msg in output_req.messages:\n            if msg.role == \"tool\":\n                assert msg.content == tool_return_array[search_counter].text\n                search_counter += 1\n        assert search_counter == 2\n\n    @patch.object(SearchTool, \"execute\", new_callable=AsyncMock)\n    def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_proto, search_data):\n        _, expect_turn_array, tool_return_array = search_data\n\n        # Mock tool execution for large batch (100 requests * 2 calls each)\n        mock_execute.side_effect = [\n            (tool_return_array[0], 0.0, {\"status\": \"success\"}),\n            (tool_return_array[1], 0.0, {\"status\": \"success\"}),\n        ] * 100\n\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        mock_rollout._tool_map[\"search\"].retrieval_service_url = \"mock://dummy\"\n\n        base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n\n        req_nums = 100\n        req_list = []\n        req_turns_map = {}\n        req_turns_counter = {}\n\n        for i in range(req_nums):\n            tmp_req = deepcopy(base_req)\n            tmp_req.batch_data_id = i\n            tmp_req.request_id = i\n            req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest))\n\n            futures = [asyncio.Future() for _ in expect_turn_array]\n            for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):\n                fut.set_result(\n                    {\n                        \"text\": turn,\n                        \"meta_info\": {\n                            \"id\": \"dummy\",\n                            \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                            \"prompt_tokens\": len(turn),\n                            \"completion_tokens\": 100,\n                        },\n                    }\n                )\n            req_turns_map[i] = futures\n            req_turns_counter[i] = 0\n\n        async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_kwargs):\n            fut = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]]\n            req_turns_counter[_req.batch_data_id] += 1\n            return await fut\n\n        with patch.object(SGLangRollout, \"_handle_engine_call\", new=hacked_handle_engine_call):\n            mock_rollout._tp_rank = 0\n            loop = asyncio.get_event_loop()\n            output_req_list = loop.run_until_complete(\n                asyncio.gather(*[mock_rollout._async_rollout_a_request(r, True, False) for r in req_list])\n            )\n\n        # Verify all requests completed successfully\n        assert len(output_req_list) == req_nums\n        for out_req in output_req_list:\n            assert out_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n            assert \"search\" in out_req.metrics\n            for metric in out_req.metrics[\"search\"]:\n                assert metric[\"status\"] == \"success\"\n            assert len(out_req.messages) == 6  # user + 3 assistant + 2 tool\n            assert sum(1 for m in out_req.messages if m.role == \"tool\") == 2\n\n        assert mock_execute.await_count == 2 * req_nums\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport os\nimport time\nfrom copy import deepcopy\nfrom functools import wraps\nfrom unittest.mock import MagicMock, patch\n\nimport numpy as np\nimport pytest\nimport ray\nfrom tensordict import TensorDict\nfrom torch.testing._internal.common_distributed import MultiProcessTestCase\nfrom transformers import AutoConfig, AutoTokenizer\nfrom utils_sglang import get_rollout_config, prepare_inputs\n\nfrom verl.protocol import DataProto\nfrom verl.tools.sandbox_fusion_tools import TokenBucketWorker\nfrom verl.tools.schemas import (OpenAIFunctionParametersSchema,\n                                OpenAIFunctionPropertySchema,\n                                OpenAIFunctionSchema, OpenAIFunctionToolSchema,\n                                ToolResponse)\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.schemas import (AsyncRolloutRequest,\n                                          AsyncRolloutRequestStateEnum,\n                                          Message)\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\nsandbox_url = \"\"\n\n\ndef get_sandbox_fusion_messages():\n    user_prompt = {\n        \"role\": \"user\",\n        \"content\": \"\"\"\n            Solve the following problem step by step. You now have the ability to selectively \n            write executable Python code to enhance your reasoning process. \\n\\n**user question:**\\nThere \n            are 152 students at Dala High School. Assume the following: \\n- 100 students take a Math class \\n- 94 \n            students take a Science class \\n- 57 students take an English class \\n- 73 students take a Math class \n            and a Science class \\n- 24 students take a Math class and an English class \\n- 27 students take a Science \n            class and an English class \\n- 22 students take a Math class and a Science class and an English class\\n \\nHow \n            many students take neither a Math class nor a Science class nor an Eglish class?\\n\\nRemember to place the final \n            answer in the last part using the format: \\n<answer>\\n\\boxed{'The final answer goes here.'}\\n</answer>\n        \"\"\",\n    }\n    expect_turn_0_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"\"\"\n            Okay, so I need to find out how many students at Dala High School are not taking any of the three classes: Math, \n            Science, or English. The total number of students is 152. Let me see... I remember this is a problem about sets \n            and maybe using the principle of inclusion-exclusion. Let me recall how that works.\\n\\nFirst, the inclusion-exclusion \n            principle for three sets says that the total number of students taking at least one of the classes is equal to the \n            sum of the numbers in each individual class, minus the sum of the numbers in each pair of classes, plus the number in \n            all three classes. Then, subtract that total from the overall number of students to get those not taking any of the \n            three. \\n\\nLet me write that down step by step. Let M be the set of students taking Math, S for Science, and E for English. \n            Then:\\n\\nTotal in at least one class = |M ∪ S ∪ E| = |M| + |S| + |E| - |M ∩ S| - |M ∩ E| - |S ∩ E| + |M ∩ S ∩ E|\\n\\nGiven the \n            numbers:\\n\\n|M| = 100\\n\\n|S| = 94\\n\\n|E| = 57\\n\\n|M ∩ S| = 73\\n\\n|M ∩ E| = 24\\n\\n|S ∩ E| = 27\\n\\n|M ∩ S ∩ E| = 22\\n\\nSo plugging \n            these into the formula:\\n\\nTotal = 100 + 94 + 57 - 73 - 24 - 27 + 22\\n\\nLet me compute that step by step using code to ensure \n            accuracy.\\n\n        \"\"\",\n        \"tool_calls\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"code_interpreter\",\n                    \"arguments\": {\n                        \"code\": \"\"\"M = 100\\\\nS = 94\\\\nE = 57\\\\nM_S = 73\\\\nM_E = 24\\\\nS_E = 27\\\\nM_S_E = \n                        22\\\\n\\\\ntotal_in_any = M + S + E - M_S - M_E - S_E + M_S_E\\\\nstudents_neither = 152 - total_in_any\\\\nprint(students_neither)\"\"\",\n                    },\n                },\n            }\n        ],\n    }\n\n    tool_return_0_msg = {\n        \"role\": \"tool\",\n        \"content\": \"\"\"3\"\"\",\n    }\n\n    expect_turn_1_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"\"\"\n            The code returns 3, indicating that 3 students take none of the classes. To verify, I'll break down each\n            category to ensure the total matches:\\n\\n- Only Math and Science: 73 - 22 = 51\\n- Only Math and English:\n            24 - 22 = 2\\n- Only Science and English: 27 - 22 = 5\\n- Only Math: 100 - 51 - 2 - 22 = 25\\n- Only Science:\n            94 - 51 - 5 - 22 = 16\\n- Only English: 57 - 2 - 5 - 22 = 28\\n\\nSumming all categories:\\n\n        \"\"\",\n        \"tool_calls\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"code_interpreter\",\n                    \"arguments\": {\n                        \"code\": \"\"\"only_M_S = 73 - 22\\\\nonly_M_E = 24 - 22\\\\nonly_S_E = 27 - 22\\\\n\\\\nonly_M = 100 - only_M_S -\n                        only_M_E - 22\\\\nonly_S = 94 - only_M_S - only_S_E - 22\\\\nonly_E = 57 - only_M_E - only_S_E - 22\\\\n\\\\ntotal_verify\n                        = only_M + only_S + only_E + only_M_S + only_M_E + only_S_E + 22\\\\nprint(total_verify)\"\"\",\n                    },\n                },\n            }\n        ],\n    }\n\n    tool_return_1_msg = {\n        \"role\": \"tool\",\n        \"content\": \"\"\"149\"\"\",\n    }\n    expect_turn_2_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"\"\"\n            The verification total is 149, so students not taking any classes are 152 - 149 = 3, confirming the initial\n            result.\\n\\n<answer>\\n\\\\boxed{3}\\n</answer>\n        \"\"\",\n    }\n\n    user_prompts = [user_prompt]\n    expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg]\n    tool_return_array = [tool_return_0_msg, tool_return_1_msg]\n\n    return user_prompts, expect_turn_array, tool_return_array\n\n\ndef skip_if_valid_sandbox(url):\n    def decorator(func):\n        @wraps(func)\n        def wrapper(*args, **kwargs):\n            if url == \"\" or url is None:\n                pytest.skip(\"No valid sandbox url provided\")\n\n        return wrapper\n\n    return decorator\n\n\nclass TestRolloutWithTools:\n    local_model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B\")\n\n    @pytest.fixture\n    def qwen_tokenizer(self):\n        tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side=\"left\")\n        tokenizer.pad_token = tokenizer.eos_token\n        return tokenizer\n\n    # we only need this for tokenizer\n    @pytest.fixture\n    def qwen_model_config(self):\n        config = AutoConfig.from_pretrained(self.local_model_path)\n        return config\n\n    @pytest.fixture\n    def sandbox_fusion_data(self, qwen_tokenizer):\n        user_prompt, expect_turn_array, tool_return_array = get_sandbox_fusion_messages()\n        prompts = [[message] for message in user_prompt]\n        preencode_turn_array = [\n            qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False)\n            for turn in expect_turn_array\n        ]\n        preencode_tool_return_array = [\n            ToolResponse(text=qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True))\n            for turn in tool_return_array\n        ]\n        return prompts, preencode_turn_array, preencode_tool_return_array\n\n    @pytest.fixture\n    def sandbox_fusion_rollout_config(self):\n        max_prompt_length = 1024\n        max_response_length = 1024\n        dtype = \"bfloat16\"\n        tensor_parallel_size = 1\n        tool_path = \"./resource/tool_configs/sandbox_fusion_tool_config\"\n        rollout_config = get_rollout_config(\n            max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path\n        )\n        return rollout_config\n\n    @pytest.fixture\n    def sandbox_data_proto(self, sandbox_fusion_data, qwen_tokenizer):\n        preencode_prompts, _, _ = sandbox_fusion_data\n        prompts = [\n            qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n            for message in preencode_prompts\n        ]\n        input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000)\n        prompt_dict = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=input_ids.shape[0],\n        )\n        messages = np.asarray(preencode_prompts)\n        tools_kwargs = np.array(\n            [\n                {\n                    \"code_interpreter\": {\n                        \"create_kwargs\": {\"ground_truth\": \"test-solution-str\"},\n                    },\n                }\n            ],\n            dtype=object,\n        )\n        index = np.array([0], dtype=object)\n        prompts = DataProto(\n            batch=prompt_dict, non_tensor_batch={\"raw_prompt\": messages, \"tools_kwargs\": tools_kwargs, \"index\": index}\n        )\n        return prompts\n\n    @pytest.fixture\n    def mock_rollout(self, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config):\n        \"\"\"Mock the rollout instance\"\"\"\n        with patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None), patch.object(\n            SGLangRollout, \"_init_inference_engine\", return_value=None\n        ), patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None):\n            rollout_config: RolloutConfig = omega_conf_to_dataclass(sandbox_fusion_rollout_config, dataclass_type=RolloutConfig)\n            model_config = HFModelConfig(path=self.local_model_path)\n            rollout = SGLangRollout(\n                config=rollout_config,\n                model_config=model_config,\n                device_mesh=None,\n            )\n            # set default sampling_params\n            rollout.sampling_params = {\n                \"n\": 1,\n                \"max_new_tokens\": sandbox_fusion_rollout_config.response_length,\n                \"presence_penalty\": 0.0,\n                \"frequency_penalty\": 0.0,\n                \"repetition_penalty\": 1.0,\n            }\n            return rollout\n\n    def test_tools_registration(self, mock_rollout):\n        \"\"\"Test tool registration functionality\"\"\"\n        assert len(mock_rollout._tool_schemas) == 1\n        assert \"code_interpreter\" in mock_rollout._tool_map.keys()\n        from verl.tools.sandbox_fusion_tools import SandboxFusionTool\n\n        assert isinstance(mock_rollout._tool_map[\"code_interpreter\"], SandboxFusionTool)\n        assert mock_rollout._tool_call_parser_type == \"qwen25\"\n\n    def test_rollout_req_creation(self, mock_rollout, sandbox_data_proto):\n        \"\"\"Test request creation functionality\"\"\"\n        req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)\n        assert len(req_list) == 1\n        assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING\n        assert len(req_list[0].tool_schemas) == 1\n        print(type(req_list[0].tool_schemas[0]))\n        assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema(\n            type=\"function\",\n            function=OpenAIFunctionSchema(\n                name=\"code_interpreter\",\n                description=\"A tool for executing code.\",\n                parameters=OpenAIFunctionParametersSchema(\n                    type=\"object\",\n                    properties={\n                        \"code\": OpenAIFunctionPropertySchema(\n                            type=\"string\",\n                            description=\"The code to execute.\",\n                            enum=None,\n                        )\n                    },\n                    required=[\"code\"],\n                ),\n                strict=False,\n            ),\n        )\n\n    def test_over_size_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data):\n        \"\"\"Test over-size response truncation case\"\"\"\n        mock_rollout.config.multi_turn.max_assistant_turns = 1\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n\n        _, expect_turn_array, tool_return_array = sandbox_fusion_data\n        # here we mock a meta info with 'length'. indicate the response is truncate\n        mock_rollout._handle_engine_call = MagicMock()\n        future = asyncio.Future()\n        future.set_result(\n            {\n                \"text\": expect_turn_array[0],\n                \"meta_info\": {\n                    \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                    \"finish_reason\": {\"type\": \"length\", \"length\": 1024},\n                    \"prompt_tokens\": 132,\n                    \"completion_tokens\": 100,\n                    \"cached_tokens\": 0,\n                    \"e2e_latency\": 9.9304039478302,\n                },\n            }\n        )\n        mock_rollout._handle_engine_call.return_value = future\n        mock_rollout._tp_rank = 0\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(\n                *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list],\n            )\n        )\n        assert len(output_req_list) == 1\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        assert output_req.reward_scores.get(\"code_interpreter\") == []\n        # we should only have two message, one for prompt, second for response.\n        assert len(output_req.messages) == 2\n        assert output_req.messages[1] == Message(\n            role=\"assistant\",\n            content=expect_turn_array[0],\n            tool_calls=None,\n        )\n\n    @skip_if_valid_sandbox(sandbox_url)\n    def test_tool_call_basic_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data):\n        \"\"\"Test basic tool call case\"\"\"\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        mock_rollout._tool_map[\"code_interpreter\"].sandbox_fusion_url = sandbox_url\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n        _, expect_turn_array, tool_return_array = sandbox_fusion_data\n        # here we mock a meta info with 'length'. indicate the response is truncate\n        mock_rollout._handle_engine_call = MagicMock()\n        futures = [asyncio.Future() for i in expect_turn_array]\n        for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)):\n            i.set_result(\n                {\n                    \"text\": turn,\n                    \"meta_info\": {\n                        \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                        \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                        \"prompt_tokens\": len(turn),\n                        \"completion_tokens\": 100,\n                        \"cached_tokens\": 0,\n                        \"e2e_latency\": 9.9304039478302,\n                    },\n                }\n            )\n            if idx < len(expect_turn_array) - 1:\n                assert mock_rollout._function_call_parser.has_tool_call(turn)\n                assert mock_rollout._function_call_parser.parse_non_stream(turn)\n\n        mock_rollout._handle_engine_call.side_effect = futures\n        mock_rollout._tp_rank = 0\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(\n                *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list],\n            )\n        )\n        assert len(output_req_list) == 1\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        # here we verify whether the code sandbox is executed correctly\n        assert output_req.metrics == {\"code_interpreter\": [\"3\", \"149\"]}\n        assert mock_rollout._handle_engine_call.call_count == 3\n        assert len(output_req.messages) == 6  # user + 3*assistant + 2*tool_call\n        code_counter = 0\n        for msg in output_req.messages:\n            if msg.role == \"tool\":\n                code_counter += 1\n                assert msg.content == tool_return_array[code_counter]\n        assert code_counter == 2\n\n    @skip_if_valid_sandbox(sandbox_url)\n    def test_tool_call_batch_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data):\n        \"\"\"Test batch tool call case\"\"\"\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        mock_rollout._tool_map[\"code_interpreter\"].sandbox_fusion_url = sandbox_url\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]\n        req_nums = 100\n        req_list = []\n        req_turns_counter = {}\n        # this map should a Map[id:List[Futures]]\n        req_turns_map = {}\n        _, expect_turn_array, tool_return_array = sandbox_fusion_data\n        for i in range(req_nums):\n            _temp_req = deepcopy(req)\n            _temp_req.batch_data_id = i\n            _temp_req.request_id = i\n            req_list.append(MagicMock(wraps=_temp_req, spec=AsyncRolloutRequest))\n            futures = [asyncio.Future() for i in expect_turn_array]\n            for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)):\n                i.set_result(\n                    {\n                        \"text\": turn,\n                        \"meta_info\": {\n                            \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                            \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                            \"prompt_tokens\": len(turn),\n                            \"completion_tokens\": 100,\n                            \"cached_tokens\": 0,\n                            \"e2e_latency\": 9.9304039478302,\n                        },\n                    }\n                )\n                if idx < len(expect_turn_array) - 1:\n                    assert mock_rollout._function_call_parser.has_tool_call(turn)\n                    assert mock_rollout._function_call_parser.parse_non_stream(turn)\n            req_turns_map[_temp_req.batch_data_id] = futures\n            req_turns_counter[_temp_req.batch_data_id] = 0\n\n        async def hacked_handle_engine_call(\n            self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs\n        ):\n            result = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]]\n            req_turns_counter[_req.batch_data_id] += 1\n            re = await result\n            return re\n\n        with patch.object(SGLangRollout, \"_handle_engine_call\", new=hacked_handle_engine_call):\n            mock_rollout._tp_rank = 0\n            loop = asyncio.get_event_loop()\n            output_req_list = loop.run_until_complete(\n                asyncio.gather(\n                    *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list],\n                )\n            )\n            assert len(output_req_list) == req_nums\n            # FIGUER out how to count this\n            # assert rollout._handle_engine_call.call_count == 3 * req_nums\n            for output_req in output_req_list:\n                assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n                # here we verify whether the code sandbox is executed correctly\n                assert output_req.metrics == {\"code_interpreter\": [\"3\", \"149\"]}\n                assert len(output_req.messages) == 6  # user + 3*assistant + 2*tool_call\n                code_counter = 0\n                for msg in output_req.messages:\n                    if msg.role == \"tool\":\n                        code_counter += 1\n                assert code_counter == 2\n\n    def test_sampling_params_functionality(self, mock_rollout):\n        \"\"\"Test sampling_params functionality\"\"\"\n        # test basic copy functionality\n        copied_params = mock_rollout.sampling_params.copy()\n        assert copied_params == mock_rollout.sampling_params\n        assert copied_params is not mock_rollout.sampling_params\n\n        # test parameter update\n        copied_params.update({\"temperature\": 0.8, \"top_p\": 0.9})\n        assert copied_params[\"temperature\"] == 0.8\n        assert copied_params[\"top_p\"] == 0.9\n\n        # ensure original parameters are not modified\n        assert \"temperature\" not in mock_rollout.sampling_params\n        assert \"top_p\" not in mock_rollout.sampling_params\n\n\nclass RayMultiProcessTestCase(MultiProcessTestCase):\n    def setUp(self):\n        super().setUp()\n        ray.init(ignore_reinit_error=True)\n        print(\"init_single cluster\")\n        self._spawn_processes()\n\n    def tearDown(self):\n        print(\"tearDown_single cluster\")\n        ray.shutdown()\n\n\n@ray.remote\nclass TestActor:\n    def __init__(self, rank, world_size):\n        self._world_size = world_size\n        self._rank = rank\n        self.rank_list = []\n        self.time_list = []\n\n    def record_rank(self, rank):\n        self.rank_list.append(rank)\n\n    def get_rank(self):\n        return self._rank\n\n    def ping(self):\n        return True\n\n    def record_execution_time(self, time):\n        self.time_list.append(time)\n\n    def get_time(self, timeout):\n        import time\n\n        now = time.time()\n        while time.time() - now < timeout:\n            # for start and end time\n            if len(self.time_list) == self._world_size * 2:\n                self.time_list.sort()\n                return self.time_list[-1] - self.time_list[0]\n            else:\n                time.sleep(1)\n                continue\n        return False\n\n    def verify_rank(self):\n        import time\n\n        now = time.time()\n        while time.time() - now < 10:\n            if len(self.rank_list) == self._world_size:\n                print(self.rank_list)\n                self.rank_list.sort()\n                for i in range(self._world_size):\n                    if self.rank_list[i] != i:\n                        return False\n                return True\n            else:\n                time.sleep(1)\n                continue\n        return False\n\n\nclass TestRayGlobalActorCase(RayMultiProcessTestCase):\n    @property\n    def world_size(self) -> int:\n        # for DP = 8\n        return 2\n\n    def test_basic_multi_process_init(self):\n        ray.init(\"auto\", namespace=\"test\", ignore_reinit_error=True)\n        handle = TestActor.remote(self.rank, self.world_size)\n        re = ray.get(handle.get_rank.remote())\n        assert re == self.rank, f\"rank not match: {re} != {self.rank}\"\n\n    # def test_global_actor(self):\n    #     ray.init(\"auto\",namespace=\"test\",ignore_reinit_error=True)\n    #     handle = TestActor.options(get_if_exists=True,name=\"test-actor\").remote(self.rank,self.world_size)\n    #     handle.record_rank.remote(self.rank)\n    #     # since test actor's concurrency is 1, we need to wait for all processes to finish\n    #     time.sleep(5)\n    #     assert ray.get(handle.ping.remote()) == True # make sure actor handle is valid\n    #     if self.rank == 0:\n    #         assert ray.get(handle.verify_rank.remote()) == True\n    #     else:\n    #         # get_actor use weak_ref, so we need to make sure the actor is not garbage collected\n    #         time.sleep(10)\n\n\nclass TestSingleNodeRateLimiterCase(RayMultiProcessTestCase):\n    @property\n    def world_size(self) -> int:\n        return 1\n\n    def test_rate_limiter(self):\n        ray.init(\"auto\", namespace=\"test\", ignore_reinit_error=True)\n        from verl.tools.sandbox_fusion_tools import (PoolMode,\n                                                     init_execution_pool)\n\n        # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=3)\n        exec_worker = init_execution_pool(\n            num_workers=10, enable_global_rate_limit=True, rate_limit=3, mode=PoolMode.ThreadMode\n        )\n        center = TestActor.options(get_if_exists=True, name=\"test-actor\").remote(self.rank, self.world_size)\n        ray.get(exec_worker.ping.remote())\n\n        def fn(i):\n            import time\n\n            time.sleep(3)\n            return i\n\n        start = time.time()\n        tasks = [exec_worker.execute.remote(fn, i) for i in range(6)]\n        loop = asyncio.get_event_loop()\n        results = loop.run_until_complete(asyncio.gather(*tasks))\n        end = time.time()\n        duration = end - start\n        center.record_execution_time.remote(start)\n        center.record_execution_time.remote(end)\n        print(f\"Total time: {duration:.2f} seconds for rank: {self.rank}\")\n\n        assert results == list(range(6))\n        # we have 6 task with rate limit of 3, therefore we need at least 2 round: 3*2=6 seconds\n        assert duration > 6\n        assert duration < 10\n\n    def test_rotten_execution(self):\n        ray.init(\"auto\", namespace=\"test\", ignore_reinit_error=True)\n        from verl.tools.sandbox_fusion_tools import (PoolMode,\n                                                     init_execution_pool)\n\n        # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=6)\n        exec_worker = init_execution_pool(\n            num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode\n        )\n        ray.get(exec_worker.ping.remote())\n\n        def fn(i):\n            if i == 10:\n                raise Exception(\"test\")\n            else:\n                return i\n\n        tasks = [exec_worker.execute.remote(fn, i) for i in range(20)]\n        loop = asyncio.get_event_loop()\n        results = loop.run_until_complete(asyncio.gather(*tasks))\n        expect_result = [None] + list(range(10)) + list(range(11, 20))\n        sorted_data = sorted(results, key=lambda x: (x is not None, x))\n        assert sorted_data == expect_result, f\"results: {results}, expect_result: {expect_result}\"\n        rate_limiter = TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote()\n        rate = ray.get(rate_limiter.get_current_count.remote())\n        assert rate == 0, f\"rate: {rate}\"\n\n\nclass TestMultiNodeRateLimiterCase(RayMultiProcessTestCase):\n    @property\n    def world_size(self) -> int:\n        return 2\n\n    def test_rate_limiter(self):\n        ray.init(\"auto\", namespace=\"test\", ignore_reinit_error=True)\n        from verl.tools.sandbox_fusion_tools import (PoolMode,\n                                                     init_execution_pool)\n\n        # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=6)\n        exec_worker = init_execution_pool(\n            num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode\n        )\n        center = TestActor.options(get_if_exists=True, name=\"test-actor\").remote(self.rank, self.world_size)\n        ray.get(exec_worker.ping.remote())\n\n        def fn(i):\n            import time\n\n            time.sleep(2)\n            return i\n\n        start = time.time()\n        tasks = [exec_worker.execute.remote(fn, i) for i in range(6)]\n        loop = asyncio.get_event_loop()\n        results = loop.run_until_complete(asyncio.gather(*tasks))\n        end = time.time()\n        duration = end - start\n        center.record_execution_time.remote(start)\n        center.record_execution_time.remote(end)\n        print(f\"Total time: {duration:.2f} seconds for rank: {self.rank}\")\n        assert results == list(range(6))\n        time.sleep(5)\n        if self.rank == 0:\n            total_cost = ray.get(center.get_time.remote(10))\n            print(f\"for total cost: {total_cost}\")\n            # # we have 6 task each node * 2node = 12 task, each task take 2 second.\n            # with rate limit of 6,\n            # therefore we need at least 2 round: 12/6*2=4 seconds\n            assert total_cost > 4, total_cost\n        else:\n            time.sleep(10)\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nusage: torchrun --standalone --nnodes=1 \\\n    --nproc_per_node=2 $(which pytest) \\\n    -s test_sglang_async_rollout_w_interaction.py\n\"\"\"\n\nimport numpy as np\nimport torch\nfrom tensordict import TensorDict\nfrom utils_sglang import (\n    are_lists_similar,\n    clean_torchelastic_env,\n    generate_hf_output,\n    get_rollout_config,\n    initialize_global_process_group,\n    load_tokenizer_and_model,\n    prepare_inputs,\n)\n\nfrom verl import DataProto\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\n\ndef test_async_sglang_rollout_w_interaction():\n    import os\n\n    assert torch.cuda.device_count() >= 2\n    initialize_global_process_group()\n    clean_torchelastic_env()\n\n    max_prompt_length = 32\n    max_response_length = 16\n    dtype = \"bfloat16\"\n    tensor_parallel_size = 2\n    local_model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B\")\n\n    tokenizer, actor_model = load_tokenizer_and_model(local_model_path)\n\n    preencode_prompts = [\n        [{\"role\": \"user\", \"content\": prompt, \"tool_calls\": None}]\n        for prompt in [\n            \"Who won the Champions League in 2019?\",\n            \"The founder of Apple is\",\n            \"What's the best way to learn python?\",\n        ]\n    ]\n    interaction_kwargs = [\n        {\"name\": \"gsm8k\", \"query\": \"Who won the Champions League in 2019?\", \"ground_truth\": \"Real Madrid\"},\n        {\"name\": \"gsm8k\", \"query\": \"The founder of Apple is\", \"ground_truth\": \"Steve Jobs\"},\n        {\"name\": \"gsm8k\", \"query\": \"What's the best way to learn python?\", \"ground_truth\": \"Learn python from scratch\"},\n    ]\n    prompts = [\n        tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n        for message in preencode_prompts\n    ]\n    input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length)\n\n    hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)\n\n    # Create a temporary interaction config file for testing\n    import tempfile\n\n    from omegaconf import OmegaConf\n\n    interaction_config = {\n        \"interaction\": [\n            {\"name\": \"gsm8k\", \"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\", \"config\": {}}\n        ]\n    }\n\n    with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n        OmegaConf.save(interaction_config, f.name)\n        interaction_config_path = f.name\n\n    rollout_config = get_rollout_config(\n        max_response_length, max_prompt_length, dtype, tensor_parallel_size, None, interaction_config_path\n    )\n    rollout_config: RolloutConfig = omega_conf_to_dataclass(rollout_config, dataclass_type=RolloutConfig)\n    model_config = HFModelConfig(path=local_model_path)\n    rollout = SGLangRollout(\n        config=rollout_config,\n        model_config=model_config,\n        device_mesh=None,\n    )\n\n    prompt_dict = TensorDict(\n        {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"position_ids\": position_ids,\n        },\n        batch_size=input_ids.shape[0],\n    )\n    print(f\"preprocessed {input_ids.shape=}\")\n\n    messages = np.asarray(preencode_prompts)\n    prompts = DataProto(\n        batch=prompt_dict,\n        non_tensor_batch={\"raw_prompt\": messages, \"interaction_kwargs\": np.asarray(interaction_kwargs)},\n    )\n\n    prompts.meta_info.update(\n        {\n            \"eos_token_id\": tokenizer.eos_token_id,\n            \"pad_token_id\": tokenizer.pad_token_id,\n        }\n    )\n\n    # log_gpu_memory_usage(\"Before generating sequences\", logger=None)\n    output = rollout.generate_sequences(prompts=prompts)\n    print(f\"generated {output.batch['responses'].shape=}\")\n    # log_gpu_memory_usage(\"After generating sequences\", logger=None)\n\n    sglang_output = output.to(\"cpu\")\n\n    sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch[\"responses\"])\n\n    print(f\"hf response: {hf_response_tokens}\")\n    print(f\"sglang response: {sglang_response_tokens}\")\n    assert are_lists_similar(hf_response_tokens, sglang_response_tokens)\n    print(\"SGLang w interaction Test Passed!\")\n\n    # Clean up temporary config file\n    import os\n\n    os.unlink(interaction_config_path)\n\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    test_async_sglang_rollout_w_interaction()\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_sglang_async_rollout_w_tools.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nusage: torchrun --standalone --nnodes=1 \\\n    --nproc_per_node=2 $(which pytest) \\\n    -s test_sglang_async_rollout_w_tools.py\n\"\"\"\n\nimport numpy as np\nimport torch\nfrom tensordict import TensorDict\nfrom utils_sglang import (\n    are_lists_similar,\n    clean_torchelastic_env,\n    generate_hf_output,\n    get_rollout_config,\n    initialize_global_process_group,\n    load_tokenizer_and_model,\n    prepare_inputs,\n)\n\nfrom verl import DataProto\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\n\ndef test_async_sglang_rollout_w_tool():\n    import os\n\n    assert torch.cuda.device_count() >= 2\n    initialize_global_process_group()\n    clean_torchelastic_env()\n\n    max_prompt_length = 32\n    max_response_length = 16\n    dtype = \"bfloat16\"\n    tensor_parallel_size = 2\n    local_model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B\")\n\n    tokenizer, actor_model = load_tokenizer_and_model(local_model_path)\n\n    preencode_prompts = [\n        [{\"role\": \"user\", \"content\": prompt, \"tool_calls\": None}]\n        for prompt in [\n            \"Who won the Champions League in 2019?\",\n            \"The founder of Apple is\",\n            \"What's the best way to learn python?\",\n        ]\n    ]\n    prompts = [\n        tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n        for message in preencode_prompts\n    ]\n    input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length)\n\n    hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)\n\n    rollout_config = get_rollout_config(\n        max_response_length,\n        max_prompt_length,\n        dtype,\n        tensor_parallel_size,\n        \"./resource/tool_configs/sandbox_fusion_tool_config\",\n    )\n    rollout_config: RolloutConfig = omega_conf_to_dataclass(rollout_config, dataclass_type=RolloutConfig)\n    model_config = HFModelConfig(path=local_model_path)\n    rollout = SGLangRollout(\n        config=rollout_config,\n        model_config=model_config,\n        device_mesh=None,\n    )\n\n    prompt_dict = TensorDict(\n        {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"position_ids\": position_ids,\n        },\n        batch_size=input_ids.shape[0],\n    )\n    print(f\"preprocessed {input_ids.shape=}\")\n\n    messages = np.asarray(preencode_prompts)\n    prompts = DataProto(\n        batch=prompt_dict,\n        non_tensor_batch={\n            \"raw_prompt\": messages,\n            \"tools_kwargs\": np.array([{}] * input_ids.shape[0], dtype=object),\n        },\n    )\n\n    prompts.meta_info.update(\n        {\n            \"eos_token_id\": tokenizer.eos_token_id,\n            \"pad_token_id\": tokenizer.pad_token_id,\n        }\n    )\n\n    # log_gpu_memory_usage(\"Before generating sequences\", logger=None)\n    output = rollout.generate_sequences(prompts=prompts)\n    print(f\"generated {output.batch['responses'].shape=}\")\n    # log_gpu_memory_usage(\"After generating sequences\", logger=None)\n\n    sglang_output = output.to(\"cpu\")\n\n    sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch[\"responses\"])\n\n    print(f\"hf response: {hf_response_tokens}\")\n    print(f\"sglang response: {sglang_response_tokens}\")\n    assert are_lists_similar(hf_response_tokens, sglang_response_tokens)\n    print(\"SGLang w tool Test Passed!\")\n\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    test_async_sglang_rollout_w_tool()\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_sglang_async_rollout_w_tools_token_out.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nusage: torchrun --standalone --nnodes=1 \\\n    --nproc_per_node=2 $(which pytest) \\\n    -s test_sglang_async_rollout_w_tools.py\n\"\"\"\n\nimport numpy as np\nimport torch\nfrom tensordict import TensorDict\nfrom utils_sglang import (\n    are_lists_similar,\n    clean_torchelastic_env,\n    generate_hf_output,\n    get_rollout_config,\n    initialize_global_process_group,\n    load_tokenizer_and_model,\n    prepare_inputs,\n)\n\nfrom verl import DataProto\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\n\ndef test_async_sglang_rollout_w_tool():\n    import os\n\n    assert torch.cuda.device_count() >= 2\n    initialize_global_process_group()\n    clean_torchelastic_env()\n\n    max_prompt_length = 32\n    max_response_length = 16\n    dtype = \"bfloat16\"\n    tensor_parallel_size = 2\n    skip_tokenizer_init = True\n    local_model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B\")\n\n    tokenizer, actor_model = load_tokenizer_and_model(local_model_path)\n\n    preencode_prompts = [\n        [{\"role\": \"user\", \"content\": prompt, \"tool_calls\": None}]\n        for prompt in [\n            \"Who won the Champions League in 2019?\",\n            \"The founder of Apple is\",\n            \"What's the best way to learn python?\",\n        ]\n    ]\n    prompts = [\n        tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n        for message in preencode_prompts\n    ]\n    input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length)\n\n    hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)\n\n    rollout_config = get_rollout_config(\n        max_response_length,\n        max_prompt_length,\n        dtype,\n        tensor_parallel_size,\n        tool_config_path=\"./resource/tool_configs/sandbox_fusion_tool_config\",\n        skip_tokenizer_init=skip_tokenizer_init,\n    )\n    rollout_config: RolloutConfig = omega_conf_to_dataclass(rollout_config, dataclass_type=RolloutConfig)\n    model_config = HFModelConfig(path=local_model_path)\n    rollout = SGLangRollout(\n        config=rollout_config,\n        model_config=model_config,\n        device_mesh=None,\n    )\n\n    prompt_dict = TensorDict(\n        {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"position_ids\": position_ids,\n        },\n        batch_size=input_ids.shape[0],\n    )\n    print(f\"preprocessed {input_ids.shape=}\")\n\n    messages = np.asarray(preencode_prompts)\n    prompts = DataProto(\n        batch=prompt_dict,\n        non_tensor_batch={\n            \"raw_prompt\": messages,\n            \"tools_kwargs\": np.array([{}] * input_ids.shape[0], dtype=object),\n        },\n    )\n\n    prompts.meta_info.update(\n        {\n            \"eos_token_id\": tokenizer.eos_token_id,\n            \"pad_token_id\": tokenizer.pad_token_id,\n        }\n    )\n\n    # log_gpu_memory_usage(\"Before generating sequences\", logger=None)\n    output = rollout.generate_sequences(prompts=prompts)\n    print(f\"generated {output.batch['responses'].shape=}\")\n    # log_gpu_memory_usage(\"After generating sequences\", logger=None)\n\n    sglang_output = output.to(\"cpu\")\n\n    sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch[\"responses\"])\n\n    print(f\"hf response: {hf_response_tokens}\")\n    print(f\"sglang response: {sglang_response_tokens}\")\n    assert are_lists_similar(hf_response_tokens, sglang_response_tokens)\n    print(\"SGLang w tool Test Passed!\")\n\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    test_async_sglang_rollout_w_tool()\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_sglang_multi_interaction.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n\"\"\"\nTest for multi-interaction support in SGLangRollout.\nusage: torchrun --standalone --nnodes=1 \\\n    --nproc_per_node=2 $(which pytest) \\\n    -s test_sglang_multi_interaction.py\n\"\"\"\n\nimport os\nimport tempfile\nfrom unittest.mock import MagicMock, patch\n\nimport torch\nimport torch.distributed as dist\nfrom omegaconf import DictConfig, OmegaConf\nfrom transformers import AutoTokenizer\n\nfrom verl.interactions.base import BaseInteraction\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\n\nclass MockInteraction(BaseInteraction):\n    \"\"\"Mock interaction for testing.\"\"\"\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.started_instances = set()\n\n    async def start_interaction(self, instance_id=None, **kwargs):\n        if instance_id is None:\n            instance_id = \"mock_instance\"\n        self.started_instances.add(instance_id)\n        return instance_id\n\n    async def generate_response(self, instance_id, messages, **kwargs):\n        return False, f\"Mock response from {self.name}\", 1.0, {}\n\n\ndef create_mock_config_with_multi_interactions():\n    \"\"\"Create a mock configuration with multiple interactions.\"\"\"\n    # Create temporary interaction config file\n    interaction_config = {\n        \"interaction\": [\n            {\n                \"name\": \"mock_agent1\",\n                \"class_name\": \"tests.workers.rollout.test_sglang_multi_interaction.MockInteraction\",\n                \"config\": {\"param1\": \"value1\"},\n            },\n            {\n                \"name\": \"mock_agent2\",\n                \"class_name\": \"tests.workers.rollout.test_sglang_multi_interaction.MockInteraction\",\n                \"config\": {\"param2\": \"value2\"},\n            },\n        ]\n    }\n\n    with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n        OmegaConf.save(interaction_config, f.name)\n        interaction_config_path = f.name\n\n    # Create mock SGLangRollout config\n    config = DictConfig(\n        {\n            \"name\": \"sglang\",\n            \"multi_turn\": {\n                \"interaction_config_path\": interaction_config_path,\n                \"tool_config_path\": None,\n                \"enable\": True,\n                \"max_assistant_turns\": 5,\n                \"max_user_turns\": 3,\n                \"use_inference_chat_template\": True,\n                \"tokenization_sanity_check_mode\": \"off\",\n            },\n            \"prompt_length\": 32,\n            \"response_length\": 16,\n            \"max_model_len\": 512,\n            \"dtype\": \"bfloat16\",\n            \"gpu_memory_utilization\": 0.8,\n            \"load_format\": \"dummy\",\n            \"enforce_eager\": True,\n            \"free_cache_engine\": False,\n            \"calculate_log_probs\": False,\n            \"tensor_model_parallel_size\": 1,\n            \"n\": 1,\n            \"val_kwargs\": {\"top_k\": 1, \"top_p\": 1.0, \"temperature\": 0.0},\n        }\n    )\n\n    return config, interaction_config_path\n\n\ndef setup_distributed():\n    \"\"\"Initialize distributed environment if not already initialized.\"\"\"\n    if not dist.is_initialized():\n        dist.init_process_group(backend=\"nccl\" if torch.cuda.is_available() else \"gloo\")\n\n\nclass TestSGLangMultiInteraction:\n    local_model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B\")\n\n    def test_initialize_multiple_interactions(self):\n        \"\"\"Test that SGLangRollout can initialize multiple interactions.\"\"\"\n        setup_distributed()\n        config, temp_config_path = create_mock_config_with_multi_interactions()\n\n        try:\n            # Mock SGLang engine and initialization methods like the reference test\n            with (\n                patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n                patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n                patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n            ):\n                # Create a real tokenizer like the reference test\n                tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side=\"left\")\n                tokenizer.pad_token = tokenizer.eos_token\n\n                # Mock model config\n                mock_model_config = MagicMock()\n                mock_model_config.max_position_embeddings = 2048\n                # since this is a mock, we can set any rope scaling config\n                # to test the rope_scaling logic at the same time of this test\n                mock_model_config.rope_scaling = {\n                    \"factor\": 4.0,\n                    \"original_max_position_embeddings\": 32768,\n                    \"type\": \"yarn\",\n                }\n\n                rollout_config: RolloutConfig = omega_conf_to_dataclass(config, dataclass_type=RolloutConfig)\n                model_config = HFModelConfig(path=self.local_model_path)\n                rollout = SGLangRollout(\n                    config=rollout_config,\n                    model_config=model_config,\n                    device_mesh=None,\n                )\n\n                # Check that interactions were initialized\n                assert len(rollout.interaction_map) == 2\n                assert \"mock_agent1\" in rollout.interaction_map\n                assert \"mock_agent2\" in rollout.interaction_map\n\n                # Use class name comparison instead of isinstance for multi-process compatibility\n                assert rollout.interaction_map[\"mock_agent1\"].__class__.__name__ == \"MockInteraction\"\n                assert rollout.interaction_map[\"mock_agent2\"].__class__.__name__ == \"MockInteraction\"\n\n                # Also check that they are instances of BaseInteraction (which should work across processes)\n                assert isinstance(rollout.interaction_map[\"mock_agent1\"], BaseInteraction)\n                assert isinstance(rollout.interaction_map[\"mock_agent2\"], BaseInteraction)\n\n                # Check that names were set correctly\n                assert rollout.interaction_map[\"mock_agent1\"].name == \"mock_agent1\"\n                assert rollout.interaction_map[\"mock_agent2\"].name == \"mock_agent2\"\n\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_interaction_selection_by_name(self):\n        \"\"\"Test that interactions are selected by name from interaction_kwargs.\"\"\"\n        setup_distributed()\n        config, temp_config_path = create_mock_config_with_multi_interactions()\n\n        try:\n            with (\n                patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n                patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n                patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n            ):\n                tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side=\"left\")\n                tokenizer.pad_token = tokenizer.eos_token\n\n                mock_model_config = MagicMock()\n                mock_model_config.max_position_embeddings = 2048\n                mock_model_config.rope_scaling = {\n                    \"factor\": 4.0,\n                    \"original_max_position_embeddings\": 32768,\n                    \"type\": \"yarn\",\n                }\n\n                rollout_config: RolloutConfig = omega_conf_to_dataclass(config, dataclass_type=RolloutConfig)\n                model_config = HFModelConfig(path=self.local_model_path)\n                rollout = SGLangRollout(\n                    config=rollout_config,\n                    model_config=model_config,\n                    device_mesh=None,\n                )\n\n                # Test interaction selection logic\n                from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message\n\n                # Create a mock request with specific interaction name\n                req = AsyncRolloutRequest(\n                    request_id=\"test_req\",\n                    state=AsyncRolloutRequestStateEnum.INTERACTING,\n                    messages=[Message(role=\"user\", content=\"test message\")],\n                    interaction_kwargs={\"name\": \"mock_agent2\", \"test_param\": \"value\"},\n                    input_ids=None,\n                    prompt_ids=None,\n                    response_ids=None,\n                    attention_mask=None,\n                    prompt_attention_mask=None,\n                    response_attention_mask=None,\n                    position_ids=None,\n                    prompt_position_ids=None,\n                    response_position_ids=None,\n                    loss_mask=None,\n                    prompt_loss_mask=None,\n                    response_loss_mask=None,\n                    reward_scores={},\n                    max_prompt_len=32,\n                    max_response_len=16,\n                    max_model_len=512,\n                    use_inference_chat_template=True,\n                    tokenization_sanity_check_mode=\"disable\",\n                    processing_class=tokenizer,\n                )\n\n                # Test that the correct interaction is selected\n                interaction_name = req.interaction_kwargs.get(\"name\", \"gsm8k\")\n                assert interaction_name == \"mock_agent2\"\n                assert interaction_name in rollout.interaction_map\n\n                selected_interaction = rollout.interaction_map[interaction_name]\n                assert selected_interaction.name == \"mock_agent2\"\n\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_fallback_to_default_interaction(self):\n        \"\"\"Test fallback to default interaction when name is not specified.\"\"\"\n        setup_distributed()\n        # Create config with gsm8k interaction\n        interaction_config = {\n            \"interaction\": [\n                {\n                    \"name\": \"gsm8k\",\n                    \"class_name\": \"tests.workers.rollout.test_sglang_multi_interaction.MockInteraction\",\n                    \"config\": {},\n                }\n            ]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(interaction_config, f.name)\n            interaction_config_path = f.name\n\n        config = DictConfig(\n            {\n                \"name\": \"sglang\",\n                \"multi_turn\": {\n                    \"interaction_config_path\": interaction_config_path,\n                    \"tool_config_path\": None,\n                    \"enable\": True,\n                    \"max_assistant_turns\": 5,\n                    \"max_user_turns\": 3,\n                    \"use_inference_chat_template\": True,\n                    \"tokenization_sanity_check_mode\": \"disable\",\n                },\n                \"prompt_length\": 32,\n                \"response_length\": 16,\n                \"max_model_len\": 512,\n                \"dtype\": \"bfloat16\",\n                \"gpu_memory_utilization\": 0.8,\n                \"load_format\": \"dummy\",\n                \"enforce_eager\": True,\n                \"free_cache_engine\": False,\n                \"calculate_log_probs\": False,\n                \"tensor_model_parallel_size\": 1,\n                \"n\": 1,\n                \"val_kwargs\": {\"top_k\": 1, \"top_p\": 1.0, \"temperature\": 0.0},\n            }\n        )\n\n        try:\n            with (\n                patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n                patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n                patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n            ):\n                tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side=\"left\")\n                tokenizer.pad_token = tokenizer.eos_token\n\n                mock_model_config = MagicMock()\n                mock_model_config.max_position_embeddings = 2048\n                mock_model_config.rope_scaling = {\n                    \"factor\": 4.0,\n                    \"original_max_position_embeddings\": 32768,\n                    \"type\": \"yarn\",\n                }\n\n                rollout_config: RolloutConfig = omega_conf_to_dataclass(config, dataclass_type=RolloutConfig)\n                model_config = HFModelConfig(path=self.local_model_path)\n                rollout = SGLangRollout(\n                    config=rollout_config,\n                    model_config=model_config,\n                    device_mesh=None,\n                )\n\n                # Test that default interaction name works\n                interaction_kwargs_without_name = {\"test_param\": \"value\"}\n                default_name = interaction_kwargs_without_name.get(\"name\", \"gsm8k\")\n                assert default_name == \"gsm8k\"\n                assert default_name in rollout.interaction_map\n\n        finally:\n            os.unlink(interaction_config_path)\n\n    def test_error_on_missing_interaction(self):\n        \"\"\"Test that error is raised when requested interaction is not found.\"\"\"\n        setup_distributed()\n        config, temp_config_path = create_mock_config_with_multi_interactions()\n\n        try:\n            with (\n                patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n                patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n                patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n            ):\n                tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side=\"left\")\n                tokenizer.pad_token = tokenizer.eos_token\n\n                mock_model_config = MagicMock()\n                mock_model_config.max_position_embeddings = 2048\n                mock_model_config.rope_scaling = {\n                    \"factor\": 4.0,\n                    \"original_max_position_embeddings\": 32768,\n                    \"type\": \"yarn\",\n                }\n\n                rollout_config: RolloutConfig = omega_conf_to_dataclass(config, dataclass_type=RolloutConfig)\n                model_config = HFModelConfig(path=self.local_model_path)\n                rollout = SGLangRollout(\n                    config=rollout_config,\n                    model_config=model_config,\n                    device_mesh=None,\n                )\n\n                # Test error when requesting non-existent interaction\n                non_existent_name = \"non_existent_interaction\"\n                assert non_existent_name not in rollout.interaction_map\n\n                # This should raise ValueError in actual usage\n                available_interactions = list(rollout.interaction_map.keys())\n                assert \"mock_agent1\" in available_interactions\n                assert \"mock_agent2\" in available_interactions\n                assert non_existent_name not in available_interactions\n\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_backward_compatibility_no_interaction_config(self):\n        \"\"\"Test backward compatibility when no interaction config is provided.\"\"\"\n        setup_distributed()\n        # Create config without interaction config\n        config = DictConfig(\n            {\n                \"name\": \"sglang\",\n                \"multi_turn\": {\n                    \"interaction_config_path\": None,\n                    \"tool_config_path\": None,\n                    \"enable\": True,\n                    \"max_assistant_turns\": 5,\n                    \"max_user_turns\": 3,\n                    \"use_inference_chat_template\": True,\n                    \"tokenization_sanity_check_mode\": \"disable\",\n                },\n                \"prompt_length\": 32,\n                \"response_length\": 16,\n                \"max_model_len\": 512,\n                \"dtype\": \"bfloat16\",\n                \"gpu_memory_utilization\": 0.8,\n                \"load_format\": \"dummy\",\n                \"enforce_eager\": True,\n                \"free_cache_engine\": False,\n                \"calculate_log_probs\": False,\n                \"tensor_model_parallel_size\": 1,\n                \"n\": 1,\n                \"val_kwargs\": {\"top_k\": 1, \"top_p\": 1.0, \"temperature\": 0.0},\n            }\n        )\n\n        with (\n            patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n            patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n            patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n        ):\n            tokenizer = AutoTokenizer.from_pretrained(self.local_model_path, padding_side=\"left\")\n            tokenizer.pad_token = tokenizer.eos_token\n\n            mock_model_config = MagicMock()\n            mock_model_config.max_position_embeddings = 2048\n            mock_model_config.rope_scaling = {\n                \"factor\": 4.0,\n                \"original_max_position_embeddings\": 32768,\n                \"type\": \"yarn\",\n            }\n\n            rollout_config: RolloutConfig = omega_conf_to_dataclass(config, dataclass_type=RolloutConfig)\n            model_config = HFModelConfig(path=self.local_model_path)\n            rollout = SGLangRollout(\n                config=rollout_config,\n                model_config=model_config,\n                device_mesh=None,\n            )\n\n            # Check that no interactions were initialized\n            assert len(rollout.interaction_map) == 0\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_sglang_rollout_sharding_manager.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\n\nfrom verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets\n\n_TENSOR_1MB = torch.zeros(512, 512)\n_BYTES_1MB = 1 << 20\n\n\n@pytest.mark.parametrize(\n    \"named_tensors, bucket_size_mb, gt_groups\",\n    [\n        (\n            [(\"a\", _TENSOR_1MB), (\"b\", _TENSOR_1MB)],\n            0.5 * _BYTES_1MB,\n            [[\"a\"], [\"b\"]],\n        ),\n        (\n            [(\"a\", _TENSOR_1MB), (\"b\", _TENSOR_1MB)],\n            1 * _BYTES_1MB,\n            [[\"a\"], [\"b\"]],\n        ),\n        (\n            [(\"a\", _TENSOR_1MB), (\"b\", _TENSOR_1MB)],\n            1.5 * _BYTES_1MB,\n            [[\"a\"], [\"b\"]],\n        ),\n        (\n            [(\"a\", _TENSOR_1MB), (\"b\", _TENSOR_1MB)],\n            2 * _BYTES_1MB,\n            [[\"a\", \"b\"]],\n        ),\n    ],\n)\ndef test_get_named_tensor_buckets(named_tensors, bucket_size_mb, gt_groups: list[list[str]]):\n    named_tensors_iter = iter(named_tensors)\n    groups = list(get_named_tensor_buckets(named_tensors_iter, bucket_size_mb))\n    assert len(groups) == len(gt_groups)\n    for group, gt_group in zip(groups, gt_groups, strict=True):\n        assert len(group) == len(gt_group)\n        for (name, _), (gt_name) in zip(group, gt_group, strict=True):\n            assert name == gt_name\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/test_sglang_spmd.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nusage: torchrun --standalone --nnodes=1 \\\n    --nproc_per_node=2 $(which pytest) \\\n    -s test_sglang_async_spmd.py\n\"\"\"\n\nimport asyncio\nimport os\n\nimport torch\nfrom sglang.srt.entrypoints.engine import Engine\nfrom sglang.srt.utils import broadcast_pyobj\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom utils_sglang import (\n    are_lists_similar,\n    clean_torchelastic_env,\n    generate_hf_output,\n    initialize_global_process_group,\n    load_tokenizer_and_model,\n    prepare_inputs,\n)\n\n\ndef _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):\n    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]\n    token_ids = prompt_token_ids[non_pad_index:].tolist()\n    return token_ids\n\n\ndef test_sglang_spmd():\n    assert torch.cuda.device_count() >= 2\n    initialize_global_process_group(spmd=True)\n    clean_torchelastic_env()\n\n    max_prompt_length = 16\n    max_response_length = 16\n\n    local_model_path = os.path.expanduser(\"~/models/Qwen/Qwen2.5-0.5B\")\n    tokenizer, actor_model = load_tokenizer_and_model(local_model_path)\n\n    preencode_prompts = [\"Who won the Champions League in 2019?\", \"The founder of Apple is\", \"What's your name?\"]\n    input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length)\n\n    hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)\n\n    tensor_parallel_size = 2\n    inference_device_mesh_cpu = init_device_mesh(\n        \"cpu\", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=[\"dp\", \"tp\", \"pp\"]\n    )\n    tp_rank = inference_device_mesh_cpu[\"tp\"].get_local_rank()\n\n    if tp_rank == 0:\n        llm = Engine(\n            model_path=local_model_path,\n            dtype=\"bfloat16\",\n            mem_fraction_static=0.5,\n            enable_memory_saver=True,\n            tp_size=inference_device_mesh_cpu[\"tp\"].size(),\n            attention_backend=\"fa3\",\n        )\n\n        input_ids = input_ids.cuda()\n        idx_list = []\n\n        pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n        for i in range(input_ids.shape[0]):\n            idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))\n\n        sampling_params = dict(\n            n=1,\n            temperature=0,\n            top_p=1,\n            top_k=-1,\n            max_new_tokens=max_response_length,\n            presence_penalty=0.0,\n            frequency_penalty=0.0,\n            repetition_penalty=1.0,\n            skip_special_tokens=True,\n            spaces_between_special_tokens=True,\n            ignore_eos=False,\n        )\n\n        loop = asyncio.get_event_loop()\n        outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params))\n    else:\n        outputs = None\n\n    [outputs] = broadcast_pyobj(\n        [outputs],\n        rank=inference_device_mesh_cpu[\"tp\"].get_local_rank(),\n        src=inference_device_mesh_cpu[\"tp\"].mesh[0].item(),\n        dist_group=inference_device_mesh_cpu[\"tp\"].get_group(),\n        force_cpu_device=False,\n    )\n\n    sglang_response_tokens = [output[\"text\"] for output in outputs]\n\n    print(f\"sglang response: {sglang_response_tokens}\")\n    assert are_lists_similar(hf_response_tokens, sglang_response_tokens), \"Strings differ more than 10%:\\n\"\n    print(\"SPMD Test Passed!\")\n\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "verl_distillation/tests/workers/rollout/utils_sglang.py",
    "content": "# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nfrom datetime import timedelta\n\nimport torch\nfrom omegaconf import OmegaConf\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig\n\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.utils.torch_functional import pad_sequence_to_length\n\n\n# ====================== utils ======================\ndef levenshtein(s1, s2):\n    m, n = len(s1), len(s2)\n    dp = [[0] * (n + 1) for _ in range(m + 1)]\n    for i in range(m + 1):\n        dp[i][0] = i\n    for j in range(n + 1):\n        dp[0][j] = j\n    for i in range(1, m + 1):\n        for j in range(1, n + 1):\n            cost = 0 if s1[i - 1] == s2[j - 1] else 1\n            dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost)\n    return dp[m][n]\n\n\ndef are_lists_similar(a, b, threshold=10):\n    if len(a) != len(b):\n        print(\"The lists are of different lengths.\")\n        return False\n    total_length = 0\n    total_diff = 0\n    for s1, s2 in zip(a, b, strict=True):\n        max_len = max(len(s1), len(s2))\n        total_length += max_len\n        total_diff += levenshtein(s1, s2)\n    percentage_difference = (total_diff / total_length) * 100\n    print(f\"Total difference: {percentage_difference:.2f}%\")\n    return percentage_difference <= threshold\n\n\ndef initialize_global_process_group(timeout_second=36000, spmd=False):\n    import torch.distributed\n\n    if not torch.distributed.is_initialized():  # Check if already initialized\n        print(\"Initializing process group...\")\n        torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second))\n    else:\n        print(\"Process group already initialized.\")\n\n    local_rank = int(os.environ[\"LOCAL_RANK\"])\n    rank = int(os.environ[\"RANK\"])\n    world_size = int(os.environ[\"WORLD_SIZE\"])\n    torch.cuda.set_device(local_rank)\n\n    CUDA_VISIBLE_DEVICES = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"\")\n    if not CUDA_VISIBLE_DEVICES:\n        if spmd:\n            # CUDA_VISIBLE_DEVICES = ','.join(str(i) for i in range(tensor_parallel_size))\n            CUDA_VISIBLE_DEVICES = \",\".join(str(i) for i in range(world_size))\n        else:\n            CUDA_VISIBLE_DEVICES = str(local_rank)\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = CUDA_VISIBLE_DEVICES\n        print(f\"CUDA_VISIBLE_DEVICES is not set, set to {CUDA_VISIBLE_DEVICES}\")\n\n    return local_rank, rank, world_size\n\n\ndef clean_torchelastic_env():\n    for k in [\"TORCHELASTIC_USE_AGENT_STORE\"]:\n        if k in os.environ:\n            del os.environ[k]\n\n\ndef load_tokenizer_and_model(local_model_path, dtype=\"bfloat16\"):\n    tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side=\"left\")\n    tokenizer.pad_token = tokenizer.eos_token\n    model = AutoModelForCausalLM.from_pretrained(local_model_path, torch_dtype=getattr(torch, dtype), device_map=\"cuda\")\n    return tokenizer, model\n\n\ndef prepare_inputs(tokenizer, prompts, max_prompt_length):\n    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n    tokenized = tokenizer(prompts, return_tensors=\"pt\", padding=True)\n    input_ids = pad_sequence_to_length(tokenized[\"input_ids\"], max_prompt_length, pad_token_id, left_pad=True)\n    attention_mask = pad_sequence_to_length(\n        tokenized[\"attention_mask\"], max_prompt_length, pad_token_id=0, left_pad=True\n    )\n    position_ids = compute_position_id_with_mask(attention_mask)\n    position_ids = pad_sequence_to_length(position_ids, max_prompt_length, pad_token_id=0, left_pad=True)\n    return input_ids, attention_mask, position_ids\n\n\ndef generate_hf_output(model, input_ids, attention_mask, tokenizer, max_response_length):\n    generation_config = GenerationConfig(do_sample=False)\n    output = model.generate(\n        input_ids=input_ids.cuda(),\n        attention_mask=attention_mask.cuda(),\n        max_new_tokens=max_response_length,\n        eos_token_id=tokenizer.eos_token_id,\n        pad_token_id=tokenizer.pad_token_id,\n        generation_config=generation_config,\n        output_scores=False,\n        return_dict_in_generate=True,\n        use_cache=False,\n    )\n    seq = output.sequences\n    response = seq[:, input_ids.shape[1] :]\n    return tokenizer.batch_decode(response)\n\n\ndef get_rollout_config(\n    max_response_length,\n    max_prompt_length,\n    dtype,\n    tensor_parallel_size,\n    tool_config_path=None,\n    interaction_config_path=None,\n    skip_tokenizer_init=False,\n):\n    sampling_params = dict(\n        n=1,\n        temperature=0,\n        top_p=1,\n        top_k=-1,\n    )\n\n    rollout_config = OmegaConf.create(\n        {\n            \"name\": \"sglang\",\n            \"mode\": \"sync\",\n            \"load_format\": \"auto\",\n            \"enforce_eager\": False,\n            \"free_cache_engine\": True,\n            \"dtype\": dtype,\n            \"gpu_memory_utilization\": 0.5,\n            \"ignore_eos\": False,\n            \"max_num_batched_tokens\": 8192,\n            \"prompt_length\": max_prompt_length,\n            \"response_length\": max_response_length,\n            \"tensor_model_parallel_size\": tensor_parallel_size,\n            # set to 128MB only for testing\n            \"update_weights_bucket_megabytes\": 128,\n            # do not drop any samples in the test\n            \"over_sample_rate\": 0.0,\n            \"multi_turn\": {\n                \"max_assistant_turns\": 4,\n                \"max_user_turns\": 4,\n                \"enable\": True,\n                \"tool_config_path\": tool_config_path,\n                \"interaction_config_path\": interaction_config_path,\n                \"use_inference_chat_template\": False,\n                \"tokenization_sanity_check_mode\": \"strict\",\n            },\n            \"calculate_log_probs\": False,\n            \"max_model_len\": None,\n            \"skip_tokenizer_init\": skip_tokenizer_init,\n            **sampling_params,\n        }\n    )\n\n    return rollout_config\n"
  },
  {
    "path": "verl_distillation/tests/workers/test_fsdp_attn_implementation.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nTest for attn_implementation override configuration in FSDP workers.\n\nThis test verifies that the fix for honoring attn_implementation override config\nworks correctly in the ActorRolloutRefWorker._build_model_optimizer method.\n\"\"\"\n\nfrom unittest.mock import Mock, patch\n\nimport pytest\nimport torch\nfrom omegaconf import OmegaConf\nfrom transformers import AutoConfig, AutoModelForCausalLM\n\n# Only run these tests if we can import verl components\ntry:\n    from verl.workers.config import FSDPEngineConfig  # noqa: F401\n    from verl.workers.fsdp_workers import (\n        ActorRolloutRefWorker,  # noqa: F401\n        CriticWorker,  # noqa: F401\n    )\n\n    VERL_AVAILABLE = True\nexcept ImportError:\n    VERL_AVAILABLE = False\n\n\n@pytest.mark.skipif(not VERL_AVAILABLE, reason=\"VERL components not available\")\nclass TestFSDPAttnImplementation:\n    \"\"\"Test cases for attn_implementation override in FSDP workers.\"\"\"\n\n    def test_attn_implementation_extraction_logic(self):\n        \"\"\"Test the core logic for extracting attn_implementation from override config.\"\"\"\n\n        # Test case 1: Default behavior\n        override_config = {}\n        attn_impl = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_impl == \"flash_attention_2\"\n\n        # Test case 2: Override to eager\n        override_config = {\"attn_implementation\": \"eager\"}\n        attn_impl = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_impl == \"eager\"\n\n        # Test case 3: Override to sdpa\n        override_config = {\"attn_implementation\": \"sdpa\"}\n        attn_impl = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_impl == \"sdpa\"\n\n        # Test case 4: Other configs don't affect attn_implementation\n        override_config = {\"other_setting\": \"value\", \"dropout\": 0.1}\n        attn_impl = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_impl == \"flash_attention_2\"\n\n    @patch(\"transformers.AutoConfig.from_pretrained\")\n    @patch(\"transformers.AutoModelForCausalLM.from_pretrained\")\n    def test_attn_implementation_passed_to_autoconfig(self, mock_model_from_pretrained, mock_config_from_pretrained):\n        \"\"\"Test that attn_implementation is correctly passed to AutoConfig.from_pretrained.\"\"\"\n\n        # Mock the AutoConfig return value\n        mock_config = Mock()\n        mock_config.tie_word_embeddings = False\n        mock_config.architectures = [\"LlamaForCausalLM\"]\n        mock_config_from_pretrained.return_value = mock_config\n\n        # Mock the model return value\n        mock_model = Mock()\n        mock_model_from_pretrained.return_value = mock_model\n\n        # Test data\n        test_cases = [\n            ({}, \"flash_attention_2\"),  # Default\n            ({\"attn_implementation\": \"eager\"}, \"eager\"),  # Override to eager\n            ({\"attn_implementation\": \"sdpa\"}, \"sdpa\"),  # Override to sdpa\n        ]\n\n        for override_config, expected_attn_impl in test_cases:\n            # Reset mocks\n            mock_config_from_pretrained.reset_mock()\n            mock_model_from_pretrained.reset_mock()\n\n            # Simulate the logic from FSDP workers\n            attn_implementation = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n            # This simulates what happens in _build_model_optimizer\n            AutoConfig.from_pretrained(\"test_path\", trust_remote_code=False, attn_implementation=attn_implementation)\n\n            # Verify AutoConfig.from_pretrained was called with correct attn_implementation\n            mock_config_from_pretrained.assert_called_once_with(\n                \"test_path\", trust_remote_code=False, attn_implementation=expected_attn_impl\n            )\n\n    @patch(\"transformers.AutoConfig.from_pretrained\")\n    @patch(\"transformers.AutoModelForCausalLM.from_pretrained\")\n    def test_attn_implementation_passed_to_model(self, mock_model_from_pretrained, mock_config_from_pretrained):\n        \"\"\"Test that attn_implementation is correctly passed to model.from_pretrained.\"\"\"\n\n        # Mock the AutoConfig return value\n        mock_config = Mock()\n        mock_config.tie_word_embeddings = False\n        mock_config.architectures = [\"LlamaForCausalLM\"]\n        mock_config_from_pretrained.return_value = mock_config\n\n        # Mock the model return value\n        mock_model = Mock()\n        mock_model_from_pretrained.return_value = mock_model\n\n        # Test with override config\n        override_config = {\"attn_implementation\": \"eager\"}\n        attn_implementation = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n        # This simulates what happens in _build_model_optimizer\n        AutoModelForCausalLM.from_pretrained(\n            pretrained_model_name_or_path=\"test_path\",\n            torch_dtype=torch.bfloat16,\n            config=mock_config,\n            trust_remote_code=False,\n            attn_implementation=attn_implementation,\n        )\n\n        # Verify AutoModelForCausalLM.from_pretrained was called with correct attn_implementation\n        mock_model_from_pretrained.assert_called_once_with(\n            pretrained_model_name_or_path=\"test_path\",\n            torch_dtype=torch.bfloat16,\n            config=mock_config,\n            trust_remote_code=False,\n            attn_implementation=\"eager\",\n        )\n\n    def test_override_config_integration(self):\n        \"\"\"Test that override_config from Hydra configuration works correctly.\"\"\"\n\n        # Simulate the OmegaConf configuration structure used in VERL\n        config_dict = {\n            \"model\": {\"path\": \"/test/path\", \"override_config\": {\"attn_implementation\": \"eager\", \"dropout\": 0.1}}\n        }\n\n        # Convert to OmegaConf structure\n        omegaconf = OmegaConf.create(config_dict)\n\n        # Simulate what happens in the FSDP worker\n        override_model_config = OmegaConf.to_container(OmegaConf.create(omegaconf.model.get(\"override_config\", {})))\n\n        # Test extraction\n        attn_implementation = override_model_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_implementation == \"eager\"\n\n        # Test that other configs are preserved\n        assert override_model_config.get(\"dropout\") == 0.1\n\n    def test_hydra_plus_prefix_config(self):\n        \"\"\"Test that Hydra +prefix configurations work correctly.\"\"\"\n\n        # This simulates the configuration when user specifies:\n        # +actor_rollout_ref.model.override_config.attn_implementation=eager\n\n        # The + prefix in Hydra adds new keys to the config\n        config_dict = {\n            \"actor_rollout_ref\": {\n                \"model\": {\n                    \"path\": \"/test/path\",\n                    \"override_config\": {\n                        \"attn_implementation\": \"eager\"  # This gets added via +prefix\n                    },\n                }\n            }\n        }\n\n        omegaconf = OmegaConf.create(config_dict)\n\n        # Extract override config as done in FSDP workers\n        override_model_config = OmegaConf.to_container(\n            OmegaConf.create(omegaconf.actor_rollout_ref.model.get(\"override_config\", {}))\n        )\n\n        # Verify extraction works\n        attn_implementation = override_model_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_implementation == \"eager\"\n\n    def test_backward_compatibility(self):\n        \"\"\"Test that the fix maintains backward compatibility.\"\"\"\n\n        # Test case 1: No override_config at all (old behavior)\n        config_without_override = {}\n        attn_implementation = config_without_override.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_implementation == \"flash_attention_2\"\n\n        # Test case 2: Empty override_config\n        config_with_empty_override = {\"override_config\": {}}\n        override_config = config_with_empty_override.get(\"override_config\", {})\n        attn_implementation = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_implementation == \"flash_attention_2\"\n\n        # Test case 3: override_config with other settings but no attn_implementation\n        config_with_other_overrides = {\"override_config\": {\"dropout\": 0.1, \"hidden_size\": 1024}}\n        override_config = config_with_other_overrides.get(\"override_config\", {})\n        attn_implementation = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_implementation == \"flash_attention_2\"\n\n    def test_critic_attn_implementation_extraction_logic(self):\n        \"\"\"Test the core logic for extracting attn_implementation from override config for CriticWorker.\"\"\"\n\n        # Test case 1: Default behavior for critic\n        override_config = {}\n        attn_impl = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_impl == \"flash_attention_2\"\n\n        # Test case 2: Override to eager for critic\n        override_config = {\"attn_implementation\": \"eager\"}\n        attn_impl = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_impl == \"eager\"\n\n        # Test case 3: Override to sdpa for critic\n        override_config = {\"attn_implementation\": \"sdpa\"}\n        attn_impl = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_impl == \"sdpa\"\n\n        # Test case 4: Other configs don't affect attn_implementation for critic\n        override_config = {\"other_setting\": \"value\", \"dropout\": 0.1}\n        attn_impl = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_impl == \"flash_attention_2\"\n\n    @patch(\"transformers.AutoConfig.from_pretrained\")\n    def test_critic_attn_implementation_passed_to_autoconfig(self, mock_config_from_pretrained):\n        \"\"\"Test that attn_implementation is correctly passed to AutoConfig.from_pretrained in CriticWorker.\"\"\"\n\n        # Mock the AutoConfig return value\n        mock_config = Mock()\n        mock_config.tie_word_embeddings = False\n        mock_config.architectures = [\"LlamaForCausalLM\"]\n        mock_config.num_labels = 1\n        mock_config_from_pretrained.return_value = mock_config\n\n        # Test data for critic model\n        test_cases = [\n            ({}, \"flash_attention_2\"),  # Default\n            ({\"attn_implementation\": \"eager\"}, \"eager\"),  # Override to eager\n            ({\"attn_implementation\": \"sdpa\"}, \"sdpa\"),  # Override to sdpa\n        ]\n\n        for override_config, expected_attn_impl in test_cases:\n            # Reset mocks\n            mock_config_from_pretrained.reset_mock()\n\n            # Simulate the logic from CriticWorker _build_critic_model_optimizer\n            attn_implementation = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n            # This simulates what should happen in CriticWorker._build_critic_model_optimizer\n            # (This is where the fix needs to be applied in the actual implementation)\n            AutoConfig.from_pretrained(\n                \"test_path\",\n                attn_implementation=attn_implementation,\n                trust_remote_code=False,\n            )\n\n            # Verify AutoConfig.from_pretrained was called with correct attn_implementation\n            mock_config_from_pretrained.assert_called_once_with(\n                \"test_path\",\n                attn_implementation=expected_attn_impl,\n                trust_remote_code=False,\n            )\n\n    def test_critic_override_config_integration(self):\n        \"\"\"Test that override_config from Hydra configuration works correctly for CriticWorker.\"\"\"\n\n        # Simulate the OmegaConf configuration structure used in VERL for critic\n        config_dict = {\n            \"critic\": {\n                \"model\": {\"path\": \"/test/path\", \"override_config\": {\"attn_implementation\": \"eager\", \"dropout\": 0.1}}\n            }\n        }\n\n        # Convert to OmegaConf structure\n        omegaconf = OmegaConf.create(config_dict)\n\n        # Simulate what happens in the CriticWorker\n        override_model_config = OmegaConf.to_container(\n            OmegaConf.create(omegaconf.critic.model.get(\"override_config\", {}))\n        )\n\n        # Test extraction for critic\n        attn_implementation = override_model_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_implementation == \"eager\"\n\n        # Test that other configs are preserved for critic\n        assert override_model_config.get(\"dropout\") == 0.1\n\n    def test_critic_hydra_plus_prefix_config(self):\n        \"\"\"Test that Hydra +prefix configurations work correctly for CriticWorker.\"\"\"\n\n        # This simulates the configuration when user specifies:\n        # +critic.model.override_config.attn_implementation=eager\n\n        # The + prefix in Hydra adds new keys to the config\n        config_dict = {\n            \"critic\": {\n                \"model\": {\n                    \"path\": \"/test/path\",\n                    \"override_config\": {\n                        \"attn_implementation\": \"eager\"  # This gets added via +prefix for critic\n                    },\n                }\n            }\n        }\n\n        omegaconf = OmegaConf.create(config_dict)\n\n        # Extract override config as done in CriticWorker\n        override_model_config = OmegaConf.to_container(\n            OmegaConf.create(omegaconf.critic.model.get(\"override_config\", {}))\n        )\n\n        # Verify extraction works for critic\n        attn_implementation = override_model_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_implementation == \"eager\"\n\n    def test_both_actor_and_critic_configuration(self):\n        \"\"\"Test that both actor and critic can have different attn_implementation overrides simultaneously.\"\"\"\n\n        # This simulates a complete training configuration with both actor and critic overrides\n        config_dict = {\n            \"actor_rollout_ref\": {\"model\": {\"override_config\": {\"attn_implementation\": \"eager\"}}},\n            \"critic\": {\"model\": {\"override_config\": {\"attn_implementation\": \"sdpa\"}}},\n        }\n\n        omegaconf = OmegaConf.create(config_dict)\n\n        # Extract actor override config\n        actor_override_config = OmegaConf.to_container(\n            OmegaConf.create(omegaconf.actor_rollout_ref.model.get(\"override_config\", {}))\n        )\n        actor_attn_implementation = actor_override_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n        # Extract critic override config\n        critic_override_config = OmegaConf.to_container(\n            OmegaConf.create(omegaconf.critic.model.get(\"override_config\", {}))\n        )\n        critic_attn_implementation = critic_override_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n        # Verify both can be configured independently\n        assert actor_attn_implementation == \"eager\"\n        assert critic_attn_implementation == \"sdpa\"\n\n    def test_critic_backward_compatibility(self):\n        \"\"\"Test that the CriticWorker fix maintains backward compatibility.\"\"\"\n\n        # Test case 1: No override_config at all for critic (old behavior)\n        config_without_override = {}\n        attn_implementation = config_without_override.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_implementation == \"flash_attention_2\"\n\n        # Test case 2: Empty override_config for critic\n        config_with_empty_override = {\"override_config\": {}}\n        override_config = config_with_empty_override.get(\"override_config\", {})\n        attn_implementation = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_implementation == \"flash_attention_2\"\n\n        # Test case 3: override_config with other settings but no attn_implementation for critic\n        config_with_other_overrides = {\"override_config\": {\"dropout\": 0.1, \"num_labels\": 1}}\n        override_config = config_with_other_overrides.get(\"override_config\", {})\n        attn_implementation = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        assert attn_implementation == \"flash_attention_2\"\n\n\ndef test_attn_implementation_fix_integration():\n    \"\"\"Integration test to verify the entire fix works as expected.\"\"\"\n\n    # This test simulates the complete flow from configuration to model creation\n\n    # Step 1: Simulate Hydra configuration with +prefix\n    # user_config = \"+actor_rollout_ref.model.override_config.attn_implementation=eager\"\n\n    # This would result in a config structure like:\n    config_dict = {\"actor_rollout_ref\": {\"model\": {\"override_config\": {\"attn_implementation\": \"eager\"}}}}\n\n    # Step 2: Extract override_model_config as done in FSDP workers\n    omegaconf = OmegaConf.create(config_dict)\n    override_model_config = OmegaConf.to_container(\n        OmegaConf.create(omegaconf.actor_rollout_ref.model.get(\"override_config\", {}))\n    )\n\n    # Step 3: Apply the fix logic\n    attn_implementation = override_model_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n    # Step 4: Verify the fix works\n    assert attn_implementation == \"eager\"\n\n    # Step 5: Verify this would be passed to both AutoConfig and Model creation\n    # (This would normally be done with mocks, but we can test the parameter preparation)\n    config_params = {\"attn_implementation\": attn_implementation}\n    model_params = {\"attn_implementation\": attn_implementation}\n\n    assert config_params[\"attn_implementation\"] == \"eager\"\n    assert model_params[\"attn_implementation\"] == \"eager\"\n\n\ndef test_critic_attn_implementation_fix_integration():\n    \"\"\"Integration test to verify the entire fix works as expected for CriticWorker.\"\"\"\n\n    # This test simulates the complete flow from configuration to model creation for critic\n\n    # Step 1: Simulate Hydra configuration with +prefix for critic\n    # user_config = \"+critic.model.override_config.attn_implementation=sdpa\"\n\n    # This would result in a config structure like:\n    config_dict = {\"critic\": {\"model\": {\"override_config\": {\"attn_implementation\": \"sdpa\"}}}}\n\n    # Step 2: Extract override_model_config as should be done in CriticWorker\n    omegaconf = OmegaConf.create(config_dict)\n    override_model_config = OmegaConf.to_container(OmegaConf.create(omegaconf.critic.model.get(\"override_config\", {})))\n\n    # Step 3: Apply the fix logic (what needs to be implemented in CriticWorker)\n    attn_implementation = override_model_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n    # Step 4: Verify the fix works for critic\n    assert attn_implementation == \"sdpa\"\n\n    # Step 5: Verify this would be passed to AutoConfig creation for critic\n    config_params = {\"attn_implementation\": attn_implementation}\n\n    assert config_params[\"attn_implementation\"] == \"sdpa\"\n\n\ndef test_complete_training_configuration():\n    \"\"\"Integration test for a complete training configuration with both actor and critic overrides.\"\"\"\n\n    # This test simulates a realistic training configuration where both\n    # actor and critic have different attention implementations\n    config_dict = {\n        \"actor_rollout_ref\": {\n            \"model\": {\n                \"path\": \"/shared/models/llama-7b\",\n                \"override_config\": {\"attn_implementation\": \"eager\", \"torch_dtype\": \"bfloat16\"},\n            }\n        },\n        \"critic\": {\n            \"model\": {\n                \"path\": \"/shared/models/llama-7b\",\n                \"override_config\": {\"attn_implementation\": \"sdpa\", \"num_labels\": 1},\n            }\n        },\n    }\n\n    omegaconf = OmegaConf.create(config_dict)\n\n    # Extract configurations as would be done in the workers\n    actor_override_config = OmegaConf.to_container(\n        OmegaConf.create(omegaconf.actor_rollout_ref.model.get(\"override_config\", {}))\n    )\n    critic_override_config = OmegaConf.to_container(OmegaConf.create(omegaconf.critic.model.get(\"override_config\", {})))\n\n    # Apply the fix logic for both\n    actor_attn_implementation = actor_override_config.get(\"attn_implementation\", \"flash_attention_2\")\n    critic_attn_implementation = critic_override_config.get(\"attn_implementation\", \"flash_attention_2\")\n\n    # Verify both configurations work independently\n    assert actor_attn_implementation == \"eager\"\n    assert critic_attn_implementation == \"sdpa\"\n\n    # Verify other configs are preserved\n    assert actor_override_config.get(\"torch_dtype\") == \"bfloat16\"\n    assert critic_override_config.get(\"num_labels\") == 1\n\n\nif __name__ == \"__main__\":\n    # Run basic tests\n    test_attn_implementation_fix_integration()\n    test_critic_attn_implementation_fix_integration()\n    test_complete_training_configuration()\n\n    if VERL_AVAILABLE:\n        # Run class-based tests\n        test_class = TestFSDPAttnImplementation()\n        test_class.test_attn_implementation_extraction_logic()\n        test_class.test_override_config_integration()\n        test_class.test_hydra_plus_prefix_config()\n        test_class.test_backward_compatibility()\n\n        # Run new critic tests\n        test_class.test_critic_attn_implementation_extraction_logic()\n        test_class.test_critic_override_config_integration()\n        test_class.test_critic_hydra_plus_prefix_config()\n        test_class.test_both_actor_and_critic_configuration()\n        test_class.test_critic_backward_compatibility()\n\n        print(\"✓ All FSDP attn_implementation tests passed!\")\n        print(\"✓ All CriticWorker attn_implementation tests passed!\")\n    else:\n        print(\"⚠ VERL components not available, skipping VERL-specific tests\")\n\n    print(\"✓ Integration tests passed!\")\n    print(\"✓ Critic integration tests passed!\")\n"
  },
  {
    "path": "verl_distillation/tests/workers/test_fsdp_workers.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nfrom omegaconf import OmegaConf\n\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker\n\n\ndef test_actor_rollout_ref_worker_actor_ref_model():\n    \"\"\"Test specifying different reference/actor model\"\"\"\n    os.environ[\"RANK\"] = \"0\"\n    os.environ[\"WORLD_SIZE\"] = \"1\"\n    os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n    os.environ[\"MASTER_PORT\"] = \"8888\"\n\n    config_str = \"\"\"\n    model:\n      path: Qwen/Qwen2.5-0.5B-Instruct\n    actor:\n      _target_: verl.workers.config.FSDPActorConfig\n      strategy: fsdp\n      fsdp_config:\n        _target_: verl.workers.config.FSDPEngineConfig\n        fsdp_size: -1\n        forward_prefetch: false\n      profiler:\n        tool: torch_memory\n        save_path: ./mem_snapshots\n        tool_config:\n          torch_memory:\n            _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n            trace_alloc_max_entries: 100000\n            stack_depth: 32\n    ref:\n      model:\n        path: Qwen/Qwen2.5-1.5B-Instruct\n      fsdp_config:\n        _target_: verl.workers.config.FSDPEngineConfig\n        fsdp_size: -1\n      profiler:\n        tool: torch_memory\n        save_path: ./mem_snapshots\n        tool_config:\n          torch_memory:\n            _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n            trace_alloc_max_entries: 100000\n            stack_depth: 32\n      log_prob_micro_batch_size: 1\n      ulysses_sequence_parallel_size: 1\n      entropy_from_logits_with_chunking: false\n    \"\"\"\n    dict_conf = OmegaConf.create(config_str)\n    actor_rollout_ref_worker = ActorRolloutRefWorker(dict_conf, role=\"ref\")\n    actor_rollout_ref_worker.init_model()\n\n    model_config = actor_rollout_ref_worker.ref_module_fsdp._fsdp_wrapped_module.config\n    assert model_config.hidden_size == 1536\n\n    # set ref.model to null, fallback to default case where actor is the same as reference\n    dict_conf[\"ref\"][\"model\"] = None\n    actor_rollout_ref_worker = ActorRolloutRefWorker(dict_conf, role=\"ref\")\n    actor_rollout_ref_worker.init_model()\n\n    model_config = actor_rollout_ref_worker.ref_module_fsdp._fsdp_wrapped_module.config\n    assert model_config.hidden_size == 896\n"
  },
  {
    "path": "verl_distillation/verl/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nimport logging\nimport os\nfrom importlib.metadata import PackageNotFoundError\nfrom importlib.metadata import version as get_version\n\nfrom packaging.version import parse as parse_version\n\nfrom .protocol import DataProto\nfrom .utils.device import is_npu_available\nfrom .utils.import_utils import import_external_libs\nfrom .utils.logging_utils import set_basic_config\n\nversion_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))\n\nwith open(os.path.join(version_folder, \"version/version\")) as f:\n    __version__ = f.read().strip()\n\n\nset_basic_config(level=logging.WARNING)\n\n\n__all__ = [\"DataProto\", \"__version__\"]\n\n\nmodules = os.getenv(\"VERL_USE_EXTERNAL_MODULES\", \"\")\nif modules:\n    modules = modules.split(\",\")\n    import_external_libs(modules)\n\n\nif os.getenv(\"VERL_USE_MODELSCOPE\", \"False\").lower() == \"true\":\n    if importlib.util.find_spec(\"modelscope\") is None:\n        raise ImportError(\"You are using the modelscope hub, please install modelscope by `pip install modelscope -U`\")\n    # Patch hub to download models from modelscope to speed up.\n    from modelscope.utils.hf_util import patch_hub\n\n    patch_hub()\n\nif is_npu_available:\n    from .models.transformers import npu_patch as npu_patch\n\n    package_name = \"transformers\"\n    required_version_spec = \"4.52.4\"\n    try:\n        installed_version = get_version(package_name)\n        installed = parse_version(installed_version)\n        required = parse_version(required_version_spec)\n\n        if installed < required:\n            raise ValueError(\n                f\"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is \"\n                f\"{installed}.\"\n            )\n    except PackageNotFoundError as e:\n        raise ImportError(\n            f\"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}\"\n        ) from e\n\n    # In verl, the driver process aggregates the computation results of workers via Ray.\n    # Therefore, after a worker completes its computation job, it will package the output\n    # using tensordict and transfer it to the CPU. Since the `to` operation of tensordict\n    # is non-blocking, when transferring data from a device to the CPU, it is necessary to\n    # ensure that a batch of data has been completely transferred before being used on the\n    # host; otherwise, unexpected precision issues may arise. Tensordict has already noticed\n    # this problem and fixed it. Ref: https://github.com/pytorch/tensordict/issues/725\n    # However, the relevant modifications only cover CUDA and MPS devices and do not take effect\n    # for third-party devices such as NPUs. This patch fixes this issue, and the relevant\n    # modifications can be removed once the fix is merged into tensordict.\n\n    import tensordict\n\n    if parse_version(tensordict.__version__) < parse_version(\"0.10.0\"):\n        from tensordict.base import TensorDictBase\n\n        def _sync_all_patch(self):\n            from torch._utils import _get_available_device_type, _get_device_module\n\n            device_type = _get_available_device_type()\n            if device_type is None:\n                return\n\n            device_module = _get_device_module(device_type)\n            device_module.synchronize()\n\n        TensorDictBase._sync_all = _sync_all_patch\n"
  },
  {
    "path": "verl_distillation/verl/base_config.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport collections\nfrom dataclasses import FrozenInstanceError, dataclass, fields\nfrom typing import Any\n\n\n# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary\n@dataclass\nclass BaseConfig(collections.abc.Mapping):\n    \"\"\"The BaseConfig provides dict-like interface for a dataclass config.\n\n    By default all fields in the config is not mutable, unless specified in\n    \"_mutable_fields\". The BaseConfig class implements the Mapping Abstract Base Class.\n    This allows instances of this class to be used like dictionaries.\n    \"\"\"\n\n    _mutable_fields = set()\n    _target_: str = \"\"\n\n    def __setattr__(self, name: str, value):\n        \"\"\"Set the value of an attribute. Check if the attr is mutable before setting the value.\"\"\"\n        # If the field already exists, it's considered frozen unless it's in _mutable_fields\n        if name in self.__dict__ and name not in getattr(self, \"_mutable_fields\", set()):\n            raise FrozenInstanceError(f\"Field '{name}' is frozen and cannot be modified\")\n        super().__setattr__(name, value)\n\n    def get(self, key: str, default: Any = None) -> Any:\n        \"\"\"Get the value associated with the given key. If the key does not exist, return the default value.\n\n        Args:\n            key (str): The attribute name to retrieve.\n            default (Any, optional): The value to return if the attribute does not exist. Defaults to None.\n\n        Returns:\n            Any: The value of the attribute or the default value.\n        \"\"\"\n        try:\n            return getattr(self, key)\n        except AttributeError:\n            return default\n\n    def __getitem__(self, key: str):\n        \"\"\"Implement the [] operator for the class. Allows accessing attributes like dictionary items.\n\n        Args:\n            key (str): The attribute name to retrieve.\n\n        Returns:\n            Any: The value of the attribute.\n\n        Raises:\n            AttributeError: If the attribute does not exist.\n            TypeError: If the key type is not string\n        \"\"\"\n        return getattr(self, key)\n\n    def __iter__(self):\n        \"\"\"Implement the iterator protocol. Allows iterating over the attribute names of the instance.\n\n        Yields:\n            str: The name of each field in the dataclass.\n        \"\"\"\n        for f in fields(self):\n            yield f.name\n\n    def __len__(self):\n        \"\"\"\n        Return the number of fields in the dataclass.\n\n        Returns:\n            int: The number of fields in the dataclass.\n        \"\"\"\n        return len(fields(self))\n"
  },
  {
    "path": "verl_distillation/verl/experimental/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/experimental/agent_loop/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .agent_loop import AgentLoopBase, AgentLoopManager, AgentLoopWorker, AsyncLLMServerManager\nfrom .single_turn_agent_loop import SingleTurnAgentLoop\nfrom .tool_agent_loop import ToolAgentLoop\n\n_ = [SingleTurnAgentLoop, ToolAgentLoop]\n\n__all__ = [\"AgentLoopBase\", \"AgentLoopManager\", \"AsyncLLMServerManager\", \"AgentLoopWorker\"]\n"
  },
  {
    "path": "verl_distillation/verl/experimental/agent_loop/agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport heapq\nimport logging\nimport os\nimport random\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Optional\n\nimport hydra\nimport numpy as np\nimport ray\nimport torch\nfrom cachetools import LRUCache\nfrom omegaconf import DictConfig, OmegaConf\nfrom pydantic import BaseModel, ConfigDict\nfrom tensordict import TensorDict\nfrom transformers import AutoProcessor, AutoTokenizer\n\nfrom verl.experimental.reward import RewardManagerWorker\nfrom verl.protocol import DataProto\nfrom verl.single_controller.ray.base import RayWorkerGroup\nfrom verl.utils import hf_processor, hf_tokenizer\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.utils.rollout_trace import (RolloutTraceConfig, rollout_trace_attr,\n                                      rollout_trace_op)\nfrom verl.utils.transferqueue_utils import tqbridge\nfrom verl.workers.rollout.replica import TokenOutput, get_rollout_replica_class\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass AsyncLLMServerManager:\n    \"\"\"\n    A class to manage multiple OpenAI compatible LLM servers. This class provides\n    - Load balance: least requests load balancing\n    - Sticky session: send multi-turn chat completions to same server for automatic prefix caching\n    \"\"\"\n\n    def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000):\n        \"\"\"Initialize the AsyncLLMServerManager.\n\n        Args:\n            config (DictConfig): YAML config.\n            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.\n            max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000.\n        \"\"\"\n        self.config = config\n        self.server_handles = server_handles\n        random.shuffle(self.server_handles)\n\n        # Least requests load balancing\n        self.weighted_serveres = [[0, (hash(server), server)] for server in server_handles]\n        heapq.heapify(self.weighted_serveres)\n\n        # LRU cache to map request_id to server\n        self.request_id_to_server = LRUCache(maxsize=max_cache_size)\n\n    def _choose_server(self, request_id: str) -> ray.actor.ActorHandle:\n        # TODO: implement server pressure awareness load balancing\n        if request_id in self.request_id_to_server:\n            return self.request_id_to_server[request_id]\n\n        server = self.weighted_serveres[0][1][1]\n        self.weighted_serveres[0][0] += 1\n        heapq.heapreplace(self.weighted_serveres, self.weighted_serveres[0])\n        self.request_id_to_server[request_id] = server\n        return server\n\n    @rollout_trace_op\n    async def generate(\n        self,\n        request_id,\n        *,\n        prompt_ids: list[int],\n        sampling_params: dict[str, Any],\n        image_data: Optional[list[Any]] = None,\n    ) -> TokenOutput:\n        \"\"\"Generate tokens from prompt ids.\n\n        Args:\n            request_id (str): request id for sticky session.\n            prompt_ids (List[int]): List of prompt token ids.\n            sampling_params (Dict[str, Any]): Sampling parameters for the chat completion.\n\n        Returns:\n            TokenOutput: token output\n        \"\"\"\n        server = self._choose_server(request_id)\n        output = await server.generate.remote(\n            request_id=request_id,\n            prompt_ids=prompt_ids,\n            sampling_params=sampling_params,\n            image_data=image_data,\n        )\n        return output\n\n\nclass AgentLoopMetrics(BaseModel):\n    \"\"\"Agent loop performance metrics.\"\"\"\n\n    generate_sequences: float = 0.0\n    tool_calls: float = 0.0\n\n\nclass AgentLoopOutput(BaseModel):\n    \"\"\"Agent loop output.\"\"\"\n\n    prompt_ids: list[int]\n    \"\"\"Prompt token ids.\"\"\"\n    response_ids: list[int]\n    \"\"\"Response token ids including LLM generated token, tool response token.\"\"\"\n    response_mask: list[int]\n    \"\"\"Response mask, 1 for LLM generated token, 0 for tool response token.\"\"\"\n    distill_special_token_mask: Optional[list[int]] = None\n    \"\"\"distill mask, 1 for special token, 0 for normal token for ref model.\"\"\"\n    response_logprobs: Optional[list[float]] = None\n    \"\"\"Log probabilities for the response tokens.\"\"\"\n    multi_modal_data: Optional[dict[str, Any]] = None\n    \"\"\"Multi-modal data for multi-modal tools.\"\"\"\n    reward_score: Optional[float] = None\n    \"\"\"Reward score for the trajectory.\"\"\"\n    num_turns: int = 0\n    \"\"\"Number of chat turns, including user, assistant, tool.\"\"\"\n    metrics: AgentLoopMetrics\n    \"\"\"Auxiliary performance metrics\"\"\"\n    extra_fields: dict[str, Any] = {}\n    \"\"\"Extra fields for dynamic addition.\"\"\"\n\n\nclass _InternalAgentLoopOutput(AgentLoopOutput):\n    \"\"\"Internal agent loop output with padded sequences.\"\"\"\n\n    model_config = ConfigDict(arbitrary_types_allowed=True)\n\n    prompt_ids: torch.Tensor\n    \"\"\"Padded prompt token ids.\"\"\"\n    response_ids: torch.Tensor\n    \"\"\"Padded response token ids.\"\"\"\n    input_ids: torch.Tensor\n    \"\"\"Padded input ids(prompt_ids + response_ids).\"\"\"\n    position_ids: torch.Tensor\n    \"\"\"Padded position ids.\"\"\"\n    response_mask: torch.Tensor\n    \"\"\"Padded response mask.\"\"\"\n    attention_mask: torch.Tensor\n    \"\"\"Padded attention mask.\"\"\"\n    response_logprobs: Optional[torch.Tensor] = None\n    \"\"\"Padded log probabilities for the response tokens.\"\"\"\n    multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None\n    \"\"\"Multi-modal inputs for processors (e.g., pixel_values, image_grid_thw).\"\"\"\n    distill_special_token_mask: Optional[torch.Tensor] = None\n    \"\"\"distill mask, 1 for special token, 0 for normal token for ref model.\"\"\"\n    extra_fields: dict[str, Any] = {}\n    \"\"\"Extra fields for dynamic addition.\"\"\"\n\n\n# make hydra.utils.instantiate happy\nclass _DummyConfig:\n    def __init__(self, config: DictConfig) -> None:\n        self.config = config\n\n\nclass AgentLoopBase(ABC):\n    \"\"\"An agent loop takes a input message, chat with OpenAI compatible LLM server and interact with various\n    environments.\"\"\"\n\n    _class_initialized = False\n\n    def __init__(\n        self,\n        trainer_config: _DummyConfig,\n        server_manager: AsyncLLMServerManager,\n        tokenizer: AutoTokenizer,\n        processor: AutoProcessor,\n        **kwargs,\n    ):\n        \"\"\"Initialize agent loop, each sample will have its own loop instance.\n\n        Args:\n            trainer_config (_DummyConfig): trainer config.\n            server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager.\n            tokenizer (AutoTokenizer): Tokenizer for tokenize messages.\n            processor (AutoProcessor): Processor for process messages.\n        \"\"\"\n        self.init_class(config=trainer_config.config, tokenizer=tokenizer, processor=processor, **kwargs)\n        self.config = trainer_config.config\n        self.server_manager = server_manager\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.loop = asyncio.get_running_loop()\n\n    @classmethod\n    def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, processor: AutoProcessor, **kwargs):\n        \"\"\"This is used to do heavy initialization work that should shared across all instances. It's only called once.\n\n        Args:\n            config (DictConfig): trainer config.\n            tokenizer (AutoTokenizer): Tokenizer for tokenize messages.\n            processor (AutoProcessor): Processor for process multi_modal data.\n            **kwargs: extra kwargs from config file passed in by `hydra.utils.instantiate`.\n        \"\"\"\n        if cls._class_initialized:\n            return\n        cls._class_initialized = True\n\n    @abstractmethod\n    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:\n        \"\"\"Run agent loop to interact with LLM server and environment.\n\n        Args:\n            sampling_params (Dict[str, Any]): LLM sampling params.\n            **kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`.\n\n        Returns:\n            AgentLoopOutput: Agent loop output.\n        \"\"\"\n        raise NotImplementedError\n\n\n\"\"\"Agent loop registry: key is agent_name, value is a dict of agent loop config\nused by hydra.utils.instantiate to initialize agent loop instance.\n\nhttps://hydra.cc/docs/advanced/instantiate_objects/overview/\n\"\"\"\n_agent_loop_registry: dict[str, dict] = {}\n\n\ndef register(agent_name: str):\n    \"\"\"Register agent loop class.\"\"\"\n\n    def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]:\n        fqdn = f\"{subclass.__module__}.{subclass.__qualname__}\"\n        _agent_loop_registry[agent_name] = {\"_target_\": fqdn}\n        return subclass\n\n    return decorator\n\n\nclass AgentLoopWorkerBase:\n    \"\"\"Agent loop worker takes a batch of messages and run each message in an agent loop.\"\"\"\n\n    def __init__(\n        self,\n        config: DictConfig,\n        server_handles: list[ray.actor.ActorHandle],\n        reward_router_address: str = None,\n    ):\n        \"\"\"Initialize agent loop manager.\n\n        Args:\n            config (DictConfig): YAML config.\n            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.\n        \"\"\"\n        self.config = config\n\n        # for recipe to change\n        if not hasattr(self, \"server_manager\"):\n            self.server_manager = AsyncLLMServerManager(config, server_handles)\n\n        self.reward_router_address = reward_router_address\n\n        model_path = config.actor_rollout_ref.model.path\n        self.model_name = \"/\".join(model_path.split(\"/\")[-2:])\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n        self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)\n        self.processor = hf_processor(local_path, trust_remote_code=True)\n\n        agent_loop_config_path = config.actor_rollout_ref.rollout.agent.agent_loop_config_path\n        if agent_loop_config_path:\n            agent_loop_configs = OmegaConf.load(agent_loop_config_path)\n            for agent_loop_config in agent_loop_configs:\n                _agent_loop_registry[agent_loop_config.name] = agent_loop_config\n        if self.config.actor_rollout_ref.model.get(\"custom_chat_template\", None) is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.config.actor_rollout_ref.model.custom_chat_template\n            self.tokenizer.chat_template = self.config.actor_rollout_ref.model.custom_chat_template\n\n        self.reward_manager_worker = RewardManagerWorker.options(\n            scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(\n                node_id=ray.get_runtime_context().get_node_id(),\n                soft=False,\n            ),\n        ).remote(self.config, self.reward_router_address)\n\n        trace_config = self.config.actor_rollout_ref.rollout.get(\"trace\", {})\n        RolloutTraceConfig.init(\n            self.config.trainer.project_name,\n            self.config.trainer.experiment_name,\n            trace_config.get(\"backend\"),\n            trace_config.get(\"token2text\", False),\n        )\n\n    @tqbridge()\n    async def generate_sequences(self, batch: DataProto) -> DataProto:\n        \"\"\"Generate sequences from agent loop.\n\n        Args:\n            batch (DataProto): Input batch.\n\n        Returns:\n            DataProto: Output batch.\n            - prompts: [bsz, prompt_length], prompt token ids from dataset.\n            - responses: [bsz, response_length], output token ids include response tokens\n              from LLM generation and observation tokens from tool_calls.\n            - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.\n            - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens\n              and response tokens.\n            - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.\n            - position_ids: [bsz, prompt_length + response_length], incremental position ids.\n\n            For multi-turn conversations:\n            responses:     |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|\n            response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|\n        \"\"\"\n        config = self.config.actor_rollout_ref.rollout\n        sampling_params = dict(\n            temperature=config.temperature,\n            top_p=config.top_p,\n            repetition_penalty=1.0,\n            logprobs=config.calculate_log_probs,\n        )\n\n        # override sampling params for validation\n        if batch.meta_info.get(\"validate\", False):\n            sampling_params[\"top_p\"] = config.val_kwargs.top_p\n            sampling_params[\"temperature\"] = config.val_kwargs.temperature\n\n        # by default, we assume it's a single turn agent\n        if \"agent_name\" not in batch.non_tensor_batch:\n            default_agent_loop = config.agent.default_agent_loop\n            batch.non_tensor_batch[\"agent_name\"] = np.array([default_agent_loop] * len(batch), dtype=object)\n\n        if \"index\" in batch.non_tensor_batch:\n            index = batch.non_tensor_batch[\"index\"]\n        else:\n            index = np.arange(len(batch))\n\n        trajectory_info = await get_trajectory_info(\n            batch.meta_info.get(\"global_steps\", -1), index.tolist(), batch.meta_info.get(\"validate\", False)\n        )\n\n        tasks = []\n        for i in range(len(batch)):\n            kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}\n            tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params, trajectory_info[i], **kwargs)))\n        outputs = await asyncio.gather(*tasks)\n\n        output = self._postprocess(outputs)\n        return output\n\n    async def _run_agent_loop(\n        self,\n        sampling_params: dict[str, Any],\n        trajectory: dict[str, Any],\n        *,\n        agent_name: str,\n        **kwargs,\n    ) -> _InternalAgentLoopOutput:\n        with rollout_trace_attr(\n            step=trajectory[\"step\"],\n            sample_index=trajectory[\"sample_index\"],\n            rollout_n=trajectory[\"rollout_n\"],\n            validate=trajectory[\"validate\"],\n            name=\"agent_loop\",\n        ):\n            assert agent_name in _agent_loop_registry, (\n                f\"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}\"\n            )\n\n            agent_loop_config = _agent_loop_registry[agent_name]\n            agent_loop = hydra.utils.instantiate(\n                config=agent_loop_config,\n                trainer_config=_DummyConfig(config=self.config),\n                server_manager=self.server_manager,\n                tokenizer=self.tokenizer,\n                processor=self.processor,\n            )\n            output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs)\n\n            # Some AgentLoop may have already computed the reward score, e.g SWE-agent.\n\n            # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py\n            # prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])\n            # response_ids: right padded with zeros (e.g., [5,6,7,8,0,0,0,0])\n            # input_ids: concatenation of prompt + response\n            # Mask:\n            # For example, if the prompt is [1,2,3,4] and the response is [5,6,7,(tool start)8,9(tool end),10,11,12]\n            # - prompt_attention_mask: 0s for padding, 1s for tokens\n            #   e.g., [0,0,0,0,1,1,1,1]\n            # - response_attention_mask: 0s for padding, 1s for tokens\n            #   e.g., [1,1,1,1,1,1,1,1,1,1,1,0,0,0,0]\n            # attention_mask: concatenation of prompt_attention_mask and response_attention_mask\n            #   e.g., [0,0,0,0,1,1,1,1(prompt),1,1,1,1,1,1,1,1,1,1,1,0,0,0,0(response)]\n            # - response_mask: 1s for LLM generated tokens, 0 for tool response/padding tokens\n            #   e.g., [1,1,1,1,1,1,1,(tool start),0,0(tool end),1,1,0,0,0,0]\n            # - position_ids: sequential positions for tokens, starting at 0\n            #   e.g., [0,0,0,0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,0,0,0,0]\n\n            self.tokenizer.padding_side = \"left\"\n            prompt_output = self.tokenizer.pad(\n                {\"input_ids\": output.prompt_ids},\n                padding=\"max_length\",\n                max_length=self.config.actor_rollout_ref.rollout.prompt_length,\n                return_tensors=\"pt\",\n                return_attention_mask=True,\n            )\n            if prompt_output[\"input_ids\"].dim() == 1:\n                prompt_output[\"input_ids\"] = prompt_output[\"input_ids\"].unsqueeze(0)\n                prompt_output[\"attention_mask\"] = prompt_output[\"attention_mask\"].unsqueeze(0)\n\n            self.tokenizer.padding_side = \"right\"\n            response_output = self.tokenizer.pad(\n                {\"input_ids\": output.response_ids},\n                padding=\"max_length\",\n                max_length=self.config.actor_rollout_ref.rollout.response_length,\n                return_tensors=\"pt\",\n                return_attention_mask=True,\n            )\n            if response_output[\"input_ids\"].dim() == 1:\n                response_output[\"input_ids\"] = response_output[\"input_ids\"].unsqueeze(0)\n                response_output[\"attention_mask\"] = response_output[\"attention_mask\"].unsqueeze(0)\n\n            response_mask_output = self.tokenizer.pad(\n                {\"input_ids\": output.response_mask},\n                padding=\"max_length\",\n                max_length=self.config.actor_rollout_ref.rollout.response_length,\n                return_tensors=\"pt\",\n                return_attention_mask=False,\n            )\n            if response_mask_output[\"input_ids\"].dim() == 1:\n                response_mask_output[\"input_ids\"] = response_mask_output[\"input_ids\"].unsqueeze(0)\n\n            distill_special_token_mask = None\n            if output.distill_special_token_mask is not None:\n                distill_special_token_mask_output = self.tokenizer.pad(\n                    {\"input_ids\": output.distill_special_token_mask},\n                    padding=\"max_length\",\n                    max_length=self.config.actor_rollout_ref.rollout.response_length,\n                    return_tensors=\"pt\",\n                    return_attention_mask=False,\n                )\n                if distill_special_token_mask_output[\"input_ids\"].dim() == 1:\n                    distill_special_token_mask_output[\"input_ids\"] = distill_special_token_mask_output[\"input_ids\"].unsqueeze(0)\n                distill_special_token_mask = distill_special_token_mask_output[\"input_ids\"] * response_output[\"attention_mask\"]\n\n            response_logprobs = None\n            if output.response_logprobs is not None:\n                pad_size = self.config.actor_rollout_ref.rollout.response_length - len(output.response_logprobs)\n                response_logprobs = torch.tensor(output.response_logprobs + [0.0] * pad_size).unsqueeze(0)\n\n            response_mask = response_mask_output[\"input_ids\"] * response_output[\"attention_mask\"]\n            attention_mask = torch.cat([prompt_output[\"attention_mask\"], response_output[\"attention_mask\"]], dim=1)\n            input_ids = torch.cat([prompt_output[\"input_ids\"], response_output[\"input_ids\"]], dim=1)\n\n            # Handle multi-modal inputs and position_ids calculation\n            # Only support Qwen2VLImageProcessor for multi-modal processing currently\n            # TODO: support other multi-modal inputs\n            multi_modal_inputs = None\n            if (\n                self.processor is not None\n                and \"Qwen2VLImageProcessor\" in self.processor.image_processor.__class__.__name__\n            ):\n                from verl.models.transformers.qwen2_vl import get_rope_index\n\n                images = getattr(output, \"multi_modal_data\", {}).get(\"image\", None)\n                current_text = self.tokenizer.decode(input_ids.squeeze(0), skip_special_tokens=True)\n                multi_modal_inputs = self.processor(text=[current_text], images=images, return_tensors=\"pt\")\n                multi_modal_inputs.pop(\"input_ids\", None)\n                multi_modal_inputs.pop(\"attention_mask\", None)\n\n                # We must use dict(multi_modal_inputs) to convert BatchFeature values to a new dict\n                # because np.array() only keeps the keys for BatchFeature.\n                multi_modal_inputs = dict(multi_modal_inputs)\n\n                image_grid_thw = multi_modal_inputs.get(\"image_grid_thw\")\n                video_grid_thw = multi_modal_inputs.get(\"video_grid_thw\")\n                second_per_grid_ts = multi_modal_inputs.get(\"second_per_grid_ts\")\n\n                vision_position_ids = get_rope_index(\n                    self.processor,\n                    input_ids=input_ids.squeeze(0),\n                    image_grid_thw=image_grid_thw,\n                    video_grid_thw=video_grid_thw,\n                    second_per_grid_ts=second_per_grid_ts,\n                    attention_mask=attention_mask.squeeze(0),\n                ).unsqueeze(0)  # (1, 3, seq_len)\n\n                valid_mask = attention_mask[0].bool()\n                text_position_ids = torch.ones((1, len(input_ids[0])), dtype=torch.long)\n                text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item())\n                text_position_ids = text_position_ids.unsqueeze(0)\n                position_ids = torch.cat((text_position_ids, vision_position_ids), dim=1)  # (1, 4, seq_length)\n            else:\n                position_ids = compute_position_id_with_mask(attention_mask)  # (1, seq_len)\n            enable_async_reward = (\n                self.reward_router_address is not None and self.config.reward_model.enable_resource_pool\n            ) or not self.config.reward_model.enable\n            if output.reward_score is None and enable_async_reward and self.config.reward_model.get(\"compute_in_agent_loop\", False):\n                batch = TensorDict(\n                    {\n                        \"prompts\": prompt_output[\"input_ids\"],  # [1, prompt_length]\n                        \"responses\": response_output[\"input_ids\"],  # [1, response_length]\n                        \"attention_mask\": attention_mask,  # [1, prompt_length + response_length]\n                        \"input_ids\": input_ids,  # [1, prompt_length + response_length]\n                        \"position_ids\": position_ids,\n                    },\n                    batch_size=1,\n                )\n                non_tensor_batch = {\n                    **{k: np.array([v]) for k, v in kwargs.items()},\n                    \"__num_turns__\": np.array([output.num_turns]),\n                    \"tool_extra_fields\": np.array([output.extra_fields], dtype=object),\n                }\n\n                data = DataProto(\n                    batch=batch,\n                    non_tensor_batch=non_tensor_batch,\n                )\n                result = await self.reward_manager_worker.compute_score.remote(data)\n                output.reward_score = result[\"reward_score\"]\n                output.extra_fields[\"reward_extra_info\"] = result[\"reward_extra_info\"]\n\n            return _InternalAgentLoopOutput(\n                prompt_ids=prompt_output[\"input_ids\"],\n                response_ids=response_output[\"input_ids\"],\n                input_ids=input_ids,\n                position_ids=position_ids,\n                response_mask=response_mask,\n                attention_mask=attention_mask,\n                response_logprobs=response_logprobs,\n                distill_special_token_mask=distill_special_token_mask,\n                multi_modal_inputs=multi_modal_inputs,\n                multi_modal_data=output.multi_modal_data,\n                reward_score=output.reward_score,\n                num_turns=output.num_turns,\n                metrics=output.metrics,\n                extra_fields=output.extra_fields,\n            )\n\n    def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:\n        \"\"\"Process the padded outputs from _run_agent_loop and combine them into a batch.\"\"\"\n        # Convert lists back to tensors and stack them to create a batch.\n        prompt_ids = torch.cat([input.prompt_ids for input in inputs], dim=0)\n        response_ids = torch.cat([input.response_ids for input in inputs], dim=0)\n        response_mask = torch.cat([input.response_mask for input in inputs], dim=0)\n        attention_mask = torch.cat([input.attention_mask for input in inputs], dim=0)\n        input_ids = torch.cat([input.input_ids for input in inputs], dim=0)\n        position_ids = torch.cat([input.position_ids for input in inputs], dim=0)\n        optional_outputs = {}\n        if inputs[0].response_logprobs is not None:\n            optional_outputs[\"rollout_log_probs\"] = torch.cat([input.response_logprobs for input in inputs], dim=0)\n        if inputs[0].distill_special_token_mask is not None:\n            optional_outputs[\"distill_special_token_mask\"] = torch.cat([input.distill_special_token_mask for input in inputs], dim=0)\n\n\n        batch = TensorDict(\n            {\n                \"prompts\": prompt_ids,  # [bsz, prompt_length]\n                \"responses\": response_ids,  # [bsz, response_length]\n                \"response_mask\": response_mask,  # [bsz, response_length]\n                \"input_ids\": input_ids,  # [bsz, prompt_length + response_length]\n                \"attention_mask\": attention_mask,  # [bsz, prompt_length + response_length]\n                # position_ids: [bsz, 3, prompt_length + response_length] or [bsz, prompt_length + response_length]\n                \"position_ids\": position_ids,\n                **optional_outputs,\n            },\n            batch_size=len(inputs),\n        )\n\n        scores = [input.reward_score for input in inputs]\n        if all(score is not None for score in scores):\n            prompt_length = prompt_ids.size(1)\n            response_length = attention_mask[:, prompt_length:].sum(dim=1) - 1\n            rm_scores = torch.zeros_like(response_mask, dtype=torch.float32)\n            rm_scores[torch.arange(response_mask.size(0)), response_length] = torch.tensor(scores, dtype=torch.float32)\n            batch[\"rm_scores\"] = rm_scores\n\n        non_tensor_batch = {\n            \"__num_turns__\": np.array([input.num_turns for input in inputs], dtype=np.int32),\n        }\n\n        # add reward_extra_info to non_tensor_batch\n        reward_extra_infos = [input.extra_fields.get(\"reward_extra_info\", {}) for input in inputs]\n        reward_extra_keys = list(reward_extra_infos[0].keys())\n        for key in reward_extra_keys:\n            non_tensor_batch[key] = np.array([info[key] for info in reward_extra_infos])\n\n        # Add multi_modal_inputs to non_tensor_batch if any samples have them\n        multi_modal_inputs_list = [input.multi_modal_inputs for input in inputs]\n        if any(mmi is not None for mmi in multi_modal_inputs_list):\n            non_tensor_batch[\"multi_modal_inputs\"] = np.array(multi_modal_inputs_list, dtype=object)\n\n        metrics = [input.metrics.model_dump() for input in inputs]\n        # Collect extra fields from all inputs and convert them to np.ndarray\n        extra_fields = {}\n        all_keys = set(key for input_item in inputs for key in input_item.extra_fields)\n        for key in all_keys:\n            temp_arr = np.empty(len(inputs), dtype=object)\n            temp_arr[:] = [input.extra_fields.get(key) for input in inputs]\n            extra_fields[key] = temp_arr\n\n        non_tensor_batch.update(extra_fields)\n        return DataProto(\n            batch=batch,\n            non_tensor_batch=non_tensor_batch,\n            meta_info={\"metrics\": metrics, \"reward_extra_keys\": reward_extra_keys},\n        )\n\n    def create_transferqueue_client(self, controller_infos, storage_infos, role):\n        \"\"\"Create a client for data system(transfer queue).\"\"\"\n        from verl.single_controller.ray.base import get_random_string\n        from verl.utils.transferqueue_utils import create_transferqueue_client\n\n        client_name = get_random_string(length=6)\n        create_transferqueue_client(\n            client_id=f\"{role}_worker_{client_name}\",\n            controller_infos=controller_infos,\n            storage_infos=storage_infos,\n        )\n\n\n@ray.remote\nclass AgentLoopWorker(AgentLoopWorkerBase):\n    \"\"\"Agent loop worker takes a batch of messages and run each message in an agent loop.\"\"\"\n\n    def __init__(\n        self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], reward_router_address: str = None\n    ):\n        \"\"\"Initialize agent loop manager.\n        Args:\n            config (DictConfig): YAML config.\n            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.\n            reward_router_address (str): reward router address.\n        \"\"\"\n        super().__init__(config, server_handles, reward_router_address)\n\n\nasync def get_trajectory_info(step, index, validate):\n    \"\"\"Get trajectory info.\n\n    Args:\n        step (int): global steps in the trainer.\n        index (list): form datastore extra_info.index column.\n        validate (bool): whether is a validate step.\n\n    Returns:\n        list: trajectory.\n    \"\"\"\n    trajectory_info = []\n    rollout_n = 0\n    for i in range(len(index)):\n        if i > 0 and index[i - 1] == index[i]:\n            rollout_n += 1\n        else:\n            rollout_n = 0\n        trajectory_info.append({\"step\": step, \"sample_index\": index[i], \"rollout_n\": rollout_n, \"validate\": validate})\n    return trajectory_info\n\n\nclass AgentLoopManager:\n    \"\"\"Agent loop manager that manages a group of agent loop workers.\"\"\"\n\n    def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None):\n        \"\"\"Initialize agent loop manager.\n\n        Args:\n            config (DictConfig): trainer config.\n            worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode.\n        \"\"\"\n        self.config = config\n        self.worker_group = worker_group\n        self.reward_model_manager = None\n        self.reward_router_address = None\n        if self.config.reward_model.enable and self.config.reward_model.enable_resource_pool:\n            from verl.experimental.reward import RewardModelManager\n\n            self.reward_model_manager = RewardModelManager(config.reward_model, rm_wg)\n            self.reward_router_address = self.reward_model_manager.get_router_address()\n\n        # for recipe to change\n        if not hasattr(self, \"rollout_replica_class\"):\n            self.rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name)\n        if not hasattr(self, \"agent_loop_workers_class\"):\n            self.agent_loop_workers_class = AgentLoopWorker\n\n        self._initialize_llm_servers()\n        self._init_agent_loop_workers()\n\n        # Initially we're in sleep mode.\n        if self.config.actor_rollout_ref.rollout.free_cache_engine:\n            self.sleep()\n\n    def _initialize_llm_servers(self):\n        rollout_world_size = (\n            self.config.actor_rollout_ref.rollout.tensor_model_parallel_size\n            * self.config.actor_rollout_ref.rollout.data_parallel_size\n            * self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size\n        )\n        world_size = (\n            self.worker_group.world_size\n            if self.worker_group\n            else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes\n        )\n        num_replicas = world_size // rollout_world_size\n\n        rollout_config = self.config.actor_rollout_ref.rollout\n        model_config = self.config.actor_rollout_ref.model\n        self.rollout_replicas = [\n            self.rollout_replica_class(\n                replica_rank=replica_rank,\n                config=rollout_config,\n                model_config=model_config,\n                gpus_per_node=self.config.trainer.n_gpus_per_node,\n            )\n            for replica_rank in range(num_replicas)\n        ]\n        if self.worker_group:\n            self._run_all([server.init_hybrid(self.worker_group) for server in self.rollout_replicas])\n        else:\n            self._run_all([server.init_standalone() for server in self.rollout_replicas])\n        self.server_handles = [server._server_handle for server in self.rollout_replicas]\n        self.server_addresses = [server._server_address for server in self.rollout_replicas]\n\n    def _init_agent_loop_workers(self):\n        self.agent_loop_workers = []\n        num_workers = self.config.actor_rollout_ref.rollout.agent.num_workers\n\n        node_ids = [node[\"NodeID\"] for node in ray.nodes() if node[\"Alive\"] and node[\"Resources\"].get(\"CPU\", 0) > 0]\n        for i in range(num_workers):\n            # Round-robin scheduling over the all nodes\n            node_id = node_ids[i % len(node_ids)]\n            self.agent_loop_workers.append(\n                self.agent_loop_workers_class.options(\n                    name=f\"agent_loop_worker_{i}\",\n                    scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(\n                        node_id=node_id, soft=True\n                    ),\n                ).remote(self.config, self.server_handles, self.reward_router_address)\n            )\n\n    def generate_sequences(self, prompts: DataProto) -> DataProto:\n        \"\"\"Split input batch and dispatch to agent loop workers.\n\n        Args:\n            prompts (DataProto): Input batch.\n\n        Returns:\n            DataProto: Output batch.\n        \"\"\"\n\n        if self.config.actor_rollout_ref.rollout.free_cache_engine:\n            self.wake_up()\n        if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine:\n            self.reward_model_manager.wake_up()\n\n        chunkes = prompts.chunk(len(self.agent_loop_workers))\n        outputs = ray.get(\n            [\n                worker.generate_sequences.remote(chunk)\n                for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)\n            ]\n        )\n        output = DataProto.concat(outputs)\n        if self.config.actor_rollout_ref.rollout.free_cache_engine:\n            self.sleep()\n        if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine:\n            self.reward_model_manager.sleep()\n\n        # calculate performance metrics\n        metrics = [output.meta_info.pop(\"metrics\") for output in outputs]  # List[List[Dict[str, str]]]\n        timing = self._performance_metrics(metrics, output)\n\n        output.meta_info = {\"timing\": timing, **outputs[0].meta_info}\n        return output\n\n    def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:\n        timing = {}\n        t_generate_sequences = np.array([metric[\"generate_sequences\"] for chunk in metrics for metric in chunk])\n        t_tool_calls = np.array([metric[\"tool_calls\"] for chunk in metrics for metric in chunk])\n        timing[\"agent_loop/generate_sequences/min\"] = t_generate_sequences.min()\n        timing[\"agent_loop/generate_sequences/max\"] = t_generate_sequences.max()\n        timing[\"agent_loop/generate_sequences/mean\"] = t_generate_sequences.mean()\n        timing[\"agent_loop/tool_calls/min\"] = t_tool_calls.min()\n        timing[\"agent_loop/tool_calls/max\"] = t_tool_calls.max()\n        timing[\"agent_loop/tool_calls/mean\"] = t_tool_calls.mean()\n\n        # batch sequence generation is bounded by the slowest sample\n        slowest = np.argmax(t_generate_sequences + t_tool_calls)\n        attention_mask = output.batch[\"attention_mask\"][slowest]\n        prompt_length = output.batch[\"prompts\"].shape[1]\n        timing[\"agent_loop/slowest/generate_sequences\"] = t_generate_sequences[slowest]\n        timing[\"agent_loop/slowest/tool_calls\"] = t_tool_calls[slowest]\n        timing[\"agent_loop/slowest/prompt_length\"] = attention_mask[:prompt_length].sum().item()\n        timing[\"agent_loop/slowest/response_length\"] = attention_mask[prompt_length:].sum().item()\n\n        # on-policy distill\n        if \"distill_special_token_mask\" in output.batch: \n            distill_special_token_mask = output.batch[\"distill_special_token_mask\"]\n            timing[\"agent_loop/on_policy_distill/extend_token_seq_ratio\"] = \\\n                (distill_special_token_mask.sum(dim=1) > 0).float().mean().item()\n        \n        return timing\n\n    def wake_up(self):\n        \"\"\"Wake up all rollout replica instances.\"\"\"\n        self._run_all([replica.wake_up() for replica in self.rollout_replicas])\n\n    def sleep(self):\n        \"\"\"Sleep all rollout replica instances.\"\"\"\n        self._run_all([replica.sleep() for replica in self.rollout_replicas])\n\n    def _run_all(self, tasks: list[asyncio.Task]):\n        async def run_all():\n            await asyncio.gather(*tasks)\n\n        asyncio.run(run_all())\n"
  },
  {
    "path": "verl_distillation/verl/experimental/agent_loop/single_turn_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport copy\nimport logging\nimport os\nfrom typing import Any\nfrom uuid import uuid4\n\nfrom verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register\nfrom verl.utils.profiler import simple_timer\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n@register(\"single_turn_agent\")\nclass SingleTurnAgentLoop(AgentLoopBase):\n    \"\"\"Naive agent loop that only do single turn chat completion.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length\n        self.response_length = self.config.actor_rollout_ref.rollout.response_length\n        self.apply_chat_template_kwargs = self.config.data.get(\"apply_chat_template_kwargs\", {})\n\n    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:\n        messages = list(kwargs[\"raw_prompt\"])\n        image_data = copy.deepcopy((kwargs.get(\"multi_modal_data\") or {}).get(\"image\", None))\n\n        metrics = {}\n        request_id = uuid4().hex\n\n        # Use processor if available for multimodal support\n        if self.processor is not None:\n            raw_prompt = await self.loop.run_in_executor(\n                None,\n                lambda: self.processor.apply_chat_template(\n                    messages,\n                    add_generation_prompt=True,\n                    tokenize=False,\n                    **self.apply_chat_template_kwargs,\n                ),\n            )\n            model_inputs = self.processor(text=[raw_prompt], images=image_data, return_tensors=\"pt\")\n            prompt_ids = model_inputs.pop(\"input_ids\").squeeze(0).tolist()\n        else:\n            prompt_ids = await self.loop.run_in_executor(\n                None,\n                lambda: self.tokenizer.apply_chat_template(\n                    messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs\n                ),\n            )\n\n        with simple_timer(\"generate_sequences\", metrics):\n            output = await self.server_manager.generate(\n                request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params, image_data=image_data\n            )\n        response_mask = [1] * len(output.token_ids)\n\n        output = AgentLoopOutput(\n            prompt_ids=prompt_ids,\n            response_ids=output.token_ids[: self.response_length],\n            response_mask=response_mask[: self.response_length],\n            response_logprobs=output.log_probs[: self.response_length] if output.log_probs else None,\n            multi_modal_data={\"image\": image_data} if image_data is not None else {},\n            num_turns=2,\n            metrics=metrics,\n        )\n        return output\n"
  },
  {
    "path": "verl_distillation/verl/experimental/agent_loop/tool_agent_loop.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport copy\nimport json\nimport logging\nimport os\nfrom enum import Enum\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom verl.experimental.agent_loop.agent_loop import (AgentLoopBase,\n                                                     AgentLoopOutput, register)\nfrom verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser\nfrom verl.experimental.agent_loop.utils import (\n    add_generation_prompt_for_gpt_oss, format_gpt_oss_tool_response_manually)\nfrom verl.interactions.base import BaseInteraction\nfrom verl.interactions.utils.interaction_registry import \\\n    initialize_interactions_from_config\nfrom verl.tools.schemas import ToolResponse\nfrom verl.tools.utils.tool_registry import initialize_tools_from_config\nfrom verl.utils.profiler import simple_timer\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass AgentState(Enum):\n    PENDING = \"pending\"\n    GENERATING = \"generating\"\n    PROCESSING_TOOLS = \"processing_tools\"\n    TERMINATED = \"terminated\"\n    INTERACTING = \"interacting\"\n\n\nclass AgentData:\n    \"\"\"Encapsulates all state variables for the agent loop.\"\"\"\n\n    def __init__(\n        self,\n        messages: list[dict[str, Any]],\n        image_data: Any,\n        metrics: dict[str, Any],\n        request_id: str,\n        tools_kwargs: dict[str, Any],\n        interaction: Optional[BaseInteraction] = None,\n        interaction_kwargs: Optional[dict[str, Any]] = None,\n    ):\n        self.messages = messages\n        self.image_data = image_data\n        self.metrics = metrics\n        self.request_id = request_id\n        self.tools_kwargs = tools_kwargs\n        self.interaction = interaction\n        self.interaction_kwargs = interaction_kwargs or {}\n\n        # State variables\n        self.prompt_ids: list[int] = []\n        self.response_ids: list[int] = []\n        self.response_mask: list[int] = []\n        self.distill_special_token_mask: list[int] = []\n        self.response_logprobs: list[float] = []\n        self.turn_scores: list[float] = []\n        self.tool_rewards: list[float] = []\n        self.user_turns = 0\n        self.assistant_turns = 0\n\n        # Temporary state for tool calls\n        self.tool_calls: list[FunctionCall] = []\n\n\n@register(\"tool_agent\")\nclass ToolAgentLoop(AgentLoopBase):\n    @classmethod\n    def init_class(cls, config, tokenizer, processor, **kwargs):\n        if cls._class_initialized:\n            return\n        cls._class_initialized = True\n        print(\"Performing class-level ToolAgentLoop initialization\")\n\n        # Initialize tools from config file\n        cls.tokenizer = tokenizer\n        cls.processor = processor\n        cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns\n        cls.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns\n        cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls\n        cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length\n        cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side\n        tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path\n        tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []\n        cls.tools = {tool.name: tool for tool in tool_list}\n        cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]\n        cls.tool_parser = ToolParser.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format, cls.tokenizer)\n        cls.tool_parser_name = config.actor_rollout_ref.rollout.multi_turn.format\n        print(f\"Initialized tools: {cls.tools}\")\n\n        cls.apply_chat_template_kwargs = config.data.get(\"apply_chat_template_kwargs\", {})\n        cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length\n        cls.response_length = config.actor_rollout_ref.rollout.response_length\n        cls.system_prompt = tokenizer.apply_chat_template(\n            [{}], add_generation_prompt=False, tokenize=True, **cls.apply_chat_template_kwargs\n        )\n        cls.extend_vocab_start_token = config.actor_rollout_ref.rollout.extend_vocab_start_token\n        cls.mask_response_if_have_extend_token = config.actor_rollout_ref.rollout.mask_response_if_have_extend_token\n        # Initialize interactions from config file\n        cls.interaction_config_file = config.actor_rollout_ref.rollout.multi_turn.interaction_config_path\n        if cls.interaction_config_file:\n            cls.interaction_map: dict[str, BaseInteraction] = cls._initialize_interactions(cls.interaction_config_file)\n\n    @rollout_trace_op\n    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:\n        messages = list(kwargs[\"raw_prompt\"])\n        image_data = copy.deepcopy(kwargs.get(\"multi_modal_data\", {}).get(\"image\", None))\n        metrics = {}\n        request_id = uuid4().hex\n        tools_kwargs = kwargs.get(\"tools_kwargs\", {})\n\n        # Initialize interaction if needed\n        interaction = None\n        interaction_kwargs = {}\n        if self.interaction_config_file:\n            interaction_kwargs = kwargs[\"extra_info\"][\"interaction_kwargs\"]\n            if \"name\" not in interaction_kwargs:\n                raise ValueError(\"'name' key is required in interaction_kwargs\")\n            interaction_name = interaction_kwargs[\"name\"]\n            if interaction_name not in self.interaction_map:\n                raise ValueError(\n                    f\"Interaction '{interaction_name}' not found in interaction_map. Available interactions: \"\n                    f\"{list(self.interaction_map.keys())}\"\n                )\n            interaction = self.interaction_map[interaction_name]\n            await interaction.start_interaction(request_id, **interaction_kwargs)\n        # Create AgentData instance to encapsulate all state\n        agent_data = AgentData(\n            messages=messages,\n            image_data=image_data,\n            metrics=metrics,\n            request_id=request_id,\n            tools_kwargs=tools_kwargs,\n            interaction=interaction,\n            interaction_kwargs=interaction_kwargs,\n        )\n\n        # State machine loop\n        state = AgentState.PENDING\n        while state != AgentState.TERMINATED:\n            if state == AgentState.PENDING:\n                state = await self._handle_pending_state(agent_data, sampling_params)\n            elif state == AgentState.GENERATING:\n                state = await self._handle_generating_state(agent_data, sampling_params)\n            elif state == AgentState.PROCESSING_TOOLS:\n                state = await self._handle_processing_tools_state(agent_data)\n            elif state == AgentState.INTERACTING:\n                state = await self._handle_interacting_state(agent_data)\n            else:\n                logger.error(f\"Invalid state: {state}\")\n                state = AgentState.TERMINATED\n\n        # Finalize output\n        response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :]\n        prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)]\n        multi_modal_data = {\"image\": agent_data.image_data} if agent_data.image_data is not None else {}\n        output = AgentLoopOutput(\n            prompt_ids=prompt_ids,\n            response_ids=response_ids[: self.response_length],\n            response_mask=agent_data.response_mask[: self.response_length],\n            multi_modal_data=multi_modal_data,\n            response_logprobs=agent_data.response_logprobs[: self.response_length]\n            if agent_data.response_logprobs\n            else None,\n            distill_special_token_mask=agent_data.distill_special_token_mask[: self.response_length]\n            if agent_data.distill_special_token_mask\n            else None,\n            num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,\n            metrics=agent_data.metrics,\n            extra_fields={},\n        )\n        output.extra_fields.update({\"turn_scores\": agent_data.turn_scores, \"tool_rewards\": agent_data.tool_rewards})\n        return output\n\n    async def _handle_pending_state(self, agent_data: AgentData, sampling_params: dict[str, Any]) -> AgentState:\n        \"\"\"Handle the pending state: prepare the prompt and start generation.\"\"\"\n        if self.processor is not None:\n            raw_prompt = await self.loop.run_in_executor(\n                None,\n                lambda: self.processor.apply_chat_template(\n                    agent_data.messages,\n                    tools=self.tool_schemas,\n                    add_generation_prompt=True,\n                    tokenize=False,\n                    **self.apply_chat_template_kwargs,\n                ),\n            )\n            model_inputs = self.processor(text=[raw_prompt], images=agent_data.image_data, return_tensors=\"pt\")\n            agent_data.prompt_ids = model_inputs.pop(\"input_ids\").squeeze(0).tolist()\n        else:\n            agent_data.prompt_ids = await self.loop.run_in_executor(\n                None,\n                lambda: self.tokenizer.apply_chat_template(\n                    agent_data.messages,\n                    tools=self.tool_schemas,\n                    add_generation_prompt=True,\n                    tokenize=True,\n                    **self.apply_chat_template_kwargs,\n                ),\n            )\n        return AgentState.GENERATING\n\n    async def _handle_generating_state(\n        self, agent_data: AgentData, sampling_params: dict[str, Any], ignore_termination: bool = False\n    ) -> AgentState:\n        \"\"\"Handle the generating state: generate model response and check for tool calls.\"\"\"\n        add_messages: list[dict[str, Any]] = []\n\n        with simple_timer(\"generate_sequences\", agent_data.metrics):\n            output = await self.server_manager.generate(\n                request_id=agent_data.request_id,\n                prompt_ids=agent_data.prompt_ids,\n                sampling_params=sampling_params,\n                image_data=agent_data.image_data,\n            )\n\n        agent_data.assistant_turns += 1\n        agent_data.response_ids = output.token_ids\n        agent_data.prompt_ids += agent_data.response_ids\n        distill_special_token_mask = []\n        response_mask = [1] * len(agent_data.response_ids)\n        if self.extend_vocab_start_token is not None:\n            assert isinstance(self.extend_vocab_start_token, int)\n            for idx, token in enumerate(agent_data.response_ids):\n                if token >= self.extend_vocab_start_token:\n                    distill_special_token_mask.append(1)\n                else:\n                    distill_special_token_mask.append(0)\n            try:\n                first_one_index = distill_special_token_mask.index(1)\n                response_mask[first_one_index + 1:] = [0] * (len(response_mask) - first_one_index - 1)\n            except ValueError:\n                pass\n            if self.mask_response_if_have_extend_token:\n                if sum(distill_special_token_mask) > 0:\n                    response_mask = [0] * len(agent_data.response_ids)\n        else:\n            distill_special_token_mask = [0] * len(agent_data.response_ids)\n\n        agent_data.response_mask += response_mask\n        agent_data.distill_special_token_mask += distill_special_token_mask\n        if output.log_probs:\n            agent_data.response_logprobs += output.log_probs\n\n        # Check termination conditions\n        if not ignore_termination and len(agent_data.response_mask) >= self.response_length:\n            return AgentState.TERMINATED\n        if self.max_assistant_turns and agent_data.assistant_turns >= self.max_assistant_turns:\n            return AgentState.TERMINATED\n        if self.max_user_turns and agent_data.user_turns >= self.max_user_turns:\n            return AgentState.TERMINATED\n\n        # Extract tool calls\n        _, agent_data.tool_calls = await self.tool_parser.extract_tool_calls(agent_data.response_ids)\n\n        # Handle interaction if needed\n        if self.interaction_config_file:\n            assistant_message = await self.loop.run_in_executor(\n                None, lambda: self.tokenizer.decode(agent_data.response_ids, skip_special_tokens=True)\n            )\n            add_messages.append({\"role\": \"assistant\", \"content\": assistant_message})\n            agent_data.messages.extend(add_messages)\n\n        # Determine next state\n        if agent_data.tool_calls:\n            return AgentState.PROCESSING_TOOLS\n        elif self.interaction_config_file:\n            return AgentState.INTERACTING\n        else:\n            return AgentState.TERMINATED\n\n    async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentState:\n        \"\"\"Handle the processing tools state: execute tool calls and prepare tool responses.\"\"\"\n        add_messages: list[dict[str, Any]] = []\n        new_images_this_turn: list[Any] = []  # Local variable instead of agent_data attribute\n\n        tasks = []\n        tool_call_names = []\n        for tool_call in agent_data.tool_calls[: self.max_parallel_calls]:\n            tasks.append(self._call_tool(tool_call, agent_data.tools_kwargs))\n            tool_call_names.append(tool_call.name)\n\n        with simple_timer(\"tool_calls\", agent_data.metrics):\n            responses = await asyncio.gather(*tasks)\n\n        # Process tool responses and update multi_modal_data\n        # Removed: agent_data.new_images_this_turn = []\n        for tool_response, tool_reward, _ in responses:\n            # Create message from tool response\n            if tool_response.image or tool_response.video:\n                # Multi-modal content with structured format\n                if not getattr(self.processor, \"image_processor\", None):\n                    raise ValueError(\n                        \"Multimedia data can only be processed by `processor`, but the processor is None. \"\n                        \"This error is often caused if you are using a LLM model but your tool returns multimodal \"\n                        \"data. Plase use a vlm as the base model.\"\n                    )\n                content = []\n                if tool_response.image:\n                    content.append({\"type\": \"image\"})\n                if tool_response.video:\n                    content.append({\"type\": \"video\"})\n                if tool_response.text:\n                    content.append({\"type\": \"text\", \"text\": tool_response.text})\n                message = {\"role\": \"tool\", \"content\": content}\n            else:\n                # Text-only content\n                message = {\"role\": \"tool\", \"content\": tool_response.text or \"\"}\n\n            add_messages.append(message)\n\n            # Handle image data\n            if tool_response.image:\n                if agent_data.image_data is None:\n                    agent_data.image_data = []\n                elif not isinstance(agent_data.image_data, list):\n                    agent_data.image_data = [agent_data.image_data]\n\n                # Add new image data\n                if isinstance(tool_response.image, list):\n                    # Ensure all elements in the list are valid image objects\n                    for img in tool_response.image:\n                        if img is not None:  # Add a check to ensure the image is not None\n                            agent_data.image_data.append(img)\n                            new_images_this_turn.append(img)  # Using local variable\n                else:\n                    # Ensure the image is not None\n                    if tool_response.image is not None:\n                        agent_data.image_data.append(tool_response.image)\n                        new_images_this_turn.append(tool_response.image)  # Using local variable\n\n            # Handle video data\n            if tool_response.video:\n                # Currently not supported, raise informative error\n                logger.warning(\"Multimedia type 'video' is not currently supported. Only 'image' is supported.\")\n                raise NotImplementedError(\n                    \"Multimedia type 'video' is not currently supported. Only 'image' is supported.\"\n                )\n\n            if tool_reward is not None:\n                agent_data.tool_rewards.append(tool_reward)\n\n        agent_data.messages.extend(add_messages)\n        # Update prompt with tool responses\n        if self.processor is not None:\n            raw_tool_response = await self.loop.run_in_executor(\n                None,\n                lambda: self.processor.apply_chat_template(\n                    add_messages,\n                    add_generation_prompt=True,\n                    tokenize=False,\n                    **self.apply_chat_template_kwargs,\n                ),\n            )\n            # Use only the new images from this turn for processing tool responses\n            current_images = new_images_this_turn if new_images_this_turn else None  # Using local variable\n            model_inputs = self.processor(text=[raw_tool_response], images=current_images, return_tensors=\"pt\")\n            response_ids = model_inputs.pop(\"input_ids\").squeeze(0).tolist()\n        else:\n            if self.tool_parser_name == \"gpt-oss\":\n                logger.info(\"manually format tool responses for gpt-oss\")\n                # Format tool responses manually\n                tool_response_texts = []\n                for i, tool_msg in enumerate(add_messages):\n                    actual_tool_name = tool_call_names[i]\n                    formatted = format_gpt_oss_tool_response_manually(tool_msg[\"content\"], actual_tool_name)\n                    tool_response_texts.append(formatted)\n\n                tool_response_text = add_generation_prompt_for_gpt_oss(\"\".join(tool_response_texts))\n                response_ids = await self.loop.run_in_executor(\n                    None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False)\n                )\n            else:\n                response_ids = await self.loop.run_in_executor(\n                    None,\n                    lambda: self.tokenizer.apply_chat_template(add_messages, add_generation_prompt=True, tokenize=True),\n                )\n                response_ids = response_ids[len(self.system_prompt) :]\n        if len(agent_data.response_mask) + len(response_ids) >= self.response_length:\n            return AgentState.TERMINATED\n        # Update prompt_ids and response_mask\n        agent_data.prompt_ids += response_ids\n        agent_data.response_mask += [0] * len(response_ids)\n        if agent_data.response_logprobs:\n            agent_data.response_logprobs += [0.0] * len(response_ids)\n        agent_data.user_turns += 1\n        return AgentState.GENERATING\n\n    async def _handle_interacting_state(self, agent_data: AgentData) -> AgentState:\n        \"\"\"Handle the interacting state: get user input from interaction.\"\"\"\n        (\n            should_terminate_sequence,\n            interaction_responses,\n            reward,\n            metrics,\n        ) = await agent_data.interaction.generate_response(\n            agent_data.request_id, agent_data.messages, **agent_data.interaction_kwargs\n        )\n        agent_data.user_turns += 1\n\n        add_messages: list[dict[str, Any]] = [{\"role\": \"user\", \"content\": interaction_responses}]\n        agent_data.messages.extend(add_messages)\n\n        if reward is not None:\n            agent_data.turn_scores.append(reward)\n\n        # Update prompt with user responses (similar to _handle_processing_tools_state)\n        if self.processor is not None:\n            raw_user_response = await self.loop.run_in_executor(\n                None,\n                lambda: self.processor.apply_chat_template(\n                    add_messages,\n                    add_generation_prompt=True,\n                    tokenize=False,\n                    **self.apply_chat_template_kwargs,\n                ),\n            )\n            model_inputs = self.processor(text=[raw_user_response], images=None, return_tensors=\"pt\")\n            response_ids = model_inputs.pop(\"input_ids\").squeeze(0).tolist()\n        else:\n            response_ids = await self.loop.run_in_executor(\n                None,\n                lambda: self.tokenizer.apply_chat_template(add_messages, add_generation_prompt=True, tokenize=True),\n            )\n        response_ids = response_ids[len(self.system_prompt) :]\n\n        # Update prompt_ids and response_mask\n        agent_data.prompt_ids += response_ids\n        agent_data.response_mask += [0] * len(response_ids)\n        if agent_data.response_logprobs:\n            agent_data.response_logprobs += [0.0] * len(response_ids)\n\n        # double check prompt\n        # Check termination condition\n        if should_terminate_sequence:\n            return AgentState.TERMINATED\n        else:\n            return AgentState.GENERATING\n\n    async def _call_tool(\n        self, tool_call: FunctionCall, tools_kwargs: dict[str, Any]\n    ) -> tuple[ToolResponse, float, dict]:\n        \"\"\"Call tool and return tool response.\"\"\"\n        tool, instance_id = None, None\n        try:\n            # TODO: append malformed tool_call to the prompt: invalid function name or arguments\n            tool_name = tool_call.name\n            tool_args = json.loads(tool_call.arguments)\n            tool = self.tools[tool_name]\n            kwargs = tools_kwargs.get(tool_name, {})\n            instance_id, _ = await tool.create(create_kwargs=kwargs.get(\"create_kwargs\", {}))\n            tool_execution_response, tool_reward, res = await tool.execute(instance_id, tool_args)\n        except Exception as e:\n            logger.warning(f\"Error when executing tool: {e}\")\n            return (\n                ToolResponse(\n                    text=f\"Error when executing tool: {e}\",\n                ),\n                0.0,\n                {},\n            )\n        finally:\n            if tool and instance_id:\n                await tool.release(instance_id)\n\n        tool_response_text = tool_execution_response.text\n        if tool_response_text and len(tool_response_text) > self.max_tool_response_length:\n            if self.tool_response_truncate_side == \"left\":\n                tool_response_text = tool_response_text[: self.max_tool_response_length] + \"...(truncated)\"\n            elif self.tool_response_truncate_side == \"right\":\n                tool_response_text = \"(truncated)...\" + tool_response_text[-self.max_tool_response_length :]\n            else:\n                length = self.max_tool_response_length // 2\n                tool_response_text = tool_response_text[:length] + \"...(truncated)...\" + tool_response_text[-length:]\n\n        # Create ToolResponse from tool execution result\n        tool_response_kwargs = {\"text\": tool_response_text}\n\n        # Add multimedia data if present\n        for attr_name in [\"image\", \"video\"]:\n            if hasattr(tool_execution_response, attr_name):\n                attr_value = getattr(tool_execution_response, attr_name)\n                if attr_value is not None:\n                    tool_response_kwargs[attr_name] = attr_value\n\n        return ToolResponse(**tool_response_kwargs), tool_reward, res\n\n    @classmethod\n    def _initialize_interactions(cls, interaction_config_file):\n        \"\"\"Initialize interactions from configuration.\n        Returns:\n            dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances.\n        \"\"\"\n        if interaction_config_file is None:\n            return {}\n\n        interaction_map = initialize_interactions_from_config(interaction_config_file)\n        logger.info(f\"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}\")\n        return interaction_map\n"
  },
  {
    "path": "verl_distillation/verl/experimental/agent_loop/tool_parser.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport json\nimport logging\nimport os\nfrom abc import ABC, abstractmethod\n\nimport regex\nfrom pydantic import BaseModel\n\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass FunctionCall(BaseModel):\n    arguments: str\n    \"\"\"\n    The arguments to call the function with, as generated by the model in JSON\n    format. Note that the model does not always generate valid JSON, and may\n    hallucinate parameters not defined by your function schema. Validate the\n    arguments in your code before calling your function.\n    \"\"\"\n\n    name: str\n    \"\"\"The name of the function to call.\"\"\"\n\n\nclass ToolParser(ABC):\n    _registry: dict[str, type[\"ToolParser\"]] = {}\n\n    def __init__(self, tokenizer) -> None:\n        self.tokenizer = tokenizer\n\n    @abstractmethod\n    async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:\n        \"\"\"Extract tool calls from the responses.\n\n        Args:\n            responses_ids (List[int]): The ids of the responses.\n\n        Returns:\n            Tuple[str, List[FunctionCall]]: Content and extracted tool calls.\n        \"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    def get_tool_parser(cls, name: str, tokenizer):\n        if name not in cls._registry:\n            raise ValueError(f\"Unknown tool parser: {name}\")\n        return cls._registry[name](tokenizer)\n\n    @classmethod\n    def register(cls, name: str):\n        def decorator(subclass: type[ToolParser]) -> type[ToolParser]:\n            cls._registry[name] = subclass\n            return subclass\n\n        return decorator\n\n\n@ToolParser.register(\"hermes\")\nclass HermesToolParser(ToolParser):\n    \"\"\"Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py\"\"\"\n\n    def __init__(self, tokenizer) -> None:\n        super().__init__(tokenizer)\n\n        self.tool_call_start_token: str = \"<tool_call>\"\n        self.tool_call_end_token: str = \"</tool_call>\"\n        self.tool_call_regex = regex.compile(r\"<tool_call>(.*?)</tool_call>\", regex.DOTALL)\n\n    @rollout_trace_op\n    async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:\n        loop = asyncio.get_running_loop()\n        text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)\n        if self.tool_call_start_token not in text or self.tool_call_end_token not in text:\n            return text, []\n\n        matches = self.tool_call_regex.findall(text)\n        function_calls = []\n        for match in matches:\n            try:\n                function_call = json.loads(match)\n                name, arguments = function_call[\"name\"], function_call[\"arguments\"]\n                function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False)))\n            except Exception as e:\n                logger.error(f\"Failed to decode tool call: {e}\")\n\n        # remaing text exclude tool call tokens\n        content = self.tool_call_regex.sub(\"\", text)\n\n        return content, function_calls\n\n\n@ToolParser.register(\"gpt-oss\")\nclass GptOssToolParser(ToolParser):\n    \"\"\"\n    Tool parser for gpt-oss model.\n    Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/function_call/gpt_oss_detector.py\n\n    Args:\n        tokenizer: The tokenizer to use.\n    \"\"\"\n\n    def __init__(self, tokenizer) -> None:\n        super().__init__(tokenizer)\n        # check https://cookbook.openai.com/articles/openai-harmony for more details.\n        self.cot_pattern = regex.compile(\n            r\"<\\|start\\|>assistant<\\|channel\\|>analysis<\\|message\\|>.*?<\\|end\\|>\", regex.DOTALL\n        )\n        # <|start|>assistant may be pre-appended in prompts, so we need to remove it.\n        self.partial_cot_pattern = regex.compile(r\"<\\|channel\\|>analysis<\\|message\\|>(.*?)<\\|end\\|>\", regex.DOTALL)\n        self.tool_call_pattern = regex.compile(\n            r\"<\\|start\\|>assistant<\\|channel\\|>[^<]* to=functions\\.([^<]+) \"\n            r\"<\\|constrain\\|>json<\\|message\\|>(.*?)<\\|call\\|>\",\n            regex.DOTALL,\n        )\n\n    @rollout_trace_op\n    async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:\n        loop = asyncio.get_running_loop()\n        # We need to keep special tokens for gpt-oss model for better tool call extraction.\n        text = await loop.run_in_executor(None, lambda: self.tokenizer.decode(responses_ids, skip_special_tokens=False))\n        # Need to remove padding tokens for better tool call extraction.\n        text = text.replace(self.tokenizer.pad_token, \"\")\n        # Need to reomve COT since COT may contain tool call tokens.But they are not valid tool calls.\n        text = regex.sub(self.cot_pattern, \"\", text)\n        text = regex.sub(self.partial_cot_pattern, \"\", text)\n\n        # check if there are tool calls in the text by re.findall\n        matches = regex.findall(self.tool_call_pattern, text)\n        if not matches:\n            return text, []\n\n        function_calls = []\n        for match in matches:\n            try:\n                name, arguments = match[0], match[1]\n                # don't check if arguments is valid JSON and leave it to client\n                function_calls.append(FunctionCall(name=name, arguments=arguments))\n            except Exception as e:\n                logger.error(f\"Failed to decode tool call: {e}\")\n\n        # remaing text exclude tool call tokens\n        content = regex.sub(self.tool_call_pattern, \"\", text)\n\n        return content, function_calls\n"
  },
  {
    "path": "verl_distillation/verl/experimental/agent_loop/utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# tokenizer.apply_chat_template is not working properly for gpt-oss model.\n# Because the chat template requires tool call messages to parse tool response messages\n# so we need to format the tool response manually.\ndef format_gpt_oss_tool_response_manually(tool_response: str, tool_call_name: str) -> str:\n    \"\"\"Format tool response for gpt-oss model.\n    Args:\n        tool_response: Tool response string\n        tool_call_name: Name of the tool that was called\n\n    Returns:\n        Formatted tool response string\n    \"\"\"\n    return f\"<|start|>functions.{tool_call_name} to=assistant<|channel|>commentary<|message|>{tool_response}<|end|>\"\n\n\ndef add_generation_prompt_for_gpt_oss(message_content: str) -> str:\n    \"\"\"Add generation prompt for gpt-oss model.\n    Args:\n        message_content: Message content string\n\n    Returns:\n        Message content string with generation prompt\n    \"\"\"\n    return message_content + \"<|start|>assistant\"\n"
  },
  {
    "path": "verl_distillation/verl/experimental/dataset/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/experimental/dataset/sampler.py",
    "content": "# Copyright 2025 Amazon.com Inc and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom abc import abstractmethod\nfrom collections.abc import Sized\n\nfrom omegaconf import DictConfig\nfrom torch.utils.data import Sampler\n\nfrom verl import DataProto\n\n\nclass AbstractSampler(Sampler[int]):\n    \"\"\"Abstract interface for custom samplers.\"\"\"\n\n    @abstractmethod\n    def __init__(\n        self,\n        data_source: Sized,\n        data_config: DictConfig,\n    ):\n        pass\n\n\nclass AbstractCurriculumSampler(AbstractSampler):\n    \"\"\"Experimental interface for curriculum learning samplers.\"\"\"\n\n    @abstractmethod\n    def update(self, batch: DataProto) -> None:\n        pass\n"
  },
  {
    "path": "verl_distillation/verl/experimental/dynamic_dataset/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/experimental/dynamic_dataset/dynamicgen_dataset.py",
    "content": "# Copyright 2025 Amazon.com Inc and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nDataset class that enables dynamic data generation strategies between iterations of training.\nThis class extends RLHFDataset and uses an AbstractDataGen instance to generate data.\n\nThis is especially useful in settings where proposer model generates new tasks based\non rollout data.\n\"\"\"\n\nimport logging\nfrom abc import ABC, abstractmethod\nfrom typing import Optional\n\nimport datasets\nfrom omegaconf import DictConfig\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nfrom verl import DataProto\nfrom verl.utils.dataset import RLHFDataset\nfrom verl.utils.import_utils import load_extern_type\n\nlogger = logging.getLogger(__name__)\n\n\nclass AbstractDataGenerator(ABC):\n    def __init__(self, config: DictConfig):\n        self.config = config\n\n    @abstractmethod\n    def generate(self, dataset: Dataset) -> datasets.Dataset:\n        \"\"\"\n        Generate method must be implemented by subclasses.\n        Args:\n            dataset: The dataset to generate from.\n        Returns:\n            Processed data or result as implemented by the subclass.\n        \"\"\"\n        pass\n\n\nclass MockDataGenerator(AbstractDataGenerator):\n    \"\"\"\n    A noop data gen class that only reappends the first datapoint.\n    This class is useful as a placeholder and testing.\n    \"\"\"\n\n    def __init__(self, config: DictConfig = None):\n        super().__init__(config)\n\n    def generate(self, dataset: Dataset) -> datasets.Dataset:\n        print(\"MockDataGenerator: No operation performed on the dataset.\")\n        return dataset.dataframe.select([0])\n\n\nclass DynamicGenDataset(RLHFDataset):\n    \"\"\"\n    A dataset class that uses a data generation strategy to process data.\n    This class extends RLHFDataset and uses an AbstractDataGen instance to generate data.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_files: str | list[str],\n        tokenizer: PreTrainedTokenizer,\n        config: DictConfig,\n        processor: Optional[ProcessorMixin] = None,\n    ):\n        super().__init__(data_files, tokenizer, config, processor)\n        self.datagen: AbstractDataGenerator = config.datagen\n        assert \"datagen\" in config and config.datagen.get(\"path\", None) is not None, (\n            f\"datagen path is not set in config: {config}\"\n        )\n        # Dynamically load the custom datagen class\n        datagen_cls = load_extern_type(config.datagen.path, config.datagen.name)\n\n        # Verify that the custom datagen class inherits from AbstractDataGenerator\n        abs_cls = AbstractDataGenerator\n        if not issubclass(datagen_cls, abs_cls):\n            raise TypeError(\n                f\"The custom datagen class '{config.datagen.name}' from '{config.datagen.path}'\"\n                + \" must inherit from {abs_cls}\"\n            )\n\n        self.data_generator = datagen_cls(config.datagen)\n        self.on_batch_end()\n\n    def append_dataframe(self, new_dataframe: datasets.Dataset):\n        new_dataframe = self.maybe_filter_out_long_prompts(new_dataframe)\n        self.dataframe = datasets.concatenate_datasets([self.dataframe, new_dataframe])\n\n        logger.info(f\"new dataset len: {len(self.dataframe)}\")\n\n    def on_batch_end(self, batch: DataProto) -> None:\n        \"\"\"\n        Generate data using the provided data generation strategy.\n        Note: This method is intended to change the dataset after each training batch.\n        \"\"\"\n        new_data = self.data_generator.generate(self)\n        self.append_dataframe(new_data)\n"
  },
  {
    "path": "verl_distillation/verl/experimental/reward/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .reward_manager import RewardManagerWorker\nfrom .reward_model import RewardModelManager\n\n__all__ = [\"RewardModelManager\", \"RewardManagerWorker\"]\n"
  },
  {
    "path": "verl_distillation/verl/experimental/reward/reward_loop/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .registry import get_reward_loop_manager_cls, register  # noqa: I001\nfrom .dapo import DAPORewardLoopManager\nfrom .naive import NaiveRewardLoopManager\n\n__all__ = [\n    \"DAPORewardLoopManager\",\n    \"NaiveRewardLoopManager\",\n    \"register\",\n    \"get_reward_loop_manager_cls\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/experimental/reward/reward_loop/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport logging\nimport os\nfrom abc import ABC, abstractmethod\n\nfrom omegaconf import DictConfig\nfrom transformers import AutoTokenizer\n\nfrom verl import DataProto\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass RewardLoopManagerBase(ABC):\n    _class_initialized = False\n\n    def __init__(self, config: DictConfig, tokenizer: AutoTokenizer):\n        \"\"\"Initialize agent loop.\n\n        Args:\n            config (DictConfig): YAML config.\n            tokenizer (AutoTokenizer): Tokenizer for tokenize messages.\n        \"\"\"\n        self.config = config\n        self.tokenizer = tokenizer\n        self.loop = asyncio.get_running_loop()\n        self.init_class(config, tokenizer)\n\n    @classmethod\n    def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer):\n        \"\"\"Initialize class state shared across all instances.\"\"\"\n        if cls._class_initialized:\n            return\n        cls._class_initialized = True\n\n    @abstractmethod\n    async def run_single(self, data: DataProto):\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_distillation/verl/experimental/reward/reward_loop/dapo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\n\nfrom verl import DataProto\nfrom verl.experimental.reward.reward_loop import register\nfrom verl.experimental.reward.reward_loop.base import RewardLoopManagerBase\nfrom verl.utils.reward_score import default_compute_score\n\n\n@register(\"dapo\")\nclass DAPORewardLoopManager(RewardLoopManagerBase):\n    \"\"\"Reward loop for DAPO.\"\"\"\n\n    def __init__(self, config, tokenizer, compute_score=None, reward_router_address=None, reward_model_tokenizer=None):\n        super().__init__(config, tokenizer)\n        self.compute_score = compute_score or default_compute_score\n        self.is_async_reward_score = inspect.iscoroutinefunction(self.compute_score)\n\n        # DAPO Reward Config\n        overlong_buffer_cfg = config.reward_model.get(\"reward_kwargs\", {}).get(\"overlong_buffer_cfg\", None)\n        self.overlong_buffer_cfg = overlong_buffer_cfg\n        self.max_resp_len = config.reward_model.get(\"reward_kwargs\", {}).get(\"max_resp_len\", None)\n        self.reward_router_address = reward_router_address\n        self.reward_model_tokenizer = reward_model_tokenizer\n\n        if self.overlong_buffer_cfg is not None:\n            assert self.max_resp_len is not None, (\n                f\"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None\"\n            )\n            assert self.max_resp_len >= self.overlong_buffer_cfg.len, (\n                \"max_resp_len must be larger than overlong_buffer.len\"\n            )\n\n    async def run_single(self, data: DataProto) -> dict:\n        assert len(data) == 1, \"Only support single data item\"\n        data_item = data[0]\n        response_ids = data_item.batch[\"responses\"]\n        response_length = response_ids.shape[-1]\n        valid_response_length = data_item.batch[\"attention_mask\"][-response_length:].sum()\n        valid_response_ids = response_ids[:valid_response_length]\n\n        data_source = data_item.non_tensor_batch[\"data_source\"]\n        ground_truth = data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"]\n        extra_info = data_item.non_tensor_batch.get(\"extra_info\", {})\n\n        response_str = await self.loop.run_in_executor(\n            None, lambda: self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n        )\n        if self.is_async_reward_score:\n            result = await self.compute_score(\n                data_source=data_source,\n                solution_str=response_str,\n                ground_truth=ground_truth,\n                extra_info=extra_info,\n                reward_router_address=self.reward_router_address,\n                reward_model_tokenizer=self.reward_model_tokenizer,\n            )\n        else:\n            result = await self.loop.run_in_executor(\n                None,\n                lambda: self.compute_score(\n                    data_source=data_source,\n                    solution_str=response_str,\n                    ground_truth=ground_truth,\n                    extra_info=extra_info,\n                    reward_router_address=self.reward_router_address,\n                    reward_model_tokenizer=self.reward_model_tokenizer,\n                ),\n            )\n\n        reward_extra_info = {}\n\n        score: float\n        if isinstance(result, dict):\n            score = result[\"score\"]\n            for key, value in result.items():\n                reward_extra_info[key] = value\n        else:\n            score = result\n            reward_extra_info[\"acc\"] = score\n\n        reward = score\n\n        if self.overlong_buffer_cfg is not None and self.overlong_buffer_cfg.enable:\n            overlong_buffer_len = self.overlong_buffer_cfg.len\n            expected_len = self.max_resp_len - overlong_buffer_len\n            exceed_len = valid_response_length - expected_len\n            overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor\n            overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)\n            reward += overlong_reward\n            if self.overlong_buffer_cfg.log:\n                reward_extra_info[\"overlong_reward\"] = overlong_reward\n                reward_extra_info[\"overlong\"] = overlong_reward < 0\n\n        return {\"reward_score\": reward, \"reward_extra_info\": reward_extra_info}\n"
  },
  {
    "path": "verl_distillation/verl/experimental/reward/reward_loop/naive.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\n\nfrom verl import DataProto\nfrom verl.experimental.reward.reward_loop import register\nfrom verl.experimental.reward.reward_loop.base import RewardLoopManagerBase\nfrom verl.utils.reward_score import default_compute_score\n\n\n@register(\"naive\")\nclass NaiveRewardLoopManager(RewardLoopManagerBase):\n    \"\"\"The reward manager.\"\"\"\n\n    def __init__(self, config, tokenizer, compute_score=None, reward_router_address=None, reward_model_tokenizer=None):\n        super().__init__(config, tokenizer)\n        self.compute_score = compute_score or default_compute_score\n        self.is_async_reward_score = inspect.iscoroutinefunction(self.compute_score)\n        self.reward_router_address = reward_router_address\n        self.reward_model_tokenizer = reward_model_tokenizer\n\n    async def run_single(self, data: DataProto) -> dict:\n        assert len(data) == 1, \"Only support single data item\"\n        data_item = data[0]\n        response_ids = data_item.batch[\"responses\"]\n        response_length = response_ids.shape[-1]\n        valid_response_length = data_item.batch[\"attention_mask\"][-response_length:].sum()\n        valid_response_ids = response_ids[:valid_response_length]\n\n        data_source = data_item.non_tensor_batch[\"data_source\"]\n        ground_truth = data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"]\n        extra_info = data_item.non_tensor_batch.get(\"extra_info\", {})\n        tool_extra_fields = data_item.non_tensor_batch.get(\"tool_extra_fields\", None)\n        if tool_extra_fields is not None:\n            extra_info.update(tool_extra_fields.items())\n\n        response_str = await self.loop.run_in_executor(\n            None, lambda: self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n        )\n        if self.is_async_reward_score:\n            result = await self.compute_score(\n                data_source=data_source,\n                solution_str=response_str,\n                ground_truth=ground_truth,\n                extra_info=extra_info,\n                reward_router_address=self.reward_router_address,\n                reward_model_tokenizer=self.reward_model_tokenizer,\n            )\n        else:\n            result = await self.loop.run_in_executor(\n                None,\n                lambda: self.compute_score(\n                    data_source=data_source,\n                    solution_str=response_str,\n                    ground_truth=ground_truth,\n                    extra_info=extra_info,\n                    reward_router_address=self.reward_router_address,\n                    reward_model_tokenizer=self.reward_model_tokenizer,\n                ),\n            )\n\n        reward_extra_info = {}\n\n        score: float\n        if isinstance(result, dict):\n            score = result[\"score\"]\n            for key, value in result.items():\n                reward_extra_info[key] = value\n        else:\n            score = result\n            reward_extra_info[\"acc\"] = score\n\n        reward = score\n\n        return {\"reward_score\": reward, \"reward_extra_info\": reward_extra_info}\n"
  },
  {
    "path": "verl_distillation/verl/experimental/reward/reward_loop/registry.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable\n\nfrom verl.experimental.reward.reward_loop.base import RewardLoopManagerBase\n\n__all__ = [\"register\", \"get_reward_loop_manager_cls\"]\n\nREWARD_LOOP_MANAGER_REGISTRY: dict[str, type[RewardLoopManagerBase]] = {}\n\n\ndef register(name: str) -> Callable[[type[RewardLoopManagerBase]], type[RewardLoopManagerBase]]:\n    \"\"\"Decorator to register a reward loop manager class with a given name.\n\n    Args:\n        name: `(str)`\n            The name of the reward loop manager.\n    \"\"\"\n\n    def decorator(cls: type[RewardLoopManagerBase]) -> type[RewardLoopManagerBase]:\n        if name in REWARD_LOOP_MANAGER_REGISTRY and REWARD_LOOP_MANAGER_REGISTRY[name] != cls:\n            raise ValueError(\n                f\"reward loop manager {name} has already been registered: {REWARD_LOOP_MANAGER_REGISTRY[name]} vs {cls}\"\n            )\n        REWARD_LOOP_MANAGER_REGISTRY[name] = cls\n        return cls\n\n    return decorator\n\n\ndef get_reward_loop_manager_cls(name: str) -> type[RewardLoopManagerBase]:\n    \"\"\"Get the reward loop manager class with a given name.\n\n    Args:\n        name: `(str)`\n            The name of the reward loop manager.\n\n    Returns:\n        `(type)`: The reward loop manager class.\n    \"\"\"\n    if name not in REWARD_LOOP_MANAGER_REGISTRY:\n        raise ValueError(f\"Unknown reward loop manager: {name}\")\n    return REWARD_LOOP_MANAGER_REGISTRY[name]\n"
  },
  {
    "path": "verl_distillation/verl/experimental/reward/reward_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport ray\nfrom omegaconf import DictConfig\n\nfrom verl.experimental.reward.reward_loop import get_reward_loop_manager_cls\nfrom verl.protocol import DataProto\nfrom verl.trainer.ppo.reward import get_custom_reward_fn\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.fs import copy_to_local\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n@ray.remote\nclass RewardManagerWorker:\n    def __init__(self, config: DictConfig, reward_router_address: str = None):\n        self.config = config\n        self.reward_router_address = reward_router_address\n        self._init_reward_fn()\n\n    def _init_reward_fn(self):\n        input_tokenizer_local_path = copy_to_local(self.config.actor_rollout_ref.model.path)\n        self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, trust_remote_code=True)\n        self.reward_model_tokenizer = None\n        if self.config.reward_model.enable:\n            reward_model_tokenizer_local_path = copy_to_local(self.config.reward_model.model.path)\n            self.reward_model_tokenizer = hf_tokenizer(reward_model_tokenizer_local_path, trust_remote_code=True)\n        self.reward_fn = get_custom_reward_fn(self.config)\n        reward_loop_manager_cls = get_reward_loop_manager_cls(self.config.reward_model.reward_manager)\n        self.reward_loop = reward_loop_manager_cls(\n            self.config, self.input_tokenizer, self.reward_fn, self.reward_router_address, self.reward_model_tokenizer\n        )\n\n    async def compute_score(self, data: DataProto) -> DataProto:\n        return await self.reward_loop.run_single(data)\n"
  },
  {
    "path": "verl_distillation/verl/experimental/reward/reward_model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport json\nimport logging\nimport os\n\nimport aiohttp\nfrom openai.types.chat import ChatCompletion\n\nfrom verl import DataProto\nfrom verl.single_controller.ray.base import RayWorkerGroup\nfrom verl.workers.config import HFModelConfig, RewardModelConfig\nfrom verl.workers.rollout.replica import get_rollout_replica_class\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass RewardModelManager:\n    \"\"\"Reward model manager.\"\"\"\n\n    def __init__(self, config: RewardModelConfig, worker_group: RayWorkerGroup = None):\n        \"\"\"\n        Initialize the reward model manager.\n\n        Args:\n            config (RewardModelConfig): Reward model configuration.\n            worker_group (RayWorkerGroup, optional): Worker group. Defaults to None.\n        \"\"\"\n        self.config = config\n        self.worker_group = worker_group\n        self._initialize_llm_servers()\n        self._initialize_router()\n        if self.config.rollout.free_cache_engine:\n            self.sleep()\n\n    def _initialize_llm_servers(self):\n        rollout_world_size = self.config.rollout.tensor_model_parallel_size\n        world_size = (\n            self.worker_group.world_size\n            if self.worker_group  # colocate mode\n            else self.config.n_gpus_per_node * self.config.nnodes  # standalone mode\n        )\n        num_replicas = world_size // rollout_world_size\n\n        rollout_replica_class = get_rollout_replica_class(self.config.rollout.name)\n        rollout_config = self.config.rollout\n        model_config = HFModelConfig(\n            path=self.config.model.path,\n            external_lib=self.config.model.external_lib,\n            trust_remote_code=self.config.model.trust_remote_code,\n        )\n        self.tokenizer = model_config.get_processor()\n        self.rollout_replicas = [\n            rollout_replica_class(\n                replica_rank=replica_rank,\n                config=rollout_config,\n                model_config=model_config,\n                gpus_per_node=self.config.n_gpus_per_node,\n                is_reward_model=True,\n            )\n            for replica_rank in range(num_replicas)\n        ]\n        if self.worker_group:\n            self._run_all([server.init_colocated(self.worker_group) for server in self.rollout_replicas])\n        else:\n            self._run_all([server.init_standalone() for server in self.rollout_replicas])\n        self.server_handles = [server._server_handle for server in self.rollout_replicas]\n        self.server_addresses = [server._server_address for server in self.rollout_replicas]\n\n    def _initialize_router(self):\n        worker_urls = [f\"http://{server_address}\" for server_address in self.server_addresses]\n\n        if self.config.rollout.name == \"sglang\":\n            from .router.sglang_router import launch_router_process\n        else:\n            from .router.naive_router import launch_router_process\n\n        self.router_address, _ = launch_router_process(worker_urls=worker_urls)\n\n    def get_router_address(self):\n        return self.router_address\n\n    def wake_up(self):\n        \"\"\"Wake up all rollout replica instances.\"\"\"\n        self._run_all([replica.wake_up() for replica in self.rollout_replicas])\n\n    def sleep(self):\n        \"\"\"Sleep all rollout replica instances.\"\"\"\n        self._run_all([replica.sleep() for replica in self.rollout_replicas])\n\n    def _run_all(self, tasks: list[asyncio.Task]):\n        async def run_all():\n            return await asyncio.gather(*tasks)\n\n        return asyncio.run(run_all())\n\n    async def chat_complete(self, chat_complete_request: dict):\n        url = f\"http://{self.router_address}/v1/chat/completions\"\n        try:\n            timeout = aiohttp.ClientTimeout(total=None)\n            session = aiohttp.ClientSession(timeout=timeout)\n            async with session.post(url, json=chat_complete_request) as resp:\n                output = await resp.text()\n                output = json.loads(output)\n                return ChatCompletion(**output)\n        except Exception as e:\n            raise e\n        finally:\n            await session.close()\n\n    def generate_sequences(self, prompts: DataProto, sampling_params: dict):\n        chat_complete_requests = [\n            {\n                \"model\": self.config.model.path,\n                \"messages\": list(messages),\n                **sampling_params,\n            }\n            for messages in prompts.non_tensor_batch.get(\"raw_prompt\")\n        ]\n        tasks = [self.chat_complete(chat_complete_request) for chat_complete_request in chat_complete_requests]\n        results = self._run_all(tasks)\n        return results\n"
  },
  {
    "path": "verl_distillation/verl/experimental/reward/router/naive_router.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport logging\nimport multiprocessing\nimport os\nimport time\nfrom typing import Any\n\nimport aiohttp\nimport ray\nimport uvicorn\nfrom fastapi import FastAPI, Request\nfrom fastapi.responses import JSONResponse\n\nfrom verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nasync def _read_async_response(resp: aiohttp.ClientResponse) -> dict[str, Any]:\n    if resp.status == 204 or (resp.content_length == 0):\n        return {}\n\n    try:\n        return await resp.json(content_type=None)\n    except Exception:\n        try:\n            text = await resp.text()\n        except Exception:\n            return {}\n        return {\n            \"content_type\": (resp.headers.get(\"Content-Type\") or \"\"),\n            \"text\": text,\n        }\n\n\ndef launch_router_process(\n    worker_urls: list[str],\n):\n    router_ip = ray.util.get_node_ip_address().strip(\"[]\")\n    router_port, _ = get_free_port(router_ip)\n    router_address = (\n        f\"[{router_ip}]:{router_port}\" if is_valid_ipv6_address(router_ip) else f\"{router_ip}:{router_port}\"\n    )\n\n    router_process = multiprocessing.Process(\n        target=run_router,\n        args=(\n            router_ip,\n            router_port,\n            worker_urls,\n        ),\n    )\n    router_process.daemon = True\n    router_process.start()\n    time.sleep(3)\n    assert router_process.is_alive()\n\n    logger.info(f\"Router is running on {router_address}\")\n    return router_address, router_process\n\n\ndef run_router(router_ip: str, router_port: int, worker_urls: list[str]):\n    router = NaiveRouter(worker_urls=worker_urls, verbose=False)\n    uvicorn.run(router.app, host=router_ip, port=router_port, log_level=\"warning\")\n\n\nclass NaiveRouter:\n    def __init__(\n        self,\n        worker_urls: list[str],\n        max_connections: int = 1024,\n        timeout: int = 60,\n        max_attempts: int = 3,\n        retry_delay: float = 2.0,\n        verbose: bool = False,\n    ) -> None:\n        \"\"\"A minimal async load-balancing router.\"\"\"\n        self.verbose = verbose\n        self.app = FastAPI()\n        self.worker_urls = worker_urls\n        self.request_counts = {url: 0 for url in worker_urls}\n\n        self.max_connections = max_connections\n        self.timeout = timeout\n        self.max_attempts = max_attempts\n        self.retry_delay = retry_delay\n\n        self.app = FastAPI()\n\n        # Register startup / shutdown hooks\n        self.app.on_event(\"startup\")(self._on_startup)\n        self.app.on_event(\"shutdown\")(self._on_shutdown)\n\n        # Catch-all proxy route\n        self.app.api_route(\"/{endpoint:path}\", methods=[\"GET\", \"POST\"])(self._make_async_request)\n\n        # Placeholder for aiohttp client\n        self.client = None\n\n    async def _on_startup(self):\n        \"\"\"Initialize aiohttp client safely inside the event loop\"\"\"\n        connector = aiohttp.TCPConnector(\n            limit=self.max_connections,\n            limit_per_host=self.max_connections // 4,\n            ttl_dns_cache=300,\n            use_dns_cache=True,\n        )\n        timeout = aiohttp.ClientTimeout(total=None)\n        self.client = aiohttp.ClientSession(connector=connector, timeout=timeout)\n        if self.verbose:\n            logger.info(f\"[router] aiohttp client initialized with max_connections={self.max_connections}\")\n\n    async def _on_shutdown(self):\n        \"\"\"Gracefully close aiohttp client\"\"\"\n        if self.client and not self.client.closed:\n            await self.client.close()\n            if self.verbose:\n                logger.info(\"[router] aiohttp client closed\")\n\n    async def _make_async_request(self, request: Request, endpoint: str):\n        \"\"\"Proxy single request to a worker URL.\"\"\"\n        if not self.worker_urls:\n            return JSONResponse(status_code=503, content={\"error\": \"No available workers\"})\n\n        worker_url = self._select_worker()\n        target_url = f\"{worker_url}/{endpoint}\"\n\n        if self.verbose:\n            logger.debug(f\"[router] Forwarding request → {target_url}\")\n\n        # Copy request data\n        body = await request.body()\n        headers = dict(request.headers)\n\n        for attempt in range(self.max_attempts):\n            # Send request to worker\n            try:\n                async with self.client.request(request.method, target_url, data=body, headers=headers) as response:\n                    response.raise_for_status()\n                    output = await _read_async_response(response)\n                    self._release_worker(worker_url)\n                    return output\n            except asyncio.TimeoutError:\n                logger.warning(f\"Async request to {endpoint} timed out (attempt {attempt + 1})\")\n            except aiohttp.ClientConnectorError:\n                logger.warning(f\"Connection error for {endpoint} (attempt {attempt + 1})\")\n            except aiohttp.ClientResponseError as e:\n                logger.error(f\"HTTP error for {endpoint}: {e}\")\n                raise\n            except Exception as e:\n                logger.error(f\"Unexpected error for {endpoint}: {e}\")\n                if attempt == self.max_attempts - 1:\n                    raise\n\n            if attempt < self.max_attempts - 1:\n                await asyncio.sleep(self.retry_delay * (2**attempt))\n\n        raise RuntimeError(f\"Failed to complete async request to {endpoint} after {self.max_attempts} attempts\")\n\n    def _select_worker(self) -> str:\n        \"\"\"Select the least-loaded worker (simple round-robin by request count).\"\"\"\n        url = min(self.request_counts, key=self.request_counts.get)\n        self.request_counts[url] += 1\n        return url\n\n    def _release_worker(self, url: str) -> None:\n        \"\"\"Mark worker as free after request completes.\"\"\"\n        self.request_counts[url] = max(0, self.request_counts[url] - 1)\n"
  },
  {
    "path": "verl_distillation/verl/experimental/reward/router/sglang_router.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport multiprocessing\nimport os\nimport time\n\nimport ray\nimport requests\nfrom sglang_router.launch_server import RouterArgs, launch_router\n\nfrom verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef launch_router_process(\n    worker_urls: list[str],\n    request_timeout: int = 180,\n    max_wait_time: int = 300,\n    timeout: int = 30,\n) -> str:\n    router_ip = ray.util.get_node_ip_address().strip(\"[]\")\n    router_port, _ = get_free_port(router_ip)\n    router_address = (\n        f\"[{router_ip}]:{router_port}\" if is_valid_ipv6_address(router_ip) else f\"{router_ip}:{router_port}\"\n    )\n    router_args = RouterArgs(\n        host=router_ip,\n        port=router_port,\n        worker_urls=worker_urls,\n        balance_abs_threshold=0,\n        log_level=\"warn\",\n        request_timeout_secs=request_timeout,\n    )\n    router_process = multiprocessing.Process(target=launch_router, args=(router_args,))\n    router_process.daemon = True\n    router_process.start()\n    time.sleep(3)\n    assert router_process.is_alive()\n\n    # health check\n    start_time = time.time()\n    url = f\"http://{router_address}/health\"\n    with requests.Session() as session:\n        while time.time() - start_time < max_wait_time:\n            try:\n                response = session.get(url, timeout=timeout)\n                if response.status_code == 200:\n                    break\n            except requests.RequestException as e:\n                logger.debug(f\"Health check failed: {e}\")\n\n            time.sleep(2)\n        else:\n            router_process.terminate()\n            raise RuntimeError(f\"Router health check failed after {max_wait_time} seconds.\")\n\n    logger.info(f\"Router is running on {router_address}\")\n    return router_address, router_process\n"
  },
  {
    "path": "verl_distillation/verl/interactions/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/interactions/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\n\nclass BaseInteraction:\n    def __init__(self, config: dict[str, Any]):\n        self.config = config\n        self.name: str = config.get(\"name\", \"interaction_agent\")  # More general agent default role name\n\n    async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str:\n        \"\"\"Create a tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The instance id of the tool.\n        \"\"\"\n        if instance_id is None:\n            return str(uuid4())\n        else:\n            return instance_id\n\n    async def generate_response(\n        self, instance_id: str, messages: list[dict[str, Any]], **kwargs\n    ) -> tuple[bool, str, float, dict[str, Any]]:  # More clear response generation method\n        \"\"\"\n        Generates a response for the current turn of interaction.\n        Returns a tuple containing:\n        - should_terminate_sequence (bool): True if the interaction sequence should end.\n        - response_content (str): The textual content of the response.\n        - current_turn_score (float): The score for this specific turn/response.\n        - additional_data (dict): Any extra information or metadata.\n        \"\"\"\n        should_terminate_sequence: bool = False  # if True, end rollout\n        response_content: str = \"Your current result seems acceptable.\"\n        current_turn_score: float = 0.8\n        additional_data: dict[str, Any] = {}\n        return should_terminate_sequence, response_content, current_turn_score, additional_data\n\n    async def calculate_score(self) -> float:  # More clear score calculation method\n        \"\"\"\n        Calculates a score for the interaction,\n        potentially considering aspects like partial exposure & in-context task switching.\n        should be invoke at turn-level\n        \"\"\"\n        # ...implement the logic to calculate turn-level score...\n        score = 0.0\n        return score\n\n    async def finalize_interaction(self) -> None:  # More clear interaction end and resource release method\n        \"\"\"\n        Finalizes the interaction session and releases any associated state or resources.\n        Simulates: release state\n        \"\"\"\n        # ...implement the logic to release state...\n        pass\n"
  },
  {
    "path": "verl_distillation/verl/interactions/gsm8k_interaction.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom verl.utils.reward_score import gsm8k\n\nfrom .base import BaseInteraction\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass Gsm8kInteraction(BaseInteraction):\n    \"\"\"A demo interaction for calculating the reward of gsm8k.\n\n    - `start_interaction`: start a interaction instance for a trajectory.\n    - `generate_response`: generate the response of the assistant.\n    - `calculate_score`: calculate the score of the interaction.\n    - `finalize_interaction`: finalize the interaction instance.\n    \"\"\"\n\n    def __init__(self, config: dict):\n        super().__init__(config)\n        self._instance_dict = {}\n\n    async def start_interaction(\n        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs\n    ) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        return instance_id\n\n    async def generate_response(\n        self, instance_id: str, messages: list[dict[str, Any]], **kwargs\n    ) -> tuple[bool, str, float, dict]:\n        content = \"\"\n        for i in range(len(messages) - 1, -1, -1):\n            item = messages[i]\n            if item.get(\"role\") == \"assistant\":\n                content = item.get(\"content\")\n                break\n\n        self._instance_dict[instance_id][\"response\"] = content\n\n        reward = await self.calculate_score(instance_id)\n        if reward == 1.0:\n            response = \"Your response is correct!\"\n            should_terminate_sequence = True\n        else:\n            response = \"Your response is incorrect! You need to reflect on your answer and try again.\"\n            should_terminate_sequence = False\n\n        return should_terminate_sequence, response, reward, {}\n\n    async def calculate_score(self, instance_id: str, **kwargs) -> float:\n        return gsm8k.compute_score(\n            self._instance_dict[instance_id][\"response\"],\n            self._instance_dict[instance_id][\"ground_truth\"],\n            method=\"strict\",\n            format_score=0.0,\n            score=1.0,\n        )\n\n    async def finalize_interaction(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "verl_distillation/verl/interactions/utils/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/interactions/utils/interaction_registry.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib.util\nimport logging\nimport os\nimport sys\n\nfrom omegaconf import OmegaConf\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef get_interaction_class(cls_name):\n    \"\"\"Dynamically import and return the interaction class.\"\"\"\n    module_name, class_name = cls_name.rsplit(\".\", 1)\n    if module_name not in sys.modules:\n        spec = importlib.util.find_spec(module_name)\n        module = importlib.util.module_from_spec(spec)\n        sys.modules[module_name] = module\n        spec.loader.exec_module(module)\n    else:\n        module = sys.modules[module_name]\n\n    interaction_cls = getattr(module, class_name)\n    return interaction_cls\n\n\ndef initialize_interactions_from_config(interaction_config_file):\n    \"\"\"Initialize interactions from configuration file.\n\n    Args:\n        interaction_config_file: Path to the interaction configuration file.\n\n    Returns:\n        dict: A dictionary mapping interaction names to BaseInteraction instances.\n    \"\"\"\n    interaction_config = OmegaConf.load(interaction_config_file)\n    interaction_map = {}\n\n    for interaction_item in interaction_config.interaction:\n        cls_name = interaction_item.class_name\n        interaction_cls = get_interaction_class(cls_name)\n\n        # Extract config and name\n        config = OmegaConf.to_container(interaction_item.config, resolve=True)\n\n        # Get the interaction name - either from config or derive from class name\n        name = interaction_item.get(\"name\", None)\n        if name is None:\n            # If no name is specified, use the class name as default\n            class_simple_name = cls_name.split(\".\")[-1]\n            # Remove \"Interaction\" suffix if present, otherwise use full class name\n            if class_simple_name.endswith(\"Interaction\"):\n                name = class_simple_name[:-11].lower()  # Remove \"Interaction\" (11 chars)\n            else:\n                name = class_simple_name.lower()\n\n        # Check for duplicate names\n        if name in interaction_map:\n            raise ValueError(f\"Duplicate interaction name '{name}' found. Each interaction must have a unique name.\")\n\n        # Inject the name into the config\n        config[\"name\"] = name\n\n        # Create the interaction instance\n        interaction = interaction_cls(config=config)\n        interaction_map[name] = interaction\n\n        logger.info(f\"Initialized interaction '{name}' with class '{cls_name}'\")\n\n    return interaction_map\n"
  },
  {
    "path": "verl_distillation/verl/interactions/weather_interaction.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom .base import BaseInteraction\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass WeatherInteraction(BaseInteraction):\n    \"\"\"A demo interaction for handling weather-related queries.\n\n    - `start_interaction`: start a interaction instance for a trajectory.\n    - `generate_response`: generate the response of the assistant.\n    - `calculate_score`: calculate the score of the interaction.\n    - `finalize_interaction`: finalize the interaction instance.\n    \"\"\"\n\n    def __init__(self, config: dict):\n        super().__init__(config)\n        self._instance_dict = {}\n\n    async def start_interaction(\n        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs\n    ) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        return instance_id\n\n    async def generate_response(\n        self, instance_id: str, messages: list[dict[str, Any]], **kwargs\n    ) -> tuple[bool, str, float, dict]:\n        content = \"no tool call\"\n        for i in range(len(messages) - 1, -1, -1):\n            item = messages[i]\n            if item.get(\"role\") == \"tool\":\n                content = item.get(\"content\")\n                break\n        self._instance_dict[instance_id][\"response\"] = content\n\n        reward = await self.calculate_score(instance_id)\n        if reward == 1.0:\n            response = \"Thank you for your weather query!\"\n            should_terminate_sequence = True\n        else:\n            response = \"Please use the weather tool to get the weather information.\"\n            should_terminate_sequence = True\n        return should_terminate_sequence, response, reward, {}\n\n    async def calculate_score(self, instance_id: str, **kwargs) -> float:\n        # For weather interaction, we can implement a more complex scoring logic\n        # For now, we'll just return a default score of 1.0\n        if self._instance_dict[instance_id][\"response\"] == \"no tool call\":\n            return 0.0\n        return 1.0\n\n    async def finalize_interaction(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "verl_distillation/verl/model_merger/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/model_merger/__main__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis module is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends.\n\nTo merge FSDP checkpoints:\n```sh\npython -m verl.model_merger merge \\\n    --backend fsdp \\\n    --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \\\n    --target_dir /path/to/merged_hf_model\n```\n\nTo merge Megatron checkpoints:\n```sh\npython -m verl.model_merger merge \\\n    --backend megatron \\\n    --tie-word-embedding \\\n    --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \\\n    --target_dir /path/to/merged_hf_model\n```\n\nor use distribtued merge for large models like dpskv3 671B\n\n```sh\ntorchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge\\\n    --backend megatron \\\n    --local_dir ./checkpoints/global_step_1/actor \\\n    --target_dir /path/to/merged_hf_model\n```\n\n\nFor more details, please refer to documentation:\nhttps://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model\n\"\"\"\n\nfrom .base_model_merger import generate_config_from_args, parse_args\n\n\ndef main():\n    args = parse_args()\n    config = generate_config_from_args(args)\n    print(f\"config: {config}\")\n\n    if config.backend == \"fsdp\":\n        from .fsdp_model_merger import FSDPModelMerger\n\n        merger = FSDPModelMerger(config)\n    elif config.backend == \"megatron\":\n        from .megatron_model_merger import MegatronModelMerger\n\n        merger = MegatronModelMerger(config)\n    else:\n        raise NotImplementedError(f\"Unknown backend: {config.backend}\")\n\n    merger.merge_and_save()\n    merger.cleanup()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/verl/model_merger/base_model_merger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport os\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass, field\nfrom typing import Optional\n\nimport torch\nfrom accelerate import init_empty_weights\nfrom transformers import (\n    AutoConfig,\n    AutoModelForCausalLM,\n    AutoModelForTokenClassification,\n    AutoModelForVision2Seq,\n    GenerationConfig,\n)\n\nfrom verl.utils import hf_processor, hf_tokenizer\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"verl model merger\")\n    subparsers = parser.add_subparsers(dest=\"operation\", required=True, help=\"Specify 'merge' or 'test' operation.\")\n\n    base_op_parser = argparse.ArgumentParser(add_help=False)\n    base_op_parser.add_argument(\n        \"--backend\", type=str, required=True, choices=[\"fsdp\", \"megatron\"], help=\"The backend of the model\"\n    )\n    base_op_parser.add_argument(\"--local_dir\", type=str, default=None, help=\"Path to the saved model checkpoints.\")\n    base_op_parser.add_argument(\n        \"--tie-word-embedding\",\n        action=\"store_true\",\n        help=\"Whether to tie word embedding weights (currently only Megatron supported)\",\n    )\n    base_op_parser.add_argument(\"--trust-remote-code\", action=\"store_true\", help=\"Whether to trust remote code\")\n    base_op_parser.add_argument(\n        \"--is-value-model\",\n        action=\"store_true\",\n        help=\"Whether the model is a value model (currently only Megatron supported)\",\n    )\n    base_op_parser.add_argument(\n        \"--use_cpu_initialization\",\n        action=\"store_true\",\n        help=\"Whether to use CPU initialization for the model. This is useful for large models that cannot \"\n        \"fit into GPU memory during initialization.\",\n    )\n\n    merge_parser = subparsers.add_parser(\"merge\", parents=[base_op_parser], help=\"Merge model checkpoints and save.\")\n    merge_parser.add_argument(\n        \"--target_dir\", default=\"tmp\", type=str, help=\"Directory to save the merged huggingface model\"\n    )\n    merge_parser.add_argument(\n        \"--hf_upload_path\", default=None, type=str, help=\"Hugging Face repository ID to upload the model\"\n    )\n    merge_parser.add_argument(\n        \"--private\", action=\"store_true\", help=\"Whether to upload the model to a private Hugging Face repository\"\n    )\n\n    test_parser = subparsers.add_parser(\n        \"test\", parents=[base_op_parser], help=\"Test merged model against a reference Hugging Face model\"\n    )\n    test_parser.add_argument(\n        \"--test_hf_dir\", type=str, required=True, help=\"Path to the reference Hugging Face model directory for testing\"\n    )\n\n    args = parser.parse_args()\n    return args\n\n\n@dataclass\nclass ModelMergerConfig:\n    \"\"\"Configuration for model merger operations.\n\n    Args:\n        operation (str): Operation type - 'merge' or 'test'.\n        backend (str): Backend type for the model ('fsdp' or 'megatron').\n        target_dir (Optional[str]): Directory to save the merged huggingface model. Defaults to \"tmp\".\n        hf_upload_path (Optional[str]): Hugging Face repository ID to upload the model. Defaults to None.\n        private (bool): Whether to upload the model to a private Hugging Face repository. Defaults to False.\n        test_hf_dir (Optional[str]): Path to the reference Hugging Face model directory for testing. Defaults to None.\n        tie_word_embedding (bool): Whether to tie word embedding weights (currently only Megatron\n            supported). Defaults to False.\n        trust_remote_code (bool): Whether to trust remote code. Defaults to False.\n        is_value_model (bool): Whether the model is a value model (currently only Megatron\n            supported). Defaults to False.\n        local_dir (Optional[str]): Path to the saved model checkpoints. Defaults to None.\n        hf_model_config_path (Optional[str]): Path to HuggingFace model configuration files. Defaults to None.\n        hf_upload (bool): Whether to upload to HuggingFace (computed automatically). Not for initialization.\n        use_cpu_initialization (bool): Whether to use CPU initialization for large models. Defaults to False.\n    \"\"\"\n\n    operation: str  # 'merge' or 'test'\n    backend: str\n    target_dir: Optional[str] = \"tmp\"\n    hf_upload_path: Optional[str] = None\n    private: bool = False\n    test_hf_dir: Optional[str] = None\n    tie_word_embedding: bool = False\n    trust_remote_code: bool = False\n    is_value_model: bool = False\n    local_dir: Optional[str] = None\n    hf_model_config_path: Optional[str] = None\n    hf_upload: bool = field(init=False)\n    use_cpu_initialization: bool = False\n\n    def __post_init__(self):\n        self.hf_upload = self.operation == \"merge\" and bool(self.hf_upload_path)\n        if self.operation == \"test\":\n            self.target_dir = None\n            self.hf_upload_path = None\n            self.private = False\n\n\ndef generate_config_from_args(args: argparse.Namespace) -> ModelMergerConfig:\n    common_config_args = {\n        \"operation\": args.operation,\n        \"backend\": args.backend,\n        \"tie_word_embedding\": args.tie_word_embedding,\n        \"trust_remote_code\": args.trust_remote_code,\n        \"is_value_model\": args.is_value_model,\n        \"local_dir\": args.local_dir,\n        \"hf_model_config_path\": os.path.join(args.local_dir, \"huggingface\"),\n        \"use_cpu_initialization\": args.use_cpu_initialization,\n    }\n\n    if args.operation == \"merge\":\n        config = ModelMergerConfig(\n            **common_config_args,\n            target_dir=args.target_dir,\n            hf_upload_path=args.hf_upload_path,\n            private=args.private,\n            test_hf_dir=None,\n        )\n        os.makedirs(config.target_dir, exist_ok=True)\n    elif args.operation == \"test\":\n        config = ModelMergerConfig(\n            **common_config_args,\n            test_hf_dir=args.test_hf_dir,\n            # the following args are not used by test operation\n            target_dir=None,\n            hf_upload_path=None,\n            private=False,\n        )\n    else:\n        raise NotImplementedError(f\"Unknown operation: {args.operation}\")\n    return config\n\n\nclass BaseModelMerger(ABC):\n    \"\"\"\n    Abstract base class for merging distributed model checkpoints into HuggingFace format.\n\n    This class provides common functionality for converting model checkpoints from different\n    distributed training backends (FSDP, Megatron) into standard HuggingFace format that\n    can be easily loaded and used for inference or further training.\n\n    The merger supports two main operations:\n    - merge: Convert and save checkpoints to HuggingFace format\n    - test: Validate merged checkpoints against a reference model\n\n    Args:\n        config (ModelMergerConfig): Configuration object containing paths, backend type,\n            and operation parameters.\n\n    Attributes:\n        config (ModelMergerConfig): The configuration object passed during initialization.\n        hf_model_config_path (str): Path to the HuggingFace model configuration files.\n        model_config (PretrainedConfig): Loaded HuggingFace model configuration.\n    \"\"\"\n\n    def __init__(self, config: ModelMergerConfig):\n        self.config = config\n        self.hf_model_config_path = config.hf_model_config_path\n        self.model_config = AutoConfig.from_pretrained(\n            self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code\n        )\n\n    def get_transformers_auto_model_class(self):\n        has_remote_code = hasattr(self.model_config, \"auto_map\") and any(\n            self.model_config.architectures[0] in val for val in self.model_config.auto_map.values()\n        )\n        if has_remote_code:\n            auto_class = next(\n                k for k, v in self.model_config.auto_map.items() if self.model_config.architectures[0] in v\n            )\n            match auto_class:\n                case \"AutoModelForCausalLM\":\n                    return AutoModelForCausalLM\n                case \"AutoModelForTokenClassification\":\n                    return AutoModelForTokenClassification\n                case \"AutoModelForVision2Seq\":\n                    return AutoModelForVision2Seq\n                case _:\n                    raise NotImplementedError(f\"Unknown auto class {auto_class}\")\n        else:\n            if \"ForTokenClassification\" in self.model_config.architectures[0]:\n                return AutoModelForTokenClassification\n            elif \"ForCausalLM\" in self.model_config.architectures[0]:\n                return AutoModelForCausalLM\n            elif \"ForConditionalGeneration\" in self.model_config.architectures[0]:\n                return AutoModelForVision2Seq\n\n            raise NotImplementedError(f\"Unknown architecture {self.model_config.architectures}\")\n\n    def patch_model_generation_config(self, model):\n        \"\"\"\n        The generation_config created from model config may be different to the pretrained model,\n        this may lead to error when generating: https://github.com/volcengine/verl/issues/1246\n\n        This function patch the generation_config created from model config to the pretrained model.\n        \"\"\"\n        if model.can_generate():\n            try:\n                model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path)\n            except OSError:\n                print(\n                    f\"Warning: Generation config file not found in {self.hf_model_config_path}, using a \"\n                    f\"generation config created from the model config.\"\n                )\n        return model\n\n    def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]):\n        \"\"\"\n        Save lora adapter to safetensors.\n\n        Returns:\n            lora_path: str, the path to the lora adapter. None if no lora adapter found.\n\n        Note:\n            This function change the 'state_dict' in place.\n        \"\"\"\n        lora_params_names = [name for name in state_dict.keys() if \"lora_\" in name]\n\n        if len(lora_params_names) == 0:\n            return None\n\n        import json\n        from typing import OrderedDict\n\n        import peft\n        from safetensors.torch import save_file\n\n        lora_params = OrderedDict()\n        target_modules = set()\n        lora_key = None\n\n        for name in lora_params_names:\n            lora_key = name.replace(\".default.weight\", \".weight\")\n            target_modules.add(lora_key.split(\".\")[-3])\n            lora_params[lora_key] = state_dict.pop(name)\n\n        lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1])\n        peft_dict = {\n            \"r\": lora_rank,\n            \"lora_alpha\": 0,  # lora_alpha is not set. An error should be raised to inform the user to set it manually.\n            \"target_modules\": list(target_modules),\n        }\n        peft_config = peft.LoraConfig(**peft_dict).to_dict()\n        peft_config[\"task_type\"] = peft_config[\"task_type\"].value if peft_config[\"task_type\"] else None\n        peft_config[\"peft_type\"] = peft_config[\"peft_type\"].value if peft_config[\"peft_type\"] else None\n        peft_config[\"target_modules\"] = list(peft_config[\"target_modules\"])\n\n        lora_path = os.path.join(self.config.target_dir, \"lora_adapter\")\n        os.makedirs(lora_path, exist_ok=True)\n        with open(os.path.join(lora_path, \"adapter_config.json\"), \"w\", encoding=\"utf-8\") as f:\n            json.dump(peft_config, f, ensure_ascii=False, indent=4)\n        save_file(lora_params, os.path.join(lora_path, \"adapter_model.safetensors\"))\n\n        for name in list(state_dict.keys()):\n            key = (\n                name.replace(\"base_model.model.\", \"\")\n                .replace(\".base_layer.weight\", \".weight\")\n                .replace(\".base_layer.bias\", \".bias\")\n            )\n            state_dict[key] = state_dict.pop(name)\n\n        return lora_path\n\n    def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]):\n        auto_model_class = self.get_transformers_auto_model_class()\n        with init_empty_weights():\n            model = auto_model_class.from_config(\n                self.model_config, torch_dtype=torch.bfloat16, trust_remote_code=self.config.trust_remote_code\n            )\n        model.to_empty(device=\"cpu\")\n        model = self.patch_model_generation_config(model)\n\n        lora_path = self.save_lora_adapter(state_dict)\n        if lora_path:\n            print(f\"Saving lora adapter to {lora_path}\")\n\n        print(f\"Saving model to {self.config.target_dir}\")\n        model.save_pretrained(self.config.target_dir, state_dict=state_dict)\n        del state_dict\n        del model\n\n        processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)\n        tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)\n        if processor is not None:\n            print(f\"Saving processor to {self.config.target_dir}\")\n            processor.save_pretrained(self.config.target_dir)\n        if tokenizer is not None:\n            print(f\"Saving tokenizer to {self.config.target_dir}\")\n            tokenizer.save_pretrained(self.config.target_dir)\n\n    def upload_to_huggingface(self):\n        import requests\n        from huggingface_hub import HfApi\n        from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError\n\n        api = HfApi()\n        try:\n            # Attempt to create repository\n            api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True)\n        except HfHubHTTPError as e:\n            # Handle authentication/API errors\n            if e.response.status_code == 401:\n                raise PermissionError(\n                    \"Hugging Face authentication failed. Verify your token is valid and has write permissions.\"\n                ) from e\n            elif e.response.status_code == 404:\n                raise RepositoryNotFoundError(f\"Repository path not found: {self.config.hf_upload_path}\") from e\n            else:\n                raise ConnectionError(f\"Failed to create repository ({e.response.status_code}): {e}\") from e\n        except requests.exceptions.ConnectionError as e:\n            raise ConnectionError(\"Network connection failed. Check your internet connection.\") from e\n\n        try:\n            # Attempt folder upload\n            api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type=\"model\")\n        except HfHubHTTPError as e:\n            if e.response.status_code == 401:\n                raise PermissionError(\"Authentication failed during upload. Token may have expired.\") from e\n            else:\n                raise RuntimeError(f\"Upload failed ({e.response.status_code}): {e}\") from e\n        except requests.exceptions.ConnectionError as e:\n            raise ConnectionError(\"Network interruption during upload. Try again with stable connection.\") from e\n        except OSError as e:\n            raise FileNotFoundError(f\"Local folder error: {self.config.target_dir} - {str(e)}\") from e\n        except Exception as e:\n            raise RuntimeError(f\"Unexpected error during upload: {str(e)}\") from e\n\n    @abstractmethod\n    def merge_and_save(self):\n        raise NotImplementedError(\"Subclasses should implement this method\")\n\n    @abstractmethod\n    def cleanup(self):\n        raise NotImplementedError(\"Subclasses should implement this method to clean up resources if needed\")\n"
  },
  {
    "path": "verl_distillation/verl/model_merger/fsdp_model_merger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\nfrom concurrent.futures import ThreadPoolExecutor\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom torch.distributed._tensor import Placement, Shard\n\ntry:\n    # for torch 2.5+\n    from torch.distributed.tensor import DTensor\nexcept ImportError:\n    from torch.distributed._tensor import DTensor\n\nfrom tqdm import tqdm\n\nfrom .base_model_merger import BaseModelMerger\n\n\nclass FSDPModelMerger(BaseModelMerger):\n    \"\"\"\n    Model merger for FSDP (Fully Sharded Data Parallel) checkpoints.\n\n    This class handles the conversion of FSDP distributed checkpoints into HuggingFace format.\n    FSDP shards model parameters across multiple processes, and this merger reconstructs\n    the full model by loading and concatenating the sharded parameters from all ranks.\n\n    The merger supports various FSDP configurations including:\n    - Pure FSDP (single dimension sharding)\n    - FSDP + DDP (data parallel + fully sharded data parallel)\n    - DTensor-based sharding with custom device meshes\n\n    Key features:\n    - Automatic detection of world size from checkpoint filenames\n    - Support for DTensor and non-DTensor checkpoints\n    - Parallel loading of checkpoint shards for efficiency\n    - Validation against reference HuggingFace models\n\n    Example:\n        To merge FSDP checkpoints:\n        ```python\n        config = ModelMergerConfig(\n            operation=\"merge\",\n            backend=\"fsdp\",\n            local_dir=\"path/to/fsdp/checkpoints\",\n            target_dir=\"path/to/output\"\n        )\n        merger = FSDPModelMerger(config)\n        merger.merge_and_save()\n        ```\n    \"\"\"\n\n    def _get_world_size(self) -> int:\n        \"\"\"_summary_\n        From FSDP json config file, extract the world size.\n\n        Returns:\n            int: world size\n        \"\"\"\n        config_path = Path(self.config.local_dir) / \"fsdp_config.json\"\n        if not config_path.exists():\n            raise FileNotFoundError(f\"Config file {config_path} does not exist.\")\n\n        with open(config_path) as f:\n            config = json.load(f)\n\n        # Extract world size from the config\n        world_size = config.get(\"world_size\", None)\n        if world_size is None:\n            raise ValueError(\"World size not found in the config file.\")\n\n        return world_size\n\n    def _load_rank_zero_state_dict(self, world_size: int) -> dict:\n        return torch.load(\n            Path(self.config.local_dir) / f\"model_world_size_{world_size}_rank_0.pt\",\n            map_location=\"cpu\",\n            weights_only=False,\n        )\n\n    def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]:\n        \"\"\"\n        Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict.\n        If no DTensor is found, infers a simple FSDP mesh based on world_size.\n        \"\"\"\n        pivot_key = sorted(list(state_dict.keys()))[0]\n        weight = state_dict[pivot_key]\n\n        if isinstance(weight, DTensor):\n            # get sharding info\n            device_mesh = weight.device_mesh\n            mesh = device_mesh.mesh\n            mesh_dim_names = device_mesh.mesh_dim_names\n        else:\n            # for non-DTensor\n            mesh = np.array([world_size], dtype=np.int64)\n            mesh_dim_names = (\"fsdp\",)\n\n        return mesh, mesh_dim_names\n\n    def _calculate_shard_configuration(\n        self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...]\n    ) -> tuple[int, tuple[int, ...]]:\n        \"\"\"Calculates the total number of shards and the shape of the device mesh.\"\"\"\n        assert mesh_dim_names in ((\"fsdp\",), (\"ddp\", \"fsdp\")), f\"Unsupported mesh_dim_names {mesh_dim_names}\"\n\n        if \"tp\" in mesh_dim_names:\n            # TODO: \"tp\" is not supported yet due to the above assert\n            total_shards = mesh.shape[-1] * mesh.shape[-2]\n            mesh_shape = (mesh.shape[-2], mesh.shape[-1])\n        else:\n            total_shards = mesh.shape[-1]\n            mesh_shape = (mesh.shape[-1],)\n\n        return total_shards, mesh_shape\n\n    def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor:\n        \"\"\"Merges a list of tensors based on their DTensor placement\"\"\"\n        if placement.is_replicate():\n            return tensors[0]\n        elif placement.is_partial():\n            raise NotImplementedError(\"Partial placement is not supported yet\")\n        elif placement.is_shard():\n            return torch.cat(tensors, dim=placement.dim).contiguous()\n\n        raise NotImplementedError(f\"Unsupported placement: {placement}\")\n\n    def _load_and_merge_state_dicts(\n        self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...]\n    ) -> dict[str, torch.Tensor]:\n        model_state_dict_lst = [None] * total_shards\n\n        def process_one_shard(rank: int, model_state_dict_lst: list):\n            model_path = Path(self.config.local_dir) / f\"model_world_size_{world_size}_rank_{rank}.pt\"\n            state_dict = torch.load(model_path, map_location=\"cpu\", weights_only=False)\n            model_state_dict_lst[rank] = state_dict\n            return state_dict\n\n        with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:\n            futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)]\n            for future in tqdm(futures, desc=f\"Loading {total_shards} FSDP shards\", total=total_shards):\n                future.result()\n\n        # Merge state dicts from all shards\n        state_dict = {}\n        param_placements: dict[str, list] = {}\n\n        for key in set(model_state_dict_lst[0].keys()):\n            state_dict[key] = []\n            for model_state_shard in model_state_dict_lst:\n                # add tensor shard in order of rank to state_dict[key]\n                tensor = model_state_shard.pop(key)\n                if isinstance(tensor, DTensor):\n                    state_dict[key].append(tensor._local_tensor.bfloat16())\n\n                    placements = tuple(tensor.placements)\n                    # replicated placement at dp dimension can be discarded\n                    if mesh_dim_names[0] in (\"dp\", \"ddp\"):\n                        placements = placements[1:]\n\n                    if key not in param_placements:\n                        param_placements[key] = placements\n                    else:\n                        assert param_placements[key] == placements\n                else:\n                    state_dict[key].append(tensor.bfloat16())\n\n        del model_state_dict_lst\n\n        # Merge tensors\n        for key in sorted(state_dict):\n            if not isinstance(state_dict[key], list):\n                print(f\"No need to merge key {key}\")\n                continue\n            if key in param_placements:\n                # merge shards\n                placements: tuple[Shard] = param_placements[key]\n                if len(mesh_shape) == 1:\n                    # 1-D list, FSDP without TP\n                    assert len(placements) == 1\n                    shards = state_dict[key]\n                    state_dict[key] = self._merge_by_placement(shards, placements[0])\n                else:\n                    # 2-D list, FSDP + TP\n                    raise NotImplementedError(\"FSDP + TP is not supported yet\")\n            else:\n                state_dict[key] = torch.cat(state_dict[key], dim=0)\n\n        return state_dict\n\n    def merge_and_save(self):\n        world_size = self._get_world_size()\n        rank_zero_state_dict = self._load_rank_zero_state_dict(world_size)\n\n        mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size)\n        print(f\"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}\")\n\n        total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names)\n        print(f\"Processing model shards with {total_shards} {mesh_shape} in total\")\n\n        merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names)\n\n        if self.config.operation == \"test\":\n            if not self.config.test_hf_dir:\n                raise ValueError(\"test_hf_dir must be provided for test operation\")\n            self._validate_state_dict(merged_state_dict)\n        elif self.config.operation == \"merge\":\n            self.save_hf_model_and_tokenizer(merged_state_dict)\n            if self.config.hf_upload:\n                self.upload_to_huggingface()\n        else:\n            raise ValueError(f\"Unknown operation: {self.config.operation}\")\n\n    def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]):\n        auto_model_class = self.get_transformers_auto_model_class()\n\n        hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16)\n        hf_state_dict = hf_model.state_dict()\n        del hf_model\n\n        hf_model_keys = set(hf_state_dict.keys())\n        collected_keys = set(state_dict.keys())\n\n        missing_keys = hf_model_keys - collected_keys\n        assert len(missing_keys) == 0, f\"Missing keys in collected state dict: {list(sorted(missing_keys))}\"\n\n        extra_keys = collected_keys - hf_model_keys\n        assert len(extra_keys) == 0, f\"Extra keys in collected state dict: {list(sorted(extra_keys))}\"\n\n        for key in hf_model_keys:\n            hf_shape = hf_state_dict[key].shape\n            collected_shape = state_dict[key].shape\n            assert hf_shape == collected_shape, (\n                f\"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}\"\n            )\n\n            hf_dtype = hf_state_dict[key].dtype\n            collected_dtype = state_dict[key].dtype\n            assert hf_dtype == collected_dtype, (\n                f\"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}\"\n            )\n\n            torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6)\n\n        print(\"FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.\")\n\n    def cleanup(self):\n        \"\"\"Cleanup temporary files if needed.\"\"\"\n        # FSDP merger does not create temporary files, so no cleanup is needed.\n        pass\n"
  },
  {
    "path": "verl_distillation/verl/model_merger/megatron_model_merger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\nimport warnings\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import Any, Callable, ContextManager\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\ntry:\n    # NPU patch\n    import mindspeed.megatron_adaptor  # noqa: F401\nexcept ImportError:\n    pass\n\nfrom accelerate import init_empty_weights\nfrom megatron.core import mpu\nfrom megatron.core.models.gpt.gpt_model import ModelType\nfrom megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed\nfrom safetensors.torch import load_file\nfrom transformers import (\n    AutoConfig,\n    PretrainedConfig,\n)\n\nfrom verl.models.mcore import hf_to_mcore_config\nfrom verl.utils.device import get_device_name, get_nccl_backend, get_torch_device\nfrom verl.utils.distributed import set_numa_affinity\nfrom verl.utils.megatron.dist_checkpointing import load_dist_checkpointing\nfrom verl.utils.megatron_utils import get_model\nfrom verl.utils.tokenizer import hf_processor, hf_tokenizer\n\nfrom .base_model_merger import BaseModelMerger, ModelMergerConfig\n\n\n@contextmanager\ndef noop_context() -> Any:\n    yield\n\n\ndef get_dynamic_pipeline_shards(layer_num: int, pp_size: int) -> list[int]:\n    \"\"\"Calculate the pipeline sharding configuration for Megatron-LM.\n\n    Args:\n        layer_num: Total number of layers in the model.\n        pp_size: Number of pipeline parallel ranks.\n\n    Returns:\n        layer number of each pp rank. Make the sharding of the pipeline as uniform as possible.\n    \"\"\"\n    if layer_num < pp_size:\n        raise ValueError(f\"layer_num {layer_num} must be greater than pp_size {pp_size}.\")\n\n    if pp_size < 1:\n        raise ValueError(f\"pp_size must be at least 1, got {pp_size}.\")\n    if pp_size == 1:\n        return [layer_num]\n\n    if pp_size == 2:\n        return [\n            layer_num // 2,\n            layer_num - layer_num // 2,\n        ]\n\n    middle_size = pp_size - 2\n    shards_strategy = []\n    for middle_layer_num in range(layer_num):\n        first_last_layer_num = layer_num - middle_layer_num * middle_size\n        first_layer_num = first_last_layer_num // 2\n        last_layer_num = first_last_layer_num - first_last_layer_num // 2\n        if 0 < first_layer_num <= middle_layer_num and 0 < last_layer_num <= middle_layer_num:\n            shards_strategy.append(\n                (\n                    [first_layer_num] + [middle_layer_num] * middle_size + [last_layer_num],\n                    abs(first_layer_num - middle_layer_num),\n                )\n            )\n\n    # sort by diff of layer_num, to make it as uniform as possible\n    res = sorted(shards_strategy, key=lambda x: x[1])[0][0]\n    assert sum(res) == layer_num, f\"sum(res)={sum(res)} != layer_num={layer_num}, pp_size={pp_size}\"\n    return res\n\n\nclass MegatronModelMerger(BaseModelMerger):\n    \"\"\"\n    Model merger for Megatron-LM distributed checkpoints.\n\n    This class handles the conversion of Megatron-LM distributed checkpoints into HuggingFace format.\n    Megatron-LM uses tensor parallelism, pipeline parallelism, and data parallelism to distribute\n    large language models across multiple GPUs. This merger reconstructs the full model by\n    loading distributed checkpoints and applying the necessary transformations.\n\n    Key features:\n    - Support for tensor parallel, pipeline parallel, and data parallel configurations\n    - Automatic parameter name mapping from Megatron to HuggingFace conventions\n    - Handling of QKV and gate-up tensor splitting/merging\n    - Support for tied word embeddings and value models\n    - Integration with Megatron's distributed checkpointing system\n\n    The merger handles various model architectures and configurations:\n    - Standard transformer models (GPT-style)\n    - Models with tied word embeddings\n    - Value models for reinforcement learning\n    - Multi-layer attention (MLA) architectures\n    - Mixture of Experts (MoE) models\n\n    Args:\n        config (ModelMergerConfig): Configuration object with Megatron-specific settings\n            including tie_word_embedding and is_value_model flags.\n\n    Example:\n        To merge Megatron checkpoints:\n        ```python\n        config = ModelMergerConfig(\n            operation=\"merge\",\n            backend=\"megatron\",\n            local_dir=\"path/to/megatron/checkpoints\",\n            target_dir=\"path/to/output\",\n            tie_word_embedding=True\n        )\n        merger = MegatronModelMerger(config)\n        merger.merge_and_save()\n        ```\n    \"\"\"\n\n    def __init__(self, config: ModelMergerConfig):\n        super().__init__(config)\n        # Currently we use only 1 rank to merge the dist_ckpt, we will move to multi-process save shortly afterwards\n        if \"WORLD_SIZE\" not in os.environ:\n            os.environ[\"RANK\"] = \"0\"\n            os.environ[\"LOCAL_RANK\"] = \"0\"\n            os.environ[\"WORLD_SIZE\"] = \"1\"\n            os.environ[\"MASTER_ADDR\"] = \"localhost\"\n            os.environ[\"MASTER_PORT\"] = \"12355\"\n\n        set_numa_affinity()\n        torch.distributed.init_process_group(get_nccl_backend())\n\n        self.rank = torch.distributed.get_rank()\n        self.world_size = torch.distributed.get_world_size()\n        local_rank = os.environ.get(\"LOCAL_RANK\", 0)\n        get_torch_device().set_device(f\"{get_device_name()}:{local_rank}\")\n\n        mpu.initialize_model_parallel(\n            tensor_model_parallel_size=1,\n            pipeline_model_parallel_size=self.world_size,\n            virtual_pipeline_model_parallel_size=None,\n            context_parallel_size=1,\n            expert_model_parallel_size=1,\n        )\n        model_parallel_cuda_manual_seed(0)\n        self.hf_config = AutoConfig.from_pretrained(\n            self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code\n        )\n        print(self.hf_config, flush=True)\n\n        self.params_mapping = {\n            # megatron core gpt model name, huggingface model name\n            # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the\n            # longer key within the containing relationship is processed first.\n            \"embedding.word_embeddings\": \"model.embed_tokens\",\n            # input layer norm for dpskv3\n            \"input_layernorm.weight\": \"input_layernorm.weight\",\n            \"input_layernorm.bias\": \"input_layernorm.bias\",\n            # attn\n            \"self_attention.linear_qkv.layer_norm_weight\": \"input_layernorm.weight\",\n            \"self_attention.linear_qkv.layer_norm_bias\": \"input_layernorm.bias\",\n            \"self_attention.linear_qkv\": \"self_attn.qkv_proj\",\n            \"self_attention.q_layernorm\": \"self_attn.q_norm\",\n            \"self_attention.k_layernorm\": \"self_attn.k_norm\",\n            \"self_attention.linear_proj\": \"self_attn.o_proj\",\n            # mla\n            \"self_attention.linear_q_proj\": \"self_attn.q_proj\",\n            \"self_attention.linear_q_down_proj\": \"self_attn.q_a_proj\",\n            \"self_attention.linear_q_up_proj.layer_norm_weight\": \"self_attn.q_a_layernorm.weight\",\n            \"self_attention.linear_q_up_proj\": \"self_attn.q_b_proj\",\n            \"self_attention.linear_kv_down_proj\": \"self_attn.kv_a_proj_with_mqa\",\n            \"self_attention.linear_kv_up_proj.layer_norm_weight\": \"self_attn.kv_a_layernorm.weight\",\n            \"self_attention.linear_kv_up_proj\": \"self_attn.kv_b_proj\",\n            # mlp\n            \"pre_mlp_layernorm\": \"post_attention_layernorm\",\n            \"mlp.linear_fc1.layer_norm_weight\": \"post_attention_layernorm.weight\",\n            \"mlp.linear_fc1.layer_norm_bias\": \"post_attention_layernorm.bias\",\n            \"mlp.linear_fc1\": \"mlp.gate_up_proj\",\n            \"mlp.linear_fc2\": \"mlp.down_proj\",\n            # moe\n            \"mlp.router.expert_bias\": \"mlp.gate.e_score_correction_bias\",\n            \"mlp.router\": \"mlp.gate\",\n            \"mlp.shared_experts.linear_fc1\": \"mlp.shared_experts.gate_up_proj\",\n            \"mlp.shared_experts.linear_fc2\": \"mlp.shared_experts.down_proj\",\n            \"linear_fc1\": \"gate_up_proj\",\n            \"linear_fc2\": \"down_proj\",\n            # output\n            \"final_layernorm\": \"norm\",\n            \"output_layer\": \"lm_head\",\n        }\n\n        if \"Qwen2MoeForCausalLM\" in self.hf_config.architectures:\n            self.params_mapping[\"mlp.shared_experts.linear_fc1\"] = \"mlp.shared_expert.gate_up_proj\"\n            self.params_mapping[\"mlp.shared_experts.linear_fc2\"] = \"mlp.shared_expert.down_proj\"\n            self.params_mapping[\"mlp.shared_experts.gate_weight\"] = \"mlp.shared_expert_gate.weight\"\n\n    def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]:\n        \"\"\"_summary_\n        Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory.\n\n        Args:\n            model_ckpt_path (str): Path to the model checkpoint directory.\n\n        Returns:\n            State dict containing the model parameters.\n        \"\"\"\n\n        # init hf config\n        self.pipeline_shards = get_dynamic_pipeline_shards(self.hf_config.num_hidden_layers, self.world_size)\n        print(f\"Pipeline shards: {self.pipeline_shards}, total layers: {sum(self.pipeline_shards)}\")\n\n        tf_config = hf_to_mcore_config(\n            self.hf_config,\n            torch.bfloat16,\n            num_layers_in_first_pipeline_stage=self.pipeline_shards[0] if len(self.pipeline_shards) > 1 else None,\n            num_layers_in_last_pipeline_stage=self.pipeline_shards[-1] if len(self.pipeline_shards) > 2 else None,\n        )\n        tf_config.use_cpu_initialization = self.config.use_cpu_initialization\n        tie_word_embeddings = getattr(self.hf_config, \"tie_word_embeddings\", False)\n\n        # init megatron model\n        def megatron_model_provider(pre_process, post_process):\n            from verl.models.mcore import init_mcore_model\n\n            parallel_model = init_mcore_model(\n                tf_config,\n                self.hf_config,\n                pre_process,\n                post_process,\n                share_embeddings_and_output_weights=tie_word_embeddings,\n                value=False,\n            )\n            return parallel_model\n\n        context: Callable[..., ContextManager] = (\n            init_empty_weights if self.config.use_cpu_initialization else noop_context\n        )\n        with context():\n            whole_model = get_model(\n                model_provider_func=megatron_model_provider,\n                model_type=ModelType.encoder_or_decoder,\n                wrap_with_ddp=False,\n                transformer_config=tf_config,\n            )\n\n        if self.config.use_cpu_initialization:\n            # convert meta device to empty tensor so it can use `copy_` function\n            whole_model[0].module = whole_model[0].module.to_empty(device=\"cpu\")\n\n        # load state dicts\n        sharded_state_dict = {}\n        for vpp_rank, model in enumerate(whole_model):\n            key = f\"model{vpp_rank}\" if len(whole_model) > 1 else \"model\"\n            mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)\n            sharded_state_dict[key] = model.sharded_state_dict()\n        model_state_dict = load_dist_checkpointing(sharded_state_dict, model_ckpt_path)\n        model_state_dict_list = []\n        for vpp_rank, model in enumerate(whole_model):\n            key = f\"model{vpp_rank}\" if len(whole_model) > 1 else \"model\"\n            mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)\n            model_state_dict_list.append(model_state_dict[key])\n\n        return model_state_dict_list\n\n    def _check_megatron_state_key(self, key: str) -> bool:\n        \"\"\"\n        Checks if the key is a valid Megatron state key.\n\n        Now the model merger only supports keys that start with \"decoder/embedding/output_layer\" in TransformerLayer.\n        Shall not use key starts with \"model.\"\n        \"\"\"\n        if key.startswith(\"model.\"):\n            raise ValueError(\n                f\"Invalid key {key} in Megatron state_dict. Expected keys to start with \"\n                f\"'decoder/embedding/output_layer' in TransformerLayer.\"\n            )\n\n        skip_checking_keys = [\"embedding.word_embeddings\", \"output_layer\"]\n        for skip_key in skip_checking_keys:\n            if skip_key in key:\n                print(f\"skip checking key {key}\")\n                return\n\n        # Exclude extra state keys\n        if not key.startswith(\"decoder\"):\n            raise ValueError(\n                f\"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer.\"\n            )\n\n    def _split_tensors(\n        self, key: str, tensor: torch.Tensor, config: PretrainedConfig, is_value_model: bool = False\n    ) -> list[torch.Tensor]:\n        \"\"\"\n        Splits a tensor into multiple tensors based on the name.\n        This is used to handle qkv and gate_up tensors.\n        \"\"\"\n        if \"linear_fc1.weight\" in key:\n            # if the tensor is gate and proj\n            gate_lst = []\n            up_lst = []\n            gate, up = tensor.chunk(2)\n            gate_lst.append(gate)\n            up_lst.append(up)\n            gate = torch.cat(gate_lst, dim=0)\n            up = torch.cat(up_lst, dim=0)\n            return [gate, up]\n        elif \"self_attention.linear_qkv.\" in key and \"layer_norm\" not in key:\n            # if the tensor is qkv, for each param on tp, split into q, k, v\n            # concat q, k, v separately.\n            q_lst, k_lst, v_lst = [], [], []\n            assert config.num_attention_heads % config.num_key_value_heads == 0\n            num_q_per_kv = config.num_attention_heads // config.num_key_value_heads\n            assert tensor.shape[0] % (num_q_per_kv + 2) == 0, (\n                f\"Tensor shape {tensor.shape} is not divisible by {num_q_per_kv + 2}\"\n            )\n            kv_size = tensor.shape[0] // (num_q_per_kv + 2)\n            split_size = [kv_size * num_q_per_kv, kv_size, kv_size]\n\n            num_query_groups_per_partition = config.num_key_value_heads\n            for chunk in tensor.chunk(num_query_groups_per_partition):\n                split_size = [\n                    kv_size * num_q_per_kv // num_query_groups_per_partition,\n                    kv_size // num_query_groups_per_partition,\n                    kv_size // num_query_groups_per_partition,\n                ]\n                q, k, v = chunk.split(split_size)\n                q_lst.append(q)\n                k_lst.append(k)\n                v_lst.append(v)\n\n            return [torch.cat(q_lst, dim=0), torch.cat(k_lst, dim=0), torch.cat(v_lst, dim=0)]\n        else:\n            return [tensor]\n\n    def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]:\n        state_dict = {}\n        layers_cum = 0\n        if self.world_size > 1:\n            pipeline_cumsum = np.cumsum(self.pipeline_shards)\n            layers_cum = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1]\n\n        print(f\"{layers_cum=}\")\n        for model_state_dict in model_state_dict_list:\n            layers_handled = 0\n            keys = model_state_dict.keys()\n            for key in keys:\n                if \"extra_state\" in key:\n                    continue\n                if self.config.tie_word_embedding and (\"output_layer\" in key):\n                    print(\"skip lm_head and reward_head loading because of tie_word_embeddings\")\n                    continue\n\n                self._check_megatron_state_key(key)\n                hf_name = self._replace_name(key, self.params_mapping)\n                assert hf_name is not None, f\"Failed to convert layer name [{key}] from megatron to huggingface.\"\n                if \"model.layers.\" in hf_name:\n                    local_layer_no = int(hf_name.split(\".\")[2])\n                    layers_handled = max(local_layer_no, layers_handled)\n                    global_layer_no = local_layer_no + layers_cum\n                    new_key_list = hf_name.split(\".\")\n                    new_key_list[2] = str(global_layer_no)\n                    hf_name = \".\".join(new_key_list)\n                else:\n                    warnings.warn(f\"hf_name {hf_name} will not be fixed with layer number\", stacklevel=2)\n\n                if \"mlp.experts.\" in hf_name and \".weight\" in hf_name:\n                    name_prefix, expert_id = hf_name.split(\".weight\")\n                    for proj in [\"gate_up\", \"down\"]:\n                        if f\"{proj}_proj\" in hf_name:\n                            hf_name = hf_name.replace(\n                                f\"mlp.experts.{proj}_proj.weight{expert_id}\",\n                                f\"mlp.experts.{expert_id}.{proj}_proj.weight\",\n                            )\n\n                tensor = model_state_dict[key]\n                split_tensor = self._split_tensors(\n                    key, tensor, self.hf_config, is_value_model=self.config.is_value_model\n                )\n\n                if len(split_tensor) == 1:\n                    state_dict[hf_name] = split_tensor[0]\n                elif len(split_tensor) == 3:\n                    # split qkv\n                    for n, d in zip([\"q\", \"k\", \"v\"], split_tensor, strict=True):\n                        state_dict[hf_name.replace(\"qkv\", n)] = d\n                elif len(split_tensor) == 2:\n                    # split gate up\n                    state_dict[hf_name.replace(\"gate_up\", \"gate\")] = split_tensor[0]\n                    state_dict[hf_name.replace(\"gate_up\", \"up\")] = split_tensor[1]\n                shape_info = (\n                    split_tensor.shape if isinstance(split_tensor, torch.Tensor) else [t.shape for t in split_tensor]\n                )\n                print(f\"converted {key} to {hf_name} with shape {shape_info}\")\n\n            layers_cum += layers_handled + 1  # zero based\n\n        return state_dict\n\n    def save_hf_model_and_tokenizer(self, merged_state_dict):\n        if self.world_size == 1:\n            return super().save_hf_model_and_tokenizer(merged_state_dict)\n\n        from safetensors.torch import save_file\n\n        layer_num = self.hf_config.num_hidden_layers\n\n        # FIXME: make configurable\n        saves_per_layer = 1 if layer_num < 30 else 2\n        saves_total = saves_per_layer * layer_num\n        saves_indexes = {}\n\n        # calculate the layer start index and key chunks\n        layer_this_rank = self.pipeline_shards[self.rank]\n        pipeline_cumsum = np.cumsum(self.pipeline_shards)\n        layer_start = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1]\n        keys = list(merged_state_dict.keys())\n        keys_chunk = np.array_split(np.array(keys), layer_this_rank * saves_per_layer)\n        numel = 0\n\n        assert len(keys_chunk) == layer_this_rank * saves_per_layer, (\n            f\"Expected {len(keys_chunk)} chunks, but got {layer_this_rank * saves_per_layer} for rank {self.rank}.\"\n        )\n\n        # save to model shards manually\n        target_dir = Path(self.config.target_dir)\n        for i, keys in enumerate(keys_chunk):\n            sd_to_save = {k: merged_state_dict[k] for k in keys}\n            numel += sum([sd_to_save[i].numel() for i in sd_to_save])\n            save_idx = layer_start * saves_per_layer + i\n            save_path = target_dir / f\"model-{save_idx + 1:05d}-of-{saves_total:05d}.safetensors\"\n\n            save_file(sd_to_save, save_path)\n            for k in keys:\n                saves_indexes[k] = str(save_path.name)\n\n        tensor = torch.tensor([numel]).to(get_device_name())\n        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)\n        numel = tensor.cpu().item()\n\n        all_save_indexes = [{} for _ in range(self.world_size)]\n        dist.all_gather_object(all_save_indexes, saves_indexes)\n        saves_indexes = {k: v for i in all_save_indexes for k, v in i.items()}\n        if self.rank == 0:\n            with open(target_dir / \"model.safetensors.index.json\", \"w\") as f:\n                json.dump(\n                    {\n                        \"metadata\": {\n                            \"total_size\": numel,\n                        },\n                        \"weight_map\": saves_indexes,\n                    },\n                    f,\n                    indent=4,\n                )\n            print(f\"model saved to {target_dir} with {numel=}\")\n\n            self.model_config.save_pretrained(self.config.target_dir)\n\n            processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)\n            tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)\n            if processor is not None:\n                print(f\"Saving processor to {self.config.target_dir}\")\n                processor.save_pretrained(self.config.target_dir)\n            if tokenizer is not None:\n                print(f\"Saving tokenizer to {self.config.target_dir}\")\n                tokenizer.save_pretrained(self.config.target_dir)\n\n    def merge_and_save(self):\n        from verl.utils.megatron_utils import get_dist_checkpoint_path\n\n        model_ckpt_path = get_dist_checkpoint_path(self.config.local_dir)\n\n        model_state_dict = self._load_state_dicts(model_ckpt_path)\n        merged_state_dict = self._merge_state_dicts(model_state_dict)\n        del model_state_dict\n\n        if self.config.operation == \"test\":\n            if not self.config.test_hf_dir:\n                raise ValueError(\"test_hf_dir must be provided for test operation\")\n            self._validate_state_dict(merged_state_dict)\n        elif self.config.operation == \"merge\":\n            self.save_hf_model_and_tokenizer(merged_state_dict)\n            if self.config.hf_upload:\n                self.upload_to_huggingface()\n        else:\n            raise ValueError(f\"Unknown operation: {self.config.operation}\")\n\n    def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]):\n        \"\"\"\n        Compares the merged Megatron state_dict against a reference safetensors model.\n        Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name.\n        \"\"\"\n        ref_state_dict = load_file(Path(self.config.test_hf_dir) / \"model.safetensors\")\n\n        for name, loaded_weight in state_dict.items():\n            # name = self._replace_name(original_name, self.params_mapping)\n            if not name or name.endswith(\".bias\") and name not in ref_state_dict:\n                continue\n            if \"rotary_emb.inv_freq\" in name:\n                continue\n            if \"lm_head.weight\" in name:\n                if self.config.is_value_model or self.config.tie_word_embedding:\n                    continue\n            if name not in ref_state_dict:\n                raise RuntimeError(f\"key: {name} not exist in state_dict\")\n            param = ref_state_dict[name]\n            assert loaded_weight.dtype == param.dtype\n            torch.testing.assert_close(loaded_weight.to(\"cpu\"), param, atol=1e-2, rtol=5e-2)\n\n    def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str:\n        for m_name, v_name in name_mapping.items():\n            if m_name not in megatron_name:\n                continue\n\n            megatron_name = megatron_name.replace(\"decoder\", \"model\")\n            param_name = megatron_name.replace(m_name, v_name)\n\n            return param_name\n\n        return None  # Return None if no mapping found\n\n    def cleanup(self):\n        torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "verl_distillation/verl/models/README.md",
    "content": "# Models\nCommon modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl. \n## Adding a New Huggingface Model\n### Step 1: Copy the model file from HF to verl\n- Add a new file under verl/models/hf\n- Copy ONLY the model file from huggingface/transformers/models to verl/models/hf\n\n### Step 2: Modify the model file to use packed inputs\n- Remove all the code related to inference (kv cache)\n- Modify the inputs to include only\n    - input_ids (total_nnz,)\n    - cu_seqlens (total_nnz + 1,)\n    - max_seqlen_in_batch: int\n- Note that this requires using flash attention with causal mask.\n\n### Step 2.5: Add tests\n- Add a test to compare this version and the huggingface version\n- Following the infrastructure and add tests to tests/models/hf\n\n### Step 3: Add a function to apply tensor parallelism\n- Please follow\n    - https://pytorch.org/docs/stable/distributed.tensor.parallel.html\n    - https://pytorch.org/tutorials/intermediate/TP_tutorial.html\n- General comments\n    - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward.\n\n### Step 4: Add a function to apply data parallelism\n- Please use FSDP2 APIs\n- See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413\n\n### Step 5: Add a function to apply pipeline parallelism\n- Comes in Pytorch 2.4\n- Currently only in alpha in nightly version\n- Check torchtitan for more details\n\n"
  },
  {
    "path": "verl_distillation/verl/models/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .modeling_llama_megatron import (\n    ParallelLlamaForCausalLM,\n    # rmpad with megatron\n    ParallelLlamaForCausalLMRmPad,\n    # rmpad with megatron and pipeline parallelism\n    ParallelLlamaForCausalLMRmPadPP,\n    ParallelLlamaForValueRmPad,\n    ParallelLlamaForValueRmPadPP,\n    # original model with megatron\n    ParallelLlamaModel,\n)\n\n__all__ = [\n    \"ParallelLlamaForCausalLM\",\n    \"ParallelLlamaForCausalLMRmPad\",\n    \"ParallelLlamaForCausalLMRmPadPP\",\n    \"ParallelLlamaForValueRmPad\",\n    \"ParallelLlamaForValueRmPadPP\",\n    \"ParallelLlamaModel\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/checkpoint_utils/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/checkpoint_utils/llama_loader.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    print(f\"get megatron data parallel size: {mpu.get_data_parallel_world_size()}\")\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_llama(\n    state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False\n):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from verl.utils.logger import print_rank_0\n    from verl.utils.megatron_utils import unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def fetch_params(module):\n        for param in module.parameters():\n            torch.distributed.fetch(\n                param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()\n            )\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, (\n        f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size \"\n        f\"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n    )\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _fetch_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"fetch tensor\"\"\"\n        nonlocal state_dict\n        if tensor is not None:\n            tensor.data.copy_(state_dict[name])\n\n    def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor.data.copy_(tensor_chunk[tp_rank])\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor.data.copy_(tensor_chunk[tp_rank])\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"fetch gate_up tensor in tp shards\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if gate_name in state_dict and up_name in state_dict:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(\n                config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(\n                    torch.cat([gate_weight_tp, up_weight_tp], dim=0)\n                )\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            if tensor is not None:\n                tensor.data.copy_(tensor_chunk[tp_rank])\n        else:\n            print(f\"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n        full_weight_q = state_dict[q_name]\n        full_weight_k = state_dict[k_name]\n        full_weight_v = state_dict[v_name]\n\n        hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n        if config.num_key_value_heads >= tp_size:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n            total_size = q_size_tp + 2 * kv_size_tp\n            new_weight_qkv = torch.empty(\n                total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        else:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head\n            total_size = q_size_tp + 2 * kv_size_tp\n            new_weight_qkv = torch.empty(\n                total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                k_part = full_weight_k[start_idx:end_idx]\n                v_part = full_weight_v[start_idx:end_idx]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n        if tensor is not None:\n            tensor.data.copy_(tensor_chunk[tp_rank])\n\n    # Embeddings\n    # -------------------\n    print_rank_0(\"loading embeddings...\")\n    gpt_model_module = _get_gpt_model(models[0])\n    embed_tokens_weight = None\n    if pp_rank == 0:\n        embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n    _fetch_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n    # Transformer layers\n    # -------------------\n    layer_map = _megatron_calc_layer_map(config)\n\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    num_layer_per_pp = config.num_hidden_layers // pp_size\n    vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n\n    layer_list = []\n    if vpp_size is not None:\n        for vpp_rank in range(vpp_size):\n            num_layer_vpp_chunk = num_layer_per_pp // vpp_size\n            num_layer_this_model = num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (\n                mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk\n            )\n            layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n    else:\n        num_layer_this_model = num_layer_per_pp\n        offset = pp_rank * num_layer_per_pp\n        layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n\n    for layer in layer_list:\n        print_rank_0(f\"loading layer #{layer}...\")\n        layer_name = f\"model.layers.{layer}\"\n        dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n        gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n        sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n        _fetch_tensor(\n            sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.input_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_qkv(\n            sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.q_proj.weight\",\n            f\"{layer_name}.self_attn.k_proj.weight\",\n            f\"{layer_name}.self_attn.v_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.o_proj.weight\",\n            chunk_dim=1,\n        )\n\n        _fetch_tensor(\n            sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.post_attention_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_gate_up(\n            sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.gate_proj.weight\",\n            f\"{layer_name}.mlp.up_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.down_proj.weight\",\n            chunk_dim=1,\n        )\n    # Final Layernorm\n    # -------------------\n    print_rank_0(\"loading final layernorm...\")\n    gpt_model_module = _get_gpt_model(models[-1])\n    _fetch_tensor(\n        getattr(gpt_model_module.model.norm, \"weight\", None),\n        \"model.norm.weight\",\n    )\n\n    print_rank_0(\"loading lm_head...\")\n    if pp_rank + 1 == pp_size:\n        lm_head_weight = gpt_model_module.lm_head.weight\n\n        if is_value_model:\n            if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                _fetch_tensor(lm_head_weight, \"lm_head.weight\")\n                print_rank_0(\"load lm_head weight\")\n            elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                _fetch_tensor(lm_head_weight, \"reward_head.weight\")\n                print_rank_0(\"load lm_head from value_head weight\")\n            else:\n                _fetch_tensor(None, \"lm_head.weight\")\n                print_rank_0(\"fail to match lm_head in value_model\")\n        else:\n            _fetch_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n\n    dist.barrier()\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    print(f\"get megatron data parallel size: {mpu.get_data_parallel_world_size()}\")\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_llama(\n    state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False\n):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from verl.utils.logger import print_rank_0\n    from verl.utils.megatron_utils import unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def broadcast_params(module):\n        for param in module.parameters():\n            torch.distributed.broadcast(\n                param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()\n            )\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, (\n        f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size \"\n        f\"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n    )\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _broadcast_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"broadcast tensor from rank0 across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                weight = state_dict[name]\n                tensor_shape = weight.shape\n            else:\n                tensor_shape = None\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not in state_dict, skip load\")\n            return\n\n        if tensor is None:\n            tensor = torch.empty(\n                tensor_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        if torch.distributed.get_rank() == 0:\n            tensor.data.copy_(weight)\n        dist.broadcast(tensor, src=0, group=mp_group)\n\n    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(\n                config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(\n                    torch.cat([gate_weight_tp, up_weight_tp], dim=0)\n                )\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape \"\n                f\"{tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n            full_weight_q = state_dict[q_name]\n            full_weight_k = state_dict[k_name]\n            full_weight_v = state_dict[v_name]\n\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                new_weight_qkv = torch.empty(\n                    total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                )\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(\n                        torch.cat([q_part, k_part, v_part], dim=0)\n                    )\n\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                new_weight_qkv = torch.empty(\n                    total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                )\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                    k_part = full_weight_k[start_idx:end_idx]\n                    v_part = full_weight_v[start_idx:end_idx]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(\n                        torch.cat([q_part, k_part, v_part], dim=0)\n                    )\n\n            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"loading embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        embed_tokens_weight = None\n        if pp_rank == 0:\n            embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"loading layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.input_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                chunk_dim=1,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                chunk_dim=1,\n            )\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"loading final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n        )\n\n        print_rank_0(\"loading lm_head...\")\n        lm_head_weight = None\n        if pp_rank + 1 == pp_size:\n            lm_head_weight = gpt_model_module.lm_head.weight\n\n        if is_value_model:\n            if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n                print_rank_0(\"load lm_head weight\")\n            elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"reward_head.weight\")\n                print_rank_0(\"load lm_head from value_head weight\")\n            else:\n                _broadcast_tensor(None, \"lm_head.weight\")\n                print_rank_0(\"fail to match lm_head in value_model\")\n        else:\n            _broadcast_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n    dist.barrier()\n    # Broadcast weights inside data parallel groups\n    for wrapped_model in wrapped_models:\n        broadcast_params(wrapped_model)\n\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/checkpoint_utils/llama_saver.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import mpu\nfrom megatron.core.distributed import DistributedDataParallel as LocalDDP\nfrom megatron.core.transformer.module import Float16Module\nfrom torch.nn.parallel import DistributedDataParallel as torchDDP\n\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.logger import print_rank_0\nfrom verl.utils.megatron_utils import unwrap_model\n\n\ndef _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):\n    \"\"\"given TP,DP,PP rank to get the global rank.\"\"\"\n\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    dp_size = mpu.get_data_parallel_world_size()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), (\n        f\"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}\"\n    )\n    # We only support TP-DP-PP grouping, for correctness when resharding\n    return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Merge sharded parameters of a Megatron module into a merged checkpoint.\n\n    Args:\n        wrapped_models (list of megatron.core.distributed.DistributedDataParallel):\n            The local DDP wrapped megatron modules.\n        config (str or None):\n            HF config for model\n        dtype: model params type\n        is_value_model: if model is value model\n        tie_word_embeddings: tie_word_embeddings, not used in llama, only to keep same interface with qwen2\n    Returns:\n        state_dict (dict):\n            The merged state_dict in rank 0, and an empty dictionary in other ranks.\n    \"\"\"\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if dist.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        assert len(models[i].model.layers) == num_layers_per_model, (\n            \"len model layers {} not equal to num_layers_per_model {}\".format(\n                len(models[i].model.layers), num_layers_per_model\n            )\n        )\n\n    state_dict = dict()\n\n    def _get_cpu_tensor(tensor: torch.Tensor):\n        if tensor is None:\n            return None\n        if tensor.device == torch.device(\"cpu\"):\n            return tensor.detach().clone()\n        return tensor.detach().cpu()\n\n    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        if torch.distributed.get_rank() == src_rank:\n            if tensor is None:\n                weight = None\n                tensor_shape = None\n            else:\n                weight = tensor\n                tensor_shape = weight.shape\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not exist, skip collect\")\n            return\n\n        if weight is None:\n            weight = torch.empty(\n                tensor_shape,\n                dtype=dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n\n        dist.broadcast(weight, src=src_rank, group=mp_group)\n\n        if torch.distributed.get_rank() == 0:\n            state_dict[name] = _get_cpu_tensor(weight)\n\n    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)\n            if mutate_func is not None:\n                full_tensor = mutate_func(full_tensor)\n            state_dict[name] = full_tensor\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            intermediate_size_tp = config.intermediate_size // tp_size\n            gate_weight_list = []\n            up_weight_list = []\n            for i in range(tp_size):\n                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n                gate_weight_list.append(gate_weight_tp)\n                up_weight_list.append(up_weight_tp)\n\n            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)\n            state_dict[up_name] = torch.cat(up_weight_list, dim=0)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            q_weight_list = []\n            k_weight_list = []\n            v_weight_list = []\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    k_weight_list.append(k_part)\n                    v_weight_list.append(v_part)\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    if i * config.num_key_value_heads % tp_size == 0:\n                        k_weight_list.append(k_part)\n                        v_weight_list.append(v_part)\n\n            state_dict[q_name] = torch.cat(q_weight_list, dim=0)\n            state_dict[k_name] = torch.cat(k_weight_list, dim=0)\n            state_dict[v_name] = torch.cat(v_weight_list, dim=0)\n\n    # empty cache before collecting weights\n    get_torch_device().empty_cache()\n    # Embeddings\n    # -------------------\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"collecting embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        _broadcast_tp_shard_tensor(\n            gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,\n            \"model.embed_tokens.weight\",\n            src_pp_rank=0,\n        )\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"collecting layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[src_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight,\n                f\"{layer_name}.input_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"collecting final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n            src_pp_rank=pp_size - 1,\n        )\n\n        print_rank_0(\"collecting lm_head...\")\n\n        if is_value_model:\n            if pp_rank == pp_size - 1:\n                print(f\"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}\")\n            _broadcast_tensor(\n                gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,\n                \"lm_head.weight\",\n                src_pp_rank=pp_size - 1,\n            )\n            _broadcast_tensor(\n                gpt_model_module.reward_head.weight\n                if pp_rank == pp_size - 1 and getattr(gpt_model_module, \"reward_weight\", None) is not None\n                else None,\n                \"reward_head.weight\",\n                src_pp_rank=pp_size - 1,\n            )\n\n        else:\n            _broadcast_tp_shard_tensor(\n                getattr(gpt_model_module.lm_head, \"weight\", None) if pp_rank == pp_size - 1 else None,\n                \"lm_head.weight\",\n                src_pp_rank=pp_size - 1,\n            )\n\n    dist.barrier()\n\n    get_torch_device().empty_cache()\n    if torch.distributed.get_rank() == 0:\n        if dtype not in [torch.float16, torch.bfloat16, torch.float32]:\n            print(f'Unknown/unsupported dtype to save: {dtype}\"')\n            exit(1)\n        for k, v in state_dict.items():\n            if dtype != v.dtype:\n                state_dict[k] = v.to(dtype)\n\n    print_rank_0(f\"merge megatron ckpt done, time elapsed {time.time() - start_time}s\")\n    return state_dict\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/layers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .parallel_attention import ParallelLlamaAttention\nfrom .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad\nfrom .parallel_linear import (\n    LinearForLastLayer,\n    MergedColumnParallelLinear,\n    QKVParallelLinear,\n)\nfrom .parallel_mlp import ParallelLlamaMLP\nfrom .parallel_rmsnorm import ParallelLlamaRMSNorm\n\n__all__ = [\n    \"LinearForLastLayer\",\n    \"MergedColumnParallelLinear\",\n    \"QKVParallelLinear\",\n    \"ParallelLlamaAttention\",\n    \"ParallelLlamaDecoderLayer\",\n    \"ParallelLlamaDecoderLayerRmPad\",\n    \"ParallelLlamaMLP\",\n    \"ParallelLlamaRMSNorm\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/layers/parallel_attention.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom flash_attn.layers.rotary import apply_rotary_emb\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers import LlamaConfig\nfrom transformers.utils import is_flash_attn_2_available\n\nfrom verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear\nfrom verl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass LlamaRotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n        )\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\nclass LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n        t = t / self.scaling_factor\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\nclass LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * (\n                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)\n            ) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\nclass LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__(dim, max_position_embeddings, base, device)\n\n        self.factor = config.rope_scaling[\"factor\"]  # `8` in the original implementation\n        self.high_freq_factor = config.rope_scaling[\"high_freq_factor\"]  # `1` in the original implementation\n        self.low_freq_factor = config.rope_scaling[\"low_freq_factor\"]  # `4` in the original implementation\n        self.old_context_len = config.rope_scaling[\n            \"original_max_position_embeddings\"\n        ]  # `8192` in the original implementation\n\n        low_freq_wavelen = self.old_context_len / self.low_freq_factor\n        high_freq_wavelen = self.old_context_len / self.high_freq_factor\n\n        wavelen = 2 * math.pi / self.inv_freq\n        # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor\n        inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq)\n        # otherwise: interpolate between the two, using a smooth factor\n        smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / (\n            self.high_freq_factor - self.low_freq_factor\n        )\n        smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama\n        is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)\n        inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n        )\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass ParallelLlamaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config = config\n        self.megatron_config = megatron_config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n\n        # assign values after tp\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert self.num_heads % tp_size == 0, (\n            f\"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}\"\n        )\n        assert self.num_key_value_heads % tp_size == 0, (\n            f\"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=\"\n            f\"{self.num_key_value_heads}, tp_size={tp_size}\"\n        )\n\n        self.num_heads_per_tp = self.num_heads // tp_size\n        self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size\n        self.hidden_size_per_tp = self.hidden_size // tp_size\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and \"\n                f\"`num_heads`: {self.num_heads}).\"\n            )\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n\n        # [self.q_size, self.k_size, self.v_size]\n        self.qkv_proj = QKVParallelLinear(\n            input_size=self.hidden_size,\n            num_heads=self.num_heads,\n            num_key_value_heads=self.num_key_value_heads,\n            head_dim=self.head_dim,\n            bias=config.attention_bias,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n        self.q_size = self.num_heads_per_tp * self.head_dim\n        self.k_size = self.num_key_value_heads_per_tp * self.head_dim\n        self.v_size = self.num_key_value_heads_per_tp * self.head_dim\n\n        self.o_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.num_heads * self.head_dim,\n            output_size=self.hidden_size,\n            bias=config.attention_bias,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self._init_rope()\n\n    def _init_rope(self):\n        if self.config.rope_scaling is None:\n            self.rotary_emb = LlamaRotaryEmbedding(\n                self.head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                base=self.rope_theta,\n            )\n        else:\n            rope_type_key = \"type\" if \"type\" in self.config.rope_scaling else \"rope_type\"\n            scaling_type = self.config.rope_scaling[rope_type_key]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if scaling_type == \"linear\":\n                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(\n                    self.head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"dynamic\":\n                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(\n                    self.head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"llama3\":\n                self.rotary_emb = LlamaLlama3ScalingRotaryEmbedding(\n                    self.head_dim,\n                    self.config,\n                    max_position_embeddings=self.max_position_embeddings,\n                    base=self.rope_theta,\n                )\n            else:\n                raise ValueError(f\"Unknown RoPE scaling type {scaling_type}\")\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, \"\n                f\"but is {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, \"\n                f\"but is {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)\n        attn_output = self.o_proj(attn_output)[0]\n        return attn_output\n\n\n\"\"\"\nRemove padding Attention\n- Using Flash-attn 2\n- Compatible with sequence parallel\n\"\"\"\n\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa: F401\n\n\ndef apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):\n    batch_size = position_ids.shape[0]\n\n    q = pad_input(q, indices, batch_size, sequence_length)  # (batch_size, seqlen, num_head, head_dim)\n    k = pad_input(k, indices, batch_size, sequence_length)\n    cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n\n    q_embed = index_first_axis(rearrange(q_embed, \"b s ... -> (b s) ...\"), indices)\n    k_embed = index_first_axis(rearrange(k_embed, \"b s ... -> (b s) ...\"), indices)\n\n    return q_embed, k_embed\n\n\n# use flash-attn rotary embeddings with rmpad\n# cos/sin shoudl be: (seq_length, rotary_dim / 2)\ndef apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):\n    q_embed = apply_rotary_emb(\n        q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen\n    )\n    k_embed = apply_rotary_emb(\n        k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen\n    )\n    return q_embed, k_embed\n\n\nclass ParallelLlamaAttentionRmPad(ParallelLlamaAttention):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: torch.Tensor = None,\n        max_seqlen_in_batch: int = None,\n    ):\n        total_nnz, _, _ = hidden_states.size()  # This is the total_nnz padded after sequence parallel\n\n        if self.megatron_config.sequence_parallel:\n            total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()\n\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split(\n            [self.q_size, self.k_size, self.v_size], dim=-1\n        )  # (total_nnz, 1, hidden_size)\n\n        if self.megatron_config.sequence_parallel:\n            sequence_parallel_pad = total_nnz - cu_seqlens[-1]\n            total_nnz = cu_seqlens[-1]  # total_nnz before sp padding\n            query_states = query_states[:total_nnz]\n            key_states = key_states[:total_nnz]\n            value_states = value_states[:total_nnz]\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dime x hidden_dim\n        # therefore we just need to keep the original shape\n        query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)\n        key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n        value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n\n        cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)\n        cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2]  # flash attn only needs half\n        query_states, key_states = apply_rotary_pos_emb_rmpad_flash(\n            query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch\n        )\n        # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin,\n        # position_ids, indices,\n\n        # TODO: llama does not have dropout in the config??\n        # It is recommended to use dropout with FA according to the docs\n        # when training.\n        dropout_rate = 0.0  # if not self.training else self.attn_dropout\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in float16 just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (LlamaRMSNorm handles it correctly)\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            query_states = query_states.to(torch.float16)\n            key_states = key_states.to(torch.float16)\n            value_states = value_states.to(torch.float16)\n\n        attn_output_unpad = flash_attn_varlen_func(\n            query_states,\n            key_states,\n            value_states,\n            cu_seqlens_q=cu_seqlens,\n            cu_seqlens_k=cu_seqlens,\n            max_seqlen_q=max_seqlen_in_batch,\n            max_seqlen_k=max_seqlen_in_batch,\n            dropout_p=dropout_rate,\n            softmax_scale=None,\n            causal=True,\n        )\n\n        attn_output_unpad = attn_output_unpad.to(input_dtype)\n        attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()\n\n        # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled\n        # Here we need to repad\n        if self.megatron_config.sequence_parallel:\n            attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))\n\n        attn_output_unpad = self.o_proj(attn_output_unpad)[0]\n        return attn_output_unpad\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/layers/parallel_decoder.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import LlamaConfig\n\nfrom verl.utils.megatron_utils import TransformerConfig, convert_config\n\nfrom .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad\nfrom .parallel_mlp import ParallelLlamaMLP\nfrom .parallel_rmsnorm import ParallelLlamaRMSNorm\n\n\nclass ParallelLlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Note: sequence parallel is hidden inside ColumnParallelLinear\n        # reduce scatter is hidden inside RowParallelLinear\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        # TODO: add sequence parallel operator all_gather here\n\n        hidden_states = self.mlp(hidden_states)\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n\n\nclass ParallelLlamaDecoderLayerRmPad(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        residual = hidden_states  # (total_nnz // sp, 1, hidden_size)\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)\n        # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        # shape changes same as attn\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/layers/parallel_linear.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py\n\nimport torch\nfrom megatron.core import tensor_parallel\n\n\nclass QKVParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        num_heads,\n        num_key_value_heads,\n        head_dim,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.q_output_size = num_heads * head_dim\n        self.kv_output_size = num_key_value_heads * head_dim\n        self.head_dim = head_dim\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        input_size = self.input_size\n        output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim\n\n        super().__init__(\n            input_size=input_size,\n            output_size=output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n\n\nclass MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        gate_ouput_size,\n        up_output_size,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.output_size = gate_ouput_size + up_output_size\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        super().__init__(\n            input_size=self.input_size,\n            output_size=self.output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n\n\nclass LinearForLastLayer(torch.nn.Linear):\n    def __init__(\n        self,\n        input_size,\n        output_size,\n        *,\n        config,\n        bias=True,\n    ):\n        super().__init__(in_features=input_size, out_features=output_size, bias=bias)\n        self.sequence_parallel = config.sequence_parallel\n        if self.sequence_parallel:\n            self.weight.sequence_parallel = True\n\n    def forward(\n        self,\n        input_,\n        weight=None,\n        runtime_gather_output=None,\n    ):\n        logits = super().forward(input_)\n        logits = logits.float()\n        if self.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits, None\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/layers/parallel_mlp.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers.activations import ACT2FN\n\nfrom verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear\nfrom verl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass ParallelLlamaMLP(nn.Module):\n    def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=self.hidden_size,\n            gate_ouput_size=self.intermediate_size,\n            up_output_size=self.intermediate_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n        self.gate_size = self.intermediate_size // tp_size\n\n        self.down_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.intermediate_size,\n            output_size=self.hidden_size,\n            bias=False,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        gate_up = self.gate_up_proj(x)[0]\n        gate, up = gate_up.split(self.gate_size, dim=-1)\n        return self.down_proj(self.act_fn(gate) * up)[0]\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/layers/parallel_rmsnorm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numbers\n\nimport torch\nfrom apex.normalization.fused_layer_norm import fused_rms_norm_affine\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import LlamaConfig\n\nfrom verl.utils.megatron import sequence_parallel as sp_utils\n\n\nclass ParallelLlamaRMSNorm(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        if isinstance(config.hidden_size, numbers.Integral):\n            normalized_shape = (config.hidden_size,)\n        self.normalized_shape = torch.Size(normalized_shape)\n        self.weight = nn.Parameter(torch.ones(self.normalized_shape))\n        self.variance_epsilon = config.rms_norm_eps\n\n        if megatron_config.sequence_parallel:\n            sp_utils.mark_parameter_as_sequence_parallel(self.weight)\n\n    def forward(self, hidden_states):\n        return fused_rms_norm_affine(\n            input=hidden_states,\n            weight=self.weight,\n            normalized_shape=self.normalized_shape,\n            eps=self.variance_epsilon,\n            memory_efficient=True,\n        )\n"
  },
  {
    "path": "verl_distillation/verl/models/llama/megatron/modeling_llama_megatron.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch LLaMA model with Megatron-style acceleration.\"\"\"\n\nfrom typing import Optional\n\nimport torch\nimport torch.utils.checkpoint\nfrom megatron.core import ModelParallelConfig, mpu, tensor_parallel\nfrom torch import nn\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\nfrom transformers.models.llama.configuration_llama import LlamaConfig\nfrom transformers.models.llama.modeling_llama import CausalLMOutputWithPast\n\nfrom verl.utils.megatron import sequence_parallel as sp_utils\nfrom verl.utils.megatron import tensor_parallel as tp_utils\nfrom verl.utils.megatron_utils import TransformerConfig, convert_config\n\nfrom .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm\n\n\"\"\"\nTODO: \n1. Add weight initialization. Here we need to be careful on TP weight init.\n2. Add sequence parallel\n3. Load checkpoint from meta LLama pretrained checkpoint\n\"\"\"\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\nclass ParallelLlamaModel(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n            num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n        )\n\n        self.layers = nn.ModuleList(\n            [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n                inputs_embeds.device\n            )\n            combined_attention_mask = (\n                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n            )\n\n        return combined_attention_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (batch_size, seq_length)\n            attention_mask: attention_mask. shape (batch_size, seq_length)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        batch_size, seq_length = input_ids.shape\n        inputs_embeds = self.embed_tokens(input_ids)\n        # embed positions\n\n        attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)\n\n        hidden_states = inputs_embeds\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelLlamaForCausalLM(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.model = ParallelLlamaModel(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        hidden_states = outputs\n        logits = self.lm_head(hidden_states)[0]\n\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)\n\n        logits = logits.float()\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nfrom flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa: F401, E402\n\n\nclass ParallelLlamaModelRmPad(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        self.megatron_config = megatron_config\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n            num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n        )\n\n        self.layers = nn.ModuleList(\n            [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n        # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n        inputs_embeds = inputs_embeds.transpose(0, 1)\n        if self.megatron_config.sequence_parallel:\n            inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n        hidden_states = inputs_embeds\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelLlamaForCausalLMRmPad(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n        self._init_head(config)\n\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        logits = self.lm_head(hidden_states)[0]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)  # (total_nnz_padded, 1, vocab_size)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n        batch_size, sequence_length = input_ids.shape\n\n        # remove padding here\n        input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(\n            input_ids.unsqueeze(dim=-1), attention_mask\n        )  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids = sp_utils.pad_to_sequence_parallel(input_ids)\n\n        input_ids = input_ids.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = outputs\n\n        logits = self._forward_head(hidden_states)\n\n        # remove padding from sequence parallel\n        if self.megatron_config.sequence_parallel:\n            totol_nnz = cu_seqlens[-1]\n            logits = logits[:totol_nnz]  # (total_nnz_padded)\n\n        logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension\n        # add removed padding back\n        logits = pad_input(\n            logits, indices, batch_size, seqlen=sequence_length\n        )  # (batch_size, sequence_length, vocab_size)\n\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nclass ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        output = super().forward(input_ids, attention_mask, position_ids)\n        output.logits = torch.squeeze(output.logits, dim=-1)\n        return output\n\n\n\"\"\"\nSupport pipeline parallelism\n\"\"\"\n\n\nclass ParallelLlamaModelRmPadPP(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n    This model definition supports pipeline parallelism. To support pp and vpp,\n    - This model only contains layer in this pp stage and vpp chunk\n    - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        self.megatron_config = megatron_config\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        if pre_process:\n            self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n                num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n            )\n        else:\n            self.embed_tokens = None\n\n        pp_rank = mpu.get_pipeline_model_parallel_rank()\n        pp_size = megatron_config.pipeline_model_parallel_size\n        self.num_layer_per_pp = config.num_hidden_layers // pp_size\n        vpp_size = megatron_config.virtual_pipeline_model_parallel_size\n        vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()\n\n        if vpp_size is not None:\n            self.layers = nn.ModuleList()\n            self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size\n            self.num_layer_this_model = self.num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk)\n        else:\n            self.num_layer_this_model = self.num_layer_per_pp\n            offset = pp_rank * self.num_layer_per_pp\n\n        self.layers = nn.ModuleList()\n        for i in range(self.num_layer_this_model):\n            layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i)\n            self.layers.add_module(f\"{i}\", layer)\n\n        if post_process:\n            self.norm = ParallelLlamaRMSNorm(config, megatron_config)\n        else:\n            self.norm = None\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        self.input_tensor = input_tensor\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        if self.pre_process:\n            inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n            # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron\n            # so need to deal with it by handle here:\n            # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n            inputs_embeds = inputs_embeds.transpose(0, 1)\n            if self.megatron_config.sequence_parallel:\n                inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n            hidden_states = inputs_embeds\n        else:\n            # self.hidden_states should be passed by Megatron\n            hidden_states = self.input_tensor\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        if self.post_process:\n            hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelLlamaForCausalLMRmPadPP(nn.Module):\n    def __init__(\n        self,\n        config: LlamaConfig,\n        megatron_config: ModelParallelConfig,\n        pre_process,\n        post_process,\n        share_embeddings_and_output_weights=False,\n    ):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelLlamaModelRmPadPP(\n            config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process\n        )\n        assert share_embeddings_and_output_weights is False, (\n            \"Llama Model not supports sharing embedding and output weights\"\n        )\n        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        if post_process:\n            self._init_head(config)\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        assert len(input_tensor) == 1\n        self.model.set_input_tensor(input_tensor[0])\n\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        # logits shape before forward_head hidden_states.shape: [4, 32, 4096]\n        logits = self.lm_head(hidden_states)[0]\n        # logits shape after forward_head logits.shape: [8, 32, 8]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        return logits\n\n    def forward(\n        self,\n        # original input\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # Note that input_ids, attention_mask and position_ids should be passed to every pp layer.\n        # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model\n        batch_size, sequence_length = input_ids.shape\n        # remove padding here\n        input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(\n            input_ids.unsqueeze(dim=-1), attention_mask\n        )  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)\n\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids_rmpad,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        if self.post_process:\n            hidden_states = outputs\n            # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])\n            logits = self._forward_head(hidden_states)\n            logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension # torch.Size([8, 32, 16])\n\n            # remove padding from sequence parallel\n            if self.megatron_config.sequence_parallel:\n                totol_nnz = cu_seqlens[-1]\n                logits = logits[:totol_nnz]  # (total_nnz_padded)\n            # add removed padding back. If input is already rmpad, we let the caller pad_input\n            logits = pad_input(\n                logits, indices, batch_size, seqlen=sequence_length\n            )  # (batch_size, sequence_length, vocab_size)\n\n            return CausalLMOutputWithPast(\n                loss=None,\n                logits=logits,\n                past_key_values=None,\n                hidden_states=None,\n                attentions=None,\n            )\n        else:\n            return outputs\n\n\nclass ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)\n        if self.post_process:\n            output.logits = torch.squeeze(output.logits, dim=-1)\n            return output\n        else:\n            return output\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .registry import (\n    get_mcore_forward_fn,\n    get_mcore_forward_fused_fn,\n    get_mcore_forward_no_padding_fn,\n    get_mcore_weight_converter,\n    hf_to_mcore_config,\n    init_mcore_model,\n)\n\n__all__ = [\n    \"hf_to_mcore_config\",\n    \"init_mcore_model\",\n    \"get_mcore_forward_fn\",\n    \"get_mcore_weight_converter\",\n    \"get_mcore_forward_fused_fn\",\n    \"get_mcore_forward_no_padding_fn\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/config_converter.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# convert huggingface config to mcore transformer config\n\n\nimport warnings\nfrom typing import TypeVar\n\nimport torch\nimport torch.nn.functional as F\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.transformer import MLATransformerConfig, TransformerConfig\nfrom transformers import PretrainedConfig\n\nT = TypeVar(\"T\", bound=TransformerConfig)\n\n\ndef _get_base_transformer_config(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> dict:\n    \"\"\"\n    Create a base TransformerConfig with common parameters across different model architectures.\n    TODO: (ycl) use dataclass or converter config?\n\n    Args:\n        hf_config: HuggingFace model configuration\n        dtype: Data type for the model\n        override_transformer_config_kwargs: Additional parameters to override defaults\n\n    Returns:\n        TransformerConfig with common parameters\n    \"\"\"\n\n    # Common parallel state parameters\n    overlap_p2p_comm = (\n        mpu.get_virtual_pipeline_model_parallel_world_size() is not None\n        and mpu.get_virtual_pipeline_model_parallel_world_size() > 1\n    )\n    batch_p2p_comm = False\n\n    # Base configuration with common parameters\n    base_config = {\n        # Model architecture parameters\n        \"num_layers\": hf_config.num_hidden_layers,\n        \"hidden_size\": hf_config.hidden_size,\n        \"num_attention_heads\": hf_config.num_attention_heads,\n        \"num_query_groups\": hf_config.num_key_value_heads,\n        \"ffn_hidden_size\": hf_config.intermediate_size,\n        \"attention_dropout\": hf_config.attention_dropout,\n        \"hidden_dropout\": getattr(hf_config, \"hidden_dropout\", 0.0),\n        \"kv_channels\": getattr(hf_config, \"head_dim\", None),\n        \"layernorm_epsilon\": hf_config.rms_norm_eps,\n        \"add_bias_linear\": True,\n        # Activation and normalization\n        \"activation_func\": F.silu,\n        \"normalization\": \"RMSNorm\",\n        \"gated_linear_unit\": True,\n        # Data types\n        \"pipeline_dtype\": dtype,\n        \"params_dtype\": dtype,\n        \"bf16\": dtype is torch.bfloat16,\n        # Parallel configuration\n        \"tensor_model_parallel_size\": mpu.get_tensor_model_parallel_world_size(),\n        \"pipeline_model_parallel_size\": mpu.get_pipeline_model_parallel_world_size(),\n        \"expert_model_parallel_size\": mpu.get_expert_model_parallel_world_size(),\n        \"expert_tensor_parallel_size\": mpu.get_expert_tensor_parallel_world_size(),\n        \"virtual_pipeline_model_parallel_size\": mpu.get_virtual_pipeline_model_parallel_world_size(),\n        \"context_parallel_size\": mpu.get_context_parallel_world_size(),\n        \"overlap_p2p_comm\": overlap_p2p_comm,\n        \"batch_p2p_comm\": batch_p2p_comm,\n        \"sequence_parallel\": mpu.get_tensor_model_parallel_world_size() > 1,\n        # Common settings\n        \"variable_seq_lengths\": True,\n        \"masked_softmax_fusion\": True,\n        \"moe_token_dispatcher_type\": \"alltoall\",\n    }\n\n    # Update with any provided overrides\n    # override_transformer_config_kwargs as kwargs shall never be none\n    base_config.update(override_transformer_config_kwargs)\n\n    return base_config\n\n\ndef _get_mla_transformer_config(\n    hf_config: PretrainedConfig, mla_rope_config: dict, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> dict:\n    \"\"\"\n    Create a MLATransformerConfig with common parameters across different model architectures.\n    This is specifically for MLA models like DeepseekV3.\n\n    Args:\n        hf_config: HuggingFace model configuration\n        mla_rope_config: MLA specific RoPE configuration\n        dtype: Data type for the model\n        override_transformer_config_kwargs: Additional parameters to override defaults\n\n    Returns:\n        MLATransformerConfig with common parameters\n    \"\"\"\n    base_config = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, **override_transformer_config_kwargs)\n    mla_config = {\n        # MLA specific parameters\n        \"q_lora_rank\": hf_config.q_lora_rank,\n        \"kv_lora_rank\": hf_config.kv_lora_rank,\n        \"qk_head_dim\": hf_config.qk_nope_head_dim,\n        \"qk_pos_emb_head_dim\": hf_config.qk_rope_head_dim,\n        \"v_head_dim\": hf_config.v_head_dim,\n        \"rotary_base\": hf_config.rope_theta,\n        \"rotary_scaling_factor\": mla_rope_config[\"factor\"],\n        \"rope_type\": mla_rope_config[\"type\"],\n        \"max_position_embeddings\": mla_rope_config[\"original_max_position_embeddings\"],\n        \"beta_fast\": mla_rope_config[\"beta_fast\"],\n        \"beta_slow\": mla_rope_config[\"beta_slow\"],\n        \"mscale\": mla_rope_config[\"mscale\"],\n        \"mscale_all_dim\": mla_rope_config[\"mscale_all_dim\"],\n    }\n\n    base_config.update(mla_config)\n    return base_config\n\n\ndef check_and_construct_configs(original_config: dict, cls: type[T]) -> T:\n    \"\"\"\n    Check and disable incompatible configurations for older Megatron version.\n\n    Args:\n        original_config (dict): The original model configuration.\n\n    Returns:\n        dict: The updated model configuration with incompatible settings disabled.\n    \"\"\"\n    removed_keys = []\n    for key in original_config.keys():\n        if not hasattr(cls, key):\n            removed_keys.append(key)\n    if removed_keys:\n        warnings.warn(\n            f\"The following keys are not supported in the current Megatron version and will be removed: {removed_keys}\",\n            stacklevel=2,\n        )\n        for key in removed_keys:\n            original_config.pop(key)\n\n    original_config = mapping_string_to_attn_backend(original_config)\n    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:\n        print(f\"Overridden {cls.__name__} init config: {original_config}\")\n    return cls(**original_config)\n\n\ndef hf_to_mcore_config_dense(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    # for LlamaForCausalLM or Qwen2ForCausalLM\n    qkv_bias = True if \"Qwen2\" in hf_config.architectures[0] else getattr(hf_config, \"attention_bias\", False)\n    qk_layernorm = True if \"Qwen3\" in hf_config.architectures[0] else False\n\n    args: dict = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        add_qkv_bias=qkv_bias,\n        qk_layernorm=qk_layernorm,\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    return check_and_construct_configs(args, TransformerConfig)\n\n\ndef hf_to_mcore_config_qwen2moe(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    args: dict = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        layernorm_epsilon=hf_config.rms_norm_eps,\n        # MoE specific\n        moe_ffn_hidden_size=hf_config.moe_intermediate_size,\n        moe_router_bias_update_rate=0.001,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        num_moe_experts=hf_config.num_experts,\n        moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size,\n        moe_aux_loss_coeff=hf_config.router_aux_loss_coef,\n        # moe_aux_loss_coeff=0.0,\n        moe_router_load_balancing_type=\"none\",  # turn off aux_loss as it hurts perf in RL\n        moe_shared_expert_overlap=True,\n        moe_grouped_gemm=True,\n        moe_router_score_function=\"softmax\",\n        # Other optimizations\n        persist_layer_norm=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n        # Qwen specific\n        moe_router_pre_softmax=True,\n        add_qkv_bias=True,\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    return check_and_construct_configs(args, TransformerConfig)\n\n\ndef hf_to_mcore_config_mixtral(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    args: dict = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        layernorm_epsilon=hf_config.rms_norm_eps,\n        # MoE specific\n        num_moe_experts=hf_config.num_local_experts,\n        moe_aux_loss_coeff=hf_config.router_aux_loss_coef,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        moe_router_pre_softmax=True,\n        moe_router_load_balancing_type=\"none\",  # turn off aux_loss as it hurts perf in RL\n        moe_router_score_function=\"softmax\",\n        moe_shared_expert_intermediate_size=None,  # mixtral has no shared expert\n        moe_shared_expert_overlap=False,  # mixtral has no shared expert\n        moe_ffn_hidden_size=hf_config.intermediate_size,\n        moe_router_bias_update_rate=0.001,\n        # moe_permute_fusion=True, # need TE 2.1+\n        moe_grouped_gemm=True,\n        # Other optimizations\n        persist_layer_norm=True,\n        apply_rope_fusion=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    return check_and_construct_configs(args, TransformerConfig)\n\n\ndef hf_to_mcore_config_qwen3moe(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    args: dict = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        layernorm_epsilon=hf_config.rms_norm_eps,\n        # MoE specific\n        moe_ffn_hidden_size=hf_config.moe_intermediate_size,\n        moe_router_bias_update_rate=0.001,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        num_moe_experts=hf_config.num_experts,\n        moe_aux_loss_coeff=hf_config.router_aux_loss_coef,\n        # moe_aux_loss_coeff=0.0,\n        moe_router_load_balancing_type=\"none\",  # turn off aux_loss as it hurts perf in RL\n        moe_grouped_gemm=True,\n        moe_router_score_function=\"softmax\",\n        # Other optimizations\n        persist_layer_norm=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n        # Qwen specific\n        moe_router_pre_softmax=False,\n        qk_layernorm=True,\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    return check_and_construct_configs(args, TransformerConfig)\n\n\ndef hf_to_mcore_config_dpskv3(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> MLATransformerConfig:\n    # DeepseekV3ForCausalLM\n    from megatron.core.transformer.enums import AttnBackend\n\n    from .patch_v012 import apply_patch\n\n    apply_patch()\n\n    mla_rope_config = {\n        \"beta_fast\": 32,\n        \"beta_slow\": 1,\n        \"factor\": 1,\n        \"mscale\": 1.0,\n        \"mscale_all_dim\": 1.0,\n        \"original_max_position_embeddings\": 4096,\n        \"type\": \"rope\",\n    }\n    if \"rope_scaling\" in hf_config and hf_config.rope_scaling is not None:\n        mla_rope_config.update(hf_config.rope_scaling)\n    moe_layer_freq = [1] * hf_config.num_hidden_layers\n    for i in range(min(hf_config.first_k_dense_replace, hf_config.num_hidden_layers)):\n        moe_layer_freq[i] = 0\n\n    # disable MTP and quantization for now\n    if \"num_nextn_predict_layers\" in hf_config:\n        assert hf_config.num_nextn_predict_layers == 0, (\n            \"MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0\"\n        )\n    assert \"quantization_config\" not in hf_config or not hf_config.quantization_config, (\n        \"quantization is not supported for now, please modify the config.json to remove quantization_config\"\n    )\n\n    args: dict = _get_mla_transformer_config(\n        hf_config=hf_config,\n        mla_rope_config=mla_rope_config,\n        dtype=dtype,\n        # Additional parameters\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        attention_backend=AttnBackend.fused,\n        qk_layernorm=True,\n        # Standard MoE parameters\n        moe_ffn_hidden_size=hf_config.moe_intermediate_size,\n        moe_token_dispatcher_type=\"alltoall\",\n        moe_router_bias_update_rate=0.001,\n        moe_router_enable_expert_bias=True,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        num_moe_experts=hf_config.n_routed_experts,\n        moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts,\n        moe_aux_loss_coeff=getattr(hf_config, \"aux_loss_alpha\", 0.001),\n        moe_router_load_balancing_type=\"seq_aux_loss\",\n        moe_shared_expert_overlap=True,\n        # moe_permute_fusion=True, # need TE 2.1+\n        moe_grouped_gemm=True,\n        moe_router_score_function=\"sigmoid\",\n        moe_router_pre_softmax=True,\n        moe_router_topk_scaling_factor=hf_config.routed_scaling_factor,\n        moe_layer_freq=moe_layer_freq,\n        # mcore 0.12 moe\n        moe_router_dtype=\"fp64\",\n        disable_bf16_reduced_precision_matmul=True,\n        # Other optimizations\n        # deallocate_pipeline_outputs=True,\n        # gradient_accumulation_fusion=True,\n        persist_layer_norm=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    transformer_config = check_and_construct_configs(args, MLATransformerConfig)\n    # MTP\n    if \"num_nextn_predict_layers\" in hf_config:\n        transformer_config.mtp_num_layers = hf_config.num_nextn_predict_layers\n        transformer_config.mtp_loss_scaling_factor = 0.1\n\n    return transformer_config\n\n\ndef hf_to_mcore_config_qwen2_5_vl(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    # Qwen2_5_VLForConditionalGeneration\n\n    args = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        add_bias_linear=False,\n        # qwen specific\n        add_qkv_bias=True,\n        mrope_section=hf_config.rope_scaling[\"mrope_section\"],\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    args = mapping_string_to_attn_backend(args)\n    return TransformerConfig(**args)\n\n\ndef hf_to_mcore_config_llama4(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    # Llama4ForConditionalGeneration\n    raise NotImplementedError(\"Llama4ForConditionalGeneration is not supported yet\")\n\n\ndef mapping_string_to_attn_backend(args: dict) -> dict:\n    if \"attention_backend\" in args and isinstance(args[\"attention_backend\"], str):\n        from megatron.core.transformer.enums import AttnBackend\n\n        args[\"attention_backend\"] = AttnBackend[args[\"attention_backend\"]]\n    return args\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/loader.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_id, get_torch_device\n\nfrom .saver import _megatron_calc_global_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from verl.utils.logger import print_rank_0\n    from verl.utils.megatron_utils import unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def broadcast_params(module):\n        for param in module.parameters():\n            torch.distributed.broadcast(\n                param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()\n            )\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    cp_rank = mpu.get_context_parallel_rank()\n    src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank)\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == src_rank:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.decoder.layers) == num_layers_per_model\n\n    def _broadcast_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"broadcast tensor from rank0 across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        if torch.distributed.get_rank() == src_rank:\n            if name in state_dict:\n                weight = state_dict[name]\n                tensor_shape = weight.shape\n            else:\n                tensor_shape = None\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not in state_dict, skip load\")\n            return\n\n        if tensor is None:\n            tensor = torch.empty(\n                tensor_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        if torch.distributed.get_rank() == src_rank:\n            tensor.data.copy_(weight)\n        dist.broadcast(tensor, src=src_rank, group=mp_group)\n\n    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            if name in state_dict:\n                full_weight = state_dict[name]\n\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            if name in state_dict:\n                full_weight = state_dict[name]\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(\n                config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(\n                    torch.cat([gate_weight_tp, up_weight_tp], dim=0)\n                )\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape \"\n                f\"{tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n            full_weight_q = state_dict[q_name]\n            full_weight_k = state_dict[k_name]\n            full_weight_v = state_dict[v_name]\n\n            hidden_size_per_head = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                sizes = [total_size * tp_size]\n                if not bias:\n                    sizes.append(config.hidden_size)\n                new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    num_query_groups_per_partition = models[0].config.num_query_groups // tp_size\n                    new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]\n                    q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0)\n                    k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0)\n                    v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0)\n                    total_size_per_head = total_size // num_query_groups_per_partition\n                    for j in range(num_query_groups_per_partition):\n                        new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(\n                            torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)\n                        )\n\n            else:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                sizes = [total_size * tp_size]\n                if not bias:\n                    sizes.append(config.hidden_size)\n                new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                    k_part = full_weight_k[start_idx:end_idx]\n                    v_part = full_weight_v[start_idx:end_idx]\n                    new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]\n                    q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0)\n                    k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0)\n                    v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0)\n                    total_size_per_head = total_size // config.num_attention_heads\n                    for j in range(config.num_attention_heads):\n                        new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(\n                            torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)\n                        )\n\n            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"loading embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        embed_tokens_weight = None\n        if pp_rank == 0:\n            embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight\n        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n\n        for layer in range(config.num_hidden_layers):\n            layer_name = f\"model.layers.{layer}\"\n            print_rank_0(f\"loading layer #{layer}, with layer_name model.layers.{layer}...\")\n            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n            sync_layer = gpt_model_module.decoder.layers[dst_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.input_layernorm.weight\",\n            )\n\n            if f\"{layer_name}.self_attn.q_norm.weight\" in state_dict:\n                _broadcast_tensor(\n                    sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None,\n                    f\"{layer_name}.self_attn.q_norm.weight\",\n                )\n                _broadcast_tensor(\n                    sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None,\n                    f\"{layer_name}.self_attn.k_norm.weight\",\n                )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n            )\n            if f\"{layer_name}.self_attn.q_proj.bias\" in state_dict:\n                _broadcast_tp_shard_tensor_qkv(\n                    sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None,\n                    f\"{layer_name}.self_attn.q_proj.bias\",\n                    f\"{layer_name}.self_attn.k_proj.bias\",\n                    f\"{layer_name}.self_attn.v_proj.bias\",\n                    bias=True,\n                )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                chunk_dim=1,\n            )\n            _broadcast_tensor(\n                sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                chunk_dim=1,\n            )\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"loading final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.decoder.final_layernorm, \"weight\", None),\n            \"model.norm.weight\",\n        )\n\n        print_rank_0(\"loading lm_head...\")\n        lm_head_weight = None\n        if pp_rank + 1 == pp_size:\n            lm_head_weight = gpt_model_module.output_layer.weight\n\n        if is_value_model:\n            # if torch.distributed.get_rank() == src_rank:\n            if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n            elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"reward_head.weight\")\n                print_rank_0(\"load lm_head from value_head weight\")\n            elif \"score.weight\" in state_dict and state_dict[\"score.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"score.weight\")\n                print_rank_0(\"load lm_head from score weight\")\n            else:\n                _broadcast_tensor(None, \"lm_head.weight\")\n                print_rank_0(\"fail to match lm_head in value_model\")\n            # else:\n\n            #     _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n\n        else:\n            _broadcast_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n    dist.barrier()\n    # Broadcast weights inside data parallel groups\n    for wrapped_model in wrapped_models:\n        broadcast_params(wrapped_model)\n    pass\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/mbridge.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ntry:\n    from mbridge import AutoBridge\n    from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model\nexcept ImportError:\n    print(\"mbridge package not found. Please install mbridge with `pip install verl[mcore]` or `pip install mbridge`\")\n    raise\n\n__all__ = [\"AutoBridge\", \"make_value_model\", \"freeze_moe_router\"]\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/model_forward.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom verl.utils.megatron_utils import unwrap_model\n\nfrom .util import (\n    postprocess_packed_seqs,\n    postprocess_packed_seqs_no_padding,\n    preprocess_packed_seqs,\n    preprocess_packed_seqs_no_padding,\n)\n\n\ndef model_forward_gen(vision_model: bool = False):\n    def model_forward(\n        model,\n        input_ids,\n        attention_mask,\n        position_ids,\n        multi_modal_inputs: dict,\n        logits_processor=None,\n        logits_processor_args: dict = None,\n        value_model=False,\n    ):\n        \"\"\"Forward pass for models with sequence packing.\"\"\"\n        pre_process = (\n            unwrap_model(model).pre_process if not vision_model else False\n        )  # vision model does not need pre_process, because we pack the input_ids to thd in the forward function\n        post_process = unwrap_model(model).post_process\n\n        model_kwargs = {}\n        if \"pixel_values\" in multi_modal_inputs:\n            model_kwargs[\"pixel_values\"] = multi_modal_inputs[\"pixel_values\"].to(input_ids.device)\n        if \"image_grid_thw\" in multi_modal_inputs:\n            model_kwargs[\"image_grid_thw\"] = multi_modal_inputs[\"image_grid_thw\"].to(input_ids.device)\n        if \"pixel_values_videos\" in multi_modal_inputs:\n            model_kwargs[\"pixel_values_videos\"] = multi_modal_inputs[\"pixel_values_videos\"].to(input_ids.device)\n        if \"video_grid_thw\" in multi_modal_inputs:\n            model_kwargs[\"video_grid_thw\"] = multi_modal_inputs[\"video_grid_thw\"].to(input_ids.device)\n\n        batch_size, seq_len = attention_mask.shape[:2]\n        input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)\n        input_ids_rmpad = input_ids_rmpad.contiguous()\n\n        input_args = dict(\n            input_ids=input_ids_rmpad,\n            attention_mask=None,\n            position_ids=position_ids if not vision_model else None,  # vision models will calculate position_ids\n            packed_seq_params=packed_seq_params,\n            **model_kwargs,\n        )\n\n        if vision_model:\n            # workaround for supporting sequence packing with context parallelism\n            # cp split with sequence packing will make model lose vision token information, so we need to keep\n            # the original input_ids and pack them after vision embedding is calculated,\n            # cooporate with mbridge\n            input_args[\"input_ids\"] = input_ids\n            input_args[\"attention_mask\"] = attention_mask\n\n        output_orig = model(**input_args)\n        if post_process and logits_processor is not None:\n            args = {\n                k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0]\n                for k, v in logits_processor_args.items()\n            }\n            output_dict = logits_processor(output_orig, **args)\n            output = {\n                k: postprocess_packed_seqs(\n                    v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n                )\n                for k, v in output_dict.items()\n            }\n        else:\n            output = postprocess_packed_seqs(\n                output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n            )\n        if value_model and post_process:\n            output = output[..., 0]\n        return output\n\n    return model_forward\n\n\ndef gptmodel_forward_no_padding(\n    model,\n    input_ids,\n    multi_modal_inputs: dict,\n    logits_processor=None,\n    logits_processor_args: dict = None,\n    value_model=False,\n):\n    \"\"\"Default forward pass for GPT models with optional sequence packing.\"\"\"\n    pre_process = unwrap_model(model).pre_process\n    post_process = unwrap_model(model).post_process\n\n    model_kwargs = {}\n    if \"pixel_values\" in multi_modal_inputs:\n        model_kwargs[\"pixel_values\"] = multi_modal_inputs[\"pixel_values\"].to(input_ids.device)\n    if \"image_grid_thw\" in multi_modal_inputs:\n        model_kwargs[\"image_grid_thw\"] = multi_modal_inputs[\"image_grid_thw\"].to(input_ids.device)\n\n    batch_size = input_ids.shape[0]\n    input_ids_rmpad, packed_seq_params = preprocess_packed_seqs_no_padding(input_ids, pre_process=pre_process)\n    input_ids_rmpad = input_ids_rmpad.contiguous()\n    output_orig = model(\n        input_ids=input_ids_rmpad,\n        attention_mask=None,\n        position_ids=None,\n        packed_seq_params=packed_seq_params,\n        **model_kwargs,\n    )\n\n    if post_process and logits_processor is not None:\n        args = {k: preprocess_packed_seqs_no_padding(v, pre_process=True)[0] for k, v in logits_processor_args.items()}\n        output_dict = logits_processor(output_orig, **args)\n        output = {\n            k: postprocess_packed_seqs_no_padding(\n                v, packed_seq_params, input_ids, batch_size, post_process=post_process\n            )\n            for k, v in output_dict.items()\n        }\n    else:\n        output = postprocess_packed_seqs_no_padding(\n            output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process\n        )\n\n    if value_model and post_process:\n        # output = output[..., 0]\n        # while using nested tensor, the advanced indexing operation above will result in an error at backward, i.e.\n        # ValueError: NestedTensor _nested_select_backward_default(grad_output: t, self: jt_all, dim: any, index: any)\n        # so we use `squeeze` to remove the last dimension\n        output = output.squeeze(-1)\n\n    return output\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/model_forward_1f1b_overlap.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable, Optional\n\nimport torch\nfrom megatron.core.models.common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan\nfrom megatron.core.models.gpt.gpt_model import GPTModel\nfrom megatron.core.utils import make_viewless_tensor\nfrom torch import Tensor\n\nfrom verl.models.mcore.util import preprocess_packed_seqs\nfrom verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\nfrom verl.utils.megatron_utils import unwrap_model\nfrom verl.utils.model import CausalLMOutputForPPO\n\nfrom .util import postprocess_packed_seqs, postprocess_packed_seqs_for_dict_output\n\n\ndef gptmodel_forward_1f1b_overlap(\n    model: GPTModel,\n    input_ids: Tensor,\n    position_ids: Tensor,\n    attention_mask: Tensor,\n    labels: Tensor = None,\n    labels_mask: Tensor = None,\n    multi_modal_inputs: Optional[dict] = None,\n    logits_processor: Optional[Callable] = None,\n    logits_processor_args: Optional[dict] = None,\n    temperature: float = 1.0,\n) -> TransformerModelChunkSchedulePlan:\n    pre_process: bool = unwrap_model(model).pre_process\n    post_process: bool = unwrap_model(model).post_process\n    assert logits_processor is None, \"only support fused kernel\"\n    batch_size, seq_len = attention_mask.shape[:2]\n    input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)\n    input_ids_rmpad = input_ids_rmpad.contiguous()\n\n    schedule_plan = model.build_schedule_plan(\n        input_ids=input_ids_rmpad,\n        attention_mask=attention_mask,\n        labels=labels,\n        position_ids=position_ids,\n        packed_seq_params=packed_seq_params,\n    )\n    if post_process:\n        attention_mask_out = attention_mask\n\n        def _postprocess(\n            self,\n            hidden_states,\n            input_ids,\n            position_ids,\n            labels,\n            rotary_pos_emb,\n            rotary_pos_cos,\n            rotary_pos_sin,\n            mtp_in_postprocess=None,\n            loss_mask=None,\n            decoder_input=None,\n            attention_mask=None,\n            inference_params=None,\n            packed_seq_params=None,\n            sequence_len_offset=None,\n            runtime_gather_output=None,\n            extra_block_kwargs=None,\n            inference_context=None,\n        ):\n            \"\"\"patched from https://github.com/NVIDIA/Megatron-LM/blob/core_r0.14.0/megatron/core/models/gpt/gpt_model.py#L412\"\"\"\n            \"\"\"Postprocesses decoder hidden states to generate logits or compute loss.\n\n            Applies Multi-Token Prediction if enabled, generates output logits through\n            the output layer, and computes language model loss when labels are provided.\n            \"\"\"\n            from megatron.core import parallel_state\n            from megatron.core.tensor_parallel import gather_from_sequence_parallel_region\n\n            in_inference_mode = inference_context is not None and not self.training\n            if in_inference_mode:\n                assert runtime_gather_output, \"Inference must always gather TP logits\"\n\n            # logits and loss\n            output_weight = None\n            if self.share_embeddings_and_output_weights:\n                output_weight = self.shared_embedding_or_output_weight()\n\n            if mtp_in_postprocess:\n                hidden_states = self.mtp(\n                    input_ids=input_ids,\n                    position_ids=position_ids,\n                    hidden_states=hidden_states,\n                    attention_mask=attention_mask,\n                    inference_params=inference_params,\n                    rotary_pos_emb=rotary_pos_emb,\n                    rotary_pos_cos=rotary_pos_cos,\n                    rotary_pos_sin=rotary_pos_sin,\n                    packed_seq_params=packed_seq_params,\n                    sequence_len_offset=sequence_len_offset,\n                    embedding=self.embedding,\n                    **(extra_block_kwargs or {}),\n                )\n\n            if not self.post_process:\n                return hidden_states\n\n            if self.mtp_process:\n                from megatron.core.transformer.multi_token_prediction import (\n                    MTPLossAutoScaler,\n                    MTPLossLoggingHelper,\n                    roll_tensor,\n                )\n\n                mtp_labels = labels.clone()\n                hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)\n                hidden_states = hidden_states_list[0]\n                if loss_mask is None:\n                    # if loss_mask is not provided, use all ones as loss_mask\n                    loss_mask = torch.ones_like(mtp_labels)\n                for mtp_layer_number in range(self.config.mtp_num_layers):\n                    # output\n                    mtp_logits, _ = self.output_layer(\n                        hidden_states_list[mtp_layer_number + 1],\n                        weight=output_weight,\n                        runtime_gather_output=runtime_gather_output,\n                    )\n                    # Calc loss for the current Multi-Token Prediction (MTP) layers.\n                    mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)\n                    loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group)\n                    mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)\n                    mtp_loss = loss_mask * mtp_loss\n                    if self.training:\n                        # TODO(shifangx): remove the use of parallel_state here\n                        # after moving loss logging to loss_func in pretrain_gpt.py\n                        MTPLossLoggingHelper.save_loss_to_tracker(\n                            torch.sum(mtp_loss) / num_tokens,\n                            mtp_layer_number,\n                            self.config.mtp_num_layers,\n                            avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),\n                        )\n                    mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers\n                    if self.config.calculate_per_token_loss:\n                        hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)\n                    else:\n                        hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)\n\n            if logits_processor is not None:\n                logits, _ = self.output_layer(\n                    hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output\n                )\n                output_orig = logits.transpose(0, 1).contiguous()\n                args = {\n                    k: preprocess_packed_seqs(v, attention_mask_out, pre_process=True)[0]\n                    for k, v in logits_processor_args.items()\n                }\n                output_dict = logits_processor(output_orig, **args)\n                output = {\n                    k: postprocess_packed_seqs(\n                        v, packed_seq_params, attention_mask_out, batch_size, seq_len, post_process=post_process\n                    )\n                    for k, v in output_dict.items()\n                }\n            else:\n                # fused kernel\n\n                labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True)\n                labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True)\n                labels_rmpad = labels_rmpad.contiguous()\n                labels_mask_rmpad = labels_mask_rmpad.contiguous()\n\n                output = CausalLMOutputForPPO(\n                    loss=None,\n                    logits=None,\n                    past_key_values=None,\n                    hidden_states=hidden_states,\n                    attentions=None,\n                )\n                if self.config.sequence_parallel:\n                    hidden_states = gather_from_sequence_parallel_region(hidden_states)\n                logprobs, entropy = linear_cross_entropy(\n                    hidden_states,\n                    self.output_layer.weight,\n                    labels_rmpad,\n                    temperature,\n                    \"none\",\n                    parallel_state.get_tensor_model_parallel_group(),\n                )\n                output.entropy = entropy\n                output.log_probs = logprobs\n\n                output = postprocess_packed_seqs_for_dict_output(\n                    labels_mask_rmpad,\n                    output,\n                    packed_seq_params,\n                    attention_mask,\n                    batch_size,\n                    seq_len,\n                    post_process=post_process,\n                )\n            output_ = [output[\"log_probs\"]]\n            # TODO NOW 1f1b overlap only support one tensor output\n            # if \"entropy\" in output:\n            #     output_.append(output[\"entropy\"])\n            output_ = tuple(output_)\n            return output_\n\n        def _custom_post_process_node_forward_impl(self, hidden_states):\n            if self.gpt_model.decoder.final_layernorm and not self.gpt_model.mtp_process:\n                hidden_states = self.gpt_model.decoder.final_layernorm(hidden_states)\n                # TENorm produces a \"viewed\" tensor. This will result in schedule.py's\n                # deallocate_output_tensor() throwing an error, so a viewless tensor is\n                # created to prevent this.\n                hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)\n\n            # Run GPTModel._postprocess\n            output = self.gpt_model._postprocess(\n                hidden_states=hidden_states,\n                input_ids=self.chunk_state.input_ids,\n                position_ids=self.chunk_state.position_ids,\n                labels=self.chunk_state.labels,\n                decoder_input=self.chunk_state.decoder_input,\n                rotary_pos_emb=self.chunk_state.rotary_pos_emb,\n                rotary_pos_cos=self.chunk_state.rotary_pos_cos,\n                rotary_pos_sin=self.chunk_state.rotary_pos_sin,\n                mtp_in_postprocess=False,\n                loss_mask=self.chunk_state.loss_mask,\n                attention_mask=self.chunk_state.attention_mask,\n                packed_seq_params=self.chunk_state.packed_seq_params,\n                sequence_len_offset=self.chunk_state.sequence_len_offset,\n                runtime_gather_output=self.chunk_state.runtime_gather_output,\n                extra_block_kwargs=self.chunk_state.extra_block_kwargs,\n            )\n            return output\n\n        schedule_plan.post_process.forward_impl = _custom_post_process_node_forward_impl.__get__(\n            schedule_plan.post_process, schedule_plan.post_process.__class__\n        )\n        unwrap_model(model)._postprocess = _postprocess.__get__(unwrap_model(model), unwrap_model(model).__class__)\n\n    return schedule_plan\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/model_forward_fused.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import OrderedDict\nfrom typing import Optional\n\nimport megatron.core as mcore\nimport torch\nfrom megatron.core import parallel_state\nfrom megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk\nfrom megatron.core.inference.contexts import BaseInferenceContext\nfrom megatron.core.models.gpt.gpt_model import GPTModel\nfrom megatron.core.packed_seq_params import PackedSeqParams\nfrom megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region\nfrom megatron.core.utils import deprecate_inference_params\nfrom torch import Tensor\n\nfrom verl.models.mcore.util import preprocess_packed_seqs\nfrom verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\nfrom verl.utils.megatron_utils import unwrap_model\nfrom verl.utils.model import CausalLMOutputForPPO\n\nfrom .util import postprocess_packed_seqs_for_dict_output\n\n\ndef _get_patching_model(model: torch.nn.Module):\n    model = unwrap_model(model)\n    if isinstance(model, GPTModel):\n        return model\n\n    if not (hasattr(model, \"language_model\") and isinstance(model.language_model, GPTModel)):\n        print(f\"Model {model.__class__.__name__} is not a supported for fused forward\")\n        return None\n\n    return model.language_model\n\n\ndef patch_fused_forward(model: torch.nn.Module):\n    assert mcore.__version__ >= \"0.13.0\", \"Fused forward patching requires mecore >= 0.13.0\"\n    model = _get_patching_model(model)\n    if model is not None:\n        model.forward_backup = model.forward\n        model.forward = _fused_GPTModel_forward.__get__(model, model.__class__)\n\n\ndef unpatch_fused_forward(model: torch.nn.Module):\n    model = _get_patching_model(model)\n    if model is not None:\n        model.forward = model.forward_backup\n\n\ndef fused_forward_model_gen(vision_model: bool = False):\n    def fused_forward_model(\n        model,\n        input_ids: Tensor,\n        position_ids: Tensor,\n        attention_mask: Tensor,\n        labels: Tensor,\n        labels_mask: Tensor,\n        temperature: float,\n        multi_modal_inputs: dict,\n    ):\n        pre_process: bool = (\n            unwrap_model(model).pre_process if not vision_model else False\n        )  # vision model does not need pre_process, because we pack the input_ids to thd in the forward function\n        post_process: bool = unwrap_model(model).post_process\n\n        model_kwargs = {}\n        if \"pixel_values\" in multi_modal_inputs:\n            model_kwargs[\"pixel_values\"] = multi_modal_inputs[\"pixel_values\"].to(input_ids.device)\n        if \"image_grid_thw\" in multi_modal_inputs:\n            model_kwargs[\"image_grid_thw\"] = multi_modal_inputs[\"image_grid_thw\"].to(input_ids.device)\n        if \"pixel_values_videos\" in multi_modal_inputs:\n            model_kwargs[\"pixel_values_videos\"] = multi_modal_inputs[\"pixel_values_videos\"].to(input_ids.device)\n        if \"video_grid_thw\" in multi_modal_inputs:\n            model_kwargs[\"video_grid_thw\"] = multi_modal_inputs[\"video_grid_thw\"].to(input_ids.device)\n\n        batch_size, seq_len = attention_mask.shape[:2]\n        input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)\n        input_ids_rmpad = input_ids_rmpad.contiguous()\n        labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True)\n        labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True)\n        labels_rmpad = labels_rmpad.contiguous()\n        labels_mask_rmpad = labels_mask_rmpad.contiguous()\n\n        input_args = dict(\n            input_ids=input_ids_rmpad,\n            attention_mask=None,\n            position_ids=position_ids if not vision_model else None,  # vision models will calculate position_ids\n            packed_seq_params=packed_seq_params,\n            labels=labels_rmpad,\n            temperature=temperature,\n            **model_kwargs,\n        )\n\n        if vision_model:\n            # workaround for supporting sequence packing with context parallelism\n            # cp split with sequence packing will make model lose vision token information, so we need to keep\n            # the original input_ids and pack them after vision embedding is calculated,\n            # cooporate with mbridge\n            input_args[\"input_ids\"] = input_ids\n            input_args[\"attention_mask\"] = attention_mask\n\n        output_orig: CausalLMOutputForPPO = model(**input_args)\n\n        if post_process:\n            # output_orig is in type of CausalLMOutputForPPO\n            output = postprocess_packed_seqs_for_dict_output(\n                labels_mask_rmpad,\n                output_orig,\n                packed_seq_params,\n                attention_mask,\n                batch_size,\n                seq_len,\n                post_process=post_process,\n            )\n        else:\n            output = output_orig\n        return output\n\n    return fused_forward_model\n\n\ndef _fused_GPTModel_forward(\n    model,\n    input_ids: Tensor,\n    position_ids: Tensor,\n    attention_mask: Tensor,\n    decoder_input: Tensor = None,\n    labels: Tensor = None,\n    inference_context: BaseInferenceContext = None,\n    packed_seq_params: PackedSeqParams = None,\n    extra_block_kwargs: dict = None,\n    runtime_gather_output: Optional[bool] = None,\n    *,\n    inference_params: Optional[BaseInferenceContext] = None,\n    loss_mask: Optional[Tensor] = None,\n    temperature: float = 1.0,\n    **kwargs,\n) -> CausalLMOutputForPPO:\n    \"\"\"\n    Patch self._postprocess in forward for GPT models to enable fused kernel support.\n    https://github.com/NVIDIA/Megatron-LM/blob/core_v0.13.0/megatron/core/models/gpt/gpt_model.py\n\n    TODO: Currently we still need to patch `forward` because we need to pass `temperature`\n    explicitly to `self._postprocess` when calling, maybe there can be a better way to handle this?\n    \"\"\"\n\n    inference_context = deprecate_inference_params(inference_context, inference_params)\n\n    preproc_output = model._preprocess(\n        input_ids=input_ids,\n        position_ids=position_ids,\n        decoder_input=decoder_input,\n        inference_context=inference_context,\n        packed_seq_params=packed_seq_params,\n    )\n\n    (decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset) = preproc_output[:5]\n\n    # Run decoder.\n    hidden_states = model.decoder(\n        hidden_states=decoder_input,\n        attention_mask=attention_mask,\n        inference_context=inference_context,\n        rotary_pos_emb=rotary_pos_emb,\n        rotary_pos_cos=rotary_pos_cos,\n        rotary_pos_sin=rotary_pos_sin,\n        packed_seq_params=packed_seq_params,\n        sequence_len_offset=sequence_len_offset,\n        **(extra_block_kwargs or {}),\n        **kwargs,\n    )\n\n    if not model.post_process:\n        return hidden_states\n\n    output = CausalLMOutputForPPO(\n        loss=None,\n        logits=None,\n        past_key_values=None,\n        hidden_states=hidden_states,\n        attentions=None,\n    )\n\n    if model.config.sequence_parallel:\n        hidden_states = gather_from_sequence_parallel_region(hidden_states)\n    logprobs, entropy = linear_cross_entropy(\n        hidden_states,\n        model.output_layer.weight,\n        labels,\n        temperature,\n        \"none\",\n        parallel_state.get_tensor_model_parallel_group(),\n    )\n\n    if has_config_logger_enabled(model.config):\n        payload = OrderedDict(\n            {\n                \"input_ids\": input_ids,\n                \"position_ids\": position_ids,\n                \"attention_mask\": attention_mask,\n                \"decoder_input\": decoder_input,\n                \"logprobs\": logprobs,\n                \"entropy\": entropy,\n            }\n        )\n        log_config_to_disk(model.config, payload, prefix=\"input_and_logits\")\n\n    output.entropy = entropy\n    output.log_probs = logprobs\n\n    return output\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/model_initializer.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# use mcore transformer config to initialize the model\nimport inspect\nfrom abc import ABC, abstractmethod\n\nfrom megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec\nfrom megatron.core.models.gpt.gpt_model import GPTModel\n\nfrom .config_converter import PretrainedConfig, TransformerConfig\n\n\nclass BaseModelInitializer(ABC):\n    \"\"\"Base class for model initializers.\"\"\"\n\n    def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig):\n        self.tfconfig = tfconfig\n        self.hf_config = hf_config\n        self.has_vp_stage = inspect.signature(get_gpt_decoder_block_spec).parameters.get(\"vp_stage\", None) is not None\n\n    @abstractmethod\n    def get_transformer_layer_spec(self, vp_stage=None):\n        \"\"\"Get the transformer layer specification.\n        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py\"\"\"\n        pass\n\n    def get_rope_scaling_args(self) -> dict:\n        \"\"\"Get rope scaling args.\"\"\"\n        rope_scaling_args = {}\n        if \"rope_scaling\" in self.hf_config:\n            if self.hf_config.rope_scaling is not None:\n                # assert self.hf_config.rope_scaling[\"type\"] == \"linear\", \"only linear scaling is supported for now\"\n                rope_scaling_args[\"seq_len_interpolation_factor\"] = self.hf_config.rope_scaling[\"factor\"]\n        return rope_scaling_args\n\n    def initialize(\n        self,\n        pre_process: bool = True,\n        post_process: bool = True,\n        share_embeddings_and_output_weights: bool = False,\n        value: bool = False,\n        **extra_kwargs,\n    ) -> GPTModel:\n        \"\"\"Initialize a GPT model with the given configuration.\n        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py\n\n        Args:\n            pre_process (bool): include embedding layer.\n            post_process (bool): including an output layer.\n            share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared.\n            value (bool): add an extra linear layer for classification or regression.\n\n        Returns:\n            GPTModel: An initialized GPT model instance\n        \"\"\"\n        vp_stage = extra_kwargs.get(\"vp_stage\", None)\n        transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage)\n        rope_scaling_args = self.get_rope_scaling_args()\n        mtp_block_spec = extra_kwargs.get(\"mtp_block_spec\", None)\n        model = GPTModel(\n            config=self.tfconfig,\n            transformer_layer_spec=transformer_layer_spec,\n            vocab_size=self.hf_config.vocab_size,\n            max_sequence_length=self.hf_config.max_position_embeddings,\n            pre_process=pre_process,\n            post_process=post_process,\n            share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n            position_embedding_type=\"rope\",\n            rotary_base=self.hf_config.rope_theta,\n            **rope_scaling_args,\n            mtp_block_spec=mtp_block_spec,\n            **({} if not self.has_vp_stage else {\"vp_stage\": vp_stage}),\n        )\n\n        if post_process and value:\n            from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer\n\n            model.output_layer = LinearForLastLayer(\n                input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig\n            )\n\n        return model\n\n\nclass DenseModel(BaseModelInitializer):\n    \"\"\"Initializer for dense models like Llama and Qwen2.\"\"\"\n\n    def get_transformer_layer_spec(self, vp_stage=None):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        extra_kwargs = {} if not self.has_vp_stage else {\"vp_stage\": vp_stage}\n        return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)\n\n\nclass Qwen2MoEModel(BaseModelInitializer):\n    \"\"\"Initializer for Qwen2 MoE models.\"\"\"\n\n    def get_transformer_layer_spec(self, vp_stage=None):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        extra_kwargs = {} if not self.has_vp_stage else {\"vp_stage\": vp_stage}\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)\n\n        # Patch layer spec for shared experts\n        for i in range(len(transformer_layer_spec.layer_specs)):\n            transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params[\"gate\"] = True\n\n        return transformer_layer_spec\n\n    def initialize(self, **kwargs):\n        # Qwen default freeze_moe_router: true\n        model = super().initialize(**kwargs)\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", True)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass MixtralModel(BaseModelInitializer):\n    \"\"\"Initializer for Mixtral models.\"\"\"\n\n    def get_transformer_layer_spec(self, vp_stage=None):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        extra_kwargs = {} if not self.has_vp_stage else {\"vp_stage\": vp_stage}\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)\n        return transformer_layer_spec\n\n    def initialize(self, **kwargs):\n        model = super().initialize(**kwargs)\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", False)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass Qwen3MoEModel(BaseModelInitializer):\n    \"\"\"Initializer for Qwen3 MoE models.\"\"\"\n\n    def get_transformer_layer_spec(self, vp_stage=None):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        extra_kwargs = {} if not self.has_vp_stage else {\"vp_stage\": vp_stage}\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)\n        return transformer_layer_spec\n\n    def initialize(self, **kwargs):\n        # Qwen default freeze_moe_router: true\n        model = super().initialize(**kwargs)\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", True)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass DeepseekV3Model(BaseModelInitializer):\n    \"\"\"Initializer for DeepseekV3 models.\"\"\"\n\n    def get_transformer_layer_spec(self, vp_stage=None):\n        extra_kwargs = {} if not self.has_vp_stage else {\"vp_stage\": vp_stage}\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)\n        return transformer_layer_spec\n\n    def get_rope_scaling_args(self) -> dict:\n        \"\"\"Get rope scaling args.\"\"\"\n        rope_scaling_args = {}\n        return rope_scaling_args\n\n    def initialize(\n        self,\n        **kwargs,\n    ):\n        vp_stage = kwargs.get(\"vp_stage\", None)\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", True)\n        if freeze_moe_router:\n            self.tfconfig.moe_router_load_balancing_type = \"none\"\n        # MTP\n        if self.tfconfig.mtp_num_layers is not None and self.tfconfig.mtp_num_layers > 0:\n            transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage)\n            mtp_block_spec = get_gpt_mtp_block_spec(\n                self.tfconfig, transformer_layer_spec, use_transformer_engine=True, vp_stage=vp_stage\n            )\n            kwargs[\"mtp_block_spec\"] = mtp_block_spec\n\n        model = super().initialize(**kwargs)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                if hasattr(layer.mlp, \"router\"):\n                    layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass Qwen25VLModel(BaseModelInitializer):\n    \"\"\"Initializer for Qwen2.5 VL models.\"\"\"\n\n    def get_transformer_layer_spec(self, vp_stage=None):\n        extra_kwargs = {} if not self.has_vp_stage else {\"vp_stage\": vp_stage}\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)\n        return transformer_layer_spec\n\n    def initialize(\n        self,\n        pre_process=None,\n        post_process=None,\n        share_embeddings_and_output_weights=False,\n        value=False,\n        **extra_kwargs,\n    ):\n        tfconfig = self.tfconfig\n        hf_config = self.hf_config\n        # Qwen2_5_VLForConditionalGeneration\n        from copy import deepcopy\n\n        transformer_layer_spec = self.get_transformer_layer_spec()\n\n        from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear\n        from megatron.core.models.gpt.moe_module_specs import MLPSubmodules\n        from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec\n\n        from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config\n\n        vision_transformer_config = get_vision_model_config(deepcopy(tfconfig))\n        vision_transformer_config.pipeline_model_parallel_size = 1\n        vision_transformer_config.first_pipeline_num_layers = None\n\n        vision_projection_config = get_vision_projection_config(\n            deepcopy(tfconfig),\n            vision_transformer_config.hidden_size,\n            spatial_merge_size=hf_config.vision_config.spatial_merge_size,\n        )\n        vision_projection_layer_spec = MLPSubmodules(\n            linear_fc1=TEColumnParallelLinear,\n            linear_fc2=TERowParallelLinear,\n        )\n        vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec()\n\n        qwen25_vl_model = Qwen2_5VLModel(\n            language_transformer_config=tfconfig,\n            language_transformer_layer_spec=transformer_layer_spec,\n            language_vocab_size=hf_config.vocab_size,\n            language_max_sequence_length=hf_config.max_position_embeddings,\n            vision_transformer_config=vision_transformer_config,\n            vision_transformer_layer_spec=vision_transformer_layer_spec,\n            vision_projection_config=vision_projection_config,\n            vision_projection_layer_spec=vision_projection_layer_spec,\n            vision_projection_type=\"mlp\",\n            language_rotary_base=hf_config.rope_theta,\n            pre_process=pre_process,\n            post_process=post_process,\n            add_decoder=True,\n            add_encoder=True,\n            parallel_output=True,\n            language_share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n        )\n\n        if post_process and value:\n            from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer\n\n            qwen25_vl_model.language_model.output_layer = LinearForLastLayer(\n                input_size=tfconfig.hidden_size, output_size=1, config=tfconfig\n            )\n\n        return qwen25_vl_model\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/patch_v012.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# there is some bug in mcore 0.12, so we need to patch it\n# 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None\n\n\ndef apply_patch():\n    import torch\n    from megatron.core import parallel_state, tensor_parallel\n    from megatron.core.transformer.multi_latent_attention import (\n        MLASelfAttention,\n        apply_rotary_pos_emb,\n        deprecate_inference_params,\n        gather_from_sequence_parallel_region,\n        gather_from_tensor_model_parallel_region,\n        scatter_to_sequence_parallel_region,\n    )\n\n    def patch_get_query_key_value_tensors(\n        self,\n        hidden_states,\n        key_value_states=None,\n        position_ids=None,\n        packed_seq_params=None,\n        inference_context=None,\n        *,\n        inference_params=None,\n    ):\n        \"\"\"\n        Derives `query`, `key` and `value` tensors from `hidden_states`.\n        \"\"\"\n        # s = sequence length, b = batch size, h = hidden size, n = num attention heads\n        # Attention heads [s, b, n*h]\n        assert hidden_states.ndim == 3, f\"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D\"\n\n        inference_context = deprecate_inference_params(inference_context, inference_params)\n\n        # =========================================\n        # Prepare RoPE and seqlen related params\n        # =========================================\n        rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(\n            inference_context, None, hidden_states, self.config, packed_seq_params\n        )\n\n        # rotary_pos_emb:[s, b, 1, 64]\n        mscale = 1.0\n        if self.config.rope_type == \"rope\":\n            packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == \"thd\"\n            rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq)\n        else:\n            rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len)\n\n        # =========================================\n        # QKV down projection and layernorm\n        # =========================================\n        if self.config.q_lora_rank is not None:\n            # if linear_q_down_proj is ColumnParallelLinear:\n            #     q_compressed: [s, b, q_lora_rank / TP]\n            # elif linear_q_down_proj is Linear:\n            #     q_compressed: [s / TP, b, q_lora_rank]\n            q_compressed, _ = self.linear_q_down_proj(hidden_states)\n\n            # When output is sharded (ColumnParallelLinear), two things are needed to be\n            # identical to a normal Linear.\n            #   1. Manually gather output to restore output dim q_lora_rank;\n            #   2. Scatter sequence back to s / TP if sequence-parallel since it was\n            #      gathered by ColumnParallelLinear.\n            if q_compressed.size(-1) != self.config.q_lora_rank:\n                q_compressed = gather_from_tensor_model_parallel_region(q_compressed)\n                if self.config.sequence_parallel:\n                    q_compressed = scatter_to_sequence_parallel_region(q_compressed)\n\n            q_compressed = self.q_layernorm(q_compressed)\n        else:\n            q_compressed = hidden_states\n\n        # if linear_kv_down_proj is ColumnParallelLinear:\n        #     kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP]\n        # elif linear_kv_down_proj is Linear:\n        #     kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)]\n        kv_combined, _ = self.linear_kv_down_proj(hidden_states)\n        if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim:\n            # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)]\n            kv_combined = gather_from_tensor_model_parallel_region(kv_combined)\n            # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim]\n            kv_compressed, k_pos_emb = torch.split(\n                kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1\n            )\n            if self.config.sequence_parallel:\n                # kv_compressed:[s / TP, b, kv_lora_rank]\n                kv_compressed = scatter_to_sequence_parallel_region(kv_compressed)\n        else:\n            # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim]\n            kv_compressed, k_pos_emb = torch.split(\n                kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1\n            )\n            if parallel_state.get_tensor_model_parallel_world_size() > 1:\n                # k_pos_emb: [s, b, qk_pos_emb_head_dim]\n                k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb)\n\n        kv_compressed = self.kv_layernorm(kv_compressed)\n\n        # =========================================\n        # QKV up projection and RoPE apply\n        # =========================================\n        def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb):\n            if self.config.q_lora_rank is not None:\n                q, _ = self.linear_q_up_proj(q_compressed)\n            else:\n                # hidden_states:[s, b, 2048], q: [s, b, n * 192]\n                q, _ = self.linear_q_proj(q_compressed)\n\n            q_len, bsz, _ = q.size()\n\n            # q: [s, b, n, 192]\n            q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim)\n\n            # kv: [s, b, 2048]\n            kv, _ = self.linear_kv_up_proj(kv_compressed)\n\n            # kv: [s, b, n, 256]\n            kv = kv.view(\n                q_len,\n                bsz,\n                self.num_attention_heads_per_partition,\n                self.config.qk_head_dim + self.config.v_head_dim,\n            )\n\n            if inference_context is not None:\n                # add offset to the sequence start for inference\n                sequence_start = inference_context.sequence_len_offset\n                sequence_end = sequence_start + q_len\n                rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end]\n            else:\n                # Shorten rotary_pos_emb to the sequence length when inference_params\n                # is not provided. This makes sure we can run forward directly with\n                # any sequence length. During training, the sequence length is always\n                # the full rotary_pos_emb length.\n                rotary_pos_emb = rotary_pos_emb[0:q_len]\n\n            # [s, b, 64] -> [s, b, 1, 64]\n            k_pos_emb = torch.unsqueeze(k_pos_emb, 2)\n\n            # q: [s, b, n, 128], q_pos_emb: [s, b, n, 64]\n            q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1)\n\n            # k_no_pe: [s, b, n, 128], value: [s, b, n, 128]\n            k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1)\n\n            if packed_seq_params is not None:\n                cu_seqlens_q = packed_seq_params.cu_seqlens_q\n                cu_seqlens_kv = packed_seq_params.cu_seqlens_kv\n                q_pos_emb = q_pos_emb.squeeze(1)\n                k_pos_emb = k_pos_emb.squeeze(1)\n                q_no_pe = q_no_pe.squeeze(1)\n                k_no_pe = k_no_pe.squeeze(1)\n                value = value.squeeze(1)\n            else:\n                cu_seqlens_q = cu_seqlens_kv = None\n\n            # q_pos_emb: [s, b, n, 64], k_pos_emb:[s, b, 1, 64]\n            q_pos_emb = apply_rotary_pos_emb(\n                q_pos_emb,\n                rotary_pos_emb,\n                config=self.config,\n                cu_seqlens=cu_seqlens_q,\n                mscale=mscale,\n            )\n            k_pos_emb = apply_rotary_pos_emb(\n                k_pos_emb,\n                rotary_pos_emb,\n                config=self.config,\n                cu_seqlens=cu_seqlens_kv,\n                mscale=mscale,\n            )\n\n            # query: [s, b, n, 192]\n            query = torch.cat([q_no_pe, q_pos_emb], dim=-1)\n            if packed_seq_params is not None:\n                k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1)\n                key = torch.cat([k_no_pe, k_pos_emb], dim=-1)\n            else:\n                # key: [s, b, n, 192]\n                k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1)\n                key = torch.cat([k_no_pe, k_pos_emb], dim=-1)\n\n            query = query.contiguous()\n            key = key.contiguous()\n            value = value.contiguous()\n            return query, key, value\n\n        if self.recompute_up_proj:\n            self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput()\n            query, key, value = self.qkv_up_checkpoint.checkpoint(\n                qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb\n            )\n        else:\n            query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb)\n\n        return query, key, value\n\n    MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/qwen2_5_vl/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom .model import Qwen2_5VLModel\nfrom .vision_config import get_vision_model_config, get_vision_projection_config\n\n__all__ = [\"Qwen2_5VLModel\", \"get_vision_model_config\", \"get_vision_projection_config\"]\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/qwen2_5_vl/attention.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core.transformer.attention import *\n\nfrom .rope_utils import apply_rotary_pos_emb_absolute\n\n\nclass Qwen2_5VLSelfAttention(SelfAttention):\n    \"\"\"\n    Overrides the SelfAttention class, the difference is that qwen2_5_vl uses apply_rotary_pos_emb_absolute\n    instead of apply_rotary_pos_emb\n    \"\"\"\n\n    def forward(\n        self,\n        hidden_states: Tensor,\n        attention_mask: Tensor,\n        key_value_states: Optional[Tensor] = None,\n        inference_context: Optional[BaseInferenceContext] = None,\n        rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None,\n        rotary_pos_cos: Optional[Tensor] = None,\n        rotary_pos_sin: Optional[Tensor] = None,\n        attention_bias: Optional[Tensor] = None,\n        packed_seq_params: Optional[PackedSeqParams] = None,\n        sequence_len_offset: Optional[int] = None,\n        *,\n        inference_params: Optional[BaseInferenceContext] = None,\n    ) -> Tuple[Tensor, Tensor]:\n        \"\"\"\n        Perform a forward pass through the attention module.\n\n        Args:\n            hidden_states (Tensor): Hidden states.\n            attention_mask (Tensor): Attention mask.\n            key_value_states (Optional[Tensor]): Key/value states (for cross attention).\n            inference_context (Optional[BaseInferenceContext]): Inference context that manages\n                KV cache.\n            rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary\n                embedding tensor(s).\n            rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine.\n            rotary_pos_sin (Optional[Tensor]): Rotary embedding sine.\n            attention_bias (Optional[Tensor]): Attention bias.\n            packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format.\n            sequence_len_offset (Optional[int]): Sequence length offset used for\n                inference CUDA graphs.\n\n        Return:\n            (Tuple[Tensor, Tensor]) Attention output and bias.\n\n        \"\"\"\n\n        inference_context = deprecate_inference_params(inference_context, inference_params)\n\n        if inference_context and inference_context.is_dynamic_batching():\n            assert flash_decode_and_prefill_kernel is not None, (\n                \"Internal use only: install package `nvidia_chunked_flash_attn`.\"\n            )\n\n        # hidden_states: [sq, b, h]\n        if self.config.flash_decode and not self.training and inference_context is not None:\n            rotary_pos_emb = None\n        else:\n            assert rotary_pos_cos is None and rotary_pos_sin is None\n\n        # For self attention we just duplicate the rotary_pos_emb if it isn't already\n        if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):\n            rotary_pos_emb = (rotary_pos_emb,) * 2\n\n        # =====================\n        # Query, Key, and Value\n        # =====================\n        # Get the query, key and value tensors based on the type of attention -\n        # self or cross attn.\n        query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)\n\n        # ===================================================\n        # Adjust key, value, and rotary_pos_emb for inference\n        # ===================================================\n\n        # This branch only runs in the decode phase of flash decoding and returns after the linear\n        # projection. This conditional is not used in the prefill phase or non-flash-decoding cases.\n        if (\n            self.config.flash_decode\n            and inference_context is not None\n            and inference_context.is_decode_only()\n            and not self.training\n            and rotary_pos_cos is not None\n        ):\n            assert self.layer_number in inference_context.key_value_memory_dict\n            assert inference_context.sequence_len_offset is not None\n            inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number]\n            output = self.flash_decode(\n                sequence_len_offset=sequence_len_offset,\n                query_layer=query,\n                key_layer=key,\n                value_layer=value,\n                inference_key_memory=inference_key_memory,\n                inference_value_memory=inference_value_memory,\n                rotary_cos=rotary_pos_cos,\n                rotary_sin=rotary_pos_sin,\n            )\n            out = output.transpose(0, 1).contiguous()\n            context_layer = out.view(out.size(0), out.size(1), -1)\n            output, bias = self.linear_proj(context_layer)\n            return output, bias\n\n        # Use latest mcore 0.13 API and forward-compatible with previous versions.\n        outputs = self._adjust_key_value_for_inference(\n            inference_context,\n            query,\n            key,\n            value,\n            rotary_pos_emb,\n            rotary_pos_cos,\n            rotary_pos_sin,\n            sequence_len_offset,\n        )\n\n        query, key, value, rotary_pos_emb, attn_mask_type = outputs[:5]\n\n        if packed_seq_params is not None:\n            query = query.squeeze(1)\n            key = key.squeeze(1)\n            value = value.squeeze(1)\n\n        # ================================================\n        # relative positional embedding (rotary embedding)\n        # ================================================\n        if rotary_pos_emb is not None and not self.config.flash_decode:\n            q_pos_emb, k_pos_emb = rotary_pos_emb\n\n            if packed_seq_params is not None:\n                if packed_seq_params.cu_seqlens_q_padded is not None:\n                    cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded\n                else:\n                    cu_seqlens_q = packed_seq_params.cu_seqlens_q\n                if packed_seq_params.cu_seqlens_kv_padded is not None:\n                    cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded\n                else:\n                    cu_seqlens_kv = packed_seq_params.cu_seqlens_kv\n            else:\n                cu_seqlens_q = cu_seqlens_kv = None\n\n            if q_pos_emb is not None:\n                # TODO VIJAY: simplify\n                if inference_context is None or inference_context.is_static_batching():\n                    query = apply_rotary_pos_emb_absolute(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q)\n                else:\n                    query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q)\n            if k_pos_emb is not None:\n                key = apply_rotary_pos_emb_absolute(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)\n\n            # TODO, can apply positional embedding to value_layer so it has\n            # absolute positional embedding.\n            # otherwise, only relative positional embedding takes effect\n            # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)\n\n        # ==================================\n        # core attention computation\n        # ==================================\n\n        if self.checkpoint_core_attention and self.training:\n            core_attn_out = self._checkpointed_attention_forward(\n                query,\n                key,\n                value,\n                attention_mask,\n                attn_mask_type=attn_mask_type,\n                attention_bias=attention_bias,\n                packed_seq_params=packed_seq_params,\n            )\n        else:\n            if inference_context is None or inference_context.is_static_batching():\n                # Static batching attention kernel.\n                core_attn_out = self.core_attention(\n                    query,\n                    key,\n                    value,\n                    attention_mask,\n                    attn_mask_type=attn_mask_type,\n                    attention_bias=attention_bias,\n                    packed_seq_params=packed_seq_params,\n                )\n\n            else:\n                # Dynamic batching attention kernel.\n                q, k, v = (query, key, value)\n                cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths()\n                cu_kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths()\n\n                core_attn_out = self.flash_decode_and_prefill(\n                    q, k, v, max_seqlen_q, max_seqlen_k, cu_query_lengths, cu_kv_lengths\n                )\n                core_attn_out = core_attn_out.squeeze(0).unsqueeze(1)\n                core_attn_out = rearrange(core_attn_out, \"s b h d -> s b (h d)\")\n\n        if packed_seq_params is not None and packed_seq_params.qkv_format == \"thd\":\n            # reshape to same output shape as unpacked case\n            # (t, np, hn) -> (t, b=1, h=np*hn)\n            # t is the pack size = sum (sq_i)\n            # note that batch is a dummy dimension in the packed case\n            core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)\n\n        # =================\n        # Output. [sq, b, h]\n        # =================\n\n        output, bias = self.linear_proj(core_attn_out)\n\n        return output, bias\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/qwen2_5_vl/model.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport logging\n\nimport torch\nfrom megatron.core import InferenceParams, mpu, tensor_parallel\nfrom megatron.core.models.gpt.gpt_model import GPTModel\n\n# from .transformer_config import Qwen2VLTransformerConfig\nfrom megatron.core.packed_seq_params import PackedSeqParams\nfrom megatron.core.transformer import MegatronModule\nfrom megatron.core.transformer.spec_utils import ModuleSpec\nfrom megatron.core.transformer.transformer_config import TransformerConfig\n\nfrom verl.models.mcore.util import preprocess_packed_seqs\n\nfrom .attention import Qwen2_5VLSelfAttention\nfrom .vision_model import Qwen2_5VisionModel\n\n\n# Note: This is under development and may be missing features.\nclass Qwen2_5VLModel(MegatronModule):\n    \"\"\"Qwen2.5VL multi-modal model.\n\n    Args:\n        language_transformer_config (TransformerConfig): Transformer config for the language model.\n        language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the\n            language model.\n        language_vocab_size (int): Language model vocabulary size.\n        language_max_sequence_length (int): Language model maximum sequence length. This is used for\n            positional embedding.\n        vision_transformer_config (TransformerConfig): Transformer config for the vision model.\n        vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the\n            vision model.\n        vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to\n            language model inputs.\n        vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision\n            projection.\n        vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP.\n        parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This\n            is typically True for training and False for inference.\n        language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings\n            in the language model. Defaults to 1.0.\n        pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism).\n            Defaults to True.\n        post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline\n            parallelism). Defaults to True.\n        add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True.\n            When we use pipelining, the encoder\n            will live on only a subset of the pipeline stages (specifically, only the first stage).\n        add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True.\n            When we use pipelining, the decoder\n            will live on only a subset of the pipeline stages (specifically, every stage after the first one).\n        img_h (int): The height of each image that the ViT will see.\n        img_w (int): The width of each image that the ViT will see.\n        patch_dim (int): The size of each patch side.\n        img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be\n            inserted. Defaults to 0.\n    \"\"\"\n\n    def __init__(\n        self,\n        language_transformer_config: TransformerConfig,\n        language_transformer_layer_spec: ModuleSpec,\n        language_vocab_size: int,\n        language_max_sequence_length: int,\n        vision_transformer_config: TransformerConfig,\n        vision_transformer_layer_spec: ModuleSpec,\n        vision_projection_config: TransformerConfig,\n        vision_projection_layer_spec: ModuleSpec,\n        vision_projection_type: str = \"mlp\",\n        parallel_output: bool = True,\n        language_rotary_percent: float = 1.0,\n        pre_process: bool = True,\n        post_process: bool = True,\n        add_encoder: bool = True,\n        add_decoder: bool = True,\n        language_rotary_base: int = 10000,\n        fp16_lm_cross_entropy: bool = False,\n        language_share_embeddings_and_output_weights: bool = False,\n        image_token_id: int = 151655,\n        video_token_id: int = 151656,\n    ) -> None:\n        super().__init__(config=language_transformer_config)\n\n        # patch self_attention to use qwen2_5_vl attention\n        vision_transformer_layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention\n        for layer_spec in language_transformer_layer_spec.layer_specs:\n            layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention\n\n        logging.getLogger(__name__).warning(\"Qwen2VL model is under development and may be missing features.\")\n\n        self.pre_process = pre_process\n        self.post_process = post_process\n        self.add_encoder = add_encoder\n        self.add_decoder = add_decoder\n\n        self.encoder_hidden_state = None\n        self.vision_model = None\n        self.vision_projection = None\n        self.language_model = None\n        self.image_token_id = image_token_id\n        self.video_token_id = video_token_id\n\n        self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size\n\n        # This attribute is needed to check if an all-reduce is required\n        # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.\n        self.share_embeddings_and_output_weights = False\n        if self.pre_process:\n            self.vision_model = Qwen2_5VisionModel(\n                vision_transformer_config,\n                vision_transformer_layer_spec,\n                vision_projection_config,\n                vision_projection_layer_spec,\n                projection_type=vision_projection_type,\n                pre_process=True,\n                post_process=True,\n            )\n\n        self.language_model = GPTModel(\n            config=language_transformer_config,\n            transformer_layer_spec=language_transformer_layer_spec,\n            vocab_size=language_vocab_size,\n            max_sequence_length=language_max_sequence_length,\n            parallel_output=parallel_output,\n            position_embedding_type=\"mrope\",\n            rotary_percent=language_rotary_percent,\n            pre_process=self.pre_process,\n            post_process=self.post_process,\n            rotary_base=language_rotary_base,\n            fp16_lm_cross_entropy=fp16_lm_cross_entropy,\n            share_embeddings_and_output_weights=language_share_embeddings_and_output_weights,\n            scatter_embedding_sequence_parallel=False,\n        )\n        assert mpu.get_context_parallel_world_size() <= 1, \"please use mbridge for qwen2_5_vl with context parallelism\"\n        self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights\n\n    def shared_embedding_or_output_weight(self):\n        \"\"\"This is a convenience method to surface the language model's word embeddings, which is\n        necessary for `finalize_model_grads._allreduce_word_embedding_grads`.\"\"\"\n        if self.add_decoder:\n            return self.language_model.shared_embedding_or_output_weight()\n        return None\n\n    def set_input_tensor(self, input_tensor) -> None:\n        # This is usually handled in schedules.py but some inference code still\n        # gives us non-lists or None\n        if not isinstance(input_tensor, list):\n            input_tensor = [input_tensor]\n        assert len(input_tensor) == 1, \"input_tensor should only be length 1 for Qwen2VL\"\n\n        if self.pre_process:\n            self.encoder_hidden_state = input_tensor[0]\n        else:\n            self.language_model.set_input_tensor(input_tensor[0])\n\n    def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool):\n        \"\"\"Freeze model modules.\n\n        Make specific modules non-trainable by setting requires_grad to False for the module's parameters.\n\n        Args:\n            freeze_language_model (bool): Freeze the language model module.\n            freeze_vision_model (bool): Freeze the vision model module.\n            freeze_vision_projection (bool): Freeze the vision projection module.\n        \"\"\"\n        modules = []\n        if freeze_language_model and self.language_model is not None:\n            modules.append(self.language_model)\n        if freeze_vision_model and self.vision_model is not None:\n            modules.append(self.vision_model)\n        if freeze_vision_projection and self.vision_projection is not None:\n            modules.append(self.vision_projection)\n\n        for module in modules:\n            for param in module.parameters():\n                param.requires_grad = False\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        attention_mask: torch.Tensor = None,\n        labels: torch.Tensor = None,\n        inference_params: InferenceParams = None,\n        packed_seq_params: PackedSeqParams = None,\n        extra_block_kwargs: dict = None,\n        pixel_values: torch.Tensor = None,\n        pixel_values_videos: torch.Tensor = None,\n        image_grid_thw: torch.Tensor = None,\n        video_grid_thw: torch.Tensor = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"Forward function of the Qwen2VL model.\n        ### there is a workaround for supporting sequence packing with context parallelism\n        # cp split with sequence packing will make model lose vision token information, so we need to keep\n        # the original input_ids and pack them after vision embedding is calculated,\n        # cooporate with verl's models/mcore/model_forward.py\n        # pack the combined_embeddings to thd here, we check if packed_seq_params is None to determine if\n        #  we need to pack the combined_embeddings to thd\n        # this function needs the position_ids and attention_mask in BSHD format, no matter use packed_seq or not\n\n        Args:\n            image_data (torch.Tensor): input image of shape [total_thw_size, n_features].\n            input_ids (torch.Tensor): input text ids [batch, text_seq_len].\n            position_ids (torch.Tensor): input text position ids [batch, text_seq_len].\n            attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len,\n                combined_seq_len].\n            labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].\n            inference_params (InferenceParams): Inference-time parameters including KV cache.\n\n            video_start_index:\n                0 -- all video\n                len(video_seq) -- all image\n                others -- mixture\n            *_input_mask: should not be None in the first PP stage\n        Returns:\n            output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape\n                [b, s, vocab_size].\n        \"\"\"\n        video_start_index = 0\n        vision_grid_thw = None\n        vision_data = None\n        if image_grid_thw is not None:\n            image_mask = input_ids == self.image_token_id\n            vision_grid_thw = image_grid_thw\n            vision_data = pixel_values\n            video_start_index = image_mask.sum().item()\n        if video_grid_thw is not None:\n            video_mask = input_ids == self.video_token_id\n            if vision_grid_thw is not None:\n                vision_grid_thw = torch.cat([vision_grid_thw, video_grid_thw], dim=0)\n                vision_data = torch.cat([vision_data, pixel_values_videos], dim=0)\n            else:\n                vision_grid_thw = video_grid_thw\n                vision_data = pixel_values_videos\n        use_inference_kv_cache = (\n            inference_params is not None and \"image_tokens_count\" in inference_params.key_value_memory_dict\n        )\n        if use_inference_kv_cache:\n            raise NotImplementedError()\n\n        if self.pre_process:\n            vision_embeds = None\n            if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0:\n                vision_embeds = self.vision_model(\n                    vision_data=vision_data,  # If None, vision model should use intermediate outputs (EPP > 1)\n                    grid_thw=vision_grid_thw,  # should provided in each EPP stage\n                )\n\n            # If running inference, the language model KV cache will be updated for image token positions.\n            # Here we store the image tokens sequence length, which can be used as an offset to the KV cache later.\n            if inference_params is not None:\n                raise NotImplementedError()\n                # inference_params.key_value_memory_dict[\"image_tokens_count\"] = (\n                #     vision_embeddings.shape[0]\n                # )\n\n            # If running inference, we can skip image token computation if they were computed already earlier\n            # for this sample.\n            if use_inference_kv_cache:\n                language_embeddings: torch.Tensor = self.language_model.embedding(\n                    input_ids=input_ids,\n                    position_ids=None,  # NOTE: disable\n                )  # [text_seq_len, b, h_language]\n                # NOTE: why not cat here? is it the combined embeddings useless?\n                combined_embeddings = language_embeddings\n            elif vision_embeds is not None:\n                if video_start_index == 0:\n                    image_embeds = None\n                    video_embeds = vision_embeds\n                elif video_start_index == vision_embeds.shape[0]:\n                    image_embeds = vision_embeds\n                    video_embeds = None\n                elif 0 < video_start_index < vision_embeds.shape[0]:\n                    image_embeds = vision_embeds[:video_start_index]\n                    video_embeds = vision_embeds[video_start_index:]\n                else:\n                    raise ValueError(\n                        f\"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got \"\n                        f\"{video_start_index}\"\n                    )\n\n                combined_embeddings = self.language_model.embedding(\n                    input_ids=input_ids,\n                    position_ids=None,  # NOTE: disable\n                )  # [text_seq_len, b, h_language]\n\n                if image_embeds is not None or video_embeds is not None:\n                    combined_embeddings = combined_embeddings.transpose(0, 1).contiguous()\n                    if image_embeds is not None:\n                        image_mask = (input_ids == self.image_token_id).contiguous()\n                        if image_mask.sum() > 0:\n                            combined_embeddings = combined_embeddings.clone()\n                            combined_embeddings[image_mask] = image_embeds.to(\n                                dtype=combined_embeddings.dtype, device=combined_embeddings.device\n                            )\n                    if video_embeds is not None:\n                        video_mask = (input_ids == self.video_token_id).contiguous()\n                        if video_mask.sum() > 0:\n                            combined_embeddings = combined_embeddings.clone()\n                            combined_embeddings[video_mask] = video_embeds.to(\n                                dtype=combined_embeddings.dtype, device=combined_embeddings.device\n                            )\n                    combined_embeddings = combined_embeddings.transpose(0, 1).contiguous()\n\n            else:\n                combined_embeddings = self.language_model.embedding(\n                    input_ids=input_ids,\n                    position_ids=None,  # NOTE: disable\n                )  # [text_seq_len, b, h_language]\n\n            if packed_seq_params is not None:\n                combined_embeddings = (\n                    preprocess_packed_seqs(\n                        combined_embeddings.transpose(0, 1).contiguous(), attention_mask, pre_process=True\n                    )[0]\n                    .transpose(0, 1)\n                    .contiguous()\n                )\n            if self.config.sequence_parallel:\n                combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings)\n                combined_embeddings = combined_embeddings.contiguous()\n        else:\n            combined_embeddings = None\n        from .rope_utils import get_rope_index\n\n        # BSHD\n        position_ids, _ = get_rope_index(\n            input_ids,\n            image_grid_thw=image_grid_thw,\n            video_grid_thw=video_grid_thw,\n            attention_mask=attention_mask,\n        )\n        # THD\n        if packed_seq_params is not None:\n            position_ids = (\n                preprocess_packed_seqs(position_ids.permute(1, 2, 0), attention_mask, pre_process=True)[0]\n                .permute(2, 0, 1)\n                .contiguous()\n            )\n            attention_mask = None\n\n        output = self.language_model(\n            input_ids=None,\n            position_ids=position_ids,  # None in encoder\n            attention_mask=attention_mask,  # None in encoder\n            decoder_input=combined_embeddings,  # only not None in the first decoder PP stage\n            labels=labels,  # only not None in the last decoder PP stage\n            # inference_params=inference_params,  # currently always None\n            packed_seq_params=packed_seq_params,  # currently always None\n            **(extra_block_kwargs or {}),\n            **kwargs,\n        )\n\n        return output\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/qwen2_5_vl/rope_utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom __future__ import annotations\n\nimport logging\nfrom typing import Optional\n\nimport torch\nfrom megatron.core.models.common.embeddings.rope_utils import *\nfrom megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd\nfrom torch import Tensor\n\nlogger = logging.getLogger(__name__)\n\n\n# Slightly modified from Qwen2VLForConditionalGeneration.get_rope_index\ndef get_rope_index(\n    input_ids: Optional[torch.LongTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    second_per_grid_ts: Optional[torch.Tensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n):\n    \"\"\"\n    Calculate the 3D rope index based on image and video's temporal, height and width in LLM.\n\n    Explanation:\n\n        Each embedding sequence contains vision embedding and text embedding or just contains text embedding.\n\n        For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.\n\n        Examples:\n\n            input_ids: [T T T T T], here T is for text.\n            temporal position_ids: [0, 1, 2, 3, 4]\n            height position_ids: [0, 1, 2, 3, 4]\n            width position_ids: [0, 1, 2, 3, 4]\n\n        For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part\n        and 1D rotary position embedding for text part.\n\n        Examples:\n\n            Temporal (Time): 3 patches, representing different segments of the video in time.\n            Height: 2 patches, dividing each frame vertically.\n            Width: 2 patches, dividing each frame horizontally.\n            We also have some important parameters:\n            fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each\n            second.\n            tokens_per_second: This is a crucial parameter. It dictates how many \"time-steps\" or \"temporal\n                               tokens\" are conceptually packed into a one-second interval of the video.\n                               In this case, we have 25 tokens per second. So each second of the video will be\n                               represented with 25 separate time points. It essentially defines the temporal\n                               granularity.\n            temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.\n            interval: The step size for the temporal position IDs, calculated as tokens_per_second *\n            temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be\n            have a difference of 50 in the temporal position IDs.\n            input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.\n            vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]\n            vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]\n            vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]\n            text temporal position_ids: [101, 102, 103, 104, 105]\n            text height position_ids: [101, 102, 103, 104, 105]\n            text width position_ids: [101, 102, 103, 104, 105]\n            Here we calculate the text start position_ids as the max vision position_ids plus 1.\n\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n        image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):\n            The temporal, height and width of feature shape of each image in LLM.\n        video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):\n            The temporal, height and width of feature shape of each video in LLM.\n        second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):\n            The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n    Returns:\n        position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)\n        mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)\n    \"\"\"\n    spatial_merge_size = 2\n    tokens_per_second = 2\n    image_token_id = 151655\n    video_token_id = 151656\n    vision_start_token_id = 151652\n    mrope_position_deltas = []\n    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):\n        total_input_ids = input_ids\n        if attention_mask is None:\n            attention_mask = torch.ones_like(total_input_ids)\n        position_ids = torch.ones(\n            3,\n            input_ids.shape[0],\n            input_ids.shape[1],\n            dtype=input_ids.dtype,\n            device=input_ids.device,\n        )\n        image_index, video_index = 0, 0\n        attention_mask = attention_mask.to(total_input_ids.device)\n        for i, input_ids in enumerate(total_input_ids):\n            input_ids = input_ids[attention_mask[i] == 1]\n            image_nums, video_nums = 0, 0\n            vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)\n            vision_tokens = input_ids[vision_start_indices + 1]\n            image_nums = (vision_tokens == image_token_id).sum()\n            video_nums = (vision_tokens == video_token_id).sum()\n            input_tokens = input_ids.tolist()\n            llm_pos_ids_list: list = []\n            st = 0\n            remain_images, remain_videos = image_nums, video_nums\n            for _ in range(image_nums + video_nums):\n                if image_token_id in input_tokens and remain_images > 0:\n                    ed_image = input_tokens.index(image_token_id, st)\n                else:\n                    ed_image = len(input_tokens) + 1\n                if video_token_id in input_tokens and remain_videos > 0:\n                    ed_video = input_tokens.index(video_token_id, st)\n                else:\n                    ed_video = len(input_tokens) + 1\n                if ed_image < ed_video:\n                    t, h, w = (\n                        image_grid_thw[image_index][0],\n                        image_grid_thw[image_index][1],\n                        image_grid_thw[image_index][2],\n                    )\n                    second_per_grid_t = 0\n                    image_index += 1\n                    remain_images -= 1\n                    ed = ed_image\n\n                else:\n                    t, h, w = (\n                        video_grid_thw[video_index][0],\n                        video_grid_thw[video_index][1],\n                        video_grid_thw[video_index][2],\n                    )\n                    if second_per_grid_ts is not None:\n                        second_per_grid_t = second_per_grid_ts[video_index]\n                    else:\n                        second_per_grid_t = 1.0\n                    video_index += 1\n                    remain_videos -= 1\n                    ed = ed_video\n                llm_grid_t, llm_grid_h, llm_grid_w = (\n                    t.item(),\n                    h.item() // spatial_merge_size,\n                    w.item() // spatial_merge_size,\n                )\n                text_len = ed - st\n\n                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n                llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n                range_tensor = torch.arange(llm_grid_t).view(-1, 1)\n                expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)\n\n                time_tensor = expanded_range * second_per_grid_t * tokens_per_second\n\n                time_tensor_long = time_tensor.long()\n                t_index = time_tensor_long.flatten()\n\n                h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()\n                w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()\n                llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)\n                st = ed + llm_grid_t * llm_grid_h * llm_grid_w\n\n            if st < len(input_tokens):\n                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n                text_len = len(input_tokens) - st\n                llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n            llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)\n            position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)\n            mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))\n        mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)\n        return position_ids, mrope_position_deltas\n    else:\n        if attention_mask is not None:\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)\n            max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]\n            mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]\n        else:\n            position_ids = (\n                torch.arange(input_ids.shape[1], device=input_ids.device)\n                .view(1, 1, -1)\n                .expand(3, input_ids.shape[0], -1)\n            )\n            mrope_position_deltas = torch.zeros(\n                [input_ids.shape[0], 1],\n                device=input_ids.device,\n                dtype=input_ids.dtype,\n            )\n\n        return position_ids, mrope_position_deltas\n\n\ndef apply_rotary_pos_emb_thd_absolute(\n    t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False\n) -> Tensor:\n    \"\"\"A baseline implementation of applying RoPE for `thd` format.\n\n    Args:\n        t (Tensor): Input tensor T is of shape [t, h, d]\n        cu_seqlens(Tensor):  Cumulative sum of sequence lengths in a batch for `t`,\n        with shape [b + 1] and dtype torch.int32.\n        freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]\n\n    Returns:\n        Tensor: Shape [t, h, d]. The input tensor after applying RoPE.\n    \"\"\"\n    return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1)\n\n\ndef apply_rotary_pos_emb_absolute(\n    t: Tensor,\n    freqs: Tensor,\n    config: TransformerConfig,\n    cu_seqlens: Optional[Tensor] = None,\n):\n    \"\"\"\n    Reroute to the appropriate apply_rotary_pos_emb function depending on\n    bshd (conventional) / thd (packed seq) format\n\n    In Qwen2-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim]\n    \"\"\"\n\n    if config.apply_rope_fusion:\n        if cu_seqlens is None:\n            # NOTE: TE backends do not support mRoPE in bshd format when bs > 1\n            if freqs.shape[1] > 1:\n                return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved)\n            else:\n                return fused_apply_rotary_pos_emb(t, freqs)\n        else:\n            # NOTE: as expected, thd format can use bshd\n            return fused_apply_rotary_pos_emb(t[:, None], freqs).squeeze(1)\n    else:\n        if cu_seqlens is None:\n            return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved)\n        else:\n            return apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved)\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/qwen2_5_vl/vision_config.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom megatron.core import parallel_state\nfrom megatron.core.transformer import TransformerConfig\n\n\ndef get_vision_model_config(config: TransformerConfig) -> TransformerConfig:\n    # Given a Transformer Config from decoder, build vision encoder config\n    # diff: out_hidden_size & intermediate_size\n\n    # mlp: hidden_size -> intermediate_size -> embed_dim, silu\n    # NOTE: here we provide a workaround to solve the wrong layer amount when VPP of decoder is on\n    if config.num_layers in [28, 36]:\n        config.ffn_hidden_size = 3420\n    else:\n        config.ffn_hidden_size = 3456\n\n    if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:\n        config.num_layers = 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size()  # depth\n    else:\n        config.num_layers = 32  # depth\n    config.num_attention_heads = 16  # num_heads\n    config.add_bias_linear = True  # all nn.Linear has bias (MLP, attn)\n    config.add_qkv_bias = True  # qkv_proj in attn has bias\n    config.hidden_size = 1280  # hidden_size\n    config.hidden_dropout = 0.0\n    config.attention_dropout = 0.0\n\n    # config.gated_linear_unit = False # no gated\n    # config.activation_func = quick_gelu # hidden_act\n    config.kv_channels = config.hidden_size // config.num_attention_heads\n    config.num_query_groups = config.num_attention_heads  # no GQA\n    config.layernorm_zero_centered_gamma = False  # False\n    config.apply_query_key_layer_scaling = False  # factor=math.sqrt(head_dim)\n    config.bias_activation_fusion = False  # no swiglu, set false\n    config.bias_dropout_fusion = False  # no dropout, set false\n    config.attention_softmax_in_fp32 = True  # use True\n    # config.normalization = 'LayerNorm' # use RMSNorm\n    config.seq_length = 1\n\n    config.tp_comm_overlap = False\n    config.sequence_parallel = False\n    config.temporal_patch_size = 2\n    config.patch_size = 14\n    config.in_channels = 3\n    config.spatial_merge_size = 2\n\n    config.fullatt_block_indexes = [7, 15, 23, 31]\n    config._qwen2_5_vl_window_size = 112\n    return config\n\n\ndef get_vision_projection_config(\n    config: TransformerConfig, embed_dim: int, spatial_merge_size: int\n) -> TransformerConfig:\n    # merger:\n    # context_dim = hidden_size * merge_size**2\n    # out_hidden_size = hidden_size\n    # context_dim -> context_dim -> out_hidden_size\n    # MLP:\n    # input_size -> ffn_hidden_size -> hidden_size\n    # spec: LN -> Linear(bias=True) -> GELU -> Linear(bias=True)\n    config.gated_linear_unit = False\n    config.bias_activation_fusion = False\n    config.add_bias_linear = True\n    config.ffn_hidden_size = embed_dim * (spatial_merge_size**2)\n    config.activation_func = torch.nn.functional.gelu\n    config.tp_comm_overlap = False\n    config.sequence_parallel = False\n    return config\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/qwen2_5_vl/vision_model.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\nfrom megatron.core import InferenceParams\nfrom megatron.core.models.common.vision_module.vision_module import VisionModule\nfrom megatron.core.models.vision.multimodal_projector import MultimodalProjector\nfrom megatron.core.packed_seq_params import PackedSeqParams\nfrom megatron.core.transformer.enums import ModelType\nfrom megatron.core.transformer.spec_utils import ModuleSpec\nfrom megatron.core.transformer.transformer_config import TransformerConfig\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .vision_transformer_block import Qwen2_5VisionTransformerBlock as TransformerBlock\n\n\n# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py\nclass PatchEmbed(nn.Module):\n    def __init__(\n        self,\n        patch_size: int = 14,\n        temporal_patch_size: int = 2,\n        in_channels: int = 3,\n        embed_dim: int = 1152,\n    ) -> None:\n        super().__init__()\n        self.patch_size = patch_size\n        self.temporal_patch_size = temporal_patch_size\n        self.in_channels = in_channels\n        self.embed_dim = embed_dim\n\n        kernel_size = [temporal_patch_size, patch_size, patch_size]\n        self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        target_dtype = self.proj.weight.dtype\n        hidden_states = hidden_states.view(\n            -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size\n        )\n        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)\n        return hidden_states\n\n\n# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py\nclass VisionRotaryEmbedding(nn.Module):\n    def __init__(self, dim: int, theta: float = 10000.0) -> None:\n        super().__init__()\n        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n    def forward(self, seqlen: int) -> torch.Tensor:\n        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)\n        freqs = torch.outer(seq, self.inv_freq)\n        return freqs.float()\n\n\nclass Qwen2_5VisionModel(VisionModule):\n    \"\"\"Qwen2.5 ViT vision model.\n\n    Args:\n        transformer_config (TransformerConfig): Transformer config.\n        transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers.\n        ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre.\n        add_class_token (bool, optional): Include a class token. Defaults to True.\n        class_token_len (int): Class token length. Defaults to 1 but 8 may be faster.\n        patch_dim (int): Image patch size.\n        img_h (int): Input image height.\n        img_w (int): Input image width.\n    \"\"\"\n\n    def __init__(\n        self,\n        transformer_config: TransformerConfig,\n        transformer_layer_spec: ModuleSpec,\n        projection_config: TransformerConfig,\n        projection_layer_spec: ModuleSpec,\n        projection_type: str = \"mlp\",\n        pre_process: bool = True,\n        post_process: bool = False,\n    ) -> None:\n        super().__init__(config=transformer_config)\n\n        self.spatial_merge_size = transformer_config.spatial_merge_size\n\n        embed_dim = transformer_config.hidden_size\n        num_heads = transformer_config.num_attention_heads\n        temporal_patch_size = transformer_config.temporal_patch_size\n        patch_size = transformer_config.patch_size\n        in_channels = transformer_config.in_channels\n\n        self.patch_size = transformer_config.patch_size\n        self.fullatt_block_indexes = transformer_config.fullatt_block_indexes\n        self.window_size = transformer_config._qwen2_5_vl_window_size\n        self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size\n\n        self.max_sequence_length = transformer_config.seq_length\n        self.patch_embed = PatchEmbed(\n            patch_size=patch_size,\n            temporal_patch_size=temporal_patch_size,\n            in_channels=in_channels,\n            embed_dim=embed_dim,\n        )\n\n        head_dim = embed_dim // num_heads\n        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)\n\n        self.model_type = ModelType.encoder_or_decoder\n        self.pre_process = pre_process\n        self.post_process = post_process\n\n        # Transformer layers.\n        # TODO: Follow-up changes will make pre and post_process configurable. They are needed for supporting\n        # pipeline parallelism.\n        # NOTE: a final layer norm and/or linear layer present in some implementations are omitted here.\n        self.decoder = TransformerBlock(\n            config=transformer_config,\n            spec=transformer_layer_spec,\n            pre_process=self.pre_process,\n            post_process=self.post_process,\n            post_layer_norm=True,\n        )\n\n        self.merge_hidden_size = projection_config.ffn_hidden_size\n        self.square_merge_size = self.merge_hidden_size // embed_dim\n\n        if self.post_process:\n            self.projection = MultimodalProjector(\n                projection_config, projection_layer_spec, projection_type, projection_config.ffn_hidden_size\n            )\n        else:\n            self.projection = None\n\n        self.input_tensor = None\n\n    def set_input_tensor(self, input_tensor: torch.Tensor) -> None:\n        \"\"\"Sets input tensor to the model.\n\n        Args:\n            input_tensor (Tensor): Sets the input tensor for the model.\n        \"\"\"\n        if self.pre_process:  # always True\n            self.input_tensor = input_tensor\n        else:\n            raise NotImplementedError()\n\n    def rot_pos_emb(self, grid_thw):\n        pos_ids = []\n        for t, h, w in grid_thw:\n            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)\n            hpos_ids = hpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            hpos_ids = hpos_ids.permute(0, 2, 1, 3)\n            hpos_ids = hpos_ids.flatten()\n\n            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)\n            wpos_ids = wpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            wpos_ids = wpos_ids.permute(0, 2, 1, 3)\n            wpos_ids = wpos_ids.flatten()\n            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))\n        pos_ids = torch.cat(pos_ids, dim=0).to(grid_thw.device)\n        max_grid_size = grid_thw[:, 1:].max()\n        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n        return rotary_pos_emb\n\n    def get_window_index(self, grid_thw):\n        window_index: list = []\n        cu_window_seqlens: list = [0]\n        window_index_id = 0\n        vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size\n\n        for grid_t, grid_h, grid_w in grid_thw:\n            llm_grid_h, llm_grid_w = (\n                grid_h // self.spatial_merge_size,\n                grid_w // self.spatial_merge_size,\n            )\n            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)\n            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size\n            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size\n            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size\n            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size\n            index_padded = F.pad(index, (0, pad_w, 0, pad_h), \"constant\", -100)\n            index_padded = index_padded.reshape(\n                grid_t,\n                num_windows_h,\n                vit_merger_window_size,\n                num_windows_w,\n                vit_merger_window_size,\n            )\n            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(\n                grid_t,\n                num_windows_h * num_windows_w,\n                vit_merger_window_size,\n                vit_merger_window_size,\n            )\n            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)\n            index_padded = index_padded.reshape(-1)\n            index_new = index_padded[index_padded != -100]\n            window_index.append(index_new + window_index_id)\n            cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]\n            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())\n            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()\n        window_index = torch.cat(window_index, dim=0)\n\n        return window_index, cu_window_seqlens\n\n    def forward(\n        self,\n        vision_data: Optional[torch.Tensor],\n        grid_thw: torch.Tensor,\n        inference_params: Optional[InferenceParams] = None,\n        extra_block_kwargs: dict = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward function of the Qwen2 Vision Model. This function passes the input tensors\n        through the embedding layer and then the transformer.\n\n        Args:\n            x (torch.Tensor): input image/video data of shape [n_tokens, n_dims]\n            grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame\n            packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend\n\n        Returns:\n            x (torch.Tensor): output after final transformer block of shape [b, s, h].\n        \"\"\"\n        assert grid_thw is not None\n        assert self.input_tensor is None\n        assert inference_params is None\n\n        # Rotary positional embeddings (embedding is None for PP intermediate devices)\n        vision_data = self.patch_embed(vision_data)\n        window_index, cu_window_seqlens = self.get_window_index(grid_thw)\n        cu_window_seqlens = torch.tensor(\n            cu_window_seqlens,\n            device=vision_data.device,\n            dtype=torch.int32,\n        )\n        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)\n\n        seq_len, _ = vision_data.size()\n        vision_data = vision_data.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)\n        vision_data = vision_data[window_index, :, :]\n        vision_data = vision_data.reshape(seq_len, 1, -1)\n\n        rotary_pos_emb = self.rot_pos_emb(grid_thw)\n        rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)\n        rotary_pos_emb = rotary_pos_emb[window_index, :, :]\n        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, 1, 1, -1).repeat(1, 1, 1, 2)\n\n        hidden_states = self.decoder(\n            hidden_states=vision_data,\n            attention_mask=None,\n            inference_params=inference_params,\n            rotary_pos_emb=rotary_pos_emb,\n            packed_seq_params=self.build_packed_seq_params(None, cu_window_seqlens),\n            packed_seq_params_full=self.build_packed_seq_params(grid_thw),\n            fullatt_block_indexes=self.fullatt_block_indexes,\n            **(extra_block_kwargs or {}),\n        )\n\n        hidden_states = self.projection(hidden_states.view(-1, self.merge_hidden_size))\n        reverse_indices = torch.argsort(window_index)\n        return hidden_states[reverse_indices, :]\n\n    def build_packed_seq_params(\n        self,\n        grid_thw: Optional[torch.Tensor],\n        cu_seqlens: Optional[torch.Tensor] = None,\n    ) -> PackedSeqParams:\n        # NOTE: each frame is a sequence (rather than each grid)\n        if grid_thw is not None:\n            seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])\n            cu_seqlens = seqlens.cumsum(dim=0)\n            cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).int()\n        else:\n            seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n\n        max_seqlen_q = seqlens.max()\n        return PackedSeqParams(\n            cu_seqlens_q=cu_seqlens,\n            cu_seqlens_kv=cu_seqlens,\n            qkv_format=\"thd\",\n            max_seqlen_q=max_seqlen_q,\n            max_seqlen_kv=max_seqlen_q,\n        )\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom megatron.core.transformer.transformer_block import *\n\n\nclass Qwen2_5VisionTransformerBlock(TransformerBlock):\n    def _checkpointed_forward(\n        self,\n        hidden_states: Tensor,\n        attention_mask: Tensor,\n        context: Tensor,\n        context_mask: Tensor,\n        rotary_pos_emb: Tensor,\n        attention_bias: Tensor,\n        packed_seq_params: PackedSeqParams,\n        packed_seq_params_full: PackedSeqParams,\n        fullatt_block_indexes,\n    ):\n        \"\"\"Forward method with activation checkpointing.\"\"\"\n\n        def custom(start: int, end: int):\n            def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb):\n                for index in range(start, end):\n                    if index in fullatt_block_indexes:\n                        packed_seq_params_now = packed_seq_params_full\n                    else:\n                        packed_seq_params_now = packed_seq_params\n                    layer = self._get_layer(index)\n                    hidden_states, context = layer(\n                        hidden_states=hidden_states,\n                        attention_mask=attention_mask,\n                        context=context,\n                        context_mask=context_mask,\n                        rotary_pos_emb=rotary_pos_emb,\n                        attention_bias=attention_bias,\n                        inference_context=None,\n                        packed_seq_params=packed_seq_params_now,\n                    )\n                return hidden_states, context\n\n            return custom_forward\n\n        def checkpoint_handler(forward_func):\n            \"\"\"Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`\"\"\"\n            if self.config.fp8:\n                return te_checkpoint(\n                    forward_func,\n                    self.config.distribute_saved_activations,\n                    tensor_parallel.random.get_cuda_rng_tracker,\n                    parallel_state.get_tensor_model_parallel_group(),\n                    hidden_states,\n                    attention_mask,\n                    context,\n                    context_mask,\n                    rotary_pos_emb,\n                )\n            else:\n                return tensor_parallel.checkpoint(\n                    forward_func,\n                    self.config.distribute_saved_activations,\n                    hidden_states,\n                    attention_mask,\n                    context,\n                    context_mask,\n                    rotary_pos_emb,\n                )\n\n        if self.config.recompute_method == \"uniform\":\n            # Uniformly divide the total number of Transformer layers and checkpoint\n            # the input activation of each divided chunk.\n            # A method to further reduce memory usage reducing checkpoints.\n            layer_idx = 0\n            while layer_idx < self.num_layers_per_pipeline_rank:\n                hidden_states, context = checkpoint_handler(\n                    custom(layer_idx, layer_idx + self.config.recompute_num_layers)\n                )\n\n                layer_idx += self.config.recompute_num_layers\n\n        elif self.config.recompute_method == \"block\":\n            # Checkpoint the input activation of only a set number of individual\n            # Transformer layers and skip the rest.\n            # A method fully use the device memory removing redundant re-computation.\n            recompute_skip_num_layers = 0\n            for layer_idx in range(self.num_layers_per_pipeline_rank):\n                # Skip recomputation when input grad computation is not needed.\n                # Need to have at least one input tensor with gradient computation\n                # for re-enterant autograd engine.\n                if self.config.fp8 and not hidden_states.requires_grad:\n                    recompute_skip_num_layers += 1\n                if (\n                    layer_idx >= recompute_skip_num_layers\n                    and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers\n                ):\n                    hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1))\n                else:\n                    hidden_states, context = custom(layer_idx, layer_idx + 1)(\n                        hidden_states, attention_mask, context, context_mask, rotary_pos_emb\n                    )\n        else:\n            raise ValueError(\"Invalid activation recompute method.\")\n\n        return hidden_states\n\n    def forward(\n        self,\n        hidden_states: Union[Tensor, WrappedTensor],\n        attention_mask: Optional[Tensor],\n        context: Optional[Tensor] = None,\n        context_mask: Optional[Tensor] = None,\n        rotary_pos_emb: Optional[Tensor] = None,\n        rotary_pos_cos: Optional[Tensor] = None,\n        rotary_pos_sin: Optional[Tensor] = None,\n        attention_bias: Optional[Tensor] = None,\n        inference_context: Optional[BaseInferenceContext] = None,\n        packed_seq_params: Optional[PackedSeqParams] = None,\n        sequence_len_offset: Optional[Tensor] = None,\n        packed_seq_params_full: PackedSeqParams = None,\n        fullatt_block_indexes=None,\n        *,\n        inference_params: Optional[BaseInferenceContext] = None,\n    ):\n        \"\"\"\n        Perform the forward pass through the transformer block.\n\n        This method handles the core computation of the transformer, including\n        self-attention, optional cross-attention, and feed-forward operations.\n\n        Args:\n            hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h]\n                where s is the sequence length, b is the batch size, and h is the hidden size.\n                Can be passed as a WrappedTensor during inference to avoid an obsolete\n                reference in the calling function.\n            attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking\n                self-attention.\n            context (Tensor, optional): Context tensor for cross-attention.\n            context_mask (Tensor, optional): Mask for cross-attention context\n            rotary_pos_emb (Tensor, optional): Rotary positional embeddings.\n            attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable\n                to [b, num_head, sq, skv], e.g. [1, 1, sq, skv].\n                Used as an alternative to apply attention mask for TE cuDNN attention.\n            inference_context (BaseInferenceContext, optional): Parameters for inference-time\n                optimizations.\n            packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence\n                processing.\n\n        Returns:\n            Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape\n            [s, b, h], and optionally the updated context tensor if cross-attention is used.\n        \"\"\"\n\n        inference_context = deprecate_inference_params(inference_context, inference_params)\n\n        # Delete the obsolete reference to the initial input tensor if necessary\n        if isinstance(hidden_states, WrappedTensor):\n            hidden_states = hidden_states.unwrap()\n\n        if not self.pre_process:\n            # See set_input_tensor()\n            hidden_states = self.input_tensor\n\n        # Update the inference parameters with the current batch size in case it is variable\n        if inference_context and not self.training:\n            inference_context.current_batch_size = hidden_states.size(1)\n\n        # Viewless tensor.\n        # - We only need to create a viewless tensor in the case of micro batch\n        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'\n        #   above creates a view tensor, and '.contiguous()' is a pass-through.\n        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating\n        #   the need to make it viewless.\n        #\n        #   However, we don't explicitly check mbs == 1 here because\n        #   make_viewless_tensor() has negligible overhead when its input\n        #   is already viewless.\n        #\n        # - For the 'else' case above, calling make_viewless_tensor() here is\n        #   likely redundant, since p2p_communication.py (likely originator)\n        #   already creates viewless tensors. That said, make_viewless_tensor()\n        #   is called here to be future-proof and corner-case-proof.\n        hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)\n\n        if self.config.sequence_parallel:\n            rng_context = tensor_parallel.get_cuda_rng_tracker().fork()\n        else:\n            rng_context = nullcontext()\n\n        # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(),\n        # otherwise do nothing extra at the outer level\n        # if we are using other fp8 recipes, then the context manager enter&exit are free\n        # we can wrap fp8_context within the for loop over layers, so that we can fine-grained\n        # control which layer will be fp8 or bf16\n        use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed\n        use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed\n        outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext()\n\n        with rng_context, outer_fp8_context:\n            # Forward pass.\n            if self.config.recompute_granularity == \"full\" and self.training:\n                hidden_states = self._checkpointed_forward(\n                    hidden_states=hidden_states,\n                    attention_mask=attention_mask,\n                    context=context,\n                    context_mask=context_mask,\n                    rotary_pos_emb=rotary_pos_emb,\n                    attention_bias=attention_bias,\n                    packed_seq_params=packed_seq_params,\n                    packed_seq_params_full=packed_seq_params_full,\n                    fullatt_block_indexes=fullatt_block_indexes,\n                )\n            else:\n                for l_no, layer in enumerate(self.layers):\n                    inner_fp8_context = (\n                        get_fp8_context(self.config, layer.layer_number - 1) if use_inner_fp8_context else nullcontext()\n                    )\n                    if l_no in fullatt_block_indexes:\n                        packed_seq_params_now = packed_seq_params_full\n                    else:\n                        packed_seq_params_now = packed_seq_params\n                    with self.offload_context, inner_fp8_context:\n                        hidden_states, context = layer(\n                            hidden_states=hidden_states,\n                            attention_mask=attention_mask,\n                            context=context,\n                            context_mask=context_mask,\n                            rotary_pos_emb=rotary_pos_emb,\n                            rotary_pos_cos=rotary_pos_cos,\n                            rotary_pos_sin=rotary_pos_sin,\n                            attention_bias=attention_bias,\n                            inference_context=inference_context,\n                            packed_seq_params=packed_seq_params_now,\n                            sequence_len_offset=sequence_len_offset,\n                        )\n\n                    if (\n                        torch.is_grad_enabled()\n                        and self.config.cpu_offloading\n                        and self.group_prefetch_offload_commit_async is not None\n                    ):\n                        hidden_states = self.group_prefetch_offload_commit_async(hidden_states)\n\n        # Final layer norm.\n        if self.final_layernorm is not None:\n            hidden_states = self.final_layernorm(hidden_states)\n            # TENorm produces a \"viewed\" tensor. This will result in schedule.py's\n            # deallocate_output_tensor() throwing an error, so a viewless tensor is\n            # created to prevent this.\n            hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)\n\n        return hidden_states\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/readme.md",
    "content": "# verl Megatron-Core Models\nNow we use [mbridge](https://github.com/iseekyan/mbridge) to support megatron models. And we will migrate to [megatron-bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) in the future.\n\nWith the mbridge, we can use allmost all the Megatron-Core features to support new models with little effort. And no offline weights conversion is needed, all the weights conversion is done online. We can directly save the mcore model to huggingface format during training.\n\nAlso, we can easily upgrade the mcore version to the latest version. In most cases, the upgrade is seamless. (except when the mcore API changes and we need to update the verl code accordingly)\n\n## How to support new models (new)\n1. make sure the model is supported by vLLM\n2. Support the model in [mbridge](https://github.com/iseekyan/mbridge), see its currently supported models for example.\n    - we will migrate to [megatron-bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) in the future.\n3. Register the model forward function in verl, see the example in `verl/verl/models/mcore/registry.py`.\n\n\n\n# Below are deprecated\nThe earlier versions of verl use `Megatron-LM` 0.4 and workaround huggingface model classes. To better use the latest features and speedup of modern Megatron, we are migrating to `Megatron-Core`(mcore), and use the recommended `GPTModel` class for all language models. With mcore `GPTModel`, we can use the latest features like `context parallel`, `expert parallel`, `dist_checkpointing`, etc. and we can update mcore with little effort in the future for new features.\n\nThe migration has been successful with the help of the mcore team and the community. What we have done is:\n1. update `Megatron` version to `0.14.0`\n2. migrate `LlamaForCausalLM` and `Qwen2ForCausalLM` to mcore `GPTModel`\n3. support sequence packing/thd format.\n4. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`.\n5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion script from huggingface to mcore `dist_checkpointing` format.\n\nWe are working on the following features:\n- support `Qwen2MoeForCausalLM`\n- support `MixtralForCausalLM`\n- support `DeepseekV3ForCausalLM`\n- support `expert parallel`\n\nFeatures we invite the community to contribute:\n- better scripts for offline weights conversion from huggingface to mcore `dist_checkpointing` format.\n    - conversion of large models with multiple GPUs\n    - conversion of large models with single GPU\n- refactor the `megatron_checkpoint_manager.py` by `dist_checkpointing` format.\n- support llama4\n- support qwen2.5-vl\n\nTo track the progress of verl mcore integration, please refer to the [mcore integration issue](https://github.com/volcengine/verl/issues/1033).\n\n## How things work now\nTo engage the community in contributing, here are the key steps in our mcore integration process and features under development. \n\nThe huggingface `transformers` is the de facto standard of model zoo while mcore is good at computation efficiency. The main challenge is conversion between the two.\nmain steps:\n1. modelling the huggingface model with mcore `GPTModel`\n    - a. convert the huggingface config to mcore `TransformerConfig`\n    - b. init the mcore `GPTModel` with the converted config\n    - c. load the huggingface model weights to the `GPTModel`\n2. online weight conversion from mcore to huggingface (due to the rollout engine `vLLM` is using huggingface format)\n    - a. bridge the gap between mcore and huggingface weights format and name mapping\n    - b. online resharding the mcore weights to rollout engine\n        - this part is very complicated with multiple parallel strategies composition between mcore and rollout engine\n3. support the mcore features in verl\n    - a. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`\n    - b. support recompute and other mcore speed up features\n\n4. checkpointing\n    - a. support recovering the verl training.\n    - b. support exporting the mcore checkpoint to huggingface format, for downstream inference.\n\n### Modelling the huggingface model with mcore `GPTModel`\nThe first step is to convert huggingface config to mcore `TransformerConfig` and init the mcore `GPTModel` with the converted config. See code in `verl/models/mcore/config_converter.py` and `verl/verl/models/mcore/models/model_initializer.py`. The corresponding model forward code is in `verl/verl/models/mcore/models/model_forward.py`.\n\nThere are two ways of loading the huggingface model weights to the `GPTModel`\n1. Runtime loading\n    - every rank loads the entire huggingface model weights and then shard and convert to mcore weights.\n    - speed is slow and memory consumption is high.\n    - this way is deprecated and will not support new models.\n2. Offline loading\n    - use offline script to convert the huggingface model weights to mcore weights and save with mcore `dist_checkpointing` format.\n    - online loading and sharding is automatically done by mcore `dist_checkpointing` format. The speed is fast and memory consumption is low.\n    - the offline script is in `verl/scripts/converter_hf_to_mcore.py`.\n\n### online weight conversion from mcore to huggingface\nSee function `convert_megatron_model_to_transformers_model` in `verl/utils/megatron_utils.py` for the details.\n\nIt should be refatored for extensibility and better performance.\n\n### support the mcore features in verl\nMost of the features of `GPTModel` is out-of-the-box supported in verl through changing the `TransformerConfig`, except those about parallel strategies, such as `expert parallel`. \nFeatures about parallel strategies should be supported with changes about the online weights conversion(especially the resharding part) and verl work dispatching.\n\n### checkpointing\nThe existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger`.\n\nThe existing checkpoint format simply saves every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format.\n\n\n## How to support new models\n1. make sure the model is supported by vLLM\n2. modelling the huggingface model with mcore `GPTModel` (The [Pai-Megatron-Path](https://github.com/alibaba/Pai-Megatron-Patch/tree/main) is a good reference)\n    - a. convert the huggingface config to mcore `TransformerConfig`\n    - b. init the mcore `GPTModel` with the converted config\n    - c. load the huggingface model weights to the `GPTModel`\n    - d. for VLM the interface might be different, it is ok to add a new model class with GPTModel as its module.\n3. offline weights conversion from huggingface to mcore `dist_checkpointing` format\n4. support online weights conversion from mcore to huggingface\n    - it is recommended to initialize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct.\n\n\n## How to scale up to larger models like deepseek-v3 or other 100B+ models\nThe greatest challenge for scaling up to larger models is the memory consumption.\n\nThe necessary features under development for scaling up are\n1. Training engine part\n    - expert parallel\n2. Rollout engine part\n    - pipeline parallel\n    - expert parallel\n    - more efficient and general weight resharding and loading\n3. Offline weights conversion\n    - support weights larger than single GPU memory\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/registry.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nRegistry module for model architecture components.\n\"\"\"\n\nfrom enum import Enum\nfrom typing import Callable\n\nimport torch\nimport torch.nn as nn\n\nfrom .config_converter import (\n    PretrainedConfig,\n    TransformerConfig,\n    hf_to_mcore_config_dense,\n    hf_to_mcore_config_dpskv3,\n    hf_to_mcore_config_llama4,\n    hf_to_mcore_config_mixtral,\n    hf_to_mcore_config_qwen2_5_vl,\n    hf_to_mcore_config_qwen2moe,\n    hf_to_mcore_config_qwen3moe,\n)\nfrom .model_forward import gptmodel_forward_no_padding, model_forward_gen\nfrom .model_forward_fused import fused_forward_model_gen\nfrom .model_initializer import (\n    BaseModelInitializer,\n    DeepseekV3Model,\n    DenseModel,\n    MixtralModel,\n    Qwen2MoEModel,\n    Qwen3MoEModel,\n    Qwen25VLModel,\n)\nfrom .weight_converter import (\n    McoreToHFWeightConverterDense,\n    McoreToHFWeightConverterDpskv3,\n    McoreToHFWeightConverterMixtral,\n    McoreToHFWeightConverterQwen2_5_VL,\n    McoreToHFWeightConverterQwen2Moe,\n    McoreToHFWeightConverterQwen3Moe,\n)\n\n\nclass SupportedModel(Enum):\n    LLAMA = \"LlamaForCausalLM\"  # tested\n    QWEN2 = \"Qwen2ForCausalLM\"  # tested\n    QWEN2_MOE = \"Qwen2MoeForCausalLM\"  # pending\n    DEEPSEEK_V3 = \"DeepseekV3ForCausalLM\"  # not tested\n    MIXTRAL = \"MixtralForCausalLM\"  # tested\n    QWEN2_5_VL = \"Qwen2_5_VLForConditionalGeneration\"  # not supported\n    LLAMA4 = \"Llama4ForConditionalGeneration\"  # not tested\n    QWEN3 = \"Qwen3ForCausalLM\"  # tested\n    QWEN3_MOE = \"Qwen3MoeForCausalLM\"  # tested\n    GLM4_MOE = \"Glm4MoeForCausalLM\"\n\n    QWEN3_TOKEN_CLASSIFICATION = \"Qwen3ForTokenClassification\"\n    QWEN3_MOE_VL = \"Qwen3VLMoeForConditionalGeneration\"\n    QWEN3_VL = \"Qwen3VLForConditionalGeneration\"\n\n\n# Registry for model configuration converters\nMODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {\n    SupportedModel.LLAMA: hf_to_mcore_config_dense,\n    SupportedModel.QWEN2: hf_to_mcore_config_dense,\n    SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,\n    SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3,\n    SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral,\n    SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,\n    SupportedModel.LLAMA4: hf_to_mcore_config_llama4,\n    SupportedModel.QWEN3: hf_to_mcore_config_dense,\n    SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,\n    SupportedModel.QWEN3_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense,\n}\n\n# Registry for model initializers\nMODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = {\n    SupportedModel.LLAMA: DenseModel,\n    SupportedModel.QWEN2: DenseModel,\n    SupportedModel.QWEN2_MOE: Qwen2MoEModel,\n    SupportedModel.MIXTRAL: MixtralModel,\n    SupportedModel.DEEPSEEK_V3: DeepseekV3Model,\n    SupportedModel.QWEN2_5_VL: Qwen25VLModel,\n    SupportedModel.LLAMA4: DenseModel,\n    SupportedModel.QWEN3: DenseModel,\n    SupportedModel.QWEN3_MOE: Qwen3MoEModel,\n    SupportedModel.QWEN3_TOKEN_CLASSIFICATION: DenseModel,\n}\n\n# Registry for model forward functions\nMODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {\n    SupportedModel.LLAMA: model_forward_gen(),\n    SupportedModel.QWEN2: model_forward_gen(),\n    SupportedModel.QWEN2_MOE: model_forward_gen(),\n    SupportedModel.MIXTRAL: model_forward_gen(),\n    SupportedModel.DEEPSEEK_V3: model_forward_gen(),\n    SupportedModel.LLAMA4: model_forward_gen(),\n    SupportedModel.QWEN3: model_forward_gen(),\n    SupportedModel.QWEN3_MOE: model_forward_gen(),\n    SupportedModel.QWEN2_5_VL: model_forward_gen(True),\n    SupportedModel.QWEN3_MOE_VL: model_forward_gen(True),\n    SupportedModel.QWEN3_VL: model_forward_gen(True),\n    SupportedModel.DEEPSEEK_V3: model_forward_gen(),\n    SupportedModel.GLM4_MOE: model_forward_gen(),\n    SupportedModel.QWEN3_TOKEN_CLASSIFICATION: model_forward_gen(),\n}\n\n# Registry for model forward functions\nMODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = {\n    SupportedModel.LLAMA: gptmodel_forward_no_padding,\n    SupportedModel.QWEN2: gptmodel_forward_no_padding,\n    SupportedModel.QWEN2_MOE: gptmodel_forward_no_padding,\n    SupportedModel.MIXTRAL: gptmodel_forward_no_padding,\n    SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,\n    SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding,\n    SupportedModel.QWEN3_MOE_VL: gptmodel_forward_no_padding,\n    SupportedModel.QWEN3_VL: gptmodel_forward_no_padding,\n    SupportedModel.LLAMA4: gptmodel_forward_no_padding,\n    SupportedModel.QWEN3: gptmodel_forward_no_padding,\n    SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding,\n    SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,\n    SupportedModel.GLM4_MOE: gptmodel_forward_no_padding,\n    SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding,\n}\n\n# Registry for model forward functions\nMODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {\n    SupportedModel.LLAMA: fused_forward_model_gen(),\n    SupportedModel.QWEN2: fused_forward_model_gen(),\n    SupportedModel.QWEN2_MOE: fused_forward_model_gen(),\n    SupportedModel.MIXTRAL: fused_forward_model_gen(),\n    SupportedModel.DEEPSEEK_V3: fused_forward_model_gen(),\n    SupportedModel.QWEN2_5_VL: fused_forward_model_gen(True),\n    SupportedModel.QWEN3_MOE_VL: fused_forward_model_gen(True),\n    SupportedModel.QWEN3_VL: fused_forward_model_gen(True),\n    SupportedModel.LLAMA4: fused_forward_model_gen(),\n    SupportedModel.QWEN3: fused_forward_model_gen(),\n    SupportedModel.QWEN3_MOE: fused_forward_model_gen(),\n    SupportedModel.DEEPSEEK_V3: fused_forward_model_gen(),\n    SupportedModel.GLM4_MOE: fused_forward_model_gen(),\n}\n\n# Registry for model weight converters\nMODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = {\n    SupportedModel.LLAMA: McoreToHFWeightConverterDense,\n    SupportedModel.QWEN2: McoreToHFWeightConverterDense,\n    SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,\n    SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral,\n    SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3,\n    SupportedModel.QWEN3: McoreToHFWeightConverterDense,\n    SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,\n    SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL,\n    SupportedModel.QWEN3_TOKEN_CLASSIFICATION: McoreToHFWeightConverterDense,\n}\n\n\ndef get_supported_model(model_type: str) -> SupportedModel:\n    try:\n        return SupportedModel(model_type)\n    except ValueError as err:\n        supported_models = [e.value for e in SupportedModel]\n        raise NotImplementedError(\n            f\"Model Type: {model_type} not supported. Supported models: {supported_models}\"\n        ) from err\n\n\ndef hf_to_mcore_config(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    \"\"\"Convert huggingface PretrainedConfig to mcore TransformerConfig.\n\n    Args:\n        hf_config: The huggingface PretrainedConfig.\n        dtype: The dtype of the model.\n        **override_transformer_config_kwargs: The kwargs to override the transformer config.\n\n    Returns:\n        The mcore TransformerConfig.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs)\n\n\ndef init_mcore_model(\n    tfconfig: TransformerConfig,\n    hf_config: PretrainedConfig,\n    pre_process: bool = True,\n    post_process: bool = None,\n    *,\n    share_embeddings_and_output_weights: bool = False,\n    value: bool = False,\n    **extra_kwargs,  # may be used for vlm and moe\n) -> nn.Module:\n    \"\"\"\n    Initialize a Mcore model.\n\n    Args:\n        tfconfig: The transformer config.\n        hf_config: The HuggingFace config.\n        pre_process: Optional pre-processing function.\n        post_process: Optional post-processing function.\n        share_embeddings_and_output_weights: Whether to share embeddings and output weights.\n        value: Whether to use value.\n        **extra_kwargs: Additional keyword arguments.\n\n    Returns:\n        The initialized model.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    initializer_cls = MODEL_INITIALIZER_REGISTRY[model]\n    initializer = initializer_cls(tfconfig, hf_config)\n    return initializer.initialize(\n        pre_process=pre_process,\n        post_process=post_process,\n        share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n        value=value,\n        **extra_kwargs,\n    )\n\n\ndef get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:\n    \"\"\"\n    Get the forward function for given model architecture.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    return MODEL_FORWARD_REGISTRY[model]\n\n\ndef get_mcore_forward_no_padding_fn(hf_config: PretrainedConfig) -> Callable:\n    \"\"\"\n    Get the forward function for given model architecture.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    return MODEL_FORWARD_NOPAD_REGISTRY[model]\n\n\ndef get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable:\n    \"\"\"\n    Get the forward function for given model architecture.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    return MODEL_FORWARD_FUSED_REGISTRY[model]\n\n\ndef get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable:\n    \"\"\"\n    Get the weight converter for given model architecture.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    tfconfig = hf_to_mcore_config(hf_config, dtype)\n    return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig)\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/saver.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import mpu\nfrom megatron.core.distributed import DistributedDataParallel as LocalDDP\nfrom megatron.core.transformer.module import Float16Module\nfrom torch.nn.parallel import DistributedDataParallel as torchDDP\n\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.logger import print_rank_0\nfrom verl.utils.megatron_utils import unwrap_model\n\n\ndef _megatron_calc_global_rank(\n    tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0\n):\n    \"\"\"Calculate global rank with support for CP/EP parallelism\"\"\"\n\n    # Get parallel sizes for each dimension\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    dp_size = mpu.get_data_parallel_world_size()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    cp_size = mpu.get_context_parallel_world_size()\n    # ep_size = mpu.get_expert_model_parallel_world_size()\n\n    # Verify total GPU count matches (must be consistent with parallel_state.py)\n    total_size = tp_size * dp_size * pp_size * cp_size\n    assert total_size == torch.distributed.get_world_size(), (\n        f\"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}\"\n    )\n\n    # Core calculation logic (corresponds to RankGenerator order parameter)\n    # Assumes default order is \"tp-cp-ep-dp-pp\"\n    return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Merge sharded parameters of a Megatron module into a merged checkpoint.\n\n    Args:\n        wrapped_models (list of megatron.core.distributed.DistributedDataParallel):\n            The local DDP wrapped megatron modules.\n        config (str or None):\n            HF config for model\n        dtype: model params type\n        is_value_model: if model is value model\n        tie_word_embeddings: tie_word_embeddings\n    Returns:\n        state_dict (dict):\n            The merged state_dict in rank 0, and an empty dictionary in other ranks.\n    \"\"\"\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    cp_rank = mpu.get_context_parallel_rank()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if dist.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        assert len(models[i].decoder.layers) == num_layers_per_model, (\n            \"len model layers {} not equal to num_layers_per_model {}\".format(\n                len(models[i].decoder.layers), num_layers_per_model\n            )\n        )\n\n    state_dict = dict()\n\n    def _get_cpu_tensor(tensor: torch.Tensor):\n        if tensor is None:\n            return None\n        if tensor.device == torch.device(\"cpu\"):\n            return tensor.detach().clone()\n        return tensor.detach().cpu()\n\n    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        if torch.distributed.get_rank() == src_rank:\n            if tensor is None:\n                weight = None\n                tensor_shape = None\n            else:\n                weight = tensor\n                tensor_shape = weight.shape\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not exist, skip collect\")\n            return\n\n        if weight is None:\n            weight = torch.empty(\n                tensor_shape,\n                dtype=dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n\n        dist.broadcast(weight, src=src_rank, group=mp_group)\n\n        if torch.distributed.get_rank() == 0:\n            state_dict[name] = _get_cpu_tensor(weight)\n\n    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        # tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)\n            if mutate_func is not None:\n                full_tensor = mutate_func(full_tensor)\n            state_dict[name] = full_tensor\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        # tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            intermediate_size_tp = config.intermediate_size // tp_size\n            gate_weight_list = []\n            up_weight_list = []\n            for i in range(tp_size):\n                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n                gate_weight_list.append(gate_weight_tp)\n                up_weight_list.append(up_weight_tp)\n\n            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)\n            state_dict[up_name] = torch.cat(up_weight_list, dim=0)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        # tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            q_weight_list = []\n            k_weight_list = []\n            v_weight_list = []\n            hidden_size_per_head = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_size_chunk = q_size_tp // num_query_groups_per_partition\n                    kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                    for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                        q_part = qkv_part_chunk[:q_size_chunk]\n                        k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                        v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                        q_weight_list.append(q_part)\n                        k_weight_list.append(k_part)\n                        v_weight_list.append(v_part)\n            else:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_size_chunk = q_size_tp // num_query_groups_per_partition\n                    kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                    for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                        q_part = qkv_part_chunk[:q_size_chunk]\n                        k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                        v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                        q_weight_list.append(q_part)\n                        if i * config.num_key_value_heads % tp_size == 0:\n                            k_weight_list.append(k_part)\n                            v_weight_list.append(v_part)\n\n            state_dict[q_name] = torch.cat(q_weight_list, dim=0)\n            state_dict[k_name] = torch.cat(k_weight_list, dim=0)\n            state_dict[v_name] = torch.cat(v_weight_list, dim=0)\n\n    # empty cache before collecting weights\n    get_torch_device().empty_cache()\n    # Embeddings\n    # -------------------\n    if dp_rank == 0 and cp_rank == 0:  # models are identical across cp ranks\n        # Embeddings\n        # -------------------\n        print_rank_0(\"collecting embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        _broadcast_tp_shard_tensor(\n            gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None,\n            \"model.embed_tokens.weight\",\n            src_pp_rank=0,\n        )\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"collecting layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])\n            sync_layer = gpt_model_module.decoder.layers[src_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.self_attention.linear_qkv.layer_norm_weight,\n                f\"{layer_name}.input_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            if gpt_model_module.config.qk_layernorm:\n                _broadcast_tensor(\n                    sync_layer.self_attention.q_layernorm.weight,\n                    f\"{layer_name}.self_attn.q_norm.weight\",\n                    src_pp_rank=src_pp_rank,\n                )\n                _broadcast_tensor(\n                    sync_layer.self_attention.k_layernorm.weight,\n                    f\"{layer_name}.self_attn.k_norm.weight\",\n                    src_pp_rank=src_pp_rank,\n                )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attention.linear_qkv.weight,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            if gpt_model_module.config.add_qkv_bias:\n                _broadcast_tp_shard_tensor_qkv(\n                    sync_layer.self_attention.linear_qkv.bias,\n                    f\"{layer_name}.self_attn.q_proj.bias\",\n                    f\"{layer_name}.self_attn.k_proj.bias\",\n                    f\"{layer_name}.self_attn.v_proj.bias\",\n                    src_pp_rank=src_pp_rank,\n                )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attention.linear_proj.weight,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tensor(\n                sync_layer.mlp.linear_fc1.layer_norm_weight,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.linear_fc1.weight,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.linear_fc2.weight,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"collecting final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.decoder.final_layernorm, \"weight\", None),\n            \"model.norm.weight\",\n            src_pp_rank=pp_size - 1,\n        )\n\n        if tie_word_embeddings:\n            print_rank_0(\"tie word embedding skip load lm_head...\")\n        else:\n            print_rank_0(\"collecting lm_head...\")\n\n            if is_value_model:\n                lm_head_weight = None\n                if pp_rank == pp_size - 1:\n                    lm_head_weight = getattr(gpt_model_module.output_layer, \"weight\", None)\n                _broadcast_tensor(lm_head_weight, \"lm_head.weight\", src_pp_rank=pp_size - 1)\n\n            else:\n                _broadcast_tp_shard_tensor(\n                    getattr(gpt_model_module.output_layer, \"weight\", None) if pp_rank == pp_size - 1 else None,\n                    \"lm_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n\n    dist.barrier()\n    get_torch_device().empty_cache()\n    if torch.distributed.get_rank() == 0:\n        for k, v in state_dict.items():\n            if dtype != v.dtype:\n                state_dict[k] = v.to(dtype)\n\n    print_rank_0(f\"merge megatron ckpt done, time elapsed {time.time() - start_time}s\")\n    return state_dict\n\n\ndef merge_megatron_ckpt_gptmodel_qwen_moe(\n    wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False\n):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_qwen_moe is not implemented\")\n\n\ndef merge_megatron_ckpt_gptmodel_qwen2_5_vl(\n    wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False\n):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_qwen2_5_vl is not implemented\")\n\n\ndef merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_dpskv3 is not implemented\")\n\n\ndef merge_megatron_ckpt_gptmodel_mixtral(\n    wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False\n):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_mixtral is not implemented\")\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/util.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.packed_seq_params import PackedSeqParams\n\nfrom verl.utils.model import CausalLMOutputForPPO\n\n\ndef preprocess_packed_seqs(\n    input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True\n) -> tuple[torch.Tensor, PackedSeqParams]:\n    \"\"\"\n    Preprocess packed sequences\n    CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1\n    gets second and second last chunks, and so on), this is for load balancing with causal masking.\n    See https://github.com/NVIDIA/TransformerEngine/issues/1368\n    \"\"\"\n    batch_size = input_ids.shape[0]\n\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    cp_size = mpu.get_context_parallel_world_size()\n    cp_rank = mpu.get_context_parallel_rank()\n    align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size\n\n    pad_size = (align_size - seqlens_in_batch % align_size) % align_size\n    seqlens_in_batch_padded = seqlens_in_batch + pad_size\n\n    cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)\n    cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)\n    cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)\n    cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)\n\n    # ----------------------------------------------------------------------------\n    # Move the index information needed in the subsequent loop to the CPU at once,\n    # to avoid frequent .item() calls in the loop that cause D2H synchronization\n    # ----------------------------------------------------------------------------\n    seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist()  # original valid lengths\n    seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist()  # lengths after padding\n    cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist()  # start positions (after padding)\n\n    # Pure Python int calculation to avoid further synchronization\n    max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu)\n\n    shape = list(input_ids.shape[1:])\n    shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size\n    if pre_process:\n        input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)\n        for i in range(batch_size):\n            # Use Python int, so no GPU→CPU sync in the loop\n            if cp_size <= 1:\n                seqlen = seqlens_in_batch_cpu[i]\n                start_idx = cu_seqlens_padded_cpu[i]\n                input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]]\n                continue\n\n            seqlen_padded_i = seqlens_in_batch_padded_cpu[i]\n            seqlen = seqlen_padded_i // cp_size\n            half_seqlen = seqlen // 2\n            start_idx = cu_seqlens_padded_cpu[i] // cp_size\n            # split to 2 chunks\n            d = input_ids[i, attention_mask[i]]\n            input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[\n                half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)\n            ]\n\n            remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1)\n            remain_end = seqlen_padded_i - half_seqlen * cp_rank\n            remain_end = min(remain_end, d.shape[0])\n            remain_len = remain_end - remain_start\n            if remain_len > 0:\n                input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[\n                    remain_start:remain_end\n                ]\n\n    packed_seq_params = PackedSeqParams(\n        qkv_format=\"thd\",\n        cu_seqlens_q=cu_seqlens_padded,\n        max_seqlen_q=max_seqlen_in_batch,\n        cu_seqlens_kv=cu_seqlens_padded,\n        max_seqlen_kv=max_seqlen_in_batch,\n        cu_seqlens_q_padded=cu_seqlens_padded,\n        cu_seqlens_kv_padded=cu_seqlens_padded,\n    )\n    if pre_process:\n        return input_ids_rmpad.unsqueeze(0), packed_seq_params\n    else:\n        return input_ids, packed_seq_params\n\n\ndef postprocess_packed_seqs(\n    output: torch.Tensor,\n    packed_seq_params: PackedSeqParams,\n    attention_mask: torch.Tensor,\n    batch_size: int,\n    seq_len: int,\n    post_process: bool = True,\n) -> torch.Tensor:\n    \"\"\"\n    Postprocess packed sequences\n    \"\"\"\n    if not post_process:\n        return output\n\n    # -------------------------------------------------------------------------\n    # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance,\n    # to avoid a large number of .item() calls in the loop\n    # -------------------------------------------------------------------------\n    cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist()\n    seq_lens_cpu: list[int] = attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist()\n\n    shape = [batch_size, seq_len] + list(output.shape[2:])  # 1,packed, dim -> batch_size, seq_len, dim\n    output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)\n\n    cp_size = mpu.get_context_parallel_world_size()\n    # all gather output across context parallel group\n    if cp_size > 1:\n        # output shape: [1, packed_len, hidden_dim]\n        # need to gather across cp group and concatenate in sequence dimension\n        output_list = [torch.empty_like(output) for _ in range(cp_size)]\n        torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())\n        output_list[mpu.get_context_parallel_rank()] = output\n    else:\n        output_list = [output]\n    for i in range(batch_size):\n        if cp_size <= 1:\n            s = seq_lens_cpu[i]\n            start_idx = cu_padded_cpu[i]\n            output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s]\n            continue\n        s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size\n        half_seqlen = s_len_padded_chunk // 2\n        s_len = seq_lens_cpu[i]\n        s_len_padded = s_len_padded_chunk * cp_size\n        tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)\n        for j in range(cp_size):\n            o = output_list[j][0]\n            # split to 2 chunks\n            packed_start_idx = cu_padded_cpu[i] // cp_size\n            o0, o1 = (\n                o[packed_start_idx : packed_start_idx + half_seqlen],\n                o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],\n            )\n            tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0\n            tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1\n        output_new[i, attention_mask[i]] = tmp[:s_len]\n\n    return output_new\n\n\ndef preprocess_packed_seqs_no_padding(\n    input_ids: torch.Tensor, pre_process: bool = True\n) -> tuple[torch.Tensor, PackedSeqParams]:\n    \"\"\"\n    Preprocess packed sequences\n    CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1\n    gets second and second last chunks, and so on), this is for load balancing with causal masking.\n    See https://github.com/NVIDIA/TransformerEngine/issues/1368\n    \"\"\"\n    batch_size = input_ids.shape[0]\n\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    cp_size = mpu.get_context_parallel_world_size()\n    cp_rank = mpu.get_context_parallel_rank()\n    align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size\n    seqlens_in_batch = input_ids.offsets().diff()\n\n    pad_size = (align_size - seqlens_in_batch % align_size) % align_size\n    seqlens_in_batch_padded = seqlens_in_batch + pad_size\n\n    cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)\n    cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)\n    cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)\n    cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)\n\n    # ----------------------------------------------------------------------------\n    # Move the index information needed in the subsequent loop to the CPU at once,\n    # to avoid frequent .item() calls in the loop that cause D2H synchronization\n    # ----------------------------------------------------------------------------\n    seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist()  # original valid lengths\n    seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist()  # lengths after padding\n    cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist()  # start positions (after padding)\n\n    # Pure Python int calculation to avoid further synchronization\n    max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu)\n\n    shape = list(input_ids.shape[1:])\n    shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size\n    if pre_process:\n        input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)\n        for i in range(batch_size):\n            # Use Python int, so no GPU→CPU sync in the loop\n            if cp_size <= 1:\n                seqlen = seqlens_in_batch_cpu[i]\n                start_idx = cu_seqlens_padded_cpu[i]\n                input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i]\n                continue\n\n            seqlen_padded_i = seqlens_in_batch_padded_cpu[i]\n            seqlen = seqlen_padded_i // cp_size\n            half_seqlen = seqlen // 2\n            start_idx = cu_seqlens_padded_cpu[i] // cp_size\n            # split to 2 chunks\n            d = input_ids[i]\n            input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[\n                half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)\n            ]\n\n            remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1)\n            remain_end = seqlen_padded_i - half_seqlen * cp_rank\n            remain_end = min(remain_end, d.shape[0])\n            remain_len = remain_end - remain_start\n            if remain_len > 0:\n                input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[\n                    remain_start:remain_end\n                ]\n\n    packed_seq_params = PackedSeqParams(\n        qkv_format=\"thd\",\n        cu_seqlens_q=cu_seqlens_padded,\n        max_seqlen_q=max_seqlen_in_batch,\n        cu_seqlens_kv=cu_seqlens_padded,\n        max_seqlen_kv=max_seqlen_in_batch,\n        cu_seqlens_q_padded=cu_seqlens_padded,\n        cu_seqlens_kv_padded=cu_seqlens_padded,\n    )\n    if pre_process:\n        return input_ids_rmpad.unsqueeze(0), packed_seq_params\n    else:\n        return input_ids, packed_seq_params\n\n\ndef postprocess_packed_seqs_no_padding(\n    output: torch.Tensor,\n    packed_seq_params: PackedSeqParams,\n    input_ids: torch.Tensor,\n    batch_size: int,\n    post_process: bool = True,\n) -> torch.Tensor:\n    \"\"\"\n    Postprocess packed sequences\n    \"\"\"\n    if not post_process:\n        return output\n\n    # -------------------------------------------------------------------------\n    # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance,\n    # to avoid a large number of .item() calls in the loop\n    # -------------------------------------------------------------------------\n    cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist()\n    # The reason why we use input_ids.offsets() instead of packed_seq_params.cu_seqlens_q.diff()\n    # is that the latter one is the padded length, while the former one is the original length.\n    cu_seqlens = input_ids.offsets()\n    seq_lens_cpu: list[int] = cu_seqlens.diff().tolist()\n\n    output_new = []\n\n    cp_size = mpu.get_context_parallel_world_size()\n    # all gather output across context parallel group\n    if cp_size > 1:\n        # output shape: [1, packed_len, hidden_dim]\n        # need to gather across cp group and concatenate in sequence dimension\n        output_list = [torch.empty_like(output) for _ in range(cp_size)]\n        torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())\n        output_list[mpu.get_context_parallel_rank()] = output\n    else:\n        output_list = [output]\n\n    for i in range(batch_size):\n        if cp_size <= 1:\n            s = seq_lens_cpu[i]\n            start_idx = cu_padded_cpu[i]\n            output_new.append(output[0][start_idx : start_idx + s])\n            continue\n        s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size\n        half_seqlen = s_len_padded_chunk // 2\n        s_len = seq_lens_cpu[i]\n        s_len_padded = s_len_padded_chunk * cp_size\n        tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)\n        for j in range(cp_size):\n            o = output_list[j][0]\n            # split to 2 chunks\n            packed_start_idx = cu_padded_cpu[i] // cp_size\n            o0, o1 = (\n                o[packed_start_idx : packed_start_idx + half_seqlen],\n                o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],\n            )\n            tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0\n            tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1\n        output_new.append(tmp[:s_len])\n\n    output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged)\n\n    return output_new_tensor\n\n\ndef remove_left_padding(\n    input_ids: torch.Tensor,\n    attention_mask: torch.Tensor,\n    position_ids: torch.Tensor,\n    sequence_parallel: bool = False,\n    pre_process: bool = True,\n):\n    \"\"\"\n    Remove left padding from input_ids, attention_mask and position_ids\n    return new_input_ids, new_attention_mask, new_position_ids\n    \"\"\"\n    assert attention_mask.ndim == 2\n    assert position_ids.ndim == 2\n    cp_size = mpu.get_context_parallel_world_size()\n    assert cp_size == 1, \"Context parallel size without seq_pack is not supported\"\n    batch_size = input_ids.shape[0]\n    shape = list(input_ids.shape)  # batch_size, seq_len,...\n    seq_lens = attention_mask.sum(dim=1)\n    seq_len = seq_lens.max().item()\n    if sequence_parallel:\n        sp_world_size = mpu.get_tensor_model_parallel_world_size()\n        pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size\n        seq_len = seq_len + pad_size\n    shape[1] = seq_len\n    if pre_process:\n        new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape)\n    new_attention_mask = torch.zeros(\n        dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len)\n    )\n    new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len))\n    for i in range(batch_size):\n        if pre_process:\n            new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]]\n        new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]]\n        new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]]\n    if pre_process:\n        return new_input_ids, new_attention_mask, new_position_ids\n    else:\n        return input_ids, new_attention_mask, new_position_ids\n\n\ndef recover_left_padding(\n    result,\n    attention_mask: torch.Tensor,\n    original_attention_mask: torch.Tensor,\n    origin_seqlen: int,\n    post_process: bool = True,\n):\n    \"\"\"\n    Recover left padding from result\n    return result\n    \"\"\"\n    if not post_process:\n        return result\n    shape = list(result.shape)\n    batch_size = shape[0]\n    shape[1] = origin_seqlen\n    new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape)\n    for i in range(batch_size):\n        new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]]\n    return new_result\n\n\ndef postprocess_packed_seqs_for_dict_output(\n    labels_mask: torch.Tensor,\n    output: CausalLMOutputForPPO,\n    packed_seq_params: PackedSeqParams,\n    attention_mask: torch.Tensor,\n    batch_size: int,\n    seq_len: int,\n    post_process: bool = True,\n) -> dict[str, torch.Tensor]:\n    \"\"\"_summary_\n    For fused kernels, the output is a dictionary with keys like 'log_probs', 'entropy', etc.\n    This function post-processes each tensor in the output dictionary.\n    Args:\n        output (CausalLMOutputForPPO): _description_\n        packed_seq_params (PackedSeqParams): _description_\n        attention_mask (torch.Tensor): _description_\n        batch_size (int): _description_\n        seq_len (int): _description_\n        post_process (bool, optional): _description_. Defaults to True.\n    Returns:\n        CausalLMOutputForPPO: _description_\n    \"\"\"\n    ret = {}\n    output.entropy = output.entropy.view(1, -1)\n    output.log_probs = output.log_probs.view(1, -1)\n    output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0)\n    ret[\"entropy\"] = postprocess_packed_seqs(\n        output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n    )\n    ret[\"log_probs\"] = postprocess_packed_seqs(\n        output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n    )\n    return ret\n"
  },
  {
    "path": "verl_distillation/verl/models/mcore/weight_converter.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# online convert mcore weight to pure huggingface weight, no any fusion\n# including format conversion and name mapping\n# not including resharding\nimport torch\nfrom megatron.core.transformer import TransformerConfig\nfrom transformers import PretrainedConfig\n\n\nclass McoreToHFWeightConverterBase:\n    def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig):\n        self.hf_config = hf_config\n        self.mcore_config = mcore_config\n\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor:\n        raise NotImplementedError\n\n\nclass McoreToHFWeightConverterDense(McoreToHFWeightConverterBase):\n    def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # 'decoder.layers.0.self_attention.linear_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight'\n        # 'decoder.layers.0.self_attention.linear_qkv.weight'\n        # 'decoder.layers.0.self_attention.linear_qkv.bias'\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"self_attention.linear_qkv.bias\" in name or \"self_attention.linear_qkv.weight\" in name:\n            param_type = name.split(\".\")[-1]\n            assert param_type == \"bias\" or param_type == \"weight\"\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.q_proj.{param_type}\")\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.k_proj.{param_type}\")\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.v_proj.{param_type}\")\n            assert len(params) == 3\n        elif \"self_attention.linear_proj.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.o_proj.weight\")\n            assert len(params) == 1\n        elif \"self_attention.linear_qkv.layer_norm_weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.input_layernorm.weight\")\n            assert len(params) == 1\n        elif \"self_attention.q_layernorm.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.q_norm.weight\")\n            assert len(params) == 1\n        elif \"self_attention.k_layernorm.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.k_norm.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight'\n        # 'decoder.layers.0.mlp.linear_fc1.weight'\n        # 'decoder.layers.0.mlp.linear_fc2.weight'\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"mlp.linear_fc1.weight\" in name:\n            # split gate_proj and up_proj\n            convert_names.append(f\"model.layers.{layer_number}.mlp.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.up_proj.weight\")\n            assert len(params) == 2\n        elif \"mlp.linear_fc1.layer_norm_weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n            assert len(params) == 1\n        elif \"mlp.linear_fc2.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.down_proj.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        direct_name_mapping = {\n            \"embedding.word_embeddings.weight\": \"model.embed_tokens.weight\",\n            \"decoder.final_layernorm.weight\": \"model.norm.weight\",\n            \"output_layer.weight\": \"lm_head.weight\",\n        }\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params_one_group[0]]\n\n        if \"self_attention\" in name:\n            return self._convert_attention_param(name, params_one_group)\n        elif \"mlp\" in name:\n            return self._convert_mlp_param(name, params_one_group)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n\nclass McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense):\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # 'decoder.layers.0.pre_mlp_layernorm.weight',\n        # 'decoder.layers.0.mlp.router.weight',\n        # 'decoder.layers.0.mlp.shared_experts.gate_weight',\n        # 'decoder.layers.0.mlp.shared_experts.linear_fc1.weight',\n        # 'decoder.layers.0.mlp.shared_experts.linear_fc2.weight'\n        # moe1\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight1',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight2',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight3',\n        # moe2\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight1',\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"pre_mlp_layernorm\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n            assert len(params) == 1\n        elif \"mlp.router.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.gate.weight\")\n            assert len(params) == 1\n        elif \"shared_experts.gate_weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert_gate.weight\")\n            assert len(params) == 1\n        elif \"shared_experts.linear_fc1.weight\" in name:  # split gate_proj and up_proj\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight\")\n            assert len(params) == 2\n        elif \"shared_experts.linear_fc2.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight\")\n            assert len(params) == 1\n        elif \"mlp.experts.linear_fc1\" in name:  # split gate_proj and up_proj\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight\")\n            assert len(params) == 2\n        elif \"mlp.experts.linear_fc2\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n\nclass McoreToHFWeightConverterQwen2_5_VL(McoreToHFWeightConverterDense):\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        direct_name_mapping = {\n            \"language_model.embedding.word_embeddings.weight\": \"model.embed_tokens.weight\",\n            \"language_model.decoder.final_layernorm.weight\": \"model.norm.weight\",\n            \"language_model.output_layer.weight\": \"lm_head.weight\",\n            \"vision_model.patch_embed.proj.weight\": \"visual.patch_embed.proj.weight\",\n            \"vision_model.decoder.final_layernorm.weight\": \"visual.merger.ln_q.weight\",\n            \"vision_model.projection.encoder.linear_fc1.weight\": \"visual.merger.mlp.0.weight\",\n            \"vision_model.projection.encoder.linear_fc1.bias\": \"visual.merger.mlp.0.bias\",\n            \"vision_model.projection.encoder.linear_fc2.weight\": \"visual.merger.mlp.2.weight\",\n            \"vision_model.projection.encoder.linear_fc2.bias\": \"visual.merger.mlp.2.bias\",\n        }\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params_one_group[0]]\n\n        if \"self_attention\" in name:\n            return self._convert_attention_param(name, params_one_group)\n        elif \"mlp\" in name:\n            return self._convert_mlp_param(name, params_one_group)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n    def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        model_type, _, _, layer_number = name.split(\".\")[:4]\n\n        convert_names = []\n        if model_type == \"language_model\":\n            name_map_after_layer = {\n                \"self_attention.linear_qkv.bias\": [\n                    \"self_attn.q_proj.bias\",\n                    \"self_attn.k_proj.bias\",\n                    \"self_attn.v_proj.bias\",\n                ],\n                \"self_attention.linear_qkv.weight\": [\n                    \"self_attn.q_proj.weight\",\n                    \"self_attn.k_proj.weight\",\n                    \"self_attn.v_proj.weight\",\n                ],\n                \"self_attention.linear_proj.weight\": \"self_attn.o_proj.weight\",\n                \"self_attention.linear_qkv.layer_norm_weight\": \"input_layernorm.weight\",\n            }\n            name_after_layer = \".\".join(name.split(\".\")[-3:])\n            mapped_name = name_map_after_layer.get(name_after_layer)\n            if isinstance(mapped_name, list):\n                assert len(params) == len(mapped_name)\n                for one in mapped_name:\n                    convert_names.append(f\"model.layers.{layer_number}.{one}\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"model.layers.{layer_number}.{mapped_name}\")\n        elif model_type == \"vision_model\":\n            name_map_after_layer = {\n                \"self_attention.linear_proj.weight\": \"attn.proj.weight\",\n                \"self_attention.linear_proj.bias\": \"attn.proj.bias\",\n                \"self_attention.linear_qkv.layer_norm_weight\": \"norm1.weight\",\n            }\n            name_after_layer = \".\".join(name.split(\".\")[-3:])\n            mapped_name = name_map_after_layer.get(name_after_layer, None)\n            if mapped_name is None:\n                assert \"linear_qkv\" in name_after_layer\n                assert len(params) == 3\n                new_param = torch.cat(params, dim=0)\n                params = [new_param]\n                if \"bias\" in name_after_layer:\n                    convert_names.append(f\"visual.blocks.{layer_number}.attn.qkv.bias\")\n                else:\n                    convert_names.append(f\"visual.blocks.{layer_number}.attn.qkv.weight\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"visual.blocks.{layer_number}.{mapped_name}\")\n        else:\n            raise NotImplementedError(f\"Unsupported model type: {model_type}\")\n        return convert_names, params\n\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        model_type, _, _, layer_number = name.split(\".\")[:4]\n\n        convert_names = []\n        if model_type == \"language_model\":\n            name_map_after_layer = {\n                \"mlp.linear_fc1.weight\": [\"mlp.gate_proj.weight\", \"mlp.up_proj.weight\"],\n                \"mlp.linear_fc1.bias\": [\"mlp.gate_proj.bias\", \"mlp.up_proj.bias\"],\n                \"mlp.linear_fc2.weight\": \"mlp.down_proj.weight\",\n                \"mlp.linear_fc2.bias\": \"mlp.down_proj.bias\",\n                \"mlp.linear_fc1.layer_norm_weight\": \"post_attention_layernorm.weight\",\n            }\n            name_after_layer = \".\".join(name.split(\".\")[-3:])\n            mapped_name = name_map_after_layer.get(name_after_layer)\n            if isinstance(mapped_name, list):\n                assert len(params) == len(mapped_name)\n                for one in mapped_name:\n                    convert_names.append(f\"model.layers.{layer_number}.{one}\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"model.layers.{layer_number}.{mapped_name}\")\n\n        elif model_type == \"vision_model\":\n            name_map_after_layer = {\n                \"mlp.linear_fc1.weight\": [\"mlp.gate_proj.weight\", \"mlp.up_proj.weight\"],\n                \"mlp.linear_fc1.bias\": [\"mlp.gate_proj.bias\", \"mlp.up_proj.bias\"],\n                \"mlp.linear_fc2.weight\": \"mlp.down_proj.weight\",\n                \"mlp.linear_fc2.bias\": \"mlp.down_proj.bias\",\n                \"mlp.linear_fc1.layer_norm_weight\": \"norm2.weight\",\n            }\n            name_after_layer = \".\".join(name.split(\".\")[-3:])\n            mapped_name = name_map_after_layer.get(name_after_layer)\n            if isinstance(mapped_name, list):\n                assert len(params) == len(mapped_name)\n                for one in mapped_name:\n                    convert_names.append(f\"visual.blocks.{layer_number}.{one}\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"visual.blocks.{layer_number}.{mapped_name}\")\n        else:\n            raise NotImplementedError(f\"Unsupported model type: {model_type}\")\n        return convert_names, params\n\n\nclass McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase):\n    def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # mcore\n        # 'decoder.layers.0.input_layernorm.weight'\n        # 'decoder.layers.0.self_attention.linear_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_kv_down_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_kv_up_proj.layer_norm_weight'\n        # 'decoder.layers.0.self_attention.linear_kv_up_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_down_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_up_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_up_proj.layer_norm_weight'\n        # hf\n        # 'model.layers.0.input_layernorm.weight'\n        # 'model.layers.0.self_attn.o_proj.weight'\n        # 'model.layers.0.self_attn.q_proj.weight'\n        # 'model.layers.0.self_attn.kv_a_proj_with_mqa.weight'\n        # 'model.layers.0.self_attn.kv_a_layernorm.weight'\n        # 'model.layers.0.self_attn.kv_b_proj.weight'\n        # 'model.layers.0.self_attn.q_a_proj.weight'\n        # 'model.layers.0.self_attn.q_b_proj.weight'\n        # 'model.layers.0.self_attn.q_a_layernorm.weight'\n        name_map_after_layer = {\n            \"input_layernorm.weight\": \"input_layernorm.weight\",\n            \"self_attention.linear_proj.weight\": \"self_attn.o_proj.weight\",\n            \"self_attention.linear_q_proj.weight\": \"self_attn.q_proj.weight\",\n            \"self_attention.linear_kv_down_proj.weight\": \"self_attn.kv_a_proj_with_mqa.weight\",\n            \"self_attention.linear_kv_up_proj.layer_norm_weight\": \"self_attn.kv_a_layernorm.weight\",\n            \"self_attention.linear_kv_up_proj.weight\": \"self_attn.kv_b_proj.weight\",\n            \"self_attention.linear_q_down_proj.weight\": \"self_attn.q_a_proj.weight\",\n            \"self_attention.linear_q_up_proj.weight\": \"self_attn.q_b_proj.weight\",\n            \"self_attention.linear_q_up_proj.layer_norm_weight\": \"self_attn.q_a_layernorm.weight\",\n        }\n        assert len(params) == 1\n        convert_names = []\n        layer_number = name.split(\".\")[2]\n        name_after_layer = name.split(f\".{layer_number}.\")[1]\n        convert_names.append(f\"model.layers.{layer_number}.{name_map_after_layer[name_after_layer]}\")\n        return convert_names, params\n\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # mcore dense\n        # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight'\n        # 'decoder.layers.0.mlp.linear_fc2.weight'\n        # 'decoder.layers.0.mlp.linear_fc1.weight'\n        #       ---\n        # 'decoder.layers.1.mlp.shared_experts.linear_fc1.weight'\n        #       ---\n        # 'decoder.layers.1.mlp.shared_experts.linear_fc2.weight'\n        # hf dense\n        # 'model.layers.0.post_attention_layernorm.weight'\n        # 'model.layers.0.mlp.down_proj.weight'\n        # 'model.layers.0.mlp.gate_proj.weight'\n        # 'model.layers.0.mlp.up_proj.weight'\n        # 'model.layers.1.mlp.shared_experts.gate_proj.weight'\n        # 'model.layers.1.mlp.shared_experts.up_proj.weight'\n        # 'model.layers.1.mlp.shared_experts.down_proj.weight'\n\n        # mcore moe\n        # 'decoder.layers.1.pre_mlp_layernorm.weight'\n        # 'decoder.layers.1.mlp.router.weight'\n        # 'decoder.layers.1.mlp.router.expert_bias'\n        # 'decoder.layers.1.mlp.experts.linear_fc1.weight0'\n        #       ---\n        # 'decoder.layers.1.mlp.experts.linear_fc2.weight0'\n        # hf moe\n        # 'model.layers.1.post_attention_layernorm.weight'\n        # 'model.layers.1.mlp.gate.weight'\n        # 'model.layers.1.mlp.gate.e_score_correction_bias'\n        # 'model.layers.1.mlp.experts.0.gate_proj.weight'\n        # 'model.layers.1.mlp.experts.0.up_proj.weight'\n        # 'model.layers.1.mlp.experts.0.down_proj.weight'\n\n        name_map_after_layer = {\n            \"mlp.linear_fc1.layer_norm_weight\": \"post_attention_layernorm.weight\",\n            \"mlp.linear_fc2.weight\": \"mlp.down_proj.weight\",\n            \"mlp.shared_experts.linear_fc2.weight\": \"mlp.shared_experts.down_proj.weight\",\n            \"mlp.linear_fc1.weight\": [\"mlp.gate_proj.weight\", \"mlp.up_proj.weight\"],\n            \"mlp.shared_experts.linear_fc1.weight\": [\n                \"mlp.shared_experts.gate_proj.weight\",\n                \"mlp.shared_experts.up_proj.weight\",\n            ],\n            \"pre_mlp_layernorm.weight\": \"post_attention_layernorm.weight\",\n            \"mlp.router.weight\": \"mlp.gate.weight\",\n            \"mlp.router.expert_bias\": \"mlp.gate.e_score_correction_bias\",\n        }\n        convert_names = []\n        layer_number = name.split(\".\")[2]\n        name_after_layer = name.split(f\".{layer_number}.\")[1]\n        if name_after_layer in name_map_after_layer:\n            mapped_name = name_map_after_layer[name_after_layer]\n            if isinstance(mapped_name, list):\n                assert len(params) == len(mapped_name)\n                for one in mapped_name:\n                    convert_names.append(f\"model.layers.{layer_number}.{one}\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"model.layers.{layer_number}.{mapped_name}\")\n        else:\n            if \"mlp.experts.linear_fc1.weight\" in name:\n                expert_id = name.split(\"weight\")[-1]\n                convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight\")\n                convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight\")\n                assert len(params) == 2\n            elif \"mlp.experts.linear_fc2.weight\" in name:\n                expert_id = name.split(\"weight\")[-1]\n                convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight\")\n                assert len(params) == 1\n            else:\n                raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n        return convert_names, params\n\n    def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        assert self.mcore_config.mtp_num_layers == 1, \"only support one mtp layer for now\"\n        assert self.mcore_config.num_layers == 61, \"only support 61 layers for now\"\n        direct_name_mapping = {\n            \"mtp.layers.0.enorm.weight\": \"model.layers.61.enorm.weight\",\n            \"mtp.layers.0.hnorm.weight\": \"model.layers.61.hnorm.weight\",\n            \"mtp.layers.0.eh_proj.weight\": \"model.layers.61.eh_proj.weight\",\n            \"mtp.layers.0.final_layernorm.weight\": \"model.layers.61.shared_head.norm.weight\",\n        }\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params[0]]\n        assert \"mtp.layers.0.transformer_layer\" in name, \"only support transformer layer for now\"\n        # use proxy name to convert\n        proxy_name = name.replace(\"mtp.layers.0.transformer_layer\", \"decoder.layers.61\")\n        if \"self_attention\" in proxy_name or \"input_layernorm.weight\" in proxy_name:\n            convert_names, params = self._convert_attention_param(proxy_name, params)\n        elif \"mlp\" in proxy_name:\n            convert_names, params = self._convert_mlp_param(proxy_name, params)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        direct_name_mapping = {\n            \"embedding.word_embeddings.weight\": \"model.embed_tokens.weight\",\n            \"decoder.final_layernorm.weight\": \"model.norm.weight\",\n            \"output_layer.weight\": \"lm_head.weight\",\n        }\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params_one_group[0]]\n        if \"mtp\" in name:\n            return self._convert_mtp_param(name, params_one_group)\n        elif \"self_attention\" in name or \"input_layernorm.weight\" in name:\n            return self._convert_attention_param(name, params_one_group)\n        elif \"mlp\" in name:\n            return self._convert_mlp_param(name, params_one_group)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n\nclass McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense):\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # decoder.layers.0.mlp.router.weight\n        # decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7\n        # decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7\n\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"pre_mlp_layernorm\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n        elif \"mlp.router.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.gate.weight\")\n        elif \"mlp.experts.linear_fc1.weight\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight\")\n        elif \"mlp.experts.linear_fc2.weight\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight\")\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n\nclass McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense):\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # qwen3 moe no share expert\n\n        # 'decoder.layers.0.pre_mlp_layernorm.weight',\n        # 'decoder.layers.0.mlp.router.weight',\n        # moe1\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight1',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight2',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight3',\n        # moe2\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight1',\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"pre_mlp_layernorm\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n            assert len(params) == 1\n        elif \"mlp.router.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.gate.weight\")\n            assert len(params) == 1\n        elif \"mlp.experts.linear_fc1\" in name:  # split gate_proj and up_proj\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight\")\n            assert len(params) == 2\n        elif \"mlp.experts.linear_fc2\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .modeling_qwen2_megatron import (\n    ParallelQwen2ForCausalLM,\n    # rmpad with megatron\n    ParallelQwen2ForCausalLMRmPad,\n    # rmpad with megatron and pipeline parallelism\n    ParallelQwen2ForCausalLMRmPadPP,\n    ParallelQwen2ForValueRmPad,\n    ParallelQwen2ForValueRmPadPP,\n    # original model with megatron\n    ParallelQwen2Model,\n)\n\n__all__ = [\n    \"ParallelQwen2ForCausalLM\",\n    \"ParallelQwen2ForCausalLMRmPad\",\n    \"ParallelQwen2ForCausalLMRmPadPP\",\n    \"ParallelQwen2ForValueRmPad\",\n    \"ParallelQwen2ForValueRmPadPP\",\n    \"ParallelQwen2Model\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/checkpoint_utils/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_qwen2(\n    state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False\n):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from verl.utils.logger import print_rank_0\n    from verl.utils.megatron_utils import unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def fetch_params(module):\n        for param in module.parameters():\n            torch.distributed.fetch(\n                param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()\n            )\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, (\n        f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: \"\n        f\"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n    )\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _fetch_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"fetch tensor\"\"\"\n        nonlocal state_dict\n        if tensor is not None:\n            tensor = tensor.data.copy_(state_dict[name], non_blocking=True)\n\n    def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"fetch gate_up tensor in tp shards\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if gate_name in state_dict and up_name in state_dict:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(\n                config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(\n                    torch.cat([gate_weight_tp, up_weight_tp], dim=0)\n                )\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            if tensor is not None:\n                tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n        else:\n            print(f\"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n        full_weight_q = state_dict[q_name]\n        full_weight_k = state_dict[k_name]\n        full_weight_v = state_dict[v_name]\n\n        hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n        if config.num_key_value_heads >= tp_size:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n            total_size = q_size_tp + 2 * kv_size_tp\n            if not bias:\n                new_weight_qkv = torch.empty(\n                    total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                )\n            else:\n                new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        else:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head\n            total_size = q_size_tp + 2 * kv_size_tp\n            if not bias:\n                new_weight_qkv = torch.empty(\n                    total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                )\n            else:\n                new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                k_part = full_weight_k[start_idx:end_idx]\n                v_part = full_weight_v[start_idx:end_idx]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n        if tensor is not None:\n            tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n\n    # Embeddings\n    # -------------------\n    print_rank_0(\"loading embeddings...\")\n    gpt_model_module = _get_gpt_model(models[0])\n    if pp_rank == 0:\n        embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n        _fetch_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n    # Transformer layers\n    # -------------------\n    layer_map = _megatron_calc_layer_map(config)\n\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    num_layer_per_pp = config.num_hidden_layers // pp_size\n    vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n\n    layer_list = []\n    if vpp_size is not None:\n        for vpp_rank in range(vpp_size):\n            num_layer_vpp_chunk = num_layer_per_pp // vpp_size\n            num_layer_this_model = num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (\n                mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk\n            )\n            layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n    else:\n        num_layer_this_model = num_layer_per_pp\n        offset = pp_rank * num_layer_per_pp\n        layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n\n    for layer in layer_list:\n        print(f\"{torch.distributed.get_rank()} loading layer #{layer}...\")\n        layer_name = f\"model.layers.{layer}\"\n        dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n        print(\n            f\"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, \"\n            f\"layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}\"\n        )\n\n        gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n        sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n        _fetch_tensor(\n            sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.input_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_qkv(\n            sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.q_proj.weight\",\n            f\"{layer_name}.self_attn.k_proj.weight\",\n            f\"{layer_name}.self_attn.v_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor_qkv(\n            sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.q_proj.bias\",\n            f\"{layer_name}.self_attn.k_proj.bias\",\n            f\"{layer_name}.self_attn.v_proj.bias\",\n            bias=True,\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.o_proj.weight\",\n            chunk_dim=1,\n        )\n\n        _fetch_tensor(\n            sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.post_attention_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_gate_up(\n            sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.gate_proj.weight\",\n            f\"{layer_name}.mlp.up_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.down_proj.weight\",\n            chunk_dim=1,\n        )\n    # Final Layernorm\n    # -------------------\n    print_rank_0(\"loading final layernorm...\")\n    gpt_model_module = _get_gpt_model(models[-1])\n    _fetch_tensor(\n        getattr(gpt_model_module.model.norm, \"weight\", None),\n        \"model.norm.weight\",\n    )\n\n    if tie_word_embeddings:\n        print_rank_0(\"tie_word_embeddings skip load lm_head\")\n    else:\n        print_rank_0(\"loading lm_head...\")\n        if pp_rank + 1 == pp_size:\n            lm_head_weight = gpt_model_module.lm_head.weight\n\n            if is_value_model:\n                if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                    _fetch_tensor(lm_head_weight, \"lm_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                    _fetch_tensor(lm_head_weight, \"reward_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                else:\n                    _fetch_tensor(None, \"lm_head.weight\")\n                    print_rank_0(\"fail to match lm_head in value_model\")\n\n            else:\n                _fetch_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n\n    dist.barrier()\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_qwen2(\n    state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False\n):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from verl.utils.logger import print_rank_0\n    from verl.utils.megatron_utils import unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def broadcast_params(module):\n        for param in module.parameters():\n            torch.distributed.broadcast(\n                param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()\n            )\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, (\n        f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: \"\n        f\"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n    )\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _broadcast_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"broadcast tensor from rank0 across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                weight = state_dict[name]\n                tensor_shape = weight.shape\n            else:\n                tensor_shape = None\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not in state_dict, skip load\")\n            return\n\n        if tensor is None:\n            tensor = torch.empty(\n                tensor_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        if torch.distributed.get_rank() == 0:\n            tensor.data.copy_(weight)\n        dist.broadcast(tensor, src=0, group=mp_group)\n\n    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(\n                config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(\n                    torch.cat([gate_weight_tp, up_weight_tp], dim=0)\n                )\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape \"\n                f\"{tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n            full_weight_q = state_dict[q_name]\n            full_weight_k = state_dict[k_name]\n            full_weight_v = state_dict[v_name]\n\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                if not bias:\n                    new_weight_qkv = torch.empty(\n                        total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                    )\n                else:\n                    new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(\n                        torch.cat([q_part, k_part, v_part], dim=0)\n                    )\n\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                if not bias:\n                    new_weight_qkv = torch.empty(\n                        total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                    )\n                else:\n                    new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                    k_part = full_weight_k[start_idx:end_idx]\n                    v_part = full_weight_v[start_idx:end_idx]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(\n                        torch.cat([q_part, k_part, v_part], dim=0)\n                    )\n\n            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"loading embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        embed_tokens_weight = None\n        if pp_rank == 0:\n            embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"loading layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.input_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.bias\",\n                f\"{layer_name}.self_attn.k_proj.bias\",\n                f\"{layer_name}.self_attn.v_proj.bias\",\n                bias=True,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                chunk_dim=1,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                chunk_dim=1,\n            )\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"loading final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n        )\n\n        if tie_word_embeddings:\n            print_rank_0(\"tie_word_embeddings skip load lm_head\")\n        else:\n            print_rank_0(\"loading lm_head...\")\n            lm_head_weight = None\n            if pp_rank + 1 == pp_size:\n                lm_head_weight = gpt_model_module.lm_head.weight\n\n            if is_value_model:\n                if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                    _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                    _broadcast_tensor(lm_head_weight, \"reward_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                else:\n                    _broadcast_tensor(None, \"lm_head.weight\")\n                    print_rank_0(\"fail to match lm_head in value_model\")\n\n            else:\n                _broadcast_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n\n    dist.barrier()\n    # Broadcast weights inside data parallel groups\n    for wrapped_model in wrapped_models:\n        broadcast_params(wrapped_model)\n\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import mpu\nfrom megatron.core.distributed import DistributedDataParallel as LocalDDP\nfrom megatron.core.transformer.module import Float16Module\nfrom torch.nn.parallel import DistributedDataParallel as torchDDP\n\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.logger import print_rank_0\nfrom verl.utils.megatron_utils import unwrap_model\n\n\ndef _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):\n    \"\"\"given TP,DP,PP rank to get the global rank.\"\"\"\n\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    dp_size = mpu.get_data_parallel_world_size()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), (\n        f\"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}\"\n    )\n    # We only support TP-DP-PP grouping, for correctness when resharding\n    return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Merge sharded parameters of a Megatron module into a merged checkpoint.\n\n    Args:\n        wrapped_models (list of megatron.core.distributed.DistributedDataParallel):\n            The local DDP wrapped megatron modules.\n        config (str or None):\n            HF config for model\n        dtype: model params type\n        is_value_model: if model is value model\n        tie_word_embeddings: tie_word_embeddings\n    Returns:\n        state_dict (dict):\n            The merged state_dict in rank 0, and an empty dictionary in other ranks.\n    \"\"\"\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if dist.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        assert len(models[i].model.layers) == num_layers_per_model, (\n            \"len model layers {} not equal to num_layers_per_model {}\".format(\n                len(models[i].model.layers), num_layers_per_model\n            )\n        )\n\n    state_dict = dict()\n\n    def _get_cpu_tensor(tensor: torch.Tensor):\n        if tensor is None:\n            return None\n        if tensor.device == torch.device(\"cpu\"):\n            return tensor.detach().clone()\n        return tensor.detach().cpu()\n\n    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        if torch.distributed.get_rank() == src_rank:\n            if tensor is None:\n                weight = None\n                tensor_shape = None\n            else:\n                weight = tensor\n                tensor_shape = weight.shape\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not exist, skip collect\")\n            return\n\n        if weight is None:\n            weight = torch.empty(\n                tensor_shape,\n                dtype=dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n\n        dist.broadcast(weight, src=src_rank, group=mp_group)\n\n        if torch.distributed.get_rank() == 0:\n            state_dict[name] = _get_cpu_tensor(weight)\n\n    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)\n            if mutate_func is not None:\n                full_tensor = mutate_func(full_tensor)\n            state_dict[name] = full_tensor\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            intermediate_size_tp = config.intermediate_size // tp_size\n            gate_weight_list = []\n            up_weight_list = []\n            for i in range(tp_size):\n                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n                gate_weight_list.append(gate_weight_tp)\n                up_weight_list.append(up_weight_tp)\n\n            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)\n            state_dict[up_name] = torch.cat(up_weight_list, dim=0)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            q_weight_list = []\n            k_weight_list = []\n            v_weight_list = []\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    k_weight_list.append(k_part)\n                    v_weight_list.append(v_part)\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    if i * config.num_key_value_heads % tp_size == 0:\n                        k_weight_list.append(k_part)\n                        v_weight_list.append(v_part)\n\n            state_dict[q_name] = torch.cat(q_weight_list, dim=0)\n            state_dict[k_name] = torch.cat(k_weight_list, dim=0)\n            state_dict[v_name] = torch.cat(v_weight_list, dim=0)\n\n    # empty cache before collecting weights\n    get_torch_device().empty_cache()\n    # Embeddings\n    # -------------------\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"collecting embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        _broadcast_tp_shard_tensor(\n            gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,\n            \"model.embed_tokens.weight\",\n            src_pp_rank=0,\n        )\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"collecting layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[src_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight,\n                f\"{layer_name}.input_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.bias,\n                f\"{layer_name}.self_attn.q_proj.bias\",\n                f\"{layer_name}.self_attn.k_proj.bias\",\n                f\"{layer_name}.self_attn.v_proj.bias\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"collecting final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n            src_pp_rank=pp_size - 1,\n        )\n\n        if tie_word_embeddings:\n            print_rank_0(\"tie word embedding skip load lm_head...\")\n        else:\n            print_rank_0(\"collecting lm_head...\")\n\n            if is_value_model:\n                _broadcast_tensor(\n                    gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,\n                    \"lm_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n                _broadcast_tensor(\n                    gpt_model_module.reward_head.weight\n                    if pp_rank == pp_size - 1 and getattr(gpt_model_module, \"reward_weight\", None) is not None\n                    else None,\n                    \"reward_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n\n            else:\n                _broadcast_tp_shard_tensor(\n                    getattr(gpt_model_module.lm_head, \"weight\", None) if pp_rank == pp_size - 1 else None,\n                    \"lm_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n\n    dist.barrier()\n\n    get_torch_device().empty_cache()\n    if torch.distributed.get_rank() == 0:\n        for k, v in state_dict.items():\n            if dtype != v.dtype:\n                state_dict[k] = v.to(dtype)\n\n    print_rank_0(f\"merge megatron ckpt done, time elapsed {time.time() - start_time}s\")\n    return state_dict\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/layers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .parallel_attention import ParallelQwen2Attention\nfrom .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad\nfrom .parallel_mlp import ParallelQwen2MLP\nfrom .parallel_rmsnorm import ParallelQwen2RMSNorm\n\n__all__ = [\n    \"ParallelQwen2Attention\",\n    \"ParallelQwen2DecoderLayer\",\n    \"ParallelQwen2DecoderLayerRmPad\",\n    \"ParallelQwen2MLP\",\n    \"ParallelQwen2RMSNorm\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/layers/parallel_attention.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Optional\n\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom transformers.utils import is_flash_attn_2_available\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa: F401\n\nimport torch\nfrom flash_attn.layers.rotary import apply_rotary_emb\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers import Qwen2Config\n\nfrom verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear\nfrom verl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass Qwen2RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n        )\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\nclass Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding):\n    \"\"\"Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n        t = t / self.scaling_factor\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\nclass Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding):\n    \"\"\"Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * (\n                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)\n            ) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass ParallelQwen2Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config = config\n        self.megatron_config = megatron_config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n\n        # assign values after tp\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert self.num_heads % tp_size == 0, (\n            f\"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}\"\n        )\n        assert self.num_key_value_heads % tp_size == 0, (\n            f\"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=\"\n            f\"{self.num_key_value_heads}, tp_size={tp_size}\"\n        )\n\n        self.num_heads_per_tp = self.num_heads // tp_size\n        self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size\n        self.hidden_size_per_tp = self.hidden_size // tp_size\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and \"\n                f\"`num_heads`: {self.num_heads}).\"\n            )\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n\n        # [self.q_size, self.k_size, self.v_size]\n        self.qkv_proj = QKVParallelLinear(\n            input_size=self.hidden_size,\n            num_heads=self.num_heads,\n            num_key_value_heads=self.num_key_value_heads,\n            head_dim=self.head_dim,\n            # bias=config.attention_bias,\n            bias=True,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n        self.q_size = self.num_heads_per_tp * self.head_dim\n        self.k_size = self.num_key_value_heads_per_tp * self.head_dim\n        self.v_size = self.num_key_value_heads_per_tp * self.head_dim\n\n        self.o_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.num_heads * self.head_dim,\n            output_size=self.hidden_size,\n            # bias=config.attention_bias,\n            bias=False,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self._init_rope()\n\n    def _init_rope(self):\n        self.rotary_emb = Qwen2RotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, \"\n                f\"but is {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, \"\n                f\"but is {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)\n        attn_output = self.o_proj(attn_output)[0]\n        return attn_output\n\n\n\"\"\"\nRemove padding Attention\n- Using Flash-attn 2\n- Compatible with sequence parallel\n\"\"\"\n\n\ndef apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):\n    batch_size = position_ids.shape[0]\n\n    q = pad_input(q, indices, batch_size, sequence_length)  # (batch_size, seqlen, num_head, head_dim)\n    k = pad_input(k, indices, batch_size, sequence_length)\n    cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n\n    q_embed = index_first_axis(rearrange(q_embed, \"b s ... -> (b s) ...\"), indices)\n    k_embed = index_first_axis(rearrange(k_embed, \"b s ... -> (b s) ...\"), indices)\n\n    return q_embed, k_embed\n\n\n# use flash-attn rotary embeddings with rmpad\n# cos/sin shoudl be: (seq_length, rotary_dim / 2)\ndef apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):\n    q_embed = apply_rotary_emb(\n        q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen\n    )\n    k_embed = apply_rotary_emb(\n        k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen\n    )\n    return q_embed, k_embed\n\n\nclass ParallelQwen2AttentionRmPad(ParallelQwen2Attention):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: torch.Tensor = None,\n        max_seqlen_in_batch: int = None,\n    ):\n        total_nnz, _, _ = hidden_states.size()  # This is the total_nnz padded after sequence parallel\n\n        if self.megatron_config.sequence_parallel:\n            total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()\n\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split(\n            [self.q_size, self.k_size, self.v_size], dim=-1\n        )  # (total_nnz, 1, hidden_size)\n\n        if self.megatron_config.sequence_parallel:\n            sequence_parallel_pad = total_nnz - cu_seqlens[-1]\n            total_nnz = cu_seqlens[-1]  # total_nnz before sp padding\n            query_states = query_states[:total_nnz]\n            key_states = key_states[:total_nnz]\n            value_states = value_states[:total_nnz]\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dime x hidden_dim\n        # therefore we just need to keep the original shape\n        query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)\n        key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n        value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n\n        cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)\n        cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2]  # flash attn only needs half\n        query_states, key_states = apply_rotary_pos_emb_rmpad_flash(\n            query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch\n        )\n        # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin,\n        # position_ids, indices,\n\n        # It is recommended to use dropout with FA according to the docs\n        # when training.\n        dropout_rate = 0.0  # if not self.training else self.attn_dropout\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in float16 just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (Qwen2RMSNorm handles it correctly)\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            query_states = query_states.to(torch.float16)\n            key_states = key_states.to(torch.float16)\n            value_states = value_states.to(torch.float16)\n\n        attn_output_unpad = flash_attn_varlen_func(\n            query_states,\n            key_states,\n            value_states,\n            cu_seqlens_q=cu_seqlens,\n            cu_seqlens_k=cu_seqlens,\n            max_seqlen_q=max_seqlen_in_batch,\n            max_seqlen_k=max_seqlen_in_batch,\n            dropout_p=dropout_rate,\n            softmax_scale=None,\n            causal=True,\n        )\n\n        attn_output_unpad = attn_output_unpad.to(input_dtype)\n        attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()\n\n        # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled\n        # Here we need to repad\n        if self.megatron_config.sequence_parallel:\n            attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))\n\n        attn_output_unpad = self.o_proj(attn_output_unpad)[0]\n        return attn_output_unpad\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/layers/parallel_decoder.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import Qwen2Config\n\nfrom verl.utils.megatron_utils import TransformerConfig, convert_config\n\nfrom .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad\nfrom .parallel_mlp import ParallelQwen2MLP\nfrom .parallel_rmsnorm import ParallelQwen2RMSNorm\n\n\nclass ParallelQwen2DecoderLayer(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Note: sequence parallel is hidden inside ColumnParallelLinear\n        # reduce scatter is hidden inside RowParallelLinear\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        # TODO: add sequence parallel operator all_gather here\n\n        hidden_states = self.mlp(hidden_states)\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n\n\nclass ParallelQwen2DecoderLayerRmPad(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.hidden_size = config.hidden_size\n        self.layer_idx = layer_idx\n        self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        residual = hidden_states  # (total_nnz // sp, 1, hidden_size)\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)\n        # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        # shape changes same as attn\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/layers/parallel_linear.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py\n\n\nfrom megatron.core import tensor_parallel\n\n\nclass QKVParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        num_heads,\n        num_key_value_heads,\n        head_dim,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.q_output_size = num_heads * head_dim\n        self.kv_output_size = num_key_value_heads * head_dim\n        self.head_dim = head_dim\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        input_size = self.input_size\n        output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim\n\n        super().__init__(\n            input_size=input_size,\n            output_size=output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n\n\nclass MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        gate_ouput_size,\n        up_output_size,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.output_size = gate_ouput_size + up_output_size\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        super().__init__(\n            input_size=self.input_size,\n            output_size=self.output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/layers/parallel_mlp.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers.activations import ACT2FN\n\nfrom verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear\nfrom verl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass ParallelQwen2MLP(nn.Module):\n    def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=self.hidden_size,\n            gate_ouput_size=self.intermediate_size,\n            up_output_size=self.intermediate_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n        self.gate_size = self.intermediate_size // tp_size\n\n        self.down_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.intermediate_size,\n            output_size=self.hidden_size,\n            bias=False,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        gate_up = self.gate_up_proj(x)[0]\n        gate, up = gate_up.split(self.gate_size, dim=-1)\n        return self.down_proj(self.act_fn(gate) * up)[0]\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numbers\n\nimport torch\nfrom apex.normalization.fused_layer_norm import fused_rms_norm_affine\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import Qwen2Config\n\nfrom verl.utils.megatron import sequence_parallel as sp_utils\n\n\nclass ParallelQwen2RMSNorm(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        \"\"\"\n        Qwen2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        if isinstance(config.hidden_size, numbers.Integral):\n            normalized_shape = (config.hidden_size,)\n        self.normalized_shape = torch.Size(normalized_shape)\n        self.weight = nn.Parameter(torch.ones(self.normalized_shape))\n        self.variance_epsilon = config.rms_norm_eps\n\n        if megatron_config.sequence_parallel:\n            sp_utils.mark_parameter_as_sequence_parallel(self.weight)\n\n    def forward(self, hidden_states):\n        return fused_rms_norm_affine(\n            input=hidden_states,\n            weight=self.weight,\n            normalized_shape=self.normalized_shape,\n            eps=self.variance_epsilon,\n            memory_efficient=True,\n        )\n"
  },
  {
    "path": "verl_distillation/verl/models/qwen2/megatron/modeling_qwen2_megatron.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Qwen2 model.\"\"\"\n\nfrom typing import Optional\n\nimport torch\nimport torch.utils.checkpoint\nfrom megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel\nfrom torch import nn\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\nfrom transformers.models.qwen2.configuration_qwen2 import Qwen2Config\nfrom transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast\n\nfrom verl.utils.device import get_device_name\nfrom verl.utils.megatron import sequence_parallel as sp_utils\nfrom verl.utils.megatron import tensor_parallel as tp_utils\nfrom verl.utils.megatron_utils import TransformerConfig, convert_config\n\nfrom .layers import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad, ParallelQwen2RMSNorm\n\n\"\"\"\nTODO: \n1. Add weight initialization. Here we need to be careful on TP weight init.\n2. Add sequence parallel\n3. Load checkpoint from Qwen2 pretrained checkpoint\n\"\"\"\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\nclass ParallelQwen2Model(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n\n    Args:\n        config: Qwen2Config\n    \"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n            num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n        )\n\n        self.layers = nn.ModuleList(\n            [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n                inputs_embeds.device\n            )\n            combined_attention_mask = (\n                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n            )\n\n        return combined_attention_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (batch_size, seq_length)\n            attention_mask: attention_mask. shape (batch_size, seq_length)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        batch_size, seq_length = input_ids.shape\n        inputs_embeds = self.embed_tokens(input_ids)\n        # embed positions\n\n        attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)\n\n        hidden_states = inputs_embeds\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelQwen2ForCausalLM(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.model = ParallelQwen2Model(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        hidden_states = outputs\n        logits = self.lm_head(hidden_states)[0]\n\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)\n\n        logits = logits.float()\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nfrom flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa: F401, E402\n\n\nclass ParallelQwen2ModelRmPad(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n\n    Args:\n        config: Qwen2Config\n    \"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        self.megatron_config = megatron_config\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n            num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n        )\n\n        self.layers = nn.ModuleList(\n            [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n        # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n        inputs_embeds = inputs_embeds.transpose(0, 1)\n        if self.megatron_config.sequence_parallel:\n            inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n        hidden_states = inputs_embeds\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelQwen2ForCausalLMRmPad(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelQwen2ModelRmPad(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n        self._init_head(config)\n\n    def _init_head(self, config: Qwen2Config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        logits = self.lm_head(hidden_states)[0]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)  # (total_nnz_padded, 1, vocab_size)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n        batch_size, sequence_length = input_ids.shape\n\n        # remove padding here\n        input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(\n            input_ids.unsqueeze(dim=-1), attention_mask\n        )  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids = sp_utils.pad_to_sequence_parallel(input_ids)\n\n        input_ids = input_ids.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = outputs\n\n        logits = self._forward_head(hidden_states)\n\n        # remove padding from sequence parallel\n        if self.megatron_config.sequence_parallel:\n            totol_nnz = cu_seqlens[-1]\n            logits = logits[:totol_nnz]  # (total_nnz_padded)\n\n        logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension\n        # add removed padding back\n        logits = pad_input(\n            logits, indices, batch_size, seqlen=sequence_length\n        )  # (batch_size, sequence_length, vocab_size)\n\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nclass ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        output = super().forward(input_ids, attention_mask, position_ids)\n        output.logits = torch.squeeze(output.logits, dim=-1)\n        return output\n\n\n\"\"\"\nSupport pipeline parallelism\n\"\"\"\n\n\nclass ParallelQwen2ModelRmPadPP(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n    This model definition supports pipeline parallelism. To support pp and vpp,\n    - This model only contains layer in this pp stage and vpp chunk\n    - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.\n    Args:\n        config: Qwen2Config\n    \"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        self.megatron_config = megatron_config\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        if pre_process:\n            self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n                num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n            )\n        else:\n            self.embed_tokens = None\n\n        pp_rank = mpu.get_pipeline_model_parallel_rank()\n        pp_size = megatron_config.pipeline_model_parallel_size\n        self.num_layer_per_pp = config.num_hidden_layers // pp_size\n        vpp_size = megatron_config.virtual_pipeline_model_parallel_size\n        vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()\n\n        if vpp_size is not None:\n            self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size\n            self.num_layer_this_model = self.num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk)\n        else:\n            self.num_layer_this_model = self.num_layer_per_pp\n            offset = pp_rank * self.num_layer_per_pp\n\n        self.layers = nn.ModuleList()\n        for i in range(self.num_layer_this_model):\n            layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config, layer_idx=i + offset)\n            self.layers.add_module(f\"{i}\", layer)\n\n        if post_process:\n            self.norm = ParallelQwen2RMSNorm(config, megatron_config)\n        else:\n            self.norm = None\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        self.input_tensor = input_tensor\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        if self.pre_process:\n            inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n            # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron\n            # so need to deal with it by handle here:\n            # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n            inputs_embeds = inputs_embeds.transpose(0, 1)\n            if self.megatron_config.sequence_parallel:\n                inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n            hidden_states = inputs_embeds\n        else:\n            # self.hidden_states should be passed by Megatron\n            hidden_states = self.input_tensor\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        if self.post_process:\n            hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelQwen2ForCausalLMRmPadPP(nn.Module):\n    def __init__(\n        self,\n        config: Qwen2Config,\n        megatron_config: ModelParallelConfig,\n        pre_process,\n        post_process,\n        share_embeddings_and_output_weights,\n    ):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelQwen2ModelRmPadPP(\n            config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process\n        )\n        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        if post_process:\n            self._init_head(config)\n        if pre_process or post_process:\n            self.setup_embeddings_and_output_layer()\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        assert len(input_tensor) == 1\n        self.model.set_input_tensor(input_tensor[0])\n\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights,\n            **column_kwargs,\n        )\n\n    def setup_embeddings_and_output_layer(self) -> None:\n        \"\"\"Sets up embedding layer in first stage and output layer in last stage.\n\n        This function initalizes word embeddings in the final stage when we are\n        using pipeline parallelism and sharing word embeddings, and sets up param\n        attributes on the embedding and output layers.\n        \"\"\"\n        # Set `is_embedding_or_output_parameter` attribute.\n        if self.pre_process:\n            self.model.embed_tokens.weight.is_embedding_or_output_parameter = True\n        if self.post_process and self.lm_head.weight is not None:\n            self.lm_head.weight.is_embedding_or_output_parameter = True\n\n        if not self.share_embeddings_and_output_weights:\n            return\n\n        if parallel_state.get_pipeline_model_parallel_world_size() == 1:\n            # Zero out wgrad if sharing embeddings between two layers on same\n            # pipeline stage to make sure grad accumulation into main_grad is\n            # correct and does not include garbage values (e.g., from torch.empty).\n            self.shared_embedding_or_output_weight().zero_out_wgrad = True\n            return\n\n        if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process:\n            self.shared_embedding_or_output_weight().shared_embedding = True\n\n        if self.post_process and not self.pre_process:\n            assert not parallel_state.is_pipeline_first_stage()\n            # set word_embeddings weights to 0 here, then copy first\n            # stage's weights using all_reduce below.\n            self.lm_head.weight.data.fill_(0)\n            self.lm_head.weight.shared = True\n            self.lm_head.weight.shared_embedding = True\n\n        if torch.distributed.is_initialized() and parallel_state.is_rank_in_embedding_group():\n            weight = self.shared_embedding_or_output_weight()\n            weight.data = weight.data.to(get_device_name())\n            torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group())\n\n    def shared_embedding_or_output_weight(self) -> torch.Tensor:\n        if self.pre_process:\n            return self.model.embed_tokens.weight\n        elif self.post_process:\n            return self.lm_head.weight\n        return None\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = '\n        # f'{self.config.vocab_size}') # [4, 32, 4096]\n        output_weight = None\n        if self.share_embeddings_and_output_weights:\n            output_weight = self.shared_embedding_or_output_weight()\n        logits = self.lm_head(hidden_states, weight=output_weight)[0]\n        # print(f'logits shape after forward_head: {logits.shape}') # [8, 32, 8]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        return logits\n\n    def forward(\n        self,\n        # original input\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # Note that input_ids, attention_mask and position_ids should be passed to every pp layer.\n        # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model\n        batch_size, sequence_length = input_ids.shape\n        # remove padding here\n        input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(\n            input_ids.unsqueeze(dim=-1), attention_mask\n        )  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)\n\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids_rmpad,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        if self.post_process:\n            hidden_states = outputs\n            logits = self._forward_head(hidden_states)\n            logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension # torch.Size([8, 32, 16])\n\n            # remove padding from sequence parallel\n            if self.megatron_config.sequence_parallel:\n                totol_nnz = cu_seqlens[-1]\n                logits = logits[:totol_nnz]  # (total_nnz_padded)\n            # add removed padding back. If input is already rmpad, we let the caller pad_input\n            logits = pad_input(\n                logits, indices, batch_size, seqlen=sequence_length\n            )  # (batch_size, sequence_length, vocab_size)\n\n            return CausalLMOutputWithPast(\n                loss=None,\n                logits=logits,\n                past_key_values=None,\n                hidden_states=None,\n                attentions=None,\n            )\n        else:\n            return outputs\n\n\nclass ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)\n        if self.post_process:\n            output.logits = torch.squeeze(output.logits, dim=-1)\n            return output\n        else:\n            return output\n"
  },
  {
    "path": "verl_distillation/verl/models/registry.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nfrom typing import Optional\n\nimport torch.nn as nn\n\n# Supported models in Megatron-LM\n# Architecture -> (module, class).\n_MODELS = {\n    \"LlamaForCausalLM\": (\n        \"llama\",\n        (\"ParallelLlamaForCausalLMRmPadPP\", \"ParallelLlamaForValueRmPadPP\", \"ParallelLlamaForCausalLMRmPad\"),\n    ),\n    \"Qwen2ForCausalLM\": (\n        \"qwen2\",\n        (\"ParallelQwen2ForCausalLMRmPadPP\", \"ParallelQwen2ForValueRmPadPP\", \"ParallelQwen2ForCausalLMRmPad\"),\n    ),\n    \"MistralForCausalLM\": (\n        \"mistral\",\n        (\"ParallelMistralForCausalLMRmPadPP\", \"ParallelMistralForValueRmPadPP\", \"ParallelMistralForCausalLMRmPad\"),\n    ),\n    \"ApertusForCausalLM\": (\n        \"apertus\",\n        (\"ParallelApertusForCausalLMRmPadPP\", \"ParallelApertusForValueRmPadPP\", \"ParallelApertusForCausalLMRmPad\"),\n    ),\n}\n\n\n# return model class\nclass ModelRegistry:\n    @staticmethod\n    def load_model_cls(model_arch: str, value=False) -> Optional[type[nn.Module]]:\n        if model_arch not in _MODELS:\n            return None\n\n        megatron = \"megatron\"\n\n        module_name, model_cls_name = _MODELS[model_arch]\n        if not value:  # actor/ref\n            model_cls_name = model_cls_name[0]\n        elif value:  # critic/rm\n            model_cls_name = model_cls_name[1]\n\n        module = importlib.import_module(f\"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron\")\n        return getattr(module, model_cls_name, None)\n\n    @staticmethod\n    def get_supported_archs() -> list[str]:\n        return list(_MODELS.keys())\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/apertus.py",
    "content": "# Copyright 2025 The SwissAI Initiative\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport sys\nfrom typing import Callable, Optional\n\nimport torch\n\nif sys.version_info >= (3, 11):\n    pass\nelse:\n    pass\n\nfrom transformers.cache_utils import Cache\nfrom transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb\nfrom transformers.utils import logging\n\n# Import compatibility wrapper for flash_attn_supports_top_left_mask\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\nlogger = logging.get_logger(__name__)\n\n\ndef apertus_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    position_embeddings: tuple[torch.Tensor, torch.Tensor],\n    attention_mask: Optional[torch.Tensor],\n    past_key_value: Optional[Cache] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n    \"\"\"\n    Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.\n\n    Key differences from Llama attention:\n    - QK normalization applied after Q/K projections\n\n        NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.\n    \"\"\"\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n    from transformers.models.apertus.modeling_apertus import eager_attention_forward\n\n    bsz, q_len, _ = hidden_states.shape\n\n    query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n\n    query_states = self.q_norm(query_states)\n    key_states = self.k_norm(key_states)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)\n\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)\n\n    cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    attention_interface: Callable = eager_attention_forward\n    if self.config._attn_implementation != \"eager\":\n        if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n            logger.warning_once(\n                \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. \"\n                \"Falling back to eager attention. This warning can be removed using the argument \"\n                '`attn_implementation=\"eager\"` when loading the model.'\n            )\n        else:\n            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n    attn_output, attn_weights = attention_interface(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        dropout=0.0 if not self.training else self.attention_dropout,\n        scaling=self.scaling,\n        **kwargs,\n    )\n\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n    return attn_output, attn_weights\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/dense_common.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass\nfrom typing import Optional, Union\n\nimport torch\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\n\n@dataclass\nclass CausalLMOutputForPPO(CausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n\n\ndef forward_base_model(\n    self,\n    input_ids: Optional[torch.LongTensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[Cache] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n) -> CausalLMOutputWithPast:\n    r\"\"\"\n    Copy paste LLaMa's forward\n    https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py\n\n    This function should be generic enough for all pure text models.\n    ```\"\"\"\n\n    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n    output_hidden_states = (\n        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n    )\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    outputs = self.model(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        cache_position=cache_position,\n    )\n\n    return outputs\n\n\ndef forward_with_torch_backend(\n    self,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[Union[\"Cache\", list[torch.FloatTensor]]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    logits_to_keep: int | torch.Tensor = 0,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> tuple | CausalLMOutputForPPO:\n    from verl.utils.experimental.torch_functional import FusedLinearForPPO\n\n    outputs = forward_base_model(\n        self,\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        cache_position=cache_position,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_with_torch_backend has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_torch_backend, either labels or input_ids must be provided.\")\n\n    fused_linear_for_ppo = FusedLinearForPPO()\n    log_probs, entropy = fused_linear_for_ppo.forward(\n        hidden_states=hidden_states,\n        vocab_weights=self.lm_head.weight,\n        input_ids=rolled_labels,\n        temperature=temperature,\n    )\n\n    return CausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n\n\ndef forward_with_triton_backend(\n    self,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[Union[\"Cache\", list[torch.FloatTensor]]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    logits_to_keep: int | torch.Tensor = 0,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> tuple | CausalLMOutputForPPO:\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\n\n    outputs = forward_base_model(\n        self,\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        cache_position=cache_position,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_with_triton_backend has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_triton_backend, either labels or input_ids must be provided.\")\n\n    log_probs, entropy = linear_cross_entropy(\n        hidden_states,\n        self.lm_head.weight,\n        rolled_labels,\n        temperature,\n        \"none\",\n    )\n\n    return CausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/glm4v.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport itertools\nimport logging\nimport os\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nimport torch.distributed as dist\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check\nfrom transformers.models.glm4v.modeling_glm4v import (\n    Glm4vCausalLMOutputWithPast,\n    Glm4vForConditionalGeneration,\n    Glm4vTextAttention,\n)\nfrom transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10\n\nfrom verl.utils.device import is_npu_available\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_group,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func\n\n    _flash_supports_window_size = \"window_size\" in inspect.signature(flash_attn_func).parameters\n    _flash_supports_deterministic = \"deterministic\" in inspect.signature(flash_attn_func).parameters\n    _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\nif is_npu_available:\n    from transformers.integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func\n    from transformers.integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func\n    from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask\n\n    _flash_supports_window_size = \"window_size\" in inspect.signature(flash_attn_func).parameters\n    _flash_supports_deterministic = \"deterministic\" in inspect.signature(flash_attn_func).parameters\n    _flash_use_top_left_mask = flash_attn_supports_top_left_mask()\n\n_flash_deterministic_enabled = os.getenv(\"FLASH_ATTENTION_DETERMINISTIC\", \"0\") == \"1\"\n\n\ndef get_rope_index(\n    processor,\n    input_ids: torch.Tensor,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"\n    Gets the position ids for GLM4V in padding-free format.\n    The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.\n    \"\"\"\n    spatial_merge_size = processor.image_processor.merge_size\n    image_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|image|>\")\n    video_start_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|begin_of_video|>\")\n    video_end_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|end_of_video|>\")\n\n    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):\n        if attention_mask is None:\n            attention_mask = torch.ones_like(input_ids)\n\n        position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device)  # (3, seqlen)\n        image_index, video_index = 0, 0\n        video_group_index = 0\n\n        input_ids_filtered = input_ids[attention_mask == 1]\n        input_tokens = input_ids_filtered.tolist()\n\n        input_token_type = []\n        video_check_flg = False\n        for token in input_tokens:\n            if token == video_start_token_id:\n                video_check_flg = True\n            elif token == video_end_token_id:\n                video_check_flg = False\n\n            if token == image_token_id and not video_check_flg:\n                input_token_type.append(\"image\")\n            elif token == image_token_id and video_check_flg:\n                input_token_type.append(\"video\")\n            else:\n                input_token_type.append(\"text\")\n\n        input_type_group = []\n        for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]):\n            group = list(group)\n            start_index = group[0][0]\n            end_index = group[-1][0] + 1\n            input_type_group.append((key, start_index, end_index))\n\n        llm_pos_ids_list = []\n        video_frame_num = 1\n\n        for modality_type, start_idx, end_idx in input_type_group:\n            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n\n            if modality_type == \"image\":\n                t, h, w = (\n                    image_grid_thw[image_index][0],\n                    image_grid_thw[image_index][1],\n                    image_grid_thw[image_index][2],\n                )\n                llm_grid_t, llm_grid_h, llm_grid_w = (\n                    t.item(),\n                    h.item() // spatial_merge_size,\n                    w.item() // spatial_merge_size,\n                )\n\n                t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()\n                h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()\n                w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()\n                llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)\n\n                image_index += 1\n                video_frame_num = 1\n\n            elif modality_type == \"video\":\n                t, h, w = (\n                    video_frame_num,\n                    video_grid_thw[video_index][1],\n                    video_grid_thw[video_index][2],\n                )\n\n                llm_grid_t, llm_grid_h, llm_grid_w = (\n                    t,\n                    h.item() // spatial_merge_size,\n                    w.item() // spatial_merge_size,\n                )\n\n                for t_idx in range(llm_grid_t):\n                    t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()\n                    h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()\n                    w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()\n                    llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)\n\n                video_group_index += 1\n\n                if video_group_index >= video_grid_thw[video_index][0]:\n                    video_index += 1\n                    video_group_index = 0\n\n                video_frame_num += 1\n\n            else:\n                text_len = end_idx - start_idx\n                llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n                video_frame_num = 1\n\n        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)\n        position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)\n    else:\n        if attention_mask is not None:\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)\n        else:\n            position_ids = torch.arange(input_ids.shape[0], device=input_ids.device).view(1, -1).expand(3, -1)\n\n    return position_ids\n\n\ndef prepare_fa2_from_position_ids(\n    query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor\n):\n    assert position_ids.ndim == 2  # (batch_size, seq_length)\n    query = query.contiguous().view(-1, query.size(-2), query.size(-1))\n    key = key.contiguous().view(-1, key.size(-2), key.size(-1))\n    value = value.contiguous().view(-1, value.size(-2), value.size(-1))\n    position_ids = position_ids.view(-1)\n    cu_seqlens = torch.cat(\n        (\n            (position_ids == 0).nonzero().view(-1).to(torch.int32),\n            torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),\n        )\n    )\n    max_length = cu_seqlens.diff().max()  # use cu_seqlens to infer max_length for qwen2vl mrope\n    return (query, key, value, (cu_seqlens, cu_seqlens), (max_length, max_length))\n\n\ndef _custom_flash_attention_forward(\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    query_length: int,\n    is_causal: bool = True,\n    position_ids: Optional[torch.Tensor] = None,\n    use_top_left_mask: bool = False,\n    deterministic: Optional[bool] = None,\n    **kwargs,\n):\n    \"\"\"\n    Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)\n    \"\"\"\n    # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).\n    flash_kwargs = {}\n\n    if _flash_supports_deterministic:\n        flash_kwargs[\"deterministic\"] = deterministic if deterministic is not None else _flash_deterministic_enabled\n\n    if kwargs.get(\"softcap\") is not None:\n        flash_kwargs[\"softcap\"] = kwargs.pop(\"softcap\")\n\n    query_states, key_states, value_states = fa_peft_integration_check(\n        query_states, key_states, value_states, target_dtype=torch.bfloat16\n    )\n\n    if position_ids is not None:\n        assert position_ids.ndim == 2  # (batch_size, seq_length / sp_size)\n\n    sp_size = get_ulysses_sequence_parallel_world_size()\n    if sp_size > 1:\n        # qkv: (batch_size, seq_length / sp_size, num_head, head_size)\n        validate_ulysses_config(query_states.size(2), sp_size)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)\n        position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]\n        position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())\n        position_ids = torch.cat(position_ids_lst, dim=-1)  # (batch_size, seq_length)\n\n    if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():\n        batch_size = query_states.size(0)\n        q, k, v, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = prepare_fa2_from_position_ids(\n            query_states, key_states, value_states, position_ids\n        )\n        attn_output = flash_attn_varlen_func(\n            q=q,\n            k=k,\n            v=v,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            max_seqlen_q=max_seqlen_q,\n            max_seqlen_k=max_seqlen_k,\n            dropout_p=kwargs.pop(\"dropout\", 0.0),\n            softmax_scale=kwargs.pop(\"softmax_scale\", None),\n            causal=is_causal,\n            **flash_kwargs,\n        )\n        attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))\n    else:\n        attn_output = _flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            query_length,\n            is_causal=is_causal,\n            use_top_left_mask=use_top_left_mask,\n            deterministic=deterministic,\n            **kwargs,\n        )  # do not pass position_ids to old flash_attention_forward\n\n    if sp_size > 1:\n        # (batch_size, seq_length, num_head, head_size)\n        attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)\n\n    return attn_output\n\n\ndef glm4v_attn_forward(\n    self: \"Glm4vTextAttention\",\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46\n    **kwargs,\n) -> tuple[torch.Tensor, None, None]:\n    from transformers.models.glm4v.modeling_glm4v import apply_multimodal_rotary_pos_emb, repeat_kv\n\n    bsz, q_len, _ = hidden_states.size()  # q_len = seq_length / sp_size\n    query_states = self.q_proj(hidden_states)  # (batch_size, seq_length / sp_size, num_heads * head_size)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n    # Because the input can be padded, the absolute sequence length depends on the max position id.\n    cos, sin = position_embeddings\n    query_states, key_states = apply_multimodal_rotary_pos_emb(\n        query_states, key_states, cos, sin, self.rope_scaling[\"mrope_section\"]\n    )\n    key_states = repeat_kv(key_states, self.num_key_value_groups)\n    value_states = repeat_kv(value_states, self.num_key_value_groups)\n    dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n    # This is before the transpose\n    q_len = query_states.shape[2]\n\n    # FA2 uses non-transposed inputs\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    attn_output = _custom_flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        query_length=q_len,\n        is_causal=getattr(self, \"is_causal\", True),\n        dropout=dropout_rate,\n        use_top_left_mask=_flash_use_top_left_mask,\n        position_ids=position_ids,  # important: pass position ids\n    )  # (batch_size, seq_length / sp_size, num_head, head_size)\n    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()\n    attn_output = self.o_proj(attn_output)\n    return attn_output, None\n\n\ndef _get_input_embeds(\n    model: \"Glm4vForConditionalGeneration\",\n    input_ids: torch.LongTensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    pixel_values: Optional[torch.FloatTensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n):\n    inputs_embeds = model.get_input_embeddings()(input_ids)\n    if pixel_values is not None:\n        pixel_values = pixel_values.type(model.visual.dtype)\n        image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)\n        n_image_tokens = (input_ids == model.config.image_token_id).sum().item()\n        n_image_features = image_embeds.shape[0]\n        if n_image_tokens != n_image_features:\n            raise ValueError(\n                f\"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}\"\n            )\n\n        mask = input_ids == model.config.image_token_id\n        mask_unsqueezed = mask.unsqueeze(-1)\n        mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n        image_mask = mask_expanded.to(inputs_embeds.device)\n\n        image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n        inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)\n\n    if pixel_values_videos is not None:\n        pixel_values_videos = pixel_values_videos.type(model.visual.dtype)\n        video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)\n        n_video_tokens = (input_ids == model.config.video_token_id).sum().item()\n        n_video_features = video_embeds.shape[0]\n        if n_video_tokens != n_video_features:\n            raise ValueError(\n                f\"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}\"\n            )\n\n        mask = input_ids == model.config.video_token_id\n        mask_unsqueezed = mask.unsqueeze(-1)\n        mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n        video_mask = mask_expanded.to(inputs_embeds.device)\n\n        video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n        inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)\n\n    if pixel_values is None and pixel_values_videos is None:  # handle mixed text-image data\n        pixel_values = torch.zeros((16, 1176), dtype=inputs_embeds.dtype, device=inputs_embeds.device)\n        image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)\n        image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)\n        inputs_embeds += 0.0 * image_embeds.mean()\n\n    if attention_mask is not None:\n        attention_mask = attention_mask.to(inputs_embeds.device)\n\n    return inputs_embeds, attention_mask\n\n\ndef process_position_ids(position_ids: torch.Tensor) -> torch.Tensor:\n    if position_ids.ndim != 3 or position_ids.size(0) != 4:\n        # we concat the text position ids with the 3D vision position ids by default\n        # see https://github.com/huggingface/transformers/pull/39447\n        raise ValueError(\"position_ids should be a 3D tensor of shape (4, batch_size, seq_length).\")\n\n    return position_ids\n\n\n@dataclass\nclass Glm4vCausalLMOutputForPPO(Glm4vCausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n\n\ndef glm4v_base_forward(\n    self: \"Glm4vForConditionalGeneration\",\n    input_ids: torch.LongTensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    pixel_values: Optional[torch.FloatTensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    **kwargs,\n):\n    kwargs[\"inputs_embeds\"], kwargs[\"attention_mask\"] = _get_input_embeds(\n        self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw\n    )  # avoid lora module having multiple keyword arguments\n    return self.language_model(\n        input_ids=None,\n        **kwargs,\n    )\n\n\ndef glm4v_forward(\n    self: \"Glm4vForConditionalGeneration\",\n    input_ids: torch.LongTensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    pixel_values: Optional[torch.FloatTensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    **kwargs,\n):\n    return self.model(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=process_position_ids(position_ids),\n        pixel_values=pixel_values,\n        pixel_values_videos=pixel_values_videos,\n        image_grid_thw=image_grid_thw,\n        video_grid_thw=video_grid_thw,\n        **kwargs,\n    )\n\n\ndef forward_with_normal_backend(\n    self: Glm4vForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    labels: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **kwargs,\n) -> \"Glm4vCausalLMOutputWithPast\":\n    outputs = glm4v_forward(self, input_ids, **kwargs)\n    hidden_states = outputs[0]\n    logits = self.lm_head(hidden_states)\n\n    return Glm4vCausalLMOutputWithPast(\n        logits=logits,\n        hidden_states=outputs.hidden_states,\n    )\n\n\ndef forward_with_torch_backend(\n    self: Glm4vForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    labels: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **kwargs,\n) -> tuple | Glm4vCausalLMOutputForPPO:\n    from verl.utils.experimental.torch_functional import FusedLinearForPPO\n\n    outputs = glm4v_forward(self, input_ids, **kwargs)\n    hidden_states = outputs[0]\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_torch_backend, either labels or input_ids must be provided.\")\n\n    fused_linear_for_ppo = FusedLinearForPPO()\n    log_probs, entropy = fused_linear_for_ppo.forward(\n        hidden_states=hidden_states,\n        vocab_weights=self.lm_head.weight,\n        input_ids=rolled_labels,\n        temperature=temperature,\n    )\n    return Glm4vCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        hidden_states=outputs.hidden_states,\n    )\n\n\ndef forward_with_triton_backend(\n    self: Glm4vForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    labels: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **kwargs,\n) -> tuple | Glm4vCausalLMOutputForPPO:\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\n\n    outputs = glm4v_forward(self, input_ids, **kwargs)\n    hidden_states = outputs[0]\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_triton_backend, either labels or input_ids must be provided.\")\n\n    log_probs, entropy = linear_cross_entropy(\n        hidden_states,\n        self.lm_head.weight,\n        rolled_labels,\n        temperature,\n        \"none\",\n    )\n    return Glm4vCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        hidden_states=outputs.hidden_states,\n    )\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/kimi_vl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\n\nfrom verl.models.transformers.monkey_patch import is_transformers_version_in_range\n\n# Import compatibility wrapper for flash_attn_supports_top_left_mask\nfrom verl.utils.transformers_compat import flash_attn_supports_top_left_mask\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n            used to pass offsetted position ids when working with a KV-cache.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n    sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n\n    b, h, s, d = q.shape\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    b, h, s, d = k.shape\n    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef _ulysses_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.LongTensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Cache] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    **kwargs,\n) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n    bsz, q_len, _ = hidden_states.size()\n\n    if self.q_lora_rank is None:\n        q = self.q_proj(hidden_states)\n    else:\n        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n    q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n\n    # Flash attention requires the input to have the shape\n    # batch_size x seq_length x head_dim x hidden_dim\n    # therefore we just need to keep the original shape\n    compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n    compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)\n    k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n    kv = (\n        self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n        .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n        .transpose(1, 2)\n    )\n\n    k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)\n\n    # patch\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads\n        k_pe = repeat_kv(k_pe, ulysses_sp_size)  # to keep heads=1 after a2a\n        k_nope = repeat_kv(k_nope, num_key_value_groups)\n        value_states = repeat_kv(value_states, num_key_value_groups)\n        q = gather_seq_scatter_heads(q, seq_dim=2, head_dim=1)\n        k_pe = gather_seq_scatter_heads(k_pe, seq_dim=2, head_dim=1)\n        k_nope = gather_seq_scatter_heads(k_nope, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n        # (batch_size, num_head / sp_size, seq_length, head_size)\n        full_q_len = q.size(2)  # full_q_len = seq_length\n\n    else:\n        full_q_len = q_len\n\n    q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)\n    cos, sin = self.rotary_emb(value_states, seq_len=full_q_len)\n    q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)\n\n    query_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim)\n    query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n    query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n    key_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim)\n    key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n    key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n\n    if self.q_head_dim != self.v_head_dim:\n        value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])\n\n    # TODO: These transpose are quite inefficient but Flash Attention requires the layout\n    # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n    # to be able to avoid many of these transpose/reshape/view.\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    dropout_rate = self.attention_dropout if self.training else 0.0\n\n    attn_output = _flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        dropout=dropout_rate,\n        sliding_window=None,\n        is_causal=self.is_causal,\n        use_top_left_mask=flash_attn_supports_top_left_mask(),\n        position_ids=position_ids,  # important: pass position ids\n        softmax_scale=self.softmax_scale,\n    )\n\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)\n\n    if self.q_head_dim != self.v_head_dim:\n        attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous()\n    attn_output = self.o_proj(attn_output)\n\n    if is_transformers_version_in_range(min_version=\"4.53.0\"):\n        return attn_output, None\n    else:\n        return attn_output, None, None\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/llama.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport sys\nfrom typing import Callable, Optional\n\nimport torch\n\nif sys.version_info >= (3, 11):\n    pass\nelse:\n    pass\n\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.models.llama.modeling_llama import apply_rotary_pos_emb\nfrom transformers.utils import logging\n\n# Import compatibility wrapper for flash_attn_supports_top_left_mask\nfrom verl.utils.transformers_compat import flash_attn_supports_top_left_mask\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\nlogger = logging.get_logger(__name__)\n\n\ndef llama_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.LongTensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Cache] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    cache_position: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46\n    **kwargs,\n) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n    \"\"\"\n    Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.\n\n    NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1].\n    \"\"\"\n    output_attentions = False\n\n    bsz, q_len, _ = hidden_states.size()\n\n    query_states = self.q_proj(hidden_states)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    # Flash attention requires the input to have the shape\n    # batch_size x seq_length x head_dim x hidden_dim\n    # therefore we just need to keep the original shape\n    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n    # trade off: repeat first and then all to all\n    # key_states = repeat_kv(key_states, self.num_key_value_groups)\n    # value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)  # full seq length\n\n    if position_embeddings is None:\n        logger.warning_once(\n            \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n            \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n            \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be \"\n            \"removed and `position_embeddings` will be mandatory.\"\n        )\n        cos, sin = self.rotary_emb(value_states, position_ids)\n    else:\n        cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    # TODO: These transpose are quite inefficient but Flash Attention requires the layout\n    # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n    # to be able to avoid many of these transpose/reshape/view.\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    dropout_rate = self.attention_dropout if self.training else 0.0\n\n    # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n    # therefore the input hidden states gets silently casted in float32. Hence, we need\n    # cast them back in the correct dtype just to be sure everything works as expected.\n    # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n    # in fp32. (LlamaRMSNorm handles it correctly)\n\n    input_dtype = query_states.dtype\n    if input_dtype == torch.float32:\n        if torch.is_autocast_enabled():\n            target_dtype = torch.get_autocast_gpu_dtype()\n        # Handle the case where the model is quantized\n        elif hasattr(self.config, \"_pre_quantization_dtype\"):\n            target_dtype = self.config._pre_quantization_dtype\n        else:\n            target_dtype = self.q_proj.weight.dtype\n\n        logger.warning_once(\n            f\"The input hidden states seems to be silently casted in float32, this might be related to \"\n            f\"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the \"\n            f\"input in {target_dtype}.\"\n        )\n\n        query_states = query_states.to(target_dtype)\n        key_states = key_states.to(target_dtype)\n        value_states = value_states.to(target_dtype)\n\n    attn_output = _flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        position_ids=position_ids,\n        dropout=dropout_rate,\n        sliding_window=getattr(self, \"sliding_window\", None),\n        use_top_left_mask=flash_attn_supports_top_left_mask(),\n        is_causal=self.is_causal,\n        **kwargs,\n    )\n\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n\n    if not output_attentions:\n        attn_weights = None\n\n    return attn_output, attn_weights, past_key_value\n\n\ndef llama_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    position_embeddings: tuple[torch.Tensor, torch.Tensor],\n    attention_mask: Optional[torch.Tensor],\n    past_key_value: Optional[Cache] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n    \"\"\"\n    Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.\n\n    NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.\n    \"\"\"\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n    from transformers.models.llama.modeling_llama import eager_attention_forward\n\n    bsz, q_len, _ = hidden_states.shape\n\n    query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)\n\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)\n\n    cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    attention_interface: Callable = eager_attention_forward\n    if self.config._attn_implementation != \"eager\":\n        if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n            logger.warning_once(\n                \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. \"\n                \"Falling back to eager attention. This warning can be removed using the argument \"\n                '`attn_implementation=\"eager\"` when loading the model.'\n            )\n        else:\n            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n    attn_output, attn_weights = attention_interface(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        dropout=0.0 if not self.training else self.attention_dropout,\n        scaling=self.scaling,\n        **kwargs,\n    )\n\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n    return attn_output, attn_weights\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/monkey_patch.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nApply monkey-patch function to models\n\"\"\"\n\nimport sys\nfrom types import SimpleNamespace\nfrom typing import Optional\n\nimport torch\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.modeling_utils import PreTrainedModel\n\nfrom verl.utils.import_utils import is_trl_available\nfrom verl.utils.transformers_compat import is_transformers_version_in_range\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_group,\n    get_ulysses_sequence_parallel_world_size,\n    slice_input_tensor,\n)\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=2, repeats=n_rep). The hidden states go from (batch,\n    seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim)\n    \"\"\"\n    batch, slen, num_key_value_heads, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)\n    return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)\n\n\ndef _ulysses_flash_attention_forward(\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    query_length: int,\n    *args,\n    position_ids: Optional[torch.Tensor] = None,\n    **kwargs,\n):\n    \"\"\"Insert all-to-all before and after flash attention.\n    DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509\n\n    For transformers>=4.55, the flash attention api has changed,\n    we need to pass the query_length after doing ulysses all2all.\n    See https://github.com/huggingface/transformers/issues/40399\n\n    Args:\n        query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim)\n        key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)\n        value_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)\n        position_ids (torch.Tensor, optional): (batch_size, seqlen/sp_size)\n\n    Returns:\n        torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim)\n\n    \"\"\"\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        assert position_ids is not None, \"position_ids is required for Ulysses sequence parallelism\"\n\n        # NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,\n        # we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.\n        # For example:\n        # - nheads_k=4, sp=8, repeats=2\n        # - nheads_k=8, sp=8, repeats=1\n        # - nheads_k=16, sp=8, repeats=1\n        repeats = max(ulysses_sp_size // key_states.size(2), 1)\n        key_states = repeat_kv(key_states, repeats)\n        value_states = repeat_kv(value_states, repeats)\n\n        # (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)\n\n        # TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate\n        # this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly.\n        # https://github.com/huggingface/transformers/pull/33932\n\n        # (bsz, seq_len/n) -> (bsz, seq_len)\n        position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]\n        torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())\n        position_ids = torch.concat(position_ids_list, dim=-1)\n\n    # (bsz, seq_len, n_head/n, head_dim)\n    query_length = query_states.size(1)\n    attn_output = _flash_attention_forward(\n        query_states, key_states, value_states, attention_mask, query_length, *args, position_ids=position_ids, **kwargs\n    )\n\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n\n    return attn_output\n\n\ndef patch_vlm_for_ulysses_input_slicing(model_class: type):\n    \"\"\"\n    Applies a monkey patch to the forward method of a given model class\n    to enable Ulysses sequence parallelism input slicing.\n    \"\"\"\n\n    def _create_ulysses_wrapped_decoder_forward(original_forward):\n        def ulysses_wrapped_decoder_forward(self, *args, **kwargs):\n            inputs_embeds = kwargs.get(\"inputs_embeds\")\n            position_ids = kwargs.get(\"position_ids\")\n            visual_pos_masks = kwargs.get(\"visual_pos_masks\")\n            deepstack_visual_embeds = kwargs.get(\"deepstack_visual_embeds\")\n            call_kwargs = kwargs.copy()\n\n            current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n            slice_now = (\n                inputs_embeds is not None\n                and current_ulysses_sp_size > 1\n                and getattr(self, \"_needs_initial_slice\", True)\n            )\n            if slice_now:\n                call_kwargs[\"inputs_embeds\"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)\n                call_kwargs[\"position_ids\"] = slice_input_tensor(position_ids, dim=-1, padding=False)\n                # Also slice visual_pos_masks and deepstack_visual_embeds for Qwen3 VL models\n                if visual_pos_masks is not None:\n                    original_visual_mask = visual_pos_masks\n                    sliced_visual_mask = slice_input_tensor(visual_pos_masks, dim=1, padding=False)\n                    call_kwargs[\"visual_pos_masks\"] = sliced_visual_mask\n\n                    if deepstack_visual_embeds is not None:\n                        sliced_embeds = []\n\n                        num_visual_before = original_visual_mask.sum().item()\n                        num_visual_in_shard = sliced_visual_mask.sum().item()\n\n                        if num_visual_in_shard > 0 and num_visual_before > 0:\n                            # Calculate which visual embeddings belong to this shard\n                            # We need to find the offset of visual tokens in this shard\n                            from verl.utils.ulysses import get_ulysses_sequence_parallel_rank\n\n                            rank = get_ulysses_sequence_parallel_rank()\n                            seq_len = original_visual_mask.shape[1]\n                            local_seq_len = seq_len // current_ulysses_sp_size\n                            start_idx = rank * local_seq_len\n                            end_idx = start_idx + local_seq_len\n\n                            # Get total visual tokens before and up to the end of the shard's sequence slice\n                            # This correctly handles batches by summing across all samples\n                            visual_start = original_visual_mask[:, :start_idx].sum().item() if start_idx > 0 else 0\n                            visual_end = original_visual_mask[:, :end_idx].sum().item()\n\n                            # Slice each tensor in deepstack_visual_embeds\n                            for embed in deepstack_visual_embeds:\n                                sliced_embeds.append(embed[visual_start:visual_end])\n                        else:\n                            # No visual tokens in this shard, create empty tensors to maintain gradient flow\n                            for embed in deepstack_visual_embeds:\n                                sliced_embeds.append(embed[:0])\n                        call_kwargs[\"deepstack_visual_embeds\"] = sliced_embeds\n\n                self._needs_initial_slice = False\n            try:\n                return original_forward(self, *args, **call_kwargs)\n            finally:\n                if slice_now:\n                    self._needs_initial_slice = True\n\n        return ulysses_wrapped_decoder_forward\n\n    original_forward = model_class.forward\n    wrapped_forward = _create_ulysses_wrapped_decoder_forward(original_forward)\n    model_class.forward = wrapped_forward\n    print(f\"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.\")\n\n\ndef patch_forward_with_backends(\n    model: PreTrainedModel,\n    use_fused_kernels: bool = False,\n    fused_kernels_backend: str = None,\n):\n    \"\"\"\n    Choose the forward function based on the model and backend.\n    Args:\n        model (PreTrainedModel): The model to apply the monkey patch.\n        use_fused_kernels (bool): Whether to use fused kernels.\n        fused_kernels_backend (str): The backend to use for fused kernels.\n    \"\"\"\n    if not use_fused_kernels or fused_kernels_backend not in [\"triton\", \"torch\"]:\n        print(\n            f\"Skipping monkey patch for {model.__class__.__name__} as use_fused_kernels is \"\n            f\"{use_fused_kernels} or fused_kernels_backend is {fused_kernels_backend}\"\n        )\n        return\n\n    forward_with_torch_backend_function = model.__class__.forward\n    forward_with_triton_backend_function = model.__class__.forward\n    if model.config.model_type in [\"qwen2_5_vl\", \"qwen2_vl\"]:\n        from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend\n\n        forward_with_torch_backend_function = forward_with_torch_backend\n        forward_with_triton_backend_function = forward_with_triton_backend\n    elif model.config.model_type in [\"qwen3_vl\", \"qwen3_vl_moe\"]:\n        from verl.models.transformers.qwen3_vl import forward_with_torch_backend, forward_with_triton_backend\n\n        forward_with_torch_backend_function = forward_with_torch_backend\n        forward_with_triton_backend_function = forward_with_triton_backend\n    elif model.config.model_type == \"glm4v\":\n        from verl.models.transformers.glm4v import forward_with_torch_backend, forward_with_triton_backend\n\n        forward_with_torch_backend_function = forward_with_torch_backend\n        forward_with_triton_backend_function = forward_with_triton_backend\n    else:\n        from verl.models.transformers.dense_common import forward_with_torch_backend, forward_with_triton_backend\n\n        forward_with_torch_backend_function = forward_with_torch_backend\n        forward_with_triton_backend_function = forward_with_triton_backend\n\n    if fused_kernels_backend == \"triton\":\n        model.__class__.forward = forward_with_triton_backend_function\n        print(f\"Using Triton backend for fused kernels in {model.__class__.__name__}\")\n    elif fused_kernels_backend == \"torch\":\n        model.__class__.forward = forward_with_torch_backend_function\n        print(f\"Using Torch backend for fused kernels in {model.__class__.__name__}\")\n    else:\n        raise ValueError(f\"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'.\")\n\n\ndef apply_monkey_patch(\n    model: PreTrainedModel,\n    ulysses_sp_size: int = 1,\n    use_remove_padding: bool = True,\n    use_fused_kernels: bool = False,\n    fused_kernels_backend: str = None,\n):\n    \"\"\"\n    Apply monkey patch to the models for ulysses sequence parallel and fused kernel.\n\n    In the end of this function forward function of the model is patched for fused kernel.\n    If the model is not supported with fused kernel, please return after patch.\n    \"\"\"\n\n    \"\"\"Replace _flash_attention_forward to _ulysses_flash_attention_forward\"\"\"\n    module = sys.modules[model.__module__]\n\n    try:\n        num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads\n    except AttributeError:\n        num_attention_heads, num_key_value_heads = (\n            model.config.text_config.num_attention_heads,\n            model.config.text_config.num_key_value_heads,\n        )\n\n    assert num_attention_heads % ulysses_sp_size == 0, (\n        f\"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}\"\n    )\n    assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, (\n        f\"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size \"\n        f\"{ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,\"\n        f\"kv heads are repeated to ensure correctness.\"\n    )\n\n    if is_trl_available():\n        from trl import AutoModelForCausalLMWithValueHead  # type: ignore\n\n        def state_dict(self, *args, **kwargs):\n            return torch.nn.Module.state_dict(self, *args, **kwargs)\n\n        AutoModelForCausalLMWithValueHead.state_dict = state_dict\n        print(\"Monkey patch state_dict in AutoModelForCausalLMWithValueHead. \")\n\n    # TODO: VLM models only, unify monkey patch to LLM models.\n    if model.config.model_type in [\"qwen2_5_vl\", \"qwen2_vl\"]:\n        # Step 1: patch model to support image-text mixed data\n        if is_transformers_version_in_range(min_version=\"4.52.0\"):\n            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (\n                Qwen2_5_VLForConditionalGeneration,\n                Qwen2_5_VLModel,\n                Qwen2_5_VLTextModel,\n            )\n            from transformers.models.qwen2_vl.modeling_qwen2_vl import (\n                Qwen2VLForConditionalGeneration,\n                Qwen2VLModel,\n                Qwen2VLTextModel,\n            )\n        else:\n            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration\n            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel\n            from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration\n            from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel as Qwen2VLTextModel\n\n            Qwen2_5_VLModel = SimpleNamespace(forward=None)\n            Qwen2VLModel = SimpleNamespace(forward=None)\n\n        from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward\n\n        Qwen2_5_VLModel.forward = qwen2_vl_base_forward\n        Qwen2VLModel.forward = qwen2_vl_base_forward\n        Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend\n        Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend\n        print(f\"Monkey patch {model.__class__.__name__} model forward\")\n\n        # Step 2: patch attention to support ulysses parallelism\n        if is_transformers_version_in_range(min_version=\"4.54.0\"):\n            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention\n            from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention\n        elif is_transformers_version_in_range(min_version=\"4.53.0\"):\n            raise RuntimeError(\"Transformers 4.53.* is bugged. Use transformers 4.54.0 or later.\")\n        else:\n            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (\n                Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,\n            )\n            from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention\n\n        if use_remove_padding or ulysses_sp_size > 1:\n            from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward\n\n            Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward\n            Qwen2VLAttention.forward = qwen2_vl_attn_forward\n            print(f\"Monkey patch {model.__class__.__name__} attention layer\")\n\n        # Step 3: patch input for multimodal sequence parallelism\n        if ulysses_sp_size > 1:\n            patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)\n            patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)\n\n    elif model.config.model_type in [\"qwen3_vl\", \"qwen3_vl_moe\"]:\n        # Step 1: patch model to support image-text mixed data\n        from transformers.models.qwen3_vl.modeling_qwen3_vl import (\n            Qwen3VLForConditionalGeneration,\n            Qwen3VLModel,\n            Qwen3VLTextModel,\n        )\n        from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (\n            Qwen3VLMoeForConditionalGeneration,\n            Qwen3VLMoeModel,\n            Qwen3VLMoeTextModel,\n        )\n\n        from verl.models.transformers.qwen3_vl import forward_with_normal_backend, qwen3_vl_base_forward\n\n        Qwen3VLModel.forward = qwen3_vl_base_forward\n        Qwen3VLMoeModel.forward = qwen3_vl_base_forward\n        Qwen3VLForConditionalGeneration.forward = forward_with_normal_backend\n        Qwen3VLMoeForConditionalGeneration.forward = forward_with_normal_backend\n        print(f\"Monkey patch {model.__class__.__name__} model forward\")\n\n        # Step 2: patch input for multimodal sequence parallelism\n        if ulysses_sp_size > 1:\n            patch_vlm_for_ulysses_input_slicing(Qwen3VLTextModel)\n            patch_vlm_for_ulysses_input_slicing(Qwen3VLMoeTextModel)\n\n    elif model.config.model_type == \"glm4v\":\n        # Step 1: patch model to support image-text mixed data\n\n        from transformers.models.glm4v.modeling_glm4v import (\n            Glm4vForConditionalGeneration,\n            Glm4vModel,\n            Glm4vTextAttention,\n            Glm4vTextModel,\n        )\n\n        from verl.models.transformers.glm4v import forward_with_normal_backend, glm4v_base_forward\n\n        Glm4vModel.forward = glm4v_base_forward\n        Glm4vForConditionalGeneration.forward = forward_with_normal_backend\n        print(f\"Monkey patch {model.__class__.__name__} model forward\")\n\n        # Step 2: patch attention to support ulysses parallelism\n        if use_remove_padding or ulysses_sp_size > 1:\n            from verl.models.transformers.glm4v import glm4v_attn_forward\n\n            Glm4vTextAttention.forward = glm4v_attn_forward\n            print(f\"Monkey patch {model.__class__.__name__} attention layer\")\n\n        # Step 3: patch input for multimodal sequence parallelism\n        if ulysses_sp_size > 1:\n            patch_vlm_for_ulysses_input_slicing(Glm4vTextModel)\n\n    elif model.config.model_type == \"kimi_vl\":\n        if use_remove_padding or ulysses_sp_size > 1:\n            # TODO: Changes need to be made when transformers are adapted.\n            from verl.models.transformers.kimi_vl import _ulysses_flash_attn_forward\n\n            module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward\n            print(\"Monkey patch FlashAttention2.forward in KimiVL\")\n\n        if ulysses_sp_size > 1:\n            patch_vlm_for_ulysses_input_slicing(module.DeepseekV3ForCausalLM)\n\n        if use_fused_kernels:\n            print(\"Not support fused kernels for KimiVL\")\n\n        return\n\n    if use_remove_padding or ulysses_sp_size > 1:\n        if hasattr(module, \"_flash_attention_forward\"):  # transformers <= 4.47.1 or legacy models\n            module._flash_attention_forward = _ulysses_flash_attention_forward\n            print(f\"Monkey patch _flash_attention_forward in {model.__module__}\")\n        else:\n            from transformers.integrations import flash_attention\n\n            flash_attention._flash_attention_forward = _ulysses_flash_attention_forward\n            print(f\"Monkey patch _flash_attention_forward in {flash_attention.__name__}\")\n\n    patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend)\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/npu_patch.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\r\n#\r\n# Copyright 2025 The Qwen Team and The HuggingFace Inc. team\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\n\r\nfrom importlib.metadata import version as get_version\r\nfrom typing import Optional\r\n\r\nimport torch\r\nimport torch.nn.functional as F\r\nimport torch_npu\r\nfrom torch_npu import npu_rotary_mul as apply_rotary_emb\r\nfrom transformers.modeling_utils import PretrainedConfig, PreTrainedModel\r\nfrom transformers.models.qwen2_5_vl import modeling_qwen2_5_vl\r\nfrom transformers.models.qwen3 import modeling_qwen3\r\nfrom transformers.models.qwen3_moe import modeling_qwen3_moe\r\nfrom transformers.utils import logging\r\n\r\nlogger = logging.get_logger(__name__)\r\n\r\n\r\n# This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in\r\n# subsequent versions\r\n# https://github.com/huggingface/transformers/pull/38491\r\ndef apply_rotary_pos_emb_flashatt_qwen2_5_vl_npu(\r\n    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor\r\n) -> tuple[torch.Tensor, torch.Tensor]:\r\n    cos = cos.chunk(2, dim=-1)[0].contiguous()\r\n    sin = sin.chunk(2, dim=-1)[0].contiguous()\r\n    cos = cos.repeat(1, 2)\r\n    sin = sin.repeat(1, 2)\r\n    q_embed = apply_rotary_emb(\r\n        q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()\r\n    ).type_as(q)\r\n    k_embed = apply_rotary_emb(\r\n        k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()\r\n    ).type_as(k)\r\n    return q_embed, k_embed\r\n\r\n\r\n# This api can improve performance on ASCEND NPU\r\ndef rms_norm_forward(self, x):\r\n    return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0]\r\n\r\n\r\ndef silu_forward(self, hidden_state):\r\n    \"\"\"NPU optimized silu\"\"\"\r\n    gate_up = torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1)\r\n    return self.down_proj(torch_npu.npu_swiglu(gate_up, dim=-1))\r\n\r\n\r\ndef apply_rotary_pos_emb_qwen3_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\r\n    cos = cos.unsqueeze(unsqueeze_dim)\r\n    sin = sin.unsqueeze(unsqueeze_dim)\r\n    q_embed = torch_npu.npu_rotary_mul(q, cos, sin)\r\n    k_embed = torch_npu.npu_rotary_mul(k, cos, sin)\r\n    return q_embed.to(q.dtype), k_embed.to(k.dtype)\r\n\r\n\r\nclass GmmFunction(torch.autograd.Function):\r\n    @staticmethod\r\n    def forward(ctx, x, weight, group_list, split_size):\r\n        ctx.save_for_backward(x, weight)\r\n        ctx.group_list = group_list\r\n        ctx.split_size = split_size\r\n\r\n        outputs = torch_npu.npu_grouped_matmul([x], [weight], group_list=group_list, group_type=0, split_item=2)\r\n        return outputs[0]\r\n\r\n    @staticmethod\r\n    def backward(ctx, grad_outputs):\r\n        x, weight = ctx.saved_tensors\r\n        group_list = ctx.group_list\r\n        wt = weight.permute(0, 2, 1)\r\n        xt = x.permute(1, 0)\r\n        dx = torch_npu.npu_grouped_matmul([grad_outputs], [wt], group_list=group_list, group_type=0, split_item=2)\r\n        dw = torch.zeros_like(weight)\r\n        split_size = ctx.split_size\r\n        xt_list = torch.split(xt, split_size, dim=1)\r\n        grad_outputs_list = torch.split(grad_outputs, split_size, dim=0)\r\n        with torch.npu.amp.autocast(enabled=False):\r\n            dw = torch.stack([torch.matmul(xt_list[i], grad_outputs_list[i]) for i in range(len(xt_list))])\r\n\r\n        return dx[0], dw, None, None\r\n\r\n\r\ndef moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\r\n    \"\"\" \"\"\"\r\n    batch_size, sequence_length, hidden_dim = hidden_states.shape\r\n    hidden_states = hidden_states.view(-1, hidden_dim)\r\n    # router_logits: (batch * sequence_length, n_experts)\r\n    router_logits = self.gate(hidden_states)\r\n\r\n    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\r\n    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\r\n    if self.norm_topk_prob:  # only diff with mixtral sparse moe block!\r\n        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\r\n    # we cast back to the input dtype\r\n    routing_weights = routing_weights.to(hidden_states.dtype)\r\n\r\n    final_hidden_states = torch.zeros(\r\n        (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device\r\n    )\r\n\r\n    # One hot encode the selected experts to create an expert mask\r\n    # this will be used to easily index which expert is going to be sollicitated\r\n    expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)\r\n\r\n    # Loop over all available experts in the model and perform the computation on each expert\r\n    # Concat all weights\r\n    input_dtype = hidden_states.dtype\r\n    up_weight_list = [e.up_proj.weight.t().to(input_dtype) for e in self.experts]\r\n    gate_weight_list = [e.gate_proj.weight.t().to(input_dtype) for e in self.experts]\r\n    down_weight_list = [e.down_proj.weight.t().to(input_dtype) for e in self.experts]\r\n    w1 = torch.stack(up_weight_list)\r\n    w2 = torch.stack(gate_weight_list)\r\n    w3 = torch.stack(down_weight_list)\r\n\r\n    # Copied from mindspeed moe_utils.py:permute\r\n    routing_map = selected_experts\r\n    flatten_indices = routing_map.view(-1)\r\n    sorted_indices = torch.sort(flatten_indices.float(), stable=True)[1]\r\n    permuted_tokens = hidden_states.index_select(0, sorted_indices // self.top_k)\r\n\r\n    tokens_per_experts = torch.sum(expert_mask, dim=(1, 2))\r\n    group_list = torch.cumsum(tokens_per_experts, dim=0)\r\n\r\n    cpu_group_list = group_list.to(\"cpu\", non_blocking=False)\r\n    cpu_group_list = [0] + cpu_group_list.tolist()\r\n    split_size = [cpu_group_list[i + 1] - cpu_group_list[i] for i in range(len(cpu_group_list) - 1)]\r\n\r\n    up_res = GmmFunction.apply(permuted_tokens, w1, group_list, split_size)\r\n    gate_res = GmmFunction.apply(permuted_tokens, w2, group_list, split_size)\r\n    act_res = torch_npu.npu_swiglu(torch.cat([gate_res, up_res], dim=-1))\r\n    down_res = GmmFunction.apply(act_res, w3, group_list, split_size)\r\n\r\n    probs = routing_weights\r\n    num_unpermuted_tokens = probs.numel()\r\n    topk = self.top_k\r\n    permuted_tokens = down_res\r\n\r\n    unpermuted_tokens = torch.zeros(\r\n        [num_unpermuted_tokens, permuted_tokens.shape[-1]],\r\n        dtype=permuted_tokens.dtype,\r\n        device=permuted_tokens.device,\r\n    )\r\n    unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)\r\n    unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))\r\n    unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)\r\n    unpermuted_tokens = unpermuted_tokens.sum(dim=1).to(hidden_states.dtype)\r\n    final_hidden_states = unpermuted_tokens\r\n\r\n    return final_hidden_states, router_logits\r\n\r\n\r\n@classmethod\r\ndef _check_and_enable_flash_attn_2(\r\n    cls,\r\n    config,\r\n    torch_dtype: Optional[torch.dtype] = None,\r\n    device_map: Optional[str | dict[str, int]] = None,\r\n    check_device_map: bool = True,\r\n    hard_check_only: bool = False,\r\n) -> PretrainedConfig:\r\n    \"\"\"\r\n    Checks the availability of Flash Attention 2 and compatibility with the current model.\r\n\r\n    If all checks pass and `hard_check_only` is False, the method will set the config attribute\r\n    `attn_implementation` to \"flash_attention_2\" so that the model can initialize\r\n    the correct attention module.\r\n    \"\"\"\r\n    if not cls._supports_flash_attn_2:\r\n        raise ValueError(\r\n            f\"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where the\"\r\n            f\" model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new\"\r\n            \" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new\"\r\n        )\r\n\r\n    if not hard_check_only:\r\n        config._attn_implementation = \"flash_attention_2\"\r\n    logger.info(\"Detect using FlashAttention2 on Ascend NPU.\")\r\n    return config\r\n\r\n\r\nmodeling_qwen2_5_vl.Qwen2RMSNorm.forward = rms_norm_forward\r\nmodeling_qwen2_5_vl.Qwen2_5_VLMLP.forward = silu_forward\r\nmodeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_qwen2_5_vl_npu\r\nmodeling_qwen3_moe.Qwen3MoeRMSNorm.forward = rms_norm_forward\r\nmodeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = moe_block_forward\r\nmodeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_qwen3_npu\r\nmodeling_qwen3.Qwen3RMSNorm.forward = rms_norm_forward\r\nmodeling_qwen3.Qwen3MLP.forward = silu_forward\r\n\r\nif get_version(\"transformers\") == \"4.52.4\":\r\n    PreTrainedModel._check_and_enable_flash_attn_2 = _check_and_enable_flash_attn_2\r\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/qwen2.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable, Optional\n\nimport torch\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv\nfrom transformers.utils import logging\n\n# Import compatibility wrapper for flash_attn_supports_top_left_mask\nfrom verl.utils.transformers_compat import flash_attn_supports_top_left_mask\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\nlogger = logging.get_logger(__name__)\n\n\ndef qwen2_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Cache] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    cache_position: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46\n):\n    \"\"\"\n    Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.\n\n    NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1.\n    \"\"\"\n    bsz, q_len, _ = hidden_states.size()\n\n    query_states = self.q_proj(hidden_states)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)  # full seq length\n\n    if position_embeddings is None:\n        logger.warning_once(\n            \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n            \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n            \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be \"\n            \"removed and `position_embeddings` will be mandatory.\"\n        )\n        cos, sin = self.rotary_emb(value_states, position_ids)\n    else:\n        cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    # repeat k/v heads if n_kv_heads < n_heads\n    key_states = repeat_kv(key_states, self.num_key_value_groups)\n    value_states = repeat_kv(value_states, self.num_key_value_groups)\n    dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n    # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n    # therefore the input hidden states gets silently casted in float32. Hence, we need\n    # cast them back in float16 just to be sure everything works as expected.\n    input_dtype = query_states.dtype\n    if input_dtype == torch.float32:\n        if torch.is_autocast_enabled():\n            target_dtype = torch.get_autocast_gpu_dtype()\n        # Handle the case where the model is quantized\n        elif hasattr(self.config, \"_pre_quantization_dtype\"):\n            target_dtype = self.config._pre_quantization_dtype\n        else:\n            target_dtype = self.q_proj.weight.dtype\n\n        logger.warning_once(\n            f\"The input hidden states seems to be silently casted in float32, this might be related to \"\n            f\"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the \"\n            f\"input in {target_dtype}.\"\n        )\n\n        query_states = query_states.to(target_dtype)\n        key_states = key_states.to(target_dtype)\n        value_states = value_states.to(target_dtype)\n\n    # Reashape to the expected shape for Flash Attention\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    if (\n        self.config.use_sliding_window\n        and getattr(self.config, \"sliding_window\", None) is not None\n        and self.layer_idx >= self.config.max_window_layers\n    ):\n        sliding_window = self.config.sliding_window\n    else:\n        sliding_window = None\n\n    attn_output = _flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        position_ids=position_ids,\n        dropout=dropout_rate,\n        sliding_window=sliding_window,\n        is_causal=self.is_causal,\n        use_top_left_mask=flash_attn_supports_top_left_mask(),\n    )\n\n    # use full_q_len to reshape\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n\n    if not output_attentions:\n        attn_weights = None\n\n    return attn_output, attn_weights, past_key_value\n\n\ndef qwen2_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    position_embeddings: tuple[torch.Tensor, torch.Tensor],\n    attention_mask: Optional[torch.Tensor],\n    past_key_value: Optional[Cache] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n    \"\"\"\n    Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.\n\n    NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.\n    \"\"\"\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n\n    bsz, q_len, _ = hidden_states.shape\n    hidden_shape = (bsz, q_len, -1, self.head_dim)\n\n    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)\n\n        # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)\n\n    cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    sliding_window = None\n    if (\n        self.config.use_sliding_window\n        and getattr(self.config, \"sliding_window\", None) is not None\n        and self.layer_idx >= self.config.max_window_layers\n    ):\n        sliding_window = self.config.sliding_window\n\n    from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward\n\n    attention_interface: Callable = eager_attention_forward\n    if self.config._attn_implementation != \"eager\":\n        if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n            logger.warning_once(\n                \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. \"\n                \"Falling back to eager attention. This warning can be removed using the argument \"\n                '`attn_implementation=\"eager\"` when loading the model.'\n            )\n        else:\n            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n    attn_output, attn_weights = attention_interface(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        dropout=0.0 if not self.training else self.attention_dropout,\n        scaling=self.scaling,\n        sliding_window=sliding_window,  # main diff with Llama\n        **kwargs,\n    )\n\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n    return attn_output, attn_weights\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/qwen2_vl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport logging\nimport os\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nimport torch.distributed as dist\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check\nfrom transformers.models.qwen2_vl.modeling_qwen2_vl import (\n    Qwen2VLAttention,\n    Qwen2VLCausalLMOutputWithPast,\n    Qwen2VLForConditionalGeneration,\n)\nfrom transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10\n\nfrom verl.utils.device import is_npu_available\nfrom verl.utils.transformers_compat import is_transformers_version_in_range\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_group,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func\n\n    _flash_supports_window_size = \"window_size\" in inspect.signature(flash_attn_func).parameters\n    _flash_supports_deterministic = \"deterministic\" in inspect.signature(flash_attn_func).parameters\n    _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\nif is_npu_available:\n    from transformers.integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func\n    from transformers.integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func\n    from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask\n\n    _flash_supports_window_size = \"window_size\" in inspect.signature(flash_attn_func).parameters\n    _flash_supports_deterministic = \"deterministic\" in inspect.signature(flash_attn_func).parameters\n    _flash_use_top_left_mask = flash_attn_supports_top_left_mask()\n\n_flash_deterministic_enabled = os.getenv(\"FLASH_ATTENTION_DETERMINISTIC\", \"0\") == \"1\"\n\n\ndef get_rope_index(\n    processor,\n    input_ids: torch.Tensor,\n    image_grid_thw: Optional[torch.Tensor] = None,\n    video_grid_thw: Optional[torch.Tensor] = None,\n    second_per_grid_ts: Optional[torch.Tensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"\n    Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.\n    The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.\n    https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1405\n    \"\"\"\n    spatial_merge_size = processor.image_processor.merge_size\n    tokens_per_second = 2\n    image_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|image_pad|>\")\n    video_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|video_pad|>\")\n    vision_start_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|vision_start|>\")\n    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):\n        if attention_mask is None:\n            attention_mask = torch.ones_like(input_ids)\n\n        position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device)  # (3, seqlen)\n        image_index, video_index = 0, 0\n        input_ids = input_ids[attention_mask == 1]\n        image_nums, video_nums = 0, 0\n        vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)\n        vision_tokens = input_ids[vision_start_indices + 1]\n        image_nums = (vision_tokens == image_token_id).sum()\n        video_nums = (vision_tokens == video_token_id).sum()\n        input_tokens = input_ids.tolist()\n        llm_pos_ids_list: list = []\n        st = 0\n        remain_images, remain_videos = image_nums, video_nums\n        for _ in range(image_nums + video_nums):\n            if image_token_id in input_tokens and remain_images > 0:\n                ed_image = input_tokens.index(image_token_id, st)\n            else:\n                ed_image = len(input_tokens) + 1\n            if video_token_id in input_tokens and remain_videos > 0:\n                ed_video = input_tokens.index(video_token_id, st)\n            else:\n                ed_video = len(input_tokens) + 1\n            if ed_image < ed_video:\n                t, h, w = (\n                    image_grid_thw[image_index][0],\n                    image_grid_thw[image_index][1],\n                    image_grid_thw[image_index][2],\n                )\n                second_per_grid_t = 0\n                image_index += 1\n                remain_images -= 1\n                ed = ed_image\n            else:\n                t, h, w = (\n                    video_grid_thw[video_index][0],\n                    video_grid_thw[video_index][1],\n                    video_grid_thw[video_index][2],\n                )\n                second_per_grid_t = second_per_grid_ts[video_index] if second_per_grid_ts is not None else 1.0\n\n                video_index += 1\n                remain_videos -= 1\n                ed = ed_video\n\n            llm_grid_t, llm_grid_h, llm_grid_w = (\n                t.item(),\n                h.item() // spatial_merge_size,\n                w.item() // spatial_merge_size,\n            )\n            text_len = ed - st\n\n            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n            t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)\n            t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()\n            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()\n            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()\n            llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)\n            st = ed + llm_grid_t * llm_grid_h * llm_grid_w\n\n        if st < len(input_tokens):\n            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n            text_len = len(input_tokens) - st\n            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)\n        position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)\n    else:\n        if attention_mask is not None:\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)\n        else:\n            position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)\n\n    return position_ids\n\n\ndef prepare_fa2_from_position_ids(\n    query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor\n):\n    assert position_ids.ndim == 2  # (batch_size, seq_length)\n    query = query.contiguous().view(-1, query.size(-2), query.size(-1))\n    key = key.contiguous().view(-1, key.size(-2), key.size(-1))\n    value = value.contiguous().view(-1, value.size(-2), value.size(-1))\n    position_ids = position_ids.view(-1)\n    cu_seqlens = torch.cat(\n        (\n            (position_ids == 0).nonzero().view(-1).to(torch.int32),\n            torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),\n        )\n    )\n    max_length = cu_seqlens.diff().max()  # use cu_seqlens to infer max_length for qwen2vl mrope\n    return (query, key, value, (cu_seqlens, cu_seqlens), (max_length, max_length))\n\n\ndef _custom_flash_attention_forward(\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    query_length: int,\n    is_causal: bool = True,\n    position_ids: Optional[torch.Tensor] = None,\n    sliding_window: Optional[int] = None,\n    use_top_left_mask: bool = False,\n    deterministic: Optional[bool] = None,\n    **kwargs,\n):\n    \"\"\"\n    Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)\n    \"\"\"\n    # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).\n    use_sliding_windows = (\n        _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window\n    )\n    flash_kwargs = {\"window_size\": (sliding_window, sliding_window)} if use_sliding_windows else {}\n\n    if _flash_supports_deterministic:\n        flash_kwargs[\"deterministic\"] = deterministic if deterministic is not None else _flash_deterministic_enabled\n\n    if kwargs.get(\"softcap\") is not None:\n        flash_kwargs[\"softcap\"] = kwargs.pop(\"softcap\")\n\n    query_states, key_states, value_states = fa_peft_integration_check(\n        query_states, key_states, value_states, target_dtype=torch.bfloat16\n    )\n\n    if position_ids is not None:\n        assert position_ids.ndim == 2  # (batch_size, seq_length / sp_size)\n\n    sp_size = get_ulysses_sequence_parallel_world_size()\n    if sp_size > 1:\n        # qkv: (batch_size, seq_length / sp_size, num_head, head_size)\n        validate_ulysses_config(query_states.size(2), sp_size)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)\n        position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]\n        position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())\n        position_ids = torch.cat(position_ids_lst, dim=-1)  # (batch_size, seq_length)\n\n    if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():\n        batch_size = query_states.size(0)\n        q, k, v, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = prepare_fa2_from_position_ids(\n            query_states, key_states, value_states, position_ids\n        )\n        attn_output = flash_attn_varlen_func(\n            q=q,\n            k=k,\n            v=v,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            max_seqlen_q=max_seqlen_q,\n            max_seqlen_k=max_seqlen_k,\n            dropout_p=kwargs.pop(\"dropout\", 0.0),\n            softmax_scale=kwargs.pop(\"softmax_scale\", None),\n            causal=is_causal,\n            **flash_kwargs,\n        )\n        attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))\n    else:\n        attn_output = _flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            query_length,\n            is_causal=is_causal,\n            sliding_window=sliding_window,\n            use_top_left_mask=use_top_left_mask,\n            deterministic=deterministic,\n            **kwargs,\n        )  # do not pass position_ids to old flash_attention_forward\n\n    if sp_size > 1:\n        # (batch_size, seq_length, num_head, head_size)\n        attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)\n\n    return attn_output\n\n\ndef qwen2_vl_attn_forward(\n    self: \"Qwen2VLAttention\",\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46\n    **kwargs,\n) -> tuple[torch.Tensor, None, None]:\n    from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv\n\n    bsz, q_len, _ = hidden_states.size()  # q_len = seq_length / sp_size\n    query_states = self.q_proj(hidden_states)  # (batch_size, seq_length / sp_size, num_heads * head_size)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n    # Because the input can be padded, the absolute sequence length depends on the max position id.\n    cos, sin = position_embeddings\n    query_states, key_states = apply_multimodal_rotary_pos_emb(\n        query_states, key_states, cos, sin, self.rope_scaling[\"mrope_section\"]\n    )\n    key_states = repeat_kv(key_states, self.num_key_value_groups)\n    value_states = repeat_kv(value_states, self.num_key_value_groups)\n    dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n    sliding_window = None\n    if (\n        self.config.use_sliding_window\n        and getattr(self.config, \"sliding_window\", None) is not None\n        and self.layer_idx >= self.config.max_window_layers\n    ):\n        sliding_window = self.config.sliding_window\n\n    # This is before the transpose\n    q_len = query_states.shape[2]\n\n    # FA2 uses non-transposed inputs\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    if position_ids.ndim == 3:\n        position_ids = position_ids[0]\n\n    attn_output = _custom_flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        query_length=q_len,\n        is_causal=getattr(self, \"is_causal\", True),\n        dropout=dropout_rate,\n        sliding_window=sliding_window,\n        use_top_left_mask=_flash_use_top_left_mask,\n        position_ids=position_ids,  # important: pass position ids\n    )  # (batch_size, seq_length / sp_size, num_head, head_size)\n    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()\n    attn_output = self.o_proj(attn_output)\n    if is_transformers_version_in_range(min_version=\"4.54.0\"):\n        return attn_output, None\n    else:\n        return attn_output, None, None\n\n\ndef _get_input_embeds(\n    model: \"Qwen2VLForConditionalGeneration\",\n    input_ids: torch.LongTensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    pixel_values: Optional[torch.FloatTensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n):\n    inputs_embeds = model.get_input_embeddings()(input_ids)\n    if pixel_values is not None:\n        pixel_values = pixel_values.type(model.visual.dtype)\n        image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)\n        n_image_tokens = (input_ids == model.config.image_token_id).sum().item()\n        n_image_features = image_embeds.shape[0]\n        if n_image_tokens != n_image_features:\n            raise ValueError(\n                f\"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}\"\n            )\n\n        mask = input_ids == model.config.image_token_id\n        mask_unsqueezed = mask.unsqueeze(-1)\n        mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n        image_mask = mask_expanded.to(inputs_embeds.device)\n\n        image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n        inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)\n\n    if pixel_values_videos is not None:\n        pixel_values_videos = pixel_values_videos.type(model.visual.dtype)\n        video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)\n        n_video_tokens = (input_ids == model.config.video_token_id).sum().item()\n        n_video_features = video_embeds.shape[0]\n        if n_video_tokens != n_video_features:\n            raise ValueError(\n                f\"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}\"\n            )\n\n        mask = input_ids == model.config.video_token_id\n        mask_unsqueezed = mask.unsqueeze(-1)\n        mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n        video_mask = mask_expanded.to(inputs_embeds.device)\n\n        video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n        inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)\n\n    if pixel_values is None and pixel_values_videos is None:  # handle mixed text-image data\n        config = model.config.vision_config\n        patch_dim = config.in_channels * config.temporal_patch_size * config.patch_size**2\n        pixel_values = torch.zeros((16, patch_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device)\n        image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)\n        image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)\n        inputs_embeds += 0.0 * image_embeds.mean()\n\n    if attention_mask is not None:\n        attention_mask = attention_mask.to(inputs_embeds.device)\n\n    return inputs_embeds, attention_mask\n\n\ndef process_position_ids(position_ids: torch.Tensor) -> torch.Tensor:\n    if position_ids.ndim != 3 or position_ids.size(0) != 4:\n        # we concat the text position ids with the 3D vision position ids by default\n        # see https://github.com/huggingface/transformers/pull/39447\n        raise ValueError(\"position_ids should be a 3D tensor of shape (4, batch_size, seq_length).\")\n\n    if is_transformers_version_in_range(max_version=\"4.53.3\"):\n        # transformers < 4.54.0 only accepts vision position ids, so we discard the text position ids here\n        position_ids = position_ids[1:]\n\n    return position_ids\n\n\n@dataclass\nclass Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n\n\ndef qwen2_vl_base_forward(\n    self: \"Qwen2VLForConditionalGeneration\",\n    input_ids: torch.LongTensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    pixel_values: Optional[torch.FloatTensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    **kwargs,\n):\n    kwargs[\"inputs_embeds\"], kwargs[\"attention_mask\"] = _get_input_embeds(\n        self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw\n    )  # avoid lora module having multiple keyword arguments\n    return self.language_model(input_ids=None, **kwargs)\n\n\ndef qwen2_vl_forward(\n    self: \"Qwen2VLForConditionalGeneration\",\n    input_ids: torch.LongTensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    pixel_values: Optional[torch.FloatTensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    **kwargs,\n):\n    if is_transformers_version_in_range(min_version=\"4.52.0\"):\n        return self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=process_position_ids(position_ids),\n            pixel_values=pixel_values,\n            pixel_values_videos=pixel_values_videos,\n            image_grid_thw=image_grid_thw,\n            video_grid_thw=video_grid_thw,\n            **kwargs,\n        )\n    else:\n        inputs_embeds, attention_mask = _get_input_embeds(\n            self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw\n        )\n        return self.model(\n            input_ids=None,\n            attention_mask=attention_mask,\n            position_ids=process_position_ids(position_ids),\n            inputs_embeds=inputs_embeds,\n            **kwargs,\n        )\n\n\ndef forward_with_normal_backend(\n    self: Qwen2VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    labels: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **kwargs,\n) -> \"Qwen2VLCausalLMOutputWithPast\":\n    outputs = qwen2_vl_forward(self, input_ids, **kwargs)\n    hidden_states = outputs[0]\n    logits = self.lm_head(hidden_states)\n\n    return Qwen2VLCausalLMOutputWithPast(\n        logits=logits,\n        hidden_states=outputs.hidden_states,\n    )\n\n\ndef forward_with_torch_backend(\n    self: Qwen2VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    labels: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **kwargs,\n) -> tuple | Qwen2VLCausalLMOutputForPPO:\n    from verl.utils.experimental.torch_functional import FusedLinearForPPO\n\n    outputs = qwen2_vl_forward(self, input_ids, **kwargs)\n    hidden_states = outputs[0]\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_torch_backend, either labels or input_ids must be provided.\")\n\n    fused_linear_for_ppo = FusedLinearForPPO()\n    log_probs, entropy = fused_linear_for_ppo.forward(\n        hidden_states=hidden_states,\n        vocab_weights=self.lm_head.weight,\n        input_ids=rolled_labels,\n        temperature=temperature,\n    )\n    return Qwen2VLCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        hidden_states=outputs.hidden_states,\n    )\n\n\ndef forward_with_triton_backend(\n    self: Qwen2VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    labels: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **kwargs,\n) -> tuple | Qwen2VLCausalLMOutputForPPO:\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\n\n    outputs = qwen2_vl_forward(self, input_ids, **kwargs)\n    hidden_states = outputs[0]\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_triton_backend, either labels or input_ids must be provided.\")\n\n    log_probs, entropy = linear_cross_entropy(\n        hidden_states,\n        self.lm_head.weight,\n        rolled_labels,\n        temperature,\n        \"none\",\n    )\n    return Qwen2VLCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        hidden_states=outputs.hidden_states,\n    )\n"
  },
  {
    "path": "verl_distillation/verl/models/transformers/qwen3_vl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom transformers.models.qwen3_vl.modeling_qwen3_vl import (\n    Qwen3VLCausalLMOutputWithPast,\n    Qwen3VLForConditionalGeneration,\n)\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef get_rope_index(\n    processor,\n    input_ids: torch.Tensor,\n    image_grid_thw: Optional[torch.Tensor] = None,\n    video_grid_thw: Optional[torch.Tensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    **kwargs,\n) -> torch.Tensor:\n    \"\"\"\n    Gets the position ids for Qwen3-VL, it should be generated before sharding the sequence.\n    The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.\n    https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L916\n    \"\"\"\n    spatial_merge_size = processor.image_processor.merge_size\n    image_token_id = processor.image_token_id\n    video_token_id = processor.video_token_id\n    vision_start_token_id = processor.vision_start_token_id\n\n    # Since we use timestamps to seperate videos,\n    # like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>,\n    # the video_grid_thw should also be split\n    if video_grid_thw is not None:\n        video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)\n        video_grid_thw[:, 0] = 1\n\n    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):\n        if attention_mask is None:\n            attention_mask = torch.ones_like(input_ids)\n\n        position_ids = torch.ones(3, input_ids.shape[0], dtype=input_ids.dtype, device=input_ids.device)\n        image_index, video_index = 0, 0\n        attention_mask = attention_mask.to(input_ids.device)\n        input_ids = input_ids[attention_mask == 1]\n        image_nums, video_nums = 0, 0\n        vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)\n        vision_tokens = input_ids[vision_start_indices + 1]\n        image_nums = (vision_tokens == image_token_id).sum()\n        video_nums = (vision_tokens == video_token_id).sum()\n        input_tokens = input_ids.tolist()\n        llm_pos_ids_list: list = []\n        st = 0\n        remain_images, remain_videos = image_nums, video_nums\n        for _ in range(image_nums + video_nums):\n            if image_token_id in input_tokens and remain_images > 0:\n                ed_image = input_tokens.index(image_token_id, st)\n            else:\n                ed_image = len(input_tokens) + 1\n            if video_token_id in input_tokens and remain_videos > 0:\n                ed_video = input_tokens.index(video_token_id, st)\n            else:\n                ed_video = len(input_tokens) + 1\n            if ed_image < ed_video:\n                t, h, w = (\n                    image_grid_thw[image_index][0],\n                    image_grid_thw[image_index][1],\n                    image_grid_thw[image_index][2],\n                )\n                image_index += 1\n                remain_images -= 1\n                ed = ed_image\n            else:\n                t, h, w = (\n                    video_grid_thw[video_index][0],\n                    video_grid_thw[video_index][1],\n                    video_grid_thw[video_index][2],\n                )\n                video_index += 1\n                remain_videos -= 1\n                ed = ed_video\n\n            llm_grid_t, llm_grid_h, llm_grid_w = (\n                t.item(),\n                h.item() // spatial_merge_size,\n                w.item() // spatial_merge_size,\n            )\n            text_len = ed - st\n\n            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n            # t_index is always 0 because llm_grid_t is always 1\n            # (we use timestamps to encode the temporal information for videos)\n            t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()\n            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()\n            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()\n            llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)\n            st = ed + llm_grid_t * llm_grid_h * llm_grid_w\n\n        if st < len(input_tokens):\n            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n            text_len = len(input_tokens) - st\n            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)\n        position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)\n    else:\n        if attention_mask is not None:\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1).to(attention_mask.device)\n        else:\n            position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)\n\n    return position_ids\n\n\ndef _get_input_embeds(\n    model: \"Qwen3VLForConditionalGeneration\",\n    input_ids: torch.LongTensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    pixel_values: Optional[torch.FloatTensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n):\n    inputs_embeds = model.get_input_embeddings()(input_ids)\n    image_mask, video_mask = None, None\n    if pixel_values is not None:\n        pixel_values = pixel_values.type(model.visual.dtype)\n        image_embeds, deepstack_image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)\n        n_image_tokens = (input_ids == model.config.image_token_id).sum().item()\n        n_image_features = image_embeds.shape[0]\n        if n_image_tokens != n_image_features:\n            raise ValueError(\n                f\"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}\"\n            )\n\n        mask = input_ids == model.config.image_token_id\n        mask_unsqueezed = mask.unsqueeze(-1)\n        mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n        image_mask = mask_expanded.to(inputs_embeds.device)\n\n        image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n        inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)\n\n    if pixel_values_videos is not None:\n        pixel_values_videos = pixel_values_videos.type(model.visual.dtype)\n        video_embeds, deepstack_video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)\n        n_video_tokens = (input_ids == model.config.video_token_id).sum().item()\n        n_video_features = video_embeds.shape[0]\n        if n_video_tokens != n_video_features:\n            raise ValueError(\n                f\"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}\"\n            )\n\n        mask = input_ids == model.config.video_token_id\n        mask_unsqueezed = mask.unsqueeze(-1)\n        mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n        video_mask = mask_expanded.to(inputs_embeds.device)\n\n        video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n        inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)\n\n    visual_pos_masks = None\n    deepstack_visual_embeds = None\n    if image_mask is not None and video_mask is not None:\n        # aggregate visual_pos_masks and deepstack_visual_embeds\n        image_mask = image_mask[..., 0]\n        video_mask = video_mask[..., 0]\n        visual_pos_masks = image_mask | video_mask\n        deepstack_visual_embeds = []\n        image_mask_joint = image_mask[visual_pos_masks]\n        video_mask_joint = video_mask[visual_pos_masks]\n        for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds, strict=False):\n            embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)\n            embed_joint[image_mask_joint, :] = img_embed\n            embed_joint[video_mask_joint, :] = vid_embed\n            deepstack_visual_embeds.append(embed_joint)\n    elif image_mask is not None:\n        image_mask = image_mask[..., 0]\n        visual_pos_masks = image_mask\n        deepstack_visual_embeds = deepstack_image_embeds\n    elif video_mask is not None:\n        video_mask = video_mask[..., 0]\n        visual_pos_masks = video_mask\n        deepstack_visual_embeds = deepstack_video_embeds\n\n    if pixel_values is None and pixel_values_videos is None:\n        config = model.config.vision_config\n        patch_dim = config.in_channels * config.temporal_patch_size * config.patch_size**2\n        pixel_values = torch.zeros((16, patch_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device)\n        image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)\n        image_embeds, dummy_deepstack_image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)\n        inputs_embeds += 0.0 * image_embeds.mean()\n        for emb in dummy_deepstack_image_embeds or []:\n            inputs_embeds += 0.0 * emb.mean()\n\n    if attention_mask is not None:\n        attention_mask = attention_mask.to(inputs_embeds.device)\n\n    return {\n        \"inputs_embeds\": inputs_embeds,\n        \"attention_mask\": attention_mask,\n        \"visual_pos_masks\": visual_pos_masks,\n        \"deepstack_visual_embeds\": deepstack_visual_embeds,\n    }\n\n\n@dataclass\nclass Qwen3VLCausalLMOutputForPPO(Qwen3VLCausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n\n\ndef qwen3_vl_base_forward(\n    self: \"Qwen3VLForConditionalGeneration\",\n    input_ids: torch.LongTensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    pixel_values: Optional[torch.FloatTensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    **kwargs,\n):\n    input_kwargs = _get_input_embeds(\n        self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw\n    )  # avoid lora module having multiple keyword arguments\n    kwargs.update(input_kwargs)\n    return self.language_model(\n        input_ids=None,\n        **kwargs,\n    )\n\n\ndef forward_with_normal_backend(\n    self: \"Qwen3VLForConditionalGeneration\",\n    input_ids: torch.LongTensor = None,\n    labels: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **kwargs,\n) -> \"Qwen3VLCausalLMOutputForPPO\":\n    outputs = self.model(input_ids, **kwargs)\n    hidden_states = outputs[0]\n    logits = self.lm_head(hidden_states)\n\n    return Qwen3VLCausalLMOutputForPPO(\n        logits=logits,\n        hidden_states=outputs.hidden_states,\n    )\n\n\ndef forward_with_torch_backend(\n    self: \"Qwen3VLForConditionalGeneration\",\n    input_ids: torch.LongTensor = None,\n    labels: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **kwargs,\n) -> \"Qwen3VLCausalLMOutputForPPO\":\n    from verl.utils.experimental.torch_functional import FusedLinearForPPO\n\n    outputs = self.model(input_ids, **kwargs)\n    hidden_states = outputs[0]\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_torch_backend, either labels or input_ids must be provided.\")\n\n    fused_linear_for_ppo = FusedLinearForPPO()\n    log_probs, entropy = fused_linear_for_ppo.forward(\n        hidden_states=hidden_states,\n        vocab_weights=self.lm_head.weight,\n        input_ids=rolled_labels,\n        temperature=temperature,\n    )\n    return Qwen3VLCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        hidden_states=outputs.hidden_states,\n    )\n\n\ndef forward_with_triton_backend(\n    self: \"Qwen3VLForConditionalGeneration\",\n    input_ids: torch.LongTensor = None,\n    labels: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **kwargs,\n) -> \"Qwen3VLCausalLMOutputForPPO\":\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\n\n    outputs = self.model(input_ids, **kwargs)\n    hidden_states = outputs[0]\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_triton_backend, either labels or input_ids must be provided.\")\n\n    log_probs, entropy = linear_cross_entropy(\n        hidden_states,\n        self.lm_head.weight,\n        rolled_labels,\n        temperature,\n        \"none\",\n    )\n    return Qwen3VLCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        hidden_states=outputs.hidden_states,\n    )\n"
  },
  {
    "path": "verl_distillation/verl/models/weight_loader_registry.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ndef get_weight_loader(arch: str):\n    from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel\n\n    _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = {\n        \"LlamaForCausalLM\": load_state_dict_to_megatron_gptmodel,\n        \"Qwen2ForCausalLM\": load_state_dict_to_megatron_gptmodel,\n    }\n\n    if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY:\n        return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch]\n    raise ValueError(\n        f\"Model architectures {arch} loader are not supported for now. Supported architectures: \"\n        f\"{_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}\"\n    )\n\n\ndef get_weight_saver(arch: str):\n    from verl.models.mcore.saver import (\n        merge_megatron_ckpt_gptmodel,\n        merge_megatron_ckpt_gptmodel_dpskv3,\n        merge_megatron_ckpt_gptmodel_mixtral,\n        merge_megatron_ckpt_gptmodel_qwen2_5_vl,\n        merge_megatron_ckpt_gptmodel_qwen_moe,\n    )\n\n    _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = {\n        \"LlamaForCausalLM\": merge_megatron_ckpt_gptmodel,\n        \"Qwen2ForCausalLM\": merge_megatron_ckpt_gptmodel,\n        \"MixtralForCausalLM\": merge_megatron_ckpt_gptmodel_mixtral,\n        \"Qwen2MoeForCausalLM\": merge_megatron_ckpt_gptmodel_qwen_moe,\n        \"Qwen2_5_VLForConditionalGeneration\": merge_megatron_ckpt_gptmodel_qwen2_5_vl,\n        \"DeepseekV3ForCausalLM\": merge_megatron_ckpt_gptmodel_dpskv3,\n        \"Qwen3ForCausalLM\": merge_megatron_ckpt_gptmodel,\n        \"Qwen3ForTokenClassification\": merge_megatron_ckpt_gptmodel,\n        \"Qwen3MoeForCausalLM\": merge_megatron_ckpt_gptmodel_qwen_moe,\n    }\n    if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY:\n        return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch]\n    raise ValueError(\n        f\"Model architectures {arch} saver are not supported for now. Supported architectures: \"\n        f\"{_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}\"\n    )\n"
  },
  {
    "path": "verl_distillation/verl/protocol.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement base data transfer protocol between any two functions, modules.\nWe can subclass Protocol to define more detailed batch info with specific keys\n\"\"\"\n\nimport contextlib\nimport copy\nimport logging\nimport math\nimport os\nimport pickle\nfrom dataclasses import dataclass, field\nfrom typing import Any, Callable, Optional\n\nimport numpy as np\nimport ray\nimport tensordict\nimport torch\nimport torch.distributed\nfrom packaging import version\nfrom packaging.version import parse as parse_version\nfrom tensordict import TensorDict\nfrom torch.utils.data import DataLoader\n\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.py_functional import union_two_dict\nfrom verl.utils.torch_functional import allgather_dict_tensors\n\n__all__ = [\"DataProto\", \"union_tensor_dict\"]\n\nwith contextlib.suppress(Exception):\n    tensordict.set_lazy_legacy(False).set()\n    if parse_version(tensordict.__version__) < parse_version(\"0.10.0\"):\n        tensordict.set_list_to_stack(True).set()\n\n\nclass _DataProtoConfigMeta(type):\n    _config = {}\n\n    auto_padding_key = \"_verl_auto_padding\"\n\n    @property\n    def auto_padding(cls):\n        enabled_by_env = os.getenv(\"VERL_AUTO_PADDING\", \"FALSE\").upper() in [\"TRUE\", \"1\"]\n        return enabled_by_env or cls._config.get(cls.auto_padding_key, False)\n\n    @auto_padding.setter\n    def auto_padding(cls, enabled: bool):\n        assert isinstance(enabled, bool), f\"enabled must be a boolean, got {enabled} as {type(enabled)}\"\n        cls._config[cls.auto_padding_key] = enabled\n\n\nclass DataProtoConfig(metaclass=_DataProtoConfigMeta):\n    pass\n\n\n_padding_size_key = \"_padding_size_key_x123d\"\n\n\ndef pad_dataproto_to_divisor(data: \"DataProto\", size_divisor: int):\n    \"\"\"Pad a DataProto to size divisible by size_divisor\n\n    Args:\n        size_divisor (int): size divisor\n\n    Returns:\n        data: (DataProto): the padded DataProto\n        pad_size (int)\n    \"\"\"\n    assert isinstance(data, DataProto), \"data must be a DataProto\"\n    if len(data) % size_divisor != 0:\n        pad_size = size_divisor - len(data) % size_divisor\n        padding_protos = []\n        remaining_pad = pad_size\n        while remaining_pad > 0:\n            take_size = min(remaining_pad, len(data))\n            padding_protos.append(data[:take_size])\n            remaining_pad -= take_size\n        data_padded = DataProto.concat([data] + padding_protos)\n    else:\n        if len(data) == 0:\n            logging.warning(\"padding a DataProto with no item, no changed made\")\n        pad_size = 0\n        data_padded = data\n    return data_padded, pad_size\n\n\ndef unpad_dataproto(data: \"DataProto\", pad_size):\n    \"\"\"Unpad the data proto with pad_size. i.e. `data[:-pad_size]`\"\"\"\n    if pad_size != 0:\n        data = data[:-pad_size]\n    return data\n\n\ndef union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:\n    \"\"\"Union two tensordicts.\"\"\"\n    assert tensor_dict1.batch_size == tensor_dict2.batch_size, (\n        f\"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}\"\n    )\n    for key in tensor_dict2.keys():\n        if key not in tensor_dict1.keys():\n            tensor_dict1[key] = tensor_dict2[key]\n        else:\n            assert tensor_dict1[key].equal(tensor_dict2[key]), (\n                f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n            )\n\n    return tensor_dict1\n\n\ndef _array_equal(array1: np.ndarray, array2: np.ndarray, visited: set[int]) -> bool:\n    \"\"\"\n    Recursively compares two NumPy arrays for strict equality, with special\n    handling for object-dtype arrays, NaN values, and circular references.\n    This function assumes that the two arguments provided are NumPy arrays.\n\n    Args:\n        array1: The first NumPy array.\n        array2: The second NumPy array.\n\n    Returns:\n        True if the arrays' dtypes, shapes, and all elements are equal.\n    \"\"\"\n    # Check dtype and shape first, as this is the fastest failure path.\n    if array1.dtype != array2.dtype or array1.shape != array2.shape:\n        return False\n\n    # For non-object dtypes, use NumPy's implementation with equal_nan=True.\n    if array1.dtype != \"object\":\n        return np.array_equal(array1, array2, equal_nan=True)\n\n    # For object-dtype arrays, we must recursively compare each element.\n    # We delegate to _deep_equal to handle elements, as they could be any\n    # type, including other nested arrays or NaNs.\n    return all(_deep_equal(x, y, visited) for x, y in zip(array1.flat, array2.flat, strict=False))\n\n\ndef _deep_equal(a: Any, b: Any, visited: set[int]) -> bool:\n    \"\"\"\n    Recursively performs a deep comparison between two Python objects.\n    - Handles NaN values correctly (NaN == NaN evaluates to True).\n    - Handling circular references.\n    - Dispatches to _array_equal if both objects are NumPy arrays.\n    - Otherwise, uses standard '==' comparison.\n    \"\"\"\n    if type(a) is not type(b):\n        return False\n\n    # If we have seen this object ID before on this path, it's a cycle.\n    # Since we already know the types match, we can safely assume this part\n    # of the structure is equal.\n    obj_id = id(a)\n    if obj_id in visited:\n        return True\n\n    visited.add(obj_id)\n\n    # Perform the specific comparison based on type\n    result = False\n    if isinstance(a, float) and math.isnan(a) and math.isnan(b):\n        result = True\n    elif isinstance(a, np.ndarray):\n        # We know b is also an ndarray due to the initial type check\n        result = _array_equal(a, b, visited)\n    else:\n        # Standard equality for all other types\n        result = a == b\n\n    # Clean up the visited set on the way out of the recursion\n    visited.remove(obj_id)\n    return result\n\n\ndef union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]:\n    for key, val in tensor_dict2.items():\n        if key in tensor_dict1:\n            assert isinstance(tensor_dict2[key], np.ndarray)\n            assert isinstance(tensor_dict1[key], np.ndarray)\n            # to properly deal with nan and object type\n            assert _deep_equal(tensor_dict1[key], tensor_dict2[key], visited=set()), (\n                f\"`{key}` in tensor_dict1 and tensor_dict2 are not the same object.\"\n            )\n        tensor_dict1[key] = val\n\n    return tensor_dict1\n\n\ndef list_of_dict_to_dict_of_list(list_of_dict: list[dict]):\n    if len(list_of_dict) == 0:\n        return {}\n    keys = list_of_dict[0].keys()\n    output = {key: [] for key in keys}\n    for data in list_of_dict:\n        for key, item in data.items():\n            assert key in output\n            output[key].append(item)\n    return output\n\n\ndef fold_batch_dim(data: \"DataProto\", new_batch_size):\n    \"\"\"\n    Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]\n    \"\"\"\n    batch_size = data.batch.batch_size[0]\n\n    assert batch_size % new_batch_size == 0\n\n    tensor: TensorDict = data.batch\n    non_tensor = data.non_tensor_batch\n\n    tensor = tensor.view(new_batch_size, -1)\n    tensor.auto_batch_size_(batch_dims=1)\n\n    for key, val in non_tensor.items():\n        non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))\n\n    return type(data)(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)\n\n\ndef unfold_batch_dim(data: \"DataProto\", batch_dims=2):\n    \"\"\"\n    Unfold the first n dims as new batch dim\n    \"\"\"\n    tensor: TensorDict = data.batch\n    non_tensor = data.non_tensor_batch\n    tensor.auto_batch_size_(batch_dims=batch_dims)\n    tensor = tensor.view(-1)\n\n    batch_size = tensor.batch_size[0]\n\n    non_tensor_new = {}\n\n    for key, val in non_tensor.items():\n        non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:]))\n\n    return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info)\n\n\ndef serialize_single_tensor(obj: torch.Tensor) -> tuple[str, tuple[int, ...], int | memoryview]:\n    data = obj.flatten().contiguous().view(torch.uint8).numpy()\n    dtype = str(obj.dtype).removeprefix(\"torch.\")\n    return dtype, obj.shape, data\n\n\ndef serialize_tensordict(batch: TensorDict) -> tuple[tuple[int, ...], Optional[str], dict[str, tuple[str, Any]]]:\n    encoded_items: dict[str, tuple[Any]] = {}\n    for k, v in batch.items():\n        if not v.is_nested:\n            encoded_items[k] = serialize_single_tensor(v)\n        else:\n            layout = str(v.layout).removeprefix(\"torch.\")\n            data = [serialize_single_tensor(tensor) for tensor in v.unbind()]\n            encoded_items[k] = (layout, data)\n\n    batch_size = tuple(batch.batch_size)\n    device = str(batch.device) if batch.device is not None else None\n    return batch_size, device, encoded_items\n\n\ndef deserialize_single_tensor(arr: Any) -> torch.Tensor:\n    dtype, shape, data = arr\n\n    torch_dtype = getattr(torch, dtype)\n    assert isinstance(torch_dtype, torch.dtype)\n\n    buffer = bytearray(data)\n    # Create uint8 array\n    arr = torch.frombuffer(buffer, dtype=torch.uint8)\n    # Convert back to proper shape & type\n    return arr.view(torch_dtype).view(shape)\n\n\ndef deserialize_tensordict(arr: Any) -> TensorDict:\n    batch_size, device, encoded_items = arr\n    decoded_items: dict[str, Any] = {}\n\n    for k, v in encoded_items.items():\n        if len(v) == 3:\n            # decode single tensor\n            decoded_items[k] = deserialize_single_tensor(v)\n        elif len(v) == 2:\n            # decode nested tensor\n            layout, data = v\n            torch_layout = getattr(torch, layout)\n            decoded_items[k] = torch.nested.as_nested_tensor(\n                [deserialize_single_tensor(tensor) for tensor in data], layout=torch_layout\n            )\n        else:\n            raise ValueError(f\"Invalid tensor encoding format, expected length 2 or 3, got {len(v)}\")\n\n    return TensorDict(source=decoded_items, batch_size=batch_size, device=device)\n\n\ndef collate_fn(x: list[\"DataProtoItem\"]):\n    batch = []\n    non_tensor_batch = []\n    for data in x:\n        batch.append(data.batch)\n        non_tensor_batch.append(data.non_tensor_batch)\n    batch = torch.stack(batch).contiguous()\n    non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch)\n    for key, val in non_tensor_batch.items():\n        non_tensor_batch[key] = np.array(val, dtype=object)\n    return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)\n\n\n@dataclass\nclass DataProtoItem:\n    # TODO(zhangchi.usc1992) add consistency check\n    batch: TensorDict = None\n    non_tensor_batch: dict = field(default_factory=dict)\n    meta_info: dict = field(default_factory=dict)\n\n\n@dataclass\nclass DataProto:\n    \"\"\"\n    A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.\n    It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.\n    TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the\n    same batch size should be put inside batch.\n    \"\"\"\n\n    batch: TensorDict = None\n    non_tensor_batch: dict = field(default_factory=dict)\n    meta_info: dict = field(default_factory=dict)\n\n    def __post_init__(self):\n        # perform necessary checking\n        self.check_consistency()\n\n    def __len__(self):\n        if self.batch is not None:\n            return self.batch.batch_size[0]\n        elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:\n            random_key = list(self.non_tensor_batch.keys())[0]\n            return self.non_tensor_batch[random_key].shape[0]\n        else:\n            return 0\n\n    def __getitem__(self, item):\n        \"\"\"\n        Enhanced indexing for DataProto objects.\n\n        Args:\n            item: Can be one of:\n                - int: A single index\n                - slice: A slice object (start:stop:step)\n                - list: A list of indices\n                - numpy.ndarray: An array of indices\n                - torch.Tensor: A tensor of indices\n\n        Returns:\n            DataProto: For all indexing types except single integers\n            DataProtoItem: Only for single integer indices\n        \"\"\"\n        # Case 1: Slice object - use the slice method\n        if isinstance(item, slice):\n            return self.slice(item.start, item.stop, item.step)\n\n        # Case 2: List, numpy array, or torch tensor - use sel_idxs\n        elif isinstance(item, list | np.ndarray | torch.Tensor):\n            return self.select_idxs(item)\n\n        # Case 3: Single integer - return DataProtoItem for backward compatibility\n        elif isinstance(item, int | np.integer):\n            tensor_data = self.batch[item] if self.batch is not None else None\n            non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}\n            return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)\n\n        # # Case 4: Unsupported type\n        else:\n            raise TypeError(f\"Indexing with {type(item)} is not supported\")\n\n    def __getstate__(self):\n        if version.parse(tensordict.__version__) >= version.parse(\"0.5.0\") and self.batch is not None:\n            batch = self.batch.contiguous().consolidate()\n        else:\n            batch = self.batch\n\n        if os.getenv(\"VERL_DATAPROTO_SERIALIZATION_METHOD\") == \"numpy\":\n            if batch is not None:\n                batch = serialize_tensordict(self.batch)\n\n            return (\n                batch,\n                self.non_tensor_batch,\n                self.meta_info,\n            )\n        else:\n            import io\n\n            buffer = io.BytesIO()\n            torch.save(batch, buffer)\n            buffer_bytes = buffer.getvalue()\n            return buffer_bytes, self.non_tensor_batch, self.meta_info\n\n    def __setstate__(self, data):\n        batch_deserialized_bytes, non_tensor_batch, meta_info = data\n\n        if os.getenv(\"VERL_DATAPROTO_SERIALIZATION_METHOD\") == \"numpy\":\n            if batch_deserialized_bytes is not None:\n                self.batch = deserialize_tensordict(batch_deserialized_bytes)\n            else:\n                self.batch = None\n        else:\n            import io\n\n            batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)\n            batch = torch.load(\n                batch_deserialized,\n                weights_only=False,\n                map_location=\"cpu\" if not get_torch_device().is_available() else None,\n            )\n            self.batch = batch\n\n        self.non_tensor_batch = non_tensor_batch\n        self.meta_info = meta_info\n\n    def save_to_disk(self, filepath):\n        with open(filepath, \"wb\") as f:\n            pickle.dump(self, f)\n\n    @staticmethod\n    def load_from_disk(filepath) -> \"DataProto\":\n        with open(filepath, \"rb\") as f:\n            data = pickle.load(f)\n            return data\n\n    def print_size(self, prefix=\"\"):\n        size_of_tensordict = 0\n        if self.batch is not None:\n            for _, tensor in self.batch.items():\n                size_of_tensordict += tensor.element_size() * tensor.numel()\n        size_of_numpy_array = 0\n        for _, numpy_array in self.non_tensor_batch.items():\n            size_of_numpy_array += numpy_array.nbytes\n\n        size_of_numpy_array /= 1024**3\n        size_of_tensordict /= 1024**3\n\n        message = f\"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB\"\n\n        if prefix:\n            message = f\"{prefix}, \" + message\n        print(message)\n\n    def check_consistency(self):\n        \"\"\"Check the consistency of the DataProto. Mainly for batch and non_tensor_batch\n        We expose this function as a public one so that user can call themselves directly\n        \"\"\"\n        if self.batch is not None:\n            assert len(self.batch.batch_size) == 1, \"only support num_batch_dims=1\"\n\n        if self.non_tensor_batch is not None:\n            for key, val in self.non_tensor_batch.items():\n                assert isinstance(val, np.ndarray)\n\n        if self.batch is not None and self.non_tensor_batch is not None and len(self.non_tensor_batch) != 0:\n            # TODO: we can actually lift this restriction if needed\n            assert len(self.batch.batch_size) == 1, \"only support num_batch_dims=1 when non_tensor_batch is not empty.\"\n\n            batch_size = self.batch.batch_size[0]\n            for key, val in self.non_tensor_batch.items():\n                assert isinstance(val, np.ndarray), (\n                    f\"data in the non_tensor_batch must be a numpy.array with dtype=object, but for \"\n                    f\"{key=}, got {type(val)=}\"\n                )\n                assert val.shape[0] == batch_size, (\n                    f\"key {key} length {len(val)} is not equal to batch size {batch_size}\"\n                )\n\n    @classmethod\n    def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None, auto_padding=False):\n        \"\"\"Create a DataProto from a dict of tensors and non_tensors\"\"\"\n        tensors = {}\n        non_tensors = {}\n\n        for key, val in data.items():\n            if isinstance(val, torch.Tensor):\n                tensors[key] = val\n            elif isinstance(val, np.ndarray):\n                non_tensors[key] = val\n            else:\n                raise ValueError(f\"Unsupported type in data {type(val)}\")\n\n        return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info, auto_padding=auto_padding)\n\n    @classmethod\n    def from_dict(\n        cls,\n        tensors: Optional[dict[str, torch.Tensor]] = None,\n        non_tensors=None,\n        meta_info=None,\n        num_batch_dims=1,\n        auto_padding=False,\n    ):\n        \"\"\"Create a DataProto from a dict of tensors. This assumes that\n        1. All the tensor in tensors have the same dim0\n        2. Only dim0 is the batch dim\n        \"\"\"\n\n        assert num_batch_dims > 0, \"num_batch_dims must be greater than zero\"\n        if non_tensors is not None:\n            assert num_batch_dims == 1, \"only support num_batch_dims=1 when non_tensors is not None.\"\n\n        if tensors is None:\n            tensors = {}\n        if meta_info is None:\n            meta_info = {}\n        if non_tensors is None:\n            non_tensors = {}\n\n        assert isinstance(non_tensors, dict)\n\n        # get and check batch size\n        batch_size = None\n        pivot_key = None\n        for key, tensor in tensors.items():\n            if batch_size is None:\n                batch_size = tensor.shape[:num_batch_dims]\n                pivot_key = key\n            else:\n                current_batch = tensor.shape[:num_batch_dims]\n                assert batch_size == current_batch, (\n                    f\"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. \"\n                    f\"Got {pivot_key} has {batch_size}, {key} has {current_batch}\"\n                )\n\n        for key, val in non_tensors.items():\n            if not isinstance(val, np.ndarray):\n                non_tensors[key] = np.array(val, dtype=object)\n\n        tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None\n        if auto_padding:\n            meta_info[DataProtoConfig.auto_padding_key] = True\n        return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)\n\n    @classmethod\n    def from_tensordict(\n        cls,\n        tensor_dict: TensorDict = None,\n        meta_info=None,\n        num_batch_dims=1,\n    ):\n        \"\"\"Create a DataProto from a TensorDict. This assumes that\n        1. All the tensor in tensor_dict have the same dim0\n        2. Only dim0 is the batch dim\n        \"\"\"\n        assert version.parse(tensordict.__version__) >= version.parse(\"0.10.0\"), (\n            \"Build DataProto from TensorDict at least requires tensordict version 0.10.0\"\n        )\n        from tensordict import NonTensorData, NonTensorStack\n\n        assert num_batch_dims > 0, \"num_batch_dims must be greater than zero\"\n        if not all(isinstance(val, torch.Tensor) for val in tensor_dict.values()):\n            assert num_batch_dims == 1, \"only support num_batch_dims=1 when tensor_dict contains non tensor data.\"\n\n        if meta_info is None:\n            meta_info = {}\n        batch = {}\n        non_tensor_batch = {}\n        batch_size = None\n        for key, val in tensor_dict.items():\n            if isinstance(val, torch.Tensor):\n                batch[key] = val\n                if batch_size is None:\n                    batch_size = val.shape[:num_batch_dims]\n            elif isinstance(val, NonTensorStack):\n                non_tensor_batch[key] = np.array([elem.data for elem in val], dtype=object)\n            elif isinstance(val, NonTensorData):\n                meta_info[key] = val.data\n\n        return cls(\n            batch=TensorDict(batch, batch_size=batch_size),\n            non_tensor_batch=non_tensor_batch,\n            meta_info=meta_info,\n        )\n\n    def to(self, device) -> \"DataProto\":\n        \"\"\"move the batch to device\n\n        Args:\n            device (torch.device, str): torch device\n\n        Returns:\n            DataProto: the current DataProto\n\n        \"\"\"\n        if self.batch is not None:\n            self.batch = self.batch.to(device)\n        return self\n\n    def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> \"DataProto\":\n        \"\"\"Select a subset of the DataProto via batch_keys and meta_info_keys\n\n        Args:\n            batch_keys (list, optional): a list of strings indicating the keys in batch to select\n            meta_info_keys (list, optional): a list of keys indicating the meta info to select\n\n        Returns:\n            DataProto: the DataProto with the selected batch_keys and meta_info_keys\n        \"\"\"\n        # TODO (zhangchi.usc1992) whether to copy\n        if batch_keys is not None:\n            batch_keys = tuple(batch_keys)\n            sub_batch = self.batch.select(*batch_keys)\n        else:\n            sub_batch = self.batch\n\n        if non_tensor_batch_keys is not None:\n            non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}\n        else:\n            non_tensor_batch = self.non_tensor_batch\n\n        if deepcopy:\n            non_tensor_batch = copy.deepcopy(non_tensor_batch)\n\n        if meta_info_keys is not None:\n            sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}\n        else:\n            sub_meta_info = self.meta_info\n\n        if deepcopy:\n            sub_meta_info = copy.deepcopy(sub_meta_info)\n\n        return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)\n\n    def select_idxs(self, idxs):\n        \"\"\"\n        Select specific indices from the DataProto.\n\n        Args:\n            idxs (torch.Tensor or numpy.ndarray or list): Indices to select\n\n        Returns:\n            DataProto: A new DataProto containing only the selected indices\n        \"\"\"\n        if isinstance(idxs, list):\n            idxs = torch.tensor(idxs)\n            if idxs.dtype != torch.bool:\n                idxs = idxs.type(torch.int32)\n\n        if isinstance(idxs, np.ndarray):\n            idxs_np = idxs\n            idxs_torch = torch.from_numpy(idxs)\n        else:  # torch.Tensor\n            idxs_torch = idxs\n            idxs_np = idxs.detach().cpu().numpy()\n\n        batch_size = int(idxs_np.sum()) if idxs_np.dtype == bool else idxs_np.shape[0]\n\n        if self.batch is not None:\n            # Use TensorDict's built-in indexing capabilities\n            selected_batch = TensorDict(\n                source={key: tensor[idxs_torch] for key, tensor in self.batch.items()},\n                batch_size=(batch_size,),\n                device=self.batch.device,\n            )\n        else:\n            selected_batch = None\n\n        selected_non_tensor = {}\n        for key, val in self.non_tensor_batch.items():\n            selected_non_tensor[key] = val[idxs_np]\n\n        return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info)\n\n    def slice(self, start=None, end=None, step=None):\n        \"\"\"\n        Slice the DataProto and return a new DataProto object.\n        This is an improved version of direct slicing which returns a DataProtoItem.\n\n        Args:\n            start (int, optional): Start index. Defaults to None (start from beginning).\n            end (int, optional): End index (exclusive). Defaults to None (go to end).\n            step (int, optional): Step size. Defaults to None (step=1).\n\n        Returns:\n            DataProto: A new DataProto containing the sliced data\n\n        Examples:\n            # Using the slice method directly\n            sliced_data = data_proto.slice(10, 20)\n\n            # Using enhanced indexing (returns DataProto)\n            sliced_data = data_proto[10:20]\n            sliced_data = data_proto[::2]  # Every other element\n\n            # Using list indexing (returns DataProto)\n            indices = [1, 5, 10]\n            selected_data = data_proto[indices]\n\n            # Single index still returns DataProtoItem\n            single_item = data_proto[5]\n        \"\"\"\n        # Create a slice object\n        slice_obj = slice(start, end, step)\n\n        # Handle the batch data\n        if self.batch is not None:\n            # Use TensorDict's built-in slicing capabilities\n            sliced_batch = self.batch[slice_obj]\n        else:\n            sliced_batch = None\n\n        # Handle the non-tensor batch data\n        sliced_non_tensor = {}\n        for key, val in self.non_tensor_batch.items():\n            sliced_non_tensor[key] = val[slice_obj]\n\n        # Return a new DataProto object\n        return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info)\n\n    def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> \"DataProto\":\n        \"\"\"Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`\n\n        Args:\n            batch_keys (list, optional): a list of strings indicating the keys in batch to pop\n            meta_info_keys (list, optional): a list of keys indicating the meta info to pop\n\n        Returns:\n            DataProto: the DataProto with the poped batch_keys and meta_info_keys\n        \"\"\"\n        if batch_keys is None:\n            batch_keys = []\n        if meta_info_keys is None:\n            meta_info_keys = []\n        if non_tensor_batch_keys is None:\n            non_tensor_batch_keys = []\n\n        tensors = {}\n        # tensor batch\n        for key in batch_keys:\n            assert key in self.batch.keys()\n            tensors[key] = self.batch.pop(key)\n        non_tensors = {}\n        # non tensor batch\n        for key in non_tensor_batch_keys:\n            assert key in self.non_tensor_batch.keys()\n            non_tensors[key] = self.non_tensor_batch.pop(key)\n        meta_info = {}\n        for key in meta_info_keys:\n            assert key in self.meta_info.keys()\n            meta_info[key] = self.meta_info.pop(key)\n        return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)\n\n    def rename(self, old_keys=None, new_keys=None) -> \"DataProto\":\n        \"\"\"\n        Note that this function only rename the key in the batch\n        \"\"\"\n\n        def validate_input(keys):\n            if keys is not None:\n                if isinstance(keys, str):\n                    keys = [keys]\n                elif isinstance(keys, list):\n                    pass\n                else:\n                    raise TypeError(f\"keys must be a list or a string, but got {type(keys)}\")\n            return keys\n\n        old_keys = validate_input(old_keys)\n        new_keys = validate_input(new_keys)\n\n        if len(new_keys) != len(old_keys):\n            raise ValueError(\n                f\"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}\"\n            )\n\n        self.batch.rename_key_(tuple(old_keys), tuple(new_keys))\n\n        return self\n\n    def union(self, other: \"DataProto\") -> \"DataProto\":\n        \"\"\"Union with another DataProto. Union batch and meta_info separately.\n        Throw an error if\n\n        - there are conflict keys in batch and they are not equal\n        - the batch size of two data batch is not the same\n        - there are conflict keys in meta_info and they are not the same.\n\n        Args:\n            other (DataProto): another DataProto to union\n\n        Returns:\n            DataProto: the DataProto after union\n        \"\"\"\n        self.batch = union_tensor_dict(self.batch, other.batch)\n        self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)\n        self.meta_info = union_two_dict(self.meta_info, other.meta_info)\n        return self\n\n    def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):\n        r\"\"\"Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch\n        dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.\n\n\n        Args:\n            mini_batch_size (int): mini-batch size when iterating the dataset. We require that\n                ``batch.batch_size[0] % mini_batch_size == 0``.\n            epochs (int): number of epochs when iterating the dataset.\n            dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The\n                dataloader_kwargs is the kwargs passed to the DataLoader.\n\n        Returns:\n            Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration\n                steps is ``self.batch.batch_size * epochs // mini_batch_size``\n        \"\"\"\n        assert self.batch.batch_size[0] % mini_batch_size == 0, f\"{self.batch.batch_size[0]} % {mini_batch_size} != 0\"\n        # we can directly create a dataloader from TensorDict\n        if dataloader_kwargs is None:\n            dataloader_kwargs = {}\n\n        if seed is not None:\n            generator = torch.Generator()\n            generator.manual_seed(seed)\n        else:\n            generator = None\n\n        assert isinstance(dataloader_kwargs, dict)\n        train_dataloader = DataLoader(\n            dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs\n        )\n\n        def get_data():\n            for _ in range(epochs):\n                for d in train_dataloader:\n                    d.meta_info = self.meta_info\n                    yield d\n\n        return iter(get_data())\n\n    def is_padding_enabled(self):\n        \"\"\"\n        Check if padding is enabled for the DataProto.\n        Returns:\n            bool: True if padding is enabled, False otherwise.\n        \"\"\"\n        dataproto_specific_padding = self.meta_info.get(DataProtoConfig.auto_padding_key, False)\n        return dataproto_specific_padding or DataProtoConfig.auto_padding\n\n    def padding(self, padding_size, padding_candidate=\"\"):\n        \"\"\"Pad the DataProto by concating with padding_candidate.repeat(padding_size)\n\n        Args:\n            padding_size (int): the number of repeated padding_candidate\n            padding_candidate: the item to be repeated and appended to the DataProto, only supporting [\"first\", \"last\"]\n        \"\"\"\n        if padding_size == 0:\n            return\n        padding_candidate = self.select_idxs([0 if padding_candidate == \"first\" else len(self) - 1])\n        padding_part = padding_candidate.repeat(padding_size)\n        padded_dp = DataProto.concat([self, padding_part])\n        self.batch = padded_dp.batch\n        self.non_tensor_batch = padded_dp.non_tensor_batch\n\n    def chunk(self, chunks: int) -> list[\"DataProto\"]:\n        \"\"\"Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.\n\n        Args:\n            chunks (int): the number of chunks to split on dim=0\n\n        Returns:\n            List[DataProto]: a list of DataProto after splitting\n        \"\"\"\n        if not self.is_padding_enabled():\n            assert len(self) % chunks == 0, (\n                f\"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.\"\n            )\n\n        bsz_in_batch = None\n        if self.batch is not None:\n            batch_lst = self.batch.chunk(chunks=chunks, dim=0)\n            bsz_in_batch = np.array([batch.batch_size[0] for batch in batch_lst])\n            chunk_indices = np.cumsum(bsz_in_batch)[:-1]\n        else:\n            batch_lst = [None for _ in range(chunks)]\n\n        non_tensor_batch_lst = [{} for _ in range(chunks)]\n        for key, val in self.non_tensor_batch.items():\n            assert isinstance(val, np.ndarray)\n            if bsz_in_batch is not None:\n                non_tensor_lst = np.array_split(val, chunk_indices.tolist())\n            else:\n                non_tensor_lst = np.array_split(val, chunks)\n            assert len(non_tensor_lst) == chunks\n            for i in range(chunks):\n                non_tensor_batch_lst[i][key] = non_tensor_lst[i]\n\n        output = []\n        for i in range(chunks):\n            output.append(\n                type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)\n            )\n\n        return output\n\n    def split(self, split_size: int) -> list[\"DataProto\"]:\n        \"\"\"Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.\n\n        Args:\n            split_size (int): the size of each split\n\n        Returns:\n            List[DataProto]: a list of DataProto after splitting\n        \"\"\"\n        return [self[i : i + split_size] for i in range(0, len(self), split_size)]\n\n    @staticmethod\n    def concat(data: list[\"DataProto\"]) -> \"DataProto\":\n        \"\"\"Concat a list of DataProto. The batch is concatenated among dim=0.\n        The meta_info is merged, with special handling for metrics from different workers.\n\n        Args:\n            data (List[DataProto]): list of DataProto\n\n        Returns:\n            DataProto: concatenated DataProto\n        \"\"\"\n        batch_lst = []\n        for batch in data:\n            batch_lst.append(batch.batch)\n        new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None\n\n        non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])\n        for key, val in non_tensor_batch.items():\n            non_tensor_batch[key] = np.concatenate(val, axis=0)\n\n        # Merge meta_info with special handling for metrics\n        merged_meta_info = {}\n        if data:\n            # Merge non-metric meta_info and aggregate metrics from all workers.\n            all_metrics = []\n            for d in data:\n                for k, v in d.meta_info.items():\n                    if k == \"metrics\":\n                        if v is not None:\n                            if isinstance(v, list):\n                                all_metrics.extend(v)\n                            else:\n                                all_metrics.append(v)\n                    else:\n                        if k in merged_meta_info:\n                            # Ensure consistency for overlapping non-metric keys\n                            assert merged_meta_info[k] == v, f\"Conflicting values for meta_info key '{k}'\"\n                        else:\n                            merged_meta_info[k] = v\n\n            # Flatten list of dicts to dict of lists for consistent metrics structure\n            if all_metrics:\n                merged_meta_info[\"metrics\"] = list_of_dict_to_dict_of_list(all_metrics)\n\n        cls = type(data[0]) if len(data) > 0 else DataProto\n        return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=merged_meta_info)\n\n    def reorder(self, indices):\n        \"\"\"\n        Note that this operation is in-place\n        \"\"\"\n        indices_np = indices.detach().numpy()\n        self.batch = self.batch[indices]\n        self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}\n\n    def repeat(self, repeat_times=2, interleave=True):\n        \"\"\"\n        Repeat the batch data a specified number of times.\n\n        Args:\n            repeat_times (int): Number of times to repeat the data.\n            interleave (bool): Whether to interleave the repeated data.\n\n        Returns:\n            DataProto: A new DataProto with repeated data.\n        \"\"\"\n        if self.batch is not None:\n            if interleave:\n                # Interleave the data\n                repeated_tensors = {\n                    key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()\n                }\n            else:\n                # Stack the data\n                repeated_tensors = {\n                    key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:])\n                    for key, tensor in self.batch.items()\n                }\n\n            repeated_batch = TensorDict(\n                source=repeated_tensors,\n                batch_size=(self.batch.batch_size[0] * repeat_times,),\n            )\n        else:\n            repeated_batch = None\n\n        repeated_non_tensor_batch = {}\n        for key, val in self.non_tensor_batch.items():\n            if interleave:\n                repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)\n            else:\n                repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))\n\n        return type(self)(\n            batch=repeated_batch,\n            non_tensor_batch=repeated_non_tensor_batch,\n            meta_info=self.meta_info,\n        )\n\n    def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None):\n        \"\"\"Split along the second dim into `n_split`, unfold it to the first dim (batch dim)\n        Useful in passing grouped tensors that doesn't want to be shuffled in dataset.\n        keys not in split_keys are repeated to match the shape\n        Note that if the `split_keys` is not provided, it will repeat all the keys in the second dim.\n        \"\"\"\n        if self.batch is not None:\n            unfolded_batch = {}\n            for key in self.batch.keys():\n                if key in split_keys if split_keys is not None else False:\n                    shape = list(self.batch[key].shape)\n                    shape[0] = self.batch[key].shape[0] * n_split\n                    shape[1] = self.batch[key].shape[1] // n_split\n                    unfolded_batch[key] = self.batch[key].reshape(*shape)\n                else:\n                    unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0)\n            # locate the `unfolded_batch` as a TensorDict on the same device as the original batch\n            unfolded_batch = TensorDict(\n                source=unfolded_batch, batch_size=(self.batch.batch_size[0] * n_split,), device=self.batch.device\n            )\n        else:\n            unfolded_batch = None\n\n        repeated_non_tensor_batch = {}\n        for key, val in self.non_tensor_batch.items():\n            if key in split_keys:\n                shape = list(val.shape)\n                shape[0] = val.shape[0] * n_split\n                shape[1] = val.shape[1] // n_split\n                repeated_non_tensor_batch[key] = val.reshape(*shape)\n            else:\n                repeated_non_tensor_batch[key] = np.repeat(val, n_split, axis=0)\n\n        return type(self)(\n            batch=unfolded_batch,\n            non_tensor_batch=repeated_non_tensor_batch,\n            meta_info=self.meta_info,\n        )\n\n    def sample_level_repeat(self, repeat_times):\n        \"\"\"\n        Repeat each row of the batch data a specified number of times.\n\n        Args:\n            repeat_times (torch.tensor, list, tuple, ndarray):  Number of times to repeat the data.\n\n        Returns:\n            DataProto: A new DataProto with repeated data.\n        \"\"\"\n        if isinstance(repeat_times, tuple):\n            repeat_times = list(repeat_times)\n        elif isinstance(repeat_times, torch.Tensor):\n            assert len(repeat_times.shape) == 1\n            repeat_times = repeat_times.tolist()\n        elif isinstance(repeat_times, np.ndarray):\n            assert len(repeat_times.shape) == 1\n            repeat_times = repeat_times.tolist()\n        else:\n            assert isinstance(repeat_times, list), (\n                f\"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}\"\n            )\n        repeat_times = torch.tensor(repeat_times)\n\n        if self.batch is not None:\n            # Interleave the data\n            repeated_tensors = {\n                key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()\n            }\n\n            repeated_batch = TensorDict(\n                source=repeated_tensors,\n                batch_size=(repeat_times.sum().item(),),\n                device=self.batch.device,\n            )\n        else:\n            repeated_batch = None\n\n        repeated_non_tensor_batch = {}\n        for key, val in self.non_tensor_batch.items():\n            repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)\n\n        return type(self)(\n            batch=repeated_batch,\n            non_tensor_batch=repeated_non_tensor_batch,\n            meta_info=self.meta_info,\n        )\n\n    def to_tensordict(self) -> TensorDict:\n        \"\"\"Convert this DataProto to TensorDict. Note that this requires tensordict version at least 0.10\n\n        Returns:\n\n        \"\"\"\n        assert parse_version(tensordict.__version__) >= parse_version(\"0.10\"), (\n            \"Convert DataProto to TensorDict at least requires tensordict version 0.10\"\n        )\n        tensor_batch = self.batch.to_dict()\n        non_tensor_batch = self.non_tensor_batch\n\n        from verl.utils import tensordict_utils as tu\n\n        common_keys = set(tensor_batch.keys()) & set(non_tensor_batch.keys())\n        assert len(common_keys) == 0, f\"tensor_batch and non_tensor_batch have common keys {common_keys}\"\n\n        for key, val in non_tensor_batch.items():\n            assert isinstance(val, np.ndarray)\n            tensor_batch[key] = val.tolist()\n        output = tu.get_tensordict(tensor_dict=tensor_batch, non_tensor_dict=self.meta_info)\n        return output\n\n    def get_data_info(self) -> str:\n        \"\"\"Return formatted information about stored data with nested type details.\n\n        Returns:\n            str: Formatted string showing tensor details and recursive metadata types\n        \"\"\"\n        info = [\"batch\"]\n\n        for key, tensor in self.batch.items():\n            if hasattr(tensor, \"shape\") and hasattr(tensor, \"dtype\") and hasattr(tensor, \"device\"):\n                info.append(f\"  {key}: {tuple(tensor.shape)} ({tensor.dtype}) {tensor.device}\")\n            elif hasattr(tensor, \"shape\") and hasattr(tensor, \"dtype\"):\n                info.append(f\"  {key}: {tuple(tensor.shape)} ({tensor.dtype})\")\n            else:\n                info.append(f\"  {key}: {type(tensor).__name__}\")\n\n        info.append(\"non_tensor_batch\")\n        for key, array in self.non_tensor_batch.items():\n            info.append(f\"  {key}: ndarray{array.shape} ({array.dtype})\")\n\n        info.append(\"meta_info\")\n        for k, v in self.meta_info.items():\n            type_info = self._get_type_info(v)\n            info.append(f\"  {k}: {type_info}\")\n\n        return \"\\n\".join(info)\n\n    def _get_type_info(self, value):\n        \"\"\"Recursively get type information for nested structures\"\"\"\n        if isinstance(value, list):\n            elem_types = {self._get_type_info(v) for v in value[:3]}\n            return f\"list[{'|'.join(elem_types) if elem_types else '...'}]\"\n        if isinstance(value, tuple):\n            elem_types = [self._get_type_info(v) for v in value]\n            return f\"tuple({', '.join(elem_types)})\"\n        if isinstance(value, dict):\n            if not value:\n                return \"dict\"\n            k, v = next(iter(value.items()))\n            return f\"dict[{self._get_type_info(k)}: {self._get_type_info(v)}]\"\n        if isinstance(value, np.ndarray):\n            return f\"ndarray{value.shape} ({value.dtype})\"\n        return type(value).__name__\n\n\n@dataclass\nclass DataProtoFuture:\n    \"\"\"\n    DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait\n    for data so that asynchronous execution becomes possible.\n    DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.\n    - collect_fn is a Callable that reduces the list of futures to a DataProto\n    - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size\n        and then select\n\n    Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination\n    - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any\n    operation on the DataProtoFuture in driver.\n    \"\"\"\n\n    collect_fn: Callable\n    futures: list[ray.ObjectRef]\n    dispatch_fn: Callable = None\n\n    @staticmethod\n    def concat(data: list[ray.ObjectRef]) -> \"DataProtoFuture\":\n        output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)\n        return output\n\n    def chunk(self, chunks: int) -> list[\"DataProtoFuture\"]:\n        from functools import partial\n\n        arg_future_lst = []\n        for i in range(chunks):\n            # note that we can't directly pass i and chunks\n            def dispatch_fn(x, i, chunks):\n                return x.chunk(chunks=chunks)[i]\n\n            arg_future = DataProtoFuture(\n                collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures\n            )\n            arg_future_lst.append(arg_future)\n        return arg_future_lst\n\n    def get(self):\n        output = ray.get(self.futures)  # dp_size.\n        for o in output:\n            assert isinstance(o, DataProto)\n        output = self.collect_fn(output)  # select dp, concat\n        if self.dispatch_fn is not None:\n            output = self.dispatch_fn(output)  # split in batch dim, select using dp\n        return output\n\n\ndef all_gather_data_proto(data: DataProto, process_group):\n    # Note that this is an inplace operator just like torch.distributed.all_gather\n    group_size = torch.distributed.get_world_size(group=process_group)\n    assert isinstance(data, DataProto)\n    prev_device = data.batch.device\n    data = data.to(get_device_id())\n    data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0)\n    data = data.to(prev_device)\n    # all gather non_tensor_batch\n    all_non_tensor_batch = [None for _ in range(group_size)]\n    torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=process_group)\n    data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}\n"
  },
  {
    "path": "verl_distillation/verl/py.typed",
    "content": ""
  },
  {
    "path": "verl_distillation/verl/single_controller/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nfrom . import base\nfrom .base import *\n\nversion_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))\n\n# Note(haibin.lin): single_controller.__version__ is deprecated\nwith open(os.path.join(os.path.join(version_folder, os.pardir), \"version/version\")) as f:\n    __version__ = f.read().strip()\n\n\n__all__ = base.__all__\n"
  },
  {
    "path": "verl_distillation/verl/single_controller/base/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .worker import Worker\nfrom .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup\n\n__all__ = [\"Worker\", \"WorkerGroup\", \"ClassWithInitArgs\", \"ResourcePool\"]\n"
  },
  {
    "path": "verl_distillation/verl/single_controller/base/decorator.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom functools import partial, wraps\nfrom types import FunctionType\n\nfrom verl.protocol import DataProtoFuture, _padding_size_key\nfrom verl.utils.py_functional import DynamicEnum\nfrom verl.utils.transferqueue_utils import BatchMeta\n\n# here we add a magic number of avoid user-defined function already have this attribute\nMAGIC_ATTR = \"attrs_3141562937\"\n\n\nclass Dispatch(DynamicEnum):\n    \"\"\"Enum class defining different dispatch modes for distributed computation.\n\n    Each mode represents a specific strategy for distributing data across\n    different ranks in a distributed system. The modes are used to control\n    how data is partitioned and processed across different worker groups.\n    \"\"\"\n\n    _registry = {}\n    _next_value = 0\n\n\ndef init_predefined_dispatch_mode():\n    Dispatch.register(\"RANK_ZERO\")\n    Dispatch.register(\"ONE_TO_ALL\")\n    Dispatch.register(\"ALL_TO_ALL\")\n    Dispatch.register(\"DP_COMPUTE\")\n    Dispatch.register(\"DP_COMPUTE_PROTO\")\n    Dispatch.register(\"DP_COMPUTE_PROTO_WITH_FUNC\")\n    Dispatch.register(\"DP_COMPUTE_METRIC\")\n    # This is a special dispatch mode for vllm ExternalRayDistributedExecutor\n    Dispatch.register(\"DIRECT_ROLLOUT_METHOD\")\n\n\nclass Execute(DynamicEnum):\n    \"\"\"Enum class defining different execution modes for distributed computation.\n\n    These modes control how a function should be executed across different ranks\n    in a distributed system.\n    \"\"\"\n\n    _registry = {}\n    _next_value = 0\n\n\ndef init_predefined_execute_mode():\n    Execute.register(\"ALL\")\n    Execute.register(\"RANK_ZERO\")\n\n\n# Initialize the two Dynamic Enum Classes\ninit_predefined_dispatch_mode()\ninit_predefined_execute_mode()\n\n\ndef _split_args_kwargs_data_proto(chunks, *args, **kwargs):\n    from verl.protocol import DataProto, DataProtoFuture\n\n    splitted_args = []\n    for arg in args:\n        assert isinstance(arg, DataProto | DataProtoFuture | BatchMeta)\n        splitted_args.append(arg.chunk(chunks=chunks))\n\n    splitted_kwargs = {}\n    for key, val in kwargs.items():\n        assert isinstance(val, DataProto | DataProtoFuture | BatchMeta)\n        splitted_kwargs[key] = val.chunk(chunks=chunks)\n\n    return splitted_args, splitted_kwargs\n\n\ndef _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs):\n    from verl.protocol import DataProto, DataProtoFuture\n\n    data_proto_len = None\n    padding_size = None\n\n    def _padding_and_split_data(obj, chunks):\n        nonlocal data_proto_len, padding_size\n        assert isinstance(obj, DataProto | DataProtoFuture)\n        if isinstance(obj, DataProto) and obj.is_padding_enabled():\n            # for padding, we only support DataProto with same length\n            if data_proto_len is None:\n                data_proto_len = len(obj)\n                padding_size = (chunks - (data_proto_len % chunks)) if (data_proto_len % chunks > 0) else 0\n            else:\n                assert data_proto_len == len(obj), (\n                    f\"expecting all arg share same length of {data_proto_len}, but got {len(obj)}\"\n                )\n            obj.padding(padding_size=padding_size)\n        return obj.chunk(chunks=chunks)\n\n    splitted_args = [_padding_and_split_data(arg, chunks) for arg in args]\n    splitted_kwargs = {key: _padding_and_split_data(val, chunks) for key, val in kwargs.items()}\n    if padding_size is not None:\n        splitted_kwargs[_padding_size_key] = padding_size\n\n    return splitted_args, splitted_kwargs\n\n\ndef dispatch_one_to_all(worker_group, *args, **kwargs):\n    args = tuple([arg] * worker_group.world_size for arg in args)\n    kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}\n    return args, kwargs\n\n\ndef dummy_direct_rollout_call(worker_group, *args, **kwargs):\n    raise NotImplementedError(\"Direct rollout call is forbidden.\")\n\n\ndef dispatch_all_to_all(worker_group, *args, **kwargs):\n    return args, kwargs\n\n\ndef collect_all_to_all(worker_group, output):\n    return output\n\n\ndef _concat_data_proto_or_future(output: list):\n    import ray\n\n    from verl.protocol import DataProto, DataProtoFuture\n\n    # make sure all the elements in output has the same type\n    for o in output:\n        assert type(o) is type(output[0])\n\n    o = output[0]\n\n    if isinstance(o, DataProto):\n        return DataProto.concat(output)\n    elif isinstance(o, ray.ObjectRef):\n        return DataProtoFuture.concat(output)\n    elif isinstance(o, BatchMeta):\n        return BatchMeta.concat(output)\n    else:\n        raise NotImplementedError\n\n\ndef dispatch_dp_compute(worker_group, *args, **kwargs):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n    for arg in args:\n        assert isinstance(arg, tuple | list) and len(arg) == worker_group.world_size\n    for k, v in kwargs.items():\n        assert isinstance(v, tuple | list) and len(v) == worker_group.world_size\n    return args, kwargs\n\n\ndef collect_dp_compute(worker_group, output):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n    assert len(output) == worker_group.world_size\n    return output\n\n\ndef dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n    # Note: enable auto padding for dp compute DatapProto\n    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding(\n        worker_group.world_size,\n        *args,\n        **kwargs,\n    )\n    return splitted_args, splitted_kwargs\n\n\ndef dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n    assert isinstance(args[0], FunctionType)  # NOTE: The first one args is a function!\n\n    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)\n    splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args\n    return splitted_args_with_func, splitted_kwargs\n\n\ndef collect_dp_compute_data_proto(worker_group, output):\n    import ray\n\n    from verl.protocol import DataProto\n\n    for o in output:\n        assert isinstance(o, DataProto | ray.ObjectRef), f\"expecting {o} to be DataProto, but got {type(o)}\"\n\n    output = collect_dp_compute(worker_group, output)\n    return _concat_data_proto_or_future(output)\n\n\ndef dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, *args, **kwargs):\n    import os\n\n    from verl.single_controller.base.worker_group import WorkerGroup\n    from verl.utils.ray_utils import parallel_put\n\n    assert isinstance(worker_group, WorkerGroup)\n\n    max_workers = max(1, min(len(args[0]), os.cpu_count()))\n\n    args = [parallel_put(arg, max_workers=max_workers) for arg in args]\n    kwargs = {k: parallel_put(v, max_workers=max_workers) for k, v in kwargs.items()}\n\n    all_args = []\n    for arg in args:\n        assert isinstance(arg, tuple | list) and len(arg) == dp_size\n        transformed_args = []\n        for i in range(worker_group.world_size):\n            local_dp_rank = dp_rank_mapping[i]\n            transformed_args.append(arg[local_dp_rank])\n        all_args.append(transformed_args)\n    all_args = tuple(all_args)\n\n    all_kwargs = {}\n    for k, v in kwargs.items():\n        assert isinstance(v, tuple | list) and len(v) == dp_size\n        transformed_v = []\n        for i in range(worker_group.world_size):\n            local_dp_rank = dp_rank_mapping[i]\n            transformed_v.append(v[local_dp_rank])\n        all_kwargs[k] = transformed_v\n    return all_args, all_kwargs\n\n\ndef collect_nd_compute(collect_mask: list[bool], worker_group, output):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n    assert len(output) == worker_group.world_size\n\n    output_in_dp = []\n    for global_rank in range(worker_group.world_size):\n        collect_dp_rank = collect_mask[global_rank]\n        if collect_dp_rank:\n            output_in_dp.append(output[global_rank])\n    return output_in_dp\n\n\ndef dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_group, *args, **kwargs):\n    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(dp_size, *args, **kwargs)\n    return dispatch_nd_compute(dp_rank_mapping, dp_size, worker_group, *splitted_args, **splitted_kwargs)\n\n\ndef collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output):\n    output = collect_nd_compute(collect_mask, worker_group, output)\n    import ray\n\n    from verl.protocol import DataProto\n\n    for o in output:\n        assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta), (\n            f\"expecting {o} to be DataProto or BatchMeta, but got {type(o)}\"\n        )\n    return _concat_data_proto_or_future(output)\n\n\ndef dispatch_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n\n    # query dispatch info of the worker group\n    if mesh_name not in worker_group._dispatch_info:\n        worker_group._dispatch_info[mesh_name] = worker_group._query_dispatch_info(mesh_name)\n        assert len(worker_group._dispatch_info[mesh_name]) == worker_group.world_size\n\n    dp_rank_mapping = worker_group._dispatch_info[mesh_name]\n    # perform dispatch\n    dp_size = max(dp_rank_mapping) + 1\n    return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, *args, **kwargs)\n\n\ndef collect_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n\n    # the dispatch info is stored in the worker group\n    assert mesh_name in worker_group._dispatch_info\n\n    if mesh_name not in worker_group._collect_info:\n        worker_group._collect_info[mesh_name] = worker_group._query_collect_info(mesh_name)\n        assert len(worker_group._collect_info[mesh_name]) == worker_group.world_size\n\n    # a boolean of whether the dp_rank is used for collect\n    collect_mask = worker_group._collect_info[mesh_name]\n    # perform dispatch\n    return collect_nd_compute_dataproto(collect_mask, worker_group, *args, **kwargs)\n\n\ndef make_nd_compute_dataproto_dispatch_fn(mesh_name):\n    return {\n        \"dispatch_fn\": partial(dispatch_lazy_compute_data_proto, mesh_name),\n        \"collect_fn\": partial(collect_lazy_compute_data_proto, mesh_name),\n    }\n\n\n# Global registry for dispatch mode.\nDISPATCH_MODE_FN_REGISTRY = {\n    Dispatch.ONE_TO_ALL: {\n        \"dispatch_fn\": dispatch_one_to_all,\n        \"collect_fn\": collect_all_to_all,\n    },\n    Dispatch.ALL_TO_ALL: {\n        \"dispatch_fn\": dispatch_all_to_all,\n        \"collect_fn\": collect_all_to_all,\n    },\n    Dispatch.DP_COMPUTE: {\"dispatch_fn\": dispatch_dp_compute, \"collect_fn\": collect_dp_compute},\n    Dispatch.DP_COMPUTE_PROTO: {\n        \"dispatch_fn\": dispatch_dp_compute_data_proto,\n        \"collect_fn\": collect_dp_compute_data_proto,\n    },\n    Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {\n        \"dispatch_fn\": dispatch_dp_compute_data_proto_with_func,\n        \"collect_fn\": collect_dp_compute_data_proto,\n    },\n    Dispatch.DP_COMPUTE_METRIC: {\"dispatch_fn\": dispatch_dp_compute_data_proto, \"collect_fn\": collect_dp_compute},\n    Dispatch.DIRECT_ROLLOUT_METHOD: {\n        \"dispatch_fn\": dummy_direct_rollout_call,\n        \"collect_fn\": dummy_direct_rollout_call,\n    },\n}\n\n\ndef get_predefined_dispatch_fn(dispatch_mode):\n    return DISPATCH_MODE_FN_REGISTRY[dispatch_mode]\n\n\ndef register_dispatch_mode(dispatch_mode_name, dispatch_fn, collect_fn):\n    \"\"\"\n    Register a new dispatch mode.\n    \"\"\"\n    dispatch_mode = Dispatch.register(dispatch_mode_name)\n    _check_dispatch_mode(dispatch_mode)\n    assert dispatch_mode not in DISPATCH_MODE_FN_REGISTRY, f\"dispatch_mode_name {dispatch_mode_name} already exists\"\n    DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {\"dispatch_fn\": dispatch_fn, \"collect_fn\": collect_fn}\n\n\ndef update_dispatch_mode(dispatch_mode, dispatch_fn, collect_fn):\n    \"\"\"\n    Update the dispatch mode.\n    \"\"\"\n    _check_dispatch_mode(dispatch_mode)\n    assert dispatch_mode in DISPATCH_MODE_FN_REGISTRY, f\"dispatch_mode {dispatch_mode} not found\"\n    DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {\"dispatch_fn\": dispatch_fn, \"collect_fn\": collect_fn}\n\n\ndef get_predefined_execute_fn(execute_mode):\n    \"\"\"\n    Note that here we only asks execute_all and execute_rank_zero to be implemented\n    Leave the choice of how these two functions handle argument 'blocking' to users\n    \"\"\"\n    predefined_execute_mode_fn = {\n        Execute.ALL: {\"execute_fn_name\": \"execute_all\"},\n        Execute.RANK_ZERO: {\"execute_fn_name\": \"execute_rank_zero\"},\n    }\n    return predefined_execute_mode_fn[execute_mode]\n\n\ndef _check_dispatch_mode(dispatch_mode):\n    assert isinstance(dispatch_mode, Dispatch | dict), (\n        f\"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}\"\n    )\n    if isinstance(dispatch_mode, dict):\n        necessary_keys = [\"dispatch_fn\", \"collect_fn\"]\n        for key in necessary_keys:\n            assert key in dispatch_mode, f\"key {key} should be in dispatch_mode if it is a dictionary\"\n\n\ndef _check_execute_mode(execute_mode):\n    assert isinstance(execute_mode, Execute), f\"execute_mode must be a Execute. Got {execute_mode}\"\n\n\ndef _materialize_futures(*args, **kwargs):\n    new_args = []\n    for arg in args:\n        if isinstance(arg, DataProtoFuture):\n            arg = arg.get()\n        # add more type to materialize\n        new_args.append(arg)\n    for k, v in kwargs.items():\n        if isinstance(v, DataProtoFuture):\n            kwargs[k] = v.get()\n\n    new_args = tuple(new_args)\n    return new_args, kwargs\n\n\ndef register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):\n    \"\"\"Register a function with distributed execution configuration.\n\n    This decorator registers a function with specific dispatch and execution modes\n    for distributed computation. It handles both synchronous and asynchronous\n    functions, and optionally materializes futures before execution.\n\n    Args:\n        dispatch_mode:\n            Dispatch mode for computation distribution. Default: Dispatch.ALL_TO_ALL.\n        execute_mode:\n            Execute mode for computation distribution. Default: Execute.ALL.\n        blocking:\n            Whether the execution should be blocking. Defaults to True.\n        materialize_futures:\n            Whether to materialize the data before dispatching. Defaults to True.\n\n    Returns:\n        A decorator that wraps the original function with distributed execution\n        configuration.\n    \"\"\"\n    from verl.utils.transferqueue_utils import tqbridge\n\n    _check_dispatch_mode(dispatch_mode=dispatch_mode)\n    _check_execute_mode(execute_mode=execute_mode)\n\n    def decorator(func):\n        func = tqbridge()(func)\n\n        @wraps(func)\n        def inner(*args, **kwargs):\n            if materialize_futures:\n                args, kwargs = _materialize_futures(*args, **kwargs)\n            return func(*args, **kwargs)\n\n        @wraps(func)\n        async def async_inner(*args, **kwargs):\n            if materialize_futures:\n                args, kwargs = _materialize_futures(*args, **kwargs)\n            return await func(*args, **kwargs)\n\n        wrapper = async_inner if inspect.iscoroutinefunction(func) else inner\n        attrs = {\"dispatch_mode\": dispatch_mode, \"execute_mode\": execute_mode, \"blocking\": blocking}\n        setattr(wrapper, MAGIC_ATTR, attrs)\n        return wrapper\n\n    return decorator\n"
  },
  {
    "path": "verl_distillation/verl/single_controller/base/worker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nthe class for Worker\n\"\"\"\n\nimport os\nimport socket\nimport warnings\nfrom dataclasses import dataclass\n\nimport ray\n\nfrom verl.utils.device import (\n    get_torch_device,\n    get_visible_devices_keyword,\n    is_npu_available,\n)\n\nfrom .decorator import Dispatch, Execute, register\n\n\n@dataclass\nclass DistRankInfo:\n    tp_rank: int\n    dp_rank: int\n    pp_rank: int\n    cp_rank: int\n\n\n@dataclass\nclass DistGlobalInfo:\n    tp_size: int\n    dp_size: int\n    pp_size: int\n    cp_size: int\n\n\nclass WorkerHelper:\n    @staticmethod\n    def _get_node_ip():\n        if os.getenv(\"WG_BACKEND\", None) == \"ray\":\n            return ray.util.get_node_ip_address()\n        else:\n            raise NotImplementedError(\"WG_BACKEND now just support ray mode.\")\n\n    @staticmethod\n    def _get_free_port():\n        with socket.socket() as sock:\n            sock.bind((\"\", 0))\n            return sock.getsockname()[1]\n\n    def get_availale_master_addr_port(self):\n        warnings.warn(\n            \"This function is deprecated due to typo in name; Please use `get_available_master_addr_port` instead\",\n            stacklevel=2,\n        )\n        return self.get_available_master_addr_port()\n\n    def get_available_master_addr_port(self):\n        return self._get_node_ip().strip(\"[]\"), str(self._get_free_port())\n\n\n# we assume that in each WorkerGroup, there is a Master Worker\nclass Worker(WorkerHelper):\n    \"\"\"A distributed worker that handles initialization and configuration for distributed training.\n\n    This class manages worker initialization, configuration, and provides methods for executing\n    distributed operations. It handles communication settings, device configuration, and worker\n    metadata management.\n    \"\"\"\n\n    fused_worker_attr_name = \"fused_worker_dict\"\n\n    def _register_dispatch_collect_info(self, mesh_name: str, dp_rank: int, is_collect: bool):\n        \"\"\"Register the dp_rank for a given mesh name. This function is meant to be called by the worker\n\n        Args:\n            mesh_name (str):\n                Name of the mesh to register dp_rank for.\n            dp_rank (int):\n                dp_rank to register for the given mesh name.\n            is_collect (bool):\n                Whether the dp_rank is used for collect.\n        \"\"\"\n        if mesh_name in self.__dispatch_dp_rank or mesh_name in self.__collect_dp_rank:\n            raise ValueError(f\"mesh_name {mesh_name} has been registered\")\n        self.__dispatch_dp_rank[mesh_name] = dp_rank\n        self.__collect_dp_rank[mesh_name] = is_collect\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def _query_dispatch_info(self, mesh_name: str):\n        \"\"\"Query the dispatch info for a given mesh name.\n\n        Args:\n            mesh_name (str):\n                Name of the mesh to query dispatch info for.\n\n        Returns:\n            int:\n                The dp_rank for the given mesh name.\n        \"\"\"\n        assert mesh_name in self.__dispatch_dp_rank, f\"{mesh_name} is not registered in {self.__class__.__name__}\"\n        # note that each rank store its own dp_rank\n        return self.__dispatch_dp_rank[mesh_name]\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def _query_collect_info(self, mesh_name: str):\n        \"\"\"Query the collect info for a given mesh name.\n\n        Args:\n            mesh_name (str):\n                Name of the mesh to query collect info for.\n\n        Returns:\n            bool:\n                Whether the dp_rank is used for collect.\n        \"\"\"\n        assert mesh_name in self.__collect_dp_rank, f\"{mesh_name} is not registered in {self.__class__.__name__}\"\n        return self.__collect_dp_rank[mesh_name]\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)\n    def create_transferqueue_client(self, controller_infos, storage_infos, role=\"train\"):\n        from verl.utils.transferqueue_utils import create_transferqueue_client\n\n        create_transferqueue_client(\n            client_id=f\"{role}_worker_{self.rank}\",\n            controller_infos=controller_infos,\n            storage_infos=storage_infos,\n        )\n\n    @classmethod\n    def env_keys(cls):\n        \"\"\"The keys of the environment variables that are used to configure the Worker.\"\"\"\n        return [\n            \"WORLD_SIZE\",\n            \"RANK\",\n            \"LOCAL_WORLD_SIZE\",\n            \"LOCAL_RANK\",\n            \"MASTER_ADDR\",\n            \"MASTER_PORT\",\n            get_visible_devices_keyword().upper(),\n        ]\n\n    def __init__(self, cuda_visible_devices=None) -> None:\n        \"\"\"Initialize the worker with environment settings and device configuration.\n\n        Args:\n            cuda_visible_devices (str, optional):\n                CUDA visible devices configuration. Defaults to None.\n        \"\"\"\n        # construct a meta from environment variable. Note that the import must be inside the class because\n        # it is executed remotely\n        import os\n\n        self._setup_env_cuda_visible_devices()\n\n        world_size = int(os.environ[\"WORLD_SIZE\"])\n        rank = int(os.environ[\"RANK\"])\n        self._rank = rank\n        self._world_size = world_size\n\n        master_addr = os.environ[\"MASTER_ADDR\"]\n        master_port = os.environ[\"MASTER_PORT\"]\n\n        local_world_size = int(os.getenv(\"LOCAL_WORLD_SIZE\", \"1\"))\n        local_rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n\n        store = {\n            \"_world_size\": world_size,\n            \"_rank\": rank,\n            \"_local_world_size\": local_world_size,\n            \"_local_rank\": local_rank,\n            \"_master_addr\": master_addr,\n            \"_master_port\": master_port,\n        }\n        if cuda_visible_devices is not None:\n            store[f\"_{get_visible_devices_keyword()}\".lower()] = cuda_visible_devices\n\n        self._configure_with_store(store=store)\n\n        self.fused_worker_dict = {}\n        self.__dispatch_dp_rank = {}\n        self.__collect_dp_rank = {}\n\n    def get_fused_worker_by_name(self, worker_name: str):\n        \"\"\"Get a fused worker by its name.\n\n        Args:\n            worker_name (str):\n                Name of the worker to retrieve\n        \"\"\"\n        return self.fused_worker_dict.get(worker_name, None)\n\n    def _setup_env_cuda_visible_devices(self):\n        from verl.utils.ray_utils import ray_noset_visible_devices\n\n        is_ray_noset_visible_devices = ray_noset_visible_devices()\n\n        # Prevent use of clashing `{CUDA/HIP/ROCR}_VISIBLE_DEVICES``\n        rocr_val = os.environ.get(\"ROCR_VISIBLE_DEVICES\", None)\n        hip_val = os.environ.get(\"HIP_VISIBLE_DEVICES\", None)\n        cuda_val = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n        if hip_val:\n            # Switch the use of HIP_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES for consistency.\n            # Make sure that the HIP_VISIBLE_DEVICES is set to the same value as CUDA_VISIBLE_DEVICES\n            # at this point.\n            val = os.environ.pop(\"HIP_VISIBLE_DEVICES\")\n            hip_val = None\n            if cuda_val:\n                assert val == cuda_val, (\n                    f\"Please use the same HIP_VISIBLE_DEVICES or CUDA_VISIBLE_DEVICES, inconsistant values \"\n                    f\"found: {val} and {cuda_val}.\"\n                )\n            else:\n                cuda_val = val\n                os.environ[\"CUDA_VISIBLE_DEVICES\"] = val\n                # os.environ[\"HIP_VISIBLE_DEVICES\"] = val\n\n        if rocr_val:\n            # You must take care if both HIP/CUDA and ROCR env vars are set as they have\n            # different meanings. Both env vars accept either a list of ints or a\n            # list of UUIDs. The ROCR env var is processed first which then reduces\n            # the number of GPUs that HIP can select from.\n            # https://github.com/pytorch/pytorch/pull/144026\n            # To avoid the complexity of this, we simply gives out error if both are set\n            # (Also to keep consistency with ray's practice with 2.45.0).\n            # Otherwise, we will set ROCR_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES\n            # and remove ROCR_VISIBLE_DEVICES.\n            if cuda_val:\n                raise ValueError(\"Please don't set ROCR_VISIBLE_DEVICES when HIP/CUDA_VISIBLE_DEVICES is set.\")\n\n            cuda_val = os.environ.pop(\"ROCR_VISIBLE_DEVICES\")\n            os.environ[\"CUDA_VISIBLE_DEVICES\"] = cuda_val\n            rocr_val = None\n\n        if is_ray_noset_visible_devices:\n            # NOTE: Ray will automatically set the *_VISIBLE_DEVICES\n            # environment variable for each actor, unless\n            # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set,\n            # so we need to set local rank when the flag is set.\n            device_name = \"NPU\" if is_npu_available else \"GPU\"\n            local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0]\n            os.environ[\"LOCAL_RANK\"] = local_rank\n            get_torch_device().set_device(int(local_rank))\n\n    def _configure_with_store(self, store: dict):\n        \"\"\"\n        This function should only be called inside by WorkerGroup\n        \"\"\"\n        store_env_dict = {f\"_{key.lower()}\": store.get(f\"_{key.lower()}\", None) for key in type(self).env_keys()}\n        self.__dict__.update(store_env_dict)  # this is hacky\n        # print(f\"__dict__: {self.__dict__}\")\n        for key in type(self).env_keys():\n            val = self.__dict__.get(f\"_{key.lower()}\", None)\n            if val is not None:\n                # print(f\"set {key} to {val}\")\n                os.environ[key] = str(val)\n        os.environ[\"REDIS_STORE_SERVER_HOST\"] = (\n            str(self._master_addr).replace(\"[\", \"\").replace(\"]\", \"\") if self._master_addr else \"\"\n        )\n\n    def get_master_addr_port(self):\n        \"\"\"Get the master address and port for distributed communication.\"\"\"\n        return self._master_addr, self._master_port\n\n    def get_cuda_visible_devices(self):\n        \"\"\"Get the CUDA visible devices configuration.\"\"\"\n        import os\n\n        visible_devices = os.environ.get(get_visible_devices_keyword().upper(), \"not set\")\n        return visible_devices\n\n    @property\n    def world_size(self):\n        \"\"\"Get the total number of workers in the distributed setup.\"\"\"\n        return self._world_size\n\n    @property\n    def rank(self):\n        \"\"\"Get the rank of this worker in the distributed setup.\"\"\"\n        return self._rank\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)\n    def execute_with_func_generator(self, func, *args, **kwargs):\n        \"\"\"Execute a function with function generator dispatch mode.\n\n        Args:\n            func:\n                Function to execute\n            *args:\n                Positional arguments for the function\n            **kwargs:\n                Keyword arguments for the function\n        \"\"\"\n        ret_proto = func(self, *args, **kwargs)\n        return ret_proto\n\n    @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)\n    def execute_func_rank_zero(self, func, *args, **kwargs):\n        \"\"\"Execute a function in rank zero execution mode.\n\n        Args:\n            func:\n                Function to execute\n            *args:\n                Positional arguments for the function\n            **kwargs:\n                Keyword arguments for the function\n        \"\"\"\n        result = func(*args, **kwargs)\n        return result\n"
  },
  {
    "path": "verl_distillation/verl/single_controller/base/worker_group.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nthe class of WorkerGroup\n\"\"\"\n\nimport logging\nimport signal\nimport threading\nimport time\nfrom typing import Any, Callable\n\nfrom .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn\n\n\nclass ResourcePool:\n    \"\"\"\n    Manages a pool of resources across multiple nodes, tracking process counts and GPU allocations.\n    The class provides methods to calculate world size, local world sizes, and local ranks\n    across all nodes in the pool.\n    \"\"\"\n\n    def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8) -> None:\n        \"\"\"Initialize the ResourcePool with node processes and GPU configuration.\n\n        Args:\n            process_on_nodes (List[int], optional): List of process counts per node. Defaults to empty list.\n            max_colocate_count (int, optional): Maximum number of processes that can be colocated. Defaults to 10.\n            n_gpus_per_node (int, optional): Number of GPUs available per node. Defaults to 8.\n        \"\"\"\n        if process_on_nodes is None:\n            process_on_nodes = []\n        self._store = process_on_nodes\n        self.max_colocate_count = max_colocate_count\n        self.n_gpus_per_node = n_gpus_per_node  # this is left for future huawei GPU that contains 16 GPUs per node\n\n    def add_node(self, process_count):\n        self._store.append(process_count)\n\n    @property\n    def world_size(self):\n        \"\"\"Total number of processes across all nodes in the pool.\"\"\"\n        return sum(self._store)\n\n    def __call__(self) -> Any:\n        return self._store\n\n    @property\n    def store(self):\n        return self._store\n\n    def local_world_size_list(self) -> list[int]:\n        \"\"\"Returns a flat list where each process has its local world size.\"\"\"\n        nested_local_world_size_list = [\n            [local_world_size for _ in range(local_world_size)] for local_world_size in self._store\n        ]\n        return [item for row in nested_local_world_size_list for item in row]\n\n    def local_rank_list(self) -> list[int]:\n        \"\"\"Returns a flat list of local ranks for all processes across all nodes.\"\"\"\n        nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]\n        return [item for row in nested_local_rank_list for item in row]\n\n\nclass ClassWithInitArgs:\n    \"\"\"\n    Wrapper class that stores constructor arguments for deferred instantiation.\n    This class is particularly useful for remote class instantiation where\n    the actual construction needs to happen at a different time or location.\n    \"\"\"\n\n    def __init__(self, cls, *args, **kwargs) -> None:\n        \"\"\"Initialize the ClassWithInitArgs instance.\n\n        Args:\n            cls: The class to be instantiated later\n            *args: Positional arguments for the class constructor\n            **kwargs: Keyword arguments for the class constructor\n        \"\"\"\n        self.cls = cls\n        self.args = args\n        self.kwargs = kwargs\n\n        self.fused_worker_used = False\n\n    def __call__(self) -> Any:\n        \"\"\"Instantiate the stored class with the stored arguments.\"\"\"\n        return self.cls(*self.args, **self.kwargs)\n\n\ndef check_workers_alive(workers: list, is_alive: Callable, gap_time: float = 1) -> None:\n    \"\"\"Continuously monitors worker processes and raises SIGABRT if any worker dies.\n\n    Args:\n        workers (List):\n            List of worker objects to monitor\n        is_alive (Callable):\n            Function to check if a worker is alive\n        gap_time (float):\n            Time interval between checks\n    \"\"\"\n    import time\n\n    while True:\n        for worker in workers:\n            if not is_alive(worker):\n                logging.warning(f\"worker {worker} is not alive sending signal to main thread\")\n                signal.raise_signal(signal.SIGABRT)\n        time.sleep(gap_time)\n\n\nclass WorkerGroup:\n    \"\"\"\n    Base class for managing a group of workers in a distributed system.\n    The class provides methods for worker management, aliveness checking, and method binding.\n    \"\"\"\n\n    fused_worker_execute_fn_name = \"_fuw_execute\"\n\n    def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:\n        self._is_init_with_detached_workers = resource_pool is None\n\n        self.fused_worker_used = False\n\n        if resource_pool is not None:\n            # handle the case when WorkGroup is attached to an existing one\n            self._procecss_dispatch_config = resource_pool()\n        else:\n            self._procecss_dispatch_config = None\n\n        self._workers = []\n        self._worker_names = []\n\n        self._dispatch_info = {}\n        self._collect_info = {}\n\n        self._master_addr = None\n        self._master_port = None\n\n        self._checker_thread: threading.Thread = None\n\n    def _is_worker_alive(self, worker):\n        \"\"\"Check if a worker is alive. Must be implemented by derived classes.\"\"\"\n        raise NotImplementedError(\"WorkerGroup._is_worker_alive called, should be implemented in derived class.\")\n\n    def _block_until_all_workers_alive(self) -> None:\n        \"\"\"Blocks until all workers in the group are alive.\"\"\"\n        while True:\n            all_state = [self._is_worker_alive(worker) for worker in self._workers]\n            if False in all_state:\n                time.sleep(1)\n            else:\n                break\n\n    def start_worker_aliveness_check(self, every_n_seconds=1) -> None:\n        \"\"\"Starts a background thread to monitor worker aliveness.\n\n        Args:\n            every_n_seconds (int): Interval between aliveness checks\n        \"\"\"\n        # before starting checking worker aliveness, make sure all workers are already alive\n        self._block_until_all_workers_alive()\n\n        self._checker_thread = threading.Thread(\n            target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds)\n        )\n        self._checker_thread.start()\n\n    @property\n    def world_size(self):\n        \"\"\"Number of workers in the group.\"\"\"\n        return len(self._workers)\n\n    def _bind_worker_method(self, user_defined_cls, func_generator):\n        \"\"\"Binds worker methods to the WorkerGroup based on registered attributes.\n\n        Args:\n            user_defined_cls (type): The class containing methods to bind\n            func_generator (Callable): Function that generates the bound method\n\n        Returns:\n            List[str]: List of method names that were successfully bound\n        \"\"\"\n        method_names = []\n        for method_name in dir(user_defined_cls):\n            try:\n                method = getattr(user_defined_cls, method_name)\n                assert callable(method), f\"{method_name} in {user_defined_cls} is not callable\"\n            except Exception:\n                # if it is a property, it will fail because Class doesn't have instance property\n                continue\n\n            if hasattr(method, MAGIC_ATTR):\n                # this method is decorated by register\n                attribute = getattr(method, MAGIC_ATTR)\n                assert isinstance(attribute, dict), f\"attribute must be a dictionary. Got {type(attribute)}\"\n                assert \"dispatch_mode\" in attribute, \"attribute must contain dispatch_mode in its key\"\n\n                dispatch_mode = attribute[\"dispatch_mode\"]\n                execute_mode = attribute[\"execute_mode\"]\n                blocking = attribute[\"blocking\"]\n\n                # get dispatch fn\n                if isinstance(dispatch_mode, Dispatch):\n                    # get default dispatch fn\n                    fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)\n                    dispatch_fn = fn[\"dispatch_fn\"]\n                    collect_fn = fn[\"collect_fn\"]\n                else:\n                    assert isinstance(dispatch_mode, dict)\n                    assert \"dispatch_fn\" in dispatch_mode\n                    assert \"collect_fn\" in dispatch_mode\n                    dispatch_fn = dispatch_mode[\"dispatch_fn\"]\n                    collect_fn = dispatch_mode[\"collect_fn\"]\n\n                # get execute_fn_name\n                execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)\n                wg_execute_fn_name = execute_mode[\"execute_fn_name\"]\n\n                # get execute_fn from string\n                try:\n                    execute_fn = getattr(self, wg_execute_fn_name)\n                    assert callable(execute_fn), \"execute_fn must be callable\"\n                except Exception:\n                    print(f\"execute_fn {wg_execute_fn_name} is invalid\")\n                    raise\n\n                # bind a new method to the RayWorkerGroup\n                func = func_generator(\n                    self,\n                    method_name,\n                    dispatch_fn=dispatch_fn,\n                    collect_fn=collect_fn,\n                    execute_fn=execute_fn,\n                    blocking=blocking,\n                )\n\n                try:\n                    setattr(self, method_name, func)\n                    method_names.append(method_name)\n                except Exception as e:\n                    raise ValueError(f\"Fail to set method_name {method_name}\") from e\n\n        return method_names\n"
  },
  {
    "path": "verl_distillation/verl/single_controller/ray/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import (\n    RayClassWithInitArgs,\n    RayResourcePool,\n    RayWorkerGroup,\n    create_colocated_worker_cls,\n    create_colocated_worker_cls_fused,\n)\n\n__all__ = [\n    \"RayClassWithInitArgs\",\n    \"RayResourcePool\",\n    \"RayWorkerGroup\",\n    \"create_colocated_worker_cls\",\n    \"create_colocated_worker_cls_fused\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/single_controller/ray/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport inspect\nimport logging\nimport socket\nfrom copy import deepcopy\nfrom typing import Any, Optional\n\nimport ray\nfrom ray.experimental.state.api import get_actor\nfrom ray.util.placement_group import PlacementGroup, placement_group\nfrom ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy\n\nfrom verl.protocol import DataProto, _padding_size_key\nfrom verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup\nfrom verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch\nfrom verl.utils.py_functional import temp_env_var\n\n__all__ = [\"Worker\"]\n\n\ndef get_random_string(length: int) -> str:\n    import random\n    import string\n\n    letters_digits = string.ascii_letters + string.digits\n    return \"\".join(random.choice(letters_digits) for _ in range(length))\n\n\ndef func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):\n    class Functor:\n        def __call__(this, *args, **kwargs):\n            args, kwargs = dispatch_fn(self, *args, **kwargs)\n            padding_count = kwargs.pop(_padding_size_key, 0)\n            output = execute_fn(method_name, *args, **kwargs)\n            if blocking:\n                output = ray.get(output)\n            output = collect_fn(self, output)\n            if padding_count > 0:\n                if isinstance(output, DataProto):\n                    indices = [i for i in range(len(output))][:-padding_count]\n                    output = output.select_idxs(indices)\n                elif isinstance(output, list):\n                    output = output[:-padding_count]\n            return output\n\n    # use class type to pass the method_name to get a better observability\n    return type(method_name, (Functor,), {})()\n\n\ndef sort_placement_group_by_node_ip(pgs: list[PlacementGroup]) -> list[PlacementGroup]:\n    \"\"\"\n    Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.\n\n    FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK\n    to be consistent across nodes when resume from checkpoint.\n\n    With this function, if there's only one resource pool and there's no node change, RANK should be consistent\n    across nodes in multiple ray jobs, even if the whole ray cluster is restarted.\n    \"\"\"\n    node_ip = {node[\"NodeID\"]: node[\"NodeManagerAddress\"] for node in ray.nodes()}\n    pg_ip = {}\n    for pg in pgs:\n        specs = ray._private.state.state.placement_group_table(pg.id)\n        # all bunles should be on the same node\n        node_id = specs[\"bundles_to_node_id\"][0]\n        pg_ip[pg.id] = node_ip[node_id]\n    return sorted(pgs, key=lambda pg: pg_ip[pg.id])\n\n\n@ray.remote\ndef get_master_addr_port() -> tuple[str, str]:\n    addr = ray.util.get_node_ip_address().strip(\"[]\")\n    with socket.socket() as sock:\n        sock.bind((\"\", 0))\n        port = sock.getsockname()[1]\n    return addr, str(port)\n\n\nclass RayResourcePool(ResourcePool):\n    def __init__(\n        self,\n        process_on_nodes: Optional[list[int]] = None,\n        use_gpu: bool = True,\n        name_prefix: str = None,\n        max_colocate_count: int = 10,\n        detached=False,\n        accelerator_type: Optional[str] = None,\n    ) -> None:\n        super().__init__(process_on_nodes, max_colocate_count)\n        self.use_gpu = use_gpu\n        # print(f\"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}\")\n        self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix\n        self.pgs = None\n        self.detached = detached\n        self.accelerator_type = accelerator_type\n\n    def get_placement_groups(self, strategy=\"STRICT_PACK\", name=None, device_name=\"cuda\"):\n        if self.pgs is not None:\n            return self.pgs\n\n        pg_name_prefix = (\n            name if name else f\"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:\"\n        )\n        # print(f\"pg_name_prefix = {pg_name_prefix}\")\n        if device_name == \"npu\":\n            device_name = \"NPU\"\n        elif device_name == \"cuda\":\n            device_name = \"GPU\"\n\n        bundle = {\"CPU\": self.max_colocate_count}\n        if self.use_gpu:\n            bundle[device_name] = 1\n            if self.accelerator_type is not None:\n                bundle[self.accelerator_type] = 1e-4\n        pg_scheme = [[bundle.copy() for _ in range(process_count)] for process_count in self._store]\n\n        lifetime = \"detached\" if self.detached else None\n\n        pgs = [\n            placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime)\n            for idx, bundles in enumerate(pg_scheme)\n        ]\n\n        ray.get([pg.ready() for pg in pgs])\n\n        self.pgs = pgs\n        return pgs\n\n\ndef extract_pg_from_exist(\n    resource_pools: dict[str, RayResourcePool], src_role_names: list[str], resource_pool: RayResourcePool\n) -> list:\n    src_pgs = [\n        pg\n        for role_name, resource_pool in resource_pools.items()\n        for pg in resource_pool.get_placement_groups()\n        if role_name in src_role_names\n    ]\n\n    sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True)\n    sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True)\n\n    unsorted_pgs: list[tuple[int, PlacementGroup]] = []\n    searching_idx = 0\n    for request_process, original_idx in sorted_process_on_nodes:\n        assert searching_idx < len(sorted_src_pgs), f\"no enough nodes for request: searching {searching_idx} th node\"\n        assert request_process <= sorted_src_pgs[searching_idx].bundle_count, (\n            f\"requesting {request_process} processes, bundle count cannot satisfy\"\n        )\n        unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx]))\n        searching_idx += 1\n\n    return [pg for _, pg in sorted(unsorted_pgs)]\n\n\ndef merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool:\n    assert rp1.use_gpu == rp2.use_gpu, \"Both RayResourcePool must either use_gpu or not\"\n    assert rp1.max_colocate_count == rp2.max_colocate_count, \"Both RayResourcePool must has the same max_colocate_count\"\n    assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, \"Both RayResourcePool must has the same n_gpus_per_node\"\n    assert rp1.detached == rp2.detached, \"Detached ResourcePool cannot be merged with non-detached ResourcePool\"\n\n    new_store = rp1.store + rp2.store\n\n    merged = type(rp1)(new_store, rp1.use_gpu, f\"{rp1.name_prefix}_{rp2.name_prefix}\")\n    merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups()\n\n    return merged\n\n\nclass RayClassWithInitArgs(ClassWithInitArgs):\n    \"\"\"A wrapper class for Ray actors with initialization arguments.\n\n    This class extends ClassWithInitArgs to provide additional functionality for\n    configuring and creating Ray actors with specific resource requirements and\n    scheduling strategies.\n    \"\"\"\n\n    def __init__(self, cls, *args, **kwargs) -> None:\n        # self._options = kwargs.pop('options', dict())\n        super().__init__(cls, *args, **kwargs)\n        self._options = {}\n        self._additional_resource = {}\n\n    def set_additional_resource(self, additional_resource):\n        \"\"\"Set additional resource requirements for the actor.\n\n        Args:\n            additional_resource: Dictionary specifying additional resource requirements\n        \"\"\"\n        self._additional_resource = additional_resource\n\n    def update_options(self, options: dict):\n        \"\"\"Update the Ray actor creation options.\n\n        Args:\n            options: Dictionary of options to update\n        \"\"\"\n        self._options.update(options)\n\n    def __call__(\n        self,\n        placement_group,\n        placement_group_bundle_idx,\n        use_gpu: bool = True,\n        num_gpus=1,\n        sharing_with=None,\n        device_name=\"cuda\",\n    ) -> Any:\n        \"\"\"Create and return a Ray actor with the configured options.\n\n        Args:\n            placement_group: Ray placement group for scheduling\n            placement_group_bundle_idx: Index of the bundle in the placement group\n            use_gpu: Whether to use GPU resources\n            num_gpus: Number of GPUs to allocate\n            sharing_with: Actor to share resources with\n            device_name: Device for training\n\n        Returns:\n            A Ray actor handle with the configured options\n        \"\"\"\n        if sharing_with is not None:\n            target_node_id = ray.get(sharing_with.get_node_id.remote())\n            visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote())\n            options = {\"scheduling_strategy\": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)}\n            return self.cls.options(**options).remote(*self.args, cuda_visible_devices=visible_devices, **self.kwargs)\n\n        options = {\n            \"scheduling_strategy\": PlacementGroupSchedulingStrategy(\n                placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx\n            )\n        }\n        options.update(self._options)\n\n        if use_gpu and device_name == \"cuda\":\n            options[\"num_gpus\"] = num_gpus\n        if use_gpu and device_name == \"npu\":\n            options[\"resources\"] = {\"NPU\": num_gpus}\n\n        if len(self._additional_resource) > 1:\n            for k, v in self._additional_resource.items():\n                options[k] = v\n\n        # print(\"cls:\", self.cls)\n        # print(\"args: \", self.args)\n        # print(\"kwargs: \", self.kwargs)\n        return self.cls.options(**options).remote(*self.args, **self.kwargs)\n\n\nclass RayWorkerGroup(WorkerGroup):\n    \"\"\"A group of Ray workers that can be managed collectively.\n\n    This class extends WorkerGroup to provide Ray-specific functionality for\n    creating and managing groups of Ray actors with specific resource requirements\n    and scheduling strategies.\n    \"\"\"\n\n    def __init__(\n        self,\n        resource_pool: RayResourcePool = None,\n        ray_cls_with_init: RayClassWithInitArgs = None,\n        bin_pack: bool = True,\n        name_prefix: str = None,\n        detached=False,\n        worker_names=None,\n        worker_handles: list[ray.actor.ActorHandle] = None,\n        ray_wait_register_center_timeout: int = 300,\n        **kwargs,\n    ) -> None:\n        \"\"\"Initialize a RayWorkerGroup.\n\n        Args:\n            resource_pool: Resource pool for worker allocation\n            ray_cls_with_init: Class with initialization arguments for workers\n            bin_pack: Whether to use strict bin packing for resource allocation\n            name_prefix: Prefix for worker names\n            detached: Whether workers should be detached\n            worker_names: Names of existing workers to attach to\n            ray_wait_register_center_timeout: Timeout for waiting on register center\n            **kwargs: Additional keyword arguments\n        \"\"\"\n        super().__init__(resource_pool=resource_pool, **kwargs)\n        self.ray_cls_with_init = ray_cls_with_init\n        self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix\n        self._ray_wait_register_center_timeout = ray_wait_register_center_timeout\n        # Whether the WorkerGroup is a Colocate WorkerGroup created by FusedWorker.\n        self.fused_worker_used = ray_cls_with_init.fused_worker_used\n        # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to\n        # this WorkerGroup.\n        self.sub_cls_name = \"\"\n        self.device_name = kwargs.get(\"device_name\", \"cuda\")\n        self.profile_steps = kwargs.get(\"profile_steps\", None)\n        self.worker_nsight_options = kwargs.get(\"worker_nsight_options\", None)\n        self.customized_worker_env = kwargs.get(\"worker_env\", {})\n        if self.worker_nsight_options is not None and self.worker_nsight_options[\"capture-range-end\"] is None:\n            self.worker_nsight_options[\"capture-range-end\"] = f\"repeat-shutdown:{6 * len(self.profile_steps)}\"\n\n        if worker_names is not None and (not self.fused_worker_used):\n            assert self._is_init_with_detached_workers\n            self._worker_names = worker_names\n\n        if self._is_init_with_detached_workers:\n            self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles)\n        else:\n            self._init_with_resource_pool(\n                resource_pool=resource_pool,\n                ray_cls_with_init=ray_cls_with_init,\n                bin_pack=bin_pack,\n                detached=detached,\n                worker_env=self.customized_worker_env,\n            )\n\n        if ray_cls_with_init is not None:\n            self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)\n\n        self.wg_dict = None\n        self.method_names = []\n\n    def _is_worker_alive(self, worker: ray.actor.ActorHandle):\n        \"\"\"Check if a worker actor is still alive.\n\n        Args:\n            worker: Ray actor handle to check\n\n        Returns:\n            bool: True if the worker is alive, False otherwise\n        \"\"\"\n        worker_state_dict = get_actor(worker._actor_id.hex())\n        return worker_state_dict.get(\"state\", \"undefined\") == \"ALIVE\" if worker_state_dict is not None else False\n\n    def _init_with_detached_workers(self, worker_names, worker_handles):\n        # ray.get_actor holds a weak reference to the actor, which causes actors garbage collected unexpectedly\n        # if we only hold spawn RayWorkerGroup. By passing actor handle explicitly, spawn RayWorkerGroup have\n        # strong reference to these actors.\n        # https://github.com/ray-project/ray/pull/45699\n        workers = worker_handles if worker_handles else [ray.get_actor(name=name) for name in worker_names]\n        self._workers = workers\n        self._world_size = len(worker_names)\n\n    def _get_master_addr_port(self, pg):\n        \"\"\"Get master addr and port for this worker group\"\"\"\n        self._master_addr, self._master_port = ray.get(\n            get_master_addr_port.options(\n                scheduling_strategy=PlacementGroupSchedulingStrategy(\n                    placement_group=pg, placement_group_bundle_index=0\n                ),\n            ).remote()\n        )\n\n    def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached, worker_env=None):\n        \"\"\"Initialize the worker group by creating new workers from a resource pool.\n\n        Args:\n            resource_pool: Resource pool for worker allocation\n            ray_cls_with_init: Class with initialization arguments for workers\n            bin_pack: Whether to use strict bin packing for resource allocation\n            detached: Whether workers should be detached\n        \"\"\"\n        use_gpu = resource_pool.use_gpu\n\n        strategy = \"PACK\"\n        if bin_pack:\n            strategy = \"STRICT_PACK\"\n        pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name)\n        world_size = resource_pool.world_size\n        self._world_size = world_size\n        # cia.add_kwarg(\"_world_size\", world_size)\n        num_gpus = 1 / resource_pool.max_colocate_count\n\n        rank = -1\n        local_world_size = resource_pool.store[0]\n        for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)):\n            assert local_world_size <= pg.bundle_count, f\"when generating for {self.name_prefix}, for the \"\n            if pg_idx == 0:\n                self._get_master_addr_port(pg)\n\n            for local_rank in range(local_world_size):\n                rank += 1\n\n                # we pass in environment variable at option so that Worker can use environment variable to set\n                env_vars = {\n                    \"WORLD_SIZE\": str(world_size),\n                    \"RANK\": str(rank),\n                    \"WG_PREFIX\": self.name_prefix,\n                    \"WG_BACKEND\": \"ray\",\n                    \"RAY_LOCAL_WORLD_SIZE\": str(local_world_size),\n                    \"MASTER_ADDR\": self._master_addr,\n                    \"MASTER_PORT\": self._master_port,\n                }\n                if worker_env is not None:\n                    logging.debug(f\"Appending ray class env, origin: {env_vars}, customized env: {worker_env}\")\n                    conflict_env_vars = set(env_vars.keys()) & set(worker_env.keys())\n                    if len(conflict_env_vars) > 0:\n                        logging.error(\n                            f\"User customized env vars conflict with system env: {conflict_env_vars} \"\n                            f\"Overriding may cause unexpected behavior.\"\n                        )\n                        raise ValueError(f\"Cannot override protected system env: {conflict_env_vars}\")\n                    env_vars.update(worker_env)\n                import re\n\n                cia_name = type(ray_cls_with_init.cls).__name__\n                match = re.search(r\"ActorClass\\(([^)]+)\\)\", cia_name)  # ray.remote(Obj) -> \"ActorClass(Obj)\"\n                cia_name = match.group(1) if match else cia_name  # \"ActorClass(Obj)\" -> \"Obj\"\n                name = f\"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}\"  # e.g. Worker_2:5\n\n                if self.profile_steps and self.device_name == \"cuda\":\n                    ray_cls_with_init.update_options(\n                        {\n                            \"runtime_env\": {\n                                \"env_vars\": env_vars,\n                                \"nsight\": self.worker_nsight_options,\n                            },\n                            \"name\": name,\n                        }\n                    )\n                else:\n                    ray_cls_with_init.update_options({\"runtime_env\": {\"env_vars\": env_vars}, \"name\": name})\n\n                if detached:\n                    ray_cls_with_init.update_options({\"lifetime\": \"detached\"})\n\n                # create a worker\n                worker = ray_cls_with_init(\n                    placement_group=pg,\n                    placement_group_bundle_idx=local_rank,\n                    use_gpu=use_gpu,\n                    num_gpus=num_gpus,\n                    device_name=self.device_name,\n                )\n                self._workers.append(worker)\n                self._worker_names.append(name)\n\n    @property\n    def worker_names(self):\n        return self._worker_names\n\n    @classmethod\n    def from_detached(\n        cls,\n        name_prefix=None,\n        worker_names=None,\n        worker_handles=None,\n        ray_cls_with_init=None,\n        **kwargs,\n    ):\n        \"\"\"Create a worker group from existing detached workers.\n\n        Args:\n            name_prefix: Prefix for worker names\n            worker_names: Names of existing workers to attach to\n            ray_cls_with_init: Class with initialization arguments for workers\n\n        Returns:\n            A new RayWorkerGroup instance\n        \"\"\"\n        worker_group = cls(\n            resource_pool=None,\n            ray_cls_with_init=ray_cls_with_init,\n            name_prefix=name_prefix,\n            worker_names=worker_names,\n            worker_handles=worker_handles,\n            **kwargs,\n        )\n        return worker_group\n\n    def spawn(self, prefix_set):\n        \"\"\"Spawn to a dictionary of worker groups, each with a subset of method with prefix.\n\n        Args:\n            prefix_set: Set of prefixes to create worker groups for\n\n        Returns:\n            Dictionary of worker groups keyed by prefix\n        \"\"\"\n        if self.fused_worker_used:\n            return self.spawn_fused(prefix_set)\n\n        def _rebind_actor_methods(worker_group, actor_name):\n            prefix: str = actor_name + \"_\"\n            for method_name in dir(worker_group):\n                if method_name.startswith(prefix):\n                    original_method_name = method_name.removeprefix(prefix)\n                    method = getattr(worker_group, method_name)\n                    setattr(worker_group, original_method_name, method)\n\n        new_worker_group_dict = {}\n        for prefix in prefix_set:\n            new_worker_group = self.from_detached(\n                name_prefix=self.name_prefix,\n                worker_names=self._worker_names,\n                worker_handles=self._workers,\n                ray_cls_with_init=self.ray_cls_with_init,\n                profile_steps=self.profile_steps,\n                worker_nsight_options=self.worker_nsight_options,\n            )\n\n            _rebind_actor_methods(new_worker_group, prefix)\n            new_worker_group_dict[prefix] = new_worker_group\n        return new_worker_group_dict\n\n    def spawn_fused(self, prefix_set):\n        \"\"\"Create a dictionary of worker groups for fused workers.\n\n        Args:\n            prefix_set: Set of prefixes to create worker groups for\n\n        Returns:\n            Dictionary of worker groups keyed by prefix\n        \"\"\"\n        wg_dict = dict()\n        for key in prefix_set:\n            new_wg = deepcopy(self)\n            new_wg._bind_worker_method(self.ray_cls_with_init.cls.raw_cls_dict[key], func_generator)\n            new_wg.sub_cls_name = key\n            wg_dict[key] = new_wg\n        return wg_dict\n\n    def fuse(self, prefix_set):\n        \"\"\"Fuse multiple worker groups into the current worker group.\n\n        Args:\n            prefix_set: Set of prefixes to fuse into the worker group\n        \"\"\"\n        if self.wg_dict is None:\n            self.wg_dict = self.spawn(prefix_set)\n        for role_name, role_wg in self.wg_dict.items():\n            setattr(self, role_name, role_wg)\n        self.method_names = self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)\n\n    def _execute_remote_single_worker(self, worker, method_name: str, *args, **kwargs):\n        \"\"\"Execute a method on a single worker remotely.\n\n        Args:\n            worker: The worker actor handle\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            Remote object reference to the method execution\n        \"\"\"\n        if self.fused_worker_used and method_name not in self.method_names:\n            remote_call = getattr(worker, self.fused_worker_execute_fn_name)\n            return remote_call.remote(f\"{self.sub_cls_name}_fwmn_{method_name}\", *args, **kwargs)\n        # fused worker not used\n        remote_call = getattr(worker, method_name)\n        return remote_call.remote(*args, **kwargs)\n\n    def execute_rank_zero_sync(self, method_name: str, *args, **kwargs):\n        \"\"\"Execute a method on rank zero worker synchronously.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            Result of the method execution\n        \"\"\"\n        return ray.get(self.execute_rank_zero_async(method_name, *args, **kwargs))\n\n    def execute_rank_zero_async(self, method_name: str, *args, **kwargs):\n        \"\"\"Execute a method on rank zero worker asynchronously.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            Remote object reference to the method execution\n        \"\"\"\n        return self._execute_remote_single_worker(self._workers[0], method_name, *args, **kwargs)\n\n    def execute_rank_zero(self, method_name: str, *args, **kwargs):\n        \"\"\"Alias for execute_rank_zero_async.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            Remote object reference to the method execution\n        \"\"\"\n        return self.execute_rank_zero_async(method_name, *args, **kwargs)\n\n    def execute_all(self, method_name: str, *args, **kwargs):\n        \"\"\"Alias for execute_all_async.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            List of remote object references to the method executions\n        \"\"\"\n        return self.execute_all_async(method_name, *args, **kwargs)\n\n    def execute_all_sync(self, method_name: str, *args, **kwargs):\n        \"\"\"Execute a method on all workers synchronously.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            List of results from all workers\n        \"\"\"\n        return ray.get(self.execute_all_async(method_name, *args, **kwargs))\n\n    def execute_all_async(self, method_name: str, *args, **kwargs):\n        \"\"\"Execute a method on all workers asynchronously.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            List of remote object references to the method executions\n        \"\"\"\n        # Here, we assume that if all arguments in args and kwargs are lists,\n        # and their lengths match len(self._workers), we'll distribute each\n        # element in these lists to the corresponding worker\n        # print(f\"execute_all_async: method {method_name}({args}, {kwargs})\")\n        length = len(self._workers)\n        if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):\n            if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()):\n                # print(f\"splitting args and kwargs into {length} shards\")\n                result = []\n                for i in range(length):\n                    sliced_args = tuple(arg[i] for arg in args)\n                    sliced_kwargs = {k: v[i] for k, v in kwargs.items()}\n                    result.append(\n                        self._execute_remote_single_worker(self._workers[i], method_name, *sliced_args, **sliced_kwargs)\n                    )\n                return result\n\n        return [self._execute_remote_single_worker(worker, method_name, *args, **kwargs) for worker in self._workers]\n\n    @property\n    def master_address(self):\n        return self._master_addr\n\n    @property\n    def master_port(self):\n        return self._master_port\n\n    @property\n    def workers(self):\n        return self._workers\n\n    @property\n    def world_size(self):\n        return self._world_size\n\n\n\"\"\"\nUtilities that enables creating workers inside the same ray.Actor,\nwith code written in separate ray.Actors.\n\"\"\"\n\n\n# deprecated, switching to FusedWorker\ndef _bind_workers_method_to_parent(cls, key, user_defined_cls):\n    \"\"\"\n    Binds the methods of each worker to the WorkerDict.\n    Note that we only bind public methods that are decorated by register\n    \"\"\"\n\n    for method_name in dir(user_defined_cls):\n        try:\n            method = getattr(user_defined_cls, method_name)\n            assert callable(method), f\"{method_name} in {user_defined_cls} is not callable\"\n        except Exception:\n            # if it is a property, it will fail because Class doesn't have instance property\n            continue\n\n        if hasattr(method, MAGIC_ATTR):\n\n            def generate_function(name, key=key):\n                def func(self, *args, **kwargs):\n                    # dispatch to the actual worker\n                    return getattr(self.worker_dict[key], name)(*args, **kwargs)\n\n                async def async_func(self, *args, **kwargs):\n                    # dispatch to the actual worker\n                    return await getattr(self.worker_dict[key], name)(*args, **kwargs)\n\n                wrapper = async_func if inspect.iscoroutinefunction(method) else func  # noqa: B023\n\n                return wrapper\n\n            func = generate_function(method_name)\n            # pass MAGIC_ATTR for outer worker group\n            attrs = getattr(method, MAGIC_ATTR)\n            setattr(func, MAGIC_ATTR, attrs)\n            try:\n                # bind direct rollout method to class without prefix\n                if attrs[\"dispatch_mode\"] == Dispatch.DIRECT_ROLLOUT_METHOD and \"rollout\" in key:\n                    assert not hasattr(cls, method_name), (\n                        f\"conflict direct rollout method {method_name} with role {key}\"\n                    )\n                    setattr(cls, method_name, func)\n                    print(f\"bind role {key} method {method_name} to class {cls}\")\n                else:\n                    method_name_with_prefix = key + \"_\" + method_name\n                    setattr(cls, method_name_with_prefix, func)\n            except Exception as e:\n                raise ValueError(f\"Fail to set method_name {method_name}\") from e\n\n\ndef _unwrap_ray_remote(cls):\n    if hasattr(cls, \"__ray_actor_class__\"):\n        cls = cls.__ray_actor_class__\n    return cls\n\n\ndef _determine_fsdp_megatron_base_class(mros: list):\n    \"\"\"\n    - megatron: base class should be MegatronWorker\n    - fsdp: base class should be Worker\n    \"\"\"\n    for cls in mros[0]:\n        if cls.__name__ == \"MegatronWorker\":\n            return cls\n        if cls.__name__ == \"Worker\":\n            return cls\n    raise ValueError(f\"Cannot determine base class for {mros}\")\n\n\n# deprecated, switching to FusedWorker\ndef create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):\n    \"\"\"\n    This function should return a class instance that delegates the calls to every\n    cls in cls_dict\n    \"\"\"\n    cls_dict = {}\n    init_args_dict = {}\n    worker_cls = _determine_fsdp_megatron_base_class(\n        [cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()]\n    )\n    assert issubclass(worker_cls, Worker), f\"worker_cls {worker_cls} should be a subclass of Worker\"\n    print(f\"colocated worker base class {worker_cls}\")\n\n    for key, cls in class_dict.items():\n        cls_dict[key] = cls.cls\n        init_args_dict[key] = {\"args\": cls.args, \"kwargs\": cls.kwargs}\n\n    assert cls_dict.keys() == init_args_dict.keys()\n\n    # TODO: create a class with customizable name\n    class WorkerDict(worker_cls):\n        def __init__(self):\n            super().__init__()\n            self.worker_dict = {}\n            for key, user_defined_cls in cls_dict.items():\n                user_defined_cls = _unwrap_ray_remote(user_defined_cls)\n                # directly instantiate the class without remote\n                # in worker class, e.g. <verl.single_controller.base.worker.Worker>\n                # when DISABLE_WORKER_INIT == 1 it will return immediately\n                with temp_env_var(\"DISABLE_WORKER_INIT\", \"1\"):\n                    self.worker_dict[key] = user_defined_cls(\n                        *init_args_dict[key].get(\"args\", ()), **init_args_dict[key].get(\"kwargs\", {})\n                    )\n\n    # now monkey-patch the methods from inner class to WorkerDict\n    for key, user_defined_cls in cls_dict.items():\n        user_defined_cls = _unwrap_ray_remote(user_defined_cls)\n        _bind_workers_method_to_parent(WorkerDict, key, user_defined_cls)\n\n    remote_cls = ray.remote(WorkerDict)\n    remote_cls = RayClassWithInitArgs(cls=remote_cls)\n    return remote_cls\n\n\nFusedWorkerCLSName = \"FusedWorker\"\n\n\ndef create_colocated_worker_raw_cls(class_dict: dict[str, RayClassWithInitArgs]):\n    \"\"\"\n    This function returns a FusedWorker class.\n\n    `FusedWorker.{class_name}` -> FusedClass\n        Use `class_name` as a param to directly access the underlying class.\n\n    `FusedWorker._fuw_execute(\"{class_name}_fwmn_{method_name}\", *args, **kwargs)`\n        First param must be \"{class_name}_fwmn_{method_name}\" in order to access `method_name`\n        of underlying class `{class_name}`.\n\n    `FusedWorker.fused_worker_dict` -> {\"class_name\": FusedClass}\n        Stores all underlying classes.\n\n    `FusedClass.fused_worker_dict` -> {\"class_name\": FusedClass}\n        The same as `FusedWorker.fused_worker_dict`, enables underlying class to access other\n        underlying classes.\n    \"\"\"\n    raw_cls_dict = {cls_name: _unwrap_ray_remote(cia.cls) for cls_name, cia in class_dict.items()}\n    init_args_dict = {cls_name: cia.args for cls_name, cia in class_dict.items()}\n    init_kwargs_dict = {cls_name: cia.kwargs for cls_name, cia in class_dict.items()}\n    cls_names = list(class_dict.keys())\n\n    # FusedWorker_Actor_Critic\n    class_name_renamed = \"_\".join([FusedWorkerCLSName] + cls_names)\n\n    class FusedWorker(Worker):\n        def __init__(self, *args, **kwargs):\n            super().__init__(*args, **kwargs)\n            self.cls_names = cls_names\n            self.raw_cls_dict = raw_cls_dict\n            self.init_args_dict = init_args_dict\n            self.init_kwargs_dict = init_kwargs_dict\n\n            for cls_name, udc, ud_args, ud_kwargs in zip(\n                self.cls_names,\n                self.raw_cls_dict.values(),\n                self.init_args_dict.values(),\n                self.init_kwargs_dict.values(),\n                strict=True,\n            ):\n                with temp_env_var(\"DISABLE_WORKER_INIT\", \"1\"):\n                    udc._get_ray_actor_cls_name = lambda x, name_renamed=class_name_renamed: name_renamed\n                    udc._get_ray_method_prefix = lambda x, name_prefixed=cls_name: f\"{name_prefixed}_\"\n                    # cls_name = \"actor\", \"critic\", udc = ActorWorker, CriticWorker\n                    self.fused_worker_dict[cls_name] = udc(*ud_args, **ud_kwargs)\n                    setattr(self, cls_name, self.fused_worker_dict[cls_name])\n\n            # injecting fused_worker to each sub worker so they can be aware of existence of each other\n            for _, worker in self.fused_worker_dict.items():\n                setattr(worker, Worker.fused_worker_attr_name, self.fused_worker_dict)\n\n        def _fuw_execute(self, method_name: str, *args, **kwargs):\n            # for fused_worker, method_name is in a form of \"{cls_name}_fwmn_{method_name}\"\n            # where fwmn stands \"fused worker method name\"\n            names = method_name.split(\"_fwmn_\")\n            cls_name = names[0]\n            method_name = names[1]\n\n            assert cls_name in self.fused_worker_dict, (\n                f\"calling {cls_name}'s {method_name}, but {cls_name} not in fused_worker_dict\"\n            )\n            udc_method = getattr(self.fused_worker_dict[cls_name], method_name)\n            return udc_method(*args, **kwargs)\n\n    renamed_fused_worker_cls = type(class_name_renamed, (FusedWorker,), {})\n    renamed_fused_worker_cls.is_fused_worker = True\n    renamed_fused_worker_cls.raw_cls_dict = raw_cls_dict\n\n    return renamed_fused_worker_cls\n\n\ndef create_colocated_worker_cls_fused(class_dict: dict[str, RayClassWithInitArgs]):\n    \"\"\"\n    This function returns a RayClassWithInitArgs instance of FusedWorker, which is an replacement\n    of `create_colocated_worker_cls`. WorkerGroup constructed using this class will be a colocated\n    WorkerGroup, which will be referenced as `ColocateWorkerGroup` below.\n\n    `ColocateWorkerGroup.spawn(prefix_set)`\n        returns a dict of WorkerGroup {\"class_name\": WorkerGroup}, WorkerGroup in this dict will\n        have methods of underlying class `class_name` attached.\n\n    `ColocateWorkerGroup.fuse(prefix_set)`\n        After executing this function, `ColocateWorkerGroup.{class_name}` will return WorkerGroup\n        with methods of underlying class `class_name` attached.\n    \"\"\"\n    raw_colocated_worker_cls = create_colocated_worker_raw_cls(class_dict)\n\n    remote_cls = ray.remote(raw_colocated_worker_cls)\n    cia = RayClassWithInitArgs(cls=remote_cls)\n    cia.fused_worker_used = True\n\n    return cia\n"
  },
  {
    "path": "verl_distillation/verl/third_party/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/third_party/sglang/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/third_party/sglang/parallel_state.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The SGlang team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n\"\"\"Model and data parallel groups.\"\"\"\n\nimport os\nfrom typing import Optional\n\nimport sglang.srt.distributed.parallel_state as ps\nimport torch\nimport torch.distributed\nfrom sglang.srt.distributed.parallel_state import (\n    get_pp_group,\n    get_world_group,\n    init_distributed_environment,\n    init_model_parallel_group,\n)\n\n\"\"\"\nThis version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.\n- We assume the Megatron tp+dp+pp world is already established before calling this function.\n\n\"\"\"\n\n# Device mesh for using DTensor\n_DEVICE_MESH = None\n\n# Tensor model parallel group that the current rank belongs to.\n_TP = None\n# Pipeline model parallel group that the current rank belongs to.\n_PP = None\n\n\n# This method is for initializing the ParallelGroup when using HybridEngine\n# NOTE(linjunrong): this function is for megatron\ndef initialize_parallel_state(\n    distributed_init_method: str = \"env://\",\n    backend: str = \"nccl\",\n    tensor_model_parallel_size: int = 1,\n    num_tp_per_train_tp: int = 1,\n    pipeline_model_parallel_size: int = 1,\n):\n    # torch.distributed.all_reduce does not free the input tensor until\n    # the synchronization point. This causes the memory usage to grow\n    # as the number of all_reduce calls increases. This env var disables\n    # this behavior.\n    # Related issue:\n    # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573\n    os.environ[\"TORCH_NCCL_AVOID_RECORD_STREAMS\"] = \"1\"\n\n    # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.\n    rank = int(os.getenv(\"RANK\", \"-1\"))\n    local_rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n\n    # Use the world_size set by TORCHRUN\n    world_size = int(os.getenv(\"WORLD_SIZE\", \"-1\"))\n    assert world_size != -1, \"The world_size is set to -1, not initialized by TORCHRUN\"\n    init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend)\n    if torch.distributed.get_world_size() > 1:\n        # NOTE: build a separate inference group with infer tp & micro dp\n        initialize_model_parallel_for_sglang(\n            tensor_model_parallel_size=tensor_model_parallel_size,\n            num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp,\n        )\n    else:\n        initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)\n\n\n# NOTE(linjunrong): After init SGLang rollout using class EngineFragment, user should always remember to call\n# this function to sync the _TP, _PP define at the beginning of this file. Otherwise, only the conterparts\n# inside sglang.srt.distributed are init as ProcessGroup, the symbols defined in this file remain as None.\n# It could be weird to maintain two _TP and _PP, I follow the same way to maintain an extra ones for\n# verl itself as how it was done in verl.third_party.vllm.parallel_state. Note that the process is a little\n# bit different\ndef ensure_model_parallel_initialized(\n    tensor_model_parallel_size: int,\n    pipeline_model_parallel_size: int = 1,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"Helper to initialize model parallel groups if they are not initialized,\n    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected\n    values if the model parallel groups are initialized.\n    \"\"\"\n    # get the backend of _DEVICE_WORLD_GROUP\n    backend = backend or torch.distributed.get_backend(get_world_group().device_group)\n    if not model_parallel_is_initialized():\n        initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)\n        return\n\n    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (\n        f\"tensor parallel group already initialized, but of unexpected size: \"\n        f\"{get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}\"\n    )\n    pp_world_size = get_pp_group().world_size\n    assert pp_world_size == pipeline_model_parallel_size, (\n        f\"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. \"\n        f\"{pipeline_model_parallel_size=}\"\n    )\n\n\n# TODO(sgm): deviate from the v0.5.4, not pp now\n# NOTE(linjunrong): the SGLang version using _TP instead of ps._TP\ndef model_parallel_is_initialized():\n    \"\"\"Check if tensor and pipeline parallel groups are initialized.\"\"\"\n    return _TP is not None\n    # and _PIPELINE_MODEL_PARALLEL_GROUP is not None)\n\n\ndef initialize_model_parallel_for_sglang(\n    tensor_model_parallel_size: int,\n    num_tensor_model_parallel_groups_per_train_tp: int = 1,\n    pipeline_model_parallel_size: int = 1,\n) -> None:\n    pass\n\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n\n    assert isinstance(tensor_model_parallel_size, int)\n\n    # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group\n    # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group\n\n    # Build the tensor model-parallel groups.\n    assert ps._TP is None, \"tensor model parallel group is already initialized\"\n\n    global _TP\n\n    world_size: int = torch.distributed.get_world_size()\n\n    backend = torch.distributed.get_backend()\n\n    num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size\n\n    if num_tensor_model_parallel_groups_per_train_tp == 1:\n        # if tensor_model_parallel_size == train_tensor_parallel_size:\n        # using the same tp group as Megatron/vllm\n        assert _TP is None, \"tensor model parallel group is already initialized\"\n        group_ranks = []\n        for i in range(num_tensor_model_parallel_groups):\n            ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n            group_ranks.append(ranks)\n        _TP = init_model_parallel_group(\n            group_ranks=group_ranks,\n            local_rank=get_world_group().local_rank,\n            backend=backend,\n            use_custom_allreduce=False,  # TODO: check why True is not work in Ray trainer\n            use_message_queue_broadcaster=True,\n        )\n        ps._TP = _TP\n        # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine\n    else:\n        # initialize a micro_dp group and a tp group\n        # assume training tp=4, infer tp=2, then, weight is partitioned as\n        # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference\n\n        # Build the inference tp groups\n        # train_tp = train_tensor_parallel_size\n        train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size\n        # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size\n        assert _TP is None, \"tensor model parallel group is already initialized\"\n        group_ranks = []\n        for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):\n            start = train_tp * i\n            end = train_tp * (i + 1)\n            for j in range(num_tensor_model_parallel_groups_per_train_tp):\n                ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))\n                for i in range(len(ranks)):\n                    ranks[i] += j\n                group_ranks.append(ranks)\n        _TP = init_model_parallel_group(\n            group_ranks=group_ranks,\n            local_rank=get_world_group().local_rank,\n            backend=backend,\n            use_custom_allreduce=False,  # TODO: check why True is not work in Ray trainer\n            use_message_queue_broadcaster=True,\n        )\n        ps._TP = _TP\n\n    # Build the pipeline model-parallel groups.\n    # global _PIPELINE_MODEL_PARALLEL_GROUP\n    # global _PIPELINE_GLOBAL_RANKS\n    # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, (\"pipeline model parallel group is already initialized\")\n\n    # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()\n    # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()\n\n    # TODO: init using device mesh (not support hybrid engine now)\n    # Build the pipeline model-parallel groups.\n    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n    global _PP\n    assert _PP is None, \"pipeline model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_pipeline_model_parallel_groups):\n        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))\n        group_ranks.append(ranks)\n    # pipeline parallel does not need custom allreduce\n    _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)\n    ps._PP = _PP  # for verl\n\n\ndef initialize_model_parallel(\n    tensor_model_parallel_size: int = 1,\n    pipeline_model_parallel_size: int = 1,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"\n    NOTE: This method is a hack from the open-sourced version without\n    asertion of world_size = tp * pp\n\n    Initialize model parallel groups.\n\n    Arguments:\n        tensor_model_parallel_size: number of GPUs used for tensor model\n            parallelism.\n        pipeline_model_parallel_size: number of GPUs used for pipeline model\n            parallelism.\n\n    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we\n    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize\n    the model pipeline. The present function will\n    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:\n        4 tensor model-parallel groups:\n            [g0, g1], [g2, g3], [g4, g5], [g6, g7]\n        2 pipeline model-parallel groups:\n            [g0, g2, g4, g6], [g1, g3, g5, g7]\n    Note that for efficiency, the caller should make sure adjacent ranks\n    are on the same DGX box. For example if we are using 2 DGX-1 boxes\n    with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n    ranks 8 to 15 belong to the second box.\n    \"\"\"\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n    world_size: int = torch.distributed.get_world_size()\n    backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group)\n\n    # NOTE(sgm) we don't assert world_size == tp * pp\n    # DP is not managed by vllm but by the VeRL WorkerGroup\n    # if (world_size !=\n    #         tensor_model_parallel_size * pipeline_model_parallel_size):\n    #     raise RuntimeError(\n    #         f\"world_size ({world_size}) is not equal to \"\n    #         f\"tensor_model_parallel_size ({tensor_model_parallel_size}) x \"\n    #         f\"pipeline_model_parallel_size ({pipeline_model_parallel_size})\")\n\n    num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n\n    global _TP\n    assert _TP is None, \"tensor model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_tensor_model_parallel_groups):\n        ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size))\n        group_ranks.append(ranks)\n\n    # message queue broadcaster is only used in tensor model parallel group\n    if ps._TP is not None:\n        _TP = ps._TP\n    else:\n        _TP = init_model_parallel_group(\n            group_ranks,\n            get_world_group().local_rank,\n            backend,\n            use_custom_allreduce=False,  # TODO: check why True is not work in Ray trainer\n            use_message_queue_broadcaster=True,\n        )\n        ps._TP = _TP\n\n    # TODO: init using device mesh (not support hybrid engine now)\n    # Build the pipeline model-parallel groups.\n    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n    global _PP\n    assert _PP is None, \"pipeline model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_pipeline_model_parallel_groups):\n        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))\n        group_ranks.append(ranks)\n    # pipeline parallel does not need custom allreduce\n    if ps._TP is not None:\n        _PP = ps._TP\n    else:\n        _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)\n        ps._PP = _PP\n\n\n\"\"\"\nDevice mesh utilities\n\"\"\"\n\n\ndef get_device_mesh():\n    assert _DEVICE_MESH is not None, \"device mesh is not initialized\"\n    return _DEVICE_MESH\n\n\n\"\"\"\nTensor model parallel utilities\n\"\"\"\n\n\n# NOTE(linjunrong): In the vllm version parallel_state.py. verl created its own _TP and _PP as verl want to use\n# the process group for some extra purpose. Under the hood, there is no difference between them and the original\n# one in vllm.distributed.parallel_state. However, the implementation need to hack the init process of inference\n# engine, as we do not maintain another SGLang here, I just use the original _TP and _PP directly.\ndef get_tensor_model_parallel_group():\n    \"\"\"Get the tensor model parallel group the caller rank belongs to.\"\"\"\n\n    assert _TP is not None, \"tensor model parallel group is not initialized\"\n    return _TP.device_group\n\n\ndef get_tensor_model_parallel_world_size():\n    \"\"\"Return world size for the tensor model parallel group.\"\"\"\n    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())\n\n\ndef get_tensor_model_parallel_rank():\n    \"\"\"Return my rank for the tensor model parallel group.\"\"\"\n    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())\n\n\ndef get_tensor_model_parallel_src_rank():\n    \"\"\"Calculate the global rank corresponding to the first local rank\n    in the tensor model parallel group.\"\"\"\n    global_rank = torch.distributed.get_rank()\n    local_world_size = get_tensor_model_parallel_world_size()\n    return (global_rank // local_world_size) * local_world_size\n"
  },
  {
    "path": "verl_distillation/verl/third_party/torch/__init__.py",
    "content": "# official torch 2.6.0 set_model_state_dict API leads to OOM\n# this is a copy of torch/distributed/checkpoint from torch 2.7.0\n\n# From PyTorch:\n\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n# From Caffe2:\n\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n\n# All contributions by Cruise LLC:\n# Copyright (c) 2022 Cruise LLC.\n# All rights reserved.\n\n# All contributions by Tri Dao:\n# Copyright (c) 2024 Tri Dao.\n# All rights reserved.\n\n# All contributions by Arm:\n# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "verl_distillation/verl/third_party/torch/distributed/__init__.py",
    "content": "# official torch 2.6.0 set_model_state_dict API leads to OOM\n# this is a copy of torch/distributed/checkpoint from torch 2.7.0\n\n# From PyTorch:\n\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n# From Caffe2:\n\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n\n# All contributions by Cruise LLC:\n# Copyright (c) 2022 Cruise LLC.\n# All rights reserved.\n\n# All contributions by Tri Dao:\n# Copyright (c) 2024 Tri Dao.\n# All rights reserved.\n\n# All contributions by Arm:\n# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "verl_distillation/verl/third_party/torch/distributed/_state_dict_utils.py",
    "content": "# official torch 2.6.0 set_model_state_dict API leads to OOM\n# this is a copy of torch/distributed/checkpoint from torch 2.7.0\n\n# From PyTorch:\n\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n# From Caffe2:\n\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n\n# All contributions by Cruise LLC:\n# Copyright (c) 2022 Cruise LLC.\n# All rights reserved.\n\n# All contributions by Tri Dao:\n# Copyright (c) 2024 Tri Dao.\n# All rights reserved.\n\n# All contributions by Arm:\n# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n\n\n# ruff: noqa: B028, UP038, UP007, E721, E501\n# mypy: allow-untyped-defs\nimport copy\nimport io\nimport math\nimport weakref\nfrom collections.abc import Mapping, MutableMapping\nfrom typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union, cast\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom torch.distributed._functional_collectives import AsyncCollectiveTensor\n\nif dist.is_available() or TYPE_CHECKING:\n    from torch.distributed import distributed_c10d\n    from torch.distributed._shard.sharded_tensor import ShardedTensor\n    from torch.distributed.tensor import DTensor, Replicate, distribute_tensor\n    from torch.distributed.tensor._utils import compute_local_shape_and_global_offset\n\n\ndef _identity_func(\n    obj: torch.Tensor,\n    pg: Optional[dist.ProcessGroup],\n    device: Optional[torch.device],\n    companion_obj: Any,\n) -> torch.Tensor:\n    return obj\n\n\ndef _all_gather_sharded_tensor(\n    sharded_tensor: \"ShardedTensor\",\n    pg: Optional[dist.ProcessGroup] = None,\n    device: Optional[torch.device] = None,\n) -> torch.Tensor:\n    if pg is None:\n        pg = distributed_c10d._get_default_group()\n    world_size = dist.get_world_size(pg)\n    shards = sharded_tensor.local_shards()\n    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]\n    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]\n    chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size\n    pg_device = distributed_c10d._get_pg_default_device(pg) if device is None else device\n    if shards:\n        local_tensor = shards[0].tensor.flatten()\n        if local_tensor.device.type != pg_device.type:\n            local_tensor = local_tensor.to(pg_device)\n        num_padding = chunk_size - local_tensor.numel()\n        if num_padding > 0:\n            local_tensor = F.pad(local_tensor, [0, num_padding])\n    else:\n        local_tensor = torch.zeros(chunk_size, dtype=sharded_tensor.dtype, device=pg_device)\n\n    tensor = torch.empty(\n        chunk_size * world_size,\n        dtype=local_tensor.dtype,\n        device=pg_device,\n    )\n    dist.all_gather_into_tensor(tensor, local_tensor, group=pg)\n\n    tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())\n    return tensor\n\n\nclass CompanionMismatch(Exception):\n    pass\n\n\ndef _iterate_state_dict(\n    iter_object: Any,\n    sharded_tensor_func: Callable,\n    dtensor_func: Callable,\n    tensor_func: Callable,\n    *,\n    pg: Optional[dist.ProcessGroup] = None,\n    device: Optional[torch.device] = None,\n    cpu_offload: bool = False,\n    companion_obj: Any = None,\n    ranks_only: tuple[int, ...] = (),\n    type_check: bool = True,\n    non_blocking: bool = True,\n) -> dict[str, Any]:\n    \"\"\"Iterate through the state dict, applying the given functions to each tensor type.\n\n    Args:\n        iter_object (Any): the target state_dict.\n        sharded_tensor_func (Callable): the function to apply to ShardedTensor\n        dtensor_func (Callable): the function to apply to DTensor\n        tensor_func (Callable): the function to apply to Tensor\n        pg (Optional[dist.ProcessGroup]): process group passed to tensor functions\n        device (Optional[torch.device]): device passed to tensor functions\n        cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored\n            if a companion_obj is supplied.\n        companion_obj (Any): A companion object to the state dict. If this object\n            is supplied, we attempt to copy the tensor to the companion object.\n        ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will\n            have the same state_dicts. Otherwise only ranks that in ``ranks_only``\n            have the same state_dicts. Other ranks will get empty state_dicts.\n        type_check (bool): check if the instance data type is a supported type\n            that can be saved by DCP.  The current supported data types are\n            torch.Tensor, DTensor, int, float, str, list, dict, None.\n        non_blocking (bool): whether to use non-blocking copy when copying to the companion object.\n    \"\"\"\n    # TODO: should we use pytree?\n    cpu_device = torch.device(\"cpu\")\n    if isinstance(iter_object, ShardedTensor):\n        ret = sharded_tensor_func(iter_object, pg, device, companion_obj)\n    elif isinstance(iter_object, DTensor):\n        ret = dtensor_func(iter_object, pg, device, companion_obj)\n    elif isinstance(iter_object, torch.Tensor):\n        ret = tensor_func(iter_object, pg, device, companion_obj)\n    elif isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) or iter_object is None:\n        ret = iter_object\n    elif isinstance(iter_object, dict):\n        if companion_obj is not None and (\n            not isinstance(companion_obj, dict) or set(companion_obj.keys()) != set(iter_object.keys())\n        ):\n            msg = \"\" if isinstance(companion_obj, dict) else f\"{set(companion_obj.keys())=} {set(iter_object.keys())=}\"\n            raise CompanionMismatch(msg)\n\n        ret = {\n            key: _iterate_state_dict(\n                value,\n                sharded_tensor_func,\n                dtensor_func,\n                tensor_func,\n                pg=pg,\n                device=device,\n                cpu_offload=cpu_offload,\n                companion_obj=companion_obj[key] if companion_obj is not None else None,\n                ranks_only=ranks_only,\n                type_check=type_check,\n                non_blocking=non_blocking,\n            )\n            for key, value in iter_object.items()\n        }\n    elif isinstance(iter_object, (list, tuple)):\n        if companion_obj is not None and (\n            not isinstance(companion_obj, (list, tuple)) or len(companion_obj) != len(iter_object)\n        ):\n            raise CompanionMismatch\n\n        ret = [\n            _iterate_state_dict(\n                v,\n                sharded_tensor_func,\n                dtensor_func,\n                tensor_func,\n                pg=pg,\n                device=device,\n                cpu_offload=cpu_offload,\n                companion_obj=companion_obj[idx] if companion_obj is not None else None,\n                ranks_only=ranks_only,\n                type_check=type_check,\n                non_blocking=non_blocking,\n            )\n            for idx, v in enumerate(iter_object)\n        ]\n        if isinstance(iter_object, tuple):\n            ret = tuple(ret)\n    elif not type_check:\n        ret = copy.deepcopy(iter_object)\n    else:\n        raise ValueError(f\"Unexpected value type {type(iter_object)}\")\n\n    if not ranks_only or dist.get_rank(pg) in ranks_only:\n        if isinstance(ret, torch.Tensor):\n            if cpu_offload and companion_obj is None:\n                ret = ret.to(cpu_device)\n\n            if companion_obj is not None:\n                if isinstance(companion_obj, DTensor):\n                    assert isinstance(ret, DTensor)\n                    companion_obj._local_tensor.copy_(ret._local_tensor, non_blocking=non_blocking)\n                else:\n                    companion_obj.copy_(ret, non_blocking=non_blocking)\n                ret = companion_obj\n    else:\n        ret = {} if isinstance(ret, dict) else None\n\n    return ret\n\n\ndef _gather_state_dict(\n    state_dict: dict[str, Any],\n    *,\n    pg: Optional[dist.ProcessGroup] = None,\n    device: Optional[torch.device] = None,\n    cpu_offload: bool = False,\n    ranks_only: tuple[int, ...] = (),\n    type_check: bool = True,\n) -> dict[str, Any]:\n    \"\"\"\n    Given a state_dict, this API gathers all the ShardedTensors or DTensors in\n    the state_dict.\n\n\n    Args:\n        state_dict (Dict[str, Any]): the target sharded state_dict.\n        pg (Optional[dist.ProcessGroup]): the process group that is used to\n            gather ShardedTensor. Note that gathering a DTensor will use\n            the DeviceMesh. So this argument will be ignored when gathering a\n            DTensor.\n        device: (Optional[torch.device]): the device that is used to\n            perform allgather for ShardedTensor. Note that gathering a DTensor\n            will use the DeviceMesh. So this argument will be ignored when\n            gathering a DTensor.\n        cpu_offload (bool): whether to offload the tensors to CPU memory. The\n            default value is False.\n        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will\n            have the same state_dicts. Otherwise only ranks that in ``ranks_only``\n            have the same state_dicts. Other ranks will get empty state_dicts.\n        type_check: (bool): check if the instance data type is a supported type\n            that can be saved by DCP.  The current supported data types are\n            torch.Tensor, DTensor, int, float, str, list, dict, None.\n\n    Returns:\n        The gathered state dictionary.\n    \"\"\"\n\n    def sharded_tensor_func(value, pg, device, companion_obj):\n        # ShardedTensor does not seem to record the original device type.\n        # So if the tensor is moved to CPU, we won't know the original type.\n        # As a result, we have to rely on the user to tell us the correct one.\n        cpu_device = torch.device(\"cpu\")\n        output_tensor = _all_gather_sharded_tensor(value, pg, device)\n        local_shard_device = value.local_shards()[0].tensor.device if value.local_shards() else cpu_device\n        if output_tensor.device != local_shard_device:\n            value = output_tensor.to(local_shard_device)\n        else:\n            value = output_tensor\n        return value\n\n    def dtensor_func(value, pg, device, companion_obj):\n        if value.device != value.device_mesh.device_type:\n            value = value.to(value.device_mesh.device_type)\n        # FSDP all_gather: [Shard(0)] -> [Replicate()]\n        # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]\n        # 2D FSDP + TP all_gather:\n        # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]\n        # - [Shard(0), Replicate()] -> [Replicate(), Replicate()]\n        placements = [Replicate() for _ in value.placements]\n        value = value.redistribute(\n            device_mesh=value.device_mesh,\n            placements=placements,\n        )\n        # Call `wait()` to force the tensor to be synchronous with respect\n        # to the main stream.\n        # See the discussion in https://github.com/pytorch/pytorch/pull/117799.\n        value = value.to_local()\n        if isinstance(value, AsyncCollectiveTensor):\n            value = value.wait()\n        return value\n\n    return _iterate_state_dict(\n        state_dict,\n        sharded_tensor_func,\n        dtensor_func,\n        _identity_func,\n        pg=pg,\n        device=device,\n        cpu_offload=cpu_offload,\n        ranks_only=ranks_only,\n        type_check=type_check,\n    )\n\n\ndef _offload_state_dict_to_cpu(\n    state_dict: dict[str, Any],\n    *,\n    ranks_only: tuple[int, ...] = (),\n    type_check: bool = True,\n) -> dict[str, Any]:\n    \"\"\"\n    Given a state_dict, this API offload all the tensors to CPU memory.\n\n    Args:\n        state_dict (Dict[str, Any]): the target state_dict.\n        pg (Optional[dist.ProcessGroup]): the process group that is used to\n            gather ShardedTensor. Note that gathering a DTensor will use\n            the DeviceMesh. So this argument will be ignored when gathering a\n            DTensor.\n        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will\n            have the same state_dicts. Otherwise only ranks that in ``ranks_only``\n            have the same state_dicts. Other ranks will get empty state_dicts.\n        type_check: (bool): check if the instance data type is a supported type\n            that can be saved by DCP.  The current supported data types are\n            torch.Tensor, DTensor, int, float, str, list, dict, None.\n\n    Returns:\n        The gathered state dictionary.\n    \"\"\"\n\n    ret = _iterate_state_dict(\n        state_dict,\n        _identity_func,\n        _identity_func,\n        _identity_func,\n        pg=None,\n        device=None,\n        cpu_offload=True,\n        ranks_only=ranks_only,\n        type_check=type_check,\n    )\n    return ret\n\n\n@torch.no_grad()\ndef _copy_state_dict(\n    state_dict: dict[str, Any],\n    copy_state_dict: dict[str, Any],\n    non_blocking: bool = False,\n    type_check: bool = True,\n) -> dict[str, Any]:\n    \"\"\"\n    Copies all tensors in a given state dict into a different state_dict with the\n    same structure. Additionally, a copied state dict with the same value references\n    is returned. Editing the keys on this state dict will not affect the\n    passed in copy_state_dict (but the value references are the same).\n\n    .. warning::\n        It is expected by this function that state_dict and copy_state_dict share\n        the same structure and data types.\n\n    .. warning::\n        The current supported data types are\n            torch.Tensor, DTensor, int, float, str, list, dict, None.\n\n    Args:\n        state_dict (Dict[str, Any]): the target state_dict.\n        copy_state_dict (Dict[str, Any]):\n            The state dict we are copying into. This state_dict must have exactly\n             the same structure as the source `state_dict`.\n        non_blocking: (bool): Whether copy ops should be performed asynchronously\n        type_check (bool): check if the instance data type is a supported type\n            that can be saved by DCP. The current supported data types are\n            torch.Tensor, DTensor, int, float, str, list, dict, None.\n\n    Returns:\n        State Dict copy\n    \"\"\"\n\n    return _iterate_state_dict(\n        state_dict,\n        _identity_func,\n        _identity_func,\n        _identity_func,\n        pg=None,\n        device=None,\n        cpu_offload=False,\n        ranks_only=(),\n        companion_obj=copy_state_dict,\n        type_check=type_check,\n        non_blocking=non_blocking,\n    )\n\n\n@torch.no_grad()\ndef _create_cpu_state_dict(\n    state_dict: dict[str, Any], pin_memory: bool = False, share_memory: bool = False\n) -> dict[str, Any]:\n    \"\"\"\n    Given a state_dict, create another state_dict with the same structure and elements.\n    However, all tensors in the returned state_dict are new tensors on CPU. These\n    tensors can be placed on pin_memory or share_memory based on the provided arguments.\n\n    .. warning::\n        Setting both `pin_memory` and `share_memory` to True significantly increases the\n        latency of this method because of the nuances which require us to register memory\n        as pinned directly as opposed to relying on the pin_memory cache allocator. This\n        option should only be used for long lived tensors which are required to be shared.\n        This is not the case as long as at least one of `pin_memory` or `share_memory` is\n         set to False.\n\n    \"\"\"\n\n    def tensor_func(\n        obj: torch.Tensor,\n        pg: Optional[dist.ProcessGroup],\n        device: Optional[torch.device],\n        _: Any,\n    ) -> torch.Tensor:\n        if len(obj.size()) == 0:\n            return torch.tensor(0, dtype=obj.dtype)\n\n        if share_memory:\n            t = torch.empty(*tuple(obj.size()), dtype=obj.dtype)\n            t = t.share_memory_()\n            if pin_memory:\n\n                def unpin_memory(t):\n                    succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr()))\n                    assert succ == 0, f\"Unpinning shared memory failed with error-code: {succ}\"\n\n                weakref.finalize(t, unpin_memory, t)\n                succ = int(\n                    torch.cuda.cudart().cudaHostRegister(\n                        t.data_ptr(),\n                        t.numel() * t.element_size(),\n                        1,  # lines up with 'cudaHostRegisterPortable'\n                    )\n                )\n                assert succ == 0, f\"Pinning shared memory failed with error-code: {succ}\"\n            return t\n        elif pin_memory:\n            return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()\n        else:\n            return torch.empty(*tuple(obj.size()), dtype=obj.dtype)\n\n    def dtensor_func(\n        obj: DTensor,\n        pg: Optional[dist.ProcessGroup],\n        device: Optional[torch.device],\n        _: Any,\n    ) -> DTensor:\n        if len(obj.size()) == 0:\n            return obj\n\n        if obj.device != torch.device(\"cpu\"):\n            ret = cast(DTensor, obj.to(device=\"cpu\"))\n        else:\n            ret = copy.deepcopy(obj)\n        ret._local_tensor = tensor_func(ret._local_tensor, pg, device, None)\n        return ret\n\n    ret = _iterate_state_dict(\n        state_dict,\n        _identity_func,\n        dtensor_func,\n        tensor_func,\n        pg=None,\n        device=None,\n        cpu_offload=False,\n        ranks_only=(),\n        type_check=False,\n    )\n    return ret\n\n\ndef _check_state_dict_similarity(\n    state_dict: dict[str, Any],\n    compared_state_dict: dict[str, Any],\n) -> bool:\n    \"\"\"\n    Given two state_dicts, check if the structures are the same. And\n    if a [key, tensor] pair exist in one state_dict there must be\n    the a corresponding pait, [key, other_tensor], in the other state_dict,\n    where tensor and other_tensor have the same size and dtype.\n\n    Return the check result.\n    \"\"\"\n\n    def tensor_func(\n        obj: torch.Tensor,\n        pg: Optional[dist.ProcessGroup],\n        device: Optional[torch.device],\n        companion_obj: Any,\n    ) -> torch.Tensor:\n        if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():\n            raise CompanionMismatch\n        return obj\n\n    try:\n        _iterate_state_dict(\n            state_dict,\n            _identity_func,\n            _identity_func,\n            tensor_func,\n            pg=None,\n            device=None,\n            cpu_offload=False,\n            ranks_only=(),\n            companion_obj=compared_state_dict,\n            type_check=False,\n        )\n    except CompanionMismatch:\n        return False\n\n    return True\n\n\nclass _TensorInfo(NamedTuple):\n    size: torch.Size\n    dtype: torch.dtype\n\n\ndef _broadcast_tensors(\n    full_state_dict: dict[str, Any],\n    local_state_dict: dict[str, Any],\n    keys: list[str],\n    device: torch.device,\n    pg: Optional[dist.ProcessGroup] = None,\n) -> None:\n    tensors = []\n    for key in keys:\n        if dist.get_rank() == 0:\n            full_state = full_state_dict[key]\n            assert isinstance(full_state, torch.Tensor)\n            full_tensor = full_state.detach().to(device)\n        else:\n            tensor_info = full_state_dict[key]\n            full_tensor = torch.empty(\n                size=tensor_info.size,\n                device=device,\n                dtype=tensor_info.dtype,\n            )\n        tensors.append(full_tensor)\n        local_state = local_state_dict.get(key, None)\n        if local_state is None:\n            continue\n        elif isinstance(local_state, DTensor):\n            local_state_dict[key] = (local_state, full_tensor)\n        else:\n            local_state_dict[key] = full_tensor\n\n    if pg is None:\n        pg = dist.distributed_c10d._get_default_group()\n\n    if len(tensors) > 1:\n        dist._broadcast_coalesced(pg, tensors, 500, 0)\n    else:\n        dist.broadcast(tensors[0], src=0, group=pg)\n\n    _distribute_tensors(local_state_dict, keys, device, pg)\n\n\ndef _distribute_tensors(\n    local_state_dict: dict[str, Any],\n    keys: list[str],\n    device: torch.device,\n    pg: Optional[dist.ProcessGroup] = None,\n) -> None:\n    if pg is None:\n        pg = dist.distributed_c10d._get_default_group()\n    for key in keys:\n        _local_state = local_state_dict.get(key, None)\n        if _local_state is None or torch.is_tensor(_local_state):\n            continue\n\n        local_state = _local_state[0]\n        full_tensor = _local_state[1]\n\n        shape, offset = compute_local_shape_and_global_offset(\n            full_tensor.shape, local_state.device_mesh, local_state.placements\n        )\n        slices = [\n            slice(cur_offset, cur_offset + cur_shape) for cur_shape, cur_offset in zip(shape, offset, strict=False)\n        ]\n        if local_state.is_meta:\n            # Use .clone() here rather than view to clone and return only the sliced portion, minimizing memory access and cost.\n            local_tensor = full_tensor[slices].detach().clone()\n            # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example,\n            # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)).\n            ret = DTensor.from_local(\n                local_tensor,\n                local_state.device_mesh,\n                local_state.placements,\n                shape=local_state.shape,\n                stride=local_state.stride(),\n            )\n        else:\n            ret = local_state\n            # Copy full_tensor[slices] into local_state.to_local() to reduce memory footprint.\n            ret.to_local().copy_(full_tensor[slices])\n        local_state_dict[key] = ret\n\n\ndef _broadcast_state_dict(\n    full_state_dict: dict[str, Any],\n    local_state_dict: dict[str, Any],\n    device: torch.device,\n    pg: Optional[dist.ProcessGroup] = None,\n    strict: bool = False,\n    cpu_offload: bool = False,\n) -> None:\n    # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`.\n    # If strict is True, any keys in `local_state_dict` but not in `full_state_dict`\n    # will be removed from `local_state_dict`.\n    ret = {}\n    if dist.get_rank() == 0:\n        for key, value in full_state_dict.items():\n            if not torch.is_tensor(value):\n                ret[key] = value\n            elif value.dim() == 0:\n                ret[key] = value.cpu()\n            else:\n                ret[key] = _TensorInfo(value.size(), value.dtype)\n\n    broadcast_list = [ret]\n    dist.broadcast_object_list(broadcast_list, src=0, group=pg)\n    ret = broadcast_list[0]\n    # Gather values\n    keys = []\n    local_state_dict_keys = set(local_state_dict.keys())\n    global_keys = set()\n    for key, value in ret.items():\n        global_keys.add(key)\n        if not isinstance(value, _TensorInfo):\n            if key in local_state_dict:\n                local_state_dict[key] = value\n            continue\n\n        if dist.get_rank() == 0:\n            ret[key] = full_state_dict[key]\n\n        keys.append(key)\n        # Broadcast every tensor to avoid OOM for now.\n        if len(keys) >= 1:\n            _broadcast_tensors(ret, local_state_dict, keys, device, pg)\n            if cpu_offload:\n                for key in keys:\n                    local_state_dict[key] = local_state_dict[key].cpu()\n            keys.clear()\n\n    if strict:\n        if missing_keys := (local_state_dict_keys - global_keys):\n            for key in missing_keys:\n                local_state_dict.pop(key)\n\n    if keys:\n        _broadcast_tensors(ret, local_state_dict, keys, device, pg)\n        if cpu_offload:\n            for key in keys:\n                local_state_dict[key] = local_state_dict[key].cpu()\n\n\ndef _distribute_state_dict(\n    full_state_dict: dict[str, Any],\n    local_state_dict: dict[str, Any],\n    device: torch.device,\n    pg: Optional[dist.ProcessGroup] = None,\n) -> None:\n    # Full_state_dict = True, broadcast_from_rank0 = False here. Each rank has\n    # full_state_dict. Skip the broadcast in ``_broadcast_state_dict`` and\n    # distribute tensors in each rank\n    for key, value in full_state_dict.items():\n        if key not in full_state_dict:\n            continue\n        if not torch.is_tensor(value):\n            local_state_dict[key] = value\n        elif value.dim() == 0:\n            local_state_dict[key] = value.cpu()\n        else:\n            assert isinstance(value, torch.Tensor)\n            local_state = local_state_dict.get(key, None)\n            if local_state is None:\n                continue\n            elif isinstance(local_state, DTensor):\n                local_state_dict[key] = distribute_tensor(\n                    value.detach().to(device),\n                    local_state.device_mesh,\n                    local_state.placements,\n                )\n            else:\n                local_state_dict[key] = value.detach().to(device)\n\n\n# These APIs are from torch.distributed.checkpoint.\n# TODO: We should consolidate the code here as some not all modules can depend on\n# DCP.\nPATH_ITEM = Union[str, int]\nOBJ_PATH = tuple[PATH_ITEM, ...]\nFLATTEN_MAPPING = dict[str, OBJ_PATH]\nSTATE_DICT_TYPE = dict[str, Any]\nCONTAINER_TYPE = MutableMapping[PATH_ITEM, Any]\n\n\ndef _traverse_state_dict(\n    state_dict: STATE_DICT_TYPE,\n    visitor: Callable[[OBJ_PATH, Any], None],\n) -> None:\n    \"\"\"\n    Invoke ``visitor`` for each value recursively in ``state_dict``.\n    Mapping, list, and tuple will be flattened and other value types are treated\n    as the terminal values and will invoke ``visitor``.\n    \"\"\"\n\n    def _traverse_obj(path: OBJ_PATH, value: Any) -> None:\n        if isinstance(value, Mapping):\n            for k, v in value.items():\n                _traverse_obj(path + (str(k),), v)\n        elif isinstance(value, (list, tuple)):\n            for i, v in enumerate(value):\n                _traverse_obj(path + (i,), v)\n        else:\n            visitor(path, value)\n\n    for key, value in state_dict.items():\n        _traverse_obj((str(key),), value)\n\n\ndef _flatten_state_dict(\n    state_dict: STATE_DICT_TYPE,\n) -> tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]:\n    \"\"\"\n    Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary.\n\n    Use ``unflatten_state_dict`` to revert this process.\n    Returns:\n        A tuple with the flatten state_dict and a mapping from original to new state_dict.\n    N.B. The new keys are derived from the object paths, joined by dot.\n        For example: ``{ 'a': {'b':...}}`` results in the key `a.b`.\n    \"\"\"\n    flattened: STATE_DICT_TYPE = {}\n    mappings: FLATTEN_MAPPING = {}\n\n    def flat_copy(path: OBJ_PATH, value: Any) -> None:\n        new_fqn = \".\".join(map(str, path))\n        if new_fqn in flattened:\n            raise ValueError(f\"duplicated flatten key {new_fqn}\")\n        flattened[new_fqn] = value\n        mappings[new_fqn] = path\n\n    _traverse_state_dict(state_dict, flat_copy)\n    return flattened, mappings\n\n\ndef _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None:\n    \"\"\"Set ``value`` in ``root_dict`` along the ``path`` object path.\"\"\"\n    cur_container = cast(CONTAINER_TYPE, root_dict)\n\n    def extend_list(lst: list[Any], idx: int) -> None:\n        while len(lst) <= idx:\n            lst.append(None)\n\n    for i in range(1, len(path)):\n        prev_key = path[i - 1]\n        key = path[i]\n        def_val: CONTAINER_TYPE | list[Any] = {} if type(key) == str else []\n\n        if isinstance(cur_container, Mapping):\n            cur_container = cast(CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val))\n        else:\n            extend_list(cur_container, prev_key)\n            if cur_container[prev_key] is None:\n                cur_container[prev_key] = def_val\n            cur_container = cur_container[prev_key]\n\n    key = path[-1]\n    if type(key) == int:\n        extend_list(cast(list[Any], cur_container), key)\n\n    cur_container[key] = value\n\n\ndef _unflatten_state_dict(state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING) -> STATE_DICT_TYPE:\n    \"\"\"Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.\"\"\"\n    nested: STATE_DICT_TYPE = {}\n    for key, value in state_dict.items():\n        _set_element(nested, mapping[key], value)\n    return nested\n"
  },
  {
    "path": "verl_distillation/verl/third_party/torch/distributed/checkpoint/__init__.py",
    "content": "# official torch 2.6.0 set_model_state_dict API leads to OOM\n# this is a copy of torch/distributed/checkpoint from torch 2.7.0\n\n# From PyTorch:\n\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n# From Caffe2:\n\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n\n# All contributions by Cruise LLC:\n# Copyright (c) 2022 Cruise LLC.\n# All rights reserved.\n\n# All contributions by Tri Dao:\n# Copyright (c) 2024 Tri Dao.\n# All rights reserved.\n\n# All contributions by Arm:\n# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "verl_distillation/verl/third_party/torch/distributed/checkpoint/state_dict.py",
    "content": "# official torch 2.6.0 set_model_state_dict API leads to OOM\n# this is a copy of torch/distributed/checkpoint from torch 2.7.0\n\n# From PyTorch:\n\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n# From Caffe2:\n\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n\n# All contributions by Cruise LLC:\n# Copyright (c) 2022 Cruise LLC.\n# All rights reserved.\n\n# All contributions by Tri Dao:\n# Copyright (c) 2024 Tri Dao.\n# All rights reserved.\n\n# All contributions by Arm:\n# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n\n# ruff: noqa: B028, UP038, UP007, E721\n# mypy: allow-untyped-defs\nimport contextlib\nimport functools\nimport gc\nimport warnings\nfrom collections.abc import Generator, Iterable\nfrom dataclasses import asdict, dataclass, field\nfrom itertools import chain\nfrom typing import Any, Callable, Optional, Union, cast, no_type_check\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom torch.distributed._shard.sharded_tensor import ShardedTensor\nfrom torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (\n    _CHECKPOINT_PREFIX,\n)\nfrom torch.distributed.fsdp import (\n    FullOptimStateDictConfig,\n    FullStateDictConfig,\n    OptimStateDictConfig,\n    ShardedOptimStateDictConfig,\n    ShardedStateDictConfig,\n    StateDictConfig,\n    StateDictType,\n)\nfrom torch.distributed.fsdp import (\n    FullyShardedDataParallel as FSDP,\n)\nfrom torch.distributed.fsdp._common_utils import (\n    FSDP_WRAPPED_MODULE,\n    _get_module_fsdp_state_if_fully_sharded_module,\n)\nfrom torch.distributed.tensor import DTensor\nfrom torch.nn.modules.module import _IncompatibleKeys\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.utils._pytree import tree_map_only\n\nfrom verl.third_party.torch.distributed._state_dict_utils import (\n    _broadcast_state_dict,\n    _distribute_state_dict,\n    _flatten_state_dict,\n    _gather_state_dict,\n    _offload_state_dict_to_cpu,\n    _unflatten_state_dict,\n)\n\n__all__ = [\n    \"FQNS_T\",\n    \"PrimitiveType\",\n    \"ValueType\",\n    \"DictValueType\",\n    \"ListDictValueType\",\n    \"OptimizerStateType\",\n    \"StateDictOptions\",\n    \"get_model_state_dict\",\n    \"get_optimizer_state_dict\",\n    \"get_state_dict\",\n    \"set_model_state_dict\",\n    \"set_optimizer_state_dict\",\n    \"set_state_dict\",\n]\n\n\n_FLAT_PARAM = \"_flat_param\"\n_PG = \"param_groups\"\n_PARAMS = \"params\"\n_STATE = \"state\"\n\nFQNS_T = set[str]\nPrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str]\nValueType = Union[PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, \"ValueType\"]]\nDictValueType = dict[str, ValueType]\nListDictValueType = list[DictValueType]\nOptimizerStateType = dict[str, DictValueType | ListDictValueType]\n\n\n_patched_state_dict: set[Callable] = set()\n\n\n@contextlib.contextmanager\ndef _gc_context():\n    is_enabled = gc.isenabled()\n    gc.disable()\n    try:\n        yield\n    finally:\n        if is_enabled:\n            gc.enable()\n\n\n@dataclass\nclass StateDictOptions:\n    \"\"\"\n    This dataclass specifies how get_state_dict/set_state_dict will work.\n\n    - ``full_state_dict``: if this is set to True, all the tensors in the\n      returned state_dict will be gathered. No ShardedTensor and DTensor\n      will be in the returned state_dict.\n\n    - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if\n      ``full_state_dict`` is also true, then only the rank0 will get the\n      state_dict and all other ranks will get empty state_dict.\n\n    - ``ignore_frozen_params``: if the value is True, the returned state_dict\n      won't contain any frozen parameters -- the ``requires_grad`` is False.\n      The default value is False.\n\n    - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option\n      indicates whether to keep the submodule prefixes from the state_dict keys.\n      or example, if the submodule is ``module.pretrain`` and the full FQN of\n      the parameter is ``pretrain.layer1.weight`` of the param. When this option\n      is True, the parameter's key in the returned state_dict will be\n      ``pretrain.layer1.weight``. If the options is False, the key will be\n      ``layer1.weight``.\n      Note that if ``keep_submodule_prefixes`` is False, there may be conflicted\n      FQNs, hence there should be only one submodule in ``submodules``.\n\n    - ``strict``: the ``strict`` option when ``set_state_dict`` calls\n      model.load_state_dict().\n\n    - ``broadcast_from_rank0``: when the option is True, rank0 should receive a\n       full state_dict and will broadcast the tensors in the state_dict/\n       optim_state_dict one by one to other ranks. Other ranks will receive\n       the tensors and shard according to the local shards in the model and\n       optimizer. ``full_state_dict`` must be set to True when using this option.\n       This option currently only supports DTensor, not the legacy ShardedTensor.\n    \"\"\"\n\n    full_state_dict: bool = False\n    cpu_offload: bool = False\n    ignore_frozen_params: bool = False\n    keep_submodule_prefixes: bool = True\n    strict: bool = True\n    broadcast_from_rank0: bool = False\n    flatten_optimizer_state_dict: bool = False\n    dsd_fqn_modifiers: str = \"_fqn_modifiers\"\n\n\n@dataclass\nclass _StateDictInfo(StateDictOptions):\n    fqn_param_mapping: dict[\n        str | torch.Tensor,\n        FQNS_T | torch.Tensor,\n    ] = field(default_factory=dict)\n    shared_params_mapping: dict[\n        str | torch.Tensor,\n        FQNS_T | torch.Tensor,\n    ] = field(default_factory=dict)\n    submodule_prefixes: set[str] = field(default_factory=set)\n    handle_model: bool = True\n    handle_optim: bool = True\n    fsdp_context: Callable = contextlib.nullcontext\n    fsdp_modules: list[nn.Module] = field(default_factory=list)\n\n\n@functools.cache\ndef _get_fqns(\n    model: nn.Module,\n    name: str,\n    dsd_fqn_modifiers: str = \"_fqn_modifiers\",\n    skip_ddp_prefix: bool = True,\n    skip_compiler_prefix: bool = True,\n) -> FQNS_T:\n    \"\"\"\n    This API is used to convert the name of a parameter to the FQNs. For FSDP\n    without `use_orig_params`, the name of FlatParameter can be mapped to\n    multiple original parameters. As a result, the return type of this function\n    is `set[str]`.\n\n    Args:\n        module (nn.Module): the root model.\n        name (str): the name\n        skip_ddp_prefix (bool): whether to skip DDP's `module` prefix\n\n    Returns:\n        The canonical FQNs based on the model traversal.\n    \"\"\"\n\n    # Remove the checkpoint prefix, if it exists.\n    name = name.replace(_CHECKPOINT_PREFIX, \"\")\n    if \".\" not in name:\n        return {name}\n\n    obj_names = name.split(\".\")\n    fqn_obj_names = []\n    curr_obj = model\n    for i, curr_obj_name in enumerate(obj_names):\n        if isinstance(curr_obj, DDP):\n            assert curr_obj_name == \"module\"\n            curr_obj = curr_obj.module\n            if not skip_ddp_prefix:\n                fqn_obj_names.append(curr_obj_name)\n        elif isinstance(curr_obj, FSDP):\n            if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM:\n                prefix = \".\".join(fqn_obj_names)\n                flat_param = getattr(curr_obj, _FLAT_PARAM)\n                if prefix:\n                    prefix = f\"{prefix}.\"\n                return {f\"{prefix}{fqn}\" for fqn in flat_param._fqns}\n            curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)\n            if curr_obj_name != FSDP_WRAPPED_MODULE:\n                fqn_obj_names.append(curr_obj_name)\n                curr_obj = getattr(curr_obj, curr_obj_name)\n        elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):\n            assert curr_obj_name == \"_orig_mod\"\n            curr_obj = curr_obj._orig_mod\n            if not skip_compiler_prefix:\n                fqn_obj_names.append(curr_obj_name)\n        else:\n            # In some modeuls, _fqn_modifiers would not shown in the state_dict keys,\n            # skip them in the fqn to ensure load stat dict successfully for them.\n            if hasattr(curr_obj, dsd_fqn_modifiers):\n                if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get(curr_obj_name):\n                    if hasattr(curr_obj, removed_fqn):\n                        curr_obj = getattr(curr_obj, removed_fqn)\n            fqn_obj_names.append(curr_obj_name)\n            if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:\n                if i != len(obj_names) - 1:\n                    raise RuntimeError(\"Expect `_extra_state` to be the last obj name\")\n            else:\n                curr_obj = getattr(curr_obj, curr_obj_name)\n\n    return {\".\".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, \"\")}\n\n\nclass _EXTRA_STATE:\n    pass\n\n\ndef _iterate_valid_model_state(model, dsd_fqn_modifiers=\"_fqn_modifiers\"):\n    visited_modules: set[nn.Module] = set()\n\n    def recurse(module: nn.Module, curr_fqn: str) -> Generator:\n        visited_modules.add(module)\n\n        curr_fqn = f\"{curr_fqn}.\" if curr_fqn else \"\"\n        for name, submodule in module.named_children():\n            if submodule in visited_modules:\n                continue\n            # if user have state_dict_hooks in their model, they can add the state_dict key changes\n            # at dsd_fqn_modifiers in input to align with the function of state_dict_hook\n            if hasattr(module, dsd_fqn_modifiers) and name in getattr(module, dsd_fqn_modifiers)().values():\n                # skip _fqn_modifiers here thus remove the last `.` added\n                new_fqn = curr_fqn[:-1]\n            else:\n                new_fqn = f\"{curr_fqn}{name}\"\n            yield from recurse(submodule, new_fqn)\n\n        for name, obj in chain(module.named_buffers(recurse=False), module.named_parameters(recurse=False)):\n            if name in module._non_persistent_buffers_set:\n                continue\n            new_fqn = f\"{curr_fqn}{name}\"\n            yield new_fqn, obj\n\n        if getattr(module.__class__, \"get_extra_state\", nn.Module.get_extra_state) != nn.Module.get_extra_state:\n            new_fqn = f\"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}\"\n            yield new_fqn, _EXTRA_STATE()\n\n    yield from recurse(model, \"\")\n\n\ndef _verify_options(\n    model: nn.Module,\n    optims: tuple[torch.optim.Optimizer, ...],\n    optim_only: bool,\n    *,\n    submodules: Optional[set[nn.Module]] = None,\n    options: Optional[StateDictOptions] = None,\n) -> _StateDictInfo:\n    \"\"\"\n    Verify the model and options passed by the user and generates _StateDictInfo.\n    \"\"\"\n    if submodules:\n        warnings.warn(\n            \"Getting submodules only model/optim state_dict is deprecated and \"\n            \"will be removed in 2.5. This feature can be achieved by manually \"\n            \"filtering out the state_dict returned from get_state_dict.\",\n            FutureWarning,\n        )\n    if optim_only and not optims:\n        raise RuntimeError(\"Optimizers are not passed in but optim_only is set to True.\")\n\n    options = options or StateDictOptions()\n\n    fqn_param_mapping: dict[str | torch.Tensor, set[str] | torch.Tensor] = {}\n    shared_params_mapping: dict[str | torch.Tensor, set[str] | torch.Tensor] = {}\n    for name, param in _iterate_valid_model_state(model):\n        if isinstance(param, _EXTRA_STATE):\n            continue\n\n        fqns = _get_fqns(model, name)\n        fqn = fqn_param_mapping.get(param, None)\n        if fqn is not None:\n            cast(set[str], fqn_param_mapping[param]).update(fqns)\n            shared_params_mapping[param] = fqn_param_mapping[param]\n        else:\n            # We need to do copy as _get_fqns is lru_cached\n            fqn_param_mapping[param] = fqns.copy()\n        for fqn in fqns:\n            if not isinstance(param, _EXTRA_STATE):\n                fqn_param_mapping[fqn] = param\n\n    for param_, fqns_ in list(shared_params_mapping.items()):\n        for fqn in fqns_:\n            shared_params_mapping[fqn] = cast(torch.Tensor, param_)\n\n    submodule_prefixes: set[str] = set()\n    if submodules:\n        submodules = set(submodules)\n        for name, module in model.named_modules():\n            if module not in submodules:\n                continue\n            fqns = _get_fqns(model, name)\n            assert len(fqns) == 1, \"Submodule FQN should only have 1 instance\"\n            submodule_prefixes.update(f\"{fqn}.\" for fqn in fqns)\n\n    if options.broadcast_from_rank0 and not options.full_state_dict:\n        raise ValueError(\"full_state_dict must be True when broadcast_from_rank0 is True.\")\n    fsdp_modules = FSDP.fsdp_modules(model)\n    state_dict_config: StateDictConfig\n    optim_state_dict_config: OptimStateDictConfig\n    fsdp_context: Callable\n    if fsdp_modules:\n        # FSDP API only work if at least one FSDP instance exists.\n        if options.full_state_dict:\n            state_dict_config = FullStateDictConfig(offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload)\n            optim_state_dict_config = FullOptimStateDictConfig(\n                offload_to_cpu=options.cpu_offload,\n                rank0_only=(options.cpu_offload or options.broadcast_from_rank0),\n            )\n            state_dict_type = StateDictType.FULL_STATE_DICT\n        else:\n            state_dict_config = ShardedStateDictConfig(\n                offload_to_cpu=options.cpu_offload,\n            )\n            optim_state_dict_config = ShardedOptimStateDictConfig(\n                offload_to_cpu=options.cpu_offload,\n            )\n            state_dict_type = StateDictType.SHARDED_STATE_DICT\n\n        @contextlib.contextmanager\n        def fsdp_state_dict_type_without_warning(\n            module,\n            state_dict_type,\n            state_dict_config,\n            optim_state_dict_config,\n        ):\n            with warnings.catch_warnings():\n                warnings.filterwarnings(\"ignore\", message=\"FSDP.state_dict_type\", category=FutureWarning)\n                with FSDP.state_dict_type(\n                    module=module,\n                    state_dict_type=state_dict_type,\n                    state_dict_config=state_dict_config,\n                    optim_state_dict_config=optim_state_dict_config,\n                ):\n                    yield\n\n        fsdp_context = functools.partial(\n            fsdp_state_dict_type_without_warning,\n            module=model,\n            state_dict_type=state_dict_type,\n            state_dict_config=state_dict_config,\n            optim_state_dict_config=optim_state_dict_config,\n        )\n    else:\n        fsdp_context = contextlib.nullcontext\n\n    return _StateDictInfo(\n        **asdict(options),\n        fqn_param_mapping=fqn_param_mapping,\n        shared_params_mapping=shared_params_mapping,\n        submodule_prefixes=submodule_prefixes,\n        fsdp_context=fsdp_context,\n        fsdp_modules=cast(list[nn.Module], fsdp_modules),\n        handle_model=not optim_only,\n        handle_optim=(len(optims) > 0),\n    )\n\n\ndef _verify_state_dict(\n    model_state_dict: dict[str, ValueType],\n    optim_state_dict: OptimizerStateType,\n    info: _StateDictInfo,\n) -> None:\n    for module in info.fsdp_modules:\n        fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)\n        assert fsdp_state is not None, \"Expected a fsdp_state with a fsdp module.\"\n\n    # Verify if the model_state_dict and optim_state_dict are valid. This API\n    # should give the users an explicit error message to debug or report.\n    if (\n        info.handle_model\n        and not model_state_dict\n        and not info.submodule_prefixes\n        and not info.ignore_frozen_params\n        and not (info.cpu_offload and info.full_state_dict)\n        and info.strict\n        and not info.broadcast_from_rank0\n    ):\n        raise RuntimeError(\n            \"The option indicates that model state_dict is required to save \"\n            \"or load, but model state_dict is empty.\"\n            f\"rank = {dist.get_rank()=}.\"\n        )\n\n    if info.handle_optim:\n        if not optim_state_dict and not (info.cpu_offload and info.full_state_dict) and (not info.broadcast_from_rank0):\n            raise RuntimeError(\n                \"The option indicates that model state_dict is required to save, \"\n                f\"or load but optim state_dict is empty. {optim_state_dict}\"\n            )\n\n    for key in model_state_dict.keys():\n        if _FLAT_PARAM in key:\n            raise RuntimeError(f\"{key} contains {_FLAT_PARAM}. This can happen if the model is not the root module.\")\n\n\ndef _state_dict_fn(obj: nn.Module | torch.optim.Optimizer, api: str) -> Callable:\n    call = getattr(obj, api)\n    if call in _patched_state_dict:\n        call = functools.partial(getattr(obj.__class__, api), self=obj)\n    return call\n\n\ndef _maybe_full_or_cpu_state_dict(state_dict: dict[str, Any], info: _StateDictInfo) -> dict[str, Any]:\n    if info.full_state_dict:\n        ranks_only = () if (not info.cpu_offload or not torch.distributed.is_initialized()) else (0,)\n        return _gather_state_dict(state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only)\n    elif info.cpu_offload:\n        return _offload_state_dict_to_cpu(state_dict)\n    else:\n        return state_dict\n\n\n@torch.no_grad()\ndef _get_model_state_dict(model: nn.Module, info: _StateDictInfo) -> dict[str, ValueType]:\n    if not info.handle_model:\n        return {}\n\n    with info.fsdp_context():\n        state_dict = _state_dict_fn(model, \"state_dict\")()\n\n    for key in list(state_dict.keys()):\n        fqns = _get_fqns(model, key)\n        assert len(fqns) == 1, (key, fqns)\n        fqn = next(iter(fqns))\n        if fqn != key:\n            # As we only support FSDP, DDP, and TP, the only cases are\n            # wrapper-based DDP and compiler. Verify if the assumption\n            # is correct.\n            def verify(key, fqn) -> bool:\n                if len(fqn) >= len(key):\n                    return False\n                fqn_split = fqn.split(\".\")\n                key_split = key.split(\".\")\n                fqn_idx = 0\n                for key_idx, key_name in enumerate(key_split):\n                    if key_name == fqn_split[fqn_idx]:\n                        fqn_idx += 1\n                        if fqn_idx == len(fqn_split):\n                            return key_idx == len(key_split) - 1\n                    elif key_name in (\"module\", \"_orig_mod\"):\n                        continue\n                    else:\n                        return False\n                return True\n\n            if not verify(key, fqn):\n                raise RuntimeError(f\"An unexpected key, {key}, exists. FQN is {fqn}\")\n            state_dict[fqn] = state_dict.pop(key)\n\n    if info.submodule_prefixes:\n        new_state_dict: dict[str, ValueType] = {}\n        # TODO: make this faster.\n        for fqn in state_dict.keys():\n            for prefix in info.submodule_prefixes:\n                if not fqn.startswith(prefix):\n                    continue\n                if info.keep_submodule_prefixes:\n                    new_state_dict[fqn] = state_dict[fqn]\n                else:\n                    new_fqn = fqn[len(prefix) :]\n                    new_state_dict[new_fqn] = state_dict[fqn]\n        state_dict = new_state_dict\n\n    if info.ignore_frozen_params:\n        for key, param in model.named_parameters():\n            if param.requires_grad:\n                continue\n            fqns = _get_fqns(model, key)\n            for fqn in fqns:\n                state_dict.pop(fqn)\n\n    for key, p in list(state_dict.items()):\n        if torch.is_tensor(p) and p.is_meta:\n            state_dict.pop(key)\n\n    return _maybe_full_or_cpu_state_dict(state_dict, info)\n\n\n@torch.no_grad()\ndef _load_model_state_dict(\n    model: nn.Module,\n    state_dict: dict[str, ValueType],\n    info: _StateDictInfo,\n) -> _IncompatibleKeys:\n    if not info.handle_model or (not state_dict and not info.broadcast_from_rank0):\n        return _IncompatibleKeys({}, {})\n\n    local_state_dict = {}\n    for key, value in _iterate_valid_model_state(model, info.dsd_fqn_modifiers):\n        fqns = _get_fqns(model, key, info.dsd_fqn_modifiers)\n        fqns_with_prefix = _get_fqns(\n            model,\n            key,\n            info.dsd_fqn_modifiers,\n            skip_ddp_prefix=False,\n            skip_compiler_prefix=False,\n        )\n\n        for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix, strict=False):\n            if (not info.broadcast_from_rank0 or dist.get_rank() == 0) and fqn != fqn_with_prefix:\n                load_value = state_dict.pop(fqn, None)\n                if load_value is None:\n                    if info.strict:\n                        raise RuntimeError(f\"Missing key: {fqn}.\")\n                else:\n                    state_dict[fqn_with_prefix] = load_value\n            local_state_dict[fqn_with_prefix] = value\n\n    assign = False\n    if info.broadcast_from_rank0 or info.full_state_dict:\n        devices = set()\n        for key, value in local_state_dict.items():\n            if torch.is_tensor(value) and value.dim() > 0:\n                devices.add(value.device)\n        # In lora state_dict, there could be multiple devices, with meta device inside.\n        # Take the other device in the broadcast/distribtue, and set assign to True\n        if torch.device(\"meta\") in devices:\n            devices.remove(torch.device(\"meta\"))\n            assign = True\n        if len(devices) == 0:\n            devices.add(dist.distributed_c10d._get_pg_default_device())\n        elif len(devices) > 1:\n            raise ValueError(\"Multiple devices found\")\n\n        if info.broadcast_from_rank0:\n            _broadcast_state_dict(\n                state_dict,\n                local_state_dict,\n                device=devices.pop(),\n                strict=info.strict,\n                cpu_offload=info.cpu_offload,\n            )\n        elif info.full_state_dict:\n            _distribute_state_dict(state_dict, local_state_dict, device=devices.pop())\n        for fqn, local_state in local_state_dict.items():\n            state_dict[fqn] = local_state\n\n    with info.fsdp_context():\n        return cast(\n            _IncompatibleKeys,\n            _state_dict_fn(model, \"load_state_dict\")(state_dict=state_dict, strict=info.strict, assign=assign),\n        )\n\n\ndef _init_optim_state(optim: torch.optim.Optimizer) -> None:\n    \"\"\"\n    Initialize optim states by calling the step() with zero grads.\n    \"\"\"\n    if optim.state:\n        # The optimizer state is initialized.\n        return\n\n    # There are some stateless optimizers like SGD. These optimizer will\n    # not return in the above condition. So if gradients exist, we should also\n    # return. If gradients do not exist, the following initialization should\n    # not disturb SGD because the gradients and lr are both zero.\n    for param_group in optim.param_groups:\n        for param in param_group[_PARAMS]:\n            if param.grad is not None:\n                return\n\n    for param_group in optim.param_groups:\n        for param in param_group[_PARAMS]:\n            if param.requires_grad:\n                param.grad = torch.zeros_like(param)\n\n    # Some optimizers will update parameters regardless of grads due to lr, so\n    # make lr to zero when calling `step()`.\n    lrs = []\n    for param_group in optim.param_groups:\n        if \"lr\" in param_group:\n            lrs.append(param_group[\"lr\"])\n            param_group[\"lr\"] = torch.tensor(0.0) if isinstance(param_group[\"lr\"], torch.Tensor) else 0.0\n    optim.step(closure=None)\n    # Whether to recover the \"lr\" should not matter too much as we will\n    # restore checkpointing later.\n    for param_group in optim.param_groups:\n        if \"lr\" in param_group:\n            param_group[\"lr\"] = lrs.pop(0)\n    optim.zero_grad(set_to_none=True)\n\n\ndef _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]:\n    \"\"\"\n    This API flattens the optimizer state_dict to support optimizer resharding for\n    MPMD, e.g., pipeline parallelism.\n\n    Without the API, the original optimizer state_dict looks like:\n    {\n        \"state\": {\n            \"layer1.weight\": {\n                \"step\": 10, \"exp_avg\": SomeTensor, \"exp_avg_sq\": SomeTensor\n            },\n            \"layer2.weight\": {\n                \"step\": 10, \"exp_avg\": SomeTensor, \"exp_avg_sq\": SomeTensor\n            },\n        },\n        \"param_group\": [\n            {\n                \"lr\": 0.0,\n                \"betas\": (0.9, 0.95), ...,\n                \"params\": [\"layer1.weight\", \"layer2.weight\"]\n            }\n        ]\n    }\n\n    With this API, the optimizer state_dict looks like:\n    {\n        \"state.layer1.weight.step\": 10,\n        \"state.layer2.weight.step\": 10,\n        \"state.layer1.weight.exp_avg\": SomeTensor,\n        \"state.layer2.weight.exp_avg\": SomeTensor,\n        \"state.layer1.weight.exp_avg_sq\": SomeTensor,\n        \"state.layer2.weight.exp_avg_sq\": SomeTensor,\n        \"param_group.layer1.weight.lr\" : 0.1,\n        \"param_group.layer2.weight.lr\" : 0.1,\n        \"param_group.layer1.weight.betas\" : (0.9, 0.95),\n        \"param_group.layer2.weight.betas\" : (0.9, 0.95),\n    }\n\n    Note that if any of the value is a container, like the betas in the example,\n    this API won't flattent it.\n    \"\"\"\n\n    def _raise_if_type_not_supported(v):\n        if not isinstance(v, (torch.Tensor, int, float)):\n            raise NotImplementedError(\n                f\"Flattening optimizer state_dict only supports tensor, int, float states now. Type is {type(v)}.\"\n            )\n\n    ret: dict[str, ValueType] = {}\n    for fqn, state in cast(DictValueType, state_dict[_STATE]).items():\n        for k, v in cast(DictValueType, state).items():\n            _raise_if_type_not_supported(v)\n            ret[f\"{_STATE}.{fqn}.{k}\"] = v\n\n    for param_group in cast(ListDictValueType, state_dict[_PG]):\n        fqns = param_group.pop(_PARAMS)\n        for fqn in cast(list[str], fqns):\n            for k, v in param_group.items():\n                ret[f\"{_PG}.{fqn}.{k}\"] = v\n    return ret\n\n\ndef _unflatten_optim_state_dict(\n    optim: torch.optim.Optimizer,\n    state_dict: dict[str, ValueType],\n    info: _StateDictInfo,\n) -> OptimizerStateType:\n    \"\"\"\n    This API unflattens the state_dict generated by _flatten_optim_state_dict().\n    See the docstring of _flatten_optim_state_dict() for more detail.\n    \"\"\"\n    state: DictValueType = {}\n    pg_state: ListDictValueType = []\n    return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}\n\n    for param_group in optim.param_groups:\n        pg_state.append({_PARAMS: []})\n        for param in param_group[_PARAMS]:\n            for fqn in info.fqn_param_mapping[param]:\n                # If a parameter is shared, only one of the FQN will be used.\n                # So we need to verify which if this fqn is actually used in\n                # the state_dict.\n                if fqn in info.shared_params_mapping:\n                    in_params = False\n                    for k in param_group.keys():\n                        if k == _PARAMS:\n                            continue\n                        flatten_key = f\"{_PG}.{fqn}.{k}\"\n                        if flatten_key in state_dict:\n                            in_params = True\n                        break\n                else:\n                    in_params = True\n\n                if not in_params:\n                    continue\n\n                params = pg_state[-1][_PARAMS]\n                assert isinstance(params, list)  # typing\n                params.append(fqn)\n                if not param.requires_grad:\n                    continue\n                state[fqn] = {}\n                for state_name in optim.state[param].keys():\n                    cast(DictValueType, state[fqn])[state_name] = state_dict[f\"{_STATE}.{fqn}.{state_name}\"]\n\n        first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0]\n        for k in param_group.keys():\n            if k == _PARAMS:\n                continue\n            value = state_dict[f\"{_PG}.{first_param_fqn}.{k}\"]\n            if k not in pg_state[-1]:\n                pg_state[-1][k] = value\n            elif pg_state[-1][k] != value:\n                raise RuntimeError(\n                    \"All the parameters in the same parameter group should have \"\n                    f\"the same saved param_group value. But {first_param_fqn}.{k} \"\n                    f\"is {value} while other(s) is {pg_state[-1][k]}.\"\n                )\n\n    return return_osd\n\n\n@torch.no_grad()\ndef _get_optim_state_dict(\n    model: nn.Module,\n    optimizers: tuple[torch.optim.Optimizer, ...],\n    info: _StateDictInfo,\n) -> OptimizerStateType:\n    if not info.handle_optim:\n        return {}\n\n    optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []}\n    for optim in optimizers:\n        _init_optim_state(optim)\n        osd = _state_dict_fn(optim, \"state_dict\")()\n        if info.fsdp_modules:\n            with info.fsdp_context():\n                osd = FSDP.optim_state_dict(model, optim, osd)\n\n            # We need to specially handle FlatParameter FSDP as\n            # FlatParameter FSDP converts the FQNs.\n            # There are no easy ways to do this conversion systematically.\n            # We can only use a string replacment without correctness check.\n            if not osd:\n                continue\n            for k in list(osd[_STATE].keys()):\n                if \"_orig_mod\" in k:\n                    osd[_STATE][k.replace(\"_orig_mod.\", \"\")] = osd[_STATE].pop(k)\n            for g in osd[_PG]:\n                params = [k.replace(\"_orig_mod.\", \"\") for k in g[_PARAMS]]\n                g[_PARAMS] = params\n        else:\n            params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups))\n            param_pid_mapping = dict(zip(params, range(len(params)), strict=False))\n            fqn_pid_mapping = {}\n            for key, param in model.named_parameters():\n                fqns = _get_fqns(model, key)\n                assert len(fqns) == 1\n                fqn = next(iter(fqns))\n                if param not in param_pid_mapping:\n                    continue\n                pid = param_pid_mapping[param]\n                fqn_pid_mapping[fqn] = pid\n                fqn_pid_mapping[pid] = fqn\n\n            for key in list(osd[_STATE].keys()):\n                fqn = fqn_pid_mapping[key]\n                osd[_STATE][fqn] = osd[_STATE].pop(key)\n\n            for group in osd[_PG]:\n                group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]]\n\n        if not osd:\n            continue\n\n        cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE])\n        cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG])\n\n    if info.flatten_optimizer_state_dict:\n        optim_state_dict = cast(OptimizerStateType, _flatten_optim_state_dict(optim_state_dict))\n\n    return _maybe_full_or_cpu_state_dict(optim_state_dict, info)\n\n\ndef _split_optim_state_dict(\n    model: nn.Module,\n    optim: torch.optim.Optimizer,\n    optim_state_dict: OptimizerStateType,\n    info: _StateDictInfo,\n) -> OptimizerStateType:\n    \"\"\"\n    Extract the corresponding optim state_dict from ``optim_state_dict`` for\n    ``optim`` and return the result optim state_dict.\n\n    Args:\n        model (nn.Module): the root model.\n        optim (torch.optim.Optimizer): the optimizer.\n        optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that\n            contains the optim state_dict of ``optim``.\n        info (_StateDictInfo): state dict information.\n\n    Returns:\n        The optim state_dict of ``optim``.\n    \"\"\"\n\n    state: DictValueType = {}\n    pg_state: ListDictValueType = []\n    return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}\n    pg_mapping: dict[int, int] = {}\n\n    if all(isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys()):\n        return optim_state_dict\n\n    for param_group in optim.param_groups:\n        pg_state.append({_PARAMS: []})\n        for param in param_group[_PARAMS]:\n            for fqn in info.fqn_param_mapping[param]:\n                if fqn in info.shared_params_mapping:\n                    in_params = False\n                    for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]):\n                        if fqn in cast(list[str], loaded_param_group[_PARAMS]):\n                            in_params = True\n                            break\n                else:\n                    in_params = True\n                if not in_params:\n                    continue\n\n                params = pg_state[-1][_PARAMS]\n                assert isinstance(params, list)\n                params.append(fqn)\n                if param.requires_grad:\n                    state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn]\n                for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]):\n                    if fqn in cast(list[str], loaded_param_group[_PARAMS]):\n                        pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1\n\n        if len(param_group[_PARAMS]) == 0:\n            # Param_group with empty params.\n            ret = []\n            for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]):\n                if len(cast(list[str], loaded_param_group[_PARAMS])) == 0:\n                    ret.append(loaded_param_group)\n            if len(ret) != 1:\n                raise ValueError(\n                    \"There are param groups that have zero parameters. \"\n                    \"In such a case, DSD only support exactly one param group \"\n                    \"with zero parameters.\"\n                    \"But the loaded state_dict has zero or more than one param groups \"\n                    \"that have zero parameters.\"\n                )\n            if len(optim_state_dict[_PG]) != len(optim.param_groups):\n                raise ValueError(\n                    \"When there is a parameter group that has zero parameters, multiple optimizers are not supported.\"\n                )\n            pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1\n\n    for param_group in cast(ListDictValueType, optim_state_dict[_PG]):\n        pg_idx = pg_mapping.get(id(param_group), -1)\n        if pg_idx == -1:\n            continue\n\n        for key, value in param_group.items():\n            if key == _PARAMS:\n                continue\n            # TODO: check if value is the same if exists.\n            pg_state[pg_idx][key] = value\n\n    return return_osd\n\n\n@torch.no_grad()\ndef _load_optim_state_dict(\n    model: nn.Module,\n    optimizers: tuple[torch.optim.Optimizer, ...],\n    state_dict: OptimizerStateType,\n    info: _StateDictInfo,\n) -> None:\n    if not info.handle_optim:\n        return\n\n    for optim in optimizers:\n        _init_optim_state(optim)\n        if state_dict:\n            if _STATE in state_dict:\n                optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info)\n            else:\n                optim_state_dict = _unflatten_optim_state_dict(optim, cast(dict[str, ValueType], state_dict), info)\n        else:\n            optim_state_dict = {}\n        if info.fsdp_modules:\n            # We need to specially handle FlatParameter FSDP as\n            # FlatParameter FSDP converts the FQNs.\n            for original_fqn, _ in model.named_parameters():\n                fqns = _get_fqns(model, original_fqn)\n                fqns_with_compiler = _get_fqns(model, original_fqn, skip_compiler_prefix=False)\n                if fqns == fqns_with_compiler:\n                    continue\n\n                assert len(fqns) == 1\n                fqn = fqns.pop()\n                fqn_with_compiler = fqns_with_compiler.pop()\n                for g in optim_state_dict[_PG]:\n                    val = cast(dict[str, Any], g)\n                    params = [key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS]]\n                    val[_PARAMS] = params\n                osd_state = cast(DictValueType, optim_state_dict[_STATE])\n                for k in list(osd_state.keys()):\n                    if fqn in k:\n                        osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k)\n\n            with info.fsdp_context():\n                optim_state_dict = FSDP.optim_state_dict_to_load(model, optim, optim_state_dict)\n        elif info.full_state_dict:\n            info.full_state_dict = False\n            local_state_dict = _get_optim_state_dict(model, (optim,), info)\n            info.full_state_dict = True\n            device = None\n\n            def _device(t):\n                if t.dim() > 0:\n                    nonlocal device\n                    if device is None:\n                        device = t.device\n                    elif device != t.device:\n                        raise ValueError(\"Device mismatch\")\n                return t\n\n            _ = tree_map_only(torch.Tensor, _device, local_state_dict)\n            assert device is not None\n            flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict)\n            flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict)\n            if info.broadcast_from_rank0:\n                _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device)\n            else:\n                _distribute_state_dict(flatten_osd, flatten_local_osd, device=device)\n            # The modifications listed seek to address the problem where optim might possess\n            # dissimilar parameters in comparison to optim_state_dict. This is achieved by\n            # incorporating differential parameters within local, which may result in optim\n            # having additional parameters ultimately.\n            for optim_key in flatten_osd.keys():\n                if optim_key not in flatten_local_osd:\n                    assert optim_key in osd_mapping\n                    flatten_local_osd[optim_key] = flatten_osd[optim_key]\n                    local_osd_mapping[optim_key] = osd_mapping[optim_key]\n            optim_state_dict = _unflatten_state_dict(flatten_local_osd, local_osd_mapping)\n            for pg in optim_state_dict[_PG]:\n                if _PARAMS not in pg:\n                    cast(dict[str, ValueType], pg)[_PARAMS] = []\n\n        # Note that we do not have to convert the FQN back to param id here if\n        # order in optim.param_groups[idx][_PARAMS] is the same as the one in\n        # optim_state_dict[_PG][idx][_PARAMS].\n        _state_dict_fn(optim, \"load_state_dict\")(state_dict=optim_state_dict)\n\n\ndef get_model_state_dict(\n    model: nn.Module,\n    *,\n    submodules: Optional[set[nn.Module]] = None,\n    options: Optional[StateDictOptions] = None,\n) -> dict[str, ValueType]:\n    \"\"\"\n    Return the model state_dict of ``model``.\n\n    See ``get_state_dict`` for the detail usage.\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters\n            that belong to the submodules.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be returned. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        The state_dict for ``model``.\n\n    :rtype: typing.Dict[str, ValueType]\n    \"\"\"\n    with _gc_context():\n        info = _verify_options(\n            model,\n            (),\n            optim_only=False,\n            submodules=submodules,\n            options=options,\n        )\n        model_state_dict = _get_model_state_dict(model, info)\n        _verify_state_dict(model_state_dict, {}, info)\n        return model_state_dict\n\n\ndef get_optimizer_state_dict(\n    model: nn.Module,\n    optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer],\n    *,\n    submodules: Optional[set[nn.Module]] = None,\n    options: Optional[StateDictOptions] = None,\n) -> OptimizerStateType:\n    \"\"\"\n    Return the combined state_dict for optimizers.\n\n    See ``get_state_dict`` for the detail usage.\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        optimizers (Union[None, Optimizer, Iterable[Optimizer]]):\n            The optimizers that are used to optimize ``model``.\n        submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters\n            that belong to the submodules.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be returned. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        The state_dict for ``optimizers``.\n\n    :rtype: OptimizerStateType\n    \"\"\"\n    with _gc_context():\n        optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers)\n        info = _verify_options(\n            model,\n            optimizers,\n            optim_only=True,\n            submodules=submodules,\n            options=options,\n        )\n        optim_state_dict = _get_optim_state_dict(model, optimizers, info)\n        _verify_state_dict({}, optim_state_dict, info)\n        return optim_state_dict\n\n\ndef get_state_dict(\n    model: nn.Module,\n    optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer],\n    *,\n    submodules: Optional[set[nn.Module]] = None,\n    options: Optional[StateDictOptions] = None,\n) -> tuple[dict[str, ValueType], OptimizerStateType]:\n    \"\"\"\n    Return the model state_dict and optimizers state_dict.\n\n    ``get_state_dict`` can process any module that is parallelized by PyTorch\n    FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any\n    combination of these parallelisms. The main functions of ``get_state_dict``\n    are: 1.) returning a model and optimizer state_dict that can be resharded\n    with a different number of trainers and/or different parallelisms.\n    2.) hiding the parallelism-specific state_dict APIs. Users don't have to call\n    these APIs.\n    3.) sanity checking the result state_dict.\n\n    The keys of the result state dictionary are the canonical FQNs (Fully\n    Qualified Names).  A canonical FQN refers to the FQN based on a parameter's\n    position in an nn.Module hierarchy. More specifically, a canonical FQN to a\n    parameter is the FQN returned by ``module.named_parameters()`` or\n    ``module.named_buffers()`` when the module is not distributed by any\n    parallelisms. Since the optimizer internally uses parameter IDs to represent\n    a parameter, there will be a conversion from the parameter IDs to the\n    canonical FQNs when calling this API.\n\n    ``get_state_dict`` can also process a module that is not parallelized. In\n    such a case, ``get_state_dict`` only performs one function -- converting the\n    optimizer parameter IDs to the canonical FQNs.\n\n    Example:\n        >>> # xdoctest: +SKIP\n        >>> import torch\n        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        >>> from torch.nn.parallel import DistributedDataParallel as DDP\n        >>> from torch.distributed.checkpoint.state_dict import get_state_dict\n\n        >>> fsdp_model = FSDP(copy.deepcopy(model))\n        >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)\n        >>> ddp_model = DDP(copy.deepcopy(model))\n        >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)\n\n\n        >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)\n        >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(\n        ...     fsdp_model, fsdp_optim\n        ... )\n\n        >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),\n        >>> # the asserts will fail.\n        >>> assert ddp_state_dict == fsdp_state_dict\n        >>> assert ddp_optim_state == fsdp_optim_state_dict\n\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        optimizers (Union[None, Optimizer, Iterable[Optimizer]]):\n            The optimizers that are used to optimize ``model``.\n        submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters\n            that belong to the submodules.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be returned. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        ``Tuple`` that contain model state_dict and optimizer state_dict.\n\n    :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType]\n    \"\"\"\n\n    with _gc_context():\n        optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers)\n        info = _verify_options(\n            model,\n            optimizers,\n            optim_only=False,\n            submodules=submodules,\n            options=options,\n        )\n        model_state_dict = _get_model_state_dict(model, info)\n        optim_state_dict = _get_optim_state_dict(model, optimizers, info)\n        _verify_state_dict(model_state_dict, optim_state_dict, info)\n        return model_state_dict, optim_state_dict\n\n\ndef _unflatten_model_state_dict(\n    model: nn.Module,\n    state_dict: dict[nn.Module, dict[str, ValueType]] | dict[str, ValueType],\n) -> dict[str, ValueType]:\n    if not state_dict:\n        return {}\n\n    if isinstance(next(iter(state_dict.keys())), nn.Module):\n        warnings.warn(\n            \"Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``\"\n            \"is deprecated and will be removed in 2.5. If you need this \"\n            \"feature, please preprocessing the model_state_dict to achieve the \"\n            \"same functionality.\",\n            FutureWarning,\n        )\n        cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict)\n        new_state_dict: dict[str, ValueType] = {}\n        for submodule, sub_state_dict in cast_state_dict.items():\n            for name, m in model.named_modules():\n                if m != submodule:\n                    continue\n\n                fqns = _get_fqns(model, name)\n                assert len(fqns) == 1, \"FQNs for a submodule should only have 1 element\"\n                prefix = f\"{next(iter(fqns))}.\"\n                new_state_dict.update({prefix + subfqn: value for subfqn, value in sub_state_dict.items()})\n        return new_state_dict\n    else:\n        return cast(dict[str, ValueType], state_dict)\n\n\ndef set_model_state_dict(\n    model: nn.Module,\n    model_state_dict: dict[str, ValueType],\n    *,\n    options: Optional[StateDictOptions] = None,\n) -> _IncompatibleKeys:\n    \"\"\"Load the model state_dict.\n\n    The counterpart of ``get_model_state_dict`` to set the state_dict to the\n    model. See ``set_state_dict`` for the detail usage.\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        model_state_dict: (Dict[str, ValueType]):\n           the model state_dict to load. If the key of the ``model_state_dict``\n           is nn.Module, the key is a submodule of ``model`` and the value should\n           be the state_dict of the submodule. When loading the state_dict,\n           the prefix of the submodule will be append to the state_dict.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be loaded. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n            * **missing_keys** is a list of str containing the missing keys\n            * **unexpected_keys** is a list of str containing the unexpected keys\n\n    :type model_state_dict: typing.Dict[str, ValueType]\n    \"\"\"\n    model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(model, model_state_dict)\n    with _gc_context():\n        info = _verify_options(model, (), optim_only=False, options=options)\n\n        _verify_state_dict(model_state_dict, {}, info)\n        return _load_model_state_dict(model, model_state_dict, info)\n\n\ndef set_optimizer_state_dict(\n    model: nn.Module,\n    optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer],\n    optim_state_dict: OptimizerStateType,\n    *,\n    options: Optional[StateDictOptions] = None,\n) -> None:\n    \"\"\"Load the optimizers state_dict.\n\n    The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the\n    optimizers. See ``set_state_dict`` for the detail usage.\n\n    WARN: ``set_optimizer_state_dict`` can only be called before ``backward()`` or after\n        ``step()`` is called on the optimizers. Otherwise, the optimizer states won't be\n        initialized correctly.\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        optimizers (Union[Optimizer, Iterable[Optimizer]]):\n            The optimizers that are used to optimize ``model``.\n        optim_state_dict: OptimizerStateType:\n            the optimizer state_dict to load.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be loaded. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        None\n\n    :type optim_state_dict: typing.OptimizerStateType\n    \"\"\"\n    with _gc_context():\n        optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers)\n        info = _verify_options(model, optimizers, optim_only=True, options=options)\n\n        _verify_state_dict({}, optim_state_dict, info)\n        _load_optim_state_dict(model, optimizers, optim_state_dict, info)\n\n\ndef set_state_dict(\n    model: nn.Module,\n    optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer],\n    *,\n    model_state_dict: dict[str, ValueType],\n    optim_state_dict: OptimizerStateType,\n    options: Optional[StateDictOptions] = None,\n) -> _IncompatibleKeys:\n    \"\"\"Load the model state_dict and optimizers state_dict.\n\n    The counterpart of ``get_state_dict`` to set the state_dict to the model and\n    optimizers.  The given ``model_state_dict`` and ``optim_state_dict`` do not\n    have to be returned by ``get_state_dict`` but must meet the following\n    requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``,\n    2) if a tensor is sharded, it must be either a ShardedTensor or DTensor,\n    3) optimizer state_dict cannot contain the parameter IDs; the keys should be\n    the canonical FQNs.\n\n    WARN: ``set_state_dict`` can only be called before ``backward()`` or after ``step()``\n        is called on the optimizers. Otherwise, the optimizer states won't be initialized\n        correctly.\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        optimizers (Union[Optimizer, Iterable[Optimizer]]):\n            The optimizers that are used to optimize ``model``.\n        model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):\n           the model state_dict to load. If the key of the ``model_state_dict``\n           is nn.Module, the key is a submodule of ``model`` and the value should\n           be the state_dict of the submodule. When loading the state_dict,\n           the prefix of the submodule will be append to the state_dict.\n        optim_state_dict: OptimizerStateType:\n            the optimizer state_dict to load.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be loaded. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n            * **missing_keys** is a list of str containing the missing keys of the model state_dict.\n            * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict.\n\n    :type model_state_dict: typing.Dict[str, ValueType]\n    :type optim_state_dict: typing.OptimizerStateType\n    \"\"\"\n\n    model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(model, model_state_dict)\n    with _gc_context():\n        optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers)\n        info = _verify_options(model, optimizers, optim_only=not model_state_dict, options=options)\n\n        _verify_state_dict(model_state_dict, optim_state_dict, info)\n        _load_optim_state_dict(model, optimizers, optim_state_dict, info)\n        return _load_model_state_dict(model, model_state_dict, info)\n\n\n# TODO: correct the state_dict function signature.\n# TODO: this API is not yet fully tested. Make it private\n@no_type_check\ndef _patch_model_state_dict(\n    model: nn.Module,\n    *,\n    options: Optional[StateDictOptions] = None,\n) -> None:\n    \"\"\"Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``.\n\n    Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to\n    be a partial function to call ``get_state_dict`` and ``set_state_dict``.\n\n    Example:\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from torch.distributed.checkpoint.state_dict import patch_model_state_dict\n\n        model = fsdp(model)\n        patch_model_state_dict(model)\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be loaded. See\n            `StateDictOptions` for the details.\n    Returns:\n        None\n    \"\"\"\n\n    _state_dict_call = functools.partial(\n        get_model_state_dict,\n        model=model,\n        options=options,\n    )\n\n    def state_dict_call():\n        return _state_dict_call()\n\n    model.state_dict = state_dict_call\n\n    _load_state_dict_call = functools.partial(\n        set_model_state_dict,\n        model=model,\n        options=options,\n    )\n\n    def load_state_dict_call(state_dict: dict[str, Any]):\n        _load_state_dict_call(model_state_dict=state_dict)\n\n    model.load_state_dict = load_state_dict_call\n\n    _patched_state_dict.add(state_dict_call)\n    _patched_state_dict.add(load_state_dict_call)\n\n\n# TODO: correct the load_state_dict function signature.\n# TODO: this API is not yet fully tested. Make it private\n@no_type_check\ndef _patch_optimizer_state_dict(\n    model: nn.Module,\n    *,\n    optimizers: tuple[torch.optim.Optimizer, ...],\n    options: Optional[StateDictOptions] = None,\n) -> None:\n    \"\"\"Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``.\n\n    Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to\n    be a partial function to call ``get_state_dict`` and ``set_state_dict``.\n\n    Note that if there are multiple optimizers, all of the optimizers will be patched.\n    So users only need to call one of the state_dict() to get the full result.\n\n    Example:\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from torch.distributed.checkpoint.state_dict import patch_model_state_dict\n\n        model = fsdp(model)\n        patch_model_state_dict(model)\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be loaded. See\n            `StateDictOptions` for the details.\n    Returns:\n        None\n    \"\"\"\n\n    _state_dict_call = functools.partial(\n        get_optimizer_state_dict,\n        model=model,\n        optimizers=optimizers,\n        options=options,\n    )\n\n    def state_dict_call():\n        return _state_dict_call()\n\n    _load_state_dict_call = functools.partial(\n        set_optimizer_state_dict,\n        model=model,\n        optimizers=optimizers,\n        options=options,\n    )\n\n    def load_state_dict_call(state_dict: dict[str, Any]):\n        _load_state_dict_call(optim_state_dict=state_dict)\n\n    _patched_state_dict.add(state_dict_call)\n    _patched_state_dict.add(load_state_dict_call)\n    optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers)\n    for optim in optimizers:\n        optim.state_dict = state_dict_call\n        optim.load_state_dict = load_state_dict_call\n"
  },
  {
    "path": "verl_distillation/verl/third_party/vllm/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom importlib.metadata import PackageNotFoundError, version\n\nfrom packaging import version as vs\n\nfrom verl.utils.device import is_npu_available\nfrom verl.utils.import_utils import is_sglang_available\n\n\ndef get_version(pkg):\n    try:\n        return version(pkg)\n    except PackageNotFoundError:\n        return None\n\n\npackage_name = \"vllm\"\npackage_version = get_version(package_name)\nvllm_version = None\nVLLM_SLEEP_LEVEL = 1\n\nif package_version is None:\n    if not is_sglang_available():\n        raise ValueError(\n            f\"vllm version {package_version} not supported and SGLang also not Found. Currently supported \"\n            f\"vllm versions are 0.7.0+\"\n        )\nelif is_npu_available:\n    # sleep_mode=2 is not supported on vllm-ascend for now, will remove this restriction when this ability is ready.\n    VLLM_SLEEP_LEVEL = 1\n    from vllm import LLM\n    from vllm.distributed import parallel_state\nelif vs.parse(package_version) >= vs.parse(\"0.7.0\"):\n    vllm_version = package_version\n    if vs.parse(package_version) >= vs.parse(\"0.8.5\"):\n        VLLM_SLEEP_LEVEL = 2\n    from vllm import LLM\n    from vllm.distributed import parallel_state\nelse:\n    if vs.parse(package_version) in [vs.parse(\"0.5.4\"), vs.parse(\"0.6.3\")]:\n        raise ValueError(\n            f\"vLLM version {package_version} support has been removed. vLLM 0.5.4 and 0.6.3 are no longer \"\n            f\"supported. Please use vLLM 0.7.0 or later.\"\n        )\n    if not is_sglang_available():\n        raise ValueError(\n            f\"vllm version {package_version} not supported and SGLang also not Found. Currently supported \"\n            f\"vllm versions are 0.7.0+\"\n        )\n\n__all__ = [\"LLM\", \"parallel_state\"]\n"
  },
  {
    "path": "verl_distillation/verl/tools/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/tools/base_tool.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nfrom .schemas import OpenAIFunctionToolSchema, ToolResponse\n\n\nclass BaseTool:\n    \"\"\"Base class for tools.\n\n    A tool should support the following methods:\n\n    - `get_openai_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        self.config = config\n        self.tool_schema = tool_schema or self.get_openai_tool_schema()\n        assert self.tool_schema is not None, \"Tool schema is not set!\"\n        self.name = self.tool_schema.function.name\n        print(json.dumps(self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), indent=2))\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]:\n        \"\"\"Create a tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The instance id of the tool.\n            tool_creation_response: The response of the tool when creating the instance.\n        \"\"\"\n        if instance_id is None:\n            return str(uuid4()), ToolResponse()\n        else:\n            return instance_id, ToolResponse()\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:\n        \"\"\"Execute the tool.\n\n        Args:\n            instance_id: The instance id of the tool.\n            parameters: The json string of the parameters of the tool.\n\n        Returns: tool_response, tool_reward_score, tool_metrics\n            tool_response: The ToolResponse object containing text, image, and/or video content.\n            tool_reward_score: The step reward score of the tool.\n            tool_metrics: The metrics of the tool.\n        \"\"\"\n        return ToolResponse(text=\"Updated the tool state.\"), 0.0, {}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> float:\n        \"\"\"Calculate the reward of the tool.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The reward of the tool.\n        \"\"\"\n        return 0.0\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        \"\"\"Release the tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n        \"\"\"\n        pass\n"
  },
  {
    "path": "verl_distillation/verl/tools/geo3k_tool.py",
    "content": "# Copyright 2023-2025 SGLang Team\n# Copyright Amazon.com, Inc. or its affiliates.\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom verl.utils.reward_score import geo3k\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nfrom .base_tool import BaseTool\nfrom .schemas import OpenAIFunctionToolSchema, ToolResponse\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass Geo3kTool(BaseTool):\n    \"\"\"A demo tool for calculating the reward of geo3k.\n    - `get_openai_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        \"\"\"\n        _tool_schema = OpenAIFunctionToolSchema.model_validate({\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"calc_geo3k_reward\",\n                \"description\": \"A tool for calculating the reward of geo3k\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"answer\": {\n                            \"type\": \"string\",\n                            \"description\": \"The answer to the question, enclosed in \\\\boxed{}\",\n                        },\n                    },\n                    \"required\": [\"answer\"],\n                },\n            }\n        })\n        \"\"\"\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(\n        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs\n    ) -> tuple[str, ToolResponse]:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        return instance_id, ToolResponse()\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:\n        answer = parameters.get(\"answer\", \"\")\n        if not isinstance(answer, str):\n            answer = str(answer)\n        self._instance_dict[instance_id][\"response\"] = answer\n        reward = await self.calc_reward(instance_id)\n        # penalty for non improved answer submission\n        tool_reward = 0.0 if reward > self._instance_dict[instance_id][\"reward\"] else -0.05\n        # update the reward\n        self._instance_dict[instance_id][\"reward\"] = reward\n        return ToolResponse(text=f\"Current parsed {answer=} {reward=}\"), tool_reward, {}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> float:\n        return geo3k.compute_score(\n            self._instance_dict[instance_id][\"response\"],\n            self._instance_dict[instance_id][\"ground_truth\"],\n            use_boxed=False,\n            format_score=0.0,\n        )\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "verl_distillation/verl/tools/gsm8k_tool.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom verl.utils.reward_score import gsm8k\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nfrom .base_tool import BaseTool\nfrom .schemas import OpenAIFunctionToolSchema, ToolResponse\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass Gsm8kTool(BaseTool):\n    \"\"\"A demo tool for calculating the reward of gsm8k.\n\n    - `get_openai_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        \"\"\"\n        _tool_schema = OpenAIFunctionToolSchema.model_validate({\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"calc_gsm8k_reward\",\n                \"description\": \"A tool for calculating the reward of gsm8k\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"answer\": {\n                            \"type\": \"string\",\n                            \"description\": \"The answer to the question\",\n                        },\n                    },\n                    \"required\": [\"answer\"],\n                },\n            }\n        })\n        \"\"\"\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(\n        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs\n    ) -> tuple[str, ToolResponse]:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        if ground_truth is None:\n            ground_truth = kwargs.get(\"create_kwargs\", {}).get(\"ground_truth\", None)\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        return instance_id, ToolResponse()\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:\n        answer = parameters.get(\"answer\", \"\")\n        if not isinstance(answer, str):\n            answer = str(answer)\n\n        if answer.startswith(\"#### \"):\n            self._instance_dict[instance_id][\"response\"] = answer\n        else:\n            self._instance_dict[instance_id][\"response\"] = \"#### \" + answer\n\n        reward = await self.calc_reward(instance_id)\n        # penalty for non improved answer submission\n        tool_reward = 0.0 if reward > self._instance_dict[instance_id][\"reward\"] else -0.05\n        # update the reward\n        self._instance_dict[instance_id][\"reward\"] = reward\n\n        return ToolResponse(text=f\"Current parsed {answer=} {reward=}\"), tool_reward, {}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> float:\n        return gsm8k.compute_score(\n            self._instance_dict[instance_id][\"response\"],\n            self._instance_dict[instance_id][\"ground_truth\"],\n            method=\"flexible\",\n            format_score=0.0,\n            score=1.0,\n        )\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "verl_distillation/verl/tools/image_zoom_in_tool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport threading\nfrom contextlib import ExitStack\nfrom enum import Enum\nfrom math import ceil, floor\nfrom typing import Any, Callable, Optional, TypeVar\nfrom uuid import uuid4\n\nimport ray\nimport ray.actor\nfrom qwen_vl_utils import fetch_image\n\nfrom .base_tool import BaseTool\nfrom .schemas import OpenAIFunctionToolSchema, ToolResponse\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\nT = TypeVar(\"T\")\n\n\n# Adapted from verl/tools/sandbox_fusion_tools.py\nclass PoolMode(Enum):\n    \"\"\"Execution pool mode enumeration.\"\"\"\n\n    ThreadMode = 1\n    ProcessMode = 2\n\n\n@ray.remote(concurrency_groups={\"acquire\": 1, \"release\": 10})\nclass TokenBucketWorker:\n    \"\"\"Ray actor for rate limiting using token bucket algorithm.\"\"\"\n\n    def __init__(self, rate_limit: int):\n        self.rate_limit = rate_limit\n        self.current_count = 0  # For observability\n        self._semaphore = threading.Semaphore(rate_limit)\n\n    @ray.method(concurrency_group=\"acquire\")\n    def acquire(self):\n        \"\"\"Acquire a token from the bucket.\"\"\"\n        self._semaphore.acquire()\n        self.current_count += 1\n\n    @ray.method(concurrency_group=\"release\")\n    def release(self):\n        \"\"\"Release a token back to the bucket.\"\"\"\n        self._semaphore.release()\n        self.current_count -= 1\n\n    def get_current_count(self):\n        \"\"\"Get current number of acquired tokens.\"\"\"\n        return self.current_count\n\n\nclass VisualExecutionWorker:\n    \"\"\"Worker for executing visual processing operations with optional rate limiting.\"\"\"\n\n    def __init__(self, enable_global_rate_limit=True, rate_limit=10):\n        self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None\n\n    def _init_rate_limit(self, rate_limit):\n        \"\"\"Initialize singleton rate limiter.\"\"\"\n        return TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote(rate_limit)\n\n    def ping(self):\n        \"\"\"Health check method.\"\"\"\n        return True\n\n    def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T:\n        \"\"\"Execute function with optional rate limiting.\"\"\"\n        if self.rate_limit_worker:\n            with ExitStack() as stack:\n                stack.callback(self.rate_limit_worker.release.remote)\n                ray.get(self.rate_limit_worker.acquire.remote())\n                try:\n                    return fn(*fn_args, **fn_kwargs)\n                except Exception as e:\n                    # TODO we should make this available to the tool caller\n                    logger.warning(f\"Error when executing visual processing: {e}\")\n        else:\n            return fn(*fn_args, **fn_kwargs)\n\n\ndef init_visual_execution_pool(\n    num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode\n):\n    \"\"\"Initialize visual execution pool.\"\"\"\n    if mode == PoolMode.ThreadMode:\n        return (\n            ray.remote(VisualExecutionWorker)\n            .options(max_concurrency=num_workers)\n            .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit)\n        )\n    else:\n        raise NotImplementedError(\"Process mode is not implemented yet\")\n\n\nclass ImageZoomInTool(BaseTool):\n    \"\"\"A tool for zooming in on an image by cropping it based on a bounding box.\n\n    This tool provides a zoom-in functionality by cropping a region from an image,\n    with rate limiting and concurrent execution support through Ray.\n\n    Methods:\n        get_openai_tool_schema: Return the tool schema in OpenAI format\n        create: Create a tool instance for a trajectory\n        execute: Execute the zoom-in operation\n        calc_reward: Calculate the reward with respect to tool state\n        release: Release the tool instance\n    \"\"\"\n\n    MIN_DIMENSION = 28\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        \"\"\"\n        _tool_schema = OpenAIFunctionToolSchema.model_validate({\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"image_zoom_in_tool\",\n                \"description\": (\n                    \"Zoom in on a specific region of an image by cropping it based on a bounding box (bbox) and an \"\n                    \"optional object label.\"\n                ),\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"bbox_2d\": {\n                            \"type\": \"array\",\n                            \"items\":{\"type\":\"number\"},\n                            \"minItems\":4,\n                            \"maxItems\":4,\n                            \"description\": (\n                                \"The bounding box of the region to zoom in, as [x1, y1, x2, y2], where (x1, y1) is \"\n                                \"the top-left corner and (x2, y2) is the bottom-right corner.\"\n                            ),\n                        },\n                        \"label\": {\n                            \"type\": \"string\",\n                            \"description\": \"The name or label of the object in the specified bounding box (optional).\",\n                        },\n                    },\n                    \"required\": [\"bbox_2d\"],\n                },\n            }\n        })\n        \"\"\"\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n\n        # Worker and rate limiting configuration\n        self.num_workers = config.get(\"num_workers\", 20)\n        self.rate_limit = config.get(\"rate_limit\", 50)\n        self.timeout = config.get(\"timeout\", 30)\n\n        self.enable_global_rate_limit = config.get(\"enable_global_rate_limit\", True)\n        self.execution_pool = init_visual_execution_pool(\n            num_workers=self.num_workers,\n            enable_global_rate_limit=self.enable_global_rate_limit,\n            rate_limit=self.rate_limit,\n            mode=PoolMode.ThreadMode,\n        )\n        logger.info(f\"Initialized ImageZoomInTool with config: {config}\")\n\n    def _validate_bbox(self, left: float, top: float, right: float, bottom: float) -> bool:\n        \"\"\"Validate the bounding box dimensions and aspect ratio.\"\"\"\n        try:\n            if not (left < right and top < bottom):\n                logger.warning(f\"Invalid bbox shape: left={left}, top={top}, right={right}, bottom={bottom}\")\n                return False\n\n            height = bottom - top\n            width = right - left\n\n            # Prevent division by zero for zero-sized boxes\n            if min(height, width) == 0:\n                logger.warning(f\"Bbox has zero width or height: left={left}, top={top}, right={right}, bottom={bottom}\")\n                return False\n\n            if max(height, width) / min(height, width) > 100:\n                logger.warning(f\"Bbox aspect ratio > 100: left={left}, top={top}, right={right}, bottom={bottom}\")\n                return False\n\n            return True\n        except Exception as e:\n            logger.warning(f\"Bbox validation error: {e}\")\n            return False\n\n    def _maybe_resize_bbox(self, bbox_2d: list[float], image_width: int, image_height: int) -> Optional[list[float]]:\n        \"\"\"\n        Clamp, validate, and potentially resize a bounding box.\n\n        This function ensures the final bounding box is within image bounds and meets the minimum\n        dimension requirements. If the initial box is too small, it attempts to expand it\n        from its center. It performs a final check to guarantee the output dimensions are valid.\n\n        Returns:\n            A valid bounding box as a list of coordinates, or None if validation fails.\n        \"\"\"\n        left, top, right, bottom = bbox_2d\n\n        # 1. Clamp the initial bounding box to the image dimensions.\n        left = max(0.0, float(left))\n        top = max(0.0, float(top))\n        right = min(float(image_width), float(right))\n        bottom = min(float(image_height), float(bottom))\n\n        # 2. If clamped bbox is invalid, return immediately.\n        if not self._validate_bbox(left, top, right, bottom):\n            return None\n\n        current_bbox = [left, top, right, bottom]\n        height = bottom - top\n        width = right - left\n\n        # 3. If the box is too small, attempt to resize it.\n        if height < self.MIN_DIMENSION or width < self.MIN_DIMENSION:\n            logger.info(f\"Bbox {width}x{height} is smaller than {self.MIN_DIMENSION}, attempting resize.\")\n            center_x = (left + right) / 2.0\n            center_y = (top + bottom) / 2.0\n\n            min_dim = min(height, width)\n            if min_dim == 0:  # Safeguard for zero-area boxes\n                return None\n\n            # 1. Calculate the target dimensions to make the smallest side MIN_DIMENSION.\n            ratio = self.MIN_DIMENSION / min_dim\n            target_width = width * ratio\n            target_height = height * ratio\n\n            # 2. If the target size is larger than the image, scale it down to fit.\n            #    This preserves the aspect ratio while respecting image boundaries.\n            if target_width > image_width:\n                scale_down = image_width / target_width\n                target_width = image_width\n                target_height *= scale_down\n\n            if target_height > image_height:\n                scale_down = image_height / target_height\n                target_height = image_height\n                target_width *= scale_down\n\n            # 3. Determine the coordinates for the box centered on the original center.\n            new_half_width = target_width / 2.0\n            new_half_height = target_height / 2.0\n            new_left = center_x - new_half_width\n            new_top = center_y - new_half_height\n\n            # 4. Shift the box if it extends beyond the image boundaries to keep its size.\n            if new_left < 0:\n                new_left = 0\n            if new_top < 0:\n                new_top = 0\n            if new_left + target_width > image_width:\n                new_left = image_width - target_width\n            if new_top + target_height > image_height:\n                new_top = image_height - target_height\n\n            new_right = new_left + target_width\n            new_bottom = new_top + target_height\n\n            # Use floor and ceil for final integer coordinates.\n            current_bbox = [floor(new_left), floor(new_top), ceil(new_right), ceil(new_bottom)]\n\n        # 4. Final validation on the resulting bounding box (either original or resized).\n        final_left, final_top, final_right, final_bottom = current_bbox\n        if not self._validate_bbox(final_left, final_top, final_right, final_bottom):\n            logger.warning(f\"Final bbox is invalid after processing: {current_bbox}\")\n            return None\n\n        final_height = floor(final_bottom) - floor(final_top)\n        final_width = floor(final_right) - floor(final_left)\n\n        if final_height < self.MIN_DIMENSION or final_width < self.MIN_DIMENSION:\n            logger.warning(\n                f\"Final bbox size ({final_width}x{final_height}) are still smaller than minimum ({self.MIN_DIMENSION}).\"\n                f\"Original bbox: {bbox_2d}, original image size: {image_width}x{image_height}\"\n            )\n            return None\n\n        return current_bbox\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]:\n        \"\"\"\n        Creates a new instance for image zoom-in tool.\n\n        This method initializes a new session for an image, which can then be used\n        for operations like zooming. It fetches the image from various sources\n        and stores it internally.\n\n        Args:\n            instance_id: An optional unique identifier for the instance. If not\n                provided, a new UUID will be generated.\n            **kwargs: Should contain 'image' key with image data, or 'create_kwargs'\n                containing {'image': image_data}. Image can be one of the following:\n                - A PIL.Image.Image object.\n                - A string containing an HTTP or HTTPS URL.\n                - A string containing a local file path.\n                - A string containing a file URI (e.g., \"file:///path/to/image.jpg\").\n                - A string containing a base64-encoded image in the format of \"data:image/jpeg;base64,...\"\n\n        Returns:\n            Tuple of (instance_id, ToolResponse)\n        \"\"\"\n        if instance_id is None:\n            instance_id = str(uuid4())\n\n        # Handle create_kwargs parameter if passed\n        create_kwargs = kwargs.get(\"create_kwargs\", {})\n        if create_kwargs:\n            kwargs.update(create_kwargs)\n\n        # Get image from kwargs\n        image = kwargs.get(\"image\")\n        if image is None:\n            raise ValueError(\"Missing required 'image' parameter in kwargs\")\n\n        img = fetch_image({\"image\": image})\n        self._instance_dict[instance_id] = {\n            \"image\": img,\n            \"response\": \"\",\n            \"reward\": 0.0,\n        }\n        return instance_id, ToolResponse()\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:\n        bbox_2d = parameters.get(\"bbox_2d\")\n        label = parameters.get(\"label\", \"\")\n\n        if not bbox_2d or len(bbox_2d) != 4:\n            return (\n                ToolResponse(text=\"Error: bbox_2d parameter is missing or not a list of 4 numbers.\"),\n                -0.05,\n                {\"success\": False},\n            )\n\n        instance_data = self._instance_dict[instance_id]\n        image = instance_data[\"image\"]\n        image_width, image_height = image.size\n\n        try:\n            resized_bbox = self._maybe_resize_bbox(bbox_2d, image_width=image_width, image_height=image_height)\n\n            if resized_bbox is None:\n                error_msg = (\n                    f\"Error: The specified bounding box {bbox_2d} is invalid or results in a crop smaller than \"\n                    f\"the minimum size of {self.MIN_DIMENSION}x{self.MIN_DIMENSION}.\"\n                )\n                logger.warning(f\"Tool execution failed: {error_msg}\")\n                return ToolResponse(text=error_msg), -0.05, {\"success\": False}\n\n            cropped_image = image.crop(resized_bbox)\n            logger.info(f\"Cropped image size: {cropped_image.size}\")\n        except Exception as e:\n            logger.error(f\"Error processing image zoom-in: {e}\")\n            return ToolResponse(text=f\"Error processing image zoom-in: {e}\"), -0.05, {\"success\": False}\n\n        response_text = f\"Zoomed in on the image to the region {bbox_2d}.\"\n        if label:\n            response_text = f\"Zoomed in on the image to the region {bbox_2d} with label {label}.\"\n\n        return (\n            ToolResponse(\n                image=[cropped_image],\n                text=response_text,\n            ),\n            0.0,\n            {\"success\": True},\n        )\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        if instance_id in self._instance_dict:\n            del self._instance_dict[instance_id]\n"
  },
  {
    "path": "verl_distillation/verl/tools/mcp_base_tool.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom fastmcp.exceptions import ClientError\n\nfrom verl.tools.utils.mcp_clients.McpClientManager import ClientManager\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nfrom .base_tool import BaseTool\nfrom .schemas import OpenAIFunctionToolSchema, ToolResponse\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MCPBaseTool(BaseTool):\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n        self.timeout = config.get(\"timeout\", 30)\n\n        # TODO(hechanghao): create a global client manager to manage the rate limit, client and pool\n        logger.info(f\"Initialized MCPBaseTool with config: {config}\")\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        \"\"\"Return the OpenAI tool schema.\"\"\"\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]:\n        \"\"\"Create a tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The instance id of the tool.\n            tool_crtool_creation_response: The response of the tool when creating the instance.\n        \"\"\"\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"reward\": [],\n        }\n        return instance_id, ToolResponse()\n\n    async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]:\n        err_msg = \"\"\n        metadata = {}\n        try:\n            call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout)\n            logger.debug(f\"Tool result for instance {instance_id} with tool {self.name}: {call_tool_result.content}\")\n            result, metadata = self._parse_tool_result(call_tool_result.content)\n        except ClientError as e:\n            err_msg = f\"\\n Tool call failed: {e}\"\n        except ConnectionError as e:\n            err_msg = f\"\\n Connection failed: {e}\"\n        except Exception as e:\n            err_msg = f\"\\n An unexpected error occurred: {e}\"\n        finally:\n            if err_msg:\n                result = err_msg\n                metadata[\"api_request_error\"] = err_msg\n            else:\n                metadata[\"api_request_error\"] = None\n        return result, metadata\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:\n        if self.name == \"\" or self.name is None or parameters is None:\n            error_msg = \"Error: 'parameters' is missing or empty.\"\n            logger.error(f\"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}\")\n            return ToolResponse(text=json.dumps({\"result\": error_msg})), 0.0, {}\n\n        try:\n            result_text, metadata = await self._call_tool(instance_id, parameters)\n\n            # Store results in instance dictionary\n            self._instance_dict[instance_id][\"reward\"].append(result_text.strip())\n\n            # Convert metadata to metrics\n            metrics = {\n                \"query_count\": metadata.get(\"query_count\", 0),\n                \"status\": metadata.get(\"status\", \"unknown\"),\n                \"total_results\": metadata.get(\"total_results\", 0),\n                \"api_request_error\": metadata.get(\"api_request_error\"),\n            }\n\n            return ToolResponse(text=result_text), 0.0, metrics\n\n        except Exception as e:\n            error_result = json.dumps({\"result\": f\"Tool execution failed: {e}\"})\n            logger.error(f\"[MCPBaseTool] Execution failed: {e}\")\n            return ToolResponse(text=error_result), 0.0, {\"error\": str(e)}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> str:\n        return self._instance_dict[instance_id][\"reward\"]\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        if instance_id in self._instance_dict:\n            del self._instance_dict[instance_id]\n\n    def _parse_tool_result(self, content: list) -> tuple[str, dict]:\n        tools_content = [part.text for part in filter(lambda x: x.type == \"text\", content)]\n        return \" \".join(tools_content), {}\n"
  },
  {
    "path": "verl_distillation/verl/tools/mcp_search_tool.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport re\n\nfrom verl.tools.mcp_base_tool import MCPBaseTool\n\nfrom .schemas import OpenAIFunctionToolSchema\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MCPSearchTool(MCPBaseTool):\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        super().__init__(config, tool_schema)\n\n    def _parse_tool_result(self, content: list) -> tuple[str, dict]:\n        res = \"\"\n        res_cnt = 0\n        query_list = []\n        metadata = {\n            \"api_request_error\": \"\",\n            \"status\": \"unknown\",\n            \"total_results\": 0,\n        }\n        try:\n            for part in content:\n                if part.type != \"text\":\n                    continue\n                text = part.text.replace(\"'\", '\"')\n                query_match = re.search(r'query\"\\s*:\\s*\"([^\"]+)\"', text)\n                query = query_match.group(1) if query_match else \"\"\n                query_list.append(query)\n\n                title_matches = re.findall(r'\"title\"\\s*:', text)\n                title_count = len(title_matches)\n\n                results_match = re.search(r'\"results\"\\s*:\\s*(\\[.*?\\])', text, re.DOTALL)\n                results_content = results_match.group(1) if results_match else \"\"\n\n                res += results_content\n                res_cnt += title_count\n        except json.JSONDecodeError:\n            err_msg = \"json parse error.\"\n            logger.error(err_msg)\n            metadata[\"api_request_error\"] = err_msg\n            metadata[\"status\"] = \"error\"\n\n        # update metadata\n        metadata[\"status\"] = \"success\"\n        metadata[\"queries\"] = query_list\n        metadata[\"query_count\"] = len(query_list)\n        metadata[\"total_results\"] = res_cnt\n        return res, metadata\n"
  },
  {
    "path": "verl_distillation/verl/tools/sandbox_fusion_tools.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport threading\nfrom contextlib import ExitStack\nfrom enum import Enum\nfrom typing import Any, Callable, Optional, TypeVar\nfrom uuid import uuid4\n\nimport ray\n\nfrom verl.tools.base_tool import BaseTool\nfrom verl.utils.reward_score.sandbox_fusion.utils import _process_single_case\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nfrom .schemas import OpenAIFunctionToolSchema, ToolResponse\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\nT = TypeVar(\"T\")\n\n\nclass PoolMode(Enum):\n    ThreadMode = 1\n    ProcessMode = 2\n\n\n@ray.remote(concurrency_groups={\"acquire\": 1, \"release\": 10})\nclass TokenBucketWorker:\n    def __init__(self, rate_limit: int):\n        self.rate_limit = rate_limit\n        # this only used for observalability\n        self.current_count = 0\n        self._semaphore = threading.Semaphore(rate_limit)\n\n    @ray.method(concurrency_group=\"acquire\")\n    def acquire(self):\n        self._semaphore.acquire()\n        self.current_count += 1\n\n    @ray.method(concurrency_group=\"release\")\n    def release(self):\n        self._semaphore.release()\n        self.current_count -= 1\n\n    def get_current_count(self):\n        return self.current_count\n\n\nclass ExecutionWorker:\n    def __init__(self, enable_global_rate_limit=True, rate_limit=10):\n        self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None\n\n    def _init_rate_limit(self, rate_limit):\n        # TODO validation for rate_limit\n        # A Singleton Rate Limitor\n        return TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote(rate_limit)\n\n    def ping(self):\n        return True\n\n    def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T:\n        with ExitStack() as stack:\n            stack.callback(self.rate_limit_worker.release.remote)\n            ray.get(self.rate_limit_worker.acquire.remote())\n            try:\n                return fn(*fn_args, **fn_kwargs)\n            except Exception as e:\n                # TODO we should make this available to the tool caller\n                logger.warning(f\"Error when executing code: {e}\")\n\n\ndef init_execution_pool(\n    num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode\n):\n    if mode == PoolMode.ThreadMode:\n        return (\n            ray.remote(ExecutionWorker)\n            .options(max_concurrency=num_workers)\n            .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit)\n        )\n    else:\n        raise NotImplementedError(\"Process mode is not implemented yet\")\n        # return ray.util.multiprocessing.Pool(processes=num_workers)\n\n\nclass SandboxFusionTool(BaseTool):\n    \"\"\"A tool for executing the code using sanbox fusion image.\n\n    - `get_openai_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        \"\"\"\n        _tool_schema = OpenAIFunctionToolSchema.model_validate({\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"code_interpreter\",\n                \"description\": \"A tool for execute code\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"code\": {\n                            \"type\": \"string\",\n                            \"description\": \"code needs to be execute and grad\",\n                        },\n                    },\n                    \"required\": [\"code\"],\n                },\n            }\n        })\n        \"\"\"\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n        # TODO: better documentation for the config\n        self.num_workers = config.get(\"num_workers\", 10)\n        self.rate_limit = config.get(\"rate_limit\", 10)\n        self.default_timeout = config.get(\"default_timeout\", 30)\n        self.default_language = config.get(\"default_language\", \"python\")\n        self.enable_global_rate_limit = config.get(\"enable_global_rate_limit\", True)\n        self.execution_pool = init_execution_pool(\n            num_workers=self.num_workers,\n            enable_global_rate_limit=self.enable_global_rate_limit,\n            rate_limit=self.rate_limit,\n            mode=PoolMode.ThreadMode,\n        )\n        self.sandbox_fusion_url = config.get(\"sandbox_fusion_url\", \"\")\n        self.memory_limit_mb = config.get(\"memory_limit_mb\", 1024)\n        if self.sandbox_fusion_url == \"\":\n            raise ValueError(\"sandbox_fusion_url is not set\")\n        log_msg = f\"Init SandboxFusionTool with config: {config}\"\n        logger.info(log_msg)\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(\n        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs\n    ) -> tuple[str, ToolResponse]:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": [],\n        }\n        return instance_id, ToolResponse()\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:\n        code = parameters.get(\"code\", \"\")\n        timeout = parameters.get(\"timeout\", self.default_timeout)\n        language = parameters.get(\"language\", self.default_language)\n        if not isinstance(code, str):\n            code = str(code)\n\n        result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language)\n        # sandbox has no score or metrics, use Nones\n        return ToolResponse(text=result), None, None\n\n    def execute_code(self, instance_id, code, timeout=30, language=\"python\"):\n        result_status, metadata = _process_single_case(\n            0, None, None, self.sandbox_fusion_url, code, timeout, self.memory_limit_mb, language\n        )\n        # we should always expect this since we don't have correct answer\n        if metadata[\"run_status\"] == \"Finished\":\n            actual_output = metadata[\"stdout\"] + metadata[\"stderr\"]\n            logger.debug(f\"actual_output from sandbox fusion: {actual_output},{instance_id}\")\n            return ToolResponse(text=actual_output)\n        else:\n            return ToolResponse(text=\"no stdout here\")\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> str:\n        return self._instance_dict[instance_id][\"reward\"]\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "verl_distillation/verl/tools/schemas.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nfrom typing import Any, Literal\n\nfrom pydantic import BaseModel, Field, model_validator\n\n\nclass OpenAIFunctionPropertySchema(BaseModel):\n    \"\"\"The schema of a parameter in OpenAI format.\"\"\"\n\n    type: str\n    description: str | None = None\n    enum: list[str] | None = None\n\n\nclass OpenAIFunctionParametersSchema(BaseModel):\n    \"\"\"The schema of parameters in OpenAI format.\"\"\"\n\n    type: str\n    properties: dict[str, OpenAIFunctionPropertySchema]\n    required: list[str]\n\n\nclass OpenAIFunctionSchema(BaseModel):\n    \"\"\"The schema of a function in OpenAI format.\"\"\"\n\n    name: str\n    description: str\n    parameters: OpenAIFunctionParametersSchema = Field(\n        default_factory=lambda: OpenAIFunctionParametersSchema(type=\"object\", properties={}, required=[])\n    )\n    strict: bool = False\n\n\nclass OpenAIFunctionToolSchema(BaseModel):\n    \"\"\"The schema of a tool in OpenAI format.\"\"\"\n\n    type: str\n    function: OpenAIFunctionSchema\n\n\nclass OpenAIFunctionParsedSchema(BaseModel):\n    \"\"\"The parsed schema of a tool in OpenAI format.\"\"\"\n\n    name: str\n    arguments: str  # JSON string\n\n\nclass OpenAIFunctionCallSchema(BaseModel):\n    \"\"\"The parsed schema of a tool in OpenAI format.\"\"\"\n\n    name: str\n    arguments: dict[str, Any]\n\n    @staticmethod\n    def from_openai_function_parsed_schema(\n        parsed_schema: OpenAIFunctionParsedSchema,\n    ) -> tuple[\"OpenAIFunctionCallSchema\", bool]:\n        has_decode_error = False\n        try:\n            arguments = json.loads(parsed_schema.arguments)\n        except json.JSONDecodeError:\n            arguments = {}\n            has_decode_error = True\n        # If the arguments is not a dict, it means the arguments is not a valid JSON string\n        if not isinstance(arguments, dict):\n            arguments = {}\n            has_decode_error = True\n\n        return OpenAIFunctionCallSchema(name=parsed_schema.name, arguments=arguments), has_decode_error\n\n\nclass OpenAIFunctionToolCall(BaseModel):\n    \"\"\"The tool call in OpenAI format.\"\"\"\n\n    id: str\n    type: Literal[\"function\"] = \"function\"\n    function: OpenAIFunctionCallSchema\n\n\nclass ToolResponse(BaseModel):\n    \"\"\"The response from a tool execution.\"\"\"\n\n    text: str | None = None\n    image: list[Any] | None = None\n    video: list[Any] | None = None\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def initialize_request(cls, values):\n        if \"image\" in values and not isinstance(values[\"image\"], list):\n            raise ValueError(\n                f\"Image must be a list, but got {type(values['image'])}. Please check the tool.execute(). \"\n                f\"For single images, wrap in a list: [image]. \"\n                f\"Example: {{'image': [img1]}} or {{'image': [img1, img2, ...]}}.\"\n            )\n        if \"video\" in values and not isinstance(values[\"video\"], list):\n            raise ValueError(\n                f\"Video must be a list, but got {type(values['video'])}. Please check the tool.execute(). \"\n                f\"For single videos, wrap in a list: [video]. \"\n                f\"Example: {{'video': [video1]}} or {{'video': [video1, video2, ...]}}.\"\n            )\n\n        return values\n\n    def is_empty(self) -> bool:\n        return not self.text and not self.image and not self.video\n\n    def is_text_only(self) -> bool:\n        return self.text and not self.image and not self.video\n"
  },
  {
    "path": "verl_distillation/verl/tools/search_tool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport json\r\nimport logging\r\nimport os\r\nimport threading\r\nfrom contextlib import ExitStack\r\nfrom enum import Enum\r\nfrom typing import Any, Callable, Optional, TypeVar\r\nfrom uuid import uuid4\r\n\r\nimport ray\r\nimport ray.actor\r\n\r\nfrom verl.tools.utils.search_r1_like_utils import perform_single_search_batch\r\nfrom verl.utils.rollout_trace import rollout_trace_op\r\n\r\nfrom .base_tool import BaseTool\r\nfrom .schemas import OpenAIFunctionToolSchema, ToolResponse\r\n\r\nlogger = logging.getLogger(__name__)\r\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\r\n\r\nT = TypeVar(\"T\")\r\n\r\n\r\n# Adapted from verl/tools/sandbox_fusion_tools.py\r\nclass PoolMode(Enum):\r\n    \"\"\"Execution pool mode enumeration.\"\"\"\r\n\r\n    ThreadMode = 1\r\n    ProcessMode = 2\r\n\r\n\r\n@ray.remote(concurrency_groups={\"acquire\": 1, \"release\": 10})\r\nclass TokenBucketWorker:\r\n    \"\"\"Ray actor for rate limiting using token bucket algorithm.\"\"\"\r\n\r\n    def __init__(self, rate_limit: int):\r\n        self.rate_limit = rate_limit\r\n        self.current_count = 0  # For observability\r\n        self._semaphore = threading.Semaphore(rate_limit)\r\n\r\n    @ray.method(concurrency_group=\"acquire\")\r\n    def acquire(self):\r\n        \"\"\"Acquire a token from the bucket.\"\"\"\r\n        self._semaphore.acquire()\r\n        self.current_count += 1\r\n\r\n    @ray.method(concurrency_group=\"release\")\r\n    def release(self):\r\n        \"\"\"Release a token back to the bucket.\"\"\"\r\n        self._semaphore.release()\r\n        self.current_count -= 1\r\n\r\n    def get_current_count(self):\r\n        \"\"\"Get current number of acquired tokens.\"\"\"\r\n        return self.current_count\r\n\r\n\r\nclass SearchExecutionWorker:\r\n    \"\"\"Worker for executing search operations with optional rate limiting.\"\"\"\r\n\r\n    def __init__(self, enable_global_rate_limit=True, rate_limit=10):\r\n        self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None\r\n\r\n    def _init_rate_limit(self, rate_limit):\r\n        \"\"\"Initialize singleton rate limiter.\"\"\"\r\n        return TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote(rate_limit)\r\n\r\n    def ping(self):\r\n        \"\"\"Health check method.\"\"\"\r\n        return True\r\n\r\n    def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T:\r\n        \"\"\"Execute function with optional rate limiting.\"\"\"\r\n        if self.rate_limit_worker:\r\n            with ExitStack() as stack:\r\n                stack.callback(self.rate_limit_worker.release.remote)\r\n                ray.get(self.rate_limit_worker.acquire.remote())\r\n                try:\r\n                    return fn(*fn_args, **fn_kwargs)\r\n                except Exception as e:\r\n                    # TODO we should make this available to the tool caller\r\n                    logger.warning(f\"Error when executing search: {e}\")\r\n        else:\r\n            return fn(*fn_args, **fn_kwargs)\r\n\r\n\r\ndef init_search_execution_pool(\r\n    num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode\r\n):\r\n    \"\"\"Initialize search execution pool.\"\"\"\r\n    if mode == PoolMode.ThreadMode:\r\n        return (\r\n            ray.remote(SearchExecutionWorker)\r\n            .options(max_concurrency=num_workers)\r\n            .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit)\r\n        )\r\n    else:\r\n        raise NotImplementedError(\"Process mode is not implemented yet\")\r\n\r\n\r\nclass SearchTool(BaseTool):\r\n    \"\"\"Search tool for retrieving information using external retrieval services.\r\n\r\n    This tool provides search functionality with rate limiting and concurrent execution\r\n    support through Ray. It integrates with external retrieval services to perform\r\n    semantic search operations.\r\n\r\n    Methods:\r\n        get_openai_tool_schema: Return the tool schema in OpenAI format\r\n        create: Create a tool instance for a trajectory\r\n        execute: Execute the search tool\r\n        calc_reward: Calculate the reward with respect to tool state\r\n        release: Release the tool instance\r\n    \"\"\"\r\n\r\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\r\n        \"\"\"Initialize SearchTool with configuration and schema.\r\n\r\n        Args:\r\n            config: Configuration dictionary containing tool settings\r\n            tool_schema: OpenAI function tool schema definition\r\n\r\n        Example tool_schema:\r\n            {\r\n                \"type\": \"function\",\r\n                \"function\": {\r\n                    \"name\": \"search\",\r\n                    \"description\": \"Searches for relevant information based on queries.\",\r\n                    \"parameters\": {\r\n                        \"type\": \"object\",\r\n                        \"properties\": {\r\n                            \"query_list\": {\r\n                                \"type\": \"array\",\r\n                                \"items\": {\"type\": \"string\"},\r\n                                \"description\": \"List of search queries\"\r\n                            }\r\n                        },\r\n                        \"required\": [\"query_list\"]\r\n                    }\r\n                }\r\n            }\r\n        \"\"\"\r\n        super().__init__(config, tool_schema)\r\n        self._instance_dict = {}\r\n\r\n        # Worker and rate limiting configuration\r\n        self.num_workers = config.get(\"num_workers\", 120)\r\n        self.rate_limit = config.get(\"rate_limit\", 120)\r\n        self.timeout = config.get(\"timeout\", 30)\r\n\r\n        self.enable_global_rate_limit = config.get(\"enable_global_rate_limit\", True)\r\n        self.execution_pool = init_search_execution_pool(\r\n            num_workers=self.num_workers,\r\n            enable_global_rate_limit=self.enable_global_rate_limit,\r\n            rate_limit=self.rate_limit,\r\n            mode=PoolMode.ThreadMode,\r\n        )\r\n\r\n        # Retrieval service configuration\r\n        self.retrieval_service_url = config.get(\"retrieval_service_url\")\r\n        assert self.retrieval_service_url, \"Configuration must include 'retrieval_service_url'\"\r\n        self.topk = config.get(\"topk\", 3)\r\n        if self.retrieval_service_url == \"\":\r\n            raise ValueError(\"retrieval_service_url is not set\")\r\n\r\n        logger.info(f\"Initialized SearchTool with config: {config}\")\r\n\r\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\r\n        \"\"\"Return the OpenAI tool schema.\"\"\"\r\n        return self.tool_schema\r\n\r\n    async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]:\r\n        \"\"\"Create a tool instance.\r\n\r\n        Args:\r\n            instance_id: The instance id of the tool.\r\n\r\n        Returns:\r\n            The instance id of the tool.\r\n            tool_creation_response: The response of the tool when creating the instance.\r\n        \"\"\"\r\n        if instance_id is None:\r\n            instance_id = str(uuid4())\r\n        self._instance_dict[instance_id] = {\r\n            \"response\": \"\",\r\n            \"reward\": [],\r\n        }\r\n        return instance_id, ToolResponse()\r\n\r\n    def execute_search(self, instance_id: str, query_list: list, retrieval_service_url: str, topk: int, timeout: int):\r\n        \"\"\"Execute search operation using retrieval service.\r\n\r\n        Args:\r\n            instance_id: Tool instance ID\r\n            query_list: List of search queries\r\n            retrieval_service_url: URL of the retrieval service\r\n            topk: Number of top results to return\r\n            timeout: Request timeout in seconds\r\n\r\n        Returns:\r\n            Tuple of (result_text, metadata)\r\n        \"\"\"\r\n        result_text, metadata = perform_single_search_batch(\r\n            retrieval_service_url=retrieval_service_url,\r\n            query_list=query_list,\r\n            topk=topk,\r\n            concurrent_semaphore=None,  # Ray handles concurrency control\r\n            timeout=timeout,\r\n        )\r\n        logger.debug(f\"Search result for instance {instance_id}: {result_text}\")\r\n        return result_text, metadata\r\n\r\n    @rollout_trace_op\r\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:\r\n        \"\"\"Execute the search tool.\r\n\r\n        Args:\r\n            instance_id: The instance ID of the tool\r\n            parameters: Tool parameters containing query_list and optional timeout\r\n\r\n        Returns: tool_response, tool_reward_score, tool_metrics\r\n            tool_response: The response str of the tool.\r\n            tool_reward_score: The step reward score of the tool.\r\n            tool_metrics: The metrics of the tool.\r\n        \"\"\"\r\n        timeout = self.timeout\r\n        query_list_from_params = parameters.get(\"query_list\")\r\n\r\n        if not query_list_from_params or not isinstance(query_list_from_params, list):\r\n            error_msg = \"Error: 'query_list' is missing, empty, or not a list in parameters.\"\r\n            logger.error(f\"[SearchTool] {error_msg} Received parameters: {parameters}\")\r\n            return ToolResponse(text=json.dumps({\"result\": error_msg})), 0.0, {}\r\n\r\n        # Execute search using Ray execution pool\r\n        try:\r\n            result_text, metadata = await self.execution_pool.execute.remote(\r\n                self.execute_search, instance_id, query_list_from_params, self.retrieval_service_url, self.topk, timeout\r\n            )\r\n\r\n            # Store results in instance dictionary\r\n            self._instance_dict[instance_id][\"reward\"].append(result_text.strip())\r\n\r\n            # Convert metadata to metrics\r\n            metrics = {\r\n                \"query_count\": metadata.get(\"query_count\", 0),\r\n                \"status\": metadata.get(\"status\", \"unknown\"),\r\n                \"total_results\": metadata.get(\"total_results\", 0),\r\n                \"api_request_error\": metadata.get(\"api_request_error\"),\r\n            }\r\n\r\n            return ToolResponse(text=result_text), 0.0, metrics\r\n\r\n        except Exception as e:\r\n            error_result = json.dumps({\"result\": f\"Search execution failed: {e}\"})\r\n            logger.error(f\"[SearchTool] Execution failed: {e}\")\r\n            return ToolResponse(text=error_result), 0.0, {\"error\": str(e)}\r\n\r\n    async def calc_reward(self, instance_id: str, **kwargs) -> str:\r\n        return self._instance_dict[instance_id][\"reward\"]\r\n\r\n    async def release(self, instance_id: str, **kwargs) -> None:\r\n        if instance_id in self._instance_dict:\r\n            del self._instance_dict[instance_id]\r\n"
  },
  {
    "path": "verl_distillation/verl/tools/utils/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/tools/utils/mcp_clients/McpClientManager.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport asyncio\r\nimport json\r\nimport logging\r\nfrom typing import Any\r\n\r\nfrom fastmcp import Client\r\nfrom fastmcp.client.transports import SSETransport\r\n\r\nfrom verl.tools.utils.mcp_clients.utils import TokenBucket, mcp2openai\r\n\r\nlogger = logging.getLogger(__name__)\r\n\r\n\r\nclass MCPClientManager:\r\n    rootServerName = \"mcpServers\"\r\n    initialized = False\r\n    clients = []\r\n    tool_client_mapping = {}\r\n    rate_limiter = None\r\n\r\n    async def initialize(self, config_path, rate_limit: float = 10.0):\r\n        if self.initialized:\r\n            return\r\n        \"\"\"Initialize the MCP Client Manager and start all clients\"\"\"\r\n        result = self._load_config(config_path)\r\n        servers = result[self.rootServerName]\r\n        exclude_sse_servers = {self.rootServerName: {}}\r\n        for server_name in servers.keys():\r\n            server = servers[server_name]\r\n            if \"auth_token\" in server:\r\n                transport = SSETransport(url=server[\"url\"], headers={\"Authorization\": f\"Bearer {server['auth_token']}\"})\r\n                client = Client(transport)\r\n                self.clients.append(client)\r\n            else:\r\n                exclude_sse_servers[self.rootServerName][server_name] = server\r\n\r\n        if exclude_sse_servers[self.rootServerName]:\r\n            self.clients.append(Client(exclude_sse_servers))\r\n\r\n        # Initialize rate limiter\r\n        self.rate_limiter = TokenBucket(rate_limit)\r\n        self.initialized = True\r\n\r\n    async def call_tool(self, tool_name, parameters, timeout):\r\n        # Apply rate limiting\r\n        while not self.rate_limiter.acquire():\r\n            await asyncio.sleep(0.1)\r\n\r\n        client = self.get_client_with_tool_name(tool_name)\r\n        async with client:\r\n            return await client.call_tool_mcp(tool_name, parameters)\r\n\r\n    async def fetch_tool_schemas(self, tool_selected_list: list[str]) -> list[dict]:\r\n        tool_schemas = []\r\n        for client in self.clients:\r\n            async with client:\r\n                tools = await client.list_tools_mcp()\r\n                for tool in tools.tools:\r\n                    if not tool_selected_list:\r\n                        self.tool_client_mapping[tool.name] = client\r\n                        tool_schemas.append(mcp2openai(tool))\r\n                    elif tool.name in tool_selected_list:\r\n                        self.tool_client_mapping[tool.name] = client\r\n                        tool_schemas.append(mcp2openai(tool))\r\n\r\n        return tool_schemas\r\n\r\n    def get_client_with_tool_name(self, tool_name: str):\r\n        return self.tool_client_mapping[tool_name]\r\n\r\n    def _load_config(self, file: str) -> dict[str, Any]:\r\n        try:\r\n            with open(file) as f:\r\n                return json.load(f)\r\n        except FileNotFoundError:\r\n            logger.warning(f'the \"{file}\" file was not found')\r\n        except Exception:\r\n            logger.error(f'there was an error reading the \"{file}\" file')\r\n\r\n        return {}\r\n\r\n\r\nClientManager = MCPClientManager()\r\n"
  },
  {
    "path": "verl_distillation/verl/tools/utils/mcp_clients/utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport threading\nimport time\n\nfrom mcp import Tool\n\nlogger = logging.getLogger(__file__)\n\n\nclass TokenBucket:\n    def __init__(self, rate_limit: float):\n        self.rate_limit = rate_limit  # tokens per second\n        self.tokens = rate_limit\n        self.last_update = time.time()\n        self.lock = threading.Lock()\n\n    def acquire(self) -> bool:\n        with self.lock:\n            now = time.time()\n            # Add new tokens based on time elapsed\n            new_tokens = (now - self.last_update) * self.rate_limit\n            self.tokens = min(self.rate_limit, self.tokens + new_tokens)\n            self.last_update = now\n\n            if self.tokens >= 1:\n                self.tokens -= 1\n                return True\n            return False\n\n\ndef mcp2openai(mcp_tool: Tool) -> dict:\n    \"\"\"Convert a MCP Tool to an OpenAI ChatCompletionTool.\"\"\"\n    openai_format = {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": mcp_tool.name,\n            \"description\": mcp_tool.description,\n            \"parameters\": mcp_tool.inputSchema,\n            \"strict\": False,\n        },\n    }\n    if not openai_format[\"function\"][\"parameters\"].get(\"required\", None):\n        openai_format[\"function\"][\"parameters\"][\"required\"] = []\n    return openai_format\n"
  },
  {
    "path": "verl_distillation/verl/tools/utils/search_r1_like_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport json\r\nimport logging\r\nimport threading\r\nimport time\r\nimport traceback\r\nimport uuid\r\nfrom typing import Any, Optional\r\n\r\nimport requests\r\n\r\nDEFAULT_TIMEOUT = 30  # Default search request timeout\r\nMAX_RETRIES = 10\r\nINITIAL_RETRY_DELAY = 1\r\nAPI_TIMEOUT = 10\r\n\r\nlogger = logging.getLogger(__name__)\r\n\r\n\r\ndef call_search_api(\r\n    retrieval_service_url: str,\r\n    query_list: list[str],\r\n    topk: int = 3,\r\n    return_scores: bool = True,\r\n    timeout: int = DEFAULT_TIMEOUT,\r\n) -> tuple[Optional[dict[str, Any]], Optional[str]]:\r\n    \"\"\"\r\n    Calls the remote search API to perform retrieval with retry logic for various errors,\r\n    using increasing delay between retries. Logs internal calls with a unique ID.\r\n\r\n    Args:\r\n        retrieval_service_url: The URL of the retrieval service API.\r\n        query_list: List of search queries.\r\n        topk: Number of top results to return.\r\n        return_scores: Whether to return scores.\r\n        timeout: Request timeout in seconds.\r\n\r\n    Returns:\r\n        A tuple (response_json, error_message).\r\n        If successful, response_json is the API's returned JSON object, error_message is None.\r\n        If failed after retries, response_json is None, error_message contains the error information.\r\n    \"\"\"\r\n    request_id = str(uuid.uuid4())\r\n    log_prefix = f\"[Search Request ID: {request_id}] \"\r\n\r\n    payload = {\"queries\": query_list, \"topk\": topk, \"return_scores\": return_scores}\r\n\r\n    headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\r\n\r\n    last_error = None\r\n\r\n    for attempt in range(MAX_RETRIES):\r\n        try:\r\n            logger.info(\r\n                f\"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling search API at {retrieval_service_url}\"\r\n            )\r\n            response = requests.post(\r\n                retrieval_service_url,\r\n                headers=headers,\r\n                json=payload,\r\n                timeout=timeout,\r\n            )\r\n\r\n            # Check for Gateway Timeout (504) and other server errors for retrying\r\n            if response.status_code in [500, 502, 503, 504]:\r\n                last_error = (\r\n                    f\"{log_prefix}API Request Error: Server Error ({response.status_code}) on attempt \"\r\n                    f\"{attempt + 1}/{MAX_RETRIES}\"\r\n                )\r\n                logger.warning(last_error)\r\n                if attempt < MAX_RETRIES - 1:\r\n                    delay = INITIAL_RETRY_DELAY * (attempt + 1)\r\n                    logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")\r\n                    time.sleep(delay)\r\n                continue\r\n\r\n            # Check for other HTTP errors (e.g., 4xx)\r\n            response.raise_for_status()\r\n\r\n            # If successful (status code 2xx)\r\n            logger.info(f\"{log_prefix}Search API call successful on attempt {attempt + 1}\")\r\n            return response.json(), None\r\n\r\n        except requests.exceptions.ConnectionError as e:\r\n            last_error = f\"{log_prefix}Connection Error: {e}\"\r\n            logger.warning(last_error)\r\n            if attempt < MAX_RETRIES - 1:\r\n                delay = INITIAL_RETRY_DELAY * (attempt + 1)\r\n                logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")\r\n                time.sleep(delay)\r\n            continue\r\n        except requests.exceptions.Timeout as e:\r\n            last_error = f\"{log_prefix}Timeout Error: {e}\"\r\n            logger.warning(last_error)\r\n            if attempt < MAX_RETRIES - 1:\r\n                delay = INITIAL_RETRY_DELAY * (attempt + 1)\r\n                logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")\r\n                time.sleep(delay)\r\n            continue\r\n        except requests.exceptions.RequestException as e:\r\n            last_error = f\"{log_prefix}API Request Error: {e}\"\r\n            break  # Exit retry loop on other request errors\r\n        except json.JSONDecodeError as e:\r\n            raw_response_text = response.text if \"response\" in locals() else \"N/A\"\r\n            last_error = f\"{log_prefix}API Response JSON Decode Error: {e}, Response: {raw_response_text[:200]}\"\r\n            break  # Exit retry loop on JSON decode errors\r\n        except Exception as e:\r\n            last_error = f\"{log_prefix}Unexpected Error: {e}\"\r\n            break  # Exit retry loop on other unexpected errors\r\n\r\n    # If loop finishes without returning success, return the last recorded error\r\n    logger.error(f\"{log_prefix}Search API call failed. Last error: {last_error}\")\r\n    return None, last_error.replace(log_prefix, \"API Call Failed: \") if last_error else \"API Call Failed after retries\"\r\n\r\n\r\ndef _passages2string(retrieval_result):\r\n    \"\"\"Convert retrieval results to formatted string.\"\"\"\r\n    format_reference = \"\"\r\n    for idx, doc_item in enumerate(retrieval_result):\r\n        content = doc_item[\"document\"][\"contents\"]\r\n        title = content.split(\"\\n\")[0]\r\n        text = \"\\n\".join(content.split(\"\\n\")[1:])\r\n        format_reference += f\"Doc {idx + 1} (Title: {title})\\n{text}\\n\\n\"\r\n    return format_reference.strip()\r\n\r\n\r\ndef perform_single_search_batch(\r\n    retrieval_service_url: str,\r\n    query_list: list[str],\r\n    topk: int = 3,\r\n    concurrent_semaphore: Optional[threading.Semaphore] = None,\r\n    timeout: int = DEFAULT_TIMEOUT,\r\n) -> tuple[str, dict[str, Any]]:\r\n    \"\"\"\r\n    Performs a single batch search for multiple queries (original search tool behavior).\r\n\r\n    Args:\r\n        retrieval_service_url: The URL of the retrieval service API.\r\n        query_list: List of search queries.\r\n        topk: Number of top results to return.\r\n        concurrent_semaphore: Optional semaphore for concurrency control.\r\n        timeout: Request timeout in seconds.\r\n\r\n    Returns:\r\n        A tuple (result_text, metadata).\r\n        result_text: The search result JSON string.\r\n        metadata: Metadata dictionary for the batch search.\r\n    \"\"\"\r\n    logger.info(f\"Starting batch search for {len(query_list)} queries.\")\r\n\r\n    api_response = None\r\n    error_msg = None\r\n\r\n    try:\r\n        if concurrent_semaphore:\r\n            with concurrent_semaphore:\r\n                api_response, error_msg = call_search_api(\r\n                    retrieval_service_url=retrieval_service_url,\r\n                    query_list=query_list,\r\n                    topk=topk,\r\n                    return_scores=True,\r\n                    timeout=timeout,\r\n                )\r\n        else:\r\n            api_response, error_msg = call_search_api(\r\n                retrieval_service_url=retrieval_service_url,\r\n                query_list=query_list,\r\n                topk=topk,\r\n                return_scores=True,\r\n                timeout=timeout,\r\n            )\r\n    except Exception as e:\r\n        error_msg = f\"API Request Exception during batch search: {e}\"\r\n        logger.error(f\"Batch search: {error_msg}\")\r\n        traceback.print_exc()\r\n\r\n    metadata = {\r\n        \"query_count\": len(query_list),\r\n        \"queries\": query_list,\r\n        \"api_request_error\": error_msg,\r\n        \"api_response\": None,\r\n        \"status\": \"unknown\",\r\n        \"total_results\": 0,\r\n        \"formatted_result\": None,\r\n    }\r\n\r\n    result_text = json.dumps({\"result\": \"Search request failed or timed out after retries.\"}, ensure_ascii=False)\r\n\r\n    if error_msg:\r\n        metadata[\"status\"] = \"api_error\"\r\n        result_text = json.dumps({\"result\": f\"Search error: {error_msg}\"}, ensure_ascii=False)\r\n        logger.error(f\"Batch search: API error occurred: {error_msg}\")\r\n    elif api_response:\r\n        logger.debug(f\"Batch search: API Response: {api_response}\")\r\n        metadata[\"api_response\"] = api_response\r\n\r\n        try:\r\n            raw_results = api_response.get(\"result\", [])\r\n            if raw_results:\r\n                pretty_results = []\r\n                total_results = 0\r\n\r\n                for retrieval in raw_results:\r\n                    formatted = _passages2string(retrieval)\r\n                    pretty_results.append(formatted)\r\n                    total_results += len(retrieval) if isinstance(retrieval, list) else 1\r\n\r\n                final_result = \"\\n---\\n\".join(pretty_results)\r\n                result_text = json.dumps({\"result\": final_result}, ensure_ascii=False)\r\n                metadata[\"status\"] = \"success\"\r\n                metadata[\"total_results\"] = total_results\r\n                metadata[\"formatted_result\"] = final_result\r\n                logger.info(f\"Batch search: Successful, got {total_results} total results\")\r\n            else:\r\n                result_text = json.dumps({\"result\": \"No search results found.\"}, ensure_ascii=False)\r\n                metadata[\"status\"] = \"no_results\"\r\n                metadata[\"total_results\"] = 0\r\n                logger.info(\"Batch search: No results found\")\r\n        except Exception as e:\r\n            error_msg = f\"Error processing search results: {e}\"\r\n            result_text = json.dumps({\"result\": error_msg}, ensure_ascii=False)\r\n            metadata[\"status\"] = \"processing_error\"\r\n            logger.error(f\"Batch search: {error_msg}\")\r\n    else:\r\n        metadata[\"status\"] = \"unknown_api_state\"\r\n        result_text = json.dumps(\r\n            {\"result\": \"Unknown API state (no response and no error message).\"}, ensure_ascii=False\r\n        )\r\n        logger.error(\"Batch search: Unknown API state.\")\r\n\r\n    return result_text, metadata\r\n"
  },
  {
    "path": "verl_distillation/verl/tools/utils/tool_registry.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport importlib\nimport logging\nimport os\nimport sys\nimport threading\nfrom enum import Enum\n\nfrom omegaconf import OmegaConf\n\nfrom verl.tools.schemas import OpenAIFunctionToolSchema\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass ToolType(Enum):\n    NATIVE = \"native\"\n    MCP = \"mcp\"\n\n\nasync def initialize_mcp_tool(tool_cls, tool_config) -> list:\n    from verl.tools.utils.mcp_clients.McpClientManager import ClientManager\n\n    tool_list = []\n    mcp_servers_config_path = tool_config.mcp.mcp_servers_config_path\n    tool_selected_list = tool_config.mcp.tool_selected_list if \"tool_selected_list\" in tool_config.mcp else None\n    await ClientManager.initialize(mcp_servers_config_path, tool_config.config.rate_limit)\n    # Wait for MCP client to be ready\n    max_retries = 10\n    retry_interval = 2  # seconds\n    for i in range(max_retries):\n        tool_schemas = await ClientManager.fetch_tool_schemas(tool_selected_list)\n        if tool_schemas:\n            break\n        if i < max_retries - 1:\n            logger.debug(f\"Waiting for MCP client to be ready, attempt {i + 1}/{max_retries}\")\n            await asyncio.sleep(retry_interval)\n    else:\n        raise RuntimeError(\"Failed to initialize MCP tools after maximum retries\")\n    # mcp registry\n    assert len(tool_schemas), \"mcp tool is empty\"\n    for tool_schema_dict in tool_schemas:\n        logger.debug(f\"tool_schema_dict: {tool_schema_dict}\")\n        tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict)\n        tool = tool_cls(\n            config=OmegaConf.to_container(tool_config.config, resolve=True),\n            tool_schema=tool_schema,\n        )\n        tool_list.append(tool)\n    return tool_list\n\n\ndef get_tool_class(cls_name):\n    module_name, class_name = cls_name.rsplit(\".\", 1)\n    if module_name not in sys.modules:\n        spec = importlib.util.find_spec(module_name)\n        module = importlib.util.module_from_spec(spec)\n        sys.modules[module_name] = module\n        spec.loader.exec_module(module)\n    else:\n        module = sys.modules[module_name]\n\n    tool_cls = getattr(module, class_name)\n    return tool_cls\n\n\ndef initialize_tools_from_config(tools_config_file):\n    tools_config = OmegaConf.load(tools_config_file)\n    tool_list = []\n\n    # Use a temporary event loop in a new thread because event\n    # loop may already exist in new async architecture while retaining\n    # backwards compatibility\n    tmp_event_loop = asyncio.new_event_loop()\n    thread = threading.Thread(target=tmp_event_loop.run_forever, name=\"mcp tool list fetcher\", daemon=True)\n\n    def run_coroutine(coroutine):\n        if not thread.is_alive():\n            thread.start()\n\n        future = asyncio.run_coroutine_threadsafe(coroutine, tmp_event_loop)\n        return future.result()\n\n    async def stop_loop():\n        tmp_event_loop.stop()\n\n    try:\n        for tool_config in tools_config.tools:\n            cls_name = tool_config.class_name\n            tool_type = ToolType(tool_config.config.type)\n            tool_cls = get_tool_class(cls_name)\n\n            match tool_type:\n                case ToolType.NATIVE:\n                    if tool_config.get(\"tool_schema\", None) is None:\n                        tool_schema = None\n                    else:\n                        tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True)\n                        tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict)\n                    tool = tool_cls(\n                        config=OmegaConf.to_container(tool_config.config, resolve=True),\n                        tool_schema=tool_schema,\n                    )\n                    tool_list.append(tool)\n                case ToolType.MCP:\n                    mcp_tools = run_coroutine(initialize_mcp_tool(tool_cls, tool_config))\n                    tool_list.extend(mcp_tools)\n                case _:\n                    raise NotImplementedError\n    finally:\n        if thread.is_alive():\n            asyncio.run_coroutine_threadsafe(stop_loop(), tmp_event_loop)\n            thread.join()\n\n    return tool_list\n"
  },
  {
    "path": "verl_distillation/verl/trainer/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom . import algorithm, config\nfrom .algorithm import *  # noqa: F401\nfrom .config import *  # noqa: F401\n\n__all__ = config.__all__ + algorithm.__all__\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/_generated_ppo_megatron_trainer.yaml",
    "content": "# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'\n# in which it invokes 'python3 scripts/print_cfg.py --cfg job --config-name=ppo_megatron_trainer.yaml' to flatten the 'verl/trainer/config/ppo_megatron_trainer.yaml' config fields into a single file.\n# Do not modify this file directly.\n# The file is usually only for reference and never used.\n\nactor_rollout_ref:\n  actor:\n    optim:\n      _target_: verl.workers.config.McoreOptimizerConfig\n      lr: 1.0e-06\n      lr_warmup_steps_ratio: 0.0\n      total_training_steps: -1\n      weight_decay: 0.01\n      lr_warmup_steps: -1\n      betas:\n      - 0.9\n      - 0.999\n      clip_grad: 1.0\n      optimizer: adam\n      lr_warmup_init: 0.0\n      lr_decay_steps: null\n      lr_decay_style: constant\n      min_lr: 0.0\n      weight_decay_incr_style: constant\n      lr_wsd_decay_style: exponential\n      lr_wsd_decay_steps: null\n      use_checkpoint_opt_param_scheduler: false\n      override_optimizer_config: {}\n    megatron:\n      _target_: verl.workers.config.McoreEngineConfig\n      param_offload: false\n      grad_offload: false\n      optimizer_offload: false\n      tensor_model_parallel_size: 1\n      expert_model_parallel_size: 1\n      expert_tensor_parallel_size: 1\n      pipeline_model_parallel_size: 1\n      virtual_pipeline_model_parallel_size: null\n      context_parallel_size: 1\n      sequence_parallel: true\n      use_distributed_optimizer: true\n      use_dist_checkpointing: false\n      dist_checkpointing_path: null\n      seed: 42\n      override_ddp_config: {}\n      override_transformer_config:\n        recompute_granularity: null\n        recompute_modules:\n        - core_attn\n        recompute_method: null\n        recompute_num_layers: null\n        attention_backend: flash\n      override_mcore_model_config: {}\n      use_mbridge: false\n      forward_only: false\n    _target_: verl.workers.config.McoreActorConfig\n    strategy: megatron\n    ppo_mini_batch_size: 256\n    ppo_micro_batch_size: null\n    ppo_micro_batch_size_per_gpu: null\n    use_dynamic_bsz: false\n    ppo_max_token_len_per_gpu: 16384\n    clip_ratio: 0.2\n    clip_ratio_low: 0.2\n    clip_ratio_high: 0.2\n    freeze_vision_tower: false\n    policy_loss:\n      _target_: verl.workers.config.PolicyLossConfig\n      loss_mode: vanilla\n      clip_cov_ratio: 0.0002\n      clip_cov_lb: 1.0\n      clip_cov_ub: 5.0\n      kl_cov_ratio: 0.0002\n      ppo_kl_coef: 0.1\n    clip_ratio_c: 3.0\n    loss_agg_mode: token-mean\n    entropy_coeff: 0\n    use_kl_loss: false\n    use_torch_compile: true\n    kl_loss_coef: 0.001\n    kl_loss_type: low_var_kl\n    ppo_epochs: 1\n    shuffle: false\n    checkpoint:\n      _target_: verl.trainer.config.CheckpointConfig\n      save_contents:\n      - model\n      - optimizer\n      - extra\n      load_contents: ${.save_contents}\n      async_save: false\n    use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}\n    profiler:\n      _target_: verl.utils.profiler.ProfilerConfig\n      tool: ${oc.select:global_profiler.tool,null}\n      enable: false\n      all_ranks: false\n      ranks: []\n      save_path: ${oc.select:global_profiler.save_path,null}\n      tool_config:\n        nsys:\n          _target_: verl.utils.profiler.config.NsightToolConfig\n          discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}\n        npu:\n          _target_: verl.utils.profiler.config.NPUToolConfig\n          contents: []\n          level: level1\n          analysis: true\n          discrete: false\n        torch:\n          _target_: verl.utils.profiler.config.TorchProfilerToolConfig\n          step_start: 0\n          step_end: null\n        torch_memory:\n          _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n          trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}\n          stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}\n    data_loader_seed: null\n    load_weight: true\n  ref:\n    strategy: megatron\n    use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}\n    log_prob_micro_batch_size: null\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n    log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n    profiler:\n      _target_: verl.utils.profiler.ProfilerConfig\n      tool: ${oc.select:global_profiler.tool,null}\n      enable: false\n      all_ranks: false\n      ranks: []\n      save_path: ${oc.select:global_profiler.save_path,null}\n      tool_config:\n        nsys:\n          _target_: verl.utils.profiler.config.NsightToolConfig\n          discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}\n        npu:\n          _target_: verl.utils.profiler.config.NPUToolConfig\n          contents: []\n          level: level1\n          analysis: true\n          discrete: false\n        torch:\n          _target_: verl.utils.profiler.config.TorchProfilerToolConfig\n          step_start: 0\n          step_end: null\n        torch_memory:\n          _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n          trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}\n          stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}\n    megatron:\n      _target_: verl.workers.config.MegatronEngineConfig\n      param_offload: false\n      grad_offload: false\n      optimizer_offload: false\n      tensor_model_parallel_size: 1\n      expert_model_parallel_size: 1\n      expert_tensor_parallel_size: 1\n      pipeline_model_parallel_size: 1\n      virtual_pipeline_model_parallel_size: null\n      context_parallel_size: 1\n      sequence_parallel: true\n      use_distributed_optimizer: true\n      use_dist_checkpointing: false\n      dist_checkpointing_path: null\n      seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}\n      override_ddp_config: {}\n      override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}\n      override_mcore_model_config: {}\n      use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}\n      forward_only: false\n    load_weight: true\n  rollout:\n    _target_: verl.workers.config.RolloutConfig\n    name: ???\n    mode: sync\n    temperature: 1.0\n    top_k: -1\n    top_p: 1\n    prompt_length: ${oc.select:data.max_prompt_length,512}\n    response_length: ${oc.select:data.max_response_length,512}\n    dtype: bfloat16\n    gpu_memory_utilization: 0.5\n    ignore_eos: false\n    enforce_eager: false\n    cudagraph_capture_sizes: null\n    free_cache_engine: true\n    tensor_model_parallel_size: 2\n    data_parallel_size: 1\n    expert_parallel_size: 1\n    pipeline_model_parallel_size: 1\n    max_num_batched_tokens: 8192\n    max_model_len: null\n    max_num_seqs: 1024\n    enable_chunked_prefill: true\n    enable_prefix_caching: true\n    load_format: dummy\n    log_prob_micro_batch_size: null\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n    log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n    disable_log_stats: true\n    do_sample: true\n    'n': 1\n    over_sample_rate: 0\n    multi_stage_wake_up: false\n    engine_kwargs:\n      vllm: {}\n      sglang: {}\n    val_kwargs:\n      _target_: verl.workers.config.SamplingConfig\n      top_k: -1\n      top_p: 1.0\n      temperature: 0\n      'n': 1\n      do_sample: false\n    multi_turn:\n      _target_: verl.workers.config.MultiTurnConfig\n      enable: false\n      max_assistant_turns: null\n      tool_config_path: null\n      max_user_turns: null\n      max_parallel_calls: 1\n      max_tool_response_length: 256\n      tool_response_truncate_side: middle\n      interaction_config_path: null\n      use_inference_chat_template: false\n      tokenization_sanity_check_mode: strict\n      format: hermes\n      num_repeat_rollouts: null\n    calculate_log_probs: false\n    agent:\n      _target_: verl.workers.config.AgentLoopConfig\n      num_workers: 8\n      default_agent_loop: single_turn_agent\n      agent_loop_config_path: null\n      custom_async_server:\n        _target_: verl.workers.config.CustomAsyncServerConfig\n        path: null\n        name: null\n    update_weights_bucket_megabytes: 512\n    trace:\n      _target_: verl.workers.config.TraceConfig\n      backend: null\n      token2text: false\n    skip_rollout: false\n    skip_dump_dir: /tmp/rollout_dump\n    skip_tokenizer_init: true\n    profiler:\n      _target_: verl.utils.profiler.ProfilerConfig\n      tool: ${oc.select:global_profiler.tool,null}\n      enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false}\n      all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false}\n      ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]}\n      save_path: ${oc.select:global_profiler.save_path,null}\n      tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}\n    layer_name_map:\n      qkv_layer_name: qkv\n      gate_proj_layer_name: gate_up\n  hybrid_engine: true\n  nccl_timeout: 600\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    custom_chat_template: null\n    external_lib: null\n    override_config:\n      model_config: {}\n      moe_config:\n        freeze_moe_router: false\n    use_fused_kernels: false\n    trust_remote_code: false\n    use_remove_padding: false\ndata:\n  tokenizer: null\n  use_shm: false\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n  train_max_samples: -1\n  val_max_samples: -1\n  prompt_key: prompt\n  reward_fn_key: data_source\n  max_prompt_length: 512\n  max_response_length: 512\n  train_batch_size: 1024\n  val_batch_size: null\n  tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path,\n    null}\n  return_raw_input_ids: false\n  return_raw_chat: false\n  return_full_prompt: false\n  shuffle: true\n  seed: null\n  dataloader_num_workers: 8\n  image_patch_size: 14\n  validation_shuffle: false\n  filter_overlong_prompts: false\n  filter_overlong_prompts_workers: 1\n  truncation: error\n  image_key: images\n  video_key: videos\n  trust_remote_code: false\n  custom_cls:\n    path: null\n    name: null\n  return_multi_modal_inputs: true\n  sampler:\n    class_path: null\n    class_name: null\n  datagen:\n    path: null\n    name: null\n  apply_chat_template_kwargs: {}\ncritic:\n  optim:\n    _target_: verl.workers.config.McoreOptimizerConfig\n    lr: 1.0e-05\n    lr_warmup_steps_ratio: 0.0\n    total_training_steps: -1\n    weight_decay: 0.01\n    lr_warmup_steps: -1\n    betas:\n    - 0.9\n    - 0.999\n    clip_grad: 1.0\n    optimizer: adam\n    lr_warmup_init: 0.0\n    lr_decay_steps: null\n    lr_decay_style: constant\n    min_lr: 0.0\n    weight_decay_incr_style: constant\n    lr_wsd_decay_style: exponential\n    lr_wsd_decay_steps: null\n    use_checkpoint_opt_param_scheduler: false\n    override_optimizer_config: {}\n  megatron:\n    _target_: verl.workers.config.McoreEngineConfig\n    param_offload: false\n    grad_offload: false\n    optimizer_offload: false\n    tensor_model_parallel_size: 1\n    expert_model_parallel_size: 1\n    expert_tensor_parallel_size: 1\n    pipeline_model_parallel_size: 1\n    virtual_pipeline_model_parallel_size: null\n    context_parallel_size: 1\n    sequence_parallel: true\n    use_distributed_optimizer: true\n    use_dist_checkpointing: false\n    dist_checkpointing_path: null\n    seed: 42\n    override_ddp_config: {}\n    override_transformer_config:\n      recompute_granularity: null\n      recompute_modules:\n      - core_attn\n      recompute_method: null\n      recompute_num_layers: null\n      attention_backend: flash\n    override_mcore_model_config: {}\n    use_mbridge: false\n    forward_only: false\n  _target_: verl.workers.config.McoreCriticConfig\n  rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}\n  strategy: megatron\n  enable: null\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    tokenizer_path: ${oc.select:actor_rollout_ref.model.path,\"~/models/deepseek-llm-7b-chat\"}\n    override_config:\n      model_config: {}\n      moe_config:\n        freeze_moe_router: false\n    external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}\n    trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}\n    _target_: verl.trainer.config.BaseModelConfig\n  ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}\n  ppo_micro_batch_size: null\n  ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}\n  use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n  ppo_max_token_len_per_gpu: 32768\n  forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}\n  ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1}\n  shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false}\n  cliprange_value: 0.5\n  loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}\n  checkpoint:\n    _target_: verl.trainer.config.CheckpointConfig\n    save_contents:\n    - model\n    - optimizer\n    - extra\n    load_contents: ${.save_contents}\n    async_save: false\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    tool: ${oc.select:global_profiler.tool,null}\n    enable: false\n    all_ranks: false\n    ranks: []\n    save_path: ${oc.select:global_profiler.save_path,null}\n    tool_config:\n      nsys:\n        _target_: verl.utils.profiler.config.NsightToolConfig\n        discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}\n      npu:\n        _target_: verl.utils.profiler.config.NPUToolConfig\n        contents: []\n        level: level1\n        analysis: true\n        discrete: false\n      torch:\n        _target_: verl.utils.profiler.config.TorchProfilerToolConfig\n        step_start: 0\n        step_end: null\n      torch_memory:\n        _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n        trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}\n        stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}\n  nccl_timeout: 600\n  load_weight: true\n  data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null}\nreward_model:\n  enable: false\n  enable_resource_pool: false\n  n_gpus_per_node: 0\n  nnodes: 0\n  strategy: megatron\n  model:\n    input_tokenizer: ${actor_rollout_ref.model.path}\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    trust_remote_code: false\n  micro_batch_size: null\n  micro_batch_size_per_gpu: null\n  max_length: null\n  use_dynamic_bsz: ${critic.use_dynamic_bsz}\n  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n  reward_manager: naive\n  launch_reward_fn_async: false\n  sandbox_fusion:\n    url: null\n    max_concurrent: 64\n    memory_limit_mb: 1024\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    tool: ${oc.select:global_profiler.tool,null}\n    enable: false\n    all_ranks: false\n    ranks: []\n    save_path: ${oc.select:global_profiler.save_path,null}\n    tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}\n  nccl_timeout: 600\n  megatron:\n    _target_: verl.workers.config.MegatronEngineConfig\n    param_offload: false\n    tensor_model_parallel_size: 1\n    expert_model_parallel_size: 1\n    expert_tensor_parallel_size: 1\n    pipeline_model_parallel_size: 1\n    virtual_pipeline_model_parallel_size: null\n    context_parallel_size: 1\n    sequence_parallel: true\n    use_distributed_optimizer: false\n    use_dist_checkpointing: false\n    dist_checkpointing_path: null\n    seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}\n    override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}\n    use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}\n  load_weight: true\ncustom_reward_function:\n  path: null\n  name: compute_score\nalgorithm:\n  _target_: verl.trainer.config.AlgoConfig\n  gamma: 1.0\n  lam: 1.0\n  adv_estimator: gae\n  norm_adv_by_std_in_grpo: true\n  use_kl_in_reward: false\n  kl_penalty: kl\n  kl_ctrl:\n    _target_: verl.trainer.config.KLControlConfig\n    type: fixed\n    kl_coef: 0.001\n    horizon: 10000\n    target_kl: 0.1\n  use_pf_ppo: false\n  pf_ppo:\n    reweight_method: pow\n    weight_pow: 2.0\n  rollout_is_threshold: null\n  rollout_is_threshold_lower: null\n  rollout_is_level: token\n  rollout_is_mode: truncate\n  rollout_is_veto_threshold: null\n  rollout_is: false\ntrainer:\n  balance_batch: true\n  total_epochs: 30\n  total_training_steps: null\n  project_name: verl_examples\n  experiment_name: gsm8k\n  logger:\n  - console\n  - wandb\n  log_val_generations: 0\n  nnodes: 1\n  n_gpus_per_node: 8\n  save_freq: -1\n  esi_redundant_time: 0\n  resume_mode: auto\n  resume_from_path: null\n  del_local_ckpt_after_load: false\n  val_before_train: true\n  test_freq: -1\n  critic_warmup: 0\n  default_hdfs_dir: null\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  max_actor_ckpt_to_keep: null\n  max_critic_ckpt_to_keep: null\n  ray_wait_register_center_timeout: 300\n  device: cuda\n  rollout_data_dir: null\nglobal_profiler:\n  _target_: verl.utils.profiler.ProfilerConfig\n  tool: null\n  steps: null\n  profile_continuous_steps: false\n  save_path: outputs/profile\n  global_tool_config:\n    nsys:\n      discrete: false\n      controller_nsight_options:\n        trace: cuda,nvtx,cublas,ucx\n        cuda-memory-usage: 'true'\n        cuda-graph-trace: graph\n      worker_nsight_options:\n        trace: cuda,nvtx,cublas,ucx\n        cuda-memory-usage: 'true'\n        cuda-graph-trace: graph\n        capture-range: cudaProfilerApi\n        capture-range-end: null\n        kill: none\n    torch_memory:\n      trace_alloc_max_entries: 100000\n      stack_depth: 32\n      context: all\n      stacks: all\n      kw_args: {}\ntransfer_queue:\n  enable: false\nray_kwargs:\n  ray_init:\n    num_cpus: null\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/_generated_ppo_trainer.yaml",
    "content": "# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'\n# in which it invokes 'python3 scripts/print_cfg.py --cfg job ' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file.\n# Do not modify this file directly.\n# The file is usually only for reference and never used.\n\nactor_rollout_ref:\n  actor:\n    optim:\n      _target_: verl.workers.config.FSDPOptimizerConfig\n      optimizer: AdamW\n      optimizer_impl: torch.optim\n      lr: 1.0e-06\n      lr_warmup_steps_ratio: 0.0\n      total_training_steps: -1\n      weight_decay: 0.01\n      lr_warmup_steps: -1\n      betas:\n      - 0.9\n      - 0.999\n      clip_grad: 1.0\n      min_lr_ratio: 0.0\n      num_cycles: 0.5\n      lr_scheduler_type: constant\n      warmup_style: null\n      override_optimizer_config: null\n    fsdp_config:\n      _target_: verl.workers.config.FSDPEngineConfig\n      wrap_policy:\n        min_num_params: 0\n      param_offload: false\n      optimizer_offload: false\n      offload_policy: false\n      reshard_after_forward: true\n      fsdp_size: -1\n      forward_prefetch: false\n      model_dtype: fp32\n      use_orig_params: false\n      ulysses_sequence_parallel_size: 1\n      entropy_from_logits_with_chunking: false\n      use_torch_compile: true\n      entropy_checkpointing: false\n      forward_only: false\n      strategy: fsdp\n    _target_: verl.workers.config.FSDPActorConfig\n    strategy: fsdp\n    ppo_mini_batch_size: 256\n    ppo_micro_batch_size: null\n    ppo_micro_batch_size_per_gpu: null\n    use_dynamic_bsz: false\n    ppo_max_token_len_per_gpu: 16384\n    clip_ratio: 0.2\n    clip_ratio_low: 0.2\n    clip_ratio_high: 0.2\n    freeze_vision_tower: false\n    policy_loss:\n      _target_: verl.workers.config.PolicyLossConfig\n      loss_mode: vanilla\n      clip_cov_ratio: 0.0002\n      clip_cov_lb: 1.0\n      clip_cov_ub: 5.0\n      kl_cov_ratio: 0.0002\n      ppo_kl_coef: 0.1\n    clip_ratio_c: 3.0\n    loss_agg_mode: token-mean\n    entropy_coeff: 0\n    use_kl_loss: false\n    use_torch_compile: true\n    kl_loss_coef: 0.001\n    kl_loss_type: low_var_kl\n    ppo_epochs: 1\n    shuffle: false\n    checkpoint:\n      _target_: verl.trainer.config.CheckpointConfig\n      save_contents:\n      - model\n      - optimizer\n      - extra\n      load_contents: ${.save_contents}\n      async_save: false\n    use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}\n    profiler:\n      _target_: verl.utils.profiler.ProfilerConfig\n      tool: ${oc.select:global_profiler.tool,null}\n      enable: false\n      all_ranks: false\n      ranks: []\n      save_path: ${oc.select:global_profiler.save_path,null}\n      tool_config:\n        nsys:\n          _target_: verl.utils.profiler.config.NsightToolConfig\n          discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}\n        npu:\n          _target_: verl.utils.profiler.config.NPUToolConfig\n          contents: []\n          level: level1\n          analysis: true\n          discrete: false\n        torch:\n          _target_: verl.utils.profiler.config.TorchProfilerToolConfig\n          step_start: 0\n          step_end: null\n        torch_memory:\n          _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n          trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}\n          stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}\n    grad_clip: 1.0\n    ulysses_sequence_parallel_size: 1\n    entropy_from_logits_with_chunking: false\n    entropy_checkpointing: false\n    use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}\n  ref:\n    strategy: ${actor_rollout_ref.actor.strategy}\n    use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}\n    log_prob_micro_batch_size: null\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n    log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n    profiler:\n      _target_: verl.utils.profiler.ProfilerConfig\n      tool: ${oc.select:global_profiler.tool,null}\n      enable: false\n      all_ranks: false\n      ranks: []\n      save_path: ${oc.select:global_profiler.save_path,null}\n      tool_config:\n        nsys:\n          _target_: verl.utils.profiler.config.NsightToolConfig\n          discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}\n        npu:\n          _target_: verl.utils.profiler.config.NPUToolConfig\n          contents: []\n          level: level1\n          analysis: true\n          discrete: false\n        torch:\n          _target_: verl.utils.profiler.config.TorchProfilerToolConfig\n          step_start: 0\n          step_end: null\n        torch_memory:\n          _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n          trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}\n          stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}\n    fsdp_config:\n      _target_: verl.workers.config.FSDPEngineConfig\n      wrap_policy:\n        min_num_params: 0\n      param_offload: false\n      optimizer_offload: false\n      offload_policy: false\n      reshard_after_forward: true\n      fsdp_size: -1\n      forward_prefetch: false\n      model_dtype: fp32\n      use_orig_params: false\n      ulysses_sequence_parallel_size: 1\n      entropy_from_logits_with_chunking: false\n      use_torch_compile: true\n      entropy_checkpointing: false\n      forward_only: false\n      strategy: fsdp\n    model: null\n    ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}\n    entropy_from_logits_with_chunking: false\n    entropy_checkpointing: false\n  rollout:\n    _target_: verl.workers.config.RolloutConfig\n    name: ???\n    mode: sync\n    temperature: 1.0\n    top_k: -1\n    top_p: 1\n    prompt_length: ${oc.select:data.max_prompt_length,512}\n    response_length: ${oc.select:data.max_response_length,512}\n    dtype: bfloat16\n    gpu_memory_utilization: 0.5\n    ignore_eos: false\n    enforce_eager: false\n    cudagraph_capture_sizes: null\n    free_cache_engine: true\n    tensor_model_parallel_size: 2\n    data_parallel_size: 1\n    expert_parallel_size: 1\n    pipeline_model_parallel_size: 1\n    max_num_batched_tokens: 8192\n    max_model_len: null\n    max_num_seqs: 1024\n    enable_chunked_prefill: true\n    enable_prefix_caching: true\n    load_format: dummy\n    log_prob_micro_batch_size: null\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n    log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n    disable_log_stats: true\n    do_sample: true\n    'n': 1\n    over_sample_rate: 0\n    multi_stage_wake_up: false\n    engine_kwargs:\n      vllm: {}\n      sglang: {}\n    val_kwargs:\n      _target_: verl.workers.config.SamplingConfig\n      top_k: -1\n      top_p: 1.0\n      temperature: 0\n      'n': 1\n      do_sample: false\n    multi_turn:\n      _target_: verl.workers.config.MultiTurnConfig\n      enable: false\n      max_assistant_turns: null\n      tool_config_path: null\n      max_user_turns: null\n      max_parallel_calls: 1\n      max_tool_response_length: 256\n      tool_response_truncate_side: middle\n      interaction_config_path: null\n      use_inference_chat_template: false\n      tokenization_sanity_check_mode: strict\n      format: hermes\n      num_repeat_rollouts: null\n    calculate_log_probs: false\n    agent:\n      _target_: verl.workers.config.AgentLoopConfig\n      num_workers: 8\n      default_agent_loop: single_turn_agent\n      agent_loop_config_path: null\n      custom_async_server:\n        _target_: verl.workers.config.CustomAsyncServerConfig\n        path: null\n        name: null\n    update_weights_bucket_megabytes: 512\n    trace:\n      _target_: verl.workers.config.TraceConfig\n      backend: null\n      token2text: false\n    skip_rollout: false\n    skip_dump_dir: /tmp/rollout_dump\n    skip_tokenizer_init: true\n    profiler:\n      _target_: verl.utils.profiler.ProfilerConfig\n      tool: ${oc.select:global_profiler.tool,null}\n      enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false}\n      all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false}\n      ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]}\n      save_path: ${oc.select:global_profiler.save_path,null}\n      tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}\n    layered_summon: false\n  model:\n    _target_: verl.workers.config.HFModelConfig\n    path: ~/models/deepseek-llm-7b-chat\n    hf_config_path: null\n    tokenizer_path: null\n    use_shm: false\n    trust_remote_code: false\n    custom_chat_template: null\n    external_lib: null\n    override_config: {}\n    enable_gradient_checkpointing: true\n    enable_activation_offload: false\n    use_remove_padding: false\n    lora_rank: 0\n    lora_alpha: 16\n    target_modules: all-linear\n    exclude_modules: null\n    lora_adapter_path: null\n    use_liger: false\n    use_fused_kernels: false\n    fused_kernel_options:\n      impl_backend: torch\n  hybrid_engine: true\n  nccl_timeout: 600\ndata:\n  tokenizer: null\n  use_shm: false\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n  train_max_samples: -1\n  val_max_samples: -1\n  prompt_key: prompt\n  reward_fn_key: data_source\n  max_prompt_length: 512\n  max_response_length: 512\n  train_batch_size: 1024\n  val_batch_size: null\n  tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path,\n    null}\n  return_raw_input_ids: false\n  return_raw_chat: false\n  return_full_prompt: false\n  shuffle: true\n  seed: null\n  dataloader_num_workers: 8\n  image_patch_size: 14\n  validation_shuffle: false\n  filter_overlong_prompts: false\n  filter_overlong_prompts_workers: 1\n  truncation: error\n  image_key: images\n  video_key: videos\n  trust_remote_code: false\n  custom_cls:\n    path: null\n    name: null\n  return_multi_modal_inputs: true\n  sampler:\n    class_path: null\n    class_name: null\n  datagen:\n    path: null\n    name: null\n  apply_chat_template_kwargs: {}\ncritic:\n  optim:\n    _target_: verl.workers.config.FSDPOptimizerConfig\n    optimizer: AdamW\n    optimizer_impl: torch.optim\n    lr: 1.0e-05\n    lr_warmup_steps_ratio: 0.0\n    total_training_steps: -1\n    weight_decay: 0.01\n    lr_warmup_steps: -1\n    betas:\n    - 0.9\n    - 0.999\n    clip_grad: 1.0\n    min_lr_ratio: 0.0\n    num_cycles: 0.5\n    lr_scheduler_type: constant\n    warmup_style: null\n    override_optimizer_config: null\n  model:\n    fsdp_config:\n      _target_: verl.workers.config.FSDPEngineConfig\n      wrap_policy:\n        min_num_params: 0\n      param_offload: false\n      optimizer_offload: false\n      offload_policy: false\n      reshard_after_forward: true\n      fsdp_size: -1\n      forward_prefetch: false\n      model_dtype: fp32\n      use_orig_params: false\n      ulysses_sequence_parallel_size: 1\n      entropy_from_logits_with_chunking: false\n      use_torch_compile: true\n      entropy_checkpointing: false\n      forward_only: false\n      strategy: fsdp\n    path: ~/models/deepseek-llm-7b-chat\n    tokenizer_path: ${oc.select:actor_rollout_ref.model.path,\"~/models/deepseek-llm-7b-chat\"}\n    override_config: {}\n    external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}\n    trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}\n    _target_: verl.workers.config.FSDPCriticModelCfg\n    use_shm: false\n    enable_gradient_checkpointing: true\n    enable_activation_offload: false\n    use_remove_padding: false\n    lora_rank: 0\n    lora_alpha: 16\n    target_modules: all-linear\n  _target_: verl.workers.config.FSDPCriticConfig\n  rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}\n  strategy: fsdp\n  enable: null\n  ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}\n  ppo_micro_batch_size: null\n  ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}\n  use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n  ppo_max_token_len_per_gpu: 32768\n  forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}\n  ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1}\n  shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false}\n  cliprange_value: 0.5\n  loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}\n  checkpoint:\n    _target_: verl.trainer.config.CheckpointConfig\n    save_contents:\n    - model\n    - optimizer\n    - extra\n    load_contents: ${.save_contents}\n    async_save: false\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    tool: ${oc.select:global_profiler.tool,null}\n    enable: false\n    all_ranks: false\n    ranks: []\n    save_path: ${oc.select:global_profiler.save_path,null}\n    tool_config:\n      nsys:\n        _target_: verl.utils.profiler.config.NsightToolConfig\n        discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}\n      npu:\n        _target_: verl.utils.profiler.config.NPUToolConfig\n        contents: []\n        level: level1\n        analysis: true\n        discrete: false\n      torch:\n        _target_: verl.utils.profiler.config.TorchProfilerToolConfig\n        step_start: 0\n        step_end: null\n      torch_memory:\n        _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n        trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}\n        stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}\n  forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null}\n  forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null}\n  ulysses_sequence_parallel_size: 1\n  grad_clip: 1.0\nreward_model:\n  enable: false\n  enable_resource_pool: false\n  n_gpus_per_node: 0\n  nnodes: 0\n  strategy: fsdp\n  model:\n    input_tokenizer: ${actor_rollout_ref.model.path}\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    trust_remote_code: false\n    use_shm: false\n    use_remove_padding: false\n    use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}\n    fsdp_config:\n      _target_: verl.workers.config.FSDPEngineConfig\n      wrap_policy:\n        min_num_params: 0\n      param_offload: false\n      reshard_after_forward: true\n      fsdp_size: -1\n      forward_prefetch: false\n  micro_batch_size: null\n  micro_batch_size_per_gpu: null\n  max_length: null\n  use_dynamic_bsz: ${critic.use_dynamic_bsz}\n  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n  reward_manager: naive\n  launch_reward_fn_async: false\n  sandbox_fusion:\n    url: null\n    max_concurrent: 64\n    memory_limit_mb: 1024\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    tool: ${oc.select:global_profiler.tool,null}\n    enable: false\n    all_ranks: false\n    ranks: []\n    save_path: ${oc.select:global_profiler.save_path,null}\n    tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}\n  ulysses_sequence_parallel_size: 1\ncustom_reward_function:\n  path: null\n  name: compute_score\nalgorithm:\n  _target_: verl.trainer.config.AlgoConfig\n  gamma: 1.0\n  lam: 1.0\n  adv_estimator: gae\n  norm_adv_by_std_in_grpo: true\n  use_kl_in_reward: false\n  kl_penalty: kl\n  kl_ctrl:\n    _target_: verl.trainer.config.KLControlConfig\n    type: fixed\n    kl_coef: 0.001\n    horizon: 10000\n    target_kl: 0.1\n  use_pf_ppo: false\n  pf_ppo:\n    reweight_method: pow\n    weight_pow: 2.0\n  rollout_is_threshold: null\n  rollout_is_threshold_lower: null\n  rollout_is_level: token\n  rollout_is_mode: truncate\n  rollout_is_veto_threshold: null\n  rollout_is: false\ntrainer:\n  balance_batch: true\n  total_epochs: 30\n  total_training_steps: null\n  project_name: verl_examples\n  experiment_name: gsm8k\n  logger:\n  - console\n  - wandb\n  log_val_generations: 0\n  rollout_data_dir: null\n  validation_data_dir: null\n  nnodes: 1\n  n_gpus_per_node: 8\n  save_freq: -1\n  esi_redundant_time: 0\n  resume_mode: auto\n  resume_from_path: null\n  val_before_train: true\n  val_only: false\n  test_freq: -1\n  critic_warmup: 0\n  default_hdfs_dir: null\n  del_local_ckpt_after_load: false\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  max_actor_ckpt_to_keep: null\n  max_critic_ckpt_to_keep: null\n  ray_wait_register_center_timeout: 300\n  device: cuda\n  use_legacy_worker_impl: auto\nglobal_profiler:\n  _target_: verl.utils.profiler.ProfilerConfig\n  tool: null\n  steps: null\n  profile_continuous_steps: false\n  save_path: outputs/profile\n  global_tool_config:\n    nsys:\n      _target_: verl.utils.profiler.config.NsightToolConfig\n      discrete: false\n      controller_nsight_options:\n        trace: cuda,nvtx,cublas,ucx\n        cuda-memory-usage: 'true'\n        cuda-graph-trace: graph\n      worker_nsight_options:\n        trace: cuda,nvtx,cublas,ucx\n        cuda-memory-usage: 'true'\n        cuda-graph-trace: graph\n        capture-range: cudaProfilerApi\n        capture-range-end: null\n        kill: none\n    torch_memory:\n      trace_alloc_max_entries: 100000\n      stack_depth: 32\n      context: all\n      stacks: all\n      kw_args: {}\ntransfer_queue:\n  enable: false\nray_kwargs:\n  ray_init:\n    num_cpus: null\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/actor/actor.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# Target class for this configuration\n_target_: verl.workers.config.ActorConfig\n\n# the abstract actor configs\n# fsdp, fsdp2 or megatron. must be set.\nstrategy: ???\n\n# Split each sample into sub-batches of this size for PPO\nppo_mini_batch_size: 256\n\n# [Deprecated] Global micro batch size\nppo_micro_batch_size: null\n\n# Local per-GPU micro batch size\nppo_micro_batch_size_per_gpu: null\n\n# Whether to automatically adjust batch size at runtime\n# oc.select: the default val for ref.log_prob_use_dynamic_bsz\nuse_dynamic_bsz: false\n\n# Max tokens per GPU in one PPO batch; affects gradient accumulation\n# Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length}\n# oc.select: the default val for ref.log_prob_max_token_len_per_gpu\nppo_max_token_len_per_gpu: 16384\n\n# PPO clip ratio\nclip_ratio: 0.2\n\n# Lower bound for asymmetric clipping (used in dual-clip PPO)\nclip_ratio_low: 0.2\n\n# Upper bound for asymmetric clipping (used in dual-clip PPO)\nclip_ratio_high: 0.2\n\n# Whether to freeze vision model, if set true, it will be freeze vision model\nfreeze_vision_tower: false\n\n# policy loss config\npolicy_loss:\n\n  # # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.workers.config.PolicyLossConfig\n\n  # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617\n  loss_mode: \"vanilla\"\n\n  # Ratio of tokens to be clipped for clip-cov loss\n  clip_cov_ratio: 0.0002\n\n  # Lower bound for clip-cov loss\n  clip_cov_lb: 1.0\n\n  # Upper bound for clip-cov loss\n  clip_cov_ub: 5.0\n\n  # Ratio of tokens to be applied kl penalty for kl-cov loss\n  kl_cov_ratio: 0.0002\n\n  # KL divergence penalty coefficient\n  ppo_kl_coef: 0.1\n\n# Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C\nclip_ratio_c: 3.0\n\n# Loss aggregation mode: \"token-mean\", \"seq-mean-token-sum\", or \"seq-mean-token-mean\"\nloss_agg_mode: token-mean\n\n# Entropy regularization coefficient in PPO loss\nentropy_coeff: 0\n\n# Whether to use KL loss instead of KL reward penalty. True for GRPO\nuse_kl_loss: false\n\n# Whether to use torch.compile()\n# oc.select: the default val for ref.use_torch_compile\nuse_torch_compile: true\n\n# float val to replace the ref_log_prob\nref_log_prob_replace_val: -10.0\n\n# KL loss coefficient when use_kl_loss is enabled. For GRPO\nkl_loss_coef: 0.001\n\n# Type of KL divergence loss. Options: \"kl\"(k1), \"abs\", \"mse\"(k2), \"low_var_kl\"(k3), \"full\"\nkl_loss_type: low_var_kl\n\n# Number of PPO epochs per batch\nppo_epochs: 1\n\n# Shuffle training data across PPO epochs\nshuffle: false\n\n# checkpoint configs\ncheckpoint:\n\n  # Target dataclass for this configuration\n  _target_: verl.trainer.config.CheckpointConfig\n\n  # What to include in saved checkpoints\n  # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n  save_contents: ['model', 'optimizer', 'extra']\n\n  # For more flexibility, you can specify the contents to load from the checkpoint.\n  # .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg\n  load_contents: ${.save_contents}\n\n  # Whether to save checkpoints asynchronously. Only effective for Megatron as of now.\n  async_save: False\n\n# optimizer configs\noptim:\n\n  # Learning rate\n  lr: 1e-6\n\n  # Warmup steps ratio (used if lr_warmup_steps is 0 or negative)\n  lr_warmup_steps_ratio: 0.0\n\n  # Total training steps (must be overridden at runtime)\n  total_training_steps: -1\n\n  # Weight decay\n  weight_decay: 0.01\n\n  # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.\n  lr_warmup_steps: -1\n\n\n# Whether to use custom fused kernels (e.g., FlashAttention, fused MLP)\nuse_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}\n\n# profile the actor model in `update_policy` \nprofiler:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.utils.profiler.ProfilerConfig\n\n  # profiler tool, default same as profiler.tool in global config\n  # choices: nsys, npu, torch\n  tool: ${oc.select:global_profiler.tool,null}\n\n  # whether enable profile on Actor\n  enable: False\n  \n  # Whether to profile all ranks.\n  all_ranks: False\n\n  # The ranks that will be profiled. [] or [0,1,...]\n  ranks: []\n\n  # profile results saving path\n  save_path: ${oc.select:global_profiler.save_path,null}\n\n  # specific tool config which only related to the role\n  tool_config:\n\n    # nsys tool config\n    nsys:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.NsightToolConfig\n    \n      # True for each task has its own database, False for all tasks in one training step share one database.\n      discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}\n    \n    # npu config\n    npu:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.NPUToolConfig\n\n      # Contents to profile, can be empty\n      # options: npu, cpu, memory, shapes, module, stack\n      contents: []\n\n      # Collection level, optional values: level_none, level0, level1, level2.\n      level: \"level1\"\n\n      # Whether to automatically parse the data.\n      analysis: True\n\n      # True for each task has its own database, False for all tasks in one training step share one database.\n      discrete: False\n    \n    # torch profiler config\n    torch:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.TorchProfilerToolConfig\n\n      # start profile mini-batch in training\n      # NOTICE: different with global steps config which refers to iteration\n      # This field only related with mini-batch\n      step_start: 0\n\n      # stop profile mini-batch in training\n      step_end: null\n\n    # torch memory profiler config\n    torch_memory:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n\n      # Maximum number of memory allocation entries to track\n      trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}\n\n      # Stack trace depth for memory allocations\n      stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/actor/dp_actor.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# defaults specify the default config from each component\ndefaults:\n\n  # fsdp optimizer config\n  - ../optim@optim: fsdp\n\n  # fsdp engine config\n  - ../engine@fsdp_config: fsdp\n\n  # dp actor config, inheriting from trainer/config/actor/actor.yaml\n  - actor\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\n# Target class for this configuration\n_target_: verl.workers.config.FSDPActorConfig\n\n# TODO(haibin.lin): switch to fsdp2\nstrategy: fsdp\n\n# Gradient clipping for actor updates, specific to the strategy.\ngrad_clip: 1.0\n\n# Sequence parallelism size for Ulysses-style model parallelism\n# oc.select: the default val for ref.ulysses_sequence_parallel_size\nulysses_sequence_parallel_size: 1\n\n# calculate entropy with chunking to reduce memory peak\nentropy_from_logits_with_chunking: False\n\n# recompute entropy\nentropy_checkpointing: False\n\n# Whether to remove padding tokens in inputs during training\nuse_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}"
  },
  {
    "path": "verl_distillation/verl/trainer/config/actor/megatron_actor.yaml",
    "content": "# megatron actor config, inheriting from trainer/config/actor/actor.yaml\ndefaults:\n  # megatron optimizer config\n  - ../optim@optim: megatron\n\n  # megatron engine config\n  - ../engine@megatron: megatron\n\n  - actor\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\n_target_: verl.workers.config.McoreActorConfig\n\nstrategy: megatron\n\ndata_loader_seed: null\n\nload_weight: True\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/algorithm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, Optional\n\nfrom verl.base_config import BaseConfig\n\n__all__ = [\"AlgoConfig\", \"FilterGroupsConfig\", \"KLControlConfig\"]\n\n\n@dataclass\nclass KLControlConfig(BaseConfig):\n    \"\"\"Configuration for KL control.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        type (str): Type of KL control. Can be \"fixed\" or \"adaptive\".\n        kl_coef (float): Initial coefficient for KL penalty.\n        horizon (int): Horizon value for adaptive controller.\n        target_kl (float): Target KL divergence for adaptive controller.\n    \"\"\"\n\n    type: str = \"fixed\"\n    kl_coef: float = 0.001\n    horizon: int = 10000\n    target_kl: float = 0.1\n\n\n@dataclass\nclass FilterGroupsConfig(BaseConfig):\n    \"\"\"Configuration for filter groups (used in DAPO and Entropy).\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        enable (bool): Whether to enable filter groups.\n        metric (Optional[str]): Metric to use for filtering: \"acc\", \"score\", \"seq_reward\", \"seq_final_reward\", etc.\n        max_num_gen_batches (int): Non-positive values mean no upper limit.\n    \"\"\"\n\n    enable: bool = False\n    metric: Optional[str] = None\n    max_num_gen_batches: int = 0\n\n\n@dataclass\nclass AlgoConfig(BaseConfig):\n    \"\"\"Configuration for the algorithm.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        gamma (float): Discount factor for future rewards.\n        lam (float): Trade-off between bias and variance in the GAE estimator.\n        adv_estimator (str): Advantage estimator type: \"gae\", \"grpo\", \"reinforce_plus_plus\", etc.\n        norm_adv_by_std_in_grpo (bool): Whether to normalize advantages by std (specific to GRPO).\n        use_kl_in_reward (bool): Whether to enable in-reward KL penalty.\n        kl_penalty (str): How to estimate KL divergence: \"kl\", \"abs\", \"mse\", \"low_var_kl\", or \"full\".\n        kl_ctrl (KLControlConfig): KL control configuration.\n        use_pf_ppo (bool): Whether to enable preference feedback PPO.\n        pf_ppo (dict[str, Any]): Preference feedback PPO settings.\n        filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy\n        rollout_is_threshold (Optional[float]): Upper threshold for IS weights. null = disabled,\n            float value = enabled (compute weights and metrics). This is the main on/off switch.\n        rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper.\n        rollout_is_level (str): Aggregation level: \"token\", \"sequence\", or \"geometric\".\n        rollout_is_mode (str): Bounding mode: \"truncate\" (cap upper only) or \"mask\" (zero outside bounds).\n        rollout_is_veto_threshold (float or None): Per-token veto threshold for catastrophic outliers. None to disable.\n        rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights,\n            False = compute metrics only (useful for monitoring before enabling correction). Default: False.\n    \"\"\"\n\n    gamma: float = 1.0\n    lam: float = 1.0\n    adv_estimator: str = \"gae\"\n    norm_adv_by_std_in_grpo: bool = True\n    use_kl_in_reward: bool = False\n    kl_penalty: str = \"kl\"\n    kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig)\n    use_pf_ppo: bool = False\n    pf_ppo: dict[str, Any] = field(default_factory=dict)\n    filter_groups: Optional[FilterGroupsConfig] = None\n    # Rollout Importance Sampling\n    # Controls computation of IS weights and mismatch metrics\n    rollout_is_threshold: Optional[float] = None  # null = disabled, float = enabled\n    rollout_is_threshold_lower: Optional[float] = None\n    rollout_is_level: str = \"token\"\n    rollout_is_mode: str = \"truncate\"\n    rollout_is_veto_threshold: Optional[float] = None\n    # Controls whether to apply IS weights to policy loss (only if rollout_is_threshold is set)\n    # True = apply weights to loss, False = compute metrics only (no weight application)\n    rollout_is: bool = False\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/config.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, Optional\n\nfrom verl.base_config import BaseConfig\n\n__all__ = [\"CheckpointConfig\", \"ProfileConfig\", \"BaseModelConfig\"]\n\n\n@dataclass\nclass CheckpointConfig(BaseConfig):\n    \"\"\"Configuration for model checkpointing.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        save_contents (list[str]): What to include in saved checkpoints.\n            Options: 'model', 'optimizer', 'extra', 'hf_model'.\n        load_contents (list[str]): Contents to load from checkpoint. Defaults to same as save_contents.\n        async_save (bool): Whether to save checkpoints asynchronously. Only implemented for Megatron as of now.\n    \"\"\"\n\n    save_contents: list[str] = field(default_factory=lambda: [\"model\", \"optimizer\", \"extra\"])\n    load_contents: list[str] = field(default_factory=lambda: [\"model\", \"optimizer\", \"extra\"])\n    async_save: bool = False\n\n\n@dataclass\nclass ProfileConfig(BaseConfig):\n    \"\"\"Configuration for profiling.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        profile_ranks (Optional[list[int]]): List of ranks to profile. None means all ranks.\n        step_start (int): Starting step for profiling.\n        step_end (int): Ending step for profiling.\n        save_path (Optional[str]): Path to save profiling results.\n    \"\"\"\n\n    profile_ranks: Optional[list[int]] = None\n    step_start: int = -1\n    step_end: int = -1\n    save_path: Optional[str] = None\n\n\n@dataclass\nclass BaseModelConfig(BaseConfig):\n    \"\"\"Base configuration for a model.\n    Contains core settings for loading and initializing a pretrained model checkpoint.\n\n    Args:\n        path (str): Path to pretrained model weights.\n        tokenizer_path (Optional[str]): Tokenizer path (defaults to actor's model path if not set).\n        override_config (dict): Hugging Face config override.\n        external_lib (Optional[str]): External model implementation (optional).\n        trust_remote_code (bool): Whether to trust remote code from Hugging Face models.\n    \"\"\"\n\n    path: str = \"~/models/deepseek-llm-7b-chat\"\n    tokenizer_path: Optional[str] = None\n    override_config: dict[str, Any] = field(default_factory=dict)\n    external_lib: Optional[str] = None\n    trust_remote_code: bool = False\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/critic/critic.yaml",
    "content": "# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n_target_: verl.workers.config.CriticConfig\n\n# Number of rollouts per update (mirrors actor rollout_n)\nrollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}\n\n# fsdp or fsdp2 strategy used for critic model training\nstrategy: ???\n\n# whether to enable the critic worker.\n# by default it is only enabled if advantage estimator is gae\n# set it to True manually if you always want to enable critic worker\nenable: null\n\n# optimizer configs\noptim:\n\n  # Learning rate\n  lr: 1e-5\n\n  # Warmup steps ratio; total steps will be injected at runtime\n  lr_warmup_steps_ratio: 0.0\n\n  # Total training steps (must be overridden at runtime)\n  total_training_steps: -1\n\n  # Weight decay\n  weight_decay: 0.01\n\n  # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.\n  lr_warmup_steps: -1\n\n\n# model config for the critic\nmodel:\n\n  # Path to pretrained model weights\n  path: ~/models/deepseek-llm-7b-chat\n\n  # Tokenizer path (defaults to actor's model path)\n  tokenizer_path: ${oc.select:actor_rollout_ref.model.path,\"~/models/deepseek-llm-7b-chat\"}\n\n  # Hugging Face config override\n  override_config: {}\n\n  # External model implementation (optional)\n  external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}\n\n  # Whether to trust remote code from Hugging Face models\n  trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}\n\n# PPO mini-batch size per update\nppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}\n\n# [Deprecated] Global micro batch size\nppo_micro_batch_size: null\n\n# Local per-GPU micro batch size\nppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}\n\n# Whether to automatically adjust batch size at runtime\nuse_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n\n# Max tokens per GPU in one PPO batch (doubled for critic)\nppo_max_token_len_per_gpu: 32768\n\n# Max token length per GPU in forward pass\nforward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}\n\n# Number of PPO epochs per batch\nppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1}\n\n# Shuffle training data across PPO epochs\nshuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false}\n\n# PPO value function clipping range\ncliprange_value: 0.5\n\n# Loss aggregation mode: \"token-mean\", \"seq-mean-token-sum\", or \"seq-mean-token-mean\"\nloss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}\n\n# checkpoint configs\ncheckpoint:\n\n  # Target dataclass for this configuration\n  _target_: verl.trainer.config.CheckpointConfig\n\n  # What to include in saved checkpoints\n  # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n  save_contents: ['model', 'optimizer', 'extra']\n\n  # What to include when loading checkpoints\n  load_contents: ${.save_contents}\n\n  # Whether to save checkpoints asynchronously. Only effective for Megatron as of now.\n  async_save: False\n\n# profile the critic model in `update_critic`\nprofiler:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.utils.profiler.ProfilerConfig\n\n  # profiler tool, default same as profiler.tool in global config\n  # choices: nsys, npu, torch, torch_memory\n  tool: ${oc.select:global_profiler.tool,null}\n\n  # whether enable profile on Critic\n  enable: False\n\n  # Whether to profile all ranks.\n  all_ranks: False\n\n  # The ranks that will be profiled. [] or [0,1,...]\n  ranks: []\n\n  # profile results saving path\n  save_path: ${oc.select:global_profiler.save_path,null}\n\n  # specific tool config which only related to the role\n  tool_config:\n\n    # nsys tool config\n    nsys:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.NsightToolConfig\n    \n      # True for each task has its own database, False for all tasks in one training step share one database.\n      discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}\n    \n    # npu config\n    npu:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.NPUToolConfig\n\n      # Contents to profile, can be empty\n      # options: npu, cpu, memory, shapes, module, stack\n      contents: []\n\n      # Collection level, optional values: level_none, level0, level1, level2.\n      level: \"level1\"\n\n      # Whether to automatically parse the data.\n      analysis: True\n\n      # True for each task has its own database, False for all tasks in one training step share one database.\n      discrete: False\n    \n    # torch profiler config\n    torch:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.TorchProfilerToolConfig\n\n      # start profile mini-batch in training\n      # NOTICE: different with global steps config which refers to iteration\n      # This field only related with mini-batch\n      step_start: 0\n\n      # stop profile mini-batch in training\n      step_end: null\n\n    # torch memory profiler config\n    torch_memory:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n\n      # Maximum number of memory allocation entries to track\n      trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}\n\n      # Stack trace depth for memory allocations\n      stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}\n      "
  },
  {
    "path": "verl_distillation/verl/trainer/config/critic/dp_critic.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# defaults specify the default config from each component\ndefaults:\n\n  # fsdp optimizer config\n  - ../optim@optim: fsdp\n\n  # fsdp engine config\n  - ../engine@model.fsdp_config: fsdp\n\n  # dp actor config, inheriting from trainer/config/critic/critic.yaml\n  - critic\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\n# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n_target_: verl.workers.config.FSDPCriticConfig\n\n# distribution strategy. Options: fsdp (deprecating), fsdp2\nstrategy: fsdp\n\n# model config for the critic\nmodel:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.workers.config.FSDPCriticModelCfg\n\n  # Whether to use shared memory for loading the model\n  use_shm: False\n\n  # Enable gradient checkpointing to save memory\n  enable_gradient_checkpointing: True\n\n  # Offload activations to CPU to reduce GPU memory usage\n  enable_activation_offload: False\n\n  # Use remove padding optimization (saves compute)\n  use_remove_padding: False\n\n  # Set to positive value to enable LoRA (e.g., 32)\n  lora_rank: 0\n\n  # LoRA scaling factor\n  lora_alpha: 16\n\n  # LoRA target modules: \"all-linear\" or list of linear projection layers\n  target_modules: all-linear\n\n# Forward-only batch size during inference (global)\nforward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null}\n\n# Forward-only batch size during inference (per GPU)\nforward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null}\n\n# Sequence parallelism size for Ulysses-style model parallelism\nulysses_sequence_parallel_size: 1\n\n# Gradient clipping for critic updates\ngrad_clip: 1.0\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/critic/megatron_critic.yaml",
    "content": "# defaults specify the default config from each component\ndefaults:\n\n  # megatron optimizer config\n  - ../optim@optim: megatron\n\n  # megatron engine config\n  - ../engine@megatron: megatron\n\n  # dp actor config, inheriting from trainer/config/critic/critic.yaml\n  - critic\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\n# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n_target_: verl.workers.config.McoreCriticConfig\n\nstrategy: megatron\n\n# seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron\nnccl_timeout: 600\n\n# model config for the critic\nmodel:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.trainer.config.BaseModelConfig\n\n  # override default empty mapping\n  override_config:\n\n    model_config: {}\n\n    moe_config:\n\n      freeze_moe_router: False\n\n# Whether to load initial weights\nload_weight: True\n\n# seed for data loader\ndata_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null}\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/data/legacy_data.yaml",
    "content": "# Tokenizer class or path. If null, it will be inferred from the model.\ntokenizer: null\n\n# Whether to use shared memory for data loading.\nuse_shm: False\n\n# Training set parquet. Can be a list or a single file.\n# The program will read all files into memory, so it can't be too large (< 100GB).\n# The path can be either a local path or an HDFS path.\n# For HDFS path, we provide utils to download it to DRAM and convert it to a local path.\ntrain_files: ~/data/rlhf/gsm8k/train.parquet\n\n# Validation parquet. Can be a list or a single file.\nval_files: ~/data/rlhf/gsm8k/test.parquet\n\n# Maximum sample length to be used.\n# Set to -1 to use full dataset, otherwise, randomly\n# select the specified number of samples from train dataset\ntrain_max_samples: -1\n\n# Maximum sample length to be used.\n# Set to -1 to use full dataset, otherwise, randomly\n# select the specified number of samples from val dataset\nval_max_samples: -1\n\n# The field in the dataset where the prompt is located. Default is 'prompt'.\nprompt_key: prompt\n\n# The field used to select the reward function (if using different ones per example).\nreward_fn_key: data_source\n\n# Maximum prompt length. All prompts will be left-padded to this length.\n# An error will be reported if the length is too long.\n# oc.select: default val for rollout.prompt_length\nmax_prompt_length: 512\n\n# Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length.\n# oc.select: default val for rollout.response_length\nmax_response_length: 512\n\n# Batch size sampled for one training iteration of different RL algorithms.\ntrain_batch_size: 1024\n\n# Batch size used during validation. Can be null.\nval_batch_size: null\n\n# use tool config to calculate true prompt length\ntool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, null}\n\n# Whether to return the original input_ids without adding chat template.\n# This is used when the reward model's chat template differs from the policy.\n# If using a model-based RM with different templates, this should be True.\nreturn_raw_input_ids: False\n\n# Whether to return the original chat (prompt) without applying chat template.\nreturn_raw_chat: False\n\n# Whether to return the full prompt with chat template.\nreturn_full_prompt: False\n\n# Whether to shuffle the data in the dataloader.\nshuffle: True\n\n# Seed to use when shuffling the data\nseed: null\n\n# num dataloader workers\ndataloader_num_workers: 8\n\n# image patch size\nimage_patch_size: 14\n\n# Whether to shuffle the validation set.\nvalidation_shuffle: False\n\n# Whether to filter overlong prompts.\nfilter_overlong_prompts: False\n\n# Number of workers for filtering overlong prompts.\n# For large-scale datasets, filtering can be time-consuming.\n# Use multiprocessing to speed up. Default is 1.\nfilter_overlong_prompts_workers: 1\n\n# Truncate the input_ids or prompt if they exceed max_prompt_length.\n# Options: 'error', 'left', 'right', 'middle'. Default is 'error'.\ntruncation: error\n\n# The field in the multi-modal dataset where the image is located. Default is 'images'.\nimage_key: images\n\n# The field in the multi-modal dataset where the video is located.\nvideo_key: videos\n\n# If the remote tokenizer has a Python file, this flag determines whether to allow using it.\ntrust_remote_code: False\n\n# Optional: specify a custom dataset class path and name if overriding default loading behavior.\ncustom_cls:\n\n  # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used.\n  path: null\n\n  # The name of the dataset class within the specified file.\n  name: null\n\n# Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs.\nreturn_multi_modal_inputs: True\n\n# settings related to data sampler\nsampler:\n\n  # the path to the module containing a curriculum class which implements the\n  # AbstractSampler interface\n  class_path: null\n\n  # the name of the curriculum class like `MySampler`\n  class_name: null\n\n# Data generation configuration for augmenting the dataset.\ndatagen:\n\n  # The path to the file containing your customized data generation class.\n  # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset'\n  path: null\n\n  # The class name of the data generation class within the specified file.\n  # E.g. 'MockDataGenerator'\n  name: null\n\n# Additional kwargs when calling tokenizer.apply_chat_template\napply_chat_template_kwargs: {}\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/engine/fsdp.yaml",
    "content": "# Target class for this configuration\n_target_: verl.workers.config.FSDPEngineConfig\n\n# policy for wrapping the model\nwrap_policy:\n\n  # Minimum number of parameters to trigger wrapping a layer with FSDP\n  min_num_params: 0\n\n# Whether to offload model parameters to CPU (trades speed for memory)\n# Note that this differs from the offload_policy in FSDP\nparam_offload: false\n\n# Whether to offload optimizer state to CPU\n# Note that this differs from the offload_policy in FSDP\noptimizer_offload: false\n\n# Only for FSDP2: offload param/grad/optimizer during train\noffload_policy: false\n\n# Only for FSDP2: Reshard after forward pass to reduce memory footprint\nreshard_after_forward: true\n\n# Number of GPUs in each FSDP shard group; -1 means auto\nfsdp_size: -1\n\n# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n# before the current forward computation.\nforward_prefetch: False\n\n# model dtype of fsdp\nmodel_dtype: fp32\n\n# Whether to use original parameters in fsdp. Only avaiable in fsdp1\nuse_orig_params: false\n\n# ulysses sequence parallel size\nulysses_sequence_parallel_size: 1\n\n# Whether to use entropy_from_logits_with_chunking in fsdp.\nentropy_from_logits_with_chunking: false\n\n# Whether to use torch compile in fsdp.\nuse_torch_compile: true\n\n# Whether to use entropy checkpointing in fsdp.\nentropy_checkpointing: false\n\n# Whether to use forward only in fsdp.\nforward_only: false\n\n# fsdp or fsdp2\nstrategy: fsdp\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/engine/megatron.yaml",
    "content": "# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n_target_: verl.workers.config.McoreEngineConfig\n\n# Whether to offload model parameters to CPU\nparam_offload: False\n\n# Whether to offload gradients to CPU\ngrad_offload: False\n\n# Whether to offload optimizer state to CPU\noptimizer_offload: False\n\n# tensor model parallel size\ntensor_model_parallel_size: 1\n\n# expert model parallel size\nexpert_model_parallel_size: 1\n\n# expert tensor parallel size\nexpert_tensor_parallel_size: 1\n\n# pipeline model parallel size\npipeline_model_parallel_size: 1\n\n# virtual pipeline model parallel size\nvirtual_pipeline_model_parallel_size: null\n\n# context parallel size\ncontext_parallel_size: 1\n\n# sequence parallel\nsequence_parallel: True\n\n# Whether to use distributed optimizer\nuse_distributed_optimizer: True\n\n# Whether to use distributed checkpointing\nuse_dist_checkpointing: False\n\n# distributed checkpointing path\ndist_checkpointing_path: null\n\n# oc.select: default val for ref.megatron.seed\nseed: 42\n\n# Allow to override Distributed Data Parallel (DDP) config\noverride_ddp_config: {}\n\n# additional transformer config like: num_layers_in_first(/last)_pipeline_stage\n# oc.select: default val for ref.megatron.override_transformer_config\noverride_transformer_config:\n  # Recompute configuration, same as in megatron.training.arguments\n  # default use minimal performance-interference recompute methods\n  # Recompute granualarity, choices: [\"full\", \"selective\"]\n  recompute_granularity: null\n\n  # Recompute modules, multiple choices: [\"core_attn\", \"moe_act\", \"layernorm\", \"mla_up_proj\", \"mlp\", \"moe\"]\n  # Please use correct module in matched model\n  recompute_modules: [\"core_attn\"]\n\n  # 'uniform', 'block'\n  # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk\n  # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity\n  recompute_method: null\n\n  # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention\n  recompute_num_layers: null\n\n  # Attention backend to use (flash,fused,unfused,local,auto). Defaults to auto in mcore, flash in verl\n  attention_backend: flash\n\noverride_mcore_model_config: {}\n\n# oc.select: default val for ref.megatron.use_mbridge\nuse_mbridge: False\n\n# whether to use forward only\nforward_only: False\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/evaluation.yaml",
    "content": "data:\n  path: /tmp/math_Qwen2-7B-Instruct.parquet\n  prompt_key: prompt\n  response_key: responses\n  data_source_key: data_source\n  reward_model_key: reward_model\n\ncustom_reward_function:\n  path: null\n  name: compute_score\n\nray_kwargs:\n  ray_init:\n    num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/generation.yaml",
    "content": "trainer:\n  nnodes: 1\n  n_gpus_per_node: 8\n  device: cuda\n\ndata:\n  path: ~/data/rlhf/math/test.parquet\n  prompt_key: prompt\n  n_samples: 5\n  output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet\n  batch_size: 128\n\nmodel:\n  path: ~/models/Qwen2-7B-Instruct\n  external_lib: null\nrollout:\n  _target_: verl.workers.config.RolloutConfig\n  name: vllm\n  mode: sync # sync: LLM, async: AsyncLLM\n  temperature: 1.0\n  top_k: 50 # 0 for hf rollout, -1 for vllm rollout\n  top_p: 0.7\n  prompt_length: 1536\n  response_length: 512\n  # for vllm rollout\n  dtype: bfloat16 # should align with FSDP\n  gpu_memory_utilization: 0.5\n  ignore_eos: False\n  enforce_eager: True\n  free_cache_engine: True\n  load_format: auto\n  tensor_model_parallel_size: 1\n  data_parallel_size: 1\n  max_num_batched_tokens: 8192\n  max_model_len: null\n  max_num_seqs: 1024\n  log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n  log_prob_micro_batch_size_per_gpu: 8\n  # for hf rollout\n  do_sample: True\n  disable_log_stats: True\n  enable_chunked_prefill: True\n  n: 1\n  # support logging rollout prob for debugging purpose\n  calculate_log_probs: False\nactor:\n  strategy: fsdp  # This is for backward-compatibility\n  ulysses_sequence_parallel_size: 1 # sp size\n  entropy_from_logits_with_chunking: False  # calculate entropy with chunking to reduce memory peak\n  entropy_checkpointing: False  # recompute entropy\n  fsdp_config:\n    fsdp_size: -1\n    forward_prefetch: False  # FSDP1 forward_prefetch configuration\n\nray_kwargs:\n  ray_init:\n    num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/model/hf_model.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n_target_: verl.workers.config.HFModelConfig\n\n# path to the huggingface model\npath: ~/models/deepseek-llm-7b-chat\n\n# config to the huggingface config. In case it is not the same as path\nhf_config_path: null\n\n# path to the huggingface tokenizer. In case it is not the same as path\ntokenizer_path: null\n\n# whether to use shared memory for model loading\nuse_shm: False\n\n# whether to trust remote code.\ntrust_remote_code: False\n\n# custom chat template for the model\ncustom_chat_template: null\n\n# whether to use external libs for the model\nexternal_lib: null\n\n# override hf config\noverride_config: {}\n\n# whether to enable gradient checkpointing. Only valid when we use hf model definition\nenable_gradient_checkpointing: True\n\n# whether to enable activation offload. Only valid when we use hf model definition\nenable_activation_offload: False\n\n# whether to use remove padding. Only valid when we use hf model definition\nuse_remove_padding: False\n\n# Set to positive value to enable LoRA (e.g., 32)\nlora_rank: 0\n\n# LoRA scaling factor\nlora_alpha: 16\n\n# Target modules for LoRA adaptation\ntarget_modules: all-linear\n\n# Exclude modules from LoRA adaptation\nexclude_modules: null\n\n# Path to pre-trained LoRA adapter to load for continued training\nlora_adapter_path: null\n\n# whether to use liger. Only valid when we use hf model definition\nuse_liger: False\n\n# whether to use fused kernels.\nuse_fused_kernels: False\n\n# fused kernel options.\nfused_kernel_options:\n\n  # the implementation backend for fused kernels.\n  impl_backend: torch\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/npu_profile/npu_profile.yaml",
    "content": "# Options for the npu profiler\noptions:\n\n  # Storage path of collected data.\n  save_path: ./profiler_data\n\n  # The roles that will be profiled. Only takes effect in discrete mode.\n  # optional values: all, rollout_generate, actor_compute_log_prob, actor_update and ref_compute_log_prob.\n  # \"all\" means all roles will be profiled.\n  roles: [\"all\"]\n\n  # Collection level, optional values: level_none, level0, level1, level2.\n  level: level1\n\n  # Whether to enable memory analysis.\n  with_memory: False\n\n  # Whether to record tensor shape.\n  record_shapes: False\n\n  # Whether to record Device-side performance data.\n  with_npu: True\n\n  # Whether to record Host-side performance data.\n  with_cpu: True\n\n  # Whether to record Python call stack information.\n  with_module: False\n\n  # Whether to record operator call stack information.\n  with_stack: False\n\n  # Whether to automatically parse the data.\n  analysis: True"
  },
  {
    "path": "verl_distillation/verl/trainer/config/optim/fsdp.yaml",
    "content": "# Target class for this configuration\n_target_: verl.workers.config.FSDPOptimizerConfig\n\n# Optimizer class name (e.g., \"AdamW\", \"AdamW8bit\", \"_AdamW\", \"Adam\")\noptimizer: AdamW\n\n# Module path to import optimizer\n# Examples: \"torch.optim\", \"torchao.optim\", \"bitsandbytes.optim\"\noptimizer_impl: torch.optim\n\n# Learning rate\nlr: 1e-3\n\n# LR warmup steps ratio\nlr_warmup_steps_ratio: 0.0\n\n# Total training steps\ntotal_training_steps: -1\n\n# Weight decay\nweight_decay: 0.01\n\n# LR warmup steps\nlr_warmup_steps: -1\n\n# Betas for Adam optimizer\nbetas: [0.9, 0.999]\n\n# Clip gradient\nclip_grad: 1.0\n\n# Minimum LR ratio for cosine schedule\nmin_lr_ratio: 0.0\n\n# Number of cosine cycles in LR schedule\nnum_cycles: 0.5\n\n# LR scheduler type: \"constant\" or \"cosine\"\nlr_scheduler_type: constant\n\n# deprecated\nwarmup_style: null\n\n# Additional optimizer-specific keyword arguments\n# Example for torchao with bf16 stochastic rounding:\n# optimizer_impl: torchao.optim\n# optimizer: _AdamW\n# override_optimizer_config:\n#   bf16_stochastic_round: true\noverride_optimizer_config: null\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/optim/megatron.yaml",
    "content": "_target_: verl.workers.config.McoreOptimizerConfig\n\n# Learning rate\nlr: 1e-3\n\n# LR warmup steps ratio\nlr_warmup_steps_ratio: 0.0\n\n# Total training steps\ntotal_training_steps: -1\n\n# Weight decay\nweight_decay: 0.01\n\n# LR warmup steps\nlr_warmup_steps: -1\n\n# Betas for Adam optimizer\nbetas: [0.9, 0.999]\n\n# Clip gradient\nclip_grad: 1.0\n\n# optimizer type\noptimizer: adam\n\n# initial learning rate for warmup, default to 0.0\nlr_warmup_init: 0.0\n\nlr_decay_steps: null\n\n# select from constant/linear/cosine/inverse_square_root\nlr_decay_style: constant\n\n# minimum learning rate, default to 0.0\nmin_lr: 0.0\n\n# select from constant/linear/cosine\nweight_decay_incr_style: constant\n\n# select from constant/exponential/cosine\nlr_wsd_decay_style: exponential\n\nlr_wsd_decay_steps: null\n\n# use checkpoint optimizer parameter scheduler\nuse_checkpoint_opt_param_scheduler: False\n\noverride_optimizer_config: {}\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/ppo_megatron_trainer.yaml",
    "content": "# specify the default per-component configs\ndefaults:\n  # <folder_name>@<field_name>.<field_name>: <yaml_file_name>\n  # actor_rollout_ref.actor: trainer/config/actor/megatron_actor.yaml\n  - actor@actor_rollout_ref.actor: megatron_actor\n  # data: trainer/config/data/legacy_data.yaml\n  - data@data: legacy_data\n  # load the reference default config, then apply the fields in the current yaml\n  # Reference model config.\n  # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True.\n  - ref@actor_rollout_ref.ref: megatron_ref\n  # Rollout model config.\n  - rollout@actor_rollout_ref.rollout: rollout\n  # Critic model config.\n  - critic@critic: megatron_critic\n  # Reward model config.\n  - reward_model@reward_model: megatron_reward_model\n  - _self_\n\nactor_rollout_ref:\n  hybrid_engine: True\n\n  nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron\n\n  model:\n\n    path: ~/models/deepseek-llm-7b-chat\n\n    custom_chat_template: null\n\n    external_lib: null\n\n    override_config:\n      model_config: {}\n\n      moe_config:\n        freeze_moe_router: False\n\n    use_fused_kernels: False # Whether to use custom fused kernels (PostProcessing, for memory efficiency)\n\n    trust_remote_code: False\n\n    # Whether to remove padding tokens in inputs during training\n    use_remove_padding: false\n\n  rollout:\n    layer_name_map:\n      qkv_layer_name: qkv\n      gate_proj_layer_name: gate_up\n\ncustom_reward_function:\n  path: null\n  name: compute_score\n\nalgorithm:\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.trainer.config.AlgoConfig\n  gamma: 1.0\n  lam: 1.0\n  adv_estimator: gae\n  norm_adv_by_std_in_grpo: True\n  use_kl_in_reward: False\n  kl_penalty: kl # how to estimate kl divergence\n  kl_ctrl:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.trainer.config.KLControlConfig\n    type: fixed\n    kl_coef: 0.001\n    horizon: 10000\n    target_kl: 0.1\n  use_pf_ppo: False\n  pf_ppo:\n    reweight_method: pow # [\"pow\", \"max_min\", \"max_random\"]\n    weight_pow: 2.0\n\n  # Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies\n  # Main control: Upper threshold for IS weights (null = disabled, float = enabled)\n  # When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.)\n  rollout_is_threshold: null\n\n  # Lower threshold for IS weights (null = auto-reciprocal of upper)\n  rollout_is_threshold_lower: null\n\n  # Aggregation level: \"token\" (biased), \"sequence\" (unbiased), \"geometric\" (experimental)\n  rollout_is_level: token\n\n  # Bounding mode: \"truncate\" (cap upper only), \"mask\" (zero outside bounds)\n  rollout_is_mode: truncate\n\n  # Per-token veto threshold for catastrophic outliers (null to disable)\n  rollout_is_veto_threshold: null\n\n  # Whether to apply IS weights to policy loss\n  # true = apply weights to loss, false = compute metrics only (no weight application)\n  # Useful for monitoring mismatch before enabling correction\n  rollout_is: false\n\ntrainer:\n  balance_batch: True\n  total_epochs: 30\n  total_training_steps: null\n  project_name: verl_examples\n  experiment_name: gsm8k\n  logger: [\"console\", \"wandb\"]\n  log_val_generations: 0\n  nnodes: 1\n  n_gpus_per_node: 8\n  save_freq: -1\n  esi_redundant_time: 0\n\n  # auto: find the last ckpt to resume. If can't find, start from scratch\n  resume_mode: auto # or disable or resume_path if resume_from_path is set\n  resume_from_path: null\n  del_local_ckpt_after_load: False\n  val_before_train: True\n  test_freq: -1\n  critic_warmup: 0\n  default_hdfs_dir: null\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  max_actor_ckpt_to_keep: null\n  max_critic_ckpt_to_keep: null\n  # The timeout for ray worker group to wait for the register center to be ready\n  ray_wait_register_center_timeout: 300\n  device: cuda\n  # Directory for logging rollout data; no dump if null\n  rollout_data_dir: null\n\nglobal_profiler:\n  _target_: verl.utils.profiler.ProfilerConfig\n  tool: null # choose between nsys, npu, torch, torch_memory\n  steps: null # profile steps\n  profile_continuous_steps: False\n  save_path: \"outputs/profile\" # profiler saving path\n  # Specific tool configs, can use +profiler.tool_config.[tool].xxx to config\n  global_tool_config:\n    # nsys config\n    nsys:\n      # True for each task has its own database, False for all tasks in one training step share one database.\n      discrete: False\n\n      # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None.\n      ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html\n      ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html\n      controller_nsight_options:\n        # Select the API(s) to be traced.\n        trace: \"cuda,nvtx,cublas,ucx\"\n\n        # Track the GPU memory usage by CUDA kernels. Must be string type \"true\" or \"false\".\n        cuda-memory-usage: \"true\"\n\n        # CUDA graphs will be traced as a whole\n        cuda-graph-trace: \"graph\"\n\n      # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None.\n      worker_nsight_options:\n        # Select the API(s) to be traced.\n        trace: \"cuda,nvtx,cublas,ucx\"\n\n        # Track the GPU memory usage by CUDA kernels. Must be string type \"true\" or \"false\".\n        cuda-memory-usage: \"true\"\n\n        # CUDA graphs will be traced as a whole\n        cuda-graph-trace: \"graph\"\n\n        # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config.\n        capture-range: \"cudaProfilerApi\"\n\n        # Specify the desired behavior when a capture range ends.\n        # In verl we need the torch.cuda.profiler.start/stop pair to repeats n times.\n        # valid values are \"repeat-shutdown:n\" or null.\n        # For normal whole step profiling, n = len(profile_steps);\n        # but for discrete profiling, n = len(profile_steps) * Number(subtasks).\n        # Or you can just leave it null and the program will use n = len(profile_steps) * 6;\n        capture-range-end: null\n\n        # Send signal to the target application's process group. We let the program to exit by itself.\n        kill: none\n\n    # enable memory visualization for debugging memory usage\n    torch_memory:\n      #  Maximum number of allocation entries to record\n      trace_alloc_max_entries: 100_000\n      # The depth of the call stack to capture for each allocation\n      stack_depth: 32\n      # 'alloc': records only allocation events || 'state': records memory state changes || 'all': records both.\n      context: \"all\"\n      # 'python': records Python stacks || 'cpp': records C++ stacks (available in some versions) || 'all': records both.\n      stacks: \"all\"\n      # devices, record_context etc.\n      kw_args: {}\n\n# configs for TransferQueue\ntransfer_queue:\n\n  # Whether to enable transfer queue\n  enable: False\n\nray_kwargs:\n  ray_init:\n    num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/ppo_trainer.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# specify the default per-component configs\ndefaults:\n\n  # <folder_name>@<field_name>.<field_name>: <yaml_file_name>\n  # actor_rollout_ref.actor: trainer/config/actor/dp_actor.yaml\n  - actor@actor_rollout_ref.actor: dp_actor\n\n  # data: trainer/config/data/legacy_data.yaml\n  - data@data: legacy_data\n\n  # Reference model config.\n  # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True.\n  - ref@actor_rollout_ref.ref: dp_ref\n\n  # Rollout model config.\n  - rollout@actor_rollout_ref.rollout: rollout\n\n  # Model config.\n  - model@actor_rollout_ref.model: hf_model\n\n  # Critic model config.\n  - critic@critic: dp_critic\n\n  # Reward model config.\n  - reward_model@reward_model: dp_reward_model\n\n  # load the reference default config, then apply the fields in the current yaml\n  # self config override anything above\n  - _self_\n\n# config for actor, rollout and reference model\nactor_rollout_ref:\n\n  # Whether it's a hybrid engine, currently only supports hybrid engine\n  hybrid_engine: true\n\n  # Timeout for operations executed against the process group\n  nccl_timeout: 600\n\n  # Rollout model config.\n  rollout:\n\n    # for huge model, layered summon can save memory (prevent OOM) but make it slower\n    layered_summon: False\n\n# custom reward function definition\ncustom_reward_function:\n\n  # The path to the file containing your customized reward function.\n  # If not specified, pre-implemented reward functions will be used.\n  path: null\n\n  # The name of the reward function within the specified file. Default is 'compute_score'.\n  name: compute_score\n\n# config for the algorithm\nalgorithm:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.trainer.config.AlgoConfig\n\n  # Discount factor for future rewards\n  gamma: 1.0\n\n  # Trade-off between bias and variance in the GAE estimator\n  lam: 1.0\n\n  # Advantage estimator type: \"gae\", \"grpo\", \"reinforce_plus_plus\", etc.\n  adv_estimator: gae\n\n  # Whether to normalize advantages by std (specific to GRPO)\n  norm_adv_by_std_in_grpo: True\n\n  # Whether to enable in-reward KL penalty\n  use_kl_in_reward: False\n\n  # How to estimate KL divergence: \"kl\", \"abs\", \"mse\", \"low_var_kl\", or \"full\"\n  kl_penalty: kl\n\n  # KL control configuration\n  kl_ctrl:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.trainer.config.KLControlConfig\n\n    # KL control type: \"fixed\" or \"adaptive\"\n    type: fixed\n\n    # Initial coefficient for KL penalty\n    kl_coef: 0.001\n\n    # Horizon value for adaptive controller (if enabled)\n    horizon: 10000\n\n    # Target KL divergence (used for adaptive controller)\n    target_kl: 0.1\n\n  # Whether to enable preference feedback PPO\n  use_pf_ppo: False\n\n  # Preference feedback PPO settings\n  pf_ppo:\n\n    # Method for reweighting samples: \"pow\", \"max_min\", or \"max_random\"\n    reweight_method: pow\n\n    # Power used for weight scaling in \"pow\" method\n    weight_pow: 2.0\n\n  # Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies\n  # Main control: Upper threshold for IS weights (null = disabled, float = enabled)\n  # When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.)\n  rollout_is_threshold: null\n\n  # Lower threshold for IS weights (null = auto-reciprocal of upper)\n  rollout_is_threshold_lower: null\n\n  # Aggregation level: \"token\" (biased), \"sequence\" (unbiased), \"geometric\" (experimental)\n  rollout_is_level: token\n\n  # Bounding mode: \"truncate\" (cap upper only), \"mask\" (zero outside bounds)\n  rollout_is_mode: truncate\n\n  # Per-token veto threshold for catastrophic outliers (null to disable)\n  rollout_is_veto_threshold: null\n\n  # Whether to apply IS weights to policy loss\n  # true = apply weights to loss, false = compute metrics only (no weight application)\n  # Useful for monitoring mismatch before enabling correction\n  rollout_is: false\n\n  # distill advantage clip params\n  distill_adv_max_clip: 1e9\n  distill_adv_min_clip: -1e9\n\n# config for the trainer\ntrainer:\n\n  # Whether to balance batch sizes across distributed workers\n  balance_batch: True\n\n  # Number of epochs in training\n  total_epochs: 30\n\n  # Total training steps (can be set explicitly or derived from epochs)\n  total_training_steps: null\n\n  # Project name for experiment tracking (e.g., wandb)\n  project_name: verl_examples\n\n  # Experiment name for run identification in tracking tools\n  experiment_name: gsm8k\n\n  # Logging backends to use: \"console\", \"wandb\", etc.\n  logger: [\"console\", \"wandb\"]\n\n  # Number of generations to log during validation\n  log_val_generations: 0\n\n  # Directory for logging rollout data; no dump if null\n  rollout_data_dir: null\n\n  # Directory for logging validation data; no dump if null\n  validation_data_dir: null\n\n  # Number of nodes used in the training\n  nnodes: 1\n\n  # Number of GPUs per node\n  n_gpus_per_node: 8\n\n  # Save frequency (by iteration) for model checkpoints\n  save_freq: -1\n\n  # ESI refers to the elastic server instance used during training, similar to the training plan. For example,\n  # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training.\n  # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance.\n  # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time.\n  # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety.\n  esi_redundant_time: 0\n\n  # Resume mode: \"auto\", \"disable\", or \"resume_path\"\n  # \"auto\": resume from last checkpoint if available\n  # \"disable\": start from scratch\n  # \"resume_path\": resume from a user-defined path\n  resume_mode: auto\n\n  # Path to resume training from (only used when resume_mode is \"resume_path\")\n  resume_from_path: null\n\n  # Whether to run validation before training begins\n  val_before_train: True\n\n  # Whether to run validation only\n  val_only: False\n\n  # Validation frequency (in training iterations)\n  test_freq: -1\n\n  # Number of iterations to warm up the critic before updating policy\n  critic_warmup: 0\n\n  # Default path to distributed filesystem for saving checkpoints\n  default_hdfs_dir: null\n\n  # Whether to delete local checkpoints after loading\n  del_local_ckpt_after_load: False\n\n  # Default local directory for saving checkpoints\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n\n  # Maximum number of actor checkpoints to keep\n  max_actor_ckpt_to_keep: null\n\n  # Maximum number of critic checkpoints to keep\n  max_critic_ckpt_to_keep: null\n\n  # Timeout (in seconds) for Ray worker to wait for registration\n  ray_wait_register_center_timeout: 300\n\n  # Device to run training on (e.g., \"cuda\", \"cpu\")\n  device: cuda\n\n  # whether to use legacy worker implementation\n  #  mode: \"auto\", \"enable\", or \"disable\"\n  use_legacy_worker_impl: auto\n\n# profiler configs\nglobal_profiler:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.utils.profiler.ProfilerConfig\n\n  # Profiling tool: choose between nsys, npu, torch, torch_memory\n  tool: null\n\n  # profile steps\n  steps: null\n\n  # Whether to combine continuous steps into one database.\n  ## If True, worker.profiler.discrete must be False, [1,2] in one, [5] in another.\n  ## If False, [1] in one, [2] in another, [5] in another.\n  profile_continuous_steps: False\n\n  # Path to save profiling contents\n  save_path: \"outputs/profile\"\n\n  # Specific tool configs, can use +profiler.tool_config.[tool].xxx to config\n  global_tool_config:\n\n    # nsys config\n    nsys:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.NsightToolConfig\n\n      # True for each task has its own database, False for all tasks in one training step share one database.\n      discrete: False\n\n      # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None.\n      ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html\n      ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html\n      controller_nsight_options:\n\n        # Select the API(s) to be traced.\n        trace: \"cuda,nvtx,cublas,ucx\"\n\n        # Track the GPU memory usage by CUDA kernels. Must be string type \"true\" or \"false\".\n        cuda-memory-usage: \"true\"\n\n        # CUDA graphs will be traced as a whole\n        cuda-graph-trace: \"graph\"\n\n      # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None.\n      worker_nsight_options:\n\n        # Select the API(s) to be traced.\n        trace: \"cuda,nvtx,cublas,ucx\"\n\n        # Track the GPU memory usage by CUDA kernels. Must be string type \"true\" or \"false\".\n        cuda-memory-usage: \"true\"\n\n        # CUDA graphs will be traced as a whole\n        cuda-graph-trace: \"graph\"\n\n        # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config.\n        capture-range: \"cudaProfilerApi\"\n\n        # Specify the desired behavior when a capture range ends.\n        # In verl we need the torch.cuda.profiler.start/stop pair to repeats n times.\n        # valid values are \"repeat-shutdown:n\" or null.\n        # For normal whole step profiling, n = len(profile_steps);\n        # but for discrete profiling, n = len(profile_steps) * Number(subtasks).\n        # Or you can just leave it null and the program will use n = len(profile_steps) * 6;\n        capture-range-end: null\n\n        # Send signal to the target application's process group. We let the program to exit by itself.\n        kill: none\n\n    # enable memory visualization for debugging memory usage\n    torch_memory:\n\n      #  Maximum number of allocation entries to record\n      trace_alloc_max_entries: 100_000\n\n      # The depth of the call stack to capture for each allocation\n      stack_depth: 32\n\n      # 'alloc': records only allocation events || 'state': records memory state changes || 'all': records both.\n      context: \"all\"\n\n      # 'python': records Python stacks || 'cpp': records C++ stacks (available in some versions) || 'all': records both.\n      stacks: \"all\"\n\n      # devices, record_context etc.\n      kw_args: {}\n\n# configs for TransferQueue\ntransfer_queue:\n\n  # Whether to enable transfer queue\n  enable: False\n\n# configs related to ray\nray_kwargs:\n\n  # configs related to ray initialization\n  ray_init:\n\n    # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM.\n    num_cpus: null\n\n  # Path to save Ray timeline JSON for performance profiling\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/ref/dp_ref.yaml",
    "content": "# defaults specify the default config from each component\ndefaults:\n\n  # dp ref config, inheriting from trainer/config/ref/ref.yaml\n  - ref\n  \n  # fsdp engine config\n  - ../engine@fsdp_config: fsdp\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\n# ref model is assumed to be identical to actor model. Specify model.path for using a different ref model.\n# Potential use case involves on policy distillation where we calculate KL divergence between student actor\n# and teacher ref\nmodel: null\n\n# sequence parallel size\n# same as actor_rollout_ref.actor.ulysses_sequence_parallel_size if it exists, otherwise 1\nulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}\n\n# calculate entropy with chunking to reduce memory peak\nentropy_from_logits_with_chunking: False\n\n# recompute entropy\nentropy_checkpointing: False\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/ref/megatron_ref.yaml",
    "content": "# megatron ref config, inheriting from trainer/config/ref/ref.yaml\ndefaults:\n  - ref\n\n  # megatron engine config\n  - ../engine@megatron: megatron\n  \n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\nstrategy: megatron\n\nmegatron:\n  _target_: verl.workers.config.MegatronEngineConfig\n  seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}\n  override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}\n  use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}\n\nload_weight: True"
  },
  {
    "path": "verl_distillation/verl/trainer/config/ref/ref.yaml",
    "content": "# actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default\nstrategy: ${actor_rollout_ref.actor.strategy}\n\n# whether to enable torch.compile\n# same as actor_rollout_ref.actor.use_torch_compile if it exists, otherwise 1\nuse_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}\n\n# [Will be deprecated, use log_prob_micro_batch_size_per_gpu]\n# The batch size for one forward pass in the computation of log_prob. Global batch size.\nlog_prob_micro_batch_size: null\n\n# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.\nlog_prob_micro_batch_size_per_gpu: null\n\n# enable dynamic batch size (sequence packing) for log_prob computation\n# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false\nlog_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n\n# the max token length per GPU\n# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384\nlog_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n\n# float val to replace the ref_log_prob\nref_log_prob_replace_val: -10.0\n\n# profile the ref model in `compute_log_prob`\nprofiler:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.utils.profiler.ProfilerConfig\n\n  # choices: nsys, npu, torch, torch_memory\n  tool: ${oc.select:global_profiler.tool,null}\n\n  # whether enable profile on Ref\n  enable: False\n\n  # Whether to profile all ranks.\n  all_ranks: False\n\n  # The ranks that will be profiled. [] or [0,1,...]\n  ranks: []\n\n  # profile results saving path\n  save_path: ${oc.select:global_profiler.save_path,null}\n\n  # specific tool config which only related to the role\n  tool_config:\n\n    # nsys tool config\n    nsys:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.NsightToolConfig\n    \n      # True for each task has its own database, False for all tasks in one training step share one database.\n      discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}\n    \n    # npu config\n    npu:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.NPUToolConfig\n\n      # Contents to profile, can be empty\n      # options: npu, cpu, memory, shapes, module, stack\n      contents: []\n\n      # Collection level, optional values: level_none, level0, level1, level2.\n      level: \"level1\"\n\n      # Whether to automatically parse the data.\n      analysis: True\n\n      # True for each task has its own database, False for all tasks in one training step share one database.\n      discrete: False\n    \n    # torch profiler config\n    torch:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.TorchProfilerToolConfig\n\n      # start profile mini-batch in training\n      # NOTICE: different with global steps config which refers to iteration\n      # This field only related with mini-batch\n      step_start: 0\n\n      # stop profile mini-batch in training\n      step_end: null\n\n    # torch memory profiler config\n    torch_memory:\n\n      # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n      _target_: verl.utils.profiler.config.TorchMemoryToolConfig\n\n      # Maximum number of memory allocation entries to track\n      trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}\n\n      # Stack trace depth for memory allocations\n      stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}"
  },
  {
    "path": "verl_distillation/verl/trainer/config/reward_model/dp_reward_model.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# defaults specify the default config from each component\ndefaults:\n\n  # dp actor config, inheriting from trainer/config/reward_model/reward_model.yaml\n  - reward_model\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\nstrategy: fsdp\n\nmodel:\n\n  # Whether to use shared memory for loading the model\n  use_shm: False\n\n  # Use remove padding optimization (saves compute)\n  use_remove_padding: False\n\n  # Whether to use fused reward kernels for speedup\n  use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}\n\n  # FSDP-specific config\n  fsdp_config:\n\n    # Target configuration dataclass\n    _target_: verl.workers.config.FSDPEngineConfig\n\n    # Policy for wrapping layers with FSDP\n    wrap_policy:\n\n      # Minimum number of parameters to trigger wrapping\n      min_num_params: 0\n\n    # Whether to offload model parameters to CPU\n    param_offload: False\n\n    # Only for FSDP2: Reshard after forward pass to reduce memory footprint\n    reshard_after_forward: True\n\n    # Number of GPUs in each FSDP shard group; -1 means auto\n    fsdp_size: -1\n\n    # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n    # before the current forward computation.\n    forward_prefetch: False\n\n# Sequence parallelism size for Ulysses-style model parallelism\nulysses_sequence_parallel_size: 1"
  },
  {
    "path": "verl_distillation/verl/trainer/config/reward_model/megatron_reward_model.yaml",
    "content": "# defaults specify the default config from each component\ndefaults:\n\n  # dp actor config, inheriting from trainer/config/reward_model/reward_model.yaml\n  - reward_model\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\nstrategy: megatron\n\n# seconds, default is 10 minutes for torch, you can set it to a larger value\n# if you have long-running operations like 32B or 72B model using megatron\nnccl_timeout: 600\n\n# Megatron parallelism & checkpointing config\nmegatron:\n\n  # Target configuration dataclass\n  _target_: verl.workers.config.MegatronEngineConfig\n\n  # Whether to offload model parameters to CPU\n  param_offload: False\n\n  # Number of GPUs in tensor model parallel group\n  tensor_model_parallel_size: 1\n\n  # Number of GPUs in expert model parallel group\n  expert_model_parallel_size: 1\n\n  # Expert tensor parallel size\n  expert_tensor_parallel_size: 1\n\n  # Number of pipeline model parallel stages\n  pipeline_model_parallel_size: 1\n\n  # change VPP interface for parallelism tests\n  virtual_pipeline_model_parallel_size: null\n\n  # Context parallel size\n  context_parallel_size: 1\n\n  # Whether to use sequence parallelism\n  sequence_parallel: True\n\n  # Whether to use distributed optimizer\n  use_distributed_optimizer: False\n\n  # Whether to enable distributed checkpointing\n  use_dist_checkpointing: False\n\n  # Path for distributed checkpoints\n  dist_checkpointing_path: null\n\n  # RNG seed for megatron\n  seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}\n\n  # Any overrides to transformer config\n  override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}\n\n  # Whether to use mbridge for faster comms\n  use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}\n\n# Whether to load weights (default True)\nload_weight: True"
  },
  {
    "path": "verl_distillation/verl/trainer/config/reward_model/reward_model.yaml",
    "content": "# configs for the reward model\n\n# Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions.\n# In GSM8K and Math examples, we disable reward model.\n# For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses.\n# If False, the following parameters are not effective\nenable: False\n\n# Whether to deploy the model to a separate resource pool.\n# If true, n_gpus_per_node & nnodes will be used to determine the resource node.\nenable_resource_pool: False\nn_gpus_per_node: 0\nnnodes: 0\n\n# FSDP strategy: \"fsdp\" or \"fsdp2\"\nstrategy: ???\n\n# model config for reward scoring\nmodel:\n\n  # Input tokenizer. If the reward model's chat template is inconsistent with the policy,\n  # we need to first decode to plaintext, then apply the rm's chat_template.\n  # Then score with RM. If chat_templates are consistent, it can be set to null.\n  # set this to null if the chat template is identical\n  input_tokenizer: ${actor_rollout_ref.model.path}\n\n  # RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification.\n  # Other model types need to define their own RewardModelWorker and pass it from the code.\n  path: ~/models/FsfairX-LLaMA3-RM-v0.1\n\n  # External model implementation (optional)\n  external_lib: ${actor_rollout_ref.model.external_lib}\n\n  # Whether to enable loading a remote code model, default to False\n  trust_remote_code: False\n\n# [Deprecated] Global micro batch size\n# will be deprecated, use micro_batch_size_per_gpu\nmicro_batch_size: null\n\n# Local per-GPU micro batch size\nmicro_batch_size_per_gpu: null\n\n# Maximum sequence length to process for scoring\nmax_length: null\n\n# Whether to dynamically adjust batch size at runtime\nuse_dynamic_bsz: ${critic.use_dynamic_bsz}\n\n# Maximum number of tokens per GPU in one forward pass\nforward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n\n# Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources.\n# Default is naive. If all verification functions are multiprocessing-safe,\n# the reward manager can be set to prime for parallel verification.\nreward_manager: naive\n\n# Whether to launch custom reward function asynchronously during log_prob\n# custom reward function executed async on CPU, during log_prob\nlaunch_reward_fn_async: False\n\n# Cloud/local sandbox fusion configuration for custom reward logic\nsandbox_fusion:\n\n  # Cloud /local function URL for sandbox execution\n  url: null\n\n  # Max concurrent requests allowed to sandbox\n  max_concurrent: 64\n\n  # Max memory limit for each sandbox process in MB\n  memory_limit_mb: 1024\n\n# profile the reward model in `compute_reward` \nprofiler:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.utils.profiler.ProfilerConfig\n\n  # profiler tool, default same as profiler.tool in global config\n  # choices: nsys, npu, torch\n  tool: ${oc.select:global_profiler.tool,null}\n\n  # whether enable profile on ref\n  enable: False\n  \n  # Whether to profile all ranks.\n  all_ranks: False\n\n  # The ranks that will be profiled. [] or [0,1,...]\n  ranks: []\n\n  # profile results saving path\n  save_path: ${oc.select:global_profiler.save_path,null}\n\n  # specific tool config\n  tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}"
  },
  {
    "path": "verl_distillation/verl/trainer/config/rollout/rollout.yaml",
    "content": "# Target class for this configuration\n_target_: verl.workers.config.RolloutConfig\n\n# actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future\nname: ???\n\n# sync: LLM, async: AsyncLLM\nmode: sync\n\n# Sampling temperature for rollout.\ntemperature: 1.0\n\n# Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.\ntop_k: -1\n\n# Top-p sampling parameter. Default 1.0.\ntop_p: 1\n\n# typically the same as data max prompt length\n# same as data.max_prompt_length if it exists\nprompt_length: ${oc.select:data.max_prompt_length,512}\n\n# typically the same as data max response length\n# same as data.max_response_length if it exists\nresponse_length: ${oc.select:data.max_response_length,512}\n\n# for vllm rollout\n# Rollout model parameters type. Align with actor model's FSDP/Megatron type.\ndtype: bfloat16\n\n# Fraction of GPU memory used by vLLM/SGLang for KV cache.\ngpu_memory_utilization: 0.5\n\n# Whether to ignore EOS and continue generating after EOS is hit.\nignore_eos: False\n\n# Whether to disable CUDA graph. Default False to best performance.\nenforce_eager: False\n\n# batch size of cudagraph to capture. Require enforce_eager: False to use this option\n# Since cudagraph in inference engine can not be offloaded during update policy,\n# you can use smaller batch size to save memory used in cuda graph, eg: [1 ,2, 4, 8, 16, 32]\n# supported engines: vllm\ncudagraph_capture_sizes: null\n\n# Whether to free engine KVCache after generation.\nfree_cache_engine: True\n\n# TP size for rollout. Not effective for hf\ntensor_model_parallel_size: 2\n\n# DP size for rollout\ndata_parallel_size: 1\n\n# EP size for rollout\nexpert_parallel_size: 1\n\n# PP size for rollout.\npipeline_model_parallel_size: 1\n\n# max number of tokens in a batch\nmax_num_batched_tokens: 8192\n\n# max length for rollout\nmax_model_len: null\n\n# max length of sequences\nmax_num_seqs: 1024\n\n# may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.\nenable_chunked_prefill: True\n\n# Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations.\nenable_prefix_caching: True\n\n# Which loader to use for rollout model weights: dummy, hf, megatron, etc.\n# safetensors (for huge model, and set use_shm=True); dummy: randomly init model weight\nload_format: dummy\n\n# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size.\nlog_prob_micro_batch_size: null\n\n# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.\nlog_prob_micro_batch_size_per_gpu: null\n\n# enable dynamic batch size (sequence packing) for log_prob computation\n# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false\nlog_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n\n# max token length for log_prob computation\n# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384\nlog_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n\n# disable logging statistics\ndisable_log_stats: True\n\n# for hf rollout\n# Whether to sample during training rollout. False uses greedy sampling.\ndo_sample: True\n\n# number of responses (i.e. num sample times). > 1 for grpo\nn: 1\n\n# The over_sample_rate parameter controls the early termination threshold for training rollouts,\n# where the system will abort remaining requests when (1 - over_sample_rate) * total_requests completions are reached.\nover_sample_rate: 0\n\n# Whether to wake up inference engine in multi-stage for SGLang\n# to reduce peak memory during training-rollout transition.\n# This is only effective for SGLang rollout.\nmulti_stage_wake_up: false\n\n# Extra inference engine arguments (vllm, sglang), please refer vllm/sglang official doc for detail\nengine_kwargs:\n\n  # vllm engine config\n  vllm: {}\n\n  # sglang engine config\n  sglang: {}\n\n# Sampling parameters used during validation.\nval_kwargs:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.workers.config.SamplingConfig\n\n  # sampling parameters for validation\n  # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.\n  top_k: -1\n\n  # Top-p sampling parameter. Default 1.0.\n  top_p: 1.0\n\n  # Sampling temperature for rollout.\n  temperature: 0\n\n  # whether to repeat n times for validation\n  n: 1\n\n  # Whether to sample during training rollout. False uses greedy sampling.\n  do_sample: False\n\n# Multi-turn interaction config for tools or chat.\nmulti_turn:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.workers.config.MultiTurnConfig\n\n  # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well\n  enable: False\n\n  # null for no limit (default max_length // 3)\n  max_assistant_turns: null\n\n  # null for no tool\n  tool_config_path: null\n\n  # null for no limit (default max_length // 3)\n  max_user_turns: null\n\n  # max parallel call for tools in single turn\n  max_parallel_calls: 1\n\n  # max length of tool response\n  max_tool_response_length: 256\n\n  # truncate side of tool response: left, middle, right\n  tool_response_truncate_side: middle\n\n  # null for no interaction\n  interaction_config_path: null\n\n  # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.\n  # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,\n  #   which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.\n  use_inference_chat_template: False\n\n  # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.\n  # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.\n  # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.\n  # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:\n  # Qwen/QwQ-32B, Qwen/Qwen3-xxB\n  # - disable: disable tokenization sanity check\n  # - strict: enable strict tokenization sanity check (default)\n  # - ignore_strippable: ignore strippable tokens when checking tokenization sanity\n  tokenization_sanity_check_mode: strict\n\n  # Format of the multi-turn interaction. Options: hermes, llama3_json, ...\n  format: hermes\n\n  # Number of repeat rollouts for each interaction\n  num_repeat_rollouts: null\n\n# support logging rollout prob for debugging purpose\n# \"Truncated importance sampling\" requires rollout log probs, set to True when turning on Truncated importance sampling\ncalculate_log_probs: False\n\n# mask special token in response, for on policy distill\nextend_vocab_start_token: null\n\n# mask_response_if_have_extend_token, for on policy distill\nmask_response_if_have_extend_token: False\n\n# [Experimental] agent loop based rollout configs\nagent:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.workers.config.AgentLoopConfig\n\n  # Number of agent loop workers\n  num_workers: 8\n\n  # default agent loop to use if `agent_name` not set in RL dataset\n  default_agent_loop: single_turn_agent\n\n  # custom agent loop config path, which should contain list of configs to intialize AgentLoop instances.\n  # https://hydra.cc/docs/advanced/instantiate_objects/overview/\n  #\n  # - name: react_agent\n  #   _target_: recipe.langgraph_agent.react_agent_loop.ReactAgentLoop\n  #   tools: [\"get_current_temperature\"]\n  # - name: math_expression\n  #   _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop\n  #   min_terms: 2\n  #   max_terms: 6\n  agent_loop_config_path: null\n\n  # custom async server configs\n  custom_async_server:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n    _target_: verl.workers.config.CustomAsyncServerConfig\n\n    # Path to the custom async server implementation\n    path: null\n\n    # Class name of the custom async server class (e.g. AsyncvLLMServer)\n    name: null\n\n# Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations.\n# This parameter controls the maximum payload size for a single weight update request.\n# Reference: https://github.com/volcengine/verl/pull/2418\n# Currently only supported in SGLang rollout implementations\n# Larger values may improve throughput but increase memory overhead\n# Detailed performance comparison:\n# https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/169#issuecomment-3070686720\n# Default value (512MB) is optimized for typical GPU memory configurations\n# For the best performance of `rebuild_cuda_tensor`, it is recommended to:\n# 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`\n# 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`\n# when using Tensor Parallelism (TP) >= 8.\nupdate_weights_bucket_megabytes: 512\n\n# trace rollout data\ntrace:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.workers.config.TraceConfig\n\n  # trace backend, support mlflow, weave\n  backend: null\n\n  # whether translate token id to text in output\n  token2text: False\n\n# When enabled (True), the trainer will attempt to load previously generated rollout data from the specified directory instead of computing new rollouts.\n# If no cached data is found or loading fails, new rollouts will be generated and automatically saved.\n# This feature is useful for debugging or when you want to reuse computation results across multiple runs.\nskip_rollout: False\n\n# Specifies the filesystem path where rollout data should be cached when skip_rollout is enabled.\n# Note: Giving path under /tmp/ray/session* is not recommended as these are temporary Ray cluster directories.\nskip_dump_dir: /tmp/rollout_dump\n\n# Whether to skip tokenizer initialization for rollout engine\n# When enabled (True), the rollout assume token in token out for generation\nskip_tokenizer_init: True\n\n# profile the rollout model in `generate_sequence` \nprofiler:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs\n  _target_: verl.utils.profiler.ProfilerConfig\n\n  # profiler tool, default same as profiler.tool in global config\n  # choices: nsys, npu, torch\n  tool: ${oc.select:global_profiler.tool,null}\n\n  # whether enable profile on ref\n  enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false}\n\n  # Whether to profile all ranks.\n  all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false}\n\n  # The ranks that will be profiled. [] or [0,1,...]\n  ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]}\n\n  # profile results saving path\n  save_path: ${oc.select:global_profiler.save_path,null}\n\n  # specific tool config\n  tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/sft_trainer.yaml",
    "content": "defaults:\n  - optim: fsdp\n  - _self_\n\ndata:\n  train_batch_size: 256\n  micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu\n  micro_batch_size_per_gpu: 4  # this is also val batch size\n  train_files: ~/data/gsm8k/train.parquet\n  val_files: ~/data/gsm8k/test.parquet\n  train_max_samples: -1  # set to -1 to use full dataset\n  val_max_samples: -1  # set to -1 to use full dataset\n  # Single-turn settings\n  prompt_key: question\n  response_key: answer\n  prompt_dict_keys: null\n  response_dict_keys: null\n  # Multi-turn settings\n  multiturn:\n    enable: false  # Set to true to use multi-turn dataset\n    messages_key: messages  # Key for messages list in multi-turn mode\n    tools_key: tools  # Key for tools list in multi-turn mode\n    enable_thinking_key: enable_thinking  # Whether to enable thinking in multi-turn mode\n  max_length: 1024\n  truncation: error\n  balance_dp_token: False\n  chat_template: null\n  custom_cls:\n    path: null\n    name: null\n  use_shm: False\n  apply_chat_template_kwargs: {}\nmodel:\n  partial_pretrain: ~/models/gemma-1.1-7b-it\n  use_shm: False\n  fsdp_config:\n    model_dtype: fp32\n    wrap_policy:\n      min_num_params: 0\n    cpu_offload: False\n    offload_params: False\n  external_lib: null\n  enable_gradient_checkpointing: True\n  trust_remote_code: False\n  lora_rank: 0  # Set to positive value to enable LoRA (e.g., 32)\n  lora_alpha: 16  # LoRA scaling factor\n  target_modules: all-linear  # Target modules for LoRA adaptation\n  use_liger: False\n  strategy: fsdp2\noptim:\n  lr: 1e-5\n  betas: [0.9, 0.95]\n  weight_decay: 0.01\n  lr_warmup_steps_ratio: 0.1\n  clip_grad: 1.0\n  lr_scheduler: cosine\nulysses_sequence_parallel_size: 1\nuse_remove_padding: False\ntrainer:\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  default_hdfs_dir: null\n  project_name: gsm8k-sft\n  experiment_name: test\n  total_epochs: 4\n  total_training_steps: null\n  logger: [ 'console', 'wandb' ]\n  seed: 1\n  save_freq: -1\n  test_freq: -1\n  nnodes: 1\n  n_gpus_per_node: 8\n  max_ckpt_to_keep: null  # Maximum number of checkpoints to keep, set to null to keep all\n\n  # Resume mode: \"auto\", \"disable\", or \"resume_path\"\n  # \"auto\": resume from last checkpoint if available\n  # \"disable\": start from scratch\n  # \"resume_path\": resume from a user-defined path\n  resume_mode: auto\n\n  # Path to resume training from (used when resume_mode is \"resume_path\" or \"auto\")\n  resume_from_path: null\n\n  # Checkpoint configuration\n  checkpoint:\n    # What to include in saved checkpoints\n    # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n    save_contents: [\"model\", \"optimizer\", \"extra\"]\n\n    # For more flexibility, you can specify the contents to load from the checkpoint.\n    load_contents: ${trainer.checkpoint.save_contents}\n  device: cuda\n"
  },
  {
    "path": "verl_distillation/verl/trainer/config/sft_trainer_engine.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# <folder_name>@<field_name>.<field_name>: <yaml_file_name>\n\ndefaults:\n  - model@model: hf_model\n  - engine@engine: fsdp\n  - optim@optim: fsdp\n  - _self_\n\ndata:\n  train_batch_size: 256 # global batch size\n  micro_batch_size_per_gpu: 4  # this is also val batch size\n  max_token_len_per_gpu: 8192\n  use_dynamic_bsz: True\n  train_files: ~/data/gsm8k/train.parquet\n  val_files: null\n  train_max_samples: -1  # set to -1 to use full dataset\n  val_max_samples: -1  # set to -1 to use full dataset\n  # Multi-turn settings\n  messages_key: messages  # Key for messages list in multi-turn mode\n  tools_key: tools  # Key for tools list in multi-turn mode\n  enable_thinking_key: enable_thinking  # Whether to enable thinking in multi-turn mode\n  pad_mode: no_padding\n  # for right padding\n  max_length: 1024\n  truncation: error\n  balance_dp_token: False # to be implement\n  custom_cls:\n    path: null\n    name: null\n  use_shm: False\n  apply_chat_template_kwargs: {}\n\n# Checkpoint configuration\ncheckpoint:\n  _target_: verl.trainer.config.CheckpointConfig\n  # What to include in saved checkpoints\n  # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n  save_contents: [\"model\", \"optimizer\", \"extra\"]\n\n  # For more flexibility, you can specify the contents to load from the checkpoint.\n  load_contents: ${checkpoint.save_contents}\n\ntrainer:\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  default_hdfs_dir: null\n  project_name: gsm8k-sft\n  experiment_name: test\n  total_epochs: 4\n  total_training_steps: null\n  logger: [ 'console', 'wandb' ]\n  seed: 1\n  save_freq: -1\n  test_freq: -1\n  max_ckpt_to_keep: null  # Maximum number of checkpoints to keep, set to null to keep all\n\n  # Resume mode: \"auto\", \"disable\", or \"resume_path\"\n  # \"auto\": resume from last checkpoint if available\n  # \"disable\": start from scratch\n  # \"resume_path\": resume from a user-defined path\n  resume_mode: auto\n\n  # Path to resume training from (used when resume_mode is \"resume_path\" or \"auto\")\n  resume_from_path: null  \n  device: cuda\n"
  },
  {
    "path": "verl_distillation/verl/trainer/constants_ppo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\n\nfrom ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR\n\nPPO_RAY_RUNTIME_ENV = {\n    \"env_vars\": {\n        \"TOKENIZERS_PARALLELISM\": \"true\",\n        \"NCCL_DEBUG\": \"WARN\",\n        \"VLLM_LOGGING_LEVEL\": \"WARN\",\n        \"VLLM_ALLOW_RUNTIME_LORA_UPDATING\": \"true\",\n        \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",\n        # To prevent hanging or crash during synchronization of weights between actor and rollout\n        # in disaggregated mode. See:\n        # https://docs.vllm.ai/en/latest/usage/troubleshooting.html?h=nccl_cumem_enable#known-issues\n        # https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445\n        \"NCCL_CUMEM_ENABLE\": \"0\",\n    },\n}\n\n\ndef get_ppo_ray_runtime_env():\n    \"\"\"\n    A filter function to return the PPO Ray runtime environment.\n    To avoid repeat of some environment variables that are already set.\n    \"\"\"\n    working_dir = (\n        json.loads(os.environ.get(RAY_JOB_CONFIG_JSON_ENV_VAR, \"{}\")).get(\"runtime_env\", {}).get(\"working_dir\", None)\n    )\n\n    runtime_env = {\n        \"env_vars\": PPO_RAY_RUNTIME_ENV[\"env_vars\"].copy(),\n        **({\"working_dir\": None} if working_dir is None else {}),\n    }\n    for key in list(runtime_env[\"env_vars\"].keys()):\n        if os.environ.get(key) is not None:\n            runtime_env[\"env_vars\"].pop(key, None)\n    return runtime_env\n"
  },
  {
    "path": "verl_distillation/verl/trainer/fsdp_sft_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA lightweight one-file FSDP SFT Trainer\nTODO(zhangchi.usc1992)\n- Add calculation of mfu\n- Add validation\n\"\"\"\n\nimport os\n\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n\nimport logging\nimport re\nimport time\nfrom contextlib import nullcontext\n\nimport hydra\nimport torch\nimport torch.distributed\nfrom omegaconf import DictConfig, OmegaConf\nfrom peft import LoraConfig, TaskType, get_peft_model\nfrom tensordict import TensorDict\nfrom torch import nn\nfrom torch.distributed.device_mesh import DeviceMesh, init_device_mesh\nfrom torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.utils.data import Dataset, DistributedSampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom tqdm import tqdm\nfrom transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel\n\nimport verl.utils.hdfs_io as hdfs_io\nfrom verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input\nfrom verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, get_checkpoint_tracker_filename\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.dataset import SFTDataset\nfrom verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset\nfrom verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available\nfrom verl.utils.distributed import destroy_global_process_group, initialize_global_process_group\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.fsdp_utils import (\n    CPUOffloadPolicy,\n    MixedPrecisionPolicy,\n    apply_fsdp2,\n    fsdp2_clip_grad_norm_,\n    fsdp2_load_full_state_dict,\n    get_fsdp_wrap_policy,\n    get_init_weight_context_manager,\n    init_fn,\n)\nfrom verl.utils.logger import log_with_rank\nfrom verl.utils.profiler import log_gpu_memory_usage\nfrom verl.utils.py_functional import convert_to_regular_types\nfrom verl.utils.torch_dtypes import PrecisionType\nfrom verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup\nfrom verl.utils.tracking import Tracking\nfrom verl.utils.ulysses import (\n    gather_outputs_and_unpad,\n    get_ulysses_sequence_parallel_world_size,\n    ulysses_pad_and_slice_inputs,\n)\nfrom verl.workers.config.optimizer import build_optimizer\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_SFT_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef extract_step(path):\n    match = re.search(r\"global_step_(\\d+)\", path)\n    if match:\n        return int(match.group(1))\n    return None\n\n\nclass FSDPSFTTrainer:\n    def __init__(\n        self,\n        config,\n        device_mesh: DeviceMesh,\n        ulysses_device_mesh: DeviceMesh,\n        tokenizer,\n        train_dataset: Dataset,\n        val_dataset: Dataset,\n    ):\n        self.config = config\n        self.device_mesh = device_mesh\n        self.ulysses_device_mesh = ulysses_device_mesh\n        self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n        self.tokenizer = tokenizer\n        if self.config.data.chat_template is not None:\n            raise ValueError(\"Apply Chat template from config is not supported yet.\")\n\n        # normalize dp size\n        self._normalize_config_bsz()\n\n        # Set sequence parallel size\n        self.config.ulysses_sequence_parallel_size = getattr(self.config, \"ulysses_sequence_parallel_size\", 1)\n        self.use_remove_padding = getattr(self.config, \"use_remove_padding\", False)\n        if self.device_mesh.get_rank() == 0:\n            print(f\"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}\")\n            print(f\"Using remove padding: {self.use_remove_padding}\")\n\n        self._build_dataloader(train_dataset, val_dataset)\n\n        self.lora = self.config.model.get(\"lora_adapter_path\") is not None or self.config.model.lora_rank > 0\n\n        # Initialize resume-related variables\n        self.resume_global_step = 0\n\n        # build model\n        self._build_model_optimizer()\n\n        # Initialize checkpoint manager\n        self._init_checkpoint_manager()\n\n        self.load_checkpoint()\n\n        if self.device_mesh.get_rank() == 0:\n            print(self.config)\n        self.device_name = self.config.trainer.device\n\n    def _normalize_config_bsz(self):\n        dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0)\n        if self.device_mesh.get_rank() == 0:\n            print(f\"Normalize batch size by dp {dp_size}\")\n\n        assert self.config.data.train_batch_size % dp_size == 0, (\n            f\"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}\"\n        )\n\n        self.config.data.train_batch_size //= dp_size\n\n        assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0\n\n    def _build_dataloader(self, train_dataset, val_dataset):\n        # build dataset\n        config = self.config\n        self.train_dataset, self.val_dataset = train_dataset, val_dataset\n\n        # build dataloader\n        # Use data parallel rank and size instead of global rank and world size\n\n        # If doing SP, we need to use the local rank and size\n        if self.config.ulysses_sequence_parallel_size > 1:\n            rank = self.ulysses_device_mesh.get_local_rank(\"dp\")\n            world_size = self.ulysses_device_mesh.size(0)\n            if self.ulysses_device_mesh.get_rank() == 0:\n                print(f\"Using SP rank {rank} and size {world_size} for data distribution\")\n                print(\"Each SP rank gets different data, but the same data WITHIN the same rank\")\n        else:\n            rank = self.device_mesh.get_rank()\n            world_size = self.device_mesh.size()\n        if self.device_mesh.get_rank() == 0:\n            print(f\"Using FSDP rank {rank} and size {world_size} for data distribution\")\n\n        # Set pin_memory_device when pin_memory is enabled.\n        device_name = get_device_name()\n\n        self.train_sampler = DistributedSampler(\n            self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True\n        )\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=config.data.train_batch_size,\n            sampler=self.train_sampler,\n            num_workers=8,\n            pin_memory=True,\n            drop_last=True,\n            pin_memory_device=device_name,\n        )\n\n        self.val_sampler = DistributedSampler(\n            self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True\n        )\n        self.val_dataloader = StatefulDataLoader(\n            dataset=self.val_dataset,\n            batch_size=config.data.micro_batch_size_per_gpu,\n            sampler=self.val_sampler,\n            num_workers=8,\n            pin_memory=True,\n            drop_last=True,\n            pin_memory_device=device_name,\n        )\n\n    def _build_model_optimizer(self):\n        # TODO (zhangchi.usc1992):\n        # 1. support pretrain from random weights\n        # 2. support init directly from sharded weights\n        local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)\n\n        if self.config.model.get(\"external_lib\", None) is not None:\n            # This is used to import external_lib into the huggingface systems\n            import importlib\n\n            importlib.import_module(self.config.model.external_lib)\n\n        log_gpu_memory_usage(\"Before model allocation\", logger=logger)\n\n        trust_remote_code = self.config.model.trust_remote_code\n        torch_dtype = self.config.model.fsdp_config.get(\"model_dtype\", \"fp32\")\n        torch_dtype = PrecisionType.to_dtype(torch_dtype)\n        # load config first\n        config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)\n        self.model_config = config\n        if hasattr(self.model_config, \"max_position_embeddings\"):\n            self.model_config.max_position_embeddings = max(\n                self.model_config.max_position_embeddings, self.config.data.max_length\n            )\n        if self.config.ulysses_sequence_parallel_size > 1:\n            assert self.use_remove_padding, \"Sequence parallel is only supported when remove_padding is enabled\"\n\n        # This may be very large\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context():\n            self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(\n                local_model_path,\n                config=config,\n                torch_dtype=torch_dtype,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )\n\n            if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1:\n                from verl.models.transformers.monkey_patch import apply_monkey_patch\n\n                apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size)\n\n            # Apply Liger kernel if use_liger is enabled\n            if self.config.model.get(\"use_liger\", False):\n                from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance\n\n                _apply_liger_kernel_to_instance(model=self.model)\n\n            if self.lora:\n                self.model.enable_input_require_grads()\n\n                lora_adapter_path = self.config.model.get(\"lora_adapter_path\")\n                if lora_adapter_path is not None:\n                    from peft import PeftModel\n\n                    print(f\"Loading pre-trained LoRA adapter for sft from: {lora_adapter_path}\")\n\n                    local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.use_shm)\n\n                    self.model = PeftModel.from_pretrained(self.model, local_adapter_path, is_trainable=True)\n                    peft_config = self.model.peft_config[\"default\"]\n                    # Ensure task_type is TaskType enum, not string\n                    if isinstance(peft_config.task_type, str):\n                        peft_config.task_type = TaskType.CAUSAL_LM\n                else:\n                    # Convert config to regular Python types before creating PEFT model\n                    lora_config = {\n                        \"task_type\": TaskType.CAUSAL_LM,\n                        \"r\": self.config.model.lora_rank,\n                        \"lora_alpha\": self.config.model.lora_alpha,\n                        \"target_modules\": convert_to_regular_types(self.config.model.target_modules),\n                        \"bias\": \"none\",\n                    }\n                    self.model = get_peft_model(self.model, LoraConfig(**lora_config))\n                self.model = self.model.to(torch_dtype)\n\n        if self.config.model.enable_gradient_checkpointing:\n            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n\n        log_gpu_memory_usage(\"After model allocation\", logger=logger)\n\n        mixed_precision = MixedPrecision(\n            param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32\n        )\n\n        auto_wrap_policy = get_fsdp_wrap_policy(\n            self.model,\n            config=self.config.model.fsdp_config.wrap_policy,\n            is_lora=self.lora,\n        )\n\n        if self.device_mesh.get_rank() == 0:\n            print(auto_wrap_policy)\n\n        if not self.config.model.fsdp_config.cpu_offload:\n            cpu_offload = None\n        else:\n            cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)\n\n        fsdp_strategy = self.config.model.strategy\n        if fsdp_strategy == \"fsdp\":\n            self.fsdp_model = FSDP(\n                self.model,\n                cpu_offload=cpu_offload,\n                param_init_fn=init_fn,\n                use_orig_params=False,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=ShardingStrategy.FULL_SHARD,\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                device_mesh=self.device_mesh,\n                forward_prefetch=False,\n            )\n        elif fsdp_strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            mp_policy = MixedPrecisionPolicy(\n                param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True\n            )\n\n            fsdp_kwargs = {\n                \"mesh\": self.device_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": cpu_offload,\n                \"reshard_after_forward\": True,\n            }\n            full_state = self.model.state_dict()\n            apply_fsdp2(self.model, fsdp_kwargs, self.config.model.fsdp_config)\n            fsdp2_load_full_state_dict(self.model, full_state, self.device_mesh, cpu_offload)\n            self.fsdp_model = self.model\n        else:\n            raise NotImplementedError(f\"not implement {fsdp_strategy}\")\n\n        log_gpu_memory_usage(\"After FSDP wrapping\", logger=logger)\n\n        self.optimizer = build_optimizer(self.fsdp_model.parameters(), self.config.optim)\n\n        log_gpu_memory_usage(\"After initialize optimizer\", logger=logger)\n\n        self.steps_per_epoch = len(self.train_dataloader)\n        self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs\n\n        if self.device_mesh.get_rank() == 0:\n            print(\n                f\"Number of steps/epoch {self.steps_per_epoch}, number of epochs \"\n                f\"{self.config.trainer.total_epochs}, total number of steps {self.total_steps}\"\n            )\n\n        num_warmup_steps = int(self.total_steps * self.config.optim.lr_warmup_steps_ratio)\n\n        if not hasattr(self.config.optim, \"lr_scheduler\") or self.config.optim.lr_scheduler == \"cosine\":\n            self.lr_scheduler = get_cosine_schedule_with_warmup(\n                optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps\n            )\n        elif self.config.optim.lr_scheduler == \"wsd\":\n            self.lr_scheduler = get_wsd_schedule_with_warmup(\n                optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps\n            )\n        else:\n            raise ValueError(f\"Unknown lr scheduler: {self.config.optim.lr_scheduler}\")\n\n    def _compute_loss_and_backward(self, batch, do_backward=True, n_micro_batches=1):\n        \"\"\"Compute loss with optional sequence parallelism and remove padding features\"\"\"\n        use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1\n\n        # Move inputs to GPU and prepare loss mask\n        input_ids = batch[\"input_ids\"].to(self.device_name)\n        attention_mask = batch[\"attention_mask\"].to(self.device_name)\n        position_ids = batch[\"position_ids\"].to(self.device_name)\n        loss_mask = batch.pop(\"loss_mask\")[:, 1:].reshape(-1).to(self.device_name)\n        loss_fct = nn.CrossEntropyLoss(reduction=\"none\")\n\n        # Context manager for sequence parallel if needed\n        context = self.sharding_manager if use_sp else nullcontext()\n        with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):\n            if not use_sp:\n                # Standard forward pass without sequence parallel\n                labels = input_ids[:, 1:].contiguous()\n                output = self.fsdp_model(\n                    input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False\n                )\n                logits = output.logits\n\n                shift_logits = logits[..., :-1, :].contiguous()\n                shift_labels = labels.contiguous()\n                # Flatten the tokens\n                shift_logits = shift_logits.view(-1, self.model.config.vocab_size)\n                shift_labels = shift_labels.view(-1)\n                # Enable model parallelism\n                shift_labels = shift_labels.to(shift_logits.device)\n                loss = loss_fct(shift_logits, shift_labels)\n                loss = loss * loss_mask.to(loss.device)\n            else:\n                # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks\n                # i.e., each GPU has <1 sequence, and each SP group has 1 sequence\n                # 1. All SP ranks will receive the *SAME* batch\n                # 2. Different SP groups will receive *DIFFERENT* batches\n                # This is implemented by the DistributedSampler\n\n                batch_size, seqlen = input_ids.shape\n                # Remove padding\n                input_ids_rmpad, indices, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # Unpad position_ids to align rotary\n                position_ids_rmpad = index_first_axis(\n                    rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                ).transpose(0, 1)\n\n                # Pad and slice inputs for sequence parallelism\n                input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(\n                    input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()\n                )\n                # For computing loss\n                input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)\n                input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(\n                    input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size()\n                )\n                input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)  # ((total_nnz / sp) + pad)\n\n                # Forward pass\n                output = self.fsdp_model(\n                    input_ids=input_ids_rmpad_sliced,\n                    attention_mask=None,  # Not needed with flash attention varlen\n                    position_ids=position_ids_rmpad_padded,\n                    use_cache=False,\n                )\n\n                # Compute loss locally then aggregate\n                logits_rmpad = output.logits.squeeze(0)\n                input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device)\n                loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled)\n                # Gather and unpad for sequence parallelism\n                loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size)\n\n                # This is the loss collected from all ulysses ranks\n                full_loss = pad_input(\n                    hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n                )\n                full_loss = full_loss.squeeze(-1)[:, :-1]  # Remove last token's loss\n                full_loss = full_loss.reshape(-1)\n                loss_mask = loss_mask.to(full_loss.device)\n                loss = full_loss * loss_mask\n\n            valid_token_this_rank = torch.sum(loss_mask)\n\n            if self.config.data.balance_dp_token:\n                torch.distributed.all_reduce(valid_token_this_rank)\n                dp_size = self.ulysses_device_mesh.size(\"dp\") if use_sp else torch.distributed.get_world_size()\n            else:\n                dp_size = 1\n\n            loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size\n\n            loss = loss / n_micro_batches  # normalize loss\n\n            if do_backward:\n                loss.backward()\n            return loss\n\n    def training_step(self, batch: TensorDict):\n        start_time = time.time()\n\n        self.fsdp_model.train()\n\n        log_gpu_memory_usage(\"Before optimizer zero_grad\", logger=logger)\n\n        self.optimizer.zero_grad()\n\n        log_gpu_memory_usage(\"After optimizer zero_grad\", logger=logger)\n\n        micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)\n        n_micro_batches = len(micro_batches)\n        step_loss = 0\n        for micro_batch in micro_batches:\n            loss = self._compute_loss_and_backward(batch=micro_batch, n_micro_batches=n_micro_batches)\n            step_loss += loss.item()\n\n        if self.config.model.strategy == \"fsdp\":\n            grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)\n        elif self.config.model.strategy == \"fsdp2\":\n            grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad)\n        else:\n            raise NotImplementedError(f\"not implement {self.config.model.strategy}\")\n\n        log_gpu_memory_usage(\"Before optimizer step\", logger=logger)\n\n        # if grad_norm is not finite, skip the update\n        if not torch.isfinite(grad_norm):\n            print(f\"WARN: grad_norm is not finite: {grad_norm}\")\n            self.optimizer.zero_grad()\n        else:\n            self.optimizer.step()\n\n        log_gpu_memory_usage(\"After optimizer step\", logger=logger)\n\n        self.lr_scheduler.step()\n\n        # reduce loss across dp ranks\n        lr = self.lr_scheduler.get_last_lr()[0]\n\n        log_gpu_memory_usage(\"After offload weights\", logger=logger)\n\n        step_loss = torch.tensor(step_loss).to(self.device_name)\n\n        # compute time spent per step\n        end_time = time.time()\n        spend_time_per_step = end_time - start_time\n\n        if is_cuda_available:\n            torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)\n        elif is_npu_available:\n            torch.distributed.all_reduce(step_loss)\n            step_loss /= self.device_mesh.size(0)\n        return {\n            \"train/loss\": step_loss.detach().item(),\n            \"train/lr(1e-3)\": lr * 1e3,\n            \"train/time(s)\": spend_time_per_step,\n        }\n\n    def validation_step(self, batch: TensorDict):\n        self.fsdp_model.eval()\n        with torch.no_grad():\n            loss = self._compute_loss_and_backward(batch, do_backward=False)\n            if is_cuda_available:\n                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)\n            elif is_npu_available:\n                torch.distributed.all_reduce(loss)\n                loss /= self.device_mesh.size(0)\n        return loss\n\n    def save_checkpoint(self, step):\n        \"\"\"Save checkpoint using FSDPCheckpointManager with improved tracking\"\"\"\n        from verl.utils.fs import local_mkdir_safe\n\n        # Determine checkpoint path\n        local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f\"global_step_{step}\")\n\n        if self.device_mesh.get_rank() == 0:\n            print(f\"Saving checkpoint to: {local_global_step_folder}\")\n\n        # Get max checkpoints to keep\n        max_ckpt_to_keep = getattr(self.config.trainer, \"max_ckpt_to_keep\", None)\n\n        # Use checkpoint manager to save\n        self.checkpoint_manager.save_checkpoint(\n            local_path=local_global_step_folder, global_step=step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n\n        # Save dataloader state\n        if self.device_mesh.get_rank() == 0:\n            local_mkdir_safe(local_global_step_folder)\n            dataloader_local_path = os.path.join(local_global_step_folder, \"data.pt\")\n\n            # Use StatefulDataLoader's built-in state dict functionality\n            dataloader_state_dict = self.train_dataloader.state_dict()\n            torch.save(dataloader_state_dict, dataloader_local_path)\n            print(f\"Saved dataloader state to: {dataloader_local_path}\")\n\n            # Update latest checkpoint tracker (atomic write)\n            tracker_file = get_checkpoint_tracker_filename(self.config.trainer.default_local_dir)\n            temp_tracker_file = tracker_file + \".tmp\"\n            with open(temp_tracker_file, \"w\") as f:\n                f.write(str(step))\n            os.rename(temp_tracker_file, tracker_file)\n            print(f\"Updated checkpoint tracker: {tracker_file}\")\n\n        # Copy to HDFS if configured\n        if self.device_mesh.get_rank() == 0 and getattr(self.config.trainer, \"default_hdfs_dir\", None):\n            hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)\n            hdfs_io.copy(src=local_global_step_folder, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)\n\n        torch.distributed.barrier()\n\n    def _init_checkpoint_manager(self):\n        \"\"\"Initialize checkpoint manager with proper configuration\"\"\"\n        # Get checkpoint configuration from config, with defaults\n        checkpoint_config = getattr(self.config.trainer, \"checkpoint\", {})\n\n        # Set default values if not specified\n        save_contents = checkpoint_config.get(\"save_contents\", [\"model\", \"optimizer\", \"extra\"])\n        load_contents = checkpoint_config.get(\"load_contents\", save_contents)\n\n        # Create checkpoint config dict\n        checkpoint_config_dict = {\n            \"load_contents\": load_contents,\n            \"save_contents\": save_contents,\n        }\n\n        # Convert to DictConfig for compatibility\n        checkpoint_config_dict = DictConfig(checkpoint_config_dict)\n\n        # Initialize checkpoint manager\n        self.checkpoint_manager = FSDPCheckpointManager(\n            model=self.fsdp_model,\n            optimizer=self.optimizer,\n            lr_scheduler=self.lr_scheduler,\n            processing_class=self.tokenizer,\n            checkpoint_config=checkpoint_config_dict,\n        )\n\n    def load_checkpoint(self):\n        # Determine resume path based on configuration\n        checkpoint_path = self._determine_resume_path()\n\n        if checkpoint_path is None:\n            return 0\n\n        # extract resume step from checkpoint path\n        resume_step = extract_step(checkpoint_path)\n        if resume_step is None:\n            log_with_rank(\n                f\"Warning: Could not extract step number from {checkpoint_path}, starting from step 0\",\n                logger=logger,\n                rank=self.device_mesh.get_rank(),\n                level=logging.WARNING,\n                log_only_rank_0=True,\n            )\n            return 0\n        self.resume_global_step = resume_step\n\n        # Use checkpoint manager to load model state\n        self.checkpoint_manager.load_checkpoint(checkpoint_path)\n        log_with_rank(\n            f\"Successfully loaded model checkpoint from {checkpoint_path} (step {resume_step})\",\n            logger=logger,\n            rank=self.device_mesh.get_rank(),\n            log_only_rank_0=True,\n        )\n\n        # Always load dataloader state for StatefulDataLoader\n        self._load_dataloader_state(checkpoint_path)\n\n        return resume_step\n\n    def _load_dataloader_state(self, checkpoint_path: str):\n        \"\"\"Load dataloader state from checkpoint\"\"\"\n        dataloader_path = os.path.join(checkpoint_path, \"data.pt\")\n\n        if os.path.exists(dataloader_path):\n            # Use StatefulDataLoader's built-in state dict functionality\n            dataloader_state_dict = torch.load(dataloader_path, map_location=\"cpu\", weights_only=False)\n            self.train_dataloader.load_state_dict(dataloader_state_dict)\n\n            log_with_rank(\n                f\"Successfully loaded dataloader state from {dataloader_path}\",\n                logger=logger,\n                rank=self.device_mesh.get_rank(),\n                log_only_rank_0=True,\n            )\n\n        else:\n            log_with_rank(\n                f\"Warning: No dataloader state found at {dataloader_path}, will start from scratch\",\n                logger=logger,\n                rank=self.device_mesh.get_rank(),\n                level=logging.WARNING,\n                log_only_rank_0=True,\n            )\n\n    def _determine_resume_path(self):\n        \"\"\"Determine the path to resume from based on resume_mode configuration\"\"\"\n        resume_mode = getattr(self.config.trainer, \"resume_mode\", \"auto\")\n        resume_from_path = getattr(self.config.trainer, \"resume_from_path\", None)\n\n        if resume_mode == \"disable\":\n            return None\n        elif resume_mode == \"auto\":\n            if resume_from_path is not None:\n                assert os.path.exists(resume_from_path), (\n                    \"resume_from_path must be null or an existing path when resume_mode is 'auto'\"\n                )\n                assert \"global_step_\" in resume_from_path, \"resume_from_path must specify the global_steps\"\n                return resume_from_path\n            # Try to find the latest checkpoint in the default directory\n            return self._find_latest_checkpoint()\n        elif resume_mode == \"resume_path\":\n            assert os.path.exists(resume_from_path), (\n                \"resume_from_path must be an existing path when resume_mode is 'resume_path'\"\n            )\n            assert \"global_step_\" in resume_from_path, \"resume_from_path must specify the global_steps\"\n            return resume_from_path\n        else:\n            raise ValueError(f\"Invalid resume_mode: {resume_mode}. Must be 'auto', 'disable', or 'resume_path'\")\n\n    def _find_latest_checkpoint(self):\n        \"\"\"Find the latest checkpoint in the default local directory\"\"\"\n        checkpoint_dir = self.config.trainer.default_local_dir\n\n        if not os.path.exists(checkpoint_dir):\n            return None\n\n        latest_checkpoint = find_latest_ckpt_path(checkpoint_dir)\n\n        if latest_checkpoint and self.device_mesh.get_rank() == 0:\n            step_num = extract_step(latest_checkpoint)\n            print(f\"Found latest checkpoint: {latest_checkpoint} (step {step_num})\")\n\n        return latest_checkpoint\n\n    def fit(self):\n        rank = self.device_mesh.get_rank()\n\n        # TODO: add a unified tracking\n        if rank == 0:\n            tracking = Tracking(\n                project_name=self.config.trainer.project_name,\n                experiment_name=self.config.trainer.experiment_name,\n                default_backend=self.config.trainer.logger,\n                config=OmegaConf.to_container(self.config, resolve=True),\n            )\n\n        global_step = self.resume_global_step  # Start from resumed step\n        last_valid_metric = None\n        # compute the total training steps.\n        # the total training steps in SFT is mainly for early exit\n        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n\n        if self.config.trainer.total_training_steps is not None:\n            total_training_steps = self.config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n        log_with_rank(\n            f\"Total training steps: {self.total_training_steps},\",\n            logger=logger,\n            rank=self.device_mesh.get_rank(),\n            log_only_rank_0=True,\n        )\n\n        # With StatefulDataLoader, we don't need to manually calculate epochs and steps\n        # The dataloader will automatically resume from where it left off\n        if global_step > 0:\n            log_with_rank(\n                f\"StatefulDataLoader will automatically resume from global step: {global_step}\",\n                logger=logger,\n                rank=self.device_mesh.get_rank(),\n                log_only_rank_0=True,\n            )\n\n        # Calculate which epoch we're starting from for sampler.set_epoch()\n        start_epoch = global_step // self.steps_per_epoch\n\n        train_time = 0\n        for epoch in range(start_epoch, self.config.trainer.total_epochs):\n            self.train_sampler.set_epoch(epoch=epoch)\n\n            for step_in_epoch, data in enumerate(\n                tqdm(\n                    self.train_dataloader,\n                    initial=global_step % self.steps_per_epoch if epoch == start_epoch else 0,\n                    total=self.steps_per_epoch,\n                    desc=f\"Epoch {epoch + 1}/{self.config.trainer.total_epochs}\",\n                    disable=rank != 0,\n                )\n            ):\n                global_step += 1\n                data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name)\n                metric = self.training_step(data)\n                train_time += metric[\"train/time(s)\"]\n                if rank == 0:\n                    tracking.log(data=metric, step=global_step)\n\n                is_last_step = global_step >= self.total_training_steps\n                is_valid_step = global_step % self.config.trainer.test_freq == 0\n                is_save_step = global_step % self.config.trainer.save_freq == 0\n\n                # early exit or validation step\n                if is_last_step or (self.config.trainer.test_freq > 0 and is_valid_step):\n                    # Perform validation\n                    val_losses = []\n                    for val_data in self.val_dataloader:\n                        val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(\n                            self.device_name\n                        )\n                        val_loss = self.validation_step(val_data)\n                        val_losses.append(val_loss)\n                    if rank == 0:\n                        val_loss = torch.mean(torch.stack(val_losses))\n                        metric = {\"val/loss\": val_loss.detach().item()}\n                        tracking.log(data=metric, step=global_step)\n                        last_valid_metric = metric\n                    torch.distributed.barrier()\n\n                if is_last_step or (self.config.trainer.save_freq > 0 and is_save_step):\n                    self.save_checkpoint(step=global_step)\n\n                if is_last_step:\n                    if rank == 0:\n                        print(f\"Total time for train steps: {train_time:.2f}s\")\n                        print(f\"Final validation metrics: {last_valid_metric}\")\n                    return\n\n\ndef run_sft(config):\n    device_name = get_device_name()\n    local_rank, rank, world_size = initialize_global_process_group()\n\n    device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=(\"fsdp\",))\n    dp_size = world_size // config.ulysses_sequence_parallel_size\n    ulysses_device_mesh = init_device_mesh(\n        device_type=device_name,\n        mesh_shape=(dp_size, config.ulysses_sequence_parallel_size),\n        mesh_dim_names=(\"dp\", \"sp\"),\n    )\n    # build tokenizer and datasets first\n    from verl.utils import hf_tokenizer\n\n    local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)\n    tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)\n    train_dataset = create_sft_dataset(\n        config.data.train_files, config.data, tokenizer, max_samples=config.data.get(\"train_max_samples\", -1)\n    )\n    val_dataset = create_sft_dataset(\n        config.data.val_files, config.data, tokenizer, max_samples=config.data.get(\"val_max_samples\", -1)\n    )\n\n    trainer = FSDPSFTTrainer(\n        config=config,\n        device_mesh=device_mesh,\n        ulysses_device_mesh=ulysses_device_mesh,\n        tokenizer=tokenizer,\n        train_dataset=train_dataset,\n        val_dataset=val_dataset,\n    )\n\n    trainer.fit()\n\n    destroy_global_process_group()\n\n\n@hydra.main(config_path=\"config\", config_name=\"sft_trainer\", version_base=None)\ndef main(config):\n    run_sft(config)\n\n\ndef create_sft_dataset(data_paths, data_config, tokenizer, max_samples=-1):\n    \"\"\"Create a dataset.\"\"\"\n    # build dataset\n    # First check if a custom dataset class is specified\n    if data_config.custom_cls.get(\"path\", None):\n        from verl.utils.import_utils import load_extern_type\n\n        dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)\n    # Then check if multi-turn dataset should be used\n    elif data_config.get(\"multiturn\", {}).get(\"enable\", False):\n        dataset_cls = MultiTurnSFTDataset\n    # Default to single-turn dataset\n    else:\n        dataset_cls = SFTDataset\n\n    # Create datasets based on the selected class\n    dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config, max_samples=max_samples)\n    return dataset\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/verl/trainer/main_eval.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nOffline evaluate the performance of a generated file using reward model and ground truth verifier.\nThe input is a parquet file that contains N generated sequences and (optional) the ground truth.\n\n\"\"\"\n\nfrom collections import defaultdict\n\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport ray\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\n\nfrom verl.trainer.ppo.reward import get_custom_reward_fn\nfrom verl.utils.fs import copy_to_local\n\n\n@ray.remote\ndef process_item(config, data_source, response_lst, reward_data):\n    reward_fn = get_custom_reward_fn(config)\n    ground_truth = reward_data[\"ground_truth\"]\n    score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]\n    return data_source, np.mean(score_lst)\n\n\n@hydra.main(config_path=\"config\", config_name=\"evaluation\", version_base=None)\ndef main(config):\n    local_path = copy_to_local(config.data.path, use_shm=config.data.get(\"use_shm\", False))\n    dataset = pd.read_parquet(local_path)\n    responses = dataset[config.data.response_key]\n    data_sources = dataset[config.data.data_source_key]\n    reward_model_data = dataset[config.data.reward_model_key]\n\n    total = len(dataset)\n\n    # Initialize Ray\n    if not ray.is_initialized():\n        ray.init(**OmegaConf.to_container(config.ray_kwargs.get(\"ray_init\", {})))\n\n    # evaluate test_score based on data source\n    data_source_reward = defaultdict(list)\n    # Create remote tasks\n    remote_tasks = [\n        process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)\n    ]\n\n    # Process results as they come in\n    with tqdm(total=total) as pbar:\n        while len(remote_tasks) > 0:\n            # Use ray.wait to get completed tasks\n            done_ids, remote_tasks = ray.wait(remote_tasks)\n            for result_id in done_ids:\n                data_source, score = ray.get(result_id)\n                data_source_reward[data_source].append(score)\n                pbar.update(1)\n\n    metric_dict = {}\n    for data_source, rewards in data_source_reward.items():\n        metric_dict[f\"test_score/{data_source}\"] = np.mean(rewards)\n\n    print(metric_dict)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/verl/trainer/main_generation.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nGenerate responses given a dataset of prompts\n\"\"\"\n\nimport os\n\nimport hydra\nimport numpy as np\nimport ray\n\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n# os.environ['TORCH_COMPILE_DISABLE'] = '1'\n\nfrom pprint import pprint\n\nimport pandas as pd\nfrom omegaconf import OmegaConf\n\nfrom verl import DataProto\nfrom verl.protocol import pad_dataproto_to_divisor, unpad_dataproto\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.hdfs_io import makedirs\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker\n\n\n@hydra.main(config_path=\"config\", config_name=\"generation\", version_base=None)\ndef main(config):\n    run_generation(config)\n\n\ndef run_generation(config) -> None:\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        default_runtime_env = {\"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\"}}\n        ray_init_kwargs = config.ray_kwargs.get(\"ray_init\", {})\n        runtime_env_kwargs = ray_init_kwargs.get(\"runtime_env\", {})\n        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)\n        ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, \"runtime_env\": runtime_env})\n        print(f\"ray init kwargs: {ray_init_kwargs}\")\n        ray.init(**OmegaConf.to_container(ray_init_kwargs))\n\n    ray.get(main_task.remote(config))\n\n\n@ray.remote(num_cpus=1)\ndef main_task(config):\n    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n    OmegaConf.resolve(config)\n\n    local_path = copy_to_local(config.model.path)\n    trust_remote_code = config.data.get(\"trust_remote_code\", False)\n    tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n\n    if config.rollout.temperature == 0.0:\n        assert config.data.n_samples == 1, \"When temperature=0, n_samples must be 1.\"\n    assert config.data.n_samples >= 1, \"n_samples should always >= 1\"\n\n    # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)\n    dataset = pd.read_parquet(config.data.path)\n    chat_lst = dataset[config.data.prompt_key].tolist()\n\n    chat_lst = [chat.tolist() for chat in chat_lst]\n\n    tokenizer.padding_side = \"left\"\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n\n    ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role=\"rollout\")\n    resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)\n    wg = RayWorkerGroup(\n        resource_pool=resource_pool,\n        ray_cls_with_init=ray_cls_with_init,\n        device_name=config.trainer.device,\n    )\n    wg.init_model()\n\n    total_samples = len(dataset)\n    config_batch_size = config.data.batch_size\n    apply_chat_template_kwargs = config.data.get(\"apply_chat_template_kwargs\", {})\n    num_batch = -(-total_samples // config_batch_size)\n    output_lst = [[] for _ in range(config.data.n_samples)]\n\n    for batch_idx in range(num_batch):\n        print(f\"[{batch_idx + 1}/{num_batch}] Start to process.\")\n        batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size]\n        inputs = tokenizer.apply_chat_template(\n            batch_chat_lst,\n            add_generation_prompt=True,\n            padding=True,\n            truncation=True,\n            max_length=config.rollout.prompt_length,\n            return_tensors=\"pt\",\n            return_dict=True,\n            tokenize=True,\n            **apply_chat_template_kwargs,\n        )\n        input_ids = inputs[\"input_ids\"]\n        attention_mask = inputs[\"attention_mask\"]\n        position_ids = compute_position_id_with_mask(attention_mask)\n        batch_dict = {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"position_ids\": position_ids}\n\n        data = DataProto.from_dict(batch_dict)\n        data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)\n\n        # START TO GENERATE FOR n_samples TIMES\n        print(f\"[{batch_idx + 1}/{num_batch}] Start to generate.\")\n        for n_sample in range(config.data.n_samples):\n            output_padded = wg.generate_sequences(data_padded)\n            output = unpad_dataproto(output_padded, pad_size=pad_size)\n\n            output_texts = []\n            for i in range(len(output)):\n                data_item = output[i]\n                prompt_length = data_item.batch[\"prompts\"].shape[-1]\n                valid_response_length = data_item.batch[\"attention_mask\"][prompt_length:].sum()\n                valid_response_ids = data_item.batch[\"responses\"][:valid_response_length]\n                response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n                output_texts.append(response_str)\n\n            output_lst[n_sample].extend(output_texts)\n\n    # convert output_lst from (n_samples, n_data) to (n_data, n_sampels)\n    output_lst = np.array(output_lst, dtype=object)\n    output_lst = np.transpose(output_lst, axes=(1, 0)).tolist()\n\n    # add to the data frame\n    dataset[\"responses\"] = output_lst\n\n    # write to a new parquet\n    output_dir = os.path.dirname(config.data.output_path)\n    makedirs(output_dir, exist_ok=True)\n    dataset.to_parquet(config.data.output_path)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/verl/trainer/main_generation_server.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nGenerate responses given a dataset of prompts\n\"\"\"\n\nimport os\n\nimport aiohttp\nimport hydra\nimport numpy as np\nimport ray\n\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n# os.environ['TORCH_COMPILE_DISABLE'] = '1'\n\nimport asyncio\nfrom pprint import pprint\n\nimport pandas as pd\nfrom omegaconf import OmegaConf\nfrom openai.types.chat import ChatCompletion\n\nfrom verl.utils.hdfs_io import makedirs\nfrom verl.workers.rollout.replica import get_rollout_replica_class\n\n\nasync def start_server(config):\n    tp_size = config.actor_rollout_ref.rollout.tensor_model_parallel_size\n    num_replicas = (config.trainer.n_gpus_per_node * config.trainer.nnodes) // tp_size\n    rollout_config = config.actor_rollout_ref.rollout\n    model_config = config.actor_rollout_ref.model\n    # create standalone rollout server\n    rollout_server_class = get_rollout_replica_class(config.actor_rollout_ref.rollout.name)\n    rollout_servers = [\n        rollout_server_class(\n            replica_rank=replica_rank,\n            config=rollout_config,\n            model_config=model_config,\n            gpus_per_node=config.trainer.n_gpus_per_node,\n        )\n        for replica_rank in range(num_replicas)\n    ]\n    await asyncio.gather(*[server.init_standalone() for server in rollout_servers])\n\n    server_handles = [server._server_handle for server in rollout_servers]\n    server_addresses = [server._server_address for server in rollout_servers]\n    assert len(server_handles) == num_replicas\n    assert len(server_addresses) == num_replicas\n\n    return server_handles, server_addresses\n\n\nasync def submit_request(server_address, **chat_complete_request):\n    try:\n        extra_headers = chat_complete_request.pop(\"extra_headers\", {})\n        timeout = aiohttp.ClientTimeout(total=None)\n        session = aiohttp.ClientSession(timeout=timeout)\n        async with session.post(\n            url=f\"http://{server_address}/v1/chat/completions\",\n            headers={\"Authorization\": \"Bearer token-abc123\", **extra_headers},\n            json=chat_complete_request,\n        ) as resp:\n            data = await resp.json()\n            return ChatCompletion(**data)\n    finally:\n        await session.close()\n\n\nasync def generate_per_replica(server_address, model_path: str, n_samples: int, sampling_params: dict, chat_lst: list):\n    # here we should sample n_samples for each chat_lst.\n    # we use aiohttp to avoid hang in AsyncOpenAI when the number of requests is large.\n\n    # client = AsyncOpenAI(\n    #     api_key=\"123-abc\",\n    #     base_url=f\"http://{server_address}/v1\",\n    # )\n\n    chat_complete_request = [\n        {\n            \"model\": model_path,\n            \"messages\": messages,\n            **sampling_params,\n        }\n        for messages in chat_lst\n        for _ in range(n_samples)\n    ]\n\n    tasks = [submit_request(server_address, **req) for req in chat_complete_request]\n    results = await asyncio.gather(*tasks)\n    return results\n\n\nasync def generate(\n    server_addresses: list, model_path: str, n_samples: int, sampling_params: dict, chat_numpy: np.ndarray\n):\n    num_replicas = len(server_addresses)\n    chat_sub_array = np.array_split(chat_numpy, num_replicas)\n    chat_sub_array = [chat.tolist() for chat in chat_sub_array]\n    assert len(server_addresses) == len(chat_sub_array)\n    results = await asyncio.gather(\n        *[\n            generate_per_replica(server_addresses[i], model_path, n_samples, sampling_params, chat_sub_array[i])\n            for i in range(num_replicas)\n        ]\n    )\n    return results\n\n\n@hydra.main(config_path=\"config\", config_name=\"ppo_trainer\", version_base=None)\ndef main(config):\n    ray.init(runtime_env={\"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\", \"VLLM_USE_V1\": \"1\"}})\n\n    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n    OmegaConf.resolve(config)\n\n    n_samples = config.actor_rollout_ref.rollout.n\n\n    if config.actor_rollout_ref.rollout.temperature == 0.0:\n        assert n_samples == 1, \"When temperature=0, n_samples must be 1.\"\n    assert n_samples >= 1, \"n_samples should always >= 1\"\n\n    sampling_params = {\n        \"temperature\": config.actor_rollout_ref.rollout.temperature,\n        \"top_p\": config.actor_rollout_ref.rollout.top_p,\n        # \"top_k\": config.actor_rollout_ref.rollout.top_k,\n        \"max_tokens\": config.actor_rollout_ref.rollout.response_length,\n    }\n\n    from omegaconf import ListConfig\n\n    train_files = config.data.train_files\n    if not isinstance(train_files, list | ListConfig):\n        train_files = [train_files]\n\n    # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)\n\n    datasets = []\n    for train_file in train_files:\n        dataset = pd.read_parquet(train_file)\n        datasets.append(dataset)\n\n    # concat dataset\n    dataset = pd.concat(datasets, axis=0, ignore_index=True)\n    chat_lst = dataset[config.data.prompt_key].tolist()\n    chat_lst = [chat.tolist() for chat in chat_lst]\n    chat_numpy = np.array(chat_lst)\n\n    # start native server\n    server_handles, server_addresses = asyncio.run(start_server(config))\n\n    # run generate\n    gen_results = asyncio.run(\n        generate(server_addresses, config.actor_rollout_ref.model.path, n_samples, sampling_params, chat_numpy)\n    )\n\n    # reshape results into a numpy array\n    import itertools\n\n    results = list(itertools.chain.from_iterable(gen_results))\n\n    # extract content from results\n    results = np.array([result.choices[0].message.content for result in results])\n    results = np.reshape(results, (-1, n_samples))\n\n    assert results.shape == (len(chat_lst), n_samples)\n\n    results = results.tolist()\n\n    # add to the data frame\n    dataset[\"responses\"] = results\n\n    # write to a new parquet\n    output_dir = os.path.dirname(config.data.output_path)\n    makedirs(output_dir, exist_ok=True)\n    print(f\"Saving results to {config.data.output_path}\")\n    dataset.to_parquet(config.data.output_path)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/verl/trainer/main_ppo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other mpain.\n\"\"\"\n\nimport os\nimport socket\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom verl.experimental.dataset.sampler import AbstractSampler\nfrom verl.trainer.constants_ppo import get_ppo_ray_runtime_env\nfrom verl.trainer.ppo.ray_trainer import RayPPOTrainer\nfrom verl.trainer.ppo.reward import load_reward_manager\nfrom verl.trainer.ppo.utils import need_critic, need_reference_policy\nfrom verl.utils.config import validate_config\nfrom verl.utils.device import is_cuda_available\nfrom verl.utils.import_utils import load_extern_type\n\n\n@hydra.main(config_path=\"config\", config_name=\"ppo_trainer\", version_base=None)\ndef main(config):\n    \"\"\"Main entry point for PPO training with Hydra configuration management.\n\n    Args:\n        config_dict: Hydra configuration dictionary containing training parameters.\n    \"\"\"\n    run_ppo(config)\n\n\n# Define a function to run the PPO-like training process\ndef run_ppo(config, task_runner_class=None) -> None:\n    \"\"\"Initialize Ray cluster and run distributed PPO training process.\n\n    Args:\n        config: Training configuration object containing all necessary parameters\n                for distributed PPO training including Ray initialization settings,\n                model paths, and training hyperparameters.\n        task_runner_class: For recipe to change TaskRunner.\n    \"\"\"\n    # Check if Ray is not initialized\n    if not ray.is_initialized():\n        # Initialize Ray with a local cluster configuration\n        # Set environment variables in the runtime environment to control tokenizer parallelism,\n        # NCCL debug level, VLLM logging level, and allow runtime LoRA updating\n        # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration\n        default_runtime_env = get_ppo_ray_runtime_env()\n        ray_init_kwargs = config.ray_kwargs.get(\"ray_init\", {})\n        runtime_env_kwargs = ray_init_kwargs.get(\"runtime_env\", {})\n\n        if config.transfer_queue.enable:\n            # Add runtime environment variables for transfer queue\n            runtime_env_vars = runtime_env_kwargs.get(\"env_vars\", {})\n            runtime_env_vars[\"TRANSFER_QUEUE_ENABLE\"] = \"1\"\n            runtime_env_kwargs[\"env_vars\"] = runtime_env_vars\n\n\n        for k, v in runtime_env_kwargs['env_vars'].items():\n            if not isinstance(v, str):\n                runtime_env_kwargs['env_vars'][k] = str(v)\n\n        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)\n        ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, \"runtime_env\": runtime_env})\n        print(f\"ray init kwargs: {ray_init_kwargs}\")\n        ray.init(**OmegaConf.to_container(ray_init_kwargs))\n\n    if task_runner_class is None:\n        task_runner_class = ray.remote(num_cpus=1)(TaskRunner)  # please make sure main_task is not scheduled on head\n\n    # Create a remote instance of the TaskRunner class, and\n    # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete\n    if (\n        is_cuda_available\n        and config.global_profiler.tool == \"nsys\"\n        and config.global_profiler.get(\"steps\") is not None\n        and len(config.global_profiler.get(\"steps\", [])) > 0\n    ):\n        from verl.utils.import_utils import is_nvtx_available\n\n        assert is_nvtx_available(), \"nvtx is not available in CUDA platform. Please 'pip3 install nvtx'\"\n        nsight_options = OmegaConf.to_container(\n            config.global_profiler.global_tool_config.nsys.controller_nsight_options\n        )\n        runner = task_runner_class.options(runtime_env={\"nsight\": nsight_options}).remote()\n    else:\n        runner = task_runner_class.remote()\n    ray.get(runner.run.remote(config))\n\n    # [Optional] get the path of the timeline trace file from the configuration, default to None\n    # This file is used for performance analysis\n    timeline_json_file = config.ray_kwargs.get(\"timeline_json_file\", None)\n    if timeline_json_file:\n        ray.timeline(filename=timeline_json_file)\n\n\nclass TaskRunner:\n    \"\"\"Ray remote class for executing distributed PPO training tasks.\n\n    This class encapsulates the main training logic and runs as a Ray remote actor\n    to enable distributed execution across multiple nodes and GPUs.\n\n    Attributes:\n        role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes\n        mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation\n    \"\"\"\n\n    def __init__(self):\n        self.role_worker_mapping = {}\n        self.mapping = {}\n\n    def add_actor_rollout_worker(self, config):\n        \"\"\"Add actor rollout worker based on the actor strategy.\"\"\"\n        from verl.single_controller.ray import RayWorkerGroup\n\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = RayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from verl.trainer.ppo.ray_trainer import Role\n\n        self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls)\n\n        return actor_rollout_cls, ray_worker_group_cls\n\n    def add_critic_worker(self, config):\n        \"\"\"Add critic worker to role mapping.\"\"\"\n        if config.critic.strategy in {\"fsdp\", \"fsdp2\"}:\n            use_legacy_worker_impl = config.trainer.get(\"use_legacy_worker_impl\", \"auto\")\n            if use_legacy_worker_impl in [\"auto\", \"enable\"]:\n                from verl.workers.fsdp_workers import CriticWorker\n            elif use_legacy_worker_impl == \"disable\":\n                from verl.workers.roles import CriticWorker\n\n                print(\"Using new worker implementation\")\n            else:\n                raise ValueError(f\"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}\")\n\n        elif config.critic.strategy == \"megatron\":\n            from verl.workers.megatron_workers import CriticWorker\n\n        else:\n            raise NotImplementedError\n\n        from verl.trainer.ppo.ray_trainer import Role\n\n        self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker)\n\n    def init_resource_pool_mgr(self, config):\n        \"\"\"Initialize resource pool manager.\"\"\"\n        from verl.trainer.ppo.ray_trainer import Role\n\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {\n            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n        }\n        # TODO Here you can use the new registration method to support dynamic registration of roles\n        if config.reward_model.enable_resource_pool:\n            if config.reward_model.n_gpus_per_node <= 0:\n                raise ValueError(\"config.reward_model.n_gpus_per_node must be greater than 0\")\n            if config.reward_model.nnodes <= 0:\n                raise ValueError(\"config.reward_model.nnodes must be greater than 0\")\n\n            reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes\n            resource_pool_spec[\"reward_pool\"] = reward_pool\n\n        self.mapping[Role.ActorRollout] = global_pool_id\n        self.mapping[Role.Critic] = global_pool_id\n        from verl.trainer.ppo.ray_trainer import ResourcePoolManager\n\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping)\n        return resource_pool_manager\n\n    def add_reward_model_worker(self, config):\n        \"\"\"Add reward model worker if enabled.\"\"\"\n        from verl.trainer.ppo.ray_trainer import Role\n\n        if config.reward_model.enable:\n            use_legacy_worker_impl = config.trainer.get(\"use_legacy_worker_impl\", \"auto\")\n            if use_legacy_worker_impl in [\"auto\", \"enable\"]:\n                if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                    from verl.workers.fsdp_workers import RewardModelWorker\n                elif config.reward_model.strategy == \"megatron\":\n                    from verl.workers.megatron_workers import RewardModelWorker\n                else:\n                    raise NotImplementedError\n            elif use_legacy_worker_impl == \"disable\":\n                from verl.workers.roles import RewardModelWorker\n\n                print(\"Using new worker implementation\")\n            else:\n                raise ValueError(f\"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}\")\n\n            self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            if config.reward_model.enable_resource_pool:\n                self.mapping[Role.RewardModel] = \"reward_pool\"\n            else:\n                self.mapping[Role.RewardModel] = \"global_pool\"\n\n    def add_ref_policy_worker(self, config, ref_policy_cls):\n        \"\"\"Add reference policy worker if KL loss or KL reward is used.\"\"\"\n        from verl.trainer.ppo.ray_trainer import Role\n\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls)\n            self.mapping[Role.RefPolicy] = \"global_pool\"\n\n    def run(self, config):\n        \"\"\"Execute the main PPO training workflow.\n\n        This method sets up the distributed training environment, initializes\n        workers, datasets, and reward functions, then starts the training process.\n\n        Args:\n            config: Training configuration object containing all parameters needed\n                   for setting up and running the PPO training process.\n        \"\"\"\n        # Print the initial configuration. `resolve=True` will evaluate symbolic values.\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        print(f\"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}\")\n        pprint(OmegaConf.to_container(config, resolve=True))\n        OmegaConf.resolve(config)\n\n        actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)\n        self.add_critic_worker(config)\n\n        # We should adopt a multi-source reward function here:\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # finally, we combine all the rewards together\n        # The reward type depends on the tag of the data\n        self.add_reward_model_worker(config)\n\n        # Add a reference policy worker if KL loss or KL reward is used.\n        self.add_ref_policy_worker(config, actor_rollout_cls)\n\n        # validate config\n        validate_config(\n            config=config,\n            use_reference_policy=need_reference_policy(self.role_worker_mapping),\n            use_critic=need_critic(config),\n        )\n\n        # Download the checkpoint from HDFS to the local machine.\n        # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on\n        local_path = copy_to_local(\n            config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get(\"use_shm\", False)\n        )\n\n        # Instantiate the tokenizer and processor.\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        # Used for multimodal LLM, could be None\n        processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)\n\n        # Load the reward manager for training and validation.\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        val_reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=1, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n\n        resource_pool_manager = self.init_resource_pool_mgr(config)\n\n        from verl.utils.dataset.rl_dataset import collate_fn\n\n        # Create training and validation datasets.\n        train_dataset = create_rl_dataset(\n            config.data.train_files,\n            config.data,\n            tokenizer,\n            processor,\n            is_train=True,\n            max_samples=config.data.get(\"train_max_samples\", -1),\n        )\n        val_dataset = create_rl_dataset(\n            config.data.val_files,\n            config.data,\n            tokenizer,\n            processor,\n            is_train=False,\n            max_samples=config.data.get(\"val_max_samples\", -1),\n        )\n        train_sampler = create_rl_sampler(config.data, train_dataset)\n\n        # Initialize the PPO trainer.\n        trainer = RayPPOTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=self.role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            train_dataset=train_dataset,\n            val_dataset=val_dataset,\n            collate_fn=collate_fn,\n            train_sampler=train_sampler,\n        )\n        # Initialize the workers of the trainer.\n        trainer.init_workers()\n\n        # Start the training process.\n        trainer.fit()\n\n\ndef create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True, max_samples: int = -1):\n    \"\"\"Create a dataset.\n\n    Arguments:\n        data_paths: List of paths to data files.\n        data_config: The data config.\n        tokenizer (Tokenizer): The tokenizer.\n        processor (Processor): The processor.\n\n    Returns:\n        dataset (Dataset): The dataset.\n    \"\"\"\n    from torch.utils.data import Dataset\n\n    from verl.utils.dataset.rl_dataset import RLHFDataset\n\n    # Check if a custom dataset class is specified in the data configuration\n    # and if the path to the custom class is provided\n    if \"custom_cls\" in data_config and data_config.custom_cls.get(\"path\", None) is not None:\n        # Dynamically load the custom dataset class\n        dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)\n        # Verify that the custom dataset class inherits from torch.utils.data.Dataset\n        if not issubclass(dataset_cls, Dataset):\n            raise TypeError(\n                f\"The custom dataset class '{data_config.custom_cls.name}' from \"\n                f\"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset\"\n            )\n    elif \"datagen\" in data_config and data_config.datagen.get(\"path\", None) is not None and is_train:\n        # If a data generation strategy is specified, use the DynamicGenDataset class\n        from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset\n\n        dataset_cls = DynamicGenDataset\n        print(\"Using DynamicGenDataset for data generation.\")\n    else:\n        # Use the default RLHFDataset class if no custom class is specified\n        dataset_cls = RLHFDataset\n    print(f\"Using dataset class: {dataset_cls.__name__}\")\n\n    # Instantiate the dataset using the determined dataset class\n    dataset = dataset_cls(\n        data_files=data_paths,\n        tokenizer=tokenizer,\n        processor=processor,\n        config=data_config,\n        max_samples=max_samples,\n    )\n\n    return dataset\n\n\ndef create_rl_sampler(data_config, dataset):\n    \"\"\"Create a sampler for the dataset.\n\n    Arguments:\n        data_config: The data config.\n        dataset (Dataset): The dataset.\n\n    Returns:\n        sampler (Sampler): The sampler.\n    \"\"\"\n    import torch\n    from torch.utils.data import RandomSampler, SequentialSampler\n\n    if data_config.sampler is not None and data_config.sampler.get(\"class_path\", None) is not None:\n        curriculum_class = load_extern_type(\n            data_config.sampler.class_path,\n            data_config.sampler.class_name,\n        )\n        sampler = curriculum_class(\n            data_source=dataset,\n            data_config=data_config,\n        )\n        assert isinstance(sampler, AbstractSampler)\n        assert data_config.get(\"dataloader_num_workers\", 8) == 0, (\n            \"If using curriculum, num_workers must be 0 to prevent data caching. \"\n            \"If the dataloader caches data before the batch is done the \"\n            \"curriculum sampler won't have the opportunity to reorder it. \"\n        )\n\n    # Use a sampler to facilitate checkpoint resumption.\n    # If shuffling is enabled in the data configuration, create a random sampler.\n    elif data_config.shuffle:\n        train_dataloader_generator = torch.Generator()\n        seed = data_config.get(\"seed\")\n        if seed is not None:\n            train_dataloader_generator.manual_seed(seed)\n        sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)\n    else:\n        # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order.\n        sampler = SequentialSampler(data_source=dataset)\n\n    return sampler\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/verl/trainer/ppo/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/trainer/ppo/core_algos.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nCore functions to implement PPO algorithms.\nThe function implemented in this file should be used by trainer with different distributed strategies to\nimplement PPO-like algorithms.\n\"\"\"\n\n__all__ = [\"register_adv_est\", \"get_adv_estimator_fn\", \"AdvantageEstimator\"]\n\nfrom collections import defaultdict\nfrom enum import Enum\nfrom typing import Any, Callable, Optional\n\nimport numpy as np\nimport torch\nfrom omegaconf import DictConfig\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.trainer.config import AlgoConfig\nfrom verl.utils import as_torch_index, group_mean_std\nfrom verl.utils.import_utils import deprecated\nfrom verl.workers.config import ActorConfig\n\nPolicyLossFn = Callable[\n    [\n        torch.Tensor,  # old_log_prob\n        torch.Tensor,  # log_prob\n        torch.Tensor,  # advantages\n        torch.Tensor,  # response_mask\n        str,  # loss_agg_mode\n        Optional[DictConfig | AlgoConfig],  # config\n        torch.Tensor | None,  # rollout_log_probs\n    ],\n    tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],\n]\n\nPOLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}\n\n\ndef register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]:\n    \"\"\"Register a policy loss function with the given name.\n\n    Args:\n        name (str): The name to register the policy loss function under.\n\n    Returns:\n        function: Decorator function that registers the policy loss function.\n    \"\"\"\n\n    def decorator(func: PolicyLossFn) -> PolicyLossFn:\n        POLICY_LOSS_REGISTRY[name] = func\n        return func\n\n    return decorator\n\n\ndef get_policy_loss_fn(name):\n    \"\"\"Get the policy loss with a given name.\n\n    Args:\n        name: `(str)`\n            The name of the policy loss.\n\n    Returns:\n        `(callable)`: The policy loss function.\n    \"\"\"\n    loss_name = name\n    if loss_name not in POLICY_LOSS_REGISTRY:\n        raise ValueError(\n            f\"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}\"\n        )\n    return POLICY_LOSS_REGISTRY[loss_name]\n\n\nclass AdvantageEstimator(str, Enum):\n    \"\"\"Using an enumeration class to avoid spelling errors in adv_estimator.\n\n    Note(haibin.lin): this enum class is immutable after creation. Extending this\n    enum for new estimators may not be necessary since users can always just call\n    `verl.trainer.ppo.core_algos.register` with string name for a custom advantage\n    estimator instead.\n    \"\"\"\n\n    GAE = \"gae\"\n    GRPO = \"grpo\"\n    REINFORCE_PLUS_PLUS = \"reinforce_plus_plus\"\n    REINFORCE_PLUS_PLUS_BASELINE = \"reinforce_plus_plus_baseline\"\n    REMAX = \"remax\"\n    RLOO = \"rloo\"\n    OPO = \"opo\"\n    GRPO_PASSK = \"grpo_passk\"\n    GPG = \"gpg\"\n    RLOO_VECTORIZED = \"rloo_vectorized\"\n    GRPO_VECTORIZED = \"grpo_vectorized\"\n    ON_POLICY_DISTILL = \"on_policy_distill\"\n\n\nADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}\n\n\ndef register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any:\n    \"\"\"Decorator to register a advantage estimator function with a given name.\n\n    Args:\n        name_or_enum: `(str)` or `(AdvantageEstimator)`\n            The name or enum of the advantage estimator.\n\n    \"\"\"\n\n    def decorator(fn):\n        name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum\n        if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn:\n            raise ValueError(\n                f\"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}\"\n            )\n        ADV_ESTIMATOR_REGISTRY[name] = fn\n        return fn\n\n    return decorator\n\n\ndef get_adv_estimator_fn(name_or_enum):\n    \"\"\"Get the advantage estimator function with a given name.\n\n    Args:\n        name_or_enum: `(str)` or `(AdvantageEstimator)`\n            The name or enum of the advantage estimator.\n\n    Returns:\n        `(callable)`: The advantage estimator function.\n    \"\"\"\n    name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum\n    if name not in ADV_ESTIMATOR_REGISTRY:\n        raise ValueError(f\"Unknown advantage estimator simply: {name}\")\n    return ADV_ESTIMATOR_REGISTRY[name]\n\n\nclass AdaptiveKLController:\n    \"\"\"\n    Adaptive KL controller described in the paper:\n    https://arxiv.org/pdf/1909.08593.pdf\n    \"\"\"\n\n    def __init__(self, init_kl_coef, target_kl, horizon):\n        self.value = init_kl_coef\n        self.target = target_kl\n        self.horizon = horizon\n\n    def update(self, current_kl, n_steps):\n        \"\"\"Update the KL coefficient based on current KL divergence.\n\n        Args:\n            current_kl (float): Current KL divergence value.\n            n_steps (int): Number of steps taken.\n        \"\"\"\n        target = self.target\n        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)\n        mult = 1 + proportional_error * n_steps / self.horizon\n        self.value *= mult\n\n\nclass FixedKLController:\n    \"\"\"Fixed KL controller.\"\"\"\n\n    def __init__(self, kl_coef):\n        self.value = kl_coef\n\n    def update(self, current_kl, n_steps):\n        \"\"\"Update method for fixed KL controller (no-op).\n\n        Args:\n            current_kl (float): Current KL divergence value (unused).\n            n_steps (int): Number of steps taken (unused).\n        \"\"\"\n        pass\n\n\ndef get_kl_controller(kl_ctrl):\n    \"\"\"Factory function to create appropriate KL controller based on configuration.\n\n    Args:\n        kl_ctrl: Configuration object containing KL controller settings.\n\n    Returns:\n        KL controller instance (FixedKLController or AdaptiveKLController).\n\n    Raises:\n        NotImplementedError: If controller type is not supported.\n        AssertionError: If adaptive controller horizon is not positive.\n    \"\"\"\n    if kl_ctrl.type == \"fixed\":\n        return FixedKLController(kl_coef=kl_ctrl.kl_coef)\n    elif kl_ctrl.type == \"adaptive\":\n        assert kl_ctrl.horizon > 0, f\"horizon must be larger than 0. Got {kl_ctrl.horizon}\"\n        return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)\n    else:\n        raise NotImplementedError\n\n\n@register_adv_est(AdvantageEstimator.GAE)  # or simply: @register_adv_est(\"gae\")\ndef compute_gae_advantage_return(\n    token_level_rewards: torch.Tensor,\n    values: torch.Tensor,\n    response_mask: torch.Tensor,\n    gamma: torch.Tensor,\n    lam: torch.Tensor,\n):\n    \"\"\"Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape is (bs, response_length)\n        values: `(torch.Tensor)`\n            shape is (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.\n        gamma is `(float)`\n            discounted factor used in RL\n        lam: `(float)`\n            lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n\n    \"\"\"\n    with torch.no_grad():\n        nextvalues = 0\n        lastgaelam = 0\n        advantages_reversed = []\n        gen_len = token_level_rewards.shape[-1]\n\n        for t in reversed(range(gen_len)):\n            delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]\n            lastgaelam_ = delta + gamma * lam * lastgaelam\n\n            # skip values and TD-error on observation tokens\n            nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues\n            lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam\n\n            advantages_reversed.append(lastgaelam)\n        advantages = torch.stack(advantages_reversed[::-1], dim=1)\n\n        returns = advantages + values\n        advantages = verl_F.masked_whiten(advantages, response_mask)\n    return advantages, returns\n\n@register_adv_est(AdvantageEstimator.ON_POLICY_DISTILL)\ndef compute_on_policy_distill_reverse_kl(\n    teacher_log_prob: torch.Tensor,\n    student_log_prob: torch.Tensor,\n    config: Optional[AlgoConfig] = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    reverse_kl = student_log_prob - teacher_log_prob\n    return -reverse_kl, -reverse_kl\n\n# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.\n@register_adv_est(AdvantageEstimator.GRPO)  # or simply: @register_adv_est(\"grpo\")\ndef compute_grpo_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    norm_adv_by_std_in_grpo: bool = True,\n    config: Optional[AlgoConfig] = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for GRPO, operating only on Outcome reward\n    (with only one scalar reward for each response).\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape is (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape is (bs, response_length)\n        index: `(np.ndarray)`\n            index array for grouping\n        epsilon: `(float)`\n            small value to avoid division by zero\n        norm_adv_by_std_in_grpo: `(bool)`\n            whether to scale the GRPO advantage\n        config: `(Optional[AlgoConfig])`\n            algorithm configuration object\n\n    Note:\n        If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.\n        If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape is (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape is (bs, response_length)\n    \"\"\"\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2mean = {}\n    id2std = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2mean[idx] = torch.tensor(0.0)\n                id2std[idx] = torch.tensor(1.0)\n            elif len(id2score[idx]) > 1:\n                scores_tensor = torch.stack(id2score[idx])\n                id2mean[idx] = torch.mean(scores_tensor)\n                id2std[idx] = torch.std(scores_tensor)\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            if norm_adv_by_std_in_grpo:\n                scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)\n            else:\n                scores[i] = scores[i] - id2mean[index[i]]\n        scores = scores.unsqueeze(-1) * response_mask\n\n    return scores, scores\n\n\n@register_adv_est(AdvantageEstimator.GRPO_VECTORIZED)\ndef compute_grpo_vectorized_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    norm_adv_by_std_in_grpo: bool = True,\n    config: Optional[AlgoConfig] = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Vectorized GRPO（outcome-only）:\n      For each group g:\n      a_i = \\\\frac{r_i - \\\\mu_g}{\\\\sigma_g} (or without dividing by \\\\sigma_g),\n      then broadcast the scalar across the token dimension (multiplied by response_mask).。\n    \"\"\"\n    with torch.no_grad():\n        scores = token_level_rewards.sum(dim=-1)\n        g = as_torch_index(index, device=scores.device)\n        mean_g, std_g, _ = group_mean_std(scores, g, eps=epsilon)\n        if norm_adv_by_std_in_grpo:\n            scalars = (scores - mean_g[g]) / (std_g[g] + epsilon)\n        else:\n            scalars = scores - mean_g[g]\n        advantages = scalars.unsqueeze(-1) * response_mask\n        return advantages, advantages\n\n\n@register_adv_est(AdvantageEstimator.GRPO_PASSK)  # or simply: @register_adv_est(\"grpo_passk\")\ndef compute_grpo_passk_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    norm_adv_by_std_in_grpo: bool = True,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for Pass@k using a GRPO-style outcome reward formulation.\n    Only the best response per group gets a non-zero advantage: r_max - r_second_max.\n\n    Implemented as described in https://arxiv.org/abs/2503.19595.\n\n    Args:\n        token_level_rewards: (bs, response_length)\n        response_mask: (bs, response_length)\n        index: (bs,) → group ID per sample\n        epsilon: float for numerical stability\n        config: (AlgoConfig) algorithm settings, which contains \"norm_adv_by_std_in_grpo\"\n\n    Returns:\n        advantages: (bs, response_length)\n        returns: (bs, response_length)\n    \"\"\"\n    assert config is not None\n    # if True, normalize advantage by std within group\n    norm_adv_by_std_in_grpo = config.get(\"norm_adv_by_std_in_grpo\", True)\n    scores = token_level_rewards.sum(dim=-1)  # (bs,)\n    advantages = torch.zeros_like(scores)\n\n    id2scores = defaultdict(list)\n    id2indices = defaultdict(list)\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            idx = index[i]\n            id2scores[idx].append(scores[i])\n            id2indices[idx].append(i)\n\n        for idx in id2scores:\n            rewards = torch.stack(id2scores[idx])  # (k,)\n            if rewards.numel() < 2:\n                raise ValueError(\n                    f\"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}.\"\n                )\n            topk, topk_idx = torch.topk(rewards, 2)\n            r_max, r_second_max = topk[0], topk[1]\n            i_max = id2indices[idx][topk_idx[0].item()]\n            advantage = r_max - r_second_max\n            if norm_adv_by_std_in_grpo:\n                std = torch.std(rewards)\n                advantage = advantage / (std + epsilon)\n            advantages[i_max] = advantage\n\n    advantages = advantages.unsqueeze(-1) * response_mask\n    return advantages, advantages\n\n\n@register_adv_est(\n    AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE\n)  # or simply: @register_adv_est(\"reinforce_plus_plus_baseline\")\ndef compute_reinforce_plus_plus_baseline_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: torch.Tensor,\n    epsilon: float = 1e-6,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward\n    (with only one scalar reward for each response).\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    response_length = token_level_rewards.shape[-1]\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2mean = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2mean[idx] = torch.tensor(0.0)\n            elif len(id2score[idx]) > 1:\n                id2mean[idx] = torch.mean(torch.stack(id2score[idx]))\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            scores[i] = scores[i] - id2mean[index[i]]\n\n        scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask\n        scores = verl_F.masked_whiten(scores, response_mask) * response_mask\n\n    return scores, scores\n\n\n@register_adv_est(AdvantageEstimator.RLOO)  # or simply: @register_adv_est(\"rloo\")\ndef compute_rloo_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2mean = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2mean[idx] = torch.tensor(0.0)\n            elif len(id2score[idx]) > 1:\n                id2mean[idx] = torch.mean(torch.stack(id2score[idx]))\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            response_num = len(id2score[index[i]])\n            if response_num > 1:\n                scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (\n                    response_num - 1\n                )\n        scores = scores.unsqueeze(-1) * response_mask\n\n    return scores, scores\n\n\n@register_adv_est(AdvantageEstimator.OPO)  # or simply: @register_adv_est(\"opo\")\ndef compute_opo_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    response_length = response_mask.sum(dim=-1)\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2len = defaultdict(list)\n    id2bsl = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n            id2len[index[i]].append(response_length[i])\n\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2bsl[idx] = torch.tensor(0.0)\n            elif len(id2score[idx]) > 1:\n                score_tensor = torch.stack(id2score[idx])\n                len_tensor = torch.stack(id2len[idx])\n                id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum()\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            scores[i] = scores[i] - id2bsl[index[i]]\n        scores = scores.unsqueeze(-1) * response_mask\n\n    return scores, scores\n\n\n@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS)  # or simply: @register_adv_est(\"reinforce_plus_plus\")\ndef compute_reinforce_plus_plus_outcome_advantage(\n    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for REINFORCE++.\n    This implementation is based on the paper: https://arxiv.org/abs/2501.03262\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    assert config is not None\n    gamma = config.gamma\n    with torch.no_grad():\n        returns = torch.zeros_like(token_level_rewards)\n        running_return = 0\n\n        for t in reversed(range(token_level_rewards.shape[1])):\n            running_return = token_level_rewards[:, t] + gamma * running_return\n            returns[:, t] = running_return\n            # Reset after EOS\n            running_return = running_return * response_mask[:, t]\n\n        advantages = verl_F.masked_whiten(returns, response_mask)\n        advantages = advantages * response_mask\n\n    return advantages, returns\n\n\n@register_adv_est(AdvantageEstimator.REMAX)  # or simply: @register_adv_est(\"remax\")\ndef compute_remax_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    reward_baselines: torch.Tensor,\n    response_mask: torch.Tensor,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for ReMax, operating only on Outcome reward\n    This implementation is based on the paper: https://arxiv.org/abs/2310.10505\n    (with only one scalar reward for each response).\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        reward_baselines: `(torch.Tensor)`\n            shape: (bs,)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n\n    with torch.no_grad():\n        returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])\n        advantages = returns - reward_baselines.unsqueeze(-1) * response_mask\n\n    return advantages, returns\n\n\n@register_adv_est(AdvantageEstimator.GPG)  # or simply: @register_adv_est(\"gpg\")\ndef compute_gpg_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    f_norm: float = 1.0,\n    alpha: float = 1.0,\n    config=None,\n    **kwargs,\n):\n    \"\"\"\n    Compute advantage for GPG, operating only on Outcome reward\n    (with only one scalar reward for each response).\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        index: `(np.ndarray)`\n            shape: (bs,)\n        epsilon: (float)\n        f_norm: (float)\n        alpha: (float)\n        config: (dict) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2mean = {}\n    id2std = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        m = torch.count_nonzero(scores)\n        alpha = bsz / m.clamp(min=1)\n\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2mean[idx] = torch.tensor(0.0)\n                id2std[idx] = torch.tensor(1.0)\n            elif len(id2score[idx]) > 1:\n                scores_tensor = torch.stack(id2score[idx])\n                id2mean[idx] = torch.mean(scores_tensor)\n                id2std[idx] = torch.std(scores_tensor)\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm)\n        scores = scores.unsqueeze(-1) * response_mask\n\n    return scores, scores\n\n\n@register_adv_est(AdvantageEstimator.RLOO_VECTORIZED)  # or simply: @register_adv_est(\"rloo_vectorized\")\ndef compute_rloo_vectorized_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    scores = token_level_rewards.sum(dim=-1)\n\n    with torch.no_grad():\n        inv = torch.from_numpy(np.unique(index, return_inverse=True)[1]).to(scores.device)\n\n        c = torch.bincount(inv)[inv].to(scores.dtype)\n        adv = ((c * scores - torch.bincount(inv, weights=scores)[inv]) / (c - 1).clamp_min(1)) * (c > 1)\n\n        adv = adv.unsqueeze(-1) * response_mask\n\n    return adv, adv\n\n\ndef compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):\n    \"\"\"Compute token-level rewards with KL penalty.\n\n    Args:\n        token_level_scores (torch.Tensor): Token-level reward scores.\n        old_log_prob (torch.Tensor): Log probabilities from current policy.\n        ref_log_prob (torch.Tensor): Log probabilities from reference policy.\n        kl_ratio (float): KL penalty coefficient.\n\n    Returns:\n        torch.Tensor: Token-level rewards with KL penalty applied.\n    \"\"\"\n    kl = old_log_prob - ref_log_prob\n    return token_level_scores - kl * kl_ratio\n\n\ndef agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):\n    \"\"\"\n    Aggregate the loss matrix into a scalar.\n\n    Args:\n        loss_mat: `(torch.Tensor)`:\n            shape: (bs, response_length)\n        loss_mask: `(torch.Tensor)`:\n            shape: (bs, response_length)\n        loss_agg_mode: (str) choices:\n            method to aggregate the loss matrix into a scalar.\n    Returns:\n        loss: `a scalar torch.Tensor`\n            aggregated loss\n    \"\"\"\n    if loss_agg_mode == \"token-mean\":\n        loss = verl_F.masked_mean(loss_mat, loss_mask)\n    elif loss_agg_mode == \"seq-mean-token-sum\":\n        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum\n        seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float()  # exclude fully masked sequences\n        loss = verl_F.masked_mean(seq_losses, seq_mask)  # seq-mean\n    elif loss_agg_mode == \"seq-mean-token-mean\":\n        seq_mask = torch.sum(loss_mask, dim=-1)  # per-sequence token count\n        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8)  # token-mean\n        seq_mask = (seq_mask > 0).float()  # exclude fully masked sequences\n        loss = verl_F.masked_mean(seq_losses, seq_mask)  # seq-mean\n    elif loss_agg_mode == \"seq-mean-token-sum-norm\":\n        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)\n        loss = torch.sum(seq_losses) / loss_mask.shape[-1]  # The divisor\n        # (loss_mask.shape[-1]) should ideally be constant\n        # throughout training to well-replicate the DrGRPO paper.\n        # TODO: Perhaps add user-defined normalizer argument to\n        # agg_loss to ensure divisor stays constant throughout.\n    else:\n        raise ValueError(f\"Invalid loss_agg_mode: {loss_agg_mode}\")\n\n    return loss\n\n\n@deprecated(\"verl.trainer.ppo.core_algos.compute_policy_loss_vanilla\")\ndef compute_policy_loss(\n    old_log_prob,\n    log_prob,\n    advantages,\n    response_mask,\n    cliprange=None,\n    cliprange_low=None,\n    cliprange_high=None,\n    clip_ratio_c=3.0,\n    loss_agg_mode: str = \"token-mean\",\n):\n    \"\"\"\n    Compute the clipped policy objective and related metrics for PPO.\n\n    Adapted from\n    https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        cliprange (float, optional):\n            Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.\n            Defaults to None (must be provided).\n        cliprange_low (float, optional):\n            Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        cliprange_high (float, optional):\n            Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        clip_ratio_c (float, optional):\n            Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.\n            Defaults to 3.0.\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n    \"\"\"\n    assert clip_ratio_c > 1.0, (\n        \"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,\"\n        + f\" but get the value: {clip_ratio_c}.\"\n    )\n\n    negative_approx_kl = log_prob - old_log_prob\n    # Clamp negative_approx_kl for stability\n    negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    pg_losses1 = -advantages * ratio\n    if cliprange_low is None:\n        cliprange_low = cliprange\n    if cliprange_high is None:\n        cliprange_high = cliprange\n    pg_losses2 = -advantages * torch.clamp(\n        ratio, 1 - cliprange_low, 1 + cliprange_high\n    )  # - clip(ratio, 1-cliprange, 1+cliprange) * A\n    clip_pg_losses1 = torch.maximum(\n        pg_losses1, pg_losses2\n    )  # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)\n    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)\n\n    pg_losses3 = -advantages * clip_ratio_c\n    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)\n    pg_clipfrac_lower = verl_F.masked_mean(\n        torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask\n    )\n\n    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\n@register_policy_loss(\"vanilla\")  # type: ignore[arg-type]\ndef compute_policy_loss_vanilla(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[DictConfig | AlgoConfig] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for PPO.\n\n    Adapted from\n    https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n        config: `(verl.trainer.config.ActorConfig)`:\n            config for the actor.\n        rollout_log_probs: `(torch.Tensor)`:\n            log probabilities of actions under the rollout policy, shape (batch_size, response_length).\n    \"\"\"\n\n    assert config is not None\n    assert not isinstance(config, AlgoConfig)\n    clip_ratio = config.clip_ratio  # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.\n    clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio\n    clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio\n    clip_ratio_c = config.get(  # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.\n        \"clip_ratio_c\", 3.0\n    )\n\n    cliprange = clip_ratio\n    cliprange_low = clip_ratio_low\n    cliprange_high = clip_ratio_high\n\n    assert clip_ratio_c > 1.0, (\n        \"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,\"\n        + f\" but get the value: {clip_ratio_c}.\"\n    )\n\n    negative_approx_kl = log_prob - old_log_prob\n    # Clamp negative_approx_kl for stability\n    negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    pg_losses1 = -advantages * ratio\n    if cliprange_low is None:\n        cliprange_low = cliprange\n    if cliprange_high is None:\n        cliprange_high = cliprange\n    pg_losses2 = -advantages * torch.clamp(\n        ratio, 1 - cliprange_low, 1 + cliprange_high\n    )  # - clip(ratio, 1-cliprange, 1+cliprange) * A\n    clip_pg_losses1 = torch.maximum(\n        pg_losses1, pg_losses2\n    )  # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)\n    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)\n\n    pg_losses3 = -advantages * clip_ratio_c\n    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)\n    pg_clipfrac_lower = verl_F.masked_mean(\n        torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask\n    )\n\n    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)\n\n    # Apply rollout importance sampling weights if provided\n    if rollout_is_weights is not None:\n        pg_losses = pg_losses * rollout_is_weights\n\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\n@register_policy_loss(\"gspo\")\ndef compute_policy_loss_gspo(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"seq-mean-token-mean\",\n    config: Optional[DictConfig | ActorConfig] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for GSPO.\n\n    See https://arxiv.org/pdf/2507.18071 for more details.\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. For GSPO, it is recommended to use \"seq-mean-token-mean\".\n    \"\"\"\n\n    assert config is not None\n    assert isinstance(config, ActorConfig)\n    clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio\n    clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio\n\n    negative_approx_kl = log_prob - old_log_prob\n\n    # compute sequence-level importance ratio:\n    # si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) =\n    # exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]\n    seq_lengths = torch.sum(response_mask, dim=-1).clamp(min=1)\n    negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths\n\n    # Combined ratio at token level:\n    # s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]\n    # In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_prob - sg[log_prob]\n    log_seq_importance_ratio = log_prob - log_prob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)\n    log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, max=10.0)  # clamp for numerical stability\n\n    # finaly exp() to remove log\n    seq_importance_ratio = torch.exp(log_seq_importance_ratio)\n\n    pg_losses1 = -advantages * seq_importance_ratio\n    pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)\n    pg_losses = torch.maximum(pg_losses1, pg_losses2)\n\n    # Apply rollout importance sampling weights if provided\n    if rollout_is_weights is not None:\n        pg_losses = pg_losses * rollout_is_weights\n\n    # for GSPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=\"seq-mean-token-mean\")\n\n    # For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)\n    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)\n    pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)\n\n    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\n@register_policy_loss(\"gpg\")\ndef compute_policy_loss_gpg(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[DictConfig | AlgoConfig] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Adapted from\n    https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495\n    Args:\n        log_prob: `(torch.Tensor)`\n            shape: (bs, response_length)\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n    return:\n        pg_loss: `a scalar torch.Tensor`\n            policy gradient loss computed via GPG\n    \"\"\"\n    pg_losses = -log_prob * advantages\n\n    # Apply rollout importance sampling weights if provided\n    if rollout_is_weights is not None:\n        pg_losses = pg_losses * rollout_is_weights\n\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n    return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)\n\n\n@register_policy_loss(\"clip_cov\")\ndef compute_policy_loss_clip_cov(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[DictConfig | AlgoConfig] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for Clip-Cov.\n\n    Adapted from\n    https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        cliprange (float, optional):\n            Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.\n            Defaults to None (must be provided).\n        cliprange_low (float, optional):\n            Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        cliprange_high (float, optional):\n            Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n        clip_cvo_ratio (float, optional):\n            Ratio for clipping the covariance. Defaults to 0.0002.\n        clip_cov_lb (float, optional):\n            Lower bound for clipping covariance. Defaults to 1.0.\n        clip_cov_ub (float, optional):\n            Upper bound for clipping covariance. Defaults to 5.0.\n    \"\"\"\n    assert config is not None\n    assert not isinstance(config, AlgoConfig), \"passing AlgoConfig not supported yet\"\n    assert config.policy_loss is not None\n\n    clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002\n    cliprange = config.clip_ratio\n    cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange\n    cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange\n    clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0\n    clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0\n\n    assert clip_cov_ratio > 0, \"clip_ratio should be larger than 0.\"\n\n    negative_approx_kl = log_prob - old_log_prob\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    pg_losses1 = -advantages * ratio\n\n    if cliprange_low is None:\n        cliprange_low = cliprange\n    if cliprange_high is None:\n        cliprange_high = cliprange\n\n    corr = torch.ones_like(advantages)\n    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)\n    clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0)\n\n    cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * (\n        log_prob - verl_F.masked_mean(log_prob.detach(), response_mask)\n    )\n    cov_all[response_mask == 0] = -torch.inf\n    cov_all[clip_by_origin] = -torch.inf\n\n    clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1)\n    top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0)\n    top_k_idx = torch.nonzero(top_k_idx)\n\n    if len(top_k_idx) > 0:\n        perm = torch.randperm(len(top_k_idx))\n        top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]]\n    else:\n        top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long)\n\n    corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0\n\n    pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask)\n\n    pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr\n\n    # Apply rollout importance sampling weights if provided\n    if rollout_is_weights is not None:\n        pg_losses = pg_losses * rollout_is_weights\n\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)\n\n\n@register_policy_loss(\"kl_cov\")\ndef compute_policy_loss_kl_cov(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[DictConfig | AlgoConfig] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for Clip-Cov.\n\n    Adapted from\n    https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n        kl_cov_ratio (float, optional):\n            Ratio for selecting the top-k covariance values. Defaults to 0.0002.\n        ppo_kl_coef (float, optional):\n            Coefficient for the KL penalty term in the loss. Defaults to 1.\n    \"\"\"\n    assert config is not None\n    assert not isinstance(config, AlgoConfig), \"passing AlgoConfig not supported yet\"\n    assert config.policy_loss is not None\n\n    kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002\n    ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0\n\n    assert kl_cov_ratio > 0, \"kl_cov_ratio should be larger than 0.\"\n\n    negative_approx_kl = log_prob - old_log_prob\n    abs_kl = negative_approx_kl.abs()\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask)\n    pg_losses1 = -advantages * ratio\n    pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl\n    pg_losses = pg_losses1\n\n    all_valid = response_mask > 0\n    all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0]\n    all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu()\n    all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu()\n\n    k = min(kl_cov_ratio, len(all_valid_adv))\n\n    if k != 0:\n        cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean())\n        k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio))\n        large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices\n\n        if len(large_cov_idxs) != 0:\n            large_cov_idxs = all_valid_idx[large_cov_idxs]\n            pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[\n                large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]\n            ]\n\n    # Apply rollout importance sampling weights if provided\n    if rollout_is_weights is not None:\n        pg_losses = pg_losses * rollout_is_weights\n\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)\n\n\n@register_policy_loss(\"geo_mean\")\ndef compute_policy_loss_geo_mean(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[DictConfig | AlgoConfig] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for GMPO.\n\n    Adapted from paper https://arxiv.org/abs/2507.20673\n    https://github.com/callsys/GMPO/blob/main/train_zero_math_gmpo.py\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        loss_agg_mode (str, optional):\n            not used\n    \"\"\"\n\n    assert config is not None\n    assert not isinstance(config, AlgoConfig)\n    clip_ratio = config.clip_ratio  # Clipping parameter. See https://arxiv.org/abs/1707.06347.\n    clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio\n    clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio\n\n    cliprange = clip_ratio\n    cliprange_low = clip_ratio_low\n    cliprange_high = clip_ratio_high\n    if cliprange_low is None:\n        cliprange_low = cliprange\n    if cliprange_high is None:\n        cliprange_high = cliprange\n\n    negative_approx_kl = log_prob - old_log_prob\n    # Clamp negative_approx_kl for stability (uncomment it if you like)\n    # negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)\n    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    # Clipping at token-level & Clipping wider\n    sgn_advantage = torch.sign(advantages)\n    negative_approx_kl_clamp = torch.clamp(negative_approx_kl, -cliprange_low, cliprange_high)\n    negative_approx_kl_min = torch.min(sgn_advantage * negative_approx_kl, sgn_advantage * negative_approx_kl_clamp)\n    negative_approx_kl_min = sgn_advantage * negative_approx_kl_min\n\n    # Geometric-Mean Policy Optimization\n    response_mask_sum = response_mask.sum(dim=-1)\n    ratio = torch.exp((negative_approx_kl_min * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8))\n    # we only support sequence level advantage for now,\n    # otherwise, below would be not consistent with the paper\n    advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)\n    pg_losses = -advantage * ratio\n\n    # Apply rollout importance sampling weights if provided\n    # For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level\n    if rollout_is_weights is not None:\n        # Aggregate token-level weights to sequence level using geometric mean for consistency\n        # Note: rollout_is_weights is always 2D regardless of rollout_is_level\n        seq_is_weights = torch.exp(\n            (torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)\n        )\n        pg_losses = pg_losses * seq_is_weights\n\n    pg_loss = torch.mean(pg_losses)\n\n    # higher: ratio is too large that need clamp to clip_high (when adv > 0)\n    clipped = torch.ne(negative_approx_kl, negative_approx_kl_clamp)\n    pg_clipfrac = verl_F.masked_mean((clipped * (advantages > 0)).float(), response_mask)\n    pg_clipfrac_lower = verl_F.masked_mean((clipped * (advantages < 0)).float(), response_mask)\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\ndef compute_entropy_loss(logits, response_mask, loss_agg_mode: str = \"token-mean\"):\n    \"\"\"Compute categorical entropy loss (For backward compatibility)\n\n    Args:\n        logits (torch.Tensor): shape is (bs, response_length, vocab_size)\n        response_mask (torch.Tensor): shape is (bs, response_length)\n\n    Returns:\n        entropy: a scalar torch.Tensor\n\n    \"\"\"\n    # compute entropy\n    token_entropy = verl_F.entropy_from_logits(logits)  # (bs, response_len)\n    entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n    return entropy_loss\n\n\ndef compute_value_loss(\n    vpreds: torch.Tensor,\n    returns: torch.Tensor,\n    values: torch.Tensor,\n    response_mask: torch.Tensor,\n    cliprange_value: float,\n    loss_agg_mode: str = \"token-mean\",\n):\n    \"\"\"\n    Compute the clipped value-function loss for PPO.\n\n    Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151\n\n    Args:\n        vpreds (torch.FloatTensor):\n            Predicted values from the value head, shape (batch_size, response_length).\n        values (torch.FloatTensor):\n            Old (baseline) values from the value head, shape (batch_size, response_length).\n        returns (torch.FloatTensor):\n            Ground-truth returns, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the value loss calculation.\n        cliprange_value (float):\n            Clip range for value prediction updates.\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n\n    Returns:\n        vf_loss (torch.FloatTensor):\n            A scalar tensor containing the aggregated value-function loss.\n        vf_clipfrac (float):\n            Fraction of elements where the clipped loss was used.\n    \"\"\"\n    vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)\n    vf_losses1 = (vpreds - returns) ** 2\n    vf_losses2 = (vpredclipped - returns) ** 2\n    clipped_vf_losses = torch.max(vf_losses1, vf_losses2)\n    vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n    vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)\n    return vf_loss, vf_clipfrac\n\n\ndef kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:\n    \"\"\"Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other\n    kl penalty compute method for unbiased KL gradient estimation.\n    See more description in http://joschu.net/blog/kl-approx.html\n\n    Args:\n        logprob:\n        ref_logprob:\n\n    Returns:\n        kl_estimate\n    \"\"\"\n    forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)\n    if not kl_penalty.endswith(\"+\") or kl_penalty in (\"mse\", \"k2\"):\n        return forward_score\n\n    \"\"\"\n    The expectation of k1 and k3 estimator is the expectaed value of KL, but the expected gradient of k1 and k3\n    estimator is not the expectaed gradient of KL. On the other hand k2 estimator gives right gradient estimator, \n    so we use a straight through trick here if the kl_penalty method ends with '+', .e.g., k3+. \n    \"\"\"\n    backward_score = 0.5 * (logprob - ref_logprob).square()\n\n    return backward_score - backward_score.detach() + forward_score.detach()\n\n\ndef kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:\n    \"\"\"Compute KL divergence given logprob and ref_logprob.\n    Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104\n    See more description in http://joschu.net/blog/kl-approx.html\n\n    Args:\n        logprob:\n        ref_logprob:\n\n    Returns:\n        kl_estimate\n    \"\"\"\n    if kl_penalty in (\"kl\", \"k1\"):\n        return logprob - ref_logprob\n\n    if kl_penalty == \"abs\":\n        return (logprob - ref_logprob).abs()\n\n    if kl_penalty in (\"mse\", \"k2\"):\n        return 0.5 * (logprob - ref_logprob).square()\n\n    # J. Schulman. Approximating kl divergence, 2020.\n    # # URL http://joschu.net/blog/kl-approx.html.\n    if kl_penalty in (\"low_var_kl\", \"k3\"):\n        kl = ref_logprob - logprob\n        # For numerical stability\n        kl = torch.clamp(kl, min=-20, max=20)\n        ratio = torch.exp(kl)\n        kld = (ratio - kl - 1).contiguous()\n        return torch.clamp(kld, min=-10, max=10)\n\n    if kl_penalty == \"full\":\n        # so, here logprob and ref_logprob should contain the logits for every token in vocabulary\n        raise NotImplementedError\n\n    raise NotImplementedError\n\n\ndef compute_pf_ppo_reweight_data(\n    data,\n    reweight_method: str = \"pow\",\n    weight_pow: float = 2.0,\n):\n    \"\"\"Reweight the data based on the token_level_scores.\n\n    Args:\n        data: DataProto object, containing batch, non_tensor_batch and meta_info\n        reweight_method: str, choices: \"pow\", \"max_min\", \"max_random\"\n        weight_pow: float, the power of the weight\n\n    Returns:\n\n    \"\"\"\n\n    @torch.no_grad()\n    def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor:\n        \"\"\"Compute importance weights for resampling based on scores.\n\n        Args:\n            scores (torch.Tensor): Tensor of scores to compute weights from.\n            reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random').\n            weight_pow (float): Power exponent for 'pow' method.\n\n        Returns:\n            torch.Tensor: Computed importance weights.\n\n        Raises:\n            ValueError: If reweight_method is not supported.\n        \"\"\"\n        if reweight_method == \"pow\":\n            weights = torch.pow(torch.abs(scores), weight_pow)\n        elif reweight_method == \"max_min\":\n            max_score = torch.max(scores)\n            min_score = torch.min(scores)\n            weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0)\n        elif reweight_method == \"max_random\":\n            max_score = torch.max(scores)\n            weights = torch.where(scores == max_score, 0.4, 0.1)\n        else:\n            raise ValueError(f\"Unsupported reweight_method: {reweight_method}\")\n        return weights\n\n    scores = data.batch[\"token_level_scores\"].sum(dim=-1)\n    weights = compute_weights(scores, reweight_method, weight_pow)\n    weights = torch.clamp(weights + 1e-8, min=1e-8)\n\n    batch_size = scores.shape[0]\n    sample_indices = torch.multinomial(weights, batch_size, replacement=True)\n\n    resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()}\n\n    sample_indices_np = sample_indices.numpy()\n    resampled_non_tensor_batch = {}\n    for key, array in data.non_tensor_batch.items():\n        if isinstance(array, np.ndarray):\n            resampled_non_tensor_batch[key] = array[sample_indices_np]\n        else:\n            resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np]\n\n    resampled_meta_info = {}\n    for key, value in data.meta_info.items():\n        if isinstance(value, list) and len(value) == batch_size:\n            resampled_meta_info[key] = [value[i] for i in sample_indices_np]\n        else:\n            resampled_meta_info[key] = value\n\n    from copy import deepcopy\n\n    resampled_data = deepcopy(data)\n    resampled_data.batch = type(data.batch)(resampled_batch)\n    resampled_data.batch.batch_size = data.batch.batch_size\n    resampled_data.non_tensor_batch = resampled_non_tensor_batch\n    resampled_data.meta_info = resampled_meta_info\n\n    return resampled_data\n"
  },
  {
    "path": "verl_distillation/verl/trainer/ppo/metric_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMetrics related to the PPO trainer.\n\"\"\"\n\nfrom collections import defaultdict\nfrom functools import partial\nfrom typing import Any, Callable\n\nimport numpy as np\nimport torch\n\nfrom verl import DataProto\nfrom verl.utils.import_utils import deprecated\n\n\n@deprecated(\"verl.utils.metric.reduce_metrics\")\ndef reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:\n    \"\"\"\n    Reduces a dictionary of metric lists by computing the mean of each list.\n\n    Args:\n        metrics: A dictionary mapping metric names to lists of metric values.\n\n    Returns:\n        A dictionary with the same keys but with each list replaced by its mean value.\n\n    Example:\n        >>> metrics = {\"loss\": [1.0, 2.0, 3.0], \"accuracy\": [0.8, 0.9, 0.7]}\n        >>> reduce_metrics(metrics)\n        {\"loss\": 2.0, \"accuracy\": 0.8}\n    \"\"\"\n    from verl.utils.metric import reduce_metrics\n\n    return reduce_metrics(metrics)\n\n\ndef _compute_response_info(batch: DataProto) -> dict[str, Any]:\n    \"\"\"\n    Computes information about prompts and responses from a batch.\n\n    This is an internal helper function that extracts masks and lengths for prompts and responses.\n\n    Args:\n        batch: A DataProto object containing batch data with responses and attention masks.\n\n    Returns:\n        A dictionary containing:\n            - response_mask: Attention mask for the response tokens\n            - prompt_length: Tensor of prompt lengths for each item in the batch\n            - response_length: Tensor of response lengths for each item in the batch\n    \"\"\"\n    response_length = batch.batch[\"responses\"].shape[-1]\n\n    prompt_mask = batch.batch[\"attention_mask\"][:, :-response_length]\n    response_mask = batch.batch[\"attention_mask\"][:, -response_length:]\n\n    prompt_length = prompt_mask.sum(-1).float()\n    response_length = response_mask.sum(-1).float()  # (batch_size,)\n\n    return dict(\n        response_mask=response_mask,\n        prompt_length=prompt_length,\n        response_length=response_length,\n    )\n\ndef compute_on_policy_distill_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, Any]:\n    \"\"\"\n    Computes various metrics from a batch of data for PPO training.\n\n    This function calculates metrics related to scores, rewards, advantages, returns, values,\n    and sequence lengths from a batch of data. It provides statistical information (mean, max, min)\n    for each metric category.\n\n    Args:\n        batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.\n        use_critic: Whether to include critic-specific metrics. Defaults to True.\n\n    Returns:\n        A dictionary of metrics including:\n            - critic/score/mean, max, min: Statistics about sequence scores\n            - critic/rewards/mean, max, min: Statistics about sequence rewards\n            - critic/advantages/mean, max, min: Statistics about advantages\n            - critic/returns/mean, max, min: Statistics about returns\n            - critic/values/mean, max, min: Statistics about critic values (if use_critic=True)\n            - critic/vf_explained_var: Explained variance of the value function (if use_critic=True)\n            - response_length/mean, max, min, clip_ratio: Statistics about response lengths\n            - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths\n            - num_turns/mean, max, min: Statistics about the number of multi-turn conversations\n    \"\"\"\n\n    advantages = batch.batch[\"advantages\"]\n    returns = batch.batch[\"returns\"]\n\n    max_response_length = batch.batch[\"responses\"].shape[-1]\n\n    prompt_mask = batch.batch[\"attention_mask\"][:, :-max_response_length].bool()\n    response_mask = batch.batch[\"response_mask\"].bool()\n\n    max_prompt_length = prompt_mask.size(-1)\n\n    response_info = _compute_response_info(batch)\n    prompt_length = response_info[\"prompt_length\"]\n    response_length = response_info[\"response_length\"]\n\n    aborted_mask = (response_length == 0).bool()\n    non_aborted_mask = ~aborted_mask\n\n\n    valid_adv = torch.masked_select(advantages, response_mask)\n    valid_returns = torch.masked_select(returns, response_mask)\n\n    if use_critic:\n        values = batch.batch[\"values\"]\n        valid_values = torch.masked_select(values, response_mask)\n        return_diff_var = torch.var(valid_returns - valid_values)\n        return_var = torch.var(valid_returns)\n\n    # Aborted samples and non-aborted response length statistics\n    # response_length_non_aborted/*: statistics computed on non-aborted samples only\n    aborted_ratio = torch.mean(aborted_mask.float()).detach().item()\n\n    non_aborted_response_length = response_length[non_aborted_mask]\n    if non_aborted_response_length.numel() > 0:\n        non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item()\n        non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item()\n        non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item()\n        non_aborted_response_length_clip_ratio = (\n            torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item()\n        )\n    else:\n        raise ValueError(\"All samples are aborted, this should not happen.\")\n\n    metrics = {\n        # adv\n        \"critic/advantages/mean\": torch.mean(valid_adv).detach().item(),\n        \"critic/advantages/max\": torch.max(valid_adv).detach().item(),\n        \"critic/advantages/min\": torch.min(valid_adv).detach().item(),\n        \"critic/advantages/std\": torch.std(valid_adv).detach().item(),\n        # returns\n        \"critic/returns/mean\": torch.mean(valid_returns).detach().item(),\n        \"critic/returns/max\": torch.max(valid_returns).detach().item(),\n        \"critic/returns/min\": torch.min(valid_returns).detach().item(),\n        \"critic/returns/std\": torch.std(valid_returns).detach().item(),\n        **(\n            {\n                # values\n                \"critic/values/mean\": torch.mean(valid_values).detach().item(),\n                \"critic/values/max\": torch.max(valid_values).detach().item(),\n                \"critic/values/min\": torch.min(valid_values).detach().item(),\n                \"critic/values/std\": torch.std(valid_values).detach().item(),\n                # vf explained var\n                \"critic/vf_explained_var\": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),\n            }\n            if use_critic\n            else {}\n        ),\n        # response length\n        \"response_length/mean\": torch.mean(response_length).detach().item(),\n        \"response_length/max\": torch.max(response_length).detach().item(),\n        \"response_length/min\": torch.min(response_length).detach().item(),\n        \"response_length/clip_ratio\": torch.mean(torch.eq(response_length, max_response_length).float())\n        .detach()\n        .item(),\n        # response length (non-aborted only)\n        # These statistics exclude aborted samples to avoid skew from zeros\n        \"response_length_non_aborted/mean\": non_aborted_response_length_mean,\n        \"response_length_non_aborted/max\": non_aborted_response_length_max,\n        \"response_length_non_aborted/min\": non_aborted_response_length_min,\n        \"response_length_non_aborted/clip_ratio\": non_aborted_response_length_clip_ratio,\n        # aborted ratio\n        # Fraction of samples whose response length is zero\n        \"response/aborted_ratio\": aborted_ratio,\n        # prompt length\n        \"prompt_length/mean\": torch.mean(prompt_length).detach().item(),\n        \"prompt_length/max\": torch.max(prompt_length).detach().item(),\n        \"prompt_length/min\": torch.min(prompt_length).detach().item(),\n        \"prompt_length/clip_ratio\": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),\n    }\n\n    # multi-turn conversation\n    if \"__num_turns__\" in batch.non_tensor_batch:\n        num_turns = batch.non_tensor_batch[\"__num_turns__\"]\n        metrics[\"num_turns/min\"] = num_turns.min()\n        metrics[\"num_turns/max\"] = num_turns.max()\n        metrics[\"num_turns/mean\"] = num_turns.mean()\n\n    if \"tool_call_counts\" in batch.non_tensor_batch:\n        tool_call_counts = batch.non_tensor_batch[\"tool_call_counts\"]\n        metrics[\"tool_call_counts/min\"] = tool_call_counts.min()\n        metrics[\"tool_call_counts/max\"] = tool_call_counts.max()\n        metrics[\"tool_call_counts/mean\"] = tool_call_counts.mean()\n\n    return metrics\n\ndef compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, Any]:\n    \"\"\"\n    Computes various metrics from a batch of data for PPO training.\n\n    This function calculates metrics related to scores, rewards, advantages, returns, values,\n    and sequence lengths from a batch of data. It provides statistical information (mean, max, min)\n    for each metric category.\n\n    Args:\n        batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.\n        use_critic: Whether to include critic-specific metrics. Defaults to True.\n\n    Returns:\n        A dictionary of metrics including:\n            - critic/score/mean, max, min: Statistics about sequence scores\n            - critic/rewards/mean, max, min: Statistics about sequence rewards\n            - critic/advantages/mean, max, min: Statistics about advantages\n            - critic/returns/mean, max, min: Statistics about returns\n            - critic/values/mean, max, min: Statistics about critic values (if use_critic=True)\n            - critic/vf_explained_var: Explained variance of the value function (if use_critic=True)\n            - response_length/mean, max, min, clip_ratio: Statistics about response lengths\n            - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths\n            - num_turns/mean, max, min: Statistics about the number of multi-turn conversations\n    \"\"\"\n    sequence_score = batch.batch[\"token_level_scores\"].sum(-1)\n    sequence_reward = batch.batch[\"token_level_rewards\"].sum(-1)\n\n    advantages = batch.batch[\"advantages\"]\n    returns = batch.batch[\"returns\"]\n\n    max_response_length = batch.batch[\"responses\"].shape[-1]\n\n    prompt_mask = batch.batch[\"attention_mask\"][:, :-max_response_length].bool()\n    response_mask = batch.batch[\"response_mask\"].bool()\n\n    max_prompt_length = prompt_mask.size(-1)\n\n    response_info = _compute_response_info(batch)\n    prompt_length = response_info[\"prompt_length\"]\n    response_length = response_info[\"response_length\"]\n\n    aborted_mask = (response_length == 0).bool()\n    non_aborted_mask = ~aborted_mask\n\n    non_aborted_sequence_score = sequence_score[non_aborted_mask]\n    non_aborted_sequence_reward = sequence_reward[non_aborted_mask]\n\n    score_mean = torch.mean(non_aborted_sequence_score).detach().item()\n    score_max = torch.max(non_aborted_sequence_score).detach().item()\n    score_min = torch.min(non_aborted_sequence_score).detach().item()\n\n    reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()\n    reward_max = torch.max(non_aborted_sequence_reward).detach().item()\n    reward_min = torch.min(non_aborted_sequence_reward).detach().item()\n\n    valid_adv = torch.masked_select(advantages, response_mask)\n    valid_returns = torch.masked_select(returns, response_mask)\n\n    if use_critic:\n        values = batch.batch[\"values\"]\n        valid_values = torch.masked_select(values, response_mask)\n        return_diff_var = torch.var(valid_returns - valid_values)\n        return_var = torch.var(valid_returns)\n\n    # Aborted samples and non-aborted response length statistics\n    # response_length_non_aborted/*: statistics computed on non-aborted samples only\n    aborted_ratio = torch.mean(aborted_mask.float()).detach().item()\n\n    non_aborted_response_length = response_length[non_aborted_mask]\n    if non_aborted_response_length.numel() > 0:\n        non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item()\n        non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item()\n        non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item()\n        non_aborted_response_length_clip_ratio = (\n            torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item()\n        )\n    else:\n        raise ValueError(\"All samples are aborted, this should not happen.\")\n\n    metrics = {\n        # score\n        \"critic/score/mean\": score_mean,\n        \"critic/score/max\": score_max,\n        \"critic/score/min\": score_min,\n        # reward\n        \"critic/rewards/mean\": reward_mean,\n        \"critic/rewards/max\": reward_max,\n        \"critic/rewards/min\": reward_min,\n        # adv\n        \"critic/advantages/mean\": torch.mean(valid_adv).detach().item(),\n        \"critic/advantages/max\": torch.max(valid_adv).detach().item(),\n        \"critic/advantages/min\": torch.min(valid_adv).detach().item(),\n        # returns\n        \"critic/returns/mean\": torch.mean(valid_returns).detach().item(),\n        \"critic/returns/max\": torch.max(valid_returns).detach().item(),\n        \"critic/returns/min\": torch.min(valid_returns).detach().item(),\n        **(\n            {\n                # values\n                \"critic/values/mean\": torch.mean(valid_values).detach().item(),\n                \"critic/values/max\": torch.max(valid_values).detach().item(),\n                \"critic/values/min\": torch.min(valid_values).detach().item(),\n                # vf explained var\n                \"critic/vf_explained_var\": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),\n            }\n            if use_critic\n            else {}\n        ),\n        # response length\n        \"response_length/mean\": torch.mean(response_length).detach().item(),\n        \"response_length/max\": torch.max(response_length).detach().item(),\n        \"response_length/min\": torch.min(response_length).detach().item(),\n        \"response_length/clip_ratio\": torch.mean(torch.eq(response_length, max_response_length).float())\n        .detach()\n        .item(),\n        # response length (non-aborted only)\n        # These statistics exclude aborted samples to avoid skew from zeros\n        \"response_length_non_aborted/mean\": non_aborted_response_length_mean,\n        \"response_length_non_aborted/max\": non_aborted_response_length_max,\n        \"response_length_non_aborted/min\": non_aborted_response_length_min,\n        \"response_length_non_aborted/clip_ratio\": non_aborted_response_length_clip_ratio,\n        # aborted ratio\n        # Fraction of samples whose response length is zero\n        \"response/aborted_ratio\": aborted_ratio,\n        # prompt length\n        \"prompt_length/mean\": torch.mean(prompt_length).detach().item(),\n        \"prompt_length/max\": torch.max(prompt_length).detach().item(),\n        \"prompt_length/min\": torch.min(prompt_length).detach().item(),\n        \"prompt_length/clip_ratio\": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),\n    }\n\n    # multi-turn conversation\n    if \"__num_turns__\" in batch.non_tensor_batch:\n        num_turns = batch.non_tensor_batch[\"__num_turns__\"]\n        metrics[\"num_turns/min\"] = num_turns.min()\n        metrics[\"num_turns/max\"] = num_turns.max()\n        metrics[\"num_turns/mean\"] = num_turns.mean()\n\n    if \"tool_call_counts\" in batch.non_tensor_batch:\n        tool_call_counts = batch.non_tensor_batch[\"tool_call_counts\"]\n        metrics[\"tool_call_counts/min\"] = tool_call_counts.min()\n        metrics[\"tool_call_counts/max\"] = tool_call_counts.max()\n        metrics[\"tool_call_counts/mean\"] = tool_call_counts.mean()\n\n    return metrics\n\n\ndef compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]:\n    \"\"\"\n    Computes timing metrics for different processing stages in PPO training.\n\n    This function calculates both raw timing metrics (in seconds) and per-token timing metrics\n    (in milliseconds) for various processing stages like generation, reference computation,\n    value computation, advantage computation, and model updates.\n\n    Args:\n        batch: A DataProto object containing batch data with responses and attention masks.\n        timing_raw: A dictionary mapping stage names to their execution times in seconds.\n\n    Returns:\n        A dictionary containing:\n            - timing_s/{name}: Raw timing in seconds for each stage\n            - timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage\n\n    Note:\n        Different stages use different token counts for normalization:\n        - \"gen\" uses only response tokens\n        - Other stages (\"ref\", \"values\", \"adv\", \"update_critic\", \"update_actor\") use all tokens\n          (prompt + response)\n    \"\"\"\n    response_info = _compute_response_info(batch)\n    num_prompt_tokens = torch.sum(response_info[\"prompt_length\"]).item()\n    num_response_tokens = torch.sum(response_info[\"response_length\"]).item()\n    num_overall_tokens = num_prompt_tokens + num_response_tokens\n\n    num_tokens_of_section = {\n        \"gen\": num_response_tokens,\n        **{name: num_overall_tokens for name in [\"ref\", \"values\", \"adv\", \"update_critic\", \"update_actor\"]},\n    }\n\n    return {\n        **{f\"timing_s/{name}\": value for name, value in timing_raw.items()},\n        **{\n            f\"timing_per_token_ms/{name}\": timing_raw[name] * 1000 / num_tokens_of_section[name]\n            for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())\n        },\n    }\n\n\ndef compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]:\n    \"\"\"\n    Computes throughput metrics for PPO training.\n\n    This function calculates performance metrics related to token processing speed,\n    including the total number of tokens processed, time per step, and throughput\n    (tokens per second per GPU).\n\n    Args:\n        batch: A DataProto object containing batch data with meta information about token counts.\n        timing_raw: A dictionary mapping stage names to their execution times in seconds.\n                   Must contain a \"step\" key with the total step time.\n        n_gpus: Number of GPUs used for training.\n\n    Returns:\n        A dictionary containing:\n            - perf/total_num_tokens: Total number of tokens processed in the batch\n            - perf/time_per_step: Time taken for the step in seconds\n            - perf/throughput: Tokens processed per second per GPU\n\n    Note:\n        The throughput is calculated as total_tokens / (time * n_gpus) to normalize\n        across different GPU counts.\n    \"\"\"\n    total_num_tokens = sum(batch.meta_info[\"global_token_num\"])\n    time = timing_raw[\"step\"]\n    # estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time)\n    # f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus),\n    # f'Theoretical TFLOPs/s/GPU​': promised_flops,\n    return {\n        \"perf/total_num_tokens\": total_num_tokens,\n        \"perf/time_per_step\": time,\n        \"perf/throughput\": total_num_tokens / (time * n_gpus),\n    }\n\n\ndef bootstrap_metric(\n    data: list[Any],\n    subset_size: int,\n    reduce_fns: list[Callable[[np.ndarray], float]],\n    n_bootstrap: int = 1000,\n    seed: int = 42,\n) -> list[tuple[float, float]]:\n    \"\"\"\n    Performs bootstrap resampling to estimate statistics of metrics.\n\n    This function uses bootstrap resampling to estimate the mean and standard deviation\n    of metrics computed by the provided reduction functions on random subsets of the data.\n\n    Args:\n        data: List of data points to bootstrap from.\n        subset_size: Size of each bootstrap sample.\n        reduce_fns: List of functions that compute a metric from a subset of data.\n        n_bootstrap: Number of bootstrap iterations. Defaults to 1000.\n        seed: Random seed for reproducibility. Defaults to 42.\n\n    Returns:\n        A list of tuples, where each tuple contains (mean, std) for a metric\n        corresponding to each reduction function in reduce_fns.\n\n    Example:\n        >>> data = [1, 2, 3, 4, 5]\n        >>> reduce_fns = [np.mean, np.max]\n        >>> bootstrap_metric(data, 3, reduce_fns)\n        [(3.0, 0.5), (4.5, 0.3)]  # Example values\n    \"\"\"\n    np.random.seed(seed)\n\n    bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))]\n    for _ in range(n_bootstrap):\n        bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True)\n        bootstrap_data = [data[i] for i in bootstrap_idxs]\n        for i, reduce_fn in enumerate(reduce_fns):\n            bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data))\n    return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts]\n\n\ndef calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float:\n    \"\"\"\n    Calculate a value based on majority voting.\n\n    This function identifies the most common value for a specified vote key\n    in the data, then returns the corresponding value for that majority vote.\n\n    Args:\n        data: List of dictionaries, where each dictionary contains both vote_key and val_key.\n        vote_key: The key in each dictionary used for voting/counting.\n        val_key: The key in each dictionary whose value will be returned for the majority vote.\n\n    Returns:\n        The value associated with the most common vote.\n\n    Example:\n        >>> data = [\n        ...     {\"pred\": \"A\", \"val\": 0.9},\n        ...     {\"pred\": \"B\", \"val\": 0.8},\n        ...     {\"pred\": \"A\", \"val\": 0.7}\n        ... ]\n        >>> calc_maj_val(data, vote_key=\"pred\", val_key=\"val\")\n        0.9  # Returns the first \"val\" for the majority vote \"A\"\n    \"\"\"\n    vote2vals = defaultdict(list)\n    for d in data:\n        vote2vals[d[vote_key]].append(d[val_key])\n\n    vote2cnt = {k: len(v) for k, v in vote2vals.items()}\n    maj_vote = max(vote2cnt, key=vote2cnt.get)\n\n    maj_val = vote2vals[maj_vote][0]\n\n    return maj_val\n\n\ndef process_validation_metrics(\n    data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42\n) -> dict[str, dict[str, dict[str, float]]]:\n    \"\"\"\n    Process validation metrics into a structured format with statistical analysis.\n\n    This function organizes validation metrics by data source and prompt, then computes\n    various statistical measures including means, standard deviations, best/worst values,\n    and majority voting results. It also performs bootstrap sampling to estimate statistics\n    for different sample sizes.\n\n    Args:\n        data_sources: List of data source identifiers for each sample.\n        sample_uids: List of sample uids corresponding to each sample.\n        infos_dict: Dictionary mapping variable names to lists of values for each sample.\n        seed: Random seed for bootstrap sampling. Defaults to 42.\n\n    Returns:\n        A nested dictionary with the structure:\n        {\n            data_source: {\n                variable_name: {\n                    metric_name: value\n                }\n            }\n        }\n\n        Where metric_name includes:\n        - \"mean@N\": Mean value across N samples\n        - \"std@N\": Standard deviation across N samples\n        - \"best@N/mean\": Mean of the best values in bootstrap samples of size N\n        - \"best@N/std\": Standard deviation of the best values in bootstrap samples\n        - \"worst@N/mean\": Mean of the worst values in bootstrap samples\n        - \"worst@N/std\": Standard deviation of the worst values in bootstrap samples\n        - \"maj@N/mean\": Mean of majority voting results in bootstrap samples (if \"pred\" exists)\n        - \"maj@N/std\": Standard deviation of majority voting results (if \"pred\" exists)\n\n    Example:\n        >>> data_sources = [\"source1\", \"source1\", \"source2\"]\n        >>> sample_uids = [\"uid1\", \"uid1\", \"uid2\"]\n        >>> infos_dict = {\"score\": [0.8, 0.9, 0.7], \"pred\": [\"A\", \"A\", \"B\"]}\n        >>> result = process_validation_metrics(data_sources, sample_uids, infos_dict)\n        >>> # result will contain statistics for each data source and variable\n    \"\"\"\n    # Group metrics by data source, prompt and variable\n    data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))\n    for sample_idx, data_source in enumerate(data_sources):\n        uid = sample_uids[sample_idx]\n        var2vals = data_src2uid2var2vals[data_source][uid]\n        for var_name, var_vals in infos_dict.items():\n            var2vals[var_name].append(var_vals[sample_idx])\n\n    # Calculate metrics for each group\n    data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))\n    for data_source, uid2var2vals in data_src2uid2var2vals.items():\n        for uid, var2vals in uid2var2vals.items():\n            for var_name, var_vals in var2vals.items():\n                if isinstance(var_vals[0], str):\n                    continue\n\n                metric = {}\n                n_resps = len(var_vals)\n                metric[f\"mean@{n_resps}\"] = np.mean(var_vals)\n\n                if n_resps > 1:\n                    metric[f\"std@{n_resps}\"] = np.std(var_vals)\n\n                    ns = []\n                    n = 2\n                    while n < n_resps:\n                        ns.append(n)\n                        n *= 2\n                    ns.append(n_resps)\n\n                    for n in ns:\n                        [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(\n                            data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed\n                        )\n                        metric[f\"best@{n}/mean\"], metric[f\"best@{n}/std\"] = bon_mean, bon_std\n                        metric[f\"worst@{n}/mean\"], metric[f\"worst@{n}/std\"] = won_mean, won_std\n                        if var2vals.get(\"pred\", None) is not None:\n                            vote_data = [\n                                {\"val\": val, \"pred\": pred} for val, pred in zip(var_vals, var2vals[\"pred\"], strict=True)\n                            ]\n                            [(maj_n_mean, maj_n_std)] = bootstrap_metric(\n                                data=vote_data,\n                                subset_size=n,\n                                reduce_fns=[partial(calc_maj_val, vote_key=\"pred\", val_key=\"val\")],\n                                seed=seed,\n                            )\n                            metric[f\"maj@{n}/mean\"], metric[f\"maj@{n}/std\"] = maj_n_mean, maj_n_std\n\n                data_src2uid2var2metric[data_source][uid][var_name] = metric\n\n    # Aggregate metrics across uids\n    data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))\n    for data_source, uid2var2metric in data_src2uid2var2metric.items():\n        for uid, var2metric in uid2var2metric.items():\n            for var_name, metric in var2metric.items():\n                for metric_name, metric_val in metric.items():\n                    data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val)\n\n    data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))\n    for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items():\n        for var_name, metric2uid_vals in var2metric2uid_vals.items():\n            for metric_name, uid_vals in metric2uid_vals.items():\n                data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals)\n\n    return data_src2var2metric2val\n"
  },
  {
    "path": "verl_distillation/verl/trainer/ppo/mismatch_helper.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nRollout Importance Sampling (IS) Helper Module\n\nThis module handles importance sampling weight computation for correcting\ndistribution mismatch between rollout policy (e.g., vLLM BFloat16) and\ntraining policy (e.g., FSDP FP32).\n\nKey Features:\n1. Three aggregation levels: token, sequence, geometric\n2. Two handling modes: truncate, mask\n3. Per-token veto mechanism for catastrophic outliers\n4. Memory-efficient computation to prevent CUDA OOM\n5. Comprehensive metrics tracking\n\nUsage Notes:\n- compute_rollout_importance_weights() computes both IS weights and mismatch metrics\n- Used in ray_trainer.py via compute_rollout_importance_weights_and_add_to_batch()\n- Also used in dp_actor.py for distributed worker computations\n- compute_mismatch_metrics() is called internally by compute_rollout_importance_weights()\n\nReferences:\n- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda\n- Off-policy RL: https://fengyao.notion.site/off-policy-rl\n\"\"\"\n\nfrom typing import Any, Optional\n\nimport torch\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.protocol import DataProto\n\n\ndef compute_rollout_importance_weights(\n    old_log_prob: torch.Tensor,\n    rollout_log_prob: torch.Tensor,\n    response_mask: torch.Tensor,\n    rollout_is_level: str = \"token\",\n    rollout_is_mode: str = \"truncate\",\n    rollout_is_threshold: Optional[float] = None,\n    rollout_is_threshold_lower: Optional[float] = None,\n    rollout_is_veto_threshold: Optional[float] = None,\n) -> tuple[Optional[DataProto], torch.Tensor, dict[str, Any]]:\n    \"\"\"Compute importance sampling weights and rejection mask for rollout-training mismatch.\n\n    This function computes IS weights to correct for distribution mismatch between rollout\n    and training policies, and applies rejection sampling for outliers.\n\n    Key Design: Separation of IS Weights and Rejection Sampling\n    - IS weights (rollout_is_weights): Ratios π_train/π_rollout with processing applied:\n      * Safety-bounded to prevent overflow:\n        - Token level: exp(clamp(log_ratio, -20, 20)) per token\n        - Sequence level: exp(clamp(sum(log_ratio), -20, 20)) broadcast to all tokens\n        - Geometric level: exp(clamp(mean(log_ratio), -20, 20)) broadcast to all tokens\n      * Truncate mode: upper clamped via .clamp(max=upper_threshold)\n      * Mask mode: safety-bounded ratios preserved (no threshold clamping)\n      * All modes: zeroed at padding positions\n      Used for policy gradient calculations\n    - Response mask (modified_response_mask): Has rejection applied (mask mode + veto)\n      Used for loss aggregation to exclude rejected samples from training\n\n    Reference:\n        When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda\n\n    Memory-efficient implementation:\n    - Log-space computation to prevent overflow\n    - Safety bounds (exp(±20)) on all exponentiations\n    - Metrics computed without large intermediate tensors\n\n    Args:\n        old_log_prob: Log probs from training policy (FSDP FP32), shape (batch_size, seq_length)\n        rollout_log_prob: Log probs from rollout policy (vLLM BF16), shape (batch_size, seq_length)\n        response_mask: Valid token mask (1=valid, 0=padding), shape (batch_size, seq_length)\n        rollout_is_level: IS weight aggregation level\n            - \"token\": Per-token ratios ρ_t = π_train(t)/π_rollout(t) (biased but low variance)\n            - \"sequence\": Sequence product ρ_seq = ∏ρ_t (unbiased but high variance)\n            - \"geometric\": Geometric mean ρ_geo = (∏ρ_t)^(1/T) (experimental trade-off)\n        rollout_is_mode: Treatment of outlier IS weights\n            - \"truncate\": Clamp weights at upper threshold only. No rejection for outlier ratios,\n              but veto can still apply (TIS)\n            - \"mask\": Reject tokens/sequences outside [lower, upper] via response_mask (MIS/rejection sampling)\n        rollout_is_threshold: Upper threshold for IS weights (required, e.g., 2.0)\n        rollout_is_threshold_lower: Lower threshold for mask mode (if None, defaults to 1/upper)\n        rollout_is_veto_threshold: Catastrophic token threshold. If any token has ratio < this,\n            reject entire sequence. Applied independently of rollout_is_mode. If None, veto disabled. Default None.\n\n    Returns:\n        Tuple of (weights_proto, modified_response_mask, metrics):\n            weights_proto: DataProto with processed IS weights, key \"rollout_is_weights\",\n                shape (batch_size, seq_length). Processing applied:\n                - Safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]:\n                  * Token level: bounds per-token ratios\n                  * Sequence/geometric level: bounds aggregated ratio (broadcast to all tokens)\n                - Truncate mode: upper clamped via .clamp(max=upper_threshold)\n                - Mask mode: safety-bounded ratios preserved (no threshold clamping)\n                - All modes: zeroed at padding positions (response_mask == 0)\n                None if rollout_is_threshold is None.\n            modified_response_mask: Response mask with rejection applied:\n                - truncate mode: unchanged for outlier ratios, but veto rejection still applied\n                - mask mode: tokens outside [lower, upper] masked to 0\n                - veto: sequences with catastrophic tokens masked to 0 (applied in both modes)\n                Shape (batch_size, seq_length).\n            metrics: Dict of IS and mismatch metrics, all scalars with \"mismatch/\" prefix\n    \"\"\"\n    if rollout_is_threshold is None:\n        return None, response_mask, {}\n\n    # Parse thresholds: if lower not specified, use 1/upper (reciprocal)\n    upper_threshold = rollout_is_threshold\n    if rollout_is_threshold_lower is not None:\n        lower_threshold = rollout_is_threshold_lower\n    else:\n        # Default: lower = 1/upper (reciprocal)\n        lower_threshold = 1.0 / upper_threshold\n\n    # Step 1: Compute raw importance weights based on the specified level\n    log_ratio = old_log_prob - rollout_log_prob\n\n    # Pre-compute log thresholds\n    device = old_log_prob.device\n    log_threshold_upper = torch.log(torch.tensor(upper_threshold, device=device))\n    log_threshold_lower = torch.log(torch.tensor(lower_threshold, device=device))\n\n    # Safety bound to prevent numerical overflow (exp(20) ≈ 485M)\n    SAFETY_BOUND = 20.0\n\n    # Store unclamped values in log-space for accurate metrics\n    if rollout_is_level == \"token\":\n        # Token-level IS: π_train(a|s) / π_rollout(a|s) per token\n        log_ratio_for_metrics = log_ratio\n\n        # Apply safety bound to prevent overflow\n        log_ratio_safe = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND)\n        rollout_is_weights = torch.exp(log_ratio_safe)\n\n    elif rollout_is_level == \"sequence\":\n        # Sequence-level IS: π_train(y|x) / π_rollout(y|x) for entire sequence\n        # Product of token ratios: exp(Σ log(π_train/π_rollout))\n        log_ratio_sum = verl_F.masked_sum(log_ratio, response_mask, axis=-1).unsqueeze(-1)\n        log_ratio_for_metrics = log_ratio_sum  # Store for metrics\n\n        # Apply safety bound to prevent overflow\n        log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-SAFETY_BOUND, max=SAFETY_BOUND)\n        rollout_is_weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob)\n\n    elif rollout_is_level == \"geometric\":\n        # Geometric mean IS: (∏ π_train/π_rollout)^(1/T)\n        # Equivalent to exp(mean(log(π_train/π_rollout)))\n        log_ratio_mean = verl_F.masked_mean(log_ratio, response_mask, axis=-1).unsqueeze(-1)\n        log_ratio_for_metrics = log_ratio_mean  # Store for metrics\n\n        # Geometric mean rarely explodes due to averaging, but apply safety bound anyway\n        log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-SAFETY_BOUND, max=SAFETY_BOUND)\n        rollout_is_weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob)\n\n    else:\n        raise ValueError(f\"Invalid rollout_is_level: {rollout_is_level}. Must be 'token', 'sequence', or 'geometric'.\")\n\n    # Step 1.5: Apply per-token veto check in log space (memory efficient)\n    if rollout_is_veto_threshold is not None:\n        log_veto_threshold = torch.log(torch.tensor(rollout_is_veto_threshold, device=device))\n\n        # Check if any token ratio is below veto threshold (in log space)\n        # log(π_train/π_rollout) < log(veto_threshold) ⟺ π_train/π_rollout < veto_threshold\n        catastrophic_tokens = (log_ratio < log_veto_threshold) & response_mask.bool()\n\n        # For each sequence, check if it has any catastrophic token\n        # Use broadcasting instead of expand_as to save memory\n        has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True)\n\n        # Create veto mask: 0 if sequence has catastrophic token, 1 otherwise\n        veto_mask = (~has_catastrophic).float()\n    else:\n        # No veto mechanism\n        catastrophic_tokens = torch.zeros_like(response_mask, dtype=torch.bool)\n        has_catastrophic = torch.zeros((old_log_prob.size(0), 1), dtype=torch.bool, device=device)\n        veto_mask = torch.ones((old_log_prob.size(0), 1), dtype=torch.float32, device=device)\n\n    # Step 2: Compute comprehensive metrics\n    metrics = compute_is_metrics(\n        rollout_is_weights=rollout_is_weights,\n        log_ratio_for_metrics=log_ratio_for_metrics,\n        response_mask=response_mask,\n        rollout_is_level=rollout_is_level,\n        rollout_is_threshold=upper_threshold,\n        rollout_is_threshold_lower=lower_threshold,\n        log_threshold_upper=log_threshold_upper,\n        log_threshold_lower=log_threshold_lower,\n        has_catastrophic=has_catastrophic,\n        catastrophic_tokens=catastrophic_tokens,\n        SAFETY_BOUND=SAFETY_BOUND,\n    )\n\n    # Step 3: Apply outlier handling and rejection sampling\n    # Key design principle: IS weights and rejection are separate mechanisms\n    # - rollout_is_weights: IS weight ratios with mode-specific processing\n    #   * Truncate mode: upper clamped to prevent extreme values\n    #   * Mask mode: safety-bounded ratios preserved (no threshold clamping, rejection via mask)\n    #   Used for policy gradient calculations\n    # - modified_response_mask: Has rejection applied (excludes outliers from training)\n    #   Used for loss denominator: ensures rejected samples don't dilute gradients\n\n    if rollout_is_mode == \"truncate\":\n        # Truncated IS (TIS): clamp weights to prevent extreme importance ratios\n        # Weights are modified by clamping; no rejection via mask for outlier ratios\n        # Veto rejection (if enabled) will still be applied to modified_response_mask below\n        rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold)\n        modified_response_mask = response_mask  # Unchanged for outlier ratios (veto applied later)\n\n    elif rollout_is_mode == \"mask\":\n        # Masked IS (MIS): rejection sampling for outlier IS weights\n        # Reject tokens/sequences with IS ratios outside [lower, upper] via response_mask\n        # IS weights themselves are NOT threshold-clamped (remain safety-bounded only)\n        mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)\n        mask = mask.float()\n\n        # Compute rejection rate metrics\n        metrics[\"rollout_is_masked_fraction\"] = verl_F.masked_mean(1 - mask, response_mask)\n        if rollout_is_level in [\"sequence\", \"geometric\"]:\n            # Sequence-level: all tokens have same weight, check first token\n            metrics[\"rollout_is_seq_masked_fraction\"] = (1 - mask[:, 0]).mean()\n        else:\n            # Token-level: sequence rejected if ANY token is rejected\n            seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0\n            metrics[\"rollout_is_seq_masked_fraction\"] = seq_has_masked.float().mean()\n\n        # Apply rejection via response_mask (NOT by clamping IS weights)\n        modified_response_mask = response_mask * mask\n        # rollout_is_weights kept as safety-bounded ratios (no threshold clamping)\n\n    else:\n        raise ValueError(f\"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.\")\n\n    # Apply veto: reject entire sequences with catastrophic tokens (ratio < veto_threshold)\n    # Veto is independent of mode - it applies to modified_response_mask after mode-specific handling\n    modified_response_mask = modified_response_mask * veto_mask\n    # Note: rollout_is_weights unaffected by veto (already clamped in truncate mode, or kept as-is in mask mode)\n\n    # Zero out padding positions in IS weights for correct aggregation\n    # This is different from rejection - padding must be zeroed regardless of mode\n    rollout_is_weights = rollout_is_weights * response_mask\n\n    # Wrap in DataProto for consistency with worker methods\n    rollout_is_weights_proto = DataProto.from_dict(tensors={\"rollout_is_weights\": rollout_is_weights})\n\n    # Compute mismatch metrics (KL, PPL, etc.) and merge with IS metrics\n    mismatch_metrics = compute_mismatch_metrics(\n        old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask\n    )\n    metrics.update(mismatch_metrics)\n\n    # Convert all tensor metrics to scalars for logging\n    # Note: No need to detach since old_log_prob and rollout_log_prob are computed with torch.no_grad()\n    metrics_scalar = {}\n    for key, value in metrics.items():\n        if isinstance(value, torch.Tensor):\n            metrics_scalar[f\"mismatch/{key}\"] = value.item()\n        else:\n            metrics_scalar[f\"mismatch/{key}\"] = value\n\n    return rollout_is_weights_proto, modified_response_mask, metrics_scalar\n\n\ndef compute_is_metrics(\n    rollout_is_weights: torch.Tensor,\n    log_ratio_for_metrics: torch.Tensor,\n    response_mask: torch.Tensor,\n    rollout_is_level: str,\n    rollout_is_threshold: float,\n    rollout_is_threshold_lower: float,\n    log_threshold_upper: torch.Tensor,\n    log_threshold_lower: torch.Tensor,\n    has_catastrophic: torch.Tensor,\n    catastrophic_tokens: torch.Tensor,\n    SAFETY_BOUND: float,\n) -> dict[str, Any]:\n    \"\"\"Compute comprehensive metrics for importance sampling weights.\n\n    Reference:\n        When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda\n\n    This function computes metrics using a mix of true unclamped values (for max/min/fractions\n    in sequence/geometric mode via log-space) and safety-clamped values (for mean/std/ESS)\n    to balance accuracy with numerical stability and avoid overflow.\n    \"\"\"\n    # Validate that we have at least one valid sample\n    assert response_mask.any(), \"Expected at least one valid sample in response_mask\"\n\n    metrics = {}\n    device = rollout_is_weights.device\n\n    # Track veto statistics\n    metrics[\"rollout_is_veto_fraction\"] = has_catastrophic.float().mean()\n    metrics[\"rollout_is_catastrophic_token_fraction\"] = verl_F.masked_mean(catastrophic_tokens.float(), response_mask)\n\n    # Compute metrics based on IS level\n    if rollout_is_level in [\"sequence\", \"geometric\"]:\n        # For sequence/geometric, compute true statistics from log-space\n        # This reflects the actual distribution before clamping\n\n        # True max/min in log space\n        log_max = log_ratio_for_metrics.max()\n        log_min = log_ratio_for_metrics.min()\n\n        # Convert to regular space with safety bound\n        metrics[\"rollout_is_max\"] = torch.exp(torch.clamp(log_max, max=SAFETY_BOUND))\n        metrics[\"rollout_is_min\"] = torch.exp(log_min)\n\n        # Mean uses clamped weights to avoid overflow\n        metrics[\"rollout_is_mean\"] = verl_F.masked_mean(rollout_is_weights, response_mask)\n\n        # Compute fraction exceeding threshold in log space (accurate)\n        exceeds_upper = log_ratio_for_metrics > log_threshold_upper\n        below_lower = log_ratio_for_metrics < log_threshold_lower\n\n        if rollout_is_level == \"sequence\":\n            # For sequence level, all tokens in a sequence have the same weight\n            metrics[\"rollout_is_ratio_fraction_high\"] = exceeds_upper.float().mean()\n            metrics[\"rollout_is_ratio_fraction_low\"] = below_lower.float().mean()\n        else:  # geometric\n            # Need to expand to match token dimensions\n            exceeds_upper_expanded = exceeds_upper.expand_as(response_mask)\n            below_lower_expanded = below_lower.expand_as(response_mask)\n            metrics[\"rollout_is_ratio_fraction_high\"] = verl_F.masked_mean(\n                exceeds_upper_expanded.float(), response_mask\n            )\n            metrics[\"rollout_is_ratio_fraction_low\"] = verl_F.masked_mean(below_lower_expanded.float(), response_mask)\n\n    else:\n        # Token-level: compute directly from weights\n        metrics[\"rollout_is_mean\"] = verl_F.masked_mean(rollout_is_weights, response_mask)\n\n        # Fraction exceeding thresholds\n        rollout_is_above_threshold = rollout_is_weights > rollout_is_threshold\n        rollout_is_below_threshold = rollout_is_weights < rollout_is_threshold_lower\n        metrics[\"rollout_is_ratio_fraction_high\"] = verl_F.masked_mean(\n            rollout_is_above_threshold.float(), response_mask\n        )\n        metrics[\"rollout_is_ratio_fraction_low\"] = verl_F.masked_mean(rollout_is_below_threshold.float(), response_mask)\n\n        # Max/min for token level\n        mask_bool = response_mask.bool()\n        metrics[\"rollout_is_max\"] = rollout_is_weights.masked_fill(~mask_bool, float(\"-inf\")).max()\n        metrics[\"rollout_is_min\"] = rollout_is_weights.masked_fill(~mask_bool, float(\"inf\")).min()\n\n    # Compute standard deviation using clamped weights to avoid overflow\n    mask_count = response_mask.sum()\n    if mask_count > 1:\n        # Use clamped weights for variance to avoid squaring huge values\n        weights_for_std = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)\n        # Use mean from clamped weights for consistency\n        mean_clamped = verl_F.masked_mean(weights_for_std, response_mask)\n        rollout_is_var = verl_F.masked_mean(weights_for_std.square(), response_mask) - mean_clamped.square()\n        metrics[\"rollout_is_std\"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0))\n    else:\n        metrics[\"rollout_is_std\"] = torch.tensor(0.0, device=device)\n\n    # Effective sample size (use clamped weights to avoid overflow)\n    weights_for_ess = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)\n    mean_for_ess = verl_F.masked_mean(weights_for_ess, response_mask)\n    is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8)\n    metrics[\"rollout_is_eff_sample_size\"] = 1.0 / verl_F.masked_mean(is_weights_normalized.square(), response_mask)\n\n    # Per-sequence breakdown metrics\n    if rollout_is_weights.dim() > 1:\n        # Compute mean IS weight per sequence\n        seq_mean_weights = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1)\n\n        # Per-sequence statistics\n        metrics[\"rollout_is_seq_mean\"] = seq_mean_weights.mean()\n        metrics[\"rollout_is_seq_std\"] = (\n            seq_mean_weights.std() if seq_mean_weights.numel() > 1 else torch.tensor(0.0, device=device)\n        )\n        metrics[\"rollout_is_seq_max\"] = seq_mean_weights.max()\n        metrics[\"rollout_is_seq_min\"] = seq_mean_weights.min()\n\n        # Identify most problematic sequences\n        seq_deviation = (seq_mean_weights - 1.0).abs()\n        metrics[\"rollout_is_seq_max_deviation\"] = seq_deviation.max()\n\n        # Fraction of sequences with high IS weights\n        metrics[\"rollout_is_seq_fraction_high\"] = (seq_mean_weights > rollout_is_threshold).float().mean()\n        metrics[\"rollout_is_seq_fraction_low\"] = (seq_mean_weights < rollout_is_threshold_lower).float().mean()\n\n    return metrics\n\n\ndef compute_mismatch_metrics(\n    old_log_prob: torch.Tensor,\n    rollout_log_prob: Optional[torch.Tensor],\n    response_mask: torch.Tensor,\n) -> dict[str, Any]:\n    \"\"\"Compute training-inference mismatch metrics (helper function).\n\n    This helper function operates on raw tensors and is used internally by:\n    - compute_rollout_importance_weights() in this module (automatically included)\n    - Tests (test_rollout_is.py, test_rollout_is_integration.py)\n\n    These metrics help diagnose the mismatch between the rollout policy (e.g., vLLM)\n    and the training policy (e.g., FSDP), which can cause training instability.\n\n    Key metrics:\n    - mismatch_kl: Direct KL divergence estimator KL(π_rollout || π_training)\n    - mismatch_k3_kl: K3 KL estimator for stability (more stable for small KL)\n    - training_ppl: Perplexity of training policy\n    - rollout_ppl: Perplexity of rollout policy\n    - log_ppl_diff: Difference in log perplexities\n    - ppl_ratio: Ratio of training PPL to rollout PPL\n\n    Args:\n        old_log_prob: Log probabilities from training policy, shape (batch_size, seq_length)\n        rollout_log_prob: Log probabilities from rollout policy, shape (batch_size, seq_length)\n        response_mask: Mask for valid tokens, shape (batch_size, seq_length)\n\n    Returns:\n        Dictionary of mismatch metrics (without prefix)\n\n    Reference:\n    - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda\n    \"\"\"\n    # Validate that we have at least one valid token\n    assert response_mask.any(), \"Expected at least one valid token in response_mask\"\n\n    metrics = {}\n\n    # 1. Training policy perplexity (always available)\n    # Formula: exp(-1/|T| * Σ log π_training(y_t|y_<t))\n    # where |T| is the number of tokens generated by the model\n    mean_log_prob_training = verl_F.masked_mean(old_log_prob, response_mask, axis=-1)  # (batch_size,)\n    training_ppl = torch.exp(-mean_log_prob_training).mean()  # Batch mean of per-sequence PPL\n    metrics[\"mismatch_training_ppl\"] = training_ppl.detach().item()\n\n    # Also log log-ppl for easier analysis (avoids exponential scale)\n    metrics[\"mismatch_training_log_ppl\"] = (-mean_log_prob_training).mean().detach().item()\n\n    # 2. Compute rollout mismatch metrics (only if rollout_log_probs available)\n    if rollout_log_prob is not None:\n        # 2a. mismatch_kl: Direct estimator for KL(π_rollout || π_training)\n        # This is the standard KL divergence: E[log(π_rollout) - log(π_training)]\n        # Positive value means rollout policy is more confident than training policy\n        metrics[\"mismatch_kl\"] = verl_F.masked_mean(rollout_log_prob - old_log_prob, response_mask).detach().item()\n\n        # 2b. mismatch_k3_kl: K3 estimator for KL(π_rollout || π_training)\n        # More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1]\n        # Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout\n        log_ratio = old_log_prob - rollout_log_prob\n        mismatch_k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1\n        metrics[\"mismatch_k3_kl\"] = verl_F.masked_mean(mismatch_k3_kl_matrix, response_mask).detach().item()\n\n        # 2c. Rollout policy perplexity\n        mean_log_prob_rollout = verl_F.masked_mean(rollout_log_prob, response_mask, axis=-1)  # (batch_size,)\n        rollout_ppl = torch.exp(-mean_log_prob_rollout).mean()  # Batch mean of per-sequence PPL\n        metrics[\"mismatch_rollout_ppl\"] = rollout_ppl.detach().item()\n        metrics[\"mismatch_rollout_log_ppl\"] = (-mean_log_prob_rollout).mean().detach().item()\n\n        # 2d. Log PPL difference (sequence-level perplexity difference)\n        # log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training\n        # Since ppl = exp(-log_prob), we have:\n        #   log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff\n        # Positive value means training assigns lower probability (higher PPL) than rollout\n        log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training\n        metrics[\"mismatch_log_ppl_diff\"] = log_ppl_diff.mean().detach().item()\n        metrics[\"mismatch_log_ppl_abs_diff\"] = log_ppl_diff.abs().mean().detach().item()\n        metrics[\"mismatch_log_ppl_diff_max\"] = log_ppl_diff.max().detach().item()\n        metrics[\"mismatch_log_ppl_diff_min\"] = log_ppl_diff.min().detach().item()\n\n        # 2e. PPL ratio (how much higher is training PPL vs rollout PPL)\n        # IMPORTANT: Compute per-sequence ratio first, then average\n        # For numerical stability, compute in log space using log_ppl_diff\n        # Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff)\n        # This is the inverse of geometric IS: ppl_ratio_i = 1 / geometric_is_i for each sequence\n        ppl_ratio = torch.exp(log_ppl_diff).mean()  # mean(exp(log_ppl_diff)) = mean(ppl_ratio_i)\n        metrics[\"mismatch_ppl_ratio\"] = ppl_ratio.detach().item()\n\n    return metrics\n"
  },
  {
    "path": "verl_distillation/verl/trainer/ppo/ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport json\nimport os\nimport uuid\nfrom collections import defaultdict\nfrom copy import deepcopy\nfrom dataclasses import dataclass, field\nfrom pprint import pprint\nfrom typing import Optional\n\nimport numpy as np\nimport ray\nimport torch\nfrom omegaconf import OmegaConf, open_dict\nfrom torch.utils.data import Dataset, Sampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.experimental.dataset.sampler import AbstractCurriculumSampler\nfrom verl.protocol import pad_dataproto_to_divisor, unpad_dataproto\nfrom verl.single_controller.ray import (RayClassWithInitArgs, RayResourcePool,\n                                        RayWorkerGroup)\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.config import AlgoConfig\nfrom verl.trainer.ppo import core_algos\nfrom verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss\nfrom verl.trainer.ppo.metric_utils import (compute_data_metrics,\n                                           compute_throughout_metrics,\n                                           compute_timing_metrics,\n                                           process_validation_metrics)\nfrom verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights\nfrom verl.trainer.ppo.reward import compute_reward, compute_reward_async\nfrom verl.trainer.ppo.utils import (Role, WorkerType, need_critic,\n                                    need_reference_policy, need_reward_model)\nfrom verl.utils.checkpoint.checkpoint_manager import (find_latest_ckpt_path,\n                                                      should_save_ckpt_esi)\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.debug import marked_timer\nfrom verl.utils.metric import reduce_metrics\nfrom verl.utils.rollout_skip import RolloutSkip\nfrom verl.utils.seqlen_balancing import (calculate_workload,\n                                         get_seqlen_balanced_partitions,\n                                         log_seqlen_unbalance)\nfrom verl.utils.torch_functional import masked_mean\nfrom verl.utils.tracking import ValidationGenerationsLogger\n\n\n@dataclass\nclass ResourcePoolManager:\n    \"\"\"\n    Define a resource pool specification. Resource pool will be initialized first.\n    \"\"\"\n\n    resource_pool_spec: dict[str, list[int]]\n    mapping: dict[Role, str]\n    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)\n\n    def create_resource_pool(self):\n        \"\"\"Create Ray resource pools for distributed training.\n\n        Initializes resource pools based on the resource pool specification,\n        with each pool managing GPU resources across multiple nodes.\n        For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups.\n        For Megatron backend, uses max_colocate_count>1 for different models.\n        \"\"\"\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool\n            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.\n            # For Megatron backend, we recommend using max_colocate_count>1\n            # that can utilize different WorkerGroup for differnt models\n            resource_pool = RayResourcePool(\n                process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name\n            )\n            self.resource_pool_dict[resource_pool_name] = resource_pool\n\n        self._check_resource_available()\n\n    def get_resource_pool(self, role: Role) -> RayResourcePool:\n        \"\"\"Get the resource pool of the worker_cls\"\"\"\n        return self.resource_pool_dict[self.mapping[role]]\n\n    def get_n_gpus(self) -> int:\n        \"\"\"Get the number of gpus in this cluster.\"\"\"\n        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])\n\n    def _check_resource_available(self):\n        \"\"\"Check if the resource pool can be satisfied in this ray cluster.\"\"\"\n        node_available_resources = ray._private.state.available_resources_per_node()\n        node_available_gpus = {\n            node: node_info.get(\"GPU\", 0) if \"GPU\" in node_info else node_info.get(\"NPU\", 0)\n            for node, node_info in node_available_resources.items()\n        }\n\n        # check total required gpus can be satisfied\n        total_available_gpus = sum(node_available_gpus.values())\n        total_required_gpus = sum(\n            [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]\n        )\n        if total_available_gpus < total_required_gpus:\n            raise ValueError(\n                f\"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}\"\n            )\n\n\ndef apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty=\"kl\"):\n    \"\"\"Apply KL penalty to the token-level rewards.\n\n    This function computes the KL divergence between the reference policy and current policy,\n    then applies a penalty to the token-level rewards based on this divergence.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n        kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty.\n        kl_penalty (str, optional): Type of KL penalty to apply. Defaults to \"kl\".\n\n    Returns:\n        tuple: A tuple containing:\n            - The updated data with token-level rewards adjusted by KL penalty\n            - A dictionary of metrics related to the KL penalty\n    \"\"\"\n    response_mask = data.batch[\"response_mask\"]\n    token_level_scores = data.batch[\"token_level_scores\"]\n    batch_size = data.batch.batch_size[0]\n\n    # compute kl between ref_policy and current policy\n    # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.\n    kld = core_algos.kl_penalty(\n        data.batch[\"old_log_probs\"], data.batch[\"ref_log_prob\"], kl_penalty=kl_penalty\n    )  # (batch_size, response_length)\n    kld = kld * response_mask\n    beta = kl_ctrl.value\n\n    token_level_rewards = token_level_scores - beta * kld\n\n    current_kl = masked_mean(kld, mask=response_mask, axis=-1)  # average over sequence\n    current_kl = torch.mean(current_kl, dim=0).item()\n\n    # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837\n    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)\n    data.batch[\"token_level_rewards\"] = token_level_rewards\n\n    metrics = {\"actor/reward_kl_penalty\": current_kl, \"actor/reward_kl_penalty_coeff\": beta}\n\n    return data, metrics\n\n\ndef compute_response_mask(data: DataProto):\n    \"\"\"Compute the attention mask for the response part of the sequence.\n\n    This function extracts the portion of the attention mask that corresponds to the model's response,\n    which is used for masking computations that should only apply to response tokens.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n\n    Returns:\n        torch.Tensor: The attention mask for the response tokens.\n    \"\"\"\n    responses = data.batch[\"responses\"]\n    response_length = responses.size(1)\n    attention_mask = data.batch[\"attention_mask\"]\n    return attention_mask[:, -response_length:]\n\n\ndef compute_advantage(\n    data: DataProto,\n    adv_estimator: AdvantageEstimator,\n    gamma: float = 1.0,\n    lam: float = 1.0,\n    num_repeat: int = 1,\n    distill_adv_max_clip: float = None,\n    distill_adv_min_clip: float = None,\n    norm_adv_by_std_in_grpo: bool = True,\n    config: Optional[AlgoConfig] = None,\n) -> DataProto:\n    \"\"\"Compute advantage estimates for policy optimization.\n\n    This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc.\n    The advantage estimates are used to guide policy optimization in RL algorithms.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n        adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++).\n        gamma (float, optional): Discount factor for future rewards. Defaults to 1.0.\n        lam (float, optional): Lambda parameter for GAE. Defaults to 1.0.\n        num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1.\n        norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in\n            GRPO. Defaults to True.\n        config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None.\n\n    Returns:\n        DataProto: The updated data with computed advantages and returns.\n    \"\"\"\n    # Back-compatible with trainers that do not compute response mask in fit\n    if \"response_mask\" not in data.batch.keys():\n        data.batch[\"response_mask\"] = compute_response_mask(data)\n    # prepare response group\n    if adv_estimator == AdvantageEstimator.GAE:\n        # Compute advantages and returns using Generalized Advantage Estimation (GAE)\n        advantages, returns = core_algos.compute_gae_advantage_return(\n            token_level_rewards=data.batch[\"token_level_rewards\"],\n            values=data.batch[\"values\"],\n            response_mask=data.batch[\"response_mask\"],\n            gamma=gamma,\n            lam=lam,\n        )\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n        if config.get(\"use_pf_ppo\", False):\n            data = core_algos.compute_pf_ppo_reweight_data(\n                data,\n                config.pf_ppo.get(\"reweight_method\"),\n                config.pf_ppo.get(\"weight_pow\"),\n            )\n    elif adv_estimator == AdvantageEstimator.GRPO:\n        # Initialize the mask for GRPO calculation\n        grpo_calculation_mask = data.batch[\"response_mask\"]\n\n        # Call compute_grpo_outcome_advantage with parameters matching its definition\n        advantages, returns = core_algos.compute_grpo_outcome_advantage(\n            token_level_rewards=data.batch[\"token_level_rewards\"],\n            response_mask=grpo_calculation_mask,\n            index=data.non_tensor_batch[\"uid\"],\n            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n        )\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n    elif adv_estimator == AdvantageEstimator.ON_POLICY_DISTILL:\n        advantages, returns = core_algos.compute_on_policy_distill_reverse_kl(\n            teacher_log_prob=data.batch[\"ref_log_prob\"],\n            student_log_prob=data.batch[\"old_log_probs\"],\n        )\n\n        if distill_adv_max_clip:\n            advantages = torch.clamp(advantages, max=distill_adv_max_clip)\n        if distill_adv_min_clip:\n            advantages = torch.clamp(advantages, min=distill_adv_min_clip)\n\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n    else:\n        # handle all other adv estimator type other than GAE and GRPO\n        adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)\n        adv_kwargs = {\n            \"token_level_rewards\": data.batch[\"token_level_rewards\"],\n            \"response_mask\": data.batch[\"response_mask\"],\n            \"config\": config,\n        }\n        if \"uid\" in data.non_tensor_batch:  # optional\n            adv_kwargs[\"index\"] = data.non_tensor_batch[\"uid\"]\n        if \"reward_baselines\" in data.batch:  # optional\n            adv_kwargs[\"reward_baselines\"] = data.batch[\"reward_baselines\"]\n\n        # calculate advantage estimator\n        advantages, returns = adv_estimator_fn(**adv_kwargs)\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n    return data\n\n\nclass RayPPOTrainer:\n    \"\"\"Distributed PPO trainer using Ray for scalable reinforcement learning.\n\n    This trainer orchestrates distributed PPO training across multiple nodes and GPUs,\n    managing actor rollouts, critic training, and reward computation with Ray backend.\n    Supports various model architectures including FSDP, Megatron, vLLM, and SGLang integration.\n    \"\"\"\n\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        train_dataset: Optional[Dataset] = None,\n        val_dataset: Optional[Dataset] = None,\n        collate_fn=None,\n        train_sampler: Optional[Sampler] = None,\n        device_name=None,\n    ):\n        \"\"\"\n        Initialize distributed PPO trainer with Ray backend.\n        Note that this trainer runs on the driver process on a single CPU/GPU node.\n\n        Args:\n            config: Configuration object containing training parameters.\n            tokenizer: Tokenizer used for encoding and decoding text.\n            role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes.\n            resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools.\n            ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup.\n            processor: Optional data processor, used for multimodal data\n            reward_fn: Function for computing rewards during training.\n            val_reward_fn: Function for computing rewards during validation.\n            train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.\n            val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.\n            collate_fn: Function to collate data samples into batches.\n            train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.\n            device_name (str, optional): Device name for training (e.g., \"cuda\", \"cpu\"). Defaults to None.\n        \"\"\"\n\n        # Store the tokenizer for text processing\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n        assert self.hybrid_engine, \"Currently, only support hybrid engine\"\n\n        if self.hybrid_engine:\n            assert Role.ActorRollout in role_worker_mapping, f\"{role_worker_mapping.keys()=}\"\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = need_reference_policy(self.role_worker_mapping)\n        self.use_rm = need_reward_model(self.role_worker_mapping)\n        self.use_critic = need_critic(self.config)\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.device_name = device_name if device_name else self.config.trainer.device\n        self.validation_generations_logger = ValidationGenerationsLogger(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n        )\n\n        # if ref_in_actor is True, the reference policy will be actor without lora applied\n        self.ref_in_actor = (\n            config.actor_rollout_ref.model.get(\"lora_rank\", 0) > 0\n            or config.actor_rollout_ref.model.get(\"lora_adapter_path\") is not None\n        )\n\n        # define in-reward KL control\n        # kl loss control currently not suppoorted\n        if self.config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)\n\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n    def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):\n        \"\"\"\n        Creates the train and validation dataloaders.\n        \"\"\"\n        # TODO: we have to make sure the batch size is divisible by the dp size\n        from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler\n\n        if train_dataset is None:\n            train_dataset = create_rl_dataset(\n                self.config.data.train_files,\n                self.config.data,\n                self.tokenizer,\n                self.processor,\n                max_samples=self.config.data.get(\"train_max_samples\", -1),\n            )\n        if val_dataset is None:\n            val_dataset = create_rl_dataset(\n                self.config.data.val_files,\n                self.config.data,\n                self.tokenizer,\n                self.processor,\n                max_samples=self.config.data.get(\"val_max_samples\", -1),\n            )\n        self.train_dataset, self.val_dataset = train_dataset, val_dataset\n\n        if train_sampler is None:\n            train_sampler = create_rl_sampler(self.config.data, self.train_dataset)\n        if collate_fn is None:\n            from verl.utils.dataset.rl_dataset import \\\n                collate_fn as default_collate_fn\n\n            collate_fn = default_collate_fn\n\n        num_workers = self.config.data[\"dataloader_num_workers\"]\n\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=self.config.data.get(\"gen_batch_size\", self.config.data.train_batch_size),\n            num_workers=num_workers,\n            drop_last=True,\n            collate_fn=collate_fn,\n            sampler=train_sampler,\n        )\n\n        val_batch_size = self.config.data.val_batch_size  # Prefer config value if set\n        if val_batch_size is None:\n            val_batch_size = len(self.val_dataset)\n\n        self.val_dataloader = StatefulDataLoader(\n            dataset=self.val_dataset,\n            batch_size=val_batch_size,\n            num_workers=num_workers,\n            shuffle=self.config.data.get(\"validation_shuffle\", True),\n            drop_last=False,\n            collate_fn=collate_fn,\n        )\n\n        assert len(self.train_dataloader) >= 1, \"Train dataloader is empty!\"\n        assert len(self.val_dataloader) >= 1, \"Validation dataloader is empty!\"\n\n        print(\n            f\"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: \"\n            f\"{len(self.val_dataloader)}\"\n        )\n\n        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n\n        if self.config.trainer.total_training_steps is not None:\n            total_training_steps = self.config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n        print(f\"Total training steps: {self.total_training_steps}\")\n\n        try:\n            OmegaConf.set_struct(self.config, True)\n            with open_dict(self.config):\n                if OmegaConf.select(self.config, \"actor_rollout_ref.actor.optim\"):\n                    self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps\n                if OmegaConf.select(self.config, \"critic.optim\"):\n                    self.config.critic.optim.total_training_steps = total_training_steps\n        except Exception as e:\n            print(f\"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}\")\n\n    def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path, logger=None):\n        \"\"\"Dump rollout/validation samples as JSONL.\"\"\"\n        os.makedirs(dump_path, exist_ok=True)\n        filename = os.path.join(dump_path, f\"{self.global_steps}.jsonl\")\n\n        n = len(inputs)\n        base_data = {\n            \"input\": inputs,\n            \"output\": outputs,\n            \"score\": scores,\n            \"step\": [self.global_steps] * n,\n        }\n\n        for k, v in reward_extra_infos_dict.items():\n            if len(v) == n:\n                base_data[k] = v\n        \n        if logger is not None and 'wandb' in logger.logger:\n            import pandas as pd\n            df = pd.DataFrame(base_data)\n            import wandb\n            logger.logger['wandb'].log({\"completions\": wandb.Table(dataframe=df)})\n            return\n\n        lines = []\n        for i in range(n):\n            entry = {\n                k: int(v[i]) if any(t in str(type(v[i])) for t in ['int64', 'bool']) else v[i] \n                for k, v in base_data.items()\n            }\n            lines.append(json.dumps(entry, ensure_ascii=False))\n\n        with open(filename, \"w\") as f:\n            f.write(\"\\n\".join(lines) + \"\\n\")\n\n        print(f\"Dumped generations to {filename}\")\n\n    def _log_rollout_data(\n        self, batch: DataProto, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str,\n        logger = None\n    ):\n        \"\"\"Log rollout data to disk.\n        Args:\n            batch (DataProto): The batch containing rollout data\n            reward_extra_infos_dict (dict): Additional reward information to log\n            timing_raw (dict): Timing information for profiling\n            rollout_data_dir (str): Directory path to save the rollout data\n        \"\"\"\n        with marked_timer(\"dump_rollout_generations\", timing_raw, color=\"green\"):\n            inputs = self.tokenizer.batch_decode(batch.batch[\"prompts\"], skip_special_tokens=True)\n            outputs = self.tokenizer.batch_decode(batch.batch[\"responses\"], skip_special_tokens=True)\n            scores = batch.batch[\"token_level_scores\"].sum(-1).cpu().tolist()\n            sample_gts = [item.non_tensor_batch.get(\"reward_model\", {}).get(\"ground_truth\", None) for item in batch]\n\n            reward_extra_infos_to_dump = reward_extra_infos_dict.copy()\n            if \"request_id\" in batch.non_tensor_batch:\n                reward_extra_infos_dict.setdefault(\n                    \"request_id\",\n                    batch.non_tensor_batch[\"request_id\"].tolist(),\n                )\n\n            self._dump_generations(\n                inputs=inputs,\n                outputs=outputs,\n                gts=sample_gts,\n                scores=scores,\n                reward_extra_infos_dict=reward_extra_infos_to_dump,\n                dump_path=rollout_data_dir,\n                logger=logger,\n            )\n\n    def _maybe_log_val_generations(self, inputs, outputs, scores):\n        \"\"\"Log a table of validation samples to the configured logger (wandb or swanlab)\"\"\"\n\n        generations_to_log = self.config.trainer.log_val_generations\n\n        if generations_to_log == 0:\n            return\n\n        import numpy as np\n\n        # Create tuples of (input, output, score) and sort by input text\n        samples = list(zip(inputs, outputs, scores, strict=True))\n        samples.sort(key=lambda x: x[0])  # Sort by input text\n\n        # Use fixed random seed for deterministic shuffling\n        rng = np.random.RandomState(42)\n        rng.shuffle(samples)\n\n        # Take first N samples after shuffling\n        samples = samples[:generations_to_log]\n\n        # Log to each configured logger\n        self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)\n\n    def _get_gen_batch(self, batch: DataProto) -> DataProto:\n        reward_model_keys = set({\"data_source\", \"reward_model\", \"extra_info\", \"uid\"}) & batch.non_tensor_batch.keys()\n\n        # pop those keys for generation\n        batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n        non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys\n        gen_batch = batch.pop(\n            batch_keys=batch_keys_to_pop,\n            non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop),\n        )\n\n        # For agent loop, we need reward model keys to compute score.\n        if self.async_rollout_mode:\n            gen_batch.non_tensor_batch.update(batch.non_tensor_batch)\n\n        return gen_batch\n\n    def _validate(self):\n        data_source_lst = []\n        reward_extra_infos_dict: dict[str, list] = defaultdict(list)\n\n        # Lists to collect samples for the table\n        sample_inputs = []\n        sample_outputs = []\n        sample_gts = []\n        sample_scores = []\n        sample_turns = []\n        sample_uids = []\n\n        for test_data in self.val_dataloader:\n            test_batch = DataProto.from_single_dict(test_data)\n\n            if \"uid\" not in test_batch.non_tensor_batch:\n                test_batch.non_tensor_batch[\"uid\"] = np.array(\n                    [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object\n                )\n\n            # repeat test batch\n            test_batch = test_batch.repeat(\n                repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True\n            )\n\n            # we only do validation on rule-based rm\n            if self.config.reward_model.enable and test_batch[0].non_tensor_batch[\"reward_model\"][\"style\"] == \"model\":\n                return {}\n\n            # Store original inputs\n            input_ids = test_batch.batch[\"input_ids\"]\n            # TODO: Can we keep special tokens except for padding tokens?\n            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]\n            sample_inputs.extend(input_texts)\n            sample_uids.extend(test_batch.non_tensor_batch[\"uid\"])\n\n            ground_truths = [\n                item.non_tensor_batch.get(\"reward_model\", {}).get(\"ground_truth\", None) for item in test_batch\n            ]\n            sample_gts.extend(ground_truths)\n\n            test_gen_batch = self._get_gen_batch(test_batch)\n            test_gen_batch.meta_info = {\n                \"eos_token_id\": self.tokenizer.eos_token_id,\n                \"pad_token_id\": self.tokenizer.pad_token_id,\n                \"recompute_log_prob\": False,\n                \"do_sample\": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,\n                \"validate\": True,\n                \"global_steps\": self.global_steps,\n            }\n            print(f\"test_gen_batch meta info: {test_gen_batch.meta_info}\")\n\n            # pad to be divisible by dp_size\n            size_divisor = (\n                self.actor_rollout_wg.world_size\n                if not self.async_rollout_mode\n                else self.config.actor_rollout_ref.rollout.agent.num_workers\n            )\n            test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)\n            if not self.async_rollout_mode:\n                test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)\n            else:\n                test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)\n\n            # unpad\n            test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)\n\n            print(\"validation generation end\")\n\n            # Store generated outputs\n            output_ids = test_output_gen_batch.batch[\"responses\"]\n            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]\n            sample_outputs.extend(output_texts)\n\n            test_batch = test_batch.union(test_output_gen_batch)\n            test_batch.meta_info[\"validate\"] = True\n\n            # evaluate using reward_function\n            if self.val_reward_fn is None:\n                raise ValueError(\"val_reward_fn must be provided for validation.\")\n            result = self.val_reward_fn(test_batch, return_dict=True)\n            reward_tensor = result[\"reward_tensor\"]\n            scores = reward_tensor.sum(-1).cpu().tolist()\n            sample_scores.extend(scores)\n\n            reward_extra_infos_dict[\"reward\"].extend(scores)\n            if \"reward_extra_info\" in result:\n                for key, lst in result[\"reward_extra_info\"].items():\n                    reward_extra_infos_dict[key].extend(lst)\n\n            # collect num_turns of each prompt\n            if \"__num_turns__\" in test_batch.non_tensor_batch:\n                sample_turns.append(test_batch.non_tensor_batch[\"__num_turns__\"])\n\n            data_source_lst.append(test_batch.non_tensor_batch.get(\"data_source\", [\"unknown\"] * reward_tensor.shape[0]))\n\n        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)\n\n        # dump generations\n        val_data_dir = self.config.trainer.get(\"validation_data_dir\", None)\n        if val_data_dir:\n            self._dump_generations(\n                inputs=sample_inputs,\n                outputs=sample_outputs,\n                gts=sample_gts,\n                scores=sample_scores,\n                reward_extra_infos_dict=reward_extra_infos_dict,\n                dump_path=val_data_dir,\n                logger=logger,\n            )\n\n        for key_info, lst in reward_extra_infos_dict.items():\n            assert len(lst) == 0 or len(lst) == len(sample_scores), f\"{key_info}: {len(lst)=}, {len(sample_scores)=}\"\n\n        data_sources = np.concatenate(data_source_lst, axis=0)\n\n        data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)\n        metric_dict = {}\n        for data_source, var2metric2val in data_src2var2metric2val.items():\n            core_var = \"acc\" if \"acc\" in var2metric2val else \"reward\"\n            for var_name, metric2val in var2metric2val.items():\n                n_max = max([int(name.split(\"@\")[-1].split(\"/\")[0]) for name in metric2val.keys()])\n                for metric_name, metric_val in metric2val.items():\n                    if (\n                        (var_name == core_var)\n                        and any(metric_name.startswith(pfx) for pfx in [\"mean\", \"maj\", \"best\"])\n                        and (f\"@{n_max}\" in metric_name)\n                    ):\n                        metric_sec = \"val-core\"\n                    else:\n                        metric_sec = \"val-aux\"\n                    pfx = f\"{metric_sec}/{data_source}/{var_name}/{metric_name}\"\n                    metric_dict[pfx] = metric_val\n\n        if len(sample_turns) > 0:\n            sample_turns = np.concatenate(sample_turns)\n            metric_dict[\"val-aux/num_turns/min\"] = sample_turns.min()\n            metric_dict[\"val-aux/num_turns/max\"] = sample_turns.max()\n            metric_dict[\"val-aux/num_turns/mean\"] = sample_turns.mean()\n\n        return metric_dict\n\n    def init_workers(self):\n        \"\"\"Initialize distributed training workers using Ray backend.\n\n        Creates:\n        1. Ray resource pools from configuration\n        2. Worker groups for each role (actor, critic, etc.)\n        \"\"\"\n        self.resource_pool_manager.create_resource_pool()\n\n        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}\n\n        # create actor and rollout\n        if self.hybrid_engine:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)\n            actor_rollout_cls = RayClassWithInitArgs(\n                cls=self.role_worker_mapping[Role.ActorRollout],\n                config=self.config.actor_rollout_ref,\n                role=str(Role.ActorRollout),\n            )\n            self.resource_pool_to_cls[resource_pool][str(Role.ActorRollout)] = actor_rollout_cls\n        else:\n            raise NotImplementedError\n\n        # create critic\n        if self.use_critic:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)\n            critic_cfg = omega_conf_to_dataclass(self.config.critic)\n            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)\n            self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls\n\n        # create reference policy if needed\n        if self.use_reference_policy:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)\n            ref_policy_cls = RayClassWithInitArgs(\n                self.role_worker_mapping[Role.RefPolicy],\n                config=self.config.actor_rollout_ref,\n                role=str(Role.RefPolicy),\n            )\n            self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls\n\n        # create a reward model if reward_fn is None\n        if self.use_rm:\n            # we create a RM here\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)\n            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)\n            self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls\n\n        # initialize WorkerGroup\n        # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,\n        # you should not use `create_colocated_worker_cls`.\n        # Instead, directly pass different resource pool to different worker groups.\n        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.\n        all_wg = {}\n        wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup\n        if OmegaConf.select(self.config.trainer, \"ray_wait_register_center_timeout\") is not None:\n            wg_kwargs[\"ray_wait_register_center_timeout\"] = self.config.trainer.ray_wait_register_center_timeout\n        if OmegaConf.select(self.config.global_profiler, \"steps\") is not None:\n            wg_kwargs[\"profile_steps\"] = OmegaConf.select(self.config.global_profiler, \"steps\")\n            # Only require nsight worker options when tool is nsys\n            if OmegaConf.select(self.config.global_profiler, \"tool\") == \"nsys\":\n                assert (\n                    OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, \"worker_nsight_options\")\n                    is not None\n                ), \"worker_nsight_options must be set when using nsys with profile_steps\"\n                wg_kwargs[\"worker_nsight_options\"] = OmegaConf.to_container(\n                    OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, \"worker_nsight_options\")\n                )\n        wg_kwargs[\"device_name\"] = self.device_name\n\n        for resource_pool, class_dict in self.resource_pool_to_cls.items():\n            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n            wg_dict = self.ray_worker_group_cls(\n                resource_pool=resource_pool,\n                ray_cls_with_init=worker_dict_cls,\n                **wg_kwargs,\n            )\n            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n            all_wg.update(spawn_wg)\n\n        if self.use_critic:\n            self.critic_wg = all_wg[str(Role.Critic)]\n            self.critic_wg.init_model()\n\n        if self.use_reference_policy and not self.ref_in_actor:\n            self.ref_policy_wg = all_wg[str(Role.RefPolicy)]\n            self.ref_policy_wg.init_model()\n\n        self.rm_wg = None\n        # initalization of rm_wg will be deprecated in the future\n        if self.use_rm:\n            self.rm_wg = all_wg[str(Role.RewardModel)]\n            self.rm_wg.init_model()\n\n        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory\n        self.actor_rollout_wg = all_wg[str(Role.ActorRollout)]\n        self.actor_rollout_wg.init_model()\n\n        # create async rollout manager and request scheduler\n        self.async_rollout_mode = False\n        if self.config.actor_rollout_ref.rollout.mode == \"async\":\n            from verl.experimental.agent_loop import AgentLoopManager\n\n            self.async_rollout_mode = True\n            self.async_rollout_manager = AgentLoopManager(\n                config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg\n            )\n\n    def _save_checkpoint(self):\n        from verl.utils.fs import local_mkdir_safe\n\n        # path: given_path + `/global_step_{global_steps}` + `/actor`\n        local_global_step_folder = os.path.join(\n            self.config.trainer.default_local_dir, f\"global_step_{self.global_steps}\"\n        )\n\n        print(f\"local_global_step_folder: {local_global_step_folder}\")\n        actor_local_path = os.path.join(local_global_step_folder, \"actor\")\n\n        actor_remote_path = (\n            None\n            if self.config.trainer.default_hdfs_dir is None\n            else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"actor\")\n        )\n\n        remove_previous_ckpt_in_save = self.config.trainer.get(\"remove_previous_ckpt_in_save\", False)\n        if remove_previous_ckpt_in_save:\n            print(\n                \"Warning: remove_previous_ckpt_in_save is deprecated,\"\n                + \" set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead\"\n            )\n        max_actor_ckpt_to_keep = (\n            self.config.trainer.get(\"max_actor_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n        max_critic_ckpt_to_keep = (\n            self.config.trainer.get(\"max_critic_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n\n        self.actor_rollout_wg.save_checkpoint(\n            actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep\n        )\n\n        if self.use_critic:\n            critic_local_path = os.path.join(local_global_step_folder, str(Role.Critic))\n            critic_remote_path = (\n                None\n                if self.config.trainer.default_hdfs_dir is None\n                else os.path.join(\n                    self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", str(Role.Critic)\n                )\n            )\n            self.critic_wg.save_checkpoint(\n                critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep\n            )\n\n        # save dataloader\n        local_mkdir_safe(local_global_step_folder)\n        dataloader_local_path = os.path.join(local_global_step_folder, \"data.pt\")\n        dataloader_state_dict = self.train_dataloader.state_dict()\n        torch.save(dataloader_state_dict, dataloader_local_path)\n\n        # latest checkpointed iteration tracker (for atomic usage)\n        local_latest_checkpointed_iteration = os.path.join(\n            self.config.trainer.default_local_dir, \"latest_checkpointed_iteration.txt\"\n        )\n        with open(local_latest_checkpointed_iteration, \"w\") as f:\n            f.write(str(self.global_steps))\n\n    def _load_checkpoint(self):\n        if self.config.trainer.resume_mode == \"disable\":\n            # NOTE: while there is no checkpoint to load, we still need to offload the model and optimizer to CPU\n            self.actor_rollout_wg.load_checkpoint(None)\n            return 0\n\n        # load from hdfs\n        if self.config.trainer.default_hdfs_dir is not None:\n            raise NotImplementedError(\"load from hdfs is not implemented yet\")\n        else:\n            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path\n            if not os.path.isabs(checkpoint_folder):\n                working_dir = os.getcwd()\n                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)\n            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest\n\n        # find global_step_folder\n        if self.config.trainer.resume_mode == \"auto\":\n            if global_step_folder is None:\n                print(\"Training from scratch\")\n                self.actor_rollout_wg.load_checkpoint(None)\n                return 0\n        else:\n            if self.config.trainer.resume_mode == \"resume_path\":\n                assert isinstance(self.config.trainer.resume_from_path, str), \"resume ckpt must be str type\"\n                assert \"global_step_\" in self.config.trainer.resume_from_path, (\n                    \"resume ckpt must specify the global_steps\"\n                )\n                global_step_folder = self.config.trainer.resume_from_path\n                if not os.path.isabs(global_step_folder):\n                    working_dir = os.getcwd()\n                    global_step_folder = os.path.join(working_dir, global_step_folder)\n        print(f\"Load from checkpoint folder: {global_step_folder}\")\n        # set global step\n        self.global_steps = int(global_step_folder.split(\"global_step_\")[-1])\n\n        print(f\"Setting global step to {self.global_steps}\")\n        print(f\"Resuming from {global_step_folder}\")\n\n        actor_path = os.path.join(global_step_folder, \"actor\")\n        critic_path = os.path.join(global_step_folder, str(Role.Critic))\n        # load actor\n        self.actor_rollout_wg.load_checkpoint(\n            actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n        )\n        # load critic\n        if self.use_critic:\n            self.critic_wg.load_checkpoint(\n                critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n            )\n\n        # load dataloader,\n        # TODO: from remote not implemented yet\n        dataloader_local_path = os.path.join(global_step_folder, \"data.pt\")\n        if os.path.exists(dataloader_local_path):\n            dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)\n            self.train_dataloader.load_state_dict(dataloader_state_dict)\n        else:\n            print(f\"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch\")\n\n    def _start_profiling(self, do_profile: bool) -> None:\n        \"\"\"Start profiling for all worker groups if profiling is enabled.\"\"\"\n        if do_profile:\n            self.actor_rollout_wg.start_profile(role=\"e2e\", profile_step=self.global_steps)\n            if self.use_reference_policy:\n                self.ref_policy_wg.start_profile(profile_step=self.global_steps)\n            if self.use_critic:\n                self.critic_wg.start_profile(profile_step=self.global_steps)\n            if self.use_rm:\n                self.rm_wg.start_profile(profile_step=self.global_steps)\n\n    def _stop_profiling(self, do_profile: bool) -> None:\n        \"\"\"Stop profiling for all worker groups if profiling is enabled.\"\"\"\n        if do_profile:\n            self.actor_rollout_wg.stop_profile()\n            if self.use_reference_policy:\n                self.ref_policy_wg.stop_profile()\n            if self.use_critic:\n                self.critic_wg.stop_profile()\n            if self.use_rm:\n                self.rm_wg.stop_profile()\n\n    def _balance_batch(self, batch: DataProto, metrics, logging_prefix=\"global_seqlen\", keep_minibatch=False):\n        \"\"\"Reorder the data on single controller such that each dp rank gets similar total tokens\"\"\"\n        attention_mask = batch.batch[\"attention_mask\"]\n        batch_size = attention_mask.shape[0]\n        global_seqlen_lst = batch.batch[\"attention_mask\"].view(batch_size, -1).sum(-1)  # (train_batch_size,)\n        global_seqlen_lst = calculate_workload(global_seqlen_lst)\n        world_size = self.actor_rollout_wg.world_size\n        if keep_minibatch:\n            # Decouple the DP balancing and mini-batching.\n            minibatch_size = self.config.actor_rollout_ref.actor.get(\"ppo_mini_batch_size\")\n            minibatch_num = len(global_seqlen_lst) // minibatch_size\n            global_partition_lst = [[] for _ in range(world_size)]\n            for i in range(minibatch_num):\n                rearrange_minibatch_lst = get_seqlen_balanced_partitions(\n                    global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size],\n                    k_partitions=world_size,\n                    equal_size=True,\n                )\n                for j, part in enumerate(rearrange_minibatch_lst):\n                    global_partition_lst[j].extend([x + minibatch_size * i for x in part])\n        else:\n            global_partition_lst = get_seqlen_balanced_partitions(\n                global_seqlen_lst, k_partitions=world_size, equal_size=True\n            )\n        # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.\n        for idx, partition in enumerate(global_partition_lst):\n            partition.sort(key=lambda x: (global_seqlen_lst[x], x))\n            ordered_partition = partition[::2] + partition[1::2][::-1]\n            global_partition_lst[idx] = ordered_partition\n        # reorder based on index. The data will be automatically equally partitioned by dispatch function\n        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])\n        batch.reorder(global_idx)\n        global_balance_stats = log_seqlen_unbalance(\n            seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix\n        )\n        metrics.update(global_balance_stats)\n\n    def compute_rollout_importance_weights_and_add_to_batch(self, batch: DataProto) -> tuple[DataProto, dict]:\n        \"\"\"Compute IS weights and apply rejection sampling for rollout-training mismatch.\n\n        Computes importance sampling weights to correct for distribution mismatch between\n        rollout and training policies. Applies rejection sampling (mask mode/veto) by\n        modifying response_mask. Always updates response_mask; conditionally adds IS weights.\n\n        Key behavior:\n        - response_mask: ALWAYS updated with rejection (mask mode + veto excluded from training)\n        - rollout_is_weights: Added to batch ONLY if config.algorithm.rollout_is=True\n\n        This separation ensures:\n        - Rejection works even when IS weights are disabled (rollout_is=False)\n        - Metrics can be monitored before enabling IS weight application\n\n        Args:\n            batch: DataProto with old_log_probs, rollout_log_probs, response_mask\n\n        Returns:\n            Tuple of (updated_batch, metrics):\n                updated_batch: Batch with modified response_mask (always) and rollout_is_weights (if rollout_is=True)\n                metrics: Dict of IS and mismatch metrics, all with \"mismatch/\" prefix\n        \"\"\"\n        # Compute rollout IS weights if enabled and data is available\n        # rollout_is_threshold is the main on/off switch (None = disabled, float = enabled)\n        rollout_is_threshold = self.config.algorithm.get(\"rollout_is_threshold\", None)\n        if rollout_is_threshold is not None and rollout_is_threshold > 0 and \"rollout_log_probs\" in batch.batch:\n            # Compute IS weights and get modified response_mask\n            rollout_is_weights, modified_response_mask, rollout_is_metrics = compute_rollout_importance_weights(\n                old_log_prob=batch.batch[\"old_log_probs\"],\n                rollout_log_prob=batch.batch[\"rollout_log_probs\"],\n                response_mask=batch.batch[\"response_mask\"],\n                rollout_is_level=self.config.algorithm.rollout_is_level,\n                rollout_is_mode=self.config.algorithm.rollout_is_mode,\n                rollout_is_threshold=self.config.algorithm.rollout_is_threshold,\n                rollout_is_threshold_lower=self.config.algorithm.get(\"rollout_is_threshold_lower\", None),\n                rollout_is_veto_threshold=self.config.algorithm.get(\"rollout_is_veto_threshold\", None),\n            )\n\n            # ALWAYS update response_mask with rejection (even if rollout_is=False)\n            # - Mask mode: tokens with outlier IS ratios excluded\n            # - Veto: sequences with catastrophic tokens excluded\n            # This ensures correct loss normalization (rejected samples not in denominator)\n            batch.batch[\"response_mask\"] = modified_response_mask\n\n            # Conditionally add IS weights based on rollout_is config flag\n            # - rollout_is=True: Enable IS weight correction in policy loss\n            # - rollout_is=False: Metrics-only mode (rejection still applied via mask)\n            apply_weights = self.config.algorithm.get(\"rollout_is\", False)\n\n            if apply_weights:\n                # Add IS weights (safety-bounded, mode-processed) to enable weight correction\n                batch = batch.union(rollout_is_weights)\n\n            return batch, rollout_is_metrics\n\n        # Return unchanged batch and empty metrics if IS is disabled\n        return batch, {}\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        if self.config.actor_rollout_ref.rollout.get(\"skip_rollout\", False):\n            rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)\n            rollout_skip.wrap_generate_sequences()\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n        self.max_steps_duration = 0\n\n        prev_step_profile = False\n        curr_step_profile = (\n            self.global_steps in self.config.global_profiler.steps\n            if self.config.global_profiler.steps is not None\n            else False\n        )\n        next_step_profile = False\n\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n                timing_raw = {}\n\n                with marked_timer(\"start_profile\", timing_raw):\n                    self._start_profiling(\n                        not prev_step_profile and curr_step_profile\n                        if self.config.global_profiler.profile_continuous_steps\n                        else curr_step_profile\n                    )\n                batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n                # add uid to batch\n                batch.non_tensor_batch[\"uid\"] = np.array(\n                    [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object\n                )\n\n                gen_batch = self._get_gen_batch(batch)\n\n                # pass global_steps to trace\n                gen_batch.meta_info[\"global_steps\"] = self.global_steps\n                gen_batch_output = gen_batch.repeat(\n                    repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True\n                )\n\n                is_last_step = self.global_steps >= self.total_training_steps\n                with marked_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with marked_timer(\"gen\", timing_raw, color=\"red\"):\n                        if not self.async_rollout_mode:\n                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)\n                        else:\n                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)\n\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                        if self.reward_fn is None:\n                            raise ValueError(\"A reward_fn is required for REMAX advantage estimation.\")\n\n                        with marked_timer(\"gen_max\", timing_raw, color=\"purple\"):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            if not self.async_rollout_mode:\n                                gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n                            else:\n                                gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)\n                            batch = batch.union(gen_baseline_output)\n                            # compute reward model score on batch\n                            rm_scores = None\n                            if self.use_rm and \"rm_scores\" not in batch.batch.keys():\n                                rm_scores = self.rm_wg.compute_rm_score(batch)\n                                batch = batch.union(rm_scores)\n                            reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            keys_to_pop = set(gen_baseline_output.batch.keys())\n                            if rm_scores is not None:\n                                keys_to_pop.update(rm_scores.batch.keys())\n                            batch.pop(batch_keys=list(keys_to_pop))\n\n                            batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del rm_scores, gen_baseline_batch, gen_baseline_output\n                    # repeat to align with repeated responses in rollout\n                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    batch = batch.union(gen_batch_output)\n\n                    if \"response_mask\" not in batch.batch.keys():\n                        batch.batch[\"response_mask\"] = compute_response_mask(batch)\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                    with marked_timer(\"reward\", timing_raw, color=\"yellow\"):\n                        # compute reward model score\n                        if self.use_rm and \"rm_scores\" not in batch.batch.keys():\n                            reward_tensor = self.rm_wg.compute_rm_score(batch)\n                            batch = batch.union(reward_tensor)\n\n                        if self.config.reward_model.launch_reward_fn_async:\n                            future_reward = compute_reward_async.remote(\n                                data=batch, config=self.config, tokenizer=self.tokenizer\n                            )\n                        else:\n                            reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)\n\n                    # recompute old_log_probs\n                    with marked_timer(\"old_log_prob\", timing_raw, color=\"blue\"):\n                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                        entropys = old_log_prob.batch[\"entropys\"]\n                        response_masks = batch.batch[\"response_mask\"]\n                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                        old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n                        metrics.update(old_log_prob_metrics)\n                        old_log_prob.batch.pop(\"entropys\")\n                        batch = batch.union(old_log_prob)\n\n                        if \"rollout_log_probs\" in batch.batch.keys():\n                            # TODO: we may want to add diff of probs too.\n                            from verl.utils.debug.metrics import \\\n                                calculate_debug_metrics\n\n                            metrics.update(calculate_debug_metrics(batch))\n\n                    if self.use_reference_policy:\n                        # compute reference log_prob\n                        with marked_timer(str(Role.RefPolicy), timing_raw, color=\"olive\"):\n                            if not self.ref_in_actor:\n                                ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                            else:\n                                ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)\n                            batch = batch.union(ref_log_prob)\n\n                    # compute values\n                    if self.use_critic:\n                        with marked_timer(\"values\", timing_raw, color=\"cyan\"):\n                            values = self.critic_wg.compute_values(batch)\n                            batch = batch.union(values)\n\n                    with marked_timer(\"adv\", timing_raw, color=\"brown\"):\n                        # we combine with rule-based rm\n                        reward_extra_infos_dict: dict[str, list]\n                        if self.config.reward_model.launch_reward_fn_async:\n                            reward_tensor, reward_extra_infos_dict = ray.get(future_reward)\n                        batch.batch[\"token_level_scores\"] = reward_tensor\n\n                        if reward_extra_infos_dict:\n                            batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})\n\n                        # compute rewards. apply_kl_penalty if available\n                        if self.config.algorithm.use_kl_in_reward:\n                            batch, kl_metrics = apply_kl_penalty(\n                                batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                            )\n                            metrics.update(kl_metrics)\n                        else:\n                            batch.batch[\"token_level_rewards\"] = batch.batch[\"token_level_scores\"]\n\n                        # Compute rollout importance sampling weights centrally (once per batch)\n                        # This corrects for mismatch between rollout policy and training policy\n                        # Also computes mismatch metrics (KL, PPL, etc.)\n                        batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)\n                        # IS and mismatch metrics already have mismatch/ prefix\n                        metrics.update(is_metrics)\n\n                        # compute advantages, executed on the driver process\n                        norm_adv_by_std_in_grpo = self.config.algorithm.get(\n                            \"norm_adv_by_std_in_grpo\", True\n                        )  # GRPO adv normalization factor\n\n                        batch = compute_advantage(\n                            batch,\n                            adv_estimator=self.config.algorithm.adv_estimator,\n                            gamma=self.config.algorithm.gamma,\n                            lam=self.config.algorithm.lam,\n                            num_repeat=self.config.actor_rollout_ref.rollout.n,\n                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                            config=self.config.algorithm,\n                        )\n\n                    # update critic\n                    if self.use_critic:\n                        with marked_timer(\"update_critic\", timing_raw, color=\"pink\"):\n                            critic_output = self.critic_wg.update_critic(batch)\n                        critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                        metrics.update(critic_output_metrics)\n\n                    # implement critic warmup\n                    if self.config.trainer.critic_warmup <= self.global_steps:\n                        # update actor\n                        with marked_timer(\"update_actor\", timing_raw, color=\"red\"):\n                            batch.meta_info[\"multi_turn\"] = self.config.actor_rollout_ref.rollout.multi_turn.enable\n                            actor_output = self.actor_rollout_wg.update_actor(batch)\n                        actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                        metrics.update(actor_output_metrics)\n\n                    # Log rollout generations if enabled\n                    rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n                    if rollout_data_dir:\n                        self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir, logger)\n\n                # validate\n                if (\n                    self.val_reward_fn is not None\n                    and self.config.trainer.test_freq > 0\n                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                ):\n                    with marked_timer(\"testing\", timing_raw, color=\"green\"):\n                        val_metrics: dict = self._validate()\n                        if is_last_step:\n                            last_val_metrics = val_metrics\n                    metrics.update(val_metrics)\n\n                # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.\n                esi_close_to_expiration = should_save_ckpt_esi(\n                    max_steps_duration=self.max_steps_duration,\n                    redundant_time=self.config.trainer.esi_redundant_time,\n                )\n                # Check if the conditions for saving a checkpoint are met.\n                # The conditions include a mandatory condition (1) and\n                # one of the following optional conditions (2/3/4):\n                # 1. The save frequency is set to a positive value.\n                # 2. It's the last training step.\n                # 3. The current step number is a multiple of the save frequency.\n                # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.\n                if self.config.trainer.save_freq > 0 and (\n                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration\n                ):\n                    if esi_close_to_expiration:\n                        print(\"Force saving checkpoint: ESI instance expiration approaching.\")\n                    with marked_timer(\"save_checkpoint\", timing_raw, color=\"green\"):\n                        self._save_checkpoint()\n\n                with marked_timer(\"stop_profile\", timing_raw):\n                    next_step_profile = (\n                        self.global_steps + 1 in self.config.global_profiler.steps\n                        if self.config.global_profiler.steps is not None\n                        else False\n                    )\n                    self._stop_profiling(\n                        curr_step_profile and not next_step_profile\n                        if self.config.global_profiler.profile_continuous_steps\n                        else curr_step_profile\n                    )\n                    prev_step_profile = curr_step_profile\n                    curr_step_profile = next_step_profile\n\n                steps_duration = timing_raw[\"step\"]\n                self.max_steps_duration = max(self.max_steps_duration, steps_duration)\n\n                # training metrics\n                metrics.update(\n                    {\n                        \"training/global_step\": self.global_steps,\n                        \"training/epoch\": epoch,\n                    }\n                )\n                # collect metrics\n                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n                # TODO: implement actual tflpo and theoretical tflpo\n                n_gpus = self.resource_pool_manager.get_n_gpus()\n                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n                # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation\n\n                # this is experimental and may be changed/removed in the future in favor of a general-purpose one\n                if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):\n                    self.train_dataloader.sampler.update(batch=batch)\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                progress_bar.update(1)\n                self.global_steps += 1\n\n                if (\n                    hasattr(self.config.actor_rollout_ref.actor, \"profiler\")\n                    and self.config.actor_rollout_ref.actor.profiler.tool == \"torch_memory\"\n                ):\n                    self.actor_rollout_wg.dump_memory_snapshot(\n                        tag=f\"post_update_step{self.global_steps}\", sub_dir=f\"step{self.global_steps}\"\n                    )\n\n                if is_last_step:\n                    pprint(f\"Final validation metrics: {last_val_metrics}\")\n                    progress_bar.close()\n                    return\n\n                # this is experimental and may be changed/removed in the future\n                # in favor of a general-purpose data buffer pool\n                if hasattr(self.train_dataset, \"on_batch_end\"):\n                    # The dataset may be changed after each training batch\n                    self.train_dataset.on_batch_end(batch=batch)\n"
  },
  {
    "path": "verl_distillation/verl/trainer/ppo/reward.py",
    "content": "# Copyright 2025 Individual Contributor: Thibaut Barroyer\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib.util\nimport inspect\nimport multiprocessing\nimport os\nimport sys\nimport warnings\nfrom functools import partial\nfrom typing import Any, Optional\n\nimport ray\nimport torch\nfrom omegaconf import DictConfig\n\nfrom verl import DataProto\nfrom verl.utils.reward_score import default_compute_score\nfrom verl.utils.transferqueue_utils import tqbridge\nfrom verl.workers.reward_manager import get_reward_manager_cls\nfrom verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn\n\n\ndef _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):\n    \"\"\"Calls `raw_fn` by merging `extra_kwargs` into call-time `kwargs`, with `extra_kwargs` taking precedence.\n\n    This function is used to merge additional keyword arguments with the original function's arguments.\n    \"\"\"\n    merged_kwargs = {**kwargs, **extra_kwargs}\n    return raw_fn(*args, **merged_kwargs)\n\n\nasync def _call_with_kwargs_async(raw_fn, extra_kwargs, *args, **kwargs):\n    \"\"\"Calls `raw_fn` by merging `extra_kwargs` into call-time `kwargs`, with `extra_kwargs` taking precedence.\n\n    This function is used to merge additional keyword arguments with the original function's arguments.\n    \"\"\"\n    merged_kwargs = {**kwargs, **extra_kwargs}\n    return await raw_fn(*args, **merged_kwargs)\n\n\ndef get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]:\n    \"\"\"Load and return a custom reward function from external file.\n\n    Dynamically imports a reward function from a specified file path and wraps\n    it with additional keyword arguments from the configuration.\n\n    Args:\n        config (dict): Configuration dictionary containing custom_reward_function\n                      settings with 'path', 'name', and 'reward_kwargs' fields.\n\n    Returns:\n        callable or None: Wrapped reward function with merged kwargs, or None\n                         if no custom reward function is configured.\n\n    Raises:\n        FileNotFoundError: If the specified reward function file doesn't exist.\n        RuntimeError: If there's an error loading the module from file.\n        AttributeError: If the specified function name isn't found in the module.\n    \"\"\"\n\n    reward_fn_config = config.get(\"custom_reward_function\") or {}\n    file_path = reward_fn_config.get(\"path\")\n    if not file_path:\n        return None\n\n    function_name = reward_fn_config.get(\"name\")\n    assert function_name is not None\n\n    module = sys.modules.get(\"custom_module\", None)\n    if module is None:\n        if not os.path.exists(file_path):\n            raise FileNotFoundError(f\"Reward function file '{file_path}' not found.\")\n\n        spec = importlib.util.spec_from_file_location(\"custom_module\", file_path)\n        assert spec is not None\n        module = importlib.util.module_from_spec(spec)\n        try:\n            sys.modules[\"custom_module\"] = module\n            assert spec.loader is not None\n            spec.loader.exec_module(module)\n        except Exception as e:\n            raise RuntimeError(f\"Error loading module from '{file_path}': {e}\") from e\n\n    if not hasattr(module, function_name):\n        raise AttributeError(f\"Reward function '{function_name}' not found in '{module.__file__}'.\")\n\n    print(f\"using customized reward function '{function_name}' from '{module.__file__}'\")\n    raw_fn = getattr(module, function_name)\n\n    reward_kwargs = dict(reward_fn_config.get(\"reward_kwargs\", {}))\n\n    if not inspect.iscoroutinefunction(raw_fn):\n        return partial(_call_with_kwargs, raw_fn, reward_kwargs)\n    else:\n        return partial(_call_with_kwargs_async, raw_fn, reward_kwargs)\n\n\ndef load_reward_manager(\n    config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any\n) -> AbstractRewardManager:\n    \"\"\"\n    Load and initialize a reward manager based on the configuration.\n\n    Args:\n        config: PPO trainer configuration object containing reward_model fields.\n        tokenizer: Tokenizer object used for processing text.\n        num_examine: Number of samples to examine.\n        **reward_kwargs: Additional keyword arguments for the reward manager.\n\n    Returns:\n        An instance of the specified reward manager class.\n    \"\"\"\n\n    # Try to get a custom reward function based on the configuration\n    # user defined reward manager can be registered in custom_reward_fn\n    compute_score = get_custom_reward_fn(config)\n    final_compute_score = compute_score\n\n    # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:\n    # naive: NaiveRewardManager\n    # prime: PrimeRewardManager\n    # batch: BatchRewardManager\n    # dapo: DAPORewardManager\n    # Note(haibin.lin): For custom reward managers, please make sure they are imported and\n    # registered via `verl.workers.reward_manager.register`\n    # By default reward_manager is set to naive (NaiveRewardManager)\n    reward_manager_name = config.reward_model.get(\"reward_manager\", \"naive\")\n    reward_manager_cls = get_reward_manager_cls(reward_manager_name)\n\n    if compute_score is None:\n        sandbox_config = config.reward_model.get(\"sandbox_fusion\")\n        sandbox_url = sandbox_config.get(\"url\") if sandbox_config else None\n        memory_limit_mb = sandbox_config.get(\"memory_limit_mb\", 1024) if sandbox_config else 1024\n        if sandbox_url:\n            sandbox_manager = multiprocessing.Manager()\n            # Create a semaphore to control concurrent access to the sandbox\n            _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get(\"max_concurrent\", 64))\n            final_compute_score = partial(\n                default_compute_score,\n                sandbox_fusion_url=sandbox_url,\n                concurrent_semaphore=_concurrent_semaphore,\n                memory_limit_mb=memory_limit_mb,\n            )\n        else:\n            final_compute_score = default_compute_score\n\n    # Instantiate and return the reward manager with the specified parameters\n    return reward_manager_cls(\n        tokenizer=tokenizer,\n        num_examine=num_examine,\n        compute_score=final_compute_score,\n        reward_fn_key=config.data.reward_fn_key,\n        **reward_kwargs,\n    )\n\n\n@tqbridge(put_data=False)\ndef compute_reward(data: DataProto, reward_fn: AbstractRewardManager) -> tuple[torch.Tensor, dict[str, Any]]:\n    \"\"\"\n    Compute reward for a batch of data.\n    Args:\n        data: DataProto object containing the input data.\n        reward_fn: Reward function to compute the reward.\n    Returns:\n        Tuple of reward tensor and extra info dictionary.\n    \"\"\"\n    try:\n        reward_result = reward_fn(data, return_dict=True)\n        reward_tensor = reward_result[\"reward_tensor\"]\n        reward_extra_infos_dict = reward_result.get(\"reward_extra_info\", {})\n    except Exception as e:\n        print(f\"Error in reward_fn: {e}\")\n        reward_tensor = reward_fn(data)\n        reward_extra_infos_dict = {}\n\n    return reward_tensor, reward_extra_infos_dict\n\n\n@ray.remote(num_cpus=1)\ndef compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn=None):\n    \"\"\"\n    Load the reward manager and compute the reward for a batch of data.\n    This is meant to be run in a separate Ray worker.\n    \"\"\"\n    if reward_fn is None:\n        assert config is not None and tokenizer is not None, (\n            \"config and tokenizer must not be None when reward_fn is None\"\n        )\n\n        warnings.warn(\"using config and tokenizer with compute_reward_async is deprecated\", stacklevel=2)\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n\n    return compute_reward(data, reward_fn)\n"
  },
  {
    "path": "verl_distillation/verl/trainer/ppo/utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport warnings\nfrom enum import Enum\n\nfrom omegaconf import DictConfig\n\nfrom verl.single_controller.base import Worker\nfrom verl.trainer.ppo.core_algos import AdvantageEstimator\n\nWorkerType = type[Worker]\n\n\nclass Role(Enum):\n    \"\"\"\n    To create more roles dynamically, you can subclass Role and add new members\n    \"\"\"\n\n    Actor = 0\n    Rollout = 1\n    ActorRollout = 2\n    Critic = 3\n    RefPolicy = 4\n    RewardModel = 5\n    ActorRolloutRef = 6\n\n    def __str__(self):\n        return self._get_role_string()\n\n    def _get_role_string(self):\n        role_mapping = {\n            Role.Actor: \"actor\",\n            Role.Rollout: \"rollout\",\n            Role.ActorRollout: \"actor_rollout\",\n            Role.Critic: \"critic\",\n            Role.RefPolicy: \"ref\",\n            Role.RewardModel: \"rm\",\n            Role.ActorRolloutRef: \"actor_rollout_ref\",\n        }\n        return role_mapping.get(self, self.name.lower())\n\n    @classmethod\n    def from_string(cls, name: str):\n        string_mapping = {\n            \"actor\": cls.Actor,\n            \"rollout\": cls.Rollout,\n            \"actor_rollout\": cls.ActorRollout,\n            \"critic\": cls.Critic,\n            \"ref\": cls.RefPolicy,\n            \"rm\": cls.RewardModel,\n            \"actor_rollout_ref\": cls.ActorRolloutRef,\n        }\n        role = string_mapping.get(name.lower())\n        if role is None:\n            raise ValueError(f\"No Role found for string: {name}\")\n        return role\n\n\ndef need_reference_policy(\n    role_worker_mapping: dict[Role, WorkerType],\n) -> bool:\n    \"\"\"Given a role worker mapping, do we need ref policy.\"\"\"\n    return Role.RefPolicy in role_worker_mapping\n\n\ndef need_reward_model(\n    role_worker_mapping: dict[Role, WorkerType],\n) -> bool:\n    \"\"\"Given a role worker mapping, do we need reward model.\"\"\"\n    return Role.RewardModel in role_worker_mapping\n\n\ndef need_critic(config: DictConfig) -> bool:\n    \"\"\"Given a config, do we need critic.\"\"\"\n    if config.critic.enable is not None:\n        return bool(config.critic.enable)\n    elif config.algorithm.adv_estimator == AdvantageEstimator.GAE:\n        return True\n    else:\n        warnings.warn(\n            \"Disabled critic as algorithm.adv_estimator != gae. If it is not intended, please set critic.enable=True\",\n            stacklevel=2,\n        )\n        return False\n"
  },
  {
    "path": "verl_distillation/verl/trainer/runtime_env.yaml",
    "content": "working_dir: ./\nexcludes: [\"/.git/\"]\nenv_vars:\n  TORCH_NCCL_AVOID_RECORD_STREAMS: \"1\"\n  CUDA_DEVICE_MAX_CONNECTIONS: \"1\"\n"
  },
  {
    "path": "verl_distillation/verl/trainer/sft_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport os\nfrom functools import partial\n\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n\nimport logging\n\nimport hydra\nimport torch\nimport torch.distributed\nfrom codetiming import Timer\nfrom omegaconf import OmegaConf\nfrom torch.utils.data import DistributedSampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom tqdm import tqdm\n\nfrom verl.utils import tensordict_utils as tu\nfrom verl.utils.checkpoint import CheckpointHandler\nfrom verl.utils.dataset.dataset_utils import SFTTensorCollator\nfrom verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset\nfrom verl.utils.device import get_device_name, is_cuda_available, is_npu_available\nfrom verl.utils.distributed import destroy_global_process_group\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.logger import log_with_rank\nfrom verl.utils.tracking import Tracking\n\nif is_cuda_available:\n    pass\nelif is_npu_available:\n    pass\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_SFT_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass SFTTrainer:\n    def __init__(\n        self,\n        config,\n    ):\n        self.config = config\n\n        self.rank = torch.distributed.get_rank()\n\n        self._build_config()\n        self._build_dataset()\n\n        self._build_engine()\n\n        self._build_dataloader()\n\n        self._init_engine()\n\n        self._build_ckpt_handler()\n\n        # Initialize resume-related variables\n        self.resume_global_step = self.ckpt_handler.load_checkpoint()\n\n        self.device_name = self.config.trainer.device\n\n        from verl.workers.roles.utils.losses import sft_loss\n\n        self.loss_fn = partial(sft_loss, config=None)\n\n        self.flops_counter = FlopsCounter(self.model_config.hf_config)\n\n        if self.rank == 0:\n            print(self.config)\n\n    def _build_ckpt_handler(self):\n        resume_mode = getattr(self.config.trainer, \"resume_mode\", \"auto\")\n        resume_from_path = getattr(self.config.trainer, \"resume_from_path\", None)\n        max_ckpt_to_keep = getattr(self.config.trainer, \"max_ckpt_to_keep\", None)\n        default_hdfs_dir = getattr(self.config.trainer, \"default_hdfs_dir\", None)\n\n        self.ckpt_handler = CheckpointHandler(\n            engine=self.engine,\n            train_dataloader=self.train_dataloader,\n            default_local_dir=self.config.trainer.default_local_dir,\n            max_ckpt_to_keep=max_ckpt_to_keep,\n            default_hdfs_dir=default_hdfs_dir,\n            resume_mode=resume_mode,\n            resume_from_path=resume_from_path,\n        )\n\n    def _build_config(self):\n        from verl.utils.config import omega_conf_to_dataclass\n\n        self.model_config = omega_conf_to_dataclass(self.config.model)\n        self.engine_config = omega_conf_to_dataclass(self.config.engine)\n        self.optimizer_config = omega_conf_to_dataclass(self.config.optim)\n        self.checkpoint_config = omega_conf_to_dataclass(self.config.checkpoint)\n\n    def _build_engine(self):\n        from verl.workers.engine import BaseEngine, EngineRegistry\n\n        self.engine: BaseEngine = EngineRegistry.new(\n            model_type=\"language_model\",\n            backend=self.engine_config.strategy,\n            model_config=self.model_config,\n            engine_config=self.engine_config,\n            optimizer_config=self.optimizer_config,\n            checkpoint_config=self.checkpoint_config,\n        )\n\n    def _init_engine(self):\n        # patch optimizer config\n        if self.config.trainer.total_training_steps is not None:\n            self.total_training_steps = self.config.trainer.total_training_steps\n        else:\n            self.total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n        self.optimizer_config.total_training_steps = self.total_training_steps\n\n        self.steps_per_epoch = len(self.train_dataloader)\n\n        # manage save and test frequency\n        self.save_freq = self.config.trainer.save_freq\n        if self.save_freq == \"after_each_epoch\":\n            self.save_freq = self.steps_per_epoch\n\n        self.test_freq = self.config.trainer.test_freq\n        if self.test_freq == \"after_each_epoch\":\n            self.test_freq = self.steps_per_epoch\n\n        self.engine.initialize()\n\n    def _build_dataset(self):\n        config = self.config\n        tokenizer = self.model_config.tokenizer\n        train_dataset = create_sft_dataset(\n            config.data.train_files, config.data, tokenizer, max_samples=config.data.get(\"train_max_samples\", -1)\n        )\n        if config.data.val_files:\n            val_dataset = create_sft_dataset(\n                config.data.val_files, config.data, tokenizer, max_samples=config.data.get(\"val_max_samples\", -1)\n            )\n        else:\n            val_dataset = None\n\n        self.train_dataset, self.val_dataset = train_dataset, val_dataset\n\n    def _build_dataloader(self):\n        # build dataset\n        config = self.config\n        # build dataloader\n        # Use data parallel rank and size instead of global rank and world size\n\n        # Set pin_memory_device when pin_memory is enabled.\n        device_name = get_device_name()\n\n        dp_rank = self.engine.get_data_parallel_rank()\n        dp_size = self.engine.get_data_parallel_size()\n\n        self.train_sampler = DistributedSampler(\n            self.train_dataset, shuffle=True, num_replicas=dp_size, rank=dp_rank, drop_last=True\n        )\n\n        self.global_batch_size = config.data.train_batch_size\n        self.train_batch_size_per_dp = self.global_batch_size // dp_size\n        self.collate_fn = SFTTensorCollator(config.data.pad_mode)\n\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=self.train_batch_size_per_dp,\n            sampler=self.train_sampler,\n            collate_fn=self.collate_fn,\n            num_workers=8,\n            pin_memory=True,\n            drop_last=True,\n            pin_memory_device=device_name,\n        )\n\n        if self.val_dataset:\n            self.val_sampler = DistributedSampler(\n                self.val_dataset, shuffle=False, num_replicas=dp_size, rank=dp_rank, drop_last=True\n            )\n            self.val_dataloader = StatefulDataLoader(\n                dataset=self.val_dataset,\n                batch_size=self.train_batch_size_per_dp,\n                sampler=self.val_sampler,\n                collate_fn=self.collate_fn,\n                num_workers=8,\n                pin_memory=True,\n                drop_last=True,\n                pin_memory_device=device_name,\n            )\n        else:\n            self.val_dataloader = None\n\n    def fit(self):\n        is_logging = self.engine.is_mp_src_rank_with_outputs() and self.engine.get_data_parallel_rank() == 0\n\n        # TODO: add a unified tracking\n        if is_logging:\n            tracking = Tracking(\n                project_name=self.config.trainer.project_name,\n                experiment_name=self.config.trainer.experiment_name,\n                default_backend=self.config.trainer.logger,\n                config=OmegaConf.to_container(self.config, resolve=True),\n            )\n\n        global_step = self.resume_global_step  # Start from resumed step\n        last_valid_metric = None\n\n        log_with_rank(\n            f\"Total training steps: {self.total_training_steps},\",\n            logger=logger,\n            rank=0,\n            log_only_rank_0=True,\n        )\n\n        # With StatefulDataLoader, we don't need to manually calculate epochs and steps\n        # The dataloader will automatically resume from where it left off\n        if global_step > 0:\n            log_with_rank(\n                f\"StatefulDataLoader will automatically resume from global step: {global_step}\",\n                logger=logger,\n                rank=0,\n                log_only_rank_0=True,\n            )\n\n        # Calculate which epoch we're starting from for sampler.set_epoch()\n        start_epoch = global_step // self.steps_per_epoch\n\n        meta_info = {\n            \"use_remove_padding\": self.config.model.use_remove_padding,\n            \"use_dynamic_bsz\": self.config.data.use_dynamic_bsz,\n            \"max_token_len_per_gpu\": self.config.data.max_token_len_per_gpu,\n            \"micro_batch_size_per_gpu\": self.config.data.micro_batch_size_per_gpu,\n            \"temperature\": 1.0,\n            \"global_batch_size\": self.global_batch_size,\n            \"pad_mode\": self.config.data.pad_mode,\n            \"pad_token_id\": self.model_config.tokenizer.pad_token_id,\n        }\n\n        train_time = 0\n        total_tokens = 0\n        for epoch in range(start_epoch, self.config.trainer.total_epochs):\n            self.train_sampler.set_epoch(epoch=epoch)\n\n            for step_in_epoch, data in enumerate(\n                tqdm(\n                    self.train_dataloader,\n                    initial=global_step % self.steps_per_epoch if epoch == start_epoch else 0,\n                    total=self.steps_per_epoch,\n                    desc=f\"Epoch {epoch + 1}/{self.config.trainer.total_epochs}\",\n                    disable=not is_logging,\n                )\n            ):\n                global_step += 1\n\n                # construct tensordict\n                data = tu.get_tensordict(tensor_dict=data, non_tensor_dict=meta_info)\n\n                with self.engine.train_mode():\n                    with Timer(name=\"update_policy\", logger=None) as timer:\n                        output = self.engine.train_batch(data=data, loss_function=self.loss_fn)\n                lr = self.engine.lr_scheduler_step()\n\n                if self.engine.is_mp_src_rank_with_outputs():\n                    metrics = output[\"metrics\"]\n\n                    loss = torch.sum(torch.tensor(metrics[\"loss\"], device=self.device_name))\n\n                    # mean over dp group\n                    is_nested = data[\"input_ids\"].is_nested\n                    if is_nested:\n                        batch_seqlens: torch.Tensor = data[\"input_ids\"].offsets().diff()\n                    else:\n                        batch_seqlens: torch.Tensor = data[\"attention_mask\"].sum(dim=-1)\n                    batch_seqlens = batch_seqlens.to(self.device_name)  # (global_bsz // dp)\n\n                    output_tensor = torch.randint(\n                        0,\n                        100,\n                        (batch_seqlens.shape[0] * self.engine.get_data_parallel_size(),),\n                        device=self.device_name,\n                    )  # (global_bsz,)\n\n                    torch.distributed.all_gather_into_tensor(\n                        output_tensor=output_tensor,\n                        input_tensor=batch_seqlens,\n                        group=self.engine.get_data_parallel_group(),\n                    )\n                    torch.distributed.all_reduce(\n                        loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group()\n                    )\n\n                    batch_seqlens = output_tensor.tolist()\n                    loss = loss.item()\n\n                    # TODO: we can actual accumulate metrics for N steps and perform aggregate metrics\n                    metrics[\"loss\"] = loss\n                    metrics[\"train/loss\"] = metrics.pop(\"loss\")\n                    metrics[\"train/grad_norm\"] = metrics.pop(\"grad_norm\")\n                    metrics[\"train/lr\"] = lr\n                    metrics[\"train/global_tokens\"] = output_tensor.sum().item()\n                    total_tokens += metrics[\"train/global_tokens\"]\n                    metrics[\"train/total_tokens(B)\"] = total_tokens / 1e9\n                    # mfu\n                    delta_time = timer.last\n                    estimated_flops, promised_flops = self.flops_counter.estimate_flops(batch_seqlens, delta_time)\n                    metrics[\"train/mfu\"] = estimated_flops / promised_flops / torch.distributed.get_world_size()\n\n                    if self.engine.get_data_parallel_rank() == 0:\n                        tracking.log(data=metrics, step=global_step)\n\n                is_last_step = global_step >= self.total_training_steps\n                is_valid_step = global_step % self.test_freq == 0\n                is_save_step = global_step % self.save_freq == 0\n\n                # early exit or validation step\n                if is_last_step and self.val_dataloader is not None or (self.test_freq > 0 and is_valid_step):\n                    # Perform validation\n                    val_losses = []\n                    for val_data in self.val_dataloader:\n                        with self.engine.eval_mode():\n                            # construct tensordict\n                            val_data = tu.get_tensordict(tensor_dict=val_data, non_tensor_dict=meta_info)\n                            output = self.engine.infer_batch(data=val_data, loss_function=self.loss_fn)\n                            if self.engine.is_mp_src_rank_with_outputs():\n                                val_losses.extend(output[\"metrics\"][\"loss\"])\n\n                    if self.engine.is_mp_src_rank_with_outputs():\n                        val_loss = torch.mean(torch.tensor(val_losses, device=self.device_name))\n                        # average over data parallel group\n                        torch.distributed.all_reduce(\n                            val_loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group()\n                        )\n\n                    if is_logging:\n                        metric = {\"val/loss\": val_loss.detach().item()}\n                        tracking.log(data=metric, step=global_step)\n                        last_valid_metric = metric\n                    torch.distributed.barrier()\n\n                if is_last_step or (self.save_freq > 0 and is_save_step):\n                    self.ckpt_handler.save_checkpoint(step=global_step)\n\n                if is_last_step:\n                    if is_logging:\n                        print(f\"Total time for train steps: {train_time:.2f}s\")\n                        print(f\"Final validation metrics: {last_valid_metric}\")\n                    return\n\n\ndef run_sft(config):\n    from verl.utils.distributed import initialize_global_process_group\n\n    initialize_global_process_group()\n    trainer = SFTTrainer(config=config)\n    trainer.fit()\n    destroy_global_process_group()\n\n\n@hydra.main(config_path=\"config\", config_name=\"sft_trainer_engine\", version_base=None)\ndef main(config):\n    run_sft(config)\n\n\ndef create_sft_dataset(data_paths, data_config, tokenizer, max_samples=-1):\n    \"\"\"Create a dataset.\"\"\"\n    # build dataset\n    # First check if a custom dataset class is specified\n    if data_config.custom_cls.get(\"path\", None):\n        from verl.utils.import_utils import load_extern_type\n\n        dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)\n    else:\n        # Default to multi-turn dataset\n        dataset_cls = MultiTurnSFTDataset\n\n    # Create datasets based on the selected class\n    dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config, max_samples=max_samples)\n    return dataset\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_distillation/verl/utils/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom . import config, tokenizer\nfrom .config import omega_conf_to_dataclass, validate_config\nfrom .groupwise import as_torch_index, group_mean_std\nfrom .tokenizer import hf_processor, hf_tokenizer\n\n__all__ = (\n    tokenizer.__all__\n    + config.__all__\n    + [\"hf_processor\", \"hf_tokenizer\", \"omega_conf_to_dataclass\", \"validate_config\"]\n    + [\"as_torch_index\", \"group_mean_std\"]\n)\n"
  },
  {
    "path": "verl_distillation/verl/utils/activation_offload.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Functionality for CPU offloading of tensors saved for backward pass.\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nimport logging\nimport os\nfrom typing import Any, Optional\n\nimport torch\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nfrom verl.utils.device import get_torch_device\nfrom verl.utils.fsdp_utils import FSDPModule as FSDP2\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef _get_unique_tensor_key(tensor):\n    key = (tensor.untyped_storage().data_ptr() + tensor.storage_offset(), tensor.dtype)\n    return key\n\n\nclass FSDPParameterFilter:\n    def __init__(self):\n        self.model_parameters_storage = set()\n\n    def __call__(self, tensor):\n        return tensor.untyped_storage().data_ptr() not in self.model_parameters_storage\n\n    def update_model_parameters(self, model):\n        new_storage = set()\n        for p in model.parameters():\n            new_storage.add(p.data.untyped_storage().data_ptr())\n        self.model_parameters_storage = new_storage\n\n\nclass CpuOffloadHookWithOffloadHandler:\n    \"\"\"Context-manager that offloads/recovers tensors through an offload hander.\n\n    The hook just offloads/recovers the tensor object to the handler through `tensor_push`\n    and `tensor_pop` interface. How the offload-handler manages the offloading, recovering\n    or prefetching timing is transparent to this hook.\n    \"\"\"\n\n    def __init__(\n        self,\n        offload_handler: OffloadHandler,\n        handler_extra_kwargs: Optional[dict[str, Any]] = None,\n    ) -> None:\n        if handler_extra_kwargs is None:\n            handler_extra_kwargs = {}\n        self.offload_handler: OffloadHandler = offload_handler\n        self.handler_extra_kwargs: dict[str, Any] = handler_extra_kwargs\n        self.inside_context = False\n\n    def __enter__(self):\n        self.inside_context = True\n        torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor)\n\n    def __exit__(self, *args: Any):\n        self.inside_context = False\n        torch._C._autograd._pop_saved_tensors_default_hooks()\n\n    def on_save_for_backward(self, tensor: torch.Tensor) -> Any:\n        retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)\n        return retrieve_identifier\n\n    def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:\n        tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs)\n        return tensor\n\n\nclass OffloadHandler:\n    \"\"\"A base class for CPU offload-handler.\"\"\"\n\n    def __init__(self) -> None:\n        pass\n\n    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:\n        \"\"\"Tensor push.\"\"\"\n        raise NotImplementedError(\n            \"`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your \"\n            \"custom tensor_push.\"\n        )\n\n    def tensor_pop(self, tensor_tag: Any, **kwargs):\n        \"\"\"Tensor pop.\"\"\"\n        raise NotImplementedError(\n            \"`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your \"\n            \"custom tensor_pop.\"\n        )\n\n\nclass GroupCommitFunction(torch.autograd.Function):\n    \"\"\"this is a dummy op with output identical to input.\n    However, it is necessary for marking a timepoint for offload handler to\n    accomplish all synchronizations. Implementing it as a function is necessary\n    because we need to actions in both forward and backward.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, tensor, cpu_offload_handler):\n        # pylint: disable=missing-function-docstring\n        cpu_offload_handler.on_group_commit_forward()\n        ctx.cpu_offload_handler = cpu_offload_handler\n        # return the identical tensor\n        return tensor\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        # pylint: disable=missing-function-docstring\n        cpu_offload_handler = ctx.cpu_offload_handler\n        cpu_offload_handler.on_group_commit_backward()\n        return grad_output, None\n\n\ngroup_prefetch_offload_commit = GroupCommitFunction.apply\n\n\nclass SynchronizedGroupOffloadHandler(OffloadHandler):\n    \"\"\"Offload Handler that offloads/reloads in a synchronized way.\n    The device-to-host and host-to-device copying happen in the same stream\n    as the computation kernels, thus the copying will block computation.\n    \"\"\"\n\n    def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None:\n        super().__init__()\n\n        self.num_offload_group = num_offload_group\n        self.tensor_need_offloading_checker = tensor_need_offloading_checker\n\n        self.groupid_reset()\n\n    def groupid_reset(self):\n        \"\"\"Groupid reset.\"\"\"\n        # Data structures to label saved tensors and book-keep their cpu copies.\n        # Currently, on push, create a new cpu tensor and copies; on pop, copies\n        # the tensor back to gpu and deletes the cpu tensor.\n        # These will increment whenever `group_commit()` is invoked\n        self.current_group, self.tensor_count_current_group = (0, 0)\n        self.torch_tensor_count = 0\n        self.tensor_tag_to_state = {}\n\n    def on_group_commit_forward(self):\n        \"\"\"On group commit forward.\"\"\"\n        # finishing up with updating current group and tensor count\n        self.current_group += 1  # increment\n        self.tensor_count_current_group = 0  # reset\n\n    def on_group_commit_backward(self):\n        \"\"\"On group commit backward.\"\"\"\n        self.current_group -= 1\n        assert self.current_group >= 0\n\n    @staticmethod\n    def offload(src_tensor, pin_memory=True):\n        \"\"\"Offload.\"\"\"\n\n        cpu_backup = torch.empty(\n            src_tensor.size(),\n            dtype=src_tensor.dtype,\n            layout=src_tensor.layout,\n            device=\"cpu\",\n            pin_memory=pin_memory,\n        )\n        cpu_backup.copy_(src_tensor, non_blocking=True)\n        state = (src_tensor.device, cpu_backup)\n        return state\n\n    @staticmethod\n    def reload(state, non_blocking=None):\n        \"\"\"Reload.\"\"\"\n        dev, cpu_backup = state\n        if non_blocking is None:\n            non_blocking = cpu_backup.is_pinned()\n        return cpu_backup.to(dev, non_blocking=non_blocking)\n\n    def tensor_push(self, tensor: torch.Tensor, **kwargs):\n        \"\"\"Tensor push.\"\"\"\n        # obtain a unique tensor tag\n        tensor_tag = (self.current_group, self.tensor_count_current_group)\n        self.tensor_count_current_group += 1\n        assert tensor_tag not in self.tensor_tag_to_state\n        if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor):\n            state = SynchronizedGroupOffloadHandler.offload(tensor)\n            self.tensor_tag_to_state[tensor_tag] = state\n        else:\n            # will be offloaded together after group commit\n            self.tensor_tag_to_state[tensor_tag] = tensor\n\n        return tensor_tag\n\n    def tensor_pop(self, tensor_tag, **kwargs):\n        \"\"\"Tensor pop.\"\"\"\n        assert tensor_tag in self.tensor_tag_to_state\n        state = self.tensor_tag_to_state.pop(tensor_tag)\n        if isinstance(state, tuple):\n            tensor = SynchronizedGroupOffloadHandler.reload(state)\n        else:\n            tensor = state\n        return tensor\n\n\nclass AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):\n    \"\"\"Compared to synchronize, this uses more memory because of the buffer but\n    achieves better performance due to the overlapping. D2h and h2d copying are\n    completely hidden behind computation if computation time of a layer is longer\n    than host-device communication time. Bulk offloading with delay and bulk reloading\n    with prefetch are implemented.\"\"\"\n\n    def __init__(\n        self,\n        num_offload_group,  # must be <= actual number of groups (number of commits)\n        num_model_group,\n        tensor_need_offloading_checker=(lambda t: True),\n    ) -> None:\n        super().__init__(\n            num_offload_group=num_offload_group,\n            tensor_need_offloading_checker=tensor_need_offloading_checker,\n        )\n        # Number of layers in the model\n        self.num_layers = num_model_group\n        # Data Structure to maintain reference to activation tensors\n        self.tensor_tag_to_buf = {}\n        # Tracking the number of layers offloaded\n        self.offloaded_group_count = 0\n        # Core data structure that decides the window for offloading\n        self.layer_window_map = {}\n        self.group_offload_mapping = {}\n\n        # Logic to make offloading load balance across computation\n        # for optimal CPU/GPU interconnect usage\n        constant = 0\n        for i in range(self.num_offload_group):\n            self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1\n            if i < (self.num_layers % self.num_offload_group):\n                self.layer_window_map[i] += i + 1\n                constant = i + 1\n            else:\n                self.layer_window_map[i] += constant\n\n        # allocate streams and events for synchronization\n        self.d2h_stream = get_torch_device().Stream()\n        self.h2d_stream = get_torch_device().Stream()\n\n    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:\n        torch_stray_tensor = isinstance(\n            tensor,\n            torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor,\n        )\n        need_offload = not torch_stray_tensor\n        need_offload = need_offload and self.tensor_need_offloading_checker(tensor)\n\n        if need_offload:\n            # obtain a unique tensor tag\n            tensor_tag = (self.current_group, self.tensor_count_current_group)\n            self.tensor_count_current_group += 1\n\n            assert tensor_tag not in self.tensor_tag_to_state\n            self.tensor_tag_to_state[tensor_tag] = tensor\n\n            if self.current_group < self.num_offload_group:\n                self.tensor_tag_to_buf[tensor_tag] = tensor\n        else:\n            tensor_tag = tensor\n        return tensor_tag\n\n    def tensor_pop(self, tensor_tag, **kwargs):\n        \"\"\"Tensor pop.\"\"\"\n        if isinstance(tensor_tag, torch.Tensor):\n            return tensor_tag\n        assert tensor_tag in self.tensor_tag_to_state\n        tensor = self.tensor_tag_to_state.pop(tensor_tag)\n        self.tensor_tag_to_buf.pop(tensor_tag, None)\n\n        # the tensor should have been copied back in on_group_commit_backward()\n        # which invokes bulk_reload_group.\n        assert not isinstance(tensor, tuple)\n        return tensor\n\n    def bulk_offload_group(self, group_to_offload):\n        \"\"\"Bulk offload group.\"\"\"\n        offload_mapping = {}\n        offload_size = 0\n        with get_torch_device().stream(self.d2h_stream):\n            for tensor_tag, state in self.tensor_tag_to_state.items():\n                group_id, _ = tensor_tag\n                if group_id == group_to_offload:\n                    assert not isinstance(state, tuple)\n                    key = _get_unique_tensor_key(state)\n                    if key not in offload_mapping:\n                        offload_mapping[key] = state\n                    # if offload, return the reference to cpu copy\n                    self.tensor_tag_to_state[tensor_tag] = (key, state.shape)\n            for key, tensor in offload_mapping.items():\n                state = SynchronizedGroupOffloadHandler.offload(tensor)\n                offload_size += tensor.numel() * tensor.element_size()\n                offload_mapping[key] = state\n\n            self.group_offload_mapping[group_to_offload] = offload_mapping\n\n    def synchronize_on_group_commit_forward(self, current_group):\n        \"\"\"Synchronize on group commit forward.\"\"\"\n\n        # For the first group, kickstart the offload after we have\n        # the first compute completion\n        if current_group == 0:\n            self.d2h_stream.wait_stream(get_torch_device().current_stream())\n            self.bulk_offload_group(current_group)\n\n        # Window map data structure helps us synchronize based on number\n        # of layers offloaded\n        if self.layer_window_map[self.offloaded_group_count] == current_group:\n            # Stream synchronization both ways\n            self.d2h_stream.wait_stream(get_torch_device().current_stream())\n            get_torch_device().current_stream().wait_stream(self.d2h_stream)\n\n            # Time to free the activation memory after usage\n            for tensor_tag, _ in self.tensor_tag_to_buf.items():\n                if tensor_tag[0] == self.offloaded_group_count:\n                    self.tensor_tag_to_buf[tensor_tag] = None\n\n            # Time to offload the next group\n            if self.offloaded_group_count < (self.num_offload_group - 1):\n                self.bulk_offload_group(self.offloaded_group_count + 1)\n\n            # Increment the offload group count to keep track\n            self.offloaded_group_count += 1\n\n    def on_group_commit_forward(self):\n        \"\"\"This function will cause host device synchronization\"\"\"\n        # handle synchronization events\n        self.synchronize_on_group_commit_forward(self.current_group)\n\n        super().on_group_commit_forward()\n\n    @torch.no_grad\n    def bulk_reload_group(self, group_to_reload):\n        \"\"\"Bulk reload group.\"\"\"\n        assert group_to_reload < self.num_offload_group\n\n        with get_torch_device().stream(self.h2d_stream):\n            # move back tensors\n            offload_mapping = self.group_offload_mapping.pop(group_to_reload)\n            assert offload_mapping is not None\n            for key, state in offload_mapping.items():\n                offload_mapping[key] = SynchronizedGroupOffloadHandler.reload(state)\n            for tensor_label, state in self.tensor_tag_to_state.items():\n                group_id, _ = tensor_label\n                if group_id == group_to_reload and not isinstance(state, torch.Tensor):\n                    assert isinstance(state, tuple), f\"{group_id} {state}\"\n                    key, shape = state\n                    recovered_tensor = offload_mapping[key].view(shape)\n                    self.tensor_tag_to_state[tensor_label] = recovered_tensor\n\n    def on_group_commit_backward(self):\n        # first decrement the current group.\n        # after last commit in forward, the group will +1; in backward it -1.\n        # Finally it should be decremented to 0.\n        self.current_group -= 1\n        assert self.current_group >= 0\n\n        # Layer window data structure helps us to reload at right times\n        if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:\n            # Stream synchronization both ways\n            self.h2d_stream.wait_stream(get_torch_device().current_stream())\n            get_torch_device().current_stream().wait_stream(self.h2d_stream)\n\n            # Time to reload the next group\n            self.bulk_reload_group(self.offloaded_group_count - 1)\n\n            # Decrease the offloading group counter\n            self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0\n\n        # Last group computation needs to wait till all the reloads complete\n        if self.current_group == 0:\n            get_torch_device().current_stream().wait_stream(self.h2d_stream)\n            self.offloaded_group_count = 0\n\n\ndef get_activation_offload_context(\n    num_layers: int = 1, model_layers: int = 1, tensor_need_offloading_checker=(lambda t: True)\n):\n    cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(\n        num_offload_group=num_layers,\n        num_model_group=model_layers,\n        tensor_need_offloading_checker=tensor_need_offloading_checker,\n    )\n\n    def group_prefetch_offload_commit_async(tensor):\n        return group_prefetch_offload_commit(tensor, cpu_offload_handler)\n\n    return (\n        CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),\n        group_prefetch_offload_commit_async,\n    )\n\n\nclass ActivationHandler:\n    def __init__(self, offload_ctx, sync_func, tensor_filter, enable_ckpt):\n        self._offload_ctx = offload_ctx\n        self._sync_func = sync_func\n        self._enable_ckpt = enable_ckpt\n        self._tensor_filter = tensor_filter\n        if enable_ckpt:\n            self.checkpoint_fn = functools.partial(\n                torch.utils.checkpoint.checkpoint,\n                use_reentrant=True,\n            )\n\n    def pre_forward(self, module):\n        if module.training:\n            self._offload_ctx.__enter__()\n            self._tensor_filter.update_model_parameters(module)\n\n    def post_forward(self, module):\n        if module.training:\n            self._offload_ctx.__exit__(None, None, None)\n\n    def _pack_kwargs(self, *args, **kwargs):\n        kwarg_keys = []\n        flat_args = list(args)\n        for k, v in kwargs.items():\n            kwarg_keys.append(k)\n            flat_args.append(v)\n\n        return tuple(flat_args), tuple(kwarg_keys)\n\n    def _unpack_kwargs(self, flat_args, kwarg_keys):\n        assert len(kwarg_keys) <= len(flat_args), f\"too many keys {len(kwarg_keys)} vs. {len(flat_args)}\"\n        if len(kwarg_keys) == 0:\n            return flat_args, {}\n        args = flat_args[: -len(kwarg_keys)]\n        kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :], strict=True))\n        return args, kwargs\n\n    def _ckpt_forward(self, forward_method, *args, **kwargs):\n        flat_args, kwarg_keys = self._pack_kwargs(*args, **kwargs)\n\n        def my_function(*inputs):\n            # unpack back into args and kwargs\n            nonlocal forward_method, kwarg_keys\n            unpacked_args, unpacked_kwargs = self._unpack_kwargs(inputs, kwarg_keys)\n            # run original module\n            return forward_method(*unpacked_args, **unpacked_kwargs)\n\n        return self.checkpoint_fn(\n            my_function,\n            *flat_args,\n        )\n\n    def forward(self, module, forward_method, *args, **kwargs):\n        if not module.training:\n            return forward_method(*args, **kwargs)\n        if not self._enable_ckpt:\n            ret = forward_method(*args, **kwargs)\n        else:\n            ret = self._ckpt_forward(forward_method, *args, **kwargs)\n        binded_tensor = ret\n        if isinstance(ret, tuple):\n            binded_tensor = ret[0]\n        binded_tensor = self._sync_func(binded_tensor)\n        final_ret = binded_tensor\n        if isinstance(ret, tuple):\n            final_ret = (final_ret,) + ret[1:]\n        return final_ret\n\n    def wrap_module_forward_method(self, module):\n        orig_method = module.forward\n        handler = self\n\n        @functools.wraps(orig_method)\n        def wrapped_method(model_self, *args, **kwargs):\n            nonlocal handler\n            handler.pre_forward(model_self)\n            out = handler.forward(model_self, orig_method, *args, **kwargs)\n            handler.post_forward(model_self)\n            return out\n\n        module.forward = wrapped_method.__get__(module, type(module))\n\n\ndef enable_activation_offloading(model, strategy, enable_ckpt=False):\n    \"\"\"\n    Enable activation offloading for the model. It groups activations by TransformerLayer and offloads activation\n    groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th\n    activation group happen at the same time, and there are at most two activation groups in GPU memory.\n\n    Args:\n        model: the model to enable activation offloading\n        strategy: the training strategy of the model, such as \"fsdp\"\n        enable_ckpt: whether activation checkpointing(also called gradient checkpointing) has been enabled for the model\n\n    Note:\n        For best efficiency, activation offloading is usually combined with activation checkpointing. However, this\n        implementation of activation offloading is conflicted with the implementation of activation checkpointing in\n        some training strategies. This function resolves this conflict, and therefore requires the \"strategy\" and\n        \"enable_ckpt\" arguments.\n\n    Returns:\n\n    \"\"\"\n\n    assert strategy == \"fsdp\" or strategy == \"fsdp2\", \"activation offloading only supports fsdp strategy\"\n    layers = []\n\n    def get_layers(module):\n        for name, child in module.named_children():\n            if not isinstance(child, FSDP | FSDP2):\n                get_layers(child)\n            else:\n                wrapped_module = child\n                if isinstance(child, FSDP):\n                    wrapped_module = child._fsdp_wrapped_module\n                # In some cases, torch.nn.Embedding is wrapped with FSDP alone. However, the activation\n                # size of torch.nn.Embedding is small, so it's not necessary to offload it.\n                if not isinstance(wrapped_module, torch.nn.Embedding):\n                    layers.append(child)\n\n    get_layers(model)\n    if len(layers) < 3:\n        logger.warning(f\"Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading\")\n        return\n\n    tensor_filter = FSDPParameterFilter()\n    context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter)\n    if enable_ckpt:\n        # The implementation of activation checkpointing in transformers library is incompatible with\n        # activation offloading,\n        # so it will be disabled, but this implementation supports another version of activation checkpointing, so that\n        # these two features can be enabled at the same time.\n        for module in model.modules():\n            if hasattr(module, \"gradient_checkpointing_disable\"):\n                module.gradient_checkpointing_disable()\n\n    handler = ActivationHandler(context, sync_func, tensor_filter, enable_ckpt)\n    for layer in layers:\n        module = layer\n        if isinstance(layer, FSDP):\n            module = module._fsdp_wrapped_module\n        handler.wrap_module_forward_method(module)\n"
  },
  {
    "path": "verl_distillation/verl/utils/attention_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable\n\n_index_first_axis, _pad_input, _rearrange, _unpad_input = None, None, None, None\n\n\ndef _get_attention_functions() -> tuple[Callable, Callable, Callable, Callable]:\n    \"\"\"Dynamically import attention functions based on available hardware.\"\"\"\n\n    from verl.utils.device import is_cuda_available, is_npu_available\n\n    global _index_first_axis, _pad_input, _rearrange, _unpad_input\n\n    if is_cuda_available:\n        from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\n    elif is_npu_available:\n        from verl.utils.npu_utils import index_first_axis, pad_input, rearrange, unpad_input\n\n    _index_first_axis, _pad_input, _rearrange, _unpad_input = index_first_axis, pad_input, rearrange, unpad_input\n\n    return _index_first_axis, _pad_input, _rearrange, _unpad_input\n\n\ndef index_first_axis(*args, **kwargs):\n    \"\"\"\n    Unified entry point for `index_first_axis` across CUDA and NPU backends.\n\n    Dynamically dispatches to the appropriate device-specific implementation:\n      - On CUDA: `flash_attn.bert_padding.index_first_axis`\n      - On NPU: `transformers.integrations.npu_flash_attention.index_first_axis`\n        (falls back to `transformers.modeling_flash_attention_utils._index_first_axis`\n        in newer versions of transformers).\n\n    Users can call this function directly without worrying about the underlying device.\n    \"\"\"\n    func, *_ = _get_attention_functions()\n    return func(*args, **kwargs)\n\n\ndef pad_input(*args, **kwargs):\n    \"\"\"\n    Unified entry point for `pad_input` across CUDA and NPU backends.\n\n    Dynamically dispatches to the appropriate device-specific implementation:\n      - On CUDA: `flash_attn.bert_padding.pad_input`\n      - On NPU: `transformers.integrations.npu_flash_attention.pad_input`\n        (falls back to `transformers.modeling_flash_attention_utils._pad_input`\n        in newer versions of transformers).\n\n    Users can call this function directly without worrying about the underlying device.\n    \"\"\"\n    _, func, *_ = _get_attention_functions()\n    return func(*args, **kwargs)\n\n\ndef rearrange(*args, **kwargs):\n    \"\"\"\n    Unified entry point for `rearrange` across CUDA and NPU backends.\n\n    Dynamically dispatches to the appropriate device-specific implementation:\n      - On CUDA: `flash_attn.bert_padding.rearrange`\n      - On NPU: `transformers.integrations.npu_flash_attention.rearrange`\n        (falls back to `einops.rearrange` if no dedicated NPU implementation exists).\n\n    Users can call this function directly without worrying about the underlying device.\n    \"\"\"\n    *_, func, _ = _get_attention_functions()\n    return func(*args, **kwargs)\n\n\ndef unpad_input(*args, **kwargs):\n    \"\"\"\n    Unified entry point for `unpad_input` across CUDA and NPU backends.\n\n    Dynamically dispatches to the appropriate device-specific implementation:\n      - On CUDA: `flash_attn.bert_padding.unpad_input`\n      - On NPU: `transformers.integrations.npu_flash_attention.unpad_input`\n        (falls back to `transformers.modeling_flash_attention_utils._unpad_input`\n        in newer versions of transformers).\n\n    Users can call this function directly without worrying about the underlying device.\n    \"\"\"\n    *_, func = _get_attention_functions()\n    return func(*args, **kwargs)\n\n\n__all__ = [\"index_first_axis\", \"pad_input\", \"rearrange\", \"unpad_input\"]\n"
  },
  {
    "path": "verl_distillation/verl/utils/checkpoint/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .checkpoint_handler import CheckpointHandler\n\n__all__ = [\"CheckpointHandler\"]\n"
  },
  {
    "path": "verl_distillation/verl/utils/checkpoint/checkpoint_handler.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n# TODO: add unit tests\n\nimport logging\nimport os\nimport re\n\nimport torch\n\nimport verl.utils.hdfs_io as hdfs_io\nfrom verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, get_checkpoint_tracker_filename\nfrom verl.utils.logger import log_with_rank\nfrom verl.workers.engine import BaseEngine\n\n\ndef extract_step(path):\n    match = re.search(r\"global_step_(\\d+)\", path)\n    if match:\n        return int(match.group(1))\n    return None\n\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_SFT_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass CheckpointHandler:\n    \"\"\"\n    Checkpoint handler handles the path, global_step of a checkpoint folder.\n    Currently, it only works with a single model.\n    We can expand it to support multiple models. It is expected to be used with SPMD style (e.g., torchrun)\n    \"\"\"\n\n    def __init__(\n        self,\n        engine: BaseEngine,\n        train_dataloader,\n        *,\n        default_local_dir,\n        max_ckpt_to_keep=None,\n        default_hdfs_dir=None,\n        resume_mode=\"auto\",\n        resume_from_path=None,\n    ):\n        self.default_local_dir = default_local_dir\n        self.max_ckpt_to_keep = max_ckpt_to_keep\n        self.default_hdfs_dir = default_hdfs_dir\n        self.resume_mode = resume_mode\n        self.resume_from_path = resume_from_path\n        self.engine = engine\n        self.train_dataloader = train_dataloader\n        self.rank = torch.distributed.get_rank()\n\n    def save_checkpoint(self, step):\n        \"\"\"Save checkpoint using FSDPCheckpointManager with improved tracking\"\"\"\n        from verl.utils.fs import local_mkdir_safe\n\n        # Determine checkpoint path\n        local_global_step_folder = os.path.join(self.default_local_dir, f\"global_step_{step}\")\n        if self.rank == 0:\n            print(f\"Saving checkpoint to: {local_global_step_folder}\")\n\n        # Get max checkpoints to keep\n        max_ckpt_to_keep = self.max_ckpt_to_keep\n\n        # Use checkpoint manager to save\n        self.engine.save_checkpoint(\n            local_path=local_global_step_folder, global_step=step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n\n        # Save dataloader state. Note that we only save the iterator in the train_dataloader.\n        # So it's identical in each dp rank.\n        if self.engine.is_mp_src_rank_with_outputs():\n            dp_rank = self.engine.get_data_parallel_rank()\n            local_mkdir_safe(local_global_step_folder)\n            dataloader_local_path = os.path.join(local_global_step_folder, f\"data_{dp_rank}.pt\")\n\n            # Use StatefulDataLoader's built-in state dict functionality\n            dataloader_state_dict = self.train_dataloader.state_dict()\n            torch.save(dataloader_state_dict, dataloader_local_path)\n            print(f\"Saved dataloader state to: {dataloader_local_path}\")\n\n        if self.rank == 0:\n            # Update latest checkpoint tracker (atomic write)\n            tracker_file = get_checkpoint_tracker_filename(self.default_local_dir)\n            temp_tracker_file = tracker_file + \".tmp\"\n            with open(temp_tracker_file, \"w\") as f:\n                f.write(str(step))\n            os.rename(temp_tracker_file, tracker_file)\n            print(f\"Updated checkpoint tracker: {tracker_file}\")\n\n        # Copy to HDFS if configured\n        if self.rank == 0 and self.default_hdfs_dir:\n            hdfs_io.makedirs(self.default_hdfs_dir, exist_ok=True)\n            hdfs_io.copy(src=local_global_step_folder, dst=self.default_hdfs_dir, dirs_exist_ok=True)\n\n        torch.distributed.barrier()\n\n    def load_checkpoint(self):\n        # Determine resume path based on configuration\n        checkpoint_path = self._determine_resume_path()\n\n        if checkpoint_path is None:\n            return 0\n\n        # extract resume step from checkpoint path\n        resume_step = extract_step(checkpoint_path)\n        if resume_step is None:\n            log_with_rank(\n                f\"Warning: Could not extract step number from {checkpoint_path}, starting from step 0\",\n                logger=logger,\n                rank=self.rank,\n                level=logging.WARNING,\n                log_only_rank_0=True,\n            )\n            return 0\n        self.resume_global_step = resume_step\n\n        # Use checkpoint manager to load model state\n        self.engine.load_checkpoint(checkpoint_path)\n        # Always load dataloader state for StatefulDataLoader\n        self._load_dataloader_state(checkpoint_path)\n\n        return resume_step\n\n    def _load_dataloader_state(self, checkpoint_path: str):\n        \"\"\"Load dataloader state from checkpoint\"\"\"\n        dp_rank = self.engine.get_data_parallel_rank()\n        dataloader_path = os.path.join(checkpoint_path, f\"data_{dp_rank}.pt\")\n\n        if os.path.exists(dataloader_path):\n            # Use StatefulDataLoader's built-in state dict functionality\n            dataloader_state_dict = torch.load(dataloader_path, map_location=\"cpu\", weights_only=False)\n            self.train_dataloader.load_state_dict(dataloader_state_dict)\n\n            log_with_rank(\n                f\"Successfully loaded dataloader state from {dataloader_path}\",\n                logger=logger,\n                rank=self.rank,\n                log_only_rank_0=True,\n            )\n\n        else:\n            log_with_rank(\n                f\"Warning: No dataloader state found at {dataloader_path}, will start from scratch\",\n                logger=logger,\n                rank=self.rank,\n                level=logging.WARNING,\n                log_only_rank_0=True,\n            )\n\n    def _determine_resume_path(self):\n        \"\"\"Determine the path to resume from based on resume_mode configuration\"\"\"\n        resume_mode = self.resume_mode\n        resume_from_path = self.resume_from_path\n\n        if resume_mode == \"disable\":\n            return None\n        elif resume_mode == \"auto\":\n            if resume_from_path is not None:\n                assert os.path.exists(resume_from_path), (\n                    \"resume_from_path must be null or an existing path when resume_mode is 'auto'\"\n                )\n                assert \"global_step_\" in resume_from_path, \"resume_from_path must specify the global_steps\"\n                return resume_from_path\n            # Try to find the latest checkpoint in the default directory\n            return self._find_latest_checkpoint()\n        elif resume_mode == \"resume_path\":\n            assert os.path.exists(resume_from_path), (\n                \"resume_from_path must be an existing path when resume_mode is 'resume_path'\"\n            )\n            assert \"global_step_\" in resume_from_path, \"resume_from_path must specify the global_steps\"\n            return resume_from_path\n        else:\n            raise ValueError(f\"Invalid resume_mode: {resume_mode}. Must be 'auto', 'disable', or 'resume_path'\")\n\n    def _find_latest_checkpoint(self):\n        \"\"\"Find the latest checkpoint in the default local directory\"\"\"\n        checkpoint_dir = self.default_local_dir\n\n        if not os.path.exists(checkpoint_dir):\n            return None\n\n        latest_checkpoint = find_latest_ckpt_path(checkpoint_dir)\n\n        if latest_checkpoint and self.rank == 0:\n            step_num = extract_step(latest_checkpoint)\n            print(f\"Found latest checkpoint: {latest_checkpoint} (step {step_num})\")\n\n        return latest_checkpoint\n"
  },
  {
    "path": "verl_distillation/verl/utils/checkpoint/checkpoint_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport random\nimport shutil\n\nimport numpy as np\nimport torch\nimport torch.distributed\nfrom omegaconf import DictConfig\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nfrom verl.trainer.config import CheckpointConfig\nfrom verl.utils.device import get_device_name, get_torch_device\n\n\nclass BaseCheckpointManager:\n    \"\"\"\n    A checkpoint manager that saves and loads the following states in a SPMD way:\n    - model\n    - optimizer\n    - lr_scheduler\n    - extra_states\n\n    We save\n    - sharded model states and optimizer states\n    - full lr_scheduler states\n    - huggingface tokenizer and config for ckpt merge\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        optimizer: torch.optim.Optimizer,\n        lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,\n        processing_class: PreTrainedTokenizer | ProcessorMixin = None,\n        checkpoint_config: DictConfig | CheckpointConfig = None,\n    ):\n        self.checkpoint_config = checkpoint_config\n        checkpoint_load_contents = checkpoint_config.get(\"load_contents\", None) if checkpoint_config else None\n        checkpoint_save_contents = checkpoint_config.get(\"save_contents\", None) if checkpoint_config else None\n        if checkpoint_load_contents is None:\n            checkpoint_load_contents = [\"model\", \"optimizer\", \"extra\"]\n        if checkpoint_save_contents is None:\n            checkpoint_save_contents = [\"model\", \"optimizer\", \"extra\"]\n        self.previous_global_step = None\n        self.previous_saved_paths = []\n\n        self.model = model\n        self.optimizer = optimizer\n        self.lr_scheduler = lr_scheduler\n        self.processing_class = processing_class\n        self.checkpoint_load_contents = checkpoint_load_contents\n        self.checkpoint_save_contents = checkpoint_save_contents\n\n        self.rank = torch.distributed.get_rank()\n        self.world_size = torch.distributed.get_world_size()\n\n    @property\n    def should_save_model(self) -> bool:\n        \"\"\"\n        Returns True if 'model' is in checkpoint_save_contents, indicating the model state should be saved.\n        \"\"\"\n        return \"model\" in self.checkpoint_save_contents\n\n    @property\n    def should_save_optimizer(self) -> bool:\n        \"\"\"\n        Returns True if 'optimizer' is in checkpoint_save_contents, indicating the optimizer state should be saved.\n        \"\"\"\n        return \"optimizer\" in self.checkpoint_save_contents\n\n    @property\n    def should_save_extra(self) -> bool:\n        \"\"\"\n        Returns True if 'extra' is in checkpoint_save_contents, indicating the extra state should be saved.\n        \"\"\"\n        return \"extra\" in self.checkpoint_save_contents\n\n    @property\n    def should_save_hf_model(self) -> bool:\n        \"\"\"\n        Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf\n        model and saved.\n        \"\"\"\n        return \"hf_model\" in self.checkpoint_save_contents\n\n    @property\n    def should_load_model(self) -> bool:\n        \"\"\"\n        Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded.\n        \"\"\"\n        return \"model\" in self.checkpoint_load_contents\n\n    @property\n    def should_load_optimizer(self) -> bool:\n        \"\"\"\n        Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded.\n        \"\"\"\n        return \"optimizer\" in self.checkpoint_load_contents\n\n    @property\n    def should_load_extra(self) -> bool:\n        \"\"\"\n        Returns True if 'extra' is in checkpoint_load_contents, indicating the extra state should be loaded.\n        \"\"\"\n        return \"extra\" in self.checkpoint_load_contents\n\n    def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False):\n        raise NotImplementedError\n\n    def save_checkpoint(\n        self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None\n    ):\n        raise NotImplementedError\n\n    @staticmethod\n    def checkpath(local_path: str, hdfs_path: str):\n        assert local_path is not None or hdfs_path is not None, \"local_path and hdfs_path cannot be both None\"\n        return local_path is not None, local_path if local_path is not None else hdfs_path\n\n    def remove_previous_save_local_path(self, path):\n        if isinstance(path, str):\n            path = [path]\n        for p in path:\n            abs_path = os.path.abspath(p)\n            print(f\"Checkpoint manager remove previous save local path: {abs_path}\")\n            if not os.path.exists(abs_path):\n                continue\n            shutil.rmtree(abs_path, ignore_errors=True)\n\n    @staticmethod\n    def get_rng_state():\n        rng_state = {\n            \"cpu\": torch.get_rng_state(),\n            \"numpy\": np.random.get_state(),\n            \"random\": random.getstate(),\n        }\n\n        if get_device_name() != \"cpu\":\n            rng_state[get_device_name()] = get_torch_device().get_rng_state()\n\n        return rng_state\n\n    @staticmethod\n    def load_rng_state(rng_state):\n        torch.set_rng_state(rng_state[\"cpu\"])\n        np.random.set_state(rng_state[\"numpy\"])\n        random.setstate(rng_state[\"random\"])\n\n        if get_device_name() != \"cpu\":\n            get_torch_device().set_rng_state(rng_state[get_device_name()])\n\n\ndef find_latest_ckpt_path(path, directory_format=\"global_step_{}\"):\n    \"\"\"\n    Return the most recent checkpoint directory based on a tracker file.\n\n    Args:\n        path (str): Base directory containing the checkpoint tracker.\n        directory_format (str): Template for checkpoint subfolders with one\n            placeholder for the iteration number (default \"global_step_{}\").\n\n    Returns:\n        str or None: Full path to the latest checkpoint directory, or\n        None if the tracker or checkpoint folder is missing.\n    \"\"\"\n    if path is None:\n        return None\n\n    tracker_file = get_checkpoint_tracker_filename(path)\n    if not os.path.exists(tracker_file):\n        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:\n            print(f\"Checkpoint tracker file does not exist: {tracker_file}\")\n        return None\n\n    with open(tracker_file, \"rb\") as f:\n        iteration = int(f.read().decode())\n    ckpt_path = os.path.join(path, directory_format.format(iteration))\n    if not os.path.exists(ckpt_path):\n        print(\"Checkpoint does not exist: %s\", ckpt_path)\n        return None\n\n    print(\"Found checkpoint: %s\", ckpt_path)\n    return ckpt_path\n\n\ndef get_checkpoint_tracker_filename(root_path: str):\n    \"\"\"\n    Tracker file rescords the latest chckpoint during training to restart from.\n    \"\"\"\n    return os.path.join(root_path, \"latest_checkpointed_iteration.txt\")\n\n\ndef should_save_ckpt_esi(max_steps_duration: float, save_ckpt_duration: float = 60, redundant_time: float = 0) -> bool:\n    \"\"\"\n    Determine if checkpoint should be saved based on capacity esi expiration.\n\n    Args:\n        max_steps_duration: Max estimated time (seconds) required to complete one training step\n        save_ckpt_duration: Estimated time (seconds) required to save checkpoint (default: 60)\n        redundant_time: Additional buffer time (seconds) for unexpected delays (default: 0)\n    \"\"\"\n    exp_ts_mlp = os.getenv(\"MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP\")  # vemlp\n    exp_ts_aws = os.getenv(\"SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP\")  # aws\n    if exp_ts_mlp:\n        try:\n            import time\n\n            remaining = float(exp_ts_mlp) - time.time()\n        except ValueError:\n            return False\n        return (\n            remaining > 0\n            and max_steps_duration > 0\n            and remaining <= save_ckpt_duration + max_steps_duration + redundant_time\n        )\n    elif exp_ts_aws:\n        from datetime import datetime, timedelta\n\n        expiration_time = datetime.fromtimestamp(int(exp_ts_aws))\n        time_difference = expiration_time - datetime.now()\n        threshold_minutes = (save_ckpt_duration + max_steps_duration + redundant_time) / 60\n        return time_difference < timedelta(minutes=threshold_minutes)\n    else:\n        return False\n"
  },
  {
    "path": "verl_distillation/verl/utils/checkpoint/fsdp_checkpoint_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport warnings\nfrom dataclasses import asdict, dataclass\nfrom typing import Optional\n\nimport torch\nimport torch.distributed\nfrom accelerate import init_empty_weights\nfrom omegaconf import DictConfig\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType\nfrom transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin\nfrom transformers.dynamic_module_utils import custom_object_save\n\nfrom verl.utils.device import is_cuda_available\nfrom verl.utils.fs import copy_to_local, is_non_local, local_mkdir_safe\nfrom verl.utils.fsdp_utils import fsdp_version, get_fsdp_full_state_dict, get_fsdp_state_ctx\nfrom verl.utils.logger import log_with_rank\n\nfrom .checkpoint_manager import BaseCheckpointManager\n\n# Setup logging\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"INFO\"))\n\n\n@dataclass\nclass FSDPConfig:\n    \"\"\"Configuration for FSDP checkpointing.\n\n    Args:\n        FSDP_version (int): Version of FSDP being used.\n        world_size (int): Number of processes in the distributed training setup.\n    \"\"\"\n\n    FSDP_version: int\n    world_size: int\n\n\nclass FSDPCheckpointManager(BaseCheckpointManager):\n    \"\"\"\n    Manage FSDP checkpointing in SPMD training.\n\n    - Saves/loads per-rank sharded model & optimizer states\n    - Persists full lr_scheduler and RNG state\n    - Stores HF tokenizer/processor and model/config for unified restore\n\n    Args:\n        model (FSDP): Wrapped model instance.\n        optimizer (Optimizer): Training optimizer.\n        lr_scheduler (LRScheduler): Learning-rate scheduler.\n        processing_class (PreTrainedTokenizer or ProcessorMixin, optional):\n            Pre-/post-processing artifact handler.\n        checkpoint_contents DictConfig: Configuration for checkpoint contents.\n            - 'load': Components to load; must contain 'model'. Defaults to ['model', 'optimizer', 'extra'].\n            - 'save': Components to save; must contain 'model'. Defaults to ['model', 'optimizer', 'extra'].\n    \"\"\"\n\n    def __init__(\n        self,\n        model: FSDP,\n        optimizer: Optional[torch.optim.Optimizer] = None,\n        lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,\n        processing_class: PreTrainedTokenizer | ProcessorMixin = None,\n        checkpoint_config: DictConfig = None,\n        **kwargs,\n    ):\n        if processing_class is None and \"tokenizer\" in kwargs:\n            warnings.warn(\n                \"`tokenizer` is deprecated. use `processing_class` instead.\", DeprecationWarning, stacklevel=2\n            )\n            processing_class = kwargs.pop(\"tokenizer\")\n\n        super().__init__(\n            model,\n            optimizer,\n            lr_scheduler=lr_scheduler,\n            processing_class=processing_class,\n            checkpoint_config=checkpoint_config,\n        )\n\n    def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):\n        \"\"\"\n        Load an FSDP checkpoint for this rank.\n\n        Downloads and loads:\n          - model and optimizer shards\n          - extra state dict (scheduler + RNG)\n\n        Args:\n            local_path: Directory with per-rank checkpoint files.\n            hdfs_path: Unused (for API compatibility).\n            del_local_after_load: Remove local files after loading.\n        \"\"\"\n        if local_path is None:\n            return\n\n        # check if the checkpoint_load_contents is valid\n        if self.should_load_model:\n            assert self.model is not None, \"model must be provided when checkpoint_contents.load includes ['model']\"\n        if self.should_load_optimizer:\n            assert self.optimizer is not None, (\n                \"optimizer must be provided when checkpoint_contents.load includes ['optimizer']\"\n            )\n\n        # every rank download its own checkpoint\n        state_dict_cfg = (\n            ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n            if self.should_load_model\n            else None\n        )\n        optim_cfg = (\n            ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n            if self.should_load_optimizer\n            else None\n        )\n        with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):\n            if self.should_load_model:\n                remote_model_path = os.path.join(local_path, f\"model_world_size_{self.world_size}_rank_{self.rank}.pt\")\n                local_model_path = copy_to_local(remote_model_path)\n                model_state_dict = torch.load(local_model_path, weights_only=False)\n                self.model.load_state_dict(model_state_dict)\n                log_with_rank(f\"Loaded model from {remote_model_path}\", rank=self.rank, logger=logger)\n\n            if self.should_load_optimizer:\n                remote_optim_path = os.path.join(local_path, f\"optim_world_size_{self.world_size}_rank_{self.rank}.pt\")\n                local_optim_path = copy_to_local(remote_optim_path)\n                optimizer_state_dict = torch.load(local_optim_path, weights_only=False)\n                self.optimizer.load_state_dict(optimizer_state_dict)\n                log_with_rank(f\"Loaded optimizer from {remote_optim_path}\", rank=self.rank, logger=logger)\n\n        if self.should_load_extra:\n            remote_extra_state_path = os.path.join(\n                local_path, f\"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt\"\n            )\n            local_extra_state_path = copy_to_local(remote_extra_state_path)\n            extra_state_dict = torch.load(local_extra_state_path, weights_only=False)\n            # recover random state\n            if \"rng\" in extra_state_dict:\n                # 'rng' may not exist for backward compatibility\n                self.load_rng_state(extra_state_dict[\"rng\"])\n                log_with_rank(f\"Loaded rng from {remote_extra_state_path}\", rank=self.rank, logger=logger)\n\n            lr_scheduler_state_dict = extra_state_dict[\"lr_scheduler\"]\n            if lr_scheduler_state_dict is not None and self.lr_scheduler is not None:\n                self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)\n                log_with_rank(f\"Loaded lr_scheduler from {remote_extra_state_path}\", rank=self.rank, logger=logger)\n\n        if self.rank == 0 and del_local_after_load:\n            try:\n                os.remove(local_model_path) if is_non_local(local_model_path) else None\n                os.remove(local_optim_path) if is_non_local(local_optim_path) else None\n                os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None\n            except Exception as e:\n                log_with_rank(\n                    f\"remove local resume ckpt file after loading failed, exception {e} will be ignored\",\n                    rank=self.rank,\n                    logger=logger,\n                )\n\n        # wait for everyone to load checkpoints\n        torch.distributed.barrier()\n\n    def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):\n        \"\"\"\n        Save an FSDP checkpoint for this rank.\n\n        Writes:\n          - model & optimizer shard files\n          - extra state dict (scheduler + RNG)\n          - HF tokenizer/processor and model/config on rank 0\n          - optional full HF model under 'huggingface/' if requested\n\n        Rotates old checkpoints, keeping at most `max_ckpt_to_keep`.\n\n        Args:\n            local_path: Target directory for checkpoint files.\n            hdfs_path: Unused (for API compatibility).\n            global_step: Current training step (used for bookkeeping).\n            max_ckpt_to_keep: Number of recent checkpoints to retain.\n        \"\"\"\n        if local_path is None:\n            return\n\n        # record the previous global step\n        self.previous_global_step = global_step\n\n        # remove previous local_path, only rank 0 should do this\n        if (\n            self.rank == 0\n            and max_ckpt_to_keep\n            and isinstance(max_ckpt_to_keep, int)\n            and max_ckpt_to_keep > 0\n            and len(self.previous_saved_paths) >= max_ckpt_to_keep\n        ):\n            keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1\n            self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])\n            self.previous_saved_paths = self.previous_saved_paths[keep_start:]\n\n        local_path = local_mkdir_safe(local_path)\n        torch.distributed.barrier()\n\n        # check if the checkpoint_save_contents is valid\n        if self.should_save_model:\n            assert self.model is not None, \"model must be provided when checkpoint_contents.save includes ['model']\"\n        if self.should_save_optimizer:\n            assert self.optimizer is not None, (\n                \"optimizer must be provided when checkpoint_contents.save includes ['optimizer']\"\n            )\n\n        # every rank will save its own model and optim shard\n        state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n        optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n        with warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):\n                model_path = os.path.join(local_path, f\"model_world_size_{self.world_size}_rank_{self.rank}.pt\")\n                optim_path = os.path.join(local_path, f\"optim_world_size_{self.world_size}_rank_{self.rank}.pt\")\n                extra_path = os.path.join(local_path, f\"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt\")\n\n                if self.should_save_model:\n                    model_state_dict = self.model.state_dict()\n                    torch.save(model_state_dict, model_path)\n                    log_with_rank(f\"Saved model to {os.path.abspath(model_path)}\", rank=self.rank, logger=logger)\n\n                if self.should_save_optimizer:\n                    optimizer_state_dict = self.optimizer.state_dict()\n                    torch.save(optimizer_state_dict, optim_path)\n                    log_with_rank(f\"Saved optim to {os.path.abspath(optim_path)}\", rank=self.rank, logger=logger)\n\n                if self.should_save_extra:\n                    lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None\n                    extra_state_dict = {\n                        \"lr_scheduler\": lr_scheduler_state_dict,\n                        \"rng\": self.get_rng_state(),\n                    }\n                    torch.save(extra_state_dict, extra_path)\n                    log_with_rank(f\"Saved extra_state to {os.path.abspath(extra_path)}\", rank=self.rank, logger=logger)\n\n        if self.rank == 0:\n            # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether\n            # huggingface model is requested to be saved or not.\n\n            if fsdp_version(self.model) == 1:\n                unwrap_model = self.model._fsdp_wrapped_module\n            else:\n                unwrap_model = self.model\n\n            hf_config_tokenizer_path = os.path.join(local_path, \"huggingface\")\n            local_mkdir_safe(hf_config_tokenizer_path)\n            model_config = unwrap_model.config\n            generation_config = None\n            if unwrap_model.can_generate() and hasattr(model_config, \"name_or_path\") and model_config.name_or_path:\n                try:\n                    # Some model's name_or_path is empty if not initialized from pretrained,\n                    # in this cases, we don't save generation config.\n                    generation_config = GenerationConfig.from_pretrained(model_config.name_or_path)\n                    generation_config.save_pretrained(hf_config_tokenizer_path)\n                except Exception:\n                    # if the generation config isn't available, we don't save it\n                    pass\n\n            model_config.save_pretrained(hf_config_tokenizer_path)\n            if self.processing_class is not None:\n                self.processing_class.save_pretrained(hf_config_tokenizer_path)\n            log_with_rank(\n                f\"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}\",\n                rank=self.rank,\n                logger=logger,\n                log_only_rank_0=True,\n            )\n\n            # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be\n            # loaded from the Hub.\n            if hasattr(model_config, \"auto_map\"):\n                custom_object_save(unwrap_model, hf_config_tokenizer_path, config=model_config)\n\n            # Also save runtime FSDP config\n            fsdp_config_path = os.path.join(local_path, \"fsdp_config.json\")\n            fsdp_config = FSDPConfig(\n                FSDP_version=fsdp_version(self.model),\n                world_size=self.world_size,\n            )\n            with open(fsdp_config_path, \"w\") as f:\n                json.dump(asdict(fsdp_config), f, indent=4)\n\n        # wait for everyone to dump to local\n        torch.distributed.barrier()\n\n        if self.should_save_hf_model:\n            # Only rank 0 will save hf model and,\n            # offload to cpu to save LLMs which may be too large to fit in one GPU\n            state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True)\n\n            if self.rank == 0:\n                hf_local_path = os.path.join(local_path, \"huggingface\")\n                os.makedirs(hf_local_path, exist_ok=True)\n\n                if \"ForTokenClassification\" in model_config.architectures[0]:\n                    from transformers import AutoModelForTokenClassification\n\n                    auto_model_cls = AutoModelForTokenClassification\n                elif \"ForCausalLM\" in model_config.architectures[0]:\n                    from transformers import AutoModelForCausalLM\n\n                    auto_model_cls = AutoModelForCausalLM\n                elif \"ForConditionalGeneration\" in model_config.architectures[0]:\n                    # Handle different transformers versions for Vision2Seq models\n                    import transformers\n                    from packaging import version\n\n                    if version.parse(transformers.__version__) >= version.parse(\"4.54.0\"):\n                        # transformers >= 4.54.0 uses AutoModelForImageTextToText\n                        from transformers import AutoModelForImageTextToText\n\n                        auto_model_cls = AutoModelForImageTextToText\n                    else:\n                        # transformers < 4.54.0 uses AutoModelForVision2Seq\n                        from transformers import AutoModelForVision2Seq\n\n                        auto_model_cls = AutoModelForVision2Seq\n                else:\n                    raise NotImplementedError(f\"Unknown architecture {model_config['architectures']}\")\n\n                with init_empty_weights():\n                    save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)\n                save_model.to_empty(device=\"cpu\")\n\n                if save_model.can_generate():\n                    if generation_config is not None:\n                        save_model.generation_config = generation_config\n                    else:\n                        print(\n                            f\"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found \"\n                            f\"in, using a generation config created from the model config when saving hf_model.\"\n                        )\n\n                save_model.save_pretrained(hf_local_path, state_dict=state_dict)\n                log_with_rank(\n                    f\"Saved hf_model to {os.path.abspath(hf_local_path)}\",\n                    rank=self.rank,\n                    logger=logger,\n                    log_only_rank_0=True,\n                )\n                del state_dict\n                del save_model\n\n            # wait for rank0 to dump hf_model to local\n            torch.distributed.barrier()\n\n        self.previous_saved_paths.append(local_path)\n"
  },
  {
    "path": "verl_distillation/verl/utils/checkpoint/megatron_checkpoint_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport random\nfrom collections.abc import Callable\nfrom dataclasses import asdict\n\nimport numpy as np\nimport torch\nimport torch.distributed\nfrom megatron.core import mpu, tensor_parallel\nfrom megatron.core.dist_checkpointing.mapping import ShardedObject\nfrom megatron.core.transformer.enums import AttnBackend\nfrom transformers import GenerationConfig\n\nfrom verl.models.weight_loader_registry import get_weight_saver\nfrom verl.utils.device import get_device_name, get_torch_device\nfrom verl.utils.fs import is_non_local, local_mkdir_safe\nfrom verl.utils.logger import log_with_rank\nfrom verl.utils.megatron.dist_checkpointing import load_dist_checkpointing, save_dist_checkpointing\nfrom verl.utils.megatron_utils import (\n    get_dist_checkpoint_path,\n    get_hf_model_checkpoint_path,\n    get_transformer_config_checkpoint_path,\n)\n\nfrom .checkpoint_manager import BaseCheckpointManager\n\n# Setup logging\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"INFO\"))\n\n\nclass MegatronCheckpointManager(BaseCheckpointManager):\n    \"\"\"\n    Checkpoint manager for Megatron-LM distributed training.\n\n    This class manages the saving and loading of model checkpoints in a Megatron-LM\n    distributed training environment. It handles various aspects of checkpointing\n    including model states, optimizer states, learning rate schedulers, and random\n    number generator states, ensuring compatibility with HuggingFace formats.\n\n    Key features:\n    - Distributed checkpoint saving and loading using Megatron's dist_checkpointing\n    - Support for tensor parallel, pipeline parallel, and data parallel configurations\n    - Automatic handling of model state dictionaries across multiple pipeline stages\n    - Integration with HuggingFace model configurations and tokenizers\n    - Random number generator state management for reproducibility\n    - Support for both synchronous and asynchronous checkpoint operations\n\n    The manager automatically handles:\n    - Directory structure creation based on global steps and process ranks\n    - Model configuration and tokenizer saving in HuggingFace format\n    - Optimizer and scheduler state persistence\n    - CUDA RNG state management for deterministic training\n    - Checkpoint cleanup and retention policies\n\n    Args:\n        model: The Megatron model instance to checkpoint\n        optimizer: The optimizer instance (optional)\n        lr_scheduler: The learning rate scheduler instance (optional)\n\n    Attributes:\n        model: Reference to the Megatron model being checkpointed\n        optimizer: Reference to the optimizer (if provided)\n        lr_scheduler: Reference to the learning rate scheduler (if provided)\n        rank: Current process rank in the distributed setup\n\n    Example:\n        ```python\n        checkpoint_manager = MegatronCheckpointManager(\n            model=megatron_model,\n            optimizer=optimizer,\n            lr_scheduler=scheduler\n        )\n\n        checkpoint_manager.save_checkpoint(\n            local_path=\"checkpoints/step_1000\",\n            global_step=1000\n        )\n\n        checkpoint_manager.load_checkpoint(\n            local_path=\"checkpoints/step_1000\"\n        )\n        ```\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        checkpoint_config,\n        model_config,\n        transformer_config,\n        role,\n        model: torch.nn.ModuleList,\n        arch: str,\n        hf_config,\n        param_dtype: torch.dtype,\n        share_embeddings_and_output_weights: bool,\n        processing_class,\n        optimizer,\n        optimizer_scheduler,\n        use_distributed_optimizer: bool,\n        use_checkpoint_opt_param_scheduler: bool = False,\n        use_dist_checkpointing: bool = True,\n        bridge=None,\n        **kwargs,\n    ):\n        super().__init__(\n            model,\n            optimizer=optimizer,\n            lr_scheduler=optimizer_scheduler,\n            processing_class=processing_class,\n            checkpoint_config=checkpoint_config,\n        )\n        self.arch = arch\n        self.config = config\n        self.transformer_config = transformer_config\n        self.role = role\n        self.is_value_model = False\n        if self.role in [\"reward\", \"critic\"]:\n            self.is_value_model = True\n        self.model_config = model_config\n        self.hf_config = hf_config\n        self.param_dtype = param_dtype\n        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights\n        self.model_path = self.config.model.path\n        self.use_distributed_optimizer = use_distributed_optimizer\n        self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler\n        self.bridge = bridge\n        self.rank = torch.distributed.get_rank()\n        self.use_dist_checkpointing = use_dist_checkpointing or not self.bridge or self.is_value_model\n        self.use_hf_checkpoint = not self.use_dist_checkpointing\n\n        self.weight_saver = None\n        if self.bridge is None:\n            self.weight_saver = get_weight_saver(self.arch)\n\n    def get_rng_state(self, use_dist_ckpt: bool = True, data_parallel_random_init: bool = False):\n        \"\"\"collect rng state across data parallel ranks\"\"\"\n        rng_state = {\n            \"random_rng_state\": random.getstate(),\n            \"np_rng_state\": np.random.get_state(),\n            \"torch_rng_state\": torch.get_rng_state(),\n            \"rng_tracker_states\": tensor_parallel.get_cuda_rng_tracker().get_states(),\n        }\n\n        if get_device_name() != \"cpu\":\n            rng_state[f\"{get_device_name()}_rng_state\"] = get_torch_device().get_rng_state()\n\n        rng_state_list = None\n        if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init:\n            rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())]\n            torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group())\n        else:\n            rng_state_list = [rng_state]\n\n        if use_dist_ckpt:\n            pp_rank = mpu.get_pipeline_model_parallel_rank()\n            pp_size = mpu.get_pipeline_model_parallel_world_size()\n            tp_rank = mpu.get_tensor_model_parallel_rank()\n            tp_size = mpu.get_tensor_model_parallel_world_size()\n            rng_state_list = ShardedObject(\n                \"rng_state\",\n                rng_state_list,\n                (pp_size, tp_size),\n                (pp_rank, tp_rank),\n                replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),\n            )\n\n        return rng_state_list\n\n    def get_checkpoint_name(\n        self,\n        checkpoints_path,\n        pipeline_parallel=None,\n        tensor_rank=None,\n        pipeline_rank=None,\n        cp_rank=None,\n        expert_parallel=None,\n        expert_rank=None,\n        return_base_dir=True,\n        basename=\"model.pt\",\n    ):\n        \"\"\"Determine the directory name for this rank's checkpoint.\"\"\"\n        # Use both the tensor and pipeline MP rank.\n        if pipeline_parallel is None:\n            pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1\n        if tensor_rank is None:\n            tensor_rank = mpu.get_tensor_model_parallel_rank()\n        if pipeline_rank is None:\n            pipeline_rank = mpu.get_pipeline_model_parallel_rank()\n        if cp_rank is None:\n            cp_rank = mpu.get_context_parallel_rank()\n        if expert_parallel is None:\n            expert_parallel = mpu.get_expert_model_parallel_world_size() > 1\n        if expert_rank is None:\n            expert_rank = mpu.get_expert_model_parallel_rank()\n\n        # Use both the tensor and pipeline MP rank. If using the distributed\n        # optimizer, then the optimizer's path must additionally include the\n        # data parallel rank.\n\n        # due to the fact that models are identical across cp ranks, cp rank is not used in the checkpoint path\n        if not pipeline_parallel:\n            common_path = os.path.join(checkpoints_path, f\"mp_rank_{tensor_rank:02d}\")\n        else:\n            common_path = os.path.join(checkpoints_path, f\"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}\")\n\n        if expert_parallel:\n            common_path = common_path + f\"_{expert_rank:03d}\"\n\n        os.makedirs(common_path, exist_ok=True)\n\n        if return_base_dir:\n            return common_path\n        return os.path.join(common_path, basename)\n\n    def generate_state_dict(\n        self,\n        generate_model: bool = True,\n        generate_optimizer: bool = True,\n        generate_extra: bool = True,\n        is_loading: bool = False,\n    ):\n        # For save dist checkpointing\n        state_dict = {}\n\n        # Should always generate model state dict\n        # All ranks Save Model to reduce memory pressure\n        # Get sharded state dict, notice that state_dict will collect among dp groups, causing memory pressure\n        for vpp_rank, model in enumerate(self.model):\n            if len(self.model) > 1:\n                mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)\n                key = f\"model{vpp_rank}\" if len(self.model) > 1 else \"model\"\n            else:\n                key = \"model\"\n            if hasattr(model, \"module\"):\n                model = model.module\n            state_dict[key] = model.sharded_state_dict()\n\n        # Optimizer State Dict\n        if generate_optimizer:\n            torch.distributed.barrier()\n            optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict, is_loading=is_loading)\n            state_dict[\"optimizer\"] = optimizer_sharded_states\n\n            if self.lr_scheduler is not None:\n                lr_state_dict = self.lr_scheduler.state_dict()\n                state_dict[\"lr_scheduler\"] = lr_state_dict\n\n        if not generate_model:\n            state_dict.pop(\"model\", None)\n\n        # RNG States State Dict\n        if generate_extra:\n            torch.distributed.barrier()\n            rng_state = self.get_rng_state()\n            state_dict[\"rng_state\"] = rng_state\n\n        return state_dict\n\n    def load_rng_states(self, rng_states, data_parallel_random_init=False, use_dist_ckpt=True):\n        # access rng_state for data parallel rank\n        if data_parallel_random_init:\n            rng_states = rng_states[mpu.get_data_parallel_rank()]\n        else:\n            rng_states = rng_states[0]\n        random.setstate(rng_states[\"random_rng_state\"])\n        np.random.set_state(rng_states[\"np_rng_state\"])\n        torch.set_rng_state(rng_states[\"torch_rng_state\"])\n\n        if get_device_name() != \"cpu\":\n            get_torch_device().set_rng_state(rng_states[f\"{get_device_name()}_rng_state\"])\n\n        # Check for empty states array\n        if not rng_states[\"rng_tracker_states\"]:\n            raise KeyError\n        tensor_parallel.get_cuda_rng_tracker().set_states(rng_states[\"rng_tracker_states\"])\n\n    def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):\n        if local_path is not None:\n            assert os.path.exists(local_path), f\"Checkpoint path {local_path} does not exist.\"\n\n        # For load optimizer dist_ckpt\n        import transformer_engine\n\n        torch.serialization.add_safe_globals([torch.optim.AdamW])\n        torch.serialization.add_safe_globals([transformer_engine.pytorch.optimizers.fused_adam.FusedAdam])\n\n        dist_checkpoint_path = get_dist_checkpoint_path(local_path)\n\n        # Get State Dict for loading\n        sharded_state_dict = self.generate_state_dict(\n            self.should_load_model and self.use_dist_checkpointing,\n            self.should_load_optimizer,\n            self.should_load_extra,\n            is_loading=True,\n        )\n        log_with_rank(f\"Generated state dict for loading: {sharded_state_dict.keys()}\", rank=self.rank, logger=logger)\n\n        # Load Dist Checkpointing\n        state_dict = load_dist_checkpointing(\n            sharded_state_dict=sharded_state_dict,\n            ckpt_dir=dist_checkpoint_path,\n        )\n\n        if self.should_load_model and self.use_dist_checkpointing:\n            assert \"model\" in state_dict or any(\n                f\"model{vpp_rank}\" in state_dict for vpp_rank in range(len(self.model))\n            ), f\"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}.\"\n            for vpp_rank, model in enumerate(self.model):\n                if len(self.model) == 1:\n                    model_state_dict = state_dict[\"model\"]\n                else:\n                    assert f\"model{vpp_rank}\" in state_dict, f\"model{vpp_rank} not found in state_dict\"\n                    model_state_dict = state_dict[f\"model{vpp_rank}\"]\n                mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)\n                self.model[vpp_rank].load_state_dict(model_state_dict)\n            log_with_rank(f\"Loaded sharded model checkpoint from {local_path}\", rank=self.rank, logger=logger)\n        elif self.should_load_model and self.use_hf_checkpoint:\n            hf_model_path = get_hf_model_checkpoint_path(local_path)\n            self.bridge.load_weights(self.model, hf_model_path)\n            log_with_rank(f\"Loaded HF model checkpoint from {hf_model_path} with bridge\", rank=self.rank, logger=logger)\n\n        if self.should_load_optimizer:\n            assert \"optimizer\" in state_dict, (\n                f\"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}.\"\n            )\n            optimizer_state_dict = state_dict[\"optimizer\"]\n            self.optimizer.load_state_dict(optimizer_state_dict)\n            log_with_rank(f\"Loaded optimizer checkpoint from {local_path}\", rank=self.rank, logger=logger)\n            if self.use_checkpoint_opt_param_scheduler:\n                assert \"lr_scheduler\" in state_dict, (\n                    f\"LR scheduler state dict not found in {state_dict.keys()}. Please check the checkpoint file \"\n                    f\"{local_path}.\"\n                )\n                lr_scheduler_state_dict = state_dict[\"lr_scheduler\"]\n                if self.lr_scheduler is not None:\n                    self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)\n                    log_with_rank(f\"Loaded LR scheduler checkpoint from {local_path}\", rank=self.rank, logger=logger)\n\n        if self.should_load_extra:\n            assert \"rng_state\" in state_dict, (\n                f\"RNG state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}.\"\n            )\n            rng_state = state_dict[\"rng_state\"]\n            self.load_rng_states(rng_state)\n            log_with_rank(f\"Loaded RNG states from {local_path}\", rank=self.rank, logger=logger)\n\n        if del_local_after_load:\n            try:\n                os.remove(local_path) if is_non_local(local_path) else None\n            except Exception as e:\n                log_with_rank(\n                    f\"remove local resume ckpt file after loading failed, exception {e} will be ignored\",\n                    rank=self.rank,\n                    logger=logger,\n                )\n\n    def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):\n        # record the previous global step\n        self.previous_global_step = global_step\n\n        # remove previous local_path\n        if (\n            max_ckpt_to_keep\n            and isinstance(max_ckpt_to_keep, int)\n            and max_ckpt_to_keep > 0\n            and len(self.previous_saved_paths) >= max_ckpt_to_keep\n        ):\n            keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1\n            self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])\n            self.previous_saved_paths = self.previous_saved_paths[keep_start:]\n\n        local_path = local_mkdir_safe(local_path)\n        dist_checkpoint_path = get_dist_checkpoint_path(local_path)\n\n        # Note that model weights, optimizer states, and extra states are generated\n        # together in a state dict, we save them in one time\n        if self.use_dist_checkpointing:\n            # Generate state dict for saving\n            state_dict = self.generate_state_dict(\n                self.should_save_model, self.should_save_optimizer, self.should_save_extra\n            )\n            log_with_rank(f\"Generated state dict for saving: {state_dict.keys()}\", rank=self.rank, logger=logger)\n            for vpp_rank, model in enumerate(self.model):\n                if len(self.model) > 1:\n                    model_i_keys = state_dict[f\"model{vpp_rank}\"].keys()\n                    log_with_rank(f\"Generated state dict for saving: {model_i_keys}\", rank=self.rank, logger=logger)\n                else:\n                    log_with_rank(\n                        f\"Generated state dict for saving: {state_dict['model'].keys()}\", rank=self.rank, logger=logger\n                    )\n            # Start Async save if enabled\n            async_save_request = save_dist_checkpointing(\n                sharded_state_dict=state_dict,\n                ckpt_path=dist_checkpoint_path,\n                async_save=self.checkpoint_config.async_save,\n            )\n\n            # Synchronize all async save requests\n            if not self.checkpoint_config.async_save:\n                assert async_save_request is None, \"Async save request should be None when not using async save.\"\n                torch.distributed.barrier()\n        else:\n            assert self.use_hf_checkpoint, \"When not using distributed checkpointing, use_hf_checkpoint should be True.\"\n            # Generate optimizer and exra state dicts\n            state_dict = self.generate_state_dict(\n                generate_model=False,\n                generate_optimizer=self.should_save_optimizer,\n                generate_extra=self.should_save_extra,\n            )\n            # Save optimizer and extra states to local path\n            # Start Async save if enabled\n            async_save_request = save_dist_checkpointing(\n                sharded_state_dict=state_dict,\n                ckpt_path=dist_checkpoint_path,\n                async_save=self.checkpoint_config.async_save,\n            )\n\n            # Synchronize all async save requests\n            if not self.checkpoint_config.async_save:\n                assert async_save_request is None, \"Async save request should be None when not using async save.\"\n                torch.distributed.barrier()\n\n        if self.should_save_model:\n            if self.use_hf_checkpoint:\n                # Use mbridge to save HF model checkpoint\n                log_with_rank(f\"Saving HF model checkpoint to {local_path} with bridge\", rank=self.rank, logger=logger)\n                hf_ckpt_path = get_hf_model_checkpoint_path(local_path)\n                self.bridge.save_weights(self.model, hf_ckpt_path)\n                log_with_rank(f\"Saved bridge checkpoint to {hf_ckpt_path}\", rank=self.rank, logger=logger)\n\n            # Only rank 0 saves the hf config and tokenizer to huggingface path\n            # No matter whether we save hf model or not\n            if self.rank == 0:\n                # Save tokenizer\n                hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path)\n                if self.processing_class is not None:\n                    self.processing_class.save_pretrained(hf_config_tokenizer_path)\n                # Save huggingface config\n                self.hf_config.save_pretrained(hf_config_tokenizer_path)\n                if hasattr(self.hf_config, \"name_or_path\") and self.hf_config.name_or_path:\n                    try:\n                        generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path)\n                        generation_config.save_pretrained(hf_config_tokenizer_path)\n                    except Exception:\n                        # if the generation config isn't available, we don't save it\n                        pass\n                log_with_rank(\n                    f\"Saved Huggingface config and tokenizer to {hf_config_tokenizer_path}\",\n                    rank=self.rank,\n                    logger=logger,\n                    log_only_rank_0=True,\n                )\n\n        if self.should_save_extra:\n            if self.rank == 0:\n                # Save transformer config\n                print(self.transformer_config)\n                transformer_config_dict = asdict(self.transformer_config)\n                to_convert_types = {torch.dtype: str, AttnBackend: str}\n                ignore_types = [Callable]\n                pop_keys = []\n                for key, value in transformer_config_dict.items():\n                    if type(value) in to_convert_types:\n                        transformer_config_dict[key] = to_convert_types[type(value)](value)\n                    if type(value) in ignore_types:\n                        pop_keys.append(key)\n                    if callable(value):\n                        pop_keys.append(key)\n                for key in pop_keys:\n                    transformer_config_dict.pop(key)\n                transformer_config_path = get_transformer_config_checkpoint_path(local_path)\n                with open(transformer_config_path, \"w\") as f:\n                    json.dump(transformer_config_dict, f, indent=2)\n\n        if self.should_save_hf_model and not self.use_hf_checkpoint:\n            # wait for everyone to dump to local\n            if self.bridge is not None:\n                hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)\n                self.bridge.save_weights(self.model, hf_model_ckpt_path)\n            else:\n                state_dict = self.weight_saver(\n                    self.model,\n                    self.hf_config,\n                    dtype=self.param_dtype,\n                    is_value_model=self.is_value_model,\n                    tie_word_embeddings=self.share_embeddings_and_output_weights,\n                )\n\n                torch.distributed.barrier()\n                if self.rank == 0:\n                    hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)\n                    import warnings\n\n                    from accelerate import init_empty_weights\n\n                    with init_empty_weights(), warnings.catch_warnings():\n                        warnings.simplefilter(\"ignore\")\n                        if \"mistral7b-rm\" in self.config.model.path:\n                            from transformers import MistralForSequenceClassification\n\n                            model = MistralForSequenceClassification.from_pretrained(\n                                self.config.model.path\n                            )  # use score head instead of lm_head\n                            state_dict[\"score.weight\"] = state_dict[\"score.weight\"]\n                        else:\n                            from transformers import AutoModelForCausalLM\n\n                            model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype=\"auto\")\n                    model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)\n                    log_with_rank(\n                        f\"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}\",\n                        rank=self.rank,\n                        logger=logger,\n                        log_only_rank_0=True,\n                    )\n\n                    if hdfs_path is not None:\n                        log_with_rank(\n                            f\"Uploading checkpoint to {hdfs_path}\", rank=self.rank, logger=logger, log_only_rank_0=True\n                        )\n                        from verl.utils import hdfs_io\n\n                        hdfs_io.makedirs(hdfs_path, exist_ok=True)\n                        hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)\n                        log_with_rank(\n                            f\"HDFS checkpoint uploaded to {hdfs_path}\",\n                            rank=self.rank,\n                            logger=logger,\n                            log_only_rank_0=True,\n                        )\n\n        def finalize_save_fn():\n            # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided\n            log_with_rank(\n                f\"Dist checkpointing save completed for {dist_checkpoint_path}\", rank=self.rank, logger=logger\n            )\n            if self.rank == 0:\n                if hdfs_path is not None:\n                    log_with_rank(f\"Uploading checkpoint to {hdfs_path}\", rank=self.rank, logger=logger)\n                    from verl.utils import hdfs_io\n\n                    hdfs_io.makedirs(hdfs_path, exist_ok=True)\n                    hdfs_io.copy(src=dist_checkpoint_path, dst=hdfs_path, dirs_exist_ok=True)\n                    hdfs_io.copy(src=hf_config_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True)\n\n        if self.checkpoint_config.async_save:\n            assert async_save_request is not None, \"Async save request should not be None when using async save.\"\n            async_save_request.add_finalize_fn(finalize_save_fn)\n        else:\n            finalize_save_fn()\n\n        self.previous_saved_paths.append(local_path)\n"
  },
  {
    "path": "verl_distillation/verl/utils/config.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import is_dataclass\nfrom typing import Any, Optional\n\nfrom omegaconf import DictConfig, ListConfig, OmegaConf\n\n__all__ = [\"omega_conf_to_dataclass\", \"validate_config\"]\n\n\ndef omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any:\n    \"\"\"\n    Convert an OmegaConf DictConfig to a dataclass.\n\n    Args:\n        config: The OmegaConf DictConfig or dict to convert.\n        dataclass_type: The dataclass type to convert to. When dataclass_type is None,\n            the DictConfig must contain _target_ to be instantiated via hydra.instantiate API.\n\n    Returns:\n        The dataclass instance.\n    \"\"\"\n    # Got an empty config\n    if not config:\n        return dataclass_type if dataclass_type is None else dataclass_type()\n    # Got an object\n    if not isinstance(config, DictConfig | ListConfig | dict | list):\n        return config\n\n    if dataclass_type is None:\n        assert \"_target_\" in config, (\n            \"When dataclass_type is not provided, config must contain _target_. \"\n            \"See trainer/config/ppo_trainer.yaml algorithm section for an example. \"\n            f\"Got config: {config}\"\n        )\n        from hydra.utils import instantiate\n\n        return instantiate(config, _convert_=\"partial\")\n\n    if not is_dataclass(dataclass_type):\n        raise ValueError(f\"{dataclass_type} must be a dataclass\")\n    cfg = OmegaConf.create(config)  # in case it's a dict\n    # pop _target_ to avoid hydra instantiate error, as most dataclass do not have _target_\n    # Updated (vermouth1992) We add _target_ to BaseConfig so that it is compatible.\n    # Otherwise, this code path can't support recursive instantiation.\n    # if \"_target_\" in cfg:\n    #     cfg.pop(\"_target_\")\n    cfg_from_dataclass = OmegaConf.structured(dataclass_type)\n    # let cfg override the existing vals in `cfg_from_dataclass`\n    cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg)\n    # now convert to `dataclass_type`\n    config_object = OmegaConf.to_object(cfg_merged)\n    return config_object\n\n\ndef update_dict_with_config(dictionary: dict, config: DictConfig):\n    for key in dictionary:\n        if hasattr(config, key):\n            dictionary[key] = getattr(config, key)\n\n\ndef validate_config(\n    config: DictConfig,\n    use_reference_policy: bool,\n    use_critic: bool,\n) -> None:\n    \"\"\"Validate an OmegaConf DictConfig.\n\n    Args:\n        config (DictConfig): The OmegaConf DictConfig to validate.\n        use_reference_policy (bool): is ref policy needed\n        use_critic (bool): is critic needed\n    \"\"\"\n    # number of GPUs total\n    n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes\n\n    if not config.actor_rollout_ref.actor.use_dynamic_bsz:\n        if config.actor_rollout_ref.actor.strategy == \"megatron\":\n            model_parallel_size = (\n                config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size\n                * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size\n            )\n            assert (\n                n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0\n            ), (\n                f\"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times \"\n                f\"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})\"\n            )\n            megatron_dp = n_gpus // (\n                model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size\n            )\n            minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu\n        else:\n            minimal_bsz = n_gpus\n\n        # 1. Check total batch size for data correctness\n        real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n\n        assert real_train_batch_size % minimal_bsz == 0, (\n            f\"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size \"\n            f\"({minimal_bsz})\"\n        )\n\n    # A helper function to check \"micro_batch_size\" vs \"micro_batch_size_per_gpu\"\n    # We throw an error if the user sets both. The new convention is \"..._micro_batch_size_per_gpu\".\n    def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):\n        \"\"\"Validate mutually exclusive micro batch size configuration options.\n\n        Ensures that users don't set both deprecated micro_batch_size and\n        the new micro_batch_size_per_gpu parameters simultaneously.\n\n        Args:\n            mbs: Deprecated micro batch size parameter value.\n            mbs_per_gpu: New micro batch size per GPU parameter value.\n            name (str): Configuration section name for error messages.\n\n        Raises:\n            ValueError: If both parameters are set or neither is set.\n        \"\"\"\n        settings = {\n            \"reward_model\": \"micro_batch_size\",\n            \"actor_rollout_ref.ref\": \"log_prob_micro_batch_size\",\n            \"actor_rollout_ref.rollout\": \"log_prob_micro_batch_size\",\n        }\n\n        if name in settings:\n            param = settings[name]\n            param_per_gpu = f\"{param}_per_gpu\"\n\n            if mbs is None and mbs_per_gpu is None:\n                raise ValueError(f\"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.\")\n\n            if mbs is not None and mbs_per_gpu is not None:\n                raise ValueError(\n                    f\"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove \"\n                    f\"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated).\"\n                )\n\n    # Actor validation done in ActorConfig.__post_init__ and validate()\n    actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor)\n    actor_config.validate(n_gpus, config.data.train_batch_size, config.actor_rollout_ref.model)\n\n    if not config.actor_rollout_ref.actor.use_dynamic_bsz:\n        if use_reference_policy:\n            # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu\n            check_mutually_exclusive(\n                config.actor_rollout_ref.ref.log_prob_micro_batch_size,\n                config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,\n                \"actor_rollout_ref.ref\",\n            )\n\n        #  The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu\n        check_mutually_exclusive(\n            config.actor_rollout_ref.rollout.log_prob_micro_batch_size,\n            config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,\n            \"actor_rollout_ref.rollout\",\n        )\n\n    # Check for reward model micro-batch size conflicts\n    if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:\n        check_mutually_exclusive(\n            config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, \"reward_model\"\n        )\n\n    if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:\n        print(\"NOTICE: You have both enabled in-reward kl and kl loss.\")\n\n    # critic\n    if use_critic:\n        critic_config = omega_conf_to_dataclass(config.critic)\n        critic_config.validate(n_gpus, config.data.train_batch_size)\n\n    if config.data.get(\"val_batch_size\", None) is not None:\n        print(\n            \"WARNING: val_batch_size is deprecated.\"\n            + \" Validation datasets are sent to inference engines as a whole batch,\"\n            + \" which will schedule the memory themselves.\"\n        )\n\n    # check eval config\n    if config.actor_rollout_ref.rollout.val_kwargs.do_sample:\n        assert config.actor_rollout_ref.rollout.temperature > 0, (\n            \"validation gen temperature should be greater than 0 when enabling do_sample\"\n        )\n\n    # check LoRA rank in vLLM\n    if config.actor_rollout_ref.model.get(\"lora_rank\", 0) > 0 and config.actor_rollout_ref.rollout.name == \"vllm\":\n        assert config.actor_rollout_ref.model.lora_rank <= 512, \"LoRA rank in vLLM must be less than or equal to 512\"\n\n    print(\"[validate_config] All configuration checks passed successfully!\")\n"
  },
  {
    "path": "verl_distillation/verl/utils/dataset/README.md",
    "content": "# Dataset Format\n## RLHF dataset\nWe combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers.\n\nMath problems\n```json\n{\n    \"data_source\": \"openai/gsm8k\",\n    \"prompt\": [{\"role\": \"user\", \"content\": \"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \\\"####\\\"\"}],\n    \"ability\": \"math\",\n    \"reward_model\": {\n        \"style\": \"rule\",\n        \"ground_truth\": [\"72\"]\n    },\n}\n```\n"
  },
  {
    "path": "verl_distillation/verl/utils/dataset/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .rl_dataset import RLHFDataset\nfrom .rm_dataset import RMDataset\nfrom .sft_dataset import SFTDataset\n\n__all__ = [\"RLHFDataset\", \"RMDataset\", \"SFTDataset\"]\n"
  },
  {
    "path": "verl_distillation/verl/utils/dataset/dataset_utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n\n#     http://www.apache.org/licenses/LICENSE-2.0\n\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom enum import Enum\n\nimport torch\n\n\nclass DatasetPadMode(str, Enum):\n    \"\"\"Padding mode for dataset\"\"\"\n\n    RIGHT = \"right\"\n    LEFT_RIGHT = \"left_right\"\n    NO_PADDING = \"no_padding\"\n\n\nclass SFTTensorCollator:\n    \"\"\"\n    A custom collate_fn that handles batching of sequences.\n    1. for variable-length sequences, convert them into NestedTensors.\n    2. for fixed-length sequences, use default_collate.\n    \"\"\"\n\n    def __init__(self, pad_mode: DatasetPadMode = DatasetPadMode.LEFT_RIGHT):\n        self.pad_mode = pad_mode\n\n    def __call__(self, batch: list[dict[str, any]]) -> dict[str, any]:\n        if self.pad_mode == DatasetPadMode.NO_PADDING:\n            return self.collate_variable_batch(batch)\n        elif self.pad_mode in [DatasetPadMode.RIGHT, DatasetPadMode.LEFT_RIGHT]:\n            from torch.utils.data import default_collate\n\n            return default_collate(batch)\n        else:\n            raise NotImplementedError(f\"pad_mode {self.pad_mode} not implemented\")\n\n    def collate_variable_batch(self, batch: list[dict[str, any]]) -> dict[str, any]:\n        \"\"\"\n        Collates a list of samples into a single batch.\n\n        Args:\n            batch: A list of dictionary samples from the dataset.\n\n        Returns:\n            A dictionary representing the batched data, with variable-length\n            sequences converted to NestedTensors.\n        \"\"\"\n\n        final_batch = {}\n\n        tensor_keys = [key for key in batch[0].keys() if isinstance(batch[0][key], torch.Tensor)]\n\n        # Handle tensor values by creating a NestedTensor.\n        for key in tensor_keys:\n            tensors = [item[key] for item in batch]\n            final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)\n\n        return final_batch\n"
  },
  {
    "path": "verl_distillation/verl/utils/dataset/multiturn_sft_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n\n#     http://www.apache.org/licenses/LICENSE-2.0\n\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMulti-turn SFT dataset that supports training on conversation data with multiple turns\n\"\"\"\n\nimport logging\nfrom typing import Any, Optional\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom omegaconf import ListConfig\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer\n\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.dataset.dataset_utils import DatasetPadMode\nfrom verl.utils.fs import copy_local_path_from_hdfs\n\n\ndef convert_nested_value_to_list_recursive(data_item):\n    if isinstance(data_item, dict):\n        return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()}\n    elif isinstance(data_item, list):\n        return [convert_nested_value_to_list_recursive(elem) for elem in data_item]\n    elif isinstance(data_item, np.ndarray):\n        # Convert to list, then recursively process the elements of the new list\n        return convert_nested_value_to_list_recursive(data_item.tolist())\n    else:\n        # Base case: item is already a primitive type (int, str, float, bool, etc.)\n        return data_item\n\n\nclass MultiTurnSFTDataset(Dataset):\n    \"\"\"\n    Dataset for multi-turn conversations where each assistant response should be trained\n    \"\"\"\n\n    def __init__(self, parquet_files: str | list[str], tokenizer, config=None, max_samples: int = -1):\n        # Set defaults and extract parameters from config if provided\n        config = config or {}\n        self.pad_mode = config.get(\"pad_mode\", \"right\")\n        assert self.pad_mode in [\"right\", \"no_padding\"], (\n            f\"Expect pad_mode to be 'right' or 'no_padding'. Got {self.pad_mode}\"\n        )\n        self.truncation = config.get(\"truncation\", \"error\")\n        # for right padding\n        self.max_length = config.get(\"max_length\", 1024)\n        # Get messages_key from the new multiturn config structure\n        multiturn_config = config.get(\"multiturn\", {})\n        self.messages_key = multiturn_config.get(\"messages_key\", \"messages\")\n        self.tools_key = multiturn_config.get(\"tools_key\", \"tools\")\n        self.enable_thinking_key = multiturn_config.get(\"enable_thinking_key\", \"enable_thinking\")\n        self.apply_chat_template_kwargs = config.get(\"apply_chat_template_kwargs\", {})\n        self.shuffle = config.get(\"shuffle\", False)\n        self.seed = config.get(\"seed\")\n        self.max_samples = max_samples\n        assert self.truncation in [\"error\", \"left\", \"right\"]\n\n        if not isinstance(parquet_files, list | ListConfig):\n            parquet_files = [parquet_files]\n\n        self.parquet_files = parquet_files\n        if isinstance(tokenizer, str):\n            tokenizer = hf_tokenizer(tokenizer)\n        self.tokenizer: PreTrainedTokenizer = tokenizer\n\n        self._download()\n        self._read_files_and_process()\n\n    def _download(self):\n        for i, parquet_file in enumerate(self.parquet_files):\n            self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True)\n\n    def _read_files_and_process(self):\n        def series_to_item(ls):\n            import numpy\n            import pandas\n\n            while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1:\n                ls = ls[0]\n            return ls\n\n        dataframes = []\n        for parquet_file in self.parquet_files:\n            dataframe = pd.read_parquet(parquet_file)\n            dataframes.append(dataframe)\n        self.dataframe = pd.concat(dataframes)\n\n        total = len(self.dataframe)\n        print(f\"dataset len: {len(self.dataframe)}\")\n\n        if self.max_samples > 0 and self.max_samples < total:\n            if self.shuffle:\n                rngs_args = (self.seed,) if self.seed is not None else ()\n                rng = np.random.default_rng(*rngs_args)\n                indices = rng.choice(total, size=self.max_samples, replace=False)\n            else:\n                indices = np.arange(self.max_samples)\n            self.dataframe = self.dataframe.iloc[indices.tolist()]\n            print(f\"selected {self.max_samples} random samples out of {total}\")\n\n        # Extract messages list from dataframe\n        self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist()\n\n        # Extract tools list from dataframe\n        if self.tools_key in self.dataframe.columns:\n            self.tools = self.dataframe[self.tools_key].apply(convert_nested_value_to_list_recursive).tolist()\n        else:\n            self.tools = None\n        # Extract enable_thinking list from dataframe\n        if self.enable_thinking_key in self.dataframe.columns:\n            self.enable_thinking = self.dataframe[self.enable_thinking_key].tolist()\n        else:\n            self.enable_thinking = None\n\n    def __len__(self):\n        return len(self.messages)\n\n    def _process_message_tokens(\n        self,\n        messages: list[dict[str, Any]],\n        start_idx: int,\n        end_idx: int,\n        is_assistant: bool = False,\n        enable_thinking: Optional[bool] = None,\n        tools: Optional[list[dict[str, Any]]] = None,\n    ) -> tuple[list[int], list[int], list[int]]:\n        \"\"\"\n        Process tokens for a single message or a group of messages.\n\n        Args:\n            messages: List of message dictionaries\n            start_idx: Start index in messages list\n            end_idx: End index in messages list\n            is_assistant: Whether this is an assistant message\n            enable_thinking: Whether to enable thinking mode\n\n        Returns:\n            Tuple of (tokens, loss_mask, attention_mask)\n        \"\"\"\n        if start_idx > 0:\n            prev_applied_text = self.tokenizer.apply_chat_template(\n                messages[:start_idx],\n                tokenize=False,\n                add_generation_prompt=False,\n                enable_thinking=enable_thinking,\n                tools=tools,\n                **self.apply_chat_template_kwargs,\n            )\n            if is_assistant:\n                prev_applied_text_w_generation_prompt = self.tokenizer.apply_chat_template(\n                    messages[:start_idx],\n                    tokenize=False,\n                    add_generation_prompt=True,\n                    enable_thinking=enable_thinking,\n                    tools=tools,\n                    **self.apply_chat_template_kwargs,\n                )\n\n        else:\n            prev_applied_text = \"\"\n\n        cur_applied_text = self.tokenizer.apply_chat_template(\n            messages[:end_idx],\n            tokenize=False,\n            add_generation_prompt=False,\n            enable_thinking=enable_thinking,\n            tools=tools,\n            **self.apply_chat_template_kwargs,\n        )\n        # Get tokens for the current message only\n        if is_assistant:\n            generation_prompt_text = prev_applied_text_w_generation_prompt[len(prev_applied_text) :]\n            generation_prompt_tokens = self.tokenizer.encode(\n                generation_prompt_text,\n                add_special_tokens=False,\n            )\n            _message_tokens = self.tokenizer.encode(\n                cur_applied_text[len(prev_applied_text_w_generation_prompt) :],\n                add_special_tokens=False,\n            )\n            message_tokens = generation_prompt_tokens + _message_tokens\n            loss_mask = [0] * (len(generation_prompt_tokens)) + [1] * (\n                len(message_tokens) - len(generation_prompt_tokens)\n            )\n        else:\n            message_tokens = self.tokenizer.encode(\n                cur_applied_text[len(prev_applied_text) :],\n                add_special_tokens=False,\n            )\n            loss_mask = [0] * len(message_tokens)\n\n        attention_mask = [1] * len(message_tokens)\n\n        return message_tokens, loss_mask, attention_mask\n\n    def _validate_and_convert_tokens(\n        self,\n        full_tokens: torch.Tensor,\n        concat_tokens: list[int],\n        concat_loss_mask: list[int],\n        concat_attention_mask: list[int],\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Validate tokenization and convert to tensors.\n\n        Args:\n            full_tokens: Full conversation tokens\n            concat_tokens: Concatenated tokens\n            concat_loss_mask: Concatenated loss mask\n            concat_attention_mask: Concatenated attention mask\n\n        Returns:\n            Tuple of (input_ids, loss_mask, attention_mask) as tensors\n        \"\"\"\n        full_tokens_list = full_tokens.tolist()\n\n        if len(concat_tokens) != len(full_tokens_list) or not all(\n            a == b for a, b in zip(concat_tokens, full_tokens_list, strict=True)\n        ):\n            logging.warning(\n                f\"Token mismatch detected! Full tokenization length: {len(full_tokens_list)}, Concatenated tokens \"\n                f\"length: {len(concat_tokens)}. Using concatenated version.\"\n                # f\"full tokens text: {self.tokenizer.decode(full_tokens_list)}\"\n                # f\"concat tokens text: {self.tokenizer.decode(concat_tokens)}\"\n            )\n            return (\n                torch.tensor(concat_tokens, dtype=torch.long),\n                torch.tensor(concat_loss_mask, dtype=torch.long),\n                torch.tensor(concat_attention_mask, dtype=torch.long),\n            )\n\n        return (\n            full_tokens,\n            torch.tensor(concat_loss_mask, dtype=torch.long),\n            torch.tensor(concat_attention_mask, dtype=torch.long),\n        )\n\n    def __getitem__(self, item):\n        tokenizer = self.tokenizer\n        messages = self.messages[item]\n        tools = self.tools[item] if self.tools is not None else None\n        enable_thinking = self.enable_thinking[item] if self.enable_thinking is not None else None\n\n        # First, get the full conversation tokens\n        try:\n            full_tokens = tokenizer.apply_chat_template(\n                messages,\n                tools=tools,\n                tokenize=True,\n                return_tensors=\"pt\",\n                add_generation_prompt=False,\n                enable_thinking=enable_thinking,\n                **self.apply_chat_template_kwargs,\n            )\n        except Exception as e:\n            logging.error(\n                f\"Error applying chat template: {e}\\nMessages: {messages}\\nTools: {tools}\\nEnable thinking: \"\n                f\"{enable_thinking}\"\n            )\n            raise\n\n        # Track concatenated tokens for validation\n        concat_tokens = []\n        concat_loss_mask = []\n        concat_attention_mask = []\n\n        i = 0\n        while i < len(messages):\n            cur_messages = messages[i]\n            if cur_messages[\"role\"] == \"assistant\":\n                # Process assistant message\n                tokens, loss_mask, attention_mask = self._process_message_tokens(\n                    messages, i, i + 1, is_assistant=True, enable_thinking=enable_thinking, tools=tools\n                )\n                i += 1\n            elif cur_messages[\"role\"] == \"tool\":\n                # Process consecutive tool messages\n                st = i\n                ed = i + 1\n                while ed < len(messages) and messages[ed][\"role\"] == \"tool\":\n                    ed += 1\n                tokens, loss_mask, attention_mask = self._process_message_tokens(\n                    messages, st, ed, enable_thinking=enable_thinking, tools=tools\n                )\n                i = ed\n            elif cur_messages[\"role\"] in [\"user\", \"system\"]:\n                # Process user or system message\n                if cur_messages[\"role\"] == \"system\" and i != 0:\n                    raise ValueError(\"System message should be the first message\")\n                tokens, loss_mask, attention_mask = self._process_message_tokens(\n                    messages, i, i + 1, enable_thinking=enable_thinking, tools=tools\n                )\n                i += 1\n            else:\n                raise ValueError(f\"Unknown role: {cur_messages['role']}\")\n\n            # override loss mask with mask in the dataset to handle multi-turn conversation\n            override_loss_mask = cur_messages.get(\"loss_mask\", None)\n            if override_loss_mask is not None:\n                if isinstance(override_loss_mask, np.ndarray):\n                    override_loss_mask = override_loss_mask.item()\n                assert isinstance(override_loss_mask, int), f\"loss_mask should be int, got {type(override_loss_mask)}\"\n                assert override_loss_mask in [0, 1], f\"loss_mask should be 0 or 1, got {override_loss_mask}\"\n                loss_mask = [override_loss_mask] * len(tokens)\n\n            concat_tokens.extend(tokens)\n            concat_loss_mask.extend(loss_mask)\n            concat_attention_mask.extend(attention_mask)\n\n        # Validate and convert tokens\n        input_ids, loss_mask, attention_mask = self._validate_and_convert_tokens(\n            full_tokens[0], concat_tokens, concat_loss_mask, concat_attention_mask\n        )\n\n        # encode prompt\n        if messages[0][\"role\"] == \"system\":\n            assert messages[1][\"role\"] == \"user\"\n            assert messages[2][\"role\"] == \"assistant\"\n        elif messages[0][\"role\"] == \"user\":\n            assert messages[1][\"role\"] == \"assistant\"\n        else:\n            raise ValueError(f\"Unknown role: {messages[0]['role']}\")\n\n        sequence_length = input_ids.shape[0]\n        # Handle sequence length\n        if self.pad_mode == DatasetPadMode.RIGHT:\n            if sequence_length < self.max_length:\n                # Pad sequences\n                pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0\n                padded_input_ids = torch.full((self.max_length - sequence_length,), pad_token_id, dtype=input_ids.dtype)\n                padded_attention_mask = torch.zeros((self.max_length - sequence_length,), dtype=attention_mask.dtype)\n                padded_loss_mask = torch.zeros((self.max_length - sequence_length,), dtype=loss_mask.dtype)\n\n                input_ids = torch.cat((input_ids, padded_input_ids))\n                attention_mask = torch.cat((attention_mask, padded_attention_mask))\n                loss_mask = torch.cat((loss_mask, padded_loss_mask))\n            elif sequence_length > self.max_length:\n                if self.truncation == \"left\":\n                    input_ids = input_ids[-self.max_length :]\n                    attention_mask = attention_mask[-self.max_length :]\n                    loss_mask = loss_mask[-self.max_length :]\n                elif self.truncation == \"right\":\n                    input_ids = input_ids[: self.max_length]\n                    attention_mask = attention_mask[: self.max_length]\n                    loss_mask = loss_mask[: self.max_length]\n                elif self.truncation == \"error\":\n                    raise ValueError(f\"{sequence_length=} is larger than {self.max_length=}\")\n                else:\n                    raise ValueError(f\"Unknown truncation method {self.truncation}\")\n\n            # Create position IDs\n            position_ids = torch.arange(len(input_ids), dtype=torch.long)\n            # Zero out position IDs for padding\n            position_ids = position_ids * attention_mask\n\n            return {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n                \"loss_mask\": loss_mask,\n            }\n        elif self.pad_mode == DatasetPadMode.NO_PADDING:\n            # truncate input_ids if it is longer than max_length\n            if len(input_ids) > self.max_length:\n                input_ids = input_ids[: self.max_length]\n                loss_mask = loss_mask[: self.max_length]\n            # create position IDs\n            position_ids = torch.arange(len(input_ids), dtype=torch.long)\n            # return nested tensor with out padding\n            return {\n                \"input_ids\": input_ids,\n                \"position_ids\": position_ids,\n                \"loss_mask\": loss_mask,\n            }\n        else:\n            raise ValueError(f\"Unknown pad mode {self.pad_mode}\")\n"
  },
  {
    "path": "verl_distillation/verl/utils/dataset/onerec_dataset.py",
    "content": "import ast\nimport copy\nimport logging\nimport os\nimport random\nimport re\nfrom typing import Any, Optional\n\nimport datasets\nimport numpy as np\nfrom omegaconf import DictConfig, ListConfig\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.utils.model import compute_position_id_with_mask\n\nlogger = logging.getLogger(__name__)\n\nclass OneRecDataset(Dataset):\n    \"\"\"Onerec数据集读取与预处理。\n\n    - 缓存Parquet文件到本地；\n    - 利用HF Dataset读取并转换chat结构；\n    - 根据配置过滤超长prompt；\n    - 支持多模态预处理与位置编码。\n    \"\"\"\n\n    def __init__(\n        self,\n        data_files: str | list[str],\n        tokenizer: PreTrainedTokenizer,\n        config: DictConfig,\n        processor: Optional[ProcessorMixin] = None,\n        max_samples: int = -1,\n    ) -> None:\n        if not isinstance(data_files, (list, ListConfig)):\n            data_files = [data_files]\n\n        self.data_files = copy.deepcopy(data_files)\n        self.original_data_files = copy.deepcopy(data_files)\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.max_samples = max_samples\n        self.config = config\n\n        self.cache_dir = os.path.expanduser(config.get(\"cache_dir\", \"~/.cache/verl/rlhf\"))\n        self.prompt_key = config.get(\"prompt_key\", \"prompt\")\n        self.image_key = config.get(\"image_key\", \"images\")\n        self.video_key = config.get(\"video_key\", \"videos\")\n        self.max_prompt_length = config.get(\"max_prompt_length\", 1024)\n        self.return_raw_chat = config.get(\"return_raw_chat\", False)\n        self.return_full_prompt = config.get(\"return_full_prompt\", False)\n        self.truncation = config.get(\"truncation\", \"error\")\n        self.filter_overlong_prompts = config.get(\"filter_overlong_prompts\", True)\n        self.need_tools_kwargs = config.get(\"need_tools_kwargs\", False)\n        self.filter_prompts = config.get(\"filter_prompts\", True)\n        self.return_multi_modal_inputs = config.get(\"return_multi_modal_inputs\", True)\n        self.enable_think = config.get(\"enable_think\", True)\n        self.think_mode = config.get(\"think_mode\", \"force_think\")\n        self.shuffle = config.get(\"shuffle\", False)\n        self.seed = config.get(\"seed\", None)\n\n        auto_workers = max(1, (os.cpu_count() or 4) // 4)\n        self.num_workers = min(\n            config.get(\"filter_overlong_prompts_workers\", auto_workers),\n            os.cpu_count() or auto_workers,\n        )\n        self.use_shm = config.get(\"use_shm\", False)\n        self.serialize_dataset = False\n\n        #self._download()\n        self._read_files_and_tokenize()\n\n    # ---------------------------------------------------------------------\n    # 数据准备\n    # ---------------------------------------------------------------------\n    def _download(self, use_origin_parquet: bool = False) -> None:\n        from verl.utils.fs import copy_to_local\n\n        target_files = self.original_data_files if use_origin_parquet else self.data_files\n        for idx, parquet_file in enumerate(target_files):\n            local_path = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)\n            target_files[idx] = local_path\n\n        if use_origin_parquet:\n            self.data_files = target_files\n\n    def _read_files_and_tokenize(self) -> None:\n        #dataframes: list[datasets.Dataset] = []\n        self.dataframe = datasets.load_dataset(\"parquet\", data_files=self.data_files)[\"train\"]\n        #for parquet_file in self.data_files:\n        #    dataframe = datasets.load_dataset(\"parquet\", data_files=parquet_file)[\"train\"]\n        #    dataframes.append(dataframe)\n\n        #self.dataframe = datasets.concatenate_datasets(dataframes)  # type: ignore[attr-defined]\n        logger.info(\"dataset len: %s\", len(self.dataframe))\n\n        if self.max_samples > 0 and self.max_samples < len(self.dataframe):\n            if self.shuffle:\n                rngs_args = (self.seed,) if self.seed is not None else ()\n                rng = np.random.default_rng(*rngs_args)\n                indices = rng.choice(len(self.dataframe), size=self.max_samples, replace=False)\n            else:\n                indices = np.arange(self.max_samples)\n            self.dataframe = self.dataframe.select(indices.tolist())\n            print(f\"selected {self.max_samples} random samples out of {len(self.dataframe)}\")\n\n        self.dataframe = self.dataframe.map(\n            self._extract_prompt_fields,\n            num_proc=self.num_workers,\n            desc=\"Extract prompts and reward annotations\",\n        )\n\n        # 过滤掉处理失败的样本\n        original_len = len(self.dataframe)\n        self.dataframe = self.dataframe.filter(\n            self._is_valid_sample,\n            num_proc=self.num_workers,\n            desc=\"Filtering out failed samples\",\n        )\n        filtered_len = len(self.dataframe)\n        logger.info(\"Filtered out %s failed samples, remaining: %s\", original_len - filtered_len, filtered_len)\n\n        logger.info(\"processed dataset len: %s\", len(self.dataframe))\n        self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)\n\n    def _extract_prompt_fields(self, row: dict[str, Any]) -> dict[str, Any]:\n        try:\n            raw_messages = row.get(\"messages\")\n            if isinstance(raw_messages, str):\n                messages = ast.literal_eval(raw_messages)\n            else:\n                messages = raw_messages or []\n            \n            # 多轮对话清洗成单轮对话\n            user_cnt = 0\n            assistant_cnt = 0\n            clean_chats = []\n\n            for msg in messages:\n                if user_cnt > 0 and assistant_cnt > 0:\n                    break\n                role = msg.get(\"role\", None)\n                content = msg.get(\"content\", None)\n                if role is None or content is None:\n                    raise ValueError(\"role or content is None!\")\n                content_text = \"\"\n                if isinstance(content, str):\n                    content_text = content\n                elif isinstance(content, dict) and content.get(\"type\") == \"text\":\n                    content_text = content[\"text\"]\n                elif isinstance(content, list):\n                    for seg in content:\n                        if isinstance(seg, str):\n                            content_text += seg\n                        elif isinstance(seg, dict) and seg.get(\"type\") == \"text\":\n                            content_text += seg.get(\"text\", \"\")\n                        \n                if role == \"user\" and content_text.strip() == \"\":\n                    raise ValueError(\"content is empty!\")\n                \n                # # drop system prompt\n                # if role == \"system\":\n                #     if \"<think></think>\" in content_text or \"<answer></answer>\" in content_text:\n                #         continue\n                    \n                clean_chats.append({\n                    \"role\": role,\n                    \"content\": content_text\n                })\n                if role == \"user\":\n                    user_cnt += 1\n                \n                if role == \"assistant\":\n                    assistant_cnt += 1\n\n            if not clean_chats or len(clean_chats) < 2:\n                raise ValueError(\"Sample has empty messages; please check data integrity.\")\n\n            prompt_messages = clean_chats[:-1]\n\n            # 根据配置决定是否给 user 消息添加 /think /no_think 指令\n            if self.enable_think:\n                think_suffix = \"\"\n                if self.think_mode == \"force_think\":\n                    think_suffix = \" /think\"\n                elif self.think_mode == \"force_nothink\":\n                    think_suffix = \" /no_think\"\n                elif self.think_mode == \"auto\":\n                    tm_idx = random.randint(0, 2)\n                    think_suffix = \" /think\" if tm_idx == 1 else \" /no_think\" if tm_idx == 2 else \"\"\n                else:\n                    raise ValueError(\"think_mode is unexcept\")\n\n                for message in prompt_messages:\n                    if message[\"role\"] == \"user\":\n                        message[\"content\"] = message[\"content\"] + think_suffix\n\n            ground_truth_message = clean_chats[-1][\"content\"]\n\n            reward_payload = {\n                \"ground_truth\": ground_truth_message,\n                \"style\": \"rule\",\n            }\n\n            row[self.prompt_key] = prompt_messages\n            row[\"reward_model\"] = reward_payload\n            return row\n        except Exception as e:\n            # 标记处理失败的样本\n            row[\"_processing_failed\"] = True\n            row[\"_processing_error\"] = str(e)\n            return row\n\n    def _is_valid_sample(self, row: dict[str, Any]) -> bool:\n        \"\"\"检查样本是否处理成功\"\"\"\n        return not row.get(\"_processing_failed\", False)\n\n    # ---------------------------------------------------------------------\n    # 过滤与恢复\n    # ---------------------------------------------------------------------\n    def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset) -> datasets.Dataset:\n        if not self.filter_overlong_prompts:\n            return dataframe\n\n        tokenizer = self.tokenizer\n        processor = self.processor\n        prompt_key = self.prompt_key\n        image_key = self.image_key\n        video_key = self.video_key\n\n        if processor is not None:\n            from verl.utils.dataset.vision_utils import (process_image,\n                                                         process_video)\n\n            def doc_length(doc: dict[str, Any]) -> int:\n                messages = self._build_messages(dict(doc))\n                raw_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n                images = [process_image(image) for image in doc.get(image_key, [])]\n                videos = [process_video(video) for video in doc.get(video_key, [])]\n                encoded = processor(text=[raw_prompt], images=images or None, videos=videos or None, return_tensors=\"pt\")\n                return int(encoded[\"input_ids\"].shape[-1])\n\n        else:\n\n            def doc_length(doc: dict[str, Any]) -> int:\n                messages = doc[prompt_key]\n                return len(tokenizer.apply_chat_template(messages, add_generation_prompt=True))\n\n        filtered = dataframe.filter(\n            lambda doc: doc_length(doc) <= self.max_prompt_length,\n            num_proc=self.num_workers,\n            desc=f\"Filtering prompts longer than {self.max_prompt_length} tokens\",\n        )\n\n        # 获取data_source字段值为\"distill\"和\"sft\"的indices\n        if \"data_source\" in filtered.features:\n            self.distill_indices = [i for i, doc in enumerate(filtered) if doc.get(\"data_source\") == \"distill\"]\n            self.sft_indices = [i for i, doc in enumerate(filtered) if doc.get(\"data_source\") == \"sft\"]\n            logger.info(f\"distill samples: {len(self.distill_indices)}, sft samples: {len(self.sft_indices)}\")\n        else:\n            logger.warning(\"data_source field not found in filtered dataset\")\n\n        logger.info(\"filtered dataset len: %s\", len(filtered))\n        return filtered\n\n    def resume_dataset_state(self) -> None:\n        self.serialize_dataset = not hasattr(self, \"original_data_files\")\n        if not self.serialize_dataset:\n            self._download(use_origin_parquet=True)\n            self._read_files_and_tokenize()\n        else:\n            logger.warning(\"resume with serialized dataloader, consider restarting from scratch for better perf\")\n\n    # ---------------------------------------------------------------------\n    # Dataset 接口\n    # ---------------------------------------------------------------------\n    def __len__(self) -> int:  # type: ignore[override]\n        return len(self.dataframe)\n\n    def _build_messages(self, example: dict[str, Any]) -> list[dict[str, Any]]:\n        messages: list[dict[str, Any]] = example.pop(self.prompt_key)\n\n        if self.image_key in example or self.video_key in example:\n            for message in messages:\n                content = message[\"content\"]\n                segments = [segment for segment in re.split(r\"(<image>|<video>)\", content) if segment]\n                parsed_segments = []\n                for segment in segments:\n                    if segment == \"<image>\":\n                        parsed_segments.append({\"type\": \"image\"})\n                    elif segment == \"<video>\":\n                        parsed_segments.append({\"type\": \"video\"})\n                    else:\n                        parsed_segments.append({\"type\": \"text\", \"text\": segment})\n                message[\"content\"] = parsed_segments\n\n        return messages\n\n    def __getitem__(self, index: int) -> dict[str, Any]:  # type: ignore[override]\n        row: dict[str, Any] = dict(self.dataframe[index])\n        messages = self._build_messages(dict(row))\n        model_inputs: dict[str, Any] = {}\n\n        if self.processor is not None:\n            from verl.utils.dataset.vision_utils import (process_image,\n                                                         process_video)\n\n            raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n            multi_modal_data: dict[str, Any] = {}\n\n            images = None\n            if self.image_key in row and row.get(self.image_key):\n                images = [process_image(image) for image in row.pop(self.image_key)]\n                multi_modal_data[\"image\"] = images\n\n            videos = None\n            if self.video_key in row and row.get(self.video_key):\n                videos = [process_video(video) for video in row.pop(self.video_key)]\n                multi_modal_data[\"video\"] = [video.numpy() for video in videos]\n\n            model_inputs = self.processor(\n                text=[raw_prompt],\n                images=images,\n                videos=videos,\n                return_tensors=\"pt\",\n            )\n\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n            row[\"multi_modal_data\"] = multi_modal_data\n            if self.return_multi_modal_inputs:\n                mm_inputs = dict(model_inputs)\n                mm_inputs.pop(\"second_per_grid_ts\", None)\n                row[\"multi_modal_inputs\"] = mm_inputs\n        else:\n            raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n            model_inputs = self.tokenizer(raw_prompt, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n        input_ids, attention_mask = verl_F.postprocess_data(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            max_length=self.max_prompt_length,\n            pad_token_id=self.tokenizer.pad_token_id,\n            left_pad=True,\n            truncation=self.truncation,\n        )\n\n        if (\n            self.processor is not None\n            and hasattr(self.processor, \"image_processor\")\n            and \"Qwen2VLImageProcessor\" in self.processor.image_processor.__class__.__name__\n        ):\n            from verl.models.transformers.qwen2_vl import get_rope_index\n\n            position_ids = [\n                get_rope_index(\n                    self.processor,\n                    input_ids=input_ids[0],\n                    image_grid_thw=model_inputs.get(\"image_grid_thw\"),\n                    video_grid_thw=model_inputs.get(\"video_grid_thw\"),\n                    second_per_grid_ts=model_inputs.get(\"second_per_grid_ts\"),\n                    attention_mask=attention_mask[0],\n                )\n            ]\n        else:\n            position_ids = compute_position_id_with_mask(attention_mask)\n\n        row[\"input_ids\"] = input_ids[0]\n        row[\"attention_mask\"] = attention_mask[0]\n        row[\"position_ids\"] = position_ids[0]\n\n        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)\n        if len(raw_prompt_ids) > self.max_prompt_length:\n            raw_prompt_ids = self._truncate_ids(raw_prompt_ids)\n\n        row[\"raw_prompt_ids\"] = raw_prompt_ids\n        if self.return_raw_chat:\n            row[\"raw_prompt\"] = messages\n        if self.return_full_prompt:\n            row[\"full_prompts\"] = raw_prompt\n\n        extra_info = row.get(\"extra_info\", {}) or {}\n        row[\"index\"] = extra_info.get(\"index\", index)\n        row[\"tools_kwargs\"] = extra_info.get(\"tools_kwargs\", {})\n        row[\"interaction_kwargs\"] = extra_info.get(\"interaction_kwargs\", {})\n\n        # 确保 data_source 或 source 字段被保留（用于按task统计）\n        # 原始 parquet 数据中应该包含 source 或 data_source 字段\n        # 如果都不存在，设置一个默认值\n        if \"source\" in row or \"data_source\" in row:\n            # 字段已存在，无需处理（会自动被 collate_fn 收集）\n            pass\n        else:\n            # 如果两个字段都不存在，设置一个默认值\n            row[\"data_source\"] = \"unknown\"\n            logger.warning(\"No source/data_source field found for index %s, set to 'unknown'\", row[\"index\"])\n\n        if self.need_tools_kwargs and not row[\"tools_kwargs\"]:\n            logger.warning(\"tools_kwargs is empty for index %s, data source: %s\", row[\"index\"], row.get(\"data_source\", row.get(\"source\", \"unknown\")))\n\n        return row\n\n    def _truncate_ids(self, token_ids: list[int]) -> list[int]:\n        if self.truncation == \"left\":\n            return token_ids[-self.max_prompt_length :]\n        if self.truncation == \"right\":\n            return token_ids[: self.max_prompt_length]\n        if self.truncation == \"middle\":\n            left = self.max_prompt_length // 2\n            right = self.max_prompt_length - left\n            return token_ids[:left] + token_ids[-right:]\n        if self.truncation == \"error\":\n            raise RuntimeError(\n                f\"Prompt length {len(token_ids)} exceeds max_prompt_length={self.max_prompt_length}. \"\n                \"Consider increasingmax_prompt_length or enabling truncation.\"\n            )\n        raise ValueError(f\"Unsupported truncation mode: {self.truncation}\")\n\n    def __getstate__(self) -> dict[str, Any]:\n        if not self.serialize_dataset:\n            state = self.__dict__.copy()\n            if \"dataframe\" in state:\n                del state[\"dataframe\"]\n            return state\n        return self.__dict__.copy()"
  },
  {
    "path": "verl_distillation/verl/utils/dataset/rl_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nimport logging\nimport os\nimport re\nimport traceback\nfrom collections import defaultdict\nfrom typing import Optional\n\nimport datasets\nimport numpy as np\nimport torch\nfrom omegaconf import DictConfig, ListConfig\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.utils.model import compute_position_id_with_mask\n\nlogger = logging.getLogger(__name__)\n\n\ndef collate_fn(data_list: list[dict]) -> dict:\n    \"\"\"\n    Collate a batch of sample dicts into batched tensors and arrays.\n\n    Args:\n        data_list: List of dicts mapping feature names to torch.Tensor or other values.\n\n    Returns:\n        Dict where tensor entries are stacked into a torch.Tensor of shape\n        (batch_size, \\\\*dims) and non-tensor entries are converted to\n        np.ndarray of dtype object with shape (batch_size,).\n    \"\"\"\n    tensors = defaultdict(list)\n    non_tensors = defaultdict(list)\n\n    for data in data_list:\n        for key, val in data.items():\n            if isinstance(val, torch.Tensor):\n                tensors[key].append(val)\n            else:\n                non_tensors[key].append(val)\n\n    for key, val in tensors.items():\n        tensors[key] = torch.stack(val, dim=0)\n\n    for key, val in non_tensors.items():\n        non_tensors[key] = np.fromiter(val, dtype=object, count=len(val))\n\n    return {**tensors, **non_tensors}\n\n\nclass RLHFDataset(Dataset):\n    \"\"\"\n    Load and preprocess RLHF data from Parquet files.\n\n    - Caches files locally.\n    - Reads into a HuggingFace Dataset and tokenizes prompts.\n    - Optionally handles images/videos via a ProcessorMixin.\n    - Filters prompts over a max length.\n    - Supports resuming from checkpoints.\n\n    Args:\n        data_files (str or list): Path(s) to Parquet file(s).\n        tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.\n        config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.\n        processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_files: str | list[str],\n        tokenizer: PreTrainedTokenizer,\n        config: DictConfig,\n        processor: Optional[ProcessorMixin] = None,\n        max_samples: int = -1,\n    ):\n        if not isinstance(data_files, list | ListConfig):\n            data_files = [data_files]\n\n        self.data_files = copy.deepcopy(data_files)\n        self.original_data_files = copy.deepcopy(data_files)  # use for resume\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.max_samples = max_samples\n        self.config = config\n\n        self.cache_dir = os.path.expanduser(config.get(\"cache_dir\", \"~/.cache/verl/rlhf\"))\n        self.prompt_key = config.get(\"prompt_key\", \"prompt\")\n        self.image_key = config.get(\"image_key\", \"images\")\n        self.video_key = config.get(\"video_key\", \"videos\")\n        self.image_patch_size = config.get(\"image_patch_size\", 14)\n        self.max_prompt_length = config.get(\"max_prompt_length\", 1024)\n        self.return_raw_chat = config.get(\"return_raw_chat\", False)\n        self.return_full_prompt = config.get(\"return_full_prompt\", False)\n        self.truncation = config.get(\"truncation\", \"error\")\n        self.filter_overlong_prompts = config.get(\"filter_overlong_prompts\", True)\n        self.apply_chat_template_kwargs = config.get(\"apply_chat_template_kwargs\", {})\n\n        self.tool_config_path = config.get(\"tool_config_path\", None)\n        self.tool_schemas = None\n        if self.tool_config_path:\n            try:\n                from verl.tools.utils.tool_registry import initialize_tools_from_config\n\n                tool_list = initialize_tools_from_config(self.tool_config_path)\n                # match ToolAgentLoop behaviour: model_dump to plain dicts\n                self.tool_schemas = [\n                    tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list\n                ]\n            except Exception as e:\n                logger.warning(\"Failed to initialize tools from %s: %s\", self.tool_config_path, e)\n                self.tool_schemas = None\n\n        self.num_workers = config.get(\"filter_overlong_prompts_workers\", max(1, os.cpu_count() // 4))\n        self.num_workers = min(self.num_workers, os.cpu_count())\n        self.use_shm = config.get(\"use_shm\", False)\n        self.chat_template_func = config.get(\"chat_template_func\", None)\n        self.need_tools_kwargs = config.get(\"need_tools_kwargs\", False)\n        self.filter_prompts = config.get(\"filter_prompts\", True)\n        self.serialize_dataset = False\n        self.return_multi_modal_inputs = config.get(\"return_multi_modal_inputs\", True)\n        self.shuffle = config.get(\"shuffle\", False)\n        self.seed = config.get(\"seed\")\n\n        self._download()\n        self._read_files_and_tokenize()\n\n    def _download(self, use_origin_parquet=False):\n        from verl.utils.fs import copy_to_local\n\n        data_files = self.data_files if not use_origin_parquet else self.original_data_files\n        for i, parquet_file in enumerate(data_files):\n            self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)\n\n    def _read_files_and_tokenize(self):\n        dataframes = []\n        for parquet_file in self.data_files:\n            # read parquet files and cache\n            dataframe = datasets.load_dataset(\"parquet\", data_files=parquet_file)[\"train\"]\n            dataframes.append(dataframe)\n        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)\n\n        total = len(self.dataframe)\n        print(f\"dataset len: {len(self.dataframe)}\")\n\n        if self.max_samples > 0 and self.max_samples < total:\n            if self.shuffle:\n                rngs_args = (self.seed,) if self.seed is not None else ()\n                rng = np.random.default_rng(*rngs_args)\n                indices = rng.choice(total, size=self.max_samples, replace=False)\n            else:\n                indices = np.arange(self.max_samples)\n            self.dataframe = self.dataframe.select(indices.tolist())\n            print(f\"selected {self.max_samples} random samples out of {total}\")\n\n        self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)\n\n    def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None):\n        # filter out too long prompts\n        if self.filter_overlong_prompts:\n            tokenizer = self.tokenizer\n            processor = self.processor\n            prompt_key = self.prompt_key\n            image_key = self.image_key\n            video_key = self.video_key\n\n            if processor is not None:\n                from verl.utils.dataset.vision_utils import process_image, process_video\n\n                def doc2len(doc) -> int:\n                    try:\n                        messages = self._build_messages(doc)\n                        # pass tool schemas if available so the processor can format prompts\n                        apply_kwargs = dict(**self.apply_chat_template_kwargs)\n                        if self.tool_schemas is not None:\n                            apply_kwargs[\"tools\"] = self.tool_schemas\n\n                        raw_prompt = self.processor.apply_chat_template(\n                            messages, add_generation_prompt=True, tokenize=False, **apply_kwargs\n                        )\n                        if image_key in doc and doc[image_key]:\n                            images = [\n                                process_image(image, image_patch_size=self.image_patch_size) for image in doc[image_key]\n                            ]\n                        else:\n                            images = None\n\n                        if video_key in doc and doc[video_key]:\n                            videos, video_metadata = zip(\n                                *[\n                                    process_video(\n                                        video, image_patch_size=self.image_patch_size, return_video_metadata=True\n                                    )\n                                    for video in doc[video_key]\n                                ],\n                                strict=True,\n                            )\n                            videos = list(videos)\n                            video_metadata = list(video_metadata)\n                            videos_kwargs = {\"video_metadata\": video_metadata, \"do_sample_frames\": False}\n                        else:\n                            videos = None\n                            videos_kwargs = {}\n\n                        return len(\n                            processor(text=[raw_prompt], images=images, videos=videos, videos_kwargs=videos_kwargs)[\n                                \"input_ids\"\n                            ][0]\n                        )\n                    except Exception:\n                        print(\"Error processing one of the samples, skipping...\")\n                        traceback.print_exc()\n                        return self.max_prompt_length + 1\n\n            else:\n\n                def doc2len(doc) -> int:\n                    try:\n                        apply_kwargs = dict(**self.apply_chat_template_kwargs)\n                        if self.tool_schemas is not None:\n                            apply_kwargs[\"tools\"] = self.tool_schemas\n\n                        return len(\n                            tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True, **apply_kwargs)\n                        )\n                    except Exception:\n                        print(\"Error processing one of the samples, skipping...\")\n                        traceback.print_exc()\n                        return self.max_prompt_length + 1\n\n            dataframe = dataframe.filter(\n                lambda doc: doc2len(doc) <= self.max_prompt_length,\n                num_proc=self.num_workers,\n                desc=f\"Filtering prompts longer than {self.max_prompt_length} tokens\",\n            )\n\n            print(f\"filter dataset len: {len(dataframe)}\")\n        return dataframe\n\n    def resume_dataset_state(self):\n        self.serialize_dataset = not hasattr(self, \"original_data_files\")\n        # resume dataframe if not it's serialized in data.pt\n        if not self.serialize_dataset:\n            self._download(use_origin_parquet=True)  # download and resume from original parquet files\n            self._read_files_and_tokenize()\n        else:\n            print(r\"old dataloader ckpt file is used, please train from scratch for better ckpt performance\")\n\n    def __len__(self):\n        return len(self.dataframe)\n\n    def _build_messages(self, example: dict):\n        messages: list = example.pop(self.prompt_key)\n\n        if self.image_key in example or self.video_key in example:\n            for message in messages:\n                content = message[\"content\"]\n                content_list = []\n                segments = re.split(\"(<image>|<video>)\", content)\n                segments = [item for item in segments if item != \"\"]\n                for segment in segments:\n                    if segment == \"<image>\":\n                        content_list.append({\"type\": \"image\"})\n                    elif segment == \"<video>\":\n                        content_list.append({\"type\": \"video\"})\n                    else:\n                        content_list.append({\"type\": \"text\", \"text\": segment})\n\n                message[\"content\"] = content_list\n\n        return messages\n\n    def __getitem__(self, item):\n        \"\"\"\n        Note that we also return the raw_input_ids so that it can be combined with other chat template\n        \"\"\"\n        row_dict: dict = self.dataframe[item]\n        messages = self._build_messages(row_dict)\n        model_inputs = {}\n\n        if self.processor is not None:\n            from verl.utils.dataset.vision_utils import process_image, process_video\n\n            raw_prompt = self.processor.apply_chat_template(\n                messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs\n            )\n            multi_modal_data = {}\n\n            images = None\n            row_dict_images = row_dict.pop(self.image_key, None)\n            if row_dict_images:\n                images = [process_image(image, image_patch_size=self.image_patch_size) for image in row_dict_images]\n\n                # due to the image key is \"image\" instead of \"images\" in vllm, we need to use \"image\" here\n                # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205\n                multi_modal_data[\"image\"] = images\n\n            videos = None\n            videos_kwargs = {}\n            row_dict_videos = row_dict.pop(self.video_key, None)\n            if row_dict_videos:\n                videos, video_metadata = zip(\n                    *[\n                        process_video(video, image_patch_size=self.image_patch_size, return_video_metadata=True)\n                        for video in row_dict_videos\n                    ],\n                    strict=True,\n                )\n                videos = list(videos)\n                video_metadata = list(video_metadata)\n                videos_kwargs = {\"video_metadata\": video_metadata, \"do_sample_frames\": False}\n\n                # due to the video key is \"video\" instead of \"videos\" in vllm, we need to use \"video\" here\n                # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205\n                multi_modal_data[\"video\"] = [\n                    (video.numpy(), metadata) for video, metadata in zip(videos, video_metadata, strict=True)\n                ]\n\n            model_inputs = self.processor(\n                text=[raw_prompt], images=images, videos=videos, videos_kwargs=videos_kwargs, return_tensors=\"pt\"\n            )\n\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n            if \"second_per_grid_ts\" in model_inputs:\n                model_inputs.pop(\"second_per_grid_ts\")\n\n            # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature\n            row_dict[\"multi_modal_data\"] = multi_modal_data\n\n            # We will do batch.union() in the trainer,\n            # so we cannot have \"multi_modal_inputs\" in row_dict if rollout generates new multi_modal_inputs\n            if self.return_multi_modal_inputs:\n                row_dict[\"multi_modal_inputs\"] = dict(model_inputs)\n\n                # second_per_grid_ts isn't used for training, just for mrope\n                row_dict[\"multi_modal_inputs\"].pop(\"second_per_grid_ts\", None)\n\n        else:\n            if self.apply_chat_template_kwargs.get(\"chat_template\") is None:\n                assert hasattr(self.tokenizer, \"chat_template\"), (\n                    \"chat_template should be provided in apply_chat_template_kwargs or tokenizer config, \"\n                    \"models like GLM can copy chat_template.jinja from instruct models\"\n                )\n            raw_prompt = self.tokenizer.apply_chat_template(\n                messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs\n            )\n            model_inputs = self.tokenizer(raw_prompt, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n        input_ids, attention_mask = verl_F.postprocess_data(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            max_length=self.max_prompt_length,\n            pad_token_id=self.tokenizer.pad_token_id,\n            left_pad=True,\n            truncation=self.truncation,\n        )\n\n        if self.processor is not None and \"Qwen2VLImageProcessor\" in self.processor.image_processor.__class__.__name__:\n            # qwen-vl mrope\n            if \"Qwen3VLProcessor\" in self.processor.__class__.__name__:\n                from verl.models.transformers.qwen3_vl import get_rope_index\n            else:\n                from verl.models.transformers.qwen2_vl import get_rope_index\n\n            vision_position_ids = get_rope_index(\n                self.processor,\n                input_ids=input_ids[0],\n                image_grid_thw=model_inputs.get(\"image_grid_thw\"),\n                video_grid_thw=model_inputs.get(\"video_grid_thw\"),\n                second_per_grid_ts=model_inputs.get(\"second_per_grid_ts\"),\n                attention_mask=attention_mask[0],\n            )  # (3, seq_length)\n            valid_mask = attention_mask[0].bool()\n            text_position_ids = torch.ones((1, len(input_ids[0])), dtype=torch.long)\n            text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item())\n            position_ids = [torch.cat((text_position_ids, vision_position_ids), dim=0)]  # (1, 4, seq_length)\n        elif self.processor is not None and \"Glm4vImageProcessor\" in self.processor.image_processor.__class__.__name__:\n            from verl.models.transformers.glm4v import get_rope_index\n\n            vision_position_ids = get_rope_index(\n                self.processor,\n                input_ids=input_ids[0],\n                image_grid_thw=model_inputs.get(\"image_grid_thw\"),\n                video_grid_thw=model_inputs.get(\"video_grid_thw\"),\n                attention_mask=attention_mask[0],\n            )  # (3, seq_length)\n            valid_mask = attention_mask[0].bool()\n            text_position_ids = torch.ones((1, len(input_ids[0])), dtype=torch.long)\n            text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item())\n            position_ids = [torch.cat((text_position_ids, vision_position_ids), dim=0)]  # (1, 4, seq_length)\n        else:\n            position_ids = compute_position_id_with_mask(attention_mask)\n\n        row_dict[\"input_ids\"] = input_ids[0]\n        row_dict[\"attention_mask\"] = attention_mask[0]\n        row_dict[\"position_ids\"] = position_ids[0]\n\n        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)\n        if len(raw_prompt_ids) > self.max_prompt_length:\n            if self.truncation == \"left\":\n                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]\n            elif self.truncation == \"right\":\n                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]\n            elif self.truncation == \"middle\":\n                left_half = self.max_prompt_length // 2\n                right_half = self.max_prompt_length - left_half\n                raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]\n            elif self.truncation == \"error\":\n                raise RuntimeError(f\"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.\")\n\n        row_dict[\"raw_prompt_ids\"] = raw_prompt_ids\n        # encode prompts without chat template\n        if self.return_raw_chat:\n            row_dict[\"raw_prompt\"] = messages\n\n        # get prompts with chat template\n        if self.return_full_prompt:\n            row_dict[\"full_prompts\"] = raw_prompt  # array of strings\n\n        # add index for each prompt\n        if \"extra_info\" not in row_dict or row_dict[\"extra_info\"] is None:\n            row_dict[\"extra_info\"] = dict()\n        index = row_dict.get(\"extra_info\", {}).get(\"index\", 0)\n        tools_kwargs = row_dict.get(\"extra_info\", {}).get(\"tools_kwargs\", {})\n        interaction_kwargs = row_dict.get(\"extra_info\", {}).get(\"interaction_kwargs\", {})\n        need_tools_kwargs = row_dict.get(\"extra_info\", {}).get(\"need_tools_kwargs\", self.need_tools_kwargs)\n        if need_tools_kwargs and not tools_kwargs:\n            logger.warning(\"tools_kwargs is empty for index {}, data source: {}\", index, row_dict[\"data_source\"])\n        row_dict[\"index\"] = index\n        row_dict[\"tools_kwargs\"] = tools_kwargs\n        row_dict[\"interaction_kwargs\"] = interaction_kwargs\n        return row_dict\n\n    def __getstate__(self):\n        if not self.serialize_dataset:\n            state = self.__dict__.copy()\n\n            if \"dataframe\" in state:\n                del state[\"dataframe\"]\n            return state\n\n        return self.__dict__.copy()\n"
  },
  {
    "path": "verl_distillation/verl/utils/dataset/rm_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom typing import Optional\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom verl.utils import hf_tokenizer\n\n\ndef download_files_distributed(download_fn):\n    import torch.distributed\n\n    if torch.distributed.is_initialized():\n        if torch.distributed.get_rank() == 0:\n            # download files\n            download_fn()\n\n        torch.distributed.barrier()\n    else:\n        # download anyway\n        download_fn()\n\n\nclass RMDataset(Dataset):\n    def __init__(\n        self,\n        parquet_files: str | list[str],\n        tokenizer,\n        prompt_key=\"prompt\",\n        chosen_key=\"chosen\",\n        rejected_key=\"rejected\",\n        max_length=1024,\n        add_eos=True,\n        cache_dir=\"~/.cache/verl/rm\",\n        max_samples: int = -1,\n        shuffle: bool = False,\n        seed: Optional[int] = None,\n    ):\n        if not isinstance(parquet_files, list):\n            parquet_files = [parquet_files]\n\n        self.parquet_files = parquet_files\n        self.max_samples = max_samples\n        self.shuffle = shuffle\n        self.seed = seed\n        self.cache_dir = os.path.expanduser(cache_dir)\n        if isinstance(tokenizer, str):\n            tokenizer = hf_tokenizer(tokenizer)\n        self.tokenizer = tokenizer\n\n        self.prompt_key = prompt_key\n        self.chosen_key = chosen_key\n        self.rejected_key = rejected_key\n\n        self.add_eos = add_eos\n        self.max_length = max_length\n\n        self._download()\n        self._read_files_and_tokenize()\n\n    def _download(self):\n        def _download_files():\n            from verl.utils.fs import copy, is_non_local\n\n            os.makedirs(self.cache_dir, exist_ok=True)\n            assert os.path.exists(self.cache_dir)\n            for i, parquet_file in enumerate(self.parquet_files):\n                if is_non_local(parquet_file):\n                    dst = os.path.join(self.cache_dir, os.path.basename(parquet_file))\n                    if not os.path.exists(dst):\n                        copy(src=parquet_file, dst=dst)\n                    self.parquet_files[i] = dst\n\n        download_files_distributed(_download_files)\n\n    def _read_files_and_tokenize(self):\n        dataframes = []\n        for parquet_file in self.parquet_files:\n            # read parquet files and cache\n            dataframe = pd.read_parquet(parquet_file)\n            dataframes.append(dataframe)\n        self.dataframe = pd.concat(dataframes)\n\n        total = len(self.dataframe)\n        print(f\"dataset len: {len(self.dataframe)}\")\n\n        if self.max_samples > 0 and self.max_samples < total:\n            if self.shuffle:\n                rngs_args = (self.seed,) if self.seed is not None else ()\n                rng = np.random.default_rng(*rngs_args)\n                indices = rng.choice(total, size=self.max_samples, replace=False)\n            else:\n                indices = np.arange(self.max_samples)\n            self.dataframe = self.dataframe.iloc[indices.tolist()]\n            print(f\"selected {self.max_samples} random samples out of {total}\")\n\n        self.prompts = self.dataframe[self.prompt_key].tolist()\n        self.chosen_responses = self.dataframe[self.chosen_key].tolist()\n        self.rejected_responses = self.dataframe[self.rejected_key].tolist()\n\n    def __len__(self):\n        return len(self.prompts)\n\n    def _pad_to_length(self, input_ids, attention_mask):\n        curr_length = input_ids.shape[-1]\n\n        if curr_length < self.max_length:\n            input_ids = torch.cat(\n                (input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1\n            )\n            attention_mask = torch.cat(\n                (attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)), dim=-1\n            )\n        elif curr_length > self.max_length:\n            input_ids = input_ids[: self.max_length]\n            attention_mask = attention_mask[: self.max_length]\n\n        return input_ids, attention_mask\n\n    def __getitem__(self, item):\n        prompt = self.prompts[item]\n        chosen_response = self.chosen_responses[item]\n        rejected_response = self.rejected_responses[item]\n\n        prompt_ids = self.tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"][0]\n        chosen_response_ids = self.tokenizer(chosen_response, return_tensors=\"pt\")[\"input_ids\"][0]\n        rejected_response_ids = self.tokenizer(rejected_response, return_tensors=\"pt\")[\"input_ids\"][0]\n\n        if self.add_eos:\n            chosen_response_ids = torch.cat((chosen_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1)\n            rejected_response_ids = torch.cat(\n                (rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1\n            )\n\n        chosen_input_ids = torch.cat((prompt_ids, chosen_response_ids), dim=-1)\n        chosen_attention_mask = torch.ones_like(chosen_input_ids)\n\n        rejected_input_ids = torch.cat((prompt_ids, rejected_response_ids), dim=-1)\n        rejected_attention_mask = torch.ones_like(rejected_input_ids)\n\n        chosen_input_ids, chosen_attention_mask = self._pad_to_length(chosen_input_ids, chosen_attention_mask)\n        rejected_input_ids, rejected_attention_mask = self._pad_to_length(rejected_input_ids, rejected_attention_mask)\n\n        input_ids = torch.stack((chosen_input_ids, rejected_input_ids), dim=0)\n        attention_mask = torch.stack((chosen_attention_mask, rejected_attention_mask), dim=0)\n\n        return {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n        }\n"
  },
  {
    "path": "verl_distillation/verl/utils/dataset/sft_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nSFT dataset\n- We assume user pass a single parquet file.\n- We load all the data into the memory.\nEach parquet file contains\n\"\"\"\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom omegaconf.listconfig import ListConfig\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer\n\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.model import compute_position_id_with_mask\n\n\nclass SFTDataset(Dataset):\n    \"\"\"\n    This is an in-memory SFTDataset\n\n    Arguments:\n        config (OmegaConf): the data config\n    \"\"\"\n\n    def __init__(self, parquet_files: str | ListConfig, tokenizer, config, max_samples: int = -1):\n        prompt_key = config.get(\"prompt_key\", \"prompt\")\n        prompt_dict_keys = config.get(\"prompt_dict_keys\", None)\n        response_key = config.get(\"response_key\", \"response\")\n        response_dict_keys = config.get(\"response_dict_keys\", None)\n        max_length = config.get(\"max_length\", 1024)\n        truncation = config.get(\"truncation\", \"error\")\n        use_shm = config.get(\"use_shm\", False)\n        self.shuffle = config.get(\"shuffle\", False)\n        self.seed = config.get(\"seed\")\n        self.apply_chat_template_kwargs = config.get(\"apply_chat_template_kwargs\", {})\n\n        assert truncation in [\"error\", \"left\", \"right\"]\n        self.truncation = truncation\n        self.use_shm = use_shm\n\n        if not isinstance(parquet_files, ListConfig):\n            parquet_files = [parquet_files]\n\n        self.parquet_files = parquet_files\n        self.max_samples = max_samples\n        if isinstance(tokenizer, str):\n            tokenizer = hf_tokenizer(tokenizer)\n        self.tokenizer: PreTrainedTokenizer = tokenizer\n\n        self.prompt_key = prompt_key if isinstance(prompt_key, tuple | list) else [prompt_key]\n        self.response_key = response_key if isinstance(response_key, tuple | list) else [response_key]\n        self.prompt_dict_keys = prompt_dict_keys if prompt_dict_keys else []\n        self.response_dict_keys = response_dict_keys if response_dict_keys else []\n\n        self.max_length = max_length\n\n        self._download()\n        self._read_files_and_tokenize()\n\n    def _download(self):\n        for i, parquet_file in enumerate(self.parquet_files):\n            self.parquet_files[i] = copy_to_local(parquet_file, verbose=True, use_shm=self.use_shm)\n\n    def _read_files_and_tokenize(self):\n        def series_to_item(ls):\n            import numpy\n            import pandas\n\n            while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1:\n                ls = ls[0]\n            return ls\n\n        dataframes = []\n        for parquet_file in self.parquet_files:\n            # read parquet files and cache\n            dataframe = pd.read_parquet(parquet_file)\n            dataframes.append(dataframe)\n        self.dataframe = pd.concat(dataframes)\n\n        total = len(self.dataframe)\n        print(f\"dataset len: {len(self.dataframe)}\")\n\n        if self.max_samples > 0 and self.max_samples < total:\n            if self.shuffle:\n                rngs_args = (self.seed,) if self.seed is not None else ()\n                rng = np.random.default_rng(*rngs_args)\n                indices = rng.choice(total, size=self.max_samples, replace=False)\n            else:\n                indices = np.arange(self.max_samples)\n            self.dataframe = self.dataframe.iloc[indices.tolist()]\n            print(f\"selected {self.max_samples} random samples out of {total}\")\n\n        self.prompts = self.dataframe[self.prompt_key]\n        for key in self.prompt_dict_keys:\n            # type(x): pandas.core.series.Series\n            # type(x[0]): numpy.ndarray\n            # type(x[0][0]): dict\n            try:\n                self.prompts = self.prompts.apply(lambda x: series_to_item(x)[key], axis=1)  # noqa: B023\n            except Exception:\n                print(f\"self.prompts={self.prompts}\")\n                raise\n        if isinstance(self.prompts, pd.DataFrame):\n            self.prompts = self.prompts.squeeze()\n        self.prompts = self.prompts.tolist()\n        self.responses = self.dataframe[self.response_key]\n        for key in self.response_dict_keys:\n            try:\n                self.responses = self.responses.apply(lambda x: series_to_item(x)[key], axis=1)  # noqa: B023\n            except Exception:\n                print(f\"self.responses={self.responses}\")\n                raise\n        if isinstance(self.responses, pd.DataFrame):\n            self.responses = self.responses.squeeze()\n        self.responses = self.responses.tolist()\n\n    def __len__(self):\n        return len(self.prompts)\n\n    def __getitem__(self, item):\n        tokenizer = self.tokenizer\n\n        prompt = self.prompts[item]\n        response = self.responses[item]\n\n        # apply chat template\n        prompt_chat = [{\"role\": \"user\", \"content\": prompt}]\n\n        # string\n        prompt_chat_str = tokenizer.apply_chat_template(\n            prompt_chat, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs\n        )\n        response_chat_str = response + tokenizer.eos_token\n\n        # tokenize\n        prompt_ids_output = tokenizer(prompt_chat_str, return_tensors=\"pt\", add_special_tokens=False)\n        prompt_ids = prompt_ids_output[\"input_ids\"][0]\n        prompt_attention_mask = prompt_ids_output[\"attention_mask\"][0]\n\n        response_ids_output = tokenizer(response_chat_str, return_tensors=\"pt\", add_special_tokens=False)\n        response_ids = response_ids_output[\"input_ids\"][0]\n        response_attention_mask = response_ids_output[\"attention_mask\"][0]\n\n        prompt_length = prompt_ids.shape[0]\n        response_length = response_ids.shape[0]\n\n        input_ids = torch.cat((prompt_ids, response_ids), dim=-1)\n        attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)\n\n        # padding to max length\n        sequence_length = input_ids.shape[0]\n        if sequence_length < self.max_length:\n            padded_input_ids = (\n                torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype)\n                * self.tokenizer.pad_token_id\n            )\n            padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype)\n\n            input_ids = torch.cat((input_ids, padded_input_ids))\n            attention_mask = torch.cat((attention_mask, padded_attention_mask))\n        elif sequence_length > self.max_length:\n            if self.truncation == \"left\":\n                # actually, left truncation may not be reasonable\n                input_ids = input_ids[-self.max_length :]\n                attention_mask = attention_mask[-self.max_length :]\n            elif self.truncation == \"right\":\n                input_ids = input_ids[: self.max_length]\n                attention_mask = attention_mask[: self.max_length]\n            elif self.truncation == \"error\":\n                raise NotImplementedError(f\"{sequence_length=} is larger than {self.max_length=}\")\n            else:\n                raise NotImplementedError(f\"Unknown truncation method {self.truncation}\")\n\n        position_ids = compute_position_id_with_mask(attention_mask)\n\n        loss_mask = attention_mask.clone()\n        if prompt_length > 1:\n            # mask out prompt for SFT.\n            loss_mask[: min(prompt_length, loss_mask.size(0)) - 1] = 0\n        # mask out the last token in response\n        loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0\n\n        return {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"position_ids\": position_ids,\n            \"loss_mask\": loss_mask,\n        }\n"
  },
  {
    "path": "verl_distillation/verl/utils/dataset/vision_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom io import BytesIO\nfrom typing import Optional\n\nimport torch\nfrom PIL import Image\nfrom qwen_vl_utils import fetch_image, fetch_video\n\n\ndef process_image(image: dict | Image.Image, image_patch_size: int = 14) -> Image.Image:\n    if isinstance(image, Image.Image):\n        return image.convert(\"RGB\")\n\n    if \"bytes\" in image:\n        assert \"image\" not in image, \"Cannot have both `bytes` and `image`\"\n        image[\"image\"] = Image.open(BytesIO(image[\"bytes\"]))\n\n    return fetch_image(image, image_patch_size=image_patch_size)\n\n\nVIDEO_FORMAT_HELP = \"\"\"Currently, we only support the video formats introduced in qwen2-vl.\nRefer to https://github.com/QwenLM/Qwen2.5-VL?tab=readme-ov-file#using---transformers-to-chat.\n\neg.\n{\n    \"type\": \"video\",\n    \"video\": [\n        \"file:///path/to/frame1.jpg\",\n        \"file:///path/to/frame2.jpg\"\n    ]\n}\n\n{\n    \"type\": \"video\",\n    \"video\": \"file:///path/to/video.mp4\"\n}\n# Defaults to fps=2, min_frames=4, max_frames=768\n\n{\n    \"type\": \"video\",\n    \"video\": \"file:///path/to/video.mp4\",\n    \"fps\": 2,\n    \"min_frames\": 1,\n    \"max_frames\": 32\n}\n\"\"\"\n\n\ndef process_video(\n    video: dict,\n    image_patch_size: int = 14,\n    nframes: Optional[int] = None,\n    fps: Optional[float] = None,\n    fps_min_frames: Optional[int] = None,\n    fps_max_frames: Optional[int] = None,\n    return_video_sample_fps: bool = False,\n    return_video_metadata: bool = False,\n) -> torch.Tensor:\n    \"\"\"Converts a video dict into a [n_frames, 3, H, W] tensor\n\n    Add video sample FPS in a future MR\n    \"\"\"\n\n    if not isinstance(video, dict) or \"video\" not in video:\n        raise NotImplementedError(VIDEO_FORMAT_HELP)\n    assert nframes is None or fps is None, \"Can't use both `nframes` or `fps`\"\n\n    # Shallow copy... since we might want to add some keys\n    video = dict(video)\n\n    contains_sampling_rules = \"nframes\" in video or \"fps\" in video\n    if not contains_sampling_rules:\n        if nframes is not None:\n            video[\"nframes\"] = nframes\n        elif fps is not None:\n            video[\"fps\"] = fps\n            if fps_min_frames is not None:\n                video[\"min_frames\"] = fps_min_frames\n            if fps_max_frames is not None:\n                video[\"max_frames\"] = fps_max_frames\n\n    return fetch_video(\n        video,\n        image_patch_size=image_patch_size,\n        return_video_sample_fps=return_video_sample_fps,\n        return_video_metadata=return_video_metadata,\n    )\n\n\ndef process_multi_modal_inputs_for_minicpmo(input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs):\n    # Adjust image bounds based on left padding and cumulative sequence lengths\n    # This is necessary for MiniCPM-o's vision-language alignment\n    left_padding_length = torch.argmax(attention_mask, dim=1)\n    image_bounds = []\n    for i in range(len(multi_modal_inputs[\"image_bound\"])):\n        image_bound = (\n            multi_modal_inputs[\"image_bound\"][i].to(left_padding_length.device) - left_padding_length[i] + cu_seqlens[i]\n        )\n        image_bounds.append(image_bound)\n\n    # Flatten pixel values list for MiniCPM-o processing\n    pixel_values = []\n    for i in range(len(multi_modal_inputs[\"pixel_values\"])):\n        pixel_values.extend([p for p in multi_modal_inputs[\"pixel_values\"][i]])\n\n    multi_modal_inputs[\"pixel_values\"] = [pixel_values]\n    multi_modal_inputs[\"image_bound\"] = [torch.vstack(image_bounds)]\n    multi_modal_inputs[\"tgt_sizes\"] = [torch.vstack(multi_modal_inputs[\"tgt_sizes\"])]\n    multi_modal_inputs[\"input_ids\"] = input_ids\n    multi_modal_inputs[\"attention_mask\"] = attention_mask\n    multi_modal_inputs[\"position_ids\"] = position_ids\n    return {\"data\": multi_modal_inputs}\n"
  },
  {
    "path": "verl_distillation/verl/utils/debug/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# APIs kept for backward compatibility purpose\n# For new features please develop in verl/utils/profiler/\nfrom ..profiler import *  # noqa: F401\n"
  },
  {
    "path": "verl_distillation/verl/utils/debug/metrics.py",
    "content": "# Copyright 2025 Individual Contributor: TomQunChaoA\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\n\nimport torch\n\nfrom verl.protocol import DataProto\n\nlogger = logging.getLogger(__file__)\n\n\ndef calculate_token_list_diff(tensor1: torch.Tensor, tensor2: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:\n    # verify inputs\n    if tensor1.numel() == 0 or tensor2.numel() == 0:\n        return torch.zeros(tensor1.shape[0], dtype=torch.long, device=tensor1.device)\n    if tensor1.shape != tensor2.shape or mask.shape != tensor1.shape or mask.shape != tensor2.shape:\n        print(\n            f\"<WARN> dim of tensor1, tensor2, mask is not equal, {(tensor1.shape)=},{(tensor2.shape)=}, {(mask.shape)=}\"\n        )\n        return torch.ones_like(tensor1)\n    # transfer to same device\n    if tensor2.device != tensor1.device:\n        tensor2 = tensor2.to(tensor1.device)\n    if mask.device != tensor1.device:\n        mask = mask.to(tensor1.device)\n\n    # calculate diff\n    diff_mask = tensor1 != tensor2\n\n    valid_diff_mask = diff_mask & (mask == 1)\n\n    diff_counts = valid_diff_mask.sum(dim=1)\n\n    return diff_counts\n\n\ndef pearson_correlation_coefficient(tensor1: torch.Tensor, tensor2: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:\n    # implemention of https://arxiv.org/pdf/2506.13585\n    if tensor1.shape != tensor2.shape or mask.shape != tensor1.shape or mask.shape != tensor2.shape:\n        return 0\n    mt1 = torch.masked_select(tensor1, mask)\n    mt2 = torch.masked_select(tensor2, mask)\n    result = torch.corrcoef(torch.stack([mt1, mt2], dim=0))\n    return result[0][1].detach().item()\n\n\ndef calculate_log_prob_diff(log_probs1: torch.Tensor, log_probs2: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:\n    full_diff = torch.abs(log_probs1 - log_probs2)\n    return torch.masked_select(full_diff, mask)\n\n\ndef calculate_debug_metrics(data: DataProto) -> dict:\n    \"\"\"\n    calculate rollout vs actor logprobs diff, for debugging purpose\n\n    Args:\n        data: DataProto\n            the data batch to calculate\n            rollout_log_probs: log_probs record when rollout forward tokens\n            old_log_probs(actor log probs): log_probs record when actor forward tokens\n            loss_mask or attention_mask: to mask unrelated token\n            responses: the response tokens, for calculating size\n    Returns:\n        dict: metrics\n            \"training/rollout_probs_diff_valid\": 1->input is valid, 0->input is invalid\n            \"training/rollout_probs_diff_max\": max value of logprob diff of rollout vs. actor\n            \"training/rollout_probs_diff_mean\": mean value of logprob diff of rollout vs. actor\n            \"training/rollout_probs_diff_std\": std value of logprob diff of rollout vs. actor\n            \"training/rollout_actor_probs_pearson_corr\": logprob's pearson corrcoef of rollout vs. actor, reference to https://arxiv.org/pdf/2506.13585\n    \"\"\"\n\n    rollout_old_log_probs = data.batch[\"rollout_log_probs\"]\n    actor_old_log_probs = data.batch[\"old_log_probs\"]\n    if \"response_mask\" in data.batch:\n        logger.debug(\"response mask found, use it to mask log probs\")\n        log_prob_mask = data.batch[\"response_mask\"]\n    elif \"attention_mask\" in data.batch:\n        log_prob_mask = data.batch[\"attention_mask\"]\n    else:\n        logger.warning(f\"no mask info found, use all log probs, {(data.batch.keys())=}\")\n        log_prob_mask = torch.ones_like(rollout_old_log_probs)\n    responses = data.batch[\"responses\"]\n    response_length = responses.size(1)\n\n    response_mask = log_prob_mask[:, -response_length:]\n    # calculate pearson corrcoef\n    actor_probs = torch.exp(actor_old_log_probs)\n    rollout_probs = torch.exp(rollout_old_log_probs)\n    response_mask_bool = response_mask.bool()\n    pearson_corrcoef = pearson_correlation_coefficient(actor_probs, rollout_probs, response_mask_bool)\n    rollout_probs_diff = calculate_log_prob_diff(actor_probs, rollout_probs, response_mask_bool)\n    return {\n        \"training/rollout_probs_diff_valid\": 1,\n        \"training/rollout_probs_diff_max\": torch.max(rollout_probs_diff).detach().item(),\n        \"training/rollout_probs_diff_mean\": torch.mean(rollout_probs_diff).detach().item(),\n        \"training/rollout_probs_diff_std\": torch.std(rollout_probs_diff).detach().item(),\n        \"training/rollout_actor_probs_pearson_corr\": pearson_corrcoef,\n    }\n"
  },
  {
    "path": "verl_distillation/verl/utils/debug/performance.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# APIs kept for backward compatibility purpose\n# This file is deprecated, for new features please develop in profiler/performance.py\nfrom verl.utils.profiler.performance import reduce_timing, simple_timer  # noqa: F401\n"
  },
  {
    "path": "verl_distillation/verl/utils/debug/trajectory_tracker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTrajectory tracker can be inserted into code to save the intermediate results.\nThe results will be dump to hdfs for offline comparison.\nEach process will have a client that first move all the tensors to CPU\n\"\"\"\n\nimport io\nimport os\nimport tempfile\nfrom collections import deque\n\nimport ray\nimport torch\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\nremote_copy = ray.remote(copy)\n\n\n@ray.remote\ndef save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose):\n    filename = name + \".pth\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        local_filepath = os.path.join(tmpdirname, filename)\n        with open(local_filepath, \"wb\") as f:\n            f.write(data.getbuffer())\n        # upload to hdfs\n\n        if verbose:\n            print(f\"Saving {local_filepath} to {hdfs_dir}\")\n        try:\n            copy(local_filepath, hdfs_dir)\n        except Exception as e:\n            print(e)\n\n\n@ray.remote\nclass TrajectoryTracker:\n    def __init__(self, hdfs_dir, verbose) -> None:\n        self.hdfs_dir = hdfs_dir\n        makedirs(hdfs_dir)\n        self.verbose = verbose\n\n        self.handle = deque()\n\n    def dump(self, data: io.BytesIO, name):\n        # get a temp file and write to it\n        self.handle.append(save_to_hdfs.remote(data, name, self.hdfs_dir, self.verbose))\n\n    def wait_for_hdfs(self):\n        while len(self.handle) != 0:\n            future = self.handle.popleft()\n            ray.get(future)\n\n\ndef dump_data(data, name):\n    enable = os.getenv(\"VERL_ENABLE_TRACKER\", \"0\") == \"1\"\n    if not enable:\n        return\n    buffer = io.BytesIO()\n    torch.save(data, buffer)\n    tracker = get_trajectory_tracker()\n    ray.get(tracker.dump.remote(buffer, name))\n\n\ndef get_trajectory_tracker():\n    hdfs_dir = os.getenv(\"VERL_TRACKER_HDFS_DIR\", default=None)\n    verbose = os.getenv(\"VERL_TRACKER_VERBOSE\", default=\"0\") == \"1\"\n    assert hdfs_dir is not None\n    tracker = TrajectoryTracker.options(name=\"global_tracker\", get_if_exists=True, lifetime=\"detached\").remote(\n        hdfs_dir, verbose\n    )\n    return tracker\n\n\nif __name__ == \"__main__\":\n    # testing\n    os.environ[\"VERL_ENABLE_TRACKER\"] = \"1\"\n    os.environ[\"VERL_TRACKER_HDFS_DIR\"] = \"~/debug/test\"\n\n    @ray.remote\n    def process(iter):\n        data = {\"obs\": torch.randn(10, 20)}\n        dump_data(data, f\"process_{iter}_obs\")\n\n    ray.init()\n\n    output_lst = []\n\n    for i in range(10):\n        output_lst.append(process.remote(i))\n\n    out = ray.get(output_lst)\n\n    tracker = get_trajectory_tracker()\n    ray.get(tracker.wait_for_hdfs.remote())\n"
  },
  {
    "path": "verl_distillation/verl/utils/device.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# This code is inspired by the torchtune.\n# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py\n#\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license in https://github.com/pytorch/torchtune/blob/main/LICENSE\n\nimport logging\n\nimport torch\n\nlogger = logging.getLogger(__name__)\n\n\ndef is_torch_npu_available() -> bool:\n    \"\"\"Check the availability of NPU\"\"\"\n    try:\n        if hasattr(torch, \"npu\") and callable(getattr(torch.npu, \"is_available\", None)):\n            return torch.npu.is_available()\n        return False\n    except ImportError:\n        return False\n\n\nis_cuda_available = torch.cuda.is_available()\nis_npu_available = is_torch_npu_available()\n\n\ndef get_visible_devices_keyword() -> str:\n    \"\"\"Function that gets visible devices keyword name.\n    Returns:\n        'CUDA_VISIBLE_DEVICES' or `ASCEND_RT_VISIBLE_DEVICES`\n    \"\"\"\n    return \"CUDA_VISIBLE_DEVICES\" if is_cuda_available else \"ASCEND_RT_VISIBLE_DEVICES\"\n\n\ndef get_device_name() -> str:\n    \"\"\"Function that gets the torch.device based on the current machine.\n    This currently only supports CPU, CUDA, NPU.\n    Returns:\n        device\n    \"\"\"\n    if is_cuda_available:\n        device = \"cuda\"\n    elif is_npu_available:\n        device = \"npu\"\n    else:\n        device = \"cpu\"\n    return device\n\n\ndef get_torch_device() -> any:\n    \"\"\"Return the corresponding torch attribute based on the device type string.\n    Returns:\n        module: The corresponding torch device namespace, or torch.cuda if not found.\n    \"\"\"\n    device_name = get_device_name()\n    try:\n        return getattr(torch, device_name)\n    except AttributeError:\n        logger.warning(f\"Device namespace '{device_name}' not found in torch, try to load torch.cuda.\")\n        return torch.cuda\n\n\ndef get_device_id() -> int:\n    \"\"\"Return current device id based on the device type.\n    Returns:\n        device index\n    \"\"\"\n    return get_torch_device().current_device()\n\n\ndef get_nccl_backend() -> str:\n    \"\"\"Return nccl backend type based on the device type.\n    Returns:\n        nccl backend type string.\n    \"\"\"\n    if is_cuda_available:\n        return \"nccl\"\n    elif is_npu_available:\n        return \"hccl\"\n    else:\n        raise RuntimeError(f\"No available nccl backend found on device type {get_device_name()}.\")\n\n\ndef set_expandable_segments(enable: bool) -> None:\n    \"\"\"Enable or disable expandable segments for cuda.\n    Args:\n        enable (bool): Whether to enable expandable segments. Used to avoid OOM.\n    \"\"\"\n    if is_cuda_available:\n        torch.cuda.memory._set_allocator_settings(f\"expandable_segments:{enable}\")\n"
  },
  {
    "path": "verl_distillation/verl/utils/distributed.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Utilities for distributed training.\"\"\"\n\nimport ctypes\nimport os\nfrom datetime import timedelta\n\nimport ray\nimport torch.distributed\n\nfrom verl.utils.device import get_device_name, get_nccl_backend, get_torch_device, is_npu_available\n\n\ndef set_numa_affinity():\n    if is_npu_available:\n        # TODO (FightingZhen) libnuma.so is not available in e2e_ascend CI image, remove this code after image update.\n        return\n\n    initialized = False\n    try:\n        libnuma = ctypes.CDLL(\"libnuma.so\")\n        if libnuma.numa_available() < 0:\n            return\n\n        import pynvml\n\n        pynvml.nvmlInit()\n        initialized = True\n        device_name = \"NPU\" if is_npu_available else \"GPU\"\n        local_rank = int(ray.get_runtime_context().get_accelerator_ids()[device_name][0])\n        handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank)\n        pynvml.nvmlDeviceSetCpuAffinity(handle)\n    except ImportError:\n        print(\"Warning: pynvml not available, skipping NUMA affinity setup\")\n    except Exception as e:\n        print(f\"Warning: Failed to set NUMA affinity: {e}\")\n    finally:\n        if initialized:\n            pynvml.nvmlShutdown()\n\n\ndef initialize_global_process_group(timeout_second=36000):\n    torch.distributed.init_process_group(\n        get_nccl_backend(),\n        timeout=timedelta(seconds=timeout_second),\n        init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n    )\n    local_rank = int(os.environ[\"LOCAL_RANK\"])\n    rank = int(os.environ[\"RANK\"])\n    world_size = int(os.environ[\"WORLD_SIZE\"])\n\n    if torch.distributed.is_initialized():\n        get_torch_device().set_device(local_rank)\n    return local_rank, rank, world_size\n\n\ndef destroy_global_process_group():\n    if torch.distributed.is_initialized():\n        torch.distributed.destroy_process_group()\n\n\ndef initialize_global_process_group_ray(timeout_second=None):\n    # in current ray environment, LOCAL_RANK is always zero.\n\n    import torch.distributed\n\n    timeout = timedelta(seconds=timeout_second) if timeout_second is not None else None\n\n    if not torch.distributed.is_initialized():\n        rank = int(os.environ.get(\"RANK\", 0))\n        world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n        torch.distributed.init_process_group(\n            backend=f\"cpu:gloo,{get_device_name()}:{get_nccl_backend()}\",\n            rank=rank,\n            world_size=world_size,\n            timeout=timeout,\n            init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n        )\n"
  },
  {
    "path": "verl_distillation/verl/utils/experimental/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/utils/experimental/torch_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\n\n\ndef _fused_linear_for_ppo_fwd(\n    hidden_states: torch.FloatTensor,\n    vocab_weights: torch.FloatTensor,\n    input_ids: torch.LongTensor,\n    temperature: float = 1.0,\n) -> tuple[torch.FloatTensor, torch.FloatTensor]:\n    logits = (hidden_states @ vocab_weights.t()) / temperature\n    orig_dtype = logits.dtype\n    logits = logits.to(torch.float32)\n\n    # Slower but more numerically stable to do log_softmax than probs.log()\n    probs = logits.softmax(dim=-1)\n    log_probs = logits.log_softmax(dim=-1)\n\n    token_log_probs = log_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)\n    entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1)\n\n    return token_log_probs.to(orig_dtype), entropy.to(orig_dtype)\n\n\ndef _fused_linear_for_ppo_bwd(\n    dlog_probs: Optional[torch.FloatTensor],\n    dentropy: Optional[torch.FloatTensor],\n    hidden_states: torch.FloatTensor,\n    vocab_weights: torch.FloatTensor,\n    input_ids: torch.LongTensor,\n    temperature: float = 1.0,\n) -> tuple[torch.FloatTensor, torch.FloatTensor]:\n    logits = (hidden_states @ vocab_weights.t()) / temperature\n    orig_dtype = logits.dtype\n    logits = logits.to(torch.float32)\n\n    probs = logits.softmax(dim=-1)\n\n    dlogits = 0\n\n    # Gradient from log_probs\n    if dlog_probs is not None:\n        one_hot_input = torch.zeros_like(logits).scatter_(-1, input_ids.unsqueeze(-1), 1)\n        dlogits += dlog_probs.to(torch.float32).unsqueeze(-1) * (one_hot_input - probs)\n\n    # Gradient from entropy\n    if dentropy is not None:\n        log_probs = logits.log_softmax(dim=-1)\n        entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1)\n        dlogits += probs * (log_probs + entropy.unsqueeze(-1)) * (-dentropy.unsqueeze(-1))\n\n    dlogits = dlogits.to(orig_dtype) / temperature\n\n    dhidden_states = dlogits @ vocab_weights\n    dvocab_weights = dlogits.t() @ hidden_states\n\n    return dhidden_states, dvocab_weights\n\n\nclass FusedLinearForPPOFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        hidden_states: torch.FloatTensor,\n        vocab_weights: torch.FloatTensor,\n        input_ids: torch.LongTensor,\n        temperature: float = 1.0,\n        chunk_size: int = 512,\n    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:\n        ctx.set_materialize_grads(False)\n\n        # Cast to a 2D tensor of the shape [T, D] for ease of working\n        orig_ndim = hidden_states.ndim\n        assert orig_ndim in (2, 3), f\"Invalid hidden_states shape, received {hidden_states.shape}\"\n\n        orig_batch_size = -1\n        if orig_ndim == 3:\n            assert input_ids.ndim == 2, f\"input_ids shape doesn't match, {hidden_states.shape} {input_ids.shape}\"\n            orig_batch_size = hidden_states.shape[0]\n            hidden_states = hidden_states.flatten(0, 1)\n            input_ids = input_ids.flatten(0, 1)\n\n        T = hidden_states.shape[0]\n\n        # Allocate memory for outputs\n        output_requires_grad = hidden_states.requires_grad or vocab_weights.requires_grad\n        log_probs = hidden_states.new_zeros(T, requires_grad=output_requires_grad)\n        entropy = hidden_states.new_zeros(T, requires_grad=output_requires_grad)\n\n        # Perform forward one chunk at a time\n        for chunk_start in range(0, T, chunk_size):\n            chunk_end = min(chunk_start + chunk_size, T)\n\n            chunk_log_probs, chunk_entropy = _fused_linear_for_ppo_fwd(\n                hidden_states=hidden_states[chunk_start:chunk_end],\n                vocab_weights=vocab_weights,\n                input_ids=input_ids[chunk_start:chunk_end],\n                temperature=temperature,\n            )\n            log_probs[chunk_start:chunk_end] = chunk_log_probs\n            entropy[chunk_start:chunk_end] = chunk_entropy\n\n        # Cast the output back to the original input dimension\n        if orig_ndim == 3:\n            log_probs = log_probs.view(orig_batch_size, -1)\n            entropy = entropy.view(orig_batch_size, -1)\n\n        ctx.save_for_backward(hidden_states, vocab_weights, input_ids)\n        ctx.orig_batch_size = orig_batch_size\n        ctx.orig_ndim = orig_ndim\n        ctx.temperature = temperature\n        ctx.chunk_size = chunk_size\n\n        return log_probs, entropy\n\n    @staticmethod\n    def backward(ctx, dlog_probs: Optional[torch.FloatTensor], dentropy: Optional[torch.FloatTensor]):\n        assert dlog_probs is not None or dentropy is not None\n\n        hidden_states, vocab_weights, input_ids = ctx.saved_tensors\n        orig_batch_size = ctx.orig_batch_size\n        orig_ndim = ctx.orig_ndim\n        temperature = ctx.temperature\n        chunk_size = ctx.chunk_size\n\n        # Here orig_ndim refers to the orig_ndim of hidden_states\n        if orig_ndim == 3:\n            if dlog_probs is not None:\n                dlog_probs = dlog_probs.flatten()\n            if dentropy is not None:\n                dentropy = dentropy.flatten()\n\n        T = hidden_states.shape[0]\n\n        # Allocate memory for outputs\n        dhidden_states = None\n        if hidden_states.requires_grad:\n            dhidden_states = torch.zeros_like(hidden_states)\n        dvocab_weights = None\n        if vocab_weights.requires_grad:\n            dvocab_weights = torch.zeros_like(vocab_weights)\n\n        # Perform backward one chunk at a time\n        for chunk_start in range(0, T, chunk_size):\n            chunk_end = min(chunk_start + chunk_size, T)\n            chunk_dlog_probs = None\n            if dlog_probs is not None:\n                chunk_dlog_probs = dlog_probs[chunk_start:chunk_end]\n            chunk_dentropy = None\n            if dentropy is not None:\n                chunk_dentropy = dentropy[chunk_start:chunk_end]\n\n            h, v = _fused_linear_for_ppo_bwd(\n                dlog_probs=chunk_dlog_probs,\n                dentropy=chunk_dentropy,\n                hidden_states=hidden_states[chunk_start:chunk_end],\n                vocab_weights=vocab_weights,\n                input_ids=input_ids[chunk_start:chunk_end],\n                temperature=temperature,\n            )\n\n            if hidden_states.requires_grad:\n                dhidden_states[chunk_start:chunk_end] += h\n            if vocab_weights.requires_grad:\n                dvocab_weights += v\n\n        # Cast the output back to the original input dimension\n        if orig_ndim == 3 and hidden_states.requires_grad:\n            hidden_size = hidden_states.shape[-1]\n            dhidden_states = dhidden_states.view(orig_batch_size, -1, hidden_size)\n\n        return (\n            dhidden_states,  # hidden_states\n            dvocab_weights,  # vocab_weights\n            None,  # input_ids\n            None,  # temperature\n            None,  # chunk_size\n        )\n\n\nclass FusedLinearForPPO(torch.nn.Module):\n    def __init__(self, chunk_size: int = 512):\n        super().__init__()\n\n        self.chunk_size = chunk_size\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        vocab_weights: torch.FloatTensor,\n        input_ids: torch.LongTensor,\n        temperature: float = 1.0,\n    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:\n        input_ids = input_ids.to(torch.int64)\n        return FusedLinearForPPOFunction.apply(\n            hidden_states,\n            vocab_weights,\n            input_ids,\n            temperature,\n            self.chunk_size,\n        )\n"
  },
  {
    "path": "verl_distillation/verl/utils/flops_counter.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom transformers import PretrainedConfig\n\nfrom verl.utils.device import get_torch_device\n\nVALID_CONFIG_TYPE = {\n    \"llama\",\n    \"qwen2\",\n    \"qwen2_moe\",\n    \"qwen2_vl\",\n    \"qwen2_5_vl\",\n    \"qwen3\",\n    \"qwen3_moe\",\n    \"qwen3_vl\",\n    \"qwen3_vl_moe\",\n    \"deepseek_v3\",\n    \"minicpmv\",\n    \"minicpmo\",\n    \"mistral\",\n    \"gemma3_text\",\n    \"seed_oss\",\n    \"apertus\",\n    \"glm4v\",\n}\n\n\ndef get_device_flops(unit=\"T\"):\n    \"\"\"Get the theoretical FLOPS (Floating Point Operations Per Second) capacity of the current device.\n\n    Args:\n        unit (str): The unit to return the FLOPS in. Supported values are:\n            \"B\" - Billion (1e9)\n            \"K\" - Thousand (1e3)\n            \"M\" - Million (1e6)\n            \"G\" - Giga (1e9)\n            \"T\" - Tera (1e12, default)\n            \"P\" - Peta (1e15)\n\n    Returns:\n        float: The theoretical FLOPS capacity of the current device in the specified unit.\n        Returns float('inf') for unknown GPU types.\n    \"\"\"\n\n    def unit_convert(number, level):\n        units = [\"B\", \"K\", \"M\", \"G\", \"T\", \"P\"]\n        if number <= 0:\n            return number\n        ptr = 0\n        while ptr < len(units) and units[ptr] != level:\n            number /= 1000\n            ptr += 1\n        return number\n\n    device = get_torch_device()\n    if device == torch.cpu:\n        device_name = \"CPU\"\n    else:\n        device_name = get_torch_device().get_device_name()\n    flops = float(\"inf\")  # INF flops for unkown gpu type\n\n    if \"CPU\" in device_name:\n        # use a general CPU flops placeholder to make the function CPU compatible\n        flops = 448e9\n    elif \"GB200\" in device_name:\n        flops = 2.5e15\n    elif \"B200\" in device_name:\n        flops = 2.25e15\n    elif \"MI300X\" in device_name:\n        flops = 1336e12\n    elif \"H100\" in device_name or \"H800\" in device_name or \"H200\" in device_name:\n        flops = 989e12\n    elif \"A100\" in device_name or \"A800\" in device_name:\n        flops = 312e12\n    elif \"L40S\" in device_name:\n        flops = 362.05e12\n    elif \"L40\" in device_name:\n        flops = 181.05e12\n    elif \"A40\" in device_name:\n        flops = 149.7e12\n    elif \"L20\" in device_name:\n        flops = 119.5e12\n    elif \"H20\" in device_name:\n        flops = 148e12\n    elif \"910B\" in device_name:\n        flops = 354e12\n    elif \"Ascend910\" in device_name:\n        flops = 354e12\n    elif \"RTX 3070 Ti\" in device_name:\n        flops = 21.75e12\n    flops_unit = unit_convert(flops, unit)\n    return flops_unit\n\n\nclass FlopsCounter:\n    \"\"\"\n    Used to count mfu during training loop\n\n    Example:\n        flops_counter = FlopsCounter(config)\n        flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)\n\n    \"\"\"\n\n    def __init__(self, config: PretrainedConfig):\n        if config.model_type not in VALID_CONFIG_TYPE:\n            print(\n                f\"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. MFU will always be \"\n                f\"zero.\"\n            )\n\n        self.estimate_func = {\n            \"qwen2\": self._estimate_qwen2_flops,\n            \"llama\": self._estimate_qwen2_flops,\n            \"qwen2_moe\": self._estimate_qwen2_moe_flops,\n            \"qwen2_vl\": self._estimate_qwen2_flops,\n            \"qwen2_5_vl\": self._estimate_qwen2_flops,\n            \"qwen3\": self._estimate_qwen2_flops,\n            \"qwen3_moe\": self._estimate_qwen2_moe_flops,\n            \"qwen3_vl\": self._estimate_qwen2_flops,\n            \"qwen3_vl_moe\": self._estimate_qwen2_moe_flops,\n            \"deepseek_v3\": self._estimate_deepseek_v3_flops,\n            \"minicpmv\": self._estimate_qwen2_flops,\n            \"minicpmo\": self._estimate_qwen2_flops,\n            \"mistral\": self._estimate_qwen2_flops,\n            \"gemma3_text\": self._estimate_gemma3_flops,\n            \"seed_oss\": self._estimate_qwen2_flops,\n            \"apertus\": self._estimate_apertus_flops,\n            \"glm4v\": self._estimate_qwen2_flops,\n        }\n        self.config = getattr(config, \"text_config\", config)\n\n    def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):\n        return 0\n\n    def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        num_hidden_layers = self.config.num_hidden_layers\n        num_key_value_heads = self.config.num_key_value_heads\n        num_attention_heads = self.config.num_attention_heads\n        intermediate_size = self.config.intermediate_size\n\n        head_dim = getattr(self.config, \"head_dim\", self.config.hidden_size // self.config.num_attention_heads)\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # non-attn per layer parm\n        # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp\n        mlp_N = hidden_size * intermediate_size * 3\n        attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * dense_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen\n        attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        moe_intermediate_size = self.config.moe_intermediate_size\n        num_hidden_layers = self.config.num_hidden_layers\n        first_k_dense_replace = self.config.first_k_dense_replace\n        num_query_heads = self.config.num_attention_heads\n        moe_num_expert = self.config.n_routed_experts\n\n        moe_topk = self.config.num_experts_per_tok\n        share_expert_num = self.config.n_shared_experts\n\n        # non-attn per layer parm\n        moe_gata_N = hidden_size * moe_num_expert\n        # moe has fc1_1, fc1_2 and fc2 using SwiGLU in ExpertMlp layer & shared experts\n        moe_expertmlp_N = hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3\n        # MLA attn\n        attn_linear_N = 0\n        q_head_dim = self.config.qk_nope_head_dim + self.config.qk_rope_head_dim\n        if self.config.q_lora_rank is None:\n            attn_linear_N += hidden_size * num_query_heads * q_head_dim\n        else:\n            attn_linear_N += hidden_size * self.config.q_lora_rank\n            attn_linear_N += num_query_heads * q_head_dim * self.config.q_lora_rank\n\n        attn_linear_N += hidden_size * (self.config.kv_lora_rank + self.config.qk_rope_head_dim)\n        attn_linear_N += (\n            num_query_heads\n            * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim)\n            * self.config.kv_lora_rank\n        )\n        attn_linear_N += num_query_heads * self.config.v_head_dim * hidden_size\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        moe_N = (\n            (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace)\n            + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace\n            + emd_and_lm_head_N\n        )\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * moe_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen * num_hidden_layers\n\n        attn_qkv_flops = 12 * seqlen_square_sum * q_head_dim * num_query_heads\n        # all_layer & all_token fwd & bwk flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n\n        return flops_achieved\n\n    def _estimate_qwen2_moe_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        num_hidden_layers = self.config.num_hidden_layers\n        num_key_value_heads = self.config.num_key_value_heads\n        num_attention_heads = self.config.num_attention_heads\n        moe_intermediate_size = self.config.moe_intermediate_size\n        moe_topk = self.config.num_experts_per_tok\n        num_experts = self.config.num_experts\n\n        head_dim = getattr(self.config, \"head_dim\", self.config.hidden_size // self.config.num_attention_heads)\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # non-attn per layer parm\n        # gate + moe export\n        moe_mlp_N = hidden_size * moe_topk * moe_intermediate_size * 3 + hidden_size * num_experts\n        attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * dense_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen\n        attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def _estimate_gemma3_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        num_hidden_layers = self.config.num_hidden_layers\n        num_key_value_heads = self.config.num_key_value_heads\n        num_attention_heads = self.config.num_attention_heads\n        intermediate_size = self.config.intermediate_size\n\n        head_dim = getattr(self.config, \"head_dim\", self.config.hidden_size // self.config.num_attention_heads)\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # non-attn per layer parm\n        # Gemma3 uses GeGLU (gelu_pytorch_tanh), having 3 matrices in MLP (inherited from Gemma2MLP)\n        mlp_N = hidden_size * intermediate_size * 3\n        attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * dense_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        # Gemma3 alternates between full and sliding window attention based on layer_types\n        seqlen_square_sum = 0\n\n        layer_types = getattr(self.config, \"layer_types\", None)\n        sliding_window = getattr(self.config, \"sliding_window\", 1024)  # default 1024\n        # default pattern: every 6th layer is full\n        sliding_window_pattern = getattr(self.config, \"sliding_window_pattern\", 6)\n\n        # If layer_types is not provided, generate it based on sliding_window_pattern\n        if layer_types is None and sliding_window is not None and sliding_window_pattern is not None:\n            layer_types = [\n                \"sliding_attention\" if bool((i + 1) % sliding_window_pattern) else \"full_attention\"\n                for i in range(num_hidden_layers)\n            ]\n\n        if layer_types:\n            # Calculate attention flops per layer based on attention type\n            for layer_idx in range(num_hidden_layers):\n                is_sliding = False\n                if layer_types and layer_idx < len(layer_types):\n                    is_sliding = layer_types[layer_idx] == \"sliding_attention\"\n\n                for seqlen in batch_seqlens:\n                    if is_sliding and sliding_window:\n                        # Sliding window limits each token to attend to at most window_size tokens\n                        effective_seqlen = min(seqlen, sliding_window)\n                        seqlen_square_sum += seqlen * effective_seqlen\n                    else:\n                        # Full attention\n                        seqlen_square_sum += seqlen * seqlen\n        else:\n            # If no layer_types config, assume all layers use full attention\n            for seqlen in batch_seqlens:\n                seqlen_square_sum += seqlen * seqlen\n            seqlen_square_sum *= num_hidden_layers\n\n        attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def _estimate_apertus_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        num_hidden_layers = self.config.num_hidden_layers\n        num_key_value_heads = self.config.num_key_value_heads\n        num_attention_heads = self.config.num_attention_heads\n        intermediate_size = self.config.intermediate_size\n\n        head_dim = getattr(self.config, \"head_dim\", self.config.hidden_size // self.config.num_attention_heads)\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # Apertus MLP with XIELU activation uses only 2 linear layers (up_proj, down_proj)\n        # No gate_proj for XIELU, unlike SwiGLU which has 3 layers\n        mlp_N = hidden_size * intermediate_size * 2\n        attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)\n\n        # ApertusConfig has qk_norm defaulting to True.\n        # This adds params for q_norm (on H) and k_norm (on num_kv_heads * head_dim)\n        qk_norm_params_per_layer = hidden_size + num_key_value_heads * head_dim  # q_norm + k_norm\n\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer params\n        dense_N = (mlp_N + attn_linear_N + qk_norm_params_per_layer) * num_hidden_layers + emd_and_lm_head_N\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * dense_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen\n        attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def estimate_flops(self, batch_seqlens, delta_time):\n        \"\"\"\n        Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.\n\n        Args:\n            batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the\n                current batch.\n            delta_time (float): The time taken to process the batch, in seconds.\n\n        Returns:\n            estimated_flops (float): The estimated FLOPS based on the input tokens and time.\n            promised_flops (float): The expected FLOPS of the current device.\n        \"\"\"\n        tokens_sum = sum(batch_seqlens)\n        func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)\n        estimated_flops = func(tokens_sum, batch_seqlens, delta_time)\n        promised_flops = get_device_flops()\n        return estimated_flops, promised_flops\n"
  },
  {
    "path": "verl_distillation/verl/utils/fs.py",
    "content": "#!/usr/bin/env python\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# -*- coding: utf-8 -*-\n\"\"\"File-system agnostic IO APIs\"\"\"\n\nimport hashlib\nimport os\nimport shutil\nimport tempfile\n\ntry:\n    from hdfs_io import copy, exists, makedirs  # for internal use only\nexcept ImportError:\n    from .hdfs_io import copy, exists, makedirs\n\n__all__ = [\"copy\", \"exists\", \"makedirs\"]\n\n_HDFS_PREFIX = \"hdfs://\"\n\n\ndef is_non_local(path):\n    \"\"\"Check if a path is a non-local (HDFS) path.\n\n    Args:\n        path (str): The path to check.\n\n    Returns:\n        bool: True if the path is an HDFS path, False otherwise.\n    \"\"\"\n    return path.startswith(_HDFS_PREFIX)\n\n\ndef md5_encode(path: str) -> str:\n    \"\"\"Generate an MD5 hash of a path string.\n\n    This function is used to create unique identifiers for paths, typically\n    for creating cache directories or lock files.\n\n    Args:\n        path (str): The path to encode.\n\n    Returns:\n        str: The hexadecimal MD5 hash of the path.\n    \"\"\"\n    return hashlib.md5(path.encode()).hexdigest()\n\n\ndef get_local_temp_path(hdfs_path: str, cache_dir: str) -> str:\n    \"\"\"Generate a unique local cache path for an HDFS resource.\n    Creates a MD5-hashed subdirectory in cache_dir to avoid name conflicts,\n    then returns path combining this subdirectory with the HDFS basename.\n\n    Args:\n        hdfs_path (str): Source HDFS path to be cached\n        cache_dir (str): Local directory for storing cached files\n\n    Returns:\n        str: Absolute local filesystem path in format:\n            {cache_dir}/{md5(hdfs_path)}/{basename(hdfs_path)}\n    \"\"\"\n    # make a base64 encoding of hdfs_path to avoid directory conflict\n    encoded_hdfs_path = md5_encode(hdfs_path)\n    temp_dir = os.path.join(cache_dir, encoded_hdfs_path)\n    os.makedirs(temp_dir, exist_ok=True)\n    dst = os.path.join(temp_dir, os.path.basename(hdfs_path))\n    return dst\n\n\ndef verify_copy(src: str, dest: str) -> bool:\n    \"\"\"\n    verify the copy of src to dest by comparing their sizes and file structures.\n\n    return:\n        bool: True if the copy is verified, False otherwise.\n    \"\"\"\n    if not os.path.exists(src):\n        return False\n    if not os.path.exists(dest):\n        return False\n\n    if os.path.isfile(src) != os.path.isfile(dest):\n        return False\n\n    if os.path.isfile(src):\n        src_size = os.path.getsize(src)\n        dest_size = os.path.getsize(dest)\n        if src_size != dest_size:\n            return False\n        return True\n\n    src_files = set()\n    dest_files = set()\n\n    for root, dirs, files in os.walk(src):\n        rel_path = os.path.relpath(root, src)\n        dest_root = os.path.join(dest, rel_path) if rel_path != \".\" else dest\n\n        if not os.path.exists(dest_root):\n            return False\n\n        for entry in os.listdir(root):\n            src_entry = os.path.join(root, entry)\n            src_files.add(os.path.relpath(src_entry, src))\n\n        for entry in os.listdir(dest_root):\n            dest_entry = os.path.join(dest_root, entry)\n            dest_files.add(os.path.relpath(dest_entry, dest))\n\n    if src_files != dest_files:\n        return False\n\n    for rel_path in src_files:\n        src_entry = os.path.join(src, rel_path)\n        dest_entry = os.path.join(dest, rel_path)\n\n        if os.path.isdir(src_entry) != os.path.isdir(dest_entry):\n            return False\n\n        if os.path.isfile(src_entry):\n            src_size = os.path.getsize(src_entry)\n            dest_size = os.path.getsize(dest_entry)\n            if src_size != dest_size:\n                return False\n\n    return True\n\n\ndef copy_to_shm(src: str):\n    \"\"\"\n    Load the model into   /dev/shm   to make the process of loading the model multiple times more efficient.\n    \"\"\"\n    shm_model_root = \"/dev/shm/verl-cache/\"\n    src_abs = os.path.abspath(os.path.normpath(src))\n    dest = os.path.join(shm_model_root, hashlib.md5(src_abs.encode(\"utf-8\")).hexdigest())\n    os.makedirs(dest, exist_ok=True)\n    dest = os.path.join(dest, os.path.basename(src_abs))\n    if os.path.exists(dest) and verify_copy(src, dest):\n        # inform user and depends on him\n        print(\n            f\"[WARNING]: The memory model path {dest} already exists. If it is not you want, please clear it and \"\n            f\"restart the task.\"\n        )\n    else:\n        if os.path.isdir(src):\n            shutil.copytree(src, dest, symlinks=False, dirs_exist_ok=True)\n        else:\n            shutil.copy2(src, dest)\n    return dest\n\n\ndef _record_directory_structure(folder_path):\n    record_file = os.path.join(folder_path, \".directory_record.txt\")\n    with open(record_file, \"w\") as f:\n        for root, dirs, files in os.walk(folder_path):\n            for dir_name in dirs:\n                relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path)\n                f.write(f\"dir:{relative_dir}\\n\")\n            for file_name in files:\n                if file_name != \".directory_record.txt\":\n                    relative_file = os.path.relpath(os.path.join(root, file_name), folder_path)\n                    f.write(f\"file:{relative_file}\\n\")\n    return record_file\n\n\ndef _check_directory_structure(folder_path, record_file):\n    if not os.path.exists(record_file):\n        return False\n    existing_entries = set()\n    for root, dirs, files in os.walk(folder_path):\n        for dir_name in dirs:\n            relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path)\n            existing_entries.add(f\"dir:{relative_dir}\")\n        for file_name in files:\n            if file_name != \".directory_record.txt\":\n                relative_file = os.path.relpath(os.path.join(root, file_name), folder_path)\n                existing_entries.add(f\"file:{relative_file}\")\n    with open(record_file) as f:\n        recorded_entries = set(f.read().splitlines())\n    return existing_entries == recorded_entries\n\n\ndef copy_to_local(\n    src: str, cache_dir=None, filelock=\".file.lock\", verbose=False, always_recopy=False, use_shm: bool = False\n) -> str:\n    \"\"\"Copy files/directories from HDFS to local cache with validation.\n\n    Args:\n        src (str): Source path - HDFS path (hdfs://...), local filesystem path, or Hugging Face model ID\n        cache_dir (str, optional): Local directory for cached files. Uses system tempdir if None\n        filelock (str): Base name for file lock. Defaults to \".file.lock\"\n        verbose (bool): Enable copy operation logging. Defaults to False\n        always_recopy (bool): Force fresh copy ignoring cache. Defaults to False\n        use_shm (bool): Enable shared memory copy. Defaults to False\n\n    Returns:\n        str: Local filesystem path to copied resource\n    \"\"\"\n    # Save to a local path for persistence.\n    local_path = copy_local_path_from_hdfs(src, cache_dir, filelock, verbose, always_recopy)\n\n    if use_shm and isinstance(local_path, str) and not os.path.exists(local_path):\n        try:\n            from huggingface_hub import snapshot_download\n\n            resolved = snapshot_download(local_path)\n            if isinstance(resolved, str) and os.path.exists(resolved):\n                local_path = resolved\n        except ImportError:\n            pass\n        except Exception as e:\n            print(f\"WARNING: Failed to download model from Hugging Face: {e}\")\n\n    # Load into shm to improve efficiency.\n    if use_shm:\n        return copy_to_shm(local_path)\n    return local_path\n\n\ndef copy_local_path_from_hdfs(\n    src: str, cache_dir=None, filelock=\".file.lock\", verbose=False, always_recopy=False\n) -> str:\n    \"\"\"Deprecated. Please use copy_to_local instead.\"\"\"\n    from filelock import FileLock\n\n    assert src[-1] != \"/\", f\"Make sure the last char in src is not / because it will cause error. Got {src}\"\n\n    if is_non_local(src):\n        # download from hdfs to local\n        if cache_dir is None:\n            # get a temp folder\n            cache_dir = tempfile.gettempdir()\n        os.makedirs(cache_dir, exist_ok=True)\n        assert os.path.exists(cache_dir)\n        local_path = get_local_temp_path(src, cache_dir)\n        # get a specific lock\n        filelock = md5_encode(src) + \".lock\"\n        lock_file = os.path.join(cache_dir, filelock)\n        with FileLock(lock_file=lock_file):\n            if always_recopy and os.path.exists(local_path):\n                if os.path.isdir(local_path):\n                    shutil.rmtree(local_path, ignore_errors=True)\n                else:\n                    os.remove(local_path)\n            if not os.path.exists(local_path):\n                if verbose:\n                    print(f\"Copy from {src} to {local_path}\")\n                copy(src, local_path)\n                if os.path.isdir(local_path):\n                    _record_directory_structure(local_path)\n            elif os.path.isdir(local_path):\n                # always_recopy=False, local path exists, and it is a folder: check whether there is anything missed\n                record_file = os.path.join(local_path, \".directory_record.txt\")\n                if not _check_directory_structure(local_path, record_file):\n                    if verbose:\n                        print(f\"Recopy from {src} to {local_path} due to missing files or directories.\")\n                    shutil.rmtree(local_path, ignore_errors=True)\n                    copy(src, local_path)\n                    _record_directory_structure(local_path)\n        return local_path\n    else:\n        return src\n\n\ndef local_mkdir_safe(path):\n    \"\"\"_summary_\n    Thread-safe directory creation function that ensures the directory is created\n    even if multiple processes attempt to create it simultaneously.\n\n    Args:\n        path (str): The path to create a directory at.\n    \"\"\"\n\n    from filelock import FileLock\n\n    if not os.path.isabs(path):\n        working_dir = os.getcwd()\n        path = os.path.join(working_dir, path)\n\n    # Using hash value of path as lock file name to avoid long file name\n    lock_filename = f\"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock\"\n    lock_path = os.path.join(tempfile.gettempdir(), lock_filename)\n\n    try:\n        with FileLock(lock_path, timeout=60):  # Add timeout\n            # make a new dir\n            os.makedirs(path, exist_ok=True)\n    except Exception as e:\n        print(f\"Warning: Failed to acquire lock for {path}: {e}\")\n        # Even if the lock is not acquired, try to create the directory\n        os.makedirs(path, exist_ok=True)\n\n    return path\n"
  },
  {
    "path": "verl_distillation/verl/utils/fsdp_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nimport itertools\nimport json\nimport math\nimport os\nfrom abc import ABC\nfrom collections import OrderedDict\nfrom contextlib import contextmanager, nullcontext\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom packaging import version\nfrom torch.distributed import DeviceMesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp._runtime_utils import _lazy_init\nfrom torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy\nfrom transformers.trainer_pt_utils import get_module_class_from_name\n\nfrom verl.utils.device import get_device_id, get_device_name, get_torch_device\nfrom verl.utils.model import check_exclude_modules, check_target_modules\n\nif version.parse(torch.__version__) >= version.parse(\"2.6\"):\n    from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard\n    from torch.distributed.tensor import Shard\n\n    fully_shard_module = torch.distributed.fsdp._fully_shard._fully_shard\nelif version.parse(torch.__version__) >= version.parse(\"2.4\"):\n    from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard\n\n    fully_shard_module = torch.distributed._composable.fsdp.fully_shard\nelse:\n    fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy, fully_shard_module = None, None, None, None, None\n\n\ndef init_fn(x: torch.nn.Module):\n    if torch.distributed.get_rank() != 0:\n        x = x.to_empty(device=get_device_id(), recurse=False)\n        get_torch_device().empty_cache()\n    return x\n\n\ndef get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None):\n    from accelerate import init_empty_weights\n\n    cpu_init_weights = lambda: torch.device(\"cpu\")\n    if use_meta_tensor:\n        if mesh is None:\n            init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights\n        else:\n            init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights\n    else:\n        init_context = cpu_init_weights\n    return init_context\n\n\n# Copyright 2020-present the HuggingFace Inc. team.\n# Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py\ndef get_fsdp_wrap_policy(module, config=None, is_lora=False):\n    \"\"\"Get FSDP wrap policy for the module.\n\n    Args:\n        module: The module to get wrap policy for\n        config: Configuration for wrap policy\n        is_lora: Whether to enable lambda policy for LoRA modules\n    \"\"\"\n    if config is None:\n        config = {}\n\n    # NOTE: This is a temporary workaround to be compatible with the OmegaConf & dataclass. We will remove this\n    # once we have make all config in verl from OmegaConf to data class.\n    def _get_attr(attr_name, default_value=None):\n        if hasattr(config, \"get\"):\n            return config.get(attr_name, default_value)\n        else:\n            return config.__getattribute__(attr_name)\n\n    if _get_attr(\"disable\", False):\n        return None\n\n    default_transformer_cls_names_to_wrap = getattr(module, \"_no_split_modules\", None)\n    fsdp_transformer_layer_cls_to_wrap = _get_attr(\n        \"transformer_layer_cls_to_wrap\", default_transformer_cls_names_to_wrap\n    )\n    min_num_params = _get_attr(\"min_num_params\", 0)\n    auto_wrap_policy = None\n\n    policies = []\n\n    from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy\n\n    # Add lambda policy for LoRA modules if is_lora is True\n    if is_lora:\n\n        def lambda_policy_fn(module):\n            return bool(\n                len(list(module.named_children())) == 0\n                and getattr(module, \"weight\", None) is not None\n                and module.weight.requires_grad\n            )\n\n        lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)\n        policies.append(lambda_policy)\n\n    if min_num_params > 0:\n        size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params)\n        policies.append(size_policy)\n    elif fsdp_transformer_layer_cls_to_wrap is not None:\n        transformer_cls_to_wrap = set()\n        for layer_class in fsdp_transformer_layer_cls_to_wrap:\n            transformer_cls = get_module_class_from_name(module, layer_class)\n            if transformer_cls is None:\n                raise Exception(\"Could not find the transformer layer class to wrap in the model.\")\n            else:\n                transformer_cls_to_wrap.add(transformer_cls)\n\n        transformer_policy = functools.partial(\n            transformer_auto_wrap_policy,\n            transformer_layer_cls=transformer_cls_to_wrap,\n        )\n        policies.append(transformer_policy)\n\n    if len(policies) > 0:\n        auto_wrap_policy = functools.partial(_or_policy, policies=policies)\n\n    return auto_wrap_policy\n\n\n@torch.no_grad()\ndef offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):\n    if fsdp_version(model) == 2:\n        offload_fsdp2_model_to_cpu(model, empty_cache)\n        return\n\n    assert isinstance(model, FSDP)\n    # lazy init FSDP model\n    _lazy_init(model, model)\n    assert model._is_root, \"Only support root model offloading to CPU\"\n    for handle in model._all_handles:\n        if handle._offload_params:\n            continue\n        flat_param = handle.flat_param\n        assert (\n            flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()\n            and id(flat_param.data) != id(flat_param._local_shard)\n            and flat_param.data.size() == flat_param._local_shard.size()\n        )\n        handle.flat_param_to(torch.device(\"cpu\"), non_blocking=True)\n        # the following still keeps id(._local_shard) != id(.data)\n        flat_param._local_shard = flat_param.data\n        assert id(flat_param._local_shard) != id(flat_param.data)\n    if empty_cache:\n        get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef offload_fsdp2_model_to_cpu(model, empty_cache: bool = True):\n    model.cpu()\n    if empty_cache:\n        get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef load_fsdp_model_to_gpu(model: FSDP):\n    if fsdp_version(model) == 2:\n        load_fsdp2_model_to_gpu(model)\n        return\n\n    assert isinstance(model, FSDP)\n    # lazy init FSDP model\n    _lazy_init(model, model)\n    assert model._is_root, \"Only support root model loading to GPU\"\n    device_id = get_device_id()\n    for handle in model._all_handles:\n        if handle._offload_params:\n            continue\n        flat_param = handle.flat_param\n        handle.flat_param_to(torch.device(f\"{get_device_name()}:{device_id}\"), non_blocking=True)\n        # the following still keeps id(._local_shard) != id(.data)\n        flat_param._local_shard = flat_param.data\n\n\n@torch.no_grad()\ndef load_fsdp2_model_to_gpu(model):\n    device = get_device_id()\n    model.to(device)\n\n\n@torch.no_grad()\ndef offload_fsdp_optimizer(optimizer):\n    if not optimizer.state:\n        return\n    for param_group in optimizer.param_groups:\n        for param in param_group[\"params\"]:\n            state = optimizer.state[param]\n            for key, value in state.items():\n                if isinstance(value, torch.Tensor):\n                    state[key] = value.to(\"cpu\", non_blocking=True)\n\n\n@torch.no_grad()\ndef load_fsdp_optimizer(optimizer, device_id):\n    if not optimizer.state:\n        return\n    for param_group in optimizer.param_groups:\n        for param in param_group[\"params\"]:\n            state = optimizer.state[param]\n            for key, value in state.items():\n                if isinstance(value, torch.Tensor):\n                    state[key] = value.to(device_id, non_blocking=True)\n\n\n@contextmanager\ndef meta_device_init():\n    \"\"\"\n    Create model parameters with meta device.\n\n    Note buffers in model will still be initialized in default device (e.g., CPU),\n    since the buffers can be non-persistent and filled with expected values that can\n    NOT be captured in meta device.\n    \"\"\"\n    device = torch.device(\"meta\")\n    old_register_parameter = nn.Module.register_parameter\n    registered = set()\n\n    def register_empty_parameter(module, name, param):\n        old_register_parameter(module, name, param)\n        # we will skip register shared parameters as it\n        # is already registered previously\n        if param is not None and param not in registered:\n            param_cls = type(module._parameters[name])\n            kwargs = module._parameters[name].__dict__\n            kwargs[\"requires_grad\"] = param.requires_grad\n            module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)\n            registered.add(module._parameters[name])\n\n    try:\n        nn.Module.register_parameter = register_empty_parameter\n        yield\n    finally:\n        registered.clear()\n        nn.Module.register_parameter = old_register_parameter\n\n\ndef parallel_load_safetensors(filepath):\n    \"\"\"\n    Parallel load safetensors from huggingface checkpoint\n\n    Huggingface checkpoint contains:\n\n    - config.json: a json file for model configuration\n    - model.safetensor.index.json: a json file for safetensors (parameters & buffers) index\n    - model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks\n\n    Or (when model is small),\n\n    - model.safetensors: a binary file for all parameters and buffers\n\n    Each rank will own a part of model chunks and load them directly into GPU memory.\n    \"\"\"\n    from safetensors.torch import load_file\n\n    safetensors2param = {}\n\n    index_file = os.path.join(filepath, \"model.safetensors.index.json\")\n    if os.path.exists(index_file):\n        index = json.load(open(index_file, \"rb\"))\n        for param_name, filename in index[\"weight_map\"].items():\n            safetensors2param.setdefault(filename, []).append(param_name)\n    else:\n        # in this case, the model is small and we can load it all at once\n        param_file = os.path.join(filepath, \"model.safetensors\")\n        assert os.path.exists(param_file), f\"Cannot find {param_file}\"\n        states = load_file(param_file)\n        for param_name in states:\n            safetensors2param.setdefault(\"model.safetensors\", []).append(param_name)\n        del states\n\n    total_files = len(safetensors2param)\n    ckpt_chunks = sorted(safetensors2param.keys())\n    world_size = dist.get_world_size()\n    size = int(math.ceil(total_files / world_size))\n    ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)]\n\n    shard_states = {}\n    device = get_device_id()\n    for rank, files in enumerate(ckpt_chunks):\n        if rank == dist.get_rank():\n            for file in files:\n                file = os.path.join(filepath, file)\n                states = load_file(file, device=device)\n                # print(f\"rank {rank} loading {file}...\")\n                shard_states.update(states)\n        else:\n            for file in files:\n                for param_name in safetensors2param[file]:\n                    shard_states[param_name] = rank\n    return shard_states\n\n\ndef parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter]):\n    \"\"\"\n    Generate a function to initialize sub-modules in the `module` with `shard_states`\n    from huggingface checkpoint.\n\n    Args:\n        module (torch.nn.Module): the global module to be initialized\n        shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint\n\n    Returns:\n        init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states`\n    \"\"\"\n\n    state2fqn = {}\n    for name, state in itertools.chain(\n        module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False)\n    ):\n        state2fqn.setdefault(state, []).append(name)\n    # remove standalone parameters and buffers\n    shared = {s for s, names in state2fqn.items() if len(names) > 1}\n    materialized_states = {}\n\n    @torch.no_grad()\n    def create_and_sync_state(param_name, state, is_param):\n        assert param_name in shard_states, f\"{param_name} not loaded\"\n        device = get_device_id()\n        if is_param:\n            param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad)\n        else:  # buffer\n            param = torch.empty_like(state.data, device=device)\n        loaded = shard_states[param_name]\n        if isinstance(loaded, torch.nn.Parameter | torch.Tensor):\n            # NOTE: loaded.dtype can be different with param.dtype\n            param.data.copy_(loaded.data)\n            dist.broadcast(param.data, src=dist.get_rank())\n        else:\n            assert isinstance(loaded, int)  # the rank that holds the state\n            dist.broadcast(param.data, src=loaded)\n        shard_states.pop(param_name)\n        del loaded\n        return param\n\n    def init_fn(sub_mod: torch.nn.Module, recurse: bool = True):\n        param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False))\n        # param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0])\n        for name, state in param_and_buffers:\n            if not state.is_meta:\n                continue\n            is_param = name in sub_mod._parameters\n            fqn = state2fqn[state].pop(0)\n            # non-persistent buffers will not be saved in state dict, we can safely skip it\n            if (not is_param) and fqn not in shard_states:\n                if state.is_meta:\n                    raise RuntimeError(\n                        f\"find a non-persistent buffer ({fqn}) initiated with device meta. Such buffer is not saved \"\n                        f\"in checkpoint and user should guarantee to init in CPU / GPU device.\"\n                    )\n                continue\n            # for shared parameter, we get it from the first time it is created\n            if state in shared:\n                if state not in materialized_states:\n                    materialized_states[state] = create_and_sync_state(fqn, state, is_param)\n                else:\n                    if fqn in shard_states:\n                        shard_states.pop(fqn)\n                materialize_state = materialized_states[state]\n            # for not shared parameter, we create it directly\n            else:\n                materialize_state = create_and_sync_state(fqn, state, is_param)\n            if is_param:\n                sub_mod._parameters[name] = materialize_state\n            else:\n                sub_mod._buffers[name] = materialize_state\n        if recurse:\n            for module in sub_mod.children():\n                init_fn(module, recurse=True)\n\n        # for debug\n        # if len(shard_states) == 0: print(\"clear\")\n        return sub_mod\n\n    return init_fn\n\n\ndef fsdp_version(model):\n    if isinstance(model, FSDP):\n        return 1\n    elif isinstance(model, FSDPModule):\n        return 2\n    else:\n        return 0\n\n\ndef get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg):\n    if fsdp_version(model) == 1:\n        return FSDP.state_dict_type(model, state_type, state_cfg, optim_cfg)\n    else:\n        return nullcontext()\n\n\ndef get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True, rank0_only: bool = True):\n    \"\"\"\n    Get the full state dict from an FSDP model.\n\n    Args:\n        model (torch.nn.Module): The FSDP model to get state dict from\n        offload_to_cpu (bool, optional): Whether to offload the state dict to CPU. Defaults to True.\n        rank0_only (bool, optional): Whether to only get state dict on rank 0. Defaults to True.\n\n    Returns:\n        dict: The full state dict of the model\n\n    Raises:\n        NotImplementedError: If the FSDP version is unknown\n    \"\"\"\n    if fsdp_version(model) == 1:\n        from torch.distributed.fsdp import FullStateDictConfig, StateDictType\n\n        state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only)\n        with get_fsdp_state_ctx(\n            model, state_type=StateDictType.FULL_STATE_DICT, state_cfg=state_dict_config, optim_cfg=None\n        ):\n            state_dict = model.state_dict()\n        return state_dict\n    elif fsdp_version(model) == 2:\n        from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict\n\n        state_dict_config = StateDictOptions(\n            full_state_dict=True, cpu_offload=offload_to_cpu, broadcast_from_rank0=not rank0_only\n        )\n        state_dict = get_model_state_dict(model, options=state_dict_config)\n        return state_dict\n    else:\n        raise NotImplementedError(f\"Unknown FSDP version {fsdp_version}\")\n\n\ndef fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None):\n    \"\"\"\n    Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the\n    parameters from rank 0 to all other ranks. This function modifies the model in-place.\n\n    Args:\n        model (`torch.nn.Module`): The model to load the state dict into\n        full_state (`dict`): The full state dict to load, can only be on rank 0\n    \"\"\"\n\n    if version.parse(torch.__version__) >= version.parse(\"2.7.0\"):\n        from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict\n    else:\n        # official torch 2.6.0 set_model_state_dict API leads to OOM\n        # use torch 2.7.0 copy from verl/third_party/torch/distributed/checkpoint\n        from verl.third_party.torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict\n\n    # To broadcast, it needs to be instantiated in the GPU.\n    if dist.get_rank() == 0:\n        model = model.to(device=get_device_id(), non_blocking=True)\n    else:\n        model = model.to_empty(device=get_device_id())\n\n    cpu_offload = cpu_offload is not None\n    options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True)\n    set_model_state_dict(model, full_state, options=options)\n\n    # rotary_emb is not in state_dict, so we need to broadcast it manually\n    for name, buf in model.named_buffers():\n        dist.broadcast(buf, src=0)\n\n    if cpu_offload:\n        model.to(\"cpu\", non_blocking=True)\n        for buf in model.buffers():\n            buf.data = buf.data.to(get_device_id())\n\n\n@contextmanager\ndef maybe_patch_fsdp_module(model):\n    if fully_shard_module is None:\n        yield\n        return\n\n    orig_fsdp_module = fully_shard_module.FSDPModule\n\n    class FSDPModuleABC(ABC, orig_fsdp_module):\n        pass\n\n    try:\n        if isinstance(model, ABC):\n            fully_shard_module.FSDPModule = FSDPModuleABC\n        yield\n    finally:\n        fully_shard_module.FSDPModule = orig_fsdp_module\n\n\ndef apply_fsdp2(model, fsdp_kwargs, config):\n    \"\"\"model: AutoModelForCausalLM\"\"\"\n    assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n\n    default_transformer_cls_names_to_wrap = getattr(model, \"_no_split_modules\", None)\n    fsdp_transformer_layer_cls_to_wrap = config.get(\"wrap_policy\", {}).get(\n        \"transformer_layer_cls_to_wrap\", default_transformer_cls_names_to_wrap\n    )\n\n    if isinstance(fsdp_transformer_layer_cls_to_wrap, str):\n        fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]\n\n    assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None\n\n    modules = []\n    for name, module in model.named_modules():\n        if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (\n            isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings\n        ):\n            modules.append(module)\n\n    for idx, module in enumerate(modules):\n        # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:\n        #     print(f\"wrap module {module.__class__.__name__}\")\n        with maybe_patch_fsdp_module(module):\n            fully_shard(module, **fsdp_kwargs)\n\n    # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:\n    #     print(f\"wrap module {model.__class__.__name__}\")\n    with maybe_patch_fsdp_module(model):\n        fully_shard(model, **fsdp_kwargs)  # fsdp2 will not reshard_after_forward for root module\n\n\ndef get_shard_placement_fn(fsdp_size):\n    \"\"\"Choose the dimension that can divide fsdp_size to avoid padding\"\"\"\n\n    def shard_placement_fn(param):\n        shape = list(param.shape)\n        for i in range(len(shape)):\n            if shape[i] % fsdp_size == 0:\n                return Shard(i)\n        return Shard(0)\n\n    return shard_placement_fn\n\n\ndef fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None):\n    \"\"\"torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor\"\"\"\n    from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm\n\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    else:\n        # prevent generators from being exhausted\n        parameters = list(parameters)\n    grads = [p.grad for p in parameters if p.grad is not None]\n    total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)\n    total_norm = total_norm.to(get_device_id(), non_blocking=True)\n    _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)\n    return total_norm\n\n\ndef layered_summon_lora_params(fsdp_module) -> OrderedDict:\n    from peft.utils.save_and_load import get_peft_model_state_dict\n\n    def __prefix_submodules(module, prefix):\n        for name, submodule in module.named_modules():\n            if name.startswith(prefix) and \".\" not in name[len(prefix) :]:\n                yield name, submodule\n\n    lora_params = OrderedDict()\n    prefix_list = [\n        # fsdp\n        \"_fsdp_wrapped_module.base_model.model.\",\n        \"_fsdp_wrapped_module.base_model.model.model.\",\n        \"_fsdp_wrapped_module.base_model.model.model.layers.\",\n        \"_fsdp_wrapped_module.base_model.model.model.language_model.layers.\",\n        # fsdp2\n        \"base_model.model.\",\n        \"base_model.model.model.\",\n        \"base_model.model.model.layers.\",\n        \"base_model.model.model.language_model.layers.\",\n    ]\n    peft_model = getattr(fsdp_module, \"_fsdp_wrapped_module\", fsdp_module)\n    for prefix in prefix_list:\n        for name, submodule in __prefix_submodules(fsdp_module, prefix):\n            prefix = name.replace(\"_fsdp_wrapped_module.base_model.model.\", \"base_model.model.\")\n            if name.endswith(\".model\") or name.endswith(\".layers\"):\n                continue\n            if fsdp_version(submodule) > 0:\n                with FSDP.summon_full_params(submodule, writeback=False):\n                    sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict())\n                    sub_lora_params = {\n                        f\"{prefix}.{name}\": param.full_tensor().detach().cpu()\n                        if hasattr(param, \"full_tensor\")\n                        else param.detach().cpu()\n                        for name, param in sub_lora_params.items()\n                    }\n                    lora_params.update(sub_lora_params)\n                    submodule._is_root = False\n                get_torch_device().empty_cache()\n    return lora_params\n\n\ndef collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool) -> OrderedDict:\n    \"\"\"\n    collect lora params or full params if base model is not ready in vllm\n    work with if isinstance(self.module._fsdp_wrapped_module, PeftModel)\n    \"\"\"\n    from peft.utils.save_and_load import get_peft_model_state_dict\n\n    lora_params = OrderedDict()\n    peft_model = getattr(module, \"_fsdp_wrapped_module\", module)\n    if fsdp_version(module) > 0:\n        if layered_summon:\n            if not base_sync_done:\n                raise ValueError(\n                    \"To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let \"\n                    \"rollout.load_format=safetensors\"\n                )\n            lora_params = layered_summon_lora_params(module)\n        else:\n            with FSDP.summon_full_params(module, writeback=False):\n                if base_sync_done:\n                    lora_params = get_peft_model_state_dict(peft_model)\n                    lora_params = {\n                        name: param.full_tensor().detach().cpu()\n                        if hasattr(param, \"full_tensor\")\n                        else param.detach().cpu()\n                        for name, param in lora_params.items()\n                    }\n                else:\n                    model = peft_model.base_model.model\n                    orig_dev = \"cpu\" if \"cpu\" in str(next(model.parameters()).device) else get_device_name()\n                    model = model.to(\"cpu\")\n                    for name, param in model.state_dict().items():\n                        if any(x in name for x in [\"_flat_param\", \"lora_\"]):\n                            continue\n                        name = name.replace(\"_fsdp_wrapped_module.\", \"\").replace(\".base_layer\", \"\")\n                        lora_params[name] = (\n                            param.full_tensor().detach().cpu()\n                            if hasattr(param, \"full_tensor\")\n                            else param.detach().cpu()\n                        )\n                    model = model.to(orig_dev)\n            get_torch_device().empty_cache()\n    else:\n        if base_sync_done:\n            lora_params = get_peft_model_state_dict(peft_model)\n        else:\n            model = peft_model.base_model.model\n            orig_dev = \"cpu\" if \"cpu\" in str(next(model.parameters()).device) else get_device_name()\n            model = model.to(\"cpu\")\n            for name, param in model.state_dict().items():\n                if any(x in name for x in [\"_flat_param\", \"lora_\"]):\n                    continue\n                name = name.replace(\"_fsdp_wrapped_module.\", \"\").replace(\".base_layer\", \"\")\n                lora_params[name] = param.detach().cpu()\n            model = model.to(orig_dev)\n    return lora_params\n\n\ndef replace_lora_wrapper(k, peft_config):\n    \"\"\"Replace LoRA parameter keys with base layer equivalents.\n\n    Transforms LoRA parameter names to their corresponding base layer\n    names for proper weight loading in vLLM when base model sync is not done.\n\n    Args:\n        k (str): Original parameter key name.\n\n    Returns:\n        str: Transformed parameter key for base layer.\n    \"\"\"\n    stacked_params = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]\n    if k.endswith(\".weight\"):\n        module_k = k[: -len(\".weight\")]\n        if check_exclude_modules(peft_config, module_k):\n            return k\n        elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(peft_config, module_k):\n            return f\"{module_k}.base_layer.weight\"\n    if k.endswith(\".bias\"):\n        module_k = k[: -len(\".bias\")]\n        if check_exclude_modules(peft_config, module_k):\n            return k\n        elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(peft_config, module_k):\n            return f\"{module_k}.base_layer.bias\"\n    return k\n"
  },
  {
    "path": "verl_distillation/verl/utils/groupwise.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nGroup-wise helpers for RL training utilities.\n\nPublic API:\n    - as_torch_index(index, device=None) -> torch.LongTensor\n    - group_mean_std(scores, gidx, eps=1e-6, device=None) -> (mean_g, std_g, count_g)\n\nDefault device policy:\n    - If `device` is None:\n        * In pytest (detected by env \"PYTEST_CURRENT_TEST\"): use CPU.\n        * Else if CUDA is available: use CUDA.\n        * Else: use CPU.\n    - You can override via env \"VERL_FORCE_DEVICE\" (e.g., \"cuda:0\" / \"cpu\").\n\nNotes:\n- as_torch_index: canonicalizes arbitrary group labels to a contiguous 1-D torch.long\n  tensor in range [0..G-1]. Robust to torch/numpy/list/tuple, ints/floats/bools,\n  numeric strings, UUIDs, mixed object arrays. Near-integer floats (|x-round(x)|<=1e-6)\n  are rounded; otherwise factorization is applied.\n- group_mean_std: pure-PyTorch per-group mean/std with Bessel correction for variance\n  (denominator max(count-1, 1)). Singleton groups fallback to mean=0, std=1 for\n  compatibility with common “native” conventions.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom typing import Any, Optional\n\nimport numpy as np\nimport torch\n\nfrom verl.utils.device import get_torch_device\n\n__all__ = [\"as_torch_index\", \"group_mean_std\"]\n\n\ndef _resolve_device(explicit: Optional[torch.device | str]) -> torch.device:\n    \"\"\"\n    Resolve device according to policy described in the module docstring.\n    Priority:\n      1) explicit argument\n      2) VERL_FORCE_DEVICE env\n      3) pytest detection -> cpu\n      4) cuda if available, else cpu\n    \"\"\"\n    if explicit is not None:\n        return torch.device(explicit)\n\n    forced = os.getenv(\"VERL_FORCE_DEVICE\")\n    if forced:\n        return torch.device(forced)\n\n    # Heuristic: pytest sets PYTEST_CURRENT_TEST\n    if \"PYTEST_CURRENT_TEST\" in os.environ:\n        return torch.device(\"cpu\")\n\n    return get_torch_device()\n\n\ndef _to_1d_numpy_object_array(x: Any) -> np.ndarray:\n    \"\"\"Best-effort: convert arbitrary input into a 1-D numpy array; fallback to object dtype.\"\"\"\n    try:\n        arr = np.asarray(x)\n    except Exception:\n        try:\n            arr = np.array(list(x), dtype=object)\n        except Exception:\n            arr = np.array([x], dtype=object)\n    if arr.ndim != 1:\n        arr = arr.reshape(-1)\n    return arr\n\n\ndef as_torch_index(index: Any, device: torch.device | str | None = None) -> torch.Tensor:\n    \"\"\"\n    Convert arbitrary group labels to a contiguous 1-D torch.long tensor (0..G-1).\n\n    Args:\n        index: Any iterable of labels or tensor/ndarray.\n        device: Target device; if None, resolved via _resolve_device().\n\n    Returns:\n        torch.LongTensor with shape (N,)\n    \"\"\"\n    target = _resolve_device(device)\n\n    # ---------- Fast path: torch.Tensor ----------\n    if isinstance(index, torch.Tensor):\n        t = index.reshape(-1)\n        if t.dtype in (\n            torch.int64,\n            torch.int32,\n            torch.int16,\n            torch.int8,\n            getattr(torch, \"uint8\", torch.uint8),\n            torch.bool,\n        ):\n            return t.to(device=target, dtype=torch.long)\n\n        if t.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16):\n            t64 = t.to(dtype=torch.float64)\n            rounded = torch.round(t64)\n            if torch.allclose(t64, rounded, rtol=0.0, atol=1e-6):\n                return rounded.to(device=target, dtype=torch.long)\n            arr = np.array([str(x.item()) for x in t], dtype=object)\n        else:\n            arr = np.array([str(x.item()) if hasattr(x, \"item\") else str(x) for x in t], dtype=object)\n\n    else:\n        # ---------- Non-torch: go through numpy ----------\n        arr = _to_1d_numpy_object_array(index)\n\n        # Pure integers (incl. bool)\n        if arr.dtype != object and np.issubdtype(arr.dtype, np.integer):\n            return torch.from_numpy(arr.astype(np.int64, copy=False)).to(device=target)\n\n        # Floats nearly equal to integers\n        if arr.dtype != object and np.issubdtype(arr.dtype, np.floating):\n            arr64 = arr.astype(np.float64, copy=False)\n            rounded = np.rint(arr64)\n            if np.allclose(arr64, rounded, rtol=0.0, atol=1e-6):\n                return torch.from_numpy(rounded.astype(np.int64)).to(device=target)\n            # fall through\n\n        # Try numeric string coercion\n        try:\n            coerced = arr.astype(np.int64)\n            return torch.from_numpy(coerced).to(device=target)\n        except Exception:\n            pass\n\n        if arr.dtype != object:\n            arr = arr.astype(object)\n\n    # ---------- Factorization (UUIDs / mixed types / arbitrary labels) ----------\n    try:\n        _, inv = np.unique(arr, return_inverse=True)\n    except Exception:\n        sarr = np.array([str(x) for x in arr], dtype=object)\n        _, inv = np.unique(sarr, return_inverse=True)\n\n    inv = inv.astype(np.int64, copy=False)\n    return torch.from_numpy(inv).to(device=target)\n\n\n@torch.no_grad()\ndef group_mean_std(\n    scores: torch.Tensor,\n    gidx: torch.Tensor,\n    eps: float = 1e-6,\n    device: torch.device | str | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute per-group mean/std/count in pure PyTorch.\n\n    mean_g = sum / count\n    std_g  = sqrt( max( (sum2 - sum^2/count) / max(count-1, 1), eps ) )\n\n    Singleton groups fallback to mean=0, std=1.\n\n    Args:\n        scores: (N,) float tensor.\n        gidx  : (N,) long/int tensor with group indices (0..G-1).\n        eps   : Numerical floor for variance.\n        device: Target device; if None, resolved via _resolve_device().\n\n    Returns:\n        mean_g: (G,) float32\n        std_g : (G,) float32\n        count : (G,) float32\n    \"\"\"\n    target = _resolve_device(device)\n\n    scores = scores.reshape(-1).to(device=target, dtype=torch.float32)\n    gidx = gidx.reshape(-1).to(device=target, dtype=torch.long)\n\n    if scores.numel() != gidx.numel():\n        raise ValueError(f\"scores and gidx length mismatch: {scores.numel()} vs {gidx.numel()}\")\n\n    G = int(torch.max(gidx).item()) + 1 if gidx.numel() > 0 else 0\n    if G == 0:\n        # Return empty tensors on the selected device\n        empty = torch.empty(0, device=target, dtype=torch.float32)\n        return empty, empty, empty\n\n    ones = torch.ones_like(scores, dtype=torch.float32)\n\n    count = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, ones)\n    s1 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores)\n    s2 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores * scores)\n\n    mean = s1 / count.clamp_min(1.0)\n    var_num = s2 - (s1 * s1) / count.clamp_min(1.0)\n    denom = (count - 1.0).clamp_min(1.0)\n    var = var_num / denom\n    std = torch.sqrt(torch.clamp(var, min=eps))\n\n    # Singleton groups: mean=0, std=1\n    single = count <= 1.0\n    if torch.any(single):\n        mean = mean.clone()\n        std = std.clone()\n        mean[single] = 0.0\n        std[single] = 1.0\n\n    return mean, std, count\n"
  },
  {
    "path": "verl_distillation/verl/utils/hdfs_io.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport shutil\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_SFT_LOGGING_LEVEL\", \"WARN\"))\n\n_HDFS_PREFIX = \"hdfs://\"\n\n_HDFS_BIN_PATH = shutil.which(\"hdfs\")\n\n\ndef exists(path: str, **kwargs) -> bool:\n    r\"\"\"Works like os.path.exists() but supports hdfs.\n\n    Test whether a path exists. Returns False for broken symbolic links.\n\n    Args:\n        path (str): path to test\n\n    Returns:\n        bool: True if the path exists, False otherwise\n    \"\"\"\n    if _is_non_local(path):\n        return _exists(path, **kwargs)\n    return os.path.exists(path)\n\n\ndef _exists(file_path: str):\n    \"\"\"hdfs capable to check whether a file_path is exists\"\"\"\n    if file_path.startswith(\"hdfs\"):\n        return _run_cmd(_hdfs_cmd(f\"-test -e {file_path}\")) == 0\n    return os.path.exists(file_path)\n\n\ndef makedirs(name, mode=0o777, exist_ok=False, **kwargs) -> None:\n    r\"\"\"Works like os.makedirs() but supports hdfs.\n\n    Super-mkdir; create a leaf directory and all intermediate ones.  Works like\n    mkdir, except that any intermediate path segment (not just the rightmost)\n    will be created if it does not exist. If the target directory already\n    exists, raise an OSError if exist_ok is False. Otherwise no exception is\n    raised.  This is recursive.\n\n    Args:\n        name (str): directory to create\n        mode (int): file mode bits\n        exist_ok (bool): if True, do not raise an exception if the directory already exists\n        kwargs: keyword arguments for hdfs\n\n    \"\"\"\n    if _is_non_local(name):\n        # TODO(haibin.lin):\n        # - handle OSError for hdfs(?)\n        # - support exist_ok for hdfs(?)\n        _mkdir(name, **kwargs)\n    else:\n        os.makedirs(name, mode=mode, exist_ok=exist_ok)\n\n\ndef _mkdir(file_path: str) -> bool:\n    \"\"\"hdfs mkdir\"\"\"\n    if file_path.startswith(\"hdfs\"):\n        _run_cmd(_hdfs_cmd(f\"-mkdir -p {file_path}\"))\n    else:\n        os.makedirs(file_path, exist_ok=True)\n    return True\n\n\ndef copy(src: str, dst: str, **kwargs) -> bool:\n    r\"\"\"Works like shutil.copy() for file, and shutil.copytree for dir, and supports hdfs.\n\n    Copy data and mode bits (\"cp src dst\"). Return the file's destination.\n    The destination may be a directory.\n    If source and destination are the same file, a SameFileError will be\n    raised.\n\n    Arg:\n        src (str): source file path\n        dst (str): destination file path\n        kwargs: keyword arguments for hdfs copy\n\n    Returns:\n        str: destination file path\n\n    \"\"\"\n    if _is_non_local(src) or _is_non_local(dst):\n        # TODO(haibin.lin):\n        # - handle SameFileError for hdfs files(?)\n        # - return file destination for hdfs files\n        return _copy(src, dst)\n    else:\n        if os.path.isdir(src):\n            return shutil.copytree(src, dst, **kwargs)\n        else:\n            return shutil.copy(src, dst, **kwargs)\n\n\ndef _copy(from_path: str, to_path: str, timeout: int = None) -> bool:\n    if to_path.startswith(\"hdfs\"):\n        if from_path.startswith(\"hdfs\"):\n            returncode = _run_cmd(_hdfs_cmd(f\"-cp -f {from_path} {to_path}\"), timeout=timeout)\n        else:\n            returncode = _run_cmd(_hdfs_cmd(f\"-put -f {from_path} {to_path}\"), timeout=timeout)\n    else:\n        if from_path.startswith(\"hdfs\"):\n            returncode = _run_cmd(\n                _hdfs_cmd(\n                    f\"-get \\\n                {from_path} {to_path}\"\n                ),\n                timeout=timeout,\n            )\n        else:\n            try:\n                shutil.copy(from_path, to_path)\n                returncode = 0\n            except shutil.SameFileError:\n                returncode = 0\n            except Exception as e:\n                logger.warning(f\"copy {from_path} {to_path} failed: {e}\")\n                returncode = -1\n    return returncode == 0\n\n\ndef _run_cmd(cmd: str, timeout=None):\n    return os.system(cmd)\n\n\ndef _hdfs_cmd(cmd: str) -> str:\n    return f\"{_HDFS_BIN_PATH} dfs {cmd}\"\n\n\ndef _is_non_local(path: str):\n    return path.startswith(_HDFS_PREFIX)\n"
  },
  {
    "path": "verl_distillation/verl/utils/import_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities to check if packages are available.\nWe assume package availability won't change during runtime.\n\"\"\"\n\nimport importlib\nimport importlib.util\nimport os\nimport warnings\nfrom functools import cache, wraps\nfrom typing import Optional\n\n\n@cache\ndef is_megatron_core_available():\n    try:\n        mcore_spec = importlib.util.find_spec(\"megatron.core\")\n    except ModuleNotFoundError:\n        mcore_spec = None\n    return mcore_spec is not None\n\n\n@cache\ndef is_vllm_available():\n    try:\n        vllm_spec = importlib.util.find_spec(\"vllm\")\n    except ModuleNotFoundError:\n        vllm_spec = None\n    return vllm_spec is not None\n\n\n@cache\ndef is_sglang_available():\n    try:\n        sglang_spec = importlib.util.find_spec(\"sglang\")\n    except ModuleNotFoundError:\n        sglang_spec = None\n    return sglang_spec is not None\n\n\n@cache\ndef is_nvtx_available():\n    try:\n        nvtx_spec = importlib.util.find_spec(\"nvtx\")\n    except ModuleNotFoundError:\n        nvtx_spec = None\n    return nvtx_spec is not None\n\n\n@cache\ndef is_trl_available():\n    try:\n        trl_spec = importlib.util.find_spec(\"trl\")\n    except ModuleNotFoundError:\n        trl_spec = None\n    return trl_spec is not None\n\n\ndef import_external_libs(external_libs=None):\n    if external_libs is None:\n        return\n    if not isinstance(external_libs, list):\n        external_libs = [external_libs]\n    import importlib\n\n    for external_lib in external_libs:\n        importlib.import_module(external_lib)\n\n\ndef load_extern_type(file_path: Optional[str], type_name: Optional[str]) -> type:\n    \"\"\"Load a external data type based on the file path and type name\"\"\"\n    if not file_path:\n        return None\n\n    if file_path.startswith(\"pkg://\"):\n        # pkg://verl.utils.dataset.rl_dataset\n        # pkg://verl/utils/dataset/rl_dataset\n        module_name = file_path[6:].replace(\"/\", \".\")\n        module = importlib.import_module(module_name)\n\n    else:\n        # file://verl/utils/dataset/rl_dataset\n        # file:///path/to/verl/utils/dataset/rl_dataset.py\n        # or without file:// prefix\n        if file_path.startswith(\"file://\"):\n            file_path = file_path[7:]\n\n        if not os.path.exists(file_path):\n            raise FileNotFoundError(f\"Custom type file '{file_path}' not found.\")\n\n        spec = importlib.util.spec_from_file_location(\"custom_module\", file_path)\n        module = importlib.util.module_from_spec(spec)\n        try:\n            spec.loader.exec_module(module)\n        except Exception as e:\n            raise RuntimeError(f\"Error loading module from '{file_path}'\") from e\n\n    if not hasattr(module, type_name):\n        raise AttributeError(f\"Custom type '{type_name}' not found in '{file_path}'.\")\n\n    return getattr(module, type_name)\n\n\ndef _get_qualified_name(func):\n    \"\"\"Get full qualified name including module and class (if any).\"\"\"\n    module = func.__module__\n    qualname = func.__qualname__\n    return f\"{module}.{qualname}\"\n\n\ndef deprecated(replacement: str = \"\"):\n    \"\"\"Decorator to mark functions or classes as deprecated.\"\"\"\n\n    def decorator(obj):\n        qualified_name = _get_qualified_name(obj)\n\n        if isinstance(obj, type):\n            original_init = obj.__init__\n\n            @wraps(original_init)\n            def wrapped_init(self, *args, **kwargs):\n                msg = f\"Warning: Class '{qualified_name}' is deprecated.\"\n                if replacement:\n                    msg += f\" Please use '{replacement}' instead.\"\n                warnings.warn(msg, category=FutureWarning, stacklevel=2)\n                return original_init(self, *args, **kwargs)\n\n            obj.__init__ = wrapped_init\n            return obj\n\n        else:\n\n            @wraps(obj)\n            def wrapped(*args, **kwargs):\n                msg = f\"Warning: Function '{qualified_name}' is deprecated.\"\n                if replacement:\n                    msg += f\" Please use '{replacement}' instead.\"\n                warnings.warn(msg, category=FutureWarning, stacklevel=2)\n                return obj(*args, **kwargs)\n\n            return wrapped\n\n    return decorator\n"
  },
  {
    "path": "verl_distillation/verl/utils/kernel/__init__.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "verl_distillation/verl/utils/kernel/kernels.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplementations of the linear cross entropy with token entropy kernel.\n\"\"\"\n\nimport typing\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.distributed as dist\n\ntry:\n    import triton\n    import triton.language as tl\n\n    HAVE_TRITON = True\nexcept ImportError:\n    HAVE_TRITON = False\n\nfrom verl.utils.device import get_torch_device\n\nif not HAVE_TRITON:\n    from contextlib import contextmanager\n    from unittest.mock import MagicMock\n\n    @contextmanager\n    def null_decorator(*args, **kwargs):\n        if len(kwargs) == 0 and len(args) == 1 and callable(args[0]):\n            return args[0]\n        else:\n\n            def inner(func):\n                return func\n\n            return inner\n\n    triton = MagicMock()\n    triton.jit = null_decorator\n    triton.autotune = null_decorator\n    tl = MagicMock()\n\n\n@dataclass\nclass EntropyReductionEnum:\n    \"\"\"\n    Enum for the reduction method of cross entropy.\n    \"\"\"\n\n    _None = 0\n    _Sum = 1\n    _Mean = 2\n\n\ndef get_entropy_reduction_enum_number(reduction: str) -> int:\n    \"\"\"\n    Get the enum number for the reduction method of cross entropy.\n    \"\"\"\n    _enum = EntropyReductionEnum._None\n    if reduction == \"none\":\n        _enum = EntropyReductionEnum._None\n    elif reduction == \"sum\":\n        _enum = EntropyReductionEnum._Sum\n    elif reduction == \"mean\":\n        _enum = EntropyReductionEnum._Mean\n    else:\n        raise ValueError(f\"Invalid reduction: {reduction}\")\n    return _enum\n\n\ndef get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum:\n    \"\"\"\n    Get the enum for the reduction method of cross entropy.\n    \"\"\"\n    _enum = EntropyReductionEnum._None\n    if ce_reduction == 0:\n        _enum = EntropyReductionEnum._None\n    elif ce_reduction == 1:\n        _enum = EntropyReductionEnum._Sum\n    elif ce_reduction == 2:\n        _enum = EntropyReductionEnum._Mean\n    else:\n        raise ValueError(f\"Invalid ce_reduction: {ce_reduction}\")\n    return _enum\n\n\n@dataclass\nclass BackwardEnum:\n    \"\"\"\n    Enum for the backward method.\n    \"\"\"\n\n    _Total_Fuse_MN = (\n        0  # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight\n    )\n    _Total_Separate = 1  # Store d_logits, no special requirements for d_hidden & d_weight\n    _Split_Dlogits_N = 2  # split d_logits along its N dimension, aka. vocab_size\n    _Split_Dlogits_M = 3  # split d_logits along its M dimension, aka. num_tokens\n\n\n@dataclass\nclass Config:\n    \"\"\"Configuration for efficient entropy kernel operations.\n\n    Args:\n        _backward (BackwardEnum): Backward computation method. Defaults to BackwardEnum._Split_Dlogits_N.\n        _use_triton (bool): Whether to use Triton kernels for computation. Defaults to True.\n    \"\"\"\n\n    _backward: BackwardEnum = BackwardEnum._Split_Dlogits_N\n    _use_triton: bool = True\n\n\n_config = Config()\n\n\ndef set_backward_method(backward_method: BackwardEnum):\n    \"\"\"\n    Set the backward method.\n    \"\"\"\n    global _config\n    _config._backward = backward_method\n\n\n@triton.autotune(\n    configs=[triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32}, num_stages=3, num_warps=8)],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_kernel_general_mainloop(\n    rank,\n    hidden_ptr,\n    weight_ptr,\n    labels_ptr,\n    num_tokens,\n    hidden_size,\n    vocab_size,\n    vocab_per_split,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    max_ptr,\n    stride_max_m: tl.int64,\n    stride_max_n: tl.int64,\n    accu_ptr,\n    stride_accu_m: tl.int64,\n    stride_accu_n: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b_m: tl.int64,\n    stride_entropy_b_n: tl.int64,\n    global_logprobs_ptr,\n    stride_global_logprobs: tl.int64,\n    global_logprobs_scalar_ptr,\n    rcp_temperature: tl.float32,\n    # Meta-parameters\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n):\n    \"\"\"\n    forward mainloop\n    \"\"\"\n    pid = tl.program_id(axis=0)\n    num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N)\n    pid_m = pid % num_pid_m\n    pid_n = pid // num_pid_m\n\n    if pid_m == 0 and pid_n == 0:\n        tl.store(global_logprobs_scalar_ptr, 0.0)\n\n    # create pointers for the first blocks of hidden\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n\n    # load labels for this block\n    labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens)\n\n    # traverse over N dimension\n    # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    _max = tl.full((BLOCK_SIZE_M,), -float(\"inf\"), dtype=tl.float32)\n    _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    for n in range(0, num_pid_n):\n        offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n        # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n)\n        weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n        # iterate over K dimension\n        logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n            # load the next block of hidden and weight\n            _hidden = tl.load(\n                hidden_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n                other=0.0,\n            )\n            # _weight = tl.load(weight_ptrs,\n            #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < (min(\n            #                       (pid_n + 1) * vocab_per_split, vocab_size))),\n            #                   other=0.0)\n\n            _weight = tl.load(\n                weight_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K)\n                & (offs_bn[:, None] < (min((pid_n + 1) * vocab_per_split, vocab_size))),\n                other=0.0,\n            )\n\n            # GEMM\n            logits = tl.dot(_hidden, _weight.trans(), logits)\n\n            # advance the ptrs to the next K block\n            hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n            weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n        # reset hidden_ptrs for next iteration\n        hidden_ptrs -= hidden_size * stride_hidden_k\n\n        # scale logits by temperature\n        logits *= rcp_temperature\n\n        # update global maximum\n        _max_old = _max\n        m_pid_n = tl.max(logits, axis=1)\n        _max = tl.maximum(_max_old, m_pid_n)\n\n        exp_logits = tl.exp(logits - _max[:, None])\n        coeff = tl.exp(_max_old - _max)\n        _accu = coeff * _accu + tl.sum(exp_logits, axis=1)\n\n        _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1)\n\n        label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n        _logprobs += tl.sum(logits * label_mask, axis=1)\n\n    # store maximum\n    offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_max_n = pid_n\n    maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m\n    tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits))\n\n    # store entropy\n    accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m\n    tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits))\n    entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m\n    tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits))\n\n    # store logprobs\n    vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size\n    vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size\n    mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx)\n    mask &= offs_am < num_tokens\n    global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs\n    # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask)\n    tl.store(global_logprobs_ptrs, _logprobs, mask=mask)\n\n\n@triton.autotune(configs=[triton.Config({\"BLOCK_SIZE_M\": 16, \"BLOCK_SIZE_N\": 64})], key=[\"num_tokens\", \"num_splits\"])\n@triton.jit\ndef efficient_entropy_triton_kernel_epilogue(\n    max_ptr,\n    stride_max_m: tl.int64,\n    stride_max_n: tl.int64,\n    num_tokens,\n    num_splits,\n    global_max_ptr,\n    stride_global_max: tl.int64,\n    accu_ptr,\n    stride_accu_m: tl.int64,\n    stride_accu_n: tl.int64,\n    global_accu_ptr,\n    stride_global_accu: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b_m: tl.int64,\n    stride_entropy_b_n: tl.int64,\n    global_entropy_b_ptr,\n    stride_global_entropy_b: tl.int64,\n    global_entropy_ptr,\n    stride_global_entropy: tl.int64,\n    global_logprobs_ptr,\n    stride_global_logprobs: tl.int64,\n    global_logprobs_scalar_ptr,\n    reduction: int,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n):\n    \"\"\"\n    foward epilogue\n    \"\"\"\n    pid_m = tl.program_id(axis=0)\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)):\n        offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n        max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n\n\n        _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0)\n\n        accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n\n        _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0)\n\n        entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n\n        _entropy_b = tl.load(\n            entropy_b_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0\n        )\n\n        # local reduction\n        _max_old = global_max\n        _local_max = tl.max(_max, axis=1)\n        global_max = tl.maximum(global_max, _local_max)\n\n        _scale = tl.exp(_max - global_max[:, None])\n        _coeff = tl.exp(_max_old - global_max)\n        global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1)\n        global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1)\n\n    # store\n    maximum_ptrs = global_max_ptr + offs_m * stride_global_max\n    tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens)\n\n    # store entropy_b\n    global_entropy_b = tl.fdiv(global_entropy_b, global_accu)  # entropy_b\n    tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens)\n\n    # store entropy\n    global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu\n    tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens)\n    global_entropy = tl.log(global_accu) + global_max - global_entropy_b  # entropy_a\n    global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy\n    tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens)\n    # update logprobs\n    global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs\n    global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens)\n    global_logprobs = global_max + tl.log(global_accu) - global_logprobs\n\n    global_logprobs = -1 * global_logprobs\n    if reduction == 0:\n        tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens)\n    elif reduction == 1:\n        global_logprobs_scalar = tl.sum(global_logprobs, axis=0)\n        tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar)\n    elif reduction == 2:\n        global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32)\n        tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar)\n\n\n@triton.autotune(configs=[triton.Config({\"BLOCK_SIZE_M\": 16, \"BLOCK_SIZE_N\": 64})], key=[\"num_tokens\", \"num_splits\"])\n@triton.jit\ndef efficient_entropy_triton_kernel_epilogue_tp(\n    num_tokens,\n    num_splits,\n    reduced_max_ptr,\n    stride_reduced_max_m: tl.int64,\n    stride_reduced_max_n: tl.int64,\n    original_max_ptr,\n    stride_original_max_m: tl.int64,\n    stride_original_max_n: tl.int64,\n    accu_ptr,\n    stride_accu_m: tl.int64,\n    stride_accu_n: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b_m: tl.int64,\n    stride_entropy_b_n: tl.int64,\n    global_max_ptr,\n    stride_global_max: tl.int64,\n    global_accu_ptr,\n    stride_global_accu: tl.int64,\n    global_entropy_b_ptr,\n    stride_global_entropy_b: tl.int64,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n):\n    pid_m = tl.program_id(axis=0)\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n    global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)):\n        offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n        _reduced_max = tl.load(\n            reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + offs_n[None, :] * stride_reduced_max_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n        _original_max = tl.load(\n            original_max_ptr + offs_m[:, None] * stride_original_max_m + offs_n[None, :] * stride_original_max_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n        _accu = tl.load(\n            accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n\n        # local reduce-max\n        _max_old = global_max\n        _local_max = tl.max(_reduced_max, axis=1)\n        global_max = tl.maximum(global_max, _local_max)\n\n        # update accumulate\n        _coeff = tl.exp(_max_old - global_max)\n        _scale = tl.exp(_original_max - global_max[:, None])\n        global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1)\n\n        # update entropy_b\n        _entropy_b = tl.load(\n            entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n        global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1)\n\n    # store\n    tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens)\n    tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens)\n    tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens)\n\n\n@triton.autotune(configs=[triton.Config({\"BLOCK_SIZE_M\": 16})], key=[\"num_tokens\"])\n@triton.jit\ndef efficient_entropy_triton_epilogue_tp_update(\n    num_tokens,\n    logprobs_ptr,\n    stride_logprobs: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accumulate_ptr,\n    stride_accumulate: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    entropy_ptr,\n    stride_entropy: tl.int64,\n    logprobs_scalar_ptr,\n    reduction: int,\n    BLOCK_SIZE_M: tl.constexpr,\n):\n    pid_m = tl.program_id(axis=0)\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n    maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens)\n    accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens)\n\n    entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens)\n    entropy_b = tl.fdiv(entropy_b, accumulate)\n    tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens)\n\n    entropy = tl.log(accumulate) + maximum - entropy_b\n    tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens)\n\n    logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens)\n    logprobs = maximum + tl.log(accumulate) - logprobs\n\n    logprobs = -1 * logprobs\n    if reduction == 0:\n        tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens)\n    elif reduction == 1:\n        logprobs_scalar = tl.sum(logprobs, axis=0)\n        tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar)\n    elif reduction == 2:\n        logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32)\n        tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar)\n\n\n_dedicated_stream, _dedicated_events = None, None\n\n\ndef efficient_entropy_forward(\n    hidden: torch.Tensor,\n    weight: torch.Tensor,\n    labels: torch.Tensor,\n    reduction: typing.Optional[int] = 2,\n    temperature: typing.Optional[float] = 1.0,\n    dist_process_group: typing.Optional[dist.ProcessGroup] = None,\n) -> list[torch.Tensor]:\n    \"\"\"\n    forward host function\n    \"\"\"\n    assert hidden.is_cuda and weight.is_cuda and labels.is_cuda\n    assert weight.device == hidden.device and labels.device == hidden.device\n    assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1\n    assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous()\n\n    assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1]\n\n    _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group)\n    _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group)\n\n    if dist_process_group is not None and not hasattr(efficient_entropy_forward, \"_initialized\"):\n        global _dedicated_stream, _dedicated_events\n        _dedicated_stream = get_torch_device().Stream(hidden.device)\n        _dedicated_events = [get_torch_device().Event() for _ in range(2)]\n        efficient_entropy_forward._initialized = True\n\n    num_tokens, hidden_size = hidden.shape\n    num_tokens = labels.shape[0]\n    vocab_size, hidden_size = weight.shape\n    assert hidden_size % 128 == 0\n\n    REDUCTION = get_entropy_reduction_enum(reduction)\n\n    if REDUCTION == EntropyReductionEnum._None:\n        if dist_process_group is None:\n            logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32)\n        else:\n            logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32)\n    elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean):\n        logprobs = torch.empty((), device=hidden.device, dtype=torch.float32)\n    else:\n        raise ValueError(f\"Invalid reduction: {reduction}\")\n\n    entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32)\n    assert logprobs.is_contiguous() and entropy.is_contiguous()\n\n    maximum = torch.empty_like(entropy)\n    accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32)\n    accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens)\n    accumulate = accumulate_and_entropy_b_view[0, :]\n    entropy_b = accumulate_and_entropy_b_view[1, :]\n    assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous()\n\n    vocab_per_split = 1024\n    assert vocab_per_split % 128 == 0\n    num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n\n    _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32)\n    _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32)\n    _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32)\n\n    if REDUCTION == EntropyReductionEnum._None:\n        _logprobs = logprobs\n    else:\n        _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32)\n\n    assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous()\n    assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda\n\n    if _config._use_triton:\n        # 1D kernel launch, then split the tile\n        def mainloop_grid(meta):\n            return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * num_splits,)\n\n        efficient_entropy_kernel_general_mainloop[mainloop_grid](\n            _rank,\n            hidden,\n            weight,\n            labels,\n            num_tokens,\n            hidden_size,\n            vocab_size,\n            vocab_per_split,\n            hidden.stride(0),\n            hidden.stride(1),\n            weight.stride(0),\n            weight.stride(1),\n            _max,\n            _max.stride(0),\n            _max.stride(1),\n            _accu,\n            _accu.stride(0),\n            _accu.stride(1),\n            _entropy_b,\n            _entropy_b.stride(0),\n            _entropy_b.stride(1),\n            _logprobs,\n            _logprobs.stride(0),\n            logprobs,\n            1.0 / temperature,\n        )\n    else:\n        raise AssertionError(\"Triton is required for efficient entropy kernel\")\n\n    # reduction on maximum and maximum_indices\n    def epilogue_grid(meta):\n        return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]),)\n\n    if dist_process_group is None:\n        efficient_entropy_triton_kernel_epilogue[epilogue_grid](\n            _max,\n            _max.stride(0),\n            _max.stride(1),\n            num_tokens,\n            num_splits,\n            maximum,\n            maximum.stride(0),\n            _accu,\n            _accu.stride(0),\n            _accu.stride(1),\n            accumulate,\n            accumulate.stride(0),\n            _entropy_b,\n            _entropy_b.stride(0),\n            _entropy_b.stride(1),\n            entropy_b,\n            entropy_b.stride(0),\n            entropy,\n            entropy.stride(0),\n            _logprobs,\n            _logprobs.stride(0),\n            logprobs,\n            REDUCTION,\n        )\n    else:\n        # tensor-parallel\n        _max_backup = _max.clone()\n        dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group)\n\n        get_torch_device().current_stream().record_event(_dedicated_events[0])\n        with get_torch_device().stream(_dedicated_stream):\n            _dedicated_stream.wait_event(_dedicated_events[0])\n            dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group)\n            _dedicated_stream.record_event(_dedicated_events[1])\n\n        efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid](\n            num_tokens,\n            num_splits,\n            _max,\n            _max.stride(0),\n            _max.stride(1),\n            _max_backup,\n            _max_backup.stride(0),\n            _max_backup.stride(1),\n            _accu,\n            _accu.stride(0),\n            _accu.stride(1),\n            _entropy_b,\n            _entropy_b.stride(0),\n            _entropy_b.stride(1),\n            maximum,\n            maximum.stride(0),\n            accumulate,\n            accumulate.stride(0),\n            entropy_b,\n            entropy_b.stride(0),\n        )\n        get_torch_device().current_stream().wait_event(_dedicated_events[1])\n\n        dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group)\n\n        # update logprobs & entropy\n        efficient_entropy_triton_epilogue_tp_update[epilogue_grid](\n            num_tokens,\n            _logprobs,\n            _logprobs.stride(0),\n            maximum,\n            maximum.stride(0),\n            accumulate,\n            accumulate.stride(0),\n            entropy_b,\n            entropy_b.stride(0),\n            entropy,\n            entropy.stride(0),\n            logprobs,\n            REDUCTION,\n        )\n\n    return (logprobs, entropy, maximum, accumulate, entropy_b)\n\n\n# NOTE: merge d_weight & d_hidden here, split along M & N\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        )\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_general_mainloop_MN(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    d_hidden_ptr,\n    stride_d_hidden_m: tl.int64,\n    stride_d_hidden_k: tl.int64,\n    d_weight_ptr,\n    stride_d_weight_n: tl.int64,\n    stride_d_weight_k: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"\n    backward mainloop, where d_logits & d_hidden & d_weight are fused\n    \"\"\"\n    # block swizzling\n    # pid = tl.program_id(axis=0)\n    # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    # pid_m = pid % num_pid_m\n    # pid_n = pid // num_pid_m\n\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n    maximum_ptrs = maximum_ptr + offs_am * stride_maximum\n    maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0)\n    accu_ptrs = accu_ptr + offs_am * stride_accu\n    accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6)  # epsilon to avoid division by zero\n    accu_rcp = tl.fdiv(1.0, accu)\n\n    d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy\n    d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0)\n    if reduction == 0:  # none\n        d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs\n        d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0)\n    elif reduction == 1:  # sum\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:  # mean\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n\n    entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b\n    entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0)\n\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n    # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n)\n    weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n    labels_ptrs = labels_ptr + offs_am * stride_labels\n    labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0)\n\n    d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k\n    # d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n\n    d_weight_ptrs = d_weight_ptr + offs_bn[:, None] * stride_d_weight_n + offs_k[None, :] * stride_d_weight_k\n\n    logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        # _weight = tl.load(weight_ptrs,\n        #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size),\n        #                   other=0.0)\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n            other=0.0,\n        )\n\n        logits = tl.dot(_hidden, _weight.trans(), logits)\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n    hidden_ptrs -= hidden_size * stride_hidden_k\n    weight_ptrs -= hidden_size * stride_weight_k\n\n    # scale logits by temperature\n    logits *= rcp_temperature\n\n    exp_logits = tl.exp(logits - maximum[:, None])\n\n    mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n    d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n    d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n    # scale d_logits by temperature\n    d_logits *= rcp_temperature\n\n    # loop for d_weight & d_hidden\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        # _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits)\n        # tl.atomic_add(d_weight_ptrs,\n        #               _d_weight,\n        #               mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size))\n        _d_weight = tl.dot(d_logits.trans(), _hidden.to(tl.float32))\n        tl.atomic_add(\n            d_weight_ptrs,\n            _d_weight,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n        )\n\n        # _weight = tl.load(weight_ptrs,\n        #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size),\n        #                   other=0.0)\n        # _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32))\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n            other=0.0,\n        )\n        _d_hidden = tl.dot(d_logits, _weight.to(tl.float32))\n        tl.atomic_add(\n            d_hidden_ptrs,\n            _d_hidden,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n        )\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n        d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k\n        d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_d_hidden(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    d_hidden_ptr,\n    stride_d_hidden_m: tl.int64,\n    stride_d_hidden_k: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"\n    backward d_hidden\n    \"\"\"\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    pid_m = pid % num_pid_m\n    pid_k = pid // num_pid_m\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    result_offs_k = pid_k * BLOCK_SIZE_K + offs_k\n\n    maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0)\n    accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6)\n    accu_rcp = tl.fdiv(1.0, accu)\n    d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0)\n    if reduction == 0:\n        d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0)\n    elif reduction == 1:\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n\n    entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0)\n    labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0)\n\n    # iterate over vocab_size\n    d_hidden = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n    for n in range(0, tl.cdiv(vocab_size, BLOCK_SIZE_N)):\n        offs_n = n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n        hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n        weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n        # iterate over hidden_size to get logits\n        logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n            _hidden = tl.load(\n                hidden_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens),\n                other=0.0,\n            )\n            _weight = tl.load(\n                weight_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size),\n                other=0.0,\n            )\n\n            logits = tl.dot(_hidden, _weight.trans(), logits)\n\n            hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n            weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n\n        # scale logits by temperature\n        logits *= rcp_temperature\n\n        exp_logits = tl.exp(logits - maximum[:, None])\n\n        mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None]\n        d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n        d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n        # scale d_logits\n        d_logits *= rcp_temperature\n\n        # calculate d_hidden\n        weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k)\n        _weight = tl.load(\n            weight_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_n[:, None] < vocab_size), other=0.0\n        )\n        d_hidden = tl.dot(d_logits.to(weight_ptr.dtype.element_ty), _weight, d_hidden)\n\n    # write back\n    tl.store(\n        d_hidden_ptr + offs_m[:, None] * stride_d_hidden_m + result_offs_k[None, :] * stride_d_hidden_k,\n        d_hidden,\n        mask=(offs_m[:, None] < num_tokens) & (result_offs_k[None, :] < hidden_size),\n    )\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_d_weight(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    d_weight_ptr,\n    stride_d_weight_n: tl.int64,\n    stride_d_weight_k: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n    num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N)\n    pid_n = pid % num_pid_n\n    pid_k = pid // num_pid_n\n\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    result_offs_k = pid_k * BLOCK_SIZE_K + offs_k\n\n    d_weight = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)\n    for m in range(0, tl.cdiv(num_tokens, BLOCK_SIZE_M)):\n        offs_m = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n        maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0)\n        accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6)\n        accu_rcp = tl.fdiv(1.0, accu)\n        d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0)\n        if reduction == 0:\n            d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0)\n        elif reduction == 1:\n            d_logprobs = tl.load(d_logprobs_ptr)\n            d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n        else:\n            d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n            d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n        d_logprobs = -1 * d_logprobs\n\n        entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0)\n        labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0)\n\n        hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n        weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n        logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n            _hidden = tl.load(\n                hidden_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens),\n                other=0.0,\n            )\n            _weight = tl.load(\n                weight_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size),\n                other=0.0,\n            )\n\n            logits = tl.dot(_hidden, _weight.trans(), logits)\n\n            hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n            weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n\n        logits *= rcp_temperature\n\n        exp_logits = tl.exp(logits - maximum[:, None])\n\n        mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None]\n        d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n        d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n        d_logits *= rcp_temperature\n\n        hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k)\n        _hidden = tl.load(\n            hidden_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_m[:, None] < num_tokens), other=0.0\n        )\n        d_weight = tl.dot(d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight)\n\n    # write back\n    tl.store(\n        d_weight_ptr + offs_n[:, None] * stride_d_weight_n + result_offs_k[None, :] * stride_d_weight_k,\n        d_weight,\n        mask=(offs_n[:, None] < vocab_size) & (result_offs_k[None, :] < hidden_size),\n    )\n\n\n# NOTE: split tile from d_logits' perspective\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_general_d_logits(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b,\n    d_logits_ptr,\n    stride_d_logits_m: tl.int64,\n    stride_d_logits_n: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"\n    backward d_logits\n    \"\"\"\n    # block swizzling\n    # pid = tl.program_id(axis=0)\n    # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    # pid_m = pid % num_pid_m\n    # pid_n = pid // num_pid_m\n\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n    maximum_ptrs = maximum_ptr + offs_am * stride_maximum\n    maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0)\n    accu_ptrs = accu_ptr + offs_am * stride_accu\n    accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6)  # epsilon to avoid division by zero\n    accu_rcp = tl.fdiv(1.0, accu)\n\n    d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy\n    d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0)\n    if reduction == 0:  # none\n        d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs\n        d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0)\n    elif reduction == 1:  # sum\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:  # mean\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n\n    entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b\n    entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0)\n\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n    # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n)\n    weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n    labels_ptrs = labels_ptr + offs_am * stride_labels\n    labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0)\n\n    logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        # _weight = tl.load(weight_ptrs,\n        #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size),\n        #                   other=0.0)\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n            other=0.0,\n        )\n\n        logits = tl.dot(_hidden, _weight.trans(), logits)\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n    hidden_ptrs -= hidden_size * stride_hidden_k\n    weight_ptrs -= hidden_size * stride_weight_k\n\n    # scale logits by temperature\n    logits *= rcp_temperature\n\n    exp_logits = tl.exp(logits - maximum[:, None])\n\n    mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n    d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n    d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n    # scale d_logits by temperature\n    d_logits *= rcp_temperature\n\n    # store d_logits\n    d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n\n    tl.store(\n        d_logits_ptrs,\n        d_logits,  # will be implicitly converted to d_logits_ptrs.dtype.element_ty\n        mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size),\n    )\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_general_d_logits_split_N(\n    split_idx: int,\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    vocab_per_split: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b,\n    d_logits_ptr,\n    stride_d_logits_m: tl.int64,\n    stride_d_logits_n: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_bn = split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n    maximum = tl.load(maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0)\n    accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6)\n    accu_rcp = tl.fdiv(1.0, accu)\n    d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0)\n    if reduction == 0:\n        d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0)\n    elif reduction == 1:\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n    entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0)\n    labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0)\n\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n    weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n    vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size)\n    logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_right_bound),\n            other=0.0,\n        )\n        logits = tl.dot(_hidden, _weight.trans(), logits)\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n\n    logits *= rcp_temperature\n    exp_logits = tl.exp(logits - maximum[:, None])\n\n    mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n    d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n    d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n    d_logits *= rcp_temperature\n\n    # filter d_logits with mask\n    result_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split)\n\n    tl.store(\n        d_logits_ptr + offs_am[:, None] * stride_d_logits_m + result_offs_n[None, :] * stride_d_logits_n, d_logits, mask\n    )\n\n\ndef efficient_entropy_backward(\n    dlogprobs: torch.Tensor,\n    dentropy: torch.Tensor,\n    hidden: torch.Tensor,\n    weight: torch.Tensor,\n    labels: torch.Tensor,\n    maximum: torch.Tensor,\n    acc: torch.Tensor,\n    entropy_b: torch.Tensor,\n    reduction: typing.Optional[int] = 2,\n    should_return_fp32_grad: bool = False,\n    temperature: typing.Optional[float] = 1.0,\n    dist_process_group: typing.Optional[dist.ProcessGroup] = None,\n) -> list[torch.Tensor]:\n    \"\"\"\n    backward host function\n    \"\"\"\n    assert hidden.is_cuda and weight.is_cuda and labels.is_cuda\n    assert weight.device == hidden.device and labels.device == hidden.device\n    assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1\n    assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous()\n    assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1]\n\n    _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group)\n    _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group)\n\n    num_tokens, hidden_size = hidden.shape\n    num_tokens = labels.shape[0]\n    vocab_size, hidden_size = weight.shape\n    assert hidden_size % 128 == 0\n\n    REDUCTION = get_entropy_reduction_enum(reduction)\n\n    if REDUCTION == EntropyReductionEnum._None:\n        assert dlogprobs.shape == (num_tokens,)\n    else:\n        assert dlogprobs.dim() == 0\n\n    assert dlogprobs.is_contiguous() and dentropy.is_contiguous()\n    assert dlogprobs.is_cuda and dentropy.is_cuda\n    assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device\n    assert dentropy.shape == (num_tokens,)\n\n    d_hidden, d_weight = None, None\n    if _config._backward == BackwardEnum._Total_Fuse_MN or should_return_fp32_grad:\n        d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device)\n        d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device)\n    else:\n        d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device)\n        d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device)\n    assert d_hidden.is_contiguous() and d_weight.is_contiguous()\n\n    assert maximum.is_contiguous() and acc.is_contiguous()\n    assert maximum.device == hidden.device and acc.device == hidden.device\n    assert maximum.shape == labels.shape == acc.shape\n    assert maximum.is_cuda and acc.is_cuda\n\n    vocab_per_split = 1024\n    assert vocab_per_split % 128 == 0\n    num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n\n    assert entropy_b.is_contiguous() and entropy_b.is_cuda\n    assert entropy_b.shape == (num_tokens,)\n\n    if _config._backward == BackwardEnum._Total_Fuse_MN:\n        # --- Triton doesn't materialize d_logits at all. Split tiles at the perspective of d_logits.\n        def mainloop_grid(meta):\n            return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * triton.cdiv(vocab_size, meta[\"BLOCK_SIZE_N\"]),)\n\n        efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid](\n            num_tokens,\n            hidden_size,\n            vocab_size,\n            _rank,\n            hidden,\n            hidden.stride(0),\n            hidden.stride(1),\n            weight,\n            weight.stride(0),\n            weight.stride(1),\n            labels,\n            labels.stride(0),\n            maximum,\n            maximum.stride(0),\n            acc,\n            acc.stride(0),\n            dentropy,\n            dentropy.stride(0),\n            dlogprobs,\n            dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0,\n            REDUCTION,\n            entropy_b,\n            entropy_b.stride(0),\n            d_hidden,\n            d_hidden.stride(0),\n            d_hidden.stride(1),\n            d_weight,\n            d_weight.stride(0),\n            d_weight.stride(1),\n            1.0 / temperature,\n        )\n\n    elif _config._backward == BackwardEnum._Total_Separate:\n        _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype).contiguous()\n        assert _d_logits.is_contiguous()\n\n        if _config._use_triton:\n\n            def d_logits_grid(meta):\n                return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * triton.cdiv(vocab_size, meta[\"BLOCK_SIZE_N\"]),)\n\n            efficient_entropy_backward_kernel_general_d_logits[d_logits_grid](\n                num_tokens,\n                hidden_size,\n                vocab_size,\n                _rank,\n                hidden,\n                hidden.stride(0),\n                hidden.stride(1),\n                weight,\n                weight.stride(0),\n                weight.stride(1),\n                labels,\n                labels.stride(0),\n                maximum,\n                maximum.stride(0),\n                acc,\n                acc.stride(0),\n                dentropy,\n                dentropy.stride(0),\n                dlogprobs,\n                dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0,\n                REDUCTION,\n                entropy_b,\n                entropy_b.stride(0),\n                _d_logits,\n                _d_logits.stride(0),\n                _d_logits.stride(1),\n                1.0 / temperature,\n            )\n\n            torch.matmul(_d_logits, weight, out=d_hidden)\n            torch.matmul(_d_logits.T, hidden, out=d_weight)\n        else:\n            raise AssertionError(\"Triton is required for efficient entropy kernel\")\n\n    elif _config._backward == BackwardEnum._Split_Dlogits_N:\n        vocab_per_split = 9504\n        num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n\n        _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous()\n        assert _d_logits.is_contiguous()\n\n        def d_logits_grid(meta):\n            return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * triton.cdiv(vocab_per_split, meta[\"BLOCK_SIZE_N\"]),)\n\n        for split_idx in range(num_splits):\n            efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid](\n                split_idx,\n                num_tokens,\n                hidden_size,\n                vocab_size,\n                vocab_per_split,\n                _rank,\n                hidden,\n                hidden.stride(0),\n                hidden.stride(1),\n                weight,\n                weight.stride(0),\n                weight.stride(1),\n                labels,\n                labels.stride(0),\n                maximum,\n                maximum.stride(0),\n                acc,\n                acc.stride(0),\n                dentropy,\n                dentropy.stride(0),\n                dlogprobs,\n                dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0,\n                REDUCTION,\n                entropy_b,\n                entropy_b.stride(0),\n                _d_logits,\n                _d_logits.stride(0),\n                _d_logits.stride(1),\n                1.0 / temperature,\n            )\n\n            if split_idx == (num_splits - 1):\n                vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split\n                _d_logits = _d_logits[:, :vocab_right_bound].contiguous()\n\n            if split_idx == 0:\n                torch.matmul(\n                    _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], out=d_hidden\n                )\n            else:\n                d_hidden += torch.matmul(\n                    _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :]\n                )\n            torch.matmul(\n                _d_logits.T, hidden, out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :]\n            )\n\n    elif _config._backward == BackwardEnum._Split_Dlogits_M:\n        raise NotImplementedError(\"BackwardEnum._Split_Dlogits_M is not implemented yet\")\n\n    return d_hidden, d_weight\n"
  },
  {
    "path": "verl_distillation/verl/utils/kernel/linear_cross_entropy.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport typing\n\nimport torch\nimport torch.distributed as dist\n\n\nclass LinearCrossEntropy(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        hidden: torch.Tensor,\n        weight: torch.Tensor,\n        labels: torch.Tensor,\n        temperature: typing.Optional[float] = 1.0,\n        reduction: typing.Optional[str] = \"none\",\n        dist_process_group: typing.Optional[dist.ProcessGroup] = None,\n    ) -> list[torch.Tensor]:\n        \"\"\"_summary_\n\n        Args:\n            ctx (_type_): _description_\n            hidden (torch.Tensor): (batch_size, num_tokens, hidden_size) -> (batch_size * num_tokens, hidden_size)\n            weight (torch.Tensor): (vocab_size, hidden_size)\n            labels (torch.Tensor): (batch_size, num_tokens) -> (batch_size * num_tokens, )\n            temperature (typing.Optional[float], optional): _description_. Defaults to 1.0.\n            reduction (typing.Optional[str], optional): _description_. Defaults to \"none\".\n            dist_process_group (typing.Optional[dist.ProcessGroup], optional): _description_. Defaults to None.\n\n        Returns:\n            typing.List[torch.Tensor]: _description_\n        \"\"\"\n\n        assert isinstance(temperature, float), f\"temperature must be a float, but got {type(temperature)}\"\n        assert isinstance(reduction, str), f\"reduction must be a str, but got {type(reduction)}\"\n        with torch.cuda.nvtx.range(\"LinearCrossEntropy-forward\"):\n            from . import kernels\n\n            REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower())\n\n            original_hidden_shape = hidden.shape\n            if len(hidden.shape) != 2:\n                hidden = hidden.view(-1, hidden.shape[-1])  # (batch_size * num_tokens, hidden_size)\n            if len(labels.shape) != 1:\n                labels = labels.view(-1)\n\n            logprobs, entropy, _maximum, _accumulate, _entropy_b = kernels.efficient_entropy_forward(\n                hidden, weight, labels, REDUCTION, temperature, dist_process_group\n            )\n\n            ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b)\n            ctx.original_hidden_shape = original_hidden_shape\n            ctx.REDUCTION = REDUCTION\n            ctx.dist_process_group = dist_process_group\n            ctx.should_return_fp32_grad = False\n            ctx.temperature = temperature\n        return logprobs, entropy\n\n    @staticmethod\n    def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> list[torch.Tensor]:\n        from . import kernels\n\n        with torch.cuda.nvtx.range(\"LinearCrossEntropy-backward\"):\n            (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors\n            REDUCTION = ctx.REDUCTION\n            dist_process_group = ctx.dist_process_group\n            should_return_fp32_grad = ctx.should_return_fp32_grad\n            temperature = ctx.temperature\n\n            d_hidden, d_weight = kernels.efficient_entropy_backward(\n                dlogprobs,\n                dentropy,\n                hidden,\n                weight,\n                labels,\n                _maximum,\n                _accumulate,\n                _entropy_b,\n                REDUCTION,\n                should_return_fp32_grad,\n                temperature,\n                dist_process_group,\n            )\n            d_hidden = d_hidden.view(ctx.original_hidden_shape)\n\n        return (d_hidden, d_weight, None, None, None, None)\n\n\nlinear_cross_entropy = LinearCrossEntropy.apply\n"
  },
  {
    "path": "verl_distillation/verl/utils/logger/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom .aggregate_logger import (\n    DecoratorLoggerBase,\n    LocalLogger,\n    log_with_rank,\n    print_rank_0,\n    print_with_rank,\n    print_with_rank_and_timer,\n)\n\n__all__ = [\n    \"LocalLogger\",\n    \"DecoratorLoggerBase\",\n    \"print_rank_0\",\n    \"print_with_rank\",\n    \"print_with_rank_and_timer\",\n    \"log_with_rank\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/utils/logger/aggregate_logger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA Ray logger will receive logging info from different processes.\n\"\"\"\n\nimport datetime\nimport logging\nimport numbers\nimport pprint\n\nimport torch\n\n\ndef concat_dict_to_str(dict: dict, step):\n    output = [f\"step:{step}\"]\n    for k, v in dict.items():\n        if isinstance(v, numbers.Number):\n            output.append(f\"{k}:{pprint.pformat(v)}\")\n    output_str = \" - \".join(output)\n    return output_str\n\n\nclass LocalLogger:\n    \"\"\"\n    A local logger that logs messages to the console.\n\n    Args:\n        print_to_console (bool): Whether to print to the console.\n    \"\"\"\n\n    def __init__(self, print_to_console=True):\n        self.print_to_console = print_to_console\n\n    def flush(self):\n        pass\n\n    def log(self, data, step):\n        if self.print_to_console:\n            print(concat_dict_to_str(data, step=step), flush=True)\n\n\nclass DecoratorLoggerBase:\n    \"\"\"\n    Base class for all decorators that log messages.\n\n    Args:\n        role (str): The role (the name) of the logger.\n        logger (logging.Logger): The logger instance to use for logging.\n        level (int): The logging level.\n        rank (int): The rank of the process.\n        log_only_rank_0 (bool): If True, only log for rank 0.\n    \"\"\"\n\n    def __init__(\n        self, role: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0, log_only_rank_0: bool = True\n    ):\n        self.role = role\n        self.logger = logger\n        self.level = level\n        self.rank = rank\n        self.log_only_rank_0 = log_only_rank_0\n        self.logging_function = self.log_by_logging\n        if logger is None:\n            self.logging_function = self.log_by_print\n\n    def log_by_print(self, log_str):\n        if not self.log_only_rank_0 or self.rank == 0:\n            print(f\"{self.role} {log_str}\", flush=True)\n\n    def log_by_logging(self, log_str):\n        if self.logger is None:\n            raise ValueError(\"Logger is not initialized\")\n        if not self.log_only_rank_0 or self.rank == 0:\n            self.logger.log(self.level, f\"{self.role} {log_str}\")\n\n\ndef print_rank_0(message):\n    \"\"\"If distributed is initialized, print only on rank 0.\"\"\"\n    if torch.distributed.is_initialized():\n        if torch.distributed.get_rank() == 0:\n            print(message, flush=True)\n    else:\n        print(message, flush=True)\n\n\ndef print_with_rank(message: str, rank: int = 0, log_only_rank_0: bool = False):\n    \"\"\"_summary_\n    Print a message with rank information.\n    This function prints the message only if `log_only_rank_0` is False or if the rank is 0.\n\n    Args:\n        message (str): _description_\n        rank (int, optional): _description_. Defaults to 0.\n        log_only_rank_0 (bool, optional): _description_. Defaults to False.\n    \"\"\"\n    if not log_only_rank_0 or rank == 0:\n        print(f\"[Rank {rank}] {message}\", flush=True)\n\n\ndef print_with_rank_and_timer(message: str, rank: int = 0, log_only_rank_0: bool = False):\n    \"\"\"_summary_\n    Print a message with rank information and a timestamp.\n    This function prints the message only if `log_only_rank_0` is False or if the rank is 0.\n\n    Args:\n        message (str): _description_\n        rank (int, optional): _description_. Defaults to 0.\n        log_only_rank_0 (bool, optional): _description_. Defaults to False.\n    \"\"\"\n    now = datetime.datetime.now()\n    message = f\"[{now.strftime('%Y-%m-%d %H:%M:%S')}] [Rank {rank}] {message}\"\n    if not log_only_rank_0 or rank == 0:\n        print(message, flush=True)\n\n\ndef log_with_rank(message: str, rank, logger: logging.Logger, level=logging.INFO, log_only_rank_0: bool = False):\n    \"\"\"_summary_\n    Log a message with rank information using a logger.\n    This function logs the message only if `log_only_rank_0` is False or if the rank is 0.\n    Args:\n        message (str): The message to log.\n        rank (int): The rank of the process.\n        logger (logging.Logger): The logger instance to use for logging.\n        level (int, optional): The logging level. Defaults to logging.INFO.\n        log_only_rank_0 (bool, optional): If True, only log for rank 0. Defaults to False.\n    \"\"\"\n    if not log_only_rank_0 or rank == 0:\n        logger.log(level, f\"[Rank {rank}] {message}\")\n"
  },
  {
    "path": "verl_distillation/verl/utils/logging_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\n\n\ndef set_basic_config(level):\n    \"\"\"\n    This function sets the global logging format and level. It will be called when import verl\n    \"\"\"\n    logging.basicConfig(format=\"%(levelname)s:%(asctime)s:%(message)s\", level=level)\n\n\ndef log_to_file(string):\n    print(string)\n    if os.path.isdir(\"logs\"):\n        with open(f\"logs/log_{torch.distributed.get_rank()}\", \"a+\") as f:\n            f.write(string + \"\\n\")\n"
  },
  {
    "path": "verl_distillation/verl/utils/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/utils/megatron/dist_checkpointing.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core import dist_checkpointing, mpu\nfrom megatron.core.dist_checkpointing.serialization import (\n    get_default_load_sharded_strategy,\n    get_default_save_sharded_strategy,\n)\nfrom megatron.core.dist_checkpointing.strategies.fully_parallel import (\n    FullyParallelLoadStrategyWrapper,\n    FullyParallelSaveStrategyWrapper,\n)\n\n\ndef save_dist_checkpointing(sharded_state_dict, ckpt_path, async_save=False):\n    validate_sharding_integrity = True\n    # Get checkpointing strategies\n    save_strategy = get_default_save_sharded_strategy(\"torch_dist\")\n    save_strategy = FullyParallelSaveStrategyWrapper(\n        save_strategy, mpu.get_data_parallel_group(with_context_parallel=True)\n    )\n\n    # Save model sharded state dicts\n    async_save_request = dist_checkpointing.save(\n        sharded_state_dict,\n        ckpt_path,\n        sharded_strategy=save_strategy,\n        async_sharded_save=async_save,\n        validate_access_integrity=validate_sharding_integrity,\n    )\n\n    return async_save_request\n\n\ndef load_dist_checkpointing(sharded_state_dict, ckpt_dir):\n    # Get checkpointing strategies\n    load_strategy = get_default_load_sharded_strategy(ckpt_dir)\n    load_strategy = FullyParallelLoadStrategyWrapper(\n        load_strategy, mpu.get_data_parallel_group(with_context_parallel=True)\n    )\n\n    # Load model sharded state dicts\n    state_dict = dist_checkpointing.load(sharded_state_dict, ckpt_dir, sharded_strategy=load_strategy)\n\n    return state_dict\n"
  },
  {
    "path": "verl_distillation/verl/utils/megatron/memory.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\nfrom verl.utils.device import get_device_id\n\n\nclass MemoryBuffer:\n    def __init__(self, numel, numel_padded, dtype):\n        self.numel = numel\n        self.numel_padded = numel_padded\n        self.dtype = dtype\n        self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_id(), requires_grad=False)\n\n    def zero(self):\n        \"\"\"Reset the buffer to zero.\"\"\"\n        self.data.zero_()\n\n    def get(self, shape, start_index):\n        \"\"\"Return a tensor with the input `shape` as a view into the\n        1-D data starting at `start_index`.\"\"\"\n        end_index = start_index + shape.numel()\n        assert end_index <= self.numel, \"requested tensor is out of the buffer range.\"\n        buffer_tensor = self.data[start_index:end_index]\n        buffer_tensor = buffer_tensor.view(shape)\n        return buffer_tensor\n"
  },
  {
    "path": "verl_distillation/verl/utils/megatron/optimizer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom megatron.core.optimizer import OptimizerConfig\nfrom megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native\nfrom megatron.core.optimizer_param_scheduler import OptimizerParamScheduler\n\nfrom verl.utils.logger import print_rank_0\n\n\ndef init_megatron_optim_config(optim_config: dict) -> OptimizerConfig:\n    optim_args = {\n        \"optimizer\": optim_config.optimizer,\n        \"lr\": optim_config.lr,\n        \"min_lr\": optim_config.min_lr,\n        \"clip_grad\": optim_config.clip_grad,\n        \"weight_decay\": optim_config.weight_decay,\n        \"bf16\": True,\n        \"params_dtype\": torch.bfloat16,\n        \"use_distributed_optimizer\": True,\n    }\n\n    override_config = optim_config.get(\"override_optimizer_config\", {})\n    if override_config:\n        for k, v in override_config.items():\n            optim_args[k] = v\n\n    print_rank_0(f\"optimizer config after override: {optim_args}\")\n\n    config = OptimizerConfig(**optim_args)\n    return config\n\n\ndef get_megatron_optimizer(\n    model,\n    config: OptimizerConfig,\n    no_weight_decay_cond=None,\n    scale_lr_cond=None,\n    lr_mult=1.0,\n):\n    # Base optimizer.\n    return get_megatron_optimizer_native(\n        config=config,\n        model_chunks=model,\n        no_weight_decay_cond=no_weight_decay_cond,\n        scale_lr_cond=scale_lr_cond,\n        lr_mult=lr_mult,\n    )\n\n\ndef get_megatron_optimizer_param_scheduler(\n    optimizer,\n    config,\n):\n    \"\"\"\n    Get the optimizer parameter scheduler for Megatron.\n    \"\"\"\n    lr_decay_steps = config.lr_decay_steps\n    lr_warmup_steps = config.lr_warmup_steps\n    if config.get(\"lr_decay_steps\", None) is None:\n        lr_decay_steps = config.total_training_steps\n    wsd_decay_steps = None\n    if config.get(\"lr_wsd_decay_steps\", None) is not None:\n        wsd_decay_steps = config.lr_wsd_decay_steps\n    if config.get(\"lr_warmup_steps_ratio\", None) is not None and (\n        config.get(\"lr_warmup_steps\", None) is None or config.lr_warmup_steps <= 0\n    ):\n        lr_warmup_steps = int(config.lr_warmup_steps_ratio * lr_decay_steps)\n\n    opt_param_scheduler = OptimizerParamScheduler(\n        optimizer,\n        init_lr=config.lr_warmup_init,\n        max_lr=config.lr,\n        min_lr=config.min_lr,\n        lr_warmup_steps=lr_warmup_steps,\n        lr_decay_steps=lr_decay_steps,\n        lr_decay_style=config.lr_decay_style,\n        start_wd=config.weight_decay,\n        end_wd=config.weight_decay,\n        wd_incr_steps=config.total_training_steps,\n        wd_incr_style=config.weight_decay_incr_style,\n        use_checkpoint_opt_param_scheduler=config.use_checkpoint_opt_param_scheduler,\n        override_opt_param_scheduler=(not config.use_checkpoint_opt_param_scheduler),\n        wsd_decay_steps=wsd_decay_steps,\n        lr_wsd_decay_style=config.lr_wsd_decay_style,\n    )\n\n    return opt_param_scheduler\n\n\ndef get_megatron_last_lr(optimizer):\n    \"\"\"\n    Get the last learning rate from the optimizer parameter scheduler.\n    \"\"\"\n    return optimizer.param_groups[0][\"lr\"]\n"
  },
  {
    "path": "verl_distillation/verl/utils/megatron/pipeline_parallel.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom megatron.core import parallel_state as mpu\n\nfrom .sequence_parallel import pad_to_sequence_parallel\n\n\ndef compute_transformers_input_shapes(batches, meta_info):\n    from flash_attn.bert_padding import unpad_input  # flash 2 is a must for Megatron\n\n    # pre-compute input shapes for each micro-batch at each pp stage\n    input_shapes = []\n    for model_inputs in batches:\n        input_ids = model_inputs[\"input_ids\"]\n        attention_mask = model_inputs[\"attention_mask\"]\n        input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0]  # (total_nnz, 1)\n        if meta_info[\"sequence_parallel\"]:\n            input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad)\n            # compute shapes for model_inputs\n            input_shapes.append(\n                torch.Size(\n                    [\n                        input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(),\n                        1,\n                        meta_info[\"hidden_size\"],\n                    ]\n                )\n            )\n        else:\n            # compute shapes for model_inputs\n            input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info[\"hidden_size\"]]))\n    return input_shapes\n\n\ndef make_batch_generator(batches, vpp_size):\n    \"\"\"\n    Creates a batch generator suitable for Megatron pipeline parallelism,\n    handling virtual pipeline parallelism (VPP).\n\n    If VPP is used (vpp_size > 1), it duplicates the batch iterator for each\n    virtual pipeline stage. Otherwise, it returns a single iterator.\n\n    Args:\n        batches: An iterable (e.g., list) of micro-batches.\n        vpp_size (int): The virtual pipeline model parallel size.\n\n    Returns:\n        An iterator or a list of iterators over the micro-batches.\n    \"\"\"\n    if vpp_size > 1:\n        # has vpp\n        batch_generator = [batches] * vpp_size  # number of vpp chunks\n        batch_generator = [iter(b) for b in batch_generator]\n    else:\n        # no vpp\n        batch_generator = iter(batches)\n    return batch_generator\n"
  },
  {
    "path": "verl_distillation/verl/utils/megatron/sequence_parallel.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.nn.functional as F\nfrom megatron.core import parallel_state as mpu\n\n\ndef mark_parameter_as_sequence_parallel(parameter):\n    parameter.sequence_parallel = True\n\n\ndef is_sequence_parallel_param(param):\n    return hasattr(param, \"sequence_parallel\") and param.sequence_parallel\n\n\ndef pad_to_sequence_parallel(unpad_tokens: torch.Tensor):\n    \"\"\"pad the tokens such that the total length is a multiple of sp world size\n\n    Args:\n        unpad_tokens: (total_nnz, ...). Tokens after removing padding\n\n    Returns:\n        the padded tokens: (total_nnz + pad_size,...)\n\n    \"\"\"\n    total_nnz = unpad_tokens.shape[0]\n    sp_world_size = mpu.get_tensor_model_parallel_world_size()\n\n    pad_size = 0 if total_nnz % sp_world_size == 0 else sp_world_size - total_nnz % sp_world_size\n\n    if pad_size > 0:\n        if unpad_tokens.ndim == 1:\n            unpad_tokens = F.pad(unpad_tokens, (0, pad_size))\n        elif unpad_tokens.ndim == 2:\n            unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))\n        else:\n            raise NotImplementedError(f\"Padding dim {unpad_tokens.ndim()} is not supported\")\n\n    return unpad_tokens\n"
  },
  {
    "path": "verl_distillation/verl/utils/megatron/tensor_parallel.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities for using tensor_parallel in megatron\n\"\"\"\n\nfrom typing import TYPE_CHECKING\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import parallel_state as mpu\nfrom torch.nn import init\n\nif TYPE_CHECKING:\n    from megatron.core import ModelParallelConfig\n\n\ndef update_kwargs_with_config(dictionary: dict, config: \"ModelParallelConfig\"):\n    dictionary[\"config\"] = config\n    return dictionary\n\n\ndef get_default_kwargs_for_model_parallel_config():\n    model_parallel_config_kwargs = {\n        \"params_dtype\": torch.float32,\n        \"use_cpu_initialization\": False,\n        \"perform_initialization\": True,\n        \"gradient_accumulation_fusion\": False,\n        \"sequence_parallel\": False,\n    }\n    return model_parallel_config_kwargs\n\n\ndef get_default_model_parallel_config():\n    from megatron.core import ModelParallelConfig\n\n    return ModelParallelConfig(**get_default_kwargs_for_model_parallel_config())\n\n\ndef get_common_default_kwargs_for_parallel_linear():\n    default_model_parallel_config = get_default_model_parallel_config()\n    common_default_kwargs = {\n        \"init_method\": init.xavier_normal_,\n        \"stride\": 1,\n        \"keep_master_weight_for_test\": False,\n        \"config\": default_model_parallel_config,\n    }\n    return common_default_kwargs\n\n\ndef get_default_kwargs_for_column_parallel_linear():\n    from megatron.core import ModelParallelConfig\n\n    model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()\n    column_parallel_config_kwargs = {\n        \"async_tensor_model_parallel_allreduce\": False,\n    }\n    model_parallel_config_kwargs.update(column_parallel_config_kwargs)\n    column_default_kwargs = {\n        \"config\": ModelParallelConfig(**model_parallel_config_kwargs),\n    }\n    common_default_kwargs = get_common_default_kwargs_for_parallel_linear()\n    common_default_kwargs.update(column_default_kwargs)\n    return common_default_kwargs\n\n\ndef get_default_kwargs_for_row_parallel_linear():\n    common_default_kwargs = get_common_default_kwargs_for_parallel_linear()\n    return common_default_kwargs\n\n\ndef get_default_kwargs_for_parallel_embedding():\n    from megatron.core import ModelParallelConfig\n\n    model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()\n    embedding_default_kwargs = {\n        \"init_method\": init.xavier_normal_,\n        \"config\": ModelParallelConfig(**model_parallel_config_kwargs),\n    }\n    return embedding_default_kwargs\n\n\ndef is_tensor_parallel_param(param):\n    return hasattr(param, \"tensor_model_parallel\") and param.tensor_model_parallel\n\n\ndef get_tensor_parallel_partition_dim(param):\n    assert is_tensor_parallel_param(param)\n    return param.partition_dim\n\n\ndef get_tensor_parallel_partition_stride(param):\n    assert is_tensor_parallel_param(param)\n    return param.partition_stride\n\n\nclass _VocabParallelEntropy(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor:\n        @torch.compile(dynamic=True)\n        def mul_reduce(a, b):\n            return (a * b).sum(dim=-1, keepdim=True)\n\n        logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values\n        dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group())\n        normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max\n        normalized_exp_logits = normalized_vocab_parallel_logits.exp_()\n        normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)\n        dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group())\n        softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits)\n        sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits)\n        dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group())\n        entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits\n        ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)\n        return entropy.squeeze(dim=-1)\n\n    @staticmethod\n    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:\n        vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors\n        # reuse softmax_logits as grad\n        vocab_parallel_logits.sub_(sum_softmax_times_logits)\n        softmax_logits.mul_(vocab_parallel_logits)\n        softmax_logits.mul_(grad_output.unsqueeze(dim=-1))\n        # recover vocab_parallel_logits\n        vocab_parallel_logits.add_(sum_softmax_times_logits)\n        softmax_logits.mul_(-1)\n        return softmax_logits\n\n\ndef vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor:\n    \"\"\"Compute entropy when the logits are sharded in tp ranks\n\n    Args:\n        vocab_parallel_logits: (total_nnz, vocab_size // tp_size)\n\n    Returns: (total_nnz,)\n\n    \"\"\"\n    return _VocabParallelEntropy.apply(vocab_parallel_logits)\n\n\ndef vocab_parallel_log_probs_from_logits(logits, labels):\n    \"\"\"TODO(zhangchi.usc1992): We may change the implementation later\"\"\"\n    from megatron.core import tensor_parallel\n\n    return -tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits=logits, target=labels)\n\n\ndef vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):\n    \"\"\"Similar to log_probs_from_logits_response_rmpad, but the logits_rmpad is now spliited across tensor parallel\n    region.\n    This will further reduce the peak memory usage during training\n\n    Args:\n        input_ids: [batch_size, seqlen]\n        attention_mask: [batch_size, seqlen]\n        logits_rmpad: [total_nnz, vocab_size // tp_size]\n        response_length: int\n\n    \"\"\"\n    from flash_attn.bert_padding import pad_input, unpad_input\n\n    batch_size, seqlen = input_ids.shape\n    input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)\n    input_ids_rmpad = input_ids_rmpad.squeeze(-1)\n    input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)\n    full_log_probs_rmpad = vocab_parallel_log_probs_from_logits(\n        logits=logits_rmpad, labels=input_ids_rmpad_rolled\n    )  # (total_nnz,)\n    full_output = pad_input(\n        hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n    )\n    output = full_output.squeeze(-1)[:, -response_length - 1 : -1]  # [batch_size, response_length]\n    return output\n"
  },
  {
    "path": "verl_distillation/verl/utils/megatron_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Pretrain utilities.\"\"\"\n\nimport gc\nimport inspect\nimport os\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport torch\nimport torch.nn.functional as F\nfrom megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel\nfrom megatron.core.distributed import DistributedDataParallel as DDP\nfrom megatron.core.distributed import DistributedDataParallelConfig\nfrom megatron.core.enums import ModelType\nfrom megatron.core.optimizer import ChainedOptimizer\nfrom megatron.core.transformer import TransformerConfig\nfrom megatron.core.transformer.module import Float16Module\nfrom megatron.core.utils import get_attr_wrapped_model\nfrom transformers import PretrainedConfig\n\nimport verl.utils.megatron.tensor_parallel as tp_utils\nfrom verl.utils.device import get_device_id, get_device_name, get_torch_device\nfrom verl.utils.fs import local_mkdir_safe\nfrom verl.utils.model import normalize_model_name\nfrom verl.utils.torch_dtypes import PrecisionType\n\n\ndef get_model_config(model):\n    return get_attr_wrapped_model(model, \"config\", allow_none=False)\n\n\ndef get_model(\n    model_provider_func,\n    model_type=ModelType.encoder_or_decoder,\n    wrap_with_ddp=True,\n    use_distributed_optimizer=True,\n    transformer_config=None,\n    override_ddp_config=None,\n):\n    \"\"\"Build the model.\"\"\"\n    # Build model.\n    if (\n        mpu.get_pipeline_model_parallel_world_size() > 1\n        and mpu.get_virtual_pipeline_model_parallel_world_size() is not None\n    ):\n        assert model_type != ModelType.encoder_and_decoder, (\n            \"Interleaved schedule not supported for model with both encoder and decoder\"\n        )\n        model = []\n        has_vp_stage = inspect.signature(mpu.is_pipeline_first_stage).parameters.get(\"vp_stage\", None) is not None\n        for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()):\n            mpu.set_virtual_pipeline_model_parallel_rank(i)\n            # Set pre_process and post_process only after virtual rank is set.\n            extra_kwargs = {} if not has_vp_stage else {\"ignore_virtual\": False, \"vp_stage\": i}\n            pre_process = mpu.is_pipeline_first_stage(**extra_kwargs)\n            post_process = mpu.is_pipeline_last_stage(**extra_kwargs)\n            this_model = model_provider_func(pre_process=pre_process, post_process=post_process, vp_stage=i)\n            this_model.model_type = model_type\n            model.append(this_model)\n        mpu.set_virtual_pipeline_model_parallel_rank(0)\n    else:\n        pre_process = mpu.is_pipeline_first_stage()\n        post_process = mpu.is_pipeline_last_stage()\n        add_encoder = True\n        add_decoder = True\n        if model_type == ModelType.encoder_and_decoder:\n            if mpu.get_pipeline_model_parallel_world_size() > 1:\n                assert mpu.get_pipeline_model_parallel_split_rank() is not None, (\n                    \"Split rank needs to be specified for model with both encoder and decoder\"\n                )\n                rank = mpu.get_pipeline_model_parallel_rank()\n                split_rank = mpu.get_pipeline_model_parallel_split_rank()\n                world_size = mpu.get_pipeline_model_parallel_world_size()\n                pre_process = rank == 0 or rank == split_rank\n                post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1))\n                add_encoder = mpu.is_pipeline_stage_before_split()\n                add_decoder = mpu.is_pipeline_stage_after_split()\n            model = model_provider_func(\n                pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder\n            )\n        else:\n            model = model_provider_func(pre_process=pre_process, post_process=post_process)\n        model.model_type = model_type\n\n    if not isinstance(model, list):\n        model = [model]\n\n    # Set tensor model parallel attributes if not set.\n    # Only parameters that are already tensor model parallel have these\n    # attributes set for them. We should make sure the default attributes\n    # are set for all params so the optimizer can use them.\n    for model_module in model:\n        for param in model_module.parameters():\n            tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)\n\n    # Print number of parameters.\n    if mpu.get_data_parallel_rank() == 0:\n        print(\n            \" > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}\".format(\n                mpu.get_tensor_model_parallel_rank(),\n                mpu.get_pipeline_model_parallel_rank(),\n                sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]),\n            ),\n            flush=True,\n        )\n\n    # GPU allocation.\n    if transformer_config is None or (not transformer_config.use_cpu_initialization):\n        for model_module in model:\n            model_module.to(f\"{get_device_name()}:{get_device_id()}\")\n\n    # Fp16 conversion.\n    config: TransformerConfig = get_model_config(model[0])\n    config.fp8 = None\n    tfconfig: TransformerConfig = model[0].config\n    if config.fp16 or config.bf16:  # the ModelParallelConfig in GPTModel\n        model = [Float16Module(config, model_module) for model_module in model]\n\n    if wrap_with_ddp:\n        ddp_models = []\n        ddp_config_dict = {\n            \"use_distributed_optimizer\": use_distributed_optimizer,\n            \"grad_reduce_in_fp32\": True,\n            \"overlap_grad_reduce\": False,\n        }\n        if override_ddp_config is not None:\n            ddp_config_dict.update(override_ddp_config)\n        ddp_config = DistributedDataParallelConfig(**ddp_config_dict)\n        for model_chunk_idx, model_chunk in enumerate(model):\n            ddp_model = DDP(\n                config=tfconfig,\n                module=model_chunk,\n                disable_bucketing=(model_chunk_idx > 0),\n                ddp_config=ddp_config,\n            )\n            ddp_models.append(ddp_model)\n        model = ddp_models\n        # # Broadcast params from data parallel src rank to other data parallel ranks.\n        # # if args.data_parallel_random_init:\n        for model_module in model:\n            model_module.broadcast_params()\n    return model\n\n\n@dataclass\nclass McoreModuleWrapperConfig:\n    \"\"\"Configuration for Mcore module wrapper.\"\"\"\n\n    is_value_model: bool = False\n    share_embeddings_and_output_weights: bool = False\n    wrap_with_ddp: bool = True\n    use_distributed_optimizer: bool = True\n\n\ndef make_megatron_module(\n    wrap_config: McoreModuleWrapperConfig,\n    tf_config: TransformerConfig,\n    hf_config: PretrainedConfig,\n    bridge: Any = None,\n    override_model_config: dict[str, Any] = None,\n    override_ddp_config: dict[str, Any] = None,\n):\n    if override_model_config is None:\n        override_model_config = {}\n\n    if bridge is not None:\n        from verl.models.mcore.mbridge import freeze_moe_router, make_value_model\n\n        post_model_creation_callbacks = []\n        if wrap_config.is_value_model:\n            post_model_creation_callbacks.append(make_value_model)\n        if override_model_config.get(\"moe_config\", {}).get(\"freeze_moe_router\", False):\n            post_model_creation_callbacks.append(freeze_moe_router)\n        return bridge.get_model(\n            post_model_creation_callbacks=post_model_creation_callbacks,\n            wrap_with_ddp=wrap_config.wrap_with_ddp,\n        )\n    else:\n\n        def megatron_model_provider(pre_process, post_process, vp_stage=None):\n            from verl.models.mcore import init_mcore_model\n\n            parallel_model = init_mcore_model(\n                tf_config,\n                hf_config,\n                pre_process,\n                post_process,\n                share_embeddings_and_output_weights=wrap_config.share_embeddings_and_output_weights,\n                value=wrap_config.is_value_model,\n                freeze_moe_router=override_model_config.get(\"moe_config\", {}).get(\"freeze_moe_router\", False),\n                vp_stage=vp_stage,\n            )\n            parallel_model.to(get_device_name())\n            return parallel_model\n\n        return get_model(\n            megatron_model_provider,\n            wrap_with_ddp=wrap_config.wrap_with_ddp,\n            use_distributed_optimizer=wrap_config.use_distributed_optimizer,\n            override_ddp_config=override_ddp_config,\n        )\n\n\nALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)\n\n\ndef unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):\n    return_list = True\n    if not isinstance(model, list):\n        model = [model]\n        return_list = False\n    unwrapped_model = []\n    for model_module in model:\n        while isinstance(model_module, module_instances):\n            model_module = model_module.module\n        unwrapped_model.append(model_module)\n    if not return_list:\n        return unwrapped_model[0]\n    return unwrapped_model\n\n\ndef convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig:\n    \"\"\"[Deprecated] convert config\n\n    Args:\n        hf_config (PretrainedConfig): _description_\n        megatron_config (_type_): _description_\n\n    Returns:\n        TransformerConfig: _description_\n    \"\"\"\n\n    warnings.warn(\"[deprecated] use config converter for more model support\", stacklevel=2)\n    print(f\"megatron config {megatron_config}\")\n    dt = PrecisionType.to_dtype(megatron_config.params_dtype)\n    print(f\"pipeline_dtype=megatron_config {dt}\")\n    qkv_bias = True if \"Qwen2ForCausalLM\" in hf_config.architectures else getattr(hf_config, \"attention_bias\", False)\n    overlap_p2p_comm = (\n        mpu.get_virtual_pipeline_model_parallel_world_size() is not None\n        and mpu.get_virtual_pipeline_model_parallel_world_size() > 1\n    )\n    batch_p2p_comm = False\n    transformer_config = TransformerConfig(\n        num_layers=hf_config.num_hidden_layers,\n        hidden_size=hf_config.hidden_size,\n        num_attention_heads=hf_config.num_attention_heads,\n        num_query_groups=hf_config.num_key_value_heads,\n        ffn_hidden_size=hf_config.intermediate_size,\n        #    max_position_embeddings=hf_config.max_position_embeddings,\n        activation_func=F.silu,\n        normalization=\"RMSNorm\",\n        #    rotary_percent=False, # default,\n        gated_linear_unit=True,  # for llama\n        use_cpu_initialization=True,\n        apply_residual_connection_post_layernorm=False,  # check what's this mean\n        add_bias_linear=False,\n        tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),\n        pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),\n        virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),\n        context_parallel_size=mpu.get_context_parallel_world_size(),\n        overlap_p2p_comm=overlap_p2p_comm,\n        batch_p2p_comm=batch_p2p_comm,\n        pipeline_dtype=dt,\n        params_dtype=dt,\n        sequence_parallel=mpu.get_tensor_model_parallel_world_size() > 1,\n        variable_seq_lengths=True,\n        masked_softmax_fusion=True,\n        moe_token_dispatcher_type=\"alltoall\",\n        attention_dropout=hf_config.attention_dropout,\n        hidden_dropout=getattr(hf_config, \"hidden_dropout\", 0.0),\n        add_qkv_bias=qkv_bias,\n        bf16=dt is torch.bfloat16,\n    )\n\n    return transformer_config\n\n\ndef mcore_model_parallel_config(\n    sequence_parallel: bool,\n    params_dtype: torch.dtype,\n) -> ModelParallelConfig:\n    # WARNING: Code should not reach this point. This function is deprecated and will be removed.\n    # Please use hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead.\n    warnings.warn(\n        \"Code should not reach this point. This function is deprecated and will be removed. Please use \"\n        \"hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead.\",\n        DeprecationWarning,\n        stacklevel=2,\n    )\n    return ModelParallelConfig(\n        tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),\n        pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),\n        virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),\n        context_parallel_size=mpu.get_context_parallel_world_size(),\n        sequence_parallel=sequence_parallel,\n        params_dtype=params_dtype,\n        pipeline_dtype=params_dtype,\n        bf16=True,\n        fp16=False,\n        timers=None,\n    )\n\n\n@torch.no_grad()\ndef offload_megatron_model_to_cpu(models):\n    \"\"\"\n    In megatron, the model and optimizer storage are:\n    - bf16 parameter data chunked in model parallel group\n    - fp32 grad chunked in model parallel group\n    - fp32 main_parameter chunked in model and dp group\n    - fp32 optimizer state chunked in model and dp group\n    \"\"\"\n    for model_chunk in models:\n        if isinstance(model_chunk, DDP):\n            model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]\n            for buffers in model_chunk_all_buffers:\n                for buffer in buffers:\n                    # offload parameters\n                    if buffer.param_data.storage().size() > 0:\n                        buffer.param_data.cpu_data = buffer.param_data.data.cpu().pin_memory()\n                        buffer.param_data_size = buffer.param_data.storage().size()\n                        buffer.param_data.storage().resize_(0)\n\n                    assert buffer.param_data_size == buffer.param_data.cpu_data.storage().size()\n\n                    if buffer.grad_data.storage().size() > 0:\n                        # if the grad_data size is already zero, we assume that it is already offloaded\n                        buffer.grad_data_size = buffer.grad_data.storage().size()\n                        buffer.grad_data.storage().resize_(0)\n        else:\n            # we need this for ref module\n            for _, param in model_chunk.named_parameters():\n                param.data = param.data.to(\"cpu\", non_blocking=True)\n                if param.grad is not None:\n                    param.grad = param.grad.to(\"cpu\", non_blocking=True)\n    gc.collect()\n    get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef load_megatron_model_to_gpu(models, load_grad=True):\n    for model_chunk in models:\n        if isinstance(model_chunk, DDP):\n            model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]\n            for buffers in model_chunk_all_buffers:\n                for buffer in buffers:\n                    # sometimes, we don't want to load grad for pure inference\n                    if load_grad:\n                        buffer.grad_data.storage().resize_(buffer.grad_data_size)\n                        buffer.grad_data.zero_()\n\n                    if buffer.param_data.storage().size() == 0:\n                        buffer.param_data.storage().resize_(buffer.param_data_size)\n                        # copy data from cpu to cuda\n                        buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True)\n        else:\n            # we need this for ref module\n            device_id = get_device_id()\n            for _, param in model_chunk.named_parameters():\n                param.data = param.data.to(device_id, non_blocking=True)\n                if param.grad is not None:\n                    param.grad = param.grad.to(device_id, non_blocking=True)\n    gc.collect()\n    get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef offload_megatron_copy_params(optimizers):\n    \"\"\"\n    Offload optimizer parameters to CPU. Supports both Megatron optimizers\n    and `ChainedOptimizer`, which wraps a list of underlying optimizers.\n\n    Args:\n        optimizers: The optimizer or ChainedOptimizer instance.\n    \"\"\"\n\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    def offload_tensor_to_cpu(tensor):\n        if tensor is None:\n            return\n        tensor.data = tensor.data.to(\"cpu\", non_blocking=True)\n\n    def offload_group_to_cpu(group):\n        if group is None:\n            return\n\n        if isinstance(group, list):\n            for param_group in group:\n                if isinstance(param_group, list):\n                    for param in param_group:\n                        offload_tensor_to_cpu(param)\n                else:\n                    offload_tensor_to_cpu(param_group)\n        else:\n            offload_tensor_to_cpu(group)\n\n    # Offload all parameter groups to CPU for each underlying optimizer\n\n    for _opt in _iter_opts(optimizers):\n        if hasattr(_opt, \"shard_fp32_from_float16_groups\"):\n            offload_group_to_cpu(_opt.shard_fp32_from_float16_groups)\n\n\n@torch.no_grad()\ndef load_megatron_copy_params(optimizers):\n    \"\"\"\n    Load optimizer parameters back to GPU. Handles ChainedOptimizer.\n\n    Args:\n        optimizers: Optimizer or ChainedOptimizer instance.\n    \"\"\"\n\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    def load_tensor_to_gpu(tensor):\n        if tensor is None:\n            return\n        device_id = get_device_id()\n        tensor.data = tensor.data.to(device_id, non_blocking=True)\n\n    def load_group_to_gpu(group):\n        if group is None:\n            return\n\n        if isinstance(group, list):\n            for param_group in group:\n                if isinstance(param_group, list):\n                    for param in param_group:\n                        load_tensor_to_gpu(param)\n                else:\n                    load_tensor_to_gpu(param_group)\n        else:\n            load_tensor_to_gpu(group)\n\n    # Load all parameter groups to GPU for each underlying optimizer\n\n    for _opt in _iter_opts(optimizers):\n        if hasattr(_opt, \"shard_fp32_from_float16_groups\"):\n            load_group_to_gpu(_opt.shard_fp32_from_float16_groups)\n\n\n@torch.no_grad()\ndef offload_megatron_optimizer(optimizers):\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    for _opt in _iter_opts(optimizers):\n        offload_megatron_copy_params(_opt)\n        ## worker may hold zero parameter when enabling custom pipeline layout\n        if _opt.optimizer is not None:\n            opt_state_dict_values = _opt.optimizer.state.values()\n            for v in opt_state_dict_values:\n                if \"exp_avg\" in v:\n                    v[\"exp_avg\"] = v[\"exp_avg\"].to(\"cpu\", non_blocking=True)\n                if \"exp_avg_sq\" in v:\n                    v[\"exp_avg_sq\"] = v[\"exp_avg_sq\"].to(\"cpu\", non_blocking=True)\n        gc.collect()\n        get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef load_megatron_optimizer(optimizers):\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    for _opt in _iter_opts(optimizers):\n        load_megatron_copy_params(_opt)\n        ## worker may hold zero parameter when enabling custom pipeline layout\n        if _opt.optimizer is not None:\n            # if we are using HybridDeviceOptimizer, we need to only move gpu optimizer state to gpu\n            if hasattr(_opt.optimizer, \"_move_new_state_to_right_device\"):\n                _opt.optimizer._move_new_state_to_right_device()\n            else:\n                opt_state_dict_values = _opt.optimizer.state.values()\n                for v in opt_state_dict_values:\n                    if \"exp_avg\" in v:\n                        v[\"exp_avg\"] = v[\"exp_avg\"].to(get_device_id(), non_blocking=True)\n                    if \"exp_avg_sq\" in v:\n                        v[\"exp_avg_sq\"] = v[\"exp_avg_sq\"].to(get_device_id(), non_blocking=True)\n        gc.collect()\n        get_torch_device().empty_cache()\n\n\ndef get_dist_checkpoint_path(checkpoint_path):\n    local_mkdir_safe(checkpoint_path)\n    local_mkdir_safe(os.path.join(checkpoint_path, \"dist_ckpt\"))\n    return os.path.join(checkpoint_path, \"dist_ckpt\")\n\n\ndef get_hf_model_checkpoint_path(checkpoint_path):\n    local_mkdir_safe(checkpoint_path)\n    local_mkdir_safe(os.path.join(checkpoint_path, \"huggingface\"))\n    return os.path.join(checkpoint_path, \"huggingface\")\n\n\ndef get_transformer_config_checkpoint_path(checkpoint_path):\n    os.makedirs(checkpoint_path, exist_ok=True)\n    return os.path.join(checkpoint_path, \"transformer_config.json\")\n\n\ndef convert_megatron_model_to_transformers_model(\n    name,\n    param,\n    config: PretrainedConfig,\n    tp_size: int,\n    num_query_groups: int,\n    convert_qkv_gate_up_by_trunk_concat=False,\n):\n    \"\"\"Convert megatron model to transformers model.\"\"\"\n    new_params = {}\n\n    def convert_qkv_shard(full_tensor, q_name, k_name, v_name):\n        nonlocal config\n        nonlocal tp_size\n        nonlocal num_query_groups\n\n        q_shard_list = []\n        k_shard_list = []\n        v_shard_list = []\n        hidden_size_per_head = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n\n        if config.num_key_value_heads >= tp_size:\n            q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n            kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n            total_size = q_size_tp + 2 * kv_size_tp\n            for i in range(tp_size):\n                num_query_groups_per_partition = num_query_groups // tp_size\n                qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                q_size_chunk = q_size_tp // num_query_groups_per_partition\n                kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                    q_part = qkv_part_chunk[:q_size_chunk]\n                    k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                    v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                    q_shard_list.append(q_part)\n                    k_shard_list.append(k_part)\n                    v_shard_list.append(v_part)\n        else:\n            q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n            kv_size_tp = hidden_size_per_head\n            total_size = q_size_tp + 2 * kv_size_tp\n            for i in range(tp_size):\n                num_query_groups_per_partition = num_query_groups // tp_size\n                qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                q_size_chunk = q_size_tp // num_query_groups_per_partition\n                kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                    q_part = qkv_part_chunk[:q_size_chunk]\n                    k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                    v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                    q_shard_list.append(q_part)\n                    if i * config.num_key_value_heads % tp_size == 0:\n                        k_shard_list.append(k_part)\n                        v_shard_list.append(v_part)\n\n        new_params[q_name] = torch.cat(q_shard_list, dim=0)\n        new_params[k_name] = torch.cat(k_shard_list, dim=0)\n        new_params[v_name] = torch.cat(v_shard_list, dim=0)\n\n    def convert_gate_up_shard(full_tensor, gate_name, up_name):\n        nonlocal config\n        nonlocal tp_size\n\n        intermediate_size_tp = config.intermediate_size // tp_size\n        gate_weight_list = []\n        up_weight_list = []\n        for i in range(tp_size):\n            gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n            gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n            up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n            gate_weight_list.append(gate_weight_tp)\n            up_weight_list.append(up_weight_tp)\n\n        new_params[gate_name] = torch.cat(gate_weight_list, dim=0)\n        new_params[up_name] = torch.cat(up_weight_list, dim=0)\n\n    if name == \"embedding.word_embeddings.weight\":\n        new_params[\"model.embed_tokens.weight\"] = param\n    elif \"self_attention\" in name:\n        splitted_name = name.split(\".\")\n        layer_number = splitted_name[2]\n        component = splitted_name[4]\n        param_type = splitted_name[5]\n        if component == \"linear_proj\":\n            new_params[f\"model.layers.{layer_number}.self_attn.o_proj.weight\"] = param\n        elif component == \"linear_qkv\" and not isinstance(param, list):\n            if param_type == \"layer_norm_weight\":\n                new_params[f\"model.layers.{layer_number}.input_layernorm.weight\"] = param\n            else:\n                if convert_qkv_gate_up_by_trunk_concat:\n                    convert_qkv_shard(\n                        param,\n                        f\"model.layers.{layer_number}.self_attn.q_proj.{param_type}\",\n                        f\"model.layers.{layer_number}.self_attn.k_proj.{param_type}\",\n                        f\"model.layers.{layer_number}.self_attn.v_proj.{param_type}\",\n                    )\n                else:\n                    new_params[f\"model.layers.{layer_number}.self_attn.qkv_proj.{param_type}\"] = param\n        elif component == \"q_layernorm\" or component == \"k_layernorm\":\n            hf_component = component.replace(\"layer\", \"\")\n            new_params[f\"model.layers.{layer_number}.self_attn.{hf_component}.weight\"] = param\n        else:\n            assert isinstance(param, list) and len(param) == 3\n            assert param_type == \"weight\" or param_type == \"bias\"\n            new_params[f\"model.layers.{layer_number}.self_attn.q_proj.{param_type}\"] = param[0]\n            new_params[f\"model.layers.{layer_number}.self_attn.k_proj.{param_type}\"] = param[1]\n            new_params[f\"model.layers.{layer_number}.self_attn.v_proj.{param_type}\"] = param[2]\n    elif \"mlp\" in name:\n        splitted_name = name.split(\".\")\n        layer_number = splitted_name[2]\n        component = splitted_name[4]\n        param_type = splitted_name[5]\n        if component == \"linear_fc1\" and not isinstance(param, list):\n            if param_type == \"layer_norm_weight\":\n                new_params[f\"model.layers.{layer_number}.post_attention_layernorm.weight\"] = param\n            elif param_type == \"weight\":\n                if convert_qkv_gate_up_by_trunk_concat:\n                    convert_gate_up_shard(\n                        param,\n                        f\"model.layers.{layer_number}.mlp.gate_proj.weight\",\n                        f\"model.layers.{layer_number}.mlp.up_proj.weight\",\n                    )\n                else:\n                    new_params[f\"model.layers.{layer_number}.mlp.gate_up_proj.weight\"] = param\n        elif component == \"linear_fc1\" and isinstance(param, list):\n            assert len(param) == 2\n            assert param_type == \"weight\" or param_type == \"bias\"\n            new_params[f\"model.layers.{layer_number}.mlp.gate_proj.weight\"] = param[0]\n            new_params[f\"model.layers.{layer_number}.mlp.up_proj.weight\"] = param[1]\n        elif component == \"linear_fc2\":\n            new_params[f\"model.layers.{layer_number}.mlp.down_proj.weight\"] = param\n    elif name == \"decoder.final_layernorm.weight\":\n        new_params[\"model.norm.weight\"] = param\n    elif name == \"output_layer.weight\":\n        new_params[\"lm_head.weight\"] = param\n    else:\n        raise ValueError(f\"Unknown param name: {name}\")\n    return new_params.keys(), new_params.values()\n\n\ndef broadcast_from_megatron_pp(tensor: torch.Tensor):\n    # tensor is not None only in one of the pp ranks\n    if tensor is not None:\n        shape = tensor.shape\n        dtype = tensor.dtype\n        tensor_parallel = getattr(tensor, \"tensor_model_parallel\", None)\n        partition_dim = getattr(tensor, \"partition_dim\", None)\n        tensor_spec = (shape, dtype, tensor_parallel, partition_dim)\n    else:\n        tensor_spec = None\n    tensor_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size()\n    torch.distributed.all_gather_object(\n        object_list=tensor_spec_output, obj=tensor_spec, group=mpu.get_pipeline_model_parallel_group()\n    )\n    # find the src rank\n    target_tensor_spec = None\n    src_rank = None\n    for rank, tensor_spec in enumerate(tensor_spec_output):\n        if tensor_spec is not None:\n            if target_tensor_spec is None:\n                target_tensor_spec = tensor_spec\n            else:\n                raise ValueError(\"A tensor exists on two pp ranks\")\n            src_rank = rank\n    assert target_tensor_spec is not None\n    if tensor is None:\n        tensor = torch.empty(size=target_tensor_spec[0], dtype=target_tensor_spec[1], device=get_device_id())\n        if target_tensor_spec[2] is not None:\n            tensor.tensor_model_parallel = target_tensor_spec[2]\n        if target_tensor_spec[3] is not None:\n            tensor.partition_dim = target_tensor_spec[3]\n\n    global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank)\n    torch.distributed.broadcast(tensor=tensor, src=global_rank, group=mpu.get_pipeline_model_parallel_group())\n    return tensor\n\n\ndef broadcast_str_from_megatron_pp(obj: Any):\n    obj_output = [None] * mpu.get_pipeline_model_parallel_world_size()\n    torch.distributed.all_gather_object(object_list=obj_output, obj=obj, group=mpu.get_pipeline_model_parallel_group())\n\n    src_rank = None\n    target_obj = None\n    for rank, item in enumerate(obj_output):\n        if item is not None:\n            if target_obj is not None:\n                raise ValueError(\"An object exists on two pp ranks\")\n            target_obj = item\n            src_rank = rank\n\n    assert target_obj is not None, \"No valid object found to broadcast.\"\n\n    global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank)\n\n    obj_output = [None] * torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group())\n    obj_output[0] = target_obj\n    torch.distributed.broadcast_object_list(\n        object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group()\n    )\n\n    return obj_output[0]\n\n\ndef default_tp_concat_fn(\n    layer_name_mapping,\n    name,\n    train_params,\n    infer_params,\n    model_config,\n    hf_config=None,\n    convert_qkv_gate_up_by_simple_split=False,\n):\n    \"\"\"\n    name: name of the parameter\n    train_params: training parameters\n    infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group\n    model_config: huggingface model_config\n    TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model\n    definition so that it is model-agnostic. If the model doesn't implement this function,\n    we can throw an error to force user disable TP HybridEngine.\n    \"\"\"\n    from megatron.core import mpu\n\n    train_tp_size = mpu.get_tensor_model_parallel_world_size()\n    if layer_name_mapping.get(\"qkv_layer_name\") in name and \"layer_norm\" not in name:\n        # if the tensor is qkv, for each param on tp, split into q, k, v\n        # concat q, k, v separately.\n        q_lst = []\n        k_lst = []\n        v_lst = []\n        num_attention_heads = model_config.num_attention_heads\n        num_key_value_heads = model_config.num_key_value_heads\n        if \"vision_model\" in name:\n            num_attention_heads = hf_config.vision_config.num_heads\n            num_key_value_heads = hf_config.vision_config.num_heads\n        assert num_attention_heads % num_key_value_heads == 0\n        num_q_per_kv = num_attention_heads // num_key_value_heads\n        assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, (\n            f\"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}\"\n        )\n        kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)\n        split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]\n        for infer_param in infer_params:\n            num_query_groups_per_partition = num_key_value_heads // train_tp_size\n            for chunk in infer_param.chunk(num_query_groups_per_partition):\n                split_size = [\n                    kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,\n                    kv_size_per_tp // num_query_groups_per_partition,\n                    kv_size_per_tp // num_query_groups_per_partition,\n                ]\n                q, k, v = chunk.split(split_size)\n                q_lst.append(q)\n                k_lst.append(k)\n                v_lst.append(v)\n        q = torch.cat(q_lst, dim=0)\n        k = torch.cat(k_lst, dim=0)\n        v = torch.cat(v_lst, dim=0)\n        infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v]\n\n    elif (\n        layer_name_mapping.get(\"gate_proj_layer_name\") in name\n        and \"layer_norm\" not in name\n        and \"vision_model.projection\" not in name\n    ):\n        # if the tensor is gate and proj\n        gate_lst = []\n        up_lst = []\n        for infer_param in infer_params:\n            gate, up = infer_param.chunk(2)\n            gate_lst.append(gate)\n            up_lst.append(up)\n        gate = torch.cat(gate_lst, dim=0)\n        up = torch.cat(up_lst, dim=0)\n        infer_params = torch.cat((gate, up), dim=0) if not convert_qkv_gate_up_by_simple_split else [gate, up]\n\n    elif \"mlp.experts.linear_fc2.weight\" in name:  # moe\n        infer_params = torch.cat(infer_params, dim=1)\n\n    else:\n        # concat tensor\n        infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(train_params))\n\n    return infer_params\n\n\ndef per_tensor_generator(\n    actor_module,\n    model_config,\n    weight_converter,\n    transformer_config,\n    layer_name_mapping,\n    convert_qkv_gate_up_by_simple_split=True,\n):\n    from megatron.core import parallel_state as mpu\n\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    ep_size = mpu.get_expert_model_parallel_world_size()\n    etp_size = mpu.get_expert_tensor_parallel_world_size()\n    ep_group = mpu.get_expert_model_parallel_group()\n    etp_group = mpu.get_expert_tensor_parallel_group()\n    vpp_size = len(actor_module)\n    all_gather_group = mpu.get_tensor_model_parallel_group()\n    all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group)\n\n    def tensor_generator():\n        for scan_vpp_idx in range(vpp_size):\n            existing_keys = set()\n            model = unwrap_model(actor_module[scan_vpp_idx])\n            for name, param in model.named_parameters():\n                existing_keys.add(name)\n                yield name, param\n            # note\n            # there is a bug in megatron GPTModel\n            # decoder.layers[n].mlp.router.expert_bias\" in GPTModel is not registered in named_parameter, but in\n            # state_dict(). for now we patch it by adding those keys to extra_keys.\n            extra_keys = [x for x in model.state_dict().keys() if \"_extra_state\" not in x and x not in existing_keys]\n            for name in extra_keys:\n                yield name, model.state_dict()[name].to(get_device_id())\n\n    # we need first make all rank get full model information\n    meta_info = []\n    for scan_vpp_idx in range(vpp_size):\n        existing_keys = set()\n        model = unwrap_model(actor_module[scan_vpp_idx])\n        for idx, (name, _) in enumerate(model.named_parameters()):\n            existing_keys.add(name)\n            meta_info.append((pp_rank, scan_vpp_idx, idx, name))\n        extra_keys = [x for x in model.state_dict().keys() if \"_extra_state\" not in x and x not in existing_keys]\n        for name in extra_keys:\n            meta_info.append((pp_rank, scan_vpp_idx, idx, name))\n\n    obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size()\n    torch.distributed.all_gather_object(\n        object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group()\n    )\n    layer_list_meta = [item for sublist in obj_spec_output for item in sublist]\n\n    gen_func = tensor_generator()\n\n    # lazy load tensor for full model\n    for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta:\n        if model_config.tie_word_embeddings and (\"output_layers\" in name):\n            import warnings\n\n            warnings.warn(\n                \"Current model sharing word and embedding weights, skip output layer conversion\", stacklevel=2\n            )\n            continue\n\n        if cur_pp_rank == pp_rank:\n            try:\n                cur_name, cur_tensor = next(gen_func)\n            except StopIteration:\n                cur_name, cur_tensor = None, None\n            cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, transformer_config)\n        else:\n            cur_tensor, cur_name = None, None\n\n        # pp broadcast model tensor and name\n        cur_name = broadcast_str_from_megatron_pp(cur_name)\n        broad_pp_tensor = broadcast_from_megatron_pp(cur_tensor)\n\n        # (xya): this is a hack to fix the name of the parameters\n        while cur_name.startswith(\"module.\"):\n            cur_name = cur_name[len(\"module.\") :]\n\n        # EP\n        if \".mlp.experts.linear_fc\" in cur_name and ep_size > 1:\n            num_experts = weight_converter.mcore_config.num_moe_experts\n            num_experts_per_rank = num_experts // ep_size\n            infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(ep_size)]\n            torch.distributed.all_gather(infer_params, broad_pp_tensor, group=ep_group)\n\n            name_prefix, local_expert_id = cur_name.split(\".weight\")\n            local_expert_id = int(local_expert_id)\n            global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(ep_size)]\n            global_expert_names = [f\"{name_prefix}.weight{expert_id}\" for expert_id in global_expert_ids]\n\n            for name, param in zip(global_expert_names, infer_params, strict=True):\n                if etp_size > 1:\n                    # gather etp\n                    etp_params = [torch.empty_like(param) for _ in range(etp_size)]\n                    torch.distributed.all_gather(etp_params, param, group=etp_group)\n                    params = etp_params\n                else:\n                    params = [param]\n\n                merge_params = default_tp_concat_fn(\n                    layer_name_mapping,\n                    name,\n                    broad_pp_tensor,\n                    params,\n                    model_config,\n                    weight_converter.hf_config,\n                    convert_qkv_gate_up_by_simple_split,\n                )\n                if not isinstance(merge_params, list):\n                    merge_params = [merge_params]\n                converted_names, converted_params = weight_converter.convert_param(name, merge_params)\n\n                yield from zip(converted_names, [param.detach() for param in converted_params], strict=True)\n            continue\n\n        # tp all gather\n        if tp_utils.is_tensor_parallel_param(broad_pp_tensor):\n            # allocate a new tensor with proper size\n            if all_gather_group_size <= 1:\n                infer_params = [broad_pp_tensor]\n            else:\n                infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)]\n                torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group())\n            infer_params = default_tp_concat_fn(\n                layer_name_mapping,\n                cur_name,\n                broad_pp_tensor,\n                infer_params,\n                model_config,\n                weight_converter.hf_config,\n                convert_qkv_gate_up_by_simple_split,\n            )\n        else:\n            infer_params = broad_pp_tensor\n\n        if not isinstance(infer_params, list):\n            infer_params = [infer_params]\n        converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params)\n\n        yield from zip(converted_names, [param.detach() for param in converted_params], strict=True)\n\n\ndef get_transformer_layer_offset(pipeline_rank, vp_stage, config: TransformerConfig):\n    \"\"\"\n    Get the index offset of any pipeline stage, given the level of pipelining.\n\n    Make pipeline_rank and vp_stage as two arguments to make it more flexible,\n    which is able to fetch layer offset for any pipeline stage.\n    The original function only returns the layer offset for current pipeline stage.\n\n    Extension to https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_layer.py::get_transformer_layer_offset\n    \"\"\"\n\n    has_vp_stage = (\n        inspect.signature(parallel_state.is_pipeline_first_stage).parameters.get(\"vp_stage\", None) is not None\n    )\n    extra_kwargs = {} if not has_vp_stage else {\"ignore_virtual\": False, \"vp_stage\": vp_stage}\n\n    if config.pipeline_model_parallel_size > 1:\n        if hasattr(config, \"pipeline_model_parallel_layout\") and config.pipeline_model_parallel_layout:\n            from megatron.core.transformer.enums import LayerType\n\n            offset = config.pipeline_model_parallel_layout.get_layer_offset(\n                layer_type=LayerType.decoder, vp_stage=vp_stage\n            )\n        elif (\n            config.num_layers_in_first_pipeline_stage is not None\n            or config.num_layers_in_last_pipeline_stage is not None\n        ):\n            # Calculate number of pipeline stages to distribute the remaining Transformer\n            # layers after deducting the Transformer layers in the first or the last stages\n            middle_pipeline_stages = config.pipeline_model_parallel_size\n            middle_pipeline_stages -= sum(\n                [\n                    1 if x is not None else 0\n                    for x in (\n                        config.num_layers_in_first_pipeline_stage,\n                        config.num_layers_in_last_pipeline_stage,\n                    )\n                ]\n            )\n\n            # Calculate layers to distribute in each pipeline stage. If the\n            # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage\n            # are not set, we will not enable uneven pipeline. All layers will be treated\n            # as middle layers.\n            num_layers_in_first_pipeline_stage = (\n                0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage\n            )\n            num_layers_in_last_pipeline_stage = (\n                0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage\n            )\n\n            middle_num_layers = (\n                config.num_layers - num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage\n            )\n\n            if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:\n                assert vp_stage is not None, \"vp_stage must be provided if virtual pipeline model parallel size is set\"\n\n                # Calculate number of layers in each virtual model chunk\n                # If the num_layers_in_first_pipeline_stage and\n                # num_layers_in_last_pipeline_stage are not set, all pipeline stages\n                # will be treated as middle pipeline stages in the calculation\n                num_layers_per_virtual_model_chunk_in_first_pipeline_stage = (\n                    0\n                    if config.num_layers_in_first_pipeline_stage is None\n                    else config.num_layers_in_first_pipeline_stage // vp_size\n                )\n\n                num_layers_per_virtual_model_chunk_in_last_pipeline_stage = (\n                    0\n                    if config.num_layers_in_last_pipeline_stage is None\n                    else config.num_layers_in_last_pipeline_stage // vp_size\n                )\n\n                num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = middle_num_layers // vp_size\n\n                # First stage + middle stage + last stage\n                total_virtual_chunks = (\n                    num_layers_per_virtual_model_chunk_in_first_pipeline_stage\n                    + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage\n                    + num_layers_per_virtual_model_chunk_in_last_pipeline_stage\n                )\n\n                # Calculate the layer offset with interleaved uneven pipeline parallelism\n                if pipeline_rank == 0:\n                    offset = vp_stage * total_virtual_chunks\n                else:\n                    offset = (\n                        vp_stage * total_virtual_chunks\n                        + num_layers_per_virtual_model_chunk_in_first_pipeline_stage\n                        + (pipeline_rank - 1)\n                        * (num_layers_per_vritual_model_chunk_in_middle_pipeline_stage // middle_pipeline_stages)\n                    )\n            else:\n                if middle_pipeline_stages > 0:\n                    num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages\n                else:\n                    num_layers_per_pipeline_rank = 0\n\n                middle_pipeline_rank = (\n                    pipeline_rank if config.num_layers_in_first_pipeline_stage is None else pipeline_rank - 1\n                )\n\n                if pipeline_rank == 0:\n                    offset = 0\n                else:\n                    offset = (middle_pipeline_rank * num_layers_per_pipeline_rank) + num_layers_in_first_pipeline_stage\n        else:\n            num_layers = config.num_layers\n\n            # Increase the number of layers by one if we include the embedding (loss)\n            # layer into pipeline parallelism partition and placement\n            if config.account_for_embedding_in_pipeline_split:\n                num_layers += 1\n\n            if config.account_for_loss_in_pipeline_split:\n                num_layers += 1\n\n            num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size\n\n            if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:\n                assert vp_stage is not None, \"vp_stage must be provided if virtual pipeline model parallel size is set\"\n\n                num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size\n                total_virtual_chunks = num_layers // vp_size\n                offset = vp_stage * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)\n\n                # Reduce the offset of embedding layer from the total layer number\n                if config.account_for_embedding_in_pipeline_split and not parallel_state.is_pipeline_first_stage(\n                    **extra_kwargs\n                ):\n                    offset -= 1\n            else:\n                offset = pipeline_rank * num_layers_per_pipeline_rank\n\n                # Reduce the offset of embedding layer from the total layer number\n                if config.account_for_embedding_in_pipeline_split and not parallel_state.is_pipeline_first_stage(\n                    **extra_kwargs\n                ):\n                    offset -= 1\n    else:\n        offset = 0\n    return offset\n"
  },
  {
    "path": "verl_distillation/verl/utils/memory_buffer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis file contains utilities to manipulate torch memory buffers\n\"\"\"\n\nfrom typing import Optional\n\nimport torch\nfrom torch import nn\n\nfrom verl.utils.device import get_device_name\n\n\nclass MemoryBuffer:\n    \"\"\"\n    A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying\n    memory. It must have a unique type to support this behavior.\n    \"\"\"\n\n    def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Optional[torch.Tensor] = None):\n        self.numel = numel\n        self.numel_padded = numel_padded\n        self.dtype = dtype\n        if source is not None:\n            self.data = source\n        else:\n            self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_name(), requires_grad=False)\n\n    def zero(self):\n        \"\"\"Reset the buffer to zero.\"\"\"\n        self.data.zero_()\n\n    def get(self, shape, start_index):\n        \"\"\"Return a tensor with the input `shape` as a view into the\n        1-D data starting at `start_index`.\"\"\"\n        end_index = start_index + shape.numel()\n        assert end_index <= self.numel, \"requested tensor is out of the buffer range.\"\n        buffer_tensor = self.data[start_index:end_index]\n        buffer_tensor = buffer_tensor.view(shape)\n        return buffer_tensor\n\n\ndef calc_padded_numel(shape: torch.Size, dtype: torch.dtype):\n    \"\"\"for cuda memory alignment, make sure alignment by 128-bits\"\"\"\n    align_numel = 128 // torch.finfo(dtype).bits\n    numel = shape.numel()\n    return (numel + align_numel - 1) // align_numel * align_numel\n\n\ndef get_weight_buffer_meta_from_module(module: nn.Module) -> dict[str, dict]:\n    \"\"\"\n    Return a dictionary containing name to a shape and dtype.\n    \"\"\"\n    weight_buffer_meta = {}\n    for name, param in sorted(module.named_parameters()):\n        weight_buffer_meta[name] = {\"shape\": param.shape, \"dtype\": param.dtype}\n    return weight_buffer_meta\n\n\ndef build_memory_buffer(weight_buffer_meta: dict[str, dict]) -> dict[torch.dtype, MemoryBuffer]:\n    \"\"\"Build the memory buffer given weight_buffer_meta\n\n    Args:\n        weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors\n\n    Returns: a large memory buffer for each dtype that can hold all the tensors\n\n    \"\"\"\n    memory_buffers = {}\n    total_numel_map = {}  # map from dtype to the total numel\n    for name, meta_info in sorted(weight_buffer_meta.items()):\n        shape = meta_info[\"shape\"]\n        dtype = meta_info[\"dtype\"]\n\n        assert isinstance(shape, torch.Size)\n        assert isinstance(dtype, torch.dtype)\n\n        if dtype not in total_numel_map:\n            total_numel_map[dtype] = 0\n\n        total_numel_map[dtype] += calc_padded_numel(shape, dtype)\n\n    for dtype, total_numel in total_numel_map.items():\n        memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype)\n\n    return memory_buffers\n\n\ndef build_memory_reference_from_module(\n    module: torch.nn.Module, memory_buffers: dict[torch.dtype, MemoryBuffer], maintain_weight=True\n):\n    start_index = {}\n    for dtype in memory_buffers:\n        start_index[dtype] = 0\n    for name, param in sorted(module.named_parameters()):\n        memory_buffer = memory_buffers[param.dtype]\n        buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype])\n        # need to increment start_index\n        start_index[param.dtype] += calc_padded_numel(param.shape, param.dtype)\n        if maintain_weight:\n            buffer.copy_(param.data)\n        param.data = buffer\n\n\ndef build_memory_reference(weight_buffer_meta: dict[str, dict], memory_buffers: dict[torch.dtype, MemoryBuffer]):\n    \"\"\"Build the memory references. The memory buffers are built using the build_memory_buffer API.\n    This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta.\n\n    Args:\n        weight_buffer_meta:\n        memory_buffers:\n\n    Returns:\n\n    \"\"\"\n    start_idx = {}\n    weight_buffers = {}\n    for dtype in memory_buffers:\n        start_idx[dtype] = 0\n\n    for name, meta_info in sorted(weight_buffer_meta.items()):\n        shape = meta_info[\"shape\"]\n        dtype = meta_info[\"dtype\"]\n\n        buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype])\n        start_idx[dtype] += calc_padded_numel(shape, dtype)\n        weight_buffers[name] = buffer\n\n    return weight_buffers\n\n\nclass MemoryBufferModuleWrapper:\n    \"\"\"\n    Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to\n    - It will change the checkpoint name\n    \"\"\"\n\n    def __init__(self, module: nn.Module):\n        super().__init__()\n        self.module = module\n        self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module)\n        self.memory_buffers = build_memory_buffer(self.weight_buffer_meta)\n        build_memory_reference_from_module(self.module, self.memory_buffers)\n\n    def get_memory_buffers(self):\n        return self.memory_buffers\n\n    def get_weight_buffer_meta(self):\n        return self.weight_buffer_meta\n\n\nclass MegatronMemoryBufferForRollout:\n    \"\"\"\n    We assume that\n    - inference engine has tp + dp\n    - actor has tp + pp + dp\n    - the tp between inference engine and actor should be the same\n    - memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer\n    - weight_buffers: contains a list of weight_buffers, each is a dict from name to param\n    - named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that\n        the named_parameters may not be directly compatible with inference engine. User has to take care of\n        this part such as the layout mismatches. (e.g. qkv transpose)\n    - Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory.\n    - When doing weight sync, the data is transfer via memory buffers\n    \"\"\"\n\n    def __init__(self, transform_memory_param_fn):\n        self._memory_buffers = []\n        self._weight_buffers = []\n        self._named_parameters = {}\n        self.transform_memory_param_fn = transform_memory_param_fn\n\n    def initialize_weight_buffer(self, weight_buffer_meta_pp: list[dict[str, dict]]):\n        \"\"\"\n        Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct\n        a large buffer for each dtype in the weight_buffer.\n\n        Args:\n            weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from\n\n        Returns: None\n\n        \"\"\"\n        self.weight_buffer_meta_pp = weight_buffer_meta_pp\n\n        for weight_buffer_meta in self.weight_buffer_meta_pp:\n            memory_buffer = build_memory_buffer(weight_buffer_meta)\n            self._memory_buffers.append(memory_buffer)\n            self._weight_buffers.append(None)\n\n    def build_memory_reference(self):\n        for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp):\n            self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i])\n        self._named_parameters = self.transform_memory_param_fn(self._weight_buffers)\n\n    @property\n    def named_parameters(self):\n        return self._named_parameters\n\n    @property\n    def weight_buffers(self):\n        return self._weight_buffers\n\n    @property\n    def memory_buffers(self):\n        return self._memory_buffers\n"
  },
  {
    "path": "verl_distillation/verl/utils/memory_utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport inspect\nimport logging\nimport os\nfrom datetime import datetime\nfrom pathlib import Path\n\nimport torch\n\nfrom verl.utils.device import get_torch_device, is_cuda_available\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef aggressive_empty_cache(force_sync: bool = True, max_retries: int = 3) -> None:\n    \"\"\"\n    More aggressive GPU memory cleanup function, tries to release PyTorch reserved\n    but unallocated memory.\n\n    Args:\n        force_sync: Whether to force device synchronization\n        max_retries: Maximum number of retries\n    \"\"\"\n    device = get_torch_device()\n    if not device.is_available():\n        return\n\n    for attempt in range(max_retries):\n        # Record memory status before cleanup\n        before_reserved = device.memory_reserved()\n        before_allocated = device.memory_allocated()\n\n        # Run garbage collection\n        gc.collect()\n\n        # Clear PyTorch cache\n        device.empty_cache()\n\n        # Force synchronization (optional)\n        if force_sync:\n            device.synchronize()\n\n        # Record memory status after cleanup\n        after_reserved = device.memory_reserved()\n        after_allocated = device.memory_allocated()\n\n        # Calculate freed memory\n        reserved_freed = before_reserved - after_reserved\n        allocated_freed = before_allocated - after_allocated\n\n        logger.info(\n            f\"Memory cleanup attempt {attempt + 1}: Freed {reserved_freed / 1024**3:.2f} GB reserved, \"\n            f\"{allocated_freed / 1024**3:.2f} GB allocated\"\n        )\n\n        # Stop retrying if little memory was freed\n        if reserved_freed < 1024**3:  # less than 1GB\n            break\n\n\ndef reset_memory_stats() -> None:\n    \"\"\"Reset GPU memory statistics\"\"\"\n    if get_torch_device().is_available():\n        device = get_torch_device()\n        device.reset_peak_memory_stats()\n        device.reset_accumulated_memory_stats()\n\n\ndef get_memory_info() -> dict:\n    \"\"\"Get detailed GPU memory information\"\"\"\n    if not get_torch_device().is_available():\n        return {}\n\n    device = get_torch_device()\n    device_id = device.current_device()\n\n    return {\n        \"total_memory_gb\": device.get_device_properties(device_id).total_memory / 1024**3,\n        \"reserved_memory_gb\": device.memory_reserved() / 1024**3,\n        \"allocated_memory_gb\": device.memory_allocated() / 1024**3,\n        \"cached_memory_gb\": (device.memory_reserved() - device.memory_allocated()) / 1024**3,\n        \"max_memory_allocated_gb\": device.max_memory_allocated() / 1024**3,\n        \"max_memory_reserved_gb\": device.max_memory_reserved() / 1024**3,\n    }\n\n\ndef log_memory_usage(stage: str = \"current\") -> None:\n    \"\"\"Log GPU memory usage\"\"\"\n    if not get_torch_device().is_available():\n        return\n\n    info = get_memory_info()\n    logger.info(\n        f\"Memory usage [{stage}]: \"\n        f\"Total: {info['total_memory_gb']:.2f} GB, \"\n        f\"Allocated: {info['allocated_memory_gb']:.2f} GB, \"\n        f\"Reserved: {info['reserved_memory_gb']:.2f} GB, \"\n        f\"Cached: {info['cached_memory_gb']:.2f} GB\"\n    )\n\n\ndef optimize_memory_for_inference() -> None:\n    \"\"\"Optimize GPU memory usage for inference\"\"\"\n    if not get_torch_device().is_available():\n        return\n\n    # Set a more aggressive memory allocation policy\n    get_torch_device().set_per_process_memory_fraction(0.95)  # Use 95% of GPU memory\n\n    # Clear cache\n    aggressive_empty_cache(force_sync=True)\n\n    logger.info(\"Optimized GPU memory usage for inference\")\n\n\ndef optimize_memory_for_training() -> None:\n    \"\"\"Optimize GPU memory usage for training\"\"\"\n    if not get_torch_device().is_available():\n        return\n\n    # Set a moderate memory allocation policy\n    get_torch_device().set_per_process_memory_fraction(0.9)  # Use 90% of GPU memory\n\n    # Clear cache\n    aggressive_empty_cache(force_sync=False)\n\n    logger.info(\"Optimized GPU memory usage for training\")\n\n\ndef enable_memory_visualize(\n    trace_alloc_max_entries: int = 200_000,\n    stack_depth: int = 32,\n    context: str = \"all\",\n    stacks: str = \"all\",\n    devices=None,\n    record_context: bool = True,\n):\n    \"\"\"\n    Enables memory history recording for CUDA allocations. This function\n    should be called before any large-scale CUDA allocations. For DDP or\n    multi-process setups, it must be called on each rank.\n\n    Args:\n        trace_alloc_max_entries (int): Maximum number of allocation entries\n            to record.\n        stack_depth (int): The depth of the call stack to capture for each\n            allocation. (Supported by some PyTorch versions).\n        context (str): The type of memory events to record.\n            'alloc': records only allocation events.\n            'state': records memory state changes.\n            'all': records both.\n        stacks (str): The type of call stacks to record.\n            'python': records Python stacks.\n            'cpp': records C++ stacks (available in some versions).\n            'all': records both.\n        devices (Union[int, list[int], None]): The device for which to enable\n            memory history. `None` enables it for the current default device.\n        record_context (bool): Whether to record context information for\n            allocations. Required by older PyTorch versions.\n    \"\"\"\n    # Memory history recording is CUDA-specific functionality\n    if not is_cuda_available:\n        logger.warning(\"[memory_visualize] Memory history recording is only available on CUDA devices\")\n        return\n\n    f = get_torch_device().memory._record_memory_history\n    params = set(inspect.signature(f).parameters.keys())\n\n    def _one_call(dev_kw=None):\n        kwargs = {}\n        if \"context\" in params:\n            kwargs[\"context\"] = context\n        if \"stacks\" in params:\n            kwargs[\"stacks\"] = stacks\n        if \"max_entries\" in params:\n            kwargs[\"max_entries\"] = trace_alloc_max_entries\n        elif \"trace_alloc_max_entries\" in params:\n            kwargs[\"trace_alloc_max_entries\"] = trace_alloc_max_entries\n        if \"stack_depth\" in params:\n            kwargs[\"stack_depth\"] = stack_depth\n        if dev_kw is not None:\n            if \"device\" in params:\n                kwargs[\"device\"] = dev_kw\n            elif \"devices\" in params:\n                kwargs[\"devices\"] = dev_kw if isinstance(dev_kw, list) else [dev_kw]\n        if \"record_context\" in params:\n            kwargs[\"record_context\"] = record_context\n\n        try:\n            f(**kwargs)\n            return \"native\", kwargs\n        except TypeError:\n            try:\n                if \"trace_alloc_max_entries\" in params and \"record_context\" in params:\n                    f(enabled=True, trace_alloc_max_entries=trace_alloc_max_entries, record_context=True)\n                    return \"legacy\", {\n                        \"enabled\": True,\n                        \"trace_alloc_max_entries\": trace_alloc_max_entries,\n                        \"record_context\": True,\n                    }\n                else:\n                    f(enabled=True)\n                    return \"legacy-min\", {\"enabled\": True}\n            except Exception:\n                raise\n\n    if devices is None or isinstance(devices, str | int | torch.device):\n        mode, used = _one_call(devices if devices is not None else None)\n    else:\n        mode, used = \"multi-device\", {}\n        for d in list(devices):\n            _mode, _used = _one_call(d)\n            used[f\"dev{d}\"] = _used\n\n    device = get_torch_device()\n    if device.is_available():\n        device.reset_peak_memory_stats()\n        device.synchronize()\n\n    rank = int(os.environ.get(\"RANK\", \"0\") or 0)\n    logger.info(f\"[memory_visualize][rank {rank}] recording enabled ({mode}); args={used}\")\n\n\nclass MemorySnapshotSampler:\n    \"\"\"\n    A utility class that dumps GPU memory snapshots.\n    This is useful for monitoring memory usage over a long-running process.\n\n    The dumped files can be visualized with https://docs.pytorch.org/memory_viz\n\n    Args:\n        out_dir (str): The directory where the snapshots will be saved.\n        tag (str): A tag for the snapshot filenames.\n    \"\"\"\n\n    def __init__(self, out_dir: str = \"./mem_snapshots\", tag: str = \"periodic\"):\n        self.out_dir = out_dir\n        self.tag = tag\n\n    def dump_memory_snapshot(self, out_dir: str = \"./mem_snapshots\", tag: str = \"snapshot\", sub_dir: str = None):\n        \"\"\"\n        Generates a memory snapshot and saves it as a pickle file in a specified directory.\n        The files are organized by timestamp in subdirectories, with all ranks' files\n        placed in the same timestamp subdirectory.\n\n        Args:\n            out_dir (str): The directory where the snapshot file will be saved.\n                The directory is created if it does not exist.\n            tag (str): A string tag to prepend to the filename for easier identification.\n            sub_dir (str): A subdirectory to place the snapshot file in.\n        \"\"\"\n        if sub_dir is None:\n            timestamp = datetime.now().strftime(\"%Y%m%d-%H%M\")\n            out_path = Path(out_dir) / timestamp\n        else:\n            out_path = Path(out_dir) / sub_dir\n        out_path.mkdir(parents=True, exist_ok=True)\n\n        # get the GPU rank on the current process\n        rank = os.environ.get(\"RANK\", \"0\")\n        pid = os.getpid()\n        # todo(chenyang): check wether we need to sync all ranks before dump\n        fname = f\"{tag}_rank{rank}_pid{pid}.pickle\"\n        path = out_path / fname\n\n        device = get_torch_device()\n        if not device.is_available():\n            logger.warning(\"[memory_visualize] is only available on CUDA devices.\")\n            return\n        try:\n            device.synchronize()\n            # Memory snapshot is CUDA-specific functionality\n            device.memory._dump_snapshot(str(path))\n            logger.info(f\"[memory_visualize] dumped: {path}\")\n        except Exception as e:\n            logger.info(f\"[memory_visualize][warn] dump failed: {e}\")\n"
  },
  {
    "path": "verl_distillation/verl/utils/metric/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .utils import reduce_metrics\n\n__all__ = [\"reduce_metrics\"]\n"
  },
  {
    "path": "verl_distillation/verl/utils/metric/utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMetrics utils.\n\"\"\"\n\nfrom typing import Any\n\nimport numpy as np\n\n\ndef reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:\n    \"\"\"\n    Reduces a dictionary of metric lists by computing the mean, max, or min of each list.\n    The reduce operation is determined by the key name:\n    - If the key contains \"max\", np.max is used\n    - If the key contains \"min\", np.min is used\n    - Otherwise, np.mean is used\n\n    Args:\n        metrics: A dictionary mapping metric names to lists of metric values.\n\n    Returns:\n        A dictionary with the same keys but with each list replaced by its reduced value.\n\n    Example:\n        >>> metrics = {\n        ...     \"loss\": [1.0, 2.0, 3.0],\n        ...     \"accuracy\": [0.8, 0.9, 0.7],\n        ...     \"max_reward\": [5.0, 8.0, 6.0],\n        ...     \"min_error\": [0.1, 0.05, 0.2]\n        ... }\n        >>> reduce_metrics(metrics)\n        {\"loss\": 2.0, \"accuracy\": 0.8, \"max_reward\": 8.0, \"min_error\": 0.05}\n    \"\"\"\n    for key, val in metrics.items():\n        if \"max\" in key:\n            metrics[key] = np.max(val)\n        elif \"min\" in key:\n            metrics[key] = np.min(val)\n        else:\n            metrics[key] = np.mean(val)\n    return metrics\n"
  },
  {
    "path": "verl_distillation/verl/utils/model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities to create common models from huggingface\n\"\"\"\n\nimport json\nimport os\nimport re\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom transformers import (\n    AutoConfig,\n    AutoModel,\n    AutoModelForCausalLM,\n    AutoModelForSequenceClassification,\n    AutoModelForTokenClassification,\n    AutoModelForVision2Seq,\n    GenerationConfig,\n    MistralForSequenceClassification,\n    PretrainedConfig,\n    PreTrainedModel,\n)\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom verl.models.registry import ModelRegistry\nfrom verl.utils.import_utils import is_trl_available\n\n\nclass LambdaLayer(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, *args, **kwargs):\n        return self.fn(*args, **kwargs)\n\n\ndef squeeze(x):\n    return torch.squeeze(x, dim=-1)\n\n\ndef update_model_config(module_config, override_config_kwargs):\n    \"\"\"Update the module config with the override_config_kwargs.\n    Args:\n        module_config: The module config from Huggingface Transformers.\n        override_config_kwargs: The kwargs to override the module config.\n    \"\"\"\n    for key, val in override_config_kwargs.items():\n        if isinstance(val, dict):\n            update_model_config(getattr(module_config, key), val)\n        else:\n            setattr(module_config, key, val)\n\n\ndef get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> dict:\n    if override_config_kwargs is None:\n        override_config_kwargs = {}\n    assert isinstance(override_config_kwargs, dict), (\n        f\"override_config_kwargs must be a dict, got {type(override_config_kwargs)}\"\n    )\n    module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)\n    update_model_config(module_config, override_config_kwargs)\n\n    return module_config\n\n\ndef get_generation_config(\n    model: str,\n    trust_remote_code: bool = False,\n) -> Optional[GenerationConfig]:\n    try:\n        return GenerationConfig.from_pretrained(model)\n    except OSError:  # Not found\n        try:\n            config = get_huggingface_actor_config(\n                model,\n                trust_remote_code=trust_remote_code,\n            )\n            return GenerationConfig.from_model_config(config)\n        except OSError:  # Not found\n            return None\n\n\ndef create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:\n    \"\"\"\n\n    Args:\n        model_name:\n        override_config_kwargs:\n\n    Returns:\n\n    \"\"\"\n    if override_config_kwargs is None:\n        override_config_kwargs = {}\n    if automodel_kwargs is None:\n        automodel_kwargs = {}\n    assert isinstance(override_config_kwargs, dict), (\n        f\"override_config_kwargs must be a dict, got {type(override_config_kwargs)}\"\n    )\n    module_config = get_huggingface_actor_config(\n        model_name, override_config_kwargs, trust_remote_code=automodel_kwargs.get(\"trust_remote_code\", False)\n    )\n    module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs)\n    return module\n\n\ndef create_huggingface_critic(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:\n    \"\"\"\n\n    Args:\n        model_name:\n        override_config_kwargs:\n\n    Returns:\n\n    \"\"\"\n    critic_module: nn.Module = create_huggingface_actor(\n        model_name, override_config_kwargs=override_config_kwargs, automodel_kwargs=automodel_kwargs\n    )\n    if automodel_kwargs is None:\n        automodel_kwargs = {}\n    torch_dtype = automodel_kwargs.get(\"torch_dtype\", torch.float32)\n    critic_module.lm_head = nn.Sequential(\n        nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), LambdaLayer(fn=squeeze)\n    )\n    return critic_module\n\n\ndef get_model_size(model: nn.Module, scale=\"auto\"):\n    n_params = sum(p.numel() for p in model.parameters())\n\n    if scale == \"auto\":\n        if n_params > 1e9:\n            scale = \"B\"\n        elif n_params > 1e6:\n            scale = \"M\"\n        elif n_params > 1e3:\n            scale = \"K\"\n        else:\n            scale = \"\"\n\n    if scale == \"B\":\n        n_params = n_params / 1e9\n    elif scale == \"M\":\n        n_params = n_params / 1e6\n    elif scale == \"K\":\n        n_params = n_params / 1e3\n    elif scale == \"\":\n        pass\n    else:\n        raise NotImplementedError(f\"Unknown scale {scale}\")\n\n    return n_params, scale\n\n\ndef print_model_size(model: nn.Module, name: str = None):\n    n_params, scale = get_model_size(model, scale=\"auto\")\n    if name is None:\n        name = model.__class__.__name__\n    print(f\"{name} contains {n_params:.2f}{scale} parameters\")\n\n\ndef create_random_mask(\n    input_ids: torch.Tensor,\n    max_ratio_of_valid_token: float,\n    max_ratio_of_left_padding: float,\n    min_ratio_of_valid_token: float = 0,\n):\n    \"\"\"Create a random mask given input_ids. Support left padding and right padding.\n    Process:\n    - Sample valid token length\n    - Sample left_padding length\n    - Generate padding\n\n    Args:\n        input_ids:\n            shape (batch_size, seq_len)\n\n    Returns:\n\n    \"\"\"\n    assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.0\n    assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.0\n    assert min_ratio_of_valid_token <= max_ratio_of_valid_token\n\n    batch_size, sequence_length = input_ids.shape\n    max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token)\n    min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token))\n    max_left_padding = int(sequence_length * max_ratio_of_left_padding)\n    assert max_num_valid_tokens + max_left_padding <= sequence_length\n    assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length\n    masks = torch.ones_like(input_ids, dtype=torch.int64)\n    # TODO: we can make this faster\n    for i in range(batch_size):\n        num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64)\n        num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64)\n\n        for index in range(num_left_padding):\n            masks[i, index] = 0\n\n        for index in range(num_left_padding + num_valid, sequence_length):\n            masks[i, index] = 0\n    return masks\n\n\ndef compute_position_id_with_mask(mask):\n    return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)\n\n\ndef convert_weight_keys(state_dict: dict[str, torch.Tensor], model: PreTrainedModel):\n    # convert state dict keys: https://github.com/huggingface/transformers/pull/38385\n    if not hasattr(model, \"_checkpoint_conversion_mapping\"):\n        return state_dict\n\n    reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()}\n    original_weights = {}\n    for key, value in state_dict.items():\n        for pattern, replacement in reverse_key_mapping.items():\n            replacement = replacement.lstrip(\"^\")  # strip off un-needed chars and patterns\n            replacement = re.sub(r\"\\(.*\\)\", \"\", replacement)\n            key, n_replace = re.subn(pattern, replacement, key)\n            # Early exit of the loop\n            if n_replace > 0:\n                break\n\n        original_weights[key] = value\n\n    return original_weights\n\n\ndef check_exclude_modules(config, key: str) -> bool:\n    \"\"\"\n    A helper method to check if the passed module's key name matches any of the exclude modules in the adapter_config.\n    Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py\n\n    Args:\n        config (`LoraConfig` | `LycorisConfig`): A config to match exclude modules from\n        key (`str`): A key to search any matches in config\n\n    Returns:\n        True of match object if key matches any exclude modules from config, False if no match found\n    \"\"\"\n    if hasattr(config, \"exclude_modules\") and config.exclude_modules:\n        if isinstance(config.exclude_modules, str):\n            if re.fullmatch(config.exclude_modules, key):\n                return True\n        elif key in config.exclude_modules:\n            return True\n        elif any(key.endswith(f\".{exclude_key}\") for exclude_key in config.exclude_modules):\n            return True\n    return False\n\n\ndef check_target_modules(config, key: str) -> bool:\n    \"\"\"\n    A helper method to check if the passed module's key name matches any of the target modules in the adapter_config.\n    Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py\n\n    Args:\n        config (`LoraConfig` | `LycorisConfig`): A config to match target modules from\n        key (`str`): A key to search any matches in config\n\n    Returns:\n        True of match object if key matches any target modules from config, False if no match found\n    \"\"\"\n    if isinstance(config.target_modules, str):\n        target_module_found = re.fullmatch(config.target_modules, key)\n    elif key in config.target_modules:\n        # this module is specified directly in target_modules\n        target_module_found = True\n    else:\n        target_module_found = any(key.endswith(f\".{target_key}\") for target_key in config.target_modules)\n\n        layer_indexes = getattr(config, \"layers_to_transform\", None)\n        layers_pattern = getattr(config, \"layers_pattern\", None)\n\n        is_using_layer_indexes = layer_indexes is not None and (\n            len(layer_indexes) != 0 if isinstance(layer_indexes, list) else True\n        )\n        if is_using_layer_indexes and target_module_found:\n            layer_index = None\n            # TODO: It's still unclear how empty layers_pattern (None, [], or \"\") should behave\n            # For now, empty layers_pattern means any layer pattern is ok\n            if layers_pattern is None or len(layers_pattern) == 0:\n                layer_index = re.match(r\".*\\.[^.]*\\.(\\d+)\\.\", key)\n            else:\n                layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern\n                for pattern in layers_pattern:\n                    layer_index = re.match(rf\".*\\.{pattern}\\.(\\d+)\\.\", key)\n                    if layer_index is not None:\n                        break\n\n            if layer_index is None:\n                target_module_found = False\n            else:\n                layer_index = int(layer_index.group(1))\n                if isinstance(layer_indexes, int):\n                    target_module_found = layer_index == layer_indexes\n                else:\n                    target_module_found = layer_index in layer_indexes\n\n    return target_module_found\n\n\ndef normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name=\"layers\"):\n    \"\"\"\n    Transform the model name in each model_chunk in each pp stage into the name in inference engine\n    \"\"\"\n    from verl.utils.megatron_utils import get_transformer_layer_offset\n\n    layer_offset = get_transformer_layer_offset(pp_rank, vpp_rank, transformer_config)\n\n    if layer_name in name:  # belong to an intermediate layer\n        split_name = name.split(\".\")\n        # find the num next to split_name\n        for i, name in enumerate(split_name):\n            if name == layer_name:\n                break\n        layer_num_idx = i + 1\n        # check the name\n        assert len(split_name) >= layer_num_idx + 1, f\"split_name = {split_name}\"\n        assert split_name[layer_num_idx].isdigit(), f\"split_name = {split_name}\"\n        # increment layer_num_idx by layer_offset\n        split_name[layer_num_idx] = str(int(split_name[layer_num_idx]) + layer_offset)\n        name = \".\".join(split_name)  # weight name in inference_tp_model\n    return name\n\n\ndef normalize_pp_vpp_params(params, num_hidden_layers, layer_name=\"layers\"):\n    \"\"\"\n    Normalize the pp vpp params into a complete named parameters.\n    This is useful when gather parameters from pp ranks and passed to a model without pp\n\n    params: Iterable[List[Dict[str, param]]]\n        params contains a list of pp, with a list of vpp named_parameters in each vpp chunk.\n    output: Dict[str, param]\n\n    \"\"\"\n    pp_size = len(params)\n    for pp_rank in range(len(params)):\n        vpp_size = len(params[pp_rank])\n        for vpp_rank in range(vpp_size):\n            for name, param in params[pp_rank][vpp_rank].items():\n                normalized_name = normalize_model_name(\n                    name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name\n                )\n                yield normalized_name, param\n\n\ndef get_parallel_model_from_config(\n    config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False\n):\n    from megatron.core import ModelParallelConfig\n\n    assert isinstance(megatron_config, ModelParallelConfig)\n    model_class = _get_parallel_model_architecture_from_config(config, value)\n\n    model = model_class(\n        config,\n        megatron_config,\n        pre_process=pre_process,\n        post_process=post_process,\n        share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n    )\n    return model\n\n\ndef _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> type[nn.Module]:\n    architectures = getattr(config, \"architectures\", [])\n    for arch in architectures:\n        model_cls = ModelRegistry.load_model_cls(arch, value)\n        print(\"after load model cls\")\n        if model_cls is not None:\n            return model_cls\n    raise ValueError(\n        f\"Model architectures {architectures} are not supported for now. Supported architectures: \"\n        f\"{ModelRegistry.get_supported_archs()}\"\n    )\n\n\ndef _load_hf_model(config, model_config, is_value_model):\n    \"\"\"Helper function containing the loading hf model logic\"\"\"\n    from accelerate import init_empty_weights\n    from megatron.core import parallel_state as mpu\n\n    from verl.models.mcore.saver import _megatron_calc_global_rank\n\n    assert hasattr(model_config, \"architectures\"), \"architectures cannot be empty when load weight!\"\n    architectures = getattr(model_config, \"architectures\", [])\n\n    # get auto class\n    auto_cls = get_hf_auto_model_class(model_config)\n\n    if config.model.path.startswith(\"hdfs:\"):\n        from verl.utils.fs import copy_to_local\n\n        print(f\"start download from {config.model.path}\")\n        local_model_path = copy_to_local(src=config.model.path, use_shm=config.model.get(\"use_shm\", False))\n        print(\"finish download\")\n    else:\n        local_model_path = config.model.path\n        print(f\"load from local dir {local_model_path}\")\n\n    src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=mpu.get_context_parallel_rank())\n    cpu_init_weights = lambda: torch.device(\"cpu\")\n    init_context = init_empty_weights if torch.distributed.get_rank() != src_rank else cpu_init_weights\n    with init_context(), warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\")\n        # TODO: to find a better way to load mistral7b-rm lm_head\n        if \"mistral7b-rm\" in config.model.path:\n            model = MistralForSequenceClassification.from_pretrained(\n                local_model_path,\n                torch_dtype=\"auto\",\n                # device_map=\"auto\",  # disable auto device_map, the HF weight is only loaded to CPU in src_rank\n                # low_cpu_mem_usage=True\n            )  # use score head instead of lm_head\n            state_dict = model.state_dict()\n            state_dict[\"lm_head.weight\"] = state_dict[\"score.weight\"]\n            state_dict[\"model.embed_tokens.weight\"] = state_dict[\"model.embed_tokens.weight\"][\n                :32000\n            ]  # workaround, 32001 -> 32000\n            is_value_model = True\n        else:\n            model = auto_cls.from_pretrained(\n                local_model_path,\n                torch_dtype=\"auto\",\n                # device_map=\"auto\", # disable auto device_map, the HF weight is only loaded to CPU in src_rank\n                # low_cpu_mem_usage=True\n            )\n            state_dict = model.state_dict()\n\n    return architectures, model, state_dict, is_value_model\n\n\ndef get_hf_model_path(config):\n    if config.model.path.startswith(\"hdfs:\"):\n        from verl.utils.fs import copy_to_local\n\n        local_model_path = copy_to_local(src=config.model.path, use_shm=config.model.get(\"use_shm\", False))\n    else:\n        local_model_path = config.model.path\n    return local_model_path\n\n\ndef load_megatron_model_weights(config, model_config, parallel_model, params_dtype, is_value_model=False):\n    \"\"\"Load weights for verl customized model.\"\"\"\n    architectures, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model)\n\n    from verl.models.weight_loader_registry import get_weight_loader\n\n    print(f\"before weight loader: architectures = {architectures}...\")\n    for arch in architectures:\n        print(f\"call weight loader arch = {arch}, model config = {model.config}\")\n        weight_loader = get_weight_loader(arch)\n        weight_loader(\n            state_dict=state_dict,\n            wrapped_models=parallel_model,\n            config=model.config,\n            params_dtype=params_dtype,\n            is_value_model=is_value_model,\n            tie_word_embeddings=model_config.tie_word_embeddings,\n        )\n    return model.config\n\n\ndef load_megatron_gptmodel_weights(config, model_config, parallel_model, params_dtype, is_value_model=False):\n    \"\"\"Load weights for mcore GPT model.\"\"\"\n    _, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model)\n\n    from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel\n\n    load_state_dict_to_megatron_gptmodel(\n        state_dict=state_dict,\n        wrapped_models=parallel_model,\n        config=model.config,\n        params_dtype=params_dtype,\n        is_value_model=is_value_model,\n    )\n    del state_dict, model\n\n\n# pad input_ids_rmpad, cu_seqlens and max_seqlen_in_batch to be divisible by tp\ndef pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size):\n    \"\"\"pad the tokens such that the total length is a multiple of size.\n    This function is useful when applying sequence parallel and context parallel\n\n    Args:\n        unpad_tokens: (total_nnz, ...). Tokens after removing padding\n        cu_seqlens: (total_nnz + 1,)\n        max_seqlen_in_batch: int\n\n    Returns:\n\n    \"\"\"\n    F = nn.functional\n\n    total_nnz = unpad_tokens.shape[0]\n\n    pad_size = 0 if total_nnz % size == 0 else size - total_nnz % size\n\n    # we assume adding a new data in the batch with seqlen pad_size\n    if pad_size > 0:\n        if unpad_tokens.ndim == 1:\n            unpad_tokens = F.pad(unpad_tokens, (0, pad_size))\n        elif unpad_tokens.ndim == 2:\n            unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))\n        else:\n            raise NotImplementedError(f\"Padding dim {unpad_tokens.ndim()} is not supported\")\n\n        cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_size + cu_seqlens[-1])\n        max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size)\n\n    return unpad_tokens, cu_seqlens, max_seqlen_in_batch\n\n\ndef load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False):\n    from megatron.core import dist_checkpointing\n    from megatron.core.dist_checkpointing.serialization import StrictHandling\n\n    from verl.utils.megatron_utils import unwrap_model\n\n    # strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED\n    strict = StrictHandling.ASSUME_OK_UNEXPECTED\n    for model in parallel_model:\n        ssd = unwrap_model(model).sharded_state_dict()\n        if is_value_model:\n            for k in list(ssd.keys()):\n                if \"output_layer\" in k:\n                    ssd.pop(k)\n        dist_checkpointing.load(ssd, dist_weight_path, strict=strict)\n\n    return\n\n\ndef get_parallel_gptmodel_from_config(\n    tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False\n):\n    from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec\n    from megatron.core.models.gpt.gpt_model import GPTModel\n\n    use_te = True\n    assert tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n    transformer_layer_spec = get_gpt_decoder_block_spec(tfconfig, use_transformer_engine=use_te)\n    rope_scaling_args = {}\n    if hf_config.rope_scaling is not None:\n        assert hf_config.rope_scaling[\"type\"] == \"linear\", \"only linear scaling is supported for now\"\n        rope_scaling_args[\"seq_len_interpolation_factor\"] = hf_config.rope_scaling[\"factor\"]\n    parallel_model = GPTModel(\n        config=tfconfig,\n        transformer_layer_spec=transformer_layer_spec,\n        vocab_size=hf_config.vocab_size,\n        max_sequence_length=hf_config.max_position_embeddings,\n        pre_process=pre_process,\n        post_process=post_process,\n        share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n        position_embedding_type=\"rope\",\n        rotary_base=hf_config.rope_theta,\n        **rope_scaling_args,\n    )\n    # # for layer in parallel_model.decoder.layers:\n    # layer.self_attention.core_attention.flash_attention.softmax_scale = None\n    if post_process and value:\n        from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer\n\n        parallel_model.output_layer = LinearForLastLayer(\n            input_size=tfconfig.hidden_size, output_size=1, config=tfconfig\n        )\n    return parallel_model\n\n\ndef patch_valuehead_model(model) -> None:\n    from types import MethodType\n\n    from transformers import PreTrainedModel\n    from trl import AutoModelForCausalLMWithValueHead\n\n    def tie_weights(self: \"AutoModelForCausalLMWithValueHead\") -> None:\n        if isinstance(self.pretrained_model, PreTrainedModel):\n            self.pretrained_model.tie_weights()\n\n    def get_input_embeddings(self: \"AutoModelForCausalLMWithValueHead\") -> torch.nn.Module:\n        if isinstance(self.pretrained_model, PreTrainedModel):\n            return self.pretrained_model.get_input_embeddings()\n\n    def get_output_embeddings(self: \"AutoModelForCausalLMWithValueHead\") -> torch.nn.Module:\n        if isinstance(self.pretrained_model, PreTrainedModel):\n            return self.pretrained_model.get_output_embeddings()\n\n    def can_generate(self):\n        return False\n\n    ignore_modules = [name for name, _ in model.named_parameters() if \"pretrained_model\" in name]\n    model._keys_to_ignore_on_save = ignore_modules\n    model.tie_weights = MethodType(tie_weights, model)\n    model.get_input_embeddings = MethodType(get_input_embeddings, model)\n    model.get_output_embeddings = MethodType(get_output_embeddings, model)\n    model.can_generate = MethodType(can_generate, model)\n    model._no_split_modules = getattr(model.pretrained_model, \"_no_split_modules\", [])\n\n\ndef load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code):\n    from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq\n\n    try:\n        model = AutoModelForTokenClassification.from_pretrained(\n            pretrained_model_name_or_path=local_path,\n            torch_dtype=torch_dtype,\n            config=model_config,\n            attn_implementation=\"flash_attention_2\",\n            trust_remote_code=trust_remote_code,\n        )\n        return model\n    except BaseException as e:\n        if not is_trl_available():\n            raise RuntimeError(\n                f\"model({local_path}) is not a value head model, please install trl to make it valid\"\n            ) from e\n\n    assert is_trl_available()\n\n    from trl import AutoModelForCausalLMWithValueHead\n\n    if type(model_config) in AutoModelForVision2Seq._model_mapping.keys():\n        module_class = AutoModelForVision2Seq\n    else:\n        module_class = AutoModelForCausalLM\n    ori_model = module_class.from_pretrained(\n        pretrained_model_name_or_path=local_path,\n        torch_dtype=torch_dtype,\n        config=model_config,\n        attn_implementation=\"flash_attention_2\",\n        trust_remote_code=trust_remote_code,\n    )\n    model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model)\n    patch_valuehead_model(model)\n    return model\n\n\n_architecture_to_auto_class = {\n    \"ForCausalLM\": AutoModelForCausalLM,\n    \"ForVision2Seq\": AutoModelForVision2Seq,\n    \"ForTokenClassification\": AutoModelForTokenClassification,\n    \"ForSequenceClassification\": AutoModelForSequenceClassification,\n}\n\n\ndef get_hf_auto_model_class(hf_config):\n    has_remote_code = hasattr(hf_config, \"auto_map\") and any(\n        hf_config.architectures[0] in val for val in hf_config.auto_map.values()\n    )\n    if has_remote_code:\n        auto_class = next(k for k, v in hf_config.auto_map.items() if hf_config.architectures[0] in v)\n        match auto_class:\n            case \"AutoModelForVision2Seq\":\n                actor_module_class = AutoModelForVision2Seq\n            case \"AutoModelForCausalLM\":\n                actor_module_class = AutoModelForCausalLM\n            case _:\n                actor_module_class = AutoModel\n    else:\n        actor_module_class = AutoModel\n        for key, cls in _architecture_to_auto_class.items():\n            if key in hf_config.architectures[0]:\n                actor_module_class = cls\n                break\n\n    return actor_module_class\n\n\ndef extract_multi_modal_inputs(\n    batch_data: list[dict[str, torch.Tensor]],\n    indices: Optional[list[int]] = None,\n) -> dict[str, torch.Tensor | list[torch.Tensor]]:\n    \"\"\"\n    Extract and process multi-modal inputs from a batch.\n\n    Args:\n        batch_data (list[dict[str, torch.Tensor]]): The batch containing potential multi-modal inputs\n        indices (Optional[list[int]]): If provided, only extract inputs at these indices\n\n    Returns:\n        dict[str, torch.Tensor | list[torch.Tensor]]: Processed multi-modal inputs ready for model consumption\n\n    \"\"\"\n    multi_modal_inputs = {}\n    multi_modal_inputs_collected = {}\n    has_image_bound = False\n\n    selected_batch_data = batch_data\n    if indices is not None:\n        selected_batch_data = [batch_data[i] for i in indices if i < len(batch_data)]\n\n    for inputs in selected_batch_data:\n        if \"image_bound\" in inputs:\n            has_image_bound = True\n        for key, value in inputs.items():\n            if value is not None:\n                if key not in multi_modal_inputs_collected:\n                    multi_modal_inputs_collected[key] = []\n                multi_modal_inputs_collected[key].append(value)\n\n    for key, values in multi_modal_inputs_collected.items():\n        if has_image_bound:  # minicpm-o logic\n            multi_modal_inputs[key] = values\n        else:\n            multi_modal_inputs[key] = torch.cat(values, dim=0)\n\n    return multi_modal_inputs\n\n\ndef get_lora_rank_from_adapter(adapter_path: str | os.PathLike) -> int:\n    \"\"\"\n    Extract LoRA rank from adapter configuration file.\n\n    Args:\n        adapter_path: Path to LoRA adapter directory\n\n    Returns:\n        LoRA rank value from adapter_config.json\n\n    Raises:\n        FileNotFoundError: If adapter path or config file doesn't exist\n        ValueError: If config file is invalid or missing rank\n    \"\"\"\n    adapter_path = os.path.abspath(os.path.expanduser(str(adapter_path)))\n\n    if not os.path.exists(adapter_path):\n        raise FileNotFoundError(f\"LoRA adapter path not found: {adapter_path}\")\n\n    config_path = os.path.join(adapter_path, \"adapter_config.json\")\n    if not os.path.exists(config_path):\n        raise FileNotFoundError(f\"adapter_config.json not found in {adapter_path}\")\n\n    try:\n        with open(config_path, encoding=\"utf-8\") as f:\n            config = json.load(f)\n            if \"r\" not in config:\n                raise ValueError(f\"LoRA rank 'r' not found in {config_path}\")\n            return int(config[\"r\"])\n    except json.JSONDecodeError as e:\n        raise ValueError(f\"Invalid JSON in {config_path}: {e}\") from e\n    except (KeyError, ValueError) as e:\n        raise ValueError(f\"Cannot parse LoRA rank from {config_path}: {e}\") from e\n\n\n@dataclass\nclass CausalLMOutputForPPO(CausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n"
  },
  {
    "path": "verl_distillation/verl/utils/net_utils.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport ipaddress\n\n\ndef is_ipv4(ip_str: str) -> bool:\n    \"\"\"\n    Check if the given string is an IPv4 address\n\n    Args:\n        ip_str: The IP address string to check\n\n    Returns:\n        bool: Returns True if it's an IPv4 address, False otherwise\n    \"\"\"\n    try:\n        ipaddress.IPv4Address(ip_str)\n        return True\n    except ipaddress.AddressValueError:\n        return False\n\n\ndef is_ipv6(ip_str: str) -> bool:\n    \"\"\"\n    Check if the given string is an IPv6 address\n\n    Args:\n        ip_str: The IP address string to check\n\n    Returns:\n        bool: Returns True if it's an IPv6 address, False otherwise\n    \"\"\"\n    try:\n        ipaddress.IPv6Address(ip_str)\n        return True\n    except ipaddress.AddressValueError:\n        return False\n"
  },
  {
    "path": "verl_distillation/verl/utils/npu_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\n\n# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py\nclass IndexFirstAxis(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, indices):\n        ctx.save_for_backward(indices)\n        assert input.ndim >= 2\n        ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]\n        second_dim = other_shape.numel()\n        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.\n        # return input[indices]\n        return torch.gather(rearrange(input, \"b ... -> b (...)\"), 0, repeat(indices, \"z -> z d\", d=second_dim)).reshape(\n            -1, *other_shape\n        )\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        (indices,) = ctx.saved_tensors\n        assert grad_output.ndim >= 2\n        other_shape = grad_output.shape[1:]\n        grad_output = rearrange(grad_output, \"b ... -> b (...)\")\n        grad_input = torch.zeros(\n            [ctx.first_axis_dim, grad_output.shape[1]],\n            device=grad_output.device,\n            dtype=grad_output.dtype,\n        )\n        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.\n        # grad_input[indices] = grad_output\n        grad_input.scatter_(0, repeat(indices, \"z -> z d\", d=grad_output.shape[1]), grad_output)\n        return grad_input.reshape(ctx.first_axis_dim, *other_shape), None\n\n\nindex_first_axis = IndexFirstAxis.apply\n\n\n# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py\nclass IndexPutFirstAxis(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, values, indices, first_axis_dim):\n        ctx.save_for_backward(indices)\n        assert indices.ndim == 1\n        assert values.ndim >= 2\n        output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)\n        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.\n        output[indices] = values\n        # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        (indices,) = ctx.saved_tensors\n        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.\n        grad_values = grad_output[indices]\n        # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))\n        return grad_values, None, None\n\n\nindex_put_first_axis = IndexPutFirstAxis.apply\n\n\n# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py\ndef pad_input(hidden_states, indices, batch, seqlen):\n    \"\"\"\n    Arguments:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.\n        indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.\n        batch: int, batch size for the padded sequence.\n        seqlen: int, maximum sequence length for the padded sequence.\n    Return:\n        hidden_states: (batch, seqlen, ...)\n    \"\"\"\n    # dim = hidden_states.shape[-1]\n    # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)\n    # output[indices] = hidden_states\n    output = index_put_first_axis(hidden_states, indices, batch * seqlen)\n    return rearrange(output, \"(b s) ... -> b s ...\", b=batch)\n\n\n# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py\ndef unpad_input(hidden_states, attention_mask, unused_mask=None):\n    \"\"\"\n    Arguments:\n        hidden_states: (batch, seqlen, ...)\n        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.\n        unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.\n    Return:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.\n        indices: (total_nnz), the indices of masked tokens from the flattened input sequence.\n        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.\n        max_seqlen_in_batch: int\n        seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.\n    \"\"\"\n    all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask\n    seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)\n    used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the\n    # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim\n    # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to\n    # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,\n    # so we write custom forward and backward to make it a bit faster.\n    return (\n        index_first_axis(rearrange(hidden_states, \"b s ... -> (b s) ...\"), indices),\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n        used_seqlens_in_batch,\n    )\n"
  },
  {
    "path": "verl_distillation/verl/utils/profiler/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom ..device import is_npu_available\nfrom ..import_utils import is_nvtx_available\nfrom .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer\nfrom .profile import DistProfiler, DistProfilerExtension, ProfilerConfig\n\n# Select marker implementations by availability, but keep DistProfiler as our dispatcher\nif is_nvtx_available():\n    from .nvtx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer\nelif is_npu_available:\n    from .mstx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer\nelse:\n    from .performance import marked_timer\n    from .profile import mark_annotate, mark_end_range, mark_start_range\n\n__all__ = [\n    \"GPUMemoryLogger\",\n    \"log_gpu_memory_usage\",\n    \"mark_start_range\",\n    \"mark_end_range\",\n    \"mark_annotate\",\n    \"DistProfiler\",\n    \"DistProfilerExtension\",\n    \"ProfilerConfig\",\n    \"simple_timer\",\n    \"marked_timer\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/utils/profiler/config.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport warnings\nfrom dataclasses import dataclass, field\nfrom typing import Any, Optional\n\nfrom omegaconf import MISSING\n\nfrom verl.base_config import BaseConfig\n\n\n@dataclass\nclass NsightToolConfig(BaseConfig):\n    \"\"\"Nsight tool config.\"\"\"\n\n    \"True for each task has its own database, False for all tasks in one training step share one database.\"\n    discrete: bool = False\n\n    def __post_init__(self) -> None:\n        pass\n\n\n@dataclass\nclass TorchProfilerToolConfig(BaseConfig):\n    \"\"\"Torch profiler tool config.\n\n    Args:\n        step_start (int): Start step in update_policy.\n        step_end (int): End step.\n    \"\"\"\n\n    step_start: int = -1\n    step_end: int = -1\n\n    def __post_init__(self) -> None:\n        \"\"\"config validation logics go here\"\"\"\n        warnings.warn(\"Torch profiler tool config is not fully supported now.\", stacklevel=1)\n        assert isinstance(self.step_start, int), f\"Profiler step_start must be of type int, got {type(self.step_start)}\"\n\n\n@dataclass\nclass TorchMemoryToolConfig(BaseConfig):\n    \"\"\"Torch memory profiler tool config.\n\n    Args:\n        trace_alloc_max_entries (int): Maximum number of memory allocation entries to track.\n        stack_depth (int): Stack trace depth for memory allocations.\n    \"\"\"\n\n    trace_alloc_max_entries: int = 100_000\n    stack_depth: int = 32\n\n    def __post_init__(self) -> None:\n        \"\"\"config validation logics go here\"\"\"\n        assert isinstance(self.trace_alloc_max_entries, int), (\n            f\"trace_alloc_max_entries must be int, got {type(self.trace_alloc_max_entries)}\"\n        )\n        assert isinstance(self.stack_depth, int), f\"stack_depth must be int, got {type(self.stack_depth)}\"\n        assert self.trace_alloc_max_entries > 0, (\n            f\"trace_alloc_max_entries must be positive, got {self.trace_alloc_max_entries}\"\n        )\n        assert self.stack_depth > 0, f\"stack_depth must be positive, got {self.stack_depth}\"\n\n\n@dataclass\nclass NPUToolConfig(NsightToolConfig):\n    \"\"\"NPU profiler too; config.\"\"\"\n\n    # options: npu, cpu, memory, shapes, module, stack\n    contents: list[str] = field(default_factory=list)\n\n    # Collection level, optional values: level_none, level0, level1, level2.\n    level: str = \"level1\"\n\n    # Whether to automatically parse the data.\n    analysis: bool = False\n\n    def __post_init__(self) -> None:\n        \"\"\"config validation logics go here\"\"\"\n        assert isinstance(self.contents, list), f\"Profiler contents must be of type list, got {type(self.contents)}\"\n        assert isinstance(self.level, str), f\"Profiler level must be of type str, got {type(self.level)}\"\n        assert isinstance(self.analysis, bool), f\"Profiler analysis must be of type bool, got {type(self.analysis)}\"\n        for content in self.contents:\n            assert content in [\"npu\", \"cpu\", \"memory\", \"shapes\", \"module\", \"stack\"], (\n                f\"Profiler contents only supports npu, cpu, memory, shapes, module, stack, but gets {content}\"\n            )\n        assert self.level in [\"level_none\", \"level0\", \"level1\", \"level2\"], (\n            f\"Profiler level only supports level0, 1, 2, and level_none, but gets {self.level}\"\n        )\n\n\n@dataclass\nclass ProfilerConfig(BaseConfig):\n    \"\"\"Worker profiler config.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        discrete (bool): True for each task has its own database, False for all tasks in one training step\n          share one database.\n        all_ranks (bool): Whether to profile all ranks.\n        ranks (list[int]): The ranks that will be profiled. Defaults to [].\n        global_tool_config (Any): Global tool configuration for all profiling tools.\n    \"\"\"\n\n    tool: Optional[str] = MISSING\n    enable: bool = False\n    all_ranks: bool = False\n    ranks: list[int] = field(default_factory=list)\n    save_path: Optional[str] = MISSING\n    tool_config: Any = MISSING  # Just a placeholder, will use configs above directly\n    global_tool_config: Optional[Any] = None  # Global tool configuration for all profiling tools\n\n    def union(self, other: \"ProfilerConfig\") -> \"ProfilerConfig\":\n        assert self.tool == other.tool, f\"Cannot union ProfilerConfig with different tools: {self.tool} vs {other.tool}\"\n        return ProfilerConfig(\n            tool=self.tool,\n            enable=self.enable or other.enable,\n            all_ranks=self.all_ranks or other.all_ranks,\n            ranks=list(set(self.ranks or []) | set(other.ranks or [])),\n            save_path=self.save_path,\n            tool_config=self.tool_config,\n            global_tool_config=self.global_tool_config or other.global_tool_config,\n        )\n\n    def intersect(self, other: \"ProfilerConfig\") -> \"ProfilerConfig\":\n        assert self.tool == other.tool, (\n            f\"Cannot intersect ProfilerConfig with different tools: {self.tool} vs {other.tool}\"\n        )\n        return ProfilerConfig(\n            tool=self.tool,\n            enable=self.enable and other.enable,\n            all_ranks=self.all_ranks and other.all_ranks,\n            ranks=list(set(self.ranks or []) & set(other.ranks or [])),\n            save_path=self.save_path,\n            tool_config=self.tool_config,\n            global_tool_config=self.global_tool_config if self.global_tool_config else other.global_tool_config,\n        )\n\n    def __post_init__(self) -> None:\n        \"\"\"config validation logics go here\"\"\"\n        assert isinstance(self.ranks, set | list | tuple), (\n            f\"Profiler ranks must be of type list, got {type(self.ranks)}\"\n        )\n"
  },
  {
    "path": "verl_distillation/verl/utils/profiler/empty_annotations.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable, Optional\n\n\ndef mark_start_range(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> None:\n    pass\n\n\ndef mark_end_range(range_id: str) -> None:\n    pass\n\n\ndef mark_annotate(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> Callable:\n    def decorator(func):\n        return func\n\n    return decorator\n"
  },
  {
    "path": "verl_distillation/verl/utils/profiler/mstx_profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Inspired from https://gitee.com/ascend/MindSpeed-RL/blob/master/mindspeed_rl/utils/utils.py\nimport functools\nimport logging\nimport os\nfrom contextlib import contextmanager\nfrom typing import Any, Callable, Optional\n\nimport torch_npu\nfrom torch_npu.npu import mstx\n\nfrom .config import NPUToolConfig\nfrom .profile import DistProfiler, ProfilerConfig\n\n\ndef mark_start_range(message: Optional[str] = None) -> None:\n    \"\"\"Start a mark range in the profiler.\n\n    Args:\n        message (str, optional):\n            The message to be displayed in the profiler. Defaults to None.\n    \"\"\"\n    return mstx.range_start(message=message)\n\n\ndef mark_end_range(range_id: str) -> None:\n    \"\"\"End a mark range in the profiler.\n\n    Args:\n        range_id (str):\n            The id of the mark range to end.\n    \"\"\"\n    return mstx.range_end(range_id)\n\n\ndef mark_annotate(message: Optional[str] = None) -> Callable:\n    \"\"\"Decorate a function to annotate a mark range along with the function life cycle.\n\n    Args:\n        message (str, optional):\n            The message to be displayed in the profiler. Defaults to None.\n    \"\"\"\n\n    def decorator(func):\n        profile_message = message or func.__name__\n        return mstx.mstx_range(profile_message)(func)\n\n    return decorator\n\n\n@contextmanager\ndef marked_timer(name: str, timing_raw: dict[str, float], *args: Any, **kwargs: Any) -> None:\n    \"\"\"Context manager for timing with MSTX markers.\n\n    This utility function measures the execution time of code within its context,\n    accumulates the timing information, and adds MSTX markers for profiling.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n\n    Yields:\n        None: This is a context manager that yields control back to the code block.\n    \"\"\"\n    if args:\n        logging.warning(f\"Args are not supported in mstx_profile, but received: {args}\")\n    if kwargs:\n        logging.warning(f\"Kwargs are not supported in mstx_profile, but received: {kwargs}\")\n    mark_range = mark_start_range(message=name)\n    from .performance import _timer\n\n    yield from _timer(name, timing_raw)\n    mark_end_range(mark_range)\n\n\ndef get_npu_profiler(\n    contents: list[str],\n    profile_level: str,\n    profile_save_path: str,\n    analysis: bool,\n    role: Optional[str] = None,\n    profile_step: Optional[str] = None,\n):\n    \"\"\"Generate and return an NPU profiler object.\n\n    Args:\n        contents (list[str]):\n            A list of options to control the collection content,\n            such as npu, cpu, memory, shapes, module, stack.\n        profile_level (str):\n            The collection level, which can be set to level_none,\n            level0, level1 and level2.\n        profile_save_path (str):\n            The path to save the collected data.\n        analysis (bool):\n            Whether to enables automatic data parsing.\n        role (str, optional):\n            The role of the current data collection. Defaults to None.\n        profile_step(str, optional):\n            The current training step. Defaults to None.\n    \"\"\"\n    if profile_level == \"level_none\":\n        level = torch_npu.profiler.ProfilerLevel.Level_none\n    elif profile_level == \"level0\":\n        level = torch_npu.profiler.ProfilerLevel.Level0\n    elif profile_level == \"level1\":\n        level = torch_npu.profiler.ProfilerLevel.Level1\n    elif profile_level == \"level2\":\n        level = torch_npu.profiler.ProfilerLevel.Level2\n    else:\n        raise ValueError(f\"level only supports level0, 1, 2, and level_none, but gets {profile_level}\")\n\n    if profile_step:\n        profile_save_path = os.path.join(profile_save_path, profile_step)\n    if role:\n        profile_save_path = os.path.join(profile_save_path, role)\n\n    experimental_config = torch_npu.profiler._ExperimentalConfig(\n        aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,\n        profiler_level=level,\n        export_type=torch_npu.profiler.ExportType.Text,\n        data_simplification=True,\n        msprof_tx=True,\n    )\n\n    activites = []\n    if contents is None or \"npu\" in contents:\n        activites.append(torch_npu.profiler.ProfilerActivity.NPU)\n    if contents is None or \"cpu\" in contents:\n        activites.append(torch_npu.profiler.ProfilerActivity.CPU)\n\n    prof = torch_npu.profiler.profile(\n        with_modules=contents is None or \"module\" in contents,\n        with_stack=contents is None or \"stack\" in contents,\n        record_shapes=contents is None or \"shapes\" in contents,\n        profile_memory=contents is None or \"memory\" in contents,\n        activities=activites,\n        on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_save_path, analyse_flag=analysis),\n        experimental_config=experimental_config,\n    )\n    return prof\n\n\nclass NPUProfiler(DistProfiler):\n    \"\"\"\n    NPU profiler. Initialized in a worker to control the NPU profiler.\n    \"\"\"\n\n    _define_count = 0\n\n    def __init__(self, rank: int, config: ProfilerConfig, tool_config: NPUToolConfig, **kwargs):\n        \"\"\"Initialize the NsightSystemsProfiler.\n\n        Args:\n            rank (int): The rank of the current process.\n            config (Optional[ProfilerConfig]): Configuration for the profiler. If None, a default configuration is used.\n            tool_config (NPUToolConfig): The config to control npu profiler behavior.\n        \"\"\"\n        if not config:\n            config = ProfilerConfig(ranks=[], enable=False)\n        if not tool_config:\n            assert not config.enable, \"tool_config must be set when profiler is enabled\"\n        self.enable: bool = config.enable\n        if not config.enable:\n            return\n        self.this_step: bool = False\n        self.discrete: bool = tool_config.discrete\n        self.this_rank: bool = False\n        self.profile_npu = None\n        self.profile_contents = tool_config.contents\n        self.profile_level = tool_config.level\n        self.profile_save_path = config.save_path\n        self.analysis = tool_config.analysis\n        if config.all_ranks:\n            self.this_rank = True\n        elif config.ranks:\n            self.this_rank = rank in config.ranks\n\n    def start(self, **kwargs):\n        role, profile_step = kwargs.get(\"role\", None), kwargs.get(\"profile_step\", None)\n        profile_step = str(profile_step) if profile_step is not None else None\n        if self.enable and self.this_rank:\n            self.this_step = True\n            if not self.discrete and NPUProfiler._define_count == 0:\n                self.profile_npu = get_npu_profiler(\n                    contents=self.profile_contents,\n                    profile_level=self.profile_level,\n                    profile_save_path=self.profile_save_path,\n                    analysis=self.analysis,\n                    role=role,\n                    profile_step=profile_step,\n                )\n                self.profile_npu.start()\n                NPUProfiler._define_count += 1\n\n    def stop(self):\n        if self.enable and self.this_rank:\n            self.this_step = False\n            if not self.discrete and NPUProfiler._define_count == 1:\n                self.profile_npu.step()\n                self.profile_npu.stop()\n                NPUProfiler._define_count -= 1\n\n    def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **kwargs_outer) -> Callable:\n        \"\"\"Decorate a Worker member function to profile the current rank in the current training step.\n\n        Requires the target function to be a member function of a Worker,\n        which has a member field `profiler` with NPUProfiler type.\n\n        Args:\n            message (str, optional):\n                The message to be displayed in the profiler. Defaults to None.\n            role (str, optional):\n                The role of the current data collection. Defaults to None.\n        \"\"\"\n\n        def decorator(func):\n            @functools.wraps(func)\n            def wrapper(*args, **kwargs_inner):\n                if not self.enable:\n                    return func(*args, **kwargs_inner)\n\n                profile_name = message or func.__name__\n                discrete_mode = self.discrete\n                profile_enable = self.this_step and self.enable\n\n                if not profile_enable:\n                    return func(*args, **kwargs_inner)\n\n                if profile_enable:\n                    if not discrete_mode:\n                        mark_range = mark_start_range(message=profile_name)\n                    else:\n                        profile_npu = get_npu_profiler(\n                            contents=self.profile_contents,\n                            profile_level=self.profile_level,\n                            profile_save_path=self.profile_save_path,\n                            analysis=self.analysis,\n                            role=role,\n                        )\n                        profile_npu.start()\n                        mark_range = mark_start_range(message=profile_name)\n\n                result = func(*args, **kwargs_inner)\n\n                if profile_enable:\n                    if not discrete_mode:\n                        mark_end_range(mark_range)\n                    else:\n                        mark_end_range(mark_range)\n                        profile_npu.step()\n                        profile_npu.stop()\n\n                return result\n\n            return wrapper\n\n        return decorator\n"
  },
  {
    "path": "verl_distillation/verl/utils/profiler/nvtx_profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nfrom contextlib import contextmanager\nfrom typing import Callable, Optional\n\nimport nvtx\nimport torch\n\nfrom .config import NsightToolConfig\nfrom .profile import DistProfiler, ProfilerConfig\n\n\ndef mark_start_range(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> None:\n    \"\"\"Start a mark range in the profiler.\n\n    Args:\n        message (str, optional):\n            The message to be displayed in the profiler. Defaults to None.\n        color (str, optional):\n            The color of the range. Defaults to None.\n        domain (str, optional):\n            The domain of the range. Defaults to None.\n        category (str, optional):\n            The category of the range. Defaults to None.\n    \"\"\"\n    return nvtx.start_range(message=message, color=color, domain=domain, category=category)\n\n\ndef mark_end_range(range_id: str) -> None:\n    \"\"\"End a mark range in the profiler.\n\n    Args:\n        range_id (str):\n            The id of the mark range to end.\n    \"\"\"\n    return nvtx.end_range(range_id)\n\n\ndef mark_annotate(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> Callable:\n    \"\"\"Decorate a function to annotate a mark range along with the function life cycle.\n\n    Args:\n        message (str, optional):\n            The message to be displayed in the profiler. Defaults to None.\n        color (str, optional):\n            The color of the range. Defaults to None.\n        domain (str, optional):\n            The domain of the range. Defaults to None.\n        category (str, optional):\n            The category of the range. Defaults to None.\n    \"\"\"\n\n    def decorator(func):\n        profile_message = message or func.__name__\n        return nvtx.annotate(profile_message, color=color, domain=domain, category=category)(func)\n\n    return decorator\n\n\n@contextmanager\ndef marked_timer(\n    name: str,\n    timing_raw: dict[str, float],\n    color: str = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n):\n    \"\"\"Context manager for timing with NVTX markers.\n\n    This utility function measures the execution time of code within its context,\n    accumulates the timing information, and adds NVTX markers for profiling.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n        color (Optional[str]): Color for the NVTX marker. Defaults to None.\n        domain (Optional[str]): Domain for the NVTX marker. Defaults to None.\n        category (Optional[str]): Category for the NVTX marker. Defaults to None.\n\n    Yields:\n        None: This is a context manager that yields control back to the code block.\n    \"\"\"\n    mark_range = mark_start_range(message=name, color=color, domain=domain, category=category)\n    from .performance import _timer\n\n    yield from _timer(name, timing_raw)\n    mark_end_range(mark_range)\n\n\nclass NsightSystemsProfiler(DistProfiler):\n    \"\"\"Nsight system profiler. Installed in a worker to control the Nsight system profiler.\"\"\"\n\n    def __init__(self, rank: int, config: Optional[ProfilerConfig], tool_config: Optional[NsightToolConfig], **kwargs):\n        \"\"\"Initialize the NsightSystemsProfiler.\n\n        Args:\n            rank (int): The rank of the current process.\n            config (Optional[ProfilerConfig]): Configuration for the profiler. If None, a default configuration is used.\n        \"\"\"\n        # If no configuration is provided, create a default ProfilerConfig with an empty list of ranks\n        if not config:\n            config = ProfilerConfig(ranks=[])\n        if not tool_config:\n            assert not config.enable, \"tool_config must be provided when profiler is enabled\"\n        self.enable = config.enable\n        if not config.enable:\n            return\n        self.this_step: bool = False\n        self.discrete: bool = tool_config.discrete\n        self.this_rank: bool = False\n        if config.all_ranks:\n            self.this_rank = True\n        elif config.ranks:\n            self.this_rank = rank in config.ranks\n\n    def start(self, **kwargs):\n        if self.enable and self.this_rank:\n            self.this_step = True\n            if not self.discrete:\n                torch.cuda.profiler.start()\n\n    def stop(self):\n        if self.enable and self.this_rank:\n            self.this_step = False\n            if not self.discrete:\n                torch.cuda.profiler.stop()\n\n    def annotate(\n        self,\n        message: Optional[str] = None,\n        color: Optional[str] = None,\n        domain: Optional[str] = None,\n        category: Optional[str] = None,\n        **kwargs_outer,\n    ) -> Callable:\n        \"\"\"Decorate a Worker member function to profile the current rank in the current training step.\n\n        Requires the target function to be a member function of a Worker, which has a member field `profiler` with\n        NightSystemsProfiler type.\n\n        Args:\n            message (str, optional):\n                The message to be displayed in the profiler. Defaults to None.\n            color (str, optional):\n                The color of the range. Defaults to None.\n            domain (str, optional):\n                The domain of the range. Defaults to None.\n            category (str, optional):\n                The category of the range. Defaults to None.\n        \"\"\"\n\n        def decorator(func):\n            @functools.wraps(func)\n            def wrapper(*args, **kwargs_inner):\n                if not self.enable:\n                    return func(*args, **kwargs_inner)\n\n                profile_name = message or func.__name__\n\n                if self.this_step:\n                    if self.discrete:\n                        torch.cuda.profiler.start()\n                    mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category)\n\n                result = func(*args, **kwargs_inner)\n\n                if self.this_step:\n                    mark_end_range(mark_range)\n                    if self.discrete:\n                        torch.cuda.profiler.stop()\n\n                return result\n\n            return wrapper\n\n        return decorator\n"
  },
  {
    "path": "verl_distillation/verl/utils/profiler/performance.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport datetime\nimport inspect\nimport logging\nfrom contextlib import contextmanager\nfrom typing import Any, Optional\n\nimport torch\nimport torch.distributed as dist\nfrom codetiming import Timer\n\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.logger import DecoratorLoggerBase\n\n\ndef _get_current_mem_info(unit: str = \"GB\", precision: int = 2) -> tuple[str]:\n    \"\"\"Get current memory usage.\n\n    Note that CPU device memory info is always 0.\n\n    Args:\n        unit (str, optional): The unit of memory measurement. Defaults to \"GB\".\n        precision (int, optional): The number of decimal places to round memory values. Defaults to 2.\n\n    Returns:\n        tuple[str]: A tuple containing memory allocated, memory reserved, memory used, and memory total\n        in the specified unit.\n    \"\"\"\n    assert unit in [\"GB\", \"MB\", \"KB\"]\n    device = get_torch_device()\n    # torch.cpu.memory_allocated() does not exist\n    if device == torch.cpu:\n        return \"0.00\", \"0.00\", \"0.00\", \"0.00\"\n\n    divisor = 1024**3 if unit == \"GB\" else 1024**2 if unit == \"MB\" else 1024\n    mem_allocated = get_torch_device().memory_allocated()\n    mem_reserved = get_torch_device().memory_reserved()\n    # use get_torch_device().mem_get_info to profile device memory\n    # since vllm's sleep mode works below pytorch\n    # see https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119\n    mem_free, mem_total = get_torch_device().mem_get_info()\n    mem_used = mem_total - mem_free\n    mem_allocated = f\"{mem_allocated / divisor:.{precision}f}\"\n    mem_reserved = f\"{mem_reserved / divisor:.{precision}f}\"\n    mem_used = f\"{mem_used / divisor:.{precision}f}\"\n    mem_total = f\"{mem_total / divisor:.{precision}f}\"\n    return mem_allocated, mem_reserved, mem_used, mem_total\n\n\ndef log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0):\n    \"\"\"Log GPU memory usage information.\n\n    Args:\n        head (str): A descriptive header for the memory usage log message.\n        logger (logging.Logger, optional): Logger instance to use for logging. If None, prints to stdout.\n        level: Logging level to use. Defaults to logging.DEBUG.\n        rank (int): The rank of the process to log memory for. Defaults to 0.\n    \"\"\"\n    if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):\n        mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()\n        message = (\n            f\"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, \"\n            f\"device memory used/total (GB): {mem_used}/{mem_total}\"\n        )\n\n        if logger is None:\n            print(message)\n        else:\n            logger.log(msg=message, level=level)\n\n\nclass GPUMemoryLogger(DecoratorLoggerBase):\n    \"\"\"A decorator class to log GPU memory usage.\n\n    Example:\n        >>> from verl.utils.profiler.performance import GPUMemoryLogger\n        >>> @GPUMemoryLogger(role=\"actor\")\n        >>> def update_actor(self, batch):\n        ...     # real actor update logics\n        ...     return\n    \"\"\"\n\n    def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True):\n        if dist.is_initialized() and dist.get_world_size() > 1:\n            rank = dist.get_rank()\n        else:\n            rank = 0\n        super().__init__(role, logger, level, rank, log_only_rank_0)\n\n    def __call__(self, decorated_function: callable):\n        def f(*args, **kwargs):\n            return self.log(decorated_function, *args, **kwargs)\n\n        return f\n\n    def log(self, func, *args, **kwargs):\n        name = func.__name__\n        mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()\n        message = (\n            f\"Before {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, \"\n            f\"device memory used/total (GB): {mem_used}/{mem_total}\"\n        )\n        self.logging_function(message)\n\n        output = func(*args, **kwargs)\n\n        mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()\n        message = (\n            f\"After {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, \"\n            f\"device memory used/total (GB): {mem_used}/{mem_total}\"\n        )\n\n        self.logging_function(message)\n        return output\n\n\ndef log_print(ctn: Any):\n    current_time = datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n\n    frame = inspect.currentframe().f_back\n    function_name = frame.f_code.co_name\n    line_number = frame.f_lineno\n    file_name = frame.f_code.co_filename.split(\"/\")[-1]\n    print(f\"[{current_time}-{file_name}:{line_number}:{function_name}]: {ctn}\")\n\n\ndef _timer(name: str, timing_raw: dict[str, float]):\n    \"\"\"Inner function that handles the core timing logic.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n    \"\"\"\n    with Timer(name=name, logger=None) as timer:\n        yield\n    if name not in timing_raw:\n        timing_raw[name] = 0\n    timing_raw[name] += timer.last\n\n\n@contextmanager\ndef simple_timer(name: str, timing_raw: dict[str, float]):\n    \"\"\"Context manager for basic timing without NVTX markers.\n\n    This utility function measures the execution time of code within its context\n    and accumulates the timing information in the provided dictionary.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n\n    Yields:\n        None: This is a context manager that yields control back to the code block.\n    \"\"\"\n    yield from _timer(name, timing_raw)\n\n\n@contextmanager\ndef marked_timer(\n    name: str,\n    timing_raw: dict[str, float],\n    color: str = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n):\n    \"\"\"Context manager for timing with platform markers.\n\n    This utility function measures the execution time of code within its context,\n    accumulates the timing information, and adds platform markers for profiling.\n    This function is a default implementation when hardware profiler is not available.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n        color (Optional[str]): Color for the marker. Defaults to None.\n        domain (Optional[str]): Domain for the marker. Defaults to None.\n        category (Optional[str]): Category for the marker. Defaults to None.\n\n    Yields:\n        None: This is a context manager that yields control back to the code block.\n    \"\"\"\n    yield from _timer(name, timing_raw)\n\n\ndef reduce_timing(\n    timing_raw: dict[str, float], reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.AVG\n) -> dict[str, float]:\n    \"\"\"Reduce timing information across all processes.\n\n    This function uses distributed communication to gather and sum the timing\n    information from all processes in a distributed environment.\n\n    Args:\n        timing_raw (Dict[str, float]): Dictionary containing timing information.\n\n    Returns:\n        Dict[str, float]: Reduced timing information.\n    \"\"\"\n    if not dist.is_initialized():\n        return timing_raw\n\n    key_list, timing_list = [], []\n    for key in sorted(timing_raw.keys()):\n        key_list.append(key)\n        timing_list.append(timing_raw[key])\n    timing_list = torch.tensor(timing_list, dtype=torch.float32, device=get_device_id())\n    torch.distributed.all_reduce(timing_list, op=reduce_op)\n    timing_list = [tensor.item() for tensor in timing_list.to(\"cpu\")]\n    timing_generate = {key_list[i]: timing_list[i] for i in range(len(key_list))}\n    return timing_generate\n\n\ndef topk_reduce_ratio_min_max(timing: float, k: int = 10) -> tuple[float, float, float]:\n    \"\"\"Calculate topk items take-up ratio, and min/max timing across all ranks.\"\"\"\n    if not dist.is_initialized():\n        return -1.0, -1.0, -1.0\n\n    world_size = dist.get_world_size()\n    timing_tensor = torch.tensor(timing, dtype=torch.float32, device=get_device_id())\n    tensor_list = [torch.zeros(1, dtype=torch.float32, device=get_device_id()) for _ in range(world_size)]\n    torch.distributed.all_gather(tensor_list, timing_tensor)\n    tensor_stack = torch.stack(tensor_list)\n    timing_min = tensor_stack.min().cpu().item()\n    timing_max = tensor_stack.max().cpu().item()\n    top_k_percentile = torch.quantile(tensor_stack, 1 - k / 100)\n    tail_ratio = torch.mean((tensor_stack > top_k_percentile).float()).cpu().item()\n    return tail_ratio, timing_min, timing_max\n"
  },
  {
    "path": "verl_distillation/verl/utils/profiler/profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nimport os\nfrom typing import Callable, Optional\n\nimport torch\nimport torch.distributed\n\nfrom ..memory_utils import MemorySnapshotSampler, enable_memory_visualize\nfrom .config import ProfilerConfig, TorchMemoryToolConfig, TorchProfilerToolConfig\n\n\nclass Profiler:\n    \"\"\"A PyTorch profiler wrapper class for collecting performance metrics.\n\n    TODO(haibin.lin): this should implement the DistProfiler interface, and the config should be unified.\n\n    This profiler provides a convenient interface for profiling PyTorch operations,\n    with support for:\n\n    - CPU and CUDA activity profiling\n    - Configurable profiling schedule (wait/warmup/active steps)\n    - Multi-rank profiling support\n    - Chrome trace export\n\n    Args:\n        config: Configuration object containing profiling parameters\n    \"\"\"\n\n    def __init__(self, config: ProfilerConfig, tool_config: Optional[TorchProfilerToolConfig] = None):\n        # note : if we do not set use_profile, it will be set as None, so that all function will be skip\n        if not config:\n            config = ProfilerConfig(ranks=[], enable=False)\n        if not tool_config:\n            assert not config.enable, \"tool_config must be provided when profiler is enabled\"\n        self.prof = None\n        self.saved = False\n        self.enable = config.enable\n        if not config.enable:\n            return\n        self.config = config\n        self.tool_config = tool_config\n        self.rank = torch.distributed.get_rank()\n        # we need to validate the config before using the profiler\n        self._validate()\n        if self.rank in self.config.profile_ranks:\n            print(f\"[Profiler] Profiler init for rank {self.rank}\")\n\n            self.prof = torch.profiler.profile(\n                activities=[\n                    torch.profiler.ProfilerActivity.CPU,\n                    torch.profiler.ProfilerActivity.CUDA,\n                ],\n                schedule=torch.profiler.schedule(\n                    wait=max(self.tool_config.step_start - 1, 0),\n                    warmup=1 if self.tool_config.step_start > 0 else 0,\n                    active=self.tool_config.step_end - self.tool_config.step_start,\n                    repeat=1,\n                ),\n                record_shapes=True,\n                with_stack=True,\n            )\n\n    def _validate(self):\n        if self.enable:\n            if self.config.profile_ranks is None:\n                print(\"[WARNING] Profile ranks is not set, default to rank 0\")\n                self.config.profile_ranks = [0]\n            assert self.tool_config.step_start >= 0, \"[ERROR] Profile step start must be greater than 0\"\n            assert self.tool_config.step_end >= 0, \"[ERROR] Profile step end must be greater than 0\"\n            assert self.tool_config.step_start < self.tool_config.step_end, (\n                \"[ERROR] Profile step start must be less than step end\"\n            )\n\n    def check(self):\n        return self.prof is not None and self.enable\n\n    def start(self):\n        if self.check():\n            print(f\"[Profiler] started for rank {self.rank}\")\n            self.prof.start()\n\n    def step(self):\n        if self.check():\n            self.prof.step()\n\n    def stop(self):\n        if self.check():\n            print(f\"[Profiler] stopped for rank {self.rank}\")\n            self.prof.stop()\n\n    def save(self):\n        if self.prof is not None and not self.saved:\n            if not os.path.exists(self.config.save_path):\n                os.makedirs(self.config.save_path)\n            save_file_name = f\"/prof_start_{self.config.step_start}_end_{self.config.step_end}_rank_{self.rank}.json\"\n            print(f\"[Profiler] Saving trace to {self.config.save_path + save_file_name}\")\n            self.prof.export_chrome_trace(self.config.save_path + save_file_name)\n            self.enable = False\n            self.saved = True\n\n    def stop_and_save(self):\n        if self.check():\n            self.stop()\n            self.save()\n\n    def stop_trace(self):\n        if self.check():\n            print(f\"[Profiler] Trace stopped for rank {self.rank}\")\n            self.enable = False\n\n\ndef mark_start_range(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> None:\n    \"\"\"Start a profiling range marker (no-op implementation).\n\n    Args:\n        message (Optional[str]): Message to associate with the range marker.\n        color (Optional[str]): Color for the marker visualization.\n        domain (Optional[str]): Domain for the marker.\n        category (Optional[str]): Category for the marker.\n    \"\"\"\n    pass\n\n\ndef mark_end_range(range_id: str) -> None:\n    \"\"\"End a profiling range marker (no-op implementation).\n\n    Args:\n        range_id (str): Identifier of the range to end.\n    \"\"\"\n    pass\n\n\ndef mark_annotate(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> Callable:\n    \"\"\"Decorator to annotate a function with profiling markers (no-op implementation).\n\n    Args:\n        message (Optional[str]): Message to associate with the annotation.\n        color (Optional[str]): Color for the marker visualization.\n        domain (Optional[str]): Domain for the marker.\n        category (Optional[str]): Category for the marker.\n\n    Returns:\n        Callable: Decorator function that returns the original function unchanged.\n    \"\"\"\n\n    def decorator(func):\n        return func\n\n    return decorator\n\n\nclass DistProfiler:\n    \"\"\"A dispatcher that delegates to specific profilers based on config.tool.\n\n    Supported tools:\n    - nsys: NsightSystemsProfiler\n    - npu: NPUProfiler (Ascend)\n    - torch: PyTorch torch.profiler wrapper\n    - torch_memory: Torch CUDA memory snapshot dump\n    \"\"\"\n\n    def __init__(\n        self, rank: int, config: Optional[ProfilerConfig] = None, tool_config: Optional[object] = None, **kwargs\n    ):\n        # Default config\n        if not config:\n            config = ProfilerConfig(ranks=[], enable=False)\n\n        self._impl = None\n        self._tool = getattr(config, \"tool\", None)\n\n        # Normalize rank selection\n        self._this_rank = False\n        if config.all_ranks:\n            self._this_rank = True\n        elif config.ranks:\n            self._this_rank = rank in config.ranks\n        else:\n            # default rank 0 if enabled but ranks unspecified\n            self._this_rank = (rank == 0) if config.enable else False\n\n        # Lazy import to avoid circular deps\n        if self._tool == \"nsys\":\n            from .nvtx_profile import NsightSystemsProfiler as _Nsight\n\n            self._impl = _Nsight(rank=rank, config=config, tool_config=tool_config, **kwargs)\n        elif self._tool == \"npu\":\n            from .mstx_profile import NPUProfiler as _Npu\n\n            self._impl = _Npu(rank=rank, config=config, tool_config=tool_config, **kwargs)\n        elif self._tool == \"torch\":\n            # Use the torch profiler wrapper defined above\n            self._impl = Profiler(config=config, tool_config=tool_config)\n        elif self._tool == \"torch_memory\":\n            self._impl = TorchMemoryProfiler(rank=rank, config=config, tool_config=tool_config)\n        else:\n            # Fallback to a no-op impl\n            self._impl = _NoOpProfiler()\n\n    def start(self, **kwargs):\n        return getattr(self._impl, \"start\", lambda **_: None)(**kwargs)\n\n    def stop(self):\n        return getattr(self._impl, \"stop\", lambda: None)()\n\n    @classmethod\n    def annotate(\n        cls,\n        message: Optional[str] = None,\n        color: Optional[str] = None,\n        domain: Optional[str] = None,\n        category: Optional[str] = None,\n        **kwargs_outer,\n    ) -> Callable:\n        def decorator(func):\n            @functools.wraps(func)\n            def wrapper(self_instance, *args, **kwargs_inner):\n                profiler = getattr(self_instance, \"profiler\", None)\n                if not profiler:\n                    return func(self_instance, *args, **kwargs_inner)\n\n                impl = profiler._impl\n                if hasattr(impl, \"annotate\"):\n                    try:\n                        actual_decorator = impl.annotate(\n                            message=message, color=color, domain=domain, category=category, **kwargs_outer\n                        )\n\n                        return actual_decorator(func)(self_instance, *args, **kwargs_inner)\n                    except Exception:\n                        return func(self_instance, *args, **kwargs_inner)\n                return func(self_instance, *args, **kwargs_inner)\n\n            return wrapper\n\n        return decorator\n\n\nclass _NoOpProfiler:\n    def start(self, **kwargs):\n        return\n\n    def stop(self):\n        return\n\n\nclass TorchMemoryProfiler:\n    \"\"\"Profiler that dumps CUDA memory snapshots at step boundaries.\n\n    Behavior:\n    - On first construction (per process), enable memory history recording if CUDA is available\n    - On start(step=X), remember sub_dir for this step\n    - On stop(), dump a memory snapshot into config.save_path under the remembered sub_dir\n    \"\"\"\n\n    _memory_history_enabled: bool = False\n\n    def __init__(\n        self, rank: int, config: Optional[ProfilerConfig], tool_config: Optional[TorchMemoryToolConfig] = None\n    ):\n        # Always respond to explicit start/stop calls for torch_memory tool,\n        # regardless of per-role enable flag, to align with global step control.\n        self.enable = True\n        if not config:\n            config = ProfilerConfig(ranks=[])\n        self.config = config\n        self.rank = rank\n        self.this_step = False\n        self.sub_dir = None\n        self.sampler = MemorySnapshotSampler()\n\n        # Get parameters from tool_config, with fallback to defaults\n        if tool_config:\n            trace_alloc_max_entries = tool_config.trace_alloc_max_entries\n            stack_depth = tool_config.stack_depth\n        else:\n            trace_alloc_max_entries = 100_000\n            stack_depth = 32\n\n        # Best-effort enable memory history once\n        if not TorchMemoryProfiler._memory_history_enabled:\n            try:\n                enable_memory_visualize(trace_alloc_max_entries=trace_alloc_max_entries, stack_depth=stack_depth)\n            except Exception:\n                # silently ignore if not supported\n                pass\n            TorchMemoryProfiler._memory_history_enabled = True\n\n    def start(self, **kwargs):\n        if not self.enable:\n            return\n        if not self._should_profile_this_rank():\n            return\n        profile_step = kwargs.get(\"profile_step\", None)\n        # Keep ranks aligned under same folder name\n        self.sub_dir = f\"step{profile_step}\" if profile_step is not None else None\n        self.this_step = True\n\n    def stop(self):\n        if not self.enable or not self.this_step:\n            return\n        self.this_step = False\n        if not self._should_profile_this_rank():\n            return\n        out_dir = self.config.save_path or \"outputs/profile\"\n        tag = \"torch_memory\"\n        # Dump snapshot; all ranks write into same sub_dir\n        try:\n            self.sampler.dump_memory_snapshot(out_dir=out_dir, tag=tag, sub_dir=self.sub_dir)\n        except Exception:\n            pass\n\n    def _should_profile_this_rank(self) -> bool:\n        if self.config.all_ranks:\n            return True\n        if self.config.ranks:\n            return self.rank in self.config.ranks\n        # default rank 0\n        return self.rank == 0\n\n\nclass DistProfilerExtension:\n    \"\"\"An extension class for DistProfiler that provides distributed profiling capabilities.\n\n    It is intended for workers in verl that single controller invokes.\n\n    This class wraps a DistProfiler instance and provides methods to start/stop profiling\n    that can be dispatched across multiple ranks in a distributed training environment.\n\n    Args:\n        profiler (DistProfiler): The base distributed profiler instance to extend\n    \"\"\"\n\n    def __init__(self, profiler: DistProfiler):\n        self.profiler = profiler\n\n    from verl.single_controller.base.decorator import Dispatch, register\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def start_profile(self, **kwargs) -> None:\n        \"\"\"Start profiling for the current rank in the current training step.\"\"\"\n        self.profiler.start(**kwargs)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def stop_profile(self) -> None:\n        \"\"\"Stop profiling for the current rank in the current training step.\"\"\"\n        self.profiler.stop()\n"
  },
  {
    "path": "verl_distillation/verl/utils/py_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContain small python utility functions\n\"\"\"\n\nimport importlib\nimport multiprocessing\nimport os\nimport queue  # Import the queue module for exception type hint\nimport signal\nfrom contextlib import contextmanager\nfrom functools import wraps\nfrom types import SimpleNamespace\nfrom typing import Any, Callable, Iterator, Optional\n\n\n# --- Top-level helper for multiprocessing timeout ---\n# This function MUST be defined at the top level to be pickleable\ndef _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: tuple, kwargs: dict[str, Any]):\n    \"\"\"\n    Internal wrapper function executed in the child process.\n    Calls the original target function and puts the result or exception into the queue.\n    \"\"\"\n    try:\n        result = target_func(*args, **kwargs)\n        mp_queue.put((True, result))  # Indicate success and put result\n    except Exception as e:\n        # Ensure the exception is pickleable for the queue\n        try:\n            import pickle\n\n            pickle.dumps(e)  # Test if the exception is pickleable\n            mp_queue.put((False, e))  # Indicate failure and put exception\n        except (pickle.PicklingError, TypeError):\n            # Fallback if the original exception cannot be pickled\n            mp_queue.put((False, RuntimeError(f\"Original exception type {type(e).__name__} not pickleable: {e}\")))\n\n\n# Renamed the function from timeout to timeout_limit\ndef timeout_limit(seconds: float, use_signals: bool = False):\n    \"\"\"\n    Decorator to add a timeout to a function.\n\n    Args:\n        seconds: The timeout duration in seconds.\n        use_signals: (Deprecated)  This is deprecated because signals only work reliably in the main thread\n                     and can cause issues in multiprocessing or multithreading contexts.\n                     Defaults to False, which uses the more robust multiprocessing approach.\n\n    Returns:\n        A decorated function with timeout.\n\n    Raises:\n        TimeoutError: If the function execution exceeds the specified time.\n        RuntimeError: If the child process exits with an error (multiprocessing mode).\n        NotImplementedError: If the OS is not POSIX (signals are only supported on POSIX).\n    \"\"\"\n\n    def decorator(func):\n        if use_signals:\n            if os.name != \"posix\":\n                raise NotImplementedError(f\"Unsupported OS: {os.name}\")\n            # Issue deprecation warning if use_signals is explicitly True\n            print(\n                \"WARN: The 'use_signals=True' option in the timeout decorator is deprecated. \\\n                Signals are unreliable outside the main thread. \\\n                Please use the default multiprocessing-based timeout (use_signals=False).\"\n            )\n\n            @wraps(func)\n            def wrapper_signal(*args, **kwargs):\n                def handler(signum, frame):\n                    # Update function name in error message if needed (optional but good practice)\n                    raise TimeoutError(f\"Function {func.__name__} timed out after {seconds} seconds (signal)!\")\n\n                old_handler = signal.getsignal(signal.SIGALRM)\n                signal.signal(signal.SIGALRM, handler)\n                # Use setitimer for float seconds support, alarm only supports integers\n                signal.setitimer(signal.ITIMER_REAL, seconds)\n\n                try:\n                    result = func(*args, **kwargs)\n                finally:\n                    # Reset timer and handler\n                    signal.setitimer(signal.ITIMER_REAL, 0)\n                    signal.signal(signal.SIGALRM, old_handler)\n                return result\n\n            return wrapper_signal\n        else:\n            # --- Multiprocessing based timeout (existing logic) ---\n            @wraps(func)\n            def wrapper_mp(*args, **kwargs):\n                q = multiprocessing.Queue(maxsize=1)\n                process = multiprocessing.Process(target=_mp_target_wrapper, args=(func, q, args, kwargs))\n                process.start()\n                process.join(timeout=seconds)\n\n                if process.is_alive():\n                    process.terminate()\n                    process.join(timeout=0.5)  # Give it a moment to terminate\n                    if process.is_alive():\n                        print(f\"Warning: Process {process.pid} did not terminate gracefully after timeout.\")\n                    # Update function name in error message if needed (optional but good practice)\n                    raise TimeoutError(f\"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!\")\n\n                try:\n                    success, result_or_exc = q.get(timeout=0.1)  # Small timeout for queue read\n                    if success:\n                        return result_or_exc\n                    else:\n                        raise result_or_exc  # Reraise exception from child\n                except queue.Empty as err:\n                    exitcode = process.exitcode\n                    if exitcode is not None and exitcode != 0:\n                        raise RuntimeError(\n                            f\"Child process exited with error (exitcode: {exitcode}) before returning result.\"\n                        ) from err\n                    else:\n                        # Should have timed out if queue is empty after join unless process died unexpectedly\n                        # Update function name in error message if needed (optional but good practice)\n                        raise TimeoutError(\n                            f\"Operation timed out or process finished unexpectedly without result \"\n                            f\"(exitcode: {exitcode}).\"\n                        ) from err\n                finally:\n                    q.close()\n                    q.join_thread()\n\n            return wrapper_mp\n\n    return decorator\n\n\ndef union_two_dict(dict1: dict, dict2: dict):\n    \"\"\"Union two dict. Will throw an error if there is an item not the same object with the same key.\n\n    Args:\n        dict1:\n        dict2:\n\n    Returns:\n\n    \"\"\"\n    for key, val in dict2.items():\n        if key in dict1:\n            assert dict2[key] == dict1[key], f\"{key} in meta_dict1 and meta_dict2 are not the same object\"\n        dict1[key] = val\n\n    return dict1\n\n\ndef append_to_dict(data: dict, new_data: dict, prefix: str = \"\"):\n    \"\"\"Append values from new_data to lists in data.\n\n    For each key in new_data, this function appends the corresponding value to a list\n    stored under the same key in data. If the key doesn't exist in data, a new list is created.\n\n    Args:\n        data (Dict): The target dictionary containing lists as values.\n        new_data (Dict): The source dictionary with values to append.\n\n    Returns:\n        None: The function modifies data in-place.\n    \"\"\"\n    for key, val in new_data.items():\n        new_key = f\"{prefix}{key}\"\n        if new_key not in data:\n            data[new_key] = []\n        data[new_key].append(val)\n\n\nclass NestedNamespace(SimpleNamespace):\n    \"\"\"A nested version of SimpleNamespace that recursively converts dictionaries to namespaces.\n\n    This class allows for dot notation access to nested dictionary structures by recursively\n    converting dictionaries to NestedNamespace objects.\n\n    Example:\n        config_dict = {\"a\": 1, \"b\": {\"c\": 2, \"d\": 3}}\n        config = NestedNamespace(config_dict)\n        # Access with: config.a, config.b.c, config.b.d\n\n    Args:\n        dictionary: The dictionary to convert to a nested namespace.\n        **kwargs: Additional attributes to set on the namespace.\n    \"\"\"\n\n    def __init__(self, dictionary, **kwargs):\n        super().__init__(**kwargs)\n        for key, value in dictionary.items():\n            if isinstance(value, dict):\n                self.__setattr__(key, NestedNamespace(value))\n            else:\n                self.__setattr__(key, value)\n\n\nclass DynamicEnumMeta(type):\n    def __iter__(cls) -> Iterator[Any]:\n        return iter(cls._registry.values())\n\n    def __contains__(cls, item: Any) -> bool:\n        # allow `name in EnumClass` or `member in EnumClass`\n        if isinstance(item, str):\n            return item in cls._registry\n        return item in cls._registry.values()\n\n    def __getitem__(cls, name: str) -> Any:\n        return cls._registry[name]\n\n    def __reduce_ex__(cls, protocol):\n        # Always load the existing module and grab the class\n        return getattr, (importlib.import_module(cls.__module__), cls.__name__)\n\n    def names(cls):\n        return list(cls._registry.keys())\n\n    def values(cls):\n        return list(cls._registry.values())\n\n\nclass DynamicEnum(metaclass=DynamicEnumMeta):\n    _registry: dict[str, \"DynamicEnum\"] = {}\n    _next_value: int = 0\n\n    def __init__(self, name: str, value: int):\n        self.name = name\n        self.value = value\n\n    def __repr__(self):\n        return f\"<{self.__class__.__name__}.{self.name}: {self.value}>\"\n\n    def __reduce_ex__(self, protocol):\n        \"\"\"\n        Unpickle via: getattr(import_module(module).Dispatch, 'ONE_TO_ALL')\n        so the existing class is reused instead of re-executed.\n        \"\"\"\n        module = importlib.import_module(self.__class__.__module__)\n        enum_cls = getattr(module, self.__class__.__name__)\n        return getattr, (enum_cls, self.name)\n\n    @classmethod\n    def register(cls, name: str) -> \"DynamicEnum\":\n        key = name.upper()\n        if key in cls._registry:\n            raise ValueError(f\"{key} already registered\")\n        member = cls(key, cls._next_value)\n        cls._registry[key] = member\n        setattr(cls, key, member)\n        cls._next_value += 1\n        return member\n\n    @classmethod\n    def remove(cls, name: str):\n        key = name.upper()\n        member = cls._registry.pop(key)\n        delattr(cls, key)\n        return member\n\n    @classmethod\n    def from_name(cls, name: str) -> Optional[\"DynamicEnum\"]:\n        return cls._registry.get(name.upper())\n\n\n@contextmanager\ndef temp_env_var(key: str, value: str):\n    \"\"\"Context manager for temporarily setting an environment variable.\n\n    This context manager ensures that environment variables are properly set and restored,\n    even if an exception occurs during the execution of the code block.\n\n    Args:\n        key: Environment variable name to set\n        value: Value to set the environment variable to\n\n    Yields:\n        None\n\n    Example:\n        >>> with temp_env_var(\"MY_VAR\", \"test_value\"):\n        ...     # MY_VAR is set to \"test_value\"\n        ...     do_something()\n        ... # MY_VAR is restored to its original value or removed if it didn't exist\n    \"\"\"\n    original = os.environ.get(key)\n    os.environ[key] = value\n    try:\n        yield\n    finally:\n        if original is None:\n            os.environ.pop(key, None)\n        else:\n            os.environ[key] = original\n\n\ndef convert_to_regular_types(obj):\n    \"\"\"Convert Hydra configs and other special types to regular Python types.\"\"\"\n    from omegaconf import DictConfig, ListConfig\n\n    if isinstance(obj, ListConfig | DictConfig):\n        return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)\n    elif isinstance(obj, list | tuple):\n        return [convert_to_regular_types(x) for x in obj]\n    elif isinstance(obj, dict):\n        return {k: convert_to_regular_types(v) for k, v in obj.items()}\n    return obj\n"
  },
  {
    "path": "verl_distillation/verl/utils/ray_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContains commonly used utilities for ray\n\"\"\"\n\nimport asyncio\nimport concurrent.futures\nimport os\nfrom typing import Any, Optional\n\nimport ray\n\n\ndef ray_noset_visible_devices(env_vars=os.environ):\n    # Refer to\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103\n    # https://github.com/ray-project/ray/blob/3b9e729f6a669ffd85190f901f5e262af79771b0/python/ray/_private/accelerators/amd_gpu.py#L114-L115\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98\n    NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [\n        \"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES\",\n        \"RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES\",\n        \"RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS\",\n        \"RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR\",\n    ]\n    return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST)\n\n\ndef parallel_put(data_list: list[Any], max_workers: Optional[int] = None):\n    \"\"\"\n    Puts a list of data into the Ray object store in parallel using a thread pool.\n\n    Args:\n        data_list (List[Any]): A list of Python objects to be put into the Ray object store.\n        max_workers (int, optional): The maximum number of worker threads to use.\n                                     Defaults to min(len(data_list), 16).\n\n    Returns:\n        List[ray.ObjectRef]: A list of Ray object references corresponding to the input data_list,\n                             maintaining the original order.\n    \"\"\"\n    assert len(data_list) > 0, \"data_list must not be empty\"\n\n    def put_data(index, data):\n        return index, ray.put(data)\n\n    if max_workers is None:\n        max_workers = min(len(data_list), 16)\n\n    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n        data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)]\n        res_lst = []\n        for future in concurrent.futures.as_completed(data_list_f):\n            res_lst.append(future.result())\n\n        # reorder based on index\n        output = [None for _ in range(len(data_list))]\n        for res in res_lst:\n            index, data_ref = res\n            output[index] = data_ref\n\n    return output\n\n\ndef get_event_loop():\n    try:\n        loop = asyncio.get_event_loop()\n    except RuntimeError:\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n\n    return loop\n"
  },
  {
    "path": "verl_distillation/verl/utils/rendezvous/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/utils/rendezvous/ray_backend.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport time\n\nimport ray\nfrom cupy.cuda.nccl import NcclCommunicator, get_unique_id\nfrom ray.util import list_named_actors\n\n\n@ray.remote\nclass NCCLIDStore:\n    def __init__(self, nccl_id):\n        self._nccl_id = nccl_id\n\n    def get(self):\n        return self._nccl_id\n\n\ndef get_nccl_id_store_by_name(name):\n    all_actors = list_named_actors(all_namespaces=True)\n    matched_actors = [actor for actor in all_actors if actor.get(\"name\", None) == name]\n    if len(matched_actors) == 1:\n        actor = matched_actors[0]\n        return ray.get_actor(**actor)\n    elif len(matched_actors) > 1:\n        logging.warning(\"multiple actors with same name found: %s\", matched_actors)\n    elif len(matched_actors) == 0:\n        logging.info(\"failed to get any actor named %s\", name)\n    return None\n\n\ndef create_nccl_communicator_in_ray(\n    rank: int, world_size: int, group_name: str, max_retries: int = 100, interval_s: int = 5\n):\n    if rank == 0:\n        nccl_id = get_unique_id()\n        nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id)\n\n        assert ray.get(nccl_id_store.get.remote()) == nccl_id\n        communicator = NcclCommunicator(\n            ndev=world_size,\n            commId=nccl_id,\n            rank=0,\n        )\n        return communicator\n    else:\n        for i in range(max_retries):\n            nccl_id_store = get_nccl_id_store_by_name(group_name)\n            if nccl_id_store is not None:\n                logging.info(\"nccl_id_store %s got\", group_name)\n                nccl_id = ray.get(nccl_id_store.get.remote())\n                logging.info(\"nccl id for %s got: %s\", group_name, nccl_id)\n                communicator = NcclCommunicator(\n                    ndev=world_size,\n                    commId=nccl_id,\n                    rank=rank,\n                )\n                return communicator\n            logging.info(\"failed to get nccl_id for %d time, sleep for %d seconds\", i + 1, interval_s)\n            time.sleep(interval_s)\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# from . import gsm8k, math, prime_math, prime_code\n\nfrom verl.utils.import_utils import deprecated\n\n\ndef default_compute_score(\n    data_source,\n    solution_str,\n    ground_truth,\n    extra_info=None,\n    sandbox_fusion_url=None,\n    concurrent_semaphore=None,\n    memory_limit_mb=None,\n    **kwargs,\n):\n    \"\"\"Compute the score for a given solution based on the data source.\n\n    Args:\n        data_source (str): The source dataset identifier which determines the scoring method.\n        solution_str (str): The solution string to be evaluated.\n        ground_truth (str): The ground truth answer for comparison.\n        extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None.\n\n    Returns:\n        float: The computed score as a floating point number. If the result is a dictionary,\n               it returns the dictionary instead.\n\n    Raises:\n        NotImplementedError: If the reward function is not implemented for the given data source.\n    \"\"\"\n    if data_source == \"openai/gsm8k\":\n        from . import gsm8k\n\n        res = gsm8k.compute_score(solution_str, ground_truth)\n    elif data_source in [\"lighteval/MATH\", \"DigitalLearningGmbH/MATH-lighteval\", \"HuggingFaceH4/MATH-500\"]:\n        from . import math_reward\n\n        res = math_reward.compute_score(solution_str, ground_truth)\n        # [Optional] Math-Verify Integration\n        # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify).\n        # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`.\n        # To use it, override the `compute_score` function with the following implementation:\n\n        # from . import math_verify\n        # res = math_verify.compute_score(solution_str, ground_truth)\n    elif data_source in [\"math_dapo\", \"math\", \"math_dapo_reasoning\"] or data_source.startswith(\"aime\"):\n        from . import math_dapo\n\n        res = math_dapo.compute_score(solution_str, ground_truth)\n    elif data_source in [\n        \"numina_aops_forum\",\n        \"numina_synthetic_math\",\n        \"numina_amc_aime\",\n        \"numina_synthetic_amc\",\n        \"numina_cn_k12\",\n        \"numina_olympiads\",\n    ]:\n        from . import prime_math\n\n        res = prime_math.compute_score(solution_str, ground_truth)\n    elif data_source in [\"codecontests\", \"apps\", \"codeforces\", \"taco\"]:\n        # Use the passed sandbox_fusion_url if available\n        if sandbox_fusion_url:\n            from . import sandbox_fusion\n\n            # Pass the URL directly, ground_truth likely contains test cases here\n            res = sandbox_fusion.compute_score(\n                sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, solution_str, ground_truth, continuous=True\n            )\n        else:\n            # If no sandbox URL is provided, fall back to prime_code or raise error\n            from . import prime_code\n\n            # Assuming prime_code doesn't need the URL\n            res = prime_code.compute_score(solution_str, ground_truth, continuous=True)\n    elif data_source in [\"hiyouga/geometry3k\"]:\n        from . import geo3k\n\n        res = geo3k.compute_score(solution_str, ground_truth)\n    elif data_source in [\n        \"searchR1_nq\",\n        \"searchR1_triviaqa\",\n        \"searchR1_popqa\",\n        \"searchR1_hotpotqa\",\n        \"searchR1_2wikimultihopqa\",\n        \"searchR1_musique\",\n        \"searchR1_bamboogle\",\n    ]:\n        from . import search_r1_like_qa_em\n\n        res = search_r1_like_qa_em.compute_score(solution_str, ground_truth)\n\n    else:\n        raise NotImplementedError(f\"Reward function is not implemented for {data_source=}\")\n\n    if isinstance(res, dict):\n        return res\n    elif isinstance(res, int | float | bool):\n        return float(res)\n    else:\n        return float(res[0])\n\n\n@deprecated(\"verl.utils.reward_score.default_compute_score\")\ndef _default_compute_score(\n    data_source,\n    solution_str,\n    ground_truth,\n    extra_info=None,\n    sandbox_fusion_url=None,\n    concurrent_semaphore=None,\n    memory_limit_mb=None,\n):\n    \"\"\"\n    Legacy function API to be deprecated. Please use `default_compute_score` instead.\n    \"\"\"\n    return default_compute_score(\n        data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore, memory_limit_mb\n    )\n\n\n__all__ = [\"default_compute_score\"]\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/geo3k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport re\n\nfrom mathruler.grader import extract_boxed_content, grade_answer\n\n\ndef format_reward(predict_str: str) -> float:\n    pattern = re.compile(r\"<think>.*</think>.*\\\\boxed\\{.*\\}.*\", re.DOTALL)\n    match_result = re.fullmatch(pattern, predict_str)\n    return 1.0 if match_result else 0.0\n\n\ndef acc_reward(predict_str: str, ground_truth: str, use_boxed: bool = True) -> float:\n    if use_boxed:\n        answer = extract_boxed_content(predict_str)\n    else:\n        answer = predict_str\n    return 1.0 if grade_answer(answer, ground_truth) else 0.0\n\n\ndef compute_score(predict_str: str, ground_truth: str, use_boxed: bool = True, format_score: float = 0.1) -> float:\n    return (1.0 - format_score) * acc_reward(predict_str, ground_truth, use_boxed) + format_score * format_reward(\n        predict_str\n    )\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/gsm8k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\n\n_SOLUTION_CLIP_CHARS = 300\n\n\ndef extract_solution(solution_str, method=\"strict\"):\n    assert method in [\"strict\", \"flexible\"]\n\n    # Optimization: Regular expression matching on very long strings can be slow.\n    # For math problems, the final answer is usually at the end.\n    # We only match on the last 300 characters, which is a safe approximation for 300 tokens.\n    if len(solution_str) > _SOLUTION_CLIP_CHARS:\n        solution_str = solution_str[-_SOLUTION_CLIP_CHARS:]\n\n    if method == \"strict\":\n        # this also tests the formatting of the model\n        solutions = re.findall(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n        if len(solutions) == 0:\n            final_answer = None\n        else:\n            # take the last solution\n            final_answer = solutions[-1].replace(\",\", \"\").replace(\"$\", \"\")\n    elif method == \"flexible\":\n        answer = re.findall(\"(\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n        final_answer = None\n        if len(answer) == 0:\n            # no reward is there is no answer\n            pass\n        else:\n            invalid_str = [\"\", \".\"]\n            # find the last number that is not '.'\n            for final_answer in reversed(answer):\n                if final_answer not in invalid_str:\n                    break\n    return final_answer\n\n\ndef compute_score(solution_str, ground_truth, method=\"strict\", format_score=0.0, score=1.0):\n    \"\"\"The scoring function for GSM8k.\n\n    Reference: Trung, Luong, et al. \"Reft: Reasoning with reinforced fine-tuning.\" Proceedings of the 62nd Annual\n    Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.\n\n    Args:\n        solution_str: the solution text\n        ground_truth: the ground truth\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\n        format_score: the score for the format\n        score: the score for the correct answer\n    \"\"\"\n    answer = extract_solution(solution_str=solution_str, method=method)\n    if answer is None:\n        return 0\n    else:\n        if answer == ground_truth:\n            return score\n        else:\n            return format_score\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/math_batch.py",
    "content": "# Copyright 2025 Individual Contributor: Mert Unsal\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .math_reward import compute_score\n\n\ndef compute_score_batched(data_sources, solution_strs, ground_truths, extra_infos):\n    \"\"\"\n    This is a demonstration of how the batched reward function should look like.\n    Typically, you want to use batched reward to speed up the process with parallelization\n    \"\"\"\n    return [\n        compute_score(solution_str, ground_truth)\n        for solution_str, ground_truth in zip(solution_strs, ground_truths, strict=True)\n    ]\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/math_dapo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py\n\nimport re\nfrom typing import Optional\n\n\ndef last_boxed_only_string(string: str) -> Optional[str]:\n    \"\"\"Extract the last LaTeX boxed expression from a string.\n\n    Args:\n        string: Input string containing LaTeX code\n\n    Returns:\n        The last boxed expression or None if not found\n    \"\"\"\n    idx = string.rfind(\"\\\\boxed{\")\n    if idx < 0:\n        return None\n\n    i = idx\n    right_brace_idx = None\n    num_left_braces_open = 0\n\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n        if string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n        i += 1\n\n    return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None\n\n\ndef remove_boxed(s: str) -> str:\n    \"\"\"Remove the LaTeX boxed command from a string.\n\n    Args:\n        s: String with format \"\\\\boxed{content}\"\n\n    Returns:\n        The content inside the boxed command\n    \"\"\"\n    left = \"\\\\boxed{\"\n    assert s[: len(left)] == left, f\"box error: {s}\"\n    assert s[-1] == \"}\", f\"box error: {s}\"\n    return s[len(left) : -1]\n\n\n# Constants for normalization\nSUBSTITUTIONS = [\n    (\"an \", \"\"),\n    (\"a \", \"\"),\n    (\".$\", \"$\"),\n    (\"\\\\$\", \"\"),\n    (r\"\\ \", \"\"),\n    (\" \", \"\"),\n    (\"mbox\", \"text\"),\n    (\",\\\\text{and}\", \",\"),\n    (\"\\\\text{and}\", \",\"),\n    (\"\\\\text{m}\", \"\\\\text{}\"),\n]\n\nREMOVED_EXPRESSIONS = [\n    \"square\",\n    \"ways\",\n    \"integers\",\n    \"dollars\",\n    \"mph\",\n    \"inches\",\n    \"hours\",\n    \"km\",\n    \"units\",\n    \"\\\\ldots\",\n    \"sue\",\n    \"points\",\n    \"feet\",\n    \"minutes\",\n    \"digits\",\n    \"cents\",\n    \"degrees\",\n    \"cm\",\n    \"gm\",\n    \"pounds\",\n    \"meters\",\n    \"meals\",\n    \"edges\",\n    \"students\",\n    \"childrentickets\",\n    \"multiples\",\n    \"\\\\text{s}\",\n    \"\\\\text{.}\",\n    \"\\\\text{\\ns}\",\n    \"\\\\text{}^2\",\n    \"\\\\text{}^3\",\n    \"\\\\text{\\n}\",\n    \"\\\\text{}\",\n    r\"\\mathrm{th}\",\n    r\"^\\circ\",\n    r\"^{\\circ}\",\n    r\"\\;\",\n    r\",\\!\",\n    \"{,}\",\n    '\"',\n    \"\\\\dots\",\n]\n\n\ndef normalize_final_answer(final_answer: str) -> str:\n    \"\"\"Normalize a final answer to a quantitative reasoning question.\n\n    Args:\n        final_answer: The answer string to normalize\n\n    Returns:\n        Normalized answer string\n    \"\"\"\n    final_answer = final_answer.split(\"=\")[-1]\n\n    # Apply substitutions and removals\n    for before, after in SUBSTITUTIONS:\n        final_answer = final_answer.replace(before, after)\n    for expr in REMOVED_EXPRESSIONS:\n        final_answer = final_answer.replace(expr, \"\")\n\n    # Extract and normalize LaTeX math\n    final_answer = re.sub(r\"(.*?)(\\$)(.*?)(\\$)(.*)\", \"$\\\\3$\", final_answer)\n    final_answer = re.sub(r\"(\\\\text\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\textbf\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\overline\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\boxed\\{)(.*)(\\})\", \"\\\\2\", final_answer)\n\n    # Normalize shorthand TeX:\n    #  \\fracab -> \\frac{a}{b}\n    #  \\frac{abc}{bef} -> \\frac{abc}{bef}\n    #  \\fracabc -> \\frac{a}{b}c\n    #  \\sqrta -> \\sqrt{a}\n    #  \\sqrtab -> sqrt{a}b\n    final_answer = re.sub(r\"(frac)([^{])(.)\", \"frac{\\\\2}{\\\\3}\", final_answer)\n    final_answer = re.sub(r\"(sqrt)([^{])\", \"sqrt{\\\\2}\", final_answer)\n    final_answer = final_answer.replace(\"$\", \"\")\n\n    # Normalize numbers\n    if final_answer.replace(\",\", \"\").isdigit():\n        final_answer = final_answer.replace(\",\", \"\")\n\n    return final_answer.strip()\n\n\ndef is_correct_minerva(\n    solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r\"(?i)Answer\\s*:\\s*([^\\n]+)\"\n) -> tuple[bool, str]:\n    \"\"\"Check if the solution is correct according to Minerva criteria.\n\n    Args:\n        solution_str: The solution string to check\n        gt: The ground truth answer\n        gt_need_extract: Whether the ground truth needs extraction\n        answer_pattern: Regex pattern to extract the answer\n\n    Returns:\n        Tuple of (is_correct, normalized_prediction)\n    \"\"\"\n    # Extract answer from solution\n    match = re.findall(answer_pattern, solution_str)\n    extracted_answer = match[-1] if match else \"[INVALID]\"\n    pred = normalize_final_answer(extracted_answer)\n\n    # Process ground truth\n    if gt_need_extract:\n        gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt)))\n    else:\n        gt = normalize_final_answer(gt)\n\n    return (pred == gt), pred\n\n\ndef is_correct_strict_box(\n    pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None\n) -> tuple[int, Optional[str]]:\n    \"\"\"Check if the prediction is correct using strict boxed answer criteria.\n\n    Args:\n        pred: The prediction string\n        gt: The ground truth answer\n        pause_tokens_index: Indices of pause tokens\n\n    Returns:\n        Tuple of (score, extracted_prediction)\n    \"\"\"\n    # Extract the relevant part of the prediction\n    if pause_tokens_index is not None:\n        assert len(pause_tokens_index) == 4\n        pred = pred[pause_tokens_index[-1] - 100 :]\n    else:\n        pred = pred[-100:]\n\n    # Extract and check the boxed answer\n    boxed_pred = last_boxed_only_string(pred)\n    extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None\n\n    return 1 if (extracted_pred == gt) else -1, extracted_pred\n\n\ndef verify(\n    solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None\n) -> bool:\n    \"\"\"Verify if the solution is correct.\n\n    Args:\n        solution_str: The solution string to verify\n        answer: The ground truth answer\n        strict_box_verify: Whether to use strict box verification\n        pause_tokens_index: Indices of pause tokens\n\n    Returns:\n        True if the solution is correct, False otherwise\n    \"\"\"\n    if strict_box_verify:\n        correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index)\n        return correct == 1, pred\n\n    correct, pred = is_correct_minerva(solution_str, answer)\n    return correct, pred\n\n\ndef compute_score(\n    solution_str: str,\n    ground_truth: str,\n    strict_box_verify: bool = False,\n    pause_tokens_index: Optional[list[int]] = None,\n) -> float:\n    \"\"\"Compute the reward score for a solution.\n\n    Args:\n        solution_str: The solution string\n        ground_truth: The ground truth answer\n        strict_box_verify: Whether to use strict box verification\n        pause_tokens_index: Indices of pause tokens\n\n    Returns:\n        Reward score (1.0 for correct, -1.0 for incorrect)\n    \"\"\"\n    # Limit solution length for efficiency\n    solution_str = solution_str[-300:]  # The longest answer in MATH-500 has 159 characters\n\n    # Verify the solution\n    correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index)\n\n    reward = 1.0 if correct else -1.0\n    acc = correct\n\n    return {\n        \"score\": reward,\n        \"acc\": acc,\n        \"pred\": pred,\n    }\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/math_reward.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py\n\n\ndef compute_score(solution_str, ground_truth) -> float:\n    retval = 0.0\n    try:\n        string_in_last_boxed = last_boxed_only_string(solution_str)\n        if string_in_last_boxed is not None:\n            answer = remove_boxed(string_in_last_boxed)\n            if is_equiv(answer, ground_truth):\n                retval = 1.0\n    except Exception as e:\n        print(e)\n\n    return retval\n\n\n# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py\ndef is_equiv(str1, str2, verbose=False):\n    if str1 is None and str2 is None:\n        print(\"WARNING: Both None\")\n        return True\n    if str1 is None or str2 is None:\n        return False\n\n    try:\n        ss1 = strip_string(str1)\n        ss2 = strip_string(str2)\n        if verbose:\n            print(ss1, ss2)\n        return ss1 == ss2\n    except Exception:\n        return str1 == str2\n\n\ndef remove_boxed(s):\n    if \"\\\\boxed \" in s:\n        left = \"\\\\boxed \"\n        assert s[: len(left)] == left\n        return s[len(left) :]\n\n    left = \"\\\\boxed{\"\n\n    assert s[: len(left)] == left\n    assert s[-1] == \"}\"\n\n    return s[len(left) : -1]\n\n\ndef last_boxed_only_string(string):\n    idx = string.rfind(\"\\\\boxed\")\n    if \"\\\\boxed \" in string:\n        return \"\\\\boxed \" + string.split(\"\\\\boxed \")[-1].split(\"$\")[0]\n    if idx < 0:\n        idx = string.rfind(\"\\\\fbox\")\n        if idx < 0:\n            return None\n\n    i = idx\n    right_brace_idx = None\n    num_left_braces_open = 0\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n        if string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n        i += 1\n\n    retval = None if right_brace_idx is None else string[idx : right_brace_idx + 1]\n\n    return retval\n\n\ndef fix_fracs(string):\n    substrs = string.split(\"\\\\frac\")\n    new_str = substrs[0]\n    if len(substrs) > 1:\n        substrs = substrs[1:]\n        for substr in substrs:\n            new_str += \"\\\\frac\"\n            if substr[0] == \"{\":\n                new_str += substr\n            else:\n                try:\n                    assert len(substr) >= 2\n                except Exception:\n                    return string\n                a = substr[0]\n                b = substr[1]\n                if b != \"{\":\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}{\" + b + \"}\"\n                else:\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}\" + b + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}\" + b\n    string = new_str\n    return string\n\n\ndef fix_a_slash_b(string):\n    if len(string.split(\"/\")) != 2:\n        return string\n    a = string.split(\"/\")[0]\n    b = string.split(\"/\")[1]\n    try:\n        a = int(a)\n        b = int(b)\n        assert string == \"{}/{}\".format(a, b)\n        new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n        return new_string\n    except Exception:\n        return string\n\n\ndef remove_right_units(string):\n    # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n    if \"\\\\text{ \" in string:\n        splits = string.split(\"\\\\text{ \")\n        assert len(splits) == 2\n        return splits[0]\n    else:\n        return string\n\n\ndef fix_sqrt(string):\n    if \"\\\\sqrt\" not in string:\n        return string\n    splits = string.split(\"\\\\sqrt\")\n    new_string = splits[0]\n    for split in splits[1:]:\n        if split[0] != \"{\":\n            a = split[0]\n            new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n        else:\n            new_substr = \"\\\\sqrt\" + split\n        new_string += new_substr\n    return new_string\n\n\ndef strip_string(string):\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\\\\\%\", \"\")\n    string = string.replace(\"\\\\%\", \"\")\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2 and len(string.split(\"=\")[0]) <= 2:\n        string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1).\n    # Also does a/b --> \\\\frac{a}{b}\n    string = fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = fix_a_slash_b(string)\n\n    return string\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/math_verify.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ntry:\n    from math_verify.errors import TimeoutException\n    from math_verify.metric import math_metric\n    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig\nexcept ImportError:\n    print(\"To use Math-Verify, please install it first by running `pip install math-verify`.\")\n\n\ndef compute_score(model_output: str, ground_truth: str, timeout_score: float = 0) -> bool:\n    verify_func = math_metric(\n        gold_extraction_target=(LatexExtractionConfig(),),\n        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),\n    )\n    ret_score = 0.0\n\n    # Wrap the ground truth in \\boxed{} format for verification\n    ground_truth_boxed = \"\\\\boxed{\" + ground_truth + \"}\"\n    try:\n        ret_score, _ = verify_func([ground_truth_boxed], [model_output])\n    except Exception:\n        pass\n    except TimeoutException:\n        ret_score = timeout_score\n\n    return ret_score\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/prime_code/README.md",
    "content": "## LiveCodeBench\n\n### Introduction\n[LiveCodeBench](https://github.com/LiveCodeBench/LiveCodeBench) provides holistic and contamination-free evaluation of coding capabilities of LLMs. Particularly, LiveCodeBench continuously collects new problems over time from contests across three competition platforms -- LeetCode, AtCoder, and CodeForces. \n\n### How to reproduce\nOur evaluation is grounded on the version found in LiveCodeBench.\n> **Installation**\n```bash\n# Make sure the CUDA version > 12.0.\npip install -r requirements.txt\npip install flash-attn --no-build-isolation\n```\n\n### Acknowleage\nThank you to the [LiveCodeBench](https://livecodebench.github.io/leaderboard.html) team for their contributions to the open-source community."
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/prime_code/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport traceback\n\nfrom .utils import check_correctness as apps_check_correctness\n\n\ndef compute_score(completion, test_cases, continuous=False):\n    # try to get code solution from completion. if the completion is pure code, this will not take effect.\n    solution = completion.split(\"```python\")[-1].split(\"```\")[0]\n    try:\n        try:\n            if not isinstance(test_cases, dict):\n                test_cases = json.loads(test_cases)\n        except Exception as e:\n            print(f\"Error:{e}\")\n\n        # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped.\n        try:\n            res, metadata = apps_check_correctness(in_outs=test_cases, generation=solution, timeout=5, debug=False)\n            metadata = dict(enumerate(metadata))[0]\n            success = all(map(lambda x: x is True, res))\n            if success:\n                return success, metadata\n        except Exception:\n            pass\n\n        test_cases_list = []\n        inputs = test_cases[\"inputs\"]\n        outputs = test_cases[\"outputs\"]\n        for i in range(len(inputs)):\n            test_cases_list.append({\"inputs\": [inputs[i]], \"outputs\": [outputs[i]]})\n\n        if continuous:\n            # per sample test: if continuous score is needed, test first 10 samples regardless of failures\n            # do not test all samples cuz some problems have enormous test cases\n            metadata_list = []\n            res_list = []\n            for test_case_id, test_case in enumerate(test_cases_list):\n                res, metadata = apps_check_correctness(in_outs=test_case, generation=solution, timeout=10, debug=False)\n                try:\n                    metadata = dict(enumerate(metadata))[0]  # metadata can be empty occasionally\n                except Exception:\n                    metadata = {}\n                metadata[\"test_case\"] = {}\n                metadata[\"test_case\"][\"input\"] = str(test_case[\"inputs\"][0])\n                metadata[\"test_case\"][\"output\"] = str(test_case[\"outputs\"][0])\n                metadata[\"test_case\"][\"res\"] = str(res)\n                metadata_list.append(metadata)\n                res_list.extend(res)\n\n                if test_case_id >= 9:\n                    break\n            res_count = len(res_list) if len(res_list) > 0 else 1\n            success = sum(map(lambda x: x is True, res_list)) / res_count\n    except Exception:\n        traceback.print_exc(10)\n        success = False\n        metadata_list = None\n    return success, metadata_list\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/prime_code/testing_util.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\nimport faulthandler\nimport json\nimport platform\n\n# to run the solution files we're using a timing based approach\nimport signal\nimport sys\nimport traceback\n\n# used for debugging to time steps\nfrom datetime import datetime\nfrom enum import Enum\n\n# for capturing the stdout\nfrom io import StringIO\n\n# used for testing the code that reads from input\nfrom unittest.mock import mock_open, patch\n\nimport numpy as np\nfrom pyext import RuntimeModule\n\n\ndef truncatefn(s, length=300):\n    assert isinstance(s, str)\n    if len(s) <= length:\n        return s\n\n    return s[: length // 2] + \"...(truncated) ...\" + s[-length // 2 :]\n\n\nclass CODE_TYPE(Enum):\n    call_based = 0\n    standard_input = 1\n\n\n# used to capture stdout as a list\n# from https://stackoverflow.com/a/16571630/6416660\n# alternative use redirect_stdout() from contextlib\nclass Capturing(list):\n    def __enter__(self):\n        self._stdout = sys.stdout\n        sys.stdout = self._stringio = StringIO()\n        # Make closing the StringIO a no-op\n        self._stringio.close = lambda x: 1\n        return self\n\n    def __exit__(self, *args):\n        self.append(self._stringio.getvalue())\n        del self._stringio  # free up some memory\n        sys.stdout = self._stdout\n\n\ndef only_int_check(val):\n    return isinstance(val, int)\n\n\ndef string_int_check(val):\n    return isinstance(val, str) and val.isdigit()\n\n\ndef combined_int_check(val):\n    return only_int_check(val) or string_int_check(val)\n\n\ndef clean_traceback(error_traceback):\n    file_start = error_traceback.find('File \"<string>\"')\n    # print(file_start)\n    error_traceback = \"Traceback (most recent call last):\\n  \" + error_traceback[file_start:]\n    return error_traceback\n\n\ndef run_test(in_outs, test=None, debug=False, timeout=15):\n    \"\"\"\n    if test(generated_code) is not None it'll try to run the code.\n    otherwise it'll just return an input and output pair.\n    \"\"\"\n    # Disable functionalities that can make destructive changes to the test.\n    reliability_guard()\n\n    if debug:\n        print(f\"start = {datetime.now().time()}\")\n\n    if in_outs:\n        if in_outs.get(\"fn_name\") is None:\n            which_type = CODE_TYPE.standard_input  # Standard input\n            method_name = None\n        else:\n            which_type = CODE_TYPE.call_based  # Call-based\n            method_name = in_outs[\"fn_name\"]\n\n    if debug:\n        print(f\"loaded input_output = {datetime.now().time()}\")\n\n    if test is None:\n        raise AssertionError(\"should not happen: test code is none\")\n    elif test is not None:\n        results = []\n        sol = \"from string import *\\nfrom re import *\\nfrom datetime import *\\nfrom collections import *\\nfrom heapq import *\\nfrom bisect import *\\nfrom copy import *\\nfrom math import *\\nfrom random import *\\nfrom statistics import *\\nfrom itertools import *\\nfrom functools import *\\nfrom operator import *\\nfrom io import *\\nfrom sys import *\\nfrom json import *\\nfrom builtins import *\\nfrom typing import *\\nimport string\\nimport re\\nimport datetime\\nimport collections\\nimport heapq\\nimport bisect\\nimport copy\\nimport math\\nimport random\\nimport statistics\\nimport itertools\\nimport functools\\nimport operator\\nimport io\\nimport sys\\nimport json\\nsys.setrecursionlimit(6*10**5)\\n\"  # noqa: E501\n        if debug:\n            print(f\"loading test code = {datetime.now().time()}\")\n\n        if which_type == CODE_TYPE.call_based:\n            sol += test\n            if debug:\n                print(f\"sol = {sol}\")\n            signal.alarm(timeout)\n            try:\n                tmp_sol = RuntimeModule.from_string(\"tmp_sol\", \"\", sol)\n                tmp = tmp_sol if \"class Solution\" not in test else tmp_sol.Solution()\n                signal.alarm(0)\n            except Exception as e:\n                signal.alarm(0)\n                error_traceback = traceback.format_exc()\n                if debug:\n                    print(f\"type 0 compilation error = {e}\")\n                results.append(-2)\n                return results, {\n                    \"error\": repr(e),\n                    # \"error_code\": -1,\n                    # \"error_message\": \"Compilation Error\",\n                    \"traceback\": clean_traceback(error_traceback),\n                }\n            signal.alarm(0)\n\n        elif which_type == CODE_TYPE.standard_input:\n            # sol\n            # if code has if __name__ == \"__main__\": then remove it\n            try:\n                astree = ast.parse(test)\n                last_block = astree.body[-1]\n                if isinstance(last_block, ast.If):\n                    condition = last_block.test\n                    if ast.unparse(condition).strip() == \"__name__ == '__main__'\":\n                        test = ast.unparse(astree.body[:-1]) + \"\\n\" + ast.unparse(last_block.body)\n            except Exception:\n                pass\n\n            tmp_test = test.split(\"\\n\")\n\n            new_test = []\n            for x in tmp_test:\n                if (not x.startswith(\"from \")) and (not x.startswith(\"import \")):\n                    new_test.append(\"\\t\" + x + \"\\n\")\n                else:\n                    new_test.append(x + \"\\n\")\n            tmp_test = new_test\n\n            new_test = \"\"\n            started = False\n            for i in tmp_test:\n                if i.startswith(\"\\t\") and not started:\n                    new_test += \"stdin = sys.stdin\\nstdout = sys.stdout\\n\"\n                    new_test += \"def code():\\n\"\n                    new_test += i\n                    started = True\n                elif started and ((i.startswith(\"from \")) or (i.startswith(\"import \"))):\n                    new_test += \"\\t\" + i\n                else:\n                    new_test += i\n            tmp_test = new_test\n\n            sol += tmp_test\n            if debug:\n                print(f\"sol = {sol}\")\n            method_name = \"code\"\n            signal.alarm(timeout)\n            try:\n                tmp_sol = RuntimeModule.from_string(\"tmp_sol\", \"\", sol)\n                tmp = tmp_sol\n                signal.alarm(0)\n            except Exception as e:\n                signal.alarm(0)\n                error_traceback = traceback.format_exc()\n                if debug:\n                    print(f\"type 1 compilation error = {e}\")\n                results.append(-2)\n                return results, {\n                    \"error\": repr(e),\n                    # \"error_code\": -1,\n                    # \"error_message\": \"Compilation Error\",\n                    \"traceback\": clean_traceback(error_traceback),\n                }\n            signal.alarm(0)\n        if debug:\n            print(f\"get method = {datetime.now().time()}\")\n\n        try:\n            method = getattr(tmp, method_name)  # get_attr second arg must be str\n        except Exception:\n            signal.alarm(0)\n            error_traceback = traceback.format_exc()\n            error_info = sys.exc_info()\n            print(f\"unable to get function error = {error_info}\")\n            results.append(-2)\n            return results, {\n                \"error\": repr(error_info),\n                # \"error_code\": -1,\n                # \"error_message\": \"Unable to extract code\",\n                \"traceback\": clean_traceback(error_traceback),\n            }\n\n        for index, inputs in enumerate(in_outs[\"inputs\"]):\n            raw_inputs = inputs\n            raw_outputs = in_outs[\"outputs\"][index]\n            if which_type == CODE_TYPE.call_based:\n                inputs = [json.loads(line) for line in inputs.split(\"\\n\")]\n                in_outs[\"outputs\"][index] = json.loads(in_outs[\"outputs\"][index])\n\n                truncate_line_size = 300 // (raw_inputs.count(\"\\n\") + 1)\n                raw_inputs = \"\\n\".join(\n                    [truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split(\"\\n\")]\n                )\n                raw_outputs = truncatefn(raw_outputs, 200)\n            else:\n                raw_inputs = truncatefn(raw_inputs)\n                raw_outputs = truncatefn(raw_outputs, 200)\n            # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)\n            try:\n                if isinstance(inputs[0], dict):\n                    inputs = [{int(k): v for k, v in inputs[0].items()}]\n            except Exception:\n                pass\n            try:\n                if isinstance(in_outs[\"outputs\"][index], dict):\n                    in_outs[\"outputs\"][index] = [{int(k): v for k, v in in_outs[\"outputs\"][index].items()}]\n            except Exception:\n                pass\n            try:\n                if isinstance(in_outs[\"outputs\"][index][0], dict):\n                    in_outs[\"outputs\"][index] = [{int(k): v for k, v in in_outs[\"outputs\"][index][0].items()}]\n            except Exception:\n                pass\n\n            if debug:\n                print(\n                    f\"time: {datetime.now().time()} testing index = {index}  inputs = {inputs}, {type(inputs)}. \"\n                    f\"type = {which_type}\"\n                )\n            if which_type == CODE_TYPE.call_based:  # Call-based\n                signal.alarm(timeout)\n                faulthandler.enable()\n                try:\n                    output = method(*inputs)\n                    raw_true_output = output\n\n                    raw_true_output_copy = json.dumps(output)\n                    raw_true_output_copy = truncatefn(raw_true_output_copy, 200)\n\n                    # ground truth sequences are not tuples\n                    if isinstance(output, tuple):\n                        output = list(output)\n\n                    tmp_result = output == in_outs[\"outputs\"][index]\n                    if isinstance(in_outs[\"outputs\"][index], list) and in_outs[\"outputs\"][index]:\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index][0])\n\n                    # ground truth sequences are not tuples\n                    try:\n                        if isinstance(output[0], tuple):\n                            tmp_result = tmp_result or ([list(x) for x in output] == in_outs[\"outputs\"][index][0])\n                    except Exception:\n                        pass\n                    results.append(tmp_result)\n                    if tmp_result is not True:\n                        return results, {\n                            \"output\": raw_true_output_copy,\n                            \"expected\": raw_outputs,\n                            \"inputs\": raw_inputs,\n                            # \"error_code\": -2,\n                            \"error_message\": \"Wrong Answer\",\n                        }\n                    # reset the alarm\n                    signal.alarm(0)\n                except Exception as e:\n                    signal.alarm(0)\n                    error_traceback = traceback.format_exc()\n                    faulthandler.disable()\n                    if debug:\n                        print(f\"Standard input runtime error or time limit exceeded error = {e}\")\n                    results.append(-1)\n                    return results, {\n                        \"error\": repr(e),\n                        \"traceback\": clean_traceback(error_traceback),\n                    }\n                faulthandler.disable()\n                signal.alarm(0)\n                if debug:\n                    print(\n                        f\"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, \"\n                        f\"{type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                    )\n            elif which_type == CODE_TYPE.standard_input:  # Standard input\n                faulthandler.enable()\n                passed = False\n\n                if isinstance(inputs, list):\n                    inputs = \"\\n\".join(inputs)\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    in_outs[\"outputs\"][index] = \"\\n\".join(in_outs[\"outputs\"][index])\n\n                signal.alarm(timeout)\n                with Capturing() as output:\n                    try:\n                        call_method(method, inputs)\n                        # reset the alarm\n                        signal.alarm(0)\n                        passed = True\n                    except Exception as e:\n                        # runtime error or took too long\n                        signal.alarm(0)\n                        error_traceback = traceback.format_exc()\n                        print(f\"Call-based runtime error or time limit exceeded error = {repr(e)}{e}\")\n                        results.append(-1)\n                        return results, {\n                            \"error\": repr(e),\n                            \"traceback\": clean_traceback(error_traceback),\n                        }\n                    signal.alarm(0)\n                raw_true_output = output[0]\n                raw_true_output_copy = truncatefn(raw_true_output, 200)\n                output = raw_true_output.splitlines()\n                if not passed:\n                    if debug:\n                        nl = \"\\n\"\n                        if not isinstance(inputs, list):\n                            print(\n                                f\"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, \"\n                                f\"inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, \"\n                                f\"{output == [in_outs['outputs'][index]]}\"\n                            )\n                        else:\n                            print(\n                                f\"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, \"\n                                f\"inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                            )\n                    continue\n\n                if passed and debug:\n                    print(f\"==> output = {output}, test outputs = {in_outs['outputs'][index]}\")\n\n                if custom_compare_(output, in_outs[\"outputs\"][index]):\n                    tmp_result = True\n                    results.append(tmp_result)\n                    continue\n\n                # ground truth sequences are expressed as lists not tuples\n                if isinstance(output, tuple):\n                    output = list(output)\n\n                tmp_result = False\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                        if isinstance(output[0], str):\n                            tmp_result = tmp_result or ([e.strip() for e in output] == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check1 exception = {e}\")\n                    pass\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                # try one more time without \\n\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    for tmp_index, i in enumerate(in_outs[\"outputs\"][index]):\n                        in_outs[\"outputs\"][index][tmp_index] = i.split(\"\\n\")\n                        in_outs[\"outputs\"][index][tmp_index] = [\n                            x.strip() for x in in_outs[\"outputs\"][index][tmp_index] if x\n                        ]\n                else:\n                    in_outs[\"outputs\"][index] = in_outs[\"outputs\"][index].split(\"\\n\")\n                    in_outs[\"outputs\"][index] = list(filter(len, in_outs[\"outputs\"][index]))\n                    in_outs[\"outputs\"][index] = list(map(lambda x: x.strip(), in_outs[\"outputs\"][index]))\n\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check2 exception = {e}\")\n                    pass\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                # try by converting the output into a split up list too\n                if isinstance(output, list):\n                    output = list(filter(len, output))\n\n                if debug:\n                    nl = \"\\n\"\n                    if not isinstance(inputs, list):\n                        print(\n                            f\"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, \"\n                            f\"inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, \"\n                            f\"{output == [in_outs['outputs'][index]]} {tmp_result=}\"\n                        )\n                    else:\n                        print(\n                            f\"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, \"\n                            f\"{type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}\"\n                        )\n\n                if debug:\n                    print(f\"{tmp_result=} @a\")\n\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check3 exception = {e}\")\n                    pass\n\n                if debug:\n                    print(f\"{tmp_result=} @b\")\n\n                try:\n                    all_ints = all(\n                        combined_int_check(e1) and combined_int_check(e2)\n                        for e1, e2 in zip(output, in_outs[\"outputs\"][index], strict=True)\n                    )\n                    if not all_ints:\n                        if debug:\n                            print(\n                                [\n                                    combined_int_check(e1) and combined_int_check(e2)\n                                    for e1, e2 in zip(output, in_outs[\"outputs\"][index], strict=True)\n                                ]\n                            )\n                        output_float = [float(e) for e in output]\n                        gt_float = [float(e) for e in in_outs[\"outputs\"][index]]\n                        tmp_result = tmp_result or (\n                            (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)\n                        )\n                except Exception:\n                    pass\n\n                if debug:\n                    print(f\"{tmp_result=} @c\")\n\n                try:\n                    if isinstance(output[0], list):\n                        all_ints = all(\n                            combined_int_check(e1) and combined_int_check(e2)\n                            for e1, e2 in zip(output[0], in_outs[\"outputs\"][index], strict=True)\n                        )\n                        if not all_ints:\n                            output_float = [float(e) for e in output[0]]\n                            gt_float = [float(e) for e in in_outs[\"outputs\"][index][0]]\n                            tmp_result = tmp_result or (\n                                (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)\n                            )\n                except Exception:\n                    pass\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                if debug:\n                    print(f\"{tmp_result=} @d\")\n                # try by converting the stuff into split up list\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    for tmp_index, i in enumerate(in_outs[\"outputs\"][index]):\n                        in_outs[\"outputs\"][index][tmp_index] = set(i.split())\n                else:\n                    in_outs[\"outputs\"][index] = set(in_outs[\"outputs\"][index].split())\n\n                if debug:\n                    print(f\"{tmp_result=} @e\")\n\n                try:\n                    tmp_result = output == in_outs[\"outputs\"][index]\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check4 exception = {e}\")\n                    continue\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                if debug:\n                    print(f\"{tmp_result=} @f\")\n\n                # try by converting the output into a split up list too\n                if isinstance(output, list):\n                    for tmp_index, i in enumerate(output):\n                        output[tmp_index] = i.split()\n                    output = list(filter(len, output))\n                    for tmp_index, i in enumerate(output):\n                        output[tmp_index] = set(i)\n                else:\n                    output = output.split()\n                    output = list(filter(len, output))\n                    output = set(output)\n\n                if debug:\n                    print(f\"{tmp_result=} @g\")\n\n                if tmp_result is True and debug:\n                    print(\"PASSED\")\n\n                results.append(tmp_result)\n                if tmp_result is not True:\n                    return results, {\n                        \"output\": raw_true_output_copy,\n                        \"expected\": raw_outputs,\n                        \"inputs\": raw_inputs,\n                        # \"error_code\": -2,\n                        \"error_message\": \"Wrong Answer\",\n                    }\n\n                if debug:\n                    nl = \"\\n\"\n                    if not isinstance(inputs, list):\n                        print(\n                            f\"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, \"\n                            f\"inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, \"\n                            f\"{output == [in_outs['outputs'][index]]}\"\n                        )\n                    else:\n                        print(\n                            f\"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, \"\n                            f\"{type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                        )\n\n                    print(f\"results = {results}\")\n\n    return results, {}\n\n\ndef custom_compare_(output, ground_truth):\n    if isinstance(output, list):\n        output_1 = \"\\n\".join(output)\n        if stripped_string_compare(output_1, ground_truth):\n            return True\n\n    if isinstance(output, list):\n        output_2 = [o.lstrip().rstrip() for o in output]\n        output_2 = \"\\n\".join(output_2)\n        if stripped_string_compare(output_2, ground_truth):\n            return True\n\n    return False\n\n\ndef stripped_string_compare(s1, s2):\n    s1 = s1.lstrip().rstrip()\n    s2 = s2.lstrip().rstrip()\n    return s1 == s2\n\n\ndef call_method(method, inputs):\n    if isinstance(inputs, list):\n        inputs = \"\\n\".join(inputs)\n\n    inputs_line_iterator = iter(inputs.split(\"\\n\"))\n\n    # sys.setrecursionlimit(10000)\n\n    # @patch('builtins.input', side_effect=inputs.split(\"\\n\"))\n    @patch(\"builtins.open\", mock_open(read_data=inputs))\n    @patch(\"sys.stdin\", StringIO(inputs))\n    @patch(\"sys.stdin.readline\", lambda *args: next(inputs_line_iterator))\n    @patch(\"sys.stdin.readlines\", lambda *args: inputs.split(\"\\n\"))\n    @patch(\"sys.stdin.read\", lambda *args: inputs)\n    # @patch('sys.stdout.write', print)\n    def _inner_call_method(_method):\n        try:\n            return _method()\n        except SystemExit:\n            pass\n        finally:\n            pass\n\n    return _inner_call_method(method)\n\n\ndef reliability_guard(maximum_memory_bytes=None):\n    \"\"\"\n    This disables various destructive functions and prevents the generated code\n    from interfering with the test (e.g. fork bomb, killing other processes,\n    removing filesystem files, etc.)\n    WARNING\n    This function is NOT a security sandbox. Untrusted code, including, model-\n    generated code, should not be blindly executed outside of one. See the\n    Codex paper for more information about OpenAI's code sandbox, and proceed\n    with caution.\n    \"\"\"\n\n    if maximum_memory_bytes is not None:\n        import resource\n\n        resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))\n        resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))\n        if platform.uname().system != \"Darwin\":\n            resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))\n\n    faulthandler.disable()\n\n    import builtins\n\n    builtins.exit = None\n    builtins.quit = None\n\n    import os\n\n    os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n\n    os.kill = None\n    os.system = None  # 防止干扰repl评测\n    os.putenv = None\n    os.remove = None\n    os.removedirs = None\n    os.rmdir = None\n    os.fchdir = None\n    os.setuid = None\n    os.fork = None\n    os.forkpty = None\n    os.killpg = None\n    os.rename = None\n    os.renames = None\n    os.truncate = None\n    os.replace = None\n    os.unlink = None\n    os.fchmod = None\n    os.fchown = None\n    os.chmod = None\n    os.chown = None\n    os.chroot = None\n    os.lchflags = None\n    os.lchmod = None\n    os.lchown = None\n    os.getcwd = None\n    os.chdir = None\n\n    import shutil\n\n    shutil.rmtree = None\n    shutil.move = None\n    shutil.chown = None\n\n    import subprocess\n\n    subprocess.Popen = None  # type: ignore\n\n    __builtins__[\"help\"] = None\n\n    import sys\n\n    sys.modules[\"ipdb\"] = None\n    sys.modules[\"joblib\"] = None\n    sys.modules[\"resource\"] = None\n    sys.modules[\"psutil\"] = None\n    sys.modules[\"tkinter\"] = None\n\n    # Disable some built-in functions that can be destructive\n    for mod in [\"subprocess\", \"ctypes\"]:\n        sys.modules[mod] = None\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/prime_code/utils.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py\n\nimport multiprocessing\nimport os\nimport sys\nimport traceback\nfrom typing import Optional\n\nfrom .testing_util import run_test\n\n\ndef _temp_run(sample, generation, debug, result, metadata_list, timeout):\n    with open(os.devnull, \"w\") as devnull:\n        sys.stdout = devnull\n        sys.stderr = devnull\n        try:\n            res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)\n            result.append(res)\n            metadata_list.append(metadata)\n        except Exception:\n            # print(e) # some tracebacks are extremely long.\n            traceback.print_exc(10)\n            result.append([-1 for i in range(len(sample[\"inputs\"]))])\n            metadata_list.append({})\n\n\ndef check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):\n    \"\"\"Check correctness of code generation with a global timeout.\n    The global timeout is to catch some extreme/rare cases not handled by the timeouts\n    inside `run_test`\"\"\"\n\n    manager = multiprocessing.Manager()\n    result = manager.list()\n    metadata_list = manager.list()\n    p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))\n    p.start()\n    p.join(timeout=timeout + 1)\n    if p.is_alive():\n        p.kill()\n        # p.terminate()\n    if not result:\n        # consider that all tests failed\n        result = [[-1 for i in range(len(in_outs[\"inputs\"]))]]\n        if debug:\n            print(\"global timeout\")\n    return result[0], metadata_list\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/prime_math/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nAnswer checker API that uses sympy to simplify expressions and check for equality.\n\nCall grade_answer(given_answer: str, ground_truth: str).\n\nFROM: https://github.com/openai/prm800k/blob/main/prm800k/grading/grader.py\n\"\"\"\n\nimport contextlib\nimport math\nimport re\n\nimport sympy\nfrom pylatexenc import latex2text\nfrom sympy.parsing import sympy_parser\n\nfrom verl.utils.py_functional import timeout_limit\n\nfrom . import math_normalize\nfrom .grader import math_equal\n\n# import math_normalize\n# from grader import math_equal\n\n# sympy might hang -- we don't care about trying to be lenient in these cases\nBAD_SUBSTRINGS = [\"^{\", \"^(\"]\nBAD_REGEXES = [r\"\\^[0-9]+\\^\", r\"\\^[0-9][0-9]+\"]\nTUPLE_CHARS = \"()[]\"\n\n\ndef _sympy_parse(expr: str):\n    \"\"\"Parses an expression with sympy.\"\"\"\n    py_expr = expr.replace(\"^\", \"**\")\n    return sympy_parser.parse_expr(\n        py_expr,\n        transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)),\n    )\n\n\ndef _parse_latex(expr: str) -> str:\n    \"\"\"Attempts to parse latex to an expression sympy can read.\"\"\"\n    expr = expr.replace(\"\\\\tfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\dfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\frac\", \" \\\\frac\")  # Play nice with mixed numbers.\n    expr = latex2text.LatexNodes2Text().latex_to_text(expr)\n\n    # Replace the specific characters that this parser uses.\n    expr = expr.replace(\"√\", \"sqrt\")\n    expr = expr.replace(\"π\", \"pi\")\n    expr = expr.replace(\"∞\", \"inf\")\n    expr = expr.replace(\"∪\", \"U\")\n    expr = expr.replace(\"·\", \"*\")\n    expr = expr.replace(\"×\", \"*\")\n\n    return expr.strip()\n\n\ndef _is_float(num: str) -> bool:\n    try:\n        float(num)\n        return True\n    except ValueError:\n        return False\n\n\ndef _is_int(x: float) -> bool:\n    try:\n        return abs(x - int(round(x))) <= 1e-7\n    except Exception:\n        return False\n\n\ndef _is_frac(expr: str) -> bool:\n    return bool(re.search(r\"^-?[0-9]+.?/0*[1-9][0-9]*.?$\", expr))\n\n\ndef _str_is_int(x: str) -> bool:\n    try:\n        x = _strip_properly_formatted_commas(x)\n        x = float(x)\n        return abs(x - int(round(x))) <= 1e-7\n    except Exception:\n        return False\n\n\ndef _str_to_int(x: str) -> bool:\n    x = x.replace(\",\", \"\")\n    x = float(x)\n    return int(x)\n\n\ndef _inject_implicit_mixed_number(step: str):\n    \"\"\"\n    Automatically make a mixed number evalable\n    e.g. 7 3/4 => 7+3/4\n    \"\"\"\n    p1 = re.compile(r\"([0-9]) +([0-9])\")\n    step = p1.sub(r\"\\1+\\2\", step)  ## implicit mults\n    return step\n\n\ndef _strip_properly_formatted_commas(expr: str):\n    # We want to be careful because we don't want to strip tuple commas\n    p1 = re.compile(r\"(\\d)(,)(\\d\\d\\d)($|\\D)\")\n    while True:\n        next_expr = p1.sub(r\"\\1\\3\\4\", expr)\n        if next_expr == expr:\n            break\n        expr = next_expr\n    return next_expr\n\n\ndef _normalize(expr: str) -> str:\n    \"\"\"Normalize answer expressions.\"\"\"\n    if expr is None:\n        return None\n\n    # Remove enclosing `\\text{}`.\n    m = re.search(r\"^\\\\text\\{(?P<text>.+?)\\}$\", expr)\n    if m is not None:\n        expr = m.group(\"text\")\n\n    expr = expr.replace(\"\\\\%\", \"%\")\n    expr = expr.replace(\"\\\\$\", \"$\")\n    expr = expr.replace(\"$\", \"\")\n    expr = expr.replace(\"%\", \"\")\n    expr = expr.replace(\" or \", \" , \")\n    expr = expr.replace(\" and \", \" , \")\n\n    expr = expr.replace(\"million\", \"*10^6\")\n    expr = expr.replace(\"billion\", \"*10^9\")\n    expr = expr.replace(\"trillion\", \"*10^12\")\n\n    for unit in [\n        \"degree\",\n        \"cm\",\n        \"centimeter\",\n        \"meter\",\n        \"mile\",\n        \"second\",\n        \"minute\",\n        \"hour\",\n        \"day\",\n        \"week\",\n        \"month\",\n        \"year\",\n        \"foot\",\n        \"feet\",\n        \"inch\",\n        \"yard\",\n        \"liter\",\n    ]:\n        expr = re.sub(f\"{unit}(es)?(s)? *(\\\\^[0-9]+)?\", \"\", expr)\n    expr = re.sub(r\"\\^ *\\\\circ\", \"\", expr)\n\n    if len(expr) > 0 and expr[0] == \"{\" and expr[-1] == \"}\":\n        expr = expr[1:-1]\n\n    expr = re.sub(\",\\\\\\\\! *\", \"\", expr)\n    if _is_float(expr) and _is_int(float(expr)):\n        expr = str(int(round(float(expr))))\n    if \"\\\\\" in expr:\n        with contextlib.suppress(Exception):\n            expr = _parse_latex(expr)\n\n    # edge case with mixed numbers and negative signs\n    expr = re.sub(\"- *\", \"-\", expr)\n\n    expr = _inject_implicit_mixed_number(expr)\n\n    # don't be case sensitive for text answers\n    expr = expr.lower()\n\n    if _str_is_int(expr):\n        expr = str(_str_to_int(expr))\n\n    return expr\n\n\ndef count_unknown_letters_in_expr(expr: str):\n    expr = expr.replace(\"sqrt\", \"\")\n    expr = expr.replace(\"frac\", \"\")\n    letters_in_expr = set([x for x in expr if x.isalpha()])\n    return len(letters_in_expr)\n\n\ndef should_allow_eval(expr: str):\n    # we don't want to try parsing unknown text or functions of more than two variables\n    if count_unknown_letters_in_expr(expr) > 2:\n        return False\n\n    for bad_string in BAD_SUBSTRINGS:\n        if bad_string in expr:\n            return False\n\n    return all(re.search(bad_regex, expr) is None for bad_regex in BAD_REGEXES)\n\n\n@timeout_limit(seconds=10)\ndef are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):\n    are_equal = False\n    try:\n        expr = f\"({ground_truth_normalized})-({given_normalized})\"\n        if should_allow_eval(expr):\n            sympy_diff = _sympy_parse(expr)\n            simplified = sympy.simplify(sympy_diff)\n            if simplified == 0:\n                are_equal = True\n    except Exception:\n        pass\n    return are_equal\n\n\ndef split_tuple(expr: str):\n    \"\"\"\n    Split the elements in a tuple/interval, while handling well-formatted commas in large numbers\n    \"\"\"\n    expr = _strip_properly_formatted_commas(expr)\n    if len(expr) == 0:\n        return []\n    if (\n        len(expr) > 2\n        and expr[0] in TUPLE_CHARS\n        and expr[-1] in TUPLE_CHARS\n        and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])\n    ):\n        elems = [elem.strip() for elem in expr[1:-1].split(\",\")]\n    else:\n        elems = [expr]\n    return elems\n\n\ndef grade_answer(given_answer: str, ground_truth: str) -> bool:\n    \"\"\"\n    The answer will be considered correct if:\n    (a) it normalizes to the same string as the ground truth answer\n    OR\n    (b) sympy can simplify the difference between the expressions to 0\n    \"\"\"\n    if given_answer is None:\n        return False\n\n    ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)\n    given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)\n\n    # be at least as lenient as mathd\n    if ground_truth_normalized_mathd == given_answer_normalized_mathd:\n        return True\n\n    ground_truth_normalized = _normalize(ground_truth)\n    given_normalized = _normalize(given_answer)\n\n    if ground_truth_normalized is None:\n        return False\n\n    if ground_truth_normalized == given_normalized:\n        return True\n\n    if len(given_normalized) == 0:\n        return False\n\n    ground_truth_elems = split_tuple(ground_truth_normalized)\n    given_elems = split_tuple(given_normalized)\n\n    if (\n        len(ground_truth_elems) > 1\n        and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1])\n        or len(ground_truth_elems) != len(given_elems)\n    ):\n        is_correct = False\n    else:\n        for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True):\n            if _is_frac(ground_truth_elem) and _is_frac(given_elem):\n                # if fractions aren't reduced, then shouldn't be marked as correct\n                # so, we don't want to allow sympy.simplify in this case\n                is_correct = ground_truth_elem == given_elem\n            elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):\n                # if the ground truth answer is an integer, we require the given answer to be a strict match\n                # (no sympy.simplify)\n                is_correct = False\n            else:\n                try:\n                    is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)\n                except Exception as e:\n                    # if there's an error, we'll just say it's not correct\n                    is_correct = False\n                    print(f\"Error: {e} from are_equal_under_sympy, {ground_truth_elem}, {given_elem}\")\n            if not is_correct:\n                break\n\n    return is_correct\n\n\ndef remove_boxed(s):\n    left = \"\\\\boxed{\"\n    try:\n        assert s[: len(left)] == left\n        assert s[-1] == \"}\"\n        return s[len(left) : -1]\n    except Exception:\n        return None\n\n\ndef _last_boxed_only_string(string):\n    idx = string.rfind(\"\\\\boxed\")\n    if idx < 0:\n        idx = string.rfind(\"\\\\fbox\")\n        if idx < 0:\n            return None\n\n    i = idx\n    left_brace_idx = None\n    right_brace_idx = None\n    num_left_braces_open = 0\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n            if left_brace_idx is None:\n                left_brace_idx = i\n        elif string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n\n        i += 1\n\n    if left_brace_idx is None or right_brace_idx is None:\n        return None\n\n    return string[left_brace_idx + 1 : right_brace_idx].strip()\n\n\ndef match_answer(response):\n    is_matched = False\n    for ans_marker in [\"answer:\", \"answer is\", \"answers are\"]:\n        ans_idx = response.lower().rfind(ans_marker)\n        if ans_idx != -1:\n            is_matched = True\n            response = response[ans_idx + len(ans_marker) :].strip()\n            if response.endswith(\"\\n\"):\n                response = response[:-2]\n\n    for ans_marker in [\"is answer\", \"is the answer\", \"are answers\", \"are the answers\"]:\n        ans_idx = response.lower().rfind(ans_marker)\n        if ans_idx != -1:\n            is_matched = True\n            response = response[:ans_idx].strip()\n            if response.endswith(\"\\n\"):\n                response = response[:-2]\n\n    # Find boxed\n    ans_boxed = _last_boxed_only_string(response)\n    if ans_boxed:\n        is_matched = True\n        response = ans_boxed\n\n    if \". \" in response:\n        dot_idx = response.lower().rfind(\". \")\n        if dot_idx != -1:\n            response = response[:dot_idx].strip()\n\n    for ans_marker in [\"be \", \"is \", \"are \", \"=\", \": \", \"get \", \"be\\n\", \"is\\n\", \"are\\n\", \":\\n\", \"get\\n\"]:\n        ans_idx = response.lower().rfind(ans_marker)\n        if ans_idx != -1:\n            is_matched = True\n            response = response[ans_idx + len(ans_marker) :].strip()\n            if response.endswith(\"\\n\"):\n                response = response[:-2]\n\n    is_matched = is_matched if any([c.isdigit() for c in response]) else False  # answer must have a digit\n    # Grade\n    return is_matched, response\n\n\ndef compute_score(model_output: str, ground_truth: str) -> bool:\n    model_output = str(model_output)\n    ground_truth = str(ground_truth)\n\n    is_matched, extracted_model_output = match_answer(model_output)\n    format_correctness = \"Step 2:\" in model_output and \"\\\\box\" in model_output\n\n    # grade simple algebra questions. if succeeded, return; otherwise, proceed to more complex grading\n    if grade_answer(extracted_model_output, ground_truth):\n        return True, True, extracted_model_output\n\n    try:\n        if \"\\\\pi\" in extracted_model_output or \"\\\\pi\" in ground_truth:\n            equivs = []\n            for pi in [math.pi, 3.14]:\n                equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi))\n            is_correct = any(equivs)\n        else:\n            is_correct = math_equal(extracted_model_output, ground_truth, timeout=True)\n    except Exception:\n        is_correct = False\n\n    return is_correct, format_correctness, extracted_model_output\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/prime_math/grader.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright (c) Microsoft Corporation.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE\n\n# Copyright (c) 2023 OpenAI\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Copyright (c) 2021 Dan Hendrycks\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:\n- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py\n- https://github.com/microsoft/ProphetNet/tree/master/CRITIC\n- https://github.com/openai/prm800k\n\"\"\"\n\nimport contextlib\nimport math\nimport re\nfrom math import isclose\n\n# sympy related\nfrom sympy import N, simplify\nfrom sympy.parsing.latex import parse_latex\nfrom sympy.parsing.sympy_parser import parse_expr\n\n# verl related\nfrom verl.utils.py_functional import timeout_limit\n\n\ndef is_digit(s):\n    try:\n        if \"{,}\" in str(s):\n            num = float(str(s).replace(\"{,}\", \"\"))\n            return True, num\n\n        num = float(str(s).replace(\",\", \"\"))\n        return True, num\n    except ValueError:\n        return False, None\n\n\ndef normalize(answer, pi) -> str:\n    # checking if answer is $<number> and removing $ in that case to compare\n    if isinstance(answer, str) and bool(re.match(r\"\\$\\d+(\\.\\d+)?\", answer)):\n        return answer[1:]\n\n    # checking if answer is <number>% or <number>\\\\% and removing %\n    if isinstance(answer, str) and (\n        bool(re.match(r\"^\\d+(\\.\\d+)?%$\", answer)) or bool(re.match(r\"^\\d+(\\.\\d+)?\\\\%$\", answer))\n    ):\n        return answer.replace(\"\\\\%\", \"\").replace(\"%\", \"\")\n\n    # handle base\n    answer = handle_base(answer)\n\n    # handle pi\n    answer = handle_pi(answer, pi)\n\n    return answer\n\n\ndef handle_base(x) -> str:\n    if isinstance(x, str) and \"_\" in x:\n        # Due to base\n        x = x.split(\"_\")[0]\n        x = float(x)\n        return int(x)\n    return x\n\n\ndef handle_pi(string, pi):\n    if isinstance(string, str) and \"\\\\pi\" in string:\n        # Find the first occurrence of \"\\pi\"\n        idx = string.find(\"\\\\pi\")\n\n        # Iterate over the string and find all occurrences of \"\\pi\" with a valid previous character\n        while idx != -1:\n            if idx > 0 and string[idx - 1].isdigit():\n                # Replace \"\\pi\" with \"*math.pi\" if the previous character is a digit\n                string = string[:idx] + f\"*{pi}\" + string[idx + 3 :]\n            else:\n                # Replace \"\\pi\" with \"1*math.pi\" if the previous character is not a digit\n                string = string[:idx] + f\"1*{pi}\" + string[idx + 3 :]\n\n            # Find the next occurrence of \"\\pi\"\n            idx = string.find(\"\\\\pi\", idx + 1)\n\n        # Evaluate the expression using eval() function\n        with contextlib.suppress(Exception):\n            string = eval(string)\n\n    return string\n\n\ndef math_equal(\n    prediction: bool | float | str,\n    reference: float | str,\n    include_percentage: bool = True,\n    tolerance: float = 1e-4,\n    timeout: float = 10.0,\n    pi: float = math.pi,\n) -> bool:\n    \"\"\"\n    Exact match of math if and only if:\n    1. numerical equal: both can convert to float and are equal\n    2. symbolic equal: both can convert to sympy expression and are equal\n    \"\"\"\n\n    prediction = normalize(prediction, pi)\n    reference = normalize(reference, pi)\n\n    if isinstance(prediction, str) and len(prediction) > 1000:  # handling weird corner-cases\n        prediction = prediction[:1000]\n\n    # 0. string comparison\n    if isinstance(prediction, str) and isinstance(reference, str):\n        if prediction.strip().lower() == reference.strip().lower():\n            return True\n        if prediction.replace(\" \", \"\") == reference.replace(\" \", \"\"):\n            return True\n\n    try:  # 1. numerical equal\n        if is_digit(prediction)[0] and is_digit(reference)[0]:\n            prediction = is_digit(prediction)[1]\n            reference = is_digit(reference)[1]\n            # number questions\n            gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]\n            for item in gt_result:\n                try:\n                    if isclose(item, prediction, rel_tol=tolerance):\n                        return True\n                except Exception:\n                    continue\n            return False\n    except Exception:\n        pass\n\n    if not prediction and prediction not in [0, False]:\n        return False\n\n    # 2. symbolic equal\n    reference = str(reference).strip()\n    prediction = str(prediction).strip()\n\n    ## deal with [], (), {}\n    prediction = format_intervals(prediction)\n\n    pred_str, ref_str = prediction, reference\n    if (prediction.startswith(\"[\") and prediction.endswith(\"]\") and not reference.startswith(\"(\")) or (\n        prediction.startswith(\"(\") and prediction.endswith(\")\") and not reference.startswith(\"[\")\n    ):\n        pred_str = pred_str.strip(\"[]()\")\n        ref_str = ref_str.strip(\"[]()\")\n    for s in [\"{\", \"}\", \"(\", \")\"]:\n        ref_str = ref_str.replace(s, \"\")\n        pred_str = pred_str.replace(s, \"\")\n    if pred_str == ref_str:\n        return True\n\n    ## [a, b] vs. [c, d], return a==c and b==d\n    if (\n        prediction\n        and reference\n        and prediction[0] in \"([\"\n        and prediction[-1] in \")]\"\n        and prediction[0] == reference[0]\n        and prediction[-1] == reference[-1]\n    ):\n        pred_parts = prediction[1:-1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts) and all(\n            [\n                math_equal(pred_pt, ref_pt, include_percentage, tolerance)\n                for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)\n            ]\n        ):\n            return True\n\n    if \",\" in prediction and \",\" in reference:\n        pred_parts = [item.strip() for item in prediction.split(\",\")]\n        ref_parts = [item.strip() for item in reference.split(\",\")]\n\n        if len(pred_parts) == len(ref_parts):\n            return bool(\n                all(\n                    [\n                        math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance)\n                        for i in range(len(pred_parts))\n                    ]\n                )\n            )\n\n    # if we have point == tuple of values\n    if prediction.startswith(\"Point\") and reference[0] == \"(\" and reference[-1] == \")\":\n        pred_parts = prediction[prediction.find(\"(\") + 1 : -1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts) and all(\n            [\n                math_equal(pred_pt, ref_pt, include_percentage, tolerance)\n                for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=False)\n            ]\n        ):\n            return True\n\n    # if reference is a matrix\n    if r\"\\begin{pmatrix}\" in reference and prediction.startswith(\"Matrix\"):\n        try:\n            pred_matrix = parse_expr(prediction)\n            ref_matrix_items = reference.split()[1:-1:2]\n            if len(pred_matrix) == len(ref_matrix_items) and all(\n                [\n                    math_equal(pred, ref, include_percentage, tolerance)\n                    for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False)\n                ]\n            ):\n                return True\n        except Exception:\n            pass\n    elif r\"\\begin{pmatrix}\" in reference and prediction.startswith(\"[\") and prediction.endswith(\"]\"):\n        if isinstance(eval(prediction), list):\n            try:\n                pred_matrix = eval(prediction)\n                # ref_matrix_items = reference.split()[1:-1:2]\n                ref_matrix_items = (\n                    reference.removeprefix(r\"\\\\begin{pmatrix}\")\n                    .removeprefix(r\"\\begin{pmatrix}\")\n                    .removesuffix(r\"\\\\end{pmatrix}\")\n                    .removesuffix(r\"\\end{pmatrix}\")\n                )\n                ref_matrix_items = ref_matrix_items.split(\"\\\\\")\n                ref_matrix_items = [row.split(\"&\") if \"&\" in row else row for row in ref_matrix_items]\n                if len(pred_matrix) == len(ref_matrix_items) and all(\n                    [\n                        math_equal(pred, ref, include_percentage, tolerance)\n                        for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False)\n                    ]\n                ):\n                    return True\n            except Exception:\n                pass\n\n    return symbolic_equal(prediction, reference, tolerance, timeout)\n\n\ndef symbolic_equal(a, b, tolerance, timeout=10.0):\n    def _parse(s):\n        for f in [parse_expr, parse_latex]:\n            try:\n                with timeout_limit(seconds=timeout):\n                    return f(s)\n            except TimeoutError:\n                print(f\"Parsing timed out for {s}\")\n                continue\n            except Exception:\n                continue\n        return s\n\n    a = _parse(a)\n    b = _parse(b)\n\n    try:\n        with timeout_limit(seconds=timeout):\n            if simplify(a - b) == 0:\n                return True\n    except TimeoutError:\n        print(f\"Simplification timed out for {a} - {b}\")\n        pass\n    except Exception:\n        pass\n\n    try:\n        with timeout_limit(seconds=timeout):\n            if isclose(N(a), N(b), rel_tol=tolerance):\n                return True\n    except TimeoutError:\n        print(f\"Numerical evaluation timed out for {a}, {b}\")\n        pass\n    except Exception:\n        pass\n    return False\n\n\ndef format_intervals(prediction):\n    patterns = {\n        \"Interval(\": r\"^Interval\\((.*)\\)$\",\n        \"Interval.Ropen(\": r\"^Interval\\.Ropen\\((.*)\\)$\",\n        \"Interval.Lopen(\": r\"^Interval\\.Lopen\\((.*)\\)$\",\n        \"Interval.open(\": r\"^Interval\\.open\\((.*)\\)$\",\n    }\n\n    for key, pattern in patterns.items():\n        match = re.match(pattern, prediction)\n        if match:\n            inner_content = match.group(1)\n\n            if key == \"Interval(\":  # Intarval(a, b) == [a, b]\n                return f\"[{inner_content}]\"\n            elif key == \"Interval.Ropen(\":  # Intarval.Ropen(a, b) == [a, b)\n                return f\"[{inner_content})\"\n            elif key == \"Interval.Lopen(\":  # Intarval.Lopen(a, b) == (a, b]\n                return f\"({inner_content}]\"\n            elif key == \"Interval.open(\":  # Intarval.open(a, b) == (a, b)\n                return f\"({inner_content})\"\n\n    return prediction\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/prime_math/math_normalize.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright (c) 2021 Dan Hendrycks\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\"\"\"\nThis logic is largely copied from the Hendrycks' MATH release (math_equivalence).\n\nFrom: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py\n\"\"\"\n\nimport re\nfrom typing import Optional\n\n\ndef normalize_answer(answer: Optional[str]) -> Optional[str]:\n    if answer is None:\n        return None\n    answer = answer.strip()\n    try:\n        # Remove enclosing `\\text{}`.\n        m = re.search(r\"^\\\\text\\{(?P<text>.+?)\\}$\", answer)\n        if m is not None:\n            answer = m.group(\"text\").strip()\n        return _strip_string(answer)\n    except Exception:\n        return answer\n\n\ndef _fix_fracs(string):\n    substrs = string.split(\"\\\\frac\")\n    new_str = substrs[0]\n    if len(substrs) > 1:\n        substrs = substrs[1:]\n        for substr in substrs:\n            new_str += \"\\\\frac\"\n            if substr[0] == \"{\":\n                new_str += substr\n            else:\n                try:\n                    assert len(substr) >= 2\n                except Exception:\n                    return string\n                a = substr[0]\n                b = substr[1]\n                if b != \"{\":\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}{\" + b + \"}\"\n                else:\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}\" + b + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}\" + b\n    string = new_str\n    return string\n\n\ndef _fix_a_slash_b(string):\n    if len(string.split(\"/\")) != 2:\n        return string\n    a = string.split(\"/\")[0]\n    b = string.split(\"/\")[1]\n    try:\n        a = int(a)\n        b = int(b)\n        assert string == \"{}/{}\".format(a, b)\n        new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n        return new_string\n    except Exception:\n        return string\n\n\ndef _remove_right_units(string):\n    # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n    if \"\\\\text{ \" in string:\n        splits = string.split(\"\\\\text{ \")\n        assert len(splits) == 2\n        return splits[0]\n    else:\n        return string\n\n\ndef _fix_sqrt(string):\n    if \"\\\\sqrt\" not in string:\n        return string\n    splits = string.split(\"\\\\sqrt\")\n    new_string = splits[0]\n    for split in splits[1:]:\n        if split[0] != \"{\":\n            a = split[0]\n            new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n        else:\n            new_substr = \"\\\\sqrt\" + split\n        new_string += new_substr\n    return new_string\n\n\ndef _strip_string(string):\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = _remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\\\\\%\", \"\")\n    string = string.replace(\"\\\\%\", \"\")\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2 and len(string.split(\"=\")[0]) <= 2:\n        string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = _fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1).\n    # Also does a/b --> \\\\frac{a}{b}\n    string = _fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = _fix_a_slash_b(string)\n\n    return string\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/sandbox_fusion/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport logging\nimport traceback\n\nfrom .utils import check_correctness\n\n\"\"\"\nVerify code correctness using the Sandbox Fusion (https://github.com/bytedance/SandboxFusion).\nYou can either deploy the sandbox_fusion service yourself or use the\nFaaS service provided by public cloud, eg: volcengine.com.\n\"\"\"\nlogger = logging.getLogger(__name__)\n\n\ndef compute_score(\n    sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, completion, test_cases, continuous=False, timeout=10\n):\n    \"\"\"\n    Computes the code score using the remote sandbox API.\n\n    Args:\n        sandbox_fusion_url: The URL of the sandbox_fusion service, eg: \"https://<your service endpoint>/run_code\"\n\n        completion: The completion string containing the code.\n        test_cases: JSON string or dictionary containing \"inputs\" and \"outputs\".\n        continuous: Whether to compute a continuous score (based on the first N test cases).\n        timeout: Timeout for each test case.\n\n    Returns:\n        A tuple (score, metadata_list).\n        score: Float score (0.0 to 1.0).\n        metadata_list: List containing execution metadata for each test case.\n    \"\"\"\n    solution = completion\n    if \"```python\" in completion:\n        solution = completion.split(\"```python\")[-1].split(\"```\")[0]\n    elif \"```\" in completion:\n        # Handle cases like ```\\ncode\\n```\n        parts = completion.split(\"```\")\n        if len(parts) >= 2:\n            solution = parts[1]\n            # Remove potential language specifier like 'python\\n'\n            if \"\\n\" in solution:\n                first_line, rest = solution.split(\"\\n\", 1)\n                if first_line.strip().isalpha():  # Simple check for language name\n                    solution = rest\n    else:\n        return 0.0, [{\"error\": \"Invalid completion (missing code block)\"}]\n\n    try:\n        if not isinstance(test_cases, dict):\n            try:\n                test_cases = json.loads(test_cases)\n            except json.JSONDecodeError as e:\n                logger.error(f\"Failed to parse test_cases JSON: {e}\")\n                return 0.0, [{\"error\": \"Invalid test_cases JSON format\"}]\n\n        if test_cases is not None and \"assert_case\" in test_cases and isinstance(test_cases.get(\"assert_case\"), list):\n            assert_cases = test_cases.get(\"assert_case\")\n            test_cases.setdefault(\"inputs\", [\"\" for _ in assert_cases])\n            test_cases.setdefault(\"outputs\", [None for _ in assert_cases])\n        elif not test_cases or \"inputs\" not in test_cases or \"outputs\" not in test_cases:\n            logger.error(\"Invalid test_cases structure.\")\n            return 0.0, [{\"error\": \"Invalid test_cases structure (missing inputs/outputs)\"}]\n\n        # Check all test cases\n        # Note: The return value of check_correctness might need adaptation here\n        # Assume check_correctness returns (results_list, metadata_list)\n        # results_list contains True, False, or error codes (-1, -2, -3, etc.)\n        res_list, metadata_list = check_correctness(\n            sandbox_fusion_url=sandbox_fusion_url,\n            in_outs=test_cases,\n            generation=solution,\n            timeout=timeout,\n            concurrent_semaphore=concurrent_semaphore,\n            memory_limit_mb=memory_limit_mb,\n        )\n\n        # Calculate score\n        if not res_list:  # If there are no results (e.g., invalid input)\n            return 0.0, metadata_list\n\n        if continuous:\n            # Calculate pass rate for the first N (e.g., 10) test cases\n            num_to_consider = min(len(res_list), 10)\n            if num_to_consider == 0:\n                score = 0.0\n            else:\n                passed_count = sum(1 for r in res_list[:num_to_consider] if r is True)\n                score = passed_count / num_to_consider\n            # Return all metadata, even if score is based on the first N\n            final_metadata = metadata_list\n        else:\n            # Calculate pass rate for all test cases\n            passed_count = sum(1 for r in res_list if r is True)\n            total_cases = len(res_list)\n            score = passed_count / total_cases if total_cases > 0 else 0.0\n            final_metadata = metadata_list\n\n    except Exception as e:\n        logger.error(f\"Error during compute_score: {e}\")\n        traceback.print_exc()\n        score = 0.0\n        # Try to return partial metadata if available, otherwise return error info\n        final_metadata = metadata_list if \"metadata_list\" in locals() else [{\"error\": f\"Unhandled exception: {e}\"}]\n\n    # Ensure float and list are returned\n    return float(score), final_metadata if isinstance(final_metadata, list) else [final_metadata]\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/sandbox_fusion/utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport concurrent.futures  # <-- Import concurrent.futures\nimport json\nimport logging\nimport os\nimport threading\nimport time\nimport traceback\nimport uuid\nfrom typing import Any, Optional\n\nimport requests\n\nDEFAULT_TIMEOUT = 10  # Default compile and run timeout\nMAX_RETRIES = 3\nINITIAL_RETRY_DELAY = 1\nAPI_TIMEOUT = 10\n\nlogger = logging.getLogger(__name__)\n\n# Define supported languages list (optional, for documentation or validation)\nSUPPORTED_LANGUAGES = [\n    \"python\",\n    \"cpp\",\n    \"nodejs\",\n    \"go\",\n    \"go_test\",\n    \"java\",\n    \"php\",\n    \"csharp\",\n    \"bash\",\n    \"typescript\",\n    \"sql\",\n    \"rust\",\n    \"cuda\",\n    \"lua\",\n    \"R\",\n    \"perl\",\n    \"D_ut\",\n    \"ruby\",\n    \"scala\",\n    \"julia\",\n    \"pytest\",\n    \"junit\",\n    \"kotlin_script\",\n    \"jest\",\n    \"verilog\",\n    \"python_gpu\",\n    \"lean\",\n    \"swift\",\n    \"racket\",\n]\n\n\ndef call_sandbox_api(\n    sandbox_fusion_url: str,\n    code: str,\n    stdin: Optional[str],\n    compile_timeout: int,\n    run_timeout: int,\n    memory_limit_mb: int,\n    language: str = \"python\",\n) -> tuple[Optional[dict[str, Any]], Optional[str]]:  # <-- Remove request_id parameter\n    \"\"\"\n    Calls the remote sandbox API to execute code with retry logic for Gateway Timeout,\n    using increasing delay between retries. Logs internal calls with a unique ID.\n\n    Args:\n        sandbox_fusion_url: The URL of the sandbox fusion API.\n        code: The code string to execute.\n        stdin: The standard input string.\n        compile_timeout: Compile timeout in seconds.\n        run_timeout: Run timeout in seconds.\n        language: The programming language of the code (e.g., \"python\", \"cpp\", \"java\"). Defaults to \"python\".\n\n    Returns:\n        A tuple (response_json, error_message).\n        If successful, response_json is the API's returned JSON object, error_message is None.\n        If failed after retries, response_json is None, error_message contains the error information.\n    \"\"\"\n    request_id = str(uuid.uuid4())  # <-- Generate request_id internally\n    log_prefix = f\"[Request ID: {request_id}] \"  # <-- Create log prefix\n\n    if language not in SUPPORTED_LANGUAGES:\n        error_msg = f\"{log_prefix}Unsupported language: {language}\"\n        logger.error(error_msg)\n        return None, error_msg\n\n    payload = json.dumps(\n        {\n            \"compile_timeout\": compile_timeout,\n            \"run_timeout\": run_timeout,\n            \"code\": code,\n            \"stdin\": stdin,\n            \"memory_limit_MB\": memory_limit_mb,\n            \"language\": language,  # Use the passed language parameter\n            \"files\": {},\n            \"fetch_files\": [],\n        }\n    )\n    headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n    # Calculate a reasonable request timeout based on compile/run timeouts plus a buffer\n    request_timeout = compile_timeout + run_timeout + API_TIMEOUT\n\n    last_error = None  # Store the last error encountered\n\n    for attempt in range(MAX_RETRIES):\n        try:\n            logger.info(\n                f\"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling sandbox API at {sandbox_fusion_url}\"\n            )  # <-- Use internal log_prefix\n            response = requests.post(\n                sandbox_fusion_url,\n                headers=headers,\n                data=payload,\n                timeout=request_timeout,  # Use the calculated timeout\n            )\n\n            # Check for Gateway Timeout (504) specifically for retrying\n            if response.status_code == 504:\n                last_error = (\n                    f\"{log_prefix}API Request Error: Gateway Timeout (504) on attempt \"\n                    f\"{attempt + 1}/{MAX_RETRIES}\"\n                )  # <-- Use internal log_prefix\n                logger.warning(last_error)\n                if attempt < MAX_RETRIES - 1:  # Don't sleep after the last attempt\n                    # Calculate increasing delay (e.g., 1s, 2s, 4s, ...) or (1s, 2s, 3s, ...)\n                    # Simple linear increase: delay = INITIAL_RETRY_DELAY * (attempt + 1)\n                    # Exponential backoff: delay = INITIAL_RETRY_DELAY * (2 ** attempt)\n                    delay = INITIAL_RETRY_DELAY * (attempt + 1)  # Using linear increase for simplicity\n                    logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")  # <-- Use internal log_prefix\n                    time.sleep(delay)\n                continue  # Go to the next retry attempt\n\n            # Check for other HTTP errors (e.g., 4xx, other 5xx)\n            response.raise_for_status()\n\n            # If successful (status code 2xx)\n            logger.info(\n                f\"{log_prefix}Sandbox API call successful on attempt {attempt + 1}\"\n            )  # <-- Use internal log_prefix\n            return response.json(), None\n\n        except requests.exceptions.RequestException as e:\n            last_error = f\"{log_prefix}API Request Error: {e}\"  # <-- Use internal log_prefix\n            break  # Exit retry loop on non-504 request errors\n        except json.JSONDecodeError as e:\n            raw_response_text = response.text if \"response\" in locals() else \"N/A\"\n            last_error = f\"{log_prefix}API Response JSON Decode Error: {e}\"  # <-- Use internal log_prefix\n            break  # Exit retry loop on JSON decode errors\n        except Exception as e:\n            last_error = f\"{log_prefix}Unexpected Error: {e}\"  # <-- Use internal log_prefix\n            break  # Exit retry loop on other unexpected errors\n\n    # If loop finishes without returning success, return the last recorded error\n    logger.error(f\"{log_prefix}Sandbox API call failed. Last error: {last_error}\")  # <-- Use internal log_prefix\n    # Return the error message without the prefix, as the caller doesn't need the internal ID\n    # Ensure API call failure returns error message, leading to -1 in check_correctness\n    return None, last_error.replace(log_prefix, \"API Call Failed: \") if last_error else \"API Call Failed after retries\"\n\n\ndef _process_single_case(\n    case_index: int,\n    stdin_data: Any,\n    expected_output: Any,\n    sandbox_fusion_url: str,\n    generation: str,\n    timeout: int,\n    memory_limit_mb: int,\n    language: str,\n    concurrent_semaphore: Optional[threading.Semaphore] = None,\n    fn_name: Optional[str] = None,\n) -> tuple[int, dict[str, Any]]:\n    \"\"\"Helper function to process a single test case.\"\"\"\n    api_response = None\n    error_msg = None\n    logger.info(f\"Processing test case {case_index + 1}.\")\n\n    current_generation_code = generation\n\n    if fn_name and language == \"python\":\n        # Wrapper assumes stdin_data is a JSON string for function arguments.\n        wrapper_code = f\"\"\"\nimport traceback\nfrom string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\n\n# === User's Original Code START ===\n{generation}\n# === User's Original Code END ===\n\n_SANDBOX_FN_NAME = \"{fn_name}\"\n\ndef _execute_user_function():\n    # --- Input Parsing ---\n    _raw_input_str = sys.stdin.read()\n    _args = []\n    if _raw_input_str.strip(): # If there's input\n        try:\n            _args = [json.loads(line) for line in _raw_input_str.split('\\\\n')]\n        except json.JSONDecodeError as _je:\n            sys.stderr.write(f\"WrapperError: Invalid JSON input for '{{_SANDBOX_FN_NAME}}': {{_je}}\\\\nInput was: \"\n                              f\"{{_raw_input_str[:200]}}\\\\n\")\n            return None, True # result, error_occurred\n\n    # --- Function Location and Execution ---\n    try:\n        _target_callable = None\n        # Try global scope first\n        if _SANDBOX_FN_NAME in globals():\n            _target_callable = globals()[_SANDBOX_FN_NAME]\n        # Else, if 'Solution' class exists, try to get its method\n        elif 'Solution' in globals():\n            _Solution_class = globals()['Solution']\n            # Attempt to instantiate and get method.\n            # Errors (e.g., Solution not a class, instantiation fails, method missing)\n            # will be caught by the broad except block below.\n            _solution_instance = _Solution_class()\n            _target_callable = getattr(_solution_instance, _SANDBOX_FN_NAME)\n\n        if not _target_callable:\n            sys.stderr.write(f\"WrapperError: Function or method '{{_SANDBOX_FN_NAME}}' not found.\\\\n\")\n            return None, True # result, error_occurred\n\n        _fn_result = _target_callable(*_args)\n        return _fn_result, False # result, no_error\n    except Exception: # Catches errors from Solution instantiation, getattr, or function call\n        sys.stderr.write(f\"Error during setup or execution of '{{_SANDBOX_FN_NAME}}':\\\\n{{traceback.format_exc()}}\\\\n\")\n        return None, True # result, error_occurred\n\nif __name__ == '__main__':\n    _result, _error_occurred = _execute_user_function()\n\n    if not _error_occurred:\n        # Serialize result to stdout\n        if isinstance(_result, (dict, list, tuple)) or _result is None or isinstance(_result, bool):\n            print(json.dumps(_result))\n        elif isinstance(_result, (int, float, str)):\n            print(str(_result)) # Ensure string conversion for print\n        else:\n            # For other types, default to string representation.\n            print(str(_result))\n    # Optional: To explicitly exit with an error code if the sandbox relies on it\n    # else:\n    #    sys.exit(1)\n\"\"\"\n        current_generation_code = wrapper_code\n\n    stdin = None if stdin_data is None else str(stdin_data)\n    try:\n        if concurrent_semaphore:\n            # logger.debug(f\"Case {case_index + 1}: Attempting to acquire semaphore.\")\n            with concurrent_semaphore:\n                # logger.debug(f\"Case {case_index + 1}: Semaphore acquired. Calling API.\")\n                api_response, error_msg = call_sandbox_api(\n                    sandbox_fusion_url=sandbox_fusion_url,\n                    code=current_generation_code,\n                    stdin=stdin,\n                    compile_timeout=timeout,\n                    run_timeout=timeout,\n                    memory_limit_mb=memory_limit_mb,\n                    language=language,\n                )\n            # logger.debug(f\"Case {case_index + 1}: Semaphore released.\")\n        else:\n            api_response, error_msg = call_sandbox_api(\n                sandbox_fusion_url=sandbox_fusion_url,\n                code=current_generation_code,\n                stdin=stdin,\n                compile_timeout=timeout,\n                run_timeout=timeout,\n                memory_limit_mb=memory_limit_mb,\n                language=language,\n            )\n    except Exception as e:\n        error_msg = f\"API Request Exception during check_correctness for case {case_index + 1}: {e}\"\n        logger.error(f\"Case {case_index + 1}: {error_msg}\")\n        traceback.print_exc()\n\n    metadata = {\n        \"case_index\": case_index,\n        \"input\": stdin,\n        \"expected_output\": str(expected_output) if expected_output else None,\n        \"api_request_error\": error_msg,\n        \"api_response\": None,\n        \"status\": \"unknown\",\n        \"stdout\": None,\n        \"stderr\": None,\n        \"exit_code\": None,\n        \"duration\": None,\n        \"compile_duration\": None,\n        \"compile_stderr\": None,\n        \"api_status\": None,\n        \"compile_status\": None,\n        \"run_status\": None,\n    }\n    result_status = -1  # Default error: API request error or unknown sandbox error\n\n    if error_msg:\n        metadata[\"status\"] = \"api_error\"\n        result_status = -1  # API request itself failed (includes timeout after retries)\n        logger.error(f\"Case {case_index}: API error occurred: {error_msg}\")\n        # Log code and input only on error for brevity\n        generation_to_log = generation[:200] + \"...\" if len(generation) > 200 else generation\n        logger.error(f\"Case {case_index}: code: {generation_to_log}\")\n        logger.error(f\"Case {case_index}: input: {stdin}\")\n    elif api_response:\n        # --- Add debug logging ---\n        logger.debug(f\"Case {case_index}: API Response: {api_response}\")\n        metadata[\"api_response\"] = api_response\n        metadata[\"api_status\"] = api_response.get(\"status\")\n        compile_result = api_response.get(\"compile_result\")\n        run_result = api_response.get(\"run_result\")\n\n        # Extract compile information\n        if compile_result:\n            metadata[\"compile_status\"] = compile_result.get(\"status\")\n            metadata[\"compile_duration\"] = compile_result.get(\"execution_time\")\n            metadata[\"compile_stderr\"] = compile_result.get(\"stderr\")\n\n        # Extract run information\n        if run_result:\n            metadata[\"run_status\"] = run_result.get(\"status\")\n            metadata[\"stdout\"] = run_result.get(\"stdout\")\n            metadata[\"stderr\"] = run_result.get(\"stderr\")  # stderr during runtime\n            metadata[\"exit_code\"] = run_result.get(\"return_code\")\n            metadata[\"duration\"] = run_result.get(\"execution_time\")\n\n        # --- Determine status based on API response ---\n        api_status = metadata[\"api_status\"]\n\n        if api_status == \"SandboxError\":\n            metadata[\"status\"] = \"sandbox_error\"\n            result_status = -1  # Internal sandbox error\n        elif api_status == \"Failed\":\n            # --- Add debug logging ---\n            logger.debug(f\"API returned Failed status. Response: {api_response}\")\n            logger.debug(f\"Compile Result: {compile_result}\")\n            logger.debug(f\"Run Result: {run_result}\")\n            # --- Check the logic here ---\n            # Compile failed or timed out\n            is_compile_error = compile_result and (\n                metadata[\"compile_status\"] in [\"Error\", \"TimeLimitExceeded\"]\n                or (metadata[\"compile_status\"] == \"Finished\" and compile_result.get(\"return_code\") != 0)\n            )\n            if is_compile_error:\n                # Differentiate between compile_error and compile_timeout based on specific status\n                if metadata[\"compile_status\"] == \"TimeLimitExceeded\":\n                    metadata[\"status\"] = \"compile_timeout\"\n                else:  # Includes Error and Finished but return_code != 0 cases\n                    metadata[\"status\"] = \"compile_error\"\n                result_status = -4\n            # Run failed or timed out\n            elif run_result:\n                # Modified condition: Check for TimeLimitExceeded OR (Finished with non-zero exit code) OR Error status\n                is_runtime_error = (\n                    metadata[\"run_status\"] == \"TimeLimitExceeded\"\n                    or metadata[\"run_status\"] == \"Error\"\n                    or (metadata[\"run_status\"] == \"Finished\" and run_result.get(\"return_code\") != 0)\n                )\n                if is_runtime_error:\n                    if metadata[\"run_status\"] == \"TimeLimitExceeded\":\n                        metadata[\"status\"] = \"timeout\"  # Runtime timeout\n                        result_status = -3\n                    else:  # Includes Error and Finished with non-zero return_code\n                        metadata[\"status\"] = \"runtime_error\"\n                        result_status = -2\n                else:\n                    # Other Failed status with run_result, classify as unknown failure\n                    logger.warning(f\"Unknown run_status '{metadata['run_status']}' or state within Failed API status.\")\n                    metadata[\"status\"] = \"unknown_failure\"\n                    result_status = -1  # Default to -1\n            else:\n                # Status is Failed but neither a clear compile error nor run_result exists\n                logger.warning(\"API status Failed but cannot determine specific error type (compile/run).\")\n                metadata[\"status\"] = \"unknown_failure_state\"\n                result_status = -1  # Default to -1\n        elif api_status == \"Success\":\n            # Run completed successfully, now check the answer\n            if run_result and metadata[\"run_status\"] == \"Finished\":\n                actual_output = metadata[\"stdout\"] if metadata[\"stdout\"] is not None else \"\"\n                # Note: Output might contain trailing newlines, need normalization\n                if expected_output is None or str(actual_output).rstrip(\"\\n\") == str(expected_output).rstrip(\"\\n\"):\n                    result_status = True\n                    metadata[\"status\"] = \"success\"\n                else:\n                    result_status = False\n                    metadata[\"status\"] = \"wrong_answer\"\n            else:\n                # Status is Success but run_result status is not Finished, this is unexpected\n                metadata[\"status\"] = \"unexpected_success_state\"\n                result_status = -1  # Classify as unknown error\n        else:\n            # API returned an unknown top-level status\n            logger.warning(f\"Unknown API status received: {api_status}\")\n            metadata[\"status\"] = f\"unknown_api_status_{api_status}\"\n            result_status = -1  # Default to -1\n    else:  # api_response is None and no error_msg (Should not happen with current call_sandbox_api logic)\n        metadata[\"status\"] = \"unknown_api_state\"\n        result_status = -1\n        logger.error(f\"Case {case_index}: Unknown API state (no response and no error message).\")\n    return result_status, metadata\n\n\ndef check_correctness(\n    sandbox_fusion_url: str,\n    in_outs: Optional[dict],\n    generation: str,\n    timeout: int = DEFAULT_TIMEOUT,\n    memory_limit_mb: int = 1024,\n    language: str = \"python\",\n    concurrent_semaphore: Optional[threading.Semaphore] = None,\n) -> tuple[list[Any], list[dict[str, Any]]]:\n    \"\"\"\n    Checks the correctness of code generation using the remote sandbox API,\n    processing test cases concurrently.\n\n    Args:\n        sandbox_fusion_url: The URL of the sandbox fusion API.\n        in_outs: Dictionary containing \"inputs\" and \"outputs\" lists.\n        generation: The generated code string.\n        timeout: Timeout for each test case (compile and run share this timeout).\n        language: The programming language of the code.\n\n    Returns:\n        A tuple (results, metadata_list).\n        results: A list containing the test result for each input/output pair\n                 (True/False/-1 api/sandbox err, -2 runtime err, -3 timeout, -4 compile err).\n                 Results are ordered corresponding to the inputs.\n        metadata_list: A list containing metadata dictionaries for each test case,\n                       ordered corresponding to the inputs.\n    \"\"\"\n    logger.info(\"Starting correctness check for generation.\")\n\n    if not in_outs or \"inputs\" not in in_outs or \"outputs\" not in in_outs:\n        logger.warning(\"Invalid in_outs format provided.\")\n        return [-1], [{\"error\": \"Invalid input/output data\"}]\n\n    inputs = in_outs[\"inputs\"]\n    expected_outputs = in_outs[\"outputs\"]\n    fn_name = in_outs.get(\"fn_name\")\n    num_cases = len(inputs)\n    assert_cases = in_outs.get(\"assert_case\", [\"\"] * num_cases)  # Default to empty strings if not provided\n    results = [None] * num_cases  # Initialize with placeholders\n    metadata_list = [None] * num_cases  # Initialize with placeholders\n\n    if num_cases == 0:\n        logger.warning(\"Empty inputs provided.\")\n        return [], []\n\n    if len(inputs) != len(expected_outputs):\n        logger.warning(f\"Mismatch between number of inputs ({len(inputs)}) and outputs ({len(expected_outputs)}).\")\n        # Return error based on the number of inputs provided\n        return [-1] * num_cases, [{\"error\": \"Input/output count mismatch\", \"case_index\": i} for i in range(num_cases)]\n\n    # If assert_cases is provided, it overrides inputs and outputs\n    if len(assert_cases) != num_cases:\n        logger.warning(\n            f\"Mismatch between number of assert cases ({len(assert_cases)}) and inputs/outputs ({num_cases}).\"\n        )\n        return [-1] * num_cases, [{\"error\": \"Input/output count mismatch\", \"case_index\": i} for i in range(num_cases)]\n\n    first_compile_error_index = -1\n\n    # max_workers is limited by sandbox_fusion_max_concurrent from concurrent_semaphore\n    with concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5)) as executor:\n        # Submit all tasks, passing the concurrent_semaphore to _process_single_case\n        future_to_index = {\n            executor.submit(\n                _process_single_case,\n                i,\n                stdin_data,\n                expected_outputs[i],\n                sandbox_fusion_url,\n                generation + \"\\n\\n\" + assert_cases[i],  # Append assert case to generation\n                timeout,\n                memory_limit_mb,\n                language,\n                concurrent_semaphore,\n                fn_name,\n            ): i\n            for i, stdin_data in enumerate(inputs)\n        }\n\n        # Process results as they complete\n        for future in concurrent.futures.as_completed(future_to_index):\n            index = future_to_index[future]\n            try:\n                result_status, metadata = future.result()\n                results[index] = result_status\n                metadata_list[index] = metadata\n\n                # Check for compile error (-4)\n                if result_status == -4:\n                    if first_compile_error_index == -1 or index < first_compile_error_index:\n                        first_compile_error_index = index\n                    # Optimization: could potentially cancel futures for index > first_compile_error_index\n                    # However, cancellation is not guaranteed. Post-processing is safer.\n\n            except Exception as exc:\n                logger.error(f\"Test case {index} generated an exception: {exc}\")\n                traceback.print_exc()\n                results[index] = -1  # Mark as API/internal error\n                metadata_list[index] = {\n                    \"case_index\": index,\n                    \"input\": str(inputs[index]),\n                    \"expected_output\": str(expected_outputs[index]) if expected_outputs[index] else None,\n                    \"api_request_error\": f\"Internal execution error: {exc}\",\n                    \"status\": \"internal_error\",\n                }\n\n    # Post-processing for compile errors\n    if first_compile_error_index != -1:\n        logger.warning(\n            f\"Compile error detected in case {first_compile_error_index}. Marking subsequent cases as compile errors.\"\n        )\n        for i in range(first_compile_error_index + 1, num_cases):\n            # Only update if not already processed (though it should be None or have a result)\n            if results[i] != -4:  # Avoid overwriting if it somehow already got -4\n                results[i] = -4\n                # Update or create metadata for skipped cases due to compile error\n                if metadata_list[i] is None:  # If future failed before returning metadata\n                    metadata_list[i] = {\n                        \"case_index\": i,\n                        \"input\": str(inputs[i]),\n                        \"expected_output\": str(expected_outputs[i]) if expected_outputs[i] else None,\n                        \"api_request_error\": None,\n                        \"status\": \"compile_error_skipped\",  # Indicate skipped due to prior compile error\n                    }\n                else:  # If future completed but result is overridden\n                    metadata_list[i][\"status\"] = \"compile_error_skipped\"\n\n    logger.info(f\"Correctness check finished. Results: {results}\")\n    return results, metadata_list\n"
  },
  {
    "path": "verl_distillation/verl/utils/reward_score/search_r1_like_qa_em.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n# Copyright 2025 Search-R1 Contributors\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/verl/utils/reward_score/qa_em.py\r\n\r\nimport random\r\nimport re\r\nimport string\r\n\r\n\r\ndef normalize_answer(s):\r\n    def remove_articles(text):\r\n        return re.sub(r\"\\b(a|an|the)\\b\", \" \", text)\r\n\r\n    def white_space_fix(text):\r\n        return \" \".join(text.split())\r\n\r\n    def remove_punc(text):\r\n        exclude = set(string.punctuation)\r\n        return \"\".join(ch for ch in text if ch not in exclude)\r\n\r\n    def lower(text):\r\n        return text.lower()\r\n\r\n    return white_space_fix(remove_articles(remove_punc(lower(s))))\r\n\r\n\r\ndef em_check(prediction, golden_answers):\r\n    if isinstance(golden_answers, str):\r\n        golden_answers = [golden_answers]\r\n    normalized_prediction = normalize_answer(prediction)\r\n    score = 0\r\n    for golden_answer in golden_answers:\r\n        golden_answer = normalize_answer(golden_answer)\r\n        if golden_answer == normalized_prediction:\r\n            score = 1\r\n            break\r\n    return score\r\n\r\n\r\ndef subem_check(prediction, golden_answers):\r\n    if isinstance(golden_answers, str):\r\n        golden_answers = [golden_answers]\r\n    normalized_prediction = normalize_answer(prediction)\r\n    score = 0\r\n    for golden_answer in golden_answers:\r\n        golden_answer = normalize_answer(golden_answer)\r\n        if golden_answer in normalized_prediction:\r\n            score = 1\r\n            break\r\n    return score\r\n\r\n\r\ndef extract_solution(solution_str):\r\n    \"\"\"Extract the equation from the solution string.\"\"\"\r\n    # Remove everything before the first \"Assistant:\"\r\n    # if \"Assistant:\" in solution_str:\r\n    #     solution_str = solution_str.split(\"Assistant:\", 1)[1]\r\n    # elif \"<|im_start|>assistant\" in solution_str:\r\n    #     solution_str = solution_str.split(\"<|im_start|>assistant\", 1)[1]\r\n    # else:\r\n    #     return None\r\n    # solution_str = solution_str.split('\\n')[-1]\r\n\r\n    answer_pattern = r\"<answer>(.*?)</answer>\"\r\n    match = re.finditer(answer_pattern, solution_str, re.DOTALL)\r\n    matches = list(match)\r\n\r\n    # If there are 0  matches, return None\r\n    if len(matches) < 1:\r\n        return None\r\n\r\n    # If there are 2 or more matches, return the last one\r\n    return matches[-1].group(1).strip()\r\n\r\n\r\ndef count_answer_tags(text):\r\n    opening_tags = text.count(\"<answer>\")\r\n    closing_tags = text.count(\"</answer>\")\r\n\r\n    return opening_tags, closing_tags\r\n\r\n\r\ndef compute_score(solution_str, ground_truth, method=\"strict\", format_score=0.0, score=1.0):\r\n    \"\"\"The scoring function for exact match (EM).\r\n\r\n    Args:\r\n        solution_str: the solution text\r\n        ground_truth: the ground truth\r\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\r\n        format_score: the score for the format\r\n        score: the score for the correct answer\r\n    \"\"\"\r\n    answer = extract_solution(solution_str=solution_str)\r\n    open_count, close_count = count_answer_tags(solution_str)\r\n    do_print = random.randint(1, 64) == 1\r\n\r\n    if do_print:\r\n        print(\"--------------------------------\")\r\n        print(f\"Golden answers: {ground_truth['target']}\")\r\n        if answer is not None:\r\n            print(f\"Extracted answer is not None: {answer}\")\r\n        else:\r\n            print(\"Extracted answer: None!\")\r\n        print(f\"Solution string: {solution_str}\")\r\n\r\n    if answer is None:\r\n        return 0\r\n    else:\r\n        if em_check(answer, ground_truth[\"target\"]):\r\n            if open_count > 10 or close_count > 10:  # prevent output a lot of </answer>\r\n                score = score / 4\r\n                return score\r\n            return score\r\n        else:\r\n            return format_score\r\n\r\n\r\ndef compute_score_subem(solution_str, ground_truth, method=\"strict\", format_score=0.0, score=1.0):\r\n    \"\"\"The scoring function for substring exact match (EM).\r\n\r\n    Args:\r\n        solution_str: the solution text\r\n        ground_truth: the ground truth\r\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\r\n        format_score: the score for the format\r\n        score: the score for the correct answer\r\n    \"\"\"\r\n    answer = extract_solution(solution_str=solution_str)\r\n    do_print = random.randint(1, 64) == 1\r\n\r\n    if do_print:\r\n        print(\"--------------------------------\")\r\n        print(f\"Golden answers: {ground_truth['target']}\")\r\n        print(f\"Extracted answer: {answer}\")\r\n        print(f\"Solution string: {solution_str}\")\r\n\r\n    if answer is None:\r\n        return 0\r\n    else:\r\n        if subem_check(answer, ground_truth[\"target\"]):\r\n            return score\r\n        else:\r\n            return format_score\r\n"
  },
  {
    "path": "verl_distillation/verl/utils/rollout_skip.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom pathlib import Path\n\nfrom verl.protocol import DataProto\n\n\nclass RolloutSkip:\n    \"\"\"\n    RolloutSkip skips sequence generation during rollout by attempting to load previously dumped data.\n    If no dumped data is found, it generates new sequences and saves them to disk.\n\n    Args:\n        config: The configuration object containing rollout settings.\n        rollout_wg: The worker group that handles the rollout process.\n\n    Note:\n        When rollout.n or rollout.gen_batch_size differ from previous runs,\n        new sequences will be generated and saved with different filenames.\n    \"\"\"\n\n    print_mark = \"[RolloutSkip()]\"\n\n    def __init__(self, config, rollout_wg):\n        self.rollout_config = config.actor_rollout_ref.rollout\n        self.exp_name = config.data.get(\"experiment_name\", \"\")\n        self.project_name = config.data.get(\"project_name\", \"\")\n\n        self.n = int(self.rollout_config.get(\"n\", 0))\n        self.gbs = int(config.data.get(\"gen_batch_size\", config.data.get(\"train_batch_size\", 0)))\n\n        self.dumped_dir = Path(self.rollout_config.get(\"skip_dump_dir\", \"/tmp/verl/rollout_dump\"))\n        self.dumped_dir.mkdir(parents=True, exist_ok=True)\n\n        # Check if path is in Ray temporary directory\n        if str(self.dumped_dir.absolute()).startswith(\"/tmp/ray/session\"):\n            print(\n                f\"\\033[33m{self.print_mark} Warning: \\nUsing dump path \",\n                f\"'{self.dumped_dir.absolute()}' is not recommended \",\n                \"as it's located in /tmp/ray/session*\\033[0m\",\n                flush=True,\n            )\n\n        print(\n            f\"{self.print_mark} Rollout skip dump path set to: \",\n            f\"{self.dumped_dir.absolute()}\",\n            flush=True,\n        )\n\n        self._rollout_wg = rollout_wg\n\n    @property\n    def curr_path_dump(self):\n        return self.dumped_dir.joinpath(f\"{self.exp_name}_{self.project_name}_GBS{self.gbs}__N{self.n}\").absolute()\n\n    def wrap_generate_sequences(self):\n        try:\n            self._rollout_wg.generate_sequences = wrap_generate_sequences(self, self._rollout_wg)\n            print(\n                f\"{self.print_mark} Successfully patched `actor_rollout_wg.generate_sequences()`\",\n                flush=True,\n            )\n        except Exception as e:\n            raise RuntimeError(\n                \"{self.print_mark} Failed to patch `actor_rollout_wg.generate_sequences()`\",\n                flush=True,\n            ) from e\n\n    def try_load(self):\n        if not self.curr_path_dump.exists():\n            print(\n                f\"{self.print_mark} No data dump found at {self.curr_path_dump}.\",\n                \"The trainer will generate and automatically dump the data for this first run.\",\n                flush=True,\n            )\n            return None\n\n        try:\n            # * Load\n            ret_batch = DataProto.load_from_disk(self.curr_path_dump)\n            print(\n                f\"\\033[32m{self.print_mark} Successfully load pre-generated data from {self.curr_path_dump}\\033[0m\",\n                flush=True,\n            )\n            return ret_batch\n        except Exception as e:\n            print(\n                f\"\\033[31m{self.print_mark} Failed to load pre-generated data from {self.curr_path_dump}\",\n                f\"Error: {str(e)}\\033[0m\",\n                flush=True,\n            )\n            return None\n\n    def dump(self, outputs: DataProto):\n        try:\n            outputs.save_to_disk(self.curr_path_dump)\n            print(\n                f\"\\033[32m{self.print_mark} Successfully dump data in {self.curr_path_dump}\\033[0m\",\n                flush=True,\n            )\n        except Exception as e:\n            print(\n                f\"\\033[31m{self.print_mark} Failed to dump data in {self.curr_path_dump}: {e}\\033[0m\",\n                flush=True,\n            )\n\n\ndef wrap_generate_sequences(rolloutskip: RolloutSkip, rollout_wg):\n    generate_sequences = rollout_wg.generate_sequences\n\n    def warp_fn(batch, **kwargs):\n        gen_batch_output = rolloutskip.try_load()\n\n        if gen_batch_output is None:\n            # * 1. Generation\n            gen_batch_output = generate_sequences(batch, **kwargs)\n            # * 2. Dump\n            rolloutskip.dump(gen_batch_output)\n        return gen_batch_output\n\n    return warp_fn\n"
  },
  {
    "path": "verl_distillation/verl/utils/rollout_trace.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport contextlib\nimport functools\nimport inspect\nimport os\nfrom typing import Optional\n\n\nclass RolloutTraceConfig:\n    \"\"\"Configuration for rollout tracing with various backends.\n\n    Singleton configuration class for managing rollout trace settings across different\n    tracing backends like Weave and MLflow.\n\n    Args:\n        backend (Optional[str]): Tracing backend to use ('weave', 'mlflow', or None).\n        client (Optional[object]): Client instance for the selected backend.\n        token2text (bool): Whether to convert tokens to text in traces. Defaults to False.\n        project_name (str): Name of the project for tracing.\n        experiment_name (str): Name of the experiment for tracing.\n    \"\"\"\n\n    _instance: Optional[\"RolloutTraceConfig\"] = None\n    backend: Optional[str] = None\n    client: Optional[object] = None\n    token2text: bool = False\n    _initialized: bool = False\n    project_name: str = None\n    experiment_name: str = None\n\n    def __new__(cls, *args, **kwargs):\n        if cls._instance is None:\n            cls._instance = super().__new__(cls)\n            cls._instance._initialized = False\n        return cls._instance\n\n    @classmethod\n    def get_instance(cls) -> \"RolloutTraceConfig\":\n        if cls._instance is None:\n            cls._instance = cls()\n        return cls._instance\n\n    @classmethod\n    def init(cls, project_name: str, experiment_name: str, backend: str, token2text: bool = False):\n        config = cls.get_instance()\n        if config._initialized:\n            return\n\n        config.backend = backend\n        config.token2text = token2text\n        config.project_name = project_name\n        config.experiment_name = experiment_name\n\n        if backend == \"weave\":\n            import weave\n\n            config.client = weave.init(project_name)\n        elif backend == \"mlflow\":\n            import mlflow\n\n            mlflow.config.enable_async_logging()\n            config.client = mlflow\n\n            MLFLOW_TRACKING_URI = os.environ.get(\"MLFLOW_TRACKING_URI\", \"sqlite:////tmp/mlruns.db\")\n            mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)\n\n            mlflow.set_experiment(project_name)\n        else:\n            config.client = None\n\n        config._initialized = True\n\n    @classmethod\n    def get_backend(cls) -> Optional[str]:\n        return cls.get_instance().backend\n\n    @classmethod\n    def get_client(cls) -> Optional[object]:\n        return cls.get_instance().client\n\n    @classmethod\n    def enable_token2text(cls) -> Optional[bool]:\n        return cls.get_instance().token2text\n\n    @classmethod\n    def reset(cls):\n        cls._instance = None\n\n\n@contextlib.contextmanager\ndef rollout_trace_attr(sample_index=None, step=None, rollout_n=None, name=\"rollout_trace\", validate=False):\n    \"\"\"A context manager to add attributes to a trace for the configured backend.\"\"\"\n    backend = RolloutTraceConfig.get_backend()\n    attributes = {}\n    if backend:\n        if sample_index is not None:\n            attributes[\"sample_index\"] = sample_index\n        if step is not None:\n            attributes[\"step\"] = step\n        if rollout_n is not None:\n            attributes[\"rollout_n\"] = rollout_n\n        attributes[\"validate\"] = validate\n        attributes[\"experiment_name\"] = RolloutTraceConfig.get_instance().experiment_name\n\n    if not attributes or backend is None:\n        yield\n        return\n\n    if backend == \"weave\":\n        import weave\n\n        with weave.attributes(attributes):\n            yield\n    elif backend == \"mlflow\":\n        import mlflow\n\n        with mlflow.start_span(name=name) as span:\n            trace_id = span.trace_id\n            for key, value in attributes.items():\n                mlflow.set_trace_tag(trace_id, str(key), str(value))\n            yield\n    else:\n        yield\n\n\ndef rollout_trace_op(func):\n    @functools.wraps(func)\n    async def async_wrapper(self, *args, **kwargs):\n        backend = RolloutTraceConfig.get_backend()\n        enable_token2text = RolloutTraceConfig.enable_token2text()\n        if backend is None:\n            return await func(self, *args, **kwargs)\n\n        sig = inspect.signature(func)\n        bound_args = sig.bind(self, *args, **kwargs)\n        bound_args.apply_defaults()\n        inputs = dict(bound_args.arguments)\n        del inputs[\"self\"]\n\n        async def add_token2text(self, result):\n            if hasattr(result, \"prompt_ids\") and hasattr(self, \"tokenizer\") and hasattr(self.tokenizer, \"decode\"):\n                _result = vars(result)\n                loop = asyncio.get_running_loop()\n                if hasattr(result, \"prompt_ids\"):\n                    prompt_text = await loop.run_in_executor(None, self.tokenizer.decode, result.prompt_ids)\n                    _result[\"prompt_text\"] = prompt_text\n\n                if hasattr(result, \"response_ids\"):\n                    response_text = await loop.run_in_executor(None, self.tokenizer.decode, result.response_ids)\n                    _result[\"response_text\"] = response_text\n                return _result\n            return result\n\n        if backend == \"weave\":\n            tracer = RolloutTraceConfig.get_client()\n            from weave.trace.context import call_context\n\n            cur_attributes = {**call_context.call_attributes.get()}\n            call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes)\n            try:\n                result = await func(self, *args, **kwargs)\n\n                if enable_token2text:\n                    _result = await add_token2text(self, result)\n                    tracer.finish_call(call, output=_result)\n                else:\n                    tracer.finish_call(call, output=result)\n\n                return result\n\n            except Exception as e:\n                tracer.finish_call(call, exception=e)\n                raise e\n        elif backend == \"mlflow\":\n            import mlflow\n\n            with mlflow.start_span(name=func.__qualname__) as span:\n                span.set_inputs(inputs)\n                result = await func(self, *args, **kwargs)\n                if enable_token2text:\n                    _result = await add_token2text(self, result)\n                    span.set_outputs(_result)\n                else:\n                    span.set_outputs(result)\n\n            return result\n\n        else:\n            return await func(self, *args, **kwargs)\n\n    @functools.wraps(func)\n    def wrapper(self, *args, **kwargs):\n        backend = RolloutTraceConfig.get_backend()\n        if backend is None:\n            return func(self, *args, **kwargs)\n\n        sig = inspect.signature(func)\n        bound_args = sig.bind(self, *args, **kwargs)\n        bound_args.apply_defaults()\n        inputs = dict(bound_args.arguments)\n        del inputs[\"self\"]\n\n        if backend == \"weave\":\n            tracer = RolloutTraceConfig.get_client()\n            from weave.trace.context import call_context\n\n            cur_attributes = {**call_context.call_attributes.get()}\n            call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes)\n            try:\n                result = func(self, *args, **kwargs)\n                tracer.finish_call(call, output=result)\n                return result\n            except Exception as e:\n                tracer.finish_call(call, exception=e)\n                raise e\n        elif backend == \"mlflow\":\n            import mlflow\n\n            return mlflow.trace(func)(self, *args, **kwargs)\n        else:\n            return func(self, *args, **kwargs)\n\n    return async_wrapper if inspect.iscoroutinefunction(func) else wrapper\n"
  },
  {
    "path": "verl_distillation/verl/utils/seqlen_balancing.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nimport heapq\nfrom itertools import chain\n\nimport torch\nfrom torch import distributed as dist\n\nfrom verl.protocol import DataProto\nfrom verl.utils import tensordict_utils as tu\nfrom verl.utils.device import get_device_name\n\n\ndef calculate_workload(seqlen_list: list[int]):\n    \"\"\"\n    Calculate the workload for a dense transformer block based on sequence length.\n    FLOPs = 12 * hidden_size^2 * seqlen + 2 * hidden_size * seqlen^2\n    Hardcodes the constants by a 7B model (hidden_size=4096),\n    so the FLOPs are propotional to (6 * 4096 * seqlen + seqlen^2).\n    \"\"\"\n    return 24576 * seqlen_list + seqlen_list**2\n\n\ndef karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool):\n    # see: https://en.wikipedia.org/wiki/Largest_differencing_method\n    class Set:\n        def __init__(self) -> None:\n            self.sum = 0\n            self.items = []\n\n        def add(self, idx: int, val: int):\n            self.items.append((idx, val))\n            self.sum += val\n\n        def merge(self, other):\n            for idx, val in other.items:\n                self.items.append((idx, val))\n                self.sum += val\n\n        def __lt__(self, other):\n            if self.sum != other.sum:\n                return self.sum < other.sum\n            if len(self.items) != len(other.items):\n                return len(self.items) < len(other.items)\n            return self.items < other.items\n\n    class State:\n        def __init__(self, items: list[tuple[int, int]], k: int) -> None:\n            self.k = k\n            # sets should always be decreasing order\n            self.sets = [Set() for _ in range(k)]\n            assert len(items) in [1, k], f\"{len(items)} not in [1, {k}]\"\n            for i, (idx, seqlen) in enumerate(items):\n                self.sets[i].add(idx=idx, val=seqlen)\n            self.sets = sorted(self.sets, reverse=True)\n\n        def get_partitions(self):\n            partitions = []\n            for i in range(len(self.sets)):\n                cur_partition = []\n                for idx, _ in self.sets[i].items:\n                    cur_partition.append(idx)\n                partitions.append(cur_partition)\n            return partitions\n\n        def merge(self, other):\n            for i in range(self.k):\n                self.sets[i].merge(other.sets[self.k - 1 - i])\n            self.sets = sorted(self.sets, reverse=True)\n\n        @property\n        def spread(self) -> int:\n            return self.sets[0].sum - self.sets[-1].sum\n\n        def __lt__(self, other):\n            # least heap, let the state with largest spread to be popped first,\n            # if the spread is the same, let the state who has the largest set\n            # to be popped first.\n            if self.spread != other.spread:\n                return self.spread > other.spread\n            return self.sets[0] > other.sets[0]\n\n        def __repr__(self) -> str:\n            repr_str = \"[\"\n            for i in range(self.k):\n                if i > 0:\n                    repr_str += \",\"\n                repr_str += \"{\"\n                for j, (_, seqlen) in enumerate(self.sets[i].items):\n                    if j > 0:\n                        repr_str += \",\"\n                    repr_str += str(seqlen)\n                repr_str += \"}\"\n            repr_str += \"]\"\n            return repr_str\n\n    sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])\n    states_pq = []\n    if equal_size:\n        assert len(seqlen_list) % k_partitions == 0, f\"{len(seqlen_list)} % {k_partitions} != 0\"\n        for offset in range(0, len(sorted_seqlen_list), k_partitions):\n            items = []\n            for i in range(k_partitions):\n                seqlen, idx = sorted_seqlen_list[offset + i]\n                items.append((idx, seqlen))\n            heapq.heappush(states_pq, State(items=items, k=k_partitions))\n    else:\n        for seqlen, idx in sorted_seqlen_list:\n            heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions))\n\n    while len(states_pq) > 1:\n        state0 = heapq.heappop(states_pq)\n        state1 = heapq.heappop(states_pq)\n        # merge states\n        state0.merge(state1)\n        heapq.heappush(states_pq, state0)\n\n    final_state = states_pq[0]\n    partitions = final_state.get_partitions()\n    if equal_size:\n        for i, partition in enumerate(partitions):\n            assert len(partition) * k_partitions == len(seqlen_list), (\n                f\"{len(partition)} * {k_partitions} != {len(seqlen_list)}\"\n            )\n    return partitions\n\n\ndef greedy_partition(seqlen_list: list[int], k_partitions: int, equal_size: bool):\n    bias = sum(seqlen_list) + 1 if equal_size else 0\n    sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)]\n    partitions = [[] for _ in range(k_partitions)]\n    partition_sums = [0 for _ in range(k_partitions)]\n    for seqlen, i in sorted_seqlen:\n        min_idx = None\n        for j in range(k_partitions):\n            if min_idx is None or partition_sums[j] < partition_sums[min_idx]:\n                min_idx = j\n        partitions[min_idx].append(i)\n        partition_sums[min_idx] += seqlen\n    if equal_size:\n        for i, partition in enumerate(partitions):\n            assert len(partition) * k_partitions == len(seqlen_list), (\n                f\"{len(partition)} * {k_partitions} != {len(seqlen_list)}\"\n            )\n    return partitions\n\n\ndef get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool):\n    \"\"\"\n    Calculates partitions of indices from seqlen_list such that the sum of sequence lengths\n    in each partition is balanced. Uses the Karmarkar-Karp differencing method.\n\n    This is useful for balancing workload across devices or batches, especially when\n    dealing with variable sequence lengths.\n\n    Args:\n        seqlen_list (List[int]): A list of sequence lengths for each item.\n        k_partitions (int): The desired number of partitions.\n        equal_size (bool): If True, ensures that each partition has the same number of items.\n                           Requires len(seqlen_list) to be divisible by k_partitions.\n                           If False, partitions can have varying numbers of items, focusing\n                           only on balancing the sum of sequence lengths.\n\n    Returns:\n        List[List[int]]: A list containing k_partitions lists. Each inner list contains the\n                         original indices of the items assigned to that partition. The indices\n                         within each partition list are sorted.\n\n    Raises:\n        AssertionError: If len(seqlen_list) < k_partitions.\n        AssertionError: If equal_size is True and len(seqlen_list) is not divisible by k_partitions.\n        AssertionError: If any resulting partition is empty.\n    \"\"\"\n    assert len(seqlen_list) >= k_partitions, f\"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]\"\n\n    def _check_and_sort_partitions(partitions):\n        assert len(partitions) == k_partitions, f\"{len(partitions)} != {k_partitions}\"\n        seen_idx = set()\n        sorted_partitions = [None] * k_partitions\n        for i, partition in enumerate(partitions):\n            assert len(partition) > 0, f\"the {i}-th partition is empty\"\n            for idx in partition:\n                seen_idx.add(idx)\n            sorted_partitions[i] = sorted(partition)\n        assert seen_idx == set(range(len(seqlen_list)))\n        return sorted_partitions\n\n    partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)\n    return _check_and_sort_partitions(partitions)\n\n\ndef log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], prefix):\n    \"\"\"\n    Calculate and log metrics related to sequence length imbalance before and after partitioning.\n\n    Args:\n        seqlen_list (List[int]): A list of sequence lengths for each item.\n        partitions (List[List[int]]): A list of partitions, where each inner list contains indices\n                                      from seqlen_list assigned to that partition.\n        prefix (str): A prefix to be added to each metric key in the returned dictionary.\n\n    Returns:\n        dict: A dictionary containing metrics related to sequence length imbalance.\n    \"\"\"\n    # Get the number of partitions\n    k_partition = len(partitions)\n    # assert len(seqlen_list) % k_partition == 0\n    batch_size = len(seqlen_list) // k_partition\n    min_sum_seqlen = None\n    max_sum_seqlen = None\n    total_sum_seqlen = 0\n\n    # Iterate over each batch of sequence lengths\n    for offset in range(0, len(seqlen_list), batch_size):\n        cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size])\n        if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:\n            min_sum_seqlen = cur_sum_seqlen\n        if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:\n            max_sum_seqlen = cur_sum_seqlen\n        total_sum_seqlen += cur_sum_seqlen\n\n    balanced_sum_seqlen_list = []\n    for partition in partitions:\n        cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition])\n        balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced)\n    # print(\"balanced_sum_seqlen_list: \", balanced_sum_seqlen_list)\n    min_sum_seqlen_balanced = min(balanced_sum_seqlen_list)\n    max_sum_seqlen_balanced = max(balanced_sum_seqlen_list)\n\n    return {\n        f\"{prefix}/min\": min_sum_seqlen,\n        f\"{prefix}/max\": max_sum_seqlen,\n        f\"{prefix}/minmax_diff\": max_sum_seqlen - min_sum_seqlen,\n        f\"{prefix}/balanced_min\": min_sum_seqlen_balanced,\n        f\"{prefix}/balanced_max\": max_sum_seqlen_balanced,\n        f\"{prefix}/mean\": total_sum_seqlen / len(partitions),\n    }\n\n\ndef ceildiv(a, b):\n    return -(a // -b)\n\n\ndef roundup_divisible(a, b):\n    return ((a + b - 1) // b) * b\n\n\ndef rearrange_micro_batches(\n    batch,\n    max_token_len,\n    dp_group=None,\n    num_batches_divided_by=None,\n    same_micro_num_in_dp=True,\n    min_num_micro_batch=None,\n    use_dynamic_bsz_balance=True,\n):\n    \"\"\"\n    Split a batch into micro-batches by total token count, with optional DP sync and padding.\n\n    Args:\n        batch (TensorDict): must include \"attention_mask\" (B*S); other fields are sliced similarly.\n        max_token_len (int): max sum of attention_mask per micro-batch.\n        dp_group (optional): torch.distributed group for data-parallel sync.\n        num_batches_divided_by (optional): virtual pipeline parallel size, for megatron.\n        same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count.\n        min_num_micro_batch (int, optional): force at least this many splits (pads empty ones).\n        use_dynamic_bsz_balance (bool, optional): balance the computational workload between micro-batches\n\n    Returns:\n        List[TensorDict]: the micro-batches.\n        List[List[int]]: index lists mapping each micro-batch back to original positions.\n    \"\"\"\n    # this is per local micro_bsz\n    input_ids = batch[\"input_ids\"]\n    if input_ids.is_nested:\n        seq_len_effective: torch.Tensor = input_ids.offsets().diff()\n        max_seq_len = max(seq_len_effective)\n    else:\n        max_seq_len = batch[\"attention_mask\"].shape[-1]\n        seq_len_effective: torch.Tensor = batch[\"attention_mask\"].sum(dim=1)\n\n    assert max_token_len >= max_seq_len, (\n        f\"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}\"\n    )\n    total_seqlen = seq_len_effective.sum().item()\n    # NOTE: num_microbatches <= batch_size, so take the min of this two.\n    num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))\n    if min_num_micro_batch is not None:\n        # used to support pp\n        num_micro_batches = max(min_num_micro_batch, num_micro_batches)\n    if dist.is_initialized() and same_micro_num_in_dp:\n        num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name())\n        dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)\n        num_micro_batches = num_micro_batches.cpu().item()\n    if num_batches_divided_by is not None:\n        num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by)\n\n    assert num_micro_batches <= len(seq_len_effective)\n\n    workloads = calculate_workload(seq_len_effective)\n    micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False)\n\n    if use_dynamic_bsz_balance:\n        # Use the sum of squared sequence lengths to approximate attention computation workload\n        micro_bsz_idx.sort(\n            key=lambda partition: (\n                sum(workloads[idx] for idx in partition),\n                partition[0] if partition else 0,\n            ),\n            reverse=True,\n        )\n        # Place smaller micro-batches at both ends to reduce the bubbles exposed during the warm-up and cool-down.\n        micro_bsz_idx = micro_bsz_idx[::2][::-1] + micro_bsz_idx[1::2]\n\n    micro_batches = []\n\n    for partition in micro_bsz_idx:\n        curr_micro_batch = tu.index_select_tensor_dict(batch, partition)\n        micro_batches.append(curr_micro_batch)\n\n    return micro_batches, micro_bsz_idx\n\n\ndef get_reverse_idx(idx_map):\n    \"\"\"\n    Build the inverse of an index mapping.\n\n    Args:\n        idx_map (Sequence[int]): Sequence where idx_map[i] = j.\n\n    Returns:\n        List[int]: Inverse mapping list such that output[j] = i for each i.\n    \"\"\"\n    reverse_idx_map = copy.deepcopy(idx_map)\n\n    for i, idx in enumerate(idx_map):\n        reverse_idx_map[idx] = i\n\n    return reverse_idx_map\n\n\ndef prepare_dynamic_batch(\n    data: DataProto,\n    max_token_len: int,\n    dp_group=None,\n    num_batches_divided_by=None,\n    same_micro_num_in_dp=True,\n    min_num_micro_batch=None,\n    use_dynamic_bsz_balance=True,\n) -> tuple[list[DataProto], list[list[int]]]:\n    \"\"\"\n    Prepare a batch for dynamic batching.\n\n    Args:\n        data (DataProto): The input data.\n        max_token_len (int): The maximum token length for dynamic batching.\n\n    Returns:\n        Tuple[List[DataProto], List[List[int]]]: A tuple containing a list of DataProto objects\n        and a list of index lists.\n    \"\"\"\n    batch, batch_idx_list = rearrange_micro_batches(\n        data.batch,\n        max_token_len=max_token_len,\n        dp_group=dp_group,\n        num_batches_divided_by=num_batches_divided_by,\n        same_micro_num_in_dp=same_micro_num_in_dp,\n        min_num_micro_batch=min_num_micro_batch,\n        use_dynamic_bsz_balance=use_dynamic_bsz_balance,\n    )\n    micro_batches = []\n    for i, batch_idx in enumerate(batch_idx_list):\n        tensors = dict(batch[i])\n        non_tensors = {key: value[batch_idx] for key, value in data.non_tensor_batch.items()}\n        meta_info = copy.deepcopy(data.meta_info)\n        micro_batches.append(DataProto.from_dict(tensors, non_tensors, meta_info=meta_info))\n\n    return micro_batches, batch_idx_list\n\n\ndef restore_dynamic_batch(data: torch.Tensor, batch_idx_list: list[list[int]]) -> torch.Tensor:\n    \"\"\"\n    Restore a batch from dynamic batching.\n\n    Args:\n        data (torch.Tensor): The input data.\n        batch_idx_list (List[List[int]]): The list of index lists.\n\n    Returns:\n        torch.Tensor: The restored data.\n    \"\"\"\n    indices = list(chain.from_iterable(batch_idx_list))\n    batch_size = data.shape[0]\n    assert len(indices) == batch_size, f\"{len(indices)} vs. {batch_size}\"\n    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n\n    if data.is_nested:\n        tensors = [data[i] for i in revert_indices]\n        reverted_data = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)\n    else:\n        reverted_data = data[revert_indices]\n\n    return reverted_data\n"
  },
  {
    "path": "verl_distillation/verl/utils/tensordict_utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nfrom typing import Iterator\n\nimport torch\nfrom tensordict import TensorDict\nfrom tensordict.tensorclass import NonTensorData, NonTensorStack\n\n\ndef assign_non_tensor_dict(tensor_dict: TensorDict, non_tensor_dict: dict):\n    for key, val in non_tensor_dict.items():\n        assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val)\n    return tensor_dict\n\n\ndef assign_non_tensor_data(tensor_dict: TensorDict, key, val):\n    tensor_dict[key] = NonTensorData(val)\n\n\ndef assign_non_tensor(tensordict: TensorDict, **kwargs):\n    for key, val in kwargs.items():\n        assign_non_tensor_data(tensor_dict=tensordict, key=key, val=val)\n    return tensordict\n\n\ndef unwrap_non_tensor_data(data):\n    if isinstance(data, NonTensorData):\n        return data.data\n    return data\n\n\ndef get_non_tensor_data(data: TensorDict, key: str, default):\n    output = data.get(key, default)\n    return unwrap_non_tensor_data(output)\n\n\ndef get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:\n    \"\"\"\n\n    Args:\n        data_dict:\n        meta_info:\n\n    Returns:\n\n    \"\"\"\n    if non_tensor_dict is None:\n        non_tensor_dict = {}\n\n    batch_size = None\n\n    for key, val in tensor_dict.items():\n        if isinstance(val, list):\n            for v in val:\n                assert not isinstance(v, torch.Tensor), (\n                    \"Passing a list makes the data NonTensorStack, \"\n                    \"which doesn't support torch.Tensor. Please convert to numpy first\"\n                )\n        assert isinstance(val, torch.Tensor | list)\n\n        if batch_size is None:\n            batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)\n        else:\n            val_batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)\n            assert val_batch_size == batch_size, (\n                f\"Batch size of tensor {key} is not consistent with other tensors. \"\n                f\"Expected {batch_size}, got {val_batch_size}\"\n            )\n\n    if batch_size is None:\n        batch_size = []\n    else:\n        batch_size = [batch_size]\n\n    for key, val in non_tensor_dict.items():\n        assert key not in tensor_dict\n        tensor_dict[key] = NonTensorData(val)\n\n    return TensorDict(source=tensor_dict, batch_size=batch_size)\n\n\ndef index_select_tensor_dict(batch: TensorDict, indices: torch.Tensor | list[int]) -> TensorDict:\n    \"\"\"Index a tensor dict with a tensor of indices.\"\"\"\n    if isinstance(indices, list):\n        indices = torch.tensor(indices)\n\n    assert indices.dim() == 1, \"indices must be a 1D tensor\"\n\n    data_dict = {}\n    batch_size = indices.shape[0]\n\n    if batch is not None:\n        for key, tensor in batch.items():\n            if isinstance(tensor, torch.Tensor) and not tensor.is_nested:\n                data_dict[key] = tensor[indices]\n            elif isinstance(tensor, torch.Tensor) and tensor.is_nested:\n                data_dict[key] = torch.nested.as_nested_tensor([tensor[idx] for idx in indices], layout=torch.jagged)\n            else:\n                # This handles NonTensorStack (indexable by batch dim) and NonTensorData (scalar metadata).\n                if tensor.shape:\n                    data_dict[key] = tensor[indices]\n                else:\n                    data_dict[key] = tensor\n        selected_batch = TensorDict(source=data_dict, batch_size=batch_size)\n    else:\n        selected_batch = None\n\n    return selected_batch\n\n\ndef union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:\n    \"\"\"Union two tensordicts.\"\"\"\n    assert tensor_dict1.batch_size == tensor_dict2.batch_size, (\n        f\"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}\"\n    )\n    for key in tensor_dict2.keys():\n        if key not in tensor_dict1.keys():\n            tensor_dict1[key] = tensor_dict2[key]\n        else:\n            if isinstance(tensor_dict2[key], torch.Tensor):\n                assert tensor_dict1[key].equal(tensor_dict2[key]), (\n                    f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n                )\n            else:\n                # non-tensor\n                assert tensor_dict1[key] == tensor_dict2[key], (\n                    f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n                )\n\n    return tensor_dict1\n\n\ndef make_iterator(tensordict: TensorDict, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):\n    from torch.utils.data import DataLoader\n\n    assert tensordict.batch_size[0] % mini_batch_size == 0, f\"{tensordict.batch_size[0]} % {mini_batch_size} != 0\"\n    # we can directly create a dataloader from TensorDict\n    if dataloader_kwargs is None:\n        dataloader_kwargs = {}\n\n    if seed is not None:\n        generator = torch.Generator()\n        generator.manual_seed(seed)\n    else:\n        generator = None\n\n    assert isinstance(dataloader_kwargs, dict)\n    train_dataloader = DataLoader(\n        dataset=tensordict, batch_size=mini_batch_size, collate_fn=lambda x: x, generator=generator, **dataloader_kwargs\n    )\n\n    def get_data():\n        for _ in range(epochs):\n            yield from train_dataloader\n\n    return iter(get_data())\n\n\ndef assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict):\n    assert set(tensordict1.keys()) == set(tensordict2.keys())\n\n    for key in tensordict1.keys():\n        val = tensordict1[key]\n        val2 = tensordict2[key]\n\n        assert type(val) is type(val2), f\"The type of {key} must be the same. Got {type(val)} vs {type(val2)}\"\n\n        if isinstance(val, torch.Tensor):\n            if val.is_nested:\n                assert val.is_nested and val2.is_nested, (\n                    f\"Both tensors must be nested tensors. {val.is_nested=}, {val2.is_nested=}\"\n                )\n                t1, t2 = val.unbind(), val2.unbind()\n                assert len(t1) == len(t2), f\"Nested tensor should have the same lengths. {len(t1)=} vs {len(t2)=}\"\n                for c1, c2 in zip(t1, t2, strict=True):\n                    assert torch.equal(c1, c2), f\"Nested tensor components have different values. {c1=} vs {c2=}\"\n            else:\n                assert torch.all(torch.eq(val, val2)).item()\n        else:\n            assert val == val2\n\n\ndef pop(tensordict: TensorDict, keys: Iterator[str]) -> TensorDict:\n    tensor_output = {}\n    non_tensor_output = {}\n    for key in keys:\n        output = tensordict.get(key)\n        if isinstance(output, torch.Tensor):\n            tensor_output[key] = tensordict.pop(key)\n        elif isinstance(output, NonTensorStack):\n            tensor_output[key] = tensordict.pop(key).tolist()\n        else:\n            assert isinstance(output, NonTensorData)\n            non_tensor_output[key] = tensordict.pop(key)\n\n    return get_tensordict(tensor_output, non_tensor_output)\n\n\ndef pad_to_divisor(data: TensorDict, size_divisor: int):\n    \"\"\"Pad a TensorDict to size divisible by size_divisor\n\n    Args:\n        size_divisor (int): size divisor\n\n    Returns:\n        data: (TensorDict): the padded TensorDict\n        pad_size (int)\n    \"\"\"\n    assert isinstance(data, TensorDict), \"data must be a TensorDict\"\n    if len(data) % size_divisor != 0:\n        pad_size = size_divisor - len(data) % size_divisor\n        padding_protos = []\n        remaining_pad = pad_size\n        while remaining_pad > 0:\n            take_size = min(remaining_pad, len(data))\n            padding_protos.append(data[:take_size])\n            remaining_pad -= take_size\n        data_padded = torch.cat([data] + padding_protos)\n    else:\n        if len(data) == 0:\n            logging.warning(\"padding a DataProto with no item, no changed made\")\n        pad_size = 0\n        data_padded = data\n    return data_padded, pad_size\n\n\ndef unpad(data: TensorDict, pad_size):\n    \"\"\"Unpad the data proto with pad_size. i.e. `data[:-pad_size]`\"\"\"\n    if pad_size != 0:\n        data = data[:-pad_size]\n    return data\n"
  },
  {
    "path": "verl_distillation/verl/utils/tokenizer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Utils for tokenization.\"\"\"\n\nimport warnings\n\n__all__ = [\"hf_tokenizer\", \"hf_processor\"]\n\n\ndef set_pad_token_id(tokenizer):\n    \"\"\"Set pad_token_id to eos_token_id if it is None.\n\n    Args:\n        tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set.\n\n    \"\"\"\n    if tokenizer.pad_token_id is None:\n        tokenizer.pad_token_id = tokenizer.eos_token_id\n        warnings.warn(f\"tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}\", stacklevel=1)\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n        warnings.warn(f\"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}\", stacklevel=1)\n\n\ndef hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs):\n    \"\"\"Create a huggingface pretrained tokenizer which correctness handles eos and pad tokens.\n\n    Args:\n\n        name (str): The name of the tokenizer.\n        correct_pad_token (bool): Whether to correct the pad token id.\n        correct_gemma2 (bool): Whether to correct the gemma2 tokenizer.\n\n    Returns:\n\n        transformers.PreTrainedTokenizer: The pretrained tokenizer.\n\n    \"\"\"\n    from transformers import AutoTokenizer\n\n    if correct_gemma2 and isinstance(name_or_path, str) and \"gemma-2-2b-it\" in name_or_path:\n        # the EOS token in gemma2 is ambiguious, which may worsen RL performance.\n        # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a\n        warnings.warn(\n            \"Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to <end_of_turn> and 107.\", stacklevel=1\n        )\n        kwargs[\"eos_token\"] = \"<end_of_turn>\"\n        kwargs[\"eos_token_id\"] = 107\n    tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)\n    if correct_pad_token:\n        set_pad_token_id(tokenizer)\n    return tokenizer\n\n\ndef hf_processor(name_or_path, **kwargs):\n    \"\"\"Create a huggingface processor to process multimodal data.\n\n    Args:\n        name_or_path (str): The name of the processor.\n\n    Returns:\n        transformers.ProcessorMixin: The pretrained processor.\n    \"\"\"\n    from transformers import AutoProcessor\n\n    try:\n        processor = AutoProcessor.from_pretrained(name_or_path, **kwargs)\n    except Exception as e:\n        processor = None\n        # TODO(haibin.lin): try-catch should be removed after adding transformer version req to setup.py to avoid\n        # silent failure\n        warnings.warn(f\"Failed to create processor: {e}. This may affect multimodal processing\", stacklevel=1)\n    # Avoid load tokenizer, see:\n    # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344\n    if processor is not None and \"Processor\" not in processor.__class__.__name__:\n        processor = None\n    return processor\n"
  },
  {
    "path": "verl_distillation/verl/utils/torch_dtypes.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nAdapted from Cruise.\n\"\"\"\n\nimport torch\n\nHALF_LIST = [16, \"16\", \"fp16\", \"float16\", torch.float16]\nFLOAT_LIST = [32, \"32\", \"fp32\", \"float32\", torch.float32]\nBFLOAT_LIST = [\"bf16\", \"bfloat16\", torch.bfloat16]\n\n\nclass PrecisionType:\n    \"\"\"Type of precision used.\n\n    >>> PrecisionType.HALF == 16\n    True\n    >>> PrecisionType.HALF in (16, \"16\")\n    True\n    \"\"\"\n\n    HALF = \"16\"\n    FLOAT = \"32\"\n    FULL = \"64\"\n    BFLOAT = \"bf16\"\n    MIXED = \"mixed\"\n\n    @staticmethod\n    def supported_type(precision: str | int) -> bool:\n        return any(x == precision for x in PrecisionType)\n\n    @staticmethod\n    def supported_types() -> list[str]:\n        return [x.value for x in PrecisionType]\n\n    @staticmethod\n    def is_fp16(precision):\n        return precision in HALF_LIST\n\n    @staticmethod\n    def is_fp32(precision):\n        return precision in FLOAT_LIST\n\n    @staticmethod\n    def is_bf16(precision):\n        return precision in BFLOAT_LIST\n\n    @staticmethod\n    def to_dtype(precision):\n        if precision in HALF_LIST:\n            return torch.float16\n        elif precision in FLOAT_LIST:\n            return torch.float32\n        elif precision in BFLOAT_LIST:\n            return torch.bfloat16\n        else:\n            raise RuntimeError(f\"unexpected precision: {precision}\")\n\n    @staticmethod\n    def to_str(precision):\n        if precision == torch.float16:\n            return \"fp16\"\n        elif precision == torch.float32:\n            return \"fp32\"\n        elif precision == torch.bfloat16:\n            return \"bf16\"\n        else:\n            raise RuntimeError(f\"unexpected precision: {precision}\")\n"
  },
  {
    "path": "verl_distillation/verl/utils/torch_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContain small torch utilities\n\"\"\"\n\nimport math\nfrom contextlib import contextmanager\nfrom typing import Optional\n\nimport torch\nimport torch.distributed\nimport torch.nn.functional as F\nfrom tensordict import TensorDict\nfrom torch import nn\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom transformers import PreTrainedTokenizer\n\nfrom verl.utils.device import get_device_name, get_torch_device\n\ntry:\n    from flash_attn.ops.triton.cross_entropy import cross_entropy_loss\n\n    FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True\nexcept ImportError:\n    FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False\n\n\ntry:\n    import torch_npu\n\n    NPU_CROSS_ENTROPY_LOSS_AVAILABLE = hasattr(torch_npu, \"npu_cross_entropy_loss\")\nexcept ImportError:\n    NPU_CROSS_ENTROPY_LOSS_AVAILABLE = False\n\n\ndef gather_from_labels(data, label):\n    \"\"\"Gather the label from data. The value in label should be [0, vocab_size)\n\n    Args:\n        data: (..., vocab_size)\n        label (torch.IntTensor) : (...,)\n\n    Returns:\n\n    \"\"\"\n\n    output = torch.gather(data, -1, label.unsqueeze(-1)).squeeze(-1)\n    return output\n\n\ndef logprobs_from_logits(logits, labels, inplace_backward=True):\n    \"\"\"\n    Compute per-token log-probabilities for the given labels.\n\n    Uses a Flash-Attention–based cross-entropy (if available) for efficient backward,\n    otherwise falls back to a standard log-softmax+gather approach.\n\n    See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591\n\n    Args:\n        logits (Tensor): Model outputs of shape (..., vocab_size).\n        labels (LongTensor): True class indices of shape matching logits[..., :-1].\n        inplace_backward (bool): If True and Flash-Attn is available, perform backward in-place.\n\n    Returns:\n        Tensor: Log-probabilities of the target labels, shape logits.shape[:-1].\n    \"\"\"\n    if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE:\n        batch_dim = logits.shape[:-1]\n        last_dim = logits.shape[-1]\n        logits = logits.reshape(-1, last_dim)\n        labels = labels.reshape(-1)\n        output = logprobs_from_logits_flash_attn(logits, labels, inplace_backward=inplace_backward)\n        output = output.view(*batch_dim)\n    elif NPU_CROSS_ENTROPY_LOSS_AVAILABLE:\n        output = logprobs_from_logits_torch_npu(logits, labels)\n    else:\n        output = logprobs_from_logits_v2(logits, labels)\n    return output\n\n\ndef logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True):\n    output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward)\n    assert isinstance(output, tuple), (\n        \"please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses].\"\n    )\n    return -output[0]\n\n\ndef logprobs_from_logits_torch_npu(logits, labels):\n    batch_dim = logits.shape[:-1]\n    logits = logits.reshape(-1, logits.shape[-1])\n    loss, _, _, _ = torch_npu.npu_cross_entropy_loss(logits, labels.reshape(-1), reduction=\"none\")\n    return -loss.view(*batch_dim)\n\n\ndef logprobs_from_logits_naive(logits, labels):\n    logp = F.log_softmax(logits, dim=-1)\n    logpy = gather_from_labels(logp, labels)\n    return logpy\n\n\ndef logprobs_from_logits_v2(logits: torch.FloatTensor, labels):\n    \"\"\"\n    A memory efficient implementation of logprobs_from_logits\n    \"\"\"\n    if logits.dtype in [torch.float32, torch.float64]:\n        logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)\n        # loop to reduce peak mem consumption\n        logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits])\n        logprobs_labels = logits_labels - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)\n    else:\n        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach\n        logprobs_labels = []\n        for row_logits, row_labels in zip(logits, labels, strict=True):  # loop to reduce peak mem consumption\n            row_logprobs = F.log_softmax(row_logits, dim=-1)\n            row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)\n            logprobs_labels.append(row_logprobs_labels)\n        logprobs_labels = torch.stack(logprobs_labels)\n    return logprobs_labels\n\n\ndef clip_by_value(x, tensor_min, tensor_max):\n    \"\"\"\n    Tensor extenstion to torch.clamp\n    https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713\n    \"\"\"\n    clipped = torch.max(torch.min(x, tensor_max), tensor_min)\n    return clipped\n\n\ndef entropy_from_logits(logits: torch.Tensor):\n    \"\"\"Calculate entropy from logits.\"\"\"\n    pd = torch.nn.functional.softmax(logits, dim=-1)\n    entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)\n    return entropy\n\n\ndef entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 2048):\n    \"\"\"Memory-efficient entropy calculation with chunking.\"\"\"\n    entropy = torch.zeros(logits.shape[0], device=logits.device)\n    for i in range(0, logits.shape[0], chunk_size):\n        logits_chunk = logits[i : i + chunk_size].float()\n        pd_chunk = torch.nn.functional.softmax(logits_chunk, dim=-1)\n        entropy_chunk = torch.logsumexp(logits_chunk, dim=-1) - torch.sum(pd_chunk * logits_chunk, dim=-1)\n        entropy[i : i + chunk_size] = entropy_chunk\n    return entropy\n\n\ndef masked_sum(values, mask, axis=None):\n    \"\"\"Compute mean of tensor with a masked values.\"\"\"\n    # If NaNs exist out of mask, replace NaNs in values with a value that\n    # won't affect the sum (e.g., 0 for masked regions)\n    valid_values = torch.where(mask.bool(), values, 0.0)\n    return (valid_values * mask).sum(axis=axis)\n\n\ndef masked_mean(values, mask, axis=None):\n    \"\"\"\n    Compute the mean of `values` over elements selected by `mask`.\n\n    Args:\n        values (Tensor): Input tensor.\n        mask (Tensor): Boolean or numeric mask of the same shape as `values`.\n        axis (int or tuple of int, optional): Dimension(s) along which to compute the mean.\n            Defaults to None (over all elements).\n\n    Returns:\n        Tensor: Masked mean, with shape equal to `values` reduced over `axis`.\n    \"\"\"\n    s = masked_sum(values, mask, axis)\n    return s / (mask.sum(axis=axis) + 1e-8)\n\n\ndef masked_var(values, mask, unbiased=True):\n    \"\"\"Compute variance of tensor with masked values.\"\"\"\n    mean = masked_mean(values, mask)\n    centered_values = values - mean\n    variance = masked_mean(centered_values**2, mask)\n    if unbiased:\n        mask_sum = mask.sum()\n        if mask_sum == 0:\n            raise ValueError(\"At least one element in the mask has to be 1.\")\n        # note that if mask_sum == 1, then there is a division by zero issue\n        # to avoid it you just need to use a larger minibatch_size\n        if mask_sum == 1:\n            raise ValueError(\"The sum of the mask is one, which can cause a division by zero.\")\n        bessel_correction = mask_sum / (mask_sum - 1)\n        variance = variance * bessel_correction\n    return variance\n\n\ndef masked_whiten(values, mask, shift_mean=True):\n    \"\"\"\n    Whiten `values` by normalizing with mean and variance computed over `mask`.\n\n    Args:\n        values (torch.Tensor): Input tensor.\n        mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats.\n        shift_mean (bool): If True (default), output is zero-mean;\n                           if False, the original mean is re-added after scaling.\n\n    Returns:\n        torch.Tensor: Whitened tensor of same shape as `values`.\n    \"\"\"\n    mean, var = masked_mean(values, mask), masked_var(values, mask)\n    whitened = (values - mean) * torch.rsqrt(var + 1e-8)\n    if not shift_mean:\n        whitened += mean\n    return whitened\n\n\ndef get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64):\n    \"\"\"\n    end of sentence token can be int or list: 1 or [1, 2]\n    e.g.\n    response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0],\n                                [78, 0, 76, 2, 1, 0, 0],\n                                [23, 98, 1, 0, 0, 0, 0],\n                                [33, 3, 98, 45, 1, 0, 0]])\n    #eos_token=1\n    response_mask:  tensor([[1, 1, 1, 1, 0, 0, 0],\n                            [1, 1, 1, 1, 1, 0, 0],\n                            [1, 1, 1, 0, 0, 0, 0],\n                            [1, 1, 1, 1, 1, 0, 0]])\n    #eos_token=[1,2]\n    response_mask:  tensor([[1, 1, 1, 1, 0, 0, 0],\n                            [1, 1, 1, 1, 0, 0, 0],\n                            [1, 1, 1, 0, 0, 0, 0],\n                            [1, 1, 1, 1, 1, 0, 0]])\n    \"\"\"\n    eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int()\n    return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype)\n\n\ndef compute_grad_norm(model: nn.Module):\n    total_grad_square = 0\n    for param in model.parameters():\n        if param.grad is not None:\n            total_grad_square += torch.sum(torch.square(param.grad.detach())).item()\n    return total_grad_square\n\n\ndef broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src, group):\n    \"\"\"\n    TODO: optimize this. Technically, we only need one broadcast\n    \"\"\"\n\n    for key in tensors.sorted_keys:\n        torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False)\n\n\ndef allgather_dict_tensors(tensors: dict[str, torch.Tensor] | TensorDict, size, group, dim=0):\n    \"\"\"\n    TODO: optimize this.\n    - We can use async ops\n    - We can use only one allgather\n    Args:\n        tensors:\n        size:\n        group:\n\n    Returns:\n\n    \"\"\"\n    if isinstance(tensors, TensorDict):\n        is_tensor_dict = True\n        tensors_as_dict = tensors.to_dict()\n    else:\n        tensors_as_dict = tensors\n        is_tensor_dict = False\n\n    output = {}\n    sorted_keys = sorted(tensors_as_dict.keys())\n    for key in sorted_keys:\n        val = tensors_as_dict[key]\n        output[key] = [torch.empty_like(val) for _ in range(size)]\n        torch.distributed.all_gather(output[key], val, group=group, async_op=False)\n        output[key] = torch.cat(output[key], dim=dim)\n\n    if is_tensor_dict:\n        output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)\n\n    return output\n\n\ndef split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> list[TensorDict]:\n    assert tensors.batch_size[0] % batch_size == 0, (\n        f\"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}\"\n    )\n    return tensors.split(batch_size)\n\n\ndef pad_2d_list_to_length(response, pad_token_id, max_length=None):\n    \"\"\"\n    pad a 2D list (e.g. responses, logprobs) to a 2D tensor.\n    \"\"\"\n    response_length = max(len(sub_list) for sub_list in response)\n    target_length = max_length if max_length is not None and max_length > response_length else response_length\n    padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response]\n    tensor = torch.tensor(padded_response)\n    return tensor\n\n\ndef pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):\n    \"\"\"\n    pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.\n    input shape: [bs, seq_length]\n    output shape: [bs, max_seq_length]\n    \"\"\"\n    if tensors.shape[-1] >= max_seq_len:\n        return tensors\n    # (0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad\n    pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1])\n    return F.pad(tensors, pad_tuple, \"constant\", pad_token_id)\n\n\ndef postprocess_data(\n    input_ids: torch.Tensor,\n    attention_mask: torch.Tensor,\n    max_length: int,\n    pad_token_id: int,\n    left_pad=True,\n    truncation=\"error\",\n):\n    \"\"\"Process tokenizer outputs to consistent shapes via padding/truncation.\n\n    Args:\n        input_ids: Token indices [batch_size, seq_len]\n        attention_mask: Mask [batch_size, seq_len]\n        max_length: Target sequence length\n        pad_token_id: Padding token ID\n        left_pad: Pad left if True\n        truncation: \"left\", \"right\", \"middle\" or \"error\"\n\n    Returns:\n        (input_ids, attention_mask) padded/truncated to max_length\n    \"\"\"\n    assert truncation in [\"left\", \"right\", \"middle\", \"error\"]\n    assert input_ids.ndim == 2\n\n    sequence_length = input_ids.shape[-1]\n    if sequence_length < max_length:\n        input_ids = pad_sequence_to_length(\n            input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad\n        )\n        attention_mask = pad_sequence_to_length(\n            attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad\n        )\n    elif sequence_length > max_length:\n        if truncation == \"left\":\n            # actually, left truncation may not be reasonable\n            input_ids = input_ids[:, -max_length:]\n            attention_mask = attention_mask[:, -max_length:]\n        elif truncation == \"right\":\n            input_ids = input_ids[:, :max_length]\n            attention_mask = attention_mask[:, :max_length]\n        elif truncation == \"middle\":\n            left_half = max_length // 2\n            right_half = max_length - left_half\n            input_ids = torch.cat([input_ids[:, :left_half], input_ids[:, -right_half:]], dim=-1)\n            attention_mask = torch.cat([attention_mask[:, :left_half], attention_mask[:, -right_half:]], dim=-1)\n        elif truncation == \"error\":\n            raise NotImplementedError(f\"{sequence_length=} is larger than {max_length=}\")\n        else:\n            raise NotImplementedError(f\"Unknown truncation method {truncation}\")\n\n    return input_ids, attention_mask\n\n\ndef tokenize_and_postprocess_data(\n    prompt: str, tokenizer: PreTrainedTokenizer, max_length: int, pad_token_id: int, left_pad=True, truncation=\"error\"\n):\n    \"\"\"Tokenize text and process outputs to consistent tensor shapes.\n\n    Args:\n        prompt: Input text to tokenize\n        tokenizer: HuggingFace tokenizer instance\n        max_length: Target sequence length\n        pad_token_id: Padding token ID\n        left_pad: Pad left if True\n        truncation: Truncation strategy (\"left\"/\"right\"/\"error\")\n\n    Returns:\n        Tuple of (input_ids, attention_mask) from postprocess_data\n    \"\"\"\n    input_data = tokenizer(prompt, return_tensors=\"pt\", add_special_tokens=False)\n    input_ids = input_data[\"input_ids\"]\n    attention_mask = input_data[\"attention_mask\"]\n\n    return postprocess_data(input_ids, attention_mask, max_length, pad_token_id, left_pad, truncation)\n\n\ndef remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor):\n    \"\"\"Remove the pad token.\n\n    Args:\n        input_ids shape: [bs, seq_length]\n        attention_mask shape: [bs, seq_length]\n    Returns:\n        no_padding_batch(List[List[int]]): contains the rmpad token ids per query.\n    \"\"\"\n    no_padding_batch = []\n    for ids, mask in zip(input_ids, attention_mask, strict=True):\n        no_padding_batch.append((ids[len(ids) - mask.sum() :]).cpu().numpy().tolist())\n    return no_padding_batch\n\n\ndef log_probs_from_logits_response(input_ids, logits, response_length):\n    \"\"\"Compute the response log_probs from full logits. Note that logits = model(input_ids)\n\n    Args:\n        input_ids: [batch_size, seqlen]\n        logits: [batch_size, seqlen, vocab_size]\n\n    Returns:\n        response_log_prob:\n    \"\"\"\n    response_logits = logits[:, -response_length - 1 : -1]\n    response = input_ids[:, -response_length:]\n    response_log_prob = logprobs_from_logits(logits=response_logits, labels=response)\n    return response_log_prob\n\n\ndef log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):\n    \"\"\"Compute the log_probs from logits with rmpad logits and pad input. Note that\n    logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between\n    logits and input_ids.\n    The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive\n    for large vocab_size\n\n    Args:\n        input_ids: [batch_size, seqlen]\n        attention_mask: [batch_size, seqlen]\n        logits_rmpad: [total_nnz, vocab_size]\n        response_length: int\n    \"\"\"\n    from flash_attn.bert_padding import pad_input, unpad_input\n\n    batch_size, seqlen = input_ids.shape\n    input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)\n    input_ids_rmpad = input_ids_rmpad.squeeze(-1)\n    input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)\n    full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)  # (total_nnz,)\n    full_output = pad_input(\n        hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n    )\n    output = full_output.squeeze(-1)[:, -response_length - 1 : -1]  # [batch_size, response_length]\n    return output\n\n\ndef log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length):\n    \"\"\"Compute the log_probs from logits with rmpad input_ids and logits. Note that\n    logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between\n    logits and input_ids.\n    The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive\n    for large vocab_size\n\n    Args:\n        input_ids_rmpad: [1, total_nnz]\n        logits_rmpad: [total_nnz, vocab_size]\n        indices: [total_nnz]\n        batch_size: int\n        seqlen: int\n        response_length: int\n    \"\"\"\n    from flash_attn.bert_padding import pad_input\n\n    input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # transpose back to [total_nnz, 1]\n    input_ids_rmpad = input_ids_rmpad.squeeze(-1)\n    input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)\n    full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)  # (total_nnz,)\n    full_output = pad_input(\n        hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n    )\n    output = full_output.squeeze(-1)[:, -response_length - 1 : -1]  # [batch_size, response_length]\n    return output\n\n\ndef post_process_logits(input_ids, logits, temperature, top_k, top_p):\n    if temperature != 1.0:\n        logits = logits.div_(temperature)  # inplace operation to avoid OOM\n    # TODO: add them back\n    # if top_k is not None and top_k > 0:\n    #     logits = TopKLogitsWarper(top_k=top_k)(input_ids, logits)\n    # if top_p is not None and top_p < 1.0 and top_p > 0.0:\n    #     logits = TopPLogitsWarper(top_p=top_p)(input_ids, logits)\n    return logits\n\n\n\"\"\"\nOptimizer related\n\"\"\"\n\n\ndef get_cosine_schedule_with_warmup(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    min_lr_ratio: float = 0.0,\n    num_cycles: float = 0.5,\n    last_epoch: int = -1,\n    init_lr_ratio: float = None,\n):\n    \"\"\"\n    Create a schedule with a learning rate that decreases following the values of the cosine function between the\n    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\n    initial lr set in the optimizer.\n    Args:\n        optimizer (:class:`~torch.optim.Optimizer`):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (:obj:`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (:obj:`int`):\n            The total number of training steps.\n        min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):\n            The minimum lr ratio w.r.t the maximum.\n        num_cycles (:obj:`float`, `optional`, defaults to 0.5):\n            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0\n            following a half-cosine).\n        last_epoch (:obj:`int`, `optional`, defaults to -1):\n            The index of the last epoch when resuming training.\n        init_lr_ratio (:obj:`float`, `optional`, defaults to None):\n            The initial lr ratio w.r.t the maximum.\n    Return:\n        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n    min_lr_ratio = 0.0 if min_lr_ratio is None else min_lr_ratio\n    assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0\n    coef = (1 - min_lr_ratio) * 0.5\n    intercept = (1 + min_lr_ratio) * 0.5\n\n    init_lr_ratio = 0.0 if init_lr_ratio is None else init_lr_ratio\n    assert init_lr_ratio >= 0 and init_lr_ratio <= 1.0\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return init_lr_ratio + (1.0 - init_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps)))\n        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n        x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)\n        return max(min_lr_ratio, x * coef + intercept)\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef get_constant_schedule_with_warmup(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    last_epoch: int = -1,\n):\n    \"\"\"\n    Create a constant LR schedule with a linear warmup phase.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        num_warmup_steps (int): Number of steps to ramp up the LR from 0 to initial value.\n        last_epoch (int, optional): The index of the last epoch when resuming training. Defaults to -1.\n\n    Returns:\n        LambdaLR: Scheduler that increases LR linearly during warmup, then holds it constant.\n    \"\"\"\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1.0, num_warmup_steps))\n        return 1.0\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds):\n    # create causal mask\n    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n    combined_attention_mask = None\n    if input_shape[-1] > 1:\n        combined_attention_mask = _make_causal_mask(\n            input_shape,\n            inputs_embeds.dtype,\n            device=inputs_embeds.device,\n        )\n\n    if attention_mask is not None:\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n            inputs_embeds.device\n        )\n        combined_attention_mask = (\n            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n        )\n\n    return combined_attention_mask\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\ndef get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\ndef get_wsd_schedule_with_warmup(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    min_lr_ratio: float = 0.0,\n    num_cycles: float = 0.5,\n    last_epoch: int = -1,\n    stable_ratio: float = 0.9,\n):\n    \"\"\"\n    Create a Warmup-Stable-Decay learning rate scheduler.\n\n    The schedule follows three phases:\n    1. Warmup: Learning rate increases linearly from 0 to the initial LR\n    2. Stable: Learning rate remains constant at the initial LR\n    3. Decay: Learning rate decreases following a cosine curve to min_lr_ratio * initial LR\n\n    Args:\n        optimizer (:class:`~torch.optim.Optimizer`):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (:obj:`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (:obj:`int`):\n            The total number of training steps.\n        min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):\n            The minimum learning rate ratio w.r.t the initial learning rate.\n        num_cycles (:obj:`float`, `optional`, defaults to 0.5):\n            The number of waves in the cosine schedule during decay phase.\n        last_epoch (:obj:`int`, `optional`, defaults to -1):\n            The index of the last epoch when resuming training.\n        stable_ratio (:obj:`float`, `optional`, defaults to 0.0):\n            The ratio of non-warmup steps that should maintain a constant learning rate.\n            Set to 0.0 to behave exactly like cosine schedule.\n\n    Return:\n        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n    remaining_steps = max(0, num_training_steps - num_warmup_steps)\n    num_stable_steps = int(remaining_steps * stable_ratio)\n    num_decay_steps = remaining_steps - num_stable_steps\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        if current_step < num_warmup_steps + num_stable_steps:\n            return 1.0\n        if current_step < num_training_steps:\n            progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))\n            value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))\n            return (1.0 - min_lr_ratio) * value + min_lr_ratio\n        return min_lr_ratio\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\n@contextmanager\ndef check_device_is_available():\n    \"\"\"\n    Some modules must be imported after CUDA is initialized. Such as sglang's sharding manager.\n\n    This context manager checks if CUDA is available and raises an error if it is not.\n    \"\"\"\n    if not get_torch_device().is_available():\n        raise RuntimeError(\"Device {} must be initialized before importing this module.\".format(get_device_name()))\n\n    yield\n\n\ndef distributed_mean_max_min_std(local_tensor, compute_max=True, compute_min=True, compute_std=True):\n    \"\"\"Compute distributed statistics across all processes.\n\n    Args:\n        local_tensor: Tensor containing local values\n        compute_max: Include maximum value calculation\n        compute_min: Include minimum value calculation\n        compute_std: Include standard deviation calculation\n\n    Returns:\n        Tuple containing (mean, max, min, std) in this order. None for disabled metrics.\n    \"\"\"\n    # Sum the local tensor across all processes\n    local_sum = torch.sum(local_tensor)\n    local_num = torch.tensor(torch.numel(local_tensor), device=get_device_name())\n\n    torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM)\n    torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM)\n\n    global_mean = local_sum / local_num\n\n    if compute_max:\n        local_max = torch.max(local_tensor)\n        torch.distributed.all_reduce(local_max, op=torch.distributed.ReduceOp.MAX)\n    else:\n        local_max = None\n\n    if compute_min:\n        local_min = torch.min(local_tensor)\n        torch.distributed.all_reduce(local_min, op=torch.distributed.ReduceOp.MIN)\n    else:\n        local_min = None\n\n    if compute_std:\n        square_diff = torch.sum(torch.pow(local_tensor - global_mean, 2))\n        torch.distributed.all_reduce(square_diff, op=torch.distributed.ReduceOp.SUM)\n        global_std = torch.sqrt(square_diff / (local_num - 1))\n    else:\n        global_std = None\n\n    return global_mean, local_max, local_min, global_std\n\n\ndef distributed_masked_mean(local_tensor, local_mask):\n    \"\"\"Compute global mean of non-masked elements across distributed processes.\n\n    Args:\n        local_tensor (torch.Tensor): Input tensor with local values\n        local_mask (torch.Tensor): Binary mask (1=valid, 0=ignore) matching local_tensor shape\n\n    Returns:\n        torch.Tensor: Global mean of all valid elements across processes\n    \"\"\"\n    local_tensor = local_tensor * local_mask\n\n    local_sum = torch.sum(local_tensor)\n    local_num = torch.sum(local_mask)\n\n    torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM)\n    torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM)\n\n    global_mean = local_sum / local_num\n    return global_mean\n"
  },
  {
    "path": "verl_distillation/verl/utils/tracking.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA unified tracking interface that supports logging data to different backend\n\"\"\"\n\nimport dataclasses\nimport json\nimport os\nfrom enum import Enum\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Any\n\n\nclass Tracking:\n    \"\"\"A unified tracking interface for logging experiment data to multiple backends.\n\n    This class provides a centralized way to log experiment metrics, parameters, and artifacts\n    to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console.\n\n    Attributes:\n        supported_backend: List of supported tracking backends.\n        logger: Dictionary of initialized logger instances for each backend.\n    \"\"\"\n\n    supported_backend = [\n        \"wandb\",\n        \"mlflow\",\n        \"swanlab\",\n        \"vemlp_wandb\",\n        \"tensorboard\",\n        \"console\",\n        \"clearml\",\n        \"trackio\",\n        \"file\",\n    ]\n\n    def __init__(self, project_name, experiment_name, default_backend: str | list[str] = \"console\", config=None):\n        if isinstance(default_backend, str):\n            default_backend = [default_backend]\n        for backend in default_backend:\n            if backend == \"tracking\":\n                import warnings\n\n                warnings.warn(\"`tracking` logger is deprecated. use `wandb` instead.\", DeprecationWarning, stacklevel=2)\n            else:\n                assert backend in self.supported_backend, f\"{backend} is not supported\"\n\n        self.logger = {}\n\n        if \"tracking\" in default_backend or \"wandb\" in default_backend:\n            import os\n\n            import wandb\n\n            settings = None\n            if config and config[\"trainer\"].get(\"wandb_proxy\", None):\n                settings = wandb.Settings(https_proxy=config[\"trainer\"][\"wandb_proxy\"])\n            entity = os.environ.get(\"WANDB_ENTITY\", None)\n            wandb.init(project=project_name, name=experiment_name, entity=entity, config=config, settings=settings)\n            self.logger[\"wandb\"] = wandb\n\n        if \"trackio\" in default_backend:\n            import trackio\n\n            trackio.init(project=project_name, name=experiment_name, config=config)\n            self.logger[\"trackio\"] = trackio\n\n        if \"mlflow\" in default_backend:\n            import os\n\n            import mlflow\n\n            MLFLOW_TRACKING_URI = os.environ.get(\"MLFLOW_TRACKING_URI\", \"sqlite:////tmp/mlruns.db\")\n            mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)\n\n            # Project_name is actually experiment_name in MLFlow\n            # If experiment does not exist, will create a new experiment\n            experiment = mlflow.set_experiment(project_name)\n            mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name)\n            mlflow.log_params(_compute_mlflow_params_from_objects(config))\n            self.logger[\"mlflow\"] = _MlflowLoggingAdapter()\n\n        if \"swanlab\" in default_backend:\n            import os\n\n            import swanlab\n\n            SWANLAB_API_KEY = os.environ.get(\"SWANLAB_API_KEY\", None)\n            SWANLAB_LOG_DIR = os.environ.get(\"SWANLAB_LOG_DIR\", \"swanlog\")\n            SWANLAB_MODE = os.environ.get(\"SWANLAB_MODE\", \"cloud\")\n            if SWANLAB_API_KEY:\n                swanlab.login(SWANLAB_API_KEY)  # NOTE: previous login information will be overwritten\n\n            if config is None:\n                config = {}  # make sure config is not None, otherwise **config will raise error\n            swanlab.init(\n                project=project_name,\n                experiment_name=experiment_name,\n                config={\"FRAMEWORK\": \"verl\", **config},\n                logdir=SWANLAB_LOG_DIR,\n                mode=SWANLAB_MODE,\n            )\n            self.logger[\"swanlab\"] = swanlab\n\n        if \"vemlp_wandb\" in default_backend:\n            import os\n\n            import volcengine_ml_platform\n            from volcengine_ml_platform import wandb as vemlp_wandb\n\n            volcengine_ml_platform.init(\n                ak=os.environ[\"VOLC_ACCESS_KEY_ID\"],\n                sk=os.environ[\"VOLC_SECRET_ACCESS_KEY\"],\n                region=os.environ[\"MLP_TRACKING_REGION\"],\n            )\n\n            vemlp_wandb.init(\n                project=project_name,\n                name=experiment_name,\n                config=config,\n                sync_tensorboard=True,\n            )\n            self.logger[\"vemlp_wandb\"] = vemlp_wandb\n\n        if \"tensorboard\" in default_backend:\n            self.logger[\"tensorboard\"] = _TensorboardAdapter(project_name, experiment_name)\n\n        if \"console\" in default_backend:\n            from verl.utils.logger import LocalLogger\n\n            self.console_logger = LocalLogger(print_to_console=True)\n            self.logger[\"console\"] = self.console_logger\n\n        if \"clearml\" in default_backend:\n            self.logger[\"clearml\"] = ClearMLLogger(project_name, experiment_name, config)\n\n        if \"file\" in default_backend:\n            self.logger[\"file\"] = FileLogger(project_name, experiment_name)\n\n    def log(self, data, step, backend=None):\n        for default_backend, logger_instance in self.logger.items():\n            if backend is None or default_backend in backend:\n                logger_instance.log(data=data, step=step)\n\n    def __del__(self):\n        if \"wandb\" in self.logger:\n            self.logger[\"wandb\"].finish(exit_code=0)\n        if \"swanlab\" in self.logger:\n            self.logger[\"swanlab\"].finish()\n        if \"vemlp_wandb\" in self.logger:\n            self.logger[\"vemlp_wandb\"].finish(exit_code=0)\n        if \"tensorboard\" in self.logger:\n            self.logger[\"tensorboard\"].finish()\n        if \"clearml\" in self.logger:\n            self.logger[\"clearml\"].finish()\n        if \"trackio\" in self.logger:\n            self.logger[\"trackio\"].finish()\n        if \"file\" in self.logger:\n            self.logger[\"file\"].finish()\n\n\nclass ClearMLLogger:\n    def __init__(self, project_name: str, experiment_name: str, config):\n        self.project_name = project_name\n        self.experiment_name = experiment_name\n\n        import clearml\n\n        self._task: clearml.Task = clearml.Task.init(\n            task_name=experiment_name,\n            project_name=project_name,\n            continue_last_task=True,\n            output_uri=False,\n        )\n\n        self._task.connect_configuration(config, name=\"Hyperparameters\")\n\n    def _get_logger(self):\n        return self._task.get_logger()\n\n    def log(self, data, step):\n        import numpy as np\n        import pandas as pd\n\n        # logs = self._rewrite_logs(data)\n        logger = self._get_logger()\n        for k, v in data.items():\n            title, series = k.split(\"/\", 1)\n\n            if isinstance(v, int | float | np.floating | np.integer):\n                logger.report_scalar(\n                    title=title,\n                    series=series,\n                    value=v,\n                    iteration=step,\n                )\n            elif isinstance(v, pd.DataFrame):\n                logger.report_table(\n                    title=title,\n                    series=series,\n                    table_plot=v,\n                    iteration=step,\n                )\n            else:\n                logger.warning(\n                    f'Trainer is attempting to log a value of \"{v}\" of type {type(v)} for key \"{k}\". This '\n                    f\"invocation of ClearML logger's function is incorrect so this attribute was dropped. \"\n                )\n\n    def finish(self):\n        self._task.close()\n\n\nclass FileLogger:\n    def __init__(self, project_name: str, experiment_name: str):\n        self.project_name = project_name\n        self.experiment_name = experiment_name\n\n        self.filepath = os.getenv(\"VERL_FILE_LOGGER_PATH\", None)\n        if self.filepath is None:\n            root_path = os.path.expanduser(os.getenv(\"VERL_FILE_LOGGER_ROOT\", \".\"))\n            directory = os.path.join(root_path, self.project_name)\n            os.makedirs(directory, exist_ok=True)\n            self.filepath = os.path.join(directory, f\"{self.experiment_name}.jsonl\")\n            print(f\"Creating file logger at {self.filepath}\")\n        self.fp = open(self.filepath, \"w\")\n\n    def log(self, data, step):\n        data = {\"step\": step, \"data\": data}\n        self.fp.write(json.dumps(data) + \"\\n\")\n\n    def finish(self):\n        self.fp.close()\n\n\nclass _TensorboardAdapter:\n    def __init__(self, project_name, experiment_name):\n        import os\n\n        from torch.utils.tensorboard import SummaryWriter\n\n        tensorboard_dir = os.environ.get(\"TENSORBOARD_DIR\", f\"tensorboard_log/{project_name}/{experiment_name}\")\n        os.makedirs(tensorboard_dir, exist_ok=True)\n        print(f\"Saving tensorboard log to {tensorboard_dir}.\")\n        self.writer = SummaryWriter(tensorboard_dir)\n\n    def log(self, data, step):\n        for key in data:\n            self.writer.add_scalar(key, data[key], step)\n\n    def finish(self):\n        self.writer.close()\n\n\nclass _MlflowLoggingAdapter:\n    def __init__(self):\n        import logging\n        import re\n\n        self.logger = logging.getLogger(__name__)\n        # MLflow metric key validation logic:\n        # https://github.com/mlflow/mlflow/blob/master/mlflow/utils/validation.py#L157C12-L157C44\n        # Only characters allowed: slashes, alphanumerics, underscores, periods, dashes, colons,\n        # and spaces.\n        self._invalid_chars_pattern = re.compile(\n            r\"[^/\\w.\\- :]\"\n        )  # Allowed: slashes, alphanumerics, underscores, periods, dashes, colons, and spaces.\n\n    def log(self, data, step):\n        import mlflow\n\n        def sanitize_key(key):\n            # First replace @ with _at_ for backward compatibility\n            sanitized = key.replace(\"@\", \"_at_\")\n            # Then replace any other invalid characters with _\n            sanitized = self._invalid_chars_pattern.sub(\"_\", sanitized)\n            if sanitized != key:\n                self.logger.warning(\n                    \"[MLflow] Metric key '%s' sanitized to '%s' due to invalid characters.\", key, sanitized\n                )\n            return sanitized\n\n        results = {sanitize_key(k): v for k, v in data.items()}\n        mlflow.log_metrics(metrics=results, step=step)\n\n\ndef _compute_mlflow_params_from_objects(params) -> dict[str, Any]:\n    if params is None:\n        return {}\n\n    return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep=\"/\")\n\n\ndef _transform_params_to_json_serializable(x, convert_list_to_dict: bool):\n    _transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict)\n\n    if dataclasses.is_dataclass(x):\n        return _transform(dataclasses.asdict(x))\n    if isinstance(x, dict):\n        return {k: _transform(v) for k, v in x.items()}\n    if isinstance(x, list):\n        if convert_list_to_dict:\n            return {\"list_len\": len(x)} | {f\"{i}\": _transform(v) for i, v in enumerate(x)}\n        else:\n            return [_transform(v) for v in x]\n    if isinstance(x, Path):\n        return str(x)\n    if isinstance(x, Enum):\n        return x.value\n\n    return x\n\n\ndef _flatten_dict(raw: dict[str, Any], *, sep: str) -> dict[str, Any]:\n    import pandas as pd\n\n    ans = pd.json_normalize(raw, sep=sep).to_dict(orient=\"records\")[0]\n    assert isinstance(ans, dict)\n    return ans\n\n\n@dataclasses.dataclass\nclass ValidationGenerationsLogger:\n    project_name: str = None\n    experiment_name: str = None\n\n    def log(self, loggers, samples, step):\n        if \"wandb\" in loggers:\n            self.log_generations_to_wandb(samples, step)\n        if \"swanlab\" in loggers:\n            self.log_generations_to_swanlab(samples, step)\n        if \"mlflow\" in loggers:\n            self.log_generations_to_mlflow(samples, step)\n\n        if \"clearml\" in loggers:\n            self.log_generations_to_clearml(samples, step)\n        if \"tensorboard\" in loggers:\n            self.log_generations_to_tensorboard(samples, step)\n\n        if \"vemlp_wandb\" in loggers:\n            self.log_generations_to_vemlp_wandb(samples, step)\n\n    def log_generations_to_vemlp_wandb(self, samples, step):\n        from volcengine_ml_platform import wandb as vemlp_wandb\n\n        self._log_generations_to_wandb(samples, step, vemlp_wandb)\n\n    def log_generations_to_wandb(self, samples, step):\n        import wandb\n\n        self._log_generations_to_wandb(samples, step, wandb)\n\n    def _log_generations_to_wandb(self, samples, step, wandb):\n        \"\"\"Log samples to wandb as a table\"\"\"\n\n        # Create column names for all samples\n        columns = [\"step\"] + sum(\n            [[f\"input_{i + 1}\", f\"output_{i + 1}\", f\"score_{i + 1}\"] for i in range(len(samples))], []\n        )\n\n        if not hasattr(self, \"validation_table\"):\n            # Initialize the table on first call\n            self.validation_table = wandb.Table(columns=columns)\n\n        # Create a new table with same columns and existing data\n        # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737\n        new_table = wandb.Table(columns=columns, data=self.validation_table.data)\n\n        # Add new row with all data\n        row_data = []\n        row_data.append(step)\n        for sample in samples:\n            row_data.extend(sample)\n\n        new_table.add_data(*row_data)\n\n        # Update reference and log\n        wandb.log({\"val/generations\": new_table}, step=step)\n        self.validation_table = new_table\n\n    def log_generations_to_swanlab(self, samples, step):\n        \"\"\"Log samples to swanlab as text\"\"\"\n        import swanlab\n\n        swanlab_table = swanlab.echarts.Table()\n\n        # Create column names\n        headers = [\"step\", \"input\", \"output\", \"score\"]\n\n        swanlab_row_list = [[step, *sample] for sample in samples]\n        swanlab_table.add(headers=headers, rows=swanlab_row_list)\n\n        # Log to swanlab\n        swanlab.log({\"val/generations\": swanlab_table}, step=step)\n\n    def log_generations_to_mlflow(self, samples, step):\n        \"\"\"Log validation generation to mlflow as artifacts\"\"\"\n        # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact\n\n        import json\n        import tempfile\n\n        import mlflow\n\n        try:\n            with tempfile.TemporaryDirectory() as tmp_dir:\n                validation_gen_step_file = Path(tmp_dir, f\"val_step{step}.json\")\n                row_data = []\n                for sample in samples:\n                    data = {\"input\": sample[0], \"output\": sample[1], \"score\": sample[2]}\n                    row_data.append(data)\n                with open(validation_gen_step_file, \"w\") as file:\n                    json.dump(row_data, file)\n                mlflow.log_artifact(validation_gen_step_file)\n        except Exception as e:\n            print(f\"WARNING: save validation generation file to mlflow failed with error {e}\")\n\n    def log_generations_to_clearml(self, samples, step):\n        \"\"\"Log validation generation to clearml as table\"\"\"\n\n        import clearml\n        import pandas as pd\n\n        task: clearml.Task | None = clearml.Task.current_task()\n        if task is None:\n            return\n\n        table = [\n            {\n                \"step\": step,\n                \"input\": sample[0],\n                \"output\": sample[1],\n                \"score\": sample[2],\n            }\n            for sample in samples\n        ]\n\n        logger = task.get_logger()\n        logger.report_table(\n            series=\"Validation generations\",\n            title=\"Validation\",\n            table_plot=pd.DataFrame.from_records(table),\n            iteration=step,\n        )\n\n    def log_generations_to_tensorboard(self, samples, step):\n        \"\"\"Log samples to tensorboard as text\"\"\"\n        # Initialize tensorboard writer if not exists\n        if not hasattr(self, \"writer\"):\n            from torch.utils.tensorboard import SummaryWriter\n\n            # Use the same directory structure as _TensorboardAdapter\n            if self.project_name and self.experiment_name:\n                default_dir = os.path.join(\"tensorboard_log\", self.project_name, self.experiment_name)\n            else:\n                default_dir = \"tensorboard_log\"\n\n            tensorboard_dir = os.environ.get(\"TENSORBOARD_DIR\", default_dir)\n            os.makedirs(tensorboard_dir, exist_ok=True)\n            self.writer = SummaryWriter(log_dir=tensorboard_dir)\n\n        # Format the samples data into readable text\n        text_content = f\"**Generation Results - Step {step}**\\n\\n\"\n\n        for i, sample in enumerate(samples):\n            text_content += f\"### Sample {i + 1}\\n\"\n\n            # Assuming sample contains [input, output, score]\n            if len(sample) >= 3:\n                input_text, output_text, score = sample[0], sample[1], sample[2]\n\n                text_content += f\"**Input:** {input_text}\\n\\n\"\n                text_content += f\"**Output:** {output_text}\\n\\n\"\n                text_content += f\"**Score:** {score}\\n\\n\"\n            else:\n                # Handle cases where sample format might be different\n                text_content += f\"**Data:** {sample}\\n\\n\"\n\n            text_content += \"---\\n\\n\"\n\n        # Log to tensorboard as text\n        self.writer.add_text(\"val/generations\", text_content, step)\n        # Flush to ensure data is written\n        self.writer.flush()\n"
  },
  {
    "path": "verl_distillation/verl/utils/transferqueue_utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport inspect\nimport os\nimport threading\nfrom functools import wraps\nfrom typing import Any, Callable\n\nfrom tensordict import TensorDict\n\ntry:\n    from transfer_queue import (\n        AsyncTransferQueueClient,\n        BatchMeta,\n        ZMQServerInfo,\n    )\n\nexcept ImportError:\n    # TODO: Use a hacky workaround for ImportError since\n    # transfer_queue isn't a default verl dependency.\n    class BatchMeta:\n        pass\n\n\nfrom verl.protocol import DataProto\n\n_TRANSFER_QUEUE_CLIENT = None\n_VAL_TRANSFER_QUEUE_CLIENT = None\n\nis_transferqueue_enabled = os.environ.get(\"TRANSFER_QUEUE_ENABLE\", False)\n\n\ndef create_transferqueue_client(\n    client_id: str,\n    controller_infos: dict[Any, \"ZMQServerInfo\"],\n    storage_infos: dict[Any, \"ZMQServerInfo\"],\n) -> None:\n    global _TRANSFER_QUEUE_CLIENT\n    global _VAL_TRANSFER_QUEUE_CLIENT\n    if \"val\" in client_id:\n        _VAL_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos)\n    else:\n        _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos)\n\n\ndef get_transferqueue_client() -> \"AsyncTransferQueueClient\":\n    return _TRANSFER_QUEUE_CLIENT\n\n\ndef get_val_transferqueue_client() -> \"AsyncTransferQueueClient\":\n    return _VAL_TRANSFER_QUEUE_CLIENT\n\n\ndef _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any:\n    # Use a temporary event loop in a new thread because event\n    # loop may already exist in server mode\n    tmp_event_loop = asyncio.new_event_loop()\n    thread = threading.Thread(\n        target=tmp_event_loop.run_forever,\n        name=\"batchmeta dataproto converter\",\n        daemon=True,\n    )\n\n    def run_coroutine(coroutine):\n        if not thread.is_alive():\n            thread.start()\n        future = asyncio.run_coroutine_threadsafe(coroutine, tmp_event_loop)\n        return future.result()\n\n    async def stop_loop():\n        tmp_event_loop.stop()\n\n    try:\n        return run_coroutine(async_func(*args, **kwargs))\n    finally:\n        if thread.is_alive():\n            asyncio.run_coroutine_threadsafe(stop_loop(), tmp_event_loop)\n            thread.join()\n\n\ndef _find_batchmeta(*args, **kwargs):\n    for arg in args:\n        if isinstance(arg, BatchMeta):\n            return arg\n    for v in kwargs.values():\n        if isinstance(v, BatchMeta):\n            return v\n    return None\n\n\nasync def _async_batchmeta_to_dataproto(batchmeta: \"BatchMeta\") -> DataProto:\n    if batchmeta.samples == [] or batchmeta.samples is None:\n        return DataProto(\n            batch=TensorDict({}, batch_size=(0,)),\n            non_tensor_batch={},\n            meta_info=batchmeta.extra_info.copy(),\n        )\n\n    if batchmeta.extra_info.get(\"validate\", False):\n        tensordict = await _VAL_TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta)\n    else:\n        tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta)\n    return DataProto.from_tensordict(tensordict, meta_info=batchmeta.extra_info.copy())\n\n\ndef _batchmeta_to_dataproto(batchmeta: \"BatchMeta\") -> DataProto:\n    return _run_async_in_temp_loop(_async_batchmeta_to_dataproto, batchmeta)\n\n\nasync def _async_update_batchmeta_with_output(output: DataProto, batchmeta: \"BatchMeta\") -> None:\n    for k, v in output.meta_info.items():\n        batchmeta.set_extra_info(k, v)\n\n    if len(output) > 0:\n        tensordict = output.to_tensordict()\n        # pop meta_info\n        for key in output.meta_info.keys():\n            tensordict.pop(key)\n        batchmeta.add_fields(tensordict)\n        if batchmeta.extra_info.get(\"validate\", False):\n            await _VAL_TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta)\n        else:\n            await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta)\n\n\ndef _update_batchmeta_with_output(output: DataProto, batchmeta: \"BatchMeta\") -> None:\n    _run_async_in_temp_loop(_async_update_batchmeta_with_output, output, batchmeta)\n\n\ndef tqbridge(put_data: bool = True):\n    \"\"\" \"Creates a decorator for bridging BatchMeta and DataProto.\n\n    This decorator automatically handles conversions between `BatchMeta` and\n    `DataProto` in function parameters, and decides whether to sync function\n    output back to `BatchMeta` based on configuration(`put_data`). It supports\n    both synchronous and asynchronous functions (async def), and can control\n    whether to enable enhanced logic via the global `HAS_TQ` variable (when disabled,\n    simply calls the original function as-is).\n\n    Args:\n        put_data: Whether put the DataProto into Storage after func return.\n                  If True, after function execution, the output result will be\n                  updated to `BatchMeta` and `BatchMeta` will be returned;\n                  If False, the function output result will be returned directly.\n                  Defaults to True.\n\n    Returns:\n        A decorator function used to decorate target functions (synchronous or asynchronous).\n    \"\"\"\n\n    def decorator(func):\n        @wraps(func)\n        def inner(*args, **kwargs):\n            batchmeta = _find_batchmeta(*args, **kwargs)\n            if batchmeta is None:\n                return func(*args, **kwargs)\n            else:\n                args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]\n                kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()}\n                output = func(*args, **kwargs)\n                if put_data:\n                    _update_batchmeta_with_output(output, batchmeta)\n                    return batchmeta\n                else:\n                    return output\n\n        @wraps(func)\n        async def async_inner(*args, **kwargs):\n            batchmeta = _find_batchmeta(*args, **kwargs)\n            if batchmeta is None:\n                return await func(*args, **kwargs)\n            else:\n                args = [await _async_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]\n                kwargs = {\n                    k: await _async_batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v\n                    for k, v in kwargs.items()\n                }\n                output = await func(*args, **kwargs)\n                if put_data:\n                    await _async_update_batchmeta_with_output(output, batchmeta)\n                    return batchmeta\n                return output\n\n        @wraps(func)\n        def dummy_inner(*args, **kwargs):\n            return func(*args, **kwargs)\n\n        @wraps(func)\n        async def dummy_async_inner(*args, **kwargs):\n            return await func(*args, **kwargs)\n\n        wrapper_inner = inner if is_transferqueue_enabled else dummy_inner\n        wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner\n\n        wrapper = wrapper_async_inner if inspect.iscoroutinefunction(func) else wrapper_inner\n        return wrapper\n\n    return decorator\n"
  },
  {
    "path": "verl_distillation/verl/utils/transformers_compat.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nCompatibility utilities for different versions of transformers library.\n\"\"\"\n\nimport importlib.metadata\nfrom functools import lru_cache\nfrom typing import Optional\n\nfrom packaging import version\n\n# Handle version compatibility for flash_attn_supports_top_left_mask\n# This function was added in newer versions of transformers\ntry:\n    from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask\nexcept ImportError:\n    # For older versions of transformers that don't have this function\n    # Default to False as a safe fallback for older versions\n    def flash_attn_supports_top_left_mask():\n        \"\"\"Fallback implementation for older transformers versions.\n        Returns False to disable features that require this function.\n        \"\"\"\n        return False\n\n\n@lru_cache\ndef is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool:\n    try:\n        # Get the installed version of the transformers library\n        transformers_version_str = importlib.metadata.version(\"transformers\")\n    except importlib.metadata.PackageNotFoundError as e:\n        raise ModuleNotFoundError(\"The `transformers` package is not installed.\") from e\n\n    transformers_version = version.parse(transformers_version_str)\n\n    lower_bound_check = True\n    if min_version is not None:\n        lower_bound_check = version.parse(min_version) <= transformers_version\n\n    upper_bound_check = True\n    if max_version is not None:\n        upper_bound_check = transformers_version <= version.parse(max_version)\n\n    return lower_bound_check and upper_bound_check\n"
  },
  {
    "path": "verl_distillation/verl/utils/ulysses.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities for DeepSpeed Ulysses Sequence Parallelism.\nDeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509\nInspired from: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/sequence/layer.py\n\"\"\"\n\nfrom typing import Any, Optional\n\nimport torch\nimport torch.distributed as dist\nfrom torch import Tensor\nfrom torch.distributed import ProcessGroup\n\n_ULYSSES_SEQUENCE_PARALLEL_GROUP = None\n\n\ndef set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):\n    \"\"\"\n    Set ulysses sequence parallel process group.\n    \"\"\"\n    global _ULYSSES_SEQUENCE_PARALLEL_GROUP\n    _ULYSSES_SEQUENCE_PARALLEL_GROUP = group\n\n\ndef get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:\n    \"\"\"\n    Get ulysses sequence parallel process group.\n    \"\"\"\n    global _ULYSSES_SEQUENCE_PARALLEL_GROUP\n    return _ULYSSES_SEQUENCE_PARALLEL_GROUP\n\n\ndef get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:\n    \"\"\"\n    Get ulysses sequence parallel world size.\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    return dist.get_world_size(group) if group else 1\n\n\ndef get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:\n    \"\"\"\n    Get ulysses sequence parallel rank.\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    return dist.get_rank(group) if group else 0\n\n\ndef gather_seq_scatter_heads(\n    x: Tensor,\n    seq_dim: int,\n    head_dim: int,\n    unpadded_dim_size: int = 0,\n    group: ProcessGroup = None,\n) -> Tensor:\n    \"\"\"\n    A func to sync embedding input with alltoall in sequence parallel\n    gather sequence dimension and scatter head dim:\n    e.g. seq_dim: 1, head_dim: 2\n    [bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...]\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    if not group:\n        return x\n    sp_world = get_ulysses_sequence_parallel_world_size(group)\n    x = SeqAllToAll.apply(group, x, head_dim, seq_dim)\n    if unpadded_dim_size and unpadded_dim_size % sp_world != 0:\n        padding_size = x.size(seq_dim) - unpadded_dim_size\n        x = _unpad_tensor(x, seq_dim, padding_size)\n    return x\n\n\ndef gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor:\n    \"\"\"\n    A func to sync attention result with alltoall in sequence parallel\n    gather head dimension and scatter seq dim:\n    e.g. seq_dim: 1, head_dim: 2\n    [bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...]\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    if not group:\n        return x\n    dim_size = x.size(seq_dim)\n    sp_world = get_ulysses_sequence_parallel_world_size(group)\n    if dim_size % sp_world != 0:\n        padding_size = sp_world - (dim_size % sp_world)\n        x = _pad_tensor(x, seq_dim, padding_size)\n    return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)\n\n\ndef _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:\n    shape = list(x.shape)\n    shape[dim] = padding_size\n    pad = torch.zeros(shape, dtype=x.dtype, device=x.device)\n    return torch.cat([x, pad], dim=dim)\n\n\ndef _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:\n    slc = [slice(None)] * len(x.shape)\n    slc[dim] = slice(0, -padding_size)\n    return x[tuple(slc)]\n\n\ndef slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor:\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    sp_world_size = dist.get_world_size(group)\n    sp_rank = get_ulysses_sequence_parallel_rank()\n    dim_size = x.size(dim)\n    # pad before slice\n    if padding and dim_size % sp_world_size:\n        padding_size = sp_world_size - (dim_size % sp_world_size)\n        x = _pad_tensor(x, dim, padding_size)\n    # slice the input tensor\n    parts = x.size(dim) // sp_world_size\n    slc = [slice(None)] * len(x.shape)\n    slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts)\n    return x[tuple(slc)].contiguous()\n\n\ndef all_to_all_tensor(\n    local_input: Tensor,\n    scatter_dim: int,\n    gather_dim: int,\n    group: Optional[dist.ProcessGroup] = None,\n    async_op: bool = False,\n):\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    seq_world_size = dist.get_world_size(group)\n    input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]\n    output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]\n    comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)\n    if async_op:\n\n        def wait():\n            comm.wait()\n            return torch.cat(output_list, dim=gather_dim).contiguous()\n\n        return wait\n    return torch.cat(output_list, dim=gather_dim).contiguous()\n\n\ndef all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False):\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    sp_world_size = dist.get_world_size(group=group)\n    output_shape = list(local_tensor.shape)\n    output_shape[0] = output_shape[0] * sp_world_size\n    output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device)\n    dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op)\n    return output\n\n\nclass SeqAllToAll(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx: Any,\n        group: dist.ProcessGroup,\n        local_input: Tensor,\n        scatter_dim: int,\n        gather_dim: int,\n        async_op: bool = False,\n    ) -> Tensor:\n        ctx.group = group\n        ctx.scatter_dim = scatter_dim\n        ctx.gather_dim = gather_dim\n        ctx.async_op = async_op\n        return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)\n\n    @staticmethod\n    def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]:\n        input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() if ctx.async_op else grad_output[0]\n        return (\n            None,\n            all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),\n            None,\n            None,\n            None,\n            None,\n        )\n\n\nclass Gather(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx: Any,\n        group: dist.ProcessGroup,\n        local_tensor: Tensor,\n        gather_dim: int,\n        grad_scaler: bool = True,\n        async_op=False,\n    ) -> Tensor:\n        ctx.group = group\n        ctx.gather_dim = gather_dim\n        ctx.grad_scaler = grad_scaler\n        ctx.async_op = async_op\n\n        sp_world_size = dist.get_world_size(group=group)\n        ctx.sp_world_size = sp_world_size\n\n        sp_rank = dist.get_rank(group=group)\n        ctx.sp_rank = sp_rank\n\n        local_shape = list(local_tensor.size())\n        split_size = local_shape[0]\n        part_size = local_shape[gather_dim]  # store original size\n        ctx.part_size = part_size\n\n        output = all_gather_tensor(local_tensor, group, async_op)\n        return torch.cat(output.split(split_size, dim=0), dim=gather_dim)\n\n    @staticmethod\n    def backward(ctx: Any, grad_output: Tensor) -> Any:\n        if ctx.grad_scaler:\n            grad_output = grad_output * ctx.sp_world_size\n        return (\n            None,\n            grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(),\n            None,\n            None,\n            None,\n            None,\n        )\n\n\ndef gather_outpus_and_unpad(*args, **kwargs):\n    raise RuntimeError(\n        \"please use verl.utils.ulysses.gather_outputs_and_unpad instead of verl.utils.ulysses.gather_outpus_and_unpad\"\n    )\n\n\ndef gather_outputs_and_unpad(\n    x: Tensor,\n    gather_dim: int,\n    unpad_dim: int = None,\n    padding_size: int = 0,\n    grad_scaler: bool = True,\n    group: Optional[dist.ProcessGroup] = None,\n):\n    \"\"\"\n    Gather a tensor across a process group and optionally unpad its padded elements.\n\n    Args:\n        x (Tensor): Input tensor to gather.\n        gather_dim (int): Dimension along which to gather across ranks.\n        unpad_dim (int, optional): Dimension from which to remove padding. If None, no unpadding.\n        padding_size (int): Number of padding elements to remove on `unpad_dim`. Defaults to 0.\n        grad_scaler (bool): Whether to apply gradient scaling during gather. Defaults to True.\n        group (ProcessGroup, optional): Process group for gathering. If None, uses\n            `get_ulysses_sequence_parallel_group()`. If still None, returns `x` unchanged.\n\n    Returns:\n        Tensor: The gathered tensor, with padding removed if requested.\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    if group is None:\n        return x\n    x = Gather.apply(group, x, gather_dim, grad_scaler)\n    if unpad_dim is not None:\n        assert isinstance(padding_size, int), \"padding size is not given or is not an integer\"\n        if padding_size == 0:\n            return x\n        x = _unpad_tensor(x, unpad_dim, padding_size)\n    return x\n\n\ndef ulysses_pad(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1):\n    if position_ids_rmpad is not None:\n        assert position_ids_rmpad.size(-2) == 1\n        assert input_ids_rmpad.size(-1) == position_ids_rmpad.size(-1)\n    if sp_size <= 1:\n        return input_ids_rmpad, position_ids_rmpad, 0\n    _, total_seq_len = input_ids_rmpad.shape\n    pad_size = (sp_size - total_seq_len % sp_size) % sp_size\n    if pad_size > 0:\n        input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0)\n        if position_ids_rmpad is not None:\n            pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)\n            if position_ids_rmpad.dim() == 3:\n                pad_pos_ids = pad_pos_ids.unsqueeze(0).repeat(position_ids_rmpad.size(0), 1, 1)\n            position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)\n    return input_ids_rmpad, position_ids_rmpad, pad_size\n\n\ndef ulysses_pad_and_slice_inputs(\n    input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1\n):\n    \"\"\"\n    Pad and slice input_ids to be divisible by sp_size\n    Pad position_ids to be divisible by sp_size.\n\n    Note both input_ids_rmpad and position_ids_rmpad will be padded and sliced.\n\n    The is the utility of pre-forward for ulysses sequence parallelism\n\n    Args:\n        input_ids_rmpad: shape of [bsz, seqlen]\n        position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1\n        sp_size (int): ulysses sequence parallelism size\n\n    Returns:\n        torch.Tensor: padded and sliced input_ids\n        torch.Tensor: padded and sliced position_ids\n        int: pad size\n    \"\"\"\n    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(input_ids_rmpad, position_ids_rmpad, sp_size)\n    input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)\n    if position_ids_rmpad is not None:\n        position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False)\n    return input_ids_rmpad, position_ids_rmpad, pad_size\n\n\ndef validate_ulysses_config(num_heads, ulysses_sequence_size):\n    if ulysses_sequence_size > 1:\n        assert num_heads % ulysses_sequence_size == 0, (\n            f\"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})\"\n        )\n"
  },
  {
    "path": "verl_distillation/verl/utils/vllm/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .utils import TensorLoRARequest, VLLMHijack, is_version_ge\n\n# The contents of vllm/patch.py should not be imported here, because the contents of\n# patch.py should be imported after the vllm LLM instance is created. Therefore,\n# wait until you actually start using it before importing the contents of\n# patch.py separately.\n\n__all__ = [\n    \"TensorLoRARequest\",\n    \"VLLMHijack\",\n    \"is_version_ge\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/utils/vllm/patch.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# To support different vLLM versions, we add the model into SUPPORTED_MOE_MODELS separately to avoid triggering\n# unsupported issues.\nSUPPORTED_MOE_MODELS = []\n\ntry:\n    from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(DeepseekV2ForCausalLM)\n    SUPPORTED_MOE_MODELS.append(DeepseekV3ForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.mixtral import MixtralForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(MixtralForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(Qwen2MoeForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(Qwen3MoeForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.qwen3_vl_moe import Qwen3MoeLLMForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(Qwen3MoeLLMForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.kimi_vl import KimiVLForConditionalGeneration\n\n    SUPPORTED_MOE_MODELS.append(KimiVLForConditionalGeneration)\nexcept ImportError:\n    pass\n\n\ndef patch_vllm_moe_model_weight_loader(model):\n    # this is a work around to load the weight of vllm fused moe model\n    # it is from a bug from vllm 0.8.2\n    # all the weights are supposed to have a weight_loader, but the moe weights\n    # do not have a weight_loader, so we need to patch it\n    # (True, 'model.embed_tokens.weight')\n    # (True, 'model.layers.0.self_attn.qkv_proj.weight')\n    # (True, 'model.layers.0.self_attn.qkv_proj.bias')\n    # (True, 'model.layers.0.self_attn.o_proj.weight')\n    # (True, 'model.layers.0.mlp.gate.weight')\n    # (True, 'model.layers.0.mlp.shared_expert.gate_up_proj.weight')\n    # (True, 'model.layers.0.mlp.shared_expert.down_proj.weight')\n    # (False, 'model.layers.0.mlp.shared_expert_gate.weight')   use default\n    # (False, 'model.layers.0.input_layernorm.weight')          use default\n    # (False, 'model.layers.0.post_attention_layernorm.weight') use default\n    # (False, 'model.layers.0.mlp.experts.w13_weight')          use mlp.experts.weight_loader\n    # (False, 'model.layers.0.mlp.experts.w2_weight')          use mlp.experts.weight_loader\n\n    # Early return if no MOE models are supported\n    if not SUPPORTED_MOE_MODELS:\n        return\n\n    original_model_type = type(model)\n\n    # Define MLP attribute mapping for different model types\n    MLP_ATTR_MAPPING = {}\n    try:\n        from vllm.model_executor.models.mixtral import MixtralForCausalLM\n\n        MLP_ATTR_MAPPING[MixtralForCausalLM] = \"block_sparse_moe\"\n    except ImportError:\n        pass\n\n    DEFAULT_MLP_ATTR = \"mlp\"\n\n    # Get inner model (either model.model or model.language_model)\n    inner_model = getattr(model, \"model\", None) or getattr(model, \"language_model\", None)\n    if inner_model is None:\n        raise ValueError(\"The provided model does not have a valid 'model' or 'language_model' attribute.\")\n\n    if not isinstance(model, tuple(SUPPORTED_MOE_MODELS)) and not isinstance(inner_model, tuple(SUPPORTED_MOE_MODELS)):\n        return\n\n    # TODO(@leisuzz): class Qwen3MoeLLMForCausalLM is not available if VLLM version < 0.11.0,\n    # will update the 'if statement' with 'isinstance' when verl commonly use VLLM version >= 0.11.0\n    if type(inner_model).__name__ == \"Qwen3MoeLLMForCausalLM\":\n        inner_model = inner_model.model  # Reassign inner_model in Qwen3-vl\n\n    for layer_idx, layer in enumerate(inner_model.layers):\n        mlp_attr = MLP_ATTR_MAPPING.get(original_model_type, DEFAULT_MLP_ATTR)\n\n        mlp = getattr(layer, mlp_attr, None)\n        if not mlp:\n            continue\n\n        experts = getattr(mlp, \"experts\", None)\n        if not experts or not hasattr(experts, \"weight_loader\"):\n            continue\n\n        # Patch the weight loaders\n        for name, param in mlp.named_parameters():\n            if \"w13_weight\" in name or \"w2_weight\" in name:\n                param.weight_loader = experts.weight_loader\n"
  },
  {
    "path": "verl_distillation/verl/utils/vllm/utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom msgspec import field\nfrom packaging import version as vs\nfrom vllm.lora.models import LoRAModel\nfrom vllm.lora.request import LoRARequest\nfrom vllm.lora.utils import get_adapter_absolute_path\nfrom vllm.lora.worker_manager import LRUCacheWorkerLoRAManager\n\nfrom verl.third_party.vllm import get_version\n\n\nclass TensorLoRARequest(LoRARequest):\n    peft_config: dict = field(default=None)\n    lora_tensors: dict = field(default=None)\n\n\nclass VLLMHijack:\n    @staticmethod\n    def hijack():\n        def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel:\n            \"\"\"\n            based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors\n\n            Reason:\n            VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths.\n            To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to\n            load memory-based LoRA tensors.\n            \"\"\"\n            try:\n                supported_lora_modules = self._adapter_manager.supported_lora_modules\n                packed_modules_mapping = self._adapter_manager.packed_modules_mapping\n                expected_lora_modules: list[str] = []\n                for module in supported_lora_modules:\n                    if module in packed_modules_mapping:\n                        expected_lora_modules.extend(packed_modules_mapping[module])\n                    else:\n                        expected_lora_modules.append(module)\n\n                expected_lora_modules = list(set(expected_lora_modules))\n\n                lora_tensors = None\n                from vllm.lora.peft_helper import PEFTHelper\n\n                if isinstance(lora_request, TensorLoRARequest):\n                    peft_config = lora_request.peft_config\n                    lora_tensors = lora_request.lora_tensors\n                    peft_helper = PEFTHelper.from_dict(peft_config)\n                else:\n                    lora_path = get_adapter_absolute_path(lora_request.lora_path)\n\n                    peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings)\n\n                # Validates the LoRA configuration against requirements before\n                # loading weights, throwing an exception if validation fails.\n                peft_helper.validate_legal(self.lora_config)\n\n                # For some models like Qwen2VL, we need to use hf_to_vllm_mapper\n                # to ensure correct loading of lora weights.\n                model = self._adapter_manager.model\n                hf_to_vllm_mapper = None\n                if hasattr(model, \"hf_to_vllm_mapper\") and model.hf_to_vllm_mapper is not None:\n                    hf_to_vllm_mapper = model.hf_to_vllm_mapper\n\n                if isinstance(lora_request, TensorLoRARequest):\n                    lora = self._lora_model_cls.from_lora_tensors(\n                        lora_model_id=lora_request.lora_int_id,\n                        tensors=lora_tensors,\n                        peft_helper=peft_helper,\n                        device=\"cpu\",\n                        dtype=self.lora_config.lora_dtype,\n                        embeddings=None,\n                        target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size,\n                        embedding_modules=self.embedding_modules,\n                        embedding_padding_modules=self.embedding_padding_modules,\n                        weights_mapper=hf_to_vllm_mapper,\n                    )\n                else:\n                    lora = self._lora_model_cls.from_local_checkpoint(\n                        lora_path,\n                        expected_lora_modules,\n                        peft_helper=peft_helper,\n                        lora_model_id=lora_request.lora_int_id,\n                        device=\"cpu\",\n                        dtype=self.lora_config.lora_dtype,\n                        target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size,\n                        embedding_modules=self.embedding_modules,\n                        embedding_padding_modules=self.embedding_padding_modules,\n                        weights_mapper=hf_to_vllm_mapper,\n                    )\n            except Exception as e:\n                raise e\n\n            if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:\n                raise ValueError(\n                    f\"LoRA added vocab size {lora.extra_vocab_size} is greater than lora_extra_vocab_size \"\n                    f\"{self.lora_config.lora_extra_vocab_size}.\"\n                )\n            return lora\n\n        def do_hijack(target_cls, target_method_name, hooking_method):\n            setattr(target_cls, target_method_name, hooking_method)\n\n        do_hijack(LRUCacheWorkerLoRAManager, \"_load_adapter\", hijack__load_adapter)\n\n\ndef is_version_ge(pkg: str = \"vllm\", minver: str = \"0.7.3\"):\n    \"\"\"check if the package version is greater than or equal to the minimum version\"\"\"\n    return vs.parse(get_version(pkg)) >= vs.parse(minver)\n"
  },
  {
    "path": "verl_distillation/verl/version/version",
    "content": "0.7.0.dev\n"
  },
  {
    "path": "verl_distillation/verl/workers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/workers/actor/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BasePPOActor\nfrom .dp_actor import DataParallelPPOActor\n\n__all__ = [\"BasePPOActor\", \"DataParallelPPOActor\"]\n"
  },
  {
    "path": "verl_distillation/verl/workers/actor/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe base class for Actor\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nimport torch\n\nfrom verl import DataProto\n\n__all__ = [\"BasePPOActor\"]\n\n\nclass BasePPOActor(ABC):\n    def __init__(self, config):\n        \"\"\"The base class for PPO actor\n\n        Args:\n            config (DictConfig): a config passed to the PPOActor. We expect the type to be\n                DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general.\n        \"\"\"\n        super().__init__()\n        self.config = config\n\n    @abstractmethod\n    def compute_log_prob(self, data: DataProto) -> torch.Tensor:\n        \"\"\"Compute logits given a batch of data.\n\n        Args:\n            data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,\n                ```attention_mask``` and ```position_ids```.\n\n        Returns:\n            DataProto: a DataProto containing the key ```log_probs```\n\n\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def update_policy(self, data: DataProto) -> dict:\n        \"\"\"Update the policy with an iterator of DataProto\n\n        Args:\n            data (DataProto): an iterator over the DataProto that returns by\n                ```make_minibatch_iterator```\n\n        Returns:\n            Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model\n            such as ```loss```, ```grad_norm```, etc,.\n\n        \"\"\"\n        pass\n"
  },
  {
    "path": "verl_distillation/verl/workers/actor/dp_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nSingle Process Actor\n\"\"\"\n\nimport logging\nimport os\n\nimport torch\nfrom torch import nn\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.tensor import DTensor\n\nimport verl.utils.torch_functional as verl_F\nfrom verl import DataProto\nfrom verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty\nfrom verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input\nfrom verl.utils.device import get_device_id, get_device_name\nfrom verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch\nfrom verl.utils.torch_functional import logprobs_from_logits\nfrom verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs\nfrom verl.workers.actor import BasePPOActor\nfrom verl.workers.config import ActorConfig\n\n__all__ = [\"DataParallelPPOActor\"]\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass DataParallelPPOActor(BasePPOActor):\n    \"\"\"FSDP DataParallel PPO Actor or Ref worker\n\n    Args:\n        config (ActorConfig): Actor config\n        actor_module (nn.Module): Actor or ref module\n        actor_optimizer (torch.optim.Optimizer, optional): Actor optimizer. Defaults to None.\n    \"\"\"\n\n    def __init__(self, config: ActorConfig, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None):\n        \"\"\"When optimizer is None, it is Reference Policy\"\"\"\n        super().__init__(config)\n        self.actor_module = actor_module\n        self.actor_optimizer = actor_optimizer\n        role = \"Ref\" if actor_optimizer is None else \"Actor\"\n\n        self.use_remove_padding = self.config.get(\"use_remove_padding\", False)\n        if torch.distributed.get_rank() == 0:\n            print(f\"{role} use_remove_padding={self.use_remove_padding}\")\n        self.use_fused_kernels = self.config.get(\"use_fused_kernels\", False)\n        if torch.distributed.get_rank() == 0:\n            print(f\"{role} use_fused_kernels={self.use_fused_kernels}\")\n\n        self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size\n        self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1\n\n        if self.config.entropy_from_logits_with_chunking:\n            entropy_from_logits = verl_F.entropy_from_logits_with_chunking\n        else:\n            entropy_from_logits = verl_F.entropy_from_logits\n\n        self.compute_entropy_from_logits = (\n            torch.compile(entropy_from_logits, dynamic=True)\n            if self.config.get(\"use_torch_compile\", True)  # use torch compile by default\n            else entropy_from_logits\n        )\n        self.device_name = get_device_name()\n\n    def _forward_micro_batch(\n        self, micro_batch, temperature, calculate_entropy=False\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Returns:\n            entropy: # (bs, response_len)\n            log_probs: # (bs, response_len)\n        \"\"\"\n        response_length = micro_batch[\"responses\"].size(-1)\n        multi_modal_inputs = {}\n        if \"multi_modal_inputs\" in micro_batch.keys():\n            from verl.utils.model import extract_multi_modal_inputs\n\n            multi_modal_inputs = extract_multi_modal_inputs(micro_batch[\"multi_modal_inputs\"])\n\n        with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch_size, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n            entropy = None\n            if position_ids.dim() == 3:  # qwen2vl mrope\n                position_ids = position_ids.transpose(0, 1)  # (bsz, 4, seqlen) -> (4, bsz, seqlen)\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                if position_ids.dim() == 3:\n                    position_ids_rmpad = (\n                        index_first_axis(rearrange(position_ids, \"c b s ... -> (b s) c ...\"), indices)\n                        .transpose(0, 1)\n                        .unsqueeze(1)\n                    )  # (4, bsz, seqlen) -> (4, 1, bsz * seqlen)\n                else:\n                    position_ids_rmpad = index_first_axis(\n                        rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                    ).transpose(0, 1)\n\n                if \"image_bound\" in multi_modal_inputs:\n                    from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo\n\n                    multi_modal_inputs = process_multi_modal_inputs_for_minicpmo(\n                        input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs\n                    )\n\n                # for compute the log_prob\n                input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)\n\n                # pad and slice the inputs if sp > 1\n                if self.use_ulysses_sp:\n                    is_vlm_model = hasattr(\n                        getattr(self.actor_module, \"module\", self.actor_module).config, \"vision_config\"\n                    )\n                    if is_vlm_model:\n                        # vlm model's inputs will be sliced after embedding\n                        input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(\n                            input_ids_rmpad,\n                            position_ids_rmpad=position_ids_rmpad,\n                            sp_size=self.ulysses_sequence_parallel_size,\n                        )\n                    else:\n                        input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                            input_ids_rmpad,\n                            position_ids_rmpad=position_ids_rmpad,\n                            sp_size=self.ulysses_sequence_parallel_size,\n                        )\n                    input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad_rolled,\n                        position_ids_rmpad=None,\n                        sp_size=self.ulysses_sequence_parallel_size,\n                    )\n\n                input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)  # ((total_nnz / sp) + pad)\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                extra_args = {}\n                if self.use_fused_kernels:\n                    extra_args[\"temperature\"] = temperature\n                    extra_args[\"return_dict\"] = True\n\n                output = self.actor_module(\n                    input_ids=input_ids_rmpad,\n                    attention_mask=None,\n                    position_ids=position_ids_rmpad,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                    **extra_args,\n                )  # prevent model thinks we are generating\n\n                if self.use_fused_kernels:\n                    log_probs = output.log_probs.squeeze(0)  # (total_nnz,)\n                    entropy_rmpad = output.entropy.squeeze(0)  # (total_nnz,)\n\n                else:\n                    logits_rmpad = output.logits.squeeze(0)  # (total_nnz, vocab_size)\n                    logits_rmpad.div_(temperature)\n\n                    # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)\n                    inplace_backward = True\n                    if calculate_entropy:\n                        inplace_backward = False\n                    log_probs = logprobs_from_logits(\n                        logits=logits_rmpad,\n                        labels=input_ids_rmpad_rolled,\n                        inplace_backward=inplace_backward,\n                    )\n\n                    # compute entropy\n                    if calculate_entropy:\n                        if not self.config.entropy_checkpointing:\n                            entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad)  # ((total_nnz / sp) + pad)\n                        else:\n                            entropy_rmpad = torch.utils.checkpoint.checkpoint(\n                                self.compute_entropy_from_logits, logits_rmpad\n                            )\n\n                # gather log_prob if sp > 1\n                if self.use_ulysses_sp:\n                    # gather and unpad for the ulysses sp\n                    log_probs = gather_outputs_and_unpad(\n                        log_probs,\n                        gather_dim=0,\n                        unpad_dim=0,\n                        padding_size=pad_size,\n                    )\n                    if calculate_entropy:\n                        entropy_rmpad = gather_outputs_and_unpad(\n                            entropy_rmpad,\n                            gather_dim=0,\n                            unpad_dim=0,\n                            padding_size=pad_size,\n                        )\n                # pad back to (bsz, seqlen)\n                if calculate_entropy:\n                    full_entropy = pad_input(\n                        hidden_states=entropy_rmpad.unsqueeze(-1),\n                        indices=indices,\n                        batch=batch_size,\n                        seqlen=seqlen,\n                    )\n                full_log_probs = pad_input(\n                    hidden_states=log_probs.unsqueeze(-1),\n                    indices=indices,\n                    batch=batch_size,\n                    seqlen=seqlen,\n                )\n\n                # only return response part:\n                if calculate_entropy:\n                    entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)\n                log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)\n\n            else:  # not using rmpad and no ulysses sp\n                extra_args = {}\n                if self.use_fused_kernels:\n                    extra_args[\"temperature\"] = temperature\n                    extra_args[\"return_dict\"] = True\n\n                output = self.actor_module(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                    **extra_args,\n                )  # prevent model thinks we are generating\n\n                if self.use_fused_kernels:\n                    log_probs = output.log_probs[:, -response_length - 1 : -1]\n                    entropy = output.entropy[:, -response_length - 1 : -1]  # (bsz, response_length)\n\n                else:\n                    logits = output.logits\n\n                    logits.div_(temperature)\n                    logits = logits[:, -response_length - 1 : -1, :]  # (bsz, response_length, vocab_size)\n                    log_probs = logprobs_from_logits(logits, micro_batch[\"responses\"])\n                    if calculate_entropy:\n                        if not self.config.entropy_checkpointing:\n                            entropy = verl_F.entropy_from_logits(logits)  # (bsz, response_length)\n                        else:\n                            entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits)\n\n            return entropy, log_probs\n\n    def _optimizer_step(self):\n        assert self.config.grad_clip is not None\n\n        if isinstance(self.actor_module, FSDP):\n            grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)\n        elif isinstance(self.actor_module, FSDPModule):\n            grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)\n\n        if isinstance(grad_norm, DTensor):\n            grad_norm = grad_norm.full_tensor()\n\n        # if grad_norm is not finite, skip the update\n        if not torch.isfinite(grad_norm):\n            print(f\"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}\")\n            self.actor_optimizer.zero_grad()\n        else:\n            self.actor_optimizer.step()\n        return grad_norm\n\n    @GPUMemoryLogger(role=\"dp actor\", logger=logger)\n    def compute_log_prob(self, data: DataProto, calculate_entropy=False, mask_special_token=False) -> torch.Tensor:\n        \"\"\"Compute the log probability of the responses given input_ids, attention_mask and position_ids\n\n        Args:\n            data (DataProto): a DataProto containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the\n                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.\n\n        Returns:\n            torch.Tensor: the log_prob tensor\n        \"\"\"\n        # set to eval\n        self.actor_module.eval()\n\n        micro_batch_size = data.meta_info[\"micro_batch_size\"]\n        temperature = data.meta_info[\"temperature\"]  # temperature must be in the data.meta_info to avoid silent error\n        use_dynamic_bsz = data.meta_info[\"use_dynamic_bsz\"]\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n        if mask_special_token:\n            select_keys.append(\"distill_special_token_mask\")\n        non_tensor_select_keys = [\"multi_modal_inputs\"] if has_multi_modal_inputs else []\n\n        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)\n        \n        # replace distill_special_token to EOS, the token behind the first distill_speical_token will be masked.\n        if mask_special_token:\n            distill_special_token_mask = torch.zeros_like(data.batch[\"attention_mask\"])\n            distill_special_token_mask[:,-len(data.batch[\"distill_special_token_mask\"][0]):] = data.batch[\"distill_special_token_mask\"]\n            data.batch[\"input_ids\"][distill_special_token_mask == 1] = 151645\n\n        if use_dynamic_bsz:\n            max_token_len = data.meta_info[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len)\n        else:\n            micro_batches = data.split(micro_batch_size)\n\n        log_probs_lst = []\n        entropy_lst = []\n        for micro_batch in micro_batches:\n            micro_batch = micro_batch.to(get_device_id())\n            model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n            with torch.no_grad():\n                entropy, log_probs = self._forward_micro_batch(\n                    model_inputs, temperature=temperature, calculate_entropy=calculate_entropy\n                )\n            log_probs_lst.append(log_probs)\n            if calculate_entropy:\n                entropy_lst.append(entropy)\n\n        log_probs = torch.concat(log_probs_lst, dim=0)\n        entropys = None\n        if calculate_entropy:\n            entropys = torch.concat(entropy_lst, dim=0)\n\n        if use_dynamic_bsz:\n            log_probs = restore_dynamic_batch(log_probs, batch_idx_list)\n            if calculate_entropy:\n                entropys = restore_dynamic_batch(entropys, batch_idx_list)\n\n        if mask_special_token:\n            log_probs[data.batch[\"distill_special_token_mask\"] == 1] = self.config.ref_log_prob_replace_val\n        return log_probs, entropys\n\n    @GPUMemoryLogger(role=\"dp actor\", logger=logger)\n    def update_policy(self, data: DataProto):\n        # make sure we are in training mode\n        self.actor_module.train()\n\n        temperature = data.meta_info[\"temperature\"]  # temperature must be in the data.meta_info to avoid silent error\n\n        select_keys = [\n            \"responses\",\n            \"response_mask\",\n            \"input_ids\",\n            \"attention_mask\",\n            \"position_ids\",\n            \"old_log_probs\",\n            \"advantages\",\n        ]\n        if self.config.use_kl_loss:\n            select_keys.append(\"ref_log_prob\")\n        # Include pre-computed IS weights if present in batch\n        # Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True\n        if \"rollout_is_weights\" in data.batch.keys():\n            select_keys.append(\"rollout_is_weights\")\n\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n        non_tensor_select_keys = [\"multi_modal_inputs\"] if has_multi_modal_inputs else []\n\n        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        mini_batches = data.split(self.config.ppo_mini_batch_size)\n\n        on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1\n\n        metrics = {}\n        for _ in range(self.config.ppo_epochs):\n            for batch_idx, mini_batch in enumerate(mini_batches):\n                if self.config.use_dynamic_bsz:\n                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                    micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len)\n                else:\n                    self.gradient_accumulation = (\n                        self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n                    )\n                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n\n                self.actor_optimizer.zero_grad()\n\n                for micro_batch in micro_batches:\n                    micro_batch = micro_batch.to(get_device_id())\n                    micro_batch_metrics = {}\n                    model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n                    response_mask = model_inputs[\"response_mask\"]\n                    old_log_prob = model_inputs[\"old_log_probs\"]\n                    advantages = model_inputs[\"advantages\"]\n\n                    entropy_coeff = self.config.entropy_coeff\n                    loss_agg_mode = self.config.loss_agg_mode\n\n                    if self.config.use_dynamic_bsz:\n                        loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size\n                    else:\n                        loss_scale_factor = 1 / self.gradient_accumulation\n\n                    # all return: (bsz, response_length)\n                    calculate_entropy = False\n                    if entropy_coeff != 0:\n                        calculate_entropy = True\n                    entropy, log_prob = self._forward_micro_batch(\n                        model_inputs, temperature=temperature, calculate_entropy=calculate_entropy\n                    )\n\n                    # for fully_async_policy recipe\n                    if hasattr(self.config, \"use_rollout_log_probs\") and self.config.use_rollout_log_probs:\n                        old_log_prob = model_inputs[\"old_log_probs\"]\n                    else:\n                        if on_policy:\n                            old_log_prob = log_prob.detach()\n                        else:\n                            old_log_prob = model_inputs[\"old_log_probs\"]\n\n                    loss_mode = self.config.policy_loss.get(\"loss_mode\", \"vanilla\")\n                    # vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla\n\n                    # Extract pre-computed rollout importance sampling weights if present\n                    # Weights are computed centrally in trainer and added when algorithm.rollout_is=True\n                    rollout_is_weights = model_inputs.get(\"rollout_is_weights\", None)\n\n                    # NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics\n                    # are computed centrally in ray_trainer.py for consistency and efficiency.\n                    # This ensures metrics are computed uniformly across all batches at the trainer level\n                    # and avoids redundant computation across workers and micro-batches.\n\n                    # gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg\n                    # clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov\n                    policy_loss_fn = get_policy_loss_fn(loss_mode)\n\n                    # Compute policy loss (all functions return 4 values)\n                    pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(\n                        old_log_prob=old_log_prob,\n                        log_prob=log_prob,\n                        advantages=advantages,\n                        response_mask=response_mask,\n                        loss_agg_mode=loss_agg_mode,\n                        config=self.config,\n                        rollout_is_weights=rollout_is_weights,\n                    )\n\n                    if entropy_coeff != 0:\n                        entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n                        # compute policy loss\n                        policy_loss = pg_loss - entropy_loss * entropy_coeff\n                    else:\n                        policy_loss = pg_loss\n\n                    if self.config.use_kl_loss:\n                        ref_log_prob = model_inputs[\"ref_log_prob\"]\n                        # compute kl loss\n                        kld = kl_penalty(\n                            logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type\n                        )\n                        kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n                        policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef\n                        micro_batch_metrics[\"actor/kl_loss\"] = kl_loss.detach().item() * loss_scale_factor\n                        micro_batch_metrics[\"actor/kl_coef\"] = self.config.kl_loss_coef\n\n                    if self.config.use_dynamic_bsz:\n                        # relative to the dynamic bsz\n                        loss = policy_loss * loss_scale_factor\n                    else:\n                        loss = policy_loss * loss_scale_factor\n                    loss.backward()\n\n                    micro_batch_metrics.update(\n                        {\n                            \"actor/pg_loss\": pg_loss.detach().item() * loss_scale_factor,\n                            \"actor/pg_clipfrac\": pg_clipfrac.detach().item(),\n                            \"actor/ppo_kl\": ppo_kl.detach().item(),\n                            \"actor/pg_clipfrac_lower\": pg_clipfrac_lower.detach().item(),\n                        }\n                    )\n                    append_to_dict(metrics, micro_batch_metrics)\n\n                grad_norm = self._optimizer_step()\n                mini_batch_metrics = {\"actor/grad_norm\": grad_norm.detach().item()}\n                append_to_dict(metrics, mini_batch_metrics)\n        self.actor_optimizer.zero_grad()\n        return metrics\n"
  },
  {
    "path": "verl_distillation/verl/workers/actor/megatron_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMegatron Actor.\nIn megatron actor, the differences are:\n1. We only make minibatch\n\nNote that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer\n\"\"\"\n\nimport itertools\nimport logging\nimport os\nfrom functools import partial\nfrom typing import Iterable\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.distributed import finalize_model_grads\n\n# from megatron.core.optimizer import DistributedOptimizer\nfrom megatron.core.optimizer import DistributedOptimizer\nfrom megatron.core.pipeline_parallel import get_forward_backward_func\nfrom omegaconf import OmegaConf\nfrom torch import nn\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.megatron.pipeline_parallel import make_batch_generator\nfrom verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits\nfrom verl.utils.megatron_utils import get_model_config\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.profiler.profile import Profiler\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.utils.torch_functional import broadcast_dict_tensor\nfrom verl.workers.actor import BasePPOActor\n\n__all__ = [\"MegatronPPOActor\"]\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MegatronPPOActor(BasePPOActor):\n    def __init__(\n        self,\n        config,\n        model_config,\n        hf_config,\n        tf_config,\n        actor_module: nn.ModuleList,\n        actor_optimizer: DistributedOptimizer,\n    ):\n        \"\"\"MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron.\n\n        Args:\n            config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain\n\n                ``ppo_micro_batch_size_per_gpu``: micro batch size when updating ppo.\n\n                ``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data.\n\n                ``ppo_epochs``: number of epochs to update the actor using the batch data.\n\n                ``shuffle``: whether to shuffle the data after each ppo epoch.\n\n                ``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347.\n\n                ``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347.\n            model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and\n                ``model_config.hidden_size``\n            hf_config (PretrainedConfig): huggingface config\n            tf_config (TransformerConfig): mcore transformer config\n            actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this\n                pp stage.\n                each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for\n                more details.\n                The actor module has some constraints to follow in order to use the updating logics implemented here\n\n                1. It must implement unpad_input before any computation and pad_input after all the computation.\n                Remove padding is an\n                optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn\n                (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py).\n\n                2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size],\n                where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size\n                of the hidden state is [total_nnz // tp, 1, hidden_size].\n            actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron.\n                It implements\n                zero1 optimizer that shards the optimizer state across dp ranks.\n\n        >>> from megatron.training import get_model\n        >>> from megatron.optimizer import get_megatron_optimizer\n        >>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True)\n        >>> actor_module = nn.ModuleList(actor_module)\n        >>> actor_optimizer = get_megatron_optimizer(actor_module)\n        >>> actor = MegatronPPOActor(config=config,\n        >>>                          model_config=actor_model_config,\n        >>>                          hf_config=hf_config,\n        >>>                          tf_config=tf_config,\n        >>>                          actor_module=actor_module,\n        >>>                          actor_optimizer=actor_optimizer)\n        \"\"\"\n        super().__init__(config)\n        self._validate_config(config)\n        self.model_config = model_config\n        self.hf_config = hf_config\n        self.tf_config = tf_config\n        self.actor_module = actor_module\n        self.actor_optimizer: DistributedOptimizer = actor_optimizer\n        self.use_torch_profiler = self.config.profiler.get(\"tool\") == \"torch\"\n        if self.use_torch_profiler:\n            self.prof = Profiler(\n                self.config.profiler, tool_config=self.config.profiler.get(\"tool_config\", {}).get(\"torch\", {})\n            )\n        else:\n            self.prof = None\n        self.use_fused_kernels = self.config.get(\"use_fused_kernels\", False)\n        if self.use_fused_kernels and not getattr(self.config, \"overlap_moe_expert_parallel_comm\", False):\n            # do not patch if overlap_moe_expert_parallel_comm is enabled\n            from verl.models.mcore.model_forward_fused import patch_fused_forward\n\n            for model in self.actor_module:\n                patch_fused_forward(model)\n\n        self.optimizer_step_args = OmegaConf.create(\n            {\n                \"skip_grad\": None,\n                \"overlap_dp_param_comm\": False,\n                \"overlap_dp_grad_comm\": False,\n                \"gradient_accumulation_steps\": 1,\n                \"sequence_parallel\": self.tf_config.sequence_parallel,\n                \"DDP_impl\": \"local\",\n                \"layernorm_allreduce_bucket_threshold\": 0,\n                \"pipeline_model_parallel_split_rank\": None,\n                \"reduce_grads_use_alltoall\": False,\n            }\n        )\n\n        config = get_model_config(self.actor_module[0])\n        print(config)\n        config.finalize_model_grads_func = finalize_model_grads\n\n    def _validate_config(self, config) -> None:\n        \"\"\"Validate config options not implemented for Megatron backend\"\"\"\n        assert config.get(\"ulysses_sequence_parallel_size\", 1) == 1\n        if config.get(\"shuffle\", False):\n            assert config.data_loader_seed is not None, \"If shuffle dataloader, seed must be manually set\"\n        if config.megatron.tensor_model_parallel_size == 1:\n            print(\"[Warining] Because actor tp size == 1, set sp to False\")\n            config.megatron.sequence_parallel = False\n        self.config = config\n\n    @GPUMemoryLogger(role=\"megatron actor\", logger=logger)\n    def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:\n        \"\"\"Compute the log probability of the responses given input_ids, attention_mask and position_ids\n\n        Args:\n            data (DataProto): a DataProto containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the\n                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.\n\n        Returns:\n            DataProto: torch.Tensor: the log_prob tensor\n        \"\"\"\n        use_dynamic_bsz = data.meta_info.get(\"use_dynamic_bsz\", False)\n        micro_batch_size = data.meta_info.get(\"micro_batch_size\", None)\n        max_token_len = data.meta_info.get(\"max_token_len\", None)\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            max_token_len = max_token_len * self.config.megatron.context_parallel_size\n        else:\n            assert micro_batch_size is not None, (\n                \"micro batch size is needed for forward compute when use_dynamic_bsz is False\"\n            )\n\n        def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None):\n            response = data[\"responses\"]\n            response_length = response.size(1)\n            log_probs = output[\"log_probs\"][:, -response_length - 1 : -1].contiguous()\n            return {\"log_probs\": log_probs}\n\n        # We make recompute_old_log_prob by default here.\n        # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be\n        # handled by user outside\n        recompute_old_log_prob = self.config.get(\"recompute_old_log_prob\", True)\n\n        entropys = torch.Tensor()\n        if recompute_old_log_prob:\n            select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n            batch = data.select(batch_keys=select_keys).batch\n            input_ids = batch[\"input_ids\"]\n            batch_size = input_ids.size(0)\n            response = batch[\"responses\"]\n            response_length = response.size(1)\n            with torch.no_grad():\n                output = self.forward_backward_batch(\n                    data,\n                    forward_only=True,\n                    post_process_fn=compute_logprobs_fn,\n                    calculate_entropy=calculate_entropy,\n                    use_dynamic_bsz=use_dynamic_bsz,\n                    micro_batch_size=micro_batch_size,\n                    max_token_len=max_token_len,\n                )\n                if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                    # only on last rank. It should be on every tp rank\n                    if calculate_entropy:\n                        log_probs = [o[0][\"log_probs\"] for o in output[\"output\"]]  # (bs, seq_size)\n                    else:\n                        log_probs = [o[\"log_probs\"] for o in output[\"output\"]]  # (bs, seq_size)\n                    log_probs = torch.cat(log_probs, dim=0).to(torch.float32)\n                    if use_dynamic_bsz:\n                        indices = output[\"indices\"]\n                        indices = list(itertools.chain.from_iterable(indices))\n                        assert len(indices) == log_probs.size(0), f\"{len(indices)} vs. {log_probs.size()}\"\n                        revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                        log_probs = log_probs[revert_indices]\n                else:\n                    log_probs = torch.empty(\n                        size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device\n                    )\n                log_probs = log_probs.to(get_device_id())\n                # broadcast across pp ranks\n                torch.distributed.broadcast(\n                    tensor=log_probs,\n                    src=mpu.get_pipeline_model_parallel_last_rank(),\n                    group=mpu.get_pipeline_model_parallel_group(),\n                    async_op=False,\n                )\n                log_probs = log_probs.to(\"cpu\")\n                if calculate_entropy:\n                    # Note that o[0] is metrics, o[1] is entropy\n                    if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                        entropys = torch.cat([o[1] for o in output[\"output\"]], dim=0)\n                        entropys = entropys.to(torch.float32)\n                        if use_dynamic_bsz:\n                            indices = output[\"indices\"]\n                            indices = list(itertools.chain.from_iterable(indices))\n                            assert len(indices) == entropys.size(0), f\"{len(indices)} vs. {entropys.size()}\"\n                            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                            entropys = entropys[revert_indices]\n                    else:\n                        entropys = torch.empty(\n                            size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device\n                        )\n                    # broadcast across pp ranks\n                    entropys = entropys.to(get_device_id())\n                    torch.distributed.broadcast(\n                        tensor=entropys,\n                        src=mpu.get_pipeline_model_parallel_last_rank(),\n                        group=mpu.get_pipeline_model_parallel_group(),\n                        async_op=False,\n                    )\n                    entropys = entropys.to(\"cpu\")\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        return log_probs, entropys\n\n    def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:\n        \"\"\"Make minibatch iterator for updating the actor\n\n        Args:\n            data (DataProto): a DataProto containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where\n                ``sequence_length = prompt_length + response_length``\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64\n\n                ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that\n                responses = input_ids[:, -response_length:]\n\n                ``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability\n                of responses.\n\n                ``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of\n                responses.\n                See PPO paper for details. https://arxiv.org/abs/1707.06347\n\n        Returns:\n\n        \"\"\"\n        select_keys = [\n            \"responses\",\n            \"input_ids\",\n            \"attention_mask\",\n            \"response_mask\",\n            \"position_ids\",\n            \"old_log_probs\",\n            \"advantages\",\n        ]\n        if self.config.use_kl_loss:\n            select_keys.append(\"ref_log_prob\")\n        # Include pre-computed IS weights if present in batch\n        # Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True\n        if \"rollout_is_weights\" in data.batch.keys():\n            select_keys.append(\"rollout_is_weights\")\n        self.has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n        if self.has_multi_modal_inputs:\n            data = data.select(select_keys, [\"multi_modal_inputs\"])\n        else:\n            data = data.select(batch_keys=select_keys)\n        return data.make_iterator(\n            mini_batch_size=self.config.ppo_mini_batch_size,\n            epochs=self.config.ppo_epochs,\n            seed=self.config.data_loader_seed,\n            dataloader_kwargs={\"shuffle\": self.config.shuffle},\n        )\n\n    def forward_backward_batch(\n        self,\n        data: DataProto,\n        forward_only=False,\n        post_process_fn=None,\n        calculate_entropy=False,\n        use_dynamic_bsz=False,\n        micro_batch_size=None,\n        max_token_len=None,\n        mini_batch_size=None,\n    ):\n        \"\"\"\n        We assume:\n        - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input\n        - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled\n        \"\"\"\n        # broadcast from last pp rank to all other pp ranks\n        # TODO: actually, we just need to control the sampling order.\n        data.to(get_device_id())\n        data.batch = data.batch.contiguous()\n        mini_batch = data\n        broadcast_dict_tensor(\n            mini_batch.batch,\n            src=mpu.get_pipeline_model_parallel_last_rank(),\n            group=mpu.get_pipeline_model_parallel_group(),\n        )\n        mini_batch.to(\"cpu\")\n        # split into micro-batches\n        mini_batch.batch[\"attention_mask\"] = mini_batch.batch[\"attention_mask\"].to(bool)\n        self.has_multi_modal_inputs = \"multi_modal_inputs\" in mini_batch.non_tensor_batch.keys()\n        if self.has_multi_modal_inputs:\n            mini_batch.batch[\"multi_modal_inputs\"] = mini_batch.non_tensor_batch[\"multi_modal_inputs\"]\n            mini_batch.batch[\"multi_modal_inputs_idx\"] = torch.Tensor(\n                list(range(len(mini_batch.non_tensor_batch[\"multi_modal_inputs\"])))\n            ).to(torch.int64)\n\n        if mini_batch.batch[\"position_ids\"].dim() == 3:  # qwen2vl mrope [bs, 3, seq_len]\n            mini_batch.batch[\"position_ids\"] = mini_batch.batch[\"position_ids\"][\n                :, 0\n            ]  # mcore patch recompute qwen2vl's pos ids during forward\n\n        indices = None\n        temperature = data.meta_info[\"temperature\"]\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n            if vpp_size is not None and vpp_size > 1:\n                microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage\n                micro_batches, indices = rearrange_micro_batches(\n                    batch=mini_batch.batch,\n                    num_batches_divided_by=microbatch_group_size_per_vp_stage,\n                    max_token_len=max_token_len,\n                )\n                assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (\n                    f\"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage \"\n                    f\"{microbatch_group_size_per_vp_stage} for megatron backend\"\n                )\n            else:\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)\n            total_seqlen = max_token_len\n        else:\n            assert micro_batch_size is not None, (\n                \"micro_batch_size is needed to be passed in when not using dynamic batch size\"\n            )\n            micro_batches = mini_batch.batch.split(micro_batch_size)\n            seq_len = micro_batches[0][\"input_ids\"].shape[1]\n            total_seqlen = micro_batch_size * seq_len\n        # compute input shapes for pp stages\n        n_micro_batch = len(micro_batches)\n\n        forward_backward_func = get_forward_backward_func()\n\n        def loss_func(output, data, meta_info):\n            # For memory efficiency\n            # We move calculation of entropy to compute_log_probs, forward_only == True\n            log_probs = None\n            entropy = None\n            if isinstance(output, dict):\n                log_probs = output[\"log_probs\"]\n                if \"entropy\" in output:\n                    entropy = output[\"entropy\"]\n            else:\n                assert isinstance(output, torch.Tensor)\n                log_probs = output\n\n            device = log_probs.device\n            metrics = {}\n            if forward_only:\n                if post_process_fn is None:\n                    pass\n                    # metrics[\"logits\"] = output\n                else:\n                    stats = post_process_fn(output, data)\n                    metrics.update(stats)\n                if not calculate_entropy:\n                    return torch.tensor(1.0, device=device), metrics\n\n            responses = data[\"responses\"]\n            response_length = responses.size(1)\n            response_mask = data[\"response_mask\"].to(bool)\n            loss_agg_mode = self.config.loss_agg_mode\n            # compute policy loss\n            log_prob = log_probs[:, -response_length - 1 : -1].contiguous()\n            ret_entropy = None\n            stats = {}\n            if not forward_only:\n                old_log_prob = data[\"old_log_probs\"]\n                advantages = data[\"advantages\"]\n\n                entropy_coeff = self.config.entropy_coeff\n                loss_agg_mode = self.config.loss_agg_mode\n\n                loss_mode = self.config.policy_loss.get(\"loss_mode\", \"vanilla\")\n\n                policy_loss_fn = get_policy_loss_fn(loss_mode)\n\n                # Extract pre-computed rollout importance sampling weights if present\n                # Weights are computed centrally in trainer and added when algorithm.rollout_is=True\n                rollout_is_weights = data.get(\"rollout_is_weights\", None)\n\n                # NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics\n                # are computed centrally in ray_trainer.py for consistency and efficiency.\n                # This ensures metrics are computed uniformly across all batches at the trainer level\n                # and avoids redundant computation across workers and micro-batches.\n                pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(\n                    old_log_prob=old_log_prob,\n                    log_prob=log_prob,\n                    advantages=advantages,\n                    response_mask=response_mask,\n                    loss_agg_mode=loss_agg_mode,\n                    config=self.config,\n                    rollout_is_weights=rollout_is_weights,\n                )\n\n                stats.update(\n                    {\n                        \"actor/pg_loss\": pg_loss.detach().item(),\n                        \"actor/pg_clipfrac\": pg_clipfrac.detach().item(),\n                        \"actor/ppo_kl\": ppo_kl.detach().item(),\n                        \"actor/pg_clipfrac_lower\": pg_clipfrac_lower.detach().item(),\n                    }\n                )\n                policy_loss = pg_loss\n\n            if calculate_entropy:\n                entropy = output[\"entropy\"][:, -response_length - 1 : -1].contiguous()\n                if not forward_only:\n                    entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n                    entropy_coeff = meta_info[\"entropy_coeff\"]\n                    policy_loss = pg_loss - entropy_coeff * entropy_loss\n                else:\n                    ret_entropy = entropy\n\n            if forward_only:\n                policy_loss = torch.tensor(1.0, device=device)\n            else:\n                if self.config.use_kl_loss:\n                    ref_log_prob = data[\"ref_log_prob\"]\n                    # compute kl loss\n                    kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)\n                    kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)\n\n                    policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef\n                    metrics[\"actor/kl_loss\"] = kl_loss.detach().item()\n                    metrics[\"actor/kl_coef\"] = self.config.kl_loss_coef\n\n                # return loss and stats\n\n            append_to_dict(metrics, stats)\n            return policy_loss, [metrics, ret_entropy]\n\n        def forward_step(batch_iter, model, return_schedule_plan: bool = False):\n            \"\"\"\n            Args:\n                batch_iter: the batch iterator\n                model: the model\n                return_schedule_plan: whether to return the schedule plan, for 1f1b overlap\n            \"\"\"\n            if return_schedule_plan:\n                assert self.tf_config.overlap_moe_expert_parallel_comm, (\n                    \"overlap_moe_expert_parallel_comm must be enabled to return the schedule plan\"\n                )\n                # TODO: Fix this\n                assert not calculate_entropy, \"calculate_entropy must be disabled to return the schedule plan\"\n                from megatron.core.models.gpt.gpt_model import GPTModel\n\n                assert isinstance(model, GPTModel), \"model must be a GPTModel\"\n                assert self.use_fused_kernels, \"use_fused_kernels must be enabled to return the schedule plan\"\n                # TODO: support VLM with MoE\n                from verl.models.mcore.model_forward_1f1b_overlap import gptmodel_forward_1f1b_overlap\n\n            batch = next(batch_iter)\n            batch = batch.to(get_device_id())\n            batch = batch.contiguous()\n\n            input_ids = batch[\"input_ids\"]\n            attention_mask = batch[\"attention_mask\"].to(bool)\n            position_ids = batch[\"position_ids\"]\n\n            multi_modal_inputs = {}\n            if \"multi_modal_inputs\" in batch:\n                from verl.utils.model import extract_multi_modal_inputs\n\n                indices = batch.get(\"multi_modal_inputs_idx\", None)\n                multi_modal_inputs = extract_multi_modal_inputs(batch[\"multi_modal_inputs\"], indices)\n            responses = batch[\"responses\"]\n            response_length = responses.size(1)\n            label = position_ids.clone()\n            label[:, -response_length - 1 : -1] = responses\n            label_mask = attention_mask.clone()\n            label_mask[:, : -response_length - 1] = False\n            label_mask[:, -1] = False\n\n            from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn\n\n            if self.use_fused_kernels:\n                forward_fn = get_mcore_forward_fused_fn(self.hf_config)\n                if return_schedule_plan:\n                    forward_fn = gptmodel_forward_1f1b_overlap\n                # return dict of [logits, entropy]\n                output = forward_fn(\n                    model=model,\n                    input_ids=input_ids,\n                    position_ids=position_ids,\n                    attention_mask=attention_mask,\n                    labels=label,\n                    labels_mask=label_mask,\n                    temperature=temperature,\n                    multi_modal_inputs=multi_modal_inputs,\n                )\n            else:\n                forward_fn = get_mcore_forward_fn(self.hf_config)\n\n                def logits_processor(logits, label, label_mask):\n                    assert logits.shape[:2] == label.shape[:2]\n                    assert label.shape == label_mask.shape\n                    logits.div_(temperature)\n                    ret = {}\n                    if calculate_entropy:\n                        logits_bak = logits.clone()\n                        logger.warning_once(\n                            \"For memory-efficient computation, enable fused kernels via \"\n                            \"`actor_rollout_ref.model.use_fused_kernels=True`. \"\n                            \"The current `clone()` operation ensures correctness but increases memory usage.\"\n                        )\n                        entropy = vocab_parallel_entropy(logits)\n                        ret[\"entropy\"] = entropy\n                    else:\n                        logits_bak = logits\n                    log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)\n                    log_probs = log_probs.masked_fill(~label_mask, 0.0)\n                    ret[\"log_probs\"] = log_probs\n                    return ret\n\n                logits_processor_args = {\"label\": label, \"label_mask\": label_mask}\n                output = forward_fn(\n                    model=model,\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    multi_modal_inputs=multi_modal_inputs,\n                    logits_processor=logits_processor,\n                    logits_processor_args=logits_processor_args,\n                )\n\n            if forward_only:\n                meta_info = None\n            else:\n                clip_ratio_c = self.config.get(\"clip_ratio_c\", 3.0)\n                meta_info = {\n                    \"clip_ratio\": self.config.clip_ratio,\n                    \"entropy_coeff\": self.config.entropy_coeff,\n                    \"clip_ratio_c\": clip_ratio_c,\n                }\n            return output, partial(loss_func, data=batch, meta_info=meta_info)\n\n        # batch should be a list of batches inside micro-batches\n        batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.actor_module))\n\n        # TODO: we may use the new schedule instead\n        # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)\n        if mpu.get_pipeline_model_parallel_world_size() > 1:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.actor_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # no use when input_shapes was set\n                micro_batch_size=1,  # no use when input_shapes was set\n                forward_only=forward_only,\n            )\n        else:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.actor_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # in use for pp = 1\n                micro_batch_size=1,  # in use for pp = 1\n                forward_only=forward_only,\n            )\n        # loss_reduces contains the stats returned from loss_func\n\n        if self.has_multi_modal_inputs:\n            data.batch.pop(\"multi_modal_inputs\")\n            data.batch.pop(\"multi_modal_inputs_idx\")\n            data.non_tensor_batch.pop(\"multi_modal_inputs\")\n\n        losses_reduced = {\"output\": losses_reduced}\n        if use_dynamic_bsz:\n            losses_reduced[\"indices\"] = indices\n        return losses_reduced\n\n    @GPUMemoryLogger(role=\"megatron actor\", logger=logger)\n    def update_policy(self, dataloader: Iterable[DataProto]) -> dict:\n        \"\"\"Update the policy with an iterator of DataProto\n\n        Args:\n            dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator``\n                The keys of each data batch is described in the make_minibatch_iterator.\n\n        Returns:\n            Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage\n            and users have to combine the output in each dp rank manually.\n\n        \"\"\"\n        metrics = {}\n        if self.use_torch_profiler and self.prof and self.prof.enable:\n            self.prof.start()\n        for data in dataloader:\n            self.actor_optimizer.zero_grad()\n            # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm\n            for chunk in self.actor_module:\n                # if use distributed optimizer, zero grad buffer will be handled by optimizer\n                chunk.zero_grad_buffer()\n\n            calculate_entropy = self.config.entropy_coeff != 0\n            if data.meta_info.get(\"micro_batch_size\", None) is not None:\n                micro_batch_size = data.meta_info[\"micro_batch_size\"]\n            else:\n                micro_batch_size = self.config.ppo_micro_batch_size_per_gpu\n            max_token_len = None\n            if self.config.use_dynamic_bsz:\n                max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size\n            metric_micro_batch = self.forward_backward_batch(\n                data,\n                calculate_entropy=calculate_entropy,\n                use_dynamic_bsz=self.config.use_dynamic_bsz,\n                micro_batch_size=micro_batch_size,\n                max_token_len=max_token_len,\n                mini_batch_size=self.config.ppo_mini_batch_size,\n            )\n            metric_micro_batch = metric_micro_batch[\"output\"]\n            for metric in metric_micro_batch:\n                # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask\n                append_to_dict(metrics, metric[0])  # append the metric from this micro-batch to global metrics.\n\n            update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step()\n            data = {\"actor/grad_norm\": grad_norm}\n            append_to_dict(metrics, data)\n\n            if update_successful:\n                # allgather already execute in optimizer.step in new megatron\n                pass\n            else:\n                raise NotImplementedError\n            if self.use_torch_profiler and self.prof and self.prof.enable:\n                self.prof.step()\n        # add empty cache after each compute\n        if self.use_torch_profiler and self.prof and self.prof.enable:\n            self.prof.stop_and_save()\n            self.prof.stop_trace()\n        get_torch_device().empty_cache()\n        return metrics\n"
  },
  {
    "path": "verl_distillation/verl/workers/config/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom . import actor, critic, engine, model, optimizer, reward_model, rollout\nfrom .actor import *  # noqa: F401\nfrom .critic import *  # noqa: F401\nfrom .engine import *  # noqa: F401\nfrom .model import *  # noqa: F401\nfrom .optimizer import *  # noqa: F401\nfrom .reward_model import *  # noqa: F401\nfrom .rollout import *  # noqa: F401\n\n__all__ = (\n    actor.__all__\n    + critic.__all__\n    + reward_model.__all__\n    + engine.__all__\n    + optimizer.__all__\n    + rollout.__all__\n    + model.__all__\n)\n"
  },
  {
    "path": "verl_distillation/verl/workers/config/actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, Optional\n\nfrom omegaconf import MISSING\n\nfrom verl.base_config import BaseConfig\nfrom verl.trainer.config import CheckpointConfig\nfrom verl.utils.profiler.config import ProfilerConfig\n\nfrom .engine import FSDPEngineConfig, McoreEngineConfig\nfrom .model import HFModelConfig\nfrom .optimizer import OptimizerConfig\n\n__all__ = [\"PolicyLossConfig\", \"ActorConfig\", \"FSDPActorConfig\", \"McoreActorConfig\"]\n\n\n@dataclass\nclass PolicyLossConfig(BaseConfig):\n    \"\"\"Configuration for policy loss computation.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        loss_mode (str): Loss function mode. Options: 'vanilla', 'clip-cov', 'kl-cov', 'gpg'.\n        clip_cov_ratio (float): Ratio of tokens to be clipped for clip-cov loss.\n        clip_cov_lb (float): Lower bound for clip-cov loss.\n        clip_cov_ub (float): Upper bound for clip-cov loss.\n        kl_cov_ratio (float): Ratio of tokens to be applied KL penalty for kl-cov loss.\n        ppo_kl_coef (float): KL divergence penalty coefficient.\n    \"\"\"\n\n    loss_mode: str = \"vanilla\"\n    clip_cov_ratio: float = 0.0002\n    clip_cov_lb: float = 1.0\n    clip_cov_ub: float = 5.0\n    kl_cov_ratio: float = 0.0002\n    ppo_kl_coef: float = 0.1\n\n\n@dataclass\nclass ActorConfig(BaseConfig):\n    \"\"\"Configuration for actor model training.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        strategy (str): Training strategy. Must be specified.\n        ppo_mini_batch_size (int): Mini-batch size for PPO training.\n        ppo_micro_batch_size (Optional[int]): Micro-batch size for PPO training.\n            If None, uses ppo_micro_batch_size_per_gpu.\n        ppo_micro_batch_size_per_gpu (Optional[int]): Micro-batch size per GPU for PPO training.\n        use_dynamic_bsz (bool): Whether to use dynamic batch sizing.\n        ppo_max_token_len_per_gpu (int): Maximum token length per GPU for PPO training.\n        clip_ratio (float): PPO clipping ratio for policy loss.\n        clip_ratio_low (float): Lower bound for PPO clipping ratio.\n        clip_ratio_high (float): Upper bound for PPO clipping ratio.\n        policy_loss (PolicyLossConfig): Configuration for policy loss computation.\n        clip_ratio_c (float): Clipping ratio for critic loss.\n        loss_agg_mode (str): Loss aggregation mode. Options: 'token-mean', 'sample-mean'.\n        entropy_coeff (float): Entropy coefficient for regularization.\n        use_kl_loss (bool): Whether to use KL divergence loss.\n        use_torch_compile (bool): Whether to use torch.compile for optimization.\n        kl_loss_coef (float): KL divergence loss coefficient.\n        kl_loss_type (str): Type of KL loss to use.\n        ppo_epochs (int): Number of PPO epochs per training step.\n        shuffle (bool): Whether to shuffle data during training.\n        checkpoint (CheckpointConfig): Configuration for checkpointing.\n        optim (OptimizerConfig): Configuration for optimizer.\n        use_fused_kernels (bool): Whether to use custom fused kernels (e.g., FlashAttention, fused MLP).\n    \"\"\"\n\n    _mutable_fields = BaseConfig._mutable_fields | {\n        \"ppo_mini_batch_size\",\n        \"ppo_micro_batch_size\",\n        \"ppo_micro_batch_size_per_gpu\",\n        \"ppo_infer_micro_batch_size_per_gpu\",\n    }\n\n    strategy: str = MISSING\n    ppo_mini_batch_size: int = 256\n    ppo_micro_batch_size: Optional[int] = None  # deprecate\n    ppo_micro_batch_size_per_gpu: Optional[int] = None\n    ppo_infer_micro_batch_size_per_gpu: Optional[int] = None\n    use_dynamic_bsz: bool = False\n    ppo_max_token_len_per_gpu: int = 16384\n    ppo_infer_max_token_len_per_gpu: int = 16384\n    clip_ratio: float = 0.2\n    clip_ratio_low: float = 0.2\n    clip_ratio_high: float = 0.2\n    freeze_vision_tower: bool = False\n    policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig)\n    clip_ratio_c: float = 3.0\n    loss_agg_mode: str = \"token-mean\"\n    entropy_coeff: float = 0\n    use_kl_loss: bool = False\n    use_torch_compile: bool = True\n    kl_loss_coef: float = 0.001\n    kl_loss_type: str = \"low_var_kl\"\n    ppo_epochs: int = 1\n    shuffle: bool = False\n    checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)\n    optim: OptimizerConfig = field(default_factory=OptimizerConfig)\n    use_fused_kernels: bool = False\n    profiler: ProfilerConfig = field(default_factory=ProfilerConfig)\n    engine: BaseConfig = field(default_factory=BaseConfig)\n    data_loader_seed = 1\n    rollout_n: int = 1  # must be override by sampling config\n    ref_log_prob_replace_val: float = -10.0\n    model_config: HFModelConfig = field(default_factory=BaseConfig)\n\n    def __post_init__(self):\n        \"\"\"Validate actor configuration parameters.\"\"\"\n        assert self.strategy != MISSING\n        assert self.rollout_n != MISSING\n        if not self.use_dynamic_bsz:\n            if self.ppo_micro_batch_size is not None and self.ppo_micro_batch_size_per_gpu is not None:\n                raise ValueError(\n                    \"[actor] You have set both 'actor.ppo_micro_batch_size' AND 'actor.ppo_micro_batch_size_per_gpu'. \"\n                    \"Please remove 'actor.ppo_micro_batch_size' because only '*_ppo_micro_batch_size_per_gpu' is \"\n                    \"supported (the former is deprecated).\"\n                )\n            else:\n                assert not (self.ppo_micro_batch_size is None and self.ppo_micro_batch_size_per_gpu is None), (\n                    \"[actor] Please set at least one of 'actor.ppo_micro_batch_size' or \"\n                    \"'actor.ppo_micro_batch_size_per_gpu' if use_dynamic_bsz is not enabled.\"\n                )\n\n        valid_loss_agg_modes = [\n            \"token-mean\",\n            \"seq-mean-token-sum\",\n            \"seq-mean-token-mean\",\n            \"seq-mean-token-sum-norm\",\n        ]\n        if self.loss_agg_mode not in valid_loss_agg_modes:\n            raise ValueError(f\"Invalid loss_agg_mode: {self.loss_agg_mode}\")\n\n    def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None):\n        \"\"\"Validate actor configuration with runtime parameters.\"\"\"\n        if not self.use_dynamic_bsz:\n            if train_batch_size < self.ppo_mini_batch_size:\n                raise ValueError(\n                    f\"train_batch_size ({train_batch_size}) must be >= \"\n                    f\"actor.ppo_mini_batch_size ({self.ppo_mini_batch_size})\"\n                )\n\n            sp_size = getattr(self, \"ulysses_sequence_parallel_size\", 1)\n            if self.ppo_micro_batch_size is not None:\n                if self.ppo_mini_batch_size % self.ppo_micro_batch_size != 0:\n                    raise ValueError(\n                        f\"ppo_mini_batch_size ({self.ppo_mini_batch_size}) must be divisible by \"\n                        f\"ppo_micro_batch_size ({self.ppo_micro_batch_size})\"\n                    )\n                if self.ppo_micro_batch_size * sp_size < n_gpus:\n                    raise ValueError(\n                        f\"ppo_micro_batch_size ({self.ppo_micro_batch_size}) * \"\n                        f\"ulysses_sequence_parallel_size ({sp_size}) must be >= n_gpus ({n_gpus})\"\n                    )\n\n    @staticmethod\n    def _check_mutually_exclusive(mbs, mbs_per_gpu, name: str):\n        \"\"\"Validate mutually exclusive micro batch size configuration options.\"\"\"\n        param = \"ppo_micro_batch_size\"\n        param_per_gpu = f\"{param}_per_gpu\"\n\n        if mbs is None and mbs_per_gpu is None:\n            raise ValueError(f\"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.\")\n\n        if mbs is not None and mbs_per_gpu is not None:\n            raise ValueError(\n                f\"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove \"\n                f\"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated).\"\n            )\n\n\n@dataclass\nclass McoreActorConfig(ActorConfig):\n    \"\"\"Configuration for Megatron actor models.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        strategy (str): Training strategy set to 'megatron' for Megatron parallelism.\n        data_loader_seed (Optional[int]): Seed for data loader. If None, uses global seed.\n        load_weight (bool): Whether to load model weights from checkpoint.\n        megatron (dict[str, Any]): Configuration for Megatron parallelism settings.\n        profile (dict[str, Any]): Configuration for profiling settings.\n    \"\"\"\n\n    strategy: str = \"megatron\"\n    data_loader_seed: Optional[int] = None\n    load_weight: bool = True\n    megatron: McoreEngineConfig = field(default_factory=McoreEngineConfig)\n    profile: dict[str, Any] = field(default_factory=dict)\n    use_rollout_log_probs: bool = False\n\n\n@dataclass\nclass FSDPActorConfig(ActorConfig):\n    \"\"\"Configuration for FSDP actor models.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        strategy (str): Training strategy set to 'fsdp' for Fully Sharded Data Parallel.\n        grad_clip (float): Gradient clipping threshold.\n        ulysses_sequence_parallel_size (int): Ulysses sequence parallel size for long sequences.\n        entropy_from_logits_with_chunking (bool): Whether to compute entropy from logits\n            with chunking for memory efficiency.\n        entropy_checkpointing (bool): Whether to use gradient checkpointing for entropy computation.\n        fsdp_config (dict[str, Any]): Configuration for FSDP settings.\n        use_remove_padding (bool): Whether to remove padding tokens in inputs during training\n    \"\"\"\n\n    strategy: str = \"fsdp\"\n    grad_clip: float = 1.0\n    ulysses_sequence_parallel_size: int = 1\n    entropy_from_logits_with_chunking: bool = False\n    entropy_checkpointing: bool = False\n    fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)\n    use_remove_padding: bool = False\n    profiler: ProfilerConfig = field(default_factory=ProfilerConfig)\n    use_rollout_log_probs: bool = False\n\n    def __post_init__(self):\n        \"\"\"Validate FSDP actor configuration parameters.\"\"\"\n        super().__post_init__()\n\n    def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None):\n        \"\"\"Validate FSDP actor configuration with runtime parameters.\"\"\"\n        super().validate(n_gpus, train_batch_size, model_config)\n\n        if self.strategy in {\"fsdp\", \"fsdp2\"} and self.ulysses_sequence_parallel_size > 1:\n            if model_config and not model_config.get(\"use_remove_padding\", False):\n                raise ValueError(\n                    \"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`.\"\n                )\n"
  },
  {
    "path": "verl_distillation/verl/workers/config/critic.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport warnings\nfrom dataclasses import dataclass, field\nfrom typing import Optional\n\nfrom omegaconf import MISSING\n\nfrom verl.base_config import BaseConfig\nfrom verl.trainer.config import BaseModelConfig, CheckpointConfig\nfrom verl.utils.profiler import ProfilerConfig\n\nfrom .engine import FSDPEngineConfig, McoreEngineConfig\nfrom .model import HFModelConfig\nfrom .optimizer import OptimizerConfig\n\n__all__ = [\"CriticConfig\", \"FSDPCriticConfig\", \"McoreCriticConfig\", \"FSDPCriticModelCfg\"]\n\n\n@dataclass\nclass CriticConfig(BaseConfig):\n    \"\"\"Configuration for critic model training.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        strategy (str): Strategy used for critic model training (fsdp, fsdp2, megatron).\n        ppo_micro_batch_size_per_gpu (int): Local per-GPU micro batch size.\n        rollout_n (int): Number of rollouts per update (mirrors actor rollout_n).\n        optim (Dict[str, Any]): Optimizer configuration including lr, weight_decay, etc.\n        model (Dict[str, Any]): Model configuration including path, tokenizer_path, etc.\n        ppo_mini_batch_size (int): PPO mini-batch size per update.\n        ppo_micro_batch_size (Optional[int]): Global micro batch size (deprecated).\n        use_dynamic_bsz (bool): Whether to automatically adjust batch size at runtime.\n        ppo_max_token_len_per_gpu (int): Max tokens per GPU in one PPO batch.\n        forward_max_token_len_per_gpu (int): Max token length per GPU in forward pass.\n        ppo_epochs (int): Number of PPO epochs per batch.\n        shuffle (bool): Shuffle training data across PPO epochs.\n        cliprange_value (float): PPO value function clipping range.\n        loss_agg_mode (str): Loss aggregation mode.\n        checkpoint (Dict[str, Any]): Checkpoint configuration.\n        profiler (Dict[str, Any]): Profiler configuration.\n        enable (Optional[bool]): Whether to enable the critic.\n    \"\"\"\n\n    _mutable_fields = BaseConfig._mutable_fields | {\n        \"ppo_micro_batch_size_per_gpu\",\n        \"ppo_mini_batch_size\",\n        \"ppo_micro_batch_size\",\n        \"model_config\",\n    }\n\n    strategy: str = MISSING\n    ppo_micro_batch_size_per_gpu: Optional[int] = None\n    enable: Optional[bool] = None\n    rollout_n: int = 1\n    ppo_mini_batch_size: int = 1\n    use_dynamic_bsz: bool = False\n    ppo_max_token_len_per_gpu: int = 32768\n    # deprecate this\n    forward_max_token_len_per_gpu: int = 32768\n    ppo_infer_micro_batch_size_per_gpu: Optional[int] = None\n    ppo_infer_max_token_len_per_gpu: int = 32768\n    ppo_epochs: int = 1\n    data_loader_seed: int = 1\n    shuffle: bool = True\n    cliprange_value: float = 0.5\n    loss_agg_mode: str = \"token-mean\"\n    ppo_micro_batch_size: Optional[int] = None\n    engine: BaseConfig = field(default_factory=BaseConfig)\n    optim: OptimizerConfig = field(default_factory=OptimizerConfig)\n    # deprecate model to favor model_config\n    model: BaseModelConfig = field(default_factory=BaseModelConfig)\n    model_config: HFModelConfig = None\n    checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)\n    profiler: ProfilerConfig = field(default_factory=ProfilerConfig)\n\n    def __post_init__(self):\n        \"\"\"Validate critic configuration parameters.\"\"\"\n        assert self.strategy != MISSING\n\n        if self.model_config is None:\n            warnings.warn(\"using model in Critic Config is deprecated, please use model_config instead\", stacklevel=2)\n            self.model_config = self.model\n\n        if not self.use_dynamic_bsz:\n            self._check_mutually_exclusive(self.ppo_micro_batch_size, self.ppo_micro_batch_size_per_gpu, \"critic\")\n\n            if self.ppo_micro_batch_size is not None:\n                if self.ppo_mini_batch_size % self.ppo_micro_batch_size != 0:\n                    raise ValueError(\n                        f\"[critic] ppo_mini_batch_size ({self.ppo_mini_batch_size}) must be divisible by \"\n                        f\"ppo_micro_batch_size ({self.ppo_micro_batch_size})\"\n                    )\n\n    def validate(self, n_gpus: int, train_batch_size: int):\n        \"\"\"Validate critic configuration with runtime parameters.\n\n        Args:\n            n_gpus: Total number of GPUs available\n            train_batch_size: Training batch size from data config\n        \"\"\"\n        if not self.use_dynamic_bsz:\n            if train_batch_size < self.ppo_mini_batch_size:\n                raise ValueError(\n                    f\"train_batch_size ({train_batch_size}) must be >= \"\n                    f\"critic.ppo_mini_batch_size ({self.ppo_mini_batch_size})\"\n                )\n\n    @staticmethod\n    def _check_mutually_exclusive(mbs, mbs_per_gpu, name: str):\n        \"\"\"Validate mutually exclusive micro batch size configuration options.\n\n        Ensures that users don't set both deprecated micro_batch_size and\n        the new micro_batch_size_per_gpu parameters simultaneously.\n\n        Args:\n            mbs: Deprecated micro batch size parameter value.\n            mbs_per_gpu: New micro batch size per GPU parameter value.\n            name (str): Configuration section name for error messages.\n\n        Raises:\n            ValueError: If both parameters are set or neither is set.\n        \"\"\"\n        param = \"micro_batch_size\"\n        param_per_gpu = f\"{param}_per_gpu\"\n\n        if mbs is None and mbs_per_gpu is None:\n            raise ValueError(f\"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.\")\n\n        if mbs is not None and mbs_per_gpu is not None:\n            raise ValueError(\n                f\"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove \"\n                f\"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated).\"\n            )\n\n\n@dataclass\nclass McoreCriticConfig(CriticConfig):\n    \"\"\"Configuration for Megatron-based critic model training.\n\n    The inheritance from CriticConfig provides all base critic configuration plus Megatron-specific settings.\n\n    Args:\n        nccl_timeout (int): NCCL timeout in seconds for distributed operations.\n        megatron (Dict[str, Any]): Megatron-specific parallelism settings.\n        load_weight (bool): Whether to load initial weights.\n        data_loader_seed (Optional[int]): Seed for data loader.\n    \"\"\"\n\n    strategy: str = \"megatron\"\n    nccl_timeout: int = 600\n    megatron: McoreEngineConfig = field(default_factory=McoreEngineConfig)\n    load_weight: bool = True\n    data_loader_seed: Optional[int] = None\n\n    def validate(self, n_gpus: int, train_batch_size: int):\n        \"\"\"Validate Megatron critic configuration with runtime parameters.\"\"\"\n        super().validate(n_gpus, train_batch_size)\n\n\n@dataclass\nclass FSDPCriticConfig(CriticConfig):\n    \"\"\"Configuration for FSDP-based critic model training.\n\n    The inheritance from CriticConfig provides all base critic configuration plus FSDP-specific settings.\n\n    Args:\n        forward_micro_batch_size (int): Forward-only batch size during inference (global).\n        forward_micro_batch_size_per_gpu (int): Forward-only batch size during inference (per GPU).\n        ulysses_sequence_parallel_size (int): Sequence parallelism size for Ulysses-style model parallelism.\n        grad_clip (float): Gradient clipping for critic updates.\n    \"\"\"\n\n    _mutable_fields = CriticConfig._mutable_fields | {\n        \"forward_micro_batch_size\",\n        \"forward_micro_batch_size_per_gpu\",\n    }\n\n    strategy: str = \"fsdp\"\n    forward_micro_batch_size: int = 1\n    forward_micro_batch_size_per_gpu: int = 1\n    ulysses_sequence_parallel_size: int = 1\n    grad_clip: float = 1.0\n\n    def __post_init__(self):\n        \"\"\"Validate FSDP critic configuration parameters.\"\"\"\n        super().__post_init__()\n\n        if self.strategy in {\"fsdp\", \"fsdp2\"}:\n            if self.ulysses_sequence_parallel_size > 1:\n                if not self.model.get(\"use_remove_padding\", False):\n                    raise ValueError(\n                        \"When using sequence parallelism for critic, you must enable `use_remove_padding`.\"\n                    )\n\n    def validate(self, n_gpus: int, train_batch_size: int):\n        \"\"\"Validate FSDP critic configuration with runtime parameters.\"\"\"\n        super().validate(n_gpus, train_batch_size)\n\n        if not self.use_dynamic_bsz:\n            sp_size = self.ulysses_sequence_parallel_size\n            if self.ppo_micro_batch_size is not None:\n                if self.ppo_micro_batch_size * sp_size < n_gpus:\n                    raise ValueError(\n                        f\"critic.ppo_micro_batch_size ({self.ppo_micro_batch_size}) * \"\n                        f\"ulysses_sequence_parallel_size ({sp_size}) must be >= n_gpus ({n_gpus})\"\n                    )\n\n\n@dataclass\nclass FSDPCriticModelCfg(BaseModelConfig):\n    \"\"\"FSDP-enabled critic model configuration.\n    Inherits base critic settings and adds distributed-memory and LoRA options.\n\n    Args:\n        use_shm (bool): Whether to use shared memory for loading the model.\n        enable_activation_offload (bool): Offload activations to CPU to reduce GPU memory usage.\n        use_remove_padding (bool): Use remove-padding optimization (saves compute).\n        enable_gradient_checkpointing (bool): Enable gradient checkpointing for memory efficiency.\n        fsdp_config (FSDPEngineConfig): FSDP-specific configuration block.\n        lora_rank (int): Set to positive value to enable LoRA (e.g., 32).\n        lora_alpha (int): LoRA scaling factor.\n        target_modules (Union[str, List[str]]): LoRA target modules: \"all-linear\" or list of layer names.\n    \"\"\"\n\n    use_shm: bool = False\n    enable_activation_offload: bool = False\n    use_remove_padding: bool = False\n    enable_gradient_checkpointing: bool = True\n    fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)\n    lora_rank: int = 0\n    lora_alpha: int = 16\n    target_modules: str | list[str] = \"all-linear\"\n"
  },
  {
    "path": "verl_distillation/verl/workers/config/engine.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport warnings\nfrom dataclasses import dataclass, field\nfrom typing import Any, Optional\n\nfrom verl.base_config import BaseConfig\n\n__all__ = [\"FSDPEngineConfig\", \"McoreEngineConfig\"]\n\n\n@dataclass\nclass McoreEngineConfig(BaseConfig):\n    \"\"\"Configuration for Megatron parallelism.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        param_offload (bool): Whether to offload parameters to CPU.\n        grad_offload (bool): Whether to offload gradients to CPU.\n        optimizer_offload (bool): Whether to offload optimizer states to CPU.\n        tensor_model_parallel_size (int): Tensor model parallel size.\n        expert_model_parallel_size (int): Expert model parallel size for MoE models.\n        expert_tensor_parallel_size (Optional[int]): Expert tensor parallel size for MoE models.\n        pipeline_model_parallel_size (int): Pipeline model parallel size.\n        virtual_pipeline_model_parallel_size (Optional[int]): Virtual pipeline model parallel size\n            for interleaved scheduling.\n        context_parallel_size (int): Context parallel size for long sequences.\n        sequence_parallel (bool): Whether to enable sequence parallelism.\n        use_distributed_optimizer (bool): Whether to use distributed optimizer.\n        use_dist_checkpointing (bool): Whether to use distributed checkpointing.\n        dist_checkpointing_path (Optional[str]): Path for distributed checkpointing.\n        seed (int): Random seed for reproducibility.\n        override_ddp_config (dict[str, Any]): Override configuration for DDP.\n        override_transformer_config (dict[str, Any]): Override configuration for transformer.\n        use_mbridge (bool): Whether to use MBridge for communication.\n    \"\"\"\n\n    # sequence_parallel is not listed as a frozen field for auto-correction purpose\n    _mutable_fields = BaseConfig._mutable_fields | {\"sequence_parallel\"}\n\n    param_offload: bool = False\n    grad_offload: bool = False\n    optimizer_offload: bool = False\n    tensor_model_parallel_size: int = 1\n    expert_model_parallel_size: int = 1\n    expert_tensor_parallel_size: Optional[int] = None\n    pipeline_model_parallel_size: int = 1\n    virtual_pipeline_model_parallel_size: Optional[int] = None\n    context_parallel_size: int = 1\n    sequence_parallel: bool = True\n    use_distributed_optimizer: bool = True\n    use_dist_checkpointing: bool = False\n    dist_checkpointing_path: Optional[str] = None\n    seed: int = 42\n    override_ddp_config: dict[str, Any] = field(default_factory=dict)\n    override_transformer_config: dict[str, Any] = field(default_factory=dict)\n    override_mcore_model_config: dict[str, Any] = field(default_factory=dict)\n    use_mbridge: bool = False\n    forward_only: bool = False\n    strategy: str = \"megatron\"\n\n    def __post_init__(self) -> None:\n        \"\"\"config validation logics go here\"\"\"\n        assert self.strategy == \"megatron\"\n        if self.tensor_model_parallel_size == 1:\n            warnings.warn(\"set sequence parallel to false as TP size is 1\", stacklevel=2)\n            self.sequence_parallel = False\n\n\n@dataclass\nclass FSDPEngineConfig(BaseConfig):\n    \"\"\"Configuration for FSDP (Fully Sharded Data Parallel).\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        wrap_policy (Dict[str, Any]): Configuration for FSDP wrap policy.\n        param_offload (bool): Whether to offload parameters to CPU, default False\n        optimizer_offload (bool): Whether to offload optimizer states to CPU, default False\n        offload_policy (bool): Whether to offload policy model parameters, default False\n        reshard_after_forward (bool): Whether to reshard parameters after forward pass, default True\n        fsdp_size (int): FSDP group size. -1 means use all available GPUs.\n        forward_prefetch (bool): Whether to prefetch parameters for next forward pass, default False\n        model_dtype (str): Model data type used to initialize the transformers model. default \"fp32\"\n        use_orig_params (bool): Whether to use original parameters when initialize FSDP1, default False\n        mixed_precision (Optional[dict[str, Any]]): Mixed precision configuration for FSDP, default None\n    \"\"\"\n\n    wrap_policy: dict[str, Any] = field(default_factory=dict)\n    param_offload: bool = False\n    optimizer_offload: bool = False\n    offload_policy: bool = False\n    reshard_after_forward: bool = True\n    fsdp_size: int = -1\n    forward_prefetch: bool = False\n    model_dtype: str = \"fp32\"\n    use_orig_params: bool = False\n    mixed_precision: Optional[dict[str, Any]] = None\n    ulysses_sequence_parallel_size: int = 1\n    entropy_from_logits_with_chunking: bool = False\n    use_torch_compile: bool = True\n    entropy_checkpointing: bool = False\n    forward_only: bool = False\n    strategy: str = \"fsdp\"\n\n    def __post_init__(self):\n        assert self.strategy in [\"fsdp\", \"fsdp2\"], f\"strategy {self.strategy} not supported\"\n"
  },
  {
    "path": "verl_distillation/verl/workers/config/model.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, Optional\n\nfrom omegaconf import MISSING\nfrom transformers import AutoConfig\n\nfrom verl.base_config import BaseConfig\nfrom verl.utils import hf_processor, hf_tokenizer\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.model import get_generation_config, update_model_config\n\n__all__ = [\"HFModelConfig\"]\n\n\n@dataclass\nclass HFModelConfig(BaseConfig):\n    # note that we separate model_path, model_config_path and tokenizer_path in case they are different\n    _mutable_fields = {\n        \"hf_config_path\",\n        \"tokenizer_path\",\n        \"hf_config\",\n        \"generation_config\",\n        \"tokenizer\",\n        \"processor\",\n        \"local_path\",\n        \"architectures\",\n        \"local_hf_config_path\",\n        \"local_tokenizer_path\",\n    }\n\n    path: str = MISSING\n    local_path: Optional[str] = None\n    hf_config_path: Optional[str] = None\n    local_hf_config_path: Optional[str] = None\n    tokenizer_path: Optional[str] = None\n    local_tokenizer_path: Optional[str] = None\n\n    # whether to load tokenizer. This is useful when we only want to load model config\n    load_tokenizer: bool = True\n\n    hf_config: Any = None\n    generation_config: Any = None\n    tokenizer: Any = None\n    processor: Any = None\n\n    # whether to use shared memory\n    use_shm: bool = False\n    trust_remote_code: bool = False\n\n    # custom chat template for the model\n    custom_chat_template: Optional[str] = None\n\n    external_lib: Optional[str] = None\n\n    override_config: dict = field(default_factory=dict)\n\n    enable_gradient_checkpointing: bool = True\n    enable_activation_offload: bool = False\n\n    use_remove_padding: bool = False\n\n    # lora related. We may setup a separate config later\n    lora_rank: int = 0\n    lora_alpha: int = 16\n    target_modules: Optional[str] = \"all-linear\"\n\n    exclude_modules: Optional[str] = None\n\n    # path to pre-trained LoRA adapter to load for continued training\n    lora_adapter_path: Optional[str] = None\n    use_liger: bool = False\n\n    use_fused_kernels: bool = False\n    fused_kernel_options: dict = field(default_factory=dict)\n\n    architectures: Optional[list[str]] = None\n\n    def __post_init__(self):\n        import_external_libs(self.external_lib)\n\n        if self.hf_config_path is None:\n            self.hf_config_path = self.path\n        if self.tokenizer_path is None:\n            self.tokenizer_path = self.path\n\n        self.local_path = copy_to_local(self.path, use_shm=self.use_shm)\n\n        # constuct tokenizer\n        if self.load_tokenizer:\n            self.local_tokenizer_path = copy_to_local(self.tokenizer_path, use_shm=self.use_shm)\n            self.tokenizer = hf_tokenizer(self.local_tokenizer_path, trust_remote_code=self.trust_remote_code)\n            self.processor = hf_processor(self.local_tokenizer_path, trust_remote_code=self.trust_remote_code)\n\n        if self.custom_chat_template is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.custom_chat_template\n            else:\n                self.tokenizer.chat_template = self.custom_chat_template\n\n        self.local_hf_config_path = copy_to_local(self.hf_config_path, use_shm=self.use_shm)\n        self.generation_config = get_generation_config(\n            self.local_hf_config_path, trust_remote_code=self.trust_remote_code\n        )\n\n        # constuct hf_config\n        attn_implementation = self.override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        self.hf_config = AutoConfig.from_pretrained(\n            self.local_hf_config_path, trust_remote_code=self.trust_remote_code, attn_implementation=attn_implementation\n        )\n\n        override_config_kwargs = {}\n\n        if self.tokenizer is not None:\n            override_config_kwargs.update(\n                {\n                    \"bos_token_id\": self.tokenizer.bos_token_id,\n                    \"eos_token_id\": self.tokenizer.eos_token_id,\n                    \"pad_token_id\": self.tokenizer.pad_token_id,\n                }\n            )\n\n        # TODO: (vermouth1992). self.config.model in megatron differs from that of fsdp in the override_config.\n        override_config = (\n            self.override_config[\"model_config\"] if \"model_config\" in self.override_config else self.override_config\n        )\n        override_config_kwargs.update(override_config)\n        update_model_config(self.hf_config, override_config_kwargs=override_config_kwargs)\n\n        self.share_embeddings_and_output_weights = getattr(self.hf_config, \"tie_word_embeddings\", False)\n\n        # get model architectures\n        self.architectures = getattr(self.hf_config, \"architectures\", None)\n        assert self.architectures is not None and len(self.architectures) == 1, (\n            \"Expect only one architecture, got {}\".format(self.architectures)\n        )\n\n        # per model patch\n        if getattr(self.hf_config, \"model_type\", None) == \"kimi_vl\":\n            self.hf_config.text_config.topk_method = \"greedy\"\n\n    def get_processor(self):\n        return self.processor if self.processor is not None else self.tokenizer\n"
  },
  {
    "path": "verl_distillation/verl/workers/config/optimizer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nfrom omegaconf import MISSING\n\nfrom verl.base_config import BaseConfig\n\n__all__ = [\"OptimizerConfig\", \"FSDPOptimizerConfig\", \"McoreOptimizerConfig\", \"build_optimizer\"]\n\n\n@dataclass\nclass OptimizerConfig(BaseConfig):\n    \"\"\"Base optimizer configuration.\n\n    Args:\n        lr (float): learning rate. Must be specified.\n        lr_warmup_steps_ratio (float): Warmup steps ratio; total steps will be injected at runtime.\n        total_training_steps (int): Total training steps (must be overridden at runtime).\n        weight_decay (float): Weight decay factor.\n        lr_warmup_steps (Optional[int]): Number of warmup steps; None delegates to lr_warmup_steps_ratio.\n    \"\"\"\n\n    _mutable_fields = {\"clip_grad\", \"total_training_steps\", \"lr_warmup_steps\"}\n\n    lr: float = 1e-3\n    lr_warmup_steps_ratio: float = 0.0\n    total_training_steps: int = -1\n    weight_decay: float = 0.01\n    lr_warmup_steps: Optional[int] = -1\n    betas: tuple[float, float] = (0.9, 0.999)\n    clip_grad: float = 1.0\n    # deprecate grad_clip\n    grad_clip: Optional[float] = None\n\n    def __post_init__(self):\n        assert self.lr != MISSING\n        if self.grad_clip is not None:\n            warnings.warn(\"`grad_clip` is deprecated, use `clip_grad` instead.\", DeprecationWarning, stacklevel=2)\n            self.clip_grad = self.grad_clip\n\n\n@dataclass\nclass FSDPOptimizerConfig(OptimizerConfig):\n    \"\"\"FSDP optimizer configuration extending base OptimizerConfig.\n\n    Args:\n        optimizer (str): Optimizer class name (e.g., \"AdamW\", \"AdamW8bit\", \"_AdamW\").\n        optimizer_impl (str): Module path to import optimizer from (e.g., \"torch.optim\", \"torchao.optim\",\n            \"bitsandbytes.optim\").\n        lr (float): Learning rate.\n        min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule.\n        lr_scheduler_type (str): LR scheduler type: \"constant\" or \"cosine\".\n        num_cycles (float): Number of cosine cycles in LR schedule.\n    \"\"\"\n\n    _mutable_fields = OptimizerConfig._mutable_fields.copy()\n    _mutable_fields.add(\"lr_scheduler_type\")\n\n    optimizer: str = \"AdamW\"\n    optimizer_impl: str = \"torch.optim\"\n    min_lr_ratio: Optional[float] = None\n    # deprecate warmup_style\n    warmup_style: Optional[str] = None\n    lr_scheduler_type: str = \"constant\"\n    num_cycles: float = 0.5\n    override_optimizer_config: Optional[dict] = None\n\n    def __post_init__(self):\n        if self.warmup_style is not None:\n            assert self.warmup_style in [\"constant\", \"cosine\"]\n            warnings.warn(\n                \"`warmup_style` is deprecated, use `lr_scheduler_type` instead.\", DeprecationWarning, stacklevel=2\n            )\n            self.lr_scheduler_type = self.warmup_style\n        assert self.lr_scheduler_type in [\"constant\", \"cosine\"]\n        return super().__post_init__()\n\n\n@dataclass\nclass McoreOptimizerConfig(OptimizerConfig):\n    \"\"\"Mcore optimizer configuration extending base OptimizerConfig.\n\n    Args:\n        optimizer (str): Optimizer name; default is \"adam\".\n        lr (float): Learning rate.\n        clip_grad (float): Gradient clipping norm.\n        lr_warmup_init (float): Initial learning rate for warmup; defaults to 0.0.\n        lr_decay_steps (Optional[int]): Number of decay steps.\n        lr_decay_style (str): LR decay style: \"constant\", \"linear\", \"cosine\", or \"inverse_square_root\".\n        min_lr (float): Minimum learning rate.\n        weight_decay_incr_style (str): Weight decay increment style: \"constant\" or \"cosine\".\n        lr_wsd_decay_style (str): Weight-standard-deviation decay style: \"constant\", \"exponential\", or \"cosine\".\n        lr_wsd_decay_steps (Optional[int]): Number of steps for weight-standard-deviation decay.\n        use_checkpoint_opt_param_scheduler (bool): Whether to use checkpoint optimizer parameter scheduler.\n    \"\"\"\n\n    optimizer: str = \"adam\"\n    lr_warmup_init: float = 0.0\n    lr_decay_steps: Optional[int] = None\n    lr_decay_style: str = \"linear\"\n    min_lr: float = 0.0\n    weight_decay_incr_style: str = \"constant\"\n    lr_wsd_decay_style: str = \"exponential\"\n    lr_wsd_decay_steps: Optional[int] = None\n    use_checkpoint_opt_param_scheduler: bool = False\n    override_optimizer_config: Optional[dict] = None\n\n\ndef build_optimizer(parameters, config: FSDPOptimizerConfig):\n    \"\"\"Build an optimizer based on the configuration.\n\n    Dynamically imports and instantiates an optimizer class from the specified module.\n\n    Args:\n        parameters: Model parameters to optimize\n        config: FSDPOptimizerConfig with optimizer settings\n\n    Returns:\n        Optimizer instance\n\n    Examples:\n        # PyTorch AdamW\n        config.optimizer_impl = \"torch.optim\"\n        config.optimizer = \"AdamW\"\n\n        # TorchAO AdamW with bf16 stochastic rounding\n        config.optimizer_impl = \"torchao.optim\"\n        config.optimizer = \"_AdamW\"\n        config.override_optimizer_config = {\"bf16_stochastic_round\": True}\n\n        # BitsAndBytes AdamW 8bit\n        config.optimizer_impl = \"bitsandbytes.optim\"\n        config.optimizer = \"AdamW8bit\"\n    \"\"\"\n    import importlib\n\n    optimizer_args = {\n        \"lr\": config.lr,\n        \"weight_decay\": config.weight_decay,\n    }\n\n    optimizer_name_lower = config.optimizer.lower()\n    if \"adam\" in optimizer_name_lower or \"ademamix\" in optimizer_name_lower:\n        optimizer_args[\"betas\"] = config.betas\n\n    if config.override_optimizer_config is not None:\n        optimizer_args.update(config.override_optimizer_config)\n\n    try:\n        module = importlib.import_module(config.optimizer_impl)\n        optimizer_cls = getattr(module, config.optimizer)\n    except ImportError as e:\n        raise ImportError(\n            f\"Failed to import module '{config.optimizer_impl}'. Make sure the package is installed. Error: {e}\"\n        ) from e\n    except AttributeError as e:\n        raise AttributeError(\n            f\"Optimizer '{config.optimizer}' not found in module '{config.optimizer_impl}'. \"\n            f\"Available optimizers: {dir(module)}\"\n        ) from e\n\n    return optimizer_cls(parameters, **optimizer_args)\n"
  },
  {
    "path": "verl_distillation/verl/workers/config/reward_model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import Optional\n\nfrom verl.base_config import BaseConfig\n\nfrom .model import HFModelConfig\nfrom .rollout import RolloutConfig\n\n__all__ = [\"SandboxFusionConfig\", \"RewardModelConfig\"]\n\n\n@dataclass\nclass SandboxFusionConfig(BaseConfig):\n    \"\"\"Configuration for cloud/local sandbox fusion.\n\n    Args:\n        url (Optional[str]): Cloud/local function URL for sandbox execution.\n        max_concurrent (int): Max concurrent requests allowed to sandbox.\n        memory_limit_mb (int): Max memory limit for each sandbox process in MB.\n    \"\"\"\n\n    url: Optional[str] = None\n    max_concurrent: int = 64\n    memory_limit_mb: int = 1024\n\n\n@dataclass\nclass RewardModelConfig(BaseConfig):\n    _mutable_fields = BaseConfig._mutable_fields\n\n    reward_manager: str = \"naive\"\n\n    enable: bool = False\n    enable_resource_pool: bool = False\n    n_gpus_per_node: int = 0\n    nnodes: int = 0\n\n    # reward model args\n    rollout: RolloutConfig = field(default_factory=RolloutConfig)\n    model: HFModelConfig = field(default_factory=HFModelConfig)\n    sandbox_fusion: SandboxFusionConfig = field(default_factory=SandboxFusionConfig)\n"
  },
  {
    "path": "verl_distillation/verl/workers/config/rollout.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import Optional\n\nfrom omegaconf import MISSING\n\nfrom verl.base_config import BaseConfig\nfrom verl.utils.profiler import ProfilerConfig\n\n__all__ = [\n    \"SamplingConfig\",\n    \"MultiTurnConfig\",\n    \"CustomAsyncServerConfig\",\n    \"AgentLoopConfig\",\n    \"TraceConfig\",\n    \"ServerConfig\",\n    \"RolloutConfig\",\n]\n\n\n@dataclass\nclass SamplingConfig(BaseConfig):\n    temperature: float = 1.0\n    top_k: int = -1\n    top_p: float = 1.0\n    do_sample: bool = True\n    n: int = 1\n\n\n@dataclass\nclass MultiTurnConfig(BaseConfig):\n    _mutable_fields = {\"max_assistant_turns\", \"max_user_turns\"}\n\n    enable: bool = False\n    max_assistant_turns: Optional[int] = None\n    tool_config_path: Optional[str] = None\n    max_user_turns: Optional[int] = None\n    max_parallel_calls: int = 1\n    max_tool_response_length: int = 256\n    tool_response_truncate_side: str = \"middle\"\n    interaction_config_path: Optional[str] = None\n    use_inference_chat_template: bool = False\n    tokenization_sanity_check_mode: str = \"strict\"\n    format: str = \"hermes\"\n    num_repeat_rollouts: Optional[int] = None\n\n\n@dataclass\nclass CustomAsyncServerConfig(BaseConfig):\n    path: Optional[str] = None\n    name: Optional[str] = None\n\n\n@dataclass\nclass AgentLoopConfig(BaseConfig):\n    num_workers: int = 8\n    default_agent_loop: str = \"single_turn_agent\"\n    agent_loop_config_path: Optional[str] = None\n    custom_async_server: CustomAsyncServerConfig = field(default_factory=CustomAsyncServerConfig)\n\n\n@dataclass\nclass TraceConfig(BaseConfig):\n    backend: Optional[str] = None\n    token2text: bool = False\n\n\n@dataclass\nclass ServerConfig(BaseConfig):\n    \"\"\"\n    Configuration for SGLang server when running in server mode\n    \"\"\"\n\n    timeout: float = 60.0\n    max_attempts: int = 3\n    retry_delay: float = 2.0\n    max_connections: int = 1000\n    max_start_wait_time: float = 300.0\n\n\n@dataclass\nclass RolloutConfig(BaseConfig):\n    _mutable_fields = {\"max_model_len\", \"load_format\"}\n\n    name: Optional[str] = MISSING\n    mode: str = \"sync\"\n    skip_tokenizer_init: bool = True\n\n    temperature: float = 1.0\n    top_k: int = -1\n    top_p: float = 1.0\n    do_sample: bool = True\n    n: int = 1\n\n    # Early termination threshold for multi-turn rollout in sglang.\n    # Abort remaining requests when (1 - over_sample_rate) * total_requests are completed.\n    over_sample_rate: float = 0.0\n\n    prompt_length: int = 512\n    response_length: int = 512\n\n    dtype: str = \"bfloat16\"\n    gpu_memory_utilization: float = 0.5\n    ignore_eos: bool = False\n    enforce_eager: bool = True\n    cudagraph_capture_sizes: Optional[list] = None\n    free_cache_engine: bool = True\n    data_parallel_size: int = 1\n    expert_parallel_size: int = 1\n    tensor_model_parallel_size: int = 2\n    pipeline_model_parallel_size: int = 1\n    max_num_batched_tokens: int = 8192\n\n    # TODO: enable train_kwargs\n    # train_sampling_config: SamplingConfig = field(default_factory=SamplingConfig)\n\n    val_kwargs: SamplingConfig = field(default_factory=SamplingConfig)\n\n    max_model_len: Optional[int] = None\n    max_num_seqs: int = 1024\n\n    # note that the logprob computation should belong to the actor\n    log_prob_micro_batch_size: Optional[int] = None\n    log_prob_micro_batch_size_per_gpu: Optional[int] = None\n    log_prob_use_dynamic_bsz: bool = False\n    log_prob_max_token_len_per_gpu: int = 16384\n\n    disable_log_stats: bool = True\n\n    multi_stage_wake_up: bool = False\n    engine_kwargs: dict = field(default_factory=dict)\n\n    calculate_log_probs: bool = False\n\n    extend_vocab_start_token: Optional[int] = None\n\n    mask_response_if_have_extend_token: bool = False\n\n    agent: AgentLoopConfig = field(default_factory=AgentLoopConfig)\n\n    trace: TraceConfig = field(default_factory=TraceConfig)\n\n    multi_turn: MultiTurnConfig = field(default_factory=MultiTurnConfig)\n\n    # Server configuration for sglang server mode\n    server: ServerConfig = field(default_factory=ServerConfig)\n\n    update_weights_bucket_megabytes: int = 512\n\n    skip_rollout: bool = False\n\n    skip_dump_dir: str = \"/tmp/rollout_dump\"\n\n    profiler: Optional[ProfilerConfig] = None\n\n    enable_chunked_prefill: bool = True\n\n    enable_prefix_caching: bool = True\n\n    load_format: str = \"dummy\"\n\n    layered_summon: bool = False\n\n    layer_name_map: dict = field(default_factory=dict)\n\n    sglang_engine_mode: str = \"local\"\n\n    limit_images: Optional[int] = None\n\n    skip_tokenizer_init: bool = False\n\n    def __post_init__(self):\n        \"\"\"Validate the rollout config\"\"\"\n        if self.expert_parallel_size > 1:\n            assert self.expert_parallel_size == (self.tensor_model_parallel_size * self.data_parallel_size), (\n                \"expert_parallel_size must be equal to tensor_model_parallel_size * data_parallel_size\"\n            )\n\n        if self.pipeline_model_parallel_size > 1:\n            if self.name == \"vllm\" or self.name == \"sglang\":\n                raise NotImplementedError(\n                    f\"Current rollout {self.name=} not implemented pipeline_model_parallel_size > 1 yet.\"\n                )\n"
  },
  {
    "path": "verl_distillation/verl/workers/critic/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BasePPOCritic\nfrom .dp_critic import DataParallelPPOCritic\n\n__all__ = [\"BasePPOCritic\", \"DataParallelPPOCritic\"]\n"
  },
  {
    "path": "verl_distillation/verl/workers/critic/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nBase class for a critic\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nimport torch\n\nfrom verl import DataProto\n\n__all__ = [\"BasePPOCritic\"]\n\n\nclass BasePPOCritic(ABC):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n\n    @abstractmethod\n    def compute_values(self, data: DataProto) -> torch.Tensor:\n        \"\"\"Compute values\"\"\"\n        pass\n\n    @abstractmethod\n    def update_critic(self, data: DataProto):\n        \"\"\"Update the critic\"\"\"\n        pass\n"
  },
  {
    "path": "verl_distillation/verl/workers/critic/dp_critic.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement a multiprocess PPOCritic\n\"\"\"\n\nimport logging\nimport os\n\nimport torch\nimport torch.distributed\nfrom torch import nn, optim\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nfrom verl import DataProto\nfrom verl.trainer.ppo import core_algos\nfrom verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input\nfrom verl.utils.device import get_device_id, get_device_name\nfrom verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch\nfrom verl.utils.torch_functional import masked_mean\nfrom verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs\nfrom verl.workers.critic import BasePPOCritic\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass DataParallelPPOCritic(BasePPOCritic):\n    def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer):\n        super().__init__(config=config)\n        self.critic_module = critic_module\n        self.critic_optimizer = critic_optimizer\n        self.use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        print(f\"Critic use_remove_padding={self.use_remove_padding}\")\n\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        self.device_name = get_device_name()\n\n    def _forward_micro_batch(self, micro_batch):\n        response_length = micro_batch[\"responses\"].size(-1)\n        multi_modal_inputs = {}\n        if \"multi_modal_inputs\" in micro_batch.keys():\n            from verl.utils.model import extract_multi_modal_inputs\n\n            multi_modal_inputs = extract_multi_modal_inputs(micro_batch[\"multi_modal_inputs\"])\n\n        with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n            if position_ids.dim() == 3:  # qwen2vl mrope\n                position_ids = position_ids.transpose(0, 1)\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                if position_ids.dim() == 3:\n                    position_ids_rmpad = (\n                        index_first_axis(rearrange(position_ids, \"c b s ... -> (b s) c ...\"), indices)\n                        .transpose(0, 1)\n                        .unsqueeze(1)\n                    )  # (4, bsz, seqlen) -> (4, 1, bsz * seqlen)\n                else:\n                    position_ids_rmpad = index_first_axis(\n                        rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                    ).transpose(0, 1)\n\n                # pad and slice the inputs if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size\n                    )\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                output = self.critic_module(\n                    input_ids=input_ids_rmpad,\n                    attention_mask=None,\n                    position_ids=position_ids_rmpad,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                )  # prevent model thinks we are generating\n\n                if hasattr(self.critic_module, \"v_head\"):\n                    # For trl.AutoModelForCausalLMWithValueHead\n                    values_rmpad = output[2].squeeze(0).unsqueeze(-1)\n                else:\n                    values_rmpad = output.logits\n                    values_rmpad = values_rmpad.squeeze(0)  # (total_nnz)\n\n                # gather output if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    values_rmpad = gather_outputs_and_unpad(\n                        values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size\n                    )\n\n                # pad it back\n                values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)\n                values = values[:, -response_length - 1 : -1]\n            else:\n                output = self.critic_module(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                )  # prevent model thinks we are generating\n                if hasattr(self.critic_module, \"v_head\"):\n                    # For trl.AutoModelForCausalLMWithValueHead\n                    values = output[2]\n                else:\n                    values = output.logits\n                values = values[:, -response_length - 1 : -1].squeeze(-1)\n            return values\n\n    def _optimizer_step(self):\n        assert self.config.grad_clip is not None\n\n        if isinstance(self.critic_module, FSDP):\n            grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip)\n        elif isinstance(self.critic_module, FSDPModule):\n            grad_norm = fsdp2_clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip)\n\n        # if grad_norm is not finite, skip the update\n        if not torch.isfinite(grad_norm):\n            print(f\"WARN: grad_norm is not finite: {grad_norm}\")\n            self.critic_optimizer.zero_grad()\n        else:\n            self.critic_optimizer.step()\n        return grad_norm\n\n    @GPUMemoryLogger(role=\"dp critic\", logger=logger)\n    def compute_values(self, data: DataProto) -> torch.Tensor:\n        self.critic_module.eval()\n        micro_batch_size = data.meta_info[\"micro_batch_size\"]\n        use_dynamic_bsz = data.meta_info[\"use_dynamic_bsz\"]\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n        select_keys = (\n            [\"responses\", \"input_ids\", \"response_mask\", \"attention_mask\", \"position_ids\"]\n            if \"response_mask\" in data.batch\n            else [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n        )\n        non_tensor_select_keys = [\"multi_modal_inputs\"] if has_multi_modal_inputs else []\n\n        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)\n\n        if use_dynamic_bsz:\n            max_token_len = data.meta_info[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len)\n        else:\n            micro_batches = data.split(micro_batch_size)\n\n        values_lst = []\n        for micro_batch in micro_batches:\n            micro_batch = micro_batch.to(get_device_id())\n            model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n            with torch.no_grad():\n                values = self._forward_micro_batch(model_inputs)\n            values_lst.append(values)\n        values = torch.concat(values_lst, dim=0)\n\n        if use_dynamic_bsz:\n            values = restore_dynamic_batch(values, batch_idx_list)\n\n        if \"response_mask\" in data.batch:\n            response_mask = data.batch[\"response_mask\"]\n            response_mask = response_mask.to(values.device)\n            values = values * response_mask  # Only action tokens have values\n        return values\n\n    @GPUMemoryLogger(role=\"dp critic\", logger=logger)\n    def update_critic(self, data: DataProto):\n        # make sure we are in training mode\n        self.critic_module.train()\n        metrics = {}\n\n        select_keys = [\"input_ids\", \"responses\", \"response_mask\", \"attention_mask\", \"position_ids\", \"values\", \"returns\"]\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n        non_tensor_select_keys = [\"multi_modal_inputs\"] if has_multi_modal_inputs else []\n\n        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        mini_batches = data.split(self.config.ppo_mini_batch_size)\n\n        for _ in range(self.config.ppo_epochs):\n            for batch_idx, mini_batch in enumerate(mini_batches):\n                if self.config.use_dynamic_bsz:\n                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                    micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len)\n                else:\n                    self.gradient_accumulation = (\n                        self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n                    )\n                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n\n                self.critic_optimizer.zero_grad()\n\n                for micro_batch in micro_batches:\n                    micro_batch = micro_batch.to(get_device_id())\n                    micro_batch_metrics = {}\n                    model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n                    response_mask = model_inputs[\"response_mask\"]\n                    values = model_inputs[\"values\"]\n                    returns = model_inputs[\"returns\"]\n\n                    vpreds = self._forward_micro_batch(model_inputs)\n                    vf_loss, vf_clipfrac = core_algos.compute_value_loss(\n                        vpreds=vpreds,\n                        values=values,\n                        returns=returns,\n                        response_mask=response_mask,\n                        cliprange_value=self.config.cliprange_value,\n                        loss_agg_mode=self.config.loss_agg_mode,\n                    )\n                    if self.config.use_dynamic_bsz:\n                        # relative to the dynamic bsz\n                        loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size\n                        loss = vf_loss * loss_scale_factor\n                    else:\n                        loss_scale_factor = 1 / self.gradient_accumulation\n                        loss = vf_loss * loss_scale_factor\n\n                    loss.backward()\n\n                    micro_batch_metrics.update(\n                        {\n                            \"critic/vf_loss\": vf_loss.detach().item() * loss_scale_factor,\n                            \"critic/vf_clipfrac\": vf_clipfrac.detach().item(),\n                            \"critic/vpred_mean\": masked_mean(vpreds, response_mask).detach().item(),\n                        }\n                    )\n\n                    append_to_dict(metrics, micro_batch_metrics)\n\n                grad_norm = self._optimizer_step()\n                mini_batch_metrics = {\"critic/grad_norm\": grad_norm.detach().item()}\n                append_to_dict(metrics, mini_batch_metrics)\n        self.critic_optimizer.zero_grad()\n        return metrics\n"
  },
  {
    "path": "verl_distillation/verl/workers/critic/megatron_critic.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement a multiprocess PPOCritic\n\"\"\"\n\nimport itertools\nimport logging\nimport os\nfrom functools import partial\nfrom typing import Iterable\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.optimizer import DistributedOptimizer, OptimizerConfig\nfrom megatron.core.pipeline_parallel import get_forward_backward_func\nfrom omegaconf import OmegaConf\nfrom torch import nn\n\nfrom verl import DataProto\nfrom verl.trainer.ppo import core_algos\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.megatron.pipeline_parallel import make_batch_generator\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.utils.torch_functional import broadcast_dict_tensor, masked_mean\nfrom verl.workers.critic import BasePPOCritic\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MegatronPPOCritic(BasePPOCritic):\n    def __init__(\n        self,\n        config,\n        model_config,\n        hf_config,\n        tf_config,\n        critic_module: nn.ModuleList,\n        critic_optimizer: DistributedOptimizer,\n        critic_optimizer_config: OptimizerConfig,\n    ):\n        super().__init__(config=config)\n        self._validate_config(config)\n        self.model_config = model_config\n        self.hf_config = hf_config  # huggingface config\n        self.tf_config = tf_config  # mcore transformer config\n\n        self.critic_module = critic_module\n        self.critic_optimizer = critic_optimizer\n        self.critic_optimizer_config = critic_optimizer_config\n\n        # we create a separate nametuple for optimizer step so that global args won't affect it.\n        self.optimizer_step_args = OmegaConf.create(\n            {\n                \"skip_grad\": None,\n                \"overlap_dp_param_comm\": False,\n                \"overlap_dp_grad_comm\": False,\n                \"gradient_accumulation_steps\": 1,\n                \"sequence_parallel\": self.tf_config.sequence_parallel,\n                \"DDP_impl\": \"local\",\n                \"layernorm_allreduce_bucket_threshold\": 0,\n                \"pipeline_model_parallel_split_rank\": None,\n                \"reduce_grads_use_alltoall\": False,\n            }\n        )\n\n    def _validate_config(self, config) -> None:\n        \"\"\"Validate config options not implemented for Megatron backend\"\"\"\n        assert config.get(\"ulysses_sequence_parallel_size\", 1) == 1\n        if config.shuffle:\n            assert config.data_loader_seed is not None, \"If shuffle dataloader, seed must be manually set\"\n        self.config = config\n\n    @GPUMemoryLogger(\"megatron critic\", logger=logger)\n    def compute_values(self, data: DataProto) -> DataProto:\n        responses = data.batch[\"responses\"]\n        attention_mask = data.batch[\"attention_mask\"]\n        use_dynamic_bsz = data.meta_info.get(\"use_dynamic_bsz\", False)\n        micro_batch_size = data.meta_info.get(\"micro_batch_size\", None)\n        max_token_len = data.meta_info.get(\"max_token_len\", None)\n        assert micro_batch_size is not None, \"micro batch size is needed for forward compute\"\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            max_token_len = max_token_len * self.config.megatron.context_parallel_size\n        response_length = responses.size(1)\n        with torch.no_grad():\n            output = self.forward_backward_batch(\n                data=data,\n                forward_only=True,\n                use_dynamic_bsz=use_dynamic_bsz,\n                micro_batch_size=micro_batch_size,\n                max_token_len=max_token_len,\n                mini_batch_size=None,\n            )\n            if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                # only on last rank. It should be on every tp rank\n                values = [o[\"vpreds\"] for o in output[\"output\"]]  # (bs, seq_size, vocal_size)\n                values = torch.cat(values, dim=0).to(torch.float32)\n                if use_dynamic_bsz:\n                    indices = output[\"indices\"]\n                    indices = list(itertools.chain.from_iterable(indices))\n                    assert len(indices) == values.size(0), f\"{len(indices)} vs. {values.size()}\"\n                    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                    values = values[revert_indices]\n            else:\n                values = torch.empty_like(attention_mask, dtype=torch.float32)\n\n            # each tp ranks should contain the same value\n            values = values[\n                :, -response_length - 1 : -1\n            ]  # Values are predicted at the ends of prefixes, e.g., the last prompt token\n            response_mask = attention_mask[:, -response_length:]\n            values = values * response_mask  # Only action tokens have values\n            values = values.contiguous()\n\n            # sync among pp ranks\n            values = values.to(get_device_id())\n            torch.distributed.broadcast(\n                tensor=values,\n                src=mpu.get_pipeline_model_parallel_last_rank(),\n                group=mpu.get_pipeline_model_parallel_group(),\n            )\n            values = values.to(\"cpu\")\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        return values\n\n    def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:\n        select_keys = [\"input_ids\", \"responses\", \"attention_mask\", \"position_ids\", \"values\", \"returns\"]\n        data = data.select(batch_keys=select_keys)\n        return data.make_iterator(\n            mini_batch_size=self.config.ppo_mini_batch_size,\n            epochs=self.config.ppo_epochs,\n            seed=self.config.data_loader_seed,\n            dataloader_kwargs={\"shuffle\": self.config.shuffle},\n        )\n\n    def forward_backward_batch(\n        self,\n        data: DataProto,\n        forward_only=False,\n        use_dynamic_bsz=False,\n        micro_batch_size=None,\n        max_token_len=None,\n        mini_batch_size=None,\n    ):\n        # broadcast from last pp rank to all other pp ranks\n        data.to(get_device_id())\n        mini_batch = data\n        mini_batch.batch = mini_batch.batch.contiguous()\n        broadcast_dict_tensor(\n            mini_batch.batch,\n            src=mpu.get_pipeline_model_parallel_last_rank(),\n            group=mpu.get_pipeline_model_parallel_group(),\n        )\n        mini_batch.to(\"cpu\")\n        # split into micro-batches\n        mini_batch.batch[\"attention_mask\"] = mini_batch.batch[\"attention_mask\"].to(bool)\n\n        indices = None\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n            if vpp_size is not None and vpp_size > 1:\n                microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage\n                micro_batches, indices = rearrange_micro_batches(\n                    batch=mini_batch.batch,\n                    num_batches_divided_by=microbatch_group_size_per_vp_stage,\n                    max_token_len=max_token_len,\n                )\n                assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (\n                    f\"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage \"\n                    f\"{microbatch_group_size_per_vp_stage} for megatron backend\"\n                )\n            else:\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)\n            total_seqlen = max_token_len\n        else:\n            assert micro_batch_size is not None, (\n                \"micro_batch_size is needed to be passed in when not using dynamic batch size\"\n            )\n            micro_batches = mini_batch.batch.split(micro_batch_size)\n            seq_len = micro_batches[0][\"input_ids\"].shape[1]\n            total_seqlen = micro_batch_size * seq_len\n        n_micro_batch = len(micro_batches)\n\n        forward_backward_func = get_forward_backward_func()\n\n        def loss_func(output, data, meta_info):\n            nonlocal use_dynamic_bsz\n\n            if forward_only:\n                return torch.tensor(1.0, device=output.device), {\"vpreds\": output}\n\n            responses = data[\"responses\"]\n            attention_mask = data[\"attention_mask\"]\n            values = data[\"values\"]\n            returns = data[\"returns\"]\n            response_length = responses.size(1)\n\n            response_mask = attention_mask[:, -response_length:]\n\n            cliprange_value = self.config.cliprange_value\n\n            vpreds = output  # (bs, sequence_length)\n            vpreds = vpreds[:, -response_length - 1 : -1]\n\n            vf_loss, vf_clipfrac = core_algos.compute_value_loss(\n                vpreds=vpreds,\n                values=values,\n                returns=returns,\n                response_mask=response_mask,\n                cliprange_value=cliprange_value,\n                loss_agg_mode=self.config.loss_agg_mode,\n            )\n\n            stats = {\n                \"critic/vf_loss\": vf_loss.detach().item(),\n                \"critic/vf_clipfrac\": vf_clipfrac.detach().item(),\n                \"critic/vpred_mean\": masked_mean(vpreds, response_mask).detach().item(),\n            }\n\n            return vf_loss, stats\n\n        def forward_step(batch_iter, model):\n            batch = next(batch_iter)\n            batch = batch.to(get_device_id())\n            batch = batch.contiguous()\n\n            input_ids = batch[\"input_ids\"]\n            attention_mask = batch[\"attention_mask\"]\n            position_ids = batch[\"position_ids\"]\n            from verl.models.mcore import get_mcore_forward_fn\n\n            forward_fn = get_mcore_forward_fn(self.hf_config)\n\n            output = forward_fn(\n                model,\n                input_ids,\n                attention_mask,\n                position_ids,\n                {},  # multi_modal_inputs\n                value_model=True,\n            )\n\n            return output, partial(loss_func, data=batch, meta_info={})\n\n        # batch should be a list of batches inside micro-batches\n        batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.critic_module))\n\n        # TODO: we may use the new schedule instead\n        # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)\n        if mpu.get_pipeline_model_parallel_world_size() > 1:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.critic_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # no use when input_shapes was set\n                micro_batch_size=1,  # no use when input_shapes was set\n                forward_only=forward_only,\n            )\n        else:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.critic_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # in use for pp = 1\n                micro_batch_size=1,  # in use for pp = 1\n                forward_only=forward_only,\n            )\n        # loss_reduces contains the stats returned from loss_func\n        losses_reduced = {\"output\": losses_reduced}\n        if use_dynamic_bsz:\n            losses_reduced[\"indices\"] = indices\n        return losses_reduced\n\n    @GPUMemoryLogger(\"megatron critic\", logger=logger)\n    def update_critic(self, dataloader: Iterable[DataProto]):\n        metrics = {}\n\n        for data in dataloader:\n            self.critic_optimizer.zero_grad()\n            # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm\n            for chunk in self.critic_module:\n                chunk.zero_grad_buffer()\n\n            micro_batch_size = self.config.ppo_micro_batch_size_per_gpu\n            max_token_len = None\n            if self.config.use_dynamic_bsz:\n                max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size\n            metric_micro_batch = self.forward_backward_batch(\n                data,\n                forward_only=False,\n                use_dynamic_bsz=self.config.use_dynamic_bsz,\n                micro_batch_size=micro_batch_size,\n                max_token_len=max_token_len,\n                mini_batch_size=self.config.ppo_mini_batch_size,\n            )\n            metric_micro_batch = metric_micro_batch[\"output\"]\n            update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step()\n            learning_rate = self.critic_optimizer.param_groups[-1][\"lr\"]\n            data = {\"critic/grad_norm\": grad_norm, \"critic/lr\": learning_rate}\n            append_to_dict(metrics, data)\n\n            if update_successful:\n                # allgather already execute in optimizer.step in new megatron\n                pass\n            else:\n                raise NotImplementedError\n\n            for metric in metric_micro_batch:\n                append_to_dict(metrics, metric)  # append the metric from this micro-batch to global metrics.\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n        return metrics\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom .base import BaseEngine, EngineRegistry\nfrom .fsdp import FSDPEngine, FSDPEngineWithLMHead\n\n__all__ = [\"BaseEngine\", \"EngineRegistry\", \"FSDPEngine\", \"FSDPEngineWithLMHead\"]\n\n# Mindspeed must be imported before Megatron to ensure the related monkey patches take effect as expected\ntry:\n    from .mindspeed import MindspeedEngineWithLMHead\n\n    __all__ += [\"MindspeedEngineWithLMHead\"]\nexcept ImportError:\n    MindspeedEngineWithLMHead = None\n\ntry:\n    from .megatron import MegatronEngine, MegatronEngineWithLMHead\n\n    __all__ += [\"MegatronEngine\", \"MegatronEngineWithLMHead\"]\nexcept ImportError:\n    MegatronEngine = None\n    MegatronEngineWithLMHead = None\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe abstract base class defining the interface for model training engines.\n\"\"\"\n\nfrom typing import Any, Callable, Optional\n\nimport torch\nfrom tensordict import TensorDict\n\nfrom verl.utils.device import get_device_name\n\n\nclass BaseEngine:\n    \"\"\"\n    Abstract base class defining the interface for model training engines. Interface is subject to\n    change before release.\n\n    Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.\n    \"\"\"\n\n    def initialize(self):\n        \"\"\"\n        Instantiate or load the model, optimizer, and learning rate scheduler.\n\n        Should prepare all components necessary for training or evaluation.\n        \"\"\"\n        raise NotImplementedError\n\n    def train_mode(self):\n        \"\"\"\n        Context manager entry for switching the engine and model into training mode.\n\n        Usage:\n            with engine.train_mode():\n                # runs in training mode\n        \"\"\"\n        raise NotImplementedError\n\n    def eval_mode(self):\n        \"\"\"\n        Context manager entry for switching the engine and model into evaluation mode.\n\n        Usage:\n            with engine.eval_mode():\n                # runs in evaluation mode\n        \"\"\"\n        raise NotImplementedError\n\n    def optimizer_zero_grad(self):\n        \"\"\"\n        Zero the gradients of the optimizer.\n        \"\"\"\n        raise NotImplementedError\n\n    def optimizer_step(self):\n        \"\"\"\n        Perform an optimization step using the optimizer.\n        \"\"\"\n        raise NotImplementedError\n\n    def lr_scheduler_step(self):\n        \"\"\"\n        Advance the learning rate scheduler by one step.\n\n        Returns:\n            current_lr (float or list[float]): Updated learning rate(s).\n        \"\"\"\n        raise NotImplementedError\n\n    def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> Any:\n        \"\"\"\n        Perform a forward pass and optionally a backward pass on a batch of data.\n\n        Args:\n            data: The input data for the forward pass, typically containing tensors and metadata.\n            loss_function: The loss function to optimize. See `verl.workers.roles.utils.losses` for examples.\n            forward_only: If True, perform only the forward pass. If False, perform forward and backward pass.\n\n        Returns:\n            Any: The output of the forward pass, which can be used for loss computation or other purposes.\n        \"\"\"\n        raise NotImplementedError\n\n    def train_batch(self, data: TensorDict, loss_function: Callable) -> Any:\n        \"\"\"\n        Perform a training step on a batch of data.\n\n        Args:\n            data: The input data for training, typically containing tensors and metadata.\n            loss_function: A function that computes the loss and metrics given a batch and predictions.\n\n        Returns:\n            dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the batch.\n        \"\"\"\n        self.optimizer_zero_grad()\n        outputs = self.forward_backward_batch(data, loss_function, forward_only=False)\n        grad_norm = self.optimizer_step()\n        if self.is_mp_src_rank_with_outputs():\n            outputs[\"metrics\"][\"grad_norm\"] = grad_norm\n        return outputs\n\n    def infer_batch(self, data: TensorDict, loss_function: Optional[Callable] = None) -> Any:\n        \"\"\"\n        Perform inference on a batch of data.\n\n        Args:\n            data: The input data for inference, typically containing tensors and metadata.\n\n        Returns:\n            Any: The output of the inference, which can be used for predictions or other purposes.\n        \"\"\"\n        with torch.no_grad():\n            outputs = self.forward_backward_batch(data, loss_function, forward_only=True)\n        return outputs\n\n    def get_per_tensor_param(self):\n        raise NotImplementedError\n\n    def get_data_parallel_size(self):\n        raise NotImplementedError\n\n    def get_data_parallel_rank(self):\n        raise NotImplementedError\n\n    def get_data_parallel_group(self):\n        raise NotImplementedError\n\n    def to(self, device: str, model: bool = True, optimizer: bool = True):\n        \"\"\"\n        Move model parameters, optimizer states, or both to the specified device.\n\n        Args:\n            device: Target device identifier.\n            model: If True, move the model.\n            optimizer: If True, move the optimizer states.\n        \"\"\"\n        raise NotImplementedError\n\n    def save_checkpoint(\n        self,\n        local_path: str,\n        hdfs_path: Optional[str] = None,\n        global_step: int = 0,\n        max_ckpt_to_keep: Optional[int] = None,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Save model, optimizer, and scheduler states to a checkpoint.\n\n        Args:\n            local_path: Local filesystem path to save checkpoint.\n            hdfs_path: Optional HDFS path to copy checkpoint.\n            global_step: Integer training step number for naming.\n            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.\n            **kwargs: Arbitrary keyword arguments.\n        \"\"\"\n        raise NotImplementedError\n\n    def load_checkpoint(\n        self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs\n    ) -> None:\n        \"\"\"\n        Load model, optimizer, and scheduler states from a checkpoint.\n\n        Args:\n            local_path: Local filesystem path of the checkpoint.\n            hdfs_path: Optional HDFS path where checkpoint is stored.\n            del_local_after_load: Whether to delete local copy after loading.\n            **kwargs: Arbitrary keyword arguments.\n        \"\"\"\n        raise NotImplementedError\n\n    def is_mp_src_rank_with_outputs(self):\n        \"\"\"\n        Whether the current rank is the first rank in model parallel group that contains model outputs\n        \"\"\"\n        raise NotImplementedError\n\n\nclass EngineRegistry:\n    \"\"\"\n    A registry for managing and instantiating different types of training engines.\n\n    This class uses a dictionary to store engine classes, mapping a string key to each class.\n    It provides a decorator `register` to add new engines to the registry and a `new` method\n    to create an instance of a registered engine.\n    \"\"\"\n\n    _engines = {}\n\n    @classmethod\n    def register(cls, model_type: str, backend: list[str] | str, device: list[str] | str = \"cuda\"):\n        \"\"\"\n        A class method decorator that registers an engine class with a given key.\n\n        This allows for dynamic instantiation of engine classes by their registered key.\n\n        Args:\n            model_type (str): The type of the model\n            backend (list[str] | str): The backend to use for the model type\n            device (list[str] | str): The device type (e.g., \"cuda\", \"npu\", \"cpu\") this engine supports,\n                default is \"cuda\"\n\n        Returns:\n            A decorator function that takes an engine class and registers it.\n        \"\"\"\n\n        def decorator(engine_class):\n            assert issubclass(engine_class, BaseEngine)\n            if model_type not in cls._engines:\n                cls._engines[model_type] = {}\n\n            backends = backend if isinstance(backend, list) else [backend]\n            devices = device if isinstance(device, list) else [device]\n            for current_backend in backends:\n                for current_device in devices:\n                    if current_backend not in cls._engines[model_type]:\n                        cls._engines[model_type][current_backend] = {}\n                    if current_device not in cls._engines[model_type][current_backend]:\n                        cls._engines[model_type][current_backend][current_device] = engine_class\n\n            return engine_class\n\n        return decorator\n\n    @classmethod\n    def get_engine_cls(cls, model_type: str, backend: str):\n        assert model_type in cls._engines, f\"Unknown model_type: {model_type}\"\n        assert backend in cls._engines[model_type], f\"Unknown backend: {backend}\"\n        device = get_device_name()\n        assert device in cls._engines[model_type][backend], (\n            f\"Unknown device: {device} for model_type: {model_type} and backend: {backend}\"\n        )\n        return cls._engines[model_type][backend][device]\n\n    @classmethod\n    def new(cls, model_type, backend, *args, **kwargs):\n        \"\"\"\n        Function to create a new training engine instance based on the provided config.\n        Args:\n            key: A configuration object containing the engine key and other settings.\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n        Returns:\n            engine: An instance of the training engine corresponding to the config.\n        Raises:\n            NotImplementedError: If the engine key in the config does not match any known engines.\n        \"\"\"\n        engine_cls = cls.get_engine_cls(model_type, backend)\n        return engine_cls(*args, **kwargs)\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/fsdp/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom .transformer_impl import FSDPEngine, FSDPEngineWithLMHead\n\n__all__ = [\"FSDPEngine\", \"FSDPEngineWithLMHead\"]\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/fsdp/transformer_impl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP)\n\"\"\"\n\nimport gc\nimport logging\nimport os\nimport warnings\nfrom contextlib import nullcontext\nfrom typing import Callable, Optional\n\nimport torch\nimport torch.distributed\nfrom peft import LoraConfig, TaskType, get_peft_model\nfrom tensordict import TensorDict\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType\nfrom torch.distributed.tensor import DTensor\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.models.transformers.monkey_patch import apply_monkey_patch\nfrom verl.trainer.config import CheckpointConfig\nfrom verl.utils import tensordict_utils as tu\nfrom verl.utils.activation_offload import enable_activation_offloading\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.dataset.dataset_utils import DatasetPadMode\nfrom verl.utils.debug import log_gpu_memory_usage\nfrom verl.utils.device import (\n    get_device_id,\n    get_device_name,\n    get_torch_device,\n)\nfrom verl.utils.fsdp_utils import (\n    CPUOffloadPolicy,\n    FSDPModule,\n    MixedPrecisionPolicy,\n    apply_fsdp2,\n    collect_lora_params,\n    fsdp2_clip_grad_norm_,\n    fsdp2_load_full_state_dict,\n    fsdp_version,\n    get_fsdp_wrap_policy,\n    get_init_weight_context_manager,\n    init_fn,\n    load_fsdp_model_to_gpu,\n    load_fsdp_optimizer,\n    offload_fsdp_model_to_cpu,\n    offload_fsdp_optimizer,\n    replace_lora_wrapper,\n)\nfrom verl.utils.model import convert_weight_keys\nfrom verl.utils.py_functional import convert_to_regular_types\nfrom verl.utils.torch_functional import logprobs_from_logits\nfrom verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs\nfrom verl.workers.config import FSDPEngineConfig, FSDPOptimizerConfig, HFModelConfig\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\nfrom ..base import BaseEngine, EngineRegistry\nfrom ..utils import postprocess_batch_func, prepare_micro_batches\nfrom .utils import create_device_mesh, get_sharding_strategy\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\ndevice_name = get_device_name()\n\n\nclass FSDPEngine(BaseEngine):\n    \"\"\"\n    Concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP).\n\n    Supports model sharding, activation/optimizer offloading, LoRA, and sequence parallelism.\n    \"\"\"\n\n    def __init__(\n        self,\n        model_config: HFModelConfig,\n        engine_config: FSDPEngineConfig,\n        optimizer_config: FSDPOptimizerConfig,\n        checkpoint_config: CheckpointConfig,\n    ):\n        \"\"\"\n        Initialize the FSDPEngine.\n\n        Sets up distributed device meshes, LoRA, and offload policies based on config.\n\n        Args:\n            config: Configuration object with FSDP and model settings.\n        \"\"\"\n        super().__init__()\n\n        self.model_config = model_config\n        self.engine_config = engine_config\n        self.optimizer_config = optimizer_config\n        self.checkpoint_config = checkpoint_config\n\n        self.mode = None\n\n        self.rank = torch.distributed.get_rank()\n        # build device mesh for Ulysses Sequence Parallel\n\n        self.use_remove_padding = self.model_config.use_remove_padding\n\n        self._init_device_mesh()\n\n        # set FSDP offload params\n        self._is_offload_param = self.engine_config.param_offload\n        self._is_offload_optimizer = self.engine_config.optimizer_offload\n        self._is_lora = self.model_config.lora_rank > 0\n\n        if self.engine_config.entropy_from_logits_with_chunking:\n            entropy_from_logits = verl_F.entropy_from_logits_with_chunking\n        else:\n            entropy_from_logits = verl_F.entropy_from_logits\n\n        self.compute_entropy_from_logits = (\n            torch.compile(entropy_from_logits, dynamic=True)\n            if self.engine_config.use_torch_compile  #  use torch compile by default\n            else entropy_from_logits\n        )\n\n    def is_mp_src_rank_with_outputs(self):\n        if self.ulysses_device_mesh is not None:\n            is_collect = self.ulysses_device_mesh[\"sp\"].get_local_rank() == 0\n        else:\n            is_collect = True\n        return is_collect\n\n    def initialize(self):\n        \"\"\"\n        Build the model, optimizer, and learning rate scheduler under FSDP.\n\n        Applies device, dtype, and precision configurations, including mixed precision.\n        Sets up checkpoint manager and FLOPs counter.\n        \"\"\"\n        # This is used to import external_lib into the huggingface systems\n        self._build_model_optimizer()\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n            log_gpu_memory_usage(\"After offload model during init\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.optimizer)\n            log_gpu_memory_usage(\"After offload optimizer during init\", logger=logger)\n\n        self.checkpoint_manager = FSDPCheckpointManager(\n            model=self.module,\n            optimizer=self.optimizer,\n            lr_scheduler=self.lr_scheduler,\n            processing_class=self.model_config.get_processor(),\n            checkpoint_contents=self.checkpoint_config,\n        )\n\n    def _init_device_mesh(self):\n        world_size = torch.distributed.get_world_size()\n        from torch.distributed.device_mesh import init_device_mesh\n\n        fsdp_size = self.engine_config.fsdp_size\n\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.engine_config.ulysses_sequence_parallel_size\n        dp_size = self.get_data_parallel_size()\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                device_name, mesh_shape=(dp_size, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n        self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1\n\n    def _build_module(self):\n        from verl.utils.model import get_hf_auto_model_class\n        from verl.utils.torch_dtypes import PrecisionType\n\n        torch_dtype = self.engine_config.model_dtype\n\n        if torch_dtype is None:\n            # if it is training, we force torch_dtype to fp32\n            torch_dtype = torch.float32 if not self.engine_config.forward_only else torch.bfloat16\n\n        torch_dtype = PrecisionType.to_dtype(torch_dtype)\n\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not self.model_config.hf_config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n\n            auto_class = get_hf_auto_model_class(hf_config=self.model_config.hf_config)\n\n            module = auto_class.from_pretrained(\n                pretrained_model_name_or_path=self.model_config.local_path,\n                torch_dtype=torch_dtype,\n                config=self.model_config.hf_config,\n                trust_remote_code=self.model_config.trust_remote_code,\n            )\n\n            use_liger = self.model_config.use_liger\n            # Apply Liger kernel to the model if use_liger is set to True\n            if use_liger:\n                from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance\n\n                _apply_liger_kernel_to_instance(model=module)\n\n            fused_kernel_options = self.model_config.fused_kernel_options\n            fused_kernels_backend = (\n                fused_kernel_options.get(\"impl_backend\", None) if fused_kernel_options is not None else None\n            )\n\n            use_fused_kernels = self.model_config.use_fused_kernels\n            apply_monkey_patch(\n                model=module,\n                use_remove_padding=self.use_remove_padding,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n                use_fused_kernels=use_fused_kernels,\n                fused_kernels_backend=fused_kernels_backend,\n            )\n\n            # some parameters may not in torch_dtype\n            module.to(torch_dtype)\n\n            if self.model_config.enable_gradient_checkpointing:\n                module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n        return module\n\n    def _build_lora_module(self, module):\n        module.enable_input_require_grads()\n\n        lora_adapter_path = getattr(self.model_config, \"lora_adapter_path\", None)\n        if lora_adapter_path is not None:\n            from peft import PeftModel\n\n            from verl.utils.fs import copy_to_local\n\n            print(f\"Loading pre-trained LoRA adapter to from: {lora_adapter_path}\")\n            # Copy adapter to local if needed\n            local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.model_config.use_shm)\n\n            module = PeftModel.from_pretrained(module, local_adapter_path, is_trainable=True)\n            peft_config = module.peft_config[\"default\"]\n            # Ensure task_type is TaskType enum, not string\n            if isinstance(peft_config.task_type, str):\n                peft_config.task_type = TaskType.CAUSAL_LM\n        else:\n            # Convert config to regular Python types before creating PEFT model\n            lora_config = {\n                \"task_type\": TaskType.CAUSAL_LM,\n                \"r\": self.model_config.lora_rank,\n                \"lora_alpha\": self.model_config.lora_alpha,\n                \"target_modules\": convert_to_regular_types(self.model_config.target_modules),\n                \"exclude_modules\": convert_to_regular_types(self.model_config.exclude_modules),\n                \"bias\": \"none\",\n            }\n            module = get_peft_model(module, LoraConfig(**lora_config))\n\n        return module\n\n    def _build_fsdp_module(self, module):\n        # TODO(ziheng): need to improve\n        from torch.distributed.fsdp import CPUOffload, MixedPrecision\n\n        from verl.utils.torch_dtypes import PrecisionType\n\n        mixed_precision_config = self.engine_config.mixed_precision\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"param_dtype\", \"bf16\"))\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"reduce_dtype\", \"fp32\"))\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"buffer_dtype\", \"fp32\"))\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n\n        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(\n            module=module,\n            config=self.engine_config.wrap_policy,\n            is_lora=self.model_config.lora_rank > 0,\n        )\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        # Note: We force turn off CPUOffload because it causes incorrect results when using grad accumulation\n        if self.engine_config.strategy == \"fsdp\":\n            # cpu_offload:\n            # - actor: None\n            # - critic: None\n            # - ref: CPUOffload(offload_params=True)\n\n            # We force reference policy to use CPUOffload to save memory.\n            # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation\n            cpu_offload = None\n            if self.engine_config.forward_only:\n                cpu_offload = CPUOffload(offload_params=True)\n                self._is_offload_param = False\n                self._is_offload_optimizer = False\n\n            module = FSDP(\n                module,\n                param_init_fn=init_fn,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                device_mesh=self.device_mesh,\n                forward_prefetch=self.engine_config.forward_prefetch,\n                use_orig_params=self.engine_config.use_orig_params,\n                cpu_offload=cpu_offload,\n            )\n        elif self.engine_config.strategy == \"fsdp2\":\n            # - actor: offload_policy\n            # - critic: offload_policy\n            # - ref: CPUOffloadPolicy(pin_memory=True)\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            mp_policy = MixedPrecisionPolicy(\n                param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True\n            )\n            offload_policy = None\n            if self.engine_config.offload_policy or self.engine_config.forward_only:\n                self._is_offload_param = False\n                self._is_offload_optimizer = False\n                offload_policy = CPUOffloadPolicy(pin_memory=True)\n\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": offload_policy,\n                \"reshard_after_forward\": self.engine_config.reshard_after_forward,\n            }\n            full_state = module.state_dict()\n            apply_fsdp2(module, fsdp_kwargs, self.engine_config)\n            fsdp2_load_full_state_dict(module, full_state, fsdp_mesh, offload_policy)\n        else:\n            raise NotImplementedError(f\"Unknown strategy {self.engine_config.strategy}\")\n\n        if self.model_config.enable_activation_offload:\n            enable_gradient_checkpointing = self.model_config.enable_gradient_checkpointing\n            enable_activation_offloading(module, self.engine_config.strategy, enable_gradient_checkpointing)\n\n        if torch.distributed.get_world_size() == 1 and fsdp_version(module) == 1:\n            FSDP.set_state_dict_type(\n                module,\n                state_dict_type=StateDictType.FULL_STATE_DICT,\n                state_dict_config=FullStateDictConfig(),\n            )\n        elif fsdp_version(module) == 1:\n            FSDP.set_state_dict_type(\n                module,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n\n        return module\n\n    def _build_optimizer(self, module):\n        from verl.workers.config.optimizer import build_optimizer\n\n        optimizer = build_optimizer(module.parameters(), self.optimizer_config)\n\n        return optimizer\n\n    def _build_lr_scheduler(self, optimizer):\n        from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup\n\n        optim_config = self.optimizer_config\n\n        total_steps = optim_config.total_training_steps\n        num_warmup_steps = optim_config.lr_warmup_steps\n        lr_scheduler_type = optim_config.lr_scheduler_type\n        min_lr_ratio = optim_config.min_lr_ratio\n        num_cycles = optim_config.num_cycles\n        if num_warmup_steps <= 0:\n            num_warmup_steps_ratio = optim_config.lr_warmup_steps_ratio\n            num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n\n        if self.rank == 0:\n            print(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n\n        if lr_scheduler_type == \"constant\":\n            lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)\n        elif lr_scheduler_type == \"cosine\":\n            lr_scheduler = get_cosine_schedule_with_warmup(\n                optimizer=optimizer,\n                num_warmup_steps=num_warmup_steps,\n                num_training_steps=total_steps,\n                min_lr_ratio=min_lr_ratio,\n                num_cycles=num_cycles,\n            )\n        else:\n            raise NotImplementedError(f\"LR scheduler type {lr_scheduler_type} is not supported\")\n        return lr_scheduler\n\n    def _build_model_optimizer(self):\n        from verl.utils.model import print_model_size\n\n        # Load base model with specified configuration and dtype\n        module = self._build_module()\n        # Apply LoRA adapters if low-rank adaptation is enabled\n        if self._is_lora:\n            module = self._build_lora_module(module)\n\n        # Synchronize all distributed processes before proceeding\n        torch.distributed.barrier()\n        if self.rank == 0:\n            print_model_size(module)\n        log_gpu_memory_usage(\"After init model from HF AutoModel\", logger=logger)\n\n        # Wrap model with FSDP for distributed training (sharding, mixed precision, etc.)\n        log_gpu_memory_usage(\"Before FSDP\", logger=None)\n        module = self._build_fsdp_module(module)\n        log_gpu_memory_usage(\"After FSDP\", logger=None)\n\n        if not self.engine_config.forward_only:\n            # Initialize optimizer with model parameters and config settings\n            optimizer = self._build_optimizer(module)\n            # Create learning rate scheduler with warmup and decay settings\n            lr_scheduler = self._build_lr_scheduler(optimizer)\n        else:\n            optimizer = None\n            lr_scheduler = None\n\n        self.module = module\n        self.optimizer = optimizer\n        self.lr_scheduler = lr_scheduler\n\n    def train_mode(self):\n        \"\"\"\n        Return a context manager that switches to training mode with FSDP-specific handling.\n\n        Includes parameter and optimizer offload entry/exit.\n        \"\"\"\n        return EngineTrainModeCtx(self)\n\n    def eval_mode(self):\n        \"\"\"\n        Return a context manager that switches to evaluation mode with FSDP-specific handling.\n\n        Includes activation offload entry/exit.\n        \"\"\"\n        return EngineEvalModeCtx(self)\n\n    def get_data_parallel_rank(self):\n        if self.ulysses_device_mesh is not None:\n            return self.ulysses_device_mesh[\"dp\"].get_local_rank()\n        else:\n            return torch.distributed.get_rank()\n\n    def get_data_parallel_size(self):\n        return torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n\n    def get_data_parallel_group(self):\n        if self.ulysses_device_mesh is not None:\n            return self.ulysses_device_mesh.get_group(mesh_dim=\"dp\")\n        else:\n            return torch.distributed.group.WORLD\n\n    def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> list[TensorDict]:\n        # note that the global_batch_size should include data on all the dp\n        tu.assign_non_tensor(data, sp_size=self.ulysses_sequence_parallel_size)\n\n        # compute num_tokens in global batch for loss normalization\n        batch_num_tokens = data[\"loss_mask\"].sum().to(get_device_id())\n        torch.distributed.all_reduce(\n            batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group()\n        )\n        tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item())\n        tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size())\n\n        micro_batches, indices = prepare_micro_batches(\n            data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True\n        )\n\n        output_lst = []\n\n        ctx = torch.no_grad() if forward_only else nullcontext()\n\n        for micro_batch in micro_batches:\n            with ctx:\n                loss, meta_info = self.forward_step(micro_batch, loss_function=loss_function, forward_only=forward_only)\n\n                if not forward_only:\n                    loss.backward()\n\n            output_lst.append(meta_info)\n\n        # postprocess and return\n        return postprocess_batch_func(output_lst=output_lst, indices=indices, data=data)\n\n    def forward_step(self, micro_batch: TensorDict, loss_function, forward_only):\n        raise NotImplementedError(\"forward_step must be implemented in subclass\")\n\n    def optimizer_zero_grad(self):\n        \"\"\"\n        Zero gradients and enforce FSDP grad-clipping logic.\n        \"\"\"\n        self.optimizer.zero_grad()\n\n    def optimizer_step(self):\n        \"\"\"\n        Clip gradients, skip update if non-finite, and step optimizer.\n\n        Returns:\n            grad_norm (float): Norm of gradients before clipping.\n        \"\"\"\n        assert self.optimizer_config.clip_grad is not None\n\n        if isinstance(self.module, FSDP):\n            grad_norm = self.module.clip_grad_norm_(self.optimizer_config.clip_grad)\n        elif isinstance(self.module, FSDPModule):\n            grad_norm = fsdp2_clip_grad_norm_(self.module.parameters(), max_norm=self.optimizer_config.clip_grad)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(\n                self.module.parameters(), max_norm=self.optimizer_config.clip_grad\n            )\n\n        if isinstance(grad_norm, DTensor):\n            grad_norm = grad_norm.full_tensor()\n\n        # if grad_norm is not finite, skip the update\n        if not torch.isfinite(grad_norm):\n            print(f\"WARN: grad_norm is not finite: {grad_norm}\")\n            self.optimizer.zero_grad()\n        else:\n            self.optimizer.step()\n        return grad_norm.item()\n\n    def lr_scheduler_step(self):\n        \"\"\"\n        Advance FSDP scheduler and return updated learning rate.\n        \"\"\"\n        self.lr_scheduler.step()\n        lr = self.lr_scheduler.get_last_lr()[0]  # only return the first group\n        return lr\n\n    def to(self, device: str, model: bool = True, optimizer: bool = True):\n        \"\"\"\n        Move FSDP model and/or optimizer to CPU or GPU with offload support.\n        \"\"\"\n        if self.engine_config.forward_only:\n            # force cpu_offload\n            return\n\n        device_name = get_device_name()\n\n        assert device in (device_name, \"cpu\")\n        if device == device_name:\n            if not self.engine_config.param_offload:\n                if model:\n                    load_fsdp_model_to_gpu(self.module)\n                if optimizer and self.optimizer is not None:\n                    load_fsdp_optimizer(self.optimizer, device)\n            gc.collect()\n        elif device == \"cpu\":\n            if not self.engine_config.param_offload:\n                if model:\n                    offload_fsdp_model_to_cpu(self.module)\n                if optimizer and self.optimizer is not None:\n                    offload_fsdp_optimizer(self.optimizer)\n        else:\n            raise ValueError(f\"Invalid device type: {device}\")\n\n    def save_checkpoint(\n        self,\n        local_path: str,\n        hdfs_path: Optional[str] = None,\n        global_step: int = 0,\n        max_ckpt_to_keep: Optional[int] = None,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Save FSDP checkpoint, handling parameter offload as needed.\n        \"\"\"\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.module)\n\n        self.checkpoint_manager.save_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n\n    def load_checkpoint(\n        self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs\n    ) -> None:\n        \"\"\"\n        Load FSDP checkpoint, restoring parameters and optimizer state.\n        \"\"\"\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.module)\n\n        self.checkpoint_manager.load_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(self.optimizer)\n\n    def get_per_tensor_param(self, layered_summon=False, base_sync_done=False):\n        log_gpu_memory_usage(\"Before load_fsdp_model_to_gpu\", logger=logger)\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.module)\n\n        log_gpu_memory_usage(\"After load_fsdp_model_to_gpu\", logger=logger)\n\n        peft_config = None\n        peft_model = getattr(self.module, \"_fsdp_wrapped_module\", self.module)\n        if hasattr(peft_model, \"peft_config\"):  # LoRA\n            peft_config = peft_model.peft_config.get(\"default\", None)\n            params = collect_lora_params(\n                module=self.module,\n                layered_summon=layered_summon,\n                base_sync_done=base_sync_done,\n            )\n            if not base_sync_done:\n                params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()}\n        else:\n            params = self.module.state_dict()\n\n        params = convert_weight_keys(params, getattr(self.module, \"_fsdp_wrapped_module\", self.module))\n\n        log_gpu_memory_usage(\"Before offload_fsdp_model_to_cpu\", logger=logger)\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n        log_gpu_memory_usage(\"After offload_fsdp_model_to_cpu\", logger=logger)\n\n        if peft_config is not None and base_sync_done:\n            per_tensor_param = params\n        else:\n            device = get_device_id()  # used when fsdp2 set cpu_offload_policy\n            per_tensor_param = (\n                (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)\n                for name, param in params.items()\n            )\n        return per_tensor_param\n\n\nclass EngineEvalModeCtx:\n    def __init__(self, engine: FSDPEngine):\n        self.engine = engine\n\n    def __enter__(self):\n        self.engine.mode = \"eval\"\n        if self.engine._is_offload_param:\n            load_fsdp_model_to_gpu(self.engine.module)\n\n        self.engine.ulysses_sharding_manager.__enter__()\n        self.engine.module.eval()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.engine.engine_config.fsdp_size > 1:\n            if fsdp_version(self.engine.module) == 1:\n                self.engine.module._handle.reshard(True)\n            elif fsdp_version(self.engine.module) == 2:\n                self.engine.module.reshard()\n\n        if self.engine._is_offload_param:\n            offload_fsdp_model_to_cpu(self.engine.module)\n        self.engine.mode = None\n\n\nclass EngineTrainModeCtx:\n    def __init__(self, engine: FSDPEngine):\n        self.engine = engine\n\n    def __enter__(self):\n        self.engine.mode = \"train\"\n        if self.engine._is_offload_param:\n            load_fsdp_model_to_gpu(self.engine.module)\n        if self.engine._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.engine.optimizer, device_id=get_torch_device().current_device())\n\n        self.engine.ulysses_sharding_manager.__enter__()\n        self.engine.module.train()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)\n        self.engine.optimizer_zero_grad()\n\n        if self.engine._is_offload_param:\n            offload_fsdp_model_to_cpu(self.engine.module)\n        if self.engine._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.engine.optimizer)\n        self.engine.mode = None\n\n\n@EngineRegistry.register(model_type=\"language_model\", backend=[\"fsdp\", \"fsdp2\"], device=[\"cuda\", \"npu\"])\nclass FSDPEngineWithLMHead(FSDPEngine):\n    def prepare_model_inputs(self, micro_batch: TensorDict):\n        use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key=\"use_remove_padding\", default=True)\n        pad_mode = tu.get_non_tensor_data(data=micro_batch, key=\"pad_mode\", default=DatasetPadMode.NO_PADDING)\n        use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key=\"use_fused_kernels\", default=False)\n        temperature = micro_batch[\"temperature\"]\n\n        assert pad_mode == DatasetPadMode.NO_PADDING, f\"pad_mode {pad_mode} not supported\"\n\n        multi_modal_inputs = {}\n        if \"multi_modal_inputs\" in micro_batch.keys():\n            from verl.utils.model import extract_multi_modal_inputs\n\n            multi_modal_inputs = extract_multi_modal_inputs(micro_batch[\"multi_modal_inputs\"])\n\n        input_ids = micro_batch[\"input_ids\"]\n        position_ids = micro_batch[\"position_ids\"]\n\n        if position_ids.dim() == 3:  # qwen2vl mrope\n            position_ids = position_ids.transpose(0, 1)  # (bsz, 3, seqlen) -> (3, bsz, seqlen)\n\n        # args used to get outputs\n        output_args = {}\n\n        if use_remove_padding:\n            if pad_mode == DatasetPadMode.NO_PADDING:\n                input_ids_rmpad = input_ids.values().unsqueeze(0)  # (1, total_nnz)\n                position_ids_rmpad = position_ids.values().unsqueeze(0)  # (1, total_nnz)\n            else:\n                raise NotImplementedError(f\"pad_mode {pad_mode} not implemented\")\n\n            # for compute the log_prob\n            input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)\n\n            # pad and slice the inputs if sp > 1\n            if self.use_ulysses_sp:\n                is_vlm_model = hasattr(getattr(self.module, \"module\", self.module).config, \"vision_config\")\n                if is_vlm_model:\n                    # vlm model's inputs will be sliced after embedding\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(\n                        input_ids_rmpad,\n                        position_ids_rmpad=position_ids_rmpad,\n                        sp_size=self.ulysses_sequence_parallel_size,\n                    )\n                else:\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad,\n                        position_ids_rmpad=position_ids_rmpad,\n                        sp_size=self.ulysses_sequence_parallel_size,\n                    )\n                input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(\n                    input_ids_rmpad_rolled,\n                    position_ids_rmpad=None,\n                    sp_size=self.ulysses_sequence_parallel_size,\n                )\n\n                output_args[\"pad_size\"] = pad_size\n\n            input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)  # ((total_nnz / sp) + pad)\n            output_args[\"input_ids_rmpad_rolled\"] = input_ids_rmpad_rolled\n\n            # only pass input_ids and position_ids to enable flash_attn_varlen\n\n            model_inputs = {\n                \"input_ids\": input_ids_rmpad,\n                \"attention_mask\": None,\n                \"position_ids\": position_ids_rmpad,\n            }\n\n        else:\n            if pad_mode == DatasetPadMode.NO_PADDING:\n                input_ids = micro_batch[\"input_ids\"]\n                position_ids = micro_batch[\"position_ids\"]\n                loss_mask = micro_batch[\"loss_mask\"]\n\n                pad_token_id = tu.get_non_tensor_data(data=micro_batch, key=\"pad_token_id\", default=0)\n                batch_size = micro_batch.batch_size[0]\n                seq_len_effective = input_ids.offsets().diff()\n                max_seq_len = max(seq_len_effective)\n\n                input_ids_rmpad_rolled = torch.roll(input_ids.values(), shifts=-1, dims=0)\n                output_args[\"input_ids_rmpad_rolled\"] = input_ids_rmpad_rolled\n\n                input_ids = torch.nested.to_padded_tensor(\n                    input_ids, padding=pad_token_id, output_size=(batch_size, max_seq_len)\n                )\n\n                position_ids = torch.nested.to_padded_tensor(\n                    position_ids, padding=0, output_size=(batch_size, max_seq_len)\n                )\n\n                attention_mask_list = [torch.ones_like(t, dtype=torch.int32) for t in loss_mask]\n                attention_mask = torch.nested.as_nested_tensor(attention_mask_list, layout=torch.jagged)\n                attention_mask = torch.nested.to_padded_tensor(\n                    attention_mask, padding=0, output_size=(batch_size, max_seq_len)\n                )\n\n                model_inputs = {\n                    \"input_ids\": input_ids,\n                    \"attention_mask\": attention_mask,\n                    \"position_ids\": position_ids,\n                }\n            else:\n                raise NotImplementedError(f\"pad_mode {pad_mode} not implemented\")\n\n        extra_args = {}\n        if use_fused_kernels:\n            extra_args[\"temperature\"] = temperature\n            extra_args[\"return_dict\"] = True\n\n        model_inputs.update(multi_modal_inputs)\n        model_inputs.update(extra_args)\n\n        return model_inputs, output_args\n\n    def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict):\n        use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key=\"use_remove_padding\", default=True)\n        pad_mode = tu.get_non_tensor_data(data=micro_batch, key=\"pad_mode\", default=DatasetPadMode.NO_PADDING)\n        use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key=\"use_fused_kernels\", default=False)\n        temperature = micro_batch[\"temperature\"]\n        calculate_entropy = tu.get_non_tensor_data(data=micro_batch, key=\"calculate_entropy\", default=False)\n\n        model_output = {}\n\n        input_ids = micro_batch[\"input_ids\"]\n        if use_remove_padding:\n            input_ids_rmpad_rolled = output_args[\"input_ids_rmpad_rolled\"]\n\n            if use_fused_kernels:\n                log_probs = output.log_probs.squeeze(0)  # (total_nnz,)\n                entropy_rmpad = output.entropy.squeeze(0)  # (total_nnz,)\n            else:\n                logits_rmpad = output.logits.squeeze(0)  # (total_nnz, vocab_size)\n                logits_rmpad.div_(temperature)\n\n                # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)\n                inplace_backward = True\n                if calculate_entropy:\n                    inplace_backward = False\n                log_probs = logprobs_from_logits(\n                    logits=logits_rmpad,\n                    labels=input_ids_rmpad_rolled,\n                    inplace_backward=inplace_backward,\n                )\n\n                # compute entropy\n                if calculate_entropy:\n                    if not self.engine_config.entropy_checkpointing:\n                        entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad)  # ((total_nnz / sp) + pad)\n                    else:\n                        entropy_rmpad = torch.utils.checkpoint.checkpoint(\n                            self.compute_entropy_from_logits, logits_rmpad\n                        )\n\n            # gather log_prob if sp > 1\n            if self.use_ulysses_sp:\n                pad_size = output_args[\"pad_size\"]\n\n                # gather and unpad for the ulysses sp\n                log_probs = gather_outputs_and_unpad(\n                    log_probs,\n                    gather_dim=0,\n                    unpad_dim=0,\n                    padding_size=pad_size,\n                )\n                if calculate_entropy:\n                    entropy_rmpad = gather_outputs_and_unpad(\n                        entropy_rmpad,\n                        gather_dim=0,\n                        unpad_dim=0,\n                        padding_size=pad_size,\n                    )\n\n            if pad_mode == DatasetPadMode.NO_PADDING:\n                cu_seqlens = input_ids.offsets()\n                # (bsz, j1), for each sample, is the length of each sample: [real_prompt length + real_response length]\n                log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)\n                if calculate_entropy:\n                    entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)\n            else:\n                raise NotImplementedError(f\"pad_mode {pad_mode} not implemented\")\n\n        else:  # not using rmpad and no ulysses sp\n            response_length = tu.get_non_tensor_data(data=micro_batch, key=\"max_response_length\", default=1024)\n            if use_fused_kernels:\n                log_probs = output.log_probs[:, -response_length - 1 : -1]\n                entropy = output.entropy[:, -response_length - 1 : -1]  # (bsz, response_length)\n\n            else:\n                logits = output.logits\n                logits.div_(temperature)\n\n                if calculate_entropy:\n                    if not self.engine_config.entropy_checkpointing:\n                        entropy = verl_F.entropy_from_logits(logits)\n                    else:\n                        entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits)\n\n                if pad_mode == DatasetPadMode.NO_PADDING:\n                    cu_seqlens = input_ids.offsets()\n                    seq_lengths = cu_seqlens.diff()\n                    starts = torch.zeros_like(seq_lengths, dtype=torch.int64)\n                    logits = torch.nested.narrow(logits, 1, starts, seq_lengths, layout=torch.jagged)\n                    logits_rmpad = torch.cat([t for t in logits.unbind()])\n                    input_ids_rmpad_rolled = output_args[\"input_ids_rmpad_rolled\"]\n                    log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)\n                    # (bsz, j1), for each sample, length of each sample: [real_prompt_length + real_response_length]\n                    log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens)\n                    if calculate_entropy:\n                        entropy = torch.nested.narrow(entropy, 1, starts, seq_lengths, layout=torch.jagged)\n                        entropy_rmpad = torch.cat([t for t in entropy.unbind()])\n                        entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens)\n                else:\n                    raise NotImplementedError(f\"pad_mode {pad_mode} not implemented\")\n\n        model_output[\"log_probs\"] = log_probs\n        if calculate_entropy:\n            model_output[\"entropy\"] = entropy\n\n        return model_output\n\n    def forward_step(self, micro_batch: TensorDict, loss_function, forward_only):\n        device_name = get_device_name()\n        # actually, we should avoid assigning like this...\n        micro_batch = micro_batch.to(get_device_id())\n        model_inputs, output_args = self.prepare_model_inputs(micro_batch=micro_batch)\n\n        with torch.autocast(device_type=device_name, dtype=torch.bfloat16):\n            raw_output = self.module(\n                **model_inputs,\n                use_cache=False,\n            )  # prevent model thinks we are generating\n\n            model_output = self.prepare_model_outputs(\n                output=raw_output, output_args=output_args, micro_batch=micro_batch\n            )\n\n            if loss_function is not None:\n                loss, metrics = loss_function(\n                    model_output=model_output, data=micro_batch, dp_group=self.get_data_parallel_group()\n                )\n            else:\n                assert forward_only, \"forward_only must be True when loss_function is None\"\n                loss = torch.tensor(1.0, device=device_name)\n                metrics = {}\n\n            output = {\n                \"model_output\": model_output,\n                \"loss\": loss,\n                \"metrics\": metrics,\n            }\n\n            return loss, output\n\n\n@EngineRegistry.register(model_type=\"value_model\", backend=[\"fsdp\", \"fsdp2\"], device=[\"cuda\", \"npu\"])\nclass FSDPEngineWithValueHead(FSDPEngineWithLMHead):\n    \"\"\"\n    The only difference between critic and actor is how the raw model output is processed\n    \"\"\"\n\n    def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict):\n        use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key=\"use_remove_padding\", default=True)\n        pad_mode = tu.get_non_tensor_data(data=micro_batch, key=\"pad_mode\", default=DatasetPadMode.NO_PADDING)\n\n        if use_remove_padding:\n            input_ids = micro_batch[\"input_ids\"]\n            batch_size, seqlen = input_ids.shape\n\n            if hasattr(self.module, \"v_head\"):\n                # For trl.AutoModelForCausalLMWithValueHead\n                values_rmpad = output[2].squeeze(0).unsqueeze(-1)\n            else:\n                values_rmpad = output.logits\n                values_rmpad = values_rmpad.squeeze(0)  # (total_nnz, 1)\n                # critic model arch is like Qwen3ForTokenClassfication and num_labels=1\n                # so we squeeze the last dimension here to get the value for each token\n                values_rmpad = values_rmpad.squeeze(-1)\n\n            # gather output if sp > 1\n            if self.use_ulysses_sp:\n                pad_size = output_args[\"pad_size\"]\n                values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)\n\n            if pad_mode == DatasetPadMode.NO_PADDING:\n                cu_seqlens = input_ids.offsets()\n                # (bsz, j1), for each sample, is the length of each sample: [real_prompt length + real_response length]\n                values = torch.nested.nested_tensor_from_jagged(values_rmpad, cu_seqlens)\n            else:\n                raise NotImplementedError(f\"pad_mode {pad_mode} not implemented\")\n\n        else:\n            if hasattr(self.module, \"v_head\"):\n                # For trl.AutoModelForCausalLMWithValueHead\n                values = output[2]\n            else:\n                values = output.logits\n\n            if pad_mode == DatasetPadMode.NO_PADDING:\n                cu_seqlens = input_ids.offsets()\n                seq_lengths = cu_seqlens.diff()\n                starts = torch.zeros_like(seq_lengths, dtype=torch.int64)\n                values = torch.nested.narrow(values, 1, starts, seq_lengths, layout=torch.jagged)\n                values_rmpad = torch.cat([t for t in values.unbind()])\n                # (bsz, j1), for each sample, length of each sample: [real_prompt_length + real_response_length]\n                values = torch.nested.nested_tensor_from_jagged(values_rmpad, cu_seqlens)\n            else:\n                raise NotImplementedError(f\"pad_mode {pad_mode} not implemented\")\n\n        return {\"values\": values}\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/fsdp/utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom torch.distributed.device_mesh import init_device_mesh\n\nfrom verl.utils.device import get_device_name\n\n\ndef create_device_mesh(world_size, fsdp_size):\n    \"\"\"\n    Create a device mesh for distributed training based on the world size and FSDP size.\n\n    Args:\n        world_size (int): Total number of processes in the distributed training setup.\n        fsdp_size (int): Size of the Fully Sharded Data Parallel (FSDP) group.\n\n    Returns:\n        torch.distributed.device_mesh.DeviceMesh: The initialized device mesh.\n    \"\"\"\n    device_name = get_device_name()\n    if fsdp_size < 0 or fsdp_size >= world_size:\n        device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n    else:\n        device_mesh = init_device_mesh(\n            device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=[\"ddp\", \"fsdp\"]\n        )\n    return device_mesh\n\n\ndef get_sharding_strategy(device_mesh):\n    \"\"\"\n    Determine the appropriate sharding strategy based on the number of dimensions of the device mesh.\n\n    Args:\n        device_mesh (torch.distributed.device_mesh.DeviceMesh): The device mesh used for distributed training.\n\n    Returns:\n        torch.distributed.fsdp.ShardingStrategy: The sharding strategy to be used with FSDP.\n\n    Raises:\n        NotImplementedError: If the number of dimensions of the device mesh is neither 1 nor 2.\n    \"\"\"\n    from torch.distributed.fsdp import ShardingStrategy\n\n    if device_mesh.ndim == 1:\n        sharding_strategy = ShardingStrategy.FULL_SHARD\n    elif device_mesh.ndim == 2:\n        sharding_strategy = ShardingStrategy.HYBRID_SHARD\n    else:\n        raise NotImplementedError(f\"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2\")\n    return sharding_strategy\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .transformer_impl import MegatronEngine, MegatronEngineWithLMHead\n\n__all__ = [\"MegatronEngine\", \"MegatronEngineWithLMHead\"]\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/megatron/transformer_impl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom functools import partial\nfrom typing import Any, Callable, Iterator, Optional\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.pipeline_parallel import get_forward_backward_func\nfrom omegaconf import OmegaConf\nfrom tensordict import TensorDict\n\nfrom verl.models.mcore import get_mcore_weight_converter\nfrom verl.trainer.config import CheckpointConfig\nfrom verl.utils import tensordict_utils as tu\nfrom verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager\nfrom verl.utils.dataset.dataset_utils import DatasetPadMode\nfrom verl.utils.device import get_device_id, get_device_name\nfrom verl.utils.megatron.pipeline_parallel import make_batch_generator\nfrom verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits\nfrom verl.utils.megatron_utils import (\n    load_megatron_model_to_gpu,\n    load_megatron_optimizer,\n    offload_megatron_model_to_cpu,\n    offload_megatron_optimizer,\n    per_tensor_generator,\n)\nfrom verl.utils.model import load_mcore_dist_weights, load_megatron_gptmodel_weights\nfrom verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig\n\nfrom ..base import BaseEngine, EngineRegistry\nfrom ..utils import postprocess_batch_func, prepare_micro_batches\nfrom .utils import set_random_seed\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MegatronEngine(BaseEngine):\n    def __init__(\n        self,\n        model_config: HFModelConfig,\n        engine_config: McoreEngineConfig,\n        optimizer_config: McoreOptimizerConfig,\n        checkpoint_config: CheckpointConfig,\n    ):\n        super().__init__()\n\n        self.model_config = model_config\n        self.engine_config = engine_config\n        self.optimizer_config = optimizer_config\n        self.checkpoint_config = checkpoint_config\n\n        self._init_device_mesh()\n\n        set_random_seed(seed=self.engine_config.seed)\n\n        self._is_offload_param = self.engine_config.param_offload\n        self._is_offload_grad = self.engine_config.grad_offload\n        self._is_offload_optimizer = self.engine_config.optimizer_offload\n\n        self.mode = None\n\n        self.layer_name_mapping = {\n            \"qkv_layer_name\": \"self_attention.linear_qkv.\",\n            \"gate_proj_layer_name\": \"linear_fc1.\",\n        }\n        self.weight_converter = None\n\n    def _init_device_mesh(self):\n        mpu.initialize_model_parallel(\n            tensor_model_parallel_size=self.engine_config.tensor_model_parallel_size,\n            pipeline_model_parallel_size=self.engine_config.pipeline_model_parallel_size,\n            virtual_pipeline_model_parallel_size=self.engine_config.virtual_pipeline_model_parallel_size,\n            pipeline_model_parallel_split_rank=None,\n            use_sharp=False,\n            context_parallel_size=self.engine_config.context_parallel_size,\n            expert_model_parallel_size=self.engine_config.expert_model_parallel_size,\n            expert_tensor_parallel_size=self.engine_config.expert_tensor_parallel_size,\n            nccl_communicator_config_path=None,\n        )\n\n    def _build_tf_config(self):\n        from verl.models.mcore import hf_to_mcore_config\n        from verl.utils.torch_dtypes import PrecisionType\n\n        self.param_dtype = torch.bfloat16\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n        tf_config = hf_to_mcore_config(\n            self.model_config.hf_config, self.dtype, **self.engine_config.override_transformer_config\n        )\n\n        use_mbridge = self.engine_config.use_mbridge\n        if use_mbridge:\n            from verl.models.mcore.mbridge import AutoBridge\n\n            bridge = AutoBridge.from_config(self.model_config.hf_config)\n            bridge.set_extra_args(**self.engine_config.override_transformer_config)\n            tf_config = bridge.config\n            self.bridge = bridge\n        else:\n            self.bridge = None\n\n        if not self.bridge:\n            self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype)\n\n        if torch.distributed.get_rank() == 0:\n            print(f\"TF config: {tf_config}\")\n        self.tf_config = tf_config\n\n    def _build_megatron_module(self):\n        from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module\n        from verl.utils.model import print_model_size\n\n        # TODO: add more cases\n        is_value_model = (\n            \"ForTokenClassification\" in self.model_config.architectures[0]\n            or \"ForSequenceClassification\" in self.model_config.architectures[0]\n        )\n\n        self.is_value_model = is_value_model\n\n        if self.engine_config.forward_only:\n            wrap_with_ddp = False\n        else:\n            wrap_with_ddp = True\n\n        wrap_config = McoreModuleWrapperConfig(\n            is_value_model=is_value_model,  # actor is not value model\n            share_embeddings_and_output_weights=self.model_config.share_embeddings_and_output_weights,\n            wrap_with_ddp=wrap_with_ddp,\n            use_distributed_optimizer=self.engine_config.use_distributed_optimizer,\n        )\n        module = make_megatron_module(\n            wrap_config=wrap_config,\n            tf_config=self.tf_config,\n            hf_config=self.model_config.hf_config,\n            bridge=self.bridge,\n            override_model_config=self.engine_config.override_mcore_model_config,\n            override_ddp_config=self.engine_config.override_ddp_config,\n        )\n        print(f\"module: {len(module)}\")\n\n        if self.engine_config.use_dist_checkpointing:\n            load_mcore_dist_weights(module, self.engine_config.dist_checkpointing_path, is_value_model=is_value_model)\n        else:\n            if self.bridge is not None:\n                self.bridge.load_weights(module, self.model_config.local_path)\n            else:\n                # (vermouth1992) this is a workaround to be compatible with the old API\n                tmp_config = OmegaConf.create(\n                    {\"model\": {\"path\": self.model_config.local_path, \"use_shm\": self.model_config.use_shm}}\n                )\n\n                load_megatron_gptmodel_weights(\n                    tmp_config,\n                    self.model_config.hf_config,\n                    module,\n                    params_dtype=self.dtype,\n                    is_value_model=is_value_model,\n                )\n\n        if torch.distributed.get_rank() == 0:\n            print_model_size(module[0])\n\n        return module\n\n    def _build_optimizer(self):\n        from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config\n\n        optim_config_megatron = init_megatron_optim_config(self.optimizer_config)\n        optimizer = get_megatron_optimizer(model=self.module, config=optim_config_megatron)\n        return optimizer\n\n    def _build_lr_scheduler(self):\n        from verl.utils.megatron.optimizer import get_megatron_optimizer_param_scheduler\n\n        optimizer_scheduler = get_megatron_optimizer_param_scheduler(\n            optimizer=self.optimizer, config=self.optimizer_config\n        )\n        return optimizer_scheduler\n\n    def is_mp_src_rank_with_outputs(self):\n        return (\n            mpu.get_tensor_model_parallel_rank() == 0\n            and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1\n            and mpu.get_context_parallel_rank() == 0\n        )\n\n    def initialize(self):\n        self._build_tf_config()\n\n        self.module = self._build_megatron_module()\n\n        if not self.engine_config.forward_only:\n            self.optimizer = self._build_optimizer()\n            self.lr_scheduler = self._build_lr_scheduler()\n        else:\n            self.optimizer = None\n            self.lr_scheduler = None\n\n        tmp_config = OmegaConf.create({\"model\": {\"path\": self.model_config.local_path}})\n\n        role = \"actor\" if not self.is_value_model else \"critic\"\n\n        self.checkpoint_mananager = MegatronCheckpointManager(\n            config=tmp_config,\n            checkpoint_config=self.checkpoint_config,\n            model_config=self.model_config.hf_config,\n            transformer_config=self.tf_config,\n            role=role,\n            model=self.module,\n            arch=self.model_config.architectures[0],\n            hf_config=self.model_config.hf_config,\n            param_dtype=self.param_dtype,\n            share_embeddings_and_output_weights=self.model_config.share_embeddings_and_output_weights,\n            processing_class=self.model_config.get_processor(),\n            optimizer=self.optimizer,\n            optimizer_scheduler=self.lr_scheduler,\n            use_distributed_optimizer=self.engine_config.use_distributed_optimizer,\n            use_checkpoint_opt_param_scheduler=self.optimizer_config.use_checkpoint_opt_param_scheduler,\n            bridge=self.bridge,\n            use_dist_checkpointing=self.engine_config.use_dist_checkpointing,\n        )\n\n    def train_mode(self):\n        \"\"\"\n        Context manager entry for switching the engine and model into training mode.\n\n        Usage:\n            with engine.train_mode():\n                # runs in training mode\n        \"\"\"\n        return EngineTrainModeCtx(self)\n\n    def eval_mode(self):\n        \"\"\"\n        Context manager entry for switching the engine and model into evaluation mode.\n\n        Usage:\n            with engine.eval_mode():\n                # runs in evaluation mode\n        \"\"\"\n        return EngineEvalModeCtx(self)\n\n    def optimizer_zero_grad(self):\n        \"\"\"\n        Zero out gradients of all parameters before starting a new backward pass.\n        \"\"\"\n        self.optimizer.zero_grad()\n        # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm\n        for chunk in self.module:\n            # if use distributed optimizer, zero grad buffer will be handled by optimizer\n            chunk.zero_grad_buffer()\n\n    def optimizer_step(self):\n        \"\"\"\n        Perform an optimization step to update model parameters based on accumulated gradients.\n\n        Returns:\n            grad_norm (float): The norm of the gradients before clipping or update.\n        \"\"\"\n        update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step()\n\n        if update_successful:\n            # allgather already execute in optimizer.step in new megatron\n            pass\n        else:\n            raise NotImplementedError(\"Megatron optimizer step failed. This should not happen\")\n\n        return grad_norm\n\n    def lr_scheduler_step(self):\n        \"\"\"\n        Advance the learning rate scheduler by one step.\n\n        Returns:\n            current_lr (float or list[float]): Updated learning rate(s).\n        \"\"\"\n        from verl.utils.megatron.optimizer import get_megatron_last_lr\n\n        self.lr_scheduler.step(1)\n        return get_megatron_last_lr(self.optimizer)\n\n    def to(self, device: str, model: bool = True, optimizer: bool = True):\n        \"\"\"\n        Move model parameters, optimizer states, or both to the specified device.\n\n        Args:\n            device: Target device identifier.\n            model: If True, move the model.\n            optimizer: If True, move the optimizer states.\n        \"\"\"\n        device_name = get_device_name()\n\n        assert device in (device_name, \"cpu\")\n        if device == device_name:\n            if not self.engine_config.param_offload:\n                if model:\n                    load_megatron_model_to_gpu(self.module, load_grad=True)\n                if optimizer and self.optimizer is not None:\n                    load_megatron_optimizer(self.optimizer, device)\n        elif device == \"cpu\":\n            if not self.engine_config.param_offload:\n                if model:\n                    offload_megatron_model_to_cpu(self.module)\n                if optimizer and self.optimizer is not None:\n                    offload_megatron_optimizer(self.optimizer)\n        else:\n            raise ValueError(f\"Invalid device type: {device}\")\n\n    def get_data_parallel_rank(self):\n        return mpu.get_data_parallel_rank()\n\n    def get_data_parallel_size(self):\n        return mpu.get_data_parallel_world_size()\n\n    def get_data_parallel_group(self):\n        return mpu.get_data_parallel_group()\n\n    def save_checkpoint(\n        self,\n        local_path: str,\n        hdfs_path: Optional[str] = None,\n        global_step: int = 0,\n        max_ckpt_to_keep: Optional[int] = None,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Save model, optimizer, and scheduler states to a checkpoint.\n\n        Args:\n            local_path: Local filesystem path to save checkpoint.\n            hdfs_path: Optional HDFS path to copy checkpoint.\n            global_step: Integer training step number for naming.\n            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.\n        \"\"\"\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.module, load_grad=True)\n        self.checkpoint_mananager.save_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.module)\n\n    def load_checkpoint(\n        self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs\n    ) -> None:\n        \"\"\"\n        Load model, optimizer, and scheduler states from a checkpoint.\n\n        Args:\n            local_path: Local filesystem path of the checkpoint.\n            hdfs_path: Optional HDFS path where checkpoint is stored.\n            del_local_after_load: Whether to delete local copy after loading.\n        \"\"\"\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.module)\n        self.checkpoint_mananager.load_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.optimizer)\n\n    def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> Any:\n        tu.assign_non_tensor(data, sp_size=self.engine_config.context_parallel_size)\n\n        # compute num_tokens in global batch for loss normalization\n        batch_num_tokens = data[\"loss_mask\"].sum().to(get_device_id())\n        torch.distributed.all_reduce(\n            batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group()\n        )\n        tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item())\n        tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size())\n\n        vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n        if vpp_size is not None and vpp_size > 1:\n            num_batches_divided_by = self.tf_config.microbatch_group_size_per_vp_stage\n        else:\n            num_batches_divided_by = None\n\n        micro_batches, indices = prepare_micro_batches(\n            data=data,\n            dp_group=self.get_data_parallel_group(),\n            num_batches_divided_by=num_batches_divided_by,\n            same_micro_num_in_dp=False,\n            min_num_micro_batch=None,\n        )\n\n        if num_batches_divided_by is not None:\n            assert len(micro_batches) % num_batches_divided_by == 0, (\n                f\"micro_batches {micro_batches} must be divisible by num_batches_divided_by \"\n                f\"{num_batches_divided_by} for megatron backend\"\n            )\n\n        # compute input shapes for pp stages\n        n_micro_batch = len(micro_batches)\n\n        for micro_batch in micro_batches:\n            tu.assign_non_tensor(micro_batch, num_micro_batch=n_micro_batch)\n\n        forward_backward_func = get_forward_backward_func()\n\n        postprocess_micro_batch_func = partial(\n            self.postprocess_micro_batch_func,\n            forward_only=forward_only,\n            loss_function=loss_function,\n        )\n\n        tu.assign_non_tensor(data, num_micro_batch=n_micro_batch)\n\n        forward_step = partial(self.forward_step, postprocess_micro_batch_func=postprocess_micro_batch_func)\n\n        # batch should be a list of batches inside micro-batches\n        batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.module))\n\n        # TODO: we may use the new schedule instead\n        # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)\n        losses_reduced = forward_backward_func(\n            forward_step_func=forward_step,\n            data_iterator=batch_generator,\n            model=self.module,\n            num_microbatches=n_micro_batch,\n            seq_length=1,  # the communication shape is obtained via p2p comm\n            micro_batch_size=1,  # the communication shape is obtained via p2p comm\n            forward_only=forward_only,\n        )\n        # loss_reduces contains the stats returned from loss_func\n        if mpu.is_pipeline_last_stage(ignore_virtual=True):\n            return postprocess_batch_func(output_lst=losses_reduced, indices=indices, data=data)\n        else:\n            return {}\n\n    def get_per_tensor_param(self):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.module, load_grad=False)\n        if self.bridge is not None:\n            per_tensor_param = self.bridge.export_weights(self.module)\n        else:\n            per_tensor_param = per_tensor_generator(\n                self.module,\n                self.model_config.hf_config,\n                self.weight_converter,\n                self.tf_config,\n                self.layer_name_mapping,\n            )\n        return per_tensor_param\n\n    def forward_step(self, batch_iter, model, postprocess_micro_batch_func):\n        raise NotImplementedError(\"forward_step must be implemented in subclass\")\n\n    def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: bool, loss_function):\n        raise NotImplementedError(\"postprocess_micro_batch_func must be implemented in subclass\")\n\n\nclass EngineEvalModeCtx:\n    def __init__(self, engine: MegatronEngine):\n        self.engine = engine\n\n    def __enter__(self):\n        assert isinstance(self.engine, MegatronEngine)\n\n        self.engine.mode = \"eval\"\n        if self.engine._is_offload_param:\n            load_megatron_model_to_gpu(self.engine.module, load_grad=True)\n\n        # mcore module is a list of model chunk in each vpp stage\n        for module in self.engine.module:\n            module.eval()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        if self.engine._is_offload_param:\n            offload_megatron_model_to_cpu(self.engine.module)\n        self.engine.mode = None\n\n\nclass EngineTrainModeCtx:\n    def __init__(self, engine: MegatronEngine):\n        self.engine = engine\n\n    def __enter__(self):\n        assert isinstance(self.engine, MegatronEngine)\n\n        self.engine.mode = \"train\"\n        if self.engine._is_offload_param:\n            load_megatron_model_to_gpu(self.engine.module, load_grad=True)\n        if self.engine._is_offload_optimizer:\n            load_megatron_optimizer(optimizer=self.engine.optimizer)\n\n        # mcore module is a list of model chunk in each vpp stage\n        for module in self.engine.module:\n            module.train()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        if self.engine._is_offload_param:\n            offload_megatron_model_to_cpu(self.engine.module)\n        if self.engine._is_offload_optimizer:\n            offload_megatron_optimizer(optimizer=self.engine.optimizer)\n        self.engine.mode = None\n\n\n@EngineRegistry.register(model_type=\"language_model\", backend=\"megatron\")\nclass MegatronEngineWithLMHead(MegatronEngine):\n    def prepare_model_inputs(self, batch: TensorDict):\n        batch = batch.to(get_device_id())\n        batch = batch.contiguous()\n        input_ids = batch[\"input_ids\"]\n        loss_mask = batch[\"loss_mask\"].to(bool)\n        position_ids = batch[\"position_ids\"]\n\n        # process vlm inputs\n        has_multi_modal_inputs = \"multi_modal_inputs\" in batch.keys()\n        if has_multi_modal_inputs:\n            batch[\"multi_modal_inputs\"] = batch[\"multi_modal_inputs\"]\n            batch[\"multi_modal_inputs_idx\"] = torch.Tensor(list(range(len(batch[\"multi_modal_inputs\"])))).to(\n                torch.int64\n            )\n\n        if batch[\"position_ids\"].dim() == 3:  # qwen2vl mrope [bs, 3, seq_len]\n            batch[\"position_ids\"] = batch[\"position_ids\"][\n                :, 0\n            ]  # mcore patch recompute qwen2vl's pos ids during forward\n\n        multi_modal_inputs = {}\n        if \"multi_modal_inputs\" in batch:\n            from verl.utils.model import extract_multi_modal_inputs\n\n            indices = batch.get(\"multi_modal_inputs_idx\", None)\n            multi_modal_inputs = extract_multi_modal_inputs(batch[\"multi_modal_inputs\"], indices)\n\n        return {\n            \"input_ids\": input_ids,\n            \"loss_mask\": loss_mask,\n            \"position_ids\": position_ids,\n            \"multi_modal_inputs\": multi_modal_inputs,\n        }\n\n    def prepare_model_outputs(self, output: dict, data: TensorDict):\n        calculate_entropy = tu.get_non_tensor_data(data, key=\"calculate_entropy\", default=False)\n\n        log_prob = output[\"log_probs\"]\n        model_output = {\"log_probs\": log_prob}\n        if calculate_entropy:\n            entropy = output[\"entropy\"]\n            model_output[\"entropy\"] = entropy\n\n        return model_output\n\n    def forward_step(self, batch_iter: Iterator[TensorDict], model, postprocess_micro_batch_func):\n        batch: TensorDict = next(batch_iter)\n        batch = batch.to(get_device_id())\n        use_fused_kernels = tu.get_non_tensor_data(batch, key=\"use_fused_kernels\", default=False)\n        calculate_entropy = tu.get_non_tensor_data(batch, key=\"calculate_entropy\", default=False)\n        pad_mode = tu.get_non_tensor_data(batch, key=\"pad_mode\", default=DatasetPadMode.NO_PADDING)\n        temperature = batch[\"temperature\"]\n\n        model_inputs = self.prepare_model_inputs(batch)\n        input_ids = model_inputs[\"input_ids\"]\n        multi_modal_inputs = model_inputs[\"multi_modal_inputs\"]\n\n        if pad_mode == DatasetPadMode.NO_PADDING:\n            label = input_ids.clone()\n        else:\n            raise NotImplementedError(f\"Pad mode {pad_mode} is not supported for megatron engine\")\n\n        from verl.models.mcore import get_mcore_forward_no_padding_fn\n\n        if use_fused_kernels:\n            raise NotImplementedError(\"Fused kernels are not supported for megatron engine\")\n\n        forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config)\n\n        def logits_processor(logits, label):\n            assert logits.shape[:2] == label.shape[:2]\n            logits.div_(temperature)\n            ret = {}\n            if calculate_entropy:\n                logits_bak = logits.clone()\n                if torch.distributed.get_rank() == 0:\n                    logger.warning_once(\n                        \"For memory-efficient computation, enable fused kernels via \"\n                        \"`actor_rollout_ref.model.use_fused_kernels=True`. \"\n                        \"The current `clone()` operation ensures correctness but increases memory usage.\"\n                    )\n                entropy = vocab_parallel_entropy(logits)\n                ret[\"entropy\"] = entropy\n            else:\n                logits_bak = logits\n\n            # Create the final labels for next-token prediction.\n            # The `label` tensor starts as a clone of `input_ids`. `torch.roll` is not applied\n            # earlier because `input_ids` is a nested tensor, which is incompatible with the operation.\n            # The `preprocess_packed_seqs_no_padding` function unnests and flattens the tensor\n            # into `input_ids_rmpad` (shape: [1, total_seqlen]).\n            # Now, on this simple, unpadded tensor, we can perform the standard left shift\n            # to align the target token `t+1` with the prediction for token `t`.\n            label = torch.roll(label, shifts=-1, dims=1)\n\n            log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)\n            ret[\"log_probs\"] = log_probs\n            return ret\n\n        logits_processor_args = {\"label\": label}\n\n        output = forward_fn(\n            model,\n            input_ids,\n            multi_modal_inputs,\n            logits_processor=logits_processor,\n            logits_processor_args=logits_processor_args,\n        )\n\n        return output, partial(postprocess_micro_batch_func, data=batch)\n\n    def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: bool, loss_function):\n        # For memory efficiency\n        # We move calculation of entropy to compute_log_probs, forward_only == True\n        device = data[\"input_ids\"].device\n        model_output = self.prepare_model_outputs(output, data)\n\n        if loss_function is not None:\n            loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group())\n            # scale loss by num_micro_batch because megatron will scale loss\n            # by n_micro_batch and cp size inside pp schedule\n            loss = loss * data[\"num_micro_batch\"] / mpu.get_context_parallel_world_size()\n        else:\n            assert forward_only, \"forward_only must be True when loss_function is None\"\n            loss = torch.tensor(1.0, device=device)\n            metrics = {}\n\n        output = {\n            \"model_output\": model_output,\n            \"loss\": loss,\n            \"metrics\": metrics,\n        }\n\n        # return loss and stats\n        return loss, output\n\n\n@EngineRegistry.register(model_type=\"value_model\", backend=\"megatron\")\nclass MegatronEngineWithValueHead(MegatronEngineWithLMHead):\n    # for value head\n    def forward_step(self, batch_iter, model, postprocess_micro_batch_func):\n        batch: TensorDict = next(batch_iter)\n        batch = batch.to(get_device_id())\n        model_inputs = self.prepare_model_inputs(batch)\n        input_ids = model_inputs[\"input_ids\"]\n        multi_modal_inputs = model_inputs[\"multi_modal_inputs\"]\n\n        from verl.models.mcore import get_mcore_forward_no_padding_fn\n\n        forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config)\n\n        output = forward_fn(\n            model,\n            input_ids,\n            multi_modal_inputs,\n            value_model=True,\n        )\n\n        return output, partial(postprocess_micro_batch_func, data=batch)\n\n    def prepare_model_outputs(self, output: dict | torch.Tensor, data: TensorDict):\n        return {\"values\": output}\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/megatron/utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom verl.utils.device import get_torch_device\n\n\ndef set_random_seed(seed):\n    import random\n\n    import numpy as np\n    import torch\n\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    if get_torch_device().device_count() > 0:\n        from megatron.core import tensor_parallel\n\n        tensor_parallel.model_parallel_cuda_manual_seed(seed)\n    # FIXME: torch cumsum not support deterministic (used in vllm sampler),\n    # https://github.com/pytorch/pytorch/issues/89492\n    # torch.use_deterministic_algorithms(True, warn_only=True)\n    # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/mindspeed/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .transformer_impl import MindspeedEngineWithLMHead\n\n__all__ = [\"MindspeedEngineWithLMHead\"]\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/mindspeed/transformer_impl.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nfrom mindspeed.megatron_adaptor import repatch\n\nfrom verl.trainer.config import CheckpointConfig\nfrom verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig\n\nfrom ..base import EngineRegistry\nfrom ..megatron import MegatronEngineWithLMHead\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n@EngineRegistry.register(model_type=\"language_model\", backend=\"megatron\", device=\"npu\")\nclass MindspeedEngineWithLMHead(MegatronEngineWithLMHead):\n    def __init__(\n        self,\n        model_config: HFModelConfig,\n        engine_config: McoreEngineConfig,\n        optimizer_config: McoreOptimizerConfig,\n        checkpoint_config: CheckpointConfig,\n    ):\n        super().__init__(model_config, engine_config, optimizer_config, checkpoint_config)\n\n        repatch_config = {\"use_flash_attn\": True}\n        if self.engine_config.context_parallel_size > 1:\n            repatch_config[\"context_parallel_size\"] = self.engine_config.context_parallel_size\n\n        repatch(repatch_config)\n"
  },
  {
    "path": "verl_distillation/verl/workers/engine/utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\nfrom tensordict import TensorDict\n\nfrom verl.utils import tensordict_utils as tu\nfrom verl.utils.dataset.dataset_utils import DatasetPadMode\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import rearrange_micro_batches, restore_dynamic_batch\n\n\ndef prepare_micro_batches(\n    data: TensorDict,\n    dp_group=None,\n    num_batches_divided_by=None,\n    same_micro_num_in_dp=True,\n    min_num_micro_batch=None,\n    use_dynamic_bsz_balance=True,\n):\n    \"\"\"\n    Prepare micro batches from data.\n    \"\"\"\n    use_dynamic_bsz = tu.get_non_tensor_data(data=data, key=\"use_dynamic_bsz\", default=True)\n    sp_size = tu.get_non_tensor_data(data=data, key=\"sp_size\", default=1)\n\n    if use_dynamic_bsz:\n        assert \"max_token_len_per_gpu\" in data.keys(), \"max_token_len_per_gpu must be set when use_dynamic_bsz is True\"\n        max_token_len_per_gpu = data[\"max_token_len_per_gpu\"]\n        max_token_len = max_token_len_per_gpu * sp_size\n        micro_batches, batch_idx_list = rearrange_micro_batches(\n            data,\n            max_token_len=max_token_len,\n            dp_group=dp_group,\n            num_batches_divided_by=num_batches_divided_by,\n            same_micro_num_in_dp=same_micro_num_in_dp,\n            min_num_micro_batch=min_num_micro_batch,\n            use_dynamic_bsz_balance=use_dynamic_bsz_balance,\n        )\n    else:\n        micro_batch_size_per_gpu = data[\"micro_batch_size_per_gpu\"]\n        micro_batches = data.split(micro_batch_size_per_gpu)\n        batch_idx_list = None\n    return micro_batches, batch_idx_list\n\n\ndef postprocess_batch_func(output_lst, indices, data: TensorDict):\n    \"\"\"postprocess the output of a forward_backward_batch.\n    output_lst is a list of dict containing outputs for each micro-batch\n    reorder entropy and outputs. Return None for other pp ranks\n    only on last rank. It should be on every tp rank\n\n    each losses_reduced contains 1. model_output, 2. loss, 3. metrics.\n    \"\"\"\n\n    use_dynamic_bsz = tu.get_non_tensor_data(data=data, key=\"use_dynamic_bsz\", default=True)\n    pad_mode = tu.get_non_tensor_data(data=data, key=\"pad_mode\", default=DatasetPadMode.NO_PADDING)\n    assert pad_mode == DatasetPadMode.NO_PADDING, \"postprocess_batch_func only support NO_PADDING pad_mode\"\n\n    # losses_reduced is a list of dict containing outputs for each micro-batch\n    # reorder entropy and outputs. Return None for other pp ranks\n    # only on last rank. It should be on every tp rank\n\n    # losses_reduced contains 1. model_output, 2. loss, 3. metrics.\n    # We perform reverse\n\n    model_output = {}\n    losses = []\n    aggregated_metrics = {}\n\n    # model output\n    for o in output_lst:\n        if \"model_output\" in o:\n            for key, val in o[\"model_output\"].items():\n                if key not in model_output:\n                    model_output[key] = []\n                model_output[key].append(val)\n\n    # concat results from micro batches\n    for key, val in model_output.items():\n        if pad_mode == DatasetPadMode.NO_PADDING:\n            tensors = [tensor for nt in model_output[key] for tensor in nt.unbind()]\n            model_output[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)\n        else:\n            raise NotImplementedError(f\"pad_mode {pad_mode} not implemented\")\n\n        # reverse with dynamic bsz\n        if use_dynamic_bsz:\n            model_output[key] = restore_dynamic_batch(model_output[key], indices)\n\n    # loss\n    for o in output_lst:\n        if \"loss\" in o:\n            losses.append(o[\"loss\"])\n\n    # metrics\n    for o in output_lst:\n        if \"metrics\" in o:\n            metrics = o[\"metrics\"]\n            append_to_dict(aggregated_metrics, metrics)\n\n    output = {\n        \"model_output\": model_output,\n        \"loss\": losses,\n        \"metrics\": aggregated_metrics,\n    }\n\n    return output\n"
  },
  {
    "path": "verl_distillation/verl/workers/fsdp_workers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe main entry point to run the PPO algorithm\n\"\"\"\n\nimport datetime\nimport json\nimport logging\nimport os\nimport warnings\nfrom dataclasses import asdict\nfrom typing import Any, Optional\n\nimport numpy as np\nimport psutil\nimport torch\nimport torch.distributed\nimport torch.distributed as dist\nfrom codetiming import Timer\nfrom omegaconf import DictConfig, OmegaConf, open_dict\nfrom peft import LoraConfig, TaskType, get_peft_model\nfrom safetensors.torch import save_file\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp.api import (FullStateDictConfig,\n                                        ShardedStateDictConfig, StateDictType)\n\ntry:\n    # for torch 2.5+\n    from torch.distributed.tensor import DTensor\nexcept ImportError:\n    from torch.distributed._tensor import DTensor\n\nimport verl.utils.torch_functional as verl_F\nfrom verl import DataProto\nfrom verl.models.transformers.monkey_patch import apply_monkey_patch\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import (\n    Dispatch, make_nd_compute_dataproto_dispatch_fn, register)\nfrom verl.utils import hf_processor, hf_tokenizer\nfrom verl.utils.activation_offload import enable_activation_offloading\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.device import (get_device_id, get_device_name,\n                               get_nccl_backend, get_torch_device,\n                               set_expandable_segments)\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.fsdp_utils import (CPUOffloadPolicy, MixedPrecisionPolicy,\n                                   apply_fsdp2, collect_lora_params,\n                                   fsdp2_load_full_state_dict, fsdp_version,\n                                   get_fsdp_wrap_policy,\n                                   get_init_weight_context_manager,\n                                   get_shard_placement_fn, init_fn,\n                                   layered_summon_lora_params,\n                                   load_fsdp_model_to_gpu, load_fsdp_optimizer,\n                                   offload_fsdp_model_to_cpu,\n                                   offload_fsdp_optimizer,\n                                   replace_lora_wrapper)\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.memory_utils import aggressive_empty_cache\nfrom verl.utils.model import compute_position_id_with_mask, convert_weight_keys\nfrom verl.utils.profiler import (DistProfiler, DistProfilerExtension,\n                                 ProfilerConfig, log_gpu_memory_usage,\n                                 simple_timer)\nfrom verl.utils.profiler.performance import (reduce_timing,\n                                             topk_reduce_ratio_min_max)\nfrom verl.utils.py_functional import convert_to_regular_types\nfrom verl.utils.ray_utils import get_event_loop\nfrom verl.workers.config import (FSDPCriticConfig, FSDPEngineConfig,\n                                 HFModelConfig, RolloutConfig)\nfrom verl.workers.config.optimizer import build_optimizer\nfrom verl.workers.rollout import get_rollout_class\nfrom verl.workers.sharding_manager.fsdp_ulysses import \\\n    FSDPUlyssesShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\ndevice_name = get_device_name()\n\n\ndef create_device_mesh(world_size, fsdp_size):\n    if fsdp_size < 0 or fsdp_size >= world_size:\n        device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n    else:\n        device_mesh = init_device_mesh(\n            device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=[\"ddp\", \"fsdp\"]\n        )\n    return device_mesh\n\n\ndef get_sharding_strategy(device_mesh):\n    from torch.distributed.fsdp import ShardingStrategy\n\n    if device_mesh.ndim == 1:\n        sharding_strategy = ShardingStrategy.FULL_SHARD\n    elif device_mesh.ndim == 2:\n        sharding_strategy = ShardingStrategy.HYBRID_SHARD\n    else:\n        raise NotImplementedError(f\"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2\")\n    return sharding_strategy\n\n\ndef get_vl_model_vision_tower(vl_model_instance):\n    \"\"\"\n    Util to extract Vision Tower from a VL model instance\n    \"\"\"\n    if hasattr(vl_model_instance, \"model\") and hasattr(vl_model_instance.model, \"visual\"):\n        # transformers >= 4.52.0\n        return vl_model_instance.model.visual\n    elif hasattr(vl_model_instance, \"visual\"):\n        # transformers < 4.52.0\n        return vl_model_instance.visual\n    return None\n\n\nclass ActorRolloutRefWorker(Worker, DistProfilerExtension):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    def __init__(self, config: DictConfig, role: str, **kwargs):\n        Worker.__init__(self)\n\n        self.config = config\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ.get(\"RANK\", 0))\n            world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n            torch.distributed.init_process_group(\n                backend=f\"cpu:gloo,{get_device_name()}:{get_nccl_backend()}\",\n                rank=rank,\n                world_size=world_size,\n                timeout=datetime.timedelta(seconds=self.config.get(\"nccl_timeout\", 600)),\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n\n        # build device mesh for FSDP\n        world_size = torch.distributed.get_world_size()\n        # TODO(sgm): support FSDP hybrid shard for larger model\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size)\n\n        # build device mesh for Ulysses Sequence Parallel\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.actor.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        # create training dispatch\n        if self.ulysses_device_mesh is not None:\n            is_collect = self.ulysses_device_mesh[\"sp\"].get_local_rank() == 0\n            self._register_dispatch_collect_info(\n                \"actor\", dp_rank=self.ulysses_device_mesh[\"dp\"].get_local_rank(), is_collect=is_collect\n            )\n        else:\n            self._register_dispatch_collect_info(\"actor\", dp_rank=self.rank, is_collect=True)\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n        self._lora_rank = self.config.model.get(\"lora_rank\", 0)\n        self._is_lora = self.config.model.get(\"lora_adapter_path\") is not None or self._lora_rank > 0\n\n        self.role = role\n        assert self.role in [\"actor\", \"rollout\", \"ref\", \"actor_rollout\", \"actor_rollout_ref\"]\n\n        self._is_actor = self.role in [\"actor\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_rollout = self.role in [\"rollout\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_ref = self.role in [\"ref\", \"actor_rollout_ref\"]\n        self.use_orig_params = self.config.actor.fsdp_config.get(\"use_orig_params\", False)\n\n        # TODO(haibin.lin):\n        # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig,\n        # it will actually convert the ProfilerConfig dataclass back to a DictConfig.\n        # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py)\n        # as they provides DictConfig-like interface\n        # The benefit of creating the dataclass config is to perform validation during __post_init__\n        if self._is_actor:\n            omega_profiler_config = config.actor.get(\"profiler\", {})\n        elif self._is_rollout:\n            # NOTE: In colocation mode, rollout config may not take effect (follow the actor config)\n            # This is for extendability in AsyncRL cases\n            omega_profiler_config = config.rollout.get(\"profiler\", {})\n        elif self._is_ref:\n            omega_profiler_config = config.ref.get(\"profiler\", {})\n        else:\n            raise ValueError(\n                f\"Invalid role {self.role}, should be one of \"\n                \"['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']\"\n            )\n        # omega_profiler_config is DictConfig\n        # profiler_config is a ProfilerConfig dataclass\n        profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)\n        if omega_profiler_config.get(\"tool\", None) in [\"npu\", \"nsys\", \"torch\", \"torch_memory\"]:\n            tool_config = omega_conf_to_dataclass(\n                omega_profiler_config.get(\"tool_config\", {}).get(omega_profiler_config.get(\"tool\"))\n            )\n        else:\n            tool_config = None\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config)\n        )\n\n        self._is_offload_param = False\n        self._is_offload_optimizer = False\n        if self._is_actor:\n            self._is_offload_param = self.config.actor.fsdp_config.get(\"param_offload\", False)\n            self._is_offload_optimizer = self.config.actor.fsdp_config.get(\"optimizer_offload\", False)\n        elif self._is_ref:\n            # TODO: it seems that manual offload is slowly than FSDP offload\n            self._is_offload_param = self.config.ref.fsdp_config.get(\"param_offload\", False)\n\n        # normalize config\n        if self._is_actor:\n            self.config.actor.ppo_mini_batch_size *= self.config.rollout.n\n            self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size\n            assert self.config.actor.ppo_mini_batch_size > 0, (\n                f\"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after \"\n                f\"normalization\"\n            )\n            # micro bsz\n            if self.config.actor.ppo_micro_batch_size is not None:\n                self.config.actor.ppo_micro_batch_size //= (\n                    self.device_mesh.size() // self.ulysses_sequence_parallel_size\n                )\n                self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size\n\n            if self.config.actor.ppo_micro_batch_size_per_gpu is not None:\n                assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, (\n                    f\"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by \"\n                    f\"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}\"\n                )\n                assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, (\n                    f\"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than \"\n                    f\"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}\"\n                )\n\n        # normalize rollout config\n        if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:\n            self.config.rollout.log_prob_micro_batch_size //= (\n                self.device_mesh.size() // self.ulysses_sequence_parallel_size\n            )\n            self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size\n        # normalize ref config\n        if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:\n            self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size\n            self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size\n\n    def _build_model_optimizer(\n        self,\n        model_path,\n        fsdp_config: FSDPEngineConfig,\n        optim_config,\n        override_model_config,\n        use_remove_padding=False,\n        use_fused_kernels=False,\n        enable_gradient_checkpointing=False,\n        trust_remote_code=False,\n        use_liger=False,\n        role=\"actor\",\n        enable_activation_offload=False,\n    ):\n        from torch.distributed.fsdp import CPUOffload, MixedPrecision\n        from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,\n                                  AutoModelForImageTextToText,\n                                  AutoModelForVision2Seq)\n\n        from verl.utils.model import (get_generation_config, print_model_size,\n                                      update_model_config)\n        from verl.utils.torch_dtypes import PrecisionType\n\n        assert role in [\"actor\", \"ref\"]\n\n        log_gpu_memory_usage(f\"Before init {role} from HF AutoModel\", logger=logger)\n        local_path = model_path\n\n        # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect\n        # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly\n        self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)\n\n        if self.config.model.get(\"custom_chat_template\", None) is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.config.model.custom_chat_template\n            else:\n                self.tokenizer.chat_template = self.config.model.custom_chat_template\n\n        torch_dtype = fsdp_config.get(\"model_dtype\", None)\n        if torch_dtype is None:\n            torch_dtype = torch.float32 if self._is_actor else torch.bfloat16\n        else:\n            torch_dtype = PrecisionType.to_dtype(torch_dtype)\n\n        # override model kwargs\n        attn_implementation = override_model_config.get(\"attn_implementation\", \"flash_attention_2\")\n        actor_model_config = AutoConfig.from_pretrained(\n            local_path, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation\n        )\n        # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53\n        # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids\n        # Maybe support Ulysses in VisionAttention in the future and remove this patch\n        if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, \"vision_config\"):\n            actor_model_config.vision_config._attn_implementation = \"eager\"\n\n        # patch for kimi-vl\n        if getattr(actor_model_config, \"model_type\", None) == \"kimi_vl\":\n            actor_model_config.text_config.topk_method = \"greedy\"\n\n        self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)\n\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_model_config)\n        update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)\n        if self.rank == 0:\n            print(f\"Model config after override: {actor_model_config}\")\n\n        # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            has_remote_code = hasattr(actor_model_config, \"auto_map\") and any(\n                actor_model_config.architectures[0] in val for val in actor_model_config.auto_map.values()\n            )\n            if has_remote_code:\n                auto_class = next(\n                    k for k, v in actor_model_config.auto_map.items() if actor_model_config.architectures[0] in v\n                )\n                match auto_class:\n                    case \"AutoModelForVision2Seq\":\n                        actor_module_class = AutoModelForVision2Seq\n                    case \"AutoModelForCausalLM\":\n                        actor_module_class = AutoModelForCausalLM\n                    case \"AutoModelForImageTextToText\":\n                        actor_module_class = AutoModelForImageTextToText\n                    case _:\n                        actor_module_class = AutoModel\n            else:\n                if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():\n                    actor_module_class = AutoModelForVision2Seq\n                elif type(actor_model_config) in AutoModelForCausalLM._model_mapping.keys():\n                    actor_module_class = AutoModelForCausalLM\n                elif type(actor_model_config) in AutoModelForImageTextToText._model_mapping.keys():\n                    actor_module_class = AutoModelForImageTextToText\n                else:\n                    actor_module_class = AutoModel\n\n            actor_module = actor_module_class.from_pretrained(\n                pretrained_model_name_or_path=local_path,\n                torch_dtype=torch_dtype,\n                config=actor_model_config,\n                trust_remote_code=trust_remote_code,\n                attn_implementation=attn_implementation,\n            )\n\n            # Apply Liger kernel to the model if use_liger is set to True\n            if use_liger:\n                from liger_kernel.transformers.monkey_patch import \\\n                    _apply_liger_kernel_to_instance\n\n                _apply_liger_kernel_to_instance(model=actor_module)\n\n            fused_kernel_options = self.config.model.get(\"fused_kernel_options\", None)\n            fused_kernels_backend = (\n                fused_kernel_options.get(\"impl_backend\", None) if fused_kernel_options is not None else None\n            )\n\n            apply_monkey_patch(\n                model=actor_module,\n                use_remove_padding=use_remove_padding,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n                use_fused_kernels=use_fused_kernels,\n                fused_kernels_backend=fused_kernels_backend,\n            )\n\n            # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2\n            actor_module.to(torch_dtype)\n\n            if enable_gradient_checkpointing:\n                actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n\n        if self._is_lora:\n            print(\"Applying LoRA to actor module\")\n            actor_module.enable_input_require_grads()\n\n            lora_adapter_path = self.config.model.get(\"lora_adapter_path\")\n            if lora_adapter_path is not None:\n                from peft import PeftModel\n\n                print(f\"Loading pre-trained LoRA adapter to {role} from: {lora_adapter_path}\")\n\n                # Copy adapter to local if needed\n                local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.get(\"use_shm\", False))\n\n                actor_module = PeftModel.from_pretrained(actor_module, local_adapter_path, is_trainable=True)\n                peft_config = actor_module.peft_config[\"default\"]\n                # Ensure task_type is TaskType enum, not string\n                if isinstance(peft_config.task_type, str):\n                    peft_config.task_type = TaskType.CAUSAL_LM\n\n            else:\n                # Convert config to regular Python types before creating PEFT model\n                lora_config = {\n                    \"task_type\": TaskType.CAUSAL_LM,\n                    \"r\": self.config.model.lora_rank,\n                    \"lora_alpha\": self.config.model.lora_alpha,\n                    \"target_modules\": convert_to_regular_types(self.config.model.target_modules),\n                    \"exclude_modules\": convert_to_regular_types(self.config.model.exclude_modules),\n                    \"bias\": \"none\",\n                }\n                actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))\n\n        self.use_orig_params = fsdp_config.get(\"use_orig_params\", False)\n        if self.config.actor.get(\"freeze_vision_tower\", False):\n            vision_tower = get_vl_model_vision_tower(actor_module)\n            if vision_tower is not None:\n                vision_tower.requires_grad_(False)\n                self.use_orig_params = True\n                if self.rank == 0:\n                    print(\"[actor model] Vision tower is set to not trainable.\")\n            else:\n                if self.rank == 0:\n                    print(\"[actor model] No vision tower found.\")\n\n        torch.distributed.barrier()\n\n        if self.rank == 0:\n            print_model_size(actor_module)\n\n        log_gpu_memory_usage(f\"After init {role} from HF AutoModel\", logger=logger)\n\n        # We wrap FSDP for rollout as well\n        mixed_precision_config = fsdp_config.get(\"mixed_precision\", None)\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"param_dtype\", \"bf16\"))\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"reduce_dtype\", \"fp32\"))\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"buffer_dtype\", \"fp32\"))\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n\n        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(\n            module=actor_module,\n            config=fsdp_config.get(\"wrap_policy\", None),\n            is_lora=self._is_lora,\n        )\n\n        if self._is_rollout and self.config.rollout.name == \"hf\":\n            # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma\n            auto_wrap_policy = None\n\n        if self.rank == 0:\n            print(f\"wrap_policy: {auto_wrap_policy}\")\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        # TODO: add transformer policy\n        # We force reference policy to use CPUOffload to save memory.\n        # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation\n        cpu_offload = None if role == \"actor\" else CPUOffload(offload_params=True)\n        fsdp_strategy = self.config.actor.strategy\n        if fsdp_strategy == \"fsdp\":\n            actor_module_fsdp = FSDP(\n                actor_module,\n                cpu_offload=cpu_offload,\n                param_init_fn=init_fn,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,  # zero3\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                device_mesh=self.device_mesh,\n                use_orig_params=self.use_orig_params,\n                forward_prefetch=fsdp_config.get(\"forward_prefetch\", False),\n            )\n        elif fsdp_strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            mp_policy = MixedPrecisionPolicy(\n                param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True\n            )\n            if role == \"actor\" and fsdp_config.offload_policy:\n                cpu_offload = CPUOffloadPolicy(pin_memory=True)\n                self._is_offload_param = False\n                self._is_offload_optimizer = False\n            else:\n                cpu_offload = None if role == \"actor\" else CPUOffloadPolicy(pin_memory=True)\n\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": cpu_offload,\n                \"reshard_after_forward\": fsdp_config.reshard_after_forward,\n                \"shard_placement_fn\": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]),\n            }\n            full_state = actor_module.state_dict()\n            apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)\n            fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)\n            actor_module_fsdp = actor_module\n        else:\n            raise NotImplementedError(f\"not implement {fsdp_strategy}\")\n\n        if enable_activation_offload:\n            enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing)\n\n        log_gpu_memory_usage(f\"After {role} FSDP init\", logger=logger)\n\n        # TODO: add more optimizer args into config\n        if role == \"actor\" and optim_config is not None:\n            from verl.utils.torch_functional import (\n                get_constant_schedule_with_warmup,\n                get_cosine_schedule_with_warmup)\n\n            actor_optimizer = build_optimizer(actor_module_fsdp.parameters(), optim_config)\n\n            total_steps = optim_config.get(\"total_training_steps\", 0)\n            num_warmup_steps = int(optim_config.get(\"lr_warmup_steps\", -1))\n            lr_scheduler_type = optim_config.get(\"lr_scheduler_type\", \"constant\")\n            min_lr_ratio = optim_config.get(\"min_lr_ratio\", 0.0)\n            num_cycles = optim_config.get(\"num_cycles\", 0.5)\n            if num_warmup_steps < 0:\n                num_warmup_steps_ratio = optim_config.get(\"lr_warmup_steps_ratio\", 0.0)\n                num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n\n            if self.rank == 0:\n                print(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n\n            if lr_scheduler_type == \"constant\":\n                actor_lr_scheduler = get_constant_schedule_with_warmup(\n                    optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps\n                )\n            elif lr_scheduler_type == \"cosine\":\n                actor_lr_scheduler = get_cosine_schedule_with_warmup(\n                    optimizer=actor_optimizer,\n                    num_warmup_steps=num_warmup_steps,\n                    num_training_steps=total_steps,\n                    min_lr_ratio=min_lr_ratio,\n                    num_cycles=num_cycles,\n                )\n            else:\n                raise NotImplementedError(f\"LR scheduler type {lr_scheduler_type} is not supported\")\n\n            log_gpu_memory_usage(f\"After {role} optimizer init\", logger=logger)\n        else:\n            actor_optimizer = None\n            actor_lr_scheduler = None\n\n        return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config\n\n    def _build_rollout(self, trust_remote_code=False):\n        from torch.distributed.device_mesh import init_device_mesh\n\n        # 1. parse rollout and huggingface model config\n        rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)\n        model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)\n        self.model_config = model_config\n\n        # 2. build rollout device mesh\n        infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size\n        infer_pp = self.config.rollout.pipeline_model_parallel_size\n        infer_world_size = infer_tp * infer_pp\n        dp = self.world_size // infer_world_size\n        assert self.world_size % infer_world_size == 0, (\n            f\"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}\"\n        )\n        rollout_device_mesh = init_device_mesh(\n            device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=[\"dp\", \"infer_tp\", \"infer_pp\"]\n        )\n        rollout_name = self.config.rollout.name\n\n        if rollout_name == \"hf\":\n            self._register_dispatch_collect_info(\"rollout\", dp_rank=self.rank, is_collect=True)\n        else:\n            is_collect = (\n                rollout_device_mesh[\"infer_tp\"].get_local_rank() == 0\n                and rollout_device_mesh[\"infer_pp\"].get_local_rank() == 0\n            )\n            self._register_dispatch_collect_info(\n                \"rollout\", dp_rank=rollout_device_mesh[\"dp\"].get_local_rank(), is_collect=is_collect\n            )\n\n        # 3. init trainer and rollout random states\n        self.torch_random_states = get_torch_device().get_rng_state()\n        gen_dp_rank = rollout_device_mesh[\"dp\"].get_local_rank()\n        get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n        self.gen_random_states = get_torch_device().get_rng_state()\n        get_torch_device().set_rng_state(self.torch_random_states)\n\n        # 4. build rollout model\n        log_gpu_memory_usage(f\"Before building {self.config.rollout.name} rollout\", logger=logger)\n        self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)(\n            config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh\n        )\n        log_gpu_memory_usage(f\"After building {self.config.rollout.name} rollout\", logger=logger)\n\n        # Full params\n        if torch.distributed.get_world_size() == 1 and fsdp_version(self.actor_module_fsdp) == 1:\n            FSDP.set_state_dict_type(\n                self.actor_module_fsdp,\n                state_dict_type=StateDictType.FULL_STATE_DICT,\n                state_dict_config=FullStateDictConfig(),\n            )\n        elif fsdp_version(self.actor_module_fsdp) == 1:\n            FSDP.set_state_dict_type(\n                self.actor_module_fsdp,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n\n        # used for LoRA\n        self.base_sync_done: bool = \"dummy\" not in self.config.rollout.load_format\n        self.layered_summon = self.config.rollout.get(\"layered_summon\", False)\n\n        # 5. switch to trainer mode\n        # NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint.\n        # For sync mode, we directly switch to trainer mode here.\n        # For async mode, we can't call run_until_complete here, so we will switch to trainer mode in AgentLoopManager.\n        if rollout_config.mode == \"sync\" and self._is_actor:\n            loop = get_event_loop()\n            loop.run_until_complete(self.trainer_mode())\n\n    async def rollout_mode(self):\n        \"\"\"Context switch hybridengine to rollout mode.\"\"\"\n        aggressive_empty_cache(force_sync=True)\n\n        log_gpu_memory_usage(\"Before load_fsdp_model_to_gpu\", logger=logger)\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n        log_gpu_memory_usage(\"After load_fsdp_model_to_gpu\", logger=logger)\n\n        peft_config = None\n        peft_model = getattr(self.actor_module_fsdp, \"_fsdp_wrapped_module\", self.actor_module_fsdp)\n        if hasattr(peft_model, \"peft_config\"):  # LoRA\n            peft_config = peft_model.peft_config.get(\"default\", None)\n            params = collect_lora_params(\n                module=self.actor_module_fsdp,\n                layered_summon=self.config.rollout.get(\"layered_summon\", False),\n                base_sync_done=self.base_sync_done,\n            )\n            if not self.base_sync_done:\n                params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()}\n        else:\n            params = self.actor_module_fsdp.state_dict()\n\n        params = convert_weight_keys(\n            params, getattr(self.actor_module_fsdp, \"_fsdp_wrapped_module\", self.actor_module_fsdp)\n        )\n\n        # Special handling for LoRA with sleep_level=2:\n        # When sleep_level=2, base model weights are destroyed during each sleep cycle.\n        # separately collect and update LoRA weights and base model weights through their respective interfaces.\n        # Here: params contains LoRA weights, base_model_params contains base model weights.\n        if peft_config is not None and getattr(self.rollout, \"sleep_level\", None) == 2:\n            base_model_params = collect_lora_params(\n                module=self.actor_module_fsdp,\n                layered_summon=self.layered_summon,\n                base_sync_done=False,\n            )\n            base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()}\n            base_model_params = convert_weight_keys(\n                base_model_params, getattr(self.actor_module_fsdp, \"_fsdp_wrapped_module\", self.actor_module_fsdp)\n            )\n\n        log_gpu_memory_usage(\"Before offload_fsdp_model_to_cpu\", logger=logger)\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n        log_gpu_memory_usage(\"After offload_fsdp_model_to_cpu\", logger=logger)\n\n        set_expandable_segments(False)\n\n        if peft_config is not None and self.base_sync_done:\n            per_tensor_param = params.items() if isinstance(params, dict) else params  # Fixed: handle dict case\n        else:\n            device = get_device_id()  # used when fsdp2 set cpu_offload_policy\n            per_tensor_param = (\n                (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)\n                for name, param in params.items()\n            )\n\n        if self.config.rollout.free_cache_engine:\n            await self.rollout.resume(tags=[\"weights\"])\n        log_gpu_memory_usage(\"After resume weights\", logger=logger)\n\n        if peft_config is not None and getattr(self.rollout, \"sleep_level\", None) == 2:\n            per_tensor_base_params = (\n                (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)\n                for name, param in base_model_params.items()\n            )\n            await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)\n            del base_model_params, per_tensor_base_params\n\n        await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)\n        log_gpu_memory_usage(\"After update_weights\", logger=logger)\n        del params, per_tensor_param\n        aggressive_empty_cache(force_sync=True)\n        if self.config.rollout.free_cache_engine:\n            await self.rollout.resume(tags=[\"kv_cache\"])\n        log_gpu_memory_usage(\"After resume kv_cache\", logger=logger)\n\n        self.base_sync_done = True\n        # important: need to manually set the random states of each tp to be identical.\n        self.torch_random_states = get_torch_device().get_rng_state()\n        get_torch_device().set_rng_state(self.gen_random_states)\n\n    async def trainer_mode(self):\n        \"\"\"Context switch hybridengine to trainer mode.\"\"\"\n        if self.config.rollout.free_cache_engine:\n            log_gpu_memory_usage(\"Before rollout offload\", logger=logger)\n            await self.rollout.release()\n            log_gpu_memory_usage(\"After rollout offload\", logger=logger)\n\n        self.actor_module_fsdp.train()\n\n        # add empty cache after each compute\n        aggressive_empty_cache(force_sync=True)\n\n        set_expandable_segments(True)\n\n        # restore random states\n        self.gen_random_states = get_torch_device().get_rng_state()\n        get_torch_device().set_rng_state(self.torch_random_states)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        from verl.workers.actor import DataParallelPPOActor\n\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get(\"override_config\", {})))\n        use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        use_shm = self.config.model.get(\"use_shm\", False)\n        use_fused_kernels = self.config.model.get(\"use_fused_kernels\", False)\n\n        if self._is_actor or self._is_rollout:\n            # we need the model for actor and rollout\n            if self._is_actor:\n                optim_config = self.config.actor.optim\n                fsdp_config = omega_conf_to_dataclass(self.config.actor.fsdp_config)\n            else:\n                optim_config = None\n                fsdp_config = FSDPEngineConfig()\n\n            local_path = copy_to_local(self.config.model.path, use_shm=use_shm)\n            (\n                self.actor_module_fsdp,\n                self.actor_optimizer,\n                self.actor_lr_scheduler,\n                self.actor_model_config,\n            ) = self._build_model_optimizer(\n                model_path=local_path,\n                fsdp_config=fsdp_config,\n                optim_config=optim_config,\n                override_model_config=override_model_config,\n                use_remove_padding=use_remove_padding,\n                use_fused_kernels=use_fused_kernels,\n                enable_gradient_checkpointing=self.config.model.get(\"enable_gradient_checkpointing\", False),\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                use_liger=self.config.model.get(\"use_liger\", False),\n                role=\"actor\",\n                enable_activation_offload=self.config.model.get(\"enable_activation_offload\", False),\n            )\n\n            # get the original unwrapped module\n            if fsdp_version(self.actor_module_fsdp) == 1:\n                self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module\n\n            if self._is_offload_param:\n                offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n                log_gpu_memory_usage(\"After offload actor model during init\", logger=logger)\n\n            if self._is_offload_optimizer:\n                offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n                log_gpu_memory_usage(\"After offload actor optimizer during init\", logger=logger)\n\n        if self._is_actor:\n            actor_cfg = omega_conf_to_dataclass(self.config.actor)\n            self.actor = DataParallelPPOActor(\n                config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer\n            )\n\n        if self._is_rollout:\n            self._build_rollout(trust_remote_code=self.config.model.get(\"trust_remote_code\", False))\n\n        if self._is_ref:\n            ref_model_path = self.config.model.path\n            ref_model = self.config.ref.get(\"model\", None)\n            if ref_model is not None:\n                ref_model_path = ref_model.get(\"path\", self.config.model.path)\n\n            if self.rank == 0:\n                print(\"reference model:\", ref_model_path)\n            local_path = copy_to_local(ref_model_path, use_shm=use_shm)\n            self.ref_module_fsdp = self._build_model_optimizer(\n                model_path=local_path,\n                fsdp_config=omega_conf_to_dataclass(self.config.ref.fsdp_config),\n                optim_config=None,\n                override_model_config=override_model_config,\n                use_remove_padding=use_remove_padding,\n                use_fused_kernels=use_fused_kernels,\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                use_liger=self.config.model.get(\"use_liger\", False),\n                role=\"ref\",\n            )[0]\n            OmegaConf.set_struct(self.config.ref, True)\n            with open_dict(self.config.ref):\n                self.config.ref.use_remove_padding = use_remove_padding\n                self.config.ref.use_fused_kernels = use_fused_kernels\n            self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)\n\n        if self._is_actor:\n            self.flops_counter = FlopsCounter(self.actor_model_config)\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp,\n                optimizer=self.actor.actor_optimizer,\n                lr_scheduler=self.actor_lr_scheduler,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                checkpoint_config=self.config.actor.checkpoint,\n            )\n\n        if not self._is_actor and self._is_rollout:\n            # If ActorRolloutRefWorker is initialized as a standalone rollout,\n            # create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout.\n\n            checkpoint_contents = OmegaConf.create({\"load_contents\": [\"model\"], \"save_contents\": []})\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp,\n                optimizer=None,\n                lr_scheduler=None,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                checkpoint_config=checkpoint_contents,\n            )\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    @DistProfiler.annotate(color=\"red\", role=\"actor_update\")\n    def update_actor(self, data: DataProto):\n        assert self._is_actor\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n        if self._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())\n\n        with self.ulysses_sharding_manager:\n            data = data.to(\"cpu\")  # data will to device with each micro batch on actor.update_policy\n\n            # perform training\n            with Timer(name=\"update_policy\", logger=None) as timer:\n                metrics = self.actor.update_policy(data=data)\n            delta_time = timer.last\n            global_num_tokens = data.meta_info[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/actor\"] = (\n                estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size\n            )\n            metrics[\"perf/max_memory_allocated_gb\"] = get_torch_device().max_memory_allocated() / (1024**3)\n            metrics[\"perf/max_memory_reserved_gb\"] = get_torch_device().max_memory_reserved() / (1024**3)\n            metrics[\"perf/cpu_memory_used_gb\"] = psutil.virtual_memory().used / (1024**3)\n\n            lr = self.actor_lr_scheduler.get_last_lr()[0]\n            metrics[\"actor/lr\"] = lr.item() if torch.is_tensor(lr) else lr\n            self.actor_lr_scheduler.step()\n\n            # TODO: here, we should return all metrics\n            output = DataProto(meta_info={\"metrics\": metrics})\n\n            output = output.to(\"cpu\")\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n            log_gpu_memory_usage(\"After offload actor model during update_actor\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n            log_gpu_memory_usage(\"After offload actor optimizer during update_actor\", logger=logger)\n\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"rollout\"))\n    @DistProfiler.annotate(color=\"red\", role=\"rollout_generate\")\n    def generate_sequences(self, prompts: DataProto):\n        # Support all hardwares\n        assert self._is_rollout\n        prompts = prompts.to(get_device_id())\n\n        meta_info = {\n            \"eos_token_id\": self.generation_config.eos_token_id\n            if self.generation_config is not None\n            else self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.generation_config.pad_token_id\n            if self.generation_config is not None\n            else self.tokenizer.pad_token_id,\n        }\n        prompts.meta_info.update(meta_info)\n\n        timing_generate = {}\n        if self._is_actor:  # For rollout only, we do not switch context.\n            loop = get_event_loop()\n            loop.run_until_complete(self.rollout_mode())\n            log_gpu_memory_usage(\"After switch to rollout mode\", logger=logger)\n\n        with simple_timer(\"generate_sequences\", timing_generate):\n            output = self.rollout.generate_sequences(prompts=prompts)\n\n        if self._is_actor:\n            loop.run_until_complete(self.trainer_mode())\n            log_gpu_memory_usage(\"After switch to trainer mode\", logger=logger)\n\n        # We calculate the average timing across all ranks\n        # to make sure meta_info[\"timing\"] is the same\n        timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max(\n            timing_generate[\"generate_sequences\"]\n        )\n        timing_generate = reduce_timing(timing_generate)\n        timing_generate.update(\n            {\n                \"generation_timing/max\": timing_generate_max,\n                \"generation_timing/min\": timing_generate_min,\n                \"generation_timing/topk_ratio\": timing_generate_topk_ratio,\n            }\n        )\n        output.meta_info[\"timing\"] = timing_generate\n        output = output.to(\"cpu\")\n\n        # clear kv cache\n        get_torch_device().empty_cache()\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    @DistProfiler.annotate(color=\"blue\", role=\"actor_compute_log_prob\")\n    def compute_log_prob(self, data: DataProto):\n        # when is_lora is True, we use the actor without lora applied to calculate the log_prob\n        # which is mostly used for ref log_prob calculation\n        assert self._is_actor\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        # Support all hardwares\n        from contextlib import nullcontext\n\n        is_lora = data.meta_info.pop(\"is_lora\", False)\n        adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext()\n        # we should always recompute old_log_probs when it is HybridEngine\n        data.meta_info[\"micro_batch_size\"] = self.config.rollout.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"max_token_len\"] = self.config.rollout.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.rollout.log_prob_use_dynamic_bsz\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        # perform recompute log_prob\n        with self.ulysses_sharding_manager:\n            with adapter_ctx:\n                output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)\n            output = DataProto.from_dict(\n                tensors={\"old_log_probs\": output, \"entropys\": entropys},\n                meta_info={\"temperature\": self.config.rollout.temperature},\n            )\n\n        output = output.to(\"cpu\")\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1:\n            self.actor.actor_module._handle.reshard(True)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n            log_gpu_memory_usage(\"After offload actor model during compute_log_prob\", logger=logger)\n\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    @DistProfiler.annotate(color=\"olive\", role=\"ref_compute_log_prob\")\n    def compute_ref_log_prob(self, data: DataProto):\n        if self._is_lora:\n            # if _is_lora, actor without lora applied is the ref\n            data.meta_info[\"is_lora\"] = True\n            data = self.compute_log_prob(data)\n            # this old_log_probs is in fact ref_log_prob\n            data = DataProto.from_dict(tensors={\"ref_log_prob\": data.batch[\"old_log_probs\"]})\n            return data\n        assert self._is_ref\n        # else:\n        # otherwise, the class have a standalone ref model\n\n        micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        data.meta_info[\"max_token_len\"] = self.config.ref.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.ref.log_prob_use_dynamic_bsz\n        with self.ulysses_sharding_manager:\n            data = data.to(\"cpu\")  # data will to device with each micro batch on ref.compute_log_prob\n            output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False, mask_special_token=True)\n            output = DataProto.from_dict(tensors={\"ref_log_prob\": output})\n\n        output = output.to(\"cpu\")\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1:\n            if fsdp_version(self.ref_policy.actor_module) == 1:\n                self.ref_policy.actor_module._handle.reshard(True)\n            elif fsdp_version(self.ref_policy.actor_module) == 2:\n                self.ref_policy.actor_module.reshard()\n\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        from verl.utils.logger import log_with_rank\n\n        # only support save and load ckpt for actor\n        assert self._is_actor\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        self.checkpoint_manager.save_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n        dist.barrier()\n\n        if self._is_lora and hasattr(getattr(self, \"actor_module\", self.actor_module_fsdp), \"peft_config\"):\n            lora_save_path = os.path.join(local_path, \"lora_adapter\")\n            peft_model = getattr(self, \"actor_module\", self.actor_module_fsdp)\n            peft_config = {}\n            if dist.get_rank() == 0:\n                os.makedirs(lora_save_path, exist_ok=True)\n                peft_config = asdict(peft_model.peft_config.get(\"default\", {}))\n                peft_config[\"task_type\"] = peft_config[\"task_type\"].value\n                peft_config[\"peft_type\"] = peft_config[\"peft_type\"].value\n                peft_config[\"target_modules\"] = list(peft_config[\"target_modules\"])\n            try:\n                if fsdp_version(self.actor_module_fsdp) > 0:\n                    self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name())\n                    lora_params = layered_summon_lora_params(self.actor_module_fsdp)\n                    if dist.get_rank() == 0:\n                        save_file(lora_params, os.path.join(lora_save_path, \"adapter_model.safetensors\"))\n                        with open(os.path.join(lora_save_path, \"adapter_config.json\"), \"w\", encoding=\"utf-8\") as f:\n                            json.dump(peft_config, f, ensure_ascii=False, indent=4)\n            except Exception as e:\n                log_with_rank(\n                    f\"Save LoRA Adapter Error ({e})\", rank=dist.get_rank(), logger=logger, log_only_rank_0=True\n                )\n\n            dist.barrier()\n            log_with_rank(\n                f\"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}\",\n                rank=dist.get_rank(),\n                logger=logger,\n                log_only_rank_0=True,\n            )\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):\n        assert self._is_actor or (not self._is_actor and self._is_rollout), (\n            f\"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got \"\n            f\"{self._is_actor} and {self._is_rollout}\"\n        )\n\n        # No checkpoint to load, just offload the model and optimizer to CPU\n        if local_path is None:\n            if self._is_offload_param:\n                offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n            if self._is_offload_optimizer:\n                offload_fsdp_optimizer(self.actor_optimizer)\n            return\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        self.checkpoint_manager.load_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(self.actor_optimizer)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def start_profile(self, **kwargs) -> None:\n        \"\"\"Start profiling for the current rank in the current training step.\"\"\"\n        self.profiler.start(**kwargs)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def stop_profile(self) -> None:\n        \"\"\"Stop profiling for the current rank in the current training step.\"\"\"\n        self.profiler.stop()\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def dump_memory_snapshot(self, tag: str = \"manual\", sub_dir: str = None) -> None:\n        \"\"\"Manually trigger a CUDA memory snapshot dump on all ranks.\"\"\"\n        # Memory snapshot is now handled by the profiler system\n        # This method is kept for backward compatibility but delegates to profiler\n        if hasattr(self, \"profiler\") and hasattr(self.profiler, \"_impl\"):\n            try:\n                # Try to use the profiler's memory snapshot functionality\n                if hasattr(self.profiler._impl, \"sampler\"):\n                    out_dir = OmegaConf.select(self.config, \"actor.profiler.save_path\") or \".\"\n                    self.profiler._impl.sampler.dump_memory_snapshot(out_dir=out_dir, tag=tag, sub_dir=sub_dir)\n            except Exception:\n                # silently ignore if profiler doesn't support memory snapshots\n                pass\n\n\nclass CriticWorker(Worker, DistProfilerExtension):\n    def __init__(self, config: FSDPCriticConfig):\n        Worker.__init__(self)\n        omega_profiler_config = config.get(\"profiler\", {})\n        profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)\n        if omega_profiler_config.get(\"tool\", None) in [\"npu\", \"nsys\", \"torch\", \"torch_memory\"]:\n            tool_config = omega_conf_to_dataclass(\n                omega_profiler_config.get(\"tool_config\", {}).get(omega_profiler_config.get(\"tool\"))\n            )\n        else:\n            tool_config = None\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config)\n        )\n        import torch.distributed\n\n        self.config = config\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(\n                backend=get_nccl_backend(),\n                timeout=datetime.timedelta(seconds=self.config.get(\"nccl_timeout\", 600)),\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n        self.config: FSDPCriticConfig = config\n\n        # build device mesh for Ulysses Sequence Parallel\n        world_size = torch.distributed.get_world_size()\n        from torch.distributed.device_mesh import init_device_mesh\n\n        fsdp_size = self.config.model.fsdp_config.fsdp_size\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        # create training dispatch\n        if self.ulysses_device_mesh is not None:\n            is_collect = self.ulysses_device_mesh[\"sp\"].get_local_rank() == 0\n            self._register_dispatch_collect_info(\n                \"critic\", dp_rank=self.ulysses_device_mesh[\"dp\"].get_local_rank(), is_collect=is_collect\n            )\n        else:\n            self._register_dispatch_collect_info(\"critic\", dp_rank=self.rank, is_collect=True)\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        # set FSDP offload params\n        self._is_offload_param = self.config.model.fsdp_config.param_offload\n        self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload\n\n        # normalize config\n        self.config.ppo_mini_batch_size *= self.config.rollout_n\n        self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n        if self.config.ppo_micro_batch_size is not None:\n            self.config.ppo_micro_batch_size //= (\n                torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n            )\n            self.config.forward_micro_batch_size //= (\n                torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n            )\n            self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size\n            self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size\n\n        if self.config.ppo_micro_batch_size_per_gpu is not None:\n            assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, (\n                f\"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by \"\n                f\"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}\"\n            )\n            assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, (\n                f\"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than \"\n                f\"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}\"\n            )\n        self._is_lora = (\n            self.config.model.get(\"lora_adapter_path\") is not None or self.config.model.get(\"lora_rank\", 0) > 0\n        )\n        self.use_orig_params = self.config.model.fsdp_config.get(\"use_orig_params\", False)\n\n    def _build_critic_model_optimizer(self, config):\n        # the following line is necessary\n        from torch.distributed.fsdp import MixedPrecision\n\n        from verl.utils.model import load_valuehead_model, print_model_size\n        from verl.utils.torch_dtypes import PrecisionType\n\n        use_shm = config.model.get(\"use_shm\", False)\n        local_path = copy_to_local(config.model.path, use_shm=use_shm)\n        # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info\n        # using random initialized model from any architecture. May not be the same as Actor.\n\n        tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm)\n        self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n        self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n\n        if self.config.model.get(\"custom_chat_template\", None) is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.config.model.custom_chat_template\n            else:\n                self.tokenizer.chat_template = self.config.model.custom_chat_template\n        override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get(\"override_config\", {})))\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_config)\n        if self.rank == 0:\n            print(f\"Critic overriding config {override_config_kwargs}\")\n\n        torch_dtype = self.config.model.fsdp_config.get(\"model_dtype\", \"fp32\")\n        torch_dtype = PrecisionType.to_dtype(torch_dtype)\n\n        from transformers import AutoConfig\n\n        # override model kwargs\n        attn_implementation = override_config.get(\"attn_implementation\", \"flash_attention_2\")\n        critic_model_config = AutoConfig.from_pretrained(\n            local_path,\n            attn_implementation=attn_implementation,\n            trust_remote_code=config.model.get(\"trust_remote_code\", False),\n        )\n        # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53\n        # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids\n        # Maybe support Ulysses in VisionAttention in the future and remove this patch\n        if self.ulysses_sequence_parallel_size > 1 and hasattr(critic_model_config, \"vision_config\"):\n            critic_model_config.vision_config._attn_implementation = \"eager\"\n\n        critic_model_config.num_labels = 1\n        # patch for kimi-vl\n        if getattr(critic_model_config, \"model_type\", None) == \"kimi_vl\":\n            critic_model_config.text_config.topk_method = \"greedy\"\n\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            critic_model_config.classifier_dropout = 0.0\n            critic_model_config.hidden_dropout = \"0\"\n            critic_model_config.summary_dropout_prob = 0.0\n\n            critic_module = load_valuehead_model(\n                local_path,\n                torch_dtype,\n                critic_model_config,\n                config.model.get(\"trust_remote_code\", False),\n            )\n\n            use_remove_padding = config.model.get(\"use_remove_padding\", False)\n\n            apply_monkey_patch(\n                model=critic_module,\n                use_remove_padding=use_remove_padding,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n            )\n\n            # some parameters may not in torch_dtype\n            critic_module.to(torch_dtype)\n\n            if config.model.get(\"enable_gradient_checkpointing\", False):\n                critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n\n        if self._is_lora:\n            print(\"Applying LoRA to critic module\")\n            critic_module.enable_input_require_grads()\n\n            # Check if we should load a pre-trained LoRA adapter\n            lora_adapter_path = self.config.model.get(\"lora_adapter_path\")\n            if lora_adapter_path is not None:\n                from peft import PeftModel\n\n                print(f\"Loading pre-trained LoRA adapter to critic from: {lora_adapter_path}\")\n\n                # Copy adapter to local if needed\n                local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.get(\"use_shm\", False))\n\n                critic_module = PeftModel.from_pretrained(critic_module, local_adapter_path, is_trainable=True)\n                peft_config = critic_module.peft_config[\"default\"]\n                # Ensure task_type is TaskType enum, not string\n                if isinstance(peft_config.task_type, str):\n                    peft_config.task_type = TaskType.CAUSAL_LM\n\n            else:\n                # Convert config to regular Python types before creating PEFT model\n                lora_config = {\n                    \"task_type\": TaskType.CAUSAL_LM,\n                    \"r\": self.config.model.lora_rank,\n                    \"lora_alpha\": self.config.model.lora_alpha,\n                    \"target_modules\": convert_to_regular_types(self.config.model.target_modules),\n                    \"bias\": \"none\",\n                }\n                critic_module = get_peft_model(critic_module, LoraConfig(**lora_config))\n\n        if self.rank == 0:\n            print_model_size(critic_module)\n\n        self.critic_model_config = critic_model_config\n\n        fsdp_config = self.config.model.fsdp_config\n        mixed_precision_config = fsdp_config.get(\"mixed_precision\", None)\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"param_dtype\", \"bf16\"))\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"reduce_dtype\", \"fp32\"))\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"buffer_dtype\", \"fp32\"))\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n\n        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(\n            module=critic_module,\n            config=self.config.model.fsdp_config.wrap_policy,\n            is_lora=self._is_lora,\n        )\n\n        log_gpu_memory_usage(\"Before critic FSDP\", logger=None)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        self.use_orig_params = fsdp_config.get(\"use_orig_params\", False)\n        if self.config.model.get(\"freeze_vision_tower\", False):\n            vision_tower = get_vl_model_vision_tower(critic_module)\n            if vision_tower is not None:\n                vision_tower.requires_grad_(False)\n                self.use_orig_params = True\n                if self.rank == 0:\n                    print(\"[critic model] Vision tower is set to not trainable.\")\n            else:\n                if self.rank == 0:\n                    print(\"[critic model] No vision tower found.\")\n\n        # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation\n        if config.strategy == \"fsdp\":\n            critic_module = FSDP(\n                critic_module,\n                param_init_fn=init_fn,\n                use_orig_params=self.use_orig_params,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                forward_prefetch=self.config.model.fsdp_config.forward_prefetch,\n                device_mesh=self.device_mesh,\n                cpu_offload=None,\n            )\n        elif config.strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            mp_policy = MixedPrecisionPolicy(\n                param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True\n            )\n            offload_policy = None\n            if fsdp_config.offload_policy:\n                self._is_offload_param = False\n                self._is_offload_optimizer = False\n                offload_policy = CPUOffloadPolicy(pin_memory=True)\n\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": offload_policy,\n                \"reshard_after_forward\": fsdp_config.reshard_after_forward,\n                \"shard_placement_fn\": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]),\n            }\n            full_state = critic_module.state_dict()\n            apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config)\n            fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy)\n        else:\n            raise NotImplementedError(f\"Unknown strategy {config.strategy}\")\n\n        if config.model.get(\"enable_activation_offload\", False):\n            enable_gradient_checkpointing = config.model.get(\"enable_gradient_checkpointing\", False)\n            enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing)\n\n        log_gpu_memory_usage(\"After critic FSDP\", logger=None)\n\n        critic_optimizer = build_optimizer(critic_module.parameters(), config.optim)\n\n        total_steps = config.optim.get(\"total_training_steps\", 0)\n        num_warmup_steps = int(config.optim.get(\"lr_warmup_steps\", -1))\n\n        lr_scheduler_type = config.optim.get(\"lr_scheduler_type\", \"constant\")\n        if num_warmup_steps < 0:\n            num_warmup_steps_ratio = config.optim.get(\"lr_warmup_steps_ratio\", 0.0)\n            num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n\n        if self.rank == 0:\n            print(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n\n        from verl.utils.torch_functional import (\n            get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup)\n\n        if lr_scheduler_type == \"constant\":\n            critic_lr_scheduler = get_constant_schedule_with_warmup(\n                optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps\n            )\n        elif lr_scheduler_type == \"cosine\":\n            min_lr_ratio = config.optim.get(\"min_lr_ratio\", 0.0)\n            num_cycles = config.optim.get(\"num_cycles\", 0.5)\n            critic_lr_scheduler = get_cosine_schedule_with_warmup(\n                optimizer=critic_optimizer,\n                num_warmup_steps=num_warmup_steps,\n                num_training_steps=total_steps,\n                min_lr_ratio=min_lr_ratio,\n                num_cycles=num_cycles,\n            )\n        else:\n            raise NotImplementedError(f\"LR scheduler type {lr_scheduler_type} is not supported\")\n\n        return critic_module, critic_optimizer, critic_lr_scheduler\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        from verl.workers.critic import DataParallelPPOCritic\n\n        self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(\n            self.config\n        )\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n            log_gpu_memory_usage(\"After offload critic model during init\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.critic_optimizer)\n            log_gpu_memory_usage(\"After offload critic optimizer during init\", logger=logger)\n\n        self.critic = DataParallelPPOCritic(\n            config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer\n        )\n\n        self.flops_counter = FlopsCounter(self.critic_model_config)\n        self.checkpoint_manager = FSDPCheckpointManager(\n            model=self.critic_module,\n            optimizer=self.critic_optimizer,\n            lr_scheduler=self.critic_lr_scheduler,\n            processing_class=self.processor if self.processor is not None else self.tokenizer,\n            checkpoint_config=self.config.checkpoint,\n        )\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"critic\"))\n    @DistProfiler.annotate(color=\"cyan\")\n    def compute_values(self, data: DataProto):\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n        micro_batch_size = self.config.forward_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"max_token_len\"] = self.config.forward_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            data = data.to(\"cpu\")  # data will to device with each micro batch on critic.compute_values\n            values = self.critic.compute_values(data=data)\n            output = DataProto.from_dict(tensors={\"values\": values})\n\n        output = output.to(\"cpu\")\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"critic\"))\n    @DistProfiler.annotate(color=\"pink\")\n    def update_critic(self, data: DataProto):\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n        if self._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id())\n\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            data = data.to(\"cpu\")  # data will to device with each micro batch on critic.update_critic\n            with Timer(name=\"update_critic\", logger=None) as timer:\n                metrics = self.critic.update_critic(data=data)\n            delta_time = timer.last\n\n            global_num_tokens = data.meta_info[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/critic\"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size\n\n            lr = self.critic_lr_scheduler.get_last_lr()[0]\n            metrics[\"critic/lr\"] = lr\n            self.critic_lr_scheduler.step()\n\n            output = DataProto(batch=None, meta_info={\"metrics\": metrics})\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.critic_optimizer)\n\n        output = output.to(\"cpu\")\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n\n        self.checkpoint_manager.save_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n\n        self.checkpoint_manager.load_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(self.critic_optimizer)\n\n\n# TODO(sgm): we may need to extract it to dp_reward_model.py\nclass RewardModelWorker(Worker, DistProfilerExtension):\n    \"\"\"\n    Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.\n    \"\"\"\n\n    def __init__(self, config):\n        Worker.__init__(self)\n\n        omega_profiler_config = config.get(\"profiler\", {})\n        profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)\n        if omega_profiler_config.get(\"tool\", None) in [\"npu\", \"nsys\", \"torch\", \"torch_memory\"]:\n            tool_config = omega_conf_to_dataclass(\n                omega_profiler_config.get(\"tool_config\", {}).get(omega_profiler_config.get(\"tool\"))\n            )\n        else:\n            tool_config = None\n        DistProfilerExtension.__init__(\n            self,\n            DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config),\n        )\n\n        import torch.distributed\n\n        self.config = config\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(\n                backend=get_nccl_backend(),\n                timeout=datetime.timedelta(seconds=self.config.get(\"nccl_timeout\", 600)),\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n\n        # build device mesh for Ulysses Sequence Parallel\n        world_size = torch.distributed.get_world_size()\n        from torch.distributed.device_mesh import init_device_mesh\n\n        fsdp_size = self.config.model.fsdp_config.fsdp_size\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        # create training dispatch\n        if self.ulysses_device_mesh is not None:\n            is_collect = self.ulysses_device_mesh[\"sp\"].get_local_rank() == 0\n            self._register_dispatch_collect_info(\n                \"reward\", dp_rank=self.ulysses_device_mesh[\"dp\"].get_local_rank(), is_collect=is_collect\n            )\n        else:\n            self._register_dispatch_collect_info(\"reward\", dp_rank=self.rank, is_collect=True)\n\n        self.use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n\n        # normalize config\n        if self.config.micro_batch_size is not None:\n            self.config.micro_batch_size //= torch.distributed.get_world_size()\n            self.config.micro_batch_size_per_gpu = self.config.micro_batch_size\n\n    def _build_model(self, config):\n        # the following line is necessary\n        from torch.distributed.fsdp import CPUOffload\n        from transformers import AutoConfig, AutoModelForTokenClassification\n\n        use_shm = config.model.get(\"use_shm\", False)\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.model.path, use_shm=use_shm)\n\n        if self.config.model.input_tokenizer is None:\n            self._do_switch_chat_template = False\n        else:\n            self._do_switch_chat_template = True\n            input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm)\n            self.input_tokenizer = hf_tokenizer(\n                input_tokenizer_local_path, trust_remote_code=config.model.get(\"trust_remote_code\", False)\n            )\n            self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n\n        trust_remote_code = config.model.get(\"trust_remote_code\", False)\n        model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)\n        model_config.num_labels = 1\n\n        # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            model_config.classifier_dropout = 0.0\n            reward_module = AutoModelForTokenClassification.from_pretrained(\n                pretrained_model_name_or_path=local_path,\n                config=model_config,\n                torch_dtype=torch.bfloat16,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )\n\n            apply_monkey_patch(\n                model=reward_module,\n                use_remove_padding=config.model.get(\"use_remove_padding\", False),\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n            )\n\n            reward_module.to(torch.bfloat16)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        if config.strategy == \"fsdp\":\n            reward_module = FSDP(\n                reward_module,\n                param_init_fn=init_fn,\n                use_orig_params=False,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,  # zero3\n                sync_module_states=True,\n                cpu_offload=CPUOffload(offload_params=True),\n                forward_prefetch=self.config.model.fsdp_config.forward_prefetch,\n                device_mesh=self.device_mesh,\n            )\n        elif config.strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            cpu_offload = CPUOffloadPolicy(pin_memory=True)\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"offload_policy\": cpu_offload,\n                \"reshard_after_forward\": config.model.fsdp_config.reshard_after_forward,\n                \"shard_placement_fn\": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]),\n            }\n            full_state = reward_module.state_dict()\n            apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config)\n            fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload)\n        else:\n            raise NotImplementedError(f\"Unknown strategy: {config.strategy}\")\n        return reward_module\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n        self.reward_module = self._build_model(config=self.config)\n\n    def _forward_micro_batch(self, micro_batch):\n        from verl.utils.attention_utils import (index_first_axis, pad_input,\n                                                rearrange, unpad_input)\n        from verl.utils.ulysses import (gather_outputs_and_unpad,\n                                        ulysses_pad_and_slice_inputs)\n\n        with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch_size, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n            if position_ids.dim() == 3:  # qwen2vl mrope\n                position_ids = position_ids.transpose(0, 1)  # (bsz, 3, seqlen) -> (3, bsz, seqlen)\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                if position_ids.dim() == 3:\n                    position_ids_rmpad = (\n                        index_first_axis(rearrange(position_ids, \"c b s ... -> (b s) c ...\"), indices)\n                        .transpose(0, 1)\n                        .unsqueeze(1)\n                    )  # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)\n                else:\n                    position_ids_rmpad = index_first_axis(\n                        rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                    ).transpose(0, 1)\n\n                # pad and slice the inputs if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size\n                    )\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                output = self.reward_module(\n                    input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False\n                )\n                reward_rmpad = output.logits\n                reward_rmpad = reward_rmpad.squeeze(0)  # (total_nnz)\n\n                # gather output if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    reward_rmpad = gather_outputs_and_unpad(\n                        reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size\n                    )\n\n                # pad it back\n                rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)\n            else:\n                output = self.reward_module(\n                    input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False\n                )\n                rm_score = output.logits  # (batch_size, seq_len, 1)\n                rm_score = rm_score.squeeze(-1)\n\n            # extract the result of the last valid token\n            eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)\n            rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]\n            return rm_score\n\n    def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):\n        batch_size = data.batch.batch_size[0]\n        # expand as token_level_reward\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        response_length = data.batch[\"responses\"].shape[-1]\n        if position_ids.dim() == 3:  # qwen2vl mrope [bs, 3, seq_len]\n            position_ids = position_ids[:, 0, :]\n        eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)\n        token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)  # (bsz, seqlen)\n        token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores\n\n        # select the response part\n        token_level_scores = token_level_scores[:, -response_length:]\n\n        return token_level_scores\n\n    def _switch_chat_template(self, data: DataProto):\n        src_max_length = data.batch[\"attention_mask\"].shape[-1]\n\n        src_tokenizer = self.input_tokenizer\n        target_tokenizer = self.tokenizer\n\n        rm_input_ids = []\n        rm_attention_mask = []\n\n        for i in range(data.batch.batch_size[0]):\n            if not isinstance(data.non_tensor_batch[\"raw_prompt\"][i], list | np.ndarray):\n                raise TypeError(\n                    f\"raw_prompt must be a list or numpy array, got {type(data.non_tensor_batch['raw_prompt'][i])}\"\n                )\n\n            # extract raw prompt\n            chat: list = list(data.non_tensor_batch[\"raw_prompt\"][i])\n\n            # extract response\n            response_ids = data.batch[\"responses\"][i]\n            response_length = response_ids.shape[-1]\n            valid_response_length = data.batch[\"attention_mask\"][i][-response_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            response = src_tokenizer.decode(valid_response_ids)\n            # remove bos and eos\n            response = response.replace(src_tokenizer.eos_token, \"\")\n\n            chat.append({\"role\": \"assistant\", \"content\": response})\n\n            prompt_with_chat_template = target_tokenizer.apply_chat_template(\n                chat, add_generation_prompt=False, tokenize=False\n            )\n            if self.rank == 0 and i == 0:\n                # for debugging purpose\n                print(f\"Switch template. chat: {prompt_with_chat_template}\")\n\n            # the maximum length is actually determined by the reward model itself\n            max_length = self.config.get(\"max_length\", src_max_length)\n            if max_length is None:\n                max_length = src_max_length\n\n            model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids, attention_mask = verl_F.postprocess_data(\n                input_ids=model_inputs[\"input_ids\"],\n                attention_mask=model_inputs[\"attention_mask\"],\n                max_length=max_length,\n                pad_token_id=target_tokenizer.pad_token_id,\n                left_pad=False,  # right padding\n                truncation=self.config.get(\"truncation\", \"right\"),\n            )  # truncate from the right\n\n            rm_input_ids.append(input_ids)\n            rm_attention_mask.append(attention_mask)\n\n        rm_input_ids = torch.cat(rm_input_ids, dim=0)\n        rm_attention_mask = torch.cat(rm_attention_mask, dim=0)\n\n        rm_position_ids = compute_position_id_with_mask(rm_attention_mask)\n\n        rm_inputs = {\"input_ids\": rm_input_ids, \"attention_mask\": rm_attention_mask, \"position_ids\": rm_position_ids}\n\n        return DataProto.from_dict(rm_inputs)\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"reward\"))\n    @DistProfiler.annotate(color=\"brown\")\n    def compute_rm_score(self, data: DataProto):\n        import itertools\n\n        from verl.utils.seqlen_balancing import (get_reverse_idx,\n                                                 rearrange_micro_batches)\n\n        # Support all hardwares\n        data = data.to(get_device_id())\n        if self._do_switch_chat_template:\n            rm_data = self._switch_chat_template(data)\n        else:\n            rm_input_ids = data.batch[\"input_ids\"]\n            rm_attention_mask = data.batch[\"attention_mask\"]\n            rm_position_ids = data.batch[\"position_ids\"]\n            rm_inputs = {\n                \"input_ids\": rm_input_ids,\n                \"attention_mask\": rm_attention_mask,\n                \"position_ids\": rm_position_ids,\n            }\n            rm_data = DataProto.from_dict(rm_inputs)\n\n        # Support all hardwares\n        rm_data = rm_data.to(get_device_id())\n\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            use_dynamic_bsz = self.config.use_dynamic_bsz\n            if use_dynamic_bsz:\n                max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)\n            else:\n                micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)\n            output = []\n            for micro_batch in micro_batches:\n                rm_score = self._forward_micro_batch(micro_batch)\n                output.append(rm_score)\n            scores = torch.cat(output, dim=0)  # (batch_size)\n\n            if use_dynamic_bsz:\n                indices = list(itertools.chain.from_iterable(indices))\n                assert len(indices) == scores.size(0), f\"{len(indices)} vs. {scores.size()}\"\n                revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                scores = scores[revert_indices]\n\n            token_level_scores = self._expand_to_token_level(data, scores)\n            # Note that this is only the scores, may not be the final rewards used to train RL\n            output = DataProto.from_dict(tensors={\"rm_scores\": token_level_scores})\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1 and fsdp_version(self.reward_module) == 1:\n            self.reward_module._handle.reshard(True)\n\n        output = output.to(\"cpu\")\n        return output\n\n\n# ================================= Async related workers =================================\nclass AsyncActorRolloutRefWorker(ActorRolloutRefWorker):\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    async def wake_up(self):\n        await self.rollout_mode()\n        return True\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    async def sleep(self):\n        await self.trainer_mode()\n        return True\n\n    # ============================ vLLM related ============================\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    def get_zeromq_address(self):\n        return self.rollout.get_zeromq_address()\n\n    # ============================ SGLang related ============================\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)\n    async def chat_completion(self, json_request):\n        ret = await self.rollout.chat_completion(json_request)\n        return ret\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)\n    async def generate(\n        self,\n        prompt_ids: list[int],\n        sampling_params: dict[str, Any],\n        request_id: str,\n        image_data: Optional[list[Any]] = None,\n    ) -> list[int]:\n        ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data)\n        return ret\n"
  },
  {
    "path": "verl_distillation/verl/workers/megatron_workers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe main entry point to run the PPO algorithm\n\"\"\"\n\nimport datetime\nimport logging\nimport os\nimport time\nfrom typing import Any, Optional\n\nimport psutil\nimport torch\nimport torch.distributed\nfrom codetiming import Timer\nfrom omegaconf import DictConfig, OmegaConf\n\ntry:\n    from mindspeed.megatron_adaptor import repatch\nexcept ImportError:\n    repatch = None\n\nfrom megatron.core import parallel_state as mpu\n\nfrom verl import DataProto\nfrom verl.models.mcore import get_mcore_weight_converter\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.device import (\n    get_device_id,\n    get_device_name,\n    get_nccl_backend,\n    get_torch_device,\n    set_expandable_segments,\n)\nfrom verl.utils.distributed import set_numa_affinity\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.megatron_utils import (\n    load_megatron_model_to_gpu,\n    load_megatron_optimizer,\n    offload_megatron_model_to_cpu,\n    offload_megatron_optimizer,\n    per_tensor_generator,\n)\nfrom verl.utils.memory_utils import aggressive_empty_cache\nfrom verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights\nfrom verl.utils.profiler import (\n    DistProfiler,\n    DistProfilerExtension,\n    GPUMemoryLogger,\n    ProfilerConfig,\n    log_gpu_memory_usage,\n    simple_timer,\n)\nfrom verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max\nfrom verl.utils.ray_utils import get_event_loop\nfrom verl.workers.actor.megatron_actor import MegatronPPOActor\nfrom verl.workers.config import HFModelConfig, McoreCriticConfig, RolloutConfig\nfrom verl.workers.critic.megatron_critic import MegatronPPOCritic\nfrom verl.workers.reward_model.megatron.reward_model import MegatronRewardModel\nfrom verl.workers.rollout import get_rollout_class\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef set_random_seed(seed):\n    import random\n\n    import numpy as np\n    import torch\n\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    if get_torch_device().device_count() > 0:\n        from megatron.core import tensor_parallel\n\n        tensor_parallel.model_parallel_cuda_manual_seed(seed)\n    # FIXME: torch cumsum not support deterministic (used in vllm sampler),\n    # https://github.com/pytorch/pytorch/issues/89492\n    # torch.use_deterministic_algorithms(True, warn_only=True)\n    # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'\n\n\nclass MegatronWorker(Worker):\n    def _init_hf_config_and_tf_config(\n        self,\n        model_path,\n        tokenizer_or_path,\n        dtype,\n        override_model_config,\n        override_transformer_config,\n        trust_remote_code=False,\n        use_mbridge=False,\n    ):\n        from transformers import AutoConfig\n\n        from verl.models.mcore import hf_to_mcore_config\n        from verl.utils import hf_processor, hf_tokenizer\n        from verl.utils.fs import copy_to_local\n        from verl.utils.model import update_model_config\n\n        # Step 1: initialize the tokenizer\n        self.local_path = copy_to_local(model_path)\n        if tokenizer_or_path is None:\n            self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code)\n            self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code)\n        elif isinstance(tokenizer_or_path, str):\n            self.tokenizer = hf_tokenizer(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code)\n            self.processor = hf_processor(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code)\n        else:\n            self.tokenizer = tokenizer_or_path\n            self.processor = tokenizer_or_path\n\n        if self.config.model.get(\"custom_chat_template\", None) is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.config.model.custom_chat_template\n            else:\n                self.tokenizer.chat_template = self.config.model.custom_chat_template\n\n        # Step 2: get the hf\n        hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code)\n\n        # Step 3: override the hf config\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_model_config.get(\"model_config\", {}))\n        self.share_embeddings_and_output_weights = getattr(hf_config, \"tie_word_embeddings\", False)\n        update_model_config(hf_config, override_config_kwargs=override_config_kwargs)\n        self.architectures = getattr(hf_config, \"architectures\", None)\n        if self.rank == 0:\n            print(f\"Model config after override: {hf_config}\")\n\n        from verl.models.mcore.config_converter import mapping_string_to_attn_backend\n\n        # todo: remove this line after mcore adopt mbridge 0.15, now for compatibility\n        override_transformer_config = mapping_string_to_attn_backend(override_transformer_config)\n\n        if use_mbridge:\n            from verl.models.mcore.mbridge import AutoBridge\n\n            bridge = AutoBridge.from_config(hf_config)\n            bridge.set_extra_args(**override_transformer_config)\n            tf_config = bridge.config\n            self.bridge = bridge\n        else:\n            tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config)\n            self.bridge = None\n\n        print(f\"TF config: {tf_config}\")\n        self.hf_config = hf_config\n        self.tf_config = tf_config\n\n\nclass ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    def __init__(self, config: DictConfig, role: str, **kwargs):\n        Worker.__init__(self)\n        self.config = config\n        if repatch is not None:\n            # NPU MindSpeed patch, will be refactored with MindSpeedEngine.\n            repatch(self.config.actor.megatron.get(\"override_transformer_config\", {}))\n\n        # NOTE(sgm): We utilize colocate WorkerGroup by default.\n        # As a result, Workers for different model share the same process.\n        # Therefore, we only require one distribute initialization.\n        # To utilize different parallel strategy in different models:\n        # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,\n        # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385\n        if not torch.distributed.is_initialized():\n            set_numa_affinity()\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(\n                backend=get_nccl_backend(),\n                timeout=datetime.timedelta(seconds=self.config.get(\"nccl_timeout\", 600)),\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n            get_torch_device().set_device(rank)\n\n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size,\n                pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size,\n                virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size,\n                use_sharp=False,\n                context_parallel_size=self.config.actor.megatron.context_parallel_size,\n                expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size,\n                expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size,\n                nccl_communicator_config_path=None,\n            )\n\n        is_collect = (\n            mpu.get_tensor_model_parallel_rank() == 0\n            and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1\n            and mpu.get_context_parallel_rank() == 0\n        )\n        self._register_dispatch_collect_info(\n            mesh_name=\"actor\", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect\n        )\n\n        set_random_seed(seed=self.config.actor.megatron.seed)\n\n        self.role = role\n        assert self.role in [\"actor\", \"rollout\", \"ref\", \"actor_rollout\", \"actor_rollout_ref\"]\n\n        self._is_actor = self.role in [\"actor\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_rollout = self.role in [\"rollout\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_ref = self.role in [\"ref\", \"actor_rollout_ref\"]\n\n        if self._is_actor:\n            omega_profiler_config = config.actor.get(\"profiler\", {})\n        elif self._is_rollout:\n            # NOTE: In colocation mode, rollout config may not take effect (follow the actor config)\n            # This is for extendability in AsyncRL cases\n            omega_profiler_config = config.rollout.get(\"profiler\", {})\n        elif self._is_ref:\n            omega_profiler_config = config.ref.get(\"profiler\", {})\n        else:\n            raise ValueError(\n                f\"Invalid role {self.role}, should be one of \"\n                \"['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']\"\n            )\n        # omega_profiler_config is DictConfig\n        # profiler_config is a ProfilerConfig dataclass\n        profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)\n        if omega_profiler_config.get(\"tool\", None) in [\"npu\", \"nsys\", \"torch\", \"torch_memory\"]:\n            tool_config = omega_conf_to_dataclass(\n                omega_profiler_config.get(\"tool_config\", {}).get(omega_profiler_config.get(\"tool\"))\n            )\n        else:\n            tool_config = None\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config)\n        )\n\n        # TODO(sgm): Currently, we only support reference model param offload\n        # will support other offload later\n        self._is_offload_param = False\n        self._is_offload_grad = False\n        self._is_offload_optimizer = False\n\n        # normalize config\n        if self._is_actor and self._is_rollout:\n            self.config.actor.ppo_mini_batch_size *= self.config.rollout.n\n            self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()\n            if self.config.actor.get(\"ppo_micro_batch_size\", None):\n                self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()\n                self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()\n                self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size\n                self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size\n\n            self._is_offload_param = self.config.actor.megatron.get(\"param_offload\", False)\n            self._is_offload_grad = self.config.actor.megatron.get(\"grad_offload\", False)\n            self._is_offload_optimizer = self.config.actor.megatron.get(\"optimizer_offload\", False)\n        elif self._is_ref:\n            if self.config.ref.get(\"log_prob_micro_batch_size\", None):\n                self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()\n                self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size\n            else:\n                assert self.config.ref.get(\"log_prob_micro_batch_size_per_gpu\", None) is not None, (\n                    \"Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and \"\n                    \"`log_prob_micro_batch_size` should not be None at the same time.\"\n                )\n            self._ref_is_offload_param = self.config.ref.megatron.get(\"param_offload\", False)\n\n    def _build_model_optimizer(\n        self, model_path, optim_config, override_model_config, override_transformer_config, override_ddp_config=None\n    ):\n        from verl.utils.megatron.optimizer import (\n            get_megatron_optimizer,\n            get_megatron_optimizer_param_scheduler,\n            init_megatron_optim_config,\n        )\n        from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module\n        from verl.utils.model import get_generation_config, print_model_size\n\n        self._init_hf_config_and_tf_config(\n            model_path,\n            model_path,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            self.config.model.get(\"trust_remote_code\", False),\n            self.config.actor.megatron.use_mbridge,\n        )\n        self.generation_config = get_generation_config(self.local_path)\n\n        if self._is_actor or self._is_rollout:\n            wrap_config = McoreModuleWrapperConfig(\n                is_value_model=False,  # actor is not value model\n                share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,\n                wrap_with_ddp=True,\n                use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,\n            )\n            actor_module = make_megatron_module(\n                wrap_config=wrap_config,\n                tf_config=self.tf_config,\n                hf_config=self.hf_config,\n                bridge=self.bridge,\n                override_model_config=override_model_config,\n                override_ddp_config=override_ddp_config,\n            )\n            print(f\"actor_module: {len(actor_module)}\")\n            if self.config.actor.load_weight:\n                if self.config.actor.megatron.use_dist_checkpointing:\n                    load_mcore_dist_weights(\n                        actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False\n                    )\n                else:\n                    if self.bridge is not None:\n                        local_model_path = get_hf_model_path(self.config)\n                        self.bridge.load_weights(actor_module, local_model_path)\n                    else:\n                        load_megatron_gptmodel_weights(\n                            self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False\n                        )\n\n            if self.rank == 0:\n                print_model_size(actor_module[0])\n            log_gpu_memory_usage(\"After MegatronPPOActor init\", logger=logger)\n        elif self._is_ref:\n            wrap_config = McoreModuleWrapperConfig(\n                is_value_model=False,  # ref is not value model\n                share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,\n                wrap_with_ddp=False,\n                use_distributed_optimizer=self.config.ref.megatron.use_distributed_optimizer,\n            )\n            ref_module = make_megatron_module(\n                wrap_config=wrap_config,\n                tf_config=self.tf_config,\n                hf_config=self.hf_config,\n                bridge=self.bridge,\n                override_model_config=override_model_config,\n            )\n            if self.config.ref.load_weight:  # should align with the actor:\n                assert self.config.actor.load_weight == self.config.ref.load_weight\n                print(\"load ref weight start\")\n                if self.config.ref.megatron.use_dist_checkpointing:\n                    load_mcore_dist_weights(\n                        ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False\n                    )\n                else:\n                    if self.bridge is not None:\n                        local_model_path = get_hf_model_path(self.config)\n                        self.bridge.load_weights(ref_module, local_model_path)\n                    else:\n                        load_megatron_gptmodel_weights(\n                            self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False\n                        )\n            log_gpu_memory_usage(\"After ref module init\", logger=logger)\n            return ref_module, self.hf_config\n\n        # TODO: add more optimizer args into config\n        if self._is_actor:\n            optim_config_megatron = init_megatron_optim_config(optim_config)\n            actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron)\n            actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler(\n                optimizer=actor_optimizer, config=optim_config\n            )\n        else:\n            optim_config = None\n            actor_optimizer = None\n            actor_optimizer_scheduler = None\n\n        log_gpu_memory_usage(\"After actor optimizer init\", logger=logger)\n\n        return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config\n\n    def _build_rollout(self, trust_remote_code=False):\n        from torch.distributed.device_mesh import init_device_mesh\n\n        # 1. parse rollout and huggingface model config\n        rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)\n        model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)\n\n        # 2. build rollout device mesh\n        infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size\n        infer_pp = self.config.rollout.pipeline_model_parallel_size\n        infer_world_size = infer_tp * infer_pp\n        dp = self.world_size // infer_world_size\n        assert self.world_size % infer_world_size == 0, (\n            f\"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}\"\n        )\n        rollout_device_mesh = init_device_mesh(\n            get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=[\"dp\", \"infer_tp\", \"infer_pp\"]\n        )\n\n        is_collect = (\n            rollout_device_mesh[\"infer_tp\"].get_local_rank() == 0\n            and rollout_device_mesh[\"infer_pp\"].get_local_rank() == 0\n        )\n        self._register_dispatch_collect_info(\n            \"rollout\", dp_rank=rollout_device_mesh[\"dp\"].get_local_rank(), is_collect=is_collect\n        )\n\n        # 3. init trainer and rollout random states\n        self.torch_random_states = get_torch_device().get_rng_state()\n        gen_dp_rank = rollout_device_mesh[\"dp\"].get_local_rank()\n        get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n        self.gen_random_states = get_torch_device().get_rng_state()\n        get_torch_device().set_rng_state(self.torch_random_states)\n\n        # 4. build rollout model\n        log_gpu_memory_usage(f\"Before building {self.config.rollout.name} rollout\", logger=logger)\n        self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)(\n            config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh\n        )\n        log_gpu_memory_usage(f\"After building {self.config.rollout.name} rollout\", logger=logger)\n\n        # 5. switch to trainer mode\n        # NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint.\n        # For sync mode, we directly switch to trainer mode here.\n        # For async mode, we can't call run_until_complete here, so we will switch to trainer mode in AgentLoopManager.\n        if rollout_config.mode == \"sync\" and self._is_actor:\n            loop = get_event_loop()\n            loop.run_until_complete(self.trainer_mode())\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        if self.config.model.get(\"external_lib\", None) is not None:\n            # This is used to import external_lib into the huggingface systems\n            import importlib\n\n            importlib.import_module(self.config.model.external_lib)\n\n        from verl.utils.torch_dtypes import PrecisionType\n\n        override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get(\"override_config\", {})))\n        if self._is_actor:\n            override_transformer_config = OmegaConf.to_container(\n                OmegaConf.create(self.config.actor.megatron.get(\"override_transformer_config\", {}))\n            )\n            override_ddp_config = OmegaConf.to_container(\n                OmegaConf.create(self.config.actor.megatron.get(\"override_ddp_config\", {}))\n            )\n        elif self._is_ref:\n            override_transformer_config = OmegaConf.to_container(\n                OmegaConf.create(self.config.ref.megatron.get(\"override_transformer_config\", {}))\n            )\n        else:\n            override_transformer_config = {}\n        self.param_dtype = torch.bfloat16\n        log_gpu_memory_usage(\"Before init actor model and optimizer\", logger=logger)\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n        if self._is_actor:\n            # we need the model for actor and rollout\n            optim_config = self.config.actor.optim if self._is_actor else None\n            (\n                self.actor_module,\n                self.actor_optimizer,\n                self.actor_optimizer_scheduler,\n                self.actor_model_config,\n                self.actor_optim_config,\n            ) = self._build_model_optimizer(\n                model_path=self.config.model.path,\n                optim_config=optim_config,\n                override_model_config=override_model_config,\n                override_transformer_config=override_transformer_config,\n                override_ddp_config=override_ddp_config,\n            )\n            if self._is_offload_param:\n                offload_megatron_model_to_cpu(self.actor_module)\n                log_gpu_memory_usage(\"After offload actor params and grad during init\", logger=logger)\n            if self._is_offload_optimizer:\n                offload_megatron_optimizer(self.actor_optimizer)\n                log_gpu_memory_usage(\"After offload actor optimizer during init\", logger=logger)\n\n        if self._is_actor:\n            actor_cfg = omega_conf_to_dataclass(self.config.actor)\n            self.actor = MegatronPPOActor(\n                config=actor_cfg,\n                model_config=self.actor_model_config,\n                hf_config=self.hf_config,\n                tf_config=self.tf_config,\n                actor_module=self.actor_module,\n                actor_optimizer=self.actor_optimizer,\n            )\n            log_gpu_memory_usage(\"After MegatronPPOActor init\", logger=logger)\n\n        if self._is_rollout:\n            self._build_rollout(trust_remote_code=self.config.model.get(\"trust_remote_code\", False))\n            log_gpu_memory_usage(\"After rollout init\", logger=logger)\n\n        if self._is_ref:\n            self.ref_module, self.ref_model_config = self._build_model_optimizer(\n                model_path=self.config.model.path,\n                optim_config=None,\n                override_model_config=override_model_config,\n                override_transformer_config=override_transformer_config,\n            )\n            log_gpu_memory_usage(\"After ref model init\", logger=logger)\n            self.ref_policy = MegatronPPOActor(\n                config=self.config.ref,\n                model_config=self.ref_model_config,\n                hf_config=self.hf_config,\n                tf_config=self.tf_config,\n                actor_module=self.ref_module,\n                actor_optimizer=None,\n            )\n            if self._ref_is_offload_param:\n                offload_megatron_model_to_cpu(self.ref_module)\n                log_gpu_memory_usage(\"After offload ref params during init\", logger=logger)\n\n        if self._is_actor:\n            self.flops_counter = FlopsCounter(self.actor_model_config)\n            self.checkpoint_mananager = MegatronCheckpointManager(\n                config=self.config,\n                checkpoint_config=self.config.actor.checkpoint,\n                model_config=self.actor_model_config,\n                transformer_config=self.tf_config,\n                role=\"actor\",\n                model=self.actor_module,\n                arch=self.architectures[0],\n                hf_config=self.hf_config,\n                param_dtype=self.param_dtype,\n                share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                optimizer=self.actor_optimizer,\n                optimizer_scheduler=self.actor_optimizer_scheduler,\n                use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,\n                use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler,\n                bridge=self.bridge,\n                use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing,\n            )\n\n            self.layer_name_mapping = {\n                \"qkv_layer_name\": \"self_attention.linear_qkv.\",\n                \"gate_proj_layer_name\": \"linear_fc1.\",\n            }\n            self.weight_converter = None\n            if not self.config.actor.megatron.use_mbridge:\n                self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)\n\n        get_torch_device().empty_cache()\n        log_gpu_memory_usage(\"After init_model finish\", logger=logger)\n\n    async def rollout_mode(self):\n        \"\"\"Context switch hybridengine to rollout mode.\"\"\"\n        aggressive_empty_cache(force_sync=True)\n\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor.actor_module, load_grad=False)\n            log_gpu_memory_usage(\"After load actor params during rollout_mode\", logger=logger)\n\n        if self.bridge is not None:\n            per_tensor_param = self.bridge.export_weights(self.actor.actor_module)\n        else:\n            per_tensor_param = per_tensor_generator(\n                self.actor.actor_module,\n                self.actor_model_config,\n                self.weight_converter,\n                self.tf_config,\n                self.layer_name_mapping,\n            )\n\n        set_expandable_segments(False)\n\n        if self.config.rollout.free_cache_engine:\n            await self.rollout.resume(tags=[\"weights\"])\n        await self.rollout.update_weights(per_tensor_param)\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor.actor_module)\n        aggressive_empty_cache(force_sync=True)\n        if self.config.rollout.free_cache_engine:\n            await self.rollout.resume(tags=[\"kv_cache\"])\n\n        # important: need to manually set the random states of each tp to be identical.\n        self.torch_random_states = get_torch_device().get_rng_state()\n        get_torch_device().set_rng_state(self.gen_random_states)\n\n    async def trainer_mode(self):\n        \"\"\"Context switch hybridengine to trainer mode.\"\"\"\n        if self.config.rollout.free_cache_engine:\n            log_gpu_memory_usage(\"Before rollout offload\", logger=logger)\n            await self.rollout.release()\n            log_gpu_memory_usage(\"After rollout offload\", logger=logger)\n\n        for model in self.actor.actor_module:\n            model.train()\n        # add empty cache after each compute\n        aggressive_empty_cache(force_sync=True)\n\n        # FIXME(@wuxibin): megatron+sglang failed with `expandable_segments:True` in ci,\n        # can't reproduce it in dev environment, temporary disable it.\n        # https://github.com/volcengine/verl/actions/runs/17382936845/job/49344264323?pr=3285\n        if os.environ.get(\"MEGATRON_CI_DISABLE_EXPANDABLE_SEGMENTS\", \"0\") == \"0\":\n            set_expandable_segments(True)\n\n        # restore random states\n        self.gen_random_states = get_torch_device().get_rng_state()\n        get_torch_device().set_rng_state(self.torch_random_states)\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    @GPUMemoryLogger(role=\"update_actor\", logger=logger)\n    @DistProfiler.annotate(color=\"red\")\n    def update_actor(self, data: DataProto):\n        assert self._is_actor\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n            log_gpu_memory_usage(\"After load actor params and grad during update_actor\", logger=logger)\n        if self._is_offload_optimizer:\n            load_megatron_optimizer(self.actor_optimizer)\n            log_gpu_memory_usage(\"After load actor optimizer during update_actor\", logger=logger)\n\n        micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        dataloader = self.actor.make_minibatch_iterator(data=data)\n        with Timer(name=\"update_policy\", logger=None) as timer:\n            metrics = self.actor.update_policy(dataloader=dataloader)\n        delta_time = timer.last\n        global_num_tokens = data.meta_info[\"global_token_num\"]\n        estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n        metrics[\"perf/mfu/actor\"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size\n        metrics[\"perf/max_memory_allocated_gb\"] = get_torch_device().max_memory_allocated() / (1024**3)\n        metrics[\"perf/max_memory_reserved_gb\"] = get_torch_device().max_memory_reserved() / (1024**3)\n        metrics[\"perf/cpu_memory_used_gb\"] = psutil.virtual_memory().used / (1024**3)\n        from verl.utils.megatron.optimizer import get_megatron_last_lr\n\n        metrics[\"actor/lr\"] = get_megatron_last_lr(self.actor_optimizer)\n        self.actor_optimizer_scheduler.step(1)\n\n        # TODO: here, we should return all metrics\n        output = DataProto(meta_info={\"metrics\": metrics})\n        output = output.to(\"cpu\")\n\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n            log_gpu_memory_usage(\"After offload actor params and grad during update_actor\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n            log_gpu_memory_usage(\"After offload actor optimizer during update_actor\", logger=logger)\n\n        aggressive_empty_cache(force_sync=True)\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"rollout\"))\n    @GPUMemoryLogger(role=\"generate_sequences\", logger=logger)\n    @DistProfiler.annotate(color=\"red\")\n    def generate_sequences(self, prompts: DataProto):\n        assert self._is_rollout\n        prompts = prompts.to(get_device_name())\n        meta_info = {\n            \"eos_token_id\": self.generation_config.eos_token_id\n            if self.generation_config is not None\n            else self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.generation_config.pad_token_id\n            if self.generation_config is not None\n            else self.tokenizer.pad_token_id,\n        }\n        prompts.meta_info.update(meta_info)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n\n        timing_generate = {}\n        if self._is_actor:  # For rollout only, we do not switch context.\n            loop = get_event_loop()\n            loop.run_until_complete(self.rollout_mode())\n            log_gpu_memory_usage(\"After switch to rollout mode\", logger=logger)\n\n        with simple_timer(\"generate_sequences\", timing_generate):\n            output = self.rollout.generate_sequences(prompts=prompts)\n\n        if self._is_actor:\n            loop.run_until_complete(self.trainer_mode())\n            log_gpu_memory_usage(\"After switch to trainer mode\", logger=logger)\n\n        # We calculate the average timing across all ranks\n        # to make sure meta_info[\"timing\"] is the same\n        timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max(\n            timing_generate[\"generate_sequences\"]\n        )\n        timing_generate = reduce_timing(timing_generate)\n        timing_generate.update(\n            {\n                \"generation_timing/max\": timing_generate_max,\n                \"generation_timing/min\": timing_generate_min,\n                \"generation_timing/topk_ratio\": timing_generate_topk_ratio,\n            }\n        )\n        output.meta_info[\"timing\"] = timing_generate\n        output = output.to(\"cpu\")\n        # clear kv cache\n        aggressive_empty_cache(force_sync=True)\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    @GPUMemoryLogger(role=\"compute_ref_log_prob\", logger=logger)\n    @DistProfiler.annotate(color=\"olive\")\n    def compute_ref_log_prob(self, data: DataProto):\n        assert self._is_ref\n        if self._ref_is_offload_param:\n            load_megatron_model_to_gpu(self.ref_module, load_grad=False)\n            log_gpu_memory_usage(\"After load ref params and grad during compute_ref_log_prob\", logger=logger)\n        micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"max_token_len\"] = self.config.ref.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.ref.log_prob_use_dynamic_bsz\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)\n        output = DataProto.from_dict(tensors={\"ref_log_prob\": output})\n        output = output.to(\"cpu\")\n        if self._ref_is_offload_param:\n            offload_megatron_model_to_cpu(self.ref_module)\n            log_gpu_memory_usage(\"After offload ref params and grad during compute_ref_log_prob\", logger=logger)\n        aggressive_empty_cache(force_sync=True)\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    @GPUMemoryLogger(role=\"compute_log_prob\", logger=logger)\n    @DistProfiler.annotate(color=\"blue\")\n    def compute_log_prob(self, data: DataProto):\n        assert self._is_actor\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module, load_grad=False)\n            log_gpu_memory_usage(\"After load actor params and grad during compute_log_prob\", logger=logger)\n        # we should always recompute old_log_probs when it is HybridEngine\n        data.meta_info[\"micro_batch_size\"] = self.config.rollout.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"max_token_len\"] = self.config.rollout.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.rollout.log_prob_use_dynamic_bsz\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)\n        output = DataProto.from_dict(\n            tensors={\"old_log_probs\": output, \"entropys\": entropys},\n            meta_info={\"temperature\": self.config.rollout.temperature},\n        )\n        output = output.to(\"cpu\")\n        # clear kv cache\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n            log_gpu_memory_usage(\"After offload actor params and grad during compute_log_prob\", logger=logger)\n        aggressive_empty_cache(force_sync=True)\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True):\n        # No checkpoint to load, just offload the model and optimizer to CPU\n        if checkpoint_path is None:\n            if self._is_offload_param:\n                offload_megatron_model_to_cpu(self.actor_module)\n            if self._is_offload_optimizer:\n                offload_megatron_optimizer(self.actor_optimizer)\n            log_gpu_memory_usage(\"After offload actor params and optimizer during load_checkpoint\", logger=logger)\n            return\n\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n        self.checkpoint_mananager.load_checkpoint(\n            local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_pretrained_model(self, checkpoint_path, del_local_after_load=True):\n        pass\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n        self.checkpoint_mananager.save_checkpoint(\n            local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n\n\nclass AsyncActorRolloutRefWorker(ActorRolloutRefWorker):\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    async def wake_up(self):\n        await self.rollout_mode()\n        return True\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    async def sleep(self):\n        await self.trainer_mode()\n        return True\n\n    # ============================ vLLM related ============================\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    def get_zeromq_address(self):\n        return self.rollout.get_zeromq_address()\n\n    # ============================ SGLang related ============================\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)\n    async def chat_completion(self, json_request):\n        ret = await self.rollout.chat_completion(json_request)\n        return ret\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)\n    async def generate(\n        self,\n        prompt_ids: list[int],\n        sampling_params: dict[str, Any],\n        request_id: str,\n        image_data: Optional[list[Any]] = None,\n    ) -> list[int]:\n        ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data)\n        return ret\n\n\nclass CriticWorker(MegatronWorker, DistProfilerExtension):\n    def __init__(self, config: McoreCriticConfig):\n        Worker.__init__(self)\n\n        omega_profiler_config = config.get(\"profiler\", {})\n        profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)\n        if omega_profiler_config.get(\"tool\", None) in [\"npu\", \"nsys\", \"torch\", \"torch_memory\"]:\n            tool_config = omega_conf_to_dataclass(\n                omega_profiler_config.get(\"tool_config\", {}).get(omega_profiler_config.get(\"tool\"))\n            )\n        else:\n            tool_config = None\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config)\n        )\n        self.config: McoreCriticConfig = config\n\n        # NOTE(sgm): We utilize colocate WorkerGroup by default.\n        # As a result, Workers for different model share the same process.\n        # Therefore, we only require one distribute initialization.\n        # To utilize different parallel strategy in different models:\n        # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,\n        # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385\n        if not torch.distributed.is_initialized():\n            set_numa_affinity()\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(\n                backend=get_nccl_backend(),\n                timeout=datetime.timedelta(seconds=self.config.get(\"nccl_timeout\", 600)),\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n            get_torch_device().set_device(rank)\n\n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size,\n                pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size,\n                virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size,\n                use_sharp=False,\n                context_parallel_size=self.config.megatron.context_parallel_size,\n                expert_model_parallel_size=self.config.megatron.expert_model_parallel_size,\n                expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size,\n                nccl_communicator_config_path=None,\n            )\n\n        is_collect = (\n            mpu.get_tensor_model_parallel_rank() == 0\n            and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1\n            and mpu.get_context_parallel_rank() == 0\n        )\n        self._register_dispatch_collect_info(\n            mesh_name=\"critic\", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect\n        )\n\n        set_random_seed(seed=self.config.megatron.seed)\n\n        # set FSDP offload params\n        self._is_offload_param = self.config.megatron.param_offload\n        self._is_offload_optimizer = self.config.megatron.optimizer_offload\n\n        # normalize config\n        self.config.ppo_mini_batch_size *= self.config.rollout_n\n        self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()\n        if self.config.get(\"ppo_micro_batch_size\", None):\n            self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()\n            self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size\n\n        # TODO(sgm): support critic model offload\n\n    def _build_critic_model_optimizer(\n        self, model_path, optim_config, override_model_config, override_transformer_config, override_ddp_config\n    ):\n        from verl.utils.megatron.optimizer import (\n            get_megatron_optimizer,\n            get_megatron_optimizer_param_scheduler,\n            init_megatron_optim_config,\n        )\n        from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module\n        from verl.utils.model import print_model_size\n\n        self._init_hf_config_and_tf_config(\n            model_path,\n            self.config.model.tokenizer_path,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            self.config.model.get(\"trust_remote_code\", False),\n            self.config.megatron.use_mbridge,\n        )\n\n        wrap_config = McoreModuleWrapperConfig(\n            is_value_model=True,  # critic is value model\n            share_embeddings_and_output_weights=False,\n            wrap_with_ddp=True,\n            use_distributed_optimizer=self.config.megatron.use_distributed_optimizer,\n        )\n        critic_module = make_megatron_module(\n            wrap_config=wrap_config,\n            tf_config=self.tf_config,\n            hf_config=self.hf_config,\n            bridge=self.bridge,\n            override_model_config=override_model_config,\n            override_ddp_config=override_ddp_config,\n        )\n        # note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp).\n        # but here, we do not use pp (vpp) yet. For simplicity, we remove the list\n        # critic_module = nn.ModuleList(critic_module)\n\n        if self.config.load_weight:\n            t0 = time.time()\n            if self.config.megatron.use_dist_checkpointing:\n                load_mcore_dist_weights(\n                    critic_module, self.config.megatron.dist_checkpointing_path, is_value_model=True\n                )\n            else:\n                if self.bridge is not None:\n                    local_model_path = get_hf_model_path(self.config)\n                    self.bridge.load_weights(critic_module, local_model_path)\n                else:\n                    load_megatron_gptmodel_weights(\n                        self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True\n                    )\n            t1 = time.time()\n            if torch.distributed.get_rank() == 0:\n                print(f\"critic load_weight time: {t1 - t0}\")\n        if self.rank == 0:\n            print_model_size(critic_module[0])\n\n        # TODO: add more optimizer args into config\n        optim_config_megatron = init_megatron_optim_config(optim_config)\n        critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron)\n        critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler(\n            optimizer=critic_optimizer, config=optim_config\n        )\n        get_torch_device().empty_cache()\n        return critic_module, critic_optimizer, critic_optimizer_scheduler, self.hf_config, optim_config\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # create critic\n\n        from verl.utils.torch_dtypes import PrecisionType\n\n        if self.config.model.get(\"external_lib\", None) is not None:\n            # This is used to import external_lib into the huggingface systems\n            import importlib\n\n            importlib.import_module(self.config.model.external_lib)\n        override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get(\"override_config\", {})))\n        override_transformer_config = OmegaConf.to_container(\n            OmegaConf.create(self.config.megatron.get(\"override_transformer_config\", {}))\n        )\n        override_ddp_config = OmegaConf.to_container(\n            OmegaConf.create(self.config.megatron.get(\"override_ddp_config\", {}))\n        )\n        self.param_dtype = torch.bfloat16\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n        (\n            self.critic_module,\n            self.critic_optimizer,\n            self.critic_optimizer_scheduler,\n            self.critic_model_config,\n            critic_optimizer_config,\n        ) = self._build_critic_model_optimizer(\n            model_path=self.config.model.path,\n            optim_config=self.config.optim,\n            override_model_config=override_model_config,\n            override_transformer_config=override_transformer_config,\n            override_ddp_config=override_ddp_config,\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.critic_optimizer)\n\n        self.critic = MegatronPPOCritic(\n            config=self.config,\n            model_config=self.critic_model_config,\n            hf_config=self.hf_config,\n            tf_config=self.tf_config,\n            critic_module=self.critic_module,\n            critic_optimizer=self.critic_optimizer,\n            critic_optimizer_config=critic_optimizer_config,\n        )\n        self.flops_counter = FlopsCounter(self.critic_model_config)\n        self.checkpoint_mananager = MegatronCheckpointManager(\n            config=self.config,\n            checkpoint_config=self.config.checkpoint,\n            model_config=self.critic_model_config,\n            transformer_config=self.tf_config,\n            role=\"critic\",\n            model=self.critic_module,\n            arch=self.architectures[0],\n            hf_config=self.hf_config,\n            param_dtype=self.param_dtype,\n            share_embeddings_and_output_weights=False,\n            processing_class=self.processor if self.processor is not None else self.tokenizer,\n            optimizer=self.critic_optimizer,\n            optimizer_scheduler=self.critic_optimizer_scheduler,\n            use_distributed_optimizer=self.config.megatron.use_distributed_optimizer,\n            use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler,\n            bridge=self.bridge,\n            use_dist_checkpointing=self.config.megatron.use_dist_checkpointing,\n        )\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"critic\"))\n    @DistProfiler.annotate(color=\"cyan\")\n    def compute_values(self, data: DataProto):\n        micro_batch_size = self.config.ppo_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"max_token_len\"] = self.config.forward_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        data = data.to(get_device_id())\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        values = self.critic.compute_values(data=data)\n        output = DataProto.from_dict(tensors={\"values\": values})\n        output = output.to(\"cpu\")\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"critic\"))\n    @DistProfiler.annotate(color=\"pink\")\n    def update_critic(self, data: DataProto):\n        data = data.to(get_device_id())\n\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        if self._is_offload_optimizer:\n            load_megatron_optimizer(self.critic_optimizer)\n\n        dataloader = self.critic.make_minibatch_iterator(data)\n        with Timer(name=\"update_critic\", logger=None) as timer:\n            metrics = self.critic.update_critic(dataloader=dataloader)\n        delta_time = timer.last\n        global_num_tokens = data.meta_info[\"global_token_num\"]\n        estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n        metrics[\"perf/mfu/critic\"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size\n        from verl.utils.megatron.optimizer import get_megatron_last_lr\n\n        metrics[\"critic/lr\"] = get_megatron_last_lr(self.critic_optimizer)\n        self.critic_optimizer_scheduler.step(1)\n\n        output = DataProto(batch=None, meta_info={\"metrics\": metrics})\n\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.critic_optimizer)\n        output = output.to(\"cpu\")\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        self.checkpoint_mananager.load_checkpoint(\n            local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.critic_optimizer)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_steps=0, max_ckpt_to_keep=None):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        self.checkpoint_mananager.save_checkpoint(\n            local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_steps, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n\n\nclass RewardModelWorker(MegatronWorker, DistProfilerExtension):\n    \"\"\"\n    Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification.\n    \"\"\"\n\n    def __init__(self, config):\n        Worker.__init__(self)\n\n        profiler_config = omega_conf_to_dataclass(config.get(\"profiler\", {}), dataclass_type=ProfilerConfig)\n        omega_profiler_config = config.get(\"profiler\", {})\n        profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)\n        if omega_profiler_config.get(\"tool\", None) in [\"npu\", \"nsys\", \"torch\", \"torch_memory\"]:\n            tool_config = omega_conf_to_dataclass(\n                omega_profiler_config.get(\"tool_config\", {}).get(omega_profiler_config.get(\"tool\"))\n            )\n        else:\n            tool_config = None\n        DistProfilerExtension.__init__(\n            self,\n            DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config),\n        )\n        self.config = config\n\n        # NOTE(sgm): We utilize colocate WorkerGroup by default.\n        # As a result, Workers for different model share the same process.\n        # Therefore, we only require one distribute initialization.\n        # To utilize different parallel strategy in different models:\n        # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,\n        # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385\n        if not torch.distributed.is_initialized():\n            set_numa_affinity()\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(\n                backend=get_nccl_backend(),\n                timeout=datetime.timedelta(seconds=self.config.get(\"nccl_timeout\", 600)),\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n            get_torch_device().set_device(rank)\n\n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size,\n                pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size,\n                virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size,\n                use_sharp=False,\n                context_parallel_size=self.config.megatron.context_parallel_size,\n                expert_model_parallel_size=self.config.megatron.expert_model_parallel_size,\n                expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size,\n                nccl_communicator_config_path=None,\n            )\n\n        is_collect = (\n            mpu.get_tensor_model_parallel_rank() == 0\n            and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1\n            and mpu.get_context_parallel_rank() == 0\n        )\n        self._register_dispatch_collect_info(\n            mesh_name=\"reward\", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect\n        )\n\n        set_random_seed(seed=self.config.megatron.seed)\n\n        # normalize config\n        if self.config.micro_batch_size is not None:\n            self.config.micro_batch_size //= mpu.get_data_parallel_world_size()\n            self.config.micro_batch_size_per_gpu = self.config.micro_batch_size\n\n    def _build_rm_model(self, model_path, tokenizer, override_model_config, override_transformer_config):\n        from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module\n\n        self._init_hf_config_and_tf_config(\n            model_path,\n            tokenizer,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            self.config.model.get(\"trust_remote_code\", False),\n            self.config.megatron.use_mbridge,\n        )\n\n        wrap_config = McoreModuleWrapperConfig(\n            is_value_model=True,  # reward model is value model\n            share_embeddings_and_output_weights=False,\n            wrap_with_ddp=False,\n            use_distributed_optimizer=self.config.megatron.use_distributed_optimizer,\n        )\n        reward_model = make_megatron_module(\n            wrap_config=wrap_config,\n            tf_config=self.tf_config,\n            hf_config=self.hf_config,\n            bridge=self.bridge,\n            override_model_config=override_model_config,\n        )\n\n        if self.config.load_weight:\n            if self.config.megatron.use_dist_checkpointing:\n                load_mcore_dist_weights(reward_model, self.config.megatron.dist_checkpointing_path, is_value_model=True)\n            else:\n                if self.bridge is not None:\n                    local_model_path = get_hf_model_path(self.config)\n                    self.bridge.load_weights(reward_model, local_model_path)\n                else:\n                    load_megatron_gptmodel_weights(\n                        self.config, self.hf_config, reward_model, params_dtype=self.dtype, is_value_model=True\n                    )\n\n        get_torch_device().empty_cache()\n        return reward_model, self.hf_config\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # create critic\n\n        from verl.utils.torch_dtypes import PrecisionType\n\n        if self.config.model.get(\"external_lib\", None) is not None:\n            # This is used to import external_lib into the huggingface systems\n            import importlib\n\n            importlib.import_module(self.config.model.external_lib)\n        override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get(\"override_config\", {})))\n        override_transformer_config = OmegaConf.to_container(\n            OmegaConf.create(self.config.megatron.get(\"override_transformer_config\", {}))\n        )\n\n        use_shm = self.config.model.get(\"use_shm\", False)\n        sft_tokenizer_local_path = copy_to_local(self.config.model.input_tokenizer, use_shm=use_shm)\n        sft_tokenizer = hf_tokenizer(sft_tokenizer_local_path)\n        rm_tokenizer_path = self.config.model.get(\"rm_tokenizer\", None)\n        rm_tokenizer = None\n        if rm_tokenizer_path is not None:\n            rm_tokenizer_local_path = copy_to_local(rm_tokenizer_path, use_shm=use_shm)\n            rm_tokenizer = hf_tokenizer(\n                rm_tokenizer_local_path, trust_remote_code=self.config.model.get(\"trust_remote_code\", False)\n            )\n\n        self.param_dtype = torch.bfloat16\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n\n        reward_model_module, reward_model_config = self._build_rm_model(\n            model_path=self.config.model.path,\n            tokenizer=rm_tokenizer,\n            override_model_config=override_model_config,\n            override_transformer_config=override_transformer_config,\n        )\n        # FIXME(sgm): reward model param offload is implemented in MegatronRewardModel\n        # should be implemented in workers\n        self.rm = MegatronRewardModel(\n            config=self.config,\n            reward_model_module=reward_model_module,\n            model_config=reward_model_config,\n            hf_config=self.hf_config,\n            tf_config=self.tf_config,\n            sft_tokenizer=sft_tokenizer,\n            rm_tokenizer=rm_tokenizer,\n        )\n\n    # TODO: reward model use itself tokenizer instead of sft tokenizer\n    # the input_ids, responses, attention_mask and position_ids may be different!\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"reward\"))\n    @DistProfiler.annotate(color=\"brown\")\n    def compute_rm_score(self, data: DataProto):\n        data.meta_info[\"micro_batch_size\"] = self.config.micro_batch_size_per_gpu\n        data.meta_info[\"max_token_len\"] = self.config.forward_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        data = data.to(get_device_id())\n        output = self.rm.compute_reward(data)\n        output = output.to(\"cpu\")\n        return output\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_manager/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .registry import get_reward_manager_cls, register  # noqa: I001\nfrom .batch import BatchRewardManager\nfrom .dapo import DAPORewardManager\nfrom .naive import NaiveRewardManager\nfrom .prime import PrimeRewardManager\n\n# Note(haibin.lin): no need to include all reward managers here in case of complicated dependencies\n__all__ = [\n    \"BatchRewardManager\",\n    \"DAPORewardManager\",\n    \"NaiveRewardManager\",\n    \"PrimeRewardManager\",\n    \"register\",\n    \"get_reward_manager_cls\",\n]\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_manager/abstract.py",
    "content": "# Copyright 2023-2025 SGLang Team\n# Copyright Amazon.com, Inc. or its affiliates.\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Callable\n\nimport torch\n\nfrom verl.protocol import DataProto\n\nRawRewardFn = Callable[..., Any]\n\n\nclass AbstractRewardManager(ABC):\n    @abstractmethod\n    def __init__(\n        self,\n        tokenizer: Any,\n        num_examine: int,\n        compute_score: RawRewardFn | None,\n        reward_fn_key: str = \"data_source\",\n        **kwargs: Any,\n    ):\n        pass\n\n    @abstractmethod\n    def __call__(\n        self,\n        data: DataProto,\n        return_dict: bool = False,\n    ) -> torch.Tensor | dict[str, Any]:\n        pass\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_manager/batch.py",
    "content": "# Copyright 2025 Individual Contributor: Mert Unsal\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import defaultdict\nfrom typing import Any\n\nimport torch\n\nfrom verl import DataProto\nfrom verl.workers.reward_manager import register\nfrom verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn\n\n\n@register(\"batch\")\nclass BatchRewardManager(AbstractRewardManager):\n    \"\"\"\n    A batch reward manager that computes rewards for a batch of data.\n\n    Args:\n        tokenizer (Tokenizer): The tokenizer to use for decoding the responses.\n        num_examine (int): The number of responses to examine.\n        compute_score (callable): The function to compute the rewards.\n        reward_fn_key (str): The key to use for the reward function.\n        reward_kwargs (dict): The keyword arguments to pass to the reward function.\n    \"\"\"\n\n    def __init__(\n        self, tokenizer, num_examine, compute_score: RawRewardFn, reward_fn_key=\"data_source\", **reward_kwargs\n    ):\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine\n        self.compute_score = compute_score\n        self.reward_fn_key = reward_fn_key\n        self.reward_kwargs = reward_kwargs\n\n    def verify(self, data):\n        prompt_ids = data.batch[\"prompts\"]\n        response_ids = data.batch[\"responses\"]\n        attention_mask = data.batch[\"attention_mask\"]\n\n        prompt_len = prompt_ids.shape[-1]\n        valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1)\n\n        responses_str = []\n        for i in range(len(data)):\n            valid_len = valid_response_lengths[i]\n            valid_response_ids = response_ids[i][:valid_len]\n            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n            responses_str.append(response_str)\n\n        ground_truths = [item.non_tensor_batch[\"reward_model\"].get(\"ground_truth\", None) for item in data]\n        data_sources = data.non_tensor_batch[self.reward_fn_key]\n        rollout_reward_scores = data.non_tensor_batch.get(\"reward_scores\", [{} for _ in range(len(data))])\n        extras = data.non_tensor_batch.get(\"extra_info\", [{} for _ in range(len(data))])\n\n        for i in range(len(data)):\n            extras[i][\"rollout_reward_scores\"] = rollout_reward_scores[i]\n\n        scores = self.compute_score(\n            data_sources=data_sources,\n            solution_strs=responses_str,\n            ground_truths=ground_truths,\n            extra_infos=extras,\n            **self.reward_kwargs,\n        )\n\n        return scores\n\n    def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            if return_dict:\n                reward_extra_keys = data.meta_info.get(\"reward_extra_keys\", [])\n                reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}\n                return {\"reward_tensor\": data.batch[\"rm_scores\"], \"reward_extra_info\": reward_extra_info}\n            else:\n                return data.batch[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n        reward_extra_info = defaultdict(list)\n        prompt_ids = data.batch[\"prompts\"]\n        prompt_len = prompt_ids.shape[-1]\n        attention_mask = data.batch[\"attention_mask\"]\n        valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1)\n        data_sources = data.non_tensor_batch[self.reward_fn_key]\n\n        scores = self.verify(data)\n        rewards = []\n        already_printed: dict[str, Any] = {}\n\n        for i in range(len(data)):\n            length = valid_response_lengths[i].item()\n            score = scores[i]\n\n            if isinstance(score, dict):\n                reward = score[\"score\"]\n                for key, value in score.items():\n                    reward_extra_info[key].append(value)\n            else:\n                reward = score\n\n            rewards.append(reward)\n            reward_tensor[i, length - 1] = reward\n\n            data_source = data_sources[i]\n            if already_printed.get(data_source, 0) < self.num_examine:\n                response_str = self.tokenizer.decode(data.batch[\"responses\"][i][:length], skip_special_tokens=True)\n                prompt_str = self.tokenizer.decode(data.batch[\"prompts\"][i], skip_special_tokens=True)\n                ground_truth = data[i].non_tensor_batch[\"reward_model\"].get(\"ground_truth\", None)\n                print(\"[prompt]\", prompt_str)\n                print(\"[response]\", response_str)\n                print(\"[ground_truth]\", ground_truth)\n                print(\"[score]\", scores[i])\n                already_printed[data_source] = already_printed.get(data_source, 0) + 1\n\n        data.batch[\"acc\"] = torch.tensor(rewards, dtype=torch.float32, device=prompt_ids.device)\n\n        if return_dict:\n            return {\"reward_tensor\": reward_tensor, \"reward_extra_info\": reward_extra_info}\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_manager/dapo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import defaultdict\n\nimport torch\n\nfrom verl import DataProto\nfrom verl.utils.reward_score import default_compute_score\nfrom verl.workers.reward_manager import register\nfrom verl.workers.reward_manager.abstract import AbstractRewardManager\n\n\n@register(\"dapo\")\nclass DAPORewardManager(AbstractRewardManager):\n    \"\"\"The reward manager.\"\"\"\n\n    def __init__(\n        self,\n        tokenizer,\n        num_examine,\n        compute_score=None,\n        reward_fn_key=\"data_source\",\n        max_resp_len=None,\n        overlong_buffer_cfg=None,\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n        self.compute_score = compute_score or default_compute_score\n        self.reward_fn_key = reward_fn_key\n        self.overlong_buffer_cfg = overlong_buffer_cfg\n        self.max_resp_len = max_resp_len\n\n        if self.overlong_buffer_cfg is not None:\n            assert self.max_resp_len is not None, (\n                f\"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None\"\n            )\n            assert self.max_resp_len >= self.overlong_buffer_cfg.len, (\n                \"max_resp_len must be larger than overlong_buffer.len\"\n            )\n\n    def __call__(self, data: DataProto, return_dict: bool = False):\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            if return_dict:\n                reward_extra_keys = data.meta_info.get(\"reward_extra_keys\", [])\n                reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}\n                return {\"reward_tensor\": data.batch[\"rm_scores\"], \"reward_extra_info\": reward_extra_info}\n            else:\n                return data.batch[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n        reward_extra_info = defaultdict(list)\n\n        already_print_data_sources = {}\n\n        for i in range(len(data)):\n            data_item = data[i]  # DataProtoItem\n\n            prompt_ids = data_item.batch[\"prompts\"]\n\n            prompt_length = prompt_ids.shape[-1]\n\n            valid_prompt_length = data_item.batch[\"attention_mask\"][:prompt_length].sum()\n            valid_prompt_ids = prompt_ids[-valid_prompt_length:]\n\n            response_ids = data_item.batch[\"responses\"]\n            valid_response_length = data_item.batch[\"attention_mask\"][prompt_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)\n            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n            eos_token = self.tokenizer.eos_token\n            if response_str.endswith(eos_token):\n                response_str = response_str[: -len(eos_token)]\n\n            ground_truth = data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"]\n\n            data_source = data_item.non_tensor_batch[self.reward_fn_key]\n\n            extra_info = data_item.non_tensor_batch.get(\"extra_info\", {})\n\n            rollout_reward_scores = data_item.non_tensor_batch.get(\"reward_scores\", {})\n\n            extra_info[\"rollout_reward_scores\"] = rollout_reward_scores\n\n            result = self.compute_score(\n                data_source=data_source,\n                solution_str=response_str,\n                ground_truth=ground_truth,\n                extra_info=extra_info,\n            )\n\n            score: float\n            if isinstance(result, dict):\n                score = result[\"score\"]\n                # Store the information including original reward\n                for key, value in result.items():\n                    reward_extra_info[key].append(value)\n            else:\n                score = result\n                reward_extra_info[\"acc\"].append(score)\n\n            reward = score\n\n            if self.overlong_buffer_cfg.enable:\n                overlong_buffer_len = self.overlong_buffer_cfg.len\n                expected_len = self.max_resp_len - overlong_buffer_len\n                exceed_len = valid_response_length - expected_len\n                overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor\n                overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)\n                reward += overlong_reward\n                if self.overlong_buffer_cfg.log:\n                    reward_extra_info[\"overlong_reward\"].append(overlong_reward)\n                    reward_extra_info[\"overlong\"].append(overlong_reward < 0)\n\n            reward_tensor[i, valid_response_length - 1] = reward\n\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                print(\"[prompt]\", prompt_str)\n                print(\"[response]\", response_str)\n                print(\"[ground_truth]\", ground_truth)\n                if isinstance(result, dict):\n                    for key, value in result.items():\n                        print(f\"[{key}]\", value)\n                else:\n                    print(\"[score]\", score)\n\n        if return_dict:\n            return {\n                \"reward_tensor\": reward_tensor,\n                \"reward_extra_info\": reward_extra_info,\n            }\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_manager/naive.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import defaultdict\nfrom typing import Any\n\nimport torch\n\nfrom verl import DataProto\nfrom verl.utils.reward_score import default_compute_score\nfrom verl.workers.reward_manager import register\nfrom verl.workers.reward_manager.abstract import AbstractRewardManager\n\n\n@register(\"naive\")\nclass NaiveRewardManager(AbstractRewardManager):\n    \"\"\"The reward manager.\"\"\"\n\n    def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key=\"data_source\") -> None:\n        \"\"\"\n        Initialize the NaiveRewardManager instance.\n\n        Args:\n            tokenizer: The tokenizer used to decode token IDs into text.\n            num_examine: The number of batches of decoded responses to print to the console for debugging purpose.\n            compute_score: A function to compute the reward score. If None, `default_compute_score` will be used.\n            reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to\n                \"data_source\".\n        \"\"\"\n        self.tokenizer = tokenizer  # Store the tokenizer for decoding token IDs\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n        self.compute_score = compute_score or default_compute_score\n        self.reward_fn_key = reward_fn_key  # Store the key for accessing the data source\n\n    def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            if return_dict:\n                reward_extra_keys = data.meta_info.get(\"reward_extra_keys\", [])\n                reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}\n                return {\"reward_tensor\": data.batch[\"rm_scores\"], \"reward_extra_info\": reward_extra_info}\n            else:\n                return data.batch[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n        reward_extra_info = defaultdict(list)\n\n        already_print_data_sources = {}\n\n        for i in range(len(data)):\n            data_item = data[i]  # DataProtoItem\n\n            prompt_ids = data_item.batch[\"prompts\"]\n\n            prompt_length = prompt_ids.shape[-1]\n\n            valid_prompt_length = data_item.batch[\"attention_mask\"][:prompt_length].sum()\n            valid_prompt_ids = prompt_ids[-valid_prompt_length:]\n\n            response_ids = data_item.batch[\"responses\"]\n            valid_response_length = data_item.batch[\"attention_mask\"][prompt_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)\n            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n\n            ground_truth = data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"]\n            data_source = data_item.non_tensor_batch[self.reward_fn_key]\n            extra_info = data_item.non_tensor_batch.get(\"extra_info\", {})\n            num_turns = data_item.non_tensor_batch.get(\"__num_turns__\", None)\n            rollout_reward_scores = data_item.non_tensor_batch.get(\"reward_scores\", {})\n            extra_info[\"num_turns\"] = num_turns\n            extra_info[\"rollout_reward_scores\"] = rollout_reward_scores\n\n            score = self.compute_score(\n                data_source=data_source,\n                solution_str=response_str,\n                ground_truth=ground_truth,\n                extra_info=extra_info,\n            )\n\n            if isinstance(score, dict):\n                reward = score[\"score\"]\n                # Store the information including original reward\n                for key, value in score.items():\n                    reward_extra_info[key].append(value)\n            else:\n                reward = score\n\n            reward_tensor[i, valid_response_length - 1] = reward\n\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                print(\"[prompt]\", prompt_str)\n                print(\"[response]\", response_str)\n                print(\"[ground_truth]\", ground_truth)\n                if isinstance(score, dict):\n                    for key, value in score.items():\n                        print(f\"[{key}]\", value)\n                else:\n                    print(\"[score]\", score)\n\n        if return_dict:\n            return {\n                \"reward_tensor\": reward_tensor,\n                \"reward_extra_info\": reward_extra_info,\n            }\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_manager/prime.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nfrom concurrent.futures import ProcessPoolExecutor\nfrom functools import partial\nfrom typing import Any, Callable, Optional\n\nimport psutil\nimport torch\nfrom transformers import PreTrainedTokenizer\n\nfrom verl import DataProto\nfrom verl.utils.reward_score import default_compute_score\nfrom verl.workers.reward_manager import register\nfrom verl.workers.reward_manager.abstract import AbstractRewardManager\n\n\nasync def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0):\n    loop = asyncio.get_running_loop()\n    try:\n        # Ensure process_completion is called properly\n        future = loop.run_in_executor(executor, partial(evaluation_func, task, completion, reference, task_extra_info))\n        return await asyncio.wait_for(future, timeout=timeout)\n    except asyncio.TimeoutError:\n        print(f\"[Timeout] Task timeout: {completion}\")\n        return None  # Default value for timed-out rows\n    except Exception as e:\n        print(f\"[Error] Task failed: {e}, completion: {completion[:80]}\")\n        return None  # Default value for failed rows\n\n\nasync def parallel_compute_score_async(\n    evaluation_func, completions, references, tasks, extra_info=None, num_processes=64\n):\n    if extra_info is None:\n        extra_info = [None] * len(tasks)\n    scores = []\n    with ProcessPoolExecutor(max_workers=num_processes) as executor:\n        # to prevent very occasional starvation caused by some anomalous programs ( like infinite loop ), the\n        # exceptions in async programs will instantly halt the evaluation, and all summoned processes will be killed.\n        try:\n            # Create tasks for all rows\n            tasks_async = [\n                single_compute_score(evaluation_func, c, r, t, ei, executor, timeout=300.0)\n                for c, r, t, ei in zip(completions, references, tasks, extra_info, strict=True)\n            ]\n            results = await asyncio.gather(*tasks_async, return_exceptions=False)\n        except Exception as e:\n            print(f\"[Exception] async gather failed: {e}\")\n            raise\n        finally:\n            terminated_count = 0\n            for pid, proc in executor._processes.items():\n                try:\n                    p = psutil.Process(pid)\n                    p.terminate()\n                    try:\n                        p.wait(timeout=5)\n                    except psutil.TimeoutExpired:\n                        p.kill()\n                    terminated_count += 1\n                except Exception:\n                    pass\n            print(f\"[Shutdown] {terminated_count} subprocess(es) terminated.\")\n\n    # Process results\n    for result, completion, reference, task in zip(results, completions, references, tasks, strict=True):\n        if isinstance(result, Exception) or result is None:\n            # Handle failed or timed-out tasks\n            scores.append(0.0)\n        elif isinstance(result, int | float | bool):\n            scores.append(float(result))\n        else:\n            scores.append(float(result[0]))\n    return scores\n\n\ndef run_reward_scoring(evaluation_func, completions, references, tasks, extra_info=None, num_processes=64):\n    loop = asyncio.new_event_loop()\n    asyncio.set_event_loop(loop)\n    try:\n        return loop.run_until_complete(\n            parallel_compute_score_async(evaluation_func, completions, references, tasks, extra_info, num_processes)\n        )\n    finally:\n        loop.close()\n\n\n@register(\"prime\")\nclass PrimeRewardManager(AbstractRewardManager):\n    \"\"\"\n    The Reward Manager used in https://github.com/PRIME-RL/PRIME\n    \"\"\"\n\n    def __init__(\n        self,\n        tokenizer: PreTrainedTokenizer,\n        num_examine: int,\n        compute_score: Optional[Callable] = None,\n        reward_fn_key: str = \"data_source\",\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n        self.compute_score = compute_score or default_compute_score\n        self.reward_fn_key = reward_fn_key\n\n    def verify(self, data):\n        \"\"\"\n        verify the batch and save as ``acc`` tensor\n        \"\"\"\n        # batched scoring\n        prompt_ids = data.batch[\"prompts\"]\n\n        response_ids = data.batch[\"responses\"]\n        sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)\n        ground_truth = [data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"] for data_item in data]\n        data_sources = data.non_tensor_batch[self.reward_fn_key]\n        extra_info = data.non_tensor_batch.get(\"extra_info\", None)\n\n        assert len(sequences_str) == len(ground_truth) == len(data_sources)\n        try:\n            scores = run_reward_scoring(\n                self.compute_score,\n                completions=sequences_str,\n                references=ground_truth,\n                tasks=data_sources,\n                extra_info=extra_info,\n                num_processes=64,\n            )\n        except asyncio.TimeoutError:\n            print(\"[Timeout] Global reward scoring timed out. Setting all as 0.\")\n            scores = [0.0 for _ in range(len(sequences_str))]\n        except Exception as e:\n            print(f\"[Error] Unexpected error during scoring. Setting all as 0. {e}\")\n            scores = [0.0 for _ in range(len(sequences_str))]\n        data.batch[\"acc\"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device)\n        return scores\n\n    def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            if return_dict:\n                reward_extra_keys = data.meta_info.get(\"reward_extra_keys\", [])\n                reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}\n                return {\"reward_tensor\": data.batch[\"rm_scores\"], \"reward_extra_info\": reward_extra_info}\n            else:\n                return data.batch[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n\n        already_print_data_sources = {}\n\n        # batched scoring\n        prompt_ids = data.batch[\"prompts\"]\n        prompt_length = prompt_ids.shape[-1]\n\n        response_ids = data.batch[\"responses\"]\n        valid_response_length = data.batch[\"attention_mask\"][:, prompt_length:].sum(dim=-1)\n        sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)\n        data_sources = data.non_tensor_batch[\"data_source\"]\n\n        scores = self.verify(data)\n\n        for i in range(len(data)):\n            data_source = data_sources[i]\n            reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]\n\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                print(sequences_str)\n\n        if return_dict:\n            return {\"reward_tensor\": reward_tensor}\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_manager/registry.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable\n\nfrom verl.workers.reward_manager.abstract import AbstractRewardManager\n\n__all__ = [\"register\", \"get_reward_manager_cls\"]\n\nREWARD_MANAGER_REGISTRY: dict[str, type[AbstractRewardManager]] = {}\n\n\ndef register(name: str) -> Callable[[type[AbstractRewardManager]], type[AbstractRewardManager]]:\n    \"\"\"Decorator to register a reward manager class with a given name.\n\n    Args:\n        name: `(str)`\n            The name of the reward manager.\n    \"\"\"\n\n    def decorator(cls: type[AbstractRewardManager]) -> type[AbstractRewardManager]:\n        if name in REWARD_MANAGER_REGISTRY and REWARD_MANAGER_REGISTRY[name] != cls:\n            raise ValueError(\n                f\"Reward manager {name} has already been registered: {REWARD_MANAGER_REGISTRY[name]} vs {cls}\"\n            )\n        REWARD_MANAGER_REGISTRY[name] = cls\n        return cls\n\n    return decorator\n\n\ndef get_reward_manager_cls(name: str) -> type[AbstractRewardManager]:\n    \"\"\"Get the reward manager class with a given name.\n\n    Args:\n        name: `(str)`\n            The name of the reward manager.\n\n    Returns:\n        `(type)`: The reward manager class.\n    \"\"\"\n    if name not in REWARD_MANAGER_REGISTRY:\n        raise ValueError(f\"Unknown reward manager: {name}\")\n    return REWARD_MANAGER_REGISTRY[name]\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_model/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BasePPORewardModel\n\n__all__ = [\"BasePPORewardModel\"]\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_model/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe base class for reward model\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nfrom torch.distributed.device_mesh import DeviceMesh\n\nfrom verl import DataProto\nfrom verl.workers.config import HFModelConfig, RewardModelConfig\n\n__all__ = [\"BasePPORewardModel\"]\n\n\nclass BasePPORewardModel(ABC):\n    \"\"\"base class for reward model\"\"\"\n\n    def __init__(\n        self,\n        config: RewardModelConfig,\n        model_config: HFModelConfig,\n        device_mesh: DeviceMesh,\n    ):\n        self.config = config\n        self.model_config = model_config\n        self.device_mesh = device_mesh\n\n    @abstractmethod\n    def compute_reward(self, data: DataProto) -> DataProto:\n        \"\"\"Computing reward given input_ids. The transformers should output a tensor with shape\n           [batch_size, sequence_length], and the value at [EOS] mask should be gathered.\n\n        Args:\n            data: must contain keys \"input_ids\", \"attention_mask\" and \"position_ids\".\n                - input_ids: [batch_size, sequence_length]\n                - attention_mask: [batch_size, sequence_length]\n                - position_ids: [batch_size, sequence_length]\n\n        Returns: a data pass protocol containing \"reward\". Only the [EOS] position contains the reward.\n            Other position should have zero reward. Note that this may change in the future if we use\n            dense reward. So, we leave the interface for general case.\n            - reward: [batch_size, sequence_length].\n\n        \"\"\"\n        pass\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_model/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .reward_model import MegatronRewardModel\n\n__all__ = [\"MegatronRewardModel\"]\n"
  },
  {
    "path": "verl_distillation/verl/workers/reward_model/megatron/reward_model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMegatron Reward Model.\n\"\"\"\n\nimport itertools\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.pipeline_parallel import get_forward_backward_func\nfrom tensordict import TensorDict\n\nfrom verl import DataProto\nfrom verl.utils.device import get_device_id, get_device_name, get_torch_device\nfrom verl.utils.megatron.pipeline_parallel import make_batch_generator\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.utils.torch_functional import broadcast_dict_tensor, pad_sequence_to_length\nfrom verl.workers.reward_model import BasePPORewardModel\n\n\nclass MegatronRewardModel(BasePPORewardModel):\n    def __init__(\n        self,\n        config,\n        model_config,\n        reward_model_module: torch.nn.ModuleList,\n        hf_config,\n        tf_config,\n        sft_tokenizer=None,\n        rm_tokenizer=None,\n    ):\n        self.config = config\n        self.reward_model_module = reward_model_module\n        self.hf_config = hf_config\n        self.tf_config = tf_config\n        self.model_config = model_config\n        self.device = \"cuda\"\n        self.sft_tokenizer = sft_tokenizer\n        self.rm_tokenizer = rm_tokenizer\n        self.use_different_tokenizer = rm_tokenizer is not None\n\n        print(f\"MegatronRewardModel.config: {self.config}\")\n\n        if self.config.megatron.param_offload:\n            self.offload_params_to_cpu()\n\n    def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto:\n        assert self.use_different_tokenizer, \"re-encode need rm tokenizer not be None!\"\n        # need to use rm tokenizer to re-generate input_ids, attention_mask and position_ids\n        # 1. remove pad for each sequence\n        # 2. decode by sft_tokenizer, remove sft system prompts\n        # 3. encode by rm_tokenizer with rm system prompts, get rm_input_ids\n        # 4. generate attention_mask and position_ids\n        input_ids = data.batch[\"input_ids\"]  # (bs, seq_len)\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        ori_values = {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"position_ids\": position_ids}\n        _, ori_seqlen = input_ids.size(0), input_ids.size(1)\n        input_ids_for_rm = []\n        attention_mask_for_rm = []\n        position_ids_for_rm = []\n        print_decode = True\n        ori_seqlen = ori_seqlen + 128\n        for id, mask in zip(input_ids, attention_mask, strict=True):\n            # 1. remove pad for each sequence\n            non_zero_indices = torch.nonzero(mask).view(-1)\n            begin_pos, end_pos = non_zero_indices[0].item(), non_zero_indices[-1].item()\n            valid_id = id[begin_pos : end_pos + 1]\n            # 2. decode by sft_tokenizer, remove sft system prompts\n            decode_result = self.sft_tokenizer.decode(valid_id)\n            # workaround\n            decode_with_rm_chat = (\n                decode_result.replace(\"<|user|>\\n\", \"[INST] \")\n                .replace(\"</s>\\n<|assistant|>\\n\", \" [/INST]\")\n                .replace(\"</s> \\n<|assistant|>\\n\", \" [/INST]\")\n                + \"</s>\"\n            )\n            if print_decode and torch.distributed.get_rank() == 0:\n                # only print first decode result\n                print(\n                    f\"device {get_device_id()}: sft decode result:\\n{decode_result}\\n \\\n                        \\ndevice {get_device_id()}: sft decode result with \\\n                        rm chat template:\\n{decode_with_rm_chat}\\n\\n\"\n                )\n                print_decode = False\n            # 3. encode by rm_tokenizer\n            rm_input_ids = self.rm_tokenizer(decode_with_rm_chat, return_tensors=\"pt\")[\"input_ids\"][0].to(\n                input_ids.device\n            )\n            # 4. generate attention_mask and position_ids\n            rm_attention_mask = torch.ones_like(rm_input_ids, device=input_ids.device)\n            cur_seqlen = rm_input_ids.shape[-1]\n            # NOTE(gh): the later reward compute will process the shape (bs, seqlen_pad_128)\n            if cur_seqlen > ori_seqlen:\n                print(f\"warninig: rm encode seqlen {cur_seqlen} > sft encode seqlen {ori_seqlen}\")\n                rm_input_ids = rm_input_ids[:ori_seqlen]\n                rm_attention_mask = rm_attention_mask[:ori_seqlen]\n            else:\n                # right padding\n                rm_input_ids = pad_sequence_to_length(rm_input_ids, ori_seqlen, self.rm_tokenizer.pad_token_id)\n                rm_attention_mask = pad_sequence_to_length(rm_attention_mask, ori_seqlen, 0)\n            rm_position_ids = torch.arange(0, ori_seqlen, device=input_ids.device)\n            input_ids_for_rm.append(torch.unsqueeze(rm_input_ids, dim=0))\n            attention_mask_for_rm.append(torch.unsqueeze(rm_attention_mask, dim=0))\n            position_ids_for_rm.append(torch.unsqueeze(rm_position_ids, dim=0))\n        input_ids_for_rm = torch.cat(input_ids_for_rm, dim=0)\n        attention_mask_for_rm = torch.cat(attention_mask_for_rm, dim=0)\n        position_ids_for_rm = torch.cat(position_ids_for_rm, dim=0)\n\n        # (bs, seqlen) will not change, but input_ids, attention_mask and position_ids will change\n        # NOTE(gh): need to replace into origin values after compute reward!\n        data.batch[\"input_ids\"] = input_ids_for_rm\n        data.batch[\"attention_mask\"] = attention_mask_for_rm\n        data.batch[\"position_ids\"] = position_ids_for_rm\n\n        return data, ori_values\n\n    @torch.no_grad()\n    def compute_reward(self, data: DataProto) -> DataProto:\n        if self.config.megatron.param_offload:\n            self.load_params_to_cuda()\n\n        if self.use_different_tokenizer:\n            data, ori_values = self.re_encode_by_rm_tokenizer(data)\n\n        input_ids = data.batch[\"input_ids\"]  # (bs, seq_len')\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        use_dynamic_bsz = data.meta_info.get(\"use_dynamic_bsz\", False)\n        micro_batch_size = data.meta_info.get(\"micro_batch_size\", None)\n        max_token_len = data.meta_info.get(\"max_token_len\", None)\n        assert micro_batch_size is not None, \"micro batch size is needed for forward compute\"\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"use_dynamic_bsz is True, but max_token_len is None!\"\n            max_token_len = max_token_len * self.config.megatron.context_parallel_size\n\n        responses = data.batch[\"responses\"]\n        batch_size = responses.size(0)\n        response_length = responses.size(1)\n\n        with torch.no_grad():\n            output = self.forward_batch(\n                data, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len\n            )\n            if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                logits = torch.cat(output[\"output\"], dim=0)\n                if use_dynamic_bsz:\n                    indices = output[\"indices\"]\n                    indices = list(itertools.chain.from_iterable(indices))\n                    assert len(indices) == logits.size(0), f\"{len(indices)} vs. {logits.size()}\"\n                    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                    logits = logits[revert_indices]\n            else:\n                logits = torch.empty(\n                    (input_ids.shape[0], input_ids.shape[1]),\n                    device=input_ids.device,\n                )\n            logits = logits.to(torch.float32)\n\n            # broadcast across pp ranks\n            torch.distributed.broadcast(\n                tensor=logits,\n                src=mpu.get_pipeline_model_parallel_last_rank(),\n                group=mpu.get_pipeline_model_parallel_group(),\n                async_op=False,\n            )\n\n        # (bs, seqlen', hidden_size) -> (bs, seqlen', 1) -> (bs, seqlen')\n        token_level_rewards = logits\n        # find the last token reward\n        ends = attention_mask.cumsum(dim=-1).argmax(dim=-1).view(-1, 1)  # (bs, 1)\n        rewards = torch.gather(token_level_rewards, dim=1, index=ends)  # (bs, 1)\n\n        if self.use_different_tokenizer:\n            data.batch.update(ori_values)\n            input_ids = ori_values[\"input_ids\"]\n            attention_mask = ori_values[\"attention_mask\"]\n            position_ids = ori_values[\"position_ids\"]\n\n        token_level_rewards = rewards.expand(attention_mask.shape[0], attention_mask.shape[1])  # (bs, ori_seqlen)\n\n        # assign last valid token reward to ori position\n        if position_ids.dim() == 3:  # qwen2vl mrope [bs, 3, seq_len]\n            position_ids = position_ids[:, 0, :]\n        eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bs,)\n        eos_mask = torch.zeros_like(attention_mask)\n        eos_mask[torch.arange(batch_size), eos_mask_idx] = 1.0\n\n        token_level_rewards = token_level_rewards * eos_mask\n        token_level_rewards = token_level_rewards[:, -response_length:]\n\n        if self.config.megatron.param_offload:\n            self.offload_params_to_cpu()\n        else:\n            # add empty cache after each compute\n            get_torch_device().empty_cache()\n\n        batch = TensorDict({\"rm_scores\": token_level_rewards}, batch_size=input_ids.shape[0])\n\n        return DataProto(batch=batch)\n\n    def forward_batch(self, data: DataProto, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None):\n        \"\"\"\n        We assume:\n        - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input\n        - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled\n        \"\"\"\n        # broadcast from last pp rank to all other pp ranks\n        # TODO: actually, we just need to control the sampling order.\n        mini_batch = data\n        mini_batch.batch = mini_batch.batch.contiguous()\n        broadcast_dict_tensor(\n            mini_batch.batch,\n            src=mpu.get_pipeline_model_parallel_last_rank(),\n            group=mpu.get_pipeline_model_parallel_group(),\n        )\n\n        mini_batch.batch[\"attention_mask\"] = mini_batch.batch[\"attention_mask\"].to(bool)\n\n        self.has_multi_modal_inputs = \"multi_modal_inputs\" in mini_batch.non_tensor_batch.keys()\n        if self.has_multi_modal_inputs:\n            mini_batch.batch[\"multi_modal_inputs\"] = mini_batch.non_tensor_batch[\"multi_modal_inputs\"]\n            mini_batch.batch[\"multi_modal_inputs_idx\"] = torch.Tensor(\n                list(range(len(mini_batch.non_tensor_batch[\"multi_modal_inputs\"])))\n            ).to(torch.int64)\n\n        indices = None\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n            if vpp_size is not None and vpp_size > 1:\n                microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage\n                micro_batches, indices = rearrange_micro_batches(\n                    batch=mini_batch.batch,\n                    num_batches_divided_by=microbatch_group_size_per_vp_stage,\n                    max_token_len=max_token_len,\n                )\n                assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (\n                    f\"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage \"\n                    f\"{microbatch_group_size_per_vp_stage} for megatron backend\"\n                )\n            else:\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)\n            total_seqlen = max_token_len\n        else:\n            assert micro_batch_size is not None, (\n                \"micro_batch_size is needed to be passed in when not using dynamic batch size\"\n            )\n            micro_batches = mini_batch.batch.split(micro_batch_size)\n            seq_len = micro_batches[0][\"input_ids\"].shape[1]\n            total_seqlen = micro_batch_size * seq_len\n        n_micro_batch = len(micro_batches)\n\n        # compute input shapes for pp stages\n        forward_backward_func = get_forward_backward_func()\n\n        def loss_func(output):\n            return torch.tensor(1.0, device=output.device), output\n\n        def forward_step(batch_iter, model):\n            batch = next(batch_iter)\n            input_ids = batch[\"input_ids\"]\n            attention_mask = batch[\"attention_mask\"]\n            position_ids = batch[\"position_ids\"]\n            from verl.models.mcore import get_mcore_forward_fn\n\n            forward_fn = get_mcore_forward_fn(self.hf_config)\n\n            multi_modal_inputs = {}\n            if \"multi_modal_inputs\" in batch:\n                from verl.utils.model import extract_multi_modal_inputs\n\n                indices = batch.get(\"multi_modal_inputs_idx\", None)\n                multi_modal_inputs = extract_multi_modal_inputs(batch[\"multi_modal_inputs\"], indices)\n            output = forward_fn(\n                model,\n                input_ids,\n                attention_mask,\n                position_ids,\n                multi_modal_inputs,\n                value_model=True,\n            )\n\n            return output, loss_func\n\n        # batch should be a list of batches inside micro-batches\n        batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.reward_model_module))\n\n        # TODO: we may use the new schedule instead\n        # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)\n        if mpu.get_pipeline_model_parallel_world_size() > 1:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.reward_model_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # no use when input_shapes was set\n                micro_batch_size=1,  # no use when input_shapes was set\n                forward_only=True,\n            )\n        else:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.reward_model_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # in use for pp = 1\n                micro_batch_size=1,  # in use for pp = 1\n                forward_only=True,\n            )\n\n        if self.has_multi_modal_inputs:\n            data.batch.pop(\"multi_modal_inputs\")\n            data.batch.pop(\"multi_modal_inputs_idx\")\n            data.non_tensor_batch.pop(\"multi_modal_inputs\")\n        # loss_reduces contains the stats returned from loss_func\n        losses_reduced = {\"output\": losses_reduced}\n        if use_dynamic_bsz:\n            losses_reduced[\"indices\"] = indices\n        return losses_reduced\n\n    def offload_params_to_cpu(self):\n        if self.device in [\"cuda\", \"npu\"]:\n            for reward_model_module in self.reward_model_module:\n                for name, param in reward_model_module.named_parameters():\n                    param.data = param.data.to(\"cpu\", non_blocking=True)\n            self.device = \"cpu\"\n            get_torch_device().empty_cache()\n\n    def load_params_to_cuda(self):\n        if self.device == \"cpu\":\n            for reward_model_module in self.reward_model_module:\n                for name, param in reward_model_module.named_parameters():\n                    param.data = param.data.to(get_device_id(), non_blocking=True)\n            self.device = get_device_name()\n"
  },
  {
    "path": "verl_distillation/verl/workers/roles/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .actor import ActorWorker\nfrom .critic import CriticWorker\n\ntry:\n    from .reward_model import RewardModelWorker\nexcept ImportError:\n    RewardModelWorker = None\n\n__all__ = [\"CriticWorker\", \"ActorWorker\"]\n\nif RewardModelWorker is not None:\n    __all__.append(\"RewardModelWorker\")\n"
  },
  {
    "path": "verl_distillation/verl/workers/roles/actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom functools import partial\n\nimport psutil\nfrom codetiming import Timer\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register\nfrom verl.utils.device import (\n    get_device_id,\n    get_device_name,\n    get_torch_device,\n)\nfrom verl.utils.distributed import initialize_global_process_group_ray\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.profiler import DistProfiler, DistProfilerExtension\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.workers.config import ActorConfig\nfrom verl.workers.roles.utils.losses import ppo_loss\nfrom verl.workers.roles.utils.padding import left_right_2_no_padding, no_padding_2_padding\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\ndevice_name = get_device_name()\n\n\nclass ActorWorker(Worker, DistProfilerExtension):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    def __init__(self, config: ActorConfig):\n        self.config = config\n        Worker.__init__(self)\n        self.profiler_config = self.config.profiler\n        tool_config = self.profiler_config.tool_config\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=self.profiler_config, tool_config=tool_config)\n        )\n\n        initialize_global_process_group_ray(timeout_second=None)\n\n        self.loss_fn = partial(ppo_loss, config=self.config)\n\n    def _build_engine(self):\n        self.model_config = self.config.model_config\n        self.engine_config = self.config.engine\n        self.optimizer_config = self.config.optim\n        self.checkpoint_config = self.config.checkpoint\n\n        from verl.workers.engine import BaseEngine, EngineRegistry\n\n        self.engine: BaseEngine = EngineRegistry.new(\n            model_type=\"language_model\",\n            backend=self.config.strategy,\n            model_config=self.model_config,\n            engine_config=self.engine_config,\n            optimizer_config=self.optimizer_config,\n            checkpoint_config=self.checkpoint_config,\n        )\n\n        # build dispatch info\n        self._register_dispatch_collect_info(\n            mesh_name=\"actor\",\n            dp_rank=self.engine.get_data_parallel_rank(),\n            is_collect=self.engine.is_mp_src_rank_with_outputs(),\n        )\n\n        # aggregate with bon sampling\n        self.ppo_mini_batch_size = self.config.ppo_mini_batch_size * self.config.rollout_n\n        assert self.ppo_mini_batch_size % self.engine.get_data_parallel_size() == 0, (\n            f\"{self.ppo_mini_batch_size=} is not divisible by {self.engine.get_data_parallel_size()=}\"\n        )\n        self.ppo_mini_batch_size_per_dp = self.ppo_mini_batch_size // self.engine.get_data_parallel_size()\n\n        # setup flops counter\n        self.flops_counter = FlopsCounter(self.model_config.hf_config)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        self._build_engine()\n        self.engine.initialize()\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def set_loss_fn(self, loss_fn):\n        self.loss_fn = loss_fn\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    @DistProfiler.annotate(color=\"blue\", role=\"actor_compute_log_prob\")\n    def compute_log_prob(self, data: DataProto):\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        data.meta_info[\"use_fused_kernels\"] = self.config.use_fused_kernels\n        data.meta_info[\"calculate_entropy\"] = True\n        if self.config.use_dynamic_bsz:\n            data.meta_info[\"max_token_len_per_gpu\"] = self.config.ppo_infer_max_token_len_per_gpu\n        else:\n            data.meta_info[\"micro_batch_size_per_gpu\"] = self.config.ppo_infer_micro_batch_size_per_gpu\n\n        with self.engine.eval_mode():\n            # TODO: make worker API to accept TensorDict as well\n            data = data.to_tensordict()\n            data = left_right_2_no_padding(data)\n            output = self.engine.infer_batch(data)\n\n        if self.engine.is_mp_src_rank_with_outputs():\n            output = output[\"model_output\"]\n            log_probs = output[\"log_probs\"]\n            log_probs = no_padding_2_padding(log_probs, data)  # (bsz, response_length)\n\n            entropy = output[\"entropy\"]\n            if entropy is not None:\n                entropy = no_padding_2_padding(entropy, data)  # (bsz, response_length)\n\n            # in megatron, only last pp contains valid data and returned to the single controller\n            output = DataProto.from_dict(\n                tensors={\"old_log_probs\": log_probs.float(), \"entropy\": entropy.float()},\n            )\n            output = output.to(\"cpu\")\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"actor\"))\n    @DistProfiler.annotate(color=\"red\", role=\"actor_update\")\n    def update_actor(self, data: DataProto):\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        data.meta_info[\"use_fused_kernels\"] = self.config.use_fused_kernels\n        data.meta_info[\"calculate_entropy\"] = self.config.entropy_coeff != 0.0\n        if self.config.use_dynamic_bsz:\n            data.meta_info[\"max_token_len_per_gpu\"] = self.config.ppo_max_token_len_per_gpu\n        else:\n            data.meta_info[\"micro_batch_size_per_gpu\"] = self.config.ppo_micro_batch_size_per_gpu\n\n        metrics = {}\n        # Support all hardwares\n        data = data.to(get_device_id())\n        # perform forward computation\n        with self.engine.train_mode():\n            dataloader = data.make_iterator(\n                mini_batch_size=self.ppo_mini_batch_size_per_dp,\n                epochs=self.config.ppo_epochs,\n                seed=self.config.data_loader_seed + self.engine.get_data_parallel_rank(),\n                dataloader_kwargs={\"shuffle\": self.config.shuffle},\n            )\n            with Timer(name=\"update_policy\", logger=None) as timer:\n                for batch_idx, mini_batch in enumerate(dataloader):\n                    mini_batch.meta_info[\"global_batch_size\"] = self.config.ppo_mini_batch_size\n                    # TODO: make worker API to accept TensorDict as well\n                    mini_batch = mini_batch.to_tensordict()\n                    mini_batch = left_right_2_no_padding(mini_batch)\n                    output = self.engine.train_batch(mini_batch, self.loss_fn)\n                    mini_batch_metrics = output.get(\"metrics\", {})\n                    append_to_dict(metrics, mini_batch_metrics, prefix=\"actor/\")\n\n            delta_time = timer.last\n\n            global_num_tokens = data.meta_info[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/actor\"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size\n            metrics[\"perf/max_memory_allocated_gb\"] = get_torch_device().max_memory_allocated() / (1024**3)\n            metrics[\"perf/max_memory_reserved_gb\"] = get_torch_device().max_memory_reserved() / (1024**3)\n            metrics[\"perf/cpu_memory_used_gb\"] = psutil.virtual_memory().used / (1024**3)\n\n            lr = self.engine.lr_scheduler_step()\n            metrics[\"actor/lr\"] = lr\n\n            output = DataProto(batch=None, meta_info={\"metrics\": metrics})\n\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        return self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):\n        return self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load)\n"
  },
  {
    "path": "verl_distillation/verl/workers/roles/critic.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport logging\nimport os\nimport warnings\nfrom functools import partial\n\nimport psutil\nfrom codetiming import Timer\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register\nfrom verl.utils.device import (\n    get_device_id,\n    get_device_name,\n    get_torch_device,\n)\nfrom verl.utils.distributed import initialize_global_process_group_ray\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.profiler import DistProfiler, DistProfilerExtension\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.workers.config import CriticConfig\nfrom verl.workers.roles.utils.losses import value_loss\nfrom verl.workers.roles.utils.padding import left_right_2_no_padding, no_padding_2_padding\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\ndevice_name = get_device_name()\n\n\nclass CriticWorker(Worker, DistProfilerExtension):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    def __init__(self, config: CriticConfig):\n        self.config = config\n        Worker.__init__(self)\n        self.profiler_config = self.config.profiler\n        tool_config = self.profiler_config.tool_config\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=self.profiler_config, tool_config=tool_config)\n        )\n\n        initialize_global_process_group_ray(timeout_second=None)\n\n        self.loss_fn = partial(value_loss, config=self.config)\n\n    def _build_engine(self):\n        from copy import copy, deepcopy\n\n        self.model_config = copy(self.config.model_config)\n        self.model_config.hf_config = deepcopy(self.config.model_config.hf_config)\n        self.engine_config = self.config.engine\n        self.optimizer_config = self.config.optim\n        self.checkpoint_config = self.config.checkpoint\n\n        from verl.workers.engine import BaseEngine, EngineRegistry\n\n        # replace AutoModelForSequenceClassification to AutoModelForTokenClassification\n        hf_config = self.model_config.hf_config\n\n        arch = hf_config.architectures[0]\n        # This logic assumes the critic is a token classification model.\n        # If the provided model is a CausalLM, we adapt it.\n        if \"ForCausalLM\" in arch:\n            model_name = arch.split(\"ForCausalLM\")[0]\n            new_arch = f\"{model_name}ForTokenClassification\"\n            warnings.warn(f\"Implicitly changing critic architecture from '{arch}' to '{new_arch}'\", stacklevel=2)\n            hf_config.architectures[0] = new_arch\n        elif \"ForTokenClassification\" not in arch and \"ForSequenceClassification\" not in arch:\n            raise ValueError(\n                f\"Unsupported critic architecture: {arch}. \"\n                f\"Critic worker expects an architecture suitable for value function estimation, \"\n                f\"such as '...ForTokenClassification' or '...ForSequenceClassification'.\"\n            )\n\n        # make sure output dropout is 0\n        hf_config.classifier_dropout = 0\n\n        self.engine: BaseEngine = EngineRegistry.new(\n            model_type=\"value_model\",\n            backend=self.config.strategy,\n            model_config=self.model_config,\n            engine_config=self.engine_config,\n            optimizer_config=self.optimizer_config,\n            checkpoint_config=self.checkpoint_config,\n        )\n\n        # build dispatch info\n        self._register_dispatch_collect_info(\n            mesh_name=\"critic\",\n            dp_rank=self.engine.get_data_parallel_rank(),\n            is_collect=self.engine.is_mp_src_rank_with_outputs(),\n        )\n\n        # aggregate with bon sampling\n        self.ppo_mini_batch_size = self.config.ppo_mini_batch_size * self.config.rollout_n\n        assert self.ppo_mini_batch_size % self.engine.get_data_parallel_size() == 0, (\n            f\"{self.ppo_mini_batch_size=} is not divisible by {self.engine.get_data_parallel_size()=}\"\n        )\n        self.ppo_mini_batch_size_per_dp = self.ppo_mini_batch_size // self.engine.get_data_parallel_size()\n\n        # setup flops counter\n        self.flops_counter = FlopsCounter(self.model_config.hf_config)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        self._build_engine()\n        self.engine.initialize()\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def set_loss_fn(self, loss_fn):\n        self.loss_fn = loss_fn\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"critic\"))\n    @DistProfiler.annotate(color=\"blue\", role=\"critic_compute_values\")\n    def compute_values(self, data: DataProto):\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        if self.config.use_dynamic_bsz:\n            data.meta_info[\"max_token_len_per_gpu\"] = self.config.ppo_infer_max_token_len_per_gpu\n        else:\n            data.meta_info[\"micro_batch_size_per_gpu\"] = self.config.ppo_infer_micro_batch_size_per_gpu\n\n        with self.engine.eval_mode():\n            # TODO: make worker API to accept TensorDict as well\n            data = data.to_tensordict()\n            data = left_right_2_no_padding(data)\n            output = self.engine.infer_batch(data)\n\n        if self.engine.is_mp_src_rank_with_outputs():\n            # in megatron, only last pp contains valid data and returned to the single controller\n            output = output[\"model_output\"]\n            values = output[\"values\"]\n            values = no_padding_2_padding(values, data)  # (bsz, response_length)\n\n            output = DataProto.from_dict(\n                tensors={\"values\": values.float()},\n            )\n            output = output.to(\"cpu\")\n\n        return output\n\n    @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name=\"critic\"))\n    @DistProfiler.annotate(color=\"red\", role=\"critic_update\")\n    def update_critic(self, data: DataProto):\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        if self.config.use_dynamic_bsz:\n            data.meta_info[\"max_token_len_per_gpu\"] = self.config.ppo_max_token_len_per_gpu\n        else:\n            data.meta_info[\"micro_batch_size_per_gpu\"] = self.config.ppo_micro_batch_size_per_gpu\n\n        metrics = {}\n        # Support all hardwares\n        data = data.to(get_device_id())\n        # perform forward computation\n        with self.engine.train_mode():\n            dataloader = data.make_iterator(\n                mini_batch_size=self.ppo_mini_batch_size_per_dp,\n                epochs=self.config.ppo_epochs,\n                seed=self.config.data_loader_seed + self.engine.get_data_parallel_rank(),\n                dataloader_kwargs={\"shuffle\": self.config.shuffle},\n            )\n            with Timer(name=\"update_policy\", logger=None) as timer:\n                for batch_idx, mini_batch in enumerate(dataloader):\n                    mini_batch.meta_info[\"global_batch_size\"] = self.config.ppo_mini_batch_size\n                    # TODO: make worker API to accept TensorDict as well\n                    mini_batch = mini_batch.to_tensordict()\n                    mini_batch = left_right_2_no_padding(mini_batch)\n                    output = self.engine.train_batch(mini_batch, self.loss_fn)\n                    mini_batch_metrics = output.get(\"metrics\", {})\n                    append_to_dict(metrics, mini_batch_metrics, prefix=\"critic/\")\n\n            delta_time = timer.last\n\n            global_num_tokens = data.meta_info[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/critic\"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size\n            metrics[\"perf/max_memory_allocated_gb\"] = get_torch_device().max_memory_allocated() / (1024**3)\n            metrics[\"perf/max_memory_reserved_gb\"] = get_torch_device().max_memory_reserved() / (1024**3)\n            metrics[\"perf/cpu_memory_used_gb\"] = psutil.virtual_memory().used / (1024**3)\n\n            lr = self.engine.lr_scheduler_step()\n            metrics[\"critic/lr\"] = lr\n\n            output = DataProto(batch=None, meta_info={\"metrics\": metrics})\n\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        return self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):\n        return self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load)\n"
  },
  {
    "path": "verl_distillation/verl/workers/roles/hybrid_engine.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/workers/roles/utils/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/workers/roles/utils/losses.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\nfrom tensordict import TensorDict\n\nfrom verl.trainer.ppo.core_algos import agg_loss, compute_value_loss, get_policy_loss_fn, kl_penalty\nfrom verl.utils import tensordict_utils as tu\nfrom verl.utils.dataset.dataset_utils import DatasetPadMode\nfrom verl.utils.torch_functional import masked_mean, masked_sum\nfrom verl.workers.config import ActorConfig, CriticConfig\nfrom verl.workers.roles.utils.padding import no_padding_2_padding\n\n\ndef sft_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None):\n    pad_mode = tu.get_non_tensor_data(data=data, key=\"pad_mode\", default=DatasetPadMode.NO_PADDING)\n    dp_size = data[\"dp_size\"]\n    batch_num_tokens = data[\"batch_num_tokens\"]\n\n    log_prob = model_output[\"log_probs\"]\n\n    if pad_mode == DatasetPadMode.NO_PADDING:\n        # log_prob and loss mask are nested tensors of shape [bsz, j1]\n        # for each sample, loss mask shape is [1, prompt_length + response_length]\n        loss_mask = data[\"loss_mask\"]\n\n        log_prob_flatten = log_prob.values()\n        loss_mask_flatten = loss_mask.values()\n\n        # left-shift the loss mask by one token to align with log_prob\n        loss_mask_flatten = torch.roll(loss_mask_flatten, shifts=-1, dims=0)\n\n        # NOTE: loss is averaged over all tokens in the batch across all data parallel groups,\n        # For FSDP backend, the loss is directly used for backward; while for Megatron backend,\n        # the loss should be scaled by `num_microbatches` and `cp_size` for pp schedule.\n        loss = -masked_sum(log_prob_flatten, loss_mask_flatten) / batch_num_tokens * dp_size\n    else:\n        response_mask = data[\"response_mask\"].to(bool)\n        loss = -masked_sum(log_prob, response_mask) / batch_num_tokens * dp_size\n\n    return loss, {\"loss\": loss.detach().item()}\n\n\ndef ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None):\n    log_prob = model_output[\"log_probs\"]\n    entropy = model_output.get(\"entropy\", None)\n\n    log_prob = no_padding_2_padding(log_prob, data)  # (bsz, response_length)\n    if entropy is not None:\n        entropy = no_padding_2_padding(entropy, data)  # (bsz, response_length)\n\n    metrics = {}\n\n    response_mask = data[\"response_mask\"].to(bool)\n    # compute policy loss\n    old_log_prob = data[\"old_log_probs\"]\n    advantages = data[\"advantages\"]\n\n    loss_agg_mode = config.loss_agg_mode\n\n    loss_mode = config.policy_loss.get(\"loss_mode\", \"vanilla\")\n\n    policy_loss_fn = get_policy_loss_fn(loss_mode)\n    pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(\n        old_log_prob=old_log_prob,\n        log_prob=log_prob,\n        advantages=advantages,\n        response_mask=response_mask,\n        loss_agg_mode=loss_agg_mode,\n        config=config,\n    )\n\n    metrics.update(\n        {\n            \"pg_loss\": pg_loss.detach().item(),\n            \"pg_clipfrac\": pg_clipfrac.detach().item(),\n            \"ppo_kl\": ppo_kl.detach().item(),\n            \"pg_clipfrac_lower\": pg_clipfrac_lower.detach().item(),\n        }\n    )\n    policy_loss = pg_loss\n\n    # add entropy loss\n    if entropy is not None:\n        entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n        entropy_coeff = config.entropy_coeff\n        policy_loss -= entropy_coeff * entropy_loss\n\n    # add kl loss\n    if config.use_kl_loss:\n        ref_log_prob = data[\"ref_log_prob\"]\n        # compute kl loss\n        kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=config.kl_loss_type)\n        kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=config.loss_agg_mode)\n\n        policy_loss += kl_loss * config.kl_loss_coef\n        metrics[\"kl_loss\"] = kl_loss.detach().item()\n        metrics[\"kl_coef\"] = config.kl_loss_coef\n\n    return policy_loss, metrics\n\n\ndef value_loss(config: CriticConfig, model_output, data: TensorDict, dp_group=None):\n    vpreds = model_output[\"values\"]\n    vpreds = no_padding_2_padding(vpreds, data)  # (bsz, response_length)\n\n    values = data[\"values\"]\n    returns = data[\"returns\"]\n    response_mask = data[\"response_mask\"].to(bool)\n\n    vf_loss, vf_clipfrac = compute_value_loss(\n        vpreds=vpreds,\n        values=values,\n        returns=returns,\n        response_mask=response_mask,\n        cliprange_value=config.cliprange_value,\n        loss_agg_mode=config.loss_agg_mode,\n    )\n\n    metrics = {}\n\n    metrics.update(\n        {\n            \"critic/vf_loss\": vf_loss.detach().item(),\n            \"critic/vf_clipfrac\": vf_clipfrac.detach().item(),\n            \"critic/vpred_mean\": masked_mean(vpreds, response_mask).detach().item(),\n        }\n    )\n\n    return vf_loss, metrics\n"
  },
  {
    "path": "verl_distillation/verl/workers/roles/utils/padding.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom tensordict import TensorDict\n\nfrom verl.utils import tensordict_utils as tu\nfrom verl.utils.device import (\n    is_cuda_available,\n    is_npu_available,\n)\n\nif is_cuda_available:\n    from flash_attn.bert_padding import pad_input, unpad_input\nelif is_npu_available:\n    from transformers.integrations.npu_flash_attention import pad_input, unpad_input\n\n\ndef left_right_2_no_padding(data: TensorDict) -> TensorDict:\n    \"\"\"\n    Convert TensorDict from left-right padding to no-padding format.\n\n    Args:\n        data: TensorDict with \"input_ids\", \"attention_mask\", \"response_mask\", \"position_ids\"\n\n    Returns:\n        data: TensorDict with\n        - Tensor includes NestedTensors like \"input_ids\", \"loss_mask\", \"position_ids\"\n        - NonTensorData includes \"max_seq_len\", \"max_response_len\", \"indices\"\n\n    Note:\n    1. the return input_ids/position_ids/loss_mask are nested tensor.\n    2. we will remove \"attention_mask\", \"response\" in the return data, but \"response_mask\" is kept.\n    \"\"\"\n    assert \"input_ids\" in data, \"input_ids is required in left-right padding data\"\n    assert \"attention_mask\" in data, \"attention_mask is required in left-right padding data\"\n    assert \"response_mask\" in data, \"response_mask is required in left-right padding data\"\n    assert \"position_ids\" in data, \"position_ids is required in left-right padding data\"\n\n    input_ids = data.pop(\"input_ids\")\n    attention_mask = data.pop(\"attention_mask\")\n    response_mask = data[\"response_mask\"]\n    if \"responses\" in data:\n        _ = data.pop(\"responses\")\n\n    max_seq_len, max_response_len = input_ids.shape[1], response_mask.shape[1]\n    tu.assign_non_tensor_data(data, \"max_seq_len\", max_seq_len)\n    tu.assign_non_tensor_data(data, \"max_response_len\", max_response_len)\n\n    input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)\n    tu.assign_non_tensor_data(data, \"indices\", indices)\n\n    input_ids_nested = torch.nested.nested_tensor_from_jagged(input_ids_rmpad.squeeze(-1), offsets=cu_seqlens)\n\n    seq_lens = cu_seqlens.diff().tolist()\n    response_lens = response_mask.sum(dim=1).tolist()\n\n    position_ids_list = []\n    loss_mask_list = []\n    for seq_len, response_len in zip(seq_lens, response_lens, strict=False):\n        position_ids_list.append(torch.arange(seq_len, device=input_ids.device))\n        loss_mask = torch.zeros(seq_len, dtype=torch.bool, device=input_ids.device)\n        assert seq_len >= response_len, f\"{seq_len=} is less than {response_len=}\"\n        loss_mask[-response_len:] = 1\n        loss_mask_list.append(loss_mask)\n\n    position_ids_nested = torch.nested.as_nested_tensor(position_ids_list, layout=torch.jagged)\n    loss_mask_nested = torch.nested.as_nested_tensor(loss_mask_list, layout=torch.jagged)\n\n    data[\"input_ids\"] = input_ids_nested\n    data[\"position_ids\"] = position_ids_nested\n    data[\"loss_mask\"] = loss_mask_nested\n\n    return data\n\n\ndef no_padding_2_padding(nested_tensor: torch.Tensor, data: TensorDict) -> torch.Tensor:\n    \"\"\"\n    Convert NestedTensor from no-padding to right padding format.\n\n    Args:\n        nested_tensor: NestedTensor with no-padding format\n        data: TensorDict with\n        - Tensor includes NestedTensors like \"input_ids\", \"loss_mask\", \"position_ids\"\n        - NonTensorData includes \"max_seq_len\", \"max_response_len\", \"indices\"\n\n    Returns:\n        values: regular tensor right padded to max_response_len\n    \"\"\"\n    assert \"indices\" in data, \"indices is required in left-right padding data\"\n    assert \"max_seq_len\" in data, \"max_seq_len is required in left-right padding data\"\n    assert \"max_response_len\" in data, \"max_response_len is required in left-right padding data\"\n\n    indices = tu.get_non_tensor_data(data=data, key=\"indices\", default=None)\n    max_seq_len = tu.get_non_tensor_data(data=data, key=\"max_seq_len\", default=2048)\n    max_response_len = tu.get_non_tensor_data(data=data, key=\"max_response_len\", default=1024)\n    batch_size = nested_tensor.size(0)\n\n    values = nested_tensor.values()\n    full_values = pad_input(\n        hidden_states=values.unsqueeze(-1),\n        indices=indices,\n        batch=batch_size,\n        seqlen=max_seq_len,\n    )\n    values = full_values.squeeze(-1)[:, -max_response_len - 1 : -1]  # (bsz, response_length)\n\n    return values\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BaseRollout, get_rollout_class\nfrom .hf_rollout import HFRollout\nfrom .naive import NaiveRollout\n\n__all__ = [\"BaseRollout\", \"NaiveRollout\", \"HFRollout\", \"get_rollout_class\"]\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nfrom abc import ABC, abstractmethod\nfrom typing import Generator\n\nimport torch\nfrom torch.distributed.device_mesh import DeviceMesh\n\nfrom verl import DataProto\nfrom verl.workers.config import HFModelConfig, RolloutConfig\n\n__all__ = [\"BaseRollout\"]\n\n\nclass BaseRollout(ABC):\n    \"\"\"Base class for rollout.\"\"\"\n\n    def __init__(\n        self,\n        config: RolloutConfig,\n        model_config: HFModelConfig,\n        device_mesh: DeviceMesh,\n    ):\n        self.config = config\n        self.model_config = model_config\n        self.device_mesh = device_mesh\n\n    @abstractmethod\n    async def resume(self, tags: list[str]):\n        \"\"\"Resume rollout weights or kv cache in GPU memory.\n\n        Args:\n            tags: weights or kv_cache.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    async def update_weights(\n        self,\n        weights: Generator[tuple[str, torch.Tensor], None, None],\n        **kwargs,\n    ):\n        \"\"\"Update the weights of the rollout model.\n\n        Args:\n            weights: A generator that yields the name of the weight tensor and the tensor itself.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    async def release(self):\n        \"\"\"Release weights and kv cache in GPU memory.\"\"\"\n        pass\n\n    def generate_sequences(self, prompts: DataProto) -> DataProto:\n        \"\"\"Batch generate sequences in sync mode.\n\n        Args:\n            prompts: The input prompts.\n\n        Returns:\n            The output sequences.\n        \"\"\"\n        raise NotImplementedError\n\n\n_ROLLOUT_REGISTRY = {\n    (\"vllm\", \"sync\"): \"verl.workers.rollout.vllm_rollout.vLLMRollout\",\n    (\"vllm\", \"async\"): \"verl.workers.rollout.vllm_rollout.vLLMAsyncRollout\",\n    (\"sglang\", \"sync\"): \"verl.workers.rollout.sglang_rollout.sglang_rollout.SGLangRollout\",\n    (\"sglang\", \"async\"): \"verl.workers.rollout.sglang_rollout.sglang_rollout.ServerAdapter\",\n}\n\n\ndef get_rollout_class(rollout_name: str, mode: str) -> type[BaseRollout]:\n    \"\"\"Get the rollout class by name.\n\n    Args:\n        rollout_name: The name of the rollout.\n        mode: The mode of the rollout, sync: spmd mode, async: server mode.\n\n    Returns:\n        The rollout class.\n    \"\"\"\n    assert (rollout_name, mode) in _ROLLOUT_REGISTRY, f\"Rollout {rollout_name} with mode {mode} not found\"\n    fqdn = _ROLLOUT_REGISTRY[(rollout_name, mode)]\n    module_name, class_name = fqdn.rsplit(\".\", 1)\n    rollout_module = importlib.import_module(module_name)\n    return getattr(rollout_module, class_name)\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/hf_rollout.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nRollout with huggingface models.\nTODO: refactor this class. Currently, it will hang when using FSDP HybridShard. We should actually create a single\nGPU model. Then, get full state_dict and bind the state_dict to the single GPU model. Then, use the single GPU model\nto perform generation.\n\"\"\"\n\nimport contextlib\n\nimport torch\nimport torch.distributed\nfrom tensordict import TensorDict\nfrom torch import nn\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom transformers import GenerationConfig\n\nfrom verl import DataProto\nfrom verl.utils.device import get_device_name, get_torch_device\nfrom verl.utils.torch_functional import get_response_mask\n\nfrom .base import BaseRollout\n\n__all__ = [\"HFRollout\"]\n\n\nclass HFRollout(BaseRollout):\n    def __init__(self, module: nn.Module, config):\n        super().__init__()\n        self.config = config\n        self.module = module\n\n    def generate_sequences(self, prompts: DataProto) -> DataProto:\n        batch_size = prompts.batch.batch_size[0]\n        num_chunks = max(batch_size // self.config.get(\"micro_batch_size\", batch_size), 1)\n        batch_prompts = prompts.chunk(chunks=num_chunks)\n        output = [self._generate_minibatch(p) for p in batch_prompts]\n        output = DataProto.concat(output)\n        return output\n\n    @torch.no_grad()\n    def _generate_minibatch(self, prompts: DataProto) -> DataProto:\n        # make sampling args can be overridden by inputs\n        do_sample = prompts.meta_info.get(\"do_sample\", self.config.do_sample)\n        is_validate = prompts.meta_info.get(\"validate\", False)\n\n        temperature = prompts.meta_info.get(\"temperature\", self.config.temperature)\n        response_length = prompts.meta_info.get(\"response_length\", self.config.response_length)\n        top_p = prompts.meta_info.get(\"top_p\", self.config.get(\"top_p\", 1.0))\n        top_k = max(0, prompts.meta_info.get(\"top_k\", self.config.get(\"top_k\", 0)))  # to be compatible with vllm\n\n        if not do_sample:\n            # do_sample==False -> greedy decoding\n            kwargs = {\n                \"do_sample\": False,\n                \"num_beams\": 1,\n            }\n        elif is_validate:\n            # do validate and do sample -> use val_kwargs\n            kwargs = {\n                \"do_sample\": True,\n                \"num_beams\": 1,\n                \"top_k\": max(0, self.config.val_kwargs.top_k),  # to be compatible with vllm\n                \"top_p\": self.config.val_kwargs.top_p,\n                \"temperature\": self.config.val_kwargs.temperature,\n                \"num_return_sequences\": 1,  # if validate, already repeat in ray_trainer\n            }\n        else:\n            # do_sample -> use rollout config\n            kwargs = {\n                \"do_sample\": True,\n                \"num_beams\": 1,\n                \"top_p\": top_p,\n                \"top_k\": top_k,\n                \"temperature\": temperature,\n                # already repeat in ray_trainer\n                # https://github.com/volcengine/verl/blob/2fdfbdcba6f2e076f64bc47922d8fe6cf7dc7da5/verl/trainer/ppo/ray_trainer.py#L1117\n                \"num_return_sequences\": 1,\n            }\n\n        # make config according to generate mode\n        generation_config = GenerationConfig(**kwargs)\n\n        idx = prompts.batch[\"input_ids\"]  # (bs, prompt_length)\n        prompt_length = idx.size(1)\n        attention_mask = prompts.batch[\"attention_mask\"]  # left-padded attention_mask\n        position_ids = prompts.batch[\"position_ids\"]\n\n        # used to construct attention_mask\n        eos_token_id = prompts.meta_info[\"eos_token_id\"]\n        pad_token_id = prompts.meta_info[\"pad_token_id\"]\n\n        self.module.eval()\n        param_ctx = contextlib.nullcontext()\n\n        if isinstance(self.module, FSDP):\n            # recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069\n            param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False)\n        with param_ctx, torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):\n            output = self.module.generate(\n                input_ids=idx,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                do_sample=do_sample,\n                max_new_tokens=response_length,\n                eos_token_id=eos_token_id,\n                pad_token_id=pad_token_id,\n                generation_config=generation_config,\n                output_scores=False,  # this is potentially very large\n                return_dict_in_generate=True,\n                use_cache=True,\n            )\n\n        # TODO: filter out the seq with no answers like ds-chat\n        seq = output.sequences\n        generated_batch_size = seq.size(0)  # bs * num_return_sequences\n\n        # huggingface generate will stop generating when all the batch reaches [EOS].\n        # We have to pad to response_length\n        sequence_length = prompt_length + self.config.response_length\n        delta_length = sequence_length - seq.shape[1]\n\n        if delta_length > 0:\n            delta_tokens = torch.ones(size=(generated_batch_size, delta_length), device=seq.device, dtype=seq.dtype)\n            delta_tokens = pad_token_id * delta_tokens\n            seq = torch.cat((seq, delta_tokens), dim=1)\n        assert seq.shape[1] == sequence_length\n\n        # make necessary reputations if num_return_sequences > 1\n        num_return_sequences = kwargs.get(\"num_return_sequences\", 1)\n        if num_return_sequences > 1:\n            position_ids = position_ids.repeat_interleave(num_return_sequences, dim=0)\n            attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)\n\n        prompt = seq[:, :prompt_length]  # (generated_batch_size, prompt_length)\n        response = seq[:, prompt_length:]  # (generated_batch_size, response_length)\n\n        response_length = response.size(1)\n        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n        delta_position_id = delta_position_id.unsqueeze(0).repeat(generated_batch_size, 1)\n\n        response_position_ids = position_ids[:, -1:] + delta_position_id\n        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n\n        response_attention_mask = get_response_mask(\n            response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype\n        )\n        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n\n        batch = TensorDict(\n            {\n                \"prompts\": prompt,\n                \"responses\": response,\n                \"input_ids\": seq,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=generated_batch_size,\n        )\n\n        # empty cache before compute old_log_prob\n        get_torch_device().empty_cache()\n\n        self.module.train()\n        return DataProto(batch=batch)\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/naive/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .naive_rollout import NaiveRollout\n\n__all__ = [\"NaiveRollout\"]\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/naive/naive_rollout.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nIn single GPU rollout, the sequences are generated directly by sampling from the model.\nThe output will contain\n1. output_ids\n2. attention_masks (left padding)\n3. eos_masks\n4. log_probs\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom tensordict import TensorDict\nfrom torch import nn\n\nfrom verl import DataProto\nfrom verl.utils.torch_functional import logprobs_from_logits\n\nfrom ..base import BaseRollout\n\n__all__ = [\"NaiveRollout\"]\n\n\nclass NaiveRollout(BaseRollout):\n    def __init__(self, module: nn.Module, config):\n        \"\"\"A naive rollout. It requires the module to be compatible with huggingface APIs. That is:\n        The module should define __call__ to receive input_ids, attention_mask and position_ids.\n        It outputs a structure that contains logits field.\n\n        Args:\n            module: module here follows huggingface APIs\n            config: DictConfig\n        \"\"\"\n        super().__init__()\n        self.config = config\n        self.module = module\n\n    @torch.no_grad()\n    def generate_sequences(self, prompts: DataProto) -> DataProto:\n        \"\"\"Generate sequences\"\"\"\n        idx = prompts.batch[\"input_ids\"]  # (bs, prompt_length)\n        attention_mask = prompts.batch[\"attention_mask\"]  # left-padded attention_mask\n        position_ids = prompts.batch[\"position_ids\"]\n\n        # used to construct attention_mask\n        eos_token_id = prompts.meta_info[\"eos_token_id\"]\n\n        batch_size = idx.size(0)\n        prompt_length = idx.size(1)\n\n        self.module.eval()\n\n        prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)\n\n        logits_lst = []\n        for _ in range(self.config.response_length):\n            # if the sequence context is growing too long we must crop it at block_size\n            # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]\n            idx_cond = idx\n            # forward the model to get the logits for the index in the sequence\n            # we use huggingface APIs here\n            output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids)\n            logits = output.logits\n            # pluck the logits at the final step and scale by desired temperature\n            logits = logits[:, -1, :] / self.config.temperature  # (bs, vocab_size)\n            # optionally crop the logits to only the top k options\n            if self.config.top_k is not None:\n                v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1)))\n                logits[logits < v[:, [-1]]] = -float(\"Inf\")\n            # apply softmax to convert logits to (normalized) probabilities\n            probs = F.softmax(logits, dim=-1)\n            # sample from the distribution\n            if self.config.do_sample:\n                idx_next = torch.multinomial(probs, num_samples=1)\n            else:\n                idx_next = torch.argmax(probs, dim=-1, keepdim=True)\n\n            attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)\n\n            for token_id in eos_token_id:\n                prev_attention_mask = torch.logical_and(idx_next != token_id, prev_attention_mask.bool())\n            prev_attention_mask.to(attention_mask.dtype)\n\n            position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)\n\n            # append sampled index to the running sequence and continue\n            idx = torch.cat((idx, idx_next), dim=1)\n            logits_lst.append(logits)\n\n        logits = torch.stack(logits_lst, dim=1)  # (bs, response_length, vocab_size)\n        prompts = idx[:, :prompt_length]  # (bs, prompt_length)\n        response = idx[:, prompt_length:]  # (bs, response_length)\n        log_probs = logprobs_from_logits(logits=logits, labels=response)\n        batch = TensorDict(\n            {\n                \"input_ids\": prompts,\n                \"responses\": response,\n                \"sequences\": idx,\n                \"old_log_probs\": log_probs,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=batch_size,\n        )\n\n        self.module.train()\n\n        return DataProto(batch=batch)\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/replica.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport logging\nimport os\nfrom abc import ABC, abstractmethod\nfrom enum import Enum\nfrom typing import Callable, Optional\n\nfrom pydantic import BaseModel\nfrom ray.actor import ActorHandle\n\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup\nfrom verl.trainer.ppo.ray_trainer import RayResourcePool, ResourcePoolManager\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import HFModelConfig, RolloutConfig\n\nlogger = logging.getLogger(__file__)\n\n\nclass TokenOutput(BaseModel):\n    token_ids: list[int]\n    \"\"\"response token ids\"\"\"\n    log_probs: Optional[list[float]] = None\n    \"\"\"logprobs of response token ids\"\"\"\n\n\nclass RolloutMode(Enum):\n    # Rollout engine and training engine(fsdp/megatron) fused in same process\n    # Rollout and trainer share GPUs, switch context with weight synchronization.\n    # Usage scenarios: on-policy training.\n    HYBRID = \"hybrid\"\n\n    # Rollout engine colocated with hybrid engine in same ray placement group but in separate process.\n    # Rollout and hybrid processes share GPUs, switch context without weight synchronization.\n    # Usage scenarios: GRM (LLM as a judge).\n    COLOCATED = \"colocated\"\n\n    # Standalone rollout server with separate GPU resource, disaggregated architecture.\n    # Usage scenarios: off-policy training.\n    STANDALONE = \"standalone\"\n\n\nclass RolloutReplica(ABC):\n    \"\"\"Rollout replica is an individual server instance, which may be deployed on single or multiple nodes.\n    It is equivalent to launch server in each node with command line:\n\n    SGLang:\n    ```\n    python -m sglang.launch_server --node-rank 0 --nnode 2 ...\n    python -m sglang.launch_server --node-rank 1 --nnode 2 ...\n    ```\n\n    vLLM:\n    ```\n    vllm serve --data-parallel-size 16 --data-parallel-size-local 8 --data-parallel-start-rank 0 ...\n    vllm serve --data-parallel-size 16 --data-parallel-size-local 8 --data-parallel-start-rank 8 ...\n    ```\n\n    Args:\n        replica_rank: int, rank of this rollout replica.\n        config: RolloutConfig, full config.\n        gpus_per_node: int, number of gpus per node.\n    \"\"\"\n\n    def __init__(\n        self,\n        replica_rank: int,\n        config: RolloutConfig,\n        model_config: HFModelConfig,\n        gpus_per_node: int = 8,\n        is_reward_model: bool = False,\n    ) -> None:\n        self.replica_rank = replica_rank\n        self.config = omega_conf_to_dataclass(config)\n        self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)\n\n        self.world_size = (\n            self.config.tensor_model_parallel_size\n            * self.config.data_parallel_size\n            * self.config.pipeline_model_parallel_size\n        )\n        self.gpus_per_node = min(gpus_per_node, self.world_size)\n        assert self.world_size % self.gpus_per_node == 0, (\n            f\"world_size {self.world_size} must be divisible by gpus_per_node {self.gpus_per_node}\"\n        )\n        self.nnodes = self.world_size // self.gpus_per_node\n        self.is_reward_model = is_reward_model\n\n        self.rollout_mode: RolloutMode = None\n        self.workers: list[ActorHandle] = []\n        self.resource_pool: RayResourcePool = None\n\n        self.servers: list[ActorHandle] = []\n        self._server_address: str = None\n        self._server_handle: ActorHandle = None\n\n    async def init_hybrid(self, worker_group: RayWorkerGroup):\n        \"\"\"Init hybrid rollout server, rollout engine and training engine(fsdp/megatron) fused in same process.\n\n        Args:\n            worker_group: RayWorkerGroup, fused workers where training engine(fsdp/megatron) have been initialized.\n        \"\"\"\n        self.rollout_mode = RolloutMode.HYBRID\n        self.workers = worker_group.workers[\n            self.world_size * self.replica_rank : self.world_size * (self.replica_rank + 1)\n        ]\n        await self.launch_servers()\n\n    # TODO(@dyy): init with resource_pool?\n    async def init_colocated(self, worker_group: RayWorkerGroup):\n        \"\"\"Init colocated rollout server, rollout engine and hybrid engine colocated in same ray placement group\n        but in separate processes.\n\n        Args:\n            resource_pool: RayResourcePool, ray placement group where hybrid engine processes have been launched.\n        \"\"\"\n        self.rollout_mode = RolloutMode.COLOCATED\n        self.workers = worker_group.workers[\n            self.world_size * self.replica_rank : self.world_size * (self.replica_rank + 1)\n        ]\n        await self.launch_servers()\n\n    async def init_standalone(self):\n        \"\"\"Init standalone rollout server, create new resource pool for this rollout.\"\"\"\n        # create resource pool for this rollout\n        self.rollout_mode = RolloutMode.STANDALONE\n        resource_pool_name = (\n            f\"rollout_pool_{self.replica_rank}\" if self.is_reward_model else f\"rollout_pool_reward_{self.replica_rank}\"\n        )\n        resource_pool_spec = {\n            resource_pool_name: [self.gpus_per_node] * self.nnodes,\n        }\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=None)\n        resource_pool_manager.create_resource_pool()\n        self.resource_pool = resource_pool_manager.resource_pool_dict[resource_pool_name]\n\n        # create worker group for this rollout\n\n        worker_group = RayWorkerGroup(\n            resource_pool=self.resource_pool,\n            ray_cls_with_init=self.get_ray_class_with_init_args(),\n            bin_pack=False,\n            name_prefix=f\"rollout_standalone_{self.replica_rank}\"\n            if not self.is_reward_model\n            else f\"rollout_reward_standalone_{self.replica_rank}\",\n        )\n        self.workers = worker_group.workers\n        await self.launch_servers()\n\n    @abstractmethod\n    def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:\n        \"\"\"Get rollout worker actor class for colocated and standalone mode.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    async def launch_servers(self):\n        \"\"\"Launch http server in each node.\"\"\"\n        raise NotImplementedError\n\n    @property\n    def server_address(self) -> str:\n        \"\"\"Get rollout server address for OpenAI chat completion.\"\"\"\n        return self._server_address\n\n    @property\n    def server_handle(self) -> ActorHandle:\n        \"\"\"Get rollout server handle for Token-in-token-out generation.\"\"\"\n        return self._server_handle\n\n    async def wake_up(self):\n        \"\"\"Wake up each rollout server.\"\"\"\n        await asyncio.gather(*[server.wake_up.remote() for server in self.servers])\n\n    async def sleep(self):\n        \"\"\"Sleep each rollout server.\"\"\"\n        await asyncio.gather(*[server.sleep.remote() for server in self.servers])\n\n\nclass RolloutReplicaRegistry:\n    \"\"\"Factory for managing rollout replica implementations.\"\"\"\n\n    _registry: dict[str, Callable[[], type[RolloutReplica]]] = {}\n\n    @classmethod\n    def register(cls, name: str, loader: Callable[[], type[RolloutReplica]]) -> None:\n        \"\"\"Register a new rollout replica type.\"\"\"\n        cls._registry[name] = loader\n\n    @classmethod\n    def get(cls, name: str) -> type[RolloutReplica]:\n        \"\"\"Get a rollout replica class by name.\"\"\"\n        if name not in cls._registry:\n            raise ValueError(f\"Unknown rollout mode: {name}. Available: {list(cls._registry.keys())}\")\n        return cls._registry[name]()\n\n\n# Loader functions for built-in types\ndef _load_vllm():\n    from verl.workers.rollout.vllm_rollout.vllm_async_server import vLLMReplica\n\n    return vLLMReplica\n\n\ndef _load_sglang():\n    os.environ[\"SGLANG_USE_CPU_ENGINE\"] = \"1\"\n\n    try:\n        import vllm  # noqa: F401\n    except ImportError:\n        import sys\n        from unittest.mock import Mock\n\n        mock_vllm = Mock()\n        mock_vllm._custom_ops = Mock()\n        mock_vllm._custom_ops.scaled_fp8_quant = Mock()\n        sys.modules[\"vllm\"] = mock_vllm\n        sys.modules[\"vllm._custom_ops\"] = mock_vllm._custom_ops\n\n    from verl.workers.rollout.sglang_rollout.async_sglang_server import SGLangReplica\n\n    del os.environ[\"SGLANG_USE_CPU_ENGINE\"]\n    return SGLangReplica\n\n\n# Register built-in types\nRolloutReplicaRegistry.register(\"vllm\", _load_vllm)\nRolloutReplicaRegistry.register(\"sglang\", _load_sglang)\n\n\n# Original function for backward compatibility\ndef get_rollout_replica_class(rollout: str) -> type[RolloutReplica]:\n    return RolloutReplicaRegistry.get(rollout)\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/schemas.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport difflib\nimport logging\nimport os\nfrom enum import Enum\nfrom typing import Any, Optional\n\nimport torch\nfrom pydantic import BaseModel, ConfigDict, model_validator\nfrom transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin\n\nfrom verl.tools.schemas import OpenAIFunctionToolCall, OpenAIFunctionToolSchema, ToolResponse\nfrom verl.utils.model import compute_position_id_with_mask\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\nBASE_CHAT_HISTORY = [\n    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n    {\"role\": \"user\", \"content\": \"I am a user.\"},\n]\n\n\nclass FinishReasonTypeEnum(str, Enum):\n    \"\"\"The enum for finish reason type.\"\"\"\n\n    LENGTH = \"length\"\n    STOP = \"stop\"\n    TOOL_CALL = \"tool_calls\"\n\n    @classmethod\n    def from_str(cls, value: str) -> \"FinishReasonTypeEnum\":\n        if value == \"stop\":\n            return cls.STOP\n        elif value == \"length\":\n            return cls.LENGTH\n        elif value == \"tool_calls\":\n            return cls.TOOL_CALL\n        else:\n            raise ValueError(f\"Unsupported finish reason type: {value}\")\n\n\nclass Message(BaseModel):\n    role: str\n    content: str | dict[str, Any] | list[dict[str, Any]] | ToolResponse\n    tool_calls: Optional[list[OpenAIFunctionToolCall]] = None\n\n\nclass AsyncRolloutRequestStateEnum(str, Enum):\n    \"\"\"The enum for async rollout request state.\"\"\"\n\n    PENDING = \"pending\"\n    RUNNING = \"running\"\n    COMPLETED = \"completed\"\n    FAILED = \"failed\"\n    TOOL_CALLING = \"tool_calling\"\n    INTERACTING = \"interacting\"\n\n\nclass TokenizationSanityCheckModeEnum(str, Enum):\n    \"\"\"The enum for tokenization sanity check mode.\"\"\"\n\n    DISABLE = \"disable\"\n    STRICT = \"strict\"\n    IGNORE_STRIPPABLE = \"ignore_strippable\"\n\n\nclass AsyncRolloutRequest(BaseModel):\n    \"\"\"The data model for async rollout.\"\"\"\n\n    model_config = ConfigDict(arbitrary_types_allowed=True)\n\n    batch_data_id: int = 0\n    rollout_offset: int = 0\n    request_id: str\n    state: AsyncRolloutRequestStateEnum\n    messages: list[Message]\n    multi_modal_keys: Optional[list[str]] = None\n    multi_modal_data: Optional[dict[str, Any]] = None\n    multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None\n    tool_schemas: Optional[list[OpenAIFunctionToolSchema]] = None\n    tools_kwargs: dict[str, Any] = {}\n    interaction_kwargs: dict[str, Any] = {}\n    input_ids: Optional[torch.Tensor] = None\n    prompt_ids: Optional[torch.Tensor] = None\n    response_ids: Optional[torch.Tensor] = None\n    attention_mask: Optional[torch.Tensor] = None\n    prompt_attention_mask: Optional[torch.Tensor] = None\n    response_attention_mask: Optional[torch.Tensor] = None\n    position_ids: Optional[torch.Tensor] = None\n    prompt_position_ids: Optional[torch.Tensor] = None\n    response_position_ids: Optional[torch.Tensor] = None\n    loss_mask: Optional[torch.Tensor] = None\n    prompt_loss_mask: Optional[torch.Tensor] = None\n    response_loss_mask: Optional[torch.Tensor] = None\n    reward_scores: dict[str, float]\n    max_prompt_len: int\n    max_response_len: int = 8192\n    max_model_len: int = 32768\n    metrics: dict[str, list[Any]] = {}\n    output_token_ids: torch.Tensor | None = None\n    rollout_log_probs: torch.Tensor | None = None\n\n    use_inference_chat_template: bool\n    tokenization_sanity_check_mode: TokenizationSanityCheckModeEnum\n    generation_prompt_ids: Optional[torch.Tensor] = None\n    base_conv_wo_gen_prompt_end_pos: int\n    base_conv_with_gen_prompt_end_pos: int\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def initialize_request(cls, values):\n        if not (messages := values.get(\"messages\")):\n            raise ValueError(\"messages is required for AsyncRolloutRequest initialization\")\n        if not (max_prompt_len := values.get(\"max_prompt_len\")):\n            raise ValueError(\"max_prompt_len is required for AsyncRolloutRequest initialization\")\n        if not (processing_class := values.pop(\"processing_class\", None)):\n            raise ValueError(\"processing_class is required for AsyncRolloutRequest initialization\")\n\n        values[\"messages\"] = [Message.model_validate(msg) for msg in messages]\n\n        # If there is no multi_modal_keys, we assume the multi-modal data is image and video.\n        if not values.get(\"multi_modal_keys\"):\n            values[\"multi_modal_keys\"] = [\"image\", \"video\"]\n        if not values.get(\"multi_modal_data\"):\n            values[\"multi_modal_data\"] = {key: [] for key in values[\"multi_modal_keys\"]}\n        else:\n            # check if all multi_modal_keys are in multi_modal_data\n            for key in values[\"multi_modal_keys\"]:\n                if key not in values[\"multi_modal_data\"]:\n                    values[\"multi_modal_data\"][key] = []\n        if not values.get(\"multi_modal_inputs\"):\n            values[\"multi_modal_inputs\"] = {}\n\n        tools = (\n            [tool.model_dump() for tool in tool_schemas] if (tool_schemas := values.get(\"tool_schemas\", [])) else None\n        )\n\n        multi_modal_data = values[\"multi_modal_data\"]\n        tokens_without_prompt = cls._handle_apply_chat_template(\n            processing_class,\n            messages,\n            multi_modal_data=multi_modal_data,\n            tools=tools,\n            add_generation_prompt=False,\n            tokenize=True,\n        )\n        if (\n            values.get(\"input_ids\") is None\n            or values.get(\"attention_mask\") is None\n            or values.get(\"position_ids\") is None\n        ):\n            tokenization_dict_with_prompt = cls._handle_apply_chat_template(\n                processing_class,\n                messages,\n                multi_modal_data=multi_modal_data,\n                tools=tools,\n                add_generation_prompt=True,\n                tokenize=True,\n                return_dict=True,\n            )\n\n            values[\"input_ids\"], values[\"attention_mask\"] = (\n                tokenization_dict_with_prompt[\"input_ids\"],\n                tokenization_dict_with_prompt[\"attention_mask\"],\n            )\n            if values[\"input_ids\"].shape[-1] > max_prompt_len:\n                # Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an\n                # error for this case in the future.\n                # Ensure batch_data_id exists with default value if not provided\n                if \"batch_data_id\" not in values:\n                    values[\"batch_data_id\"] = cls.model_fields[\"batch_data_id\"].default\n                logger.warning(\n                    f\"Prompt {values['batch_data_id']} has length {values['input_ids'].shape[-1]} \"\n                    f\"which is greater than max_prompt_len {max_prompt_len} after applied chat template with tools.\"\n                )\n\n            # Process multi_modal_inputs\n            multi_modal_inputs = tokenization_dict_with_prompt.copy()\n            multi_modal_inputs.pop(\"input_ids\", None)\n            multi_modal_inputs.pop(\"attention_mask\", None)\n            values[\"multi_modal_inputs\"] = multi_modal_inputs\n\n            values[\"position_ids\"] = values[\"prompt_position_ids\"] = cls._get_position_ids(\n                processing_class, values[\"input_ids\"], values[\"attention_mask\"], multi_modal_inputs\n            )\n\n        values[\"prompt_ids\"], values[\"prompt_attention_mask\"] = values[\"input_ids\"], values[\"attention_mask\"]\n        values[\"loss_mask\"] = values[\"prompt_loss_mask\"] = torch.zeros_like(values[\"input_ids\"], dtype=torch.bool)\n        values[\"generation_prompt_ids\"] = values[\"input_ids\"][..., tokens_without_prompt.shape[-1] :]\n        values[\"base_conv_wo_gen_prompt_end_pos\"] = cls._handle_apply_chat_template(\n            processing_class,\n            BASE_CHAT_HISTORY,\n            multi_modal_data=multi_modal_data,\n            tools=tools,\n            add_generation_prompt=False,\n            tokenize=True,\n        ).shape[-1]\n\n        values[\"base_conv_with_gen_prompt_end_pos\"] = cls._handle_apply_chat_template(\n            processing_class,\n            BASE_CHAT_HISTORY,\n            multi_modal_data=multi_modal_data,\n            tools=tools,\n            add_generation_prompt=True,\n            tokenize=True,\n        ).shape[-1]\n\n        return values\n\n    @staticmethod\n    def _handle_apply_chat_template(\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        messages: list[Message],\n        multi_modal_data: dict[str, Any],\n        tools: Optional[list[OpenAIFunctionToolSchema]] = None,\n        add_generation_prompt: bool = False,\n        tokenize: bool = False,\n        return_dict: bool = False,\n    ):\n        raw_prompt = processing_class.apply_chat_template(\n            messages, tools=tools, add_generation_prompt=add_generation_prompt, tokenize=False\n        )\n        if not tokenize:\n            return raw_prompt\n\n        if isinstance(processing_class, PreTrainedTokenizer) or isinstance(processing_class, PreTrainedTokenizerFast):\n            if any(len(values) > 0 for values in multi_modal_data.values()):\n                logger.warning(\n                    \"There is multi_modal_data but you are not using a processor. Multi-modal data will be ignored.\"\n                )\n            model_inputs = processing_class(text=[raw_prompt], return_tensors=\"pt\")\n        elif isinstance(processing_class, ProcessorMixin):\n            # When we update multi_model_keys, we also need to update this logic\n            images = images if len(images := multi_modal_data.get(\"image\", [])) > 0 else None\n            videos = videos if len(videos := multi_modal_data.get(\"video\", [])) > 0 else None\n            model_inputs = processing_class(text=[raw_prompt], images=images, videos=videos, return_tensors=\"pt\")\n        else:\n            raise ValueError(f\"Unsupported processing class type: {type(processing_class)}\")\n\n        model_inputs = dict(model_inputs)\n        if return_dict:\n            return model_inputs\n        else:\n            return model_inputs[\"input_ids\"]\n\n    @staticmethod\n    def _get_position_ids(\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        input_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n        multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        # special case for qwen2vl\n        is_qwen2vl = (\n            hasattr(processing_class, \"image_processor\")\n            and \"Qwen2VLImageProcessor\" in processing_class.image_processor.__class__.__name__\n        )\n        if is_qwen2vl:\n            from verl.models.transformers.qwen2_vl import get_rope_index\n\n            image_grid_thw = video_grid_thw = second_per_grid_ts = None\n            if multi_modal_inputs:\n                image_grid_thw = multi_modal_inputs.get(\"image_grid_thw\")\n                video_grid_thw = multi_modal_inputs.get(\"video_grid_thw\")\n                second_per_grid_ts = multi_modal_inputs.get(\"second_per_grid_ts\")\n\n            assert input_ids.dim() == 2 and input_ids.shape[0] == 1, (\n                f\"input_ids should be 2D with batch size 1, but got shape {input_ids.shape}\"\n            )\n            assert attention_mask.dim() == 2 and attention_mask.shape[0] == 1, (\n                f\"attention_mask should be 2D with batch size 1, but got shape {attention_mask.shape}\"\n            )\n            new_position_ids = get_rope_index(\n                processing_class,\n                input_ids=input_ids.squeeze(0),\n                image_grid_thw=image_grid_thw,\n                video_grid_thw=video_grid_thw,\n                second_per_grid_ts=second_per_grid_ts,\n                attention_mask=attention_mask.squeeze(0),\n            )\n            return new_position_ids  # (3, seq_len)\n        else:\n            return compute_position_id_with_mask(attention_mask)  # (1, seq_len)\n\n    def _update_input_ids(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        new_input_ids: torch.Tensor,\n        attention_mask: bool,\n        loss_mask: bool,\n        new_multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None,\n    ) -> None:\n        \"\"\"\n        Update the input_ids, attention_mask, position_ids, and loss_mask of the request in additive manner.\n        \"\"\"\n        self.input_ids = torch.cat([self.input_ids, new_input_ids], dim=-1)\n        attention_mask = torch.ones_like(new_input_ids) * int(attention_mask)\n        self.attention_mask = torch.cat([self.attention_mask, attention_mask], dim=-1)\n        loss_mask = torch.ones_like(new_input_ids) * int(loss_mask)\n        self.loss_mask = torch.cat([self.loss_mask, loss_mask], dim=-1)\n\n        if new_multi_modal_inputs:\n            self._update_multi_modal_inputs(new_multi_modal_inputs)\n\n        new_position_ids = self._get_position_ids(\n            processing_class, new_input_ids, attention_mask, new_multi_modal_inputs\n        )\n\n        last_pos = self.position_ids[..., -1:]\n        new_position_ids = new_position_ids + (last_pos + 1)\n\n        self.position_ids = torch.cat([self.position_ids, new_position_ids], dim=-1)\n\n        assert (\n            self.input_ids.shape[-1]\n            == self.attention_mask.shape[-1]\n            == self.position_ids.shape[-1]\n            == self.loss_mask.shape[-1]\n        ), f\"\"\"Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, \n            {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}\"\"\"\n\n    def _update_multi_modal_inputs(self, new_multi_modal_inputs: dict[str, torch.Tensor]) -> None:\n        \"\"\"\n        Update the multi_modal_inputs of the request in additive manner.\n        \"\"\"\n        for key in new_multi_modal_inputs:\n            input_tensor = new_multi_modal_inputs[key]\n            self.multi_modal_inputs[key] = (\n                torch.cat([self.multi_modal_inputs[key], input_tensor], dim=0)\n                if key in self.multi_modal_inputs\n                else input_tensor\n            )\n\n    def get_generation_prompt_ids(\n        self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin\n    ) -> list[int]:\n        \"\"\"\n        Get the generation prompt ids for rollout engine.\n\n        Because rollout engine(SGLang) requires the ids to be a list, we need to convert the tensor to a list.\n        \"\"\"\n        generation_prompt_ids = (\n            None\n            if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all()\n            else self.generation_prompt_ids\n        )\n        if generation_prompt_ids is not None:\n            self._update_input_ids(processing_class, generation_prompt_ids, attention_mask=True, loss_mask=False)\n\n        if self.use_inference_chat_template:\n            messages = [msg.model_dump() for msg in self.messages]\n            tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n            generation_prompt_ids = self._handle_apply_chat_template(\n                processing_class,\n                messages,\n                multi_modal_data=self.multi_modal_data,\n                tools=tools,\n                add_generation_prompt=True,\n                tokenize=True,\n            )\n            return generation_prompt_ids.squeeze(0).tolist()\n        else:\n            return self.input_ids.squeeze(0).tolist()\n\n    def add_user_message(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        content: str,\n    ) -> None:\n        self.messages.append(Message(role=\"user\", content=content))\n        messages = [*BASE_CHAT_HISTORY, self.messages[-1]]\n        tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n\n        # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine\n        # Inference, it is pure text.\n        content_ids = self._handle_apply_chat_template(\n            processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True\n        )[..., self.base_conv_wo_gen_prompt_end_pos :]\n        self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=False)\n\n    def add_assistant_message(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        content: str,\n        content_ids: Optional[torch.Tensor] = None,\n        tool_calls: Optional[list[OpenAIFunctionToolCall]] = None,\n    ) -> None:\n        self.messages.append(Message(role=\"assistant\", content=content, tool_calls=tool_calls))\n        if content_ids is None:\n            messages = [*BASE_CHAT_HISTORY, self.messages[-1]]\n            tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n\n            # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine\n            # Inference, it is pure text.\n            content_ids = self._handle_apply_chat_template(\n                processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True\n            )[..., self.base_conv_with_gen_prompt_end_pos :]\n        self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=True)\n\n    def add_tool_response_messages(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        contents: list[ToolResponse],\n    ) -> None:\n        if not contents or all(content.is_empty() for content in contents):\n            return\n        # We also handle the case when tool returns image\n        # We require the processing of the image and video to be done at tool.execute() level\n        delta_multi_modal_data = {key: [] for key in self.multi_modal_keys}\n        for content in contents:\n            if content.is_text_only():\n                self.messages.append(Message(role=\"tool\", content=content.text))\n            else:\n                content_list = []\n                # When we update multi_model_keys, we also need to update this logic\n                if content.image:\n                    content_list.extend([{\"type\": \"image\"} for _ in content.image])\n                    delta_multi_modal_data[\"image\"].extend(content.image)\n                if content.video:\n                    content_list.extend([{\"type\": \"video\"} for _ in content.video])\n                    delta_multi_modal_data[\"video\"].extend(content.video)\n                if content.text:\n                    content_list.append({\"type\": \"text\", \"text\": content.text})\n                self.messages.append(Message(role=\"tool\", content=content_list))\n\n        messages = [*BASE_CHAT_HISTORY, *self.messages[-len(contents) :]]\n        tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n\n        for key in self.multi_modal_keys:\n            if len(delta_multi_modal_data[key]) > 0:\n                self.multi_modal_data[key].extend(delta_multi_modal_data[key])\n\n        # We just passed the new multi-modal data to the chat template to update the input_ids.\n        content_info = self._handle_apply_chat_template(\n            processing_class,\n            messages,\n            multi_modal_data=delta_multi_modal_data,\n            tools=tools,\n            add_generation_prompt=False,\n            tokenize=True,\n            return_dict=True,\n        )\n        content_ids = content_info[\"input_ids\"][..., self.base_conv_wo_gen_prompt_end_pos :]\n\n        # process multi_modal_inputs\n        multi_modal_inputs = content_info.copy()\n        multi_modal_inputs.pop(\"input_ids\", None)\n        multi_modal_inputs.pop(\"attention_mask\", None)\n\n        # chat templates include generation prompt tokens (e.g., \"<im_start>assistant\\n\")\n        # So when tool response is added, we need to explicitly remove these tokens.\n        self._remove_generation_prompt_ids_if_present()\n\n        self._update_input_ids(\n            processing_class,\n            content_ids,\n            attention_mask=True,\n            loss_mask=False,\n            new_multi_modal_inputs=multi_modal_inputs,\n        )\n\n    def update_metrics(self, metrics: Any, tool_id: str) -> None:\n        \"\"\"\n        metrics: should be a dict of tools_name -> Any\n        \"\"\"\n        if self.metrics.get(tool_id) is None:\n            self.metrics[tool_id] = []\n        self.metrics[tool_id].append(metrics)\n\n    def _get_prompt_diffs(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        full_prompt_ids: torch.Tensor,\n        current_prompt_ids: torch.Tensor,\n        diff_surrounding_chars: int = 10,\n    ) -> list[dict[str, Any]]:\n        \"\"\"Get differences between full prompt and current prompt with surrounding context.\n\n        This function helps debug tokenization mismatches by showing the differences between\n        full prompt and current prompt with surrounding context. Instead of just showing\n        the exact diff, it includes additional tokens before and after to help locate\n        the issue in the chat template.\n\n        For example, if the actual diff is a newline change from \"\\n\\n\" to \"\\n\", with\n        diff_surrounding_chars the output might look like:\n\n        full_prompt_chunk:    \"<|im_start|>assistant\\n\\nI think...\"\n        current_prompt_chunk: \"<|im_start|>assistant\\nI think...\"\n\n        This context makes it much easier to identify where in the chat template the\n        mismatch occurs.\n\n        Args:\n            processing_class: The processing class to use for decoding the token IDs\n            full_prompt_ids: Token IDs from applying chat template to all messages at once\n            current_prompt_ids: Token IDs from incremental chat template application\n            diff_surrounding_chars: Number of surrounding characters to include for context (default: 10)\n\n        Returns:\n            List of dicts containing the differing chunks with context and their indices\n        \"\"\"\n        full_prompt_ids = full_prompt_ids.squeeze(0)\n        current_prompt_ids = current_prompt_ids.squeeze(0)\n        full_prompt = processing_class.decode(full_prompt_ids, skip_special_tokens=False)\n        current_prompt = processing_class.decode(current_prompt_ids, skip_special_tokens=False)\n        s = difflib.SequenceMatcher(None, full_prompt, current_prompt, autojunk=False)\n        diffs = []\n        for tag, i1, i2, j1, j2 in s.get_opcodes():\n            if tag == \"equal\":\n                continue\n\n            # Get the surrounding context for better readability\n            start_i = max(0, i1 - diff_surrounding_chars)\n            end_i = min(len(full_prompt), i2 + diff_surrounding_chars)\n            start_j = max(0, j1 - diff_surrounding_chars)\n            end_j = min(len(current_prompt), j2 + diff_surrounding_chars)\n\n            diffs.append(\n                {\n                    \"full_prompt_chunk\": full_prompt[start_i:end_i],\n                    \"current_prompt_chunk\": current_prompt[start_j:end_j],\n                    \"indices\": (start_i, end_i, start_j, end_j),\n                }\n            )\n        return diffs\n\n    def _remove_generation_prompt_ids_if_present(self) -> None:\n        \"\"\"\n        Remove generation prompt IDs from input tensors if they are present at the end.\n        \"\"\"\n        if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all():\n            self.input_ids = self.input_ids[..., : -self.generation_prompt_ids.shape[-1]]\n            self.attention_mask = self.attention_mask[..., : -self.generation_prompt_ids.shape[-1]]\n            self.position_ids = self.position_ids[..., : -self.generation_prompt_ids.shape[-1]]\n            self.loss_mask = self.loss_mask[..., : -self.generation_prompt_ids.shape[-1]]\n\n    def finalize(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        reward_scores: dict[str, list[float]],\n        finish_reason_type: FinishReasonTypeEnum = FinishReasonTypeEnum.STOP,\n    ) -> None:\n        self.state = AsyncRolloutRequestStateEnum.COMPLETED\n        self.reward_scores = reward_scores\n\n        # In case we failed to generate the assistant message and the generation prompt ids were already added to\n        # input_ids, remove them from the end of input_ids\n        self._remove_generation_prompt_ids_if_present()\n\n        self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :]\n\n        if self.tokenization_sanity_check_mode != TokenizationSanityCheckModeEnum.DISABLE:\n            # When there is a diff, we log the diffs with diff_surrounding_chars context\n            diff_surrounding_chars = 10\n\n            messages = [msg.model_dump() for msg in self.messages]\n            tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n            full_prompt_info = self._handle_apply_chat_template(\n                processing_class,\n                messages,\n                multi_modal_data=self.multi_modal_data,\n                tools=tools,\n                add_generation_prompt=False,\n                tokenize=True,\n                return_dict=True,\n            )\n            full_prompt_ids = full_prompt_info[\"input_ids\"]\n\n            # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict\n            # because np.array() only keeps the keys for BatchFeature.\n            full_prompt_multi_modal_inputs = full_prompt_info.copy()\n            full_prompt_multi_modal_inputs.pop(\"input_ids\", None)\n            full_prompt_multi_modal_inputs.pop(\"attention_mask\", None)\n\n            for multi_modal_inputs_key in self.multi_modal_inputs:\n                if multi_modal_inputs_key in full_prompt_multi_modal_inputs:\n                    if (\n                        not self.multi_modal_inputs[multi_modal_inputs_key]\n                        .eq(full_prompt_multi_modal_inputs[multi_modal_inputs_key])\n                        .all()\n                    ):\n                        logger.warning(\n                            f\"Multi-modal data {multi_modal_inputs_key} is not consistent. \"\n                            f\"This may lead to unexpected behavior during training. \"\n                            f\"Please review your multi_modal_inputs logic.\"\n                        )\n                else:\n                    logger.warning(\n                        f\"Multi-modal inputs key {multi_modal_inputs_key} is not found in the multi_modal_inputs. \"\n                        f\"This may lead to unexpected behavior during training.\"\n                        f\"Please review your multi_modal_inputs logic.\"\n                    )\n\n            if diffs := self._get_prompt_diffs(\n                processing_class, full_prompt_ids, self.input_ids, diff_surrounding_chars=diff_surrounding_chars\n            ):\n                log_warning = False\n                if self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.STRICT:\n                    log_warning = True\n                elif self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.IGNORE_STRIPPABLE:\n                    non_strippable_diffs_exist = any(\n                        d[\"full_prompt_chunk\"].strip() or d[\"current_prompt_chunk\"].strip() for d in diffs\n                    )\n                    if non_strippable_diffs_exist:\n                        log_warning = True\n\n                if log_warning:\n                    mode_str = f\" ({self.tokenization_sanity_check_mode.value})\"\n                    logger.warning(\n                        f\"Inconsistent training and inference tokenization detected{mode_str}. This may lead to \"\n                        f\"unexpected behavior during training. Please review your chat template to determine if this \"\n                        f\"is intentional. For more information, refer to the multiturn README.md.\"\n                    )\n                    logger.warning(\n                        f\"Showing {diff_surrounding_chars} characters before and after the diffs for context and \"\n                        f\"better readability.\"\n                    )\n                    diff_details_list = []\n                    for d in diffs:\n                        i1, i2, j1, j2 = d[\"indices\"]\n                        diff_details_list.append(\n                            f\"idx {i1}:{i2} -> {j1}:{j2} | full_prompt_chunk: {repr(d['full_prompt_chunk'])} | \"\n                            f\"current_prompt_chunk: {repr(d['current_prompt_chunk'])}\"\n                        )\n                    diff_details = \"\\n\".join(diff_details_list)\n                    logger.warning(f\"Found differences:\\n{diff_details}\")\n\n        if finish_reason_type == FinishReasonTypeEnum.STOP:\n            pass\n        elif finish_reason_type == FinishReasonTypeEnum.LENGTH:\n            pass\n        else:\n            raise ValueError(f\"Unsupported finalize finish reason type: {finish_reason_type}\")\n        self.truncate_output_ids(processing_class)\n\n        assert (\n            self.input_ids.shape[-1]\n            == self.attention_mask.shape[-1]\n            == self.position_ids.shape[-1]\n            == self.loss_mask.shape[-1]\n        ), f\"\"\"Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, \n            {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}\"\"\"\n\n    def truncate_output_ids(\n        self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin\n    ) -> None:\n        self.input_ids = self.input_ids[..., : self.max_model_len]\n        self.attention_mask = self.attention_mask[..., : self.max_model_len]\n        self.position_ids = self.position_ids[..., : self.max_model_len]\n        self.loss_mask = self.loss_mask[..., : self.max_model_len]\n        self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :][..., : self.max_response_len]\n        self.response_attention_mask = self.attention_mask[..., self.prompt_attention_mask.shape[-1] :][\n            ..., : self.max_response_len\n        ]\n        self.response_position_ids = self.position_ids[..., self.prompt_position_ids.shape[-1] :][\n            ..., : self.max_response_len\n        ]\n        self.response_loss_mask = self.loss_mask[..., self.prompt_loss_mask.shape[-1] :][..., : self.max_response_len]\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/sglang_rollout/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/sglang_rollout/async_sglang_server.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport dataclasses\nimport logging\nimport os\nfrom typing import Any, Optional\n\nimport ray\nimport sglang.srt.entrypoints.engine\nimport torch\nfrom ray.actor import ActorHandle\nfrom sglang.srt.entrypoints.http_server import (\n    ServerArgs,\n    _GlobalState,\n    _launch_subprocesses,\n    app,\n    set_global_state,\n)\nfrom sglang.srt.managers.io_struct import (\n    GenerateReqInput,\n    ReleaseMemoryOccupationReqInput,\n    ResumeMemoryOccupationReqInput,\n)\nfrom sglang.srt.managers.tokenizer_manager import ServerStatus\n\nfrom verl.single_controller.ray import RayClassWithInitArgs\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter, _set_envs_and_config\nfrom verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address, run_unvicorn\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(logging.INFO)\n\n\n@ray.remote(num_cpus=1)\nclass SGLangHttpServer:\n    \"\"\"SGLang http server in single node, this is equivalent to launch server with command line:\n    ```\n    python -m sglang.launch_server --node-rank 0 --nnode 1 ...\n    ```\n\n    Args:\n        config (DictConfig): full config.\n        rollout_mode (RolloutMode): rollout mode.\n        replica_rank (int): replica rank, a replica may contain multiple nodes.\n        node_rank (int): node rank.\n        nnodes (int): number of nodes.\n        cuda_visible_devices (str): cuda visible devices.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: RolloutConfig,\n        model_config: HFModelConfig,\n        rollout_mode: RolloutMode,\n        workers: list[ActorHandle],\n        replica_rank: int,\n        node_rank: int,\n        nnodes: int,\n        cuda_visible_devices: str,\n    ):\n        print(f\"SGLang http server: {rollout_mode=}, {replica_rank=}, {node_rank=}, {nnodes=}, {cuda_visible_devices=}\")\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = cuda_visible_devices\n        assert torch.cuda.is_available(), \"SGLang http server should run on GPU node\"\n\n        self.config: RolloutConfig = omega_conf_to_dataclass(config)\n        self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)\n        self.config.max_model_len = self.config.prompt_length + self.config.response_length\n        self.rollout_mode = rollout_mode\n        self.workers = workers\n\n        self.replica_rank = replica_rank\n        self.node_rank = node_rank\n        self.nnodes = nnodes\n\n        if self.rollout_mode != RolloutMode.HYBRID and self.config.load_format == \"dummy\":\n            logger.warning(f\"rollout mode is {self.rollout_mode}, load_format is dummy, set to auto\")\n            self.config.load_format = \"auto\"\n\n        # used for http server\n        self._server_address = ray.util.get_node_ip_address().strip(\"[]\")\n        self._server_port = None\n\n        # used for NCCL process group\n        if self.node_rank == 0:\n            self._master_address = self._server_address\n            self._master_port, self._master_sock = get_free_port(self._server_address)\n            logger.info(\n                f\"SGLangHttpServer, replica_rank: {self.replica_rank}, \"\n                f\"master address: {self._master_address}, port: {self._master_port}\"\n            )\n        else:\n            self._master_address = None\n            self._master_port = None\n\n    def get_master_address(self):\n        \"\"\"Get master address and port for init NCCL process group.\"\"\"\n        return self._master_address, self._master_port\n\n    def get_server_address(self):\n        \"\"\"Get http server address and port.\"\"\"\n        assert self._server_port is not None, \"http server is not launched, port is None\"\n        return self._server_address, self._server_port\n\n    async def launch_server(self, master_address: str = None, master_port: int = None):\n        if self.node_rank != 0:\n            assert master_address and master_port, \"non-master node should provide master address and port\"\n            self._master_address = master_address\n            self._master_port = master_port\n\n        engine_kwargs = self.config.get(\"engine_kwargs\", {}).get(\"sglang\", {}) or {}\n        attention_backend = engine_kwargs.pop(\"attention_backend\", None)\n        dist_init_addr = (\n            f\"[{self._master_address}]:{self._master_port}\"\n            if is_valid_ipv6_address(self._master_address)\n            else f\"{self._master_address}:{self._master_port}\"\n        )\n\n        args = {\n            \"model_path\": self.model_config.local_path,\n            \"dtype\": self.config.dtype,\n            \"mem_fraction_static\": self.config.gpu_memory_utilization,\n            \"disable_cuda_graph\": self.config.enforce_eager,\n            \"enable_memory_saver\": True,\n            \"base_gpu_id\": 0,\n            \"gpu_id_step\": 1,\n            \"tp_size\": self.config.tensor_model_parallel_size,\n            \"dp_size\": self.config.data_parallel_size,\n            \"ep_size\": self.config.expert_parallel_size,\n            \"node_rank\": self.node_rank,\n            \"load_format\": self.config.load_format,\n            \"dist_init_addr\": dist_init_addr,\n            \"nnodes\": self.nnodes,\n            \"trust_remote_code\": self.model_config.trust_remote_code,\n            \"max_running_requests\": self.config.get(\"max_num_seqs\", None),\n            \"log_level\": \"error\",\n            \"mm_attention_backend\": \"fa3\",\n            \"attention_backend\": attention_backend if attention_backend is not None else \"fa3\",\n            \"skip_tokenizer_init\": self.config.skip_tokenizer_init,\n            **engine_kwargs,\n        }\n        # enable_weights_cpu_backup is supported in sglang>=0.5.3\n        if \"enable_weights_cpu_backup\" in [f.name for f in dataclasses.fields(ServerArgs)]:\n            enable_weights_cpu_backup = True if self.rollout_mode == RolloutMode.COLOCATED else False\n            args[\"enable_weights_cpu_backup\"] = enable_weights_cpu_backup\n\n        # NOTE: We can't directly call SGLang's launch_server since it's not an async function.\n        # https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py\n        sglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config\n        os.environ[\"SGLANG_BLOCK_NONZERO_RANK_CHILDREN\"] = \"0\"\n        server_args = ServerArgs(**args)\n        self.tokenizer_manager, self.template_manager, self.scheduler_info = _launch_subprocesses(\n            server_args=server_args\n        )\n\n        # In multi-node cases, non-zero rank nodes should not launch http server.\n        if self.node_rank > 0:\n            return\n\n        set_global_state(\n            _GlobalState(\n                tokenizer_manager=self.tokenizer_manager,\n                template_manager=self.template_manager,\n                scheduler_info=self.scheduler_info,\n            )\n        )\n        app.is_single_tokenizer_mode = True\n        self._server_port, self._server_task = await run_unvicorn(app, server_args, self._server_address)\n        self.tokenizer_manager.server_status = ServerStatus.Up\n\n    async def wake_up(self):\n        if self.rollout_mode == RolloutMode.HYBRID:\n            # Call all workers to switch between trainer mode and rollout mode.\n            await asyncio.gather(*[worker.wake_up.remote() for worker in self.workers])\n        elif self.rollout_mode == RolloutMode.COLOCATED:\n            # Directly call engine to wake up without sync weights.\n            # FIXME(@wuxibin): sglang seems resume with random weights.\n            obj = ResumeMemoryOccupationReqInput(tags=[\"kv_cache\", \"weights\"])\n            await self.tokenizer_manager.resume_memory_occupation(obj, None)\n            await self.tokenizer_manager.flush_cache()\n        elif self.rollout_mode == RolloutMode.STANDALONE:\n            logger.info(\"skip wake_up in standalone mode\")\n\n    async def sleep(self):\n        if self.rollout_mode == RolloutMode.HYBRID:\n            await asyncio.gather(*[worker.sleep.remote() for worker in self.workers])\n        elif self.rollout_mode == RolloutMode.COLOCATED:\n            obj = ReleaseMemoryOccupationReqInput(tags=[\"kv_cache\", \"weights\"])\n            await self.tokenizer_manager.release_memory_occupation(obj, None)\n        elif self.rollout_mode == RolloutMode.STANDALONE:\n            logger.info(\"skip sleep in standalone mode\")\n\n    async def generate(\n        self,\n        prompt_ids: torch.Tensor,\n        sampling_params: dict[str, Any],\n        request_id: str,\n        image_data: Optional[list[Any]] = None,\n    ) -> TokenOutput:\n        \"\"\"Generate sequence with token-in-token-out.\"\"\"\n        # TODO(@wuxibin): switch to `/generate` http endpoint once multi-modal support ready.\n        max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids) - 1)\n        sampling_params[\"max_new_tokens\"] = max_new_tokens\n        return_logprob = sampling_params.pop(\"logprobs\", False)\n\n        request = GenerateReqInput(\n            rid=request_id,\n            input_ids=prompt_ids,\n            sampling_params=sampling_params,\n            return_logprob=return_logprob,\n            image_data=image_data,\n        )\n        output = await self.tokenizer_manager.generate_request(request, None).__anext__()\n        if return_logprob:\n            output_token_logprobs = output[\"meta_info\"][\"output_token_logprobs\"]\n            log_probs, token_ids = zip(\n                *[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs], strict=True\n            )\n        else:\n            token_ids = output[\"output_ids\"]\n            log_probs = None\n        return TokenOutput(token_ids=token_ids, log_probs=log_probs)\n\n\n_rollout_worker_actor_cls = ray.remote(ServerAdapter)\n\n\nclass SGLangReplica(RolloutReplica):\n    def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:\n        \"\"\"Get rollout worker actor class for colocated and standalone mode.\"\"\"\n        worker_dict_cls = RayClassWithInitArgs(\n            cls=_rollout_worker_actor_cls,\n            config=self.config,\n            model_config=self.model_config,\n            device_mesh=None,\n        )\n        return worker_dict_cls\n\n    async def launch_servers(self):\n        \"\"\"Launch http server in each node.\"\"\"\n        assert len(self.workers) == self.world_size, (\n            f\"worker number {len(self.workers)} not equal to world size {self.world_size}\"\n        )\n\n        # get (node_id, CUDA_VISIBLE_DEVICES) of all workers\n        worker_infos = await asyncio.gather(\n            *[\n                worker.__ray_call__.remote(\n                    lambda self: (ray.get_runtime_context().get_node_id(), os.environ[\"CUDA_VISIBLE_DEVICES\"])\n                )\n                for worker in self.workers\n            ]\n        )\n        worker_cuda_visible_devices = [worker_info[1] for worker_info in worker_infos]\n        worker_node_ids = [worker_info[0] for worker_info in worker_infos]\n\n        # create server actor in each node with node affinity and cuda visible devices\n        for node_rank in range(self.nnodes):\n            workers = self.workers[node_rank * self.gpus_per_node : (node_rank + 1) * self.gpus_per_node]\n            node_cuda_visible_devices = \",\".join(\n                worker_cuda_visible_devices[node_rank * self.gpus_per_node : (node_rank + 1) * self.gpus_per_node]\n            )\n            node_id = worker_node_ids[node_rank * self.gpus_per_node]\n            name = (\n                f\"sglang_server_{self.replica_rank}_{node_rank}\"\n                if not self.is_reward_model\n                else f\"sglang_server_reward_{self.replica_rank}_{node_rank}\"\n            )\n            server = SGLangHttpServer.options(\n                scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(\n                    node_id=node_id,\n                    soft=False,\n                ),\n                runtime_env={\"env_vars\": {\"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES\": \"1\"}},\n                name=name,\n            ).remote(\n                config=self.config,\n                model_config=self.model_config,\n                rollout_mode=self.rollout_mode,\n                workers=workers,\n                replica_rank=self.replica_rank,\n                node_rank=node_rank,\n                nnodes=self.nnodes,\n                cuda_visible_devices=node_cuda_visible_devices,\n            )\n            self.servers.append(server)\n\n        # launch http server in each node\n        master_address, master_port = await self.servers[0].get_master_address.remote()\n        await asyncio.gather(\n            *[\n                server.launch_server.remote(master_address=master_address, master_port=master_port)\n                for server in self.servers\n            ]\n        )\n\n        # get http server address from first server\n        server_address, server_port = await self.servers[0].get_server_address.remote()\n        self._server_handle = self.servers[0]\n        self._server_address = (\n            f\"[{server_address}]:{server_port}\"\n            if is_valid_ipv6_address(server_address)\n            else f\"{server_address}:{server_port}\"\n        )\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/sglang_rollout/http_server_engine.py",
    "content": "# Copyright 2025 z.ai\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# This file is adapted from multiple sources:\n# 1. THUDM/slime project\n#    Original source: https://github.com/THUDM/slime/blob/main/slime/backends/sglang_utils/http_server_engine.py\n#    Copyright 2025 z.ai\n#    Licensed under the Apache License, Version 2.0\n# 2. SGLang project\n#    Original source: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server_engine.py\n#    Copyright 2023-2024 SGLang Team\n#    Licensed under the Apache License, Version 2.0\n#\n# Modifications made by z.ai and ModelBest Inc. include but are not limited to:\n# - Enhanced error handling and retry logic\n# - Added async support with connection pooling\n# - Extended functionality for distributed weight updates\n# - Improved logging and monitoring capabilities\n# - Additional configuration options and optimizations\n\n\"\"\"HTTP Server Engine Adapter for SGLang.\n\nThis module provides HTTP-based adapters for SGLang engines, allowing communication\nwith SGLang servers through HTTP requests instead of direct engine calls.\n\nClasses:\n    HttpServerAdapter: Synchronous HTTP adapter for SGLang engines\n    AsyncHttpServerAdapter: Asynchronous HTTP adapter for SGLang engines\n\nFunctions:\n    launch_server_process: Launch and initialize an SGLang HTTP server process\n\"\"\"\n\nimport asyncio\nimport logging\nimport multiprocessing\nimport os\nimport time\nfrom contextlib import asynccontextmanager\nfrom typing import Any, Callable, Optional\n\nimport aiohttp\nimport requests\nfrom sglang.srt.entrypoints.EngineBase import EngineBase\nfrom sglang.srt.entrypoints.http_server import launch_server\nfrom sglang.srt.managers.io_struct import (\n    UpdateWeightsFromTensorReqInput,\n)\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import kill_process_tree\n\n# Configure logger\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n# Default configuration constants\nDEFAULT_TIMEOUT = 60.0\nDEFAULT_MAX_ATTEMPTS = 3\nDEFAULT_RETRY_DELAY = 2.0\nDEFAULT_MAX_CONNECTIONS = 2000\nDEFAULT_MAX_WAIT_TIME = 300.0\n\n\ndef _read_response(response: requests.Response):\n    if response.status_code == 204 or not response.content:\n        return {}\n    try:\n        return response.json()\n    except ValueError:\n        return {\n            \"content_type\": response.headers.get(\"Content-Type\", \"\"),\n            \"text\": response.text,\n        }\n\n\nasync def _read_async_response(resp: aiohttp.ClientResponse) -> dict[str, Any]:\n    if resp.status == 204 or (resp.content_length == 0):\n        return {}\n\n    try:\n        return await resp.json(content_type=None)\n    except Exception:\n        try:\n            text = await resp.text()\n        except Exception:\n            return {}\n        return {\n            \"content_type\": (resp.headers.get(\"Content-Type\") or \"\"),\n            \"text\": text,\n        }\n\n\ndef launch_server_process(\n    server_args: ServerArgs,\n    timeout: float = DEFAULT_TIMEOUT,\n    max_wait_time=DEFAULT_MAX_WAIT_TIME,\n    first_rank_in_node=False,\n) -> multiprocessing.Process:\n    \"\"\"Launch an SGLang HTTP server process and wait for it to be ready.\n\n    This function starts a new process running an SGLang HTTP server, then waits\n    for the server to become ready by polling its health endpoints. It ensures\n    the server is fully operational before returning.\n\n    Args:\n        server_args (ServerArgs): Server configuration arguments including host, port, and other settings\n        timeout (float, optional): Timeout for individual HTTP requests during health checks.\n            Defaults to DEFAULT_TIMEOUT.\n\n    Returns:\n        multiprocessing.Process: The launched multiprocessing.Process instance\n\n    Raises:\n        RuntimeError: If the server process terminates unexpectedly during startup or cache flush\n        TimeoutError: If server fails to become ready within reasonable time (300 seconds)\n        requests.RequestException: If health check requests fail repeatedly\n\n    Note:\n        This function will return immediately for non-master nodes (node_rank != 0),\n        but the process will still be started and returned.\n        This is for consistency; except for the process obtained by node_rank = 0,\n        other processes have no actual effect.\n    \"\"\"\n    p = multiprocessing.Process(target=launch_server, args=(server_args,))\n    if server_args.node_rank != 0 or not first_rank_in_node:\n        logger.info(f\"Server process started with PID {p.pid} for node rank {server_args.node_rank}\", flush=True)\n        return p\n\n    p.start()\n\n    base_url = server_args.url()\n    headers = {\n        \"Content-Type\": \"application/json; charset=utf-8\",\n        \"Authorization\": f\"Bearer {server_args.api_key}\",\n    }\n\n    # Health check with overall timeout\n    start_time = time.time()\n\n    with requests.Session() as session:\n        while time.time() - start_time < max_wait_time:\n            if not p.is_alive():\n                raise RuntimeError(\"Server process terminated unexpectedly during startup\")\n\n            try:\n                if server_args.is_embedding:\n                    response = session.get(f\"{base_url}/health\", headers=headers, timeout=timeout)\n                else:\n                    response = session.get(f\"{base_url}/health_generate\", headers=headers, timeout=timeout)\n                if response.status_code == 200:\n                    break\n            except requests.RequestException as e:\n                logger.debug(f\"Health check failed: {e}\")\n\n            time.sleep(2)\n        else:\n            p.terminate()\n            logger.error(f\"Server in {base_url} failed to become healthy within timeout period\")\n            raise TimeoutError(\"Server failed to become healthy within timeout period\")\n\n        # Ensure cache is ready\n        while time.time() - start_time < max_wait_time:\n            if not p.is_alive():\n                raise RuntimeError(\"Server process terminated unexpectedly during cache flush\")\n\n            try:\n                response = session.get(f\"{base_url}/flush_cache\", headers=headers, timeout=timeout)\n                if response.status_code == 200:\n                    break\n            except requests.RequestException as e:\n                logger.debug(f\"Cache flush check failed: {e}\")\n\n            time.sleep(2)\n        else:\n            p.terminate()\n            raise TimeoutError(\"Server cache flush failed within timeout period\")\n\n    return p\n\n\nclass HttpServerAdapter(EngineBase):\n    \"\"\"HTTP-based adapter for SGLang engines.\n\n    This adapter allows interaction with SGLang engines through HTTP requests\n    instead of direct engine calls. It launches an HTTP server process and\n    provides methods to communicate with it via REST API calls.\n\n    You can use this class to launch a server from a HttpServerAdapter instance.\n    We recommend using this class only when you need to use http server.\n    Otherwise, you can use Engine directly.\n\n    Attributes:\n        router_ip (Optional[str]): IP address of the router for worker registration\n        router_port (Optional[int]): Port of the router for worker registration\n        server_args (ServerArgs): Server configuration arguments\n        node_rank (int): Rank of this node in distributed setup\n        process (multiprocessing.Process): The launched server process\n        timeout (float): HTTP request timeout in seconds\n        max_attempts (int): Maximum number of attempts for requests\n        retry_delay (float): Base delay between retries in seconds\n    \"\"\"\n\n    def __init__(\n        self,\n        router_ip: Optional[str] = None,\n        router_port: Optional[int] = None,\n        timeout: float = DEFAULT_TIMEOUT,\n        max_attempts: int = DEFAULT_MAX_ATTEMPTS,\n        retry_delay: float = DEFAULT_RETRY_DELAY,\n        first_rank_in_node: bool = False,\n        max_start_wait_time: float = DEFAULT_MAX_WAIT_TIME,\n        launch_server: bool = True,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"Initialize the HTTP server engine adapter.\n\n        Args:\n            router_ip (Optional[str], optional): IP address of router for worker registration.\n                Defaults to None.\n            router_port (Optional[int], optional): Port of router for worker registration.\n                Defaults to None.\n            timeout (float, optional): HTTP request timeout in seconds.\n                Defaults to DEFAULT_TIMEOUT.\n            max_attempts (int, optional): Maximum number of retry attempts for failed requests.\n                Defaults to DEFAULT_MAX_ATTEMPTS.\n            retry_delay (float, optional): Base delay between retries in seconds.\n                Defaults to DEFAULT_RETRY_DELAY.\n            launch_server (bool, optional): Whether to launch the server process.\n                Defaults to True.\n            **kwargs (Any): Additional arguments passed to ServerArgs\n\n        Note:\n            TODO: @ChangyiYang Enable SGLang router for this http server engine\n            If both router_ip and router_port are provided and this is the master node\n            (node_rank == 0), the adapter will automatically register with the router.\n        \"\"\"\n        self.router_ip: Optional[str] = router_ip\n        self.router_port: Optional[int] = router_port\n        self.timeout: float = timeout\n        self.max_attempts: int = max_attempts\n        self.retry_delay: float = retry_delay\n        self.server_args: ServerArgs = ServerArgs(**kwargs)\n        self.node_rank: int = self.server_args.node_rank\n        self.max_start_wait_time: float = max_start_wait_time\n\n        logger.info(\n            f\"Launch HttpServerAdapter at: {self.server_args.host}:{self.server_args.port} with {first_rank_in_node}\"\n        )\n        if launch_server:\n            self.process: multiprocessing.Process = launch_server_process(\n                self.server_args, self.timeout, self.max_start_wait_time, first_rank_in_node\n            )\n\n        if self.node_rank == 0 and self.router_ip and self.router_port:\n            self._register_with_router()\n\n    def _register_with_router(self) -> None:\n        \"\"\"Register worker with router with error handling.\n\n        This method attempts to register the current worker with a router service.\n        If registration fails, it logs an error but does not raise an exception,\n        allowing the server to continue operating without router integration.\n\n        Raises:\n            Does not raise exceptions - all errors are logged and handled gracefully.\n        \"\"\"\n        try:\n            url = f\"http://{self.router_ip}:{self.router_port}/add_worker\"\n            params = {\"url\": f\"http://{self.server_args.host}:{self.server_args.port}\"}\n            response = requests.post(url, params=params, timeout=self.timeout)\n            response.raise_for_status()\n            logger.info(\"Successfully registered with router\")\n        except Exception as e:\n            logger.error(f\"Failed to register with router: {e}\")\n            # Don't raise here - server can still work without router\n\n    def _make_request(\n        self,\n        endpoint: str,\n        payload: Optional[dict[str, Any]] = None,\n        method: str = \"POST\",\n        timeout: float = DEFAULT_TIMEOUT,\n        only_master: bool = True,\n    ) -> dict[str, Any]:\n        \"\"\"Make a HTTP request with retry logic and consistent error handling.\n\n        Args:\n            endpoint (str): The API endpoint to call (without leading slash)\n            payload (Optional[Dict[str, Any]], optional): The JSON payload to send.\n                Defaults to empty dict if None.\n            method (str, optional): HTTP method to use. Defaults to \"POST\".\n\n        Returns:\n            Dict[str, Any]: The JSON response from the server\n\n        Raises:\n            requests.HTTPError: If the HTTP request fails with a client/server error\n            RuntimeError: If all retry attempts are exhausted\n\n        Note:\n            - For non-master nodes (node_rank != 0), returns empty dict immediately\n            - Uses exponential backoff for retries\n            - Logs warnings for timeout and connection errors, errors for HTTP errors\n        \"\"\"\n        if only_master and self.node_rank != 0:\n            return {}\n\n        url = f\"http://{self.server_args.host}:{self.server_args.port}/{endpoint}\"\n\n        for attempt in range(self.max_attempts):\n            try:\n                if method.upper() == \"GET\":\n                    response = requests.get(url, timeout=self.timeout)\n                else:\n                    response = requests.post(url, json=payload or {}, timeout=self.timeout)\n\n                response.raise_for_status()\n                return _read_response(response)\n\n            except requests.exceptions.Timeout:\n                logger.warning(f\"Request to {endpoint} timed out (attempt {attempt + 1})\")\n            except requests.exceptions.ConnectionError:\n                logger.warning(f\"Connection error for {endpoint} (attempt {attempt + 1})\")\n            except requests.exceptions.HTTPError as e:\n                logger.error(f\"HTTP error for {endpoint}: {e}\")\n                raise\n            except Exception as e:\n                logger.error(f\"Unexpected error for {endpoint}: {e}\")\n                if attempt == self.max_attempts - 1:\n                    raise\n\n            if attempt < self.max_attempts - 1:\n                time.sleep(self.retry_delay * (2**attempt))\n\n        raise RuntimeError(f\"Failed to complete request to {endpoint} after {self.max_attempts} attempts\")\n\n    def update_weights_from_tensor(self, req: UpdateWeightsFromTensorReqInput) -> dict[str, Any]:\n        \"\"\"Update model weights from tensor data.\n\n        The HTTP server will only post meta data, and the real weights will be\n        copied directly from GPUs.\n\n        Args:\n            serialized_named_tensors (List[str]): List of serialized tensor data\n            load_format (Optional[str], optional): Format specification for loading weights.\n                Defaults to None.\n            flush_cache (bool, optional): Whether to flush cache after updating weights.\n                Defaults to False.\n\n        Returns:\n            Dict[str, Any]: Server response containing update status\n\n        Note:\n            The model should be on GPUs rather than CPU for this functionality to work properly.\n            If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.\n        \"\"\"\n        import base64\n\n        named_tensors = req.serialized_named_tensors\n        load_format = req.load_format\n        flush_cache = req.flush_cache\n\n        if named_tensors:\n            serialized_named_tensors = [\n                base64.b64encode(named_tensor).decode(\"utf-8\") for named_tensor in named_tensors\n            ]\n        else:\n            serialized_named_tensors = []\n\n        return self._make_request(\n            \"update_weights_from_tensor\",\n            {\n                \"serialized_named_tensors\": serialized_named_tensors,\n                \"load_format\": load_format,\n                \"flush_cache\": flush_cache,\n            },\n        )\n\n    def shutdown(self) -> None:\n        \"\"\"Shutdown the HTTP server and clean up resources.\n\n        This method performs the following cleanup operations:\n        1. Unregisters the worker from the router (if configured)\n        2. Terminates the server process tree\n\n        All operations are performed with error handling to ensure graceful shutdown\n        even if individual steps fail.\n\n        Note:\n            This method should be called when the adapter is no longer needed\n            to ensure proper cleanup of resources and processes.\n        \"\"\"\n        # Unregister from router\n        if self.router_ip and self.router_port:\n            try:\n                url = f\"http://{self.router_ip}:{self.router_port}/remove_worker\"\n                params = {\"url\": f\"http://{self.server_args.host}:{self.server_args.port}\"}\n                requests.post(url, params=params, timeout=5.0)  # Short timeout for shutdown\n                logger.info(\"Successfully unregistered from router\")\n            except Exception as e:\n                logger.warning(f\"Failed to unregister from router: {e}\")\n\n        # Kill server process\n        if hasattr(self, \"process\") and self.process is not None:\n            try:\n                kill_process_tree(self.process.pid)\n                logger.info(\"Server process terminated\")\n            except Exception as e:\n                logger.error(f\"Failed to terminate server process: {e}\")\n\n    def generate(\n        self,\n        prompt: Optional[str] = None,\n        sampling_params: Optional[dict[str, Any]] = None,\n        input_ids: Optional[list[int]] = None,\n        image_data: Optional[Any] = None,\n        return_logprob: bool = False,\n        logprob_start_len: Optional[int] = None,\n        top_logprobs_num: Optional[int] = None,\n        token_ids_logprob: Optional[list[int]] = None,\n        lora_path: Optional[str] = None,\n        custom_logit_processor: Optional[Callable] = None,\n    ) -> dict[str, Any]:\n        \"\"\"Generate text using the SGLang server.\n\n        Args:\n            prompt (Optional[str], optional): Text prompt for generation. Defaults to None.\n            sampling_params (Optional[Dict[str, Any]], optional): Parameters controlling\n                text generation sampling. Defaults to None.\n            input_ids (Optional[List[int]], optional): Alternative to prompt, direct token IDs input.\n                Defaults to None.\n            image_data (Optional[Any], optional): Image data for multimodal generation.\n                Defaults to None.\n            return_logprob (bool, optional): Whether to return log probabilities.\n                Defaults to False.\n            logprob_start_len (Optional[int], optional): Starting length for log probability calculation.\n                Defaults to None.\n            top_logprobs_num (Optional[int], optional): Number of top log probabilities to return.\n                Defaults to None.\n            token_ids_logprob (Optional[List[int]], optional): Specific token IDs for\n                log probability calculation. Defaults to None.\n            lora_path (Optional[str], optional): Path to LoRA adapter weights. Defaults to None.\n            custom_logit_processor (Optional[Callable], optional): Custom logit processing function.\n                Defaults to None.\n\n        Returns:\n            Dict[str, Any]: Generated text and associated metadata from the server\n\n        Note:\n            Either prompt or input_ids should be provided, but not both.\n            The response format depends on the server configuration and parameters.\n        \"\"\"\n        payload = {\n            \"text\": prompt,\n            \"sampling_params\": sampling_params,\n            \"input_ids\": input_ids,\n            \"image_data\": image_data,\n            \"return_logprob\": return_logprob,\n            \"logprob_start_len\": logprob_start_len,\n            \"top_logprobs_num\": top_logprobs_num,\n            \"token_ids_logprob\": token_ids_logprob,\n            \"lora_path\": lora_path,\n            \"custom_logit_processor\": custom_logit_processor,\n        }\n        # Filter out None values\n        payload = {k: v for k, v in payload.items() if v is not None}\n\n        return self._make_request(\"generate\", payload, only_master=False)\n\n    def reward_score(\n        self,\n        prompt: Optional[str] = None,\n        input_ids: Optional[list[int]] = None,\n        image_data: Optional[Any] = None,\n        lora_path: Optional[str] = None,\n    ) -> dict[str, Any]:\n        assert self.server_args.is_embedding, \"Score is only supported for embedding models\"\n        payload = {\n            \"text\": prompt,\n            \"input_ids\": input_ids,\n            \"image_data\": image_data,\n            \"lora_path\": lora_path,\n        }\n        # Filter out None values\n        payload = {k: v for k, v in payload.items() if v is not None}\n\n        return self._make_request(\"classify\", payload, only_master=False)\n\n    def flush_cache(self) -> dict[str, Any]:\n        \"\"\"Flush the cache of the server.\n\n        This method repeatedly attempts to flush the server cache until successful.\n        The flush operation will not return status 200 when there are pending requests.\n\n        Returns:\n            Dict[str, Any]: Server response indicating cache flush status.\n                For non-master nodes, returns empty dict.\n\n        Note:\n            Uses retry logic with limited attempts (max_attempts * 2) to avoid infinite loops.\n            Each retry includes a delay to allow pending requests to complete.\n        \"\"\"\n        if self.node_rank != 0:\n            return {}\n\n        # Use retry logic with limited attempts to avoid infinite loops\n        for attempt in range(self.max_attempts * 2):  # Allow more retries for cache flush\n            try:\n                response = requests.get(\n                    f\"http://{self.server_args.host}:{self.server_args.port}/flush_cache\", timeout=self.timeout\n                )\n                if response.status_code == 200:\n                    return _read_response(response)\n            except Exception as e:\n                logger.warning(f\"Error flushing cache (attempt {attempt + 1}): {e}\")\n\n            time.sleep(self.retry_delay)\n\n        logger.error(\"Failed to flush cache after maximum attempts\")\n        return {}\n\n    def release_memory_occupation(self, tags: Optional[list[str]] = None) -> dict[str, Any]:\n        \"\"\"Release GPU memory occupation temporarily.\n\n        Args:\n            tags (Optional[List[str]], optional): List of tags to specify which memory to release.\n                If None, releases all memory. Defaults to None. [\"weights\", \"kv_cache\"]\n\n        Returns:\n            Dict[str, Any]: Server response indicating memory release status\n        \"\"\"\n        return self._make_request(\"release_memory_occupation\", {\"tags\": tags})\n\n    def resume_memory_occupation(self, tags: Optional[list[str]] = None) -> dict[str, Any]:\n        \"\"\"Resume GPU memory occupation.\n\n        Args:\n            tags (Optional[List[str]], optional): List of tags to specify which memory to resume.\n                If None, resumes all memory. Defaults to None. [\"weights\", \"kv_cache\"]\n\n        Returns:\n            Dict[str, Any]: Server response indicating memory resume status\n        \"\"\"\n        return self._make_request(\"resume_memory_occupation\", {\"tags\": tags})\n\n    def abort_request(self, rid: str = \"\", abort_all: bool = False) -> dict[str, Any]:\n        \"\"\"Abort a request.\n\n        Args:\n            rid (str): The ID of the request to abort\n            abort_all (bool, optional): Whether to abort all requests. Defaults to False.\n\n        Returns:\n            Dict[str, Any]: Server response indicating abort status\n        \"\"\"\n        return self._make_request(\"abort_request\", {\"rid\": rid, \"abort_all\": abort_all})\n\n\nclass AsyncHttpServerAdapter(HttpServerAdapter):\n    \"\"\"Asynchronous HTTP-based adapter for SGLang engines.\n\n    This class inherits from HttpServerAdapter and adds async capabilities\n    for non-blocking HTTP requests to the SGLang server. It provides the same\n    functionality as the synchronous version but with async/await support.\n\n    The async adapter is useful when you need to make multiple concurrent requests\n    or integrate with async frameworks. It uses aiohttp for efficient async HTTP\n    communication and maintains connection pooling for better performance.\n\n    Attributes:\n        max_connections (int): Maximum number of connections in the connection pool\n    \"\"\"\n\n    def __init__(\n        self,\n        router_ip: Optional[str] = None,\n        router_port: Optional[int] = None,\n        timeout: float = DEFAULT_TIMEOUT,\n        max_attempts: int = DEFAULT_MAX_ATTEMPTS,\n        retry_delay: float = DEFAULT_RETRY_DELAY,\n        max_connections: int = DEFAULT_MAX_CONNECTIONS,\n        first_rank_in_node: bool = False,\n        launch_server: bool = True,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"Initialize the async HTTP server engine adapter.\n\n        Args:\n            router_ip (Optional[str], optional): IP address of router for worker registration.\n                Defaults to None.\n            router_port (Optional[int], optional): Port of router for worker registration.\n                Defaults to None.\n            timeout (float, optional): HTTP request timeout in seconds.\n                Defaults to DEFAULT_TIMEOUT.\n            max_attempts (int, optional): Maximum number of retry attempts for failed requests.\n                Defaults to DEFAULT_MAX_ATTEMPTS.\n            retry_delay (float, optional): Base delay between retries in seconds.\n                Defaults to DEFAULT_RETRY_DELAY.\n            max_connections (int, optional): Maximum number of connections in the connection pool.\n                Defaults to DEFAULT_MAX_CONNECTIONS.\n            launch_server (bool, optional): Whether to launch the server process.\n                Defaults to True.\n            **kwargs (Any): Additional arguments passed to ServerArgs\n        \"\"\"\n        super().__init__(\n            router_ip,\n            router_port,\n            timeout,\n            max_attempts,\n            retry_delay,\n            first_rank_in_node,\n            launch_server=launch_server,\n            **kwargs,\n        )\n        self.max_connections: int = max_connections\n\n    @asynccontextmanager\n    async def _get_session(self) -> aiohttp.ClientSession:\n        \"\"\"Context manager for safe session access with proper connection pooling.\n\n        Yields:\n            aiohttp.ClientSession: Session instance for making HTTP requests\n\n        Note:\n            This method creates a new session for each request to avoid resource competition\n            while still maintaining proper connection pooling through the shared connector.\n        \"\"\"\n        # Create a new session for each request to avoid resource competition\n        connector = aiohttp.TCPConnector(\n            limit=self.max_connections,\n            limit_per_host=self.max_connections // 4,\n            ttl_dns_cache=300,\n            use_dns_cache=True,\n        )\n        timeout = aiohttp.ClientTimeout(total=self.timeout)\n        session = aiohttp.ClientSession(connector=connector, timeout=timeout)\n\n        try:\n            yield session\n        finally:\n            # Always close the session to free up resources\n            if not session.closed:\n                await session.close()\n\n    async def _make_async_request(\n        self,\n        endpoint: str,\n        payload: Optional[dict[str, Any]] = None,\n        method: str = \"POST\",\n        timeout: float = DEFAULT_TIMEOUT,\n        only_master: bool = True,\n    ) -> dict[str, Any]:\n        \"\"\"Make an async HTTP request with retry logic and consistent error handling.\n\n        Args:\n            endpoint (str): The API endpoint to call (without leading slash)\n            payload (Optional[Dict[str, Any]], optional): The JSON payload to send.\n                Defaults to empty dict if None.\n            method (str, optional): HTTP method to use. Defaults to \"POST\".\n\n        Returns:\n            Dict[str, Any]: The JSON response from the server\n\n        Raises:\n            aiohttp.ClientResponseError: If the HTTP request fails with a client/server error\n            RuntimeError: If all retry attempts are exhausted\n\n        Note:\n            - For non-master nodes (node_rank != 0), returns empty dict immediately\n            - Uses exponential backoff for retries\n            - Logs warnings for timeout and connection errors, errors for HTTP errors\n        \"\"\"\n        if only_master and self.node_rank != 0:\n            return {}\n\n        url = f\"http://{self.server_args.host}:{self.server_args.port}/{endpoint}\"\n\n        for attempt in range(self.max_attempts):\n            try:\n                async with self._get_session() as session:\n                    if method.upper() == \"GET\":\n                        async with session.get(url, timeout=timeout) as response:\n                            response.raise_for_status()\n                            return await _read_async_response(response)\n                    else:\n                        async with session.post(url, json=payload or {}, timeout=timeout) as response:\n                            response.raise_for_status()\n                            return await _read_async_response(response)\n\n            except asyncio.TimeoutError:\n                logger.warning(f\"Async request to {endpoint} timed out (attempt {attempt + 1})\")\n            except aiohttp.ClientConnectorError:\n                logger.warning(f\"Connection error for {endpoint} (attempt {attempt + 1})\")\n            except aiohttp.ClientResponseError as e:\n                logger.error(f\"HTTP error for {endpoint}: {e}\")\n                raise\n            except Exception as e:\n                logger.error(f\"Unexpected error for {endpoint}: {e}\")\n                if attempt == self.max_attempts - 1:\n                    raise\n\n            if attempt < self.max_attempts - 1:\n                await asyncio.sleep(self.retry_delay * (2**attempt))\n\n        raise RuntimeError(f\"Failed to complete async request to {endpoint} after {self.max_attempts} attempts\")\n\n    async def release_memory_occupation(self, tags: Optional[list[str]] = None) -> dict[str, Any]:\n        \"\"\"Release GPU memory occupation temporarily (async version).\n\n        Args:\n            tags (Optional[List[str]], optional): List of tags to specify which memory to release.\n                If None, releases all memory. Defaults to None. [\"weights\", \"kv_cache\"]\n\n        Returns:\n            Dict[str, Any]: Server response indicating memory release status\n        \"\"\"\n        return await self._make_async_request(\"release_memory_occupation\", {\"tags\": tags})\n\n    async def resume_memory_occupation(self, tags: Optional[list[str]] = None) -> dict[str, Any]:\n        \"\"\"Resume GPU memory occupation (async version).\n\n        Similar to AsyncEngine, this method handles first-time weight reloading\n        by calling release_memory_occupation if needed.\n\n        Args:\n            tags (Optional[List[str]], optional): List of tags to specify which memory to resume.\n                If None, resumes all memory. Defaults to None. [\"weights\", \"kv_cache\"]\n\n        Returns:\n            Dict[str, Any]: Server response indicating memory resume status\n        \"\"\"\n        return await self._make_async_request(\"resume_memory_occupation\", {\"tags\": tags})\n\n    async def update_weights_from_tensor(\n        self,\n        req: UpdateWeightsFromTensorReqInput,\n    ) -> dict[str, Any]:\n        \"\"\"Update model weights from tensor data asynchronously.\n\n        Args:\n            serialized_named_tensors (List[str]): List of serialized tensor data\n            load_format (Optional[str], optional): Format specification for loading weights.\n                Defaults to None.\n            flush_cache (bool, optional): Whether to flush cache after updating weights.\n                Defaults to True.\n\n        Returns:\n            Dict[str, Any]: Server response containing update status\n        \"\"\"\n        import base64\n\n        named_tensors = req.serialized_named_tensors\n        load_format = req.load_format\n        flush_cache = req.flush_cache\n\n        serialized_named_tensors = [base64.b64encode(named_tensor).decode(\"utf-8\") for named_tensor in named_tensors]\n        return await self._make_async_request(\n            \"update_weights_from_tensor\",\n            {\n                \"serialized_named_tensors\": serialized_named_tensors,\n                \"load_format\": load_format,\n                \"flush_cache\": flush_cache,\n            },\n        )\n\n    async def flush_cache(self) -> dict[str, Any]:\n        \"\"\"Flush the cache of the server asynchronously.\n\n        Similar to the sync version, this method retries until the cache\n        is successfully flushed. It uses async sleep between retries.\n\n        Returns:\n            Dict[str, Any]: Server response indicating cache flush status.\n                For non-master nodes, returns empty dict.\n\n        Note:\n            Uses retry logic with limited attempts (max_attempts * 4) to avoid infinite loops.\n            Each retry includes an async delay to allow pending requests to complete.\n        \"\"\"\n        if self.node_rank != 0:\n            return {}\n\n        # Use retry logic with limited attempts to avoid infinite loops\n        for attempt in range(self.max_attempts * 4):  # Allow more retries for cache flush\n            try:\n                async with self._get_session() as session:\n                    url = f\"http://{self.server_args.host}:{self.server_args.port}/flush_cache\"\n                    async with session.get(url) as response:\n                        if response.status == 200:\n                            return await _read_async_response(response)\n            except Exception as e:\n                logger.warning(f\"Error flushing cache (attempt {attempt + 1}): {e}\")\n\n            await asyncio.sleep(self.retry_delay)\n\n        logger.error(\"Failed to flush cache after maximum attempts\")\n        return {}\n\n    async def generate(\n        self,\n        prompt: Optional[str] = None,\n        sampling_params: Optional[dict[str, Any]] = None,\n        input_ids: Optional[list[int]] = None,\n        image_data: Optional[Any] = None,\n        return_logprob: bool = False,\n        logprob_start_len: Optional[int] = None,\n        top_logprobs_num: Optional[int] = None,\n        token_ids_logprob: Optional[list[int]] = None,\n        lora_path: Optional[str] = None,\n        custom_logit_processor: Optional[Callable] = None,\n    ) -> dict[str, Any]:\n        \"\"\"Generate text using the SGLang server asynchronously.\"\"\"\n        logger.info(\"generate() started\")\n\n        payload = {\n            \"text\": prompt,\n            \"sampling_params\": sampling_params,\n            \"input_ids\": input_ids,\n            \"image_data\": image_data,\n            \"return_logprob\": return_logprob,\n            \"logprob_start_len\": logprob_start_len,\n            \"top_logprobs_num\": top_logprobs_num,\n            \"token_ids_logprob\": token_ids_logprob,\n            \"lora_path\": lora_path,\n            \"custom_logit_processor\": custom_logit_processor,\n        }\n\n        # Filter out None values\n        payload = {k: v for k, v in payload.items() if v is not None}\n\n        # Send request\n        response = await self._make_async_request(\"generate\", payload, timeout=self.timeout, only_master=False)\n\n        return response\n\n    async def async_generate(\n        self,\n        prompt: Optional[str] = None,\n        sampling_params: Optional[dict[str, Any]] = None,\n        input_ids: Optional[list[int]] = None,\n        image_data: Optional[Any] = None,\n        return_logprob: bool = False,\n        logprob_start_len: Optional[int] = None,\n        top_logprobs_num: Optional[int] = None,\n        token_ids_logprob: Optional[list[int]] = None,\n        lora_path: Optional[str] = None,\n        custom_logit_processor: Optional[Callable] = None,\n    ) -> dict[str, Any]:\n        \"\"\"Async generate method that mirrors AsyncEngine.async_generate interface.\n\n        This method provides compatibility with AsyncEngine's async_generate method\n        by forwarding the call to the generate method. It ensures API consistency\n        between direct engine usage and HTTP-based engine usage.\n\n        Args:\n            prompt (Optional[str], optional): Text prompt for generation. Defaults to None.\n            sampling_params (Optional[Dict[str, Any]], optional): Parameters controlling\n                text generation sampling. Defaults to None.\n            input_ids (Optional[List[int]], optional): Alternative to prompt, direct token IDs input.\n                Defaults to None.\n            image_data (Optional[Any], optional): Image data for multimodal generation.\n                Defaults to None.\n            return_logprob (bool, optional): Whether to return log probabilities.\n                Defaults to False.\n            logprob_start_len (Optional[int], optional): Starting length for log probability calculation.\n                Defaults to None.\n            top_logprobs_num (Optional[int], optional): Number of top log probabilities to return.\n                Defaults to None.\n            token_ids_logprob (Optional[List[int]], optional): Specific token IDs for\n                log probability calculation. Defaults to None.\n            lora_path (Optional[str], optional): Path to LoRA adapter weights. Defaults to None.\n            custom_logit_processor (Optional[Callable], optional): Custom logit processing function.\n                Defaults to None.\n\n        Returns:\n            Dict[str, Any]: Generated text and associated metadata from the server\n\n        Note:\n            This method is provided for API compatibility with AsyncEngine.\n            It forwards all calls to the generate method.\n        \"\"\"\n        return await self.generate(\n            prompt=prompt,\n            sampling_params=sampling_params,\n            input_ids=input_ids,\n            image_data=image_data,\n            return_logprob=return_logprob,\n            logprob_start_len=logprob_start_len,\n            top_logprobs_num=top_logprobs_num,\n            token_ids_logprob=token_ids_logprob,\n            lora_path=lora_path,\n            custom_logit_processor=custom_logit_processor,\n        )\n\n    async def reward_score(\n        self,\n        prompt: Optional[str] = None,\n        input_ids: Optional[list[int]] = None,\n        image_data: Optional[Any] = None,\n        lora_path: Optional[str] = None,\n    ) -> dict[str, Any]:\n        logger.info(\"reward_score() started\")\n        payload = {\n            \"text\": prompt,\n            \"input_ids\": input_ids,\n            \"image_data\": image_data,\n            \"lora_path\": lora_path,\n        }\n        # Filter out None values\n        payload = {k: v for k, v in payload.items() if v is not None}\n\n        # Send request\n        response = await self._make_async_request(\"classify\", payload, timeout=self.timeout, only_master=False)\n\n        return response\n\n    async def async_reward_score(\n        self,\n        prompt: Optional[str] = None,\n        input_ids: Optional[list[int]] = None,\n        image_data: Optional[Any] = None,\n        lora_path: Optional[str] = None,\n    ) -> dict[str, Any]:\n        return await self.reward_score(\n            prompt=prompt,\n            input_ids=input_ids,\n            image_data=image_data,\n            lora_path=lora_path,\n        )\n\n    async def abort_request(self, rid: str = \"\", abort_all: bool = False) -> dict[str, Any]:\n        \"\"\"Abort a request asynchronously.\n\n        Args:\n            rid (str): The ID of the request to abort\n            abort_all (bool, optional): Whether to abort all requests. Defaults to False.\n\n        Returns:\n            Dict[str, Any]: Server response indicating abort status\n        \"\"\"\n        return await self._make_async_request(\"abort_request\", {\"rid\": rid, \"abort_all\": abort_all})\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/sglang_rollout/sglang_rollout.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nimport multiprocessing as mp\nimport os\nfrom copy import deepcopy\nfrom json import JSONDecodeError\nfrom typing import Any, Generator, Optional\nfrom uuid import uuid4\n\nimport numpy as np\nimport ray\nimport sglang.srt.entrypoints.engine\nimport torch\nimport torch.distributed as dist\nfrom sglang.srt.managers.io_struct import (\n    ReleaseMemoryOccupationReqInput,\n    ResumeMemoryOccupationReqInput,\n    UpdateWeightsFromTensorReqInput,\n)\nfrom sglang.srt.sampling.sampling_params import SamplingParams\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import (\n    assert_pkg_version,\n    get_open_port,\n    is_cuda,\n    set_prometheus_multiproc_dir,\n    set_ulimit,\n)\nfrom sglang.srt.weight_sync.utils import update_weights as sgl_update_weights\nfrom tensordict import TensorDict\nfrom torch.distributed.device_mesh import DeviceMesh, init_device_mesh\nfrom torch.nn.utils.rnn import pad_sequence\nfrom transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin\n\nfrom verl import DataProto\nfrom verl.interactions.base import BaseInteraction\nfrom verl.interactions.utils.interaction_registry import initialize_interactions_from_config\nfrom verl.third_party.sglang import parallel_state as sglang_ps\nfrom verl.tools.base_tool import BaseTool\nfrom verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall\nfrom verl.tools.utils.tool_registry import initialize_tools_from_config\nfrom verl.utils.device import get_visible_devices_keyword\nfrom verl.utils.net_utils import is_ipv6\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.torch_functional import get_response_mask, pad_sequence_to_length\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.base import BaseRollout\nfrom verl.workers.rollout.schemas import (\n    AsyncRolloutRequest,\n    AsyncRolloutRequestStateEnum,\n    FinishReasonTypeEnum,\n)\nfrom verl.workers.rollout.sglang_rollout.http_server_engine import AsyncHttpServerAdapter\nfrom verl.workers.rollout.sglang_rollout.utils import broadcast_pyobj, get_named_tensor_buckets\nfrom verl.workers.rollout.utils import is_valid_ipv6_address\n\ntry:\n    from sglang.srt.function_call.function_call_parser import FunctionCallParser\nexcept ImportError:\n    from sglang.srt.function_call_parser import FunctionCallParser\n\ntry:\n    from sglang.srt.entrypoints.openai.protocol import Tool\nexcept ImportError:\n    from sglang.srt.openai_api.protocol import Tool\n\n# compatible with sglang 0.5.3\ntry:\n    from sglang.srt.utils import get_ip\nexcept ImportError:\n    from sglang.srt.utils import get_local_ip_auto as get_ip\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n# patch to avoid issue https://github.com/sgl-project/sglang/issues/6723\ndef _set_envs_and_config(server_args: ServerArgs):\n    # Set global environments\n    os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n    os.environ[\"NCCL_CUMEM_ENABLE\"] = \"0\"\n    os.environ[\"NCCL_NVLS_ENABLE\"] = str(int(server_args.enable_nccl_nvls))\n    os.environ[\"TORCH_NCCL_AVOID_RECORD_STREAMS\"] = \"1\"\n    os.environ[\"CUDA_DEVICE_MAX_CONNECTIONS\"] = \"4\"\n    os.environ[\"CUDA_MODULE_LOADING\"] = \"AUTO\"\n\n    # Set prometheus env vars\n    if server_args.enable_metrics:\n        set_prometheus_multiproc_dir()\n\n    # Set ulimit\n    set_ulimit()\n\n    # Check flashinfer version\n    if server_args.attention_backend == \"flashinfer\":\n        assert_pkg_version(\n            \"flashinfer_python\",\n            \"0.2.5\",\n            \"Please uninstall the old version and reinstall the latest version by following the instructions at https://docs.flashinfer.ai/installation.html.\",\n        )\n    if is_cuda():\n        assert_pkg_version(\n            \"sgl-kernel\",\n            \"0.1.1\",\n            \"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`\",\n        )\n\n    # Set mp start method\n    mp.set_start_method(\"spawn\", force=True)\n\n\nsglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config\n\n\n# because chatCompletion is an async method, it makes the whole ray actor be an async actor\n# which can not call loop.run_until_complete. So we need to make the engine to be an async class\nclass AsyncEngine(sglang.srt.entrypoints.engine.Engine):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n    async def release_memory_occupation(self, tags: Optional[list[str]] = None):\n        \"\"\"Release GPU occupation temporarily.\"\"\"\n        if tags is None:\n            obj = ReleaseMemoryOccupationReqInput()\n        else:\n            obj = ReleaseMemoryOccupationReqInput(tags=tags)\n        return await self.tokenizer_manager.release_memory_occupation(obj, None)\n\n    async def resume_memory_occupation(self, tags: Optional[list[str]] = None):\n        \"\"\"Resume GPU occupation.\"\"\"\n        if tags is None:\n            obj = ResumeMemoryOccupationReqInput()\n        else:\n            obj = ResumeMemoryOccupationReqInput(tags=tags)\n        return await self.tokenizer_manager.resume_memory_occupation(obj, None)\n\n    async def update_weights_from_tensor(self, update_weights_request: UpdateWeightsFromTensorReqInput):\n        return await self.tokenizer_manager.update_weights_from_tensor(update_weights_request, None)\n\n    async def flush_cache(self):\n        return await self.tokenizer_manager.flush_cache()\n\n    async def abort_request(self, rid: str = \"\", abort_all: bool = False):\n        \"\"\"Abort a specific request or all requests.\n\n        Args:\n            rid: The request ID to abort. If empty and abort_all is False, no action is taken.\n            abort_all: If True, abort all running requests regardless of rid.\n        \"\"\"\n        return self.tokenizer_manager.abort_request(rid=rid, abort_all=abort_all)\n\n\n# NOTE(sgm): add for verl. We can optimize it by making\n#  the dataloader yield List[int] without padding.\ndef _pre_process_inputs(\n    pad_token_id,\n    prompt_token_ids: torch.Tensor,\n) -> torch.Tensor:\n    # remove the left padding in the prompt token_id\n    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]\n    return prompt_token_ids[non_pad_index:]\n\n\ndef _extract_logprob_from_output(output):\n    \"\"\"\n    extract log_prob from single sglang inference output\n    \"\"\"\n\n    def _map_each_response(resp):\n        input_token_logprobs = resp[\"meta_info\"][\"input_token_logprobs\"]\n        log_probs, output_token_ids = zip(\n            *[(log_prob, token_ids) for log_prob, token_ids, _ in input_token_logprobs[1:]], strict=False\n        )\n        return torch.tensor(output_token_ids), torch.tensor(log_probs)\n\n    output_token_ids, log_probs = _map_each_response(output)\n    return output_token_ids, log_probs\n\n\n# NOTE(linjunrong): adhoc\ndef _post_process_outputs(processing_class, output):\n    try:\n        # This is when processing_class is a processor\n        tokenizer = processing_class.tokenizer\n    except AttributeError:\n        try:\n            # This is when processing_class is a tokenizer\n            tokenizer = processing_class\n        except AttributeError as e:\n            raise ValueError(f\"Cannot get tokenizer from processing_class {processing_class}\") from e\n\n    def _map_each_response(resp):\n        output_token_logprobs = resp[\"meta_info\"][\"output_token_logprobs\"]\n        log_probs, output_token_ids = zip(\n            *[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs], strict=True\n        )\n        return torch.tensor(output_token_ids), torch.tensor(log_probs)\n\n    out_map = map(lambda x: _map_each_response(x), output)\n    batched_output_token_ids = []\n    batched_logprobs = []\n    for output_token_ids, log_probs in out_map:\n        batched_output_token_ids.append(output_token_ids)\n        batched_logprobs.append(log_probs)\n    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n    batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id)\n    if len(batched_logprobs) > 0:\n        batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id)\n    return batched_output_token_ids, batched_logprobs\n\n\ndef get_tool_call_parser_type(\n    processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n) -> str:\n    items = FunctionCallParser.ToolCallParserEnum.items()\n    if \"gpt-oss\" in getattr(processing_class, \"name_or_path\", \"\").lower():\n        logger.debug(f\"gpt-oss model detected from name_or_path: {processing_class.name_or_path}\")\n        logger.debug(\"Using 'gpt-oss' tool call parser.\")\n        return \"gpt-oss\"\n    for parser_type, parser_cls in items:\n        parser = parser_cls()\n        try:\n            # This is when processing_class is a tokenizer\n            tokenizer_vocab = processing_class.get_vocab()\n        except AttributeError:\n            try:\n                # This is when processing_class is a processor\n                tokenizer_vocab = processing_class.tokenizer.get_vocab()\n            except AttributeError as e:\n                raise ValueError(f\"Cannot get vocab from processing_class {processing_class}\") from e\n\n        if parser.bot_token.strip() in tokenizer_vocab and (\n            parser.eot_token == \"\" or parser.eot_token.strip() in tokenizer_vocab\n        ):\n            return parser_type\n    else:\n        raise ValueError(f\"No tool call parser found for processing_class {processing_class}\")\n\n\nclass SGLangRollout(BaseRollout):\n    def __init__(\n        self,\n        config: RolloutConfig,\n        model_config: HFModelConfig,\n        device_mesh: DeviceMesh,\n    ):\n        super().__init__(config, model_config, device_mesh)\n\n        actor_module = model_config.local_path\n        processing_class = model_config.get_processor()\n        model_hf_config = model_config.hf_config\n        trust_remote_code = model_config.trust_remote_code\n        port = None\n        kwargs = {}\n\n        os.environ.setdefault(\"SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK\", \"true\")\n\n        (\n            self._tool_schemas,\n            self._tool_map,\n            self._tool_call_parser_type,\n            self._sgl_tools,\n            self._function_call_parser,\n        ) = self._initialize_tools(config, processing_class)\n        self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(config)\n\n        # If turn on `free_cache_engine`, SGLang engine's KV cache\n        # will be freed after each `generate_sequences` call.\n        logger.info(\n            f\"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: \"\n            f\"{self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: \"\n            f\"{self._function_call_parser}\"\n        )\n\n        self._init_distributed_env(device_mesh_cpu=None, **kwargs)\n\n        self._verify_config(model_hf_config=model_hf_config)\n        # initialize the inference engine\n        self._init_inference_engine(trust_remote_code, actor_module, port)\n\n        self._init_sampling_params(**kwargs)\n\n        self.processing_class = processing_class\n        try:\n            # This is when processing_class is a tokenizer\n            self.pad_token_id = self.processing_class.pad_token_id\n        except AttributeError:\n            try:\n                # This is when processing_class is a processor\n                self.pad_token_id = self.processing_class.tokenizer.pad_token_id\n            except AttributeError as e:\n                raise ValueError(f\"Cannot get pad_token_id from processing_class {self.processing_class}\") from e\n\n    def _init_distributed_env(self, device_mesh_cpu, **kwargs):\n        self._device_mesh_cpu = device_mesh_cpu\n        os.environ.setdefault(\"SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK\", \"true\")\n        self.tensor_parallel_size = self.config.get(\"tensor_model_parallel_size\", 1)\n        assert self.tensor_parallel_size <= dist.get_world_size(), (\n            \"tensor parallel size should be less than or equal to the world size\"\n        )\n        self.train_tp = kwargs.get(\"train_tp\", None)\n        if self.train_tp is not None:\n            # deployed with megatron\n            os.environ[\"CUDA_TIMER_STREAM_KAFKA_ENABLE\"] = \"0\"\n            os.environ[\"MEGATRON_IMPORT_TIMERS\"] = \"0\"\n            train_tp = kwargs.get(\"train_tp\", None)\n            num_tp_per_train_tp = train_tp // self.tensor_parallel_size\n            sglang_ps.initialize_parallel_state(\n                tensor_model_parallel_size=self.tensor_parallel_size,\n                num_tp_per_train_tp=num_tp_per_train_tp,\n            )\n\n        tp_size = self.tensor_parallel_size\n        world_size = int(os.getenv(\"WORLD_SIZE\", \"-1\"))\n\n        # init device mesh\n        if self._device_mesh_cpu is None:\n            device_mesh_kwargs = dict(\n                mesh_shape=(world_size // tp_size, tp_size, 1),\n                mesh_dim_names=[\"dp\", \"tp\", \"pp\"],\n            )\n\n            self._device_mesh_cpu = init_device_mesh(\"cpu\", **device_mesh_kwargs)\n\n        self._rank = self._device_mesh_cpu.get_rank()\n        self._tp_rank = self._device_mesh_cpu[\"tp\"].get_local_rank()\n        self._tp_size = self._device_mesh_cpu[\"tp\"].size()\n        if self._rank == 0:\n            logger.info(f\"_init_distributed_env: :tp_world: {self._tp_size}, global_world: {world_size}\")\n        # get tp_rank of this process in this tp group\n        visible_devices = [None] * self._device_mesh_cpu.size(1)\n        devices_keyword = get_visible_devices_keyword()\n        torch.distributed.all_gather_object(\n            visible_devices, os.environ[devices_keyword], self._device_mesh_cpu.get_group(\"tp\")\n        )\n        self.visible_devices_set = set(\",\".join(visible_devices).split(\",\"))\n        os.environ[devices_keyword] = \",\".join(sorted(list(self.visible_devices_set), key=int))\n\n    def _verify_config(self, model_hf_config):\n        if not self.config.get(\"max_model_len\", None):\n            self.config.max_model_len = self.config.prompt_length + self.config.response_length\n        assert (\n            self.config.max_model_len >= self.config.prompt_length + self.config.response_length\n        ), f\"\"\"max_model_len should be greater than total sequence length (prompt_length + response_length): \n            {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}\"\"\"\n        max_position_embeddings = None\n        if hasattr(model_hf_config, \"max_position_embeddings\"):\n            max_position_embeddings = model_hf_config.max_position_embeddings\n        elif hasattr(model_hf_config, \"llm_config\") and hasattr(model_hf_config.llm_config, \"max_position_embeddings\"):\n            max_position_embeddings = model_hf_config.llm_config.max_position_embeddings\n        elif hasattr(model_hf_config, \"text_config\") and hasattr(\n            model_hf_config.text_config, \"max_position_embeddings\"\n        ):\n            max_position_embeddings = model_hf_config.text_config.max_position_embeddings\n        if max_position_embeddings is None:\n            raise ValueError(\"max_position_embeddings not found in model_hf_config\")\n        rope_scaling_config = getattr(model_hf_config, \"rope_scaling\", None)\n        if not rope_scaling_config:\n            assert max_position_embeddings >= self.config.prompt_length + self.config.response_length, (\n                \"model context length should be greater than total sequence length\"\n            )\n        else:\n            # handle type where there's a length extend factor\n            # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support\n            # for using yarn as an example\n            rope_scaling_factor = rope_scaling_config.get(\"factor\", 1.0)\n\n            assert (\n                model_hf_config.max_position_embeddings * rope_scaling_factor\n                >= self.config.prompt_length + self.config.response_length\n            ), (\n                f\"model context length should be greater than total sequence length, \"\n                f\"got rope_scaling_factor={rope_scaling_factor} and \"\n                f\"max_position_embeddings={model_hf_config.max_position_embeddings}\"\n            )\n\n        # currently max_assistant_turns stand for max number of tool calls\n        if self.config.multi_turn.max_assistant_turns is None:\n            self.config.multi_turn.max_assistant_turns = self.config.max_model_len // 3\n        if self.config.multi_turn.max_user_turns is None:\n            self.config.multi_turn.max_user_turns = self.config.max_model_len // 3\n\n    def _init_inference_engine(self, trust_remote_code, actor_module, port):\n        # initialize the inference engine\n        nnodes = -(-self._tp_size // len(self.visible_devices_set))\n        if nnodes > 1:\n            ip = get_ip()\n            port = get_open_port() if port is None else port\n            [ip, port] = broadcast_pyobj(\n                [ip, port],\n                rank=self._rank,\n                dist_group=self._device_mesh_cpu.get_group(\"tp\"),\n                src=self._device_mesh_cpu[\"tp\"].mesh[0].item(),\n                force_cpu_device=False,\n            )\n            dist_init_addr = f\"[{ip}]:{port}\" if is_ipv6(ip) else f\"{ip}:{port}\"\n        else:\n            dist_init_addr = None\n\n        load_format = \"dummy\" if self.config.load_format.startswith(\"dummy\") else self.config.load_format\n        tp_size_per_node = self._tp_size // nnodes\n        node_rank = self._tp_rank // tp_size_per_node\n        first_rank_in_node = self._tp_rank % tp_size_per_node == 0\n        engine_kwargs = self.config.get(\"engine_kwargs\", {}).get(\"sglang\", {}) or {}\n        engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None}\n\n        # attention backend will be changed to fa3 if not specified\n        attention_backend = engine_kwargs.pop(\"attention_backend\", None)\n        max_running_requests = self.config.get(\"max_num_seqs\", None)\n\n        try:\n            is_server_mode = self.config.sglang_rollout_mode == \"server\"\n        except Exception:\n            is_server_mode = False\n        effective_first = first_rank_in_node or is_server_mode\n\n        if self.config.mode == \"async\" and not self.config.skip_tokenizer_init:\n            raise ValueError(\"async mode requires skip_tokenizer_init to be True\")\n        backend = attention_backend if attention_backend is not None else \"fa3\"\n        sglang_port = int(os.getenv(\"SGLANG_PORT\", \"30000\")) + (dist.get_rank() * 2)\n        if effective_first:\n            os.environ[\"SGLANG_BLOCK_NONZERO_RANK_CHILDREN\"] = \"0\"\n            args = {\n                \"model_path\": actor_module,\n                \"dtype\": self.config.dtype,\n                \"mem_fraction_static\": self.config.gpu_memory_utilization,\n                \"enable_memory_saver\": True,\n                \"base_gpu_id\": 0,\n                \"gpu_id_step\": 1,\n                \"tp_size\": self._tp_size,\n                \"node_rank\": node_rank,\n                \"load_format\": load_format,\n                \"dist_init_addr\": dist_init_addr,\n                \"nnodes\": nnodes,\n                \"trust_remote_code\": trust_remote_code,\n                \"max_running_requests\": max_running_requests,\n                # NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new\n                # when random.seed is being set during training\n                \"port\": sglang_port,\n                \"nccl_port\": sglang_port + 1,\n                # NOTE(Chenyang): if you want to debug the SGLang engine output\n                # please set the following parameters\n                # Otherwise, it will make the engine run too slow\n                \"log_level\": \"info\",\n                # \"log_level\": \"error\",\n                # log_requests=True,\n                # log_requests_level=2,\n                # NOTE(Chenyang): turn on max_running_requests to set the max concurrent running requests\n                # max_running_requests=1,\n                \"mm_attention_backend\": backend,\n                \"attention_backend\": backend,\n                # In async mode, we want token in token out.\n                \"skip_tokenizer_init\": self.config.skip_tokenizer_init,\n                \"dist_timeout\": 1800,\n            }\n\n            if is_server_mode:\n                # add server specific args\n                args[\"first_rank_in_node\"] = first_rank_in_node\n                args[\"timeout\"] = self.config.server[\"timeout\"]\n                args[\"max_attempts\"] = self.config.server[\"max_attempts\"]\n                args[\"retry_delay\"] = self.config.server[\"retry_delay\"]\n                args[\"max_connections\"] = self.config.server[\"max_connections\"]\n                args[\"max_start_wait_time\"] = self.config.server[\"max_start_wait_time\"]\n                self._engine = AsyncHttpServerAdapter(**args)\n            else:\n                self._engine = AsyncEngine(**args)\n        else:\n            self._engine = None\n\n        self.sharding_manager = None\n        self.is_sleep = True\n\n    def _init_sampling_params(self, **kwargs):\n        kwargs = dict(\n            n=1,\n            max_new_tokens=self.config.response_length,\n            presence_penalty=0.0,\n            frequency_penalty=0.0,\n            repetition_penalty=self.config.get(\"repetition_penalty\", 1.0),\n        )\n        # supporting adding any sampling params from the config file\n        for k in self.config.keys():\n            if hasattr(SamplingParams(), str(k)) or \"stop\" in str(k):\n                kwargs[k] = self.config.get(k)\n        kwargs[\"n\"] = 1  # already repeat in ray_trainer\n        self.sampling_params = kwargs\n\n    def _initialize_tools(self, config, processing_class):\n        \"\"\"Initialize tools from configuration.\n\n        Args:\n            config: Configuration object containing tool-related settings,\n                    specifically `config.multi_turn.tool_config_path`.\n            tokenizer: The tokenizer instance used for parsing tool calls from\n                       the model's generated text.\n\n        Returns:\n            tuple: A tuple containing:\n                - tool_schemas (list[dict]): OpenAI-formatted JSON schemas\n                  defining each tool's capabilities.\n                - tool_map (dict[str, BaseTool]): A dictionary mapping tool\n                  names to their executable `BaseTool` objects.\n                - tool_call_parser_type (str): The identifier for the specific\n                  parser type (e.g., 'json_mode', 'tool_code') used to extract\n                  tool calls.\n                - sgl_tools (list[sglang.srt.openai_api.protocol.Tool]): Tool\n                  definitions optimized for SGLang's internal engine.\n                - function_call_parser (sglang.srt.function_call_parser.FunctionCallParser):\n                  The active parser instance responsible for extracting\n                  structured tool calls from model outputs.\n        \"\"\"\n        if config.multi_turn.tool_config_path is None:\n            return [], {}, None, [], None\n\n        tools_config_file = config.multi_turn.tool_config_path\n        tool_list = initialize_tools_from_config(tools_config_file)\n\n        logger.info(f\"Initialize tools from configuration.: tool_list: {tool_list}\")\n        tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list]\n        tool_map = {tool.name: tool for tool in tool_list}\n        tool_call_parser_type = get_tool_call_parser_type(processing_class)\n        sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas]\n        function_call_parser = FunctionCallParser(\n            sgl_tools,\n            tool_call_parser_type,\n        )\n\n        return (\n            tool_schemas,\n            tool_map,\n            tool_call_parser_type,\n            sgl_tools,\n            function_call_parser,\n        )\n\n    def _initialize_interactions(self, config):\n        \"\"\"Initialize interactions from configuration.\n\n        Returns:\n            dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances.\n        \"\"\"\n        if config.multi_turn.interaction_config_path is None:\n            return {}\n\n        interaction_config_file = config.multi_turn.interaction_config_path\n        interaction_map = initialize_interactions_from_config(interaction_config_file)\n\n        logger.info(f\"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}\")\n        return interaction_map\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:\n        \"\"\"Generate sequences for a batch of prompts.\n\n        Args:\n            batch (DataProto): Input batch.\n\n        Returns:\n            DataProto: Output batch.\n            - prompts: [bsz, prompt_length], prompt token ids from dataset.\n            - responses: [bsz, response_length], output token ids include response tokens\n              from LLM generation and observation tokens from tool_calls.\n            - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.\n            - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens\n              and response tokens.\n            - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.\n            - position_ids: [bsz, prompt_length + response_length], incremental position ids.\n\n            For multi-turn conversations:\n            responses:     |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|\n            response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|\n        \"\"\"\n        if self.config.multi_turn.enable:\n            return self._req_level_generate_sequences(prompts, **kwargs)\n        return self._batch_level_generate_sequences(prompts, **kwargs)\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:\n        \"\"\"Generates single-turn sequences for a batch of prompts.\n        For single-turn generation, all prompts are processed in one request.\n        `_batch_level_generate_sequences` involves:\n        1.  Extracting and pre-processing prompt token IDs from the input\n            `prompts`. This includes handling padding and preparing raw\n            token ID lists.\n        2.  Preparing inputs for the SGLang engine, including multi-modal\n            data if present.\n        3.  Invoking the SGLang engine (`self._engine.async_generate`,\n            an async coroutine) with the batch of processed inputs and\n            specified sampling parameters on the master TP rank.\n        4.  Broadcasting the results from the master TP rank to all\n            other TP ranks.\n        5.  Post-processing the engine's output to format the generated\n            token IDs and (if applicable) log probabilities.\n        6.  Constructing the final sequences by concatenating original\n            prompts with the generated responses.\n        7.  Updating attention masks and position IDs to reflect the full\n            concatenated sequences.\n        8.  If `self.config.free_cache_engine` is true, the SGLang engine's\n            KV cache is flushed after generation on the master TP rank.\n        Args:\n            prompts: A `DataProto` object containing the batch of\n              input prompts, including tensor data (like `input_ids`,\n              `attention_mask`) and meta-information (like `eos_token_id`,\n              `do_sample`).\n            **kwargs: Additional keyword arguments that can override the\n              default sampling parameters (e.g., `temperature`, `top_p`,\n              `max_new_tokens`). These are temporarily applied using\n              `update_sampling_params`.\n        Returns:\n            DataProto: A `DataProto` object containing the batch of\n              generated sequences. This includes tensors for `prompts`\n              (original input IDs), `responses` (generated token IDs),\n              `input_ids` (concatenated prompt and response),\n              `attention_mask`, and `position_ids` for the full\n              sequences.\n        Note that in GRPO, if the prompts are validated, we repeat the prompts for rollout.n times in ray_trainer.\n        Thus we do not need to repeat the prompts here and set the sampling parameter n to 1.\n        \"\"\"\n        # input ids: (bs, prompt_length), left-padded\n        idx = prompts.batch[\"input_ids\"]\n        # attention_mask: (bs, seq_length), left-padded\n        attention_mask = prompts.batch[\"attention_mask\"]\n        position_ids = prompts.batch[\"position_ids\"]\n\n        # used to generate attention mask for the\n        # response based on EOS token position\n        eos_token_id = prompts.meta_info[\"eos_token_id\"]\n\n        batch_size = idx.size(0)\n\n        # Extract non-tensor data\n        non_tensor_batch = prompts.non_tensor_batch\n        if \"raw_prompt_ids\" not in non_tensor_batch:\n            non_tensor_batch[\"raw_prompt_ids\"] = np.array(\n                [_pre_process_inputs(self.pad_token_id, idx[i]).tolist() for i in range(batch_size)],\n                dtype=object,\n            )\n\n        if \"multi_modal_data\" in non_tensor_batch:\n            sglang_inputs = []\n            for raw_prompt_ids, multi_modal_data in zip(\n                non_tensor_batch.pop(\"raw_prompt_ids\"),\n                non_tensor_batch.pop(\"multi_modal_data\"),\n                strict=True,\n            ):\n                sglang_inputs.append(\n                    {\n                        \"prompt_token_ids\": raw_prompt_ids,\n                        \"multi_modal_data\": multi_modal_data,\n                        \"image_data\": (\n                            multi_modal_data.get(\"image\", None) if isinstance(multi_modal_data, dict) else None\n                        ),\n                    }\n                )\n        else:\n            sglang_inputs = [\n                {\"prompt_token_ids\": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop(\"raw_prompt_ids\")\n            ]\n\n        for input_data in sglang_inputs:\n            # Ensure token IDs are lists or numpy arrays\n            if not isinstance(input_data[\"prompt_token_ids\"], list | np.ndarray):\n                raise TypeError(\n                    f\"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}\"\n                )\n\n            input_data[\"prompt_token_ids\"] = list(input_data[\"prompt_token_ids\"])\n\n        # Extract token IDs and image data for SGLang Engine\n        idx_list = [input_data[\"prompt_token_ids\"] for input_data in sglang_inputs]\n        image_list = [input_data.get(\"image_data\", None) for input_data in sglang_inputs]\n\n        do_sample = prompts.meta_info.get(\"do_sample\", True)\n        is_validate = prompts.meta_info.get(\"validate\", False)\n\n        # Create request-level sampling parameters\n        request_sampling_params = self.sampling_params.copy()\n        if not do_sample:\n            request_sampling_params.update(\n                {\n                    \"n\": 1,\n                    \"presence_penalty\": 0.0,\n                    \"frequency_penalty\": 0.0,\n                    \"repetition_penalty\": 1.0,\n                    \"temperature\": 0,\n                    \"top_p\": 1,\n                    \"top_k\": -1,\n                    \"ignore_eos\": False,\n                    \"min_new_tokens\": 0,\n                    \"max_new_tokens\": self.config.response_length,\n                    \"skip_special_tokens\": True,\n                    \"spaces_between_special_tokens\": True,\n                }\n            )\n        elif is_validate:\n            request_sampling_params.update(\n                {\n                    \"top_k\": self.config.val_kwargs.top_k,\n                    \"top_p\": self.config.val_kwargs.top_p,\n                    \"temperature\": self.config.val_kwargs.temperature,\n                    \"n\": 1,  # if validate, already repeat in ray_trainer\n                }\n            )\n\n        # Update with any additional kwargs\n        request_sampling_params.update(kwargs)\n\n        if self._tp_rank == 0:\n            loop = asyncio.get_event_loop()\n            output = loop.run_until_complete(\n                self._engine.async_generate(\n                    prompt=None,  # because we have already convert it to prompt token id\n                    sampling_params=request_sampling_params,\n                    return_logprob=True,\n                    input_ids=idx_list,\n                    image_data=image_list,\n                )\n            )\n        else:\n            output = None\n\n        # Most naive implementation, can extract tensor and send via gloo if too slow\n        dist.barrier()\n\n        # Because the logic below requires GPU memory proportional to the batch size, so free cache first to avoid OOM\n        if self._engine is not None and self._tp_rank == 0:\n            loop = asyncio.get_event_loop()\n            loop.run_until_complete(self._engine.flush_cache())\n\n        [output] = broadcast_pyobj(\n            data=[output],\n            rank=self._rank,\n            dist_group=self._device_mesh_cpu[\"tp\"].get_group(),\n            src=self._device_mesh_cpu[\"tp\"].mesh[0].item(),\n            force_cpu_device=False,\n        )\n        out = _post_process_outputs(self.processing_class, output)\n\n        response = out[0].to(idx.device)\n        rollout_log_probs = None\n        if self.config.calculate_log_probs:\n            rollout_log_probs = out[1].to(idx.device)\n\n        if response.shape[1] < self.config.response_length:\n            response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)\n            if self.config.calculate_log_probs:\n                rollout_log_probs = pad_sequence_to_length(\n                    rollout_log_probs, self.config.response_length, self.pad_token_id\n                )\n\n        seq = torch.cat([idx, response], dim=-1)\n\n        response_length = response.size(1)\n        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n        delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)\n        if position_ids.dim() == 3:  # qwen2vl mrope (batch size, 4, seq len)\n            delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, position_ids.size(1), -1)\n\n        # TODO(sgm): fix position_ids on right_pad\n        # prompt: left pad + response: right pad\n        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n        response_position_ids = position_ids[..., -1:] + delta_position_id\n        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n        response_attention_mask = get_response_mask(\n            response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype\n        )\n        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n\n        # all the tp ranks should contain the same data here. data in all ranks are valid\n        batch = TensorDict(\n            {\n                \"prompts\": idx,\n                \"responses\": response,\n                \"input_ids\": seq,  # here input_ids become the whole sentences\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=batch_size,\n        )\n        if self.config.calculate_log_probs:\n            # we will recompute old log prob with actor\n            batch[\"rollout_log_probs\"] = rollout_log_probs\n\n        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)\n\n    async def _async_rollout_a_request(\n        self,\n        req: AsyncRolloutRequest,\n        do_sample: bool = True,\n        is_validate: bool = False,\n        **kwargs,\n    ) -> AsyncRolloutRequest:\n        assert self._tp_rank == 0, \"only the master process can call this function\"\n        _req = deepcopy(req)\n        finish_reason_type = None\n        output = None\n\n        current_turns = 0\n        user_turns = 0\n        user_turn_rewards = []\n\n        # Create request-level sampling parameters\n        request_sampling_params = self.sampling_params.copy()\n        if not do_sample:\n            request_sampling_params.update(\n                {\n                    \"n\": 1,\n                    \"presence_penalty\": 0.0,\n                    \"frequency_penalty\": 0.0,\n                    \"repetition_penalty\": 1.0,\n                    \"temperature\": 0,\n                    \"top_p\": 1,\n                    \"top_k\": -1,\n                    \"ignore_eos\": False,\n                    \"min_new_tokens\": 0,\n                    \"max_new_tokens\": self.config.response_length,\n                    \"skip_special_tokens\": True,\n                    \"spaces_between_special_tokens\": True,\n                }\n            )\n        elif is_validate:\n            request_sampling_params.update(\n                {\n                    \"top_k\": self.config.val_kwargs.top_k,\n                    \"top_p\": self.config.val_kwargs.top_p,\n                    \"temperature\": self.config.val_kwargs.temperature,\n                    \"n\": 1,  # if validate, already repeat in ray_trainer\n                }\n            )\n\n        # Update with any additional kwargs\n        request_sampling_params.update(kwargs)\n\n        while current_turns < self.config.multi_turn.max_assistant_turns:\n            if _req.state == AsyncRolloutRequestStateEnum.PENDING:\n                await self._handle_pending_state(_req)\n                _req.state = AsyncRolloutRequestStateEnum.RUNNING\n            elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING:\n                if _req.messages[-1].tool_calls is not None:\n                    parsed_tool_calls = _req.messages[-1].tool_calls\n                    if self.config.skip_tokenizer_init:\n                        _req.messages[-1].tool_calls = None\n                    tool_call_results = await asyncio.gather(\n                        *[\n                            self._tool_map[tool_call.function.name].execute(\n                                _req.request_id,\n                                tool_call.function.arguments,\n                                **_req.tools_kwargs.get(tool_call.function.name, {}).get(\"execute_kwargs\", {}),\n                            )\n                            for tool_call in parsed_tool_calls\n                        ]\n                    )\n                    _req.add_tool_response_messages(self.processing_class, [resp for resp, _, _ in tool_call_results])\n                    for tool_call, (resp, reward, metrics) in zip(parsed_tool_calls, tool_call_results, strict=True):\n                        _req.update_metrics(metrics, tool_call.function.name)\n                    if _req.input_ids.size(-1) >= self.config.max_model_len:\n                        finish_reason_type = FinishReasonTypeEnum.STOP\n                        break\n                    _req.state = AsyncRolloutRequestStateEnum.RUNNING\n                else:\n                    raise ValueError(f\"Unexpected tool calling last message state: {_req.messages[-1]}\")\n            elif _req.state == AsyncRolloutRequestStateEnum.RUNNING:\n                # Only continue the conversation if the prompt length is not greater than max_model_len - 1,\n                # since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra\n                # token accounts for the EOS token).\n                prompt_length = len(_req.get_generation_prompt_ids(self.processing_class))\n\n                if prompt_length + 1 >= self.config.max_model_len:\n                    finish_reason_type = FinishReasonTypeEnum.LENGTH\n                    break\n\n                # Video support is not implemented yet\n                image_data = (\n                    _req.multi_modal_data[\"image\"]\n                    if _req.multi_modal_data and \"image\" in _req.multi_modal_data\n                    else None\n                )\n                video_data = (\n                    _req.multi_modal_data[\"video\"]\n                    if _req.multi_modal_data and \"video\" in _req.multi_modal_data\n                    else None\n                )\n                if video_data:\n                    logger.warning(\n                        \"video support is not implemented yet, current length of video data is %d\", len(video_data)\n                    )\n\n                output = await self._handle_engine_call(_req, request_sampling_params, image_data=image_data)\n                if self.config.skip_tokenizer_init:\n                    content_ids = output[\"output_ids\"]\n                    content = self.processing_class.decode(content_ids, skip_special_tokens=True)\n                    content_ids = torch.tensor(\n                        content_ids, dtype=_req.input_ids.dtype, device=_req.input_ids.device\n                    ).unsqueeze(0)\n                else:\n                    content_ids = None\n                    content = output[\"text\"]\n\n                finish_reason_type = FinishReasonTypeEnum.from_str(output[\"meta_info\"][\"finish_reason\"][\"type\"])\n                current_turns += 1\n                if finish_reason_type == FinishReasonTypeEnum.LENGTH:\n                    _req.add_assistant_message(self.processing_class, content=content, content_ids=content_ids)\n                    break\n                else:\n                    if self._function_call_parser and self._function_call_parser.has_tool_call(content):\n                        finish_reason_type = FinishReasonTypeEnum.TOOL_CALL\n                        _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING\n                        try:\n                            normed_content, tool_calls = self._function_call_parser.parse_non_stream(content)\n                        except JSONDecodeError:\n                            normed_content = content\n                            tool_calls = []\n                        except AttributeError:\n                            normed_content = content\n                            tool_calls = []\n                        parsed_tool_calls = []\n                        for tool_call in tool_calls:\n                            function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema(\n                                OpenAIFunctionParsedSchema(\n                                    name=tool_call.name,\n                                    arguments=tool_call.parameters,\n                                )\n                            )\n                            # Drop the tool call if its arguments has decode error\n                            if has_decode_error:\n                                continue\n                            parsed_tool_calls.append(\n                                OpenAIFunctionToolCall(\n                                    id=str(tool_call.tool_index),\n                                    function=function,\n                                )\n                            )\n                        if len(parsed_tool_calls) > 0:\n                            _req.add_assistant_message(\n                                # since the content is updated, we just pass the content not content_ids\n                                self.processing_class,\n                                content=normed_content,\n                                tool_calls=parsed_tool_calls,\n                            )\n                        else:\n                            _req.add_assistant_message(self.processing_class, content=content, content_ids=content_ids)\n                            finish_reason_type = FinishReasonTypeEnum.STOP\n                            _req.state = AsyncRolloutRequestStateEnum.COMPLETED\n                            break\n                    else:\n                        _req.add_assistant_message(\n                            self.processing_class,\n                            content=content,\n                            content_ids=content_ids,\n                        )\n                        if (\n                            _req.interaction_kwargs\n                            and self.interaction_map\n                            and user_turns < self.config.multi_turn.max_user_turns\n                            and current_turns < self.config.multi_turn.max_assistant_turns\n                        ):\n                            _req.state = AsyncRolloutRequestStateEnum.INTERACTING\n                        else:\n                            # Add ending condition\n                            finish_reason_type = FinishReasonTypeEnum.STOP\n                            _req.state = AsyncRolloutRequestStateEnum.COMPLETED\n                            break\n            elif _req.state == AsyncRolloutRequestStateEnum.INTERACTING:\n                user_turns += 1\n                messages = [{\"role\": x.role, \"content\": x.content} for x in _req.messages]\n\n                # Get interaction by name from interaction_kwargs\n                interaction_name = _req.interaction_kwargs.get(\n                    \"name\", \"gsm8k\"\n                )  # Default to gsm8k for backward compatibility\n                if interaction_name not in self.interaction_map:\n                    raise ValueError(\n                        f\"Interaction '{interaction_name}' not found in interaction_map. Available interactions: \"\n                        f\"{list(self.interaction_map.keys())}\"\n                    )\n\n                interaction = self.interaction_map[interaction_name]\n\n                should_terminate_sequence, content, reward, metrics = await interaction.generate_response(\n                    _req.request_id, messages, **_req.interaction_kwargs\n                )\n                user_turn_rewards.append(reward)\n                # Add turn check\n                if (\n                    should_terminate_sequence\n                    or user_turns > self.config.multi_turn.max_user_turns\n                    or current_turns > self.config.multi_turn.max_assistant_turns\n                ):\n                    finish_reason_type = FinishReasonTypeEnum.STOP\n                    _req.state = AsyncRolloutRequestStateEnum.COMPLETED\n                    break\n                else:\n                    _req.add_user_message(self.processing_class, content)\n                    if _req.input_ids.size(-1) >= self.config.max_model_len:\n                        finish_reason_type = FinishReasonTypeEnum.STOP\n                        break\n                    else:\n                        _req.state = AsyncRolloutRequestStateEnum.RUNNING\n\n        if current_turns >= self.config.multi_turn.max_assistant_turns:\n            finish_reason_type = FinishReasonTypeEnum.STOP\n\n        # Calculate the reward for each tool\n        async def calc_reward_and_release_fn(name: str, tool: BaseTool):\n            reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get(\"calc_reward_kwargs\", {}))\n            await tool.release(_req.request_id, **_req.tools_kwargs[name].get(\"release_kwargs\", {}))\n            return name, reward\n\n        tool_reward_tasks = []\n        for name in _req.tools_kwargs.keys():\n            tool = self._tool_map[name]\n            tool_reward_tasks.append(calc_reward_and_release_fn(name, tool))\n        tool_reward_scores = await asyncio.gather(*tool_reward_tasks)\n        tool_reward_scores = dict(tool_reward_scores)\n        all_rewards = {**tool_reward_scores, **{\"user_turn_rewards\": user_turn_rewards}}\n        _req.finalize(self.processing_class, all_rewards, finish_reason_type)\n\n        if self.config.calculate_log_probs:\n            debug_sampling_params = {**self.sampling_params}\n            debug_sampling_params[\"max_new_tokens\"] = 0\n            output = await self._engine.async_generate(\n                prompt=None,\n                input_ids=_req.input_ids,\n                sampling_params=debug_sampling_params,\n                return_logprob=True,\n                logprob_start_len=0,\n            )\n            # len(input_token_logprobs) = len(input_tokens)-1，because logprob of 1st token is None\n            _req.output_token_ids, _req.rollout_log_probs = _extract_logprob_from_output(output)\n        return _req\n\n    async def _handle_engine_call(\n        self, _req: AsyncRolloutRequest, sampling_params: dict, image_data: Optional[list[Any]] = None\n    ) -> dict:\n        generation_prompt_ids = _req.get_generation_prompt_ids(self.processing_class)\n        return await self._handle_engine_generate(generation_prompt_ids, sampling_params, image_data)\n\n    async def _handle_engine_generate(\n        self, generation_prompt_ids: list[int], sampling_params: dict, image_data: Optional[list[Any]] = None\n    ) -> dict:\n        max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1)\n\n        kwargs = sampling_params.copy()\n        kwargs[\"max_new_tokens\"] = max_new_tokens\n        kwargs[\"n\"] = 1  # group size is supported in preprocess\n        return_logprob = kwargs.pop(\"logprobs\", False)\n\n        output = await self._engine.async_generate(\n            input_ids=generation_prompt_ids,\n            sampling_params=kwargs,\n            return_logprob=return_logprob,\n            image_data=image_data,\n        )\n        return output\n\n    async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest:\n        if _req.tool_schemas is not None:\n            tool_creation_coroutines = []\n            for tool_schema in _req.tool_schemas:\n                tool = self._tool_map[tool_schema.function.name]\n                create_kwargs = _req.tools_kwargs[tool.name].get(\"create_kwargs\", {})\n                tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs))\n            tool_creation_results = await asyncio.gather(*tool_creation_coroutines)\n            _req.add_tool_response_messages(\n                self.processing_class, [tool_result for _, tool_result in tool_creation_results]\n            )\n        if _req.interaction_kwargs and self.interaction_map:\n            interaction_kwargs = _req.interaction_kwargs\n            # Get interaction by name from interaction_kwargs\n            interaction_name = interaction_kwargs.get(\"name\", \"gsm8k\")  # Default to gsm8k for backward compatibility\n            if interaction_name not in self.interaction_map:\n                raise ValueError(\n                    f\"Interaction '{interaction_name}' not found in interaction_map. Available interactions: \"\n                    f\"{list(self.interaction_map.keys())}\"\n                )\n\n            interaction = self.interaction_map[interaction_name]\n            await interaction.start_interaction(_req.request_id, **interaction_kwargs)\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:\n        \"\"\"Generates multi-turn sequences for a batch of prompts.\n        For multi-turn generation, each prompt is processed separately via\n        `_req_level_generate_sequences` for better tool calling control.\n        Note that in multi-turn generation, we repeat the prompts for rollout.n times in ray_trainer.\n        Thus we do not need to repeat the prompts here and set the sampling parameter n to 1.\n        \"\"\"\n        # Async rollout with tools support\n        do_sample = prompts.meta_info.get(\"do_sample\", True)\n        is_validate = prompts.meta_info.get(\"validate\", False)\n        tgt_device = prompts.batch[\"input_ids\"].device\n\n        if self._tp_rank == 0:\n            req_list = self._preprocess_prompt_to_async_rollout_requests(\n                prompts,\n            )\n\n            # distinguish training and validation\n            if is_validate:\n                # Validation mode: process all requests without abort\n                loop = asyncio.get_event_loop()\n                output_req_list = loop.run_until_complete(\n                    asyncio.gather(\n                        *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list],\n                    )\n                )\n            else:\n                # add progress monitoring and abort function\n                total_requests = len(req_list)\n                target_completion = int(total_requests * (1 - self.config.get(\"over_sample_rate\", 0.0)))\n                # abort when target_completion of requests are completed\n\n                completed_count = 0\n                aborted_requests = []\n                all_tasks = []\n\n                async def rollout_a_request_with_cancellation_handler(req):\n                    try:\n                        result = await self._async_rollout_a_request(req, do_sample, is_validate, **kwargs)\n                        return result\n                    except asyncio.CancelledError:\n                        # request is cancelled, return padding\n                        logger.info(f\"Request {req.request_id} was cancelled, creating padding\")\n                        aborted_requests.append(req.request_id)\n                        return self._create_padding_request(req)\n\n                async def run_with_cancellation():\n                    nonlocal all_tasks\n                    nonlocal completed_count\n                    all_tasks = [\n                        asyncio.create_task(rollout_a_request_with_cancellation_handler(req)) for req in req_list\n                    ]\n\n                    # Wait for target_completion tasks to complete\n                    try:\n                        for completed_task in asyncio.as_completed(all_tasks):\n                            await completed_task\n                            completed_count += 1\n                            if completed_count >= target_completion:\n                                break\n                    finally:\n                        # Cancel remaining tasks\n                        for t in all_tasks:\n                            if not t.done():\n                                t.cancel()\n\n                        # Wait for all tasks to finish (including cancelled ones)\n                        final_results = await asyncio.gather(*all_tasks, return_exceptions=True)\n                        # Abort all requests in SGLang engine\n                        await self._engine.abort_request(abort_all=True)\n                    return final_results\n\n                loop = asyncio.get_event_loop()\n                output_req_list = loop.run_until_complete(run_with_cancellation())\n\n            sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset))\n        else:\n            sorted_output_req_list = None\n\n        dist.barrier()\n\n        # Because the logic below requires GPU memory proportional to the batch size, so free cache first to avoid OOM\n        if self._engine is not None and self._tp_rank == 0:\n            loop = asyncio.get_event_loop()\n            loop.run_until_complete(self._engine.flush_cache())\n\n        [sorted_output_req_list] = broadcast_pyobj(\n            data=[sorted_output_req_list],\n            rank=self._rank,\n            dist_group=self._device_mesh_cpu[\"tp\"].get_group(),\n            src=self._device_mesh_cpu[\"tp\"].mesh[0].item(),\n            force_cpu_device=False,\n        )\n        # Construct the batch data\n        prompt_ids, response_ids = [], []\n        prompt_attention_mask, response_attention_mask = [], []\n        prompt_position_ids, response_position_ids = [], []\n        response_loss_mask = []\n        messages = []\n        reward_scores = []\n        multi_modal_inputs = []\n        request_ids = []\n        if self.config.calculate_log_probs:\n            output_logprobs = []\n            rollout_output_token_ids = []\n\n        for req in sorted_output_req_list:\n            assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f\"Request {req.request_id} is not completed\"\n            assert (\n                req.input_ids.shape[-1]\n                == req.attention_mask.shape[-1]\n                == req.position_ids.shape[-1]\n                == req.loss_mask.shape[-1]\n            ), f\"\"\"Request {req.request_id} has different length of \n                {req.input_ids.shape[-1]=}, {req.attention_mask.shape[-1]=}, \n                {req.position_ids.shape[-1]=}, {req.loss_mask.shape[-1]=}\"\"\"\n            error_message_lines = [\n                f\"\"\"Request {req.request_id} has input_ids length {req.input_ids.shape[-1]}\n                    greater than max_model_len {self.config.max_model_len}\"\"\",\n                f\"Decoded input_ids: {self.processing_class.decode(req.input_ids.squeeze(0))}\",\n                f\"Decoded prompt_ids: {self.processing_class.decode(req.prompt_ids.squeeze(0))}\",\n                f\"Decoded response_ids: {self.processing_class.decode(req.response_ids.squeeze(0))}\",\n                f\"Messages: {req.messages}\",\n                f\"Max model length: {req.max_model_len}\",\n            ]\n            error_message = \"\\n\".join(error_message_lines)\n            assert req.input_ids.shape[-1] <= self.config.max_model_len, error_message\n\n            prompt_ids.append(req.prompt_ids.to(tgt_device).squeeze(0))\n            response_ids.append(req.response_ids.to(tgt_device).squeeze(0))\n            if req.response_ids.shape[-1] > self.config.response_length:\n                logger.warning(\n                    f\"\"\"{req.request_id=} has response_ids length {req.response_ids.shape[-1]} \n                    greater than max_response_len {self.config.response_length},\\n{req=}\"\"\"\n                )\n            prompt_attention_mask.append(req.prompt_attention_mask.to(tgt_device).squeeze(0))\n            response_attention_mask.append(req.response_attention_mask.to(tgt_device).squeeze(0))\n            prompt_position_ids.append(req.prompt_position_ids.to(tgt_device).squeeze(0))\n            response_position_ids.append(req.response_position_ids.to(tgt_device).squeeze(0))\n            response_loss_mask.append(req.response_loss_mask.to(tgt_device).squeeze(0))\n            messages.append({\"messages\": req.messages})\n            reward_scores.append(req.reward_scores)\n            multi_modal_inputs.append(req.multi_modal_inputs)\n            request_ids.append(req.request_id)\n            if self.config.calculate_log_probs:\n                # extract output log_probs\n                output_logprobs.append(req.rollout_log_probs[-len(req.response_ids) :])\n                rollout_output_token_ids.append(req.output_token_ids[-len(req.response_ids) :])\n\n        prompt_ids = pad_sequence(\n            prompt_ids,\n            batch_first=True,\n            padding_value=self.pad_token_id,\n            padding_side=\"left\",\n        )\n        if prompt_ids.shape[-1] < self.config.prompt_length:\n            prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True)\n        response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id)\n        if response_ids.shape[-1] < self.config.response_length:\n            response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id)\n        prompt_attention_mask = pad_sequence(\n            prompt_attention_mask,\n            batch_first=True,\n            padding_value=0,\n            padding_side=\"left\",\n        )\n        if prompt_attention_mask.shape[-1] < self.config.prompt_length:\n            prompt_attention_mask = pad_sequence_to_length(\n                prompt_attention_mask, self.config.prompt_length, 0, left_pad=True\n            )\n        response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0)\n        if response_attention_mask.shape[-1] < self.config.response_length:\n            response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0)\n\n        # padding prompt_position_ids\n        if prompt_position_ids[0].dim() == 2:\n            # if prompt_position_ids is a 2D tensor\n            # e.g. from qwen2vl, prompt_position_ids.shape = (3, seq_len)\n            transposed_prompt_position_ids = [p.transpose(0, 1) for p in prompt_position_ids]\n            prompt_position_ids = pad_sequence(\n                transposed_prompt_position_ids, batch_first=True, padding_value=0, padding_side=\"left\"\n            )\n            prompt_position_ids = prompt_position_ids.transpose(1, 2)\n        else:\n            prompt_position_ids = pad_sequence(\n                prompt_position_ids, batch_first=True, padding_value=0, padding_side=\"left\"\n            )\n        if prompt_position_ids.shape[-1] < self.config.prompt_length:\n            prompt_position_ids = pad_sequence_to_length(\n                prompt_position_ids, self.config.prompt_length, 0, left_pad=True\n            )\n\n        # padding response_position_ids\n        if response_position_ids[0].dim() == 2:\n            # if response_position_ids is a 2D tensor\n            # e.g. from qwen2vl, response_position_ids.shape = (3, seq_len)\n            transposed_response_position_ids = [p.transpose(0, 1) for p in response_position_ids]\n            response_position_ids = pad_sequence(\n                transposed_response_position_ids, batch_first=True, padding_value=0, padding_side=\"left\"\n            )\n            response_position_ids = response_position_ids.transpose(1, 2)\n        else:\n            response_position_ids = pad_sequence(response_position_ids, batch_first=True, padding_value=0)\n        if response_position_ids.shape[-1] < self.config.response_length:\n            response_position_ids = pad_sequence_to_length(response_position_ids, self.config.response_length, 0)\n\n        response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0)\n        if response_loss_mask.shape[1] < self.config.response_length:\n            response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0)\n        if self.config.calculate_log_probs:\n            output_logprobs = pad_sequence(output_logprobs, padding_value=0.0, batch_first=True)\n            output_logprobs = pad_sequence_to_length(\n                output_logprobs, pad_token_id=0.0, max_seq_len=response_ids.shape[-1]\n            ).to(tgt_device)\n            rollout_output_token_ids = pad_sequence(\n                rollout_output_token_ids, padding_value=self.pad_token_id, batch_first=True\n            )\n            rollout_output_token_ids = pad_sequence_to_length(\n                rollout_output_token_ids, pad_token_id=self.pad_token_id, max_seq_len=response_ids.shape[-1]\n            ).to(tgt_device)\n\n        input_ids = torch.cat((prompt_ids, response_ids), dim=-1)\n        attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)\n        position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1)\n\n        # Construct the batch data\n        batch = TensorDict(\n            {\n                \"prompts\": prompt_ids,\n                \"responses\": response_ids,\n                \"response_mask\": response_loss_mask,\n                \"input_ids\": input_ids,  # here input_ids become the whole sentences\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=len(sorted_output_req_list),\n        )\n        if self.config.calculate_log_probs:\n            batch[\"rollout_log_probs\"] = output_logprobs\n            batch[\"rollout_output_token_ids\"] = rollout_output_token_ids\n\n        non_tensor_batch = {\n            \"messages\": np.array(messages),\n            \"reward_scores\": np.array(reward_scores),\n            \"request_id\": np.array(request_ids),\n        }\n\n        is_multimodal = isinstance(self.processing_class, ProcessorMixin) and (\n            hasattr(self.processing_class, \"image_processor\") or hasattr(self.model_hf_config, \"vision_config\")\n        )\n\n        if is_multimodal:\n            non_tensor_batch[\"multi_modal_inputs\"] = np.array(multi_modal_inputs, dtype=object)\n\n        return DataProto(\n            batch=batch,\n            non_tensor_batch=non_tensor_batch,\n        )\n\n    def _create_padding_request(self, original_req: AsyncRolloutRequest) -> AsyncRolloutRequest:\n        # create a padding request to replace the aborted request\n        # the padding request has the following characteristics:\n        # 1. state is COMPLETED, but contains empty response\n        # 2. response_loss_mask is all 0, ensuring it is ignored in loss calculation\n        # 3. keep the original request structure, but the content is empty\n        # create padding response_ids (all pad_token_id)\n        padding_response_length = self.config.response_length\n        device = original_req.input_ids.device if original_req.input_ids is not None else \"cpu\"\n        padding_response_ids = torch.full(\n            (1, padding_response_length),\n            self.pad_token_id,\n            dtype=torch.long,\n            device=device,\n        )\n\n        # create padding attention_mask (all 0)\n        padding_response_attention_mask = torch.zeros(\n            (1, padding_response_length),\n            dtype=torch.long,\n            device=device,\n        )\n\n        # create padding position_ids\n        if original_req.position_ids is not None:\n            first_dim = 1\n            # if position_ids is a 2D tensor (e.g. qwen2vl)\n            if original_req.position_ids.dim() == 2:\n                first_dim = original_req.position_ids.shape[0]\n            padding_response_position_ids = torch.zeros(\n                (first_dim, padding_response_length),\n                dtype=torch.long,\n                device=device,\n            )\n        else:\n            padding_response_position_ids = None\n\n        # create padding prompt_attention_mask (all 0)\n        padding_prompt_attention_mask = torch.zeros(\n            (1, original_req.prompt_attention_mask.shape[-1]),\n            dtype=torch.long,\n            device=device,\n        )\n\n        # create padding loss_mask (all 0, ensuring it is ignored)\n        padding_response_loss_mask = torch.zeros(\n            (1, padding_response_length),\n            dtype=torch.long,\n            device=device,\n        )\n\n        padding_req = original_req.model_copy(deep=True)\n        padding_req.state = AsyncRolloutRequestStateEnum.COMPLETED\n        padding_req.response_ids = padding_response_ids\n        padding_req.prompt_attention_mask = padding_prompt_attention_mask\n        padding_req.response_attention_mask = padding_response_attention_mask\n        padding_req.response_position_ids = padding_response_position_ids\n        padding_req.response_loss_mask = padding_response_loss_mask\n        padding_req.reward_scores = {}\n        padding_req.metrics = {}\n        padding_req.output_token_ids = None\n        padding_req.rollout_log_probs = None\n        return padding_req\n\n    def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int = 1) -> list[AsyncRolloutRequest]:\n        assert \"raw_prompt\" in prompts.non_tensor_batch, (\n            \"need data.return_raw_chat=True, due to no official way do parse_messages\"\n        )\n        logger.info(\n            \"n is deprecated for SGLang rollout since ray ppo trainer will repeat the prompts for rollout.n times\"\n        )\n        req_list = []\n        multi_modal_data_list = prompts.non_tensor_batch.get(\n            \"multi_modal_data\", [None] * len(prompts.non_tensor_batch[\"raw_prompt\"])\n        )\n\n        for data_idx, (raw_prompt, multi_modal_data) in enumerate(\n            zip(prompts.non_tensor_batch[\"raw_prompt\"], multi_modal_data_list, strict=True)\n        ):\n            if self._tool_schemas:\n                _tools_kwargs = prompts.non_tensor_batch[\"tools_kwargs\"][data_idx]\n                _tool_schemas = [self._tool_map[k].get_openai_tool_schema() for k in _tools_kwargs.keys()]\n                _input_ids = None\n                _attention_mask = None\n            else:\n                _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch[\"input_ids\"][data_idx])\n                _attention_mask = _pre_process_inputs(0, prompts.batch[\"attention_mask\"][data_idx])\n                _tools_kwargs = {}\n                _tool_schemas = None\n\n            if self.interaction_map:\n                _interaction_kwargs = prompts.non_tensor_batch[\"interaction_kwargs\"][data_idx]\n            else:\n                _interaction_kwargs = {}\n\n            if not isinstance(raw_prompt, list | np.ndarray):\n                raise TypeError(f\"raw_prompt must be a list or numpy array, got {type(raw_prompt)}\")\n\n            req = AsyncRolloutRequest(\n                batch_data_id=data_idx,\n                rollout_offset=0,\n                request_id=str(uuid4()),\n                state=AsyncRolloutRequestStateEnum.PENDING,\n                messages=list(raw_prompt),\n                multi_modal_data=multi_modal_data,\n                tool_schemas=_tool_schemas,\n                tools_kwargs=_tools_kwargs,\n                interaction_kwargs=_interaction_kwargs,\n                input_ids=_input_ids,\n                response_ids=None,\n                attention_mask=_attention_mask,\n                response_attention_mask=None,\n                response_position_ids=None,\n                response_loss_mask=None,\n                reward_scores={},\n                max_prompt_len=self.config.prompt_length,\n                max_response_len=self.config.response_length,\n                max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),\n                use_inference_chat_template=self.config.multi_turn.use_inference_chat_template,\n                tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode,\n                processing_class=self.processing_class,\n            )\n            error_message = f\"\"\"Request {req.request_id} has mismatched lengths: \n            input_ids={req.input_ids.shape[-1]}, \n            attention_mask={req.attention_mask.shape[-1]}, \n            position_ids={req.position_ids.shape[-1]}, \n            loss_mask={req.loss_mask.shape[-1]}\"\"\"\n            assert (\n                req.input_ids.shape[-1]\n                == req.attention_mask.shape[-1]\n                == req.position_ids.shape[-1]\n                == req.loss_mask.shape[-1]\n            ), error_message\n            req_list.append(req)\n\n        return req_list\n\n    async def resume(self, tags: list[str]):\n        \"\"\"Resume rollout weights or kv cache in GPU memory.\n\n        Args:\n            tag: weights or kv_cache.\n        \"\"\"\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.config.free_cache_engine:\n            await self._engine.resume_memory_occupation(tags=tags)\n\n    async def release(self):\n        \"\"\"Release weights and kv cache in GPU memory.\"\"\"\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.config.free_cache_engine:\n            await self._engine.release_memory_occupation(tags=[\"kv_cache\", \"weights\"])\n\n    async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs):\n        \"\"\"\n        Update model weights using tensor buckets, similar to THUDM/slime's implementation.\n\n        Notes:\n          - For the best performance of `rebuild_cuda_tensor`, it is recommended to:\n              1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`.\n              2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`\n            when using Tensor Parallelism (TP >= 8).\n          - See reference implementations in SLIME:\n            - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452\n            - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39\n        \"\"\"\n        update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20\n        for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes):\n            await sgl_update_weights(\n                engine=self._engine,\n                params_batch=params_batch,\n                device_mesh_key=\"infer_tp\",\n                device_mesh=self.device_mesh,\n            )\n\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n            await self._engine.flush_cache()\n\n\nclass ServerAdapter(BaseRollout):\n    \"\"\"SGLang server adapter used in native http server mode, serve as http client to request SGLang server\n    to resume/release/update weights and kv_cache.\n\n    - hybrid mode: reside in each hybrid worker to sync weights between training engine and SGLang server.\n    - standalone/colocated mode: just a dummy placeholder to occupy the GPU to prevent ray scheduling new GPU actor.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: RolloutConfig,\n        model_config: HFModelConfig,\n        device_mesh: DeviceMesh,\n    ):\n        super().__init__(config, model_config, device_mesh)\n        self._engine: AsyncHttpServerAdapter = None\n\n        rank = int(os.environ[\"RANK\"])\n        local_world_size = int(os.environ[\"RAY_LOCAL_WORLD_SIZE\"])\n        rollout_world_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size\n        self.replica_rank = rank // rollout_world_size\n        self.rollout_rank = rank % rollout_world_size\n        self.node_rank = self.rollout_rank // local_world_size\n        self.local_rank = self.rollout_rank % local_world_size\n\n    async def _init_server_adapter(self):\n        if self._engine is not None:\n            return\n\n        # Lazy init http server adapter because http server is launched after hybrid engine.\n        self.server_actor = ray.get_actor(f\"sglang_server_{self.replica_rank}_{self.node_rank}\")\n        server_address, server_port = await self.server_actor.get_server_address.remote()\n        logger.debug(\n            f\"replica_rank={self.replica_rank} node_rank={self.node_rank}, \"\n            f\"server address: {server_address}, port: {server_port}\"\n        )\n        host = f\"[{server_address}]\" if is_valid_ipv6_address(server_address) else server_address\n        self._engine = AsyncHttpServerAdapter(\n            model_path=self.model_config.local_path, host=host, port=server_port, launch_server=False\n        )\n\n    async def resume(self, tags: list[str]):\n        \"\"\"Resume rollout weights or kv cache in GPU memory.\n\n        Args:\n            tag: weights or kv_cache.\n        \"\"\"\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.config.free_cache_engine:\n            await self._init_server_adapter()\n            await self._engine.resume_memory_occupation(tags=tags)\n\n    async def release(self):\n        \"\"\"Release weights and kv cache in GPU memory.\"\"\"\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.config.free_cache_engine:\n            await self._init_server_adapter()\n            await self._engine.release_memory_occupation(tags=[\"kv_cache\", \"weights\"])\n\n    async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs):\n        \"\"\"\n        Update model weights using tensor buckets, similar to THUDM/slime's implementation.\n\n        Notes:\n          - For the best performance of `rebuild_cuda_tensor`, it is recommended to:\n              1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`.\n              2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`\n            when using Tensor Parallelism (TP >= 8).\n          - See reference implementations in SLIME:\n            - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452\n            - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39\n        \"\"\"\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n            await self._init_server_adapter()\n\n        update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20\n        for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes):\n            await sgl_update_weights(\n                engine=self._engine,\n                params_batch=params_batch,\n                device_mesh_key=\"infer_tp\",\n                device_mesh=self.device_mesh,\n            )\n\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n            await self._engine.flush_cache()\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/sglang_rollout/utils.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pickle\nfrom typing import Any, Iterator, Optional\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_name\n\n\ndef broadcast_pyobj(\n    data: list[Any],\n    rank: int,\n    dist_group: Optional[torch.distributed.ProcessGroup] = None,\n    src: int = 0,\n    force_cpu_device: bool = False,\n):\n    \"\"\"from https://github.com/sgl-project/sglang/blob/844e2f227ab0cce6ef818a719170ce37b9eb1e1b/python/sglang/srt/utils.py#L905\n\n    Broadcast inputs from src rank to all other ranks with torch.dist backend.\n    The `rank` here refer to the source rank on global process group (regardless\n    of dist_group argument).\n    \"\"\"\n    device = torch.device(get_device_name() if not force_cpu_device else \"cpu\")\n\n    if rank == src:\n        if len(data) == 0:\n            tensor_size = torch.tensor([0], dtype=torch.long, device=device)\n            dist.broadcast(tensor_size, src=src, group=dist_group)\n        else:\n            serialized_data = pickle.dumps(data)\n            size = len(serialized_data)\n\n            tensor_data = torch.ByteTensor(np.frombuffer(serialized_data, dtype=np.uint8)).to(device)\n            tensor_size = torch.tensor([size], dtype=torch.long, device=device)\n\n            dist.broadcast(tensor_size, src=src, group=dist_group)\n            dist.broadcast(tensor_data, src=src, group=dist_group)\n        return data\n    else:\n        tensor_size = torch.tensor([0], dtype=torch.long, device=device)\n        dist.broadcast(tensor_size, src=src, group=dist_group)\n        size = tensor_size.item()\n\n        if size == 0:\n            return []\n\n        tensor_data = torch.empty(size, dtype=torch.uint8, device=device)\n        dist.broadcast(tensor_data, src=src, group=dist_group)\n\n        serialized_data = bytes(tensor_data.cpu().numpy())\n        data = pickle.loads(serialized_data)\n        return data\n\n\ndef get_named_tensor_buckets(\n    iterable: Iterator[tuple[str, torch.Tensor]], bucket_bytes: int\n) -> Iterator[list[tuple[str, torch.Tensor]]]:\n    \"\"\"\n    Group tensors into buckets based on a specified size in megabytes.\n\n    Args:\n        iterable: An iterator of tuples containing tensor names and tensors.\n        bucket_bytes: The maximum size of each bucket in bytes.\n\n    Yields:\n        Lists of tuples, where each tuple contains a tensor name and its corresponding tensor.\n\n    Example:\n        >>> tensors = [('tensor1', torch.randn(1000, 1000)), ('tensor2', torch.randn(2000, 2000))]\n        >>> for bucket in get_named_tensor_buckets(tensors, bucket_size_mb=10):\n        ...     print(bucket)\n        [('tensor1', tensor(...)), ('tensor2', tensor(...))]\n\n    \"\"\"\n    if bucket_bytes <= 0:\n        raise ValueError(f\"bucket_bytes must be greater than 0, got {bucket_bytes}\")\n\n    current_bucket = []\n    current_size = 0\n    for name, tensor in iterable:\n        tensor_size = tensor.element_size() * tensor.numel()\n        if current_size + tensor_size > bucket_bytes:\n            if current_bucket:\n                yield current_bucket\n            current_bucket = [(name, tensor)]\n            current_size = tensor_size\n        else:\n            current_bucket.append((name, tensor))\n            current_size += tensor_size\n\n    if current_bucket:\n        yield current_bucket\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/tokenizer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe base tokenizer class, required for any hybrid engine based rollout or inference with vLLM.\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nimport numpy as np\nimport torch\n\n__all__ = [\"HybridEngineBaseTokenizer\"]\n\n\nclass HybridEngineBaseTokenizer(ABC):\n    \"\"\"the tokenizer property and function name should align with HF's to meet vllm requirement\"\"\"\n\n    @property\n    @abstractmethod\n    def vocab_size(self):\n        \"\"\"\n        `int`: Size of the base vocabulary (without the added tokens).\n        \"\"\"\n        pass\n\n    @property\n    @abstractmethod\n    def pad_token_id(self):\n        \"\"\"\n        `Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.\n        \"\"\"\n        pass\n\n    @property\n    @abstractmethod\n    def eos_token_id(self):\n        \"\"\"\n        `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been\n        set.\n        \"\"\"\n        pass\n\n    @property\n    @abstractmethod\n    def all_special_ids(self) -> list[int]:\n        \"\"\"\n        `List[int]`: List the ids of the special tokens(`'<unk>'`, `'<cls>'`, etc.) mapped to class attributes.\n        \"\"\"\n        pass\n\n    @property\n    @abstractmethod\n    def all_special_tokens(self) -> list[str]:\n        \"\"\"\n        `List[str]`: A list of the unique special tokens (`'<unk>'`, `'<cls>'`, ..., etc.).\n\n        Convert tokens of `tokenizers.AddedToken` type to string.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def encode(self, text):\n        \"\"\"\n        Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.\n\n        Args:\n            text (`str`, `List[str]` or `List[int]`):\n                The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the\n                `tokenize` method) or a list of integers.\n\n            text_pair (`str`, `List[str]` or `List[int]`, *optional*):\n                Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using\n                the `tokenize` method) or a list of integers.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def decode(\n        self,\n        token_ids: int | list[int] | np.ndarray | torch.Tensor,\n        skip_special_tokens: bool = False,\n        clean_up_tokenization_spaces: bool = None,\n        **kwargs,\n    ) -> str:\n        \"\"\"\n        Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special\n        tokens and clean up tokenization spaces.\n\n        Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.\n\n        Args:\n            token_ids (`Union[int, List[int], np.ndarray, torch.Tensor]`):\n                List of tokenized input ids. Can be obtained using the `__call__` method.\n            skip_special_tokens (`bool`, *optional*, defaults to `False`):\n                Whether or not to remove special tokens in the decoding.\n            clean_up_tokenization_spaces (`bool`, *optional*):\n                Whether or not to clean up the tokenization spaces. If `None`, will default to\n                `self.clean_up_tokenization_spaces`.\n            kwargs (additional keyword arguments, *optional*):\n                Will be passed to the underlying model specific decode method.\n\n        Returns:\n            `str`: The decoded sentence.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False) -> str | list[str]:\n        \"\"\"\n        Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and\n        added tokens.\n\n        Args:\n            ids (`int` or `List[int]`):\n                The token id (or token ids) to convert to tokens.\n            skip_special_tokens (`bool`, *optional*, defaults to `False`):\n                Whether or not to remove special tokens in the decoding.\n\n        Returns:\n            `str` or `List[str]`: The decoded token(s).\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_added_vocab(self) -> dict[str, int]:\n        \"\"\"\n        Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from\n        the fast call because for now we always add the tokens even if they are already in the vocabulary. This is\n        something we should change.\n\n        Returns:\n            `Dict[str, int]`: The added tokens.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def convert_tokens_to_string(self, tokens: list[str]) -> str:\n        \"\"\"\n        Converts a sequence of tokens in a single string. The most simple way to do it is `\" \".join(tokens)` but we\n        often want to remove sub-word tokenization artifacts at the same time.\n\n        Args:\n            tokens (`List[str]`): The token to join in a string.\n\n        Returns:\n            `str`: The joined tokens.\n        \"\"\"\n        pass\n\n    @property\n    def is_fast(self):\n        return False\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport ipaddress\nimport logging\nimport os\nimport socket\n\nimport uvicorn\nfrom fastapi import FastAPI\n\nlogger = logging.getLogger(__file__)\n\n\ndef is_valid_ipv6_address(address: str) -> bool:\n    try:\n        ipaddress.IPv6Address(address)\n        return True\n    except ValueError:\n        return False\n\n\ndef get_free_port(address: str) -> tuple[int, socket.socket]:\n    family = socket.AF_INET\n    if is_valid_ipv6_address(address):\n        family = socket.AF_INET6\n\n    sock = socket.socket(family=family, type=socket.SOCK_STREAM)\n    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)\n    sock.bind((address, 0))\n\n    port = sock.getsockname()[1]\n    return port, sock\n\n\nasync def run_unvicorn(app: FastAPI, server_args, server_address, max_retries=5) -> tuple[int, asyncio.Task]:\n    server_port, server_task = None, None\n\n    for i in range(max_retries):\n        try:\n            server_port, sock = get_free_port(server_address)\n            app.server_args = server_args\n            config = uvicorn.Config(app, host=server_address, port=server_port, log_level=\"warning\")\n            server = uvicorn.Server(config)\n            server.should_exit = True\n            await server.serve()\n            server_task = asyncio.create_task(server.main_loop())\n            break\n        except (OSError, SystemExit) as e:\n            logger.error(f\"Failed to start HTTP server on port {server_port} at try {i}, error: {e}\")\n    else:\n        logger.error(f\"Failed to start HTTP server after {max_retries} retries, exiting...\")\n        os._exit(-1)\n\n    logger.info(f\"HTTP server started on port {server_port}\")\n    return server_port, server_task\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/vllm_rollout/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nfrom importlib.metadata import PackageNotFoundError, version\n\nfrom .vllm_rollout_spmd import vLLMAsyncRollout, vLLMRollout  # noqa: F401\n\n\ndef get_version(pkg):\n    try:\n        return version(pkg)\n    except PackageNotFoundError:\n        return None\n\n\nvllm_package_name = \"vllm\"\nvllm_package_version = get_version(vllm_package_name)\nif vllm_package_version is None:\n    raise PackageNotFoundError(\n        \"To use vllm rollout, please ensure the 'vllm' package is properly installed. See \"\n        \"https://verl.readthedocs.io/en/latest/start/install.html for more details\"\n    )\n\nif \"ROCM_PATH\" in os.environ:\n    import re\n\n    match = re.match(r\"(\\d+\\.\\d+\\.?\\d*)\", vllm_package_version)\n    if match:\n        vllm_package_version = match.group(1)\n    else:\n        raise ValueError(f\"Warning: Could not parse version format: {vllm_package_version}\")\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/vllm_rollout/utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# magic numbers that ensure we are using the same LoRA adapter during the rollout and training process\nVLLM_LORA_INT_ID = 123\nVLLM_LORA_NAME = \"123\"\nVLLM_LORA_PATH = \"simon_lora_path\"\n\n\ndef get_vllm_max_lora_rank(lora_rank: int):\n    \"\"\"\n    For vLLM, the smallest `max_lora_rank` is 8, and allowed values are (8, 16, 32, 64, 128, 256, 320, 512)\n    This function automatically adjusts the `max_lora_rank` to the nearest allowed value.\n\n    Reference: https://github.com/vllm-project/vllm/blob/8a297115e2367d463b781adb86b55ac740594cf6/vllm/config/lora.py#L27\n    \"\"\"\n    assert lora_rank > 0, f\"lora_rank must be greater than 0 to invoke this function, get {lora_rank}\"\n    vllm_max_lora_ranks = [8, 16, 32, 64, 128, 256, 320, 512]\n    for rank in vllm_max_lora_ranks:\n        if lora_rank <= rank:\n            return rank\n\n    raise ValueError(f\"lora_rank must be less than or equal to {vllm_max_lora_ranks[-1]}, but got {lora_rank}\")\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/vllm_rollout/vllm_async_server.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport asyncio\nimport json\nimport logging\nimport os\nimport pickle\nfrom pprint import pprint\nfrom typing import Any, Callable, Optional\n\nimport numpy as np\nimport ray\nimport vllm.entrypoints.cli.serve\nimport zmq\nfrom ray.actor import ActorHandle\nfrom vllm import SamplingParams\nfrom vllm.engine.arg_utils import AsyncEngineArgs\nfrom vllm.entrypoints.openai.api_server import (\n    build_app,\n    init_app_state,\n)\nfrom vllm.inputs import TokensPrompt\nfrom vllm.lora.request import LoRARequest\nfrom vllm.outputs import RequestOutput\nfrom vllm.usage.usage_lib import UsageContext\nfrom vllm.utils import FlexibleArgumentParser, get_tcp_uri\nfrom vllm.v1.engine.async_llm import AsyncLLM\nfrom vllm.v1.engine.core import EngineCoreProc\nfrom vllm.v1.engine.utils import CoreEngineProcManager\nfrom vllm.v1.executor.abstract import Executor\n\nfrom verl.single_controller.ray import RayClassWithInitArgs\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.workers.config import HFModelConfig, RewardModelConfig, RolloutConfig\nfrom verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput\nfrom verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address, run_unvicorn\nfrom verl.workers.rollout.vllm_rollout import vLLMAsyncRollout\nfrom verl.workers.rollout.vllm_rollout.utils import (\n    VLLM_LORA_INT_ID,\n    VLLM_LORA_NAME,\n    VLLM_LORA_PATH,\n    get_vllm_max_lora_rank,\n)\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(logging.INFO)\n\n\nclass ExternalZeroMQDistributedExecutor(Executor):\n    \"\"\"An executor that engines are launched by external ray actors.\"\"\"\n\n    uses_ray: bool = False\n\n    def _init_executor(self) -> None:\n        dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local\n        tp_size = self.vllm_config.parallel_config.tensor_parallel_size\n\n        addresses = os.environ[\"VERL_VLLM_ZMQ_ADDRESSES\"].split(\",\")\n        addresses = addresses[dp_rank_local * tp_size : (dp_rank_local + 1) * tp_size]\n        self.context = zmq.Context()\n        self.sockets = []\n        for address in addresses:\n            socket = self.context.socket(zmq.REQ)\n            if address.startswith(\"tcp://[\"):\n                socket.setsockopt(zmq.IPV6, 1)\n            socket.connect(address)\n            self.sockets.append(socket)\n\n        kwargs = dict(\n            vllm_config=self.vllm_config,\n            local_rank=None,\n            rank=None,\n            distributed_init_method=\"env://\",\n            is_driver_worker=True,\n        )\n        self.collective_rpc(\"init_worker\", args=([kwargs],))\n        self.collective_rpc(\"init_device\")\n        self.collective_rpc(\"load_model\")\n\n    def collective_rpc(\n        self,\n        method: str | Callable,\n        timeout: Optional[float] = None,\n        args: tuple = (),\n        kwargs: Optional[dict[str, Any]] = None,\n        **kwargs_extra: Any,\n    ) -> list[Any]:\n        if isinstance(method, str):\n            sent_method = method\n        else:\n            sent_method = pickle.dumps(method)\n        del method\n\n        message = pickle.dumps((sent_method, args, kwargs or {}))\n        for socket in self.sockets:\n            socket.send(message, zmq.DONTWAIT)\n\n        outputs = []\n        for socket in self.sockets:\n            outputs.append(pickle.loads(socket.recv()))\n\n        for output in outputs:\n            if isinstance(output, Exception):\n                raise output\n        return outputs\n\n    def check_health(self):\n        return\n\n\nclass vLLMHttpServerBase:\n    \"\"\"vLLM http server in single node, this is equivalent to launch server with command line:\n    ```\n    vllm serve --tensor-parallel-size=8 ...\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        config: RolloutConfig,\n        model_config: HFModelConfig,\n        rollout_mode: RolloutMode,\n        workers: list[ActorHandle],\n        replica_rank: int,\n        node_rank: int,\n        gpus_per_node: int,\n        nnodes: int,\n    ):\n        \"\"\"\n        Args:\n            config (RolloutConfig): full config.\n            model_config (HFModelConfig): model config.\n            rollout_mode (RolloutMode): rollout mode.\n            replica_rank (int): replica rank, a replica may contain multiple nodes.\n            node_rank (int): node rank.\n            gpus_per_node (int): number of gpus per node.\n            nnodes (int): number of nodes.\n        \"\"\"\n        super().__init__()\n\n        self.config: RolloutConfig = omega_conf_to_dataclass(config)\n        self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)\n        self.config.max_model_len = self.config.prompt_length + self.config.response_length\n        self.rollout_mode = rollout_mode\n        self.workers = workers\n\n        self.replica_rank = replica_rank\n        self.node_rank = node_rank\n        self.gpus_per_node = gpus_per_node\n        self.nnodes = nnodes\n\n        if self.rollout_mode != RolloutMode.HYBRID and self.config.load_format == \"dummy\":\n            logger.warning(f\"rollout mode is {self.rollout_mode}, load_format is dummy, set to auto\")\n            self.config.load_format = \"auto\"\n\n        # used for http server\n        self._server_address = ray.util.get_node_ip_address().strip(\"[]\")\n        self._server_port = None\n\n        # used for data parallel: --data-parallel-address, --data-parallel-rpc-port\n        if self.node_rank == 0:\n            self._master_address = self._server_address\n            self._master_port, self._master_sock = get_free_port(self._server_address)\n            self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address)\n            logger.info(\n                f\"vLLMHttpServer, replica_rank: {self.replica_rank}, master address: {self._master_address}, \"\n                f\"master port: {self._master_port}, data parallel master port: {self._dp_master_port}\"\n            )\n        else:\n            self._master_address = None\n            self._master_port = None\n\n    def get_master_address(self):\n        \"\"\"Get master address and port for data parallel.\"\"\"\n        return self._master_address, self._master_port\n\n    def get_server_address(self):\n        \"\"\"Get http server address and port.\"\"\"\n        assert self._server_port is not None, \"http server is not launched, port is None\"\n        return self._server_address, self._server_port\n\n    async def launch_server(self, master_address: str = None, master_port: int = None):\n        if self.node_rank != 0:\n            assert master_address and master_port, \"non-master node should provide master address and port\"\n            self._master_address = master_address\n            self._master_port = master_port\n\n        # 1. setup vllm serve cli args\n        engine_kwargs = self.config.get(\"engine_kwargs\", {}).get(\"vllm\", {}) or {}\n        engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None}\n        if self.config.get(\"limit_images\", None):  # support for multi-image data\n            engine_kwargs[\"limit_mm_per_prompt\"] = {\"image\": self.config.get(\"limit_images\")}\n        if self.config.cudagraph_capture_sizes:\n            engine_kwargs[\"cuda_graph_sizes\"] = self.config.cudagraph_capture_sizes\n\n        # Override default generation config from hugging face model config,\n        # user can still override them by passing kwargs in each request.\n        override_generation_config = dict(\n            temperature=self.config.temperature,\n            top_k=self.config.top_k,\n            top_p=self.config.top_p,\n            repetition_penalty=1.0,\n            max_new_tokens=self.config.response_length,\n        )\n        logger.info(f\"override_generation_config: {override_generation_config}\")\n\n        args = {\n            \"dtype\": self.config.dtype,\n            \"load_format\": self.config.load_format,\n            \"skip_tokenizer_init\": False,\n            # \"trust_remote_code\": True,\n            \"max_model_len\": self.config.max_model_len,\n            \"max_num_seqs\": self.config.max_num_seqs,\n            \"enable_chunked_prefill\": self.config.enable_chunked_prefill,\n            \"max_num_batched_tokens\": self.config.max_num_batched_tokens,\n            \"enable_prefix_caching\": self.config.enable_prefix_caching,\n            \"enable_sleep_mode\": True,\n            \"disable_custom_all_reduce\": True,\n            \"enforce_eager\": self.config.enforce_eager,\n            \"gpu_memory_utilization\": self.config.gpu_memory_utilization,\n            \"disable_log_stats\": self.config.disable_log_stats,\n            \"tensor_parallel_size\": self.config.tensor_model_parallel_size,\n            \"seed\": self.config.get(\"seed\", 0),\n            \"override_generation_config\": json.dumps(override_generation_config),\n            **engine_kwargs,\n        }\n        if self.config.expert_parallel_size > 1:\n            assert self.gpus_per_node % self.config.tensor_model_parallel_size == 0, (\n                \"gpus_per_node should be divisible by tensor_model_parallel_size\"\n            )\n            data_parallel_size_local = self.gpus_per_node // self.config.tensor_model_parallel_size\n            assert len(self.workers) == data_parallel_size_local * self.config.tensor_model_parallel_size, (\n                f\"num workers ({len(self.workers)}) should be equal to dp_size_local \"\n            )\n            f\"({data_parallel_size_local}) * tp_size ({self.config.tensor_model_parallel_size})\"\n\n            args.update(\n                {\n                    \"enable_expert_parallel\": self.config.expert_parallel_size > 1,\n                    \"data_parallel_size\": self.config.data_parallel_size,\n                    \"data_parallel_size_local\": data_parallel_size_local,\n                    \"data_parallel_start_rank\": self.node_rank * data_parallel_size_local,\n                    \"data_parallel_address\": self._master_address,\n                    \"data_parallel_rpc_port\": self._master_port,\n                }\n            )\n\n        # update lora-related args\n        if self.model_config.lora_rank > 0:\n            args.update(\n                {\n                    \"enable_lora\": True,\n                    \"max_loras\": 1,\n                    \"max_lora_rank\": get_vllm_max_lora_rank(self.model_config.lora_rank),\n                }\n            )\n\n        server_args = [\"serve\", self.model_config.local_path]\n        for k, v in args.items():\n            if isinstance(v, bool):\n                if v:\n                    server_args.append(f\"--{k}\")\n            else:\n                server_args.append(f\"--{k}\")\n                server_args.append(str(v))\n\n        if self.replica_rank == 0:\n            pprint(server_args)\n\n        CMD_MODULES = [vllm.entrypoints.cli.serve]\n        parser = FlexibleArgumentParser(description=\"vLLM CLI\")\n        subparsers = parser.add_subparsers(required=False, dest=\"subparser\")\n        cmds = {}\n        for cmd_module in CMD_MODULES:\n            new_cmds = cmd_module.cmd_init()\n            for cmd in new_cmds:\n                cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd)\n                cmds[cmd.name] = cmd\n        server_args = parser.parse_args(args=server_args)\n        server_args.model = server_args.model_tag\n        if server_args.subparser in cmds:\n            cmds[server_args.subparser].validate(server_args)\n\n        # 2. setup distributed executor backend\n        distributed_executor_backend = ExternalZeroMQDistributedExecutor if len(self.workers) > 0 else None\n        server_args.distributed_executor_backend = distributed_executor_backend\n\n        zmq_addresses = ray.get([worker.get_zeromq_address.remote() for worker in self.workers])\n        logger.info(\n            f\"replica_rank={self.replica_rank}, node_rank={self.node_rank}, nnodes={self.nnodes}, \"\n            f\"get worker zmq addresses: {zmq_addresses}\"\n        )\n        os.environ[\"VERL_VLLM_ZMQ_ADDRESSES\"] = \",\".join(zmq_addresses)\n\n        # 3. launch server\n        if self.node_rank == 0:\n            await self.run_server(server_args)\n        else:\n            await self.run_headless(server_args)\n\n    async def run_server(self, args: argparse.Namespace):\n        engine_args = AsyncEngineArgs.from_cli_args(args)\n        usage_context = UsageContext.OPENAI_API_SERVER\n        vllm_config = engine_args.create_engine_config(usage_context=usage_context)\n        vllm_config.parallel_config.data_parallel_master_port = self._dp_master_port\n\n        engine_client = AsyncLLM.from_vllm_config(\n            vllm_config=vllm_config,\n            usage_context=usage_context,\n            disable_log_requests=engine_args.disable_log_requests,\n            disable_log_stats=engine_args.disable_log_stats,\n        )\n\n        # Don't keep the dummy data in memory\n        await engine_client.reset_mm_cache()\n\n        app = build_app(args)\n        await init_app_state(engine_client, vllm_config, app.state, args)\n        if self.replica_rank == 0 and self.node_rank == 0:\n            logger.info(f\"Initializing a V1 LLM engine with config: {vllm_config}\")\n\n        self.engine = engine_client\n        self._server_port, self._server_task = await run_unvicorn(app, args, self._server_address)\n\n    async def run_headless(self, args: argparse.Namespace):\n        # Create the EngineConfig.\n        engine_args = vllm.AsyncEngineArgs.from_cli_args(args)\n        usage_context = UsageContext.OPENAI_API_SERVER\n        vllm_config = engine_args.create_engine_config(usage_context=usage_context, headless=True)\n\n        parallel_config = vllm_config.parallel_config\n        local_engine_count = parallel_config.data_parallel_size_local\n\n        host = parallel_config.data_parallel_master_ip\n        port = engine_args.data_parallel_rpc_port  # add to config too\n        handshake_address = get_tcp_uri(host, port)\n\n        # Create the engines.\n        self.engine_manager = CoreEngineProcManager(\n            target_fn=EngineCoreProc.run_engine_core,\n            local_engine_count=local_engine_count,\n            start_index=vllm_config.parallel_config.data_parallel_rank,\n            local_start_index=0,\n            vllm_config=vllm_config,\n            local_client=False,\n            handshake_address=handshake_address,\n            executor_class=Executor.get_class(vllm_config),\n            log_stats=not engine_args.disable_log_stats,\n        )\n\n    async def generate(\n        self,\n        prompt_ids: list[int],\n        sampling_params: dict[str, Any],\n        request_id: str,\n        image_data: Optional[list[Any]] = None,\n    ) -> TokenOutput:\n        \"\"\"Generate sequence with token-in-token-out.\"\"\"\n        # TODO(@wuxibin): switch to `/generate` http endpoint once multi-modal support ready.\n        max_tokens = self.config.max_model_len - len(prompt_ids)\n        sampling_params[\"logprobs\"] = 0 if sampling_params.pop(\"logprobs\", False) else None\n        sampling_params.setdefault(\"repetition_penalty\", self.config.get(\"repetition_penalty\", 1.0))\n        sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)\n        prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)\n        prompt = TokensPrompt(\n            prompt_token_ids=prompt_ids, multi_modal_data={\"image\": image_data} if image_data else None\n        )\n\n        # Add lora request\n        lora_request = None\n        if self.model_config.lora_rank > 0:\n            # Make sure we also check that the lora is already loaded in the engine\n            lora_loaded = VLLM_LORA_INT_ID in await self.engine.list_loras()\n            if lora_loaded:\n                lora_request = LoRARequest(\n                    lora_name=VLLM_LORA_NAME, lora_int_id=VLLM_LORA_INT_ID, lora_path=VLLM_LORA_PATH\n                )\n\n        generator = self.engine.generate(\n            prompt=prompt, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request\n        )\n\n        # Get final response\n        final_res: Optional[RequestOutput] = None\n        async for output in generator:\n            final_res = output\n        assert final_res is not None\n\n        token_ids = final_res.outputs[0].token_ids\n        log_probs = None\n        if sampling_params.logprobs is not None:\n            log_probs = [logprobs[token_ids[i]].logprob for i, logprobs in enumerate(final_res.outputs[0].logprobs)]\n        return TokenOutput(token_ids=token_ids, log_probs=log_probs)\n\n    async def wake_up(self):\n        if self.rollout_mode == RolloutMode.HYBRID:\n            # Call all workers to switch between trainer mode and rollout mode.\n            await asyncio.gather(*[worker.wake_up.remote() for worker in self.workers])\n        elif self.rollout_mode == RolloutMode.COLOCATED:\n            # Directly call engine to wake up without sync weights.\n            if self.node_rank == 0:\n                await self.engine.wake_up(tags=[\"kv_cache\", \"weights\"])\n        elif self.rollout_mode == RolloutMode.STANDALONE:\n            logger.info(\"skip wake_up in standalone mode\")\n\n    async def sleep(self):\n        if self.rollout_mode == RolloutMode.HYBRID:\n            if self.node_rank == 0:\n                await self.engine.reset_prefix_cache()\n            await asyncio.gather(*[worker.sleep.remote() for worker in self.workers])\n        elif self.rollout_mode == RolloutMode.COLOCATED:\n            if self.node_rank == 0:\n                await self.engine.reset_prefix_cache()\n                await self.engine.sleep(level=1)\n        elif self.rollout_mode == RolloutMode.STANDALONE:\n            logger.info(\"skip sleep in standalone mode\")\n\n    async def wait_for_requests_to_drain(self):\n        await self.engine.wait_for_requests_to_drain()\n\n\n@ray.remote(num_cpus=1)\nclass vLLMHttpServer(vLLMHttpServerBase):\n    \"\"\"vLLM http server in single node, this is equivalent to launch server with command line:\n    ```\n    vllm serve --tensor-parallel-size=8 ...\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        config: RolloutConfig | RewardModelConfig,\n        model_config: HFModelConfig,\n        rollout_mode: RolloutMode,\n        workers: list[ActorHandle],\n        replica_rank: int,\n        node_rank: int,\n        gpus_per_node: int,\n        nnodes: int,\n    ):\n        super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes)\n\n\n_rollout_worker_actor_cls = ray.remote(vLLMAsyncRollout)\n\n\nclass vLLMReplica(RolloutReplica):\n    def __init__(\n        self,\n        replica_rank: int,\n        config: RolloutConfig | RewardModelConfig,\n        model_config: HFModelConfig,\n        gpus_per_node: int = 8,\n        is_reward_model: bool = False,\n    ):\n        super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model)\n        self.server_class = vLLMHttpServer\n\n    def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:\n        \"\"\"Get rollout worker actor class for colocated and standalone mode.\"\"\"\n        worker_dict_cls = RayClassWithInitArgs(\n            cls=_rollout_worker_actor_cls,\n            config=self.config,\n            model_config=self.model_config,\n            device_mesh=None,\n        )\n        return worker_dict_cls\n\n    async def launch_servers(self):\n        \"\"\"Launch http server in each node.\"\"\"\n        assert len(self.workers) == self.world_size, (\n            f\"worker number {len(self.workers)} not equal to world size {self.world_size}\"\n        )\n\n        # get node_id of all workers\n        worker_node_ids = await asyncio.gather(\n            *[\n                worker.__ray_call__.remote(lambda self: ray.get_runtime_context().get_node_id())\n                for worker in self.workers\n            ]\n        )\n\n        # For non-data parallel case, there's only one server whether it's single or multi nodes.\n        nnodes, gpus_per_node = self.nnodes, self.gpus_per_node\n        if self.config.data_parallel_size == 1:\n            nnodes = 1\n            gpus_per_node = self.world_size\n\n        # create server actor in each node with node affinity\n        for node_rank in range(nnodes):\n            workers = self.workers[node_rank * gpus_per_node : (node_rank + 1) * gpus_per_node]\n            node_id = worker_node_ids[node_rank * gpus_per_node]\n            name = (\n                f\"vllm_server_{self.replica_rank}_{node_rank}\"\n                if not self.is_reward_model\n                else f\"vllm_server_reward_{self.replica_rank}_{node_rank}\"\n            )\n            server = self.server_class.options(\n                scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(\n                    node_id=node_id,\n                    soft=False,\n                ),\n                name=name,\n            ).remote(\n                config=self.config,\n                model_config=self.model_config,\n                rollout_mode=self.rollout_mode,\n                workers=workers,\n                replica_rank=self.replica_rank,\n                node_rank=node_rank,\n                gpus_per_node=gpus_per_node,\n                nnodes=nnodes,\n            )\n            self.servers.append(server)\n\n        # launch http server in each node\n        master_address, master_port = await self.servers[0].get_master_address.remote()\n        await asyncio.gather(\n            *[\n                server.launch_server.remote(master_address=master_address, master_port=master_port)\n                for server in self.servers\n            ]\n        )\n\n        # get http server address from first server\n        server_address, server_port = await self.servers[0].get_server_address.remote()\n        self._server_handle = self.servers[0]\n        self._server_address = (\n            f\"[{server_address}]:{server_port}\"\n            if is_valid_ipv6_address(server_address)\n            else f\"{server_address}:{server_port}\"\n        )\n\n    async def sleep(self):\n        \"\"\"Sleep each rollout server.\"\"\"\n        # Drain DP engines for safe sleep.\n        await self.servers[0].wait_for_requests_to_drain.remote()\n        await asyncio.gather(*[server.sleep.remote() for server in self.servers])\n\n\ndef _qwen2_5_vl_dedup_image_tokens(prompt_ids: list[int], processor):\n    \"\"\"Deduplicate consecutive image tokens in prompt_ids for Qwen2.5-VL, since vLLM will replicate the\n    <|image_pad|> token by image_data.\n\n    For example,\n    ```\n    <|vision_start|><|image_pad|><|image_pad|>...<|image_pad|><|vision_end|>\n    =>\n    <|vision_start|><|image_pad|><|vision_end|>\n    ```\n    \"\"\"\n    if processor is not None and \"Qwen2VLImageProcessor\" in processor.image_processor.__class__.__name__:\n        prompt_ids = np.array(prompt_ids)\n\n        # Create a mask where True indicates elements to keep\n        mask = np.ones(len(prompt_ids), dtype=bool)\n\n        # Find where the array equals the value\n        is_value = prompt_ids == processor.image_token_id\n\n        # Find consecutive duplicates by checking if previous element is also the value\n        mask[1:] &= ~(is_value[1:] & is_value[:-1])\n\n        return prompt_ids[mask].tolist()\n    else:\n        return prompt_ids\n"
  },
  {
    "path": "verl_distillation/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe vllm_rollout that can be applied in different backend\nWhen working with FSDP:\n- Use DTensor weight loader (recommended) or HF weight loader\n- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM\nWhen working with Megatron:\n- Use Megatron weight loader\n- During training, only the current pp stage holds the parameters\n- Before inference, broadcast the parameters of the current pp rank\n  to all other pp ranks (all pp ranks holds all the parameters)\n- Bind the parameters to the inference engine\n- Do inference in tp. pp is treated as additional dp\n- After inference, all the parameters that doesn't belong to this pp rank is freed.\n\"\"\"\n\nimport asyncio\nimport getpass\nimport inspect\nimport logging\nimport os\nimport pickle\nimport time\nfrom contextlib import contextmanager\nfrom dataclasses import asdict\nfrom types import MethodType\nfrom typing import Any, Generator\n\nimport numpy as np\nimport ray\nimport torch\nimport torch.distributed\nimport zmq\nimport zmq.asyncio\nfrom filelock import FileLock\nfrom omegaconf import ListConfig\nfrom tensordict import TensorDict\nfrom torch.distributed.device_mesh import DeviceMesh\nfrom vllm import LLM, SamplingParams\nfrom vllm.config import CompilationConfig, LoRAConfig\nfrom vllm.lora.request import LoRARequest\n\ntry:\n    # https://github.com/vllm-project/vllm/commit/96b9aa5aa076e64c68765232aec343e4d0006e2a\n    from vllm.config import CompilationMode\n\n    _use_compilation_mode = True\nexcept ImportError:\n    from vllm.config import CompilationLevel\n\n    _use_compilation_mode = False\n\ntry:\n    from vllm.worker.worker_base import WorkerWrapperBase\nexcept ModuleNotFoundError:\n    # https://github.com/vllm-project/vllm/commit/6a113d9aed8221a9c234535958e70e34ab6cac5b\n    from vllm.v1.worker.worker_base import WorkerWrapperBase\n\nfrom verl import DataProto\nfrom verl.third_party.vllm import VLLM_SLEEP_LEVEL\nfrom verl.utils.device import is_npu_available\nfrom verl.utils.distributed import initialize_global_process_group_ray\nfrom verl.utils.model import get_lora_rank_from_adapter\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.ray_utils import ray_noset_visible_devices\nfrom verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length\nfrom verl.utils.vllm import TensorLoRARequest, VLLMHijack, is_version_ge\nfrom verl.workers.config import HFModelConfig, RolloutConfig\nfrom verl.workers.rollout.base import BaseRollout\nfrom verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address\nfrom verl.workers.rollout.vllm_rollout.utils import (\n    VLLM_LORA_INT_ID,\n    VLLM_LORA_NAME,\n    VLLM_LORA_PATH,\n    get_vllm_max_lora_rank,\n)\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n# TODO\n# 1. support pp in vllm\n# 2. passing tokenizer is not necessary? no encoding/decoding is happending here\n# 3. simplify init logics\n\n\n# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.\ndef _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[int]:\n    # remove the left padding in the prompt token_id\n    # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id\n    # is not None else self.llm_engine.tokenizer.eos_token_id\n    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]\n    token_ids = prompt_token_ids[non_pad_index:].tolist()\n    return token_ids\n\n\nif is_version_ge(pkg=\"vllm\", minver=\"0.7.3\"):\n    VLLMHijack.hijack()\n\n\nclass vLLMRollout(BaseRollout):\n    def __init__(\n        self,\n        config: RolloutConfig,\n        model_config: HFModelConfig,\n        device_mesh: DeviceMesh,\n    ):\n        super().__init__(config, model_config, device_mesh)\n\n        if config.layered_summon:\n            self.sleep_level = 1\n        else:\n            self.sleep_level = VLLM_SLEEP_LEVEL\n\n        model_path = model_config.local_path\n        tokenizer = model_config.tokenizer\n        model_hf_config = model_config.hf_config\n        trust_remote_code = model_config.trust_remote_code\n\n        lora_adapter_path = getattr(model_config, \"lora_adapter_path\", None)\n        if lora_adapter_path is not None:\n            lora_rank = get_lora_rank_from_adapter(lora_adapter_path)\n        else:\n            lora_rank = model_config.lora_rank\n\n        self.lora_kwargs = (\n            {\"enable_lora\": True, \"max_loras\": 1, \"max_lora_rank\": get_vllm_max_lora_rank(lora_rank)}\n            if model_config.lora_rank > 0\n            else {}\n        )\n\n        tensor_parallel_size = self.config.get(\"tensor_model_parallel_size\", 1)\n        assert tensor_parallel_size <= torch.distributed.get_world_size(), (\n            \"tensor parallel size should be less than or equal to the world size\"\n        )\n        max_num_batched_tokens = self.config.get(\"max_num_batched_tokens\", 8192)\n\n        rope_scaling_config = getattr(model_hf_config, \"rope_scaling\", None)\n        if not rope_scaling_config:\n            max_position_embeddings = None\n            if hasattr(model_hf_config, \"max_position_embeddings\"):\n                max_position_embeddings = model_hf_config.max_position_embeddings\n            elif hasattr(model_hf_config, \"llm_config\") and hasattr(\n                model_hf_config.llm_config, \"max_position_embeddings\"\n            ):\n                max_position_embeddings = model_hf_config.llm_config.max_position_embeddings\n            elif hasattr(model_hf_config, \"text_config\") and hasattr(\n                model_hf_config.text_config, \"max_position_embeddings\"\n            ):\n                max_position_embeddings = model_hf_config.text_config.max_position_embeddings\n            if max_position_embeddings is None:\n                raise ValueError(\"max_position_embeddings not found in model_hf_config\")\n            assert max_position_embeddings >= config.prompt_length + config.response_length, (\n                \"model context length should be greater than total sequence length\"\n            )\n        else:\n            # handle type where there's a length extend factor\n            # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support\n            # for using yarn as an example\n            rope_scaling_factor = rope_scaling_config.get(\"factor\", 1.0)\n\n            assert (\n                model_hf_config.max_position_embeddings * rope_scaling_factor\n                >= config.prompt_length + config.response_length\n            ), (\n                \"model context length should be greater than total sequence length, \"\n                + f\"got rope_scaling_factor={rope_scaling_factor} and \"\n                + f\"max_position_embeddings={model_hf_config.max_position_embeddings}\"\n            )\n\n        max_model_len = int(config.max_model_len or config.prompt_length + config.response_length)\n\n        if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill:\n            raise ValueError(\n                \"Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \\\n                             please increase max_num_batched_tokens or disable chunked prefill\"\n            )\n\n        load_format = \"dummy\" if config.load_format.startswith(\"dummy\") else config.load_format\n\n        # copy it to avoid secretly modifying the engine config\n        engine_kwargs = config.get(\"engine_kwargs\", {}).get(\"vllm\", {}) or {}\n\n        # For each vLLM engine parameter,\n        # - `None` means not setting it, so we pop it, and leave it to vLLM default value\n        #    (which can vary across different vLLM versions);\n        # - Otherwise it's the desired value we want to explicitly set.\n        engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None}\n        if config.get(\"limit_images\", None):  # support for multi-image data\n            engine_kwargs[\"limit_mm_per_prompt\"] = {\"image\": config.get(\"limit_images\")}\n\n        compilation_config = {}\n\n        cudagraph_capture_sizes = config.get(\"cudagraph_capture_sizes\")\n        # enforce_eager must be False to use cudagraph\n        if not config.enforce_eager and cudagraph_capture_sizes:\n            if isinstance(cudagraph_capture_sizes, ListConfig):\n                compilation_args = {\"cudagraph_capture_sizes\": cudagraph_capture_sizes}\n                if _use_compilation_mode:\n                    compilation_args[\"mode\"] = CompilationMode.VLLM_COMPILE\n                else:\n                    compilation_args[\"level\"] = CompilationLevel.PIECEWISE\n                compilation_config[\"compilation_config\"] = CompilationConfig(**compilation_args)\n            else:\n                logger.warning(f\"cudagraph_capture_sizes must be a list, but got {cudagraph_capture_sizes}\")\n\n        self.inference_engine = LLM(\n            model=model_path,\n            enable_sleep_mode=config.free_cache_engine,\n            tensor_parallel_size=tensor_parallel_size,\n            distributed_executor_backend=\"external_launcher\",\n            dtype=config.dtype,\n            enforce_eager=config.enforce_eager,\n            gpu_memory_utilization=config.gpu_memory_utilization,\n            disable_custom_all_reduce=True,\n            skip_tokenizer_init=False,\n            max_model_len=max_model_len,\n            max_num_seqs=config.max_num_seqs,\n            load_format=load_format,\n            disable_log_stats=config.disable_log_stats,\n            max_num_batched_tokens=max_num_batched_tokens,\n            enable_chunked_prefill=config.enable_chunked_prefill,\n            enable_prefix_caching=config.enable_prefix_caching,\n            trust_remote_code=trust_remote_code,\n            seed=config.get(\"seed\", 0),\n            **compilation_config,\n            **self.lora_kwargs,\n            **engine_kwargs,\n        )\n\n        kwargs = dict(\n            n=1,\n            logprobs=0,  # can be set to 0 and let actor to recompute\n            max_tokens=config.response_length,\n            repetition_penalty=config.get(\"repetition_penalty\", 1.0),\n        )\n\n        kwargs[\"detokenize\"] = False\n\n        # supporting adding any sampling params from the config file\n        for k in config.keys():\n            if hasattr(SamplingParams(), str(k)) and k != \"seed\":\n                kwargs[k] = config.get(k)\n        kwargs[\"n\"] = 1  # already repeat in ray_trainer\n        print(f\"kwargs: {kwargs}\")\n        self.sampling_params = SamplingParams(**kwargs)\n\n        self.pad_token_id = tokenizer.pad_token_id\n\n    @contextmanager\n    def update_sampling_params(self, **kwargs):\n        # update sampling params\n        old_sampling_params_args = {}\n        if kwargs:\n            for key, value in kwargs.items():\n                if hasattr(self.sampling_params, key):\n                    old_value = getattr(self.sampling_params, key)\n                    old_sampling_params_args[key] = old_value\n                    setattr(self.sampling_params, key, value)\n        yield\n        # roll back to previous sampling params\n        # if len(old_sampling_params_args):\n        for key, value in old_sampling_params_args.items():\n            setattr(self.sampling_params, key, value)\n\n    @GPUMemoryLogger(role=\"vllm rollout spmd\", logger=logger)\n    @torch.no_grad()\n    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:\n        \"\"\"Generate sequences for a batch of prompts.\n\n        Args:\n            batch (DataProto): Input batch.\n\n        Returns:\n            DataProto: Output batch.\n            - prompts: [bsz, prompt_length], prompt token ids from dataset.\n            - responses: [bsz, response_length], output token ids include response tokens\n              from LLM generation and observation tokens from tool_calls.\n            - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.\n            - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens\n              and response tokens.\n            - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.\n            - position_ids: [bsz, prompt_length + response_length], incremental position ids.\n\n            For multi-turn conversations:\n            responses:     |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|\n            response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|\n        \"\"\"\n        idx = prompts.batch[\"input_ids\"]  # (bs, prompt_length)\n        # left-padded attention_mask\n        attention_mask = prompts.batch[\"attention_mask\"]\n        position_ids = prompts.batch[\"position_ids\"]\n\n        # used to construct attention_mask\n        eos_token_id = prompts.meta_info[\"eos_token_id\"]\n\n        batch_size = idx.size(0)\n\n        non_tensor_batch = prompts.non_tensor_batch\n        if \"raw_prompt_ids\" not in non_tensor_batch:\n            non_tensor_batch[\"raw_prompt_ids\"] = np.array(\n                [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object\n            )\n\n        if batch_size != len(non_tensor_batch[\"raw_prompt_ids\"]):\n            raise RuntimeError(\"vllm sharding manager is not work properly.\")\n\n        if \"multi_modal_data\" in non_tensor_batch:\n            vllm_inputs = []\n            for raw_prompt_ids, multi_modal_data in zip(\n                non_tensor_batch.pop(\"raw_prompt_ids\"), non_tensor_batch.pop(\"multi_modal_data\"), strict=True\n            ):\n                vllm_inputs.append({\"prompt_token_ids\": raw_prompt_ids, \"multi_modal_data\": multi_modal_data})\n        else:\n            vllm_inputs = [\n                {\"prompt_token_ids\": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop(\"raw_prompt_ids\")\n            ]\n\n        for input_data in vllm_inputs:\n            # Ensure token IDs are lists or numpy arrays\n            if not isinstance(input_data[\"prompt_token_ids\"], list | np.ndarray):\n                raise TypeError(\n                    f\"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}\"\n                )\n\n            input_data[\"prompt_token_ids\"] = list(input_data[\"prompt_token_ids\"])\n\n        do_sample = prompts.meta_info.get(\"do_sample\", True)\n        is_validate = prompts.meta_info.get(\"validate\", False)\n        if not do_sample:\n            kwargs = {\n                \"best_of\": 1,\n                \"top_p\": 1.0,\n                \"top_k\": -1,\n                \"min_p\": 0.0,\n                \"temperature\": 0,\n                \"n\": 1,  # if greedy, only 1 response\n            }\n        elif is_validate:\n            # TODO: try **\n            kwargs = {\n                \"top_k\": self.config.val_kwargs.top_k,\n                \"top_p\": self.config.val_kwargs.top_p,\n                \"temperature\": self.config.val_kwargs.temperature,\n                \"n\": 1,  # if validate, already repeat in ray_trainer\n            }\n\n        lora_requests = None\n        if self.lora_kwargs:\n            lora_int_ids = list(self.inference_engine.llm_engine.list_loras())\n            if len(lora_int_ids) > 0:\n                lora_int_id = lora_int_ids[0]\n                lora_requests = [\n                    LoRARequest(lora_name=f\"{lora_int_id}\", lora_int_id=lora_int_id, lora_path=\"/simon-stub-path\")\n                ] * batch_size\n\n        # users can customize different sampling_params at different run\n        with self.update_sampling_params(**kwargs):\n            outputs = self.inference_engine.generate(\n                prompts=vllm_inputs,  # because we have already convert it to prompt token id\n                sampling_params=self.sampling_params,\n                lora_request=lora_requests,\n                use_tqdm=False,\n            )\n\n            # TODO(sgm): disable logprob when recompute_log_prob is enable\n            # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)\n\n            response = []\n            rollout_log_probs = []\n            for output in outputs:\n                for sample_id in range(len(output.outputs)):\n                    response_ids = output.outputs[sample_id].token_ids\n                    response.append(response_ids)\n                    if self.config.calculate_log_probs:\n                        curr_log_prob = []\n                        for i, logprob in enumerate(output.outputs[sample_id].logprobs):\n                            curr_log_prob.append(logprob[response_ids[i]].logprob)\n                        rollout_log_probs.append(curr_log_prob)\n\n            response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(\n                idx.device\n            )\n            if self.config.calculate_log_probs:\n                rollout_log_probs = pad_2d_list_to_length(\n                    rollout_log_probs, -1, max_length=self.config.response_length\n                ).to(idx.device)\n                rollout_log_probs = rollout_log_probs.to(torch.float32)\n\n            seq = torch.cat([idx, response], dim=-1)\n\n        response_length = response.size(1)\n        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n        delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)\n        if position_ids.dim() == 3:  # qwen2vl mrope (batch size, 4, seq len)\n            delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, position_ids.size(1), -1)\n\n        # TODO(sgm): fix position_ids on right_pad\n        # prompt: left pad + response: right pad\n        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n        response_position_ids = position_ids[..., -1:] + delta_position_id\n        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n        response_attention_mask = get_response_mask(\n            response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype\n        )\n        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n\n        # all the tp ranks should contain the same data here. data in all ranks are valid\n        batch = TensorDict(\n            {\n                \"prompts\": idx,\n                \"responses\": response,\n                \"input_ids\": seq,  # here input_ids become the whole sentences\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=batch_size,\n        )\n        if self.config.calculate_log_probs:\n            # we will recompute old log prob with actor\n            batch[\"rollout_log_probs\"] = rollout_log_probs\n\n        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)\n\n    async def resume(self, tags: list[str]):\n        \"\"\"Resume rollout weights or kv cache in GPU memory.\n\n        Args:\n            tags: weights or kv_cache.\n        \"\"\"\n        if not self.config.free_cache_engine:\n            return\n\n        if \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters:\n            self.inference_engine.wake_up(tags=tags)\n        else:\n            self.inference_engine.wake_up()\n\n    async def release(self):\n        \"\"\"Release weights and kv cache in GPU memory.\"\"\"\n        self.inference_engine.reset_prefix_cache()\n\n        if not self.config.free_cache_engine:\n            return\n\n        self.inference_engine.sleep(level=self.sleep_level)\n\n    async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs):\n        \"\"\"Update the weights of the rollout model.\n\n        Args:\n            weights: A generator that yields the name of the weight tensor and the tensor itself.\n        \"\"\"\n        peft_config, base_sync_done = kwargs.get(\"peft_config\", None), kwargs.get(\"base_sync_done\", False)\n        if peft_config and base_sync_done:\n            lora_int_id = int(time.time_ns() % 0x7FFFFFFF)\n            lora_reqest = TensorLoRARequest(\n                lora_name=f\"{lora_int_id}\",\n                lora_int_id=lora_int_id,\n                lora_path=\"simon_lora_path\",\n                peft_config=asdict(peft_config),\n                lora_tensors=dict(weights),\n            )\n            self.inference_engine.llm_engine.add_lora(lora_reqest)\n            logger.info(f\"vLLM load weights, loaded_params: {len(weights)}\")\n        else:\n            from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader\n\n            model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n            patch_vllm_moe_model_weight_loader(model)\n            model.load_weights(weights)\n\n\n# https://github.com/vllm-project/vllm/issues/13175\ndef _monkey_patch_compute_logits(model, vocab_size: int):\n    original_compute_logits = model.compute_logits\n\n    def compute_logits(\n        self,\n        *args,\n        **kwargs,\n    ) -> torch.Tensor:\n        logits = original_compute_logits(*args, **kwargs)\n        logits[..., vocab_size:] = float(\"-inf\")\n        return logits\n\n    model.compute_logits = MethodType(compute_logits, model)\n\n\nclass vLLMAsyncRollout(BaseRollout):\n    \"\"\"vLLMAsyncRollout is a thin wrapper of WorkerWrapperBase, which is engine in single worker process.\"\"\"\n\n    def __init__(\n        self,\n        config: RolloutConfig,\n        model_config: HFModelConfig,\n        device_mesh: DeviceMesh,\n    ):\n        super().__init__(config, model_config, device_mesh)\n        self.tokenizer = model_config.tokenizer\n        self.inference_engine: WorkerWrapperBase = None\n        self.address = self._init_zeromq()\n        self.lora_config = (\n            {\"max_loras\": 1, \"max_lora_rank\": get_vllm_max_lora_rank(model_config.lora_rank)}\n            if model_config.lora_rank > 0\n            else {}\n        )\n\n        # https://github.com/vllm-project/vllm/issues/25171\n        if config.layered_summon or config.expert_parallel_size > 1:\n            self.sleep_level = 1\n        else:\n            self.sleep_level = VLLM_SLEEP_LEVEL\n\n    def _init_zeromq(self) -> str:\n        tensor_parallel_size = self.config.tensor_model_parallel_size\n\n        # single node: ipc, multi nodes: tcp\n        local_world_size = int(os.environ[\"RAY_LOCAL_WORLD_SIZE\"])\n        socket_type = \"ipc\" if tensor_parallel_size <= local_world_size else \"tcp\"\n\n        # File lock to prevent multiple workers listen to same port\n        with FileLock(f\"/tmp/verl_vllm_zmq_{getpass.getuser()}.lock\"):\n            context = zmq.asyncio.Context()\n            self.socket = context.socket(zmq.REP)\n            if socket_type == \"ipc\":\n                pid = os.getpid()\n                address = f\"ipc:///tmp/verl_vllm_zmq_{pid}_{getpass.getuser()}.ipc\"\n            else:\n                ip = ray.util.get_node_ip_address().strip(\"[]\")\n                port, sock = get_free_port(ip)\n                if is_valid_ipv6_address(ip):\n                    address = f\"tcp://[{ip}]:{port}\"\n                    self.socket.setsockopt(zmq.IPV6, 1)\n                else:\n                    address = f\"tcp://{ip}:{port}\"\n            self.socket.bind(address)\n\n        loop = asyncio.get_running_loop()\n        self.zmq_loop_task = loop.create_task(self._loop_forever())\n\n        return address\n\n    async def _loop_forever(self):\n        while True:\n            try:\n                message = await self.socket.recv()\n                method, args, kwargs = pickle.loads(message)\n                result = await self._execute_method(method, *args, **kwargs)\n                await self.socket.send(pickle.dumps(result))\n            except Exception as e:\n                logger.exception(f\"vLLMAsyncRollout _loop_forever error: {e}\")\n                await self.socket.send(pickle.dumps(e))\n                break\n\n    def _init_worker(self, all_kwargs: list[dict[str, Any]]):\n        \"\"\"Initialize worker engine.\"\"\"\n        if not torch.distributed.is_initialized():\n            initialize_global_process_group_ray()\n        all_kwargs[0][\"rank\"] = int(os.environ[\"RANK\"])\n        device_name = \"NPU\" if is_npu_available else \"GPU\"\n        all_kwargs[0][\"local_rank\"] = (\n            0\n            if not ray_noset_visible_devices()\n            else int(ray.get_runtime_context().get_accelerator_ids()[device_name][0])\n        )\n        self.vllm_config = all_kwargs[0][\"vllm_config\"]\n        if self.lora_config:\n            lora_dtype = getattr(torch, self.config.dtype)\n            self.vllm_config.lora_config = LoRAConfig(lora_dtype=lora_dtype, **self.lora_config)\n        self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)\n        self.inference_engine.init_worker(all_kwargs)\n\n    def _load_model(self, *args, **kwargs):\n        self.inference_engine.load_model(*args, **kwargs)\n        _monkey_patch_compute_logits(self.inference_engine.worker.model_runner.model, len(self.tokenizer))\n\n    async def _execute_method(self, method: str | bytes, *args, **kwargs):\n        if method == \"init_worker\":\n            return self._init_worker(*args, **kwargs)\n        elif method == \"load_model\":\n            return self._load_model(*args, **kwargs)\n        elif method == \"sleep\" or method == \"wake_up\":\n            raise ValueError(\"wake_up and sleep should not be called through ZeroMQ\")\n        else:\n            return self.inference_engine.execute_method(method, *args, **kwargs)\n\n    async def resume(self, tags: list[str]):\n        \"\"\"Resume rollout weights or kv cache in GPU memory.\n\n        Args:\n            tags: weights or kv_cache.\n        \"\"\"\n        if self.config.free_cache_engine:\n            self.inference_engine.wake_up(tags=tags)\n\n    async def release(self):\n        \"\"\"Release weights and kv cache in GPU memory.\"\"\"\n        if self.config.free_cache_engine:\n            self.inference_engine.sleep(level=self.sleep_level)\n\n    async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs):\n        \"\"\"Update the weights of the rollout model.\n\n        Args:\n            weights: A generator that yields the name of the weight tensor and the tensor itself.\n        \"\"\"\n        peft_config, base_sync_done = kwargs.get(\"peft_config\", None), kwargs.get(\"base_sync_done\", False)\n        if peft_config and base_sync_done:\n            # In async mode, make sure the old lora is removed before adding the new one\n            self.inference_engine.worker.remove_lora(VLLM_LORA_INT_ID)\n            lora_request = TensorLoRARequest(\n                lora_name=VLLM_LORA_NAME,\n                lora_int_id=VLLM_LORA_INT_ID,\n                lora_path=VLLM_LORA_PATH,\n                peft_config=asdict(peft_config),\n                lora_tensors=dict(weights),\n            )\n            self.inference_engine.worker.add_lora(lora_request)\n            logger.info(f\"vLLM load weights, loaded_params: {len(weights)}\")\n        else:\n            from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader\n\n            model = self.inference_engine.worker.model_runner.model\n            patch_vllm_moe_model_weight_loader(model)\n            model.load_weights(weights)\n\n    def generate_sequences(self, prompts: DataProto) -> DataProto:\n        \"\"\"Batch generate sequences in sync mode.\"\"\"\n        raise NotImplementedError\n\n    # ==================== server mode public methods ====================\n\n    def get_zeromq_address(self):\n        return self.address\n"
  },
  {
    "path": "verl_distillation/verl/workers/sharding_manager/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_distillation/verl/workers/sharding_manager/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nSharding manager to implement HybridEngine\n\"\"\"\n\nfrom verl import DataProto\n\n\nclass BaseShardingManager:\n    def __init__(self):\n        self.timing = {}\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        pass\n\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        return data\n\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        return data\n"
  },
  {
    "path": "verl_distillation/verl/workers/sharding_manager/fsdp_sglang.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport logging\nimport os\n\nfrom sglang.srt.entrypoints.engine import Engine\nfrom sglang.srt.weight_sync.utils import update_weights as sgl_update_weights\nfrom torch.distributed.device_mesh import DeviceMesh\nfrom torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n\nfrom verl import DataProto\nfrom verl.protocol import all_gather_data_proto\nfrom verl.utils.device import get_device_id, get_torch_device, set_expandable_segments\nfrom verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu\nfrom verl.utils.import_utils import deprecated\nfrom verl.utils.memory_utils import aggressive_empty_cache\nfrom verl.utils.model import convert_weight_keys\nfrom verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer\nfrom verl.utils.torch_functional import check_device_is_available\nfrom verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets\n\nfrom .base import BaseShardingManager\n\n# from vllm.distributed import parallel_state as sglang_ps\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n@deprecated()\nclass FSDPSGLangShardingManager(BaseShardingManager):\n    @check_device_is_available()\n    def __init__(\n        self,\n        module: FSDP,\n        inference_engine: Engine,\n        model_config,\n        rollout_config,\n        full_params: bool = False,\n        device_mesh: DeviceMesh = None,\n        offload_param: bool = False,\n        multi_stage_wake_up: bool = False,\n    ):\n        self.module = module\n        self.inference_engine = inference_engine\n        self.model_config = model_config\n        self.rollout_config = rollout_config\n        self.device_mesh = device_mesh\n        self.offload_param = offload_param\n        self.multi_stage_wake_up = multi_stage_wake_up\n\n        # Full params\n        self.full_params = full_params\n        if full_params and fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(\n                self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig()\n            )\n        elif fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(\n                self.module,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n\n        self.tp_size = self.device_mesh[\"infer_tp\"].size()\n        self.tp_rank = self.device_mesh[\"infer_tp\"].get_local_rank()\n\n        # Note that torch_random_states may be different on each dp rank\n        self.torch_random_states = get_torch_device().get_rng_state()\n        # get a random rng states\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n    @GPUMemoryLogger(role=\"FSDPSGLangShardingManager enter\", logger=logger)\n    def __enter__(self):\n        self.timing = {}\n        with simple_timer(\"reshard\", self.timing):\n            loop = asyncio.get_event_loop()\n            loop.run_until_complete(self.wake_up())\n\n    @GPUMemoryLogger(role=\"FSDPSGLangShardingManager exit\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        loop = asyncio.get_event_loop()\n        loop.run_until_complete(self.sleep())\n\n    async def update_weights(self, params):\n        named_tensors = [(k, v) for k, v in params.items()]\n        update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20\n        for params_batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes):\n            await sgl_update_weights(\n                engine=self.inference_engine,\n                params_batch=params_batch,\n                device_mesh_key=\"infer_tp\",\n                device_mesh=self.device_mesh,\n            )\n\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n            await self.inference_engine.flush_cache()\n\n    async def release_memory(self):\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:\n            if self.multi_stage_wake_up:\n                await self.inference_engine.release_memory_occupation(tags=[\"kv_cache\", \"weights\"])\n            else:\n                await self.inference_engine.release_memory_occupation()\n            log_gpu_memory_usage(\"After release memory occupation in sharding manager\", logger=logger)\n\n    @GPUMemoryLogger(role=\"FSDPSGLangShardingManager enter\", logger=logger)\n    async def wake_up(self):\n        aggressive_empty_cache(force_sync=True)\n\n        log_gpu_memory_usage(\"Before state_dict() in sharding manager memory\", logger=logger)\n        if self.offload_param:\n            load_fsdp_model_to_gpu(self.module)\n        params = self.module.state_dict()\n        log_gpu_memory_usage(\"After state_dict() in sharding manager memory\", logger=logger)\n        device = get_device_id()  # used when fsdp2 set cpu_offload_policy\n        params = {\n            k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()\n        }\n\n        # convert weight keys to match the model config\n        params = convert_weight_keys(params, getattr(self.module, \"_fsdp_wrapped_module\", self.module))\n\n        if self.offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n\n        log_gpu_memory_usage(\"After offload_param in sharding manager memory\", logger=logger)\n\n        # sglang need to set _set_allocator_settings to False\n        logger.debug(\"fsdp sglang sharding_manager _set_allocator_settings to False\")\n        # Note(chenyang): SGLang is using torch memory pool to manage memory\n        # which is incompatible with expandable segments\n        set_expandable_segments(False)\n\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:\n            if self.multi_stage_wake_up:\n                await self.inference_engine.resume_memory_occupation(tags=[\"weights\"])\n                log_gpu_memory_usage(\"Before resume SGLang weights in sharding manager\", logger=logger)\n            else:\n                await self.inference_engine.resume_memory_occupation()\n                log_gpu_memory_usage(\"Before resume SGLang weights + kv_cache in sharding manager\", logger=logger)\n\n        # Copy, not share memory\n        await self.update_weights(params)\n        log_gpu_memory_usage(\"After sync model weights in sharding manager\", logger=logger)\n\n        del params\n        aggressive_empty_cache(force_sync=True)\n        log_gpu_memory_usage(\"After del state_dict and empty_cache in sharding manager\", logger=logger)\n\n        if (\n            self.multi_stage_wake_up\n            and self.rollout_config.free_cache_engine\n            and self.device_mesh[\"infer_tp\"].get_local_rank() == 0\n        ):\n            await self.inference_engine.resume_memory_occupation(tags=[\"kv_cache\"])\n            log_gpu_memory_usage(\"After resume SGLang kv_cache in sharding manager\", logger=logger)\n\n        # important: need to manually set the random states of each tp to be identical.\n        if self.device_mesh is not None:\n            self.torch_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"FSDPSGLangShardingManager exit\", logger=logger)\n    async def sleep(self):\n        if self.rollout_config.free_cache_engine:\n            log_gpu_memory_usage(\"Before SGLang offload in sharding manager\", logger=logger)\n            await self.release_memory()\n            log_gpu_memory_usage(\"After SGLang offload in sharding manager\", logger=logger)\n\n        self.module.train()\n\n        # add empty cache after each compute\n        aggressive_empty_cache(force_sync=True)\n\n        # always set _set_allocator_settings to True when using sglang\n        # it is required by fsdp2 to avoid oom\n        logger.debug(\"fsdp sglang sharding_manager _set_allocator_settings to True\")\n        set_expandable_segments(True)\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"All gather across tp group to make each rank has identical input.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        # TODO: Current impl doesn't consider FSDP with torch micro-dp\n        group = self.device_mesh[\"infer_tp\"].get_group()\n\n        all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"Get chunk data of this tp rank since we do all gather in preprocess.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        return data.chunk(chunks=self.tp_size)[self.tp_rank]\n"
  },
  {
    "path": "verl_distillation/verl/workers/sharding_manager/fsdp_ulysses.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContains a resharding manager that binds weights from FSDP zero3 to XPerfGPT\n\"\"\"\n\nfrom torch.distributed.device_mesh import DeviceMesh\n\nfrom verl import DataProto\nfrom verl.protocol import all_gather_data_proto\nfrom verl.utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group\n\nfrom .base import BaseShardingManager\n\n\nclass FSDPUlyssesShardingManager(BaseShardingManager):\n    \"\"\"\n    Sharding manager to support data resharding when using FSDP + Ulysses\n    \"\"\"\n\n    def __init__(self, device_mesh: DeviceMesh):\n        super().__init__()\n        self.device_mesh = device_mesh\n        self.seed_offset = 12345\n\n    def __enter__(self):\n        if self.device_mesh is not None:\n            # We have a global SP group\n            # so we have to change to use model-specific sp group\n            self.prev_sp_group = get_ulysses_sequence_parallel_group()\n            set_ulysses_sequence_parallel_group(self.device_mesh[\"sp\"].get_group())\n            # TODO: check how to set seed for each model\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        # restore random states\n        if self.device_mesh is not None:\n            # revert to previous sp group\n            set_ulysses_sequence_parallel_group(self.prev_sp_group)\n            # TODO: check how to set seed for each model\n\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"\n        AllGather data from sp region\n        This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE\n        In Ulysses, we need to make sure the same data is used across a SP group\n        \"\"\"\n        if self.device_mesh is not None:\n            group = self.device_mesh[\"sp\"].get_group()\n\n            all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"\n        Split the data to follow FSDP partition\n        \"\"\"\n        if self.device_mesh is not None:\n            sp_size = self.device_mesh[\"sp\"].size()\n            sp_rank = self.device_mesh[\"sp\"].get_local_rank()\n            data = data.chunk(chunks=sp_size)[sp_rank]\n        return data\n"
  },
  {
    "path": "verl_distillation/verl/workers/sharding_manager/fsdp_vllm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport logging\nimport os\nimport time\nfrom collections import OrderedDict\n\nfrom torch.distributed.device_mesh import DeviceMesh\nfrom torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n\ntry:\n    # for torch 2.5+\n    from torch.distributed.tensor import DTensor\nexcept ImportError:\n    from torch.distributed._tensor import DTensor\n\nfrom dataclasses import asdict\n\nfrom verl import DataProto\nfrom verl.protocol import all_gather_data_proto\nfrom verl.third_party.vllm import LLM, VLLM_SLEEP_LEVEL\nfrom verl.third_party.vllm import parallel_state as vllm_ps\nfrom verl.utils.device import get_device_id, get_device_name, get_torch_device, set_expandable_segments\nfrom verl.utils.fsdp_utils import (\n    fsdp_version,\n    layered_summon_lora_params,\n    load_fsdp_model_to_gpu,\n    offload_fsdp_model_to_cpu,\n)\nfrom verl.utils.import_utils import deprecated\nfrom verl.utils.model import check_exclude_modules, check_target_modules, convert_weight_keys\nfrom verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer\nfrom verl.utils.torch_functional import check_device_is_available\nfrom verl.utils.vllm import TensorLoRARequest, VLLMHijack, is_version_ge\n\nfrom .base import BaseShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n@deprecated()\nclass FSDPVLLMShardingManager(BaseShardingManager):\n    \"\"\"Sharding manager for FSDP models with vLLM inference engine integration.\n\n    Manages parameter synchronization between FSDP training models and vLLM\n    inference engines, handling both full parameters and LoRA adapters with\n    efficient memory management and device placement.\n    \"\"\"\n\n    @check_device_is_available()\n    def __init__(\n        self,\n        module: FSDP,\n        inference_engine: LLM,\n        model_config,\n        rollout_config,\n        full_params: bool = False,\n        device_mesh: DeviceMesh = None,\n        offload_param: bool = False,\n        load_format: str = \"dummy_hf\",\n        layered_summon: bool = True,\n    ):\n        self.module = module\n        # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model\n        self.inference_engine = inference_engine\n        # self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if\n        # inference_engine else None\n\n        self.model_runner = (\n            self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner\n            if self.inference_engine\n            else None\n        )\n\n        self.model_config = model_config\n        self.rollout_config = rollout_config\n        self.device_mesh = device_mesh\n        self.offload_param = offload_param\n        self.load_format = load_format\n        self.layered_summon = layered_summon\n\n        # Full params\n        self.full_params = full_params\n        if full_params and fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(\n                self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig()\n            )\n        elif fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(\n                self.module,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n\n        self.tp_size = self.device_mesh[\"infer_tp\"].size()\n        self.tp_rank = self.device_mesh[\"infer_tp\"].get_local_rank()\n\n        # Note that torch_random_states may be different on each dp rank\n        self.torch_random_states = get_torch_device().get_rng_state()\n        # get a random rng states\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n        self.base_sync_done: bool = \"dummy\" not in load_format\n        if is_version_ge(pkg=\"vllm\", minver=\"0.7.3\"):\n            VLLMHijack.hijack()\n\n    @GPUMemoryLogger(role=\"fsdp vllm sharding_manager\", logger=logger)\n    def __enter__(self):\n        def __collect_lora_params() -> OrderedDict:\n            \"\"\"\n            collect lora params or full params if base model is not ready in vllm\n            work with if isinstance(self.module._fsdp_wrapped_module, PeftModel)\n            \"\"\"\n            from peft.utils.save_and_load import get_peft_model_state_dict\n\n            lora_params = OrderedDict()\n            peft_model = getattr(self.module, \"_fsdp_wrapped_module\", self.module)\n            if fsdp_version(self.module) > 0:\n                if self.layered_summon:\n                    if not self.base_sync_done:\n                        raise ValueError(\n                            \"To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let \"\n                            \"rollout.load_format=safetensors\"\n                        )\n                    lora_params = layered_summon_lora_params(self.module)\n                else:\n                    with FSDP.summon_full_params(self.module, writeback=False):\n                        if self.base_sync_done:\n                            lora_params = get_peft_model_state_dict(peft_model)\n                            lora_params = {\n                                name: param.full_tensor().detach().cpu()\n                                if hasattr(param, \"full_tensor\")\n                                else param.detach().cpu()\n                                for name, param in lora_params.items()\n                            }\n                        else:\n                            model = peft_model.base_model.model\n                            orig_dev = \"cpu\" if \"cpu\" in str(next(model.parameters()).device) else get_device_name()\n                            model = model.to(\"cpu\")\n                            for name, param in model.state_dict().items():\n                                if any(x in name for x in [\"_flat_param\", \"lora_\"]):\n                                    continue\n                                name = name.replace(\"_fsdp_wrapped_module.\", \"\").replace(\".base_layer\", \"\")\n                                lora_params[name] = (\n                                    param.full_tensor().detach().cpu()\n                                    if hasattr(param, \"full_tensor\")\n                                    else param.detach().cpu()\n                                )\n                            model = model.to(orig_dev)\n                    get_torch_device().empty_cache()\n            else:\n                if self.base_sync_done:\n                    lora_params = get_peft_model_state_dict(peft_model)\n                else:\n                    model = peft_model.base_model.model\n                    orig_dev = \"cpu\" if \"cpu\" in str(next(model.parameters()).device) else get_device_name()\n                    model = model.to(\"cpu\")\n                    for name, param in model.state_dict().items():\n                        if any(x in name for x in [\"_flat_param\", \"lora_\"]):\n                            continue\n                        name = name.replace(\"_fsdp_wrapped_module.\", \"\").replace(\".base_layer\", \"\")\n                        lora_params[name] = param.detach().cpu()\n                    model = model.to(orig_dev)\n            return lora_params\n\n        # NOTE: Basically, we only need `get_torch_device().empty_cache()` before vllm wake_up and\n        # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.\n        # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory\n        # to speed up memory allocations.\n        #\n        # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management\n        # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103\n        self.timing = {}\n        with simple_timer(\"reshard\", self.timing):\n            get_torch_device().empty_cache()\n\n            log_gpu_memory_usage(\"Before state_dict() in sharding manager memory\", logger=logger)\n            if self.offload_param:\n                load_fsdp_model_to_gpu(self.module)\n\n            peft_config = None\n            peft_model = getattr(self.module, \"_fsdp_wrapped_module\", self.module)\n            if hasattr(peft_model, \"peft_config\"):\n                peft_config = peft_model.peft_config.get(\"default\", None)\n                params = __collect_lora_params()\n            else:\n                params = self.module.state_dict()\n            params = convert_weight_keys(params, getattr(self.module, \"_fsdp_wrapped_module\", self.module))\n\n            if self.offload_param:\n                offload_fsdp_model_to_cpu(self.module)\n            log_gpu_memory_usage(\"After state_dict() in sharding manager memory\", logger=logger)\n\n            # vllm need to set _set_allocator_settings to False\n            logger.debug(\"fsdp vllm sharding_manager _set_allocator_settings to False\")\n            set_expandable_segments(False)\n\n            if self.rollout_config.free_cache_engine:\n                if \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters:\n                    self.inference_engine.wake_up(tags=[\"weights\"])\n                else:\n                    self.inference_engine.wake_up()\n\n            # update model params\n            self.update_params(params, peft_config=peft_config)\n            log_gpu_memory_usage(\"After sync model weights in sharding manager\", logger=logger)\n            del params\n            get_torch_device().empty_cache()\n\n            if (\n                self.rollout_config.free_cache_engine\n                and \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters\n            ):\n                self.inference_engine.wake_up(tags=[\"kv_cache\"])\n\n            log_gpu_memory_usage(\"After del state_dict and empty_cache in sharding manager\", logger=logger)\n\n            # important: need to manually set the random states of each tp to be identical.\n            if self.device_mesh is not None:\n                self.torch_random_states = get_torch_device().get_rng_state()\n                get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"fsdp vllm sharding_manager\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        if self.rollout_config.free_cache_engine:\n            self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL)\n\n        self.module.train()\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        # _set_allocator_settings to True is required by fsdp2 to avoid oom\n        logger.debug(\"fsdp vllm sharding_manager _set_allocator_settings to True\")\n        set_expandable_segments(True)\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n\n    @GPUMemoryLogger(role=\"fsdp vllm sharding_manager\", logger=logger)\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"All gather across tp group to make each rank has identical input.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        # TODO: Current impl doesn't consider FSDP with torch micro-dp\n        group = vllm_ps.get_tensor_model_parallel_group().device_group\n\n        all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    @GPUMemoryLogger(role=\"fsdp vllm sharding_manager\", logger=logger)\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"Get chunk data of this tp rank since we do all gather in preprocess.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        return data.chunk(chunks=self.tp_size)[self.tp_rank]\n\n    def update_params(self, updated_params, peft_config=None):\n        \"\"\"Update model parameters in the vLLM inference engine.\n\n        Synchronizes parameters from the FSDP training model to the vLLM inference\n        engine, handling both full model parameters and LoRA adapters with proper\n        device placement and memory management.\n\n        Args:\n            updated_params (dict): Dictionary of parameter names to tensor values.\n            peft_config (optional): PEFT configuration for LoRA adapters.\n        \"\"\"\n        model = self.model_runner.model\n        if peft_config:\n            if self.base_sync_done:\n                lora_int_id = int(time.time_ns() % 0x7FFFFFFF)\n                lora_reqest = TensorLoRARequest(\n                    lora_name=f\"{lora_int_id}\",\n                    lora_int_id=lora_int_id,\n                    lora_path=\"simon_lora_path\",\n                    peft_config=asdict(peft_config),\n                    lora_tensors=updated_params,\n                )\n                self.inference_engine.llm_engine.add_lora(lora_reqest)\n                logger.info(f\"vLLM load weights, loaded_params: {len(updated_params)}\")\n                return\n            else:\n\n                def replace_lora_wrapper(k):\n                    \"\"\"Replace LoRA parameter keys with base layer equivalents.\n\n                    Transforms LoRA parameter names to their corresponding base layer\n                    names for proper weight loading in vLLM when base model sync is not done.\n\n                    Args:\n                        k (str): Original parameter key name.\n\n                    Returns:\n                        str: Transformed parameter key for base layer.\n                    \"\"\"\n                    stacked_params = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]\n                    if k.endswith(\".weight\"):\n                        module_k = k[: -len(\".weight\")]\n                        if check_exclude_modules(peft_config, module_k):\n                            return k\n                        elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(\n                            peft_config, module_k\n                        ):\n                            return f\"{module_k}.base_layer.weight\"\n                    if k.endswith(\".bias\"):\n                        module_k = k[: -len(\".bias\")]\n                        if check_exclude_modules(peft_config, module_k):\n                            return k\n                        elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(\n                            peft_config, module_k\n                        ):\n                            return f\"{module_k}.base_layer.bias\"\n                    return k\n\n                updated_params = {replace_lora_wrapper(k): v for k, v in updated_params.items()}\n\n        from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader\n\n        patch_vllm_moe_model_weight_loader(model)\n        device = get_device_id()  # used when fsdp2 set cpu_offload_policy\n        loaded_params = model.load_weights(\n            (\n                (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)\n                for name, param in updated_params.items()\n            )\n        )\n\n        self.base_sync_done = True\n        logger.info(f\"vLLM load weights, loaded_params: {len(loaded_params) if loaded_params else -1}\")\n"
  },
  {
    "path": "verl_distillation/verl/workers/sharding_manager/megatron_sglang.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.\n\"\"\"\n\nimport asyncio\nimport logging\nimport os\n\nfrom omegaconf import DictConfig\nfrom sglang.srt.entrypoints.engine import Engine\nfrom sglang.srt.weight_sync.utils import update_weights as sgl_update_weights\nfrom torch import nn\nfrom torch.distributed.device_mesh import DeviceMesh\n\nfrom verl.protocol import DataProto, all_gather_data_proto\nfrom verl.utils.device import get_torch_device, set_expandable_segments\nfrom verl.utils.import_utils import deprecated\nfrom verl.utils.megatron_utils import (\n    load_megatron_model_to_gpu,\n    offload_megatron_model_to_cpu,\n    per_tensor_generator,\n)\nfrom verl.utils.memory_utils import aggressive_empty_cache\nfrom verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer\nfrom verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets\n\nfrom .base import BaseShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_PPO_LOGGING_LEVEL\", \"WARN\"))\n\n\n\"\"\"\nMegatron Hybrid Engine:\n- During training, only the current pp stage holds the parameters\n- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all \n  the parameters)\n- Bind the parameters to the inference engine\n- Do inference in tp. pp is treated as additional dp\n- After inference, all the parameters that doesn't belong to this pp rank is freed.\n\"\"\"\n\n\n@deprecated()\nclass MegatronSGLangShardingManager(BaseShardingManager):\n    \"\"\"A sharding manager for Megatron-style training & inference with SGLang.\n\n    This class manages the sharding of model parameters between training and inference\n    phases in a Megatron-style parallel setup. It handles:\n    - Loading/offloading parameters between CPU/GPU\n    - Updating inference engine weights\n    - Managing random states for reproducibility\n    - Data preprocessing for distributed inference\n\n    Args:\n        actor_module (nn.ModuleList): The actor model modules\n        inference_engine (Engine): The SGLang inference engine\n        model_config: Configuration for the actor's model\n        rollout_config: Configuration for rollout generation\n        transformer_config: Transformer-specific configuration\n        layer_name_mapping: Mapping between layer names and parameters\n        weight_converter: Utility for converting weights between formats\n        device_mesh (DeviceMesh | None): PyTorch device mesh for distributed training\n        offload_param (bool): Whether to offload parameters to CPU when not in use\n    \"\"\"\n\n    def __init__(\n        self,\n        actor_module: nn.ModuleList,\n        inference_engine: Engine,\n        model_config: DictConfig,\n        rollout_config: DictConfig,\n        transformer_config,\n        layer_name_mapping,\n        weight_converter,\n        device_mesh: DeviceMesh | None = None,\n        offload_param: bool = False,\n        bridge=None,\n    ):\n        self.actor_module = actor_module\n        self.inference_engine = inference_engine\n        self.model_config = model_config\n        self.rollout_config = rollout_config\n        self.transformer_config = transformer_config\n        self.layer_name_mapping = layer_name_mapping\n        self.weight_converter = weight_converter\n        self.device_mesh = device_mesh\n        self.bridge = bridge\n        self.offload_param = offload_param\n\n        if self.device_mesh is not None:\n            self.infer_tp_size = self.device_mesh[\"infer_tp\"].mesh.size()[0]\n        else:\n            self.infer_tp_size = self.inference_engine._tp_size\n\n        # Note that torch_random_states may be different on each dp rank\n        self.torch_random_states = get_torch_device().get_rng_state()\n        # get a random rng states\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n    @GPUMemoryLogger(role=\"MegatronSGLangShardingManager enter\", logger=logger)\n    def __enter__(self):\n        self.timing = {}\n        with simple_timer(\"reshard\", self.timing):\n            loop = asyncio.get_event_loop()\n            loop.run_until_complete(self.wake_up())\n\n    @GPUMemoryLogger(role=\"MegatronSGLangShardingManager exit\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        loop = asyncio.get_event_loop()\n        loop.run_until_complete(self.sleep())\n\n    async def update_weights(self, params):\n        \"\"\"\n        Update model weights using tensor buckets, similar to THUDM/slime's implementation.\n\n        Notes:\n          - For the best performance of `rebuild_cuda_tensor`, it is recommended to:\n              1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`.\n              2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`\n            when using Tensor Parallelism (TP >= 8).\n          - See reference implementations in SLIME:\n            - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452\n            - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39\n        \"\"\"\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:\n            await self.inference_engine.resume_memory_occupation()\n        named_tensors = params\n\n        update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20\n        for params_batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes):\n            await sgl_update_weights(\n                engine=self.inference_engine,\n                params_batch=params_batch,\n                device_mesh_key=\"infer_tp\",\n                device_mesh=self.device_mesh,\n            )\n\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n            await self.inference_engine.flush_cache()\n\n    async def release_memory(self):\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:\n            await self.inference_engine.release_memory_occupation()\n\n    @GPUMemoryLogger(role=\"MegatronSGLangShardingManager enter\", logger=logger)\n    async def wake_up(self):\n        aggressive_empty_cache(force_sync=True)\n\n        if self.offload_param:\n            load_megatron_model_to_gpu(self.actor_module, load_grad=False)\n        if self.bridge is not None:\n            per_tensor_param = self.bridge.export_weights(self.actor_module)\n        else:\n            per_tensor_param = per_tensor_generator(\n                self.actor_module,\n                self.model_config,\n                self.weight_converter,\n                self.transformer_config,\n                self.layer_name_mapping,\n            )\n\n        set_expandable_segments(False)\n\n        await self.update_weights(per_tensor_param)\n        if self.offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n        aggressive_empty_cache(force_sync=True)\n        # important: need to manually set the random states of each tp to be identical.\n        if self.device_mesh is not None:\n            self.torch_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"MegatronSGLangShardingManager exit\", logger=logger)\n    async def sleep(self):\n        if self.rollout_config.free_cache_engine:\n            log_gpu_memory_usage(\"Before SGLang offload in sharding manager\", logger=logger)\n            await self.release_memory()\n            log_gpu_memory_usage(\"After SGLang offload in sharding manager\", logger=logger)\n\n        for model in self.actor_module:\n            model.train()\n        # add empty cache after each compute\n        aggressive_empty_cache(force_sync=True)\n\n        set_expandable_segments(True)\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n\n    @GPUMemoryLogger(role=\"megatron sglang sharding_manager\", logger=logger)\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp\n        if self.infer_tp_size == 1:\n            return data\n        all_gather_data_proto(data, self.device_mesh[\"infer_tp\"].get_group())\n        return data\n\n    @GPUMemoryLogger(role=\"megatron sglang sharding_manager\", logger=logger)\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp\n        if self.infer_tp_size == 1:\n            return data\n        return data.chunk(chunks=self.infer_tp_size)[self.device_mesh[\"infer_tp\"].get_local_rank()]\n"
  },
  {
    "path": "verl_distillation/verl/workers/sharding_manager/megatron_vllm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.\n\"\"\"\n\nimport inspect\nimport logging\nimport os\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom omegaconf import DictConfig\nfrom torch import nn\n\nfrom verl import DataProto\nfrom verl.models.mcore.weight_converter import McoreToHFWeightConverterBase\nfrom verl.protocol import all_gather_data_proto\nfrom verl.third_party.vllm import LLM, VLLM_SLEEP_LEVEL\nfrom verl.third_party.vllm import parallel_state as vllm_ps\nfrom verl.utils.device import get_torch_device, set_expandable_segments\nfrom verl.utils.import_utils import deprecated\nfrom verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator\nfrom verl.utils.memory_utils import aggressive_empty_cache\nfrom verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage\nfrom verl.utils.profiler.performance import simple_timer\nfrom verl.utils.torch_functional import check_device_is_available\n\nfrom .base import BaseShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n\"\"\"\nMegatron Hybrid Engine:\n- During training, only the current pp stage holds the parameters\n- Before inference, broadcast the parameters of the current pp rank \n   to all other pp ranks (all pp ranks holds all the parameters)\n- Bind the parameters to the inference engine\n- Do inference in tp. pp is treated as additional dp\n- After inference, all the parameters that doesn't belong to this pp rank is freed.\n\"\"\"\n\n\n@deprecated()\nclass MegatronVLLMShardingManager(BaseShardingManager):\n    \"\"\"A sharding manager that bridges Megatron-LM training with vLLM inference.\n\n    This class handles the parameter sharding and communication between:\n    - Megatron-LM's tensor/expert parallel training setup\n    - vLLM's tensor parallel inference setup\n\n    Key responsibilities:\n    - Manages parameter broadcasting between training and inference configurations\n    - Handles weight conversion between Megatron and HuggingFace formats\n    - Coordinates memory management between training and inference phases\n    - Maintains random state consistency across different parallel groups\n\n    Args:\n        actor_module (nn.ModuleList): The Megatron-LM model being trained\n        inference_engine (LLM): The vLLM inference engine\n        model_config: Configuration for the actor's model\n        transformer_config: Transformer-specific configuration for the model\n        rollout_config: Configuration for rollout\n        layer_name_mapping: Mapping between Megatron and HF layer names\n        weight_converter (McoreToHFWeightConverterBase): Converts weights between formats\n        device_mesh: Device mesh for parallel operations\n        offload_param (bool): Whether to offload parameters when not in use\n    \"\"\"\n\n    @check_device_is_available()\n    def __init__(\n        self,\n        actor_module: nn.ModuleList,\n        inference_engine: LLM,\n        model_config: DictConfig,\n        transformer_config,\n        rollout_config: DictConfig,\n        layer_name_mapping,\n        weight_converter: McoreToHFWeightConverterBase,\n        device_mesh,\n        offload_param: bool = True,\n        bridge=None,\n    ):\n        self.actor_module = actor_module\n        self.inference_engine = inference_engine\n        self.offload_param = offload_param\n\n        # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model\n        self.model_runner = (\n            self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner\n            if self.inference_engine\n            else None\n        )\n\n        self.model_config = model_config\n        self.transformer_config = transformer_config\n        self.rollout_config = rollout_config\n        self.layer_name_mapping = layer_name_mapping\n        self.weight_converter = weight_converter\n        self.bridge = bridge\n        # initialize groups for vllm inference\n        self.rank = torch.distributed.get_rank()\n        self.world_size = torch.distributed.get_world_size()\n\n        self.device_mesh = device_mesh\n        self.infer_tp_size = self.device_mesh[\"infer_tp\"].size()\n        self.infer_tp_rank = self.device_mesh[\"infer_tp\"].get_local_rank()\n\n        self.train_tp_size = mpu.get_tensor_model_parallel_world_size()\n        self.train_tp_rank = mpu.get_tensor_model_parallel_rank()\n        self.train_tp_group = mpu.get_tensor_model_parallel_group()\n        self.train_ep_size = mpu.get_expert_model_parallel_world_size()\n        self.train_ep_rank = mpu.get_expert_model_parallel_rank()\n        self.train_ep_group = mpu.get_expert_model_parallel_group()\n        self.train_etp_size = mpu.get_expert_tensor_parallel_world_size()\n        self.train_etp_rank = mpu.get_expert_tensor_parallel_rank()\n        self.train_etp_group = mpu.get_expert_tensor_parallel_group()\n        self.need_tp_reshard = self.train_tp_size != self.infer_tp_size\n        self.train_tp_larger = self.train_tp_size > self.infer_tp_size\n\n        self.torch_random_states = get_torch_device().get_rng_state()\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n    @GPUMemoryLogger(role=\"megatron vllm sharding_manager\", logger=logger)\n    def __enter__(self):\n        self.timing = {}\n        with simple_timer(\"reshard\", self.timing):\n            aggressive_empty_cache(force_sync=True)\n\n            log_gpu_memory_usage(\"Before state_dict() in sharding manager memory\", logger=logger)\n            if self.offload_param:\n                load_megatron_model_to_gpu(self.actor_module, load_grad=False)\n\n            set_expandable_segments(False)\n\n            if self.rollout_config.free_cache_engine:\n                if \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters:\n                    self.inference_engine.wake_up(tags=[\"weights\"])\n                else:\n                    self.inference_engine.wake_up()\n            if self.bridge is not None:\n                per_tensor_param = self.bridge.export_weights(self.actor_module)\n            else:\n                per_tensor_param = per_tensor_generator(\n                    self.actor_module,\n                    self.model_config,\n                    self.weight_converter,\n                    self.transformer_config,\n                    self.layer_name_mapping,\n                )\n            model = self.model_runner.model\n            from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader\n\n            patch_vllm_moe_model_weight_loader(model)\n            loaded_params = model.load_weights(per_tensor_param)\n            info = f\"vLLM load weights, loaded_params: {len(loaded_params)}\"\n            logger.info(info)\n\n            if self.offload_param:\n                offload_megatron_model_to_cpu(self.actor_module)\n            aggressive_empty_cache(force_sync=True)\n\n            if (\n                self.rollout_config.free_cache_engine\n                and \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters\n            ):\n                self.inference_engine.wake_up(tags=[\"kv_cache\"])\n\n            # important: need to manually set the random states of each tp to be identical.\n            if self.device_mesh is not None:\n                self.torch_random_states = get_torch_device().get_rng_state()\n                get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"megatron vllm sharding_manager\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        if self.rollout_config.free_cache_engine:\n            self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL)\n        for model in self.actor_module:\n            model.train()\n\n        aggressive_empty_cache(force_sync=True)\n\n        set_expandable_segments(True)\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n\n    @GPUMemoryLogger(role=\"megatron vllm sharding_manager\", logger=logger)\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp\n        if self.infer_tp_size == 1:\n            return data\n\n        # TODO: Current impl doesn't consider FSDP with torch micro-dp\n        group = vllm_ps.get_tensor_model_parallel_group().device_group\n\n        all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    @GPUMemoryLogger(role=\"megatron vllm sharding_manager\", logger=logger)\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp\n        if self.infer_tp_size == 1:\n            return data\n        return data.chunk(chunks=self.infer_tp_size)[self.infer_tp_rank]\n"
  },
  {
    "path": "verl_rl/CONTRIBUTING.md",
    "content": "# Contributing to verl\n\nThank you for considering a contribution to verl! We welcome contributions of any kind - bug fixes, enhancements, documentation improvements, or even just feedback. Whether you're an experienced developer or this is your first open-source project, your help is invaluable.\n\nYour support can take many forms:\n- Report issues or unexpected behaviors.\n- Suggest or implement new features.\n- Improve or expand documentation.\n- Review pull requests and assist other contributors.\n- Spread the word: share verl in blog posts, social media, or give the repo a ⭐.\n\n## Finding Issues to Contribute\n\nLooking for ways to dive in? Check out these issues:\n- [Good first issues](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22)\n- [Call for contribution](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22call%20for%20contribution%22)\nFurthermore, you can learn the development plan and roadmap via [RFC](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3ARFC) and [Roadmap](https://github.com/volcengine/verl/issues?q=state%3Aopen%20label%3A%22roadmap%22).\n\n\n## Developing\n\n- **Python-only**: install verl via `pip install -e .[test,vllm]` or `pip install -e .[test,sglang]` and iterate quickly. For full dependency setup, check out the verl [installation doc](https://verl.readthedocs.io/en/latest/start/install.html).\n\n## Code Linting and Formatting\n\nWe rely on pre-commit to keep our code consistent. To set it up:\n\n```bash\npip install pre-commit\npre-commit install\n# for staged changes\npre-commit run\n# for all files in the repo\npre-commit run --all-files\n# run a specific hook with pre-commit\n# pre-commit run --all-files --show-diff-on-failure --color=always <hood-id>\npre-commit run --all-files --show-diff-on-failure --color=always ruff\npre-commit run --all-files --show-diff-on-failure --color=always autogen-trainer-cfg\n```\n\n## Testing\n\nOur test suites run on GitHub Actions. Check these workflows for details:\n- [GPU unit tests](https://github.com/volcengine/verl/blob/main/.github/workflows/gpu_unit_tests.yml)\n- [CPU unit tests](https://github.com/volcengine/verl/blob/main/.github/workflows/cpu_unit_tests.yml)\n- [vLLM tests](https://github.com/volcengine/verl/blob/main/.github/workflows/vllm.yml)\n- [SGLang tests](https://github.com/volcengine/verl/blob/main/.github/workflows/sgl.yml)\n\n### Adding CI tests\n\nIf possible, please add CI test(s) for your new feature:\n\n1. Find the most relevant workflow yml file, which usually corresponds to a `hydra` default config (e.g. `ppo_trainer`, `ppo_megatron_trainer`, `sft_trainer`, etc).\n2. Add related path patterns to the `paths` section if not already included.\n3. Minimize the workload of the test script(s) (see existing scripts for examples).\n\n## Building the Docs\n```\n# Ensure verl is on your PYTHONPATH, e.g.:\npip install -e .[test]\n\n# Install documentation dependencies\npip install -r requirements-docs.txt\n\n# Generate HTML docs\nmake clean\nmake html\n\n# Preview locally\npython -m http.server -d _build/html/\n```\nOpen your browser at http://localhost:8000 to explore the docs.\n\n## Pull Requests & Code Reviews\n\nThanks for submitting a PR! To streamline reviews:\n- Follow our Pull Request Template for title format and checklist.\n- Adhere to our pre-commit lint rules and ensure all checks pass.\n- Update docs for any user-facing changes.\n- Add or update tests in the CI workflows, or explain why tests aren't applicable.\n\n## License\n\nSee the [LICENSE](https://github.com/volcengine/verl/blob/main/LICENSE) file for full details.\n\n## Thank You\n\nWe appreciate your contributions to verl. Your efforts help make the project stronger and more user-friendly. Happy coding!\n\n"
  },
  {
    "path": "verl_rl/LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "verl_rl/README.md",
    "content": "# OneRec RL Training\n\nReinforcement learning training for OneRec recommendation model based on verl framework.\n\n\n## Installation\n\n### 1. Configure hostfile (multi-node)\n\n```bash\ncat > /etc/mpi/hostfile << EOF\n192.168.1.100 slots=8\n192.168.1.101 slots=8\n192.168.1.102 slots=8\nEOF\n```\n\nNote: `slots=N` specifies the number of GPUs available on each node.\n\n### 2. Install dependencies\n\n```bash\n# Single node\nbash deploy_env.sh\n\n# Multi-node\nbash deploy_env.sh --all-nodes\n```\n\n### 3. Start Ray cluster\n\n```bash\nbash init_ray_cluster.sh\n```\n\n## Quick Start\n\n### Data Format\n\nWe use SFT data from five `*_rec` tasks: `video_rec`, `interactive_rec`, `label_cond_rec`, `ad_rec`, `product_rec`.\n\nSee [data/README.md](../data/README.md) for detailed data format specification.\n\n### Start Training\n\n```bash\ncd verl_rl\n\nexport BASE_MODEL=\"/path/to/your/model\"\n\nbash recipe/onerec/run_grpo.sh 2>&1 | tee logs/train_$(date +%Y%m%d_%H%M%S).log\n```\n\n## Configuration\n\n### Model\n\n| Parameter | Default | Description |\n|-----------|---------|-------------|\n| `BASE_MODEL` | - | Model path |\n| `ROLLOUT_TP_SIZE` | 1 | Tensor parallel size |\n\n### Training\n\n| Parameter | Default | Description |\n|-----------|---------|-------------|\n| `LEARNING_RATE` | 2e-6 | Learning rate |\n| `KL_LOSS_COEF` | 0.001 | KL loss coefficient |\n| `TEMPERATURE` | 1 | Sampling temperature |\n\n### Rollout\n\n| Parameter | Default | Description |\n|-----------|---------|-------------|\n| `ROLLOUT_N` | 1 | Samples per prompt |\n| `STAGE2_BEAM_SIZE` | 32 | Beam search width |\n| `RESPONSE_LENGTH` | 2048 | Max response length |\n| `STAGE1_MAX_TOKENS` | 1024 | Stage 1 max tokens |\n| `STAGE2_NUM_TOKENS` | 3 | Stage 2 tokens |\n\n### Think Mode\n\n| Parameter | Default | Description |\n|-----------|---------|-------------|\n| `ENABLE_THINK` | False | Enable think mode |\n| `ENABLE_NONTHINK` | False | Enable non-think mode |\n| `USE_FORCE_PREFIX` | False | Force prefix |\n\n## Directory Structure\n\n```\nverl_rl/\n├── deploy_env.sh          # Environment deployment\n├── init_ray.sh            # Single node Ray init\n├── init_ray_cluster.sh    # Multi-node Ray cluster\n├── requirements.txt       # Dependencies\n├── recipe/\n│   └── onerec/\n│       ├── run_grpo.sh    # Training script\n│       └── onerec_recipe.py\n└── verl/                  # verl core code\n```\n"
  },
  {
    "path": "verl_rl/README_ORIGINAL.md",
    "content": "<div align=\"center\">\n 👋 Hi, everyone! \n    verl is a RL training library initiated by <b>ByteDance Seed team</b> and maintained by the verl community.\n    <br>\n    <br>\n</div>\n\n<div align=\"center\">\n\n<a href=\"https://deepwiki.com/volcengine/verl\"><img src=\"https://devin.ai/assets/deepwiki-badge.png\" alt=\"Ask DeepWiki.com\" style=\"height:20px;\"></a>\n[![GitHub Repo stars](https://img.shields.io/github/stars/volcengine/verl)](https://github.com/volcengine/verl/stargazers)\n[![Twitter](https://img.shields.io/twitter/follow/verl_project)](https://twitter.com/verl_project)\n<a href=\"https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA\"><img src=\"https://img.shields.io/badge/Slack-verl-blueviolet?logo=slack&amp\"></a>\n<a href=\"https://arxiv.org/pdf/2409.19256\"><img src=\"https://img.shields.io/static/v1?label=EuroSys&message=Paper&color=red\"></a>\n[![Documentation](https://img.shields.io/badge/documentation-blue)](https://verl.readthedocs.io/en/latest/)\n<a href=\"https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG\"><img src=\"https://img.shields.io/badge/微信-green?logo=wechat&amp\"></a>\n\n</div>\n\n![seed logo](https://github.com/user-attachments/assets/c42e675e-497c-4508-8bb9-093ad4d1f216)\n\n<h1 style=\"text-align: center;\">verl: Volcano Engine Reinforcement Learning for LLMs</h1>\n\nverl is a flexible, efficient and production-ready RL training library for large language models (LLMs).\n\nverl is the open-source version of **[HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)** paper.\n\nverl is flexible and easy to use with:\n\n- **Easy extension of diverse RL algorithms**: The hybrid-controller programming model enables flexible representation and efficient execution of complex post-training dataflows. Build RL dataflows such as GRPO, PPO in a few lines of code.\n\n- **Seamless integration of existing LLM infra with modular APIs**: Decouples computation and data dependencies, enabling seamless integration with existing LLM frameworks, such as FSDP, Megatron-LM, vLLM, SGLang, etc\n\n- **Flexible device mapping**: Supports various placement of models onto different sets of GPUs for efficient resource utilization and scalability across different cluster sizes.\n\n- Ready integration with popular HuggingFace models\n\nverl is fast with:\n\n- **State-of-the-art throughput**: SOTA LLM training and inference engine integrations and SOTA RL throughput.\n\n- **Efficient actor model resharding with 3D-HybridEngine**: Eliminates memory redundancy and significantly reduces communication overhead during transitions between training and generation phases.\n\n</p>\n\n## News\n- [2025/07] The first verl meetup will be held at ICML Vancouver on July 16th! Please [join us](https://lu.ma/0ek2nyao) if you are at ICML! (onsite only)\n- [2025/07] verl keynote at [AWS AI Hours Singapore](https://pages.awscloud.com/aws-ai-hours-sg.html#agenda) on 7/8, verl & verl-agent project updates at [Agent for SWE meetup](https://lu.ma/e498qhsi) by LF AI & Data Singapore on 7/11.\n- [2025/06] verl with Megatron backend enables large MoE models such as [DeepSeek-671b and Qwen3-236b](https://verl.readthedocs.io/en/latest/perf/dpsk.html).\n- [2025/06] verl team will provide latest project updates at [PyTorch Day China](https://www.lfasiallc.com/pytorch-day-china/) on June 7th. Meet our dev team in Beijing!\n- [2025/04] [Seed-Thinking-v1.5](https://github.com/ByteDance-Seed/Seed-Thinking-v1.5/blob/main/seed-thinking-v1.5.pdf) tech report is released! Trained with verl, Seed-Thinking-v1.5 achieves 86.7 on AIME 2024, 55.0 on Codeforces and 77.3 on GPQA, demonstrating excellent reasoning abilities in STEM and coding. Beyond reasoning tasks, the method demonstrates notable generalization across diverse domains.\n- [2025/03] [DAPO](https://dapo-sia.github.io/) is the open-sourced SOTA RL algorithm that achieves 50 points on AIME 2024 based on the Qwen2.5-32B pre-trained model, surpassing the previous SOTA achieved by DeepSeek's GRPO (DeepSeek-R1-Zero-Qwen-32B). DAPO's training is fully powered by verl and the reproduction code is available in `recipe/dapo` now.\n<details><summary> more... </summary>\n<ul>\n  <li> [2025/04] [VAPO](https://arxiv.org/pdf/2504.05118) (value-based augmented PPO) paper covers our latest RL method for reasoning models. Trained from Qwen-32B-base model, VAPO achieves 60.4 on AIME 2024, outperforming DAPO-32B.</li>\n  <li>[2025/05] [PF-PPO](https://arxiv.org/abs/2409.06957), accepted to ICML 2025, is now supported in verl! PF-PPO enhances policy learning efficiency and robustness by filtering potentially noisy reward signals and reusing high-quality experiences via a replay buffer.</li>\n  <li>[2025/04] We will give a tutorial about latest post-training techniques and programming guide for verl at [ICLR 2025 Expo](https://iclr.cc/virtual/2025/calendar?filter_events=Expo+Talk+Panel&filter_rooms=), [SCI-FM workshop](https://open-foundation-model.github.io/) and [LMSys afterparty](https://lu.ma/d23nyynm). Talk materials available [here](https://github.com/eric-haibin-lin/verl-community/tree/main/iclr25). </li>\n  <li>[2025/03] verl v0.3.0.post1 is released! See [release note](https://github.com/volcengine/verl/releases/) for details. It achieves [~1.4x speedup](https://tongyx361.github.io/blogs/posts/verl-intro/#/verl-flexible-and-efficient-rl-for-llms) compared to prev versions.</li>\n  <li>[2025/05] verl will be presented at [A2M Shanghai](https://a2m.msup.com.cn/home/?aid=4488&city=shanghai) on 5/16 - 5/17.</li>\n  <li>[2025/05] verl will be presented at [GOSIM x PyTorch Day 2025](https://paris2025.gosim.org/). See you in Paris! </li>\n  <li>[2025/03] We introduced the programming model of verl at the [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg) and [verl intro and updates](https://github.com/eric-haibin-lin/verl-community/blob/main/slides/verl-lmsys-meetup.pdf) at the [SGLang-LMSYS Org Meetup](https://lu.ma/ntjrr7ig) in Sunnyvale mid-March.</li>\n  <li>[2025/03] We will present verl(HybridFlow) at EuroSys 2025. See you in Rotterdam!</li>\n  <li>[2025/02] verl v0.2.0.post2 is released!</li>\n  <li>[2025/02] We presented verl in the <a href=\"https://lu.ma/ji7atxux\">Bytedance/NVIDIA/Anyscale Ray Meetup</a>. See you in San Jose!</li>\n  <li>[2025/01] [Doubao-1.5-pro](https://team.doubao.com/zh/special/doubao_1_5_pro) is released with SOTA-level performance on LLM & VLM. The RL scaling preview model is trained using verl, reaching OpenAI O1-level performance on math benchmarks (70.0 pass@1 on AIME).</li>\n  <li>[2024/12] verl is presented at Ray Forward 2024. Slides available <a href=\"https://github.com/eric-haibin-lin/verl-community/blob/main/slides/Ray_Forward_2024_%E5%B7%AB%E9%94%A1%E6%96%8C.pdf\">here</a></li>\n  <li>[2024/12] The team presented <a href=\"https://neurips.cc/Expo/Conferences/2024/workshop/100677\">Post-training LLMs: From Algorithms to Infrastructure</a> at NeurIPS 2024. <a href=\"https://github.com/eric-haibin-lin/verl-data/tree/neurips\">Slides</a> and <a href=\"https://neurips.cc/Expo/Conferences/2024/workshop/100677\">video</a> available.</li>\n  <li>[2024/10] verl is presented at Ray Summit. <a href=\"https://www.youtube.com/watch?v=MrhMcXkXvJU&list=PLzTswPQNepXntmT8jr9WaNfqQ60QwW7-U&index=37\">Youtube video</a> available.</li>\n  <li>[2024/08] HybridFlow (verl) is accepted to EuroSys 2025.</li>\n</ul>   \n</details>\n\n## Key Features\n\n- **FSDP**, **FSDP2** and **Megatron-LM** for training.\n- **vLLM**, **SGLang** and **HF Transformers** for rollout generation.\n- Compatible with Hugging Face Transformers and Modelscope Hub: [Qwen-3](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen3-8b.sh), Qwen-2.5, Llama3.1, Gemma2, DeepSeek-LLM, etc\n- Supervised fine-tuning.\n- Reinforcement learning with [PPO](examples/ppo_trainer/), [GRPO](examples/grpo_trainer/), [ReMax](examples/remax_trainer/), [REINFORCE++](https://verl.readthedocs.io/en/latest/examples/config.html#algorithm), [RLOO](examples/rloo_trainer/), [PRIME](recipe/prime/), [DAPO](recipe/dapo/), [DrGRPO](recipe/drgrpo), [KL_Cov & Clip_Cov](recipe/entropy) etc.\n  - Support model-based reward and function-based reward (verifiable reward) for math, [coding](https://github.com/volcengine/verl/tree/main/recipe/dapo), etc\n  - Support vision-language models (VLMs) and [multi-modal RL](examples/grpo_trainer/run_qwen2_5_vl-7b.sh) with Qwen2.5-vl, Kimi-VL\n  - [Multi-turn with tool calling](https://github.com/volcengine/verl/tree/main/examples/sglang_multiturn)\n- LLM alignment recipes such as [Self-play preference optimization (SPPO)](https://github.com/volcengine/verl/tree/main/recipe/sppo)\n- Flash attention 2, [sequence packing](examples/ppo_trainer/run_qwen2-7b_seq_balance.sh), [sequence parallelism](examples/ppo_trainer/run_deepseek7b_llm_sp2.sh) support via DeepSpeed Ulysses, [LoRA](examples/sft/gsm8k/run_qwen_05_peft.sh), [Liger-kernel](examples/sft/gsm8k/run_qwen_05_sp2_liger.sh).\n- Scales up to 671B models and hundreds of GPUs with [expert parallelism](https://github.com/volcengine/verl/pull/1467)\n- Multi-gpu [LoRA RL](https://verl.readthedocs.io/en/latest/advance/ppo_lora.html) support to save memory.\n- Experiment tracking with wandb, swanlab, mlflow and tensorboard.\n\n## Upcoming Features and Changes\n\n- Q3 Roadmap https://github.com/volcengine/verl/issues/2388\n- DeepSeek 671b optimizations with Megatron https://github.com/volcengine/verl/issues/1033\n- Multi-turn rollout and tools using optimizations https://github.com/volcengine/verl/issues/1882\n- [Agent integration](https://github.com/volcengine/verl/tree/main/verl/experimental/agent_loop)\n- Async and off-policy architecture https://github.com/volcengine/verl/pull/2231\n- List of breaking changes since v0.4 https://github.com/volcengine/verl/discussions/2270\n\n## Getting Started\n\n<a href=\"https://verl.readthedocs.io/en/latest/index.html\"><b>Documentation</b></a>\n\n**Quickstart:**\n\n- [Installation](https://verl.readthedocs.io/en/latest/start/install.html)\n- [Quickstart](https://verl.readthedocs.io/en/latest/start/quickstart.html)\n- [Programming Guide](https://verl.readthedocs.io/en/latest/hybrid_flow.html) & [Tech Talk](https://hcqnc.xetlk.com/sl/3vACOK) (in Chinese)\n- [PPO in verl](https://verl.readthedocs.io/en/latest/algo/ppo.html)\n- [GRPO in verl](https://verl.readthedocs.io/en/latest/algo/grpo.html)\n\n**Running a PPO example step-by-step:**\n\n\n- [Prepare Data for Post-Training](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html)\n- [Implement Reward Function for Dataset](https://verl.readthedocs.io/en/latest/preparation/reward_function.html)\n- [PPO Example Architecture](https://verl.readthedocs.io/en/latest/examples/ppo_code_architecture.html)\n- [Config Explanation](https://verl.readthedocs.io/en/latest/examples/config.html)\n\n**Reproducible algorithm baselines:**\n\n- [RL performance on coding, math](https://verl.readthedocs.io/en/latest/algo/baseline.html)\n\n**For code explanation and advance usage (extension):**\n\n- PPO Trainer and Workers\n  - [PPO Ray Trainer](https://verl.readthedocs.io/en/latest/workers/ray_trainer.html)\n  - [PyTorch FSDP Backend](https://verl.readthedocs.io/en/latest/workers/fsdp_workers.html)\n  - [Megatron-LM Backend](https://verl.readthedocs.io/en/latest/index.html)\n\n- Advanced Usage and Extension\n  - [Add Models with the FSDP Backend](https://verl.readthedocs.io/en/latest/advance/fsdp_extension.html)\n  - [Add Models with the Megatron-LM Backend](https://verl.readthedocs.io/en/latest/advance/megatron_extension.html)\n  - [Multi-turn Rollout Support](https://verl.readthedocs.io/en/latest/sglang_multiturn/multiturn.html)\n  - [Search Tool Integration](https://verl.readthedocs.io/en/latest/sglang_multiturn/search_tool_example.html)\n  - [Sandbox Fusion Integration](https://verl.readthedocs.io/en/latest/examples/sandbox_fusion_example.html)\n  - [Deployment using Separate GPU Resources](https://github.com/volcengine/verl/tree/main/examples/split_placement)\n  - [Extend to Other RL(HF) algorithms](https://verl.readthedocs.io/en/latest/advance/dpo_extension.html)\n  - [Ray API design tutorial](https://verl.readthedocs.io/en/latest/advance/placement.html)\n\n**Blogs from the community**\n\n- [When Reasoning Models Break Tokenization: The Hidden Complexity of Multiturn Training](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/fast_tokenization/multiturn_tokenization_and_masking.md)\n- [verl deployment on AWS SageMaker](https://medium.com/@kaige.yang0110/run-verl-on-sagemaker-using-4x8-l40s-gpus-8e6d5c3c61d3)\n- [verl x SGLang Multi-turn Code Walkthrough](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/code-walk-through/readme_EN.md)\n- [Optimizing SGLang Memory Usage in verl](https://hebiao064.github.io/rl-memory-management)\n- [SGLang, verl, OpenBMB and Tsinghua University: Pioneering End-to-End Multi-Turn RLHF](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/verl-multiturn-rollout-Release.md)\n- [Reinforcement Learning from Human Feedback on AMD GPUs with verl and ROCm Integration](https://rocm.blogs.amd.com/artificial-intelligence/verl-large-scale/README.html)\n- [veMLP x verl ：玩转强化学习训练](https://mp.weixin.qq.com/s/7nbqxk4knMGd-hQE9ls2tA)\n- [使用 verl 进行 GRPO 分布式强化学习训练最佳实践](https://www.volcengine.com/docs/6459/1463942)\n- [HybridFlow verl 原文浅析](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/readme.md)\n- [最高提升 20 倍吞吐量！豆包大模型团队发布全新 RLHF 框架，现已开源！](https://team.doubao.com/en/blog/%E6%9C%80%E9%AB%98%E6%8F%90%E5%8D%8720%E5%80%8D%E5%90%9E%E5%90%90%E9%87%8F-%E8%B1%86%E5%8C%85%E5%A4%A7%E6%A8%A1%E5%9E%8B%E5%9B%A2%E9%98%9F%E5%8F%91%E5%B8%83%E5%85%A8%E6%96%B0-rlhf-%E6%A1%86%E6%9E%B6-%E7%8E%B0%E5%B7%B2%E5%BC%80%E6%BA%90)\n\n## Performance Tuning Guide\n\nThe performance is essential for on-policy RL algorithm. We have written a detailed [performance tuning guide](https://verl.readthedocs.io/en/latest/perf/perf_tuning.html) to help you optimize performance.\n\n## Upgrade to vLLM >= v0.8.2\n\nverl now supports vLLM>=0.8.2 when using FSDP as the training backend. Please refer to [this document](https://github.com/volcengine/verl/blob/main/docs/README_vllm0.8.md) for the installation guide and more information. Please avoid vllm 0.7.x, which contains bugs that may lead to OOMs and unexpected errors.\n\n## Use Latest SGLang\n\nSGLang is fully supported with verl, and SGLang RL Group is working extensively on building unique features, including multi-turn agentic RL, VLM RLHF, server-based RL, and partial rollout. Please refer to [this document](https://verl.readthedocs.io/en/latest/workers/sglang_worker.html) for the installation guide and more information.\n\n## Upgrade to FSDP2\n\nverl is fully embracing FSDP2! FSDP2 is recommended by torch distributed team, providing better throughput and memory usage, and is composible with other features (e.g. torch.compile). To enable FSDP2, simply use verl main and set the following options:\n```\nactor_rollout_ref.ref.strategy=fsdp2\nactor_rollout_ref.actor.strategy=fsdp2\ncritic.strategy=fsdp2 \nreward_model.strategy=fsdp2 \n```\nFurthermore, FSDP2 cpu offloading is compatible with gradient accumulation. You can turn it on to save memory with `actor_rollout_ref.actor.fsdp_config.offload_policy=True`. For more details, see https://github.com/volcengine/verl/pull/1026\n\n## AMD Support (ROCm Kernel)\n\nverl now supports FSDP as the training engine (Megatron support coming soon) and both integrates with vLLM and SGLang as inference engines. Please refer to [this document](https://github.com/volcengine/verl/blob/main/docs/amd_tutorial/amd_build_dockerfile_page.rst) for the installation guide and more information, and [this document](https://github.com/volcengine/verl/blob/main/docs/amd_tutorial/amd_vllm_page.rst) for the vLLM performance tuning for ROCm.\n\n\n## Citation and acknowledgement\n\nIf you find the project helpful, please cite:\n\n- [HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)\n- [A Framework for Training Large Language Models for Code Generation via Proximal Policy Optimization](https://i.cs.hku.hk/~cwu/papers/gmsheng-NL2Code24.pdf)\n\n```bibtex\n@article{sheng2024hybridflow,\n  title   = {HybridFlow: A Flexible and Efficient RLHF Framework},\n  author  = {Guangming Sheng and Chi Zhang and Zilingfeng Ye and Xibin Wu and Wang Zhang and Ru Zhang and Yanghua Peng and Haibin Lin and Chuan Wu},\n  year    = {2024},\n  journal = {arXiv preprint arXiv: 2409.19256}\n}\n```\n\nverl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The project is adopted and contributed by Bytedance, Anyscale, LMSys.org, [Alibaba Qwen team](https://github.com/QwenLM/), Shanghai AI Lab, Tsinghua University, UC Berkeley, UCLA, UIUC, University of Hong Kong, ke.com, [All Hands AI](https://www.all-hands.dev/), [ModelBest](http://modelbest.cn/), JD AI Lab, Microsoft Research, [StepFun](https://www.stepfun.com/), Amazon, LinkedIn, Meituan, [Camel-AI](https://www.camel-ai.org/), [OpenManus](https://github.com/OpenManus), Xiaomi, NVIDIA research, [Baichuan](https://www.baichuan-ai.com/home), [RedNote](https://www.xiaohongshu.com/), [SwissAI](https://www.swiss-ai.org/), [Moonshot AI (Kimi)](https://www.moonshot-ai.com/), Baidu, Snowflake, Skywork.ai, JetBrains, [IceSword Lab](https://www.iceswordlab.com), and many more.\n\n## Awesome work using verl\n\n- [TinyZero](https://github.com/Jiayi-Pan/TinyZero): a reproduction of **DeepSeek R1 Zero** recipe for reasoning tasks ![GitHub Repo stars](https://img.shields.io/github/stars/Jiayi-Pan/TinyZero)\n- [SkyThought](https://github.com/NovaSky-AI/SkyThought): RL training for Sky-T1-7B by NovaSky AI team. ![GitHub Repo stars](https://img.shields.io/github/stars/NovaSky-AI/SkyThought)\n- [simpleRL-reason](https://github.com/hkust-nlp/simpleRL-reason): SimpleRL-Zoo: Investigating and Taming Zero Reinforcement Learning for Open Base Models in the Wild ![GitHub Repo stars](https://img.shields.io/github/stars/hkust-nlp/simpleRL-reason)\n- [Easy-R1](https://github.com/hiyouga/EasyR1): **Multi-modal** RL training framework ![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/EasyR1)\n- [OpenManus-RL](https://github.com/OpenManus/OpenManus-RL): LLM Agents RL tunning framework for multiple agent environments. ![GitHub Repo stars](https://img.shields.io/github/stars/OpenManus/OpenManus-RL)\n- [rllm](https://github.com/agentica-project/rllm): async RL training with [verl-pipeline](https://github.com/agentica-project/verl-pipeline) ![GitHub Repo stars](https://img.shields.io/github/stars/agentica-project/rllm)\n- [RAGEN](https://github.com/ZihanWang314/ragen): a general-purpose reasoning **agent** training framework ![GitHub Repo stars](https://img.shields.io/github/stars/ZihanWang314/ragen)\n- [Search-R1](https://github.com/PeterGriffinJin/Search-R1): RL with reasoning and **searching (tool-call)** interleaved LLMs ![GitHub Repo stars](https://img.shields.io/github/stars/PeterGriffinJin/Search-R1)\n- [ReSearch](https://github.com/Agent-RL/ReSearch): Learning to **Re**ason with **Search** for LLMs via Reinforcement Learning ![GitHub Repo stars](https://img.shields.io/github/stars/Agent-RL/ReSearch)\n- [Skywork-OR1](https://github.com/SkyworkAI/Skywork-OR1): Skywork open reaonser series ![GitHub Repo stars](https://img.shields.io/github/stars/SkyworkAI/Skywork-OR1)\n- [ToRL](https://github.com/GAIR-NLP/ToRL): Scaling tool-integrated RL ![GitHub Repo stars](https://img.shields.io/github/stars/GAIR-NLP/ToRL)\n- [Absolute Zero Reasoner](https://github.com/LeapLabTHU/Absolute-Zero-Reasoner): [A no human curated data self-play framework for reasoning](https://arxiv.org/abs/2505.03335) ![GitHub Repo stars](https://img.shields.io/github/stars/LeapLabTHU/Absolute-Zero-Reasoner)\n- [verl-agent](https://github.com/langfengQ/verl-agent): A scalable training framework for **long-horizon LLM/VLM agents**, along with a new algorithm **GiGPO** ![GitHub Repo stars](https://img.shields.io/github/stars/langfengQ/verl-agent)\n- [RL-Factory](https://github.com/Simple-Efficient/RL-Factory): An easy and efficient RL post-training framework for Agentic Learning ![GitHub Repo stars](https://img.shields.io/github/stars/Simple-Efficient/RL-Factory)\n- [ReTool](https://retool-rl.github.io/): ReTool: reinforcement learning for strategic tool use in LLMs. Code release is in progress...\n- [verl-tool](https://github.com/TIGER-AI-Lab/verl-tool): An unified and easy-to-extend tool-agent training framework based on verl![GitHub Repo stars](https://img.shields.io/github/stars/TIGER-AI-Lab/verl-tool)\n- [PRIME](https://github.com/PRIME-RL/PRIME): Process reinforcement through implicit rewards ![GitHub Repo stars](https://img.shields.io/github/stars/PRIME-RL/PRIME)\n- [MemAgent](https://github.com/BytedTsinghua-SIA/MemAgent): MemAgent: Reshaping Long-Context LLM with Multi-Conv RL based Memory Agent ![GitHub Repo stars](https://img.shields.io/github/stars/BytedTsinghua-SIA/MemAgent)\n- [POLARIS](https://github.com/ChenxinAn-fdu/POLARIS): A Post-training recipe for scaling RL on Advanced Reasoning models ![GitHub Repo stars](https://img.shields.io/github/stars/ChenxinAn-fdu/POLARIS)\n- [GUI-R1](https://github.com/ritzz-ai/GUI-R1): **GUI-R1**: A Generalist R1-style Vision-Language Action Model For **GUI Agents** ![GitHub Repo stars](https://img.shields.io/github/stars/ritzz-ai/GUI-R1)\n- [DeepRetrieval](https://github.com/pat-jj/DeepRetrieval): RL Training of **Search Agent** with **Search/Retrieval Outcome** ![GitHub Repo stars](https://img.shields.io/github/stars/pat-jj/DeepRetrieval)\n- [Code-R1](https://github.com/ganler/code-r1): Reproducing R1 for **Code** with Reliable Rewards ![GitHub Repo stars](https://img.shields.io/github/stars/ganler/code-r1)\n- [DeepResearcher](https://github.com/GAIR-NLP/DeepResearcher): Scaling deep research via reinforcement learning in real-world environments ![GitHub Repo stars](https://img.shields.io/github/stars/GAIR-NLP/DeepResearcher)\n- [VAGEN](https://github.com/RAGEN-AI/VAGEN): Training VLM agents with multi-turn reinforcement learning ![GitHub Repo stars](https://img.shields.io/github/stars/RAGEN-AI/VAGEN)\n- [RM-R1](https://arxiv.org/abs/2505.02387): RL training of reasoning reward models ![GitHub Repo stars](https://img.shields.io/github/stars/RM-R1-UIUC/RM-R1)\n- [LUFFY](https://arxiv.org/pdf/2504.14945): Learning to Reason under Off-Policy Guidance![GitHub Repo stars](https://img.shields.io/github/stars/ElliottYan/LUFFY)\n- [DeepMath](https://github.com/zwhe99/DeepMath): DeepMath-103K data and series models for math reasoning![GitHub Repo stars](https://img.shields.io/github/stars/zwhe99/DeepMath)\n- [Entropy Mechanism of RL](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL): The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning![GitHub Repo stars](https://img.shields.io/github/stars/PRIME-RL/Entropy-Mechanism-of-RL)\n- [LLaSA-TTS-GRPO](https://github.com/channel-io/ch-tts-llasa-rl-grpo): TTS fine-tuning with GRPO optimization based on LLASA models ![GitHub Repo stars](https://img.shields.io/github/stars/channel-io/ch-tts-llasa-rl-grpo)\n- [PF-PPO](https://arxiv.org/abs/2409.06957): Policy Filtration for PPO based on the reliability of reward signals for more efficient and robust RLHF.\n- [RACRO](https://github.com/gyhdog99/RACRO2): Build multi-modal reasoning models via decoupling it into query-conditioned captioning and text-only reasoning ![GitHub Repo stars](https://img.shields.io/github/stars/gyhdog99/RACRO2)\n\nand many more awesome work listed in [recipe](recipe/README.md).\n\n## Contribution Guide\n\nSee [contributions guide](CONTRIBUTING.md)\n\n## About [ByteDance Seed Team](https://team.doubao.com/)\n\nFounded in 2023, ByteDance Seed Team is dedicated to crafting the industry's most advanced AI foundation models. The team aspires to become a world-class research team and make significant contributions to the advancement of science and society. You can get to know Bytedance Seed better through the following channels👇\n<div>\n  <a href=\"https://team.doubao.com/\">\n    <img src=\"https://img.shields.io/badge/Website-%231e37ff?style=for-the-badge&logo=bytedance&logoColor=white\"></a>\n  <a href=\"https://github.com/user-attachments/assets/469535a8-42f2-4797-acdf-4f7a1d4a0c3e\">\n    <img src=\"https://img.shields.io/badge/WeChat-07C160?style=for-the-badge&logo=wechat&logoColor=white\"></a>\n <a href=\"https://www.xiaohongshu.com/user/profile/668e7e15000000000303157d?xsec_token=ABl2-aqekpytY6A8TuxjrwnZskU-6BsMRE_ufQQaSAvjc%3D&xsec_source=pc_search\">\n    <img src=\"https://img.shields.io/badge/Xiaohongshu-%23FF2442?style=for-the-badge&logo=xiaohongshu&logoColor=white\"></a>\n  <a href=\"https://www.zhihu.com/org/dou-bao-da-mo-xing-tuan-dui/\">\n    <img src=\"https://img.shields.io/badge/zhihu-%230084FF?style=for-the-badge&logo=zhihu&logoColor=white\"></a>\n\n</div>\n---\n\nWe are HIRING! Send us an [email](mailto:haibin.lin@bytedance.com) if you are interested in internship/FTE opportunities in RL for agents.\n"
  },
  {
    "path": "verl_rl/deploy_env.sh",
    "content": "#!/bin/bash\n# Multi-node Environment Deployment Script\n# Usage: bash deploy_env.sh [--all-nodes]\n\nset -e\n\nSCRIPT_DIR=$(cd $(dirname $0); pwd)\nPROJECT_DIR=${SCRIPT_DIR}\n\n# Configuration\nCONDA_ENV_NAME=${CONDA_ENV_NAME:-\"verl\"}\nPYTHON_VERSION=${PYTHON_VERSION:-\"3.10\"}\nHOSTFILE=${HOSTFILE:-\"/etc/mpi/hostfile\"}\n\n# Colors\nGREEN='\\033[0;32m'\nYELLOW='\\033[1;33m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nlog_info() { echo -e \"${GREEN}[INFO]${NC} $1\"; }\nlog_warn() { echo -e \"${YELLOW}[WARN]${NC} $1\"; }\nlog_error() { echo -e \"${RED}[ERROR]${NC} $1\"; }\n\n# Initialize conda\ninit_conda() {\n    for conda_sh in /root/anaconda3/etc/profile.d/conda.sh \\\n                    /root/miniconda3/etc/profile.d/conda.sh \\\n                    $HOME/anaconda3/etc/profile.d/conda.sh \\\n                    $HOME/miniconda3/etc/profile.d/conda.sh \\\n                    /opt/conda/etc/profile.d/conda.sh; do\n        [ -f \"$conda_sh\" ] && source \"$conda_sh\" && return 0\n    done\n    command -v conda &>/dev/null\n}\n\n# Setup proxy\nsetup_proxy() {\n    log_info \"Setting up proxy...\"\n    unset -v http_proxy https_proxy no_proxy\n    export http_proxy=http://oversea-squid2.ko.txyun:11080\n    export https_proxy=http://oversea-squid2.ko.txyun:11080\n    export no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com\n}\n\n# Install on local node\ninstall_local() {\n    log_info \"Installing environment...\"\n\n    # Setup proxy first\n    setup_proxy\n\n    if ! init_conda; then\n        log_error \"Conda not found.\"\n        exit 1\n    fi\n\n    # Configure conda for stability\n    conda config --set remote_read_timeout_secs 600\n    conda config --set remote_connect_timeout_secs 60\n    conda config --set remote_max_retries 10\n    conda config --set show_channel_urls yes\n\n    # Accept TOS for Anaconda channels\n    conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main 2>/dev/null || true\n    conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r 2>/dev/null || true\n\n    # Create or activate conda env\n    if conda env list | grep -q \"^${CONDA_ENV_NAME} \"; then\n        log_warn \"Environment '${CONDA_ENV_NAME}' exists, activating...\"\n    else\n        log_info \"Creating environment '${CONDA_ENV_NAME}'...\"\n        conda create -n ${CONDA_ENV_NAME} python=${PYTHON_VERSION} -y\n    fi\n\n    source $(conda info --base)/etc/profile.d/conda.sh\n    conda activate ${CONDA_ENV_NAME}\n\n    log_info \"Installing torch...\"\n    pip install torch==2.6.0\n\n    # Install requirements\n    log_info \"Installing requirements.txt...\"\n    pip install -r ${PROJECT_DIR}/requirements.txt\n\n    # Install flash-attn separately\n    log_info \"Installing flash-attn...\"\n    pip install flash-attn==2.7.4.post1 --no-build-isolation\n\n    # Install verl package\n    log_info \"Installing verl package...\"\n    cd ${PROJECT_DIR}\n    pip install -e .\n\n    log_info \"Done!\"\n}\n\n# Deploy to all nodes\ndeploy_all_nodes() {\n    [ ! -f \"${HOSTFILE}\" ] && log_error \"Hostfile not found: ${HOSTFILE}\" && exit 1\n\n    ALL_NODES=$(awk '!a[$1]++ {print $1}' ${HOSTFILE})\n    log_info \"Deploying to: ${ALL_NODES}\"\n\n    mkdir -p ./logs/deploy\n    for node in ${ALL_NODES}; do\n        log_info \"Deploying to ${node}...\"\n        ssh -n ${node} \"CONDA_ENV_NAME=${CONDA_ENV_NAME} bash ${SCRIPT_DIR}/deploy_env.sh\" \\\n            > \"./logs/deploy/deploy_${node}.log\" 2>&1 &\n    done\n\n    wait\n    log_info \"Deployment completed! Check logs in ./logs/deploy/\"\n}\n\n# Main\ncase \"${1}\" in\n    --all-nodes) deploy_all_nodes ;;\n    *) install_local ;;\nesac\n"
  },
  {
    "path": "verl_rl/docker/Apptainerfile.rocm",
    "content": "Bootstrap: docker\n\n# Support - Traing: fsdp; Inference: vllm\n# FROM: rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n# Support - Traing: fsdp; Inference: vllm, sglang\nFROM lmsysorg/sglang:v0.4.5-rocm630\n\n%environment\n    export PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n\n    export HIPCC_COMPILE_FLAGS_APPEND=\"--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__\"\n    export CFLAGS=\"-D__HIP_PLATFORM_AMD__\"\n    export CXXFLAGS=\"-D__HIP_PLATFORM_AMD__\"\n\n%post\n    # Create source directory\n    mkdir -p /opt/src\n\n    # Uninstall and reinstall vllm\n    pip uninstall -y vllm\n    cd /opt/src\n    git clone -b v0.6.3 https://github.com/vllm-project/vllm.git\n    cd vllm\n    MAX_JOBS=$(nproc) python3 setup.py install\n    cd /opt\n    rm -rf /opt/src/vllm\n\n    # Install dependencies\n    pip install \"tensordict<0.6\" --no-deps\n    pip install accelerate \\\n        codetiming \\\n        datasets \\\n        dill \\\n        hydra-core \\\n        liger-kernel \\\n        numpy \\\n        pandas \\\n        peft \\\n        \"pyarrow>=15.0.0\" \\\n        pylatexenc \\\n        \"ray[data,train,tune,serve]\" \\\n        torchdata \\\n        transformers \\\n        wandb \\\n        orjson \\\n        pybind11\n\n    # Clone and install verl from GitHub\n    cd /opt\n    git clone https://github.com/volcengine/verl.git\n    cd verl\n    # Uncomment to use a specific version\n    # git checkout v0.3.0.post0\n    pip install -e . --no-deps\n\n    # Install torch_memory_saver\n    pip install git+https://github.com/ExtremeViscent/torch_memory_saver.git --no-deps"
  },
  {
    "path": "verl_rl/docker/Dockerfile.extention.awsefa",
    "content": "# Base Image support aws EFA\n# Build Image with frameworks based on this\nFROM verlai/verl:app-verl0.5-sglang0.4.6.post5-mcore0.12.2\n\n# For aws instances with EFA net interface (Sagemaker AI Pod)\n#     install EFA driver:\n######## AWS EFA ############\nENV NCCL_VERSION=2.25.1-1\nENV DEBIAN_FRONTEND=noninteractive\nENV EFA_INSTALLER_VERSION=1.40.0\nENV AWS_OFI_NCCL_VERSION=1.14.2\nENV FI_EFA_SET_CUDA_SYNC_MEMOPS=0\nENV FI_PROVIDER=efa\n\nRUN apt update && apt install -y linux-image-generic libhwloc-dev\n\nRUN cd /tmp && \\\n    curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz  && \\\n    tar -xf aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz && \\\n    cd aws-efa-installer && \\\n    ./efa_installer.sh -y -g --skip-kmod --skip-limit-conf --no-verify && \\\n    ldconfig && \\\n    rm -rf /tmp/aws-efa-installer /var/lib/apt/lists/*\n\n# NCCL EFA Plugin\nRUN cd /tmp && \\\n    curl -LO https://github.com/aws/aws-ofi-nccl/archive/refs/tags/v${AWS_OFI_NCCL_VERSION}.tar.gz && \\\n    tar -xzf /tmp/v${AWS_OFI_NCCL_VERSION}.tar.gz && \\\n    rm /tmp/v${AWS_OFI_NCCL_VERSION}.tar.gz && \\\n    mv aws-ofi-nccl-${AWS_OFI_NCCL_VERSION} aws-ofi-nccl && \\\n    cd /tmp/aws-ofi-nccl && \\\n    ./autogen.sh && \\\n    ./configure --prefix=/opt/amazon/efa \\\n    --with-libfabric=/opt/amazon/efa \\\n    --with-cuda=/usr/local/cuda \\\n    --enable-platform-aws \\\n    --with-mpi=/opt/amazon/openmpi && \\\n    make -j$(nproc) install && \\\n    rm -rf /tmp/aws-ofi/nccl\n\n# NCCL\nRUN echo \"/usr/local/lib\"      >> /etc/ld.so.conf.d/local.conf && \\\n    echo \"/opt/amazon/openmpi/lib\" >> /etc/ld.so.conf.d/efa.conf && \\\n    ldconfig\n\nENV OMPI_MCA_pml=^cm,ucx            \\\n    OMPI_MCA_btl=tcp,self           \\\n    OMPI_MCA_btl_tcp_if_exclude=lo,docker0,veth_def_agent \\\n    OPAL_PREFIX=/opt/amazon/openmpi \\\n    NCCL_SOCKET_IFNAME=^docker,lo,veth_def_agent  \\\n    FI_EFA_USE_HUGE_PAGE=0\n\n# docker build -t verl:awsefa --label \"commit=$(git rev-parse --short HEAD)\" .\n# on aws:\n# docker run --ipc=host --privileged --name verldev --gpus all --network=host --shm-size=1800gb -itd verl:awsefa\n"
  },
  {
    "path": "verl_rl/docker/Dockerfile.ngc.vllm",
    "content": "# docker buildx build --platform linux/x86_64 -t \"verlai/verl:ngc-th2.4.0-cu124-vllm0.6.3-ray2.4-te1.7-v0.0.6\" -f docker/Dockerfile.ngc.vllm . --builder cloud-verlai-verl-builder --progress=plain --push\nFROM nvcr.io/nvidia/pytorch:24.05-py3\n\n# uninstall nv-pytorch fork\nRUN pip3 uninstall pytorch-quantization \\\n    pytorch-triton \\\n    torch \\\n    torch-tensorrt \\\n    torchvision \\\n    xgboost transformer_engine flash_attn \\\n    apex megatron-core -y\n\nRUN pip3 install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124\n\n# =============== Megatron dependencies (optional) =================\n# install apex, set MAX_JOBS to avoid OOMs\nRUN MAX_JOBS=4 pip3 install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \\\n    --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" \\\n    git+https://github.com/NVIDIA/apex\n# =============== End of Megatron dependencies (optional) =================\n\nRUN pip3 install --no-cache-dir \\\n    accelerate \\\n    codetiming \\\n    datasets \\\n    dill \\\n    hydra-core \\\n    numpy \\\n    'pandas' \\\n    'peft' \\\n    'pyarrow>=15.0.0' \\\n    'pybind11' \\\n    'pylatexenc' \\\n    'ray>=2.10' \\\n    'tensordict<0.6' \\\n    'transformers' \\\n    'vllm==0.6.3.post1' \\\n    'wandb'\n\n# full dependencies\nRUN pip3 install pytest pre-commit py-spy pyext liger-kernel\n\n# =============== Megatron dependencies (optional) =================\n# install Transformer Engine, which requires FA 2.5.8. Do it in a separate step for docker cache\nRUN MAX_JOBS=4 NINJA_FLAGS=\"-j4\" pip3 install flash-attn==2.5.8 --no-cache-dir --no-build-isolation\nRUN MAX_JOBS=1 NINJA_FLAGS=\"-j1\" TE_BUILD_WITH_NINJA=0 pip3 install git+https://github.com/eric-haibin-lin/TransformerEngine.git@v1.7.0\n# =============== End of Megatron dependencies (optional) =================\n"
  },
  {
    "path": "verl_rl/docker/Dockerfile.ngc.vllm0.8",
    "content": "# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\n# Install torch-2.6.0+cu124 + vllm-0.8.3\n# torch-2.6.0+cu124: cxx11abi=False\n# torch-2.6.0+cu126: cxx11abi=True\n# see https://github.com/flashinfer-ai/flashinfer/issues/911\nRUN pip install --no-cache-dir \"vllm==0.8.3\" \"torch==2.6.0\" \"torchvision==0.21.0\" \"torchaudio==2.6.0\" \"tensordict==0.6.2\" torchdata \\\n    \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=15.0.0\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \\\n    pytest py-spy pyext pre-commit ruff\n\n# Install flash-attn-2.7.4.post1 (cxx11abi=False)\nRUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \\\n    pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n\n# Install flashinfer-0.2.2.post1+cu124 (cxx11abi=False)\n# vllm-0.8.3 does not support flashinfer>=0.2.3\n# see https://github.com/vllm-project/vllm/pull/15777\nRUN wget -nv https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install verl\nRUN pip install --no-cache-dir verl[vllm] -U\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n"
  },
  {
    "path": "verl_rl/docker/Dockerfile.ngc.vllm0.8.sagemaker",
    "content": "# Using a pre-built image from AWS DLC which contains the current version of python (3.10) and supported cuda version (12.1)\nFROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.1.0-transformers4.36.0-gpu-py310-cu121-ubuntu20.04\n\n# uninstall nv-pytorch fork\nRUN pip3 uninstall -y pytorch-quantization \\\n    pytorch-triton torch torch-tensorrt torchvision \\\n    xgboost transformer_engine flash_attn apex megatron-core\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini && \\\n    apt-get clean\n\n# Install torch-2.6.0 + vllm-0.8.2\nRUN pip install --no-cache-dir vllm==0.8.2 torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata==0.11.0 \\\n    transformers>=4.49.0 accelerate datasets peft hf-transfer \\\n    ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \\\n    pytest pre-commit py-spy pyext ruff\n\n# Install flash_attn-2.7.4.post1\nRUN pip uninstall -y transformer-engine flash-attn && \\\n    pip install flash-attn==2.7.4.post1 --no-build-isolation\n\n# Fix cv2\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6 && \\\n    pip install --no-cache-dir --upgrade optree>=0.13.0\n\n# Install verl\nRUN pip install --no-cache-dir verl[vllm] -U\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n"
  },
  {
    "path": "verl_rl/docker/Dockerfile.rocm",
    "content": "# FROM \"compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.4:94_ubuntu22.04_py3.10_pytorch_release-2.7_575e247\"\nFROM \"rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04\"\n\nSHELL [\"/bin/bash\", \"-ceuxo\", \"pipefail\"]\n\nENV MAX_JOBS=512\n\nENV PATH=\"/usr/local/python3.12/bin:$PATH\"\nRUN ln -sf /usr/bin/python3.12 /usr/bin/python && \\\n    ln -sf /usr/bin/pip3.12 /usr/bin/pip\n\n############################################\n############################################\nRUN apt-get update\nRUN apt-get install -y pkg-config liblzma-dev\n############################################\n############################################\n\n\n###########################################\n##########Install TransformerEngine########\n###########################################\nWORKDIR /workspace/\n# transformer-engine install\n# https://github.com/ROCm/TransformerEngine\n\nRUN rm -rf TransformerEngine \nRUN git clone --recursive https://github.com/ROCm/TransformerEngine.git\nWORKDIR /workspace/TransformerEngine\nRUN git checkout 236178e5\n# git checkout bb061ade\n# git checkout 864405c\n\nENV NVTE_FRAMEWORK=pytorch \nENV NVTE_ROCM_ARCH=gfx942 \nENV NVTE_USE_HIPBLASLT=1\nENV NVTE_USE_ROCM=1  \n\n# export CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}\"\nENV CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr\"\n\n\n# ENV NVTE_BUILD_MAX_JOBS=$(MAX_JOBS)\n\nRUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv \n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n\n####################################################################################\n################Install vllm - sglang require vllm 0.6.7 dependency#################\n####################################################################################\n#### Require vllm 0.6.7 - checkout 113274a0\nWORKDIR /workspace/\nRUN rm -rf vllm\nRUN pip uninstall -y vllm\n# Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html\nRUN git clone https://github.com/ROCm/vllm.git\n# git clone https://github.com/vllm-project/vllm.git\nWORKDIR /workspace/vllm\nRUN git checkout 113274a0\nENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n#ENV MAX_JOBS=512\nENV MAX_JOBS=${MAX_JOBS}\nRUN pip install \"boto3>=1.26.0\"\nRUN pip install setuptools_scm\n# will add src into py. You can delete the repo\nRUN python3 setup.py install\nWORKDIR /workspace/\n####################################################################################\n####################################################################################\n####################################################################################\n\n\n\n###########################################\n############For hack docker################\n###########################################\nRUN pip install setuptools==75.8.0\n###########################################\n###########################################\n###########################################\n\n\n\n###########################################\n############build sgalng###################\n###########################################\n# Set environment variables\nENV BASE_DIR=/sgl-workspace\nENV BUILD_TYPE=all\nENV SGL_REPO=https://github.com/sgl-project/sglang\nENV SGL_BRANCH=v0.4.6.post5\nENV TRITON_REPO=https://github.com/ROCm/triton.git\nENV TRITON_COMMIT=improve_fa_decode_3.0.0\nENV AITER_REPO=https://github.com/ROCm/aiter.git\nENV AITER_COMMIT=v0.1.2\n# v0.1.2 version - commit id: 9d11f47\n# ENV AITER_COMMIT=9d11f47\n\nENV HIP_FORCE_DEV_KERNARG=1\nENV HSA_NO_SCRATCH_RECLAIM=1\nENV SGLANG_SET_CPU_AFFINITY=1\nENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1\nENV NCCL_MIN_NCHANNELS=112\nENV MOE_PADDING=1\nENV VLLM_FP8_PADDING=1\nENV VLLM_FP8_ACT_PADDING=1\nENV VLLM_FP8_WEIGHT_PADDING=1\nENV VLLM_FP8_REDUCE_CONV=1\nENV TORCHINDUCTOR_MAX_AUTOTUNE=1\nENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1\nENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\nENV AMDGPU_TARGETS=gfx942\nENV ROCM_ARCH=gfx942\nENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n\n# Switch to working directory\nWORKDIR /sgl-workspace\n\n# Clean and create directory\nRUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace\n\n# Clone and build sglang\nRUN git clone ${SGL_REPO} \\\n    && cd sglang \\\n    && git checkout ${SGL_BRANCH} || echo \"Using default branch\" \\\n    && cd sgl-kernel \\\n    && rm -f pyproject.toml \\\n    && mv pyproject_rocm.toml pyproject.toml \\\n    && python setup_rocm.py install \\\n    && cd .. \\\n    && if [ \"$BUILD_TYPE\" = \"srt\" ]; then \\\n         python -m pip --no-cache-dir install -e \"python[srt_hip]\"; \\\n       else \\\n         python -m pip --no-cache-dir install -e \"python[all_hip]\"; \\\n       fi \\\n    && cd /sgl-workspace \\\n    && cp -r /sgl-workspace/sglang /sglang \\\n    && python -m pip cache purge\n\n# Install common Python packages\nRUN pip install IPython orjson python-multipart torchao pybind11\n\n# Rebuild Triton\nRUN pip uninstall -y triton || true \\\n    && git clone ${TRITON_REPO} \\\n    && cd triton \\\n    && git checkout ${TRITON_COMMIT} \\\n    && cd python \\\n    && python3 setup.py install \\\n    && cd /sgl-workspace\n\n\n# ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1\"\n# ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\n\n# Build aiter\n#version: Commit 9d11f47\n    # && git checkout ${AITER_COMMIT} \\\nRUN pip uninstall -y aiter || true\nRUN git clone ${AITER_REPO} \\\n    && cd aiter \\\n    && git checkout ${AITER_COMMIT} \\\n    && git submodule sync \\\n    && git submodule update --init --recursive \\\n    && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \\\n    && cd /sgl-workspace\n    # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \\\n    # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \\\n\n# Copy MI300X config \nRUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \\\n         /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \\\n         -type f -name '*MI300X*' | \\\n         xargs -I {} sh -c 'vf_config=$(echo \"$1\" | sed \"s/MI300X/MI300X_VF/\"); cp \"$1\" \"$vf_config\"' -- {}\n\n# Environment setup complete.\nRUN echo \"Environment setup complete.\"\n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n\n\n###########################################\n###############vllm v0.8.5#################\n###########################################\n# ENV GITHUB_USERNAME=yushengsu-thu\n# ENV GITHUB_MAIL=yushengsu@gmail.com\n\n# RUN git config --global user.name \"${GITHUB_USERNAME}\" \\\n#     && git config --global user.email \"${GITHUB_MAIL}\" \n\nWORKDIR /workspace/\n\nENV VLLM_TARGET_DEVICE=rocm \nENV ROCM_PATH=/opt/rocm \nENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev\n\n# Find the repo path in: DockerFile/Dockerfile.rocm_yang\n# RUN git clone https://github.com/RLFoundation/vllm-patch.git\nRUN pip uninstall -y vllm || true\nRUN rm -rf vllm-patch\nRUN git clone https://github.com/RLFoundation/vllm-patch.git \\\n    && cd vllm-patch \\\n    && git checkout v0.8.5-sleep-numa \\\n    && rm -rf build/ dist/ *.egg-info \\\n    && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \\\n    && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py install\n    # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py develop\n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n#########################################\n#### Install megatron-core###############\n#########################################\nRUN pip uninstall -y megatron-core && \\\n    git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \\\n    cd Megatron-LM-amd_version && \\\n    pip install -vvv -e . && \\\n    cd /workspace/\n#########################################\n#########################################\n#########################################\n\n\n\n\n#######################################\n################apex###################\n#######################################\nWORKDIR /workspace/\nRUN pip uninstall -y apex && \\\n    git clone https://github.com/ROCm/apex.git && \\\n    cd apex && \\\n    python setup.py install && \\\n    cd /workspace/ \n#######################################\n#######################################\n#######################################\n\n\n\n\n################################################################################\n###########################Add torch_memory_saver###############################\n################################################################################\n# Set environment variables\nENV HIPCC_COMPILE_FLAGS_APPEND=\"--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__\"\nENV CFLAGS=\"-D__HIP_PLATFORM_AMD__\"\nENV CXXFLAGS=\"-D__HIP_PLATFORM_AMD__\"\nRUN pip install \"git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa\"\n################################################################################\n################################################################################\n################################################################################\n\n\n\n########################################\n######Install ray#######################\n########################################\n# need to add this patch: https://github.com/ray-project/ray/pull/53531/files\nRUN pip uninstall ray -y\nRUN pip install \"ray[data,train,tune,serve]>=2.47.0\" \n########################################\n########################################\n########################################\n\n\n\n##########################################\n#######Install other dependencies#########\n##########################################\nRUN pip install \"tensordict==0.6.2\" --no-deps && \\\n    pip install accelerate \\\n    codetiming \\\n    datasets \\\n    dill \\\n    hydra-core \\\n    liger-kernel \\\n    numpy \\\n    pandas \\\n    peft \\\n    \"pyarrow>=15.0.0\" \\\n    pylatexenc \\\n    torchdata \\\n    wandb \\\n    orjson \\\n    pybind11\n    \nWORKDIR /workspace/\nRUN git clone https://github.com/volcengine/verl.git && \\\n    cd verl && \\\n    pip install -e . \n##########################################\n##########################################\n##########################################\n\n\n\nWORKDIR /workspace/\n\nCMD [\"/usr/bin/bash\"]\n"
  },
  {
    "path": "verl_rl/docker/Dockerfile.rocm_verl-0.3.0.post1",
    "content": "#  Build the docker in the repo dir:\n# docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 .\n# docker images # you can find your built docker\n\n\n# Support - Traing: fsdp; Inference: vllm\n# FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n# Support - Traing: fsdp; Inference: vllm, sglang\nFROM lmsysorg/sglang:v0.4.6.post5-rocm630\n\n# Set working directory\n# WORKDIR $PWD/app\n\n# Set environment variables\nENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n\nENV HIPCC_COMPILE_FLAGS_APPEND=\"--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__\"\nENV CFLAGS=\"-D__HIP_PLATFORM_AMD__\"\nENV CXXFLAGS=\"-D__HIP_PLATFORM_AMD__\"\n\n# Install vllm\nRUN pip uninstall -y vllm && \\\n    rm -rf vllm && \\\n    git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \\\n    cd vllm && \\\n    MAX_JOBS=$(nproc) python3 setup.py install && \\\n    cd .. && \\\n    rm -rf vllm\n\n# Copy the entire project directory\nCOPY . .\n\n# Install dependencies\nRUN pip install \"tensordict==0.6.2\" --no-deps && \\\n    pip install accelerate \\\n    codetiming \\\n    datasets \\\n    dill \\\n    hydra-core \\\n    liger-kernel \\\n    numpy \\\n    pandas \\\n    peft \\\n    \"pyarrow>=15.0.0\" \\\n    pylatexenc \\\n    \"ray[data,train,tune,serve]<2.45.0\" \\\n    torchdata \\\n    transformers \\\n    wandb \\\n    orjson \\\n    pybind11\n    \nRUN git clone https://github.com/volcengine/verl.git && \\\n    cd verl && \\\n    pip install -e . \n\n# Install torch_memory_saver\nRUN pip install git+https://github.com/ExtremeViscent/torch_memory_saver.git --no-deps\n"
  },
  {
    "path": "verl_rl/docker/Dockerfile.rocm_verl-0.4.1",
    "content": "# FROM \"compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.4:94_ubuntu22.04_py3.10_pytorch_release-2.7_575e247\"\nFROM \"rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04\"\n\nSHELL [\"/bin/bash\", \"-ceuxo\", \"pipefail\"]\n\nENV MAX_JOBS=512\n\nENV PATH=\"/usr/local/python3.12/bin:$PATH\"\nRUN ln -sf /usr/bin/python3.12 /usr/bin/python && \\\n    ln -sf /usr/bin/pip3.12 /usr/bin/pip\n\n############################################\n############################################\nRUN apt-get update\nRUN apt-get install -y pkg-config liblzma-dev\n############################################\n############################################\n\n\n###########################################\n##########Install TransformerEngine########\n###########################################\nWORKDIR /workspace/\n# transformer-engine install\n# https://github.com/ROCm/TransformerEngine\n\nRUN rm -rf TransformerEngine \nRUN git clone --recursive https://github.com/ROCm/TransformerEngine.git\nWORKDIR /workspace/TransformerEngine\nRUN git checkout 236178e5\n# git checkout bb061ade\n# git checkout 864405c\n\nENV NVTE_FRAMEWORK=pytorch \nENV NVTE_ROCM_ARCH=gfx942 \nENV NVTE_USE_HIPBLASLT=1\nENV NVTE_USE_ROCM=1  \n\n# export CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}\"\nENV CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr\"\n\n\n# ENV NVTE_BUILD_MAX_JOBS=$(MAX_JOBS)\n\nRUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv \n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n\n####################################################################################\n################Install vllm - sglang require vllm 0.6.7 dependency#################\n####################################################################################\n#### Require vllm 0.6.7 - checkout 113274a0\nWORKDIR /workspace/\nRUN rm -rf vllm\nRUN pip uninstall -y vllm\n# Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html\nRUN git clone https://github.com/ROCm/vllm.git\n# git clone https://github.com/vllm-project/vllm.git\nWORKDIR /workspace/vllm\nRUN git checkout 113274a0\nENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n#ENV MAX_JOBS=512\nENV MAX_JOBS=${MAX_JOBS}\nRUN pip install \"boto3>=1.26.0\"\nRUN pip install setuptools_scm\n# will add src into py. You can delete the repo\nRUN python3 setup.py install\nWORKDIR /workspace/\n####################################################################################\n####################################################################################\n####################################################################################\n\n\n\n###########################################\n############For hack docker################\n###########################################\nRUN pip install setuptools==75.8.0\n###########################################\n###########################################\n###########################################\n\n\n\n###########################################\n############build sgalng###################\n###########################################\n# Set environment variables\nENV BASE_DIR=/sgl-workspace\nENV BUILD_TYPE=all\nENV SGL_REPO=https://github.com/sgl-project/sglang\nENV SGL_BRANCH=v0.4.6.post5\nENV TRITON_REPO=https://github.com/ROCm/triton.git\nENV TRITON_COMMIT=improve_fa_decode_3.0.0\nENV AITER_REPO=https://github.com/ROCm/aiter.git\nENV AITER_COMMIT=v0.1.2\n# v0.1.2 version - commit id: 9d11f47\n# ENV AITER_COMMIT=9d11f47\n\nENV HIP_FORCE_DEV_KERNARG=1\nENV HSA_NO_SCRATCH_RECLAIM=1\nENV SGLANG_SET_CPU_AFFINITY=1\nENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1\nENV NCCL_MIN_NCHANNELS=112\nENV MOE_PADDING=1\nENV VLLM_FP8_PADDING=1\nENV VLLM_FP8_ACT_PADDING=1\nENV VLLM_FP8_WEIGHT_PADDING=1\nENV VLLM_FP8_REDUCE_CONV=1\nENV TORCHINDUCTOR_MAX_AUTOTUNE=1\nENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1\nENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\nENV AMDGPU_TARGETS=gfx942\nENV ROCM_ARCH=gfx942\nENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n\n# Switch to working directory\nWORKDIR /sgl-workspace\n\n# Clean and create directory\nRUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace\n\n# Clone and build sglang\nRUN git clone ${SGL_REPO} \\\n    && cd sglang \\\n    && git checkout ${SGL_BRANCH} || echo \"Using default branch\" \\\n    && cd sgl-kernel \\\n    && rm -f pyproject.toml \\\n    && mv pyproject_rocm.toml pyproject.toml \\\n    && python setup_rocm.py install \\\n    && cd .. \\\n    && if [ \"$BUILD_TYPE\" = \"srt\" ]; then \\\n         python -m pip --no-cache-dir install -e \"python[srt_hip]\"; \\\n       else \\\n         python -m pip --no-cache-dir install -e \"python[all_hip]\"; \\\n       fi \\\n    && cd /sgl-workspace \\\n    && cp -r /sgl-workspace/sglang /sglang \\\n    && python -m pip cache purge\n\n# Install common Python packages\nRUN pip install IPython orjson python-multipart torchao pybind11\n\n# Rebuild Triton\nRUN pip uninstall -y triton || true \\\n    && git clone ${TRITON_REPO} \\\n    && cd triton \\\n    && git checkout ${TRITON_COMMIT} \\\n    && cd python \\\n    && python3 setup.py install \\\n    && cd /sgl-workspace\n\n\n# ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1\"\n# ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\n\n# Build aiter\n#version: Commit 9d11f47\n    # && git checkout ${AITER_COMMIT} \\\nRUN pip uninstall -y aiter || true\nRUN git clone ${AITER_REPO} \\\n    && cd aiter \\\n    && git checkout ${AITER_COMMIT} \\\n    && git submodule sync \\\n    && git submodule update --init --recursive \\\n    && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \\\n    && cd /sgl-workspace\n    # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \\\n    # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \\\n\n# Copy MI300X config \nRUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \\\n         /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \\\n         -type f -name '*MI300X*' | \\\n         xargs -I {} sh -c 'vf_config=$(echo \"$1\" | sed \"s/MI300X/MI300X_VF/\"); cp \"$1\" \"$vf_config\"' -- {}\n\n# Environment setup complete.\nRUN echo \"Environment setup complete.\"\n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n\n\n###########################################\n###############vllm v0.8.5#################\n###########################################\n# ENV GITHUB_USERNAME=yushengsu-thu\n# ENV GITHUB_MAIL=yushengsu@gmail.com\n\n# RUN git config --global user.name \"${GITHUB_USERNAME}\" \\\n#     && git config --global user.email \"${GITHUB_MAIL}\" \n\nWORKDIR /workspace/\n\nENV VLLM_TARGET_DEVICE=rocm \nENV ROCM_PATH=/opt/rocm \nENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev\n\n# Find the repo path in: DockerFile/Dockerfile.rocm_yang\n# RUN git clone https://github.com/RLFoundation/vllm-patch.git\nRUN pip uninstall -y vllm || true\nRUN rm -rf vllm-patch\nRUN git clone https://github.com/RLFoundation/vllm-patch.git \\\n    && cd vllm-patch \\\n    && git checkout v0.8.5-sleep-numa \\\n    && rm -rf build/ dist/ *.egg-info \\\n    && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \\\n    && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py install\n    # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py develop\n\nWORKDIR /workspace/\n###########################################\n###########################################\n###########################################\n\n\n\n\n#########################################\n#### Install megatron-core###############\n#########################################\nRUN pip uninstall -y megatron-core && \\\n    git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \\\n    cd Megatron-LM-amd_version && \\\n    pip install -vvv -e . && \\\n    cd /workspace/\n#########################################\n#########################################\n#########################################\n\n\n\n\n#######################################\n################apex###################\n#######################################\nWORKDIR /workspace/\nRUN pip uninstall -y apex && \\\n    git clone https://github.com/ROCm/apex.git && \\\n    cd apex && \\\n    python setup.py install && \\\n    cd /workspace/ \n#######################################\n#######################################\n#######################################\n\n\n\n\n################################################################################\n###########################Add torch_memory_saver###############################\n################################################################################\n# Set environment variables\nENV HIPCC_COMPILE_FLAGS_APPEND=\"--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__\"\nENV CFLAGS=\"-D__HIP_PLATFORM_AMD__\"\nENV CXXFLAGS=\"-D__HIP_PLATFORM_AMD__\"\nRUN pip install \"git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa\"\n################################################################################\n################################################################################\n################################################################################\n\n\n\n########################################\n######Install ray#######################\n########################################\n# need to add this patch: https://github.com/ray-project/ray/pull/53531/files\nRUN pip uninstall ray -y\nRUN pip install \"ray[data,train,tune,serve]>=2.47.0\" \n########################################\n########################################\n########################################\n\n\n\n##########################################\n#######Install other dependencies#########\n##########################################\nRUN pip install \"tensordict==0.6.2\" --no-deps && \\\n    pip install accelerate \\\n    codetiming \\\n    datasets \\\n    dill \\\n    hydra-core \\\n    liger-kernel \\\n    numpy \\\n    pandas \\\n    peft \\\n    \"pyarrow>=15.0.0\" \\\n    pylatexenc \\\n    torchdata \\\n    wandb \\\n    orjson \\\n    pybind11\n    \nWORKDIR /workspace/\nRUN git clone https://github.com/volcengine/verl.git && \\\n    cd verl && \\\n    pip install -e . \n##########################################\n##########################################\n##########################################\n\n\n\nWORKDIR /workspace/\n\nCMD [\"/usr/bin/bash\"]\nCMD [\"/usr/bin/bash\"]\n"
  },
  {
    "path": "verl_rl/docker/Dockerfile.sglang",
    "content": "# Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=32\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.ustc.edu.cn/ubuntu/\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini && \\\n    apt-get clean\n\n# Change pip source\nARG PIP_INDEX=https://mirrors.aliyun.com/pypi/simple/\n\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Install sglang-0.4.6.post5 and torch-memory-saver\nRUN pip uninstall -y cuda-python && pip install \"sglang[all]==0.4.6.post5\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir\n\n# Install torch-2.6.0\nRUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \\\n    transformers>=4.49.0 accelerate datasets peft hf_transfer \\\n    ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb liger-kernel \\\n    pytest pre-commit py-spy pyext\n\n# Install flash_attn-2.7.4.post1\nRUN pip uninstall -y transformer-engine flash-attn && \\\n    wget -v https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \\\n    pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n\n# Fix cv2\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6\n"
  },
  {
    "path": "verl_rl/docker/Dockerfile.vemlp.vllm.te",
    "content": "# docker buildx build --platform linux/x86_64 -t \"verlai/verl:$TAG\" -f docker/$FILE .\n\n# the one in docker.io is an alias for the one veturbo\n# FROM vemlp-cn-beijing.cr.volces.com/veturbo/pytorch:2.4-cu124\nFROM docker.io/haibinlin/verl:v0.0.5-th2.4.0-cu124-base\n\n# only config pip index with https://pypi.tuna.tsinghua.edu.cn/simple if needed\n# unset for now\nRUN pip3 config unset global.index-url\n\n# transformers 4.47.0 contains the following bug:\n# AttributeError: 'Gemma2Attention' object has no attribute '_flash_attn_uses_top_left_mask'\nRUN pip3 install --no-cache-dir \\\n    torch==2.4.0 \\\n    accelerate \\\n    codetiming \\\n    dill \\\n    hydra-core \\\n    numpy \\\n    pybind11 \\\n    tensordict \\\n    \"transformers <= 4.46.0\"\n\nRUN pip3 install --no-cache-dir flash-attn==2.7.0.post2 --no-build-isolation\n\n# vllm depends on ray\nRUN pip3 install --no-cache-dir vllm==0.6.3 ray==2.10\n\n# install apex\nRUN MAX_JOBS=4 pip3 install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \\\n    --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" \\\n    git+https://github.com/NVIDIA/apex\n\n# install Transformer Engine\n# - flash-attn pinned to 2.5.3 by TransformerEngine, switch to eric-haibin-lin/TransformerEngine.git@v1.7.0 to relax version req\n# - install with: MAX_JOBS=1 NINJA_FLAGS=\"-j1\" TE_BUILD_WITH_NINJA=0 to avoid OOM\n# - cudnn is required by TransformerEngine\n# RUN CUDNN_PATH=/opt/conda/lib/python3.11/site-packages/nvidia/cudnn \\\n#     pip3 install git+https://github.com/eric-haibin-lin/TransformerEngine.git@v1.7.0\nRUN MAX_JOBS=1 NINJA_FLAGS=\"-j1\" pip3 install flash-attn==2.5.3 --no-cache-dir --no-build-isolation\nRUN MAX_JOBS=1 NINJA_FLAGS=\"-j1\" pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@v1.7\n"
  },
  {
    "path": "verl_rl/docker/Dockerfile.vllm.sglang.megatron.deepseek",
    "content": "# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\n# Reinstall CUDA 12.4\nRUN aria2c https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \\\n    mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600\n\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cuda-toolkit-12-4 && \\\n    rm cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    update-alternatives --set cuda /usr/local/cuda-12.4 && \\\n    rm -rf /usr/local/cuda-12.6\n\n# Install torch-2.6.0+cu124 + vllm-0.8.5.post1 + sglang-0.4.6.post5\n# torch-2.6.0+cu124: cxx11abi=False\n# torch-2.6.0+cu126: cxx11abi=True\n# see https://github.com/flashinfer-ai/flashinfer/issues/911\n# Install sglang-0.4.6.post1 and torch-memory-saver\nRUN pip install --resume-retries 999 \"sglang[all]==0.4.6.post5\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install --resume-retries 999 torch-memory-saver --no-cache-dir\n\nRUN pip install --resume-retries 999 --no-cache-dir \"vllm==0.8.5.post1\" \"torch==2.6.0\" \"torchvision==0.21.0\" \"torchaudio==2.6.0\" \"tensordict==0.6.2\" torchdata\n\nRUN pip install --resume-retries 999 --no-cache-dir \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=15.0.0\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile \\\n    pytest py-spy pyext pre-commit ruff\n\n# Install flash-attn-2.7.4.post1 (cxx11abi=False)\nRUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \\\n    pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install Apex\nRUN git clone https://github.com/NVIDIA/apex.git && \\\n    cd apex && \\\n    pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" ./\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/TransformerEngine.git@v2.3\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Fix opencv\nRUN pip install opencv-python\n\nRUN pip install opencv-fixer && \\\n    python -c \"from opencv_fixer import AutoFix; AutoFix()\"\n\n# Install verl\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\n    RUN apt-get update && \\\n    apt-get install -y aria2 libfreeimage3 libfreeimage-dev zlib1g"
  },
  {
    "path": "verl_rl/docker/README.md",
    "content": "# Dockerfiles of verl\n\nWe provide pre-built Docker images for quick setup. And from this version, we utilize a new image release hierarchy for productivity and stability.\n\nThe image types are divided into three large categories:\n\n- **Base Image**: Without inference and training frameworks, only basic dependencies are installed. Can directly install vllm or SGLang on top of it, without need of reinstall torch or CUDA.\n- **Application Image**: Stable version with inference and training frameworks installed.\n- **Preview Image**: Unstable version with the latest frameworks and features.\n\nThe first two types of images are hosted on dockerhub [verlai/verl](https://hub.docker.com/r/verlai/verl) repository, while the preview images are hosted on community repository.\n\n> The image versions are mapped with verl releases, for example, image with tag ``verl0.4`` is built for verl release ``v0.4.x``.\n\n## Base Image\n\nThe stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4``. The installed package versions can be found from tags, and the Dockerfile can be found in ``verl[version]-[packages]/Dockerfile.base``.\n\nThe base images for preview are ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0`` and ``verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`` with different CUDA versions.\n\nThe update of base image is not frequent, and the app image can be built on top of it without reinstalling base packages.\n\n## Application Image\n\nFrom this version, we divide images built for vLLM and SGLang as the divergence of dependent packages like FlashInfer.\n\nThere are four types of application images available:\n\n- **vLLM with FSDP and Megatron**: ``verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2``, with Deep-EP support: ``verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2-deepep``.\n- **SGLang with FSDP and Megatron**: ``verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2`` (need vLLM support, but can have some package conflicts), with Deep-EP support: ``verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2-deepep``.\n- **Preview version of SGLang with FSDP and Megatron, CUDA 12.6**: ``verlai/verl:app-verl0.5-sglang0.4.8-mcore0.12.2-te2.2``\n- **Preview version of SGLang with FSDP and Megatron, CUDA 12.8**: ``verlai/verl:app-preview-verl0.5-sglang0.4.8-mcore0.12.2-te2.2``\n\nFor Megatron 0.13.0, we offer preview images, to use latest codes, just replace ``mcore0.12.2`` with ``mcore0.13.0-preview`` in the above image tag.\n\nThe latest vLLM support is coming soon.\n\nDocker images with Megatron backends are runnable with large language model like ``Qwen/Qwen3-235B-A22B``, ``deepseek-ai/DeepSeek-V3-0324`` post-training. Refer to the :doc:`Large Language Model Post-Training documentation<../perf/dpsk>` for more details.\n\nApplication images can be updated frequently, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.app.[frameworks]``. Based on the base image, it is easy to build your own application image with the desired inference and training frameworks.\n\n## Community Image\n\nFor vLLM with FSDP, please refer to [hiyouga/verl](https://hub.docker.com/r/hiyouga/verl) repository and the latest version is ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``.\n\nFor SGLang with FSDP, please refer to [ocss884/verl-sglang](https://hub.docker.com/r/ocss884/verl-sglang) repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group.\n\nSee files under ``docker/`` for NGC-based image or if you want to build your own.\n\nNote that For aws instances with EFA net interface (Sagemaker AI Pod), you need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa``\n\n## Installation from Docker\n\nAfter pulling the desired Docker image and installing desired inference and training frameworks, you can run it with the following steps:\n\n1. Launch the desired Docker image and attach into it:\n\n```sh\ndocker create --runtime=nvidia --gpus all --net=host --shm-size=\"10g\" --cap-add=SYS_ADMIN -v .:/workspace/verl --name verl <image:tag> sleep infinity\ndocker start verl\ndocker exec -it verl bash\n```\n\n2.\tIf you use the images provided, you only need to install verl itself without dependencies:\n\n```sh\n# install the nightly version (recommended)\ngit clone https://github.com/volcengine/verl && cd verl\npip3 install --no-deps -e .\n```\n\n[Optional] If you hope to switch between different frameworks, you can install verl with the following command:\n\n```sh\n# install the nightly version (recommended)\ngit clone https://github.com/volcengine/verl && cd verl\npip3 install -e .[vllm]\npip3 install -e .[sglang]\n```\n"
  },
  {
    "path": "verl_rl/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.6.post5 and torch-memory-saver\nRUN pip install --resume-retries 999 \"sglang[all]==0.4.6.post5\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir\n\n# Some sglang operations in 0.4.6.post5 require vllm\n# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Fix for transformers 4.53.0\nRUN pip3 install --no-cache-dir \"transformers[hf_xet]<4.52.0\"\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_rl/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12.deepep",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.6.post5 and torch-memory-saver\nRUN pip install --resume-retries 999 \"sglang[all]==0.4.6.post5\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir\n\n# Some sglang operations in 0.4.6.post5 require vllm\n# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Fix for transformers 4.53.0\nRUN pip3 install --no-cache-dir \"transformers[hf_xet]<4.52.0\"\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install"
  },
  {
    "path": "verl_rl/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.13.preview",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.6.post5 and torch-memory-saver\nRUN pip install --resume-retries 999 \"sglang[all]==0.4.6.post5\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir\n\n# Some sglang operations in 0.4.6.post5 require vllm\n# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.13.0\n\n# Fix for transformers 4.53.0\nRUN pip3 install --no-cache-dir \"transformers[hf_xet]<4.52.0\"\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install"
  },
  {
    "path": "verl_rl/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install torch-2.6.0+cu124 + vllm-0.8.5.post1\n# torch-2.6.0+cu124: cxx11abi=False\n# torch-2.6.0+cu126: cxx11abi=True\n# see https://github.com/flashinfer-ai/flashinfer/issues/911\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True)\n# vllm-0.8.3 does not support flashinfer>=0.2.3\n# see https://github.com/vllm-project/vllm/pull/15777\nRUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Fix for transformers 4.53.0\nRUN pip3 install --no-cache-dir \"transformers[hf_xet]<4.52.0\"\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_rl/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12.deepep",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install torch-2.6.0+cu124 + vllm-0.8.5.post1\n# torch-2.6.0+cu124: cxx11abi=False\n# torch-2.6.0+cu126: cxx11abi=True\n# see https://github.com/flashinfer-ai/flashinfer/issues/911\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True)\n# vllm-0.8.3 does not support flashinfer>=0.2.3\n# see https://github.com/vllm-project/vllm/pull/15777\nRUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Fix for transformers 4.53.0\nRUN pip3 install --no-cache-dir \"transformers[hf_xet]<4.52.0\"\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install"
  },
  {
    "path": "verl_rl/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install torch-2.6.0+cu124 + vllm-0.8.5.post1\n# torch-2.6.0+cu124: cxx11abi=False\n# torch-2.6.0+cu126: cxx11abi=True\n# see https://github.com/flashinfer-ai/flashinfer/issues/911\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1\n\n# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True)\n# vllm-0.8.3 does not support flashinfer>=0.2.3\n# see https://github.com/vllm-project/vllm/pull/15777\nRUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install"
  },
  {
    "path": "verl_rl/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.base",
    "content": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai/verl:base-v2-cu124-cudnn9.8-torch2.6-fa2.8.0-te2.3\n# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=16\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\n# Reinstall CUDA 12.4\nRUN aria2c https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \\\n    mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600\n\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cuda-toolkit-12-4 && \\\n    rm cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \\\n    update-alternatives --set cuda /usr/local/cuda-12.4 && \\\n    rm -rf /usr/local/cuda-12.6\n\nRUN pip install --resume-retries 999 --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0\n\nRUN pip install --resume-retries 999 --no-cache-dir \"tensordict==0.6.2\" torchdata \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\n# Install flash-attn-2.7.4.post1 (cxx11abi=False)\nRUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \\\n    pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\n# Install Apex\nRUN git clone https://github.com/NVIDIA/apex.git && \\\n    cd apex && \\\n    pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" ./\n\n# Profiling tools\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    apt-get update && apt-get install -y libxcb-cursor0 && \\\n    dpkg -i ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    rm -rf /usr/local/cuda/bin/nsys && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys  /usr/local/cuda/bin/nsys && \\\n    rm -rf /usr/local/cuda/bin/nsys-ui && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \\\n    rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb\n\n# Fix opencv\nRUN pip install --resume-retries 999 --no-cache-dir opencv-python\n\nRUN pip install --resume-retries 999 --no-cache-dir opencv-fixer && \\\n    python -c \"from opencv_fixer import AutoFix; AutoFix()\"\n\nRUN pip install --resume-retries 999 --no-cache-dir cuda-bindings\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\nRUN apt-get update && \\\n    apt-get install -y libfreeimage3 libfreeimage-dev zlib1g htop\n\n"
  },
  {
    "path": "verl_rl/docker/verl0.4-cu124-torch2.6-fa2.7.4/README.md",
    "content": "# verl image with verl v0.4.x\n\n## Important packages version\n\n```txt\ncuda==12.4\ncudnn==9.8.0\ntorch==2.6.0\nflash_attn=2.7.4\nsglang==0.4.6.post5\nvllm==0.8.5.post1\nvidia-cudnn-cu12==9.8.0.87\ntransformer_engine==2.3\nmegatron.core==core_v0.12.2\n# Preview\ntransformer_engine==2.5\nmegatron.core==core_r0.13.0\n```\n\n## Target\n\n- Base image: \n    - `verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4`\n- App image:\n    - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2`: SGLang requires vLLM in 0.4.6.post5 version, vLLM can have some package conflicts with SGLang\n    - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2-deepep`: Built with deepep\n    - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2`\n    - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2-deepep`: Built with deepep\n- Preview image:\n    - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.13.0-te2.2-preview`\n    - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.13.0-te2.2-preview`"
  },
  {
    "path": "verl_rl/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.app.sglang.mcore0.12",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=8\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.8 and torch-memory-saver\n# Install FlashInfer Python package\nRUN pip install --upgrade pip setuptools packaging\nRUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1\nRUN pip install --resume-retries 999  --no-cache-dir \"sglang[all]==0.4.8\" && pip install torch-memory-saver --no-cache-dir\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.52.3\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_rl/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.app.vllm.mcore0.12",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4\n\n# Define environments\nENV MAX_JOBS=32\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install torch-2.7.0+cu126 + vllm-0.9.1\nRUN pip install --resume-retries 999 --no-cache-dir vllm==0.9.1\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_rl/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.base.torch2.7.0",
    "content": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6\n# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=16\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\nRUN pip install --resume-retries 999 --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0\n\n# Install flash-attn-2.7.4.post1, although built with torch2.6, it is compatible with torch2.7\n# https://github.com/Dao-AILab/flash-attention/issues/1644#issuecomment-2899396361\nRUN ABI_FLAG=$(python -c \"import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')\") && \\\n    URL=\"https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl\" && \\\n    FILE=\"flash_attn-2.7.4.post1+cu12torch2.6cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl\" && \\\n    wget -nv \"${URL}\" && \\\n    pip install --no-cache-dir \"${FILE}\"\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\n# Install Apex\nRUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" --resume-retries 999 git+https://github.com/NVIDIA/apex.git\n\n# Profiling tools\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    apt-get update && apt-get install -y libxcb-cursor0\n\nRUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    rm -rf /usr/local/cuda/bin/nsys && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys  /usr/local/cuda/bin/nsys && \\\n    rm -rf /usr/local/cuda/bin/nsys-ui && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \\\n    rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb\n\nRUN pip install --resume-retries 999 --no-cache-dir \"tensordict==0.6.2\" torchdata \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas cuda-bindings \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\n"
  },
  {
    "path": "verl_rl/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.base.torch2.7.1",
    "content": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6\n# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=16\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\nRUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1\n\n# Install flash-attn-2.7.4.post1, although built with torch2.6, it is compatible with torch2.7\n# https://github.com/Dao-AILab/flash-attention/issues/1644#issuecomment-2899396361\nRUN ABI_FLAG=$(python -c \"import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')\") && \\\n    URL=\"https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl\" && \\\n    FILE=\"flash_attn-2.7.4.post1+cu12torch2.6cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl\" && \\\n    wget -nv \"${URL}\" && \\\n    pip install --no-cache-dir \"${FILE}\"\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\n# Install Apex\nRUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" --resume-retries 999 git+https://github.com/NVIDIA/apex.git\n\n# Profiling tools\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    apt-get update && apt-get install -y libxcb-cursor0\n\nRUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    rm -rf /usr/local/cuda/bin/nsys && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys  /usr/local/cuda/bin/nsys && \\\n    rm -rf /usr/local/cuda/bin/nsys-ui && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \\\n    rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb\n\nRUN pip install --resume-retries 999 --no-cache-dir \"tensordict==0.6.2\" torchdata \"transformers[hf_xet]>=4.52.3\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas cuda-bindings \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\n"
  },
  {
    "path": "verl_rl/docker/verl0.5-cu126-torch2.7-fa2.7.4/README.md",
    "content": "# verl image with verl v0.5\n\n## Important packages version\n\n```txt\ncuda==12.6\ncudnn==9.8.0\ntorch==2.7.1\nflash_attn=2.8.0    ##\nsglang==0.4.8\nvllm==0.8.5.post1\nvidia-cudnn-cu12==9.8.0.87\ntransformer_engine==2.3\nmegatron.core==core_v0.12.2\n# Preview\ntransformer_engine==2.5\nmegatron.core==core_r0.13.0\n```\n\n## Target\n\n- Base image:\n    - `verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.0-fa2.7.4`: We offer a base image with deep ep built in, for vllm\n    - `verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4`: We offer a base image with deep ep built in, for sglang\n- App image:\n    - `verlai/verl:app-verl0.5-vllm0.9.1-mcore0.12.2-te2.2`\n    - `verlai/verl:app-verl0.5-sglang0.4.8-mcore0.12.2-te2.2`"
  },
  {
    "path": "verl_rl/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.12",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0\n\n# Define environments\nENV MAX_JOBS=8\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.8 and torch-memory-saver\n# Install FlashInfer Python package\nRUN pip install --upgrade pip setuptools packaging\nRUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1\nRUN pip install --resume-retries 999  --no-cache-dir \"sglang[all]==0.4.8\" && pip install torch-memory-saver --no-cache-dir\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_rl/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.13.preview",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0\n\n# Define environments\nENV MAX_JOBS=8\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.8 and torch-memory-saver\n# Install FlashInfer Python package\nRUN pip install --upgrade pip setuptools packaging\nRUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1\nRUN pip install --resume-retries 999  --no-cache-dir \"sglang[all]==0.4.8\" && pip install torch-memory-saver --no-cache-dir\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_rl/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.base",
    "content": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6\n# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:24.08-py3\n\n# Define environments\nENV MAX_JOBS=16\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\nRUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1\n\n# Install flash-attn-2.8.0.post2 (cxx11abi=True)\nRUN ABI_FLAG=$(python -c \"import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')\") && \\\n    URL=\"https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl\" && \\\n    FILE=\"flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl\" && \\\n    wget -nv \"${URL}\" && \\\n    pip install --no-cache-dir \"${FILE}\"\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\n# Install Apex\nRUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" --resume-retries 999 git+https://github.com/NVIDIA/apex.git\n\n# Profiling tools\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    apt-get update && apt-get install -y libxcb-cursor0\n\nRUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    rm -rf /usr/local/cuda/bin/nsys && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys  /usr/local/cuda/bin/nsys && \\\n    rm -rf /usr/local/cuda/bin/nsys-ui && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \\\n    rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb\n\nRUN pip install --resume-retries 999 --no-cache-dir \"tensordict==0.6.2\" torchdata \"transformers[hf_xet]>=4.53\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas cuda-bindings \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pyext pre-commit ruff\n\n# Install DeepEP\n## the dependency of IBGDA\nRUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so\n\n## Clone and build deepep and deepep-nvshmem\nRUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \\\n    git clone https://github.com/deepseek-ai/DeepEP.git  && \\\n    cd DeepEP && git checkout a84a248\n\n# Prepare nvshmem\nRUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \\\n    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \\\n    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch\n\nENV CUDA_HOME=/usr/local/cuda\n### Set MPI environment variables. Having errors when not set.\nENV CPATH=/usr/local/mpi/include:$CPATH\nENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH\nENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH\nENV GDRCOPY_HOME=/workspace/gdrcopy\n\n## Build deepep-nvshmem\nRUN cd deepep-nvshmem && \\\n    NVSHMEM_SHMEM_SUPPORT=0 \\\n    NVSHMEM_UCX_SUPPORT=0 \\\n    NVSHMEM_USE_NCCL=0 \\\n    NVSHMEM_MPI_SUPPORT=0 \\\n    NVSHMEM_IBGDA_SUPPORT=1 \\\n    NVSHMEM_PMIX_SUPPORT=0 \\\n    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \\\n    NVSHMEM_USE_GDRCOPY=1 \\\n    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install\n\nENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install\nENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH\nENV PATH=$NVSHMEM_DIR/bin:$PATH\n\n## Build deepep\nRUN cd DeepEP && \\\n    python setup.py install\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\n"
  },
  {
    "path": "verl_rl/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/README.md",
    "content": "# verl image with verl v0.5\n\n## Important packages version\n\n```txt\ncuda==12.6\ncudnn==9.8.0\ntorch==2.7.1\nflash_attn=2.8.0    ##\nsglang==0.4.8\nvllm==0.8.5.post1\nvidia-cudnn-cu12==9.8.0.87\ntransformer_engine==2.3\nmegatron.core==core_v0.12.2\n# Preview\ntransformer_engine==2.5\nmegatron.core==core_r0.13.0\n```\n\n## Target\n\n- Base image:\n    - `verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0`: We offer a base image with deep ep built in\n- App image:\n    - `verlai/verl:app-verl0.5-sglang0.4.9-mcore0.12.2`\n    - `verlai/verl:app-verl0.5-sglang0.4.9-mcore0.13.0-preview`\n- vllm temporarily not support latest version"
  },
  {
    "path": "verl_rl/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.megatron",
    "content": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6\n\n# Define environments\nENV MAX_JOBS=8\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Install sglang-0.4.8 and torch-memory-saver\n# Install FlashInfer Python package\nRUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1\nRUN pip install --resume-retries 999  --no-cache-dir \"sglang[all]==0.4.8\" && pip install torch-memory-saver --no-cache-dir\n\n# Fix packages\nRUN pip install --no-cache-dir \"tensordict==0.6.2\" \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pre-commit ruff\n\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --resume-retries 999 --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\nRUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.13.0\n\n# Install mbridge\nRUN pip3 install --no-cache-dir mbridge"
  },
  {
    "path": "verl_rl/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.base",
    "content": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6\n# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html\nFROM nvcr.io/nvidia/pytorch:25.02-py3\n\n# Define environments\nENV MAX_JOBS=16\nENV VLLM_WORKER_MULTIPROC_METHOD=spawn\nENV DEBIAN_FRONTEND=noninteractive\nENV NODE_OPTIONS=\"\"\nENV PIP_ROOT_USER_ACTION=ignore\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\n# Define installation arguments\nARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/\nARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple\n\n# Set apt source\nRUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \\\n    { \\\n    echo \"deb ${APT_SOURCE} jammy main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-updates main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-backports main restricted universe multiverse\"; \\\n    echo \"deb ${APT_SOURCE} jammy-security main restricted universe multiverse\"; \\\n    } > /etc/apt/sources.list\n\n# Install systemctl\nRUN apt-get update && \\\n    apt-get install -y -o Dpkg::Options::=\"--force-confdef\" systemd && \\\n    apt-get clean\n\n# Install tini\nRUN apt-get update && \\\n    apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \\\n    apt-get clean\n\n# Change pip source\nRUN pip config set global.index-url \"${PIP_INDEX}\" && \\\n    pip config set global.extra-index-url \"${PIP_INDEX}\" && \\\n    python -m pip install --upgrade pip\n\n# Uninstall nv-pytorch fork\nRUN pip uninstall -y torch torchvision torchaudio \\\n    pytorch-quantization pytorch-triton torch-tensorrt \\\n    xgboost transformer_engine flash_attn apex megatron-core grpcio\n\nRUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128\n\n# Install flash-attn-2.8.0.post2 (cxx11abi=True)\nRUN ABI_FLAG=$(python -c \"import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')\") && \\\n    URL=\"https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp312-cp312-linux_x86_64.whl\" && \\\n    FILE=\"flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp312-cp312-linux_x86_64.whl\" && \\\n    wget -nv \"${URL}\" && \\\n    pip install --no-cache-dir \"${FILE}\"\n\n# Fix packages\nRUN pip uninstall -y pynvml nvidia-ml-py && \\\n    pip install --no-cache-dir --upgrade \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n# Install cudnn\nRUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \\\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \\\n    apt-get update && \\\n    apt-get -y install cudnn-cuda-12 && \\\n    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n\n# Install Apex\nRUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" --resume-retries 999 git+https://github.com/NVIDIA/apex.git\n\n# Profiling tools\nRUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    apt-get update && apt-get install -y libxcb-cursor0\n\nRUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \\\n    rm -rf /usr/local/cuda/bin/nsys && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys  /usr/local/cuda/bin/nsys && \\\n    rm -rf /usr/local/cuda/bin/nsys-ui && \\\n    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \\\n    rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb\n\nRUN pip install --resume-retries 999 --no-cache-dir \"tensordict==0.6.2\" torchdata \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=19.0.1\" pandas cuda-bindings \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \\\n    pytest py-spy pre-commit ruff\n\n# Reset pip config\nRUN pip config unset global.index-url && \\\n    pip config unset global.extra-index-url\n\n"
  },
  {
    "path": "verl_rl/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/README.md",
    "content": "# verl image with verl v0.5\n\n## Important packages version\n\n```txt\ncuda==12.8\ncudnn==9.8.0\ntorch==2.7.1\nflash_attn=2.8.0    ##\nsglang==0.4.8\ntransformer_engine==2.5\nmegatron.core==core_r0.13.0\nvidia-cudnn-cu12==9.8.0.87\n```\n\n## Target\n\n- Base image:\n    - `verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`: We offer a base image with flash infer 0.2.6.post1 built in\n- App image:\n    - `verlai/verl:app-verl0.5-preview-sglang0.4.8-mcore0.13.0-preview`\n- vllm temporarily not support latest version\n\n## !!!Notice!!!\n\n- pyext is lack of maintainace and cannot work with python 3.12, consider using replacement and deprecating this package."
  },
  {
    "path": "verl_rl/docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSPHINXPROJ    = verl\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "verl_rl/docs/README.md",
    "content": "# verl documentations\n\n## Build the docs\n\n```bash\n# If you want to view auto-generated API docstring, please make sure verl is available in python path. For instance, install verl via:\n# pip install .. -e[test]\n\n# Install dependencies needed for building docs.\npip install -r requirements-docs.txt\n\n# Build the docs.\nmake clean\nmake html\n```\n\n## Open the docs with your browser\n\n```bash\npython -m http.server -d _build/html/\n```\nLaunch your browser and navigate to http://localhost:8000 to view the documentation. Alternatively you could drag the file `_build/html/index.html` to your local browser and view directly.\n"
  },
  {
    "path": "verl_rl/docs/README_vllm0.7.md",
    "content": "# Upgrading to vllm >= 0.7\n\nNote: verl+vllm 0.8.3 is now stable. Please see ``docs/README_vllm0.8.md`` for upgrade guide.\n\n## Installation\n\nNote: At time of writing, verl+vllm 0.7.x supports **FSDP** for training and **vLLM** for rollout.\n\n```\n# Create the conda environment\nconda create -n verl python==3.10\nconda activate verl\n\n# Install verl\ngit clone https://github.com/volcengine/verl.git\ncd verl\npip3 install -e .\n\n# Install the latest stable version of vLLM\npip3 install vllm==0.7.3 \n\n# Install flash-attn\npip3 install flash-attn --no-build-isolation\n\n```\n\nNote that if you are installing lower versions of vLLM (0.7.0, 0.7.1, 0.7.2), you need to make some tiny patches manually on vllm (/path/to/site-packages/vllm after installation) after the above steps:\n\n- vllm/distributed/parallel_state.py: Remove the assertion below:\n\n```\nif (world_size\n        != tensor_model_parallel_size * pipeline_model_parallel_size):\n    raise RuntimeError(\n        f\"world_size ({world_size}) is not equal to \"\n        f\"tensor_model_parallel_size ({tensor_model_parallel_size}) x \"\n        f\"pipeline_model_parallel_size ({pipeline_model_parallel_size})\")\n\n```\n\n- vllm/executor/uniproc_executor.py: change `local_rank = rank` to `local_rank = int(os.environ[\"LOCAL_RANK\"])`\n- vllm/model_executor/model_loader/weight_utils.py: remove the `torch.cuda.empty_cache()` in `pt_weights_iterator`\n\n## Features\n\n### Use cuda graph\n\nAfter installation, examples using FSDP as training backends can be used. By default, the `enforce_eager` is set to True, which disables the cuda graph. To enjoy cuda graphs and the sleep mode of vLLM>=0.7, add the following lines to the bash script:\n\n```\nactor_rollout_ref.rollout.enforce_eager=False \\\nactor_rollout_ref.rollout.free_cache_engine=True \\\n\n```\n\nFor a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh, the rollout generation time is 85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation duration is further reduced to 62 seconds.\n\n**Note:** Currently, if the `n` is greater than 1 in `SamplingParams` in vLLM>=0.7, there is a potential performance issue on the stability of rollout generation time (Some iterations would see generation time bursts) using vLLM's V0 Engine.\n\n### Use vLLM V1 Engine\n\nUsing the vLLM V1 engine can avoid instability issues and achieve additional performance improvements. To use the V1 engine, you can first uninstall the previously installed vLLM and then follow the steps below to install the newer version.\n\n```\ngit clone https://github.com/vllm-project/vllm.git\ncd vllm\ngit checkout 2275784\nsed -i \"903a\\    data_parallel_size = world_size // pipeline_model_parallel_size // tensor_model_parallel_size\" ./vllm/distributed/parallel_state.py\nVLLM_USE_PRECOMPILED=1 pip install --editable .\n```\n\nThen you can enable the V1 engine by setting `export VLLM_USE_V1=1`. In some benchmark tests, the V1 engine demonstrates a 1.5x speed improvement over the vLLM V0 engine.\nThe stable support of the vLLM V1 engine is available on verl main.\n"
  },
  {
    "path": "verl_rl/docs/README_vllm0.8.md",
    "content": "# Upgrading to vLLM >= 0.8\n\nLast updated: 05/04/2025.\n\n## Installation\n\nNote: This version of verl+vLLM 0.8+ supports **FSDP** for training and **vLLM** for rollout.\n\n```bash\n# Create the conda environment\nconda create -n verl python==3.10\nconda activate verl\n\n# Install verl\ngit clone https://github.com/volcengine/verl.git\ncd verl\npip3 install -e .\n\n# Install the latest stable version of vLLM\npip3 install vllm==0.8.3\n\n# Install flash-attn\npip3 install flash-attn --no-build-isolation\n\n```\n\nWe have a pre-built docker image for verl+vLLM 0.8.3. You can direct import it with the following command:\n\n```bash\ndocker pull hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0\n```\n\n## Features\n\nvLLM 0.8+ supports cuda graph and V1 engine by default in verl. To enable these features, remember to add the following lines to the bash script:\n\n```bash\nactor_rollout_ref.rollout.enforce_eager=False \\\nactor_rollout_ref.rollout.free_cache_engine=True \\\n```\n\nand also **remove** the environment variable if it exists:\n\n## Notes\n\nWhen you just directly upgrade vllm>=0.8, some dependency packages may undergo version changes. If you encounter the following problems:\n\n```bash\nin <module> from torch.multiprocessing.reductions import ForkingPickler ImportError: cannot import name 'ForkingPickler' from 'torch.multiprocessing.reductions' (/opt/conda/lib/python3.11/site-packages/torch/multiprocessing/reductions.py)\n```\n\nYou need to upgrade `tensordict` to version 0.6.2 using the command `pip install tensordict==0.6.2`.\n"
  },
  {
    "path": "verl_rl/docs/_static/js/runllm-widget.js",
    "content": "document.addEventListener(\"DOMContentLoaded\", function () {\n    var script = document.createElement(\"script\");\n    script.type = \"module\";\n    script.id = \"runllm-widget-script\";\n    script.src = \"https://widget.runllm.com\";\n    script.setAttribute(\"version\", \"stable\");\n    script.setAttribute(\"crossorigin\", \"true\");\n    script.setAttribute(\"runllm-keyboard-shortcut\", \"Mod+j\");\n    script.setAttribute(\"runllm-name\", \"verl Chatbot\");\n    script.setAttribute(\"runllm-position\", \"TOP_RIGHT\");\n    script.setAttribute(\"runllm-assistant-id\", \"679\");\n    script.async = true;\n    document.head.appendChild(script);\n  });"
  },
  {
    "path": "verl_rl/docs/advance/agent_loop.rst",
    "content": "Agent Loop\n==========\n\nLast updated: 07/17/2025.\n\n.. versionadded:: 0.4.2\n   [status: alpha]\n\n.. warning::\n   Agent Loop is ready for use, but the API may change in future releaes.\n\nAgent Loop is designed as general interface for multi-turn rollout and agentic reinforcement learning.\n\n**Design goal**:\n\n- Plugable user defined agent loop\n- Provide standard request generate api with different inference frameworks\n- Provide request level load balance between multiple inference servers\n\n**Non-goal**:\n\n- How tool is defined and how to call tool\n\nIn high level overview, agent loop is given a prompt, run user defined loop: call LLM generate api, call tools, ...\nand return the final output. The final output is then calculated reward and used as trajectory for RL training.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop_overview.svg?raw=true\n\n\nAPI Design\n----------\n\n``AgentLoopBase`` class is the abstraction of agent loop, and ``run`` method is the only interface that user need to implement.\nThe run method, given prompt messages in format: [{\"role\": \"user\"}, {\"content\": \"...\"}], and additional sampling params,\ncould do whatever user wants, such as\n\n- call LLM generate api\n- call tools: web search, database query, code sandbox, ...\n- environment interaction\n- reflection\n- ...\n\n.. code:: python\n\n   class AgentLoopBase(ABC):\n       @abstractmethod\n       async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:\n           \"\"\"Run agent loop to interact with LLM server and environment.\n\n           Args:\n               messages (List[Dict[str, Any]]): Input messages.\n               sampling_params (Dict[str, Any]): LLM sampling params.\n\n           Returns:\n               AgentLoopOutput: Agent loop output.\n           \"\"\"\n           raise NotImplementedError\n\nAfter running user defined loop, run method should return ``AgentLoopOutput``, including prompt token ids,\nresponse token ids, and response mask.\n\n.. code:: python\n\n   class AgentLoopOutput(BaseModel):\n       \"\"\"Agent loop output.\"\"\"\n\n       prompt_ids: list[int]\n       \"\"\"Prompt token ids.\"\"\"\n       response_ids: list[int]\n       \"\"\"Response token ids including LLM generated token, tool response token.\"\"\"\n       response_mask: list[int]\n       \"\"\"Response mask, 1 for LLM generated token, 0 for tool response token.\"\"\"\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop_output.svg?raw=true\n\n.. note:: AgentLoopOutput only output one trajectory for a given prompt, multiple trajectories output is still under discussion.\n\nArchitecture Design\n-------------------\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop_architecture.png?raw=true\n\nA single PPO step contain two phase: rollout and train. In rollout phase:\n\n1. PPOTrainer sample a batch from dataset and call ``AgentLoopManager.generate_sequences``.\n2. AgentLoopManager ``wake_up`` all async LLM server instances, which will sync weights between inference engine(vLLM/SGLang) and training engine(FSDP/Megatron-LM).\n3. AgentLoopManager split batch into chunks and send each chunk to ``AgentLoopWorker``.\n4. AgentLoopWorker receive chunk and for each prompt, spawn a user defined ``AgentLoopBase`` instance, run ``run`` coroutine until end and get ``AgentLoopOutput``.\n\n.. tip::\n   AgentLoopWorker schedules multiple coroutines concurrently. If number of AgentLoopWorker equals batch_size, then each worker is response for one prompt.\n\nIn agent loop, when user need LLM generate response:\n\n5. Call ``AsyncLLMServerManager.generate`` with prompt_ids.\n6. AsyncLLMServerManager select a server instance with least request in first turn and send request to it. (In following turns, the request will be sent to the same server instance).\n7. AsyncLLMServer receive a request, issue ipc/rpc with model_runner, and generate response. (There's slight differences between vLLM and SGLang, see below).\n\nWhen all prompts in all AgentLoopWorker finish, AgentLoopManager gather results and return to PPOTrainer.\n\n8. AgentLoopManager ``sleep`` all server instances, which will free kv cache and offload weights to CPU memory.\n\nAsyncLLMServer\n~~~~~~~~~~~~~~\n\nAsyncLLMServer is the abstraction of LLM server with two types of generation api:\n\n- `OpenAI chat completion <https://platform.openai.com/docs/api-reference/chat>`_: generate response for the given chat conversation.\n- Token in token out: generate response ids for the given token ids.\n\nWe have officially supported vLLM and SGLang AsyncLLMServer, both of them implement the two api and are well tested.\nOther inference engine should be easy to plug-in by implement the ``AsyncServerBase`` class.\n\n.. code:: python\n\n   class AsyncServerBase(ABC):\n       @abstractmethod\n       async def chat_completion(self, raw_request: Request) -> JSONResponse:\n           \"\"\"OpenAI chat completion API.\n\n           Args:\n               raw_request (Request): raw json request\n           \n           Returns:\n               JSONResponse: json response\n\n           API reference: https://platform.openai.com/docs/api-reference/chat/create\n           \"\"\"\n           raise NotImplementedError\n\n       @abstractmethod\n       async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]:\n           \"\"\"Generate response ids given prompt ids.\n\n           Args:\n               prompt_ids (List[int]): prompt ids\n               sampling_params (Dict[str, Any]): sampling params\n               request_id (str): request id\n\n           Returns:\n               List[int]: response ids\n           \"\"\"\n           raise NotImplementedError\n\n\nChat completion vs Token in token out\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. warning::\n   The following conclusion is based on our recent experience and is still open to investigation and discussion.\n\nAlmost all agent frameworks (LangGraph, CrewAI, LlamaIndex, etc) call LLM with OpenAI chat completion api, and \nkeep chat history as messages. So user may expect that we should use the chat completion api in multi-turn rollout.\n\nBut based on our recent experience on single-turn training on DAPO and multi-turn training on `retool <https://github.com/volcengine/verl/tree/main/recipe/retool>`_,\nwe found the token_ids from apply the final messages may not equal to the token_ids by concat prompt_ids and response_ids in each turn.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/multi_turn.png?raw=true\n\n**Where does this inconsistency happened?**\n\nFirst, the tool parser may alter the content. For example\n\n.. code:: json\n\n   {\"role\": \"assistant\", \"content\": \"Let me call a <tool_call>...</tool_call> and get the result\"}\n\nAfter tool_calls extraction, the messages is like this:\n\n.. code:: json\n\n   {\"role\": \"assistant\", \"content\": \"Let me call a and get the result\", \"tool_calls\": [{\"name\": \"foo\", \"arguments\": \"{}\"}]}\n\nEncode the extracted message back is not equal to the original LLM generated response_ids.\n\nSecond,  the `decode-encode` may also lead to inconsistency: `Agent-R1 issue#30 <https://github.com/0russwest0/Agent-R1/issues/30#issuecomment-2826155367>`_.\n\n**What is the impact of this inconsistency?**\n\nThis inconsistency is not a big problem for serving/agent system, but is critical to RL training.\nIt causes the trajectory deviate from the policy model distribution. We have observed that apply_chat_template\nto the final chat history messages make PPO training not even converged in single-turn.\n\nvLLM\n^^^^\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/async_vllm.png?raw=true\n\nFor vLLM, the Async LLM Engine is running in same process as the server, and ModelRunner is running in same process as FSDP/Megatron-LM workers.\nAsync LLM Engine communicate with ModelRunner through ZeroMQ. When server receive a request, it directly call engine to generate response_ids.\n\nSGLang\n^^^^^^\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/async_sglang.png?raw=true\n\nFor SGLang, the Async LLM Engine is running in same process as FSDP/Megatron-LM worker-0, and it spawn multiple subprocesses as ModelRunner.\nAlso, Async LLM Engine communicate with ModelRunner through ZeroMQ. When server receive a request, it remote call the worker-0 and get response_ids.\n\nAsyncLLMServerManager\n~~~~~~~~~~~~~~~~~~~~~\n\nAsyncLLMServerManager serve as proxy to multiple AsyncLLMServer instances, provides:\n\n- load balance: select a server instance with least request in first turn and send request to it.\n- sticky session: bind request_id to server instance, so that the same request_id will be sent to the same server instance in following turns.\n\nAsyncLLMServerManager is passed to ``AgentLoopBase.__init__``, whenever user want to interact with LLM in agent loop,\nthey can call ``AsyncLLMServerManager.generate`` to generate response_ids.\n\n.. code:: python\n\n   class AsyncLLMServerManager:\n       async def generate(\n           self,\n           request_id,\n           *,\n           prompt_ids: list[int],\n           sampling_params: dict[str, Any],\n       ) -> list[int]:\n           \"\"\"Generate tokens from prompt ids.\n\n           Args:\n               request_id (str): request id for sticky session.\n               prompt_ids (List[int]): List of prompt token ids.\n               sampling_params (Dict[str, Any]): Sampling parameters for the chat completion.\n\n           Returns:\n               List[int]: List of generated token ids.\n           \"\"\"\n           ...\n\nNext\n----\n\n- :doc:`Agentic RL Training<../start/agentic_rl>`: Quick start agentic RL training with gsm8k dataset.\n- `LangGraph MathExpression <https://github.com/volcengine/verl/tree/main/recipe/langgraph_agent/example>`_: Demonstrate how to use LangGraph to build agent loop.\n- `Retool <https://github.com/volcengine/verl/tree/main/recipe/retool>`_: End-to-end retool paper reproduction using tool agent.\n"
  },
  {
    "path": "verl_rl/docs/advance/checkpoint.rst",
    "content": ".. _checkpoint-page:\n\nUsing Checkpoints to Support Fault Tolerance Training\n=====================================================\n\nLast updated: 06/25/2025.\n\nThere could be training errors or machine failure during the whole RLHF training process, \nso it is recommended to enable checkpoints to minimize your loss.\n\nThe API Interface has already been listed in :ref:`config-explain-page`,\nand we will not repeat them. But there are still some technique details\nwe hope to clarify.\n\n.. note:: \n\n    Notice that the ``checkpoint.contents`` field has no effect to FSDP checkpoint except ``hf_model``, \n    the other 3 fields are binded together to save and load. We recommend to include ``model``, ``optimizer`` and ``extra`` all.\n\nCheckpoint Saving Directory Structure\n-------------------------------------\n\nCommonly, we use the ``default_local_dir`` declared in ``ppo_trainer.yaml`` or ``ppo_megatron_trainer.yml``\nto work as preffix when saving checkpoints, which is ``checkpoints/${trainer.project_name}/${trainer.experiment_name}``.\n\nSo the inner checkpoint structure of **FSDP** is like:\n\n.. code::\n\n    checkpoints/${trainer.project_name}/${trainer.experiment_name}\n    ├── global_steps_${i}\n    │   ├── actor\n    │   │   ├── huggingface      # default save config and tokenizer, save huggingface model if include ``hf_model`` in checkpoint.contents\n    │   │   └── fsdp_config.json # FSDP config file, including world_size and fsdp version\n    │   │   ├── model_world_size_{self.world_size}_rank_{self.rank}.pt\n    │   │   ├── optim_world_size_{self.world_size}_rank_{self.rank}.pt\n    │   │   └── extra_state_world_size_{self.world_size}_rank_{self.rank}.pt\n    │   ├── critic\n    │   │   ├── huggingface\n    │   │   └── fsdp_config.json\n    │   │   ├── model_world_size_{self.world_size}_rank_{self.rank}.pt\n    │   │   ├── optim_world_size_{self.world_size}_rank_{self.rank}.pt\n    │   │   └── extra_state_world_size_{self.world_size}_rank_{self.rank}.pt\n    └── latest_checkpointed_iteration.txt\n\nAll model shards, optimizers and extra states are stored together, in a sharded and distributed way.\n\nWhile **Megatron** current checkpoint structure is:\n\n.. code::\n\n    checkpoints/${trainer.project_name}/${trainer.experiment_name}\n    ├── global_steps_${i}\n    │   ├── actor\n    │   │   ├── huggingface     # default save config and tokenizer, save huggingface model if include ``hf_mode`` in checkpoint.contents\n    │   │   └── dist_ckpt       # save sharded model/optimizer/rng_states, naming the same as Megatron\n    │   └── critic\n    │   │   ├── huggingface\n    │   │   └── dist_ckpt\n    └── latest_checkpointed_iteration.txt\n\nConvert FSDP and Megatron Checkpoints to HuggingFace Format Model\n-----------------------------------------------------------------\n\nWe provide a tool to convert the FSDP and Megatron checkpoints to HuggingFace format model.\nThe tool is located in ``verl/model_merger``. For older versions of verl that don't include fsdp_config.json in checkpoints, you can use the legacy model merger located at ``verl/scripts/legacy_model_merger.py``.\n\nThe script supports two main sub-commands: `merge` (to convert and save checkpoints) and `test` (to validate merged checkpoints against a reference model).\nThe arguments for the `merge` sub-command are as follows:\n\n.. code:: bash\n\n    usage: python -m verl.model_merger merge [-h] --backend {fsdp,megatron} [--local_dir LOCAL_DIR] [--tie-word-embedding] [--is-value-model] [--use_cpu_initialization] [--target_dir TARGET_DIR]\n                         [--hf_upload_path HF_UPLOAD_PATH] [--private]\n\n    options:\n    -h, --help            show this help message and exit\n    --backend {fsdp,megatron}\n                            The backend of the model\n    --local_dir LOCAL_DIR\n                            Path to the saved model checkpoints\n    --tie-word-embedding  Whether to tie word embedding weights (currently only Megatron supported)\n    --is-value-model      Whether the model is a value model (currently only Megatron supported)\n    --use_cpu_initialization\n                            Whether to use CPU initialization for the model. This is useful for large models that cannot fit into GPU memory during initialization.\n    --target_dir TARGET_DIR\n                            Directory to save the merged huggingface model\n    --hf_upload_path HF_UPLOAD_PATH\n                            Hugging Face repository ID to upload the model\n    --private             Whether to upload the model to a private Hugging Face repository\n\nExample usage for merging Megatron checkpoints:\n\n.. code:: bash\n\n    python -m verl.model_merger merge \\\n        --backend megatron \\\n        --tie-word-embedding \\\n        --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \\\n        --target_dir /path/to/merged_hf_model\n\nExample usage for distributed merging Megatron checkpoints:\n\n.. code:: bash\n\n    torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge \\\n        --backend megatron \\\n        --tie-word-embedding \\\n        --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \\\n        --target_dir /path/to/merged_hf_model\n\nExample usage for merging FSDP checkpoints:\n\n.. code:: bash\n\n    python -m verl.model_merger merge \\\n        --backend fsdp \\\n        --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \\\n        --target_dir /path/to/merged_hf_model\n\n\nMegatron Merger details\n-----------------------\n\nCurrent implement of decoder layers uses ``nn.ModuleList`` to store the layers, \nand thus the model layers on every PP rank and VPP rank starts their index from 0.\n\nThere are 3 ways to correct this behavior:\n\n1. Modify the decoder layer's state_dict, add ``offset`` to each layer's index, thus rewrite ``nn.ModuleList`` implementation.\n2. Modify the layer index when saving checkpoint and recover them when loading checkpoint.\n3. The Checkpoint merger do this work, calculate the actual ``offset`` from ``state_dict`` only, a little complex.\n\nCurrent implementation use solution 2.\n\n\nHuggingFace to Megatron DistCheckpoint details\n----------------------------------------------\n\nIf your model is quite huge, we recommend you to use Megatron dist-checkpoint to load the model.\nMegatron dist-checkpoint supports loading with different kinds of model parallelism,\nand it is much faster than the original checkpoint loading.\n\nTo convert original HuggingFace model to Megatron dist-checkpoint,\nyou can use the ``scripts/converter_hf_to_mcore.py`` script. Large MoE models are temporarily supported with CPU initialization,\nwhich is a little slower. While we are working on a better solution to support large models.\n\nExample command to convert the model is as follows:\n\n.. code:: bash\n\n    python scripts/converter_hf_to_mcore.py \\\n        --hf_model_path Qwen/Qwen1.5-MoE-A2.7B-Chat \\\n        --output_path /mnt/disk/Qwen/Qwen1.5-MoE-A2.7B-Chat \\\n        --use_cpu_initialization    # Only work for MoE models\n\n\nExample command to distributed convert the huge model like deepseekv3 671B is as follows:\n\n.. code:: bash\n\n    torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} scripts/converter_hf_to_mcore.py \\\n        --hf_model_path deepseek-ai/DeepSeek-V3 \\\n        --output_path /mnt/disk/deepseek-ai/DeepSeek-V3 \\\n        --use_cpu_initialization    # Only work for MoE models\n\nOriginal Checkpoint Utils\n-------------------------\n\nOriginal Checkpoint Utils refer to original checkpoint implementation in ``verl/models/[model]/megatron/checkpoint_utils``.\n\nWe only need ``[model]_loader.py`` in original checkpoint utils now, since we get rid of storing ``hf_model`` every time (which is not recommended for large model training, try only saving sharded models if you can).\n\n.. note:: \n\n    Note that ``[model]_loader`` only support environments where **storage clusters are able to connect with every calculation nodes**. \n    Because it utilizes **sharded load way to minimize the loading checkpoint overhead**. \n    Every rank loads its own data from ``state_dict`` which can be accessed by all of them.\n    While there is also no need to broadcast among DP ranks, since the saved state_dict is only produced by DP rank 0.\n\n    For users who can **only place the huggingface model on one device**, we keep the original costly implementation in ``[model]_loader_deprecated``. In this implementation, rank 0 broadcast all weights to each tp and pp rank, and then dp rank 0 broadcast to all dp ranks. There may be at risks of OOM.\n\n    To use deprecated loader, change the import package of ``load_state_dict_to_megatron_llama``.\n"
  },
  {
    "path": "verl_rl/docs/advance/dpo_extension.rst",
    "content": "Extend to other RL(HF) algorithms\n=================================\n\nLast updated: 02/25/2025.\n\nWe already implemented the complete training pipeline of the PPO\nalgorithms. To extend to other algorithms, we analyze the high-level\nprinciple to use verl and provide a tutorial to implement the DPO\nalgorithm. Users can follow the similar paradigm to extend to other RL algorithms.\n\n.. note:: **Key ideas**: Single process drives multi-process computation and data communication.\n\nOverall Approach\n----------------\n\nStep 1: Consider what multi-machine multi-GPU computations are needed\nfor each model, such as ``generate_sequence`` , ``compute_log_prob`` and\n``update_policy`` in the actor_rollout model. Implement distributed\nsingle-process-multiple-data (SPMD) computation and encapsulate them\ninto APIs\n\nStep 2: Based on different distributed scenarios, including FSDP and 3D\nparallelism in Megatron-LM, implement single-process control of data\ninteraction among multi-process computations.\n\nStep 3: Utilize the encapsulated APIs to implement the control flow\n\nExample: Online DPO\n-------------------\n\nWe use verl to implement a simple online DPO algorithm. The algorithm\nflow of Online DPO is as follows:\n\n1. There is a prompt (rollout) generator which has the same weight as\n   the actor model. After a batch of prompts are fed into the generator,\n   it generates N responses for each prompt.\n2. Send all the prompts + responses to a verifier for scoring, which can\n   be reward model or a rule-based function. Then sort them in pairs to\n   form a training batch.\n3. Use this training batch to train the actor model using DPO. During\n   the process, a reference policy is needed.\n\nStep 1: What are the multi-machine multi-GPU computations\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**Sample Generator**\n\nImplementation details:\n\n.. code:: python\n\n   from verl.single_controller.base import Worker\n   from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool\n   import ray\n\n   @ray.remote\n   class SampleGenerator(Worker):\n       def __init__(self, config):\n           super().__init__()\n           self.config = config\n           \n       def generate_sequences(self, data):\n           pass\n\nHere, ``SampleGenerator`` can be viewed as a multi-process pulled up by\n``torchrun``, with each process running the same code (SPMD).\n``SampleGenerator`` needs to implement a ``generate_sequences`` API for\nthe control flow to call. The implementation details inside can use any\ninference engine including vllm, sglang and huggingface. Users can\nlargely reuse the code in\nverl/verl/workers/rollout/vllm_rollout/vllm_rollout.py and we won't\ngo into details here.\n\n**ReferencePolicy inference**\n\nAPI: compute reference log probability\n\n.. code:: python\n\n   from verl.single_controller.base import Worker\n   import ray\n\n   @ray.remote\n   class ReferencePolicy(Worker):\n       def __init__(self):\n           super().__init__()\n           self.model = Model()\n           \n       def infer(self, data):\n           return self.model(data)\n\n**Actor update**\n\nAPI: Update actor model parameters\n\n.. code:: python\n\n   from verl.single_controller.base import Worker\n   import ray\n\n   @ray.remote\n   class DPOActor(Worker):\n       def __init__(self):\n           super().__init__()\n           self.model = Model()\n           self.model = FSDP(self.model)  # or other distributed strategy\n           self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)\n           self.loss_fn = xxx\n           \n       def update(self, data):\n           self.optimizer.zero_grad()\n           logits = self.model(data)\n           loss = self.loss_fn(logits)\n           loss.backward()\n           self.optimizer.step()\n\n**Notes: How to distinguish between control processes and distributed computation processes**\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n- Control processes are generally functions directly decorated with\n  ``@ray.remote``\n- Computation processes are all wrapped into a ``RayWorkerGroup``.\n\nUsers can reuse most of the distribtued computation logics implemented\nin PPO algorithm, including FSDP and Megatron-LM backend in\nverl/verl/trainer/ppo.\n\nStep 2: Based on different distributed scenarios, implement single-process control of multi-process data interaction\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**The core problem to solve here is how a single process sends data to\nmultiple processes, drives multi-process computation, and how the\ncontrol process obtains the results of multi-process computation.**\nFirst, we initialize the multi-process ``WorkerGroup`` in the control\nprocess.\n\n.. code:: python\n\n   @ray.remote(num_cpus=1)\n   def main_task(config):\n       # construct SampleGenerator\n       resource_pool = RayResourcePool(process_on_nodes=[8] * 2)  # 16 GPUs\n       ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)\n       # put SampleGenerator onto resource pool\n       worker_group = RayWorkerGroup(resource_pool, ray_cls)\n       \n       # construct reference policy\n\nAs we can see, in the control process, multiple processes are wrapped\ninto a ``RayWorkerGroup``. Inside this ``WorkerGroup``, there is a\n``self._workers`` member, where each worker is a RayActor\n(https://docs.ray.io/en/latest/ray-core/actors.html) of SampleGenerator.\nray_trainer.md also provide an implementation of\n``MegatronRayWorkerGroup``.\n\nAssuming the model is distributed using FSDP, and there is a batch of\ndata on the control process, for data parallelism, the underlying\ncalling process is:\n\n.. code:: python\n\n   data = xxx\n   data_list = data.chunk(dp_size)\n\n   output = []\n   for d in data_list:\n       # worker_group._workers[i] is a SampleGenerator\n       output.append(worker_group._workers[i].generate_sequences.remote(d))\n\n   output = ray.get(output)\n   output = torch.cat(output)\n\nSingle process calling multiple processes involves the following 3\nsteps:\n\n1. Split the data into DP parts on the control process.\n2. Send the data to remote, call the remote computation through RPC, and\n   utilize multi-process computation.\n3. Obtain the computation results of each worker on the control process\n   and merge them.\n\nFrequently calling these 3 steps on the controller process greatly hurts\ncode readability. **In verl, we have abstracted and encapsulated these 3\nsteps, so that the worker's method + dispatch + collect can be\nregistered into the worker_group**\n\n.. code:: python\n\n   from verl.single_controller.base.decorator import register\n\n   def dispatch_data(worker_group, data):\n       return data.chunk(worker_group.world_size)\n       \n   def collect_data(worker_group, data):\n       return torch.cat(data)\n\n   dispatch_mode = {\n       'dispatch_fn': dispatch_data,\n       'collect_fn': collect_data\n   }\n\n   @register(dispatch_mode=dispatch_mode)\n   def generate_sequences(self, data):\n       pass\n\nIn this way, we can directly call the method inside the worker through\nthe ``worker_group`` on the control (driver) process (which is a single\nprocess):\n\n.. code:: python\n\n   output = worker_group.generate_sequences(data)\n\nThis single line includes data splitting, data distribution and\ncomputation, and data collection.\n\nFurthermore, the model parallelism size of each model is usually fixed,\nincluding dp, tp, pp. So for these common distributed scenarios, we have\npre-implemented specific dispatch and collect methods,in `decorator.py <https://github.com/volcengine/verl/blob/main/verl/single_controller/base/decorator.py>`_, which can be directly used to wrap the computations.\n\n.. code:: python\n\n   from verl.single_controller.base.decorator import register, Dispatch\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def generate_sequences(self, data: DataProto) -> DataProto:\n       pass\n\nHere it requires the data interface to be ``DataProto``. Definition of\n``DataProto`` is in `protocol.py <https://github.com/volcengine/verl/blob/main/verl/protocol.py>`_.\n\nStep 3: Main training loop\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nWith the above training flows, we can implement the algorithm's control\nflow. It is recommended that ``main_task`` is also a ray remote process.\n\n.. code:: python\n\n   @ray.remote(num_cpus=1)\n   def main_task(config):\n       # construct SampleGenerator\n       resource_pool = RayResourcePool(process_on_nodes=[8] * 2)  # 16 GPUs\n       ray_cls = RayClassWithInitArgs(SampleGenerator, config=config) \n       # put SampleGenerator onto resource pool\n       sample_gen = RayWorkerGroup(resource_pool, ray_cls)\n       \n       # construct reference policy\n       ray_cls = RayClassWithInitArgs(ReferencePolicy)\n       ref_policy = RayWorkerGroup(resource_pool, ray_cls)\n       \n       # construct actor\n       ray_cls = RayClassWithInitArgs(DPOActor)  \n       dpo_policy = RayWorkerGroup(resource_pool, ray_cls)\n       \n       dataloader = DataLoader()\n       \n       for data in dataloader:\n           # generate data\n           data = sample_gen.generate_sequences(data)\n           # generate scores for each data \n           data = generate_scores(data)\n           # generate pairwise data using scores\n           data = generate_pairwise_data(data)\n           # generate ref_log_prob\n           data.batch['ref_log_prob'] = ref_policy.infer(data)\n           # update using dpo\n           dpo_policy.update(data)\n           # logging\n\nHere, different ``WorkerGroups`` can be placed in the same resource pool or\nin different resource pools using ``create_colocated_worker_cls``\nsimilar as in `ray_trainer.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py>`_.\n"
  },
  {
    "path": "verl_rl/docs/advance/fsdp_extension.rst",
    "content": "\nAdd models with the FSDP backend\n==================================\n\nLast updated: 02/09/2025.\n\nModel\n--------------------------\n\nIn principle, our FSDP backend can support any HF model and we can\nsychronoize the actor model weight with vLLM using `hf_weight_loader.py` under `third_party/vllm`.\nHowever, ``hf_weight_loader`` is will gather the full state_dict of a\nmodel during synchronization, which may cause OOM. We suggest using\n``dtensor_weight_loader`` which gather the full model parameter layer by\nlayer to reduce the peak memory usage. We already support dtensor weight\nloader for the models below in `dtensor_weight_loader.py` under `third_party/vllm`:\n\n- ``GPT2LMHeadModel``\n- ``LlamaForCausalLM``\n- ``LLaMAForCausalLM``\n- ``MistralForCausalLM``\n- ``InternLMForCausalLM``\n- ``AquilaModel``\n- ``AquilaForCausalLM``\n- ``Phi3ForCausalLM``\n- ``GemmaForCausalLM``\n- ``Gemma2ForCausalLM``\n- ``GPTBigCodeForCausalLM``\n- ``Starcoder2ForCausalLM``\n- ``Qwen2ForCausalLM``\n- ``DeepseekV2ForCausalLM``\n\nTo implement ``dtensor_weight_loader`` of a model that's supported in\nvLLM, follow the guide of gemma model below:\n\n1. Copy the\n   ``load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]])`` from the vllm model class\n   to ``dtensor_weight_loaders.py``\n2. Modify the arguments to\n   ``(actor_weights: Dict, vllm_model: nn.Module)``\n3. Replace the ``self`` to ``vllm_model``\n4. Add the\n   ``local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)``\n   before each ``param = params_dict[name]`` and modify the following\n   weight loading using ``local_loaded_weight``.\n5. Register the implemented dtensor weight loader to ``__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__``.\n\n.. code-block:: diff\n\n    - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n    + def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            (\"qkv_proj\", \"q_proj\", \"q\"),\n            (\"qkv_proj\", \"k_proj\", \"k\"),\n            (\"qkv_proj\", \"v_proj\", \"v\"),\n            (\"gate_up_proj\", \"gate_proj\", 0),\n            (\"gate_up_proj\", \"up_proj\", 1),\n        ]\n    -   params_dict = dict(self.named_parameters())\n    +   params_dict = dict(vllm_model.named_parameters())\n        loaded_params = set()\n    -   for name, loaded_weight in weights:\n    +   for name, loaded_weight in actor_weights.items():\n            for (param_name, shard_name, shard_id) in stacked_params_mapping:\n                if shard_name not in name:\n                    continue\n                name = name.replace(shard_name, param_name)\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n    +           local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)\n                param = params_dict[name]\n                weight_loader = param.weight_loader\n    -           weight_loader(param, loaded_weight, shard_id)\n    +           weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)\n                break\n            else:\n                # lm_head is not used in vllm as it is tied with embed_token.\n                # To prevent errors, skip loading lm_head.weight.\n                if \"lm_head.weight\" in name:\n                    continue\n                # Skip loading extra bias for GPTQ models.\n                if name.endswith(\".bias\") and name not in params_dict:\n                    continue\n    +           local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)\n                param = params_dict[name]\n                weight_loader = getattr(param, \"weight_loader\",\n                                        default_weight_loader)\n    -           weight_loader(param, loaded_weight)\n    +           weight_loader(param, local_loaded_weight.to(dtype=param.dtype))\n            loaded_params.add(name)\n        unloaded_params = params_dict.keys() - loaded_params\n        if unloaded_params:\n            raise RuntimeError(\n                \"Some weights are not initialized from checkpoints: \"\n                f\"{unloaded_params}\")"
  },
  {
    "path": "verl_rl/docs/advance/megatron_extension.rst",
    "content": "Add models with the Megatron-LM backend\n=========================================\n\nLast updated: 04/25/2025.\n\nModel\n-----------\n\n\nIf use latest verl, we have direct support of ``GPTModel`` for Megatron backend. \nYou can use the similar way of using Megatron to pretrain custom models. \nWe list the steps here:\n\n1. Find `model_initializer.py <https://github.com/volcengine/verl/blob/main/verl/models/mcore/model_initializer.py>`_\n2. If your model is configurable by ``TransformerLayerSpec`` , you can\n   directly use ``GPTModel``. Otherwise, Please implement a new\n   ``ModelLayerSpec`` and ``ModelLayer`` here.\n3. Use the right ``LayerSpec`` , ``TransformerConfig`` and ``HuggingfaceConfig`` \n   as arguments to initialize the GPTModel.\n4. Return the model at last.\n"
  },
  {
    "path": "verl_rl/docs/advance/one_step_off.md",
    "content": "# Recipe: One Step Off Policy Async Trainer\n\n**Author:**  `https://github.com/meituan-search`\n\nLast updated: 07/17/2025.\n\n## Introduction\n\n### Background\n\nThe current reinforcement learning training process implemented by verl is synchronous, adhering to the algorithmic\nworkflows of established methods like PPO, GRPO, and DAPO. In each step, training samples are generated by the latest\nmodel, and the model is updated after training completes. While this approach aligns with off-policy reinforcement\nlearning and stabilizes RL training, but it suffers from severe efficiency issues.\nModel updates must wait for the longest output in the generation phase to complete.\nDuring the generation of long-tail samples, GPUs remain idle, resulting in significant underutilization.\nThe more severe the long-tail problem in sample generation, the lower the overall training efficiency.\nFor example, in DAPO 32B training, the Rollout phase accounts for approximately 70% of the total time,\nand increasing resources does not reduce the Rollout duration.\n\n![DAPO 32B Math Performance](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/dapo_32b_math.png)\n> source data: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=nwusertongyuxuan361\n\n### Solution\n\nWe have implemented the **One Step Off Async Trainer** to help alleviate this issue. This approach parallelizes the\ngeneration and training processes, utilizing samples generated in the previous step for current training.\nIt also involves appropriately partitioning resources, allocating dedicated resources for generation while automatically\nassigning the remainder to training. By reducing resources allocated to the generation phase, we mitigate GPU idle time\nduring long-tail sample generation. Throughout this process, generation and training parameters maintain a one-step off\npolicy.\n\n![One Step Off Policy Diagram](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_policy.png)\n> reference: [AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning](\n> https://arxiv.org/abs/2505.24298)\n\nOur core contributions include:\n\n1. **Parallel Generation and Training**:  \n   Samples for the next batch are asynchronously generated while the current batch is being trained.\n\n2. **Resource Isolation**:  \n   Unlike `hybrid_engine`, this method requires explicit resource allocation for rollout, with remaining resources\n   automatically assigned to training.\n\n3. **NCCL Parameter Synchronization**:  \n   Employs NCCL communication primitives for seamless parameter transfer between generation and training modules.\n\n### Experimental Results\n\n- **Machine Configuration**: 2 nodes with 16 H20 GPUs each\n   - Generation: 4 GPUs\n   - Training: 12 GPUs\n- **Model**: Qwen2.5-Math-7B\n- **Rollout Configuration**:\n- **Max Response Length**: FSDP2: 20,480 tokens; Megatron: 8,192 tokens\n- **Algorithm**: DAPO\n- **Rollout Engine**: vLLM\n\n| training mode          | engine        | step | gen | wait_prev_gen | generate_sequences | old_log_prob | update_actor | total time    | acc/best@32/mean | acc/maj@32/mean |\n|------------------------|---------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------|\n| colocate sync          | VLLM+FSDP2    | 749  | 321 | -             | 247                | 88           | 286          | 19h18m        | 0.5948           | 0.417           |\n| one-step-overlap async | VLLM+FSDP2    | 520  | -   | 45            | 458                | 108          | 337          | 15h34m（+23%）  | 0.6165           | 0.494           |\n| colocate sync          | VLLM+Megatron | 699  | 207 | -             | 162                | 119          | 344          | 18h21m        | 0.605            | 0.4217          |\n| one-step-overlap async | VLLM+Megatron | 566  | -   | 59            | 501                | 120          | 347          | 13h06m (+40%) | 0.6569           | 0.4038          |\n\n* colocate sync: step ≈ gen + old_log_prob + update_actor\n* one-step-overlap async: step ≈ wait_prev_gen + old_log_prob + update_actor\n\n![One Step Off Megatron Performance](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_megatron.png)\n\n> source data: https://wandb.ai/hou-zg-meituan/one-step-off-policy?nw=nwuserhouzg\n\n## Implementation\n\n### One Step Off Policy Async Pipline\n\nOur implemented **One Step Off Policy Async Pipeline** integrates seamlessly into existing training logic at minimal\ncost,\neliminating the need for additional sample storage management. The core mechanism uses `async_gen_next_batch`\nfor asynchronous rollout generation while maintaining continuous operation during epoch transitions\nvia `create_continuous_iterator`.\n\n```python\n# iterator generator, simplify one-step integration of the training process\ndef _create_continuous_iterator(self):\n   for epoch in range(self.config.trainer.total_epochs):\n      iterator = iter(self.train_dataloader)\n      for batch_dict in iterator:\n         yield epoch, batch_dict\n\n\n# read next batch samples, parameters sync and launch asyn gen_seq\ndef _async_gen_next_batch(self, continuous_iterator):\n   # read train_data\n   try:\n      epoch, batch_dict = next(continuous_iterator)\n   except StopIteration:\n      return None\n   batch = DataProto.from_single_dict(batch_dict)\n   gen_batch = batch_pocess(batch)\n   # sync weights from actor to rollout\n   self.sync_rollout_weights()\n   # async generation\n   gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)\n   # future encapsulated\n   return GenerationBatchFuture(epoch, batch, gen_batch_output)\n\n\ncontinuous_iterator = self._create_continuous_iterator()\n# run rollout first to achieve one-step-off\nbatch_data_future = self._async_gen_next_batch(continuous_iterator)\n\nwhile batch_data_future is not None:\n   # wait for the gen_seq result from the previous step\n   batch = batch_data_future.get()\n   # launch the next async call to generate sequences\n   batch_data_future = self._async_gen_next_batch(continuous_iterator)\n\n   # compute advantages \n   batch = critic.compute_values(batch)\n   batch = reference.compute_log_prob(batch)\n   batch = reward.compute_reward(batch)\n   batch = compute_advantages(batch)\n\n   # model update\n   critic_metrics = critic.update_critic(batch)\n   actor_metrics = actor.update_actor(batch)\n```\n\n### Parameter Synchronization\n\nThe exciting point is that our nccl based weights updating for rollout model has great performance.\nAt most of time, the latency is under 300ms, which is negligible for RLHF.\n\n> **sync_rollout_weights**：The time for synchronizing parameters from actor to rollout is extremely fast and can almost\n> be ignored because it is implemented with nccl.\n\n```python\nclass ActorRolloutRefWorker:\n   # actor acquires the meta-info of model parameters for parameter sync\n   @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n   def get_actor_weights_info(self):\n      params = self._get_actor_params()\n      ret = []\n      for key, tensor in params.items():\n         ret.append((key, tensor.size(), tensor.dtype))\n      self._weights_info = ret\n      return ret\n\n   # rollout sets the meta-info of model parameters for parameter sync\n   @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n   def set_actor_weights_info(self, weights_info):\n      self._weights_info = weights_info\n\n\nclass AsyncRayPPOTrainer(RayPPOTrainer):\n   def init_workers(self):\n\n\n...\n# rollout obtains the meta-info of model parameters from the actor for parameter sync\nweights_info = self.actor_wg.get_actor_weights_info()[0]\nself.rollout_wg.set_actor_weights_info(weights_info)\n\n# Create an actor-rollout communication group for parameter sync\nactor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers\ncollective.create_collective_group(\n   actor_rollout_workers,\n   len(actor_rollout_workers),\n   list(range(0, len(actor_rollout_workers))),\n   backend=\"nccl\",\n   group_name=\"actor_rollout\"\n)\n```\n\n```python\n# drive process call the actor and rollout respectively to sync parameters by nccl \ndef sync_rollout_weights(self):\n   self.actor_wg.sync_rollout_weights()\n   ray.get(self.rollout_wg.sync_rollout_weights())\n\n\n# fsdp model parameter sync\n@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\ndef sync_rollout_weights(self):\n   params = self._get_actor_params() if self._is_actor else None\n   if self._is_rollout:\n      inference_model = (\n         self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n      )\n      patch_vllm_moe_model_weight_loader(inference_model)\n   # Model parameters are broadcast tensor-by-tensor from actor to rollout\n   for key, shape, dtype in self._weights_info:\n      tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())\n      if self._is_actor:\n         assert key in params\n         origin_data = params[key]\n         if hasattr(origin_data, \"full_tensor\"):\n            origin_data = origin_data.full_tensor()\n         if torch.distributed.get_rank() == 0:\n            tensor.copy_(origin_data)\n      from ray.util.collective import collective\n\n      collective.broadcast(tensor, src_rank=0, group_name=\"actor_rollout\")\n      if self._is_rollout:\n         inference_model.load_weights([(key, tensor)])\n```\n\n## Usage\n\n### FSDP2 Configuration Example\n\n```shell\npython3 -m recipe.one_step_off_policy.async_main_ppo \\\n    --config-path=config \\\n    --config-name='one_step_off_ppo_trainer.yaml' \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    # actor and rollout are placed separately\n    actor_rollout_ref.hybrid_engine=False \\\n    # actor and rollout resource\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=6 \\\n    rollout.nnodes=1 \\\n    rollout.n_gpus_per_node=2\n```\n\n### Megatron Configuration Example\n\n```shell\npython3 -m recipe.one_step_off_policy.async_main_ppo \\\n    --config-path=config \\\n    --config-name='one_step_off_ppo_megatron_trainer.yaml' \\\n    actor_rollout_ref.actor.strategy=megatron \\\n    # actor and rollout are placed separately\n    actor_rollout_ref.hybrid_engine=False \\\n    # actor and rollout resource\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=6 \\\n    rollout.nnodes=1 \\\n    rollout.n_gpus_per_node=2\n```\n\n### Configuration Guidelines\n\n1. **Card Number Relationships**  \n   Maintain either of these relationships for optimal batch distribution:\n   - `actor_rollout_ref.rollout.n` should be an integer divisor of:  \n     `trainer.n_gpus_per_node * trainer.nnodes`\n   - `actor_rollout_ref.rollout.n * data.train_batch_size` should be evenly divisible by:  \n     `trainer.n_gpus_per_node * trainer.nnodes`\n\n   > Rationale: Ensures training samples can be evenly distributed across training GPUs when using partial resources for\n   generation.\n\n2. **Dynamic Resource Tuning**  \n   Adjust `trainer.nnodes` `trainer.n_gpus_per_node` `rollout.nnodes` `rollout.n_gpus_per_node` based on phase\n   durations:\n   - **Ideal state**: Rollout and training phases have comparable durations\n   - **Diagnostic metrics**:\n      - Monitor `wait_prev_gen` duration\n      - Analyze `sequence_length` distribution\n   - **Adjustment strategy**:\n      - High `wait_prev_gen` + uniform sequence lengths → Increase rollout resources\n      - High `wait_prev_gen` + long-tail sequences → Optimize stopping criteria (resource increase won't help)\n   > **wait_prev_gen**：The time consumed waiting for the previous rollout to end (the part that is not fully\n   overlapped).\n   **Resource Configuration Strategies:**\n   - **Resource-constrained scenario**: Optimize resource utilization by adjusting GPU allocation ratios,\n     keeping the number of nodes equal to allow training and rollout to share nodes;\n      - Configure `trainer.nnodes = rollout.nnodes` with\n        `trainer.n_gpus_per_node + rollout.n_gpus_per_node = physical_gpus_per_node`. Control rollout resource\n        allocation by adjusting `n_gpus_per_node`.\n   - **Resource-abundant scenario**: Optimize performance by adjusting the number of nodes,\n     keeping the number of GPUs per node equal to enable independent scaling of training and rollout\n     parallelism.\n      - Configure `trainer.n_gpus_per_node = rollout.n_gpus_per_node` and control rollout resource allocation by\n        adjusting `trainer.nnodes` and `rollout.nnodes`to achieve optimal performance.\n   > **Note**: The total number of nodes required by the system is not simply `trainer.nnodes + rollout.nnodes`. The\n   > actual calculation depends on GPU capacity:\n   > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node <= physical_gpus_per_node`,\n       > the required node count is `max(trainer.nnodes, rollout.nnodes)`\n   > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node > physical_gpus_per_node`,\n       > the required node count is `trainer.nnodes + rollout.nnodes`\n\n## Functional Support\n\n| Category           | Support Situation                                                                                               |\n|--------------------|-----------------------------------------------------------------------------------------------------------------|\n| train engine       | FSDP2  <br/> Megatron                                                                                           |\n| rollout engine     | vLLM                                                                                                            |\n| AdvantageEstimator | GRPO <br/> GRPO_PASSK <br/> REINFORCE_PLUS_PLUS <br/> RLOO <br/> OPO <br/> REINFORCE_PLUS_PLUS_BASELINE<br/>GPG |\n| Reward             | all                                                                                                             |\n"
  },
  {
    "path": "verl_rl/docs/advance/placement.rst",
    "content": "Ray API Design Tutorial\n=======================================\n\nLast updated: 10/30/2024.\n\nWe provide a tutorial for our Ray API design, including:\n\n- Ray basic concepts\n- Resource Pool and RayWorkerGroup\n- Data Dispatch, Execution and Collection\n- Initialize the RayWorkerGroup and execute the distributed computation in the given Resource Pool\n\nSee details in `tutorial.ipynb <https://github.com/volcengine/verl/blob/main/examples/ray/tutorial.ipynb>`_."
  },
  {
    "path": "verl_rl/docs/advance/ppo_lora.rst",
    "content": "RL(HF) algorithms with LoRA Support\n===========================================\n\nLast updated: 06/05/2025.\n\nWe support LoRA (Low-Rank Adaptation) for reinforcement learning algorithms such as PPO, GRPO, and others.\n\nLoRA is a parameter-efficient fine-tuning technique that injects trainable low-rank matrices into pre-trained weights (typically linear layers). This reduces memory footprint and compute cost, making it possible to fine-tune large models with limited hardware.\n\nThe benefits this brings include:\n\n- reinforcement learning with very large models (e.g. 70B+) with modest hardware (e.g. 8x80G GPUs),\n- enable larger batch sizes due to reduced memory usage,\n- simplify model transfer and deployment, as only LoRA adapters need to be saved,\n- Combine with techniques like `SLoRA <https://arxiv.org/abs/2311.03285>`_ or `CCoE <https://arxiv.org/abs/2407.11686>`_ to serve multiple LoRA adapters efficiently\n\nThis guide explains how to enable LoRA in RL training and configure related parameters.\n\nUsage Guide\n------------------------\n1. Lora is available in the `verl.trainer.ppo.ray_trainer.RayPPOTrainer`. Examples are provided via the `verl.trainer.main_ppo` entry point.\n\n2. Currently, LoRA is supported via huggingface peft, only with fsdp/fsdp2 and vllm backend (sglang support coming soon).\n\n- `strategy=fsdp` or `strategy=fsdp2`\n- `rollout.name=vllm`\n\n3. Required configurations for LoRA:\n\n- `actor_rollout_ref.model.lora_rank`: int, set to a reasonable value greater than 0 (e.g., 8, 16, 32, 64)\n- `actor_rollout_ref.model.lora_alpha`: float, the alpha term in LoRA\n- `actor_rollout_ref.rollout.load_format=\"safetensors\"`: required. This enables vLLM to load the base model.\n- `actor_rollout_ref.model.target_modules`: the target modules for LoRA. Typically set to \"all-linear\".\n\n4. Recommend options:\n\n- `actor_rollout_ref.model.use_shm=True`: preload the model into `/dev/shm` to improve model loading speed.\n- `actor_rollout_ref.rollout.layered_summon=True`: this enables the actor-model to gather the FSDP shards per layers when synchronizing the LoRA Adapter to vLLM, thereby reducing GPU peak memory. Recommended if the model is very large (70B+) or the GPU memory is limited (< 48GB)\n\n\nBest Practices and Notes\n-------------------------\n\n1. **Learning rate**: it is recommended to increase the value of learning rate by an order of magnitude.\n\n2. **LoRA Rank**:\n\n- Too small a rank can hurt convergence.\n- LoRA rank recommendation from @thelongestusernameofall:\n\n  - A very small lora_rank can lead to slower convergence or worse training performance. It is recommended to set lora_rank to be>=32. Tests have shown that for a 0.5B model, with lora_rank=32,the training convergence speed and final performance are almost identical to non-LoRA training\n  - For a 32B model,with lora_rank=128,the training convergence speed and final performance are also almost identical to non-LoRA training.\n  - More comprehensive reference results are coming soon.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/f2b80b8b26829124dd393b7a795a0640eff11644/docs/lora.jpg?raw=true\n\n3. Reference configuration for RL training with the Qwen2.5-72B model using 8 x 80GB GPUs (increase lora_rank if needed):\n\n.. code-block::\n\n    data.train_batch_size=64 \\\n    actor_rollout_ref.model.use_shm=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=8 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=64 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n\nExample Script\n-------------------\n\nFor an end-to-end example, refer to the script below:\n\nexamples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh\n"
  },
  {
    "path": "verl_rl/docs/advance/rollout_trace.rst",
    "content": "Trace Function Usage Instructions\n========================================\n\nLast updated: 07/10/2025.\n\nApplicable Scenarios\n--------------------\n\nAgentic RL involves multiple turns of conversations, tool invocations, and user interactions during the rollout process. During the Model Training process, it is necessary to track function calls, inputs, and outputs to understand the flow path of data within the application. The Trace feature helps, in complex multi-round conversations, to view the transformation of data during each interaction and the entire process leading to the final output by recording the inputs, outputs, and corresponding timestamps of functions, which is conducive to understanding the details of how the model processes data and optimizing the training results.\n\nThe Trace feature integrates commonly used Agent trace tools, including wandb weave and mlflow, which are already supported. Users can choose the appropriate trace tool according to their own needs and preferences. Here, we introduce the usage of each tool.\n\n\nTrace Parameter Configuration\n-----------------------------\n\n- ``actor_rollout_ref.rollout.trace.backend=mlflow|weave`` # the trace backend type\n- ``actor_rollout_ref.rollout.trace.token2text=True`` # To show decoded text in trace view\n\n\nGlossary\n--------\n\n+----------------+------------------------------------------------------------------------------------------------------+\n| Object         | Explaination                                                                                         |\n+================+======================================================================================================+\n| trajectory     | A complete multi-turn conversation includes:                                                         |\n|                | 1. LLM output at least once                                                                          |\n|                | 2. Tool Call                                                                                         |\n+----------------+------------------------------------------------------------------------------------------------------+\n| step           | The training step corresponds to the global_steps variable in the trainer                            |\n+----------------+------------------------------------------------------------------------------------------------------+\n| sample_index   | The identifier of the sample, defined in the extra_info.index of the dataset. It is usually a number,|\n|                | but may also be a uuid in some cases.                                                                |\n+----------------+------------------------------------------------------------------------------------------------------+\n| rollout_n      | In the GROP algorithm, each sample is rolled out n times. rollout_n represents the serial number of  |\n|                | the rollout.                                                                                         |\n+----------------+------------------------------------------------------------------------------------------------------+\n| validate       | Whether the test dataset is used for evaluation?                                                     |\n+----------------+------------------------------------------------------------------------------------------------------+\n\nRollout trace functions\n-----------------------\n\nThere are 2 functions used for tracing:\n\n1. ``rollout_trace_op``: This is a decorator function used to mark the functions to trace. In default, only few method has it, you can add it to more functions to trace more infor.\n2. ``rollout_trace_attr``: This function is used to mark the entry of a trajectory and input some info to trace. If you add new type of agent, you may need to add it to enable trace.\n\n\nUsage of wandb weave\n--------------------\n\n1.1 Basic Configuration\n~~~~~~~~~~~~~~~~~~~~~~~\n\n1. Set the ``WANDB_API_KEY`` environment variable\n2. Configuration Parameters\n\n   1. ``actor_rollout_ref.rollout.trace.backend=weave``\n   2. ``trainer.logger=['console', 'wandb']``: This item is optional. Trace and logger are independent functions. When using Weave, it is recommended to also enable the wandb logger to implement both functions in one system.\n   3. ``trainer.project_name=$project_name``\n   4. ``trainer.experiment_name=$experiment_name``\n   5. ``actor_rollout_ref.rollout.mode=async``: Since trace is mainly used for agentic RL, need to enable agent toop using async mode for either vllm or sglang.\n\nNote:\nThe Weave Free Plan comes with a default monthly network traffic allowance of 1GB. During the training process, the amount of trace data generated is substantial, reaching dozens of gigabytes per day, so it is necessary to select an appropriate wandb plan.\n\n\n1.2 View Trace Logs\n~~~~~~~~~~~~~~~~~~~\n\nAfter executing the training, on the project page, you can see the WEAVE sidebar. Click Traces to view it.\n\nEach Trace project corresponds to a trajectory. You can filter and select the trajectories you need to view by step, sample_index, rollout_n, and experiment_name.\n\nAfter enabling token2text, prompt_text and response_text will be automatically added to the output of ToolAgentLoop.run, making it convenient to view the input and output content.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/weave_trace_list.png?raw=true\n\n1.3 Compare Trace Logs\n~~~~~~~~~~~~~~~~~~~~~~\n\nWeave can select multiple trace items and then compare the differences among them.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/weave_trace_compare.png?raw=true\n\nUsage of mlflow\n---------------\n\n1. Basic Configuration\n~~~~~~~~~~~~~~~~~~~~~~\n\n1. Set the ``MLFLOW_TRACKING_URI`` environment variable, which can be:\n\n   1. Http and https URLs corresponding to online services\n   2. Local files or directories, such as ``sqlite:////tmp/mlruns.db``, indicate that data is stored in ``/tmp/mlruns.db``. When using local files, it is necessary to initialize the file first (e.g., start the UI: ``mlflow ui --backend-store-uri sqlite:////tmp/mlruns.db``) to avoid conflicts when multiple workers create files simultaneously.\n\n2. Configuration Parameters\n\n   1. ``actor_rollout_ref.rollout.trace.backend=mlflow``\n   2. ``trainer.logger=['console', 'mlflow']``. This item is optional. Trace and logger are independent functions. When using mlflow, it is recommended to also enable the mlflow logger to implement both functions in one system.\n   3. ``trainer.project_name=$project_name``\n   4. ``trainer.experiment_name=$experiment_name``\n\n\n2. View Log\n~~~~~~~~~~~\n\nSince ``trainer.project_name`` corresponds to Experiments in mlflow, in the mlflow view, you need to select the corresponding project name, then click the \"Traces\" tab to view traces. Among them, ``trainer.experiment_name`` corresponds to the experiment_name of tags, and tags corresponding to step, sample_index, rollout_n, etc., are used for filtering and viewing.\n\nFor example, searching for ``\"tags.step = '1'\"`` can display all trajectories of step 1.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/mlflow_trace_list.png?raw=true\n\nOpening one of the trajectories allows you to view each function call process within it.\n\nAfter enabling token2text, prompt_text and response_text will be automatically added to the output of ToolAgentLoop.run, making it convenient to view the content.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/mlflow_trace_view.png?raw=true\n\nNote:\n\n1. mlflow does not support comparing multiple traces\n2. rollout_trace can not associate the mlflow trace with the run, so the trace content cannot be seen in the mlflow run logs.\n"
  },
  {
    "path": "verl_rl/docs/advance/rope.rst",
    "content": "RoPE Scaling override\n=======================================\n\nLast updated: 05/14/2025.\n\nSome models such as `Qwen/Qwen2.5-7B-Instruct <https://huggingface.co/Qwen/Qwen2.5-7B-Instruct#processing-long-texts>`_ support RoPE Scaling but don't have it defined in their config.json file.\nFor example, this model supports this configuration:\n\n.. code:: python\n\n    {\n        ...,\n        \"rope_scaling\": {\n            \"factor\": 4.0,\n            \"original_max_position_embeddings\": 32768,\n            \"type\": \"yarn\"\n        }\n    }\n\n\n\nIn order to support a longer context for such models, you must override the model configs when starting the trainer.\n\nPPO example:\n\n.. code:: bash\n\n    +actor_rollout_ref.model.override_config.rope_scaling.type=yarn \\\n    +actor_rollout_ref.model.override_config.rope_scaling.factor=4.0 \\\n    +actor_rollout_ref.model.override_config.rope_scaling.original_max_position_embeddings=32768 \\\n\n\nAnd for the critic model\n\n.. code:: bash\n\n    +critic.model.override_config.rope_scaling.type=yarn \\\n    +critic.model.override_config.rope_scaling.factor=4.0 \\\n    +critic.model.override_config.rope_scaling.original_max_position_embeddings=32768 \\\n"
  },
  {
    "path": "verl_rl/docs/algo/baseline.md",
    "content": "# Algorithm Baselines\n\nLast updated: 06/18/2025.\n\n## Math related datasets\n\n### GSM8k\n\nAssuming GSM8k/math dataset is preprocessed via:\n\n```bash\npython3 examples/data_preprocess/*.py\n```\n\nRefer to the table below to reproduce RL training from different pre-trained checkpoints. Below is the performance on the GSM8k dataset if not specified otherwise. More comprehensive benchmark results areavailable in the recipe folder.\n\n\n| Hardware    | Model                            | Method            | Test score   | Details |\n|-------------|----------------------------------|-------------------|--------------|---------|\n| NVIDIA GPU  | google/gemma-2-2b-it             | hf checkpoint     | 23.9         | [Huggingface](https://huggingface.co/google/gemma-2-2b-it#benchmark-results) |\n| NVIDIA GPU  | google/gemma-2-2b-it             | SFT               | 52.06        | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/gemma-2-2b-it-sft-0.411.log) |\n| NVIDIA GPU  | google/gemma-2-2b-it             | SFT + PPO         | 64.02        | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/gemma-2-2b-it-ppo-bsz512_4-prompt1024-resp-512-0.640.log), [wandb](https://api.wandb.ai/links/verl-team/h7ux8602) |\n| NVIDIA GPU  | Qwen/Qwen2.5-0.5B-Instruct       | hf checkpoint     | 36.4         | [Qwen blog](https://qwenlm.github.io/blog/qwen2.5-llm/) |\n| NVIDIA GPU  | Qwen/Qwen2.5-0.5B-Instruct       | PPO               | 56.7         | [command and log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) |\n| NVIDIA GPU  | Qwen/Qwen2.5-0.5B-Instruct       | PRIME             | 58.7         | [script](https://github.com/volcengine/verl/blob/main/recipe/prime/run_prime_qwen.sh), [wandb](https://api.wandb.ai/links/zefan-wang-thu-tsinghua-university/rxd1btvb) |\n| NVIDIA GPU  | Qwen/Qwen2.5-0.5B-Instruct       | GRPO-LoRA         | 54.3         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz64_2-prompt512-resp1024-lorarank32-score0.543.log)|\n| NVIDIA GPU  | Qwen/Qwen2.5-1.5B-Instruct       | GRPO-LoRA         | 77.9         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-1.5B-bsz64_2-prompt512-resp1024-lorarank32-score0.779.log)|\n| NVIDIA GPU  | Qwen/Qwen2.5-3B-Instruct         | GRPO-LoRA         | 86.1         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-3B-bsz64_2-prompt512-resp1024-lorarank32-score0.861.log)|\n| NVIDIA GPU  | deepseek-ai/deepseek-llm-7b-chat | PPO (Megatron)    | 69.5 [1]     | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/deepseek-llm-7b-chat-megatron-bsz256_4-prompt512-resp512-0.695.log), [wandb](https://wandb.ai/verl-team/verl_megatron_gsm8k_examples/runs/10fetyr3) |\n| NVIDIA GPU  | Qwen/Qwen2-7B-Instruct           | GRPO              | 89           | [script](https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh) |\n| NVIDIA GPU  | Qwen/Qwen2-7B-Instruct           | GRPO (FSDP2)      | 89.8         | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log) |\n| NVIDIA GPU  | Qwen/Qwen2-7B-Instruct           | GRPO (Megatron)   | 89.6         | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b_math_megatron.log) |\n| NVIDIA GPU  | Qwen/Qwen2.5-7B-Instruct         | ReMax             | 97           | [script](https://github.com/eric-haibin-lin/verl/blob/main/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh), [wandb](https://wandb.ai/liziniu1997/verl_remax_example_gsm8k/runs/vxl10pln) |\n| NVIDIA GPU  | Qwen/Qwen2.5-7B-Instruct         | SPPO              | 65.6 (MATH)  | [SPPO script](https://github.com/volcengine/verl/tree/main/recipe/sppo/README.md) |\n| NVIDIA GPU  | Qwen/Qwen2.5-7B-Instruct         | GRPO-LoRA         | 93.4         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-7B-bsz64_8-prompt512-resp1024-lorarank32-score0.934.log)|\n| NVIDIA GPU  | Mixtral-8x22B-Instruct-v0.1      | Instruct model    | 83.7         | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/) |\n| NVIDIA GPU  | Mixtral-8x22B-Instruct-v0.1      | RLOO (Megatron)   | 92.3         | [wandb](https://api.wandb.ai/links/ppo_dev/sbuiuf2d) |\n| NVIDIA GPU  | Qwen/Qwen2.5-7B-Instruct         | SPIN              | 92           | [script](https://github.com/volcengine/verl/tree/main/recipe/spin/README.md) |\n| NVIDIA GPU  | Qwen/Qwen2-7B-Instruct           | GPG               | 88           | [log](https://github.com/diqiuzhuanzhuan/verldata/blob/main/run_logs/qwen2-7b_math.log), [wandb](https://wandb.ai/diqiuzhuanzhuan/verl_gpg_example_gsm8k_math/runs/ab86c4va) |\n| NVIDIA GPU  | Qwen/Qwen2-7B-Instruct           | GPG (Megatron)    | 88           | [log](https://github.com/diqiuzhuanzhuan/verldata/blob/main/run_logs/qwen2-7b_math_megatron.log), [wandb](https://wandb.ai/diqiuzhuanzhuan/verl_gpg_example_gsm8k_math/runs/yy8bheu8) |\n| NVIDIA GPU  | Qwen/Qwen2.5-VL-7B-Instruct      | GRPO (Megatron)   | 65.4 (GEO3k) | [script](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh), [wandb](https://api.wandb.ai/links/megatron-core-moe-dev/1yngvkek) |\n| AMD MI300   | deepseek-ai/deepseek-llm-7b-chat | PPO               | 70.5 [1]     | [log](https://github.com/yushengsu-thu/verl_training_log/blob/main/gsm8k/ppo_run_deepseek7b_llm.log) |\n| AMD MI300   | deepseek-ai/deepseek-llm-7b-chat | GRPO              | 71.4 [1]     | [log](https://github.com/yushengsu-thu/verl_training_log/blob/main/gsm8k/grpo_run_deepseek7b_llm.log) |\n| NVIDIA GPU  | Qwen/Qwen2.5-14B-Instruct         | GRPO-LoRA         | 94.6         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-14B-bsz64_8-prompt512-resp1024-lorarank32-score0.946.log)|\n| NVIDIA GPU  | Qwen/Qwen2.5-32B-Instruct         | GRPO-LoRA         | 95.8         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-32B-bsz64_8-prompt512-resp1024-lorarank32-score0.958.log)|\n| NVIDIA GPU  | Qwen/Qwen2.5-72B-Instruct         | GRPO-LoRA         | 96.0         | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-72B-bs64_8-prompt512-resp1024-lorarank32-score0.960.log)|\n\n### DAPO math-17k\n\n- Training DAPO math-17k dataset: https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k\n- Testing: AIME'24: https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024\n\nNote:\n- For Qwen/Qwen2.5-Math-7B, we directly modify the max_position_embeddings to 32768 without observing performance degradation in order to train longer response length.\n\n| Hardware    | Model                            | Method            | Test score   | Details |\n|-------------|----------------------------------|-------------------|--------------|---------|\n| NVIDIA GPU  | Qwen/Qwen2.5-Math-7B (32k)       | DAPO              | 36.3         | [command](https://github.com/volcengine/verl/blob/main/recipe/dapo/test_dapo_7b_math.sh), [logs](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361)|\n\n\n\n## Coding related datasets\n\nBelow is the result on leetcode if not specified otherwise.\n\n| Hardware    | Model                            | Method            | Test score   | Details |\n|-------------|----------------------------------|-------------------|--------------|---------|\n| NVIDIA GPU  | PRIME-RL/Eurus-2-7B-SFT          | RPIME             | 36.1         | [script](https://github.com/volcengine/verl/blob/main/recipe/prime/run_prime_qwen_code.sh), [swanlab](https://swanlab.cn/@wangzefan/prime_example/runs/7f541qhspgmy8nmhdlx35/chart) |\n\n\n### Notes\n\n[1] During evaluation, we have only extracted answers following the format `\"####\"`. A more flexible answer extraction, longer response length, and better prompt engineering may lead to a higher score.\n\n[2] The default value of `actor_rollout_ref.actor.entropy_coeff` is set to `0.0` since verl 0.3.x on 2025-05-30, which is different from previous versions.\n"
  },
  {
    "path": "verl_rl/docs/algo/dapo.md",
    "content": "# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO)\n\nLast updated: 06/19/2025.\n\n> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211)\n\n🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper@arXiv](https://arxiv.org/abs/2503.14476) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO)\n\n> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps.\n>\n> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png)\n\n## Quickstart\n\n1. Prepare the datasets **on the Ray cluster**:\n\n```bash\nbash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default\n```\n\n2. Submit the job to the Ray cluster **from any machine**:\n\n```bash\ncd verl # Repo root\nexport RAY_ADDRESS=\"http://${RAY_IP:-localhost}:8265\" # The Ray cluster address to connect to\nexport WORKING_DIR=\"${PWD}\" # The local directory to package to the Ray cluster\n# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml\nexport RUNTIME_ENV=\"./recipe/dapo/runtime_env.yaml\" # This sets environment variables for the Ray cluster\nbash recipe/dapo/run_dapo_qwen2.5_32b.sh # or other scripts\n```\n\n## Reproduction Runs\n\n| Setup                                        | AIME 2024 Acc. | Hardware  | Image                                                                | Commit                                                                                       | Environment Variables                                                                                                             | Training Script                                                                                                                                             | Training Record                                                                           |\n| -------------------------------------------- | -------------- | --------- | -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- |\n| DAPO                                         | 52%            | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_qwen2.5_32b.sh)             | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n| DAPO w/o Dynamic Sampling                    | 50%            | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_wo_ds_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n| DAPO w/o Token-level Loss & Dynamic Sampling | 44%            | 16x8xH20  | `hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix`                    | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_early_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n\n> [!IMPORTANT]\n>\n> **📢 Call for Contribution!**\n>\n> Welcome to submit your reproduction runs and setups!\n\n## Configuration\n\n### Separated Clip Epsilons (-> Clip-Higher)\n\nAn example configuration:\n\n```yaml\nactor_rollout_ref:\n  actor:\n    clip_ratio_low: 0.2\n    clip_ratio_high: 0.28\n```\n\n`clip_ratio_low` and `clip_ratio_high` specify the $\\varepsilon_{\\text {low }}$ and $\\varepsilon_{\\text {high }}$ in the DAPO objective.\n\nCore relevant code:\n\n```python\npg_losses1 = -advantages * ratio\npg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)\npg_losses = torch.maximum(pg_losses1, pg_losses2)\n```\n\n### Dynamic Sampling (with Group Filtering)\n\nAn example configuration:\n\n```yaml\ndata:\n  gen_batch_size: 1536\n  train_batch_size: 512\nalgorithm:\n  filter_groups:\n    enable: True\n    metric: acc # score / seq_reward / seq_final_reward / ...\n    max_num_gen_batches: 10 # Non-positive values mean no upper limit\n```\n\nSetting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0.\n\nThe trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`.\n\nCore relevant code:\n\n```python\nprompt_bsz = self.config.data.train_batch_size\nif num_prompt_in_batch < prompt_bsz:\n    print(f'{num_prompt_in_batch=} < {prompt_bsz=}')\n    num_gen_batches += 1\n    max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches\n    if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:\n        print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...')\n        continue\n    else:\n        raise ValueError(\n            f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.'\n        )\nelse:\n    # Align the batch\n    traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n\n    batch = batch[:traj_bsz]\n```\n\n### Flexible Loss Aggregation Mode (-> Token-level Loss)\n\nAn example configuration:\n\n```yaml\nactor_rollout_ref:\n  actor:\n    loss_agg_mode: \"token-mean\" # / \"seq-mean-token-sum\" / \"seq-mean-token-mean\"\n    # NOTE: \"token-mean\" is the default behavior\n```\n\nSetting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch.\n\nCore relevant code:\n\n```python\nif loss_agg_mode == \"token-mean\":\n    loss = verl_F.masked_mean(loss_mat, loss_mask)\nelif loss_agg_mode == \"seq-mean-token-sum\":\n    seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum\n    loss = torch.mean(seq_losses)  # seq-mean\nelif loss_agg_mode == \"seq-mean-token-mean\":\n    seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)  # token-mean\n    loss = torch.mean(seq_losses)  # seq-mean\nelse:\n    raise ValueError(f\"Invalid loss_agg_mode: {loss_agg_mode}\")\n```\n\n### Overlong Reward Shaping\n\nAn example configuration:\n\n```yaml\ndata:\n  max_response_length: 20480 # 16384 + 4096\nreward_model:\n  overlong_buffer:\n    enable: True\n    len: 4096\n    penalty_factor: 1.0\n```\n\nSetting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit.\n\nSpecifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length` by `0` to `overlong_buffer.len` tokens.\n\nCore relevant code:\n\n```python\nif self.overlong_buffer_cfg.enable:\n    overlong_buffer_len = self.overlong_buffer_cfg.len\n    expected_len = self.max_resp_len - overlong_buffer_len\n    exceed_len = valid_response_length - expected_len\n    overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor\n    overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)\n    reward += overlong_reward\n```\n\n## FAQ\n\n### Where is the \"Overlong Filtering\" in the paper?\n\nMost experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here.\n\n### What's the difference between [the `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) and the [`recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)?\n\n[The `recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) is for **as-is reproduction** and thus won't be updated with new features.\n\n[The `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) works as an example of how to extend the latest `verl` to implement an algorithm recipe, which will be maintained with new features.\n\n### Why can't I produce similar results after modifications?\n\nRL infrastructures nowadays still have inherent unrobustness, on which we are still working hard to improve.\n\nWe strongly recommend to only modify one thing at a time.\n\nWe also list some known problems here:\n\n1. Enabling CUDA graph (`enforce_eager=False`) might cause model performance degradation, whose cause is still under investigation.\n"
  },
  {
    "path": "verl_rl/docs/algo/entropy.md",
    "content": "# Recipe: Entropy Mechanism\n\nLast updated: 06/27/2025.\n\n\n<div align=\"center\">\n\n  The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning.\n\n[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617)  [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue\n)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861)\n\n\n<div align=\"center\" style=\"font-family: Arial, sans-serif;\">\n  <p>\n    <a href=\"#🎉news\" style=\"text-decoration: none; font-weight: bold;\">🎉 News</a> •\n    <a href=\"#✨getting-started\" style=\"text-decoration: none; font-weight: bold;\">✨ Getting Started</a> •\n    <a href=\"#📖introduction\" style=\"text-decoration: none; font-weight: bold;\">📖 Introduction</a>\n  </p>\n  <p>\n    <a href=\"#🎈citation\" style=\"text-decoration: none; font-weight: bold;\">🎈 Citation</a> •\n    <a href=\"#🌻acknowledgement\" style=\"text-decoration: none; font-weight: bold;\">🌻 Acknowledgement</a> •\n    <a href=\"#📬Contact\" style=\"text-decoration: none; font-weight: bold;\">📬 Contact</a> •\n    <a href=\"#📈star-history\" style=\"text-decoration: none; font-weight: bold;\">📈 Star History</a>\n  </p>\n</div>\n\n</div>\n\n\n## 🎉News\n\n- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29).\n- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse. \n\n\n\n## ✨Getting started\n\nAfter preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run:\n\n```\ncd verl\nconda activate your_env\nbash recipe/dapo/7b_kl_cov.sh\n```\n\nWhile for training Qwen2.5-32B on multi nodes, you can run the following commands:\n\n```\ncd verl\nconda activate your_env\nbash recipe/dapo/32b_kl_cov.sh\n```\n\n## 📖Introduction\n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/e2a.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\nThis paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion. \n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/cov.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\nTheoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose ​​Clip-Cov​​ and ​​KL-Cov​​, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance. \n\n## 📃Evaluation\n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/performance_fig.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\n\nOur method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL. \n| **Method**        | **AIME24** | **AIME25** |  **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** |\n| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: |\n| *Qwen2.5-7B*      |            |            |          |              |               |                   |             |          |\n| GRPO              |       21.2 |        9.6 |     58.7 |         78.8 |          27.9 |              40.7 |        36.7 |     38.6 |\n| w. Clip-higher    |       18.1 |       11.5 |     56.6 |         79.2 |          29.8 |              43.3 |        40.4 |     38.8 |\n| w. **`CLIP-Cov`** |       22.1 |   **15.8** |     58.2 |         80.4 |      **30.5** |          **44.1** |    **41.1** |     40.4 |\n| w. **`KL-Cov`**   |   **22.6** |       12.9 | **61.4** |     **80.8** |          29.1 |              42.6 |        38.2 | **40.6** |\n| *Qwen2.5-32B*     |            |            |          |              |               |                   |             |          |\n| GRPO              |       21.8 |       16.2 |     69.7 |         84.2 |          35.2 |              43.6 |        45.5 |     45.8 |\n| w. Clip-higher    |       35.6 |       22.3 |     69.5 |         77.2 |          35.1 |              42.5 |        43.0 |     47.2 |\n| w. **`CLIP-Cov`** |       32.3 |       22.7 |     67.2 |     **87.0** |      **42.0** |          **57.2** |        46.0 |     50.3 |\n| w. **`KL-Cov`**   |   **36.8** |   **30.8** | **74.5** |         84.6 |          39.1 |              49.0 |    **46.3** | **52.2** |\n\nOur two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively.\n\n\n## 🎈Citation\nIf you find this paper or repo helpful, please cite us.\n\n```bibtex\n@article{cui2025entropy,\n  title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models},\n  author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others},\n  journal={arXiv preprint arXiv:2505.22617},\n  year={2025}\n}\n```\n## 🌻Acknowledgement\nWe implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions!\n\n## 📬 Contact\n\nFor questions, discussion, or collaboration opportunities, feel free to contact:\n- Ganqu Cui: cuiganqu@pjlab.org.cn\n- Yuchen Zhang: yuchen.zhang2003@gmail.com\n- Jiacheng Chen: jackchan9345@gmail.com\n- Ning Ding: ningding.cs@gmail.com\n\n"
  },
  {
    "path": "verl_rl/docs/algo/gpg.md",
    "content": "# GPG: Group Policy Gradient\n\nLast updated: 07/03/2025.\n\nGroup Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning\n](https://arxiv.org/abs/2504.02546).\n\n## Key Components\n- Use a corrected advantage function to improve policy gradient accuracy and training efficiency.\n- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO)\n\n## Configuration\nTo configure GPG within the framework, use the following YAML settings.\n\n```yaml\nalgorithm:\n  adv_estimator: gpg \nactor_rollout_ref:\n  actor:\n    policy_loss:\n      loss_mode: \"gpg\"\n```\n\n## Advanced Extensions\nGPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance.\n\n```yaml\nalgorithm:\n  adv_estimator: gpg\nactor_rollout_ref:\n  actor:\n    use_kl_loss: True # enable kl regularization\n    kl_loss_coef: 0.01\n    policy_loss:\n      loss_mode: \"gpg\"\n```"
  },
  {
    "path": "verl_rl/docs/algo/grpo.md",
    "content": "# Group Relative Policy Optimization (GRPO)\n\nLast updated: 05/31/2025.\n\nIn reinforcement learning, classic algorithms like PPO rely on a \"critic\" model to estimate the value of actions, guiding the learning process. However, training this critic model can be resource-intensive. \n\nGRPO simplifies this process by eliminating the need for a separate critic model. Instead, it operates as follows:\n- Group Sampling: For a given problem, the model generates multiple possible solutions, forming a \"group\" of outputs.\n- Reward Assignment: Each solution is evaluated and assigned a reward based on its correctness or quality.\n- Baseline Calculation: The average reward of the group serves as a baseline. \n- Policy Update: The model updates its parameters by comparing each solution's reward to the group baseline, reinforcing better-than-average solutions and discouraging worse-than-average ones.\n\nThis approach reduces computational overhead by avoiding the training of a separate value estimation model, making the learning process more efficient. For more details, refer to the original paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/pdf/2402.03300)\n\n## Key Components\n\n- No Value Function (Critic-less): unlike PPO, GRPO does not train a separate value network (critic)\n- Group Sampling (Grouped Rollouts): instead of evaluating one rollout per input, GRPO generates multiple completions (responses) from the current policy for each prompt. This set of completions is referred to as a group.\n- Relative Rewards: within each group, completions are scored (e.g., based on correctness), and rewards are normalized relative to the group.\n\n## Configuration\n\nNote that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior.\n\nDespite that many configurations start with the `ppo_` prefix, they work across different RL algorithms in verl, as the GRPO training loop is similar to that of PPO (without critic).\n\n![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d)\n\n- `actor_rollout.ref.rollout.n`: For each prompt, sample n times. Default to 1. For GRPO, please set it to a value larger than 1 for group sampling.\n\n- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n`\n\n- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers.\n\n- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for GRPO updates on one set of sampled trajectories for actor\n\n- `actor_rollout_ref.actor.clip_ratio`: The GRPO clip range. Default to 0.2\n\n- `algorithm.adv_estimator`: Default is gae. Please set it to grpo instead\n\n- `actor_rollout_ref.actor.loss_agg_mode`: Default is \"token-mean\". Options include \"token-mean\", \"seq-mean-token-sum\", \"seq-mean-token-mean\". The original GRPO paper takes the sample-level loss (seq-mean-token-mean), which may be unstable in long-CoT scenarios. All GRPO example scripts provided in verl uses the default configuration \"token-mean\" for loss aggregation instead.\n\nInstead of adding KL penalty in the reward, GRPO regularizes by directly adding the KL divergence between the trained policy and the reference policy to the loss:\n\n- `actor_rollout_ref.actor.use_kl_loss`: To use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False. Please set it to True for GRPO.\n\n- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.\n\n- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\n## Advanced Extensions\n\n### DrGRPO\n\n[Understanding R1-Zero-Like Training: A Critical Perspective](https://arxiv.org/pdf/2503.20783) claims there's optimization bias in GRPO, which leads to artificially longer responses, especially for incorrect outputs. This inefficiency stems from the way GRPO calculates advantages using group-based reward normalization. Instead, DrGRPO aggregates token-level losses by normalizing with a global constant to eliminate length bias.\n\nConfigure the following to enable DrGRPO, with all other parameters the same as GRPO's:\n\n- `actor_rollout_ref.actor.loss_agg_mode`: \"seq-mean-token-sum-norm\", which turns off seq-dim averaging\n- `actor_rollout_ref.actor.use_kl_loss`: Please set it to False for DrGRPO\n- `algorithm.norm_adv_by_std_in_grpo`: False, which turns off standard deviation norm\n\n## Reference Example\n\nQwen2.5 GRPO training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log)\n\n```bash\nbash examples/grpo_trainer/run_qwen3-8b.sh\n```\n\nFor more reference performance, please see https://verl.readthedocs.io/en/latest/algo/baseline.html\n"
  },
  {
    "path": "verl_rl/docs/algo/opo.md",
    "content": "# On-Policy RL with Optimal Reward Baseline (OPO)\n\nLast updated: 06/02/2025.\n\nLoose on-policy constraints and suboptimal baselines in reinforcement learning often lead to training instability such as large policy shifts and entropy collapse. OPO addresses these challenges by using exact on-policy training with the theretically optimal reward baseline for advantage estimation. It achieves lower policy shifts and higher output entropy, encouraging more diverse and less repetitive responses.\n\nOPO uses group sampling to generate multiple outputs for each input like GRPO. Unlike group-based algorithms which typically use the mean reward of a group as its baseline, OPO employs a theoretically optimal baseline: the length-weighted reward of the group. It also  omits the standard deviation normalization. By adopting these two key components, OPO enables the training of a single policy model with the objective of maximizing only the expected reward. For more detailes, refer to the original paper [On-Policy RL with Optimal Reward Baseline](https://arxiv.org/pdf/2505.23585).\n\n## Key Components\n\n- Exact On-Policy Training: always generates responses from the current policy, without using any pre-generated data or off-policy data.\n- Optimal Reward Baseline: uses a length-weighted reward of the group as the baseline for normalizing the rewards.\n\n## Configuration\n\nTo configure OPO within the framework, use the following YAML settings. These parameters are crucial for enabling exact on-policy training and activating the optimal reward baseline.\n\n```yaml\nalgorithm:\n  adv_estimator: opo  # Use OPO for optimal reward baseline \ndata:\n  train_batch_size: 1024\nactor_rollout_ref:\n  actor:\n    ppo_mini_batch_size: 1024 # ppo_mini_batch_size should equal to train_batch_size to enable exact on-policy training\n    entropy_coeff: 0 # disable entropy regularization\n    use_kl_loss: False # disable kl regularization\n    kl_loss_coef: 0 \n```\n\n## Advanced Extensions\n\nOPO can also be extended to other algorithms like RLOO and Reinforce++. It just needs to adjust their configurations to enable exact on-policy training and incorporate the optimal length-weighted reward baseline with minimal modifications to their advantage estimation functions.\n"
  },
  {
    "path": "verl_rl/docs/algo/ppo.md",
    "content": "# Proximal Policy Optimization (PPO)\n\nLast updated: 06/19/2025.\n\nProximal Policy Optimization (PPO) is a family of policy gradient methods for reinforcement learning, proposed by OpenAI in 2017. PPO strikes a balance between simplicity, stability, and performance, making it one of the most widely used algorithms in modern RL applications, including large-scale language model fine-tuning.\n\nTraditional policy gradient methods like REINFORCE or Vanilla Policy Gradient suffer from:\n\n- High variance and sample inefficiency.\n- Instability due to large policy updates.\n\nPPO addresses this problem using a clipped surrogate objective that avoids overly large updates without requiring second-order derivatives.\n\nFor more technical details regarding PPO, we suggest reading the introduction in the [OpenAI spinning up tutorial](https://spinningup.openai.com/en/latest/algorithms/ppo.html), and the paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347).\n\n## Key Components\n\n- Actor-Critic Architecture: PPO requires both an actor model (policy) and a critic model (value function). This differs from other algorithms like GRPO and RLOO that don't require a critic model.\n\n- Generalized Advantage Estimation (GAE): PPO uses GAE for computing advantage values, which helps reduce variance in policy gradient estimates while maintaining low bias.\n\n- Clipped Surrogate Objective: The core of PPO is implemented through the clipped surrogate objective function that limits policy updates.\n\n## Configuration\n\nNote that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior.\n\nMost critic configs are similar to those of actors. Note that the critic model is omitted from the figure below.\n\n![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d)\n\n- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n`\n\n- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers\n\n- `actor_rollout_ref.critic.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO critic updates. The ppo_mini_batch_size is a global size across all workers\n\n- `actor_rollout_ref.actor.clip_ratio`: The PPO clip range. Default to 0.2\n\n- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for actor\n\n- `critic.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for critic. Defaults to `actor_rollout_ref.actor.ppo_epochs`\n\n- `algorithm.gemma`: discount factor\n\n- `algorithm.lam`: The lambda term that trades off between bias and variance in the GAE estimator\n\n- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo\n\n## Advanced Extensions\n\n### KL Divergence Control\n\nOptions to prevent the policy from diverging too far from a reference policy. Two mechanisms are available: KL reward penalty and KL loss. For more technical details, see [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155)\n\nOptions to use KL loss for KL divergence control: \n\n- `actor_rollout_ref.actor.use_kl_loss`: to use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False\n\n- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.\n\n- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\nOptions to use KL penalty in the reward:\n\n- `algorithm.use_kl_in_reward`: Whether to enable in-reward kl penalty. Default is False.\n\n- `algorithm.kl_penalty`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. This defines the way to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty` in core_algos.py. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\n- `algorithm.kl_ctrl.kl_coef`: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.\n- `algorithm.kl_ctrl.type`: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.\n- `algorithm.kl_ctrl.horizon`: See source code of AdaptiveKLController for details.\n- `algorithm.kl_ctrl.target_kl`: See source code of AdaptiveKLController for details.\n\n### Dual-clip PPO\n\nThe Dual-Clip PPO introduces a approach by applying a lower bound to the policy ratio when the advantage is less than zero, when multiplied by a large raito, does not exceed a specified lower bound.\n\n![image](https://github.com/user-attachments/assets/fc232181-d8b0-4307-8dd2-4dc0a4c1c139)\n\n- `actor_rollout_ref.actor.clip_ratio_c`: lower bound of the value for Dual-clip PPO, defaults to 3.0\n\n## Reference Example\n\nQwen2.5 training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log)\n\n```bash\nbash run_gemma.sh\n  trainer.n_gpus_per_node=1 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  trainer.logger=console \\\n  critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  data.train_batch_size=256 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size=2 \\\n  critic.ppo_micro_batch_size=2\n```\n\nReference performance with verl v0.2:\n\n| Model                          | Method          | Score | Link                                                                                           |\n|-------------------------------|------------------|-------|------------------------------------------------------------------------------------------------|\n| Qwen/Qwen2.5-0.5B-Instruct     | pretrained model | 36.4  | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/)                                        |\n| Qwen/Qwen2.5-0.5B-Instruct     | PPO              | 56.7  | [PPO Command and Logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) |\n"
  },
  {
    "path": "verl_rl/docs/algo/spin.md",
    "content": "# Recipe: Self-Play Fine-Tuning (SPIN)\n\nLast updated: 05/31/2025.\n\n`verl` provides a recipe inspired by the paper **\"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models\"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory.\n\n**Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models:\n\n1.  **Synthetic Data Generation:** The current model generates responses, creating its own training data from previous iterations.\n2.  **Two-Player Game Setup:** A game involving two players acted by a single LLM.\n3.  **Iterative Training:** The model progressively improves by refining its policy, with each iteration's model becoming the opponent for the next iteration.\n\nPaper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\\*, [Yihe Deng](https://github.com/uclaml/SPIN)\\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n\n[[Webpage](https://uclaml.github.io/SPIN/)] [[Huggingface](https://huggingface.co/papers/2401.01335)] [[Paper](https://arxiv.org/abs/2401.01335)] [[Original Implementation](https://github.com/uclaml/SPIN)]\n\nverl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20)\n\n---\n\n## Key Function (compute_online_dpo_loss) and Related works\nSPIN (Chen et al., 2024) proposes an iterative self-play mechanism to fine-tune language models. In each iteration, SPIN's training objective, when using a logistic loss function, is equivalent to Direct Preference Optimization (DPO) loss (Rafailov et al., 2023). \n\nThis `verl` recipe realizes SPIN's core concept by using DPO loss iteratively (Xu et al., 2023; Xiong et al., 2023; Snorkel AI, 2024). This means that in each iteration, we fine-tune the LLM using DPO loss for preference optimization. Notably, Xu et al. (2023) explored iterative preference optimization with pairwise cringe loss, while Xiong et al. (2023) discussed how to bridge theory and practice for RLHF under KL constraints using iterative training. The concept of iterative preference learning was also explored in online DPO (Guo et al., 2024), which focuses on direct alignment from online AI feedback. In online DPO, preference data is dynamically updated during training, allowing the model to learn from its own generated data.\n\nSpecifically, we developed the **`compute_online_dpo_loss`** function and built this SPIN recipe on top of it. By incorporating online preference generation, this approach enables continuously refining language models without relying on fixed external preference datasets.\n\n**Reference Papers:**\n* [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024) \n* [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023) \n* [Somethings are more cringe than others: Preference optimization with the pairwise cringe loss](https://arxiv.org/abs/2312.16682) (Xu et al., 2023) \n* [Iterative preference learning from human feedback: Bridging theory and practice for rlhf under kl-constraint](https://arxiv.org/abs/2312.11456) (Xiong et al., 2023)\n* [Snorkel-Mistral-PairRM-DPO](https://huggingface.co/snorkelai/Snorkel-Mistral-PairRM-DPO) (Snorkel AI, 2024)\n* [Direct language model alignment from online ai feedback](https://arxiv.org/abs/2402.04792) (Guo et al., 2024)\n\n\n## Our Online DPO Implementation\n\nOur `compute_online_dpo_loss` function adapts `verl`'s existing PPO infrastructure (based on `verl` v0.3.0.post1) for this iterative online DPO. Key aspects of our implementation include:\n\n* **No Critic:** Unlike PPO, we omit the value function critic.\n* **Dynamic Reference Model:** An explicit reference policy (`ref_policy_wg`) is used for DPO loss. This reference model's weights can be periodically updated from the actor (`ref_update_freq`), providing a dynamic baseline.\n* **Online Preference Generation:** The `compute_onlineDPO_pref` function (in `core_algos.py`) dynamically creates chosen/rejected pairs based on a reward source (e.g., rule-based ranking for math problems).\n* **DPO Loss Integration:** We replace PPO's policy loss with our `compute_online_dpo_loss` (in `core_algos.py`) within the actor update (`dp_actor.py`), directly optimizing the policy using the generated preferences.\n* **Iterative Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the entire self-play loop: generation, preference labeling, optional reference model updates, and policy updates, enabling continuous self-improvement aligned with SPIN's principles.\n\n---\n## Algorithm\n\nThis recipe implements an Online algorithm adapted to the `verl` Reinforcement Learning framework, which provides an alternative to PPO for fine-tuning language models.\n\n**Online Loop:** Instead of maximizing a scalar reward signal in PPO, this approach directly optimizes the policy model to align with preference data generated *online* during training:\n\n1.  **Generation:** The current model generates multiple responses for each prompt in a batch.\n2.  **Preference Labeling:** A function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). This can be done using a reward function or implicit ranking based on specific rules. (In this recipe, we use rule-based ranking on the math problem).\n3.  **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using `compute_online_dpo_loss`, comparing against a reference model.\n\n**Connection with SPIN:**\nInstead of only using a fixed target data distribution, the online generation loop in step 2 will dynamically change the target data distribution by using a certain Preference Labeling method (rule-based ranking on the math problem by selecting the better one in this recipe). This explores the direction mentioned in SPIN's paper Section 7 about \"dynamically changing target data distribution\" to potentially elevate LLM performance beyond the fixed human-annotated data ceiling.\n\n---\n\n## Reproduce the Experiment (Example Setup)\n\nThe following steps outline how to set up the environment and run the SPIN recipe, based on the provided test log using GSM8K and Qwen2.5-3B-Instruct.\n\n1.  **Setup Environment (Example using Docker):**\n    ```bash\n    # Start a container with GPU access and shared memory\n    docker run -it --name spin_test --gpus all \\\n        --shm-size=32g \\\n        --ipc=host \\\n        -v /path/to/host/.cache:/root/.cache \\\n        -e HF_TOKEN=<YOUR_HUGGINGFACE_TOKEN> \\\n        lmsysorg/sglang:latest \\\n        /bin/bash\n\n    # Inside the container or on your host machine:\n    # Ensure /tmp is writable\n    mkdir -p /tmp\n    chmod 1777 /tmp\n\n    # Install Python 3.10 (if not present) and venv\n    sudo apt update\n    sudo apt install -y python3.10 python3.10-venv tmux\n    python3 -m ensurepip --upgrade\n\n    # Create and activate a virtual environment\n    python3 -m venv ~/.python/spin_env\n    source ~/.python/spin_env/bin/activate\n\n    # Install uv (fast package installer)\n    python3 -m pip install uv\n    ```\n\n2.  **Install verl and Dependencies:**\n    ```bash\n    # Clone the verl repository and checkout the spin branch\n    cd ~\n    git clone git@github.com:volcengine/verl.git && cd verl\n\n    # Install flash-attn (handle potential build issues)\n    python3 -m uv pip install wheel packaging\n    python3 -m uv pip install flash-attn --no-build-isolation --no-deps\n\n    # Install verl with sglang extras\n    python3 -m uv pip install -e \".[sglang]\"\n    ```\n    *Note: If `flash-attn` installation fails, try the manual steps again or consult its documentation.*\n\n3.  **Login & Download Data/Model:**\n    ```bash\n    # Login to Weights & Biases (optional, for logging)\n    export WANDB_API_KEY=<YOUR_WANDB_API_KEY>\n    # wandb login\n\n    # Download the GSM8K dataset\n    python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k # Adjusted path\n\n    # Download the base model (Example: Qwen2.5-3B-Instruct)\n    huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen2.5-3B-Instruct\n    ```\n\n4.  **Configure:**\n    * Modify the configuration file (e.g., `config/spin_trainer.yaml` or the one specified in the run script) with correct paths to your downloaded model, data, desired hyperparameters (`dpo_beta`, learning rate, etc.), and distributed training settings (nodes, GPUs per node).\n    * Pay attention to `actor_rollout_ref.model_path`, `data` paths, `reward_model` config (if using one), and `trainer.ref_update_freq`.\n\n5.  **Run Training:**\n    ```bash\n    # Set CUDA visible devices (adjust based on your hardware and config)\n    export CUDA_VISIBLE_DEVICES=0,1,2,3\n\n    # Launch the training script (e.g., test.sh or a custom script)\n    # Ensure test.sh points to the correct config and main script\n    bash recipe/spin/run_spin.sh\n    ```\n\n---\n\n## Configuration\n\n* The primary configuration is typically managed through a YAML file specified in the launch script (e.g., `config/spin_trainer.yaml`).\n* Key configuration sections:\n    * `data`: Paths to training/validation prompt files, batch sizes, sequence lengths.\n    * `actor_rollout_ref`: Paths to the base model (used for actor and initial reference), FSDP settings, optimization parameters (learning rate, scheduler).\n    * `reward_model`: Configuration for the reward model used for online preference labeling (path, batch size, etc.). Can be omitted if using a simpler reward function.\n    * `algorithm`: DPO-specific hyperparameters like `dpo_beta`, `dpo_loss_type`.\n    * `trainer`: Distributed training settings (nodes, GPUs per node), logging (WandB), checkpointing frequency, and `ref_update_freq` (set > 0 to enable periodic reference model updates from the actor).\n\n---\n\n## Key Files\n\n* `main_spin.py`: Main entry point using Hydra to load the config and launch the `SpinTrainer`.\n* `spin_trainer.py`: Defines the `SpinTrainer` class, orchestrating the Online DPO training loop.\n* `fsdp_workers.py`: Implements Ray workers (Actor, Reference) potentially using FSDP.\n* `dp_actor.py`: Contains the actor class, including the DPO policy update logic.\n* `core_algos.py`: Includes helper functions for `compute_online_dpo_loss` and `compute_onlineDPO_pref`.\n* `config/spin_trainer.yaml` (or similar): Main Hydra configuration file for the recipe.\n* `run_spin.sh` (or similar): Example bash script for launching a training run.\n* `README.md`: This file.\n\n---\n\n## Acknowledgement\n\nWe sincerely thank the contribution and guidance from the `verl` community and advisors, including (adapted from SPPO):\n\n* [Zixiang Chen](https://sites.google.com/view/zxchen)\n* [Yuhao Yang](https://github.com/yhyang201)\n* [Yifan Zhang](https://github.com/yifanzhang-pro)\n* [Yongan Xiang](https://github.com/BearBiscuit05)\n* [Junrong Lin](https://github.com/ocss884)\n* [Yuxuan Tong](https://github.com/tongyx361)\n* [Guangming Shen](https://github.com/PeterSH6)\n* [Biao He](https://www.linkedin.com/in/biao-he/)\n* [Qingquan Song](https://qingquansong.github.io/)\n* [Chenyang Zhao](https://zhaochenyang20.github.io/Chayenne/)\n* [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n"
  },
  {
    "path": "verl_rl/docs/algo/sppo.md",
    "content": "# Recipe: Self-Play Preference Optimization (SPPO)\n\nLast updated: 05/28/2025.\n\nverl provides a community recipe implementation for the paper [Self-Play Preference Optimization for Language Model Alignment](https://arxiv.org/abs/2405.00675). SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets.\n\nPaper Authors: [Yue Wu](https://yuewu.us/)\\*, [Zhiqing Sun](https://www.cs.cmu.edu/~zhiqings/)\\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Yiming Yang](https://www.cs.cmu.edu/~yiming/), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n\nverl Implementation Authors: [Yuhao Yang](https://github.com/yhyang201), [Chenyang Zhao](https://github.com/zhaochenyang20)\n\n[[Webpage](https://uclaml.github.io/SPPO/)] [[Huggingface](https://huggingface.co/papers/2405.00675)] [[Paper](https://arxiv.org/abs/2405.00675)][[Original Implementation](https://github.com/uclaml/SPPO)]\n\n## Reproduce the Experiment\n\nWe evaluate the performance of SPPO on the MATH dataset. Starting from an initial score of 46.6 with Qwen2.5-7B-Instruct, we achieve a score of 65.6 after 20 epochs of training, placing our model approximately in the top 20 on the [MATH leaderboard](https://paperswithcode.com/sota/math-word-problem-solving-on-math). It's important to note that verl's internal evaluation metrics may not perfectly align with the official evaluation methodology for Qwen2.5-7B-Instruct. Therefore, for consistency and fair comparison, we report only the results based on verl's evaluation framework.\n\n```\ngit clone git@github.com:volcengine/verl.git\ncd verl\npython3 -m uv pip install -e \".[sglang]\"\n\nexport WANDB_API_KEY=<YOUR_WANDB_API_KEY>\n\npython3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math\nhuggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct\n\nexport CUDA_VISIBLE_DEVICES=0,1,2,3\nbash recipe/sppo/run_qwen2.5-7b_rm.sh\n```\n\nNote that the installation would occasionally fail to install flash-attn. If this happens, you can install it manually by running:\n\n```bash\npython3 -m uv pip install wheel\npython3 -m uv pip install packaging\npython3 -m uv pip install flash-attn --no-build-isolation --no-deps\n```\n\n## Acknowledgement\n\nWe sincerely thank the contribution and guidance from:\n\n- [Yue Wu](https://yuewu.us/)\n- [Chendong Wang](https://cdwang96.github.io/)\n- [Yifan Zhang](https://github.com/yifanzhang-pro)\n- [Yongan Xiang](https://github.com/BearBiscuit05)\n- [Junrong Lin](https://github.com/ocss884)\n- [Yuxuan Tong](https://github.com/tongyx361)\n- [Guangming Shen](https://github.com/PeterSH6)\n- [Biao He](https://www.linkedin.com/in/biao-he/)\n- [Qingquan Song](https://qingquansong.github.io/)\n- [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n"
  },
  {
    "path": "verl_rl/docs/amd_tutorial/amd_build_dockerfile_page.rst",
    "content": "Getting started with AMD (ROCM Kernel)\n=====================================================\n\nLast updated: 07/06/2025.\n\nAuthor: `Yusheng Su <https://yushengsu-thu.github.io/>`_\n\nSetup\n-----\n\nIf you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and set ``RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES`` or ``RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES`` when starting ray in verl's RLHF training.\n\n\ndocker/Dockerfile.rocm\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    FROM \"rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04\"\n\n    SHELL [\"/bin/bash\", \"-ceuxo\", \"pipefail\"]\n\n    ENV MAX_JOBS=512\n\n    ENV PATH=\"/usr/local/python3.12/bin:$PATH\"\n    RUN ln -sf /usr/bin/python3.12 /usr/bin/python && \\\n        ln -sf /usr/bin/pip3.12 /usr/bin/pip\n\n    ############################################\n    RUN apt-get update\n    RUN apt-get install -y pkg-config liblzma-dev\n    ############################################\n\n    ###########################################\n    ##########Install TransformerEngine########\n    ###########################################\n    WORKDIR /workspace/\n    # transformer-engine install\n    # https://github.com/ROCm/TransformerEngine\n    RUN rm -rf TransformerEngine \n    RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git\n    WORKDIR /workspace/TransformerEngine\n    git checkout 236178e5\n    # git checkout bb061ade\n    # git checkout 864405c\n    ENV NVTE_FRAMEWORK=pytorch \n    ENV NVTE_ROCM_ARCH=gfx942 \n    ENV NVTE_USE_HIPBLASLT=1\n    ENV NVTE_USE_ROCM=1  \n    # export CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}\"\n    ENV CMAKE_PREFIX_PATH=\"/opt/rocm:/opt/rocm/hip:/usr/local:/usr\"\n    RUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv \n    WORKDIR /workspace/\n    ###########################################\n    ###########################################\n    ###########################################\n\n\n\n\n\n    ####################################################################################\n    ################Install vllm - sglang require vllm 0.6.7 dependency#################\n    ####################################################################################\n    #### Require vllm 0.6.7 - checkout 113274a0\n    WORKDIR /workspace/\n    RUN rm -rf vllm\n    RUN pip uninstall -y vllm\n    # Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html\n    RUN git clone https://github.com/ROCm/vllm.git\n    # git clone https://github.com/vllm-project/vllm.git\n    WORKDIR /workspace/vllm\n    RUN git checkout 113274a0\n    ENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n    #ENV MAX_JOBS=512\n    ENV MAX_JOBS=${MAX_JOBS}\n    RUN pip install \"boto3>=1.26.0\"\n    RUN pip install setuptools_scm\n    # will add src into py. You can delete the repo\n    RUN python3 setup.py install\n    WORKDIR /workspace/\n    ####################################################################################\n    ####################################################################################\n    ####################################################################################\n\n\n\n    ###########################################\n    ############For hack docker################\n    ###########################################\n    RUN pip install setuptools==75.8.0\n    ###########################################\n    ###########################################\n    ###########################################\n\n\n\n    ###########################################\n    ############build sgalng###################\n    ###########################################\n    # Set environment variables\n    ENV BASE_DIR=/sgl-workspace\n    ENV BUILD_TYPE=all\n    ENV SGL_REPO=https://github.com/sgl-project/sglang\n    ENV SGL_BRANCH=v0.4.6.post5\n    ENV TRITON_REPO=https://github.com/ROCm/triton.git\n    ENV TRITON_COMMIT=improve_fa_decode_3.0.0\n    ENV AITER_REPO=https://github.com/ROCm/aiter.git\n    ENV AITER_COMMIT=v0.1.2\n    # v0.1.2 version - commit id: 9d11f47\n    # ENV AITER_COMMIT=9d11f47\n    ENV HIP_FORCE_DEV_KERNARG=1\n    ENV HSA_NO_SCRATCH_RECLAIM=1\n    ENV SGLANG_SET_CPU_AFFINITY=1\n    ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1\n    ENV NCCL_MIN_NCHANNELS=112\n    ENV MOE_PADDING=1\n    ENV VLLM_FP8_PADDING=1\n    ENV VLLM_FP8_ACT_PADDING=1\n    ENV VLLM_FP8_WEIGHT_PADDING=1\n    ENV VLLM_FP8_REDUCE_CONV=1\n    ENV TORCHINDUCTOR_MAX_AUTOTUNE=1\n    ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1\n    ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\n    ENV AMDGPU_TARGETS=gfx942\n    ENV ROCM_ARCH=gfx942\n    ENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n    # Switch to working directory\n    WORKDIR /sgl-workspace\n    # Clean and create directory\n    RUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace\n\n    # Clone and build sglang\n    RUN git clone ${SGL_REPO} \\\n        && cd sglang \\\n        && git checkout ${SGL_BRANCH} || echo \"Using default branch\" \\\n        && cd sgl-kernel \\\n        && rm -f pyproject.toml \\\n        && mv pyproject_rocm.toml pyproject.toml \\\n        && python setup_rocm.py install \\\n        && cd .. \\\n        && if [ \"$BUILD_TYPE\" = \"srt\" ]; then \\\n            python -m pip --no-cache-dir install -e \"python[srt_hip]\"; \\\n        else \\\n            python -m pip --no-cache-dir install -e \"python[all_hip]\"; \\\n        fi \\\n        && cd /sgl-workspace \\\n        && cp -r /sgl-workspace/sglang /sglang \\\n        && python -m pip cache purge\n\n    # Install common Python packages\n    RUN pip install IPython orjson python-multipart torchao pybind11\n    # Rebuild Triton\n    RUN pip uninstall -y triton || true \\\n        && git clone ${TRITON_REPO} \\\n        && cd triton \\\n        && git checkout ${TRITON_COMMIT} \\\n        && cd python \\\n        && python3 setup.py install \\\n        && cd /sgl-workspace\n    # ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1\"\n    # ENV HIPCC_COMPILE_FLAGS_APPEND=\"--offload-arch=gfx942\"\n\n    # Build aiter\n    #version: Commit 9d11f47\n        # && git checkout ${AITER_COMMIT} \\\n    RUN pip uninstall -y aiter || true\n    RUN git clone ${AITER_REPO} \\\n        && cd aiter \\\n        && git checkout ${AITER_COMMIT} \\\n        && git submodule sync \\\n        && git submodule update --init --recursive \\\n        && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \\\n        && cd /sgl-workspace\n\n    # Copy MI300X config \n    RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \\\n            /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \\\n            -type f -name '*MI300X*' | \\\n            xargs -I {} sh -c 'vf_config=$(echo \"$1\" | sed \"s/MI300X/MI300X_VF/\"); cp \"$1\" \"$vf_config\"' -- {}\n\n    # Environment setup complete.\n    RUN echo \"Environment setup complete.\"\n\n    WORKDIR /workspace/\n    ###########################################\n    ###########################################\n    ###########################################\n\n\n\n\n\n\n    ###########################################\n    ###############vllm v0.8.5#################\n    ###########################################\n    WORKDIR /workspace/\n\n    ENV VLLM_TARGET_DEVICE=rocm \n    ENV ROCM_PATH=/opt/rocm \n    ENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev\n    # Find the repo path in: DockerFile/Dockerfile.rocm_yang\n    # RUN git clone https://github.com/RLFoundation/vllm-patch.git\n    RUN pip uninstall -y vllm || true\n    RUN rm -rf vllm-patch\n    RUN git clone https://github.com/RLFoundation/vllm-patch.git \\\n        && cd vllm-patch \\\n        && git checkout v0.8.5-sleep-numa \\\n        && rm -rf build/ dist/ *.egg-info \\\n        && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \\\n        && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py install\n        # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" MAX_JOBS=${MAX_JOBS} python3 setup.py develop\n    WORKDIR /workspace/\n    ###########################################\n    ###########################################\n    ###########################################\n\n\n\n\n    #########################################\n    #### Install megatron-core###############\n    #########################################\n    RUN pip uninstall -y megatron-core && \\\n        git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \\\n        cd Megatron-LM-amd_version && \\\n        pip install -vvv -e . && \\\n        cd /workspace/\n    #########################################\n    #########################################\n    #########################################\n\n\n\n\n    #######################################\n    ################apex###################\n    #######################################\n    WORKDIR /workspace/\n    RUN pip uninstall -y apex && \\\n        git clone git@github.com:ROCm/apex.git && \\\n        cd apex && \\\n        python setup.py install && \\\n        cd /workspace/ \n    #######################################\n    #######################################\n    #######################################\n\n\n    ################################################################################\n    ###########################Add torch_memory_saver###############################\n    ################################################################################\n    # Set environment variables\n    ENV HIPCC_COMPILE_FLAGS_APPEND=\"--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__\"\n    ENV CFLAGS=\"-D__HIP_PLATFORM_AMD__\"\n    ENV CXXFLAGS=\"-D__HIP_PLATFORM_AMD__\"\n    RUN pip install \"git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa\"\n    ################################################################################\n    ################################################################################\n    ################################################################################\n\n\n\n    ########################################\n    ######Install ray#######################\n    ########################################\n    # need to add this patch: https://github.com/ray-project/ray/pull/53531/files\n    RUN pip uninstall ray -y\n    RUN pip install \"ray[data,train,tune,serve]>=2.47.0\" \n    ########################################\n    ########################################\n    ########################################\n\n\n    ##########################################\n    #######Install other dependencies#########\n    ##########################################\n    RUN pip install \"tensordict==0.6.2\" --no-deps && \\\n        pip install accelerate \\\n        codetiming \\\n        datasets \\\n        dill \\\n        hydra-core \\\n        liger-kernel \\\n        numpy \\\n        pandas \\\n        peft \\\n        \"pyarrow>=15.0.0\" \\\n        pylatexenc \\\n        torchdata \\\n        wandb \\\n        orjson \\\n        pybind11\n        \n    WORKDIR /workspace/\n    RUN git clone https://github.com/volcengine/verl.git && \\\n        cd verl && \\\n        pip install -e . \n    ##########################################\n    ##########################################\n    ##########################################\n\n    WORKDIR /workspace/\n    CMD [\"/usr/bin/bash\"]\n\n\nBuild the image:\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    docker docker/build -t verl-rocm .\n\nRun the container\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nNote: You can pull the docker from this DockerHub: [RLSys Foundation](https://hub.docker.com/u/yushengsuthu)\nPull the image:\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    docker pull yushengsuthu/verl:verl-0.4.1_ubuntu-22.04_rocm6.3.4-numa-patch_vllm0.8.5_sglang0.4.6.post4\n\n    docker tag yushengsuthu/verl:verl-0.4.1_ubuntu-22.04_rocm6.3.4-numa-patch_vllm0.8.5_sglang0.4.6.post4 verl-rocm:latest\n\nRun the container\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n\nOptional: Running without root and with user permissions\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. code-block:: bash\n\n    docker run --rm -it \\\n      --device /dev/dri \\\n      --device /dev/kfd \\\n      -p 8265:8265 \\\n      --group-add video \\\n      --cap-add SYS_PTRACE \\\n      --security-opt seccomp=unconfined \\\n      --privileged \\\n      -v $HOME/.ssh:/root/.ssh \\\n      -v $HOME:$HOME \\\n      --shm-size 128G \\\n      -w $PWD \\\n      verl-rocm \\\n      /bin/bash\n\n(Optional): If you do not want to root mode and require assign yourself as the user\nPlease add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script. \n\nExample\n-------\n\nDue to to special setting in AMD (ROCM) torch, \n1. If your ``ray>=2.45.0`` (default), you need to set ``RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES`` when starting ray in verl's RLHF training and add this [patch](https://github.com/ray-project/ray/pull/53531/files).\n2. If your ``ray<2.45.0``, you need to set ``RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES`` when starting ray in verl's RLHF training.\nInference ``$ENGINE`` can be ``vllm`` or ``sglang``. We choose ``vllm`` as default in the following examples.\n\n\n\nPPO\n~~~\n\n.. code-block:: bash\n\n    YOUR_PROJECT_NAME=r1-verl-ppo-upstream\n    YOUR_RUN_NAME=r1-training_ppo-upstream \n    # export HYDRA_FULL_ERROR=1\n\n    export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n    \n    # [ray] < 2.45.0\n    #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1\n\n    # [ray] >= 2.45.0\n    export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 # Patch with https://github.com/ray-project/ray/pull/52794\n\n    GPUS_PER_NODE=8\n    MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct\n    python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k\n    python3 -c \"import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')\"\n    ENGINE=vllm #sglang\n\n    PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n     data.train_files=data/gsm8k/train.parquet \\\n     data.val_files=data/gsm8k/test.parquet \\\n     data.train_batch_size=256 \\\n     data.val_batch_size=1312 \\\n     data.max_prompt_length=512 \\\n     data.max_response_length=256 \\\n     actor_rollout_ref.model.path=$MODEL_PATH \\\n     actor_rollout_ref.actor.optim.lr=1e-6 \\\n     actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n     actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n     actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n     actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n     actor_rollout_ref.rollout.name=$ENGINE \\\n     actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n     actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n     critic.optim.lr=1e-5 \\\n     critic.model.path=$MODEL_PATH \\\n     critic.ppo_micro_batch_size_per_gpu=4 \\\n     algorithm.kl_ctrl.kl_coef=0.001 \\\n     trainer.logger=console \\\n     trainer.project_name=$YOUR_PROJECT_NAME \\\n     trainer.experiment_name=$YOUR_RUN_NAME \\\n     trainer.val_before_train=False \\\n     trainer.n_gpus_per_node=$GPUS_PER_NODE \\\n     trainer.nnodes=1 \\\n     trainer.save_freq=10 \\\n     trainer.test_freq=10 \\\n     trainer.total_epochs=15 #2>&1 | tee verl_demo.log\n\nGRPO\n~~~~\n\n.. code-block:: bash\n\n    YOUR_PROJECT_NAME=r1-verl-grpo-upstream\n    YOUR_RUN_NAME=r1-training_grpo-upstream\n    # export HYDRA_FULL_ERROR=1\n    # export FSDP_VERBOSE=1 \n\n    #export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n\n    # [ray] < 2.45.0\n    #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1\n\n    # [ray] >= 2.45.0\n    export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 # Patch with https://github.com/ray-project/ray/pull/52794\n\n    GPUS_PER_NODE=8\n    MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct\n    # MODEL_PATH=Qwen/Qwen2-7B-Instruct\n    python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k\n    python3 -c \"import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')\"\n    ENGINE=vllm #sglang\n    \n    python3 -m verl.trainer.main_ppo \\\n        algorithm.adv_estimator=grpo \\\n        data.train_files=data/gsm8k/train.parquet \\\n        data.val_files=data/gsm8k/test.parquet \\\n        data.train_batch_size=1024 \\\n        data.val_batch_size=1312 \\\n        data.max_prompt_length=512 \\\n        data.max_response_length=1024 \\\n        actor_rollout_ref.model.path=$MODEL_PATH \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n        actor_rollout_ref.actor.use_dynamic_bsz=True \\\n        actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n        actor_rollout_ref.actor.use_kl_loss=True \\\n        actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n        actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=Flase \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n        actor_rollout_ref.rollout.name=$ENGINE \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n        actor_rollout_ref.rollout.n=5 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=False \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.critic_warmup=0 \\\n        trainer.logger=console \\\n        trainer.project_name=$YOUR_PROJECT_NAME \\\n        trainer.experiment_name=$YOUR_RUN_NAME \\\n        trainer.n_gpus_per_node=$GPUS_PER_NODE \\\n        trainer.val_before_train=False \\\n        trainer.nnodes=1 \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15\n\n\n\nMulti-node training: slurm with Docker/Podman container\n---------------------------------------------------------------------------------------\n\nIf you want to run multi-node training with slurm, you can use the following script. \n\n.. note::\n    1. You need to use ``podman`` or ``docker`` in the following script. We will release the apptainer script later.\n    2. If you want to use ``podman``, you just replace ``docker`` with ``podman`` in the following script.\n\nThe script includes the following steps:\n\n1. SLURM Configuration\n2. Environment Setup\n3. Docker/Podman Container Setup\n4. Ray Cluster Initialization\n5. Data Preprocessing\n6. Model Setup\n7. Training Launch\n\n\nslurm_script.sh\n~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    #!/bin/bash\n\n    #SBATCH --job-name=verl-ray-on-slurm\n    #SBATCH --nodes=2\n    #SBATCH --ntasks-per-node=2\n    #SBATCH --mem=200G\n    #SBATCH --time=30-00:00:00\n    #SBATCH --gpus-per-node=8\n    #SBATCH --cpus-per-task=28\n    #SBATCH --output=../verl_log/slurm-%j.out\n    #SBATCH --error=../verl_log/slurm-%j.err\n    #SBATCH --nodelist=gpu-[0,1]\n\n\n    # load necessary modules\n    ### Run this setup\n    # [Cluster]: Use docker\n    # docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n\n\n    ##########################################################################\n    ###The following setting should be set in different project and cluster###\n    ##########################################################################\n\n    ### Project\n    CONTAINER_NAME=\"multinode_verl_training\"\n    IMG=\"verl.rocm\"\n    DOCKERFILE=\"docker/Dockerfile.rocm\"\n    # echo $PWD\n    verl_workdir=\"${HOME}/projects/verl_upstream\"\n    export TRANSFORMERS_CACHE=\"${HOME}/.cache/huggingface\"\n    export HF_HOME=$TRANSFORMERS_CACHE\n\n    ### Cluster Network Setting\n    export NCCL_DEBUG=TRACE\n    export GPU_MAX_HW_QUEUES=2\n    export TORCH_NCCL_HIGH_PRIORITY=1\n    export NCCL_CHECKS_DISABLE=1\n    # export NCCL_IB_HCA=rdma0,rdma1,rdma2,rdma3,rdma4,rdma5,rdma6,rdma7 \n    export NCCL_IB_HCA=mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_8,mlx5_9\n    export NCCL_IB_GID_INDEX=3\n    export NCCL_CROSS_NIC=0\n    export CUDA_DEVICE_MAX_CONNECTIONS=1\n    export NCCL_PROTO=Simple\n    export RCCL_MSCCL_ENABLE=0\n    export TOKENIZERS_PARALLELISM=false\n    export HSA_NO_SCRATCH_RECLAIM=1\n    ##########################################################################\n\n    ## Assign using GPUs\n    export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n\n    ### For rocm and training script\n    # [ray] < 2.45.0\n    #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1\n\n    # [ray] >= 2.45.0\n    export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 # Patch with https://github.com/ray-project/ray/pull/52794\n\n\n    # Build and launch the Docker container\n    srun bash -c \"\n        # Exit on any error\n        set -e \n\n        # Clean up dangling images (images with <none> tag)\n        docker image prune -f\n\n        # Need to pull the docker first\n        docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n        \n        if ! docker images --format \"{{.Repository}}:{{.Tag}}\" | grep -q \"${IMG}\"; then\n            echo \\\"Building ${IMG} image...\\\"\n            docker build -f \\\"${DOCKERFILE}\\\" -t \\\"${IMG}\\\" .\n        else\n            echo \\\"${IMG} image already exists, skipping build\\\"\n        fi\n\n        # Removing old container if exists\n        docker rm \\\"${CONTAINER_NAME}\\\" 2>/dev/null || true\n\n        # Checking network devices\n        ibdev2netdev\n\n        # Launch the docker\n        docker run --rm -d \\\n        -e HYDRA_FULL_ERROR=1 \\\n        -e RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 \\\n        -e RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 \\\n        -e NCCL_DEBUG=${NCCL_DEBUG} \\\n        -e GPU_MAX_HW_QUEUES=${GPU_MAX_HW_QUEUES} \\\n        -e TORCH_NCCL_HIGH_PRIORITY=${TORCH_NCCL_HIGH_PRIORITY} \\\n        -e NCCL_CHECKS_DISABLE=${NCCL_CHECKS_DISABLE} \\\n        -e NCCL_IB_HCA=${NCCL_IB_HCA} \\\n        -e NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX} \\\n        -e NCCL_CROSS_NIC=${NCCL_CROSS_NIC} \\\n        -e CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS} \\\n        -e NCCL_PROTO=${NCCL_PROTO} \\\n        -e RCCL_MSCCL_ENABLE=${RCCL_MSCCL_ENABLE} \\\n        -e TOKENIZERS_PARALLELISM=${TOKENIZERS_PARALLELISM} \\\n        -e HSA_NO_SCRATCH_RECLAIM=${HSA_NO_SCRATCH_RECLAIM} \\\n        -e TRANSFORMERS_CACHE=${TRANSFORMERS_CACHE} \\\n        -e HF_HOME=${HF_HOME} \\\n        --network host \\\n        --device /dev/dri \\\n        --device /dev/kfd \\\n        --device /dev/infiniband \\\n        --group-add video \\\n        --cap-add SYS_PTRACE \\\n        --security-opt seccomp=unconfined \\\n        --privileged \\\n        -v \\${HOME}:\\${HOME} \\\n        -v \\${HOME}/.ssh:/root/.ssh \\\n        -w \"${verl_workdir}\" \\\n        --shm-size 128G \\\n        --name \\\"${CONTAINER_NAME}\\\" \\\n        \\\"${IMG}\\\" \\\n        tail -f /dev/null\n\n        echo \\\"Container setup completed\\\"\n    \"\n        # (Optional): If you do not want to root mode and require assign yuorself as the user\n        # Please add `-e HOST_UID=$(id -u)` and `-e HOST_GID=$(id -g)` into the above docker launch script. \n\n\n\n\n\n    ### Ray launch the nodes before training\n\n    # Getting the node names\n    nodes_array=($(scontrol show hostnames \"$SLURM_JOB_NODELIST\" | tr '\\n' ' '))\n\n    head_node=${nodes_array[0]}\n    head_node_ip=$(srun --nodes=1 --ntasks=1 -w \"$head_node\" hostname --ip-address)\n\n    # if we detect a space character in the head node IP, we'll\n    # convert it to an ipv4 address. This step is optional.\n    if [[ \"$head_node_ip\" == *\" \"* ]]; then\n        IFS=' ' read -ra ADDR <<<\"$head_node_ip\"\n    if [[ ${#ADDR[0]} -gt 16 ]]; then\n        head_node_ip=${ADDR[1]}\n    else\n        head_node_ip=${ADDR[0]}\n    fi\n        echo \"IPV6 address detected. We split the IPV4 address as $head_node_ip\"\n    fi\n\n    port=6379\n    ip_head=$head_node_ip:$port\n    export ip_head\n    echo \"IP Head: $ip_head\"\n\n    # make sure we set environment variables before Ray initialization\n\n    # Print out all env variables\n    printenv\n\n    echo \"Starting HEAD at $head_node\"\n    srun --nodes=1 --ntasks=1 -w \"$head_node\" \\\n        docker exec \"${CONTAINER_NAME}\" \\\n            ray start --head --node-ip-address=\"$head_node_ip\" --port=$port \\\n            --dashboard-port=8266 \\\n            --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n    # optional, though may be useful in certain versions of Ray < 1.0.\n    sleep 10\n\n    # number of nodes other than the head node\n    worker_num=$((SLURM_JOB_NUM_NODES - 1))\n\n    for ((i = 1; i <= worker_num; i++)); do\n        node_i=${nodes_array[$i]}\n        echo \"Debug: Starting worker on node_i = ${node_i}\"\n        if [ -z \"$node_i\" ]; then\n            echo \"Error: Empty node name for worker $i\"\n            continue\n        fi\n        echo \"Starting WORKER $i at $node_i\"\n        srun --nodes=1 --ntasks=1 -w \"$node_i\" \\\n            docker exec \"${CONTAINER_NAME}\" \\\n                ray start --address \"$ip_head\" --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n        sleep 5\n    done\n\n\n\n\n    # Ray initlization test (See whether any error in the above execution)\n    echo \"Testing Ray initialization in the slurm nodes...\"\n    docker exec \"${CONTAINER_NAME}\" python3 -c '\n    import ray\n    try:\n        ray.init(address=\"auto\")\n        print(\"\\n=== Ray Cluster Status ===\")\n        print(f\"Number of nodes: {len(ray.nodes())}\")\n        for node in ray.nodes():\n            print(\"Node: {}, Status: {}\".format(node[\"NodeManagerHostname\"], node[\"Alive\"]))\n            # print(f\"Node: {node}\")\n        ray.shutdown()\n        print(\"Ray initialization successful!\")\n    except Exception as e:\n        print(f\"Ray initialization failed: {str(e)}\")\n    '\n    echo \"=== Ray test completed ===\"\n    ######\n\n\n\n    # Run data preprocessing\n\n    echo \"Starting data preprocessing...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 \"examples/data_preprocess/gsm8k.py\" \"--local_dir\" \"../data/gsm8k\"\n\n    echo \"Starting data preprocessing...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 \"examples/data_preprocess/math_dataset.py\" \"--local_dir\" \"../data/math\"\n\n    train_files=\"../data/gsm8k/train.parquet\"\n    val_files=\"../data/gsm8k/test.parquet\"\n\n    # Download and test model\n    echo \"Loading model...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')\"\n    MODEL_PATH=\"Qwen/Qwen2-7B-Instruct\"\n\n    # Set model path after pipeline test\n    MODEL_PATH=\"Qwen/Qwen2.5-0.5B-Instruct\"\n\n    echo \"== Data and model loading Done ==\"\n\n    echo \"Start to train...\"\n\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')\"\n    MODEL_PATH=\"Qwen/Qwen2-7B-Instruct\"\n\n\n    PYTHONUNBUFFERED=1 srun --overlap --nodes=${SLURM_NNODES} --ntasks=1 -w \"$head_node\" \\\n        docker exec \"${CONTAINER_NAME}\" \\\n        python3 -m verl.trainer.main_ppo \\\n        data.train_files=$train_files \\\n        data.val_files=$val_files \\\n        data.train_batch_size=1024 \\\n        data.max_prompt_length=1024 \\\n        data.max_response_length=1024 \\\n        actor_rollout_ref.model.path=$MODEL_PATH \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=False \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n        actor_rollout_ref.rollout.name=vllm \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n        critic.optim.lr=1e-5 \\\n        critic.model.use_remove_padding=True \\\n        critic.model.path=$MODEL_PATH \\\n        critic.model.enable_gradient_checkpointing=False \\\n        critic.ppo_micro_batch_size_per_gpu=8 \\\n        critic.model.fsdp_config.param_offload=False \\\n        critic.model.fsdp_config.optimizer_offload=False \\\n        algorithm.kl_ctrl.kl_coef=0.0001 \\\n        trainer.critic_warmup=0 \\\n        trainer.logger='[\"console\",\"wandb\"]' \\\n        trainer.project_name='verl_example' \\\n        trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \\\n        trainer.n_gpus_per_node=${SLURM_GPUS_PER_NODE} \\\n        trainer.val_before_train=False \\\n        trainer.nnodes=${SLURM_NNODES} \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15\n\n\nRun slurm_script.sh\n~~~~~~~~~~~~~~~~~~~~\nJust sbatch your slurm_script.sh\n\n.. code-block:: bash\n\n    sbatch slurm_script.sh\n\n"
  },
  {
    "path": "verl_rl/docs/amd_tutorial/amd_vllm_page.rst",
    "content": "verl performance tuning for AMD (ROCm Kernel)\n=====================================================\n\nLast updated: 04/25/2025.\n\nAuthor: `Yang Wang <https://github.com/YangWang92/>`_\n\nPatch vLLM to Enable Sleep Mode for AMD GPUs\n--------------------------------------------------------------\n\nBy default, verl requires vLLM to enable sleep mode, which allows vLLM to offload GPU memory to CPU memory after rollout. However, this feature is still under review by the vLLM community.\n\nTo enable vLLM's sleep mode, you can first use community patched code (from `this pull request <https://github.com/vllm-project/vllm/pull/12695>`_) to build vLLM from the source code in the corresponding pull request. After the patch merged in vLLM main branch, you can directly install vLLM from the latest version.\n\n1. Clone the vLLM repository and build it with the following commands:\n\n.. code-block:: bash\n\n    git clone -b sleep_amd https://github.com/HollowMan6/vllm.git\n    cd vllm\n    sudo ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so\n    VLLM_TARGET_DEVICE=rocm ROCM_PATH=/opt/rocm/ VLLM_GPU_LANG=HIP SETUPTOOLS_SCM_PRETEND_VERSION=0.8.4.dev python3 setup.py develop\n\n2. Additionally, make sure to use the ROCm version in your Docker image lager than or equal to ROCm 6.3.4, and we recommend to use ROCm 6.4.0 for better performance (see `this comment <https://github.com/vllm-project/vllm/pull/12695#issuecomment-2637839574>`_).\n\nAfter the upgrade, you can verify whether sleep mode is enabled by running the following test code (from `this comment <https://github.com/vllm-project/vllm/pull/12695#issuecomment-2637839574>`_).\n\n.. code-block:: python\n\n\timport torch\n\tfrom vllm import LLM\n\n\tllm = LLM(model=\"meta-llama/Llama-3.1-8B-Instruct\", enable_sleep_mode=True)\n\n\tdef run_inference(prompt):\n\t\toutputs = llm.generate(prompt)\n\t\tfor output in outputs:\n\t\t\tprompt = output.prompt\n\t\t\tgenerated_text = output.outputs[0].text\n\t\t\tprint(f\"Prompt: {prompt!r}, Generated text: {generated_text!r}\")\n\n\n\tprint(\"CUDA Memory Usage (after inference):\")\n\ttorch.cuda.empty_cache()\n\tprint(f\"{torch.cuda.memory_allocated()=}\")\n\n\trun_inference(\"San Francisco is\")\n\tllm.sleep()\n\n\tprint(\"CUDA Memory Usage (after sleep):\")\n\ttorch.cuda.empty_cache()\n\tprint(f\"{torch.cuda.memory_allocated()=}\")\n\n\tllm.wake_up()\n\n\tprint(\"CUDA Memory Usage (after wakeup):\")\n\ttorch.cuda.empty_cache()\n\tprint(f\"{torch.cuda.memory_allocated()=}\")\n\n\trun_inference(\"Paris is\")\n\nIf sleep mode is enabled, you should see the memory usage reduce after sleep.\n\nAfter applying the vLLM patch and completing the installation, you can enable sleep mode in verl to reduce memory overhead. This allows verl to offload unused GPU memory during rollout, significantly lowering the memory footprint during long-context training or multi-node reinforcement learning.\n\n\nEnable CUDA Graph and Bypass ROCm-related issues\n--------------------------------------------------------------\n\nDue to potential issues with CUDA graph capture in ROCm, we’ve found that vLLM’s CUDA graph feature cannot be enabled on multiple nodes in verl on AMD platforms with vLLM V1 mode. This leads to significantly slower rollout performance.\n\nOur investigation shows that ROCm may trigger an unexpected crash when attempting to capture large batches with CUDA graph. One workaround is to patch the LLM configuration (from `this commit <https://github.com/volcengine/verl/blob/v0.3.0.rc0/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py#L100-L115>`_).\n\n.. code-block:: python\n\t\n    self.inference_engine = LLM(\n        model=model_path,\n        enable_sleep_mode=True,\n        tensor_parallel_size=tensor_parallel_size,\n        distributed_executor_backend=\"external_launcher\",\n        dtype=config.dtype,\n        enforce_eager=config.enforce_eager,\n        gpu_memory_utilization=config.gpu_memory_utilization,\n        disable_custom_all_reduce=True,\n        disable_mm_preprocessor_cache=True,\n        limit_mm_per_prompt=limit_mm_per_prompt,\n        skip_tokenizer_init=False,\n        max_model_len=max_model_len,\n        load_format=load_format,\n        disable_log_stats=config.disable_log_stats,\n        max_num_batched_tokens=max_num_batched_tokens,\n        enable_chunked_prefill=config.enable_chunked_prefill,\n        enable_prefix_caching=True,\n        trust_remote_code=trust_remote_code,\n        # enable compilation config to bypass oom on rocm\n\t# change depends on your GPU memory size\n        compilation_config={\"cudagraph_capture_sizes\": [1, 2, 4, 8, 16, 32, 64]},\n        seed=config.get('seed', 0),\n    )\n\nThen, you can choose to enable CUDA graph by setting the following environment variables (see `this page <https://github.com/volcengine/verl/blob/v0.3.0.rc0/docs/README_vllm0.8.md>`_):\n\n.. code-block:: bash\n\n\tactor_rollout_ref.rollout.enforce_eager=False \\\n"
  },
  {
    "path": "verl_rl/docs/api/data.rst",
    "content": "Data interface\n=========================\n\nLast updated: 05/19/2025 (API docstrings are auto-generated).\n\nDataProto is the interface for data exchange.\n\nThe :class:`verl.DataProto` class contains two key members:\n\n- batch: a :class:`tensordict.TensorDict` object for the actual data\n- meta_info: a :class:`Dict` with additional meta information\n\nTensorDict\n~~~~~~~~~~~~\n\n:attr:`DataProto.batch` is built on top of :class:`tensordict`, a project in the PyTorch ecosystem.\nA TensorDict is a dict-like container for tensors. To instantiate a TensorDict, you must specify key-value pairs as well as the batch size.\n\n.. code-block:: python\n\n    >>> import torch\n    >>> from tensordict import TensorDict\n    >>> tensordict = TensorDict({\"zeros\": torch.zeros(2, 3, 4), \"ones\": torch.ones(2, 3, 5)}, batch_size=[2,])\n    >>> tensordict[\"twos\"] = 2 * torch.ones(2, 5, 6)\n    >>> zeros = tensordict[\"zeros\"]\n    >>> tensordict\n    TensorDict(\n    fields={\n        ones: Tensor(shape=torch.Size([2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),\n        twos: Tensor(shape=torch.Size([2, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),\n        zeros: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},\n    batch_size=torch.Size([2]),\n    device=None,\n    is_shared=False)\n\nOne can also index a tensordict along its batch_size. The contents of the TensorDict can be manipulated collectively as well.\n\n.. code-block:: python\n\n    >>> tensordict[..., :1]\n    TensorDict(\n    fields={\n        ones: Tensor(shape=torch.Size([1, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),\n        twos: Tensor(shape=torch.Size([1, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),\n        zeros: Tensor(shape=torch.Size([1, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},\n    batch_size=torch.Size([1]),\n    device=None,\n    is_shared=False)\n    >>> tensordict = tensordict.to(\"cuda:0\")\n    >>> tensordict = tensordict.reshape(6)\n\nFor more about :class:`tensordict.TensorDict` usage, see the official tensordict_ documentation.\n\n.. _tensordict: https://pytorch.org/tensordict/overview.html\n\n\nCore APIs\n~~~~~~~~~~~~~~~~~\n\n.. autoclass::  verl.DataProto\n   :members: to, select, union, make_iterator, concat\n"
  },
  {
    "path": "verl_rl/docs/api/single_controller.rst",
    "content": "Single Controller interface\n============================\n\nLast updated: 05/27/2025 (API docstrings are auto-generated).\n\nThe Single Controller provides a unified interface for managing distributed workers\nusing Ray or other backends and executing functions across them.\nIt simplifies the process of dispatching tasks and collecting results, particularly \nwhen dealing with data parallelism or model parallelism. \n\n\nCore APIs\n~~~~~~~~~~~~~~~~~\n\n.. autoclass:: verl.single_controller.Worker\n   :members: __init__, __new__, get_master_addr_port, get_cuda_visible_devices, world_size, rank\n\n.. autoclass:: verl.single_controller.WorkerGroup\n   :members: __init__,  world_size\n\n.. autoclass:: verl.single_controller.ClassWithInitArgs\n   :members: __init__, __call__\n\n.. autoclass:: verl.single_controller.ResourcePool\n   :members: __init__, world_size, local_world_size_list, local_rank_list\n\n.. autoclass:: verl.single_controller.ray.RayWorkerGroup\n   :members: __init__\n\n.. autofunction:: verl.single_controller.ray.create_colocated_worker_cls"
  },
  {
    "path": "verl_rl/docs/api/trainer.rst",
    "content": "Trainer Interface\n================================\n\nLast updated: 06/08/2025 (API docstrings are auto-generated).\n\nTrainers drive the training loop. Introducing new trainer classes in case of new training paradiam is encouraged.\n\n.. autosummary::\n   :nosignatures:\n\n   verl.trainer.ppo.ray_trainer.RayPPOTrainer\n\n\nCore APIs\n~~~~~~~~~~~~~~~~~\n\n.. autoclass:: verl.trainer.ppo.ray_trainer.RayPPOTrainer\n   :members: __init__, init_workers, fit\n\n.. automodule:: verl.utils.tokenizer\n   :members: hf_tokenizer\n\n.. automodule:: verl.trainer.ppo.core_algos\n   :members: agg_loss, kl_penalty, compute_policy_loss, kl_penalty\n\n.. automodule:: verl.trainer.ppo.reward\n   :members: load_reward_manager, compute_reward, compute_reward_async\n\n.. autoclass:: verl.workers.reward_manager.NaiveRewardManager\n\n.. autoclass:: verl.workers.reward_manager.DAPORewardManager\n"
  },
  {
    "path": "verl_rl/docs/api/utils.rst",
    "content": "Utilities\n============\n\nLast updated: 05/19/2025 (API docstrings are auto-generated).\n\nThis section documents the utility functions and classes in the VERL library.\n\nPython Functional Utilities\n------------------------------\n\n.. automodule:: verl.utils.py_functional\n   :members: append_to_dict\n\nFile System Utilities\n------------------------\n\n.. automodule:: verl.utils.fs\n   :members: copy_to_local\n\nTracking Utilities\n---------------------\n\n.. automodule:: verl.utils.tracking\n   :members: Tracking\n\nMetrics Utilities\n---------------------\n\n.. automodule::  verl.utils.metric\n   :members: reduce_metrics\n\nCheckpoint Management\n------------------------\n\n.. automodule:: verl.utils.checkpoint.checkpoint_manager\n   :members: find_latest_ckpt_path\n\n.. automodule:: verl.utils.checkpoint.fsdp_checkpoint_manager\n   :members: FSDPCheckpointManager\n\nDataset Utilities\n---------------------\n\n.. automodule:: verl.utils.dataset.rl_dataset\n   :members: RLHFDataset, collate_fn\n\nTorch Functional Utilities\n-----------------------------\n\n.. automodule:: verl.utils.torch_functional\n   :members: get_constant_schedule_with_warmup, masked_whiten, masked_mean, logprobs_from_logits\n\nSequence Length Balancing\n----------------------------\n\n.. automodule:: verl.utils.seqlen_balancing\n   :members: get_reverse_idx, rearrange_micro_batches\n\nUlysses Utilities\n--------------------\n\n.. automodule:: verl.utils.ulysses\n   :members: gather_outputs_and_unpad, ulysses_pad_and_slice_inputs\n\nFSDP Utilities\n------------------\n\n.. automodule:: verl.utils.fsdp_utils\n   :members: get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer,\n\nDebug Utilities\n-------------------\n\n.. automodule:: verl.utils.profiler\n   :members: log_gpu_memory_usage, GPUMemoryLogger\n\n"
  },
  {
    "path": "verl_rl/docs/ascend_tutorial/ascend_profiling.rst",
    "content": "在昇腾设备上基于FSDP后端进行数据采集\n====================================\n\nLast updated: 07/14/2025.\n\n这是一份在昇腾设备上基于FSDP后端使用GRPO或DAPO算法进行数据采集的教程。\n\n配置\n----\n\n复用verl/trainer/config/ppo_trainer.yaml中的配置项控制采集的模式和步数，\n通过verl/trainer/config/npu_profile/npu_profile.yaml中的配置项控制例如采集等级等参数。\n\n全局采集控制\n~~~~~~~~~~~~\n\n通过 ppo_trainer.yaml 中的参数控制采集步数和模式：\n\n-  trainer.profile_steps：\n   该参数可以设置为一个包含采集步数的列表，例如[2，\n   4]， 意味着将会采集第二步和第四步。如果该参数为null，则代表不进行采集\n-  actor_rollout_ref.profiler：\n   控制采集的ranks和模式\n\n   -  all_ranks：设为True代表对所有rank进行采集\n   -  ranks：当all_ranks不为True时，\n      通过ranks参数控制需要采集的rank，该参数设置为一个包含采集rank的列表， 例如[0，\n      1]\n   -  discrete：\n      控制采集的模式。当该参数设置为False，代表采集端到端的数据；当该参数设置为True，代表采用离散模式分训练阶段采集数据\n\n通过 npu_profile.yaml 中的参数控制具体采集行为：\n\n-  save_path：采集数据的存放路径\n-  level：采集等级，可选项为level_none、level0、level1和level2\n\n   -  level_none：不采集所有Level层级控制的数据，即关闭profiler_level\n   -  level0：采集上层应用数据、底层NPU数据以及NPU上执行的算子信息\n   -  level1：在level0的基础上多采集CANN层AscendCL数据和NPU上执行的AI\n      Core性能指标信息\n   -  level2：在level1的基础上多采集CANN层Runtime数据以及AI CPU\n\n-  record_shapes：是否记录张量形状\n-  with_memory：是否启用内存分析\n-  with_npu：是否采集device侧性能数据\n-  with_cpu：是否采集host侧性能数据\n-  with_module：是否记录框架层python调用栈信息\n-  with_stack：是否记录算子调用栈信息\n-  analysis：是否自动解析数据\n\n示例\n----\n\n禁用采集\n~~~~~~~~\n\n.. code:: yaml\n\n       trainer:\n           profile_steps: null # disable profile\n\n端到端采集\n~~~~~~~~~~\n\n.. code:: yaml\n\n       trainer:\n           profile_steps: [1, 2, 5]\n       actor_rollout_ref:\n            profiler:\n                discrete: False\n                all_ranks: True\n\n\n离散模式采集\n~~~~~~~~~~~~\n\n.. code:: yaml\n\n       trainer:\n           profile_steps: [1, 2, 5]\n       actor_rollout_ref:\n            profiler:\n                discrete: True\n                all_ranks: False\n                ranks: [0, 1]\n\n\n可视化\n------\n\n采集后的数据存放在用户设置的save_path下，可通过 `MindStudio Insight <https://www.hiascend.com/document/detail/zh/mindstudio/80RC1/GUI_baseddevelopmenttool/msascendinsightug/Insight_userguide_0002.html>`_ 工具进行可视化。\n\n如果analysis参数设置为False，采集之后需要进行离线解析：\n\n.. code:: python\n\n    import torch_npu\n    # profiler_path请设置为\"localhost.localdomain_<PID>_<timestamp>_ascend_pt\"目录的上一级目录\n    torch_npu.profiler.profiler.analyse(profiler_path=profiler_path)"
  },
  {
    "path": "verl_rl/docs/ascend_tutorial/ascend_profiling_en.rst",
    "content": "Data collection based on FSDP (Fully Sharded Data Parallel) backend on Ascend devices(NPU)\n==========================================================================================\n\nLast updated: 07/14/2025.\n\nThis is a tutorial for data collection using the GRPO or DAPO algorithm\nbased on FSDP on Ascend devices.\n\nConfiguration\n-------------\n\nReuse the configuration items in\nverl/trainer/config/ppo_trainer.yaml to control the collection mode\nand steps, you can also manage the collection behaviors such as\ncollection level via verl/trainer/config/npu_profile/npu_profile.yaml.\n\nGlobal collection control\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\nUse parameters in ppo_trainer.yaml to control the collection mode\nand steps.\n\n-  trainer.profile_steps: This parameter can be set as a list that has\n   collection steps, such as [2, 4], which means it will collect steps 2\n   and 4. If set to null, no collection occurs.\n-  actor_rollout_ref.profiler: Control the ranks and mode of profiling\n\n   -  all_ranks: Collects data from all ranks when set to true.\n   -  ranks: This parameter specifies which ranks to collect (e.g., [0,\n      1]) when all_ranks is False.\n   -  discrete: Controls the collection mode. If False, end-to-end data\n      is collected; if True, data is collected in discrete phases during\n      training.\n\nUse parameters in npu_profile.yaml to control collection behavior:\n\n-  save_path: Storage path for collected data.\n-  level: Collection level—options are level_none, level0, level1, and\n   level2\n\n   -  level_none: Disables all level-based data collection (turns off\n      profiler_level).\n   -  level0: Collect high-level application data, underlying NPU data,\n      and operator execution details on NPU.\n   -  level1: Extends level0 by adding CANN-layer AscendCL data and AI\n      Core performance metrics on NPU.\n   -  level2: Extends level1 by adding CANN-layer Runtime data and AI\n      CPU metrics.\n\n-  record_shapes: Whether to record tensor shapes.\n-  with_memory: Whether to enable memory analysis.\n-  with_npu: Whether to collect device-side performance data.\n-  with_cpu: Whether to collect host-side performance data.\n-  with_module: Whether to record framework-layer Python call stack\n   information.\n-  with_stack: Whether to record operator call stack information.\n-  analysis: Enables automatic data parsing.\n\nExamples\n--------\n\nDisabling collection\n~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n       trainer:\n           profile_steps: null # disable profile\n\nEnd-to-End collection\n~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n       trainer:\n           profile_steps: [1, 2, 5]\n       actor_rollout_ref:\n            profiler:\n                discrete: False\n                all_ranks: True\n\n\nDiscrete Mode Collection\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n       trainer:\n           profile_steps: [1, 2, 5]\n       actor_rollout_ref:\n            profiler:\n                discrete: True\n                all_ranks: False\n                ranks: [0, 1]\n\n\nVisualization\n-------------\n\nCollected data is stored in the user-defined save_path and can be\nvisualized by using the `MindStudio Insight <https://www.hiascend.com/document/detail/zh/mindstudio/80RC1/GUI_baseddevelopmenttool/msascendinsightug/Insight_userguide_0002.html>`_ tool.\n\nIf the analysis parameter is set to False, offline parsing is required after data collection:\n\n.. code:: python\n\n    import torch_npu\n    # Set profiler_path to the parent directory of the \"localhost.localdomain_<PID>_<timestamp>_ascend_pt\" folder\n    torch_npu.profiler.profiler.analyse(profiler_path=profiler_path)"
  },
  {
    "path": "verl_rl/docs/ascend_tutorial/ascend_quick_start.rst",
    "content": "verl x Ascend\n===================================\n\nLast updated: 06/17/2025.\n\n我们在 verl 上增加对华为昇腾设备的支持。\n\n硬件支持\n-----------------------------------\n\nAtlas 200T A2 Box16\n\nAtlas 900 A2 PODc\n\n\n安装\n-----------------------------------\n\n基础环境准备\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n+-----------+-------------+\n| software  | version     |\n+-----------+-------------+\n| Python    | == 3.10     |\n+-----------+-------------+\n| CANN      | == 8.1.RC1  |\n+-----------+-------------+\n| torch     | == 2.5.1    |\n+-----------+-------------+\n| torch_npu | == 2.5.1.RC1|\n+-----------+-------------+\n\n\nvllm & vllm-ascend\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n为了能够在 verl 中正常使用 vllm，需使用以下命令编译安装 vllm 和 vllm-ascend。请注意根据机器类型区分安装方式。\n\n.. code-block:: bash\n    \n    # vllm\n    git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm.git\n    cd vllm\n    pip install -r requirements-build.txt\n\n    # for Atlas 200T A2 Box16\n    VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/\n    \n    # for Atlas 900 A2 PODc\n    VLLM_TARGET_DEVICE=empty pip install -e .\n\n.. code-block:: bash\n    \n    # vllm-ascend\n    git clone -b v0.7.3.post1 --depth 1 https://github.com/vllm-project/vllm-ascend.git\n    cd vllm-ascend\n    export COMPILE_CUSTOM_KERNELS=1\n    python setup.py install\n\n安装verl\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. code-block:: bash\n\n    git clone https://github.com/volcengine/verl.git\n    cd verl\n    pip install -r requirements-npu.txt\n    pip install -e .\n\n其他三方库说明\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n+--------------+---------------+\n| software     | description   |\n+--------------+---------------+\n| transformers | v4.52.4       |\n+--------------+---------------+\n| flash_attn   | not supported |\n+--------------+---------------+\n| liger-kernel | not supported |\n+--------------+---------------+\n| tensordict   | 0.8.3 (ARM)   |\n+--------------+---------------+\n\n1. 支持通过 transformers 使能 --flash_attention_2， transformers 需大于等于 4.52.0版本。\n2. 不支持通过 flash_attn 使能 flash attention 加速。\n3. 不支持 liger-kernel 使能。\n4. 针对 ARM 服务器，tensordict 要求 0.8.3，可在依赖安装完成后再手动安装 tensordict。\n5. 针对 x86 服务器，需要安装 cpu 版本的 torchvision。\n\n.. code-block:: bash\n\n    pip install torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu\n\n\n快速开始\n-----------------------------------\n正式使用前，建议您通过对Qwen2.5-0.5B GRPO的训练尝试以检验环境准备和安装的正确性。\n\n1.下载数据集并将数据集预处理为parquet格式，以便包含计算RL奖励所需的必要字段\n\n.. code-block:: bash\n\n    python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k\n\n2.执行训练\n\n.. code-block:: bash\n\n    set -x\n\n    export VLLM_ATTENTION_BACKEND=XFORMERS\n\n    python3 -m verl.trainer.main_ppo \\\n        algorithm.adv_estimator=grpo \\\n        data.train_files=$HOME/data/gsm8k/train.parquet \\\n        data.val_files=$HOME/data/gsm8k/test.parquet \\\n        data.train_batch_size=128 \\\n        data.max_prompt_length=512 \\\n        data.max_response_length=128 \\\n        data.filter_overlong_prompts=True \\\n        data.truncation='error' \\\n        actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n        actor_rollout_ref.actor.optim.lr=5e-7 \\\n        actor_rollout_ref.model.use_remove_padding=False \\\n        actor_rollout_ref.actor.entropy_coeff=0.001 \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \\\n        actor_rollout_ref.actor.use_kl_loss=True \\\n        actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n        actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n        actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n        actor_rollout_ref.rollout.name=vllm \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n        actor_rollout_ref.rollout.n=5 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.critic_warmup=0 \\\n        trainer.logger=console \\\n        trainer.project_name='verl_grpo_example_gsm8k' \\\n        trainer.experiment_name='qwen2_7b_function_rm' \\\n        trainer.n_gpus_per_node=8 \\\n        trainer.nnodes=1 \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=5 \\\n        trainer.total_epochs=1 \\\n        trainer.device=npu $@\n\n\n支持现状\n-----------------------------------\n\n+-----------+-------------------------+-------------+-------------------+----------------------+\n| algorithm |         model           | rewards mae |  throughput ratio |        hardware      |\n+-----------+-------------------------+-------------+-------------------+----------------------+\n|   GRPO    | Qwen2.5-7B-instruct     |    0.38%    |        0.588      |  Atlas 200T A2 Box16 |\n+-----------+-------------------------+-------------+-------------------+----------------------+\n|   GRPO    | Qwen2.5-32B-instruct    |    0.30%    |        0.685      |  Atlas 200T A2 Box16 |\n+-----------+-------------------------+-------------+-------------------+----------------------+\n|   GRPO    | Qwen2.5-VL-3B-instruct  |    3.14%    |        0.470      |  Atlas 200T A2 Box16 |\n+-----------+-------------------------+-------------+-------------------+----------------------+\n|   GRPO    | Qwen2.5-VL-7B-instruct  |    3.30%    |        0.380      |  Atlas 200T A2 Box16 |\n+-----------+-------------------------+-------------+-------------------+----------------------+\n|   GRPO    | Qwen2.5-VL-32B-instruct |    0.79%    |        0.568      |  Atlas 200T A2 Box16 |\n+-----------+-------------------------+-------------+-------------------+----------------------+\n|   DAPO    | Qwen2.5-7B-instruct     |    3.83%    |        pending    |  Atlas 200T A2 Box16 |\n+-----------+-------------------------+-------------+-------------------+----------------------+\n|  SFT-PEFT | Qwen2.5-0.5B-instruct   |    0.06%    |        0.305      |  Atlas 900 A2 PODc   |\n+-----------+-------------------------+-------------+-------------------+----------------------+\n\n精度对比说明\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n对于 SFT 类算法，我们期望在相同配置下华为昇腾设备与 A100 的 loss 平均绝对误差<= 2%。计算方式如下图。更多信息请参考 `精度计算说明 <https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/LMaccuracy_0001.html>`_。\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/loss_comparison.png?raw=true\n   :alt: loss_comparison\n\n根据经验，对于 GRPO 等 RL 类算法，我们期望在相同配置下华为昇腾设备与 A100 的 rewards 平均绝对误差<= 4%，计算方式参考上图。\n\n\n吞吐对比说明\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nAscend npu 和 A100 分别取日志中前4个 step 的 \"perf/throughput\" 做平均， throughput ratio = npu 平均值 / A100 平均值。 \n\n\n\n计划\n-----------------------------------\n\n查看 `roadmap <https://github.com/volcengine/verl/discussions/900>`_ 获取更多特性的支持进度。\n\n\n\n声明\n-----------------------------------\nverl中提供的ascend支持代码皆为参考样例，商业使用请通过官方正式途径沟通，谢谢。\n"
  },
  {
    "path": "verl_rl/docs/conf.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\n# Configuration file for the Sphinx documentation builder.\r\n#\r\n# This file only contains a selection of the most common options. For a full\r\n# list see the documentation:\r\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\r\n\r\n# -- Path setup --------------------------------------------------------------\r\n\r\n# If extensions (or modules to document with autodoc) are in another directory,\r\n# add these directories to sys.path here. If the directory is relative to the\r\n# documentation root, use os.path.abspath to make it absolute, like shown here.\r\n#\r\n# import os\r\n# import sys\r\n# sys.path.insert(0, os.path.abspath('.'))\r\n\r\n\r\n# -- Project information -----------------------------------------------------\r\n\r\nproject = \"verl\"\r\ncopyright = \"2024 ByteDance Seed Foundation MLSys Team\"\r\nauthor = \"Guangming Sheng, Chi Zhang, Yanghua Peng, Haibin Lin\"\r\n\r\n\r\n# -- General configuration ---------------------------------------------------\r\n# The master toctree document.\r\nmaster_doc = \"index\"\r\n\r\n# Add any Sphinx extension module names here, as strings. They can be\r\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\r\n# ones.\r\nextensions = [\r\n    \"myst_parser\",\r\n    \"sphinx.ext.autodoc\",\r\n    \"sphinx.ext.autosummary\",\r\n    \"sphinx.ext.autosectionlabel\",\r\n    \"sphinx.ext.napoleon\",\r\n    \"sphinx.ext.viewcode\",\r\n]\r\n# Use Google style docstrings instead of NumPy docstrings.\r\nnapoleon_google_docstring = True\r\nnapoleon_numpy_docstring = False\r\n\r\n# The suffix(es) of source filenames.\r\n# You can specify multiple suffix as a list of string:\r\nsource_suffix = {\r\n    \".rst\": \"restructuredtext\",\r\n    \".md\": \"markdown\",\r\n}\r\n\r\n# Add any paths that contain templates here, relative to this directory.\r\ntemplates_path = [\"_templates\"]\r\n\r\n# The language for content autogenerated by Sphinx. Refer to documentation\r\n# for a list of supported languages.\r\n#\r\n# This is also used if you do content translation via gettext catalogs.\r\n# Usually you set \"language\" from the command line for these cases.\r\nlanguage = \"en\"\r\n\r\n# List of patterns, relative to source directory, that match files and\r\n# directories to ignore when looking for source files.\r\n# This pattern also affects html_static_path and html_extra_path.\r\nexclude_patterns = [\"_build\", \"Thumbs.db\", \".DS_Store\"]\r\n\r\n\r\n# -- Options for HTML output -------------------------------------------------\r\n\r\n# The theme to use for HTML and HTML Help pages.  See the documentation for\r\n# a list of builtin themes.\r\n#\r\nhtml_theme = \"sphinx_rtd_theme\"\r\n\r\n# Add any paths that contain custom static files (such as style sheets) here,\r\n# relative to this directory. They are copied after the builtin static files,\r\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\r\nhtml_static_path = [\"_static\"]\r\n\r\n# Add the JavaScript file\r\nhtml_js_files = [\r\n    \"js/runllm-widget.js\",\r\n]\r\n\r\nexclude_patterns += [\"README.md\", \"README_vllm0.7.md\"]\r\n\r\nsuppress_warnings = [\"ref.duplicate\", \"ref.myst\"]\r\n"
  },
  {
    "path": "verl_rl/docs/examples/config.rst",
    "content": ".. _config-explain-page:\n\nConfig Explanation\n===================\n\nLast updated: 06/18/2025.\n\nppo_trainer.yaml for RL FSDP Backend\n-------------------------------------\n\nData\n~~~~\n\n.. code:: yaml\n\n   data:\n     tokenizer: null\n     train_files: ~/data/rlhf/gsm8k/train.parquet\n     val_files: ~/data/rlhf/gsm8k/test.parquet\n     prompt_key: prompt\n     max_prompt_length: 512\n     max_response_length: 512\n     train_batch_size: 1024\n     return_raw_input_ids: False  # This should be set to true when the tokenizer between policy and rm differs\n     return_raw_chat: False\n     return_full_prompt: False\n     shuffle: True\n     filter_overlong_prompts: False\n     filter_overlong_prompts_workers: 1\n     truncation: error\n     image_key: images\n     trust_remote_code: True\n     custom_cls:\n        path: null\n        name: null\n\n- ``data.train_files``: Training set parquet. Can be a list or a single\n  file. The program will read all files into memory, so it can't be too\n  large (< 100GB). The path can be either local path or HDFS path. For\n  HDFS path, we provide utils to download it to DRAM and convert the\n  HDFS path to local path.\n- ``data.val_files``: Validation parquet. Can be a list or a single\n  file.\n- ``data.prompt_key``: The field in the dataset where the prompt is\n  located. Default is 'prompt'.\n- ``data.max_prompt_length``: Maximum prompt length. All prompts will be\n  left-padded to this length. An error will be reported if the length is\n  too long\n- ``data.max_response_length``: Maximum response length. Rollout in RL\n  algorithms (e.g. PPO) generates up to this length\n- ``data.train_batch_size``: Batch size sampled for one training\n  iteration of different RL algorithms.\n- ``data.return_raw_input_ids``: Whether to return the original\n  input_ids without adding chat template. This is mainly used to\n  accommodate situations where the reward model's chat template differs\n  from the policy. It needs to be decoded first, then apply the RM's\n  chat template. If using a model-based RM, and the policy and RM\n  chat_templates are different, this flag needs to be set\n- ``data.return_raw_chat``: Whether to return the original chat (prompt)\n  without applying chat template.\n- ``data.return_full_prompt``: Whether to return the full prompt with chat template\n- ``data.shuffle``: Whether to shuffle the data in the dataloader.\n- ``data.filter_overlong_prompts``: Default don't filter.\n- ``data.filter_overlong_prompts_workers``: For large-scale dataset, filtering\n  overlong prompts could be timeconsuming. You cat set the ``filter_overlong_prompts_workers``\n  to use multiprocessing for speed up. Default to 1.\n- ``data.truncation``: Truncate the input_ids or prompt length if they\n  exceed max_prompt_length. Default is 'error', not allow exceed the\n  max_prompt_length. The users should increase the max_prompt_length if\n  throwing the error. You can also set ``left``, ``right`` and ``middle``. \n  When ``middle`` is selected, the logic splits the allowed max length roughly in half \n  and keeps the head and tail of the sequence, effectively discarding the middle section.\n- ``data.image_key``: The field in the multi-modal dataset where the image is\n  located. Default is 'images'.\n- ``data.trust_remote_code``: If the remote tokenizer has python file, we can use this field to allow \n  using remote tokenizer. For example: moonshotai/Moonlight-16B-A3B-Instruct\n\nCustomized Dataset\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nCustomized dataset extension is implemented for the SFT trainer and can be extended to other trainers with similar changes.\n\n.. code:: yaml\n\n   custom_cls:\n     path: null\n     name: null\n\n- ``data.custom_cls.path``: The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used.\n- ``data.custom_cls.name``: The name of the dataset class within the specified file.\n\nActor/Rollout/Reference Policy\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n   actor_rollout_ref:\n    hybrid_engine: True\n    model:\n      path: ~/models/deepseek-llm-7b-chat\n      external_lib: null\n      override_config:\n        model_config: {}\n        moe_config:  # Megatron only, can adjust moe configuration\n          freeze_moe_router: False  # Megatron only, can freeze moe router (no grad)\n      enable_gradient_checkpointing: False\n      enable_activation_offload: False\n      trust_remote_code: False\n      use_remove_padding: False\n    actor:\n      strategy: fsdp  # This is for backward-compatibility\n      ppo_mini_batch_size: 256\n      ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu\n      ppo_micro_batch_size_per_gpu: 8\n      use_dynamic_bsz: False\n      ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}\n      grad_clip: 1.0\n      clip_ratio: 0.2\n      entropy_coeff: 0.0\n      use_kl_loss: False # True for GRPO\n      use_torch_compile: True # False to disable torch compile\n      kl_loss_coef: 0.001 # for grpo\n      kl_loss_type: low_var_kl # for grpo\n      ppo_epochs: 1\n      data_loader_seed: null\n      shuffle: False\n      ulysses_sequence_parallel_size: 1 # sp size\n      optim:\n        lr: 1e-6\n        lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.\n        lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n        min_lr_ratio: 0.0   # only used with cosine lr scheduler, default to 0.0\n        num_cycles: 0.5     # only used with cosine lr scheduler, default to 0.5\n        warmup_style: constant  # select from constant/cosine\n        total_training_steps: -1  # must be override by program\n      fsdp_config:\n        wrap_policy:\n          # transformer_layer_cls_to_wrap: None\n          min_num_params: 0\n        param_offload: False\n        optimizer_offload: False\n        fsdp_size: -1\n      checkpoint:\n        # What to include in saved checkpoints\n        # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n        save_contents: ['model', 'optimizer', 'extra']\n        # For more flexibility, you can specify the contents to load from the checkpoint.\n        load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}\n    ref:\n      fsdp_config:\n        param_offload: False\n        wrap_policy:\n          # transformer_layer_cls_to_wrap: None\n          min_num_params: 0\n      log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n      log_prob_micro_batch_size_per_gpu: 16\n      log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n      log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n      ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size\n    rollout:\n      name: vllm\n      temperature: 1.0\n      top_k: -1 # 0 for hf rollout, -1 for vllm rollout\n      top_p: 1\n      prompt_length: ${data.max_prompt_length}  # not use for opensource\n      response_length: ${data.max_response_length}\n      # for vllm rollout\n      dtype: bfloat16 # should align with FSDP\n      gpu_memory_utilization: 0.5\n      ignore_eos: False\n      enforce_eager: True\n      free_cache_engine: True\n      load_format: dummy_dtensor\n      tensor_model_parallel_size: 2\n      max_num_batched_tokens: 8192\n      max_num_seqs: 1024\n      log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n      log_prob_micro_batch_size_per_gpu: 16\n      log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n      log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n      # for hf rollout\n      do_sample: True\n      engine_kwargs: # inference engine parameters\n        vllm:\n          swap_space: null # null means \"use the engine default value\" (usually 4 GB), setting it to, e.g., 32 means 32 GB\n          disable_mm_preprocessor_cache: False # disable preprocessor cache for multimodel models\n        sglang:\n          attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla\n\n      n: 1 # for each prompt, sample n responses (i.e. num sample times). set it to values > 1 for grpo, rloo\n      val_kwargs:\n        # sampling parameters for validation\n        top_k: -1 # 0 for hf rollout, -1 for vllm rollout\n        top_p: 1.0\n        temperature: 0\n        n: 1\n        do_sample: False # default eager for validation\n\n      agent:\n        custom_async_server: # Use custom async server implementation for rollout\n          path: null\n          name: null\n\n**Common config for actor, rollout and reference model**\n\n- ``actor_rollout_ref.hybrid_engine``: Whether it's a hybrid engine,\n  currently only supports hybrid engine\n- ``actor_rollout_ref.model.path``: Huggingface model path. This can be\n  either local path or HDFS path. For HDFS path, we provide utils to\n  download it to DRAM and convert the HDFS path to local path.\n- ``actor_rollout_ref.model.external_libs``: Additional Python packages\n  that need to be imported. Used to register models or tokenizers into\n  the Huggingface system.\n- ``actor_rollout_ref.model.override_config``: Used to override some of\n  the model's original configurations, mainly dropout\n- ``actor_rollout_ref.model.enable_gradient_checkpointing``: FSDP only, decide\n  Whether to enable gradient checkpointing for the actor,\n  Megatron uses recompute options in ``override_transformer_config`` to set this\n- ``actor_rollout_ref.model.enable_activation_offload``: Whether to enable\n  activation offloading for the actor\n- ``actor_rollout_ref.model.trust_remote_code``: Whether to enable loading\n  a remote code model\n- ``actor_rollout_ref.model.use_fused_kernels``: Whether to use fused\n  kernels in the model. If set to True, the following parameters will be\n  used.\n  - ``actor_rollout_ref.model.fused_kernel_options.impl_backend``: The\n  implementation backend for fused kernels. Options: \"triton\" or\n  \"torch\". Default is \"torch\".\n  While in megatron, we only support \"triton\" as the\n  implementation backend, so there is no need for this option.\n- ``actor_rollout_ref.model.use_remove_padding``: Whether to use remove\n  padding in the model. If set to True, the model will remove padding\n  tokens in the input_ids and response_ids. This helps a lot in improving model running efficiency.\n\n**Actor model**\n\n- ``actor_rollout_ref.actor.strategy``: fsdp or megatron. In this\n  example, we use fsdp backend.\n\n- ``actor_rollout_ref.actor.ppo_mini_batch_size``: One sample is split\n  into multiple sub-batches with batch_size=ppo_mini_batch_size for PPO\n  updates. The ppo_mini_batch_size is a global num across all workers/gpus\n\n- ``actor_rollout_ref.actor.ppo_micro_batch_size``: [Will be deprecated, use ppo_micro_batch_size_per_gpu] \n  Similar to gradient accumulation, the micro_batch_size_per_gpu for one forward pass,\n  trading speed for GPU memory. The value represent the global view.\n\n- ``actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``: Similar to gradient\n  accumulation, the micro_batch_size_per_gpu for one forward pass, trading speed\n  for GPU memory. The value represent the local num per gpu.\n\n- ``actor_rollout_ref.actor.grad_clip``: Gradient clipping for actor\n  updates\n- ``actor_rollout_ref.actor.use_kl_loss``: to use kl loss in actor. When used, we are not applying KL in the reward function.\n\n- ``actor_rollout_ref.actor.clip_ratio``: PPO clip ratio\n\n- ``actor_rollout_ref.actor.use_torch_compile``: Whether to use torch compile in actor\n\n- ``actor_rollout_ref.actor.entropy_coeff``: The weight of entropy when\n  calculating PPO loss. The default value is changed to 0.0 since v0.3.x\n\n- ``actor_rollout_ref.actor.ppo_epochs``: Number of epochs for PPO\n  updates on one set of sampled data\n\n- ``actor_rollout_ref.actor.data_loader_seed``: From torch 2.6.0 Megatron backend can get wrong seed generated by pytorch \n  between cp ranks and cause misalignment between data on these ranks, so we shall manually set the seed to avoid hanging\n  issue. if ``actor_rollout_ref.actor.shuffle`` is not null, this must be set.\n\n- ``actor_rollout_ref.actor.shuffle``: Whether to shuffle data when\n  there are multiple epochs\n\n- ``actor_rollout_ref.actor.optim``: Actor's optimizer parameters\n\n- ``actor_rollout_ref.actor.fsdp_config``: FSDP config for actor\n  training\n\n  - ``wrap_policy``: FSDP wrap policy. By default, it uses Huggingface's\n    wrap policy, i.e., wrapping by DecoderLayer\n\n    - No need to set transformer_layer_cls_to_wrap, so we comment it.\n\n  - ``*_offload``: Whether to enable parameter, gradient and optimizer\n    offload\n\n    - Trading speed for GPU memory.\n\n- ``actor_rollout_ref.actor.use_kl_loss``: Whether to enable kl loss. Default is False.\n\n- ``actor_rollout_ref.actor.kl_loss_coef``: The coefficient of kl loss. Default is 0.001. \n\n- ``actor_rollout_ref.actor.kl_loss_type``: Support ``kl`` (``k1``), ``abs``, ``mse`` (``k2``), ``low_var_kl`` (``k3``) and ``full``. How to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty()` in `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py>`_ . See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\n- ``actor_rollout_ref.actor.checkpoint``: The configurations of checkpoint function in actor\n\n  - ``save_contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint.\n    The extra information includes Rng states currently, FSDP supported lr_scheduler, and Megatron opt_param_scheduler will coming soon.\n    We do not store hf_model in checkpoint by default, but we provide a tool in ``scripts/model_merge.py`` to convert checkpoint format to hf format.\n\n  - ``load_contents``: The contents to load in the checkpoint, you can specify different checkpoint loading contents. By default, it is the same with ``save_checkpoint``.\n\n**Reference Model**\n\nReference model will be enabled when ``actor.use_kl_loss`` or/and ``algorithm.use_kl_in_reward`` is/are True.\n\n- ``actor_rollout_ref.ref``: FSDP config same as actor. **For models\n  larger than 7B, it's recommended to turn on offload for ref by\n  default**\n\n- ``actor_rollout_ref.ref.log_prob_micro_batch_size``: [Will be deprecate, use log_prob_micro_batch_size_per_gpu]\n  The batch size for one forward pass in the computation of ``ref_log_prob``. The value represent the global num.\n\n- ``actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu``: The batch size\n  for one forward pass in the computation of ``ref_log_prob``. The value represent the local num per gpu.\n\n**Rollout Model**\n\n- ``actor_rollout_ref.rollout.name``: hf/vllm/sglang.\n\n- Rollout (Auto-regressive) parameters. The key should be equal to the\n  property name in vLLM's ``SamplingParams``.\n\n  - ``temperature``, ``top_k``, ``top_p`` and others: Sampling\n    parameters in ``SamplingParams``.\n\n- ``actor_rollout_ref.rollout.dtype``: Rollout model parameters type. This should be align with\n  the actor model parameter type in FSDP/Megatron backend.\n\n- ``actor_rollout_ref.rollout.gpu_memory_utilization``:\n\n  - For vLLM v0.7.0 and later: The fraction of **total** GPU memory to be used for the vLLM instance.\n  - For SGLang: Corresponding to ``mem_fraction_static``, the fraction of the free GPU memory used for **static** memory like model weights and KV cache. \n\n- ``actor_rollout_ref.rollout.tensor_model_parallel_size``: TP size for rollout. Only effective\n  for vllm.\n\n- ``actor_rollout_ref.rollout.log_prob_micro_batch_size``: [Will be deprecate, use log_prob_micro_batch_size_per_gpu]\n  The batch size for one forward pass in the computation of ``log_prob``. The value represent the global num.\n\n- ``actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu``: Micro batch size per gpu (The batch size for\n  one forward pass) for recalculating ``log_prob``. The value represent the local num per gpu.\n\n- ``actor_rollout_ref.rollout.do_sample``: Whether to sample during training rollout. If set to False, the rollout model\n  will perform greedy sampling.\n\n- ``actor_rollout_ref.rollout.val_kwargs```: Sampling parameters used specifically during validation.\n\n  - ``top_k``: Top-k sampling parameter. Default to -1 for vLLM rollout or 0 for HF rollout.\n  - ``top_p``: Top-p sampling parameter. Default is 1.0 (disabled).\n  - ``temperature``: Sampling temperature. Default is 0 (deterministic greedy).\n  - ``n``: Number of responses to generate during validation. Default is 1.\n  - ``do_sample``: Whether to use sampling during validation. Default is False for\n    deterministic outputs. When set to True, the rollout will use the ``actor_rollout_ref.rollout.val_kwargs`` parameters\n    (top_k, top_p, temperature) to control the sampling behavior.\n\n- ``actor_rollout_ref.rollout.engine_kwargs.vllm``: extra vllm engine args\n\n  - ``swap_space``: swap space in GB used by the inference engine. Positive integer, e.g., ``32`` means 32 GB. ``null``: means not setting and using the engine default value (usually, e.g., 4 GB for vLLM)\n  - ``disable_mm_preprocessor_cache``: Whether to disable preprocessor cache for multimodel models. \n\n- ``actor_rollout_ref.rollout.engine_kwargs.sglang``: extra sglang engine args\n\n  - ``attention_backend``: The attention backend to use for the inference engine.\n\n    - ``null``: means not setting and using the engine default value (usually, e.g., ``fa3`` for SGLang)\n    - ``flashinfer``: Use flashinfer attention backend.\n    - ``triton``: Use triton attention backend.\n    - ``flashmla``: Use flashmla attention backend.\n\n- ``actor_rollout_ref.rollout.ignore_eos``: Whether to ignore the EOS\n  token and continue generating tokens after the EOS token is generated.\n\n- ``actor_rollout_ref.rollout.free_cache_engine``: Offload the KVCache\n  after rollout generation stage. Default is True. When set to True,\n  for vllm v0.5.4 and v0.6.3, we need to disable the usage of CUDAGraph\n  (set ``enforce_eager`` to True.)\n\n- ``actor_rollout_ref.rollout.enforce_eager``: Whether to use CUDAGraph\n  in vLLM generation. Default set to True to disable CUDAGraph.\n\n- ``actor_rollout_ref.rollout.load_format``: Which weight loader to use\n  to load the actor model weights to the rollout model.\n\n  - ``auto``: Use Megatron weight loader.\n  - ``megatron``: Use Megatron weight loader. Deployed with Megatron\n    backend. The input model ``state_dict()`` is already partitioned\n    along TP dimension and already gathered along PP dimension. This\n    weight loader requires that the Rollout model and Actor model's\n    parameters shape and name should be identical.\n  - ``dtensor``: Default solution when using Huggingface weight loader.\n    Deployed with FSDP backend and the state_dict_type is\n    ``StateDictType.SHARDED_STATE_DICT``. Recommend to use this weight\n    loader\n  - ``hf``: Use Huggingface weight loader. Deployed with FSDP backend\n    and the state_dict_type is ``StateDictType.FULL_STATE_DICT``. This\n    solution doesn't need to rewrite the weight loader for each model\n    implemented in vLLM but it results in larger peak memory usage.\n  - ``dummy_hf``, ``dummy_megatron``, ``dummy_dtensor``: Random\n    initialization.\n\n.. note:: **NOTED**: In this config field, users only need to select from ``dummy_megatron``, ``dummy_dtensor``, ``dummy_hf`` for rollout initialization and our hybrid engine will select the corresponding weight loader (i.e., ``megatron``, ``dtensor``, ``hf``) during actor/rollout weight synchronization.\n\n\nMegatron Optimizer and Optimizer Parameter Scheduler\n____________________________________________________\n\n.. code:: yaml\n\n    optim:\n      optimizer: adam\n      lr: 1e-6\n      clip_grad: 1.0\n      total_training_steps: -1  # must be override by program\n      lr_warmup_init: 0.0  # initial learning rate for warmup, default to 0.0\n      lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.\n      lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n      lr_decay_steps: null\n      lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root\n      min_lr: 0.0 # minimum learning rate, default to 0.0\n      weight_decay: 0.01\n      weight_decay_incr_style: constant # select from constant/linear/cosine\n      lr_wsd_decay_style: exponential # select from constant/exponential/cosine\n      lr_wsd_decay_steps: null\n      use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler\n\n\nNotice that there are some differences in APIs between Megatron optimizer and FSDP optimizer.\n\n- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``warmup_style`` actually means the style of lr decay after warmup.\n- Megatron optimizer also support weight decay decay mechanism\n- ``use_checkpoint_opt_param_scheduler`` determines whether to use the checkpoint optimizer parameter scheduler. If set to True, the optimizer parameter scheduler will be saved in the checkpoint and loaded from the checkpoint during resuming training.\n\nFor learning rate decay, original Megatron pretrain default option of ``lr_decay_style`` is ``linear``,\nmeaning that the learning rate will be linearly decayed from the initial learning rate to ``min_lr`` within the\n``lr_decay_steps``. However, in verl, to align with FSDP's default behavior, we set the default\n``lr_decay_style`` to ``constant``, meaning that the learning rate will be kept constant after the warmup stage.\n\n\nCritic Model\n~~~~~~~~~~~~\n\nMost parameters for Critic are similar to Actor Model.\n\nReward Model\n~~~~~~~~~~~~\n\n.. code:: yaml\n\n   reward_model:\n     enable: False\n     model:\n       input_tokenizer: ${actor_rollout_ref.model.path}  # set this to null if the chat template is identical\n       path: ~/models/Anomy-RM-v0.1\n       external_lib: ${actor_rollout_ref.model.external_lib}\n       trust_remote_code: False\n       fsdp_config:\n         min_num_params: 0\n         param_offload: False\n     micro_batch_size_per_gpu: 16\n     max_length: null\n     reward_manager: naive\n\n- ``reward_model.enable``: Whether to enable reward model. If False, we\n  compute the reward only with the user-defined reward functions. In\n  GSM8K and Math examples, we disable reward model. For RLHF alignment\n  example using full_hh_rlhf, we utilize reward model to assess the\n  responses. If False, the following parameters are not effective.\n- ``reward_model.model``\n\n  - ``input_tokenizer``: Input tokenizer. If the reward model's chat\n    template is inconsistent with the policy, we need to first decode to\n    plaintext, then apply the rm's chat_template. Then score with RM. If\n    chat_templates are consistent, it can be set to null.\n  - ``path``: RM's HDFS path or local path. Note that RM only supports\n    AutoModelForSequenceClassification. Other model types need to define\n    their own RewardModelWorker and pass it from the code.\n  - ``trust_remote_code``: Whether to enable loading a remote code model,\n    default to False.\n- ``reward_model.reward_manager``:  Reward Manager. This defines the mechanism\n  of computing rule-based reward and handling different reward sources. Default\n  is ``naive``. If all verification functions are multiprocessing-safe, the reward\n  manager can be set to ``prime`` for parallel verification.\n\nCustomized Reward Function\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n  \n   custom_reward_function:\n     path: null\n     name: compute_score\n\n- ``custom_reward_function.path``: The path to the file containing your customized reward function. If not specified, pre-implemented reward functions will be used.\n- ``custom_reward_function.name`` (Optional) : The name of the reward function within the specified file. Default is 'compute_score'.\n\nAlgorithm\n~~~~~~~~~\n\n.. code:: yaml\n\n   algorithm:\n     gamma: 1.0\n     lam: 1.0\n     adv_estimator: gae\n     use_kl_in_reward: False\n     kl_penalty: kl  # how to estimate kl divergence\n     kl_ctrl:\n       type: fixed\n       kl_coef: 0.005\n       horizon: 10000\n       target_kl: 0.1\n\n- ``gamma``: discount factor\n- ``lam``: Trade-off between bias and variance in the GAE estimator\n- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``\n- ``use_kl_in_reward``: Whether to enable in-reward kl penalty. Default is False.\n- ``kl_penalty``: Support ``kl``, ``abs``, ``mse``, ``low_var_kl`` and ``full``. How to\n  calculate the kl divergence between actor and reference policy. For\n  specific options, refer to `kl_penalty()` in `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py>`_ .\n- ``kl_ctrl``: Config for in-reward kl_penalty controller\n  - ``kl_coef``: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.\n  - ``type``: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.\n  - ``horizon`` and ``target_kl``: See source code of AdaptiveKLController for details.\n\nTrainer\n~~~~~~~\n\n.. code:: yaml\n\n   trainer:\n     total_epochs: 30\n     project_name: verl_examples\n     experiment_name: gsm8k\n     logger: ['console', 'wandb']\n     log_val_generations: 0\n     nnodes: 1\n     n_gpus_per_node: 8\n     save_freq: -1\n     val_before_train: True\n     test_freq: 2\n     critic_warmup: 0\n     default_hdfs_dir: null # hdfs checkpoint path\n     default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} # local checkpoint path\n     resume_mode: auto # or disable or resume_path if resume_from_path is set\n     resume_from_path: null\n     remove_previous_ckpt_in_save: False\n     del_local_ckpt_after_load: False\n     ray_wait_register_center_timeout: 300\n\n- ``trainer.total_epochs``: Number of epochs in training.\n- ``trainer.project_name``: For wandb, swanlab, mlflow\n- ``trainer.experiment_name``: For wandb, swanlab, mlflow\n- ``trainer.logger``: Support console and wandb, swanlab, mlflow, tensorboard\n- ``trainer.log_val_generations``: The number of logged generation during validation (default ``0``)\n- ``trainer.nnodes``: Number of nodes used in the training.\n- ``trainer.n_gpus_per_node``: Number of GPUs per node.\n- ``trainer.save_freq``: The frequency (by iteration) to save checkpoint\n  of the actor and critic model.\n- ``trainer.val_before_train``: Whether to run validation before training.\n- ``trainer.test_freq``: The validation frequency (by iteration).\n- ``trainer.critic_warmup``: The number of iteration to train the critic\n  model before actual policy learning.\n- ``trainer.resume_mode``: The mode of resuming training. Support\n  ``disable``, ``auto`` and ``resume_path``. If set to ``auto`` as default, the\n  program will automatically resume from the latest checkpoint in the\n  ``default_local_dir``. If set to ``resume_path``, the program will resume\n  from the path specified in ``resume_from_path``.\n- ``trainer.resume_from_path``: The path to resume training from. Only\n  effective when ``resume_mode`` is set to ``resume_path``.\n- ``trainer.remove_previous_ckpt_in_save``: Whether to remove previous\n  checkpoints in the save directory. Default is False.\n- ``trainer.del_local_ckpt_after_load``: Whether to delete local\n  checkpoints after loading them. Default is False.\n- ``trainer.ray_wait_register_center_timeout``: The timeout for waiting\n  for the ray register center to be ready. Default is 300 seconds.\n\n\nThis figure illustrates how the configurations affect the training.\n\nhttps://excalidraw.com/#json=pfhkRmiLm1jnnRli9VFhb,Ut4E8peALlgAUpr7E5pPCA\n\n.. image:: https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d\n\n\nevaluation.yaml\n---------------\n\nData\n~~~~\n\n.. code:: yaml\n\n   data:\n     path: /tmp/math_Qwen2-7B-Instruct.parquet\n     prompt_key: prompt\n     response_key: responses\n     data_source_key: data_source\n     reward_model_key: reward_model\n\n- ``data.path``: Path to the dataset file (Parquet format).\n- ``data.prompt_key``: The field in the dataset where the prompt is located. Default is 'prompt'.\n- ``data.response_key``: The key holds the generated responses. This should be a list of strings representing the responses. Default is 'responses'.\n- ``data.data_source_key``: This is used to separate metric calculations for different data sources, ensuring that metrics are calculated independently for each source.\n- ``data.reward_model_key``: The key holds the reference answers. These reference answers typically serve as the ground truth or test cases for the task.\n\nCustomized Reward Function\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n  \n   custom_reward_function:\n     path: null\n     name: compute_score\n\n- ``custom_reward_function.path``: The path to the file containing your customized reward function. If not specified, pre-implemented reward functions will be used.\n- ``custom_reward_function.name`` (Optional) : The name of the reward function within the specified file. Default is 'compute_score'.\n\nsft_trainer.yaml for SFT FSDP Backend\n--------------------------------------\n\n\nOptim\n~~~~~~~\n\n.. code:: yaml\n\n   optim:\n     lr: 1e-5\n     weight_decay: 0.01\n     warmup_steps_ratio: 0.1\n     clip_grad: 1.0\n     lr_scheduler: cosine\n\n- ``optim.lr``: Learning rate for the optimizer.\n- ``optim.weight_decay``: Weight decay for the optimizer.\n- ``optim.warmup_steps_ratio``: Ratio of warmup steps to total training steps.\n- ``optim.clip_grad``: Gradient clipping value.\n- ``optim.lr_scheduler``: Learning rate scheduler type. Options:\n\n  - ``cosine``: Cosine learning rate scheduler with warmup (default).\n  - ``wsd``: Warmup-Stable-Decay scheduler that provides a stable learning rate phase between warmup and decay phases.\n\nModel\n~~~~~~~~~~~~\n\nMost parameters for Model are similar to Reward Model.\n\n.. code:: yaml\n\n   model:\n     partial_pretrain: ~/models/gemma-1.1-7b-it\n     fsdp_config:\n       model_dtype: fp32\n       wrap_policy:\n         min_num_params: 0\n       cpu_offload: False\n       offload_params: False\n     external_lib: null\n     enable_gradient_checkpointing: False\n     trust_remote_code: False\n     lora_rank: 0\n     lora_alpha: 16\n     target_modules: all-linear\n     use_liger: False\n\n- ``partial_pretrain``: HDFS path or local path for the pretrained model.\n- ``fsdp_config``\n\n  - ``model_dtype``: Model parameters type, default to ``fp32``.\n    Support: ``bf16``, ``fp16``, ``fp32``.\n  - ``cpu_offload``: Whether to enable CPU offloading for FSDP. If True,\n    the offload_params will be used as argument.\n  - ``offload_params``: Whether to offload parameters to CPU\n    when not involved in computation. If True, then this offloads gradients\n    to CPU as well, meaning that the optimizer step runs on CPU.\n\n- ``lora_rank``: The rank of the LoRA model, default to 0. If ``lora_rank``>0,\n  we will train LoRA modules instead of tuning the full model.\n- ``lora_alpha``: The alpha parameter for LoRA scaling, default to 16.\n- ``target_modules``: The names of the modules to apply the adapter to,\n  default to ``all-linear``. See `peft docs <https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.target_modules>`_ for detail.\n\n- ``use_liger``: Whether to enable Liger kernel, default to False. If True,\n  we apply Liger kernel to the model (depends on `liger-kernel`).\n"
  },
  {
    "path": "verl_rl/docs/examples/gsm8k_example.rst",
    "content": "GSM8K Example\n=============\n\nLast updated: 03/25/2025.\n\nIntroduction\n------------\n\nIn this example, we train an LLM to tackle the GSM8k task.\n\nPaper: https://arxiv.org/pdf/2110.14168\n\nDataset: https://huggingface.co/datasets/gsm8k\n\nNote that the original paper mainly focuses on training a verifier (a\nreward model) to solve math problems via Best-of-N sampling. In this\nexample, we train an RLHF agent using a rule-based reward model.\n\nDataset Introduction\n--------------------\n\nGSM8k is a math problem dataset. The prompt is an elementary school\nproblem. The LLM model is required to answer the math problem.\n\nThe training set contains 7473 samples and the test set contains 1319\nsamples.\n\n**An example**\n\nPrompt\n\n   Katy makes coffee using teaspoons of sugar and cups of water in the\n   ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups\n   of water, calculate the number of teaspoonfuls of sugar she used.\n\nSolution\n\n   The total ratio representing the ingredients she used to make the\n   coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the\n   number of teaspoons she used is 7/20, she used 7/20\\ *120 =\n   <<7/20*\\ 120=42>>42 #### 42\n\nStep 1: Prepare dataset\n-----------------------\n\n.. code:: bash\n\n   cd examples/data_preprocess\n   python3 gsm8k.py --local_dir ~/data/gsm8k\n\nStep 2: Download Model\n----------------------\n\nThere're three ways to prepare the model checkpoints for post-training:\n\n- Download the required models from huggingface or modelscope\n\n.. code:: bash\n\n   huggingface-cli download deepseek-ai/deepseek-math-7b-instruct --local-dir ~/models/deepseek-math-7b-instruct --local-dir-use-symlinks False\n   # or\n   modelscope download --model deepseek-ai/deepseek-math-7b-instruct --local_dir ~/models/deepseek-math-7b-instruct\n\n- Already store your store model in the local directory or HDFS path.\n- Also, you can directly use the model name in huggingface (e.g.,\n  deepseek-ai/deepseek-math-7b-instruct) in\n  ``actor_rollout_ref.model.path`` and ``critic.model.path`` field in\n  the run script. You can also download models from modelscope by setting environmental variable ``VERL_USE_MODELSCOPE=True``.\n  See examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh for example.\n\nNoted that users should prepare checkpoints for actor, critic and reward\nmodel.\n\n[Optional] Step 3: SFT your Model\n---------------------------------\n\nWe provide a SFT Trainer using PyTorch FSDP in\n`fsdp_sft_trainer.py <https://github.com/volcengine/verl/blob/main/verl/trainer/fsdp_sft_trainer.py>`_. \nUsers can customize their own SFT\nscript using our FSDP SFT Trainer.\n\nWe also provide various training scripts for SFT on GSM8K dataset in `gsm8k sft directory <https://github.com/volcengine/verl/blob/main/examples/sft/gsm8k/>`_.\n\n.. code:: shell\n\n   set -x\n\n   torchrun -m verl.trainer.fsdp_sft_trainer \\\n       data.train_files=$HOME/data/gsm8k/train.parquet \\\n       data.val_files=$HOME/data/gsm8k/test.parquet \\\n       data.prompt_key=question \\\n       data.response_key=answer \\\n       data.micro_batch_size_per_gpu=8 \\\n       model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \\\n       trainer.project_name=gsm8k-sft \\\n       trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \\\n       trainer.total_epochs=4 \\\n       trainer.logger='[\"console\",\"wandb\"]'\n\n\nIf you use AMD GPUs (ROCm kernel), you need to add the following environment variables into the run script:\n\n    .. code-block:: bash\n\n        export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n        export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n        export CUDA_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n\n\nStep 4: Perform PPO training with your model on GSM8K Dataset\n-------------------------------------------------------------\n\n- Prepare your own run.sh script. Here's an example for GSM8k dataset\n  and deepseek-llm-7b-chat model.\n- Users could replace the ``data.train_files`` ,\\ ``data.val_files``,\n  ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on\n  their environment.\n- See :doc:`config` for detailed explanation of each config field.\n\n**Reward Model/Function**\n\nWe use a rule-based reward model. We force the model to produce a final\nanswer following 4 “#” as shown in the solution. We extract the final\nanswer from both the solution and model's output using regular\nexpression matching. We compare them and assign a reward of 1 to correct\nanswer, 0.1 to incorrect answer and 0 to no answer.\n\n**Training Script**\n\nThe training script example for FSDP and Megatron-LM backend are stored in examples/ppo_trainer directory.\n\n.. code:: bash\n\n   cd ../ppo_trainer\n   bash run_deepseek7b_llm.sh\n\nThe script of run_deepseek7b_llm.sh\n\n.. code:: bash\n\n   set -x\n\n   python3 -m verl.trainer.main_ppo \\\n      data.train_files=$HOME/data/gsm8k/train.parquet \\\n      data.val_files=$HOME/data/gsm8k/test.parquet \\\n      data.train_batch_size=1024 \\\n      data.max_prompt_length=512 \\\n      data.max_response_length=512 \\\n      actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n      actor_rollout_ref.actor.optim.lr=1e-6 \\\n      actor_rollout_ref.model.use_remove_padding=True \\\n      actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n      actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n      actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n      actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n      actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n      actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n      actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n      actor_rollout_ref.rollout.name=vllm \\\n      actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n      actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n      actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n      critic.optim.lr=1e-5 \\\n      critic.model.use_remove_padding=True \\\n      critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n      critic.model.enable_gradient_checkpointing=True \\\n      critic.ppo_micro_batch_size_per_gpu=32 \\\n      critic.model.fsdp_config.param_offload=False \\\n      critic.model.fsdp_config.optimizer_offload=False \\\n      algorithm.kl_ctrl.kl_coef=0.001 \\\n      trainer.critic_warmup=0 \\\n      trainer.logger='[\"console\",\"wandb\"]' \\\n      trainer.project_name='verl_example_gsm8k' \\\n      trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n      trainer.n_gpus_per_node=8 \\\n      trainer.nnodes=1 \\\n      trainer.save_freq=-1 \\\n      trainer.test_freq=1 \\\n      trainer.total_epochs=15 $@\n\n\nIf you use AMD GPUs (ROCm kernel), you need to add the following environment variables into the run script:\n\n    .. code-block:: bash\n\n        export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n        export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n        export CUDA_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n\nIf you encounter any issues in using AMD GPUs running VeRL, feel free to contact me - `Yusheng Su <https://yushengsu-thu.github.io/>`_."
  },
  {
    "path": "verl_rl/docs/examples/multi_modal_example.rst",
    "content": "Multi-Modal Example Architecture\n=================================\n\nLast updated: 04/28/2025.\n\nIntroduction\n------------\n\nNow, verl has supported multi-modal training. You can use fsdp and \nvllm/sglang to start a multi-modal RL task. Megatron supports is also \non the way.\n\nFollow the steps below to quickly start a multi-modal RL task.\n\nStep 1: Prepare dataset\n-----------------------\n\n.. code:: python\n\n    # it will be saved in the $HOME/data/geo3k folder\n    python examples/data_preprocess/geo3k.py\n\nStep 2: Download Model\n----------------------\n\n.. code:: bash\n\n    # download the model from huggingface\n    python3 -c \"import transformers; transformers.pipeline(model='Qwen/Qwen2.5-VL-7B-Instruct')\"\n\nStep 3: Perform GRPO training with multi-modal model on Geo3K Dataset\n---------------------------------------------------------------------\n\n.. code:: bash\n\n    # run the task\n    bash examples/grpo_trainer/run_qwen2_5_vl-7b.sh\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "verl_rl/docs/examples/ppo_code_architecture.rst",
    "content": "PPO Example Architecture\n========================\n\nLast updated: 02/17/2025.\n\nLet's start with the Proximal Policy Optimization algorithm, which is\nmost widely used algorithm in LLM post-training.\n\nThe main entry point of the PPO algorithm example is:\n`main_ppo.py <https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py>`_.\nIn this tutorial, we will go through the code architecture in `main_ppo.py <https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py>`_.\n\nDefine the data\n---------------\n\nUsers need to preprocess and store the dataset in parquet files.\nAnd we implement `RLHFDataset` to load and tokenize the parquet files.\n\nFor ``RLHFDataset`` (Default), at least 1 fields are required:\n\n- ``prompt``: Contains the string prompt\n\nWe already provide some examples of processing the datasets to parquet\nfiles in `data_preprocess directory <https://github.com/volcengine/verl/blob/main/examples/data_preprocess>`_. Currently, we support\npreprocess of GSM8k, MATH, Hellasage, Full_hh_rlhf datasets. See :doc:`../preparation/prepare_data` for\nmore information.\n\nDefine the reward functions for different datasets\n--------------------------------------------------\n\nIn this main entry point, the users only need to define their own reward\nfunction based on the datasets (or applications) utilized in PPO\ntraining.\n\nFor example, we already provide reward functions for `GSM8k <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/gsm8k.py>`_ \nand `MATH <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/math.py>`_\ndatasets in the ``_select_rm_score_fn``. In the ``RewardManager``, we\nwill compute the reward score based on the data_source to select\ncorresponding reward functions. For some RLHF datasets (e.g.,\nfull_hh_rlhf), the reward model is utilized to assess the responses\nwithout any reward functions. In this case, the ``RewardManager`` will\nreturn the ``rm_score`` computed by the reward model directly.\n\nSee `reward functions <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_ for detailed implementation.\n\nDefine worker classes\n---------------------\n\n.. code:: python\n\n   if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}: # for FSDP backend\n       assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n       from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker\n       from verl.single_controller.ray import RayWorkerGroup\n       ray_worker_group_cls = RayWorkerGroup\n\n   elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend\n       assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n       from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker\n       from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n       ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM\n\n   else:\n       raise NotImplementedError\n\n   from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n   role_worker_mapping = {\n       Role.ActorRollout: ActorRolloutRefWorker,\n       Role.Critic: CriticWorker,\n       Role.RefPolicy: ActorRolloutRefWorker\n   }\n\n   global_pool_id = 'global_pool'\n   resource_pool_spec = {\n       global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n   }\n   mapping = {\n       Role.ActorRollout: global_pool_id,\n       Role.Critic: global_pool_id,\n       Role.RefPolicy: global_pool_id,\n   }\n\nStep 1: Construct the mapping between roles and workers\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nA role represents a group of workers in the same process. We have\npre-defined several roles in `ray_trainer.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py#L38>`_.\n\n.. code:: python\n\n   class Role(Enum):\n       \"\"\"\n       To create more roles dynamically, you can subclass Role and add new members\n       \"\"\"\n       Actor = 0  # This worker only has Actor\n       Rollout = 1 # This worker only has Rollout\n       ActorRollout = 2 # This worker has both actor and rollout, it's a HybridEngine\n       Critic = 3 # This worker only has critic\n       RefPolicy = 4 # This worker only has reference policy\n       RewardModel = 5 # This worker only has reward model\n       ActorRolloutRef = 6 # This worker contains actor, rollout and reference policy simultaneously \n\nStep 2: Define the worker class corresponding to this role\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n- We have pre-implemented the ``ActorRolloutRefWorker``. Through\n  different configs, it can be a standalone actor, a standalone rollout,\n  an ActorRollout HybridEngine, or an ActorRolloutRef HybridEngine\n- We also pre-implemented workers for ``Actor``, ``Rollout``,\n  ``Critic``, ``Reward Model`` and ``Reference model`` on two different\n  backend: PyTorch FSDP\n  and Megatron-LM.\n  See `FSDP Workers <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py>`_ \n  and `Megatron-LM Workers <https://github.com/volcengine/verl/blob/main/verl/workers/megatron_workers.py>`_\n  for more information.\n\nStep 3: Define resource pool id and resource pool spec\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n- Resource pool is a division of global GPU resources,\n  ``resource_pool_spec`` is a dict, mapping from id to # of GPUs\n\n  - In the above example, we defined a global resource pool:\n    global_pool_id, and then put all roles on this one resource pool\n    with all the GPUs in this post-training task. This refers to\n    *co-locate* placement where all the models share the same set of\n    GPUs.\n\n- See resource pool and placement for advance usage.\n\nDefining reward model/function\n------------------------------\n\n.. code:: python\n\n   # we should adopt a multi-source reward function here\n   # - for rule-based rm, we directly call a reward score\n   # - for model-based rm, we call a model\n   # - for code related prompt, we send to a sandbox if there are test cases\n   # - finally, we combine all the rewards together\n   # - The reward type depends on the tag of the data\n   if config.reward_model.enable:\n       from verl.workers.fsdp_workers import RewardModelWorker\n       role_worker_mapping[Role.RewardModel] = RewardModelWorker\n       mapping[Role.RewardModel] = global_pool_id\n    \n   reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)\n\n   # Note that we always use function-based RM for validation\n   val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1)\n\n   resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\nSince not all tasks use model-based RM, users need to define here\nwhether it's a model-based RM or a function-based RM\n\n- If it's a model-based RM, directly add the ``RewardModel`` role in the\n  resource mapping and add it to the resource pool mapping.\n\n  - Note that the pre-defined ``RewardModelWorker`` only supports models\n    with the structure of huggingface\n    ``AutoModelForSequenceClassification``. If it's not this model, you\n    need to define your own RewardModelWorker in `FSDP Workers <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py>`_ \n    and `Megatron-LM Workers <https://github.com/volcengine/verl/blob/main/verl/workers/megatron_workers.py>`_.\n\n- If it's a function-based RM, the users are required to classified the\n  reward function for each datasets.\n\n.. code:: python\n\n   def _select_rm_score_fn(data_source):\n       if data_source == 'openai/gsm8k':\n           return gsm8k.compute_score\n       elif data_source == 'lighteval/MATH':\n           return math.compute_score\n       else:\n           raise NotImplementedError\n\nSee reward functions implemented in `directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/>`_ \nfor more information.\n\nDefine, init and run the PPO Trainer\n------------------------------------\n\n.. code:: python\n\n   trainer = RayPPOTrainer(config=config,\n                           tokenizer=tokenizer,\n                           role_worker_mapping=role_worker_mapping,\n                           resource_pool_manager=resource_pool_manager,\n                           ray_worker_group_cls=ray_worker_group_cls,\n                           reward_fn=reward_fn,\n                           val_reward_fn=val_reward_fn)\n   trainer.init_workers()\n   trainer.fit()\n\n- We first initialize the ``RayPPOTrainer`` with user config, tokenizer\n  and all the above worker mapping, resource pool, worker group and\n  reward functions\n- We first call the ``trainer.init_workers()`` to initialize the models\n  on the allocated GPUs (in the resource pool)\n- The actual PPO training will be executed in ``trainer.fit()``\n\nverl can be easily extended to other RL algorithms by reusing the Ray\nmodel workers, resource pool and reward functions. See :doc:`extension<../advance/dpo_extension>` for\nmore information.\n\nDetails of the ``RayPPOTrainer`` is discussed in :doc:`Ray Trainer<../workers/ray_trainer>`.\n"
  },
  {
    "path": "verl_rl/docs/examples/sandbox_fusion_example.rst",
    "content": "Sandbox Fusion Example\n============================\n\nLast updated: 06/27/2025.\n\nIntroduction\n------------\n\nSandbox Fusion is a remote code sandbox service that provides a secure environment for running and evaluating code generated by Large Language Models (LLMs). This example demonstrates how to train an LLM and use Sandbox Fusion to verify generated code, enhancing both security and performance.\n\nBy leveraging a remote code sandbox service with greater CPU resources for concurrent code verification, you can reduce the reward stage time by 10-30%, depending on the quality of the generated code.\n\nStep 1: Prepare the Dataset\n---------------------------\n\nWe use the Eurus-2-RL-Data dataset for training. This dataset combines math and code questions, making it suitable for LLM training tasks. You can download it from HuggingFace: `Eurus-2-RL-Data Dataset <https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data>`_.\n\nStep 2: Set Up the Sandbox Fusion Service\n-----------------------------------------\n\nSandbox Fusion is a remote code sandbox service designed to securely run and evaluate LLM-generated code. To use it:\n\n1. **Access Full Documentation**: For detailed setup instructions, refer to the `Sandbox Fusion Documentation <https://bytedance.github.io/SandboxFusion/>`_.\n2. **Deploy the Service**: Choose one of the following deployment methods:\n\n   - **Local Deployment**: Follow the guide `here <https://bytedance.github.io/SandboxFusion/docs/docs/get-started#local-deployment>`_.\n   - **FaaS Instance (Volcengine)**: Create an instance using the `Volcengine Documentation <https://www.volcengine.com/docs/6662/1539235>`_.\n\nAfter deployment, you will receive an API endpoint in the format: ``https://<ip-address-or-domain-name>/run_code``.\n\nStep 3: Configure the Training Script\n-------------------------------------\n\nTo integrate Sandbox Fusion into your training script, configure the following parameters:\n\n**Key Settings for Sandbox Fusion**\n\n- ``reward_model.sandbox_fusion.url='<API-endpoint>'``: Enable Sandbox Fusion by specifying the API endpoint (must end with ``/run_code``).\n- ``reward_model.sandbox_fusion.max_concurrent=256``: Set the maximum number of concurrent API requests to the Sandbox Fusion service.\n- ``reward_model.sandbox_fusion.memory_limit_mb=1024``: Set the memory limit (in MB) for each sandbox instance. Defaults to 1024MB if not specified.\n\n**Additional Optimization**\n\nTo further reduce code verification time, enable parallel processing with:  \n\n- ``reward_model.reward_manager=prime``: The Prime reward manager verifies code across multiple subprocesses concurrently.\n\n**Example Script**\n\nFor a practical implementation, refer to the example script:  \n\n``examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh``\n\nOnce you’ve set your API endpoint in the script, you can start the training job."
  },
  {
    "path": "verl_rl/docs/faq/faq.rst",
    "content": "Frequently Asked Questions\n====================================\n\nLast updated: 06/25/2025.\n\nRay related\n------------\n\nHow to add breakpoint for debugging with distributed Ray?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nPlease checkout the official debugging guide from Ray: https://docs.ray.io/en/latest/ray-observability/ray-distributed-debugger.html\n\n\n\"Unable to register worker with raylet\"\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThe cause of this issue is due to some system setting, e.g., SLURM added some constraints on how the CPUs are shared on a node. \nWhile `ray.init()` tries to launch as many worker processes as the number of CPU cores of the machine,\nsome constraints of SLURM restricts the `core-workers` seeing the `raylet` process, leading to the problem.\n\nTo fix this issue, you can set the config term ``ray_init.num_cpus`` to a number allowed by your system.\n\nDistributed training\n------------------------\n\nHow to run multi-node post-training with Ray?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nYou can start a ray cluster and submit a ray job, following the official guide from Ray: https://docs.ray.io/en/latest/ray-core/starting-ray.html\n\nThen in the configuration, set the ``trainer.nnode`` config to the number of machines for your job.\n\nHow to use verl on a Slurm-managed cluster?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nRay provides users with `this <https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html>`_ official\ntutorial to start a Ray cluster on top of Slurm. We have verified the :doc:`GSM8K example<../examples/gsm8k_example>`\non a Slurm cluster under a multi-node setting with the following steps.\n\n1. [Optional] If your cluster support `Apptainer or Singularity <https://apptainer.org/docs/user/main/>`_ and you wish\nto use it, convert verl's Docker image to an Apptainer image. Alternatively, set up the environment with the package\nmanager available on your cluster or use other container runtimes (e.g. through `Slurm's OCI support <https://slurm.schedmd.com/containers.html>`_) available to you.\n\n.. code:: bash\n\n    apptainer pull /your/dest/dir/vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3.sif docker://verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3\n\n2. Follow :doc:`GSM8K example<../examples/gsm8k_example>` to prepare the dataset and model checkpoints.\n\n3. Modify `examples/slurm/ray_on_slurm.slurm <https://github.com/volcengine/verl/blob/main/examples/slurm/ray_on_slurm.slurm>`_ with your cluster's own information.\n\n4. Submit the job script to the Slurm cluster with `sbatch`.\n\nPlease note that Slurm cluster setup may vary. If you encounter any issues, please refer to Ray's\n`Slurm user guide <https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html>`_ for common caveats.\n\nIf you changed Slurm resource specifications, please make sure to update the environment variables in the job script if necessary.\n\n\nInstall related\n------------------------\n\nNotImplementedError: TensorDict does not support membership checks with the `in` keyword. \n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nDetail error information: \n\n.. code:: bash\n\n    NotImplementedError: TensorDict does not support membership checks with the `in` keyword. If you want to check if a particular key is in your TensorDict, please use `key in tensordict.keys()` instead.\n\nCause of the problem: There is no suitable version of tensordict package for the linux-arm64 platform. The confirmation method is as follows:\n\n.. code:: bash\n\n    pip install tensordict==0.6.2\n\nOutput example:\n\n.. code:: bash\n\n    ERROR: Could not find a version that satisfies the requirement tensordict==0.6.2 (from versions: 0.0.1a0, 0.0.1b0, 0.0.1rc0, 0.0.2a0, 0.0.2b0, 0.0.3, 0.1.0, 0.1.1, 0.1.2, 0.8.0, 0.8.1, 0.8.2, 0.8.3)\n    ERROR: No matching distribution found for tensordict==0.6.2\n\nSolution 1st:\n  Install tensordict from source code:\n\n.. code:: bash\n\n    pip uninstall tensordict\n    git clone https://github.com/pytorch/tensordict.git\n    cd tensordict/\n    git checkout v0.6.2\n    python setup.py develop\n    pip install -v -e .\n\nSolution 2nd:\n  Temperally modify the error takeplace codes: tensordict_var -> tensordict_var.keys()\n\n\nIllegal memory access\n---------------------------------\n\nIf you encounter the error message like ``CUDA error: an illegal memory access was encountered`` during rollout, please check the vLLM documentation for troubleshooting steps specific to your vLLM version.\n\nCheckpoints\n------------------------\n\nIf you want to convert the model checkpoint into huggingface safetensor format, please refer to ``verl/model_merger``.\n\n\nTriton ``compile_module_from_src`` error\n------------------------------------------------\n\nIf you encounter triton compilation error similar to the stacktrace below, please set the ``use_torch_compile`` flag according to\nhttps://verl.readthedocs.io/en/latest/examples/config.html to disable just-in-time compilation for fused kernels.\n\n.. code:: bash\n\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/jit.py\", line 345, in <lambda>\n    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/autotuner.py\", line 338, in run\n    return self.fn.run(*args, **kwargs)\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/jit.py\", line 607, in run\n    device = driver.active.get_current_device()\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/driver.py\", line 23, in __getattr__\n    self._initialize_obj()\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/driver.py\", line 20, in _initialize_obj\n    self._obj = self._init_fn()\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/driver.py\", line 9, in _create_driver\n    return actives[0]()\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/backends/nvidia/driver.py\", line 371, in __init__\n    self.utils = CudaUtils()  # TODO: make static\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/backends/nvidia/driver.py\", line 80, in __init__\n    mod = compile_module_from_src(Path(os.path.join(dirname, \"driver.c\")).read_text(), \"cuda_utils\")\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/backends/nvidia/driver.py\", line 57, in compile_module_from_src\n    so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/build.py\", line 48, in _build\n    ret = subprocess.check_call(cc_cmd)\n  File \"/data/lbh/conda_envs/verl/lib/python3.10/subprocess.py\", line 369, in check_call\n    raise CalledProcessError(retcode, cmd)\n\nWhat is the meaning of train batch size, mini batch size, and micro batch size?\n------------------------------------------------------------------------------------------\n\nThis figure illustrates the relationship between different batch size configurations.\n\nhttps://excalidraw.com/#json=pfhkRmiLm1jnnRli9VFhb,Ut4E8peALlgAUpr7E5pPCA\n\n.. image:: https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d\n\nHow to generate ray timeline to analyse performance of a training job?\n------------------------------------------------------------------------------------------\n\nTo generate the ray timeline file, you can set the config term ``ray_init.timeline_file`` to a json file path.\nFor example:\n\n.. code:: bash\n\n    ray_init.timeline_file=/tmp/ray_timeline.json\n  \nThe file will be generated in the specified path at the end of a training job.\nYou can use tools like chrome://tracing or the Perfetto UI and view the ray timeline file.\n\nThis figure shows the ray timeline file generated by from a training job on 1 node with 4 GPUs\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray_timeline.png?raw=true\n\nHow to set proxy only for wandb?\n------------------------------------------------------------------------------------------\n\nIf you need a proxy to access wandb, you can add below config in your training job script.\nComparing to using global https_proxy env variable, this approach won't mess up other http requests, such as ChatCompletionScheduler.\n\n.. code:: bash\n\n  +trainer.wandb_proxy=http://<your proxy and port>\n\n"
  },
  {
    "path": "verl_rl/docs/hybrid_flow.rst",
    "content": "=========================================================\nHybridFlow Programming Guide\n=========================================================\n\nLast updated: 06/02/2025.\n\n.. _vermouth: https://github.com/vermouth1992\n\nAuthor: `Chi Zhang <https://github.com/vermouth1992>`_\n\nverl is an open source implementation of the paper `HybridFlow <https://arxiv.org/abs/2409.19256v2>`_ [1]_. In this section, we will introduce the basic concepts of HybridFlow, the motivation and how to program with verl APIs.\n\nMotivation and Design\n------------------------\nWe use dataflow to represent RL systems. [4]_.\n\nDataFlow\n~~~~~~~~~~~~~~~~~~~~\n\nDataflow is an abstraction of computations. Neural Network training is a typical dataflow. It can be represented by computational graph. \n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/dataflow.jpeg?raw=true\n   :alt: The dataflow graph from CS231n 2024 lecture 4\n\nThis figure [2]_ represents the computation graph of a polynomial function followed by a sigmoid function. In the data flow of neural network computation, each node represents an operator, and each edge represents the direction of forward/backward propagation. The computation graph determines the architecture of the neural network.\n\nRL as a dataflow problem\n++++++++++++++++++++++++++++++++++++++++++++++\n\nReinforcement learning (RL) training can also be represented as a dataflow. Below is the dataflow graph that represents the PPO algorithm used in RLHF [3]_:\n\n.. image:: https://picx.zhimg.com/70/v2-cb8ab5ee946a105aab6a563e92682ffa_1440w.avis?source=172ae18b&biz_tag=Post\n  :alt: PPO dataflow graph, credit to Zhihu 低级炼丹师\n\nHowever, the dataflow of RL has fundamental differences compared with dataflow of neural network training as follows:\n\n+--------------------------+--------------------------------------------------+---------------------+\n| Workload                 | Node                                             | Edge                |\n+--------------------------+--------------------------------------------------+---------------------+\n| Neural Network Training  | Operator (+/-/matmul/softmax)                    | Tensor movement     |\n+--------------------------+--------------------------------------------------+---------------------+\n| Reinforcement Learning   | High-level operators (rollout/model forward)     | Data Movement       |\n+--------------------------+--------------------------------------------------+---------------------+\n\nIn the case of tabular reinforcement learning, each operator is a simple scalar math operation (e.g., bellman update). In deep reinforcement learning(DRL), each operator is a high-level neural network computation such as model inference/update. This makes RL a two-level dataflow problem:\n\n- Control flow: defines how the high-level operators are executed (e.g., In PPO, we first perform rollout. Then, we perform advantage computation. Finally, we perform training). It expresses the **core logics of RL algorithms**.\n- Computation flow: defines the dataflow of **neural network computation** (e.g., model forward/backward/optimizer).\n\n\nDesign Choices\n~~~~~~~~~~~~~~~~~~~~\nThe model size used in DRL before the LLM era is typically small. Thus, the high-level neural network computation can be done in a single process. This enables embedding the computation flow inside the control flow as a single process.\n\nHowever, in the LLM era, the computation flow (e.g., training neural network) becomes a multi-process program. This naturally leads to two design choices:\n\n1. Convert the control flow into a multi-process program as well. Then colocate with computation flow (unified multi-controller)\n\n- Advantages:\n\n  - Achieves the **optimal performance** under fixed computation flow and control flow as the communication overhead in both training and data transfer is minimized.\n\n- Disadvantages:\n\n  - The computation and/or control flow is **hard to reuse** from software perspective as computation code is coupled with specific controller code. For example, the training loop of PPO is generic. Say we have an PPO training flow implemented with a specific computation flow such as FSDP. Neither the control flow or computation flow can be reused if we want to switch the computation flow from FSDP to Megatron, due to the coupling of control and computation flows.\n  - Requires more efforts from the user under flexible and dynamic control flows, due to the multi-process nature of the program.\n\n2. Separate the flows: single process for the control flow and multi-process for computation flow\n\n- Advantages:\n\n  - The computation flow defined elsewhere can be **easily reused** after the decoupling.\n  - The controller runs on a single process. Implementing a new RL algorithm with a **different control flow is simple and easy**.\n\n- Disadvantages:\n\n  - Additional **data communication overhead** each time the controller process and computatation processes interact. The data has to be sent back and forth.\n\nIn verl, the latter strategy with separate control flow and computation flow is adopted. verl is designed to decouple the control flow of RL algorithms, and the implementation of computation engines.\n\nOverall Execution Diagram\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nBelow is a simplified diagram denoting the execution of a reinforcement learning job. In the diagram, the controller runs on a single process, while the generator/actor workers, critic workers run on multiple processes, placed with specific resource groups. For rollout, the controller passes the data to the generator to perform sample generation. When the rollout is done, the data is passed back to controller for the next step of the algorithm. Similar execution is done for other workers. With the hybrid controller design, the data flow and computation is decoupled to provide both efficiency in computation and flexibility in defining algorithm training loops.\n\n.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/driver_worker.png?raw=true\n   :alt: The execution diagram\n\nCodebase walkthrough (PPO)\n------------------------------------------------\n\nEntry function\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\nCode: https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py\n\nIn this file, we define a remote function `main_task` that serves as the controller (driver) process as shown in the above figure. We also define a ``RewardManager``, where users can customize their reward function based on the data source in the dataset. Note that `RewardManager` should return the final token-level reward that is optimized by RL algorithms. Note that users can combine model-based rewards and rule-based rewards.\nThe ``main_task`` constructs a RayPPOTrainer instance and launch the fit. Note that ``main_task`` **runs as a single process**.\n\nWe highly recommend that the ``main_task`` is NOT scheduled on the head of the ray cluster because ``main_task`` will consume a lot of memory but the head usually contains very few resources.\n\nRay trainer\n~~~~~~~~~~~~~~~~~~~~\nCode: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py\n\nThe RayPPOTrainer manages \n\n- Worker and WorkerGroup construction\n- Runs the main loop of PPO algorithm\n\nNote that, the fit function of RayPPOTrainer **runs as a single process**.\n\nWorker and WorkerGroup construction\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nEach workerGroup manages a list of workers that runs remotely. Note that the worker group runs in the process of its constructor.\nEach worker inside the WorkerGroup runs on a GPU. The worker group serves as a proxy for the controller process to interact with a list of workers, in order to perform certain computations. **In order to do so, we have to bind the methods of the worker into the method of the WorkerGroup and define the data dispatch and data collection**. This is done via simple decoration that will be introduced in the Worker definition section.\n\nFor example, in PPO, we define 3 worker groups:\n\n- ActorRolloutRef: manages actor, rollout and reference policy. ActorRolloutRefWorker can be instantiated as a single actor, a single rollout, a single reference policy, a combined actor/rollout or a combined actor/rollout/ref. This design is aimed for the maximum code reuse in various scenarios. The reason for colocating actor and rollout is for fast weight transfer using nccl. The reason for coloating actor and reference is to implement an efficient lora PPO as the reference policy is simply the base model of PPO in lora. The colocation is done via ``verl.single_controller.ray.base.create_colocated_worker_cls``, where it creates a single ray remote class exposing all class methods from these roles.\n- Critic: manages the critic model\n- Reward: manages the reward model\n\nThe worker group will be constructed on the resource pool it designates. The resource pool is a set of GPUs in the ray cluster.\n\nWorker definition\n~~~~~~~~~~~~~~~~~~~~\n\n.. _ActorRolloutRefWorker: https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py\n\nWe take `ActorRolloutRefWorker <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py>`_ for an example.\nThe APIs it should expose to the controller process are:\n\n- init_model: build the underlying model\n- generate_sequences: given prompts, generate responses\n- compute_log_prob: compute the log-probability of a generated sequence using actor\n- compute_ref_log_prob: compute the log-probability of a generated sequence using reference policy\n- save_checkpoint: save the checkpoint\n\nNote that these methods are defined in the worker that can only be invoked via remote calls. For example, if the controller process wants to initialize the model, it has to call\n\n.. code-block:: python\n\n   for worker in actor_rollout_ref_wg:\n       worker.init_model.remote()\n\nIf the controller process wants to generate sequences, it has to call\n\n.. code-block:: python\n\n   data = xxx\n   # split the data into dp chunks\n   data_dp_lst = data.split(dp_size)\n   output_dp_lst = []\n   for i, worker in enumerate(actor_rollout_ref_wg):\n       output_future = worker.generate_sequences.remote(data_dp_lst[i])\n       output_dp_lst.append(output_future)\n   output = torch.cat(ray.get(output_dp_lst), dim=0)\n\nWe observe that controller process calling worker group methods in general can be divided into 3 parts:\n\n- Split the data into data parallel sizes\n- Dispatch the corresponding data into each worker\n- Collect and concatenate the data when the computation finishes\n\nIn verl, we design a syntax sugar to encapsulate the 3 processes into a single call from the controller process.\n\n.. code-block:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def generate_sequences(data):\n       ...\n\n   # on the driver\n   output = actor_rollout_ref_wg.generate_sequences(data)\n\nWe decorate the method of the worker with a ``register`` that explicitly defines how the input data should be split and dispatched to each worker, and how the output data should be collected and concatenated by the controller. For example, ``Dispatch.DP_COMPUTE_PROTO`` splits the input data into dp chunks, dispatch each data to each worker, collect the output and concatenate the results. Note that this function requires the input and output to be a DataProto defined here (https://github.com/volcengine/verl/blob/main/verl/protocol.py).\n\n\nPPO main loop\n~~~~~~~~~~~~~~~~~~~~\nWith the aforementioned APIs, we can implement the main loop of PPO as if it is a single process program\n\n.. code-block:: python\n\n   for prompt in dataloader:\n       output = actor_rollout_ref_wg.generate_sequences(prompt)\n       old_log_prob = actor_rollout_ref_wg.compute_log_prob(output)\n       ref_log_prob = actor_rollout_ref_wg.compute_ref_log_prob(output)\n       values = critic_wg.compute_values(output)\n       rewards = reward_wg.compute_scores(output)\n       # compute_advantages is running directly on the control process\n       advantages = compute_advantages(values, rewards)\n       output = output.union(old_log_prob)\n       output = output.union(ref_log_prob)\n       output = output.union(values)\n       output = output.union(rewards)\n       output = output.union(advantages)\n       # update actor\n       actor_rollout_ref_wg.update_actor(output)\n       critic.update_critic(output)\n\nTakeaways\n~~~~~~~~~~~~~~~~~~~~\n- This programming paradigm enables users to use different computation backend without modification of the control process.\n- This programming paradigm enables flexible placement (by changing the mapping of WorkerGroup and ResourcePool) without modification of the control process.\n\nRepository organization\n------------------------------------------------\n\nImportant code files in the repository are organized as below:\n\n.. code-block:: bash\n\n   verl # the verl package\n     trainer\n       main_ppo.py  # the entrypoint for RL training\n       ppo\n         ray_trainer.py  # the training loop for RL algorithms such as PPO\n       fsdp_sft_trainer.py  # the SFT trainer with FSDP backend\n     config\n       generation.yaml  # configuration template for rollout\n       ppo_trainer.yaml  # configuration template for the RL trainer\n     workers\n       protocol.py  # the interface of DataProto\n       fsdp_workers.py   # the FSDP worker interfaces: ActorRolloutRefWorker, CriticWorker, RewardModelWorker\n       megatron_workers.py  # the Megatron worker interfaces: ActorRolloutRefWorker, CriticWorker, RewardModelWorker\n       actor\n         dp_actor.py  #  data parallel actor with FSDP backend\n         megatron_actor.py  # nD parallel actor with Megatron backend\n       critic\n         dp_critic.py  # data parallel critic with FSDP backend\n         megatron_critic.py  # nD parallel critic with FSDP backend\n       reward_model\n         megatron\n           reward_model.py  # reward model with Megatron backend\n       rollout\n         vllm\n           vllm_rollout.py  # rollout with vllm backend\n         hf_rollout.py  # rollout with huggingface TGI backend\n       sharding_manager\n         fsdp_ulysses.py  # data and model resharding when using FSDP + ulysses\n         fsdp_vllm.py  # data and model resharding when using FSDP + ulysses + vllm\n         megatron_vllm.py  # data and model resharding when using Megatron + vllm\n     utils\n       dataset  # datasets for SFT/RM/RL\n       reward_score  # function based reward\n         gsm8k.py  # reward function for gsm8k dataset\n         math.py  # reward function for math dataset\n       seqlen_balancing.py  # the sequence balance optimization\n     models\n       llama  # Megatron implementation for llama, deepseek, mistral, etc\n       transformers  # ulysses integration with transformer models such as llama, qwen, etc\n       weight_loader_registery.py  # registry of weight loaders for loading hf ckpt into Megatron\n     third_party\n       vllm  # adaptor for vllm's usage in RL\n         vllm_spmd  # vllm >= v0.7 adaptor\n   examples  # example scripts\n   tests  # integration and unit tests\n   .github  # the configuration of continuous integration tests\n\n\n.. [1] HybridFlow: A Flexible and Efficient RLHF Framework: https://arxiv.org/abs/2409.19256v2\n.. [2] Data flow graph credit to CS231n 2024 lecture 4: https://cs231n.stanford.edu/slides/2024/lecture_4.pdf\n.. [3] PPO dataflow graph credit to 低级炼丹师 from Zhihu​: https://zhuanlan.zhihu.com/p/635757674\n.. [4] RLFlow\n"
  },
  {
    "path": "verl_rl/docs/index.rst",
    "content": "Welcome to verl's documentation!\n================================================\n\nverl is a flexible, efficient and production-ready RL training framework designed for large language models (LLMs) post-training. It is an open source implementation of the `HybridFlow <https://arxiv.org/pdf/2409.19256>`_ paper.\n\nverl is flexible and easy to use with:\n\n- **Easy extension of diverse RL algorithms**: The hybrid programming model combines the strengths of single-controller and multi-controller paradigms to enable flexible representation and efficient execution of complex Post-Training dataflows. Allowing users to build RL dataflows in a few lines of code.\n\n- **Seamless integration of existing LLM infra with modular APIs**: Decouples computation and data dependencies, enabling seamless integration with existing LLM frameworks, such as PyTorch FSDP, Megatron-LM, vLLM and SGLang. Moreover, users can easily extend to other LLM training and inference frameworks.\n\n- **Flexible device mapping and parallelism**: Supports various placement of models onto different sets of GPUs for efficient resource utilization and scalability across different cluster sizes.\n\n- Ready integration with popular HuggingFace models\n\n\nverl is fast with:\n\n- **State-of-the-art throughput**: By seamlessly integrating existing SOTA LLM training and inference frameworks, verl achieves high generation and training throughput.\n\n- **Efficient actor model resharding with 3D-HybridEngine**: Eliminates memory redundancy and significantly reduces communication overhead during transitions between training and generation phases.\n\n--------------------------------------------\n\n.. _Contents:\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Quickstart\n\n   start/install\n   start/quickstart\n   start/multinode\n   start/ray_debug_tutorial\n   start/more_resources\n   start/agentic_rl\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Programming guide\n\n   hybrid_flow\n   single_controller\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Data Preparation\n\n   preparation/prepare_data\n   preparation/reward_function\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Configurations\n\n   examples/config\n\n.. toctree::\n   :maxdepth: 1\n   :caption: PPO Example\n\n   examples/ppo_code_architecture\n   examples/gsm8k_example\n   examples/multi_modal_example\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Algorithms\n\n   algo/ppo.md\n   algo/grpo.md\n   algo/dapo.md\n   algo/spin.md\n   algo/sppo.md\n   algo/entropy.md\n   algo/opo.md\n   algo/baseline.md\n   algo/gpg.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: PPO Trainer and Workers\n\n   workers/ray_trainer\n   workers/fsdp_workers\n   workers/megatron_workers\n   workers/sglang_worker\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Performance Tuning Guide\n\n   perf/dpsk.md\n   perf/perf_tuning\n   README_vllm0.8.md\n   perf/device_tuning\n   perf/nsight_profiling.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Adding new models\n\n   advance/fsdp_extension\n   advance/megatron_extension\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Advanced Features\n\n   advance/checkpoint\n   advance/rope\n   advance/ppo_lora.rst\n   sglang_multiturn/multiturn.rst\n   sglang_multiturn/interaction_system.rst\n   advance/placement\n   advance/dpo_extension\n   examples/sandbox_fusion_example\n   advance/rollout_trace.rst\n   advance/one_step_off\n   advance/agent_loop\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Hardware Support\n\n   amd_tutorial/amd_build_dockerfile_page.rst\n   amd_tutorial/amd_vllm_page.rst\n   ascend_tutorial/ascend_quick_start.rst\n   ascend_tutorial/ascend_profiling.rst\n   ascend_tutorial/ascend_profiling_en.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: API References\n\n   api/data\n   api/single_controller.rst\n   api/trainer.rst\n   api/utils.rst\n\n\n.. toctree::\n   :maxdepth: 2\n   :caption: FAQ\n\n   faq/faq\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Development Notes\n\n   sglang_multiturn/sandbox_fusion.rst\n\nContribution\n-------------\n\nverl is free software; you can redistribute it and/or modify it under the terms\nof the Apache License 2.0. We welcome contributions.\nJoin us on `GitHub <https://github.com/volcengine/verl>`_, `Slack <https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA>`_ and `Wechat <https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG>`_ for discussions.\n\nContributions from the community are welcome! Please check out our `project roadmap <https://github.com/volcengine/verl/issues/710>`_ and `good first issues <https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22>`_ to see where you can contribute.\n\nCode Linting and Formatting\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nWe use pre-commit to help improve code quality. To initialize pre-commit, run:\n\n.. code-block:: bash\n\n   pip install pre-commit\n   pre-commit install\n\nTo resolve CI errors locally, you can also manually run pre-commit by:\n\n.. code-block:: bash\n\n   pre-commit run\n\nAdding CI tests\n^^^^^^^^^^^^^^^^^^^^^^^^\n\nIf possible, please add CI test(s) for your new feature:\n\n1. Find the most relevant workflow yml file, which usually corresponds to a ``hydra`` default config (e.g. ``ppo_trainer``, ``ppo_megatron_trainer``, ``sft_trainer``, etc).\n2. Add related path patterns to the ``paths`` section if not already included.\n3. Minimize the workload of the test script(s) (see existing scripts for examples).\n\nWe are HIRING! Send us an `email <mailto:haibin.lin@bytedance.com>`_ if you are interested in internship/FTE opportunities in MLSys/LLM reasoning/multimodal alignment.\n"
  },
  {
    "path": "verl_rl/docs/perf/device_tuning.rst",
    "content": "Hardware Resource Needed for RL\n===============================\n\nLast updated: 06/25/2025.\n\nSince RL requires more resources compared to regular training, \ndetermining how much resources are needed to successfully run it before training \nis a relatively difficult task. To provide more people with reference points for \nresource selection when dealing with different models and tasks, this section is \nmainly dedicated to introducing the environmental requirements based on experiments \nwe have conducted.\n\nHowever, due to limited staff and equipment resources, we also hope for more \ncontributions from the open-source community. When submitting a PR, it is necessary \nto provide a script to be added to the example/tuning scripts.\n\nWe need two types of scripts: one is the configuration that can run with the **minimum \nresources(min)**, and the other is the configuration that runs with **recommended resources(recommended)**. For the former, \nit can be understood as a script that can run after applying all memory optimization techniques \n(e.g., offload, gradient checkpointing). For the latter, it can be understood as a script that \ncan run while avoiding operations that incur additional time overhead as much as possible (targetting best throughput).\n\nWhen defining script names, please follow this format: \n``[model]_[task]_[gpunums]_[device]_[train]_[infer].sh``. This will effectively improve \nthe script's recognizability. You can place the script under the ``examples/tuning/`` directory.\n\nIf you happen to have a configuration that has already been tested, we welcome you to submit \na PR and include a screenshot from Wandb or other verifiable evidence.\n\n----------------------------------------\n\n0.5B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2.5-0.5B\n      - GRPO-LoRA\n      - 1*H100\n      - 116\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n1.5B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2.5-1.5B\n      - GRPO-LoRA\n      - 1*H100\n      - 128\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n3B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2.5-3B\n      - GRPO-LoRA\n      - 1*H100\n      - 62\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n7B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2-7B\n      - GRPO\n      - 2*H800\n      - \\\n      - fsdp\n      - vllm0.8.2\n      - `qwen2-7b_grpo_2_h800_fsdp_vllm <https://github.com/volcengine/verl/blob/main/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh>`_\n      - `Xiangyongan <xiangyongan@bytedance.com>`_\n    * - MIN\n      - Qwen2.5-7B\n      - GRPO-LoRA\n      - 1*H100\n      - 16\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n14B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2-14B\n      - GRPO\n      - 4*H800\n      - \\\n      - fsdp\n      - vllm0.8.2\n      - `qwen2-14b_grpo_4_h800_fsdp_vllm <https://github.com/volcengine/verl/blob/main/examples/tuning/14b/qwen2-14b_grpo_4_h800_fsdp_vllm.sh>`_\n      - `Xiangyongan <xiangyongan@bytedance.com>`_\n    * - MIN\n      - Qwen2.5-14B\n      - GRPO-LoRA\n      - 2*H100\n      - 116\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n32B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n    \n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2-32B\n      - GRPO\n      - 8*H20\n      - \\\n      - megatron\n      - vllm0.8.2\n      - `qwen2-32b_grpo_8_h20_megatron_vllm <https://github.com/volcengine/verl/tree/main/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh>`_\n      - `Xiangyongan <xiangyongan@bytedance.com>`_\n    * - MIN\n      - Qwen2.5-32B\n      - GRPO-LoRA\n      - 4*H100\n      - 180\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n70B\n~~~\n\n.. list-table::\n    :widths: auto\n    :header-rows: 1\n\n    * - Tag\n      - Model\n      - Task\n      - Resource\n      - MaxBatch\n      - Train\n      - Infer\n      - Link\n      - Contributor\n    * - MIN\n      - Qwen2-70B\n      - GRPO\n      - 32*H20\n      - \\\n      - fsdp\n      - vllm0.8.2\n      - `qwen2-70b_grpo_32_h20_fsdp_vllm <https://github.com/volcengine/verl/blob/main/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh>`_\n      - `Xiangyongan <xiangyongan@bytedance.com>`_\n    * - MIN\n      - Qwen2-70B\n      - GRPO\n      - 32*H800\n      - \\\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-70b_grpo_32_h800_fsdp_vllm <https://github.com/volcengine/verl/blob/main/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh>`_\n      - `Xiangyongan <xiangyongan@bytedance.com>`_\n    * - MIN\n      - Qwen2.5-72B\n      - GRPO-LoRA\n      - 8*H100\n      - 176\n      - fsdp\n      - vllm0.8.3\n      - `qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh <https://github.com/volcengine/verl/blob/main/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh>`_\n      - `SimonHuang <thelongestusernameofall@gmail.com>`_\n\n405B\n~~~~\n\n.. table::\n   :widths: auto\n\n   ====== ====== ====== ======== ======== ====== ====== ======\n   tag    model  task   resource MaxBatch train  infer  link\n   ====== ====== ====== ======== ======== ====== ====== ======\n   \\      \\      \\        \\        \\      \\      \\\n   ====== ====== ====== ======== ======== ====== ====== ======\n\n671B\n~~~~\n\n.. table::\n   :widths: auto\n\n   ====== ====== ====== ======== ======== ====== ====== ======\n   tag    model  task   resource MaxBatch train  infer  link\n   ====== ====== ====== ======== ======== ====== ====== ======\n   \\      \\      \\        \\        \\      \\      \\\n   ====== ====== ====== ======== ======== ====== ====== ======\n"
  },
  {
    "path": "verl_rl/docs/perf/dpsk.md",
    "content": "# Training DeepSeek 671b\n\nLast updated: 06/13/2025.\n\nverl integrates Megatron to support large MoE models such as `Qwen3-235B-A22B` and `deepseek-ai/DeepSeek-V3`. This is an ongoing community effort.\n\nIn the journey the community added the following features and optimizations that enable verl with larger models:\n- per tensor weight resharding between rollout and training\n- context parallelism and expert parallelism enabled via megatron\n- dynamic batch size (sequence balance) for megatron\n- reduced ray-related serialization overhead\n- optimizer offloading, recomputation, and efficient kernels\n- various debugging metrics and utils\n\nand the megatron backend now has a wider list of models supported:\n- DeepSeek-V3\n- Moonlight\n- Qwen3\n- Qwen2.5-VL (to be merged soon)\n- Qwen2\n- Mixtral\n\n## Getting Started\n\n### DeepSeek 671b\n\nThe recommended image with pre-built megatron dependency is `whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.2-te2.3-deepseekv3`, built with the Dockerfile in [docker/Dockerfile.vllm.sglang.megatron.deepseek](https://github.com/volcengine/verl/blob/main/docker/Dockerfile.vllm.sglang.megatron.deepseek).\n\nFor checkpoint loading, we rely on megatron dist-ckpt for resharding. A converted dist-ckpt for DeepSeek-V3 is available from [huggingface BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt](https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main).\n\nTo run end-to-end training on the DAPO dataset, run [recipe/dapo/test_dapo_dspk_671b_megatron.sh](https://github.com/volcengine/verl/blob/main/recipe/dapo/test_dapo_dspk_671b_megatron.sh). It runs on 512 H20(96GB) GPUs with the following setup:\n- vllm rollout with TP=32, bfloat16\n- megatron training with attention DP, MoE EP=32, PP=16, bfloat16\n\nMTP is disabled during RL training.\n\n### Qwen3 236b\n\nFor Qwen3-236b, please refer to [examples/grpo_trainer/run_qwen3-236b_megatron.sh](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen3-236b_megatron.sh), which runs on 128 H20(96GB) GPUs.\n\n## Upcoming Optimizations\n\nThe community continue to optimize large MoE models further, ongoing efforts include:\n- further optimizing memory consumption, and provide recommended/tuned configurations with various machine types\n- optimizing long context RL training performance\n- performance improvement with SGLang x Megatron\n\nWe invite the community to try and improve verl together. Get connected with us on [slack](https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA)/[wechat](https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG)/[Github issues](https://github.com/volcengine/verl/issues/708)!\n\n## Acknowledgement\n@vermouth1992 @ISEEKYAN @ETOgaosion @yzlnew @ShareLer @BearBiscuit05 @ccclyu @ann-qin-lu @SwordFaith @zzong2006 @zhaochenyang20 @ocss884 @eric-haibin-lin\n"
  },
  {
    "path": "verl_rl/docs/perf/nsight_profiling.md",
    "content": "# NVIDIA Nsight Systems profiling in verl\n\nLast updated: 06/20/2025.\n\nThis guide explains how to use NVIDIA Nsight Systems for profiling verl training runs.\n\n## Configuration\n\nProfiling in verl can be configured through several parameters in the trainer configuration file (ppo_trainer.yaml or other files like dapo_trainer.yaml):\n\n### Prerequisites\n\nNsight Systems version is important, please reference `docker/Dockerfile.vllm.sglang.megatron` for the version we used.\n\n### Global profiling control\n\nverl has one single controller process and multiple worker processes. Both controller and worker processes can be profiled. Since the controller process can be executed in any nodes in the cluster, there is a message printed in the logging to indicate the controller process node hostname and process id.\n\nIn `trainer`, three new config entries control the profiler behaviors:\n\n* **`trainer.profile_steps`**. List of step numbers at which profiling should be performed. For example: [1, 2, 5] will profile steps 1, 2, and 5. And ``null`` means no profiling.\n\n\n* **`controller_nsight_options`**. This config group is for the single controller. All fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. `ppo_trainer.yaml` provides a workable example. Users can reference [Nsight Systems manual](https://docs.nvidia.com/nsight-systems/UserGuide/index.html) and [Ray user guide](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html) for more details.\n\n* **`worker_nsight_options`**. This config group is for the worker processes. Similarly all fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. Capture range is used to control the profiler when to start and stop. So `capture-range: \"cudaProfilerApi\"` is fixed and does not change it. Users can change `capture-range-end` with some accurate calculation or just leave it `null`.\n\n### Worker process profiling\n\nVerl manages mulitiple RL roles, _Actor_, _Ref_, _Rollout_, _Critic_, _Reward_, which are implemented in different Worker classes. And these workers can be combined into one Ray Actor, running in a process group. Each RL role has its own profiling config group, `profiler`, which consists of three fields:\n\n* **`all_ranks` and `ranks`**. When `all_ranks` is set `True` then all ranks will be profiled; when set `False`, `ranks` will be profiled. By default, verl profiles the whole training process in a series ` worker_process_<PID>.<RID>.nsys-rep` files for each process rank. PID is the process ID; RID is the capture range ID.\n\n* **`discrete`**. When set `False`, all the roles actions in one training step will be dumped in one database. When set `True`, the actions annotated by `DistProfiler.annotate` will be dumped into a discrete database. In this case, each role's action occupies one `<RID>`.\n\n* **`actor_rollout_ref`**. This Worker can be configured to contain at most 3 roles and executes together. So `actor_rollout_ref` has a `profiler` config and all the inside roles inherit it.\n\n* **Verl collocate mode**. Verl can combine two Worker sub classes to one Worker Actor. In this case, the user should take care that the combined Workers have consistent `discrete`. The Nsight Systems profiler uses a `torch.cuda.profiler.start()` and `stop()` pair to dump a `<step>` database anyway.\n\n### where to find the profiling data\n\nBy default the `*.nsys-rep` files are saved in the directory `/tmp/ray/session_latest/logs/nsight/` at each node. According to the Ray manual, this default directory is not changeable. [\"however, Ray preserves the `--output` option of the default config\"](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html).\n\nSome users may think it is not convenient, but it is understandable that Ray may start hundreds of processes and it would be a big network file system pressure if we save the files in one central place.\n\n## Usage Example\n\nTo enable profiling for specific components and steps, modify your ppo_trainer.yaml like this:\n\n### Disable profiler\n```yaml\n    trainer:\n        profile_steps: null # disable profile\n```\n\n### Enable profiler and one database for one training step\n```yaml\n    trainer:\n        profile_steps: [1, 2, 5]\n    actor_rollout_ref:\n        profiler:\n            discrete: False\n            all_ranks: False\n            ranks: [0, 1]\n    critic:\n        profiler:\n            discrete: False\n            all_ranks: False\n            ranks: [0, 1]\n    reward_model:\n        profiler:\n            discrete: False\n            all_ranks: False\n            ranks: [0, 1]\n```\n\n### Enable profiler and multiple databases for one training step\n```yaml\n    trainer:\n        profile_steps: [1, 2, 5]\n    actor_rollout_ref:\n        profiler:\n            discrete: True\n            all_ranks: False\n            ranks: [0, 1]\n    critic:\n        profiler:\n            discrete: True\n            all_ranks: False\n            ranks: [0, 1]\n    reward_model:\n        profiler:\n            discrete: True\n            all_ranks: False\n            ranks: [0, 1]\n```\n\n## Profiling Output\n\nWhen profiling is enabled, verl will generate Nsight Systems profiles for the specified components and steps. The profiles will include:\n\n- CUDA kernel execution\n- Memory operations\n- CPU-GPU synchronization\n- NVTX markers for key operations\n\nNsight Systems supports multi-report view, to open multiple databases together. In this mode, different processes and steps can be aligned in one time line for better analysis.\n"
  },
  {
    "path": "verl_rl/docs/perf/perf_tuning.rst",
    "content": "Performance Tuning Guide\n==============================\n\nLast updated: 07/17/2025.\n\nAuthor: `Guangming Sheng <https://github.com/PeterSH6>`_, `Jiali Zheng <https://github.com/CurryRice233>`_\n\nIn this section, we will discuss how to tune the performance of all the stages in verl, including:\n\n1. Rollout generation throughput.\n\n2. Enable ``use_remove_padding=True`` for sequence packing (i.e., data packing and remove padding).\n\n3. Batch size tuning for forward and backward computation\n\n4. Enable ``use_dynamic_bsz=True`` for higher throughput.\n\n5. Utilize Ulysses Sequence Parallel for Long Context Training\n\n6. LigerKernel for SFT performance optimization\n\n7. Forward prefetch in FSDP training backend\n\n8. Memory optimization for entropy calculation from logits\n\nRollout Generation Tuning\n--------------------------\n\nverl currently supports two rollout backends: vLLM and TGI (with SGLang support coming soon). \n\nBelow are key factors for tuning vLLM-based rollout. Before tuning, we recommend setting ``actor_rollout_ref.rollout.disable_log_stats=False`` so that rollout statistics are logged.\n\n- Increase ``gpu_memory_utilization``.\n\n  - For vLLM v0.7.0 and later, the vLLM instance will only use gpu_memory_utilization of the **total** memory.\n  - For SGLang, it's the fraction of the free GPU memory used for **static** memory like model weights and KV cache. However, the remaining (1-gpu_memory_utilization) will also be used during inference.\n\n  However, if model parameters and optimizer states are not offloaded, using too high a fraction can lead to OOM. \n  A value between 0.5 and 0.7 often strikes a good balance between high throughput and avoiding OOM.\n\n  Note: since the definition of ``gpu_memory_utilization`` varies across inference engines, a value that works well for one engine may cause OOM for another.\n\n- Adjust ``max_num_seqs`` or ``max_num_batched_tokens``.\n  If the GPU cache utilization is relatively low in the log, increase ``max_num_seqs`` or ``max_num_batched_tokens`` \n  can enlarge the effective batch size in the decoding stage, allowing more concurrent requests per batch. \n  We recommend setting ``max_num_batched_tokens > 2048`` for higher throughput.\n\n- Use a smaller ``tensor_parallel_size``. \n  When GPU resources allow, a smaller tensor parallel size spawns more vLLM replicas. \n  Data parallelism (DP) can yield higher throughput than tensor parallelism (TP), but also increases KVCache consumption. \n  Carefully balance the trade-off between more replicas and higher memory usage.\n  Our experiment in Sec. 8.4 of `HybridFlow paper <https://arxiv.org/pdf/2409.19256v2>`_ evaluate this trade-off.\n\nMore tuning details such as dealing with Preemption and Chunked-prefill\ncan be found in `vLLM official tuning guide <https://docs.vllm.ai/en/latest/performance/optimization.html>`_ \n\nFor optimal performance, we recommend using vLLM v0.8.3 or later. See https://github.com/volcengine/verl/blob/main/docs/README_vllm0.8.md for details.\n\nEnable remove padding (sequence packing)\n-----------------------------------------\n\nCurrently, for llama, mistral, gemma1 and qwen based models, users can enable `use_remove_padding=True` to utilize the \nsequence packing implementation provided by transformers library.\n\nFor other models, transformers library may also support it but we haven't tested it yet.\nUsers can add the desired model config to the  `test_transformer.py <https://github.com/volcengine/verl/blob/main/tests/models/test_transformer.py#L24>`_ file.\nAnd test its functionality by running the following command:\n\n.. code-block:: bash\n\n  pytest -s tests/models/test_transformer.py\n\nIf the test passes, you can add your desired model into the model `registry.py <https://github.com/volcengine/verl/blob/main/verl/models/registry.py#L24>`_ file.\nThen, you can enjoy the performance boost of sequence packing\nand welcome to PR your tested model to verl!\n\n\nBatch Size Tuning\n-----------------\n\nTo achieve higher throughput in experience preparation (i.e., model fwd) and model update (i.e., actor/critic fwd/bwd), \nusers may need to tune the ``*micro_batch_size_per_gpu`` for different computation.\n\nIn verl, the core principle for setting batch sizes is:\n\n- **Algorithmic metrics** (train batch size, PPO mini-batch size) are *global* (from a single-controller perspective), \n  normalized in each worker. See the `normalization code <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py#L120-L122>`_.\n\n- **Performance-related parameters** (micro batch size, max token length for dynamic batch size) are *local* parameters that define the per-GPU data allocations. \n  See the `normalization code <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py#L127>`_.\n\n.. note:: In your training script, please use ``*micro_batch_size_per_gpu`` instead of ``*micro_batch_size``. \n  So that you don't need to consider the normalization of the ``micro_batch_size`` and ``micro_batch_size`` will be deprecated.\n\nBatch Size Tuning tips\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nTherefore, users may need to tune the ``*micro_batch_size_per_gpu`` to accelerate training. Here're some tips:\n\n1. **Enable gradient checkpointing**: \n   Set ``actor_rollout_ref.model.enable_gradient_checkpointing=True`` and ``critic.model.enable_gradient_checkpointing=True``. \n   This often allows for larger micro-batch sizes and will be beneficial for large mini-batch training.\n\n2. Increase the ``*micro_batch_size_per_gpu`` as much as possible till equals to normalized ``mini_batch_size``.\n\n3. **Use larger forward-only parameters**: \n   Forward only parameter, such as ``actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu``, \n   ``actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu``, ``critic.forward_micro_batch_size_per_gpu`` could be larger (e.g., 2x) than training related micro batch sizes,\n   such as ``actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``, ``critic.ppo_micro_batch_size_per_gpu``.\n\n4. **Allow larger micro-batch sizes for Critic and Reward models**:\n   micro batch size of Critic and Reward model could be larger than Actor model. This is because the actor model has much larger vocab size in the final layer.\n\n5. **Enable activation offloading**:\n   Set ``actor_rollout_ref.model.enable_activation_offload=True`` and ``critic.model.enable_activation_offload=True``.\n   This often works together with gradient checkpointing to get larger micro-batch sizes and it's only available in FSDP backend now.\n\nTuning for Dynamic Batch Size\n-----------------------------\n\nDynamic batch size is a technique that allows the model to process similar number of tokens in a single forward pass (with different actual batch sizes).\nThis can significantly improve the training efficiency and reduce the memory usage.\n\nTo utilize this technique, users can set ``use_dynamic_bsz=True`` in actor, ref, critic and reward models.\nWith ``use_dynamic_bsz=True``, users don't need to tune ``*micro_batch_size_per_gpu``. \nInstead, users should tune the following parameters:\n\n- ``actor_rollout_ref.actor.ppo_max_token_len_per_gpu``, ``critic.ppo_max_token_len_per_gpu``: \n  The maximum number of tokens to be processed in fwd and bwd of ``update_policy`` and ``update_critic``.\n\n- ``actor_rollout_ref.ref.log_prob_max_token_len_per_gpu`` and ``actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu``: \n  The maximum number of tokens to be processed in a the fwd computation of ``compute_log_prob`` and ``compute_ref_log_prob``.\n\n- ``critic.forward_micro_batch_size_per_gpu``, ``reward_model.forward_micro_batch_size_per_gpu``: \n  The maximum number of tokens to be processed in a the fwd computation of ``compute_values``, ``compute_rm_score``.\n\nDynamic Batch Size Tuning tips\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nHere're some tips to tune the above parameters:\n\n1. **Increase** ``actor_rollout_ref.actor.ppo_max_token_len_per_gpu``  \n   Make it at least 2 x (max_prompt_length + max_response_length). We set it to 3x in `run_qwen2-7b_rm_seq_balance.sh <https://github.com/volcengine/verl/blob/main/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh#L25>`_.\n   Try to increase it to get higher throughput.\n\n2. **Forward-only parameters can be larger**: \n   Similar to the non-dynamic-batch scenario, forward-only token limits can exceed those used in forward/backward operations.\n \n3. **Use larger limits for Critic and Reward models**:\n   Critic and Reward parameters can be set at least 2× the Actor’s limits. For instance, we set them to 4× here:  \n   `run_qwen2-7b_rm_seq_balance.sh <https://github.com/volcengine/verl/blob/main/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh#L40>`_\n   \n.. :math:`\\text{critic.ppo_max_token_len_per_gpu}  = 2 \\times  \\text{actor.ppo_max_token_len_per_gpu})`.\n\nUlysses Sequence Parallel for Long Context Training\n----------------------------------------------------\n\nTo utilize this technique, users can set ``ulysses_sequence_parallel_size>1`` in actor, ref, critic and reward models.\n\nWe support different model utilize different ulysses_sequence_parallel_size sizes.\n\nTo train long sequence (>32k), users may need to decrease the ``*micro_batch_size_per_gpu`` and ``*max_token_len_per_gpu`` to avoid OOM.\n\nLigerKernel for SFT\n----------------------\n\nLigerKernel is a high-performance kernel for Supervised Fine-Tuning (SFT) that can improve training efficiency. To enable LigerKernel in your SFT training:\n\n1. Install liger-kernel via ``pip3 install liger-kernel``. In your SFT configuration file (e.g., ``verl/trainer/config/sft_trainer.yaml``), set the ``use_liger`` parameter:\n\n   .. code-block:: yaml\n\n      model:\n        use_liger: True  # Enable LigerKernel for SFT\n\n2. The default value is ``False``. Enable it only when you want to use LigerKernel's optimizations.\n\n3. LigerKernel is particularly useful for improving training performance in SFT scenarios.\n\nForward prefetch in FSDP training backend\n----------------------\n\nDuring the training phase, users can enable forward prefetching in FSDP by setting ``fsdp_config.forward_prefetch=True``. For example, ``actor_rollout_ref.actor.fsdp_config.forward_prefetch=True``. This configuration prefetches the next forward-pass all-gather operation before completing the current forward computation, overlapping communication with computation and improving efficiency. For further details, refer to the `FSDP forward_prefetch <https://docs.pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp>`_ documentation.\n\n.. note::\n    Backward prefetch is unsupported because the ``BACKWARD_POST`` policy may prefetch incorrectly in nested-module cases. For details, see the `FSDP documentation <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md?plain=1#L70>`_\n\nMigrating to FSDP2\n----------------------\n\nFSDP2 offers notable improvements over FSDP1. According to `PyTorch TorchTitan benchmarks <https://arxiv.org/abs/2410.06511v1>`_:\n\n- 7% lower GPU memory usage on average\n- 1.5% throughput improvement with BF16 training\n- Better composability with DTensor and per-parameter sharding\n\n**Enabling FSDP2 in VERL:**\n\n   .. code-block:: python\n\n    # Enable FSDP2 in actor configuration\n    actor_rollout_ref.actor.strategy=\"fsdp2\"\n\n.. note:: \n   FSDP2 requires PyTorch 2.1+ and is recommended for models with transformer architecture.\n\nMemory optimization for entropy calculation from logits\n----------------------\n\nThe ``logits`` tensor (typically of shape ``[bsz*seq_len, voc]``) can consume significant memory. When using ``compute_entropy_from_logits``, memory usage reaches approximately ``[bsz*seq_len, voc] × (4 bytes (float32) + 2 bytes (autocast for softmax+logsumexp) + 1 byte (softmax output))``.\n\nTo reduce this memory peak, enable chunked computation by setting:\n``actor_rollout_ref.ref.entropy_from_logits_with_chunking = True``\nThis processes the tensor in chunks of shape ``[chunk_size, voc]`` (e.g., 2048) rather than the full sequence length, exclusively during the model's forward pass.\n\nAdditionally, during training, standard gradient checkpointing (``enable_gradient_checkpointing=True``) does not apply to entropy calculations. To reduce memory peaks in this context, set:\n``actor_rollout_ref.actor.entropy_checkpointing = True``\nThis enables entropy recomputation specifically for the entropy calculation, lowering memory usage during training.\n"
  },
  {
    "path": "verl_rl/docs/preparation/prepare_data.rst",
    "content": "Prepare Data for Post-Training\n========================================\n\nLast updated: 02/09/2025.\n\nBefore starting the post-training job, we need to prepare the data for\nthe policy training. The data should be stored in the parquet format.\n\nWe provide several data preprocess scripts for different datasets,\nincluding GSM8K, MATH, HelloSwag, Full_hh_rlhf. To prepare other datasets, we need\nto follow the following steps: The data preprocess script can be divided\ninto two parts:\n\n1. The first part is the common part, which loads the dataset from\n   huggingface's ``datasets`` package. Then preprocess the datasets with\n   the ``make_map_fn`` and then store in the parquet format.\n\n.. code:: python\n\n   import re\n   import os\n   import datasets\n\n   from verl.utils.hdfs_io import copy, makedirs\n   import argparse\n\n   # To extract the solution for each prompts in the dataset\n   # def extract_solution(solution_str): \n   # ...\n\n\n   if __name__ == '__main__':\n       parser = argparse.ArgumentParser()\n       parser.add_argument('--local_dir', default='/opt/tiger/gsm8k')\n       parser.add_argument('--hdfs_dir', default=None)\n\n       args = parser.parse_args()\n\n       num_few_shot = 5\n       data_source = 'openai/gsm8k'\n\n       dataset = datasets.load_dataset(data_source, 'main')\n\n       train_dataset = dataset['train']\n       test_dataset = dataset['test']\n\n           # Construct a `def make_map_fn(split)` for the corresponding datasets.\n       # ...\n           \n       train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)\n       test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)\n\n       local_dir = args.local_dir\n       hdfs_dir = args.hdfs_dir\n\n       train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))\n       test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))\n\n       makedirs(hdfs_dir)\n\n       copy(src=local_dir, dst=hdfs_dir)\n\n2. The users are required to implement the ``make_map_fn()`` function\n   (as well as the ``extract_solution``) on their own to support\n   different datasets or tasks.\n\nWe already implemented the data preprocess of GSM8k, MATH, Hellaswag and Full_hh_rlhf\ndatasets. And we take the GSM8k dataset as an example:\n\n**GSM8K**\n\nIn the ``make_map_fn``, each data field should consist of the following\n5 fields:\n\n1. ``data_source``: The name of the dataset. To index the corresponding\n   reward function in the ``RewardModule``\n2. ``prompt``: This field should be constructed in the format of\n   huggingface chat_template. The tokenizer in ``RLHFDataset`` will\n   apply chat template and tokenize the prompt.\n3. ``ability``: Define the task category.\n4. ``reward_model``: Currently, we only utilize the ``ground_truth``\n   field during evaluation. The ``ground_truth`` is computed by the\n   ``extract_solution`` function. **NOTED** that the implementation of\n   the corresponding reward function should align with this extracted\n   ``ground_truth``.\n5. ``extra_info``: Record some information of the current prompt. Not\n   use for now.\n\n.. code:: python\n\n   def extract_solution(solution_str):\n       solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str) # extract the solution after ####\n       assert solution is not None\n       final_solution = solution.group(0)\n       final_solution = final_solution.split('#### ')[1].replace(',', '')\n       return final_solution\n\n   instruction_following = \"Let's think step by step and output the final answer after \\\"####\\\".\"\n\n   # add a row to each data item that represents a unique id\n   def make_map_fn(split):\n\n       def process_fn(example, idx):\n           question = example.pop('question')\n\n           question = question + ' ' + instruction_following\n\n           answer = example.pop('answer')\n           solution = extract_solution(answer)\n           data = {\n               \"data_source\": data_source,\n               \"prompt\": [{\n                   \"role\": \"user\",\n                   \"content\": question\n               }],\n               \"ability\": \"math\",\n               \"reward_model\": {\n                   \"style\": \"rule\",\n                   \"ground_truth\": solution\n               },\n               \"extra_info\": {\n                   'split': split,\n                   'index': idx\n               }\n           }\n           return data\n\n       return process_fn\n"
  },
  {
    "path": "verl_rl/docs/preparation/reward_function.rst",
    "content": "Implement Reward Function for Dataset\n======================================\n\nLast updated: 06/02/2025.\n\nFor each dataset, we need to implement a reward function or utilize a reward model to compute the rewards for the generated responses.\nWe already pre-implemented some reward functions in `reward_score directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_.\nYou can also use customized reward functions.\n\nCurrently, we support reward functions for GSM8k and MATH datasets. For RLHF datasets (e.g.,\nfull_hh_rlhf) and Code Generation (e.g., APPS), we utilize reward model\nand SandBox (will opensource soon) for evaluation respectively.\n\nRewardManager\n-------------\n\nIn the entrypoint of the PPO Post-Training script `main_ppo.py <https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py#L33>`_,\nwe implement a ``RewardManager`` that utilize pre-implemented reward functions to compute the scores for each response.\n\nIn the ``RewardManager``, we implemented a ``__call__`` function to\ncompute the score for each response. \nAll the reward functions are executed by ``compute_score_fn``.\nThe input is a ``DataProto``, which includes:\n\n- ``input_ids``, ``attention_mask``: ``input_ids`` and ``attention_mask`` after applying\n  chat_template, including prompt and response\n- ``responses``: response tokens\n- ``ground_truth``: The ground truth string of the current prompt.\n  Stored in ``non_tensor_batch`` in the ``DataProto``, which should be\n  preprocessed in the parquet files.\n- ``data_source``: The dataset name of the current prompt. Stored in\n  ``non_tensor_batch`` in the ``DataProto``, which should be\n  preprocessed in the parquet files.\n\nAfter detokenize the responses, the responses string and the ground\ntruth string will be input to the ``compute_score_fn`` to compute the\nscore for each response.\n\nReward Functions\n----------------\n\nPre-implemented\n~~~~~~~~~~~~~~~\n\nWe already pre-implemented some reward functions in `reward_score directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_.\n\n- In the `GSM8k example <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/gsm8k.py>`_, we\n  force the response to output the final answer after four ####, then\n  use string matching to compare with the ground truth. If completely\n  correct, score 1 point; if the format is correct, score 0.1 points; if\n  the format is incorrect, score 0 points.\n- In the `MATH example <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/math.py>`_, we follow\n  the implementation in `lm-evaluation-harness repository <https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py>`_.\n\nCustomized\n~~~~~~~~~~\n\nYou can implement customized reward functions in a separate file and specify them using ``custom_reward_function.path`` and ``custom_reward_function.name``. For the set of them, please refer to :ref:`config-explain-page`.\n\nThe parameters of your reward function should be ``data_source``, ``solution_str``, ``ground_truth``, and ``extra_info``.\nFor example:\n\n.. code:: python\n\n  def my_reward_fn(data_source, solution_str, ground_truth, extra_info=None):\n    return len(solution_str)/100\n\nIf you are testing only a single customized reward function, you can simply name it 'compute_score' and leave ``custom_reward_function.name`` unset.\n\nTo run multiple tests with different customized reward functions, you can modify both ``custom_reward_function.path`` and ``custom_reward_function.name`` for each trial. \nFor instance, you might create a single `my_reward.py` file and implement multiple reward functions within it. This way, for different trials, you only need to adjust ``custom_reward_function.name``, making it more convenient to conduct multiple tests within scripts.\n"
  },
  {
    "path": "verl_rl/docs/requirements-docs.txt",
    "content": "# markdown support\r\nrecommonmark\r\nmyst_parser\r\n# markdown table support\r\nsphinx-markdown-tables\r\n\r\n# theme default rtd\r\n\r\n# crate-docs-theme\r\nsphinx-rtd-theme\r\n\r\n# pin tokenizers version to avoid env_logger version req\r\ntokenizers==0.21\r\n"
  },
  {
    "path": "verl_rl/docs/sglang_multiturn/interaction_system.rst",
    "content": "Interaction System for Multi-turn RL Training\n=============================================\n\nLast updated: 06/25/2025.\n\nOverview\n--------\n\nThe verl interaction system enables dynamic, multi-turn conversational feedback during reinforcement learning training. This system allows models to engage in iterative problem-solving scenarios where interaction agents can provide corrective feedback, guidance, or evaluation based on the model's responses.\n\n**New in Multi-Interaction Support**: The system now supports multiple named interactions within a single training session, enabling sophisticated training scenarios where different samples can use different interaction strategies. This allows for curriculum learning, domain-specific feedback, and flexible agent switching at the sample level.\n\nKey features:\n\n- **Async-based Architecture**: Non-blocking interaction processing for distributed training\n- **Instance Management**: Stateful session handling with unique instance IDs for concurrent interactions\n- **SGLang Integration**: Seamless integration with SGLang rollout system for multi-turn conversations\n- **Configuration-driven**: Dynamic agent loading via YAML configuration files\n- **Multi-Interaction Support**: Registry system enabling multiple named interactions per rollout\n- **Sample-Level Selection**: Each sample can specify which interaction to use via configuration\n- **Reward Integration**: Turn-level scoring mechanism integrated with verl's reward system\n\nArchitecture\n------------\n\nThe interaction system follows a plugin-based architecture with clear separation of concerns:\n\n.. code-block::\n\n    Interaction Registry System\n         ↓\n    BaseInteraction (Abstract Interface)\n         ↓\n    Multiple Named Interactions (e.g., Gsm8kInteraction, CustomInteraction)\n         ↓\n    SGLang Rollout Integration (interaction_map)\n         ↓\n    Sample-Level Interaction Selection\n         ↓\n    Async Request Lifecycle Management\n\nCore Components\n~~~~~~~~~~~~~~~\n\n**Interaction Registry System**\n\nThe interaction registry system allows loading and managing multiple named interactions:\n\n.. code-block:: python\n\n    from verl.interactions.utils.interaction_registry import initialize_interactions_from_config\n    \n    # Load multiple interactions from config\n    interaction_map = initialize_interactions_from_config(\"config.yaml\")\n    \n    # Access specific interaction by name\n    gsm8k_interaction = interaction_map[\"gsm8k\"]\n    custom_interaction = interaction_map[\"custom_solver\"]\n\n**BaseInteraction Interface**\n\nAll interaction agents must implement the ``BaseInteraction`` abstract class:\n\n.. code-block:: python\n\n    from verl.interactions.base import BaseInteraction\n    from typing import Dict, Any, List, Tuple, Optional\n\n    class BaseInteraction:\n        def __init__(self, config: Dict[str, Any]):\n            self.config = config\n            self.name: str = config.get(\"name\", \"interaction_agent\")\n        \n        async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str:\n            \"\"\"Initialize interaction session, return instance_id\"\"\"\n            \n        async def generate_response(self, instance_id: str, messages: List[Dict[str, Any]], **kwargs) -> Tuple[bool, str, float, Dict[str, Any]]:\n            \"\"\"Generate response, return (should_terminate, response, score, metadata)\"\"\"\n            \n        async def calculate_score(self, instance_id: str, **kwargs) -> float:\n            \"\"\"Calculate turn-level score for RL training\"\"\"\n            \n        async def finalize_interaction(self, instance_id: str, **kwargs) -> None:\n            \"\"\"Clean up resources\"\"\"\n\n**Request Lifecycle**\n\nThe interaction system integrates with SGLang's async rollout via state management:\n\n1. ``PENDING`` → Initialize interaction via ``start_interaction()``\n2. ``GENERATING`` → Model generates response\n3. ``INTERACTING`` → Process response via ``generate_response()``\n4. ``GENERATING`` → Continue if not terminated, otherwise ``COMPLETED``\n\nConfiguration\n-------------\n\n**Basic Setup**\n\nEnable interaction in your rollout configuration:\n\n.. code-block:: yaml\n\n    actor_rollout_ref:\n        rollout:\n            multi_turn:\n                enable: true\n                interaction_config_path: \"path/to/interaction_config.yaml\"\n                max_user_turns: 10\n                max_assistant_turns: 10\n\n**Interaction Configuration File**\n\nCreate an interaction configuration file (e.g., ``interaction_config.yaml``):\n\n**Single Interaction (Legacy Format)**\n\n.. code-block:: yaml\n\n    interaction:\n      - name: \"gsm8k\"\n        class_name: \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\"\n        config: {}\n\n**Multiple Interactions (New Format)**\n\n.. code-block:: yaml\n\n    interaction:\n      - name: \"gsm8k\"\n        class_name: \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\"\n        config: {}\n      - name: \"custom_solver\"\n        class_name: \"custom.interactions.CustomInteraction\"\n        config: \n          solver_type: \"advanced\"\n          timeout: 30\n      - name: \"code_verifier\"\n        class_name: \"verl.interactions.base.BaseInteraction\"\n        config: \n          verification_mode: \"strict\"\n\n**Automatic Name Generation**\n\nIf no ``name`` field is provided, the system will automatically generate one from the class name:\n\n.. code-block:: yaml\n\n    interaction:\n      - class_name: \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\"\n        config: {}\n        # Automatically generates name: \"gsm8k\"\n\nThe system will dynamically load all specified interaction classes and make them available by name.\n\nImplementation Example: GSM8K\n-----------------------------\n\nThe GSM8K interaction demonstrates a complete implementation for math problem-solving scenarios:\n\n.. code-block:: python\n\n    from verl.interactions.base import BaseInteraction\n    from verl.utils.reward_score import gsm8k\n    from uuid import uuid4\n\n    class Gsm8kInteraction(BaseInteraction):\n        def __init__(self, config: dict):\n            super().__init__(config)\n            self._instance_dict = {}\n\n        async def start_interaction(self, instance_id=None, ground_truth=None, **kwargs):\n            if instance_id is None:\n                instance_id = str(uuid4())\n            self._instance_dict[instance_id] = {\n                \"response\": \"\",\n                \"ground_truth\": ground_truth,\n                \"reward\": 0.0,\n            }\n            return instance_id\n\n        async def generate_response(self, instance_id, messages, **kwargs):\n            # Extract last user message content\n            content = \"\"\n            for item in reversed(messages):\n                if item.get(\"role\") == \"assistant\":\n                    content = item.get(\"content\", \"\")\n                    break\n\n            # Ensure GSM8K format (#### prefix)\n            self._instance_dict[instance_id][\"response\"] = content\n\n            reward = await self.calculate_score(instance_id)\n            if reward == 1.0:\n                return True, \"Your response is correct!\", 1.0, {}\n            else:\n                return False, \"Your response is incorrect! You need to reflect on your answer and try again.\", 0.0, {}\n\n        async def calculate_score(self, instance_id, **kwargs):\n            return gsm8k.compute_score(\n                self._instance_dict[instance_id][\"response\"],\n                self._instance_dict[instance_id][\"ground_truth\"],\n                method=\"strict\", format_score=0.0, score=1.0,\n            )\n\n        async def finalize_interaction(self, instance_id, **kwargs):\n            del self._instance_dict[instance_id]\n\nTraining Integration\n--------------------\n\n**Training Script Configuration**\n\nInclude interaction configuration in your training command:\n\n.. code-block:: bash\n\n    python3 -m verl.trainer.main_ppo \\\\\n        --config-path=\"$CONFIG_PATH\" \\\\\n        --config-name='gsm8k_multiturn_grpo_w_interaction' \\\\\n        algorithm.adv_estimator=grpo \\\\\n        data.train_batch_size=512 \\\\\n        data.return_raw_chat=True \\\\\n        actor_rollout_ref.rollout.name=sglang \\\\\n        actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml\" \\\\\n        trainer.total_epochs=15\n\n**Data Requirements**\n\nEnsure your dataset includes interaction parameters with the ``name`` field for interaction selection:\n\n.. code-block:: python\n\n    # Dataset should include interaction_kwargs in non_tensor_batch\n    interaction_kwargs = [\n        {\"name\": \"gsm8k\", \"query\": \"What is 2+2?\", \"ground_truth\": \"4\"},\n        {\"name\": \"custom_solver\", \"query\": \"Solve: x^2 + 5x + 6 = 0\", \"ground_truth\": \"x = -2, -3\"},\n        {\"name\": \"gsm8k\", \"query\": \"What is 3+3?\", \"ground_truth\": \"6\"},\n    ]\n\n**Sample-Level Interaction Selection**\n\nEach sample can specify which interaction to use via the ``name`` field. This enables flexible training scenarios where different samples use different interaction strategies:\n\n.. code-block:: python\n\n    # Example: Math problems use GSM8K interaction, code problems use code verifier\n    data_samples = [\n        {\n            \"prompt\": \"What is 15% of 200?\",\n            \"interaction_kwargs\": {\n                \"name\": \"gsm8k\",\n                \"query\": \"What is 15% of 200?\", \n                \"ground_truth\": \"30\"\n            }\n        },\n        {\n            \"prompt\": \"Write a function to check if a number is prime\",\n            \"interaction_kwargs\": {\n                \"name\": \"code_verifier\",\n                \"code_type\": \"python\",\n                \"expected_behavior\": \"return True for prime numbers\"\n            }\n        }\n    ]\n\n**Backward Compatibility**\n\nIf no ``name`` field is provided in ``interaction_kwargs``, the system defaults to ``\"gsm8k\"`` for backward compatibility.\n\nBest Practices\n--------------\n\n**Resource Management**\n\n- Always implement proper cleanup in ``finalize_interaction()``\n- Use unique instance IDs to avoid conflicts in concurrent training\n- Handle edge cases like empty messages or malformed content\n\n**Performance Optimization**\n\n- Keep interaction logic lightweight to avoid blocking training\n- Use async/await properly to maintain non-blocking behavior\n- Consider caching expensive computations within interaction instances\n\n**Testing**\n\nComprehensive testing is essential for interaction systems:\n\n.. code-block:: python\n\n    import pytest\n    from unittest.mock import patch\n\n    @pytest.mark.asyncio\n    async def test_interaction_workflow():\n        interaction = YourInteraction({})\n        \n        # Test complete workflow\n        instance_id = await interaction.start_interaction(ground_truth=\"expected_answer\")\n        \n        messages = [{\"role\": \"user\", \"content\": \"user_content\"}, {\"role\": \"assistant\", \"content\": \"assistant_response\"}]\n        should_terminate, response, reward, metadata = await interaction.generate_response(instance_id, messages)\n        \n        assert should_terminate in [True, False]\n        assert isinstance(reward, float)\n        \n        await interaction.finalize_interaction(instance_id)\n\nAdvanced Usage\n--------------\n\n**Multi-Interaction Training Strategies**\n\nYou can design sophisticated training scenarios using multiple interactions:\n\n.. code-block:: python\n\n    # Example: Progressive difficulty with different interaction agents\n    class MathTrainingPipeline:\n        def create_interaction_config(self):\n            return {\n                \"interaction\": [\n                    {\n                        \"name\": \"basic_math\",\n                        \"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\",\n                        \"config\": {\"difficulty\": \"easy\"}\n                    },\n                    {\n                        \"name\": \"advanced_math\", \n                        \"class_name\": \"custom.interactions.AdvancedMathInteraction\",\n                        \"config\": {\"difficulty\": \"hard\", \"allow_hints\": True}\n                    },\n                    {\n                        \"name\": \"competition_math\",\n                        \"class_name\": \"custom.interactions.CompetitionMathInteraction\", \n                        \"config\": {\"time_limit\": 300, \"show_steps\": False}\n                    }\n                ]\n            }\n    \n        def create_curriculum_data(self, epoch):\n            if epoch < 5:\n                return [{\"name\": \"basic_math\", ...} for _ in samples]\n            elif epoch < 10:\n                return [{\"name\": \"advanced_math\", ...} for _ in samples]\n            else:\n                return [{\"name\": \"competition_math\", ...} for _ in samples]\n\n**Custom Scoring Functions**\n\nYou can integrate custom reward functions:\n\n.. code-block:: python\n\n    async def calculate_score(self, instance_id, **kwargs):\n        response = self._instance_dict[instance_id][\"response\"]\n        ground_truth = self._instance_dict[instance_id][\"ground_truth\"]\n        \n        # Custom evaluation logic\n        if custom_evaluation_function(response, ground_truth):\n            return 1.0\n        else:\n            return 0.0\n\n**Multi-step Interactions**\n\nFor complex scenarios requiring multiple feedback rounds:\n\n.. code-block:: python\n\n    async def generate_response(self, instance_id, messages, **kwargs):\n        instance = self._instance_dict[instance_id]\n        instance[\"attempts\"] += 1\n        \n        # Evaluate current response\n        reward = await self.calculate_score(instance_id)\n        \n        if reward > 0.8:\n            return True, \"Excellent work!\", reward, {}\n        elif instance[\"attempts\"] < 3:\n            return False, \"Good attempt, but try to improve...\", reward, {}\n        else:\n            return True, \"Maximum attempts reached.\", reward, {}\n\nTroubleshooting\n---------------\n\n**Common Issues**\n\n1. **Instance ID Conflicts**: Ensure unique instance IDs across concurrent sessions\n2. **Memory Leaks**: Always call ``finalize_interaction()`` to clean up resources\n3. **Blocking Operations**: Keep interaction logic async and non-blocking\n4. **Configuration Errors**: Verify interaction config path and class name are correct\n5. **Interaction Name Conflicts**: Ensure all interactions have unique names in the configuration\n6. **Missing Interaction**: Verify the ``name`` field in ``interaction_kwargs`` matches available interactions\n7. **Backward Compatibility**: When migrating from single to multi-interaction, add ``name`` fields to existing data\n\n**Debugging**\n\nEnable debug logging to trace interaction flow:\n\n.. code-block:: bash\n\n    export VERL_LOGGING_LEVEL=DEBUG\n\n**Performance Monitoring**\n\nMonitor interaction performance impact on training throughput and adjust accordingly.\n\nRelated Documentation\n--------------------\n\n- :doc:`multiturn`: Basic multi-turn rollout configuration\n- :doc:`sandbox_fusion`: Tool integration with SGLang\n- :doc:`search_tool_example`: Search tool implementation example"
  },
  {
    "path": "verl_rl/docs/sglang_multiturn/multiturn.rst",
    "content": "Multi-turn Rollout Support\n==========================\n\nLast updated: 06/27/2025.\n\nBasic Configuration\n~~~~~~~~~~~~~~~~~~~\n\nTo enable multi-turn rollout, make sure to configure the following fields in your rollout configuration:\n\n.. code-block:: yaml\n\n    actor_rollout_ref: \n        rollout: \n            multi_turn: True\n            name: \"sglang\"\n\nThese configuration activates the sglang engine for multi-turn interaction during rollout.\n\nCustom Tool Configuration\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\nFor custom environment interaction tools, you can implement your own tools based on ``verl.tools.base_tool.BaseTool``. Then, specify your tool configurations in a YAML file:\n\n.. code-block:: yaml\n\n    tools:\n      - class_name: \"\"\n        config: \n            type: native\n        tool_schema:\n\nYou may refer to GSM8KTool_example_configuration_, which is one example of the tool configurations. Its implementation can be found in gsm8k_tool.py_.\n\nFinally, set the ``tools_config_file`` in your rollout config:\n\n.. code-block:: yaml\n\n    actor_rollout_ref:\n        rollout:\n            tool_kwargs:\n                tools_config_file: <path_to_tool_yaml_file>\n\nThis allows integration of customized tool behaviors during actor rollout steps.\n\nIf you want rollout with simulated interaction, you can set the ``interaction_config_file`` in your rollout config:\n\n.. code-block:: yaml\n\n    interaction:\n      - class_name: \"\"\n        config: {}\n\n.. code-block:: yaml\n\n    actor_rollout_ref:\n        rollout:\n            interaction_config_file: <path_to_interaction_yaml_file>\n\nIf your tool creates multi-modal inputs, you should return a list of multi-modal inputs in your tool.execute() implementation.\n\nImage and video should be processed before returning. For example, if you are using Qwen2.5-VL, you can use the following code to get the representations:\n\n.. code-block:: python\n\n    async def execute(self, ...) -> Tuple[str | Dict[str, Any], float, dict]:\n        ...\n        from verl.utils.dataset.vision_utils import process_image, process_video\n\n        img1 = process_image(img1)\n        video1 = process_video(video1)\n\n        # due to the (image | video) key is (\"image\" | \"video\") instead of (\"images\" | \"videos\") in vllm, we need to use (\"image\" | \"video\") to specify list of images/videos\n        # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205\n        return {\"image\": [img1, ...], \"video\": [video1, ...], \"text\": \"...\"}, 0, {}\n\nremeber to set ``return_multi_modal_inputs: False`` in your dataset config in order to process the multi-modal inputs in the rollout correctly.\nRefer to the `Handling Multi-Modal Inputs in Datasets`_ section for more details.\n\nMCP Tool Configuration\n~~~~~~~~~~~~~~~~~~~~~~\n\nFor MCP interaction tools, you can flexibly configure them using a YAML file. The typical setup is as follows:\n\n.. code-block:: yaml\n\n    tools:\n      - class_name: \"\"\n        config:\n            type: mcp\n        mcp:\n            mcp_servers_config_path: ./mcp_server.json\n            tool_selected_list: {}\n\nThe ``tool_selected_list`` field is optional and specifies which tools to use from the servers. If you want to enable all available tools, simply omit this attribute. Besides, ``mcp_servers_config_path`` points to a JSON file containing the MCP server configurations. For example:\n\n.. code-block:: json\n\n      {\n          \"mcpServers\": {\n              \"SSE Server\": {\n                  \"url\": \"your_server_url\",\n                  \"auth_token\": \"your_server_api_token\"\n              },\n              \"STDIO Server\": {\n                  \"command\": \"npx\",\n                  \"args\": [\"-y\", \"server-mcp@0.2.1\"],\n                  \"env\": {\n                    \"SERVER_API_KEY\": \"your_server_api_token\"\n                  }\n              }\n          }\n      }\n\nSince the content formats returned by the MCP server may vary, users can inherit from ``MCPBaseTool`` and override the ``_parse_tool_result`` method to implement custom parsing logic.\n\n.. code-block:: python\n\n   class MCPYourTool(MCPBaseTool):\n       def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n           super().__init__(config, tool_schema)\n\n       def _parse_tool_result(self, content: list) -> Tuple[str, dict]:\n           ...\n\nOverall, you may refer to mcp_search_tool.py_ and mcp_tool_config.yaml_ for custom implementation and configuration.\n\nMulti-turn Tokenization\n~~~~~~~~~~~~~~~~~~~~~~~\n\nTokenizing multi-turn rollouts poses a challenge: after applying the chat template and tokenizing the full message list, it's hard to identify which tokens belong to assistant messages. Since the token list is flat, it lacks direct alignment with the message roles.\n\nTo address this, we adopt a **delta-based tokenization** strategy. Each time the LLM generates a new message, we:\n\n1. Apply the chat template to all prior messages (`messages[:i]`).\n2. Apply the chat template again including the latest message (`messages[:i+1]`).\n3. Tokenize only the *delta* between these two serialized message strings.\n\nThis ensures that only tokens generated by the assistant are included in the loss mask.\n\n.. code-block:: python\n\n   # When using tokenizer\n   # Exclude the assistant prompt (e.g., \"<|im_start|>assistant\") from the loss by setting add_generation_prompt=True\n   prev = tokenizer.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False)\n   curr = tokenizer.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False)\n   token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False)\n   loss_mask += [1] * len(token_ids)  # Mask only the new assistant tokens\n\n.. code-block:: python\n\n   # When using processor\n   # Exclude the assistant prompt (e.g., \"<|im_start|>assistant\") from the loss by setting add_generation_prompt=True\n   prev = processor.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False)\n   prev_model_inputs = processor(text=prev, images=images, videos=videos, return_tensors=\"pt\")[0].tolist()\n   curr = processor.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False)\n   curr_model_inputs = processor(text=curr, images=images, videos=videos, return_tensors=\"pt\")[0].tolist()\n   token_ids += curr_model_inputs[\"input_ids\"][len(prev_model_inputs[\"input_ids\"]):]\n   loss_mask += [1] * len(token_ids)  # Mask only the new assistant tokens\n\nWhile we've validated this produces consistent results with full message tokenization, future models' chat template could break compatibility. To guard against silent inconsistencies, we compare the delta-based tokenization with full-tokenization results by default at the end of each rollout.\n\nIf you see the following warning, you can check the mismatched substring in the log:\n\n.. code-block::\n\n    Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more information, refer to the multiturn README.md.\n\nThe tokenization sanity check mode can be configured using the ``actor_rollout_ref.rollout.multi_turn.tokenization_sanity_check_mode`` parameter, which accepts the following values:\n\n- ``strict`` (default): Performs strict comparison between delta-based and full tokenization results, raising warnings for any differences.\n\n- ``ignore_strippable``: Ignores differences in whitespace characters (``\\n``, ``\\t``, ``\\r``, spaces) while still checking for meaningful text mismatches. This is useful when debugging chat template issues where whitespace variations are expected and acceptable.\n\n- ``disable``: Completely disables the tokenization sanity check. Only use this if you have thoroughly validated that tokenization discrepancies are expected and won't impact training.\n\nExample configuration:\n\n.. code-block:: yaml\n\n    actor_rollout_ref:\n        rollout:\n            multi_turn:\n                tokenization_sanity_check_mode: \"ignore_strippable\"  # Choose from: \"disable\", \"ignore_strippable\", \"strict\"\n\nHandling Multi-Modal Inputs in Datasets\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nIf your dataset includes multi-modal inputs (such as images or videos), you can control whether these are pre-processed and included in each sample by setting the return_multi_modal_inputs flag in your dataset config (used by RLHFDataset).\n\n- ``return_multi_modal_inputs: True`` (default): The dataset will pre-process and include a multi_modal_inputs dictionary for each sample. This dict contains the model-ready representations (e.g., image tensors, video tensors, etc.) as produced by your processor. This is useful for single-turn or SFT-style training, where the model expects all modalities to be present in the batch.\n\n- ``return_multi_modal_inputs: False``: The dataset will not include the multi_modal_inputs field. This is recommended for multi-turn RL or tool-augmented rollouts, where the model may generate new multi-modal inputs dynamically during rollout, and you want to avoid conflicts or redundant data in the batch.\n\n\nSpecial Cases\n^^^^^^^^^^^^^\n\nSome models (e.g., Qwen/QwQ-32B and Qwen3 series) remove internal reasoning content during chat template rendering. As a result, the message content can vary across turns, making the delta-based tokenization inaccurate.\n\nFor example, for the following conversation:\n\n.. code-block:: python\n\n    messages = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"What is 2 + 2?\"},\n        {\"role\": \"assistant\", \"content\": \"<think>user asked about a simple math question.</think> 2 + 2 = 4.\"},\n        {\"role\": \"user\", \"content\": \"Explain why.\"},\n        {\"role\": \"assistant\", \"content\": \"<think>user wants to know the reasoning behind the answer. Search for a good explanation</think>\",\n         \"tool_calls\": [{\"id\": \"tool1\", \"type\": \"search\", \"arguments\": {\"query\": \"Why is 2 + 2 = 4?\"}}]},\n        {\"role\": \"tool\", \"content\": \"The sum of two and two is four because it is a basic arithmetic operation.\"},\n        {\"role\": \"assistant\", \"content\": \"<think>The tool provided a good explanation.</think>The sum of two and two is four because it is a basic arithmetic operation.\"}\n    ]\n\n1. Qwen/QwQ-32B will remove all reasoning content except the last assistant message after applying the chat template.\n\n.. code-block:: text\n\n    <|im_start|>system\n    You are a helpful assistant.<|im_end|>\n    <|im_start|>user\n    What is 2 + 2?<|im_end|>\n    <|im_start|>assistant\n     2 + 2 = 4.<|im_end|>\n    <|im_start|>user\n    Explain why.<|im_end|>\n    <|im_start|>assistant\n    <tool_call>\n    {\"name\": \"\", \"arguments\": {\"query\": \"Why is 2 + 2 = 4?\"}}\n    </tool_call><|im_end|>\n    <|im_start|>user\n    <tool_response>\n    The sum of two and two is four because it is a basic arithmetic operation.\n    </tool_response><|im_end|>\n    <|im_start|>assistant\n    <think>The tool provided a good explanation.</think> The sum of two and two is four because it is a basic arithmetic operation.<|im_end|>\n\n2. Qwen3 series will remove all reasoning content before the last user message.\n\n.. code-block:: text\n\n    <|im_start|>system\n    You are a helpful assistant.<|im_end|>\n    <|im_start|>user\n    What is 2 + 2?<|im_end|>\n    <|im_start|>assistant\n     2 + 2 = 4.<|im_end|>\n    <|im_start|>user\n    Explain why.<|im_end|>\n    <|im_start|>assistant\n    <think>\n    user wants to know the reasoning behind the answer. Search for a good explanation\n    </think>\n\n    <tool_call>\n    {\"name\": \"\", \"arguments\": {\"query\": \"Why is 2 + 2 = 4?\"}}\n    </tool_call><|im_end|>\n    <|im_start|>user\n    <tool_response>\n    The sum of two and two is four because it is a basic arithmetic operation.\n    </tool_response><|im_end|>\n    <|im_start|>assistant\n    <think>\n    The tool provided a good explanation.\n    </think>\n\n    The sum of two and two is four because it is a basic arithmetic operation.<|im_end|>\n\nTo handle this, we fall back to a **fixed base conversation** containing only a single system and user message. Since this base doesn't include assistant messages or reasoning content, it remains consistent across turns.\n\n.. code-block:: python\n\n    BASE_CHAT_HISTORY = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"I am a user.\"}\n    ]\n    prev = tokenizer.apply_chat_template(BASE_CHAT_HISTORY, add_generation_prompt=True, tokenize=False)\n    curr = tokenizer.apply_chat_template([*BASE_CHAT_HISTORY, messages[i]], add_generation_prompt=False, tokenize=False)\n    token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False)\n    loss_mask += [1] * len(token_ids)\n\nThis method works well for Qwen3 series. However, Qwen/QwQ-32B currently has a bug in its chat template. A fix_ has been proposed but not yet adopted. Until then, use the following command to download the fixed model revision:\n\n.. code-block:: bash\n\n    pip install huggingface_hub\n    huggingface-cli download Qwen/QwQ-32B --revision refs/pr/81\n\n.. _fix: https://huggingface.co/Qwen/QwQ-32B/discussions/81\n\nDiscrepancy Between Training and Inference Templates\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nAlthough the above approach fixes the delta mismatch issue, the removal of reasoning content in the inference-time chat template introduces a new discrepancy: training uses the full reasoning content, while inference does not.\n\nThis mismatch can affect model performance in unpredictable ways. To avoid it, we default to using the full response (including reasoning) for both training and rollout.\n\nHowever, this approach comes with trade-offs:\n\n1. Long reasoning contents can easily exceed the model's context window, especially in multi-turn rollout.\n2. There's a mismatch between rollout and production environment now—models will not have reasoning content from past turns if you use the default chat template in production.\n\nWe are still evaluating the impact of these issues. If you experience context length problems or prefer rollouts that match production (i.e., exclude reasoning), you can enable:\n\n``actor_rollout_ref.rollout.multi_turn.use_inference_chat_template = True``\n\nGSM8K Multi-turn Training Performance  \n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nSee the training performance of multi-turn rollout on the GSM8K task HERE_.\n\n.. _HERE: https://wandb.ai/zhaochenyang20/gsm8k_async_rl/runs/1ro1r7om?nw=nwuserzhaochenyang20\n\n.. _GSM8KTool_example_configuration: https://github.com/volcengine/verl/blob/main/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\n\n.. _gsm8k_tool.py: https://github.com/volcengine/verl/blob/main/verl/tools/gsm8k_tool.py\n\n.. _mcp_search_tool.py: https://github.com/volcengine/verl/blob/main/verl/tools/mcp_search_tool.py\n\n.. _mcp_tool_config.yaml: https://github.com/volcengine/verl/blob/main/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml\n\nInteraction System\n~~~~~~~~~~~~~~~~~~\n\nFor dynamic conversational feedback during RL training, see:\n\n.. toctree::\n   :maxdepth: 1\n\n   interaction_system\n\nSearch Tool Integration\n~~~~~~~~~~~~~~~~~~~~~~~\n\n.. toctree::\n   :maxdepth: 1\n\n   search_tool_example\n\nCode Walkthrough\n~~~~~~~~~~~~~~~~~~~~~~~\nIf you want to learn more in depth about the code execution flow, please read https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/rlhf/verl/multi-turn/code-walk-through\n"
  },
  {
    "path": "verl_rl/docs/sglang_multiturn/sandbox_fusion.rst",
    "content": "===============================\nSandbox Fusion Tool Integration\n===============================\n\nLast updated: 06/10/2025.\n\nMotivations\n===========\n\n- As users of verl, we want to allow the model to call certain tools during Actor rollout, incorporating the results into the training process.\n- A colleague from ByteDance proposed a paper aimed at enhancing model capability through code execution tools.\n- We aim to support tool-calling capabilities of inference engines using `sandbox-fusion` as the code execution system, providing the community with a reimplementation of `retools`.\n\nReward Compute with Sandbox Fusion + FaaS Integration\n=====================================================\n\n- In current datasets and tasks, similar work already exists (e.g., Prime), which uses local processes as runners to execute model-generated code for reward computation.\n- On this basis, #1429 has advanced the design by integrating FaaS as the runner for reward computation.\n\nGoals\n=====\n\n- Adapt to the `sglang` tool-calling protocol and define tools for sandbox fusion.\n- Integrate with the `async-rollout` process, ensuring sandbox fusion tools follow asyncIO conventions.\n- Design and implement a basic rate limiter to prevent issues such as 429 errors.\n\nNon-Goals\n=========\n\n- Training effectiveness is out of scope.\n- Observability metrics are not considered.\n- Distributed failover and component fault tolerance are not addressed.\n\nDesign Details\n==============\n\nTool Schema Definition\n----------------------\n\n- Currently, only code execution is considered, requiring a `code` field in the JSON from the model.\n- Only Python code is supported for now, so no `language` parameter is defined.\n\n.. code-block:: python\n\n   OpenAIFunctionToolSchema(\n       type=\"function\",\n       function=OpenAIFunctionSchema(\n           name=\"code_interpreter\",\n           description=\"A tool for executing code.\",\n           parameters=OpenAIFunctionParametersSchema(\n               type=\"object\",\n               properties={\n                   \"code\": OpenAIFunctionPropertySchema(\n                       type=\"string\",\n                       description=\"The code to execute.\",\n                       enum=None,\n                   )\n               },\n               required=[\"code\"],\n           ),\n           strict=False,\n       )\n   )\n\nConfiguration Parameters\n--------------------------\n\n+----------------------------+--------------------------------------------------------------+\n| Parameter Name             | Description                                                  |\n+============================+==============================================================+\n| `num_workers`              | Number of worker threads/processes per DP to request runner. |\n+----------------------------+--------------------------------------------------------------+\n| `rate_limit`               | Global limit of concurrent code executions. Default: 10      |\n+----------------------------+--------------------------------------------------------------+\n| `default_timeout`          | Timeout (in seconds) for each code execution. Default: 30    |\n+----------------------------+--------------------------------------------------------------+\n| `default_language`         | Default programming language. Default: \"python\"              |\n+----------------------------+--------------------------------------------------------------+\n| `enable_global_rate_limit` | Whether to enable global rate limiting. Default: True        |\n+----------------------------+--------------------------------------------------------------+\n| `sandbox_fusion_url`       | URL for the veFaas sandbox execution service                 |\n+----------------------------+--------------------------------------------------------------+\n\nRate Limiting Design\n-----------------------\n\nObjective:\n\n- Limit the number of inflight requests using a token bucket model.\n\n- Ensure ordered submission to code runners to avoid starvation due to backoff.\n\nDesign Highlights:\n\n- Use Ray Global Actor as a singleton distributed counter at cluster level.\n  \n- Semaphore used for counting, with `acquire` and `release` in separate thread pools to preserve order.\n  \n- Use Ray’s cloud-pickle to serialize functions for decoupled `ExecutionWorker`.\n\n.. code-block:: python\n\n   @ray.remote(concurrency_groups={\"acquire\": 1,\"release\": 10})\n   class TokenBucketWorker:\n       def __init__(self, rate_limit: int):\n           self.rate_limit = rate_limit\n           self.current_count = 0\n           self._semaphore = threading.Semaphore(rate_limit)\n\n       @ray.method(concurrency_group=\"acquire\")\n       def acquire(self):\n           self._semaphore.acquire()\n           self.current_count += 1\n\n       @ray.method(concurrency_group=\"release\")\n       def release(self):\n           self._semaphore.release()\n           self.current_count -= 1\n\n       def get_current_count(self):\n           return self.current_count\n\n   class ExecutionWorker:\n       def __init__(self, enable_global_rate_limit=True, rate_limit=10):\n           self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None\n\n       def _init_rate_limit(self, rate_limit):\n           return TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote(rate_limit)\n\n       def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T:\n           with ExitStack() as stack:\n               stack.callback(self.rate_limit_worker.release.remote)\n               ray.get(self.rate_limit_worker.acquire.remote())\n               try:\n                   return fn(*fn_args, **fn_kwargs)\n               except Exception as e:\n                   logger.warning(f\"Error when executing code: {e}\")\n\n   def init_execution_pool(num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode=PoolMode.ThreadMode):\n       if mode == PoolMode.ThreadMode:\n           return ray.remote(ExecutionWorker).options(max_concurrency=num_workers).remote(\n               enable_global_rate_limit=enable_global_rate_limit,\n               rate_limit=rate_limit\n           )\n       else:\n           raise NotImplementedError(\"Process mode is not implemented yet\")\n\nTool Implementation\n-------------------\n\n- Use `instance_id` to identify requests across multiple dialogue rounds.\n  \n- Use `execution_pool` to implement async invocation.\n  \n- Cleanup state after rollout completion.\n\n.. code-block:: python\n\n   class SandboxFusionTool(BaseTool):\n       def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n           ...\n           self.execution_pool = init_execution_pool(...)\n           ...\n\n       async def create(self, instance_id: Optional[str] = None, ...):\n           ...\n\n        async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:\n            code = parameters.get(\"code\", \"\")\n            timeout = parameters.get(\"timeout\", self.default_timeout)\n            language = parameters.get(\"language\", self.default_language)\n            if not isinstance(code, str):\n                code = str(code)\n\n            result = await self.execution_pool.execute.remote(self.execute_code,instance_id,code,timeout,language)\n            self._instance_dict[instance_id][\"reward\"].append(result.strip())\n\n            return result, result, {}\n\n        def execute_code(self,instance_id,code,timeout=30,language=\"python\"):\n            result_status, metadata  = _process_single_case(0, None, None,self.sandbox_fusion_url, code, timeout, language)\n            # we should always expect this since we don't have correct answer\n            if metadata[\"run_status\"] == \"Finished\":\n                actual_output = metadata[\"stdout\"] if metadata[\"stdout\"] is not None else \"\"\n                return actual_output\n            else:\n                return \"no stdout here\"\n\n       async def calc_reward(self, instance_id: str, ...):\n           ...\n\n       async def release(self, instance_id: str, ...):\n           ...\n\nTest Plan\n=========\n\nUnit Tests\n----------\n\n- **test_tools_registration**: Test tool registration and initialization.\n- **test_rollout_req_creation**: Validate that `AsyncRolloutReq` is built correctly.\n- **test_over_size_case**: Ensure rollout terminates early when exceeding `max_seq_len`.\n- **test_tool_call_basic_case**: Mock `sglang` output, validate tool call and result.\n- **test_tool_call_batch_case**: Test batch processing of tool calls.\n- **test_basic_multi_process_init**: Validate Ray global actor behaves as singleton.\n- **TestSingleNodeRateLimiterCase**: Verify rate limiter works in single-node mode.\n- **test_rotten_execution**: Ensure rate limiter recovers from function errors.\n- **TestMultiNodeRateLimiterCase**: Verify behavior in multi-node environments.\n\ne2e Tests\n----------\nwe provide e2e test scripts in `tests/special_e2e` folder, named `tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh`\n\nby setting 'trainer.rollout_data_dir' you can dump the rollout data to local disk. here is an sample taken from the rollout data:\n\n.. code-block:: python\n\n   {\n     \"input\": \"\n     \n     system\\nYou are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### <answer>`.\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\\n{\\\"type\\\": \\\"function\\\", \\\"function\\\": {\\\"name\\\": \\\"code_interpreter\\\", \\\"description\\\": \\\"A tool for executing code.\\\", \\\"parameters\\\": {\\\"type\\\": \\\"object\\\", \\\"properties\\\": {\\\"code\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The code to execute.\\\", \\\"enum\\\": null}}, \\\"required\\\": [\\\"code\\\"]}, \\\"strict\\\": false}}\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call>\\n\n     \n     user\\nJohn gets a bonus that's the same percentage every year.  Last year he made $100,000 and got a $10,000 bonus.  This year he makes $200,000.  How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`.\\n\n     \n     assistant\\n\",\n     \"output\": \"assistant\\n<think>\\nOkay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together.\\n\\nFirst, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year.\\n\\nLast year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that.\\n\\n10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year.\\n\\nThis year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy.\\n\n     \n     <tool_call>\\n{\\\"name\\\": \\\"code_interpreter\\\", \\\"arguments\\\": {\\\"code\\\": \\\"total_pay_this_year = 200000\\\\nbonus_percentage = 10 / 100\\\\nbonus_this_year = total_pay_this_year * bonus_percentage\\\\ntotal_income_this_year = total_pay_this_year + bonus_this_year\\\\nprint(total_income_this_year)\\\", \\\"executes\\\": \\\"True\\\"}}\\n</tool_call>\\n\n     \n     tool\\n220000.0\\n\\n\n     \n     assistant\\n<think>\\nSo the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000.\\n\\nI need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters.\\n</think>\\n\\n#### 220000.0\",\n     \"score\": 0,\n     \"step\": 1\n   }\n\nhere is the readable format version:\n\n.. code-block:: python\n\n   [system]\n   \n   You are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### <answer>`.\n   \n   # Tools\n   \n   You may call one or more functions to assist with the user query.\n   \n   You are provided with function signatures within <tools></tools> XML tags:\n   <tools>\n   {\"type\": \"function\", \"function\": {\"name\": \"code_interpreter\", \"description\": \"A tool for executing code.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"code\": {\"type\": \"string\", \"description\": \"The code to execute.\", \"enum\": null}}, \"required\": [\"code\"]}, \"strict\": false}}\n   </tools>\n   \n   For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n   <tool_call>\n   {\"name\": <function-name>, \"arguments\": <args-json-object>}\n   </tool_call>\n   \n   [user]\n   \n   John gets a bonus that's the same percentage every year.  Last year he made $100,000 and got a $10,000 bonus.  This year he makes $200,000.  How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`.\n   \n   [assistant]\n   \n   <think>\n   Okay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together.\n   \n   First, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year.\n   \n   Last year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that.\n   \n   10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year.\n   \n   This year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy.\n   <tool_call>\n   {\"name\": \"code_interpreter\", \"arguments\": {\"code\": \"total_pay_this_year = 200000\\nbonus_percentage = 10 / 100\\nbonus_this_year = total_pay_this_year * bonus_percentage\\ntotal_income_this_year = total_pay_this_year + bonus_this_year\\nprint(total_income_this_year)\", \"executes\": \"True\"}}\n   </tool_call>\n   \n   [tool]\n   \n   220000.0\n   \n   [assistant]\n   \n   <think>\n   So the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000.\n   \n   I need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters.\n   </think>\n   \n   #### 220000.0\n\n\nYou can also use the `RolloutViewer` TUI tool to view the dumped rollout data:\n\n\n.. code-block:: bash\n\n    python scripts/rollout_viewer.py ${trainer.rollout_data_dir}\n\n\n.. image:: https://github.com/user-attachments/assets/e34e5157-2880-4a21-afb2-73885d0dfb11\n   :alt: RolloutViewer screenshot"
  },
  {
    "path": "verl_rl/docs/sglang_multiturn/search_tool_example.rst",
    "content": "=======================\r\nSearch Tool Integration\r\n=======================\r\n\r\nLast updated: 05/30/2025.\r\n\r\nIntroduction\r\n------------\r\n- We have added a search tool calling function to Multi-Turn RL, enabling the model to initiate retrieval requests during Actor rollout and directly use retrieval results for training. **We support using a local dense retriever as the retrieval tool, as well as integrating with your own local retrieval engine.**\r\n\r\n\r\n\r\nQuick Reproduction\r\n------------------\r\n\r\nCreate a New Docker Container\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n.. code:: bash\r\n\r\n   docker run \\\r\n       -it \\\r\n       --shm-size 32g \\\r\n       --gpus all \\\r\n       -v {Huggingface-Cache-Path}:/root/.cache \\\r\n       --ipc=host \\\r\n       --network=host \\\r\n       --privileged \\\r\n       --name sglang_{your-name} \\\r\n       lmsysorg/sglang:dev \\\r\n       /bin/zsh\r\n\r\nIf you need to restart after exiting the container:\r\n\r\n.. code:: bash\r\n\r\n   docker start -i sglang_{your-name}\r\n\r\nUpdate Python and Configure the Virtual Environment using uv\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n.. code:: bash\r\n\r\n   apt update\r\n   apt install -y python3.10 python3.10-venv\r\n\r\n   # Create a virtual environment\r\n   python3 -m venv ~/.python/verl-multiturn-rollout\r\n\r\n   # Activate the virtual environment\r\n   source ~/.python/verl-multiturn-rollout/bin/activate\r\n\r\n   # Install uv\r\n   python3 -m pip install uv\r\n\r\nInstall verl Upstream\r\n~~~~~~~~~~~~~~~~~~~~~\r\n\r\n.. code:: bash\r\n\r\n   cd ~\r\n   git clone https://github.com/volcengine/verl.git\r\n   cd verl\r\n\r\n   # Install verl\r\n   python3 -m uv pip install .\r\n   python3 -m uv pip install -r ./requirements_sglang.txt\r\n\r\n   # Manually install flash-attn\r\n   python3 -m uv pip install wheel\r\n   python3 -m uv pip install packaging\r\n   python3 -m uv pip install flash-attn --no-build-isolation --no-deps\r\n\r\nSet Up a Local Retrieval Engine\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\nIf you are using your own local retrieval service, you can skip this\r\nstep. We chose the local dense retriever provided in the search-R1\r\nexample; detailed instructions are in the `searchR1\r\ndocs <https://raw.githubusercontent.com/PeterGriffinJin/Search-R1/refs/heads/main/docs/retriever.md>`__.\r\nIn brief:\r\n\r\n-  The GPU version offers higher accuracy and speed; each GPU uses about\r\n   5–7 GB of memory.\r\n-  The CPU version can be used for simple testing but has lower\r\n   retrieval precision, which will degrade training performance. See the\r\n   `retriever\r\n   documentation <https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/retriever.md>`__\r\n   in search-R1 for details.\r\n-  Recommend using Conda to install faiss-gpu=1.8.0; venv may cause errors.\r\n\r\n**Note**: To start both the training process and the local retrieval\r\nservice, we launch two separate Python environments. The training uses\r\nuv in the verl-multiturn-rollout environment, while the retriever uses\r\nconda to install ``faiss-gpu``.\r\n\r\n.. code:: bash\r\n\r\n   # Download the Miniconda installer script\r\n   wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh\r\n\r\n   # Install to $HOME/miniconda3 in batch mode\r\n   bash ~/miniconda.sh -b -p $HOME/miniconda3\r\n\r\n   # Activate conda (only in the current shell)\r\n   eval \"$($HOME/miniconda3/bin/conda shell.bash hook)\"\r\n\r\n   # (Optional) Add conda to your default shell startup\r\n   conda init\r\n\r\n   # Reload shell config\r\n   source ~/.bashrc\r\n\r\n   # Create and activate the retriever environment with Python 3.10\r\n   conda create -n retriever python=3.10 -y\r\n   conda activate retriever\r\n\r\n   # Install PyTorch (with GPU support) and related libraries\r\n   conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y\r\n\r\n   # Install other Python packages\r\n   pip install transformers datasets pyserini huggingface_hub\r\n\r\n   # Install the GPU version of faiss\r\n   conda install faiss-gpu=1.8.0 -c pytorch -c nvidia -y\r\n\r\n   # Install the API service framework\r\n   pip install uvicorn fastapi\r\n\r\nDownload the Indexing and Corpus\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\nThe local retrieval files are large—prepare sufficient disk space.\r\nDownloading is about 60–70 GB, and uncompressed takes about 132 GB:\r\n\r\n.. code:: bash\r\n\r\n   conda activate retriever\r\n\r\n   save_path=/the/path/to/save\r\n   python examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py --save_path $save_path\r\n   cat $save_path/part_* > $save_path/e5_Flat.index\r\n   gzip -d $save_path/wiki-18.jsonl.gz\r\n\r\nStart the Local flat e5 Retrieval Server\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n1. The first startup will download models and load the index.\r\n2. Apart from the download, startup takes about 1–2 minutes.\r\n3. After startup, each GPU uses about 5–7 GB of memory, leaving the rest\r\n   for multi-turn RL training.\r\n\r\n.. code:: bash\r\n\r\n   conda activate retriever\r\n\r\n   index_file=$save_path/e5_Flat.index\r\n   corpus_file=$save_path/wiki-18.jsonl\r\n   retriever_name=e5\r\n   retriever_path=intfloat/e5-base-v2\r\n\r\n   python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \\\r\n     --index_path $index_file \\\r\n     --corpus_path $corpus_file \\\r\n     --topk 3 \\\r\n     --retriever_name $retriever_name \\\r\n     --retriever_model $retriever_path \\\r\n     --faiss_gpu\r\n\r\nSet Up WANDB_API_KEY\r\n~~~~~~~~~~~~~~~~~~~~\r\n\r\n.. code:: bash\r\n\r\n   export WANDB_API_KEY={YOUR_WANDB_API_KEY}\r\n\r\n   # Define a timestamp function\r\n   function now() {\r\n       date '+%Y-%m-%d-%H-%M'\r\n   }\r\n\r\n**Preprocess the Dataset**\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n   **Note:** The following data processing and training commands must be\r\n   run in the verl-multiturn-rollout environment.\r\n\r\n.. code:: bash\r\n\r\n   python3 examples/data_preprocess/preprocess_search_r1_dataset.py\r\n\r\nTesting on 8 x H20\r\n~~~~~~~~~~~~~~~~~~\r\n\r\n.. code:: bash\r\n\r\n   # Ensure the now() function is defined\r\n   # Create a logs directory\r\n   mkdir -p logs\r\n\r\n   # Set GPUs and run with a suitable log path\r\n   export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\r\n\r\n   nohup bash examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh \\\r\n     trainer.experiment_name=qwen2.5-3b-it_rm-searchR1-like-sgl-multiturn-$(now) \\\r\n     > logs/searchR1-like$(now).log 2>&1 &\r\n\r\nCustom Search Configuration\r\n---------------------------\r\n\r\nTo enable multi-turn reasoning, set the following fields in your config:\r\n\r\n.. code:: yaml\r\n\r\n   actor_rollout_ref:\r\n     rollout:\r\n       name: \"sglang\"\r\n       multi_turn:\r\n         enable: True\r\n\r\nYou must specify ``retrieval_service_url`` in ``examples/sglang_multiturn/config/tool_config/search_tool_config.yaml``, and properly configure concurrency. For more details on concurrency, refer to the Sandbox Fusion example:\r\n\r\n.. code:: yaml\r\n\r\n   tools:\r\n     - class_name: verl.tools.search_tool.SearchTool\r\n       config:\r\n         retrieval_service_url: http://127.0.0.1:8000/retrieve\r\n         num_workers: 120\r\n         rate_limit: 120\r\n         timeout: 30\r\n\r\nThe retriever input/output formats are as follows. If your service\r\nparameters match, only modify ``retrieval_service_url``. You can also\r\ncustomize in ``search_r1_like_utils.py``.\r\n\r\n.. code:: python\r\n\r\n   Input format:\r\n   {\r\n     \"queries\": [\"What is Python?\", \"Tell me about neural networks.\"],\r\n     \"topk\": 3,\r\n     \"return_scores\": true\r\n   }\r\n\r\n   Output format (when return_scores=True, similarity scores are returned):\r\n   {\r\n       \"result\": [\r\n           [   # Results for each query\r\n               {\r\n                   \"document\": doc, \"score\": score\r\n               },\r\n               # ... more documents\r\n           ],\r\n           # ... results for other queries\r\n       ]\r\n   }\r\n\r\nNotes\r\n-----\r\n\r\n1. The total training time is about 27 hours; meanwhile, the validation\r\n   dataset is very large (51 k), and each validation takes about 6000 s.\r\n   (Therefore, ``val_before_train=False`` by default)\r\n"
  },
  {
    "path": "verl_rl/docs/single_controller.rst",
    "content": "The Design of ``verl.single_controller``\n==============================================\n\nLast updated: 05/21/2025.\n\n**Author:**\\  `Wang Zhang <https://github.com/zw0610>`__\n\nPreface\n-------\n\nWe prepared this document for developers of ``verl``, particularly those\ninterested in understanding or contributing to the\n``verl.single_controller`` module. It is not intended for end users, but\nfor contributors seeking to understand the architectural rationale and\ninternal mechanics.\n\n--------------\n\nOrigin\n------\n\nThe ``single_controller`` module originated from a request I received —\nto adapt a toy single-process RLHF script into a distributed system with\nminimal changes, while maintaining ease of debugging.\n\nCommon practice — such as using PyTorch’s Distributed Data Parallel\n(DDP) — typically involves wrapping ``nn.Module`` and launching multiple\nprocesses that execute the same function under different ranks. However,\nthis approach presents two main limitations in the context of\ndistributed RLHF: - Difficulty representing multiple DAGs as required by\nPPO; - Difficulty inspecting intermediate tensors during training.\n\nTo maintain debuggability, we opted for a different approach — breaking\nthe training loop into well-defined stages like ``generate_sequences``,\n``compute_advantages``, and so on.\n\nWe selected `Ray <https://www.ray.io/>`__ as the initial backend for\n``verl`` due to its ability to expose Python class methods as RPC\nendpoints. However, Ray’s default model only supports **one method call,\none RPC**, while training LLMs typically requires coordination across\nmultiple processes.\n\nTo hide this multi-Ray actors invocation for a single method from users,\nwe introduced the following components:\n\n-  ``WorkerGroup`` – manages a group of remote workers and provides\n   a unified interface for multi-process distributed computation;\n-  ``ResourcePool`` – binds computational resources to worker\n   processes;\n-  ``ClassWithArgs`` – enables delayed remote instantiation with\n   specified initialization arguments.\n\n--------------\n\nA Running Example: ``generate_sequences``\n-----------------------------------------\n\nTo illustrate the design, we walk through how the ``generate_sequences``\nmethod in the ``ActorRolloutRefWorker`` class is registered and invoked\nacross distributed workers.\n\n--------------\n\nStep 1: Register with a Decorator\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe first step is to define the ``generate_sequences`` and decorate it\nwith ``@register`` as it will be called in driver script.\n\n**Source:**\n`fsdp_workers.py <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/workers/fsdp_workers.py#L528>`__\n\n.. code:: python\n\n   class ActorRolloutRefWorker(Worker):\n       ...\n       @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n       def generate_sequences(self, prompts: DataProto):\n           prompts = prompts.to(torch.cuda.current_device())\n           ...\n\nThe ``@register`` decorator adds metadata to the ``generate_sequences``\nmethod. Currently, it doesn’t alter functionality, but attaches\nattributes via a magic key (``MAGIC_ATTR``):\n\n**Source:**\n`decorator.py <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L411>`__\n\n.. code:: python\n\n   def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):\n       ...\n       def decorator(func):\n           @wraps(func)\n           def inner(*args, **kwargs):\n               if materialize_futures:\n                   args, kwargs = _materialize_futures(*args, **kwargs)\n               return func(*args, **kwargs)\n\n           attrs = {\"dispatch_mode\": dispatch_mode, \"execute_mode\": execute_mode, \"blocking\": blocking}\n           setattr(inner, MAGIC_ATTR, attrs)\n           return inner\n\n       return decorator\n\nAs the code shows, values of ``dispatch_mode``, ``execute_mode`` and\n``blocking`` is attached the ``generate_sequences`` method.\n\n--------------\n\nStep 2: Binding During Initialization\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThese attached attributes are extracted and utilized when\n``ActorRolloutRefWorker``, wrapped in a ``RayClassWithArgs``, is passed\ninto a ``RayWorkerGroup``.\n\n**Source:**\n`main_generation.py <https://github.com/volcengine/verl/blob/4ae9a0fdab229f75f080e9478807783ed4c97154/verl/trainer/main_generation.py#L82>`__\n\n.. code:: python\n\n   ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role=\"rollout\")\n   resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)\n   wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)\n\nDuring the\n`initialization <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L184>`__\nof ``RayWorkerGroup``, two key steps occur:\n\n1. Worker instances (Ray actors) are created:\n   `RayWorkerGroup._init_with_resource_pool <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L211>`__\n2. Methods decorated with ``@register`` are bound to ``RayWorkerGroup``:\n   `RayWorkerGroup._bind_worker_method <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L214>`__\n\n.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/worker_group_init.png?raw=true\n   :alt: initialization_and_binding_of_worker_group\n\n   initialization_and_binding_of_worker_group\n\nThe binding procedure is the heart of ``verl.single_controller``.\n\n**Key function:**\n`WorkerGroup._bind_worker_method <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/worker_group.py#L143>`__\n\n.. code:: python\n\n   def _bind_worker_method(self, user_defined_cls, func_generator):\n       ...\n       for method_name in dir(user_defined_cls):\n           try:\n               method = getattr(user_defined_cls, method_name)\n               assert callable(method)\n           except Exception:\n               continue  # Skip properties\n           <<<to be continue 1>>>\n\nWhen a method has the ``MAGIC_ATTR``, the attributes set by\n``@register`` are extracted:\n\n.. code:: python\n\n           <<<continue 1>>>\n           if hasattr(method, MAGIC_ATTR):\n               attribute = getattr(method, MAGIC_ATTR)\n               dispatch_mode = attribute[\"dispatch_mode\"]\n               execute_mode = attribute[\"execute_mode\"]\n               blocking = attribute[\"blocking\"]\n\n               <<<to be continue 2>>>\n\nAs show in the flow chart above, these attributes are fed into\n``func_generator``. However, ``func_generator`` takes ``method_name``,\n``dispatch_fn``, ``collect_fn``, ``execute_fn``, ``blocking``. We need\nto find the corresponding ``dispatch_fn`` and ``collect_fn`` associated\nwith the ``dispatch_mode`` (``DP_COMPUTE_PROTO``) from\n`DISPATCH_MODE_FN_REGISTRY <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L387>`__:\n\n.. code:: python3\n\n   DISPATCH_MODE_FN_REGISTRY = {\n       Dispatch.ONE_TO_ALL: {\n           \"dispatch_fn\": dispatch_one_to_all,\n           \"collect_fn\": collect_all_to_all,\n       },\n       ...\n       Dispatch.DP_COMPUTE_PROTO: {\n           \"dispatch_fn\": dispatch_dp_compute_data_proto,\n           \"collect_fn\": collect_dp_compute_data_proto,\n       },\n       ...\n   }\n\nSimilarly, the ``execute_fn`` is selected by ``execute_mode`` and\nextracted by:\n\n.. code:: python\n\n               <<<continue 2>>>\n               # get execute_fn_name\n               execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)\n               wg_execute_fn_name = execute_mode[\"execute_fn_name\"]\n\n               # get execute_fn from string\n               try:\n                   execute_fn = getattr(self, wg_execute_fn_name)\n                   assert callable(execute_fn), \"execute_fn must be callable\"\n               except Exception:\n                   print(f\"execute_fn {wg_execute_fn_name} is invalid\")\n                   raise\n               <<<to be continue 3>>>\n\nIn this ``generate_sequences`` cases: -\n``dispatch_mode = Dispatch.DP_COMPUTE_PROTO`` -\n``dispatch_fn = dispatch_dp_compute_data_proto`` -\n``collect_fn = collect_dp_compute_data_proto`` -\n``execute_fn = RayWorkerGroup.execute_all``\n\nONE_TO_ALL v.s. DP_COMPUTE_PROTO\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n``dispatch_mode`` is associated with a ``dispatch_fn`` and a\n``collect_fn``. As the name implies, ``dispatch_fn`` processes the input\narguments in ``WorkerGroup`` and generate a batch (list) of input\narguments, each of which will be fed into a worker attached to the\n``WorkerGroup``.\n\n``dispatch_fn`` of ``ONE_TO_ALL`` is\n`dispatch_one_to_all <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L119>`__,\nwhich just duplicates all the input arguments into N replicas, where N\nequals the number of Workers attached to the ``worker_group``:\n\n.. code:: python\n\n   def dispatch_one_to_all(worker_group, *args, **kwargs):\n       args = tuple([arg] * worker_group.world_size for arg in args)\n       kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}\n       return args, kwargs\n\n``dispatch_fn`` of ``DP_COMPUTE_PROTO`` is\n`dispatch_dp_compute_data_proto <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L350>`__,\nwhich uses ``DataProto.chunk`` to split a large ``DataProto`` into N\nsmaller ``DataProto``, where N equals the world_size (number of the\nworkers) of the ``worker_group``:\n\n.. code:: python\n\n   def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):\n       from verl.single_controller.base.worker_group import WorkerGroup\n\n       assert isinstance(worker_group, WorkerGroup)\n       # Note: enable auto padding for dp compute DatapProto\n       splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding(\n           worker_group.world_size,\n           *args,\n           **kwargs,\n       )\n       return splitted_args, splitted_kwargs\n\nThe ``collect_fn`` follows the same pattern and process a batch (list)\nof returned value from all workers of a ``WorkerGroup`` and merge it\ninto a list as ``collect_all_to_all`` does or a large ``DataProto`` as\n``collect_dp_compute_data_proto`` does.\n\nFinally, a new method is dynamically generated using ``func_generator``\nand added to the ``WorkerGroup`` instance:\n\n.. code:: python\n\n               <<<continue 3>>>\n               # bind a new method to the RayWorkerGroup\n               func = func_generator(\n                   self,\n                   method_name,\n                   dispatch_fn=dispatch_fn,\n                   collect_fn=collect_fn,\n                   execute_fn=execute_fn,\n                   blocking=blocking,\n               )\n\n               try:\n                   setattr(self, method_name, func)\n                   method_names.append(method_name)\n               except Exception as e:\n                   raise ValueError(f\"Fail to set method_name {method_name}\") from e\n\nThis makes the method invocable via the ``WorkerGroup`` interface.\n\n--------------\n\nStep 3: Call Chain\n~~~~~~~~~~~~~~~~~~\n\nAll the machinery above ensures that distributed calls feel identical to\nsingle-process ones. In the original single-process script, the code\nlooks like:\n\n.. code:: python\n\n   rollout = Rollout()\n   rollout.generate_sequences(batch)\n\nWith ``verl``, the multiprocess program becomes:\n\n.. code:: python\n\n   rollout = RayWorkerGroup(resource_pool=[4], RayClassWithArgs(Rollout))\n   rollout.generate_sequences(batch)\n\n.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/call_generate_sequences.png?raw=true\n   :alt: call_chain_of_generate_sequences\n\n   call_chain_of_generate_sequences\n\nBehind this simple call: - ``dispatch_fn`` splits input across workers -\n``execute_fn`` performs the actual remote invocation - ``collect_fn``\ngathers the results\n\nAll of this is abstracted away, enabling developers to write distributed\ncode with minimal changes to their existing logic.\n\n--------------\n\nBeyond RL Post-Training: Generalizing ``verl.single_controller``\n----------------------------------------------------------------\n\nThe ``verl.single_controller`` module generalizes well beyond\nreinforcement learning. It provides a clean abstraction to batch-process\nremote method calls, with automatic input/output handling.\n\nBy minimizing the gap between single-process and multi-process scripts,\n``verl.single_controller`` opens the door to distributed computing in\nbroader domains — not limited to RL post-training.\n\nWe hope this design inspires more examples and extensions from the\ncommunity.\n"
  },
  {
    "path": "verl_rl/docs/start/agentic_rl.rst",
    "content": "Agentic RL Training\n===================\n\nLast updated: 07/15/2025.\n\nOverview\n----------\nThe goal of Agentic RL is to improve the performance of backend models from reinforcement learning to the Agent. During the training process, a series of features are developed:\n\n1. Server-based asynchronous rollout\n2. Multi-turn conversations and tool calls\n3. LangGraph-based Agent\n\n\nThis document explains the system principles and usage involved to help users implement Agentic RL.\n\n\nServer-based Asynchronous Rollout\n---------------------------------\n\nSince Agents need to interact with the environment through various tool calls, in order to avoid GPU idling while waiting for tool call return results, an asyncio based co-routing mechanism is utilized to execute each rollout requests asynchronously, thereby improving training performance. To support asynchronous rollout, the inference engine (server) and the agent (client) are architecturally separated, implementing a server-based system with the following objectives:\n\n1. Enabling load balancing mechanisms to balance loads across multiple GPUs and reduce the impact of long-tail requests on performance. For this purpose, scheduling capabilities in stream mode (recipe\\stream_mode) are implemented as a recipe.\n2. Preventing agent specific features such as tracing from affecting the inference engine.\n\nSystem Architecture\n~~~~~~~~~~~~~~~~~~~\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop.png?raw=true\n\nFor more detail on internal design, please refer to :doc:`Agent Loop<../advance/agent_loop>`.\n\nSystem Components\n~~~~~~~~~~~~~~~~~\n\n+--------------------------+----------------------------------------------------------------------------+\n| Component                | Role                                                                       |\n+==========================+============================================================================+\n| AgentLoop                | Client, implements Agent functions                                         |\n+--------------------------+----------------------------------------------------------------------------+\n| AsyncLLMServerManager    | Inference gateway, provides generate interface for AgentLoop               |\n+--------------------------+----------------------------------------------------------------------------+\n| AsyncServer              | Server, each instance is connected to one DP group of the inference engine |\n+--------------------------+----------------------------------------------------------------------------+\n\n**\"generate\" Interface**\n\nThe \"generate\" function based on ray actor is used between the Client and Server instead of the standard chat completion API. This is because the conversion between tokens and text can be irreversible. For example, the token converted from \"<think>\" will be different from that generated by the LLM. During the training phase, it is necessary to strictly use the tokens generated by LLM inference to avoid inaccurate in computing advantage, which may affect model performance. Having the Server provide a token-based API helps the Client maintain the relationship between the text generated by tool calls and the tokens returned by the LLM, so as to output correct tokens for training.\n\n\n**Inference Engine Adaptation**\nAsyncServer uniformly provides a generate function to the upper layer, with separate implementations for SGLang and vLLM to hide underlying differences:\n\n1. The SGLang AsyncServer uses the async_generate interface of the SGLang engine, which is located on the first GPU of each TP group. Therefore, AsyncServer needs to remotely call async_generate through ray actor.\n2. The vLLM AsyncServer uses the generate interface of the vLLM engine, which can communicate with the GPUs in the TP group through ZMQ and can be directly called in AsyncServer.\n\n\nUsage Example\n~~~~~~~~~~~~~\n\nFollow :doc:`GSM8K example<../examples/gsm8k_example>` to prepare the dataset and model checkpoints.\n\nThere are two options required to use agent loop:\n\n- `data.return_raw_chat=True`\n- `actor_rollout_ref.rollout.mode=async`\n\nThis example uses the sglang inference engine by default, and you can also modify rollout_name to use vllm.\n\n.. code-block:: bash\n\n    bash examples/grpo_trainer/run_qwen2-7b_seq_balance.sh\n\n\nMulti-turn Conversations and Tool Calls\n---------------------------------------\n\nFollow :doc:`Multi-turn Rollout Support<../sglang_multiturn/multiturn>` to prepare tool and configuration files.\n\nThe Tool Agent Loop has an additional requirement: adding an \"agent_name\" field to the dataset. During rollout, it will choose to use tool_agent_loop or single_turn_agent (default) based on this field.\n\nUsage Example\n~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    # install mlflow to view toolcall and llm trace\n    pip install mlflow\n\n    # This will download and preprocess the GSM8K dataset into ~/data/gsm8k/ and add the \"agent_name\" field.\n    bash examples/data_preprocess/gsm8k_tool_agent_loop.py\n\n    # Start training with tool calls and enabled mlflow based trace helping to debug the rollout details\n    bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh\n\n    # When training is done, start a mlflow server to view trace\n    mlflow ui -h 0.0.0.0 -p 5000 --backend-store-uri sqlite:////tmp/mlruns.db\n\n    # then you can open http://<your ip address>:5000 from browser to view trace\n\n\nNote: During training, because the model may sometimes fail to generate correct toolcall tags, an error message \"Failed to decode tool call\" will be output to the console, which does not indicate an abnormality in training.\n\n\nFollow :doc:`Rollout trace<../advance/rollout_trace>` to known more about trace feature.\n\n\n\nAgent Framework\n---------------\n\nSystem Architecture\n~~~~~~~~~~~~~~~~~~~\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/langgraph_agent.png?raw=true\n\nSystem Components\n~~~~~~~~~~~~~~~~~\n\n+--------------------------+-----------------------------------------------------------------------------------------------+\n| Component                | Role                                                                                          |\n+==========================+===============================================================================================+\n| ChatModel                | LLM object of LangChain, used to adapt to the “generate” api provided by AsyncLLMServerManager|\n+--------------------------+-----------------------------------------------------------------------------------------------+\n| RectAgentLoop            | Agent adaptation layer, which by default supports a naive LangGraph Agentic.                  |\n|                          | New classes can be derived to support user-defined Agents, and the run function needs to be   |\n|                          | implemented to complete Agent calls.                                                          |\n+--------------------------+-----------------------------------------------------------------------------------------------+\n| AsyncServer              | Server, each instance is connected to one DP group of the inference engine.                   |\n+--------------------------+-----------------------------------------------------------------------------------------------+\n\n\nFollow doc \"recipe/langgraph_agent/example/README.md\" for more details."
  },
  {
    "path": "verl_rl/docs/start/install.rst",
    "content": "Installation\n============\n\nRequirements\n------------\n\n- **Python**: Version >= 3.9\n- **CUDA**: Version >= 12.1\n\nverl supports various backends. Currently, the following configurations are available:\n\n- **FSDP** and **Megatron-LM** (optional) for training.\n- **SGLang**, **vLLM** and **TGI** for rollout generation.\n\nChoices of Backend Engines\n----------------------------\n\n1. Training:\n\nWe recommend using **FSDP** backend to investigate, research and prototype different models, datasets and RL algorithms. The guide for using FSDP backend can be found in :doc:`FSDP Workers<../workers/fsdp_workers>`.\n\nFor users who pursue better scalability, we recommend using **Megatron-LM** backend. Currently, we support `Megatron-LM v0.12.2 <https://github.com/NVIDIA/Megatron-LM/tree/core_v0.12.2>`_. The guide for using Megatron-LM backend can be found in :doc:`Megatron-LM Workers<../workers/megatron_workers>`.\n\n\n2. Inference:\n\nFor inference, vllm 0.8.3 and later versions have been tested for stability. We recommend turning on env var `VLLM_USE_V1=1` for optimal performance.\n\nFor SGLang, refer to the :doc:`SGLang Backend<../workers/sglang_worker>` for detailed installation and usage instructions. SGLang rollout is under extensive development and offers many advanced features and optimizations. We encourage users to report any issues or provide feedback via the `SGLang Issue Tracker <https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/106>`_.\n\nFor huggingface TGI integration, it is usually used for debugging and single GPU exploration.\n\nInstall from docker image\n-------------------------\n\nWe provide pre-built Docker images for quick setup. And from this version,\nwe utilize a new image release hierarchy for productivity and stability.\n\nThe image types are divided into three large categories:\n\n- **Base Image**: Without inference and training frameworks, only basic dependencies are installed.\n  Can directly install vllm or SGLang on top of it, without need of reinstall torch or CUDA.\n- **Application Image**: Stable version with inference and training frameworks installed.\n- **Community Image**: Unstable version with the latest frameworks and features.\n\nThe first two types of images are hosted on dockerhub `verlai/verl <https://hub.docker.com/r/verlai/verl>`_ repository, while the preview images are hosted on community repository.\n\n.. note::\n\n    The image versions are mapped with verl releases, for example, image with tag ``verl0.4`` is built for verl release ``v0.4.x``.\n\nBase Image\n::::::::::\n\nThe stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4``. The installed package versions can be found from tags, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.base``.\n\nThe base images for preview are ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0` and ``verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`` with different CUDA versions. From verl0.5, images are built with `Deep-EP <https://github.com/deepseek-ai/DeepEP>`_ for efficient EP communication.\n\nThe update of base image is not frequent, and the app image can be built on top of it without reinstalling base packages.\n\nApplication Image\n:::::::::::::::::\n\nFrom this version, we divide images built for vLLM and SGLang as the divergence of dependent packages like FlashInfer.\n\nThere are four types of application images available:\n\n- **vLLM with FSDP and Megatron**: ``verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2``, with Deep-EP support: ``verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2-deepep``.\n- **SGLang with FSDP and Megatron**: ``verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2`` (need vLLM support, but can have some package conflicts), with Deep-EP support: ``verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2-deepep``.\n- **Preview version of SGLang with FSDP and Megatron, CUDA 12.6**: ``verlai/verl:app-verl0.5-sglang0.4.8-mcore0.12.2-te2.2``\n- **Preview version of SGLang with FSDP and Megatron, CUDA 12.8**: ``verlai/verl:app-preview-verl0.5-sglang0.4.8-mcore0.12.2-te2.2``\n\nThe latest vLLM support is coming soon.\n\nDocker images with Megatron backends are runnable with large language model like ``Qwen/Qwen3-235B-A22B``, ``deepseek-ai/DeepSeek-V3-0324`` post-training. Refer to the :doc:`Large Language Model Post-Training documentation<../perf/dpsk>` for more details.\n\nApplication images can be updated frequently, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.app.[frameworks]``. Based on the base image, it is easy to build your own application image with the desired inference and training frameworks.\n\nCommunity Image\n:::::::::::::::\n\nCommunity images are provided by the community, including the latest versions of vLLM and SGLang, and may include experimental features or configurations. And also works for other hardwares or platforms like AMD GPUs with ROCM or AWS EFA and Sagemaker.\n\nFor latest vLLM with FSDP, please refer to `hiyouga/verl <https://hub.docker.com/r/hiyouga/verl>`_ repository and the latest version is ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``.\n\nFor latest SGLang with FSDP, please refer to `ocss884/verl-sglang <https://hub.docker.com/r/ocss884/verl-sglang>`_ repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group.\n\nSee files under ``docker/`` for NGC-based image or if you want to build your own.\n\nNote that For aws instances with EFA net interface (Sagemaker AI Pod),\nyou need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa``\n\nInstallation from Docker\n::::::::::::::::::::::::\n\nAfter pulling the desired Docker image and installing desired inference and training frameworks, you can run it with the following steps:\n\n1. Launch the desired Docker image and attach into it:\n\n.. code:: bash\n\n    docker create --runtime=nvidia --gpus all --net=host --shm-size=\"10g\" --cap-add=SYS_ADMIN -v .:/workspace/verl --name verl <image:tag> sleep infinity\n    docker start verl\n    docker exec -it verl bash\n\n\n2.\tIf you use the images provided, you only need to install verl itself without dependencies:\n\n.. code:: bash\n\n    # install the nightly version (recommended)\n    git clone https://github.com/volcengine/verl && cd verl\n    pip3 install --no-deps -e .\n\n[Optional] If you hope to switch between different frameworks, you can install verl with the following command:\n\n.. code:: bash\n\n    # install the nightly version (recommended)\n    git clone https://github.com/volcengine/verl && cd verl\n    pip3 install -e .[vllm]\n    pip3 install -e .[sglang]\n\n\nInstall from custom environment\n---------------------------------------------\n\nWe recommend to use docker images for convenience. However, if your environment is not compatible with the docker image, you can also install verl in a python environment.\n\n\nPre-requisites\n::::::::::::::\n\nFor training and inference engines to utilize better and faster hardware support, CUDA/cuDNN and other dependencies are required,\nand some of the dependencies are easy to be overridden when installing other packages,\nso we put them in the :ref:`Post-installation` step.\n\n.. note::\n\n    The installation steps below are recommended configurations for the latest version of verl.\n    If you are trying to customize your own environment, please ignore the strict constraints.\n\nWe need to install the following pre-requisites:\n\n- **CUDA**: Version >= 12.4\n- **cuDNN**: Version >= 9.8.0\n- **Apex**\n\nCUDA above 12.4 is recommended to use as the docker image,\nplease refer to `NVIDIA's official website <https://developer.nvidia.com/cuda-toolkit-archive>`_ for other version of CUDA.\n\n.. code:: bash\n\n    # change directory to anywher you like, in verl source code directory is not recommended\n    wget https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb\n    dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb\n    cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/\n    apt-get update\n    apt-get -y install cuda-toolkit-12-4\n    update-alternatives --set cuda /usr/local/cuda-12.4\n\n\ncuDNN can be installed via the following command,\nplease refer to `NVIDIA's official website <https://developer.nvidia.com/rdp/cudnn-archive>`_ for other version of cuDNN.\n\n.. code:: bash\n\n    # change directory to anywher you like, in verl source code directory is not recommended\n    wget https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb\n    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/\n    apt-get update\n    apt-get -y install cudnn-cuda-12\n\nNVIDIA Apex is required for Megatron-LM and FSDP training.\nYou can install it via the following command, but notice that this steps can take a very long time.\nIt is recommended to set the ``MAX_JOBS`` environment variable to accelerate the installation process,\nbut do not set it too large, otherwise the memory will be overloaded and your machines may hang.\n\n.. code:: bash\n\n    # change directory to anywher you like, in verl source code directory is not recommended\n    git clone https://github.com/NVIDIA/apex.git && \\\n    cd apex && \\\n    MAX_JOB=32 pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" ./\n\n\nInstall dependencies\n::::::::::::::::::::\n\n.. note::\n\n    We recommend to use a fresh new conda environment to install verl and its dependencies.\n\n    **Notice that the inference frameworks often strictly limit your pytorch version and will directly override your installed pytorch if not paying enough attention.**\n\n    As a countermeasure, it is recommended to install inference frameworks first with the pytorch they needed. For vLLM, if you hope to use your existing pytorch,\n    please follow their official instructions\n    `Use an existing PyTorch installation <https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html#build-wheel-from-source>`_ .\n\n\n1. First of all, to manage environment, we recommend using conda:\n\n.. code:: bash\n\n   conda create -n verl python==3.10\n   conda activate verl\n\n\n2. Then, execute the ``install.sh`` script that we provided in verl:\n\n.. code:: bash\n\n    # Make sure you have activated verl conda env\n    # If you need to run with megatron\n    bash scripts/install_vllm_sglang_mcore.sh\n    # Or if you simply need to run with FSDP\n    USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh\n\n\nIf you encounter errors in this step, please check the script and manually follow the steps in the script.\n\n\nInstall verl\n::::::::::::\n\nFor installing the latest version of verl, the best way is to clone and\ninstall it from source. Then you can modify our code to customize your\nown post-training jobs.\n\n.. code:: bash\n\n   git clone https://github.com/volcengine/verl.git\n   cd verl\n   pip install --no-deps -e .\n\n\nPost-installation\n:::::::::::::::::\n\nPlease make sure that the installed packages are not overridden during the installation of other packages.\n\nThe packages worth checking are:\n\n- **torch** and torch series\n- **vLLM**\n- **SGLang**\n- **pyarrow**\n- **tensordict**\n- **nvidia-cudnn-cu12**: For Magetron backend\n\nIf you encounter issues about package versions during running verl, please update the outdated ones.\n\n\nInstall with AMD GPUs - ROCM kernel support\n------------------------------------------------------------------\n\nWhen you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and run it.\nIf you encounter any issues in using AMD GPUs running verl, feel free to contact me - `Yusheng Su <https://yushengsu-thu.github.io/>`_.\n\nFind the docker for AMD ROCm: `docker/Dockerfile.rocm <https://github.com/volcengine/verl/blob/main/docker/Dockerfile.rocm>`_\n::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::\n\n.. code-block:: bash\n\n    #  Build the docker in the repo dir:\n    # docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 .\n    # docker images # you can find your built docker\n    FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n\n    # Set working directory\n    # WORKDIR $PWD/app\n\n    # Set environment variables\n    ENV PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\"\n\n    # Install vllm\n    RUN pip uninstall -y vllm && \\\n        rm -rf vllm && \\\n        git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \\\n        cd vllm && \\\n        MAX_JOBS=$(nproc) python3 setup.py install && \\\n        cd .. && \\\n        rm -rf vllm\n\n    # Copy the entire project directory\n    COPY . .\n\n    # Install dependencies\n    RUN pip install \"tensordict<0.6\" --no-deps && \\\n        pip install accelerate \\\n        codetiming \\\n        datasets \\\n        dill \\\n        hydra-core \\\n        liger-kernel \\\n        numpy \\\n        pandas \\\n        datasets \\\n        peft \\\n        \"pyarrow>=15.0.0\" \\\n        pylatexenc \\\n        \"ray[data,train,tune,serve]\" \\\n        torchdata \\\n        transformers \\\n        wandb \\\n        orjson \\\n        pybind11 && \\\n        pip install -e . --no-deps\n\nBuild the image\n::::::::::::::::::::::::\n\n.. code-block:: bash\n\n    docker build -t verl-rocm .\n\nLaunch the container\n::::::::::::::::::::::::::::\n\n.. code-block:: bash\n\n    docker run --rm -it \\\n      --device /dev/dri \\\n      --device /dev/kfd \\\n      -p 8265:8265 \\\n      --group-add video \\\n      --cap-add SYS_PTRACE \\\n      --security-opt seccomp=unconfined \\\n      --privileged \\\n      -v $HOME/.ssh:/root/.ssh \\\n      -v $HOME:$HOME \\\n      --shm-size 128G \\\n      -w $PWD \\\n      verl-rocm \\\n      /bin/bash\n\nIf you do not want to root mode and require assign yourself as the user,\nPlease add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script.\n\nverl with AMD GPUs currently supports FSDP as the training engine, vLLM and SGLang as the inference engine. We will support Megatron in the future.\n"
  },
  {
    "path": "verl_rl/docs/start/more_resources.rst",
    "content": "More Resources\n==============\n\nLast updated: 06/30/2025.\n\n- Introduction to verl (`Slides <https://tongyx361.github.io/blogs/posts/verl-intro>`_)\n- verl Code Walkthrough (`Slides <https://tongyx361.github.io/blogs/posts/verl-tutorial>`_, `Talk in Chinese <https://hcqnc.xetlk.com/sl/3vACOK>`_) \n"
  },
  {
    "path": "verl_rl/docs/start/multinode.rst",
    "content": "Multinode Training\n==================\n\nLast updated: 06/10/2025.\n\n.. _wuxibin89: https://github.com/wuxibin89\n\nAuthor: `Xibin Wu <https://github.com/wuxibin89>`_, `Yusheng Su <https://yushengsu-thu.github.io/>`_.\n\nManual\n------\n\nSet up multinode ray cluster\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n1. Start head node with ``ray start --head --dashboard-host=0.0.0.0``, there're 2 address you should care about:\n\n- GCS address: ``ray start --address=<address>``, where worker node should connect to.\n- Dashboard address: ``<address>:8265``, where you should submit job to the cluster.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/head.png?raw=true\n\n2. Start worker node with ``ray start --address=<address>`` you get above.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/worker.png?raw=true\n\n3. Now you should see the cluster have 2 nodes with ``ray status``.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/status.png?raw=true\n\n4. Additionally, you can access dashboard in the browser with the address you get above. \n\n*Firewall rules maybe need configure to access the dashboard, if there's any trouble, please contact your network administrator.*\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/overview.png?raw=true\n\nSubmit job to ray cluster\n~~~~~~~~~~~~~~~~~~~~~~~~~\n1. Submit ray job to cluster with the dashboard address you get above.\n\n.. code-block:: bash\n\n    ray job submit --address=\"http://127.0.0.1:8265\" \\\n        --runtime-env=verl/trainer/runtime_env.yaml \\\n        --no-wait \\\n        -- \\\n        python3 -m verl.trainer.main_ppo \\\n        trainer.n_gpus_per_node=8 \\\n        trainer.nnodes=2 \\\n        ...\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/submit.png?raw=true\n\n2. Then you can check the job status with the following commands:\n\n- ray job list: list all jobs submitted to the cluster.\n- ray job logs <Submission ID>: query the logs of the job.\n- ray job status <Submission ID>: query the status of the job.\n- ray job stop <Submission ID>: request the job to be stopped.\n\n3. You can also access driver/task/actor logs in ``/tmp/ray/session_latest/logs/``, driver log is ``job-driver-raysubmit_<Submission ID>.log``.\n\n4. We strongly recommend you to view job detail from dashboard in multinode training, because it provide more structure way to view the job information.\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/job.png?raw=true\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/job_detail.png?raw=true\n\n\nSlurm\n-----\nTBD\n\ndstack\n------\n`dstackai/dstack <https://github.com/dstackai/dstack>`_ is an open-source container orchestrator that simplifies distributed training across cloud providers and on-premises environments\nwithout the need to use K8S or Slurm.\n\nPrerequisite\n~~~~~~~~~~~~\nOnce dstack is `installed <https://dstack.ai/docs/installation>`_, initialize the directory as a repo with ``dstack init``. \n\n.. code-block:: bash\n\n    mkdir myproject && cd myproject\n    dstack init\n\n**Create a fleet**\n\nBefore submitting distributed training jobs, create a `dstack` `fleet <https://dstack.ai/docs/concepts/fleets>`_.\n\nRun a Ray cluster task\n~~~~~~~~~~~~~~~~~~~~~~\n\nOnce the fleet is created, define a Ray cluster task, e.g. in ``ray-cluster.dstack.yml``:\n\n.. code-block:: yaml\n\n    type: task\n    name: ray-verl-cluster\n\n    nodes: 2\n\n    env:\n        - WANDB_API_KEY\n        - PYTHONUNBUFFERED=1\n        - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n    \n    image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.2\n    commands:\n        - git clone https://github.com/volcengine/verl\n        - cd verl\n        - pip install --no-deps -e .\n        - pip install hf_transfer hf_xet\n        - |\n        if [ $DSTACK_NODE_RANK = 0 ]; then\n            python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k\n            python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-7B-Instruct')\" \n            ray start --head --port=6379;\n        else\n            ray start --address=$DSTACK_MASTER_NODE_IP:6379\n        fi\n\n    # Expose Ray dashboard port\n    ports:\n        - 8265\n\n    resources:\n        gpu: 80GB:8\n        shm_size: 128GB\n\n    # Save checkpoints on the instance\n    volumes:\n        - /checkpoints:/checkpoints\n\nNow, if you run this task via `dstack apply`, it will automatically forward the Ray's dashboard port to `localhost:8265`.\n\n.. code-block:: bash\n\n    dstack apply -f ray-cluster.dstack.yml\n\nAs long as the `dstack apply` is attached, you can use `localhost:8265` to submit Ray jobs for execution\n\nSubmit Ray jobs\n~~~~~~~~~~~~~~~\n\nBefore you can submit Ray jobs, ensure to install `ray` locally:\n   \n.. code-block:: shell\n\n    pip install ray\n\nNow you can submit the training job to the Ray cluster which is available at ``localhost:8265``:\n   \n.. code-block:: shell\n\n    $ RAY_ADDRESS=http://localhost:8265\n    $ ray job submit \\\n        -- python3 -m verl.trainer.main_ppo \\\n        data.train_files=/root/data/gsm8k/train.parquet \\\n        data.val_files=/root/data/gsm8k/test.parquet \\\n        data.train_batch_size=256 \\\n        data.max_prompt_length=512 \\\n        data.max_response_length=256 \\\n        actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n        critic.optim.lr=1e-5 \\\n        critic.model.path=Qwen/Qwen2.5-7B-Instruct \\\n        critic.ppo_micro_batch_size_per_gpu=4 \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.project_name=ppo_training \\\n        trainer.experiment_name=qwen-2.5-7B \\\n        trainer.val_before_train=False \\\n        trainer.n_gpus_per_node=8 \\\n        trainer.nnodes=2 \\\n        trainer.default_local_dir=/checkpoints \\\n        trainer.save_freq=10 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15 2>&1 | tee verl_demo.log \\\n        trainer.resume_mode=disable\n\n\nFor more details on how `dstack` works, check out its `documentation <https://dstack.ai/docs>`_.\n\nHow to debug?\n---------------------\n\n\nRay Distributed Debugger VSCode Extension (Recommended)\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n1. Starting with Ray 2.39, Anyscale has introduced the `Ray Distributed Debugger <https://docs.ray.io/en/latest/ray-observability/ray-distributed-debugger.html>`_ VSCode extension. Follow the extension’s installation instructions, then add your cluster using the dashboard URL you obtained earlier.\n\n   .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/debugger.png?raw=true\n      :alt: Ray Distributed Debugger VSCode extension screenshot\n\n2. Prerequisites.\n\n   Ensure the following are installed (see the extension README for more detail):\n\n   - Visual Studio Code  \n   - `ray[default]` >= 2.9.1  \n   - `debugpy` >= 1.8.0  \n\n   .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/c7098b755ff689859837773a916c857.png?raw=true\n      :alt: VSCode with Ray prerequisites\n\n3. Environment Variables.\n\n   To enable post‑mortem debugging, set:\n\n   .. code-block:: bash\n\n      export RAY_DEBUG_POST_MORTEM=1\n\n   .. admonition:: Note\n      :class: important\n\n      Be sure to remove any legacy flags before starting Ray:\n\n      - `RAY_DEBUG=legacy`  \n      - `--ray-debugger-external`\n\n4. Configuring BreakpointsSet up breakpoint() in your code, and submit job to cluster. Then the extension will show the breakpoint information.\n\n\n   1. Insert `breakpoint()` calls into your remote functions.  \n   2. Submit your job to the cluster.  \n\n   The extension will detect active breakpoints and display them in VSCode.\n\n   .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/4ddad74395c79a1402331c0ce73316f.png?raw=true\n      :alt: Detected breakpoint in VSCode\n\n   **Note:** Breakpoints are only supported inside functions decorated with `@ray.remote`.\n\n5. Launching the Debugger.\n\n   Run your job directly from the command line (do not use a `launch.json`):\n\n   .. code-block:: bash\n\n      python job.py\n\n6. Attaching to a Breakpoint.\n\n Once the process hits the first `breakpoint()`, click the Ray Distributed Debugger icon in the VSCode sidebar to attach the debugger.\n\n   .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/4ddad74395c79a1402331c0ce73316f.png?raw=true\n      :alt: Attaching VSCode debugger to Ray process\n\n7. Debugging With Multiple breakpoint().\n\n   For each subsequent task, first disconnect the current debugger session, then click the extension icon again to attach to the next breakpoint.\n\n   .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/6e83c910a62c82fecb89c6619e001cd.png?raw=true\n      :alt: Disconnecting and reconnecting the debugger\n\nLegacy Ray Debugger\n~~~~~~~~~~~~~~~~~~~\n1. Ray has a builtin legacy `debugger <https://docs.ray.io/en/latest/ray-observability/user-guides/debug-apps/ray-debugging.html>`_ that allows you to debug your distributed applications. To enable debugger, start ray cluster with ``RAY_DEBUG=legacy`` and ``--ray-debugger-external``.\n\n.. code-block:: bash\n\n    # start head node\n    RAY_DEBUG=legacy ray start --head --dashboard-host=0.0.0.0 --ray-debugger-external\n    # start worker node\n    RAY_DEBUG=legacy ray start --address='10.124.46.192:6379' --ray-debugger-external\n\n2. Set up breakpoint in your code, and submit job to cluster. Then run ``ray debug`` to wait breakpoint:\n\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/legacy.png?raw=true\n\n\nMulti-node training on AMD clusters\n---------------------------------------------------------------------------------------\n\nIf you want to run multi-node training with slurm with Docker/Podman container on AMD Cluster, you can use the following script. \n\nIf you encounter any issues in using AMD GPUs running verl, please contact `Yusheng Su <https://yushengsu-thu.github.io/>`_.\n\n.. note::\n    1. You need to use ``podman`` or ``docker`` in the following script. We will release the apptainer script later.\n    2. If you want to use ``podman``, you just replace ``docker`` with ``podman`` in the following script.\n\nThe script includes the following steps:\n\n1. SLURM Configuration\n2. Environment Setup\n3. Docker/Podman Container Setup\n4. Ray Cluster Initialization\n5. Data Preprocessing\n6. Model Setup\n7. Training Launch\n\n\nslurm_script.sh\n~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n    #!/bin/bash\n\n    #SBATCH --job-name=verl-ray-on-slurm\n    #SBATCH --nodes=2\n    #SBATCH --ntasks-per-node=2\n    #SBATCH --mem=200G\n    #SBATCH --time=30-00:00:00\n    #SBATCH --gpus-per-node=8\n    #SBATCH --cpus-per-task=28\n    #SBATCH --output=../verl_log/slurm-%j.out\n    #SBATCH --error=../verl_log/slurm-%j.err\n    #SBATCH --nodelist=gpu-[0,1]\n\n\n    # load necessary modules\n    ### Run this setup\n    # [Cluster]: Use docker\n    # docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n\n\n    ##########################################################################\n    ###The following setting should be set in different project and cluster###\n    ##########################################################################\n\n    ### Project\n    CONTAINER_NAME=\"multinode_verl_training\"\n    IMG=\"verl.rocm\"\n    DOCKERFILE=\"docker/Dockerfile.rocm\"\n    # echo $PWD\n    verl_workdir=\"${HOME}/projects/verl_upstream\"\n    export TRANSFORMERS_CACHE=\"${HOME}/.cache/huggingface\"\n    export HF_HOME=$TRANSFORMERS_CACHE\n\n    ### Cluster Network Setting\n    export NCCL_DEBUG=TRACE\n    export GPU_MAX_HW_QUEUES=2\n    export TORCH_NCCL_HIGH_PRIORITY=1\n    export NCCL_CHECKS_DISABLE=1\n    # export NCCL_IB_HCA=rdma0,rdma1,rdma2,rdma3,rdma4,rdma5,rdma6,rdma7 \n    export NCCL_IB_HCA=mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_8,mlx5_9\n    export NCCL_IB_GID_INDEX=3\n    export NCCL_CROSS_NIC=0\n    export CUDA_DEVICE_MAX_CONNECTIONS=1\n    export NCCL_PROTO=Simple\n    export RCCL_MSCCL_ENABLE=0\n    export TOKENIZERS_PARALLELISM=false\n    export HSA_NO_SCRATCH_RECLAIM=1\n    ##########################################################################\n\n    ### For rocm and training script\n    export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n    export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n    export CUDA_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES\n\n\n    # Build and launch the Docker container\n    srun bash -c \"\n        # Exit on any error\n        set -e \n\n        # Clean up dangling images (images with <none> tag)\n        docker image prune -f\n\n        # Need to pull the docker first\n        docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4\n        \n        if ! docker images --format \"{{.Repository}}:{{.Tag}}\" | grep -q \"${IMG}\"; then\n            echo \\\"Building ${IMG} image...\\\"\n            docker build -f \\\"${DOCKERFILE}\\\" -t \\\"${IMG}\\\" .\n        else\n            echo \\\"${IMG} image already exists, skipping build\\\"\n        fi\n\n        # Removing old container if exists\n        docker rm \\\"${CONTAINER_NAME}\\\" 2>/dev/null || true\n\n        # Checking network devices\n        ibdev2netdev\n\n        # Launch the docker\n        docker run --rm -d \\\n        -e HYDRA_FULL_ERROR=1 \\\n        -e HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES} \\\n        -e ROCR_VISIBLE_DEVICES=${ROCR_VISIBLE_DEVICES} \\\n        -e CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} \\\n        -e NCCL_DEBUG=${NCCL_DEBUG} \\\n        -e GPU_MAX_HW_QUEUES=${GPU_MAX_HW_QUEUES} \\\n        -e TORCH_NCCL_HIGH_PRIORITY=${TORCH_NCCL_HIGH_PRIORITY} \\\n        -e NCCL_CHECKS_DISABLE=${NCCL_CHECKS_DISABLE} \\\n        -e NCCL_IB_HCA=${NCCL_IB_HCA} \\\n        -e NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX} \\\n        -e NCCL_CROSS_NIC=${NCCL_CROSS_NIC} \\\n        -e CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS} \\\n        -e NCCL_PROTO=${NCCL_PROTO} \\\n        -e RCCL_MSCCL_ENABLE=${RCCL_MSCCL_ENABLE} \\\n        -e TOKENIZERS_PARALLELISM=${TOKENIZERS_PARALLELISM} \\\n        -e HSA_NO_SCRATCH_RECLAIM=${HSA_NO_SCRATCH_RECLAIM} \\\n        -e TRANSFORMERS_CACHE=${TRANSFORMERS_CACHE} \\\n        -e HF_HOME=${HF_HOME} \\\n        --network host \\\n        --device /dev/dri \\\n        --device /dev/kfd \\\n        --device /dev/infiniband \\\n        --group-add video \\\n        --cap-add SYS_PTRACE \\\n        --security-opt seccomp=unconfined \\\n        --privileged \\\n        -v \\${HOME}:\\${HOME} \\\n        -v \\${HOME}/.ssh:/root/.ssh \\\n        -w \"${verl_workdir}\" \\\n        --shm-size 128G \\\n        --name \\\"${CONTAINER_NAME}\\\" \\\n        \\\"${IMG}\\\" \\\n        tail -f /dev/null\n\n        echo \\\"Container setup completed\\\"\n    \"\n        # (Optional): If you do not want to root mode and require assign yuorself as the user\n        # Please add `-e HOST_UID=$(id -u)` and `-e HOST_GID=$(id -g)` into the above docker launch script. \n\n\n\n\n\n    ### Ray launch the nodes before training\n\n    # Getting the node names\n    nodes_array=($(scontrol show hostnames \"$SLURM_JOB_NODELIST\" | tr '\\n' ' '))\n\n    head_node=${nodes_array[0]}\n    head_node_ip=$(srun --nodes=1 --ntasks=1 -w \"$head_node\" hostname --ip-address)\n\n    # if we detect a space character in the head node IP, we'll\n    # convert it to an ipv4 address. This step is optional.\n    if [[ \"$head_node_ip\" == *\" \"* ]]; then\n        IFS=' ' read -ra ADDR <<<\"$head_node_ip\"\n    if [[ ${#ADDR[0]} -gt 16 ]]; then\n        head_node_ip=${ADDR[1]}\n    else\n        head_node_ip=${ADDR[0]}\n    fi\n        echo \"IPV6 address detected. We split the IPV4 address as $head_node_ip\"\n    fi\n\n    port=6379\n    ip_head=$head_node_ip:$port\n    export ip_head\n    echo \"IP Head: $ip_head\"\n\n    # make sure we set environment variables before Ray initialization\n\n    # Print out all env variables\n    printenv\n\n    echo \"Starting HEAD at $head_node\"\n    srun --nodes=1 --ntasks=1 -w \"$head_node\" \\\n        docker exec \"${CONTAINER_NAME}\" \\\n            ray start --head --node-ip-address=\"$head_node_ip\" --port=$port \\\n            --dashboard-port=8266 \\\n            --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n    # optional, though may be useful in certain versions of Ray < 1.0.\n    sleep 10\n\n    # number of nodes other than the head node\n    worker_num=$((SLURM_JOB_NUM_NODES - 1))\n\n    for ((i = 1; i <= worker_num; i++)); do\n        node_i=${nodes_array[$i]}\n        echo \"Debug: Starting worker on node_i = ${node_i}\"\n        if [ -z \"$node_i\" ]; then\n            echo \"Error: Empty node name for worker $i\"\n            continue\n        fi\n        echo \"Starting WORKER $i at $node_i\"\n        srun --nodes=1 --ntasks=1 -w \"$node_i\" \\\n            docker exec \"${CONTAINER_NAME}\" \\\n                ray start --address \"$ip_head\" --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n        sleep 5\n    done\n\n\n\n\n    # Ray initlization test (See whether any error in the above execution)\n    echo \"Testing Ray initialization in the slurm nodes...\"\n    docker exec \"${CONTAINER_NAME}\" python3 -c '\n    import ray\n    try:\n        ray.init(address=\"auto\")\n        print(\"\\n=== Ray Cluster Status ===\")\n        print(f\"Number of nodes: {len(ray.nodes())}\")\n        for node in ray.nodes():\n            print(\"Node: {}, Status: {}\".format(node[\"NodeManagerHostname\"], node[\"Alive\"]))\n            # print(f\"Node: {node}\")\n        ray.shutdown()\n        print(\"Ray initialization successful!\")\n    except Exception as e:\n        print(f\"Ray initialization failed: {str(e)}\")\n    '\n    echo \"=== Ray test completed ===\"\n    ######\n\n\n\n    # Run data preprocessing\n\n    echo \"Starting data preprocessing...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 \"examples/data_preprocess/gsm8k.py\" \"--local_dir\" \"../data/gsm8k\"\n\n    echo \"Starting data preprocessing...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 \"examples/data_preprocess/math_dataset.py\" \"--local_dir\" \"../data/math\"\n\n    train_files=\"../data/gsm8k/train.parquet\"\n    val_files=\"../data/gsm8k/test.parquet\"\n\n    # Download and test model\n    echo \"Loading model...\"\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')\"\n    MODEL_PATH=\"Qwen/Qwen2-7B-Instruct\"\n\n    # Set model path after pipeline test\n    MODEL_PATH=\"Qwen/Qwen2.5-0.5B-Instruct\"\n\n    echo \"== Data and model loading Done ==\"\n\n    echo \"Start to train...\"\n\n    docker exec \"${CONTAINER_NAME}\" \\\n        python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')\"\n    MODEL_PATH=\"Qwen/Qwen2-7B-Instruct\"\n\n\n    PYTHONUNBUFFERED=1 srun --overlap --nodes=${SLURM_NNODES} --ntasks=1 -w \"$head_node\" \\\n        docker exec \"${CONTAINER_NAME}\" \\\n        python3 -m verl.trainer.main_ppo \\\n        data.train_files=$train_files \\\n        data.val_files=$val_files \\\n        data.train_batch_size=1024 \\\n        data.max_prompt_length=1024 \\\n        data.max_response_length=1024 \\\n        actor_rollout_ref.model.path=$MODEL_PATH \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=False \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n        actor_rollout_ref.rollout.name=vllm \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n        critic.optim.lr=1e-5 \\\n        critic.model.use_remove_padding=True \\\n        critic.model.path=$MODEL_PATH \\\n        critic.model.enable_gradient_checkpointing=False \\\n        critic.ppo_micro_batch_size_per_gpu=8 \\\n        critic.model.fsdp_config.param_offload=False \\\n        critic.model.fsdp_config.optimizer_offload=False \\\n        algorithm.kl_ctrl.kl_coef=0.0001 \\\n        trainer.critic_warmup=0 \\\n        trainer.logger='[\"console\",\"wandb\"]' \\\n        trainer.project_name='verl_example' \\\n        trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \\\n        trainer.n_gpus_per_node=${SLURM_GPUS_PER_NODE} \\\n        trainer.val_before_train=False \\\n        trainer.nnodes=${SLURM_NNODES} \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15\n\n\nRun multi-node training with above slurm_script.sh\n~~~~~~~~~~~~~~~~~~~~\nJust sbatch your slurm_script.sh\n\n.. code-block:: bash\n\n    sbatch slurm_script.sh\n\n"
  },
  {
    "path": "verl_rl/docs/start/quickstart.rst",
    "content": ".. _quickstart:\n\n=========================================================\nQuickstart: PPO training on GSM8K dataset\n=========================================================\n\nPost-train a LLM using GSM8K dataset.\n\nIntroduction\n------------\n\n.. _hf_dataset_gsm8k: https://huggingface.co/datasets/gsm8k\n\nIn this example, we train an LLM to tackle the `GSM8k <hf_dataset_gsm8k>`_ task with function-based rewards. [1]_\n\nPrerequisite:\n\n- the latest version of ``verl`` and its dependencies installed following the installation guide. Using the docker image is recommended.\n\n- a GPU with at least 24 GB HBM\n\n\nDataset Introduction\n--------------------\n\nGSM8k is a math problem dataset. The prompt is an elementary school\nproblem. The LLM model is asked to solve the math problem. Below is an example:\n\nPrompt\n\n   Katy makes coffee using teaspoons of sugar and cups of water in the\n   ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups\n   of water, calculate the number of teaspoonfuls of sugar she used.\n\nSolution\n\n   The total ratio representing the ingredients she used to make the\n   coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the\n   number of teaspoons she used is 7/20, she used 7/20\\ *120 =\n   <<7/20*\\ 120=42>>42 #### 42\n\nStep 1: Prepare the dataset\n----------------------------\n\nWe preprocess the dataset in parquet format so that (1) it contains necessary fields for computing RL rewards and (2) is faster to read.\n\n.. code-block:: bash\n\n   python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k\n\nStep 2: Download a model for post-training\n-------------------------------------------\n\nIn this example, we start with the ``Qwen2.5-0.5B-Instruct`` model.\n\nIf you want to perform SFT before RL, refer to the :doc:`Complete GSM8K Example<../examples/gsm8k_example>`, the `sft directory <https://github.com/volcengine/verl/blob/main/examples/sft/gsm8k>`_ and `SFT Trainer <https://github.com/volcengine/verl/blob/main/verl/trainer/fsdp_sft_trainer.py>`_ for further details.\n\n.. code-block:: bash\n\n   python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-0.5B-Instruct')\"\n\nStep 3: Perform PPO training with the instruct model\n----------------------------------------------------------------------\n\n**Reward Model/Function**\n\nWe use a pre-defined rule-based reward model. We force the model to produce a final\nanswer following 4 “#” as shown in the solution. We extract the final\nanswer from both the solution and model's output using regular\nexpression matching. We assign a reward of 1 to correct\nanswer, 0.0 to incorrect answer and 0 to no answer. \n\nFor more details, please refer to `verl/utils/reward_score/gsm8k.py <https://github.com/volcengine/verl/blob/v0.4.1/verl/utils/reward_score/gsm8k.py>`_.\n\n**Training Script**\n\nNow let's run PPO training with the dataset and model above. [2]_\n\n\nSet the ``data.train_files`` ,\\ ``data.val_files``, ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on your dataset and model names or paths.\nYou may set ``VERL_USE_MODELSCOPE=True`` to download models from `modelscope <https://www.modelscope.cn>`_ instead of `huggingface <https://huggingface.co>`_.\n\n.. code-block:: bash\n\n   PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=256 \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.logger=console \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=10 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 2>&1 | tee verl_demo.log\n\nYou are expected to see the following logs, indicating training in progress. The key metric ``val/test_score/openai/gsm8k`` is computed every ``trainer.test_freq`` steps:\n\n.. code-block:: bash\n\n    step:0 - timing/gen:21.470 - timing/ref:4.360 - timing/values:5.800 - actor/reward_kl_penalty:0.000 - actor/reward_kl_penalty_coeff:0.001 - timing/adv:0.109 - timing/update_critic:15.664 - critic/vf_loss:14.947 - critic/vf_clipfrac:0.000 - critic/vpred_mean:-2.056 - critic/grad_norm:1023.278 - critic/lr(1e-4):0.100 - timing/update_actor:20.314 - actor/entropy_loss:0.433 - actor/pg_loss:-0.005 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/grad_norm:1.992 - actor/lr(1e-4):0.010 - critic/score/mean:0.004 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.004 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:2.360 - critic/advantages/min:-2.280 - critic/returns/mean:0.003 - critic/returns/max:0.000 - critic/returns/min:0.000 - critic/values/mean:-2.045 - critic/values/max:9.500 - critic/values/min:-14.000 - response_length/mean:239.133 - response_length/max:256.000 - response_length/min:77.000 - prompt_length/mean:104.883 - prompt_length/max:175.000 - prompt_length/min:68.000\n    step:1 - timing/gen:23.020 - timing/ref:4.322 - timing/values:5.953 - actor/reward_kl_penalty:0.000 - actor/reward_kl_penalty:0.001 - timing/adv:0.118 - timing/update_critic:15.646 - critic/vf_loss:18.472 - critic/vf_clipfrac:0.384 - critic/vpred_mean:1.038 - critic/grad_norm:942.924 - critic/lr(1e-4):0.100 - timing/update_actor:20.526 - actor/entropy_loss:0.440 - actor/pg_loss:0.000 - actor/pg_clipfrac:0.002 - actor/ppo_kl:0.000 - actor/grad_norm:2.060 - actor/lr(1e-4):0.010 - critic/score/mean:0.000 - critic/score/max:0.000 - critic/score/min:0.000 - critic/rewards/mean:0.000 - critic/rewards/max:0.000 - critic/rewards/min:0.000 - critic/advantages/mean:0.000 - critic/advantages/max:2.702 - critic/advantages/min:-2.616 - critic/returns/mean:0.000 - critic/returns/max:0.000 - critic/returns/min:0.000 - critic/values/mean:-2.280 - critic/values/max:11.000 - critic/values/min:-16.000 - response_length/mean:232.242 - response_length/max:256.000 - response_length/min:91.000 - prompt_length/mean:102.398 - prompt_length/max:185.000 - prompt_length/min:70.000\n\nCheckout ``Algorithm Baselines`` page for full training and validation logs for reference.\n\nThe checkpoint is saved at the following dir by default: ``checkpoints/${trainer.project_name}/${trainer.experiment_name}``. You can merge the saved checkpoints to huggingface model using ``verl.model_merger`` module, for example:\n\n.. code-block:: bash\n\n    python3 -m verl.model_merger merge \\\n        --backend fsdp \\\n        --local_dir checkpoints/${trainer.project_name}/${trainer.experiment_name}/global_step_1/actor \\\n        --target_dir checkpoints/${trainer.project_name}/${trainer.experiment_name}/global_step_1/actor/huggingface\n\nFor more details about checkpoint and model merging, please refer to :ref:`checkpoint-page`.\n\nTo enable ``wandb`` for experiment tracking, set the following configs:\n\n.. code-block:: bash\n\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=$YOUR_PROJECT_NAME \\\n    trainer.experiment_name=$YOUR_RUN_NAME \\\n\nIf you encounter out of memory issues with HBM less than 32GB, enable the following configs would help:\n\n.. code-block:: bash\n\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    critic.ppo_micro_batch_size_per_gpu=1 \\\n\nFor the full set of configs, please refer to :ref:`config-explain-page` for detailed explanation and performance tuning.\n\n\n.. [1] The original paper (https://arxiv.org/pdf/2110.14168) mainly focuses on training a verifier (a reward model) to solve math problems via Best-of-N sampling. In this example, we train an RL agent using a rule-based reward model.\n.. [2] More training script examples for FSDP and Megatron-LM backend are stored in `examples/ppo_trainer <https://github.com/volcengine/verl/tree/main/examples/ppo_trainer>`_ directory.\n"
  },
  {
    "path": "verl_rl/docs/start/ray_debug_tutorial.rst",
    "content": "Ray Debug Tutorial\r\n==================\r\n\r\nLast updated: 04/23/2025\r\n\r\n\r\n.. _wuxibin89: https://github.com/wuxibin89\r\n\r\nAuthor: `Ao Shen <https://aoshen524.github.io/>`_.\r\n\r\nHow to debug?\r\n---------------------\r\n\r\n\r\nRay Distributed Debugger VSCode Extension (Recommended)\r\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\r\n\r\n1. Starting with Ray 2.39, Anyscale has introduced the `Ray Distributed Debugger <https://docs.ray.io/en/latest/ray-observability/ray-distributed-debugger.html>`_ VSCode extension. Follow the extension’s installation instructions, then add your cluster using the dashboard URL you obtained earlier.\r\n\r\n   .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/debugger.png?raw=true\r\n      :alt: Ray Distributed Debugger VSCode extension screenshot\r\n\r\n2. Prerequisites.\r\n\r\n   Ensure the following are installed (see the extension README for more detail):\r\n\r\n   - Visual Studio Code  \r\n   - `ray[default]` >= 2.9.1  \r\n   - `debugpy` >= 1.8.0  \r\n\r\n   .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/readme.png?raw=true\r\n      :alt: VSCode with Ray prerequisites\r\n\r\n3. Environment Variables.\r\n\r\n   To enable post‑mortem debugging, set:\r\n\r\n   .. code-block:: bash\r\n\r\n      export RAY_DEBUG_POST_MORTEM=1\r\n\r\n   .. admonition:: Note\r\n      :class: important\r\n\r\n      Be sure to remove any legacy flags before starting Ray:\r\n\r\n      - `RAY_DEBUG=legacy`  \r\n      - `--ray-debugger-external`\r\n\r\n4. Configuring BreakpointsSet up breakpoint() in your code, and submit job to cluster. Then the extension will show the breakpoint information.\r\n\r\n\r\n   1. Insert `breakpoint()` calls into your remote functions.  \r\n   2. Submit your job to the cluster.  \r\n\r\n   The extension will detect active breakpoints and display them in VSCode.\r\n\r\n   **Note:** Breakpoints are only supported inside functions decorated with `@ray.remote`.\r\n\r\n5. Launching the Debugger.\r\n\r\n   Run your job directly from the command line (do not use a `launch.json`):\r\n\r\n   .. code-block:: bash\r\n\r\n      python job.py\r\n\r\n6. Attaching to a Breakpoint.\r\n\r\n Once the process hits the first `breakpoint()`, click the Ray Distributed Debugger icon in the VSCode sidebar to attach the debugger.\r\n\r\n   .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/launch.png?raw=true\r\n      :alt: Attaching VSCode debugger to Ray process\r\n\r\n7. Debugging With Multiple breakpoint().\r\n\r\n   For each subsequent task, first disconnect the current debugger session, then click the extension icon again to attach to the next breakpoint.\r\n\r\n   .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/disconnect.png?raw=true\r\n      :alt: Disconnecting and reconnecting the debugger\r\n\r\nLegacy Ray Debugger\r\n~~~~~~~~~~~~~~~~~~~\r\n1. Ray has a builtin legacy `debugger <https://docs.ray.io/en/latest/ray-observability/user-guides/debug-apps/ray-debugging.html>`_ that allows you to debug your distributed applications. To enable debugger, start ray cluster with ``RAY_DEBUG=legacy`` and ``--ray-debugger-external``.\r\n\r\n.. code-block:: bash\r\n\r\n    # start head node\r\n    RAY_DEBUG=legacy ray start --head --dashboard-host=0.0.0.0 --ray-debugger-external\r\n    # start worker node\r\n    RAY_DEBUG=legacy ray start --address='10.124.46.192:6379' --ray-debugger-external\r\n\r\n2. Set up breakpoint in your code, and submit job to cluster. Then run ``ray debug`` to wait breakpoint:\r\n\r\n.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/legacy.png?raw=true\r\n\r\n"
  },
  {
    "path": "verl_rl/docs/workers/fsdp_workers.rst",
    "content": "PyTorch FSDP Backend\n======================\n\nLast updated: 02/12/2025.\n\nWe support PyTorch FSDP Backend by implementing various workers for\nactor, critic, reference, rollout and reward models. We also implement\nthe ``FSDPVLLMShardingManager`` that reshard weight between FSDP and\nvLLM in `fsdp_vllm.py <https://github.com/volcengine/verl/blob/main/verl/workers/sharding_manager/fsdp_vllm.py>`_.\n\n**Pros**\n\n- Readily support various models.\n\n  - Users only need to implement the corresponding\n    ``dtensor_weight_loader`` for weight synchronization between FSDP\n    and vLLM. While for ``hf_weight_loader``, users can directly apply\n    any models supported both in HF and vLLM without any code change.\n\n- Easy to organize the forward and backward computation for each model.\n\n**Cons**\n\n- Poor scalability when it comes to large-scale models (e.g. Llama 70B\n  and 405B)\n- The resharding overhead between actor and rollout could be larger than\n  Megatron-LM backend.\n\nDue to the simplicity, we recommend using FSDP backend for algorithm\nresearch and prototyping.\n\nFSDP Workers\n--------------\n\nActorRolloutRefWorker\n^^^^^^^^^^^^^^^^^^^^^\n\nActor/Rollout HybridEngine\n''''''''''''''''''''''''''\n\n1. HybridEngine, Actor and Rollout initialization API.\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n   def init_model(self):\n\n``ONE_TO_ALL``: when calling the ``init_model`` function from the driver\nprocess, each worker (on a GPU) will execute the following model\ninitialization process.\n\nThe initialization details of HybridEngine, Actor and Rollout are\nhighlighted below:\n\n1. ``DataParallelPPOActor`` implements the simple PPO computation logics\n   when the model is built with FSDP, including compute log prob, model\n   update.\n2. ``vLLMRollout`` support generation with vLLM. We modify the vLLM\n   Engine and make it executed under SPMD to fit into our\n   ``WorkerGroup`` design.\n3. ``FSDPVLLMShardingManager`` a context manager to perform actual\n   resharding between actor and rollout.\n\nSee `source code <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py>`_. for more information.\n\n1. Generate sequence and recompute log prob\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def generate_sequences(self, prompts: DataProto):\n\n- ``Dispatch.DP_COMPUTE_PROTO``: The data will be dispatched and\n  collected along the DP dimension\n\n- In this function, the rollout model will perform auto-regressive\n  generation and the actor model will recompute the old log prob for the\n  generated response.\n\n3. Update actor model\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def update_actor(self, data: DataProto):\n\n- Update the actor model weight using PPO & entropy loss.\n\nReferenceModel\n''''''''''''''\n\n1. Reference model initialization\n\nThe reference model is initialized using the same function as the actor\nmodel without initializing the HybridEngine and Optimizer. Then the\nactor model is also wrapped by the ``DataParallelPPOActor``.\n\n2. Compute reference log prob\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def compute_ref_log_prob(self, data: DataProto):\n\n- In this function, the reference model will call the compute log prob\n  function in ``DataParallelPPOActor`` to compute the reference log\n  prob.\n\nCriticWorker and RewardWorker\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n1. Model initialization\n\nQuite similar to reference model. The CriticWorker will perform\nadditional initialization for the Optimizer.\n\n2. Compute Values for CriticWorker\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def compute_values(self, data: DataProto):\n\n3. Update Critic\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def update_critic(self, data: DataProto):\n\n4. Compute Reward\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n   def compute_rm_score(self, data: DataProto):\n\n\nHybridShard\n------------\n\nWe didn't support FSDP `HybridShard`. To support this, we may need to\nconstruct a 2D device mesh and test the corresponding\n``dtensor_weight_loader`` and ``hf_weight_loader`` for each model.\n"
  },
  {
    "path": "verl_rl/docs/workers/megatron_workers.rst",
    "content": "Megatron-LM Backend\n===================\n\nLast updated: 06/24/2025.\n\nWe support Megatron Backend by implementing various workers for actor,\ncritic, reference, rollout and reward models. We also implement the\n``3DHybridEngine`` using Megatron-LM and vLLM/SGLang in\n`megatron_vllm.py <https://github.com/volcengine/verl/blob/main/verl/workers/sharding_manager/megatron_vllm.py>`_\nand `megatron_sglang.py <https://github.com/volcengine/verl/blob/main/verl/workers/sharding_manager/megatron_sglang.py>`_.\n\n**Pros**\n\n- Support 5D parallelism (TP, EP, CP, DP, PP) and sequence parallelism\n  for best scalablility and throughput.\n- 3D HybridEngine can significantly reduce peak memory usage and reduce\n  weight synchronize overhead between actor and rollout.\n\n**Cons**\n\n- Huggingface Models and Megatron checkpoints need tools for conversion.\n\n\nDevelopment Progress\n--------------------\n\n\nNote that [Deprecated] means that the feature is not supported in the latest\nversion of verl.\n[To-Optimize] means that the feature is implemented but not optimized yet.\n[WIP] means that the feature is working in progress.\n[In-Release] means that the feature is ready and in review process,\ncoming at any time.\n\n\n+---------------+-----------------------------------------------------------+\n| [Deprecated]  | Megatron 3D Parallelism with custom models                |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron 0.11.0 ``GPTModel`` support                      |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron GRPO support                                     |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron with vLLM 0.8.2, with per-tensor weights loading |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron with Context Parallel                            |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Qwen2MoE model support                                    |\n+---------------+-----------------------------------------------------------+\n| [To-Optimize] | Megatron dist Checkpoint                                  |\n+---------------+-----------------------------------------------------------+\n| [To-Optimize] | Huggingface and Megatron Checkpoint Converter             |\n+---------------+-----------------------------------------------------------+\n| [To-Optimize] | Efficient fused linear, entropy and cross entropy         |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron offload(param, grad, optimizer)                  |\n+---------------+-----------------------------------------------------------+\n| [Done]        | Megatron Profiler                                         |\n+---------------+-----------------------------------------------------------+\n| [In-Release]  | Megatron 0.12.0, TE 2.2 with vLLM 0.8.3 and Fused Attn    |\n+---------------+-----------------------------------------------------------+\n| [WIP]         | Moonlight/DeepSeek-V3 model support                       |\n+---------------+-----------------------------------------------------------+\n| [WIP]         | Expert Parallel support                                   |\n+---------------+-----------------------------------------------------------+\n| [WIP]         | Megatron support dynamic batch size                       |\n+---------------+-----------------------------------------------------------+\n| [To-Do]       | Performance tuning                                        |\n+---------------+-----------------------------------------------------------+\n| [MileStone]   | Runnable with DeepSeek-V3 671B post-training              |\n+---------------+-----------------------------------------------------------+\n\n\n\nUtils of Megatron Workers\n-------------------------\n\nMegatronWorker\n^^^^^^^^^^^^^^\n\n``MegatronWorker`` is the base class of different megatron worker\nclasses. In this class, ``get_megatron_global_info`` and\n``get_megatron_rank_info`` function to retrieve the 3D parallel world\nsize and rank of each ``Worker`` running on specific GPU. These information\nwill be used in transfer protocol for Megatron Backend.\n\nThe following ``Worker`` class for different models will be utilized to\nconstruct the ``WorkerGroup`` .\n\nWe implement various of APIs for each ``Worker`` class decorated by the\n``@register(dispatch_mode=)`` . These APIs can be called by the ray\ndriver process. The data can be correctly collect and dispatch following\nthe ``dispatch_mode`` on each function. The supported dispatch_model\n(i.e., transfer protocols) can be found in `decorator.py <https://github.com/volcengine/verl/blob/main/verl/single_controller/base/decorator.py>`_.\n\nActorRolloutRefWorker\n^^^^^^^^^^^^^^^^^^^^^\n\nThis class is implemented for Actor/Rollout HybridEngine or for the\nreference model to initialize their model and perform computation.\n\nActor/Rollout HybridEngine\n''''''''''''''''''''''''''\n\n1. HybridEngine, Actor and Rollout initialization API.\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n   def init_model(self):\n\n``ONE_TO_ALL``: when calling the ``init_model`` function from the driver\nprocess, each worker (on a GPU) will execute the following model\ninitialization process.\n\nThe initialization details of HybridEngine, Actor and Rollout are\nhighlighted below:\n\n1. ``MegatronPPOActor`` implements the simple PPO computation logics\n   when the model is built with Megatron, including compute log prob,\n   model update.\n2. ``vLLMRollout`` support generation with vLLM. We modify the vLLM\n   Engine and make it executed under SPMD to fit into our\n   ``WorkerGroup`` design.\n3. ``MegatronVLLMShardingManager`` a context manager to perform actual\n   resharding between actor and rollout.\n\nSee `source code <https://github.com/volcengine/verl/blob/main/verl/workers/megatron_workers.py#L63>`_ for more information.\n\n.. code:: python\n\n   # build actor model\n   self.actor = MegatronPPOActor(config=self.config.actor,\n                                 model_config=self.actor_model_config,\n                                 megatron_config=megatron_config,\n                                 actor_module=self.actor_module,\n                                 actor_optimizer=self.actor_optimizer,\n                                 actor_optimizer_config=self.actor_optim_config)\n\n   # build rollout\n   # rollout initialization\n   rollout = vLLMRollout(actor_module=params,\n                        config=self.config.rollout,\n                        tokenizer=self.tokenizer,\n                        model_hf_config=self.actor_model_config,\n                        train_tp=mpu.get_tensor_model_parallel_world_size())\n   # perform weight resharding between actor and rollout\n   sharding_manager = MegatronVLLMShardingManager(module=self.hybrid_engine,\n                                                  inference_engine=rollout.inference_engine,\n                                                  model_config=self.actor_model_config,\n                                                  layer_name_mapping=layer_name_mapping)\n   ...\n\n1. Generate sequence and recompute log prob\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_PP_AS_DP_PROTO)\n   def generate_sequences(self, prompts: DataProto):\n\n- ``Dispatch.MEGATRON_PP_AS_DP_PROTO``: The PP dimension of the actor\n  model will be regarded as DP dimension. Then the driver process will\n  dispatch and collect the data according to this reorganization. This\n  is because, in HybridEngine, the actor weight, which usually applied\n  larger 3D parallel sizes, will be gathered along the PP dimension and\n  TP dimension. Therefore, the corresponding data should be dispatched\n  and collected through the 3D parallel group of the rollout model,\n  rather than the actor model. However, the world_size and rank\n  information can only be retrieved from ``get_megatron_global_info`` and\n  ``get_megatron_rank_info``, which records the 3D information for the\n  actor model. Moreover, the data resharding inside TP dimension will be\n  processed within the HybridEngine.\n\n- In this function, the rollout model will perform auto-regressive\n  generation and the actor model will recompute the old log prob for the\n  generated response.\n\n3. Update actor model\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n   def update_actor(self, data: DataProto):\n\n- ``Dispatch.MEGATRON_COMPUTE_PROTO``: User passes the data partitioned\n  by DP dimension. The data is dispatched to all tp/pp ranks within the\n  same dp group, and ultimately only collects output data from tp=0 and\n  the last pp.\n- Update the actor model weight using PPO & entropy loss.\n\n\n..note:: \n\n   Currently, training Tensor Parallel Size can be different from inference\n   Tensor Parallel Size.\n\n\nReferenceModel\n''''''''''''''\n\n1. Reference model initialization\n\nThe reference model is initialized using the same function as the actor\nmodel without initializing the HybridEngine and Optimizer. Then the\nactor model is also wrapped by the ``MegatronPPOActor``.\n\n2. Compute reference log prob\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n   def compute_ref_log_prob(self, data: DataProto):\n\n- In this function, the reference model will call the compute log prob\n  function in ``MegatronPPOActor`` to compute the reference log prob.\n\nCriticWorker and RewardWorker\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n1. Model initialization\n\nQuite similar to reference model. The CriticWorker will perform\nadditional initialization for the Optimizer.\n\n2. Compute Values for CriticWorker\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n   def compute_values(self, data: DataProto):\n\n3. Update Critic\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n   def update_critic(self, data: DataProto):\n\n4. Compute Reward\n\n.. code:: python\n\n   @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n   def compute_rm_score(self, data: DataProto):\n\n\nUtils of Train Optimization\n---------------------------\n\nOffload\n^^^^^^^\nWhen resources are tight, the offload method can lower GPU memory \nusage, helping training and inference frameworks work well under verl. \nIt moves parameters, gradients, and optimizers to CPU memory and only \nloads them back to the GPU when needed.\n\nIf you want to use the offload, you can add the following parameters \nfor the actor and ref separately. \n\n.. code:: python\n\n   # For the actor\n   actor_rollout_ref.actor.megatron.param_offload=True \\\n   actor_rollout_ref.actor.megatron.grad_offload=True \\\n   actor_rollout_ref.actor.megatron.optimizer_offload=True \\\n   # For the ref w/o grad and optimizer\n   actor_rollout_ref.ref.megatron.param_offload=True \\\n\n\nFor the critic, you can include these parameters.\n\n.. code:: python\n\n   # For the critic\n   critic.megatron.param_offload=True \\\n   critic.megatron.grad_offload=True \\\n   critic.megatron.optimizer_offload=True \\\n\nProfiler\n^^^^^^^^\n\nThe profiler is a tool that helps you understand the performance of your \nmodel. It can be used to profile the time spent on different operations \nand identify the bottlenecks. You can get more information from \n`torch.profiler <https://pytorch.org/docs/stable/profiler.html>`_.\n\nIn verl, now the profiler is only support for the actor role In Megatron. You can set \nthe begin step and end step to profile. Notice, one step means one gradient update. And \nthe profile result will be saved in the save_path. If you just want to profile in the \nspecific rank, you can set the profile_ranks, by default, it will be [0].\n\n.. code:: python\n\n   actor_rollout_ref.actor.profile.use_profile=True \\\n   actor_rollout_ref.actor.profile.profile_ranks=[0] \\\n   actor_rollout_ref.actor.profile.step_start=0 \\\n   actor_rollout_ref.actor.profile.step_end=1 \\\n   actor_rollout_ref.actor.profile.save_path=\"./profile\"\n\n\nRelated MCore Document\n----------------------\n\nThere is also a detailed document of using MCore to train different\nkinds of models, please refer to `MCore Document <https://github.com/volcengine/verl/blob/main/verl/models/mcore/readme.md>`_.\n"
  },
  {
    "path": "verl_rl/docs/workers/ray_trainer.rst",
    "content": "PPO Ray Trainer\n===============\n\nLast updated: 02/12/2025.\n\nWe implement the RayPPOTrainer, which is a trainer runs on the driver\nprocess on a single CPU/GPU node (default is CPU).\n\nThe PPORayTrainer include 3 core functions for data preparation,\nWorkerGroup initialization and PPO training loop.\n\nData Preparation\n----------------\n\nThe ``PPORayTrainer``, as a single process, is responsible for loading a\ncomplete batch of samples (prompts) from the dataset and then dispatch\nto different worker_groups running on different GPUs.\n\nTo generalize the data loading, we implement the ``RLHFDataset`` class\nto load the preprocessed parquet files, apply chat templates to the\nprompts, add padding, truncate prompts that exceed max prompt length and\nthen tokenize.\n\n.. code:: python\n\n   self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,\n                                       tokenizer=self.tokenizer,\n                                       config=self.config.data)\n\nThen, the dataloader will iterate the dataset under PPO mini batch size.\n\nWorkerGroup Initialization\n--------------------------\n\nWe first introduce a basic implementation of initializing the\n``WorkerGroup`` of the actor model on a given set of GPUs.\n\n.. code:: python\n\n   # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool\n   # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.\n   # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models\n   resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n                                   use_gpu=True,\n                                   max_colocate_count=1)\n   # define actor rollout cls to be init on remote\n   actor_rollout_cls = RayClassWithInitArgs(cls=ActorRolloutWorker)\n   # define actor_rollout worker group\n   actor_rollout_worker_group = MegatronRayWorkerGroup(resource_pool=resource_pool,\n                                                       ray_cls_with_init=actor_rollout_cls,\n                                                       default_megatron_kwargs=config.actor_rollout.megatron)\n\nDifferent WorkerGroups, like ``actor_rollout_worker_group`` ,\n``critic_worker_group`` and ``ref_worker_group`` lies on a separate\nprocess in the above implementation.\n\nThe driver process can then call the distributed compute function within\nthe ``actor_rollout_worker_group`` and other roles to construct the RL\ntraining loop.\n\nFor models colocated in the same set of GPUs, we further provide a\nfine-grain optimization, which merge the ``worker_group`` of different roles\nin the same process. This optimization can save the redundant\nCUDA/distributed context in different processes.\n\n.. code:: python\n\n   # initialize WorkerGroup\n   # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,\n   # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.\n   # See TODO(url) for more information.\n   all_wg = {}\n   for resource_pool, class_dict in self.resource_pool_to_cls.items():\n       worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n       wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)\n       spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n       all_wg.update(spawn_wg)\n\n   if self.use_critic:\n       self.critic_wg = all_wg['critic']\n       self.critic_wg.init_model()\n\n   if self.use_reference_policy:\n       self.ref_policy_wg = all_wg['ref']\n       self.ref_policy_wg.init_model()\n\n   if self.use_rm:\n       self.rm_wg = all_wg['rm']\n       self.rm_wg.init_model()\n\n   # we should create rollout at the end so that vllm can have a better estimation of kv cache memory\n   self.actor_rollout_wg = all_wg['actor_rollout']\n   self.actor_rollout_wg.init_model()\n\n.. note:: For megatron backend, if we merge the ``worker_groups`` into the same processes, all the roles will utilize the same 3D parallel size. To optimize this, we may need to maintain several 3D process groups for each role in the same distributed context. If you want to use different 3D parallel size for different roles, please follow the similar architecture of the first code block to initialize each role's ``worker_group``\n\n\nPPO Training Loop\n-----------------\n\nWe implement the PPO training loop by calling the functions in\nworker_group of each role. The input and output data of each function is\na ``DataProto`` object implemented in `protocol.py <https://github.com/volcengine/verl/blob/main/verl/protocol.py>`_. In the training\nloop, trainer will dispatch/collect the data to/from different GPUs\nfollowing the transfer protocols wrapped in the workers' functions. The\ncomputation of PPO micro batches is processed in ``update_actor`` and\n``update_critic`` functions.\n\nTo extend to other RLHF algorithms, such as DPO, GRPO, please refer to\n:doc:`../advance/dpo_extension`.\n\n.. code:: python\n\n   def fit(self):\n       \"\"\"\n       The training loop of PPO.\n       The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.\n       The light-weight advantage computation is done on the driver process.\n       \"\"\"\n       from verl.utils.tracking import Tracking\n       from omegaconf import OmegaConf\n\n       logger = Tracking(project_name=self.config.trainer.project_name,\n                           experiment_name=self.config.trainer.experiment_name,\n                           default_backend=self.config.trainer.logger,\n                           config=OmegaConf.to_container(self.config, resolve=True))\n\n       global_steps = 0\n\n       # perform validation before training\n       # currently, we only support validation using the reward_function.\n       if self.val_reward_fn is not None:\n           val_metrics = self._validate()\n           pprint(f'Initial validation metrics: {val_metrics}')\n\n       for epoch in range(self.config.trainer.total_epochs):\n           for batch_dict in self.train_dataloader:\n               metrics = {}\n\n               batch: DataProto = DataProto.from_single_dict(batch_dict)\n               # batch = batch.to('cuda')\n\n               # pop those keys for generation\n               gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])\n\n               # generate a batch\n               with Timer(name='gen', logger=None) as timer:\n                   gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n               metrics['timing/gen'] = timer.last\n\n               batch = batch.union(gen_batch_output)\n\n               if self.use_reference_policy:\n                   # compute reference log_prob\n                   with Timer(name='ref', logger=None) as timer:\n                       ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                       batch = batch.union(ref_log_prob)\n                   metrics['timing/ref'] = timer.last\n\n               # compute values\n               with Timer(name='values', logger=None) as timer:\n                   values = self.critic_wg.compute_values(batch)\n                   batch = batch.union(values)\n               metrics['timing/values'] = timer.last\n\n               with Timer(name='adv', logger=None) as timer:\n                   # compute scores. Support both model and function-based.\n                   # We first compute the scores using reward model. Then, we call reward_fn to combine\n                   # the results from reward model and rule-based results.\n                   if self.use_rm:\n                       # we first compute reward model score\n                       reward_tensor = self.rm_wg.compute_rm_score(batch)\n                       batch = batch.union(reward_tensor)\n\n                   # we combine with rule-based rm\n                   reward_tensor = self.reward_fn(batch)\n                   batch.batch['token_level_scores'] = reward_tensor\n\n                   # compute rewards. apply_kl_penalty if available\n                   batch, kl_metrics = apply_kl_penalty(batch,\n                                                           kl_ctrl=self.kl_ctrl_in_reward,\n                                                           kl_penalty=self.config.algorithm.kl_penalty)\n                   metrics.update(kl_metrics)\n\n                   # compute advantages, executed on the driver process\n                   batch = compute_advantage(batch,\n                                               self.config.algorithm.gamma,\n                                               self.config.algorithm.lam,\n                                               adv_estimator=self.config.algorithm.adv_estimator)\n               metrics['timing/adv'] = timer.last\n\n               # update critic\n               if self.use_critic:\n                   with Timer(name='update_critic', logger=None) as timer:\n                       critic_output = self.critic_wg.update_critic(batch)\n                   metrics['timing/update_critic'] = timer.last\n                   critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])\n                   metrics.update(critic_output_metrics)\n\n               # implement critic warmup\n               if self.config.trainer.critic_warmup <= global_steps:\n                   # update actor\n                   with Timer(name='update_actor', logger=None) as timer:\n                       actor_output = self.actor_rollout_wg.update_actor(batch)\n                   metrics['timing/update_actor'] = timer.last\n                   actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])\n                   metrics.update(actor_output_metrics)\n\n               # validate\n               if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0:\n                   with Timer(name='testing', logger=None) as timer:\n                       val_metrics: dict = self._validate()\n                       val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}\n                   metrics['timing/testing'] = timer.last\n                   metrics.update(val_metrics)\n\n               # collect metrics\n               data_metrics = compute_data_metrics(batch=batch)\n               metrics.update(data_metrics)\n\n               # TODO: make a canonical logger that supports various backend\n               logger.log(data=metrics, step=global_steps)\n\n               if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0:\n                   actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',\n                                                   f'global_step_{global_steps}')\n                   actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor')\n                   self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)\n\n                   if self.use_critic:\n                       critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',\n                                                           f'global_step_{global_steps}')\n                       critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic')\n                       self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)\n\n               global_steps += 1\n\n       # perform validation after training\n       if self.val_reward_fn is not None:\n           val_metrics = self._validate()\n           pprint(f'Final validation metrics: {val_metrics}')\n"
  },
  {
    "path": "verl_rl/docs/workers/sglang_worker.rst",
    "content": "SGLang Backend\n==============\n\nLast updated: 05/31/2025.\n\n**Authored By SGLang RL Team and listed alphabetically by last name**\n\n`Jingyi Chen <https://github.com/fzyzcjy>`_, `Yitong Guan <https://github.com/minleminzui>`_, `Zhuobin Huang <https://zobinhuang.github.io/sec_about/>`_, `Jiajun Li <https://github.com/guapisolo>`_, `Ji Li <https://github.com/GeLee-Q>`_, `Shenggui Li <https://franklee.xyz/about>`_, `Junrong Lin <https://github.com/ocss884>`_, `Xiang Long <https://github.com/SwordFaith>`_, `Rui Lu <https://scholar.google.com/citations?user=-MGuqDcAAAAJ>`_, `Jin Pan <https://jhinpan.github.io/>`_, `Shuai Shi <https://github.com/shuaills>`_, `Yushen Su <https://yushengsu-thu.github.io/>`_, `Xinyuan Tong <https://github.com/JustinTong0323>`_, `Chendong Wang <https://github.com/cedricbeta>`_, `Hanchen Zhang <https://scholar.google.com/citations?user=pGcJcagAAAAJ>`_, `Haoran Wang <https://ubecc.github.io/about/>`_, `Yongan Xiang <https://github.com/BearBiscuit05>`_, `Chengxing Xie <https://yitianlian.github.io/>`_, `Yuhao Yang <https://github.com/yhyang201>`_, `Jinwei Yao <https://kivi-yao.github.io/>`_, `Qiaolin Yu <https://github.com/Qiaolin-Yu>`_, `Yuzhen Zhou <https://github.com/zyzshishui>`_, `Chenyang Zhao <https://github.com/zhaochenyang20>`_\n\n\n\nIntroduction\n------------\n`SGLang <https://github.com/sgl-project/sglang>`_ is an open-source state-of-the-art inference service engine, fully adopted by xAI to support all inference needs of Grok during research and serving processes.\n\nCurrently, verl fully supports using SGLang as the inference engine during the rollout phase. As a rollout engine, SGLang provides the same feature coverage as vLLM., including memory saving and multi-node rollout features. After installing verl and SGLang, simply add ``actor_rollout_ref.rollout.name=sglang`` at startup script to seamlessly switch between the two inference frameworks.\n\nIn addition, the SGLang team is actively working on supporting features such as Multi-Turn Agentic RL, VLM RLHF, Server-Based RLHF, and Partial Rollout. You can track the related development progress in the `Tracking Roadmap <https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/74>`_.\n\nInstallation\n------------\nPlease always follow the following command to install SGLang with verl. \n\n.. code-block:: bash\n    \n    pip install --upgrade pip\n    # Currently 0.4.6.post5, subject to updates at any time, please refer to the latest version specified in `setup.py`\n    pip install -e \".[sglang]\"\n\nYou can check the following dependencies are in your environment:\n\n.. note::\n\n    - **PyTorch**: 2.6.0+cu124\n    - **CUDA**: 12.4\n    - **flashinfer-python**: 0.2.5+cu124torch2.6\n    - **sgLang**: 0.4.6.post5\n    - **sgl-kernel**: 0.1.4\n\nUsing SGLang as the Inference Backend for PPO Training on a Single Machine\n-------------------------------------------------------------------------\nWe use Qwen/Qwen2-7B-Instruct on the gsm8k dataset for a simple test.\n\n1. Run the following command to prepare the gsm8k dataset:\n\n.. code-block:: bash\n\n    python3 examples/data_preprocess/gsm8k.py\n\n2. Run the following script to conduct a PPO experiment on a single machine with 4 GPUs:\n\n.. code-block:: bash\n\n    export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=True\n    PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n        data.train_files=$HOME/data/gsm8k/train.parquet \\\n        data.val_files=$HOME/data/gsm8k/test.parquet \\\n        data.train_batch_size=4096 \\\n        data.max_prompt_length=4096 \\\n        data.max_response_length=4096 \\\n        actor_rollout_ref.rollout.name=sglang \\\n        actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n        critic.optim.lr=1e-5 \\\n        critic.model.path=Qwen/Qwen2-7B-Instruct \\\n        critic.ppo_micro_batch_size_per_gpu=4 \\\n        critic.model.fsdp_config.param_offload=True \\\n        critic.model.fsdp_config.optimizer_offload=True \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.logger=console \\\n        trainer.val_before_train=False \\\n        trainer.n_gpus_per_node=4 \\\n        trainer.nnodes=1 \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15 2>&1 | tee verl_demo.log\n\nWhy export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n1. ``verl`` initializes a ``SGLangRollout`` module during rollout, which is used to evaluate/generate samples.\n\n2. ``SGLangRollout`` will initialize ``Engine``, and further initialize a ``torch.distributed.DeviceMesh``, used to support Tensor Parallel (TP).\n\n3. ``DeviceMesh.init()`` internally checks the free GPU memory of all participating devices. If the difference is too large (more than ~10%), it directly reports an error to avoid initialization failures or deadlocks.\n\nWhy might there be inconsistent GPU memory?\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n**1. Ray Distributed Actor loads the model at different times**\n\n``verl`` uses Ray-based multi-process, multi-GPU concurrent training. Each ``WorkerDict`` may be called at different times:\n\n.. code-block:: python\n\n    self.rollout = SGLangRollout(...)\n\nDifferent workers initialize the model at different times → different memory usage.\n\n**2. Delayed initialization causes memory bias**\n\nSome workers start model loading/inference (e.g., ``generate_sequences()``, ``compute_log_prob()``) earlier than others.  \nEarly workers already use up GPU memory → late workers still have empty memory → memory difference appears.\n\n**3. SGLang's TP init uses \"all-device broadcast\", but there's no uniform release timing**\n\nAlthough ``SGLangRollout`` may only involve subset of GPUs, its ``Engine`` initialization calls ``torch.distributed.init_process_group()`` and broadcasts weights, so:\n\n- Non-rollout GPUs also join the communication.\n- Later on, ``DeviceMesh`` init will fail due to \"inconsistent memory\".\n\n**4. Different FSDP/TP loading behaviors also lead to mismatch**\n\nIf using:\n\n.. code-block:: bash\n\n    actor.fsdp_config.param_offload=True  \n    ref.fsdp_config.param_offload=True\n\nThen some workers keep params on CPU while others already sharded to GPU → leads to asymmetric memory layout.\n\nUsing SGLang as the Inference Backend for PPO Training Across Multiple Machines\n------------------------------------------------------------------------------\nSGLang also supports running verl's RAY-based cross-machine inference in IPv4 and IPv6 scenarios. In the script below, we use TP=16 for cross-machine inference. Suppose we have two interconnected machines: node0 with IP 10.94.16.4 and node1 with IP 10.94.16.5.\n\n1. Start Ray on node0:\n\n.. code-block:: bash\n\n    ray start --head --dashboard-host=0.0.0.0\n\nYou will see the following prompt:\n\n.. code-block:: bash\n\n    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.\n\n    Local node IP: 10.94.16.4\n\n    --------------------\n    Ray runtime started.\n    --------------------\n\n    Next steps\n    To add another node to this Ray cluster, run\n        ray start --address='10.94.16.4:6379'\n\n2. Have node1 join the Ray cluster:\n\nRun the following command on node1:\n\n.. code-block:: bash\n\n    ray start --address='10.94.16.4:6379'\n\nRun the following command to confirm that the Ray cluster now has two nodes:\n\n.. code-block:: bash\n\n    ray status\n\nYou can see that the cluster has two nodes with 16 GPUs:\n\n.. code-block:: bash\n\n    ======== Autoscaler status: 2025-04-09 09:25:37.694016 ========\n    Node status\n    ---------------------------------------------------------------\n    Active:\n     1 node_ef382ffd687d8f6b060c1b68e63ada7341b936fe5b1901dd04de1027\n     1 node_1eb4d7d07e793114c23a89d1a41f1f76acf6ef5b35af844a4ee8e4ba\n    Pending:\n     (no pending nodes)\n    Recent failures:\n     (no failures)\n\n    Resources\n    ---------------------------------------------------------------\n    Usage:\n     0.0/360.0 CPU\n     0.0/16.0 GPU\n     0B/3.39TiB memory\n     0B/372.53GiB object_store_memory\n\n3. Run the following script to train meta-llama/Llama-3.1-8B-Instruct with TP=16 across 2 machines using 16 GPUs:\n\n.. code-block:: bash\n\n    DATA_DIR=$HOME/data/gsm8k\n\n    python3 -m verl.trainer.main_ppo \\\n        actor_rollout_ref.rollout.name=sglang \\\n        data.train_files=$DATA_DIR/train.parquet \\\n        data.val_files=$DATA_DIR/test.parquet \\\n        data.train_batch_size=4096 \\\n        data.max_prompt_length=4096 \\\n        data.max_response_length=4096 \\\n        actor_rollout_ref.model.path=meta-llama/Llama-3.1-8B-Instruct \\\n        actor_rollout_ref.actor.optim.lr=1e-6 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=16 \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n        actor_rollout_ref.rollout.free_cache_engine=True \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size=16 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n        critic.optim.lr=1e-5 \\\n        critic.model.use_remove_padding=True \\\n        critic.model.path=meta-llama/Llama-3.1-8B-Instruct \\\n        critic.model.enable_gradient_checkpointing=True \\\n        critic.ppo_micro_batch_size=16 \\\n        critic.model.fsdp_config.param_offload=True \\\n        critic.model.fsdp_config.optimizer_offload=True \\\n        algorithm.kl_ctrl.kl_coef=0.001 \\\n        trainer.critic_warmup=0 \\\n        trainer.logger=console \\\n        trainer.val_before_train=True \\\n        trainer.n_gpus_per_node=8 \\\n        trainer.nnodes=2 \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=10 \\\n        trainer.total_epochs=15 2>&1 | tee verl_demo.log\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/aime2024_multiturn_w_tool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the DAPO-Math-17k dataset to multiturn format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/retool_aime2024\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    data_path = \"BytedTsinghua-SIA/AIME-2024\"\n    dataset = datasets.load_dataset(data_path, \"default\")\n\n    train_dataset = dataset[\"train\"]\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            orig_extra_info = example.pop(\"extra_info\")\n            extra_info = orig_extra_info.copy()\n            extra_info[\"need_tools_kwargs\"] = True\n            extra_info[\"tools_kwargs\"] = {\n                \"code_interpreter\": {\n                    \"create_kwargs\": {\n                        \"ground_truth\": example[\"reward_model\"][\"ground_truth\"],\n                    },\n                },\n            }\n            example[\"extra_info\"] = extra_info\n            return example\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/dapo_multiturn_w_tool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the DAPO-Math-17k dataset to multiturn format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/retool_dapo\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    data_path = \"BytedTsinghua-SIA/DAPO-Math-17k\"\n    dataset = datasets.load_dataset(data_path, \"default\")\n\n    train_dataset = dataset[\"train\"]\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            orig_extra_info = example.pop(\"extra_info\")\n            extra_info = orig_extra_info.copy()\n            extra_info[\"need_tools_kwargs\"] = True\n            extra_info[\"tools_kwargs\"] = {\n                \"code_interpreter\": {\n                    \"create_kwargs\": {\n                        \"ground_truth\": example[\"reward_model\"][\"ground_truth\"],\n                    },\n                },\n            }\n            example[\"extra_info\"] = extra_info\n            return example\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/full_hh_rlhf.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\n- Preprocess data and split the training set into 75% for training RM and 25% for validting RM.\n- All the training data is used to train SFT and RL.\n- Both chosen and rejected is used to train SFT\n\"\"\"\n\nimport argparse\nimport os\n\nimport pandas as pd\nfrom datasets import load_dataset\nfrom tqdm.auto import tqdm\n\nfrom verl.utils.fs import copy, makedirs\n\n\ndef generate_sft_dataset(target_hdfs_path_dir, local_dir=\"~/data/full_hh_rlh/sft\"):\n    dataset = load_dataset(\"Dahoas/full-hh-rlhf\")\n    output = {\"prompt\": [], \"response\": []}\n    for data in tqdm(dataset[\"train\"]):\n        # add chosen\n        output[\"prompt\"].append(data[\"prompt\"])\n        output[\"response\"].append(data[\"chosen\"])\n\n        # add rejection\n        output[\"prompt\"].append(data[\"prompt\"])\n        output[\"response\"].append(data[\"rejected\"])\n\n    df = pd.DataFrame(output)\n\n    local_dir = os.path.expanduser(local_dir)\n    os.makedirs(local_dir, exist_ok=True)\n\n    local_path = os.path.join(local_dir, \"train.parquet\")\n\n    df.to_parquet(path=local_path)\n\n    if target_hdfs_path_dir is not None:\n        hdfs_dir = target_hdfs_path_dir + \"/\" + \"train.parquet\"\n        makedirs(hdfs_dir)\n\n        copy(local_path, hdfs_dir)\n\n\ndef generate_rm_dataset(target_hdfs_path_dir, local_dir=\"~/data/full_hh_rlh/rm\"):\n    train_dataset = load_dataset(\"Dahoas/full-hh-rlhf\", split=\"train[:75%]\")\n    test_dataset = load_dataset(\"Dahoas/full-hh-rlhf\", split=\"train[-25%:]\")\n\n    local_dir = os.path.expanduser(local_dir)\n    os.makedirs(local_dir, exist_ok=True)\n\n    for dataset, name in zip([train_dataset, test_dataset], [\"train\", \"test\"], strict=True):\n        output = {\"prompt\": [], \"chosen\": [], \"rejected\": []}\n        for data in tqdm(dataset):\n            # add chosen\n            output[\"prompt\"].append(data[\"prompt\"])\n            output[\"chosen\"].append(data[\"chosen\"])\n            output[\"rejected\"].append(data[\"rejected\"])\n\n        df = pd.DataFrame(output)\n\n        local_path = os.path.join(local_dir, name + \".parquet\")\n\n        df.to_parquet(path=local_path)\n\n        if target_hdfs_path_dir is not None:\n            hdfs_dir = target_hdfs_path_dir + \"/\" + name + \".parquet\"\n            makedirs(hdfs_dir)\n\n            copy(local_path, hdfs_dir)\n\n\ndef generate_rl_dataset(target_hdfs_path_dir, local_dir=\"~/data/full_hh_rlhf/rl\"):\n    dataset = load_dataset(\"Dahoas/full-hh-rlhf\")\n    train_dataset = dataset[\"train\"]\n\n    data_source = \"Dahoas/full-hh-rlhf\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            prompt = example.pop(\"prompt\")\n            response = example.pop(\"response\")\n\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [{\"role\": \"user\", \"content\": prompt}],\n                \"ability\": \"alignment\",\n                \"reward_model\": {\n                    \"style\": \"model\",\n                    \"ground_truth\": response,  # should not be used\n                },\n                \"extra_info\": {\"split\": split, \"index\": idx},\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    local_dir = os.path.expanduser(local_dir)\n    local_path = os.path.join(local_dir, \"train.parquet\")\n    train_dataset.to_parquet(local_path)\n\n    if target_hdfs_path_dir is not None:\n        hdfs_dir = target_hdfs_path_dir + \"/\" + \"train.parquet\"\n        makedirs(hdfs_dir)\n\n        copy(local_path, hdfs_dir)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--split\", type=str, choices=[\"sft\", \"rm\", \"rl\"], required=True)\n    parser.add_argument(\"--local_dir\", type=str, default=\"~/data/full_hh_rlhf\")\n    parser.add_argument(\"--hdfs_dir\", type=str, required=False, default=None)\n\n    args = parser.parse_args()\n\n    if args.split == \"sft\":\n        generate_sft_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split))\n    elif args.split == \"rm\":\n        generate_rm_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split))\n    elif args.split == \"rl\":\n        generate_rl_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split))\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/geo3k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the Geometry3k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/geo3k\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    data_source = \"hiyouga/geometry3k\"\n\n    dataset = datasets.load_dataset(data_source)\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = (\n        r\"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. \"\n        r\"The reasoning process MUST BE enclosed within <think> </think> tags. \"\n        r\"The final answer MUST BE put in \\boxed{}.\"\n    )\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            problem = example.pop(\"problem\")\n            prompt = problem + \" \" + instruction_following\n            answer = example.pop(\"answer\")\n            images = example.pop(\"images\")\n\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": prompt,\n                    }\n                ],\n                \"images\": images,\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": answer},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer,\n                    \"question\": problem,\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True, num_proc=8)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True, num_proc=8)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/geo3k_multiturn_w_tool.py",
    "content": "# Copyright 2023-2025 SGLang Team\n# Copyright Amazon.com, Inc. or its affiliates.\n# Copyright 2025 Reallm Labs Ltd. or its affiliates\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nPreprocess the Geometry3k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/geo3k_multiturn_w_tool\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    args = parser.parse_args()\n    data_source = \"hiyouga/geometry3k\"\n    dataset = datasets.load_dataset(data_source)\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n    instruction_following = (\n        r\"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. \"\n        r\"The reasoning process MUST BE enclosed within <think> </think> tags. \"\n        r\"The final answer MUST BE put in \\boxed{}.\"\n    )\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            problem = example.pop(\"problem\")\n            prompt = problem + \" \" + instruction_following\n            answer = example.pop(\"answer\")\n            images = example.pop(\"images\")\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"system\",\n                        \"content\": (\n                            \"You are a math expert. You are given a question and you need to solve it step by step. \"\n                            \"Reasoning step by step before any tool call. \"\n                            \"You should use the `calc_geo3k_reward` tool after step by step solving the question, \"\n                            \"before generate final answer at least once and refine your answer if necessary. \"\n                        ),\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": prompt,\n                    },\n                ],\n                \"images\": images,\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": answer},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer,\n                    \"question\": problem,\n                    \"need_tools_kwargs\": True,\n                    \"tools_kwargs\": {\n                        \"calc_geo3k_reward\": {\n                            \"create_kwargs\": {\"ground_truth\": answer},\n                            # \"execute_kwargs\": {},\n                            # \"calc_reward_kwargs\": {},\n                            # \"release_kwargs\": {},\n                        },\n                    },\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True, num_proc=8)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True, num_proc=8)\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/gsm8k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the GSM8k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef extract_solution(solution_str):\n    solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n    assert solution is not None\n    final_solution = solution.group(0)\n    final_solution = final_solution.split(\"#### \")[1].replace(\",\", \"\")\n    return final_solution\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/gsm8k\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    data_source = \"openai/gsm8k\"\n\n    dataset = datasets.load_dataset(data_source, \"main\")\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = 'Let\\'s think step by step and output the final answer after \"####\".'\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"question\")\n\n            question = question_raw + \" \" + instruction_following\n\n            answer_raw = example.pop(\"answer\")\n            solution = extract_solution(answer_raw)\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    }\n                ],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer_raw,\n                    \"question\": question_raw,\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/gsm8k_multiturn_w_interaction.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the GSM8k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef extract_solution(solution_str):\n    solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n    assert solution is not None\n    final_solution = solution.group(0)\n    final_solution = final_solution.split(\"#### \")[1].replace(\",\", \"\")\n    return final_solution\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/gsm8k\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    data_source = \"openai/gsm8k\"\n    dataset = datasets.load_dataset(data_source, \"main\")\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = \"Let's think step by step and output the final answer after `####`.\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"question\")\n\n            question = question_raw + \" \" + instruction_following\n\n            answer_raw = example.pop(\"answer\")\n            solution = extract_solution(answer_raw)\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"system\",\n                        \"content\": (\n                            \"You are a math expert. You are given a question and you need to solve it step by step. \"\n                            \"You should rethinking carefully if user point out your answer is wrong. \"\n                            \"Put your final answer in the format of `#### <answer>`.\"\n                        ),\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    },\n                ],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer_raw,\n                    \"question\": question_raw,\n                    \"interaction_kwargs\": {\n                        \"name\": \"gsm8k\",\n                        \"query\": question,\n                        \"ground_truth\": solution,\n                    },\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/gsm8k_multiturn_w_tool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the GSM8k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef extract_solution(solution_str):\n    solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n    assert solution is not None\n    final_solution = solution.group(0)\n    final_solution = final_solution.split(\"#### \")[1].replace(\",\", \"\")\n    return final_solution\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/gsm8k\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    data_source = \"openai/gsm8k\"\n    dataset = datasets.load_dataset(data_source, \"main\")\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = \"Let's think step by step and output the final answer after `####`.\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"question\")\n\n            question = question_raw + \" \" + instruction_following\n\n            answer_raw = example.pop(\"answer\")\n            solution = extract_solution(answer_raw)\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"system\",\n                        \"content\": (\n                            \"You are a math expert. You are given a question and you need to solve it step by step. \"\n                            \"Reasoning step by step before any tool call. \"\n                            \"You should use the `calc_gsm8k_reward` tool after step by step solving the question, \"\n                            \"before generate final answer at least once and refine your answer if necessary. \"\n                            \"Put your final answer in the format of `#### <answer>`.\"\n                        ),\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    },\n                ],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer_raw,\n                    \"question\": question_raw,\n                    \"need_tools_kwargs\": True,\n                    \"tools_kwargs\": {\n                        \"calc_gsm8k_reward\": {\n                            \"create_kwargs\": {\"ground_truth\": solution},\n                            # \"execute_kwargs\": {},\n                            # \"calc_reward_kwargs\": {},\n                            # \"release_kwargs\": {},\n                        },\n                    },\n                    \"interaction_kwargs\": {\n                        \"query\": question,\n                        \"ground_truth\": solution,\n                    },\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/gsm8k_tool_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the GSM8k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef extract_solution(solution_str):\n    solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n    assert solution is not None\n    final_solution = solution.group(0)\n    final_solution = final_solution.split(\"#### \")[1].replace(\",\", \"\")\n    return final_solution\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/gsm8k\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    data_source = \"openai/gsm8k\"\n    dataset = datasets.load_dataset(data_source, \"main\")\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = \"Let's think step by step and output the final answer after `####`.\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"question\")\n\n            question = question_raw + \" \" + instruction_following\n\n            answer_raw = example.pop(\"answer\")\n            solution = extract_solution(answer_raw)\n            data = {\n                \"data_source\": data_source,\n                \"agent_name\": \"tool_agent\",\n                \"prompt\": [\n                    {\n                        \"role\": \"system\",\n                        \"content\": (\n                            \"You are a math expert. You are given a question and you need to solve it step by step. \"\n                            \"Reasoning step by step before any tool call. \"\n                            \"You should use the `calc_gsm8k_reward` tool after step by step solving the question, \"\n                            \"before generate final answer at least once and refine your answer if necessary. \"\n                            \"Put your final answer in the format of `#### <answer>`.\"\n                        ),\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    },\n                ],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer_raw,\n                    \"question\": question_raw,\n                    \"need_tools_kwargs\": True,\n                    \"tools_kwargs\": {\n                        \"calc_gsm8k_reward\": {\n                            \"create_kwargs\": {\"ground_truth\": solution},\n                            # \"execute_kwargs\": {},\n                            # \"calc_reward_kwargs\": {},\n                            # \"release_kwargs\": {},\n                        },\n                    },\n                    \"interaction_kwargs\": {\n                        \"query\": question,\n                        \"ground_truth\": solution,\n                    },\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/hellaswag.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess Hellaswag dataset.\n\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef preprocess(text):\n    text = text.strip()\n    # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.\n    text = text.replace(\" [title]\", \". \")\n    text = re.sub(\"\\\\[.*?\\\\]\", \"\", text)\n    text = text.replace(\"  \", \" \")\n    return text\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"/opt/tiger/hellaswag\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    data_source = \"Rowan/hellaswag\"\n\n    dataset = datasets.load_dataset(data_source, trust_remote_code=True)\n\n    train_dataset = dataset[\"train\"]\n    val_dataset = dataset[\"validation\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction = \"Please complete the following sentence.\\n\"\n\n    def make_map_fn(split):\n        def process_fn(doc, idx):\n            ctx = doc[\"ctx_a\"] + \" \" + doc[\"ctx_b\"].capitalize()\n            query = preprocess(doc[\"activity_label\"] + \": \" + ctx)\n            choices = [preprocess(ending) for ending in doc[\"endings\"]]\n            gold = int(doc[\"label\"])\n\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [{\"role\": \"user\", \"content\": query}],\n                \"ability\": \"nlp\",\n                \"reward_model\": {\n                    \"style\": \"model\",\n                    \"eval\": \"multiple_choice\",  # using loglikelihood\n                    \"ground_truth\": gold,\n                    \"choices\": choices,\n                },\n                \"extra_info\": {\"split\": split, \"index\": idx},\n            }\n            return data\n\n        return process_fn\n\n    # filter data that doesn't have a label\n    train_dataset = train_dataset.filter(lambda x: len(x[\"label\"]) > 0)\n    val_dataset = val_dataset.filter(lambda x: len(x[\"label\"]) > 0)\n    test_dataset = test_dataset.filter(lambda x: len(x[\"label\"]) > 0)\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    val_dataset = val_dataset.map(function=make_map_fn(\"validation\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    val_dataset.to_parquet(os.path.join(local_dir, \"validation.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/math_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the MATH-lighteval dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom verl.utils.hdfs_io import copy, makedirs\nfrom verl.utils.reward_score.math import last_boxed_only_string, remove_boxed\n\n\ndef extract_solution(solution_str):\n    return remove_boxed(last_boxed_only_string(solution_str))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/math\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    # 'lighteval/MATH' is no longer available on huggingface.\n    # Use mirror repo: DigitalLearningGmbH/MATH-lighteval\n    data_source = \"DigitalLearningGmbH/MATH-lighteval\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n    dataset = datasets.load_dataset(data_source, trust_remote_code=True)\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = \"Let's think step by step and output the final answer within \\\\boxed{}.\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question = example.pop(\"problem\")\n\n            question = question + \" \" + instruction_following\n\n            answer = example.pop(\"solution\")\n            solution = extract_solution(answer)\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [{\"role\": \"user\", \"content\": question}],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\"split\": split, \"index\": idx},\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/multiturn.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n\n#     http://www.apache.org/licenses/LICENSE-2.0\n\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nCreate a simple multi-turn dataset for testing\n\"\"\"\n\nimport argparse\nimport os\n\nimport pandas as pd\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/multiturn\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    args = parser.parse_args()\n\n    # Create example conversations\n    conversations = []\n\n    # Conversation 1\n    conversations.append(\n        {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n                {\"role\": \"assistant\", \"content\": \"The capital of France is Paris.\"},\n                {\"role\": \"user\", \"content\": \"And what about Germany?\"},\n                {\"role\": \"assistant\", \"content\": \"The capital of Germany is Berlin.\"},\n            ]\n        }\n    )\n\n    # Conversation 2\n    conversations.append(\n        {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"Can you explain quantum computing?\"},\n                {\n                    \"role\": \"assistant\",\n                    \"content\": \"Quantum computing is a type of computing that uses quantum-mechanical phenomena, \"\n                    \"such as superposition and entanglement, to perform operations on data.\",\n                },\n                {\"role\": \"user\", \"content\": \"How is it different from classical computing?\"},\n                {\n                    \"role\": \"assistant\",\n                    \"content\": \"Classical computing uses bits that are either 0 or 1, while quantum computing uses \"\n                    \"quantum bits or qubits that can exist in multiple states simultaneously due to superposition.\",\n                },\n            ]\n        }\n    )\n\n    # Conversation 3\n    conversations.append(\n        {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"Write a simple Python function to calculate factorial.\"},\n                {\n                    \"role\": \"assistant\",\n                    \"content\": (\n                        \"```python\\ndef factorial(n):\\n    if n == 0 or n == 1:\\n        return 1\\n    else:\\n        \"\n                        \"return n * factorial(n-1)\\n```\\n\\nThis is a recursive function to calculate the \"\n                        \"factorial of a number.\"\n                    ),\n                },\n                {\"role\": \"user\", \"content\": \"Can you make it iterative instead?\"},\n                {\n                    \"role\": \"assistant\",\n                    \"content\": (\n                        \"```python\\ndef factorial(n):\\n    result = 1\\n    for i in range(1, n+1):\\n        \"\n                        \"result *= i\\n    return result\\n```\\n\\nThis is an iterative version of the factorial function.\"\n                    ),\n                },\n            ]\n        }\n    )\n\n    # Create train and test datasets\n    train_data = conversations[:2]  # First 2 conversations for training\n    test_data = conversations[2:]  # Last conversation for testing\n\n    # Create output directory\n    local_dir = os.path.expanduser(args.local_dir)\n    os.makedirs(local_dir, exist_ok=True)\n\n    # Save to parquet files\n    train_df = pd.DataFrame(train_data)\n    test_df = pd.DataFrame(test_data)\n\n    train_df.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_df.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    # Handle HDFS if specified\n    if args.hdfs_dir is not None:\n        try:\n            from verl.utils.hdfs_io import copy, makedirs\n\n            makedirs(args.hdfs_dir)\n            copy(src=local_dir, dst=args.hdfs_dir)\n        except ImportError:\n            print(\"Warning: HDFS support not available. Skipping HDFS copy.\")\n\n    # Print statistics\n    print(f\"Train dataset size: {len(train_df)}\")\n    print(f\"Test dataset size: {len(test_df)}\")\n    print(f\"Data saved to {local_dir}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/examples/data_preprocess/preprocess_search_r1_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport argparse\r\nimport logging\r\nimport os\r\nimport tempfile\r\n\r\nimport pandas as pd\r\nfrom huggingface_hub import hf_hub_download\r\nfrom huggingface_hub.utils import EntryNotFoundError\r\n\r\nfrom verl.utils.hdfs_io import copy, makedirs\r\n\r\n# Setup logging\r\nlogging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\r\nlogger = logging.getLogger(__name__)\r\n\r\n# Configuration constants\r\nDEFAULT_SYSTEM_CONTENT = \"You are a helpful and harmless assistant.\"\r\nDEFAULT_USER_CONTENT_PREFIX = (\r\n    \"Answer the given question. You must conduct reasoning inside <think> and </think> \"\r\n    \"first every time you get new information. After reasoning, if you find you lack \"\r\n    \"some knowledge, you can call a search engine by <tool_call> query </tool_call> \"\r\n    \"and it will return the top searched results between <tool_response> and \"\r\n    \"</tool_response>. You can search as many times as your want. If you find no \"\r\n    \"further external knowledge needed, you can directly provide the answer inside \"\r\n    \"<answer> and </answer>, without detailed illustrations. For example, \"\r\n    \"<answer> Beijing </answer>. Question: \"\r\n)\r\n\r\n\r\ndef process_single_row(row, current_split_name, row_index):\r\n    \"\"\"\r\n    Process a single row of data for SearchR1-like format.\r\n\r\n    Args:\r\n        row: DataFrame row containing the original data\r\n        current_split_name: Name of the current split (train/test)\r\n        row_index: Index of the row in the DataFrame\r\n\r\n    Returns:\r\n        pd.Series: Processed row data in the required format\r\n    \"\"\"\r\n    question = row.get(\"question\", \"\")\r\n\r\n    # Build prompt structure\r\n    user_content = user_content_prefix.rstrip(\"\\n\") + question\r\n    prompt = [{\"role\": \"system\", \"content\": system_content}, {\"role\": \"user\", \"content\": user_content}]\r\n\r\n    # Extract ground truth from reward_model or fallback to golden_answers\r\n    reward_model_data = row.get(\"reward_model\")\r\n    if isinstance(reward_model_data, dict) and \"ground_truth\" in reward_model_data:\r\n        ground_truth = reward_model_data.get(\"ground_truth\")\r\n    else:\r\n        ground_truth = row.get(\"golden_answers\", [])\r\n\r\n    # Process data source\r\n    data_source_tagged = \"searchR1_\" + str(row.get(\"data_source\", \"\"))\r\n\r\n    # Build tools kwargs structure\r\n    tools_kwargs = {\r\n        \"search\": {\r\n            \"create_kwargs\": {\"ground_truth\": ground_truth, \"question\": question, \"data_source\": data_source_tagged}\r\n        }\r\n    }\r\n\r\n    # Build complete extra_info structure\r\n    extra_info = {\r\n        \"index\": row_index,\r\n        \"need_tools_kwargs\": True,\r\n        \"question\": question,\r\n        \"split\": current_split_name,\r\n        \"tools_kwargs\": tools_kwargs,\r\n    }\r\n\r\n    return pd.Series(\r\n        {\r\n            \"data_source\": data_source_tagged,\r\n            \"prompt\": prompt,\r\n            \"ability\": row.get(\"ability\"),\r\n            \"reward_model\": reward_model_data,\r\n            \"extra_info\": extra_info,\r\n            \"metadata\": row.get(\"metadata\"),\r\n        }\r\n    )\r\n\r\n\r\ndef main():\r\n    local_save_dir = os.path.expanduser(args.local_dir)\r\n    os.makedirs(local_save_dir, exist_ok=True)\r\n\r\n    processed_files = []\r\n\r\n    # Download and process files using temporary directory\r\n    with tempfile.TemporaryDirectory() as tmp_download_dir:\r\n        for split in [\"train\", \"test\"]:\r\n            parquet_filename = f\"{split}.parquet\"\r\n            logger.info(f\"Processing {split} split...\")\r\n\r\n            try:\r\n                # Download Parquet file from HuggingFace\r\n                logger.info(f\"Downloading {parquet_filename} from {args.hf_repo_id}\")\r\n                local_parquet_filepath = hf_hub_download(\r\n                    repo_id=args.hf_repo_id,\r\n                    filename=parquet_filename,\r\n                    repo_type=\"dataset\",\r\n                    local_dir=tmp_download_dir,\r\n                    local_dir_use_symlinks=False,\r\n                )\r\n\r\n                # Load and process Parquet file\r\n                df_raw = pd.read_parquet(local_parquet_filepath)\r\n                logger.info(f\"Loaded {len(df_raw)} rows from {parquet_filename}\")\r\n\r\n                def apply_process_row(row, split_name=split):\r\n                    return process_single_row(row, current_split_name=split_name, row_index=row.name)\r\n\r\n                df_processed = df_raw.apply(apply_process_row, axis=1)\r\n\r\n                # Save processed DataFrame\r\n                output_file_path = os.path.join(local_save_dir, f\"{split}.parquet\")\r\n                df_processed.to_parquet(output_file_path, index=False)\r\n                logger.info(f\"Saved {len(df_processed)} processed rows to {output_file_path}\")\r\n                processed_files.append(output_file_path)\r\n\r\n            except EntryNotFoundError:\r\n                logger.warning(f\"{parquet_filename} not found in repository {args.hf_repo_id}\")\r\n            except Exception as e:\r\n                logger.error(f\"Error processing {split} split: {e}\")\r\n\r\n    if not processed_files:\r\n        logger.warning(\"No data was processed or saved\")\r\n        return\r\n\r\n    logger.info(f\"Successfully processed {len(processed_files)} files to {local_save_dir}\")\r\n\r\n    # Copy to HDFS if specified\r\n    if args.hdfs_dir:\r\n        try:\r\n            makedirs(args.hdfs_dir)\r\n            copy(src=local_save_dir, dst=args.hdfs_dir)\r\n            logger.info(f\"Successfully copied files to HDFS: {args.hdfs_dir}\")\r\n        except Exception as e:\r\n            logger.error(f\"Error copying files to HDFS: {e}\")\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    parser = argparse.ArgumentParser(description=\"Download Search-R1 from HuggingFace, process, and save to Parquet.\")\r\n    parser.add_argument(\r\n        \"--hf_repo_id\", default=\"PeterJinGo/nq_hotpotqa_train\", help=\"HuggingFace dataset repository ID.\"\r\n    )\r\n    parser.add_argument(\r\n        \"--local_dir\",\r\n        default=\"~/data/searchR1_processed_direct\",\r\n        help=\"Local directory to save the processed Parquet files.\",\r\n    )\r\n    parser.add_argument(\"--hdfs_dir\", default=None, help=\"Optional HDFS directory to copy the Parquet files to.\")\r\n\r\n    args = parser.parse_args()\r\n\r\n    # System and user content configuration\r\n    system_content = DEFAULT_SYSTEM_CONTENT\r\n    user_content_prefix = DEFAULT_USER_CONTENT_PREFIX\r\n\r\n    main()\r\n"
  },
  {
    "path": "verl_rl/examples/generation/run_deepseek7b_mutli_node.sh",
    "content": "set -x\n\ndata_path=$HOME/data/rlhf/gsm8k/test.parquet\nsave_path=$HOME/data/rlhf/math/deepseek_v2_lite_gen_test.parquet\nmodel_path=deepseek-ai/deepseek-llm-7b-chat\n\npython3 -m verl.trainer.main_generation \\\n    trainer.nnodes=2 \\\n    trainer.n_gpus_per_node=8 \\\n    data.path=$data_path \\\n    data.prompt_key=prompt \\\n    data.n_samples=1 \\\n    data.output_path=$save_path \\\n    model.path=$model_path\\\n    +model.trust_remote_code=True \\\n    rollout.temperature=1.0 \\\n    rollout.top_k=50 \\\n    rollout.top_p=0.7 \\\n    rollout.prompt_length=2048 \\\n    rollout.response_length=1024 \\\n    rollout.tensor_model_parallel_size=16 \\\n    rollout.gpu_memory_utilization=0.8\n"
  },
  {
    "path": "verl_rl/examples/generation/run_deepseek_v2_lite_math.sh",
    "content": "set -x\n\ndata_path=$HOME/data/gsm8k/test.parquet\nsave_path=$HOME/data/gsm8k/deepseek_v2_lite_gen_test.parquet\nmodel_path=deepseek-ai/deepseek-llm-7b-chat\n\npython3 -m verl.trainer.main_generation \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=8 \\\n    data.path=$data_path \\\n    data.prompt_key=prompt \\\n    data.n_samples=1 \\\n    data.output_path=$save_path \\\n    model.path=$model_path \\\n    +model.trust_remote_code=True \\\n    rollout.temperature=1.0 \\\n    rollout.top_k=50 \\\n    rollout.top_p=0.7 \\\n    rollout.prompt_length=2048 \\\n    rollout.response_length=1024 \\\n    rollout.tensor_model_parallel_size=2 \\\n    rollout.gpu_memory_utilization=0.8\n"
  },
  {
    "path": "verl_rl/examples/gpg_trainer/gpg.md",
    "content": "# GPG: Group Policy Gradient\n\nGroup Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning\n](https://arxiv.org/abs/2504.02546).\n\n## Key Components\n- Use a corrected advantage function to improve policy gradient accuracy and training efficiency.\n- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO)\n\n## Configuration\nTo configure GPG within the framework, use the following YAML settings.\n\n```yaml\nalgorithm:\n  adv_estimator: gpg \nactor_rollout_ref:\n  actor:\n    policy_loss:\n      loss_mode: \"gpg\"\n```\n\n## Advanced Extensions\nGPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance.\n\n```yaml\nalgorithm:\n  adv_estimator: gpg\nactor_rollout_ref:\n  actor:\n    use_kl_loss: True # enable kl regularization\n    kl_loss_coef: 0.01\n    policy_loss:\n      loss_mode: \"gpg\"\n```"
  },
  {
    "path": "verl_rl/examples/gpg_trainer/run_qwen2-7b_math.sh",
    "content": "set -x\n\n# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:\n# export VLLM_ATTENTION_BACKEND=XFORMERS\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gpg \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=gpg \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_gpg_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_rl/examples/gpg_trainer/run_qwen2-7b_math_megatron.sh",
    "content": "set -x\n\n# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:\n# export VLLM_ATTENTION_BACKEND=XFORMERS\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=gpg \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=gpg \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_gpg_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/README.md",
    "content": "# Group Relative Policy Optimization (GRPO)\n\nIn reinforcement learning, classic algorithms like PPO rely on a \"critic\" model to estimate the value of actions, guiding the learning process. However, training this critic model can be resource-intensive. \n\nGRPO simplifies this process by eliminating the need for a separate critic model. Instead, it operates as follows:\n- Group Sampling: For a given problem, the model generates multiple possible solutions, forming a \"group\" of outputs.\n- Reward Assignment: Each solution is evaluated and assigned a reward based on its correctness or quality.\n- Baseline Calculation: The average reward of the group serves as a baseline. \n- Policy Update: The model updates its parameters by comparing each solution's reward to the group baseline, reinforcing better-than-average solutions and discouraging worse-than-average ones.\n\nThis approach reduces computational overhead by avoiding the training of a separate value estimation model, making the learning process more efficient. For more details, refer to the original paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/pdf/2402.03300)\n\n## Key Components\n\n- No Value Function (Critic-less): unlike PPO, GRPO does not train a separate value network (critic)\n- Group Sampling (Grouped Rollouts): instead of evaluating one rollout per input, GRPO generates multiple completions (responses) from the current policy for each prompt. This set of completions is referred to as a group.\n- Relative Rewards: within each group, completions are scored (e.g., based on correctness), and rewards are normalized relative to the group.\n\n## Configuration\n\nNote that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior.\n\nDespite that many configurations start with the `ppo_` prefix, they work across different RL algorithms in verl, as the GRPO training loop is similar to that of PPO (without critic).\n\n![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d)\n\n- `actor_rollout.ref.rollout.n`: For each prompt, sample n times. Default to 1. For GRPO, please set it to a value larger than 1 for group sampling.\n\n- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n`\n\n- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers.\n\n- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for GRPO updates on one set of sampled trajectories for actor\n\n- `actor_rollout_ref.actor.clip_ratio`: The GRPO clip range. Default to 0.2\n\n- `algorithm.adv_estimator`: Default is gae. Please set it to grpo instead\n\n- `actor_rollout_ref.actor.loss_agg_mode`: Default is \"token-mean\". Options include \"token-mean\", \"seq-mean-token-sum\", \"seq-mean-token-mean\". The original GRPO paper takes the sample-level loss (seq-mean-token-mean), which may be unstable in long-CoT scenarios. All GRPO example scripts provided in verl uses the default configuration \"token-mean\" for loss aggregation instead.\n\nInstead of adding KL penalty in the reward, GRPO regularizes by directly adding the KL divergence between the trained policy and the reference policy to the loss:\n\n- `actor_rollout_ref.actor.use_kl_loss`: To use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False. Please set it to True for GRPO.\n\n- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.\n\n- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\n## Advanced Extensions\n\n### DrGRPO\n\nThe work [Understanding R1-Zero-Like Training: A Critical Perspective](https://arxiv.org/pdf/2503.20783) claims there's optimization bias in GRPO, that leads to artificially longer responses, especially for incorrect outputs. This inefficiency stems from the way GRPO calculates advantages using group-based reward normalization, which can inadvertently favor longer, less accurate responses. Instead, DrGRPO aggregates token-level losses by normalizing with a global constant to eliminate length bias.\n\nConfigure the following to enable DrGRPO, with all other parameters the same as GRPO's:\n\n- `actor_rollout_ref.actor.loss_agg_mode`: \"seq-mean-token-sum-norm\", which turns off seq-dim averaging\n- `actor_rollout_ref.actor.use_kl_loss`: Please set it to False for DrGRPO\n- `algorithm.norm_adv_by_std_in_grpo`: False, which turns off standard deviation norm\n\n## Reference Example\n\nQwen2.5 GRPO training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log)\n\n```bash\nbash examples/grpo_trainer/run_qwen3-8b.sh\n```\n\nFor more reference performance, please see https://verl.readthedocs.io/en/latest/algo/baseline.html\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_deepseek671b_math_megatron.sh",
    "content": "set -x\n\n# 0. download the config\n# only need to download the `configuration_deepseek.py`, `config.json`, `tokenizer_config.json`, `tokenizer.json` and `generation_config.json`\n# remove the `quantization_config` in the `config.json`\n# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported\n\nhuggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json\n\n# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main\n# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path\nDIST_CKPT_PATH=\"<path_to_dist_ckpt>\"\nLLM=\"<path_to_dsv3_config>\"\n\n\n# 2. run the script\ngsm8k_train_path=/data/gsm8k/train.parquet\ngsm8k_test_path=/data/gsm8k/test.parquet\ntrain_files=$gsm8k_train_path\ntest_files=$gsm8k_test_path\n\nALL_OFFLOAD=${ALL_OFFLOAD:-True}\nCOMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}\n\nACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nREF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nCRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nRM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\n\n# 512 H20(96GB)\nNODES=64\nPP=16\nTP=1\nEP=32\nETP=1\nINFER_TP=32\n# consider TP/ETP, and enable recompute if short of memory\n\n# full recompute\n# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \\\n# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \\\n# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \\\n\nn_resp_per_prompt=4\n\n# RAY_ADDRESS='auto' ray job submit --working-dir . --\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$LLM \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.rollout.temperature=1.0 \\\n    actor_rollout_ref.rollout.top_p=1.0 \\\n    actor_rollout_ref.rollout.top_k=-1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$INFER_TP \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.logger='[\"console\",\"tensorboard\"]' \\\n    trainer.project_name='verl_megatron_gsm8k_examples' \\\n    trainer.experiment_name='dsv3-32nodes' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=$NODES \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=3 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=2 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \\\n    actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    trainer.default_local_dir=$CKPT_DIR \\\n    trainer.val_before_train=False \\\n    trainer.total_epochs=100 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_deepseek7b_llm.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_deepseek7b_llm_math.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm_math' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='deepseek_llm_7b_math_megatron' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm_seq_packing' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_minicpmo2_6.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=False \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    data.trust_remote_code=True \\\n    data.custom_cls.path=recipe/minicpmo/rl_dataset.py \\\n    data.custom_cls.name=RLHFDataset \\\n    actor_rollout_ref.model.path=openbmb/MiniCPM-o-2_6 \\\n    actor_rollout_ref.model.trust_remote_code=True \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    +actor_rollout_ref.actor.fsdp_config.use_orig_params=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=False \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='minicpmo2_6_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_moonlight16b_math_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\nHF_MODEL_PATH=moonshotai/Moonlight-16B-A3B\nDIST_CKPT_PATH=${DIST_CKPT_PATH}\n\ntrain_path=$HOME/data/gsm8k/train.parquet\ntest_path=$HOME/data/gsm8k/test.parquet\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_path\" \\\n    data.val_files=\"$test_path\" \\\n    data.train_batch_size=192 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.trust_remote_code=True \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.model.trust_remote_code=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=3 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=4 \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=1 \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=3 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=4 \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=1 \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='moonlight_megatron_ep' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=3 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2-7b.sh",
    "content": "set -x\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2-7b_math.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\nrollout_mode=\"sync\"\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\nUSE_FUSED_KERNELS=True\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=$return_raw_chat \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=$rollout_mode \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh",
    "content": "set -x\n\n\n# For async rollout mode, dataset should return raw chat.\nrollout_mode=\"async\"\nrollout_name=\"sglang\" # sglang or vllm\nif [ \"$rollout_mode\" = \"async\" ]; then\n    export VLLM_USE_V1=1\n    return_raw_chat=\"True\"\nfi\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.return_raw_chat=$return_raw_chat \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$rollout_name \\\n    actor_rollout_ref.rollout.mode=$rollout_mode \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm_kl1e-3' \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh",
    "content": "set -x\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.model.use_shm=True \\\n    actor_rollout_ref.model.lora_rank=64 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2.5_3b_grpo_lora' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='qwen2_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6\\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_5_32b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=2 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh",
    "content": "set -x\n\n# profiling configuration\nPROFILE_STEPS=\"[2,4]\"\nPROFILE_RANKS_ALL=False\nDISCRETE=True\nPROFILE_RANKS=\"[1,2]\"\n\n# profiling NPU options\nSAVE_PATH=\"$HOME/profile_data\"\nLEVEL=\"level1\"\nWITH_MEMORY=False\nRECORD_SHAPES=False\nWITH_NPU=True\nWITH_CPU=True\nWITH_MODULE=False\nWITH_STACK=False\nANALYSIS=True\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=5e-8 \\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.profiler.ranks=$PROFILE_RANKS \\\n    actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.profiler.discrete=$DISCRETE \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.npu_profile.options.save_path=$SAVE_PATH \\\n    trainer.npu_profile.options.level=$LEVEL \\\n    trainer.npu_profile.options.with_memory=$WITH_MEMORY \\\n    trainer.npu_profile.options.record_shapes=$RECORD_SHAPES \\\n    trainer.npu_profile.options.with_npu=$WITH_NPU \\\n    trainer.npu_profile.options.with_cpu=$WITH_CPU \\\n    trainer.npu_profile.options.with_module=$WITH_MODULE \\\n    trainer.npu_profile.options.with_stack=$WITH_STACK \\\n    trainer.npu_profile.options.analysis=$ANALYSIS \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_5_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=5 \\\n    trainer.profile_steps=$PROFILE_STEPS \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh",
    "content": "set -x\n\n# profiling configuration\nPROFILE_STEPS=\"[2,4]\"\nPROFILE_RANKS_ALL=True\nDISCRETE=False\n\n# profiling NPU options\nSAVE_PATH=\"$HOME/profile_data\"\nLEVEL=\"level1\"\nWITH_MEMORY=False\nRECORD_SHAPES=False\nWITH_NPU=True\nWITH_CPU=True\nWITH_MODULE=False\nWITH_STACK=False\nANALYSIS=True\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=5e-8 \\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.profiler.discrete=$DISCRETE \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.npu_profile.options.save_path=$SAVE_PATH \\\n    trainer.npu_profile.options.level=$LEVEL \\\n    trainer.npu_profile.options.with_memory=$WITH_MEMORY \\\n    trainer.npu_profile.options.record_shapes=$RECORD_SHAPES \\\n    trainer.npu_profile.options.with_npu=$WITH_NPU \\\n    trainer.npu_profile.options.with_cpu=$WITH_CPU \\\n    trainer.npu_profile.options.with_module=$WITH_MODULE \\\n    trainer.npu_profile.options.with_stack=$WITH_STACK \\\n    trainer.npu_profile.options.analysis=$ANALYSIS \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_5_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=5 \\\n    trainer.profile_steps=$PROFILE_STEPS \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=5e-8 \\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_5_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=5 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh",
    "content": "set -x\nENGINE=${1:-vllm}\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\nHF_MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct\nDIST_CKPT_PATH=${DIST_CKPT_PATH}\n\n# convert HF model to meagatron format offlinely\n# python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH\n\n\n# megatron tuning guide:\n# 1. recommend to offload all states by setting ALL_OFFLOAD=True\n# 2. enable dynamic batch size by setting actor_rollout_ref.actor.use_dynamic_bsz=True ref.log_prob_use_dynamic_bsz=True rollout.log_prob_use_dynamic_bsz=True\n# 3. set ppo_max_token_len_per_gpu and log_prob_max_token_len_per_gpu as large as possible for better MFU (limited by GPU memory). assure ppo_max_token_len_per_gpu > max_prompt_length+max_response_length, if sequence length is too long, you can increase the TP/PP size\n# 4. if memory is very limited, enable full recompute, but the mfu will be 30% lower\n#        full recompute settings:\n#        +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \\\n#        +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \\\n#        +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \\\n\nALL_OFFLOAD=${ALL_OFFLOAD:-True}\nCOMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}\n\nACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nREF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\n\n\ntrain_path=$HOME/data/geo3k/train.parquet\ntest_path=$HOME/data/geo3k/test.parquet\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_path\" \\\n    data.val_files=\"$test_path\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=20480 \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=20480 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \\\n    actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_vl-7b.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:\n# export VLLM_ATTENTION_BACKEND=XFORMERS\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \\\n    actor_rollout_ref.model.lora_rank=64 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.model.exclude_modules='.*visual.*' \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=False \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=6144 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=False \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=6144 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\n# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, \n# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model.\nexport USE_OPTIMIZED_MODEL=0\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-32B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_32b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=2 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=15 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\n# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, \n# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model.\nexport USE_OPTIMIZED_MODEL=0\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=16 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_3b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=15 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\n# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, \n# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model.\nexport USE_OPTIMIZED_MODEL=0\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=15 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen3-236b_megatron.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n# Note that we set the response length to 4k. This results in many truncations at the beginning.\n# So the training dynamic acts as using RL to compress the math capabilities of QWen3 236b into 4k response instead of verbose thinking.\n# We can achieve 0.5 on AIME'24 after 30 steps.\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen3-236b-megatron-0531a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 4))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=0.1\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=256\nn_resp_per_prompt=4\ntrain_prompt_mini_bsz=16\n\n# H20 GPUs\nNNODES=${NNODES:-32}\n\n# Paths\n\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n\nMODEL_PATH=$RAY_DATA_HOME/models/Qwen3-235B-A22B\nMCORE_MODEL_PATH=$RAY_DATA_HOME/models/Qwen3-235B-A22B_dist_ckpt_mcore/\n\n# convert QWen3-235b-A22b to dist ckpt of mcore. Conversion process will take about 4 hours\n# python scripts/converter_hf_to_mcore.py --hf_model_path $MODEL_PATH --output_path $MCORE_MODEL_PATH --use_cpu_initialization\nCKPTS_DIR=$RAY_DATA_HOME/ckpt/${project_name}/${exp_name}\nTRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet\nTEST_FILE=$RAY_DATA_HOME/dataset/aime-2024.parquet\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\noffload=True\ngen_tp=8\ntrain_tp=4\ntrain_ep=4\ntrain_pp=8\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=5 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=5 \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=20 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen3-8b.sh",
    "content": "# Tested successfully on the hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0 image.\n# It outperforms the Qwen2 7B base model by two percentage points on the test set of GSM8K.\n\nset -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen3-8B \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen3_8b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@"
  },
  {
    "path": "verl_rl/examples/grpo_trainer/run_qwen3moe-30b_megatron.sh",
    "content": "set -x\n\nHF_MODEL_PATH=Qwen/Qwen3-30B-A3B\nDIST_CKPT_PATH=${DIST_CKPT_PATH}\n\npython scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=64 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=4 \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=4 \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k_math' \\\n    trainer.experiment_name='qwen3_30b_moe_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=4 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/README.md",
    "content": "# Proximal Policy Optimization (PPO)\n\nProximal Policy Optimization (PPO) is a family of policy gradient methods for reinforcement learning, proposed by OpenAI in 2017. PPO strikes a balance between simplicity, stability, and performance, making it one of the most widely used algorithms in modern RL applications, including large-scale language model fine-tuning.\n\nTraditional policy gradient methods like REINFORCE or Vanilla Policy Gradient suffer from:\n\n- High variance and sample inefficiency.\n- Instability due to large policy updates.\n\nPPO addresses this problem using a clipped surrogate objective that avoids overly large updates without requiring second-order derivatives.\n\nFor more technical details regarding PPO, we suggest reading the introduction in the [OpenAI spinning up tutorial](https://spinningup.openai.com/en/latest/algorithms/ppo.html), and the paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347).\n\n## Key Components\n\n- Actor-Critic Architecture: PPO requires both an actor model (policy) and a critic model (value function). This differs from other algorithms like GRPO and RLOO that don't require a critic model.\n\n- Generalized Advantage Estimation (GAE): PPO uses GAE for computing advantage values, which helps reduce variance in policy gradient estimates while maintaining low bias.\n\n- Clipped Surrogate Objective: The core of PPO is implemented through the clipped surrogate objective function that limits policy updates.\n\n## Configuration\n\nNote that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior.\n\nMost critic configs are similar to those of actors. Note that the critic model is omitted from the figure below.\n\n![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d)\n\n- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n`\n\n- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers\n\n- `actor_rollout_ref.critic.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO critic updates. The ppo_mini_batch_size is a global size across all workers\n\n- `actor_rollout_ref.actor.clip_ratio`: The PPO clip range. Default to 0.2\n\n- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for actor\n\n- `critic.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for critic. Defaults to `actor_rollout_ref.actor.ppo_epochs`\n\n- `algorithm.gamma`: discount factor\n\n- `algorithm.lam`: The lambda term that trades off between bias and variance in the GAE estimator\n\n- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo\n\n## Advanced Extensions\n\n### KL Divergence Control\n\nOptions to prevent the policy from diverging too far from a reference policy. Two mechanisms are available: KL reward penalty and KL loss. For more technical details, see [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155)\n\nOptions to use KL loss for KL divergence control: \n\n- `actor_rollout_ref.actor.use_kl_loss`: to use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False\n\n- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.\n\n- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\nOptions to use KL penalty in the reward:\n\n- `algorithm.use_kl_in_reward`: Whether to enable in-reward kl penalty. Default is False.\n\n- `algorithm.kl_penalty`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. This defines the way to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty` in core_algos.py. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html\n\n- `algorithm.kl_ctrl.kl_coef`: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.\n- `algorithm.kl_ctrl.type`: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.\n- `algorithm.kl_ctrl.horizon`: See source code of AdaptiveKLController for details.\n- `algorithm.kl_ctrl.target_kl`: See source code of AdaptiveKLController for details.\n\n### Dual-clip PPO\n\nThe Dual-Clip PPO introduces a approach by applying a lower bound to the policy ratio when the advantage is less than zero, when multiplied by a large raito, does not exceed a specified lower bound.\n\n![image](https://github.com/user-attachments/assets/fc232181-d8b0-4307-8dd2-4dc0a4c1c139)\n\n- `actor_rollout_ref.actor.clip_ratio_c`: lower bound of the value for Dual-clip PPO, defaults to 3.0\n\n## Reference Example\n\nQwen2.5 training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log)\n\n```bash\nbash run_gemma.sh\n  trainer.n_gpus_per_node=1 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  trainer.logger=console \\\n  critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  data.train_batch_size=256 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size=2 \\\n  critic.ppo_micro_batch_size=2\n```\n\nReference performance with verl v0.2:\n\n| Model                          | Method          | Score | Link                                                                                           |\n|-------------------------------|------------------|-------|------------------------------------------------------------------------------------------------|\n| Qwen/Qwen2.5-0.5B-Instruct     | pretrained model | 36.4  | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/)                                        |\n| Qwen/Qwen2.5-0.5B-Instruct     | PPO              | 56.7  | [PPO Command and Logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) |\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_deepseek7b_llm.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=32 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=1 \\\n    trainer.use_legacy_worker_impl=auto \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh",
    "content": "set -x\n\nVERL_USE_MODELSCOPE=True \\\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=32 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=1 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    algorithm.use_pf_ppo=True \\\n    algorithm.pf_ppo.reweight_method=pow \\  # [\"pow\", \"max_min\", \"max_random\"]\n    algorithm.pf_ppo.weight_pow=2.0 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.rollout.n=5 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=32 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=1 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    reward_model.sandbox_fusion.url='https://xxxxxxxxx.apigateway-cn-beijing.volceapi.com/run_code' \\\n    reward_model.sandbox_fusion.max_concurrent=128 \\\n    reward_model.reward_manager=prime \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/Eurus-2-RL-Data/train.parquet \\\n    data.val_files=$HOME/data/Eurus-2-RL-Data/validation.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=32 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_sandbox_fusion' \\\n    trainer.experiment_name='deepseek_llm_7b_function_sandbox_fusion' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=1 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    critic.optim.lr=1e-5 \\\n    critic.ulysses_sequence_parallel_size=2 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=64 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm_sp2' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh",
    "content": "set -x\n\ntrain_files=$HOME/data/full_hh_rlhf/rl/train.parquet\ntest_files=$HOME/data/full_hh_rlhf/rl/train.parquet # no use\n\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=128 \\\n    data.max_response_length=128 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    reward_model.enable=True \\\n    reward_model.megatron.tensor_model_parallel_size=4 \\\n    reward_model.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    reward_model.micro_batch_size_per_gpu=4 \\\n    reward_model.param_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_megatron_full_hh_rlhf_examples' \\\n    trainer.experiment_name='deepseek_llm_7b_model_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=100 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh",
    "content": "set -x\n\n# Example runnable on H20 * 8\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_ppo_gsm8k_math_examples' \\\n    trainer.experiment_name='deepseek_llm_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=100 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh",
    "content": "set -x\n\n# Example runnable on H20 * 8\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=${train_files:-\"$gsm8k_train_path\"}\ntest_files=${test_files:-\"$gsm8k_test_path\"}\n\n# Nsight profiling configuration\nPROFILE_STEPS=\"[1,2,5]\" # or [] or null\nPROFILE_RANKS_ALL=False # or True\nPROFILE_RANKS=[0,4,8,12]\nDISCRETE=True  # or True\n\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.profiler.ranks=$PROFILE_RANKS \\\n    actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.profiler.discrete=$DISCRETE \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    critic.profiler.ranks=$PROFILE_RANKS \\\n    critic.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    critic.profiler.discrete=$DISCRETE \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_ppo_gsm8k_math_examples' \\\n    trainer.experiment_name='deepseek_llm_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=2 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=100 \\\n    trainer.total_training_steps=6 \\\n    trainer.profile_steps=$PROFILE_STEPS $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_gemma.sh",
    "content": "set -x\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=google/gemma-2-2b-it \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=False \\\n    critic.model.path=google/gemma-2-2b-it \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example' \\\n    trainer.experiment_name='gemma2b_function_rm' \\\n    trainer.n_gpus_per_node=2 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\n\n# 0. download the model\nhuggingface-cli download moonshotai/Moonlight-16B-A3B-Instruct\n\n# 1. convert the model to mcore format\n# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path\nHF_MODEL_PATH=/data/models/moonshotai/Moonlight-16B-A3B-Instruct\nDIST_CKPT_PATH=/data/mcore_ckpt/Moonlight-16B-A3B-Instruct\npython scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH\n\n\n# 2. run the script\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\ntrain_files=$gsm8k_train_path\ntest_files=$gsm8k_test_path\n\nALL_OFFLOAD=${ALL_OFFLOAD:-False}\nCOMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}\n\nACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nREF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nCRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nRM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\n\n\nNODES=4\nPP=2\nTP=8\nEP=8\nETP=1\nVLLM_TP=4\n\n# RAY_ADDRESS='auto' ray job submit --working-dir . -- \npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.trust_remote_code=True \\\n    actor_rollout_ref.model.path=$LLM \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=$LLM \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_megatron_gsm8k_examples' \\\n    trainer.experiment_name='moonlight_16b_a3b_instruct_1node' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=$NODES \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    actor_rollout_ref.model.trust_remote_code=True \\\n    critic.model.trust_remote_code=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=13 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \\\n    critic.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \\\n    critic.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \\\n    critic.megatron.expert_model_parallel_size=$EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \\\n    critic.megatron.expert_tensor_parallel_size=$ETP \\\n    actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \\\n    actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \\\n    critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \\\n    critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \\\n    critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    critic.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    trainer.val_before_train=False \\\n    trainer.total_epochs=100 $@\n    "
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\n# 0. download the model\nhuggingface-cli download Qwen/Qwen1.5-MoE-A2.7B-Chat\n\n# 1. convert the model to mcore format\n# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path\nHF_MODEL_PATH=/data/models/Qwen/Qwen1.5-MoE-A2.7B-Chat\nDIST_CKPT_PATH=/data/mcore_ckpt/Qwen1.5-MoE-A2.7B-Chat\npython scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH\n\n# 2. run the script\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\ntrain_files=$gsm8k_train_path\ntest_files=$gsm8k_test_path\n\nNODES=4\nPP=2\nTP=4\nCP=1\nVLLM_TP=4\n\n# RAY_ADDRESS='auto' ray job submit --working-dir . -- \npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$HF_MODEL_PATH \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=$CP \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=$CP \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=$HF_MODEL_PATH \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    critic.megatron.tensor_model_parallel_size=$TP \\\n    critic.megatron.pipeline_model_parallel_size=$PP \\\n    critic.megatron.context_parallel_size=$CP \\\n    critic.megatron.use_dist_checkpointing=True \\\n    critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_megatron_gsm8k_examples' \\\n    trainer.experiment_name='qwen1.5_moe_nochat' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=$NODES \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=100 $@\n    "
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh",
    "content": "set -x\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_ppo_gsm8k_math_examples' \\\n    trainer.experiment_name='qwen2_7b_megatron' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=100 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_qwen2-7b_rm.sh",
    "content": "# Discliamer: the model used in the script is only for academic purpose.\nset -x\n\n# Data preparation scripts are available in ``examples/data_preprocess``.\n# Example usage:\n#\n#   python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math\n#   python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\n\n# prepare model ckpt\nhuggingface-cli download Qwen/Qwen2-7B-Instruct --local-dir $HOME/models/Qwen2-7B-Instruct &\nhuggingface-cli download sfairXC/FsfairX-LLaMA3-RM-v0.1 --local-dir $HOME/models/FsfairX-LLaMA3-RM-v0.1 &\nwait\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"$HOME/models/Qwen2-7B-Instruct\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.optim.lr_warmup_steps_ratio=0.05 \\\n    critic.model.path=\"$HOME/models/Qwen2-7B-Instruct\" \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=32 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    reward_model.enable=True \\\n    reward_model.model.path=\"$HOME/models/FsfairX-LLaMA3-RM-v0.1\" \\\n    reward_model.model.use_remove_padding=True \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.micro_batch_size_per_gpu=32 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example' \\\n    trainer.val_before_train=False \\\n    trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=4096 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.use_dynamic_bsz=True \\\n    critic.ppo_max_token_len_per_gpu=98304 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    reward_model.enable=True \\\n    reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\\\n    reward_model.model.use_remove_padding=True \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.micro_batch_size_per_gpu=32 \\\n    reward_model.use_dynamic_bsz=True \\\n    reward_model.forward_max_token_len_per_gpu=98304 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\nFUSED_KERNEL_BACKEND=triton # or 'torch' for torch backend\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=4096 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.model.fused_kernel_options.impl_backend=$FUSED_KERNEL_BACKEND \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.use_dynamic_bsz=True \\\n    critic.ppo_max_token_len_per_gpu=98304 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    reward_model.enable=True \\\n    reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\\\n    reward_model.model.use_remove_padding=True \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.micro_batch_size_per_gpu=32 \\\n    reward_model.use_dynamic_bsz=True \\\n    reward_model.forward_max_token_len_per_gpu=98304 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing_fused_kernel' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=${train_files:-\"$gsm8k_train_path\"}\ntest_files=${test_files:-\"$gsm8k_test_path\"}\n\nPROFILE_STEPS=\"[1,2,5]\" # or [] or null\nPROFILE_RANKS_ALL=False # or True\nPROFILE_RANKS=[0,4,8,12]\nDISCRETE=True  # or True\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=4096 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.profiler.ranks=$PROFILE_RANKS \\\n    actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    actor_rollout_ref.profiler.discrete=$DISCRETE \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_micro_batch_size_per_gpu=2 \\\n    critic.use_dynamic_bsz=True \\\n    critic.ppo_max_token_len_per_gpu=98304 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    critic.profiler.ranks=$PROFILE_RANKS \\\n    critic.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    critic.profiler.discrete=$DISCRETE \\\n    reward_model.enable=True \\\n    reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\\\n    reward_model.model.use_remove_padding=True \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.micro_batch_size_per_gpu=32 \\\n    reward_model.use_dynamic_bsz=True \\\n    reward_model.forward_max_token_len_per_gpu=98304 \\\n    reward_model.profiler.ranks=$PROFILE_RANKS \\\n    reward_model.profiler.all_ranks=$PROFILE_RANKS_ALL \\\n    reward_model.profiler.discrete=$DISCRETE \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=2 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=15 \\\n    trainer.total_training_steps=6 \\\n    trainer.profile_steps=$PROFILE_STEPS $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\n# For async rollout mode, dataset should return raw chat.\nrollout_mode=\"sync\"\nif [ \"$rollout_mode\" = \"async\" ]; then\n    return_raw_chat=\"True\"\nfi\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=$return_raw_chat \\\n    data.train_batch_size=4096 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=$rollout_mode \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_max_token_len_per_gpu=98304 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='qwen2-7b_function_rm_bsz8k_p4k_r4k_seq_packing' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=4096 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=512 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2-7B-Instruct \\\n    critic.model.enable_gradient_checkpointing=True \\\n    critic.ppo_max_token_len_per_gpu=98304 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='qwen2-7b_function_rm_bsz8k_p4k_r4k_seq_packing' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ppo_trainer/run_qwen2.5-32b.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=False \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=True \\\n    critic.model.path=Qwen/Qwen2.5-32B-Instruct \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.ppo_micro_batch_size_per_gpu=8 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example' \\\n    trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=4 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/ray/tutorial.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0ddc582b\",\n   \"metadata\": {},\n   \"source\": [\n    \"# VeRL Ray API Tutorial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"71fe3b94\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Chapter 1: Ray Basics\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 144,\n   \"id\": \"1347d381\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import os\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 145,\n   \"id\": \"e75b9d44\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import warnings\\n\",\n    \"\\n\",\n    \"import ray\\n\",\n    \"import torch\\n\",\n    \"\\n\",\n    \"warnings.filterwarnings(\\\"ignore\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 146,\n   \"id\": \"2e90ae00\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2024-11-01 17:27:19,132\\tINFO worker.py:1752 -- Started a local Ray instance.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"9cc9d2ccbdfb48918c8fd6cd13a0807a\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/html\": [\n       \"<div class=\\\"lm-Widget p-Widget lm-Panel p-Panel jp-Cell-outputWrapper\\\">\\n\",\n       \"    <div style=\\\"margin-left: 50px;display: flex;flex-direction: row;align-items: center\\\">\\n\",\n       \"        <div class=\\\"jp-RenderedHTMLCommon\\\" style=\\\"display: flex; flex-direction: row;\\\">\\n\",\n       \"  <svg viewBox=\\\"0 0 567 224\\\" fill=\\\"none\\\" xmlns=\\\"http://www.w3.org/2000/svg\\\" style=\\\"height: 3em;\\\">\\n\",\n       \"    <g clip-path=\\\"url(#clip0_4338_178347)\\\">\\n\",\n       \"        <path d=\\\"M341.29 165.561H355.29L330.13 129.051C345.63 123.991 354.21 112.051 354.21 94.2307C354.21 71.3707 338.72 58.1807 311.88 58.1807H271V165.561H283.27V131.661H311.8C314.25 131.661 316.71 131.501 319.01 131.351L341.25 165.561H341.29ZM283.29 119.851V70.0007H311.82C331.3 70.0007 342.34 78.2907 342.34 94.5507C342.34 111.271 331.34 119.861 311.82 119.861L283.29 119.851ZM451.4 138.411L463.4 165.561H476.74L428.74 58.1807H416L367.83 165.561H380.83L392.83 138.411H451.4ZM446.19 126.601H398L422 72.1407L446.24 126.601H446.19ZM526.11 128.741L566.91 58.1807H554.35L519.99 114.181L485.17 58.1807H472.44L514.01 129.181V165.541H526.13V128.741H526.11Z\\\" fill=\\\"var(--jp-ui-font-color0)\\\"/>\\n\",\n       \"        <path d=\\\"M82.35 104.44C84.0187 97.8827 87.8248 92.0678 93.1671 87.9146C98.5094 83.7614 105.083 81.5067 111.85 81.5067C118.617 81.5067 125.191 83.7614 130.533 87.9146C135.875 92.0678 139.681 97.8827 141.35 104.44H163.75C164.476 101.562 165.622 98.8057 167.15 96.2605L127.45 56.5605C121.071 60.3522 113.526 61.6823 106.235 60.3005C98.9443 58.9187 92.4094 54.9203 87.8602 49.0574C83.3109 43.1946 81.0609 35.8714 81.5332 28.4656C82.0056 21.0599 85.1679 14.0819 90.4252 8.8446C95.6824 3.60726 102.672 0.471508 110.08 0.0272655C117.487 -0.416977 124.802 1.86091 130.647 6.4324C136.493 11.0039 140.467 17.5539 141.821 24.8501C143.175 32.1463 141.816 39.6859 138 46.0505L177.69 85.7505C182.31 82.9877 187.58 81.4995 192.962 81.4375C198.345 81.3755 203.648 82.742 208.33 85.3976C213.012 88.0532 216.907 91.9029 219.616 96.5544C222.326 101.206 223.753 106.492 223.753 111.875C223.753 117.258 222.326 122.545 219.616 127.197C216.907 131.848 213.012 135.698 208.33 138.353C203.648 141.009 198.345 142.375 192.962 142.313C187.58 142.251 182.31 140.763 177.69 138L138 177.7C141.808 184.071 143.155 191.614 141.79 198.91C140.424 206.205 136.44 212.75 130.585 217.313C124.731 221.875 117.412 224.141 110.004 223.683C102.596 223.226 95.6103 220.077 90.3621 214.828C85.1139 209.58 81.9647 202.595 81.5072 195.187C81.0497 187.779 83.3154 180.459 87.878 174.605C92.4405 168.751 98.9853 164.766 106.281 163.401C113.576 162.035 121.119 163.383 127.49 167.19L167.19 127.49C165.664 124.941 164.518 122.182 163.79 119.3H141.39C139.721 125.858 135.915 131.673 130.573 135.826C125.231 139.98 118.657 142.234 111.89 142.234C105.123 142.234 98.5494 139.98 93.2071 135.826C87.8648 131.673 84.0587 125.858 82.39 119.3H60C58.1878 126.495 53.8086 132.78 47.6863 136.971C41.5641 141.163 34.1211 142.972 26.7579 142.059C19.3947 141.146 12.6191 137.574 7.70605 132.014C2.79302 126.454 0.0813599 119.29 0.0813599 111.87C0.0813599 104.451 2.79302 97.2871 7.70605 91.7272C12.6191 86.1673 19.3947 82.5947 26.7579 81.6817C34.1211 80.7686 41.5641 82.5781 47.6863 86.7696C53.8086 90.9611 58.1878 97.2456 60 104.44H82.35ZM100.86 204.32C103.407 206.868 106.759 208.453 110.345 208.806C113.93 209.159 117.527 208.258 120.522 206.256C123.517 204.254 125.725 201.276 126.771 197.828C127.816 194.38 127.633 190.677 126.253 187.349C124.874 184.021 122.383 181.274 119.205 179.577C116.027 177.88 112.359 177.337 108.826 178.042C105.293 178.746 102.113 180.654 99.8291 183.44C97.5451 186.226 96.2979 189.718 96.3 193.32C96.2985 195.364 96.7006 197.388 97.4831 199.275C98.2656 201.163 99.4132 202.877 100.86 204.32ZM204.32 122.88C206.868 120.333 208.453 116.981 208.806 113.396C209.159 109.811 208.258 106.214 206.256 103.219C204.254 100.223 201.275 98.0151 197.827 96.97C194.38 95.9249 190.676 96.1077 187.348 97.4873C184.02 98.8669 181.274 101.358 179.577 104.536C177.879 107.714 177.337 111.382 178.041 114.915C178.746 118.448 180.653 121.627 183.439 123.911C186.226 126.195 189.717 127.443 193.32 127.44C195.364 127.443 197.388 127.042 199.275 126.259C201.163 125.476 202.878 124.328 204.32 122.88ZM122.88 19.4205C120.333 16.8729 116.981 15.2876 113.395 14.9347C109.81 14.5817 106.213 15.483 103.218 17.4849C100.223 19.4868 98.0146 22.4654 96.9696 25.9131C95.9245 29.3608 96.1073 33.0642 97.4869 36.3922C98.8665 39.7202 101.358 42.4668 104.535 44.1639C107.713 45.861 111.381 46.4036 114.914 45.6992C118.447 44.9949 121.627 43.0871 123.911 40.301C126.195 37.515 127.442 34.0231 127.44 30.4205C127.44 28.3772 127.038 26.3539 126.255 24.4664C125.473 22.5788 124.326 20.8642 122.88 19.4205ZM19.42 100.86C16.8725 103.408 15.2872 106.76 14.9342 110.345C14.5813 113.93 15.4826 117.527 17.4844 120.522C19.4863 123.518 22.4649 125.726 25.9127 126.771C29.3604 127.816 33.0638 127.633 36.3918 126.254C39.7198 124.874 42.4664 122.383 44.1635 119.205C45.8606 116.027 46.4032 112.359 45.6988 108.826C44.9944 105.293 43.0866 102.114 40.3006 99.8296C37.5145 97.5455 34.0227 96.2983 30.42 96.3005C26.2938 96.3018 22.337 97.9421 19.42 100.86ZM100.86 100.86C98.3125 103.408 96.7272 106.76 96.3742 110.345C96.0213 113.93 96.9226 117.527 98.9244 120.522C100.926 123.518 103.905 125.726 107.353 126.771C110.8 127.816 114.504 127.633 117.832 126.254C121.16 124.874 123.906 122.383 125.604 119.205C127.301 116.027 127.843 112.359 127.139 108.826C126.434 105.293 124.527 102.114 121.741 99.8296C118.955 97.5455 115.463 96.2983 111.86 96.3005C109.817 96.299 107.793 96.701 105.905 97.4835C104.018 98.2661 102.303 99.4136 100.86 100.86Z\\\" fill=\\\"#00AEEF\\\"/>\\n\",\n       \"    </g>\\n\",\n       \"    <defs>\\n\",\n       \"        <clipPath id=\\\"clip0_4338_178347\\\">\\n\",\n       \"            <rect width=\\\"566.93\\\" height=\\\"223.75\\\" fill=\\\"white\\\"/>\\n\",\n       \"        </clipPath>\\n\",\n       \"    </defs>\\n\",\n       \"  </svg>\\n\",\n       \"</div>\\n\",\n       \"\\n\",\n       \"        <table class=\\\"jp-RenderedHTMLCommon\\\" style=\\\"border-collapse: collapse;color: var(--jp-ui-font-color1);font-size: var(--jp-ui-font-size1);\\\">\\n\",\n       \"    <tr>\\n\",\n       \"        <td style=\\\"text-align: left\\\"><b>Python version:</b></td>\\n\",\n       \"        <td style=\\\"text-align: left\\\"><b>3.9.2</b></td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"        <td style=\\\"text-align: left\\\"><b>Ray version:</b></td>\\n\",\n       \"        <td style=\\\"text-align: left\\\"><b>2.10.0</b></td>\\n\",\n       \"    </tr>\\n\",\n       \"    \\n\",\n       \"</table>\\n\",\n       \"\\n\",\n       \"    </div>\\n\",\n       \"</div>\\n\"\n      ],\n      \"text/plain\": [\n       \"RayContext(dashboard_url='', python_version='3.9.2', ray_version='2.10.0', ray_commit='09abba26b5bf2707639bb637c208d062a47b46f6')\"\n      ]\n     },\n     \"execution_count\": 146,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[36m(GPUAccumulator pid=224400)\\u001b[0m rank 0, value: tensor([1.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulator pid=225234)\\u001b[0m rank 2, value: tensor([3.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulator pid=225607)\\u001b[0m rank 0, value: tensor([2.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulator pid=226423)\\u001b[0m rank 1, value: tensor([3.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulator pid=226857)\\u001b[0m rank 3, value: tensor([6.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulatorDecorator pid=227475)\\u001b[0m 10\\n\",\n      \"\\u001b[36m(GPUAccumulatorDecorator pid=227475)\\u001b[0m rank 0, value: tensor([10.], device='cuda:0')\\n\",\n      \"\\u001b[36m(GPUAccumulatorDecorator pid=227655)\\u001b[0m rank 1, value: tensor([11.], device='cuda:0')\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Build a local ray cluster. The head node and worker node are on this machine\\n\",\n    \"ray.init()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a127e4e4\",\n   \"metadata\": {},\n   \"source\": [\n    \"Implement an Accumulator class.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 147,\n   \"id\": \"20e7b9a3\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"@ray.remote\\n\",\n    \"class Accumulator:\\n\",\n    \"    def __init__(self):\\n\",\n    \"        self.value = 0\\n\",\n    \"\\n\",\n    \"    def add(self, x):\\n\",\n    \"        self.value += x\\n\",\n    \"\\n\",\n    \"    def get_value(self):\\n\",\n    \"        return self.value\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 148,\n   \"id\": \"3b80098c\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Instantiate an accumulator. Accumulator can be viewed as a process, acting as an RPC service.\\n\",\n    \"accumulator = Accumulator.remote()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 149,\n   \"id\": \"b14b1009\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"0\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"value_ref = accumulator.get_value.remote()  # Check the current value. Note that this function returns immediately and does not actually wait for the remote execution to complete.\\n\",\n    \"# Get the value\\n\",\n    \"value = ray.get(value_ref)\\n\",\n    \"print(value)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 150,\n   \"id\": \"513a84b3\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"10\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Accumulate, then check the result.\\n\",\n    \"accumulator.add.remote(10)  # Similarly, the 'add' here will return immediately.\\n\",\n    \"new_value = ray.get(accumulator.get_value.remote())\\n\",\n    \"print(new_value)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3c332fe0\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Chapter 2: Resource Pool and RayWorkerGroup\\n\",\n    \"In the previous example, it was a simple single-process worker. \\n\",\n    \"In this example, we implement a worker with a GPU and form a RayWorkerGroup. Within this RayWorkerGroup, we implement a simple operation of an accumulator.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 151,\n   \"id\": \"04229afb\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from verl.single_controller.base import Worker\\n\",\n    \"from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 152,\n   \"id\": \"0d0dbd58\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"resource_pool = RayResourcePool([4], use_gpu=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 153,\n   \"id\": \"68f6838a\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"@ray.remote\\n\",\n    \"class GPUAccumulator(Worker):\\n\",\n    \"    def __init__(self) -> None:\\n\",\n    \"        super().__init__()\\n\",\n    \"        # The initial value of each rank is the same as the rank\\n\",\n    \"        self.value = torch.zeros(size=(1,), device=\\\"cuda\\\") + self.rank\\n\",\n    \"\\n\",\n    \"    def add(self, x):\\n\",\n    \"        self.value += x\\n\",\n    \"        print(f\\\"rank {self.rank}, value: {self.value}\\\")\\n\",\n    \"        return self.value.cpu()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 154,\n   \"id\": \"23aad8fe\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[tensor([1.]), tensor([2.]), tensor([3.]), tensor([4.])]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Each worker's initial value is its rank, and then each rank's value is incremented by 1, so the values obtained on each rank are [1, 2, 3, 4]\\n\",\n    \"class_with_args = RayClassWithInitArgs(cls=GPUAccumulator)\\n\",\n    \"worker_group = RayWorkerGroup(resource_pool, class_with_args)\\n\",\n    \"print(worker_group.execute_all_sync(\\\"add\\\", x=[1, 1, 1, 1]))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e6705284\",\n   \"metadata\": {},\n   \"source\": [\n    \"The principle of parameter passing: The input parameter is a list of length world_size, where each element in the list is dispatched respectively to each worker in the RayWorkerGroup. \\n\",\n    \"The return parameter is also a list, corresponding to the return value of each worker.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d25c2412\",\n   \"metadata\": {},\n   \"source\": [\n    \"### GPU Resource Sharing\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f74f6d24\",\n   \"metadata\": {},\n   \"source\": [\n    \"RayWorkerGroups mapped to the same resource pool share the GPU. In this example, we implement three resource pools: the first occupies 4 GPUs, the second also occupies 4 GPUs, and the last occupies all 8 GPUs. Among them, the first resource pool reuses the resource pool mentioned above.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 155,\n   \"id\": \"49f9c06f\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Create a new resource pool and then merge the newly created resource pool with the previous one.\\n\",\n    \"resource_pool_1 = RayResourcePool([4], use_gpu=True, name_prefix=\\\"a\\\")\\n\",\n    \"resource_pool_merge = merge_resource_pool(resource_pool, resource_pool_1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 156,\n   \"id\": \"05c2e305\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Establish a RayWorkerGroup on the newly created resource pool.\\n\",\n    \"worker_group_1 = RayWorkerGroup(resource_pool_1, class_with_args)\\n\",\n    \"worker_group_merge = RayWorkerGroup(resource_pool_merge, class_with_args)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 157,\n   \"id\": \"6b9b13f4\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[tensor([2.]), tensor([3.]), tensor([4.]), tensor([5.])]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run 'add' on the second set of 4 GPUs; the result should be [2, 3, 4, 5].\\n\",\n    \"output_1 = worker_group_1.execute_all_sync(\\\"add\\\", x=[2, 2, 2, 2])\\n\",\n    \"print(output_1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 158,\n   \"id\": \"d856d030\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[tensor([3.]), tensor([4.]), tensor([5.]), tensor([6.]), tensor([7.]), tensor([8.]), tensor([9.]), tensor([10.])]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run 'add' on the merged set of 8 GPUs; the result should be [3, 4, 5, 6, 7, 8, 9, 10].\\n\",\n    \"output_merge = worker_group_merge.execute_all_sync(\\\"add\\\", x=[3, 3, 3, 3, 3, 3, 3, 3])\\n\",\n    \"print(output_merge)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 159,\n   \"id\": \"33a4628c\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"4 4 8\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(worker_group.world_size, worker_group_1.world_size, worker_group_merge.world_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3df19d13\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Chapter 3: Data Dispatch, Execution and Collection\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"acb22d9d\",\n   \"metadata\": {},\n   \"source\": [\n    \"In the above example, we used the `execute_all_sync` function in the RayWorkerGroup to dispatch data from the driver to each worker. This is very inconvenient for coding. \\n\",\n    \"In this chapter, we use the form of function decorators to allow RayWorkerGroup to directly call functions written in the Worker, and to greatly simplify parameter passing.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 160,\n   \"id\": \"35237432\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from verl.single_controller.base.decorator import Dispatch, Execute, register\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 161,\n   \"id\": \"88b8ba3b\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"@ray.remote\\n\",\n    \"class GPUAccumulatorDecorator(Worker):\\n\",\n    \"    def __init__(self) -> None:\\n\",\n    \"        super().__init__()\\n\",\n    \"        # The initial value of each rank is the same as the rank\\n\",\n    \"        self.value = torch.zeros(size=(1,), device=\\\"cuda\\\") + self.rank\\n\",\n    \"\\n\",\n    \"    # map from a single input to all the worker\\n\",\n    \"    @register(Dispatch.ONE_TO_ALL)\\n\",\n    \"    def add(self, x):\\n\",\n    \"        print(x)\\n\",\n    \"        self.value = self.value + x\\n\",\n    \"        print(f\\\"rank {self.rank}, value: {self.value}\\\")\\n\",\n    \"        return self.value.cpu()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 162,\n   \"id\": \"eddaa043\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"class_with_args = RayClassWithInitArgs(cls=GPUAccumulatorDecorator)\\n\",\n    \"gpu_accumulator_decorator = RayWorkerGroup(resource_pool_merge, class_with_args)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 163,\n   \"id\": \"10087c91\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[tensor([10.]), tensor([11.]), tensor([12.]), tensor([13.]), tensor([14.]), tensor([15.]), tensor([16.]), tensor([17.])]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# As we can see, 10 is automatically dispatched to each Worker in this RayWorkerGroup.\\n\",\n    \"print(gpu_accumulator_decorator.add(x=10))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"540ee6ad\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Custom Dispatch, Collection\\n\",\n    \"Users can customize `dispatch` and `collection` function. You only need to write the `dispatch_fn` and `collect_fn` functions yourself. We also support executing RPC only on rank_zero, with specific examples provided below.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 164,\n   \"id\": \"8e041270\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from verl.single_controller.base.decorator import Dispatch, collect_all_to_all, register\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 165,\n   \"id\": \"43b5be31\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def two_to_all_dispatch_fn(worker_group, *args, **kwargs):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    for arg in args:\\n\",\n    \"        assert len(arg) == 2\\n\",\n    \"        for i in range(worker_group.world_size - 2):\\n\",\n    \"            arg.append(arg[i % 2])\\n\",\n    \"    for k, v in kwargs.items():\\n\",\n    \"        assert len(v) == 2\\n\",\n    \"        for i in range(worker_group.world_size - 2):\\n\",\n    \"            v.append(v[i % 2])\\n\",\n    \"    return args, kwargs\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@ray.remote\\n\",\n    \"class TestActor(Worker):\\n\",\n    \"    # TODO: pass *args and **kwargs is bug prone and not very convincing\\n\",\n    \"    def __init__(self, x) -> None:\\n\",\n    \"        super().__init__()\\n\",\n    \"        self._x = x\\n\",\n    \"\\n\",\n    \"    def foo(self, y):\\n\",\n    \"        return self._x + y\\n\",\n    \"\\n\",\n    \"    @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)\\n\",\n    \"    def foo_rank_zero(self, x, y):\\n\",\n    \"        return self._x + y + x\\n\",\n    \"\\n\",\n    \"    @register(dispatch_mode={\\\"dispatch_fn\\\": two_to_all_dispatch_fn, \\\"collect_fn\\\": collect_all_to_all})\\n\",\n    \"    def foo_custom(self, x, y):\\n\",\n    \"        return self._x + y + x\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 166,\n   \"id\": \"83ec6609\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)\\n\",\n    \"worker_group = RayWorkerGroup(resource_pool, class_with_args)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 167,\n   \"id\": \"62c58d8a\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])\\n\",\n    \"assert output_ref == [8, 10, 8, 10]\\n\",\n    \"\\n\",\n    \"output_ref = worker_group.foo_rank_zero(x=1, y=2)\\n\",\n    \"assert output_ref == 5\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 168,\n   \"id\": \"14689353\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"8\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(gpu_accumulator_decorator.world_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 169,\n   \"id\": \"2c80bbf4\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Shutdown ray cluster\\n\",\n    \"ray.shutdown()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a5c8151c\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Chapter 4: NVMegatronRayWorkerGroup\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"cd5680e9\",\n   \"metadata\": {},\n   \"source\": [\n    \"Due to the Ray issue, we can only support max_colocate_count=1 in RayResourcePool for now. \\n\",\n    \"This means that each GPU can only have one process.\\n\",\n    \"We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"92724419\",\n   \"metadata\": {},\n   \"source\": [\n    \"Therefore, we need to restart the ray and initialize a new resource_pool to demonstrate the **NVMegatronRayWorkerGroup**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9b038538\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Build a local ray cluster. The head node and worker node are on this machine\\n\",\n    \"ray.init()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ebfd8798\",\n   \"metadata\": {},\n   \"source\": [\n    \"Finally, we implement a `NVMegatronRayWorkerGroup`, within which we create a Megatron and then run a tensor parallel (tp) split Llama mlp layer. Here, we use a complex dispatch mode, `Megatron_COMPUTE`. This dispatch mode assumes that user passes the data partitioned by DP dimension. The data is dispatched to all tp/pp ranks within the same dp group, and ultimately only collects output data from tp=0 and the last pp. In this way, for users that only write code on the driver, the Megatron behind the RPC becomes transparent.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 171,\n   \"id\": \"5a032154\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/opt/tiger/Megatron-LM\\n\",\n      \"/opt/tiger/Megatron-LM/megatron/__init__.py\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import sys\\n\",\n    \"\\n\",\n    \"current_pythonpath = os.environ.get(\\\"PYTHONPATH\\\", \\\"\\\")\\n\",\n    \"\\n\",\n    \"new_path = \\\"/opt/tiger/Megatron-LM\\\"\\n\",\n    \"\\n\",\n    \"new_pythonpath = f\\\"{new_path}:{current_pythonpath}\\\" if current_pythonpath else new_path\\n\",\n    \"\\n\",\n    \"os.environ[\\\"PYTHONPATH\\\"] = new_pythonpath\\n\",\n    \"\\n\",\n    \"print(new_path)\\n\",\n    \"sys.path.append(new_path)\\n\",\n    \"\\n\",\n    \"import megatron\\n\",\n    \"\\n\",\n    \"print(megatron.__file__)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 172,\n   \"id\": \"8c84cd5a\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from megatron.core import parallel_state as mpu\\n\",\n    \"from omegaconf import OmegaConf\\n\",\n    \"\\n\",\n    \"from verl.single_controller.base.decorator import Dispatch, Execute, register\\n\",\n    \"from verl.single_controller.base.megatron.worker import MegatronWorker\\n\",\n    \"from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\\n\",\n    \"from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 173,\n   \"id\": \"1b1debcc\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"resource_pool = RayResourcePool([4], use_gpu=True, max_colocate_count=1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 174,\n   \"id\": \"bccbe081\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"@ray.remote\\n\",\n    \"class MLPLayerWorker(MegatronWorker):\\n\",\n    \"    def __init__(self):\\n\",\n    \"        super().__init__()\\n\",\n    \"        rank = int(os.environ[\\\"LOCAL_RANK\\\"])\\n\",\n    \"        torch.distributed.init_process_group(backend=\\\"nccl\\\")\\n\",\n    \"        torch.cuda.set_device(rank)\\n\",\n    \"\\n\",\n    \"        mpu.initialize_model_parallel(\\n\",\n    \"            tensor_model_parallel_size=4,\\n\",\n    \"            pipeline_model_parallel_size=1,\\n\",\n    \"            virtual_pipeline_model_parallel_size=None,\\n\",\n    \"            pipeline_model_parallel_split_rank=None,\\n\",\n    \"            use_sharp=False,\\n\",\n    \"            context_parallel_size=1,\\n\",\n    \"            expert_model_parallel_size=1,\\n\",\n    \"            nccl_communicator_config_path=None,\\n\",\n    \"        )\\n\",\n    \"        from megatron.core import tensor_parallel\\n\",\n    \"\\n\",\n    \"        tensor_parallel.model_parallel_cuda_manual_seed(10)\\n\",\n    \"\\n\",\n    \"    @register(Dispatch.ONE_TO_ALL)\\n\",\n    \"    def init_model(self, config):\\n\",\n    \"        from omegaconf import OmegaConf\\n\",\n    \"\\n\",\n    \"        from verl.models.llama.megatron.layers import ParallelLlamaMLP\\n\",\n    \"        from verl.utils.megatron_utils import init_model_parallel_config\\n\",\n    \"\\n\",\n    \"        megatron_config = OmegaConf.create(\\n\",\n    \"            {\\n\",\n    \"                \\\"sequence_parallel\\\": False,\\n\",\n    \"                \\\"param_dtype\\\": \\\"fp32\\\",\\n\",\n    \"                \\\"tensor_model_parallel_size\\\": mpu.get_tensor_model_parallel_world_size(),\\n\",\n    \"                \\\"pipeline_model_parallel_rank\\\": mpu.get_pipeline_model_parallel_rank(),\\n\",\n    \"                \\\"pipeline_model_parallel_size\\\": mpu.get_pipeline_model_parallel_world_size(),\\n\",\n    \"                \\\"virtual_pipeline_model_parallel_rank\\\": mpu.get_virtual_pipeline_model_parallel_rank(),\\n\",\n    \"                \\\"virtual_pipeline_model_parallel_size\\\": mpu.get_virtual_pipeline_model_parallel_world_size(),\\n\",\n    \"            }\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        megatron_config = init_model_parallel_config(megatron_config)\\n\",\n    \"        self.parallel_layer = ParallelLlamaMLP(config=config, megatron_config=megatron_config)\\n\",\n    \"\\n\",\n    \"    @register(Dispatch.ONE_TO_ALL)\\n\",\n    \"    def get_weights(self):\\n\",\n    \"        output = {}\\n\",\n    \"        for key, val in self.parallel_layer.named_parameters():\\n\",\n    \"            output[key] = val\\n\",\n    \"        return output\\n\",\n    \"\\n\",\n    \"    @register(Dispatch.MEGATRON_COMPUTE)\\n\",\n    \"    def run_layer(self, x):\\n\",\n    \"        x = x.to(\\\"cuda\\\")\\n\",\n    \"        y = self.parallel_layer(x)\\n\",\n    \"        return y\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 175,\n   \"id\": \"a655271d\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"layer_cls = RayClassWithInitArgs(cls=MLPLayerWorker)\\n\",\n    \"layer_worker_group = NVMegatronRayWorkerGroup(\\n\",\n    \"    resource_pool=resource_pool,\\n\",\n    \"    ray_cls_with_init=layer_cls,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 176,\n   \"id\": \"f105ebee\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"4 4 1 1\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(layer_worker_group.world_size, layer_worker_group.tp_size, layer_worker_group.pp_size, layer_worker_group.dp_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 177,\n   \"id\": \"38655091\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"ffn_hidden_size = 11008\\n\",\n    \"batch_size = 16\\n\",\n    \"seq_len = 2048\\n\",\n    \"hidden_size = 4096\\n\",\n    \"\\n\",\n    \"config = OmegaConf.create(\\n\",\n    \"    {\\n\",\n    \"        \\\"hidden_size\\\": hidden_size,\\n\",\n    \"        \\\"intermediate_size\\\": ffn_hidden_size,\\n\",\n    \"        \\\"hidden_act\\\": \\\"silu\\\",\\n\",\n    \"        \\\"pretraining_tp\\\": 1,\\n\",\n    \"        \\\"tp\\\": layer_worker_group.tp_size,\\n\",\n    \"    }\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 178,\n   \"id\": \"a026efca\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"x = torch.rand(size=(seq_len, batch_size, hidden_size), dtype=torch.float32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 179,\n   \"id\": \"f5fcaf13\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[None, None, None, None]\"\n      ]\n     },\n     \"execution_count\": 179,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"layer_worker_group.init_model(config)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 180,\n   \"id\": \"3f5cc9b4\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"torch.Size([2048, 16, 4096])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"output = layer_worker_group.run_layer(\\n\",\n    \"    [x]\\n\",\n    \")  # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\\n\",\n    \"print(output[0].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 181,\n   \"id\": \"49792210\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Shutdown ray cluster\\n\",\n    \"ray.shutdown()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.2\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "verl_rl/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=reinforce_plus_plus \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=1024 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=mse \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=True \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=reinforce_plus_plus_baseline \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=1024 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=mse \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=True \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh",
    "content": "set -x\n\nexport HF_DATASETS_OFFLINE=1\nexport TRANSFORMERS_OFFLINE=1\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=remax \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_remax_example_gsm8k' \\\n    trainer.experiment_name='qwen2.5_3b_function_rm_kl1e-3' \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=5 $@\n"
  },
  {
    "path": "verl_rl/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh",
    "content": "set -x\n\nexport HF_DATASETS_OFFLINE=1\nexport TRANSFORMERS_OFFLINE=1\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=remax \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_remax_example_gsm8k' \\\n    trainer.experiment_name='qwen2.5_7b_function_rm_kl1e-3' \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=10 $@\n"
  },
  {
    "path": "verl_rl/examples/rloo_trainer/run_qwen2-7b.sh",
    "content": "set -x\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=rloo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_rloo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/sft/gsm8k/run_deepseek_6b7.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_deepseek_6b7.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \\\n    trainer.total_epochs=4 \\\n    trainer.logger='[\"console\",\"wandb\"]' $@"
  },
  {
    "path": "verl_rl/examples/sft/gsm8k/run_gemma_2b.sh",
    "content": "# Tested with 2 & 4 GPUs\n\nset -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_gemma_2b.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=google/gemma-2b-it \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-gemma-2b-it \\\n    trainer.total_epochs=2 \\\n    trainer.logger='[\"console\",\"wandb\"]' $@"
  },
  {
    "path": "verl_rl/examples/sft/gsm8k/run_gemma_7b.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_gemma_7b.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    data.prompt_dict_keys=['question'] \\\n    data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=google/gemma-1.1-7b-it \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-gemma-1.1-7b-it \\\n    trainer.total_epochs=4 \\\n    trainer.logger='[\"console\",\"wandb\"]' $@\n"
  },
  {
    "path": "verl_rl/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_qwen2_5_05b_sft_peft_sp2_npu.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=64 \\\n    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \\\n    trainer.logger=console \\\n    trainer.total_epochs=2 $@ \\\n    model.lora_rank=32 \\\n    model.lora_alpha=16 \\\n    model.target_modules=all-linear \\\n    model.strategy=fsdp \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true \\\n    trainer.device=npu\n"
  },
  {
    "path": "verl_rl/examples/sft/gsm8k/run_qwen_05_peft.sh",
    "content": "# Tested with 2 & 4 GPUs\n\nset -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_qwen_05_peft.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \\\n    trainer.logger=console \\\n    trainer.total_epochs=1 $@ \\\n    model.lora_rank=32\\\n    model.lora_alpha=16 \\\n    model.target_modules=all-linear\n\n    # Or you can do this:\n    # model.target_modules=[q_proj,v_proj] \\\n"
  },
  {
    "path": "verl_rl/examples/sft/gsm8k/run_qwen_05_sp2.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_qwen_05_sp2.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size=4 \\\n    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2 \\\n    trainer.logger=console \\\n    trainer.total_training_steps=1 $@ \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true\n"
  },
  {
    "path": "verl_rl/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh",
    "content": "set -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_qwen_05_sp2.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size=4 \\\n    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \\\n    model.use_liger=True \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2-liger \\\n    trainer.logger=console $@ \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true\n"
  },
  {
    "path": "verl_rl/examples/sft/multiturn/run_qwen_05_sp2.sh",
    "content": "#!/bin/bash\nset -x\n\nif [ \"$#\" -lt 2 ]; then\n    echo \"Usage: run_qwen_05_sp2.sh <nproc_per_node> <save_path> [other_configs...]\"\n    exit 1\nfi\n\nnproc_per_node=$1\nsave_path=$2\n\n# Shift the arguments so $@ refers to the rest\nshift 2\n\ntorchrun --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/multiturn/train.parquet \\\n    data.val_files=$HOME/data/multiturn/test.parquet \\\n    data.multiturn.enable=true \\\n    data.multiturn.messages_key=messages \\\n    data.micro_batch_size=4 \\\n    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=multiturn-sft \\\n    trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \\\n    trainer.logger=console \\\n    trainer.total_training_steps=1 $@ \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/README.md",
    "content": "# Multi-Turn Rollout Example (GSM8K)\n\nThis example demonstrates how to perform **multi-turn rollout** using SGLang with a tool-calling capable model (e.g., Qwen2.5-3B) on the GSM8K dataset.\n\n## Usage\n\n### Step 1: Download GSM8K Dataset\n\n```bash\ncd examples/data_preprocess\npython3 gsm8k_multiturn_w_tool.py\n```\n\nThis will download and preprocess the GSM8K dataset into ~/data/gsm8k/.\n\n### Step 2: Run Multi-Turn Rollout\n\nIf you have 8 GPUs\nUse the standard 8-GPU script:\n\n```bash\ncd your_verl_root_dir\nbash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh\n```\n\nIf you have only 4 GPUs\nUse the fallback 4-GPU script:\n\n```bash\ncd your_verl_root_dir\nbash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh \n```\n\n## Notes\n\n- The rollout supports multi-turn conversations with tool-calling capabilities.\n- Current tools are used for GSM8K answer evaluation.\n- Future versions may extend to search and code interpreter tools.\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 2048\n  max_response_length: 2048\n  train_batch_size: 256\n  return_raw_chat: True\n  return_multi_modal_inputs: False\n\nactor_rollout_ref:\n  hybrid_engine: True\n  model:\n    custom_chat_template: \"{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \\\"\\\\n\\\\n# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\\" }}{%- for tool in tools %}{{- \\\"\\\\n\\\" }}{{- tool | tojson }}{%- endfor %}{{- \\\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\\"name\\\\\\\": <function-name>, \\\\\\\"arguments\\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\\n{% endif %}{%- elif message.role == \\\"assistant\\\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}{{- tool_call.name }}{{- '\\\", \\\"arguments\\\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\\\n' }}{%- elif message.role == \\\"tool\\\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\\\n<tool_response>\\\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}{{- '<|im_end|>\\\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n{% endif %}{%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\\n{% endif %}{%- elif message.role == \\\"assistant\\\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}{{- tool_call.name }}{{- '\\\", \\\"arguments\\\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\\\n' }}{%- elif message.role == \\\"tool\\\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\\\n<tool_response>\\\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}{{- '<|im_end|>\\\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\\n{% endif %}\"\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n      # tool_config_path: \"./config/tool_config/gsm8k_tool_config.yaml\"\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_megatron_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 2048\n  max_response_length: 2048\n  train_batch_size: 256\n  return_raw_chat: True\n  return_multi_modal_inputs: False\n\nactor_rollout_ref:\n  hybrid_engine: True\n  model:\n    custom_chat_template: \"{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \\\"\\\\n\\\\n# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\\" }}{%- for tool in tools %}{{- \\\"\\\\n\\\" }}{{- tool | tojson }}{%- endfor %}{{- \\\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\\"name\\\\\\\": <function-name>, \\\\\\\"arguments\\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\\n{% endif %}{%- elif message.role == \\\"assistant\\\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}{{- tool_call.name }}{{- '\\\", \\\"arguments\\\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\\\n' }}{%- elif message.role == \\\"tool\\\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\\\n<tool_response>\\\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}{{- '<|im_end|>\\\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n{% endif %}{%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\\n{% endif %}{%- elif message.role == \\\"assistant\\\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}{{- tool_call.name }}{{- '\\\", \\\"arguments\\\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\\\n' }}{%- elif message.role == \\\"tool\\\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\\\n<tool_response>\\\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}{{- '<|im_end|>\\\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\\n{% endif %}\"\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n      # tool_config_path: \"./config/tool_config/gsm8k_tool_config.yaml\"\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_user_turns: 5\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_megatron_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n  \n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml",
    "content": "interaction:\n  - name: \"gsm8k\"\n    class_name: \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\"\n    config: {}"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 5\n      tool_config_path: \"./config/tool_config/sandbox_fusion_tool_config.yaml\"\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/search_multiturn_grpo.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  max_prompt_length: 1024\n  max_response_length: 1024\n  train_batch_size: 256\n  return_raw_chat: True\n  shuffle: False\n\nactor_rollout_ref:\n  hybrid_engine: True\n  rollout:\n    name: sglang\n    multi_turn:\n      enable: True\n      max_assistant_turns: 2\n      format: qwen\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml",
    "content": "tools:\n  - class_name: \"verl.tools.geo3k_tool.Geo3kTool\"\n    config: \n      type: native\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"calc_geo3k_reward\"\n        description: \"A tool for calculating the reward of geo3k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)\"\n        parameters:\n          type: \"object\"\n          properties:\n            answer:\n              type: \"string\"\n              description: \"The model's answer to the geo3k problem, must be a digits\"\n          required: [\"answer\"]"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml",
    "content": "tools:\n  - class_name: \"verl.tools.gsm8k_tool.Gsm8kTool\"\n    config: \n      type: native\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"calc_gsm8k_reward\"\n        description: \"A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)\"\n        parameters:\n          type: \"object\"\n          properties:\n            answer:\n              type: \"string\"\n              description: \"The model's answer to the GSM8K math problem, must be a digits\"\n          required: [\"answer\"]\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/tool_config/mcp_server.json",
    "content": "{\n    \"mcpServers\": {\n        \"Tavily Expert\": {\n            \"url\": \"your_tavily_expert_url\",\n            \"auth_token\": \"your_tavily_api_token\"\n        }\n    }\n}"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml",
    "content": "tools:\n  - class_name: verl.tools.mcp_search_tool.MCPSearchTool\n    config:\n      rate_limit: 120\n      timeout: 120\n      type: mcp\n    mcp:\n      mcp_servers_config_path: ./mcp_server.json\n      # optional\n      tool_selected_list: \n        - tavily_search_tool"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml",
    "content": "tools:\n  - class_name: \"verl.tools.sandbox_fusion_tools.SandboxFusionTool\"\n    config: \n      sandbox_fusion_url: \"https://xxx.apigateway-cn-beijing.volceapi.com/run_code\"\n      num_workers: 10\n      enable_global_rate_limit: true\n      rate_limit: 10\n      default_timeout: 30\n      default_language: \"python\"\n      memory_limit_mb: 1024\n      type: native\n\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"code_interpreter\"\n        description: \"A tool for executing code.\"\n        parameters:\n          type: \"object\"\n          properties:\n            code:\n              type: \"string\"\n              description: \"The code to execute.\"\n          required: [\"code\"]"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml",
    "content": "tools:\n  - class_name: verl.tools.search_tool.SearchTool\n    config:\n      retrieval_service_url: http://127.0.0.1:8000/retrieve\n      num_workers: 120\n      rate_limit: 120\n      timeout: 30\n      type: native\n    tool_schema:\n      type: function\n      function:\n        name: search\n        description: Searches the web for relevant information based on the given query.\n        parameters:\n          type: object\n          properties:\n            query_list:\n              type: array\n              item:\n                type: string\n              description: A list of fully-formed semantic queries. The tool will return search results for each query.\n          required: \n            - query_list"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='geo3k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='geo3k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-sgl-multi-w-tool-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/geo3k_multiturn_w_tool/train.parquet \\\n    data.val_files=$HOME/data/geo3k_multiturn_w_tool/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh",
    "content": "# run on 4xH100\n# make sure your current working directory is the root of the project\n\nset -x\nexport HYDRA_FULL_ERROR=1\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='geo3k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='geo3k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-async-sgl-multi-w-tool-verify-n16-4cards' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.total_epochs=15 \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \\\n    critic.ppo_max_token_len_per_gpu=8192 \\\n    critic.forward_max_token_len_per_gpu=8192 \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml\" \\\n    $@"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n# this is a verification training script, the parallel setting should be tuned to your model\n\nset -x\n\nexport PYTHONUNBUFFERED=1\nexport RAY_DEDUP_LOGS=0\nexport RUST_BACKTRACE=1\nexport HYDRA_FULL_ERROR=1\nexport CUDA_DEVICE_MAX_CONNECTIONS=1\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='geo3k_multiturn_megatron_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.megatron.seed=42 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='geo3k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/geo3k_multiturn_w_tool/train.parquet \\\n    data.val_files=$HOME/data/geo3k_multiturn_w_tool/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/run_qwen0.5b_gsm8k_multiturn_curriculum.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.sampler.class_name=\"RandomCurriculumSampler\" \\\n    data.sampler.class_path=\"pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu\" \\\n    data.dataloader_num_workers=0 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.train_batch_size=256 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen3-4b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\nTRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-512}\nMICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-8}\nOFFLOAD=${OFFLOAD:-False}\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo_w_interaction' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=$TRAIN_BATCH_SIZE \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=$((1024 * 3)) \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    +actor_rollout_ref.model.enable_activation_offloading=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_BATCH_SIZE \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=$OFFLOAD \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=$OFFLOAD \\\n    +actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=$OFFLOAD \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen2.5-0.5b_function_rm-gsm8k-sgl-multi-w-interaction-n8' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/gsm8k_verl_sgl_multi_turn_w_interaction/train.parquet \\\n    data.val_files=$HOME/data/gsm8k_verl_sgl_multi_turn_w_interaction/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 \\\n    actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@\n\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh",
    "content": "# run on 4xH100\n# make sure your current working directory is the root of the project\n\nset -x\nexport HYDRA_FULL_ERROR=1\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16-4cards' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.total_epochs=15 \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \\\n    critic.ppo_max_token_len_per_gpu=8192 \\\n    critic.forward_max_token_len_per_gpu=8192 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml\" \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=1 \\\n    $@"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.rollout.trace.backend=mlflow \\\n    actor_rollout_ref.rollout.trace.token2text=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"mlflow\"]' \\\n    trainer.project_name='gsm8k_tool-agent' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-tool-agent-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    trainer.total_training_steps=2 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n# this is a verification training script, the parallel setting should be tuned to your model\n\nset -x\n\nexport PYTHONUNBUFFERED=1\nexport RAY_DEDUP_LOGS=0\nexport RUST_BACKTRACE=1\nexport HYDRA_FULL_ERROR=1\nexport CUDA_DEVICE_MAX_CONNECTIONS=1\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_megatron_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=/user/longxiang1/models/Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=2 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.megatron.seed=42 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=2 \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=/user/longxiang1/data/gsm8k_verl_sgl_multi_turn_preprocessed_v2/train.parquet \\\n    data.val_files=/user/longxiang1/data/gsm8k_verl_sgl_multi_turn_preprocessed_v2/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen3-4B \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=16 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name='qwen3-4b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=20 \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.total_epochs=15 $@\n\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 Search-R1 Contributors\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/scripts/download.py\n\n\nimport argparse\n\nfrom huggingface_hub import hf_hub_download\n\nparser = argparse.ArgumentParser(description=\"Download files from a Hugging Face dataset repository.\")\nparser.add_argument(\"--repo_id\", type=str, default=\"PeterJinGo/wiki-18-e5-index\", help=\"Hugging Face repository ID\")\nparser.add_argument(\"--save_path\", type=str, required=True, help=\"Local directory to save files\")\n\nargs = parser.parse_args()\n\nrepo_id = \"PeterJinGo/wiki-18-e5-index\"\nfor file in [\"part_aa\", \"part_ab\"]:\n    hf_hub_download(\n        repo_id=repo_id,\n        filename=file,  # e.g., \"e5_Flat.index\"\n        repo_type=\"dataset\",\n        local_dir=args.save_path,\n    )\n\nrepo_id = \"PeterJinGo/wiki-18-corpus\"\nhf_hub_download(\n    repo_id=repo_id,\n    filename=\"wiki-18.jsonl.gz\",\n    repo_type=\"dataset\",\n    local_dir=args.save_path,\n)\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 Search-R1 Contributors\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/search_r1/search/retrieval_server.py\n\nimport argparse\nimport json\nimport warnings\nfrom typing import Optional\n\nimport datasets\nimport faiss\nimport numpy as np\nimport torch\nimport uvicorn\nfrom fastapi import FastAPI\nfrom pydantic import BaseModel\nfrom tqdm import tqdm\nfrom transformers import AutoModel, AutoTokenizer\n\n\ndef load_corpus(corpus_path: str):\n    corpus = datasets.load_dataset(\"json\", data_files=corpus_path, split=\"train\", num_proc=4)\n    return corpus\n\n\ndef load_docs(corpus, doc_idxs):\n    results = [corpus[int(idx)] for idx in doc_idxs]\n    return results\n\n\ndef load_model(model_path: str, use_fp16: bool = False):\n    model = AutoModel.from_pretrained(model_path, trust_remote_code=True)\n    model.eval()\n    model.cuda()\n    if use_fp16:\n        model = model.half()\n    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)\n    return model, tokenizer\n\n\ndef pooling(pooler_output, last_hidden_state, attention_mask=None, pooling_method=\"mean\"):\n    if pooling_method == \"mean\":\n        last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)\n        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]\n    elif pooling_method == \"cls\":\n        return last_hidden_state[:, 0]\n    elif pooling_method == \"pooler\":\n        return pooler_output\n    else:\n        raise NotImplementedError(\"Pooling method not implemented!\")\n\n\nclass Encoder:\n    def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):\n        self.model_name = model_name\n        self.model_path = model_path\n        self.pooling_method = pooling_method\n        self.max_length = max_length\n        self.use_fp16 = use_fp16\n\n        self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16)\n        self.model.eval()\n\n    @torch.no_grad()\n    def encode(self, query_list: list[str], is_query=True) -> np.ndarray:\n        # processing query for different encoders\n        if isinstance(query_list, str):\n            query_list = [query_list]\n\n        if \"e5\" in self.model_name.lower():\n            if is_query:\n                query_list = [f\"query: {query}\" for query in query_list]\n            else:\n                query_list = [f\"passage: {query}\" for query in query_list]\n\n        if \"bge\" in self.model_name.lower():\n            if is_query:\n                query_list = [\n                    f\"Represent this sentence for searching relevant passages: {query}\" for query in query_list\n                ]\n\n        inputs = self.tokenizer(\n            query_list, max_length=self.max_length, padding=True, truncation=True, return_tensors=\"pt\"\n        )\n        inputs = {k: v.cuda() for k, v in inputs.items()}\n\n        if \"T5\" in type(self.model).__name__:\n            # T5-based retrieval model\n            decoder_input_ids = torch.zeros((inputs[\"input_ids\"].shape[0], 1), dtype=torch.long).to(\n                inputs[\"input_ids\"].device\n            )\n            output = self.model(**inputs, decoder_input_ids=decoder_input_ids, return_dict=True)\n            query_emb = output.last_hidden_state[:, 0, :]\n        else:\n            output = self.model(**inputs, return_dict=True)\n            query_emb = pooling(\n                output.pooler_output, output.last_hidden_state, inputs[\"attention_mask\"], self.pooling_method\n            )\n            if \"dpr\" not in self.model_name.lower():\n                query_emb = torch.nn.functional.normalize(query_emb, dim=-1)\n\n        query_emb = query_emb.detach().cpu().numpy()\n        query_emb = query_emb.astype(np.float32, order=\"C\")\n\n        del inputs, output\n        torch.cuda.empty_cache()\n\n        return query_emb\n\n\nclass BaseRetriever:\n    def __init__(self, config):\n        self.config = config\n        self.retrieval_method = config.retrieval_method\n        self.topk = config.retrieval_topk\n\n        self.index_path = config.index_path\n        self.corpus_path = config.corpus_path\n\n    def _search(self, query: str, num: int, return_score: bool):\n        raise NotImplementedError\n\n    def _batch_search(self, query_list: list[str], num: int, return_score: bool):\n        raise NotImplementedError\n\n    def search(self, query: str, num: int = None, return_score: bool = False):\n        return self._search(query, num, return_score)\n\n    def batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):\n        return self._batch_search(query_list, num, return_score)\n\n\nclass BM25Retriever(BaseRetriever):\n    def __init__(self, config):\n        super().__init__(config)\n        from pyserini.search.lucene import LuceneSearcher\n\n        self.searcher = LuceneSearcher(self.index_path)\n        self.contain_doc = self._check_contain_doc()\n        if not self.contain_doc:\n            self.corpus = load_corpus(self.corpus_path)\n        self.max_process_num = 8\n\n    def _check_contain_doc(self):\n        return self.searcher.doc(0).raw() is not None\n\n    def _search(self, query: str, num: int = None, return_score: bool = False):\n        if num is None:\n            num = self.topk\n        hits = self.searcher.search(query, num)\n        if len(hits) < 1:\n            if return_score:\n                return [], []\n            else:\n                return []\n        scores = [hit.score for hit in hits]\n        if len(hits) < num:\n            warnings.warn(\"Not enough documents retrieved!\", stacklevel=2)\n        else:\n            hits = hits[:num]\n\n        if self.contain_doc:\n            all_contents = [json.loads(self.searcher.doc(hit.docid).raw())[\"contents\"] for hit in hits]\n            results = [\n                {\n                    \"title\": content.split(\"\\n\")[0].strip('\"'),\n                    \"text\": \"\\n\".join(content.split(\"\\n\")[1:]),\n                    \"contents\": content,\n                }\n                for content in all_contents\n            ]\n        else:\n            results = load_docs(self.corpus, [hit.docid for hit in hits])\n\n        if return_score:\n            return results, scores\n        else:\n            return results\n\n    def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):\n        results = []\n        scores = []\n        for query in query_list:\n            item_result, item_score = self._search(query, num, True)\n            results.append(item_result)\n            scores.append(item_score)\n        if return_score:\n            return results, scores\n        else:\n            return results\n\n\nclass DenseRetriever(BaseRetriever):\n    def __init__(self, config):\n        super().__init__(config)\n        self.index = faiss.read_index(self.index_path)\n        if config.faiss_gpu:\n            co = faiss.GpuMultipleClonerOptions()\n            co.useFloat16 = True\n            co.shard = True\n            self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)\n\n        self.corpus = load_corpus(self.corpus_path)\n        self.encoder = Encoder(\n            model_name=self.retrieval_method,\n            model_path=config.retrieval_model_path,\n            pooling_method=config.retrieval_pooling_method,\n            max_length=config.retrieval_query_max_length,\n            use_fp16=config.retrieval_use_fp16,\n        )\n        self.topk = config.retrieval_topk\n        self.batch_size = config.retrieval_batch_size\n\n    def _search(self, query: str, num: int = None, return_score: bool = False):\n        if num is None:\n            num = self.topk\n        query_emb = self.encoder.encode(query)\n        scores, idxs = self.index.search(query_emb, k=num)\n        idxs = idxs[0]\n        scores = scores[0]\n        results = load_docs(self.corpus, idxs)\n        if return_score:\n            return results, scores.tolist()\n        else:\n            return results\n\n    def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):\n        if isinstance(query_list, str):\n            query_list = [query_list]\n        if num is None:\n            num = self.topk\n\n        results = []\n        scores = []\n        for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc=\"Retrieval process: \"):\n            query_batch = query_list[start_idx : start_idx + self.batch_size]\n            batch_emb = self.encoder.encode(query_batch)\n            batch_scores, batch_idxs = self.index.search(batch_emb, k=num)\n            batch_scores = batch_scores.tolist()\n            batch_idxs = batch_idxs.tolist()\n\n            # load_docs is not vectorized, but is a python list approach\n            flat_idxs = sum(batch_idxs, [])\n            batch_results = load_docs(self.corpus, flat_idxs)\n            # chunk them back\n            batch_results = [batch_results[i * num : (i + 1) * num] for i in range(len(batch_idxs))]\n\n            results.extend(batch_results)\n            scores.extend(batch_scores)\n\n            del batch_emb, batch_scores, batch_idxs, query_batch, flat_idxs, batch_results\n            torch.cuda.empty_cache()\n\n        if return_score:\n            return results, scores\n        else:\n            return results\n\n\ndef get_retriever(config):\n    if config.retrieval_method == \"bm25\":\n        return BM25Retriever(config)\n    else:\n        return DenseRetriever(config)\n\n\n#####################################\n# FastAPI server below\n#####################################\n\n\nclass Config:\n    \"\"\"\n    Minimal config class (simulating your argparse)\n    Replace this with your real arguments or load them dynamically.\n    \"\"\"\n\n    def __init__(\n        self,\n        retrieval_method: str = \"bm25\",\n        retrieval_topk: int = 10,\n        index_path: str = \"./index/bm25\",\n        corpus_path: str = \"./data/corpus.jsonl\",\n        dataset_path: str = \"./data\",\n        data_split: str = \"train\",\n        faiss_gpu: bool = True,\n        retrieval_model_path: str = \"./model\",\n        retrieval_pooling_method: str = \"mean\",\n        retrieval_query_max_length: int = 256,\n        retrieval_use_fp16: bool = False,\n        retrieval_batch_size: int = 128,\n    ):\n        self.retrieval_method = retrieval_method\n        self.retrieval_topk = retrieval_topk\n        self.index_path = index_path\n        self.corpus_path = corpus_path\n        self.dataset_path = dataset_path\n        self.data_split = data_split\n        self.faiss_gpu = faiss_gpu\n        self.retrieval_model_path = retrieval_model_path\n        self.retrieval_pooling_method = retrieval_pooling_method\n        self.retrieval_query_max_length = retrieval_query_max_length\n        self.retrieval_use_fp16 = retrieval_use_fp16\n        self.retrieval_batch_size = retrieval_batch_size\n\n\nclass QueryRequest(BaseModel):\n    queries: list[str]\n    topk: Optional[int] = None\n    return_scores: bool = False\n\n\napp = FastAPI()\n\n\n@app.post(\"/retrieve\")\ndef retrieve_endpoint(request: QueryRequest):\n    \"\"\"\n    Endpoint that accepts queries and performs retrieval.\n\n    Input format:\n    {\n      \"queries\": [\"What is Python?\", \"Tell me about neural networks.\"],\n      \"topk\": 3,\n      \"return_scores\": true\n    }\n\n    Output format (when return_scores=True，similarity scores are returned):\n    {\n        \"result\": [\n            [   # Results for each query\n                {\n                    {\"document\": doc, \"score\": score}\n                },\n                # ... more documents\n            ],\n            # ... results for other queries\n        ]\n    }\n    \"\"\"\n    if not request.topk:\n        request.topk = config.retrieval_topk  # fallback to default\n\n    # Perform batch retrieval\n    results, scores = retriever.batch_search(\n        query_list=request.queries, num=request.topk, return_score=request.return_scores\n    )\n\n    # Format response\n    resp = []\n    for i, single_result in enumerate(results):\n        if request.return_scores:\n            # If scores are returned, combine them with results\n            combined = []\n            for doc, score in zip(single_result, scores[i], strict=True):\n                combined.append({\"document\": doc, \"score\": score})\n            resp.append(combined)\n        else:\n            resp.append(single_result)\n    return {\"result\": resp}\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Launch the local faiss retriever.\")\n    parser.add_argument(\n        \"--index_path\", type=str, default=\"/home/peterjin/mnt/index/wiki-18/e5_Flat.index\", help=\"Corpus indexing file.\"\n    )\n    parser.add_argument(\n        \"--corpus_path\",\n        type=str,\n        default=\"/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl\",\n        help=\"Local corpus file.\",\n    )\n    parser.add_argument(\"--topk\", type=int, default=3, help=\"Number of retrieved passages for one query.\")\n    parser.add_argument(\"--retriever_name\", type=str, default=\"e5\", help=\"Name of the retriever model.\")\n    parser.add_argument(\n        \"--retriever_model\", type=str, default=\"intfloat/e5-base-v2\", help=\"Path of the retriever model.\"\n    )\n    parser.add_argument(\"--faiss_gpu\", action=\"store_true\", help=\"Use GPU for computation\")\n\n    args = parser.parse_args()\n\n    # 1) Build a config (could also parse from arguments).\n    #    In real usage, you'd parse your CLI arguments or environment variables.\n    config = Config(\n        retrieval_method=args.retriever_name,  # or \"dense\"\n        index_path=args.index_path,\n        corpus_path=args.corpus_path,\n        retrieval_topk=args.topk,\n        faiss_gpu=args.faiss_gpu,\n        retrieval_model_path=args.retriever_model,\n        retrieval_pooling_method=\"mean\",\n        retrieval_query_max_length=256,\n        retrieval_use_fp16=True,\n        retrieval_batch_size=512,\n    )\n\n    # 2) Instantiate a global retriever so it is loaded once and reused.\n    retriever = get_retriever(config)\n\n    # 3) Launch the server. By default, it listens on http://127.0.0.1:8000\n    uvicorn.run(app, host=\"0.0.0.0\", port=8000)\n"
  },
  {
    "path": "verl_rl/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh",
    "content": "# run on 8xH20\n# make sure your current working directory is the root of the project\n\nset -x\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\n\nTRAIN_DATA=\"$HOME/data/searchR1_processed_direct/train.parquet\"\nVAL_DATA=\"$HOME/data/searchR1_processed_direct/test.parquet\"\n\nTOOL_CONFIG=\"$CONFIG_PATH/tool_config/search_tool_config.yaml\"\n\n\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='search_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=512 \\\n    data.val_batch_size=256 \\\n    data.max_prompt_length=4096 \\\n    data.max_response_length=3000 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.max_model_len=15000 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.val_before_train=False \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='search_r1_like_async_rl' \\\n    trainer.experiment_name='qwen2.5-3b-instruct_function_rm-search-async-sgl-multi-w-searchtool-verify-n16' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=100 \\\n    trainer.test_freq=50 \\\n    data.train_files=\"$TRAIN_DATA\" \\\n    data.val_files=\"$VAL_DATA\"  \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$TOOL_CONFIG\" \\\n    trainer.total_epochs=1 $@\n\n"
  },
  {
    "path": "verl_rl/examples/slurm/ray_on_slurm.slurm",
    "content": "#!/bin/bash\n#SBATCH --job-name=verl-ray-on-slurm\n#SBATCH --nodes=2\n#SBATCH --ntasks-per-node=1\n#SBATCH --mem=200G\n#SBATCH --partition=your-partition\n#SBATCH --time=01:00:00\n#SBATCH --account=your-account\n#SBATCH --gpus-per-node=4\n#SBATCH --cpus-per-task=64\n#SBATCH --output=slurm-%j.out\n#SBATCH --error=slurm-%j.err\n\n# load necessary modules\n\n# replace these information with your own\nverl_workdir=/path/to/verl\ntrain_files=/path/to/gsm8k/train.parquet\nval_files=/path/to/gsm8k/test.parquet\napptainer_image_path=/path/to/verl-ngc.sif\n# replace these information with your own\n\n# Getting the node names\nnodes=$(scontrol show hostnames \"$SLURM_JOB_NODELIST\")\nnodes_array=(\"$nodes\")\n\nhead_node=${nodes_array[0]}\nhead_node_ip=$(srun --nodes=1 --ntasks=1 -w \"$head_node\" hostname --ip-address)\n\n# if we detect a space character in the head node IP, we'll\n# convert it to an ipv4 address. This step is optional.\nif [[ \"$head_node_ip\" == *\" \"* ]]; then\nIFS=' ' read -ra ADDR <<<\"$head_node_ip\"\nif [[ ${#ADDR[0]} -gt 16 ]]; then\n  head_node_ip=${ADDR[1]}\nelse\n  head_node_ip=${ADDR[0]}\nfi\necho \"IPV6 address detected. We split the IPV4 address as $head_node_ip\"\nfi\n\nport=6379\nip_head=$head_node_ip:$port\nexport ip_head\necho \"IP Head: $ip_head\"\n\n# make sure we set environment variables before Ray initialization\n\nprintenv\n\necho \"Starting HEAD at $head_node\"\nsrun --nodes=1 --ntasks=1 -w \"$head_node\" \\\n    apptainer run --nv --bind $verl_workdir $apptainer_image_path \\\n        ray start --head --node-ip-address=\"$head_node_ip\" --port=$port \\\n        --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n# optional, though may be useful in certain versions of Ray < 1.0.\nsleep 10\n\n# number of nodes other than the head node\nworker_num=$((SLURM_JOB_NUM_NODES - 1))\n\nfor ((i = 1; i <= worker_num; i++)); do\n    node_i=${nodes_array[$i]}\n    echo \"Starting WORKER $i at $node_i\"\n    srun --nodes=1 --ntasks=1 -w \"$node_i\" \\\n        apptainer run --nv --bind $verl_workdir $apptainer_image_path \\\n            ray start --address \"$ip_head\" --num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus \"${SLURM_GPUS_PER_NODE}\" --block &\n    sleep 5\ndone\n\nPYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w \"$head_node\" \\\n    apptainer run --nv --bind $verl_workdir $apptainer_image_path \\\n    python3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$train_files \\\n    data.val_files=$val_files \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=256 \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    critic.ppo_micro_batch_size_per_gpu=4 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.logger=console \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=\"${SLURM_GPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${SLURM_NNODES}\" \\\n    trainer.save_freq=10 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=15 2>&1 | tee verl_demo_slurm.log\n"
  },
  {
    "path": "verl_rl/examples/split_placement/README.md",
    "content": "# Split Placement Example\nHere we introduce how to run the naive implementation of the split placement of PPO algorithm.\nWe will release the complete version of flexible placement in the near future.\n\n For quickstart, you can only follow Step 2 to modify the code and then follow Step 4 to execute the split placement example.\n\n### Step 1: Placing the models to different GPUs\nSpecify the placement and resource allocation. In the example, we place the actor and reference in the first half of the GPUs while map the critic and reward model (if any) to the second half of the GPUs.\n```python\nactor_rollout_ref_pool_id = 'actor_rollout_ref_pool'\ncritic_pool_id = 'critic_pool'\nif config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0:\n    resource_pool_spec = {\n        actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,\n        critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,\n    }\nelse:\n    resource_pool_spec = {\n        actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),\n        critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),\n    }\nprint(f'resource_pool_spec: {resource_pool_spec}')\nmapping = {\n    Role.ActorRollout: actor_rollout_ref_pool_id,\n    Role.Critic: critic_pool_id,\n    Role.RefPolicy: actor_rollout_ref_pool_id,\n}\nmapping[Role.RewardModel] = critic_pool_id\n```\n\n### Step 2: Make the models executed asynchronously\nBased on the model placement, we need to make the models executed asynchronously.\n\nTo do so, you need to turn off the `blocking` flag (i.e., `blocking=False`) in our decorator of some model operations.\nFor example, we hope the actor update and critic update can be executed in parallel, then we need to make the following modification in `fsdp_workers.py`\n\n```\n@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)\ndef update_actor(self, data: DataProto):\n    ...\n\n@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)\ndef update_critic(self, data: DataProto):\n    ...\n```\n\nWe can also parallelize the computation of `ref_log_prob` and `values` and `rewards` in the split placement. For simplicity of the tutorial, we don't do this in this example.\n\n### Step 3: Execute these operation in parallel in the single controller process\nTo implement the parallel execution of the actor and critic update, the only thing we need to modify in the `ray_trainer.py` is to `get` the concurrent  `futures` on the single controller process.\n\n```python\ncritic_output = critic_output.get()\nactor_output = actor_output.get()\n```\n\n### Step 4: Run the split placement example\n\n```\nbash run_deepseek7b_llm.sh\n```\n"
  },
  {
    "path": "verl_rl/examples/split_placement/config/ppo_trainer_split.yaml",
    "content": "# the ppo trainer split config will override default ppo_trainer.yaml\n\nhydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  tokenizer: null\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n  prompt_key: prompt\n  max_prompt_length: 512\n  max_response_length: 512\n  train_batch_size: 1024\n  val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves\n  return_raw_input_ids: False  # This should be set to true when the tokenizer between policy and rm differs\n  return_raw_chat: False\n  return_full_prompt: False\n  shuffle: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    external_lib: null\n    override_config: { }\n    enable_gradient_checkpointing: True\n    use_remove_padding: False\n  actor:\n    strategy: fsdp  # This is for backward-compatibility\n    ppo_mini_batch_size: 256\n    ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu\n    ppo_micro_batch_size_per_gpu: null\n    use_dynamic_bsz: False\n    ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}\n    grad_clip: 1.0\n    clip_ratio: 0.2\n    entropy_coeff: 0.0\n    use_kl_loss: False # True for GRPO\n    kl_loss_coef: 0.001 # for grpo\n    kl_loss_type: low_var_kl # for grpo\n    ppo_epochs: 1\n    shuffle: False\n    ulysses_sequence_parallel_size: 1 # sp size\n    optim:\n      lr: 1e-6\n      lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.\n      lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n      min_lr_ratio: null   # only useful for warmup with cosine\n      warmup_style: constant  # select from constant/cosine\n      total_training_steps: -1  # must be override by program\n    fsdp_config:\n      wrap_policy:\n        # transformer_layer_cls_to_wrap: None\n        min_num_params: 0\n      param_offload: False\n      optimizer_offload: False\n      fsdp_size: -1\n  ref:\n    fsdp_config:\n      param_offload: False\n      wrap_policy:\n        # transformer_layer_cls_to_wrap: None\n        min_num_params: 0\n    log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n    ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size\n  rollout:\n    name: vllm\n    temperature: 1.0\n    top_k: -1 # 0 for hf rollout, -1 for vllm rollout\n    top_p: 1\n    prompt_length: ${data.max_prompt_length}  # not use for opensource\n    response_length: ${data.max_response_length}\n    # for vllm rollout\n    dtype: bfloat16 # should align with FSDP\n    gpu_memory_utilization: 0.5\n    ignore_eos: False\n    enforce_eager: True\n    free_cache_engine: True\n    load_format: dummy_dtensor\n    tensor_model_parallel_size: 2\n    max_num_batched_tokens: 8192\n    max_num_seqs: 1024\n    log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n    disable_log_stats: True\n    enable_chunked_prefill: True # could get higher throughput\n    # for hf rollout\n    do_sample: True\n    # number of responses (i.e. num sample times)\n    n: 1 # > 1 for grpo\n\ncritic:\n  strategy: fsdp\n  optim:\n    lr: 1e-5\n    lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n    min_lr_ratio: null   # only useful for warmup with cosine\n    warmup_style: constant  # select from constant/cosine\n    total_training_steps: -1  # must be override by program\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    tokenizer_path: ${actor_rollout_ref.model.path}\n    override_config: { }\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    enable_gradient_checkpointing: True\n    use_remove_padding: False\n    fsdp_config:\n      param_offload: False\n      optimizer_offload: False\n      wrap_policy:\n        # transformer_layer_cls_to_wrap: None\n        min_num_params: 0\n      fsdp_size: -1\n  ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}\n  ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu\n  ppo_micro_batch_size_per_gpu: null\n  forward_micro_batch_size: ${critic.ppo_micro_batch_size}\n  forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}\n  use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n  ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2\n  forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}\n  ulysses_sequence_parallel_size: 1 # sp size\n  ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}\n  shuffle: ${actor_rollout_ref.actor.shuffle}\n  grad_clip: 1.0\n  cliprange_value: 0.5\n\nreward_model:\n  enable: False\n  strategy: fsdp\n  model:\n    input_tokenizer: ${actor_rollout_ref.model.path}  # set this to null if the chat template is identical\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    use_remove_padding: False\n    fsdp_config:\n      min_num_params: 0\n      param_offload: False\n      fsdp_size: -1\n  micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu\n  micro_batch_size_per_gpu: null # set a number\n  max_length: null\n  ulysses_sequence_parallel_size: 1 # sp size\n  use_dynamic_bsz: ${critic.use_dynamic_bsz}\n  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n  reward_manager: naive\n\nalgorithm:\n  gamma: 1.0\n  lam: 1.0\n  adv_estimator: gae\n  use_kl_in_reward: False\n  kl_penalty: kl  # how to estimate kl divergence\n  kl_ctrl:\n    type: fixed\n    kl_coef: 0.001\n\ntrainer:\n  total_epochs: 30\n  total_training_steps: null\n  project_name: verl_examples\n  experiment_name: gsm8k\n  logger: [ 'console', 'wandb' ]\n  log_val_generations: 0\n  nnodes: 1\n  n_gpus_per_node: 8\n  save_freq: -1\n  # auto: find the last ckpt to resume. If can't find, start from scratch\n  resume_mode: auto # or disable or resume_path if resume_from_path is set\n  resume_from_path: null\n  test_freq: -1\n  critic_warmup: 0\n  default_hdfs_dir: null\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n\nray_init:\n  num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.\n"
  },
  {
    "path": "verl_rl/examples/split_placement/main_ppo_split.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport hydra\nimport ray\nimport torch\nfrom split_monkey_patch import fit\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.ray_trainer import RayPPOTrainer\nfrom verl.utils.reward_score import gsm8k, math\n\n\ndef _select_rm_score_fn(data_source):\n    if data_source == \"openai/gsm8k\":\n        return gsm8k.compute_score\n    elif data_source == \"lighteval/MATH\":\n        return math.compute_score\n    else:\n        raise NotImplementedError\n\n\nclass RewardManager:\n    def __init__(self, tokenizer, num_examine) -> None:\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n\n    def __call__(self, data: DataProto, return_dict: bool = False):\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            return data.batch[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n\n        already_print_data_sources = {}\n\n        for i in range(len(data)):\n            data_item = data[i]  # DataProtoItem\n\n            prompt_ids = data_item.batch[\"prompts\"]\n\n            prompt_length = prompt_ids.shape[-1]\n\n            valid_prompt_length = data_item.batch[\"attention_mask\"][:prompt_length].sum()\n            valid_prompt_ids = prompt_ids[-valid_prompt_length:]\n\n            response_ids = data_item.batch[\"responses\"]\n            valid_response_length = data_item.batch[\"attention_mask\"][prompt_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            sequences = torch.cat((valid_prompt_ids, valid_response_ids))\n            sequences_str = self.tokenizer.decode(sequences)\n\n            ground_truth = data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"]\n\n            # select rm_score\n            data_source = data_item.non_tensor_batch[\"data_source\"]\n            compute_score_fn = _select_rm_score_fn(data_source)\n\n            score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth)\n            reward_tensor[i, valid_response_length - 1] = score\n\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                print(sequences_str)\n\n        if return_dict:\n            return {\"reward_tensor\": reward_tensor}\n        else:\n            return reward_tensor\n\n\n@hydra.main(config_path=\"config\", config_name=\"ppo_trainer_split\", version_base=None)\ndef main(config):\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        ray.init(\n            runtime_env={\"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\"}},\n            num_cpus=config.ray_init.num_cpus,\n        )\n\n    ray.get(main_task.remote(config))\n\n\n@ray.remote\ndef main_task(config):\n    # print initial config\n    from pprint import pprint\n\n    from omegaconf import OmegaConf\n\n    from verl.utils.fs import copy_to_local\n\n    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n    OmegaConf.resolve(config)\n\n    # download the checkpoint from hdfs\n    local_path = copy_to_local(config.actor_rollout_ref.model.path)\n\n    # instantiate tokenizer\n    from verl.utils import hf_tokenizer\n\n    tokenizer = hf_tokenizer(local_path)\n\n    # define worker classes\n    if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n        assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n        from verl.single_controller.ray import RayWorkerGroup\n        from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker\n\n        ray_worker_group_cls = RayWorkerGroup\n\n    elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n        assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n        from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n        from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker\n\n        ray_worker_group_cls = NVMegatronRayWorkerGroup\n\n    else:\n        raise NotImplementedError\n\n    from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n    role_worker_mapping = {\n        Role.ActorRollout: ray.remote(ActorRolloutRefWorker),\n        Role.Critic: ray.remote(CriticWorker),\n    }\n\n    # NOTE: initialze two resource pool\n    actor_rollout_ref_pool_id = \"actor_rollout_ref_pool\"\n    critic_pool_id = \"critic_pool\"\n    if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0:\n        resource_pool_spec = {\n            actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,\n            critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,\n        }\n    else:\n        resource_pool_spec = {\n            actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),\n            critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),\n        }\n    print(f\"resource_pool_spec: {resource_pool_spec}\")\n    mapping = {\n        Role.ActorRollout: actor_rollout_ref_pool_id,\n        Role.Critic: critic_pool_id,\n    }\n\n    # use reference model\n    if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n        role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n        mapping[Role.RefPolicy] = actor_rollout_ref_pool_id\n\n    # we should adopt a multi-source reward function here\n    # - for rule-based rm, we directly call a reward score\n    # - for model-based rm, we call a model\n    # - for code related prompt, we send to a sandbox if there are test cases\n    # - finally, we combine all the rewards together\n    # - The reward type depends on the tag of the data\n    if config.reward_model.enable:\n        if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n            from verl.workers.fsdp_workers import RewardModelWorker\n        elif config.reward_model.strategy == \"megatron\":\n            from verl.workers.megatron_workers import RewardModelWorker\n        else:\n            raise NotImplementedError\n        role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n        mapping[Role.RewardModel] = critic_pool_id\n\n    reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)\n\n    # Note that we always use function-based RM for validation\n    val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1)\n\n    resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n    RayPPOTrainer.fit = fit\n    trainer = RayPPOTrainer(\n        config=config,\n        tokenizer=tokenizer,\n        role_worker_mapping=role_worker_mapping,\n        resource_pool_manager=resource_pool_manager,\n        ray_worker_group_cls=ray_worker_group_cls,\n        reward_fn=reward_fn,\n        val_reward_fn=val_reward_fn,\n    )\n    trainer.init_workers()\n    trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/examples/split_placement/run_deepseek7b_llm.sh",
    "content": "set -x\n\npython3 main_ppo_split.py \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    critic.optim.lr=1e-5 \\\n    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.ppo_micro_batch_size_per_gpu=8 \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_example_gsm8k' \\\n    trainer.experiment_name='deepseek_llm_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/split_placement/split_monkey_patch.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nAn naive implementation of split placment example\n\"\"\"\n\nimport uuid\nfrom copy import deepcopy\nfrom pprint import pprint\n\nimport numpy as np\nimport torch\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.ray_trainer import (\n    AdvantageEstimator,\n    apply_kl_penalty,\n    compute_advantage,\n    compute_data_metrics,\n    compute_timing_metrics,\n    marked_timer,\n)\nfrom verl.utils.metric import reduce_metrics\n\n\ndef fit(self):\n    \"\"\"\n    The training loop of PPO.\n    The driver process only need to call the compute functions of the worker group through RPC\n    to construct the PPO dataflow.\n    The light-weight advantage computation is done on the driver process.\n    \"\"\"\n    from omegaconf import OmegaConf\n\n    from verl.utils.tracking import Tracking\n\n    logger = Tracking(\n        project_name=self.config.trainer.project_name,\n        experiment_name=self.config.trainer.experiment_name,\n        default_backend=self.config.trainer.logger,\n        config=OmegaConf.to_container(self.config, resolve=True),\n    )\n\n    self.global_steps = 0\n\n    # load checkpoint before doing anything\n    self._load_checkpoint()\n\n    # perform validation before training\n    # currently, we only support validation using the reward_function.\n    if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n        val_metrics = self._validate()\n        pprint(f\"Initial validation metrics: {val_metrics}\")\n        logger.log(data=val_metrics, step=self.global_steps)\n        if self.config.trainer.get(\"val_only\", False):\n            return\n\n    # we start from step 1\n    self.global_steps += 1\n    last_val_metrics = None\n\n    for epoch in range(self.config.trainer.total_epochs):\n        for batch_dict in self.train_dataloader:\n            metrics = {}\n            timing_raw = {}\n\n            batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n            # pop those keys for generation\n            gen_batch = batch.pop(batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"])\n            is_last_step = self.global_steps >= self.total_training_steps\n\n            with marked_timer(\"step\", timing_raw):\n                # generate a batch\n                with marked_timer(\"gen\", timing_raw):\n                    gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                    timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                    gen_batch_output.meta_info.pop(\"timing\", None)\n\n                if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                    with marked_timer(\"gen_max\", timing_raw):\n                        gen_baseline_batch = deepcopy(gen_batch)\n                        gen_baseline_batch.meta_info[\"do_sample\"] = False\n                        gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n\n                        batch = batch.union(gen_baseline_output)\n                        reward_baseline_tensor = self.reward_fn(batch)\n                        reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                        batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))\n\n                        batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                        del gen_baseline_batch, gen_baseline_output\n\n                batch.non_tensor_batch[\"uid\"] = np.array(\n                    [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object\n                )\n                # repeat to align with repeated responses in rollout\n                batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                batch = batch.union(gen_batch_output)\n\n                # Balance the number of valid tokens across DP ranks.\n                # NOTE: This usually changes the order of data in the `batch`,\n                # which won't affect the advantage calculation (since it's based on uid),\n                # but might affect the loss calculation (due to the change of mini-batching).\n                # TODO: Decouple the DP balancing and mini-batching.\n                self._balance_batch(batch, metrics=metrics)\n\n                # compute global_valid tokens\n                batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                # recompute old_log_probs\n                with marked_timer(\"old_log_prob\", timing_raw):\n                    old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                    batch = batch.union(old_log_prob)\n\n                if self.use_reference_policy:\n                    # compute reference log_prob\n                    with marked_timer(\"ref\", timing_raw):\n                        ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                        batch = batch.union(ref_log_prob)\n\n                # compute values\n                if self.use_critic:\n                    with marked_timer(\"values\", timing_raw):\n                        values = self.critic_wg.compute_values(batch)\n                        batch = batch.union(values)\n\n                with marked_timer(\"adv\", timing_raw):\n                    # compute scores. Support both model and function-based.\n                    # We first compute the scores using reward model. Then, we call reward_fn to combine\n                    # the results from reward model and rule-based results.\n                    if self.use_rm:\n                        # we first compute reward model score\n                        reward_tensor = self.rm_wg.compute_rm_score(batch)\n                        batch = batch.union(reward_tensor)\n\n                    # we combine with rule-based rm\n                    reward_tensor = self.reward_fn(batch)\n                    batch.batch[\"token_level_scores\"] = reward_tensor\n\n                    # compute rewards. apply_kl_penalty if available\n                    if self.config.algorithm.use_kl_in_reward:\n                        batch, kl_metrics = apply_kl_penalty(\n                            batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                        )\n                        metrics.update(kl_metrics)\n                    else:\n                        batch.batch[\"token_level_rewards\"] = batch.batch[\"token_level_scores\"]\n\n                    # compute advantages, executed on the driver process\n                    norm_adv_by_std_in_grpo = self.config.algorithm.get(\"norm_adv_by_std_in_grpo\", True)\n                    batch = compute_advantage(\n                        batch,\n                        adv_estimator=self.config.algorithm.adv_estimator,\n                        gamma=self.config.algorithm.gamma,\n                        lam=self.config.algorithm.lam,\n                        num_repeat=self.config.actor_rollout_ref.rollout.n,\n                        norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                    )\n\n                # implement critic warmup\n                if self.config.trainer.critic_warmup <= self.global_steps:\n                    # update actor\n                    with marked_timer(\"update_actor_call\", timing_raw):\n                        actor_output = self.actor_rollout_wg.update_actor(batch)\n                else:\n                    actor_output = None\n\n                # update critic\n                if self.use_critic:\n                    with marked_timer(\"update_critic_call\", timing_raw):\n                        critic_output = self.critic_wg.update_critic(batch)\n\n                    # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class\n                    with marked_timer(\"update_actor_critic\", timing_raw):\n                        critic_output = critic_output.get()\n                        critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                        metrics.update(critic_output_metrics)\n\n                if actor_output is not None:\n                    actor_output = actor_output.get()\n                    actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                    metrics.update(actor_output_metrics)\n\n                # validate\n                if (\n                    self.val_reward_fn is not None\n                    and self.config.trainer.test_freq > 0\n                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                ):\n                    with marked_timer(\"testing\", timing_raw):\n                        val_metrics: dict = self._validate()\n                        if is_last_step:\n                            last_val_metrics = val_metrics\n                    metrics.update(val_metrics)\n\n                if self.config.trainer.save_freq > 0 and (\n                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n                ):\n                    with marked_timer(\"save_checkpoint\", timing_raw):\n                        self._save_checkpoint()\n\n            # collect metrics\n            metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n            metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n\n            # TODO: make a canonical logger that supports various backend\n            logger.log(data=metrics, step=self.global_steps)\n\n            if self.global_steps >= self.total_training_steps:\n                pprint(f\"Final validation metrics: {last_val_metrics}\")\n                return\n\n            self.global_steps += 1\n"
  },
  {
    "path": "verl_rl/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-0.5b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=0.5b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct\n\nset -x\nnproc_per_gpu=116\nnnodes=1\nngpu_per_node=1\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_rl/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-1.5b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=1.5b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-1.5B-Instruct\n\nset -x\nnproc_per_gpu=128\nnnodes=1\nngpu_per_node=1\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_rl/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-14b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=14b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-14B-Instruct\n\nset -x\nnproc_per_gpu=58 # 32√ → 64× → 48√ → 56√ → 60× → 58√ → 59×\nnnodes=1\nngpu_per_node=2\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.25 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=2 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_rl/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/rlhf/math/test.parquet\nmodel_path=Qwen/Qwen2.5-Coder-14B-Instruct\n\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\nPYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_14b_function_rm' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@\n"
  },
  {
    "path": "verl_rl/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-32b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=32b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-32B-Instruct\n\nset -x\nnproc_per_gpu=45 # 32√ → 64× → 48× → 40√ → 44√ → 46× → 45×\nnnodes=1\nngpu_per_node=4\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_rl/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh",
    "content": "set -x\n\n# we need this to avoid fragmentation of GPU memory\nexport PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256\n\ngsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/rlhf/math/test.parquet\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\nmodel_path=Qwen/Qwen2.5-32B\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=6144 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.actor.megatron.param_offload=True \\\n    actor_rollout_ref.actor.megatron.grad_offload=True \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=True \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.ref.megatron.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='megatron_vllm_qwen2_32b' \\\n    trainer.experiment_name='qwen2_32b_grpo_8_h20' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-3b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=3b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-3B-Instruct\n\nset -x\nnproc_per_gpu=62\nnnodes=1\nngpu_per_node=1\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_rl/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh",
    "content": "set -x\n\ngsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet\ngsm8k_val_path=$HOME/data/rlhf/math/test.parquet\nmodel_path=Qwen/Qwen2-72B-Instruct\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$data_path \\\n    data.val_files=$gsm8k_val_path \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=16 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='Qwen2_72B_Instruct' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=4 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@"
  },
  {
    "path": "verl_rl/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh",
    "content": "set -x\n\n#### important: vllm version must be >= 0.8.3\n\ngsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet\ngsm8k_val_path=$HOME/data/rlhf/math/test.parquet\nmodel_path=Qwen/Qwen2-72B-Instruct\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$gsm8k_train_path \\\n    data.val_files=$gsm8k_val_path \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=16 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='Qwen2_72B_Instruct' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=4 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@"
  },
  {
    "path": "verl_rl/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-72b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=72b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-72B-Instruct\n\nset -x\nnproc_per_gpu=22 # 16√ → 32× → 24× → 20√ → 22√ → 23×\nnnodes=1\nngpu_per_node=8\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=8 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_rl/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh",
    "content": "# -*- coding: utf-8 -*-\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\nNOW=$(date +%Y%m%d)\nexport WANDB_DIR=gsm8k-grpo-lora-qwen2.5-7b-${NOW}\nexport WANDB_PROJECT=${WANDB_DIR}\nexport WANDB_EXP=7b-${NOW}\nMODEL_PATH=Qwen/Qwen2.5-7B-Instruct\n\nset -x\nnproc_per_gpu=16 # 64√ → 128× → 96√ → 112× → 104× → 100√ → 102× → 101×\nnnodes=1\nngpu_per_node=1\ntotal_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node ))\nmini_batch_size=$(( total_procs ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=data/gsm8k/train.parquet \\\n    data.val_files=data/gsm8k/test.parquet \\\n    data.train_batch_size=${total_procs} \\\n    data.val_batch_size=${total_procs} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$MODEL_PATH  \\\n    actor_rollout_ref.model.use_shm=True  \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=32 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.model.target_modules=all-linear \\\n    actor_rollout_ref.actor.optim.lr=3e-5 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.max_num_seqs=512 \\\n    actor_rollout_ref.rollout.max_model_len=1536 \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.entropy_coeff=0.001 \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=${WANDB_PROJECT} \\\n    trainer.experiment_name=${WANDB_EXP} \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log\n"
  },
  {
    "path": "verl_rl/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/rlhf/math/test.parquet\nmodel_path=Qwen/Qwen2-7B-Instruct\n\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\nPYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=2 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/init_ray.sh",
    "content": "#!/bin/bash\n# Single Node Ray Initialization Script\n# Usage: bash init_ray.sh <HEAD_NODE_IP> <PORT> <RANK>\n#   HEAD_NODE_IP: IP address of the head node\n#   PORT: Ray port (default: 6379)\n#   RANK: Node rank (0 for head, >0 for workers)\n\nset -e\n\n# Parse arguments\nHEAD_NODE_IP=${1:-\"127.0.0.1\"}\nPORT=${2:-6379}\nRANK=${3:-0}\n\n# Configuration\nNUM_CPUS=${NUM_CPUS:-\"\"}\nNUM_GPUS=${NUM_GPUS:-\"\"}\nOBJECT_STORE_MEMORY=${OBJECT_STORE_MEMORY:-\"\"}\nCONDA_ENV_NAME=${CONDA_ENV_NAME:-\"verl\"}\n\n# Colors\nGREEN='\\033[0;32m'\nYELLOW='\\033[1;33m'\nNC='\\033[0m'\n\nlog_info() {\n    echo -e \"${GREEN}[INFO]${NC} $(hostname): $1\"\n}\n\nlog_warn() {\n    echo -e \"${YELLOW}[WARN]${NC} $(hostname): $1\"\n}\n\n# Activate conda environment\nif [ -f \"/root/anaconda3/etc/profile.d/conda.sh\" ]; then\n    source \"/root/anaconda3/etc/profile.d/conda.sh\"\nelif [ -f \"$HOME/anaconda3/etc/profile.d/conda.sh\" ]; then\n    source \"$HOME/anaconda3/etc/profile.d/conda.sh\"\nelif [ -f \"$HOME/miniconda3/etc/profile.d/conda.sh\" ]; then\n    source \"$HOME/miniconda3/etc/profile.d/conda.sh\"\nfi\n\nif command -v conda &> /dev/null; then\n    conda activate ${CONDA_ENV_NAME} 2>/dev/null || log_warn \"Could not activate conda env: ${CONDA_ENV_NAME}\"\nfi\n\n# Build ray start command options\nRAY_OPTS=\"\"\nif [ -n \"${NUM_CPUS}\" ]; then\n    RAY_OPTS=\"${RAY_OPTS} --num-cpus=${NUM_CPUS}\"\nfi\nif [ -n \"${NUM_GPUS}\" ]; then\n    RAY_OPTS=\"${RAY_OPTS} --num-gpus=${NUM_GPUS}\"\nfi\nif [ -n \"${OBJECT_STORE_MEMORY}\" ]; then\n    RAY_OPTS=\"${RAY_OPTS} --object-store-memory=${OBJECT_STORE_MEMORY}\"\nfi\n\n# Stop existing Ray instance\nray stop --force 2>/dev/null || true\nsleep 2\n\n# Start Ray\nif [ \"${RANK}\" -eq 0 ]; then\n    log_info \"Starting Ray HEAD node on port ${PORT}...\"\n    ray start --head --port=${PORT} ${RAY_OPTS}\nelse\n    log_info \"Starting Ray WORKER node, connecting to ${HEAD_NODE_IP}:${PORT}...\"\n    ray start --address=${HEAD_NODE_IP}:${PORT} ${RAY_OPTS}\nfi\n\nsleep 3\n\n# Check status\nlog_info \"Ray node started. Checking status...\"\nray status\n"
  },
  {
    "path": "verl_rl/init_ray_cluster.sh",
    "content": "#!/bin/bash\n# Multi-node Ray Cluster Initialization Script\n# Usage: bash init_ray_cluster.sh [--stop]\n#   --stop: Stop Ray on all nodes instead of starting\n\nset -e\n\nSCRIPT_DIR=$(cd $(dirname $0); pwd)\nPROJECT_DIR=${SCRIPT_DIR}\n\n# Configuration\nPORT=${RAY_PORT:-6379}\nHOSTFILE=${HOSTFILE:-\"/etc/mpi/hostfile\"}\nCONDA_ENV_NAME=${CONDA_ENV_NAME:-\"verl\"}\nLOG_DIR=\"${PROJECT_DIR}/logs/ray\"\n\n# Colors\nRED='\\033[0;31m'\nGREEN='\\033[0;32m'\nYELLOW='\\033[1;33m'\nNC='\\033[0m'\n\nlog_info() {\n    echo -e \"${GREEN}[INFO]${NC} $1\"\n}\n\nlog_warn() {\n    echo -e \"${YELLOW}[WARN]${NC} $1\"\n}\n\nlog_error() {\n    echo -e \"${RED}[ERROR]${NC} $1\"\n}\n\n# Generate conda initialization command that works with both anaconda and miniconda\nget_conda_init_cmd() {\n    cat << 'EOF'\nfor conda_sh in /root/miniconda3/etc/profile.d/conda.sh \\\n                /root/anaconda3/etc/profile.d/conda.sh \\\n                $HOME/miniconda3/etc/profile.d/conda.sh \\\n                $HOME/anaconda3/etc/profile.d/conda.sh \\\n                /opt/conda/etc/profile.d/conda.sh; do\n    [ -f \"$conda_sh\" ] && source \"$conda_sh\" && break\ndone\nEOF\n}\n\n# Function to stop Ray on all nodes\nstop_cluster() {\n    log_info \"Stopping Ray on all nodes...\"\n\n    if [ ! -f \"${HOSTFILE}\" ]; then\n        log_warn \"Hostfile not found, stopping local Ray only\"\n        ray stop --force 2>/dev/null || true\n        return\n    fi\n\n    ALL_NODES=$(awk '!a[$1]++ {print $1}' ${HOSTFILE})\n\n    for node in ${ALL_NODES}; do\n        log_info \"Stopping Ray on ${node}...\"\n        ssh -n ${node} \"$(get_conda_init_cmd) && conda activate ${CONDA_ENV_NAME} && ray stop --force\" 2>/dev/null &\n    done\n\n    wait\n    log_info \"Ray stopped on all nodes\"\n}\n\n# Function to start Ray cluster\nstart_cluster() {\n    # Check hostfile\n    if [ ! -f \"${HOSTFILE}\" ]; then\n        log_error \"Hostfile not found: ${HOSTFILE}\"\n        log_info \"Please create a hostfile with one IP per line\"\n        log_info \"Example:\"\n        echo \"  192.168.1.100\"\n        echo \"  192.168.1.101\"\n        echo \"  192.168.1.102\"\n        exit 1\n    fi\n\n    # Get head node (first line)\n    HEAD_NODE=$(awk 'NR==1 {print $1}' ${HOSTFILE})\n    ALL_NODES=$(awk '!a[$1]++ {print $1}' ${HOSTFILE})\n\n    log_info \"Head node: ${HEAD_NODE}\"\n    log_info \"Ray port: ${PORT}\"\n    log_info \"Conda env: ${CONDA_ENV_NAME}\"\n    echo \"\"\n    log_info \"Nodes in cluster:\"\n    echo \"${ALL_NODES}\"\n    echo \"\"\n\n    # Create log directory\n    mkdir -p \"${LOG_DIR}\"\n\n    # Stop existing Ray instances first\n    log_info \"Stopping any existing Ray instances...\"\n    stop_cluster\n    sleep 3\n\n    # Start head node first (synchronously)\n    log_info \"Starting Ray HEAD on ${HEAD_NODE}...\"\n    ssh -n ${HEAD_NODE} \"CONDA_ENV_NAME=${CONDA_ENV_NAME} bash ${SCRIPT_DIR}/init_ray.sh ${HEAD_NODE} ${PORT} 0\" \\\n        > \"${LOG_DIR}/ray_${HEAD_NODE}.log\" 2>&1\n\n    if [ $? -ne 0 ]; then\n        log_error \"Failed to start Ray HEAD. Check ${LOG_DIR}/ray_${HEAD_NODE}.log\"\n        exit 1\n    fi\n    log_info \"Ray HEAD started successfully\"\n\n    # Wait for head to be ready\n    sleep 5\n\n    # Start worker nodes (asynchronously)\n    rank=1\n    for node in ${ALL_NODES}; do\n        if [ \"${node}\" == \"${HEAD_NODE}\" ]; then\n            continue\n        fi\n\n        log_info \"Starting Ray WORKER on ${node} (rank ${rank})...\"\n        ssh -n ${node} \"CONDA_ENV_NAME=${CONDA_ENV_NAME} bash ${SCRIPT_DIR}/init_ray.sh ${HEAD_NODE} ${PORT} ${rank}\" \\\n            > \"${LOG_DIR}/ray_${node}.log\" 2>&1 &\n        rank=$((rank + 1))\n    done\n\n    # Wait for all workers\n    log_info \"Waiting for all workers to join...\"\n    wait\n    sleep 3\n\n    # Check cluster status\n    echo \"\"\n    log_info \"Ray cluster initialization complete!\"\n    log_info \"Logs saved to: ${LOG_DIR}/\"\n    echo \"\"\n    log_info \"Cluster status:\"\n    ssh -n ${HEAD_NODE} \"$(get_conda_init_cmd) && conda activate ${CONDA_ENV_NAME} && ray status\"\n}\n\n# Main\ncase \"${1}\" in\n    --stop)\n        stop_cluster\n        ;;\n    *)\n        start_cluster\n        ;;\nesac\n"
  },
  {
    "path": "verl_rl/pyproject.toml",
    "content": "# -------------------------------\n# build-system\n# -------------------------------\n[build-system]\nrequires = [\n    \"setuptools>=61.0\",\n    \"wheel\"\n]\nbuild-backend = \"setuptools.build_meta\"\n\n# -------------------------------\n# project (PEP 621 metadata)\n# -------------------------------\n[project]\nname = \"verl\"\n# We'll mark the version as \"dynamic\" because it's read from the file \"verl/version/version\" \n# (PEP 621 calls this \"dynamic version\"). \n# The actual version is specified in the [tool.setuptools.dynamic] section below.\ndynamic = [\"version\", \"dependencies\", \"optional-dependencies\", \"authors\", \"urls\"]\n\ndescription = \"verl: Volcano Engine Reinforcement Learning for LLM\"\nlicense = {text = \"Apache-2.0\"}  # Changed from file to text format\nreadme = {file = \"README.md\", content-type = \"text/markdown\"}\nrequires-python = \">=3.10\"\n\n# -------------------------------\n# tool.ruff - Linting configuration\n# -------------------------------\n[tool.ruff]\n# Note: While the formatter will attempt to format lines such that they remain within the line-length,\n# it isn't a hard upper bound, and formatted lines may exceed the line-length.\nline-length = 120\nexclude = [\"tests/workers/rollout/test_sglang_async_rollout_sf_tools.py\", \"scripts/legacy_model_merger.py\"]\n\n[tool.ruff.lint]\nisort = {known-first-party = [\"verl\"]}\n# c.f. https://github.com/vllm-project/vllm/blob/ce8d6b75fc0586045df75ee1568a5b5f9957251b/pyproject.toml\nselect = [\n    # pycodestyle\n    \"E\",\n    # Pyflakes\n    \"F\",\n    # pyupgrade\n    \"UP\",\n    # flake8-bugbear\n    \"B\",\n    # isort\n    \"I\",\n    \"G\",\n]\nignore = [\n    # star imports\n    \"F405\", \"F403\",\n    # lambda expression assignment\n    \"E731\",\n    # Loop control variable not used within loop body\n    \"B007\",\n    # f-string format\n    \"UP032\",\n    # `.log()` statement uses f-string\n    \"G004\",\n    # X | None for type annotations\n    \"UP045\",\n    # deprecated import\n    \"UP035\",\n]\n\n# -------------------------------\n# tool.setuptools - Additional config\n# -------------------------------\n[tool.setuptools]\n# True means `setuptools` will attempt to include all relevant files in package_data automatically.\n# This corresponds to `include_package_data=True` in setup.py.\ninclude-package-data = true\n\n# We read the version from a file in 'verl/version/version'\n[tool.setuptools.dynamic]\nversion = {file = \"verl/version/version\"}\n\n# If you need to mimic `package_dir={'': '.'}`:\n[tool.setuptools.package-dir]\n\"\" = \".\"\n\n# If you need to include specific non-Python data (like YAML files or version file):\n# This is the rough equivalent of package_data={'': ['version/*'], 'verl': ['trainer/config/*.yaml']}\n[tool.setuptools.package-data]\nverl = [\n  \"version/*\",\n  \"trainer/config/*.yaml\",\n  \"trainer/config/*/*.yaml\",\n]\n"
  },
  {
    "path": "verl_rl/recipe/README.md",
    "content": "# Recipe\nThe examples under `recipes/` are representative extensions to verl for specific end-to-end RL training recipes.\nThe help the community reproduce experiments, verl team provides a snapshot of the codebase when each recipe is initially PR'ed to verl main. You can find them via [github branches](https://github.com/volcengine/verl/branches/all?query=recipe)\n\n# Awesome work using verl\n\n- [Logic-RL](https://github.com/Unakar/Logic-RL): a reproduction of DeepSeek R1 Zero on 2K Tiny Logic Puzzle Dataset. ![GitHub Repo stars](https://img.shields.io/github/stars/Unakar/Logic-RL)\n- [Seed-Coder](https://github.com/ByteDance-Seed/Seed-Coder): RL training of Seed-Coder boosts performance on competitive programming ![GitHub Repo stars](https://img.shields.io/github/stars/ByteDance-Seed/Seed-Coder)\n- [all-hands/openhands-lm-32b-v0.1](https://www.all-hands.dev/blog/introducing-openhands-lm-32b----a-strong-open-coding-agent-model): A strong, open coding agent model, trained with [multi-turn fine-tuning](https://github.com/volcengine/verl/pull/195)\n- [s3](https://github.com/pat-jj/s3) **Efficient Yet Effective** Search Agent Training via RL ![GitHub Repo stars](https://img.shields.io/github/stars/pat-jj/s3)\n- [Rec-R1](https://arxiv.org/pdf/2503.24289): Bridging Generative Large Language Models and Recommendation Systems via Reinforcement Learning\n- [Explore RL Data Scaling](https://arxiv.org/abs/2503.22230): Exploring Data Scaling Trends and Effects in Reinforcement Learning from Human Feedback\n- [FIRE](https://arxiv.org/abs/2410.21236): Flaming-hot initiation with regular execution sampling for large language models\n- [DQO](https://arxiv.org/abs/2410.09302): Enhancing multi-Step reasoning abilities of language models through direct Q-function optimization\n- [ProRL](https://arxiv.org/abs/2505.24864): Prolonged Reinforcement Learning Expands Reasoning Boundaries in Large Language Models\n- [cognition-engineering](https://github.com/gair-nlp/cognition-engineering): Test time scaling drives cognition engineering. ![GitHub Repo stars](https://img.shields.io/github/stars/gair-nlp/cognition-engineering)\n- [Trust Region Preference Approximation](https://github.com/XueruiSu/Trust-Region-Preference-Approximation): A simple and stable **reinforcement learning algorithm** for LLM reasoning. ![GitHub Repo stars](https://img.shields.io/github/stars/XueruiSu/Trust-Region-Preference-Approximation)\n- [AdaRFT](https://github.com/uscnlp-lime/verl): Efficient Reinforcement Finetuning via **Adaptive Curriculum Learning** ![GitHub Repo stars](https://img.shields.io/github/stars/uscnlp-lime/verl)\n- [critic-rl](https://github.com/HKUNLP/critic-rl): LLM critics for code generation ![GitHub Repo stars](https://img.shields.io/github/stars/HKUNLP/critic-rl)\n- [self-rewarding-reasoning-LLM](https://arxiv.org/pdf/2502.19613): self-rewarding and correction with **generative reward models** ![GitHub Repo stars](https://img.shields.io/github/stars/RLHFlow/Self-rewarding-reasoning-LLM)\n- [DeepEnlighten](https://github.com/DolbyUUU/DeepEnlighten): Reproduce R1 with **social reasoning** tasks and analyze key findings ![GitHub Repo stars](https://img.shields.io/github/stars/DolbyUUU/DeepEnlighten)\n- [MetaSpatial](https://github.com/PzySeere/MetaSpatial): Reinforcing **3D Spatial Reasoning** in **VLMs** for the **Metaverse** ![GitHub Repo stars](https://img.shields.io/github/stars/PzySeere/MetaSpatial)\n- [PURE](https://github.com/CJReinforce/PURE): **Credit assignment** is the key to successful reinforcement fine-tuning using **process reward model** ![GitHub Repo stars](https://img.shields.io/github/stars/CJReinforce/PURE)\n- [cognitive-behaviors](https://github.com/kanishkg/cognitive-behaviors): Cognitive Behaviors that Enable Self-Improving Reasoners, or, Four Habits of Highly Effective STaRs ![GitHub Repo stars](https://img.shields.io/github/stars/kanishkg/cognitive-behaviors)\n- [deepscaler](https://github.com/agentica-project/rllm/tree/deepscaler): iterative context scaling with GRPO ![GitHub Repo stars](https://img.shields.io/github/stars/agentica-project/deepscaler)\n- [DAPO](https://dapo-sia.github.io/): the fully open source SOTA RL algorithm that beats DeepSeek-R1-zero-32B ![GitHub Repo stars](https://img.shields.io/github/stars/volcengine/verl)\n"
  },
  {
    "path": "verl_rl/recipe/char_count/README.md",
    "content": "# Char Count\n## Introduction\nChar count is a simple NLP task. We create it for beginners to grasp the idea of RLVR. The task can be trained using a tiny model (e.g., https://huggingface.co/HuggingFaceTB/SmolLM2-135M) on a consumer GPU with only 8GB.\n\n## Problem formulation\nThe prompt is: \"How many {char} are there in {word}?\". In order for LLM to better answer this question, we create SFT dataset with intermediate steps. For example,\n\n```text\nQuestion: How many n are there in n-i-n-e?\nAnswer:\nn = n\ni != n\nn = n\ne != n\n\\boxed{2}\n```\n\nNote that\n- We add a dash between each individual char to make the task easier because each individual char will be tokenized to the same token by most tokenizer.\n- In the SFT dataset, we create a CoT by listing all the individual chars and whether it equals to the target. In the end, it outputs the final answer inside the box.\n- The task can be verified.\n- The word is not always meaningful. Each char is sampled uniformly from a to z. We make the total length and the answer uniformly distributed within a range.\n\n## Scripts\nTo create the dataset, run\n```bash\npython3 create_dataset.py\n```\nWe create a train set and a val set. Both of them are used of SFT and RL. You can specify the total number of data, min/max length and data path.\n\nTo run the SFT\n```bash\nbash train_sft.sh\n```\nWe train SFT for 3 epochs. After 3 epochs, the validation score is around 0.12.\n\nTo run GRPO\n```bash\nbash train_grpo.sh\n```\nWe train GRPO for 2 epochs. After 2 epochs, the validation score is around 0.36.\n"
  },
  {
    "path": "verl_rl/recipe/char_count/create_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nTask description:\nGiven a random word and a random char, count the number of occurrence of char in the word.\n\nCreate CoT dataset that split the word into separate char. Then list the char and count the occurrence.\n\nThe word set comes from shakespeare\n\"\"\"\n\nimport os.path\nimport random\n\nprompt_template = \"How many {} are there in word {}?\"\n\n\ndef generate_random_char():\n    return chr(97 + random.randint(0, 25))\n\n\ndef create_prompt_response(min_length=3, max_length=5):\n    # randomly generate a length\n    word_length = random.randint(min_length, max_length)\n    # randomly generate a target count number. This makes the target number\n    target_count_number = random.randint(1, word_length)\n\n    char_lst = []\n    # generate the word\n    # step 1: generate the target word\n    target_char = generate_random_char()\n\n    for _ in range(target_count_number):\n        char_lst.append(target_char)\n\n    # step 2: generate other words\n    for _ in range(word_length - target_count_number):\n        while True:\n            char = generate_random_char()\n            if char != target_char:\n                char_lst.append(char)\n                break\n\n    # step 3: random permute char_lst\n    random.shuffle(char_lst)\n\n    word = \"-\".join(char_lst)\n\n    prompt = prompt_template.format(target_char, word)\n    final_answer = []\n\n    # cot\n    number = 0\n    for i, char in enumerate(char_lst):\n        cot = f\"{char}\"\n        if char != target_char:\n            cot += \" != \"\n        else:\n            cot += \" = \"\n            number += 1\n        cot += f\"{target_char}.\"\n\n        final_answer.append(cot)\n\n    conclusion = f\"\\\\boxed{{{number}}} {target_char} in {word}.\"\n\n    final_answer.append(conclusion)\n\n    final_answer = \"\\n\".join(final_answer)\n\n    return prompt, final_answer\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--total_number\", type=int, default=10000)\n    parser.add_argument(\"--min_length\", type=int, default=5)\n    parser.add_argument(\"--max_length\", type=int, default=20)\n    parser.add_argument(\"--data_path\", type=str, default=\"~/data/char_count\")\n\n    args = vars(parser.parse_args())\n\n    total_number = args[\"total_number\"]\n    min_length = args[\"min_length\"]\n    max_length = args[\"max_length\"]\n    data_path = args[\"data_path\"]\n    data_path = os.path.expanduser(data_path)\n\n    full_output = []\n    for _ in range(total_number):\n        output = create_prompt_response(min_length=min_length, max_length=max_length)\n        full_output.append(output)\n\n    # random reorder\n    random.shuffle(full_output)\n\n    # split for train and test\n    train_split_len = int(0.9 * len(full_output))\n    train_outputs = full_output[:train_split_len]\n    test_output = full_output[train_split_len:]\n\n    sft_train_dataset = {\"prompt\": [], \"response\": []}\n\n    for o in train_outputs:\n        sft_train_dataset[\"prompt\"].append(o[0])\n        sft_train_dataset[\"response\"].append(o[1])\n\n    sft_test_dataset = {\"prompt\": [], \"response\": []}\n\n    for o in test_output:\n        sft_test_dataset[\"prompt\"].append(o[0])\n        sft_test_dataset[\"response\"].append(o[1])\n\n    import pandas as pd\n\n    sft_train_dataset = pd.DataFrame(data=sft_train_dataset)\n    sft_test_dataset = pd.DataFrame(data=sft_test_dataset)\n\n    folder = os.path.join(data_path, \"sft\")\n\n    os.makedirs(folder, exist_ok=True)\n\n    sft_train_dataset.to_parquet(os.path.join(folder, \"train.parquet\"))\n    sft_test_dataset.to_parquet(os.path.join(folder, \"test.parquet\"))\n\n    # build RL dataset\n    rl_train_dataset = {\"prompt\": [], \"data_source\": [], \"ability\": [], \"reward_model\": [], \"extra_info\": []}\n\n    rl_test_dataset = {\"prompt\": [], \"data_source\": [], \"ability\": [], \"reward_model\": [], \"extra_info\": []}\n\n    from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed\n\n    for o in train_outputs:\n        prompt = o[0]\n        response = o[1]\n        prompt_with_template = [\n            {\n                \"role\": \"user\",\n                \"content\": prompt,\n            }\n        ]\n\n        rl_train_dataset[\"prompt\"].append(prompt_with_template)\n        rl_train_dataset[\"data_source\"].append(\"char_count\")\n        rl_train_dataset[\"ability\"].append(\"other\")\n        rl_train_dataset[\"reward_model\"].append(\n            {\"style\": \"rule\", \"ground_truth\": remove_boxed(last_boxed_only_string(response))}\n        )\n        rl_train_dataset[\"extra_info\"].append({\"response\": response})\n\n    for o in test_output:\n        prompt = o[0]\n        response = o[1]\n        prompt_with_template = [\n            {\n                \"role\": \"user\",\n                \"content\": prompt,\n            }\n        ]\n\n        rl_test_dataset[\"prompt\"].append(prompt_with_template)\n        rl_test_dataset[\"data_source\"].append(\"char_count\")\n        rl_test_dataset[\"ability\"].append(\"other\")\n        rl_test_dataset[\"reward_model\"].append(\n            {\"style\": \"rule\", \"ground_truth\": remove_boxed(last_boxed_only_string(response))}\n        )\n        rl_test_dataset[\"extra_info\"].append({\"response\": response})\n\n    rl_train_dataset = pd.DataFrame(data=rl_train_dataset)\n    rl_test_dataset = pd.DataFrame(data=rl_test_dataset)\n\n    folder = os.path.join(data_path, \"rl\")\n\n    os.makedirs(folder, exist_ok=True)\n\n    rl_train_dataset.to_parquet(os.path.join(folder, \"train.parquet\"))\n    rl_test_dataset.to_parquet(os.path.join(folder, \"test.parquet\"))\n"
  },
  {
    "path": "verl_rl/recipe/char_count/reward_function.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nReward function\n\"\"\"\n\nfrom verl.utils.reward_score import math\n\n\ndef char_count_reward_function(data_source, solution_str, ground_truth, extra_info=None):\n    try:\n        last_boxed_string = math.last_boxed_only_string(solution_str)\n        if last_boxed_string is None:\n            return 0\n        solution = math.remove_boxed(last_boxed_string)\n        if solution == ground_truth:\n            return 1\n        else:\n            return 0\n    except Exception:\n        print(ground_truth, solution_str)\n        return 0\n"
  },
  {
    "path": "verl_rl/recipe/char_count/train_grpo.sh",
    "content": "set -x\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/char_count/rl/train.parquet \\\n    data.val_files=$HOME/data/char_count/rl/test.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=128 \\\n    data.max_response_length=128 \\\n    data.filter_overlong_prompts=False \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=./models/sft/global_step_105 \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=16 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5000 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.0 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"tensorboard\"]' \\\n    trainer.project_name='verl_example' \\\n    trainer.experiment_name='smol135m_grpo' \\\n    trainer.val_before_train=True \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=2 \\\n    custom_reward_function.path=recipe/char_count/reward_function.py \\\n    custom_reward_function.name=char_count_reward_function\n"
  },
  {
    "path": "verl_rl/recipe/char_count/train_sft.sh",
    "content": "set -x\n\nnproc_per_node=1\nsave_path=./models/sft\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/char_count/sft/train.parquet \\\n    data.val_files=$HOME/data/char_count/sft/test.parquet \\\n    data.prompt_key=prompt \\\n    data.response_key=response \\\n    data.micro_batch_size_per_gpu=8 \\\n    data.max_length=256 \\\n    data.train_batch_size=256 \\\n    use_remove_padding=True \\\n    model.partial_pretrain=HuggingFaceTB/SmolLM2-135M-Instruct \\\n    trainer.default_local_dir=$save_path \\\n    trainer.project_name=char_count-sft \\\n    trainer.experiment_name=char_count-sft-SmolLM2-135M-Instruct \\\n    trainer.total_epochs=3 \\\n    trainer.logger=console"
  },
  {
    "path": "verl_rl/recipe/dapo/README.md",
    "content": "# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO)\n\n> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211)\n\n> [!IMPORTANT]\n>\n> **🔥 News!!!**\n>\n> - [2025/04] We reproduced the results of two versions of DAPO ([Full](./run_dapo_qwen2.5_32b.sh) & [w/o Dynamic Sampling](./run_dapo_wo_ds_qwen2.5_32b.sh)), achieving 52% and 50% on AIME 2024 respectively, based on [the latest codebase on `recipe/dapo`](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo). Please check the details in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n).\n> - [2025/03] We published the training record of [an early version of DAPO (w/o Token-level PG Loss & Dynamic Sampling)](./run_dapo_early_qwen2.5_32b.sh), achieving 44% on AIME 2024, in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n).\n\n🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper@arXiv](https://arxiv.org/abs/2503.14476) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO)\n\n> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps.\n>\n> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png)\n\n## Quickstart\n\n1. Prepare the datasets **on the Ray cluster**:\n\n```bash\nbash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default\n```\n\n2. Submit the job to the Ray cluster **from any machine**:\n\n```bash\ncd verl # Repo root\nexport RAY_ADDRESS=\"http://${RAY_IP:-localhost}:8265\" # The Ray cluster address to connect to\nexport WORKING_DIR=\"${PWD}\" # The local directory to package to the Ray cluster\n# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml\nexport RUNTIME_ENV=\"./recipe/dapo/runtime_env.yaml\" # This sets environment variables for the Ray cluster\nbash recipe/dapo/run_dapo_qwen2.5_32b.sh # or other scripts\n```\n\n## Reproduction Runs\n\n| Setup                                        | AIME 2024 Acc. | Hardware  | Image                                                                | Commit                                                                                       | Environment Variables                                                                                                             | Training Script                                                                                                                                             | Training Record                                                                           |\n| -------------------------------------------- | -------------- | --------- | -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- |\n| DAPO                                         | 52%            | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_qwen2.5_32b.sh)             | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n| DAPO w/o Dynamic Sampling                    | 50%            | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_wo_ds_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n| DAPO w/o Token-level Loss & Dynamic Sampling | 44%            | 16x8xH20  | `hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix`                    | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_early_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) |\n\n> [!IMPORTANT]\n>\n> **📢 Call for Contribution!**\n>\n> Welcome to submit your reproduction runs and setups!\n\n## Configuration\n\n### Separated Clip Epsilons (-> Clip-Higher)\n\nAn example configuration:\n\n```yaml\nactor_rollout_ref:\n  actor:\n    clip_ratio_low: 0.2\n    clip_ratio_high: 0.28\n```\n\n`clip_ratio_low` and `clip_ratio_high` specify the $\\varepsilon_{\\text {low }}$ and $\\varepsilon_{\\text {high }}$ in the DAPO objective.\n\nCore relevant code:\n\n```python\npg_losses1 = -advantages * ratio\npg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)\npg_losses = torch.maximum(pg_losses1, pg_losses2)\n```\n\n### Dynamic Sampling (with Group Filtering)\n\nAn example configuration:\n\n```yaml\ndata:\n  gen_batch_size: 1536\n  train_batch_size: 512\nalgorithm:\n  filter_groups:\n    enable: True\n    metric: acc # score / seq_reward / seq_final_reward / ...\n    max_num_gen_batches: 10 # Non-positive values mean no upper limit\n```\n\nSetting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0.\n\nThe trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`.\n\nCore relevant code:\n\n```python\nprompt_bsz = self.config.data.train_batch_size\nif num_prompt_in_batch < prompt_bsz:\n    print(f'{num_prompt_in_batch=} < {prompt_bsz=}')\n    num_gen_batches += 1\n    max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches\n    if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:\n        print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...')\n        continue\n    else:\n        raise ValueError(\n            f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.'\n        )\nelse:\n    # Align the batch\n    traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n\n    batch = batch[:traj_bsz]\n```\n\n### Flexible Loss Aggregation Mode (-> Token-level Loss)\n\nAn example configuration:\n\n```yaml\nactor_rollout_ref:\n  actor:\n    loss_agg_mode: \"token-mean\" # / \"seq-mean-token-sum\" / \"seq-mean-token-mean\"\n    # NOTE: \"token-mean\" is the default behavior\n```\n\nSetting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch.\n\nCore relevant code:\n\n```python\nif loss_agg_mode == \"token-mean\":\n    loss = verl_F.masked_mean(loss_mat, loss_mask)\nelif loss_agg_mode == \"seq-mean-token-sum\":\n    seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum\n    loss = torch.mean(seq_losses)  # seq-mean\nelif loss_agg_mode == \"seq-mean-token-mean\":\n    seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)  # token-mean\n    loss = torch.mean(seq_losses)  # seq-mean\nelse:\n    raise ValueError(f\"Invalid loss_agg_mode: {loss_agg_mode}\")\n```\n\n### Overlong Reward Shaping\n\nAn example configuration:\n\n```yaml\ndata:\n  max_response_length: 20480 # 16384 + 4096\nreward_model:\n  overlong_buffer:\n    enable: True\n    len: 4096\n    penalty_factor: 1.0\n```\n\nSetting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit.\n\nSpecifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length` by `0` to `overlong_buffer.len` tokens.\n\nCore relevant code:\n\n```python\nif self.overlong_buffer_cfg.enable:\n    overlong_buffer_len = self.overlong_buffer_cfg.len\n    expected_len = self.max_resp_len - overlong_buffer_len\n    exceed_len = valid_response_length - expected_len\n    overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor\n    overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)\n    reward += overlong_reward\n```\n\n## FAQ\n\n### Where is the \"Overlong Filtering\" in the paper?\n\nMost experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here.\n\n### What's the difference between [the `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) and the [`recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)?\n\n[The `recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) is for **as-is reproduction** and thus won't be updated with new features.\n\n[The `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) works as an example of how to extend the latest `verl` to implement an algorithm recipe, which will be maintained with new features.\n\n### Why can't I produce similar results after modifications?\n\nRL infrastructures nowadays still have inherent unrobustness, on which we are still working hard to improve.\n\nWe strongly recommend to only modify one thing at a time.\n\nWe also list some known problems here:\n\n1. Enabling CUDA graph (`enforce_eager=False`) might cause model performance degradation, whose cause is still under investigation.\n"
  },
  {
    "path": "verl_rl/recipe/dapo/config/dapo_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  gen_batch_size: ${data.train_batch_size}\n\nreward_model:\n  reward_manager: dapo\n  overlong_buffer: \n    enable: False # We try to avoid forgetting to set enable\n    len: 0\n    penalty_factor: 0.0\n    log: False\n\nalgorithm:\n  filter_groups:\n    _target_: verl.trainer.config.FilterGroupsConfig\n    enable: False # We try to avoid forgetting to set enable\n    metric: null # acc / score / seq_reward / seq_final_reward / ...\n    max_num_gen_batches: 0 # Non-positive values mean no upper limit\n\ntrainer:\n  project_name: verl-dapo\n"
  },
  {
    "path": "verl_rl/recipe/dapo/dapo_ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nFSDP PPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport uuid\nfrom collections import defaultdict\nfrom copy import deepcopy\nfrom pprint import pprint\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.core_algos import agg_loss\nfrom verl.trainer.ppo.metric_utils import (\n    compute_data_metrics,\n    compute_throughout_metrics,\n    compute_timing_metrics,\n    reduce_metrics,\n)\nfrom verl.trainer.ppo.ray_trainer import (\n    AdvantageEstimator,\n    RayPPOTrainer,\n    apply_kl_penalty,\n    compute_advantage,\n    compute_response_mask,\n)\nfrom verl.utils.profiler import marked_timer\n\n\nclass RayDAPOTrainer(RayPPOTrainer):\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n        self.gen_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        self.gen_steps += 1\n        last_val_metrics = None\n\n        timing_raw = defaultdict(float)\n        batch = None\n        num_prompt_in_batch = 0\n        num_gen_batches = 0\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n\n                do_profile = (\n                    self.global_steps in self.config.trainer.profile_steps\n                    if self.config.trainer.profile_steps is not None\n                    else False\n                )\n                with marked_timer(\"start_profile\", timing_raw):\n                    if do_profile:\n                        self.actor_rollout_wg.start_profile(role=\"e2e\", profile_step=self.global_steps)\n                        if self.use_reference_policy:\n                            self.ref_policy_wg.start_profile()\n                        if self.use_critic:\n                            self.critic_wg.start_profile()\n                        if self.use_rm:\n                            self.rm_wg.start_profile()\n\n                new_batch: DataProto = DataProto.from_single_dict(batch_dict)\n                num_gen_batches += 1\n                # pop those keys for generation\n                if \"multi_modal_data\" in new_batch.non_tensor_batch.keys():\n                    gen_batch = new_batch.pop(\n                        batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"],\n                        non_tensor_batch_keys=[\"raw_prompt_ids\", \"multi_modal_data\"],\n                    )\n                else:\n                    gen_batch = new_batch.pop(\n                        batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"],\n                        non_tensor_batch_keys=[\"raw_prompt_ids\"],\n                    )\n                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n\n                is_last_step = self.gen_steps >= self.total_training_steps\n\n                with marked_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with marked_timer(\"gen\", timing_raw, \"red\"):\n                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                        with marked_timer(\"gen_max\", timing_raw, \"red\"):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n\n                            new_batch = new_batch.union(gen_baseline_output)\n                            reward_baseline_tensor = self.reward_fn(new_batch)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))\n\n                            new_batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del gen_baseline_batch, gen_baseline_output\n\n                    new_batch.non_tensor_batch[\"uid\"] = np.array(\n                        [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object\n                    )\n                    # repeat to align with repeated responses in rollout\n                    new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    new_batch = new_batch.union(gen_batch_output)\n\n                    with marked_timer(\"reward\", timing_raw, \"yellow\"):\n                        # compute scores. Support both model and function-based.\n                        # We first compute the scores using reward model. Then, we call reward_fn to combine\n                        # the results from reward model and rule-based results.\n                        if self.use_rm:\n                            # we first compute reward model score\n                            reward_tensor = self.rm_wg.compute_rm_score(new_batch)\n                            new_batch = new_batch.union(reward_tensor)\n\n                        # we combine with rule-based rm\n                        reward_extra_infos_dict: dict[str, list]\n                        try:\n                            reward_result = self.reward_fn(new_batch, return_dict=True)\n                            reward_tensor = reward_result[\"reward_tensor\"]\n                            reward_extra_infos_dict = reward_result.get(\"reward_extra_info\", {})\n                        except Exception as e:\n                            print(f\"Error in reward_fn: {e}\")\n                            reward_tensor = self.reward_fn(new_batch)\n                            reward_extra_infos_dict = {}\n\n                        new_batch.batch[\"token_level_scores\"] = reward_tensor\n\n                        if reward_extra_infos_dict:\n                            new_batch.non_tensor_batch.update(\n                                {k: np.array(v) for k, v in reward_extra_infos_dict.items()}\n                            )\n\n                        # compute rewards. apply_kl_penalty if available\n                        if self.config.algorithm.use_kl_in_reward:\n                            new_batch, kl_metrics = apply_kl_penalty(\n                                new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                            )\n                            metrics.update(\n                                kl_metrics\n                            )  # TODO: This will be cleared if we use multiple genenration batches\n                        else:\n                            new_batch.batch[\"token_level_rewards\"] = new_batch.batch[\"token_level_scores\"]\n\n                    if not self.config.algorithm.filter_groups.enable:\n                        batch = new_batch\n                    else:  # NOTE: When prompts after filtering is less than train batch size,\n                        # we skip to the next generation batch\n                        metric_name = self.config.algorithm.filter_groups.metric\n                        if metric_name == \"seq_final_reward\":\n                            # Turn to numpy for easier filtering\n                            new_batch.non_tensor_batch[\"seq_final_reward\"] = (\n                                new_batch.batch[\"token_level_rewards\"].sum(dim=-1).numpy()\n                            )\n                        elif metric_name == \"seq_reward\":\n                            new_batch.non_tensor_batch[\"seq_reward\"] = (\n                                new_batch.batch[\"token_level_scores\"].sum(dim=-1).numpy()\n                            )\n\n                        # Collect the sequence reward for each trajectory\n                        prompt_uid2metric_vals = defaultdict(list)\n                        for uid, metric_val in zip(\n                            new_batch.non_tensor_batch[\"uid\"], new_batch.non_tensor_batch[metric_name], strict=True\n                        ):\n                            prompt_uid2metric_vals[uid].append(metric_val)\n\n                        prompt_uid2metric_std = {}\n                        for prompt_uid, metric_vals in prompt_uid2metric_vals.items():\n                            prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)\n\n                        kept_prompt_uids = [\n                            uid\n                            for uid, std in prompt_uid2metric_std.items()\n                            if std > 0 or len(prompt_uid2metric_vals[uid]) == 1\n                        ]\n                        num_prompt_in_batch += len(kept_prompt_uids)\n\n                        kept_traj_idxs = []\n                        for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch[\"uid\"]):\n                            if traj_from_prompt_uid in kept_prompt_uids:\n                                kept_traj_idxs.append(idx)\n\n                        new_batch = new_batch[kept_traj_idxs]\n                        batch = new_batch if batch is None else DataProto.concat([batch, new_batch])\n\n                        prompt_bsz = self.config.data.train_batch_size\n                        if num_prompt_in_batch < prompt_bsz:\n                            print(f\"{num_prompt_in_batch=} < {prompt_bsz=}\")\n                            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches\n                            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:\n                                print(f\"{num_gen_batches=}. Keep generating...\")\n                                progress_bar.update(1)\n                                self.gen_steps += 1\n                                continue\n                            else:\n                                raise ValueError(\n                                    f\"{num_gen_batches=} >= {max_num_gen_batches=}.\"\n                                    + \" Generated too many. Please check if your data are too difficult.\"\n                                    + \" You could also try set max_num_gen_batches=0 to enable endless trials.\"\n                                )\n                        else:\n                            # Align the batch\n                            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n\n                            batch = batch[:traj_bsz]\n\n                    # === Updating ===\n\n                    batch.batch[\"response_mask\"] = compute_response_mask(batch)\n\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    # TODO: Decouple the DP balancing and mini-batching.\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                    # recompute old_log_probs\n                    with marked_timer(\"old_log_prob\", timing_raw, \"blue\"):\n                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                        entropys = old_log_prob.batch[\"entropys\"]\n                        response_masks = batch.batch[\"response_mask\"]\n                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                        old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n                        metrics.update(old_log_prob_metrics)\n                        old_log_prob.batch.pop(\"entropys\")\n                        batch = batch.union(old_log_prob)\n\n                    if self.use_reference_policy:\n                        # compute reference log_prob\n                        with marked_timer(\"ref\", timing_raw, \"olive\"):\n                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                            batch = batch.union(ref_log_prob)\n\n                    # compute values\n                    if self.use_critic:\n                        with marked_timer(\"values\", timing_raw, \"cyan\"):\n                            values = self.critic_wg.compute_values(batch)\n                            batch = batch.union(values)\n\n                    with marked_timer(\"adv\", timing_raw, \"brown\"):\n                        # compute advantages, executed on the driver process\n                        norm_adv_by_std_in_grpo = self.config.algorithm.get(\"norm_adv_by_std_in_grpo\", True)\n                        batch = compute_advantage(\n                            batch,\n                            adv_estimator=self.config.algorithm.adv_estimator,\n                            gamma=self.config.algorithm.gamma,\n                            lam=self.config.algorithm.lam,\n                            num_repeat=self.config.actor_rollout_ref.rollout.n,\n                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                        )\n\n                    # update critic\n                    if self.use_critic:\n                        with marked_timer(\"update_critic\", timing_raw, \"pink\"):\n                            critic_output = self.critic_wg.update_critic(batch)\n                        critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                        metrics.update(critic_output_metrics)\n\n                    # implement critic warmup\n                    if self.config.trainer.critic_warmup <= self.global_steps:\n                        # update actor\n                        with marked_timer(\"update_actor\", timing_raw, \"red\"):\n                            actor_output = self.actor_rollout_wg.update_actor(batch)\n                        actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                        metrics.update(actor_output_metrics)\n\n                    # validate\n                    if (\n                        self.val_reward_fn is not None\n                        and self.config.trainer.test_freq > 0\n                        and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                    ):\n                        with marked_timer(\"testing\", timing_raw, \"green\"):\n                            val_metrics: dict = self._validate()\n                            if is_last_step:\n                                last_val_metrics = val_metrics\n                        metrics.update(val_metrics)\n\n                    if self.config.trainer.save_freq > 0 and (\n                        is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n                    ):\n                        with marked_timer(\"save_checkpoint\", timing_raw, \"green\"):\n                            self._save_checkpoint()\n\n                with marked_timer(\"stop_profile\", timing_raw):\n                    if do_profile:\n                        self.actor_rollout_wg.stop_profile()\n                        if self.use_reference_policy:\n                            self.ref_policy_wg.stop_profile()\n                        if self.use_critic:\n                            self.critic_wg.stop_profile()\n                        if self.use_rm:\n                            self.rm_wg.stop_profile()\n\n                # collect metrics\n                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n                # TODO: implement actual tflpo and theoretical tflpo\n                n_gpus = self.resource_pool_manager.get_n_gpus()\n                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n                timing_raw = defaultdict(float)  # clear timing\n\n                metrics[\"train/num_gen_batches\"] = num_gen_batches\n                batch = None\n                num_prompt_in_batch = 0\n                num_gen_batches = 0\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                if is_last_step:\n                    pprint(f\"Final validation metrics: {last_val_metrics}\")\n                    progress_bar.close()\n                    return\n\n                progress_bar.update(1)\n                self.global_steps += 1\n                self.gen_steps += 1\n"
  },
  {
    "path": "verl_rl/recipe/dapo/main_dapo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport os\nimport socket\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom verl.trainer.ppo.reward import load_reward_manager\nfrom verl.utils.device import is_cuda_available\n\nfrom .dapo_ray_trainer import RayDAPOTrainer\n\n\n@hydra.main(config_path=\"config\", config_name=\"dapo_trainer\", version_base=None)\ndef main(config):\n    run_ppo(config)\n\n\ndef run_ppo(config) -> None:\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        ray.init(\n            runtime_env={\n                \"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\", \"VLLM_LOGGING_LEVEL\": \"WARN\"}\n            },\n            num_cpus=config.ray_init.num_cpus,\n        )\n\n    if (\n        is_cuda_available\n        and OmegaConf.select(config.trainer, \"profile_steps\") is not None\n        and len(OmegaConf.select(config.trainer, \"profile_steps\")) > 0\n    ):\n        nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)\n        runner = TaskRunner.options(runtime_env={\"nsight\": nsight_options}).remote()\n    else:\n        runner = TaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    def run(self, config):\n        # print initial config\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        print(f\"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}\")\n\n        pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n        OmegaConf.resolve(config)\n\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n\n        # instantiate tokenizer\n        from verl.utils import hf_processor, hf_tokenizer\n\n        tokenizer = hf_tokenizer(local_path)\n        processor = hf_processor(local_path, use_fast=True)  # used for multimodal LLM, could be none\n\n        # define worker classes\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n            from verl.single_controller.ray import RayWorkerGroup\n            from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker\n\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n            from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker\n\n            ray_worker_group_cls = NVMegatronRayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n        role_worker_mapping = {\n            Role.ActorRollout: ray.remote(ActorRolloutRefWorker),\n            Role.Critic: ray.remote(CriticWorker),\n        }\n\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {\n            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n        }\n        mapping = {\n            Role.ActorRollout: global_pool_id,\n            Role.Critic: global_pool_id,\n        }\n\n        # we should adopt a multi-source reward function here\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # - finally, we combine all the rewards together\n        # - The reward type depends on the tag of the data\n        if config.reward_model.enable:\n            if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                from verl.workers.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # reference model\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n            mapping[Role.RefPolicy] = global_pool_id\n\n        reward_fn = load_reward_manager(\n            config,\n            tokenizer,\n            0,\n            max_resp_len=config.data.max_response_length,\n            overlong_buffer_cfg=config.reward_model.overlong_buffer,\n        )\n\n        # Note that we always use function-based RM for validation\n        val_reward_fn = load_reward_manager(\n            config,\n            tokenizer,\n            1,\n            max_resp_len=config.data.max_response_length,\n            overlong_buffer_cfg=config.reward_model.overlong_buffer,\n        )\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        trainer = RayDAPOTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n        )\n        trainer.init_workers()\n        trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/recipe/dapo/prepare_dapo_data.sh",
    "content": "#!/usr/bin/env bash\nset -uxo pipefail\n\nexport VERL_HOME=${VERL_HOME:-\"${HOME}/verl\"}\nexport TRAIN_FILE=${TRAIN_FILE:-\"${VERL_HOME}/data/dapo-math-17k.parquet\"}\nexport TEST_FILE=${TEST_FILE:-\"${VERL_HOME}/data/aime-2024.parquet\"}\nexport OVERWRITE=${OVERWRITE:-0}\n\nmkdir -p \"${VERL_HOME}/data\"\n\nif [ ! -f \"${TRAIN_FILE}\" ] || [ \"${OVERWRITE}\" -eq 1 ]; then\n  wget -O \"${TRAIN_FILE}\" \"https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/resolve/main/data/dapo-math-17k.parquet?download=true\"\nfi\n\nif [ ! -f \"${TEST_FILE}\" ] || [ \"${OVERWRITE}\" -eq 1 ]; then\n  wget -O \"${TEST_FILE}\" \"https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024/resolve/main/data/aime-2024.parquet?download=true\"\nfi\n"
  },
  {
    "path": "verl_rl/recipe/dapo/run_dapo_early_qwen2.5_32b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Early-Qwen2.5-32B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\n# An early version for DAPO\nloss_agg_mode=\"seq-mean-token-mean\"\n\nenable_filter_groups=False\ngen_prompt_bsz=512 # NOTE: no filtering here\ntrain_prompt_bsz=512\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=16\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-16}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$((max_prompt_length + max_response_length))\ninfer_ppo_max_token_len=$((max_prompt_length + max_response_length))\noffload=True\ngen_tp=4\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=5 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto\n"
  },
  {
    "path": "verl_rl/recipe/dapo/run_dapo_qwen2.5_32b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-32B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=512\ngen_prompt_bsz=$((train_prompt_bsz * 3))\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-16}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$((max_prompt_length + max_response_length))\ninfer_ppo_max_token_len=$((max_prompt_length + max_response_length))\noffload=True\ngen_tp=4\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=5 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto\n"
  },
  {
    "path": "verl_rl/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh",
    "content": "#!/usr/bin/env bash\nset -euxo pipefail\n# DAPO (w/o Dynamic Sampling)\n\nproject_name='DAPO-verl'\nexp_name='DAPO-wo-DS-Qwen2.5-32B'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 20))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=False\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-16}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-32B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=8\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$((max_prompt_length + max_response_length))\ninfer_ppo_max_token_len=$((max_prompt_length + max_response_length))\noffload=True\ngen_tp=4\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=5 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto\n"
  },
  {
    "path": "verl_rl/recipe/dapo/runtime_env.yaml",
    "content": "working_dir: ./\nexcludes: [\"/.git/\"]\nenv_vars:\n  TORCH_NCCL_AVOID_RECORD_STREAMS: \"1\"\n  VLLM_USE_V1: \"1\"\n"
  },
  {
    "path": "verl_rl/recipe/dapo/test_dapo_7b.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7B-Math-Test'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 2))\nenable_overlong_buffer=True\noverlong_buffer_len=512\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=512\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=16\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nray job submit --no-wait --runtime-env=\"${RUNTIME_ENV}\" \\\n    --working-dir \"${WORKING_DIR}\" \\\n    -- python3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=2 \\\n    trainer.save_freq=2 \\\n    trainer.total_epochs=1 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_rl/recipe/dapo/test_dapo_7b_math.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-8}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\nfsdp_size=32\n\n# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=200 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_rl/recipe/dapo/test_dapo_7b_math_lora.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-8}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\nfsdp_size=32\n\n# remember to set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for this model\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.model.lora_rank=8 \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=200 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_rl/recipe/dapo/test_dapo_7b_math_megatron.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-megatron-0519a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\ntrain_tp=4\ntrain_pp=2\n\n# TODO: support dynamic_bsz for megatron\n# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=16 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_rl/recipe/dapo/test_dapo_dspk_671b_megatron.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n# 0. download the config\n# only need to download the configuration_deepseek.py and config.json\n# remove the `quantization_config` in the `config.json`\n# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported\nhuggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json\n\nproject_name='DAPO'\nexp_name='DAPO-DeepSeek-671b-megatron'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 4))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=0.1\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512 # must be > n_gpus. need to fix\nn_resp_per_prompt=2\ntrain_prompt_mini_bsz=16  # mini_bsz * n >= micro_bsz * pp * dp\n\nNNODES=${NNODES:-64}\n\n# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main\n# change the MODEL_PATH and MCORE_MODEL_PATH to your own path\n# Paths\nMODEL_PATH=\"<path_to_dsv3_config>\"\nMCORE_MODEL_PATH=\"<path_to_dpsk-v3-671B-BF16-dist_ckpt>\"\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\naime24_test_path=${RAY_DATA_HOME}/data/aime-2024.parquet\n# TEST_FILE=\"['$math500_test_path', '$aime24_test_path']\"\n\nTEST_FILE=\"['$aime24_test_path']\"\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=32\ntrain_tp=1\ntrain_ep=32\ntrain_pp=16\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=3 \\\n    +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=2 \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=5 \\\n    trainer.save_freq=5 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=10 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_rl/recipe/dapo/test_dapo_qwen3_30b_math.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0527a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-8}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\nfsdp_size=32\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=10 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=300 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_rl/recipe/dapo/test_dapo_qwen3_30b_math_single_node.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0719a1'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 4))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=0.1\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=64\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=16\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-1}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nsp_size=4\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=4\nfsdp_size=8\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=300 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_rl/recipe/entropy/32b_clip_cov.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport WANDB_API_KEY=YOUR_WANDB_API_KEY\n# export VLLM_USE_V1=1\n\nproject_name='Qwen2.5-32B'\nexp_name='clipcov'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=1\nclip_ratio_high=1\nclip_cov_ratio=0.0002\nclip_cov_lb=1.0\nclip_cov_ub=5.0\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 2))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=\"clip_cov\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=256\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=8\nmax_token=20480\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"/YOUR_MODELPATH\"}\nCKPTS_DIR=${CKPTS_DIR:-\"/YOUR_CKPTS_PATH\"}\nTRAIN_FILE=${TRAIN_FILE:-\"/YOUR_TRAIN_FILE_PATH\"}\nTEST_FILE=${TEST_FILE:-[\"/YOUR_TRAIN_FILE_PATH\"]}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nppo_kl_coef=1\nkl_cov_ratio=0.02\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nHYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.filter_overlong_prompts=False \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.weight_decay=0 \\\n    actor_rollout_ref.actor.optim.warmup_style=constant \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.clip_cov_ratio=${clip_cov_ratio} \\\n    actor_rollout_ref.actor.clip_cov_lb=${clip_cov_lb} \\\n    actor_rollout_ref.actor.clip_cov_ub=${clip_cov_ub} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=False \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=4 \\\n    trainer.save_freq=32 \\\n    trainer.total_epochs=1000 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_rl/recipe/entropy/32b_kl_cov.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport WANDB_API_KEY=YOUR_WANDB_API_KEY\n# export VLLM_USE_V1=1\n\nproject_name='Qwen2.5-32B'\nexp_name='klcov'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.2\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 2))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=\"kl_cov\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=256\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=8\nmax_token=20480\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"/YOUR_MODELPATH\"}\nCKPTS_DIR=${CKPTS_DIR:-\"/YOUR_CKPTS_PATH\"}\nTRAIN_FILE=${TRAIN_FILE:-\"/YOUR_TRAIN_FILE_PATH\"}\nTEST_FILE=${TEST_FILE:-[\"/YOUR_TRAIN_FILE_PATH\"]}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nppo_kl_coef=1\nkl_cov_ratio=0.0002\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nHYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.filter_overlong_prompts=False \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \\\n    actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.weight_decay=0 \\\n    actor_rollout_ref.actor.optim.warmup_style=constant \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=False \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=4 \\\n    trainer.save_freq=32 \\\n    trainer.total_epochs=1000 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_rl/recipe/entropy/32b_kl_cov_mininbsz.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport WANDB_API_KEY=YOUR_WANDB_API_KEY\n# export VLLM_USE_V1=1\n\nproject_name='Qwen2.5-32B'\nexp_name='klcov'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.2\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 2))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=\"kl_cov\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=256\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=16\nn_resp_per_prompt=8\nmax_token=20480\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"/YOUR_MODELPATH\"}\nCKPTS_DIR=${CKPTS_DIR:-\"/YOUR_CKPTS_PATH\"}\nTRAIN_FILE=${TRAIN_FILE:-\"/YOUR_TRAIN_FILE_PATH\"}\nTEST_FILE=${TEST_FILE:-[\"/YOUR_TRAIN_FILE_PATH\"]}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nppo_kl_coef=1\nkl_cov_ratio=0.0002\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nHYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.filter_overlong_prompts=False \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \\\n    actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.weight_decay=0 \\\n    actor_rollout_ref.actor.optim.warmup_style=constant \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=False \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=4 \\\n    trainer.save_freq=32 \\\n    trainer.total_epochs=1000 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_rl/recipe/entropy/7b_clip_cov.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport WANDB_API_KEY=YOUR_WANDB_API_KEY\n# export VLLM_USE_V1=1\n\nproject_name='Qwen2.5-7B'\nexp_name='clipcov'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=1\nclip_ratio_high=1\nclip_cov_ratio=0.0002\nclip_cov_lb=1.0\nclip_cov_ub=5.0\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 2))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=\"clip_cov\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=256\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=8\nmax_token=30720\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"/YOUR_MODELPATH\"}\nCKPTS_DIR=${CKPTS_DIR:-\"/YOUR_CKPTS_PATH\"}\nTRAIN_FILE=${TRAIN_FILE:-\"/YOUR_TRAIN_FILE_PATH\"}\nTEST_FILE=${TEST_FILE:-[\"/YOUR_TRAIN_FILE_PATH\"]}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nppo_kl_coef=1\nkl_cov_ratio=0.2\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nHYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.filter_overlong_prompts=False \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \\\n    actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.weight_decay=0 \\\n    actor_rollout_ref.actor.optim.warmup_style=constant \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=False \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=4 \\\n    trainer.save_freq=32 \\\n    trainer.total_epochs=1000 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_rl/recipe/entropy/7b_kl_cov.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport WANDB_API_KEY=YOUR_WANDB_API_KEY\n# export VLLM_USE_V1=1\n\nproject_name='Qwen2.5-7B'\nexp_name='klcov'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.2\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=False\noverlong_buffer_len=$((1024 * 2))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\nloss_mode=\"kl_cov\"\nenable_filter_groups=True\nfilter_groups_metric=acc\nmax_num_gen_batches=10\ntrain_prompt_bsz=256\ngen_prompt_bsz=$((train_prompt_bsz * 3))\ntrain_prompt_mini_bsz=32\nn_resp_per_prompt=8\nmax_token=30720\n\n# Ray\nRAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\nWORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\nRUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-4}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"/YOUR_MODELPATH\"}\nCKPTS_DIR=${CKPTS_DIR:-\"/YOUR_CKPTS_PATH\"}\nTRAIN_FILE=${TRAIN_FILE:-\"/YOUR_TRAIN_FILE_PATH\"}\nTEST_FILE=${TEST_FILE:-[\"/YOUR_TRAIN_FILE_PATH\"]}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nppo_kl_coef=1\nkl_cov_ratio=0.002\n\n# Mathematically equivalent\nuse_dynamic_bsz=True\ninfer_micro_batch_size=null\ntrain_micro_batch_size=null\noffload=False\n\nHYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.filter_overlong_prompts=False \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \\\n    actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \\\n    actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.mode=sync \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.weight_decay=0 \\\n    actor_rollout_ref.actor.optim.warmup_style=constant \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=\"${top_k}\" \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=False \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \\\n    reward_model.reward_manager=dapo \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=False \\\n    trainer.test_freq=4 \\\n    trainer.save_freq=32 \\\n    trainer.total_epochs=1000 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=disable\n"
  },
  {
    "path": "verl_rl/recipe/entropy/README.md",
    "content": "<div align=\"center\">\n\n# The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning.\n\n[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617)  [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue\n)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861)\n\n\n<div align=\"center\" style=\"font-family: Arial, sans-serif;\">\n  <p>\n    <a href=\"#🎉news\" style=\"text-decoration: none; font-weight: bold;\">🎉 News</a> •\n    <a href=\"#✨getting-started\" style=\"text-decoration: none; font-weight: bold;\">✨ Getting Started</a> •\n    <a href=\"#📖introduction\" style=\"text-decoration: none; font-weight: bold;\">📖 Introduction</a>\n  </p>\n  <p>\n    <a href=\"#🎈citation\" style=\"text-decoration: none; font-weight: bold;\">🎈 Citation</a> •\n    <a href=\"#🌻acknowledgement\" style=\"text-decoration: none; font-weight: bold;\">🌻 Acknowledgement</a> •\n    <a href=\"#📬Contact\" style=\"text-decoration: none; font-weight: bold;\">📬 Contact</a> •\n    <a href=\"#📈star-history\" style=\"text-decoration: none; font-weight: bold;\">📈 Star History</a>\n  </p>\n</div>\n\n</div>\n\n\n# 🎉News\n\n- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29).\n- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse. \n\n\n\n# ✨Getting started\n\nAfter preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run:\n\n```\ncd verl\nconda activate your_env\nbash recipe/dapo/7b_kl_cov.sh\n```\n\nWhile for training Qwen2.5-32B on multi nodes, you can run the following commands:\n\n```\ncd verl\nconda activate your_env\nbash recipe/dapo/32b_kl_cov.sh\n```\n\n# 📖Introduction\n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/e2a.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\nThis paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion. \n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/cov.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\nTheoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose ​​Clip-Cov​​ and ​​KL-Cov​​, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance. \n\n# 📃Evaluation\n\n<div align=\"left\">\n  <img src=\"https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/figures/performance_fig.jpg?raw=true\" alt=\"issue\" style=\"width: 96%; height: auto;\">\n</div>\n\n\nOur method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL. \n| **Method**        | **AIME24** | **AIME25** |  **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** |\n| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: |\n| *Qwen2.5-7B*      |            |            |          |              |               |                   |             |          |\n| GRPO              |       21.2 |        9.6 |     58.7 |         78.8 |          27.9 |              40.7 |        36.7 |     38.6 |\n| w. Clip-higher    |       18.1 |       11.5 |     56.6 |         79.2 |          29.8 |              43.3 |        40.4 |     38.8 |\n| w. **`CLIP-Cov`** |       22.1 |   **15.8** |     58.2 |         80.4 |      **30.5** |          **44.1** |    **41.1** |     40.4 |\n| w. **`KL-Cov`**   |   **22.6** |       12.9 | **61.4** |     **80.8** |          29.1 |              42.6 |        38.2 | **40.6** |\n| *Qwen2.5-32B*     |            |            |          |              |               |                   |             |          |\n| GRPO              |       21.8 |       16.2 |     69.7 |         84.2 |          35.2 |              43.6 |        45.5 |     45.8 |\n| w. Clip-higher    |       35.6 |       22.3 |     69.5 |         77.2 |          35.1 |              42.5 |        43.0 |     47.2 |\n| w. **`CLIP-Cov`** |       32.3 |       22.7 |     67.2 |     **87.0** |      **42.0** |          **57.2** |        46.0 |     50.3 |\n| w. **`KL-Cov`**   |   **36.8** |   **30.8** | **74.5** |         84.6 |          39.1 |              49.0 |    **46.3** | **52.2** |\n\nOur two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively.\n\n\n# 🎈Citation\nIf you find this paper or repo helpful, please cite us.\n\n```bibtex\n@article{cui2025entropy,\n  title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models},\n  author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others},\n  journal={arXiv preprint arXiv:2505.22617},\n  year={2025}\n}\n```\n# 🌻Acknowledgement\nWe implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions!\n\n# 📬 Contact\n\nFor questions, discussion, or collaboration opportunities, feel free to contact:\n- Ganqu Cui: cuiganqu@pjlab.org.cn\n- Yuchen Zhang: yuchen.zhang2003@gmail.com\n- Jiacheng Chen: jackchan9345@gmail.com\n- Ning Ding: ningding.cs@gmail.com\n\n"
  },
  {
    "path": "verl_rl/recipe/entropy/config/entropy_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  gen_batch_size: ${data.train_batch_size}\n\nreward_model:\n  reward_kwargs:\n        overlong_buffer_cfg: ${reward_model.overlong_buffer}\n  reward_manager: dapo\n  overlong_buffer: \n    enable: False \n    len: 0\n    penalty_factor: 0.0\n    log: False\n\nalgorithm:\n  filter_groups:\n    enable: False # We try to avoid forgetting to set enable\n    metric: null # acc / score / seq_reward / seq_final_reward / ...\n    max_num_gen_batches: 0 # Non-positive values mean no upper limit\n\ntrainer:\n  project_name: verl-entropy\n\nactor_rollout_ref:\n  actor:\n    policy_loss:\n      loss_mode: \"vanilla\" # /clip-cov / kl-cov from https://arxiv.org/abs/2505.\n      clip_cov_ratio: 0.0002 # for clip-cov loss\n      clip_cov_lb: 1.0 # for clip-cov loss\n      clip_cov_ub: 5.0 # for clip-cov loss\n      kl_cov_ratio: 0.0002 # for kl-cov loss\n      ppo_kl_coef: 0.1 # for kl-cov loss"
  },
  {
    "path": "verl_rl/recipe/entropy/entropy_ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nFSDP PPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport uuid\nfrom collections import defaultdict\nfrom copy import deepcopy\nfrom pprint import pprint\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.metric_utils import (\n    compute_data_metrics,\n    compute_throughout_metrics,\n    compute_timing_metrics,\n    reduce_metrics,\n)\nfrom verl.trainer.ppo.ray_trainer import (\n    AdvantageEstimator,\n    RayPPOTrainer,\n    apply_kl_penalty,\n    compute_advantage,\n    compute_response_mask,\n)\nfrom verl.utils.profiler import simple_timer\n\n\nclass RayEntropyTrainer(RayPPOTrainer):\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n\n        timing_raw = defaultdict(float)\n        batch = None\n        num_prompt_in_batch = 0\n        num_gen_batches = 0\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n\n                new_batch: DataProto = DataProto.from_single_dict(batch_dict)\n                num_gen_batches += 1\n                # pop those keys for generation\n                if \"multi_modal_inputs\" in new_batch.non_tensor_batch.keys():\n                    gen_batch = new_batch.pop(\n                        batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"],\n                        non_tensor_batch_keys=[\"raw_prompt_ids\", \"multi_modal_data\", \"multi_modal_inputs\"],\n                    )\n                else:\n                    gen_batch = new_batch.pop(\n                        batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"],\n                        non_tensor_batch_keys=[\"raw_prompt_ids\"],\n                    )\n                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                with simple_timer(\"step\", timing_raw):\n                    # generate a batch\n                    # with simple_timer(\"gen\", timing_raw):\n                    #     gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                    with simple_timer(\"gen\", timing_raw):\n                        if not self.async_rollout_mode:\n                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                        else:\n                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)\n\n                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                        with simple_timer(\"gen_max\", timing_raw):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n\n                            new_batch = new_batch.union(gen_baseline_output)\n                            reward_baseline_tensor = self.reward_fn(new_batch)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))\n\n                            new_batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del gen_baseline_batch, gen_baseline_output\n\n                    new_batch.non_tensor_batch[\"uid\"] = np.array(\n                        [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object\n                    )\n                    # repeat to align with repeated responses in rollout\n                    new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    new_batch = new_batch.union(gen_batch_output)\n\n                    with simple_timer(\"reward\", timing_raw):\n                        # compute scores. Support both model and function-based.\n                        # We first compute the scores using reward model. Then, we call reward_fn to combine\n                        # the results from reward model and rule-based results.\n                        if self.use_rm:\n                            # we first compute reward model score\n                            reward_tensor = self.rm_wg.compute_rm_score(new_batch)\n                            new_batch = new_batch.union(reward_tensor)\n\n                        # we combine with rule-based rm\n                        reward_extra_infos_dict: dict[str, list]\n                        try:\n                            reward_result = self.reward_fn(new_batch, return_dict=True)\n                            reward_tensor = reward_result[\"reward_tensor\"]\n                            reward_extra_infos_dict = reward_result[\"reward_extra_info\"]\n                        except Exception as e:\n                            print(f\"Error in reward_fn: {e}\")\n                            reward_tensor = self.reward_fn(new_batch)\n                            reward_extra_infos_dict = {}\n\n                        new_batch.batch[\"token_level_scores\"] = reward_tensor\n\n                        print(f\"{list(reward_extra_infos_dict.keys())=}\")\n                        if reward_extra_infos_dict:\n                            new_batch.non_tensor_batch.update(\n                                {k: np.array(v) for k, v in reward_extra_infos_dict.items()}\n                            )\n\n                        # compute rewards. apply_kl_penalty if available\n                        if self.config.algorithm.use_kl_in_reward:\n                            new_batch, kl_metrics = apply_kl_penalty(\n                                new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                            )\n                            metrics.update(\n                                kl_metrics\n                            )  # TODO: This will be cleared if we use multiple genenration batches\n                        else:\n                            new_batch.batch[\"token_level_rewards\"] = new_batch.batch[\"token_level_scores\"]\n\n                    if not self.config.algorithm.filter_groups.enable:\n                        batch = new_batch\n                    else:  # NOTE: When prompts after filtering is less than train batch size,\n                        # we skip to the next generation batch\n                        metric_name = self.config.algorithm.filter_groups.metric\n                        if metric_name == \"seq_final_reward\":\n                            # Turn to numpy for easier filtering\n                            new_batch.non_tensor_batch[\"seq_final_reward\"] = (\n                                new_batch.batch[\"token_level_rewards\"].sum(dim=-1).numpy()\n                            )\n                        elif metric_name == \"seq_reward\":\n                            new_batch.non_tensor_batch[\"seq_reward\"] = (\n                                new_batch.batch[\"token_level_scores\"].sum(dim=-1).numpy()\n                            )\n\n                        # Collect the sequence reward for each trajectory\n                        prompt_uid2metric_vals = defaultdict(list)\n                        for uid, metric_val in zip(\n                            new_batch.non_tensor_batch[\"uid\"], new_batch.non_tensor_batch[metric_name], strict=True\n                        ):\n                            prompt_uid2metric_vals[uid].append(metric_val)\n\n                        prompt_uid2metric_std = {}\n                        for prompt_uid, metric_vals in prompt_uid2metric_vals.items():\n                            prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)\n\n                        kept_prompt_uids = [\n                            uid\n                            for uid, std in prompt_uid2metric_std.items()\n                            if std > 0 or len(prompt_uid2metric_vals[uid]) == 1\n                        ]\n                        num_prompt_in_batch += len(kept_prompt_uids)\n\n                        kept_traj_idxs = []\n                        for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch[\"uid\"]):\n                            if traj_from_prompt_uid in kept_prompt_uids:\n                                kept_traj_idxs.append(idx)\n\n                        new_batch = new_batch[kept_traj_idxs]\n                        batch = new_batch if batch is None else DataProto.concat([batch, new_batch])\n\n                        prompt_bsz = self.config.data.train_batch_size\n                        if num_prompt_in_batch < prompt_bsz:\n                            print(f\"{num_prompt_in_batch=} < {prompt_bsz=}\")\n                            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches\n                            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:\n                                print(f\"{num_gen_batches=}. Keep generating...\")\n                                continue\n                            else:\n                                raise ValueError(\n                                    f\"{num_gen_batches=} >= {max_num_gen_batches=}.\"\n                                    + \" Generated too many. Please check if your data are too difficult.\"\n                                    + \" You could also try set max_num_gen_batches=0 to enable endless trials.\"\n                                )\n                        else:\n                            # Align the batch\n                            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n\n                            print(\n                                f\"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. \"\n                                f\"Collecting finished.\"\n                            )\n                            batch = batch[:traj_bsz]\n\n                    # === Updating ===\n\n                    batch.batch[\"response_mask\"] = compute_response_mask(batch)\n\n                    # balance the number of valid tokens on each dp rank.\n                    # Note that this breaks the order of data inside the batch.\n                    # Please take care when you implement group based adv computation such as GRPO and rloo\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                    # recompute old_log_probs\n                    with simple_timer(\"old_log_prob\", timing_raw):\n                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                        batch = batch.union(old_log_prob)\n\n                    if self.use_reference_policy:\n                        # compute reference log_prob\n                        with simple_timer(\"ref\", timing_raw):\n                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                            batch = batch.union(ref_log_prob)\n\n                    # compute values\n                    if self.use_critic:\n                        with simple_timer(\"values\", timing_raw):\n                            values = self.critic_wg.compute_values(batch)\n                            batch = batch.union(values)\n\n                    with simple_timer(\"adv\", timing_raw):\n                        # compute advantages, executed on the driver process\n                        norm_adv_by_std_in_grpo = self.config.algorithm.get(\"norm_adv_by_std_in_grpo\", True)\n                        batch = compute_advantage(\n                            batch,\n                            adv_estimator=self.config.algorithm.adv_estimator,\n                            gamma=self.config.algorithm.gamma,\n                            lam=self.config.algorithm.lam,\n                            num_repeat=self.config.actor_rollout_ref.rollout.n,\n                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                        )\n\n                    # update critic\n                    if self.use_critic:\n                        with simple_timer(\"update_critic\", timing_raw):\n                            critic_output = self.critic_wg.update_critic(batch)\n                        critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                        metrics.update(critic_output_metrics)\n\n                    # implement critic warmup\n                    if self.config.trainer.critic_warmup <= self.global_steps:\n                        # update actor\n                        with simple_timer(\"update_actor\", timing_raw):\n                            actor_output = self.actor_rollout_wg.update_actor(batch)\n                        actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                        metrics.update(actor_output_metrics)\n\n                    # validate\n                    if (\n                        self.val_reward_fn is not None\n                        and self.config.trainer.test_freq > 0\n                        and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                    ):\n                        with simple_timer(\"testing\", timing_raw):\n                            val_metrics: dict = self._validate()\n                            if is_last_step:\n                                last_val_metrics = val_metrics\n                        metrics.update(val_metrics)\n\n                    if self.config.trainer.save_freq > 0 and (\n                        is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n                    ):\n                        with simple_timer(\"save_checkpoint\", timing_raw):\n                            self._save_checkpoint()\n\n                # collect metrics\n                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n                # TODO: implement actual tflpo and theoretical tflpo\n                n_gpus = self.resource_pool_manager.get_n_gpus()\n                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n                timing_raw = defaultdict(float)  # clear timing\n\n                metrics[\"train/num_gen_batches\"] = num_gen_batches\n                batch = None\n                num_prompt_in_batch = 0\n                num_gen_batches = 0\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                if is_last_step:\n                    pprint(f\"Final validation metrics: {last_val_metrics}\")\n                    progress_bar.close()\n                    return\n\n                progress_bar.update(1)\n                self.global_steps += 1\n"
  },
  {
    "path": "verl_rl/recipe/entropy/main_entropy.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport hydra\nimport ray\n\nfrom .entropy_ray_trainer import RayEntropyTrainer\nfrom .reward import load_reward_manager\n\n\n@hydra.main(config_path=\"config\", config_name=\"entropy_trainer\", version_base=None)\ndef main(config):\n    run_ppo(config)\n\n\ndef run_ppo(config) -> None:\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        ray.init(\n            runtime_env={\n                \"env_vars\": {\n                    \"TOKENIZERS_PARALLELISM\": \"true\",\n                    \"NCCL_DEBUG\": \"WARN\",\n                    \"VLLM_LOGGING_LEVEL\": \"WARN\",\n                    \"WANDB_API_KEY\": \"YOUR_WANDB_API_KEY\",\n                }\n            },\n            num_cpus=config.ray_init.num_cpus,\n        )\n\n    runner = TaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n\ndef merge_dict(a: dict, b: dict) -> dict:\n    \"\"\"Return a new dict that has `a` updated with `b` (b wins on conflicts).\n\n    Example::\n\n      >>> d1 = {\"x\": 1, \"y\": 2}\n      >>> d2 = {\"y\": 20, \"z\": 3}\n      >>> new_dict = merge_dict(d1, d2)\n      >>> print(new_dict)   # {'x': 1, 'y': 20, 'z': 3}\n      >>> print(d1)         # {\"x\": 1, \"y\": 2} (unchanged)\n      >>> print(d2)         # {\"y\": 20, \"z\": 3} (unchanged)\n    \"\"\"\n    return a | b\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    def run(self, config):\n        # print initial config\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n        OmegaConf.resolve(config)\n\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n        print(f\"{config.actor_rollout_ref.model.path}\")\n        # instantiate tokenizer\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        processor = hf_processor(local_path, use_fast=True)  # used for multimodal LLM, could be none\n\n        # define worker classes\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n            from verl.single_controller.ray import RayWorkerGroup\n            from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n            from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker\n\n            actor_rollout_cls = ActorRolloutRefWorker\n            ray_worker_group_cls = NVMegatronRayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n        role_worker_mapping = {\n            Role.ActorRollout: ray.remote(actor_rollout_cls),\n            Role.Critic: ray.remote(CriticWorker),\n        }\n\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {\n            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n        }\n        mapping = {\n            Role.ActorRollout: global_pool_id,\n            Role.Critic: global_pool_id,\n        }\n\n        # we should adopt a multi-source reward function here\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # - finally, we combine all the rewards together\n        # - The reward type depends on the tag of the data\n        if config.reward_model.enable:\n            if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                from verl.workers.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # use reference model\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n            mapping[Role.RefPolicy] = global_pool_id\n\n        reward_kwargs = {\n            \"max_resp_len\": config.data.max_response_length,\n            \"overlong_buffer_cfg\": config.reward_model.overlong_buffer,\n        }\n        cfg_reward_kwargs = config.reward_model.get(\"reward_kwargs\", {})\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **OmegaConf.merge(OmegaConf.create(reward_kwargs), cfg_reward_kwargs)\n        )\n        val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **reward_kwargs)\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        from verl.utils.dataset.rl_dataset import collate_fn\n\n        train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)\n        val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)\n        train_sampler = create_rl_sampler(config.data, train_dataset)\n        trainer = RayEntropyTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            train_dataset=train_dataset,\n            val_dataset=val_dataset,\n            collate_fn=collate_fn,\n            train_sampler=train_sampler,\n        )\n        trainer.init_workers()\n        trainer.fit()\n\n\ndef create_rl_dataset(data_paths, data_config, tokenizer, processor):\n    \"\"\"Create a dataset.\n\n    Arguments:\n        data_config: The data config.\n        tokenizer (Tokenizer): The tokenizer.\n        processor (Processor): The processor.\n\n    Returns:\n        dataset (Dataset): The dataset.\n    \"\"\"\n    from torch.utils.data import Dataset\n\n    from verl.utils.dataset.rl_dataset import RLHFDataset\n\n    if \"custom_cls\" in data_config and data_config.custom_cls.get(\"path\", None) is not None:\n        from verl.utils.import_utils import load_extern_type\n\n        dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)\n        if not issubclass(dataset_cls, Dataset):\n            raise TypeError(\n                f\"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' \"\n                f\"must inherit from torch.utils.data.Dataset\"\n            )\n    else:\n        dataset_cls = RLHFDataset\n    print(f\"Using dataset class: {dataset_cls.__name__}\")\n\n    dataset = dataset_cls(\n        data_files=data_paths,\n        tokenizer=tokenizer,\n        processor=processor,\n        config=data_config,\n    )\n\n    return dataset\n\n\ndef create_rl_sampler(data_config, dataset):\n    \"\"\"Create a sampler for the dataset.\n\n    Arguments:\n        data_config: The data config.\n        dataset (Dataset): The dataset.\n\n    Returns:\n        sampler (Sampler): The sampler.\n    \"\"\"\n    import torch\n    from torch.utils.data import RandomSampler, SequentialSampler\n\n    # use sampler for better ckpt resume\n    if data_config.shuffle:\n        train_dataloader_generator = torch.Generator()\n        train_dataloader_generator.manual_seed(data_config.get(\"seed\", 1))\n        sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)\n    else:\n        sampler = SequentialSampler(data_source=dataset)\n\n    return sampler\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/recipe/entropy/reward.py",
    "content": "# Copyright 2025 Individual Contributor: Thibaut Barroyer\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport multiprocessing\nfrom functools import partial\n\nimport ray\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.reward import compute_reward, get_custom_reward_fn\n\nfrom .reward_score import _default_compute_score\n\n\ndef load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):\n    \"\"\"\n    Load and initialize a reward manager based on the configuration.\n\n    Args:\n        config: PPO trainer configuration object containing reward_model fields.\n        tokenizer: Tokenizer object used for processing text.\n        num_examine: Number of samples to examine.\n        **reward_kwargs: Additional keyword arguments for the reward manager.\n\n    Returns:\n        An instance of the specified reward manager class.\n    \"\"\"\n    from verl.workers.reward_manager import get_reward_manager_cls\n\n    # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:\n    # naive: NaiveRewardManager\n    # prime: PrimeRewardManager\n    # batch: BatchRewardManager\n    # dapo: DAPORewardManager\n    # Note(haibin.lin): For custom reward managers, please make sure they are imported and\n    # registered via `verl.workers.reward_manager.register`\n    # By default reward_manager is set to naive (NaiveRewardManager)\n    reward_manager_name = config.reward_model.get(\"reward_manager\", \"naive\")\n    reward_manager_cls = get_reward_manager_cls(reward_manager_name)\n\n    # Try to get a custom reward function based on the configuration\n    compute_score = get_custom_reward_fn(config)\n    final_compute_score = compute_score\n\n    if compute_score is None:\n        sandbox_config = config.reward_model.get(\"sandbox_fusion\")\n        sandbox_url = sandbox_config.get(\"url\") if sandbox_config else None\n        if sandbox_url:\n            sandbox_manager = multiprocessing.Manager()\n            # Create a semaphore to control concurrent access to the sandbox\n            _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get(\"max_concurrent\", 64))\n            final_compute_score = partial(\n                _default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore\n            )\n        else:\n            final_compute_score = _default_compute_score\n\n    # Instantiate and return the reward manager with the specified parameters\n    return reward_manager_cls(\n        tokenizer=tokenizer,\n        num_examine=num_examine,\n        compute_score=final_compute_score,\n        reward_fn_key=config.data.reward_fn_key,\n        **reward_kwargs,\n    )\n\n\n@ray.remote(num_cpus=1)\ndef compute_reward_async(data: DataProto, config, tokenizer):\n    \"\"\"\n    Load the reward manager and compute the reward for a batch of data.\n    This is meant to be run in a separate Ray worker.\n    \"\"\"\n    reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {}))\n    return compute_reward(data, reward_fn)\n"
  },
  {
    "path": "verl_rl/recipe/entropy/reward_score/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# from . import gsm8k, math, prime_math, prime_code\n\nimport traceback\n\nfrom . import entropy_math\n\n\ndef _default_compute_score(\n    data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None\n):\n    try:\n        res = entropy_math.compute_score(solution_str, str(ground_truth))\n        # print(f\"data_source: {data_source}\")\n        # raise NotImplementedError(f\"Reward function is not implemented for {data_source=}\")\n\n        if isinstance(res, dict):\n            return res\n        elif isinstance(res, int | float | bool):\n            return float(res)\n        else:\n            return float(res[0])\n    except Exception as e:\n        print(f\"[ERROR] Error in process_completion for task : {str(e)}\")\n        traceback.print_exc()  # 打印完整堆栈\n        raise  # 重新抛出异常以便上层捕获\n"
  },
  {
    "path": "verl_rl/recipe/entropy/reward_score/entropy_math/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except Exception in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Provides a math answer grading function with high recall.\nBased on HF math_verify, verl, open reasoner zero, etc.\n\"\"\"\n\nimport os\nimport re\nimport signal\nfrom itertools import islice, zip_longest\nfrom math import isclose\nfrom typing import Optional\n\nimport sympy\nfrom latex2sympy2_extended import latex2sympy\nfrom math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify\nfrom pylatexenc import latex2text\nfrom sympy import N, simplify\nfrom sympy.parsing import sympy_parser\nfrom sympy.parsing.latex import parse_latex\nfrom sympy.parsing.sympy_parser import parse_expr\n\n\"\"\"\nThis code is adapted from: Dr. GRPO (https://github.com/sail-sg/understand-r1-zero/blob/main/understand_r1_zero/math_grader.py).\n\"\"\"\n\n\ndef timeout_ours(timeout_seconds: int = 8):\n    if os.name == \"posix\":\n        import signal\n\n        def decorator(func):\n            def handler(signum, frame):\n                raise TimeoutError(\"Operation timed out!\")\n\n            def wrapper(*args, **kwargs):\n                old_handler = signal.getsignal(signal.SIGALRM)\n                signal.signal(signal.SIGALRM, handler)\n                signal.alarm(timeout_seconds)\n\n                try:\n                    return func(*args, **kwargs)\n                finally:\n                    signal.alarm(0)\n                    signal.signal(signal.SIGALRM, old_handler)\n\n            return wrapper\n\n        return decorator\n    else:\n        raise NotImplementedError(f\"Unsupported OS: {os.name}\")\n\n\n# Dan Hendrycks' code\ndef mathd_normalize_answer(answer: Optional[str]) -> Optional[str]:\n    if answer is None:\n        return None\n    answer = answer.strip()\n    try:\n        # Remove enclosing `\\text{}`.\n        m = re.search(\"^\\\\\\\\text\\{(?P<text>.+?)\\}$\", answer)\n        if m is not None:\n            answer = m.group(\"text\").strip()\n        return _strip_string(answer)\n    except Exception:\n        return answer\n\n\n# units mainly from MathQA\nunit_texts = [\n    \"east\",\n    \"degree\",\n    \"mph\",\n    \"kmph\",\n    \"ft\",\n    \"m sqaure\",\n    \" m east\",\n    \"sq m\",\n    \"deg\",\n    \"mile\",\n    \"q .\",\n    \"monkey\",\n    \"prime\",\n    \"ratio\",\n    \"profit of rs\",\n    \"rd\",\n    \"o\",\n    \"gm\",\n    \"p . m\",\n    \"lb\",\n    \"tile\",\n    \"per\",\n    \"dm\",\n    \"lt\",\n    \"gain\",\n    \"ab\",\n    \"way\",\n    \"west\",\n    \"a .\",\n    \"b .\",\n    \"c .\",\n    \"d .\",\n    \"e .\",\n    \"f .\",\n    \"g .\",\n    \"h .\",\n    \"t\",\n    \"a\",\n    \"h\",\n    \"no change\",\n    \"men\",\n    \"soldier\",\n    \"pie\",\n    \"bc\",\n    \"excess\",\n    \"st\",\n    \"inches\",\n    \"noon\",\n    \"percent\",\n    \"by\",\n    \"gal\",\n    \"kmh\",\n    \"c\",\n    \"acre\",\n    \"rise\",\n    \"a . m\",\n    \"th\",\n    \"π r 2\",\n    \"sq\",\n    \"mark\",\n    \"l\",\n    \"toy\",\n    \"coin\",\n    \"sq . m\",\n    \"gallon\",\n    \"° f\",\n    \"profit\",\n    \"minw\",\n    \"yr\",\n    \"women\",\n    \"feet\",\n    \"am\",\n    \"pm\",\n    \"hr\",\n    \"cu cm\",\n    \"square\",\n    \"v â € ™\",\n    \"are\",\n    \"rupee\",\n    \"rounds\",\n    \"cubic\",\n    \"cc\",\n    \"mtr\",\n    \"s\",\n    \"ohm\",\n    \"number\",\n    \"kmph\",\n    \"day\",\n    \"hour\",\n    \"minute\",\n    \"min\",\n    \"second\",\n    \"man\",\n    \"woman\",\n    \"sec\",\n    \"cube\",\n    \"mt\",\n    \"sq inch\",\n    \"mp\",\n    \"∏ cm ³\",\n    \"hectare\",\n    \"more\",\n    \"sec\",\n    \"unit\",\n    \"cu . m\",\n    \"cm 2\",\n    \"rs .\",\n    \"rs\",\n    \"kg\",\n    \"g\",\n    \"month\",\n    \"km\",\n    \"m\",\n    \"cm\",\n    \"mm\",\n    \"apple\",\n    \"liter\",\n    \"loss\",\n    \"yard\",\n    \"pure\",\n    \"year\",\n    \"increase\",\n    \"decrease\",\n    \"d\",\n    \"less\",\n    \"Surface\",\n    \"litre\",\n    \"pi sq m\",\n    \"s .\",\n    \"metre\",\n    \"meter\",\n    \"inch\",\n]\n\nunit_texts.extend([t + \"s\" for t in unit_texts])\n\n\ndef _strip_string(string):\n    def _fix_fracs(string):\n        substrs = string.split(\"\\\\frac\")\n        new_str = substrs[0]\n        if len(substrs) > 1:\n            substrs = substrs[1:]\n            for substr in substrs:\n                new_str += \"\\\\frac\"\n                if substr[0] == \"{\":\n                    new_str += substr\n                else:\n                    try:\n                        assert len(substr) >= 2\n                    except Exception:\n                        return string\n                    a = substr[0]\n                    b = substr[1]\n                    if b != \"{\":\n                        if len(substr) > 2:\n                            post_substr = substr[2:]\n                            new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                        else:\n                            new_str += \"{\" + a + \"}{\" + b + \"}\"\n                    else:\n                        if len(substr) > 2:\n                            post_substr = substr[2:]\n                            new_str += \"{\" + a + \"}\" + b + post_substr\n                        else:\n                            new_str += \"{\" + a + \"}\" + b\n        string = new_str\n        return string\n\n    def _fix_a_slash_b(string):\n        if len(string.split(\"/\")) != 2:\n            return string\n        a = string.split(\"/\")[0]\n        b = string.split(\"/\")[1]\n        try:\n            a = int(a)\n            b = int(b)\n            assert string == \"{}/{}\".format(a, b)\n            new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n            return new_string\n        except Exception:\n            return string\n\n    def _remove_right_units(string):\n        # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n        if \"\\\\text{ \" in string:\n            splits = string.split(\"\\\\text{ \")\n            assert len(splits) == 2\n            return splits[0]\n        else:\n            return string\n\n    def _fix_sqrt(string):\n        if \"\\\\sqrt\" not in string:\n            return string\n        splits = string.split(\"\\\\sqrt\")\n        new_string = splits[0]\n        for split in splits[1:]:\n            if split[0] != \"{\":\n                a = split[0]\n                new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n            else:\n                new_substr = \"\\\\sqrt\" + split\n            new_string += new_substr\n        return new_string\n\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n    # print(string)\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n    # print(string)\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n    # print(string)\n\n    # matrix\n    string = re.sub(r\"\\\\begin\\{array\\}\\{.*?\\}\", r\"\\\\begin{pmatrix}\", string)\n    string = re.sub(r\"\\\\end\\{array\\}\", r\"\\\\end{pmatrix}\", string)\n    string = string.replace(\"bmatrix\", \"pmatrix\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n    string = string.replace(\"\\\\neq\", \"\\\\ne\").replace(\"\\\\leq\", \"\\\\le\").replace(\"\\\\geq\", \"\\\\ge\")\n    # print(string)\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n    # print(string)\n\n    # Remove unit: miles, dollars if after is not none\n    _string = re.sub(r\"\\\\text{.*?}$\", \"\", string).strip()\n    if _string != \"\" and _string != string:\n        # print(\"Warning: unit not removed: '{}' -> '{}'\".format(string, _string))\n        string = _string\n\n    # Remove unit: texts\n    for _ in range(2):\n        for unit_text in unit_texts:\n            # use regex, the prefix should be either the start of the string or a non-alphanumeric character\n            # the suffix should be either the end of the string or a non-alphanumeric character\n            _string = re.sub(r\"(^|\\W)\" + unit_text + r\"($|\\W)\", r\"\\1\\2\", string)\n            if _string != \"\":\n                string = _string\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = _remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\%\", \"\")\n    string = string.replace(\"\\%\", \"\")\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2:\n        if len(string.split(\"=\")[0]) <= 2:\n            string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = _fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1).\n    # Also does a/b --> \\\\frac{a}{b}\n    string = _fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = _fix_a_slash_b(string)\n\n    return string\n\n\nSUBSTITUTIONS = [\n    (\"an \", \"\"),\n    (\"a \", \"\"),\n    (\".$\", \"$\"),\n    (\"\\\\$\", \"\"),\n    (r\"\\ \", \"\"),\n    (\" \", \"\"),\n    (\"mbox\", \"text\"),\n    (\",\\\\text{and}\", \",\"),\n    (\"\\\\text{and}\", \",\"),\n    (\"\\\\text{m}\", \"\\\\text{}\"),\n]\n\n\nREMOVED_EXPRESSIONS = [\n    \"square\",\n    \"ways\",\n    \"integers\",\n    \"dollars\",\n    \"mph\",\n    \"inches\",\n    \"ft\",\n    \"hours\",\n    \"km\",\n    \"units\",\n    \"\\\\ldots\",\n    \"sue\",\n    \"points\",\n    \"feet\",\n    \"minutes\",\n    \"digits\",\n    \"cents\",\n    \"degrees\",\n    \"cm\",\n    \"gm\",\n    \"pounds\",\n    \"meters\",\n    \"meals\",\n    \"edges\",\n    \"students\",\n    \"childrentickets\",\n    \"multiples\",\n    \"\\\\text{s}\",\n    \"\\\\text{.}\",\n    \"\\\\text{\\ns}\",\n    \"\\\\text{}^2\",\n    \"\\\\text{}^3\",\n    \"\\\\text{\\n}\",\n    \"\\\\text{}\",\n    r\"\\mathrm{th}\",\n    r\"^\\circ\",\n    r\"^{\\circ}\",\n    r\"\\;\",\n    r\",\\!\",\n    \"{,}\",\n    '\"',\n    \"\\\\dots\",\n]\n\n\ndef normalize_final_answer(final_answer: str) -> str:\n    \"\"\"\n    Normalize a final answer to a quantitative reasoning question.\n    This code comes from https://arxiv.org/pdf/2206.14858.pdf, page18.\n    \"\"\"\n    # final_answer = final_answer.split(\"=\")[-1]\n\n    for before, after in SUBSTITUTIONS:\n        final_answer = final_answer.replace(before, after)\n    for expr in REMOVED_EXPRESSIONS:\n        final_answer = final_answer.replace(expr, \"\")\n\n    # Extract answer that is in LaTeX math, is bold,\n    # is surrounded by a box, etc.\n    final_answer = re.sub(r\"(.*?)(\\$)(.*?)(\\$)(.*)\", \"$\\\\3$\", final_answer)\n    final_answer = re.sub(r\"(\\\\text\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\textbf\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\overline\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\boxed\\{)(.*)(\\})\", \"\\\\2\", final_answer)\n\n    # Normalize shorthand TeX:\n    # \\fracab -> \\frac{a}{b}\n    # \\frac{abc}{bef} -> \\frac{abc}{bef}\n    # \\fracabc -> \\frac{a}{b}c\n    # \\sqrta -> \\sqrt{a}\n    # \\sqrtab -> sqrt{a}b\n    final_answer = re.sub(r\"(frac)([^{])(.)\", \"frac{\\\\2}{\\\\3}\", final_answer)\n    final_answer = re.sub(r\"(sqrt)([^{])\", \"sqrt{\\\\2}\", final_answer)\n    final_answer = final_answer.replace(\"$\", \"\")\n\n    # Normalize 100,000 -> 100000\n    if final_answer.replace(\",\", \"\").isdigit():\n        final_answer = final_answer.replace(\",\", \"\")\n\n    return final_answer\n\n\ndef repeatness(s: str):\n    def ranks(seq):\n        index = {v: i for i, v in enumerate(sorted(set(seq)))}\n        return [index[v] for v in seq]\n\n    def suffixArray(s):\n        line = ranks(s)\n        n, k, ans, sa = len(s), 1, line, [0] * len(s)\n        while k < n - 1:\n            line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))\n            ans, k = line, k << 1\n        for i, k in enumerate(ans):\n            sa[k] = i\n        return ans, sa\n\n    def lcp(arr, suffixArr, inv_suff):\n        n, ans, k = len(arr), [0] * len(arr), 0\n\n        for i in range(n):\n            if inv_suff[i] == n - 1:\n                k = 0\n                continue\n\n            j = suffixArr[inv_suff[i] + 1]\n            while i + k < n and j + k < n and arr[i + k] == arr[j + k]:\n                k += 1\n\n            ans[inv_suff[i]] = k\n            if k > 0:\n                k -= 1\n\n        return ans\n\n    arr = [ord(i) for i in s]\n    n = len(arr)\n    if n <= 1:\n        return 0\n    c, sa = suffixArray(arr)\n    cnt = sum(lcp(arr, sa, c))\n\n    return (cnt * 2 / (n * (n + 1))) > 0.2\n\n\nclass timeout:\n    def __init__(self, seconds=1, error_message=\"Timeout\"):\n        self.seconds = seconds\n        self.error_message = error_message\n\n    def handle_timeout(self, signum, frame):\n        raise TimeoutError(self.error_message)\n\n    def __enter__(self):\n        signal.signal(signal.SIGALRM, self.handle_timeout)\n        signal.alarm(self.seconds)\n\n    def __exit__(self, type, value, traceback):\n        signal.alarm(0)\n\n\ndef latex_eval(latex):\n    sym = parse_latex(latex)\n    val = sym.evalf()\n    return sym, val\n\n\ndef numeric_equal(prediction: float, reference: float):\n    # Note that relative tolerance has significant impact\n    # on the result of the synthesized GSM-Hard dataset\n    # if reference.is_integer():\n    #     return isclose(reference, round(prediction), abs_tol=1e-4)\n    # else:\n    # prediction = round(prediction, len(str(reference).split(\".\")[-1]))\n    return isclose(reference, prediction, rel_tol=1e-4)\n\n\n@timeout_ours(timeout_seconds=5)\ndef symbolic_equal(a, b):\n    def _parse(s):\n        for f in [parse_latex, parse_expr, latex2sympy]:\n            try:\n                return f(s.replace(\"\\\\\\\\\", \"\\\\\"))\n            except Exception:\n                try:\n                    return f(s)\n                except Exception:\n                    pass\n        return s\n\n    a = _parse(a)\n    b = _parse(b)\n\n    # direct equal\n    try:\n        if str(a) == str(b) or a == b:\n            return True\n    except Exception:\n        pass\n\n    # simplify equal\n    try:\n        if a.equals(b) or simplify(a - b) == 0:\n            return True\n    except Exception:\n        pass\n\n    # equation equal\n    try:\n        if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):\n            return True\n    except Exception:\n        pass\n\n    try:\n        if numeric_equal(float(N(a)), float(N(b))):\n            return True\n    except Exception:\n        pass\n\n    # matrix\n    try:\n        # if a and b are matrix\n        if a.shape == b.shape:\n            _a = a.applyfunc(lambda x: round(x, 3))\n            _b = b.applyfunc(lambda x: round(x, 3))\n            if _a.equals(_b):\n                return True\n    except Exception:\n        pass\n\n    return False\n\n\ndef _is_latex_equal(str1, str2):\n    try:\n        sym1, val1 = latex_eval(str1)\n        sym2, val2 = latex_eval(str2)\n        if sym1 == sym2 or val1 == val2:\n            return True\n        else:\n            raise ValueError\n    except Exception:  # noqa\n        try:\n            norm1, norm2 = normalize_final_answer(str1), normalize_final_answer(str2)\n            sym1, val1 = latex_eval(norm1)\n            sym2, val2 = latex_eval(norm2)\n            if sym1 == sym2 or val1 == val2:\n                return True\n        except Exception:  # noqa\n            return norm1 == norm2\n    return False\n\n\ndef is_latex_equal(given_answer: str, ground_truth: str) -> bool:\n    try:\n        with timeout(1):\n            try:\n                if (len(given_answer) > 128 and repeatness(given_answer)) or (\n                    len(ground_truth) > 128 and repeatness(ground_truth)\n                ):\n                    return False\n                # First conduct normalized string matching.\n                ground_truth_normalized = _normalize(ground_truth)\n                given_normalized = _normalize(given_answer)\n                if ground_truth_normalized is None:\n                    return False\n                if ground_truth_normalized == given_normalized:\n                    return True\n\n                # Next call math verify.\n                given_answer.replace(\"\\n\", \"\")\n                ground_truth.replace(\"\\n\", \"\")\n                if \"$\" not in given_answer:\n                    given_answer = f\"${given_answer}$\"\n                if \"$\" not in ground_truth:\n                    ground_truth = f\"${ground_truth}$\"\n                return verify(\n                    parse(\n                        ground_truth,\n                        extraction_config=(\n                            LatexExtractionConfig(boxed_match_priority=0),\n                            ExprExtractionConfig(),\n                        ),\n                        fallback_mode=\"no_fallback\",\n                        extraction_mode=[\"first_match\"],\n                        parsing_timeout=1,\n                    ),\n                    parse(\n                        given_answer,\n                        extraction_config=(\n                            LatexExtractionConfig(boxed_match_priority=0),\n                            ExprExtractionConfig(),\n                        ),\n                        fallback_mode=\"no_fallback\",\n                        extraction_mode=[\"first_match\"],\n                        parsing_timeout=1,\n                    ),\n                    timeout_seconds=1,\n                )\n                # or symbolic_equal(ground_truth, given_answer)\n            except Exception:\n                return False\n    except TimeoutError:\n        return False\n\n\ndef is_value_equal(given_answer: str, ground_truth: str) -> bool:\n    assert ground_truth is not None\n    ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth)\n    given_answer_normalized_mathd = mathd_normalize_answer(given_answer)\n\n    str_equal = ground_truth_normalized_mathd == given_answer_normalized_mathd\n    try:\n        number_equal = float(ground_truth_normalized_mathd) == float(given_answer_normalized_mathd)\n        return str_equal or number_equal\n    except Exception:\n        return str_equal\n\n\n# sympy might hang -- we don't care about trying to be lenient in these cases\nBAD_SUBSTRINGS = [\"^{\", \"^(\"]\nBAD_REGEXES = [\"\\^[0-9]+\\^\", \"\\^[0-9][0-9]+\"]\nTUPLE_CHARS = \"()[]\"\n\n\ndef _sympy_parse(expr: str):\n    \"\"\"Parses an expression with sympy.\"\"\"\n    py_expr = expr.replace(\"^\", \"**\")\n    return sympy_parser.parse_expr(\n        py_expr,\n        transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)),\n    )\n\n\ndef _parse_latex(expr: str) -> str:\n    \"\"\"Attempts to parse latex to an expression sympy can read.\"\"\"\n    expr = expr.replace(\"\\\\tfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\dfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\frac\", \" \\\\frac\")  # Play nice with mixed numbers.\n    expr = latex2text.LatexNodes2Text().latex_to_text(expr)\n\n    # Replace the specific characters that this parser uses.\n    expr = expr.replace(\"√\", \"sqrt\")\n    expr = expr.replace(\"π\", \"pi\")\n    expr = expr.replace(\"∞\", \"inf\")\n    expr = expr.replace(\"∪\", \"U\")\n    expr = expr.replace(\"·\", \"*\")\n    expr = expr.replace(\"×\", \"*\")\n\n    return expr.strip()\n\n\ndef _is_float(num: str) -> bool:\n    try:\n        float(num)\n        return True\n    except ValueError:\n        return False\n\n\ndef _is_int(x: float) -> bool:\n    try:\n        return abs(x - int(round(x))) <= 1e-7\n    except Exception:\n        return False\n\n\ndef _is_frac(expr: str) -> bool:\n    return bool(re.search(r\"^-?[0-9]+.?/0*[1-9][0-9]*.?$\", expr))\n\n\ndef _str_is_int(x: str) -> bool:\n    try:\n        x = _strip_properly_formatted_commas(x)\n        x = float(x)\n        return abs(x - int(round(x))) <= 1e-7\n    except Exception:\n        return False\n\n\ndef _str_to_int(x: str) -> bool:\n    x = x.replace(\",\", \"\")\n    x = float(x)\n    return int(x)\n\n\ndef _inject_implicit_mixed_number(step: str):\n    \"\"\"\n    Automatically make a mixed number evalable\n    e.g. 7 3/4 => 7+3/4\n    \"\"\"\n    p1 = re.compile(\"([0-9]) +([0-9])\")\n    step = p1.sub(\"\\\\1+\\\\2\", step)  ## implicit mults\n    return step\n\n\ndef _strip_properly_formatted_commas(expr: str):\n    # We want to be careful because we don't want to strip tuple commas\n    p1 = re.compile(\"(\\d)(,)(\\d\\d\\d)($|\\D)\")\n    while True:\n        next_expr = p1.sub(\"\\\\1\\\\3\\\\4\", expr)\n        if next_expr == expr:\n            break\n        expr = next_expr\n    return next_expr\n\n\ndef _normalize(expr: str) -> str:\n    \"\"\"Normalize answer expressions.\"\"\"\n    if expr is None:\n        return None\n\n    # Remove enclosing `\\text{}`.\n    m = re.search(\"^\\\\\\\\text\\{(?P<text>.+?)\\}$\", expr)\n    if m is not None:\n        expr = m.group(\"text\")\n\n    expr = expr.replace(\"\\\\%\", \"%\")\n    expr = expr.replace(\"\\\\$\", \"$\")\n    expr = expr.replace(\"$\", \"\")\n    expr = expr.replace(\"%\", \"\")\n    expr = expr.replace(\" or \", \" , \")\n    expr = expr.replace(\" and \", \" , \")\n\n    expr = expr.replace(\"million\", \"*10^6\")\n    expr = expr.replace(\"billion\", \"*10^9\")\n    expr = expr.replace(\"trillion\", \"*10^12\")\n\n    for unit in [\n        \"degree\",\n        \"cm\",\n        \"centimeter\",\n        \"meter\",\n        \"mile\",\n        \"second\",\n        \"minute\",\n        \"hour\",\n        \"day\",\n        \"week\",\n        \"month\",\n        \"year\",\n        \"foot\",\n        \"feet\",\n        \"inch\",\n        \"yard\",\n    ]:\n        expr = re.sub(f\"{unit}(es)?(s)? *(\\^[0-9]+)?\", \"\", expr)\n    expr = re.sub(\"\\^ *\\\\\\\\circ\", \"\", expr)\n\n    if len(expr) > 0 and expr[0] == \"{\" and expr[-1] == \"}\":\n        expr = expr[1:-1]\n\n    expr = re.sub(\",\\\\\\\\! *\", \"\", expr)\n    if _is_float(expr) and _is_int(float(expr)):\n        expr = str(int(round(float(expr))))\n    if \"\\\\\" in expr:\n        try:\n            expr = _parse_latex(expr)\n        except Exception:\n            pass\n\n    # edge case with mixed numbers and negative signs\n    expr = re.sub(\"- *\", \"-\", expr)\n\n    expr = _inject_implicit_mixed_number(expr)\n    expr = expr.replace(\" \", \"\")\n\n    # if we somehow still have latex braces here, just drop them\n    expr = expr.replace(\"{\", \"\")\n    expr = expr.replace(\"}\", \"\")\n\n    # don't be case sensitive for text answers\n    expr = expr.lower()\n\n    if _str_is_int(expr):\n        expr = str(_str_to_int(expr))\n\n    return expr\n\n\ndef count_unknown_letters_in_expr(expr: str):\n    expr = expr.replace(\"sqrt\", \"\")\n    expr = expr.replace(\"frac\", \"\")\n    letters_in_expr = set([x for x in expr if x.isalpha()])\n    return len(letters_in_expr)\n\n\ndef should_allow_eval(expr: str):\n    # we don't want to try parsing unknown text or functions of more than two variables\n    if count_unknown_letters_in_expr(expr) > 2:\n        return False\n\n    for bad_string in BAD_SUBSTRINGS:\n        if bad_string in expr:\n            return False\n\n    for bad_regex in BAD_REGEXES:\n        if re.search(bad_regex, expr) is not None:\n            return False\n\n    return True\n\n\n@timeout_ours(timeout_seconds=5)\ndef are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):\n    are_equal = False\n    try:\n        expr = f\"({ground_truth_normalized})-({given_normalized})\"\n        if should_allow_eval(expr):\n            sympy_diff = _sympy_parse(expr)\n            simplified = sympy.simplify(sympy_diff)\n            if simplified == 0:\n                are_equal = True\n    except Exception:\n        pass\n    return are_equal\n\n\ndef split_tuple(expr: str):\n    \"\"\"\n    Split the elements in a tuple/interval, while handling well-formatted commas in large numbers\n    \"\"\"\n    expr = _strip_properly_formatted_commas(expr)\n    if len(expr) == 0:\n        return []\n    if (\n        len(expr) > 2\n        and expr[0] in TUPLE_CHARS\n        and expr[-1] in TUPLE_CHARS\n        and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])\n    ):\n        elems = [elem.strip() for elem in expr[1:-1].split(\",\")]\n    else:\n        elems = [expr]\n    return elems\n\n\ndef last_boxed_only_string(string):\n    idx = string.rfind(\"\\\\boxed\")\n    if idx < 0:\n        idx = string.rfind(\"\\\\fbox\")\n        if idx < 0:\n            return None\n\n    i = idx\n    right_brace_idx = None\n    num_left_braces_open = 0\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n        if string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n        i += 1\n    if right_brace_idx is None:\n        retval = None\n    else:\n        retval = string[idx : right_brace_idx + 1]\n\n    return retval\n\n\ndef remove_boxed(s):\n    left = \"\\\\boxed{\"\n    try:\n        assert s[: len(left)] == left\n        assert s[-1] == \"}\"\n        return s[len(left) : -1]\n    except Exception:\n        return None\n\n\ndef extract_boxed_answer(solution: str) -> str:\n    \"\"\"Extract the answer from inside a LaTeX \\\\boxed{} command\"\"\"\n    solution = last_boxed_only_string(solution)\n    solution = remove_boxed(solution)\n    return solution\n\n\ndef grade_answer_sympy(given_answer: str, ground_truth: str) -> bool:\n    ground_truth_normalized = _normalize(ground_truth)\n    given_normalized = _normalize(given_answer)\n\n    if ground_truth_normalized is None:\n        return False\n\n    if ground_truth_normalized == given_normalized:\n        return True\n\n    if len(given_normalized) == 0:\n        return False\n\n    ground_truth_elems = split_tuple(ground_truth_normalized)\n    given_elems = split_tuple(given_normalized)\n\n    if len(ground_truth_elems) > 1 and (\n        ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]\n    ):\n        is_correct = False\n    elif len(ground_truth_elems) != len(given_elems):\n        is_correct = False\n    else:\n        for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True):\n            if _is_frac(ground_truth_elem) and _is_frac(given_elem):\n                # if fractions aren't reduced, then shouldn't be marked as correct\n                # so, we don't want to allow sympy.simplify in this case\n                is_correct = ground_truth_elem == given_elem\n            elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):\n                # if the ground truth answer is an integer, we require the given answer to be a strict match\n                # (no sympy.simplify)\n                is_correct = False\n            else:\n                is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)\n            if not is_correct:\n                break\n\n    return is_correct\n\n\ndef grade_answer_mathd(given_answer: str, ground_truth: str) -> bool:\n    ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth)\n    given_answer_normalized_mathd = mathd_normalize_answer(given_answer)\n\n    # be at least as lenient as mathd\n    if ground_truth_normalized_mathd == given_answer_normalized_mathd:\n        return True\n    return False\n\n\ndef extract_answer(passage: str) -> str:\n    if \"\\\\boxed\" in passage:\n        return extract_boxed_answer(passage)\n    return None\n\n\ndef grade(model_answer: str, gt_answer: str, fast: bool = True):\n    if \"\\\\boxed\" in gt_answer:\n        gt_answer = extract_answer(gt_answer)\n    correct = grade_answer_mathd(model_answer, gt_answer) or grade_answer_sympy(model_answer, gt_answer)\n    if not fast:\n        # This mode further uses math_verify to recall originally false positives.\n        # Will be a bit slower, and sensitive to bad inputs.\n        correct = correct or is_latex_equal(\n            model_answer,\n            gt_answer,\n        )\n    return correct\n\n\ndef compute_score(model_response, gt_answer, fast=False):\n    model_answer = extract_answer(model_response)\n    if model_answer is None:\n        return {\n            \"score\": 0.0,\n            \"format_score\": 0.0,\n            \"acc\": False,\n            \"extracted_gt\": gt_answer,\n            # \"extracted_pred\": None,\n        }\n        # return 0.0, 0.0  # Cannot even parse anything.\n    is_correct = False\n    if isinstance(gt_answer, float) or isinstance(gt_answer, int):\n        gt_answer = str(gt_answer)\n    if isinstance(gt_answer, str):\n        is_correct = grade(model_answer, gt_answer, fast)\n    elif isinstance(gt_answer, list):\n        is_correct = False\n        for gt in gt_answer:\n            is_correct |= grade(model_answer, gt, fast)\n    if is_correct:\n        return {\n            \"score\": 1.0,\n            \"format_score\": 1.0,\n            \"acc\": True,\n            \"extracted_gt\": gt_answer,\n            # \"extracted_pred\": None,\n        }\n    else:\n        return {\n            \"score\": 0.0,\n            \"format_score\": 1.0,\n            \"acc\": False,\n            \"extracted_gt\": gt_answer,\n            # \"extracted_pred\": None,\n        }\n"
  },
  {
    "path": "verl_rl/recipe/entropy/reward_score/entropy_math/grader.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright (c) Microsoft Corporation.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE\n\n# Copyright (c) 2023 OpenAI\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Copyright (c) 2021 Dan Hendrycks\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:\n- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py\n- https://github.com/microsoft/ProphetNet/tree/master/CRITIC\n- https://github.com/openai/prm800k\n\"\"\"\n\nimport contextlib\nimport math\nimport re\nfrom math import isclose\n\n# sympy related\nfrom sympy import N, simplify\nfrom sympy.parsing.latex import parse_latex\nfrom sympy.parsing.sympy_parser import parse_expr\n\n# verl related\nfrom verl.utils.py_functional import timeout_limit\n\n\ndef is_digit(s):\n    try:\n        if \"{,}\" in str(s):\n            num = float(str(s).replace(\"{,}\", \"\"))\n            return True, num\n\n        num = float(str(s).replace(\",\", \"\"))\n        return True, num\n    except ValueError:\n        return False, None\n\n\ndef normalize(answer, pi) -> str:\n    # checking if answer is $<number> and removing $ in that case to compare\n    if isinstance(answer, str) and bool(re.match(r\"\\$\\d+(\\.\\d+)?\", answer)):\n        return answer[1:]\n\n    # checking if answer is <number>% or <number>\\\\% and removing %\n    if isinstance(answer, str) and (\n        bool(re.match(r\"^\\d+(\\.\\d+)?%$\", answer)) or bool(re.match(r\"^\\d+(\\.\\d+)?\\\\%$\", answer))\n    ):\n        return answer.replace(\"\\\\%\", \"\").replace(\"%\", \"\")\n\n    # handle base\n    answer = handle_base(answer)\n\n    # handle pi\n    answer = handle_pi(answer, pi)\n\n    return answer\n\n\ndef handle_base(x) -> str:\n    if isinstance(x, str) and \"_\" in x:\n        # Due to base\n        x = x.split(\"_\")[0]\n        x = float(x)\n        return int(x)\n    return x\n\n\ndef handle_pi(string, pi):\n    if isinstance(string, str) and \"\\pi\" in string:\n        # Find the first occurrence of \"\\pi\"\n        idx = string.find(\"\\pi\")\n\n        # Iterate over the string and find all occurrences of \"\\pi\" with a valid previous character\n        while idx != -1:\n            if idx > 0 and string[idx - 1].isdigit():\n                # Replace \"\\pi\" with \"*math.pi\" if the previous character is a digit\n                string = string[:idx] + f\"*{pi}\" + string[idx + 3 :]\n            else:\n                # Replace \"\\pi\" with \"1*math.pi\" if the previous character is not a digit\n                string = string[:idx] + f\"1*{pi}\" + string[idx + 3 :]\n\n            # Find the next occurrence of \"\\pi\"\n            idx = string.find(\"\\pi\", idx + 1)\n\n        # Evaluate the expression using eval() function\n        with contextlib.suppress(Exception):\n            string = eval(string)\n\n    return string\n\n\ndef math_equal(\n    prediction: bool | float | str,\n    reference: float | str,\n    include_percentage: bool = True,\n    tolerance: float = 1e-4,\n    timeout: float = 10.0,\n    pi: float = math.pi,\n) -> bool:\n    \"\"\"\n    Exact match of math if and only if:\n    1. numerical equal: both can convert to float and are equal\n    2. symbolic equal: both can convert to sympy expression and are equal\n    \"\"\"\n\n    prediction = normalize(prediction, pi)\n    reference = normalize(reference, pi)\n\n    if isinstance(prediction, str) and len(prediction) > 1000:  # handling weird corner-cases\n        prediction = prediction[:1000]\n\n    # 0. string comparison\n    if isinstance(prediction, str) and isinstance(reference, str):\n        if prediction.strip().lower() == reference.strip().lower():\n            return True\n        if prediction.replace(\" \", \"\") == reference.replace(\" \", \"\"):\n            return True\n\n    try:  # 1. numerical equal\n        if is_digit(prediction)[0] and is_digit(reference)[0]:\n            prediction = is_digit(prediction)[1]\n            reference = is_digit(reference)[1]\n            # number questions\n            gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]\n            for item in gt_result:\n                try:\n                    if isclose(item, prediction, rel_tol=tolerance):\n                        return True\n                except Exception:\n                    continue\n            return False\n    except Exception:\n        pass\n\n    if not prediction and prediction not in [0, False]:\n        return False\n\n    # 2. symbolic equal\n    reference = str(reference).strip()\n    prediction = str(prediction).strip()\n\n    ## deal with [], (), {}\n    prediction = format_intervals(prediction)\n\n    pred_str, ref_str = prediction, reference\n    if (prediction.startswith(\"[\") and prediction.endswith(\"]\") and not reference.startswith(\"(\")) or (\n        prediction.startswith(\"(\") and prediction.endswith(\")\") and not reference.startswith(\"[\")\n    ):\n        pred_str = pred_str.strip(\"[]()\")\n        ref_str = ref_str.strip(\"[]()\")\n    for s in [\"{\", \"}\", \"(\", \")\"]:\n        ref_str = ref_str.replace(s, \"\")\n        pred_str = pred_str.replace(s, \"\")\n    if pred_str == ref_str:\n        return True\n\n    ## [a, b] vs. [c, d], return a==c and b==d\n    if (\n        prediction\n        and reference\n        and prediction[0] in \"([\"\n        and prediction[-1] in \")]\"\n        and prediction[0] == reference[0]\n        and prediction[-1] == reference[-1]\n    ):\n        pred_parts = prediction[1:-1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts) and all(\n            [\n                math_equal(pred_pt, ref_pt, include_percentage, tolerance)\n                for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)\n            ]\n        ):\n            return True\n\n    if \",\" in prediction and \",\" in reference:\n        pred_parts = [item.strip() for item in prediction.split(\",\")]\n        ref_parts = [item.strip() for item in reference.split(\",\")]\n\n        if len(pred_parts) == len(ref_parts):\n            return bool(\n                all(\n                    [\n                        math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance)\n                        for i in range(len(pred_parts))\n                    ]\n                )\n            )\n\n    # if we have point == tuple of values\n    if prediction.startswith(\"Point\") and reference[0] == \"(\" and reference[-1] == \")\":\n        pred_parts = prediction[prediction.find(\"(\") + 1 : -1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts) and all(\n            [\n                math_equal(pred_pt, ref_pt, include_percentage, tolerance)\n                for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)\n            ]\n        ):\n            return True\n\n    # if reference is a matrix\n    if \"\\begin{pmatrix}\" in reference and prediction.startswith(\"Matrix\"):\n        try:\n            pred_matrix = parse_expr(prediction)\n            ref_matrix_items = reference.split()[1:-1:2]\n            if len(pred_matrix) == len(ref_matrix_items) and all(\n                [\n                    math_equal(pred, ref, include_percentage, tolerance)\n                    for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)\n                ]\n            ):\n                return True\n        except Exception:\n            pass\n    elif \"\\begin{pmatrix}\" in reference and prediction.startswith(\"[\") and prediction.endswith(\"]\"):\n        if isinstance(eval(prediction), list):\n            try:\n                pred_matrix = eval(prediction)\n                # ref_matrix_items = reference.split()[1:-1:2]\n                ref_matrix_items = (\n                    reference.lstrip(\"\\\\begin{pmatrix}\")  # noqa: B005\n                    .lstrip(\"\\begin{pmatrix}\")\n                    .rstrip(\"\\\\end{pmatrix}\")\n                    .rstrip(\"\\end{pmatrix}\")\n                )  # noqa: B005\n                ref_matrix_items = ref_matrix_items.split(\"\\\\\")\n                ref_matrix_items = [row.split(\"&\") if \"&\" in row else row for row in ref_matrix_items]\n                if len(pred_matrix) == len(ref_matrix_items) and all(\n                    [\n                        math_equal(pred, ref, include_percentage, tolerance)\n                        for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)\n                    ]\n                ):\n                    return True\n            except Exception:\n                pass\n\n    return symbolic_equal(prediction, reference, tolerance, timeout)\n\n\ndef symbolic_equal(a, b, tolerance, timeout=10.0):\n    def _parse(s):\n        for f in [parse_expr, parse_latex]:\n            try:\n                with timeout_limit(seconds=timeout):\n                    return f(s)\n            except TimeoutError:\n                print(f\"Parsing timed out for {s}\")\n                continue\n            except Exception:\n                continue\n        return s\n\n    a = _parse(a)\n    b = _parse(b)\n\n    try:\n        with timeout_limit(seconds=timeout):\n            if simplify(a - b) == 0:\n                return True\n    except TimeoutError:\n        print(f\"Simplification timed out for {a} - {b}\")\n        pass\n    except Exception:\n        pass\n\n    try:\n        with timeout_limit(seconds=timeout):\n            if isclose(N(a), N(b), rel_tol=tolerance):\n                return True\n    except TimeoutError:\n        print(f\"Numerical evaluation timed out for {a}, {b}\")\n        pass\n    except Exception:\n        pass\n    return False\n\n\ndef format_intervals(prediction):\n    patterns = {\n        \"Interval(\": r\"^Interval\\((.*)\\)$\",\n        \"Interval.Ropen(\": r\"^Interval\\.Ropen\\((.*)\\)$\",\n        \"Interval.Lopen(\": r\"^Interval\\.Lopen\\((.*)\\)$\",\n        \"Interval.open(\": r\"^Interval\\.open\\((.*)\\)$\",\n    }\n\n    for key, pattern in patterns.items():\n        match = re.match(pattern, prediction)\n        if match:\n            inner_content = match.group(1)\n\n            if key == \"Interval(\":  # Intarval(a, b) == [a, b]\n                return f\"[{inner_content}]\"\n            elif key == \"Interval.Ropen(\":  # Intarval.Ropen(a, b) == [a, b)\n                return f\"[{inner_content})\"\n            elif key == \"Interval.Lopen(\":  # Intarval.Lopen(a, b) == (a, b]\n                return f\"({inner_content}]\"\n            elif key == \"Interval.open(\":  # Intarval.open(a, b) == (a, b)\n                return f\"({inner_content})\"\n\n    return prediction\n"
  },
  {
    "path": "verl_rl/recipe/entropy/reward_score/entropy_math/math_normalize.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright (c) 2021 Dan Hendrycks\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\"\"\"\nThis logic is largely copied from the Hendrycks' MATH release (math_equivalence).\n\nFrom: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py\n\"\"\"\n\nimport re\nfrom typing import Optional\n\n\ndef normalize_answer(answer: Optional[str]) -> Optional[str]:\n    if answer is None:\n        return None\n    answer = answer.strip()\n    try:\n        # Remove enclosing `\\text{}`.\n        m = re.search(\"^\\\\\\\\text\\{(?P<text>.+?)\\}$\", answer)\n        if m is not None:\n            answer = m.group(\"text\").strip()\n        return _strip_string(answer)\n    except Exception:  # noqa: E722\n        return answer\n\n\ndef _fix_fracs(string):\n    substrs = string.split(\"\\\\frac\")\n    new_str = substrs[0]\n    if len(substrs) > 1:\n        substrs = substrs[1:]\n        for substr in substrs:\n            new_str += \"\\\\frac\"\n            if substr[0] == \"{\":\n                new_str += substr\n            else:\n                try:\n                    assert len(substr) >= 2\n                except Exception:  # noqa: E722\n                    return string\n                a = substr[0]\n                b = substr[1]\n                if b != \"{\":\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}{\" + b + \"}\"\n                else:\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}\" + b + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}\" + b\n    string = new_str\n    return string\n\n\ndef _fix_a_slash_b(string):\n    if len(string.split(\"/\")) != 2:\n        return string\n    a = string.split(\"/\")[0]\n    b = string.split(\"/\")[1]\n    try:\n        a = int(a)\n        b = int(b)\n        assert string == \"{}/{}\".format(a, b)\n        new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n        return new_string\n    except Exception:  # noqa: E722\n        return string\n\n\ndef _remove_right_units(string):\n    # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n    if \"\\\\text{ \" in string:\n        splits = string.split(\"\\\\text{ \")\n        assert len(splits) == 2\n        return splits[0]\n    else:\n        return string\n\n\ndef _fix_sqrt(string):\n    if \"\\\\sqrt\" not in string:\n        return string\n    splits = string.split(\"\\\\sqrt\")\n    new_string = splits[0]\n    for split in splits[1:]:\n        if split[0] != \"{\":\n            a = split[0]\n            new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n        else:\n            new_substr = \"\\\\sqrt\" + split\n        new_string += new_substr\n    return new_string\n\n\ndef _strip_string(string):\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = _remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\%\", \"\")\n    string = string.replace(\"\\%\", \"\")\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2 and len(string.split(\"=\")[0]) <= 2:\n        string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = _fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1).\n    # Also does a/b --> \\\\frac{a}{b}\n    string = _fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = _fix_a_slash_b(string)\n\n    return string\n"
  },
  {
    "path": "verl_rl/recipe/genrm_remote/README.md",
    "content": "# Generative Reward Model\n\n## Scripts\n\n### Step 1: Launch a vLLM Server (Optional)\n\nDeploy the pretrained GenRM model using vLLM. Skip this step if you want to use an external api service.\n\n```bash \nvllm serve verl-team/GenRM-CI-Test-1.5B --served-model-name genrm-demo\n```\n\n### Step 2: Perform RL using GenRM\n\n```bash\nbash recipe/api-genrm/run_genrm_remote.sh\n```\n\nThe implementation works by passing a customized reward function (see `reward_function.py`)\n\nFor convenience, we run both the RL training and server on the same machine. To use an external server, configure the `BASE_URL` and `API_KEY` in `reward_function.py` first.\n\n## Advanced: Customizing Your GenRM\n\nYou can use sglang server with data parallel for faster inference:\n\n```bash\nCUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4\n```\n\nNote that you should modify the `BASE_URL` in `reward_function.py` to match your SGLang Server address.\n\nYou can also create your own customized GenRM by implementing a custom reward function. Here are some tips for customizing your own GenRM based on `reward_function.py`:\n\n- Design appropriate prompts for your GenRM\n- Convert GenRM responses into RL rewards\n- ...\n\nSince these aspects are highly flexible, we only provide a demo implementation. The actual design and implementation of GenRM is left to the user's discretion.\n"
  },
  {
    "path": "verl_rl/recipe/genrm_remote/reward_function.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom concurrent.futures import ThreadPoolExecutor\nfrom time import sleep\n\nimport requests\n\nfrom verl.utils.reward_score.math import last_boxed_only_string, remove_boxed\n\nBASE_URL = \"http://localhost:30000\"\nAPI_KEY = \"EMPTY\"\nMAX_RETRIES = 3\nBASE_DELAY = 2\nMAX_WORKERS = 32\nMODEL_NAME = \"genrm-demo\"\nGENRM_PROMPT_TEMPLATE = \"\"\"\nThe following is a math problem and an AI solution:\n\n[Math Problem]\n\n{problem}\n\n[AI Solution]\n\n{solution}\n\nYour task is to review and critique the solution step by step, and output whether the AI solution is correct.\n\nPlease put your final answer (i.e., 'True' or 'False') in \\\\boxed{{}}.\n\"\"\".strip()\n\n\ndef get_response(problem, solution_str, ground_truth):\n    prompt = GENRM_PROMPT_TEMPLATE.format(problem=problem, solution=solution_str)\n    messages = [{\"role\": \"user\", \"content\": prompt}]\n    for attempt in range(MAX_RETRIES):\n        try:\n            headers = {\"Content-Type\": \"application/json\"}\n            chat_url = f\"{BASE_URL}/v1/chat/completions\"\n            data = {\"model\": MODEL_NAME, \"messages\": messages}\n            output = requests.post(chat_url, headers=headers, json=data, timeout=30)\n            response = output.json()[\"choices\"][0][\"message\"][\"content\"]\n            return response\n        except Exception as e:\n            if attempt < MAX_RETRIES - 1:\n                print(\"Exception: \", repr(e))\n                delay = BASE_DELAY * (2**attempt)\n                print(f\"Retrying in {delay} seconds...\")\n                sleep(delay)\n            else:\n                print(f\"Failed after {MAX_RETRIES} attempts. Error: {e}\")\n\n    raise ConnectionRefusedError(f\"Failed to run the model for {prompt}!\")\n\n\ndef compute_reward(response):\n    reward_score = 0.0\n    try:\n        boxed_result = last_boxed_only_string(response)\n        if boxed_result is not None:\n            result = remove_boxed(boxed_result)\n            reward_score = float(result == \"True\")\n    except Exception as e:\n        print(e)\n    return reward_score\n\n\ndef compute_score(data_source, solution_str, ground_truth, extra_info):\n    split = extra_info[\"split\"]\n    from verl.utils.reward_score import default_compute_score\n\n    func_rm_score = default_compute_score(data_source, solution_str, ground_truth, extra_info)\n\n    if split == \"test\":\n        return func_rm_score\n    else:\n        problem = extra_info[\"question\"]\n        response = get_response(problem, solution_str, ground_truth)\n        if response is not None:\n            reward_score = compute_reward(response)\n        else:\n            reward_score = 0.0\n\n        return reward_score\n\n\ndef compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos):\n    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:\n        futures = []\n        for data_source, solution_str, ground_truth, extra_info in zip(\n            data_sources, solution_strs, ground_truths, extra_infos, strict=True\n        ):\n            future = executor.submit(compute_score, data_source, solution_str, ground_truth, extra_info)\n            futures.append(future)\n\n        results = [future.result() for future in futures]\n\n    return results\n"
  },
  {
    "path": "verl_rl/recipe/genrm_remote/run_genrm_remote.sh",
    "content": "# vllm server\n# CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve verl-team/GenRM-CI-Test-1.5B --served_model_name genrm-demo\n\n# sglang server\n# CUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4\n\nset -x\n\nCUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=${HOME}/data/gsm8k/train.parquet \\\n    data.val_files=${HOME}/data/gsm8k/test.parquet \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=8 \\\n    algorithm.use_kl_in_reward=False \\\n    reward_model.reward_manager=batch \\\n    custom_reward_function.path=recipe/genrm_remote/reward_function.py \\\n    custom_reward_function.name=compute_score_batch \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_func_rm_example_gsm8k' \\\n    trainer.experiment_name='qwen2_5_3b_gen_rm' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.val_before_train=True \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=20 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=10 \\\n    trainer.resume_mode='disable'\n"
  },
  {
    "path": "verl_rl/recipe/langgraph_agent/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/recipe/langgraph_agent/chat_model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nRef: https://python.langchain.com/docs/how_to/custom_chat_model/\n\"\"\"\n\nimport asyncio\nimport json\nimport logging\nimport os\nimport uuid\nfrom typing import Any, Optional\n\nfrom langchain_core.language_models import BaseChatModel\nfrom langchain_core.language_models.base import LanguageModelInput\nfrom langchain_core.messages import (\n    AIMessage,\n    BaseMessage,\n    convert_to_openai_messages,\n)\nfrom langchain_core.messages.tool import InvalidToolCall, ToolCall\nfrom langchain_core.outputs import ChatGeneration, ChatResult\nfrom langchain_core.runnables import Runnable, RunnableConfig\nfrom langchain_core.tools import StructuredTool\nfrom langchain_core.utils.function_calling import convert_to_openai_tool\nfrom pydantic import Field\n\nfrom verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager\nfrom verl.experimental.agent_loop.tool_parser import ToolParser\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MaxTokenExceededError(Exception):\n    \"\"\"Indicate that history chat messages + tool message exceeds LLM max_tokens.\"\"\"\n\n    pass\n\n\nclass ChatModel(BaseChatModel):\n    model_name: str = Field(alias=\"model\")\n    \"\"\"The name of the model\"\"\"\n\n    client: AsyncLLMServerManager\n    \"\"\"AsyncLLM server manager\"\"\"\n\n    tokenizer: Any\n    \"\"\"Tokenizer for the model\"\"\"\n\n    max_tokens: int\n    \"\"\"Max tokens to generate\"\"\"\n\n    tool_parser: str = \"hermes\"\n    \"\"\"Tool parser for the model\"\"\"\n\n    max_parallel_calls: int = 1\n    \"\"\"Max parallel tool calls\"\"\"\n\n    temperature: float = 1.0\n    \"\"\"Temperature for sampling\"\"\"\n\n    top_p: float = 1.0\n    \"\"\"Top p for sampling\"\"\"\n\n    repetition_penalty: float = 1.0\n    \"\"\"Repetition penalty for sampling\"\"\"\n\n    def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]:\n        \"\"\"Bind tools to the model.\n\n        Args:\n            tools: Sequence of tools to bind to the model.\n\n        Returns:\n            A Runnable that returns a message.\n        \"\"\"\n        formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools]\n\n        # used to remove system prompt prefix when encoding tool response\n        system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)\n        kwargs[\"system_prompt\"] = system_prompt\n\n        return self.bind(tools=formatted_tools, **kwargs)\n\n    def with_structured_output(\n        self,\n        schema: dict | type,\n        *,\n        include_raw: bool = False,\n        **kwargs: Any,\n    ) -> Runnable[LanguageModelInput, dict | BaseChatModel]:\n        \"\"\"Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/\"\"\"\n        raise NotImplementedError\n\n    def _generate(\n        self,\n        messages: list[BaseMessage],\n        stop: Optional[list[str]] = None,\n        **kwargs: Any,\n    ) -> ChatResult:\n        raise NotImplementedError\n\n    async def _agenerate(\n        self,\n        messages: list[BaseMessage],\n        stop: Optional[list[str]] = None,\n        **kwargs: Any,\n    ) -> ChatResult:\n        \"\"\"Asynchronously generate chat completion message.\n\n        Args:\n            messages (list[BaseMessage]): List of list of messages.\n            stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the\n                first occurrence of any of these substrings. Defaults to None.\n\n        Returns:\n            ChatResult: Chat result.\n        \"\"\"\n        request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs)\n\n        sampling_params = {\n            \"temperature\": self.temperature,\n            \"top_p\": self.top_p,\n            \"repetition_penalty\": self.repetition_penalty,\n        }\n        if \"sampling_params\" in kwargs:\n            sampling_params.update(kwargs[\"sampling_params\"])\n\n        response_ids = await self.client.generate(\n            request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params\n        )\n\n        message = await self._postprocess(request_id, prompt_ids, response_mask, response_ids, **kwargs)\n        generation = ChatGeneration(message=message)\n        return ChatResult(generations=[generation])\n\n    @property\n    def _llm_type(self) -> str:\n        \"\"\"Get the type of language model used by this chat model.\"\"\"\n        return self.model_name\n\n    async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]:\n        \"\"\"Preprocess messages for chat completion.\n\n        To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out\n        instead of messages list.\n\n        But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory\n        (prompt_ids, response_mask) in lastest AIMessage.response_metadata.\n\n        1. Encode ToolMessage to token ids.\n        2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata.\n        3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask.\n\n        Ref: https://python.langchain.com/docs/concepts/chat_history/\n\n        Args:\n            messages (list[BaseMessage]): List of messages.\n\n        Returns:\n            tuple[str, list[int], list[int]]: Request id, prompt ids, response mask.\n        \"\"\"\n        # messages: [system], human, ai, human|tool, ai, human|tool, ...\n        assert messages[-1].type in [\"human\", \"tool\"], (\n            f\"Last message must be human or tool, but got {messages[-1].type}\"\n        )\n        loop = asyncio.get_running_loop()\n\n        # Case 1: initial chat completion: [system], human\n        if messages[-1].type == \"human\" and (len(messages) == 1 or messages[-2].type != \"ai\"):\n            prompt_ids = await loop.run_in_executor(\n                None,\n                lambda: self.tokenizer.apply_chat_template(\n                    convert_to_openai_messages(messages),\n                    tools=kwargs.get(\"tools\"),\n                    add_generation_prompt=True,\n                    tokenize=True,\n                ),\n            )\n            return str(uuid.uuid4()), prompt_ids, []\n\n        # Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ...\n        for i in range(len(messages) - 1, -1, -1):\n            if messages[i].type == \"ai\":\n                break\n        assert \"prompt_ids\" in messages[i].response_metadata, \"Last message must have prompt_ids in response_metadata\"\n        assert \"response_mask\" in messages[i].response_metadata, (\n            \"Last message must have response_mask in response_metadata\"\n        )\n\n        # encode tool response\n        tool_responses = convert_to_openai_messages(messages[i + 1 :])\n        tool_response_ids = await loop.run_in_executor(\n            None,\n            lambda messages=tool_responses: self.tokenizer.apply_chat_template(\n                messages, add_generation_prompt=True, tokenize=True\n            ),\n        )\n        tool_response_ids = tool_response_ids[len(kwargs[\"system_prompt\"]) :]\n\n        # stop generation if response length exceeds max response length\n        if len(messages[i].response_metadata[\"response_mask\"]) + len(tool_response_ids) >= self.max_tokens:\n            raise MaxTokenExceededError(f\"Max response length {self.max_tokens} exceeded\")\n\n        # append tool response to prompt\n        request_id = messages[i].response_metadata.pop(\"request_id\")\n        prompt_ids = messages[i].response_metadata.pop(\"prompt_ids\")\n        response_mask = messages[i].response_metadata.pop(\"response_mask\")\n        prompt_ids += tool_response_ids\n        response_mask += [0] * len(tool_response_ids)\n\n        return request_id, prompt_ids, response_mask\n\n    async def _postprocess(\n        self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any\n    ) -> AIMessage:\n        \"\"\"Postprocess response_ids when chat completion is done.\n\n        1. Decode response_ids, parse tool calls to AIMessage.\n        2. Append response_ids to prompt_ids, and append 1 to response_mask.\n        3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata.\n\n        Args:\n            request_id (str): Unique request id.\n            prompt_ids (list[int]): Input prompt token ids in this chat completion.\n            response_mask (list[int]): Response mask before this chat completion.\n            response_ids (list[int]): LLM generated token ids in this chat completion.\n\n        Returns:\n            AIMessage: Postprocessed message.\n        \"\"\"\n        prompt_ids += response_ids\n        response_mask += [1] * len(response_ids)\n\n        tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer)\n        content, function_calls = await tool_parser.extract_tool_calls(response_ids)\n\n        tool_calls, invalid_tool_calls = [], []\n        for function_call in function_calls:\n            try:\n                args = json.loads(function_call.arguments)\n                if not isinstance(args, dict):\n                    raise json.JSONDecodeError(f\"Invalid json tool arguments: {args}\")\n                tool_call = ToolCall(\n                    args=args,\n                    name=function_call.name,\n                    id=str(uuid.uuid4()),\n                )\n                tool_calls.append(tool_call)\n            except json.JSONDecodeError as e:\n                logger.warning(f\"Invalid json tool arguments: {e}\")\n                tool_call = InvalidToolCall(\n                    args=function_call.arguments,\n                    name=function_call.name,\n                    error=f\"Invalid json tool arguments: {e}\",\n                )\n                invalid_tool_calls.append(tool_call)\n\n        message = AIMessage(\n            content=content,\n            tool_calls=tool_calls[: self.max_parallel_calls],\n            invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls],\n            response_metadata={\n                \"request_id\": request_id,\n                \"prompt_ids\": prompt_ids,\n                \"response_mask\": response_mask,\n            },\n        )\n        return message\n\n\nclass TruncateStructuredTool(StructuredTool):\n    \"\"\"Structured tool with response truncation.\"\"\"\n\n    tool_response_truncate_side: str\n    \"\"\"truncate side of tool response: left, middle, right\"\"\"\n\n    max_tool_response_length: int\n    \"\"\"max length of tool response\"\"\"\n\n    async def _arun(\n        self,\n        *args: Any,\n        config: RunnableConfig,\n        **kwargs: Any,\n    ) -> Any:\n        tool_response = await super()._arun(*args, config=config, **kwargs)\n        tool_response = str(tool_response)\n\n        if len(tool_response) > self.max_tool_response_length:\n            if self.tool_response_truncate_side == \"left\":\n                tool_response = tool_response[: self.max_tool_response_length] + \"...(truncated)\"\n            elif self.tool_response_truncate_side == \"right\":\n                tool_response = \"(truncated)...\" + tool_response[-self.max_tool_response_length :]\n            else:\n                length = self.max_tool_response_length // 2\n                tool_response = tool_response[:length] + \"...(truncated)...\" + tool_response[-length:]\n\n        return tool_response\n\n\ndef convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput:\n    \"\"\"Convert messages to AgentLoopOutput.\n\n    Args:\n        messages (List[BaseMessage]): List of messages, last message must be assistant\n            with response_metadata containing `prompt_ids` and `response_mask`.\n        response_length (int): Max length of response.\n\n    Returns:\n        AgentLoopOutput: agent loop output trajectory used for training.\n    \"\"\"\n    # skip last tool calls\n    for i in range(len(messages) - 1, -1, -1):\n        if messages[i].type != \"tool\":\n            break\n    last_message = messages[i]\n    assert last_message.type == \"ai\", f\"Last message must be assistant, but got {last_message.type}\"\n    assert \"prompt_ids\" in last_message.response_metadata, \"Last message must have prompt_ids in response_metadata\"\n    assert \"response_mask\" in last_message.response_metadata, (\n        \"Last message must have response_mask in response_metadata\"\n    )\n\n    num_turns = 0\n    for i in range(len(messages)):\n        if messages[i].type == \"system\":\n            continue\n        # parallel tool calls are in single turn\n        if i == 0 or messages[i].type != messages[i - 1].type:\n            num_turns += 1\n\n    prompt_ids = last_message.response_metadata[\"prompt_ids\"]\n    response_mask = last_message.response_metadata[\"response_mask\"]\n\n    response_ids = prompt_ids[-len(response_mask) :]\n    prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]\n\n    output = AgentLoopOutput(\n        prompt_ids=prompt_ids,\n        response_ids=response_ids[:response_length],\n        response_mask=response_mask[:response_length],\n        num_turns=num_turns,\n        metrics={},\n    )\n    return output\n"
  },
  {
    "path": "verl_rl/recipe/langgraph_agent/example/README.md",
    "content": "# MathExpression: LangGraph Agent Example\n\nMathExpression is a tiny example to demonstrate multi-turn rollout with [LangGraph ReactAgent](https://langchain-ai.github.io/langgraph/agents/overview/).\n\n### Define react agent with tool\nFirstly, to force ReactAgent to evaluate math expression by tool, we define a special operand `@`:\n```python\n@tool(parse_docstring=True)\ndef calculate(a: int, b: int, operand: str) -> int:\n    \"\"\"\n    Compute the results using operand with two integers\n\n    Args:\n        a: the first operand\n        b: the second operand\n        operand: '+' or '-' or '*' or '@'\n    \"\"\"\n    assert operand in [\"+\", \"-\", \"*\", \"@\"], f\"unknown operand {operand}\"\n    if operand == \"@\":\n        return 3 * a - 2 * b\n    return eval(f\"{a} {operand} {b}\")\n```\n\nWithout calling `calculate`, ReactAgent is impossible to evaluate math expression correctly.\n\nThen, we can equip ReactAgent with `calculate` tool:\n```python\nclass MathExpressionReactAgentLoop(ReactAgentLoop):\n    @classmethod\n    def init_class(cls, config, tokenizer):\n        cls.tools = [calculate]\n        super().init_class(config, tokenizer)\n```\n\nWe can define agent loop config in yaml file, which will be used by AgentLoopWorker to dynamic load custom AgentLoop class.\n```yaml\n- name: math_expression\n  _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop\n```\n\n### Prepare dataset\nNow, let's prepare two small datasets for training and evaluation:\n```bash\npython recipe/langgraph_agent/example/create_dataset.py\n```\n\nNote that dataset should contain a column `agent_name` with `math_expression`, which is used by `AgentLoopWorker` to select the\nagent loop class.\n| prompt | reward_model | agent_name |\n|--------------------------------------|------------------------------|-----------------|\n| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |\n| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |\n\nGenerated math expressions are like below, requiring model to call `calculate` multiple times to solve sub expressions.\n```\n(2 @ (8 @ 8 @ 5 @ 5 @ 3) @ 6 @ (1 @ 4 @ 4 @ 4) @ 2) @ 6\n(4.6 @ (9.05 @ 4.0) @ 8.3 @ 1.21) @ 8.6\n9 @ 4\n((2 @ 2) @ (3 @ 3)) @ 4\n```\n\n### Training\nHook all these up and start training:\n```bash\nbash recipe/langgraph_agent/example/run_qwen2.5_3b.sh 2>&1 | tee train.log\n```\n\nAfter total 39 steps, model should achieve 100% accuray on test dataset:\n- val-aux/lighteval/MATH/reward: 1.0\n- val-aux/num_turns/mean: 9.0, average number of messages include assistant and tool turns.\n\nWe can see model can solve math expression by calling `calculate` tool:\n```\nsystem\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.\nuser\nWe define a new math operator @, where you can only call an external tool to compute. Please put your final answer inside \\boxed{} only in the last turn. Now answer the following questions:\nCompute 5 @ (3 @ 6 @ 8)\nassistant\nSure, let's break this down step by step. We need to compute the expression \\(5 @ (3 @ 6 @ 8)\\). First, we need to compute the inner expression \\(3 @ 6 @ 8\\).\n\nLet's compute \\(3 @ 6\\) first using the external tool.\n<tool_call>\n{\"name\": \"calculate\", \"arguments\": {\"a\": 3, \"b\": 6, \"operand\": \"@\"}}\n</tool_call>user\n<tool_response>\n-3\n</tool_response>\nassistant\nWe have computed that \\(3 @ 6 = -3\\). Now we need to compute \\(-3 @ 8\\).\n\nLet's use the external tool to compute this.\n<tool_call>\n{\"name\": \"calculate\", \"arguments\": {\"a\": -3, \"b\": 8, \"operand\": \"@\"}}\n</tool_call>user\n<tool_response>\n-25\n</tool_response>\nassistant\nWe have computed that \\(-3 @ 8 = -25\\). Now we can put this result into the original expression to get the final result of \\(5 @ (3 @ 6 @ 8)\\), which is \\(5 @ (-25)\\).\n\nLet's compute \\(5 @ (-25)\\) using the external tool.\n<tool_call>\n{\"name\": \"calculate\", \"arguments\": {\"a\": 5, \"b\": -25, \"operand\": \"@\"}}\n</tool_call>user\n<tool_response>\n65\n</tool_response>\nassistant\nThe final result of \\(5 @ (3 @ 6 @ 8)\\) is \\(\\boxed{65}\\).\n```\n"
  },
  {
    "path": "verl_rl/recipe/langgraph_agent/example/agent.yaml",
    "content": "- name: math_expression\n  _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop\n"
  },
  {
    "path": "verl_rl/recipe/langgraph_agent/example/create_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nCreate dataset for calculator\n\"\"\"\n\nimport random\n\nimport pandas as pd\n\n\ndef generate_math_expression(min_terms=2, max_terms=5, min_number=1, max_number=10, allow_decimals=False, max_depth=2):\n    \"\"\"\n    Generate a random mathematical expression with operators +, -, *, /, and parentheses.\n\n    Args:\n        min_terms (int): Minimum number of terms in the expression.\n        max_terms (int): Maximum number of terms in the expression.\n        max_number (int): Maximum value for numbers in the expression.\n        allow_decimals (bool): Whether to allow decimal numbers.\n        max_depth (int): Maximum nesting depth for parentheses.\n\n    Returns:\n        str: A valid mathematical expression as a string.\n    \"\"\"\n\n    def generate_number():\n        \"\"\"Generate a random number (integer or float).\"\"\"\n        assert min_number < max_number\n        num = random.uniform(min_number, max_number)\n        if not allow_decimals:\n            num = int(num)\n        else:\n            num = round(num, random.randint(0, 2))  # Round to 0-2 decimal places\n        return str(num)\n\n    def generate_term(depth=0):\n        \"\"\"Generate a term (number or parenthesized expression).\"\"\"\n        if depth < max_depth and random.random() < 0.5:  # 50% chance to add parentheses\n            expr = generate_expression(depth + 1)\n            return f\"({expr})\"\n        else:\n            return generate_number()\n\n    def generate_expression(depth=0):\n        \"\"\"Generate a full expression with multiple terms and operators.\"\"\"\n        num_terms = random.randint(min_terms, max_terms)\n        terms = [generate_term(depth) for _ in range(num_terms)]\n\n        # Randomly select operators\n        operators = [\"+\", \"-\", \"*\", \"/\", \"@\"]\n        expr = terms[0]\n\n        for i in range(1, num_terms):\n            # Bias towards + and - for readability\n            op = random.choices(\n                operators,\n                weights=[0, 0, 0, 0, 1],  # + and - are 1.5x more likely than * and /\n            )[0]\n            expr += f\" {op} \" + terms[i]\n\n        return expr\n\n    return generate_expression()\n\n\ndef test():\n    # Example 1: Basic integer expression\n    print(generate_math_expression())\n    # Output: (3 + 7) * 2 - 5\n\n    # Example 2: Expression with decimals\n    print(generate_math_expression(allow_decimals=True))\n    # Output: 4.5 / (2.1 + 3.7) - 1.2\n\n    # Example 3: More complex expression with higher depth\n    print(generate_math_expression(max_terms=6, max_depth=3))\n    # Output: ((5 * 2) - (3 + 1)) / (7 - 2) + 4\n\n    # Example 4: Simplified expression\n    print(generate_math_expression(min_terms=2, max_terms=3, max_number=5))\n    # Output: 4 - 2 * 3\n\n\ndef calculate(expression: str) -> float:\n    \"\"\"\n    Evaluate a mathematical expression with +, -, *, /, @, and parentheses.\n    The @ operator is defined as: a @ b = 3a - 2b.\n\n    Args:\n        expression (str): Input mathematical expression (e.g., \"3@2+4\").\n\n    Returns:\n        float: Result of the evaluated expression.\n\n    Raises:\n        ValueError: For invalid expressions (e.g., mismatched parentheses, division by zero).\n    \"\"\"\n\n    def tokenize(s: str) -> list:\n        \"\"\"Convert the input string into tokens (numbers, operators, parentheses).\"\"\"\n        tokens = []\n        i = 0\n        while i < len(s):\n            if s[i].isdigit() or s[i] == \".\":\n                # Parse number (integer or float)\n                j = i\n                while j < len(s) and (s[j].isdigit() or s[j] == \".\"):\n                    j += 1\n                tokens.append(s[i:j])\n                i = j\n            elif s[i] in \"+-*/@()\":\n                # Operator or parenthesis\n                tokens.append(s[i])\n                i += 1\n            elif s[i].isspace():\n                # Skip whitespace\n                i += 1\n            else:\n                raise ValueError(f\"Invalid character: {s[i]}\")\n        return tokens\n\n    def infix_to_postfix(tokens: list) -> list:\n        \"\"\"Convert infix notation to postfix notation (Reverse Polish Notation).\"\"\"\n        output = []\n        stack = []\n        # Higher precedence for @ (between * and +)\n        precedence = {\"@\": 3, \"*\": 2, \"/\": 2, \"+\": 1, \"-\": 1}\n\n        for token in tokens:\n            if token.isdigit() or \".\" in token:\n                output.append(token)\n            elif token == \"(\":\n                stack.append(token)\n            elif token == \")\":\n                while stack and stack[-1] != \"(\":\n                    output.append(stack.pop())\n                if not stack or stack[-1] != \"(\":\n                    raise ValueError(\"Mismatched parentheses\")\n                stack.pop()  # Discard '('\n            else:  # Operator\n                while stack and stack[-1] != \"(\" and precedence.get(stack[-1], 0) >= precedence.get(token, 0):\n                    output.append(stack.pop())\n                stack.append(token)\n\n        # Pop remaining operators\n        while stack:\n            if stack[-1] in \"()\":\n                raise ValueError(\"Mismatched parentheses\")\n            output.append(stack.pop())\n\n        return output\n\n    def evaluate_postfix(postfix: list) -> float:\n        \"\"\"Evaluate postfix expression using a stack.\"\"\"\n        stack = []\n        for token in postfix:\n            if token.isdigit() or \".\" in token:\n                stack.append(float(token))\n            else:\n                if len(stack) < 2:\n                    raise ValueError(\"Invalid expression\")\n                b = stack.pop()\n                a = stack.pop()\n                if token == \"+\":\n                    res = a + b\n                elif token == \"-\":\n                    res = a - b\n                elif token == \"*\":\n                    res = a * b\n                elif token == \"/\":\n                    if b == 0:\n                        raise ValueError(\"Division by zero\")\n                    res = a / b\n                elif token == \"@\":\n                    res = 3 * a - 2 * b  # Custom @ operator implementation\n                else:\n                    raise ValueError(f\"Invalid operator: {token}\")\n                stack.append(res)\n\n        if len(stack) != 1:\n            raise ValueError(\"Invalid expression\")\n        return stack[0]\n\n    # Remove spaces and validate parentheses\n    expression = expression.replace(\" \", \"\")\n    if expression.count(\"(\") != expression.count(\")\"):\n        raise ValueError(\"Mismatched parentheses\")\n\n    tokens = tokenize(expression)\n    postfix = infix_to_postfix(tokens)\n    result = evaluate_postfix(postfix)\n\n    # Convert integers to integer representation\n    if result.is_integer():\n        return int(result)\n    return result\n\n\ndef generate_data(total_num_dataset, split):\n    rl_dataset = {\n        \"prompt\": [],\n        \"data_source\": [],\n        \"ability\": [],\n        \"reward_model\": [],\n        \"extra_info\": [],\n        \"agent_name\": [],\n    }\n\n    for idx in range(total_num_dataset):\n        while True:\n            try:\n                expression: str = generate_math_expression(\n                    min_terms=2, max_terms=3, min_number=1, max_number=10, allow_decimals=False, max_depth=1\n                )\n\n                num_plus = expression.count(\"+\")\n                num_minus = expression.count(\"-\")\n                num_mul = expression.count(\"*\")\n                num_star = expression.count(\"@\")\n\n                answer = str(calculate(expression))\n                # answer = str(eval(expression))\n                break\n            except Exception as e:\n                print(e)\n                continue\n\n        num_tool_calls = num_plus + num_minus + num_mul + num_star\n\n        prompt = (\n            f\"We define a new math operator @, where you can only call an external tool to compute. \"\n            f\"Please put your final answer inside \\\\boxed{{}} only in the last turn. Now answer the \"\n            f\"following questions:\\nCompute {expression}\"\n        )\n        prompt_with_template = [\n            {\n                \"role\": \"user\",\n                \"content\": prompt,\n            }\n        ]\n\n        rl_dataset[\"prompt\"].append(prompt_with_template)\n        rl_dataset[\"data_source\"].append(\"lighteval/MATH\")\n        rl_dataset[\"ability\"].append(\"math\")\n        rl_dataset[\"reward_model\"].append({\"style\": \"lighteval/MATH\", \"ground_truth\": answer})\n        rl_dataset[\"extra_info\"].append(\n            {\"index\": idx, \"expression\": expression, \"split\": split, \"expected_tool_calls\": num_tool_calls}\n        )\n        rl_dataset[\"agent_name\"].append(\"math_expression\")\n\n    rl_dataset = pd.DataFrame(data=rl_dataset)\n    return rl_dataset\n\n\nif __name__ == \"__main__\":\n    # print(calculate(\"3@2\"))          # Output: 5 (3*3 - 2*2)\n    # print(calculate(\"3@2+4\"))        # Output: 9 (5 + 4)\n    # print(calculate(\"3*(4@2)\"))      # Output: 24 (3 * 8)\n    # print(calculate(\"(5@3)*2\"))      # Output: 18 (9 * 2)\n\n    train_dataset = generate_data(total_num_dataset=5000, split=\"train\")\n    test_dataset = generate_data(total_num_dataset=500, split=\"test\")\n\n    train_dataset.to_parquet(\"train.parquet\")\n    test_dataset.to_parquet(\"test.parquet\")\n"
  },
  {
    "path": "verl_rl/recipe/langgraph_agent/example/math_expression.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom langchain_core.tools import tool\n\nfrom recipe.langgraph_agent.react_agent_loop import ReactAgentLoop\n\n\n@tool(parse_docstring=True)\ndef calculate(a: int, b: int, operand: str) -> int:\n    \"\"\"\n    Compute the results using operand with two integers\n\n    Args:\n        a: the first operand\n        b: the second operand\n        operand: '+' or '-' or '*' or '@'\n    \"\"\"\n    assert operand in [\"+\", \"-\", \"*\", \"@\"], f\"unknown operand {operand}\"\n    if operand == \"@\":\n        return 3 * a - 2 * b\n    return eval(f\"{a} {operand} {b}\")\n\n\nclass MathExpressionReactAgentLoop(ReactAgentLoop):\n    @classmethod\n    def init_class(cls, config, tokenizer, **kwargs):\n        cls.tools = [calculate]\n        super().init_class(config, tokenizer)\n"
  },
  {
    "path": "verl_rl/recipe/langgraph_agent/example/run_qwen2.5_3b.sh",
    "content": "set -x\n\n# ================= data/model/tool =================\nHDFS_ROOT=${HDFS_ROOT:-$PWD}\nDATA_ROOT=${DATA_ROOT:-$PWD}\n\nmodel_path=$DATA_ROOT/model/Qwen2.5-3B-Instruct\n\ntrain_files=$DATA_ROOT/dataset/math_expression_tool/train.parquet\ntest_files=$DATA_ROOT/dataset/math_expression_tool/test.parquet\n\n# agent\nagent_loop_config_path=recipe/langgraph_agent/example/agent.yaml\n\n# wandb\nproject_name=math_expression_tool\nexperiment_name=qwen2.5-3b\ndefault_local_dir=$DATA_ROOT/checkpoint/$experiment_name\n\n# ================= algorithm =================\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_turns=8\nmax_prompt_length=1024\nmax_response_length=2048\nactor_lr=1e-6\n\ntrain_batch_size=128\nppo_mini_batch_size=16\nn_resp_per_prompt=8\nn_resp_per_prompt_val=1\n\n# ================= perfomance =================\ninfer_tp=2 # vllm\ntrain_sp=4 # train\noffload=True\n\nactor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))\nlog_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 ))\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=$adv_estimator \\\n    algorithm.use_kl_in_reward=$use_kl_in_reward \\\n    algorithm.kl_ctrl.kl_coef=$kl_coef \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.return_raw_chat=True \\\n    data.train_batch_size=$train_batch_size \\\n    data.max_prompt_length=$max_prompt_length \\\n    data.max_response_length=$max_response_length \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \\\n    actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \\\n    actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \\\n    actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.actor.optim.lr=$actor_lr \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=$offload \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.mode=async \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \\\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \\\n    actor_rollout_ref.rollout.multi_turn.format=hermes \\\n    actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \\\n    actor_rollout_ref.rollout.n=$n_resp_per_prompt \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \\\n    actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \\\n    trainer.logger=['console','wandb'] \\\n    trainer.project_name=$project_name \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.n_gpus_per_node=$ARNOLD_WORKER_GPU \\\n    trainer.val_before_train=True \\\n    trainer.log_val_generations=50 \\\n    trainer.nnodes=$ARNOLD_WORKER_NUM \\\n    trainer.save_freq=-1 \\\n    trainer.default_local_dir=$default_local_dir \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@\n"
  },
  {
    "path": "verl_rl/recipe/langgraph_agent/react_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nLangGraph React Agent Loop.\n\nThis implementation is exact same as `ToolAgentLoop`.\n\nRef: https://langchain-ai.github.io/langgraph/tutorials/workflows/\n\"\"\"\n\nfrom typing import Any, Literal\n\nfrom langchain_core.runnables import RunnableConfig\nfrom langgraph.graph import END, MessagesState, StateGraph\nfrom langgraph.prebuilt import ToolNode\n\nfrom recipe.langgraph_agent.chat_model import (\n    ChatModel,\n    MaxTokenExceededError,\n    convert_to_agent_output,\n)\nfrom verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput\n\n\nasync def call_model(state: MessagesState, config: RunnableConfig):\n    model = config[\"configurable\"][\"model\"]\n    sampling_params = config[\"configurable\"][\"sampling_params\"]\n    try:\n        message = await model.ainvoke(state[\"messages\"], sampling_params=sampling_params)\n        return {\"messages\": [message]}\n    except MaxTokenExceededError:\n        # last message is ToolMessage\n        return {\"messages\": []}\n\n\ndef should_continue(state: MessagesState, config: RunnableConfig) -> Literal[\"tools\", END]:\n    max_assistant_turns = config[\"configurable\"][\"max_assistant_turns\"]\n    num_assistant_turns = 0\n    for message in state[\"messages\"]:\n        if message.type == \"ai\":\n            num_assistant_turns += 1\n\n    last_message = state[\"messages\"][-1]\n\n    # LLM call failed, e.g: max response length exceeded\n    if last_message.type == \"tool\":\n        return END\n\n    # max assistant turns exceeded\n    if max_assistant_turns and num_assistant_turns >= max_assistant_turns:\n        return END\n\n    # no tool calls\n    if not last_message.tool_calls:\n        return END\n\n    return \"tools\"\n\n\nclass ReactAgentLoop(AgentLoopBase):\n    @classmethod\n    def init_class(cls, config, tokenizer, **kwargs):\n        if cls._class_initialized:\n            return\n        cls._class_initialized = True\n        print(\"Performing class-level ReactAgentLoop initialization\")\n\n        # build graph\n        cls.graph = cls.build_graph()\n\n    @classmethod\n    def build_graph(cls) -> StateGraph:\n        workflow = StateGraph(MessagesState)\n\n        workflow.add_node(\"agent\", call_model)\n        workflow.add_node(\"tools\", ToolNode(cls.tools))\n        workflow.set_entry_point(\"agent\")\n        workflow.add_conditional_edges(\n            \"agent\",\n            should_continue,\n            {\n                \"tools\": \"tools\",\n                END: END,\n            },\n        )\n\n        workflow.add_edge(\"tools\", \"agent\")\n        graph = workflow.compile()\n        return graph\n\n    async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:\n        model_path = self.config.actor_rollout_ref.model.path\n        model_name = \"/\".join(model_path.split(\"/\")[-2:])\n\n        rollout = self.config.actor_rollout_ref.rollout\n        model = ChatModel(\n            model=model_name,\n            client=self.server_manager,\n            tokenizer=self.tokenizer,\n            max_tokens=rollout.response_length,\n            max_parallel_calls=rollout.multi_turn.max_parallel_calls,\n            tool_parser=rollout.multi_turn.format,\n        )\n\n        model = model.bind_tools(self.tools, tool_choice=\"any\")\n\n        config = {\n            \"configurable\": {\n                \"model\": model,\n                \"sampling_params\": sampling_params,\n                \"max_user_turns\": rollout.multi_turn.max_user_turns,\n                \"max_assistant_turns\": rollout.multi_turn.max_assistant_turns,\n            }\n        }\n\n        # TODO: how to handle multiple trajectories in an graph invocation?\n        # Each graph node may has its own LLM calls and state, e.g:\n        # https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart\n        state = await self.graph.ainvoke(input={\"messages\": messages}, config=config)\n\n        output = convert_to_agent_output(state[\"messages\"], rollout.response_length)\n        return output\n"
  },
  {
    "path": "verl_rl/recipe/langgraph_agent/test_react_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport os\n\nimport numpy as np\nimport pytest\nimport ray\nfrom langchain_core.tools import tool\nfrom omegaconf import DictConfig\n\nfrom recipe.langgraph_agent.react_agent_loop import ReactAgentLoop\nfrom tests.experimental.agent_loop.agent_utils import init_agent_loop_manager\nfrom verl.protocol import DataProto\nfrom verl.utils import hf_tokenizer\n\n\n@pytest.fixture\ndef init_config() -> DictConfig:\n    from hydra import compose, initialize_config_dir\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(config_name=\"ppo_trainer\")\n    model_path = \"Qwen/Qwen2.5-1.5B-Instruct\"\n    config.actor_rollout_ref.model.path = model_path\n    config.actor_rollout_ref.rollout.name = os.getenv(\"ROLLOUT_NAME\", \"vllm\")\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.prompt_length = 4096\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.n = 4\n    config.actor_rollout_ref.rollout.agent.num_workers = 2\n\n    # test sleep/wake_up with fsdp offload\n    config.actor_rollout_ref.actor.fsdp_config.param_offload = True\n    config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True\n\n    return config\n\n\n@tool(parse_docstring=True)\ndef get_current_temperature(location: str, unit: str = \"celsius\"):\n    \"\"\"Get current temperature at a location.\n\n    Args:\n        location: The location to get the temperature for, in the format \"City, State, Country\".\n        unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n    Returns:\n        the temperature, the location, and the unit in a dict\n    \"\"\"\n    print(f\"[DEBUG] get_current_temperature: {location}, {unit}\")\n    return {\n        \"temperature\": 26.1,\n        \"location\": location,\n        \"unit\": unit,\n    }\n\n\n@tool(parse_docstring=True)\ndef get_temperature_date(location: str, date: str, unit: str = \"celsius\"):\n    \"\"\"Get temperature at a location and date.\n\n    Args:\n        location: The location to get the temperature for, in the format \"City, State, Country\".\n        date: The date to get the temperature for, in the format \"Year-Month-Day\".\n        unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n    Returns:\n        the temperature, the location, the date and the unit in a dict\n    \"\"\"\n    print(f\"[DEBUG] get_temperature_date: {location}, {date}, {unit}\")\n    return {\n        \"temperature\": 25.9,\n        \"location\": location,\n        \"date\": date,\n        \"unit\": unit,\n    }\n\n\nclass TestReactAgentLoop(ReactAgentLoop):\n    @classmethod\n    def init_class(cls, config, tokenizer, **kwargs):\n        # TODO: find better way to configure tools\n        cls.tools = [get_current_temperature, get_temperature_date]\n        super().init_class(config, tokenizer, **kwargs)\n\n\ndef test_react_agent(init_config):\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    # =========================== 1. Init rollout manager ===========================\n    agent_loop_config = [\n        {\n            \"_target_\": \"recipe.langgraph_agent.test_react_agent_loop.TestReactAgentLoop\",\n            \"name\": \"react_agent\",\n        },\n    ]\n    agent_loop_config_path = \"/tmp/agent_loop_config.json\"\n    with open(agent_loop_config_path, \"w\") as f:\n        json.dump(agent_loop_config, f)\n\n    n = 2\n    init_config.actor_rollout_ref.rollout.n = n\n    # init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path\n    init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2\n    init_config.actor_rollout_ref.rollout.agent.agent_loop_config_path = agent_loop_config_path\n    agent_loop_manager = init_agent_loop_manager(init_config)\n\n    # =========================== 2. Generate sequences  ===========================\n    raw_prompts = [\n        [\n            {\"role\": \"user\", \"content\": \"How are you?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in Los Angeles now?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in New York now?\"},\n        ],\n        [\n            {\n                \"role\": \"system\",\n                \"content\": \"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\\n\\n\"\n                \"Current Date: 2024-09-30\",\n            },\n            {\"role\": \"user\", \"content\": \"What's the temperature in San Francisco now? How about tomorrow?\"},\n        ],\n    ]\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),\n            \"agent_name\": np.array([\"react_agent\"] * len(raw_prompts)),\n        },\n    )\n    batch = batch.repeat(n)\n    result = agent_loop_manager.generate_sequences(prompts=batch)\n    assert len(result) == len(raw_prompts) * n\n\n    # Check turns\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    print(f\"num_turns: {num_turns}\")\n    for i in range(len(num_turns)):\n        if i // n == 0:\n            # [user, assistant]\n            assert num_turns[i] == 2\n        else:\n            # [user, assistant, tool, assistant]\n            assert num_turns[i] == 4\n\n    # Check response_mask\n    tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)\n    responses = result.batch[\"responses\"]\n    response_mask = result.batch[\"response_mask\"]\n    attention_mask = result.batch[\"attention_mask\"]\n    assert responses.size() == response_mask.size(), f\"{responses.size()} != {response_mask.size()}\"\n    response_length = response_mask.size(1)\n\n    for i in range(len(responses)):\n        # response with tool response\n        valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]\n        response_with_obs = tokenizer.decode(valid_tokens)\n\n        # response without tool response\n        valid_tokens = responses[i][response_mask[i].bool()]\n        response_without_obs = tokenizer.decode(valid_tokens)\n\n        assert \"<tool_response>\" not in response_without_obs, (\n            f\"found <tool_response> in response: {response_without_obs}\"\n        )\n        assert \"</tool_response>\" not in response_without_obs, (\n            f\"found </tool_response> in response: {response_without_obs}\"\n        )\n        print(\"=========================\")\n        print(response_with_obs)\n        print(\"---\")\n        print(response_without_obs)\n\n    print(\"Test passed!\")\n    ray.shutdown()\n"
  },
  {
    "path": "verl_rl/recipe/minicpmo/rl_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nimport logging\nimport math\nimport os\nimport re\nfrom typing import Optional\n\nimport datasets\nimport torch\nfrom omegaconf import DictConfig, ListConfig\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.utils.dataset.vision_utils import process_image\nfrom verl.utils.model import compute_position_id_with_mask\n\nlogger = logging.getLogger(__name__)\n\n\ndef build_transform():\n    IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)  # timm.data.IMAGENET_INCEPTION_MEAN\n    IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)  # timm.data.IMAGENET_INCEPTION_STD\n    return transforms.Compose(\n        [\n            transforms.ToTensor(),\n            transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),\n        ]\n    )\n\n\ndef build_image_bound(input_ids, tokenizer, new_schema=True, logger=None):\n    if new_schema:\n        start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id)\n        end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id)\n    else:\n        start_cond = input_ids == tokenizer.im_start_id\n        end_cond = input_ids == tokenizer.im_end_id\n    image_start_tokens = torch.where(start_cond)[0]\n    image_start_tokens += 1\n    image_end_tokens = torch.where(end_cond)[0]\n    if len(image_start_tokens) != len(image_end_tokens):\n        logger.error(\"image start token != image end tokens\")\n        raise Exception(\"image start token != image end tokens\")\n    if len(image_start_tokens) > 0:\n        image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])\n    else:\n        image_bound = []\n    return image_bound\n\n\ndef preprocess(\n    images_dict,\n    conversations,\n    tokenizer,\n    transform,\n    query_nums=64,\n    slice_config=None,\n    llm_type=None,\n    patch_size=14,\n    batch_vision=False,\n    max_length=2048,\n    truncation=\"error\",\n    logger=None,\n):\n    \"\"\"\n    single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation\n    \"\"\"\n    conversations = copy.deepcopy(conversations)\n    assert conversations[0][\"role\"] == \"user\", \"the first role must be user\"\n\n    if slice_config is not None:\n        assert isinstance(slice_config, dict)\n        assert \"patch_size\" in slice_config\n        assert \"max_slice_nums\" in slice_config\n        assert \"scale_resolution\" in slice_config\n    default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end\n    new_schema = False\n    use_image_id = False\n    if llm_type == \"qwen\":\n        new_schema = True\n        use_image_id = True\n    image_placeholder_dict = {}\n    images = []\n    image_id_cnt = 0\n    for img_name, image in images_dict.items():\n        if slice_config:\n            source_image, patches, best_grid = slice_image(\n                image,\n                slice_config[\"max_slice_nums\"],\n                slice_config[\"scale_resolution\"],\n                slice_config[\"patch_size\"],\n            )\n            images.append(source_image)\n            image_placeholder = default_image_placeholder\n            if len(patches) > 0:\n                for i in range(len(patches)):\n                    for j in range(len(patches[0])):\n                        images.append(patches[i][j])\n                if use_image_id:\n                    image_placeholder = (\n                        f\"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}\" + image_placeholder\n                    )\n                    image_id_cnt += 1\n                image_placeholder += get_grid_placeholder(tokenizer, best_grid, query_nums, new_schema=new_schema)\n            image_placeholder_dict[img_name] = image_placeholder\n        else:\n            images.append(image)\n            if use_image_id:\n                image_placeholder = f\"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}\" + image_placeholder\n                image_id_cnt += 1\n            else:\n                image_placeholder = default_image_placeholder\n            image_placeholder_dict[img_name] = image_placeholder\n\n    images = [transform(i) for i in images]\n\n    if len(images_dict) == 1 and \"<image>\" in images_dict:\n        if \"<image>\" in conversations[0][\"content\"]:\n            conversations[0][\"content\"] = conversations[0][\"content\"].replace(\"<image>\", image_placeholder)\n        else:\n            conversations[0][\"content\"] = image_placeholder + \"\\n\" + conversations[0][\"content\"]\n    else:\n        pattern = r\"<image_\\d+>\"\n        new_conversations = []\n        for conversation in conversations:\n            content = conversation[\"content\"]\n            parts = re.split(f\"({pattern})\", content)\n            for i, part in enumerate(parts):\n                if not part.strip():\n                    continue\n                if re.match(pattern, part):\n                    if part in image_placeholder_dict:\n                        parts[i] = image_placeholder_dict[part]\n                    else:\n                        raise Exception(f\"not found {part} in image dict\")\n            conversation[\"content\"] = \"\\n\".join(parts)\n            new_conversations.append(conversation)\n        conversations = new_conversations\n\n    # TODO change role in conversation for different llm\n    prompt_with_chat_template = tokenizer.apply_chat_template(conversations, add_generation_prompt=True, tokenize=False)\n\n    input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(\n        prompt=prompt_with_chat_template,\n        tokenizer=tokenizer,\n        max_length=max_length,\n        pad_token_id=tokenizer.pad_token_id,\n        left_pad=True,\n        truncation=truncation,\n    )\n    position_ids = compute_position_id_with_mask(attention_mask)\n    image_bound = build_image_bound(input_ids[0], tokenizer, new_schema, logger)\n\n    input_dict = {\n        \"input_ids\": input_ids[0],\n        \"attention_mask\": attention_mask[0],\n        \"position_ids\": position_ids[0],\n        \"image_bound\": image_bound,\n    }\n\n    if batch_vision:\n        tgt_sizes = []\n        reshape_images = []\n        for image in images:\n            H, W = image.shape[1:]\n            reshape_image = reshape_by_patch(image, patch_size)\n            reshape_images.append(reshape_image)\n            tgt_sizes.append([H // patch_size, W // patch_size])\n        if tgt_sizes:\n            tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32)\n\n        input_dict[\"pixel_values\"] = reshape_images\n        input_dict[\"tgt_sizes\"] = tgt_sizes\n\n    else:\n        input_dict[\"pixel_values\"] = images\n        input_dict[\"tgt_sizes\"] = []\n\n    return input_dict\n\n\ndef slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):\n    original_size = image.size\n    original_width, original_height = original_size\n    log_ratio = math.log(original_width / original_height)\n    ratio = original_width * original_height / (scale_resolution * scale_resolution)\n    multiple = min(math.ceil(ratio), max_slice_nums)\n\n    source_image = None\n    best_grid = None\n    patches = []\n\n    if multiple <= 1 or never_split:\n        # dont need to slice, upsample\n        best_size = find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)\n        source_image = image.resize(best_size, Image.Resampling.BICUBIC)\n    else:\n        candidate_split_grids_nums = []\n        for i in [multiple - 1, multiple, multiple + 1]:\n            if i == 1 or i > max_slice_nums:\n                continue\n            candidate_split_grids_nums.append(i)\n\n        # source image, down-sampling and ensure divided by patch_size\n        best_resize = find_best_resize(original_size, scale_resolution, patch_size)\n        source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)\n        candidate_grids = []\n\n        # find best grid\n        for split_grids_nums in candidate_split_grids_nums:\n            m = 1\n            while m <= split_grids_nums:\n                if split_grids_nums % m == 0:\n                    candidate_grids.append([m, split_grids_nums // m])\n                m += 1\n\n        best_grid = [1, 1]\n        min_error = float(\"inf\")\n        for grid in candidate_grids:\n            error = abs(log_ratio - math.log(grid[0] / grid[1]))\n            if error < min_error:\n                best_grid = grid\n                min_error = error\n\n        refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, allow_upscale=True)\n\n        refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)\n        patches = split_to_patches(refine_image, best_grid)\n\n    return source_image, patches, best_grid\n\n\ndef ensure_divide(length, patch_size):\n    return max(round(length / patch_size) * patch_size, patch_size)\n\n\ndef find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False):\n    width, height = original_size\n    if (width * height > scale_resolution * scale_resolution) or allow_upscale:\n        r = width / height\n        height = int(scale_resolution / math.sqrt(r))\n        width = int(height * r)\n    best_width = ensure_divide(width, patch_size)\n    best_height = ensure_divide(height, patch_size)\n    return (best_width, best_height)\n\n\ndef get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False):\n    width, height = original_size\n    grid_x, grid_y = grid\n\n    refine_width = ensure_divide(width, grid_x)\n    refine_height = ensure_divide(height, grid_y)\n\n    grid_width = refine_width / grid_x\n    grid_height = refine_height / grid_y\n\n    best_grid_size = find_best_resize(\n        (grid_width, grid_height),\n        scale_resolution,\n        patch_size,\n        allow_upscale=allow_upscale,\n    )\n\n    refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)\n\n    return refine_size\n\n\ndef split_to_patches(image, grid):\n    patches = []\n    width, height = image.size\n    grid_x = int(width / grid[0])\n    grid_y = int(height / grid[1])\n\n    for i in range(0, height, grid_y):\n        images = []\n        for j in range(0, width, grid_x):\n            box = (j, i, j + grid_x, i + grid_y)\n            patch = image.crop(box)\n            images.append(patch)\n        patches.append(images)\n\n    return patches\n\n\ndef get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):\n    if new_schema:\n        image_placeholder = tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end\n    else:\n        image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end\n\n    cols = grid[0]\n    rows = grid[1]\n    slices = []\n    for i in range(rows):\n        lines = []\n        for j in range(cols):\n            lines.append(image_placeholder)\n        slices.append(\"\".join(lines))\n    if new_schema:\n        slice_placeholder = \"\\n\".join(slices)\n    else:\n        slice_placeholder = tokenizer.slice_start + \"\\n\".join(slices) + tokenizer.slice_end\n    return slice_placeholder\n\n\ndef reshape_by_patch(image_tensor, patch_size):\n    \"\"\"\n    :param image_tensor: shape [3, H, W]\n    :param patch_size:\n    :return: [3, patch_size, HW/patch_size]\n    \"\"\"\n    patches = torch.nn.functional.unfold(image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size))\n\n    patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)\n    patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1)\n    return patches\n\n\ndef init_minicpmo_config(processor, config):\n    \"\"\"Initialize MiniCPM-o specific configuration\"\"\"\n    minicpmo_config = {\n        \"transform\": build_transform(),\n        \"patch_size\": config.get(\"patch_size\", 14),\n        \"query_nums\": config.get(\"query_nums\", 64),\n        \"slice_config\": config.get(\n            \"slice_config\", {\"max_slice_nums\": 9, \"patch_size\": config.get(\"patch_size\", 14), \"scale_resolution\": 448}\n        ),\n        \"llm_type\": config.get(\"llm_type\", \"qwen\"),\n        \"batch_vision\": config.get(\"batch_vision\", True),\n    }\n    return minicpmo_config\n\n\ndef process_minicpmo_data(\n    row_dict, messages, tokenizer, minicpmo_config, image_key, max_prompt_length, truncation, logger\n):\n    \"\"\"Process data for MiniCPM-o model\"\"\"\n    if len(row_dict[image_key]) == 1:\n        multi_modal_data = {}\n        image = process_image(row_dict.pop(image_key)[0])\n        multi_modal_data[\"image\"] = [image]\n        images_dict = {\"<image>\": image}\n    else:\n        raise NotImplementedError\n\n    model_inputs = preprocess(\n        images_dict,\n        messages,\n        tokenizer,\n        minicpmo_config[\"transform\"],\n        query_nums=minicpmo_config[\"query_nums\"],\n        slice_config=minicpmo_config[\"slice_config\"],\n        llm_type=minicpmo_config[\"llm_type\"],\n        patch_size=minicpmo_config[\"patch_size\"],\n        batch_vision=minicpmo_config[\"batch_vision\"],\n        max_length=max_prompt_length,\n        truncation=truncation,\n        logger=logger,\n    )\n\n    raw_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n    raw_prompt = raw_prompt.replace(\"<image>\", \"(<image>./</image>)\")\n\n    return model_inputs, multi_modal_data, raw_prompt\n\n\nclass RLHFDataset(Dataset):\n    \"\"\"\n    Load and preprocess RLHF data from Parquet files.\n\n    - Caches files locally.\n    - Reads into a HuggingFace Dataset and tokenizes prompts.\n    - Optionally handles images/videos via a ProcessorMixin.\n    - Filters prompts over a max length.\n    - Supports resuming from checkpoints.\n\n    Args:\n        data_files (str or list): Path(s) to Parquet file(s).\n        tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.\n        config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.\n        processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_files: str | list[str],\n        tokenizer: PreTrainedTokenizer,\n        config: DictConfig,\n        processor: Optional[ProcessorMixin] = None,\n    ):\n        if not isinstance(data_files, list | ListConfig):\n            data_files = [data_files]\n\n        self.data_files = copy.deepcopy(data_files)\n        self.original_data_files = copy.deepcopy(data_files)  # use for resume\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n\n        self.cache_dir = os.path.expanduser(config.get(\"cache_dir\", \"~/.cache/verl/rlhf\"))\n        self.prompt_key = config.get(\"prompt_key\", \"prompt\")\n        self.image_key = config.get(\"image_key\", \"images\")\n        self.video_key = config.get(\"video_key\", \"videos\")\n        self.max_prompt_length = config.get(\"max_prompt_length\", 1024)\n        self.return_raw_chat = config.get(\"return_raw_chat\", False)\n        self.return_full_prompt = config.get(\"return_full_prompt\", False)\n        self.truncation = config.get(\"truncation\", \"error\")\n        self.filter_overlong_prompts = config.get(\"filter_overlong_prompts\", True)\n\n        self.num_workers = config.get(\"filter_overlong_prompts_workers\", max(1, os.cpu_count() // 4))\n        self.num_workers = min(self.num_workers, os.cpu_count())\n        self.use_shm = config.get(\"use_shm\", False)\n        self.chat_template_func = config.get(\"chat_template_func\", None)\n        self.need_tools_kwargs = config.get(\"need_tools_kwargs\", False)\n        self.filter_prompts = config.get(\"filter_prompts\", True)\n        self.serialize_dataset = False\n        self.minicpmo_config = init_minicpmo_config(self.processor, config)\n        self._download()\n        self._read_files_and_tokenize()\n\n    def _download(self, use_origin_parquet=False):\n        from verl.utils.fs import copy_to_local\n\n        data_files = self.data_files if not use_origin_parquet else self.original_data_files\n        for i, parquet_file in enumerate(data_files):\n            self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)\n\n    def _read_files_and_tokenize(self):\n        dataframes = []\n        for parquet_file in self.data_files:\n            # read parquet files and cache\n            dataframe = datasets.load_dataset(\"parquet\", data_files=parquet_file)[\"train\"]\n            dataframes.append(dataframe)\n        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)\n\n        print(f\"dataset len: {len(self.dataframe)}\")\n\n    def resume_dataset_state(self):\n        self.serialize_dataset = not hasattr(self, \"original_data_files\")\n        # resume dataframe if not it's serialized in data.pt\n        if not self.serialize_dataset:\n            self._download(use_origin_parquet=True)  # download and resume from original parquet files\n            self._read_files_and_tokenize()\n        else:\n            print(r\"old dataloader ckpt file is used, please train from scratch for better ckpt performance\")\n\n    def __len__(self):\n        return len(self.dataframe)\n\n    def _build_messages(self, example: dict):\n        return example.pop(self.prompt_key)\n\n    def __getitem__(self, item):\n        \"\"\"\n        Note that we also return the raw_input_ids so that it can be combined with other chat template\n        \"\"\"\n        row_dict: dict = self.dataframe[item]\n        messages = self._build_messages(row_dict)\n        model_inputs = {}\n\n        if self.processor is not None:\n            model_inputs, multi_modal_data, raw_prompt = process_minicpmo_data(\n                row_dict,\n                messages,\n                self.tokenizer,\n                self.minicpmo_config,\n                self.image_key,\n                self.max_prompt_length,\n                self.truncation,\n                logger,\n            )\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n            position_ids = model_inputs.pop(\"position_ids\")\n\n            # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature\n            row_dict[\"multi_modal_data\"] = multi_modal_data\n            row_dict[\"multi_modal_inputs\"] = dict(model_inputs)\n        else:\n            raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n            model_inputs = self.tokenizer(raw_prompt, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n            position_ids = compute_position_id_with_mask(attention_mask)\n\n        row_dict[\"input_ids\"] = input_ids\n        row_dict[\"attention_mask\"] = attention_mask\n        row_dict[\"position_ids\"] = position_ids\n\n        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)\n        if len(raw_prompt_ids) > self.max_prompt_length:\n            if self.truncation == \"left\":\n                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]\n            elif self.truncation == \"right\":\n                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]\n            elif self.truncation == \"middle\":\n                left_half = self.max_prompt_length // 2\n                right_half = self.max_prompt_length - left_half\n                raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]\n            elif self.truncation == \"error\":\n                raise RuntimeError(f\"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.\")\n\n        row_dict[\"raw_prompt_ids\"] = raw_prompt_ids\n        # encode prompts without chat template\n        if self.return_raw_chat:\n            row_dict[\"raw_prompt\"] = messages\n\n        # get prompts with chat template\n        if self.return_full_prompt:\n            row_dict[\"full_prompts\"] = raw_prompt  # array of strings\n\n        # add index for each prompt\n        index = row_dict.get(\"extra_info\", {}).get(\"index\", 0)\n        tools_kwargs = row_dict.get(\"extra_info\", {}).get(\"tools_kwargs\", {})\n        interaction_kwargs = row_dict.get(\"extra_info\", {}).get(\"interaction_kwargs\", {})\n        need_tools_kwargs = row_dict.get(\"extra_info\", {}).get(\"need_tools_kwargs\", self.need_tools_kwargs)\n        if need_tools_kwargs and not tools_kwargs:\n            logger.warning(\"tools_kwargs is empty for index {}, data source: {}\", index, row_dict[\"data_source\"])\n        row_dict[\"index\"] = index\n        row_dict[\"tools_kwargs\"] = tools_kwargs\n        row_dict[\"interaction_kwargs\"] = interaction_kwargs\n        return row_dict\n\n    def __getstate__(self):\n        if not self.serialize_dataset:\n            state = self.__dict__.copy()\n\n            if \"dataframe\" in state:\n                del state[\"dataframe\"]\n            return state\n\n        return self.__dict__.copy()\n"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/README.md",
    "content": "# Recipe: One Step Off Policy Async Trainer\n\n**Author:**  `https://github.com/meituan-search`\n\nLast updated: 07/17/2025.\n\n## Introduction\n\n### Background\n\nThe current reinforcement learning training process implemented by verl is synchronous, adhering to the algorithmic\nworkflows of established methods like PPO, GRPO, and DAPO. In each step, training samples are generated by the latest\nmodel, and the model is updated after training completes. While this approach aligns with off-policy reinforcement\nlearning and stabilizes RL training, but it suffers from severe efficiency issues.\nModel updates must wait for the longest output in the generation phase to complete.\nDuring the generation of long-tail samples, GPUs remain idle, resulting in significant underutilization.\nThe more severe the long-tail problem in sample generation, the lower the overall training efficiency.\nFor example, in DAPO 32B training, the Rollout phase accounts for approximately 70% of the total time,\nand increasing resources does not reduce the Rollout duration.\n\n![DAPO 32B Math Performance](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/dapo_32b_math.png)\n> source data: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=nwusertongyuxuan361\n\n### Solution\n\nWe have implemented the **One Step Off Async Trainer** to help alleviate this issue. This approach parallelizes the\ngeneration and training processes, utilizing samples generated in the previous step for current training.\nIt also involves appropriately partitioning resources, allocating dedicated resources for generation while automatically\nassigning the remainder to training. By reducing resources allocated to the generation phase, we mitigate GPU idle time\nduring long-tail sample generation. Throughout this process, generation and training parameters maintain a one-step off\npolicy.\n\n![One Step Off Policy Diagram](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_policy.png)\n> reference: [AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning](\n> https://arxiv.org/abs/2505.24298)\n\nOur core contributions include:\n\n1. **Parallel Generation and Training**:  \n   Samples for the next batch are asynchronously generated while the current batch is being trained.\n\n2. **Resource Isolation**:  \n   Unlike `hybrid_engine`, this method requires explicit resource allocation for rollout, with remaining resources\n   automatically assigned to training.\n\n3. **NCCL Parameter Synchronization**:  \n   Employs NCCL communication primitives for seamless parameter transfer between generation and training modules.\n\n### Experimental Results\n\n- **Machine Configuration**: 2 nodes with 16 H20 GPUs each\n    - Generation: 4 GPUs\n    - Training: 12 GPUs\n- **Model**: Qwen2.5-Math-7B\n- **Rollout Configuration**:\n- **Max Response Length**: FSDP2: 20,480 tokens; Megatron: 8,192 tokens\n- **Algorithm**: DAPO\n- **Rollout Engine**: vLLM\n\n| training mode          | engine        | step | gen | wait_prev_gen | generate_sequences | old_log_prob | update_actor | total time    | acc/best@32/mean | acc/maj@32/mean |\n|------------------------|---------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------|\n| colocate sync          | VLLM+FSDP2    | 749  | 321 | -             | 247                | 88           | 286          | 19h18m        | 0.5948           | 0.417           |\n| one-step-overlap async | VLLM+FSDP2    | 520  | -   | 45            | 458                | 108          | 337          | 15h34m（+23%）  | 0.6165           | 0.494           |\n| colocate sync          | VLLM+Megatron | 699  | 207 | -             | 162                | 119          | 344          | 18h21m        | 0.605            | 0.4217          |\n| one-step-overlap async | VLLM+Megatron | 566  | -   | 59            | 501                | 120          | 347          | 13h06m (+40%) | 0.6569           | 0.4038          |\n\n* colocate sync: step ≈ gen + old_log_prob + update_actor\n* one-step-overlap async: step ≈ wait_prev_gen + old_log_prob + update_actor\n\n![One Step Off Megatron Performance](\nhttps://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_megatron.png)\n\n> source data: https://wandb.ai/hou-zg-meituan/one-step-off-policy?nw=nwuserhouzg\n\n## Implementation\n\n### One Step Off Policy Async Pipline\n\nOur implemented **One Step Off Policy Async Pipeline** integrates seamlessly into existing training logic at minimal\ncost,\neliminating the need for additional sample storage management. The core mechanism uses `async_gen_next_batch`\nfor asynchronous rollout generation while maintaining continuous operation during epoch transitions\nvia `create_continuous_iterator`.\n\n```python\n# iterator generator, simplify one-step integration of the training process\ndef _create_continuous_iterator(self):\n    for epoch in range(self.config.trainer.total_epochs):\n        iterator = iter(self.train_dataloader)\n        for batch_dict in iterator:\n            yield epoch, batch_dict\n\n\n# read next batch samples, parameters sync and launch asyn gen_seq\ndef _async_gen_next_batch(self, continuous_iterator):\n    # read train_data\n    try:\n        epoch, batch_dict = next(continuous_iterator)\n    except StopIteration:\n        return None\n    batch = DataProto.from_single_dict(batch_dict)\n    gen_batch = batch_pocess(batch)\n    # sync weights from actor to rollout\n    self.sync_rollout_weights()\n    # async generation\n    gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)\n    # future encapsulated\n    return GenerationBatchFuture(epoch, batch, gen_batch_output)\n\n\ncontinuous_iterator = self._create_continuous_iterator()\n# run rollout first to achieve one-step-off\nbatch_data_future = self._async_gen_next_batch(continuous_iterator)\n\nwhile batch_data_future is not None:\n    # wait for the gen_seq result from the previous step\n    batch = batch_data_future.get()\n    # launch the next async call to generate sequences\n    batch_data_future = self._async_gen_next_batch(continuous_iterator)\n\n    # compute advantages \n    batch = critic.compute_values(batch)\n    batch = reference.compute_log_prob(batch)\n    batch = reward.compute_reward(batch)\n    batch = compute_advantages(batch)\n\n    # model update\n    critic_metrics = critic.update_critic(batch)\n    actor_metrics = actor.update_actor(batch)\n```\n\n### Parameter Synchronization\n\nThe exciting point is that our nccl based weights updating for rollout model has great performance.\nAt most of time, the latency is under 300ms, which is negligible for RLHF.\n\n> **sync_rollout_weights**：The time for synchronizing parameters from actor to rollout is extremely fast and can almost\n> be ignored because it is implemented with nccl.\n\n```python\nclass ActorRolloutRefWorker:\n    # actor acquires the meta-info of model parameters for parameter sync\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def get_actor_weights_info(self):\n        params = self._get_actor_params()\n        ret = []\n        for key, tensor in params.items():\n            ret.append((key, tensor.size(), tensor.dtype))\n        self._weights_info = ret\n        return ret\n\n    # rollout sets the meta-info of model parameters for parameter sync\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def set_actor_weights_info(self, weights_info):\n        self._weights_info = weights_info\n\n\nclass AsyncRayPPOTrainer(RayPPOTrainer):\n    def init_workers(self):\n\n\n...\n# rollout obtains the meta-info of model parameters from the actor for parameter sync\nweights_info = self.actor_wg.get_actor_weights_info()[0]\nself.rollout_wg.set_actor_weights_info(weights_info)\n\n# Create an actor-rollout communication group for parameter sync\nactor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers\ncollective.create_collective_group(\n    actor_rollout_workers,\n    len(actor_rollout_workers),\n    list(range(0, len(actor_rollout_workers))),\n    backend=\"nccl\",\n    group_name=\"actor_rollout\"\n)\n```\n\n```python\n# drive process call the actor and rollout respectively to sync parameters by nccl \ndef sync_rollout_weights(self):\n    self.actor_wg.sync_rollout_weights()\n    ray.get(self.rollout_wg.sync_rollout_weights())\n\n\n# fsdp model parameter sync\n@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\ndef sync_rollout_weights(self):\n    params = self._get_actor_params() if self._is_actor else None\n    if self._is_rollout:\n        inference_model = (\n            self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n        )\n        patch_vllm_moe_model_weight_loader(inference_model)\n    # Model parameters are broadcast tensor-by-tensor from actor to rollout\n    for key, shape, dtype in self._weights_info:\n        tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())\n        if self._is_actor:\n            assert key in params\n            origin_data = params[key]\n            if hasattr(origin_data, \"full_tensor\"):\n                origin_data = origin_data.full_tensor()\n            if torch.distributed.get_rank() == 0:\n                tensor.copy_(origin_data)\n        from ray.util.collective import collective\n\n        collective.broadcast(tensor, src_rank=0, group_name=\"actor_rollout\")\n        if self._is_rollout:\n            inference_model.load_weights([(key, tensor)])\n```\n\n## Usage\n\n### FSDP2 Configuration Example\n\n```shell\npython3 -m recipe.one_step_off_policy.async_main_ppo \\\n    --config-path=config \\\n    --config-name='one_step_off_ppo_trainer.yaml' \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    # actor and rollout are placed separately\n    actor_rollout_ref.hybrid_engine=False \\\n    # actor and rollout resource\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=6 \\\n    rollout.nnodes=1 \\\n    rollout.n_gpus_per_node=2\n```\n\n### Megatron Configuration Example\n\n```shell\npython3 -m recipe.one_step_off_policy.async_main_ppo \\\n    --config-path=config \\\n    --config-name='one_step_off_ppo_megatron_trainer.yaml' \\\n    actor_rollout_ref.actor.strategy=megatron \\\n    # actor and rollout are placed separately\n    actor_rollout_ref.hybrid_engine=False \\\n    # actor and rollout resource\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=6 \\\n    rollout.nnodes=1 \\\n    rollout.n_gpus_per_node=2\n```\n\n### Configuration Guidelines\n\n1. **Card Number Relationships**  \n   Maintain either of these relationships for optimal batch distribution:\n    - `actor_rollout_ref.rollout.n` should be an integer divisor of:  \n      `trainer.n_gpus_per_node * trainer.nnodes`\n    - `actor_rollout_ref.rollout.n * data.train_batch_size` should be evenly divisible by:  \n      `trainer.n_gpus_per_node * trainer.nnodes`\n\n   > Rationale: Ensures training samples can be evenly distributed across training GPUs when using partial resources for\n   generation.\n\n2. **Dynamic Resource Tuning**  \n   Adjust `trainer.nnodes` `trainer.n_gpus_per_node` `rollout.nnodes` `rollout.n_gpus_per_node` based on phase\n   durations:\n    - **Ideal state**: Rollout and training phases have comparable durations\n    - **Diagnostic metrics**:\n        - Monitor `wait_prev_gen` duration\n        - Analyze `sequence_length` distribution\n    - **Adjustment strategy**:\n        - High `wait_prev_gen` + uniform sequence lengths → Increase rollout resources\n        - High `wait_prev_gen` + long-tail sequences → Optimize stopping criteria (resource increase won't help)\n   > **wait_prev_gen**：The time consumed waiting for the previous rollout to end (the part that is not fully\n   overlapped).\n   **Resource Configuration Strategies:**\n    - **Resource-constrained scenario**: Optimize resource utilization by adjusting GPU allocation ratios,\n      keeping the number of nodes equal to allow training and rollout to share nodes;\n        - Configure `trainer.nnodes = rollout.nnodes` with\n          `trainer.n_gpus_per_node + rollout.n_gpus_per_node = physical_gpus_per_node`. Control rollout resource\n          allocation by adjusting `n_gpus_per_node`.\n    - **Resource-abundant scenario**: Optimize performance by adjusting the number of nodes,\n      keeping the number of GPUs per node equal to enable independent scaling of training and rollout\n      parallelism.\n        - Configure `trainer.n_gpus_per_node = rollout.n_gpus_per_node` and control rollout resource allocation by\n          adjusting `trainer.nnodes` and `rollout.nnodes`to achieve optimal performance.\n   > **Note**: The total number of nodes required by the system is not simply `trainer.nnodes + rollout.nnodes`. The\n   > actual calculation depends on GPU capacity:\n   > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node <= physical_gpus_per_node`,\n       > the required node count is `max(trainer.nnodes, rollout.nnodes)`\n   > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node > physical_gpus_per_node`,\n       > the required node count is `trainer.nnodes + rollout.nnodes`\n\n## Functional Support\n\n| Category           | Support Situation                                                                                               |\n|--------------------|-----------------------------------------------------------------------------------------------------------------|\n| train engine       | FSDP2  <br/> Megatron                                                                                           |\n| rollout engine     | vLLM                                                                                                            |\n| AdvantageEstimator | GRPO <br/> GRPO_PASSK <br/> REINFORCE_PLUS_PLUS <br/> RLOO <br/> OPO <br/> REINFORCE_PLUS_PLUS_BASELINE<br/>GPG |\n| Reward             | all                                                                                                             |\n"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_megatron_trainer\n  - _self_\n\n# config for the rollout (only for resource isolation)\nrollout:\n  # Number of nodes used in the rollout\n  nnodes: 1\n  # Number of GPUs per node\n  n_gpus_per_node: 8"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml",
    "content": "hydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\n# config for the rollout (only for resource isolation)\nrollout:\n  # Number of nodes used in the rollout\n  nnodes: 1\n  # Number of GPUs per node\n  n_gpus_per_node: 8"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/dapo_7b_math_fsdp2_4_12.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-one-step-off-4-12'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=12\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=2\nsp_size=4\nfsdp_size=2\n\npython3 -m recipe.one_step_off_policy.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\"\n"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/dapo_7b_math_fsdp2_colocate.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-colocate'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=12\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=2\nsp_size=4\nfsdp_size=2\n\n# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361\n\npython3 -m verl.trainer.main_ppo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.grad_clip=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \\\n    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/dapo_7b_math_megatron_4_12.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0527a1-megatron-one-step-off-4-12'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=12\ntrain_prompt_mini_bsz=32\n\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\nref_offload=True\nactor_offload=False\ngen_tp=2\ntrain_tp=2\ntrain_pp=2\n\n# TODO: support dynamic_bsz for megatron\n# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n\npython3 -m recipe.one_step_off_policy.main_ppo \\\n    --config-path=config \\\n    --config-name='one_step_off_ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=megatron \\\n    critic.strategy=megatron \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.param_offload=${ref_offload} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\"\n"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/dapo_7b_math_megatron_colocate.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nproject_name='DAPO'\nexp_name='DAPO-Qwen2.5-7b-MATH-0519a1-megatron-colocate'\n\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=$((1024 * 2))\nmax_response_length=$((1024 * 8))\nenable_overlong_buffer=True\noverlong_buffer_len=$((1024 * 4))\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\ntrain_prompt_bsz=512\nn_resp_per_prompt=16\ntrain_prompt_mini_bsz=32\n\n# Ray\n# RAY_ADDRESS=${RAY_ADDRESS:-\"http://localhost:8265\"}\n# WORKING_DIR=${WORKING_DIR:-\"${PWD}\"}\n# RUNTIME_ENV=${RUNTIME_ENV:-\"${WORKING_DIR}/verl/trainer/runtime_env.yaml\"}\nNNODES=${NNODES:-2}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\n# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/dapo-math-17k.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/aime-2024.parquet\"}\n\n# Algorithm\ntemperature=1.0\ntop_p=1.0\ntop_k=-1 # 0 for HF rollout, -1 for vLLM rollout\nval_top_p=0.7\n\n# Performance Related Parameter\nuse_dynamic_bsz=True\nactor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))\ninfer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))\noffload=True\ngen_tp=2\ntrain_tp=2\ntrain_pp=2\n\n# TODO: support dynamic_bsz for megatron\n# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \\\n# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \\\n# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \\\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml' \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.prompt_key=prompt \\\n    data.truncation='left' \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.strategy=megatron \\\n    critic.strategy=megatron \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    actor_rollout_ref.actor.clip_ratio_c=10.0 \\\n    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.megatron.param_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${offload} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.optim.clip_grad=1.0 \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \\\n    actor_rollout_ref.rollout.temperature=${temperature} \\\n    actor_rollout_ref.rollout.top_p=${top_p} \\\n    actor_rollout_ref.rollout.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \\\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \\\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \\\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True \\\n    actor_rollout_ref.rollout.val_kwargs.n=1 \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n    actor_rollout_ref.ref.megatron.param_offload=${offload} \\\n    reward_model.reward_manager=dapo \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \\\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \\\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.val_before_train=True \\\n    trainer.test_freq=10 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.total_training_steps=100 \\\n    trainer.default_local_dir=\"${CKPTS_DIR}\" \\\n    trainer.resume_mode=auto \\\n    trainer.log_val_generations=10\n"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/fsdp_workers.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\nimport torch.distributed\nfrom omegaconf import DictConfig, OmegaConf\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom transformers import AutoConfig\n\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass\nfrom verl.utils.debug import DistProfiler, DistProfilerExtension, log_gpu_memory_usage\nfrom verl.utils.device import (\n    get_device_name,\n    get_nccl_backend,\n    get_torch_device,\n)\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.fsdp_utils import (\n    fsdp_version,\n)\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.model import get_generation_config, update_model_config\nfrom verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker as ARRWorker\nfrom verl.workers.fsdp_workers import CriticWorker\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\ndevice_name = get_device_name()\n\n__all__ = [\"ActorRolloutRefWorker\", \"AsyncActorRolloutRefWorker\", \"CriticWorker\", \"RolloutWorker\"]\n\n\nclass ActorRolloutRefWorker(ARRWorker):\n    def _get_actor_params(self):\n        assert self._is_actor\n        params = self.actor_module_fsdp.state_dict()\n        from verl.utils.model import convert_weight_keys\n\n        params = convert_weight_keys(\n            params, getattr(self.actor_module_fsdp, \"_fsdp_wrapped_module\", self.actor_module_fsdp)\n        )\n        return params\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\n    def sync_rollout_weights(self):\n        assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine\n        assert hasattr(self, \"_weights_info\") and self._weights_info is not None\n\n        params = self._get_actor_params() if self._is_actor else None\n        if self._is_rollout:\n            inference_model = (\n                self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n            )\n            patch_vllm_moe_model_weight_loader(inference_model)\n        for key, shape, dtype in self._weights_info:\n            tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())\n            if self._is_actor:\n                assert key in params\n                origin_data = params[key]\n                if hasattr(origin_data, \"full_tensor\"):\n                    origin_data = origin_data.full_tensor()\n                if torch.distributed.get_rank() == 0:\n                    tensor.copy_(origin_data)\n            from ray.util.collective import collective\n\n            collective.broadcast(tensor, src_rank=0, group_name=\"actor_rollout\")\n            if self._is_rollout:\n                inference_model.load_weights([(key, tensor)])\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def get_actor_weights_info(self):\n        assert self._is_actor\n        if hasattr(self, \"_weights_info\"):\n            return self._weights_info\n        if fsdp_version(self.actor_module_fsdp) == 1:\n            from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType\n\n            FSDP.set_state_dict_type(\n                self.actor_module_fsdp,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n        params = self._get_actor_params()\n        ret = []\n        for key, tensor in params.items():\n            ret.append((key, tensor.size(), tensor.dtype))\n        self._weights_info = ret\n        return ret\n\n\nclass RolloutWorker(ActorRolloutRefWorker):\n    def __init__(self, config: DictConfig, role: str):\n        Worker.__init__(self)\n        assert role == \"rollout\"\n        self.config = config\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ.get(\"RANK\", 0))\n            world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n            torch.distributed.init_process_group(\n                backend=f\"cpu:gloo,{get_device_name()}:{get_nccl_backend()}\",\n                rank=rank,\n                world_size=world_size,\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n        # TODO(haibin.lin):\n        # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig,\n        # it will actually convert the ProfilerConfig dataclass back to a DictConfig.\n        # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py)\n        # as they provides DictConfig-like interface\n        # The benefit of creating the dataclass config is to perform validation during __post_init__\n        profiler_config = omega_conf_to_dataclass(config.rollout.get(\"profiler\", {}))\n        DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config))\n        self._is_rollout = True\n        self._is_actor = False\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n        override_model_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n\n        use_shm = self.config.model.get(\"use_shm\", False)\n        local_path = copy_to_local(self.config.model.path, use_shm=use_shm)\n        trust_remote_code = self.config.model.get(\"trust_remote_code\", False)\n\n        self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)\n\n        if self.config.model.get(\"custom_chat_template\", None) is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.config.model.custom_chat_template\n            else:\n                self.tokenizer.chat_template = self.config.model.custom_chat_template\n\n        # override model kwargs\n        actor_model_config = AutoConfig.from_pretrained(\n            local_path, trust_remote_code=trust_remote_code, attn_implementation=\"flash_attention_2\"\n        )\n\n        # patch for kimi-vl\n        if getattr(actor_model_config, \"model_type\", None) == \"kimi_vl\":\n            actor_model_config.text_config.topk_method = \"greedy\"\n\n        self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)\n\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_model_config)\n        update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)\n        if self.rank == 0:\n            print(f\"Model config after override: {actor_model_config}\")\n\n        infer_tp = self.config.rollout.tensor_model_parallel_size\n        dp = self.world_size // infer_tp\n        assert self.world_size % infer_tp == 0, (\n            f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n        )\n        rollout_device_mesh = init_device_mesh(\n            device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=[\"dp\", \"infer_tp\"]\n        )\n        rollout_name = self.config.rollout.name\n        assert rollout_name == \"vllm\"\n\n        from verl.workers.rollout.vllm_rollout import vLLMRollout\n\n        log_gpu_memory_usage(f\"Before building {rollout_name} rollout\", logger=logger)\n\n        from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout\n\n        vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == \"sync\" else vLLMAsyncRollout\n        rollout = vllm_rollout_cls(\n            model_path=local_path,\n            config=self.config.rollout,\n            tokenizer=self.tokenizer,\n            model_hf_config=actor_model_config,\n            device_mesh=rollout_device_mesh,\n            trust_remote_code=trust_remote_code,\n        )\n        log_gpu_memory_usage(f\"After building {rollout_name} rollout\", logger=logger)\n        from .vllm_sharding_manager import VLLMShardingManager\n\n        rollout_sharding_manager = VLLMShardingManager(\n            inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh\n        )\n\n        log_gpu_memory_usage(\"After building sharding manager\", logger=logger)\n\n        self.rollout = rollout\n        self.rollout_sharding_manager = rollout_sharding_manager\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)\n    def async_generate_sequences(self, *args, **kwargs):\n        return super().generate_sequences(*args, **kwargs)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def set_actor_weights_info(self, weights_info):\n        assert self._is_rollout\n        self._weights_info = weights_info\n\n\nclass AsyncActorRolloutRefWorker(ActorRolloutRefWorker):\n    def __init__(self, *args, **kwargs):\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/grpo_0.6b_gsm8k_fsdp2_2_6.sh",
    "content": "set -x\n\nproject_name='GRPO'\nexp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6'\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen3-0.6B\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/train.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/test.parquet\"}\n\nNNODES=${NNODES:-1}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\n\npython3 -m recipe.one_step_off_policy.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.train_batch_size=1152 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=192 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.val_before_train=True \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=2 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\" $@"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/grpo_3b_gsm8k_fsdp2_2_6.sh",
    "content": "set -x\n\nproject_name='GRPO'\nexp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6'\n\n# Paths\nRAY_DATA_HOME=${RAY_DATA_HOME:-\"${HOME}/verl\"}\nMODEL_PATH=${MODEL_PATH:-\"${RAY_DATA_HOME}/models/Qwen/Qwen2.5-3B-Instruct\"}\nCKPTS_DIR=${CKPTS_DIR:-\"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}\"}\nTRAIN_FILE=${TRAIN_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/train.parquet\"}\nTEST_FILE=${TEST_FILE:-\"${RAY_DATA_HOME}/data/gsm8k/test.parquet\"}\n\nNNODES=${NNODES:-1}\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-8}\n\nn_gpus_rollout=2\nn_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))\n\npython3 -m recipe.one_step_off_policy.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=\"${TRAIN_FILE}\" \\\n    data.val_files=\"${TEST_FILE}\" \\\n    data.train_batch_size=1152 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    critic.strategy=fsdp2 \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=192 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.val_before_train=True \\\n    trainer.logger=['console','tensorboard'] \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=2 \\\n    trainer.nnodes=\"${NNODES}\" \\\n    trainer.n_gpus_per_node=\"${n_gpus_training}\" \\\n    rollout.nnodes=\"${NNODES}\" \\\n    rollout.n_gpus_per_node=\"${n_gpus_rollout}\" $@"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/main_ppo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport os\nimport socket\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom verl.trainer.constants_ppo import get_ppo_ray_runtime_env\nfrom verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler\nfrom verl.trainer.ppo.reward import load_reward_manager\n\nfrom .ray_trainer import OneStepOffRayTrainer\n\n\n@hydra.main(config_path=\"config\", config_name=\"one_step_off_ppo_trainer\", version_base=None)\ndef main(config):\n    run_ppo(config)\n\n\n# Define a function to run the PPO-like training process\ndef run_ppo(config) -> None:\n    # Check if Ray is not initialized\n    if not ray.is_initialized():\n        # Initialize Ray with a local cluster configuration\n        # Set environment variables in the runtime environment to control tokenizer parallelism,\n        # NCCL debug level, VLLM logging level, and allow runtime LoRA updating\n        # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration\n        ray.init(\n            runtime_env=get_ppo_ray_runtime_env(),\n            num_cpus=config.ray_init.num_cpus,\n        )\n\n    # Create a remote instance of the TaskRunner class, and\n    # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete\n    if (\n        OmegaConf.select(config.trainer, \"profile_steps\") is not None\n        and len(OmegaConf.select(config.trainer, \"profile_steps\")) > 0\n    ):\n        nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)\n        runner = TaskRunner.options(runtime_env={\"nsight\": nsight_options}).remote()\n    else:\n        runner = TaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n    # [Optional] get the path of the timeline trace file from the configuration, default to None\n    # This file is used for performance analysis\n    timeline_json_file = config.ray_init.get(\"timeline_json_file\", None)\n    if timeline_json_file:\n        ray.timeline(filename=timeline_json_file)\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    def run(self, config):\n        # Print the initial configuration. `resolve=True` will evaluate symbolic values.\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        print(f\"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}\")\n\n        pprint(OmegaConf.to_container(config, resolve=True))\n\n        OmegaConf.resolve(config)\n\n        # Download the checkpoint from HDFS to the local machine.\n        # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on\n        local_path = copy_to_local(\n            config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get(\"use_shm\", False)\n        )\n\n        # Instantiate the tokenizer and processor.\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        # Used for multimodal LLM, could be None\n        processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)\n\n        # Define worker classes based on the actor strategy.\n        if config.actor_rollout_ref.actor.strategy == \"fsdp2\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray import RayWorkerGroup\n\n            from .fsdp_workers import (\n                ActorRolloutRefWorker,\n                AsyncActorRolloutRefWorker,\n                CriticWorker,\n                RolloutWorker,\n            )\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n\n            from .megatron_workers import (\n                ActorRolloutRefWorker,\n                AsyncActorRolloutRefWorker,\n                CriticWorker,\n                RolloutWorker,\n            )\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = NVMegatronRayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from .ray_trainer import ResourcePoolManager, Role\n\n        role_worker_mapping = {\n            Role.Actor: ray.remote(actor_rollout_cls),\n            Role.Rollout: ray.remote(RolloutWorker),\n            Role.Critic: ray.remote(CriticWorker),\n        }\n\n        global_pool_id = \"actor_pool\"\n\n        assert config.trainer.n_gpus_per_node > 0, \"config.trainer.n_gpus_per_node must be greater than 0\"\n        assert config.trainer.nnodes > 0, \"config.trainer.nnodes must be greater than 0\"\n        assert config.rollout.n_gpus_per_node > 0, \"config.rollout.n_gpus_per_node must be greater than 0\"\n        assert config.rollout.nnodes > 0, \"config.rollout.nnodes must be greater than 0\"\n\n        actor_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes\n        rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes\n\n        resource_pool_spec = {\n            \"actor_pool\": actor_pool,\n            \"rollout_pool\": rollout_pool,\n        }\n        mapping = {\n            Role.Actor: \"actor_pool\",\n            Role.Rollout: \"rollout_pool\",\n            Role.Critic: \"actor_pool\",\n        }\n        print(f\"resource_pool_spec: {resource_pool_spec}\")\n        # We should adopt a multi-source reward function here:\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # finally, we combine all the rewards together\n        # The reward type depends on the tag of the data\n        if config.reward_model.enable:\n            if config.reward_model.strategy in [\"fsdp2\"]:\n                from verl.workers.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # Add a reference policy worker if KL loss or KL reward is used.\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n            mapping[Role.RefPolicy] = global_pool_id\n\n        # Load the reward manager for training and validation.\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        val_reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=1, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        from verl.utils.dataset.rl_dataset import collate_fn\n\n        # Create training and validation datasets.\n        train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)\n        val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)\n        train_sampler = create_rl_sampler(config.data, train_dataset)\n\n        # Initialize the PPO trainer.\n        trainer = OneStepOffRayTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            train_dataset=train_dataset,\n            val_dataset=val_dataset,\n            collate_fn=collate_fn,\n            train_sampler=train_sampler,\n            device_name=config.trainer.device,\n        )\n        # Initialize the workers of the trainer.\n        trainer.init_workers()\n        # Start the training process.\n        trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/megatron_workers.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\nimport torch.distributed\nfrom omegaconf import DictConfig\n\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.utils.debug import (\n    log_gpu_memory_usage,\n)\nfrom verl.utils.device import get_device_name, get_torch_device\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader\nfrom verl.workers.megatron_workers import ActorRolloutRefWorker as ARRWorker\nfrom verl.workers.megatron_workers import CriticWorker, RewardModelWorker\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n__all__ = [\"ActorRolloutRefWorker\", \"AsyncActorRolloutRefWorker\", \"CriticWorker\", \"RewardModelWorker\", \"RolloutWorker\"]\n\n\nclass ActorRolloutRefWorker(ARRWorker):\n    def __init__(self, config: DictConfig, role: str):\n        assert role in [\"actor\", \"ref\"]\n        tmp_role = \"ref\" if role == \"ref\" else \"actor_rollout\"\n        super().__init__(config, tmp_role)\n        if role == \"actor\":\n            self._is_rollout = False\n        self.role = role\n\n    def _get_actor_params_generator(self):\n        assert self._is_actor\n        from verl.models.mcore import get_mcore_weight_converter\n        from verl.utils.megatron_utils import per_tensor_generator\n\n        layer_name_mapping = {\n            \"qkv_layer_name\": \"self_attention.linear_qkv.\",\n            \"gate_proj_layer_name\": \"linear_fc1.\",\n        }\n        weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)\n        generator = per_tensor_generator(\n            self.actor.actor_module,\n            self.actor_model_config,\n            weight_converter,\n            self.tf_config,\n            layer_name_mapping,\n        )\n        return generator\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\n    def sync_rollout_weights(self):\n        assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine\n        assert hasattr(self, \"_weights_info\") and self._weights_info is not None\n\n        params_generator = self._get_actor_params_generator() if self._is_actor else None\n        if self._is_rollout:\n            inference_model = (\n                self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n            )\n            patch_vllm_moe_model_weight_loader(inference_model)\n        for key, shape, dtype in self._weights_info:\n            if self._is_actor:\n                weight_key, weight = next(params_generator)\n                assert key == weight_key\n                assert shape == weight.size()\n                assert dtype == weight.dtype\n\n            tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())\n            if self._is_actor and torch.distributed.get_rank() == 0:\n                tensor.copy_(weight)\n            from ray.util.collective import collective\n\n            collective.broadcast(tensor, src_rank=0, group_name=\"actor_rollout\")\n            if self._is_rollout:\n                inference_model.load_weights([(key, tensor)])\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def get_actor_weights_info(self):\n        assert self._is_actor\n        if hasattr(self, \"_weights_info\"):\n            return self._weights_info\n\n        params_generator = self._get_actor_params_generator()\n        ret = []\n        for key, tensor in params_generator:\n            ret.append((key, tensor.size(), tensor.dtype))\n\n        self._weights_info = ret\n        return ret\n\n\nclass RolloutWorker(ActorRolloutRefWorker):\n    def __init__(self, config: DictConfig, role: str):\n        assert role == \"rollout\"\n        ARRWorker.__init__(self, config, role)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        if self.config.model.get(\"external_lib\", None) is not None:\n            # This is used to import external_lib into the huggingface systems\n            import importlib\n\n            importlib.import_module(self.config.model.external_lib)\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.torch_dtypes import PrecisionType\n\n        override_model_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n        override_transformer_config = {}\n        self.param_dtype = torch.bfloat16\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n        trust_remote_code = self.config.model.get(\"trust_remote_code\", False)\n\n        from verl.utils.model import get_generation_config\n\n        self._init_hf_config_and_tf_config(\n            self.config.model.path,\n            self.config.model.path,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            trust_remote_code,\n        )\n        self.generation_config = get_generation_config(self.local_path)\n\n        from torch.distributed.device_mesh import init_device_mesh\n\n        assert self.config.rollout.name == \"vllm\"\n        assert self.config.rollout.mode == \"sync\"\n\n        from verl.workers.rollout.vllm_rollout import vLLMRollout\n\n        from .vllm_sharding_manager import VLLMShardingManager\n\n        # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,\n        # we will reorganize their weight format when resharding from actor to rollout.\n\n        infer_tp = self.config.rollout.tensor_model_parallel_size\n        dp = self.world_size // infer_tp\n        assert self.world_size % infer_tp == 0, (\n            f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n        )\n        rollout_device_mesh = init_device_mesh(\n            get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=[\"dp\", \"infer_tp\"]\n        )\n        log_gpu_memory_usage(\"Before building vllm rollout\", logger=None)\n\n        local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get(\"use_shm\", False))\n        from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout\n\n        vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == \"sync\" else vLLMAsyncRollout\n        rollout = vllm_rollout_cls(\n            model_path=local_path,\n            config=self.config.rollout,\n            tokenizer=self.tokenizer,\n            model_hf_config=self.hf_config,\n            device_mesh=rollout_device_mesh,\n            trust_remote_code=trust_remote_code,\n        )\n        log_gpu_memory_usage(\"After building vllm rollout\", logger=logger)\n\n        sharding_manager = VLLMShardingManager(\n            inference_engine=rollout.inference_engine,\n            device_mesh=rollout_device_mesh,\n        )\n        log_gpu_memory_usage(\"After building sharding manager\", logger=logger)\n\n        self.rollout, self.sharding_manager = rollout, sharding_manager\n        self.rollout.sharding_manager = sharding_manager\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)\n    def async_generate_sequences(self, *args, **kwargs):\n        return super().generate_sequences(*args, **kwargs)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def set_actor_weights_info(self, weights_info):\n        assert self._is_rollout\n        self._weights_info = weights_info\n\n\nclass AsyncActorRolloutRefWorker(ActorRolloutRefWorker):\n    def __init__(self, *args, **kwargs):\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport uuid\nfrom pprint import pprint\n\nimport numpy as np\nimport ray\nimport torch\nfrom omegaconf import OmegaConf\nfrom torch.utils.data import Dataset, Sampler\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.ppo import core_algos\nfrom verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss\nfrom verl.trainer.ppo.metric_utils import (\n    compute_data_metrics,\n    compute_throughout_metrics,\n    compute_timing_metrics,\n)\nfrom verl.trainer.ppo.ray_trainer import (\n    RayPPOTrainer,\n    ResourcePoolManager,\n    Role,\n    WorkerType,\n    apply_kl_penalty,\n    compute_advantage,\n    compute_response_mask,\n)\nfrom verl.trainer.ppo.reward import compute_reward, compute_reward_async\nfrom verl.utils.debug import marked_timer\nfrom verl.utils.metric import (\n    reduce_metrics,\n)\nfrom verl.utils.tracking import ValidationGenerationsLogger\n\n\nclass GenerationBatchFuture:\n    \"\"\"\n    Wrapper class for encapsulating batch generation results\n    \"\"\"\n\n    def __init__(self, epoch, batch, gen_batch_output):\n        \"\"\"\n        :param epoch: current epoch\n        :param batch: Input batch data\n        :param gen_batch_output: Generated sequences from the main model (DataProtoFuture)\n        \"\"\"\n        self.epoch = epoch\n        self.batch = batch\n        self.gen_batch_output = gen_batch_output\n\n    def get(self):\n        \"\"\"\n        Get the actual results by calling get() method on gen_batch_output\n\n        Returns:\n            tuple: (batch, gen_batch_result)\n                - batch: Original input batch data\n                - gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself\n        \"\"\"\n        # Call get() method on gen_batch_output if available\n        if hasattr(self.gen_batch_output, \"get\"):\n            gen_batch_result = self.gen_batch_output.get()\n        else:\n            gen_batch_result = self.gen_batch_output\n\n        return self.epoch, self.batch, gen_batch_result\n\n\nclass OneStepOffRayTrainer(RayPPOTrainer):\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        train_dataset: Dataset | None = None,\n        val_dataset: Dataset | None = None,\n        collate_fn=None,\n        train_sampler: Sampler | None = None,\n        device_name=\"cuda\",\n    ):\n        \"\"\"\n        Initialize distributed PPO trainer with Ray backend.\n        Note that this trainer runs on the driver process on a single CPU/GPU node.\n\n        Args:\n            config: Configuration object containing training parameters.\n            tokenizer: Tokenizer used for encoding and decoding text.\n            role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes.\n            resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools.\n            ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup.\n            processor: Optional data processor, used for multimodal data\n            reward_fn: Function for computing rewards during training.\n            val_reward_fn: Function for computing rewards during validation.\n            train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.\n            val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.\n            collate_fn: Function to collate data samples into batches.\n            train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.\n            device_name (str, optional): Device name for training (e.g., \"cuda\", \"cpu\"). Defaults to \"cuda\".\n        \"\"\"\n\n        # Store the tokenizer for text processing\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n\n        assert not self.hybrid_engine\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = Role.RefPolicy in role_worker_mapping\n        self.use_rm = Role.RewardModel in role_worker_mapping\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.device_name = device_name\n        self.validation_generations_logger = ValidationGenerationsLogger()\n\n        # if ref_in_actor is True, the reference policy will be actor without lora applied\n        self.ref_in_actor = config.actor_rollout_ref.model.get(\"lora_rank\", 0) > 0\n\n        # define in-reward KL control\n        # kl loss control currently not suppoorted\n        if config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)\n\n        if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:\n            self.use_critic = True\n        elif self.config.algorithm.adv_estimator in [\n            AdvantageEstimator.GRPO,\n            AdvantageEstimator.GRPO_PASSK,\n            AdvantageEstimator.REINFORCE_PLUS_PLUS,\n            # AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy\n            AdvantageEstimator.RLOO,\n            AdvantageEstimator.OPO,\n            AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,\n            AdvantageEstimator.GPG,\n        ]:\n            self.use_critic = False\n        else:\n            raise NotImplementedError\n\n        self._validate_config()\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n    def _validate(self):\n        self.actor_rollout_wg = self.rollout_wg\n        ret = super()._validate()\n        self.actor_rollout_wg = self.actor_wg\n        return ret\n\n    def init_workers(self):\n        \"\"\"Initialize distributed training workers using Ray backend.\n\n        Creates:\n        1. Ray resource pools from configuration\n        2. Worker groups for each role (actor, critic, etc.)\n        \"\"\"\n        self.resource_pool_manager.create_resource_pool()\n\n        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}\n\n        # create actor and rollout\n        for role, role_name in [(Role.Actor, \"actor\"), (Role.Rollout, \"rollout\")]:\n            resource_pool = self.resource_pool_manager.get_resource_pool(role)\n            role_cls = RayClassWithInitArgs(\n                cls=self.role_worker_mapping[role],\n                config=self.config.actor_rollout_ref,\n                role=role_name,\n            )\n            self.resource_pool_to_cls[resource_pool][role_name] = role_cls\n\n        # create critic\n        if self.use_critic:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)\n            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)\n            self.resource_pool_to_cls[resource_pool][\"critic\"] = critic_cls\n\n        # create reference policy if needed\n        if self.use_reference_policy:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)\n            ref_policy_cls = RayClassWithInitArgs(\n                self.role_worker_mapping[Role.RefPolicy],\n                config=self.config.actor_rollout_ref,\n                role=\"ref\",\n                profile_option=self.config.trainer.npu_profile.options,\n            )\n            self.resource_pool_to_cls[resource_pool][\"ref\"] = ref_policy_cls\n\n        # create a reward model if reward_fn is None\n        if self.use_rm:\n            # we create a RM here\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)\n            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)\n            self.resource_pool_to_cls[resource_pool][\"rm\"] = rm_cls\n\n        # initialize WorkerGroup\n        # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,\n        # you should not use `create_colocated_worker_cls`.\n        # Instead, directly pass different resource pool to different worker groups.\n        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.\n        all_wg = {}\n        wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup\n        if OmegaConf.select(self.config.trainer, \"ray_wait_register_center_timeout\") is not None:\n            wg_kwargs[\"ray_wait_register_center_timeout\"] = self.config.trainer.ray_wait_register_center_timeout\n        if OmegaConf.select(self.config.trainer, \"profile_steps\") is not None:\n            wg_kwargs[\"profile_steps\"] = OmegaConf.select(self.config.trainer, \"profile_steps\")\n            assert OmegaConf.select(self.config.trainer, \"worker_nsight_options\") is not None, (\n                \"worker_nsight_options must be set when profile_steps is set\"\n            )\n            wg_kwargs[\"worker_nsight_options\"] = OmegaConf.to_container(\n                OmegaConf.select(self.config.trainer, \"worker_nsight_options\")\n            )\n\n        for resource_pool, class_dict in self.resource_pool_to_cls.items():\n            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n            wg_dict = self.ray_worker_group_cls(\n                resource_pool=resource_pool,\n                ray_cls_with_init=worker_dict_cls,\n                device_name=self.device_name,\n                **wg_kwargs,\n            )\n            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n            all_wg.update(spawn_wg)\n\n        if self.use_critic:\n            self.critic_wg = all_wg[\"critic\"]\n            self.critic_wg.init_model()\n\n        if self.use_reference_policy and not self.ref_in_actor:\n            self.ref_policy_wg = all_wg[\"ref\"]\n            self.ref_policy_wg.init_model()\n\n        if self.use_rm:\n            self.rm_wg = all_wg[\"rm\"]\n            self.rm_wg.init_model()\n\n        self.actor_wg = all_wg[\"actor\"]\n        self.rollout_wg = all_wg[\"rollout\"]\n        self.actor_wg.init_model()\n        self.rollout_wg.init_model()\n        self.actor_rollout_wg = self.actor_wg  # to be compatible with the functions that not be modified\n        weights_info = self.actor_wg.get_actor_weights_info()[0]\n        self.rollout_wg.set_actor_weights_info(weights_info)\n        from ray.util.collective import collective\n\n        actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers\n        collective.create_collective_group(\n            actor_rollout_workers,\n            len(actor_rollout_workers),\n            list(range(0, len(actor_rollout_workers))),\n            backend=\"nccl\",\n            group_name=\"actor_rollout\",\n        )\n        self.sync_rollout_weights()\n\n        # create async rollout manager and request scheduler\n        self.async_rollout_mode = False\n        if self.config.actor_rollout_ref.rollout.mode == \"async\" and self._is_rollout:\n            from verl.workers.rollout.async_server import AsyncLLMServerManager\n\n            self.async_rollout_mode = True\n            self.async_rollout_manager = AsyncLLMServerManager(\n                config=self.config,\n                worker_group=self.rollout_wg,\n            )\n\n    def sync_rollout_weights(self):\n        if not self.hybrid_engine:\n            self.actor_wg.sync_rollout_weights()\n            ray.get(self.rollout_wg.sync_rollout_weights())\n\n    def _create_continuous_iterator(self):\n        \"\"\"\n        Create a continuous data iterator across epoch\n        \"\"\"\n        for epoch in range(self.config.trainer.total_epochs):\n            iterator = iter(self.train_dataloader)\n            for batch_dict in iterator:\n                yield epoch, batch_dict\n\n    def _async_gen_next_batch(self, continuous_iterator):\n        \"\"\"\n        Call parameter synchronization and asynchronous sequence generation.\n        \"\"\"\n        try:\n            epoch, batch_dict = next(continuous_iterator)\n        except StopIteration:\n            return None\n        except Exception as e:\n            print(f\"Error in async_gen_next_batch: {e}\")\n            return None\n        batch = DataProto.from_single_dict(batch_dict)\n        # pop those keys for generation\n        batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n        non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n        if \"multi_modal_data\" in batch.non_tensor_batch:\n            non_tensor_batch_keys_to_pop.append(\"multi_modal_data\")\n        if \"raw_prompt\" in batch.non_tensor_batch:\n            non_tensor_batch_keys_to_pop.append(\"raw_prompt\")\n        if \"tools_kwargs\" in batch.non_tensor_batch:\n            non_tensor_batch_keys_to_pop.append(\"tools_kwargs\")\n        if \"interaction_kwargs\" in batch.non_tensor_batch:\n            non_tensor_batch_keys_to_pop.append(\"interaction_kwargs\")\n        gen_batch = batch.pop(\n            batch_keys=batch_keys_to_pop,\n            non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n        )\n        gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n        # sync weights from actor to rollout\n        self.sync_rollout_weights()\n        # async generation\n        gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)\n        return GenerationBatchFuture(epoch, batch, gen_batch_output)\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n\n        # across epoch iterator\n        continuous_iterator = self._create_continuous_iterator()\n\n        # Start the first asynchronous generation task.\n        batch_data_future = self._async_gen_next_batch(continuous_iterator)\n\n        while batch_data_future is not None:\n            do_profile = (\n                self.global_steps in self.config.trainer.profile_steps\n                if self.config.trainer.profile_steps is not None\n                else False\n            )\n            if do_profile:\n                self.actor_wg.start_profile()\n                if not self.hybrid_engine:\n                    self.rollout_wg.start_profile()\n                if self.use_reference_policy:\n                    self.ref_policy_wg.start_profile()\n                if self.use_critic:\n                    self.critic_wg.start_profile()\n                if self.use_rm:\n                    self.rm_wg.start_profile()\n\n            metrics = {}\n            timing_raw = {}\n            is_last_step = self.global_steps >= self.total_training_steps\n\n            with marked_timer(\"step\", timing_raw):\n                # wait for the previous batch\n                with marked_timer(\"wait_prev_gen\", timing_raw, color=\"red\"):\n                    epoch, batch, gen_batch_output = batch_data_future.get()\n                    timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                    gen_batch_output.meta_info.pop(\"timing\", None)\n\n                # asys next generation (with syns weights from actor to rollout)\n                with marked_timer(\"sync_rollout_weights\", timing_raw, color=\"purple\"):\n                    if not is_last_step:\n                        batch_data_future = self._async_gen_next_batch(continuous_iterator)\n\n                batch.non_tensor_batch[\"uid\"] = np.array(\n                    [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object\n                )\n                # repeat to align with repeated responses in rollout\n                batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                batch = batch.union(gen_batch_output)\n\n                batch.batch[\"response_mask\"] = compute_response_mask(batch)\n                # Balance the number of valid tokens across DP ranks.\n                # NOTE: This usually changes the order of data in the `batch`,\n                # which won't affect the advantage calculation (since it's based on uid),\n                # but might affect the loss calculation (due to the change of mini-batching).\n                # TODO: Decouple the DP balancing and mini-batching.\n                if self.config.trainer.balance_batch:\n                    self._balance_batch(batch, metrics=metrics)\n\n                # compute global_valid tokens\n                batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                with marked_timer(\"reward\", timing_raw, color=\"yellow\"):\n                    # compute reward model score\n                    if self.use_rm:\n                        reward_tensor = self.rm_wg.compute_rm_score(batch)\n                        batch = batch.union(reward_tensor)\n\n                    if self.config.reward_model.launch_reward_fn_async:\n                        future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)\n                    else:\n                        reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)\n\n                # recompute old_log_probs\n                with marked_timer(\"old_log_prob\", timing_raw, color=\"blue\"):\n                    old_log_prob = self.actor_wg.compute_log_prob(batch)\n                    entropys = old_log_prob.batch[\"entropys\"]\n                    response_masks = batch.batch[\"response_mask\"]\n                    loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                    entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                    old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n                    metrics.update(old_log_prob_metrics)\n                    old_log_prob.batch.pop(\"entropys\")\n                    batch = batch.union(old_log_prob)\n\n                    if \"rollout_log_probs\" in batch.batch.keys():\n                        # TODO: we may want to add diff of probs too.\n                        rollout_old_log_probs = batch.batch[\"rollout_log_probs\"]\n                        actor_old_log_probs = batch.batch[\"old_log_probs\"]\n                        attention_mask = batch.batch[\"attention_mask\"]\n                        responses = batch.batch[\"responses\"]\n                        response_length = responses.size(1)\n                        response_mask = attention_mask[:, -response_length:]\n\n                        rollout_probs = torch.exp(rollout_old_log_probs)\n                        actor_probs = torch.exp(actor_old_log_probs)\n                        rollout_probs_diff = torch.abs(rollout_probs - actor_probs)\n                        rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())\n                        rollout_probs_diff_max = torch.max(rollout_probs_diff)\n                        rollout_probs_diff_mean = torch.mean(rollout_probs_diff)\n                        rollout_probs_diff_std = torch.std(rollout_probs_diff)\n                        metrics.update(\n                            {\n                                \"training/rollout_probs_diff_max\": rollout_probs_diff_max.detach().item(),\n                                \"training/rollout_probs_diff_mean\": rollout_probs_diff_mean.detach().item(),\n                                \"training/rollout_probs_diff_std\": rollout_probs_diff_std.detach().item(),\n                            }\n                        )\n\n                if self.use_reference_policy:\n                    # compute reference log_prob\n                    with marked_timer(\"ref\", timing_raw, color=\"olive\"):\n                        if not self.ref_in_actor:\n                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                        else:\n                            ref_log_prob = self.actor_wg.compute_ref_log_prob(batch)\n                        batch = batch.union(ref_log_prob)\n\n                # compute values\n                if self.use_critic:\n                    with marked_timer(\"values\", timing_raw, color=\"cyan\"):\n                        values = self.critic_wg.compute_values(batch)\n                        batch = batch.union(values)\n\n                with marked_timer(\"adv\", timing_raw, color=\"brown\"):\n                    # we combine with rule-based rm\n                    reward_extra_infos_dict: dict[str, list]\n                    if self.config.reward_model.launch_reward_fn_async:\n                        reward_tensor, reward_extra_infos_dict = ray.get(future_reward)\n                    batch.batch[\"token_level_scores\"] = reward_tensor\n\n                    if reward_extra_infos_dict:\n                        batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})\n\n                    # compute rewards. apply_kl_penalty if available\n                    if self.config.algorithm.use_kl_in_reward:\n                        batch, kl_metrics = apply_kl_penalty(\n                            batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                        )\n                        metrics.update(kl_metrics)\n                    else:\n                        batch.batch[\"token_level_rewards\"] = batch.batch[\"token_level_scores\"]\n\n                    # compute advantages, executed on the driver process\n\n                    norm_adv_by_std_in_grpo = self.config.algorithm.get(\n                        \"norm_adv_by_std_in_grpo\", True\n                    )  # GRPO adv normalization factor\n\n                    batch = compute_advantage(\n                        batch,\n                        adv_estimator=self.config.algorithm.adv_estimator,\n                        gamma=self.config.algorithm.gamma,\n                        lam=self.config.algorithm.lam,\n                        num_repeat=self.config.actor_rollout_ref.rollout.n,\n                        norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                        config=self.config.algorithm,\n                    )\n\n                # update critic\n                if self.use_critic:\n                    with marked_timer(\"update_critic\", timing_raw, color=\"pink\"):\n                        critic_output = self.critic_wg.update_critic(batch)\n                    critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                    metrics.update(critic_output_metrics)\n\n                # implement critic warmup\n                if self.config.trainer.critic_warmup <= self.global_steps:\n                    # update actor\n                    with marked_timer(\"update_actor\", timing_raw, color=\"red\"):\n                        batch.meta_info[\"multi_turn\"] = self.config.actor_rollout_ref.rollout.multi_turn.enable\n                        actor_output = self.actor_wg.update_actor(batch)\n                    actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                    metrics.update(actor_output_metrics)\n\n                # Log rollout generations if enabled\n                rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n                if rollout_data_dir:\n                    with marked_timer(\"dump_rollout_generations\", timing_raw, color=\"green\"):\n                        inputs = self.tokenizer.batch_decode(batch.batch[\"prompts\"], skip_special_tokens=True)\n                        outputs = self.tokenizer.batch_decode(batch.batch[\"responses\"], skip_special_tokens=True)\n                        scores = batch.batch[\"token_level_scores\"].sum(-1).cpu().tolist()\n                        self._dump_generations(\n                            inputs=inputs,\n                            outputs=outputs,\n                            scores=scores,\n                            reward_extra_infos_dict=reward_extra_infos_dict,\n                            dump_path=rollout_data_dir,\n                        )\n\n                # validate\n                if (\n                    self.val_reward_fn is not None\n                    and self.config.trainer.test_freq > 0\n                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                ):\n                    with marked_timer(\"testing\", timing_raw, color=\"green\"):\n                        val_metrics: dict = self._validate()\n                        if is_last_step:\n                            last_val_metrics = val_metrics\n                    metrics.update(val_metrics)\n\n                if self.config.trainer.save_freq > 0 and (\n                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n                ):\n                    with marked_timer(\"save_checkpoint\", timing_raw, color=\"green\"):\n                        self._save_checkpoint()\n\n            # training metrics\n            metrics.update(\n                {\n                    \"training/global_step\": self.global_steps,\n                    \"training/epoch\": epoch,\n                }\n            )\n            # collect metrics\n            metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n            metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n            # TODO: implement actual tflpo and theoretical tflpo\n            n_gpus = self.resource_pool_manager.get_n_gpus()\n            metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n\n            # TODO: make a canonical logger that supports various backend\n            logger.log(data=metrics, step=self.global_steps)\n\n            progress_bar.update(1)\n            self.global_steps += 1\n\n            if do_profile:\n                self.actor_wg.stop_profile()\n                if not self.hybrid_engine:\n                    self.rollout_wg.stop_profile()\n                if self.use_reference_policy:\n                    self.ref_policy_wg.stop_profile()\n                if self.use_critic:\n                    self.critic_wg.stop_profile()\n                if self.use_rm:\n                    self.rm_wg.stop_profile()\n\n            if is_last_step:\n                pprint(f\"Final validation metrics: {last_val_metrics}\")\n                progress_bar.close()\n                return\n"
  },
  {
    "path": "verl_rl/recipe/one_step_off_policy/vllm_sharding_manager.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 Meituan Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nfrom torch.distributed.device_mesh import DeviceMesh\n\nfrom verl import DataProto\nfrom verl.protocol import all_gather_data_proto\nfrom verl.third_party.vllm import parallel_state as vllm_ps\nfrom verl.utils.debug import GPUMemoryLogger\nfrom verl.utils.device import get_torch_device\nfrom verl.utils.torch_functional import check_device_is_available\nfrom verl.workers.sharding_manager.base import BaseShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass VLLMShardingManager(BaseShardingManager):\n    @check_device_is_available()\n    def __init__(self, inference_engine, device_mesh: DeviceMesh):\n        self.device_mesh = device_mesh\n        self.inference_engine = inference_engine\n        inference_engine.wake_up()\n        assert device_mesh is not None\n        assert inference_engine is not None\n        self.tp_size = self.device_mesh[\"infer_tp\"].size()\n        self.tp_rank = self.device_mesh[\"infer_tp\"].get_local_rank()\n        self.timing = {}\n        gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n        get_torch_device().manual_seed(gen_dp_rank + 1000)\n        self.gen_random_states = get_torch_device().get_rng_state()\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def __enter__(self):\n        get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.gen_random_states = get_torch_device().get_rng_state()\n        self.inference_engine.reset_prefix_cache()\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"All gather across tp group to make each rank has identical input.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        group = vllm_ps.get_tensor_model_parallel_group().device_group\n\n        all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    @GPUMemoryLogger(role=\"vllm sharding_manager\", logger=logger)\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"Get chunk data of this tp rank since we do all gather in preprocess.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        return data.chunk(chunks=self.tp_size)[self.tp_rank]\n"
  },
  {
    "path": "verl_rl/recipe/onerec/main_onerec_ppo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nOneRec custom main entry point for PPO training using custom onerec_ray_trainer.\n\"\"\"\n\nimport os\nimport sys\n\n# Add project root to path to ensure imports work correctly\nproject_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\nif project_root not in sys.path:\n    sys.path.insert(0, project_root)\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\n# Import the custom trainer from onerec_ray_trainer.py\nfrom recipe.onerec.onerec_ray_trainer import RayPPOTrainer\n\n# Import other necessary components from verl\nfrom verl.trainer.constants_ppo import get_ppo_ray_runtime_env\nfrom verl.trainer.main_ppo import TaskRunner as BaseTaskRunner, create_rl_dataset, create_rl_sampler\nfrom verl.utils.device import is_cuda_available\n\n\n@hydra.main(config_path=\"../../verl/trainer/config\", config_name=\"ppo_trainer\", version_base=None)\ndef main(config):\n    \"\"\"Main entry point for OneRec PPO training with Hydra configuration management.\n\n    Args:\n        config: Hydra configuration dictionary containing training parameters.\n    \"\"\"\n    run_ppo(config)\n\n\ndef run_ppo(config) -> None:\n    \"\"\"Run PPO training process with OneRec custom trainer.\n\n    Args:\n        config: Training configuration object containing all necessary parameters\n                for distributed PPO training including Ray initialization settings,\n                model paths, and training hyperparameters.\n    \"\"\"\n    # Check if Ray is not initialized\n    if not ray.is_initialized():\n        # Initialize Ray with a local cluster configuration\n        ray.init(\n            runtime_env=get_ppo_ray_runtime_env(),\n            num_cpus=config.ray_init.num_cpus,\n        )\n\n    # Create a remote instance of the TaskRunner class\n    if (\n        is_cuda_available\n        and config.trainer.get(\"profile_steps\") is not None\n        and len(config.trainer.get(\"profile_steps\", [])) > 0\n    ):\n        nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)\n        runner = OneRecTaskRunner.options(runtime_env={\"nsight\": nsight_options}).remote()\n    else:\n        runner = OneRecTaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n    # Optional: get the path of the timeline trace file from the configuration\n    timeline_json_file = config.trainer.get(\"ray_timeline_filename\", None)\n    if timeline_json_file:\n        ray.timeline(filename=timeline_json_file)\n\n\n@ray.remote(num_cpus=1)\nclass OneRecTaskRunner:\n    \"\"\"Ray remote class for executing distributed OneRec PPO training tasks.\n\n    This class encapsulates the main training logic and runs as a Ray remote actor\n    to enable distributed execution across multiple nodes and GPUs.\n    Uses the custom onerec_ray_trainer.RayPPOTrainer instead of the default trainer.\n    \"\"\"\n\n    def run(self, config):\n        \"\"\"Execute the main PPO training workflow with OneRec custom trainer.\n\n        Args:\n            config: Training configuration object containing all parameters needed\n                   for setting up and running the PPO training process.\n        \"\"\"\n        import socket\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.trainer.ppo.reward import load_reward_manager\n        from verl.utils.fs import copy_to_local\n        from verl.utils.import_utils import load_extern_type\n        \n        # Import Role and ResourcePoolManager from the custom onerec_ray_trainer\n        # to ensure we use the same Role enum\n        from recipe.onerec.onerec_ray_trainer import ResourcePoolManager, Role\n\n        print(f\"OneRecTaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}\")\n        print(\"=\" * 80)\n        print(\"Using Custom OneRec RayPPOTrainer from recipe/onerec/onerec_ray_trainer.py\")\n        print(\"=\" * 80)\n        pprint(OmegaConf.to_container(config, resolve=True))\n        OmegaConf.resolve(config)\n\n        # Download the checkpoint from HDFS to the local machine\n        local_path = copy_to_local(\n            config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get(\"use_shm\", False)\n        )\n\n        # Instantiate the tokenizer and processor\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)\n\n        # Define worker classes based on the actor strategy\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n            from verl.single_controller.ray import RayWorkerGroup\n            # Use custom OneRecActorRolloutRefWorker instead of standard ActorRolloutRefWorker\n            from recipe.onerec.onerec_fsdp_workers import OneRecActorRolloutRefWorker as ActorRolloutRefWorker\n            from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker\n\n            use_legacy_worker_impl = config.trainer.get(\"use_legacy_worker_impl\", \"auto\")\n            if use_legacy_worker_impl in [\"auto\", \"enable\"]:\n                from verl.workers.fsdp_workers import CriticWorker\n            elif use_legacy_worker_impl == \"disable\":\n                from verl.workers.roles import CriticWorker\n                print(\"Using new worker implementation\")\n            else:\n                raise ValueError(f\"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}\")\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n            from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = NVMegatronRayWorkerGroup\n\n        else:\n            raise NotImplementedError(f\"Unknown strategy: {config.actor_rollout_ref.actor.strategy}\")\n\n        # Load reward model worker if enabled\n        if config.reward_model.get(\"enable\", False):\n            if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                from verl.workers.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError(f\"Unknown reward model strategy: {config.reward_model.strategy}\")\n        else:\n            RewardModelWorker = None\n\n        # Setup resource pool configuration\n        n_gpus_per_node = config.trainer.n_gpus_per_node\n        nnodes = config.trainer.nnodes\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {global_pool_id: [n_gpus_per_node] * nnodes}\n\n        # Map roles to workers\n        role_worker_mapping = {\n            Role.ActorRollout: ray.remote(actor_rollout_cls),\n        }\n        mapping = {\n            Role.ActorRollout: global_pool_id,\n        }\n\n        if config.critic.get(\"enable\", True):\n            role_worker_mapping[Role.Critic] = ray.remote(CriticWorker)\n            mapping[Role.Critic] = global_pool_id\n\n        if config.reward_model.get(\"enable\", False):\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            role_worker_mapping[Role.RefPolicy] = ray.remote(actor_rollout_cls)\n            mapping[Role.RefPolicy] = global_pool_id\n\n        # Load reward managers\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        val_reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=1, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        from verl.utils.dataset.rl_dataset import collate_fn\n\n        # Create training and validation datasets\n        train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True)\n        val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False)\n        train_sampler = create_rl_sampler(config.data, train_dataset)\n\n        # ========================================================================\n        # KEY CHANGE: Use the custom OneRec RayPPOTrainer instead of default\n        # ========================================================================\n        trainer = RayPPOTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            train_dataset=train_dataset,\n            val_dataset=val_dataset,\n            collate_fn=collate_fn,\n            train_sampler=train_sampler,\n        )\n\n        # Initialize the workers of the trainer\n        trainer.init_workers()\n        # Start the training process\n        trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "verl_rl/recipe/onerec/onerec_fsdp_workers.py",
    "content": "\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker\nfrom recipe.onerec.onerec_vllm_rollout import OneRecvLLMRollout\nfrom verl.utils.fs import copy_to_local\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom verl.utils.device import get_device_name\nimport logging\nimport torch\n\nlogger = logging.getLogger(__name__)\n\nclass OneRecActorRolloutRefWorker(ActorRolloutRefWorker):\n    \"\"\"\n    Custom ActorRolloutRefWorker that uses OneRecvLLMRollout instead of standard vLLMRollout.\n    \"\"\"\n    def _build_rollout(self, trust_remote_code=False):\n        # We only override the two_stage rollout path\n        if self.config.rollout.name == \"two_stage\":\n            from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager\n            from verl.utils.profiler import log_gpu_memory_usage\n\n            # Original logic from ActorRolloutRefWorker._build_rollout\n            infer_tp = self.config.rollout.tensor_model_parallel_size\n            dp = self.world_size // infer_tp\n            assert self.world_size % infer_tp == 0, (\n                f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n            )\n            device_name = get_device_name()\n            rollout_device_mesh = init_device_mesh(\n                device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=[\"dp\", \"infer_tp\"]\n            )\n\n            log_gpu_memory_usage(f\"Before building vllm rollout (OneRec Custom)\", logger=logger)\n            local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get(\"use_shm\", False))\n            lora_kwargs = (\n                {\"lora_kwargs\": {\"enable_lora\": True, \"max_loras\": 1, \"max_lora_rank\": self._lora_rank}}\n                if self._is_lora\n                else {}\n            )\n            \n            # Use our custom class!\n            # We check for async mode but currently only support Sync OneRecvLLMRollout\n            if self.config.rollout.mode == \"async\":\n                 logger.warning(\"OneRecvLLMRollout currently only supports SYNC mode fully. Async might fallback or fail if logic differs.\")\n                 # If you implemented AsyncOneRecvLLMRollout, use it here.\n                 # For now, we assume sync mode or that OneRecvLLMRollout works for both structure wise \n                 # (vLLMAsyncRollout inherits from different base, so simple substitution might fail for async)\n                 # Fallback to original for async if you haven't implemented Async wrapper\n                 return super()._build_rollout(trust_remote_code)\n\n            rollout = OneRecvLLMRollout(\n                model_path=local_path,\n                config=self.config.rollout,\n                tokenizer=self.tokenizer,\n                model_hf_config=self.actor_model_config,\n                device_mesh=rollout_device_mesh,\n                trust_remote_code=trust_remote_code,\n                **lora_kwargs,\n            )\n\n            log_gpu_memory_usage(f\"After building vllm rollout (OneRec Custom)\", logger=logger)\n            full_params = torch.distributed.get_world_size() == 1\n            rollout_sharding_manager = FSDPVLLMShardingManager(\n                module=self.actor_module_fsdp,\n                inference_engine=rollout.inference_engine,\n                model_config=self.actor_model_config,\n                rollout_config=self.config.rollout,\n                full_params=full_params,\n                device_mesh=rollout_device_mesh,\n                offload_param=self._is_offload_param,\n                load_format=self.config.rollout.load_format,\n                layered_summon=self.config.rollout.get(\"layered_summon\", False),\n            )\n            log_gpu_memory_usage(\"After building sharding manager\", logger=logger)\n            \n            return rollout, rollout_sharding_manager\n        \n        else:\n            # Fallback to parent implementation for other backends\n            return super()._build_rollout(trust_remote_code)\n"
  },
  {
    "path": "verl_rl/recipe/onerec/onerec_ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport json\nimport os\nimport uuid\nfrom collections import defaultdict\nfrom copy import deepcopy\nfrom dataclasses import dataclass, field\nfrom enum import Enum\nfrom pprint import pprint\nfrom typing import Optional\n\nimport numpy as np\nimport ray\nimport torch\nimport wandb\nfrom omegaconf import OmegaConf, open_dict\nfrom tensordict import TensorDict\nfrom torch.utils.data import Dataset, Sampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.experimental.dataset.sampler import AbstractCurriculumSampler\nfrom verl.protocol import pad_dataproto_to_divisor, unpad_dataproto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.config import AlgoConfig\nfrom verl.trainer.ppo import core_algos\nfrom verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss\nfrom verl.trainer.ppo.metric_utils import (\n    compute_data_metrics,\n    compute_throughout_metrics,\n    compute_timing_metrics,\n    process_validation_metrics\n)\nfrom verl.trainer.ppo.reward import compute_reward, compute_reward_async\nfrom verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi\nfrom verl.utils.debug import marked_timer\nfrom verl.utils.metric import (\n    reduce_metrics,\n)\nfrom verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance\nfrom verl.utils.torch_functional import masked_mean, postprocess_data\nfrom verl.utils.tracking import ValidationGenerationsLogger\n\n\nWorkerType = type[Worker]\n\n\nclass Role(Enum):\n    \"\"\"\n    To create more roles dynamically, you can subclass Role and add new members\n    \"\"\"\n\n    Actor = 0\n    Rollout = 1\n    ActorRollout = 2\n    Critic = 3\n    RefPolicy = 4\n    RewardModel = 5\n    ActorRolloutRef = 6\n\n\n@dataclass\nclass ResourcePoolManager:\n    \"\"\"\n    Define a resource pool specification. Resource pool will be initialized first.\n    \"\"\"\n\n    resource_pool_spec: dict[str, list[int]]\n    mapping: dict[Role, str]\n    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)\n\n    def create_resource_pool(self):\n        \"\"\"Create Ray resource pools for distributed training.\n\n        Initializes resource pools based on the resource pool specification,\n        with each pool managing GPU resources across multiple nodes.\n        For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups.\n        For Megatron backend, uses max_colocate_count>1 for different models.\n        \"\"\"\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool\n            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.\n            # For Megatron backend, we recommend using max_colocate_count>1\n            # that can utilize different WorkerGroup for differnt models\n            resource_pool = RayResourcePool(\n                process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name\n            )\n            self.resource_pool_dict[resource_pool_name] = resource_pool\n\n        self._check_resource_available()\n\n    def get_resource_pool(self, role: Role) -> RayResourcePool:\n        \"\"\"Get the resource pool of the worker_cls\"\"\"\n        return self.resource_pool_dict[self.mapping[role]]\n\n    def get_n_gpus(self) -> int:\n        \"\"\"Get the number of gpus in this cluster.\"\"\"\n        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])\n\n    def _check_resource_available(self):\n        \"\"\"Check if the resource pool can be satisfied in this ray cluster.\"\"\"\n        node_available_resources = ray.state.available_resources_per_node()\n        node_available_gpus = {\n            node: node_info.get(\"GPU\", 0) if \"GPU\" in node_info else node_info.get(\"NPU\", 0)\n            for node, node_info in node_available_resources.items()\n        }\n\n        # check total required gpus can be satisfied\n        total_available_gpus = sum(node_available_gpus.values())\n        total_required_gpus = sum(\n            [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]\n        )\n        if total_available_gpus < total_required_gpus:\n            raise ValueError(\n                f\"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}\"\n            )\n\n        # check each resource pool can be satisfied, O(#resource_pools * #nodes)\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)\n            for node, available_gpus in node_available_gpus.items():\n                if available_gpus >= num_gpus:\n                    node_available_gpus[node] -= num_gpus\n                    num_nodes -= 1\n                    if num_nodes == 0:\n                        break\n            if num_nodes > 0:\n                raise ValueError(\n                    f\"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}\"\n                    + \"cannot be satisfied in this ray cluster\"\n                )\n\n\ndef apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty=\"kl\"):\n    \"\"\"Apply KL penalty to the token-level rewards.\n\n    This function computes the KL divergence between the reference policy and current policy,\n    then applies a penalty to the token-level rewards based on this divergence.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n        kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty.\n        kl_penalty (str, optional): Type of KL penalty to apply. Defaults to \"kl\".\n        multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False.\n\n    Returns:\n        tuple: A tuple containing:\n            - The updated data with token-level rewards adjusted by KL penalty\n            - A dictionary of metrics related to the KL penalty\n    \"\"\"\n    response_mask = data.batch[\"response_mask\"]\n    token_level_scores = data.batch[\"token_level_scores\"]\n    batch_size = data.batch.batch_size[0]\n\n    # compute kl between ref_policy and current policy\n    # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.\n    kld = core_algos.kl_penalty(\n        data.batch[\"old_log_probs\"], data.batch[\"ref_log_prob\"], kl_penalty=kl_penalty\n    )  # (batch_size, response_length)\n    kld = kld * response_mask\n    beta = kl_ctrl.value\n\n    token_level_rewards = token_level_scores - beta * kld\n\n    current_kl = masked_mean(kld, mask=response_mask, axis=-1)  # average over sequence\n    current_kl = torch.mean(current_kl, dim=0).item()\n\n    # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837\n    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)\n    data.batch[\"token_level_rewards\"] = token_level_rewards\n\n    metrics = {\"actor/reward_kl_penalty\": current_kl, \"actor/reward_kl_penalty_coeff\": beta}\n\n    return data, metrics\n\n\ndef compute_response_mask(data: DataProto):\n    \"\"\"Compute the attention mask for the response part of the sequence.\n\n    This function extracts the portion of the attention mask that corresponds to the model's response,\n    which is used for masking computations that should only apply to response tokens.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n\n    Returns:\n        torch.Tensor: The attention mask for the response tokens.\n    \"\"\"\n    responses = data.batch[\"responses\"]\n    response_length = responses.size(1)\n    attention_mask = data.batch[\"attention_mask\"]\n    return attention_mask[:, -response_length:]\n\n\ndef compute_advantage(\n    data: DataProto,\n    adv_estimator: AdvantageEstimator,\n    gamma: float = 1.0,\n    lam: float = 1.0,\n    num_repeat: int = 1,\n    norm_adv_by_std_in_grpo: bool = True,\n    config: Optional[AlgoConfig] = None,\n    tokenizer = None,\n) -> DataProto:\n    \"\"\"Compute advantage estimates for policy optimization.\n\n    This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc.\n    The advantage estimates are used to guide policy optimization in RL algorithms.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n        adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++).\n        gamma (float, optional): Discount factor for future rewards. Defaults to 1.0.\n        lam (float, optional): Lambda parameter for GAE. Defaults to 1.0.\n        num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1.\n        norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in\n            GRPO. Defaults to True.\n        config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None.\n\n    Returns:\n        DataProto: The updated data with computed advantages and returns.\n    \"\"\"\n    # Back-compatible with trainers that do not compute response mask in fit\n    if \"response_mask\" not in data.batch.keys():\n        data.batch[\"response_mask\"] = compute_response_mask(data)\n    # prepare response group\n    if adv_estimator == AdvantageEstimator.GAE:\n        # Compute advantages and returns using Generalized Advantage Estimation (GAE)\n        advantages, returns = core_algos.compute_gae_advantage_return(\n            token_level_rewards=data.batch[\"token_level_rewards\"],\n            values=data.batch[\"values\"],\n            response_mask=data.batch[\"response_mask\"],\n            gamma=gamma,\n            lam=lam,\n        )\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n        if config.get(\"use_pf_ppo\", False):\n            data = core_algos.compute_pf_ppo_reweight_data(\n                data,\n                config.pf_ppo.reweight_method,\n                config.pf_ppo.weight_pow,\n            )\n    elif adv_estimator == AdvantageEstimator.GRPO:\n        # Initialize the mask for GRPO calculation\n        grpo_calculation_mask = data.batch[\"response_mask\"]\n        # Call compute_grpo_outcome_advantage with parameters matching its definition\n        advantages, returns = core_algos.compute_grpo_outcome_advantage(\n            token_level_rewards=data.batch[\"token_level_rewards\"],\n            response_mask=grpo_calculation_mask,\n            index=data.non_tensor_batch[\"uid\"],\n            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n        )\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n    else:\n        # handle all other adv estimator type other than GAE and GRPO\n        adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)\n        adv_kwargs = {\n            \"token_level_rewards\": data.batch[\"token_level_rewards\"],\n            \"response_mask\": data.batch[\"response_mask\"],\n            \"config\": config,\n        }\n        if \"uid\" in data.non_tensor_batch:  # optional\n            adv_kwargs[\"index\"] = data.non_tensor_batch[\"uid\"]\n        if \"reward_baselines\" in data.batch:  # optional\n            adv_kwargs[\"reward_baselines\"] = data.batch[\"reward_baselines\"]\n\n        # calculate advantage estimator\n        advantages, returns = adv_estimator_fn(**adv_kwargs)\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n    return data\n\n\nclass RayPPOTrainer:\n    \"\"\"Distributed PPO trainer using Ray for scalable reinforcement learning.\n\n    This trainer orchestrates distributed PPO training across multiple nodes and GPUs,\n    managing actor rollouts, critic training, and reward computation with Ray backend.\n    Supports various model architectures including FSDP, Megatron, and vLLM integration.\n    \"\"\"\n\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        train_dataset: Optional[Dataset] = None,\n        val_dataset: Optional[Dataset] = None,\n        collate_fn=None,\n        train_sampler: Optional[Sampler] = None,\n        device_name=None,\n    ):\n        \"\"\"\n        Initialize distributed PPO trainer with Ray backend.\n        Note that this trainer runs on the driver process on a single CPU/GPU node.\n\n        Args:\n            config: Configuration object containing training parameters.\n            tokenizer: Tokenizer used for encoding and decoding text.\n            role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes.\n            resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools.\n            ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup.\n            processor: Optional data processor, used for multimodal data\n            reward_fn: Function for computing rewards during training.\n            val_reward_fn: Function for computing rewards during validation.\n            train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.\n            val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.\n            collate_fn: Function to collate data samples into batches.\n            train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.\n            device_name (str, optional): Device name for training (e.g., \"cuda\", \"cpu\"). Defaults to None.\n        \"\"\"\n\n        # Store the tokenizer for text processing\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n        assert self.hybrid_engine, \"Currently, only support hybrid engine\"\n\n        if self.hybrid_engine:\n            assert Role.ActorRollout in role_worker_mapping, f\"{role_worker_mapping.keys()=}\"\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = Role.RefPolicy in role_worker_mapping\n        self.use_rm = Role.RewardModel in role_worker_mapping\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.device_name = device_name if device_name else self.config.trainer.device\n        self.validation_generations_logger = ValidationGenerationsLogger(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n        )\n\n        # if ref_in_actor is True, the reference policy will be actor without lora applied\n        self.ref_in_actor = config.actor_rollout_ref.model.get(\"lora_rank\", 0) > 0\n\n        # define in-reward KL control\n        # kl loss control currently not suppoorted\n        if self.config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)\n\n        if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:\n            self.use_critic = True\n        elif self.config.algorithm.adv_estimator in [\n            AdvantageEstimator.GRPO,\n            AdvantageEstimator.REINFORCE_PLUS_PLUS,\n            AdvantageEstimator.REMAX,\n            AdvantageEstimator.RLOO,\n            AdvantageEstimator.OPO,\n            AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,\n            AdvantageEstimator.GPG,\n        ]:\n            self.use_critic = False\n        else:\n            raise NotImplementedError\n\n        self._validate_config()\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n    def _validate_config(self):\n        config = self.config\n        # number of GPUs total\n        n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes\n        if config.actor_rollout_ref.actor.strategy == \"megatron\":\n            model_parallel_size = (\n                config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size\n                * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size\n            )\n            assert (\n                n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0\n            ), (\n                f\"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times \"\n                f\"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})\"\n            )\n            megatron_dp = n_gpus // (\n                model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size\n            )\n            minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu\n        else:\n            minimal_bsz = n_gpus\n\n        # 1. Check total batch size for data correctness\n        real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n\n        assert real_train_batch_size % minimal_bsz == 0, (\n            f\"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size \"\n            f\"({minimal_bsz})\"\n        )\n\n        # A helper function to check \"micro_batch_size\" vs \"micro_batch_size_per_gpu\"\n        # We throw an error if the user sets both. The new convention is \"..._micro_batch_size_per_gpu\".\n        def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):\n            \"\"\"Validate mutually exclusive micro batch size configuration options.\n\n            Ensures that users don't set both deprecated micro_batch_size and\n            the new micro_batch_size_per_gpu parameters simultaneously.\n\n            Args:\n                mbs: Deprecated micro batch size parameter value.\n                mbs_per_gpu: New micro batch size per GPU parameter value.\n                name (str): Configuration section name for error messages.\n\n            Raises:\n                ValueError: If both parameters are set or neither is set.\n            \"\"\"\n            settings = {\n                \"actor_rollout_ref.actor\": \"micro_batch_size\",\n                \"critic\": \"micro_batch_size\",\n                \"reward_model\": \"micro_batch_size\",\n                \"actor_rollout_ref.ref\": \"log_prob_micro_batch_size\",\n                \"actor_rollout_ref.rollout\": \"log_prob_micro_batch_size\",\n            }\n\n            if name in settings:\n                param = settings[name]\n                param_per_gpu = f\"{param}_per_gpu\"\n\n                if mbs is None and mbs_per_gpu is None:\n                    raise ValueError(\n                        f\"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.\"\n                    )\n\n                if mbs is not None and mbs_per_gpu is not None:\n                    raise ValueError(\n                        f\"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove \"\n                        f\"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated).\"\n                    )\n\n        if not config.actor_rollout_ref.actor.use_dynamic_bsz:\n            # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu\n            check_mutually_exclusive(\n                config.actor_rollout_ref.actor.ppo_micro_batch_size,\n                config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,\n                \"actor_rollout_ref.actor\",\n            )\n\n            if self.use_reference_policy:\n                # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu\n                check_mutually_exclusive(\n                    config.actor_rollout_ref.ref.log_prob_micro_batch_size,\n                    config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,\n                    \"actor_rollout_ref.ref\",\n                )\n\n            #  The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu\n            check_mutually_exclusive(\n                config.actor_rollout_ref.rollout.log_prob_micro_batch_size,\n                config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,\n                \"actor_rollout_ref.rollout\",\n            )\n\n        if self.use_critic and not config.critic.use_dynamic_bsz:\n            # Check for critic micro-batch size conflicts\n            check_mutually_exclusive(\n                config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, \"critic\"\n            )\n\n        # Check for reward model micro-batch size conflicts\n        if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:\n            check_mutually_exclusive(\n                config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, \"reward_model\"\n            )\n\n        # Actor\n        # check if train_batch_size is larger than ppo_mini_batch_size\n        # if NOT dynamic_bsz, we must ensure:\n        #    ppo_mini_batch_size is divisible by ppo_micro_batch_size\n        #    ppo_micro_batch_size * sequence_parallel_size >= n_gpus\n        if not config.actor_rollout_ref.actor.use_dynamic_bsz:\n            assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size\n            sp_size = config.actor_rollout_ref.actor.get(\"ulysses_sequence_parallel_size\", 1)\n            if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:\n                assert (\n                    config.actor_rollout_ref.actor.ppo_mini_batch_size\n                    % config.actor_rollout_ref.actor.ppo_micro_batch_size\n                    == 0\n                )\n                assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus\n\n        assert config.actor_rollout_ref.actor.loss_agg_mode in [\n            \"token-mean\",\n            \"seq-mean-token-sum\",\n            \"seq-mean-token-mean\",\n            \"seq-mean-token-sum-norm\",\n        ], f\"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}\"\n\n        if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:\n            print(\"NOTICE: You have both enabled in-reward kl and kl loss.\")\n\n        # critic\n        if self.use_critic and not config.critic.use_dynamic_bsz:\n            assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size\n            sp_size = config.critic.get(\"ulysses_sequence_parallel_size\", 1)\n            if config.critic.ppo_micro_batch_size is not None:\n                assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0\n                assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus\n\n        # Check if use_remove_padding is enabled when using sequence parallelism for fsdp\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"} and (\n            config.actor_rollout_ref.actor.get(\"ulysses_sequence_parallel_size\", 1) > 1\n            or config.actor_rollout_ref.ref.get(\"ulysses_sequence_parallel_size\", 1) > 1\n        ):\n            assert config.actor_rollout_ref.model.use_remove_padding, (\n                \"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`.\"\n            )\n\n        if self.use_critic and config.critic.strategy in {\"fsdp\", \"fsdp2\"}:\n            if config.critic.get(\"ulysses_sequence_parallel_size\", 1) > 1:\n                assert config.critic.model.use_remove_padding, (\n                    \"When using sequence parallelism for critic, you must enable `use_remove_padding`.\"\n                )\n\n        if config.data.get(\"val_batch_size\", None) is not None:\n            print(\n                \"WARNING: val_batch_size is deprecated.\"\n                + \" Validation datasets are sent to inference engines as a whole batch,\"\n                + \" which will schedule the memory themselves.\"\n            )\n\n        # check eval config\n        if config.actor_rollout_ref.rollout.val_kwargs.do_sample:\n            assert config.actor_rollout_ref.rollout.temperature > 0, (\n                \"validation gen temperature should be greater than 0 when enabling do_sample\"\n            )\n\n        print(\"[validate_config] All configuration checks passed successfully!\")\n\n    def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):\n        \"\"\"\n        Creates the train and validation dataloaders.\n        \"\"\"\n        # TODO: we have to make sure the batch size is divisible by the dp size\n        from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler\n\n        if train_dataset is None:\n            train_dataset = create_rl_dataset(\n                self.config.data.train_files, self.config.data, self.tokenizer, self.processor\n            )\n        if val_dataset is None:\n            val_dataset = create_rl_dataset(\n                self.config.data.val_files, self.config.data, self.tokenizer, self.processor\n            )\n        self.train_dataset, self.val_dataset = train_dataset, val_dataset\n\n        if train_sampler is None:\n            train_sampler = create_rl_sampler(self.config.data, self.train_dataset)\n        if collate_fn is None:\n            from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn\n\n            collate_fn = default_collate_fn\n\n        num_workers = self.config.data[\"dataloader_num_workers\"]\n\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=self.config.data.get(\"gen_batch_size\", self.config.data.train_batch_size),\n            num_workers=num_workers,\n            drop_last=True,\n            collate_fn=collate_fn,\n            sampler=train_sampler,\n        )\n\n        val_batch_size = self.config.data.val_batch_size  # Prefer config value if set\n        if val_batch_size is None:\n            val_batch_size = len(self.val_dataset)\n\n        self.val_dataloader = StatefulDataLoader(\n            dataset=self.val_dataset,\n            batch_size=val_batch_size,\n            num_workers=num_workers,\n            shuffle=self.config.data.get(\"validation_shuffle\", True),\n            drop_last=False,\n            collate_fn=collate_fn,\n        )\n\n        assert len(self.train_dataloader) >= 1, \"Train dataloader is empty!\"\n        assert len(self.val_dataloader) >= 1, \"Validation dataloader is empty!\"\n\n        print(\n            f\"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: \"\n            f\"{len(self.val_dataloader)}\"\n        )\n\n        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n\n        if self.config.trainer.total_training_steps is not None:\n            total_training_steps = self.config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n        print(f\"Total training steps: {self.total_training_steps}\")\n\n        try:\n            OmegaConf.set_struct(self.config, True)\n            with open_dict(self.config):\n                if OmegaConf.select(self.config, \"actor_rollout_ref.actor.optim\"):\n                    self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps\n                if OmegaConf.select(self.config, \"critic.optim\"):\n                    self.config.critic.optim.total_training_steps = total_training_steps\n        except Exception as e:\n            print(f\"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}\")\n\n    def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path, ground_truths=None):\n        \"\"\"Dump rollout/validation samples as JSONL.\"\"\"\n        os.makedirs(dump_path, exist_ok=True)\n        filename = os.path.join(dump_path, f\"{self.global_steps}.jsonl\")\n\n        n = len(inputs)\n        base_data = {\n            \"input\": inputs,\n            \"output\": outputs,\n            \"score\": scores,\n            \"step\": [self.global_steps] * n,\n        }\n\n        if ground_truths and len(ground_truths) == n:\n            base_data[\"ground_truth\"] = ground_truths\n\n        for k, v in reward_extra_infos_dict.items():\n            if len(v) == n:\n                base_data[k] = v\n\n        lines = []\n        for i in range(n):\n            entry = {k: v[i] for k, v in base_data.items()}\n            lines.append(json.dumps(entry, ensure_ascii=False))\n\n        with open(filename, \"w\") as f:\n            f.write(\"\\n\".join(lines) + \"\\n\")\n\n        print(f\"Dumped generations to {filename}\")\n\n    def _maybe_log_val_generations(self, inputs, outputs, scores):\n        \"\"\"Log a table of validation samples to the configured logger (wandb or swanlab)\"\"\"\n\n        generations_to_log = self.config.trainer.log_val_generations\n\n        if generations_to_log == 0:\n            return\n\n        import numpy as np\n\n        # Create tuples of (input, output, score) and sort by input text\n        samples = list(zip(inputs, outputs, scores, strict=True))\n        samples.sort(key=lambda x: x[0])  # Sort by input text\n\n        # Use fixed random seed for deterministic shuffling\n        rng = np.random.RandomState(42)\n        rng.shuffle(samples)\n\n        # Take first N samples after shuffling\n        samples = samples[:generations_to_log]\n\n        # Log to each configured logger\n        self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)\n\n    def _validate(self):\n        data_source_lst = []\n        reward_extra_infos_dict: dict[str, list] = defaultdict(list)\n\n        # Debug: print dataset sizes before validation\n        print(f\"[_validate] Starting validation. train_dataset size: {len(self.train_dataset)}, val_dataset size: {len(self.val_dataset)}\")\n        print(f\"[_validate] actor_rollout_wg world_size: {self.actor_rollout_wg.world_size}\")\n\n        # Lists to collect samples for the table\n        sample_inputs = []\n        sample_outputs = []\n        sample_scores = []\n        sample_turns = []\n        sample_ground_truths = []\n\n        batch_idx = 0\n        for test_data in self.val_dataloader:\n            test_batch = DataProto.from_single_dict(test_data)\n            print(f\"[Validation Debug] Batch {batch_idx}: test_batch size = {len(test_batch)}\")\n            batch_idx += 1\n\n            # Check if beam search or two-stage rollout is enabled for validation\n            val_kwargs = self.config.actor_rollout_ref.rollout.val_kwargs\n            rollout_config = self.config.actor_rollout_ref.rollout\n            use_beam_search_val = val_kwargs.get(\"use_beam_search\", False)\n            is_two_stage_rollout_val = rollout_config.get(\"name\") == \"two_stage\"\n\n            # Only repeat if NOT using beam search (beam search will expand outputs internally)\n            # For two-stage rollout, we DO repeat (for different CoT samples), beam expansion happens in rollout\n            if not use_beam_search_val:\n                # repeat test batch for sampling-based generation\n                test_batch = test_batch.repeat(\n                    repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True\n                )\n\n            # we only do validation on rule-based rm\n            if self.config.reward_model.enable and test_batch[0].non_tensor_batch[\"reward_model\"][\"style\"] == \"model\":\n                return {}\n\n            # Store original inputs (will be expanded later if beam search returns all beams)\n            input_ids = test_batch.batch[\"input_ids\"]\n            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]\n            # Note: sample_inputs will be extended after beam search expansion handling\n\n            if \"reward_model\" in test_batch.non_tensor_batch:\n                ground_truths = [item[\"ground_truth\"] for item in test_batch.non_tensor_batch[\"reward_model\"]]\n                # Note: ground_truths will be extended after beam search expansion handling\n\n            batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n            non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n            if \"multi_modal_data\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"multi_modal_data\")\n            if \"raw_prompt\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"raw_prompt\")\n            if \"tools_kwargs\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"tools_kwargs\")\n            if \"interaction_kwargs\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"interaction_kwargs\")\n            if \"agent_name\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"agent_name\")\n            test_gen_batch = test_batch.pop(\n                batch_keys=batch_keys_to_pop,\n                non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n            )\n\n            # Validation configuration\n            val_kwargs = self.config.actor_rollout_ref.rollout.val_kwargs\n            rollout_config = self.config.actor_rollout_ref.rollout\n            meta_info = {\n                \"eos_token_id\": self.tokenizer.eos_token_id,\n                \"pad_token_id\": self.tokenizer.pad_token_id,\n                \"recompute_log_prob\": False,\n                \"do_sample\": val_kwargs.do_sample,\n                \"validate\": True,\n                \"global_steps\": self.global_steps,\n            }\n\n            # Check for Two-Stage Rollout in Validation\n            if rollout_config.get(\"enable_two_stage_rollout\", False):\n                meta_info[\"enable_two_stage_rollout\"] = True\n                meta_info[\"stage2_beam_size\"] = rollout_config.get(\"stage2_beam_size\", 32)\n                meta_info[\"stage2_max_tokens\"] = rollout_config.get(\"stage2_max_tokens\", 16)\n                \n                # Stage 1 CoT config\n                meta_info[\"max_tokens\"] = self.config.data.get(\"max_response_length\", 1024)\n                # Disable standard beam search for Stage 1 (use sampling)\n                meta_info[\"use_beam_search\"] = False\n                meta_info[\"n\"] = val_kwargs.get(\"n\", 1)\n                \n                print(f\"[OneRecTrainer] Validation Two-Stage Enabled: {meta_info}\")\n\n            # Inject Beam Search parameters if enabled for validation (Single Stage)\n            elif val_kwargs.get(\"use_beam_search\", False):\n                meta_info[\"use_beam_search\"] = True\n                meta_info[\"best_of\"] = val_kwargs.get(\"best_of\", 4)\n                # Use max_response_length from config for validation as well\n                meta_info[\"max_tokens\"] = self.config.data.get(\"max_response_length\", 16)\n                meta_info[\"temperature\"] = 0\n                # n controls how many beams to return per prompt (will expand output)\n                meta_info[\"n\"] = val_kwargs.get(\"n\", 1)\n                # Signal rollout to return all beams (no repeat, expand internally)\n                meta_info[\"return_all_beams\"] = True\n\n                print(f\"[OneRecTrainer] Validation Beam Search Enabled (optimized, no repeat): {meta_info}\")\n\n            test_gen_batch.meta_info = meta_info\n            print(f\"test_gen_batch meta info: {test_gen_batch.meta_info}\")\n\n            # pad to be divisible by dp_size\n            size_divisor = (\n                self.actor_rollout_wg.world_size\n                if not self.async_rollout_mode\n                else self.config.actor_rollout_ref.rollout.agent.num_workers\n            )\n            test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)\n            if not self.async_rollout_mode:\n                test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)\n            else:\n                test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)\n\n            # unpad - For beam search or two-stage rollout, output is expanded, so we need to unpad accordingly\n            if use_beam_search_val or is_two_stage_rollout_val:\n                # For two-stage rollout, expansion is val_kwargs.n * stage2_beam_size\n                if is_two_stage_rollout_val:\n                    stage2_beam_size = rollout_config.get(\"stage2_beam_size\", 2)\n                    n_beams = stage2_beam_size  # rollout already expands by beam_width\n                    print(f\"[Validation Debug] Two-stage unpad: original pad_size={pad_size}, stage2_beam_size={stage2_beam_size}, actual_pad_size={pad_size * n_beams}\")\n                else:\n                    n_beams = val_kwargs.get(\"n\", 1)\n                    print(f\"[Validation Debug] Beam search unpad: original pad_size={pad_size}, n_beams={n_beams}, actual_pad_size={pad_size * n_beams}\")\n                actual_pad_size = pad_size * n_beams\n            else:\n                actual_pad_size = pad_size\n            test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=actual_pad_size)\n\n            # Debug: Check keys returned from worker\n            print(f\"[Trainer Debug] test_output_gen_batch keys: {test_output_gen_batch.non_tensor_batch.keys()}\")\n\n            print(\"validation generation end\")\n\n            # Handle beam search or two-stage rollout expansion: output may be larger than input\n            # When return_all_beams=True, rollout expands output to batch_size * beam_width\n            output_len = len(test_output_gen_batch)\n            input_len = len(test_batch)\n            if output_len > input_len and (use_beam_search_val or is_two_stage_rollout_val):\n                # Rollout guarantees output_len = input_len * expand_factor, so we can use simple repeat\n                expand_factor = output_len // input_len\n                print(f\"[Validation Debug] Batch {batch_idx-1}: Beam/TwoStage expansion - input={input_len}, output={output_len}, factor={expand_factor}\")\n                test_batch = test_batch.repeat(repeat_times=expand_factor, interleave=True)\n                input_texts = [t for t in input_texts for _ in range(expand_factor)]\n                if \"reward_model\" in test_batch.non_tensor_batch:\n                    ground_truths = [t for t in ground_truths for _ in range(expand_factor)]\n                print(f\"[Validation Debug] Batch {batch_idx-1}: After expansion - len(input_texts)={len(input_texts)}, len(test_batch)={len(test_batch)}\")\n\n            # Now extend sample_inputs and sample_ground_truths\n            before_extend = len(sample_inputs)\n            sample_inputs.extend(input_texts)\n            print(f\"[Validation Debug] Batch {batch_idx-1}: Extended sample_inputs from {before_extend} to {len(sample_inputs)} (+{len(input_texts)})\")\n            if \"reward_model\" in test_batch.non_tensor_batch:\n                sample_ground_truths.extend(ground_truths)\n\n            # Store generated outputs\n            output_ids = test_output_gen_batch.batch[\"responses\"]\n            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]\n            sample_outputs.extend(output_texts)\n\n            # Collect response lengths for validation metrics\n            response_lengths = [(ids != self.tokenizer.pad_token_id).sum().item() for ids in output_ids]\n            reward_extra_infos_dict[\"response_length\"].extend(response_lengths)\n\n            test_batch = test_batch.union(test_output_gen_batch)\n            test_batch.meta_info[\"validate\"] = True\n            \n            # Debug: Check keys after union\n            print(f\"[Trainer Debug] test_batch keys after union: {test_batch.non_tensor_batch.keys()}\")\n\n            # Critical Step: Move generated_items into extra_info for NaiveRewardManager\n            if \"generated_items\" in test_batch.non_tensor_batch:\n                print(\"[Trainer Debug] Moving generated_items into extra_info...\")\n                generated_items_arr = test_batch.non_tensor_batch[\"generated_items\"]\n                batch_size = len(generated_items_arr)\n                \n                # Ensure extra_info exists\n                if \"extra_info\" not in test_batch.non_tensor_batch:\n                    test_batch.non_tensor_batch[\"extra_info\"] = np.array([{} for _ in range(batch_size)], dtype=object)\n                \n                extra_info_arr = test_batch.non_tensor_batch[\"extra_info\"]\n                for i in range(batch_size):\n                    if extra_info_arr[i] is None: extra_info_arr[i] = {}\n                    # Update dict (reference modification)\n                    extra_info_arr[i][\"generated_items\"] = generated_items_arr[i]\n\n            # evaluate using reward_function\n            result = self.val_reward_fn(test_batch, return_dict=True)\n            reward_tensor = result[\"reward_tensor\"]\n            scores = reward_tensor.sum(-1).cpu().tolist()\n            sample_scores.extend(scores)\n\n            reward_extra_infos_dict[\"reward\"].extend(scores)\n            print(f\"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}\")\n            if \"reward_extra_info\" in result:\n                for key, lst in result[\"reward_extra_info\"].items():\n                    reward_extra_infos_dict[key].extend(lst)\n                    print(f\"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}\")\n\n            # collect num_turns of each prompt\n            if \"__num_turns__\" in test_batch.non_tensor_batch:\n                sample_turns.append(test_batch.non_tensor_batch[\"__num_turns__\"])\n\n            # 获取 data_source 信息，用于按task分组统计（和training逻辑一致）\n            reward_fn_key = self.config.data.get(\"reward_fn_key\", \"data_source\")\n            data_sources_batch = test_batch.non_tensor_batch.get(reward_fn_key, None)\n\n            # 如果没有找到，尝试其他常见字段名\n            if data_sources_batch is None:\n                data_sources_batch = test_batch.non_tensor_batch.get(\"source\", None)\n            if data_sources_batch is None:\n                data_sources_batch = test_batch.non_tensor_batch.get(\"data_source\", None)\n\n            # 如果还是找不到，使用默认值\n            if data_sources_batch is None:\n                data_sources_batch = [\"unknown\"] * reward_tensor.shape[0]\n\n            data_source_lst.append(data_sources_batch)\n\n        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)\n\n        # dump generations\n        val_data_dir = self.config.trainer.get(\"validation_data_dir\", None)\n        if val_data_dir:\n            self._dump_generations(\n                inputs=sample_inputs,\n                outputs=sample_outputs,\n                scores=sample_scores,\n                reward_extra_infos_dict=reward_extra_infos_dict,\n                dump_path=val_data_dir,\n                ground_truths=sample_ground_truths,\n            )\n\n        for key_info, lst in reward_extra_infos_dict.items():\n            assert len(lst) == 0 or len(lst) == len(sample_scores), f\"{key_info}: {len(lst)=}, {len(sample_scores)=}\"\n\n        data_sources = np.concatenate(data_source_lst, axis=0)\n\n        # Debug: Check for duplicate prompts\n        from collections import Counter\n        prompt_counts = Counter(sample_inputs)\n        duplicate_prompts = {p: c for p, c in prompt_counts.items() if c > 1}\n        if duplicate_prompts:\n            print(f\"[Validation Debug] Found {len(duplicate_prompts)} duplicate prompts!\")\n            for p, c in list(duplicate_prompts.items())[:3]:  # Show first 3\n                print(f\"  Prompt (truncated): '{p[:100]}...' appears {c} times\")\n        else:\n            print(f\"[Validation Debug] No duplicate prompts found. Total unique prompts: {len(prompt_counts)}\")\n        print(f\"[Validation Debug] Total samples: {len(sample_inputs)}, Total scores: {len(sample_scores)}\")\n\n        data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)\n        metric_dict = {}\n        for data_source, var2metric2val in data_src2var2metric2val.items():\n            core_var = \"acc\" if \"acc\" in var2metric2val else \"reward\"\n            for var_name, metric2val in var2metric2val.items():\n                n_max = max([int(name.split(\"@\")[-1].split(\"/\")[0]) for name in metric2val.keys()])\n                for metric_name, metric_val in metric2val.items():\n                    if (\n                        (var_name == core_var)\n                        and any(metric_name.startswith(pfx) for pfx in [\"mean\", \"maj\", \"best\", \"pass\"])\n                        and (f\"@{n_max}\" in metric_name)\n                    ):\n                        metric_sec = \"val-core\"\n                    else:\n                        metric_sec = \"val-aux\"\n                    pfx = f\"{metric_sec}/{data_source}/{var_name}/{metric_name}\"\n                    metric_dict[pfx] = metric_val\n\n        if len(sample_turns) > 0:\n            sample_turns = np.concatenate(sample_turns)\n            metric_dict[\"val-aux/num_turns/min\"] = sample_turns.min()\n            metric_dict[\"val-aux/num_turns/max\"] = sample_turns.max()\n            metric_dict[\"val-aux/num_turns/mean\"] = sample_turns.mean()\n\n        # Add validation response_length statistics\n        if \"response_length\" in reward_extra_infos_dict:\n            response_lengths = reward_extra_infos_dict[\"response_length\"]\n            if len(response_lengths) > 0:\n                import torch\n                response_lengths_tensor = torch.tensor(response_lengths)\n                metric_dict[\"val/response_length/mean\"] = response_lengths_tensor.float().mean().item()\n                metric_dict[\"val/response_length/max\"] = response_lengths_tensor.max().item()\n                metric_dict[\"val/response_length/min\"] = response_lengths_tensor.min().item()\n\n        return metric_dict\n\n    def init_workers(self):\n        \"\"\"Initialize distributed training workers using Ray backend.\n\n        Creates:\n        1. Ray resource pools from configuration\n        2. Worker groups for each role (actor, critic, etc.)\n        \"\"\"\n        self.resource_pool_manager.create_resource_pool()\n\n        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}\n\n        # create actor and rollout\n        if self.hybrid_engine:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)\n            actor_rollout_cls = RayClassWithInitArgs(\n                cls=self.role_worker_mapping[Role.ActorRollout],\n                config=self.config.actor_rollout_ref,\n                role=\"actor_rollout\",\n                profile_option=self.config.trainer.npu_profile.options,\n            )\n            self.resource_pool_to_cls[resource_pool][\"actor_rollout\"] = actor_rollout_cls\n        else:\n            raise NotImplementedError\n\n        # create critic\n        if self.use_critic:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)\n            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)\n            self.resource_pool_to_cls[resource_pool][\"critic\"] = critic_cls\n\n        # create reference policy if needed\n        if self.use_reference_policy:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)\n            ref_policy_cls = RayClassWithInitArgs(\n                self.role_worker_mapping[Role.RefPolicy],\n                config=self.config.actor_rollout_ref,\n                role=\"ref\",\n                profile_option=self.config.trainer.npu_profile.options,\n            )\n            self.resource_pool_to_cls[resource_pool][\"ref\"] = ref_policy_cls\n\n        # create a reward model if reward_fn is None\n        if self.use_rm:\n            # we create a RM here\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)\n            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)\n            self.resource_pool_to_cls[resource_pool][\"rm\"] = rm_cls\n\n        # initialize WorkerGroup\n        # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,\n        # you should not use `create_colocated_worker_cls`.\n        # Instead, directly pass different resource pool to different worker groups.\n        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.\n        all_wg = {}\n        wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup\n        if OmegaConf.select(self.config.trainer, \"ray_wait_register_center_timeout\") is not None:\n            wg_kwargs[\"ray_wait_register_center_timeout\"] = self.config.trainer.ray_wait_register_center_timeout\n        if OmegaConf.select(self.config.trainer, \"profile_steps\") is not None:\n            wg_kwargs[\"profile_steps\"] = OmegaConf.select(self.config.trainer, \"profile_steps\")\n            assert OmegaConf.select(self.config.trainer, \"worker_nsight_options\") is not None, (\n                \"worker_nsight_options must be set when profile_steps is set\"\n            )\n            wg_kwargs[\"worker_nsight_options\"] = OmegaConf.to_container(\n                OmegaConf.select(self.config.trainer, \"worker_nsight_options\")\n            )\n        wg_kwargs[\"device_name\"] = self.device_name\n\n        for resource_pool, class_dict in self.resource_pool_to_cls.items():\n            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n            wg_dict = self.ray_worker_group_cls(\n                resource_pool=resource_pool,\n                ray_cls_with_init=worker_dict_cls,\n                **wg_kwargs,\n            )\n            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n            all_wg.update(spawn_wg)\n\n        if self.use_critic:\n            self.critic_wg = all_wg[\"critic\"]\n            self.critic_wg.init_model()\n\n        if self.use_reference_policy and not self.ref_in_actor:\n            self.ref_policy_wg = all_wg[\"ref\"]\n            self.ref_policy_wg.init_model()\n\n        if self.use_rm:\n            self.rm_wg = all_wg[\"rm\"]\n            self.rm_wg.init_model()\n\n        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory\n        self.actor_rollout_wg = all_wg[\"actor_rollout\"]\n        self.actor_rollout_wg.init_model()\n\n        # create async rollout manager and request scheduler\n        self.async_rollout_mode = False\n        if self.config.actor_rollout_ref.rollout.mode == \"async\":\n            from verl.experimental.agent_loop import AgentLoopManager\n\n            self.async_rollout_mode = True\n            self.async_rollout_manager = AgentLoopManager(\n                config=self.config,\n                worker_group=self.actor_rollout_wg,\n            )\n\n    def _save_checkpoint(self):\n        from verl.utils.fs import local_mkdir_safe\n\n        # path: given_path + `/global_step_{global_steps}` + `/actor`\n        local_global_step_folder = os.path.join(\n            self.config.trainer.default_local_dir, f\"global_step_{self.global_steps}\"\n        )\n\n        print(f\"local_global_step_folder: {local_global_step_folder}\")\n        actor_local_path = os.path.join(local_global_step_folder, \"actor\")\n\n        actor_remote_path = (\n            None\n            if self.config.trainer.default_hdfs_dir is None\n            else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"actor\")\n        )\n\n        remove_previous_ckpt_in_save = self.config.trainer.get(\"remove_previous_ckpt_in_save\", False)\n        if remove_previous_ckpt_in_save:\n            print(\n                \"Warning: remove_previous_ckpt_in_save is deprecated,\"\n                + \" set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead\"\n            )\n        max_actor_ckpt_to_keep = (\n            self.config.trainer.get(\"max_actor_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n        max_critic_ckpt_to_keep = (\n            self.config.trainer.get(\"max_critic_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n\n        self.actor_rollout_wg.save_checkpoint(\n            actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep\n        )\n\n        if self.use_critic:\n            critic_local_path = os.path.join(local_global_step_folder, \"critic\")\n            critic_remote_path = (\n                None\n                if self.config.trainer.default_hdfs_dir is None\n                else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"critic\")\n            )\n            self.critic_wg.save_checkpoint(\n                critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep\n            )\n\n        # save dataloader\n        local_mkdir_safe(local_global_step_folder)\n        dataloader_local_path = os.path.join(local_global_step_folder, \"data.pt\")\n        dataloader_state_dict = self.train_dataloader.state_dict()\n        torch.save(dataloader_state_dict, dataloader_local_path)\n\n        # latest checkpointed iteration tracker (for atomic usage)\n        local_latest_checkpointed_iteration = os.path.join(\n            self.config.trainer.default_local_dir, \"latest_checkpointed_iteration.txt\"\n        )\n        with open(local_latest_checkpointed_iteration, \"w\") as f:\n            f.write(str(self.global_steps))\n\n    def _load_checkpoint(self):\n        if self.config.trainer.resume_mode == \"disable\":\n            return 0\n\n        # load from hdfs\n        if self.config.trainer.default_hdfs_dir is not None:\n            raise NotImplementedError(\"load from hdfs is not implemented yet\")\n        else:\n            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path\n            if not os.path.isabs(checkpoint_folder):\n                working_dir = os.getcwd()\n                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)\n            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest\n\n        # find global_step_folder\n        if self.config.trainer.resume_mode == \"auto\":\n            if global_step_folder is None:\n                print(\"Training from scratch\")\n                return 0\n        else:\n            if self.config.trainer.resume_mode == \"resume_path\":\n                assert isinstance(self.config.trainer.resume_from_path, str), \"resume ckpt must be str type\"\n                assert \"global_step_\" in self.config.trainer.resume_from_path, (\n                    \"resume ckpt must specify the global_steps\"\n                )\n                global_step_folder = self.config.trainer.resume_from_path\n                if not os.path.isabs(global_step_folder):\n                    working_dir = os.getcwd()\n                    global_step_folder = os.path.join(working_dir, global_step_folder)\n        print(f\"Load from checkpoint folder: {global_step_folder}\")\n        # set global step\n        self.global_steps = int(global_step_folder.split(\"global_step_\")[-1])\n\n        print(f\"Setting global step to {self.global_steps}\")\n        print(f\"Resuming from {global_step_folder}\")\n\n        actor_path = os.path.join(global_step_folder, \"actor\")\n        critic_path = os.path.join(global_step_folder, \"critic\")\n        # load actor\n        self.actor_rollout_wg.load_checkpoint(\n            actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n        )\n        # load critic\n        if self.use_critic:\n            self.critic_wg.load_checkpoint(\n                critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n            )\n\n        # load dataloader,\n        # TODO: from remote not implemented yet\n        dataloader_local_path = os.path.join(global_step_folder, \"data.pt\")\n        if os.path.exists(dataloader_local_path):\n            dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)\n            self.train_dataloader.load_state_dict(dataloader_state_dict)\n        else:\n            print(f\"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch\")\n\n    def _start_profiling(self, do_profile: bool) -> None:\n        \"\"\"Start profiling for all worker groups if profiling is enabled.\"\"\"\n        if do_profile:\n            self.actor_rollout_wg.start_profile(role=\"e2e\", profile_step=self.global_steps)\n            if self.use_reference_policy:\n                self.ref_policy_wg.start_profile()\n            if self.use_critic:\n                self.critic_wg.start_profile()\n            if self.use_rm:\n                self.rm_wg.start_profile()\n\n    def _stop_profiling(self, do_profile: bool) -> None:\n        \"\"\"Stop profiling for all worker groups if profiling is enabled.\"\"\"\n        if do_profile:\n            self.actor_rollout_wg.stop_profile()\n            if self.use_reference_policy:\n                self.ref_policy_wg.stop_profile()\n            if self.use_critic:\n                self.critic_wg.stop_profile()\n            if self.use_rm:\n                self.rm_wg.stop_profile()\n\n    def _balance_batch(self, batch: DataProto, metrics, logging_prefix=\"global_seqlen\"):\n        \"\"\"Reorder the data on single controller such that each dp rank gets similar total tokens\"\"\"\n        attention_mask = batch.batch[\"attention_mask\"]\n        batch_size = attention_mask.shape[0]\n        global_seqlen_lst = batch.batch[\"attention_mask\"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)\n        world_size = self.actor_rollout_wg.world_size\n        global_partition_lst = get_seqlen_balanced_partitions(\n            global_seqlen_lst, k_partitions=world_size, equal_size=True\n        )\n        # reorder based on index. The data will be automatically equally partitioned by dispatch function\n        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])\n        batch.reorder(global_idx)\n        global_balance_stats = log_seqlen_unbalance(\n            seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix\n        )\n        metrics.update(global_balance_stats)\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n        self.max_steps_duration = 0\n\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n                timing_raw = {}\n\n                do_profile = (\n                    self.global_steps in self.config.trainer.profile_steps\n                    if self.config.trainer.profile_steps is not None\n                    else False\n                )\n                with marked_timer(\"start_profile\", timing_raw):\n                    self._start_profiling(do_profile)\n\n                batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n                # pop those keys for generation\n                batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n                non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n                if \"multi_modal_data\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"multi_modal_data\")\n                if \"raw_prompt\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"raw_prompt\")\n                if \"tools_kwargs\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"tools_kwargs\")\n                if \"interaction_kwargs\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"interaction_kwargs\")\n                if \"index\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"index\")\n                if \"agent_name\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"agent_name\")\n\n                gen_batch = batch.pop(\n                    batch_keys=batch_keys_to_pop,\n                    non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n                )\n\n                # pass global_steps to trace\n                gen_batch.meta_info[\"global_steps\"] = self.global_steps\n\n                # Get original batch size for beam_idx calculation\n                original_bs = len(gen_batch)\n\n                # Check if beam search is enabled - if so, don't repeat (optimization)\n                # Two-stage rollout: still repeat (for different CoT samples), but beam expansion happens in rollout\n                rollout_config = self.config.actor_rollout_ref.rollout\n                use_beam_search_train = rollout_config.get(\"use_beam_search\", False)\n                is_two_stage_rollout = rollout_config.get(\"name\") == \"two_stage\"\n                rollout_n = self.config.actor_rollout_ref.rollout.n\n\n                if not use_beam_search_train:\n                    # Standard sampling or two-stage rollout: repeat the batch for n_rollout different samples\n                    gen_batch = gen_batch.repeat(repeat_times=rollout_n, interleave=True)\n\n                    if \"reward_model\" in batch.non_tensor_batch:\n                        # repeat reward_model to match gen_batch size\n                        repeated_reward_model = np.repeat(\n                            batch.non_tensor_batch[\"reward_model\"],\n                            rollout_n,\n                            axis=0\n                        )\n                        gen_batch.non_tensor_batch[\"reward_model\"] = repeated_reward_model\n                else:\n                    print(f\"[OneRecTrainer] Beam search enabled, skipping repeat (optimized path)\")\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                with marked_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with marked_timer(\"gen\", timing_raw, color=\"red\"):\n                        # Dynamically configure generation parameters based on config\n                        rollout_config = self.config.actor_rollout_ref.rollout\n                        \n                        # Check if beam search is enabled in config\n                        if rollout_config.get(\"use_beam_search\", False):\n                            gen_batch.meta_info[\"use_beam_search\"] = True\n                            gen_batch.meta_info[\"best_of\"] = rollout_config.get(\"best_of\", 4)\n                            # Use max_response_length from data config if available, otherwise default\n                            gen_batch.meta_info[\"max_tokens\"] = self.config.data.get(\"max_response_length\", 16)\n                            gen_batch.meta_info[\"temperature\"] = 0\n                            n = rollout_config.get(\"n\", 1)\n                            gen_batch.meta_info[\"n\"] = n\n                            # Optimized: return all beams from rollout, no repeat needed\n                            gen_batch.meta_info[\"return_all_beams\"] = True\n\n                            print(f\"[OneRecTrainer] Beam Search Enabled (optimized, no repeat): {gen_batch.meta_info}\")\n                        \n                        # Check if Two-Stage Rollout is enabled\n                        if rollout_config.get(\"enable_two_stage_rollout\", False):\n                            gen_batch.meta_info[\"enable_two_stage_rollout\"] = True\n                            gen_batch.meta_info[\"stage2_beam_size\"] = rollout_config.get(\"stage2_beam_size\", 32)\n                            gen_batch.meta_info[\"stage2_max_tokens\"] = rollout_config.get(\"stage2_max_tokens\", 16)\n                            # For Stage 1 (CoT), we use sampling params\n                            gen_batch.meta_info[\"max_tokens\"] = self.config.data.get(\"max_response_length\", 1024) # CoT length\n                            gen_batch.meta_info[\"temperature\"] = rollout_config.get(\"temperature\", 1.0)\n                            gen_batch.meta_info[\"top_p\"] = rollout_config.get(\"top_p\", 1.0)\n                            # Disable use_beam_search flag to prevent conflict in standard flow if both are set\n                            gen_batch.meta_info[\"use_beam_search\"] = False \n                            \n                            print(f\"[OneRecTrainer] Two-Stage Rollout Enabled: {gen_batch.meta_info}\")\n\n                        if not self.async_rollout_mode:\n                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                        else:\n                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                        # Handle beam search/two-stage rollout expansion: output may be larger than input\n                        # When return_all_beams=True, rollout expands output to batch_size * beam_width\n                        if use_beam_search_train or is_two_stage_rollout:\n                            output_len = len(gen_batch_output)\n                            input_len = len(batch)\n                            print(f\"[OneRecTrainer] Beam/TwoStage: gen_batch_output size={output_len}, batch size={input_len}, n={rollout_n}\")\n\n                            # CRITICAL FIX: Generate UIDs BEFORE expansion so that beams from\n                            # the same prompt share the same UID for correct GRPO grouping\n                            # This must happen regardless of whether expansion is needed\n                            batch.non_tensor_batch[\"uid\"] = np.array(\n                                [str(uuid.uuid4()) for _ in range(input_len)], dtype=object\n                            )\n                            print(f\"[OneRecTrainer] Generated UIDs before expansion: {len(batch.non_tensor_batch['uid'])} unique UIDs\")\n\n                            if output_len > input_len:\n                                # Rollout guarantees output_len = input_len * expand_factor\n                                assert output_len % input_len == 0, \\\n                                    f\"Output size {output_len} must be a multiple of input size {input_len}\"\n                                expand_factor = output_len // input_len\n                                print(f\"[OneRecTrainer] Expanding batch using repeat: factor={expand_factor}\")\n\n                                batch = batch.repeat(repeat_times=expand_factor, interleave=True)\n                                print(f\"[OneRecTrainer] After expansion: batch size={len(batch)}, UIDs will be repeated {expand_factor}x\")\n\n                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                        with marked_timer(\"gen_max\", timing_raw, color=\"purple\"):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            if not self.async_rollout_mode:\n                                gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n                            else:\n                                gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)\n                            batch = batch.union(gen_baseline_output)\n                            reward_baseline_tensor = self.reward_fn(batch)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))\n\n                            batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del gen_baseline_batch, gen_baseline_output\n\n                    # Generate UIDs and repeat batch for standard sampling path\n                    # Skip if beam search or two-stage rollout already handled this above\n                    if not use_beam_search_train and not is_two_stage_rollout:\n                        # Use original_bs (stored at line 1786) for consistency\n                        batch.non_tensor_batch[\"uid\"] = np.array(\n                            [str(uuid.uuid4()) for _ in range(original_bs)], dtype=object\n                        )\n                        # repeat to align with repeated responses in rollout\n                        batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    \n                    # FORCE INJECTION: Bypass union and inject directly into extra_info\n                    if \"generated_items\" in gen_batch_output.non_tensor_batch:\n                        print(f\"[Trainer Fit Debug] Force injecting generated_items into extra_info...\")\n                        gen_items = gen_batch_output.non_tensor_batch[\"generated_items\"]\n                        \n                        # Ensure extra_info exists in batch\n                        if \"extra_info\" not in batch.non_tensor_batch:\n                            batch.non_tensor_batch[\"extra_info\"] = np.array([{} for _ in range(len(batch))], dtype=object)\n                        \n                        extra_infos = batch.non_tensor_batch[\"extra_info\"]\n                        \n                        if len(gen_items) == len(extra_infos):\n                            for i in range(len(gen_items)):\n                                if extra_infos[i] is None: extra_infos[i] = {}\n                                extra_infos[i][\"generated_items\"] = gen_items[i]\n                        else:\n                            print(f\"[Trainer Fit Error] Batch size mismatch during injection: {len(gen_items)} vs {len(extra_infos)}\")\n\n                    batch = batch.union(gen_batch_output)\n\n                    if \"response_mask\" not in batch.batch.keys():\n                        batch.batch[\"response_mask\"] = compute_response_mask(batch)\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    # TODO: Decouple the DP balancing and mini-batching.\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                    with marked_timer(\"reward\", timing_raw, color=\"yellow\"):\n                        # compute reward model score\n                        if self.use_rm:\n                            reward_tensor = self.rm_wg.compute_rm_score(batch)\n                            batch = batch.union(reward_tensor)\n\n                        if self.config.reward_model.launch_reward_fn_async:\n                            future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)\n                        else:\n                            reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)\n\n                    # recompute old_log_probs\n                    with marked_timer(\"old_log_prob\", timing_raw, color=\"blue\"):\n                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                        entropys = old_log_prob.batch[\"entropys\"]\n                        response_masks = batch.batch[\"response_mask\"]\n                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                        old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n\n                        # per-position entropy plot\n                        # masked_entropys = entropys * response_masks\n                        # sum_entropy_per_position = torch.sum(masked_entropys, dim=0)\n                        # num_tokens_per_position = torch.sum(response_masks, dim=0)\n                        # mean_entropy_per_position = sum_entropy_per_position / torch.clamp(\n                        #     num_tokens_per_position, min=1\n                        # )\n                        # try:\n                        #     entropy_list = mean_entropy_per_position.cpu().tolist()\n                        #     table_data = [[i, ent] for i, ent in enumerate(entropy_list)]\n                        #     table = wandb.Table(data=table_data, columns=[\"position\", \"entropy\"])\n                        #     old_log_prob_metrics[\"actor/per_position_entropy_plot\"] = wandb.plot.line(\n                        #         table, \"position\", \"entropy\", title=\"Per-Position Entropy\"\n                        #     )\n                        # except Exception as e:\n                        #     print(f\"Warning: Could not create wandb per-position entropy plot. Error: {e}\")\n\n                        # token-type entropy\n                        try:\n                            responses = batch.batch[\"responses\"]\n                            # mask for token type 1 (id >= 151669)\n                            type1_mask = (responses >= 151669) * response_masks\n                            # mask for token type 2 (id < 151669)\n                            type2_mask = (responses < 151669) * response_masks\n\n                            count_type1 = type1_mask.sum().item()\n                            count_type2 = type2_mask.sum().item()\n\n                            if count_type1 > 0:\n                                entropy_type1 = masked_mean(entropys, mask=type1_mask, axis=None).item()\n                                old_log_prob_metrics[\"actor/entropy_itemic_token\"] = entropy_type1\n\n                            if count_type2 > 0:\n                                entropy_type2 = masked_mean(entropys, mask=type2_mask, axis=None).item()\n                                old_log_prob_metrics[\"actor/entropy_lang_token\"] = entropy_type2\n\n                            old_log_prob_metrics[\"actor/token_count_itemic_token\"] = count_type1\n                            old_log_prob_metrics[\"actor/token_count_lang_token\"] = count_type2\n                        except Exception as e:\n                            print(f\"Warning: Could not compute token-type entropy metrics. Error: {e}\")\n\n                        metrics.update(old_log_prob_metrics)\n                        old_log_prob.batch.pop(\"entropys\")\n                        batch = batch.union(old_log_prob)\n\n                        if \"rollout_log_probs\" in batch.batch.keys():\n                            # TODO: we may want to add diff of probs too.\n                            rollout_old_log_probs = batch.batch[\"rollout_log_probs\"]\n                            actor_old_log_probs = batch.batch[\"old_log_probs\"]\n                            attention_mask = batch.batch[\"attention_mask\"]\n                            responses = batch.batch[\"responses\"]\n                            response_length = responses.size(1)\n                            response_mask = attention_mask[:, -response_length:]\n\n                            rollout_probs = torch.exp(rollout_old_log_probs)\n                            actor_probs = torch.exp(actor_old_log_probs)\n                            rollout_probs_diff = torch.abs(rollout_probs - actor_probs)\n                            rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())\n                            rollout_probs_diff_max = torch.max(rollout_probs_diff)\n                            rollout_probs_diff_mean = torch.mean(rollout_probs_diff)\n                            rollout_probs_diff_std = torch.std(rollout_probs_diff)\n                            metrics.update(\n                                {\n                                    \"training/rollout_probs_diff_max\": rollout_probs_diff_max.detach().item(),\n                                    \"training/rollout_probs_diff_mean\": rollout_probs_diff_mean.detach().item(),\n                                    \"training/rollout_probs_diff_std\": rollout_probs_diff_std.detach().item(),\n                                }\n                            )\n\n                    if self.use_reference_policy:\n                        # compute reference log_prob\n                        with marked_timer(\"ref\", timing_raw, color=\"olive\"):\n                            if not self.ref_in_actor:\n                                ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                            else:\n                                ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)\n                            batch = batch.union(ref_log_prob)\n\n                    # compute values\n                    if self.use_critic:\n                        with marked_timer(\"values\", timing_raw, color=\"cyan\"):\n                            values = self.critic_wg.compute_values(batch)\n                            batch = batch.union(values)\n\n                    with marked_timer(\"adv\", timing_raw, color=\"brown\"):\n                        # we combine with rule-based rm\n                        reward_extra_infos_dict: dict[str, list]\n                        if self.config.reward_model.launch_reward_fn_async:\n                            reward_tensor, reward_extra_infos_dict = ray.get(future_reward)\n                        batch.batch[\"token_level_scores\"] = reward_tensor\n\n                        if reward_extra_infos_dict:\n                            batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})\n\n                            # 获取 data_source 信息，用于按task分组统计\n                            # 尝试多个可能的字段名：source, data_source, data_source_key\n                            reward_fn_key = self.config.data.get(\"reward_fn_key\", \"data_source\")\n                            data_sources = batch.non_tensor_batch.get(reward_fn_key, None)\n\n                            # 如果没有找到，尝试其他常见字段名\n                            if data_sources is None:\n                                data_sources = batch.non_tensor_batch.get(\"source\", None)\n                            if data_sources is None:\n                                data_sources = batch.non_tensor_batch.get(\"data_source\", None)\n\n                            # 调试信息：打印可用的字段\n                            if self.global_steps <= 2:  # 只在前几步打印\n                                print(f\"[DEBUG] Batch size: {len(batch)}\")\n                                print(f\"[DEBUG] Available non_tensor_batch keys: {list(batch.non_tensor_batch.keys())}\")\n                                print(f\"[DEBUG] reward_fn_key from config: {reward_fn_key}\")\n                                print(f\"[DEBUG] data_sources found: {data_sources is not None}\")\n                                if data_sources is not None:\n                                    print(f\"[DEBUG] data_sources type: {type(data_sources)}, shape: {getattr(data_sources, 'shape', len(data_sources))}\")\n                                    print(f\"[DEBUG] first 10 sources: {data_sources[:10] if len(data_sources) > 0 else []}\")\n                                    print(f\"[DEBUG] unique sources: {np.unique(data_sources)}\")\n\n                            if data_sources is not None:\n                                # 按 data_source 分组统计不同task的得分\n                                unique_sources = np.unique(data_sources)\n                                print(f\"[Task Statistics] Found {len(unique_sources)} unique tasks: {unique_sources}\")\n\n                                for source in unique_sources:\n                                    source_mask = data_sources == source\n                                    num_samples = int(np.sum(source_mask))\n\n                                    for key, values in reward_extra_infos_dict.items():\n                                        if values and len(values) > 0:\n                                            values_array = np.array(values)\n                                            # 只记录数值类型的指标\n                                            if np.issubdtype(values_array.dtype, np.number):\n                                                source_values = values_array[source_mask]\n                                                if len(source_values) > 0:\n                                                    metrics[f\"reward/{source}/{key}/mean\"] = float(np.mean(source_values))\n                                                    metrics[f\"reward/{source}/{key}/max\"] = float(np.max(source_values))\n                                                    metrics[f\"reward/{source}/{key}/min\"] = float(np.min(source_values))\n                                                    metrics[f\"reward/{source}/{key}/count\"] = num_samples\n                            else:\n                                print(f\"[WARNING] data_sources not found in batch.non_tensor_batch. Available keys: {list(batch.non_tensor_batch.keys())}\")\n\n                            # 全局统计（所有task合并）\n                            for key, values in reward_extra_infos_dict.items():\n                                if values and len(values) > 0:\n                                    values_array = np.array(values)\n                                    # 只记录数值类型的指标\n                                    if np.issubdtype(values_array.dtype, np.number):\n                                        metrics[f\"reward/all/{key}/mean\"] = float(np.mean(values_array))\n                                        metrics[f\"reward/all/{key}/max\"] = float(np.max(values_array))\n                                        metrics[f\"reward/all/{key}/min\"] = float(np.min(values_array))\n\n                        # compute rewards. apply_kl_penalty if available\n                        if self.config.algorithm.use_kl_in_reward:\n                            batch, kl_metrics = apply_kl_penalty(\n                                batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                            )\n                            metrics.update(kl_metrics)\n                        else:\n                            batch.batch[\"token_level_rewards\"] = batch.batch[\"token_level_scores\"]\n\n                        # 🎯 添加基于GT PPL的think quality reward\n                        enable_sid_ppl_reward = self.config.get(\"enable_sid_ppl_reward\", False)\n                        if enable_sid_ppl_reward:\n\n                            \n                            gt_ppl_reward_weight = self.config.get(\"sid_ppl_reward_weight\", 0.1)\n                            \n                            # 1. 构造 probe data\n                            probe_batch_dict, probe_non_tensor_dict, probe_mapping = construct_gt_probe_data(\n                                batch=batch,\n                                tokenizer=self.tokenizer\n                            )\n                            \n                            if probe_batch_dict:\n                                # 2. 构造 DataProto 并计算 log_prob\n                                # 注意：我们需要将 probe data 放到 device 上\n                                device = batch.batch[\"input_ids\"].device\n                                for k, v in probe_batch_dict.items():\n                                    probe_batch_dict[k] = v.to(device)\n\n                                # 将 dict 转换为 TensorDict\n                                probe_batch_size = probe_batch_dict[\"input_ids\"].shape[0]\n                                probe_tensor_dict = TensorDict(probe_batch_dict, batch_size=probe_batch_size)\n\n                                probe_batch = DataProto(\n                                    batch=probe_tensor_dict,\n                                    non_tensor_batch=probe_non_tensor_dict\n                                )\n                                \n                                # 计算 log_prob\n                                # compute_log_prob 返回的是 DataProto，其中 batch[\"old_log_probs\"] 是 log_prob\n                                probe_output = self.actor_rollout_wg.compute_log_prob(probe_batch)\n                                probe_log_probs = probe_output.batch[\"old_log_probs\"] # (num_probes, seq_len)\n                                \n                                # 3. 提取 GT tokens 的 log_prob 并计算 reward\n                                # 我们需要聚合每个 original_idx 的最大 reward\n                                original_idx_to_rewards = defaultdict(list)\n                                original_idx_to_think_end = {}\n                                \n                                for i, mapping in enumerate(probe_mapping):\n                                    original_idx = mapping[\"original_idx\"]\n                                    gt_len = mapping[\"gt_len\"]\n                                    think_end_idx = mapping[\"think_end_idx\"]\n                                    \n                                    original_idx_to_think_end[original_idx] = think_end_idx\n                                    \n                                    # 提取最后 gt_len 个 token 的 log_prob\n                                    # 注意：old_log_probs 对应的是 input_ids 的 log_prob\n                                    # input_ids = [prompt, thought, </think>, GT]\n                                    # 我们只关心 GT 部分\n                                    gt_log_probs = probe_log_probs[i, -gt_len:]\n                                    \n                                    # 计算平均 log_prob (即 -PPL score)\n                                    reward = gt_log_probs.mean().item()\n                                    \n                                    original_idx_to_rewards[original_idx].append(reward)\n                                \n                                # 4. 回填 reward\n                                reward_added_count = 0\n                                reward_sum = 0.0\n                                max_reward_val = -float('inf')\n                                min_reward_val = float('inf')\n                                \n                                for i, rewards in original_idx_to_rewards.items():\n                                    # 取 max reward (最匹配的 GT)\n                                    max_reward = max(rewards)\n                                    \n                                    think_end_idx = original_idx_to_think_end[i]\n                                    \n                                    # 确保索引不越界\n                                    if think_end_idx < batch.batch[\"token_level_rewards\"].shape[1]:\n                                        # 加上权重\n                                        weighted_reward = max_reward * gt_ppl_reward_weight\n                                        batch.batch[\"token_level_rewards\"][i, think_end_idx] += weighted_reward\n                                        \n                                        reward_added_count += 1\n                                        reward_sum += max_reward\n                                        max_reward_val = max(max_reward_val, max_reward)\n                                        min_reward_val = min(min_reward_val, max_reward)\n                                \n                                # 记录 metrics\n                                if reward_added_count > 0:\n                                    metrics[\"gt_ppl_reward/mean\"] = reward_sum / reward_added_count\n                                    metrics[\"gt_ppl_reward/max\"] = max_reward_val\n                                    metrics[\"gt_ppl_reward/min\"] = min_reward_val\n                                    metrics[\"gt_ppl_reward/count\"] = reward_added_count\n                                    print(f\"[Step {self.global_steps}] GT PPL Reward added to {reward_added_count} samples. Mean raw reward: {reward_sum / reward_added_count:.4f}\")\n\n                        # compute advantages, executed on the driver process\n\n                        norm_adv_by_std_in_grpo = self.config.algorithm.get(\n                            \"norm_adv_by_std_in_grpo\", True\n                        )  # GRPO adv normalization factor\n\n                        batch = compute_advantage(\n                            batch,\n                            adv_estimator=self.config.algorithm.adv_estimator,\n                            gamma=self.config.algorithm.gamma,\n                            lam=self.config.algorithm.lam,\n                            num_repeat=self.config.actor_rollout_ref.rollout.n,\n                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                            config=self.config.algorithm,\n                            tokenizer=self.tokenizer,\n                        )\n                        \n                        if self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO:\n                            hit_rewards = batch.non_tensor_batch[\"score\"]\n                            if isinstance(hit_rewards, np.ndarray):\n                                hit_rewards_tensor = torch.tensor(hit_rewards, dtype=torch.float32)\n                            else:\n                                hit_rewards_tensor = torch.tensor(list(hit_rewards), dtype=torch.float32)\n\n                            # 根据uid分组\n                            uids = batch.non_tensor_batch[\"uid\"]\n                            unique_uids = np.unique(uids)\n\n                            zero_hit_reward_group_ratios = []\n                            all_group_zero_count = 0  # 统计完全为0的group数量\n\n                            for uid in unique_uids:\n                                # 找到属于当前uid的所有样本\n                                uid_mask = (uids == uid)\n                                uid_hit_rewards = hit_rewards_tensor[uid_mask]\n\n                                # 统计hit_reward为0的样本数量\n                                zero_count = (uid_hit_rewards == 0).sum().item()\n                                total_count = len(uid_hit_rewards)\n\n                                # 计算当前group中hit_reward为0的比例\n                                zero_ratio = zero_count / total_count if total_count > 0 else 0\n                                zero_hit_reward_group_ratios.append(zero_ratio)\n\n                                # 如果整个group的hit_reward都是0，计数加1\n                                if zero_count == total_count:\n                                    all_group_zero_count += 1\n\n                            # 计算统计指标\n                            if len(zero_hit_reward_group_ratios) > 0:\n                                # 每个group中hit_reward为0的样本的平均比例\n                                mean_zero_hit_reward_ratio_in_group = np.mean(zero_hit_reward_group_ratios)\n                                # hit_reward完全为0的group占总group数的比例\n                                all_zero_group_ratio = all_group_zero_count / len(unique_uids)\n\n                                metrics[\"training/grpo_zero_hit_reward_ratio_in_group_mean\"] = mean_zero_hit_reward_ratio_in_group\n                                metrics[\"training/grpo_all_zero_hit_reward_group_ratio\"] = all_zero_group_ratio\n                                metrics[\"training/grpo_all_zero_hit_reward_group_count\"] = all_group_zero_count\n                                metrics[\"training/grpo_total_group_count\"] = len(unique_uids)\n\n                    # update critic\n                    if self.use_critic:\n                        with marked_timer(\"update_critic\", timing_raw, color=\"pink\"):\n                            critic_output = self.critic_wg.update_critic(batch)\n                        critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                        metrics.update(critic_output_metrics)\n\n                    # implement critic warmup\n                    if self.config.trainer.critic_warmup <= self.global_steps:\n                        # update actor\n                        with marked_timer(\"update_actor\", timing_raw, color=\"red\"):\n                            batch.meta_info[\"multi_turn\"] = self.config.actor_rollout_ref.rollout.multi_turn.enable\n                            actor_output = self.actor_rollout_wg.update_actor(batch)\n                        actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                        metrics.update(actor_output_metrics)\n\n                    # Log rollout generations if enabled\n                    rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n                    if rollout_data_dir:\n                        with marked_timer(\"dump_rollout_generations\", timing_raw, color=\"green\"):\n                            inputs = self.tokenizer.batch_decode(batch.batch[\"prompts\"], skip_special_tokens=True)\n                            outputs = self.tokenizer.batch_decode(batch.batch[\"responses\"], skip_special_tokens=True)\n                            scores = batch.batch[\"token_level_scores\"].sum(-1).cpu().tolist()\n                            ground_truths = None\n                            if \"reward_model\" in batch.non_tensor_batch:\n                                ground_truths = [item[\"ground_truth\"] for item in batch.non_tensor_batch[\"reward_model\"]]\n                            if \"request_id\" in batch.non_tensor_batch:\n                                reward_extra_infos_dict.setdefault(\n                                    \"request_id\",\n                                    batch.non_tensor_batch[\"request_id\"].tolist(),\n                                )\n                            self._dump_generations(\n                                inputs=inputs,\n                                outputs=outputs,\n                                scores=scores,\n                                reward_extra_infos_dict=reward_extra_infos_dict,\n                                dump_path=rollout_data_dir,\n                                ground_truths=ground_truths,\n                            )\n\n                    # validate\n                    if (\n                        self.val_reward_fn is not None\n                        and self.config.trainer.test_freq > 0\n                        and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                    ):\n                        with marked_timer(\"testing\", timing_raw, color=\"green\"):\n                            val_metrics: dict = self._validate()\n                            if is_last_step:\n                                last_val_metrics = val_metrics\n                        metrics.update(val_metrics)\n\n                    # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.\n                    esi_close_to_expiration = should_save_ckpt_esi(\n                        max_steps_duration=self.max_steps_duration,\n                        redundant_time=self.config.trainer.esi_redundant_time,\n                    )\n                    # Check if the conditions for saving a checkpoint are met.\n                    # The conditions include a mandatory condition (1) and\n                    # one of the following optional conditions (2/3/4):\n                    # 1. The save frequency is set to a positive value.\n                    # 2. It's the last training step.\n                    # 3. The current step number is a multiple of the save frequency.\n                    # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.\n                    if self.config.trainer.save_freq > 0 and (\n                        is_last_step\n                        or self.global_steps % self.config.trainer.save_freq == 0\n                        or esi_close_to_expiration\n                    ):\n                        if esi_close_to_expiration:\n                            print(\"Force saving checkpoint: ESI instance expiration approaching.\")\n                        with marked_timer(\"save_checkpoint\", timing_raw, color=\"green\"):\n                            self._save_checkpoint()\n\n                with marked_timer(\"stop_profile\", timing_raw):\n                    self._stop_profiling(do_profile)\n\n                steps_duration = timing_raw[\"step\"]\n                self.max_steps_duration = max(self.max_steps_duration, steps_duration)\n\n                # training metrics\n                metrics.update(\n                    {\n                        \"training/global_step\": self.global_steps,\n                        \"training/epoch\": epoch,\n                    }\n                )\n                # collect metrics\n                train_data_metrics = compute_data_metrics(batch=batch, use_critic=self.use_critic)\n                # Add train/ prefix to response_length metrics\n                train_data_metrics_prefixed = {}\n                for key, value in train_data_metrics.items():\n                    if key.startswith(\"response_length/\") or key.startswith(\"prompt_length/\"):\n                        train_data_metrics_prefixed[f\"train/{key}\"] = value\n                    else:\n                        train_data_metrics_prefixed[key] = value\n                metrics.update(train_data_metrics_prefixed)\n                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n                # TODO: implement actual tflpo and theoretical tflpo\n                n_gpus = self.resource_pool_manager.get_n_gpus()\n                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n\n                # this is experimental and may be changed/removed in the future in favor of a general-purpose one\n                if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):\n                    self.train_dataloader.sampler.update(batch=batch)\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                progress_bar.update(1)\n                self.global_steps += 1\n\n                if is_last_step:\n                    pprint(f\"Final validation metrics: {last_val_metrics}\")\n                    progress_bar.close()\n                    return\n\n                # this is experimental and may be changed/removed in the future\n                # in favor of a general-purpose data buffer pool\n                if hasattr(self.train_dataset, \"on_batch_end\"):\n                    # The dataset may be changed after each training batch\n                    self.train_dataset.on_batch_end(batch=batch)\n"
  },
  {
    "path": "verl_rl/recipe/onerec/onerec_recipe.py",
    "content": "from __future__ import annotations\n\nimport ast\nimport copy\nimport logging\nimport os\nimport re\nfrom collections import defaultdict\nfrom typing import Any, Optional\n\nimport datasets\nimport numpy as np\nimport torch\nfrom omegaconf import DictConfig, ListConfig\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.utils.model import compute_position_id_with_mask\n\nlogger = logging.getLogger(__name__)\n\n__all__ = [\"collate_fn\", \"OneRecDataset\", \"compute_score\"]\n\ndef collate_fn(samples: list[dict[str, Any]]) -> dict[str, Any]:\n    tensors: dict[str, list[torch.Tensor]] = defaultdict(list)\n    non_tensors: dict[str, list[Any]] = defaultdict(list)\n\n    for sample in samples:\n        for key, value in sample.items():\n            if isinstance(value, torch.Tensor):\n                tensors[key].append(value)\n            else:\n                non_tensors[key].append(value)\n\n    batch: dict[str, Any] = {}\n    for key, value in tensors.items():\n        batch[key] = torch.stack(value, dim=0)\n\n    for key, value in non_tensors.items():\n        batch[key] = np.array(value, dtype=object)\n\n    return batch\n\n\nclass OneRecDataset(Dataset):\n    def __init__(\n        self,\n        data_files: str | list[str],\n        tokenizer: PreTrainedTokenizer,\n        config: DictConfig,\n        processor: Optional[ProcessorMixin] = None,\n        max_samples: int = -1,\n    ) -> None:\n        if not isinstance(data_files, (list, ListConfig)):\n            data_files = [data_files]\n\n        self.data_files = copy.deepcopy(list(data_files))\n        self.original_data_files = copy.deepcopy(list(data_files))\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.max_samples = max_samples\n        self.config = config\n\n        self.cache_dir = os.path.expanduser(config.get(\"cache_dir\", \"~/.cache/verl/rlhf\"))\n        self.prompt_key = config.get(\"prompt_key\", \"prompt\")\n        self.image_key = config.get(\"image_key\", \"images\")\n        self.video_key = config.get(\"video_key\", \"videos\")\n        self.max_prompt_length = config.get(\"max_prompt_length\", 1024)\n        self.return_raw_chat = config.get(\"return_raw_chat\", False)\n        self.return_full_prompt = config.get(\"return_full_prompt\", False)\n        self.truncation = config.get(\"truncation\", \"error\")\n        self.filter_overlong_prompts = config.get(\"filter_overlong_prompts\", True)\n        self.need_tools_kwargs = config.get(\"need_tools_kwargs\", False)\n        self.filter_prompts = config.get(\"filter_prompts\", True)\n        self.return_multi_modal_inputs = config.get(\"return_multi_modal_inputs\", True)\n        self.enable_think = config.get(\"enable_think\", True)\n        self.enable_nonthink = config.get(\"enable_nonthink\", False)\n\n        self.use_force_prefix = config.get(\"use_force_prefix\", False)\n        self._FORCE_PREFIX_CONTENT = \"<think>\\n</think><|sid_begin|>\"\n\n        if self.enable_think and self.enable_nonthink:\n            raise ValueError(\"enable_think and enable_nonthink cannot be both True\") \n\n        self.num_workers = os.cpu_count()\n        self.use_shm = config.get(\"use_shm\", False)\n        self.serialize_dataset = False\n\n        self._download()\n        self._read_files_and_tokenize()\n\n    def _download(self, use_origin_parquet: bool = False) -> None:\n        from verl.utils.fs import copy_to_local\n\n        target_files = self.original_data_files if use_origin_parquet else self.data_files\n        for idx, parquet_file in enumerate(target_files):\n            local_path = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)\n            target_files[idx] = local_path\n\n        if use_origin_parquet:\n            self.data_files = target_files\n\n    def _read_files_and_tokenize(self) -> None:\n        dataframes: list[datasets.Dataset] = []\n        for parquet_file in self.data_files:\n            dataframe = datasets.load_dataset(\"parquet\", data_files=parquet_file)[\"train\"]\n            dataframes.append(dataframe)\n\n        self.dataframe = datasets.concatenate_datasets(dataframes)  # type: ignore[attr-defined]\n        logger.info(\"dataset len: %s\", len(self.dataframe))\n\n        if self.max_samples > 0 and self.max_samples < len(self.dataframe):\n            if self.shuffle:\n                rngs_args = (self.seed,) if self.seed is not None else ()\n                rng = np.random.default_rng(*rngs_args)\n                indices = rng.choice(len(self.dataframe), size=self.max_samples, replace=False)\n            else:\n                indices = np.arange(self.max_samples)\n            self.dataframe = self.dataframe.select(indices.tolist())\n            print(f\"selected {self.max_samples} random samples out of {len(self.dataframe)}\")\n\n        self.dataframe = self.dataframe.map(\n            self._extract_prompt_fields,\n            num_proc=self.num_workers,\n            desc=\"Extract prompts and reward annotations\",\n        )\n\n        logger.info(\"processed dataset len: %s\", len(self.dataframe))\n        self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)\n\n    def _extract_prompt_fields(self, row: dict[str, Any]) -> dict[str, Any]:\n        raw_messages = row.get(\"messages\")\n        if isinstance(raw_messages, str):\n            messages = ast.literal_eval(raw_messages)\n        else:\n            messages = raw_messages or []\n\n        clean_chats = [\n            {\n                \"role\": message.get(\"role\"),\n                \"content\": \"\".join(segment.get(\"text\", \"\") for segment in message.get(\"content\", []) if segment.get(\"type\") == \"text\"),\n            }\n            for message in messages\n        ]\n\n        if not clean_chats:\n            raise ValueError(\"Sample has empty messages; please check data integrity.\")\n\n        prompt_messages = clean_chats[:-1]\n\n        # Append /think or /no_think suffix to user messages based on config\n        if self.enable_think:\n            for message in prompt_messages:\n                if message[\"role\"] == \"user\":\n                    message[\"content\"] = message[\"content\"] + \"/think\"\n        if self.enable_nonthink:\n            for message in prompt_messages:\n                if message[\"role\"] == \"user\":\n                    message[\"content\"] = message[\"content\"] + \"/no_think\"\n\n\n        ground_truth_message = clean_chats[-1][\"content\"]\n\n        reward_payload = {\n            \"ground_truth\": ground_truth_message,\n            \"style\": \"rule\",\n        }\n\n        row[self.prompt_key] = prompt_messages\n        row[\"reward_model\"] = reward_payload\n        return row\n\n    def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset) -> datasets.Dataset:\n        if not self.filter_overlong_prompts:\n            return dataframe\n\n        tokenizer = self.tokenizer\n        processor = self.processor\n        prompt_key = self.prompt_key\n        image_key = self.image_key\n        video_key = self.video_key\n\n        if processor is not None:\n            from verl.utils.dataset.vision_utils import process_image, process_video\n\n            def doc_length(doc: dict[str, Any]) -> int:\n                messages = self._build_messages(dict(doc))\n                raw_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n                images = [process_image(image) for image in doc.get(image_key, [])]\n                videos = [process_video(video) for video in doc.get(video_key, [])]\n                encoded = processor(text=[raw_prompt], images=images or None, videos=videos or None, return_tensors=\"pt\")\n                return int(encoded[\"input_ids\"].shape[-1])\n\n        else:\n\n            def doc_length(doc: dict[str, Any]) -> int:\n                messages = doc[prompt_key]\n                return len(tokenizer.apply_chat_template(messages, add_generation_prompt=True))\n\n        filtered = dataframe.filter(\n            lambda doc: doc_length(doc) <= self.max_prompt_length - 10,\n            num_proc=self.num_workers,\n            desc=f\"Filtering prompts longer than {self.max_prompt_length - 10} tokens\",\n        )\n\n        logger.info(\"filtered dataset len: %s\", len(filtered))\n        return filtered\n\n    def resume_dataset_state(self) -> None:\n        self.serialize_dataset = not hasattr(self, \"original_data_files\")\n        if not self.serialize_dataset:\n            self._download(use_origin_parquet=True)\n            self._read_files_and_tokenize()\n        else:\n            logger.warning(\"resume with serialized dataloader, consider restarting from scratch for better perf\")\n\n    def __len__(self) -> int:  # type: ignore[override]\n        return len(self.dataframe)\n\n    def _build_messages(self, example: dict[str, Any]) -> list[dict[str, Any]]:\n        messages: list[dict[str, Any]] = example.pop(self.prompt_key)\n\n        if self.image_key in example or self.video_key in example:\n            for message in messages:\n                content = message[\"content\"]\n                segments = [segment for segment in re.split(r\"(<image>|<video>)\", content) if segment]\n                parsed_segments = []\n                for segment in segments:\n                    if segment == \"<image>\":\n                        parsed_segments.append({\"type\": \"image\"})\n                    elif segment == \"<video>\":\n                        parsed_segments.append({\"type\": \"video\"})\n                    else:\n                        parsed_segments.append({\"type\": \"text\", \"text\": segment})\n                message[\"content\"] = parsed_segments\n\n        return messages\n\n    def __getitem__(self, index: int) -> dict[str, Any]:  # type: ignore[override]\n        row: dict[str, Any] = dict(self.dataframe[index])\n        messages = self._build_messages(dict(row))\n        model_inputs: dict[str, Any] = {}\n\n        if self.processor is not None:\n            from verl.utils.dataset.vision_utils import process_image, process_video\n\n            raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n\n            if self.use_force_prefix:\n                raw_prompt = raw_prompt + self._FORCE_PREFIX_CONTENT\n\n            multi_modal_data: dict[str, Any] = {}\n\n            images = None\n            if self.image_key in row and row.get(self.image_key):\n                images = [process_image(image) for image in row.pop(self.image_key)]\n                multi_modal_data[\"image\"] = images\n\n            videos = None\n            if self.video_key in row and row.get(self.video_key):\n                videos = [process_video(video) for video in row.pop(self.video_key)]\n                multi_modal_data[\"video\"] = [video.numpy() for video in videos]\n\n            model_inputs = self.processor(\n                text=[raw_prompt],\n                images=images,\n                videos=videos,\n                return_tensors=\"pt\",\n            )\n\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n            row[\"multi_modal_data\"] = multi_modal_data\n            if self.return_multi_modal_inputs:\n                mm_inputs = dict(model_inputs)\n                mm_inputs.pop(\"second_per_grid_ts\", None)\n                row[\"multi_modal_inputs\"] = mm_inputs\n        else:\n            raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n\n            if self.use_force_prefix:\n                raw_prompt = raw_prompt + self._FORCE_PREFIX_CONTENT\n\n            model_inputs = self.tokenizer(raw_prompt, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n        input_ids, attention_mask = verl_F.postprocess_data(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            max_length=self.max_prompt_length,\n            pad_token_id=self.tokenizer.pad_token_id,\n            left_pad=True,\n            truncation=self.truncation,\n        )\n\n        if (\n            self.processor is not None\n            and hasattr(self.processor, \"image_processor\")\n            and \"Qwen2VLImageProcessor\" in self.processor.image_processor.__class__.__name__\n        ):\n            from verl.models.transformers.qwen2_vl import get_rope_index\n\n            position_ids = [\n                get_rope_index(\n                    self.processor,\n                    input_ids=input_ids[0],\n                    image_grid_thw=model_inputs.get(\"image_grid_thw\"),\n                    video_grid_thw=model_inputs.get(\"video_grid_thw\"),\n                    second_per_grid_ts=model_inputs.get(\"second_per_grid_ts\"),\n                    attention_mask=attention_mask[0],\n                )\n            ]\n        else:\n            position_ids = compute_position_id_with_mask(attention_mask)\n\n        row[\"input_ids\"] = input_ids[0]\n        row[\"attention_mask\"] = attention_mask[0]\n        row[\"position_ids\"] = position_ids[0]\n\n        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)\n        if len(raw_prompt_ids) > self.max_prompt_length:\n            raw_prompt_ids = self._truncate_ids(raw_prompt_ids)\n\n        row[\"raw_prompt_ids\"] = raw_prompt_ids\n        if self.return_raw_chat:\n            row[\"raw_prompt\"] = messages\n        if self.return_full_prompt:\n            row[\"full_prompts\"] = raw_prompt\n\n        extra_info = row.get(\"extra_info\", {}) or {}\n        row[\"index\"] = extra_info.get(\"index\", index)\n        row[\"tools_kwargs\"] = extra_info.get(\"tools_kwargs\", {})\n        row[\"interaction_kwargs\"] = extra_info.get(\"interaction_kwargs\", {})\n\n\n        if \"source\" in row or \"data_source\" in row:\n            pass\n        else:\n            row[\"data_source\"] = \"unknown\"\n            logger.warning(\"No source/data_source field found for index %s, set to 'unknown'\", row[\"index\"])\n\n        if self.need_tools_kwargs and not row[\"tools_kwargs\"]:\n            logger.warning(\"tools_kwargs is empty for index %s, data source: %s\", row[\"index\"], row.get(\"data_source\", row.get(\"source\", \"unknown\")))\n\n        return row\n\n    def _truncate_ids(self, token_ids: list[int]) -> list[int]:\n        if self.truncation == \"left\":\n            return token_ids[-self.max_prompt_length :]\n        if self.truncation == \"right\":\n            return token_ids[: self.max_prompt_length]\n        if self.truncation == \"middle\":\n            left = self.max_prompt_length // 2\n            right = self.max_prompt_length - left\n            return token_ids[:left] + token_ids[-right:]\n        if self.truncation == \"error\":\n            raise RuntimeError(\n                f\"Prompt length {len(token_ids)} exceeds max_prompt_length={self.max_prompt_length}. \"\n                \"Consider increasingmax_prompt_length or enabling truncation.\"\n            )\n        raise ValueError(f\"Unsupported truncation mode: {self.truncation}\")\n\n    def __getstate__(self) -> dict[str, Any]:\n        if not self.serialize_dataset:\n            state = self.__dict__.copy()\n            state.pop(\"dataframe\", None)\n            return state\n        return self.__dict__.copy()\n\n\nSLOT_PATTERN = re.compile(r\"<s_a_(\\d+)><s_b_(\\d+)><s_c_(\\d+)>\")\n\n\ndef _extract_all_tuples(text: Any) -> list[tuple[str, str, str]]:\n    if not isinstance(text, str):\n        logger.warning(\"_extract_all_tuples received non-string input: %s\", type(text))\n        return []\n\n    matches = SLOT_PATTERN.findall(text)\n    return [tuple(match) for match in matches] if matches else []\n\n\ndef think_format_reward(prediction: str) -> float:\n    \"\"\"Check if prediction contains valid think format.\n\n    Args:\n        prediction: Model prediction text.\n\n    Returns:\n        1.0 if contains valid <think>...</think> with content length > 10, else 0.0.\n    \"\"\"\n    if \"<think>\" not in prediction or \"</think>\" not in prediction:\n        return 0.0\n\n    start_idx = prediction.find(\"<think>\") + len(\"<think>\")\n    end_idx = prediction.find(\"</think>\")\n\n    if end_idx < start_idx:\n        return 0.0\n\n    content = prediction[start_idx:end_idx]\n    content_stripped = content.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\r\", \"\").replace(\"\\t\", \"\")\n\n    return 1.0 if len(content_stripped) > 10 else 0.0\n\n\ndef partial_hit_reward(prediction: str, ground_truth: str) -> float:\n    \"\"\"Calculate hierarchical matching reward with partial match support.\n\n    Args:\n        prediction: Model prediction text, may contain multiple sids.\n        ground_truth: Ground truth text, may contain multiple sids.\n\n    Returns:\n        Weighted match score:\n        - Full match (s_a, s_b, s_c): 100 points\n        - s_a and s_b match: 10 points\n        - Only s_a match: 1 point\n        - No match: 0 points\n        Returns average score across all predicted sids.\n    \"\"\"\n    pred_tuples = _extract_all_tuples(prediction)\n    gt_tuples = _extract_all_tuples(ground_truth)\n\n    if not pred_tuples or not gt_tuples:\n        return 0.0\n\n    total_reward = 0.0\n\n    # Find best match for each predicted sid and calculate score\n    for pred_tuple in pred_tuples:\n        max_score = 0.0\n        \n        for gt_tuple in gt_tuples:\n            # Full match (s_a, s_b, s_c)\n            if pred_tuple == gt_tuple:\n                max_score = max(max_score, 100.0)\n            # s_a and s_b match\n            elif pred_tuple[:2] == gt_tuple[:2]:\n                max_score = max(max_score, 10.0)\n            # Only s_a match\n            elif pred_tuple[0] == gt_tuple[0]:\n                max_score = max(max_score, 1.0)\n        \n        total_reward += max_score\n    \n    # Return average score to avoid inflated scores with multiple predictions\n    return total_reward / len(pred_tuples)\n\ndef hit_reward(prediction: str, ground_truth: str) -> float:\n    \"\"\"Calculate hit reward: intersection ratio between prediction and ground truth.\n\n    Args:\n        prediction: Model prediction text, may contain multiple sids.\n        ground_truth: Ground truth text, may contain multiple sids.\n\n    Returns:\n        Hit reward: intersection count / prediction count.\n    \"\"\"\n    if \"</think>\" in prediction and \"<think>\" in prediction:\n        think_end_idx = prediction.find(\"</think>\") + len(\"</think>\")\n        prediction = prediction[think_end_idx:]\n    else:\n        return 0.0\n\n\n    pred_tuples = _extract_all_tuples(prediction)\n    gt_tuples = _extract_all_tuples(ground_truth)\n    if not pred_tuples or not gt_tuples:\n        return 0.0\n\n    pred_set = set(pred_tuples)\n    gt_set = set(gt_tuples)\n    return len(pred_set & gt_set) / len(pred_tuples)\n\ndef first_sid_hit_reward(prediction: str, ground_truth: str) -> float:\n    \"\"\"Calculate Pass@1 reward: whether the first sid after </think> hits ground truth.\n\n    Args:\n        prediction: Model prediction text.\n        ground_truth: Ground truth text.\n\n    Returns:\n        1.0 if first sid is in ground truth, else 0.0.\n    \"\"\"\n    # Extract content after </think>\n    if \"</think>\" in prediction and \"<think>\" in prediction:\n        think_end_idx = prediction.find(\"</think>\") + len(\"</think>\")\n        prediction = prediction[think_end_idx:]\n    else:\n        return 0.0\n\n    pred_tuples = _extract_all_tuples(prediction)\n    if not pred_tuples:\n        return 0.0\n\n    # Get the first predicted sid tuple\n    first_pred_tuple = pred_tuples[0]\n\n    gt_tuples = _extract_all_tuples(ground_truth)\n    if not gt_tuples:\n        return 0.0\n\n    gt_set = set(gt_tuples)\n    \n    return float(first_pred_tuple in gt_set)\n\ndef pass_rate(prediction: str, ground_truth: str) -> float:\n    \"\"\"Calculate pass rate: whether prediction and ground truth have intersection.\n\n    Args:\n        prediction: Model prediction text, may contain multiple sids.\n        ground_truth: Ground truth text, may contain multiple sids.\n\n    Returns:\n        1.0 if there is intersection, else 0.0.\n    \"\"\"\n    pred_tuples = _extract_all_tuples(prediction)\n    gt_tuples = _extract_all_tuples(ground_truth)\n    if not pred_tuples or not gt_tuples:\n        return 0.0\n\n    # Convert to set for intersection calculation\n    pred_set = set(pred_tuples)\n    gt_set = set(gt_tuples)\n    intersection_count = len(pred_set & gt_set)\n    \n    return float(intersection_count > 0)\n\n\n\ndef compute_score(\n    data_source: str,  # noqa: ARG001\n    solution_str: str,\n    ground_truth: str,\n    extra_info: dict[str, Any],  # noqa: ARG001\n) -> dict[str, float]:\n    \"\"\"Compute reward scores for recommendation results.\n\n    Args:\n        data_source: Data source identifier (kept for API compatibility).\n        solution_str: Model generated prediction text.\n        ground_truth: Ground truth text.\n        extra_info: Extra information (kept for API compatibility).\n\n    Returns:\n        Dictionary containing various reward scores.\n    \"\"\"\n    prediction = solution_str\n    format_reward_value = think_format_reward(prediction)\n    partial_hit_reward_value = partial_hit_reward(prediction, ground_truth)\n    hit_reward_value = hit_reward(prediction, ground_truth)\n    pass_rate_value = pass_rate(prediction, ground_truth)\n    pass_at_1_value = first_sid_hit_reward(prediction, ground_truth)\n\n    return {\n        \"score\": pass_at_1_value,\n        \"format_reward\": format_reward_value,\n        \"partial_hit_reward\": partial_hit_reward_value,\n        \"hit_reward\": hit_reward_value,\n        \"pass_rate\": pass_rate_value,\n        \"pass_at_1\": pass_at_1_value,\n    }"
  },
  {
    "path": "verl_rl/recipe/onerec/onerec_vllm_rollout.py",
    "content": "import numpy as np\nimport torch\nfrom tensordict import TensorDict\nfrom vllm import SamplingParams\nfrom vllm.lora.request import LoRARequest\nfrom verl import DataProto\nfrom verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length\nfrom verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import vLLMRollout, _pre_process_inputs\n\ntry:\n    from vllm.sampling_params import BeamSearchParams\nexcept ImportError:\n    BeamSearchParams = None\n\n\nclass OneRecvLLMRollout(vLLMRollout):\n    \"\"\"\n    Custom vLLM Rollout for OneRec with Two-Stage Generation:\n    1. Sample CoT until </think>\n    2. Beam search items using Prompt + CoT + Prefix\n    \"\"\"\n\n    @torch.no_grad()\n    def _two_stage_generation(self, prompts: DataProto, **kwargs) -> DataProto:\n        \"\"\"\n        Two-stage generation:\n        1. Sample CoT until </think>.\n        2. Beam search items using Prompt + CoT + Prefix.\n        \"\"\"\n        idx = prompts.batch[\"input_ids\"]\n        attention_mask = prompts.batch[\"attention_mask\"]\n        position_ids = prompts.batch[\"position_ids\"]\n        eos_token_id = prompts.meta_info[\"eos_token_id\"]\n        batch_size = idx.size(0)\n        \n        # Prepare vllm inputs (same as standard)\n        non_tensor_batch = prompts.non_tensor_batch\n        if \"raw_prompt_ids\" not in non_tensor_batch:\n            non_tensor_batch[\"raw_prompt_ids\"] = np.array(\n                [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object\n            )\n            \n        if \"multi_modal_data\" in non_tensor_batch:\n            vllm_inputs = []\n            for raw_prompt_ids, multi_modal_data in zip(\n                non_tensor_batch.pop(\"raw_prompt_ids\"), non_tensor_batch.pop(\"multi_modal_data\"), strict=True\n            ):\n                vllm_inputs.append({\"prompt_token_ids\": raw_prompt_ids, \"multi_modal_data\": multi_modal_data})\n        else:\n            vllm_inputs = [\n                {\"prompt_token_ids\": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop(\"raw_prompt_ids\")\n            ]\n            \n        for input_data in vllm_inputs:\n            if isinstance(input_data[\"prompt_token_ids\"], np.ndarray):\n                input_data[\"prompt_token_ids\"] = input_data[\"prompt_token_ids\"].tolist()\n\n        # Stage 1: CoT Sampling\n        # Use standard sampling parameters but stop at </think>\n        # Read stage1_max_tokens from config (set via ++actor_rollout_ref.rollout.stage1_max_tokens)\n        stage1_max_tokens = kwargs.get(\"stage1_max_tokens\",\n                            getattr(self.config, \"stage1_max_tokens\",\n                            kwargs.get(\"max_tokens\", 1024)))  # fallback to max_tokens or 1024\n\n        cot_sampling_params = SamplingParams(\n            n=1, # We generate 1 CoT per prompt (prompts are already repeated by Trainer if needed)\n            temperature=kwargs.get(\"temperature\", 1.0),\n            top_p=kwargs.get(\"top_p\", 1.0),\n            top_k=kwargs.get(\"top_k\", -1),\n            max_tokens=stage1_max_tokens,  # Use stage1_max_tokens for CoT length\n            stop=[\"</think>\"],\n            include_stop_str_in_output=True,\n        )\n\n        print(f\"[TwoStage] Stage 1 params: max_tokens={stage1_max_tokens}, temperature={kwargs.get('temperature', 1.0)}\")\n        \n        lora_requests = None # Assuming LoRA logic is same as standard, omitted for brevity or copy if needed\n        if self.lora_kwargs:\n             # Copy lora logic\n             lora_int_ids = list(self.inference_engine.llm_engine.list_loras())\n             if len(lora_int_ids) > 0:\n                 lora_int_id = lora_int_ids[0]\n                 lora_requests = [\n                     LoRARequest(lora_name=f\"{lora_int_id}\", lora_int_id=lora_int_id, lora_path=\"/simon-stub-path\")\n                 ] * batch_size\n\n        # print(f\"[TwoStage] Starting Stage 1: CoT Generation...\")\n        cot_outputs = self.inference_engine.generate(\n            prompts=vllm_inputs,\n            sampling_params=cot_sampling_params,\n            lora_request=lora_requests,\n            use_tqdm=False,\n        )\n        \n        # Process Stage 1 Outputs and Prepare Stage 2 Inputs\n        stage2_inputs = []\n        cot_responses = []\n\n        # We need tokenizer to encode the prefix\n        tokenizer = self.inference_engine.get_tokenizer()\n        # Prefix: \\n<|sid_begin|> (assuming </think> is already in output)\n        # Note: if model stopped because of length, </think> might be missing.\n        # We should handle this.\n        prefix_ids = tokenizer.encode(\"\\n<|sid_begin|>\", add_special_tokens=False)\n\n        # Get vocab size for OOV filtering\n        vocab_size = len(tokenizer)\n\n        for i, output in enumerate(cot_outputs):\n            # Get CoT tokens\n            cot_token_ids = list(output.outputs[0].token_ids)\n\n            # Filter OOV tokens from CoT output\n            cot_token_ids_filtered = [tid for tid in cot_token_ids if tid < vocab_size]\n            if len(cot_token_ids_filtered) < len(cot_token_ids):\n                print(f\"[TwoStage] Filtered {len(cot_token_ids) - len(cot_token_ids_filtered)} OOV tokens from CoT output {i}\")\n\n            cot_responses.append(cot_token_ids_filtered)\n\n            # Construct Stage 2 Prompt: Original Prompt + CoT + Prefix\n            original_prompt_ids = vllm_inputs[i][\"prompt_token_ids\"]\n            new_prompt_ids = original_prompt_ids + cot_token_ids_filtered + prefix_ids\n\n            stage2_input = {\"prompt_token_ids\": new_prompt_ids}\n            if \"multi_modal_data\" in vllm_inputs[i]:\n                stage2_input[\"multi_modal_data\"] = vllm_inputs[i][\"multi_modal_data\"]\n            stage2_inputs.append(stage2_input)\n\n        # Stage 2: Item Beam Search\n        # Read from kwargs first, then fallback to config, then default\n        beam_width = kwargs.get(\"stage2_beam_size\", getattr(self.config, \"stage2_beam_size\", 32))\n        # Support both stage2_max_tokens and stage2_num_tokens (for backward compatibility)\n        max_tokens_item = kwargs.get(\"stage2_max_tokens\",\n                          kwargs.get(\"stage2_num_tokens\",\n                          getattr(self.config, \"stage2_max_tokens\",\n                          getattr(self.config, \"stage2_num_tokens\", 16))))\n\n        print(f\"[TwoStage] Stage 2 params: beam_width={beam_width}, max_tokens={max_tokens_item}, batch_size={batch_size}\")\n\n        if BeamSearchParams is None:\n             raise ImportError(\"BeamSearchParams not available, cannot run Stage 2\")\n\n        beam_params = BeamSearchParams(\n            beam_width=beam_width,\n            max_tokens=max_tokens_item,\n        )\n\n        # Call beam search (aligned with standard implementation)\n        item_outputs = self.inference_engine.beam_search(\n            prompts=stage2_inputs,\n            params=beam_params,\n        )\n\n        # Post-process beam search outputs (aligned with standard beam search logic)\n        # For two-stage rollout, always return all beams for both training and evaluation\n        return_all_beams = kwargs.get(\"return_all_beams\", True)\n        # For two-stage rollout, n_beams_to_return should be beam_width, not kwargs[\"n\"] (which is rollout_n)\n        n_beams_to_return = beam_width\n\n        print(f\"[TwoStage] Post-process: return_all_beams={return_all_beams}, n_beams_to_return={n_beams_to_return}\")\n\n        response = []\n\n        if return_all_beams:\n            # Return all beams, expand output\n            # Output will be exactly batch_size * n_beams_to_return (pad if needed)\n            expanded_idx = []\n            beam_indices = []  # Track which beam index within each prompt\n\n            for i, output in enumerate(item_outputs):\n                # Prompt length including CoT + Prefix\n                stage2_prompt_len = len(stage2_inputs[i][\"prompt_token_ids\"])\n                original_prompt_len = len(vllm_inputs[i][\"prompt_token_ids\"])\n\n                # Get top n beams for this prompt, pad if not enough\n                num_seqs = len(output.sequences)\n                for seq_idx in range(n_beams_to_return):\n                    if seq_idx < num_seqs:\n                        best_seq = output.sequences[seq_idx]\n                        full_seq = best_seq.tokens\n                        # Response = full_seq - original_prompt (not stage2_prompt!)\n                        response_ids = full_seq[original_prompt_len:]\n                    else:\n                        # Pad with first beam's result if not enough beams\n                        best_seq = output.sequences[0]\n                        full_seq = best_seq.tokens\n                        response_ids = full_seq[original_prompt_len:]\n                    response.append(response_ids)\n                    expanded_idx.append(i)\n                    beam_indices.append(seq_idx)\n\n            # Expand idx, attention_mask, position_ids to match expanded output\n            idx = idx[expanded_idx]  # (batch_size * n, prompt_length)\n            attention_mask = attention_mask[expanded_idx]\n            position_ids = position_ids[expanded_idx]\n\n            # Expand non_tensor_batch to match expanded output\n            expanded_non_tensor_batch = {}\n            for key, val in non_tensor_batch.items():\n                if isinstance(val, np.ndarray):\n                    expanded_non_tensor_batch[key] = val[expanded_idx]\n                elif isinstance(val, list):\n                    expanded_non_tensor_batch[key] = [val[i] for i in expanded_idx]\n                else:\n                    expanded_non_tensor_batch[key] = val\n            non_tensor_batch = expanded_non_tensor_batch\n\n            # Store beam indices for reference\n            non_tensor_batch[\"_beam_indices\"] = np.array(beam_indices, dtype=np.int64)\n\n            batch_size = len(response)  # Update batch_size\n\n            print(f\"[TwoStage] Expanded output: original_bs={len(item_outputs)}, expanded_bs={batch_size}, n_beams={n_beams_to_return}\")\n        else:\n            # Original path: use beam_idx to select specific beam\n            beam_idxs = non_tensor_batch.get(\"beam_idx\", None)\n\n            for i, output in enumerate(item_outputs):\n                original_prompt_len = len(vllm_inputs[i][\"prompt_token_ids\"])\n\n                seq_idx = 0\n                if beam_idxs is not None:\n                    seq_idx = beam_idxs[i]\n\n                if seq_idx >= len(output.sequences):\n                    seq_idx = 0\n\n                best_seq = output.sequences[seq_idx]\n                full_seq = best_seq.tokens\n                response_ids = full_seq[original_prompt_len:]\n                response.append(response_ids)\n\n        # Pad responses\n        response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device)\n\n        if self.config.calculate_log_probs:\n            rollout_log_probs = torch.zeros_like(response, dtype=torch.float32)\n\n        seq = torch.cat([idx, response], dim=-1)\n\n        # Position IDs & Attention Mask Update (standard logic)\n        response_length = response.size(1)\n        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n        delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)\n\n        if position_ids.dim() == 3:  # qwen2vl mrope\n            delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)\n\n        response_position_ids = position_ids[..., -1:] + delta_position_id\n        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n\n        response_attention_mask = get_response_mask(\n            response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype\n        )\n        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n\n        batch = TensorDict(\n            {\n                \"prompts\": idx,\n                \"responses\": response,\n                \"input_ids\": seq,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=batch_size,\n        )\n        if self.config.calculate_log_probs:\n            batch[\"rollout_log_probs\"] = rollout_log_probs\n\n        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)\n\n\n    @torch.no_grad()\n    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:\n        \"\"\"\n        Generate sequences using two-stage generation.\n        \"\"\"\n        # Extract params from meta_info and merge into kwargs\n        for key in [\"max_tokens\", \"temperature\", \"n\", \"top_p\", \"top_k\",\n                    \"stage2_beam_size\", \"stage2_max_tokens\", \"return_all_beams\"]:\n            if key in prompts.meta_info:\n                kwargs[key] = prompts.meta_info[key]\n\n        return self._two_stage_generation(prompts, **kwargs)\n"
  },
  {
    "path": "verl_rl/recipe/onerec/run_grpo.sh",
    "content": "#!/bin/bash\n# GRPO Training Script with Two-Stage Rollout\n# Two-Stage Rollout: first generate to </think>, then insert <sid_begin> and beam search\n\nset -e\n\n# ============================================================================\n# Cluster Configuration (auto-detect from Ray)\n# ============================================================================\nRAY_INFO=$(python -c \"import ray; ray.init(address='auto', ignore_reinit_error=True); nodes = [n for n in ray.nodes() if n['Alive']]; gpus=next((int(n.get('Resources',{}).get('GPU',0)) for n in nodes if n.get('Resources',{}).get('GPU',0)>0), 0); print(f'{len(nodes)} {gpus}')\" 2>/dev/null)\n\nexport N_NODES=$(echo $RAY_INFO | awk '{print $1}')\nexport N_GPUS=$(echo $RAY_INFO | awk '{print $2}')\n\nif [ -z \"$N_NODES\" ] || [ -z \"$N_GPUS\" ] || [ \"$N_NODES\" -eq 0 ]; then\n    echo \"Could not detect Ray cluster. Using defaults: N_NODES=1, N_GPUS=8\"\n    export N_NODES=1\n    export N_GPUS=8\nelse\n    echo \"Detected Ray cluster: $N_NODES nodes, $N_GPUS GPUs per node\"\nfi\n\nPROJECT_DIR=\"$(cd \"$(dirname \"$0\")/../..\" && pwd)\"\nSCRIPT_DIR=\"$(cd \"$(dirname \"$0\")\" && pwd)\"\n\n# ============================================================================\n# Model Configuration\n# ============================================================================\nexport BASE_MODEL=${BASE_MODEL:-\"/path/to/your/model\"}\nexport ROLLOUT_TP_SIZE=${ROLLOUT_TP_SIZE:-1}\nexport VLLM_ATTENTION_BACKEND=XFORMERS\n\n# ============================================================================\n# Training Hyperparameters\n# ============================================================================\nexport LEARNING_RATE=${LEARNING_RATE:-2e-6}\nexport KL_LOSS_COEF=${KL_LOSS_COEF:-0.001}\nexport TEMPERATURE=${TEMPERATURE:-1}\n\n# ============================================================================\n# Batch Size Configuration\n# ============================================================================\nexport USE_DYNAMIC_BSZ=${USE_DYNAMIC_BSZ:-True}\nexport MAX_TOKENS_PER_GPU=${MAX_TOKENS_PER_GPU:-40960}\nexport TRAIN_BATCH_SIZE=$((N_GPUS * N_NODES))\n\n# ============================================================================\n# Rollout Configuration\n# ============================================================================\nexport ROLLOUT_N=${ROLLOUT_N:-1}\nexport STAGE2_BEAM_SIZE=${STAGE2_BEAM_SIZE:-32}\nexport RESPONSE_LENGTH=${RESPONSE_LENGTH:-2048}\nexport STAGE1_MAX_TOKENS=${STAGE1_MAX_TOKENS:-1024}\nexport STAGE2_NUM_TOKENS=${STAGE2_NUM_TOKENS:-3}\n\n# Think mode configuration\nexport ENABLE_THINK=${ENABLE_THINK:-False}\nexport ENABLE_NONTHINK=${ENABLE_NONTHINK:-False}\nexport USE_FORCE_PREFIX=${USE_FORCE_PREFIX:-False}\n\n# ============================================================================\n# Data Configuration\n# ============================================================================\nexport DATA_DIR=${DATA_DIR:-\"$(realpath ../output/rl_data)\"}\nexport TRAIN_FILES=${TRAIN_FILES:-\"[$DATA_DIR/train.parquet]\"}\nexport VAL_FILES=${VAL_FILES:-\"[$DATA_DIR/test.parquet]\"}\n\n# ============================================================================\n# Output Configuration\n# ============================================================================\nexport PROJECT_NAME=${PROJECT_NAME:-\"OneRec_RL\"}\nexport EXPERIMENT_NAME=${EXPERIMENT_NAME:-\"grpo_two_stage\"}\nexport OUTPUT_DIR=${OUTPUT_DIR:-\"./output\"}\nexport WANDB_MODE=${WANDB_MODE:-offline}\n\n# ============================================================================\n# Network Configuration (for distributed training)\n# ============================================================================\nexport TCP_NIC=$(ifconfig 2>/dev/null | grep -B1 \" \"$(hostname -i 2>/dev/null)\" \" | grep -o \"^\\w*\" || echo \"eth0\")\nexport NCCL_IB_DISABLE=${NCCL_IB_DISABLE:-0}\nexport NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX:-3}\n\n# ============================================================================\n# Print Configuration\n# ============================================================================\necho \"===================================\"\necho \"GRPO Training with Two-Stage Rollout\"\necho \"===================================\"\necho \"Model: $BASE_MODEL\"\necho \"Cluster: $N_NODES nodes x $N_GPUS GPUs\"\necho \"Batch Size: $TRAIN_BATCH_SIZE\"\necho \"Learning Rate: $LEARNING_RATE\"\necho \"Rollout N: $ROLLOUT_N\"\necho \"Stage2 Beam Size: $STAGE2_BEAM_SIZE\"\necho \"Enable Think: $ENABLE_THINK\"\necho \"Enable NonThink: $ENABLE_NONTHINK\"\necho \"===================================\"\n\n# ============================================================================\n# Launch Training\n# ============================================================================\nmkdir -p logs\n\nconda activate verl\n\npython3 -u -m recipe.onerec.main_onerec_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$TRAIN_FILES \\\n    data.val_files=$VAL_FILES \\\n    data.max_prompt_length=10240 \\\n    ++data.enable_think=$ENABLE_THINK \\\n    ++data.enable_nonthink=$ENABLE_NONTHINK \\\n    ++data.use_force_prefix=$USE_FORCE_PREFIX \\\n    data.prompt_key='prompt' \\\n    data.shuffle=True \\\n    data.max_response_length=$RESPONSE_LENGTH \\\n    data.train_batch_size=$TRAIN_BATCH_SIZE \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.custom_cls.path=$SCRIPT_DIR/onerec_recipe.py \\\n    data.custom_cls.name=OneRecDataset \\\n    data.reward_fn_key='source' \\\n    ++data.data_source_key='source' \\\n    actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \\\n    actor_rollout_ref.actor.entropy_checkpointing=True \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    actor_rollout_ref.rollout.calculate_log_probs=False \\\n    actor_rollout_ref.actor.clip_ratio_high=0.28 \\\n    actor_rollout_ref.model.enable_activation_offload=True \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    custom_reward_function.path=$SCRIPT_DIR/onerec_recipe.py \\\n    custom_reward_function.name=compute_score \\\n    actor_rollout_ref.actor.use_dynamic_bsz=$USE_DYNAMIC_BSZ \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$MAX_TOKENS_PER_GPU \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_BATCH_SIZE \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$MAX_TOKENS_PER_GPU \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$MAX_TOKENS_PER_GPU \\\n    actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_TOKENS_PER_GPU \\\n    actor_rollout_ref.rollout.max_num_seqs=2048 \\\n    actor_rollout_ref.actor.optim.lr=$LEARNING_RATE \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \\\n    actor_rollout_ref.actor.optim.weight_decay=0.1 \\\n    actor_rollout_ref.model.path=$BASE_MODEL \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.rollout.n=$ROLLOUT_N \\\n    actor_rollout_ref.rollout.dtype=bfloat16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \\\n    actor_rollout_ref.rollout.name=two_stage \\\n    ++actor_rollout_ref.rollout.backend=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    ++actor_rollout_ref.rollout.max_length=$RESPONSE_LENGTH \\\n    ++actor_rollout_ref.rollout.stage1_max_tokens=$STAGE1_MAX_TOKENS \\\n    ++actor_rollout_ref.rollout.stage2_num_tokens=$STAGE2_NUM_TOKENS \\\n    ++actor_rollout_ref.rollout.stage2_beam_size=$STAGE2_BEAM_SIZE \\\n    ++actor_rollout_ref.rollout.engine_kwargs.vllm.max_logprobs=320 \\\n    actor_rollout_ref.rollout.temperature=$TEMPERATURE \\\n    actor_rollout_ref.rollout.top_p=1.0 \\\n    actor_rollout_ref.rollout.do_sample=True \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=$KL_LOSS_COEF \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    algorithm.norm_adv_by_std_in_grpo=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.default_hdfs_dir=null \\\n    trainer.n_gpus_per_node=$N_GPUS \\\n    trainer.nnodes=$N_NODES \\\n    trainer.save_freq=50 \\\n    trainer.test_freq=50 \\\n    trainer.project_name=$PROJECT_NAME \\\n    trainer.experiment_name=$EXPERIMENT_NAME \\\n    trainer.default_local_dir=$OUTPUT_DIR/ckpt \\\n    trainer.total_epochs=20 \\\n    trainer.val_before_train=True \\\n    actor_rollout_ref.ref.strategy=fsdp2 \\\n    actor_rollout_ref.actor.strategy=fsdp2 \\\n    ++critic.enable=False \\\n    ++actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \\\n    ++actor_rollout_ref.ref.fsdp_config.model_dtype=bfloat16 \\\n    \"$@\"\n"
  },
  {
    "path": "verl_rl/recipe/prime/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/recipe/prime/config/prime_trainer.yaml",
    "content": "# the prime config will override default ppo_trainer.yaml\n\nhydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\ndata:\n  filter_accuracy: True\n  accuracy_lower_bound: 0.2\n  accuracy_upper_bound: 0.8\n  oversample_factor: 4.0 # Sample more responses than the batch size. prompts satisfying the filter will be prioritized.\n  filter_truncate: True\n  truncation: right\n\nactor_rollout_ref:\n  hybrid_engine: True\n  model:\n    use_remove_padding: True\n  rollout:\n    # number of responses (i.e. num sample times)\n    n: 4\n  actor:\n    entropy_coeff: 0.001\n\nreward_model:\n  enable: True\n  strategy: fsdp\n  model:\n    ref_path: ${reward_model.model.path}\n    use_remove_padding:  True\n    use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}\n    fused_kernel_options:\n      impl_backend: torch # triton, torch\n    tokenizer_path: ${actor_rollout_ref.model.path}\n    enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing}\n    ref_type: freeze\n    fsdp_config:\n      min_num_params: 0\n      param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload}\n#      grad_offload: ${actor_rollout_ref.actor.fsdp_config.grad_offload}\n      optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload}\n    update: before # ``before`` for double-forward, ``after`` for single-forward\n    optim:\n      lr: 1e-6\n      lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.\n      lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n      min_lr_ratio: null\n      warmup_style: constant\n      total_training_steps: -1  # must be overridden by program\n      weight_decay: 0.\n      grad_clip: 10.0\n    beta_train: 0.05\n    loss_type: ce # currently only supports ce loss\n  prime_granularity: token\n  prime_norm: batch_norm # batch_norm or none. if set to none, the normalizer is beta_train\n  mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}\n  reward_manager: prime\n\nalgorithm:\n  adv_estimator: rloo\n  # now supports rloo. it treats different source of reward separately.\n  kl_ctrl:\n    type: fixed\n    kl_coef: 0.000\n  reward_gt_coef: 5\n  reward_dpo_coef: 5\n\ntrainer:\n  project_name: prime\n  experiment_name: examples\n  val_before_train: False\n  balance_batch: False\n"
  },
  {
    "path": "verl_rl/recipe/prime/main_prime.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport hydra\nimport ray\n\nfrom .prime_ray_trainer import RayPRIMETrainer\n\n\n@hydra.main(config_path=\"config\", config_name=\"prime_trainer\", version_base=None)\ndef main(config):\n    run_prime(config)\n\n\ndef run_prime(config, compute_score=None):\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        ray.init(\n            runtime_env={\"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\"}},\n            num_cpus=config.ray_init.num_cpus,\n        )\n\n    ray.get(main_task.remote(config, compute_score))\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\ndef main_task(config, compute_score=None):\n    # print initial config\n    from pprint import pprint\n\n    from omegaconf import OmegaConf\n\n    from verl.utils.fs import copy_local_path_from_hdfs\n\n    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n    OmegaConf.resolve(config)\n\n    # download the checkpoint from hdfs\n    local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)\n\n    # instantiate tokenizer\n    from verl.utils import hf_tokenizer\n\n    tokenizer = hf_tokenizer(local_path)\n\n    # define worker classes\n    if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n        assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n        from verl.single_controller.ray import RayWorkerGroup\n        from verl.workers.fsdp_workers import ActorRolloutRefWorker\n\n        ray_worker_group_cls = RayWorkerGroup\n\n    elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n        assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n        from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n        from verl.workers.megatron_workers import ActorRolloutRefWorker\n\n        ray_worker_group_cls = NVMegatronRayWorkerGroup\n\n    else:\n        raise NotImplementedError\n\n    from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n    role_worker_mapping = {\n        Role.ActorRollout: ray.remote(ActorRolloutRefWorker),\n    }\n\n    global_pool_id = \"global_pool\"\n    resource_pool_spec = {\n        global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n    }\n    mapping = {\n        Role.ActorRollout: global_pool_id,\n    }\n\n    # use reference model\n    if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n        role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n        mapping[Role.RefPolicy] = global_pool_id\n\n    if config.reward_model.enable:\n        from .prime_fsdp_workers import PRIMERewardModelWorker\n\n        role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)\n        mapping[Role.RewardModel] = global_pool_id\n\n    reward_manager_name = config.reward_model.get(\"reward_manager\", \"naive\")\n    if reward_manager_name == \"naive\":\n        from verl.workers.reward_manager import NaiveRewardManager\n\n        reward_manager_cls = NaiveRewardManager\n    elif reward_manager_name == \"prime\":\n        from verl.workers.reward_manager import PrimeRewardManager\n\n        reward_manager_cls = PrimeRewardManager\n    else:\n        raise NotImplementedError\n    reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)\n\n    # Note that we always use function-based RM for validation\n    val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)\n\n    resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n    trainer = RayPRIMETrainer(\n        config=config,\n        tokenizer=tokenizer,\n        role_worker_mapping=role_worker_mapping,\n        resource_pool_manager=resource_pool_manager,\n        ray_worker_group_cls=ray_worker_group_cls,\n        reward_fn=reward_fn,\n        val_reward_fn=val_reward_fn,\n    )\n    trainer.init_workers()\n    trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/recipe/prime/prime_core_algos.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\nimport verl\nimport verl.utils.torch_functional as verl_F\n\n\ndef compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config):\n    # calculate rloo reward on different reward sources, and sum again\n    def masked_rloo(reward_tensor_original, mask_tensor):\n        reward_tensor = reward_tensor_original.clone()\n        reward_tensor[~mask_tensor] = 0\n        for start_pos in range(0, reward_tensor.shape[0], n_samples):\n            cur_rewards_mean = torch.cat(\n                [\n                    reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True)\n                    for pos in range(start_pos, start_pos + n_samples)\n                ],\n                dim=0,\n            )\n            cur_rewards_sum = cur_rewards_mean.sum()\n            cur_reward_baseline = cur_rewards_sum / (n_samples - 1)\n            reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = (\n                reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]]\n                * (n_samples / (n_samples - 1))\n                - cur_reward_baseline\n            )\n\n        return reward_tensor\n\n    reward_tensors = []\n\n    with torch.no_grad():\n        if \"rm_scores\" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0:\n            reward_tensor = data.batch[\"rm_scores\"]\n            reward_mask = response_mask.bool()\n\n            reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)\n\n        if \"acc\" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0:\n            reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)\n            reward_mask = torch.zeros_like(response_mask, dtype=torch.bool)\n\n            prompt_ids = data.batch[\"prompts\"]\n            prompt_length = prompt_ids.shape[-1]\n            valid_response_length = data.batch[\"attention_mask\"][:, prompt_length:].sum(-1)\n\n            reward_mask[\n                torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),\n                valid_response_length - 1,\n            ] = True\n            reward_tensor[\n                torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),\n                valid_response_length - 1,\n            ] = data.batch[\"acc\"]\n\n            reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef)\n\n        final_reward_tensor = sum(reward_tensors)\n\n        returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])\n\n        advantages = returns.clone()\n        advantages = verl_F.masked_whiten(advantages, response_mask)\n\n        return advantages, returns\n\n\ndef compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta):\n    cur_scores = ((token_level_scores * response_mask).sum(dim=1) * beta).sigmoid()\n    cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)\n    return cur_dpo_loss\n\n\ndef compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode=\"none\"):\n    # we always assume that the BoN size equals n_samples\n    # mode1: use acc as rm\n    # mode2: use Q as rm\n    cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta\n    other_Q = torch.zeros_like(cur_Q)\n    for i in range(token_level_scores.shape[0]):\n        Q_chosen = Q_bc[i][acc_bc[i] < acc[i]] if acc[i] > 0 else Q_bc[i][acc_bc[i] > acc[i]]\n        if len(Q_chosen) > 0:\n            other_Q[i] = Q_chosen.mean() * beta\n        else:\n            other_Q[i] = 0\n    dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1)))\n    if bon_mode == \"none\":\n        dpo_loss = dpo_loss.mean()\n    else:\n        weight = torch.zeros_like(dpo_loss)\n        n_samples = acc_bc.shape[1]\n        if bon_mode == \"bon_rm\":\n            for i in range(token_level_scores.shape[0]):\n                weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1)\n        elif bon_mode == \"bon_acc\":\n            for i in range(token_level_scores.shape[0]):\n                weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1)\n        else:\n            raise NotImplementedError\n        dpo_loss = (dpo_loss * weight).sum()\n\n    return dpo_loss\n\n\ndef compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples):\n    dpo_acc = []\n    for start_id in range(0, token_level_scores.shape[0], n_samples):\n        cur_scores = (\n            token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples]\n        ).sum(dim=1)\n\n        def get_upper_triangle(tensor_x):\n            diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)\n            upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1)\n            return diff_matrix[upper_tri_indices]\n\n        cur_acc_diff = get_upper_triangle(acc[start_id : start_id + n_samples])  # in range [-1,1]\n        cur_score_diff = get_upper_triangle(cur_scores)  # in R\n        cur_score_prediction = (cur_score_diff > 0).float()  # in [0,1]\n        if cur_acc_diff.abs().sum() == 0:\n            cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5\n        else:\n            cur_acc = (\n                ((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs()\n            ).sum() / cur_acc_diff.abs().sum()\n\n        dpo_acc.append(cur_acc.unsqueeze(0))\n\n    return torch.cat(dpo_acc, dim=0).mean()\n\n\ndef compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples):\n    return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()\n"
  },
  {
    "path": "verl_rl/recipe/prime/prime_dp_rm.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement a multiprocess PPOCritic\n\"\"\"\n\nimport itertools\n\nimport torch\nimport torch.distributed\nfrom flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\nfrom torch import nn, optim\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nimport verl.utils.torch_functional as verl_F\nfrom verl import DataProto\nfrom verl.utils.device import get_device_name\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs\n\nfrom .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm\n\n__all__ = [\"DataParallelPRIMERewardModel\"]\n\n\nclass DataParallelPRIMERewardModel:\n    def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer):\n        self.config = config\n        self.reward_module = reward_module\n        self.ref_module = ref_module\n        self.reward_optimizer = reward_optimizer\n        self.use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        print(f\"Reward model use_remove_padding={self.use_remove_padding}\")\n        self.use_fused_kernels = self.config.model.get(\"use_fused_kernels\", False)\n        print(f\"Reward model use_fused_kernels={self.use_fused_kernels}\")\n\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n\n    def _forward_micro_batch(self, micro_batch, prompt_length):\n        input_ids = micro_batch[\"input_ids\"]\n        batch_size, seqlen = input_ids.shape\n        attention_mask = micro_batch[\"attention_mask\"]\n        position_ids = micro_batch[\"position_ids\"]\n\n        num_actions = micro_batch[\"input_ids\"].shape[-1] - prompt_length\n        max_positions = micro_batch[\"attention_mask\"][:, prompt_length:].sum(-1)\n\n        if self.use_remove_padding:\n            input_ids_rmpad, indices, *_ = unpad_input(\n                input_ids.unsqueeze(-1), attention_mask\n            )  # input_ids_rmpad (total_nnz, ...)\n            input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n            # unpad the position_ids to align the rotary\n            position_ids_rmpad = index_first_axis(\n                rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n            ).transpose(0, 1)\n\n            # for compute the log_prob\n            input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)\n\n            # pad and slice the inputs if sp > 1\n            if self.ulysses_sequence_parallel_size > 1:\n                input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                    input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size\n                )\n                input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(\n                    input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size\n                )\n\n            input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)\n            output = self.reward_module(\n                input_ids=input_ids_rmpad,\n                attention_mask=None,\n                position_ids=position_ids_rmpad,\n                use_cache=False,\n                return_dict=self.use_fused_kernels,\n            )\n\n            if self.use_fused_kernels:\n                rm_log_labels = output.log_probs.squeeze(0)  # (total_nnz,)\n                rm_log_labels = rm_log_labels.to(torch.float32)\n\n            else:\n                rm_output_logits = output.logits.squeeze(0)\n                rm_log_labels = verl_F.logprobs_from_logits(\n                    logits=rm_output_logits,\n                    labels=input_ids_rmpad_rolled,\n                )\n\n            if self.ulysses_sequence_parallel_size > 1:\n                rm_log_labels = gather_outputs_and_unpad(\n                    rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size\n                )\n            rm_log_labels = pad_input(\n                hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n            ).squeeze(-1)[:, -num_actions - 1 : -1]\n\n        else:\n            output = self.reward_module(\n                input_ids=micro_batch[\"input_ids\"],\n                attention_mask=micro_batch[\"attention_mask\"],\n                position_ids=micro_batch[\"position_ids\"],\n                use_cache=False,\n                return_dict=self.use_fused_kernels,\n            )\n\n            if self.use_fused_kernels:\n                rm_log_labels = output.log_probs[:, :-1]  # (bsz, seq_length)\n                rm_log_labels = rm_log_labels.to(torch.float32)\n\n            else:\n                rm_output_logits = output.logits\n                rm_log_prob = torch.nn.functional.log_softmax(\n                    rm_output_logits[:, :-1, :], dim=-1\n                )  # (batch_size, seq_length, vocab_size)\n                rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch[\"input_ids\"][:, 1:].unsqueeze(-1)).squeeze(\n                    -1\n                )  # (batch, seq_length)\n\n        if self.ref_module is not None:\n            # do not have to pad again\n            with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):\n                if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding:\n                    ref_output = self.ref_module(\n                        input_ids=input_ids_rmpad,\n                        attention_mask=None,\n                        position_ids=position_ids_rmpad,\n                        use_cache=False,\n                    )\n\n                    if self.use_fused_kernels:\n                        ref_log_labels = ref_output.log_probs.squeeze(0)  # (total_nnz,)\n                        ref_log_labels = ref_log_labels.to(torch.float32)\n\n                    else:\n                        ref_output_logits = ref_output.logits.squeeze(0)\n                        ref_log_labels = verl_F.logprobs_from_logits(\n                            logits=ref_output_logits, labels=input_ids_rmpad_rolled\n                        )\n\n                    ref_log_labels = gather_outputs_and_unpad(\n                        ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size\n                    )\n                    ref_log_labels = pad_input(\n                        hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n                    ).squeeze(-1)[:, -num_actions - 1 : -1]\n                else:\n                    ref_output = self.ref_module(\n                        input_ids=micro_batch[\"input_ids\"],\n                        attention_mask=micro_batch[\"attention_mask\"],\n                        position_ids=micro_batch[\"position_ids\"],\n                        use_cache=False,\n                    )\n\n                    if self.use_fused_kernels:\n                        ref_log_labels = ref_output.log_probs[:, :-1]  # (batch_size, seq_length)\n                        ref_log_labels = ref_log_labels.to(torch.float32)\n\n                    else:\n                        ref_output_logits = ref_output.logits\n                        ref_log_prob = torch.nn.functional.log_softmax(\n                            ref_output_logits[:, :-1, :], dim=-1\n                        )  # (batch_size, seq_length, vocab_size)\n                        ref_log_labels = ref_log_prob.gather(\n                            dim=-1, index=micro_batch[\"input_ids\"][:, 1:].unsqueeze(-1)\n                        ).squeeze(-1)  # (batch, seq_length)\n\n        else:\n            ref_log_labels = micro_batch[\"old_log_probs\"]\n\n        ref_log_labels.to(rm_log_labels.dtype)\n        q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:]  # this is actually diff of q\n\n        # trim unnecessary logprobs here\n        for i in range(micro_batch[\"input_ids\"].shape[0]):\n            q[i, max_positions[i] :] = 0\n\n        # reward computation does not need gradient. only q needs\n        with torch.no_grad():\n            # generalized estimation of r should go before the reward filling. r means process reward for policy\n            # model, or the advantage of reward model.\n            lam = self.config.get(\"lambda\", 0.0)\n            beta = self.config.model.get(\"beta_train\", 0.05)\n            if lam == 0.0:\n                r = q * beta\n            else:\n                # reward coefficient takes no effect here\n                acc = micro_batch[\"acc\"]\n                q_ = q * beta\n                r = torch.zeros_like(q)\n                lastgaelam = 0\n                # change the last token and mask out all paddings to make this process easier if we rely on\n                # outcome reward to calculate V\n                for i in range(q.shape[0]):\n                    if self.config.prime_use_gt:\n                        q_[i, max_positions[i] - 1] = acc[i] - q_[i, : max_positions[i] - 1].sum()\n                    q_[i, max_positions[i] :] = 0\n\n                for t in reversed(range(num_actions)):\n                    delta = q_[:, t]\n                    lastgaelam = delta + lam * lastgaelam\n                    r[:, t] = lastgaelam\n\n            token_level_score = torch.zeros_like(q)\n\n            if self.config.prime_granularity == \"token\":\n                for i in range(micro_batch[\"input_ids\"].shape[0]):\n                    token_level_score[i, : max_positions[i] - 1] = r[i, : max_positions[i] - 1]\n            elif self.config.prime_granularity == \"whole\":\n                for i in range(micro_batch[\"input_ids\"].shape[0]):\n                    token_level_score[i, max_positions[i] - 1] = r[i, : max_positions[i]]\n            else:\n                raise NotImplementedError\n\n        return token_level_score, q\n\n    def _optimizer_step(self):\n        assert self.config.model.optim.grad_clip is not None\n\n        if isinstance(self.reward_module, FSDP):\n            grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(\n                self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip\n            )\n        self.reward_optimizer.step()\n        return grad_norm\n\n    def prime_norm(self, token_level_scores):\n        if self.config.prime_norm == \"batch_norm\":\n            reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1])\n            token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6)\n        return token_level_scores\n\n    def compute_rm_score(self, data: DataProto):\n        self.reward_module.eval()\n        self.ref_module.eval()\n        micro_batch_size = data.meta_info[\"micro_batch_size\"]\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\", \"acc\"]\n        batch = data.select(batch_keys=select_keys).batch\n        use_dynamic_bsz = data.meta_info[\"use_dynamic_bsz\"]\n        prompt_length = data.batch[\"input_ids\"].shape[-1] - data.batch[\"responses\"].shape[-1]\n\n        if use_dynamic_bsz:\n            # split using dynamic bsz\n            max_token_len = data.meta_info[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)\n        else:\n            micro_batches = batch.split(micro_batch_size)\n\n        rm_scores_lst = []\n        q_lst = []\n        for micro_batch in micro_batches:\n            with torch.no_grad():\n                rm_score, q = self._forward_micro_batch(micro_batch, prompt_length)\n            rm_scores_lst.append(rm_score)\n            q_lst.append(q)\n        rm_scores = torch.concat(rm_scores_lst, dim=0)\n        q = torch.concat(q_lst, dim=0)\n\n        rm_scores = self.prime_norm(rm_scores)\n\n        if use_dynamic_bsz:\n            indices = list(itertools.chain.from_iterable(indices))\n            assert len(indices) == rm_scores.size(0), f\"{len(indices)} vs. {rm_scores.size()}\"\n            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n            rm_scores = rm_scores[revert_indices]\n\n        return (\n            rm_scores,\n            q.detach(),\n            {\n                \"reward_model/reward\": rm_scores.sum(dim=-1).mean().item(),\n                \"reward_model/raw_reward\": q.sum(dim=-1).mean().item(),\n            },\n        )\n\n    def update_rm(self, data: DataProto):\n        # make sure we are in training mode\n        self.reward_module.train()\n        metrics = {}\n\n        beta = self.config.model.get(\"beta_train\", 0.05)\n\n        select_keys = [\"input_ids\", \"responses\", \"attention_mask\", \"position_ids\", \"acc\", \"prompts\"]\n\n        for key in [\"Q_bc\", \"acc_bc\"]:\n            if key in data.batch.keys():\n                select_keys.append(key)\n\n        batch = data.select(batch_keys=select_keys).batch\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        dataloader = batch.split(self.config.mini_batch_size)\n\n        rm_scores_lst = []\n        q_lst = []\n\n        for batch_idx, data in enumerate(dataloader):\n            # split batch into micro_batches\n            mini_batch = data\n            if self.config.use_dynamic_bsz:\n                max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)\n            else:\n                micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu)\n                self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu\n\n            self.reward_optimizer.zero_grad()\n\n            for data in micro_batches:\n                data = data.to(get_device_name())\n                attention_mask = data[\"attention_mask\"]\n                acc = data[\"acc\"]\n\n                prompt_ids = data[\"prompts\"]\n                prompt_length = prompt_ids.shape[-1]\n\n                response_mask = attention_mask[:, prompt_length:]\n\n                rm_score, q = self._forward_micro_batch(data, prompt_length)\n\n                rm_scores_lst.append(rm_score)\n                q_lst.append(q.detach())\n\n                if self.config.model.loss_type == \"ce\":\n                    dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta)\n                elif self.config.model.loss_type == \"dpo\":\n                    # the implementation of dpo is actually detached, which means we have to know the average\n                    # value of w/l reward before the update.\n                    dpo_loss = compute_detach_dpo_loss_rm(\n                        q, acc, Q_bc=data[\"Q_bc\"], acc_bc=data[\"acc_bc\"], response_mask=response_mask, beta=beta\n                    )\n                elif self.config.model.loss_type == \"bon_acc\":\n                    # change the original distribution of each sample to BoN distribution, then update reward model\n                    dpo_loss = compute_detach_dpo_loss_rm(\n                        q,\n                        acc,\n                        Q_bc=data[\"Q_bc\"],\n                        acc_bc=data[\"acc_bc\"],\n                        response_mask=response_mask,\n                        beta=beta,\n                        bon_mode=\"bon_acc\",\n                    )\n                elif self.config.model.loss_type == \"bon_rm\":\n                    dpo_loss = compute_detach_dpo_loss_rm(\n                        q,\n                        acc,\n                        Q_bc=data[\"Q_bc\"],\n                        acc_bc=data[\"acc_bc\"],\n                        response_mask=response_mask,\n                        beta=beta,\n                        bon_mode=\"bon_rm\",\n                    )\n                else:\n                    raise NotImplementedError\n\n                data = {\"reward_model/dpo_loss\": dpo_loss.detach().item()}\n\n                if self.config.use_dynamic_bsz:\n                    # relative to the dynamic bsz\n                    loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size)\n                else:\n                    loss = dpo_loss / self.gradient_accumulation\n\n                loss.backward()\n\n                append_to_dict(metrics, data)\n\n            grad_norm = self._optimizer_step()\n            data = {\"reward_model/grad_norm\": grad_norm.detach().item()}\n            append_to_dict(metrics, data)\n        self.reward_optimizer.zero_grad()\n\n        rm_scores = torch.cat(rm_scores_lst, dim=0)\n        q = torch.concat(q_lst, dim=0)\n\n        rm_scores = self.prime_norm(rm_scores)\n\n        metrics.update(\n            {\n                \"reward_model/reward\": rm_scores.sum(dim=-1).mean().item(),\n                \"reward_model/raw_reward\": q.sum(dim=-1).mean().item(),\n            }\n        )\n\n        return rm_scores, metrics\n"
  },
  {
    "path": "verl_rl/recipe/prime/prime_fsdp_workers.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport logging\nimport os\nimport warnings\n\nimport torch\nimport torch.distributed\nfrom torch.distributed.device_mesh import init_device_mesh\n\nfrom verl import DataProto\nfrom verl.models.transformers.monkey_patch import apply_monkey_patch\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.device import get_device_id, get_device_name, get_nccl_backend\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fs import copy_local_path_from_hdfs\nfrom verl.utils.fsdp_utils import (\n    get_fsdp_wrap_policy,\n    get_init_weight_context_manager,\n    init_fn,\n    load_fsdp_model_to_gpu,\n    load_fsdp_optimizer,\n    offload_fsdp_model_to_cpu,\n    offload_fsdp_optimizer,\n)\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.profiler import log_gpu_memory_usage\nfrom verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\nfrom .prime_core_algos import compute_dpo_abs_accuracy, compute_dpo_accuracy\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass PRIMERewardModelWorker(Worker):\n    def __init__(self, config):\n        super().__init__()\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(backend=get_nccl_backend())\n        self.config = config\n\n        # build device mesh for Ulysses Sequence Parallel\n        world_size = torch.distributed.get_world_size()\n\n        fsdp_size = self.config.model.fsdp_config.fsdp_size\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        # set FSDP offload params\n        self._is_offload_param = self.config.model.fsdp_config.param_offload\n        self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload\n\n        # normalize config\n        self.config.mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n        if self.config.micro_batch_size is not None:\n            self.config.micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n            self.config.micro_batch_size_per_gpu = self.config.micro_batch_size\n            assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0\n\n    def _build_reward_ref_model_optimizer(self, config):\n        # the following line is necessary\n        from torch import optim\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from torch.distributed.fsdp import MixedPrecision\n\n        from verl.utils.model import print_model_size\n        from verl.utils.torch_dtypes import PrecisionType\n\n        local_path = copy_local_path_from_hdfs(config.model.path)\n\n        tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)\n        self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n\n        from omegaconf import OmegaConf\n\n        override_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_config)\n        if self.rank == 0:\n            print(f\"Reward model overriding config {override_config_kwargs}\")\n\n        torch_dtype = self.config.model.fsdp_config.get(\"model_dtype\", \"fp32\")\n        torch_dtype = PrecisionType.to_dtype(torch_dtype)\n\n        from transformers import AutoConfig, AutoModelForCausalLM\n\n        trust_remote_code = False\n        reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)\n        reward_model_config.num_labels = 1\n\n        init_context = get_init_weight_context_manager(use_meta_tensor=not reward_model_config.tie_word_embeddings)\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            reward_model_config.classifier_dropout = 0.0\n            reward_model_config.hidden_dropout = \"0\"\n            reward_module = AutoModelForCausalLM.from_pretrained(\n                pretrained_model_name_or_path=local_path,\n                torch_dtype=torch_dtype,\n                config=reward_model_config,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )\n\n            fused_kernel_options = config.model.get(\"fused_kernel_options\", None)\n            fused_kernels_backend = (\n                fused_kernel_options.get(\"impl_backend\", None) if fused_kernel_options is not None else None\n            )\n\n            apply_monkey_patch(\n                model=reward_module,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n                use_remove_padding=config.model.get(\"use_remove_padding\", False),\n                use_fused_kernels=config.model.get(\"use_fused_kernels\", False),\n                fused_kernels_backend=fused_kernels_backend,\n            )\n\n            # some parameters may not in torch_dtype\n            reward_module.to(torch_dtype)\n\n            if config.model.get(\"enable_gradient_checkpointing\", False):\n                reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n        if self.rank == 0:\n            print_model_size(reward_module)\n\n        self.reward_model_config = reward_model_config\n\n        fsdp_config = self.config.model.fsdp_config\n        mixed_precision_config = fsdp_config.get(\"mixed_precision\", None)\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"param_dtype\", \"bf16\"))\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"reduce_dtype\", \"fp32\"))\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"buffer_dtype\", \"fp32\"))\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n\n        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy)\n\n        log_gpu_memory_usage(\"Before reward model FSDP\", logger=None)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            reward_model_config.classifier_dropout = 0.0\n            reward_model_config.hidden_dropout = \"0\"\n            ref_module = AutoModelForCausalLM.from_pretrained(\n                pretrained_model_name_or_path=copy_local_path_from_hdfs(config.model.ref_path),\n                torch_dtype=torch_dtype,\n                config=reward_model_config,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )\n\n            # some parameters may not in torch_dtype\n            ref_module.to(torch_dtype)\n\n        reward_module = FSDP(\n            reward_module,\n            param_init_fn=init_fn,\n            use_orig_params=False,\n            auto_wrap_policy=auto_wrap_policy,\n            device_id=get_device_id(),\n            sharding_strategy=sharding_strategy,\n            mixed_precision=mixed_precision,\n            sync_module_states=True,\n            forward_prefetch=False,\n            device_mesh=self.device_mesh,\n            cpu_offload=None,\n        )\n\n        log_gpu_memory_usage(\"After reward FSDP\", logger=None)\n\n        ref_module = FSDP(\n            ref_module,\n            param_init_fn=init_fn,\n            use_orig_params=False,\n            auto_wrap_policy=auto_wrap_policy,\n            device_id=get_device_id(),\n            sharding_strategy=sharding_strategy,\n            mixed_precision=mixed_precision,\n            sync_module_states=True,\n            forward_prefetch=False,\n            device_mesh=self.device_mesh,\n            cpu_offload=None,\n        )\n\n        reward_optimizer = optim.AdamW(\n            reward_module.parameters(),\n            lr=config.model.optim.lr,\n            betas=config.model.optim.get(\"betas\", (0.9, 0.999)),\n            weight_decay=config.model.optim.get(\"weight_decay\", 1e-2),\n        )\n\n        total_steps = config.model.optim.get(\"total_training_steps\", 0)\n        num_warmup_steps = int(config.model.optim.get(\"lr_warmup_steps\", -1))\n        if num_warmup_steps < 0:\n            num_warmup_steps_ratio = config.model.optim.get(\"lr_warmup_steps_ratio\", 0.0)\n            num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n\n        print(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n\n        from verl.utils.torch_functional import get_constant_schedule_with_warmup\n\n        reward_lr_scheduler = get_constant_schedule_with_warmup(\n            optimizer=reward_optimizer, num_warmup_steps=num_warmup_steps\n        )\n\n        return reward_module, ref_module, reward_optimizer, reward_lr_scheduler\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        from .prime_dp_rm import DataParallelPRIMERewardModel\n\n        self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = (\n            self._build_reward_ref_model_optimizer(config=self.config)\n        )\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.reward_module)\n            offload_fsdp_model_to_cpu(self.ref_module)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.reward_optimizer)\n\n        self.rm = DataParallelPRIMERewardModel(\n            config=self.config,\n            reward_module=self.reward_module,\n            ref_module=self.ref_module,\n            reward_optimizer=self.reward_optimizer,\n        )\n\n        self.flops_counter = FlopsCounter(self.reward_model_config)\n        self.checkpoint_manager = FSDPCheckpointManager(\n            model=self.reward_module,\n            optimizer=self.reward_optimizer,\n            lr_scheduler=self.reward_lr_scheduler,\n            tokenizer=self.tokenizer,\n        )\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def compute_rm_score(self, data: DataProto):\n        data = data.to(get_device_name())\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.reward_module)\n            load_fsdp_model_to_gpu(self.ref_module)\n        micro_batch_size = self.config.micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"max_token_len\"] = self.config.forward_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n            rm_scores, q, metrics = self.rm.compute_rm_score(data=data)\n\n            prompt_length = data.batch[\"prompts\"].shape[-1]\n            response_mask = data.batch[\"attention_mask\"][:, prompt_length:]\n            acc = data.batch[\"acc\"]\n\n            dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info[\"n\"])\n            dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info[\"n\"])\n\n            metrics[\"reward_model/dpo_acc\"] = dpo_acc.detach().item()\n            metrics[\"reward_model/dpo_acc_abs\"] = dpo_acc_abs.detach().item()\n\n            output = DataProto.from_dict(tensors={\"rm_scores\": rm_scores, \"q\": q}, meta_info={\"metrics\": metrics})\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n\n        output = output.to(\"cpu\")\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.reward_module)\n            offload_fsdp_model_to_cpu(self.ref_module)\n        return output\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def update_rm(self, data: DataProto):\n        data = data.to(get_device_name())\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.ref_module)\n            load_fsdp_model_to_gpu(self.reward_module)\n        if self._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=get_device_id())\n\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n\n            rm_scores, metrics = self.rm.update_rm(data=data)\n\n            self.reward_lr_scheduler.step()\n            lr = self.reward_lr_scheduler.get_last_lr()[0]\n            metrics[\"rm/lr\"] = lr\n\n            prompt_length = data.batch[\"prompts\"].shape[-1]\n            response_mask = data.batch[\"attention_mask\"][:, prompt_length:]\n            acc = data.batch[\"acc\"]\n\n            dpo_acc_before = compute_dpo_accuracy(\n                rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info[\"n\"]\n            )\n            dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info[\"n\"])\n\n            metrics[\"reward_model/dpo_acc_before\"] = dpo_acc_before.detach().item()\n            metrics[\"reward_model/dpo_acc_abs_before\"] = dpo_acc_abs.detach().item()\n\n            output = DataProto.from_dict(tensors={\"rm_scores\": rm_scores}, meta_info={\"metrics\": metrics})\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.reward_module)\n            offload_fsdp_model_to_cpu(self.ref_module)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.reward_optimizer)\n        output = output.to(\"cpu\")\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.reward_module)\n\n        self.checkpoint_manager.save_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.reward_module)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, local_path, del_local_after_load=True):\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.reward_module)\n\n        self.checkpoint_manager.load_checkpoint(local_path=local_path, del_local_after_load=del_local_after_load)\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.reward_module)\n"
  },
  {
    "path": "verl_rl/recipe/prime/prime_ray_trainer.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nFSDP PPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport os\nimport statistics\nimport uuid\nfrom copy import deepcopy\nfrom pprint import pprint\n\nimport numpy as np\nimport torch\nfrom omegaconf import OmegaConf, open_dict\n\nfrom verl import DataProto\nfrom verl.single_controller.ray import RayWorkerGroup\nfrom verl.trainer.ppo.core_algos import agg_loss\nfrom verl.trainer.ppo.metric_utils import _compute_response_info\nfrom verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType\nfrom verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path\nfrom verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn\nfrom verl.utils.metric import reduce_metrics\nfrom verl.utils.profiler.performance import simple_timer\n\nfrom . import prime_core_algos\n\n\ndef compute_advantage(data: DataProto, adv_estimator, config):\n    if adv_estimator == \"rloo\":\n        responses = data.batch[\"responses\"]\n        response_length = responses.size(-1)\n        attention_mask = data.batch[\"attention_mask\"]\n        response_mask = attention_mask[:, -response_length:]\n        advantages, returns = prime_core_algos.compute_rloo_advantage_return(\n            data, response_mask, config.actor_rollout_ref.rollout.n, config\n        )\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n    else:\n        raise NotImplementedError\n    return data\n\n\ndef compute_data_metrics(batch, use_critic=True):\n    advantages = batch.batch[\"advantages\"]\n    returns = batch.batch[\"returns\"]\n\n    max_response_length = batch.batch[\"responses\"].shape[-1]\n\n    prompt_mask = batch.batch[\"attention_mask\"][:, :-max_response_length].bool()\n    response_mask = batch.batch[\"attention_mask\"][:, -max_response_length:].bool()\n\n    max_prompt_length = prompt_mask.size(-1)\n\n    response_info = _compute_response_info(batch)\n    prompt_length = response_info[\"prompt_length\"]\n    response_length = response_info[\"response_length\"]\n\n    valid_adv = torch.masked_select(advantages, response_mask)\n    valid_returns = torch.masked_select(returns, response_mask)\n\n    if use_critic:\n        values = batch.batch[\"values\"]\n        valid_values = torch.masked_select(values, response_mask)\n        return_diff_var = torch.var(valid_returns - valid_values)\n        return_var = torch.var(valid_returns)\n\n    metrics = {\n        # adv\n        \"critic/advantages/mean\": torch.mean(valid_adv).detach().item(),\n        \"critic/advantages/max\": torch.max(valid_adv).detach().item(),\n        \"critic/advantages/min\": torch.min(valid_adv).detach().item(),\n        # returns\n        \"critic/returns/mean\": torch.mean(valid_returns).detach().item(),\n        \"critic/returns/max\": torch.max(valid_returns).detach().item(),\n        \"critic/returns/min\": torch.min(valid_returns).detach().item(),\n        **(\n            {\n                # values\n                \"critic/values/mean\": torch.mean(valid_values).detach().item(),\n                \"critic/values/max\": torch.max(valid_values).detach().item(),\n                \"critic/values/min\": torch.min(valid_values).detach().item(),\n                # vf explained var\n                \"critic/vf_explained_var\": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),\n            }\n            if use_critic\n            else {}\n        ),\n        # response length\n        \"response_length/mean\": torch.mean(response_length).detach().item(),\n        \"response_length/max\": torch.max(response_length).detach().item(),\n        \"response_length/min\": torch.min(response_length).detach().item(),\n        \"response_length/clip_ratio\": torch.mean(torch.eq(response_length, max_response_length).float())\n        .detach()\n        .item(),\n        # prompt length\n        \"prompt_length/mean\": torch.mean(prompt_length).detach().item(),\n        \"prompt_length/max\": torch.max(prompt_length).detach().item(),\n        \"prompt_length/min\": torch.min(prompt_length).detach().item(),\n        \"prompt_length/clip_ratio\": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),\n    }\n    return metrics\n\n\ndef compute_response_mask(data: DataProto):\n    responses = data.batch[\"responses\"]\n    response_length = responses.size(1)\n    attention_mask = data.batch[\"attention_mask\"]\n    return attention_mask[:, -response_length:]\n\n\ndef compute_timing_metrics(batch, timing_raw):\n    response_info = _compute_response_info(batch)\n    num_prompt_tokens = torch.sum(response_info[\"prompt_length\"]).item()\n    num_response_tokens = torch.sum(response_info[\"response_length\"]).item()\n    num_overall_tokens = num_prompt_tokens + num_response_tokens\n\n    num_tokens_of_section = {\n        \"gen\": num_response_tokens,\n        **{name: num_overall_tokens for name in [\"ref\", \"values\", \"adv\", \"update_critic\", \"update_actor\"]},\n    }\n\n    return {\n        **{f\"timing_s/{name}\": value for name, value in timing_raw.items()},\n        **{\n            f\"timing_per_token_ms/{name}\": timing_raw[name] * 1000 / num_tokens_of_section[name]\n            for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())\n        },\n    }\n\n\nclass RayPRIMETrainer(RayPPOTrainer):\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        reward_fn=None,\n        val_reward_fn=None,\n        device_name=\"cuda\",\n    ):\n        # assert get_torch_device().is_available(), 'cuda must be available on driver'\n\n        super().__init__(\n            config,\n            tokenizer,\n            role_worker_mapping,\n            resource_pool_manager,\n            ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            device_name=device_name,\n        )\n\n        self.use_critic = False\n\n    def _validate_config(self):\n        super()._validate_config()\n        # TODO: Additional config checks can be added here\n\n    def _create_dataloader(self, *args, **kwargs):\n        from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n\n        # TODO: we have to make sure the batch size is divisible by the dp size\n        self.train_dataset = RLHFDataset(\n            data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data\n        )\n        # use sampler for better ckpt resume\n        if self.config.data.shuffle:\n            train_dataloader_generator = torch.Generator()\n            train_dataloader_generator.manual_seed(self.config.data.get(\"seed\", 1))\n            sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)\n        else:\n            sampler = SequentialSampler(data_source=self.train_dataset)\n\n        self.train_dataloader = DataLoader(\n            dataset=self.train_dataset,\n            batch_size=int(self.config.data.train_batch_size * self.config.data.oversample_factor),\n            drop_last=True,\n            collate_fn=collate_fn,\n            sampler=sampler,\n        )\n\n        self.val_dataset = RLHFDataset(\n            data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data\n        )\n        self.val_dataloader = DataLoader(\n            dataset=self.val_dataset,\n            batch_size=len(self.val_dataset),\n            shuffle=True,\n            drop_last=True,\n            collate_fn=collate_fn,\n        )\n\n        assert len(self.train_dataloader) >= 1\n        assert len(self.val_dataloader) >= 1\n\n        print(f\"Size of train dataloader: {len(self.train_dataloader)}\")\n        print(f\"Size of val dataloader: {len(self.val_dataloader)}\")\n\n        # inject total_training_steps to actor/critic optim_config. This is hacky.\n        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n\n        if self.config.trainer.total_training_steps is not None:\n            total_training_steps = self.config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n        print(f\"Total training steps: {self.total_training_steps}\")\n\n        OmegaConf.set_struct(self.config, True)\n        with open_dict(self.config):\n            self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps\n            self.config.critic.optim.total_training_steps = total_training_steps\n\n    def _save_checkpoint(self):\n        # path: given_path + `/global_step_{global_steps}` + `/actor`\n        local_global_step_folder = os.path.join(\n            self.config.trainer.default_local_dir, f\"global_step_{self.global_steps}\"\n        )\n        print(f\"local_global_step_folder: {local_global_step_folder}\")\n        actor_local_path = os.path.join(local_global_step_folder, \"actor\")\n\n        actor_remote_path = (\n            None\n            if self.config.trainer.default_hdfs_dir is None\n            else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"actor\")\n        )\n        self.actor_rollout_wg.save_checkpoint(\n            actor_local_path,\n            actor_remote_path,\n            self.global_steps,\n        )\n\n        if self.use_rm:\n            reward_local_path = os.path.join(local_global_step_folder, \"reward\")\n            reward_remote_path = (\n                None\n                if self.config.trainer.default_hdfs_dir is None\n                else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"reward\")\n            )\n            self.rm_wg.save_checkpoint(\n                reward_local_path,\n                reward_remote_path,\n                self.global_steps,\n            )\n\n        # save dataloader\n        dataloader_local_path = os.path.join(local_global_step_folder, \"data.pt\")\n        import dill\n\n        torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill)\n\n        # latest checkpointed iteration tracker (for atomic usage)\n        local_latest_checkpointed_iteration = os.path.join(\n            self.config.trainer.default_local_dir, \"latest_checkpointed_iteration.txt\"\n        )\n        with open(local_latest_checkpointed_iteration, \"w\") as f:\n            f.write(str(self.global_steps))\n\n    def _load_checkpoint(self):\n        if self.config.trainer.resume_mode == \"disable\":\n            return 0\n\n        # load from hdfs\n        if self.config.trainer.default_hdfs_dir is not None:\n            NotImplementedError(\"load from hdfs is not implemented yet\")\n        else:\n            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path\n            if not os.path.isabs(checkpoint_folder):\n                working_dir = os.getcwd()\n                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)\n            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest\n\n        # find global_step_folder\n        if self.config.trainer.resume_mode == \"auto\":\n            if global_step_folder is None:\n                print(\"Training from scratch\")\n                return 0\n        else:\n            if self.config.trainer.resume_mode == \"resume_path\":\n                assert isinstance(self.config.trainer.resume_from_path, str), \"resume ckpt must be str type\"\n                assert \"global_step_\" in self.config.trainer.resume_from_path, (\n                    \"resume ckpt must specify the global_steps\"\n                )\n                global_step_folder = self.config.trainer.resume_from_path\n                if not os.path.isabs(global_step_folder):\n                    working_dir = os.getcwd()\n                    global_step_folder = os.path.join(working_dir, global_step_folder)\n        print(f\"Load from checkpoint folder: {global_step_folder}\")\n        # set global step\n        self.global_steps = int(global_step_folder.split(\"global_step_\")[-1])\n\n        print(f\"Setting global step to {self.global_steps}\")\n        print(f\"Resuming from {global_step_folder}\")\n\n        actor_path = os.path.join(global_step_folder, \"actor\")\n        reward_path = os.path.join(global_step_folder, \"reward\")\n        # load actor\n        self.actor_rollout_wg.load_checkpoint(\n            actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n        )\n        # load rm\n        if self.use_rm:\n            self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)\n\n        # load dataloader,\n        # TODO: from remote not implemented yet\n        dataloader_local_path = os.path.join(global_step_folder, \"data.pt\")\n        self.train_dataloader = torch.load(dataloader_local_path)\n        if isinstance(self.train_dataloader.dataset, RLHFDataset):\n            self.train_dataloader.dataset.resume_dataset_state()\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC to\n        construct the PPO dataflow. The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # we start from step 1\n        self.global_steps += 1\n\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n                timing_raw = {}\n\n                batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n                # pop those keys for generation\n                gen_batch = batch.pop(batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"])\n                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n\n                with simple_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with simple_timer(\"gen\", timing_raw):\n                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                    if self.config.algorithm.adv_estimator == \"remax\":\n                        with simple_timer(\"gen_max\", timing_raw):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n\n                            batch = batch.union(gen_baseline_output)\n                            reward_baseline_tensor = self.reward_fn(batch)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))\n\n                            batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del gen_baseline_batch, gen_baseline_output\n\n                    batch.non_tensor_batch[\"uid\"] = np.array(\n                        [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object\n                    )\n                    # repeat to align with repeated responses in rollout\n                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    batch = batch.union(gen_batch_output)\n\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    # TODO: Decouple the DP balancing and mini-batching.\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                    # verify\n                    with simple_timer(\"verify\", timing_raw):\n                        scores = self.reward_fn.verify(batch)\n                        metrics[\"acc\"] = statistics.mean(scores)\n\n                    # filter the batch. 1/oversample_factor samples will be kept.\n                    # If there is a filter, prompts passing it will be prioritized.\n\n                    batch = self.filter_and_downsample(scores, batch)\n                    batch.meta_info[\"n\"] = self.config.actor_rollout_ref.rollout.n\n                    n_samples = self.config.actor_rollout_ref.rollout.n\n\n                    # recompute old_log_probs\n                    with simple_timer(\"old_log_prob\", timing_raw):\n                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                        entropys = old_log_prob.batch[\"entropys\"]\n                        response_masks = compute_response_mask(batch)\n                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                        old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n                        metrics.update(old_log_prob_metrics)\n                        old_log_prob.batch.pop(\"entropys\")\n                        batch = batch.union(old_log_prob)\n\n                    if self.use_reference_policy:\n                        # compute reference log_prob\n                        with simple_timer(\"ref\", timing_raw):\n                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                            batch = batch.union(ref_log_prob)\n\n                    with simple_timer(\"adv\", timing_raw):\n                        if self.use_rm:\n                            update_style = self.config.reward_model.model.get(\"update\", \"none\")\n                            if update_style == \"none\":  # only run forward\n                                reward_output = self.rm_wg.compute_rm_score(batch)\n                            elif update_style == \"after\":  # update and directly return the reward\n                                reward_output = self.rm_wg.update_rm(batch)\n                            elif update_style == \"before\":  # update reward model, and then run forward\n                                reward_output = self.rm_wg.update_rm(batch)\n                                if \"metrics\" in reward_output.meta_info.keys():\n                                    reward_output_metrics = reduce_metrics(reward_output.meta_info[\"metrics\"])\n                                    metrics.update(reward_output_metrics)\n\n                                reward_output = self.rm_wg.compute_rm_score(batch)\n                            elif (\n                                update_style == \"reverse\"\n                            ):  # run forward to calculate statistics, then update reward model\n                                reward_output = self.rm_wg.compute_rm_score(batch)\n                                # broadcast q and acc tensor to each result\n                                bc_td = DataProto.from_dict(\n                                    tensors={\n                                        \"Q_bc\": reward_output.batch[\"q\"]\n                                        .sum(dim=-1)\n                                        .view(-1, n_samples)\n                                        .unsqueeze(1)\n                                        .expand(-1, n_samples, -1)\n                                        .reshape(-1, n_samples),\n                                        \"acc_bc\": batch.batch[\"acc\"]\n                                        .view(-1, n_samples)\n                                        .unsqueeze(1)\n                                        .expand(-1, n_samples, -1)\n                                        .reshape(-1, n_samples),\n                                    }\n                                )\n                                batch = batch.union(bc_td)\n                                reward_output = self.rm_wg.update_rm(batch)\n                            else:\n                                raise NotImplementedError\n                            batch = batch.union(reward_output)\n                            if \"metrics\" in reward_output.meta_info.keys():\n                                reward_output_metrics = reduce_metrics(reward_output.meta_info[\"metrics\"])\n                                metrics.update(reward_output_metrics)\n\n                        # compute advantages, executed on the driver process\n                        batch = compute_advantage(\n                            batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config\n                        )\n\n                    # update actor\n                    with simple_timer(\"update_actor\", timing_raw):\n                        actor_output = self.actor_rollout_wg.update_actor(batch)\n                    actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                    metrics.update(actor_output_metrics)\n\n                    # validate\n                    if (\n                        self.val_reward_fn is not None\n                        and self.config.trainer.test_freq > 0\n                        and self.global_steps % self.config.trainer.test_freq == 0\n                    ):\n                        with simple_timer(\"testing\", timing_raw):\n                            val_metrics: dict = self._validate()\n                        metrics.update(val_metrics)\n\n                    if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0:\n                        with simple_timer(\"save_checkpoint\", timing_raw):\n                            self._save_checkpoint()\n\n                # collect metrics\n                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                self.global_steps += 1\n\n                if self.global_steps >= self.total_training_steps:\n                    # perform validation after training\n                    if self.val_reward_fn is not None:\n                        val_metrics = self._validate()\n                        pprint(f\"Final validation metrics: {val_metrics}\")\n                        logger.log(data=val_metrics, step=self.global_steps)\n                    if (\n                        self.config.trainer.save_freq > 0\n                        and (self.global_steps - 1) % self.config.trainer.save_freq != 0\n                    ):\n                        with simple_timer(\"save_checkpoint\", timing_raw):\n                            self._save_checkpoint()\n                    return\n\n    def filter_and_downsample(self, scores, batch: DataProto):\n        \"\"\"\n        downsample the batch according to oversample_factor\n        samples passing the filters will be prioritized\n        \"\"\"\n        n_samples = int(self.config.actor_rollout_ref.rollout.n)\n        reward_matrix = torch.tensor(scores).reshape(-1, n_samples)\n\n        filter_mask = torch.ones((reward_matrix.shape[0]), dtype=torch.bool)\n\n        if self.config.data.filter_accuracy:\n            acc_tensor = torch.mean(reward_matrix, dim=-1)\n            filter_mask[\n                (acc_tensor > self.config.data.accuracy_upper_bound)\n                | (acc_tensor < self.config.data.accuracy_lower_bound)\n            ] = False\n\n        if self.config.data.filter_truncate:\n            length_matrix = (\n                batch.batch[\"attention_mask\"][:, -batch.batch[\"responses\"].shape[-1] :]\n                .sum(dim=-1)\n                .reshape(-1, n_samples)\n            )\n            length_tensor = torch.max(length_matrix, dim=-1)[0]\n            filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False\n\n        reorder_index = torch.argsort(filter_mask, descending=True)\n        reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1)\n        batch.reorder(\n            reorder_index[: int(len(batch) // self.config.data.oversample_factor)]\n        )  # this operation is inplace\n\n        return batch\n"
  },
  {
    "path": "verl_rl/recipe/prime/run_prime_qwen.sh",
    "content": "set -x\n\n\ngsm8k_train_path=$HOME/data/gsm8k/train.parquet\ngsm8k_test_path=$HOME/data/gsm8k/test.parquet\n\n# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data\nmath_train_path=$HOME/data/math/train.parquet\nmath_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path', '$math_train_path']\"\ntest_files=\"['$gsm8k_test_path', '$math_test_path']\"\n\nmodel_path=PRIME-RL/Eurus-2-7B-SFT\n# model_path=Qwen/Qwen2.5-0.5B-Instruct\n\npython3 -m recipe.prime.main_prime \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=64 \\\n    data.val_batch_size=6312 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=3072 \\\n    data.filter_overlong_prompts=True \\\n    data.filter_accuracy=True \\\n    data.accuracy_lower_bound=0.2 \\\n    data.accuracy_upper_bound=0.8 \\\n    data.oversample_factor=4 \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=5e-7 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    algorithm.adv_estimator=rloo \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    reward_model.model.path=$model_path \\\n    reward_model.micro_batch_size_per_gpu=1 \\\n    reward_model.model.update=before \\\n    reward_model.model.beta_train=0.05 \\\n    reward_model.model.optim.lr=1e-6 \\\n    reward_model.model.optim.grad_clip=10.0 \\\n    reward_model.model.input_tokenizer=null \\\n    reward_model.mini_batch_size=64 \\\n    trainer.val_before_train=False \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='prime_example' \\\n    trainer.experiment_name='Eurus-2-7B-SFT-gsm8k' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=64 \\\n    trainer.test_freq=64 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/recipe/prime/run_prime_qwen_code.sh",
    "content": "set -x\n\n\n# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data\ncode_train_path=$HOME/data/code/train.parquet\ncode_test_path=$HOME/data/code/test.parquet\n\ntrain_files=\"['$code_train_path']\"\ntest_files=\"['$code_test_path']\"\n\nmodel_path=PRIME-RL/Eurus-2-7B-SFT\n# model_path=Qwen/Qwen2.5-0.5B-Instruct\n\npython3 -m recipe.prime.main_prime \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=64 \\\n    data.val_batch_size=6312 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=3072 \\\n    data.filter_overlong_prompts=True \\\n    data.filter_accuracy=True \\\n    data.accuracy_lower_bound=0.2 \\\n    data.accuracy_upper_bound=0.8 \\\n    data.oversample_factor=4 \\\n    actor_rollout_ref.model.path=$model_path \\\n    actor_rollout_ref.actor.optim.lr=5e-7 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.n=4 \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    algorithm.adv_estimator=rloo \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    reward_model.model.path=$model_path \\\n    reward_model.micro_batch_size_per_gpu=1 \\\n    reward_model.model.update=before \\\n    reward_model.model.beta_train=0.05 \\\n    reward_model.model.optim.lr=1e-6 \\\n    reward_model.model.optim.grad_clip=10.0 \\\n    reward_model.model.input_tokenizer=null \\\n    reward_model.mini_batch_size=64 \\\n    trainer.val_before_train=False \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='prime_example' \\\n    trainer.experiment_name='Eurus-2-7B-SFT-code' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=64 \\\n    trainer.test_freq=64 \\\n    trainer.total_epochs=15 $@\n"
  },
  {
    "path": "verl_rl/recipe/r1/README.md",
    "content": "# DeepSeek R1 Reproduction\n\nThis recipe is under development, if you are interested, checkout the TODO list and join this project! https://github.com/volcengine/verl/issues/708 \n\n## Reproducing Evaluation\n\nEval Results of DS-R1-Distill-Qwen2.5-1.5B (k=8)\n\nDataset | Test Results | Reported\n-- | -- | --\nGPQA Diamond | 35.3 | 33.8\nLiveCodeBench | 16.9 | 16.9\nAIME 2024 | 30.4 | 28.9\nCNMO 2024 (en) | 45.1 | -\nCNMO 2024 (zh) | 41.0 | -\n\n---\n\nEval Results (DS-R1)\n\nDataset | Test Results (k=1) | Test Results (k=4) | Reported\n-- | -- | -- | --\nGPQA Diamond | 67.7 | 69.6 | 71.5\nLiveCodeBench | 64.7 | 63.1 | 65.9\nAIME 2024 | 86.7 | 79.2 | 79.8\nCNMO 2024 | 75.0 | 78.5 | 78.8\n"
  },
  {
    "path": "verl_rl/recipe/r1/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/recipe/r1/config/evaluation.yaml",
    "content": "data:\n  path: /tmp/math_Qwen2-7B-Instruct.parquet\n  prompt_key: prompt\n  response_key: responses\n  data_source_key: data_source\n  reward_model_key: reward_model\n\ncustom_reward_function:\n  path: null\n  name: compute_score\n\nray_init:\n  num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then."
  },
  {
    "path": "verl_rl/recipe/r1/data_process.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nfrom functools import partial\n\nfrom datasets import concatenate_datasets, load_dataset\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\n\ndef example_map_fn(example, idx, process_fn, data_source, ability, split):\n    question, solution = process_fn(example)\n    data = {\n        \"data_source\": data_source,\n        \"prompt\": [{\"role\": \"user\", \"content\": question}],\n        \"ability\": ability,\n        \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n        \"extra_info\": {\"split\": split, \"index\": idx},\n    }\n    return data\n\n\ndef build_aime2024_dataset():\n    def process_aime2024(example):\n        return example[\"Problem\"], str(example[\"Answer\"])\n\n    data_source = \"Maxwell-Jia/AIME_2024\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n    dataset = load_dataset(data_source, split=\"train\")\n    map_fn = partial(\n        example_map_fn, process_fn=process_aime2024, data_source=data_source, ability=\"English\", split=\"test\"\n    )\n    dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)\n    return dataset\n\n\ndef build_gpqa_dimond_dataset():\n    import random\n\n    GPQA_QUERY_TEMPLATE = (\n        \"Answer the following multiple choice question. The last line of your response should be of the following \"\n        \"format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before \"\n        \"answering.\\n\\n{Question}\\n\\nA) {A}\\nB) {B}\\nC) {C}\\nD) {D}\"\n    )\n\n    def process_gpqa_diamond(example):\n        choices = [example[\"Incorrect Answer 1\"], example[\"Incorrect Answer 2\"], example[\"Incorrect Answer 3\"]]\n        random.shuffle(choices)\n        gold_index = random.randint(0, 3)\n        choices.insert(gold_index, example[\"Correct Answer\"])\n        query_prompt = GPQA_QUERY_TEMPLATE.format(\n            A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=example[\"Question\"]\n        )\n        gold_choice = \"ABCD\"[gold_index]\n        return query_prompt, gold_choice\n\n    data_source = \"Idavidrein/gpqa\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n\n    dataset = load_dataset(data_source, \"gpqa_diamond\", split=\"train\")\n    map_fn = partial(\n        example_map_fn, process_fn=process_gpqa_diamond, data_source=data_source, ability=\"Math\", split=\"test\"\n    )\n    dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)\n    return dataset\n\n\ndef build_cnmo2024_dataset():\n    def process_cnmo2024(example):\n        return example[\"question\"], example[\"answer\"]\n\n    data_source = \"opencompass/LiveMathBench\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n\n    dataset_en = load_dataset(data_source, \"v202412_CNMO_en\", split=\"test\")\n    map_fn_en = partial(\n        example_map_fn, process_fn=process_cnmo2024, data_source=\"opencompass/cnmo2024_en\", ability=\"Math\", split=\"test\"\n    )\n    dataset_en = dataset_en.map(map_fn_en, with_indices=True, remove_columns=dataset_en.column_names)\n\n    dataset_zh = load_dataset(data_source, \"v202412_CNMO_cn\", split=\"test\")\n    map_fn_zh = partial(\n        example_map_fn, process_fn=process_cnmo2024, data_source=\"opencompass/cnmo2024_zh\", ability=\"Math\", split=\"test\"\n    )\n    dataset_zh = dataset_zh.map(map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names)\n\n    dataset = concatenate_datasets([dataset_en, dataset_zh])\n    return dataset\n\n\ndef build_livecodebench_dataset():\n    import base64\n    import json\n    import pickle\n    import zlib\n\n    def process_livecodebench(example):\n        # Construct Query Prompt\n        # From https://github.com/LiveCodeBench/LiveCodeBench/blob/998c52d394b836f15fff3b9a29866191108ff81b/lcb_runner/prompts/code_generation.py#L140\n        query_prompt = (\n            f\"You will be given a question (problem specification) and will generate a correct Python program \"\n            f\"that matches the specification and passes all tests.\\n\\nQuestion: {example['question_content']}\\n\\n\"\n        )\n        if example[\"starter_code\"]:\n            query_prompt += (\n                f\"You will use the following starter code to write the solution to the problem and enclose your \"\n                f\"code within delimiters.\\n```python\\n{example['starter_code']}\\n```\"\n            )\n        else:\n            query_prompt += (\n                \"Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test \"\n                \"on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python \"\n                \"program runs, it reads the inputs, runs the algorithm and writes output to STDOUT.\"\n                \"```python\\n# YOUR CODE HERE\\n```\"\n            )\n\n        # Construct test cases\n        public_test_cases = json.loads(example[\"public_test_cases\"])\n        try:\n            private_test_cases = json.loads(example[\"private_test_cases\"])\n        except Exception as e:\n            print(f\"Error loading private test cases: {e}\")\n            private_test_cases = json.loads(\n                pickle.loads(zlib.decompress(base64.b64decode(example[\"private_test_cases\"].encode(\"utf-8\"))))\n            )\n        full_test_cases = public_test_cases + private_test_cases\n\n        metadata = json.loads(example[\"metadata\"])\n        test_cases = {\n            \"inputs\": [t[\"input\"] for t in full_test_cases],\n            \"outputs\": [t[\"output\"] for t in full_test_cases],\n            \"fn_name\": metadata.get(\"func_name\", None),\n        }\n        text_cases_compressed = base64.b64encode(zlib.compress(pickle.dumps(json.dumps(test_cases)))).decode(\"utf-8\")\n        return query_prompt, text_cases_compressed\n\n    data_source = \"livecodebench/code_generation_lite\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n    dataset = load_dataset(data_source, split=\"test\")\n    # R1 Evaluation use LiveCodeBench 24.08-25.01\n    dataset = dataset.filter(lambda line: \"2024-08-00T00:00:00\" <= line[\"contest_date\"] < \"2025-01-00T00:00:00\")\n    map_fn = partial(\n        example_map_fn, process_fn=process_livecodebench, data_source=data_source, ability=\"Code\", split=\"test\"\n    )\n\n    dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8)\n    return dataset\n\n\nTASK2DATA = {\n    \"aime2024\": build_aime2024_dataset,\n    \"gpqa_diamond\": build_gpqa_dimond_dataset,\n    \"cnmo2024\": build_cnmo2024_dataset,\n    \"livecodebench\": build_livecodebench_dataset,\n}\nSUPPORTED_TASKS = TASK2DATA.keys()\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/r1\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--tasks\", default=\"all\")\n\n    args = parser.parse_args()\n\n    if args.tasks.lower() == \"all\":\n        args.tasks = SUPPORTED_TASKS\n    else:\n        args.tasks = [task.strip() for task in args.tasks.split(\",\") if task.strip()]\n        for task in args.tasks:\n            if task not in SUPPORTED_TASKS:\n                raise NotImplementedError(f\"{task} has not been supported.\")\n\n    datasets = []\n    for task in args.tasks:\n        datasets.append(TASK2DATA[task]())\n    test_dataset = concatenate_datasets(datasets)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "verl_rl/recipe/r1/main_eval.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nOffline evaluate the performance of a generated file using reward model and ground truth verifier.\nThe input is a parquet file that contains N generated sequences and (optional) the ground truth.\n\n\"\"\"\n\nfrom collections import defaultdict\n\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport ray\nfrom tqdm import tqdm\n\nfrom verl.trainer.ppo.reward import get_custom_reward_fn\nfrom verl.utils.fs import copy_to_local\n\n\n@ray.remote\ndef process_item(config, data_source, response_lst, reward_data):\n    reward_fn = get_custom_reward_fn(config)\n    ground_truth = reward_data[\"ground_truth\"]\n    score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]\n    return data_source, np.mean(score_lst)\n\n\n@hydra.main(config_path=\"config\", config_name=\"evaluation\", version_base=None)\ndef main(config):\n    local_path = copy_to_local(config.data.path)\n    dataset = pd.read_parquet(local_path)\n    responses = dataset[config.data.response_key]\n    data_sources = dataset[config.data.data_source_key]\n    reward_model_data = dataset[config.data.reward_model_key]\n\n    total = len(dataset)\n\n    # Initialize Ray\n    if not ray.is_initialized():\n        ray.init(num_cpus=config.ray_init.num_cpus)\n\n    # evaluate test_score based on data source\n    data_source_reward = defaultdict(list)\n\n    # Create remote tasks\n    remote_tasks = [\n        process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)\n    ]\n\n    # Process results as they come in\n    with tqdm(total=total) as pbar:\n        while len(remote_tasks) > 0:\n            # Use ray.wait to get completed tasks\n            done_ids, remote_tasks = ray.wait(remote_tasks)\n            for result_id in done_ids:\n                data_source, score = ray.get(result_id)\n                data_source_reward[data_source].append(score)\n                pbar.update(1)\n\n    metric_dict = {}\n    for data_source, rewards in data_source_reward.items():\n        metric_dict[f\"test_score/{data_source}\"] = np.mean(rewards)\n\n    print(metric_dict)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/recipe/r1/reward_score.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ndef reward_func(data_source, solution_str, ground_truth, extra_info=None):\n    if data_source in [\"Maxwell-Jia/AIME_2024\", \"opencompass/cnmo2024_en\", \"opencompass/cnmo2024_zh\"]:\n        from recipe.r1.tasks import math\n\n        return math.compute_score(solution_str, ground_truth)\n    elif data_source == \"Idavidrein/gpqa\":\n        from recipe.r1.tasks import gpqa\n\n        return gpqa.compute_score(solution_str, ground_truth)\n    elif data_source in [\"livecodebench/code_generation_lite\", \"livecodebench/code_generation\"]:\n        from recipe.r1.tasks import livecodebench\n\n        return livecodebench.compute_score(solution_str, ground_truth)\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_rl/recipe/r1/run_r1_distill_qwen.sh",
    "content": "MODEL_PATH=Qwen/DeepSeek-R1-Distill-Qwen-1.5B\nDATA_PATH=/workspace/datasets/r1_bench\n\n# Eval Data Process\npython3 -m recipe.r1.data_process \\\n    --local_dir $DATA_PATH \\\n    --tasks all\n\n# Generation\npython3 -m verl.trainer.main_generation \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=8 \\\n    data.path=$DATA_PATH/test.parquet \\\n    data.prompt_key=prompt \\\n    data.batch_size=1024 \\\n    data.n_samples=8 \\\n    data.output_path=$DATA_PATH/test-output-8.parquet \\\n    model.path=$MODEL_PATH \\\n    rollout.temperature=0.6 \\\n    rollout.top_p=0.95 \\\n    rollout.prompt_length=1024 \\\n    rollout.response_length=32768 \\\n    rollout.tensor_model_parallel_size=1 \\\n    rollout.gpu_memory_utilization=0.9 \\\n    rollout.max_num_batched_tokens=65536\n\n# Evaluation\npython3 -m recipe.r1.main_eval \\\n    data.path=$DATA_PATH/test-output-8.parquet \\\n    data.prompt_key=prompt \\\n    data.response_key=responses \\\n    custom_reward_function.path=recipe/r1/reward_score.py \\\n    custom_reward_function.name=reward_func\n"
  },
  {
    "path": "verl_rl/recipe/r1/tasks/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/recipe/r1/tasks/gpqa.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\n\n# Extraction Template from https://github.com/openai/simple-evals/blob/90e3e821cabba2aeb6be651dcb662b253df04225/common.py#L25\nANSWER_PATTERN_MULTICHOICE = r\"(?i)Answer[ \\t]*:[ \\t]*\\$?([A-D])\\$?\"\n\n\ndef compute_score(solution_str, ground_truth) -> float:\n    match = re.search(ANSWER_PATTERN_MULTICHOICE, solution_str)\n    extracted_answer = match.group(1) if match else None\n    score = 1.0 if extracted_answer == ground_truth else 0.0\n    return score\n"
  },
  {
    "path": "verl_rl/recipe/r1/tasks/livecodebench.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport base64\nimport json\nimport multiprocessing\nimport pickle\nimport zlib\n\n# Reuse `run_test` for convenience\nfrom verl.utils.reward_score.prime_code.testing_util import run_test\n\n\ndef _temp_run(in_outs, generation, debug, result, metadata_list, timeout):\n    res, metadata = run_test(in_outs, test=generation, debug=debug, timeout=timeout)\n    result.append(res)\n    metadata_list.append(metadata)\n\n\ndef check_correctness(in_outs, generation, timeout, debug=True):\n    \"\"\"Check correctness of code generation with a global timeout.\n    The global timeout is to catch some extreme/rare cases not handled by the timeouts\n    inside `run_test`\"\"\"\n\n    manager = multiprocessing.Manager()\n    result = manager.list()\n    metadata_list = manager.list()\n    p = multiprocessing.Process(\n        target=_temp_run,\n        args=(in_outs, generation, debug, result, metadata_list, timeout),\n    )\n    p.start()\n    p.join(timeout=(timeout + 1) * len(in_outs[\"inputs\"]) + 5)\n    if p.is_alive():\n        p.kill()\n    if not result:\n        # consider that all tests failed\n        result = [[-1 for i in range(len(in_outs[\"inputs\"]))]]\n        if debug:\n            print(\"global timeout\")\n    return result[0], metadata_list[0]\n\n\ndef compute_score(completion, test_cases):\n    solution = completion.split(\"```python\")[-1].split(\"```\")[0]\n\n    # extract test cases\n    try:\n        in_outs = json.loads(test_cases)\n    except Exception as e:\n        print(f\"Error loading test cases: {e}\")\n        in_outs = json.loads(pickle.loads(zlib.decompress(base64.b64decode(test_cases.encode(\"utf-8\")))))\n\n    success = False\n    try:\n        res, metadata = check_correctness(in_outs=in_outs, generation=solution, timeout=6, debug=False)\n        success = all(map(lambda x: x is True, res))\n    except Exception:\n        pass\n\n    return success\n"
  },
  {
    "path": "verl_rl/recipe/r1/tasks/math.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport contextlib\n\ntry:\n    from math_verify.metric import math_metric\n    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig\nexcept ImportError:\n    print(\"To use Math-Verify, please install it first by running `pip install math-verify`.\")\n\n\ndef compute_score(model_output: str, ground_truth: str) -> bool:\n    verify_func = math_metric(\n        gold_extraction_target=(LatexExtractionConfig(),),\n        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),\n    )\n    ret_score = 0.0\n\n    # Wrap the ground truth in \\boxed{} format for verification\n    ground_truth_boxed = \"\\\\boxed{\" + ground_truth + \"}\"\n    with contextlib.suppress(Exception):\n        ret_score, _ = verify_func([ground_truth_boxed], [model_output])\n\n    return ret_score\n"
  },
  {
    "path": "verl_rl/recipe/retool/retool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport logging\nimport re\nfrom typing import Any\n\nimport datasets\n\nfrom verl.tools.base_tool import OpenAIFunctionToolSchema\nfrom verl.tools.sandbox_fusion_tools import SandboxFusionTool\nfrom verl.utils.dataset import RLHFDataset\nfrom verl.utils.reward_score import math_dapo\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nlogger = logging.getLogger(__name__)\n\n\nclass CustomSandboxFusionTool(SandboxFusionTool):\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        super().__init__(config, tool_schema)\n        self.code_pattern = re.compile(r\"```python(.*?)```\", re.DOTALL)\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        code = parameters[\"code\"]\n        matches = self.code_pattern.findall(code)\n        if matches:\n            code = matches[0].strip()\n\n        # NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script\n        lines = code.split(\"\\n\")\n        for i, line in reversed(list(enumerate(lines))):\n            if line == \"\":\n                continue\n            if not lines[i].startswith(\"print\"):\n                lines[i] = f\"print({line})\"\n            break\n        code = \"\\n\".join(lines)\n\n        timeout = parameters.get(\"timeout\", self.default_timeout)\n        language = parameters.get(\"language\", self.default_language)\n        if not isinstance(code, str):\n            code = str(code)\n\n        result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language)\n        # sandbox has no score or metrics, use Nones\n        return result, None, None\n\n\nanswer_format = \"\"\"\\nThe answer format must be: \\\\boxed{'The final answer goes here.'}\"\"\"\n\n\nclass CustomRLHFDataset(RLHFDataset):\n    \"\"\"Custom dataset class to process Maxwell-Jia/AIME_2024, yentinglin/aime_2025 datasets.\"\"\"\n\n    def _read_files_and_tokenize(self):\n        dataframes = []\n        for parquet_file in self.data_files:\n            # read parquet files and cache\n            dataframe = datasets.load_dataset(parquet_file)[\"train\"]\n            data_source = \"/\".join(parquet_file.split(\"/\")[-2:])\n            if data_source in [\"Maxwell-Jia/AIME_2024\", \"yentinglin/aime_2025\"]:\n                dataframe = dataframe.map(\n                    self.map_fn, fn_kwargs={\"data_source\": data_source}, remove_columns=dataframe.column_names\n                )\n            else:\n                dataframe = dataframe.map(self.map_fn2, num_proc=16)\n            dataframes.append(dataframe)\n        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)\n\n        print(f\"dataset len: {len(self.dataframe)}\")\n\n    def map_fn(self, row: dict, *, data_source: str = None):\n        if data_source == \"Maxwell-Jia/AIME_2024\":\n            problem, answer = row[\"Problem\"], row[\"Answer\"]\n        elif data_source == \"yentinglin/aime_2025\":\n            problem, answer = row[\"problem\"], row[\"answer\"]\n\n        prompt = problem + answer_format\n        data = {\n            \"data_source\": data_source.split(\"/\")[1].lower(),  # aime_2024, aime_2025\n            \"prompt\": [{\"role\": \"user\", \"content\": prompt}],\n            \"ability\": \"MATH\",\n            \"reward_model\": {\"ground_truth\": str(answer)},\n            \"agent_name\": \"tool_agent\",\n        }\n        return data\n\n    def map_fn2(self, row: dict):\n        content = row[\"prompt\"][0][\"content\"]\n        row[\"prompt\"][0][\"content\"] = content + answer_format\n        row[\"agent_name\"] = \"tool_agent\"\n        return row\n\n\ndef compute_score(data_source, solution_str, ground_truth, extra_info):\n    # use \\\\boxed{...} answer\n    result = math_dapo.compute_score(solution_str, ground_truth, strict_box_verify=True)\n\n    # encourage model to call tools\n    num_turns = extra_info[\"num_turns\"]\n    if result[\"score\"] < 0:\n        tool_call_reward = (num_turns - 2) / 2 * 0.1\n        result[\"score\"] = min(0, result[\"score\"] + tool_call_reward)\n\n    if result[\"pred\"] is None:\n        result[\"pred\"] = \"\"\n\n    return result\n"
  },
  {
    "path": "verl_rl/recipe/retool/retool_multi_turn_sft_preprocess.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n\n#     http://www.apache.org/licenses/LICENSE-2.0\n\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the Retool dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/retool_multiturn\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--train_ratio\", default=0.9, type=float)\n    parser.add_argument(\"--seed\", default=42, type=int)\n    args = parser.parse_args()\n\n    data_source = \"swordfaith/ReTool-SFT-multi-turn\"\n    dataset = datasets.load_dataset(data_source, \"default\")\n\n    train_dataset = dataset[\"train\"]\n    shuffled_train_dataset = train_dataset.shuffle(seed=args.seed)\n    split_idx = int(len(shuffled_train_dataset) * args.train_ratio)\n    train_dataset = shuffled_train_dataset.select(range(split_idx))\n    test_dataset = shuffled_train_dataset.select(range(split_idx, len(shuffled_train_dataset)))\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            messages = example.pop(\"messages\")\n            tools = example.pop(\"tools\")\n            data = {\n                \"data_source\": data_source,\n                \"messages\": messages,\n                \"tools\": tools,\n                \"enable_thinking\": False,\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    # Create output directory\n    local_dir = os.path.expanduser(args.local_dir)\n    os.makedirs(local_dir, exist_ok=True)\n\n    # Save to parquet files\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    # Handle HDFS if specified\n    if hdfs_dir is not None:\n        try:\n            from verl.utils.hdfs_io import copy, makedirs\n\n            makedirs(hdfs_dir)\n            copy(src=local_dir, dst=hdfs_dir)\n        except ImportError:\n            print(\"Warning: HDFS support not available. Skipping HDFS copy.\")\n\n    # Print statistics\n    print(f\"Train dataset size: {len(train_dataset)}\")\n    print(f\"Test dataset size: {len(test_dataset)}\")\n    print(f\"Data saved to {local_dir}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/recipe/retool/retool_sft_preprocess.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nConvert JoeYing/ReTool-SFT to standard multi-turn tool calling messages.\n\"\"\"\n\nimport json\nimport re\nfrom typing import Any\n\nimport datasets\nfrom omegaconf import OmegaConf\n\ncode_pattern = re.compile(r\"```python(.*?)```\", re.DOTALL)\n\n\ndef extract_code_message(content: str) -> tuple[dict[str, Any], str]:\n    start, stop = \"<code>\", \"</code>\"\n    i = content.find(start)\n    if i == -1:\n        return None, content\n    j = content.find(stop)\n    assert j > i\n\n    code = content[i + len(start) : j]\n    matches = code_pattern.findall(code)\n    if matches:\n        code = matches[0].strip()\n\n    message = {\n        \"role\": \"assistant\",\n        \"content\": content[:i].strip(),\n        \"tool_calls\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"code_interpreter\",\n                    \"arguments\": {\"code\": code},\n                },\n            },\n        ],\n    }\n    return message, content[j + len(stop) :]\n\n\ndef extract_answer_message(content: str) -> tuple[dict[str, Any], str]:\n    start, stop = \"<answer>\", \"</answer>\"\n    i = content.find(start)\n    if i == -1:\n        return None, content\n    j = content.find(stop)\n    assert j > i\n\n    answer = content[:i] + content[i + len(start) : j]\n    message = {\n        \"role\": \"assistant\",\n        \"content\": answer.strip(),\n    }\n    return message, content[j + len(stop) :]\n\n\ndef extract_interpreter_message(content: str) -> tuple[dict[str, Any], str]:\n    start, stop = \"<interpreter>\", \"</interpreter>\"\n    i = content.find(start)\n    if i == -1:\n        return None, content\n    j = content.find(stop)\n    assert j > i\n\n    interpreter = content[i + len(start) : j]\n    message = {\n        \"role\": \"tool\",\n        \"content\": interpreter.strip(),\n    }\n    return message, content[j + len(stop) :]\n\n\ndef process(row: dict, *, tools: str):\n    messages = []\n\n    # extract problem\n    content = row[\"messages\"][0][\"content\"]\n    start = \"*user question:*\"\n    i = content.find(start)\n    assert i != -1\n    prompt = content[i + len(start) :].replace(\"<answer>\", \"\").replace(\"</answer>\", \"\").strip()\n    messages.append(\n        {\n            \"role\": \"user\",\n            \"content\": prompt,\n        }\n    )\n\n    # extract multi turns\n    content = row[\"messages\"][1][\"content\"]\n    role = \"assistant\"\n    while len(content) > 0:\n        if role == \"assistant\":\n            message, content = extract_code_message(content)\n            if message is None:\n                message, content = extract_answer_message(content)\n            assert message is not None\n            messages.append(message)\n            role = \"tool\"\n        else:\n            message, content = extract_interpreter_message(content)\n            assert message is not None\n            messages.append(message)\n            role = \"assistant\"\n\n    return {\"messages\": messages, \"tools\": tools}\n\n\nif __name__ == \"__main__\":\n    tools_config_file = \"recipe/retool/sandbox_fusion_tool_config.yaml\"\n    tools_config = OmegaConf.load(tools_config_file)\n    tool_schema = OmegaConf.to_container(tools_config[\"tools\"][0][\"tool_schema\"])\n    tools = json.dumps([tool_schema])\n\n    data = datasets.load_dataset(\"JoeYing/ReTool-SFT\")[\"train\"]\n    data = data.map(process, fn_kwargs={\"tools\": tools})\n    data.to_parquet(\"wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet\")\n"
  },
  {
    "path": "verl_rl/recipe/retool/run_qwen2-32b_sft.sh",
    "content": "#!/bin/bash\nset -x\n\n# set dist args\nnproc_per_node=${ARNOLD_WORKER_GPU}\nif [ ! -z \"$SINGLE\" ] && [ \"$SINGLE\" != \"0\" ]; then\n  echo \"[single node alone] SINGLE=$SINGLE\"\n  MASTER_NODE_ID=${ARNOLD_ID}\n  nnodes=1\n  node_rank=0\nelse\n  MASTER_NODE_ID=0\n  nnodes=${ARNOLD_WORKER_NUM}\n  node_rank=${ARNOLD_ID}\nfi\nmaster_addr=\"METIS_WORKER_${MASTER_NODE_ID}_HOST\"\nmaster_addr=${!master_addr}\nmaster_port=\"METIS_WORKER_${MASTER_NODE_ID}_PORT\"\nmaster_port=${!master_port}\nports=(`echo $master_port | tr ',' ' '`)\nmaster_port=${ports[0]}\necho \"[nproc_per_node: ${nproc_per_node}]\"\necho \"[nnodes: ${nnodes}]\"\necho \"[node_rank: ${node_rank}]\"\necho \"[master_addr: ${master_addr}]\"\necho \"[master_port: ${master_port}]\"\n\nexperiment_name=multiturn-sft-qwen-2.5-32b-instruct\nHDFS_ROOT=${HDFS_ROOT:-$PWD}\nDATA_ROOT=${DATA_ROOT:-$PWD}\n\nTRAIN_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet\nEVAL_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet\nMODEL_PATH=$HDFS_ROOT/model/Qwen2.5-32B-Instruct\nSAVE_PATH=$DATA_ROOT/checkpoint/$experiment_name\n\ntorchrun --nnodes=$ARNOLD_WORKER_NUM \\\n     --nproc_per_node=$ARNOLD_WORKER_GPU \\\n     --master-addr=$master_addr \\\n     --master-port=$master_port \\\n     --node-rank=$node_rank \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$TRAIN_DATA \\\n    data.val_files=$EVAL_DATA \\\n    data.max_length=16384 \\\n    data.train_batch_size=32 \\\n    data.multiturn.enable=true \\\n    data.multiturn.messages_key=messages \\\n    data.multiturn.tools_key=tools \\\n    data.micro_batch_size_per_gpu=4 \\\n    model.partial_pretrain=$MODEL_PATH \\\n    model.strategy=fsdp \\\n    trainer.default_local_dir=$SAVE_PATH \\\n    trainer.project_name=wuxibin-multiturn-sft \\\n    trainer.experiment_name=$experiment_name \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.total_epochs=6 \\\n    ulysses_sequence_parallel_size=4 \\\n    use_remove_padding=true"
  },
  {
    "path": "verl_rl/recipe/retool/run_qwen2.5_32b_sp8.sh",
    "content": "#!/bin/bash\nset -x\n\nexport PYTHONUNBUFFERED=1\nexport RUST_BACKTRACE=1\nexport HYDRA_FULL_ERROR=1\n\nulimit -n 65535\n\nEXPERIMENT_NAME=retool-multiturn-sft-qwen2.5-32b-sp8\n\ntorchrun --nnodes=1 --nproc_per_node=8 \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.max_length=16384 \\\n    data.train_batch_size=128 \\\n    data.micro_batch_size_per_gpu=4 \\\n    data.train_files=$HOME/data/retool_multi_turn_sft_preprocessed/train.parquet \\\n    data.val_files=$HOME/data/retool_multi_turn_sft_preprocessed/test.parquet \\\n    data.multiturn.enable=true \\\n    data.multiturn.messages_key=messages \\\n    data.multiturn.tools_key=tools \\\n    model.partial_pretrain=$HOME/models/Qwen/Qwen2.5-32B-Instruct \\\n    model.trust_remote_code=true \\\n    model.fsdp_config.cpu_offload=true \\\n    model.fsdp_config.offload_params=true \\\n    optim.lr=1e-6 \\\n    trainer.default_local_dir=$HOME/checkpoints/retool-multiturn-sft/$EXPERIMENT_NAME \\\n    trainer.project_name=retool-multiturn-sft \\\n    trainer.experiment_name=$EXPERIMENT_NAME \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.total_epochs=12 $@ \\\n    ulysses_sequence_parallel_size=8 \\\n    use_remove_padding=true\n"
  },
  {
    "path": "verl_rl/recipe/retool/run_qwen2.5_7b_sp4.sh",
    "content": "#!/bin/bash\nset -x\n\nexport PYTHONUNBUFFERED=1\nexport RUST_BACKTRACE=1\nexport HYDRA_FULL_ERROR=1\n\nulimit -n 65535\n\nEXPERIMENT_NAME=retool-multiturn-sft-qwen2.5-7b-sp4\n\ntorchrun --nnodes=1 --nproc_per_node=8 \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.max_length=16384 \\\n    data.train_batch_size=128 \\\n    data.micro_batch_size_per_gpu=16 \\\n    data.train_files=$HOME/data/retool_multi_turn_sft_preprocessed/train.parquet \\\n    data.val_files=$HOME/data/retool_multi_turn_sft_preprocessed/test.parquet \\\n    data.multiturn.enable=true \\\n    data.multiturn.messages_key=messages \\\n    data.multiturn.tools_key=tools \\\n    model.partial_pretrain=$HOME/models/Qwen/Qwen2.5-7B-Instruct \\\n    model.trust_remote_code=true \\\n    model.fsdp_config.cpu_offload=false \\\n    model.fsdp_config.offload_params=false \\\n    optim.lr=1e-6 \\\n    trainer.default_local_dir=$HOME/checkpoints/retool-multiturn-sft/$EXPERIMENT_NAME \\\n    trainer.project_name=retool-multiturn-sft \\\n    trainer.experiment_name=$EXPERIMENT_NAME \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.total_epochs=8 $@ \\\n    ulysses_sequence_parallel_size=4 \\\n    use_remove_padding=true\n"
  },
  {
    "path": "verl_rl/recipe/retool/run_qwen3_4b_sp4.sh",
    "content": "#!/bin/bash\nset -x\n\nexport PYTHONUNBUFFERED=1\nexport RUST_BACKTRACE=1\nexport HYDRA_FULL_ERROR=1\n\nulimit -n 65535\n\nEXPERIMENT_NAME=retool-multiturn-sft-qwen3-4b-sp4\n\ntorchrun --nnodes=1 --nproc_per_node=8 \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.max_length=16384 \\\n    data.train_batch_size=128 \\\n    data.micro_batch_size_per_gpu=16 \\\n    data.train_files=$HOME/data/retool_multi_turn_sft_preprocessed/train.parquet \\\n    data.val_files=$HOME/data/retool_multi_turn_sft_preprocessed/test.parquet \\\n    data.multiturn.enable=true \\\n    data.multiturn.messages_key=messages \\\n    data.multiturn.tools_key=tools \\\n    model.partial_pretrain=$HOME/models/Qwen/Qwen3-4B \\\n    model.trust_remote_code=true \\\n    optim.lr=1e-6 \\\n    trainer.default_local_dir=$HOME/checkpoints/retool-multiturn-sft/$EXPERIMENT_NAME \\\n    trainer.project_name=retool-multiturn-sft \\\n    trainer.experiment_name=$EXPERIMENT_NAME \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.total_epochs=12 $@ \\\n    ulysses_sequence_parallel_size=4 \\\n    use_remove_padding=true\n"
  },
  {
    "path": "verl_rl/recipe/retool/sandbox_fusion_tool_config.yaml",
    "content": "tools:\n  - class_name: \"recipe.retool.retool.CustomSandboxFusionTool\"\n    config:\n      sandbox_fusion_url: \"https://***.apigateway-cn-beijing.volceapi.com/run_code\"\n      num_workers: 128\n      enable_global_rate_limit: true\n      rate_limit: 128\n      default_timeout: 30\n      default_language: \"python\"\n      memory_limit_mb: 1024\n      type: native\n\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"code_interpreter\"\n        description: \"A tool for executing code.\"\n        parameters:\n          type: \"object\"\n          properties:\n            code:\n              type: \"string\"\n              description: \"The code to execute.\"\n          required: [\"code\"]\n"
  },
  {
    "path": "verl_rl/recipe/spin/README.md",
    "content": "# SPIN: Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models\n\nThis repository hosts a `verl` recipe inspired by the paper **\"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models\"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory.\n\n**Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models:\n\n1.  **Synthetic Data Generation:** The current model generates responses, creating its own training data from previous iterations.\n2.  **Two-Player Game Setup:** A game involving two players acted by a single LLM.\n3.  **Iterative Training:** The model progressively improves by refining its policy, with each iteration's model becoming the opponent for the next iteration.\n\nPaper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\\*, [Yihe Deng](https://github.com/uclaml/SPIN)\\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n\n[[Webpage](https://uclaml.github.io/SPIN/)] [[Huggingface](https://huggingface.co/papers/2401.01335)] [[Paper](https://arxiv.org/abs/2401.01335)] [[Original Implementation](https://github.com/uclaml/SPIN)]\n\nverl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20)\n\n---\n\n## Key Function (compute_online_dpo_loss) and Related works\nSPIN (Chen et al., 2024) proposes an iterative self-play mechanism to fine-tune language models. In each iteration, SPIN's training objective, when using a logistic loss function, is equivalent to Direct Preference Optimization (DPO) loss (Rafailov et al., 2023). \n\nThis `verl` recipe realizes SPIN's core concept by using DPO loss iteratively (Xu et al., 2023; Xiong et al., 2023; Snorkel AI, 2024). This means that in each iteration, we fine-tune the LLM using DPO loss for preference optimization. Notably, Xu et al. (2023) explored iterative preference optimization with pairwise cringe loss, while Xiong et al. (2023) discussed how to bridge theory and practice for RLHF under KL constraints using iterative training. The concept of iterative preference learning was also explored in online DPO (Guo et al., 2024), which focuses on direct alignment from online AI feedback. In online DPO, preference data is dynamically updated during training, allowing the model to learn from its own generated data.\n\nSpecifically, we developed the **`compute_online_dpo_loss`** function and built this SPIN recipe on top of it. By incorporating online preference generation, this approach enables continuously refining language models without relying on fixed external preference datasets.\n\n**Reference Papers:**\n* [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024) \n* [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023) \n* [Somethings are more cringe than others: Preference optimization with the pairwise cringe loss](https://arxiv.org/abs/2312.16682) (Xu et al., 2023) \n* [Iterative preference learning from human feedback: Bridging theory and practice for rlhf under kl-constraint](https://arxiv.org/abs/2312.11456) (Xiong et al., 2023)\n* [Snorkel-Mistral-PairRM-DPO](https://huggingface.co/snorkelai/Snorkel-Mistral-PairRM-DPO) (Snorkel AI, 2024)\n* [Direct language model alignment from online ai feedback](https://arxiv.org/abs/2402.04792) (Guo et al., 2024)\n\n\n## Our Online DPO Implementation\n\nOur `compute_online_dpo_loss` function adapts `verl`'s existing PPO infrastructure (based on `verl` v0.3.0.post1) for this iterative online DPO. Key aspects of our implementation include:\n\n* **No Critic:** Unlike PPO, we omit the value function critic.\n* **Dynamic Reference Model:** An explicit reference policy (`ref_policy_wg`) is used for DPO loss. This reference model's weights can be periodically updated from the actor (`ref_update_freq`), providing a dynamic baseline.\n* **Online Preference Generation:** The `compute_onlineDPO_pref` function (in `core_algos.py`) dynamically creates chosen/rejected pairs based on a reward source (e.g., rule-based ranking for math problems).\n* **DPO Loss Integration:** We replace PPO's policy loss with our `compute_online_dpo_loss` (in `core_algos.py`) within the actor update (`dp_actor.py`), directly optimizing the policy using the generated preferences.\n* **Iterative Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the entire self-play loop: generation, preference labeling, optional reference model updates, and policy updates, enabling continuous self-improvement aligned with SPIN's principles.\n\n---\n## Algorithm\n\nThis recipe implements an Online algorithm adapted to the `verl` Reinforcement Learning framework, which provides an alternative to PPO for fine-tuning language models.\n\n**Online Loop:** Instead of maximizing a scalar reward signal in PPO, this approach directly optimizes the policy model to align with preference data generated *online* during training:\n\n1.  **Generation:** The current model generates multiple responses for each prompt in a batch.\n2.  **Preference Labeling:** A function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). This can be done using a reward function or implicit ranking based on specific rules. (In this recipe, we use rule-based ranking on the math problem).\n3.  **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using `compute_online_dpo_loss`, comparing against a reference model.\n\n**Connection with SPIN:**\nInstead of only using a fixed target data distribution, the online generation loop in step 2 will dynamically change the target data distribution by using a certain Preference Labeling method (rule-based ranking on the math problem by selecting the better one in this recipe). This explores the direction mentioned in SPIN's paper Section 7 about \"dynamically changing target data distribution\" to potentially elevate LLM performance beyond the fixed human-annotated data ceiling.\n\n---\n\n## Reproduce the Experiment (Example Setup)\n\nThe following steps outline how to set up the environment and run the SPIN recipe, based on the provided test log using GSM8K and Qwen2.5-3B-Instruct.\n\n1.  **Setup Environment (Example using Docker):**\n    ```bash\n    # Start a container with GPU access and shared memory\n    docker run -it --name spin_test --gpus all \\\n        --shm-size=32g \\\n        --ipc=host \\\n        -v /path/to/host/.cache:/root/.cache \\\n        -e HF_TOKEN=<YOUR_HUGGINGFACE_TOKEN> \\\n        lmsysorg/sglang:latest \\\n        /bin/bash\n\n    # Inside the container or on your host machine:\n    # Ensure /tmp is writable\n    mkdir -p /tmp\n    chmod 1777 /tmp\n\n    # Install Python 3.10 (if not present) and venv\n    sudo apt update\n    sudo apt install -y python3.10 python3.10-venv tmux\n    python3 -m ensurepip --upgrade\n\n    # Create and activate a virtual environment\n    python3 -m venv ~/.python/spin_env\n    source ~/.python/spin_env/bin/activate\n\n    # Install uv (fast package installer)\n    python3 -m pip install uv\n    ```\n\n2.  **Install verl and Dependencies:**\n    ```bash\n    # Clone the verl repository and checkout the spin branch\n    cd ~\n    git clone git@github.com:volcengine/verl.git && cd verl\n\n    # Install flash-attn (handle potential build issues)\n    python3 -m uv pip install wheel packaging\n    python3 -m uv pip install flash-attn --no-build-isolation --no-deps\n\n    # Install verl with sglang extras\n    python3 -m uv pip install -e \".[sglang]\"\n    ```\n    *Note: If `flash-attn` installation fails, try the manual steps again or consult its documentation.*\n\n3.  **Login & Download Data/Model:**\n    ```bash\n    # Login to Weights & Biases (optional, for logging)\n    export WANDB_API_KEY=<YOUR_WANDB_API_KEY>\n    # wandb login\n\n    # Download the GSM8K dataset\n    python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k # Adjusted path\n\n    # Download the base model (Example: Qwen2.5-3B-Instruct)\n    huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen2.5-3B-Instruct\n    ```\n\n4.  **Configure:**\n    * Modify the configuration file (e.g., `config/spin_trainer.yaml` or the one specified in the run script) with correct paths to your downloaded model, data, desired hyperparameters (`dpo_beta`, learning rate, etc.), and distributed training settings (nodes, GPUs per node).\n    * Pay attention to `actor_rollout_ref.model_path`, `data` paths, `reward_model` config (if using one), and `trainer.ref_update_freq`.\n\n5.  **Run Training:**\n    ```bash\n    # Set CUDA visible devices (adjust based on your hardware and config)\n    export CUDA_VISIBLE_DEVICES=0,1,2,3\n\n    # Launch the training script (e.g., test.sh or a custom script)\n    # Ensure test.sh points to the correct config and main script\n    bash recipe/spin/run_spin.sh\n    ```\n\n---\n\n## Configuration\n\n* The primary configuration is typically managed through a YAML file specified in the launch script (e.g., `config/spin_trainer.yaml`).\n* Key configuration sections:\n    * `data`: Paths to training/validation prompt files, batch sizes, sequence lengths.\n    * `actor_rollout_ref`: Paths to the base model (used for actor and initial reference), FSDP settings, optimization parameters (learning rate, scheduler).\n    * `reward_model`: Configuration for the reward model used for online preference labeling (path, batch size, etc.). Can be omitted if using a simpler reward function.\n    * `algorithm`: DPO-specific hyperparameters like `dpo_beta`, `dpo_loss_type`.\n    * `trainer`: Distributed training settings (nodes, GPUs per node), logging (WandB), checkpointing frequency, and `ref_update_freq` (set > 0 to enable periodic reference model updates from the actor).\n\n---\n\n## Key Files\n\n* `main_spin.py`: Main entry point using Hydra to load the config and launch the `SpinTrainer`.\n* `spin_trainer.py`: Defines the `SpinTrainer` class, orchestrating the Online DPO training loop.\n* `fsdp_workers.py`: Implements Ray workers (Actor, Reference) potentially using FSDP.\n* `dp_actor.py`: Contains the actor class, including the DPO policy update logic.\n* `core_algos.py`: Includes helper functions for `compute_online_dpo_loss` and `compute_onlineDPO_pref`.\n* `config/spin_trainer.yaml` (or similar): Main Hydra configuration file for the recipe.\n* `run_spin.sh` (or similar): Example bash script for launching a training run.\n* `README.md`: This file.\n\n---\n\n## Acknowledgement\n\nWe sincerely thank the contribution and guidance from the `verl` community and advisors, including (adapted from SPPO):\n\n* [Zixiang Chen](https://sites.google.com/view/zxchen)\n* [Yuhao Yang](https://github.com/yhyang201)\n* [Yifan Zhang](https://github.com/yifanzhang-pro)\n* [Yongan Xiang](https://github.com/BearBiscuit05)\n* [Junrong Lin](https://github.com/ocss884)\n* [Yuxuan Tong](https://github.com/tongyx361)\n* [Guangming Shen](https://github.com/PeterSH6)\n* [Biao He](https://www.linkedin.com/in/biao-he/)\n* [Qingquan Song](https://qingquansong.github.io/)\n* [Chenyang Zhao](https://zhaochenyang20.github.io/Chayenne/)\n* [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n\n---\n"
  },
  {
    "path": "verl_rl/recipe/spin/config/spin_trainer.yaml",
    "content": "# the sppo config will override default ppo_trainer.yaml\n\nhydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\nactor_rollout_ref:\n  actor:\n    dpo_beta: 0.1\n    optim:\n      lr_warmup_steps: 15\n  rollout:\n    name: sglang\n    tensor_model_parallel_size: 2\n    gpu_memory_utilization: 0.5\n    val_kwargs:\n      n: 2  # 2 will trigger validation, 1 will bypass\n\nalgorithm:\n  adv_estimator: null\n\ntrainer:\n  log_val_generations: 0\n  ref_update_freq: 1"
  },
  {
    "path": "verl_rl/recipe/spin/core_algos.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport numpy as np\nimport torch\n\n\nclass AdaptiveKLController:\n    \"\"\"\n    Adaptive KL controller described in the paper:\n    https://arxiv.org/pdf/1909.08593.pdf\n    \"\"\"\n\n    def __init__(self, init_kl_coef, target_kl, horizon):\n        self.value = init_kl_coef\n        self.target = target_kl\n        self.horizon = horizon\n\n    def update(self, current_kl, n_steps):\n        target = self.target\n        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)\n        mult = 1 + proportional_error * n_steps / self.horizon\n        self.value *= mult\n\n\nclass FixedKLController:\n    \"\"\"Fixed KL controller.\"\"\"\n\n    def __init__(self, kl_coef):\n        self.value = kl_coef\n\n    def update(self, current_kl, n_steps):\n        pass\n\n\ndef get_kl_controller(kl_ctrl):\n    if kl_ctrl.type == \"fixed\":\n        return FixedKLController(kl_coef=kl_ctrl.kl_coef)\n    elif kl_ctrl.type == \"adaptive\":\n        assert kl_ctrl.horizon > 0, f\"horizon must be larger than 0. Got {kl_ctrl.horizon}\"\n        return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)\n    else:\n        raise NotImplementedError\n\n\ndef compute_onlinedpo_pref(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"\n    Computes preferences between pairs of sequences based on summed rewards\n    and returns a mask aligned with the interleaved batch.\n\n    Assumes inputs are interleaved: [Resp1_Prompt0, Resp2_Prompt0, Resp1_Prompt1, Resp2_Prompt1, ...]\n\n    Args:\n        token_level_rewards: Tensor of shape [batch_size * 2, seq_len]\n        response_mask: Tensor of shape [batch_size * 2, seq_len]\n\n    Returns:\n        torch.Tensor: A boolean mask of shape [batch_size * 2], where True indicates\n                      the corresponding entry is the chosen response for its pair.\n                      Example: [True, False, False, True, ...] means for prompt 0,\n                               response 1 was chosen; for prompt 1, response 2 was chosen.\n    \"\"\"\n    # print(f\"---- [DEBUG] Inside compute_onlinedpo_pref ----\")\n    if token_level_rewards.shape[0] % 2 != 0 or response_mask.shape[0] % 2 != 0:\n        raise ValueError(\n            f\"Input tensor batch dimension must be even for pair comparison, got shapes: \"\n            f\"{token_level_rewards.shape}, {response_mask.shape}\"\n        )\n    if token_level_rewards.shape != response_mask.shape:\n        raise ValueError(f\"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}\")\n\n    # 1. Calculate Sequence Scores\n    scores = (token_level_rewards * response_mask).sum(dim=-1)\n    # print(f\"  Calculated sequence scores shape: {scores.shape}\") # [batch_size * 2]\n\n    # 2. Reshape scores to group pairs: [batch_size, 2]\n    try:\n        score_pairs = scores.view(-1, 2)\n    except RuntimeError as e:\n        print(f\"ERROR reshaping scores (shape {scores.shape}) into pairs: {e}\")\n        raise e\n    print(f\"  Reshaped score pairs shape: {score_pairs.shape}\")  # [batch_size, 2]\n\n    # 3. Compare scores to find which index (0 or 1) is the winner within each pair\n    #    winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1\n    winner_indices = torch.argmax(score_pairs, dim=1)  # 0 if first is max, 1 if second is max\n    # Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max)\n    # Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1]\n    # print(f\"  Winner indices shape: {winner_indices.shape}\") # [batch_size]\n    # print(f\"  Number where Response 2 (index 1) is preferred: {winner_indices.sum().item()}\") # Counts number of 1s\n\n    # 4. Create the final [batch_size * 2] mask\n    num_pairs = score_pairs.shape[0]\n    full_batch_size = num_pairs * 2\n    # Create indices for the full batch [0, 1, 2, 3, ..., N*2-1]\n    # full_indices = torch.arange(full_batch_size, device=scores.device)\n    # Create indices corresponding to the winner within each pair's original index\n    # E.g., if winner_indices is [0, 1, 0], pair_indices is [0, 1, 2]\n    # winner_global_indices = (pair_indices * 2) + winner_indices -> [ (0*2)+0, (1*2)+1, (2*2)+0 ] -> [0, 3, 4]\n    pair_indices = torch.arange(num_pairs, device=scores.device)\n    winner_global_indices = (pair_indices * 2) + winner_indices\n\n    # Create boolean mask - True at the winner's position\n    output_preference_mask = torch.zeros(full_batch_size, dtype=torch.bool, device=scores.device)\n    output_preference_mask[winner_global_indices] = True\n\n    # print(f\"  Output preference mask shape: {output_preference_mask.shape}\") # Should be [batch_size * 2]\n    # print(f\"  Output mask True count (Chosen): {output_preference_mask.sum().item()}\") # Should be batch_size\n    # print(f\"  Output mask False count (Rejected): {(~output_preference_mask).sum().item()}\") # Should be batch_size\n    # print(f\"---- [DEBUG] Exiting compute_onlinedpo_pref ----\")\n\n    return output_preference_mask\n\n\ndef compute_online_dpo_loss(\n    policy_chosen_logps: torch.Tensor,\n    policy_rejected_logps: torch.Tensor,\n    reference_chosen_logps: torch.Tensor,\n    reference_rejected_logps: torch.Tensor,\n    beta: float,\n    label_smoothing: float = 0.0,\n    loss_type: str = \"sigmoid\",\n    reference_free: bool = False,\n) -> torch.Tensor:\n    import torch.nn.functional as F\n\n    pi_logratios = policy_chosen_logps - policy_rejected_logps\n    ref_logratios = reference_chosen_logps - reference_rejected_logps\n\n    if reference_free:\n        ref_logratios = torch.zeros_like(pi_logratios)\n\n    logits = pi_logratios - ref_logratios\n\n    if loss_type == \"sigmoid\":\n        losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing\n    elif loss_type == \"ipo\":\n        losses = (logits - 1 / (2 * beta)) ** 2\n    else:\n        raise ValueError(f\"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'.\")\n\n    return losses.mean()\n\n\ndef get_batch_logps(\n    logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False\n) -> torch.FloatTensor:\n    \"\"\"\n    Compute the log probabilities of the given labels under the given logits.\n\n    Args:\n        logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`).\n                Shape: (batch_size, sequence_length, vocab_size)\n        labels: Labels for computing the sequence log probabilities. Shape: (batch_size, sequence_length)\n        average_log_prob: If True, return the average log probability per sequence. Otherwise, return the sum.\n\n    Returns:\n        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given sequences.\n    \"\"\"\n    if logits.shape[:-1] != labels.shape:\n        raise ValueError(\"Logits and labels must have the same shape[:-1]\")\n\n    # Ensure labels are contiguous and on the same device as logits\n    labels = labels.contiguous().to(logits.device)\n    # Shift so that tokens < n predict n\n    shift_logits = logits[..., :-1, :].contiguous()\n    shift_labels = labels[..., 1:].contiguous()\n\n    # Calculate per token log probability\n    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction=\"none\")\n    per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n    per_token_logps = per_token_logps.view(\n        shift_logits.size(0), shift_logits.size(1)\n    )  # Reshape back to (batch_size, seq_len-1)\n\n    # Create a mask for the labels that are not -100\n    loss_mask = shift_labels != -100\n\n    # Apply the mask to the per token log probabilities\n    masked_logps = per_token_logps * loss_mask\n\n    # Calculate the sum or average log probability per sequence\n    sequence_logps = masked_logps.sum(dim=-1)\n\n    if average_log_prob:\n        # Avoid division by zero for sequences with no valid tokens\n        num_valid_tokens = loss_mask.sum(dim=-1)\n        return sequence_logps / torch.clamp(num_valid_tokens, min=1)\n    else:\n        return sequence_logps\n"
  },
  {
    "path": "verl_rl/recipe/spin/dp_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport itertools\nimport math\nfrom collections import defaultdict\n\nimport numpy as np\nimport torch\n\nfrom recipe.spin.core_algos import compute_online_dpo_loss, get_batch_logps\nfrom verl import DataProto\nfrom verl.utils.device import get_device_name\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.workers.actor import DataParallelPPOActor\n\n__all__ = [\"DataParallelPPOActor\"]\n\n\nclass SPINDataParallelPPOActor(DataParallelPPOActor):\n    def compute_log_prob(self, data: DataProto) -> torch.Tensor:\n        \"\"\"Compute the log probability of the responses given input_ids, attention_mask and position_ids\n\n        Args:\n            data (DataProto): a DataProto containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the\n                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.\n\n        Returns:\n            torch.Tensor: the log_prob tensor\n        \"\"\"\n        # set to eval\n        self.actor_module.eval()\n\n        micro_batch_size = data.meta_info[\"micro_batch_size\"]\n        temperature = data.meta_info[\"temperature\"]  # temperature must be in the data.meta_info to avoid silent error\n        use_dynamic_bsz = data.meta_info[\"use_dynamic_bsz\"]\n\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n        batch = data.select(batch_keys=select_keys).batch\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n\n        if has_multi_modal_inputs:\n            num_micro_batches = data.batch.batch_size[0] // micro_batch_size\n            non_tensor_select_keys = [\"multi_modal_inputs\"]\n            micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)\n        elif use_dynamic_bsz:\n            # split using dynamic bsz\n            max_token_len = data.meta_info[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)\n        else:\n            micro_batches = batch.split(micro_batch_size)\n\n        log_probs_lst = []\n        for micro_batch in micro_batches:\n            if isinstance(micro_batch, DataProto):\n                micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n\n            with torch.no_grad():\n                _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)\n            log_probs_lst.append(log_probs)\n        log_probs = torch.concat(log_probs_lst, dim=0)\n\n        if use_dynamic_bsz:\n            indices = list(itertools.chain.from_iterable(indices))\n            assert len(indices) == log_probs.size(0), f\"{len(indices)} vs. {log_probs.size()}\"\n            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n            log_probs = log_probs[revert_indices]\n\n        return log_probs\n\n    def update_policy_dpo_with_ref(self, data: DataProto):\n        \"\"\"\n        Performs the DPO update step using pre-calculated reference log probs\n        from an external, periodically updated reference model.\n        \"\"\"\n        self.actor_module.train()  # Ensure training mode\n\n        # --- Retrieve necessary data ---\n        try:\n            # Expects batch prepared by fit_dpo loop, including reference log probs\n            batch_td = data.batch\n            chosen_labels = batch_td[\"chosen_labels\"]\n            rejected_labels = batch_td[\"rejected_labels\"]\n            # ... other needed tensors like chosen/rejected input_ids, attention_mask, position_ids ...\n\n            # === Get PRE-CALCULATED reference log probs from input data ===\n            reference_chosen_logps = batch_td[\"reference_chosen_logps\"]  # Should be sequence-level logps\n            reference_rejected_logps = batch_td[\"reference_rejected_logps\"]  # Should be sequence-level logps\n            # ============================================================\n\n            # Get DPO params from meta_info\n            # beta = data.meta_info.get('dpo_beta', 0.1) # Default beta\n            beta = self.config.get(\"dpo_beta\", 0.1)  # Default beta\n            loss_type = data.meta_info.get(\"dpo_loss_type\", \"sigmoid\")\n            label_smoothing = data.meta_info.get(\"dpo_label_smoothing\", 0.0)\n            # reference_free should now be False as we provide ref logps\n            reference_free = data.meta_info.get(\"reference_free\", False)  # Default False\n\n        except KeyError as e:\n            print(f\"ERROR: Missing required key for DPO update (in update_policy_dpo): {e}\")\n            print(f\"Available keys in data.batch: {list(batch_td.keys())}\")  # Debug print\n            return {}  # Return empty metrics on error\n        except Exception as e_data:\n            print(f\"ERROR accessing data for DPO update (in update_policy_dpo): {e_data}\")\n            return {}\n\n        # --- Micro-batching Setup ---\n        micro_batch_size = self.config.get(\"ppo_micro_batch_size_per_gpu\")\n        if micro_batch_size is None:\n            # Fallback or default if not set, or raise error\n            micro_batch_size = 1  # Example fallback, adjust as needed\n            print(f\"Warning: 'ppo_micro_batch_size_per_gpu' not set, defaulting to {micro_batch_size}\")\n            # raise ValueError(\"Config 'ppo_micro_batch_size_per_gpu' must be set.\")\n\n        # Ensure chosen_input_ids exists before getting shape\n        if \"chosen_input_ids\" not in batch_td:\n            print(\"ERROR: 'chosen_input_ids' not found in batch_td for DPO update.\")\n            return {}\n        bsz = batch_td[\"chosen_input_ids\"].shape[0]\n\n        if bsz == 0:\n            print(\"Warning: DPO batch size is 0 in update_policy_dpo. Skipping update.\")\n            return {\"actor/dpo_loss\": 0.0, \"actor/grad_norm\": 0.0}  # Return zero metrics if batch is empty\n\n        num_micro_batches = math.ceil(bsz / micro_batch_size)\n        gradient_accumulation_steps = num_micro_batches\n\n        # --- Metrics Accumulation ---\n        total_loss = 0.0\n        accumulated_metrics = defaultdict(list)\n        metrics = {}  # Final metrics dict\n\n        # --- Zero Gradients ---\n        self.actor_optimizer.zero_grad(set_to_none=True)\n\n        # --- Micro-batch Loop ---\n        for i in range(num_micro_batches):\n            start_idx = i * micro_batch_size\n            end_idx = min(start_idx + micro_batch_size, bsz)\n            if start_idx >= end_idx:\n                continue\n\n            # Slice the full DPO batch into micro-batches\n            # Important: Slice ALL required tensors, including labels and inputs\n            micro_batch_chosen_labels = chosen_labels[start_idx:end_idx]\n            micro_batch_rejected_labels = rejected_labels[start_idx:end_idx]\n            micro_batch_chosen_inputs = {\n                \"input_ids\": batch_td[\"chosen_input_ids\"][start_idx:end_idx],\n                \"attention_mask\": batch_td[\"chosen_attention_mask\"][start_idx:end_idx],\n            }\n            if \"chosen_position_ids\" in batch_td:\n                micro_batch_chosen_inputs[\"position_ids\"] = batch_td[\"chosen_position_ids\"][start_idx:end_idx]\n\n            micro_batch_rejected_inputs = {\n                \"input_ids\": batch_td[\"rejected_input_ids\"][start_idx:end_idx],\n                \"attention_mask\": batch_td[\"rejected_attention_mask\"][start_idx:end_idx],\n            }\n            if \"rejected_position_ids\" in batch_td:\n                micro_batch_rejected_inputs[\"position_ids\"] = batch_td[\"rejected_position_ids\"][start_idx:end_idx]\n\n            # Determine autocast dtype\n            autocast_dtype = torch.bfloat16  # Or get dynamically from config/FSDP settings\n            # --- Autocast Forward Pass ---\n            with torch.autocast(device_type=get_device_name(), dtype=autocast_dtype):\n                # --- Step 1: Forward pass for CURRENT policy log probs (with grad) ---\n                policy_chosen_outputs = self.actor_module(**micro_batch_chosen_inputs, use_cache=False)\n                policy_rejected_outputs = self.actor_module(**micro_batch_rejected_inputs, use_cache=False)\n\n                # --- Step 2: Calculate CURRENT policy log probs using get_batch_logps ---\n                policy_chosen_logps = get_batch_logps(\n                    policy_chosen_outputs.logits, micro_batch_chosen_labels, average_log_prob=False\n                )\n                policy_rejected_logps = get_batch_logps(\n                    policy_rejected_outputs.logits, micro_batch_rejected_labels, average_log_prob=False\n                )\n\n                # --- Step 3: Retrieve PRE-CALCULATED reference log probs (NO grad needed) ---\n                # Slice the full batch reference logps for the current micro-batch\n                micro_ref_chosen_logps = reference_chosen_logps[start_idx:end_idx]\n                micro_ref_rejected_logps = reference_rejected_logps[start_idx:end_idx]\n                # --- The ActorAsRef calculation block is REMOVED ---\n\n                # --- Step 4: Calculate DPO Logits and Loss ---\n                pi_logratios = policy_chosen_logps - policy_rejected_logps\n                ref_logratios = micro_ref_chosen_logps - micro_ref_rejected_logps  # Uses pre-calculated values\n                logits = pi_logratios - ref_logratios  # DPO logits\n\n                loss = compute_online_dpo_loss(\n                    policy_chosen_logps=policy_chosen_logps,  # Has grad\n                    policy_rejected_logps=policy_rejected_logps,  # Has grad\n                    reference_chosen_logps=micro_ref_chosen_logps,  # No grad (from input)\n                    reference_rejected_logps=micro_ref_rejected_logps,  # No grad (from input)\n                    beta=beta,\n                    label_smoothing=label_smoothing,\n                    loss_type=loss_type,\n                    reference_free=reference_free,  # Should be False now\n                )\n\n                # --- Scale loss for gradient accumulation ---\n                scaled_loss = loss / gradient_accumulation_steps\n\n                # --- Accumulate Metrics ---\n                total_loss += loss.item()  # Unscaled loss\n                accumulated_metrics[\"actor/dpo_loss_batch\"].append(loss.item())\n                accumulated_metrics[\"actor/dpo_logits_batch\"].append(logits.mean().item())\n                # Accumulate policy and reference log probs/ratios if needed for debugging\n                accumulated_metrics[\"actor/policy_chosen_logps_batch\"].append(policy_chosen_logps.mean().item())\n                accumulated_metrics[\"actor/policy_rejected_logps_batch\"].append(policy_rejected_logps.mean().item())\n                accumulated_metrics[\"actor/reference_chosen_logps_batch\"].append(micro_ref_chosen_logps.mean().item())\n                accumulated_metrics[\"actor/reference_rejected_logps_batch\"].append(\n                    micro_ref_rejected_logps.mean().item()\n                )\n\n            # --- Backward Pass (outside autocast) ---\n            # Check if loss requires grad before backward\n            if scaled_loss.requires_grad:\n                scaled_loss.backward()\n            else:\n                print(f\"Warning: Scaled loss at micro-batch {i} does not require grad. Skipping backward.\")\n\n        # --- End Micro-batch Loop ---\n\n        # --- Optimizer Step (after accumulating gradients for all micro-batches) ---\n        grad_norm = self._optimizer_step()\n\n        # --- Populate Final Metrics ---\n        if num_micro_batches > 0 and bsz > 0:  # Check if any processing happened\n            metrics[\"actor/dpo_loss\"] = total_loss / num_micro_batches\n            metrics[\"actor/grad_norm\"] = (\n                grad_norm.item() if torch.is_tensor(grad_norm) and torch.isfinite(grad_norm) else float(\"inf\")\n            )\n            # Average other accumulated metrics\n            for key, val_list in accumulated_metrics.items():\n                if val_list:\n                    metrics[key.replace(\"_batch\", \"\")] = np.mean(val_list)\n\n            # Calculate accuracy / rewards / margins based on averaged logprobs if desired\n            if (\n                \"actor/policy_chosen_logps\" in metrics\n                and \"actor/policy_rejected_logps\" in metrics\n                and \"actor/reference_chosen_logps\" in metrics\n                and \"actor/reference_rejected_logps\" in metrics\n            ):\n                policy_ratio_mean = metrics[\"actor/policy_chosen_logps\"] - metrics[\"actor/policy_rejected_logps\"]\n                ref_ratio_mean = metrics[\"actor/reference_chosen_logps\"] - metrics[\"actor/reference_rejected_logps\"]\n                logits_mean = policy_ratio_mean - ref_ratio_mean\n                metrics[\"actor/rewards_chosen\"] = beta * (\n                    metrics[\"actor/policy_chosen_logps\"] - metrics[\"actor/reference_chosen_logps\"]\n                )\n                metrics[\"actor/rewards_rejected\"] = beta * (\n                    metrics[\"actor/policy_rejected_logps\"] - metrics[\"actor/reference_rejected_logps\"]\n                )\n                metrics[\"actor/rewards_accuracies\"] = float(logits_mean > 0)  # Mean accuracy proxy\n                metrics[\"actor/rewards_margins\"] = metrics[\"actor/rewards_chosen\"] - metrics[\"actor/rewards_rejected\"]\n\n        else:  # Handle case where no micro-batches were run (e.g., bsz=0)\n            metrics[\"actor/dpo_loss\"] = 0.0\n            metrics[\"actor/grad_norm\"] = 0.0\n            # Initialize other metrics to 0 or NaN as appropriate\n            for key in accumulated_metrics.keys():\n                metrics[key.replace(\"_batch\", \"\")] = 0.0\n            metrics[\"actor/rewards_chosen\"] = 0.0\n            metrics[\"actor/rewards_rejected\"] = 0.0\n            metrics[\"actor/rewards_accuracies\"] = 0.0\n            metrics[\"actor/rewards_margins\"] = 0.0\n\n        return metrics  # Return aggregated metrics\n"
  },
  {
    "path": "verl_rl/recipe/spin/fsdp_workers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport logging\nimport os\nimport warnings\n\nimport psutil\nimport torch\nimport torch.distributed\nfrom codetiming import Timer\nfrom omegaconf import open_dict\nfrom torch.distributed.device_mesh import init_device_mesh\n\nimport verl.utils.torch_functional as verl_F\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.fsdp_utils import (\n    get_fsdp_wrap_policy,\n    get_init_weight_context_manager,\n    init_fn,\n    load_fsdp_model_to_gpu,\n    load_fsdp_optimizer,\n    offload_fsdp_model_to_cpu,\n    offload_fsdp_optimizer,\n)\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.utils.profiler import log_gpu_memory_usage\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_PPO_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef create_device_mesh(world_size, fsdp_size):\n    if fsdp_size < 0 or fsdp_size >= world_size:\n        device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n    else:\n        device_mesh = init_device_mesh(\n            get_device_name(), mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=[\"ddp\", \"fsdp\"]\n        )\n    return device_mesh\n\n\ndef get_sharding_strategy(device_mesh):\n    from torch.distributed.fsdp import ShardingStrategy\n\n    if device_mesh.ndim == 1:\n        sharding_strategy = ShardingStrategy.FULL_SHARD\n    elif device_mesh.ndim == 2:\n        sharding_strategy = ShardingStrategy.HYBRID_SHARD\n    else:\n        raise NotImplementedError(f\"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2\")\n    return sharding_strategy\n\n\nclass SPINRolloutRefWorker(ActorRolloutRefWorker):\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        from recipe.spin.dp_actor import SPINDataParallelPPOActor as DataParallelPPOActor\n\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        from omegaconf import OmegaConf\n\n        override_model_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n\n        use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        use_fused_kernels = self.config.model.get(\"use_fused_kernels\", False)\n\n        if self._is_actor or self._is_rollout or self._is_ref:\n            # we need the model for actor and rollout\n            if self._is_actor or self._is_ref:\n                optim_config = self.config.actor.optim\n                fsdp_config = self.config.actor.fsdp_config\n            else:\n                optim_config = None\n                fsdp_config = OmegaConf.create()\n            self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = (\n                self._build_model_optimizer(\n                    model_path=self.config.model.path,\n                    fsdp_config=fsdp_config,\n                    optim_config=optim_config,\n                    override_model_config=override_model_config,\n                    use_remove_padding=use_remove_padding,\n                    use_fused_kernels=use_fused_kernels,\n                    enable_gradient_checkpointing=self.config.model.get(\"enable_gradient_checkpointing\", False),\n                    trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                    use_liger=self.config.model.get(\"use_liger\", False),\n                    role=\"actor\",\n                )\n            )\n\n            # get the original unwrapped module\n            self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module\n\n            if self._is_offload_optimizer:\n                offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n                log_gpu_memory_usage(\"After offload actor optimizer during init\", logger=logger)\n        # load from checkpoint\n        if self._is_actor or self._is_ref:\n            OmegaConf.set_struct(self.config.actor, True)\n            with open_dict(self.config.actor):\n                self.config.actor.use_remove_padding = use_remove_padding\n                self.config.actor.use_fused_kernels = use_fused_kernels\n            self.actor = DataParallelPPOActor(\n                config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer\n            )\n\n        if self._is_rollout:\n            self.rollout, self.rollout_sharding_manager = self._build_rollout(\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False)\n            )\n\n        if self._is_ref:\n            self.ref_module_fsdp = self._build_model_optimizer(\n                model_path=self.config.model.path,\n                fsdp_config=self.config.ref.fsdp_config,\n                optim_config=None,\n                override_model_config=override_model_config,\n                use_remove_padding=use_remove_padding,\n                use_fused_kernels=use_fused_kernels,\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                use_liger=self.config.model.get(\"use_liger\", False),\n                role=\"ref\",\n            )[0]\n            OmegaConf.set_struct(self.config.ref, True)\n            with open_dict(self.config.ref):\n                self.config.ref.use_remove_padding = use_remove_padding\n                self.config.ref.use_fused_kernels = use_fused_kernels\n            self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp,\n                optimizer=self.actor.actor_optimizer,\n                lr_scheduler=self.actor_lr_scheduler,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                checkpoint_config=self.config.actor.checkpoint,\n            )\n\n        if self._is_actor:\n            self.flops_counter = FlopsCounter(self.actor_model_config)\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp,\n                optimizer=self.actor.actor_optimizer,\n                lr_scheduler=self.actor_lr_scheduler,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                checkpoint_config=self.config.actor.checkpoint,\n            )\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def compute_ref_log_prob(self, data: DataProto):\n        assert self._is_ref\n\n        # Support all hardwares\n        data = data.to(get_device_id())\n\n        micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        data.meta_info[\"max_token_len\"] = self.config.ref.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.ref.log_prob_use_dynamic_bsz\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data)\n            output = self.ref_policy.compute_log_prob(data=data)\n            output = DataProto.from_dict(tensors={\"ref_log_prob\": output})\n            output = self.ulysses_sharding_manager.postprocess_data(output)\n\n        output = output.to(\"cpu\")\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1:\n            self.ref_policy.actor_module._handle.reshard(True)\n\n        return output\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def compute_log_prob(self, data: DataProto):\n        assert self._is_actor\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        # Support all hardwares\n        data = data.to(get_device_id())\n        # we should always recompute old_log_probs when it is HybridEngine\n        data.meta_info[\"micro_batch_size\"] = self.config.rollout.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"max_token_len\"] = self.config.rollout.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.rollout.log_prob_use_dynamic_bsz\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        # perform recompute log_prob\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data)\n            output = self.actor.compute_log_prob(data=data)\n            output = DataProto.from_dict(\n                tensors={\"old_log_probs\": output}, meta_info={\"temperature\": self.config.rollout.temperature}\n            )\n            output = self.ulysses_sharding_manager.postprocess_data(output)\n\n        output = output.to(\"cpu\")\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1:\n            self.actor.actor_module._handle.reshard(True)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n\n        log_gpu_memory_usage(\"After compute_log_prob\", logger=logger)\n        return output\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def update_actor_dpo(self, data: DataProto):\n        \"\"\"\n        Wrapper for actor update step. Handles FSDP state management.\n        Calls self.actor.update_policy which now contains DPO logic based\n        on pre-calculated log probabilities.\n        \"\"\"\n        # Support all hardwares\n        data = data.to(get_device_id())\n\n        assert self._is_actor  # Make sure this worker has the actor role\n        if self.actor is None:\n            raise RuntimeError(\"Actor instance (self.actor) not initialized in worker.\")\n\n        # --- FSDP State Management ---\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n        if self._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())\n\n        log_gpu_memory_usage(\"Before update policy (DPO via PPO path)\", logger=logger)\n\n        # --- Ulysses Sharding (if used) ---\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n\n            # --- Call the core update method (now containing DPO logic) ---\n            with Timer(name=\"update_policy_dpo_via_ppo\", logger=None) as timer:  # Use a distinct timer name\n                # Calls the modified update_policy method\n                metrics = self.actor.update_policy_dpo_with_ref(data=data)  # <-- THIS CALLS THE MODIFIED FUNCTION\n            delta_time = timer.last\n\n            # --- Add Performance Metrics ---\n            # MFU calculation might be less accurate/meaningful here for DPO\n            metrics[\"perf/approx_tokens_processed\"] = torch.sum(\n                data.batch.get(\"attention_mask\", torch.tensor(0))\n            ).item()  # Approx tokens\n            metrics[\"perf/max_memory_allocated_gb\"] = get_torch_device().max_memory_allocated() / (1024**3)\n            metrics[\"perf/max_memory_reserved_gb\"] = get_torch_device().max_memory_reserved() / (1024**3)\n            metrics[\"perf/cpu_memory_used_gb\"] = psutil.virtual_memory().used / (1024**3)\n            global_num_tokens = data.meta_info[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/actor\"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size\n\n            # --- LR Scheduler Step ---\n            lr = self.actor_lr_scheduler.get_last_lr()[0]\n            metrics[\"actor/lr\"] = lr\n            self.actor_lr_scheduler.step()\n\n            log_gpu_memory_usage(\"After update policy (DPO via PPO path)\", logger=logger)\n\n            # --- Prepare Output ---\n            output = DataProto(meta_info={\"metrics\": metrics})\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n            output = output.to(\"cpu\")\n\n        # --- FSDP State Management (Offload) ---\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n\n        return output\n\n\n# TODO(sgm): we may need to extract it to dp_reward_model.py\nclass RewardModelWorker(Worker):\n    \"\"\"\n    Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(backend=get_nccl_backend())\n        self.config = config\n\n        # build device mesh for Ulysses Sequence Parallel\n        world_size = torch.distributed.get_world_size()\n        from torch.distributed.device_mesh import init_device_mesh\n\n        fsdp_size = self.config.model.fsdp_config.fsdp_size\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        self.use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n\n        # normalize config\n        if self.config.micro_batch_size is not None:\n            self.config.micro_batch_size //= torch.distributed.get_world_size()\n            self.config.micro_batch_size_per_gpu = self.config.micro_batch_size\n\n    def _build_model(self, config):\n        # the following line is necessary\n        from torch.distributed.fsdp import CPUOffload\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from transformers import AutoConfig, AutoModelForTokenClassification\n\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.model.path)\n\n        if self.config.model.input_tokenizer is None:\n            self._do_switch_chat_template = False\n        else:\n            self._do_switch_chat_template = True\n            input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer)\n            self.input_tokenizer = hf_tokenizer(\n                input_tokenizer_local_path, trust_remote_code=config.model.get(\"trust_remote_code\", False)\n            )\n            self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n\n        trust_remote_code = config.model.get(\"trust_remote_code\", False)\n        model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)\n        model_config.num_labels = 1\n\n        # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            model_config.classifier_dropout = 0.0\n            reward_module = AutoModelForTokenClassification.from_pretrained(\n                pretrained_model_name_or_path=local_path,\n                config=model_config,\n                torch_dtype=torch.bfloat16,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )\n\n            if config.model.get(\"use_remove_padding\", False) or self.ulysses_sequence_parallel_size > 1:\n                from verl.models.transformers.monkey_patch import apply_monkey_patch\n\n                apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)\n\n            reward_module.to(torch.bfloat16)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        reward_module = FSDP(\n            reward_module,\n            param_init_fn=init_fn,\n            use_orig_params=False,\n            auto_wrap_policy=auto_wrap_policy,\n            device_id=get_device_id(),\n            sharding_strategy=sharding_strategy,  # zero3\n            sync_module_states=True,\n            cpu_offload=CPUOffload(offload_params=True),\n            forward_prefetch=False,\n            device_mesh=self.device_mesh,\n        )\n\n        return reward_module\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n        self.reward_module = self._build_model(config=self.config)\n\n    def _forward_micro_batch(self, micro_batch):\n        from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\n\n        from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs\n\n        with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch_size, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                position_ids_rmpad = index_first_axis(\n                    rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                ).transpose(0, 1)\n\n                # pad and slice the inputs if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size\n                    )\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                output = self.reward_module(\n                    input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False\n                )  # prevent model thinks we are generating\n                reward_rmpad = output.logits\n                reward_rmpad = reward_rmpad.squeeze(0)  # (total_nnz)\n\n                # gather output if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    reward_rmpad = gather_outputs_and_unpad(\n                        reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size\n                    )\n\n                # pad it back\n                rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)\n            else:\n                output = self.reward_module(\n                    input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False\n                )\n                rm_score = output.logits  # (batch_size, seq_len, 1)\n                rm_score = rm_score.squeeze(-1)\n\n            # extract the result of the last valid token\n            eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)\n            rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]\n            return rm_score\n\n    def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):\n        batch_size = data.batch.batch_size[0]\n        # expand as token_level_reward\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        response_length = data.batch[\"responses\"].shape[-1]\n        eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)\n        token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)  # (bsz, seqlen)\n        token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores\n\n        # select the response part\n        token_level_scores = token_level_scores[:, -response_length:]\n\n        return token_level_scores\n\n    def _switch_chat_template(self, data: DataProto):\n        src_max_length = data.batch[\"attention_mask\"].shape[-1]\n\n        src_tokenizer = self.input_tokenizer\n        target_tokenizer = self.tokenizer\n\n        rm_input_ids = []\n        rm_attention_mask = []\n\n        for i in range(data.batch.batch_size[0]):\n            # extract raw prompt\n            if isinstance(data.non_tensor_batch[\"raw_prompt\"][i], list):\n                chat: list = data.non_tensor_batch[\"raw_prompt\"][i]\n            else:\n                chat: list = data.non_tensor_batch[\"raw_prompt\"][i].tolist()\n\n            # extract response\n            response_ids = data.batch[\"responses\"][i]\n            response_length = response_ids.shape[-1]\n            valid_response_length = data.batch[\"attention_mask\"][i][-response_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            response = src_tokenizer.decode(valid_response_ids)\n            # remove bos and eos\n            response = response.replace(src_tokenizer.eos_token, \"\")\n\n            chat.append({\"role\": \"assistant\", \"content\": response})\n\n            prompt_with_chat_template = target_tokenizer.apply_chat_template(\n                chat, add_generation_prompt=False, tokenize=False\n            )\n            if self.rank == 0 and i == 0:\n                # for debugging purpose\n                print(f\"Switch template. chat: {prompt_with_chat_template}\")\n\n            # the maximum length is actually determined by the reward model itself\n            max_length = self.config.get(\"max_length\", src_max_length)\n            if max_length is None:\n                max_length = src_max_length\n\n            model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids, attention_mask = verl_F.postprocess_data(\n                input_ids=model_inputs[\"input_ids\"],\n                attention_mask=model_inputs[\"attention_mask\"],\n                max_length=max_length,\n                pad_token_id=target_tokenizer.pad_token_id,\n                left_pad=False,  # right padding\n                truncation=self.config.get(\"truncation\", \"right\"),\n            )  # truncate from the right\n\n            rm_input_ids.append(input_ids)\n            rm_attention_mask.append(attention_mask)\n\n        rm_input_ids = torch.cat(rm_input_ids, dim=0)\n        rm_attention_mask = torch.cat(rm_attention_mask, dim=0)\n\n        rm_position_ids = compute_position_id_with_mask(rm_attention_mask)\n\n        rm_inputs = {\"input_ids\": rm_input_ids, \"attention_mask\": rm_attention_mask, \"position_ids\": rm_position_ids}\n\n        return DataProto.from_dict(rm_inputs)\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def compute_rm_score(self, data: DataProto):\n        import itertools\n\n        from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\n\n        # Support all hardwares\n        data = data.to(get_device_id())\n        if self._do_switch_chat_template:\n            rm_data = self._switch_chat_template(data)\n        else:\n            rm_input_ids = data.batch[\"input_ids\"]\n            rm_attention_mask = data.batch[\"attention_mask\"]\n            rm_position_ids = data.batch[\"position_ids\"]\n            rm_inputs = {\n                \"input_ids\": rm_input_ids,\n                \"attention_mask\": rm_attention_mask,\n                \"position_ids\": rm_position_ids,\n            }\n            rm_data = DataProto.from_dict(rm_inputs)\n\n        # Support all hardwares\n        rm_data.batch = rm_data.batch.to(get_device_id())\n\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data)\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n\n            use_dynamic_bsz = self.config.use_dynamic_bsz\n            if use_dynamic_bsz:\n                max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)\n            else:\n                micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)\n            output = []\n            for micro_batch in micro_batches:\n                rm_score = self._forward_micro_batch(micro_batch)\n                output.append(rm_score)\n            scores = torch.cat(output, dim=0)  # (batch_size)\n\n            if use_dynamic_bsz:\n                indices = list(itertools.chain.from_iterable(indices))\n                assert len(indices) == scores.size(0), f\"{len(indices)} vs. {scores.size()}\"\n                revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                scores = scores[revert_indices]\n\n            token_level_scores = self._expand_to_token_level(data, scores)\n            # Note that this is only the scores, may not be the final rewards used to train RL\n            output = DataProto.from_dict(tensors={\"rm_scores\": token_level_scores})\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        self.reward_module._handle.reshard(True)\n\n        output = output.to(\"cpu\")\n        return output\n"
  },
  {
    "path": "verl_rl/recipe/spin/main_spin.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport hydra\nimport ray\n\nfrom recipe.spin.spin_trainer import RaySPINTrainer\nfrom verl.trainer.ppo.reward import get_custom_reward_fn\n\n\n@hydra.main(config_path=\"config\", config_name=\"spin_trainer\", version_base=None)\ndef main(config):\n    run_ppo(config)\n\n\ndef run_ppo(config) -> None:\n    # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices\n    # isolation, will solve in the future\n    os.environ[\"ENSURE_CUDA_VISIBLE_DEVICES\"] = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"\")\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        ray.init(\n            runtime_env={\n                \"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\", \"VLLM_LOGGING_LEVEL\": \"WARN\"}\n            }\n        )\n\n    runner = TaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    def run(self, config):\n        # print initial config\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n        OmegaConf.resolve(config)\n\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n\n        # instantiate tokenizer\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        processor = hf_processor(local_path, use_fast=True)  # used for multimodal LLM, could be none\n\n        # define worker classes\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n            # from recipe.spin.fsdp_workers import ActorRolloutRefWorker\n            from recipe.spin.fsdp_workers import SPINRolloutRefWorker\n            from verl.single_controller.ray import RayWorkerGroup\n\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n\n            ray_worker_group_cls = NVMegatronRayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from recipe.spin.spin_trainer import ResourcePoolManager, Role\n\n        role_worker_mapping = {\n            # Role.ActorRollout: ray.remote(ActorRolloutRefWorker),\n            Role.ActorRollout: ray.remote(SPINRolloutRefWorker),\n            # Role.Critic: ray.remote(CriticWorker),\n        }\n\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {\n            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n        }\n        mapping = {\n            Role.ActorRollout: global_pool_id,\n            # Role.Critic: global_pool_id,\n        }\n\n        if config.reward_model.enable:\n            if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                from recipe.spin.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # use reference model\n        # if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n        # role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n        role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker)\n        mapping[Role.RefPolicy] = global_pool_id\n\n        from verl.workers.reward_manager import get_reward_manager_cls\n\n        # Note(haibin.lin): please make sure custom reward managers are imported and\n        # registered via `verl.workers.reward_manager.register`\n        reward_manager_name = config.reward_model.get(\"reward_manager\", \"naive\")\n        reward_manager_cls = get_reward_manager_cls(reward_manager_name)\n\n        compute_score = get_custom_reward_fn(config)\n        reward_kwargs = dict(config.reward_model.get(\"reward_kwargs\", {}))\n        reward_fn = reward_manager_cls(\n            tokenizer=tokenizer,\n            num_examine=0,\n            compute_score=compute_score,\n            reward_fn_key=config.data.reward_fn_key,\n            **reward_kwargs,\n        )\n\n        # Note that we always use function-based RM for validation\n        val_reward_fn = reward_manager_cls(\n            tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key\n        )\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        trainer = RaySPINTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n        )\n        trainer.init_workers()\n        trainer.fit_dpo()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/recipe/spin/run_spin.sh",
    "content": "set -e\nset -x\nVISIBLE_DEVICES=\"4,5,6,7\"\nexport HYDRA_FULL_ERROR=1\n\nCUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} python3 -m recipe.spin.main_spin \\\n  data.train_files=$HOME/data/gsm8k/train.parquet \\\n  data.val_files=$HOME/data/gsm8k/test.parquet \\\n  data.train_batch_size=1024 \\\n  data.max_prompt_length=1024 \\\n  data.max_response_length=1024 \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.actor.optim.lr=1e-6 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size=8 \\\n  actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n  actor_rollout_ref.ref.log_prob_micro_batch_size=64 \\\n  algorithm.kl_ctrl.kl_coef=0.001 \\\n  trainer.logger=console \\\n  trainer.val_before_train=True \\\n  trainer.n_gpus_per_node=4 \\\n  trainer.nnodes=1 \\\n  trainer.save_freq=-1 \\\n  trainer.test_freq=1 \\\n  +trainer.log_freq=1 \\\n  trainer.ref_update_freq=1 \\\n  trainer.total_epochs=1000 2>&1 | tee verl_demo.log"
  },
  {
    "path": "verl_rl/recipe/spin/spin_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport traceback\nimport uuid\nfrom collections import defaultdict\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass, field\nfrom enum import Enum\nfrom pprint import pprint\nfrom typing import Any, Optional\n\nimport numpy as np\nimport ray\nimport torch\nfrom codetiming import Timer\nfrom omegaconf import OmegaConf, open_dict\nfrom torch.utils.data import Dataset, Sampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom tqdm import tqdm\n\nfrom recipe.spin import core_algos\nfrom verl import DataProto\nfrom verl.protocol import pad_dataproto_to_divisor, unpad_dataproto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.ppo.metric_utils import (\n    compute_throughout_metrics,\n    compute_timing_metrics,\n    process_validation_metrics,\n    reduce_metrics,\n)\nfrom verl.trainer.ppo.ray_trainer import Role\nfrom verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path\nfrom verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance\nfrom verl.utils.torch_functional import masked_mean\nfrom verl.utils.tracking import ValidationGenerationsLogger\n\nWorkerType = type[Worker]\n\n\nclass AdvantageEstimator(str, Enum):\n    \"\"\"\n    Using an enumeration class to avoid spelling errors in adv_estimator\n    \"\"\"\n\n    GAE = \"gae\"\n    GRPO = \"grpo\"\n    REINFORCE_PLUS_PLUS = \"reinforce_plus_plus\"\n    REINFORCE_PLUS_PLUS_BASELINE = \"reinforce_plus_plus_baseline\"\n    REMAX = \"remax\"\n    RLOO = \"rloo\"\n\n\n@dataclass\nclass ResourcePoolManager:\n    \"\"\"\n    Define a resource pool specification. Resource pool will be initialized first.\n    Mapping\n    \"\"\"\n\n    resource_pool_spec: dict[str, list[int]]\n    mapping: dict[Role, str]\n    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)\n\n    def create_resource_pool(self):\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool\n            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.\n            # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different\n            # WorkerGroup for different models\n            resource_pool = RayResourcePool(\n                process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name\n            )\n            self.resource_pool_dict[resource_pool_name] = resource_pool\n\n        self._check_resource_available()\n\n    def get_resource_pool(self, role: Role) -> RayResourcePool:\n        \"\"\"Get the resource pool of the worker_cls\"\"\"\n        return self.resource_pool_dict[self.mapping[role]]\n\n    def get_n_gpus(self) -> int:\n        \"\"\"Get the number of gpus in this cluster.\"\"\"\n        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])\n\n    def _check_resource_available(self):\n        \"\"\"Check if the resource pool can be satisfied in this ray cluster.\"\"\"\n        node_available_resources = ray.state.available_resources_per_node()\n        node_available_gpus = {node: node_info.get(\"GPU\", 0) for node, node_info in node_available_resources.items()}\n\n        # check total required gpus can be satisfied\n        total_available_gpus = sum(node_available_gpus.values())\n        total_required_gpus = sum(\n            [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]\n        )\n        if total_available_gpus < total_required_gpus:\n            raise ValueError(\n                f\"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}\"\n            )\n\n        # check each resource pool can be satisfied, O(#resource_pools * #nodes)\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)\n            for node, available_gpus in node_available_gpus.items():\n                if available_gpus >= num_gpus:\n                    node_available_gpus[node] -= num_gpus\n                    num_nodes -= 1\n                    if num_nodes == 0:\n                        break\n            if num_nodes > 0:\n                raise ValueError(\n                    f\"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this \"\n                    f\"ray cluster\"\n                )\n\n\ndef _compute_response_info(batch: DataProto) -> dict[str, Any]:\n    \"\"\"Placeholder: Computes prompt and response lengths.\"\"\"\n    try:\n        # Assuming 'prompts' and 'responses' keys exist after generation/union\n        prompt_len = batch.batch[\"prompts\"].shape[1]\n        resp_len = batch.batch[\"responses\"].shape[1]\n        # This is simplified - real implementation might use attention masks\n        # to get actual lengths per sample.\n        batch_size = batch.batch.batch_size[0]\n        prompt_lengths_tensor = torch.full((batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device)\n        response_lengths_tensor = torch.full((batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device)\n\n        # Try getting actual lengths from attention mask if possible (more accurate)\n        if \"response_mask\" in batch.batch:\n            response_lengths_tensor = batch.batch[\"response_mask\"].sum(dim=1).float()\n            # if \"attention_mask\" in batch.batch and \"response_mask\" in batch.batch:\n            # full_mask = batch.batch[\"attention_mask\"]\n            # resp_mask = batch.batch[\"response_mask\"]\n            # Infer prompt mask length based on where response mask starts or total length\n            # This logic depends heavily on how your masks are constructed.\n            # Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor\n            # Fallback to using prompt shape if mask logic is complex:\n            prompt_lengths_tensor = torch.tensor(\n                [batch.batch[\"prompts\"].shape[1]] * batch_size, dtype=torch.float32, device=batch.batch.device\n            )\n\n        return {\n            \"prompt_length\": prompt_lengths_tensor,\n            \"response_length\": response_lengths_tensor,\n            \"max_response_length\": resp_len,\n            \"max_prompt_length\": prompt_len,  # Or from config if fixed padding\n        }\n    except KeyError as e:\n        print(f\"Warning: Missing key in _compute_response_info: {e}. Returning defaults.\")\n        # Return default/dummy values if keys are missing\n        b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1\n        max_resp = batch.batch.get(\"responses\").shape[1] if batch.batch.get(\"responses\") is not None else 0\n        max_prompt = batch.batch.get(\"prompts\").shape[1] if batch.batch.get(\"prompts\") is not None else 0\n        return {\n            \"prompt_length\": torch.zeros(b_size),\n            \"response_length\": torch.zeros(b_size),\n            \"max_response_length\": max_resp,\n            \"max_prompt_length\": max_prompt,\n        }\n\n\n# --- Modified Metric Function ---\ndef compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]:\n    \"\"\"\n    Computes and returns metrics relevant for the DPO-like process.\n    Assumes 'batch' contains results after generation and preference marking,\n    potentially including 'dpo_logits', 'preferences', 'chosen_logps', etc.\n    Removes PPO-specific advantage/return/critic metrics.\n    \"\"\"\n    print(\"---- [DEBUG] Computing DPO Data Metrics ----\")\n    metrics = {}\n    try:\n        # --- Scores and Rewards (from reward_fn) ---\n        if \"token_level_scores\" in batch.batch and batch.batch[\"token_level_scores\"] is not None:\n            sequence_score = batch.batch[\"token_level_scores\"].sum(-1)\n            metrics.update(\n                {\n                    \"reward/score/mean\": torch.mean(sequence_score).item(),\n                    \"reward/score/max\": torch.max(sequence_score).item(),\n                    \"reward/score/min\": torch.min(sequence_score).item(),\n                }\n            )\n        else:\n            print(\"DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.\")\n\n        if \"token_level_rewards\" in batch.batch and batch.batch[\"token_level_rewards\"] is not None:\n            sequence_reward = batch.batch[\"token_level_rewards\"].sum(-1)\n            metrics.update(\n                {\n                    \"reward/rewards/mean\": torch.mean(sequence_reward).item(),\n                    \"reward/rewards/max\": torch.max(sequence_reward).item(),\n                    \"reward/rewards/min\": torch.min(sequence_reward).item(),\n                }\n            )\n        else:\n            print(\"DEBUG compute_dpo_data_metrics: 'token_level_rewards' not found.\")\n\n        # --- DPO Specific Metrics (if stored previously) ---\n        if \"dpo_logits\" in batch.batch and batch.batch[\"dpo_logits\"] is not None:\n            metrics[\"actor/dpo_logits\"] = batch.batch[\"dpo_logits\"].mean().item()\n        else:\n            print(\"DEBUG compute_dpo_data_metrics: 'dpo_logits' not found.\")\n\n        if \"chosen_logps\" in batch.batch and batch.batch[\"chosen_logps\"] is not None:\n            metrics[\"actor/chosen_logps\"] = batch.batch[\"chosen_logps\"].mean().item()\n        else:\n            print(\"DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.\")\n\n        if \"rejected_logps\" in batch.batch and batch.batch[\"rejected_logps\"] is not None:\n            metrics[\"actor/rejected_logps\"] = batch.batch[\"rejected_logps\"].mean().item()\n        else:\n            print(\"DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.\")\n\n        # Add metrics based on the 'preferences' mask if available\n        # if \"preferences\" in batch.batch and batch.batch[\"preferences\"] is not None:\n        # prefs_mask = batch.batch[\"preferences\"]  # Shape [batch_size * n]\n        # Calculate accuracy based on RM scores (assuming higher score -> True in mask)\n        # Requires chosen/rejected scores to be available or recalculated\n        # This is complex here, better calculated in the main loop or update function\n\n        # --- Length Metrics ---\n        response_info = _compute_response_info(batch)\n        prompt_length = response_info[\"prompt_length\"]\n        response_length = response_info[\"response_length\"]\n        max_response_length = response_info[\"max_response_length\"]\n        max_prompt_length = response_info[\"max_prompt_length\"]  # Use calculated or from config\n\n        metrics.update(\n            {\n                \"response_length/mean\": torch.mean(response_length).item(),\n                \"response_length/max\": torch.max(response_length).item(),\n                \"response_length/min\": torch.min(response_length).item(),\n                \"response_length/clip_ratio\": torch.mean(torch.eq(response_length, max_response_length).float()).item(),\n                \"prompt_length/mean\": torch.mean(prompt_length).item(),\n                \"prompt_length/max\": torch.max(prompt_length).item(),\n                \"prompt_length/min\": torch.min(prompt_length).item(),\n                # Prompt clip ratio might need adjustment based on how max_prompt_length is defined\n                \"prompt_length/clip_ratio\": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(),\n            }\n        )\n\n    except KeyError as e:\n        print(f\"ERROR in compute_dpo_data_metrics: Missing key {e}\")\n    except Exception as e:\n        print(f\"ERROR in compute_dpo_data_metrics: {e}\")\n        traceback.print_exc()\n\n    print(f\"---- [DEBUG] Calculated DPO Data Metrics: {list(metrics.keys())} ----\")\n    return metrics\n\n\ndef apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty=\"kl\"):\n    responses = data.batch[\"responses\"]\n    response_length = responses.size(1)\n    token_level_scores = data.batch[\"token_level_scores\"]\n    batch_size = data.batch.batch_size[0]\n    attention_mask = data.batch[\"attention_mask\"]\n    response_mask = attention_mask[:, -response_length:]\n\n    # compute kl between ref_policy and current policy\n    # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.\n    kld = core_algos.kl_penalty(\n        data.batch[\"old_log_probs\"], data.batch[\"ref_log_prob\"], kl_penalty=kl_penalty\n    )  # (batch_size, response_length)\n    kld = kld * response_mask\n    beta = kl_ctrl.value\n\n    token_level_rewards = token_level_scores - beta * kld\n\n    current_kl = masked_mean(kld, mask=response_mask, axis=-1)  # average over sequence\n    current_kl = torch.mean(current_kl, dim=0).item()\n\n    # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837\n    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)\n    data.batch[\"token_level_rewards\"] = token_level_rewards\n\n    metrics = {\"actor/reward_kl_penalty\": current_kl, \"actor/reward_kl_penalty_coeff\": beta}\n\n    return data, metrics\n\n\ndef compute_response_mask(data: DataProto):\n    responses = data.batch[\"responses\"]\n    response_length = responses.size(1)\n    attention_mask = data.batch[\"attention_mask\"]\n    return attention_mask[:, -response_length:]\n\n\ndef compute_onlineDPO_pref(data: DataProto):\n    \"\"\"\n    Wrapper to compute DPO preference and add it to the DataProto batch.\n    Includes debugging prints.\n    \"\"\"\n    # print(f\"\\n---- [DEBUG] Entering compute_onlineDPO_pref ----\")\n    # print(f\"  Input batch keys: {list(data.batch.keys())}\")\n\n    # Check inputs\n    rewards_tensor = data.batch.get(\"token_level_rewards\")\n    mask_tensor = data.batch.get(\"response_mask\")\n\n    if rewards_tensor is None or mask_tensor is None:\n        print(\"  ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!\")\n        # Handle error case - maybe return original data or raise?\n        # Returning original data for now to potentially allow skipping\n        return data\n\n    try:\n        preferences = core_algos.compute_onlinedpo_pref(token_level_rewards=rewards_tensor, response_mask=mask_tensor)\n        # Store the result\n        data.batch[\"preferences\"] = preferences\n\n    except AttributeError:\n        print(\"ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!\")\n        # Assign dummy value or raise error\n        data.batch[\"preferences\"] = None  # Indicate failure\n    except Exception as e_pref:\n        print(f\"ERROR during core_algos.compute_online_dpo_preference: {e_pref}\")\n        import traceback\n\n        traceback.print_exc()\n        data.batch[\"preferences\"] = None  # Indicate failure\n\n    # print(f\"---- [DEBUG] Exiting compute_onlineDPO_pref ----\")\n    return data\n\n\n@contextmanager\ndef _timer(name: str, timing_raw: dict[str, float]):\n    with Timer(name=name, logger=None) as timer:\n        yield\n    timing_raw[name] = timer.last\n\n\nclass RaySPINTrainer:\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        train_dataset: Optional[Dataset] = None,\n        val_dataset: Optional[Dataset] = None,\n        collate_fn=None,\n        train_sampler: Optional[Sampler] = None,\n        device_name=None,\n    ):\n        # assert get_torch_device().is_available(), 'cuda must be available on driver'\n\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n        assert self.hybrid_engine, \"Currently, only support hybrid engine\"\n\n        if self.hybrid_engine:\n            assert Role.ActorRollout in role_worker_mapping, f\"{role_worker_mapping.keys()=}\"\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = Role.RefPolicy in role_worker_mapping\n        self.use_rm = Role.RewardModel in role_worker_mapping\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.validation_generations_logger = ValidationGenerationsLogger()\n        self.async_rollout_mode = False\n        self.device_name = device_name if device_name else self.config.trainer.device\n\n        # define in-reward KL control\n        # kl loss control currently not suppoorted\n        if config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)\n\n        self.use_critic = False\n        self._validate_config()\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n    def _validate_config(self):\n        config = self.config\n        # number of GPUs total\n        n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes\n\n        # 1. Check total batch size for data correctness\n        real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n\n        assert real_train_batch_size % n_gpus == 0, (\n            f\"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus}).\"\n        )\n\n        # A helper function to check \"micro_batch_size\" vs \"micro_batch_size_per_gpu\"\n        # We throw an error if the user sets both. The new convention is \"..._micro_batch_size_per_gpu\".\n        def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):\n            settings = {\n                \"actor_rollout_ref.actor\": \"micro_batch_size\",\n                \"critic\": \"micro_batch_size\",\n                \"reward_model\": \"micro_batch_size\",\n                \"actor_rollout_ref.ref\": \"log_prob_micro_batch_size\",\n                \"actor_rollout_ref.rollout\": \"log_prob_micro_batch_size\",\n            }\n\n            if name in settings:\n                param = settings[name]\n                param_per_gpu = f\"{param}_per_gpu\"\n\n                if mbs is None and mbs_per_gpu is None:\n                    raise ValueError(\n                        f\"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.\"\n                    )\n\n                if mbs is not None and mbs_per_gpu is not None:\n                    raise ValueError(\n                        f\"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. \"\n                        f\"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported \"\n                        f\"(the former is deprecated).\"\n                    )\n\n        if not config.actor_rollout_ref.actor.use_dynamic_bsz:\n            # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu\n            check_mutually_exclusive(\n                config.actor_rollout_ref.actor.ppo_micro_batch_size,\n                config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,\n                \"actor_rollout_ref.actor\",\n            )\n\n            if self.use_reference_policy:\n                # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu\n                check_mutually_exclusive(\n                    config.actor_rollout_ref.ref.log_prob_micro_batch_size,\n                    config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,\n                    \"actor_rollout_ref.ref\",\n                )\n\n            #  The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu\n            check_mutually_exclusive(\n                config.actor_rollout_ref.rollout.log_prob_micro_batch_size,\n                config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,\n                \"actor_rollout_ref.rollout\",\n            )\n\n        if self.use_critic and not config.critic.use_dynamic_bsz:\n            # Check for critic micro-batch size conflicts\n            check_mutually_exclusive(\n                config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, \"critic\"\n            )\n\n        # Check for reward model micro-batch size conflicts\n        if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:\n            check_mutually_exclusive(\n                config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, \"reward_model\"\n            )\n\n        # Actor\n        # check if train_batch_size is larger than ppo_mini_batch_size\n        # if NOT dynamic_bsz, we must ensure:\n        #    ppo_mini_batch_size is divisible by ppo_micro_batch_size\n        #    ppo_micro_batch_size * sequence_parallel_size >= n_gpus\n        if not config.actor_rollout_ref.actor.use_dynamic_bsz:\n            assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size\n            sp_size = config.actor_rollout_ref.actor.get(\"ulysses_sequence_parallel_size\", 1)\n            if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:\n                assert (\n                    config.actor_rollout_ref.actor.ppo_mini_batch_size\n                    % config.actor_rollout_ref.actor.ppo_micro_batch_size\n                    == 0\n                )\n                assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus\n\n        assert config.actor_rollout_ref.actor.loss_agg_mode in [\n            \"token-mean\",\n            \"seq-mean-token-sum\",\n            \"seq-mean-token-mean\",\n        ], f\"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}\"\n\n        if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:\n            print(\"NOTICE: You have both enabled in-reward kl and kl loss.\")\n\n        # critic\n        if self.use_critic and not config.critic.use_dynamic_bsz:\n            assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size\n            sp_size = config.critic.get(\"ulysses_sequence_parallel_size\", 1)\n            if config.critic.ppo_micro_batch_size is not None:\n                assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0\n                assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus\n\n        # Check if use_remove_padding is enabled when using sequence parallelism for fsdp\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            if (\n                config.actor_rollout_ref.actor.get(\"ulysses_sequence_parallel_size\", 1) > 1\n                or config.actor_rollout_ref.ref.get(\"ulysses_sequence_parallel_size\", 1) > 1\n            ):\n                assert config.actor_rollout_ref.model.use_remove_padding, (\n                    \"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`.\"\n                )\n\n        if self.use_critic and config.critic.strategy in {\"fsdp\", \"fsdp2\"}:\n            if config.critic.get(\"ulysses_sequence_parallel_size\", 1) > 1:\n                assert config.critic.model.use_remove_padding, (\n                    \"When using sequence parallelism for critic, you must enable `use_remove_padding`.\"\n                )\n\n        if config.data.get(\"val_batch_size\", None) is not None:\n            print(\n                \"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines \"\n                \"as a whole batch, which will schedule the memory themselves.\"\n            )\n\n        # check eval config\n        if config.actor_rollout_ref.rollout.val_kwargs.do_sample:\n            assert config.actor_rollout_ref.rollout.temperature > 0, (\n                \"validation gen temperature should be greater than 0 when enabling do_sample\"\n            )\n\n        print(\"[validate_config] All configuration checks passed successfully!\")\n\n    def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):\n        \"\"\"\n        Creates the train and validation dataloaders.\n        \"\"\"\n        # TODO: we have to make sure the batch size is divisible by the dp size\n        from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler\n\n        if train_dataset is None:\n            train_dataset = create_rl_dataset(\n                self.config.data.train_files, self.config.data, self.tokenizer, self.processor\n            )\n        if val_dataset is None:\n            val_dataset = create_rl_dataset(\n                self.config.data.val_files, self.config.data, self.tokenizer, self.processor\n            )\n        self.train_dataset, self.val_dataset = train_dataset, val_dataset\n\n        if train_sampler is None:\n            train_sampler = create_rl_sampler(self.config.data, self.train_dataset)\n        if collate_fn is None:\n            from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn\n\n            collate_fn = default_collate_fn\n\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=self.config.data.get(\"gen_batch_size\", self.config.data.train_batch_size),\n            num_workers=self.config.data.get(\"dataloader_num_workers\", 8),\n            drop_last=True,\n            collate_fn=collate_fn,\n            sampler=train_sampler,\n        )\n\n        val_batch_size = self.config.data.val_batch_size  # Prefer config value if set\n        if val_batch_size is None:\n            val_batch_size = len(self.val_dataset)\n\n        self.val_dataloader = StatefulDataLoader(\n            dataset=self.val_dataset,\n            batch_size=val_batch_size,\n            num_workers=self.config.data.get(\"dataloader_num_workers\", 8),\n            shuffle=False,\n            drop_last=False,\n            collate_fn=collate_fn,\n        )\n\n        assert len(self.train_dataloader) >= 1, \"Train dataloader is empty!\"\n        assert len(self.val_dataloader) >= 1, \"Validation dataloader is empty!\"\n\n        print(\n            f\"Size of train dataloader: {len(self.train_dataloader)}, \"\n            f\"Size of val dataloader: {len(self.val_dataloader)}\"\n        )\n\n        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n\n        if self.config.trainer.total_training_steps is not None:\n            total_training_steps = self.config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n        print(f\"Total training steps: {self.total_training_steps}\")\n\n        try:\n            OmegaConf.set_struct(self.config, True)\n            with open_dict(self.config):\n                if OmegaConf.select(self.config, \"actor_rollout_ref.actor.optim\"):\n                    self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps\n                if OmegaConf.select(self.config, \"critic.optim\"):\n                    self.config.critic.optim.total_training_steps = total_training_steps\n        except Exception as e:\n            print(f\"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}\")\n\n    def _maybe_log_val_generations(self, inputs, outputs, scores):\n        \"\"\"Log a table of validation samples to the configured logger (wandb or swanlab)\"\"\"\n\n        generations_to_log = self.config.trainer.log_val_generations\n\n        if generations_to_log == 0:\n            return\n\n        import numpy as np\n\n        # Create tuples of (input, output, score) and sort by input text\n        samples = list(zip(inputs, outputs, scores, strict=True))\n        samples.sort(key=lambda x: x[0])  # Sort by input text\n\n        # Use fixed random seed for deterministic shuffling\n        rng = np.random.RandomState(42)\n        rng.shuffle(samples)\n\n        # Take first N samples after shuffling\n        samples = samples[:generations_to_log]\n\n        # Log to each configured logger\n        self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)\n\n    def _validate(self):\n        data_source_lst = []\n        reward_extra_infos_dict: dict[str, list] = defaultdict(list)\n\n        # Lists to collect samples for the table\n        sample_inputs = []\n        sample_outputs = []\n        sample_scores = []\n\n        for test_data in self.val_dataloader:\n            test_batch = DataProto.from_single_dict(test_data)\n\n            # repeat test batch\n            test_batch = test_batch.repeat(\n                repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True\n            )\n\n            # we only do validation on rule-based rm\n            if self.config.reward_model.enable and test_batch[0].non_tensor_batch[\"reward_model\"][\"style\"] == \"model\":\n                return {}\n\n            # Store original inputs\n            input_ids = test_batch.batch[\"input_ids\"]\n            # TODO: Can we keep special tokens except for padding tokens?\n            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]\n            sample_inputs.extend(input_texts)\n\n            batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n            non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n            if \"multi_modal_inputs\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.extend([\"multi_modal_data\", \"multi_modal_inputs\"])\n            if \"raw_prompt\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"raw_prompt\")\n            if \"tools_kwargs\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"tools_kwargs\")\n            test_gen_batch = test_batch.pop(\n                batch_keys=batch_keys_to_pop,\n                non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n            )\n\n            test_gen_batch.meta_info = {\n                \"eos_token_id\": self.tokenizer.eos_token_id,\n                \"pad_token_id\": self.tokenizer.pad_token_id,\n                \"recompute_log_prob\": False,\n                \"do_sample\": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,\n                \"validate\": True,\n            }\n            print(f\"test_gen_batch meta info: {test_gen_batch.meta_info}\")\n\n            # pad to be divisible by dp_size\n            test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)\n            if not self.async_rollout_mode:\n                test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)\n            else:\n                test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)\n\n            # unpad\n            test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)\n            print(\"validation generation end\")\n\n            # Store generated outputs\n            output_ids = test_output_gen_batch.batch[\"responses\"]\n            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]\n            sample_outputs.extend(output_texts)\n\n            test_batch = test_batch.union(test_output_gen_batch)\n\n            # evaluate using reward_function\n            result = self.val_reward_fn(test_batch, return_dict=True)\n            reward_tensor = result[\"reward_tensor\"]\n            scores = reward_tensor.sum(-1).cpu().tolist()\n            sample_scores.extend(scores)\n\n            reward_extra_infos_dict[\"reward\"].extend(scores)\n            if \"reward_extra_info\" in result:\n                for key, lst in result[\"reward_extra_info\"].items():\n                    reward_extra_infos_dict[key].extend(lst)\n\n            data_source_lst.append(test_batch.non_tensor_batch.get(\"data_source\", [\"unknown\"] * reward_tensor.shape[0]))\n\n        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)\n\n        # dump generations\n        val_data_dir = self.config.trainer.get(\"validation_data_dir\", None)\n        if val_data_dir:\n            self._dump_generations(\n                inputs=sample_inputs,\n                outputs=sample_outputs,\n                scores=sample_scores,\n                reward_extra_infos_dict=reward_extra_infos_dict,\n                dump_path=val_data_dir,\n            )\n\n        for key_info, lst in reward_extra_infos_dict.items():\n            assert len(lst) == 0 or len(lst) == len(sample_scores), f\"{key_info}: {len(lst)=}, {len(sample_scores)=}\"\n\n        data_sources = np.concatenate(data_source_lst, axis=0)\n        print(f\"DEBUG: Data sources shape: {data_sources.shape}\")  # Added Print\n        print(f\"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}\")  # Added Print\n\n        data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)\n        print(\n            f\"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}\"\n        )  # Added Print\n        metric_dict = {}\n        for data_source, var2metric2val in data_src2var2metric2val.items():\n            core_var = \"acc\" if \"acc\" in var2metric2val else \"reward\"\n            for var_name, metric2val in var2metric2val.items():\n                n_max = max([int(name.split(\"@\")[-1].split(\"/\")[0]) for name in metric2val.keys()])\n                for metric_name, metric_val in metric2val.items():\n                    if (\n                        (var_name == core_var)\n                        and any(metric_name.startswith(pfx) for pfx in [\"mean\", \"maj\", \"best\"])\n                        and (f\"@{n_max}\" in metric_name)\n                    ):\n                        metric_sec = \"val-core\"\n                    else:\n                        metric_sec = \"val-aux\"\n                    pfx = f\"{metric_sec}/{data_source}/{var_name}/{metric_name}\"\n                    metric_dict[pfx] = metric_val\n\n        return metric_dict\n\n    def init_workers(self):\n        \"\"\"Init resource pool and worker group\"\"\"\n        self.resource_pool_manager.create_resource_pool()\n\n        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}\n\n        # create actor and rollout\n        if self.hybrid_engine:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)\n            actor_rollout_cls = RayClassWithInitArgs(\n                cls=self.role_worker_mapping[Role.ActorRollout],\n                config=self.config.actor_rollout_ref,\n                role=\"actor_rollout\",\n            )\n            self.resource_pool_to_cls[resource_pool][\"actor_rollout\"] = actor_rollout_cls\n        else:\n            raise NotImplementedError\n\n        # create critic\n        if self.use_critic:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)\n            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)\n            self.resource_pool_to_cls[resource_pool][\"critic\"] = critic_cls\n\n        # create reference policy if needed\n        if self.use_reference_policy:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)\n            ref_policy_cls = RayClassWithInitArgs(\n                self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role=\"ref\"\n            )\n            self.resource_pool_to_cls[resource_pool][\"ref\"] = ref_policy_cls\n\n        # create a reward model if reward_fn is None\n        if self.use_rm:\n            # we create a RM here\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)\n            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)\n            self.resource_pool_to_cls[resource_pool][\"rm\"] = rm_cls\n\n        # initialize WorkerGroup\n        # NOTE: if you want to use a different resource pool for each role, which can support different\n        # parallel size,\n        # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to\n        # different worker groups.\n        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.\n        all_wg = {}\n        self.wg_dicts = []\n        wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup\n        if OmegaConf.select(self.config.trainer, \"ray_wait_register_center_timeout\") is not None:\n            wg_kwargs[\"ray_wait_register_center_timeout\"] = self.config.trainer.ray_wait_register_center_timeout\n        wg_kwargs[\"device_name\"] = self.device_name\n\n        for resource_pool, class_dict in self.resource_pool_to_cls.items():\n            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n            wg_dict = self.ray_worker_group_cls(\n                resource_pool=resource_pool,\n                ray_cls_with_init=worker_dict_cls,\n                **wg_kwargs,\n            )\n            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n            all_wg.update(spawn_wg)\n            # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699\n            self.wg_dicts.append(wg_dict)\n\n        if self.use_critic:\n            self.critic_wg = all_wg[\"critic\"]\n            self.critic_wg.init_model()\n\n        if self.use_reference_policy:\n            self.ref_policy_wg = all_wg[\"ref\"]\n            self.ref_policy_wg.init_model()\n\n        if self.use_rm:\n            self.rm_wg = all_wg[\"rm\"]\n            self.rm_wg.init_model()\n\n        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory\n        self.actor_rollout_wg = all_wg[\"actor_rollout\"]\n        self.actor_rollout_wg.init_model()\n\n    def _save_checkpoint(self):\n        # path: given_path + `/global_step_{global_steps}` + `/actor`\n        local_global_step_folder = os.path.join(\n            self.config.trainer.default_local_dir, f\"global_step_{self.global_steps}\"\n        )\n\n        print(f\"local_global_step_folder: {local_global_step_folder}\")\n        actor_local_path = os.path.join(local_global_step_folder, \"actor\")\n\n        actor_remote_path = (\n            None\n            if self.config.trainer.default_hdfs_dir is None\n            else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"actor\")\n        )\n\n        remove_previous_ckpt_in_save = self.config.trainer.get(\"remove_previous_ckpt_in_save\", False)\n        if remove_previous_ckpt_in_save:\n            print(\n                \"Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and \"\n                \"max_critic_ckpt_to_keep=1 instead\"\n            )\n        max_actor_ckpt_to_keep = (\n            self.config.trainer.get(\"max_actor_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n        max_critic_ckpt_to_keep = (\n            self.config.trainer.get(\"max_critic_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n\n        self.actor_rollout_wg.save_checkpoint(\n            actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep\n        )\n\n        if self.use_critic:\n            critic_local_path = os.path.join(local_global_step_folder, \"critic\")\n            critic_remote_path = (\n                None\n                if self.config.trainer.default_hdfs_dir is None\n                else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"critic\")\n            )\n            self.critic_wg.save_checkpoint(\n                critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep\n            )\n\n        # save dataloader\n        dataloader_local_path = os.path.join(local_global_step_folder, \"data.pt\")\n        dataloader_state_dict = self.train_dataloader.state_dict()\n        torch.save(dataloader_state_dict, dataloader_local_path)\n\n        # latest checkpointed iteration tracker (for atomic usage)\n        local_latest_checkpointed_iteration = os.path.join(\n            self.config.trainer.default_local_dir, \"latest_checkpointed_iteration.txt\"\n        )\n        with open(local_latest_checkpointed_iteration, \"w\") as f:\n            f.write(str(self.global_steps))\n\n    def _load_checkpoint(self):\n        if self.config.trainer.resume_mode == \"disable\":\n            return 0\n\n        # load from hdfs\n        if self.config.trainer.default_hdfs_dir is not None:\n            raise NotImplementedError(\"load from hdfs is not implemented yet\")\n        else:\n            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path\n            if not os.path.isabs(checkpoint_folder):\n                working_dir = os.getcwd()\n                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)\n            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest\n\n        # find global_step_folder\n        if self.config.trainer.resume_mode == \"auto\":\n            if global_step_folder is None:\n                print(\"Training from scratch\")\n                return 0\n        else:\n            if self.config.trainer.resume_mode == \"resume_path\":\n                assert isinstance(self.config.trainer.resume_from_path, str), \"resume ckpt must be str type\"\n                assert \"global_step_\" in self.config.trainer.resume_from_path, (\n                    \"resume ckpt must specify the global_steps\"\n                )\n                global_step_folder = self.config.trainer.resume_from_path\n                if not os.path.isabs(global_step_folder):\n                    working_dir = os.getcwd()\n                    global_step_folder = os.path.join(working_dir, global_step_folder)\n        print(f\"Load from checkpoint folder: {global_step_folder}\")\n        # set global step\n        self.global_steps = int(global_step_folder.split(\"global_step_\")[-1])\n\n        print(f\"Setting global step to {self.global_steps}\")\n        print(f\"Resuming from {global_step_folder}\")\n\n        actor_path = os.path.join(global_step_folder, \"actor\")\n        critic_path = os.path.join(global_step_folder, \"critic\")\n        # load actor\n        self.actor_rollout_wg.load_checkpoint(\n            actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n        )\n        # load critic\n        if self.use_critic:\n            self.critic_wg.load_checkpoint(\n                critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n            )\n\n        # load dataloader,\n        # TODO: from remote not implemented yet\n        dataloader_local_path = os.path.join(global_step_folder, \"data.pt\")\n        if os.path.exists(dataloader_local_path):\n            dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)\n            self.train_dataloader.load_state_dict(dataloader_state_dict)\n        else:\n            print(f\"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch\")\n\n    def _balance_batch(self, batch: DataProto, metrics, logging_prefix=\"global_seqlen\"):\n        \"\"\"Reorder the data on single controller such that each dp rank gets similar total tokens\"\"\"\n        attention_mask = batch.batch[\"attention_mask\"]\n        batch_size = attention_mask.shape[0]\n        global_seqlen_lst = batch.batch[\"attention_mask\"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)\n        world_size = self.actor_rollout_wg.world_size\n        global_partition_lst = get_seqlen_balanced_partitions(\n            global_seqlen_lst, k_partitions=world_size, equal_size=True\n        )\n        # reorder based on index. The data will be automatically equally partitioned by dispatch function\n        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])\n        batch.reorder(global_idx)\n        global_balance_stats = log_seqlen_unbalance(\n            seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix\n        )\n        metrics.update(global_balance_stats)\n\n    def fit_dpo(self):  # Renamed for clarity as standard PPO loop\n        \"\"\"\n        The training loop of Online DPO using a periodically updated reference model.\n        The driver process calls worker groups for computation.\n        Advantage computation is replaced by DPO logic.\n        \"\"\"\n        import traceback  # Ensure traceback is imported\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        # Initialize logger\n        logger = None\n        try:\n            logger = Tracking(\n                project_name=self.config.trainer.project_name,\n                experiment_name=self.config.trainer.experiment_name,\n                default_backend=self.config.trainer.logger,\n                config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False),\n            )\n        except Exception as e:\n            print(f\"Warning: Failed to initialize logger: {e}\")\n\n        self.global_steps = 0\n        # Load checkpoint before doing anything\n        loaded_step = self._load_checkpoint()\n        self.global_steps = loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1\n        print(\n            f\"Starting Online DPO training from global step {self.global_steps}. \"\n            f\"Total steps: {self.total_training_steps}\"\n        )\n        print(f\"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}\")\n\n        # Check if reference policy is configured correctly for this mode\n        if not self.use_reference_policy:\n            print(\n                \"WARNING: 'use_reference_policy' is False. Periodic reference model update requires a \"\n                \"reference policy worker. DPO updates might fail or use incorrect logic.\"\n            )\n            # Consider raising an error if strict adherence is required:\n            # raise ValueError(\"Periodic reference model update requires 'use_reference_policy' to be True \"\n            #                  \"and a configured reference worker.\")\n\n        # Perform validation before training\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            print(\"Running validation before Online DPO training...\")\n            val_metrics = self._validate()\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            if logger and val_metrics:\n                logger.log(data=val_metrics, step=max(0, self.global_steps - 1))\n            if self.config.trainer.get(\"val_only\", False):\n                print(\"Validation only mode enabled. Exiting training.\")\n                if logger and hasattr(logger, \"finish\"):\n                    logger.finish()\n                return\n\n        # Add tqdm progress bar\n        progress_bar = tqdm(\n            total=self.total_training_steps,\n            initial=self.global_steps,\n            desc=\"Online DPO Training Progress\",\n            position=0,\n            leave=True,\n        )\n\n        last_val_metrics = None\n        should_stop = False\n\n        for epoch in range(self.config.trainer.total_epochs):\n            if should_stop:\n                break\n            print(f\"--- Starting Online DPO Epoch {epoch} ---\")\n            try:\n                train_iterator = iter(self.train_dataloader)\n            except TypeError:\n                print(\"Warning: Dataloader is not iterable.\")\n                train_iterator = self.train_dataloader  # Fallback attempt\n\n            for batch_idx, batch_dict in enumerate(train_iterator):\n                if self.global_steps > self.total_training_steps:\n                    should_stop = True\n                    break\n\n                metrics = {}\n                timing_raw = {}\n                step_timer = Timer(logger=None)\n                ref_log_prob_computed = False  # Flag to track if ref log probs were computed\n\n                try:  # Outer try-except for the whole step\n                    step_timer.start()\n                    with _timer(\"step\", timing_raw):\n                        batch: DataProto = DataProto.from_single_dict(batch_dict)\n                        current_batch_size = batch.batch.batch_size[0]\n                        print(\n                            f\"\\n[Step {self.global_steps}, Batch {batch_idx}] Processing batch size: \"\n                            f\"{current_batch_size}\"\n                        )\n\n                        # --- Reference Model Update ---\n                        ref_update_freq = self.config.trainer.get(\"ref_update_freq\", -1)\n                        if (\n                            self.use_reference_policy\n                            and ref_update_freq > 0\n                            and self.global_steps % ref_update_freq == 0\n                        ):\n                            print(f\"\\n[Step {self.global_steps}] Updating Reference Model Weights from Actor...\")\n                            try:\n                                # --- This requires careful implementation with FSDP ---\n                                # 1. Save actor state dict (potentially to CPU memory or disk)\n                                #    This needs to be done collectively across actor worker ranks.\n                                #    The checkpoint_manager might be adaptable, or use FSDP APIs directly.\n                                #    Example placeholder using a conceptual save/load mechanism:\n                                actor_state_path = \"/tmp/actor_state_mid\"  # Temporary path\n                                self.actor_rollout_wg.save_checkpoint(actor_state_path)  # Adapt save logic\n\n                                # 2. Load the state dict onto the reference model worker group\n                                #    This also needs collective loading on the ref worker ranks.\n                                self.ref_policy_wg.load_checkpoint(actor_state_path, None, True)  # Adapt load logic\n\n                                print(f\"[Step {self.global_steps}] Reference Model Weights Updated.\")\n                                # Optionally remove the temporary state file\n                                # os.remove(actor_state_path) # Needs rank-aware removal or shared storage\n\n                            except Exception as sync_e:\n                                print(f\"ERROR during reference model sync at step {self.global_steps}: {sync_e}\")\n                                traceback.print_exc()\n\n                        # Pop keys for generation\n                        pop_batch_keys = [\"input_ids\", \"attention_mask\"]\n                        if \"position_ids\" in batch.batch:\n                            pop_batch_keys.append(\"position_ids\")\n                        pop_non_tensor_keys = [\"raw_prompt_ids\"] if \"raw_prompt_ids\" in batch.non_tensor_batch else []\n                        if \"multi_modal_inputs\" in batch.non_tensor_batch.keys():\n                            pop_non_tensor_keys.extend([\"multi_modal_data\", \"multi_modal_inputs\"])\n                        original_non_tensor_data = batch.non_tensor_batch\n                        gen_batch = batch.pop(\n                            batch_keys=pop_batch_keys,\n                            non_tensor_batch_keys=pop_non_tensor_keys,\n                        )\n                        gen_batch = gen_batch.repeat(\n                            repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True\n                        )\n                        # (Add Debug prints for gen_batch if needed)\n\n                        # Generate sequences (chosen/rejected pairs)\n                        with _timer(\"gen\", timing_raw):\n                            try:\n                                gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                                # (Add Debug prints for gen_batch_output if needed)\n                            except Exception as gen_e:\n                                print(f\"\\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!\")\n                                print(gen_e)\n                                traceback.print_exc()\n                                print(\"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\")\n                                step_timer.stop()\n                                continue\n\n                        # Combine original prompts with generated sequences\n                        batch.non_tensor_batch = original_non_tensor_data  # Restore non-tensor data\n                        batch.non_tensor_batch[\"uid\"] = np.array(\n                            [str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object\n                        )\n                        batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                        batch = batch.union(gen_batch_output)\n                        # (Add Debug prints after union if needed)\n\n                        # Compute response mask (needed for ref logprob calc and DPO prep)\n                        batch.batch[\"response_mask\"] = compute_response_mask(batch)\n\n                        if self.config.trainer.balance_batch:\n                            self._balance_batch(batch, metrics=metrics)\n\n                        batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                        # --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef\n                        # fallback) ---\n                        # Note: For pure DPO with external ref, this 'old_log_probs' might not be strictly needed\n                        #       unless used for other metrics or a fallback. Keep it for now.\n                        with _timer(\"policy_log_prob\", timing_raw):\n                            policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch)\n                            batch = batch.union(policy_log_prob_output)  # Adds 'old_log_probs'\n                            # (Debug prints for old_log_probs)\n\n                        # --- Compute Log Probs using the EXTERNAL Reference Model ---\n                        if self.use_reference_policy:\n                            with _timer(\"ref_log_prob_dpo\", timing_raw):\n                                # print(f\"---- [Step {self.global_steps}] DEBUG DPO: Calling compute_ref_log_prob ----\")\n                                try:\n                                    # 'batch' contains interleaved chosen/rejected sequences\n                                    ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob(\n                                        batch\n                                    )  # Returns DataProto with 'ref_log_prob'\n                                    batch = batch.union(\n                                        ref_log_prob_output\n                                    )  # Adds 'ref_log_prob' key [batch_size * n, seq_len]\n                                    ref_log_prob_computed = True  # Mark success\n                                    # print(f\"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: \"\n                                    #       f\"{batch.batch['ref_log_prob'].shape} ----\")\n                                except Exception as ref_e:\n                                    print(f\"ERROR computing reference log probs at step {self.global_steps}: {ref_e}\")\n                                    traceback.print_exc()\n                                    batch.batch[\"ref_log_prob\"] = None  # Mark as failed\n                                    ref_log_prob_computed = False\n                        else:\n                            print(\n                                \"Warning: Skipping external reference log prob calculation as use_reference_policy \"\n                                \"is False.\"\n                            )\n                            # DPO update will likely fail unless ActorAsRef logic is re-enabled in dp_actor\n\n                        # --- Compute Rewards/Scores (used to determine preference) ---\n                        with _timer(\"reward_calc\", timing_raw):\n                            # (Reward calculation logic using RM or reward_fn as before)\n                            # ... Ensure this calculates 'token_level_rewards' or similar ...\n                            if self.use_rm:\n                                reward_tensor_rm = self.rm_wg.compute_rm_score(batch)\n                                batch = batch.union(reward_tensor_rm)  # Adds 'rm_scores'\n\n                            reward_extra_infos_dict = {}\n                            try:\n                                if self.reward_fn is None:\n                                    #  print(f\"---- [DEBUG Step {self.global_steps}] ERROR: self.reward_fn is None! \"\n                                    #        f\"Using dummy rewards. ----\")\n                                    # Use rm_scores if available, otherwise zeros\n                                    reward_tensor = batch.batch.get(\n                                        \"rm_scores\", torch.zeros_like(batch.batch[\"response_mask\"], dtype=torch.float32)\n                                    )\n                                else:\n                                    reward_result = self.reward_fn(batch, return_dict=True)\n                                    reward_tensor = reward_result[\"reward_tensor\"]  # Final combined reward\n                                    reward_extra_infos_dict = reward_result.get(\"reward_extra_info\", {})\n\n                            except Exception:\n                                # print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. '\n                                #       f'Using dummy rewards. ----')\n                                traceback.print_exc()\n                                reward_tensor = torch.zeros_like(batch.batch[\"response_mask\"], dtype=torch.float32)\n                                reward_extra_infos_dict = {}\n\n                            # Use 'token_level_rewards' as the key for preference calculation\n                            batch.batch[\"token_level_rewards\"] = reward_tensor\n                            if reward_extra_infos_dict:\n                                batch.non_tensor_batch.update(\n                                    {k: np.array(v) for k, v in reward_extra_infos_dict.items()}\n                                )\n\n                        # --- Determine Preferences ---\n                        # Uses 'token_level_rewards' to determine chosen/rejected based on score\n                        batch = compute_onlineDPO_pref(batch)  # Adds 'preferences' key\n\n                        # --- Prepare DPO Batch ---\n                        dpo_update_batch_proto = None  # Initialize\n                        with _timer(\"prepare_dpo_batch\", timing_raw):\n                            try:\n                                if \"preferences\" not in batch.batch or batch.batch[\"preferences\"] is None:\n                                    raise ValueError(\"'preferences' key missing or None after compute_onlineDPO_pref.\")\n\n                                # Check if reference log probs were computed successfully (if needed)\n                                if self.use_reference_policy and not ref_log_prob_computed:\n                                    raise ValueError(\"Reference log probs required but failed to compute.\")\n\n                                # Check required base keys\n                                required_keys = [\"input_ids\", \"attention_mask\", \"response_mask\"]\n                                for rk in required_keys:\n                                    if rk not in batch.batch or batch.batch[rk] is None:\n                                        raise KeyError(f\"Required key '{rk}' missing from batch for DPO prep.\")\n\n                                preferences_mask = batch.batch[\"preferences\"]  # Shape [batch_size * n]\n                                not_preferences_mask = ~preferences_mask\n\n                                # Gather Chosen/Rejected Base Tensors\n                                chosen_input_ids = batch.batch[\"input_ids\"][preferences_mask]\n                                chosen_attention_mask = batch.batch[\"attention_mask\"][preferences_mask]\n                                rejected_input_ids = batch.batch[\"input_ids\"][not_preferences_mask]\n                                rejected_attention_mask = batch.batch[\"attention_mask\"][not_preferences_mask]\n                                chosen_position_ids = (\n                                    batch.batch.get(\"position_ids\")[preferences_mask]\n                                    if \"position_ids\" in batch.batch\n                                    else None\n                                )\n                                rejected_position_ids = (\n                                    batch.batch.get(\"position_ids\")[not_preferences_mask]\n                                    if \"position_ids\" in batch.batch\n                                    else None\n                                )\n\n                                # Create Labels\n                                print(\"WARNING: Creating DPO labels using configured max_prompt_length...\")\n                                prompt_len = self.config.data.max_prompt_length\n                                chosen_labels = chosen_input_ids.clone()\n                                chosen_labels[:, :prompt_len] = -100\n                                rejected_labels = rejected_input_ids.clone()\n                                rejected_labels[:, :prompt_len] = -100\n\n                                # Calculate and Gather Reference Log Probs (Sequence Level)\n                                if self.use_reference_policy:\n                                    ref_log_prob_tensor = batch.batch[\"ref_log_prob\"]  # Token level [bsz * n, seq_len]\n                                    response_mask_full = batch.batch[\n                                        \"response_mask\"\n                                    ]  # Response mask [bsz * n, seq_len]\n                                    ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum(\n                                        dim=-1\n                                    )  # Sequence level [bsz * n]\n                                    reference_chosen_logps = ref_sequence_logps[preferences_mask]\n                                    reference_rejected_logps = ref_sequence_logps[not_preferences_mask]\n                                else:\n                                    # If not using external ref, DPO needs ActorAsRef logic in dp_actor\n                                    # We won't add the keys here, dp_actor will handle it (or fail if not modified)\n                                    print(\n                                        \"Info: Not adding explicit reference logps to DPO batch \"\n                                        \"(use_reference_policy=False).\"\n                                    )\n                                    reference_chosen_logps = None  # Explicitly None\n                                    reference_rejected_logps = None\n\n                                # Package Tensors\n                                dpo_tensors = {\n                                    \"chosen_input_ids\": chosen_input_ids,\n                                    \"chosen_attention_mask\": chosen_attention_mask,\n                                    \"chosen_labels\": chosen_labels,\n                                    \"rejected_input_ids\": rejected_input_ids,\n                                    \"rejected_attention_mask\": rejected_attention_mask,\n                                    \"rejected_labels\": rejected_labels,\n                                }\n                                # Conditionally add reference logps if computed\n                                if reference_chosen_logps is not None:\n                                    dpo_tensors[\"reference_chosen_logps\"] = reference_chosen_logps\n                                if reference_rejected_logps is not None:\n                                    dpo_tensors[\"reference_rejected_logps\"] = reference_rejected_logps\n                                # Add position ids if they exist\n                                if chosen_position_ids is not None:\n                                    dpo_tensors[\"chosen_position_ids\"] = chosen_position_ids\n                                if rejected_position_ids is not None:\n                                    dpo_tensors[\"rejected_position_ids\"] = rejected_position_ids\n\n                                # Prepare Meta Info\n                                dpo_meta = {\n                                    \"dpo_beta\": OmegaConf.select(self.config.algorithm, \"dpo_beta\", default=0.1),\n                                    \"dpo_loss_type\": OmegaConf.select(\n                                        self.config.algorithm, \"dpo_loss_type\", default=\"sigmoid\"\n                                    ),\n                                    \"dpo_label_smoothing\": OmegaConf.select(\n                                        self.config.algorithm, \"dpo_label_smoothing\", default=0.0\n                                    ),\n                                    \"use_reference_policy\": self.use_reference_policy,\n                                    \"reference_free\": not self.use_reference_policy,  # False if using external ref\n                                    \"global_step\": self.global_steps,\n                                }\n\n                                dpo_update_batch_proto = DataProto.from_dict(tensors=dpo_tensors, meta_info=dpo_meta)\n                                # print(f\"---- [Step {self.global_steps}] DEBUG DPO: Prepared DPO Update Batch ----\")\n                                # print(f\"  Keys: {list(dpo_update_batch_proto.batch.keys())}\")\n                                # print(f\"  Meta Info: {dpo_meta}\")\n\n                            except Exception as e_prep:\n                                print(f\"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}\")\n                                traceback.print_exc()\n                                dpo_update_batch_proto = None  # Skip update on error\n\n                        # --- Actor Update Step ---\n                        actor_output = None\n                        if self.config.trainer.critic_warmup <= self.global_steps and dpo_update_batch_proto:\n                            with _timer(\"update_actor\", timing_raw):\n                                # Pass the batch containing reference log probs (if computed)\n                                # The modified update_actor_dpo expects them if reference_free=False\n                                actor_output = self.actor_rollout_wg.update_actor_dpo(dpo_update_batch_proto)\n                            if actor_output and \"metrics\" in actor_output.meta_info:\n                                metrics.update(reduce_metrics(actor_output.meta_info[\"metrics\"]))\n                        elif dpo_update_batch_proto is None:\n                            print(\n                                f\"Skipping actor update at step {self.global_steps} due to DPO batch preparation error.\"\n                            )\n\n                        # --- Validation and Saving ---\n                        test_freq = OmegaConf.select(self.config.trainer, \"test_freq\", default=-1)\n                        is_last_step = self.global_steps >= self.total_training_steps\n                        if (\n                            self.val_reward_fn is not None\n                            and test_freq > 0\n                            and (is_last_step or self.global_steps % test_freq == 0)\n                        ):\n                            print(f\"\\nRunning DPO validation at step {self.global_steps}...\")\n                            val_timing_raw = {}\n                            with _timer(\"testing\", val_timing_raw):\n                                val_metrics: dict = self._validate()\n                            if is_last_step:\n                                last_val_metrics = val_metrics\n                            if val_metrics:\n                                metrics[\"time/validation_run\"] = val_timing_raw.get(\"testing\", 0)\n                                metrics.update(val_metrics)\n                            else:\n                                print(\"Validation skipped or returned no metrics.\")\n\n                        save_freq = OmegaConf.select(self.config.trainer, \"save_freq\", default=-1)\n                        if save_freq > 0 and (is_last_step or self.global_steps % save_freq == 0):\n                            print(f\"\\nSaving DPO checkpoint at step {self.global_steps}...\")\n                            with _timer(\"save_checkpoint\", timing_raw):\n                                self._save_checkpoint()  # Saves actor (and potentially critic if used elsewhere)\n                            metrics[\"time/save_checkpoint\"] = timing_raw.get(\"save_checkpoint\", 0)\n\n                    # --- End main step timer context ---\n\n                    # --- Metrics calculation AFTER the 'step' timer block ---\n                    metrics.update(compute_dpo_data_metrics(batch=batch))  # Use DPO-specific metrics\n                    metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n                    n_gpus = self.resource_pool_manager.get_n_gpus()\n                    if \"step\" in timing_raw:\n                        metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n                    else:\n                        print(\n                            f\"Warning: 'step' key missing from timing_raw at step {self.global_steps}. \"\n                            f\"Skipping throughput.\"\n                        )\n\n                    step_timer.stop()\n                    metrics[\"time/step\"] = step_timer.last\n\n                    # Log metrics\n                    log_freq = OmegaConf.select(self.config.trainer, \"log_freq\", default=1)\n                    if logger and self.global_steps % log_freq == 0:\n                        log_payload = metrics.copy()\n                        # Add learning rate to log payload\n                        if actor_output and \"actor/lr\" in metrics:\n                            log_payload[\"actor/lr\"] = metrics[\"actor/lr\"]\n\n                        print(f\"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}\")\n                        try:\n                            logger.log(data=log_payload, step=self.global_steps)\n                        except Exception as e:\n                            print(f\"Logging failed at step {self.global_steps}: {e}\")\n\n                    # Update progress bar\n                    postfix_metrics = {\n                        k: f\"{v:.3f}\" if isinstance(v, float) else v\n                        for k, v in metrics.items()\n                        if isinstance(v, int | float)\n                    }\n                    progress_bar.set_postfix(postfix_metrics)\n\n                except Exception as step_e:\n                    print(f\"\\n!!!!!!!! ERROR DURING DPO Step {self.global_steps} !!!!!!!!\")\n                    print(f\"Caught Exception: {step_e}\")\n                    traceback.print_exc()\n                    print(\"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\")\n                    step_timer.stop()\n                    should_stop = True\n                    break\n\n                if is_last_step or should_stop:\n                    print(f\"Stopping DPO training at step {self.global_steps}.\")\n                    break\n\n                self.global_steps += 1\n                progress_bar.update(1)\n\n            # End of epoch handling\n            if hasattr(self.train_dataloader, \"reset\"):\n                try:\n                    self.train_dataloader.reset()\n                except Exception as e:\n                    print(f\"Warning: Failed to reset train dataloader state: {e}\")\n            if should_stop:\n                break\n\n        # --- Final cleanup and logging ---\n        progress_bar.close()\n        final_step = max(0, self.global_steps - 1)\n        print(f\"Online DPO Training finished at step {final_step}.\")\n        # Save final checkpoint\n        save_freq = OmegaConf.select(self.config.trainer, \"save_freq\", default=-1)\n        if not self.config.trainer.get(\"val_only\", False) and (save_freq <= 0 or final_step % save_freq != 0):\n            print(f\"Saving final DPO checkpoint at step {final_step}...\")\n            self._save_checkpoint()\n\n        # Final validation run\n        if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get(\"val_only\", False):\n            print(\"Running final validation...\")\n            last_val_metrics = self._validate()\n            if last_val_metrics and logger:\n                last_val_metrics[\"final_validation\"] = True\n                try:\n                    logger.log(data=last_val_metrics, step=final_step)\n                except Exception as e:\n                    print(f\"[Final Val Metrics Log Error]: {e}\")\n\n        pprint(f\"Final validation metrics: {last_val_metrics}\")\n        if logger and hasattr(logger, \"finish\"):\n            logger.finish()\n        print(\"Online DPO Training Run Complete.\")\n"
  },
  {
    "path": "verl_rl/recipe/sppo/README.md",
    "content": "# SPPO: Self-Play Preference Optimization for Language Model Alignment\n\nThis repository hosts the community implementation for the paper [Self-Play Preference Optimization for Language Model Alignment](https://arxiv.org/abs/2405.00675). SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets.\n\nPaper Authors: [Yue Wu](https://yuewu.us/)\\*, [Zhiqing Sun](https://www.cs.cmu.edu/~zhiqings/)\\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Yiming Yang](https://www.cs.cmu.edu/~yiming/), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n\nverl Implementation Authors: [Yuhao Yang](https://github.com/yhyang201), [Chenyang Zhao](https://github.com/zhaochenyang20)\n\n[[Webpage](https://uclaml.github.io/SPPO/)] [[Huggingface](https://huggingface.co/papers/2405.00675)] [[Paper](https://arxiv.org/abs/2405.00675)][[Original Implementation](https://github.com/uclaml/SPPO)]\n\n## Reproduce the Experiment\n\nWe evaluate the performance of SPPO on the MATH dataset. Starting from an initial score of 46.6 with Qwen2.5-7B-Instruct, we achieve a score of 65.6 after 20 epochs of training, placing our model approximately in the top 20 on the [MATH leaderboard](https://paperswithcode.com/sota/math-word-problem-solving-on-math). It's important to note that verl's internal evaluation metrics may not perfectly align with the official evaluation methodology for Qwen2.5-7B-Instruct. Therefore, for consistency and fair comparison, we report only the results based on verl's evaluation framework.\n\n```\ngit clone git@github.com:volcengine/verl.git\ncd verl\npython3 -m uv pip install -e \".[sglang]\"\n\nexport WANDB_API_KEY=<YOUR_WANDB_API_KEY>\n\npython3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math\nhuggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct\n\nexport CUDA_VISIBLE_DEVICES=0,1,2,3\nbash recipe/sppo/run_qwen2.5-7b_rm.sh\n```\n\nNote that the installation would occasionally fail to install flash-attn. If this happens, you can install it manually by running:\n\n```bash\npython3 -m uv pip install wheel\npython3 -m uv pip install packaging\npython3 -m uv pip install flash-attn --no-build-isolation --no-deps\n```\n\n## Acknowledgement\n\nWe sincerely thank the contribution and guidance from:\n\n- [Yue Wu](https://yuewu.us/)\n- [Chendong Wang](https://cdwang96.github.io/)\n- [Yifan Zhang](https://github.com/yifanzhang-pro)\n- [Yongan Xiang](https://github.com/BearBiscuit05)\n- [Junrong Lin](https://github.com/ocss884)\n- [Yuxuan Tong](https://github.com/tongyx361)\n- [Guangming Shen](https://github.com/PeterSH6)\n- [Biao He](https://www.linkedin.com/in/biao-he/)\n- [Qingquan Song](https://qingquansong.github.io/)\n- [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)\n"
  },
  {
    "path": "verl_rl/recipe/sppo/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/recipe/sppo/config/sppo_trainer.yaml",
    "content": "# the sppo config will override default ppo_trainer.yaml\n\nhydra:\n  searchpath:\n    - file://verl/trainer/config\n\ndefaults:\n  - ppo_trainer\n  - _self_\n\nactor_rollout_ref:\n  actor:\n    sppo_eta: 1.0\n    optim:\n      lr_warmup_steps: 15\n  rollout:\n    name: sglang\n    tensor_model_parallel_size: 2\n    gpu_memory_utilization: 0.5\n    val_kwargs:\n      n: 2  # 2 will trigger validation, 1 will bypass\n\nalgorithm:\n  adv_estimator: null\n  sppo_eta: 1.0\n\ntrainer:\n  log_val_generations: 0"
  },
  {
    "path": "verl_rl/recipe/sppo/dp_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\n\nimport verl.utils.torch_functional as verl_F\nfrom verl import DataProto\nfrom verl.trainer.ppo.core_algos import agg_loss, kl_penalty\nfrom verl.utils.device import get_device_id\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import rearrange_micro_batches\nfrom verl.workers.actor.dp_actor import DataParallelPPOActor\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef compute_sppo_loss(\n    old_log_prob: torch.Tensor,  # (bs, seq_len)\n    log_prob: torch.Tensor,  # (bs, seq_len)\n    rewards: torch.Tensor,  # (bs,)\n    response_mask: torch.Tensor,  # (bs, seq_len)\n    eta: float = 1.0,\n    loss_agg_mode: str = \"token-mean\",\n):\n    \"\"\"\n    SPPO Loss computation.\n    \"\"\"\n    # Compute log-ratios over masked tokens\n    log_prob_sum = (log_prob * response_mask).sum(dim=1)  # (bs,)\n    old_log_prob_sum = (old_log_prob * response_mask).sum(dim=1)  # (bs,)\n    log_ratios = log_prob_sum - old_log_prob_sum  # (bs,)\n\n    scaled_rewards = eta * (rewards)\n    loss_vec = (log_ratios - scaled_rewards) ** 2  # (bs,)\n\n    if loss_agg_mode == \"token-mean\":\n        sample_mask = response_mask.any(dim=1).float()  # (bs,)\n        loss = verl_F.masked_mean(loss_vec, sample_mask)\n\n    return loss, log_ratios, scaled_rewards\n\n\nclass DataParallelSPPOActor(DataParallelPPOActor):\n    @GPUMemoryLogger(role=\"dp actor\", logger=logger)\n    def update_policy(self, data: DataProto):\n        # make sure we are in training mode\n        self.actor_module.train()\n\n        temperature = data.meta_info[\"temperature\"]  # temperature must be in the data.meta_info to avoid slient error\n        multi_turn = data.meta_info.get(\"multi_turn\", False)\n\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\", \"old_log_probs\", \"seq_level_rewards\"]\n        if multi_turn:\n            select_keys.append(\"loss_mask\")\n        if self.config.use_kl_loss:\n            select_keys.append(\"ref_log_prob\")\n        batch = data.select(batch_keys=select_keys).batch\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        if has_multi_modal_inputs:\n            num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size\n            non_tensor_select_keys = [\"multi_modal_inputs\"]\n            dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)\n        else:\n            dataloader = batch.split(self.config.ppo_mini_batch_size)\n\n        metrics = {}\n        for epoch in range(self.config.ppo_epochs):\n            for batch_idx, data in enumerate(dataloader):\n                # split batch into micro_batches\n                mini_batch = data\n                if has_multi_modal_inputs:\n                    self.gradient_accumulation = (\n                        self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n                    )\n                    num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu\n                    micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)\n                elif self.config.use_dynamic_bsz:\n                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                    micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)\n                else:\n                    self.gradient_accumulation = (\n                        self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n                    )\n                    # split batch into micro_batches\n                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n\n                self.actor_optimizer.zero_grad()\n\n                for data in micro_batches:\n                    # Support all hardwares\n                    if isinstance(data, DataProto):\n                        data = {**data.batch.to(get_device_id()), **data.non_tensor_batch}\n                    else:\n                        data = data.to(get_device_id())  # actor device is cpu when using offload\n                    responses = data[\"responses\"]\n                    response_length = responses.size(1)\n                    attention_mask = data[\"attention_mask\"]\n                    if multi_turn:\n                        response_mask = data[\"loss_mask\"][:, -response_length:]\n                    else:\n                        response_mask = attention_mask[:, -response_length:]\n\n                    old_log_prob = data[\"old_log_probs\"]\n                    rewards = data[\"seq_level_rewards\"]\n\n                    entropy_coeff = self.config.entropy_coeff\n                    loss_agg_mode = self.config.loss_agg_mode\n                    eta = self.config.get(\"sppo_eta\", 1.0)\n\n                    # all return: (bsz, response_length)\n                    calculate_entropy = False\n                    if entropy_coeff != 0:\n                        calculate_entropy = True\n                    entropy, log_prob = self._forward_micro_batch(\n                        micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy\n                    )\n\n                    pg_loss, log_ratios, preference = compute_sppo_loss(\n                        old_log_prob=old_log_prob,\n                        log_prob=log_prob,\n                        rewards=rewards,\n                        response_mask=response_mask,\n                        eta=eta,\n                        loss_agg_mode=loss_agg_mode,\n                    )\n\n                    if entropy_coeff != 0:\n                        entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n                        # compute policy loss\n                        policy_loss = pg_loss - entropy_loss * entropy_coeff\n                    else:\n                        policy_loss = pg_loss\n\n                    if self.config.use_kl_loss:\n                        ref_log_prob = data[\"ref_log_prob\"]\n                        # compute kl loss\n                        kld = kl_penalty(\n                            logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type\n                        )\n                        kl_loss = agg_loss(\n                            loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode\n                        )\n\n                        policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef\n                        metrics[\"actor/kl_loss\"] = kl_loss.detach().item()\n                        metrics[\"actor/kl_coef\"] = self.config.kl_loss_coef\n\n                    if self.config.use_dynamic_bsz:\n                        # relative to the dynamic bsz\n                        loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)\n                    else:\n                        loss = policy_loss / self.gradient_accumulation\n                    loss.backward()\n\n                    data = {\n                        \"actor/loss\": loss.detach().item(),\n                        \"actor/log_ratio_mean\": log_ratios.mean().detach().item(),\n                        \"actor/preference_mean\": preference.mean().detach().item(),\n                    }\n                    append_to_dict(metrics, data)\n\n                grad_norm = self._optimizer_step()\n                data = {\"actor/grad_norm\": grad_norm.detach().item()}\n            append_to_dict(metrics, data)\n        self.actor_optimizer.zero_grad()\n        return metrics\n"
  },
  {
    "path": "verl_rl/recipe/sppo/main_sppo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport os\n\nimport hydra\nimport ray\n\nfrom verl.trainer.ppo.reward import load_reward_manager\n\nfrom .sppo_ray_trainer import RaySPPOTrainer\n\n\n@hydra.main(config_path=\"config\", config_name=\"sppo_trainer\", version_base=None)\ndef main(config):\n    run_ppo(config)\n\n\ndef run_ppo(config) -> None:\n    # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices\n    # isolation, will solve in the future\n    os.environ[\"ENSURE_CUDA_VISIBLE_DEVICES\"] = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"\")\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        ray.init(\n            runtime_env={\n                \"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\", \"VLLM_LOGGING_LEVEL\": \"WARN\"}\n            },\n            num_cpus=config.ray_init.num_cpus,\n        )\n\n    runner = TaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    def run(self, config):\n        # print initial config\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n        OmegaConf.resolve(config)\n\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n\n        # instantiate tokenizer\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        processor = hf_processor(local_path, use_fast=True)  # used for multimodal LLM, could be none\n\n        # define worker classes\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n            from verl.single_controller.ray import RayWorkerGroup\n\n            from .sppo_worker import SPPOActorRolloutRefWorker  # , CriticWorker\n\n            actor_rollout_cls = SPPOActorRolloutRefWorker\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n            from verl.workers.megatron_workers import ActorRolloutRefWorker\n\n            actor_rollout_cls = ActorRolloutRefWorker\n            ray_worker_group_cls = NVMegatronRayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n        # sppo does not use critic\n        role_worker_mapping = {\n            Role.ActorRollout: ray.remote(actor_rollout_cls),\n        }\n\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {\n            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n        }\n        mapping = {\n            Role.ActorRollout: global_pool_id,\n        }\n\n        # we should adopt a multi-source reward function here\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # - finally, we combine all the rewards together\n        # - The reward type depends on the tag of the data\n        if config.reward_model.enable:\n            if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                from verl.workers.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # use reference model\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            role_worker_mapping[Role.RefPolicy] = ray.remote(SPPOActorRolloutRefWorker)\n            mapping[Role.RefPolicy] = global_pool_id\n\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1)\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        trainer = RaySPPOTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n        )\n        trainer.init_workers()\n        trainer.fit()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/recipe/sppo/run_qwen2.5-7b_rm.sh",
    "content": "# Discliamer: the model used in the script is only for academic purpose.\nset -x\n\n# Data preparation scripts are available in ``examples/data_preprocess``.\n# Example usage:\n#\n#   python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math\n#   python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k\n\ngsm8k_train_path=$HOME/data/math/train.parquet\ngsm8k_test_path=$HOME/data/math/test.parquet\n\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\n# prepare model ckpt\nhuggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct &\n# huggingface-cli download sfairXC/FsfairX-LLaMA3-RM-v0.1 --local-dir $HOME/models/FsfairX-LLaMA3-RM-v0.1 &\nwait\n\npython3 -m recipe.sppo.main_sppo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"$HOME/models/Qwen2.5-7B-Instruct\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang  \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='sppo-sglang' \\\n    trainer.val_before_train=True \\\n    trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=1 \\\n    trainer.total_epochs=1000 $@\n    # Note that we set lr_warmup_steps = 15 in config/sppo_trainer.yaml\n    # The experiment will converge to 0.656 on MATH dataset after 20 epochs"
  },
  {
    "path": "verl_rl/recipe/sppo/sppo_ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nFSDP PPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport uuid\nfrom copy import deepcopy\nfrom pprint import pprint\nfrom typing import Optional\n\nimport numpy as np\nimport ray\nimport torch\nfrom torch.utils.data import Dataset, Sampler\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.single_controller.ray import RayWorkerGroup\nfrom verl.trainer.ppo import core_algos\nfrom verl.trainer.ppo.core_algos import agg_loss\nfrom verl.trainer.ppo.metric_utils import reduce_metrics\nfrom verl.trainer.ppo.ray_trainer import (\n    AdvantageEstimator,\n    RayPPOTrainer,\n    ResourcePoolManager,\n    Role,\n    WorkerType,\n    apply_kl_penalty,\n    compute_response_mask,\n)\nfrom verl.trainer.ppo.reward import compute_reward, compute_reward_async\nfrom verl.utils.profiler.performance import simple_timer\nfrom verl.utils.tracking import ValidationGenerationsLogger\n\n\ndef softmean(x: torch.Tensor, beta: float, dim: int = -1, keepdim: bool = False) -> torch.Tensor:\n    \"\"\"\n    Compute SoftMean_β(x) = (1/β) * log( (1/n) * Σ exp(β * x_i) )\n    Falls back to arithmetic mean when β=0.\n    \"\"\"\n    if beta == 0.0:\n        return x.mean(dim=dim, keepdim=keepdim)\n\n    # cast beta to tensor on same device/dtype\n    beta_t = x.new_tensor(beta)\n    # numerically-stable logsumexp(β x)\n    lse = torch.logsumexp(x * beta_t, dim=dim, keepdim=keepdim)\n    n = x.size(dim)\n    log_n = x.new_tensor(n).log()\n\n    return (lse - log_n) / beta_t\n\n\ndef compute_advantage(data: DataProto, beta=1.0):\n    rewards = data.batch[\"token_level_rewards\"].sum(axis=-1)  # (bs, )\n    s_mean = softmean(rewards, beta, keepdim=True)  # (bs, )\n    rewards = rewards - s_mean  # (bs, )\n    data.batch[\"seq_level_rewards\"] = rewards  # (bs, )\n    return data\n\n\nclass RaySPPOTrainer(RayPPOTrainer):\n    \"\"\"\n    Note that this trainer runs on the driver process on a single CPU/GPU node.\n    \"\"\"\n\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        train_dataset: Optional[Dataset] = None,\n        val_dataset: Optional[Dataset] = None,\n        collate_fn=None,\n        train_sampler: Optional[Sampler] = None,\n        device_name=None,\n    ):\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n        assert self.hybrid_engine, \"Currently, only support hybrid engine\"\n\n        if self.hybrid_engine:\n            assert Role.ActorRollout in role_worker_mapping, f\"{role_worker_mapping.keys()=}\"\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = Role.RefPolicy in role_worker_mapping\n        self.use_rm = Role.RewardModel in role_worker_mapping\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.validation_generations_logger = ValidationGenerationsLogger()\n        self.device_name = device_name if device_name else self.config.trainer.device\n\n        # define in-reward KL control\n        # kl loss control currently not supported\n        if config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)\n\n        self.use_critic = False\n\n        self._validate_config()\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the\n        worker group through RPC to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n                timing_raw = {}\n                batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n                # pop those keys for generation\n                batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n                non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n                if \"multi_modal_data\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"multi_modal_data\")\n                if \"raw_prompt\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"raw_prompt\")\n                if \"tools_kwargs\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"tools_kwargs\")\n                gen_batch = batch.pop(\n                    batch_keys=batch_keys_to_pop,\n                    non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n                )\n                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                with simple_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with simple_timer(\"gen\", timing_raw):\n                        if not self.async_rollout_mode:\n                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                        else:\n                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                        with simple_timer(\"gen_max\", timing_raw):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n\n                            batch = batch.union(gen_baseline_output)\n                            reward_baseline_tensor = self.reward_fn(batch)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))\n\n                            batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del gen_baseline_batch, gen_baseline_output\n\n                    batch.non_tensor_batch[\"uid\"] = np.array(\n                        [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object\n                    )\n                    # repeat to align with repeated responses in rollout\n                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    batch = batch.union(gen_batch_output)\n\n                    batch.batch[\"response_mask\"] = compute_response_mask(batch)\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    # TODO: Decouple the DP balancing and mini-batching.\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                with simple_timer(\"reward\", timing_raw):\n                    # compute reward model score\n                    if self.use_rm:\n                        reward_tensor = self.rm_wg.compute_rm_score(batch)\n                        batch = batch.union(reward_tensor)\n\n                    if self.config.reward_model.launch_reward_fn_async:\n                        future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)\n                    else:\n                        reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)\n\n                # recompute old_log_probs\n                with simple_timer(\"old_log_prob\", timing_raw):\n                    old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                    entropys = old_log_prob.batch[\"entropys\"]\n                    response_masks = batch.batch[\"response_mask\"]\n                    loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                    entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                    old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n                    metrics.update(old_log_prob_metrics)\n                    old_log_prob.batch.pop(\"entropys\")\n                    batch = batch.union(old_log_prob)\n\n                if self.use_reference_policy:\n                    # compute reference log_prob\n                    with simple_timer(\"ref\", timing_raw):\n                        ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                        batch = batch.union(ref_log_prob)\n\n                # compute values\n                if self.use_critic:\n                    with simple_timer(\"values\", timing_raw):\n                        values = self.critic_wg.compute_values(batch)\n                        batch = batch.union(values)\n\n                with simple_timer(\"adv\", timing_raw):\n                    # we combine with rule-based rm\n                    reward_extra_infos_dict: dict[str, list]\n                    if self.config.reward_model.launch_reward_fn_async:\n                        reward_tensor, reward_extra_infos_dict = ray.get(future_reward)\n                    batch.batch[\"token_level_scores\"] = reward_tensor\n\n                    if reward_extra_infos_dict:\n                        batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})\n\n                    # compute rewards. apply_kl_penalty if available\n                    if self.config.algorithm.use_kl_in_reward:\n                        batch, kl_metrics = apply_kl_penalty(\n                            batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                        )\n                        metrics.update(kl_metrics)\n                    else:\n                        batch.batch[\"token_level_rewards\"] = batch.batch[\"token_level_scores\"]\n                        batch.batch[\"seq_level_rewards\"] = batch.batch[\"token_level_scores\"]\n\n                    beta = self.config.algorithm.sppo_eta\n                    batch = compute_advantage(batch, beta=beta)\n\n                # update critic\n                if self.use_critic:\n                    with simple_timer(\"update_critic\", timing_raw):\n                        critic_output = self.critic_wg.update_critic(batch)\n                    critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                    metrics.update(critic_output_metrics)\n\n                # implement critic warmup\n                if self.config.trainer.critic_warmup <= self.global_steps:\n                    # update actor\n                    with simple_timer(\"update_actor\", timing_raw):\n                        batch.meta_info[\"multi_turn\"] = self.config.actor_rollout_ref.rollout.multi_turn.enable\n                        actor_output = self.actor_rollout_wg.update_actor(batch)\n                    actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                    metrics.update(actor_output_metrics)\n\n                # Log rollout generations if enabled\n                rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n                if rollout_data_dir:\n                    with simple_timer(\"dump_rollout_generations\", timing_raw):\n                        print(batch.batch.keys())\n                        inputs = self.tokenizer.batch_decode(batch.batch[\"prompts\"], skip_special_tokens=True)\n                        outputs = self.tokenizer.batch_decode(batch.batch[\"responses\"], skip_special_tokens=True)\n                        scores = batch.batch[\"token_level_scores\"].sum(-1).cpu().tolist()\n                        self._dump_generations(\n                            inputs=inputs,\n                            outputs=outputs,\n                            scores=scores,\n                            reward_extra_infos_dict=reward_extra_infos_dict,\n                            dump_path=rollout_data_dir,\n                        )\n\n                # validate\n                if (\n                    self.val_reward_fn is not None\n                    and self.config.trainer.test_freq > 0\n                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                ):\n                    with simple_timer(\"testing\", timing_raw):\n                        val_metrics: dict = self._validate()\n                        if is_last_step:\n                            last_val_metrics = val_metrics\n                    metrics.update(val_metrics)\n\n                if self.config.trainer.save_freq > 0 and (\n                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0\n                ):\n                    with simple_timer(\"save_checkpoint\", timing_raw):\n                        self._save_checkpoint()\n\n            # training metrics\n            metrics.update(\n                {\n                    \"training/global_step\": self.global_steps,\n                    \"training/epoch\": epoch,\n                }\n            )\n\n            # TODO: make a canonical logger that supports various backend\n            logger.log(data=metrics, step=self.global_steps)\n\n            if is_last_step:\n                pprint(f\"Final validation metrics: {last_val_metrics}\")\n                progress_bar.close()\n                return\n\n            progress_bar.update(1)\n            self.global_steps += 1\n"
  },
  {
    "path": "verl_rl/recipe/sppo/sppo_worker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nfrom omegaconf import open_dict\n\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fsdp_utils import offload_fsdp_model_to_cpu, offload_fsdp_optimizer\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.profiler import log_gpu_memory_usage\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_PPO_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass SPPOActorRolloutRefWorker(ActorRolloutRefWorker):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        from .dp_actor import DataParallelSPPOActor\n\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        from omegaconf import OmegaConf\n\n        override_model_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n\n        use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        use_fused_kernels = self.config.model.get(\"use_fused_kernels\", False)\n\n        if self._is_actor or self._is_rollout:\n            # we need the model for actor and rollout\n            if self._is_actor:\n                optim_config = self.config.actor.optim\n                fsdp_config = self.config.actor.fsdp_config\n            else:\n                optim_config = None\n                fsdp_config = OmegaConf.create()\n            self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = (\n                self._build_model_optimizer(\n                    model_path=self.config.model.path,\n                    fsdp_config=fsdp_config,\n                    optim_config=optim_config,\n                    override_model_config=override_model_config,\n                    use_remove_padding=use_remove_padding,\n                    use_fused_kernels=use_fused_kernels,\n                    enable_gradient_checkpointing=self.config.model.get(\"enable_gradient_checkpointing\", False),\n                    trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                    use_liger=self.config.model.get(\"use_liger\", False),\n                    role=\"actor\",\n                )\n            )\n\n            # get the original unwrapped module\n            self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module\n\n            if self._is_offload_param:\n                offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n                log_gpu_memory_usage(\"After offload actor model during init\", logger=logger)\n\n            if self._is_offload_optimizer:\n                offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n                log_gpu_memory_usage(\"After offload actor optimizer during init\", logger=logger)\n        # load from checkpoint\n        if self._is_actor:\n            OmegaConf.set_struct(self.config.actor, True)\n            with open_dict(self.config.actor):\n                self.config.actor.use_remove_padding = use_remove_padding\n                self.config.actor.use_fused_kernels = use_fused_kernels\n            self.actor = DataParallelSPPOActor(\n                config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer\n            )\n\n        if self._is_rollout:\n            self.rollout, self.rollout_sharding_manager = self._build_rollout(\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False)\n            )\n\n        if self._is_ref:\n            self.ref_module_fsdp = self._build_model_optimizer(\n                model_path=self.config.model.path,\n                fsdp_config=self.config.ref.fsdp_config,\n                optim_config=None,\n                override_model_config=override_model_config,\n                use_remove_padding=use_remove_padding,\n                use_fused_kernels=use_fused_kernels,\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                use_liger=self.config.model.get(\"use_liger\", False),\n                role=\"ref\",\n            )[0]\n            OmegaConf.set_struct(self.config.ref, True)\n            with open_dict(self.config.ref):\n                self.config.ref.use_remove_padding = use_remove_padding\n                self.config.ref.use_fused_kernels = use_fused_kernels\n            self.ref_policy = DataParallelSPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)\n\n        if self._is_actor:\n            self.flops_counter = FlopsCounter(self.actor_model_config)\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp,\n                optimizer=self.actor.actor_optimizer,\n                lr_scheduler=self.actor_lr_scheduler,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                checkpoint_config=self.config.actor.checkpoint,\n            )\n"
  },
  {
    "path": "verl_rl/requirements-npu.txt",
    "content": "# requirements.txt records the full set of dependencies for development\naccelerate\ncodetiming\ndatasets\ndill\nhydra-core\nnumpy<2.0.0\npandas\npeft\npyarrow>=15.0.0\npybind11\npylatexenc\ntensordict>=0.8.0,<=0.9.1,!=0.9.0\ntransformers==4.52.4\nray==2.46.0\nwandb\nmathruler\ntorchdata\neinops\nqwen_vl_utils\ntorchvision==0.20.1\n"
  },
  {
    "path": "verl_rl/requirements.txt",
    "content": "# requirements.txt records the full set of dependencies for development\naccelerate\ncodetiming\nclick==8.0.4\ndatasets\ndill\n# flash-attn\nhydra-core\nliger-kernel\nnumpy<2.0.0\npandas\npeft\npyarrow>=19.0.0\npybind11\npylatexenc\npre-commit\nray==2.49.0\ntensordict>=0.8.0,<=0.9.1,!=0.9.0\ntorchdata\ntransformers\nvllm==0.8.5.post1\nopentelemetry-api>=1.26.0,<1.27.0\nopentelemetry-sdk>=1.26.0,<1.27.0\nopentelemetry-exporter-otlp-proto-grpc>=1.26.0,<1.27.0\nopentelemetry-exporter-otlp-proto-http>=1.26.0,<1.27.0\nwandb\npackaging>=20.0\nuvicorn\nfastapi\nlatex2sympy2_extended\nmath_verify"
  },
  {
    "path": "verl_rl/requirements_sglang.txt",
    "content": "# requirements.txt records the full set of dependencies for development\naccelerate\ncodetiming\ndatasets\ndill\nflash-attn\nhydra-core\nnumpy<2.0.0\npandas\npeft\npyarrow>=19.0.0\npybind11\npylatexenc\nray[default]>=2.10\ntensordict>=0.8.0,<=0.9.1,!=0.9.0\ntorchdata\ntorchvision\ntransformers\nwandb\nsglang[all]==0.4.6.post5\ntorch-memory-saver>=0.0.5\nhuggingface_hub\n"
  },
  {
    "path": "verl_rl/scripts/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/scripts/converter_hf_to_mcore.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport os\nimport warnings\nfrom contextlib import contextmanager\nfrom typing import Any, Callable, ContextManager, Optional\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom accelerate import init_empty_weights\nfrom megatron.core import dist_checkpointing\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.dist_checkpointing.mapping import ShardedTensor\nfrom megatron.core.dist_checkpointing.serialization import StrictHandling\nfrom megatron.core.models.gpt.gpt_model import ModelType\nfrom megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed\nfrom transformers import AutoConfig\n\nfrom verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards\nfrom verl.models.mcore import hf_to_mcore_config\nfrom verl.utils.device import get_device_name, get_torch_device\nfrom verl.utils.megatron_utils import get_model\n\n\ndef _init_args():\n    \"\"\"\n    Examples:\n\n    1. single rank conversion for any model:\n        > python converter_hf_to_mcore.py --hf_model_path %{hf_model} --output_path ${output_path}\n    2. distributed conversion for DeepseekV3 671B:\n        > torchrun --nproc_per_node 1 --nnodes 4 --node_rank ${RANK} converter_hf_to_mcore.py \\\n          --hf_model_path %{hf_model} --output_path ${output_path}\n    \"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--hf_model_path\", type=str, required=True, help=\"The path for the huggingface model\")\n    parser.add_argument(\"--output_path\", type=str, required=True, help=\"The path for the output mcore model\")\n    parser.add_argument(\"--use_cpu_initialization\", action=\"store_true\", help=\"Whether to use cpu initialization\")\n    parser.add_argument(\"--test\", action=\"store_true\", help=\"Whether to test the conversion\")\n    parser.add_argument(\"--trust_remote_code\", action=\"store_true\", help=\"Whether to trust remote code\")\n    args = parser.parse_args()\n    return args\n\n\ndef test_conversion(megatron_model_provider, tfconfig, output_path, model):\n    ########### test ###########\n    # load model\n    model_test = get_model(\n        model_provider_func=megatron_model_provider,\n        model_type=ModelType.encoder_or_decoder,\n        wrap_with_ddp=True,\n        transformer_config=tfconfig,\n    )\n    ref_state_dict = model_test[0].module.sharded_state_dict()\n    dist_checkpointing.load(ref_state_dict, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED)\n\n    dut_state_dict = model[0].module.state_dict()\n    for name in dut_state_dict.keys():\n        if dut_state_dict[name] is None:\n            print(f\"[Warning] {name} is none in dut_state_dict\")\n            continue\n        dut_data = dut_state_dict[name].data\n        if name in ref_state_dict:\n            ref_data = ref_state_dict[name]\n            if isinstance(ref_data, ShardedTensor):\n                ref_data = ref_data.data.view(ref_data.local_shape)\n            else:\n                ref_data = ref_data.data\n            assert dut_data.shape == ref_data.shape, f\"{name=} {dut_data.shape=} {ref_data.shape=}\"\n            assert (dut_data == ref_data).all(), f\"{name} is not equal\"\n            print(f\"{name} is equal\")\n        else:\n            print(f\"[Warning] {name} is not in ref_state_dict\")\n    for name in ref_state_dict.keys():\n        if ref_state_dict[name] is None:\n            print(f\"[Warning] {name} is none in ref_state_dict\")\n            continue\n        ref_data = ref_state_dict[name]\n        if isinstance(ref_data, ShardedTensor):\n            ref_data = ref_data.data.view(ref_data.local_shape)\n        else:\n            ref_data = ref_data.data\n        if name in dut_state_dict:\n            dut_data = dut_state_dict[name].data\n            assert dut_data.shape == ref_data.shape, f\"{name=} {dut_data.shape=} {ref_data.shape=}\"\n            assert (dut_data == ref_data).all(), f\"{name} is not equal\"\n            print(f\"{name} is equal\")\n        else:\n            print(f\"[Warning] {name} is not in dut_state_dict\")\n    print(\"Conversion test passed!\")\n\n\n@torch.inference_mode()\ndef convert_checkpoint_from_transformers_to_megatron(\n    hf_model, model, hf_config, layer_start_end: Optional[tuple[int, int]] = None\n):\n    if layer_start_end is None:\n        layer_start_end = (0, len(model.decoder.layers))\n    layer_start, layer_end = layer_start_end\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    numel = 0\n\n    num_attention_heads = hf_config.num_attention_heads\n    num_key_value_heads = hf_config.num_key_value_heads\n    hidden_dim = hf_config.hidden_size\n    head_dim = getattr(hf_config, \"head_dim\", hidden_dim // num_attention_heads)\n    if num_attention_heads != num_key_value_heads:\n        print(\"[WARNING] Converting GQA model\")\n    has_qkv_bias = getattr(hf_config, \"qkv_bias\", False) or getattr(hf_config, \"attention_bias\", False)\n    has_share_expert = getattr(hf_config, \"shared_expert_intermediate_size\", None)\n    if pp_rank == 0:\n        numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)\n\n    assert len(model.decoder.layers) == (layer_end - layer_start), (\n        f\"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}\"\n    )\n    for layer_idx, (layer, hf_layer) in enumerate(\n        zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True)\n    ):\n        global_layer_idx = layer_idx + layer_start\n        numel_cur = numel\n        numel += safe_copy(hf_layer.input_layernorm.weight, layer.self_attention.linear_qkv.layer_norm_weight)\n\n        q = hf_layer.self_attn.q_proj.weight.view(\n            [num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1]\n        )\n        k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1])\n        v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1])\n        qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous()\n        numel += safe_copy(qkv, layer.self_attention.linear_qkv.weight)\n\n        if has_qkv_bias:\n            q_bias = hf_layer.self_attn.q_proj.bias.view([num_key_value_heads, -1])\n            k_bias = hf_layer.self_attn.k_proj.bias.view([num_key_value_heads, -1])\n            v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1])\n            qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous()\n            numel += safe_copy(qkv_bias, layer.self_attention.linear_qkv.bias)\n\n        if hasattr(hf_layer.self_attn, \"q_norm\"):\n            numel += safe_copy(hf_layer.self_attn.q_norm.weight.data, layer.self_attention.q_layernorm.weight)\n            numel += safe_copy(hf_layer.self_attn.k_norm.weight.data, layer.self_attention.k_layernorm.weight)\n\n        numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight)\n        numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight)\n\n        numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight)\n\n        for idx, hf_expert in enumerate(hf_layer.mlp.experts):\n            fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])\n            numel += safe_copy(fc1_weight, layer.mlp.experts.linear_fc1._parameters[f\"weight{idx}\"])\n            numel += safe_copy(hf_expert.down_proj.weight, layer.mlp.experts.linear_fc2._parameters[f\"weight{idx}\"])\n\n        if has_share_expert:\n            numel += safe_copy(hf_layer.mlp.shared_expert_gate.weight, layer.mlp.shared_experts.gate_weight)\n            shared_fc1_weight = torch.cat(\n                [hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight]\n            )\n            numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight)\n            numel += safe_copy(hf_layer.mlp.shared_expert.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight)\n        print(f\"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}\")\n\n    if pp_rank == pp_size - 1:\n        numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight)\n        numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight)\n    return numel\n\n\ndef safe_copy(\n    src_tensor: torch.Tensor,\n    dst_tensor: torch.Tensor,\n    skip_dtype_assert: bool = False,\n):\n    if not skip_dtype_assert:\n        if src_tensor.dtype != dst_tensor.dtype:\n            raise ValueError(f\"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}\")\n    assert src_tensor.shape == dst_tensor.shape\n    dst_tensor.data.copy_(src_tensor.data)\n    return src_tensor.numel()\n\n\n@torch.inference_mode()\ndef convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel, hf_config):\n    mgmodel = mgmodel.bfloat16()\n    hfmodel = hfmodel.bfloat16()\n    num_attention_heads = hf_config.num_attention_heads\n    num_query_groups = hf_config.num_key_value_heads\n    hidden_size = hf_config.hidden_size\n    head_dim = hidden_size // num_attention_heads\n\n    # 1. vision model\n    hfvision = hfmodel.visual\n    mgvision = mgmodel.vision_model\n    vision_hidden_size = mgvision.config.hidden_size\n    vision_num_query_groups = mgvision.config.num_query_groups\n    vision_head_dim = vision_hidden_size // mgvision.config.num_attention_heads\n    copied_numel = 0\n    safe_copy(hfvision.rotary_pos_emb.inv_freq, mgvision.rotary_pos_emb.inv_freq)\n    copied_numel += safe_copy(hfvision.patch_embed.proj.weight, mgvision.patch_embed.proj.weight)\n    for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers, strict=True):\n        # norm1 --> linear_qkv.norm\n        copied_numel += safe_copy(hfblock.norm1.weight, mgblock.self_attention.linear_qkv.layer_norm_weight)\n        # norm2 --> mlp.linear_fc1.norm\n        copied_numel += safe_copy(hfblock.norm2.weight, mgblock.mlp.linear_fc1.layer_norm_weight)\n        # qkv --> self_attention.linear_qkv\n        converted_weight = (\n            hfblock.attn.qkv.weight.view(3, vision_num_query_groups, -1, vision_head_dim, vision_hidden_size)\n            .transpose(0, 1)\n            .flatten(1, 2)\n            .reshape(-1, vision_hidden_size)\n            .contiguous()\n        )\n        copied_numel += safe_copy(converted_weight, mgblock.self_attention.linear_qkv.weight)\n        converted_bias = (\n            hfblock.attn.qkv.bias.view(3, vision_num_query_groups, -1)\n            .transpose(0, 1)\n            .flatten(1, 2)\n            .view(-1)\n            .contiguous()\n        )\n        copied_numel += safe_copy(converted_bias, mgblock.self_attention.linear_qkv.bias)\n        # proj --> self_attention.linear_proj\n        copied_numel += safe_copy(hfblock.attn.proj.weight, mgblock.self_attention.linear_proj.weight)\n        copied_numel += safe_copy(hfblock.attn.proj.bias, mgblock.self_attention.linear_proj.bias)\n        # mlp --> mlp: gate\n        fc1_weight = torch.cat([hfblock.mlp.gate_proj.weight, hfblock.mlp.up_proj.weight])\n        fc1_bias = torch.cat([hfblock.mlp.gate_proj.bias, hfblock.mlp.up_proj.bias])\n        copied_numel += safe_copy(fc1_weight, mgblock.mlp.linear_fc1.weight)\n        copied_numel += safe_copy(fc1_bias, mgblock.mlp.linear_fc1.bias)\n        copied_numel += safe_copy(hfblock.mlp.down_proj.weight, mgblock.mlp.linear_fc2.weight)\n        copied_numel += safe_copy(hfblock.mlp.down_proj.bias, mgblock.mlp.linear_fc2.bias)\n\n    # 2. vision projector\n    hfprojector = hfvision.merger\n    mgprojector = mgvision.projection\n    copied_numel += safe_copy(hfprojector.ln_q.weight, mgvision.decoder.final_layernorm.weight)\n\n    copied_numel += safe_copy(hfprojector.mlp[0].weight, mgprojector.encoder.linear_fc1.weight)\n    copied_numel += safe_copy(hfprojector.mlp[0].bias, mgprojector.encoder.linear_fc1.bias)\n    copied_numel += safe_copy(hfprojector.mlp[2].weight, mgprojector.encoder.linear_fc2.weight)\n    copied_numel += safe_copy(hfprojector.mlp[2].bias, mgprojector.encoder.linear_fc2.bias)\n    n_params = sum([t.numel() for t in hfvision.state_dict().values()])\n    assert n_params == copied_numel\n    # 3. llm [just Qwen2]\n    hfllm = hfmodel.model\n    mgllm = mgmodel.language_model\n    copied_numel = 0\n    copied_numel += safe_copy(hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight)\n    for mglayer, hflayer in zip(mgllm.decoder.layers, hfllm.layers, strict=True):\n        copied_numel += safe_copy(hflayer.input_layernorm.weight, mglayer.self_attention.linear_qkv.layer_norm_weight)\n\n        q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)\n        k_proj_weight = hflayer.self_attn.k_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)\n        v_proj_weight = hflayer.self_attn.v_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)\n        qkv_proj = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=1).view(-1, hidden_size).contiguous()\n        copied_numel += safe_copy(qkv_proj, mglayer.self_attention.linear_qkv.weight)\n\n        q_proj_bias = hflayer.self_attn.q_proj.bias.view(num_query_groups, -1)\n        k_proj_bias = hflayer.self_attn.k_proj.bias.view(num_query_groups, -1)\n        v_proj_bias = hflayer.self_attn.v_proj.bias.view(num_query_groups, -1)\n        qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=1).view(-1).contiguous()\n        copied_numel += safe_copy(qkv_bias, mglayer.self_attention.linear_qkv.bias)\n        copied_numel += safe_copy(hflayer.self_attn.o_proj.weight, mglayer.self_attention.linear_proj.weight)\n\n        fc1_weight = torch.cat([hflayer.mlp.gate_proj.weight, hflayer.mlp.up_proj.weight])\n        copied_numel += safe_copy(fc1_weight, mglayer.mlp.linear_fc1.weight)\n\n        copied_numel += safe_copy(hflayer.mlp.down_proj.weight, mglayer.mlp.linear_fc2.weight)\n        copied_numel += safe_copy(hflayer.post_attention_layernorm.weight, mglayer.mlp.linear_fc1.layer_norm_weight)\n\n    copied_numel += safe_copy(hfllm.norm.weight, mgllm.decoder.final_layernorm.weight)\n    if not hf_config.tie_word_embeddings:\n        safe_copy(hfmodel.lm_head.weight, mgllm.output_layer.weight)\n\n    n_params = sum([t.numel() for t in hfllm.state_dict().values()])\n\n    assert n_params == copied_numel\n\n\n@torch.inference_mode()\ndef convert_checkpoint_from_transformers_to_megatron_dpskv3(\n    hf_model,\n    model,\n    hf_config,\n    tfconfig,\n    layer_start_end: Optional[tuple[int, int]] = None,\n):\n    warnings.warn(\"MTP model is not supported yet\", stacklevel=2)\n    if layer_start_end is None:\n        layer_start_end = (0, len(model.decoder.layers))\n    layer_start, layer_end = layer_start_end\n    numel: int = 0\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    if pp_rank == 0:\n        numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)\n\n    assert len(model.decoder.layers) == (layer_end - layer_start), (\n        f\"Expected {len(model.decoder.layers)} layers, but got {layer_end - layer_start}\"\n    )\n    for layer_idx, (layer, hf_layer) in enumerate(\n        zip(model.decoder.layers, hf_model.model.layers[layer_start:layer_end], strict=True)\n    ):\n        global_layer_idx = layer_idx + layer_start\n        numel_cur: int = numel\n        numel += safe_copy(hf_layer.input_layernorm.weight, layer.input_layernorm.weight)\n\n        if hf_config.q_lora_rank is None:\n            numel += safe_copy(hf_layer.self_attn.q_proj.weight, layer.self_attention.linear_q_proj.weight)\n        else:\n            numel += safe_copy(hf_layer.self_attn.q_a_proj.weight, layer.self_attention.linear_q_down_proj.weight)\n            numel += safe_copy(hf_layer.self_attn.q_b_proj.weight, layer.self_attention.linear_q_up_proj.weight)\n            numel += safe_copy(\n                hf_layer.self_attn.q_a_layernorm.weight, layer.self_attention.linear_q_up_proj.layer_norm_weight\n            )\n\n        numel += safe_copy(\n            hf_layer.self_attn.kv_a_proj_with_mqa.weight, layer.self_attention.linear_kv_down_proj.weight\n        )\n        numel += safe_copy(hf_layer.self_attn.kv_b_proj.weight, layer.self_attention.linear_kv_up_proj.weight)\n        numel += safe_copy(\n            hf_layer.self_attn.kv_a_layernorm.weight, layer.self_attention.linear_kv_up_proj.layer_norm_weight\n        )\n        numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight)\n\n        if not hasattr(layer.mlp, \"router\"):\n            numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.mlp.linear_fc1.layer_norm_weight)\n            numel += safe_copy(\n                torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight]), layer.mlp.linear_fc1.weight\n            )\n            numel += safe_copy(hf_layer.mlp.down_proj.weight, layer.mlp.linear_fc2.weight)\n        else:\n            numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight)\n            # NOTE: the e_score_correction_bias in mcore model will be initialized with bfloat16 and \\\n            # recover to fp32 in the first forward. There is always a diff in the bias between two models (~0.3%)\n            numel += safe_copy(\n                hf_layer.mlp.gate.e_score_correction_bias, layer.mlp.router.expert_bias, skip_dtype_assert=True\n            )\n            if tfconfig.moe_grouped_gemm:\n                for i, hf_expert in enumerate(hf_layer.mlp.experts):\n                    fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])\n                    linear_fc1_weighti = getattr(layer.mlp.experts.linear_fc1, \"weight\" + str(i))\n                    numel += safe_copy(fc1_weight, linear_fc1_weighti)\n                    linear_fc2_weighti = getattr(layer.mlp.experts.linear_fc2, \"weight\" + str(i))\n                    numel += safe_copy(hf_expert.down_proj.weight, linear_fc2_weighti)\n            else:\n                for i, hf_expert in enumerate(hf_layer.mlp.experts):\n                    expert = layer.mlp.experts.local_experts[i]\n                    fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])\n                    numel += safe_copy(fc1_weight, expert.linear_fc1.weight)\n                    numel += safe_copy(hf_expert.down_proj.weight, expert.linear_fc2.weight)\n            numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight)\n            shared_fc1_weight = torch.cat(\n                [hf_layer.mlp.shared_experts.gate_proj.weight, hf_layer.mlp.shared_experts.up_proj.weight]\n            )\n            numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight)\n            numel += safe_copy(hf_layer.mlp.shared_experts.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight)\n        print(f\"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}\")\n        assert numel - numel_cur == sum([i.numel() for i in hf_layer.state_dict().values()]), \"numel mismatch\"\n\n    if pp_rank == pp_size - 1:\n        numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight)\n        if not hf_config.tie_word_embeddings:\n            numel += safe_copy(hf_model.lm_head.weight, model.output_layer.weight)\n    print(f\"{pp_rank=} {numel=}\")\n    return numel\n\n\n@contextmanager\ndef noop_context() -> Any:\n    yield\n\n\ndef support_distributed_convert(hf_config: AutoConfig) -> bool:\n    for arch in [\"DeepseekV3ForCausalLM\", \"Qwen3MoeForCausalLM\", \"Qwen2MoeForCausalLM\"]:\n        if arch in hf_config.architectures:\n            return True\n    return False\n\n\ndef convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False):\n    os.makedirs(output_path, exist_ok=True)\n    if len(os.listdir(output_path)) > 0 and not test:\n        print(f\"Output path {output_path} is not empty, skipping conversion\")\n        return\n\n    # init torch distributed and mpu\n    if \"WORLD_SIZE\" not in os.environ:\n        os.environ[\"RANK\"] = \"0\"\n        os.environ[\"WORLD_SIZE\"] = \"1\"\n        os.environ[\"MASTER_ADDR\"] = \"localhost\"\n        os.environ[\"MASTER_PORT\"] = \"12355\"\n\n    torch.distributed.init_process_group(\"nccl\")\n\n    rank = dist.get_rank()\n    local_rank = os.getenv(\"LOCAL_RANK\", 0)\n    world_size = dist.get_world_size()\n    get_torch_device().set_device(f\"{get_device_name()}:{local_rank}\")\n\n    mpu.initialize_model_parallel(\n        tensor_model_parallel_size=1,\n        pipeline_model_parallel_size=world_size,\n        virtual_pipeline_model_parallel_size=None,\n        context_parallel_size=1,\n        expert_model_parallel_size=1,\n    )\n    model_parallel_cuda_manual_seed(0)\n\n    # init hf config\n    hf_config = AutoConfig.from_pretrained(hf_model_path)\n    print(hf_config, flush=True)\n\n    if world_size > 1 and not support_distributed_convert(hf_config):\n        raise NotImplementedError(f\"distributed conversion is not supported for {hf_config.architectures} yet.\")\n\n    pipeline_shards = get_dynamic_pipeline_shards(hf_config.num_hidden_layers, world_size)\n    print(f\"Pipeline shards: {pipeline_shards}\", flush=True)\n\n    tfconfig = hf_to_mcore_config(\n        hf_config,\n        torch.bfloat16,\n        num_layers_in_first_pipeline_stage=pipeline_shards[0] if len(pipeline_shards) > 1 else None,\n        num_layers_in_last_pipeline_stage=pipeline_shards[-1] if len(pipeline_shards) > 2 else None,\n    )\n    tfconfig.use_cpu_initialization = use_cpu_initialization\n    tie_word_embeddings = getattr(hf_config, \"tie_word_embeddings\", False)\n\n    # init megatron model\n    def megatron_model_provider(pre_process, post_process):\n        from verl.models.mcore import init_mcore_model\n\n        parallel_model = init_mcore_model(\n            tfconfig,\n            hf_config,\n            pre_process,\n            post_process,\n            share_embeddings_and_output_weights=tie_word_embeddings,\n            value=False,\n        )\n        return parallel_model\n\n    context: Callable[..., ContextManager] = init_empty_weights if use_cpu_initialization else noop_context\n    with context():\n        model = get_model(\n            model_provider_func=megatron_model_provider,\n            model_type=ModelType.encoder_or_decoder,\n            wrap_with_ddp=False,\n            transformer_config=tfconfig,\n        )\n\n    if use_cpu_initialization:\n        # convert meta device to empty tensor so it can use `copy_` function\n        model[0].module = model[0].module.to_empty(device=\"cpu\")\n\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\")\n    from transformers import AutoModelForCausalLM, AutoModelForImageTextToText\n\n    # init hf model\n    if \"Qwen2_5_VLForConditionalGeneration\" in hf_config.architectures:\n        hf_model = AutoModelForImageTextToText.from_pretrained(\n            hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code\n        )\n    else:\n        hf_model = AutoModelForCausalLM.from_pretrained(\n            hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code\n        )\n    hf_state_dict = hf_model.state_dict()\n\n    # distributed convert\n    if world_size > 1 and support_distributed_convert(hf_config):\n        pipeline_cumsum = np.cumsum(pipeline_shards)\n        layer_start = 0 if rank == 0 else pipeline_cumsum[rank - 1]\n        layer_end = pipeline_cumsum[rank]\n        if \"DeepseekV3ForCausalLM\" in hf_config.architectures:\n            numel_partial: int = convert_checkpoint_from_transformers_to_megatron_dpskv3(\n                hf_model, model[0].module, hf_config, tfconfig=tfconfig, layer_start_end=(layer_start, layer_end)\n            )\n        elif \"Qwen3MoeForCausalLM\" in hf_config.architectures or \"Qwen2MoeForCausalLM\" in hf_config.architectures:\n            numel_partial: int = convert_checkpoint_from_transformers_to_megatron(\n                hf_model, model[0].module, hf_config, layer_start_end=(layer_start, layer_end)\n            )\n        else:\n            raise NotImplementedError(f\"Distributed conversion is not supported for {hf_config.architectures} yet.\")\n\n        numel_tensor = torch.tensor([numel_partial]).to(get_device_name())\n        dist.all_reduce(numel_tensor, op=dist.ReduceOp.SUM)\n        numel = int(numel_tensor.cpu().item())\n        print(f\"total numel={numel} vs {hf_model.num_parameters()=}\")\n        if numel != hf_model.num_parameters():\n            warnings.warn(f\"numel mismatch: {numel=} != {hf_model.num_parameters()=}\", stacklevel=1)\n\n    # load hf state dict to megatron model\n    elif \"Qwen2MoeForCausalLM\" in hf_config.architectures:\n        convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)\n    elif \"Qwen2_5_VLForConditionalGeneration\" in hf_config.architectures:\n        convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hf_model, model[0].module, hf_config)\n    elif \"DeepseekV3ForCausalLM\" in hf_config.architectures:\n        convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig)\n    elif \"Qwen3MoeForCausalLM\" in hf_config.architectures:\n        convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)\n    else:\n        assert not use_cpu_initialization, \"use_cpu_initialization is only supported for MoE model\"\n        from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel\n\n        load_state_dict_to_megatron_gptmodel(\n            state_dict=hf_state_dict,\n            wrapped_models=model,\n            config=hf_config,\n            params_dtype=torch.bfloat16,\n            is_value_model=False,\n        )\n\n    megatron_state_dict = model[0].module.sharded_state_dict()\n    del hf_state_dict, hf_model\n\n    # save megatron model\n    if len(os.listdir(output_path)) == 0:\n        dist_checkpointing.save(megatron_state_dict, output_path, sharded_strategy=None, async_sharded_save=False)\n    if test:\n        test_conversion(megatron_model_provider, tfconfig, output_path, model)\n\n\nif __name__ == \"__main__\":\n    args = _init_args()\n    convert_hf_to_mcore(\n        args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test, args.trust_remote_code\n    )\n"
  },
  {
    "path": "verl_rl/scripts/diagnose.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Diagnose script for checking OS/hardware/python/pip/verl/network.\nThe output of this script can be a very good hint to issue/problem.\n\"\"\"\n\nimport os\nimport platform\nimport socket\nimport subprocess\nimport sys\nimport time\n\nimport psutil\n\ntry:\n    from urllib.parse import urlparse\n    from urllib.request import urlopen\nexcept ImportError:\n    from urllib2 import urlopen\n    from urlparse import urlparse\nimport argparse\nimport importlib.metadata\n\nimport torch\n\nURLS = {\n    \"PYPI\": \"https://pypi.python.org/pypi/pip\",\n}\n\nREGIONAL_URLS = {\n    \"cn\": {\n        \"PYPI(douban)\": \"https://pypi.douban.com/\",\n        \"Conda(tsinghua)\": \"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/\",\n    }\n}\n\n\ndef test_connection(name, url, timeout=10):\n    \"\"\"Simple connection test\"\"\"\n    urlinfo = urlparse(url)\n    start = time.time()\n    try:\n        socket.gethostbyname(urlinfo.netloc)\n    except Exception as e:\n        print(\"Error resolving DNS for {}: {}, {}\".format(name, url, e))\n        return\n    dns_elapsed = time.time() - start\n    start = time.time()\n    try:\n        _ = urlopen(url, timeout=timeout)\n    except Exception as e:\n        print(\"Error open {}: {}, {}, DNS finished in {} sec.\".format(name, url, e, dns_elapsed))\n        return\n    load_elapsed = time.time() - start\n    print(\"Timing for {}: {}, DNS: {:.4f} sec, LOAD: {:.4f} sec.\".format(name, url, dns_elapsed, load_elapsed))\n\n\ndef check_python():\n    print(\"----------Python Info----------\")\n    print(\"Version      :\", platform.python_version())\n    print(\"Compiler     :\", platform.python_compiler())\n    print(\"Build        :\", platform.python_build())\n    print(\"Arch         :\", platform.architecture())\n\n\ndef check_pip():\n    print(\"------------Pip Info-----------\")\n    try:\n        import pip\n\n        print(\"Version      :\", pip.__version__)\n        print(\"Directory    :\", os.path.dirname(pip.__file__))\n    except ImportError:\n        print(\"No corresponding pip install for current python.\")\n\n\ndef _get_current_git_commit():\n    try:\n        result = subprocess.run([\"git\", \"rev-parse\", \"HEAD\"], capture_output=True, text=True, check=True)\n        return result.stdout.strip()\n    except subprocess.CalledProcessError as e:\n        print(f\"Error running git command: {e.stderr.strip()}\")\n        return None\n    except FileNotFoundError:\n        print(\"Did not find command: git\")\n        return None\n\n\ndef check_verl():\n    print(\"----------verl Info-----------\")\n    try:\n        sys.path.insert(0, os.getcwd())\n        import verl\n\n        print(\"Version      :\", verl.__version__)\n        verl_dir = os.path.dirname(verl.__file__)\n        print(\"Directory    :\", verl_dir)\n        try:\n            commit_hash = _get_current_git_commit()\n            print(\"Commit Hash  :\", commit_hash)\n        except AttributeError:\n            print(\"Commit hash not found. \")\n    except ImportError as e:\n        print(f\"No verl installed: {e}\")\n    except Exception as e:\n        import traceback\n\n        if not isinstance(e, IOError):\n            print(\"An error occurred trying to import verl.\")\n            print(\"This is very likely due to missing or incompatible library files.\")\n        print(traceback.format_exc())\n\n\ndef check_os():\n    print(\"----------Platform Info----------\")\n    print(\"Platform     :\", platform.platform())\n    print(\"system       :\", platform.system())\n    print(\"node         :\", platform.node())\n    print(\"release      :\", platform.release())\n    print(\"version      :\", platform.version())\n\n\ndef check_hardware():\n    print(\"----------Hardware Info----------\")\n    print(\"machine      :\", platform.machine())\n    print(\"processor    :\", platform.processor())\n    if sys.platform.startswith(\"darwin\"):\n        pipe = subprocess.Popen((\"sysctl\", \"-a\"), stdout=subprocess.PIPE)\n        output = pipe.communicate()[0]\n        for line in output.split(b\"\\n\"):\n            if b\"brand_string\" in line or b\"features\" in line:\n                print(line.strip())\n    elif sys.platform.startswith(\"linux\"):\n        subprocess.call([\"lscpu\"])\n    elif sys.platform.startswith(\"win32\"):\n        subprocess.call([\"wmic\", \"cpu\", \"get\", \"name\"])\n\n\ndef check_network(args):\n    print(\"----------Network Test----------\")\n    if args.timeout > 0:\n        print(\"Setting timeout: {}\".format(args.timeout))\n        socket.setdefaulttimeout(10)\n    for region in args.region.strip().split(\",\"):\n        r = region.strip().lower()\n        if not r:\n            continue\n        if r in REGIONAL_URLS:\n            URLS.update(REGIONAL_URLS[r])\n        else:\n            import warnings\n\n            warnings.warn(\"Region {} do not need specific test, please refer to global sites.\".format(r), stacklevel=2)\n    for name, url in URLS.items():\n        test_connection(name, url, args.timeout)\n\n\ndef check_environment():\n    print(\"----------Environment----------\")\n    for k, v in os.environ.items():\n        if k.startswith(\"VERL_\") or k.startswith(\"OMP_\") or k.startswith(\"KMP_\") or k == \"CC\" or k == \"CXX\":\n            print('{}=\"{}\"'.format(k, v))\n\n\ndef check_pip_package_versions():\n    packages = [\"vllm\", \"sglang\", \"ray\", \"torch\"]\n    for package in packages:\n        try:\n            version = importlib.metadata.version(package)\n            print(f\"{package}\\t     : {version}\")\n        except importlib.metadata.PackageNotFoundError:\n            print(f\"{package}\\t     : not found.\")\n\n\ndef check_cuda_versions():\n    if torch.cuda.is_available():\n        try:\n            cuda_runtime_version = torch.version.cuda\n            print(f\"CUDA Runtime : {cuda_runtime_version}\")\n            import subprocess\n\n            nvcc_output = subprocess.check_output([\"nvcc\", \"--version\"]).decode(\"utf-8\")\n            cuda_compiler_version = next((line for line in nvcc_output.splitlines() if \"release\" in line), None)\n            if cuda_compiler_version:\n                print(f\"CUDA Compiler : {cuda_compiler_version.strip()}\")\n            else:\n                print(\"Could not determine CUDA compiler version.\")\n        except FileNotFoundError as e:\n            print(f\"CUDA compiler : Not found: {e}\")\n        except Exception as e:\n            print(f\"An error occurred while checking CUDA versions: {e}\")\n    else:\n        print(\"CUDA is not available.\")\n\n\ndef _get_cpu_memory():\n    \"\"\"\n    Get the total CPU memory capacity in GB.\n    \"\"\"\n    memory = psutil.virtual_memory()\n    return memory.total / (1024**3)\n\n\ndef _get_gpu_info():\n    \"\"\"\n    Get GPU type, GPU memory, and GPU count using nvidia-smi command.\n    \"\"\"\n    try:\n        result = subprocess.run(\n            [\"nvidia-smi\", \"--query-gpu=gpu_name,memory.total\", \"--format=csv,noheader,nounits\"],\n            capture_output=True,\n            text=True,\n            check=True,\n        )\n        gpu_lines = result.stdout.strip().split(\"\\n\")\n        gpu_count = len(gpu_lines)\n        gpu_info = []\n        for line in gpu_lines:\n            gpu_name, gpu_memory = line.split(\", \")\n            gpu_info.append(\n                {\n                    \"type\": gpu_name,\n                    \"memory\": float(gpu_memory) / 1024,  # Convert to GB\n                }\n            )\n        return gpu_count, gpu_info\n    except subprocess.CalledProcessError:\n        print(\"Failed to execute nvidia-smi command.\")\n        return 0, []\n\n\ndef _get_system_info():\n    \"\"\"\n    Get CPU memory capacity, GPU type, GPU memory, and GPU count.\n    \"\"\"\n    cpu_memory = _get_cpu_memory()\n    gpu_count, gpu_info = _get_gpu_info()\n    return {\"cpu_memory\": cpu_memory, \"gpu_count\": gpu_count, \"gpu_info\": gpu_info}\n\n\ndef check_system_info():\n    print(\"----------System Info----------\")\n    system_info = _get_system_info()\n    print(f\"CPU Memory\\t: {system_info['cpu_memory']:.2f} GB\")\n    print(f\"GPU Count\\t: {system_info['gpu_count']}\")\n    for i, gpu in enumerate(system_info[\"gpu_info\"]):\n        print(f\"GPU {i + 1}\\tType    : {gpu['type']}\")\n        print(f\"GPU {i + 1}\\tMemory  : {gpu['memory']:.2f} GB\")\n\n\ndef parse_args():\n    \"\"\"Parse arguments.\"\"\"\n    parser = argparse.ArgumentParser(\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n        description=\"Diagnose script for checking the current system.\",\n    )\n    choices = [\"python\", \"pip\", \"verl\", \"system\", \"os\", \"environment\"]\n    for choice in choices:\n        parser.add_argument(\"--\" + choice, default=1, type=int, help=\"Diagnose {}.\".format(choice))\n    parser.add_argument(\"--network\", default=0, type=int, help=\"Diagnose network.\")\n    parser.add_argument(\"--hardware\", default=0, type=int, help=\"Diagnose hardware.\")\n    parser.add_argument(\n        \"--region\",\n        default=\"\",\n        type=str,\n        help=\"Additional sites in which region(s) to test. \\\n                        Specify 'cn' for example to test mirror sites in China.\",\n    )\n    parser.add_argument(\"--timeout\", default=10, type=int, help=\"Connection test timeout threshold, 0 to disable.\")\n    args = parser.parse_args()\n    return args\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    if args.python:\n        check_python()\n\n    if args.pip:\n        check_pip()\n        check_pip_package_versions()\n\n    if args.verl:\n        check_verl()\n\n    if args.os:\n        check_os()\n\n    if args.hardware:\n        check_hardware()\n\n    if args.network:\n        check_network(args)\n\n    if args.environment:\n        check_environment()\n        check_cuda_versions()\n\n    if args.system:\n        check_system_info()\n"
  },
  {
    "path": "verl_rl/scripts/generate_trainer_config.sh",
    "content": "#!/usr/bin/env bash\nset -euox pipefail\n\n\n# Define config specifications: \"config_name:output_file:config_arg\"\nCONFIG_SPECS=(\n    \"ppo_trainer:_generated_ppo_trainer.yaml:\"\n    \"ppo_megatron_trainer:_generated_ppo_megatron_trainer.yaml:--config-name=ppo_megatron_trainer.yaml\"\n)\n\ngenerate_config() {\n    local config_name=\"$1\"\n    local output_file=\"$2\"\n    local config_arg=\"$3\"\n    \n    local target_cfg=\"verl/trainer/config/${output_file}\"\n    local tmp_header=$(mktemp)\n    local tmp_cfg=$(mktemp)\n    \n    echo \"# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'\" > \"$tmp_header\"\n    echo \"# in which it invokes 'python3 scripts/print_cfg.py --cfg job ${config_arg}' to flatten the 'verl/trainer/config/${config_name}.yaml' config fields into a single file.\" >> \"$tmp_header\"\n    echo \"# Do not modify this file directly.\" >> \"$tmp_header\"\n    echo \"# The file is usually only for reference and never used.\" >> \"$tmp_header\"\n    echo \"\" >> \"$tmp_header\"\n    \n    python3 scripts/print_cfg.py --cfg job ${config_arg} > \"$tmp_cfg\"\n    \n    cat \"$tmp_header\" > \"$target_cfg\"\n    sed -n '/^actor_rollout_ref/,$p' \"$tmp_cfg\" >> \"$target_cfg\"\n    \n    rm \"$tmp_cfg\" \"$tmp_header\"\n    \n    echo \"Generated: $target_cfg\"\n}\n\nfor spec in \"${CONFIG_SPECS[@]}\"; do\n    IFS=':' read -r config_name output_file config_arg <<< \"$spec\"\n    generate_config \"$config_name\" \"$output_file\" \"$config_arg\"\ndone\n\nfor spec in \"${CONFIG_SPECS[@]}\"; do\n    IFS=':' read -r config_name output_file config_arg <<< \"$spec\"\n    target_cfg=\"verl/trainer/config/${output_file}\"\n    if ! git diff --exit-code -- \"$target_cfg\" >/dev/null; then\n        echo \"✖ $target_cfg is out of date. Please regenerate via 'scripts/generate_trainer_config.sh' and commit the changes.\"\n        exit 1\n    fi\ndone\n\necho \"All good\"\nexit 0\n"
  },
  {
    "path": "verl_rl/scripts/init_random_model.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script override a model with custom config and random weights, mainly for create small models for \ndebugging purposes.\n\nUsage:\n    python scripts/init_random_model.py \\\n        --hf_model_path <path_to_hf_model> \\\n        --new_config_path <path_to_new_config.json> \\\n        --output_path <path_to_output_model>\n\n\"\"\"\n\nimport argparse\nimport json\nimport os\nimport warnings\nfrom typing import Any\n\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig\n\n\ndef _init_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--hf_model_path\", type=str, required=True, help=\"The path for the huggingface model\")\n    parser.add_argument(\"--new_config_path\", type=str, required=True, help=\"The path for the new config file\")\n    parser.add_argument(\"--output_path\", type=str, required=True, help=\"The path for the output random model\")\n    args = parser.parse_args()\n    return args\n\n\ndef check_output_path(output_path: str):\n    if os.path.exists(output_path):\n        warnings.warn(f\"Output path '{output_path}' already exists. Will do nothing.\", stacklevel=2)\n        exit()\n    else:\n        os.makedirs(output_path, exist_ok=True)\n        print(f\"Output path '{output_path}' created.\")\n\n\ndef check_configs(original_config: dict[str, Any], new_config: dict[str, Any]) -> bool:\n    \"\"\"\n    Check if the original config and new config are compatible.\n    This is a placeholder function; actual implementation may vary based on requirements.\n    \"\"\"\n    # Example check: ensure 'model_type' is the same\n    if new_config.get(\"model_type\", None) is not None and original_config.get(\"model_type\") != new_config.get(\n        \"model_type\"\n    ):\n        raise RuntimeError(\"Model types do not match.\")\n    for key in new_config:\n        if key not in original_config:\n            warnings.warn(\n                f\"Key '{key}' in new config does not exist in original config, may not take effect.\", stacklevel=2\n            )\n\n\ndef init_random_model(hf_model_path, new_config_path, output_path):\n    config = AutoConfig.from_pretrained(hf_model_path)\n    tokenizer = AutoTokenizer.from_pretrained(hf_model_path)\n    config_dict = PretrainedConfig.get_config_dict(hf_model_path)[0]\n    print(config_dict)\n    with open(new_config_path) as f:\n        new_config_dict = json.load(f)\n    check_configs(config_dict, new_config_dict)\n    config_dict.update(new_config_dict)\n    new_confg = config.from_dict(config_dict)\n    print(f\"new_config: {new_confg}\")\n    model = AutoModelForCausalLM.from_config(new_confg)\n    model.save_pretrained(output_path)\n    tokenizer.save_pretrained(output_path)\n    new_confg.save_pretrained(output_path)\n    print(f\"Random model initialized and saved to {output_path}\")\n\n\nif __name__ == \"__main__\":\n    args = _init_args()\n    check_output_path(args.output_path)\n    init_random_model(\n        hf_model_path=args.hf_model_path, new_config_path=args.new_config_path, output_path=args.output_path\n    )\n"
  },
  {
    "path": "verl_rl/scripts/install_vllm_sglang_mcore.sh",
    "content": "#!/bin/bash\n\nUSE_MEGATRON=${USE_MEGATRON:-1}\nUSE_SGLANG=${USE_SGLANG:-1}\n\nexport MAX_JOBS=32\n\necho \"1. install inference frameworks and pytorch they need\"\nif [ $USE_SGLANG -eq 1 ]; then\n    pip install \"sglang[all]==0.4.6.post1\" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir\nfi\npip install --no-cache-dir \"vllm==0.8.5.post1\" \"torch==2.6.0\" \"torchvision==0.21.0\" \"torchaudio==2.6.0\" \"tensordict==0.6.2\" torchdata\n\necho \"2. install basic packages\"\npip install \"transformers[hf_xet]>=4.51.0\" accelerate datasets peft hf-transfer \\\n    \"numpy<2.0.0\" \"pyarrow>=15.0.0\" pandas \\\n    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \\\n    pytest py-spy pyext pre-commit ruff\n\npip install \"nvidia-ml-py>=12.560.30\" \"fastapi[standard]>=0.115.0\" \"optree>=0.13.0\" \"pydantic>=2.9\" \"grpcio>=1.62.1\"\n\n\necho \"3. install FlashAttention and FlashInfer\"\n# Install flash-attn-2.7.4.post1 (cxx11abi=False)\nwget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \\\n    pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n\n# Install flashinfer-0.2.2.post1+cu124 (cxx11abi=False)\n# vllm-0.8.3 does not support flashinfer>=0.2.3\n# see https://github.com/vllm-project/vllm/pull/15777\nwget -nv https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \\\n    pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl\n\n\nif [ $USE_MEGATRON -eq 1 ]; then\n    echo \"4. install TransformerEngine and Megatron\"\n    echo \"Notice that TransformerEngine installation can take very long time, please be patient\"\n    NVTE_FRAMEWORK=pytorch pip3 install --no-deps git+https://github.com/NVIDIA/TransformerEngine.git@v2.2.1\n    pip3 install --no-deps git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\nfi\n\n\necho \"5. May need to fix opencv\"\npip install opencv-python\npip install opencv-fixer && \\\n    python -c \"from opencv_fixer import AutoFix; AutoFix()\"\n\n\nif [ $USE_MEGATRON -eq 1 ]; then\n    echo \"6. Install cudnn python package (avoid being overridden)\"\n    pip install nvidia-cudnn-cu12==9.8.0.87\nfi\n\necho \"Successfully installed all packages\"\n"
  },
  {
    "path": "verl_rl/scripts/legacy_model_merger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis script is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends.\n\nTo merge FSDP checkpoints:\n```sh\npython scripts/legacy_model_merger.py merge \\\n    --backend fsdp \\\n    --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \\\n    --target_dir /path/to/merged_hf_model\n```\n\nTo merge Megatron checkpoints:\n```sh\npython scripts/legacy_model_merger.py merge \\\n    --backend megatron \\\n    --tie-word-embedding \\\n    --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \\\n    --target_dir /path/to/merged_hf_model\n```\n\nFor more details, please refer to documentation:\nhttps://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model\n\"\"\"\n\nimport argparse\nimport os\nimport re\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom concurrent.futures import ThreadPoolExecutor\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nimport numpy as np\nimport torch\nfrom accelerate import init_empty_weights\nfrom safetensors.torch import load_file\nfrom torch.distributed._tensor import Placement, Shard\nfrom transformers import (\n    AutoConfig,\n    AutoModelForCausalLM,\n    AutoModelForTokenClassification,\n    AutoModelForVision2Seq,\n    GenerationConfig,\n    PretrainedConfig,\n)\n\ntry:\n    # for torch 2.5+\n    from torch.distributed.tensor import DTensor\nexcept ImportError:\n    from torch.distributed._tensor import DTensor\n\nfrom tqdm import tqdm\n\nfrom verl.utils import hf_processor, hf_tokenizer\n\n\n@dataclass\nclass ModelMergerConfig:\n    operation: str  # 'merge' or 'test'\n    backend: str\n    local_dir: str\n    hf_model_config_path: str\n    target_dir: Optional[str] = \"tmp\"\n    hf_upload_path: Optional[str] = None\n    private: bool = False\n    test_hf_dir: Optional[str] = None\n    tie_word_embedding: bool = False\n    is_value_model: bool = False\n    hf_model_path: Optional[str] = None\n    hf_upload: bool = field(init=False)\n\n    def __post_init__(self):\n        self.hf_upload = self.operation == \"merge\" and bool(self.hf_upload_path)\n        if self.operation == \"test\":\n            self.target_dir = None\n            self.hf_upload_path = None\n            self.private = False\n\n\nclass BaseModelMerger(ABC):\n    def __init__(self, config: ModelMergerConfig):\n        self.config = config\n        self.hf_model_config_path = config.hf_model_config_path\n\n        if config.hf_model_path:\n            print(\n                \"Warning: --hf_model_path is deprecated and will be removed in a future version. Currently verl will save huggingface model configuration files into checkpoint directories. Therefore, there is no need to provide --hf_model_path. \"\n            )\n            self.hf_model_config_path = config.hf_model_path\n\n        self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path)\n\n    def get_transformers_auto_model_class(self):\n        if \"ForTokenClassification\" in self.model_config.architectures[0]:\n            return AutoModelForTokenClassification\n        elif \"ForCausalLM\" in self.model_config.architectures[0]:\n            return AutoModelForCausalLM\n        elif \"ForConditionalGeneration\" in self.model_config.architectures[0]:\n            return AutoModelForVision2Seq\n\n        raise NotImplementedError(f\"Unknown architecture {self.model_config.architectures}\")\n\n    def patch_model_generation_config(self, model):\n        \"\"\"\n        The generation_config created from model config may be different to the pretrained model,\n        this may lead to error when generating: https://github.com/volcengine/verl/issues/1246\n\n        This function patch the generation_config created from model config to the pretrained model.\n        \"\"\"\n        if model.can_generate():\n            try:\n                model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path)\n            except OSError:\n                print(\n                    f\"Warning: Generation config file not found in {self.hf_model_config_path}, using a generation config created from the model config.\"\n                )\n        return model\n\n    def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]):\n        \"\"\"\n        Save lora adapter to safetensors.\n\n        Returns:\n            lora_path: str, the path to the lora adapter. None if no lora adapter found.\n\n        Note:\n            This function change the 'state_dict' in place.\n        \"\"\"\n        lora_params_names = [name for name in state_dict.keys() if \"lora_\" in name]\n\n        if len(lora_params_names) == 0:\n            return None\n\n        import json\n        from typing import OrderedDict\n\n        import peft\n        from safetensors.torch import save_file\n\n        lora_params = OrderedDict()\n        target_modules = set()\n        lora_key = None\n\n        for name in lora_params_names:\n            lora_key = name.replace(\".default.weight\", \".weight\")\n            target_modules.add(lora_key.split(\".\")[-3])\n            lora_params[lora_key] = state_dict.pop(name)\n\n        lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1])\n        peft_dict = {\n            \"r\": lora_rank,\n            \"lora_alpha\": 0,  # lora_alpha is not set. An error should be raised to inform the user to set it manually.\n            \"target_modules\": list(target_modules),\n        }\n        peft_config = peft.LoraConfig(**peft_dict).to_dict()\n        peft_config[\"task_type\"] = peft_config[\"task_type\"].value if peft_config[\"task_type\"] else None\n        peft_config[\"peft_type\"] = peft_config[\"peft_type\"].value if peft_config[\"peft_type\"] else None\n        peft_config[\"target_modules\"] = list(peft_config[\"target_modules\"])\n\n        lora_path = os.path.join(self.config.target_dir, \"lora_adapter\")\n        os.makedirs(lora_path, exist_ok=True)\n        with open(os.path.join(lora_path, \"adapter_config.json\"), \"w\", encoding=\"utf-8\") as f:\n            json.dump(peft_config, f, ensure_ascii=False, indent=4)\n        save_file(lora_params, os.path.join(lora_path, \"adapter_model.safetensors\"))\n\n        for name in list(state_dict.keys()):\n            key = (\n                name.replace(\"base_model.model.\", \"\")\n                .replace(\".base_layer.weight\", \".weight\")\n                .replace(\".base_layer.bias\", \".bias\")\n            )\n            state_dict[key] = state_dict.pop(name)\n\n        return lora_path\n\n    def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]):\n        auto_model_class = self.get_transformers_auto_model_class()\n        with init_empty_weights():\n            model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16)\n        model.to_empty(device=\"cpu\")\n        model = self.patch_model_generation_config(model)\n\n        lora_path = self.save_lora_adapter(state_dict)\n        if lora_path:\n            print(f\"Saving lora adapter to {lora_path}\")\n\n        print(f\"Saving model to {self.config.target_dir}\")\n        model.save_pretrained(self.config.target_dir, state_dict=state_dict)\n        del state_dict\n        del model\n\n        processor = hf_processor(self.hf_model_config_path)\n        tokenizer = hf_tokenizer(self.hf_model_config_path)\n        if processor is not None:\n            print(f\"Saving processor to {self.config.target_dir}\")\n            processor.save_pretrained(self.config.target_dir)\n        if tokenizer is not None:\n            print(f\"Saving tokenizer to {self.config.target_dir}\")\n            tokenizer.save_pretrained(self.config.target_dir)\n\n    def upload_to_huggingface(self):\n        from huggingface_hub import HfApi\n\n        api = HfApi()\n        api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True)\n        api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type=\"model\")\n\n    @abstractmethod\n    def merge_and_save(self):\n        raise NotImplementedError(\"Subclasses should implement this method\")\n\n\nclass FSDPModelMerger(BaseModelMerger):\n    def _get_world_size(self) -> int:\n        \"\"\"Extracts the FSDP world_size from checkpoint filenames (e.g., 'model_world_size_8_rank_0.pt').\"\"\"\n        for filename in os.listdir(self.config.local_dir):\n            match = re.match(r\"model_world_size_(\\d+)_rank_0\\.pt\", filename)\n            if match:\n                return int(match.group(1))\n        raise FileNotFoundError(\n            f\"Could not determine world size. No file matching 'model_world_size_(\\d+)_rank_0.pt' found in {self.config.local_dir}\"\n        )\n\n    def _load_rank_zero_state_dict(self, world_size: int) -> dict:\n        return torch.load(\n            Path(self.config.local_dir) / f\"model_world_size_{world_size}_rank_0.pt\",\n            map_location=\"cpu\",\n            weights_only=False,\n        )\n\n    def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]:\n        \"\"\"\n        Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict.\n        If no DTensor is found, infers a simple FSDP mesh based on world_size.\n        \"\"\"\n        pivot_key = sorted(list(state_dict.keys()))[0]\n        weight = state_dict[pivot_key]\n\n        if isinstance(weight, DTensor):\n            # get sharding info\n            device_mesh = weight.device_mesh\n            mesh = device_mesh.mesh\n            mesh_dim_names = device_mesh.mesh_dim_names\n        else:\n            # for non-DTensor\n            mesh = np.array([world_size], dtype=np.int64)\n            mesh_dim_names = (\"fsdp\",)\n\n        return mesh, mesh_dim_names\n\n    def _calculate_shard_configuration(\n        self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...]\n    ) -> tuple[int, tuple[int, ...]]:\n        \"\"\"Calculates the total number of shards and the shape of the device mesh.\"\"\"\n        assert mesh_dim_names in ((\"fsdp\",), (\"ddp\", \"fsdp\")), f\"Unsupported mesh_dim_names {mesh_dim_names}\"\n\n        if \"tp\" in mesh_dim_names:\n            # TODO: \"tp\" is not supported yet due to the above assert\n            total_shards = mesh.shape[-1] * mesh.shape[-2]\n            mesh_shape = (mesh.shape[-2], mesh.shape[-1])\n        else:\n            total_shards = mesh.shape[-1]\n            mesh_shape = (mesh.shape[-1],)\n\n        return total_shards, mesh_shape\n\n    def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor:\n        \"\"\"Merges a list of tensors based on their DTensor placement\"\"\"\n        if placement.is_replicate():\n            return tensors[0]\n        elif placement.is_partial():\n            raise NotImplementedError(\"Partial placement is not supported yet\")\n        elif placement.is_shard():\n            return torch.cat(tensors, dim=placement.dim).contiguous()\n\n        raise NotImplementedError(f\"Unsupported placement: {placement}\")\n\n    def _load_and_merge_state_dicts(\n        self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...]\n    ) -> dict[str, torch.Tensor]:\n        model_state_dict_lst = [None] * total_shards\n\n        def process_one_shard(rank: int, model_state_dict_lst: list):\n            model_path = Path(self.config.local_dir) / f\"model_world_size_{world_size}_rank_{rank}.pt\"\n            state_dict = torch.load(model_path, map_location=\"cpu\", weights_only=False)\n            model_state_dict_lst[rank] = state_dict\n            return state_dict\n\n        with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:\n            futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)]\n            for future in tqdm(futures, desc=f\"Loading {total_shards} FSDP shards\", total=total_shards):\n                future.result()\n\n        # Merge state dicts from all shards\n        state_dict = {}\n        param_placements: dict[str, list] = {}\n\n        for key in set(model_state_dict_lst[0].keys()):\n            state_dict[key] = []\n            for model_state_shard in model_state_dict_lst:\n                # add tensor shard in order of rank to state_dict[key]\n                tensor = model_state_shard.pop(key)\n                if isinstance(tensor, DTensor):\n                    state_dict[key].append(tensor._local_tensor.bfloat16())\n\n                    placements = tuple(tensor.placements)\n                    # replicated placement at dp dimension can be discarded\n                    if mesh_dim_names[0] in (\"dp\", \"ddp\"):\n                        placements = placements[1:]\n\n                    if key not in param_placements:\n                        param_placements[key] = placements\n                    else:\n                        assert param_placements[key] == placements\n                else:\n                    state_dict[key].append(tensor.bfloat16())\n\n        del model_state_dict_lst\n\n        # Merge tensors\n        for key in sorted(state_dict):\n            if not isinstance(state_dict[key], list):\n                print(f\"No need to merge key {key}\")\n                continue\n            if key in param_placements:\n                # merge shards\n                placements: tuple[Shard] = param_placements[key]\n                if len(mesh_shape) == 1:\n                    # 1-D list, FSDP without TP\n                    assert len(placements) == 1\n                    shards = state_dict[key]\n                    state_dict[key] = self._merge_by_placement(shards, placements[0])\n                else:\n                    # 2-D list, FSDP + TP\n                    raise NotImplementedError(\"FSDP + TP is not supported yet\")\n            else:\n                state_dict[key] = torch.cat(state_dict[key], dim=0)\n\n        return state_dict\n\n    def merge_and_save(self):\n        world_size = self._get_world_size()\n        rank_zero_state_dict = self._load_rank_zero_state_dict(world_size)\n\n        mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size)\n        print(f\"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}\")\n\n        total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names)\n        print(f\"Processing model shards with {total_shards} {mesh_shape} in total\")\n\n        merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names)\n\n        if self.config.operation == \"test\":\n            if not self.config.test_hf_dir:\n                raise ValueError(\"test_hf_dir must be provided for test operation\")\n            self._test_state_dict(merged_state_dict)\n        elif self.config.operation == \"merge\":\n            self.save_hf_model_and_tokenizer(merged_state_dict)\n            if self.config.hf_upload:\n                self.upload_to_huggingface()\n        else:\n            raise ValueError(f\"Unknown operation: {self.config.operation}\")\n\n    def _test_state_dict(self, state_dict: dict[str, torch.Tensor]):\n        auto_model_class = self.get_transformers_auto_model_class()\n\n        hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16)\n        hf_state_dict = hf_model.state_dict()\n        del hf_model\n\n        hf_model_keys = set(hf_state_dict.keys())\n        collected_keys = set(state_dict.keys())\n\n        missing_keys = hf_model_keys - collected_keys\n        assert len(missing_keys) == 0, f\"Missing keys in collected state dict: {list(sorted(missing_keys))}\"\n\n        extra_keys = collected_keys - hf_model_keys\n        assert len(extra_keys) == 0, f\"Extra keys in collected state dict: {list(sorted(extra_keys))}\"\n\n        for key in hf_model_keys:\n            hf_shape = hf_state_dict[key].shape\n            collected_shape = state_dict[key].shape\n            assert hf_shape == collected_shape, (\n                f\"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}\"\n            )\n\n            hf_dtype = hf_state_dict[key].dtype\n            collected_dtype = state_dict[key].dtype\n            assert hf_dtype == collected_dtype, (\n                f\"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}\"\n            )\n\n            torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6)\n\n        print(\"FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.\")\n\n\nclass MegatronModelMerger(BaseModelMerger):\n    def __init__(self, config: ModelMergerConfig):\n        from verl.utils.megatron_utils import get_hf_config_and_tokenizer_checkpoint_path\n\n        config.hf_model_config_path = get_hf_config_and_tokenizer_checkpoint_path(config.local_dir)\n        super().__init__(config)\n\n        self.params_mapping = {\n            # megatron core gpt model name, huggingface model name\n            # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the longer key within the containing relationship is processed first.\n            \"embedding.word_embeddings\": \"model.embed_tokens\",\n            # attn\n            \"self_attention.linear_qkv.layer_norm_weight\": \"input_layernorm.weight\",\n            \"self_attention.linear_qkv.layer_norm_bias\": \"input_layernorm.bias\",\n            \"self_attention.linear_qkv\": \"self_attn.qkv_proj\",\n            \"self_attention.q_layernorm\": \"self_attn.q_norm\",\n            \"self_attention.k_layernorm\": \"self_attn.k_norm\",\n            \"self_attention.linear_proj\": \"self_attn.o_proj\",\n            # mla\n            \"self_attention.linear_q_proj\": \"self_attn.q_proj\",\n            \"self_attention.linear_q_down_proj\": \"self_attn.q_a_proj\",\n            \"self_attention.linear_q_up_proj.layer_norm_weight\": \"self_attn.q_a_layernorm.weight\",\n            \"self_attention.linear_q_up_proj\": \"self_attn.q_b_proj\",\n            \"self_attention.linear_kv_down_proj\": \"self_attn.kv_a_proj_with_mqa\",\n            \"self_attention.linear_kv_up_proj.layer_norm_weight\": \"self_attn.kv_a_layernorm.weight\",\n            \"self_attention.linear_kv_up_proj\": \"self_attn.kv_b_proj\",\n            # mlp\n            \"pre_mlp_layernorm\": \"post_attention_layernorm\",\n            \"mlp.linear_fc1.layer_norm_weight\": \"post_attention_layernorm.weight\",\n            \"mlp.linear_fc1.layer_norm_bias\": \"post_attention_layernorm.bias\",\n            \"mlp.linear_fc1\": \"mlp.gate_up_proj\",\n            \"mlp.linear_fc2\": \"mlp.down_proj\",\n            # moe\n            \"mlp.router.expert_bias\": \"mlp.gate.e_score_correction_bias\",\n            \"mlp.router\": \"mlp.gate\",\n            \"mlp.shared_experts.linear_fc1\": \"mlp.shared_experts.gate_up_proj\",\n            \"mlp.shared_experts.linear_fc2\": \"mlp.shared_experts.down_proj\",\n            \"linear_fc1\": \"gate_up_proj\",\n            \"linear_fc2\": \"down_proj\",\n            # output\n            \"final_layernorm\": \"norm\",\n            \"output_layer\": \"lm_head\",\n        }\n\n    def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[int, int]:\n        tp_rank = pp_rank = None\n        rank_list = sharded_dir.split(\"_\")[2:]\n        if re.match(r\"mp_rank_(\\d\\d)_(\\d\\d\\d)\", sharded_dir):\n            tp_rank = int(rank_list[0])\n            pp_rank = int(rank_list[1])\n        elif re.match(r\"mp_rank_(\\d\\d)\", sharded_dir):\n            tp_rank = int(rank_list[0])\n            pp_rank = 0\n\n        assert tp_rank is not None and pp_rank is not None, f\"Invalid sharded dir {sharded_dir}\"\n\n        return tp_rank, pp_rank\n\n    def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[list[str], int, int]:\n        \"\"\"\n        Validates the Megatron checkpoint structure (presence of 'model.pt' in sharded directories).\n        Determines TP and PP sizes from directory names.\n        \"\"\"\n        tp_size = 0\n        pp_size = 0\n        sharded_dirs = sorted(os.listdir(model_path))\n        for sharded_dir in sharded_dirs:\n            assert \"model.pt\" in os.listdir(Path(model_path) / sharded_dir), f\"model.pt not found in {sharded_dir}\"\n            tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir)\n            tp_size = max(tp_size, tp_rank + 1)\n            pp_size = max(pp_size, pp_rank + 1)\n        return sharded_dirs, tp_size, pp_size\n\n    def _merge_across_tp(\n        self,\n        key: str,\n        tp_data: list[torch.Tensor],\n        config: PretrainedConfig,\n        tp_size: int,\n        is_value_model: bool = False,\n    ) -> Union[torch.Tensor, list[torch.Tensor]]:\n        if \"linear_fc1.weight\" in key:\n            # if the tensor is gate and proj\n            gate_lst = []\n            up_lst = []\n            for infer_param in tp_data:\n                gate, up = infer_param.chunk(2)\n                gate_lst.append(gate)\n                up_lst.append(up)\n            gate = torch.cat(gate_lst, dim=0)\n            up = torch.cat(up_lst, dim=0)\n            return [gate, up]\n        elif \"self_attention.linear_qkv.\" in key and \"layer_norm\" not in key:\n            # if the tensor is qkv, for each param on tp, split into q, k, v\n            # concat q, k, v separately.\n            q_lst = []\n            k_lst = []\n            v_lst = []\n            assert config.num_attention_heads % config.num_key_value_heads == 0\n            num_q_per_kv = config.num_attention_heads // config.num_key_value_heads\n            assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0\n            kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2)\n            split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]\n\n            for infer_param in tp_data:\n                num_query_groups_per_partition = config.num_key_value_heads // tp_size\n                for chunk in infer_param.chunk(num_query_groups_per_partition):\n                    split_size = [\n                        kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,\n                        kv_size_per_tp // num_query_groups_per_partition,\n                        kv_size_per_tp // num_query_groups_per_partition,\n                    ]\n                    q, k, v = chunk.split(split_size)\n                    q_lst.append(q)\n                    k_lst.append(k)\n                    v_lst.append(v)\n\n            q = torch.cat(q_lst, dim=0)\n            k = torch.cat(k_lst, dim=0)\n            v = torch.cat(v_lst, dim=0)\n            return [q, k, v]\n        elif \"layer_norm\" in key or \"layernorm\" in key or \"router\" in key or (\"output_layer\" in key and is_value_model):\n            return tp_data[0]\n        else:\n            dim = 0\n            if \"linear_fc2.weight\" in key or \"self_attention.linear_proj\" in key:\n                dim = 1\n            return torch.cat(tp_data, dim=dim)\n\n    def _load_state_dicts(\n        self, model_ckpt_path: str, sharded_dirs: list[str], tp_size: int, pp_size: int\n    ) -> list[list[dict]]:\n        model_state_dict_lst = [[None for _ in range(tp_size)] for _ in range(pp_size)]\n\n        def _process_one_megatron_shard(sharded_dir: str):\n            model_file_path = Path(model_ckpt_path) / sharded_dir / \"model.pt\"\n            state_dict = torch.load(model_file_path, map_location=\"cpu\", weights_only=False)\n            tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir)\n            model_state_dict_lst[pp_rank][tp_rank] = state_dict\n\n        with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:\n            futures = [executor.submit(_process_one_megatron_shard, sharded_dir) for sharded_dir in sharded_dirs]\n            for future in tqdm(futures, desc=f\"Loading {len(sharded_dirs)} Megatron shards\", total=len(sharded_dirs)):\n                future.result()\n\n        return model_state_dict_lst\n\n    def _check_megatron_state_key(self, key: str) -> bool:\n        \"\"\"\n        Checks if the key is a valid Megatron state key.\n\n        Now the model merger only supports keys that start with \"decoder/embedding/output_layer\" in TransformerLayer.\n        Shall not use key starts with \"model.\"\n        \"\"\"\n        if key.startswith(\"model.\"):\n            raise ValueError(\n                f\"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder/embedding/output_layer' in TransformerLayer.\"\n            )\n\n        skip_checking_keys = [\"embedding.word_embeddings\", \"output_layer\"]\n        for skip_key in skip_checking_keys:\n            if skip_key in key:\n                print(f\"skip checking key {key}\")\n                return\n\n        # Exclude extra state keys\n        if not key.startswith(\"decoder\"):\n            raise ValueError(\n                f\"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer.\"\n            )\n\n    def _merge_state_dicts(\n        self, model_state_dict_lst: list[list[dict]], tp_size: int, pp_size: int\n    ) -> dict[str, torch.Tensor]:\n        state_dict = {}\n        vpp_size = len(model_state_dict_lst[0][0])\n        layers_cum = 0\n\n        for vpp_rank in range(vpp_size):\n            for pp_rank in range(pp_size):\n                layers_handled = 0\n                keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys()\n                for key in keys:\n                    if \"extra_state\" in key:\n                        continue\n                    if self.config.tie_word_embedding and (\"output_layer\" in key):\n                        print(\"skip lm_head and reward_head loading because of tie_word_embeddings\")\n                        continue\n\n                    self._check_megatron_state_key(key)\n                    hf_name = self._replace_name(key, self.params_mapping)\n                    assert hf_name is not None, f\"Failed to convert layer name [{key}] from megatron to huggingface.\"\n                    if \"model.layers.\" in hf_name:\n                        local_layer_no = int(hf_name.split(\".\")[2])\n                        layers_handled = max(local_layer_no, layers_handled)\n                        global_layer_no = local_layer_no + layers_cum\n                        new_key_list = hf_name.split(\".\")\n                        new_key_list[2] = str(global_layer_no)\n                        hf_name = \".\".join(new_key_list)\n                    else:\n                        warnings.warn(f\"hf_name {hf_name} will not be fixed with layer number\", stacklevel=2)\n\n                    tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)]\n                    merged = self._merge_across_tp(key, tp_data, self.model_config, tp_size, self.config.is_value_model)\n\n                    if not isinstance(merged, list):\n                        state_dict[hf_name] = merged\n                    elif len(merged) == 3:\n                        # split qkv\n                        for n, d in zip([\"q\", \"k\", \"v\"], merged):\n                            state_dict[hf_name.replace(\"qkv\", n)] = d\n                    elif len(merged) == 2:\n                        # split gate up\n                        state_dict[hf_name.replace(\"gate_up\", \"gate\")] = merged[0]\n                        state_dict[hf_name.replace(\"gate_up\", \"up\")] = merged[1]\n                    print(\n                        f\"converted {key} to {hf_name} with shape {merged.shape if isinstance(merged, torch.Tensor) else [t.shape for t in merged]}\"\n                    )\n\n                layers_cum += layers_handled + 1  # zero based\n\n        return state_dict\n\n    def merge_and_save(self):\n        from verl.utils.megatron_utils import get_model_checkpoint_path\n\n        model_ckpt_path = get_model_checkpoint_path(self.config.local_dir)\n        sharded_dirs, tp_size, pp_size = self._check_megatron_checkpoint_path(model_ckpt_path)\n        print(f\"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}\")\n\n        model_state_dict_lst = self._load_state_dicts(model_ckpt_path, sharded_dirs, tp_size, pp_size)\n        merged_state_dict = self._merge_state_dicts(model_state_dict_lst, tp_size, pp_size)\n        del model_state_dict_lst\n\n        if self.config.operation == \"test\":\n            if not self.config.test_hf_dir:\n                raise ValueError(\"test_hf_dir must be provided for test operation\")\n            self._test_state_dict(merged_state_dict)\n        elif self.config.operation == \"merge\":\n            self.save_hf_model_and_tokenizer(merged_state_dict)\n            if self.config.hf_upload:\n                self.upload_to_huggingface()\n        else:\n            raise ValueError(f\"Unknown operation: {self.config.operation}\")\n\n    def _test_state_dict(self, state_dict: dict[str, torch.Tensor]):\n        \"\"\"\n        Compares the merged Megatron state_dict against a reference safetensors model.\n        Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name.\n        \"\"\"\n        ref_state_dict = load_file(Path(self.config.test_hf_dir) / \"model.safetensors\")\n\n        for name, loaded_weight in state_dict.items():\n            # name = self._replace_name(original_name, self.params_mapping)\n            if not name or name.endswith(\".bias\") and name not in ref_state_dict:\n                continue\n            if \"rotary_emb.inv_freq\" in name:\n                continue\n            if self.config.tie_word_embedding and \"lm_head.weight\" in name:\n                continue\n            if name not in ref_state_dict:\n                raise RuntimeError(f\"key: {name} not exist in state_dict\")\n            param = ref_state_dict[name]\n            assert loaded_weight.dtype == param.dtype\n            torch.testing.assert_close(loaded_weight, param, atol=1e-2, rtol=5e-2)\n\n    def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str:\n        for m_name, v_name in name_mapping.items():\n            if m_name not in megatron_name:\n                continue\n\n            megatron_name = megatron_name.replace(\"decoder\", \"model\")\n            param_name = megatron_name.replace(m_name, v_name)\n            return param_name\n\n        return None  # Return None if no mapping found\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"verl model merger\")\n    subparsers = parser.add_subparsers(dest=\"operation\", required=True, help=\"Specify 'merge' or 'test' operation.\")\n\n    base_op_parser = argparse.ArgumentParser(add_help=False)\n    base_op_parser.add_argument(\n        \"--backend\", type=str, required=True, choices=[\"fsdp\", \"megatron\"], help=\"The backend of the model\"\n    )\n    base_op_parser.add_argument(\"--local_dir\", type=str, required=True, help=\"Path to the saved model checkpoints\")\n    base_op_parser.add_argument(\n        \"--hf_model_path\",\n        type=str,\n        default=None,\n        help=\"(Deprecated) Path to the original Hugging Face model for config.\",\n    )\n    base_op_parser.add_argument(\n        \"--tie-word-embedding\",\n        action=\"store_true\",\n        help=\"Whether to tie word embedding weights (currently only Megatron supported)\",\n    )\n    base_op_parser.add_argument(\n        \"--is-value-model\",\n        action=\"store_true\",\n        help=\"Whether the model is a value model (currently only Megatron supported)\",\n    )\n\n    merge_parser = subparsers.add_parser(\"merge\", parents=[base_op_parser], help=\"Merge model checkpoints and save.\")\n    merge_parser.add_argument(\n        \"--target_dir\", default=\"tmp\", type=str, help=\"Directory to save the merged huggingface model\"\n    )\n    merge_parser.add_argument(\n        \"--hf_upload_path\", default=None, type=str, help=\"Hugging Face repository ID to upload the model\"\n    )\n    merge_parser.add_argument(\n        \"--private\", action=\"store_true\", help=\"Whether to upload the model to a private Hugging Face repository\"\n    )\n\n    test_parser = subparsers.add_parser(\n        \"test\", parents=[base_op_parser], help=\"Test merged model against a reference Hugging Face model\"\n    )\n    test_parser.add_argument(\n        \"--test_hf_dir\", type=str, required=True, help=\"Path to the reference Hugging Face model directory for testing\"\n    )\n\n    args = parser.parse_args()\n\n    common_config_args = {\n        \"operation\": args.operation,\n        \"backend\": args.backend,\n        \"tie_word_embedding\": args.tie_word_embedding,\n        \"is_value_model\": args.is_value_model,\n        \"local_dir\": args.local_dir,\n        \"hf_model_path\": args.hf_model_path,\n        \"hf_model_config_path\": args.local_dir,\n    }\n\n    if args.operation == \"merge\":\n        config = ModelMergerConfig(\n            **common_config_args,\n            target_dir=args.target_dir,\n            hf_upload_path=args.hf_upload_path,\n            private=args.private,\n            test_hf_dir=None,\n        )\n        os.makedirs(config.target_dir, exist_ok=True)\n    elif args.operation == \"test\":\n        config = ModelMergerConfig(\n            **common_config_args,\n            test_hf_dir=args.test_hf_dir,\n            # the following args are not used by test operation\n            target_dir=None,\n            hf_upload_path=None,\n            private=False,\n        )\n    else:\n        raise NotImplementedError(f\"Unknown operation: {args.operation}\")\n\n    if config.backend == \"fsdp\":\n        merger = FSDPModelMerger(config)\n    elif config.backend == \"megatron\":\n        merger = MegatronModelMerger(config)\n    else:\n        raise NotImplementedError(f\"Unknown backend: {config.backend}\")\n\n    merger.merge_and_save()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/scripts/print_cfg.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\ntry:\n    import hydra\nexcept ImportError as e:\n    raise ImportError(\"Please install hydra-core via 'pip install hydra-core' and retry.\") from e\n\n\n@hydra.main(config_path=\"../verl/trainer/config\", config_name=\"ppo_trainer\", version_base=None)\ndef main(config):\n    \"\"\"Main entry point for PPO training with Hydra configuration management.\n\n    Args:\n        config_dict: Hydra configuration dictionary containing training parameters.\n    \"\"\"\n    print(config)\n    from verl.utils.config import omega_conf_to_dataclass\n\n    profiler_config = omega_conf_to_dataclass(config.critic.profiler)\n    print(profiler_config)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/scripts/rollout_viewer.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport re\nimport traceback\nfrom pathlib import Path\nfrom typing import Annotated, Optional\n\nimport aiofiles\n\ntry:\n    import ujson as json\nexcept ImportError:\n    import json\nimport typer\nfrom rich.highlighter import ReprHighlighter\nfrom rich.markdown import Markdown\nfrom rich.table import Table\nfrom rich.text import Text\nfrom textual import on\nfrom textual.app import App, ComposeResult\nfrom textual.containers import Horizontal, Vertical, VerticalScroll\nfrom textual.widgets import Input, ProgressBar, Select, SelectionList, Static\n\nINDEX_KEY = \"__IDX\"\nFILE_SUFFIX = \".jsonl\"\n\n\ndef check_textual_version():\n    # check if textual version is equal to 0.52.1\n    import textual\n    from packaging.version import Version\n\n    if Version(textual.__version__) != Version(\"0.52.1\"):\n        raise ImportError(f\"Textual version {textual.__version__} is not supported, please pip install textual==0.52.1\")\n\n\ncheck_textual_version()\n\n\nasync def load_path(p: Path, data: dict, mask_strs: str, idx: int, pbar):\n    samples = []\n    async with aiofiles.open(p, encoding=\"utf-8\") as f:\n        async for line in f:\n            d = json.loads(line)\n            for k in d:\n                if isinstance(d[k], str):\n                    if mask_strs:\n                        d[k] = re.sub(rf\"{mask_strs}\", \"*\", d[k])\n                else:\n                    d[k] = json.dumps(d[k], ensure_ascii=False, indent=4)\n\n            d[INDEX_KEY] = len(samples)\n            samples.append(d)\n        data[idx] = {\"samples\": samples}\n\n    print(f\"path {p} loaded\")\n    pbar.advance(1)\n\n\nasync def load_dir(path: Path, data: dict[int, dict], pbar, mask_strs: str = \"\"):\n    paths = list(path.glob(f\"*{FILE_SUFFIX}\"))\n    paths = sorted(paths, key=lambda x: int(x.stem))\n\n    tasks = [load_path(p, data, mask_strs, i, pbar) for i, p in enumerate(paths)]\n\n    await asyncio.gather(*tasks)\n\n\nclass Highlighter(ReprHighlighter):\n    highlights = ReprHighlighter.highlights + [\n        r\"(?P<tag_name>[][\\<\\>{}()\\|（）【】\\[\\]=`])\",\n        r\"\\<\\|(?P<tag_name>[\\w\\W]*?)\\|\\>\",\n    ]\n\n\ndef center_word_with_equals_exactly(word: str, total_length: int, char: str = \"=\") -> str:\n    if len(word) > total_length:\n        return word\n\n    padding = total_length - len(word)\n    left_pad = (padding) // 2\n    right_pad = (padding + 1) // 2\n    return char * left_pad + \" \" + word + \" \" + char * right_pad\n\n\ndef highlight_keyword(content: str, keyword: Optional[str]):\n    if not keyword:\n        return Text(content)\n    text = Text()\n    parts = content.split(keyword)\n    for i, part in enumerate(parts):\n        text.append(part, style=None)\n        if i < len(parts) - 1:\n            # text.append(keyword, style=Style(color=\"#d154d1\", bgcolor=\"yellow\", bold=True))\n            text.append(keyword, style=\"on #8f51b5\")\n    return text\n\n\nhelp_doc = \"\"\"\n⌨️   keybinds：\n\n- `f/esc`: find/cancel\n- `tab/←/→`: change focus\n- `j/k`: page down/up\n- `g/G`: scroll home/end\n- `n/N`: next sample/step\n- `p/P`: previous sample/step\n- `s`: switch display mode\n  - plain text\n  - rich table\n\n\"\"\"\n\n\nclass JsonLineViewer(App):\n    BINDINGS = [\n        (\"left\", \"focus_previous\", \"Focus Previous\"),\n        (\"right\", \"focus_next\", \"Focus Next\"),\n        (\"s\", \"swith_render\", \"switch render\"),\n        # control\n        (\"n\", \"next_sample\", \"Next Sample\"),\n        (\"N\", \"next_step\", \"Next Step\"),\n        (\"p\", \"previous_sample\", \"Previous Sample\"),\n        (\"P\", \"previous_step\", \"Previous Step\"),\n        # search\n        (\"f\", \"toggle_search\", \"find\"),\n        (\"enter\", \"next_search\", \"find next\"),\n        (\"escape\", \"cancel_search\", \"cancel find\"),\n        # scroll\n        (\"j\", \"page_down\", \"page down\"),\n        (\"k\", \"page_up\", \"page up\"),\n        (\"g\", \"page_home\", \"page home\"),\n        (\"G\", \"page_end\", \"page end\"),\n    ]\n\n    CSS = \"\"\"\n\n    Select:focus > SelectCurrent {\n        border: tall #8f51b5;\n    }\n    Select.-expanded > SelectCurrent {\n        border: tall #8f51b5;\n    }\n    #select-container {\n        width: 15%;\n        height: 100%;\n        align: center top;\n    }\n    #search-container {\n        height: 10%;\n        align: center top;\n    }\n    #search-box {\n        width: 50%;\n    }\n    #reqid-box {\n        width: 50%;\n    }\n    \"\"\"\n\n    def __init__(self, step_num: int, data: dict[int, dict], pbar):\n        super().__init__()\n        self.step_num = step_num\n\n        self.data = data\n        self.render_table = False\n        self.selected_step_index = 0\n        self.selected_sample_index = 0\n        self.pbar = pbar\n\n        self.matches = []\n        self.current_match_index = 0\n\n        self.highlighter = Highlighter()\n\n        first_samples = data[list(data.keys())[0]][\"samples\"]\n        # Prepare the initial field filter list (all keys from the first sample)\n        self.filter_fields = [(f, f, True) for f in first_samples[0].keys()]\n\n        # Internal set used for fast membership checks when we add new fields on the fly.\n        # We keep it here so that when new columns appear in later steps (e.g. `request_id`),\n        # they can be added to the UI automatically without restarting the viewer.\n        self._field_set: set[str] = set(first_samples[0].keys())\n        self.sample_num = len(first_samples)\n\n    def compose(self) -> ComposeResult:\n        with Horizontal(id=\"search-container\"):\n            yield Input(placeholder=\"find something...\", id=\"search-box\")\n            yield Input(placeholder=\"request id...\", id=\"reqid-box\")\n            with Vertical(id=\"search-container2\"):\n                yield self.pbar\n                yield Static(\"\", id=\"search-status\")\n\n        with Horizontal():\n            with Vertical(id=\"select-container\"):\n                yield Static(\"\\n\")\n                yield Static(\n                    renderable=Markdown(\n                        help_doc,\n                    ),\n                    markup=False,\n                )\n                yield Static(\"\\n\")\n                yield Select(\n                    id=\"step-select\",\n                    value=0,\n                    prompt=\"select step\",\n                    options=[(\"step: 1\", 0)],\n                    allow_blank=False,\n                )\n                yield Select(\n                    id=\"sample-select\",\n                    value=0,\n                    prompt=\"select sample\",\n                    options=[(\"sample: 1\", 0)],\n                    allow_blank=False,\n                )\n                yield Select(\n                    id=\"sample-sort\",\n                    value=0,\n                    prompt=\"排序\",\n                    options=[\n                        (\"sort\", 0),\n                        (\"score asc\", 1),\n                        (\"score desc\", 2),\n                    ],\n                    allow_blank=False,\n                )\n\n                yield SelectionList[int]((\"Select ALL\", 1, True), id=\"fields-select-all\")\n                with VerticalScroll(id=\"scroll-view2\"):\n                    yield SelectionList[str](*self.filter_fields, id=\"fields-select\")\n            with VerticalScroll(id=\"scroll-view\"):\n                yield Static(id=\"content\", markup=False)\n\n    async def on_mount(self) -> None:\n        self.step_select = self.query_one(\"#step-select\", Select)\n        self.sample_select = self.query_one(\"#sample-select\", Select)\n        self.sample_sort = self.query_one(\"#sample-sort\", Select)\n        self.content_display = self.query_one(\"#content\", Static)\n        self.search_box = self.query_one(\"#search-box\", Input)\n        self.reqid_box = self.query_one(\"#reqid-box\", Input)\n        self.scroll_view = self.query_one(\"#scroll-view\", VerticalScroll)\n        self.search_status = self.query_one(\"#search-status\", Static)\n        self.fields_select = self.query_one(\"#fields-select\", SelectionList)\n        self.fields_select.border_title = \"field filter\"\n\n        if self.data:\n            self.step_select.set_options([(f\"step: {i + 1}\", i) for i in range(self.step_num)])\n            self.sample_select.set_options([(f\"sample: {i + 1}\", i) for i in range(self.sample_num)])\n            self.step_select.focus()\n            await self.update_content()\n\n    def update_result_options(self, offset: int = 0, sort_desc: Optional[bool] = None):\n        options = []\n        if isinstance(self.selected_step_index, int) and self.selected_step_index < len(self.data):\n            if self.sample_num is None or sort_desc is not None:\n                samples = self.data[self.selected_step_index].get(\"samples\", [])\n                if not samples:\n                    self.selected_sample_index = offset\n                    return\n                if sort_desc is not None:\n                    samples = sorted(\n                        samples,\n                        key=lambda x: x.get(\"score\", x.get(\"score_1\", 0)),\n                        reverse=sort_desc,\n                    )\n\n                options = [(f\"sample: {r[INDEX_KEY] + 1}\", r[INDEX_KEY]) for r in samples]\n                self.sample_select.set_options(options)\n                self.sample_num = len(samples)\n\n            if sort_desc is not None and options:\n                self.selected_sample_index = options[0][1]\n            else:\n                self.selected_sample_index = offset\n\n    async def update_content(self, search_keyword: Optional[str] = None):\n        content = \"\"\n        try:\n            samples = self.data[self.selected_step_index].get(\"samples\", [])\n            content_dict_full = samples[self.selected_sample_index]\n\n            # Dynamically track any NEW keys that appear and add them to the field filter.\n            self._update_fields_select(content_dict_full.keys())\n\n            # Apply field selection filter (only show selected fields)\n            content_dict = {k: v for k, v in content_dict_full.items() if k in self.fields_select.selected}\n            if self.render_table:\n                content = Table(\"key\", \"value\", show_lines=True)\n                for k in content_dict:\n                    v = content_dict[k]\n                    v = f\"{v}\"\n                    content.add_row(\n                        k,\n                        self.highlighter(highlight_keyword(v, search_keyword)),\n                    )\n            else:\n                text = Text()\n                for k in content_dict:\n                    v = content_dict[k]\n                    s = center_word_with_equals_exactly(k, 64) + f\"\\n{v}\\n\"\n                    text.append(highlight_keyword(s, search_keyword))\n                content = self.highlighter(text)\n        except KeyError:\n            content = f\"Loading data asynchronously, progress: {len(self.data)}/{self.step_num} step\"\n\n        except Exception:\n            content = self.highlighter(traceback.format_exc())\n\n        self.content_display.update(content)\n\n    # ---------------------------------------------------------------------\n    # Request-ID jump logic\n    # ---------------------------------------------------------------------\n\n    @on(Input.Submitted, \"#reqid-box\")\n    async def on_reqid_submitted(self, event: Input.Submitted) -> None:\n        \"\"\"Jump to the sample that has a matching `request_id`.\"\"\"\n\n        req_id_raw = event.value.strip()\n        # Remove hyphens so search is tolerant to different id formats\n        req_id = req_id_raw.replace(\"-\", \"\")\n        if not req_id:\n            return\n\n        found = False\n        for step_idx, step_data in self.data.items():\n            for sample in step_data.get(\"samples\", []):\n                sample_id = str(sample.get(\"request_id\", \"\"))\n                if sample_id.replace(\"-\", \"\") == req_id:\n                    # Update selected indices\n                    self.selected_step_index = step_idx\n                    self.step_select.value = step_idx\n\n                    # Ensure sample list is updated and select sample\n                    self.update_result_options(offset=sample[INDEX_KEY])\n                    self.selected_sample_index = sample[INDEX_KEY]\n                    self.sample_select.value = sample[INDEX_KEY]\n\n                    await self._clear_search()\n                    await self.update_content()\n\n                    found = True\n                    break\n            if found:\n                break\n\n        if not found:\n            self.search_status.update(Text(f\"request_id '{req_id_raw}' not found\", style=\"bold red\"))\n        else:\n            # Keep the typed id in the input box so users see what was searched.\n            pass\n\n    # ---------------------------------------------------------------------\n    # Helper: add new fields to SelectionList on-the-fly\n    # ---------------------------------------------------------------------\n\n    def _update_fields_select(self, keys):\n        \"\"\"Add any unseen *keys* to the field-selection widget so they can be toggled.\n\n        The viewer is often launched with only the first step loaded. Later steps may\n        introduce new columns (e.g. `request_id`). This helper ensures those fields\n        become visible without requiring a restart.\n        \"\"\"\n        # Ensure we have the widget (only after on_mount)\n        if not hasattr(self, \"fields_select\"):\n            return\n\n        for k in keys:\n            if k not in self._field_set:\n                self._field_set.add(k)\n                try:\n                    # By default, new fields are selected so they appear immediately.\n                    self.fields_select.add_option(k, k, selected=True)\n                except Exception:\n                    # Fallback for older textual versions where signature is different.\n                    self.fields_select.add_option((k, k, True))\n\n    @on(Select.Changed, \"#step-select\")\n    async def step_changed(self, event):\n        self.selected_step_index = event.value\n        self.update_result_options()\n        await self.update_content()\n\n    @on(Select.Changed, \"#sample-select\")\n    async def sample_changed(self, event):\n        self.selected_sample_index = event.value\n        await self._clear_search()\n        await self.update_content()\n\n    @on(Select.Changed, \"#sample-sort\")\n    async def sort_changed(self, event):\n        v = event.value\n        self.update_result_options(sort_desc=None if v == 0 else False if v == 1 else True)\n        await self.update_content()\n\n    @on(SelectionList.SelectedChanged, \"#fields-select\")\n    async def fields_changed(self, event):\n        await self.update_content()\n\n    @on(SelectionList.SelectedChanged, \"#fields-select-all\")\n    async def fields_all_changed(self, event):\n        s = self.query_one(\"#fields-select-all\", SelectionList)\n        if s.selected:\n            self.fields_select.select_all()\n        else:\n            self.fields_select.deselect_all()\n\n    def action_focus_previous(self):\n        self.screen.focus_previous()\n\n    def action_focus_next(self):\n        self.screen.focus_next()\n\n    async def action_next_step(self) -> None:\n        self.selected_step_index += 1\n        if self.selected_step_index >= self.step_num:\n            self.selected_step_index = 0\n        self.step_select.value = self.selected_step_index\n        self.update_result_options()\n        await self.update_content()\n\n    async def action_next_sample(self) -> None:\n        self.selected_sample_index += 1\n        if not self.sample_num or self.selected_sample_index >= self.sample_num:\n            self.selected_sample_index = 0\n        self.sample_select.value = self.selected_sample_index\n        await self._clear_search()\n        await self.update_content()\n\n    async def action_previous_step(self) -> None:\n        self.selected_step_index -= 1\n        if self.selected_step_index < 0:\n            self.selected_step_index = self.step_num - 1\n        self.step_select.value = self.selected_step_index\n        self.update_result_options()\n        await self.update_content()\n\n    async def action_previous_sample(self) -> None:\n        self.selected_sample_index -= 1\n        if self.selected_sample_index < 0:\n            self.selected_sample_index = self.sample_num - 1\n        self.sample_select.value = self.selected_sample_index\n        await self._clear_search()\n        await self.update_content()\n\n    async def action_swith_render(self):\n        self.render_table = not self.render_table\n        await self.update_content()\n\n    def action_toggle_search(self) -> None:\n        self.search_box.focus()\n\n    async def action_cancel_search(self) -> None:\n        self.search_box.value = \"\"\n        await self._clear_search()\n        await self.update_content()\n\n    async def _clear_search(self):\n        self.matches = []\n        self.search_status.update(\"\")\n        self.current_match_index = 0\n\n    @on(Input.Submitted, \"#search-box\")\n    async def on_search_submitted(self, event: Input.Submitted) -> None:\n        self.matches = []\n        self.current_match_index = 0\n        if event.value:\n            await self.update_content(event.value)\n            renderable = self.content_display.render()\n            if isinstance(renderable, Table):\n                return\n\n            assert isinstance(renderable, Text)\n            console = self.content_display._console\n            lines = renderable.wrap(console, self.scroll_view.container_size.width)\n            line_idx_recorded = set()\n            for line_idx, line in enumerate(lines):\n                if line_idx in line_idx_recorded:\n                    continue\n                if event.value in line:\n                    self.matches.append(\n                        {\n                            \"line\": line_idx,\n                            \"word\": event.value,\n                        }\n                    )\n                    line_idx_recorded.add(line_idx)\n            self.scroll_view.focus()\n            await self.action_next_search()\n\n    async def action_next_search(self) -> None:\n        if not self.matches or self.current_match_index >= len(self.matches):\n            return\n\n        target_line = self.matches[self.current_match_index][\"line\"]\n        self.scroll_view.scroll_to(x=0, y=target_line * 1, animate=False)\n        self.current_match_index = (self.current_match_index + 1) % len(self.matches)\n        self.search_status.update(\n            Text(\n                f\"Find ：{self.current_match_index + 1}/{len(self.matches)}\",\n                style=\"bold on #8f51b5\",\n            )\n        )\n\n    def action_page_up(self):\n        self.scroll_view.scroll_page_up(animate=False)\n\n    def action_page_down(self):\n        self.scroll_view.scroll_page_down(animate=False)\n\n    def action_page_home(self):\n        self.scroll_view.scroll_home(animate=False)\n\n    def action_page_end(self):\n        self.scroll_view.scroll_end(animate=False)\n\n\nasync def _run(path: Path, mask_str: str):\n    assert path.exists(), f\"{path} not exist\"\n\n    paths = list(path.glob(f\"*{FILE_SUFFIX}\"))\n    paths = sorted(paths, key=lambda x: int(x.stem))\n\n    if not paths:\n        raise ValueError(f\"no available reward dump files under f{path}\")\n\n    print(f\"get jsonl file nums: {len(paths)}\")\n\n    pbar = ProgressBar(total=len(paths), name=\"data load progress\")\n    data = {}\n    await load_path(paths[0], data, mask_str, 0, pbar)\n    app = JsonLineViewer(step_num=len(paths), data=data, pbar=pbar)\n    await asyncio.gather(load_dir(path, data, pbar, mask_str), app.run_async())\n\n\napp = typer.Typer()\n\n\n@app.command(help=\"launch TUI APP\")\ndef run(\n    rollout_data_dir: Path,\n    mask_str: Annotated[str, typer.Option(help=\"string that will be masked to *\")] = \"<\\|image_pad\\|>|<\\|imgpad\\|>\",\n):\n    loop = asyncio.get_event_loop()\n    loop.run_until_complete(_run(rollout_data_dir, mask_str))\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "verl_rl/setup.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# setup.py is the fallback installation script when pyproject.toml does not work\nimport os\nfrom pathlib import Path\n\nfrom setuptools import find_packages, setup\n\nversion_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))\n\nwith open(os.path.join(version_folder, \"verl/version/version\")) as f:\n    __version__ = f.read().strip()\n\ninstall_requires = [\n    \"accelerate\",\n    \"codetiming\",\n    \"datasets\",\n    \"dill\",\n    \"hydra-core\",\n    \"numpy<2.0.0\",\n    \"pandas\",\n    \"peft\",\n    \"pyarrow>=19.0.0\",\n    \"pybind11\",\n    \"pylatexenc\",\n    \"ray[default]>=2.41.0\",\n    \"torchdata\",\n    \"tensordict>=0.8.0,<=0.9.1,!=0.9.0\",\n    \"transformers\",\n    \"wandb\",\n    \"packaging>=20.0\",\n]\n\nTEST_REQUIRES = [\"pytest\", \"pre-commit\", \"py-spy\", \"pytest-asyncio\"]\nPRIME_REQUIRES = [\"pyext\"]\nGEO_REQUIRES = [\"mathruler\", \"torchvision\", \"qwen_vl_utils\"]\nGPU_REQUIRES = [\"liger-kernel\", \"flash-attn\"]\nMATH_REQUIRES = [\"math-verify\"]  # Add math-verify as an optional dependency\nVLLM_REQUIRES = [\"tensordict>=0.8.0,<=0.9.1,!=0.9.0\", \"vllm>=0.7.3,<=0.8.5\"]\nSGLANG_REQUIRES = [\n    \"tensordict>=0.8.0,<=0.9.1,!=0.9.0\",\n    \"sglang[srt,openai]==0.4.6.post5\",\n    \"torch-memory-saver>=0.0.5\",\n    \"torch==2.6.0\",\n]\nTRL_REQUIRES = [\"trl<=0.9.6\"]\nMCORE_REQUIRES = [\"mbridge\"]\n\nextras_require = {\n    \"test\": TEST_REQUIRES,\n    \"prime\": PRIME_REQUIRES,\n    \"geo\": GEO_REQUIRES,\n    \"gpu\": GPU_REQUIRES,\n    \"math\": MATH_REQUIRES,\n    \"vllm\": VLLM_REQUIRES,\n    \"sglang\": SGLANG_REQUIRES,\n    \"trl\": TRL_REQUIRES,\n    \"mcore\": MCORE_REQUIRES,\n}\n\n\nthis_directory = Path(__file__).parent\nlong_description = (this_directory / \"README.md\").read_text()\n\nsetup(\n    name=\"verl\",\n    version=__version__,\n    package_dir={\"\": \".\"},\n    packages=find_packages(where=\".\"),\n    url=\"https://github.com/volcengine/verl\",\n    license=\"Apache 2.0\",\n    author=\"Bytedance - Seed - MLSys\",\n    author_email=\"zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk\",\n    description=\"verl: Volcano Engine Reinforcement Learning for LLM\",\n    install_requires=install_requires,\n    extras_require=extras_require,\n    package_data={\n        \"\": [\"version/*\"],\n        \"verl\": [\"trainer/config/*.yaml\"],\n    },\n    include_package_data=True,\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n)\n"
  },
  {
    "path": "verl_rl/tests/README.md",
    "content": "# Tests layout\n\nEach folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance:\n- `tests/trainer` for testing functionality related to `verl/trainer`\n- `tests/models` for testing functionality related to `verl/models`\n- ...\n\nThere are a few folders with `special_` prefix, created for special purposes:\n- `special_distributed`: unit tests that must run with multiple GPUs\n- `special_e2e`: end-to-end tests with training/generation scripts\n- `special_npu`: tests for NPUs\n- `special_sanity`: a suite of quick sanity tests\n- `special_standalone`: a set of test that are designed to run in dedicated environments\n\nAccelerators for tests \n- By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`.\n- For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment.\n\n# Workflow layout\n\nAll CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs:\n1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml`\n2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml`\n3. End-to-end tests: `e2e_*.yml`\n4. Unit tests\n  - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py`\n  - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix.\n  - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when\n    - new workflow yaml is added to `.github/workflows`\n    - new tests are added to workflow mentioned in 2."
  },
  {
    "path": "verl_rl/tests/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/tests/experimental/agent_loop/agent_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\nfrom omegaconf import DictConfig\n\nfrom verl.experimental.agent_loop import AgentLoopManager\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker\n\n\ndef init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup:\n    # =========================== 1. Create hybrid ActorRollout workers ===========================\n    actor_rollout_cls = (\n        AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == \"async\" else ActorRolloutRefWorker\n    )\n    role_worker_mapping = {\n        Role.ActorRollout: ray.remote(actor_rollout_cls),\n    }\n    global_pool_id = \"global_pool\"\n    resource_pool_spec = {\n        global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n    }\n    mapping = {\n        Role.ActorRollout: global_pool_id,\n    }\n    resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n    resource_pool_manager.create_resource_pool()\n    resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()}\n\n    # create actor and rollout\n    resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout)\n    actor_rollout_cls = RayClassWithInitArgs(\n        cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role=\"actor_rollout\"\n    )\n    resource_pool_to_cls[resource_pool][\"actor_rollout\"] = actor_rollout_cls\n\n    all_wg = {}\n    for resource_pool, class_dict in resource_pool_to_cls.items():\n        worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n        wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)\n        spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n        all_wg.update(spawn_wg)\n    actor_rollout_wg = all_wg[\"actor_rollout\"]\n    actor_rollout_wg.init_model()\n\n    if config.actor_rollout_ref.rollout.mode == \"sync\":\n        return actor_rollout_wg\n\n    # =========================== 2. Create AgentLoopManager ===========================\n    agent_loop_manager = AgentLoopManager(\n        config=config,\n        worker_group=actor_rollout_wg,\n    )\n\n    return agent_loop_manager\n"
  },
  {
    "path": "verl_rl/tests/experimental/agent_loop/test_basic_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport os\nfrom typing import Any\n\nimport numpy as np\nimport pytest\nimport ray\nfrom omegaconf import DictConfig\nfrom transformers.utils import get_json_schema\n\nfrom tests.experimental.agent_loop.agent_utils import init_agent_loop_manager\nfrom verl.experimental.agent_loop.agent_loop import get_trajectory_info\nfrom verl.protocol import DataProto\nfrom verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema\nfrom verl.utils import hf_tokenizer\n\n\n@pytest.fixture\ndef init_config() -> DictConfig:\n    from hydra import compose, initialize_config_dir\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(config_name=\"ppo_trainer\")\n    model_path = \"Qwen/Qwen2.5-1.5B-Instruct\"\n    config.actor_rollout_ref.model.path = model_path\n    config.actor_rollout_ref.rollout.name = os.getenv(\"ROLLOUT_NAME\", \"vllm\")\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.prompt_length = 4096\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.n = 4\n    config.actor_rollout_ref.rollout.agent.num_workers = 2\n\n    # test sleep/wake_up with fsdp offload\n    config.actor_rollout_ref.actor.fsdp_config.param_offload = True\n    config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True\n\n    return config\n\n\ndef test_single_turn(init_config):\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    agent_loop_manager = init_agent_loop_manager(init_config)\n\n    raw_prompts = [\n        [\n            {\n                \"role\": \"user\",\n                \"content\": \"Let's play a role playing game. Your name is Alice, your favorite color is blue.\",\n            }\n        ],\n        [{\"role\": \"user\", \"content\": \"Let's play a role playing game. Your name is Bob, your favorite color is red.\"}],\n    ]\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array(raw_prompts),\n            \"agent_name\": np.array([\"single_turn_agent\"] * len(raw_prompts)),\n        },\n    )\n    n = init_config.actor_rollout_ref.rollout.n\n    batch = batch.repeat(n)\n    result = agent_loop_manager.generate_sequences(prompts=batch)\n    assert len(result) == len(raw_prompts) * n\n\n    # check result\n    seq_len = result.batch[\"prompts\"].size(1) + result.batch[\"responses\"].size(1)\n    assert result.batch[\"input_ids\"].size(1) == seq_len\n    assert result.batch[\"attention_mask\"].size(1) == seq_len\n    assert result.batch[\"position_ids\"].size(1) == seq_len\n\n    # check turns\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    assert np.all(num_turns == 2)\n\n    print(\"Test passed!\")\n    ray.shutdown()\n\n\nclass WeatherTool(BaseTool):\n    def get_current_temperature(self, location: str, unit: str = \"celsius\"):\n        \"\"\"Get current temperature at a location.\n\n        Args:\n            location: The location to get the temperature for, in the format \"City, State, Country\".\n            unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n        Returns:\n            the temperature, the location, and the unit in a dict\n        \"\"\"\n        print(f\"[DEBUG] get_current_temperature: {location}, {unit}\")\n        return {\n            \"temperature\": 26.1,\n            \"location\": location,\n            \"unit\": unit,\n        }\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        schema = get_json_schema(self.get_current_temperature)\n        return OpenAIFunctionToolSchema(**schema)\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        try:\n            result = self.get_current_temperature(**parameters)\n            return json.dumps(result), 0, {}\n        except Exception as e:\n            return str(e), 0, {}\n\n\nclass WeatherToolWithData(BaseTool):\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        schema = get_json_schema(self.get_temperature_date)\n        return OpenAIFunctionToolSchema(**schema)\n\n    def get_temperature_date(self, location: str, date: str, unit: str = \"celsius\"):\n        \"\"\"Get temperature at a location and date.\n\n        Args:\n            location: The location to get the temperature for, in the format \"City, State, Country\".\n            date: The date to get the temperature for, in the format \"Year-Month-Day\".\n            unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n        Returns:\n            the temperature, the location, the date and the unit in a dict\n        \"\"\"\n        print(f\"[DEBUG] get_temperature_date: {location}, {date}, {unit}\")\n        return {\n            \"temperature\": 25.9,\n            \"location\": location,\n            \"date\": date,\n            \"unit\": unit,\n        }\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        try:\n            result = self.get_temperature_date(**parameters)\n            return json.dumps(result), 0, {}\n        except Exception as e:\n            return str(e), 0, {}\n\n\ndef test_tool_agent(init_config):\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    # =========================== 1. Init rollout manager ===========================\n    tool_config = {\n        \"tools\": [\n            {\n                \"class_name\": \"tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool\",\n                \"config\": {\"type\": \"native\"},\n            },\n            {\n                \"class_name\": \"tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData\",\n                \"config\": {\"type\": \"native\"},\n            },\n        ]\n    }\n    tool_config_path = \"/tmp/tool_config.json\"\n    with open(tool_config_path, \"w\") as f:\n        json.dump(tool_config, f)\n\n    n = 2\n    init_config.actor_rollout_ref.rollout.n = n\n    init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path\n    init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2\n    agent_loop_manager = init_agent_loop_manager(init_config)\n\n    # =========================== 2. Generate sequences  ===========================\n    raw_prompts = [\n        [\n            {\"role\": \"user\", \"content\": \"How are you?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in Los Angeles now?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in New York now?\"},\n        ],\n        [\n            {\n                \"role\": \"system\",\n                \"content\": \"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\\n\\n\"\n                \"Current Date: 2024-09-30\",\n            },\n            {\"role\": \"user\", \"content\": \"What's the temperature in San Francisco now? How about tomorrow?\"},\n        ],\n    ]\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),\n            \"agent_name\": np.array([\"tool_agent\"] * len(raw_prompts)),\n        },\n    )\n    batch = batch.repeat(n)\n    result = agent_loop_manager.generate_sequences(prompts=batch)\n    assert len(result) == len(raw_prompts) * n\n\n    # Check turns\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    print(f\"num_turns: {num_turns}\")\n    for i in range(len(num_turns)):\n        if i // n == 0:\n            # [user, assistant]\n            assert num_turns[i] == 2\n        else:\n            # [user, assistant, tool, assistant]\n            assert num_turns[i] == 4\n\n    # Check response_mask\n    tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)\n    responses = result.batch[\"responses\"]\n    response_mask = result.batch[\"response_mask\"]\n    attention_mask = result.batch[\"attention_mask\"]\n    assert responses.size() == response_mask.size(), f\"{responses.size()} != {response_mask.size()}\"\n    response_length = response_mask.size(1)\n\n    for i in range(len(responses)):\n        # response with tool response\n        valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]\n        response_with_obs = tokenizer.decode(valid_tokens)\n\n        # response without tool response\n        valid_tokens = responses[i][response_mask[i].bool()]\n        response_without_obs = tokenizer.decode(valid_tokens)\n\n        assert \"<tool_response>\" not in response_without_obs, (\n            f\"found <tool_response> in response: {response_without_obs}\"\n        )\n        assert \"</tool_response>\" not in response_without_obs, (\n            f\"found </tool_response> in response: {response_without_obs}\"\n        )\n        print(\"=========================\")\n        print(response_with_obs)\n        print(\"---\")\n        print(response_without_obs)\n\n    print(\"Test passed!\")\n    ray.shutdown()\n\n\n@pytest.mark.asyncio\nasync def test_get_trajectory_info():\n    \"\"\"Tests the get_trajectory_info method.\"\"\"\n    # Initialize the class to set up class-level attributes\n    step = 10\n    index = [1, 1, 3, 3]\n    expected_info = [\n        {\"step\": step, \"sample_index\": 1, \"rollout_n\": 0, \"validate\": False},\n        {\"step\": step, \"sample_index\": 1, \"rollout_n\": 1, \"validate\": False},\n        {\"step\": step, \"sample_index\": 3, \"rollout_n\": 0, \"validate\": False},\n        {\"step\": step, \"sample_index\": 3, \"rollout_n\": 1, \"validate\": False},\n    ]\n\n    trajectory_info = await get_trajectory_info(step, index, validate=False)\n\n    assert trajectory_info == expected_info\n"
  },
  {
    "path": "verl_rl/tests/interactions/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/tests/interactions/test_gsm8k_interaction.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom unittest.mock import patch\n\nimport pytest\n\nfrom verl.interactions.gsm8k_interaction import Gsm8kInteraction\n\n\nclass TestGsm8kInteraction:\n    \"\"\"Test cases for Gsm8kInteraction class.\"\"\"\n\n    def setup_method(self):\n        \"\"\"Set up test environment before each test method.\"\"\"\n        self.config = {\"name\": \"gsm8k\"}\n        self.interaction = Gsm8kInteraction(self.config)\n\n    def test_init(self):\n        \"\"\"Test Gsm8kInteraction initialization.\"\"\"\n        assert self.interaction._instance_dict == {}\n        assert self.interaction.config == self.config\n        assert self.interaction.name == \"gsm8k\"\n\n    @pytest.mark.asyncio\n    async def test_start_interaction_with_instance_id(self):\n        \"\"\"Test start_interaction with provided instance_id.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        result_id = await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        assert result_id == instance_id\n        assert instance_id in self.interaction._instance_dict\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"\"\n        assert self.interaction._instance_dict[instance_id][\"ground_truth\"] == ground_truth\n        assert self.interaction._instance_dict[instance_id][\"reward\"] == 0.0\n\n    @pytest.mark.asyncio\n    async def test_start_interaction_without_instance_id(self):\n        \"\"\"Test start_interaction without provided instance_id (auto-generated).\"\"\"\n        ground_truth = \"42\"\n\n        result_id = await self.interaction.start_interaction(ground_truth=ground_truth)\n\n        assert result_id is not None\n        assert len(result_id) == 36  # UUID4 length\n        assert result_id in self.interaction._instance_dict\n        assert self.interaction._instance_dict[result_id][\"ground_truth\"] == ground_truth\n\n    @pytest.mark.asyncio\n    async def test_start_interaction_without_ground_truth(self):\n        \"\"\"Test start_interaction without ground_truth parameter.\"\"\"\n        instance_id = \"test_instance\"\n\n        result_id = await self.interaction.start_interaction(instance_id=instance_id)\n\n        assert result_id == instance_id\n        assert self.interaction._instance_dict[instance_id][\"ground_truth\"] is None\n\n    @pytest.mark.asyncio\n    async def test_generate_response_correct_answer_with_prefix(self):\n        \"\"\"Test generate_response with correct answer already having #### prefix.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [{\"role\": \"user\", \"content\": \"#### 42\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is True\n        assert response == \"Your response is correct!\"\n        assert reward == 1.0\n        assert metadata == {}\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"#### 42\"\n\n    @pytest.mark.asyncio\n    async def test_generate_response_correct_answer_without_prefix(self):\n        \"\"\"Test generate_response with correct answer missing #### prefix.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [{\"role\": \"user\", \"content\": \"42\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is True\n        assert response == \"Your response is correct!\"\n        assert reward == 1.0\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"#### 42\"\n\n    @pytest.mark.asyncio\n    async def test_generate_response_incorrect_answer(self):\n        \"\"\"Test generate_response with incorrect answer.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [{\"role\": \"user\", \"content\": \"24\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is False\n        assert response == \"Your response is incorrect! You need to reflect on your answer and try again.\"\n        assert reward == 0.0\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"#### 24\"\n\n    @pytest.mark.asyncio\n    async def test_generate_response_multiple_messages(self):\n        \"\"\"Test generate_response with multiple messages (should use last user message).\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [\n            {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n            {\"role\": \"assistant\", \"content\": \"Let me think about this...\"},\n            {\"role\": \"user\", \"content\": \"#### 42\"},\n        ]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is True\n        assert response == \"Your response is correct!\"\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"#### 42\"\n\n    @pytest.mark.asyncio\n    async def test_generate_response_no_user_message(self):\n        \"\"\"Test generate_response with no user messages.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [{\"role\": \"assistant\", \"content\": \"Hello!\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is False\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"#### \"\n\n    @pytest.mark.asyncio\n    async def test_calculate_score_direct_call(self):\n        \"\"\"Test calculate_score method directly.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        # Set a response\n        self.interaction._instance_dict[instance_id][\"response\"] = \"#### 42\"\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0) as mock_compute:\n            score = await self.interaction.calculate_score(instance_id)\n\n            assert score == 1.0\n            mock_compute.assert_called_once_with(\"#### 42\", \"42\", method=\"flexible\", format_score=0.0, score=1.0)\n\n    @pytest.mark.asyncio\n    async def test_calculate_score_with_kwargs(self):\n        \"\"\"Test calculate_score method with additional kwargs.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        # Set a response\n        self.interaction._instance_dict[instance_id][\"response\"] = \"#### 24\"\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0) as mock_compute:\n            score = await self.interaction.calculate_score(instance_id, extra_param=\"test\")\n\n            assert score == 0.0\n            mock_compute.assert_called_once_with(\"#### 24\", \"42\", method=\"flexible\", format_score=0.0, score=1.0)\n\n    @pytest.mark.asyncio\n    async def test_finalize_interaction(self):\n        \"\"\"Test finalize_interaction method.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        assert instance_id in self.interaction._instance_dict\n\n        await self.interaction.finalize_interaction(instance_id)\n\n        assert instance_id not in self.interaction._instance_dict\n\n    @pytest.mark.asyncio\n    async def test_finalize_interaction_with_kwargs(self):\n        \"\"\"Test finalize_interaction method with additional kwargs.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        assert instance_id in self.interaction._instance_dict\n\n        await self.interaction.finalize_interaction(instance_id, extra_param=\"test\")\n\n        assert instance_id not in self.interaction._instance_dict\n\n    @pytest.mark.asyncio\n    async def test_finalize_nonexistent_interaction(self):\n        \"\"\"Test finalize_interaction with non-existent instance_id.\"\"\"\n        instance_id = \"nonexistent_instance\"\n\n        # This should raise KeyError\n        with pytest.raises(KeyError):\n            await self.interaction.finalize_interaction(instance_id)\n\n    @pytest.mark.asyncio\n    async def test_full_interaction_workflow_correct(self):\n        \"\"\"Test complete interaction workflow with correct answer.\"\"\"\n        ground_truth = \"42\"\n\n        # Start interaction\n        instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)\n\n        # Generate response with correct answer\n        messages = [{\"role\": \"user\", \"content\": \"42\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is True\n        assert reward == 1.0\n\n        # Finalize interaction\n        await self.interaction.finalize_interaction(instance_id)\n        assert instance_id not in self.interaction._instance_dict\n\n    @pytest.mark.asyncio\n    async def test_full_interaction_workflow_incorrect(self):\n        \"\"\"Test complete interaction workflow with incorrect answer.\"\"\"\n        ground_truth = \"42\"\n\n        # Start interaction\n        instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)\n\n        # Generate response with incorrect answer\n        messages = [{\"role\": \"user\", \"content\": \"24\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is False\n        assert reward == 0.0\n\n        # Continue with another attempt\n        messages.append({\"role\": \"assistant\", \"content\": response})\n        messages.append({\"role\": \"user\", \"content\": \"42\"})\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=1.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is True\n        assert reward == 1.0\n\n        # Finalize interaction\n        await self.interaction.finalize_interaction(instance_id)\n        assert instance_id not in self.interaction._instance_dict\n\n    @pytest.mark.asyncio\n    async def test_multiple_concurrent_interactions(self):\n        \"\"\"Test multiple concurrent interaction instances.\"\"\"\n        ground_truth_1 = \"42\"\n        ground_truth_2 = \"24\"\n\n        # Start multiple interactions\n        instance_id_1 = await self.interaction.start_interaction(ground_truth=ground_truth_1)\n        instance_id_2 = await self.interaction.start_interaction(ground_truth=ground_truth_2)\n\n        assert len(self.interaction._instance_dict) == 2\n        assert instance_id_1 in self.interaction._instance_dict\n        assert instance_id_2 in self.interaction._instance_dict\n\n        # Test responses for both instances\n        messages_1 = [{\"role\": \"user\", \"content\": \"42\"}]\n        messages_2 = [{\"role\": \"user\", \"content\": \"24\"}]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", side_effect=[1.0, 1.0]):\n            should_terminate_1, _, reward_1, _ = await self.interaction.generate_response(instance_id_1, messages_1)\n            should_terminate_2, _, reward_2, _ = await self.interaction.generate_response(instance_id_2, messages_2)\n\n        assert should_terminate_1 is True\n        assert should_terminate_2 is True\n        assert reward_1 == 1.0\n        assert reward_2 == 1.0\n\n        # Finalize both interactions\n        await self.interaction.finalize_interaction(instance_id_1)\n        await self.interaction.finalize_interaction(instance_id_2)\n\n        assert len(self.interaction._instance_dict) == 0\n\n    @pytest.mark.asyncio\n    async def test_edge_case_empty_messages(self):\n        \"\"\"Test edge case with empty messages list.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = []\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is False\n        assert reward == 0.0\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"#### \"\n\n    @pytest.mark.asyncio\n    async def test_edge_case_message_without_content(self):\n        \"\"\"Test edge case with message without content field.\"\"\"\n        instance_id = \"test_instance\"\n        ground_truth = \"42\"\n\n        # Setup instance\n        await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)\n\n        messages = [\n            {\"role\": \"user\"}  # Missing content field\n        ]\n\n        with patch(\"verl.utils.reward_score.gsm8k.compute_score\", return_value=0.0):\n            should_terminate, response, reward, metadata = await self.interaction.generate_response(\n                instance_id, messages\n            )\n\n        assert should_terminate is False\n        assert reward == 0.0\n        assert self.interaction._instance_dict[instance_id][\"response\"] == \"#### None\"\n\n    def test_inheritance_from_base_interaction(self):\n        \"\"\"Test that Gsm8kInteraction properly inherits from BaseInteraction.\"\"\"\n        from verl.interactions.base import BaseInteraction\n\n        assert isinstance(self.interaction, BaseInteraction)\n\n        # Test that all required methods are implemented\n        assert hasattr(self.interaction, \"start_interaction\")\n        assert hasattr(self.interaction, \"generate_response\")\n        assert hasattr(self.interaction, \"calculate_score\")\n        assert hasattr(self.interaction, \"finalize_interaction\")\n\n        # Test that methods are callable\n        assert callable(self.interaction.start_interaction)\n        assert callable(self.interaction.generate_response)\n        assert callable(self.interaction.calculate_score)\n        assert callable(self.interaction.finalize_interaction)\n\n    def test_name_attribute_initialization(self):\n        \"\"\"Test name attribute initialization with different configs.\"\"\"\n        # Test with explicit name in config\n        config_with_name = {\"name\": \"custom_gsm8k\"}\n        interaction_with_name = Gsm8kInteraction(config_with_name)\n        assert interaction_with_name.name == \"custom_gsm8k\"\n\n        # Test with default name when not provided in config\n        config_without_name = {}\n        interaction_without_name = Gsm8kInteraction(config_without_name)\n        assert interaction_without_name.name == \"interaction_agent\"  # Default from BaseInteraction\n\n        # Test that name is accessible as attribute\n        assert hasattr(self.interaction, \"name\")\n        assert self.interaction.name == \"gsm8k\"\n"
  },
  {
    "path": "verl_rl/tests/interactions/test_interaction_registry.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport tempfile\n\nimport pytest\nfrom omegaconf import OmegaConf\n\nfrom verl.interactions.base import BaseInteraction\nfrom verl.interactions.gsm8k_interaction import Gsm8kInteraction\nfrom verl.interactions.utils.interaction_registry import (\n    get_interaction_class,\n    initialize_interactions_from_config,\n)\n\n\nclass TestInteractionRegistry:\n    def test_get_interaction_class(self):\n        \"\"\"Test getting interaction class by name.\"\"\"\n        # Test getting base interaction class\n        base_cls = get_interaction_class(\"verl.interactions.base.BaseInteraction\")\n        assert base_cls == BaseInteraction\n\n        # Test getting gsm8k interaction class\n        gsm8k_cls = get_interaction_class(\"verl.interactions.gsm8k_interaction.Gsm8kInteraction\")\n        assert gsm8k_cls == Gsm8kInteraction\n\n    def test_initialize_single_interaction_from_config(self):\n        \"\"\"Test initializing single interaction from config.\"\"\"\n        # Create temporary config file\n        config_content = {\n            \"interaction\": [\n                {\n                    \"name\": \"test_gsm8k\",\n                    \"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\",\n                    \"config\": {},\n                }\n            ]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            interaction_map = initialize_interactions_from_config(temp_config_path)\n\n            # Check that interaction was created\n            assert len(interaction_map) == 1\n            assert \"test_gsm8k\" in interaction_map\n            assert isinstance(interaction_map[\"test_gsm8k\"], Gsm8kInteraction)\n            assert interaction_map[\"test_gsm8k\"].name == \"test_gsm8k\"\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_initialize_multiple_interactions_from_config(self):\n        \"\"\"Test initializing multiple interactions from config.\"\"\"\n        config_content = {\n            \"interaction\": [\n                {\n                    \"name\": \"gsm8k_solver\",\n                    \"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\",\n                    \"config\": {},\n                },\n                {\n                    \"name\": \"base_agent\",\n                    \"class_name\": \"verl.interactions.base.BaseInteraction\",\n                    \"config\": {\"custom_param\": \"test_value\"},\n                },\n            ]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            interaction_map = initialize_interactions_from_config(temp_config_path)\n\n            # Check that both interactions were created\n            assert len(interaction_map) == 2\n            assert \"gsm8k_solver\" in interaction_map\n            assert \"base_agent\" in interaction_map\n\n            # Check types\n            assert isinstance(interaction_map[\"gsm8k_solver\"], Gsm8kInteraction)\n            assert isinstance(interaction_map[\"base_agent\"], BaseInteraction)\n\n            # Check names were injected\n            assert interaction_map[\"gsm8k_solver\"].name == \"gsm8k_solver\"\n            assert interaction_map[\"base_agent\"].name == \"base_agent\"\n\n            # Check custom config was passed\n            assert interaction_map[\"base_agent\"].config.get(\"custom_param\") == \"test_value\"\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_initialize_interaction_without_explicit_name(self):\n        \"\"\"Test that interaction name is derived from class name when not specified.\"\"\"\n        config_content = {\n            \"interaction\": [{\"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\", \"config\": {}}]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            interaction_map = initialize_interactions_from_config(temp_config_path)\n\n            # Check that interaction name was derived from class name\n            assert len(interaction_map) == 1\n            assert \"gsm8k\" in interaction_map  # Should be \"gsm8k\" after removing \"interaction\" suffix\n            assert isinstance(interaction_map[\"gsm8k\"], Gsm8kInteraction)\n            assert interaction_map[\"gsm8k\"].name == \"gsm8k\"\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_initialize_empty_config(self):\n        \"\"\"Test initializing from empty config.\"\"\"\n        config_content = {\"interaction\": []}\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            interaction_map = initialize_interactions_from_config(temp_config_path)\n            assert len(interaction_map) == 0\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_invalid_class_name(self):\n        \"\"\"Test handling of invalid class name.\"\"\"\n        config_content = {\n            \"interaction\": [{\"name\": \"invalid\", \"class_name\": \"invalid.module.InvalidClass\", \"config\": {}}]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            with pytest.raises(ModuleNotFoundError):\n                initialize_interactions_from_config(temp_config_path)\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_duplicate_interaction_names(self):\n        \"\"\"Test handling of duplicate interaction names.\"\"\"\n        config_content = {\n            \"interaction\": [\n                {\"name\": \"duplicate\", \"class_name\": \"verl.interactions.base.BaseInteraction\", \"config\": {}},\n                {\n                    \"name\": \"duplicate\",\n                    \"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\",\n                    \"config\": {},\n                },\n            ]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            with pytest.raises(ValueError, match=\"Duplicate interaction name 'duplicate' found\"):\n                initialize_interactions_from_config(temp_config_path)\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_auto_name_generation_edge_cases(self):\n        \"\"\"Test automatic name generation for various class name patterns.\"\"\"\n        config_content = {\n            \"interaction\": [\n                {\"class_name\": \"verl.interactions.base.BaseInteraction\", \"config\": {}},\n                {\"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\", \"config\": {}},\n            ]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(config_content, f.name)\n            temp_config_path = f.name\n\n        try:\n            interaction_map = initialize_interactions_from_config(temp_config_path)\n\n            # Check that names were generated correctly\n            assert len(interaction_map) == 2\n            assert \"base\" in interaction_map  # BaseInteraction -> base\n            assert \"gsm8k\" in interaction_map  # Gsm8kInteraction -> gsm8k\n        finally:\n            os.unlink(temp_config_path)\n"
  },
  {
    "path": "verl_rl/tests/kill_github_tests.sh",
    "content": "#!/bin/bash\n\nif [ \"$#\" -ne 1 ]; then\n    echo \"Usage: $0 YOUR_GITHUB_TOKEN\"\n    echo \"Please provide exactly one input argument for your github token.\"\n    exit 1\nfi\n\n# Set your GitHub repository details\nOWNER=\"volcengine\"\nREPO=\"verl\"\nTOKEN=$1\n\n# API URL for workflow runs\nAPI_URL=\"https://api.github.com/repos/$OWNER/$REPO/actions/runs?status=queued\"\n\n# Check required commands\ncommand -v jq >/dev/null 2>&1 || { echo \"jq is required but not installed. Aborting.\"; exit 1; }\n\n# Get queued workflow runs\nresponse=$(curl -s -H \"Authorization: token $TOKEN\" -H \"Accept: application/vnd.github.v3+json\" \"$API_URL\")\n\n# Run this for debugging\n# echo $response\n\n# Extract run IDs\nqueued_run_ids=$(echo \"$response\" | jq -r '.workflow_runs[] | .id')\n\nif [ -z \"$queued_run_ids\" ]; then\n    echo \"No queued workflow runs found.\"\n    exit 0\nfi\n\n# Cancel each queued run\nfor run_id in $queued_run_ids; do\n    echo \"Cancelling run $run_id\"\n    cancel_url=\"https://api.github.com/repos/$OWNER/$REPO/actions/runs/$run_id/cancel\"\n    curl -s -X POST -H \"Authorization: token $TOKEN\" -H \"Accept: application/vnd.github.v3+json\" \"$cancel_url\"\ndone\n\necho \"Cancelled all queued workflow runs.\"\n"
  },
  {
    "path": "verl_rl/tests/models/test_transformer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\nfrom transformers import (\n    AutoModelForCausalLM,\n    AutoModelForTokenClassification,\n    GemmaConfig,\n    LlamaConfig,\n    MistralConfig,\n    Qwen2Config,\n)\n\nfrom verl.utils.model import compute_position_id_with_mask, create_random_mask\nfrom verl.utils.torch_functional import log_probs_from_logits_all_rmpad, masked_mean\n\n# TODO(sgm): add more models for test\n# we only need one scale for each model\ntest_configs = [\n    LlamaConfig(num_hidden_layers=1),\n    MistralConfig(num_hidden_layers=1),\n    GemmaConfig(num_hidden_layers=1),\n    Qwen2Config(num_hidden_layers=1),\n]\n\n\ndef test_hf_casual_models():\n    batch_size = 4\n    seqlen = 128\n    response_length = 127\n\n    for config in test_configs:\n        # config = AutoConfig.from_pretrained(test_case)\n        with torch.device(\"cuda\"):\n            model = AutoModelForCausalLM.from_config(\n                config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n            )\n            model = model.to(device=\"cuda\")\n        input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device=\"cuda\")\n        attention_mask = create_random_mask(\n            input_ids=input_ids,\n            max_ratio_of_left_padding=0.1,\n            max_ratio_of_valid_token=0.8,\n            min_ratio_of_valid_token=0.5,\n        )\n        position_ids = compute_position_id_with_mask(\n            attention_mask\n        )  # TODO(sgm): we can construct the position_ids_rmpad here\n\n        input_ids_rmpad, indices, *_ = unpad_input(\n            input_ids.unsqueeze(-1), attention_mask\n        )  # input_ids_rmpad (total_nnz, ...)\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n        # unpad the position_ids to align the rotary\n        position_ids_rmpad = index_first_axis(\n            rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n        ).transpose(0, 1)\n\n        # input with input_ids_rmpad and postition_ids to enable flash attention varlen\n        logits_rmpad = model(\n            input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False\n        ).logits  # (1, total_nnz, vocab_size)\n\n        origin_logits = model(\n            input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False\n        ).logits\n        origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask)\n\n        logits_rmpad = logits_rmpad.squeeze(0)\n        log_probs = log_probs_from_logits_all_rmpad(\n            input_ids_rmpad=input_ids_rmpad,\n            logits_rmpad=logits_rmpad,\n            indices=indices,\n            batch_size=batch_size,\n            seqlen=seqlen,\n            response_length=response_length,\n        )  # (batch, seqlen)\n        origin_log_probs = log_probs_from_logits_all_rmpad(\n            input_ids_rmpad=input_ids_rmpad,\n            logits_rmpad=origin_logits_rmpad,\n            indices=origin_logits_indices,\n            batch_size=batch_size,\n            seqlen=seqlen,\n            response_length=response_length,\n        )  # (batch, seqlen)\n\n        torch.testing.assert_close(\n            masked_mean(log_probs, attention_mask[:, -response_length - 1 : -1]),\n            masked_mean(origin_log_probs, attention_mask[:, -response_length - 1 : -1]),\n            atol=1e-2,\n            rtol=1e-5,\n        )\n    print(\"Check pass\")\n\n\ndef test_hf_value_models():\n    batch_size = 4\n    seqlen = 128\n\n    for config in test_configs:\n        # config = AutoConfig.from_pretrained(test_case)\n        config.num_labels = 1\n        config.classifier_dropout = 0\n        config.hidden_dropout = 0\n        with torch.device(\"cuda\"):\n            model = AutoModelForTokenClassification.from_config(\n                config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n            )\n            model = model.to(device=\"cuda\")\n        input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device=\"cuda\")\n        attention_mask = create_random_mask(\n            input_ids=input_ids,\n            max_ratio_of_left_padding=0.1,\n            max_ratio_of_valid_token=0.8,\n            min_ratio_of_valid_token=0.5,\n        )\n        position_ids = compute_position_id_with_mask(\n            attention_mask\n        )  # TODO(sgm): we can construct the position_ids_rmpad here\n\n        input_ids_rmpad, indices, *_ = unpad_input(\n            input_ids.unsqueeze(-1), attention_mask\n        )  # input_ids_rmpad (total_nnz, ...)\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n        # unpad the position_ids to align the rotary\n        position_ids_rmpad = index_first_axis(\n            rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n        ).transpose(0, 1)\n\n        origin_logits = model(\n            input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False\n        ).logits\n\n        # input with input_ids_rmpad and postition_ids to enable flash attention varlen\n        rmpad_logits = model(\n            input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False\n        ).logits  # (1, total_nnz, 1)\n        rmpad_logits = rmpad_logits.squeeze(0)\n        pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen)\n\n        torch.testing.assert_close(\n            masked_mean(pad_logits, attention_mask[:, :, None]),\n            masked_mean(origin_logits, attention_mask[:, :, None]),\n            atol=1e-2,\n            rtol=1e-5,\n        )\n    print(\"Value model check pass\")\n\n\nif __name__ == \"__main__\":\n    test_hf_casual_models()\n    test_hf_value_models()\n"
  },
  {
    "path": "verl_rl/tests/models/test_transformers_ulysses.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport contextlib\nimport copy\nfrom dataclasses import dataclass\n\nimport pytest\nimport torch\nimport torch.distributed\nfrom flash_attn.bert_padding import index_first_axis, rearrange, unpad_input\nfrom torch.distributed import init_device_mesh\nfrom transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig, Qwen2Config\n\nfrom verl.models.transformers.monkey_patch import apply_monkey_patch\nfrom verl.protocol import DataProto\nfrom verl.utils.distributed import initialize_global_process_group\nfrom verl.utils.model import compute_position_id_with_mask, create_random_mask\nfrom verl.utils.ulysses import (\n    gather_outputs_and_unpad,\n    get_ulysses_sequence_parallel_world_size,\n    set_ulysses_sequence_parallel_group,\n    ulysses_pad_and_slice_inputs,\n)\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\n# TODO(sgm): add more models for test\n# we only need one scale for each model\n\n\n@dataclass\nclass SequenceParallelConfig:\n    config: PretrainedConfig\n    sp_size: int\n    is_valid: bool\n\n\ndef test_configs():\n    return [\n        SequenceParallelConfig(\n            LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True\n        ),\n        SequenceParallelConfig(\n            Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584),\n            sp_size=4,\n            is_valid=True,\n        ),\n        SequenceParallelConfig(\n            Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584),\n            sp_size=8,\n            is_valid=False,\n        ),\n        SequenceParallelConfig(\n            Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=4, is_valid=True\n        ),\n        SequenceParallelConfig(\n            Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=8, is_valid=True\n        ),\n    ]\n\n\ndef sync_model_parameters_global(layer):\n    # synchronize weights\n    for p in layer.parameters():\n        torch.distributed.broadcast(tensor=p.data, src=0)\n\n\n@pytest.mark.parametrize(\"test_config\", test_configs())\ndef test_hf_casual_fwd_bwd(test_config):\n    if not torch.distributed.is_initialized():\n        initialize_global_process_group()\n\n    context = contextlib.nullcontext() if test_config.is_valid else pytest.raises(AssertionError)\n    with context:\n        world_size = torch.distributed.get_world_size()\n        _hf_casual_fwd_bwd(test_config.config, test_config.sp_size, world_size // test_config.sp_size)\n\n    # TODO: seems not work, will cause `socketStartConnect: Connect to xxx failed : Software caused connection abort`\n    # torch.distributed.destroy_process_group()\n\n\ndef _hf_casual_fwd(config, sp_size, dp_size):\n    assert torch.cuda.device_count() >= 2, \"need at least 2 gpus for test\"\n\n    ulysses_device_mesh = init_device_mesh(\n        device_type=\"cuda\", mesh_shape=(dp_size, sp_size), mesh_dim_names=(\"dp\", \"sp\")\n    )\n    sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)\n\n    batch_size = 1\n    seqlen = 128\n    # response_length = 127\n\n    # patch before load\n    with torch.device(\"cuda\"):\n        model = AutoModelForCausalLM.from_config(\n            config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n        )\n        apply_monkey_patch(model, sp_size)\n        model = model.to(device=\"cuda\")\n        sync_model_parameters_global(model)\n\n    # different rank will generate different input_ids following fsdp\n    input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device=\"cuda\")\n    attention_mask = create_random_mask(\n        input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8\n    )\n    position_ids = compute_position_id_with_mask(\n        attention_mask\n    )  # TODO(sgm): we can construct the position_ids_rmpad here\n\n    model_inputs = {\n        \"input_ids\": input_ids.cuda(),\n        \"attention_mask\": attention_mask.cuda(),\n        \"position_ids\": position_ids.int().cuda(),\n    }\n\n    model_inputs = DataProto.from_dict(model_inputs)\n\n    # 1. perform ulysses forward\n    with sharding_manager:\n        model_inputs = sharding_manager.preprocess_data(model_inputs)\n        input_ids = model_inputs.batch[\"input_ids\"]\n        attention_mask = model_inputs.batch[\"attention_mask\"]\n        position_ids = model_inputs.batch[\"position_ids\"]\n        input_ids_rmpad, indices, *_ = unpad_input(\n            input_ids.unsqueeze(-1), attention_mask\n        )  # input_ids_rmpad (total_nnz, ...)\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n        # unpad the position_ids to align the rotary\n        position_ids_rmpad = index_first_axis(\n            rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n        ).transpose(0, 1)\n\n        # slice input tensor for ulysses\n        # input_ids are padded and sliced\n        # postition_ids are only padded but not sliced\n        input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(\n            input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()\n        )\n\n        # input with input_ids_rmpad and postition_ids to enable flash attention varlen\n        logits_split_in_seq = model(\n            input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False\n        ).logits  # (1, total_nnz/n, vocab_size)\n\n        # all_gather output\n        logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)\n\n    # 2. perform normal forward\n    set_ulysses_sequence_parallel_group(None)\n    logits_rmpad_local = model(\n        input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False\n    ).logits  # (1, total_nnz, vocab_size)\n\n    mean_local = logits_rmpad_local.mean()\n    mean_full = logits_full.mean()\n    torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)\n\n\ndef _hf_casual_fwd_bwd(config, sp_size, dp_size):\n    assert torch.cuda.device_count() >= 2, \"need at least 2 gpus for test\"\n\n    ulysses_device_mesh = init_device_mesh(\n        device_type=\"cuda\", mesh_shape=(dp_size, sp_size), mesh_dim_names=(\"dp\", \"sp\")\n    )\n    sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)\n\n    batch_size = 1\n    seqlen = 128\n    # response_length = 127\n\n    # patch before load\n    with torch.device(\"cuda\"):\n        model = AutoModelForCausalLM.from_config(\n            config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n        )\n        apply_monkey_patch(model, sp_size)\n        model = model.to(device=\"cuda\")\n        sync_model_parameters_global(model)\n\n    # different rank will generate different input_ids following fsdp\n    input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device=\"cuda\")\n    attention_mask = create_random_mask(\n        input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8\n    )\n    position_ids = compute_position_id_with_mask(\n        attention_mask\n    )  # TODO(sgm): we can construct the position_ids_rmpad here\n\n    model_inputs = {\n        \"input_ids\": input_ids.cuda(),\n        \"attention_mask\": attention_mask.cuda(),\n        \"position_ids\": position_ids.int().cuda(),\n    }\n\n    model_inputs = DataProto.from_dict(model_inputs)\n\n    # 1. perform ulysses forward\n    with sharding_manager:\n        model_inputs = sharding_manager.preprocess_data(model_inputs)\n        input_ids = model_inputs.batch[\"input_ids\"]\n        attention_mask = model_inputs.batch[\"attention_mask\"]\n        position_ids = model_inputs.batch[\"position_ids\"]\n        input_ids_rmpad, indices, *_ = unpad_input(\n            input_ids.unsqueeze(-1), attention_mask\n        )  # input_ids_rmpad (total_nnz, ...)\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n        # unpad the position_ids to align the rotary\n        position_ids_rmpad = index_first_axis(\n            rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n        ).transpose(0, 1)\n\n        # slice input tensor for ulysses\n        # input_ids are padded and sliced\n        # postition_ids are only padded but not sliced\n        input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(\n            input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()\n        )\n\n        # input with input_ids_rmpad and postition_ids to enable flash attention varlen\n        logits_split_in_seq = model(\n            input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False\n        ).logits  # (1, total_nnz/n, vocab_size)\n\n        # all_gather output\n        logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)\n\n    # 2. perform normal forward\n    set_ulysses_sequence_parallel_group(None)\n    input_ids_full = copy.deepcopy(input_ids_rmpad)\n    position_ids_full = copy.deepcopy(position_ids_rmpad)\n    model_no_sp = copy.deepcopy(model)\n    logits_rmpad_local = model_no_sp(\n        input_ids_full, position_ids=position_ids_full, use_cache=False\n    ).logits  # (1, total_nnz, vocab_size)\n\n    mean_local = logits_rmpad_local.mean()\n    mean_full = logits_full.mean()\n\n    mean_full.backward()\n    mean_local.backward()\n\n    # 3. check the gradients\n    grad = model.model.layers[0].self_attn.q_proj.weight.grad\n    grad_full = model_no_sp.model.layers[0].self_attn.q_proj.weight.grad\n    torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)\n    torch.testing.assert_close(grad, grad_full, atol=1e-2, rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-svv\"])\n"
  },
  {
    "path": "verl_rl/tests/single_controller/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/tests/single_controller/base/test_decorator.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\nimport verl.single_controller.base.decorator as decorator_module\nfrom verl.single_controller.base.decorator import (\n    DISPATCH_MODE_FN_REGISTRY,\n    Dispatch,\n    _check_dispatch_mode,\n    get_predefined_dispatch_fn,\n    register_dispatch_mode,\n    update_dispatch_mode,\n)\n\n\n@pytest.fixture\ndef reset_dispatch_registry():\n    # Store original state\n    original_registry = DISPATCH_MODE_FN_REGISTRY.copy()\n    yield\n    # Reset registry after test\n    decorator_module.DISPATCH_MODE_FN_REGISTRY.clear()\n    decorator_module.DISPATCH_MODE_FN_REGISTRY.update(original_registry)\n\n\ndef test_register_new_dispatch_mode(reset_dispatch_registry):\n    # Test registration\n    def dummy_dispatch(worker_group, *args, **kwargs):\n        return args, kwargs\n\n    def dummy_collect(worker_group, output):\n        return output\n\n    register_dispatch_mode(\"TEST_MODE\", dummy_dispatch, dummy_collect)\n\n    # Verify enum extension\n    _check_dispatch_mode(Dispatch.TEST_MODE)\n\n    # Verify registry update\n    assert get_predefined_dispatch_fn(Dispatch.TEST_MODE) == {\n        \"dispatch_fn\": dummy_dispatch,\n        \"collect_fn\": dummy_collect,\n    }\n    # Clean up\n    Dispatch.remove(\"TEST_MODE\")\n\n\ndef test_update_existing_dispatch_mode(reset_dispatch_registry):\n    # Store original implementation\n    original_mode = Dispatch.ONE_TO_ALL\n\n    # New implementations\n    def new_dispatch(worker_group, *args, **kwargs):\n        return args, kwargs\n\n    def new_collect(worker_group, output):\n        return output\n\n    # Test update=\n    update_dispatch_mode(original_mode, new_dispatch, new_collect)\n\n    # Verify update\n    assert get_predefined_dispatch_fn(original_mode)[\"dispatch_fn\"] == new_dispatch\n    assert get_predefined_dispatch_fn(original_mode)[\"collect_fn\"] == new_collect\n"
  },
  {
    "path": "verl_rl/tests/single_controller/check_worker_alive/main.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport sys\nimport time\n\nimport ray\n\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n@ray.remote\nclass TestActor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)\n    def foo(self, wait_time):\n        time.sleep(wait_time)\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    wait_time = int(os.getenv(\"WAIT_TIME\", \"10\"))\n\n    ray.init()\n\n    # test single-node-no-partition\n    print(\"test single-node-no-partition\")\n    resource_pool = RayResourcePool([2], use_gpu=False)\n    class_with_args = RayClassWithInitArgs(cls=TestActor)\n\n    print(\"create worker group\")\n    wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"test\")\n\n    wg.start_worker_aliveness_check(1)\n    time.sleep(1)\n\n    print(time.time(), \"start foo\")\n\n    _ = wg.foo(wait_time)\n    print(\"foo started\")\n\n    print(\n        time.time(),\n        f\"wait 6x wait time {wait_time * 6} to let signal returned to process but still not exceed process wait time\",\n    )\n    time.sleep(wait_time * 6)\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/detached_worker/README.md",
    "content": "# Detached Worker\n## How to run (Only on a single node)\n- Start a local ray cluster: \n```bash\nray start --head --port=6379\n```\n- Run the server\n```bash\npython3 server.py\n```\n- On another terminal, Run the client\n```bash\npython3 client.py\n```\n"
  },
  {
    "path": "verl_rl/tests/single_controller/detached_worker/client.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nIn client, we can get the server handler and send RPC request\n\"\"\"\n\nimport ray\nimport torch\nfrom server import Trainer\nfrom tensordict import TensorDict\n\nfrom verl import DataProto\nfrom verl.single_controller.ray import RayClassWithInitArgs\nfrom verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n\n\ndef compute_position_id_with_mask(mask):\n    return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)\n\n\nif __name__ == \"__main__\":\n    ray.init(address=\"auto\", namespace=\"verl\")\n    # get the worker group using names\n    worker_names = [\"trainerTrainer_0:0\", \"trainerTrainer_0:1\"]\n    cls_with_init_args = RayClassWithInitArgs(cls=Trainer)\n    worker_group = NVMegatronRayWorkerGroup.from_detached(\n        worker_names=worker_names, ray_cls_with_init=cls_with_init_args\n    )\n\n    batch_size = 16\n    sequence_length = 1024\n\n    # give Trainer some data to train\n    input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device=\"cuda\")\n    attention_mask = torch.ones_like(input_ids)\n    position_ids = compute_position_id_with_mask(attention_mask)\n\n    data = DataProto(\n        batch=TensorDict(\n            {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"position_ids\": position_ids},\n            batch_size=batch_size,\n        ),\n        meta_info={},\n    )\n\n    output = worker_group.train_model(data)\n\n    print(output)\n"
  },
  {
    "path": "verl_rl/tests/single_controller/detached_worker/run.sh",
    "content": "#!/bin/bash\nray start --head --port=6379\npython3 server.py\npython3 client.py\nray stop --force"
  },
  {
    "path": "verl_rl/tests/single_controller/detached_worker/server.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nServer starts a Trainer. Client sends data to the server to train.\n\"\"\"\n\nimport os\n\nos.environ[\"MEGATRON_USE_CUDA_TIMER\"] = \"0\"\nos.environ[\"MEGATRON_START_PROCESS_TIMER\"] = \"False\"\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\n\nimport ray\nimport torch\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core import tensor_parallel\nfrom megatron.core.models.gpt.gpt_model import ModelType\nfrom omegaconf import OmegaConf\nfrom tensordict import TensorDict\nfrom torch import nn\nfrom transformers import LlamaConfig\n\nfrom verl import DataProto\nfrom verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.base.megatron.worker import MegatronWorker\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool\nfrom verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\nfrom verl.utils.megatron.optimizer import get_megatron_optimizer\nfrom verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config\n\n\n@ray.remote\nclass Trainer(MegatronWorker):\n    def __init__(self):\n        super().__init__()\n\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(backend=\"nccl\")\n            torch.cuda.set_device(rank)\n\n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=2,\n                pipeline_model_parallel_size=1,\n                virtual_pipeline_model_parallel_size=None,\n                pipeline_model_parallel_split_rank=None,\n                use_sharp=False,\n                context_parallel_size=1,\n                expert_model_parallel_size=1,\n                nccl_communicator_config_path=None,\n            )\n            tensor_parallel.model_parallel_cuda_manual_seed(10)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        actor_model_config = LlamaConfig(\n            vocab_size=256,\n            hidden_size=2048,\n            intermediate_size=5504,\n            num_hidden_layers=24,\n            num_attention_heads=16,\n            num_key_value_heads=16,\n        )\n\n        megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16)\n        self.megatron_config = megatron_config\n\n        def megatron_actor_model_provider(pre_process, post_process):\n            # vpp is not supported yet because it will hang for some reason. Need debugging\n            # this_megatron_config = copy.deepcopy(megatron_config)\n            # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank\n            parallel_model = ParallelLlamaForCausalLMRmPadPP(\n                config=actor_model_config,\n                megatron_config=megatron_config,\n                pre_process=pre_process,\n                post_process=post_process,\n            )\n            parallel_model.cuda()\n            return parallel_model\n\n        actor_module = get_model(\n            model_provider_func=megatron_actor_model_provider,\n            model_type=ModelType.encoder_or_decoder,\n            wrap_with_ddp=True,\n        )\n        actor_module = nn.ModuleList(actor_module)\n\n        optim_config = OmegaConf.create({\"lr\": 1e-6, \"clip_grad\": 1.0})\n\n        optim_config = init_megatron_optim_config(optim_config)\n        self.optimizer_config = optim_config\n        actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)\n\n        self.model = actor_module[0]\n        self.optimizer = actor_optimizer\n\n    @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n    def train_model(self, data: DataProto) -> DataProto:\n        input_ids = data.batch[\"input_ids\"]\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n\n        self.optimizer.zero_grad()\n        self.model.zero_grad_buffer(\n            zero_buffer=(not self.optimizer_config.use_distributed_optimizer)\n        )  # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm\n        # update for 1 iteration\n        output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits\n        output.mean().backward()\n\n        update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(\n            self.megatron_config, self.megatron_config.timers\n        )\n\n        return DataProto(batch=TensorDict({\"loss\": output.detach()}, batch_size=output.shape[0]))\n\n\nif __name__ == \"__main__\":\n    ray.init(address=\"auto\", namespace=\"verl\")\n\n    resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)\n    cls_with_init_args = RayClassWithInitArgs(cls=Trainer)\n    worker_group = NVMegatronRayWorkerGroup(\n        resource_pool=resource_pool,\n        ray_cls_with_init=cls_with_init_args,\n        name_prefix=\"trainer\",\n        detached=True,\n    )\n\n    worker_group.init_model()\n\n    worker_names = worker_group.worker_names\n    print(worker_names)\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_auto_padding_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport ray\nimport torch\n\nfrom verl import DataProto\nfrom verl.protocol import DataProtoConfig\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n# or set env var VERL_AUTO_PADDING = \"1\" / \"true\"\nDataProtoConfig.auto_padding = True\n\n\n@ray.remote\nclass Actor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def add(self, data: DataProto):\n        data.batch[\"a\"] += self.rank\n        return data\n\n\ndef test_auto_padding():\n    ray.init(num_cpus=100)\n\n    chunk_size = 4\n    actor_cls = RayClassWithInitArgs(cls=Actor)\n    resource_pool = RayResourcePool(process_on_nodes=[chunk_size], use_gpu=False)\n    actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)\n\n    # test locally first\n    for test_size in range(4, 20):\n        local_data = DataProto.from_dict({\"a\": torch.zeros(test_size)}, {\"na\": np.zeros(test_size, dtype=object)})\n        # print(f\"before padding, local_data = {local_data}\")\n        padding_size = (chunk_size - (test_size % chunk_size)) if (test_size % chunk_size > 0) else 0\n        local_data.padding(padding_size)\n        # print(f\"after padding, local_data = {local_data}\")\n        assert len(local_data) == len(local_data) + len(local_data) % chunk_size, (\n            f\"expecting padded length to be {len(local_data) + len(local_data) % chunk_size}, but got {len(local_data)}\"\n        )\n        chunked = local_data.chunk(chunk_size)\n        assert len(chunked) == chunk_size, f\"during test_size = {test_size}, expecting {chunk_size}, got {chunked}\"\n        for dp in chunked:\n            assert len(dp) == test_size // chunk_size + bool(test_size % chunk_size), (\n                f\"test size = {test_size}, expecting dp to be length of \"\n                f\"{test_size // chunk_size + bool(test_size % chunk_size)}, but got {len(dp)}: {dp} {chunked}\"\n            )\n\n    # test with RayWorkerGroup method decorated as dispatch_mode=Dispatch.DP_COMPUTE_PROTO\n    data = DataProto.from_dict({\"a\": torch.zeros(10)}, {\"na\": np.array([str(i) for i in range(10)], dtype=object)})\n    output = actor_wg.add(data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 10, \"Failed in args split and padding.\"\n\n    data = DataProto.from_dict({\"a\": torch.zeros(10)}, {\"na\": np.array([str(i) for i in range(10)], dtype=object)})\n    output = actor_wg.add(data=data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 10, \"Failed in kwargs split and padding.\"\n\n    data = DataProto.from_dict({\"a\": torch.zeros(1)}, {\"na\": np.array([str(i) for i in range(1)], dtype=object)})\n    output = actor_wg.add(data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 1, \"Failed in args split and padding.\"\n\n    data = DataProto.from_dict({\"a\": torch.zeros(1)}, {\"na\": np.array([str(i) for i in range(1)], dtype=object)})\n    output = actor_wg.add(data=data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 1, \"Failed in kwargs split and padding.\"\n\n    data = DataProto.from_dict({\"a\": torch.zeros(8)}, {\"na\": np.array([str(i) for i in range(8)], dtype=object)})\n    output = actor_wg.add(data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 8, \"Failed in args split and padding.\"\n\n    data = DataProto.from_dict({\"a\": torch.zeros(8)}, {\"na\": np.array([str(i) for i in range(8)], dtype=object)})\n    output = actor_wg.add(data=data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 8, \"Failed in kwargs split and padding.\"\n\n    # test data proto specific config\n    DataProtoConfig.auto_padding = False\n\n    data = DataProto.from_dict(\n        {\"a\": torch.zeros(10)}, {\"na\": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True\n    )\n    output = actor_wg.add(data)\n    print(output.batch[\"a\"])\n    assert len(output) == 10, \"Failed in args split and padding.\"\n\n    data = DataProto.from_dict(\n        {\"a\": torch.zeros(10)}, {\"na\": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True\n    )\n    output = actor_wg.add(data=data)\n    print(output.batch[\"a\"])\n    assert len(output) == 10, \"Failed in kwargs split and padding.\"\n\n    data = DataProto.from_single_dict(\n        {\"a\": torch.zeros(1), \"na\": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True\n    )\n    output = actor_wg.add(data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 1, \"Failed in args split and padding.\"\n\n    data = DataProto.from_single_dict(\n        {\"a\": torch.zeros(1), \"na\": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True\n    )\n    output = actor_wg.add(data=data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 1, \"Failed in kwargs split and padding.\"\n\n    data = DataProto.from_single_dict({\"a\": torch.zeros(8), \"na\": np.array([str(i) for i in range(8)], dtype=object)})\n    output = actor_wg.add(data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 8, \"Failed in args split and padding.\"\n\n    data = DataProto.from_single_dict({\"a\": torch.zeros(8), \"na\": np.array([str(i) for i in range(8)], dtype=object)})\n    output = actor_wg.add(data=data)\n\n    print(output.batch[\"a\"])\n    assert len(output) == 8, \"Failed in kwargs split and padding.\"\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_auto_padding()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_colocated_workers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray.base import (\n    RayClassWithInitArgs,\n    RayResourcePool,\n    RayWorkerGroup,\n    create_colocated_worker_cls,\n)\n\n\n@ray.remote\nclass Actor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def add(self, data: DataProto):\n        data.batch[\"a\"] += self.rank\n        return data\n\n\n@ray.remote\nclass Critic(Worker):\n    def __init__(self, config) -> None:\n        super().__init__()\n        self.config = config\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    async def sub(self, data: DataProto):\n        data.batch[\"a\"] -= self.config[\"b\"]\n        return data\n\n\ndef test_colocated_workers():\n    ray.init()\n\n    import torch\n\n    data = DataProto.from_dict({\"a\": torch.zeros(10)})\n    # create separate workers on the same resource pool\n    actor_cls = RayClassWithInitArgs(cls=Actor)\n    critic_cls = RayClassWithInitArgs(cls=Critic, config={\"b\": 10})\n    resource_pool = RayResourcePool(process_on_nodes=[2])\n\n    actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)\n    critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls)\n\n    expected_actor_output = actor_wg.add(data)\n    expected_critic_output = critic_wg.sub(data)\n\n    # create colocated workers\n    cls_dict = {\"actor\": actor_cls, \"critic\": critic_cls}\n    ray_cls_with_init = create_colocated_worker_cls(cls_dict)\n    wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)\n    spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())\n\n    colocated_actor_wg = spawn_wg[\"actor\"]\n    colocated_critic_wg = spawn_wg[\"critic\"]\n\n    actor_output = colocated_actor_wg.add(data)\n    critic_output = colocated_critic_wg.sub(data)\n\n    torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)\n    torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_colocated_workers_fused.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray.base import (\n    RayClassWithInitArgs,\n    RayResourcePool,\n    RayWorkerGroup,\n    create_colocated_worker_cls_fused,\n)\n\n\n@ray.remote\nclass Actor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def add(self, data: DataProto):\n        data.batch[\"a\"] += self.rank\n        return data\n\n\n@ray.remote\nclass Critic(Worker):\n    def __init__(self, config) -> None:\n        super().__init__()\n        self.config = config\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def sub(self, data: DataProto):\n        data.batch[\"a\"] -= self.config[\"b\"]\n        return data\n\n\ndef test_colocated_workers_fused():\n    ray.init()\n\n    import torch\n\n    data = DataProto.from_dict({\"a\": torch.zeros(10)})\n    # create separate workers on the same resource pool\n    actor_cls = RayClassWithInitArgs(cls=Actor)\n    critic_cls = RayClassWithInitArgs(cls=Critic, config={\"b\": 10})\n    resource_pool = RayResourcePool(process_on_nodes=[2])\n\n    actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)\n    critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls)\n\n    expected_actor_output = actor_wg.add(data)\n    expected_critic_output = critic_wg.sub(data)\n\n    # create colocated workers\n    cls_dict = {\"actor\": actor_cls, \"critic\": critic_cls}\n    ray_cls_with_init = create_colocated_worker_cls_fused(cls_dict)\n    wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)\n    spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())\n\n    colocated_actor_wg = spawn_wg[\"actor\"]\n    colocated_critic_wg = spawn_wg[\"critic\"]\n\n    actor_output = colocated_actor_wg.add(data)\n    critic_output = colocated_critic_wg.sub(data)\n\n    torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)\n    torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_data_transfer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nIn this test, we instantiate a data parallel worker with 8 GPUs\n\"\"\"\n\nimport ray\nimport tensordict\nimport torch\nfrom codetiming import Timer\nfrom torch import distributed as dist\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\nfrom verl.utils.ray_utils import parallel_put\n\n\n@ray.remote\nclass DummyWorker(Worker):\n    def __init__(self):\n        super().__init__()\n        dist.init_process_group()\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)\n    def do_nothing(self, data):\n        for key in data.batch.keys():\n            data.batch[key] += 1\n        if tensordict.__version__ >= \"0.5.0\":\n            data.batch = data.batch.consolidate()\n        return data\n\n\ndef test_data_transfer():\n    ray.init()\n    # construct resource pool\n    resource_pool = RayResourcePool([8])\n    cls_with_init = RayClassWithInitArgs(cls=DummyWorker)\n    # construct worker group\n    wg = RayWorkerGroup(resource_pool, cls_with_init)\n\n    # this is real dataset size\n    batch_size = 4096\n    seqlen = 32768\n\n    data_dict = {}\n\n    for i in range(2):\n        data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen))\n\n    data = DataProto.from_dict(tensors=data_dict)\n\n    print(data)\n\n    # we manually split data here and send to each worker\n    data_list = data.chunk(wg.world_size)\n\n    for i in range(wg.world_size):\n        # consolidate is necessary\n        if tensordict.__version__ >= \"0.5.0\":\n            data_list[i].batch = data_list[i].batch.consolidate()\n\n    with Timer(name=\"ray.pickle\", initial_text=True):\n        for i in range(wg.world_size):\n            ray.cloudpickle.pickle.dumps(data_list[i])\n\n    with Timer(name=\"raw.pickle\", initial_text=True):\n        import pickle\n\n        for i in range(wg.world_size):\n            pickle.dumps(data_list[i])\n\n    # we put in advance\n    with Timer(name=\"put\", initial_text=True):\n        # takes around 40 seconds\n        data_list_ref = parallel_put(data_list)\n        # for i in range(wg.world_size):\n        #     data_list[i] = ray.put(data_list[i])\n\n    with Timer(name=\"launch\", initial_text=True):\n        output_ref = wg.do_nothing(data_list_ref)\n\n    with Timer(name=\"get\", initial_text=True):\n        # takes around 40 seconds\n        output_lst = ray.get(output_ref)\n\n    for input_data, output_data in zip(data_list, output_lst, strict=True):\n        for key in input_data.batch.keys():\n            assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), (\n                input_data.batch[key],\n                output_data.batch[key],\n                key,\n            )\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_decorator_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport time\n\nimport pytest\nimport ray\nimport torch\nfrom tensordict import TensorDict\n\nfrom verl.protocol import DataProto, DataProtoFuture\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n# Pytest fixture for Ray setup/teardown\n@pytest.fixture\ndef ray_init_shutdown():\n    ray.init(num_cpus=100)\n    yield\n    ray.shutdown()\n\n\n# Define a simple worker for testing\n@ray.remote\nclass DecoratorTestWorker(Worker):\n    def __init__(self, initial_value=0):\n        super().__init__()\n        self.value = initial_value\n        # Simulate some setup if needed\n        time.sleep(0.1)  # Ensure worker init completes\n\n    # Test method for synchronous DP compute (default behavior)\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def dp_compute(self, data: DataProto) -> DataProto:\n        time.sleep(0.1)  # Simulate work\n        rank_value = torch.tensor(self.rank, device=data.batch[\"input\"].device, dtype=data.batch[\"input\"].dtype)\n        data.batch[\"output\"] = data.batch[\"input\"] + self.value + rank_value\n        return data\n\n    # Test async def method with DP compute (default behavior)\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)\n    async def async_dp_compute(self, data: DataProto) -> DataProto:\n        # Simulate async work\n        await asyncio.sleep(0.1)  # Simulate async work\n        rank_value = torch.tensor(self.rank, device=data.batch[\"input\"].device, dtype=data.batch[\"input\"].dtype)\n        data.batch[\"output_async\"] = data.batch[\"input\"] * 2 + self.value + rank_value\n        return data\n\n\n# Test function for synchronous DP compute\ndef test_decorator_dp_compute(ray_init_shutdown):\n    \"\"\"\n    Tests the default behavior of a synchronous decorated method with DP_COMPUTE_PROTO.\n    Verifies the result correctness.\n    \"\"\"\n    num_workers = 2\n    resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1)  # Use CPU for simplicity\n    cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=10)\n    worker_group = RayWorkerGroup(\n        resource_pool, cls_with_args, name_prefix=f\"decorator_test_sync_dp_{int(time.time())}\"\n    )\n\n    # Prepare input data (size 4, for 2 workers)\n    input_tensor = torch.arange(4, dtype=torch.float32)\n    data = DataProto(batch=TensorDict({\"input\": input_tensor}, batch_size=[4]))\n\n    # Call the decorated method\n    output = worker_group.dp_compute(data)\n\n    # Assert the result correctness\n    assert isinstance(output, DataProto), \"Expected DataProto result\"\n    assert \"output\" in output.batch.keys()\n    assert len(output) == len(data), \"Output length should match input length\"\n\n    # Expected output calculation for DP_COMPUTE_PROTO with 2 workers\n    # Worker 0 gets data[0:2], Worker 1 gets data[2:4]\n    # Worker 0 adds initial_value(10) + rank(0) = 10\n    # Worker 1 adds initial_value(10) + rank(1) = 11\n    expected_output_part1 = torch.tensor([0, 1], dtype=torch.float32) + 10 + 0\n    expected_output_part2 = torch.tensor([2, 3], dtype=torch.float32) + 10 + 1\n    expected_output = torch.cat([expected_output_part1, expected_output_part2])\n\n    torch.testing.assert_close(output.batch[\"output\"], expected_output, msg=\"Sync DP compute output data mismatch\")\n\n\n# Test function for async def method with DP compute\ndef test_decorator_async_function(ray_init_shutdown):\n    \"\"\"\n    Tests the decorator with an `async def` method using DP_COMPUTE_PROTO.\n    Verifies that the call returns a future and the result is correct after .get().\n    \"\"\"\n    num_workers = 2\n    resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1)\n    cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=5)\n    worker_group = RayWorkerGroup(\n        resource_pool, cls_with_args, name_prefix=f\"decorator_test_async_dp_{int(time.time())}\"\n    )\n\n    # Prepare input data (size 4, for 2 workers)\n    input_tensor = torch.arange(4, dtype=torch.float32)\n    data = DataProto(batch=TensorDict({\"input\": input_tensor}, batch_size=[4]))\n\n    # Call the async decorated method - this should return a future\n    future_output: DataProtoFuture = worker_group.async_dp_compute(data)\n\n    # Assert that the call returned a future\n    assert isinstance(future_output, DataProtoFuture), \"Expected DataProtoFuture for async def call\"\n\n    # Get the result (this should block)\n    result_data = future_output.get()\n\n    # Assert the result correctness\n    assert isinstance(result_data, DataProto)\n    assert \"output_async\" in result_data.batch.keys()\n    assert len(result_data) == len(data), \"Output length should match input length\"\n\n    # Expected output calculation for DP_COMPUTE_PROTO with 2 workers\n    # Worker 0 gets data[0:2], Worker 1 gets data[2:4]\n    # Worker 0 calculates: input * 2 + initial_value(5) + rank(0)\n    # Worker 1 calculates: input * 2 + initial_value(5) + rank(1)\n    expected_output_part1 = (torch.tensor([0, 1], dtype=torch.float32) * 2) + 5 + 0\n    expected_output_part2 = (torch.tensor([2, 3], dtype=torch.float32) * 2) + 5 + 1\n    expected_output = torch.cat([expected_output_part1, expected_output_part2])\n\n    torch.testing.assert_close(\n        result_data.batch[\"output_async\"], expected_output, msg=\"Async DP compute output data mismatch\"\n    )\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_driverfunc_to_worker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport ray\nimport torch\nfrom tensordict import TensorDict\n\nfrom verl import DataProto\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray import RayWorkerGroup\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool\n\nos.environ[\"RAY_DEDUP_LOGS\"] = \"0\"\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\n\n\n@ray.remote\nclass ModelActor(Worker):\n    def __init__(self):\n        pass\n\n\nclass HackSelf:\n    def __init__(self):\n        pass\n\n\ndef get_aux_metrics(self, test_proto):\n    sequence_ids = test_proto.batch[\"sequence_ids\"]\n    decode_count = []\n    for i in range(sequence_ids.size(0)):\n        decode_count.append(len(sequence_ids[i].tolist()))\n    ret_proto = DataProto(\n        batch=TensorDict(\n            {\"sequence_ids\": sequence_ids, \"decode_count\": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0)\n        )\n    )\n    return ret_proto\n\n\ndef test():\n    # construct model\n    ray.init()\n\n    # create 2 workers, each hold a GPU\n    resource_pool = RayResourcePool([2], use_gpu=True, name_prefix=\"a\")\n\n    class_with_args = RayClassWithInitArgs(cls=ModelActor)\n    shard_wg = RayWorkerGroup(resource_pool, class_with_args)\n\n    test_bs = 8\n    test_proto = DataProto(\n        TensorDict(\n            {\n                \"sequence_ids\": torch.ones([test_bs, 2048], dtype=torch.int64),\n            },\n            batch_size=test_bs,\n        ),\n        meta_info={\"query_length\": 1536},\n    )\n\n    # Sharding among different ranks\n    ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)\n\n    # compare execute on driver\n    hs = HackSelf()\n    ret_proto2 = get_aux_metrics(hs, test_proto)\n\n    torch.testing.assert_close(ret_proto1.batch[\"decode_count\"], ret_proto2.batch[\"decode_count\"])\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_fused_workers_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\n\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray.base import (\n    RayClassWithInitArgs,\n    RayResourcePool,\n    RayWorkerGroup,\n    create_colocated_worker_raw_cls,\n)\n\n\n@ray.remote\nclass Actor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def add(self, x):\n        x += self.rank\n        return x\n\n\n@ray.remote\nclass Critic(Worker):\n    def __init__(self, val) -> None:\n        super().__init__()\n        self.val = val\n\n    @register(dispatch_mode=Dispatch.ALL_TO_ALL)\n    def sub(self, x):\n        x -= self.val\n        return x\n\n\nactor_cls = RayClassWithInitArgs(cls=Actor)\ncritic_cls = RayClassWithInitArgs(cls=Critic, val=10)\ncls_dict = {\"actor\": actor_cls, \"critic\": critic_cls}\nFusedBaseClass = create_colocated_worker_raw_cls(cls_dict)\n\n\n@ray.remote\nclass HybridWorker(FusedBaseClass):\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def foo(self, x):\n        return self.critic.sub(self.actor.add(x))\n\n\ndef test_fused_workers():\n    ray.init(num_cpus=100)\n\n    # create separate workers on the same resource pool\n    process_on_nodes = [2]\n    resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=False)\n\n    # create colocated workers\n    hybrid_cls_with_init = RayClassWithInitArgs(cls=HybridWorker)\n    hybrid_cls_with_init.fused_worker_used = True\n\n    fused_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=hybrid_cls_with_init)\n    fused_wg.fuse(cls_dict.keys())\n\n    x = fused_wg.actor.add(0.1)\n    print(x)\n    y = fused_wg.critic.sub(x)\n    print(y)\n    z = fused_wg.foo(0.1)\n    print(z)\n    for i, j in zip(y, z, strict=True):\n        assert i == j\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_fused_workers()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_high_level_scheduling_api.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport ray\n\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool\n\n\n@ray.remote\nclass TestActor(Worker):\n    # TODO: pass *args and **kwargs is bug prone and not very convincing\n    def __init__(self, cuda_visible_devices=None) -> None:\n        super().__init__(cuda_visible_devices)\n\n    def get_node_id(self):\n        return ray.get_runtime_context().get_node_id()\n\n\ndef test():\n    ray.init()\n\n    # test single-node-no-partition\n    print(\"test single-node-no-partition\")\n    resource_pool = RayResourcePool([8], use_gpu=True)\n\n    class_with_args = RayClassWithInitArgs(cls=TestActor)\n\n    print(\"create actor worker group\")\n    actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"high_level_api_actor\")\n    print(\"create critic worker group\")\n    critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"hight_level_api_critic\")\n    print(\"create rm worker group\")\n    rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"high_level_api_rm\")\n    print(\"create ref worker group\")\n    ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"high_level_api_ref\")\n\n    assert actor_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n    assert critic_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n    assert rm_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n    assert ref_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n\n    del actor_wg\n    del critic_wg\n    del rm_wg\n    del ref_wg\n\n    [ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()]\n    print(\"wait 5s to remove placemeng_group\")\n    time.sleep(5)\n    # test single-node-multi-partition\n\n    print(\"test single-node-multi-partition\")\n    rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix=\"rm\")\n    ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix=\"ref\")\n    total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)\n\n    assert rm_resource_pool.world_size == 4\n    assert ref_resource_pool.world_size == 4\n    assert total_resource_pool.world_size == 8\n\n    actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix=\"high_level_api_actor\")\n    critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix=\"high_level_api_critic\")\n    rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix=\"high_level_api_rm\")\n    ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix=\"high_level_api_ref\")\n\n    assert actor_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n    assert critic_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(8)]\n    assert rm_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(4)]\n    assert ref_wg.execute_all_sync(\"get_cuda_visible_devices\") == [str(i) for i in range(4, 8)]\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_ray_collectives.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTest for using ray collective group.\nSuppose we Actor and Rollout. Actor contains 4 workers and Rollout contains 2 workers. We established a Worker to\nRollout relationship by using collective groups\nActor: rank 0, 1 - Rollout rank 0\nRollout rank 2, 3 - Rollout rank 1\nThen, we initiate 4 p2p comms from actor to rollout\n\"\"\"\n\nimport ray\nimport ray.util.collective as collective\nimport torch\n\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n@ray.remote\nclass Actor(Worker):\n    @register(Dispatch.ONE_TO_ALL)\n    def init(self):\n        remote_rank = self.rank // 2\n        self.group_name = f\"A{self.rank}_R{remote_rank}\"\n        collective.init_collective_group(world_size=2, rank=0, backend=\"nccl\", group_name=self.group_name)\n\n    @register(Dispatch.ONE_TO_ALL, blocking=False)\n    def send_tensors(self):\n        tensor = torch.ones(size=(4,), dtype=torch.float32, device=\"cuda\") * self.rank\n        collective.send(tensor=tensor, dst_rank=1, group_name=self.group_name)\n\n\n@ray.remote\nclass Rollout(Worker):\n    @register(Dispatch.ONE_TO_ALL)\n    def init(self):\n        self.remote_first_rank = self.rank * 2\n        self.remote_second_rank = self.remote_first_rank + 1\n        self.first_group_name = f\"A{self.remote_first_rank}_R{self.rank}\"\n        self.second_group_name = f\"A{self.remote_second_rank}_R{self.rank}\"\n\n        collective.init_collective_group(world_size=2, rank=1, backend=\"nccl\", group_name=self.first_group_name)\n        collective.init_collective_group(world_size=2, rank=1, backend=\"nccl\", group_name=self.second_group_name)\n\n    @register(Dispatch.ONE_TO_ALL, blocking=False)\n    def receive_tensors(self):\n        self.tensor1 = torch.randn(size=(4,), dtype=torch.float32, device=\"cuda\")\n        self.tensor2 = torch.randn(size=(4,), dtype=torch.float32, device=\"cuda\")\n\n        collective.recv(self.tensor1, src_rank=0, group_name=self.first_group_name)\n        collective.recv(self.tensor2, src_rank=0, group_name=self.second_group_name)\n\n    @register(Dispatch.ONE_TO_ALL)\n    def get_tensors(self):\n        return {f\"src_{self.remote_first_rank}\": self.tensor1, f\"src_{self.remote_second_rank}\": self.tensor2}\n\n\ndef test_ray_collective_group():\n    ray.init()\n\n    actor_resource_pool = RayResourcePool([4])\n    rollout_resource_pool = RayResourcePool([2])\n\n    actor_cls = RayClassWithInitArgs(cls=Actor)\n    rollout_cls = RayClassWithInitArgs(cls=Rollout)\n\n    actor_wg = RayWorkerGroup(\n        resource_pool=actor_resource_pool, ray_cls_with_init=actor_cls, name_prefix=\"collective_group_actor\"\n    )\n    rollout_wg = RayWorkerGroup(\n        resource_pool=rollout_resource_pool, ray_cls_with_init=rollout_cls, name_prefix=\"collective_group_rollout\"\n    )\n\n    actor_wg.init()\n    rollout_wg.init()\n\n    out1 = actor_wg.send_tensors()\n    out2 = rollout_wg.receive_tensors()\n\n    # block to wait\n    ray.get(out1)\n    ray.get(out2)\n\n    output = rollout_wg.get_tensors()\n\n    rollout_0_output = output[0]\n    rollout_1_output = output[1]\n\n    output = rollout_0_output | rollout_1_output\n\n    print(output)\n\n    for i in range(4):\n        assert torch.sum(output[f\"src_{i}\"]).item() == 4 * i\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_ray_collective_group()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_ray_local_envs_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\ne2e test verl.single_controller.ray\n\"\"\"\n\nimport os\n\nimport ray\n\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n@ray.remote\nclass TestActor(Worker):\n    def __init__(self) -> None:\n        super().__init__()\n\n    def getenv(self, key):\n        val = os.getenv(key, f\"{key} not set\")\n        return val\n\n\ndef test_basics():\n    ray.init(num_cpus=100)\n\n    # create 4 workers, each hold a GPU\n    resource_pool = RayResourcePool([4], use_gpu=False)\n    class_with_args = RayClassWithInitArgs(cls=TestActor)\n\n    worker_group = RayWorkerGroup(\n        resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix=\"worker_group_basic\"\n    )\n\n    output = worker_group.execute_all_sync(\"getenv\", key=\"RAY_LOCAL_WORLD_SIZE\")\n    assert output == [\"4\", \"4\", \"4\", \"4\"]\n\n    output = worker_group.execute_all_sync(\"getenv\", key=\"RAY_LOCAL_RANK\")\n    assert set(output) == set([\"0\", \"1\", \"2\", \"3\"])\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_basics()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_ray_utils_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport ray\n\nfrom verl.utils.ray_utils import parallel_put\n\n\n# Initialize Ray for testing if not already done globally\n@pytest.fixture()\ndef init_ray():\n    ray.init(num_cpus=4)\n    yield\n    ray.shutdown()\n\n\ndef test_parallel_put_basic(init_ray):\n    data = [1, \"hello\", {\"a\": 2}, [3, 4]]\n    refs = parallel_put(data)\n    assert len(refs) == len(data)\n    retrieved_data = [ray.get(ref) for ref in refs]\n    assert retrieved_data == data\n\n\ndef test_parallel_put_empty(init_ray):\n    data = []\n    with pytest.raises(AssertionError):\n        _ = parallel_put(data)\n\n\ndef test_parallel_put_workers(init_ray):\n    data = list(range(20))\n    # Test with specific number of workers\n    refs = parallel_put(data, max_workers=4)\n    assert len(refs) == len(data)\n    retrieved_data = [ray.get(ref) for ref in refs]\n    assert retrieved_data == data\n    # Test with default workers (should cap)\n    refs_default = parallel_put(data)\n    assert len(refs_default) == len(data)\n    retrieved_data_default = [ray.get(ref) for ref in refs_default]\n    assert retrieved_data_default == data\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_rvdz.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\n\n\n@ray.remote\nclass TestWorker:\n    def __init__(self, rank, world_size, group_name):\n        self.rank = rank\n        self.world_size = world_size\n        self.group_name = group_name\n        self.communicator = None\n\n    def init(self):\n        from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray\n\n        self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name)\n\n    def test(self):\n        if self.communicator is None:\n            return None\n        return self.communicator.rank_id()\n\n\ndef test_rvdz():\n    ray.init()\n\n    group_name = \"test_group\"\n    world_size = 2\n\n    workers = [TestWorker.options(num_gpus=1).remote(rank, world_size, group_name) for rank in range(world_size)]\n\n    ray.get([worker.init.remote() for worker in workers])\n\n    ranks = ray.get([worker.test.remote() for worker in workers])\n\n    assert ranks == [0, 1], f\"expecting [0, 1], got {ranks}\"\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_worker_group_basics.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\ne2e test verl.single_controller.ray\n\"\"\"\n\nimport ray\nimport torch\n\nfrom verl.single_controller.base.decorator import Dispatch, Execute, collect_all_to_all, register\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\ndef two_to_all_dispatch_fn(worker_group, *args, **kwargs):\n    \"\"\"\n    Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker.\n    \"\"\"\n    for arg in args:\n        assert len(arg) == 2\n        for i in range(worker_group.world_size - 2):\n            arg.append(arg[i % 2])\n    for k, v in kwargs.items():\n        assert len(v) == 2\n        for i in range(worker_group.world_size - 2):\n            v.append(v[i % 2])\n    return args, kwargs\n\n\n@ray.remote\nclass TestActor(Worker):\n    # TODO: pass *args and **kwargs is bug prone and not very convincing\n    def __init__(self, x) -> None:\n        super().__init__()\n        self._x = x\n\n    def foo(self, y):\n        return self._x + y\n\n    @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)\n    def foo_rank_zero(self, x, y):\n        return self._x + y + x\n\n    @register(Dispatch.ONE_TO_ALL, blocking=False)\n    def foo_one_to_all(self, x, y):\n        return self._x + y + x\n\n    @register(Dispatch.ALL_TO_ALL, blocking=False)\n    def foo_all_to_all(self, x, y):\n        return self._x + y + x\n\n    @register(dispatch_mode={\"dispatch_fn\": two_to_all_dispatch_fn, \"collect_fn\": collect_all_to_all})\n    def foo_custom(self, x, y):\n        return self._x + y + x\n\n\n@ray.remote(num_gpus=0.1)\ndef remote_call_wg(worker_names):\n    class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)\n    worker_group = RayWorkerGroup.from_detached(\n        worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None\n    )\n    print(worker_group.worker_names)\n\n    output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])\n    assert output_ref == [8, 10, 8, 10]\n\n    output_ref = worker_group.foo_rank_zero(x=1, y=2)\n    assert output_ref == 5\n\n    return worker_group.worker_names\n\n\ndef add_one(data):\n    data = data.to(\"cuda\")\n    data += 1\n    data = data.to(\"cpu\")\n    return data\n\n\ndef test_basics():\n    ray.init(num_cpus=100)\n\n    # create 4 workers, each hold a GPU\n    resource_pool = RayResourcePool([4], use_gpu=True)\n    class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)\n\n    worker_group = RayWorkerGroup(\n        resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix=\"worker_group_basic\"\n    )\n\n    print(worker_group.worker_names)\n\n    # this will wait for all the results\n    output = worker_group.execute_all_sync(\"foo\", y=3)\n    assert output == [5, 5, 5, 5]\n\n    # this is a list of object reference. It won't block.\n    output_ref = worker_group.execute_all_async(\"foo\", y=4)\n    print(output_ref)\n\n    assert ray.get(output_ref) == [6, 6, 6, 6]\n\n    output_ref = worker_group.foo_one_to_all(x=1, y=2)\n    assert ray.get(output_ref) == [5, 5, 5, 5]\n\n    output_ref = worker_group.foo_all_to_all(x=[1, 2, 3, 4], y=[5, 6, 7, 8])\n    assert ray.get(output_ref) == [8, 10, 12, 14]\n\n    print(ray.get(remote_call_wg.remote(worker_group.worker_names)))\n\n    output = worker_group.execute_func_rank_zero(add_one, torch.ones(2, 2))\n    torch.testing.assert_close(output, torch.ones(2, 2) + 1)\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    test_basics()\n"
  },
  {
    "path": "verl_rl/tests/single_controller/test_worker_group_torch.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nos.environ[\"RAY_DEDUP_LOGS\"] = \"0\"\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\n\nimport ray\nimport torch\nimport torch.distributed\n\nfrom verl.single_controller.base.worker import Worker\nfrom verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n@ray.remote\nclass TestAllGatherActor(Worker):\n    def __init__(self, size) -> None:\n        super().__init__()\n        self.size = size\n\n    def init(self):\n        torch.distributed.init_process_group()\n        self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device=\"cuda\")\n        self.tensor += self.rank\n\n    def all_gather(self):\n        world_size = self._world_size\n        output = torch.zeros(\n            size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device\n        )\n        torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)\n        return output\n\n\n@ray.remote\nclass TestAllGatherActorV2(Worker):\n    def __init__(self, size) -> None:\n        super().__init__()\n        self.size = size\n\n        torch.distributed.init_process_group()\n        self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device=\"cuda\")\n        self.tensor += self.rank\n\n    def all_gather(self):\n        world_size = self._world_size\n        output = torch.zeros(\n            size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device\n        )\n        torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)\n        return output\n\n\ndef test_all_gather_torch():\n    \"\"\"\n    In this test, we instantiate 4 GPUs in a group and test the all_gather\n    \"\"\"\n    ray.init()\n\n    # create 4 workers, each hold a GPU\n    resource_pool = RayResourcePool([4], use_gpu=True)\n    class_with_args = RayClassWithInitArgs(cls=TestAllGatherActor, size=2)\n\n    worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"worker_group_torch\")\n\n    worker_group.execute_all_sync(\"init\")\n    output = worker_group.execute_all_sync(\"all_gather\")\n    for i in range(1, len(output)):\n        assert torch.all(output[i] == output[0])\n\n    output = output[0].cpu()\n    print(output)\n    assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))\n\n    ray.shutdown()\n\n\ndef test_all_gather_torch_v2():\n    \"\"\"\n    In this test, we instantiate 4 GPUs in a group and test the all_gather\n    \"\"\"\n    ray.init()\n\n    # create 4 workers, each hold a GPU\n    resource_pool = RayResourcePool([4], use_gpu=True)\n    class_with_args = RayClassWithInitArgs(cls=TestAllGatherActorV2, size=2)\n\n    worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix=\"worker_group_torch\")\n\n    output = worker_group.execute_all_sync(\"all_gather\")\n    for i in range(1, len(output)):\n        assert torch.all(output[i] == output[0])\n\n    output = output[0].cpu()\n    print(output)\n    assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))\n\n    ray.shutdown()\n"
  },
  {
    "path": "verl_rl/tests/special_distributed/README.md",
    "content": "This folder is reserved for unit tests (instead of end-to-end tests) that require multiple GPUs.\n"
  },
  {
    "path": "verl_rl/tests/special_distributed/run_all.sh",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/usr/bin/env bash\n\nset -e -x\ntorchrun --nproc-per-node=4 --standalone tests/special_distributed/test_tensor_dict.py"
  },
  {
    "path": "verl_rl/tests/special_distributed/test_fsdp_ckpt.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nimport shutil\nimport tempfile\n\nimport torch\nimport torch.distributed\nfrom torch.distributed import init_device_mesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import MixedPrecision, ShardingStrategy\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config\n\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.distributed import initialize_global_process_group\nfrom verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2\n\n\ndef test_fsdp_ckpt(strategy=\"fsdp\"):\n    assert torch.cuda.device_count() >= 2, \"need at least 2 gpus for test\"\n    local_rank, rank, world_size = initialize_global_process_group()\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(world_size,), mesh_dim_names=(\"dp\",))\n\n    model_name = \"Qwen/Qwen2.5-0.5B-Instruct\"\n    config = Qwen2Config(num_hidden_layers=1)\n\n    with torch.device(\"cuda\"):\n        model = AutoModelForCausalLM.from_config(\n            config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n        )\n        model = model.to(device=\"cuda\")\n\n    # Wrap model with FSDP\n    if strategy == \"fsdp\":\n        mixed_precision = MixedPrecision(\n            param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32\n        )\n\n        model = FSDP(\n            model,\n            use_orig_params=False,\n            device_id=torch.cuda.current_device(),\n            sharding_strategy=ShardingStrategy.FULL_SHARD,\n            mixed_precision=mixed_precision,\n            device_mesh=device_mesh,\n        )\n    else:\n        mp_policy = MixedPrecisionPolicy(\n            param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True\n        )\n        fsdp_kwargs = {\n            \"mesh\": device_mesh,\n            \"mp_policy\": mp_policy,\n        }\n        apply_fsdp2(model, fsdp_kwargs, {})\n\n    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)\n    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)\n\n    # Create checkpoint manager\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    checkpoint_manager = FSDPCheckpointManager(\n        model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer\n    )\n\n    # Generate sample input\n    batch_size = 2\n    seq_len = 32\n    vocab_size = 32000\n    # First input for initial update\n    input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device=\"cuda\")\n    attention_mask1 = torch.ones_like(input_ids1)\n\n    # Second input for verification\n    input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device=\"cuda\")\n    attention_mask2 = torch.ones_like(input_ids2)\n\n    # Step 1: Initial update and save checkpoint\n    outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1)\n    loss1 = outputs1.logits.mean()\n    loss1.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Save checkpoint after first update\n    temp_dir = tempfile.mkdtemp()\n    checkpoint_path = os.path.join(temp_dir, \"checkpoint\")\n    checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)\n\n    # Step 2: Second update and forward pass\n    outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2)\n    loss2 = outputs2.logits.mean()\n    loss2.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Record logits after second update\n    with torch.no_grad():\n        logits_before_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits\n\n    # Step 3: Load checkpoint and repeat second update\n    checkpoint_manager.load_checkpoint(checkpoint_path)\n\n    # Repeat the second update with same input\n    outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2)\n    loss3 = outputs3.logits.mean()\n    loss3.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Record logits after loaded checkpoint and update\n    with torch.no_grad():\n        logits_after_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits\n\n    # Step 4: Verify outputs match\n    torch.testing.assert_close(logits_before_load, logits_after_load, atol=0.0, rtol=0.0)\n    print(\"Checkpoint save/load test passed!\")\n\n    # Cleanup\n    shutil.rmtree(temp_dir)\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    strategy = os.environ.get(\"STRATEGY\", \"fsdp\")\n    test_fsdp_ckpt(strategy=strategy)\n"
  },
  {
    "path": "verl_rl/tests/special_distributed/test_tensor_dict.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\n\nimport numpy as np\nimport torch\nimport torch.distributed\n\nfrom verl.protocol import DataProto, all_gather_data_proto\nfrom verl.utils.distributed import initialize_global_process_group\n\n\ndef test_all_gather_data_proto():\n    device_mesh = torch.distributed.device_mesh.init_device_mesh(\"cuda\", mesh_shape=[2, 2], mesh_dim_names=[\"dp\", \"tp\"])\n\n    global_rank = torch.distributed.get_rank()\n\n    obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]])\n\n    labels = [\"a\", \"b\"] if global_rank % 2 == 0 else [\"b\", \"a\"]\n    labels = np.array(labels, dtype=object)\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    all_gather_data_proto(data=data, process_group=device_mesh.get_group(\"dp\"))\n\n    if global_rank == 0:\n        expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device=\"cuda\")\n        expected_labels = [\"a\", \"b\", \"a\", \"b\"]\n    elif global_rank == 1:\n        expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device=\"cuda\")\n        expected_labels = [\"b\", \"a\", \"b\", \"a\"]\n    elif global_rank == 2:\n        expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device=\"cuda\")\n        expected_labels = [\"a\", \"b\", \"a\", \"b\"]\n    elif global_rank == 3:\n        expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device=\"cuda\")\n        expected_labels = [\"b\", \"a\", \"b\", \"a\"]\n\n    torch.testing.assert_close(data.batch[\"obs\"], expected_obs, atol=0, rtol=0)\n    assert (data.non_tensor_batch[\"labels\"] == expected_labels).all()\n    assert data.meta_info == {\"info\": \"test_info\"}\n\n\ndef test_vocab_parallel_entropy():\n    from megatron.core import parallel_state as mpu\n\n    from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy\n    from verl.utils.profiler import log_gpu_memory_usage\n    from verl.utils.torch_functional import entropy_from_logits\n\n    mpu.initialize_model_parallel(\n        tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None\n    )\n\n    batch_size = 2\n    seqlen = 128\n    vocab_size = 155136\n\n    logits = torch.randn(batch_size * seqlen, vocab_size, device=\"cuda\", requires_grad=True)\n    target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device=\"cuda\", dtype=torch.int64)\n\n    # broadcast across tp\n    torch.distributed.broadcast(\n        logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()\n    )\n    torch.distributed.broadcast(\n        target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()\n    )\n\n    tp_rank = mpu.get_tensor_model_parallel_rank()\n    vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size()\n\n    # get the local logits of each tp\n    vocab_parallel_logits = (\n        logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_()\n    )\n    logits.grad = None\n    vocab_parallel_logits.grad = None\n\n    log_gpu_memory_usage(\"begin\")\n    output_entropy = vocab_parallel_entropy(vocab_parallel_logits)\n    log_gpu_memory_usage(\"after forward\")\n    grad_output = torch.randn_like(output_entropy)\n    output_entropy.backward(grad_output)\n    log_gpu_memory_usage(\"after backward\")\n\n    target_entropy = entropy_from_logits(logits)\n    torch.testing.assert_close(output_entropy, target_entropy)\n    target_entropy.backward(grad_output)\n    torch.testing.assert_close(\n        logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad\n    )\n    # make sure logits is not altered\n    torch.testing.assert_close(\n        logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits\n    )\n\n    if mpu.get_tensor_model_parallel_rank() == 0:\n        print(\"test_vocab_parallel_entropy passes\")\n\n    mpu.destroy_model_parallel()\n\n\nif __name__ == \"__main__\":\n    local_rank, rank, world_size = initialize_global_process_group()\n    test_all_gather_data_proto()\n    test_vocab_parallel_entropy()\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/README.md",
    "content": "This folder is reserved for end-to-end tests that typically require multiple GPUs.\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/check_custom_rwd_fn.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\n\n\ndef check_congratulations_in_file(output_file):\n    with open(output_file) as f:\n        output = f.read()\n\n    success_message = \"Congratulations!!! You have called my_reward_function successfully!!!\"\n    assert success_message in output, f\"Success message of my_reward_function not found in {output_file}\"\n    print(\"Check passes\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--output_file\", required=True, type=str)\n\n    args = parser.parse_args()\n\n    check_congratulations_in_file(args.output_file)\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/check_results.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\n\nimport numpy as np\n\n\ndef extract_reward_from_line(line):\n    # TODO: this function needs error handling\n    try:\n        key_vals = line.split(\" - \")\n        for key_val in key_vals:\n            key, val = key_val.split(\":\")\n            if key == \"critic/rewards/mean\":\n                reward = float(val)\n                return reward\n        return -np.inf\n    except Exception:\n        return -np.inf\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--output_file\", required=True, type=str)\n    parser.add_argument(\"--target\", type=float, default=0.2, help=\"target reward score\")\n\n    args = parser.parse_args()\n\n    with open(args.output_file) as f:\n        output = f.read().split(\"\\n\")\n\n    best_reward = -np.inf\n    for line in output:\n        if line.startswith(\"step\"):\n            reward = extract_reward_from_line(line)\n            if reward > best_reward:\n                best_reward = reward\n\n    print(f\"Best reward is {best_reward}\")\n    assert best_reward > args.target, f\"Best reward must be greater than {args.target}. best_reward: {best_reward}\"\n    print(\"Check passes\")\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/envs/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .digit_completion import DigitCompletion\n\n__all__ = [\"DigitCompletion\"]\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/envs/digit_completion/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers import AutoTokenizer, LlamaConfig\n\nfrom .task import DigitCompletion, generate_ground_truth_response\nfrom .tokenizer import CharTokenizer\n\nAutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True)\n\n__all__ = [\"DigitCompletion\", \"generate_ground_truth_response\", \"CharTokenizer\"]\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/envs/digit_completion/task.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Task and environment definition for digit completion.\"\"\"\n\nimport numpy as np\n\n\nclass DigitCompletion:\n    \"\"\"\n    The implementation of a simple digit completion task.\n    The prompt is a sequence of numbers with fixed difference. The task is to complete the next N numbers.\n    If the max number is reached, the next number should be modulo with max number.\n\n    For example,\n    - prompt = [1, 2, 3]\n    - N = 5\n    - max_number = 6\n\n    the response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1]\n\n    Note that the tokenizer is char-level to increase the difficulty.\n    \"\"\"\n\n    def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, seed=0):\n        \"\"\"\n\n        Args:\n            max_number: the maximum number allowed in the arithmetic sequence\n            max_diff: the maximum diff. The actual common diff will be sampled from [0, max_diff]\n            max_num_in_response: the maximum number in the response\n        \"\"\"\n        super().__init__()\n        self.max_number = max_number\n        self.max_diff = max_diff\n        self.max_num_in_response = max_num_in_response\n        assert self.max_num_in_response < 10\n        assert self.max_number > 0\n        assert self.max_diff > 0\n        self.max_number_length = len(str(max_number))\n        # {num1},{num2}:{max_num_in_response},{max_number}\n        self._prompt_length = self.max_number_length * 2 + 4 + self.max_number_length  # no negative is allowed\n\n        self.np_rng = np.random.default_rng(seed=seed)\n\n    def __str__(self):\n        return (\n            f\"Prompt length: {self.prompt_length}. Response length: {self.response_length}, \"\n            f\"Max number: {self.max_number}. Max diff: {self.max_diff}, \"\n            f\"Max number in response: {self.max_num_in_response}\"\n        )\n\n    def get_state(self):\n        return {\"rng\": self.np_rng}\n\n    def set_state(self, state):\n        assert \"rng\" in state, \"rng must be inside state\"\n        self.np_rng = state[\"rng\"]\n\n    @property\n    def prompt_length(self):\n        return self._prompt_length\n\n    @property\n    def response_length(self):\n        # number length + comma length + [EOS]\n        # The actual number times 1.5 to allow 'U'\n        return (self.max_num_in_response * self.max_number_length + (self.max_num_in_response - 1) + 1) * 2\n\n    def add(self, a, b):\n        return (a + b) % self.max_number\n\n    def get_all_prompts(self):\n        all_prompts = []\n        for first_num in range(self.max_number + 1):\n            for diff in range(0, self.max_diff + 1):\n                second_num = self.add(first_num, diff)\n                for num_to_complete in range(self.max_num_in_response + 1):\n                    prompt = str(first_num) + \",\" + str(second_num) + f\":{self.max_number},{num_to_complete}\"\n                    all_prompts.append(prompt)\n        return all_prompts\n\n    def sample_str_prompts(self):\n        # step 1: sample initial numbers\n        first_num = self.np_rng.integers(self.max_number + 1)\n        diff = self.np_rng.integers(self.max_diff + 1)\n        second_num = self.add(first_num, diff)\n        num_to_complete = self.np_rng.integers(self.max_num_in_response + 1)\n        prompt = str(first_num) + \",\" + str(second_num) + f\":{self.max_number},{num_to_complete}\"\n        return prompt\n\n    def sample_batch_str_prompts(self, batch_size):\n        str_prompts = []\n        for _ in range(batch_size):\n            str_prompts.append(self.sample_str_prompts())\n        return str_prompts\n\n\ndef compute_attention_mask(prompts, pad_token_id):\n    mask = np.ones_like(prompts)\n    mask[prompts == pad_token_id] = 0\n    return mask\n\n\ndef compute_position_id_with_mask(mask):\n    return np.clip(np.cumsum(mask, axis=-1) - 1, a_min=0, a_max=None)\n\n\ndef generate_ground_truth_response(prompt: str):\n    \"\"\"Generate ground truth response given a prompt.\"\"\"\n    num, info = prompt.split(\":\")\n    num1, num2 = num.split(\",\")\n    max_number, num_to_gen = info.split(\",\")\n    num1 = int(num1)\n    num2 = int(num2)\n    max_number = int(max_number)\n    num_to_gen = int(num_to_gen)\n    diff = (num2 - num1) % max_number\n    results = []\n    last_num = num2\n    for _ in range(num_to_gen):\n        curr = (last_num + diff) % max_number\n        results.append(str(curr))\n        last_num = curr\n    response = \",\".join(results)\n    return response\n\n\ndef compute_reward(prompt: str, response: str, sequence_reward=1.0):\n    \"\"\"We compute dense reward here so that we can directly train RL without SFT\"\"\"\n    response_length = len(response)\n    ground_truth_response = generate_ground_truth_response(prompt)\n    per_token_reward = sequence_reward / (len(ground_truth_response) + 1)  # including [EOS]\n\n    # pad\n    reward = np.zeros(response_length, dtype=np.float32)  # this assumes that each char is a token\n    # assign reward until mismatches\n    ground_truth_idx = 0\n    for i in range(response_length):\n        if ground_truth_idx == len(ground_truth_response):\n            break\n\n        ground_truth_response_token = ground_truth_response[ground_truth_idx]\n        response_token = response[i]\n        if ground_truth_response_token == response_token:\n            reward[i] = per_token_reward\n            ground_truth_idx += 1\n        else:\n            # no matches\n            break\n\n    return reward, {\"ground_truth_response\": ground_truth_response}\n\n\nif __name__ == \"__main__\":\n    task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5)\n    print(task.sample_str_prompts())\n\n    prompt = \"7,8:20,0\"\n    response = \"\"\n    print(compute_reward(prompt, response))\n\n    prompt = \"7,8:20,0\"\n    response = \"E000\"\n    print(compute_reward(prompt, response))\n\n    prompt = \"9,10:20,2\"\n    response = \"11,12,13\"\n    print(compute_reward(prompt, response))\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/envs/digit_completion/tokenizer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Copied from https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py\n\nCharacterTokenzier for Hugging Face Transformers.\n\nThis is heavily inspired from CanineTokenizer in transformers package.\n\"\"\"\n\nimport json\nimport os\nfrom pathlib import Path\nfrom typing import Optional, Sequence\n\nfrom transformers.tokenization_utils import AddedToken, PreTrainedTokenizer\n\n\nclass CharTokenizer(PreTrainedTokenizer):\n    def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs):\n        \"\"\"Character tokenizer for Hugging Face transformers.\n\n        Args:\n            characters (Sequence[str]): List of desired characters. Any character which\n                is not included in this list will be replaced by a special token called\n                [UNK] with id=6. Following are list of all of the special tokens with\n                their corresponding ids:\n                    \"[CLS]\": 0\n                    \"[SEP]\": 1\n                    \"[BOS]\": 2\n                    \"[MASK]\": 3\n                    \"[PAD]\": 4\n                    \"[RESERVED]\": 5\n                    \"[UNK]\": 6\n                an id (starting at 7) will be assigned to each character.\n\n            model_max_length (int): Model maximum sequence length.\n        \"\"\"\n        eos_token_str = \"E\"\n        sep_token_str = \"S\"\n        pad_token_str = \"P\"\n        unk_token_str = \"U\"\n\n        self.characters = characters\n        self.model_max_length = model_max_length\n        eos_token = AddedToken(eos_token_str, lstrip=False, rstrip=False)\n        sep_token = AddedToken(sep_token_str, lstrip=False, rstrip=False)\n        pad_token = AddedToken(pad_token_str, lstrip=False, rstrip=False)\n        unk_token = AddedToken(unk_token_str, lstrip=False, rstrip=False)\n\n        self._vocab_str_to_int = {\n            sep_token_str: 0,\n            eos_token_str: 1,\n            pad_token_str: 2,\n            unk_token_str: 3,\n            **{ch: i + 4 for i, ch in enumerate(characters)},\n        }\n        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}\n\n        super().__init__(\n            eos_token=eos_token,\n            sep_token=sep_token,\n            pad_token=pad_token,\n            unk_token=unk_token,\n            add_prefix_space=False,\n            model_max_length=model_max_length,\n            **kwargs,\n        )\n\n        self.chat_template = chat_template\n\n    @property\n    def vocab_size(self) -> int:\n        return len(self._vocab_str_to_int)\n\n    def get_vocab(self):\n        return self._vocab_str_to_int\n\n    def _tokenize(self, text: str) -> list[str]:\n        return list(text)\n\n    def _convert_token_to_id(self, token: str) -> int:\n        return self._vocab_str_to_int.get(token, self._vocab_str_to_int[\"U\"])\n\n    def _convert_id_to_token(self, index: int) -> str:\n        return self._vocab_int_to_str[index]\n\n    def convert_tokens_to_string(self, tokens):\n        return \"\".join(tokens)\n\n    def build_inputs_with_special_tokens(\n        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None\n    ) -> list[int]:\n        sep = [self.sep_token_id]\n        cls = [self.cls_token_id]\n        result = cls + token_ids_0 + sep\n        if token_ids_1 is not None:\n            result += token_ids_1 + sep\n        return result\n\n    def get_special_tokens_mask(\n        self,\n        token_ids_0: list[int],\n        token_ids_1: Optional[list[int]] = None,\n        already_has_special_tokens: bool = False,\n    ) -> list[int]:\n        if already_has_special_tokens:\n            return super().get_special_tokens_mask(\n                token_ids_0=token_ids_0,\n                token_ids_1=token_ids_1,\n                already_has_special_tokens=True,\n            )\n\n        result = [1] + ([0] * len(token_ids_0)) + [1]\n        if token_ids_1 is not None:\n            result += ([0] * len(token_ids_1)) + [1]\n        return result\n\n    def get_config(self) -> dict:\n        return {\n            \"char_ords\": [ord(ch) for ch in self.characters],\n            \"model_max_length\": self.model_max_length,\n            \"chat_template\": self.chat_template,\n        }\n\n    @classmethod\n    def from_config(cls, config: dict):\n        cfg = {}\n        cfg[\"characters\"] = [chr(i) for i in config[\"char_ords\"]]\n        cfg[\"model_max_length\"] = config[\"model_max_length\"]\n        cfg[\"chat_template\"] = config[\"chat_template\"]\n        return cls(**cfg)\n\n    def save_pretrained(self, save_directory: str | os.PathLike, **kwargs):\n        cfg_file = Path(save_directory) / \"tokenizer_config.json\"\n        cfg = self.get_config()\n        with open(cfg_file, \"w\") as f:\n            json.dump(cfg, f, indent=4)\n\n    @classmethod\n    def from_pretrained(cls, save_directory: str | os.PathLike, **kwargs):\n        cfg_file = Path(save_directory) / \"tokenizer_config.json\"\n        with open(cfg_file) as f:\n            cfg = json.load(f)\n        return cls.from_config(cfg)\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/generation/run_gen_qwen05.sh",
    "content": "#!/usr/bin/env bash\n# Tested with 1 & 4 GPUs\nset -xeuo pipefail\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\n\nNGPUS_PER_NODE=${NGPUS_PER_NODE:-4}\nOUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_05_gen_test.parquet}\nGEN_TP=${GEN_TP:-2}  # Default tensor parallel size to 2\n\npython3 -m verl.trainer.main_generation \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=\"${NGPUS_PER_NODE}\" \\\n    data.path=\"${HOME}/data/gsm8k/test.parquet\" \\\n    data.prompt_key=prompt \\\n    data.n_samples=1 \\\n    data.output_path=\"${OUTPUT_PATH}\" \\\n    model.path=\"${MODEL_ID}\" \\\n    +model.trust_remote_code=True \\\n    rollout.temperature=1.0 \\\n    rollout.top_k=50 \\\n    rollout.top_p=0.7 \\\n    rollout.prompt_length=2048 \\\n    rollout.response_length=1024 \\\n    rollout.tensor_model_parallel_size=\"${GEN_TP}\" \\\n    rollout.gpu_memory_utilization=0.8\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json",
    "content": "{\n    \"num_hidden_layers\": 2,\n    \"max_window_layers\": 2\n}"
  },
  {
    "path": "verl_rl/tests/special_e2e/ppo_trainer/run_function_reward.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nhuggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nTRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}\nVAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}\nMAX_PROMPT_LEN=${MAX_PROMPT_LEN:-512}\nMAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512}\n\nENGINE=${ENGINE:-vllm}\nROLLOUT_MODE=${ROLLOUT_MODE:-sync}\n\nRETURN_RAW_CHAT=\"False\"\nif [ \"$ROLLOUT_MODE\" = \"async\" ]; then\n    RETURN_RAW_CHAT=\"True\"\nfi\n\nGPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.8}\nACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False}\nACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False}\nREF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True}\nRM_PAD=${RM_PAD:-True}\nFUSED_KERNELS=${FUSED_KERNELS:-False}\nFUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend\nADV_ESTIMATOR=${ADV_ESTIMATOR:-gae}\nUSE_KL=${USE_KL:-False}\nCUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False}\nENABLE_CHUNKED_PREFILL=${ENABLE_CHUNKED_PREFILL:-True} # For vLLM VLM placeholder issue: https://github.com/vllm-project/vllm/issues/15185\nSTRATEGY=${STRATEGY:-fsdp}\n# LoRA config\nLORA_RANK=${LORA_RANK:-0}\nLORA_ALPHA=${LORA_ALPHA:-${LORA_RANK}}\nLORA_TARGET=${LORA_TARGET:-\"all-linear\"}\nLORA_EXCLUDE=${LORA_EXCLUDE:-\"DONT_EXCLUDE\"}\nUSE_SHM=${USE_SHM:-False}\nLOAD_FORMAT=${LOAD_FORMAT:-dummy_dtensor}\nLAYERED_SUMMON=${LAYERED_SUMMON:-False}\n# Validation\nVAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False}\nTEST_FREQ=${TEST_FREQ:--1}\n# Save & Resume\nRESUME_MODE=${RESUME_MODE:-disable}\nSAVE_FREQ=${SAVE_FREQ:--1}\nTOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}\n\n# whether to save hf_model\nSAVE_HF_MODEL=${SAVE_HF_MODEL:-False}\nFSDP_SIZE=${FSDP_SIZE:--1}\nSP_SIZE=${SP_SIZE:-1}\n\nif [ \"${SAVE_HF_MODEL}\" = \"True\" ]; then\n    CHECKPOINT_CONTENTS=\"['model','hf_model','optimizer','extra']\"\nelse\n    CHECKPOINT_CONTENTS=\"['model','optimizer','extra']\"\nfi\n\ntrain_traj_micro_bsz_per_gpu=2 # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\nreward_fn_name=null\nreward_fn_file_path=null\noutput_file=\"$(pwd)/output.txt\"\nif [ \"${CUSTOM_REWARD_FN}\" = \"True\" ]; then\n    reward_fn_name=\"my_reward_function\"\n    reward_fn_file_path=\"$(pwd)/my_reward_function.py\"\n    rm -rf \"${reward_fn_file_path}\"\n    cat <<EOF > \"$reward_fn_file_path\"\ndef ${reward_fn_name}(data_source, solution_str, ground_truth, extra_info=None):\n    print(f\"Congratulations!!! You have called ${reward_fn_name} successfully!!!\")\n    return 0.1\nEOF\n\n    rm -rf \"${output_file}\"\nfi\n\nexp_name=\"${VERL_EXP_NAME:-$(basename \"${MODEL_ID,,}\")-function-reward-minimal}\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=\"${ADV_ESTIMATOR}\" \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.train_batch_size=\"${train_prompt_bsz}\" \\\n    data.max_prompt_length=\"${MAX_PROMPT_LEN}\" \\\n    data.max_response_length=\"${MAX_RESPONSE_LEN}\" \\\n    data.return_raw_chat=${RETURN_RAW_CHAT} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.use_shm=${USE_SHM} \\\n    actor_rollout_ref.model.lora_rank=${LORA_RANK} \\\n    actor_rollout_ref.model.lora_alpha=${LORA_ALPHA} \\\n    actor_rollout_ref.model.target_modules=${LORA_TARGET} \\\n    actor_rollout_ref.model.exclude_modules=${LORA_EXCLUDE} \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=\"${RM_PAD}\" \\\n    actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \\\n    actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.actor.strategy=${STRATEGY} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \\\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=\"${SP_SIZE}\" \\\n    actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \\\n    actor_rollout_ref.actor.use_kl_loss=\"${USE_KL}\" \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=\"${ENGINE}\" \\\n    actor_rollout_ref.rollout.mode=\"${ROLLOUT_MODE}\" \\\n    actor_rollout_ref.rollout.load_format=${LOAD_FORMAT} \\\n    actor_rollout_ref.rollout.layered_summon=${LAYERED_SUMMON} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=\"${GPU_MEMORY_UTILIZATION}\" \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=\"${ENABLE_CHUNKED_PREFILL}\" \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=\"${REF_FSDP_PARAM_OFFLOAD}\" \\\n    critic.optim.lr=1e-5 \\\n    critic.model.use_remove_padding=\"${RM_PAD}\" \\\n    critic.model.path=\"${MODEL_PATH}\" \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    custom_reward_function.path=\"${reward_fn_file_path}\"\\\n    custom_reward_function.name=\"${reward_fn_name}\"\\\n    algorithm.use_kl_in_reward=\"${USE_KL}\" \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=\"${NUM_GPUS}\" \\\n    trainer.val_before_train=\"${VAL_BEFORE_TRAIN}\" \\\n    trainer.test_freq=\"${TEST_FREQ}\" \\\n    trainer.save_freq=\"${SAVE_FREQ}\" \\\n    trainer.resume_mode=\"${RESUME_MODE}\" \\\n    trainer.total_epochs=2 \\\n    trainer.device=cuda \\\n    trainer.total_training_steps=\"${TOTAL_TRAIN_STEPS}\" $@ \\\n    | tee \"${output_file}\"\n\nif [ \"${CUSTOM_REWARD_FN}\" = \"True\" ]; then\n    python3 tests/special_e2e/check_custom_rwd_fn.py --output_file=\"${output_file}\"\n    check_exit_code=$?\n    rm -rf \"${reward_fn_file_path}\"\n    rm -rf \"${output_file}\"\n    # Return the exit code of check_custom_rwd_fn.py if it fails\n    if [ $check_exit_code -ne 0 ]; then\n        exit $check_exit_code\n    fi\nfi\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/ppo_trainer/run_model_reward.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nhuggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nTRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}\nVAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}\n\nRM_PAD=${RM_PAD:-True}\nFUSED_KERNELS=${FUSED_KERNELS:-False}\nFUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend\nSP_SIZE=${SP_SIZE:-1}\nSEQ_BALANCE=${SEQ_BALANCE:-False}\nLIGER=${LIGER:-False}\n# Validation\nVAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False}\nTEST_FREQ=${TEST_FREQ:--1}\n# Save & Resume\nRESUME_MODE=${RESUME_MODE:-disable}\nSAVE_FREQ=${SAVE_FREQ:--1}\nTOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}\n\ntrain_traj_micro_bsz_per_gpu=2 # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\ntrain_max_token_num_per_gpu=32768\ninfer_max_token_num_per_gpu=32768\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-model-reward-minimal\"\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=gae \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.use_liger=\"${LIGER}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=\"${RM_PAD}\" \\\n    actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \\\n    actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.use_dynamic_bsz=\"${SEQ_BALANCE}\" \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=\"${SP_SIZE}\" \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    critic.optim.lr=1e-5 \\\n    critic.ulysses_sequence_parallel_size=\"${SP_SIZE}\" \\\n    critic.model.use_remove_padding=\"${RM_PAD}\" \\\n    critic.optim.lr_warmup_steps_ratio=0.05 \\\n    critic.model.path=\"${MODEL_PATH}\" \\\n    critic.model.enable_gradient_checkpointing=False \\\n    critic.use_dynamic_bsz=\"${SEQ_BALANCE}\" \\\n    critic.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \\\n    critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    critic.model.fsdp_config.param_offload=False \\\n    critic.model.fsdp_config.optimizer_offload=False \\\n    reward_model.enable=True \\\n    reward_model.ulysses_sequence_parallel_size=\"${SP_SIZE}\" \\\n    reward_model.model.path=\"${MODEL_PATH}\" \\\n    reward_model.model.use_remove_padding=\"${RM_PAD}\" \\\n    reward_model.model.fsdp_config.param_offload=True \\\n    reward_model.use_dynamic_bsz=\"${SEQ_BALANCE}\" \\\n    reward_model.forward_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \\\n    reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=\"${NUM_GPUS}\" \\\n    trainer.val_before_train=\"${VAL_BEFORE_TRAIN}\" \\\n    trainer.test_freq=\"${VAL_BEFORE_TRAIN}\" \\\n    trainer.save_freq=\"${SAVE_FREQ}\" \\\n    trainer.resume_mode=\"${RESUME_MODE}\" \\\n    trainer.total_epochs=2 \\\n    trainer.total_training_steps=\"${TOTAL_TRAIN_STEPS}\" $@\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/ppo_trainer/run_single_gpu.sh",
    "content": "PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n  data.train_files=$HOME/data/gsm8k/train.parquet \\\n  data.val_files=$HOME/data/gsm8k/test.parquet \\\n  data.train_batch_size=256  \\\n  data.max_prompt_length=512 \\\n  data.max_response_length=256  \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.actor.optim.lr=1e-6 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4  \\\n  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n  critic.optim.lr=1e-5 \\\n  critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  critic.ppo_micro_batch_size_per_gpu=4 \\\n  algorithm.kl_ctrl.kl_coef=0.001 \\\n  trainer.logger=console \\\n  trainer.val_before_train=False \\\n  trainer.n_gpus_per_node=1 \\\n  trainer.nnodes=1 \\\n  actor_rollout_ref.rollout.name=hf \\\n  trainer.total_training_steps=2"
  },
  {
    "path": "verl_rl/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh",
    "content": "PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n  data.train_files=$HOME/data/gsm8k/train.parquet \\\n  data.val_files=$HOME/data/gsm8k/test.parquet \\\n  data.train_batch_size=256  \\\n  data.max_prompt_length=512 \\\n  data.max_response_length=256  \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.actor.optim.lr=1e-6 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4  \\\n  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n  critic.optim.lr=1e-5 \\\n  critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  critic.ppo_micro_batch_size_per_gpu=4 \\\n  algorithm.kl_ctrl.kl_coef=0.001 \\\n  trainer.logger=['console'] \\\n  trainer.val_before_train=False \\\n  trainer.n_gpus_per_node=1 \\\n  trainer.nnodes=1 \\\n  actor_rollout_ref.rollout.name=hf \\\n  trainer.use_legacy_worker_impl=disable \\\n  trainer.total_training_steps=2"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_dapo.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nhuggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nadv_estimator=grpo\n\nkl_coef=0.0\nuse_kl_in_reward=False\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=1024\nmax_response_length=2048\nenable_overlong_buffer=True\noverlong_buffer_len=128\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=True\nfilter_groups_metric=seq_reward\nmax_num_gen_batches=10\n\ntrain_traj_micro_bsz_per_gpu=2 # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\ngen_prompt_bsz=$((train_prompt_bsz * 4))\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-dapo-minimal\"\n\npython3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${HOME}/data/gsm8k/train.parquet\" \\\n    data.val_files=\"${HOME}/data/gsm8k/test.parquet\" \\\n    reward_model.reward_manager=dapo \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=${NUM_GPUS} \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=2 \\\n    trainer.resume_mode=disable \\\n    trainer.val_before_train=False \\\n    trainer.total_training_steps=1 $@\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_genrm_remote.sh",
    "content": "#!/usr/bin/env bash\n\nexport no_proxy=\"localhost,127.0.0.1\"\n\nset -x\n\n# Launch a vllm server\nCUDA_VISIBLE_DEVICES=0 vllm serve verl-team/GenRM-CI-Test-1.5B \\\n    --served_model_name genrm-demo --host localhost --port 30000 > /dev/null &\nSERVER_PID=$!\n\n# kill server when script exits\ncleanup() {\n    echo \"Cleaning up...\"\n    kill $SERVER_PID 2>/dev/null || true\n    wait $SERVER_PID 2>/dev/null || true\n    echo \"Cleanup done\"\n}\ntrap cleanup EXIT\n\n# wait for server to start\nwait_for_server() {\n    local max_attempts=60\n    local attempt=0\n    local sleep_time=10\n\n    while [ $attempt -lt $max_attempts ]; do\n        if curl -s \"http://localhost:30000/health\" >/dev/null; then\n            echo \"Server is up and running!\"\n            return 0\n        fi\n        echo \"Waiting for server to start... (attempt $((attempt+1))/$max_attempts)\"\n        sleep $sleep_time\n        ((attempt++))\n    done\n    \n    echo \"Error: Failed to start server after $max_attempts attempts\" >&2\n    return 1\n}\n\nif ! wait_for_server; then\n    exit 1\nfi\n\nCUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=${HOME}/data/gsm8k/train.parquet \\\n    data.val_files=${HOME}/data/gsm8k/test.parquet \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=4 \\\n    algorithm.use_kl_in_reward=False \\\n    reward_model.reward_manager=batch \\\n    custom_reward_function.path=recipe/genrm_remote/reward_function.py \\\n    custom_reward_function.name=compute_score_batch \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name='qwen2.5-0.5b-gen-rm' \\\n    trainer.n_gpus_per_node=4 \\\n    trainer.val_before_train=False \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=10 \\\n    trainer.resume_mode='disable' \\\n    trainer.total_training_steps=1\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_geo3k_fsdp_sgl_multiturn_w_tool.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nhuggingface-cli download Qwen/Qwen2.5-VL-3B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-VL-3B-Instruct\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\nFSDP_STRATEGY=${FSDP_STRATEGY:-fsdp}\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='geo3k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=64 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.strategy=$FSDP_STRATEGY \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.ref.strategy=$FSDP_STRATEGY \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='geo3k_async_rl' \\\n    trainer.experiment_name=qwen2.5-vl-3b_function_rm-geo3k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0619-verify-n8 \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    data.train_files=$HOME/data/geo3k_verl_sgl_multi_turn_preprocessed/train.parquet \\\n    data.val_files=$HOME/data/geo3k_verl_sgl_multi_turn_preprocessed/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml\" \\\n    trainer.val_before_train=False \\\n    trainer.total_training_steps=1 $@"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_grpo_lora_with_merge.sh",
    "content": "#!/usr/bin/env bash\n#\n#  An e2e test script for testing the GRPO LoRA training process \n#  and processing the generated checkpoint using the merge_model.py script.  \n\nset -xeuo pipefail\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nif [ ! -d \"$MODEL_PATH\" ]; then\n    echo \"Downloading model to ${MODEL_PATH}...\"\n    huggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\nelse\n    echo \"Model directory ${MODEL_PATH} already exists, skip downloading.\"\nfi\n\n\nBATCH_SIZE=16\nEXP_NAME=\"qwen2.5_0.5b_grpo_lora\"\n# step 1. train model with grpo-lora for 1 step\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=${BATCH_SIZE} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=${MODEL_PATH} \\\n    actor_rollout_ref.model.use_shm=True \\\n    actor_rollout_ref.model.lora_rank=64 \\\n    actor_rollout_ref.model.lora_alpha=32 \\\n    actor_rollout_ref.actor.optim.lr=3e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${BATCH_SIZE} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.rollout.load_format=safetensors \\\n    actor_rollout_ref.rollout.layered_summon=True \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name=${EXP_NAME} \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.total_training_steps=1 \\\n    trainer.save_freq=1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 $@\n\n# step 2. merge model\npython3 -m verl.model_merger merge \\\n    --backend fsdp \\\n    --local_dir checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/ \\\n    --target_dir checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/hf\n\n# step 3. assert\n# make sure adapter_model.safetensors exists and its size is larger than 1MB\nfile_path=\"checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/hf/lora_adapter/adapter_model.safetensors\"\n\nif [ ! -f \"$file_path\" ]; then\n    echo \"Error: File $file_path does not exist!\"\n    exit 1\nfi\n\nfile_size=$(stat -c %s \"$file_path\")\n\nmin_size_mb=1\nmin_size=$((min_size_mb * 1024 * 1024))  # 1MB = 1048576 bytes\n\nif [ \"$file_size\" -lt \"$min_size\" ]; then\n    echo \"Error: File $file_path is too small! Current size: $((file_size/1024))KB, Required: ${min_size_mb}MB\"\n    exit 1\nfi\n\necho \"Check passed: File exists and size is $(($file_size/1024/1024))MB\"\nexit 0\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh",
    "content": "# run on 8xH20\n# make sure your current working directory is the root of the project\n\nset -x\n\n\nexport PYTHONUNBUFFERED=1\nexport RAY_DEDUP_LOGS=0\nexport RUST_BACKTRACE=1\nexport HYDRA_FULL_ERROR=1\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_sf_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=16384 \\\n    data.filter_overlong_prompts=False \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    data.train_files=$HOME/data/retool_dapo/train.parquet \\\n    data.val_files=$HOME/data/retool_aime2024/train.parquet \\\n    actor_rollout_ref.model.path=Qwen/Qwen3-4B \\\n    actor_rollout_ref.actor.use_dynamic_bsz=True \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_liger=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    +actor_rollout_ref.model.enable_activation_offloading=True \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=128 \\\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.actor.kl_loss_coef=0.0 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml\" \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger='[\"console\",\"wandb\"]' \\\n    trainer.project_name='retool_async_rl' \\\n    trainer.experiment_name='qwen3-4b_function_rm-retool-async-sgl-no-sft-n8-v2505271300' \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=100 \\\n    trainer.test_freq=20 \\\n    trainer.total_training_steps=1000 \\\n    trainer.total_epochs=1 $@"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh",
    "content": "# run on 8xH100\n# make sure your current working directory is the root of the project\n\nset -x\n\nhuggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-3B-Instruct\n\nulimit -n 65535\n\nPROJECT_DIR=\"$(pwd)\"\nCONFIG_PATH=\"$PROJECT_DIR/examples/sglang_multiturn/config\"\nFSDP_STRATEGY=${FSDP_STRATEGY:-fsdp}\n\npython3 -m verl.trainer.main_ppo \\\n    --config-path=\"$CONFIG_PATH\" \\\n    --config-name='gsm8k_multiturn_grpo' \\\n    algorithm.adv_estimator=grpo \\\n    data.train_batch_size=256 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=1024 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.strategy=$FSDP_STRATEGY \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=sglang \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \\\n    actor_rollout_ref.ref.strategy=$FSDP_STRATEGY \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='gsm8k_async_rl' \\\n    trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    data.train_files=$HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed/train.parquet \\\n    data.val_files=$HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed/test.parquet \\\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml\" \\\n    trainer.val_before_train=False \\\n    trainer.total_training_steps=1 $@\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_one_step_off_policy.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n# Test script for one_step_off_policy E2E regression testing\n# This script runs one_step_off_policy with both FSDP2 and Megatron backends\n# to ensure the asynchronous training mechanism works correctly\n\nNUM_GPUS=${NUM_GPUS:-8}\nACTOR_STRATEGY=${ACTOR_STRATEGY:-\"fsdp2\"}  # fsdp2 or megatron\n\n# Download model if not exists\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nhuggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\n# Algorithm parameters\nadv_estimator=grpo\n\nuse_kl_in_reward=False\nkl_coef=0.0\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\n# Response length parameters\nmax_prompt_length=1024\nmax_response_length=2048\nenable_overlong_buffer=True\noverlong_buffer_len=128\noverlong_penalty_factor=1.0\n\n# Training parameters\nloss_agg_mode=\"token-mean\"\ntrain_prompt_bsz=8\nn_resp_per_prompt=3\ntrain_prompt_mini_bsz=4\n\n# Temperature parameters\ntemperature=1.0\ntop_p=1.0\ntop_k=-1\nval_top_p=0.7\n\n# One-step-off-policy specific parameters\n# Allocate 2 GPUs for rollout, remaining for training\nn_gpus_rollout=2\nn_gpus_training=$((NUM_GPUS - n_gpus_rollout))\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-one-step-off-policy-${ACTOR_STRATEGY}-minimal\"\n\necho \"Running one_step_off_policy with ${ACTOR_STRATEGY} strategy\"\necho \"Total GPUs: ${NUM_GPUS}, Rollout GPUs: ${n_gpus_rollout}, Training GPUs: ${n_gpus_training}\"\n\n# Common parameters for both FSDP2 and Megatron\ncommon_params=(\n    data.train_files=\"${HOME}/data/gsm8k/train.parquet\"\n    data.val_files=\"${HOME}/data/gsm8k/test.parquet\"\n    data.prompt_key=prompt\n    data.truncation='left'\n    data.max_prompt_length=${max_prompt_length}\n    data.max_response_length=${max_response_length}\n    data.train_batch_size=${train_prompt_bsz}\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt}\n    algorithm.adv_estimator=${adv_estimator}\n    algorithm.use_kl_in_reward=${use_kl_in_reward}\n    algorithm.kl_ctrl.kl_coef=${kl_coef}\n    actor_rollout_ref.hybrid_engine=False \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss}\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef}\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low}\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high}\n    actor_rollout_ref.actor.clip_ratio_c=10.0\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\"\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.actor.optim.lr_warmup_steps=-1\n    actor_rollout_ref.actor.optim.weight_decay=0.1\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode}\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.80\n    actor_rollout_ref.rollout.temperature=${temperature}\n    actor_rollout_ref.rollout.top_p=${top_p}\n    actor_rollout_ref.rollout.top_k=${top_k}\n    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature}\n    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p}\n    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k}\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True\n    actor_rollout_ref.rollout.val_kwargs.n=1\n    actor_rollout_ref.rollout.enable_chunked_prefill=True \\\n    reward_model.reward_manager=dapo\n    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer}\n    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len}\n    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor}\n    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False\n    +reward_model.reward_kwargs.max_resp_len=${max_response_length}\n    trainer.logger=['console']\n    trainer.project_name='verl-test'\n    trainer.experiment_name=\"${exp_name}\"\n    trainer.val_before_train=False\n    trainer.test_freq=-1\n    trainer.save_freq=-1\n    trainer.total_epochs=2\n    trainer.total_training_steps=2\n    trainer.resume_mode=disable\n    trainer.nnodes=1\n    trainer.n_gpus_per_node=${n_gpus_training}\n    rollout.nnodes=1\n    rollout.n_gpus_per_node=${n_gpus_rollout}\n\n)\n\nif [ \"${ACTOR_STRATEGY}\" == \"fsdp2\" ]; then\n    echo \"Running with FSDP2 strategy...\"\n    # FSDP2 specific parameters\n    gen_tp=2\n    sp_size=2\n    fsdp_size=2\n    ref_offload=True\n    actor_offload=False\n\n    python3 -m recipe.one_step_off_policy.main_ppo \\\n        \"${common_params[@]}\" \\\n        actor_rollout_ref.actor.strategy=fsdp2 \\\n        critic.strategy=fsdp2 \\\n        actor_rollout_ref.actor.grad_clip=1.0 \\\n        actor_rollout_ref.model.use_remove_padding=True \\\n        actor_rollout_ref.actor.use_dynamic_bsz=True \\\n        actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \\\n        actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \\\n        actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \\\n        actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \\\n        actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} $@\n\nelif [ \"${ACTOR_STRATEGY}\" == \"megatron\" ]; then\n    echo \"Running with Megatron strategy...\"\n    # Megatron specific parameters\n    gen_tp=2\n    train_tp=1\n    train_pp=2\n    ref_offload=True\n    actor_offload=False\n\n    python3 -m recipe.one_step_off_policy.main_ppo \\\n        --config-path=config \\\n        --config-name='one_step_off_ppo_megatron_trainer.yaml' \\\n        \"${common_params[@]}\" \\\n        actor_rollout_ref.actor.strategy=megatron \\\n        critic.strategy=megatron \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n        actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \\\n        actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \\\n        actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \\\n        actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \\\n        actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \\\n        actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \\\n        actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \\\n        actor_rollout_ref.ref.megatron.param_offload=${ref_offload} $@\nelse\n    echo \"Error: Unknown strategy ${ACTOR_STRATEGY}. Please use 'fsdp2' or 'megatron'\"\n    exit 1\nfi\n\necho \"One-step-off-policy E2E test completed successfully with ${ACTOR_STRATEGY} strategy\""
  },
  {
    "path": "verl_rl/tests/special_e2e/run_ppo_trainer_megatron.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping\nexport VERL_LOGGING_LEVEL=INFO\nexport VERL_PPO_LOGGING_LEVEL=INFO\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nhuggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nUSE_DUMMY_MODEL=${USE_DUMMY_MODEL:-False}\nDUMMY_MODEL_PATH=${DUMMY_MODEL_PATH:-${HOME}/dummy_models/${MODEL_ID}}\nif [ \"$USE_DUMMY_MODEL\" = \"True\" ]; then\n    if [ -z \"${DUMMY_MODEL_CONFIG_PATH}\"  ]; then\n        echo \"[ERROR] DUMMY_MODEL_CONFIG_PATH not set\"\n        exit 1\n    fi\n    \n    python scripts/init_random_model.py \\\n        --hf_model_path \"${MODEL_PATH}\" \\\n        --new_config_path \"${DUMMY_MODEL_CONFIG_PATH}\" \\\n        --output_path \"${DUMMY_MODEL_PATH}\"\n\n    MODEL_PATH=\"${DUMMY_MODEL_PATH}\"\nfi\n\nTRAIN_FILES=${TRAIN_FILES:-${HOME}/data/gsm8k/train.parquet}\nVAL_FILES=${VAL_FILES:-${HOME}/data/gsm8k/test.parquet}\n\nADV_ESTIMATOR=${ADV_ESTIMATOR:-gae}\n# Validation\nVAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False}\nTEST_FREQ=${TEST_FREQ:--1}\n# Save & Resume\nRESUME_MODE=${RESUME_MODE:-disable}\nSAVE_FREQ=${SAVE_FREQ:--1}\nTOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}\n\nUSE_DYNAMIC_BSZ=${USE_DYNAMIC_BSZ:-True}\nppo_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN:-2400}\nforward_max_token_len_per_gpu=${FWD_MAX_TOKEN_LEN:-4800}\ntrain_traj_micro_bsz_per_gpu=${MICRO_BSZ:-2} # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\nMAX_PROMPT_LENGTH=${MAX_PROMPT_LENGTH:-512}\nMAX_RESPONSE_LENGTH=${MAX_RESPONSE_LENGTH:-512}\n\nCOMMON_PP=${COMMON_PP:-2}\nCOMMON_VPP=${COMMON_VPP:-2}\nCOMMON_CP=${COMMON_CP:-2}\nCOMMON_TP=${COMMON_TP:-2}\nCOMMON_EP=${COMMON_EP:-1}\nCOMMON_ETP=${COMMON_ETP:-null}\n\nTRAIN_TP=${TRAIN_TP:-$COMMON_TP}\nINFER_TP=${INFER_TP:-$COMMON_TP}\n\nACTOR_PP=${ACTOR_PP:-$COMMON_PP}\nACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP}\nACTOR_CP=${ACTOR_CP:-$COMMON_CP}\nACTOR_TP=${ACTOR_TP:-$TRAIN_TP}\nACTOR_EP=${ACTOR_EP:-$COMMON_EP}\nACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP}\nROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP}\nREF_PP=${REF_PP:-$COMMON_PP}\nREF_VPP=${REF_VPP:-$COMMON_VPP}\nREF_CP=${REF_CP:-$COMMON_CP}\nREF_TP=${REF_TP:-$TRAIN_TP}\nREF_EP=${REF_EP:-$COMMON_EP}\nREF_ETP=${REF_ETP:-$COMMON_ETP}\nCRITIC_PP=${CRITIC_PP:-$COMMON_PP}\nCRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP}\nCRITIC_CP=${CRITIC_CP:-$COMMON_CP}\nCRITIC_TP=${CRITIC_TP:-$TRAIN_TP}\nCRITIC_EP=${CRITIC_EP:-$COMMON_EP}\nCRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP}\nRM_PP=${RM_PP:-$COMMON_PP}\nRM_VPP=${RM_VPP:-$COMMON_VPP}\nRM_CP=${RM_CP:-$COMMON_CP}\nRM_TP=${RM_TP:-$TRAIN_TP}\nRM_EP=${RM_EP:-$COMMON_EP}\nRM_ETP=${RM_ETP:-$COMMON_ETP}\n\nALL_OFFLOAD=${ALL_OFFLOAD:-False}\nCOMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}\nCOMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}\n\nACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nREF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nCRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}\nCRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}\nRM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}\nUSE_MBRIDGE=${USE_MBRIDGE:-False}\nUSE_FUSED_KERNELS=${USE_FUSED_KERNELS:-False}\n\nLR_WARMUP_STEPS=${LR_WARMUP_STEPS:-null}\n\nCHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra']\nSKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0}\nif [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then\n    CHECKPOINT_CONTENTS=['model','optimizer','extra']\nfi\n\nUSE_DIST_CKPT=${USE_DIST_CKPT:-False}\nDIST_CKPT_PATH=${DIST_CKPT_PATH:-${HOME}/dist_ckpt/${MODEL_ID}}\nif [ \"$USE_DIST_CKPT\" = \"True\" ]; then\n    if [ \"$USE_DUMMY_MODEL\" = \"True\" ]; then\n        DIST_CKPT_PATH=${HOME}/dist_ckpt_dummy/${MODEL_ID}\n    fi\n    python scripts/converter_hf_to_mcore.py \\\n        --hf_model_path \"${MODEL_PATH}\" \\\n        --output_path \"${DIST_CKPT_PATH}\"\nfi\n\nENGINE=${ENGINE:-\"vllm\"}\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-megatron-gsm8k-minimal\"\n\nif [ \"$ENGINE\" = \"vllm\" ]; then\n    MODE=${MODE:-\"sync\"}\n    ROLLOUT_MODE_ARG=\"actor_rollout_ref.rollout.mode=${MODE}\"\n    if [ \"$MODE\" = \"async\" ]; then\n        ROLLOUT_MODE_ARG=\"${ROLLOUT_MODE_ARG} data.return_raw_chat=True\"\n    fi\nelse\n    ROLLOUT_MODE_ARG=\"\"\nfi\n\npython3 -m verl.trainer.main_ppo --config-path=config \\\n    --config-name='ppo_megatron_trainer.yaml'\\\n    algorithm.adv_estimator=\"${ADV_ESTIMATOR}\" \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.max_prompt_length=${MAX_PROMPT_LENGTH} \\\n    data.max_response_length=${MAX_RESPONSE_LENGTH} \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.model.use_fused_kernels=${USE_FUSED_KERNELS} \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps=$LR_WARMUP_STEPS \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ} \\\n    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \\\n    actor_rollout_ref.actor.megatron.use_mbridge=${USE_MBRIDGE} \\\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$ACTOR_PP \\\n    actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \\\n    actor_rollout_ref.actor.megatron.context_parallel_size=$ACTOR_CP \\\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$ACTOR_TP \\\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=$ACTOR_EP \\\n    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ACTOR_ETP \\\n    actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \\\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \\\n    actor_rollout_ref.actor.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_CONTENTS \\\n    actor_rollout_ref.rollout.name=\"${ENGINE}\" ${ROLLOUT_MODE_ARG}\\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \\\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$REF_PP \\\n    actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=$REF_VPP \\\n    actor_rollout_ref.ref.megatron.context_parallel_size=$REF_CP \\\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$REF_TP \\\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=$REF_EP \\\n    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$REF_ETP \\\n    actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \\\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \\\n    actor_rollout_ref.ref.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \\\n    critic.optim.lr=2e-5 \\\n    critic.optim.lr_warmup_steps=$LR_WARMUP_STEPS \\\n    critic.model.path=\"${MODEL_PATH}\" \\\n    critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    critic.ppo_max_token_len_per_gpu=${forward_max_token_len_per_gpu} \\\n    critic.megatron.use_mbridge=${USE_MBRIDGE} \\\n    critic.megatron.pipeline_model_parallel_size=$CRITIC_PP \\\n    critic.megatron.virtual_pipeline_model_parallel_size=$CRITIC_VPP \\\n    critic.megatron.context_parallel_size=$CRITIC_CP \\\n    critic.megatron.tensor_model_parallel_size=$CRITIC_TP \\\n    critic.megatron.expert_model_parallel_size=$CRITIC_EP \\\n    critic.megatron.expert_tensor_parallel_size=$CRITIC_ETP \\\n    critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \\\n    critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \\\n    critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \\\n    critic.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \\\n    critic.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \\\n    critic.checkpoint.save_contents=$CHECKPOINT_CONTENTS \\\n    reward_model.enable=True \\\n    reward_model.model.path=\"${MODEL_PATH}\" \\\n    reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    reward_model.megatron.use_mbridge=${USE_MBRIDGE} \\\n    reward_model.megatron.pipeline_model_parallel_size=$RM_PP \\\n    reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \\\n    reward_model.megatron.context_parallel_size=$RM_CP \\\n    reward_model.megatron.tensor_model_parallel_size=$RM_TP \\\n    reward_model.megatron.expert_model_parallel_size=$RM_EP \\\n    reward_model.megatron.expert_tensor_parallel_size=$RM_ETP \\\n    reward_model.megatron.param_offload=${RM_PARAM_OFFLOAD} \\\n    reward_model.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \\\n    reward_model.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \\\n    algorithm.use_kl_in_reward=False \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=${NUM_GPUS} \\\n    trainer.val_before_train=\"${VAL_BEFORE_TRAIN}\" \\\n    trainer.test_freq=\"${TEST_FREQ}\" \\\n    trainer.save_freq=\"${SAVE_FREQ}\" \\\n    trainer.resume_mode=\"${RESUME_MODE}\" \\\n    trainer.total_epochs=2 \\\n    trainer.total_training_steps=\"${TOTAL_TRAIN_STEPS}\" $@\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_prime.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nhuggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nTRAIN_FILES=${TRAIN_FILES:-${HOME}/data/gsm8k/train.parquet}\nVAL_FILES=${VAL_FILES:-${HOME}/data/gsm8k/test.parquet}\n\ntrain_traj_micro_bsz_per_gpu=2 # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-prime-minimal\"\n\npython3 -m recipe.prime.main_prime \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=512 \\\n    data.filter_accuracy=True \\\n    data.accuracy_lower_bound=0.2 \\\n    data.accuracy_upper_bound=0.8 \\\n    data.oversample_factor=4 \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=5e-7 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=False \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.adv_estimator=rloo \\\n    algorithm.use_kl_in_reward=True \\\n    algorithm.kl_penalty=kl \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    reward_model.model.path=\"${MODEL_PATH}\" \\\n    reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    reward_model.model.update=before \\\n    reward_model.model.beta_train=0.05 \\\n    reward_model.model.optim.lr=1e-6 \\\n    reward_model.model.optim.grad_clip=10.0 \\\n    reward_model.model.input_tokenizer=null \\\n    reward_model.mini_batch_size=${train_prompt_bsz} \\\n    reward_model.reward_manager=prime \\\n    trainer.val_before_train=False \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=${NUM_GPUS} \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_training_steps=1 $@\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_r1_distill_qwen_aime24_eval.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nhuggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \\\n    --local-dir $HOME/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\n\npython3 -m verl.trainer.main_generation \\\n    trainer.nnodes=1 \\\n    trainer.n_gpus_per_node=8 \\\n    data.path=$HOME/data/r1/test.parquet \\\n    data.prompt_key=prompt \\\n    data.batch_size=1024 \\\n    data.n_samples=1 \\\n    data.output_path=$HOME/data/r1/test-output-k1.parquet \\\n    model.path=$HOME/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \\\n    rollout.temperature=0.6 \\\n    rollout.top_p=0.95 \\\n    rollout.prompt_length=1024 \\\n    rollout.response_length=32768 \\\n    rollout.tensor_model_parallel_size=1 \\\n    rollout.gpu_memory_utilization=0.95 \\\n    rollout.max_num_batched_tokens=65536 \\\n    rollout.enforce_eager=False \\\n    rollout.free_cache_engine=True\n\npython3 -m recipe.r1.main_eval \\\n    data.path=$HOME/data/r1/test-output-k1.parquet \\\n    data.prompt_key=prompt \\\n    data.response_key=responses \\\n    custom_reward_function.path=recipe/r1/reward_score.py \\\n    custom_reward_function.name=reward_func"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_spin.sh",
    "content": "set -e\nset -x\nNUM_GPUS=${NUM_GPUS:-8}\n\nexp_name=\"Qwen2.5-0.5B-Instruct-spin-minimal\"\n\nCUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} python3 -m recipe.spin.main_spin \\\n  data.train_files=$HOME/data/gsm8k/train.parquet \\\n  data.val_files=$HOME/data/gsm8k/test.parquet \\\n  data.train_batch_size=1024 \\\n  data.max_prompt_length=1024 \\\n  data.max_response_length=1024 \\\n  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n  actor_rollout_ref.actor.optim.lr=1e-6 \\\n  actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n  actor_rollout_ref.actor.ppo_micro_batch_size=8 \\\n  actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \\\n  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n  actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n  actor_rollout_ref.ref.log_prob_micro_batch_size=64 \\\n  algorithm.kl_ctrl.kl_coef=0.001 \\\n  trainer.logger=console \\\n  trainer.val_before_train=False \\\n  trainer.n_gpus_per_node=4 \\\n  trainer.nnodes=1 \\\n  trainer.save_freq=-1 \\\n  trainer.test_freq=1 \\\n  +trainer.log_freq=1 \\\n  trainer.ref_update_freq=1 \\\n  trainer.total_training_steps=1 \\\n  trainer.total_epochs=1000 2>&1 | tee verl_demo.log"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_sppo.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\n# in e2e_sppo.yml, we set NUM_GPUS=8 L20\n\nNUM_GPUS=${NUM_GPUS:-8}\n\ngsm8k_train_path=./data/math/train.parquet\ngsm8k_test_path=./data/math/test.parquet\ntrain_files=\"['$gsm8k_train_path']\"\ntest_files=\"['$gsm8k_test_path']\"\n\nexp_name=\"Qwen2.5-0.5B-Instruct-sppo-minimal\"\n\npython3 -m recipe.sppo.main_sppo \\\n    data.train_files=\"$train_files\" \\\n    data.val_files=\"$test_files\" \\\n    data.train_batch_size=1024 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=512 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.return_raw_chat=True \\\n    actor_rollout_ref.model.path=\"./models/Qwen2.5-0.5B-Instruct\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.actor.use_kl_loss=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=sglang  \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.val_before_train=False \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_training_steps=1 \\\n    trainer.total_epochs=2 $@\n"
  },
  {
    "path": "verl_rl/tests/special_e2e/run_test.sh",
    "content": "#!/bin/bash\nset -xeuo pipefail\n\n# Get the configuration name and engine name from arguments\nCONFIG_NAME=\"$1\"\nENGINE=\"${2:-vllm}\"\n\n# Download model if needed\nhuggingface-cli download Qwen/Qwen2.5-0.5B --local-dir \"$HOME/models/Qwen/Qwen2.5-0.5B\"\n\n# Run the training with the specified configuration\npython3 -m verl.trainer.main_ppo \\\n    --config-name \"$CONFIG_NAME\" \"$@\" "
  },
  {
    "path": "verl_rl/tests/special_e2e/sft/run_sft.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nENTRYPOINT=${ENTRYPOINT:-\"-m verl.trainer.fsdp_sft_trainer\"}\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nhuggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nTRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}\nVAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}\n\nSP_SIZE=${SP_SIZE:-1}\nLIGER=${LIGER:-False}\nMULTITURN=${MULTITURN:-False}\nLORA_RANK=${LORA_RANK:-0}\nRM_PAD=${RM_PAD:-True}\n\nTOTAL_TRAIN_STEP=${TOTAL_TRAIN_STEP:-1}\nRESUME_MODE=${RESUME_MODE:-disable}\nSAVE_FREQ=${SAVE_FREQ:-1}\n\nmicro_bsz=2\nNUM_GPUS=8\n\nproject_name=\"verl-test\"\nexp_name=\"$(basename \"${MODEL_ID,,}\")-sft-minimal\"\nckpts_home=${ckpts_home:-$HOME/${project_name}/${exp_name}}\n\nmkdir -p \"${ckpts_home}\"\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \\\n    data.train_files=\"${TRAIN_FILES}\" \\\n    data.val_files=\"${VAL_FILES}\" \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    data.prompt_dict_keys=['question'] \\\n    data.response_dict_keys=['answer'] \\\n    data.multiturn.enable=\"${MULTITURN}\" \\\n    data.multiturn.messages_key=messages \\\n    optim.lr=1e-4 \\\n    data.micro_batch_size_per_gpu=${micro_bsz} \\\n    model.strategy=fsdp \\\n    model.partial_pretrain=\"${MODEL_PATH}\" \\\n    model.lora_rank=\"${LORA_RANK}\" \\\n    model.lora_alpha=16 \\\n    model.target_modules=all-linear \\\n    model.use_liger=\"${LIGER}\" \\\n    ulysses_sequence_parallel_size=\"${SP_SIZE}\" \\\n    use_remove_padding=\"${RM_PAD}\" \\\n    trainer.default_local_dir=\"${ckpts_home}\" \\\n    trainer.project_name=\"${project_name}\" \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.total_training_steps=${TOTAL_TRAIN_STEP} \\\n    trainer.save_freq=${SAVE_FREQ} \\\n    trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \\\n    trainer.max_ckpt_to_keep=1 \\\n    trainer.resume_mode=${RESUME_MODE} \\\n    trainer.logger=['console'] $@\n\nrm -rf \"${ckpts_home:?}/*\""
  },
  {
    "path": "verl_rl/tests/special_e2e/sft/test_sp_loss_match.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\nfrom tensordict import TensorDict\nfrom torch.distributed.device_mesh import init_device_mesh\n\nfrom verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer\nfrom verl.utils.distributed import initialize_global_process_group\n\n\ndef test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4):\n    \"\"\"Test consistency between original forward pass and SP+rmpad forward passes.\n\n    Args:\n        trainer: The FSDPSFTTrainer instance to test\n        total_steps: Number of steps to test (default: 4)\n    \"\"\"\n    if trainer.device_mesh.get_rank() == 0:\n        print(\"\\nStarting debug comparison between original and SP+rmpad forward passes...\")\n        print(f\"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}\")\n        print(f\"Remove padding: {trainer.use_remove_padding}\\n\")\n\n    steps_remaining = total_steps\n\n    for epoch in range(1):  # Just one epoch for testing\n        trainer.train_sampler.set_epoch(epoch=epoch)\n        for data in trainer.train_dataloader:\n            data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda()\n            trainer.fsdp_model.train()\n            micro_batches = data.split(trainer.config.data.micro_batch_size_per_gpu)\n\n            for idx, micro_batch in enumerate(micro_batches):\n                if trainer.device_mesh.get_rank() == 0:\n                    print(f\"\\nProcessing micro batch {idx + 1}/{len(micro_batches)}\")\n\n                # Compute losses using both methods\n                # Disable SP and rmpad\n                trainer.use_remove_padding = False\n                old_sp = trainer.config.ulysses_sequence_parallel_size\n                trainer.config.ulysses_sequence_parallel_size = 1\n                loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False)\n\n                # Do SP and rmpad\n                trainer.config.ulysses_sequence_parallel_size = old_sp\n                trainer.use_remove_padding = True\n                loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False)\n\n                # Collect losses across all ranks\n                loss_ref_all = loss_ref.clone()\n                loss_sp_all = loss_sp.clone()\n                torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG)\n                torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG)\n\n                # Calculate relative difference of averaged losses\n                rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8)\n\n                if trainer.device_mesh.get_rank() == 0:\n                    print(\"\\nComparison Results (Averaged across ranks):\")\n                    print(f\"Reference Loss: {loss_ref_all.item():.6f}\")\n                    print(f\"SP+rmpad Loss: {loss_sp_all.item():.6f}\")\n                    print(f\"Relative Difference: {rel_diff.item():.6f}\")\n\n                    assert rel_diff.item() < 1e-2, \"Significant difference detected between averaged losses!\"\n                    print(\"Loss difference is within the acceptable range.\")\n\n                steps_remaining -= 1\n                if steps_remaining == 0:\n                    break\n            if steps_remaining == 0:\n                break\n        break\n\n    if trainer.device_mesh.get_rank() == 0:\n        print(\"\\nDebug comparison completed successfully.\")\n\n\ndef create_trainer(config):\n    \"\"\"Create and initialize a trainer instance with the given config.\n\n    Args:\n        config: Configuration object with training parameters\n\n    Returns:\n        FSDPSFTTrainer: Initialized trainer instance\n    \"\"\"\n    local_rank, rank, world_size = initialize_global_process_group()\n\n    device_mesh = init_device_mesh(device_type=\"cuda\", mesh_shape=(world_size,), mesh_dim_names=(\"fsdp\",))\n\n    dp_size = world_size // config.ulysses_sequence_parallel_size\n    ulysses_device_mesh = init_device_mesh(\n        device_type=\"cuda\", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=(\"dp\", \"sp\")\n    )\n\n    # build tokenizer and datasets first\n    from verl.trainer.fsdp_sft_trainer import create_sft_dataset\n    from verl.utils import hf_tokenizer\n    from verl.utils.fs import copy_to_local\n\n    local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)\n    tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)\n    train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)\n    val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)\n\n    return FSDPSFTTrainer(\n        config=config,\n        device_mesh=device_mesh,\n        ulysses_device_mesh=ulysses_device_mesh,\n        tokenizer=tokenizer,\n        train_dataset=train_dataset,\n        val_dataset=val_dataset,\n    )\n\n\ndef main(config):\n    \"\"\"Main function to run trainer tests.\n\n    Args:\n        config: Configuration object with training parameters\n    \"\"\"\n    trainer = create_trainer(config)\n    test_trainer_forward_consistency(trainer)\n\n\nif __name__ == \"__main__\":\n    import hydra\n    from omegaconf import DictConfig\n\n    @hydra.main(config_path=\"../../../verl/trainer/config\", config_name=\"sft_trainer\")\n    def hydra_entry(cfg: DictConfig) -> None:\n        main(cfg)\n\n    hydra_entry()\n"
  },
  {
    "path": "verl_rl/tests/special_npu/run_qwen2_5_05b_dapo.sh",
    "content": "#!/usr/bin/env bash\nset -xeuo pipefail\n\nNUM_GPUS=${NUM_GPUS:-8}\n\nMODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}\nMODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}\nhuggingface-cli download \"${MODEL_ID}\" --local-dir \"${MODEL_PATH}\"\n\nadv_estimator=grpo\n\nkl_coef=0.0\nuse_kl_in_reward=False\nuse_kl_loss=False\nkl_loss_coef=0.0\n\nclip_ratio_low=0.2\nclip_ratio_high=0.28\n\nmax_prompt_length=1024\nmax_response_length=2048\nenable_overlong_buffer=True\noverlong_buffer_len=128\noverlong_penalty_factor=1.0\n\nloss_agg_mode=\"token-mean\"\n\nenable_filter_groups=True\nfilter_groups_metric=seq_reward\nmax_num_gen_batches=10\n\ntrain_traj_micro_bsz_per_gpu=2 # b\nn_resp_per_prompt=4 # g\n\ntrain_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n\ntrain_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n\ntrain_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g\ntrain_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g\n\ngen_prompt_bsz=$((train_prompt_bsz * 4))\n\nexp_name=\"$(basename \"${MODEL_ID,,}\")-dapo-minimal\"\n\npython3 -m recipe.dapo.main_dapo \\\n    data.train_files=\"${HOME}/data/gsm8k/train.parquet\" \\\n    data.val_files=\"${HOME}/data/gsm8k/test.parquet\" \\\n    reward_model.reward_manager=dapo \\\n    algorithm.adv_estimator=${adv_estimator} \\\n    algorithm.use_kl_in_reward=${use_kl_in_reward} \\\n    algorithm.kl_ctrl.kl_coef=${kl_coef} \\\n    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \\\n    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \\\n    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \\\n    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \\\n    data.max_prompt_length=${max_prompt_length} \\\n    data.max_response_length=${max_response_length} \\\n    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \\\n    reward_model.overlong_buffer.len=${overlong_buffer_len} \\\n    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \\\n    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \\\n    data.train_batch_size=${train_prompt_bsz} \\\n    data.gen_batch_size=${gen_prompt_bsz} \\\n    algorithm.filter_groups.enable=${enable_filter_groups} \\\n    algorithm.filter_groups.metric=${filter_groups_metric} \\\n    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \\\n    actor_rollout_ref.model.path=\"${MODEL_PATH}\" \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=True \\\n    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \\\n    actor_rollout_ref.actor.entropy_checkpointing=True \\\n    actor_rollout_ref.ref.entropy_checkpointing=True \\\n    actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \\\n    actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \\\n    trainer.logger=console \\\n    trainer.project_name='verl-test' \\\n    trainer.experiment_name=\"${exp_name}\" \\\n    trainer.n_gpus_per_node=${NUM_GPUS} \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.total_epochs=2 \\\n    trainer.resume_mode=disable \\\n    trainer.val_before_train=False \\\n    trainer.total_training_steps=1 \\\n    trainer.device=npu $@\n"
  },
  {
    "path": "verl_rl/tests/special_npu/run_qwen2_5_05b_grpo.sh",
    "content": "set -x\n\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=512 \\\n    data.max_response_length=128 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=5e-7 \\\n    actor_rollout_ref.model.use_remove_padding=False \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_gsm8k' \\\n    trainer.experiment_name='qwen2_7b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=5 \\\n    trainer.total_epochs=1 \\\n    trainer.total_training_steps=2 \\\n    trainer.device=npu $@\n"
  },
  {
    "path": "verl_rl/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh",
    "content": "set -x\n\nmkdir -p ./save_ckpts\n\ntorchrun --standalone --nnodes=1 --nproc_per_node=8 \\\n     -m verl.trainer.fsdp_sft_trainer \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.prompt_key=extra_info \\\n    data.response_key=extra_info \\\n    optim.lr=1e-4 \\\n    data.prompt_dict_keys=['question'] \\\n    +data.response_dict_keys=['answer'] \\\n    data.micro_batch_size_per_gpu=32 \\\n    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \\\n    trainer.default_local_dir=./save_ckpts \\\n    trainer.project_name=gsm8k-sft \\\n    trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \\\n    trainer.logger=console \\\n    trainer.total_epochs=1 \\\n    trainer.total_training_steps=1 $@ \\\n    model.lora_rank=32 \\\n    model.lora_alpha=16 \\\n    model.target_modules=all-linear \\\n    model.strategy=fsdp \\\n    ulysses_sequence_parallel_size=2 \\\n    use_remove_padding=true \\\n    trainer.device=npu\n\nrm -rf ./outputs ./save_ckpts\n"
  },
  {
    "path": "verl_rl/tests/special_npu/run_qwen2_5_vl_3b_npu.sh",
    "content": "set -x\nENGINE=${1:-vllm}\n\n# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, \n# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model.\nexport USE_OPTIMIZED_MODEL=0\n\npython3 -m verl.trainer.main_ppo \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/geo3k/train.parquet \\\n    data.val_files=$HOME/data/geo3k/test.parquet \\\n    data.train_batch_size=512 \\\n    data.max_prompt_length=1024 \\\n    data.max_response_length=2048 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.image_key=images \\\n    actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=16 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.actor.entropy_coeff=0 \\\n    actor_rollout_ref.actor.use_torch_compile=False \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n    actor_rollout_ref.rollout.name=$ENGINE \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=True \\\n    actor_rollout_ref.rollout.free_cache_engine=True \\\n    actor_rollout_ref.rollout.n=5 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=console \\\n    trainer.project_name='verl_grpo_example_geo3k' \\\n    trainer.experiment_name='qwen2_5_vl_3b_function_rm' \\\n    trainer.n_gpus_per_node=8 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=-1 \\\n    trainer.test_freq=-1 \\\n    trainer.total_epochs=1 \\\n    trainer.total_training_steps=1 \\\n    trainer.device=npu $@"
  },
  {
    "path": "verl_rl/tests/special_sanity/check_api_docs.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nFail CI if any function or class that is publicly exported via\n``__all__`` lacks a docstring.\n\nUsage\n-----\n  # Check specific modules or packages\n  python check_docstrings.py mypkg.core mypkg.utils\n\n  # Check an entire source tree (all top-level packages under cwd)\n  python check_docstrings.py\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport importlib\nimport inspect\nimport pkgutil\nimport sys\nfrom pathlib import Path\nfrom types import ModuleType\nfrom typing import Iterable\n\n_ALLOW_LIST = [\n    \"verl.third_party.vllm.LLM\",\n    \"verl.third_party.vllm.parallel_state\",\n    \"verl.utils.profiler.WorkerProfiler\",\n    \"verl.utils.profiler.WorkerProfilerExtension\",\n    \"verl.utils.profiler.log_gpu_memory_usage\",\n    \"verl.utils.profiler.log_print\",\n    \"verl.utils.profiler.mark_annotate\",\n    \"verl.utils.profiler.mark_end_range\",\n    \"verl.utils.profiler.mark_start_range\",\n    \"verl.models.mcore.qwen2_5_vl.get_vision_model_config\",\n    \"verl.models.mcore.qwen2_5_vl.get_vision_projection_config\",\n]\n\n\ndef iter_submodules(root: ModuleType) -> Iterable[ModuleType]:\n    \"\"\"Yield *root* and every sub-module inside it.\"\"\"\n    yield root\n    if getattr(root, \"__path__\", None):  # only packages have __path__\n        for mod_info in pkgutil.walk_packages(root.__path__, prefix=f\"{root.__name__}.\"):\n            try:\n                yield importlib.import_module(mod_info.name)\n            except Exception as exc:  # noqa: BLE001\n                print(f\"[warn] Skipping {mod_info.name!r}: {exc}\", file=sys.stderr)\n\n\ndef names_missing_doc(mod: ModuleType) -> list[str]:\n    \"\"\"Return fully-qualified names that need docstrings.\"\"\"\n    missing: list[str] = []\n    public = getattr(mod, \"__all__\", [])\n    for name in public:\n        obj = getattr(mod, name, None)\n        if f\"{mod.__name__}.{name}\" in _ALLOW_LIST:\n            continue\n        if obj is None:\n            # Exported but not found in the module: flag it anyway.\n            missing.append(f\"{mod.__name__}.{name}  (not found)\")\n            continue\n\n        if inspect.isfunction(obj) or inspect.isclass(obj):\n            doc = inspect.getdoc(obj)\n            if not doc or not doc.strip():\n                missing.append(f\"{mod.__name__}.{name}\")\n    return missing\n\n\ndef check_module(qualname: str) -> list[str]:\n    \"\"\"Import *qualname* and check it (and sub-modules).\"\"\"\n    try:\n        module = importlib.import_module(qualname)\n    except ModuleNotFoundError as exc:\n        print(f\"[error] Cannot import '{qualname}': {exc}\", file=sys.stderr)\n        return [qualname]\n\n    missing: list[str] = []\n    for submod in iter_submodules(module):\n        missing.extend(names_missing_doc(submod))\n    return missing\n\n\ndef autodiscover_packages() -> list[str]:\n    \"\"\"Detect top-level packages under CWD when no argument is given.\"\"\"\n    pkgs: list[str] = []\n    for p in Path.cwd().iterdir():\n        if p.is_dir() and (p / \"__init__.py\").exists():\n            pkgs.append(p.name)\n    return pkgs\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(description=__doc__)\n    parser.add_argument(\n        \"modules\",\n        nargs=\"*\",\n        help=\"Fully-qualified module or package names (defaults to every top-level package found in CWD).\",\n    )\n    args = parser.parse_args()\n\n    targets = args.modules or autodiscover_packages()\n    if not targets:\n        raise ValueError(\"[error] No modules specified and none detected automatically.\")\n\n    all_missing: list[str] = []\n    for modname in targets:\n        all_missing.extend(check_module(modname))\n\n    if all_missing:\n        print(\"\\nMissing docstrings:\")\n        for name in sorted(all_missing):\n            print(f\"  - {name}\")\n        raise ValueError(\"Missing docstrings detected. Please enhance them with docs accordingly.\")\n\n    print(\"✅ All exported functions/classes have docstrings.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/check_device_api_usage.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis CI test is used for checking whether device api usage is irregular, suggest using api in `verl/utils/device.py`.\nSearch targets include .py files in verl/recipe and verl/verl.\nSome files that must contain \".cuda\", \"cuda\" or \"nccl\" keyword is pre-defined in whitelist below.\n\"\"\"\n\nimport os\nfrom argparse import ArgumentParser\nfrom pathlib import Path\n\n# directory or file path must contain keyword \".cuda\" or \"cuda\"\nCUDA_KEYWORD_CHECK_WHITELIST = [\n    \"verl/utils/device.py\",\n    \"recipe/prime/prime_ray_trainer.py\",  # appear in default device_name\n    \"recipe/spin/spin_trainer.py\",  # appear in default device_name\n    \"recipe/sppo/sppo_ray_trainer.py\",  # appear in default device_name\n    \"recipe/one_step_off_policy/ray_trainer.py\",  # appear in default device_name\n    \"verl/utils/profiler/nvtx_profile.py\",  # appear in NsightSystemsProfiler\n    \"verl/utils/kernel/linear_cross_entropy.py\",  # appear in nvidia nvtx\n    \"verl/utils/rendezvous/ray_backend.py\",  # appear in cupy importance\n    \"verl/single_controller/ray/base.py\",  # appear in default device_name\n    \"verl/trainer/ppo/ray_trainer.py\",  # appear in default device_name\n    \"verl/utils/reward_score/sandbox_fusion/utils.py\",  # appear in sandbox language type\n    \"verl/workers/reward_model/megatron/reward_model.py\",  # appear in default device_name\n    \"verl/third_party/torch/distributed/_state_dict_utils.py\",  # torch monkey patch fixes\n    \"verl/third_party/torch/distributed/checkpoint/state_dict.py\",  # torch monkey patch fixes\n    \"verl/workers/engine/fsdp/engine_impl.py\",\n]\n\n# directory or file path must contain keyword \"nccl\"\nNCCL_KEYWORD_CHECK_WHITELIST = [\n    \"verl/utils/device.py\",\n    \"verl/third_party/sglang/parallel_state.py\",  # appear in default backend\n]\n\nSEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST\n\nSEARCH_KEYWORDS = [\".cuda\", '\"cuda\"', '\"nccl\"']\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--directory\", \"-d\", required=True, type=str)\n    args = parser.parse_args()\n    directory_in_str = args.directory\n\n    pathlist = Path(directory_in_str).glob(\"**/*.py\")\n    for path in pathlist:\n        path_in_str = str(path.absolute())\n\n        # judge whether current path is in pre-defined search whitelist or not.\n        path_in_whitelist = False\n\n        for sw in SEARCH_WHITELIST:\n            # for easy debugging in non-linux system\n            sw = sw.replace(\"/\", os.sep)\n            if sw in path_in_str:\n                print(f\"[SKIP] File {path_in_str} is in device api usage check whitelist, checking is skipped.\")\n                path_in_whitelist = True\n                break\n\n        if path_in_whitelist:\n            continue\n\n        with open(path_in_str, encoding=\"utf-8\") as f:\n            file_content = f.read()\n\n            find_invalid_device_management = False\n\n            for sk in SEARCH_KEYWORDS:\n                if sk in file_content:\n                    find_invalid_device_management = True\n                    break\n\n            print(\n                f\"[CHECK] File {path_in_str} is detected for device api usage check, check result: \"\n                f\"{'success' if not find_invalid_device_management else f'failed, because detect {sk}'}.\"\n            )\n\n            assert not find_invalid_device_management, (\n                f'file {path_in_str} contains .cuda/\"cuda\"/\"nccl\" usage, please use api in '\n                f\"verl/utils/device.py directly.\"\n            )\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/check_docs_time_info.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nCheck that every .md and .rst file under docs/ contains the substring \"Last updated\",\nwith an allow-list for exceptions.\n\"\"\"\n\nimport sys\nfrom pathlib import Path\n\n# === CONFIGURATION ===\n\n# Relative paths (to docs/) or glob patterns to skip checking\nALLOW_LIST = {\n    \"docs/README.md\",  # you can list individual files\n    \"docs/legacy/*.rst\",  # or glob patterns\n    \"docs/index.rst\",\n    \"docs/start/install.rst\",\n    \"docs/start/quickstart.rst\",\n    \"docs/README_vllm0.7.md\",\n}\n\n# The folder to scan\nDOCS_DIR = Path(\"docs\")\n\n# === SCRIPT ===\n\n\ndef is_allowed(path: Path) -> bool:\n    \"\"\"\n    Return True if `path` matches any entry in ALLOW_LIST.\n    \"\"\"\n    rel = str(path)\n    for pattern in ALLOW_LIST:\n        if Path(rel).match(pattern):\n            return True\n    return False\n\n\ndef main():\n    if not DOCS_DIR.exists():\n        print(f\"Error: Documentation directory '{DOCS_DIR}' does not exist.\", file=sys.stderr)\n        sys.exit(1)\n\n    missing = []\n\n    # Gather all .md and .rst files under docs/\n    for ext in (\"*.md\", \"*.rst\"):\n        for path in DOCS_DIR.rglob(ext):\n            if is_allowed(path):\n                continue\n\n            text = path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n            if \"Last updated\" not in text:\n                missing.append(path)\n\n    # Report\n    if missing:\n        print(\"\\nThe following files are missing the 'Last updated' string:\\n\")\n        for p in missing:\n            print(f\"  - {p}\")\n        print(f\"\\nTotal missing: {len(missing)}\\n\", file=sys.stderr)\n        raise AssertionError(\n            \"Some documentation files lack a 'Last updated' line. Please include info such as \"\n            \"'Last updated: mm/dd/yyyy' to indicate the last update time of the document.\"\n        )\n    else:\n        print(\"✅ All checked files contain 'Last updated'.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/check_docstrings.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nPython script to check docstrings for functions and classes in specified files.\nChecks that every public function and class has proper docstring documentation.\n\"\"\"\n\nimport ast\nimport os\nimport sys\n\n\nclass DocstringChecker(ast.NodeVisitor):\n    \"\"\"AST visitor to check for missing docstrings in functions and classes.\"\"\"\n\n    def __init__(self, filename: str):\n        self.filename = filename\n        self.missing_docstrings: list[tuple[str, str, int]] = []\n        self.current_class = None\n        self.function_nesting_level = 0\n\n    def visit_FunctionDef(self, node: ast.FunctionDef):\n        \"\"\"Visit function definitions and check for docstrings.\"\"\"\n        if not node.name.startswith(\"_\") and self.function_nesting_level == 0:\n            if not self._has_docstring(node):\n                func_name = f\"{self.current_class}.{node.name}\" if self.current_class else node.name\n                self.missing_docstrings.append((func_name, self.filename, node.lineno))\n\n        self.function_nesting_level += 1\n        self.generic_visit(node)\n        self.function_nesting_level -= 1\n\n    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):\n        \"\"\"Visit async function definitions and check for docstrings.\"\"\"\n        if not node.name.startswith(\"_\") and self.function_nesting_level == 0:\n            if not self._has_docstring(node):\n                func_name = f\"{self.current_class}.{node.name}\" if self.current_class else node.name\n                self.missing_docstrings.append((func_name, self.filename, node.lineno))\n\n        self.function_nesting_level += 1\n        self.generic_visit(node)\n        self.function_nesting_level -= 1\n\n    def visit_ClassDef(self, node: ast.ClassDef):\n        \"\"\"Visit class definitions and check for docstrings.\"\"\"\n        if not node.name.startswith(\"_\"):\n            if not self._has_docstring(node):\n                self.missing_docstrings.append((node.name, self.filename, node.lineno))\n\n        old_class = self.current_class\n        self.current_class = node.name\n        self.generic_visit(node)\n        self.current_class = old_class\n\n    def _has_docstring(self, node) -> bool:\n        \"\"\"Check if a node has a docstring.\"\"\"\n        return ast.get_docstring(node) is not None\n\n\ndef check_file_docstrings(filepath: str) -> list[tuple[str, str, int]]:\n    \"\"\"Check docstrings in a single file.\"\"\"\n    try:\n        with open(filepath, encoding=\"utf-8\") as f:\n            content = f.read()\n\n        tree = ast.parse(content, filename=filepath)\n        checker = DocstringChecker(filepath)\n        checker.visit(tree)\n        return checker.missing_docstrings\n\n    except Exception as e:\n        print(f\"Error processing {filepath}: {e}\")\n        return []\n\n\ndef main():\n    \"\"\"Main function to check docstrings in specified files.\"\"\"\n\n    files_to_check = [\n        \"verl/trainer/ppo/ray_trainer.py\",\n        \"verl/trainer/main_ppo.py\",\n        \"verl/trainer/ppo/reward.py\",\n        \"verl/utils/reward_score/__init__.py\",\n        \"verl/trainer/ppo/core_algos.py\",\n        \"verl/experimental/agent_loop/agent_loop.py\",\n        \"verl/workers/sharding_manager/fsdp_vllm.py\",\n        \"verl/workers/sharding_manager/fsdp_ulysses.py\",\n    ]\n\n    script_dir = os.path.dirname(os.path.abspath(__file__))\n    repo_path = os.path.dirname(os.path.dirname(script_dir))\n\n    if not os.path.exists(repo_path):\n        print(f\"Repository path {repo_path} does not exist!\")\n        sys.exit(1)\n\n    os.chdir(repo_path)\n\n    all_missing_docstrings = []\n\n    print(\"Checking docstrings in specified files...\")\n    print(\"=\" * 60)\n\n    for file_path in files_to_check:\n        if not os.path.exists(file_path):\n            print(f\"Warning: File {file_path} does not exist!\")\n            continue\n\n        print(f\"Checking {file_path}...\")\n        missing = check_file_docstrings(file_path)\n        all_missing_docstrings.extend(missing)\n\n        if missing:\n            print(f\"  Found {len(missing)} missing docstrings\")\n        else:\n            print(\"  All functions and classes have docstrings ✓\")\n\n    print(\"=\" * 60)\n\n    if all_missing_docstrings:\n        print(f\"\\nSUMMARY: Found {len(all_missing_docstrings)} functions/classes missing docstrings:\")\n        print(\"-\" * 60)\n\n        by_file = {}\n        for name, filepath, lineno in all_missing_docstrings:\n            if filepath not in by_file:\n                by_file[filepath] = []\n            by_file[filepath].append((name, lineno))\n\n        for filepath in sorted(by_file.keys()):\n            print(f\"\\n{filepath}:\")\n            for name, lineno in sorted(by_file[filepath], key=lambda x: x[1]):\n                print(f\"  - {name} (line {lineno})\")\n\n        print(f\"\\nTotal missing docstrings: {len(all_missing_docstrings)}\")\n\n        raise Exception(f\"Found {len(all_missing_docstrings)} functions/classes without proper docstrings!\")\n\n    else:\n        print(\"\\n✅ All functions and classes have proper docstrings!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/check_license.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom argparse import ArgumentParser\nfrom pathlib import Path\n\nlicense_head_bytedance = \"Copyright 2024 Bytedance Ltd. and/or its affiliates\"\nlicense_head_bytedance_25 = \"Copyright 2025 Bytedance Ltd. and/or its affiliates\"\n# Add custom license headers below\nlicense_head_prime = \"Copyright 2024 PRIME team and/or its affiliates\"\nlicense_head_individual = \"Copyright 2025 Individual Contributor:\"\nlicense_head_sglang = \"Copyright 2023-2024 SGLang Team\"\nlicense_head_modelbest = \"Copyright 2025 ModelBest Inc. and/or its affiliates\"\nlicense_head_amazon = \"Copyright 2025 Amazon.com Inc and/or its affiliates\"\nlicense_head_facebook = \"Copyright (c) 2016-     Facebook, Inc\"\nlicense_headers = [\n    license_head_bytedance,\n    license_head_bytedance_25,\n    license_head_prime,\n    license_head_individual,\n    license_head_sglang,\n    license_head_modelbest,\n    license_head_amazon,\n    license_head_facebook,\n]\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--directory\", \"-d\", required=True, type=str)\n    args = parser.parse_args()\n    directory_in_str = args.directory\n\n    pathlist = Path(directory_in_str).glob(\"**/*.py\")\n    for path in pathlist:\n        # because path is object not string\n        path_in_str = str(path.absolute())\n        print(path_in_str)\n        with open(path_in_str, encoding=\"utf-8\") as f:\n            file_content = f.read()\n\n            has_license = False\n            for lh in license_headers:\n                if lh in file_content:\n                    has_license = True\n                    break\n            assert has_license, f\"file {path_in_str} does not contain license\"\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/check_pr_description.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/usr/bin/env python3\nimport json\nimport os\n\n# Number of lines to check\nNUM_LINES = 5\n\n\n# Custom exception types for clear error handling\nclass TemplateFileError(Exception):\n    pass\n\n\nclass PRBodyLoadError(Exception):\n    pass\n\n\nclass PRDescriptionError(Exception):\n    pass\n\n\n# Path to the PR template file\ntemplate_file = os.path.join(os.getenv(\"GITHUB_WORKSPACE\", \".\"), \".github\", \"PULL_REQUEST_TEMPLATE.md\")\n\n\ndef load_template(path):\n    \"\"\"\n    Load only the first NUM_LINES of the PR template file as a list of lines,\n    without stripping any characters.\n    \"\"\"\n    lines = []\n    try:\n        with open(path, encoding=\"utf-8\") as f:\n            for _ in range(NUM_LINES):\n                line = f.readline()\n                if not line:\n                    break\n                lines.append(line.strip())\n        return lines\n    except Exception as e:\n        raise TemplateFileError(f\"Failed to read PR template (first {NUM_LINES} lines) at {path}: {e}\") from e\n\n\ndef load_pr_body(event_path):\n    try:\n        with open(event_path, encoding=\"utf-8\") as f:\n            payload = json.load(f)\n        return payload.get(\"pull_request\", {}).get(\"body\", \"\") or \"\"\n    except Exception as e:\n        raise PRBodyLoadError(f\"Failed to read PR body from {event_path}: {e}\") from e\n\n\ndef check_pr_description(body, template_lines):\n    \"\"\"\n    Compare the first NUM_LINES lines of the PR body to the template lines.\n    If they match exactly, the placeholder was not modified.\n    \"\"\"\n    pr_lines = body.splitlines(keepends=True)\n    pr_first = [x.strip() for x in pr_lines[:NUM_LINES]]\n    if pr_first == template_lines:\n        raise PRDescriptionError(\n            \"It looks like you haven't updated the '### What does this PR do?' section. Please replace \"\n            \"the placeholder text with a concise description of what your PR does.\"\n        )\n    else:\n        print(pr_first)\n        print(template_lines)\n\n\ndef main():\n    event_path = os.getenv(\"GITHUB_EVENT_PATH\")\n    if not event_path:\n        raise OSError(\"GITHUB_EVENT_PATH is not set.\")\n\n    template_lines = load_template(template_file)\n    pr_body = load_pr_body(event_path)\n    check_pr_description(pr_body, template_lines)\n\n    print(\"✅ '### What does this PR do?' section has been filled out.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/check_pr_title.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport re\n\n# Get PR title from environment\npr_title = os.environ.get(\"PR_TITLE\", \"\").strip()\n\n# Define rules\nallowed_modules = [\"fsdp\", \"megatron\", \"sglang\", \"vllm\", \"rollout\", \"trainer\"]\nallowed_modules += [\"tests\", \"training_utils\", \"recipe\", \"hardware\", \"deployment\"]\nallowed_modules += [\"ray\", \"worker\", \"single_controller\", \"misc\", \"docker\", \"ci\"]\nallowed_modules += [\"perf\", \"model\", \"algo\", \"env\", \"tool\", \"ckpt\", \"doc\", \"data\", \"cfg\"]\nallowed_types = [\"feat\", \"fix\", \"refactor\", \"chore\", \"test\"]\n\n# Check for [BREAKING] prefix and extract the rest of the title\nbreaking_match = re.match(r\"^\\[BREAKING\\]\\s*(.+)$\", pr_title, re.IGNORECASE)\nif breaking_match:\n    core_pr_title = breaking_match.group(1).strip()\n    is_breaking = True\nelse:\n    core_pr_title = pr_title\n    is_breaking = False\n\n# Build dynamic regex pattern for modules (now working on core_pr_title)\nre_modules_pattern = re.compile(r\"^\\[([a-z_,\\s]+)\\]\", re.IGNORECASE)\nre_modules = re_modules_pattern.match(core_pr_title)\nif not re_modules:\n    print(f\"❌ Invalid PR title: '{pr_title}'\")\n    print(\"Expected format: [BREAKING][module] type: description\")\n    print(f\"Allowed modules: {', '.join(allowed_modules)}\")\n    raise Exception(\"Invalid PR title\")\nelse:\n    modules = re.findall(r\"[a-z_]+\", re_modules.group(1).lower())\n    if not all(module in allowed_modules for module in modules):\n        invalid_modules = [module for module in modules if module not in allowed_modules]\n        print(f\"❌ Invalid modules: {', '.join(invalid_modules)}\")\n        print(f\"Allowed modules: {', '.join(allowed_modules)}\")\n        raise Exception(\"Invalid PR title\")\n\ntypes_pattern = \"|\".join(re.escape(t) for t in allowed_types)\nre_types_pattern = re.compile(rf\"^\\[[a-z_,\\s]+\\]\\s+({types_pattern}):\\s+.+$\", re.IGNORECASE)\nmatch = re_types_pattern.match(core_pr_title)\n\nif not match:\n    print(f\"❌ Invalid PR title: '{pr_title}'\")\n    print(\"Expected format: [BREAKING][module] type: description\")\n    print(f\"Allowed types: {', '.join(allowed_types)}\")\n    raise Exception(\"Invalid PR title\")\n\nchange_type = match.group(1).lower()\n\n# Build the success message\nbreaking_info = \" (BREAKING CHANGE)\" if is_breaking else \"\"\nprint(f\"✅ PR title is valid: {pr_title}, modules: {modules}, type: {change_type}{breaking_info}\")\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/test_config_docs.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\nfrom pathlib import Path\n\n\ndef validate_yaml_format(yaml_lines):\n    errors = []\n    i = 0\n\n    while i < len(yaml_lines):\n        line = yaml_lines[i]\n        stripped = line.strip()\n\n        # Skip empty lines\n        if stripped == \"\":\n            i += 1\n            continue\n\n        # Match YAML keys like \"field:\" or \"field: value\"\n        key_match = re.match(r\"^(\\s*)([a-zA-Z0-9_]+):\", line)\n        if key_match:\n            # Check if there's a comment above\n            if i == 0 or not yaml_lines[i - 1].strip().startswith(\"#\"):\n                errors.append(f\"Missing comment above line {i + 1}: {line.strip()}\")\n\n            # Check for inline comment\n            if \"#\" in line and not stripped.startswith(\"#\"):\n                comment_index = line.index(\"#\")\n                colon_index = line.index(\":\")\n                if comment_index > colon_index:\n                    errors.append(f\"Inline comment found on line {i + 1}: {line.strip()}\")\n\n            # Check for blank line after this key line (unless next is a deeper indent)\n            if i + 1 < len(yaml_lines):\n                next_line = yaml_lines[i + 1]\n                next_stripped = next_line.strip()\n\n                # If next is not empty and not a deeper nested line, enforce blank line\n                if next_stripped != \"\":\n                    errors.append(f\"Missing blank line after line {i + 1}: {line.strip()}\")\n\n        i += 1\n\n    return errors\n\n\ndef test_trainer_config_doc():\n    yamls_to_inspect = [\n        \"verl/trainer/config/ppo_trainer.yaml\",\n        \"verl/trainer/config/actor/actor.yaml\",\n        \"verl/trainer/config/actor/dp_actor.yaml\",\n        \"verl/trainer/config/ref/ref.yaml\",\n        \"verl/trainer/config/ref/dp_ref.yaml\",\n        \"verl/trainer/config/rollout/rollout.yaml\",\n    ]\n    success = True\n    for yaml_to_inspect in yamls_to_inspect:\n        yaml_path = Path(yaml_to_inspect)  # path to your YAML file\n        with open(yaml_path) as f:\n            lines = f.readlines()\n\n        validation_errors = validate_yaml_format(lines)\n        if validation_errors:\n            success = False\n            print(\"YAML documentation format check failed:\")\n            print(f\"Please read the top block of {yaml_to_inspect} to see format rules:\\n\")\n            for err in validation_errors:\n                print(\" -\", err)\n\n    if not success:\n        raise Exception(\"Please fix documentation format.\")\n    else:\n        print(\"YAML format check passed ✅\")\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/test_import.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ndef test_import():\n    import verl\n\n    print(verl.__version__)\n\n\ndef test_single_controller_import():\n    import verl.single_controller\n\n    print(verl.single_controller.__version__)\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/type_coverage_check.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Custom type annotation check tool.\nTo inspect the type annotation for functions in the entire codebase, please run:\nfind verl -type f -name \"*.py\" | xargs -n 1 python3 tests/special_sanity/type_coverage_check.py --all-lines\n--debug --target-file\n\"\"\"\n\nimport argparse\nimport ast\nimport linecache\nimport subprocess\nfrom pathlib import Path\n\n\ndef get_changed_files() -> list[Path]:\n    result = subprocess.run(\n        [\"git\", \"diff\", \"--name-only\", \"--diff-filter=AM\", \"origin/main...HEAD\"], stdout=subprocess.PIPE, text=True\n    )\n    return [Path(f) for f in result.stdout.splitlines() if f.endswith(\".py\")]\n\n\ndef get_changed_lines(file_path: Path) -> set[int]:\n    result = subprocess.run(\n        [\"git\", \"diff\", \"-U0\", \"origin/main...HEAD\", \"--\", str(file_path)],\n        stdout=subprocess.PIPE,\n        text=True,\n    )\n    lines: set[int] = set()\n    for line in result.stdout.splitlines():\n        if line.startswith(\"@@\"):\n            for part in line.split():\n                try:\n                    if part.startswith(\"+\") and \",\" in part:\n                        start, count = map(int, part[1:].split(\",\"))\n                        lines.update(range(start, start + count))\n                    elif part.startswith(\"+\") and \",\" not in part:\n                        lines.add(int(part[1:]))\n                except Exception:\n                    # (vermouth1992) There are many edge cases here because + can be in the changed program\n                    pass\n    return lines\n\n\nCHECK_SUCCESS = 0\nCHECK_WARNING = 1\nCHECK_FAILURE = -1\n\n\ndef should_check_type(arg_name: str) -> bool:\n    if arg_name in (\"self\", \"cls\"):\n        return False\n    if arg_name.startswith(\"*\"):\n        return False\n    return True\n\n\ndef has_type_annotations(node: ast.AST, debug: bool = False) -> int:\n    if isinstance(node, ast.FunctionDef):\n        is_private = node.name.startswith(\"_\")\n        has_ann = (\n            all(arg.annotation is not None for arg in node.args.args if should_check_type(arg.arg))\n            and node.returns is not None\n        )\n        if has_ann or is_private:\n            return CHECK_SUCCESS\n        else:\n            if debug:\n                print(node, [(arg.annotation, arg.arg) for arg in node.args.args if should_check_type(arg.arg)])\n            return CHECK_FAILURE\n    return CHECK_SUCCESS\n\n\ndef check_file(\n    file_path: Path, changed_lines: set[int], debug: bool = False\n) -> tuple[int, int, list[tuple[Path, int, str]], list[tuple[Path, int, str]]]:\n    with open(file_path) as f:\n        source: str = f.read()\n    tree = ast.parse(source, filename=str(file_path))\n    annotated = 0\n    total = 0\n    warning_lines: list[tuple[Path, int, str]] = []\n    failure_lines: list[tuple[Path, int, str]] = []\n\n    for node in ast.walk(tree):\n        if hasattr(node, \"lineno\") and node.lineno in changed_lines:\n            if isinstance(node, ast.FunctionDef | ast.Assign | ast.AnnAssign):\n                total += 1\n                result = has_type_annotations(node, debug)\n                if result == CHECK_SUCCESS or result == CHECK_WARNING:\n                    annotated += 1\n                    if result == CHECK_WARNING:\n                        warning_lines.append(\n                            (file_path, node.lineno, linecache.getline(str(file_path), node.lineno).strip())\n                        )\n                else:\n                    source_line = linecache.getline(str(file_path), node.lineno).strip()\n                    failure_lines.append((file_path, node.lineno, source_line))\n\n    return annotated, total, warning_lines, failure_lines\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--threshold\", type=float, default=0.3, help=\"Minimum ratio of annotated lines required (0.0 - 1.0)\"\n    )\n    parser.add_argument(\"--target-file\", type=str, default=None, help=\"Path to the Python source file to analyse\")\n    parser.add_argument(\n        \"--all-lines\",\n        action=\"store_true\",\n        help=\"Check all lines in the file instead of only changed lines based on git\",\n    )\n    parser.add_argument(\"--debug\", action=\"store_true\", help=\"Add debugging logs\")\n    args = parser.parse_args()\n\n    total_changed = 0\n    total_annotated = 0\n    all_warnings: list[tuple[Path, int, str]] = []\n    all_failures: list[tuple[Path, int, str]] = []\n\n    target_files = [args.target_file] if args.target_file is not None else get_changed_files()\n    for fpath in target_files:\n        if \"tests/\" in str(fpath):\n            continue\n        if args.all_lines:\n            changed_lines = [i + 1 for i in range(len(open(fpath).readlines()))]\n        else:\n            changed_lines = get_changed_lines(fpath)\n        annotated, total, warning_lines, failure_lines = check_file(fpath, changed_lines, args.debug)\n        total_annotated += annotated\n        total_changed += total\n        all_warnings.extend(warning_lines)\n        all_failures.extend(failure_lines)\n\n    ratio = (total_annotated / total_changed) if total_changed else 1.0\n\n    print(\n        f\"🔍 Type coverage on {'all' if args.all_lines else 'changed'} lines: \"\n        f\"{total_annotated}/{total_changed} = {ratio:.2%}. Files inspected: {target_files}\"\n    )\n\n    if all_warnings:\n        print(\"\\n⚠️ Suggest Improve: Lines missing type annotations for inputs and outputs:\\n\")\n        for fname, lineno, line in all_warnings:\n            print(f\"{fname}:{lineno}: {line}\")\n\n    if all_failures:\n        print(\"⚠️ [ERROR] Lines missing type annotations for inputs and outputs:\\n\")\n        for fname, lineno, line in all_failures:\n            print(f\"{fname}:{lineno}: {line}\")\n\n    if ratio < args.threshold:\n        print(\n            f\"Please add type annotations for inputs and outputs to meet threshold {args.threshold}. \"\n            f\"Cases exempt from checking:\"\n        )\n        print(\"1. Private methods.\")\n        print(\"2. Args with name in ('self', 'cls'), or *args / **kwargs\")\n        print(\"3. Files under tests/\")\n        raise Exception(f\"\\n❌ Type coverage below threshold ({args.threshold:.0%}).\")\n    else:\n        if all_warnings or all_failures:\n            print(\"\")\n        print(\"✅ Type annotation coverage acceptable.\\n\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/validate_imported_docs.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nverify_imported_docs.py\n\nAssert that every function or class *explicitly imported* (via\n`from <module> import <name>`) in a given Python file has a docstring.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport ast\nimport importlib\nimport inspect\nimport pathlib\nimport sys\n\n\ndef _parse_args() -> argparse.Namespace:\n    p = argparse.ArgumentParser(description=\"Verify that imported functions/classes have docstrings.\")\n    p.add_argument(\n        \"--target-file\",\n        default=\"verl/trainer/ppo/ray_trainer.py\",\n        help=\"Path to the Python source file to analyse (e.g. verl/trainer/ppo/ray_trainer.py)\",\n    )\n    p.add_argument(\n        \"--allow-list\",\n        default=[\"omegaconf.open_dict\"],\n        help=\"a list of third_party dependencies that do not have proper docs :(\",\n    )\n    p.add_argument(\n        \"--project-root\",\n        default=\".\",\n        help=\"Directory to prepend to PYTHONPATH so local packages resolve (default: .)\",\n    )\n    p.add_argument(\n        \"--quiet\",\n        action=\"store_true\",\n        help=\"Suppress success message (still prints errors).\",\n    )\n    return p.parse_args()\n\n\ndef _import_attr(module_name: str, attr_name: str):\n    \"\"\"Import `module_name` then return `getattr(module, attr_name)`.\"\"\"\n    module = importlib.import_module(module_name)\n    return getattr(module, attr_name)\n\n\ndef _check_file(py_file: pathlib.Path, project_root: pathlib.Path, allow_list: list[str]) -> list[str]:\n    \"\"\"Return a list of error strings (empty == success).\"\"\"\n    # Ensure local packages resolve\n    sys.path.insert(0, str(project_root.resolve()))\n\n    tree = ast.parse(py_file.read_text(), filename=str(py_file))\n    problems: list[str] = []\n\n    for node in ast.walk(tree):\n        if not isinstance(node, ast.ImportFrom):\n            continue\n\n        # Relative imports (level > 0) get the leading dots stripped\n        module_name = \".\" * node.level + (node.module or \"\")\n        for alias in node.names:\n            if alias.name == \"*\":\n                problems.append(\n                    f\"{py_file}:{node.lineno} - wildcard import `from {module_name} import *` cannot be verified.\"\n                )\n                continue\n\n            imported_name = alias.name\n\n            try:\n                obj = _import_attr(module_name, imported_name)\n            except Exception:  # pragma: no cover – wide net for import quirks\n                pass\n                # For some reason the module cannot be imported, skip for now\n                # problems.append(\n                #     f\"{py_file}:{node.lineno} - could not resolve \"\n                #     f\"`{imported_name}` from `{module_name}` ({exc})\"\n                # )\n                continue\n\n            if f\"{module_name}.{imported_name}\" in allow_list:\n                continue\n            if inspect.isfunction(obj) or inspect.isclass(obj):\n                doc = inspect.getdoc(obj)\n                if not (doc and doc.strip()):\n                    kind = \"class\" if inspect.isclass(obj) else \"function\"\n                    problems.append(\n                        f\"{py_file}:{node.lineno} - {kind} `{module_name}.{imported_name}` is missing a docstring.\"\n                    )\n\n    return problems\n\n\ndef main() -> None:\n    args = _parse_args()\n    target_path = pathlib.Path(args.target_file).resolve()\n    project_root = pathlib.Path(args.project_root).resolve()\n\n    if not target_path.is_file():\n        raise Exception(f\"❌ Target file not found: {target_path}\")\n\n    errors = _check_file(target_path, project_root, args.allow_list)\n\n    if errors:\n        print(\"Docstring verification failed:\\n\")\n        print(\"\\n\".join(f\" • {e}\" for e in errors))\n        raise Exception(\"❌ Docstring verification failed.\")\n\n    if not args.quiet:\n        print(f\"✅ All explicitly imported functions/classes in {target_path} have docstrings.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/tests/special_sanity/validate_structure.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/usr/bin/env python3\n\"\"\"\nValidate that test file subfolders mirror the top-level package layout.\n\nUsage examples\n--------------\n\n# Typical run (defaults: impl_root=my_project, tests_root=tests)\npython check_tests_structure.py\n\n# Custom layout and extra allowed folders\npython check_tests_structure.py \\\n    --impl-root verl \\\n    --tests-root tests \\\n    --allow-dirs special_e2e special_sanity special_standalone special_distributed\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport sys\nfrom pathlib import Path\n\n\ndef discover_allowed_modules(impl_root: Path, extra: list[str]) -> set[str]:\n    \"\"\"Return the set of first-level directories that tests may live under.\"\"\"\n    allowed = {p.name for p in impl_root.iterdir() if p.is_dir()}\n    allowed.update(extra)\n    return allowed\n\n\ndef find_violations(tests_root: Path, allowed: set[str], allowed_files: list[str]) -> list[str]:\n    \"\"\"Return a list of error strings for test files in the wrong place.\"\"\"\n    errors: list[str] = []\n    for test_file in tests_root.rglob(\"test*.py\"):\n        if str(test_file) in allowed_files:\n            continue\n        rel_parts = test_file.relative_to(tests_root).parts\n        if len(rel_parts) < 2:\n            errors.append(f\"{test_file}: must be inside one of {sorted(allowed)} (not at tests root)\")\n            continue\n\n        first_folder = rel_parts[0]\n        if first_folder not in allowed:\n            errors.append(\n                f\"{test_file}: subfolder '{first_folder}' under tests/ is not an allowed module. \"\n                f\"The valid ones are: {sorted(allowed)}\"\n            )\n    return errors\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(description=\"Check that test files follow tests/<module>/… layout.\")\n    parser.add_argument(\n        \"--impl-root\",\n        type=Path,\n        default=\"verl\",\n        help=\"Implementation root (default: my_project)\",\n    )\n    parser.add_argument(\n        \"--tests-root\",\n        type=Path,\n        default=\"tests\",\n        help=\"Root of test tree (default: tests)\",\n    )\n    parser.add_argument(\n        \"--allow-dirs\",\n        nargs=\"*\",\n        default=[\"special_e2e\", \"special_sanity\", \"special_standalone\", \"special_distributed\"],\n        help=\"Extra top-level test folders that are exempt from the rule\",\n    )\n    parser.add_argument(\n        \"--allow-files\",\n        nargs=\"*\",\n        default=[\"tests/test_protocol_on_cpu.py\", \"tests/test_base_config_on_cpu.py\"],\n        help=\"Extra top-level test folders that are exempt from the rule\",\n    )\n    args = parser.parse_args()\n\n    if not args.impl_root.is_dir():\n        raise Exception(f\"Implementation root '{args.impl_root}' does not exist.\")\n    if not args.tests_root.is_dir():\n        raise Exception(f\"Tests root '{args.tests_root}' does not exist.\")\n\n    allowed = discover_allowed_modules(args.impl_root, args.allow_dirs)\n    violations = find_violations(args.tests_root, allowed, args.allow_files)\n\n    if violations:\n        print(\"❌  Test layout violations found:\\n\", file=sys.stderr)\n        for err in violations:\n            print(\"  -\", err, file=sys.stderr)\n\n        print(\n            f\"\\nGuideline:\\n  Place each test file under   tests/<module_name>/…\\n  where <module_name> is \"\n            f\"one of the top-level packages inside '{args.impl_root}', or is explicitly listed via --allow-dirs.\\n\",\n            file=sys.stderr,\n        )\n        raise Exception(\"❌  Test layout violations found.\")\n\n    print(\"✅  Tests folder structure looks good.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/tests/special_standalone/README.md",
    "content": "The standalone test folder is reserved for tests that require dedicated environment (e.g. memory stress tests)\n"
  },
  {
    "path": "verl_rl/tests/special_standalone/test_memory_buffers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTest memory buffers\n- We start with two models with the same weights\n- We use Memory buffer to make one of the models and then compare the parameters\n\"\"\"\n\nimport gc\n\nimport torch\nfrom transformers import LlamaConfig, LlamaModel\n\n\ndef test_memory_buffers():\n    llama_config = LlamaConfig(\n        vocab_size=256,\n        hidden_size=4096,\n        intermediate_size=11008,\n        num_hidden_layers=2,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n    )\n\n    model = LlamaModel(config=llama_config).cuda()\n    model_copy = LlamaModel(config=llama_config).cuda()\n    model_copy.load_state_dict(model.state_dict())\n\n    norm_factor = 1024**3\n\n    t_before = torch.cuda.get_device_properties(0).total_memory / norm_factor\n    r_before = torch.cuda.memory_reserved(0) / norm_factor\n    a_before = torch.cuda.memory_allocated(0) / norm_factor\n\n    print(f\"Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB\")\n\n    t = torch.cuda.get_device_properties(0).total_memory / norm_factor\n    r = torch.cuda.memory_reserved(0) / norm_factor\n    a = torch.cuda.memory_allocated(0) / norm_factor\n\n    gc.collect()\n    torch.cuda.empty_cache()\n\n    print(f\"After Total memory: {t} GB, reserved: {r} GB, allocated: {a} GB\")\n\n    change_ratio = (a - a_before) / a_before\n    assert change_ratio < 0.01, f\"make sure the allocated change is less than 1%, Got {change_ratio}\"\n\n    for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters(), strict=True):\n        assert name1 == name2\n        assert torch.eq(param1.data, param2.data).all(), f\"{param1.data}, {param2.data}, {name1}\"\n\n\nif __name__ == \"__main__\":\n    test_memory_buffers()\n"
  },
  {
    "path": "verl_rl/tests/test_base_config_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\nfrom verl.base_config import BaseConfig\n\n\n@pytest.fixture\ndef base_config_mock():\n    \"\"\"Fixture to create a mock BaseConfig instance with test attributes.\"\"\"\n    mock_config = BaseConfig()\n    mock_config.test_attr = \"test_value\"\n    return mock_config\n\n\ndef test_getitem_success(base_config_mock):\n    \"\"\"Test __getitem__ with existing attribute (happy path).\"\"\"\n    assert base_config_mock[\"test_attr\"] == \"test_value\"\n\n\ndef test_getitem_nonexistent_attribute(base_config_mock):\n    \"\"\"Test __getitem__ with non-existent attribute (exception path 1).\"\"\"\n    with pytest.raises(AttributeError):\n        _ = base_config_mock[\"nonexistent_attr\"]\n\n\ndef test_getitem_invalid_key_type(base_config_mock):\n    \"\"\"Test __getitem__ with invalid key type (exception path 2).\"\"\"\n    with pytest.raises(TypeError):\n        _ = base_config_mock[123]  # type: ignore\n"
  },
  {
    "path": "verl_rl/tests/test_protocol_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport random\n\nimport numpy as np\nimport pytest\nimport torch\nfrom tensordict import TensorDict\n\nfrom verl import DataProto\nfrom verl.protocol import union_numpy_dict, union_tensor_dict\n\n\ndef test_union_tensor_dict():\n    obs = torch.randn(100, 10)\n\n    data1 = TensorDict({\"obs\": obs, \"act\": torch.randn(100, 3)}, batch_size=[100])\n    data2 = TensorDict({\"obs\": obs, \"next_obs\": torch.randn(100, 10), \"rew\": torch.randn(100)}, batch_size=[100])\n\n    data_with_copied_obs = TensorDict(\n        {\"obs\": obs.clone(), \"next_obs\": torch.randn(100, 10), \"rew\": torch.randn(100)}, batch_size=[100]\n    )\n\n    data = union_tensor_dict(data1, data2)\n    with pytest.raises(AssertionError):\n        data = union_tensor_dict(data1, data_with_copied_obs)\n\n    data = np.random.random(100)\n    data2 = [float(\"nan\") for _ in range(99)]\n    data2.append(\"nan\")\n    data2 = np.array(data2, dtype=object)\n    data3 = np.tile(data2, (2, 1))\n    a = {\"a\": data, \"b\": data2, \"c\": data3}\n    b = {\"a\": data, \"b\": data2, \"c\": data3}\n    b_ = {\"a\": np.random.random(100)}\n    union_numpy_dict(a, b)\n    with pytest.raises(AssertionError):\n        union_numpy_dict(a, b_)\n\n\ndef test_tensor_dict_constructor():\n    obs = torch.randn(100, 10)\n    act = torch.randn(100, 10, 3)\n    data = DataProto.from_dict(tensors={\"obs\": obs, \"act\": act})\n\n    assert data.batch.batch_size == torch.Size([100])\n\n    with pytest.raises(AssertionError):\n        data = DataProto.from_dict(tensors={\"obs\": obs, \"act\": act}, num_batch_dims=2)\n\n    with pytest.raises(AssertionError):\n        data = DataProto.from_dict(tensors={\"obs\": obs, \"act\": act}, num_batch_dims=3)\n\n\ndef test_tensor_dict_make_iterator():\n    obs = torch.randn(100, 10)\n    labels = [random.choice([\"abc\", \"cde\"]) for _ in range(100)]\n    dataset = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels})\n\n    data_iter_1 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1)\n    data_list_1 = []\n    for data in data_iter_1:\n        data_list_1.append(data)\n\n    data_iter_2 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1)\n    data_list_2 = []\n    for data in data_iter_2:\n        data_list_2.append(data)\n\n    for data1, data2 in zip(data_list_1, data_list_2, strict=True):\n        assert isinstance(data1, DataProto)\n        assert isinstance(data2, DataProto)\n        result = torch.all(torch.eq(data1.batch[\"obs\"], data2.batch[\"obs\"]))\n        if not result.item():\n            print(data1.batch[\"obs\"])\n            print(data2.batch[\"obs\"])\n            raise AssertionError()\n        non_tensor_result = np.all(np.equal(data1.non_tensor_batch[\"labels\"], data2.non_tensor_batch[\"labels\"]))\n        if not non_tensor_result.item():\n            print(data1.non_tensor_batch[\"labels\"])\n            print(data2.non_tensor_batch[\"labels\"])\n\n\ndef test_reorder():\n    obs = torch.tensor([1, 2, 3, 4, 5, 6])\n    labels = [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abdce\"})\n    data.reorder(torch.tensor([3, 4, 2, 0, 1, 5]))\n\n    assert torch.all(torch.eq(data.batch[\"obs\"], torch.tensor([4, 5, 3, 1, 2, 6])))\n    assert np.all(data.non_tensor_batch[\"labels\"] == np.array([\"d\", \"e\", \"c\", \"a\", \"b\", \"f\"]))\n    assert data.meta_info == {\"name\": \"abdce\"}\n\n\ndef test_chunk_concat():\n    obs = torch.tensor([1, 2, 3, 4, 5, 6])\n    labels = [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abdce\"})\n\n    with pytest.raises(AssertionError):\n        data.chunk(5)\n\n    data_split = data.chunk(2)\n    assert len(data_split) == 2\n    assert torch.all(torch.eq(data_split[0].batch[\"obs\"], torch.tensor([1, 2, 3])))\n    assert np.all(data_split[0].non_tensor_batch[\"labels\"] == np.array([\"a\", \"b\", \"c\"]))\n    assert data_split[0].meta_info == {\"name\": \"abdce\"}\n\n    assert torch.all(torch.eq(data_split[1].batch[\"obs\"], torch.tensor([4, 5, 6])))\n    assert np.all(data_split[1].non_tensor_batch[\"labels\"] == np.array([\"d\", \"e\", \"f\"]))\n    assert data_split[1].meta_info == {\"name\": \"abdce\"}\n\n    concat_data = DataProto.concat(data_split)\n    assert torch.all(torch.eq(concat_data.batch[\"obs\"], data.batch[\"obs\"]))\n    assert np.all(concat_data.non_tensor_batch[\"labels\"] == data.non_tensor_batch[\"labels\"])\n    assert concat_data.meta_info == data.meta_info\n\n\ndef test_pop():\n    obs = torch.randn(100, 10)\n    act = torch.randn(100, 3)\n    dataset = DataProto.from_dict({\"obs\": obs, \"act\": act}, meta_info={\"2\": 2, \"1\": 1})\n    poped_dataset = dataset.pop(batch_keys=[\"obs\"], meta_info_keys=[\"2\"])\n\n    assert poped_dataset.batch.keys() == {\"obs\"}\n    assert poped_dataset.meta_info.keys() == {\"2\"}\n\n    assert dataset.batch.keys() == {\"act\"}\n    assert dataset.meta_info.keys() == {\"1\"}\n\n\ndef test_repeat():\n    # Create a DataProto object with some batch and non-tensor data\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    # Test interleave=True\n    repeated_data_interleave = data.repeat(repeat_times=2, interleave=True)\n    expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]])\n    expected_labels_interleave = [\"a\", \"a\", \"b\", \"b\", \"c\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_interleave.batch[\"obs\"], expected_obs_interleave))\n    assert (repeated_data_interleave.non_tensor_batch[\"labels\"] == expected_labels_interleave).all()\n    assert repeated_data_interleave.meta_info == {\"info\": \"test_info\"}\n\n    # Test interleave=False\n    repeated_data_no_interleave = data.repeat(repeat_times=2, interleave=False)\n    expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]])\n    expected_labels_no_interleave = [\"a\", \"b\", \"c\", \"a\", \"b\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_no_interleave.batch[\"obs\"], expected_obs_no_interleave))\n    assert (repeated_data_no_interleave.non_tensor_batch[\"labels\"] == expected_labels_no_interleave).all()\n    assert repeated_data_no_interleave.meta_info == {\"info\": \"test_info\"}\n\n\ndef test_dataproto_pad_unpad():\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto\n\n    padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=2)\n    assert pad_size == 1\n\n    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]])\n    expected_labels = [\"a\", \"b\", \"c\", \"a\"]\n\n    assert torch.all(torch.eq(padded_data.batch[\"obs\"], expected_obs))\n    assert (padded_data.non_tensor_batch[\"labels\"] == expected_labels).all()\n    assert padded_data.meta_info == {\"info\": \"test_info\"}\n\n    unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)\n    assert torch.all(torch.eq(unpadd_data.batch[\"obs\"], obs))\n    assert (unpadd_data.non_tensor_batch[\"labels\"] == labels).all()\n    assert unpadd_data.meta_info == {\"info\": \"test_info\"}\n\n    padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=3)\n    assert pad_size == 0\n\n    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    expected_labels = [\"a\", \"b\", \"c\"]\n\n    assert torch.all(torch.eq(padded_data.batch[\"obs\"], expected_obs))\n    assert (padded_data.non_tensor_batch[\"labels\"] == expected_labels).all()\n    assert padded_data.meta_info == {\"info\": \"test_info\"}\n\n    unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)\n    assert torch.all(torch.eq(unpadd_data.batch[\"obs\"], obs))\n    assert (unpadd_data.non_tensor_batch[\"labels\"] == labels).all()\n    assert unpadd_data.meta_info == {\"info\": \"test_info\"}\n\n    padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7)\n    assert pad_size == 4\n\n    expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])\n    expected_labels = [\"a\", \"b\", \"c\", \"a\", \"b\", \"c\", \"a\"]\n    assert torch.all(torch.eq(padded_data.batch[\"obs\"], expected_obs))\n    assert (padded_data.non_tensor_batch[\"labels\"] == expected_labels).all()\n    assert padded_data.meta_info == {\"info\": \"test_info\"}\n\n    unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)\n    assert torch.all(torch.eq(unpadd_data.batch[\"obs\"], obs))\n    assert (unpadd_data.non_tensor_batch[\"labels\"] == labels).all()\n    assert unpadd_data.meta_info == {\"info\": \"test_info\"}\n\n\ndef test_dataproto_fold_unfold():\n    from verl.protocol import DataProto, fold_batch_dim, unfold_batch_dim\n\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    data1 = data.repeat(repeat_times=2, interleave=True)\n\n    data2 = fold_batch_dim(data1, new_batch_size=3)\n\n    torch.testing.assert_close(data2.batch[\"obs\"], torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]))\n    assert (data2.non_tensor_batch[\"labels\"] == [[\"a\", \"a\"], [\"b\", \"b\"], [\"c\", \"c\"]]).all()\n\n    data2.reorder(indices=torch.tensor([1, 2, 0]))\n\n    data3 = unfold_batch_dim(data2, batch_dims=2)\n\n    torch.testing.assert_close(data3.batch[\"obs\"], torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]]))\n    assert (data3.non_tensor_batch[\"labels\"] == [\"b\", \"b\", \"c\", \"c\", \"a\", \"a\"]).all()\n    assert data3.meta_info == {\"info\": \"test_info\"}\n\n\ndef test_torch_save_data_proto():\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n    data.save_to_disk(\"test_data.pt\")\n    loaded_data = DataProto.load_from_disk(\"test_data.pt\")\n\n    assert torch.all(torch.eq(loaded_data.batch[\"obs\"], data.batch[\"obs\"]))\n    assert (loaded_data.non_tensor_batch[\"labels\"] == data.non_tensor_batch[\"labels\"]).all()\n    assert loaded_data.meta_info == data.meta_info\n\n    import os\n\n    os.remove(\"test_data.pt\")\n\n\ndef test_len():\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = np.array([\"a\", \"b\", \"c\"], dtype=object)\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    assert len(data) == 3\n\n    data = DataProto(batch=None, non_tensor_batch={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    assert len(data) == 3\n\n    data = DataProto(batch=None, non_tensor_batch={}, meta_info={\"info\": \"test_info\"})\n\n    assert len(data) == 0\n\n    data = DataProto(batch=None, non_tensor_batch=None, meta_info={\"info\": \"test_info\"})\n\n    assert len(data) == 0\n\n\ndef test_dataproto_index():\n    data_len = 100\n    idx_num = 10\n\n    obs = torch.randn(data_len, 10)\n    labels = [random.choice([\"abc\", \"cde\"]) for _ in range(data_len)]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels})\n    labels_np = np.array(labels)\n\n    idx_np_int = np.random.randint(0, data_len, size=(idx_num,))\n    result_np_int = data[idx_np_int]\n    assert result_np_int.batch.keys() == data.batch.keys()\n    assert result_np_int.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_np_int.batch[\"obs\"].shape[0] == idx_num\n    assert result_np_int.non_tensor_batch[\"labels\"].shape[0] == idx_num\n    assert np.array_equal(result_np_int.batch[\"obs\"].cpu().numpy(), obs[idx_np_int].numpy())\n    assert np.array_equal(result_np_int.non_tensor_batch[\"labels\"], labels_np[idx_np_int])\n\n    idx_torch_int = torch.randint(0, data_len, size=(idx_num,))\n    result_torch_int = data[idx_torch_int]\n    assert result_torch_int.batch.keys() == data.batch.keys()\n    assert result_torch_int.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_torch_int.batch[\"obs\"].shape[0] == idx_num\n    assert result_torch_int.non_tensor_batch[\"labels\"].shape[0] == idx_num\n    assert np.array_equal(result_torch_int.batch[\"obs\"].cpu().numpy(), obs[idx_torch_int].cpu().numpy())\n    assert np.array_equal(result_torch_int.non_tensor_batch[\"labels\"], labels_np[idx_torch_int.cpu().numpy()])\n\n    idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)]\n    result_list_int = data[idx_list_int]\n    assert result_list_int.batch.keys() == data.batch.keys()\n    assert result_list_int.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_list_int.batch[\"obs\"].shape[0] == idx_num\n    assert result_list_int.non_tensor_batch[\"labels\"].shape[0] == idx_num\n    assert np.array_equal(result_list_int.batch[\"obs\"].cpu().numpy(), obs[idx_list_int].cpu().numpy())\n    assert np.array_equal(result_list_int.non_tensor_batch[\"labels\"], labels_np[idx_list_int])\n\n    idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool)\n    result_np_bool = data[idx_np_bool]\n    assert result_np_bool.batch.keys() == data.batch.keys()\n    assert result_np_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_np_bool.batch[\"obs\"].shape[0] == idx_np_bool.sum()\n    assert result_np_bool.non_tensor_batch[\"labels\"].shape[0] == idx_np_bool.sum()\n    assert np.array_equal(result_np_bool.batch[\"obs\"].cpu().numpy(), obs[idx_np_bool].cpu().numpy())\n    assert np.array_equal(result_np_bool.non_tensor_batch[\"labels\"], labels_np[idx_np_bool])\n\n    idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool)\n    result_torch_bool = data[idx_torch_bool]\n    assert result_torch_bool.batch.keys() == data.batch.keys()\n    assert result_torch_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_torch_bool.batch[\"obs\"].shape[0] == idx_torch_bool.sum().item()\n    assert result_torch_bool.non_tensor_batch[\"labels\"].shape[0] == idx_torch_bool.sum().item()\n    assert np.array_equal(result_torch_bool.batch[\"obs\"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy())\n    assert np.array_equal(result_torch_bool.non_tensor_batch[\"labels\"], labels_np[idx_torch_bool])\n\n    idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)]\n    result_list_bool = data[idx_list_bool]\n    assert result_list_bool.batch.keys() == data.batch.keys()\n    assert result_list_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys()\n    assert result_list_bool.batch[\"obs\"].shape[0] == sum(idx_list_bool)\n    assert result_list_bool.non_tensor_batch[\"labels\"].shape[0] == sum(idx_list_bool)\n    assert np.array_equal(result_list_bool.batch[\"obs\"].cpu().numpy(), obs[idx_list_bool].cpu().numpy())\n    assert np.array_equal(result_list_bool.non_tensor_batch[\"labels\"], labels_np[idx_list_bool])\n\n\ndef test_old_vs_new_from_single_dict():\n    class CustomProto(DataProto):\n        \"\"\"Uses the new, fixed from_single_dict.\"\"\"\n\n        pass\n\n    class OriginProto(DataProto):\n        \"\"\"Mimics the *old* from_single_dict (always returns a DataProto).\"\"\"\n\n        @classmethod\n        def from_single_dict(cls, data, meta_info=None, auto_padding=False):\n            tensors, non_tensors = {}, {}\n            for k, v in data.items():\n                if torch.is_tensor(v):\n                    tensors[k] = v\n                else:\n                    non_tensors[k] = v\n            # always calls DataProto.from_dict, ignoring `cls`\n            return DataProto.from_dict(\n                tensors=tensors,\n                non_tensors=non_tensors,\n                meta_info=meta_info,\n                auto_padding=auto_padding,\n            )\n\n    sample = {\"x\": torch.tensor([0])}\n\n    orig = OriginProto.from_single_dict(sample)\n    # old behavior: always DataProto, not a CustomOriginProto\n    assert type(orig) is DataProto\n    assert type(orig) is not OriginProto\n\n    cust = CustomProto.from_single_dict(sample)\n    # new behavior: respects subclass\n    assert type(cust) is CustomProto\n\n\ndef test_dataproto_no_batch():\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n    selected = data.select(non_tensor_batch_keys=[\"labels\"])\n    assert (selected.non_tensor_batch[\"labels\"] == labels).all()\n    pop_data = data.pop(non_tensor_batch_keys=[\"labels\"])\n    assert (pop_data.non_tensor_batch[\"labels\"] == labels).all()\n    assert data.non_tensor_batch == {}\n\n\ndef test_sample_level_repeat():\n    # Create a DataProto object with some batch and non-tensor data\n    obs = torch.tensor([[1, 2], [3, 4], [5, 6]])\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"info\": \"test_info\"})\n\n    # list\n    repeated_data_interleave = data.sample_level_repeat(repeat_times=[3, 1, 2])\n    expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]])\n    expected_labels_interleave = [\"a\", \"a\", \"a\", \"b\", \"c\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_interleave.batch[\"obs\"], expected_obs_interleave))\n    assert (repeated_data_interleave.non_tensor_batch[\"labels\"] == expected_labels_interleave).all()\n    assert repeated_data_interleave.meta_info == {\"info\": \"test_info\"}\n\n    # torch.tensor\n    repeated_data_no_interleave = data.sample_level_repeat(repeat_times=torch.tensor([1, 2, 3]))\n    expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]])\n    expected_labels_no_interleave = [\"a\", \"b\", \"b\", \"c\", \"c\", \"c\"]\n\n    assert torch.all(torch.eq(repeated_data_no_interleave.batch[\"obs\"], expected_obs_no_interleave))\n    assert (repeated_data_no_interleave.non_tensor_batch[\"labels\"] == expected_labels_no_interleave).all()\n    assert repeated_data_no_interleave.meta_info == {\"info\": \"test_info\"}\n\n\ndef test_dataproto_unfold_column_chunks():\n    obs1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])\n    obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]])\n\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(\n        tensors={\"obs1\": obs1, \"obs2\": obs2}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abc\"}\n    )\n    ret = data.unfold_column_chunks(2, split_keys=[\"obs1\"])\n\n    expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])\n    expect_obs2 = torch.tensor([[1, 2], [1, 2], [5, 6], [5, 6], [9, 10], [9, 10]])\n    expect_labels = [\"a\", \"a\", \"b\", \"b\", \"c\", \"c\"]\n    assert torch.all(torch.eq(ret.batch[\"obs1\"], expect_obs1))\n    assert torch.all(torch.eq(ret.batch[\"obs2\"], expect_obs2))\n    assert (ret.non_tensor_batch[\"labels\"] == expect_labels).all()\n    assert ret.meta_info == {\"name\": \"abc\"}\n\n    obs1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])\n    obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]])\n\n    labels = [[\"a1\", \"a2\"], [\"b1\", \"b2\"], [\"c1\", \"c2\"]]\n    data = DataProto.from_dict(\n        tensors={\"obs1\": obs1, \"obs2\": obs2}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abc\"}\n    )\n    ret = data.unfold_column_chunks(2, split_keys=[\"obs1\", \"labels\"])\n\n    expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])\n    expect_obs2 = torch.tensor([[1, 2], [1, 2], [5, 6], [5, 6], [9, 10], [9, 10]])\n    expect_labels = [[\"a1\"], [\"a2\"], [\"b1\"], [\"b2\"], [\"c1\"], [\"c2\"]]\n    assert torch.all(torch.eq(ret.batch[\"obs1\"], expect_obs1))\n    assert torch.all(torch.eq(ret.batch[\"obs2\"], expect_obs2))\n    assert (ret.non_tensor_batch[\"labels\"] == expect_labels).all()\n    assert ret.meta_info == {\"name\": \"abc\"}\n\n    obs1 = torch.tensor(\n        [[[1, 1], [2, 2], [3, 3], [4, 4]], [[5, 5], [6, 6], [7, 7], [8, 8]], [[9, 9], [10, 10], [11, 11], [12, 12]]]\n    )\n    obs2 = torch.tensor([[[1, 1], [2, 2]], [[5, 5], [6, 6]], [[9, 9], [10, 10]]])\n\n    labels = [\"a\", \"b\", \"c\"]\n    data = DataProto.from_dict(\n        tensors={\"obs1\": obs1, \"obs2\": obs2}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abc\"}\n    )\n    ret = data.unfold_column_chunks(2, split_keys=[\"obs1\"])\n\n    expect_obs1 = torch.tensor(\n        [\n            [[1, 1], [2, 2]],\n            [[3, 3], [4, 4]],\n            [[5, 5], [6, 6]],\n            [[7, 7], [8, 8]],\n            [[9, 9], [10, 10]],\n            [[11, 11], [12, 12]],\n        ]\n    )\n    expect_obs2 = torch.tensor(\n        [[[1, 1], [2, 2]], [[1, 1], [2, 2]], [[5, 5], [6, 6]], [[5, 5], [6, 6]], [[9, 9], [10, 10]], [[9, 9], [10, 10]]]\n    )\n    expect_labels = [\"a\", \"a\", \"b\", \"b\", \"c\", \"c\"]\n    assert torch.all(torch.eq(ret.batch[\"obs1\"], expect_obs1))\n    assert torch.all(torch.eq(ret.batch[\"obs2\"], expect_obs2))\n    assert (ret.non_tensor_batch[\"labels\"] == expect_labels).all()\n    assert ret.meta_info == {\"name\": \"abc\"}\n\n\ndef test_dataproto_chunk_after_index():\n    data_len = 4\n    obs = torch.randn(data_len, 4)\n    labels = [f\"label_{i}\" for i in range(data_len)]\n    data = DataProto.from_dict(tensors={\"obs\": obs}, non_tensors={\"labels\": labels}, meta_info={\"name\": \"abc\"})\n\n    # Test with boolean numpy array\n    bool_mask = np.array([True, False, True, False])\n    selected = data[bool_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)  # int or List[int]\n\n    # Test with integer numpy array\n    int_mask = np.array([0, 2])\n    selected = data[int_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)\n\n    # Test with boolean list\n    list_mask = [True, False, True, False]\n    selected = data[list_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)\n\n    # Test with list\n    list_mask = [0, 2]\n    selected = data[list_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)\n\n    # Test with torch tensor (bool)\n    torch_bool_mask = torch.tensor([True, False, True, False])\n    selected = data[torch_bool_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)\n\n    # Test with torch tensor (int)\n    torch_int_mask = torch.tensor([0, 2])\n    selected = data[torch_int_mask]\n    assert isinstance(selected.batch.batch_size, torch.Size)\n    assert all(isinstance(d, int) for d in selected.batch.batch_size)\n"
  },
  {
    "path": "verl_rl/tests/tools/test_base_tool_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Unit Tests for `initialize_tools_from_config`\nimport json\nimport os\nfrom typing import Any\n\nimport pytest\nfrom transformers.utils import get_json_schema\n\nfrom verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema\nfrom verl.tools.utils.tool_registry import initialize_tools_from_config\n\n\nclass WeatherToolForTest(BaseTool):\n    def get_current_temperature(self, location: str, unit: str = \"celsius\"):\n        \"\"\"Get current temperature at a location.\n\n        Args:\n            location: The location to get the temperature for, in the format \"City, State, Country\".\n            unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n        Returns:\n            the temperature, the location, and the unit in a dict\n        \"\"\"\n        return {\n            \"temperature\": 26.1,\n            \"location\": location,\n            \"unit\": unit,\n        }\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        schema = get_json_schema(self.get_current_temperature)\n        return OpenAIFunctionToolSchema(**schema)\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        try:\n            result = self.get_current_temperature(**parameters)\n            return json.dumps(result), 0, {}\n        except Exception as e:\n            return str(e), 0, {}\n\n\nclass WeatherToolWithDataForTest(BaseTool):\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        schema = get_json_schema(self.get_temperature_date)\n        return OpenAIFunctionToolSchema(**schema)\n\n    def get_temperature_date(self, location: str, date: str, unit: str = \"celsius\"):\n        \"\"\"Get temperature at a location and date.\n\n        Args:\n            location: The location to get the temperature for, in the format \"City, State, Country\".\n            date: The date to get the temperature for, in the format \"Year-Month-Day\".\n            unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n        Returns:\n            the temperature, the location, the date and the unit in a dict\n        \"\"\"\n        return {\n            \"temperature\": 25.9,\n            \"location\": location,\n            \"date\": date,\n            \"unit\": unit,\n        }\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        try:\n            result = self.get_temperature_date(**parameters)\n            return json.dumps(result), 0, {}\n        except Exception as e:\n            return str(e), 0, {}\n\n\n@pytest.fixture\ndef create_local_tool_config():\n    tool_config = {\n        \"tools\": [\n            {\n                \"class_name\": \"tests.tools.test_base_tool_on_cpu.WeatherToolForTest\",\n                \"config\": {\"type\": \"native\"},\n            },\n            {\n                \"class_name\": \"tests.tools.test_base_tool_on_cpu.WeatherToolWithDataForTest\",\n                \"config\": {\"type\": \"native\"},\n            },\n        ]\n    }\n    tool_config_path = \"/tmp/tool_config.json\"\n    with open(tool_config_path, \"w\") as f:\n        json.dump(tool_config, f)\n    yield tool_config_path\n    if os.path.exists(tool_config_path):\n        os.remove(tool_config_path)\n\n\n@pytest.fixture\ndef create_fake_tool_config():\n    tool_config = {\n        \"tools\": [\n            {\n                \"class_name\": \"tests.workers.rollout.fake_path.test_vllm_chat_scheduler.WeatherTool\",\n                \"config\": {\"type\": \"native\"},\n            },\n            {\n                \"class_name\": \"tests.workers.rollout.fake_path.test_vllm_chat_scheduler.WeatherToolWithData\",\n                \"config\": {\"type\": \"native\"},\n            },\n        ]\n    }\n    tool_config_path = \"/tmp/tool_config.json\"\n    with open(tool_config_path, \"w\") as f:\n        json.dump(tool_config, f)\n    yield tool_config_path\n    if os.path.exists(tool_config_path):\n        os.remove(tool_config_path)\n\n\ndef test_initialize_tools_from_fake_config(create_fake_tool_config):\n    tool_config_path = create_fake_tool_config\n\n    # Use pytest.raises to check if an exception is raised when calling initialize_tools_from_config.\n    # Since the tool configuration uses fake paths, an exception is expected during the tool initialization process.\n    with pytest.raises(ModuleNotFoundError):\n        _ = initialize_tools_from_config(tool_config_path)\n\n\ndef test_initialize_tools_from_local_config(create_local_tool_config):\n    \"\"\"\n    Test the `initialize_tools_from_config` function using a local tool configuration.\n    This test verifies that the function can correctly initialize tools based on a local configuration file.\n\n    Args:\n        create_local_tool_config: A pytest fixture that creates a local tool configuration file\n                                  and returns its path. After the test is completed, the fixture\n                                  will clean up the configuration file.\n    \"\"\"\n    # Retrieve the path of the local tool configuration file generated by the fixture\n    tool_config_path = create_local_tool_config\n\n    tools = initialize_tools_from_config(tool_config_path)\n\n    assert len(tools) == 2\n    from tests.tools.test_base_tool_on_cpu import WeatherToolForTest, WeatherToolWithDataForTest\n\n    assert isinstance(tools[0], WeatherToolForTest)\n    assert isinstance(tools[1], WeatherToolWithDataForTest)\n    assert tools[0].config == {\"type\": \"native\"}\n    assert tools[1].config == {\"type\": \"native\"}\n"
  },
  {
    "path": "verl_rl/tests/trainer/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTests for the trainer module.\n\"\"\"\n"
  },
  {
    "path": "verl_rl/tests/trainer/config/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/tests/trainer/config/legacy_ppo_megatron_trainer.yaml",
    "content": "data:\n  tokenizer: null\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n  prompt_key: prompt\n  reward_fn_key: data_source\n  max_prompt_length: 512\n  max_response_length: 512\n  train_batch_size: 1024\n  val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves\n  return_raw_input_ids: False  # This should be set to true when the tokenizer between policy and rm differs\n  return_raw_chat: False\n  return_full_prompt: False\n  shuffle: True\n  filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.\n  filter_overlong_prompts_workers: 1\n  truncation: error\n  trust_remote_code: False  # main_ppo will check this config to determine whether to use remote code for tokenizer\n  custom_cls:\n      path: null\n      name: null\n  sampler:\n    class_path: null\n    class_name: null\n  dataloader_num_workers: 8\n  return_multi_modal_inputs: True\n\nactor_rollout_ref:\n  hybrid_engine: True\n  nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    custom_chat_template: null\n    external_lib: null\n    override_config:\n      model_config: {}\n      moe_config:\n        freeze_moe_router: False\n    enable_gradient_checkpointing: False\n    gradient_checkpointing_kwargs:\n      ## Activation Checkpointing\n      activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective'\n      # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk\n      # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity\n      activations_checkpoint_granularity: null # 'selective' or 'full'\n      # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention\n      activations_checkpoint_num_layers: null # not used with 'selective'\n    trust_remote_code: False\n  actor:\n    strategy: megatron  # This is for backward-compatibility\n    ppo_mini_batch_size: 256\n    ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu\n    ppo_micro_batch_size_per_gpu: null\n    use_dynamic_bsz: False\n    ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}\n    use_torch_compile: True # False to disable torch compile\n    # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)\n    clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified\n    clip_ratio_low: 0.2\n    clip_ratio_high: 0.2\n    clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729\n    loss_agg_mode: \"token-mean\" # / \"seq-mean-token-sum\" / \"seq-mean-token-mean\"\n    # NOTE: \"token-mean\" is the default behavior\n    entropy_coeff: 0\n    use_kl_loss: False # True for GRPO\n    kl_loss_coef: 0.001 # for grpo\n    kl_loss_type: low_var_kl # for grpo\n    ppo_epochs: 1\n    data_loader_seed: null\n    shuffle: False\n    policy_loss:   # policy loss config\n      loss_mode: \"vanilla\" # Loss function mode: vanilla / clip-cov / kl-cov / gpg from https://arxiv.org/abs/2505.22617,\n      clip_cov_ratio: 0.0002 # Ratio of tokens to be clipped for clip-cov loss\n      clip_cov_lb: 1.0 # Lower bound for clip-cov loss\n      clip_cov_ub: 5.0 # Upper bound for clip-cov loss\n      kl_cov_ratio: 0.0002 # Ratio of tokens to be applied kl penalty for kl-cov loss\n      ppo_kl_coef: 0.1 # KL divergence penalty coefficient\n    optim:\n      optimizer: adam\n      lr: 1e-6\n      clip_grad: 1.0\n      total_training_steps: -1  # must be override by program\n      lr_warmup_init: 0.0  # initial learning rate for warmup, default to 0.0\n      lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.\n      lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n      lr_decay_steps: null\n      lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root\n      min_lr: 0.0 # minimum learning rate, default to 0.0\n      weight_decay: 0.01\n      weight_decay_incr_style: constant # select from constant/linear/cosine\n      lr_wsd_decay_style: exponential # select from constant/exponential/cosine\n      lr_wsd_decay_steps: null\n      use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler\n    megatron:\n      param_offload: False\n      grad_offload: False\n      optimizer_offload: False\n      tensor_model_parallel_size: 1\n      expert_model_parallel_size: 1\n      expert_tensor_parallel_size: null\n      pipeline_model_parallel_size: 1\n      virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests\n      context_parallel_size: 1\n      sequence_parallel: True\n      use_distributed_optimizer: True\n      use_dist_checkpointing: False\n      dist_checkpointing_path: null\n      seed: 42\n      override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage\n      use_mbridge: False\n    profile: # profile the actor model in `update_policy`\n      use_profile: False # open it when you want to profile the actor model\n      profile_ranks: null # list, you can specify the ranks to profile\n      step_start: -1 # start step in update_policy\n      step_end: -1 # end step\n      save_path: null # the path to save the profile result\n    load_weight: True\n    checkpoint:\n      async_save: False # save checkpoint asynchronously\n      # What to include in saved checkpoints\n      # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n      save_contents: ['model', 'optimizer', 'extra']\n      # For more flexibility, you can specify the contents to load from the checkpoint.\n      load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}\n  ref:\n    strategy: ${actor_rollout_ref.actor.strategy}\n    use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}\n    megatron:\n      param_offload: False\n      tensor_model_parallel_size: 1\n      expert_model_parallel_size: 1\n      expert_tensor_parallel_size: None\n      pipeline_model_parallel_size: 1\n      virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests\n      context_parallel_size: 1\n      sequence_parallel: True\n      use_distributed_optimizer: False\n      use_dist_checkpointing: False\n      dist_checkpointing_path: null\n      seed: ${actor_rollout_ref.actor.megatron.seed}\n      override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}\n      use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}\n    profile:\n      use_profile: False\n      profile_ranks: null\n      step_start: -1\n      step_end: -1\n      save_path: null\n    load_weight: True\n    log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n  rollout:\n    name: vllm\n    mode: sync # sync: LLM, async: AsyncLLM\n    temperature: 1.0\n    top_k: -1 # 0 for hf rollout, -1 for vllm rollout\n    top_p: 1\n    prompt_length: ${data.max_prompt_length}  # for xperf_gpt\n    response_length: ${data.max_response_length}\n    # for vllm rollout\n    dtype: bfloat16 # should align with FSDP\n    gpu_memory_utilization: 0.5\n    ignore_eos: False\n    enforce_eager: True\n    free_cache_engine: True\n    load_format: dummy_megatron\n    tensor_model_parallel_size: 1\n    max_num_batched_tokens: 8192\n    max_model_len: null\n    max_num_seqs: 1024\n    log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n    disable_log_stats: True\n    enable_chunked_prefill: False # could get higher throughput\n    # for hf rollout\n    do_sample: True\n    layer_name_map:\n      qkv_layer_name: qkv\n      gate_proj_layer_name: gate_up\n    # number of responses (i.e. num sample times)\n    n: 1\n    engine_kwargs: # inference engine parameters\n      vllm:\n        swap_space: null # null means \"use the engine default value\" (usually 4 GB), setting it to, e.g., 32 means 32 GB\n        disable_mm_preprocessor_cache: False # whether to disable the preprocessor cache for multimodel models.\n      sglang:\n        attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla\n    val_kwargs:\n      # sampling parameters for validation\n      top_k: -1 # 0 for hf rollout, -1 for vllm rollout\n      top_p: 1.0\n      temperature: 0\n      n: 1\n      do_sample: False # default eager for validation\n\n    # Multi-turn interaction config for tools or chat.\n    multi_turn:\n      # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well\n      enable: False\n\n      # null for no limit (default max_length // 3)\n      max_assistant_turns: null\n\n      # null for no tool\n      tool_config_path: null\n\n      # null for no limit (default max_length // 3)\n      max_user_turns: null\n\n      # max parallel call for tools in single turn\n      max_parallel_calls: 1\n\n      # max length of tool response\n      max_tool_response_length: 256\n\n      # truncate side of tool response: left, middle, right\n      tool_response_truncate_side: middle\n\n      # null for no interaction\n      interaction_config_path: null\n\n      # null for default callback\n      completion_callback: null\n\n      # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.\n      # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,\n      #   which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.\n      use_inference_chat_template: False\n\n      # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.\n      # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.\n      # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.\n      # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:\n      # Qwen/QwQ-32B, Qwen/Qwen3-xxB\n      # - disable: disable tokenization sanity check\n      # - strict: enable strict tokenization sanity check (default)\n      # - ignore_strippable: ignore strippable tokens when checking tokenization sanity\n      tokenization_sanity_check_mode: strict\n\n      # Format of the multi-turn interaction. Options: hermes, llama3_json, ...\n      format: hermes\n\n    # [Experimental] agent loop based rollout configs\n    agent:\n\n      # Number of agent loop workers\n      num_workers: 8\n\n      custom_async_server:\n        path: null\n        name: null\n\n    # support logging rollout prob for debugging purpose\n    calculate_log_probs: False\n    # Nsight system profiler configs\n  profiler:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: False\n    all_ranks: False\n    ranks: []\n\ncritic:\n  rollout_n: ${actor_rollout_ref.rollout.n}\n  strategy: ${actor_rollout_ref.actor.strategy}\n  nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron\n  optim:\n    optimizer: adam\n    lr: 1e-6\n    clip_grad: 1.0\n    total_training_steps: -1  # must be override by program\n    lr_warmup_init: 0.0  # initial learning rate for warmup, default to 0.0\n    lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.\n    lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime\n    lr_decay_steps: null\n    lr_decay_style: linear # select from constant/linear/cosine/inverse_square_root\n    min_lr: 0.0 # minimum learning rate, default to 0.0\n    weight_decay: 0.01\n    weight_decay_incr_style: constant # select from constant/linear/cosine\n    lr_wsd_decay_style: exponential # select from constant/exponential/cosine\n    lr_wsd_decay_steps: null\n    use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    tokenizer_path: ${actor_rollout_ref.model.path}\n    override_config:\n      model_config: {}\n      moe_config:\n        freeze_moe_router: False\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    trust_remote_code: False\n    enable_gradient_checkpointing: False\n    gradient_checkpointing_kwargs:\n      ## Activation Checkpointing\n      activations_checkpoint_method: null\n      activations_checkpoint_granularity: null\n      activations_checkpoint_num_layers: null\n  megatron:\n    param_offload: False\n    grad_offload: False\n    optimizer_offload: False\n    tensor_model_parallel_size: 1\n    expert_model_parallel_size: 1\n    expert_tensor_parallel_size: null\n    pipeline_model_parallel_size: 1\n    virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests\n    context_parallel_size: 1\n    sequence_parallel: True\n    use_distributed_optimizer: True\n    use_dist_checkpointing: False\n    dist_checkpointing_path: null\n    seed: ${actor_rollout_ref.actor.megatron.seed}\n    override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}\n    use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}\n  load_weight: True\n  ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}\n  ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu\n  ppo_micro_batch_size_per_gpu: null\n  use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n  ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2\n  forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}\n  ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}\n  data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed}\n  shuffle: ${actor_rollout_ref.actor.shuffle}\n  cliprange_value: 0.5\n  loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}\n  checkpoint:\n    async_save: False # save checkpoint asynchronously\n    # What to include in saved checkpoints\n    # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n    save_contents: ['model', 'optimizer', 'extra']\n    load_contents: ${critic.checkpoint.save_contents}\n  # Nsight system profiler configs\n  profiler:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: False\n    all_ranks: False\n    ranks: []\nreward_model:\n  enable: False\n  strategy: ${actor_rollout_ref.actor.strategy}\n  nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron\n  megatron:\n    param_offload: False\n    tensor_model_parallel_size: 1\n    expert_model_parallel_size: 1\n    expert_tensor_parallel_size: null\n    pipeline_model_parallel_size: 1\n    virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests\n    context_parallel_size: 1\n    sequence_parallel: True\n    use_distributed_optimizer: False\n    use_dist_checkpointing: False\n    dist_checkpointing_path: null\n    seed: ${actor_rollout_ref.actor.megatron.seed}\n    override_transformer_config: {}\n    use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}\n  model:\n    input_tokenizer: ${actor_rollout_ref.model.path}  # set this to null if the chat template is identical\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n    trust_remote_code: False\n    external_lib: ${actor_rollout_ref.model.external_lib}\n  load_weight: True\n  micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu\n  micro_batch_size_per_gpu: null\n  use_dynamic_bsz: ${critic.use_dynamic_bsz}\n  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n  max_length: null\n  reward_manager: naive\n  launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob\n  sandbox_fusion:\n    url: null # faas url to run code in cloud sandbox\n    max_concurrent: 64 # max concurrent requests to sandbox\n    memory_limit_mb: 1024 # Max memory limit for each sandbox process in MB\n  # Nsight system profiler configs\n  profiler:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: False\n    all_ranks: False\n    ranks: []\n\ncustom_reward_function:\n  path: null\n  name: compute_score\n\nalgorithm:\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n  _target_: verl.trainer.config.AlgoConfig\n  gamma: 1.0\n  lam: 1.0\n  adv_estimator: gae\n  norm_adv_by_std_in_grpo: True\n  use_kl_in_reward: False\n  kl_penalty: kl  # how to estimate kl divergence\n  kl_ctrl:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.trainer.config.KLControlConfig\n    type: fixed\n    kl_coef: 0.001\n    horizon: 10000\n    target_kl: 0.1\n  use_pf_ppo: False\n  pf_ppo:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.trainer.config.PFPPOConfig\n    reweight_method: pow  # [\"pow\", \"max_min\", \"max_random\"]\n    weight_pow: 2.0\n\ntrainer:\n  balance_batch: True\n  total_epochs: 30\n  total_training_steps: null\n  profile_steps: null # [1,2,5] or [] or null\n  project_name: verl_examples\n  experiment_name: gsm8k\n  logger: ['console', 'wandb']\n  log_val_generations: 0\n  nnodes: 1\n  n_gpus_per_node: 8\n  save_freq: -1\n  esi_redundant_time: 0\n\n  # auto: find the last ckpt to resume. If can't find, start from scratch\n  resume_mode: auto # or disable or resume_path if resume_from_path is set\n  resume_from_path: null\n  del_local_ckpt_after_load: False\n  val_before_train: True\n  test_freq: -1\n  critic_warmup: 0\n  default_hdfs_dir: null\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  max_actor_ckpt_to_keep: null\n  max_critic_ckpt_to_keep: null\n  # The timeout for ray worker group to wait for the register center to be ready\n  ray_wait_register_center_timeout: 300\n  device: cuda\n  # see ppo_trainer.yaml for more details\n  controller_nsight_options:\n    trace: \"cuda,nvtx,cublas,ucx\"\n    cuda-memory-usage: \"true\"\n    cuda-graph-trace: \"graph\"\n  worker_nsight_options:\n    trace: \"cuda,nvtx,cublas,ucx\"\n    cuda-memory-usage: \"true\"\n    cuda-graph-trace: \"graph\"\n    capture-range: \"cudaProfilerApi\"\n    capture-range-end: null\n    kill: none\n  npu_profile:\n    options:\n      save_path: ./profiler_data\n      level: level1\n      with_memory: False\n      record_shapes: False\n      with_npu: True\n      with_cpu: True\n      with_module: False\n      with_stack: False\n      analysis: True\n\nray_init:\n  num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_rl/tests/trainer/config/legacy_ppo_trainer.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# dataset config\ndata:\n\n  # Tokenizer class or path. If null, it will be inferred from the model.\n  tokenizer: null\n\n  # Whether to use shared memory for data loading.\n  use_shm: False\n\n  # Training set parquet. Can be a list or a single file.\n  # The program will read all files into memory, so it can't be too large (< 100GB).\n  # The path can be either a local path or an HDFS path.\n  # For HDFS path, we provide utils to download it to DRAM and convert it to a local path.\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n\n  # Validation parquet. Can be a list or a single file.\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n\n  # The field in the dataset where the prompt is located. Default is 'prompt'.\n  prompt_key: prompt\n\n  # The field used to select the reward function (if using different ones per example).\n  reward_fn_key: data_source\n\n  # Maximum prompt length. All prompts will be left-padded to this length.\n  # An error will be reported if the length is too long.\n  max_prompt_length: 512\n\n  # Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length.\n  max_response_length: 512\n\n  # Batch size sampled for one training iteration of different RL algorithms.\n  train_batch_size: 1024\n\n  # Batch size used during validation. Can be null.\n  val_batch_size: null\n\n  # Whether to return the original input_ids without adding chat template.\n  # This is used when the reward model's chat template differs from the policy.\n  # If using a model-based RM with different templates, this should be True.\n  return_raw_input_ids: False\n\n  # Whether to return the original chat (prompt) without applying chat template.\n  return_raw_chat: False\n\n  # Whether to return the full prompt with chat template.\n  return_full_prompt: False\n\n  # Whether to shuffle the data in the dataloader.\n  shuffle: True\n\n  # num dataloader workers\n  dataloader_num_workers: 8\n\n  # Whether to shuffle the validation set.\n  validation_shuffle: False\n\n  # Whether to filter overlong prompts.\n  filter_overlong_prompts: False\n\n  # Number of workers for filtering overlong prompts.\n  # For large-scale datasets, filtering can be time-consuming.\n  # Use multiprocessing to speed up. Default is 1.\n  filter_overlong_prompts_workers: 1\n\n  # Truncate the input_ids or prompt if they exceed max_prompt_length.\n  # Options: 'error', 'left', or 'right'. Default is 'error'.\n  truncation: error\n\n  # The field in the multi-modal dataset where the image is located. Default is 'images'.\n  image_key: images\n\n  # The field in the multi-modal dataset where the video is located.\n  video_key: videos\n\n  # If the remote tokenizer has a Python file, this flag determines whether to allow using it.\n  trust_remote_code: False\n\n  # Optional: specify a custom dataset class path and name if overriding default loading behavior.\n  custom_cls:\n\n    # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used.\n    path: null\n\n    # The name of the dataset class within the specified file.\n    name: null\n\n  # Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs.\n  return_multi_modal_inputs: True\n\n  # Data generation configuration for augmenting the dataset.\n  datagen:\n\n    # The path to the file containing your customized data generation class.\n    # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset'\n    path: null\n\n    # The class name of the data generation class within the specified file.\n    # E.g. 'MockDataGenerator'\n    name: null\n\n  # settings related to data sampler\n  sampler:\n\n    # the path to the module containing a curriculum class which implements the\n    # AbstractSampler interface\n    class_path: null\n\n    # the name of the curriculum class like `MySampler`\n    class_name: null\n\n# config for actor, rollout and reference model\nactor_rollout_ref:\n\n  # Whether it's a hybrid engine, currently only supports hybrid engine\n  hybrid_engine: true\n\n  # common configs for the model\n  model:\n\n    # Huggingface model path. This can be either local path or HDFS path.\n    path: ~/models/deepseek-llm-7b-chat\n\n    # Custom chat template for the model.\n    custom_chat_template: null\n\n    # Whether to use shared memory (SHM) for accelerating the loading of model weights\n    use_shm: false\n\n    # Additional Python packages to register huggingface models/tokenizers.\n    external_lib: null\n\n    # Used to override model's original configurations, mainly dropout\n    override_config: {}\n\n    # Enable gradient checkpointing for actor\n    enable_gradient_checkpointing: true\n\n    # Enable activation offloading for actor\n    enable_activation_offload: false\n\n    # Whether to remove padding tokens in inputs during training\n    use_remove_padding: false\n\n    # Set to positive value to enable LoRA (e.g., 32)\n    lora_rank: 0\n\n    # LoRA scaling factor\n    lora_alpha: 16\n\n    # Target modules to apply LoRA. Options: \"all-linear\" (not recommended for VLMs) or\n    # [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj]\n    target_modules: all-linear\n\n    # Exclude modules from applying Lora. Similar usage to target_modules and Peft.\n    # Example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as currently vllm does not support ViT Lora.\n    exclude_modules: null\n\n    # Whether to use Liger for linear layer fusion\n    use_liger: false\n\n    # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP)\n    use_fused_kernels: false\n\n    # Options for fused kernels. If use_fused_kernels is true, this will be used.\n    fused_kernel_options:\n\n      # Implementation backend for fused kernels. Options: \"triton\" or \"torch\".\n      impl_backend: torch\n\n    # Whether to enable loading a remote code model\n    trust_remote_code: false\n\n  # actor configs\n  actor:\n\n    # fsdp, fsdp2 or megatron. fsdp backend used here.\n    strategy: fsdp\n\n    # Split each sample into sub-batches of this size for PPO\n    ppo_mini_batch_size: 256\n\n    # [Deprecated] Global micro batch size\n    ppo_micro_batch_size: null\n\n    # Local per-GPU micro batch size\n    ppo_micro_batch_size_per_gpu: null\n\n    # Whether to automatically adjust batch size at runtime\n    use_dynamic_bsz: false\n\n    # Max tokens per GPU in one PPO batch; affects gradient accumulation\n    # Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length}\n    ppo_max_token_len_per_gpu: 16384\n\n    # Gradient clipping for actor updates\n    grad_clip: 1.0\n\n    # PPO clip ratio\n    clip_ratio: 0.2\n\n    # Lower bound for asymmetric clipping (used in dual-clip PPO)\n    clip_ratio_low: 0.2\n\n    # Upper bound for asymmetric clipping (used in dual-clip PPO)\n    clip_ratio_high: 0.2\n\n    # policy loss config\n    policy_loss:\n\n      # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617\n      loss_mode: \"vanilla\"\n\n      # Ratio of tokens to be clipped for clip-cov loss\n      clip_cov_ratio: 0.0002\n\n      # Lower bound for clip-cov loss\n      clip_cov_lb: 1.0\n\n      # Upper bound for clip-cov loss\n      clip_cov_ub: 5.0\n\n      # Ratio of tokens to be applied kl penalty for kl-cov loss\n      kl_cov_ratio: 0.0002\n\n      # KL divergence penalty coefficient\n      ppo_kl_coef: 0.1\n\n    # Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C\n    clip_ratio_c: 3.0\n\n    # Loss aggregation mode: \"token-mean\", \"seq-mean-token-sum\", or \"seq-mean-token-mean\"\n    loss_agg_mode: token-mean\n\n    # Entropy regularization coefficient in PPO loss\n    entropy_coeff: 0\n\n    # Whether to use KL loss instead of KL reward penalty. True for GRPO\n    use_kl_loss: false\n\n    # Whether to use torch.compile()\n    use_torch_compile: true\n\n    # KL loss coefficient when use_kl_loss is enabled. For GRPO\n    kl_loss_coef: 0.001\n\n    # Type of KL divergence loss. Options: \"kl\"(k1), \"abs\", \"mse\"(k2), \"low_var_kl\"(k3), \"full\"\n    kl_loss_type: low_var_kl\n\n    # Number of PPO epochs per batch\n    ppo_epochs: 1\n\n    # Shuffle training data across PPO epochs\n    shuffle: false\n\n    # Sequence parallelism size for Ulysses-style model parallelism\n    ulysses_sequence_parallel_size: 1\n\n    # calculate entropy with chunking to reduce memory peak\n    entropy_from_logits_with_chunking: False\n\n    # recompute entropy\n    entropy_checkpointing: False\n\n    # checkpoint configs\n    checkpoint:\n\n      # What to include in saved checkpoints\n      # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n      save_contents: ['model', 'optimizer', 'extra']\n\n      # For more flexibility, you can specify the contents to load from the checkpoint.\n      load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}\n\n    # optimizer configs\n    optim:\n\n      # Learning rate\n      lr: 1e-6\n\n      # Warmup steps; negative value delegates to lr_warmup_steps_ratio\n      lr_warmup_steps: -1\n\n      # Warmup steps ratio (used if lr_warmup_steps is negative)\n      lr_warmup_steps_ratio: 0.0\n\n      # Minimum LR ratio for cosine schedule\n      min_lr_ratio: 0.0\n\n      # Number of cosine cycles in LR schedule\n      num_cycles: 0.5\n\n      # LR warmup style: \"constant\" or \"cosine\"\n      warmup_style: constant\n\n      # Total training steps (must be overridden at runtime)\n      total_training_steps: -1\n\n      # Weight decay\n      weight_decay: 0.01\n\n    # configs for FSDP\n    fsdp_config:\n\n      # policy for wrapping the model\n      wrap_policy:\n\n        # Minimum number of parameters to trigger wrapping a layer with FSDP\n        min_num_params: 0\n\n      # Whether to offload model parameters to CPU (trades speed for memory)\n      param_offload: false\n\n      # Whether to offload optimizer state to CPU\n      optimizer_offload: false\n\n      # Only for FSDP2: offload param/grad/optimizer during train\n      offload_policy: false\n\n      # Only for FSDP2: Reshard after forward pass to reduce memory footprint\n      reshard_after_forward: true\n\n      # Number of GPUs in each FSDP shard group; -1 means auto\n      fsdp_size: -1\n\n      # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n      # before the current forward computation.\n      forward_prefetch: False\n\n  # Reference model config.\n  # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True.\n  ref:\n\n    # actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default\n    strategy: ${actor_rollout_ref.actor.strategy}\n\n    # config for FSDP strategy\n    fsdp_config:\n\n      # whether to offload parameters in FSDP\n      param_offload: False\n\n      # whether to perform reshard after model forward to save memory.\n      # only for fsdp2, [True, False, int between 1 and fsdp_size]\n      reshard_after_forward: True\n\n      # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n      # before the current forward computation.\n      forward_prefetch: False\n\n      # the wrap policy for FSDP model\n      wrap_policy:\n\n        # minimum number of params in a wrapped module\n        min_num_params: 0\n\n    # whether to enable torch.compile\n    use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}\n\n    # [Will be deprecated, use log_prob_micro_batch_size_per_gpu]\n    # The batch size for one forward pass in the computation of log_prob. Global batch size.\n    log_prob_micro_batch_size: null\n\n    # The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.\n    log_prob_micro_batch_size_per_gpu: null\n\n    # enable dynamic batch size (sequence packing) for log_prob computation\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n\n    # the max token length per GPU\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n\n    # sequence parallel size\n    ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}\n\n    # calculate entropy with chunking to reduce memory peak\n    entropy_from_logits_with_chunking: False\n\n    # recompute entropy\n    entropy_checkpointing: False\n\n  # Rollout model config.\n  rollout:\n\n    # actor_rollout_ref.rollout.name: hf/vllm/sglang.\n    name: vllm\n\n    # sync: LLM, async: AsyncLLM\n    mode: sync\n\n    # Sampling temperature for rollout.\n    temperature: 1.0\n\n    # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.\n    top_k: -1\n\n    # Top-p sampling parameter. Default 1.0.\n    top_p: 1\n\n\n    # typically the same as data max prompt length\n    prompt_length: ${data.max_prompt_length}\n\n    # typically the same as data max response length\n    response_length: ${data.max_response_length}\n\n    # for vllm rollout\n    # Rollout model parameters type. Align with actor model's FSDP/Megatron type.\n    dtype: bfloat16\n\n    # Fraction of GPU memory used by vLLM/SGLang for KV cache.\n    gpu_memory_utilization: 0.5\n\n    # Whether to ignore EOS and continue generating after EOS is hit.\n    ignore_eos: False\n\n    # Whether to disable CUDA graph. Default True to allow cache freeing.\n    enforce_eager: True\n\n    # Whether to free engine KVCache after generation. Set enforce_eager=True when enabled.\n    free_cache_engine: True\n\n    # Which loader to use for rollout model weights: dummy_dtensor, hf, megatron, etc.\n    # safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight\n    load_format: dummy_dtensor\n\n    # for huge model, layered summon can save memory (prevent OOM) but make it slower\n    layered_summon: False\n\n    # TP size for rollout. Only effective for vLLM.\n    tensor_model_parallel_size: 2\n\n    # max number of tokens in a batch\n    max_num_batched_tokens: 8192\n\n    # max length for rollout\n    max_model_len: null\n\n    # max length of sequences\n    max_num_seqs: 1024\n\n    # [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size.\n    log_prob_micro_batch_size: null\n\n    # The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.\n    log_prob_micro_batch_size_per_gpu: null\n\n    # enable dynamic batch size (sequence packing) for log_prob computation\n    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n\n    # max token length for log_prob computation\n    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}\n\n    # disable logging statistics\n    disable_log_stats: True\n\n    # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.\n    enable_chunked_prefill: True\n\n    # for hf rollout\n    # Whether to sample during training rollout. False uses greedy sampling.\n    do_sample: True\n\n    # number of responses (i.e. num sample times). > 1 for grpo\n    n: 1\n\n    # Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache)\n    multi_stage_wake_up: false\n\n    # Extra inference engine arguments (vllm, sglang).\n    engine_kwargs:\n\n      # for vllm\n      vllm:\n\n        # Swap space (in GB) used by inference engine. null uses default (e.g., 4 GB).\n        swap_space: null\n\n        # Whether to disable the preprocessor cache for multimodel models.\n        disable_mm_preprocessor_cache: False\n\n      # for sglang\n      sglang:\n\n        # The attention backend for sglang engine. Options: flashinfer, triton, flashmla, null for default.\n        attention_backend: null\n\n    # Sampling parameters used during validation.\n    val_kwargs:\n\n      # sampling parameters for validation\n      # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.\n      top_k: -1\n\n      # Top-p sampling parameter. Default 1.0.\n      top_p: 1.0\n\n      # Sampling temperature for rollout.\n      temperature: 0\n\n      # whether to repeat n times for validation\n      n: 1\n\n      # Whether to sample during training rollout. False uses greedy sampling.\n      do_sample: False\n\n    # Multi-turn interaction config for tools or chat.\n    multi_turn:\n\n      # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well\n      enable: False\n\n      # null for no limit (default max_length // 3)\n      max_assistant_turns: null\n\n      # null for no tool\n      tool_config_path: null\n\n      # null for no limit (default max_length // 3)\n      max_user_turns: null\n\n      # max parallel call for tools in single turn\n      max_parallel_calls: 1\n\n      # max length of tool response\n      max_tool_response_length: 256\n\n      # truncate side of tool response: left, middle, right\n      tool_response_truncate_side: middle\n\n      # null for no interaction\n      interaction_config_path: null\n\n      # null for default callback\n      completion_callback: null\n\n      # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.\n      # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,\n      #   which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.\n      use_inference_chat_template: False\n\n      # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.\n      # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.\n      # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.\n      # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:\n      # Qwen/QwQ-32B, Qwen/Qwen3-xxB\n      # - disable: disable tokenization sanity check\n      # - strict: enable strict tokenization sanity check (default)\n      # - ignore_strippable: ignore strippable tokens when checking tokenization sanity\n      tokenization_sanity_check_mode: strict\n\n      # Format of the multi-turn interaction. Options: hermes, llama3_json, ...\n      format: hermes\n\n    # support logging rollout prob for debugging purpose\n    calculate_log_probs: False\n\n    # [Experimental] agent loop based rollout configs\n    agent:\n\n      # Number of agent loop workers\n      num_workers: 8\n\n      # custom async server configs\n      custom_async_server:\n\n        # Path to the custom async server implementation\n        path: null\n\n        # Class name of the custom async server class (e.g. AsyncvLLMServer)\n        name: null\n\n  # profiler configs\n  profiler:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.utils.profiler.ProfilerConfig\n\n    # True for each task has its own database, False for all tasks in one training step share one database.\n    discrete: False\n\n    # Whether to profile all ranks.\n    all_ranks: False\n\n    # The ranks that will be profiled. [] or [0,1,...]\n    ranks: []\n\n# configs for the critic\ncritic:\n\n  # Number of rollouts per update (mirrors actor rollout_n)\n  rollout_n: ${actor_rollout_ref.rollout.n}\n\n  # fsdp or fsdp2 strategy used for critic model training\n  strategy: ${actor_rollout_ref.actor.strategy}\n\n  # optimizer configs\n  optim:\n\n    # Learning rate\n    lr: 1e-5\n\n    # Warmup steps ratio; total steps will be injected at runtime\n    lr_warmup_steps_ratio: 0.\n\n    # Minimum LR ratio for cosine schedule\n    min_lr_ratio: null\n\n    # LR warmup style: \"constant\" or \"cosine\"\n    warmup_style: constant\n\n    # Total training steps (must be overridden at runtime)\n    total_training_steps: -1\n\n    # Weight decay\n    weight_decay: 0.01\n\n  # model config for the critic\n  model:\n\n    # Path to pretrained model weights\n    path: ~/models/deepseek-llm-7b-chat\n\n    # Whether to use shared memory for loading the model\n    use_shm: False\n\n    # Tokenizer path (defaults to actor's model path)\n    tokenizer_path: ${actor_rollout_ref.model.path}\n\n    # Hugging Face config override\n    override_config: { }\n\n    # External model implementation (optional)\n    external_lib: ${actor_rollout_ref.model.external_lib}\n\n    # Enable gradient checkpointing to save memory\n    enable_gradient_checkpointing: True\n\n    # Offload activations to CPU to reduce GPU memory usage\n    enable_activation_offload: False\n\n    # Use remove padding optimization (saves compute)\n    use_remove_padding: False\n\n    # Whether to trust remote code from Hugging Face models\n    trust_remote_code: ${actor_rollout_ref.model.trust_remote_code}\n\n    # FSDP-specific config\n    fsdp_config:\n\n      # Whether to offload model parameters to CPU\n      param_offload: False\n\n      # Whether to offload optimizer state to CPU\n      optimizer_offload: False\n\n      # Only for FSDP2: offload param/grad/optimizer during train\n      offload_policy: False\n\n      # Only for FSDP2: Reshard after forward pass to reduce memory footprint\n      reshard_after_forward: True\n\n      # Policy for wrapping layers with FSDP\n      wrap_policy:\n\n        # Minimum number of parameters to trigger wrapping\n        min_num_params: 0\n\n      # Number of GPUs in each FSDP shard group; -1 means auto\n      fsdp_size: -1\n\n      # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n      # before the current forward computation.\n      forward_prefetch: False\n\n    # Set to positive value to enable LoRA (e.g., 32)\n    lora_rank: 0\n\n    # LoRA scaling factor\n    lora_alpha: 16\n\n    # LoRA target modules: \"all-linear\" or list of linear projection layers\n    target_modules: all-linear\n\n  # PPO mini-batch size per update\n  ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}\n\n  # [Deprecated] Global micro batch size\n  ppo_micro_batch_size: null\n\n  # Local per-GPU micro batch size\n  ppo_micro_batch_size_per_gpu: null\n\n  # Forward-only batch size (global)\n  forward_micro_batch_size: ${critic.ppo_micro_batch_size}\n\n  # Forward-only batch size (per GPU)\n  forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}\n\n  # Whether to automatically adjust batch size at runtime\n  use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}\n\n  # Max tokens per GPU in one PPO batch (doubled for critic)\n  ppo_max_token_len_per_gpu: 32768\n\n  # Max token length per GPU in forward pass\n  forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}\n\n  # Sequence parallelism size for Ulysses-style model parallelism\n  ulysses_sequence_parallel_size: 1\n\n  # Number of PPO epochs per batch\n  ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}\n\n  # Shuffle training data across PPO epochs\n  shuffle: ${actor_rollout_ref.actor.shuffle}\n\n  # Gradient clipping for critic updates\n  grad_clip: 1.0\n\n  # PPO value function clipping range\n  cliprange_value: 0.5\n\n  # Loss aggregation mode: \"token-mean\", \"seq-mean-token-sum\", or \"seq-mean-token-mean\"\n  loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}\n\n  # checkpoint configs\n  checkpoint:\n\n    # What to include in saved checkpoints\n    # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n    save_contents: ['model', 'optimizer', 'extra']\n\n    # What to include when loading checkpoints\n    load_contents: ${critic.checkpoint.save_contents}\n\n  # profiler configs\n  # the corresponding dataclass is verl.utils.profiler.ProfilerConfig.\n  profiler:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.utils.profiler.ProfilerConfig\n\n    # True for each task has its own database, False for all tasks in one training step share one database.\n    discrete: False\n\n    # Whether to profile all ranks.\n    all_ranks: False\n\n    # The ranks that will be profiled. [] or [0,1,...]\n    ranks: []\n\n# configs for the reward model\nreward_model:\n\n  # Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions.\n  # In GSM8K and Math examples, we disable reward model.\n  # For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses.\n  # If False, the following parameters are not effective\n  enable: False\n\n  # FSDP strategy: \"fsdp\" or \"fsdp2\"\n  strategy: ${actor_rollout_ref.actor.strategy}\n\n  # model config for reward scoring\n  model:\n\n    # Input tokenizer. If the reward model’s chat template is inconsistent with the policy,\n    # we need to first decode to plaintext, then apply the rm’s chat_template.\n    # Then score with RM. If chat_templates are consistent, it can be set to null.\n    input_tokenizer: ${actor_rollout_ref.model.path}\n\n    # RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification.\n    # Other model types need to define their own RewardModelWorker and pass it from the code.\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n\n    # Whether to use shared memory for loading the model\n    use_shm: False\n\n    # External model implementation (optional)\n    external_lib: ${actor_rollout_ref.model.external_lib}\n\n    # Use remove padding optimization (saves compute)\n    use_remove_padding: False\n\n    # Whether to use fused reward kernels for speedup\n    use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}\n\n    # Whether to enable loading a remote code model, default to False\n    trust_remote_code: False\n\n    # FSDP-specific config\n    fsdp_config:\n\n      # Policy for wrapping layers with FSDP\n      wrap_policy:\n\n        # Minimum number of parameters to trigger wrapping\n        min_num_params: 0\n\n      # Whether to offload model parameters to CPU\n      param_offload: False\n\n      # Only for FSDP2: Reshard after forward pass to reduce memory footprint\n      reshard_after_forward: True\n\n      # Number of GPUs in each FSDP shard group; -1 means auto\n      fsdp_size: -1\n\n      # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n      # before the current forward computation.\n      forward_prefetch: False\n\n  # [Deprecated] Global micro batch size\n  micro_batch_size: null\n\n  # Local per-GPU micro batch size\n  micro_batch_size_per_gpu: null\n\n  # Maximum sequence length to process for scoring\n  max_length: null\n\n  # Sequence parallelism size for Ulysses-style model parallelism\n  ulysses_sequence_parallel_size: 1\n\n  # Whether to dynamically adjust batch size at runtime\n  use_dynamic_bsz: ${critic.use_dynamic_bsz}\n\n  # Maximum number of tokens per GPU in one forward pass\n  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n\n  # Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources.\n  # Default is naive. If all verification functions are multiprocessing-safe,\n  # the reward manager can be set to prime for parallel verification.\n  reward_manager: naive\n\n  # Whether to launch custom reward function asynchronously during log_prob\n  launch_reward_fn_async: False\n\n  # Cloud/local sandbox fusion configuration for custom reward logic\n  sandbox_fusion:\n\n    # Cloud/local function URL for sandbox execution\n    url: null\n\n    # Max concurrent requests allowed to sandbox\n    max_concurrent: 64\n\n    # Max memory limit for each sandbox process in MB\n    memory_limit_mb: 1024\n\n  # profiler configs\n  profiler:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.utils.profiler.ProfilerConfig\n\n    # True for each task has its own database, False for all tasks in one training step share one database.\n    discrete: False\n\n    # Whether to profile all ranks.\n    all_ranks: False\n\n    # The ranks that will be profiled. [] or [0,1,...]\n    ranks: []\n\n# custom reward function definition\ncustom_reward_function:\n\n  # The path to the file containing your customized reward function.\n  # If not specified, pre-implemented reward functions will be used.\n  path: null\n\n  # The name of the reward function within the specified file. Default is 'compute_score'.\n  name: compute_score\n\n# config for the algorithm\nalgorithm:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n  _target_: verl.trainer.config.AlgoConfig\n\n  # Discount factor for future rewards\n  gamma: 1.0\n\n  # Trade-off between bias and variance in the GAE estimator\n  lam: 1.0\n\n  # Advantage estimator type: \"gae\", \"grpo\", \"reinforce_plus_plus\", etc.\n  adv_estimator: gae\n\n  # Whether to normalize advantages by std (specific to GRPO)\n  norm_adv_by_std_in_grpo: True\n\n  # Whether to enable in-reward KL penalty\n  use_kl_in_reward: False\n\n  # How to estimate KL divergence: \"kl\", \"abs\", \"mse\", \"low_var_kl\", or \"full\"\n  kl_penalty: kl\n\n  # KL control configuration\n  kl_ctrl:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.trainer.config.KLControlConfig\n\n    # KL control type: \"fixed\" or \"adaptive\"\n    type: fixed\n\n    # Initial coefficient for KL penalty\n    kl_coef: 0.001\n\n    # Horizon value for adaptive controller (if enabled)\n    horizon: 10000\n\n    # Target KL divergence (used for adaptive controller)\n    target_kl: 0.1\n\n  # Whether to enable preference feedback PPO\n  use_pf_ppo: False\n\n  # Preference feedback PPO settings\n  pf_ppo:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.trainer.config.PFPPOConfig\n\n    # Method for reweighting samples: \"pow\", \"max_min\", or \"max_random\"\n    reweight_method: pow\n\n    # Power used for weight scaling in \"pow\" method\n    weight_pow: 2.0\n\n# config for the trainer\ntrainer:\n\n  # Whether to balance batch sizes across distributed workers\n  balance_batch: True\n\n  # Number of epochs in training\n  total_epochs: 30\n\n  # Total training steps (can be set explicitly or derived from epochs)\n  total_training_steps: null\n\n  # The steps that will be profiled. null means no profiling. null or [1,2,5,...]\n  profile_steps: null\n\n  # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None.\n  ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html\n  ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html\n  controller_nsight_options:\n\n    # Select the API(s) to be traced.\n    trace: \"cuda,nvtx,cublas,ucx\"\n\n    # Track the GPU memory usage by CUDA kernels. Must be string type \"true\" or \"false\".\n    cuda-memory-usage: \"true\"\n\n    # CUDA graphs will be traced as a whole\n    cuda-graph-trace: \"graph\"\n\n  # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None.\n  worker_nsight_options:\n\n    # Select the API(s) to be traced.\n    trace: \"cuda,nvtx,cublas,ucx\"\n\n    # Track the GPU memory usage by CUDA kernels. Must be string type \"true\" or \"false\".\n    cuda-memory-usage: \"true\"\n\n    # CUDA graphs will be traced as a whole\n    cuda-graph-trace: \"graph\"\n\n    # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config.\n    capture-range: \"cudaProfilerApi\"\n\n    # Specify the desired behavior when a capture range ends.\n    # In verl we need the orch.cuda.profiler.start/stop pair to repeats n times.\n    # valid values are \"repeat-shutdown:n\" or null.\n    # For normal whole step profiling, n = len(profile_steps);\n    # but for discrete profiling, n = len(profile_steps) * Number(subtasks).\n    # Or you can just leave it null and the program will use n = len(profile_steps) * 6;\n    capture-range-end: null\n\n    # Send signal to the target application's process group. We let the program to exit by itself.\n    kill: none\n\n  # Config for npu profiler. Must set when profile_steps is not None and torch_npu is available.\n  npu_profile:\n\n    # Options for the npu profiler\n    options:\n\n      # Storage path of collected data.\n      save_path: ./profiler_data\n\n      # Collection level, optional values: level_none, level0, level1, level2.\n      level: level1\n\n      # Whether to enable memory analysis.\n      with_memory: False\n\n      # Whether to record tensor shape.\n      record_shapes: False\n\n      # Whether to record Device-side performance data.\n      with_npu: True\n\n      # Whether to record Host-side performance data.\n      with_cpu: True\n\n      # Whether to record Python call stack information.\n      with_module: False\n\n      # Whether to record operator call stack information.\n      with_stack: False\n\n      # Whether to automatically parse the data.\n      analysis: True\n\n  # Project name for experiment tracking (e.g., wandb)\n  project_name: verl_examples\n\n  # Experiment name for run identification in tracking tools\n  experiment_name: gsm8k\n\n  # Logging backends to use: \"console\", \"wandb\", etc.\n  logger: [ 'console', 'wandb' ]\n\n  # Number of generations to log during validation\n  log_val_generations: 0\n\n  # Directory for logging rollout data; no dump if null\n  rollout_data_dir: null\n\n  # Directory for logging validation data; no dump if null\n  validation_data_dir: null\n\n  # Number of nodes used in the training\n  nnodes: 1\n\n  # Number of GPUs per node\n  n_gpus_per_node: 8\n\n  # Save frequency (by iteration) for model checkpoints\n  save_freq: -1\n\n  # ESI refers to the elastic server instance used during training, similar to the training plan. For example,\n  # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training.\n  # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance.\n  # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time.\n  # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety.\n  esi_redundant_time: 0\n\n  # Resume mode: \"auto\", \"disable\", or \"resume_path\"\n  # \"auto\": resume from last checkpoint if available\n  # \"disable\": start from scratch\n  # \"resume_path\": resume from a user-defined path\n  resume_mode: auto\n\n  # Path to resume training from (only used when resume_mode is \"resume_path\")\n  resume_from_path: null\n\n  # Whether to run validation before training begins\n  val_before_train: True\n\n  # Whether to run validation only\n  val_only: False\n\n  # Validation frequency (in training iterations)\n  test_freq: -1\n\n  # Number of iterations to warm up the critic before updating policy\n  critic_warmup: 0\n\n  # Default path to distributed filesystem for saving checkpoints\n  default_hdfs_dir: null\n\n  # Whether to delete local checkpoints after loading\n  del_local_ckpt_after_load: False\n\n  # Default local directory for saving checkpoints\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n\n  # Maximum number of actor checkpoints to keep\n  max_actor_ckpt_to_keep: null\n\n  # Maximum number of critic checkpoints to keep\n  max_critic_ckpt_to_keep: null\n\n  # Timeout (in seconds) for Ray worker to wait for registration\n  ray_wait_register_center_timeout: 300\n\n  # Device to run training on (e.g., \"cuda\", \"cpu\")\n  device: cuda\n\n# configs related to ray initialization\nray_init:\n\n  # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM.\n  num_cpus: null\n\n  # Path to save Ray timeline JSON for performance profiling\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_rl/tests/trainer/config/test_algo_config_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom omegaconf import OmegaConf\n\nfrom verl.trainer.config import AlgoConfig, KLControlConfig, PFPPOConfig\nfrom verl.trainer.ppo.core_algos import (\n    compute_gae_advantage_return,\n    compute_grpo_outcome_advantage,\n    get_adv_estimator_fn,\n)\nfrom verl.utils.config import omega_conf_to_dataclass\n\n\nclass TestAlgoConfig(unittest.TestCase):\n    \"\"\"Test the AlgoConfig dataclass and its integration with core algorithms.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up test fixtures.\"\"\"\n        # Create a sample algorithm config as DictConfig (similar to what comes from YAML)\n        self.config_dict = {\n            \"_target_\": \"verl.trainer.config.AlgoConfig\",\n            \"gamma\": 0.99,\n            \"lam\": 0.95,\n            \"adv_estimator\": \"gae\",\n            \"norm_adv_by_std_in_grpo\": True,\n            \"use_kl_in_reward\": True,\n            \"kl_penalty\": \"kl\",\n            \"kl_ctrl\": {\n                \"_target_\": \"verl.trainer.config.KLControlConfig\",\n                \"type\": \"adaptive\",\n                \"kl_coef\": 0.002,\n                \"horizon\": 5000,\n                \"target_kl\": 0.05,\n            },\n            \"use_pf_ppo\": True,\n            \"pf_ppo\": {\"_target_\": \"verl.trainer.config.PFPPOConfig\", \"reweight_method\": \"max_min\", \"weight_pow\": 3.0},\n        }\n        self.omega_config = OmegaConf.create(self.config_dict)\n\n    def test_dataclass_creation_from_dict(self):\n        \"\"\"Test creating AlgoConfig from dictionary.\"\"\"\n        config = omega_conf_to_dataclass(self.config_dict)\n\n        self.assertIsInstance(config, AlgoConfig)\n        self.assertEqual(config.gamma, 0.99)\n        self.assertEqual(config.lam, 0.95)\n        self.assertEqual(config.adv_estimator, \"gae\")\n        self.assertTrue(config.norm_adv_by_std_in_grpo)\n        self.assertTrue(config.use_kl_in_reward)\n        self.assertEqual(config.kl_penalty, \"kl\")\n        self.assertTrue(config.use_pf_ppo)\n\n    def test_dataclass_creation_from_omega_config(self):\n        \"\"\"Test creating AlgoConfig from OmegaConf DictConfig.\"\"\"\n        config = omega_conf_to_dataclass(self.omega_config)\n\n        self.assertIsInstance(config, AlgoConfig)\n        self.assertEqual(config.gamma, 0.99)\n        self.assertEqual(config.lam, 0.95)\n\n    def test_nested_configs(self):\n        \"\"\"Test that nested configurations are properly converted.\"\"\"\n        config = omega_conf_to_dataclass(self.omega_config)\n\n        # Test KL control config\n        self.assertIsInstance(config.kl_ctrl, KLControlConfig)\n        self.assertEqual(config.kl_ctrl.type, \"adaptive\")\n        self.assertEqual(config.kl_ctrl.kl_coef, 0.002)\n        self.assertEqual(config.kl_ctrl.horizon, 5000)\n        self.assertEqual(config.kl_ctrl.target_kl, 0.05)\n\n        # Test PF PPO config\n        self.assertIsInstance(config.pf_ppo, PFPPOConfig)\n        self.assertEqual(config.pf_ppo.reweight_method, \"max_min\")\n        self.assertEqual(config.pf_ppo.weight_pow, 3.0)\n\n    def test_default_values(self):\n        \"\"\"Test that default values are properly set.\"\"\"\n        minimal_config = {\"gamma\": 0.8}\n        config = omega_conf_to_dataclass(minimal_config, AlgoConfig)\n\n        self.assertEqual(config.gamma, 0.8)\n        self.assertEqual(config.lam, 1.0)  # default value\n        self.assertEqual(config.adv_estimator, \"gae\")  # default value\n        self.assertTrue(config.norm_adv_by_std_in_grpo)  # default value\n        self.assertFalse(config.use_kl_in_reward)  # default value\n        self.assertEqual(config.kl_penalty, \"kl\")  # default value\n        self.assertFalse(config.use_pf_ppo)  # default value\n\n    def test_get_method_backward_compatibility(self):\n        \"\"\"Test the get method for backward compatibility.\"\"\"\n        config = omega_conf_to_dataclass(self.omega_config)\n\n        # Test existing attribute\n        self.assertEqual(config.get(\"gamma\"), 0.99)\n        self.assertEqual(config.get(\"gamma\", 1.0), 0.99)\n\n        # Test non-existing attribute\n        self.assertIsNone(config.get(\"non_existing\"))\n        self.assertEqual(config.get(\"non_existing\", \"default\"), \"default\")\n\n    def test_post_init_nested_configs(self):\n        \"\"\"Test that __post_init__ properly initializes nested configs when None.\"\"\"\n        # Create config without nested configs\n        minimal_config = AlgoConfig(gamma=0.9)\n\n        # Check that nested configs are initialized\n        self.assertIsNotNone(minimal_config.kl_ctrl)\n        self.assertIsInstance(minimal_config.kl_ctrl, KLControlConfig)\n        self.assertIsNone(minimal_config.pf_ppo)\n\n    def test_config_init_from_yaml(self):\n        import os\n\n        from hydra import compose, initialize_config_dir\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n            cfg = compose(config_name=\"ppo_trainer\")\n        algo_config = omega_conf_to_dataclass(cfg.algorithm)\n        from verl.trainer.config import AlgoConfig, PFPPOConfig\n\n        assert isinstance(algo_config, AlgoConfig)\n        assert isinstance(algo_config.pf_ppo, PFPPOConfig)\n\n\nclass TestAlgoCompute(unittest.TestCase):\n    \"\"\"Test the AlgoConfig dataclass and its integration with core algorithms.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up test fixtures.\"\"\"\n        self.algo_config = AlgoConfig(\n            gamma=0.99,\n            lam=0.95,\n            adv_estimator=\"gae\",\n            norm_adv_by_std_in_grpo=True,\n            use_kl_in_reward=True,\n            kl_penalty=\"kl\",\n            kl_ctrl=KLControlConfig(type=\"adaptive\", kl_coef=0.002, horizon=5000, target_kl=0.05),\n            use_pf_ppo=True,\n            pf_ppo=PFPPOConfig(reweight_method=\"max_min\", weight_pow=3.0),\n        )\n\n    def test_advantage_estimator_with_cfg(self):\n        \"\"\"Test integration with advantage estimators from core_algos.\"\"\"\n        config = self.algo_config\n\n        # Test GAE advantage estimator\n        adv_fn = get_adv_estimator_fn(config.adv_estimator)\n        self.assertIsNotNone(adv_fn)\n\n        # Test with actual GAE computation\n        batch_size, seq_len = 2, 5\n        token_level_rewards = torch.randn(batch_size, seq_len)\n        values = torch.randn(batch_size, seq_len)\n        response_mask = torch.ones(batch_size, seq_len)\n\n        advantages, returns = compute_gae_advantage_return(\n            token_level_rewards=token_level_rewards,\n            values=values,\n            response_mask=response_mask,\n            gamma=config.gamma,\n            lam=config.lam,\n        )\n\n        self.assertEqual(advantages.shape, (batch_size, seq_len))\n        self.assertEqual(returns.shape, (batch_size, seq_len))\n\n    def test_grpo_advantage_estimator_with_cfg(self):\n        \"\"\"Test integration with GRPO advantage estimator.\"\"\"\n        grpo_config = AlgoConfig(adv_estimator=\"grpo\", norm_adv_by_std_in_grpo=True)\n\n        # Test GRPO advantage computation\n        batch_size, seq_len = 4, 3\n        token_level_rewards = torch.tensor([[1.0, 0.5, 0.0], [2.0, 1.0, 0.0], [0.5, 0.2, 0.0], [1.5, 0.8, 0.0]])\n        response_mask = torch.ones(batch_size, seq_len)\n        index = np.array([0, 0, 1, 1])  # Two groups\n\n        advantages, returns = compute_grpo_outcome_advantage(\n            token_level_rewards=token_level_rewards,\n            response_mask=response_mask,\n            index=index,\n            norm_adv_by_std_in_grpo=grpo_config.norm_adv_by_std_in_grpo,\n        )\n\n        self.assertEqual(advantages.shape, (batch_size, seq_len))\n        self.assertEqual(returns.shape, (batch_size, seq_len))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_rl/tests/trainer/config/test_critic_config_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom pathlib import Path\n\nimport pytest\nfrom hydra import compose, initialize_config_dir\n\nfrom verl.trainer.config.config import CriticConfig, FSDPCriticConfig, MegatronCriticConfig\nfrom verl.utils.config import omega_conf_to_dataclass\n\n\nclass TestCriticConfig:\n    \"\"\"Test suite for critic configuration dataclasses.\"\"\"\n\n    @pytest.fixture\n    def config_dir(self):\n        \"\"\"Get the path to the config directory.\"\"\"\n        return Path(__file__).parent.parent.parent.parent / \"verl\" / \"trainer\" / \"config\" / \"critic\"\n\n    def test_megatron_critic_config_instantiation_from_yaml(self, config_dir):\n        \"\"\"Test that MegatronCriticConfig can be instantiated from megatron_critic.yaml.\"\"\"\n        yaml_path = config_dir / \"megatron_critic.yaml\"\n        assert yaml_path.exists(), f\"Config file not found: {yaml_path}\"\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config/critic\")):\n            test_config = compose(config_name=\"megatron_critic\")\n\n        megatron_config_obj = omega_conf_to_dataclass(test_config)\n\n        assert isinstance(megatron_config_obj, MegatronCriticConfig)\n        assert isinstance(megatron_config_obj, CriticConfig)\n\n        expected_attrs = [\n            \"strategy\",\n            \"rollout_n\",\n            \"optim\",\n            \"model\",\n            \"ppo_mini_batch_size\",\n            \"ppo_max_token_len_per_gpu\",\n            \"cliprange_value\",\n            \"get\",\n            \"nccl_timeout\",\n            \"megatron\",\n            \"load_weight\",\n        ]\n        for attr in expected_attrs:\n            assert hasattr(megatron_config_obj, attr), f\"Missing attribute: {attr}\"\n\n        assert callable(megatron_config_obj.get)\n        assert megatron_config_obj.strategy == \"megatron\"\n\n    def test_fsdp_critic_config_instantiation_from_yaml(self, config_dir):\n        \"\"\"Test that FSDPCriticConfig can be instantiated from dp_critic.yaml.\"\"\"\n        yaml_path = config_dir / \"dp_critic.yaml\"\n        assert yaml_path.exists(), f\"Config file not found: {yaml_path}\"\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config/critic\")):\n            test_config = compose(config_name=\"dp_critic\")\n\n        fsdp_config_obj = omega_conf_to_dataclass(test_config)\n\n        assert isinstance(fsdp_config_obj, FSDPCriticConfig)\n        assert isinstance(fsdp_config_obj, CriticConfig)\n\n        expected_attrs = [\n            \"strategy\",\n            \"rollout_n\",\n            \"optim\",\n            \"model\",\n            \"ppo_mini_batch_size\",\n            \"ppo_max_token_len_per_gpu\",\n            \"cliprange_value\",\n            \"get\",\n            \"forward_micro_batch_size\",\n            \"forward_micro_batch_size_per_gpu\",\n            \"ulysses_sequence_parallel_size\",\n            \"grad_clip\",\n        ]\n        for attr in expected_attrs:\n            assert hasattr(fsdp_config_obj, attr), f\"Missing attribute: {attr}\"\n\n        assert callable(fsdp_config_obj.get)\n        assert fsdp_config_obj.strategy == \"fsdp\"\n\n    def test_config_inheritance_hierarchy(self):\n        \"\"\"Test that the inheritance hierarchy is correct.\"\"\"\n        megatron_config = MegatronCriticConfig()\n        assert isinstance(megatron_config, CriticConfig)\n        assert isinstance(megatron_config, MegatronCriticConfig)\n\n        fsdp_config = FSDPCriticConfig()\n        assert isinstance(fsdp_config, CriticConfig)\n        assert isinstance(fsdp_config, FSDPCriticConfig)\n\n        critic_config = CriticConfig()\n        assert isinstance(critic_config, CriticConfig)\n        assert not isinstance(critic_config, MegatronCriticConfig)\n        assert not isinstance(critic_config, FSDPCriticConfig)\n\n    def test_config_dict_interface(self):\n        \"\"\"Test that configs provide dict-like interface from BaseConfig.\"\"\"\n        config = CriticConfig()\n\n        assert \"strategy\" in config\n        assert config[\"strategy\"] == \"fsdp\"\n\n        assert config.get(\"strategy\") == \"fsdp\"\n        assert config.get(\"nonexistent_key\", \"default\") == \"default\"\n\n        keys = list(config)\n        assert \"strategy\" in keys\n        assert \"rollout_n\" in keys\n\n        assert len(config) > 0\n\n    def test_frozen_fields_immutability(self):\n        \"\"\"Test that frozen fields raise exceptions when modified after creation.\"\"\"\n        critic_config = CriticConfig()\n        frozen_fields = [\"rollout_n\", \"strategy\", \"cliprange_value\"]\n\n        for field_name in frozen_fields:\n            with pytest.raises((AttributeError, TypeError, ValueError)):\n                setattr(critic_config, field_name, \"modified_value\")\n\n        megatron_config = MegatronCriticConfig()\n        megatron_frozen_fields = [\"nccl_timeout\", \"load_weight\", \"data_loader_seed\"]\n\n        for field_name in megatron_frozen_fields:\n            with pytest.raises((AttributeError, TypeError, ValueError)):\n                setattr(megatron_config, field_name, \"modified_value\")\n\n        fsdp_config = FSDPCriticConfig()\n        fsdp_frozen_fields = [\"ulysses_sequence_parallel_size\", \"grad_clip\"]\n\n        for field_name in fsdp_frozen_fields:\n            with pytest.raises((AttributeError, TypeError, ValueError)):\n                setattr(fsdp_config, field_name, \"modified_value\")\n\n    def test_batch_size_fields_modifiable(self):\n        \"\"\"Test that batch size fields can be modified after creation.\"\"\"\n        critic_config = CriticConfig()\n\n        critic_config.ppo_mini_batch_size = 8\n        critic_config.ppo_micro_batch_size = 4\n        critic_config.ppo_micro_batch_size_per_gpu = 2\n\n        assert critic_config.ppo_mini_batch_size == 8\n        assert critic_config.ppo_micro_batch_size == 4\n        assert critic_config.ppo_micro_batch_size_per_gpu == 2\n\n        fsdp_config = FSDPCriticConfig()\n\n        fsdp_config.forward_micro_batch_size = 16\n        fsdp_config.forward_micro_batch_size_per_gpu = 8\n\n        assert fsdp_config.forward_micro_batch_size == 16\n        assert fsdp_config.forward_micro_batch_size_per_gpu == 8\n"
  },
  {
    "path": "verl_rl/tests/trainer/config/test_legacy_config_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport unittest\nimport warnings\n\nfrom hydra import compose, initialize_config_dir\nfrom hydra.core.global_hydra import GlobalHydra\nfrom omegaconf import OmegaConf\n\n\nclass TestConfigComparison(unittest.TestCase):\n    \"\"\"Test that current configs match their legacy counterparts exactly.\"\"\"\n\n    ignored_keys = [\n        \"enable_gradient_checkpointing\",\n        \"gradient_checkpointing_kwargs\",\n        \"activations_checkpoint_method\",\n        \"activations_checkpoint_granularity\",\n        \"activations_checkpoint_num_layers\",\n    ]\n\n    def _compare_configs_recursively(\n        self, current_config, legacy_config, path=\"\", legacy_allow_missing=True, current_allow_missing=False\n    ):\n        \"\"\"Recursively compare two OmegaConf configs and assert they are identical.\n\n        Args:\n            legacy_allow_missing (bool): sometimes the legacy megatron config contains fewer keys and\n              we allow that to happen\n        \"\"\"\n        if isinstance(current_config, dict) and isinstance(legacy_config, dict):\n            current_keys = set(current_config.keys())\n            legacy_keys = set(legacy_config.keys())\n\n            missing_in_current = legacy_keys - current_keys\n            missing_in_legacy = current_keys - legacy_keys\n\n            # Ignore specific keys that are allowed to be missing\n            for key in self.ignored_keys:\n                if key in missing_in_current:\n                    missing_in_current.remove(key)\n                if key in missing_in_legacy:\n                    missing_in_legacy.remove(key)\n\n            if missing_in_current:\n                msg = f\"Keys missing in current config at {path}: {missing_in_current}\"\n                if current_allow_missing:\n                    warnings.warn(msg, stacklevel=1)\n                else:\n                    self.fail(f\"Keys missing in current config at {path}: {missing_in_current}\")\n            if missing_in_legacy:\n                # if the legacy\n                msg = f\"Keys missing in legacy config at {path}: {missing_in_legacy}\"\n                if legacy_allow_missing:\n                    warnings.warn(msg, stacklevel=1)\n                else:\n                    self.fail(msg)\n\n            for key in current_keys:\n                current_path = f\"{path}.{key}\" if path else key\n                if key in legacy_config:\n                    self._compare_configs_recursively(current_config[key], legacy_config[key], current_path)\n        elif isinstance(current_config, list) and isinstance(legacy_config, list):\n            self.assertEqual(\n                len(current_config),\n                len(legacy_config),\n                f\"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}\",\n            )\n            for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)):\n                self._compare_configs_recursively(current_item, legacy_item, f\"{path}[{i}]\")\n        else:\n            self.assertEqual(\n                current_config,\n                legacy_config,\n                f\"Values differ at {path}: current={current_config}, legacy={legacy_config}\",\n            )\n\n    def test_ppo_trainer_config_matches_legacy(self):\n        \"\"\"Test that ppo_trainer.yaml matches legacy_ppo_trainer.yaml exactly.\"\"\"\n        import os\n\n        from hydra import compose, initialize_config_dir\n        from hydra.core.global_hydra import GlobalHydra\n\n        GlobalHydra.instance().clear()\n\n        try:\n            with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n                current_config = compose(config_name=\"ppo_trainer\")\n\n            legacy_config = OmegaConf.load(\"tests/trainer/config/legacy_ppo_trainer.yaml\")\n            current_dict = OmegaConf.to_container(current_config, resolve=True)\n            legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)\n\n            if \"defaults\" in current_dict:\n                del current_dict[\"defaults\"]\n\n            self._compare_configs_recursively(current_dict, legacy_dict)\n        finally:\n            GlobalHydra.instance().clear()\n\n    def test_ppo_megatron_trainer_config_matches_legacy(self):\n        \"\"\"Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly.\"\"\"\n\n        GlobalHydra.instance().clear()\n\n        try:\n            with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n                current_config = compose(config_name=\"ppo_megatron_trainer\")\n\n            legacy_config = OmegaConf.load(\"tests/trainer/config/legacy_ppo_megatron_trainer.yaml\")\n            current_dict = OmegaConf.to_container(current_config, resolve=True)\n            legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)\n\n            if \"defaults\" in current_dict:\n                del current_dict[\"defaults\"]\n\n            self._compare_configs_recursively(\n                current_dict, legacy_dict, legacy_allow_missing=True, current_allow_missing=False\n            )\n        finally:\n            GlobalHydra.instance().clear()\n\n    def test_load_component(self):\n        \"\"\"Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly.\"\"\"\n\n        GlobalHydra.instance().clear()\n        configs_to_load = [\n            (\"verl/trainer/config/actor\", \"dp_actor\"),\n            (\"verl/trainer/config/actor\", \"megatron_actor\"),\n            (\"verl/trainer/config/ref\", \"dp_ref\"),\n            (\"verl/trainer/config/ref\", \"megatron_ref\"),\n            (\"verl/trainer/config/rollout\", \"rollout\"),\n        ]\n        for config_dir, config_file in configs_to_load:\n            try:\n                with initialize_config_dir(config_dir=os.path.abspath(config_dir)):\n                    compose(config_name=config_file)\n            finally:\n                GlobalHydra.instance().clear()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_rl/tests/trainer/ppo/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTests for the PPO trainer module.\n\"\"\"\n"
  },
  {
    "path": "verl_rl/tests/trainer/ppo/test_core_algos_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport random\nimport unittest\n\nimport pytest\nimport torch\n\nimport verl.trainer.ppo.core_algos\nfrom verl.trainer.ppo.core_algos import compute_gae_advantage_return, get_adv_estimator_fn, register_adv_est\n\n\ndef mock_test_fn():\n    pass\n\n\nclass TestRegisterAdvEst(unittest.TestCase):\n    def setUp(self):\n        \"\"\"Clear the registry before each test\"\"\"\n        verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()\n        verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY = {\n            \"gae\": lambda x: x * 2,\n            \"vtrace\": lambda x: x + 1,\n        }\n        self.ADV_ESTIMATOR_REGISTRY = verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY\n\n    def tearDown(self) -> None:\n        verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()\n        return super().tearDown()\n\n    def test_register_new_function(self):\n        \"\"\"Test registering a new function with a string name\"\"\"\n\n        @register_adv_est(\"test_estimator\")\n        def test_fn():\n            pass\n\n        self.assertIn(\"test_estimator\", self.ADV_ESTIMATOR_REGISTRY)\n        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY[\"test_estimator\"], test_fn)\n\n    def test_register_with_enum(self):\n        \"\"\"Test registering with an enum value (assuming AdvantageEstimator exists)\"\"\"\n        from enum import Enum\n\n        class AdvantageEstimator(Enum):\n            TEST = \"test_enum_estimator\"\n\n        @register_adv_est(AdvantageEstimator.TEST)\n        def test_fn():\n            pass\n\n        self.assertIn(\"test_enum_estimator\", self.ADV_ESTIMATOR_REGISTRY)\n        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY[\"test_enum_estimator\"], test_fn)\n\n    def test_duplicate_registration_same_function(self):\n        \"\"\"Test that registering the same function twice doesn't raise an error\"\"\"\n        register_adv_est(\"duplicate_test\")(mock_test_fn)\n        register_adv_est(\"duplicate_test\")(mock_test_fn)\n\n        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY[\"duplicate_test\"], mock_test_fn)\n\n    def test_duplicate_registration_different_function(self):\n        \"\"\"Test that registering different functions with same name raises ValueError\"\"\"\n\n        @register_adv_est(\"conflict_test\")\n        def test_fn1():\n            pass\n\n        with self.assertRaises(ValueError):\n\n            @register_adv_est(\"conflict_test\")\n            def test_fn2():\n                pass\n\n    def test_decorator_preserves_function(self):\n        \"\"\"Test that the decorator returns the original function\"\"\"\n\n        def test_fn():\n            return \"original\"\n\n        decorated = register_adv_est(\"preserve_test\")(test_fn)\n        self.assertEqual(decorated(), \"original\")\n\n    def test_multiple_registrations(self):\n        \"\"\"Test registering multiple different functions\"\"\"\n        init_adv_count = len(self.ADV_ESTIMATOR_REGISTRY)\n\n        @register_adv_est(\"estimator1\")\n        def fn1():\n            pass\n\n        @register_adv_est(\"estimator2\")\n        def fn2():\n            pass\n\n        self.assertEqual(len(self.ADV_ESTIMATOR_REGISTRY), 2 + init_adv_count)\n        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY[\"estimator1\"], fn1)\n        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY[\"estimator2\"], fn2)\n\n    def test_get_adv_estimator_fn_valid_names(self):\n        \"\"\"Test that valid names return the correct function from registry.\"\"\"\n        # Test GAE\n        gae_fn = get_adv_estimator_fn(\"gae\")\n        assert gae_fn(5) == 10  # 5 * 2 = 10\n\n        # Test Vtrace\n        vtrace_fn = get_adv_estimator_fn(\"vtrace\")\n        assert vtrace_fn(5) == 6  # 5 + 1 = 6\n\n    def test_get_adv_estimator_fn_invalid_name(self):\n        \"\"\"Test that invalid names raise ValueError.\"\"\"\n        with pytest.raises(ValueError) as excinfo:\n            get_adv_estimator_fn(\"invalid_name\")\n        assert \"Unknown advantage estimator simply: invalid_name\" in str(excinfo.value)\n\n    def test_get_adv_estimator_fn_case_sensitive(self):\n        \"\"\"Test that name lookup is case-sensitive.\"\"\"\n        with pytest.raises(ValueError):\n            get_adv_estimator_fn(\"GAE\")  # Different case\n\n\ndef test_multi_turn_compute_gae_advantage_return():\n    \"\"\"Test multi-turn GAE skip observation tokens.\"\"\"\n    gamma = random.uniform(0.0, 1.0)\n    lam = random.uniform(0.0, 1.0)\n\n    rewards = torch.tensor([[0.0, 0.0, 0.1, 0.1, 0.1, 0.0, 0.0, 0.1, 1.0, 0.0, 0.0]], dtype=torch.float)\n\n    values1 = torch.tensor(\n        [\n            [\n                random.uniform(-100.0, 100.0),\n                random.random(),\n                4.0,\n                5.0,\n                6.0,\n                random.uniform(-100.0, 0),\n                random.random(),\n                7.0,\n                9.0,\n                0.0,\n                0.0,\n            ]\n        ],\n        dtype=torch.float,\n    )\n\n    values2 = torch.tensor(\n        [\n            [\n                random.random(),\n                random.uniform(-100.0, 100.0),\n                4.0,\n                5.0,\n                6.0,\n                random.random(),\n                random.uniform(0.0, 100.0),\n                7.0,\n                9.0,\n                0.0,\n                0.0,\n            ]\n        ],\n        dtype=torch.float,\n    )\n\n    response_mask = torch.tensor([[0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float)\n\n    adv1, ret1 = compute_gae_advantage_return(rewards, values1, response_mask, gamma, lam)\n    adv2, ret2 = compute_gae_advantage_return(rewards, values2, response_mask, gamma, lam)\n\n    ret1 *= response_mask\n    ret2 *= response_mask\n    assert torch.equal(adv1, adv2), f\"{adv1=}, {adv2=}\"\n    assert torch.equal(ret1, ret2), f\"{ret1=}, {ret2=}\"\n    print(f\" [CORRECT] \\n\\n{adv1=}, \\n\\n{ret1=}\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_rl/tests/trainer/ppo/test_metric_utils_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTests for the metric utilities in verl.trainer.ppo.metric_utils.\n\"\"\"\n\nimport unittest\nfrom unittest.mock import MagicMock, patch\n\nimport numpy as np\nimport torch\n\nfrom verl.trainer.ppo.metric_utils import (\n    bootstrap_metric,\n    calc_maj_val,\n    compute_data_metrics,\n    compute_throughout_metrics,\n    compute_timing_metrics,\n    process_validation_metrics,\n)\nfrom verl.utils.metric import (\n    reduce_metrics,\n)\n\n\nclass TestReduceMetrics(unittest.TestCase):\n    \"\"\"Tests for the reduce_metrics function.\"\"\"\n\n    def test_reduce_metrics_basic(self):\n        \"\"\"Test that reduce_metrics correctly computes means.\"\"\"\n        metrics = {\n            \"loss\": [1.0, 2.0, 3.0],\n            \"accuracy\": [0.0, 0.5, 1.0],\n        }\n        result = reduce_metrics(metrics)\n\n        self.assertEqual(result[\"loss\"], 2.0)\n        self.assertEqual(result[\"accuracy\"], 0.5)\n\n    def test_reduce_metrics_empty(self):\n        \"\"\"Test that reduce_metrics handles empty lists.\"\"\"\n        metrics = {\n            \"empty\": [],\n        }\n        result = reduce_metrics(metrics)\n\n        self.assertTrue(np.isnan(result[\"empty\"]))\n\n    def test_reduce_metrics_single_value(self):\n        \"\"\"Test that reduce_metrics works with single values.\"\"\"\n        metrics = {\n            \"single\": [5.0],\n        }\n        result = reduce_metrics(metrics)\n\n        self.assertEqual(result[\"single\"], 5.0)\n\n\nclass TestComputeDataMetrics(unittest.TestCase):\n    \"\"\"Tests for the compute_data_metrics function.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up common test data.\"\"\"\n        # Create a mock DataProto object\n        self.batch = MagicMock()\n        self.batch.batch = {\n            \"token_level_scores\": torch.tensor([[1.0, 2.0], [3.0, 4.0]]),\n            \"token_level_rewards\": torch.tensor([[0.5, 1.0], [1.5, 2.0]]),\n            \"advantages\": torch.tensor([[0.1, 0.2], [0.3, 0.4]]),\n            \"returns\": torch.tensor([[1.1, 1.2], [1.3, 1.4]]),\n            \"responses\": torch.zeros((2, 2)),  # 2 samples, 2 tokens each\n            \"attention_mask\": torch.tensor(\n                [\n                    [1, 1, 1, 1],  # 2 prompt tokens, 2 response tokens\n                    [1, 1, 1, 1],\n                ]\n            ),\n            \"response_mask\": torch.tensor(\n                [\n                    [1, 1],  # 2 response tokens\n                    [1, 1],\n                ]\n            ),\n            \"values\": torch.tensor([[0.9, 1.0], [1.1, 1.2]]),\n        }\n\n    def test_compute_data_metrics_with_critic(self):\n        \"\"\"Test compute_data_metrics with critic enabled.\"\"\"\n        metrics = compute_data_metrics(self.batch, use_critic=True)\n\n        # Check that all expected metrics are present\n        self.assertIn(\"critic/score/mean\", metrics)\n        self.assertIn(\"critic/rewards/mean\", metrics)\n        self.assertIn(\"critic/advantages/mean\", metrics)\n        self.assertIn(\"critic/returns/mean\", metrics)\n        self.assertIn(\"critic/values/mean\", metrics)\n        self.assertIn(\"critic/vf_explained_var\", metrics)\n        self.assertIn(\"response_length/mean\", metrics)\n        self.assertIn(\"prompt_length/mean\", metrics)\n\n        # Check some specific values\n        self.assertAlmostEqual(metrics[\"critic/score/mean\"], 5.0)  # Sum of token_level_scores\n        self.assertAlmostEqual(metrics[\"critic/rewards/mean\"], 2.5)  # Sum of token_level_rewards\n\n    def test_compute_data_metrics_without_critic(self):\n        \"\"\"Test compute_data_metrics with critic disabled.\"\"\"\n        metrics = compute_data_metrics(self.batch, use_critic=False)\n\n        # Check that critic-specific metrics are not present\n        self.assertNotIn(\"critic/values/mean\", metrics)\n        self.assertNotIn(\"critic/vf_explained_var\", metrics)\n\n        # Check that other metrics are still present\n        self.assertIn(\"critic/score/mean\", metrics)\n        self.assertIn(\"critic/rewards/mean\", metrics)\n        self.assertIn(\"response_length/mean\", metrics)\n\n\nclass TestComputeTimingMetrics(unittest.TestCase):\n    \"\"\"Tests for the compute_timing_metrics function.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up common test data.\"\"\"\n        # Create a mock DataProto object\n        self.batch = MagicMock()\n        self.batch.batch = {\n            \"responses\": torch.zeros((2, 3)),  # 2 samples, 3 response tokens each\n            \"attention_mask\": torch.tensor(\n                [\n                    [1, 1, 1, 1, 1, 1],  # 3 prompt tokens, 3 response tokens\n                    [1, 1, 1, 1, 1, 1],\n                ]\n            ),\n        }\n\n        # Mock the _compute_response_info function to return known values\n        self.response_info = {\n            \"prompt_length\": torch.tensor([3.0, 3.0]),\n            \"response_length\": torch.tensor([3.0, 3.0]),\n            \"response_mask\": torch.ones((2, 3)),\n        }\n\n    @patch(\"verl.trainer.ppo.metric_utils._compute_response_info\")\n    def test_compute_timing_metrics(self, mock_compute_response_info):\n        \"\"\"Test compute_timing_metrics with various timing data.\"\"\"\n        mock_compute_response_info.return_value = self.response_info\n\n        timing_raw = {\n            \"gen\": 0.5,  # 500ms\n            \"ref\": 0.3,  # 300ms\n            \"values\": 0.2,  # 200ms\n        }\n\n        metrics = compute_timing_metrics(self.batch, timing_raw)\n\n        # Check raw timing metrics\n        self.assertEqual(metrics[\"timing_s/gen\"], 0.5)\n        self.assertEqual(metrics[\"timing_s/ref\"], 0.3)\n        self.assertEqual(metrics[\"timing_s/values\"], 0.2)\n\n        # Check per-token timing metrics\n        # gen uses only response tokens (6 tokens)\n        self.assertAlmostEqual(metrics[\"timing_per_token_ms/gen\"], 0.5 * 1000 / 6, places=5)\n\n        # ref and values use all tokens (12 tokens)\n        self.assertAlmostEqual(metrics[\"timing_per_token_ms/ref\"], 0.3 * 1000 / 12, places=5)\n        self.assertAlmostEqual(metrics[\"timing_per_token_ms/values\"], 0.2 * 1000 / 12, places=5)\n\n\nclass TestComputeThroughputMetrics(unittest.TestCase):\n    \"\"\"Tests for the compute_throughout_metrics function.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up common test data.\"\"\"\n        # Create a mock DataProto object\n        self.batch = MagicMock()\n        self.batch.meta_info = {\n            \"global_token_num\": [100, 200, 300],  # 600 tokens total\n        }\n\n    def test_compute_throughout_metrics(self):\n        \"\"\"Test compute_throughout_metrics with various timing data.\"\"\"\n        timing_raw = {\n            \"step\": 2.0,  # 2 seconds per step\n        }\n\n        # Test with 1 GPU\n        metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=1)\n\n        self.assertEqual(metrics[\"perf/total_num_tokens\"], 600)\n        self.assertEqual(metrics[\"perf/time_per_step\"], 2.0)\n        self.assertEqual(metrics[\"perf/throughput\"], 600 / 2.0)  # 300 tokens/sec\n\n        # Test with 2 GPUs\n        metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=2)\n\n        self.assertEqual(metrics[\"perf/total_num_tokens\"], 600)\n        self.assertEqual(metrics[\"perf/time_per_step\"], 2.0)\n        self.assertEqual(metrics[\"perf/throughput\"], 600 / (2.0 * 2))  # 150 tokens/sec/GPU\n\n\nclass TestBootstrapMetric(unittest.TestCase):\n    \"\"\"Tests for the bootstrap_metric function.\"\"\"\n\n    def test_bootstrap_metric_basic(self):\n        \"\"\"Test bootstrap_metric with simple data and functions.\"\"\"\n        data = [1, 2, 3, 4, 5]\n        reduce_fns = [np.mean, np.max]\n\n        # Use a fixed seed for reproducibility\n        result = bootstrap_metric(data, subset_size=3, reduce_fns=reduce_fns, n_bootstrap=100, seed=42)\n\n        # Check that we get two results (one for each reduce_fn)\n        self.assertEqual(len(result), 2)\n\n        # Each result should be a tuple of (mean, std)\n        mean_result, max_result = result\n        self.assertEqual(len(mean_result), 2)\n        self.assertEqual(len(max_result), 2)\n\n        # The mean of means should be close to the true mean (3.0)\n        self.assertAlmostEqual(mean_result[0], 3.0, delta=0.3)\n\n        # The mean of maxes should be close to the expected value for samples of size 3\n        # For samples of size 3 from [1,2,3,4,5], the expected max is around 4.0-4.5\n        self.assertGreater(max_result[0], 3.5)\n        self.assertLess(max_result[0], 5.0)\n\n    def test_bootstrap_metric_empty(self):\n        \"\"\"Test bootstrap_metric with empty data.\"\"\"\n        with self.assertRaises(ValueError):\n            bootstrap_metric([], subset_size=1, reduce_fns=[np.mean])\n\n\nclass TestCalcMajVal(unittest.TestCase):\n    \"\"\"Tests for the calc_maj_val function.\"\"\"\n\n    def test_calc_maj_val_basic(self):\n        \"\"\"Test calc_maj_val with simple data.\"\"\"\n        data = [\n            {\"pred\": \"A\", \"val\": 0.9},\n            {\"pred\": \"B\", \"val\": 0.8},\n            {\"pred\": \"A\", \"val\": 0.7},\n        ]\n\n        result = calc_maj_val(data, vote_key=\"pred\", val_key=\"val\")\n\n        # \"A\" is the majority vote, so we should get the first \"val\" for \"A\"\n        self.assertEqual(result, 0.9)\n\n    def test_calc_maj_val_tie(self):\n        \"\"\"Test calc_maj_val with tied votes.\"\"\"\n        data = [\n            {\"pred\": \"A\", \"val\": 0.9},\n            {\"pred\": \"B\", \"val\": 0.8},\n            {\"pred\": \"B\", \"val\": 0.7},\n            {\"pred\": \"A\", \"val\": 0.6},\n        ]\n\n        # In case of a tie, the first key in sorted order wins\n        # This depends on Python's dict implementation, but for this test\n        # we just verify that one of the valid values is returned\n        result = calc_maj_val(data, vote_key=\"pred\", val_key=\"val\")\n\n        self.assertTrue(result in [0.9, 0.8])\n\n\nclass TestProcessValidationMetrics(unittest.TestCase):\n    \"\"\"Tests for the process_validation_metrics function.\"\"\"\n\n    def test_process_validation_metrics_basic(self):\n        \"\"\"Test process_validation_metrics with simple data.\"\"\"\n        data_sources = [\"source1\", \"source1\", \"source2\"]\n        sample_inputs = [\"prompt1\", \"prompt1\", \"prompt2\"]\n        infos_dict = {\n            \"score\": [0.8, 0.9, 0.7],\n        }\n\n        result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42)\n\n        # Check the structure of the result\n        self.assertIn(\"source1\", result)\n        self.assertIn(\"source2\", result)\n\n        # Check that source1 has metrics for score\n        self.assertIn(\"score\", result[\"source1\"])\n\n        # Check that mean@2 is present for source1/score\n        self.assertIn(\"mean@2\", result[\"source1\"][\"score\"])\n\n        # Check the value of mean@2 for source1/score\n        self.assertAlmostEqual(result[\"source1\"][\"score\"][\"mean@2\"], 0.85)\n\n    def test_process_validation_metrics_with_pred(self):\n        \"\"\"Test process_validation_metrics with prediction data.\"\"\"\n        data_sources = [\"source1\", \"source1\", \"source1\"]\n        sample_inputs = [\"prompt1\", \"prompt1\", \"prompt1\"]\n        infos_dict = {\n            \"score\": [0.8, 0.9, 0.7],\n            \"pred\": [\"A\", \"B\", \"A\"],\n        }\n\n        result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42)\n\n        # Check that majority voting metrics are present\n        self.assertIn(\"maj@2/mean\", result[\"source1\"][\"score\"])\n\n        # For bootstrap with n=2, the majority vote could be either A or B\n        # depending on the random sampling, so we don't check the exact value\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_rl/tests/utils/_test_module.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n# Test module for import_utils.load_extern_type testing\nclass TestClass:\n    \"\"\"A test class to be imported by load_extern_type\"\"\"\n\n    def __init__(self, value=None):\n        self.value = value or \"default\"\n\n    def get_value(self):\n        return self.value\n\n\nTEST_CONSTANT = \"test_constant_value\"\n\n\ndef test_function():\n    return \"test_function_result\"\n"
  },
  {
    "path": "verl_rl/tests/utils/dataset/test_create_rl_sampler_on_cpu.py",
    "content": "# Copyright 2025 Amazon.com Inc and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\ntest create_rl_sampler\n\"\"\"\n\nfrom collections.abc import Sized\n\nimport pytest\nimport torch\nfrom omegaconf import DictConfig, OmegaConf\nfrom torch.utils.data import Dataset, RandomSampler\n\nfrom verl.experimental.dataset.sampler import AbstractCurriculumSampler\nfrom verl.trainer.main_ppo import create_rl_sampler\n\n\nclass RandomCurriculumSampler(AbstractCurriculumSampler):\n    def __init__(\n        self,\n        data_source: Sized,\n        data_config: DictConfig,\n    ):\n        train_dataloader_generator = torch.Generator()\n        train_dataloader_generator.manual_seed(1)\n        sampler = RandomSampler(data_source=data_source)\n        self.sampler = sampler\n\n    def __iter__(self):\n        return self.sampler.__iter__()\n\n    def __len__(self) -> int:\n        return len(self.sampler)\n\n    def update(self, batch) -> None:\n        return\n\n\nclass MockIncorrectSampler:\n    \"\"\"A fake sampler class that does not adhere to the AbstractCurriculumSampler interface.\"\"\"\n\n    def __init__(self, data_source, data_config):\n        pass\n\n\nclass MockChatDataset(Dataset):\n    def __init__(self):\n        self.data = [\n            {\"prompt\": \"What's your name?\", \"response\": \"My name is Assistant.\"},\n            {\"prompt\": \"How are you?\", \"response\": \"I'm doing well, thank you.\"},\n            {\"prompt\": \"What is the capital of France?\", \"response\": \"Paris.\"},\n            {\n                \"prompt\": \"Tell me a joke.\",\n                \"response\": \"Why did the chicken cross the road? To get to the other side!\",\n            },\n            {\"prompt\": \"What is 2+2?\", \"response\": \"4\"},\n        ]\n\n    def __getitem__(self, index):\n        return self.data[index]\n\n    def __len__(self):\n        return len(self.data)\n\n\ndef test_create_custom_curriculum_samper():\n    data_config = OmegaConf.create(\n        {\n            \"dataloader_num_workers\": 0,\n            \"sampler\": {\n                \"class_path\": \"pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu\",\n                \"class_name\": \"RandomCurriculumSampler\",\n            },\n        }\n    )\n\n    dataset = MockChatDataset()\n\n    # doesn't raise\n    create_rl_sampler(data_config, dataset)\n\n\ndef test_create_custom_curriculum_samper_wrong_class():\n    data_config = OmegaConf.create(\n        {\n            \"sampler\": {\n                \"class_path\": \"pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu\",\n                \"class_name\": \"MockIncorrectSampler\",\n            }\n        }\n    )\n\n    dataset = MockChatDataset()\n\n    # MockIncorrectSampler is not an instance of AbstractCurriculumSampler, so raises\n    with pytest.raises(AssertionError):\n        create_rl_sampler(data_config, dataset)\n"
  },
  {
    "path": "verl_rl/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n\n#     http://www.apache.org/licenses/LICENSE-2.0\n\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTest the MultiTurnSFTDataset implementation\n\"\"\"\n\nimport os\n\nimport pandas as pd\nimport torch\nfrom transformers import AutoTokenizer\n\nfrom verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset\n\n\ndef test_multiturn_sft_dataset():\n    print(\"Starting test...\")\n    # Create a temporary parquet file with test data\n    test_data = {\n        \"messages\": [\n            [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n                {\"role\": \"assistant\", \"content\": \"2+2 equals 4.\"},\n                {\"role\": \"user\", \"content\": \"And what is 4+4?\"},\n                {\"role\": \"assistant\", \"content\": \"4+4 equals 8.\"},\n            ],\n            [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"Tell me a joke.\"},\n                {\"role\": \"assistant\", \"content\": \"Why did the chicken cross the road?\"},\n                {\"role\": \"user\", \"content\": \"Why?\"},\n                {\"role\": \"assistant\", \"content\": \"To get to the other side!\"},\n            ],\n        ]\n    }\n\n    # Create test directory if it doesn't exist\n    os.makedirs(\"test_data\", exist_ok=True)\n    test_file = \"test_data/test.parquet\"\n\n    # Save test data to parquet\n    df = pd.DataFrame(test_data)\n    df.to_parquet(test_file)\n\n    # Initialize tokenizer and dataset\n    tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-Coder-7B-Instruct\")\n    config = {\"max_length\": 512, \"truncation\": \"error\", \"multiturn\": {\"messages_key\": \"messages\"}}\n    dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)\n\n    # Test 1: Dataset Length\n    assert len(dataset) == 2, f\"Expected dataset length 2, got {len(dataset)}\"\n\n    # Get items for testing\n    item0 = dataset[0]  # Math conversation\n    item1 = dataset[1]  # Joke conversation\n\n    # Test 2: Required Keys and Types\n    required_keys = [\"input_ids\", \"attention_mask\", \"position_ids\", \"loss_mask\"]\n    for key in required_keys:\n        assert key in item0, f\"Missing key {key} in dataset item\"\n        assert isinstance(item0[key], torch.Tensor), f\"Expected torch.Tensor for {key}\"\n        assert item0[key].dtype == torch.long, f\"Expected torch.long for {key}, got {item0[key].dtype}\"\n\n    # Test 3: Shape Consistency\n    assert item0[\"loss_mask\"].shape == item0[\"input_ids\"].shape, \"Loss mask shape doesn't match input_ids shape\"\n    assert item0[\"attention_mask\"].shape == item0[\"input_ids\"].shape, (\n        \"Attention mask shape doesn't match input_ids shape\"\n    )\n    assert item0[\"position_ids\"].shape == item0[\"input_ids\"].shape, \"Position IDs shape doesn't match input_ids shape\"\n\n    # Test 4: Loss Mask Pattern - Math Conversation\n    loss_mask0 = item0[\"loss_mask\"]\n    input_ids0 = item0[\"input_ids\"]\n\n    # Find assistant response positions\n    assistant_positions0 = torch.where(loss_mask0 == 1)[0]\n    assert len(assistant_positions0) > 0, \"No assistant positions found in loss mask\"\n\n    # Decode and verify assistant responses\n    assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1])\n    print(f\"Math conversation assistant text: {assistant_text0}\")\n    assert \"2+2 equals 4\" in assistant_text0, \"First assistant response not found\"\n    assert \"4+4 equals 8\" in assistant_text0, \"Second assistant response not found\"\n\n    # Test 5: Loss Mask Pattern - Joke Conversation\n    loss_mask1 = item1[\"loss_mask\"]\n    input_ids1 = item1[\"input_ids\"]\n\n    # Find assistant response positions\n    assistant_positions1 = torch.where(loss_mask1 == 1)[0]\n    assert len(assistant_positions1) > 0, \"No assistant positions found in loss mask\"\n\n    # Decode and verify assistant responses\n    assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1])\n    print(f\"Joke conversation assistant text: {assistant_text1}\")\n    assert \"chicken cross the road\" in assistant_text1, \"First assistant response not found\"\n    assert \"other side\" in assistant_text1, \"Second assistant response not found\"\n\n    # Test 6: Attention Mask Pattern\n    attention_mask0 = item0[\"attention_mask\"]\n    sequence_length = torch.sum(attention_mask0)\n    assert sequence_length > 0, \"No tokens marked as attended in attention mask\"\n    assert torch.all(attention_mask0[:sequence_length] == 1), \"Incorrect attention mask pattern\"\n    if sequence_length < len(attention_mask0):\n        assert torch.all(attention_mask0[sequence_length:] == 0), \"Padding not properly masked\"\n\n    # Test 7: Position IDs Pattern\n    position_ids0 = item0[\"position_ids\"]\n    assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), (\n        \"Position IDs not sequential for non-padded tokens\"\n    )\n    if sequence_length < len(position_ids0):\n        assert torch.all(position_ids0[sequence_length:] == 0), \"Padding position IDs not zero\"\n\n    # Test 8: Verify loss mask for assistant responses\n    # Get the full conversation text\n    full_text = tokenizer.decode(input_ids0)\n    print(f\"\\nFull conversation text:\\n{full_text}\")\n\n    # Get the assistant responses\n    assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1])\n    print(f\"\\nAssistant responses (from loss mask):\\n{assistant_text}\")\n\n    # Verify that loss mask is set for all assistant responses\n    for msg in test_data[\"messages\"][0]:  # First conversation\n        if msg[\"role\"] == \"assistant\":\n            # The content should appear in the masked text\n            assert msg[\"content\"] in assistant_text, f\"Assistant message '{msg['content']}' not found in masked text\"\n\n            # The content should NOT appear in the non-masked text\n            non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0])\n            assert msg[\"content\"] not in non_assistant_text, (\n                f\"Assistant message '{msg['content']}' found in non-assistant text\"\n            )\n\n    # Test 9: Verify non-assistant parts have loss_mask=0\n    # Get non-assistant text\n    non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0])\n    print(f\"\\nNon-assistant text (from loss mask):\\n{non_assistant_text}\")\n\n    # Verify that system and user messages are in the non-assistant text\n    for msg in test_data[\"messages\"][0]:  # First conversation\n        if msg[\"role\"] in [\"system\", \"user\"]:\n            assert msg[\"content\"] in non_assistant_text, (\n                f\"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text\"\n            )\n\n            # And verify they're NOT in the assistant text\n            assert msg[\"content\"] not in assistant_text, (\n                f\"{msg['role'].title()} message '{msg['content']}' found in assistant text\"\n            )\n\n    # Test 10: Verify padding behavior\n    padding_config = {\"max_length\": 1024, \"truncation\": \"error\", \"multiturn\": {\"messages_key\": \"messages\"}}\n    small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config)\n    padded_item = small_dataset[0]\n\n    # Get actual sequence length (before padding)\n    actual_length = torch.sum(padded_item[\"attention_mask\"])\n\n    # Verify padding tokens\n    assert torch.all(padded_item[\"input_ids\"][actual_length:] == tokenizer.pad_token_id), (\n        \"Padding tokens not set correctly\"\n    )\n    assert torch.all(padded_item[\"attention_mask\"][actual_length:] == 0), \"Attention mask not set correctly for padding\"\n    assert torch.all(padded_item[\"loss_mask\"][actual_length:] == 0), \"Loss mask not set correctly for padding\"\n\n    print(\"All tests passed!\")\n    print(\"Starting test...\")\n"
  },
  {
    "path": "verl_rl/tests/utils/dataset/test_rl_dataset_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nimport torch\nfrom omegaconf import OmegaConf\nfrom torch.utils.data import DataLoader\n\n\ndef get_gsm8k_data():\n    # prepare test dataset\n    local_folder = os.path.expanduser(\"~/verl-data/gsm8k/\")\n    local_path = os.path.join(local_folder, \"train.parquet\")\n    os.makedirs(local_folder, exist_ok=True)\n    return local_path\n\n\ndef test_rl_dataset():\n    from verl.utils import hf_tokenizer\n    from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn\n\n    tokenizer = hf_tokenizer(\"deepseek-ai/deepseek-coder-1.3b-instruct\")\n    local_path = get_gsm8k_data()\n    config = OmegaConf.create(\n        {\n            \"prompt_key\": \"prompt\",\n            \"max_prompt_length\": 256,\n            \"filter_overlong_prompts\": True,\n            \"filter_overlong_prompts_workers\": 2,\n        }\n    )\n    dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config)\n\n    dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)\n\n    a = next(iter(dataloader))\n\n    from verl import DataProto\n\n    tensors = {}\n    non_tensors = {}\n\n    for key, val in a.items():\n        if isinstance(val, torch.Tensor):\n            tensors[key] = val\n        else:\n            non_tensors[key] = val\n\n    data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)\n    assert \"input_ids\" in data_proto.batch\n\n    data = dataset[0][\"input_ids\"]\n    output = tokenizer.batch_decode([data])[0]\n    print(f\"type: type{output}\")\n    print(f\"\\n\\noutput: {output}\")\n\n\ndef test_image_rl_data():\n    from verl.utils import hf_processor, hf_tokenizer\n    from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn\n\n    tokenizer = hf_tokenizer(\"Qwen/Qwen2-VL-2B-Instruct\")\n    processor = hf_processor(\"Qwen/Qwen2-VL-2B-Instruct\")\n    config = OmegaConf.create(\n        {\n            \"prompt_key\": \"prompt\",\n            \"max_prompt_length\": 1024,\n            \"filter_overlong_prompts\": True,\n            \"filter_overlong_prompts_workers\": 2,\n        }\n    )\n    dataset = RLHFDataset(\n        data_files=os.path.expanduser(\"~/data/geo3k/train.parquet\"),\n        tokenizer=tokenizer,\n        config=config,\n        processor=processor,\n    )\n\n    dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)\n\n    a = next(iter(dataloader))\n\n    from verl import DataProto\n\n    tensors = {}\n    non_tensors = {}\n\n    for key, val in a.items():\n        if isinstance(val, torch.Tensor):\n            tensors[key] = val\n        else:\n            non_tensors[key] = val\n\n    data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)\n\n    assert \"multi_modal_data\" in data_proto.non_tensor_batch, data_proto\n    assert \"multi_modal_inputs\" in data_proto.non_tensor_batch, data_proto\n\n    data = dataset[0][\"input_ids\"]\n    output = tokenizer.batch_decode([data])[0]\n    print(f\"type: type{output}\")\n    print(f\"\\n\\noutput: {output}\")\n"
  },
  {
    "path": "verl_rl/tests/utils/dataset/test_sft_dataset_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.dataset.sft_dataset import SFTDataset\n\n\ndef get_gsm8k_data():\n    # prepare test dataset\n    local_folder = os.path.expanduser(\"~/verl-data/gsm8k/\")\n    local_path = os.path.join(local_folder, \"train.parquet\")\n    return local_path\n\n\ndef test_sft_cot_dataset():\n    tokenizer = hf_tokenizer(\"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct\")\n    local_path = get_gsm8k_data()\n    from omegaconf import OmegaConf\n\n    dataset = SFTDataset(\n        parquet_files=local_path,\n        tokenizer=tokenizer,\n        config=OmegaConf.create(\n            {\n                \"prompt_key\": \"prompt\",\n                \"prompt_dict_keys\": [\"content\"],\n                \"response_key\": \"extra_info\",\n                \"response_dict_keys\": [\"answer\"],\n                \"max_length\": 512,\n            }\n        ),\n    )\n\n    data = dataset[0][\"input_ids\"]\n    output = tokenizer.batch_decode([data])[0]\n    assert len(output) > 1\n    assert isinstance(output, str)\n\n\ndef test_sft_dataset():\n    tokenizer = hf_tokenizer(\"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct\")\n    local_path = get_gsm8k_data()\n    from omegaconf import OmegaConf\n\n    dataset = SFTDataset(\n        parquet_files=local_path,\n        tokenizer=tokenizer,\n        config=OmegaConf.create(\n            {\n                \"prompt_key\": \"extra_info\",\n                \"prompt_dict_keys\": [\"question\"],\n                \"response_key\": \"extra_info\",\n                \"response_dict_keys\": [\"answer\"],\n                \"max_length\": 512,\n            }\n        ),\n    )\n\n    data = dataset[0][\"input_ids\"]\n    output = tokenizer.batch_decode([data])[0]\n    assert len(output) > 1\n    assert isinstance(output, str)\n"
  },
  {
    "path": "verl_rl/tests/utils/megatron/test_pipeline_parallel.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\nfrom verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards\nfrom verl.utils.megatron.pipeline_parallel import make_batch_generator\n\n\ndef test_make_batch_generator_no_vpp():\n    batches = [1, 2, 3]\n    vpp_size = 1\n    generator = make_batch_generator(batches, vpp_size)\n    assert list(generator) == batches\n\n\ndef test_make_batch_generator_with_vpp():\n    batches = [{\"data\": 1}, {\"data\": 2}]\n    vpp_size = 2\n    generators = make_batch_generator(batches, vpp_size)\n    assert isinstance(generators, list)\n    assert len(generators) == vpp_size\n\n    # Check each generator yields the original batches\n    for gen in generators:\n        assert list(gen) == batches\n\n\ndef test_make_batch_generator_empty():\n    batches = []\n    vpp_size = 1\n    generator = make_batch_generator(batches, vpp_size)\n    assert list(generator) == []\n\n    vpp_size = 3\n    generators = make_batch_generator(batches, vpp_size)\n    assert len(generators) == vpp_size\n    for gen in generators:\n        assert list(gen) == []\n\n\n@pytest.mark.parametrize(\n    \"layer_num,pp_size,gt\",\n    [\n        (61, 8, [6, 8, 8, 8, 8, 8, 8, 7]),\n        (61, 7, [8, 9, 9, 9, 9, 9, 8]),\n        (61, 1, [61]),\n        (61, 0, ValueError),\n        (10, 16, ValueError),\n    ],\n)\ndef test_get_dynamic_pipeline_shards(layer_num, pp_size, gt):\n    if isinstance(gt, list):\n        shards = get_dynamic_pipeline_shards(layer_num, pp_size)\n        assert len(shards) == len(gt) == pp_size, f\"Expected {pp_size} shards, got {len(shards)}\"\n        assert all([shard == gt[i] for i, shard in enumerate(shards)]), f\"Expected shards {gt}, got {shards}\"\n    elif issubclass(gt, Exception):\n        with pytest.raises(gt):\n            shards = get_dynamic_pipeline_shards(layer_num, pp_size)\n"
  },
  {
    "path": "verl_rl/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport multiprocessing\nimport os\nimport time\nfrom concurrent.futures import ProcessPoolExecutor\nfrom unittest.mock import patch\n\nimport pytest\n\n# Import the function to be tested\nfrom verl.utils.reward_score.sandbox_fusion.utils import check_correctness\n\n# Get SANDBOX_URL from environment variable\nSANDBOX_URL = os.environ.get(\"SANDBOX_FUSION_URL\")\n# Define skip condition and reason\nskip_reason = \"SANDBOX_FUSION_URL environment variable not set\"\nskip_condition = not SANDBOX_URL\n\n# --- Test code (for real API calls) ---\nCODE_SUCCESS = \"\"\"\nimport sys\ndata = sys.stdin.read()\nif data == 'input1':\n    print('output1\\\\n', end='')\nelif data == 'input2':\n    print('output2\\\\n', end='')\nelse:\n    print('unexpected input', end='')\n\"\"\"\n\nCODE_WRONG_OUTPUT = \"\"\"\nprint('wrong_output\\\\n', end='')\n\"\"\"\n\nCODE_COMPILE_ERROR = \"\"\"\na=b\n\"\"\"\n\nCODE_RUNTIME_ERROR = \"\"\"\nimport sys\nprint(\"About to raise error\", file=sys.stderr)\nraise ValueError(\"This is a runtime error\")\n\"\"\"\n\nCODE_TIMEOUT = \"\"\"\nimport time\nimport sys\nprint(\"Sleeping...\", file=sys.stderr)\ntime.sleep(10) # Sleep time should be longer than the timeout set in the test\nprint(\"Finished sleeping\", file=sys.stderr)\n\"\"\"\n\n# --- Test input/output data ---\nINPUT_OUTPUT_VALID = {\"inputs\": [\"input1\", \"input2\"], \"outputs\": [\"output1\\n\", \"output2\\n\"]}\n\nINPUT_OUTPUT_SINGLE = {\"inputs\": [\"input1\"], \"outputs\": [\"output1\\n\"]}\n\nINPUT_OUTPUT_MISMATCH = {\"inputs\": [\"input1\"], \"outputs\": [\"output1\\n\", \"output2\\n\"]}\n\nINPUT_OUTPUT_INVALID_MISSING_KEY = {\"inputs\": [\"input1\"]}\n\n# --- Integration test cases (calling real API) ---\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_success_correct():\n    \"\"\"Integration test: Code is correct, output is correct\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_SUCCESS)\n    assert results == [True, True]\n    assert metadata_list[0][\"status\"] == \"success\"\n    assert metadata_list[0][\"stdout\"] == \"output1\\n\"\n    assert metadata_list[1][\"status\"] == \"success\"\n    assert metadata_list[1][\"stdout\"] == \"output2\\n\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_success_wrong_output():\n    \"\"\"Integration test: Code runs successfully, but output is wrong\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_WRONG_OUTPUT)\n    assert results == [False, False]\n    assert metadata_list[0][\"status\"] == \"wrong_answer\"\n    assert metadata_list[0][\"stdout\"] == \"wrong_output\\n\"\n    assert metadata_list[1][\"status\"] == \"wrong_answer\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_compile_error():\n    \"\"\"Integration test: Code causes compile error\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_COMPILE_ERROR, language=\"cpp\")\n    assert results == [-4, -4]\n    assert metadata_list[0][\"status\"] == \"compile_error\"\n    assert metadata_list[1][\"status\"] == \"compile_error\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_runtime_error():\n    \"\"\"Integration test: Code causes runtime error\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_RUNTIME_ERROR)\n    assert results == [-2]\n    assert metadata_list[0][\"status\"] == \"runtime_error\"\n    # More assertions can be added based on the actual API response, e.g., exit_code, stderr\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_runtime_timeout():\n    \"\"\"Integration test: Code causes runtime timeout\"\"\"\n    test_timeout = 5  # Set a timeout shorter than the sleep time in CODE_TIMEOUT\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_TIMEOUT, timeout=test_timeout)\n    assert results == [-3]\n    assert metadata_list[0][\"status\"] == \"timeout\"\n    # More assertions can be added based on the actual API response, e.g., run_status\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_concurrency_high_load():\n    \"\"\"Integration test: High concurrency (100 cases) against real API with mixed results (success, wrong\n    answer, timeout)\"\"\"\n    concurrency_level = 100\n    # Indices for different expected outcomes\n    wrong_answer_indices = {10, 25, 50}\n    timeout_indices = {5, 30, 60, 90}  # Indices where we expect a timeout\n\n    # Generate 100 input/output pairs and code\n    high_load_inputs = []\n    high_load_outputs = []\n    expected_results_map = {}  # Store expected result for each index\n\n    for i in range(concurrency_level):\n        if i in timeout_indices:\n            # Use a special input to trigger timeout in the code\n            high_load_inputs.append(f\"input_timeout_{i}\")\n            # Output doesn't matter for timeout, but keep it consistent\n            high_load_outputs.append(f\"output_{i}\\n\")\n            expected_results_map[i] = -3  # Expect timeout\n        elif i in wrong_answer_indices:\n            high_load_inputs.append(f\"input_{i}\")\n            # Intentionally set wrong expected output\n            high_load_outputs.append(f\"wrong_output_{i}\\n\")\n            expected_results_map[i] = False  # Expect wrong answer\n        else:\n            high_load_inputs.append(f\"input_{i}\")\n            # Correct expected output\n            high_load_outputs.append(f\"output_{i}\\n\")\n            expected_results_map[i] = True  # Expect success\n\n    high_load_in_outs = {\"inputs\": high_load_inputs, \"outputs\": high_load_outputs}\n\n    # Code that handles normal inputs, and sleeps on specific \"timeout\" inputs\n    code_mixed_concurrent = \"\"\"\nimport sys\nimport time\ndata = sys.stdin.read()\nif data.startswith('input_timeout_'):\n    time.sleep(20) # Sleep longer than the test timeout\n    print(f\"output_{data.split('_')[-1]}\\\\n\", end='') # Still print something in case it finishes early\nelif data.startswith('input_'):\n    print(f\"output_{data.split('_')[-1]}\\\\n\", end='')\nelse:\n    print(\"unknown_input\\\\n\", end='')\n\"\"\"\n    # Set a reasonable timeout per case (must be less than the sleep time in the code)\n    test_timeout = 15  # Allow slightly more time due to potential API load, but less than 20s sleep\n\n    start_time = time.time()\n    results, metadata_list = check_correctness(\n        SANDBOX_URL,\n        high_load_in_outs,\n        code_mixed_concurrent,  # Use the new code\n        timeout=test_timeout,\n    )\n    end_time = time.time()\n    duration = end_time - start_time\n    print(\n        f\"\\nHigh concurrency test ({concurrency_level} cases with {len(wrong_answer_indices)} wrong answers, \"\n        f\"{len(timeout_indices)} timeouts) duration: {duration:.2f} seconds\"\n    )\n\n    # Verify results against the expected map\n    assert len(results) == concurrency_level, f\"Expected {concurrency_level} results, got {len(results)}\"\n\n    correct_count = 0\n    wrong_count = 0\n    timeout_count = 0\n    unexpected_results = []\n    for i, r in enumerate(results):\n        expected = expected_results_map[i]\n        if r == expected:\n            if expected is True:\n                correct_count += 1\n            elif expected is False:\n                wrong_count += 1\n            elif expected == -3:\n                timeout_count += 1\n        else:\n            unexpected_results.append((i, r, f\"Expected {expected}\"))\n\n    print(\n        f\"Correct results (True): {correct_count}/\"\n        f\"{concurrency_level - len(wrong_answer_indices) - len(timeout_indices)}\"\n    )\n    print(f\"Expected wrong answers (False, correctly identified): {wrong_count}/{len(wrong_answer_indices)}\")\n    print(f\"Expected timeouts (-3, correctly identified): {timeout_count}/{len(timeout_indices)}\")\n\n    if unexpected_results:\n        print(\"Unexpected results found:\")\n        for idx, res, expected_str in unexpected_results[:10]:  # Print first 10 unexpected\n            print(f\"  Index {idx}: Got {res}, {expected_str}. Metadata: {metadata_list[idx]}\")\n        raise AssertionError(f\"Found {len(unexpected_results)} unexpected results.\")\n\n    assert correct_count == concurrency_level - len(wrong_answer_indices) - len(timeout_indices), (\n        \"Incorrect number of successful results\"\n    )\n    assert wrong_count == len(wrong_answer_indices), \"Incorrect number of identified wrong answers\"\n    assert timeout_count == len(timeout_indices), \"Incorrect number of identified timeouts\"\n\n    # Verify metadata count and basic status of one of each type\n    assert len(metadata_list) == concurrency_level\n    # Find the first correct index\n    first_correct_index = next(\n        i for i in range(concurrency_level) if i not in wrong_answer_indices and i not in timeout_indices\n    )\n    assert metadata_list[first_correct_index][\"status\"] == \"success\"\n    assert metadata_list[first_correct_index][\"stdout\"] == f\"output_{first_correct_index}\\n\"\n\n    # Check the status of the first intentionally wrong case\n    first_wrong_index = min(wrong_answer_indices)\n    assert metadata_list[first_wrong_index][\"status\"] == \"wrong_answer\"\n    assert metadata_list[first_wrong_index][\"stdout\"] == f\"output_{first_wrong_index}\\n\"\n    assert metadata_list[first_wrong_index][\"expected_output\"] == f\"wrong_output_{first_wrong_index}\\n\"\n\n    # Check the status of the first intentionally timeout case\n    first_timeout_index = min(timeout_indices)\n    assert metadata_list[first_timeout_index][\"status\"] == \"timeout\"\n    # For timeout, stdout might be None or empty depending on when the timeout occurred\n    # assert metadata_list[first_timeout_index][\"stdout\"] is None or metadata_list[first_timeout_index][\"stdout\"] == \"\"\n\n\n# --- Unit test cases (using mock) ---\n\n\n@patch(\"verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api\")\ndef test_unit_concurrency_order(mock_call_sandbox_api):\n    sandbox_url = \"mock_url\"\n    generation = \"print(input())\"\n    language = \"python\"\n    timeout = 5\n    in_outs = {\"inputs\": [\"input1\", \"input2\", \"input3\"], \"outputs\": [\"output1\", \"output2\", \"output3\"]}\n\n    def side_effect(*args, **kwargs):\n        stdin = kwargs.get(\"stdin\")\n        if stdin == \"input1\":\n            return (\n                {\"status\": \"Success\", \"run_result\": {\"status\": \"Finished\", \"stdout\": \"output1\", \"return_code\": 0}},\n                None,\n            )\n        elif stdin == \"input2\":\n            time.sleep(0.1)\n            return (\n                {\"status\": \"Success\", \"run_result\": {\"status\": \"Finished\", \"stdout\": \"output2\", \"return_code\": 0}},\n                None,\n            )\n        elif stdin == \"input3\":\n            return (\n                {\"status\": \"Success\", \"run_result\": {\"status\": \"Finished\", \"stdout\": \"output3\", \"return_code\": 0}},\n                None,\n            )\n        else:\n            return (None, \"Unknown input in mock\")\n\n    mock_call_sandbox_api.side_effect = side_effect\n\n    results, metadata_list = check_correctness(sandbox_url, in_outs, generation, timeout, language)\n\n    assert results == [True, True, True]\n    assert len(metadata_list) == 3\n    assert metadata_list[0][\"case_index\"] == 0\n    assert metadata_list[0][\"status\"] == \"success\"\n    assert metadata_list[1][\"case_index\"] == 1\n    assert metadata_list[1][\"status\"] == \"success\"\n    assert metadata_list[2][\"case_index\"] == 2\n    assert metadata_list[2][\"status\"] == \"success\"\n    assert mock_call_sandbox_api.call_count == 3\n\n\n@patch(\"verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api\")\ndef test_unit_api_timeout_error_concurrent(mock_call_sandbox_api):\n    sandbox_url = \"mock_url\"\n    generation = \"print(input())\"\n    language = \"python\"\n    timeout = 5\n    in_outs = {\"inputs\": [\"input1\", \"input2_timeout\", \"input3\"], \"outputs\": [\"output1\", \"output2\", \"output3\"]}\n\n    api_error_message = \"API Call Failed: Gateway Timeout (504) on attempt 3/3\"\n\n    def side_effect(*args, **kwargs):\n        stdin = kwargs.get(\"stdin\")\n        if stdin == \"input1\":\n            return (\n                {\"status\": \"Success\", \"run_result\": {\"status\": \"Finished\", \"stdout\": \"output1\", \"return_code\": 0}},\n                None,\n            )\n        elif stdin == \"input2_timeout\":\n            return (None, api_error_message)\n        elif stdin == \"input3\":\n            return (\n                {\"status\": \"Success\", \"run_result\": {\"status\": \"Finished\", \"stdout\": \"output3\", \"return_code\": 0}},\n                None,\n            )\n        else:\n            return (None, \"Unknown input in mock\")\n\n    mock_call_sandbox_api.side_effect = side_effect\n\n    results, metadata_list = check_correctness(sandbox_url, in_outs, generation, timeout, language)\n\n    assert results == [True, -1, True]\n    assert len(metadata_list) == 3\n    assert metadata_list[0][\"status\"] == \"success\"\n    assert metadata_list[1][\"status\"] == \"api_error\"\n    assert metadata_list[1][\"api_request_error\"] == api_error_message\n    assert metadata_list[2][\"status\"] == \"success\"\n    assert mock_call_sandbox_api.call_count == 3\n\n\n# --- Constants for the new concurrency test ---\n# Define a low global concurrency limit to test the semaphore's effect\nMAX_GLOBAL_CONCURRENCY_LIMIT_TEST = 5\n# Define the number of processes used in the test\nNUM_PROCESSES_TEST = 4\n# Define the number of tasks processed by check_correctness in each process (i.e., internal\n# ThreadPoolExecutor's concurrency potential)\nNUM_TASKS_PER_PROCESS_TEST = 3\n# Simulate API call duration to ensure calls can overlap\nSIMULATED_API_CALL_DURATION_TEST = 0.2  # seconds\n\n\n# --- Mock API call function for concurrency tracking ---\n# This function will replace the real call_sandbox_api and use shared variables to track concurrency\ndef _mock_api_call_for_concurrency_tracking(\n    active_calls_counter,  # multiprocessing.Value\n    max_calls_tracker,  # multiprocessing.Value\n    call_lock,  # multiprocessing.Lock\n    # Standard call_sandbox_api parameters\n    sandbox_fusion_url,\n    code,\n    stdin,\n    compile_timeout,\n    run_timeout,\n    memory_limit_mb,\n    language,\n):\n    # entry_time = time.time() # For detailed logging\n    with call_lock:\n        active_calls_counter.value += 1\n        if active_calls_counter.value > max_calls_tracker.value:\n            max_calls_tracker.value = active_calls_counter.value\n        # Optional debug log:\n        # print(f\"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call Start. Active: \"\n        #       f\"{active_calls_counter.value}, Max Observed: {max_calls_tracker.value}, Input: {stdin}\")\n\n    time.sleep(SIMULATED_API_CALL_DURATION_TEST)  # Simulate actual work duration\n\n    # exit_time = time.time() # For detailed logging\n    with call_lock:\n        active_calls_counter.value -= 1\n        # Optional debug log:\n        # print(f\"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call End. Active: \"\n        #       f\"{active_calls_counter.value}, Input: {stdin}, Duration: {exit_time - entry_time:.2f}s\")\n\n    # Return a simulated successful API response\n    return {\n        \"status\": \"Success\",\n        \"run_result\": {\"status\": \"Finished\", \"stdout\": f\"mock_output_for_{stdin}\", \"return_code\": 0},\n    }, None\n\n\n# --- Worker function for ProcessPoolExecutor ---\n# This function runs in each child process of ProcessPoolExecutor\ndef _process_pool_worker_for_concurrency_test(\n    sandbox_url,\n    in_outs,\n    generation,\n    memory_limit_mb,\n    language,\n    timeout,\n    mp_semaphore_for_check_correctness,\n    active_calls_counter,\n    max_calls_tracker,\n    call_lock,\n):\n    # Corrected lambda to accept keyword arguments matching call_sandbox_api's usage\n    curried_mock_api_call = (\n        lambda sandbox_fusion_url, code, stdin, compile_timeout, run_timeout, memory_limit_mb, language: (\n            _mock_api_call_for_concurrency_tracking(\n                active_calls_counter,\n                max_calls_tracker,\n                call_lock,\n                sandbox_fusion_url,\n                code,\n                stdin,\n                compile_timeout,\n                run_timeout,\n                memory_limit_mb,\n                language,\n            )\n        )\n    )\n\n    # ---- START DEBUG PRINTS ----\n    import os\n\n    import verl.utils.reward_score.sandbox_fusion.utils\n\n    print(\n        f\"[Worker PID:{os.getpid()}] Original call_sandbox_api: \"\n        f\"{verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}\",\n        flush=True,\n    )\n    # ---- END DEBUG PRINTS ----\n\n    with patch(\n        \"verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api\", side_effect=curried_mock_api_call\n    ) as mock_obj:\n        # ---- START DEBUG PRINTS ----\n        print(\n            f\"[Worker PID:{os.getpid()}] Patched call_sandbox_api: \"\n            f\"{verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}\",\n            flush=True,\n        )\n        print(f\"[Worker PID:{os.getpid()}] Mock object: {mock_obj}\", flush=True)\n        # ---- END DEBUG PRINTS ----\n        results, metadata_list = check_correctness(\n            sandbox_fusion_url=sandbox_url,\n            in_outs=in_outs,\n            generation=generation,\n            timeout=timeout,\n            memory_limit_mb=memory_limit_mb,\n            language=language,\n            concurrent_semaphore=mp_semaphore_for_check_correctness,  # Pass multiprocessing.Semaphore\n        )\n        # print(f\"Process {os.getpid()} finished check_correctness. Processed {len(results)} tasks.\")\n    return len(results)  # Return the number of processed tasks for basic validation\n\n\n# --- The actual test case for multiprocess concurrency control ---\ndef test_multiprocess_global_concurrency_limit_with_semaphore():\n    \"\"\"\n    Tests that the global concurrent_semaphore (multiprocessing.Semaphore)\n    correctly limits the number of concurrent calls to call_sandbox_api\n    across multiple processes, each potentially running multiple threads\n    via check_correctness's internal ThreadPoolExecutor.\n    \"\"\"\n    manager = multiprocessing.Manager()\n    active_calls_counter = manager.Value(\"i\", 0)  # Current active mock API calls\n    max_calls_tracker = manager.Value(\"i\", 0)  # Observed maximum concurrent mock API calls\n    call_lock = manager.Lock()  # Lock to protect counters\n\n    # Create a multiprocessing.Semaphore instance, this is the global semaphore we are testing.\n    # It will be passed to check_correctness and used by _process_single_case to limit calls to call_sandbox_api.\n    global_mp_semaphore = manager.Semaphore(MAX_GLOBAL_CONCURRENCY_LIMIT_TEST)\n\n    mock_sandbox_url = \"mock_url_for_concurrency_test\"\n    mock_generation = \"pass\"  # Specific code content is not important as API call is mocked\n    mock_memory_limit_mb = 1024\n    mock_language = \"python\"\n    mock_timeout = 5  # Timeout setting, not critical for mock calls\n\n    # Input/output data for each process\n    # NUM_TASKS_PER_PROCESS_TEST tasks will be handled by check_correctness's internal ThreadPoolExecutor\n    process_in_outs = {\n        \"inputs\": [f\"task_input_{i}\" for i in range(NUM_TASKS_PER_PROCESS_TEST)],\n        \"outputs\": [f\"task_output_{i}\" for i in range(NUM_TASKS_PER_PROCESS_TEST)],\n    }\n\n    futures = []\n    total_tasks_expected_to_run = NUM_PROCESSES_TEST * NUM_TASKS_PER_PROCESS_TEST\n\n    test_start_time = time.time()\n\n    with ProcessPoolExecutor(max_workers=NUM_PROCESSES_TEST) as executor:\n        for i in range(NUM_PROCESSES_TEST):\n            future = executor.submit(\n                _process_pool_worker_for_concurrency_test,  # Worker function\n                mock_sandbox_url,\n                process_in_outs,\n                mock_generation,\n                mock_memory_limit_mb,\n                mock_language,\n                mock_timeout,\n                global_mp_semaphore,  # Global semaphore to test\n                active_calls_counter,  # Shared variables for tracking\n                max_calls_tracker,\n                call_lock,\n            )\n            futures.append(future)\n\n    # Wait for all processes to complete and collect results\n    num_tasks_processed_per_worker = [f.result() for f in futures]\n    test_end_time = time.time()\n    total_execution_time = test_end_time - test_start_time\n\n    # Print some test statistics for debugging and validation\n    print(\"\\n--- Global Concurrency Test Stats ---\")\n    print(f\"Semaphore Limit (MAX_GLOBAL_CONCURRENCY_LIMIT_TEST): {MAX_GLOBAL_CONCURRENCY_LIMIT_TEST}\")\n    print(f\"Number of Processes (NUM_PROCESSES_TEST): {NUM_PROCESSES_TEST}\")\n    print(f\"Tasks per Process (NUM_TASKS_PER_PROCESS_TEST): {NUM_TASKS_PER_PROCESS_TEST}\")\n    print(f\"Total Tasks Submitted: {total_tasks_expected_to_run}\")\n    print(f\"Simulated API Call Duration: {SIMULATED_API_CALL_DURATION_TEST}s\")\n    print(f\"Total Test Execution Time: {total_execution_time:.2f}s\")\n    print(f\"Max Concurrent Mock API Calls Observed: {max_calls_tracker.value}\")\n    # print(f\"Tasks processed per worker: {num_tasks_processed_per_worker}\")\n\n    # Verify that all submitted tasks have been processed\n    assert sum(num_tasks_processed_per_worker) == total_tasks_expected_to_run, (\n        \"Mismatch in the number of tasks processed.\"\n    )\n\n    # Verify that the mock API was called at least once\n    assert max_calls_tracker.value > 0, \"The mocked API call_sandbox_api was not called.\"\n\n    # Core assertion: Observed maximum concurrent calls should not exceed the semaphore's limit\n    assert max_calls_tracker.value <= MAX_GLOBAL_CONCURRENCY_LIMIT_TEST, (\n        f\"Observed concurrency ({max_calls_tracker.value}) exceeded semaphore limit \"\n        f\"({MAX_GLOBAL_CONCURRENCY_LIMIT_TEST}).\"\n    )\n\n    # Optional: Rough check on execution time to verify semaphore is working to limit concurrency\n    # Theoretical minimum execution time = (Total tasks / Concurrency limit) * Single task duration\n    # Actual time will be longer due to various overheads\n    min_expected_duration = (\n        total_tasks_expected_to_run * SIMULATED_API_CALL_DURATION_TEST\n    ) / MAX_GLOBAL_CONCURRENCY_LIMIT_TEST\n    # print(f\"Minimum Expected Execution Time (approx): {min_expected_duration:.2f}s\")\n    # Allow some margin, e.g., 80% of theoretical minimum time\n    assert total_execution_time >= min_expected_duration * 0.8, (\n        f\"Total execution time ({total_execution_time:.2f}s) was unexpectedly short, suggesting the \"\n        f\"semaphore might not be effectively limiting concurrency as expected \"\n        f\"(min expected: {min_expected_duration * 0.8:.2f}s).\"\n    )\n\n\n# Ensure there is no more code after this point if these were the last functions.\n# If there was other code, it would follow here.\ndef test_unit_invalid_input_format():\n    \"\"\"Unit test: Invalid in_outs format passed\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, None, CODE_SUCCESS)\n    assert results == [-1]\n    assert metadata_list[0][\"error\"] == \"Invalid input/output data\"\n\n    results, metadata_list = check_correctness(SANDBOX_URL, {}, CODE_SUCCESS)\n    assert results == [-1]\n    assert metadata_list[0][\"error\"] == \"Invalid input/output data\"\n\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_INVALID_MISSING_KEY, CODE_SUCCESS)\n    assert results == [-1]\n    assert metadata_list[0][\"error\"] == \"Invalid input/output data\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_unit_input_output_mismatch():\n    \"\"\"Unit test: Mismatch between the number of inputs and outputs\"\"\"\n    results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_MISMATCH, CODE_SUCCESS)\n    assert results == [-1]\n    assert len(metadata_list) == 1\n    assert metadata_list[0][\"error\"] == \"Input/output count mismatch\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_integration_concurrency_all_timeout():\n    \"\"\"Integration test: High concurrency (100 cases) against real API, all causing timeout\"\"\"\n    concurrency_level = 100\n    code_infinite_loop = \"\"\"\ndef knight_moves(X, Y):\n    MOD = 10**9 + 7\n    dp = [[0] * (Y + 1) for _ in range(X + 1)]\n    dp[0][0] = 1\n    for i in range(1, X + 1):\n        for j in range(1, Y + 1):\n            dp[i][j] = (dp[i - 1][j] + dp[i][j - 1]) % MOD\n    return dp[X][Y]\n\ndef solve():\n    X, Y = map(int, input().split())\n    print(knight_moves(X, Y))\n\nif __name__ == \"__main__\":\n    solve()\n    \"\"\"\n\n    # Generate 100 simple input/output pairs (content doesn't matter)\n    timeout_inputs = [\"324 384429\" for i in range(concurrency_level)]\n    timeout_outputs = [f\"output_{i}\\n\" for i in range(concurrency_level)]\n    timeout_in_outs = {\"inputs\": timeout_inputs, \"outputs\": timeout_outputs}\n\n    # Set a timeout for the test cases\n    test_timeout = 10  # Set a timeout value\n\n    start_time = time.time()\n    results, metadata_list = check_correctness(SANDBOX_URL, timeout_in_outs, code_infinite_loop, timeout=test_timeout)\n    end_time = time.time()\n    duration = end_time - start_time\n    print(f\"\\nHigh concurrency all timeout test ({concurrency_level} cases) duration: {duration:.2f} seconds\")\n\n    # Verify all results are -3 (timeout)\n    assert len(results) == concurrency_level, f\"Expected {concurrency_level} results, got {len(results)}\"\n    all_timed_out = all(r == -3 for r in results)\n    if not all_timed_out:\n        non_timeout_indices = [i for i, r in enumerate(results) if r != -3]\n        print(f\"Indices that did not time out: {non_timeout_indices}\")\n        # Print metadata for the first few non-timeout cases for debugging\n        for i in non_timeout_indices[:5]:\n            print(f\"Metadata for non-timeout case {i}: {metadata_list[i]}\")\n    assert all_timed_out, f\"Not all {concurrency_level} concurrent tests resulted in timeout (-3). Results: {results}\"\n\n    # Verify metadata count and status of the first case\n    assert len(metadata_list) == concurrency_level\n    assert metadata_list[0][\"status\"] == \"timeout\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_fn_name_success_single_case():\n    \"\"\"Tests successful execution for a single test case with fn_name.\n    from livecodebench/code_generation_lite test 510\n    \"\"\"\n    generation_code = \"\"\"\nclass Solution:\n    def occurrencesOfElement(self, nums: List[int], queries: List[int], x: int) -> List[int]:\n        positions = defaultdict(list)\n        for idx, num in enumerate(nums):\n            positions[num].append(idx)\n\n        x_positions = positions[x]\n        answer = []\n        for k in queries:\n            if k > len(x_positions):\n                answer.append(-1)\n            else:\n                answer.append(x_positions[k-1])\n        return answer\n\"\"\"\n    in_outs = {\n        \"fn_name\": \"occurrencesOfElement\",\n        \"inputs\": [\"[1, 3, 1, 7]\\n[1, 3, 2, 4]\\n1\", \"[1, 2, 3]\\n[10]\\n5\"],\n        \"outputs\": [\"[0, -1, 2, -1]\", \"[-1]\"],\n    }\n\n    # Use a short timeout for fast tests\n    results, metadata_list = check_correctness(SANDBOX_URL, in_outs, generation_code, timeout=5)\n    # from verl.utils.reward_score.prime_code import apps_check_correctness\n    # results, metadata_list = apps_check_correctness(in_outs=in_outs, generation=generation_code,\n    #                                                        timeout=50000, debug=True)\n\n    assert results == [True, True]\n    assert \"error\" not in metadata_list[0]\n    assert metadata_list[0].get(\"status\") != \"compilation error\"\n    assert metadata_list[0].get(\"status\") != \"runtime error\"\n\n\n@pytest.mark.skipif(skip_condition, reason=skip_reason)\ndef test_none_and_empty_stdin_passed_correctly():\n    \"\"\"\n    Tests that when stdin data is set to an empty string or None, it is still\n    is passed correctly to Sandbox Fusion as an empty string.\n    \"\"\"\n    echo_code = \"\"\"\nimport sys\nprint(f\"You said '{sys.stdin.readline().strip()}'\")\n\"\"\"\n    in_outs = {\n        \"inputs\": [None, \"\", \"hello\"],\n        \"outputs\": [\"You said ''\", \"You said ''\", \"You said 'hello'\"],\n    }\n\n    # Use a short timeout for fast tests\n    results, metadata_list = check_correctness(SANDBOX_URL, in_outs, echo_code, timeout=5)\n\n    assert results == [True, True, True]\n    assert \"error\" not in metadata_list[0]\n    assert metadata_list[0].get(\"status\") != \"compilation error\"\n    assert metadata_list[0].get(\"status\") != \"runtime error\"\n"
  },
  {
    "path": "verl_rl/tests/utils/reward_score/test_sandbox_on_cpu.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport json\nimport os\n\nimport pytest\n\nfrom verl.utils.reward_score import default_compute_score, prime_code, sandbox_fusion\nfrom verl.utils.reward_score.prime_code import apps_check_correctness\nfrom verl.workers.reward_manager.prime import parallel_compute_score_async\n\nprime_math_answers = [\n    \"\"\"\\\\begin{bmatrix}\\n -7 & 6 & -8 \\\\\\\\\\n 11 & -9 & 12 \\\\\\\\\\n 15 & -16 & 19 \\n \\\\end{bmatrix}\"\"\",\n    \"\"\"\\\\frac{\\\\sqrt{505}}{7}\"\"\",\n    \"\"\"x^2 + y^2 + 4x - 6y + 13\"\"\",\n]\nprime_math_gts = [\n    \"\"\"\\\\begin{pmatrix}\\n -7 & 6 & -8 \\\\\\\\\\n 11 & -9 & 12 \\\\\\\\\\n 15 & -16 & 19\\n \\\\end{pmatrix}\"\"\",  # mat test\n    \"\"\"\\\\frac{\\\\sqrt{505}}{7}\"\"\",  # frac test\n    \"\"\"(x + 2)^2 + (y - 3)^2 \"\"\",  # symbolic test\n]\n\nprime_code_answers = [\n    \"\"\"import sys\nfrom collections import deque\n\ndef main():\n    data = sys.stdin.read().split()\n    it = iter(data)\n    \n    # Read start and target positions\n    x0, y0, x1, y1 = int(next(it)), int(next(it)), int(next(it)), int(next(it))\n    \n    n = int(next(it))\n    allowed = set()\n    # The total number of allowed cells is at most 10^5.\n    for _ in range(n):\n        r = int(next(it))\n        a = int(next(it))\n        b = int(next(it))\n        for c in range(a, b + 1):\n            allowed.add((r, c))\n    \n    # Directions for the king (8 neighboring cells)\n    directions = [(-1, -1), (-1, 0), (-1, 1),\n                  (0, -1),           (0, 1),\n                  (1, -1),  (1, 0),  (1, 1)]\n    \n    start = (x0, y0)\n    target = (x1, y1)\n    \n    # BFS initialization\n    queue = deque()\n    queue.append((x0, y0, 0))\n    # Mark the starting cell as visited by removing it from allowed set.\n    allowed.discard(start)\n    \n    while queue:\n        x, y, moves = queue.popleft()\n        if (x, y) == target:\n            print(moves)\n            return\n        for dx, dy in directions:\n            nx, ny = x + dx, y + dy\n            if (nx, ny) in allowed:\n                allowed.remove((nx, ny))\n                queue.append((nx, ny, moves + 1))\n    \n    print(-1)\n\nif __name__ == '__main__':\n    main()\n\"\"\"\n] * 2\nprime_code_gts = [\n    \"\"\"{\\n \\\"inputs\\\": [\\n \\\"5 7 6 11\\\\n3\\\\n5 3 8\\\\n6 7 11\\\\n5 2 5\\\\n\\\",\\n \\\"3 4 3 10\\\\n3\\\\n3 1 4\\\\n4 5 9\\\\n3 10 10\\\\n\\\",\\n \\\"1 1 2 10\\\\n2\\\\n1 1 3\\\\n2 6 10\\\\n\\\",\\n \\\"9 8 7 8\\\\n9\\\\n10 6 6\\\\n10 6 6\\\\n7 7 8\\\\n9 5 6\\\\n8 9 9\\\\n9 5 5\\\\n9 8 8\\\\n8 5 6\\\\n9 10 10\\\\n\\\",\\n \\\"6 15 7 15\\\\n9\\\\n6 15 15\\\\n7 14 14\\\\n6 15 15\\\\n9 14 14\\\\n7 14 16\\\\n6 15 15\\\\n6 15 15\\\\n7 14 14\\\\n8 15 15\\\\n\\\",\\n \\\"13 16 20 10\\\\n18\\\\n13 16 16\\\\n20 10 10\\\\n19 10 10\\\\n12 15 15\\\\n20 10 10\\\\n18 11 11\\\\n19 10 10\\\\n19 10 10\\\\n20 10 10\\\\n19 10 10\\\\n20 10 10\\\\n20 10 10\\\\n19 10 10\\\\n18 11 11\\\\n13 16 16\\\\n12 15 15\\\\n19 10 10\\\\n19 10 10\\\\n\\\",\\n \\\"89 29 88 30\\\\n16\\\\n87 31 31\\\\n14 95 95\\\\n98 88 89\\\\n96 88 88\\\\n14 97 97\\\\n13 97 98\\\\n100 88 88\\\\n88 32 32\\\\n99 88 89\\\\n90 29 29\\\\n87 31 31\\\\n15 94 96\\\\n89 29 29\\\\n88 32 32\\\\n97 89 89\\\\n88 29 30\\\\n\\\",\\n \\\"30 14 39 19\\\\n31\\\\n35 7 11\\\\n37 11 12\\\\n32 13 13\\\\n37 5 6\\\\n46 13 13\\\\n37 14 14\\\\n31 13 13\\\\n43 13 19\\\\n45 15 19\\\\n46 13 13\\\\n32 17 17\\\\n41 14 19\\\\n30 14 14\\\\n43 13 17\\\\n34 16 18\\\\n44 11 19\\\\n38 13 13\\\\n40 12 20\\\\n37 16 18\\\\n46 16 18\\\\n34 10 14\\\\n36 9 10\\\\n36 15 19\\\\n38 15 19\\\\n42 13 19\\\\n33 14 15\\\\n35 15 19\\\\n33 17 18\\\\n39 12 20\\\\n36 5 7\\\\n45 12 12\\\\n\\\",\\n \\\"2 1 1 1\\\\n2\\\\n1 1 2\\\\n2 1 2\\\\n\\\",\\n \\\"1 1 1 2\\\\n5\\\\n1000000000 1 10000\\\\n19920401 1188 5566\\\\n1000000000 1 10000\\\\n1 1 10000\\\\n5 100 200\\\\n\\\",\\n \\\"1 1 1000000000 2\\\\n5\\\\n1000000000 1 10000\\\\n19920401 1188 5566\\\\n1000000000 1 10000\\\\n1 1 10000\\\\n5 100 200\\\\n\\\"\\n ],\\n \\\"outputs\\\": [\\n \\\"4\\\\n\\\",\\n \\\"6\\\\n\\\",\\n \\\"-1\\\\n\\\",\\n \\\"2\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"-1\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"9\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"-1\\\\n\\\"\\n ]\\n}\"\"\",  # A correct sample # noqa: E501\n    \"\"\"{\\n \\\"inputs\\\": [\\n \\\"5 7 6 11\\\\n3\\\\n5 3 8\\\\n6 7 11\\\\n5 2 5\\\\n\\\",\\n \\\"3 4 3 10\\\\n3\\\\n3 1 4\\\\n4 5 9\\\\n3 10 10\\\\n\\\",\\n \\\"1 1 2 10\\\\n2\\\\n1 1 3\\\\n2 6 10\\\\n\\\",\\n \\\"9 8 7 8\\\\n9\\\\n10 6 6\\\\n10 6 6\\\\n7 7 8\\\\n9 5 6\\\\n8 9 9\\\\n9 5 5\\\\n9 8 8\\\\n8 5 6\\\\n9 10 10\\\\n\\\",\\n \\\"6 15 7 15\\\\n9\\\\n6 15 15\\\\n7 14 14\\\\n6 15 15\\\\n9 14 14\\\\n7 14 16\\\\n6 15 15\\\\n6 15 15\\\\n7 14 14\\\\n8 15 15\\\\n\\\",\\n \\\"13 16 20 10\\\\n18\\\\n13 16 16\\\\n20 10 10\\\\n19 10 10\\\\n12 15 15\\\\n20 10 10\\\\n18 11 11\\\\n19 10 10\\\\n19 10 10\\\\n20 10 10\\\\n19 10 10\\\\n20 10 10\\\\n20 10 10\\\\n19 10 10\\\\n18 11 11\\\\n13 16 16\\\\n12 15 15\\\\n19 10 10\\\\n19 10 10\\\\n\\\",\\n \\\"89 29 88 30\\\\n16\\\\n87 31 31\\\\n14 95 95\\\\n98 88 89\\\\n96 88 88\\\\n14 97 97\\\\n13 97 98\\\\n100 88 88\\\\n88 32 32\\\\n99 88 89\\\\n90 29 29\\\\n87 31 31\\\\n15 94 96\\\\n89 29 29\\\\n88 32 32\\\\n97 89 89\\\\n88 29 30\\\\n\\\",\\n \\\"30 14 39 19\\\\n31\\\\n35 7 11\\\\n37 11 12\\\\n32 13 13\\\\n37 5 6\\\\n46 13 13\\\\n37 14 14\\\\n31 13 13\\\\n43 13 19\\\\n45 15 19\\\\n46 13 13\\\\n32 17 17\\\\n41 14 19\\\\n30 14 14\\\\n43 13 17\\\\n34 16 18\\\\n44 11 19\\\\n38 13 13\\\\n40 12 20\\\\n37 16 18\\\\n46 16 18\\\\n34 10 14\\\\n36 9 10\\\\n36 15 19\\\\n38 15 19\\\\n42 13 19\\\\n33 14 15\\\\n35 15 19\\\\n33 17 18\\\\n39 12 20\\\\n36 5 7\\\\n45 12 12\\\\n\\\",\\n \\\"2 1 1 1\\\\n2\\\\n1 1 2\\\\n2 1 2\\\\n\\\",\\n \\\"1 1 1 2\\\\n5\\\\n1000000000 1 10000\\\\n19920401 1188 5566\\\\n1000000000 1 10000\\\\n1 1 10000\\\\n5 100 200\\\\n\\\",\\n \\\"1 1 1000000000 2\\\\n5\\\\n1000000000 1 10000\\\\n19920401 1188 5566\\\\n1000000000 1 10000\\\\n1 1 10000\\\\n5 100 200\\\\n\\\"\\n ],\\n \\\"outputs\\\": [\\n \\\"4\\\\n\\\",\\n \\\"6\\\\n\\\",\\n \\\"-1\\\\n\\\",\\n \\\"-1\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"-1\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"9\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"1\\\\n\\\",\\n \\\"-1\\\\n\\\"\\n ]\\n}\"\"\",  # noqa: E501\n]  # A failed sample with first several in-out passed\n\nprime_code_scores = [1.0, 0.9]\n\n\ndef test_parallelism():\n    \"\"\"\n    Test if process pool works properly\n    \"\"\"\n    sequences_str = []\n    ground_truth = []\n    data_sources = []\n    while len(sequences_str) < 32:\n        sequences_str.extend(prime_code_answers)\n        ground_truth.extend(prime_code_gts)\n        data_sources.extend([\"codecontests\"] * len(prime_code_answers))\n\n        sequences_str.extend(prime_math_answers)\n        ground_truth.extend(prime_math_gts)\n        data_sources.extend([\"numina_aops_forum\"] * len(prime_math_answers))\n\n    scores = asyncio.run(\n        parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16)\n    )\n    print(scores)\n\n\ndef test_prime_code():\n    \"\"\"\n    Test PRIME code sandbox.\n    \"\"\"\n    data_source = \"codecontests\"\n    for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True):\n        score = default_compute_score(data_source, completion, ground_truth)\n        assert float(score) == score_\n\n\n# Use the pytest.mark.skipif decorator to skip the test\n@pytest.mark.skipif(not os.environ.get(\"SANDBOX_FUSION_URL\"), reason=\"SANDBOX_FUSION_URL environment variable not set\")\ndef test_prime_code_sandbox_fusion():\n    \"\"\"\n    Test PRIME code on sandbox fusion. Skips if SANDBOX_FUSION_URL is not set.\n    \"\"\"\n    data_source = \"codecontests\"\n    # Get the URL from the environment variable, as skipif ensures it is set at this point\n    sandbox_fusion_url = os.environ.get(\"SANDBOX_FUSION_URL\")\n    # Removed the previous 'if not sandbox_url' check block\n\n    for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True):\n        score = default_compute_score(\n            data_source, completion, ground_truth, extra_info={\"sandbox_fusion_url\": sandbox_fusion_url}\n        )  # <-- Use the URL obtained from the environment variable\n        assert float(score) == score_\n\n\n@pytest.mark.skipif(not os.environ.get(\"SANDBOX_FUSION_URL\"), reason=\"SANDBOX_FUSION_URL environment variable not set\")\ndef test_continuous_score_consistency():\n    \"\"\"\n    Verify that continuous score calculation is consistent between prime_code and sandbox_fusion.\n    Uses a test case where the first 9 out of 11 sub-cases pass (expected score 0.9).\n    \"\"\"\n    completion = prime_code_answers[1]  # Use the second sample\n    ground_truth = prime_code_gts[1]  # Use the second sample (9/11 pass, first 9 pass)\n    expected_continuous_score = 0.9\n\n    # 1. Calculate score using prime_code (default) with continuous=True\n    prime_score, _ = sandbox_fusion.compute_score(\n        os.environ.get(\"SANDBOX_FUSION_URL\"), None, completion, ground_truth, continuous=True\n    )\n\n    # 2. Calculate score using sandbox_fusion with continuous=True\n    # Ensure the extra_info key triggers the sandbox_fusion path in default_compute_score\n    fusion_score, _ = prime_code.compute_score(completion, ground_truth, continuous=True)\n\n    # 3. Assert scores are equal (using pytest.approx for float comparison)\n    assert float(prime_score) == pytest.approx(expected_continuous_score)\n    assert float(fusion_score) == pytest.approx(expected_continuous_score)\n    assert float(prime_score) == pytest.approx(float(fusion_score))\n    print(f\"Continuous Score (Prime Code): {prime_score}\")\n    print(f\"Continuous Score (Sandbox Fusion): {fusion_score}\")\n\n\ndef test_check_correctness():\n    completion = prime_code_answers[0]\n    ground_truth = json.loads(prime_code_gts[0])\n    ground_truth_single = {\"inputs\": ground_truth[\"inputs\"][:1], \"outputs\": ground_truth[\"outputs\"][:1]}\n    res, meta = apps_check_correctness(in_outs=ground_truth_single, generation=completion, timeout=5, debug=False)\n    print(res, meta)\n\n\ndef test_prime_math():\n    data_source = \"numina_aops_forum\"\n    for completion, ground_truth in zip(prime_math_answers, prime_math_gts, strict=True):\n        score = default_compute_score(data_source, completion, ground_truth)\n        assert float(score) == 1.0\n"
  },
  {
    "path": "verl_rl/tests/utils/test_activation_offload.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nimport shutil\nimport tempfile\n\nimport pytest\nimport torch\nimport torch.distributed\nimport torch.multiprocessing as mp\nfrom torch.distributed import init_device_mesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import MixedPrecision, ShardingStrategy\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config\n\nfrom verl.utils.activation_offload import enable_activation_offloading\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, get_fsdp_wrap_policy\n\n\ndef _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy=\"fsdp\"):\n    torch.cuda.set_device(rank)\n    torch.distributed.init_process_group(\n        backend=\"nccl\",\n        init_method=f\"file://{rendezvous_file}\",\n        rank=rank,\n        world_size=world_size,\n    )\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(world_size,), mesh_dim_names=(\"dp\",))\n\n    model_name = \"Qwen/Qwen2.5-0.5B-Instruct\"\n    config = Qwen2Config(num_hidden_layers=4)\n\n    with torch.device(\"cuda\"):\n        model = AutoModelForCausalLM.from_config(\n            config=config, torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n        )\n        model = model.to(device=\"cuda\")\n\n    # Wrap model with FSDP\n    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)\n\n    if strategy == \"fsdp\":\n        model = FSDP(\n            model,\n            use_orig_params=False,\n            device_id=torch.cuda.current_device(),\n            sharding_strategy=ShardingStrategy.FULL_SHARD,\n            mixed_precision=mixed_precision,\n            device_mesh=device_mesh,\n            auto_wrap_policy=get_fsdp_wrap_policy(module=model),\n        )\n    else:\n        mp_policy = MixedPrecisionPolicy(\n            param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True\n        )\n        fsdp_kwargs = {\n            \"mesh\": device_mesh,\n            \"mp_policy\": mp_policy,\n        }\n        apply_fsdp2(model, fsdp_kwargs, {})\n\n    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)\n    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)\n\n    # Create checkpoint manager\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    checkpoint_manager = FSDPCheckpointManager(\n        model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer\n    )\n\n    # Generate sample input\n    batch_size = 2\n    seq_len = 32\n    vocab_size = 32000\n    # First input for initial update\n    input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device=\"cuda\")\n    attention_mask1 = torch.ones_like(input_ids1)\n\n    # Second input for verification\n    input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device=\"cuda\")\n    attention_mask2 = torch.ones_like(input_ids2)\n\n    # Step 1: Initial update and save checkpoint\n    outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1)\n    loss1 = outputs1.logits.mean()\n    loss1.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Save checkpoint after first update\n    temp_dir = tempfile.mkdtemp()\n    checkpoint_path = os.path.join(temp_dir, \"checkpoint\")\n    checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)\n\n    # Step 2: Second update and forward pass\n    outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2)\n    loss2 = outputs2.logits.mean()\n    loss2.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Record logits after second update\n    with torch.no_grad():\n        logits_without_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits\n\n    # Step 3: wrap module with activation offloading and load checkpoint\n    enable_activation_offloading(model, \"fsdp\")\n    checkpoint_manager.load_checkpoint(checkpoint_path)\n\n    # Step 4: Repeat the second update with same input\n    outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2)\n    loss3 = outputs3.logits.mean()\n    loss3.backward()\n    optimizer.step()\n    lr_scheduler.step()\n    optimizer.zero_grad()\n\n    # Record logits after loaded checkpoint and update\n    with torch.no_grad():\n        logits_with_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits\n\n    # Step 4: Verify outputs match\n    torch.testing.assert_close(logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0)\n    print(f\"Activaiton offloading for {strategy} test passed on {world_size} GPUs!\")\n\n    # Cleanup\n    shutil.rmtree(temp_dir)\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n\n\n@pytest.mark.parametrize(\"world_size\", (2, 4))\n@pytest.mark.parametrize(\"strategy\", (\"fsdp\", \"fsdp2\"))\ndef test_activation_offloading(world_size, strategy, tmp_path):\n    rendezvous_file = str(tmp_path / \"rdzv_file\")\n    os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)\n\n    mp.spawn(\n        fn=_fsdp_activation_offloading_test,\n        args=(world_size, rendezvous_file, strategy),\n        nprocs=world_size,\n        join=True,\n    )\n"
  },
  {
    "path": "verl_rl/tests/utils/test_config_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nfrom dataclasses import dataclass\n\nfrom omegaconf import OmegaConf\n\nfrom verl.utils import omega_conf_to_dataclass\n\n\n@dataclass\nclass TestDataclass:\n    hidden_size: int\n    activation: str\n\n\n@dataclass\nclass TestTrainConfig:\n    batch_size: int\n    model: TestDataclass\n\n\n_cfg_str = \"\"\"train_config:\n  batch_size: 32\n  model:\n    hidden_size: 768\n    activation: relu\"\"\"\n\n\nclass TestConfigOnCPU(unittest.TestCase):\n    \"\"\"Test cases for configuration utilities on CPU.\n\n    Test Plan:\n    1. Test basic OmegaConf to dataclass conversion for simple nested structures\n    2. Test nested OmegaConf to dataclass conversion for complex hierarchical configurations\n    3. Verify all configuration values are correctly converted and accessible\n    \"\"\"\n\n    def setUp(self):\n        self.config = OmegaConf.create(_cfg_str)\n\n    def test_omega_conf_to_dataclass(self):\n        sub_cfg = self.config.train_config.model\n        cfg = omega_conf_to_dataclass(sub_cfg, TestDataclass)\n        self.assertEqual(cfg.hidden_size, 768)\n        self.assertEqual(cfg.activation, \"relu\")\n        assert isinstance(cfg, TestDataclass)\n\n    def test_nested_omega_conf_to_dataclass(self):\n        cfg = omega_conf_to_dataclass(self.config.train_config, TestTrainConfig)\n        self.assertEqual(cfg.batch_size, 32)\n        self.assertEqual(cfg.model.hidden_size, 768)\n        self.assertEqual(cfg.model.activation, \"relu\")\n        assert isinstance(cfg, TestTrainConfig)\n        assert isinstance(cfg.model, TestDataclass)\n\n\nclass TestPrintCfgCommand(unittest.TestCase):\n    \"\"\"Test suite for the print_cfg.py command-line tool.\"\"\"\n\n    def test_command_with_override(self):\n        \"\"\"Test that the command runs without error when overriding config values.\"\"\"\n        import subprocess\n\n        # Run the command\n        result = subprocess.run(\n            [\"python3\", \"scripts/print_cfg.py\", \"critic.profiler.discrete=True\", \"+critic.profiler.extra.any_key=val\"],\n            capture_output=True,\n            text=True,\n        )\n\n        # Verify the command exited successfully\n        self.assertEqual(result.returncode, 0, f\"Command failed with stderr: {result.stderr}\")\n\n        # Verify the output contains expected config information\n        self.assertIn(\"critic\", result.stdout)\n        self.assertIn(\"profiler\", result.stdout)\n        self.assertIn(\"discrete=True\", result.stdout)\n        self.assertIn(\"extra={'any_key': 'val'}\", result.stdout)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_rl/tests/utils/test_flops_counter.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\n\nimport pytest\n\nfrom verl.utils.flops_counter import FlopsCounter\n\nVALID_CONFIG_TYPE = {\"llama\", \"qwen2\", \"qwen3\", \"qwen3_moe\", \"deepseek_v3\", \"mistral\", \"gemma3_text\"}\n\n\nclass Config:\n    def __init__(self, config_dict):\n        for key, value in config_dict.items():\n            setattr(self, key, value)\n\n\nCONFIG = {\n    \"llama\": {\n        \"config\": {  # llama2-7B\n            \"model_type\": \"llama\",\n            \"vocab_size\": 32000,\n            \"hidden_size\": 4096,\n            \"intermediate_size\": 11008,\n            \"num_hidden_layers\": 32,\n            \"num_attention_heads\": 32,\n            \"num_key_value_heads\": 32,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +\n        # 12*sum(seqlen^2)*layer*head*head_dim\n        # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(512+1024+2048) +\n        # 12*(512*512+1024*1024+2048*2048)*32*4096\n        # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(4096+4096+4096) +\n        # 12*(4096*4096+4096*4096+4096*4096)*32*4096\n        \"expected_flops_tuple\": (153555818250240 / 1e12, 575955114393600 / 1e12),\n    },\n    \"qwen2\": {\n        \"config\": {  # Qwen/Qwen2.5-7B-Instruct\n            \"model_type\": \"qwen2\",\n            \"vocab_size\": 152064,\n            \"hidden_size\": 3584,\n            \"intermediate_size\": 18944,\n            \"num_hidden_layers\": 28,\n            \"num_attention_heads\": 28,\n            \"num_key_value_heads\": 4,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +\n        # 12*sum(seqlen^2)*layer*head*head_dim\n        # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(512+1024+2048) +\n        # 12*(512*512+1024*1024+2048*2048)*28*3584\n        # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(4096+4096+4096) +\n        # 12*(4096*4096+4096*4096+4096*4096)*28*3584\n        \"expected_flops_tuple\": (170388331954176 / 1e12, 622070178250752 / 1e12),\n    },\n    \"qwen3\": {\n        \"config\": {  # Qwen/Qwen3-8B\n            \"model_type\": \"qwen3\",\n            \"vocab_size\": 151936,\n            \"hidden_size\": 4096,\n            \"intermediate_size\": 12288,\n            \"num_hidden_layers\": 36,\n            \"num_attention_heads\": 32,\n            \"num_key_value_heads\": 8,\n            \"head_dim\": 128,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +\n        # 12*sum(seqlen^2)*layer*head*head_dim\n        # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(512+1024+2048) +\n        # 12*(512*512+1024*1024+2048*2048)*36*128*32\n        # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(4096+4096+4096) +\n        # 12*(4096*4096+4096*4096+4096*4096)*36*128*32\n        \"expected_flops_tuple\": (185867930959872 / 1e12, 692924253732864 / 1e12),\n    },\n    \"qwen3_moe\": {\n        \"config\": {  # Qwen/Qwen3-30B-A3B-Base\n            \"model_type\": \"qwen3_moe\",\n            \"hidden_size\": 2048,\n            \"vocab_size\": 151936,\n            \"num_hidden_layers\": 48,\n            \"num_key_value_heads\": 4,\n            \"num_attention_heads\": 32,\n            \"head_dim\": 128,\n            \"moe_intermediate_size\": 768,\n            \"num_experts_per_tok\": 8,\n            \"num_experts\": 128,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3 +\n        # hidden*num_experts))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim\n        # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(512+1024+2048) +\n        # 12*(512*512+1024*1024+2048*2048)*48*128*32\n        # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(4096+4096+4096) +\n        # 12*(4096*4096+4096*4096+4096*4096)*48*128*32\n        \"expected_flops_tuple\": (85087060230144 / 1e12, 365944098521088 / 1e12),\n    },\n    \"deepseek_v3\": {\n        \"config\": {  # deepseek-ai/DeepSeek-Prover-V2-671B\n            \"model_type\": \"deepseek_v3\",\n            \"hidden_size\": 7168,\n            \"vocab_size\": 129280,\n            \"moe_intermediate_size\": 2048,\n            \"num_hidden_layers\": 61,\n            \"first_k_dense_replace\": 3,\n            \"num_attention_heads\": 128,\n            \"n_routed_experts\": 256,\n            \"num_experts_per_tok\": 8,\n            \"n_shared_experts\": 1,\n            \"kv_lora_rank\": 512,\n            \"qk_rope_head_dim\": 64,\n            \"v_head_dim\": 128,\n            \"intermediate_size\": 18432,\n            \"qk_nope_head_dim\": 128,\n            \"q_lora_rank\": 1536,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # (1536*7168+128*192*1536+7168*(512+64)+128*(128+128)*512+128*128*7168) = 187105280\n        # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(512+1024+2048) +\n        # 12*(512*512+1024*1024+2048*2048)*61*192*128\n        # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(4096+4096+4096) +\n        # 12*(4096*4096+4096*4096+4096*4096)*61*192*128\n        \"expected_flops_tuple\": (906535995703296 / 1e12, 3674028304760832 / 1e12),\n    },\n    \"mistral\": {\n        \"config\": {  # mistralai/Mistral-Small-24B-Instruct-2501\n            \"model_type\": \"mistral\",\n            \"vocab_size\": 131072,\n            \"hidden_size\": 5120,\n            \"intermediate_size\": 32768,\n            \"num_hidden_layers\": 40,\n            \"num_attention_heads\": 32,\n            \"num_key_value_heads\": 8,\n            \"head_dim\": 128,\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # Mistral uses same architecture as Llama, with GQA\n        # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +\n        # 12*sum(seqlen^2)*layer*head*head_dim\n        # vocab part: 131072*5120*2 = 1342177280\n        # attn part per layer: 5120*(128*32+128*8+128*8+128*32) = 5120*10240 = 52428800\n        # mlp part per layer: 5120*32768*3 = 503316480\n        # total per layer: 52428800 + 503316480 = 555745280\n        # all layers: 1342177280 + 40*555745280 = 23571988480\n        # For batch [512, 1024, 2048], tokens_sum = 3584:\n        # dense flops: 6 * 23571988480 * 3584 = 506892040273920\n        # attn flops: 12 * 5505024 * 40 * 128 * 32 = 10823317585920\n        # total: 517715357859840 / 1e12 = 517.71535785984\n        # For batch [4096, 4096, 4096], tokens_sum = 12288:\n        # dense flops: 6 * 23571988480 * 12288 = 1737915566653440\n        # attn flops: 12 * 50331648 * 40 * 128 * 32 = 98956046499840\n        # total: 1836871613153280 / 1e12 = 1836.87161315328\n        \"expected_flops_tuple\": (517715357859840 / 1e12, 1836871613153280 / 1e12),\n    },\n    \"gemma3_text\": {\n        \"config\": {  # Gemma3-12B-IT-TextOnly\n            \"model_type\": \"gemma3_text\",\n            \"vocab_size\": 262208,\n            \"hidden_size\": 3840,\n            \"intermediate_size\": 15360,\n            \"num_hidden_layers\": 48,\n            \"num_attention_heads\": 16,\n            \"num_key_value_heads\": 8,\n            \"head_dim\": 256,\n            \"sliding_window\": 1024,\n            \"layer_types\": None,\n            # Will be auto-generated based on sliding_window_pattern\n            \"sliding_window_pattern\": 6,\n            # Every 6th layer is full attention\n        },\n        \"batch_seqlens_tuple\": ([512, 1024, 2048], [4096, 4096, 4096]),\n        # Gemma3 has alternating sliding window attention\n        # With sliding_window_pattern=6: layers 5,11,17,23,29,35,41,47 use full attention (8 layers)\n        # Other 40 layers use sliding window attention with window_size=1024\n        #\n        # Non-attention FLOPs:\n        # vocab part: 262208*3840*2 = 2013757440\n        # attn part per layer: 3840*(256*16+256*8+256*8+256*16) = 3840*12288 = 47185920\n        # mlp part per layer: 3840*15360*3 = 176947200\n        # total per layer: 47185920 + 176947200 = 224133120\n        # all layers: 2013757440 + 48*224133120 = 12772147200\n        #\n        # For batch [512, 1024, 2048], tokens_sum = 3584:\n        # dense flops: 6 * 12772147200 * 3584 = 274652253388800\n        # seqlen_square_sum: 180355072 (calculated with sliding window logic)\n        # attn flops: 12 * 180355072 * 256 * 16 = 8864812498944\n        # total: 283517065887744 / 1e12 = 283.517065887744\n        #\n        # For batch [4096, 4096, 4096], tokens_sum = 12288:\n        # dense flops: 6 * 12772147200 * 12288 = 941664868761600\n        # seqlen_square_sum: 905969664 (calculated with sliding window logic)\n        # attn flops: 12 * 905969664 * 256 * 16 = 44530220924928\n        # total: 986195089686528 / 1e12 = 986.195089686528\n        \"expected_flops_tuple\": (283517065887744 / 1e12, 986195089686528 / 1e12),\n    },\n}\n\n\n@pytest.mark.parametrize(\n    \"config_type\",\n    [\"llama\", \"qwen2\", \"qwen3\", \"qwen3_moe\", \"deepseek_v3\", \"mistral\", \"gemma3_text\"],\n)\ndef test_flops_counter(config_type: str):\n    test_config = CONFIG[config_type]\n    config = Config(test_config[\"config\"])\n    flops_counter = FlopsCounter(config)\n    for batch_seqlens, expected_flops in zip(\n        test_config[\"batch_seqlens_tuple\"], test_config[\"expected_flops_tuple\"], strict=True\n    ):\n        # set delta time to 1 to get the flops\n        counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1)\n        print(f\"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}\")\n        assert math.isclose(counted_flops, expected_flops), (\n            f\"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}\"\n        )\n"
  },
  {
    "path": "verl_rl/tests/utils/test_fs_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom pathlib import Path\n\nimport verl.utils.fs as fs\n\n\ndef test_record_and_check_directory_structure(tmp_path):\n    # Create test directory structure\n    test_dir = tmp_path / \"test_dir\"\n    test_dir.mkdir()\n    (test_dir / \"file1.txt\").write_text(\"test\")\n    (test_dir / \"subdir\").mkdir()\n    (test_dir / \"subdir\" / \"file2.txt\").write_text(\"test\")\n\n    # Create structure record\n    record_file = fs._record_directory_structure(test_dir)\n\n    # Verify record file exists\n    assert os.path.exists(record_file)\n\n    # Initial check should pass\n    assert fs._check_directory_structure(test_dir, record_file) is True\n\n    # Modify structure and verify check fails\n    (test_dir / \"new_file.txt\").write_text(\"test\")\n    assert fs._check_directory_structure(test_dir, record_file) is False\n\n\ndef test_copy_from_hdfs_with_mocks(tmp_path, monkeypatch):\n    # Mock HDFS dependencies\n    monkeypatch.setattr(fs, \"is_non_local\", lambda path: True)\n\n    # side_effect will simulate the copy by creating parent dirs + empty file\n    def fake_copy(src: str, dst: str, *args, **kwargs):\n        dst_path = Path(dst)\n        dst_path.parent.mkdir(parents=True, exist_ok=True)\n        dst_path.write_bytes(b\"\")  # touch an empty file\n\n    monkeypatch.setattr(fs, \"copy\", fake_copy)  # Mock actual HDFS copy\n\n    # Test parameters\n    test_cache = tmp_path / \"cache\"\n    hdfs_path = \"hdfs://test/path/file.txt\"\n\n    # Test initial copy\n    local_path = fs.copy_to_local(hdfs_path, cache_dir=test_cache)\n    expected_path = os.path.join(test_cache, fs.md5_encode(hdfs_path), os.path.basename(hdfs_path))\n    assert local_path == expected_path\n    assert os.path.exists(local_path)\n\n\ndef test_always_recopy_flag(tmp_path, monkeypatch):\n    # Mock HDFS dependencies\n    monkeypatch.setattr(fs, \"is_non_local\", lambda path: True)\n\n    copy_call_count = 0\n\n    def fake_copy(src: str, dst: str, *args, **kwargs):\n        nonlocal copy_call_count\n        copy_call_count += 1\n        dst_path = Path(dst)\n        dst_path.parent.mkdir(parents=True, exist_ok=True)\n        dst_path.write_bytes(b\"\")\n\n    monkeypatch.setattr(fs, \"copy\", fake_copy)  # Mock actual HDFS copy\n\n    test_cache = tmp_path / \"cache\"\n    hdfs_path = \"hdfs://test/path/file.txt\"\n\n    # Initial copy (always_recopy=False)\n    fs.copy_to_local(hdfs_path, cache_dir=test_cache)\n    assert copy_call_count == 1\n\n    # Force recopy (always_recopy=True)\n    fs.copy_to_local(hdfs_path, cache_dir=test_cache, always_recopy=True)\n    assert copy_call_count == 2\n\n    # Subsequent normal call (always_recopy=False)\n    fs.copy_to_local(hdfs_path, cache_dir=test_cache)\n    assert copy_call_count == 2  # Should not increment\n"
  },
  {
    "path": "verl_rl/tests/utils/test_import_utils_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport pytest\n\nfrom verl.utils.import_utils import load_extern_type\n\n# Path to the test module\nTEST_MODULE_PATH = os.path.join(os.path.dirname(__file__), \"_test_module.py\")\n\n\ndef test_load_extern_type_class():\n    \"\"\"Test loading a class from an external file\"\"\"\n    TestClass = load_extern_type(TEST_MODULE_PATH, \"TestClass\")\n\n    # Verify the class was loaded correctly\n    assert TestClass is not None\n    assert TestClass.__name__ == \"TestClass\"\n\n    # Test instantiation and functionality\n    instance = TestClass()\n    assert instance.value == \"default\"\n\n    # Test with a custom value\n    custom_instance = TestClass(\"custom\")\n    assert custom_instance.get_value() == \"custom\"\n\n\ndef test_load_extern_type_function():\n    \"\"\"Test loading a function from an external file\"\"\"\n    test_function = load_extern_type(TEST_MODULE_PATH, \"test_function\")\n\n    # Verify the function was loaded correctly\n    assert test_function is not None\n    assert callable(test_function)\n\n    # Test function execution\n    result = test_function()\n    assert result == \"test_function_result\"\n\n\ndef test_load_extern_type_constant():\n    \"\"\"Test loading a constant from an external file\"\"\"\n    constant = load_extern_type(TEST_MODULE_PATH, \"TEST_CONSTANT\")\n\n    # Verify the constant was loaded correctly\n    assert constant is not None\n    assert constant == \"test_constant_value\"\n\n\ndef test_load_extern_type_nonexistent_file():\n    \"\"\"Test behavior when file doesn't exist\"\"\"\n    with pytest.raises(FileNotFoundError):\n        load_extern_type(\"/nonexistent/path.py\", \"SomeType\")\n\n\ndef test_load_extern_type_nonexistent_type():\n    \"\"\"Test behavior when type doesn't exist in the file\"\"\"\n    with pytest.raises(AttributeError):\n        load_extern_type(TEST_MODULE_PATH, \"NonExistentType\")\n\n\ndef test_load_extern_type_none_path():\n    \"\"\"Test behavior when file path is None\"\"\"\n    result = load_extern_type(None, \"SomeType\")\n    assert result is None\n\n\ndef test_load_extern_type_invalid_module():\n    \"\"\"Test behavior when module has syntax errors\"\"\"\n    # Create a temporary file with syntax errors\n    import tempfile\n\n    with tempfile.NamedTemporaryFile(suffix=\".py\", mode=\"w+\", delete=False) as temp_file:\n        temp_file.write(\"This is not valid Python syntax :\")\n        temp_path = temp_file.name\n\n    try:\n        with pytest.raises(RuntimeError):\n            load_extern_type(temp_path, \"SomeType\")\n    finally:\n        # Clean up the temporary file\n        if os.path.exists(temp_path):\n            os.remove(temp_path)\n"
  },
  {
    "path": "verl_rl/tests/utils/test_linear_cross_entropy.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.utils.experimental.torch_functional import FusedLinearForPPO\nfrom verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\nfrom verl.utils.torch_functional import logprobs_from_logits\n\ncompute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)\nfused_linear_for_ppo = FusedLinearForPPO()\nfused_linear_for_ppo.compile(dynamic=True)\n\nMAX_TEST_CASES = os.environ.get(\"MAX_TEST_CASES\", 5)\n\n\ndef run_torch_entropy(\n    hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction=\"none\"\n) -> list[torch.Tensor]:\n    hidden = hidden.squeeze(0).to(torch.float32)\n    weight = weight.transpose(0, 1).to(torch.float32)\n    logits = torch.matmul(hidden, weight)  # [num_tokens, vocab_size]\n    logits /= temperature\n    pd = torch.nn.functional.softmax(logits, dim=-1)  # [num_tokens, vocab_size]\n    entropy_a = torch.logsumexp(logits, dim=-1)  # [num_tokens]\n    entropy_b = torch.sum(pd * logits, dim=-1)  # [num_tokens]\n    entropy = entropy_a - entropy_b\n    logprobs = torch.nn.functional.cross_entropy(logits, labels.squeeze(0), reduction=reduction)  # [num_tokens]\n    logprobs = torch.neg(logprobs)\n    return logprobs, entropy\n\n\ndef run_verl_original_entropy(\n    hidden: torch.Tensor,\n    weight: torch.Tensor,\n    labels: torch.Tensor,\n    temperature: float,\n) -> list[torch.Tensor]:\n    hidden = hidden.squeeze(0).to(torch.float32)\n    weight = weight.transpose(0, 1).to(torch.float32)\n    logits = torch.matmul(hidden, weight)  # [num_tokens, vocab_size]\n    logits /= temperature\n    # compute entropy\n    entropy = compute_entropy_from_logits(logits)  # ((total_nnz / sp) + pad)\n    # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)\n    logprobs = logprobs_from_logits(logits=logits, labels=labels, inplace_backward=False)\n    return logprobs, entropy\n\n\n# To be tested\ndef run_verl_torch_fused_entropy(\n    hidden: torch.Tensor,\n    weight: torch.Tensor,\n    labels: torch.Tensor,\n    temperature: float,\n):\n    hidden = hidden.to(torch.float32)\n    weight = weight.to(torch.float32)\n    logprobs, entropy = fused_linear_for_ppo(\n        hidden,\n        weight,\n        labels,\n        temperature=temperature,\n    )\n    return logprobs.squeeze(0), entropy.squeeze(0)\n\n\nclass TestLinearCrossEntropy:\n    def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None:\n        self.test_case_idx = test_case_idx\n        self.temperature = temperature\n\n    def cleanup(self):\n        torch.cuda.empty_cache()\n        torch.cuda.reset_peak_memory_stats()\n        import gc\n\n        gc.collect()\n        torch.cuda.synchronize()\n\n    def generate_hyper(self):\n        global MAX_TEST_CASES\n\n        self.dtype = torch.bfloat16\n        if self.test_case_idx == 0:\n            self.batch_size = 1\n            self.num_tokens = 1937\n            self.hidden_size = 3584\n            self.vocab_size = 152064\n        elif self.test_case_idx == 1:\n            self.batch_size = 1\n            self.num_tokens = 2169\n            self.hidden_size = 896\n            self.vocab_size = 151936\n        elif self.test_case_idx == 2:\n            self.batch_size = 1\n            self.num_tokens = 1530\n            self.hidden_size = 2048\n            self.vocab_size = 32256\n        elif self.test_case_idx == 3:\n            self.batch_size = 1\n            self.num_tokens = 1388\n            self.hidden_size = 4096\n            self.vocab_size = 102400\n        elif self.test_case_idx == 4:\n            self.batch_size = 1\n            self.num_tokens = 8192\n            self.hidden_size = 4096\n            self.vocab_size = 102400\n        else:\n            raise ValueError(f\"Invalid test case index: {self.test_case_idx}\")\n        assert MAX_TEST_CASES <= 5, \"MAX_TEST_CASES should be less than or equal to 5.\"\n\n    def generate_forward_inputs(self):\n        hidden = (\n            torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device=\"cuda\")\n            .uniform_(-0.5, 0.5)\n            .requires_grad_()\n        )\n        weight = (\n            torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device=\"cuda\")\n            .uniform_(-0.5, 0.5)\n            .requires_grad_()\n        )\n        labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device=\"cuda\")\n        return hidden, weight, labels\n\n    def generate_backward_inputs(self):\n        g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device=\"cuda\").uniform_(-0.5, 0.5)\n        g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device=\"cuda\").uniform_(-1, 1)\n        return g_entropy, g_logprobs\n\n    def verify_correctness(self, iterations=5):\n        self.cleanup()\n        self.generate_hyper()\n\n        torch_forward_latency = list()\n        torch_backward_latency = list()\n        verl_forward_latency = list()\n        verl_backward_latency = list()\n        verl_fused_forward_latency = list()\n        verl_fused_backward_latency = list()\n        kernel_forward_latency = list()\n        kernel_backward_latency = list()\n\n        start_event = torch.cuda.Event(enable_timing=True)\n        end_event = torch.cuda.Event(enable_timing=True)\n\n        for i in range(iterations):\n            print(f\"[INFO]: Iteration {i + 1} / {iterations}...\", end=\"\\r\")\n            hidden, weight, labels = self.generate_forward_inputs()\n\n            start_event.record()\n            (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature)\n            end_event.record()\n            torch.cuda.synchronize()\n            torch_forward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature)\n            end_event.record()\n            torch.cuda.synchronize()\n            verl_forward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy(\n                hidden, weight, labels, self.temperature\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            verl_fused_forward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature)\n            end_event.record()\n            torch.cuda.synchronize()\n            kernel_forward_latency.append(start_event.elapsed_time(end_event))\n\n            torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4)\n            torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4)\n\n            torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4)\n            torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4)\n            torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4)\n            torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4)\n\n            torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)\n            torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)\n            torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)\n            torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)\n            torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)\n            torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)\n\n            # backward\n            g_entropy, g_logprobs = self.generate_backward_inputs()\n\n            start_event.record()\n            (d_torch_hidden, d_torch_weight) = torch.autograd.grad(\n                (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            torch_backward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (d_verl_hidden, d_verl_weight) = torch.autograd.grad(\n                (verl_entropy, verl_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            verl_backward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (d_verl_fused_hidden, d_verl_fused_weight) = torch.autograd.grad(\n                (verl_fused_entropy, verl_fused_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            verl_fused_backward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(\n                (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            kernel_backward_latency.append(start_event.elapsed_time(end_event))\n\n            torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4)\n\n            torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4)\n            torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4)\n\n            torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)\n\n        # remove first latency\n        torch_forward_latency = torch_forward_latency[1:]\n        torch_backward_latency = torch_backward_latency[1:]\n        verl_forward_latency = verl_forward_latency[1:]\n        verl_backward_latency = verl_backward_latency[1:]\n        verl_fused_forward_latency = verl_fused_forward_latency[1:]\n        verl_fused_backward_latency = verl_fused_backward_latency[1:]\n        kernel_forward_latency = kernel_forward_latency[1:]\n        kernel_backward_latency = kernel_backward_latency[1:]\n\n        print(\"\\n[INFO]: Verified forward & backward correctness.\")\n\n        print(\n            f\"[INFO]: Forward pass: Torch implementation average time: \"\n            f\"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Backward pass: torch implementation average time: \"\n            f\"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Forward pass: VeRL implementation average time: \"\n            f\"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Backward pass: VeRL implementation average time: \"\n            f\"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: \"\n            f\"{sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: \"\n            f\"{sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Forward pass: Kernel implementation average time: \"\n            f\"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms\"\n        )\n        print(\n            f\"[INFO]: Backward pass: kernel implementation average time: \"\n            f\"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms\"\n        )\n\n    def check_storage(self, method_name, run_forward):\n        self.cleanup()\n        self.generate_hyper()\n\n        hidden, weight, labels = self.generate_forward_inputs()\n\n        torch.cuda.reset_peak_memory_stats()\n        (logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature)\n        torch.cuda.synchronize()\n        torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n        print(f\"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB\")\n\n        g_entropy, g_logprobs = self.generate_backward_inputs()\n\n        torch.cuda.reset_peak_memory_stats()\n        (d_torch_hidden, d_torch_weight) = torch.autograd.grad(\n            (entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n        )\n        torch.cuda.synchronize()\n        torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n        print(f\"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB\")\n\n    def check_storage_all(self):\n        self.check_storage(\"Torch\", run_torch_entropy)\n        self.check_storage(\"VeRL\", run_verl_original_entropy)\n        self.check_storage(\"VeRL Torch Fused\", run_verl_torch_fused_entropy)\n        self.check_storage(\"Kernel\", linear_cross_entropy)\n\n\nif __name__ == \"__main__\":\n    # torch.cuda.memory._record_memory_history()\n\n    for test_case_idx in range(MAX_TEST_CASES):\n        print(f\"[INFO] Running test case {test_case_idx}\")\n        test = TestLinearCrossEntropy(test_case_idx)\n\n        test.verify_correctness()\n        test.check_storage_all()\n\n    # torch.cuda.memory._dump_snapshot(\"test_linear_cross_entropy.pkl\")\n"
  },
  {
    "path": "verl_rl/tests/utils/test_linear_cross_entropy_tp.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\nimport torch.distributed as dist\n\ntry:\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\nexcept ImportError:\n    # FIXME: remove these manually included paths\n    import sys\n\n    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), \"../../\")))\nfinally:\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\n\nimport verl.utils.torch_functional as verl_F\n\ncompute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)\n\nMAX_TEST_CASES = os.environ.get(\"MAX_TEST_CASES\", 5)\nVERIFY_TORCH_SELF = os.environ.get(\"VERIFY_TORCH_SELF\", False)\nLOW_MEMORY = os.environ.get(\"LOW_MEMORY\", False)\nLOW_MEMORY_DIV_FACTOR = os.environ.get(\"LOW_MEMORY_DIV_FACTOR\", 16)\n\n\ndef run_torch_entropy(\n    hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction=\"none\"\n) -> list[torch.Tensor]:\n    # [num_tokens, vocab_size]\n    if len(hidden.shape) > 2:\n        hidden = hidden.view(-1, hidden.shape[-1])  # [num_tokens, hidden_size]\n    if len(labels.shape) > 1:\n        labels = labels.view(-1)\n    logits = torch.matmul(\n        hidden.to(torch.float32),\n        weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32),\n    )\n    logits /= temperature\n    pd = torch.nn.functional.softmax(logits, dim=-1)  # [num_tokens, vocab_size]\n    entropy_a = torch.logsumexp(logits, dim=-1)  # [num_tokens]\n    entropy_b = torch.sum(pd * logits, dim=-1)  # [num_tokens]\n    entropy = entropy_a - entropy_b\n    logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction)  # [num_tokens]\n    logprobs = torch.neg(logprobs)\n    return logprobs, entropy\n\n\nclass TorchEntropyTP(torch.autograd.Function):\n    \"\"\"\n    it is used for testing the correctness of the kernel\n    it is not efficient and is not recommended to use in practice\n    \"\"\"\n\n    @staticmethod\n    def forward(\n        ctx,\n        hidden: torch.Tensor,\n        weight: torch.Tensor,\n        labels: torch.Tensor,\n        temperature: float,\n        dist_process_group: torch.distributed.ProcessGroup,\n    ):\n        # weight has shape [vocab_size, hidden_size], hidden has shape [num_tokens, hidden_size]\n        ctx.original_hidden_shape = hidden.shape\n        if len(hidden.shape) > 2:\n            hidden = hidden.view(-1, hidden.shape[-1])  # [num_tokens, hidden_size]\n        if len(labels.shape) > 1:\n            labels = labels.view(-1)\n\n        logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T)  # [num_tokens, vocab_size]\n        logits /= temperature\n        whole_logits = torch.empty(\n            (logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)),\n            dtype=logits.dtype,\n            device=logits.device,\n        )\n        whole_logits_ref = [\n            whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]]\n            for i in range(dist.get_world_size(dist_process_group))\n        ]\n        dist.all_gather(whole_logits_ref, logits, group=dist_process_group)\n\n        pd = torch.nn.functional.softmax(whole_logits, dim=-1)\n        entropy_a = torch.logsumexp(whole_logits, dim=-1)  # [num_tokens]\n        entropy_b = torch.sum(pd * whole_logits, dim=-1)  # [num_tokens]\n        entropy = entropy_a - entropy_b\n\n        logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction=\"none\")\n        logprobs = torch.neg(logprobs)\n\n        ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b)\n        ctx.dist_process_group = dist_process_group\n        ctx.temperature = temperature\n        return logprobs, entropy\n\n    @staticmethod\n    def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor):\n        hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors\n        dist_process_group = ctx.dist_process_group\n        temperature = ctx.temperature\n        batch_size, hidden_size = hidden.shape\n        vocab_size, hidden_size = weight.shape\n        rank = dist.get_rank(dist_process_group)\n\n        # Compute softmax probabilities\n        maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True)\n        exp_logits = torch.exp(whole_logits - maximum)\n        accumulate = exp_logits.sum(dim=-1, keepdim=True)\n        pd = exp_logits / accumulate\n\n        # Gradient for entropy\n        # entropy = entropy_a - entropy_b\n        # entropy_a = log(sum(exp(logits)))\n        # entropy_b = sum(pd * logits)\n        # d_entropy_a/d_logits = pd\n        # d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1)\n        # d_entropy/d_logits = d_entropy_a - d_entropy_b\n        # d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1)\n        # d_entropy/d_logits = -pd * (logits - b.unsqueeze(1))\n        d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1)))\n\n        # Gradient for logprobs\n        # logprobs = -cross_entropy = -log(pd[labels])\n        # d_logprobs/d_logits = (pd - one_hot(labels))\n        one_hot = torch.zeros_like(whole_logits)\n        one_hot.scatter_(1, labels.unsqueeze(1), 1)\n        g_logprobs = torch.neg(g_logprobs)\n        d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot)\n        # NOTE: This will lead to wrong result\n        # d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot\n\n        # Combine gradients\n        d_logits = d_logits_entropy + d_logits_logprobs\n        d_logits /= temperature\n\n        # Get local slice of gradients\n        local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size]\n\n        # Compute gradients for hidden and weight\n        d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32))\n        d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32))\n        d_hidden = d_hidden.view(ctx.original_hidden_shape)\n\n        return d_hidden, d_weight, None, None, None\n\n\nrun_torch_entropy_tp = TorchEntropyTP.apply\n\n\nclass TestLinearCrossEntropy_TensorParallel:\n    def __init__(self):\n        dist.init_process_group(backend=\"nccl\")\n        self.group = dist.group.WORLD\n\n        self.local_rank = dist.get_rank(self.group)\n        self.world_size = dist.get_world_size(self.group)\n        device = torch.device(f\"cuda:{self.local_rank}\")\n        torch.cuda.set_device(device)\n        print(f\"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}\")\n\n    def initialize(self, test_case_idx: int, temperature: float = 1.5):\n        self.test_case_idx = test_case_idx\n        self.temperature = temperature\n\n    def shutdown(self):\n        dist.destroy_process_group()\n\n    def cleanup(self):\n        torch.cuda.empty_cache()\n        torch.cuda.reset_peak_memory_stats()\n        import gc\n\n        gc.collect()\n        torch.cuda.synchronize()\n\n    def generate_hyper(self):\n        global LOW_MEMORY, LOW_MEMORY_DIV_FACTOR, MAX_TEST_CASES\n\n        self.dtype = torch.bfloat16\n        if self.test_case_idx == 0:\n            self.batch_size = 1\n            self.num_tokens = 1937\n            self.hidden_size = 3584\n            self.vocab_size = 152064\n        elif self.test_case_idx == 1:\n            self.batch_size = 1\n            self.num_tokens = 2169\n            self.hidden_size = 896\n            self.vocab_size = 151936\n        elif self.test_case_idx == 2:\n            self.batch_size = 1\n            self.num_tokens = 1530\n            self.hidden_size = 2048\n            self.vocab_size = 32256\n        elif self.test_case_idx == 3:\n            self.batch_size = 1\n            self.num_tokens = 1388\n            self.hidden_size = 4096\n            self.vocab_size = 102400\n        elif self.test_case_idx == 4:\n            self.batch_size = 1\n            self.num_tokens = 8192\n            self.hidden_size = 4096\n            self.vocab_size = 102400\n        else:\n            raise ValueError(f\"Invalid test case index: {self.test_case_idx}\")\n        if LOW_MEMORY:\n            self.vocab_size = int(self.vocab_size / LOW_MEMORY_DIV_FACTOR)\n        assert MAX_TEST_CASES <= 5, \"MAX_TEST_CASES should be less than or equal to 5.\"\n\n    def generate_forward_inputs(self):\n        hidden = (\n            torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device=\"cuda\")\n            .uniform_(-0.5, 0.5)\n            .requires_grad_()\n        )\n        weight = (\n            torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device=\"cuda\")\n            .uniform_(-0.5, 0.5)\n            .requires_grad_()\n        )\n        labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device=\"cuda\")\n        return hidden, weight, labels\n\n    def generate_backward_inputs(self):\n        g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device=\"cuda\").uniform_(-0.5, 0.5)\n        g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device=\"cuda\").uniform_(-1, 1)\n        return g_entropy, g_logprobs\n\n    def verify_torch_itself(self, iterations: int = 5):\n        self.cleanup()\n        self.generate_hyper()\n\n        for i in range(iterations):\n            hidden, weight, labels = self.generate_forward_inputs()\n\n            # NOTE: we need to manually synchronize hidden and labels among Process Group\n            dist.broadcast(hidden, src=0, group=self.group)\n            dist.broadcast(labels, src=0, group=self.group)\n\n            # forward pass\n            # Create a tensor to hold the gathered weights from all ranks\n            # weight has shape [vocab_size, hidden_size]\n            # We want to gather along the first dimension to get [vocab_size * world_size, hidden_size]\n\n            # Create a single contiguous tensor to hold all gathered weights\n            whole_weight = torch.empty(\n                (self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device\n            )\n\n            # Create views into the tensor for each rank's portion\n            whole_weight_views = [\n                whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size)\n            ]\n\n            # Perform all_gather operation using the views\n            dist.all_gather(whole_weight_views, weight, group=self.group)\n\n            # Set requires_grad for autograd\n            whole_weight.requires_grad_()\n\n            (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature)\n\n            (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)\n\n            torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4)\n            torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4)\n\n            # backward pass\n            g_entropy, g_logprobs = self.generate_backward_inputs()\n            # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group\n            dist.broadcast(g_entropy, src=0, group=self.group)\n            dist.broadcast(g_logprobs, src=0, group=self.group)\n\n            (single_d_hidden, single_d_weight) = torch.autograd.grad(\n                (single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n\n            (tp_d_hidden, tp_d_weight) = torch.autograd.grad(\n                (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            # NOTE: all-reduce on hidden is conducted outside the kernel\n            dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group)\n\n            torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4)\n            # Extract the corresponding slice from single_d_weight for comparison\n            # tp_d_weight has shape [vocab_size, hidden_size]\n            # single_d_weight has shape [vocab_size * world_size, hidden_size]\n            torch.testing.assert_close(\n                tp_d_weight,\n                single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size],\n                atol=1e-2,\n                rtol=1e-4,\n            )\n\n            # atol=1e-3, rtol=1e-4)\n        if self.local_rank == 0:\n            print(\"[PASS] torch TP correctness is verified\")\n\n    def check_torch_storage(self):\n        self.cleanup()\n        self.generate_hyper()\n\n        hidden, weight, labels = self.generate_forward_inputs()\n\n        # NOTE: we need to manually synchronize hidden and labels among Process Group\n        dist.broadcast(hidden, src=0, group=self.group)\n        dist.broadcast(labels, src=0, group=self.group)\n\n        torch.cuda.reset_peak_memory_stats()\n        (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)\n        torch.cuda.synchronize()\n        forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n\n        g_entropy, g_logprobs = self.generate_backward_inputs()\n        # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group\n        dist.broadcast(g_entropy, src=0, group=self.group)\n        dist.broadcast(g_logprobs, src=0, group=self.group)\n\n        torch.cuda.reset_peak_memory_stats()\n        (d_tp_hidden, d_tp_weight) = torch.autograd.grad(\n            (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n        )\n        torch.cuda.synchronize()\n        backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n        # NOTE: all-reduce on hidden is conducted outside the kernel\n        dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group)\n\n        if self.local_rank == 0:\n            print(f\"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB\")\n            print(f\"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB\")\n\n    def verify_kernel_correctness(self, iterations: int = 5):\n        self.cleanup()\n        self.generate_hyper()\n\n        torch_forward_latency = list()\n        torch_backward_latency = list()\n        kernel_forward_latency = list()\n        kernel_backward_latency = list()\n\n        start_event = torch.cuda.Event(enable_timing=True)\n        end_event = torch.cuda.Event(enable_timing=True)\n\n        for i in range(iterations):\n            hidden, weight, labels = self.generate_forward_inputs()\n\n            # NOTE: we need to manually synchronize hidden and labels among Process Group\n            dist.broadcast(hidden, src=0, group=self.group)\n            dist.broadcast(labels, src=0, group=self.group)\n\n            start_event.record()\n            (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)\n            end_event.record()\n            torch.cuda.synchronize()\n            torch_forward_latency.append(start_event.elapsed_time(end_event))\n\n            start_event.record()\n            (kernel_logprobs, kernel_entropy) = linear_cross_entropy(\n                hidden, weight, labels, self.temperature, \"none\", self.group\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            kernel_forward_latency.append(start_event.elapsed_time(end_event))\n\n            torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2)\n            torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2)\n\n            # backward pass\n            g_entropy, g_logprobs = self.generate_backward_inputs()\n            # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group\n            dist.broadcast(g_entropy, src=0, group=self.group)\n            dist.broadcast(g_logprobs, src=0, group=self.group)\n\n            start_event.record()\n            (torch_d_hidden, torch_d_weight) = torch.autograd.grad(\n                (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            torch_backward_latency.append(start_event.elapsed_time(end_event))\n            # NOTE: all-reduce on hidden is conducted outside the kernel\n            dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group)\n\n            start_event.record()\n            (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad(\n                (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n            )\n            end_event.record()\n            torch.cuda.synchronize()\n            kernel_backward_latency.append(start_event.elapsed_time(end_event))\n            # NOTE: all-reduce on hidden is conducted outside the kernel\n            dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group)\n\n            torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2)\n            torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2)\n\n        # remove first latency\n        torch_forward_latency = torch_forward_latency[1:]\n        torch_backward_latency = torch_backward_latency[1:]\n        kernel_forward_latency = kernel_forward_latency[1:]\n        kernel_backward_latency = kernel_backward_latency[1:]\n\n        if self.local_rank == 0:\n            print(\"\\n[PASS]: Verified kernel forward & backward correctness.\")\n\n            print(\n                f\"[INFO]: Forward pass: Torch implementation average time: \"\n                f\"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms\"\n            )\n            print(\n                f\"[INFO]: Backward pass: torch implementation average time: \"\n                f\"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms\"\n            )\n            print(\n                f\"[INFO]: Forward pass: Kernel implementation average time: \"\n                f\"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms\"\n            )\n            print(\n                f\"[INFO]: Backward pass: kernel implementation average time: \"\n                f\"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms\"\n            )\n\n    def check_kernel_storage(self):\n        self.cleanup()\n        self.generate_hyper()\n\n        hidden, weight, labels = self.generate_forward_inputs()\n\n        # NOTE: we need to manually synchronize hidden and labels among Process Group\n        dist.broadcast(hidden, src=0, group=self.group)\n        dist.broadcast(labels, src=0, group=self.group)\n\n        torch.cuda.reset_peak_memory_stats()\n        (kernel_logprobs, kernel_entropy) = linear_cross_entropy(\n            hidden, weight, labels, self.temperature, \"none\", self.group\n        )\n        torch.cuda.synchronize()\n        kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n\n        g_entropy, g_logprobs = self.generate_backward_inputs()\n        # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group\n        dist.broadcast(g_entropy, src=0, group=self.group)\n        dist.broadcast(g_logprobs, src=0, group=self.group)\n\n        torch.cuda.reset_peak_memory_stats()\n        (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(\n            (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False\n        )\n        torch.cuda.synchronize()\n        kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024\n        # NOTE: all-reduce on hidden is conducted outside the kernel\n        dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group)\n\n        if self.local_rank == 0:\n            print(f\"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB\")\n            print(f\"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB\")\n\n\nif __name__ == \"__main__\":\n    # TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernels/test_linear_cross_entropy_tp.py\n\n    # Check if running with torchrun (distributed mode)\n    assert int(os.environ[\"WORLD_SIZE\"]) > 1, (\n        \"[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to \"\n        \"execute this script.\"\n    )\n    torch.manual_seed(233376 + int(os.environ.get(\"RANK\", 0)))\n\n    # set_backward_method(BackwardEnum._Total_Fuse_MN)\n    # set_backward_method(BackwardEnum._Split_Dlogits_N)\n\n    test = TestLinearCrossEntropy_TensorParallel()\n    for test_case_idx in range(MAX_TEST_CASES):\n        print(f\"[INFO] Running test case {test_case_idx}\")\n        test.initialize(test_case_idx)\n        if VERIFY_TORCH_SELF:\n            test.verify_torch_itself()\n        test.check_torch_storage()\n        test.verify_kernel_correctness()\n        test.check_kernel_storage()\n\n    test.shutdown()\n"
  },
  {
    "path": "verl_rl/tests/utils/test_model_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom types import SimpleNamespace  # Or use a mock object library\n\nimport pytest\n\nfrom verl.utils.model import update_model_config\n\n\n# Parametrize with different override scenarios\n@pytest.mark.parametrize(\n    \"override_kwargs\",\n    [\n        {\"param_a\": 5, \"new_param\": \"plain_added\"},\n        {\"param_a\": 2, \"nested_params\": {\"sub_param_x\": \"updated_x\", \"sub_param_z\": True}},\n    ],\n)\ndef test_update_model_config(override_kwargs):\n    \"\"\"\n    Tests that update_model_config correctly updates attributes,\n    handling both plain and nested overrides via parametrization.\n    \"\"\"\n    # Create a fresh mock config object for each test case\n    mock_config = SimpleNamespace(\n        param_a=1, nested_params=SimpleNamespace(sub_param_x=\"original_x\", sub_param_y=100), other_param=\"keep_me\"\n    )\n    # Apply the updates using the parametrized override_kwargs\n    update_model_config(mock_config, override_kwargs)\n\n    # Assertions to check if the config was updated correctly\n    if \"nested_params\" in override_kwargs:  # Case 2: Nested override\n        override_nested = override_kwargs[\"nested_params\"]\n        assert mock_config.nested_params.sub_param_x == override_nested[\"sub_param_x\"], \"Nested sub_param_x mismatch\"\n        assert mock_config.nested_params.sub_param_y == 100, \"Nested sub_param_y should be unchanged\"\n        assert hasattr(mock_config.nested_params, \"sub_param_z\"), \"Expected nested sub_param_z to be added\"\n        assert mock_config.nested_params.sub_param_z == override_nested[\"sub_param_z\"], \"Value of sub_param_z mismatch\"\n    else:  # Case 1: Plain override (nested params untouched)\n        assert mock_config.nested_params.sub_param_x == \"original_x\", \"Nested sub_param_x should be unchanged\"\n        assert mock_config.nested_params.sub_param_y == 100, \"Nested sub_param_y should be unchanged\"\n        assert not hasattr(mock_config.nested_params, \"sub_param_z\"), \"Nested sub_param_z should not exist\"\n"
  },
  {
    "path": "verl_rl/tests/utils/test_nvtx_profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nfrom unittest.mock import MagicMock, patch\n\nfrom verl.utils import omega_conf_to_dataclass\nfrom verl.utils.profiler import ProfilerConfig\nfrom verl.utils.profiler.nvtx_profile import NsightSystemsProfiler\n\n\nclass TestProfilerConfig(unittest.TestCase):\n    def test_config_init(self):\n        import os\n\n        from hydra import compose, initialize_config_dir\n\n        with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n            cfg = compose(config_name=\"ppo_trainer\")\n        arr = cfg.actor_rollout_ref\n        for config in [\n            cfg.critic.profiler,\n            arr.profiler,\n            cfg.reward_model.profiler,\n        ]:\n            profiler_config = omega_conf_to_dataclass(config)\n            self.assertEqual(profiler_config.discrete, config.discrete)\n            self.assertEqual(profiler_config.all_ranks, config.all_ranks)\n            self.assertEqual(profiler_config.ranks, config.ranks)\n            assert isinstance(profiler_config, ProfilerConfig)\n            with self.assertRaises(AttributeError):\n                _ = profiler_config.non_existing_key\n            assert config.get(\"non_existing_key\") == profiler_config.get(\"non_existing_key\")\n            assert config.get(\"non_existing_key\", 1) == profiler_config.get(\"non_existing_key\", 1)\n            assert config[\"discrete\"] == profiler_config[\"discrete\"]\n            from dataclasses import FrozenInstanceError\n\n            with self.assertRaises(FrozenInstanceError):\n                profiler_config.discrete = False\n\n    def test_frozen_config(self):\n        \"\"\"Test that modifying frozen keys in ProfilerConfig raises exceptions.\"\"\"\n        from dataclasses import FrozenInstanceError\n\n        from verl.utils.profiler.config import ProfilerConfig\n\n        # Create a new ProfilerConfig instance\n        config = ProfilerConfig(discrete=True, all_ranks=False, ranks=[0])\n\n        # Test direct attribute assignment\n        with self.assertRaises(FrozenInstanceError):\n            config.discrete = False\n\n        with self.assertRaises(FrozenInstanceError):\n            config.all_ranks = True\n\n        with self.assertRaises(FrozenInstanceError):\n            config.ranks = [1, 2, 3]\n\n        # Test dictionary-style assignment\n        with self.assertRaises(TypeError):\n            config[\"discrete\"] = False\n\n        with self.assertRaises(TypeError):\n            config[\"all_ranks\"] = True\n\n        with self.assertRaises(TypeError):\n            config[\"ranks\"] = [1, 2, 3]\n\n        config[\"extra\"][\"key\"] = \"value\"\n\n\nclass TestNsightSystemsProfiler(unittest.TestCase):\n    \"\"\"Test suite for NsightSystemsProfiler functionality.\n\n    Test Plan:\n    1. Initialization: Verify profiler state after creation\n    2. Basic Profiling: Test start/stop functionality\n    3. Discrete Mode: Test discrete profiling behavior\n    4. Annotation: Test the annotate decorator in both normal and discrete modes\n    5. Config Validation: Verify proper config initialization from OmegaConf\n    \"\"\"\n\n    def setUp(self):\n        self.config = ProfilerConfig(all_ranks=True)\n        self.rank = 0\n        self.profiler = NsightSystemsProfiler(self.rank, self.config)\n\n    def test_initialization(self):\n        self.assertEqual(self.profiler.this_rank, True)\n        self.assertEqual(self.profiler.this_step, False)\n        self.assertEqual(self.profiler.discrete, False)\n\n    def test_start_stop_profiling(self):\n        with patch(\"torch.cuda.profiler.start\") as mock_start, patch(\"torch.cuda.profiler.stop\") as mock_stop:\n            # Test start\n            self.profiler.start()\n            self.assertTrue(self.profiler.this_step)\n            mock_start.assert_called_once()\n\n            # Test stop\n            self.profiler.stop()\n            self.assertFalse(self.profiler.this_step)\n            mock_stop.assert_called_once()\n\n    def test_discrete_profiling(self):\n        discrete_config = ProfilerConfig(discrete=True, all_ranks=True)\n        profiler = NsightSystemsProfiler(self.rank, discrete_config)\n\n        with patch(\"torch.cuda.profiler.start\") as mock_start, patch(\"torch.cuda.profiler.stop\") as mock_stop:\n            profiler.start()\n            self.assertTrue(profiler.this_step)\n            mock_start.assert_not_called()  # Shouldn't start immediately in discrete mode\n\n            profiler.stop()\n            self.assertFalse(profiler.this_step)\n            mock_stop.assert_not_called()  # Shouldn't stop immediately in discrete mode\n\n    def test_annotate_decorator(self):\n        mock_self = MagicMock()\n        mock_self.profiler = self.profiler\n        mock_self.profiler.this_step = True\n\n        @NsightSystemsProfiler.annotate(message=\"test\")\n        def test_func(self, *args, **kwargs):\n            return \"result\"\n\n        with (\n            patch(\"torch.cuda.profiler.start\") as mock_start,\n            patch(\"torch.cuda.profiler.stop\") as mock_stop,\n            patch(\"verl.utils.profiler.nvtx_profile.mark_start_range\") as mock_start_range,\n            patch(\"verl.utils.profiler.nvtx_profile.mark_end_range\") as mock_end_range,\n        ):\n            result = test_func(mock_self)\n            self.assertEqual(result, \"result\")\n            mock_start_range.assert_called_once()\n            mock_end_range.assert_called_once()\n            mock_start.assert_not_called()  # Not discrete mode\n            mock_stop.assert_not_called()  # Not discrete mode\n\n    def test_annotate_discrete_mode(self):\n        discrete_config = ProfilerConfig(discrete=True, all_ranks=True)\n        profiler = NsightSystemsProfiler(self.rank, discrete_config)\n        mock_self = MagicMock()\n        mock_self.profiler = profiler\n        mock_self.profiler.this_step = True\n\n        @NsightSystemsProfiler.annotate(message=\"test\")\n        def test_func(self, *args, **kwargs):\n            return \"result\"\n\n        with (\n            patch(\"torch.cuda.profiler.start\") as mock_start,\n            patch(\"torch.cuda.profiler.stop\") as mock_stop,\n            patch(\"verl.utils.profiler.nvtx_profile.mark_start_range\") as mock_start_range,\n            patch(\"verl.utils.profiler.nvtx_profile.mark_end_range\") as mock_end_range,\n        ):\n            result = test_func(mock_self)\n            self.assertEqual(result, \"result\")\n            mock_start_range.assert_called_once()\n            mock_end_range.assert_called_once()\n            mock_start.assert_called_once()  # Should start in discrete mode\n            mock_stop.assert_called_once()  # Should stop in discrete mode\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "verl_rl/tests/utils/test_rollout_trace_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport sys\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\n\nfrom verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op\n\n\n@pytest.fixture(autouse=True)\ndef reset_rollout_trace_config_singleton():\n    \"\"\"Fixture to reset the RolloutTraceConfig singleton before each test.\"\"\"\n    RolloutTraceConfig.reset()\n\n\n@pytest.fixture\ndef mock_weave_client():\n    \"\"\"Mocks the weave module and its client, yielding the mock client.\"\"\"\n    mock_weave = MagicMock()\n    mock_client = MagicMock()\n    mock_call = MagicMock()\n    mock_client.create_call.return_value = mock_call\n    mock_weave.init.return_value = mock_client\n\n    # Also mock the call_context if it's used internally by the decorator\n    mock_weave.trace.context.call_context.return_value = MagicMock()\n\n    with patch.dict(sys.modules, {\"weave\": mock_weave, \"weave.trace.context\": mock_weave.trace.context}):\n        yield mock_client\n\n\nclass TracedClass:\n    @rollout_trace_op\n    # @weave.op\n    # @mlflow.trace\n    async def my_method(self, a, b=\"default\"):\n        return f\"result: {a}, {b}\"\n\n    @rollout_trace_op\n    # @weave.op\n    # @mlflow.trace\n    async def middle_method(self, a, b=\"default\"):\n        await self.my_method(\"test_a1\", b=\"test_b1\")\n        return f\"result: {a}, {b}\"\n\n    @rollout_trace_op\n    # @mlflow.trace\n    async def my_method_with_exception(self):\n        raise ValueError(\"Test Exception\")\n\n    async def upper_method(self):\n        await self.my_method(\"test_a0\", b=\"test_b0\")\n        await self.middle_method(\"test_a2\", b=\"test_b2\")\n        return True\n\n\nclass UntracedClass:\n    @rollout_trace_op\n    async def my_method(self, x):\n        return x * 2\n\n\nasync def test_rollout_trace_on_untraced_class():\n    \"\"\"Tests that the decorator works correctly when no backend is configured.\"\"\"\n    instance = UntracedClass()\n    assert await instance.my_method(10) == 20\n\n\nasync def test_rollout_trace_with_tracer(mock_weave_client):\n    \"\"\"Tests that the decorator calls the tracer's methods correctly.\"\"\"\n    RolloutTraceConfig.init(project_name=\"my-project\", experiment_name=\"my-experiment\", backend=\"weave\")\n    instance = TracedClass()\n    assert RolloutTraceConfig.get_client() is mock_weave_client\n\n    result = await instance.my_method(\"test_a\", b=\"test_b\")\n\n    assert result == \"result: test_a, test_b\"\n    mock_weave_client.create_call.assert_called_once()\n    call_kwargs = mock_weave_client.create_call.call_args.kwargs\n    assert call_kwargs[\"op\"] == \"TracedClass.my_method\"\n    expected_inputs = {\"a\": \"test_a\", \"b\": \"test_b\"}\n    assert call_kwargs[\"inputs\"] == expected_inputs\n\n    mock_call = mock_weave_client.create_call.return_value\n    mock_weave_client.finish_call.assert_called_once_with(mock_call, output=result)\n\n\nasync def test_rollout_trace_with_exception(mock_weave_client):\n    \"\"\"Tests that `finish` is called with the exception when one is raised.\"\"\"\n    RolloutTraceConfig.init(project_name=\"my-project\", experiment_name=\"my-experiment\", backend=\"weave\")\n    instance = TracedClass()\n\n    with pytest.raises(ValueError, match=\"Test Exception\"):\n        await instance.my_method_with_exception()\n\n    mock_weave_client.create_call.assert_called_once()\n    mock_call = mock_weave_client.create_call.return_value\n    mock_weave_client.finish_call.assert_called_once()\n\n    # Check that finish_call was called with the exception\n    args, kwargs = mock_weave_client.finish_call.call_args\n    assert args[0] == mock_call\n    assert \"exception\" in kwargs\n    assert isinstance(kwargs[\"exception\"], ValueError)\n\n\nasync def test_rollout_trace_with_dummy_backend(mock_weave_client):\n    \"\"\"Tests that the tracer is not called when the backend is 'dummy'.\"\"\"\n    RolloutTraceConfig.init(project_name=\"my-project\", experiment_name=\"my-experiment\", backend=\"dummy\")\n    instance = TracedClass()\n\n    await instance.my_method(\"test_a\")\n\n    mock_weave_client.create_call.assert_not_called()\n\n\n@pytest.mark.skipif(\n    os.environ.get(\"RUN_WEAVE_INTEGRATION_TESTS\", \"false\").lower() != \"true\",\n    reason=\"Skipping weave integration test. Set RUN_WEAVE_INTEGRATION_TESTS=true to run.\",\n)\nasync def test_rollout_trace_with_real_weave_backend():\n    \"\"\"Integration test with a real weave backend.\"\"\"\n\n    # This assumes that the weave environment (e.g., project) is configured\n    RolloutTraceConfig.init(project_name=\"my-project\", experiment_name=\"my-experiment\", backend=\"weave\")\n\n    instance = TracedClass()\n\n    with rollout_trace_attr(step=1, sample_index=2, rollout_n=3):\n        await instance.upper_method()\n\n    with pytest.raises(ValueError, match=\"Test Exception\"):\n        await instance.my_method_with_exception()\n\n    print(\"\\nWeave integration test ran successfully. Check your weave project for the trace.\")\n\n\n@pytest.mark.skipif(\n    os.environ.get(\"RUN_MLFLOW_INTEGRATION_TESTS\", \"false\").lower() != \"true\",\n    reason=\"Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.\",\n)\nasync def test_rollout_trace_with_real_mlflow_backend():\n    \"\"\"Integration test with a real mlflow backend.\"\"\"\n\n    # This assumes that the mlflow environment (e.g., project) is configured\n    RolloutTraceConfig.init(project_name=\"my-project\", experiment_name=\"my-experiment\", backend=\"mlflow\")\n\n    instance = TracedClass()\n\n    with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name=\"agent_run\"):\n        assert await instance.upper_method()\n\n    # with pytest.raises(ValueError, match=\"Test Exception\"):\n    #     await instance.my_method_with_exception()\n\n    print(\"\\nWeave integration test ran successfully. Check your weave project for the trace.\")\n"
  },
  {
    "path": "verl_rl/tests/utils/test_seqlen_balancing.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nfrom verl import DataProto\nfrom verl.utils.model import create_random_mask\nfrom verl.utils.seqlen_balancing import (\n    ceildiv,\n    get_reverse_idx,\n    prepare_dynamic_batch,\n    rearrange_micro_batches,\n    restore_dynamic_batch,\n)\n\n\ndef test_seqlen_balancing():\n    input_ids = torch.randint(low=0, high=10, size=(20, 100))\n\n    attention_mask = create_random_mask(\n        input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5\n    )\n    data = {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n    dataproto = DataProto.from_single_dict(data)\n    micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300)\n    batch = torch.cat(micro_batches)\n    micro_bsz_idx = []\n    for idx in micro_bsz_idx_lst:\n        micro_bsz_idx.extend(idx)\n    reverse_idx_map = get_reverse_idx(micro_bsz_idx)\n    reverse_idx_map = torch.tensor(reverse_idx_map)\n    new_batch = batch[reverse_idx_map]\n    torch.testing.assert_close(new_batch, dataproto.batch)\n\n\ndef test_dynamic_batch():\n    input_ids = torch.randint(low=0, high=10, size=(20, 100))\n\n    attention_mask = create_random_mask(\n        input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5\n    )\n    data = {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n    dataproto = DataProto.from_single_dict(data)\n    micro_batches, micro_bsz_idx_lst = prepare_dynamic_batch(dataproto, max_token_len=300)\n    input_ids = torch.cat([micro_batch.batch[\"input_ids\"] for micro_batch in micro_batches], dim=0)\n    input_ids = restore_dynamic_batch(input_ids, micro_bsz_idx_lst)\n    torch.testing.assert_close(input_ids, dataproto.batch[\"input_ids\"])\n\n\ndef _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb):\n    # 1) init process group & CUDA\n    torch.cuda.set_device(rank)\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=init_method,\n        world_size=world_size,\n        rank=rank,\n    )\n\n    # 2) build a small random batch (each rank different length to force mismatch)\n    torch.manual_seed(42 + rank)\n    input_ids = torch.randint(0, 10, (20 + rank * 5, 100), device=f\"cuda:{rank}\")\n    attention_mask = create_random_mask(\n        input_ids=input_ids,\n        max_ratio_of_left_padding=0.1,\n        max_ratio_of_valid_token=0.9,\n        min_ratio_of_valid_token=0.5,\n    )\n    dp = {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n    proto = DataProto.from_single_dict(dp)\n    batch = proto.batch\n\n    # 3) call rearrange_micro_batches with one of the two params under test\n    micros, idx_lst = rearrange_micro_batches(\n        batch,\n        max_token_len=max_token_len,\n        dp_group=dist.group.WORLD,\n        same_micro_num_in_dp=use_same_dp,\n        min_num_micro_batch=min_mb,\n    )\n\n    # 4) check the enforced counts\n    seq_len_effective: torch.Tensor = batch[\"attention_mask\"].sum(dim=1)\n    total_seqlen = seq_len_effective.sum().item()\n    local = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))\n\n    if min_mb is not None:\n        expected = max(local, min_mb)\n        assert len(micros) == expected\n    if use_same_dp:\n        # gather all local_counts\n        counts = [torch.zeros(1, device=f\"cuda:{rank}\") for _ in range(world_size)]\n        counts[rank].fill_(local)\n        dist.all_gather(counts, counts[rank])\n        expected = max(int(c.item()) for c in counts)\n        assert len(micros) == expected\n    else:\n        # if neither, we get the local natural count\n        assert len(micros) == local\n\n    # 5) reconstruction sanity: concat→reverse_idx→orig\n    flat = torch.cat(micros, dim=0)\n    idx = []\n    for sub in idx_lst:\n        idx.extend(sub)\n    inv = get_reverse_idx(idx)\n    inv = torch.tensor(inv, device=flat.device)\n    reconstructed = flat[inv]\n    torch.testing.assert_close(reconstructed, batch)\n\n    dist.destroy_process_group()\n\n\ndef test_dataproto_split_uneven():\n    \"\"\"Test DataProto.split with uneven splits\"\"\"\n    # Create test data with 10 items\n    input_ids = torch.randint(low=0, high=10, size=(10, 5))\n    attention_mask = torch.ones(10, 5)\n    data = {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n    dataproto = DataProto.from_single_dict(data)\n\n    # Test split with size 3 (should create chunks of [3, 3, 3, 1])\n    splits = dataproto.split(3)\n    assert len(splits) == 4\n    assert len(splits[0]) == 3\n    assert len(splits[1]) == 3\n    assert len(splits[2]) == 3\n    assert len(splits[3]) == 1\n\n    reconstructed = DataProto.concat(splits)\n    torch.testing.assert_close(reconstructed.batch[\"input_ids\"], dataproto.batch[\"input_ids\"])\n    torch.testing.assert_close(reconstructed.batch[\"attention_mask\"], dataproto.batch[\"attention_mask\"])\n\n    # Test split with size equal to length (should create one chunk)\n    splits = dataproto.split(10)\n    assert len(splits) == 1\n    assert len(splits[0]) == 10\n\n    # Test split with size larger than length (should create one chunk with all data)\n    splits = dataproto.split(15)\n    assert len(splits) == 1\n    assert len(splits[0]) == 10\n\n    # Test with non-tensor batch data\n    import numpy as np\n\n    data_with_non_tensor = {\n        \"input_ids\": input_ids,\n        \"attention_mask\": attention_mask,\n        \"labels\": np.array([f\"label_{i}\" for i in range(10)], dtype=object),\n    }\n    dataproto_with_non_tensor = DataProto.from_single_dict(data_with_non_tensor)\n\n    splits = dataproto_with_non_tensor.split(3)\n    assert len(splits) == 4\n    assert len(splits[0]) == 3\n    assert len(splits[1]) == 3\n    assert len(splits[2]) == 3\n    assert len(splits[3]) == 1\n\n    # Verify non-tensor data integrity\n    reconstructed = DataProto.concat(splits)\n    np.testing.assert_array_equal(\n        reconstructed.non_tensor_batch[\"labels\"], dataproto_with_non_tensor.non_tensor_batch[\"labels\"]\n    )\n\n\ndef test_seqlen_balancing_distributed_params(tmp_path):\n    world_size = 2\n    init_file = tmp_path / \"dist_init\"\n    init_file.write_text(\"\")  # empty file\n    init_method = f\"file://{init_file}\"\n\n    # test min_num_micro_batch only\n    mp.spawn(\n        _worker,\n        args=(world_size, init_method, 300, False, 4),\n        nprocs=world_size,\n        join=True,\n    )\n\n    # test same_micro_num_in_dp only\n    mp.spawn(\n        _worker,\n        args=(world_size, init_method, 300, True, None),\n        nprocs=world_size,\n        join=True,\n    )\n"
  },
  {
    "path": "verl_rl/tests/utils/test_temp_env_on_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport pytest\n\nfrom verl.utils.py_functional import temp_env_var\n\n\n@pytest.fixture(autouse=True)\ndef clean_env():\n    \"\"\"Fixture to clean up environment variables before and after each test.\"\"\"\n    # Store original environment state\n    original_env = dict(os.environ)\n\n    # Clean up any test variables that might exist\n    test_vars = [\"TEST_VAR\", \"TEST_VAR_2\", \"EXISTING_VAR\"]\n    for var in test_vars:\n        if var in os.environ:\n            del os.environ[var]\n\n    # Yield control to the test function\n    yield\n\n    # Restore original environment state after test\n    os.environ.clear()\n    os.environ.update(original_env)\n\n\ndef test_set_new_env_var():\n    \"\"\"Test setting a new environment variable that didn't exist before.\"\"\"\n    # Ensure variable doesn't exist\n    assert \"TEST_VAR\" not in os.environ\n\n    with temp_env_var(\"TEST_VAR\", \"test_value\"):\n        # Variable should be set inside context\n        assert os.environ[\"TEST_VAR\"] == \"test_value\"\n        assert \"TEST_VAR\" in os.environ\n\n    # Variable should be removed after context\n    assert \"TEST_VAR\" not in os.environ\n\n\ndef test_restore_existing_env_var():\n    \"\"\"Test restoring an environment variable that already existed.\"\"\"\n    # Set up existing variable\n    os.environ[\"EXISTING_VAR\"] = \"original_value\"\n\n    with temp_env_var(\"EXISTING_VAR\", \"temporary_value\"):\n        # Variable should be temporarily changed\n        assert os.environ[\"EXISTING_VAR\"] == \"temporary_value\"\n\n    # Variable should be restored to original value\n    assert os.environ[\"EXISTING_VAR\"] == \"original_value\"\n\n\ndef test_env_var_restored_on_exception():\n    \"\"\"Test that environment variables are restored even when exceptions occur.\"\"\"\n    # Set up existing variable\n    os.environ[\"EXISTING_VAR\"] = \"original_value\"\n\n    with pytest.raises(ValueError):\n        with temp_env_var(\"EXISTING_VAR\", \"temporary_value\"):\n            # Verify variable is set\n            assert os.environ[\"EXISTING_VAR\"] == \"temporary_value\"\n            # Raise exception\n            raise ValueError(\"Test exception\")\n\n    # Variable should still be restored despite exception\n    assert os.environ[\"EXISTING_VAR\"] == \"original_value\"\n\n\ndef test_nested_context_managers():\n    \"\"\"Test nested temp_env_var context managers.\"\"\"\n    # Set up original variable\n    os.environ[\"TEST_VAR\"] = \"original\"\n\n    with temp_env_var(\"TEST_VAR\", \"level1\"):\n        assert os.environ[\"TEST_VAR\"] == \"level1\"\n\n        with temp_env_var(\"TEST_VAR\", \"level2\"):\n            assert os.environ[\"TEST_VAR\"] == \"level2\"\n\n        # Should restore to level1\n        assert os.environ[\"TEST_VAR\"] == \"level1\"\n\n    # Should restore to original\n    assert os.environ[\"TEST_VAR\"] == \"original\"\n\n\ndef test_multiple_different_vars():\n    \"\"\"Test setting multiple different environment variables.\"\"\"\n    # Set up one existing variable\n    os.environ[\"EXISTING_VAR\"] = \"existing_value\"\n\n    with temp_env_var(\"EXISTING_VAR\", \"modified\"):\n        with temp_env_var(\"TEST_VAR\", \"new_value\"):\n            assert os.environ[\"EXISTING_VAR\"] == \"modified\"\n            assert os.environ[\"TEST_VAR\"] == \"new_value\"\n\n    # Check restoration\n    assert os.environ[\"EXISTING_VAR\"] == \"existing_value\"\n    assert \"TEST_VAR\" not in os.environ\n\n\ndef test_empty_string_value():\n    \"\"\"Test setting environment variable to empty string.\"\"\"\n    with temp_env_var(\"TEST_VAR\", \"\"):\n        assert os.environ[\"TEST_VAR\"] == \"\"\n        assert \"TEST_VAR\" in os.environ\n\n    # Should be removed after context\n    assert \"TEST_VAR\" not in os.environ\n\n\ndef test_overwrite_with_empty_string():\n    \"\"\"Test overwriting existing variable with empty string.\"\"\"\n    os.environ[\"EXISTING_VAR\"] = \"original\"\n\n    with temp_env_var(\"EXISTING_VAR\", \"\"):\n        assert os.environ[\"EXISTING_VAR\"] == \"\"\n\n    # Should restore original value\n    assert os.environ[\"EXISTING_VAR\"] == \"original\"\n\n\ndef test_context_manager_returns_none():\n    \"\"\"Test that context manager yields None.\"\"\"\n    with temp_env_var(\"TEST_VAR\", \"value\") as result:\n        assert result is None\n        assert os.environ[\"TEST_VAR\"] == \"value\"\n"
  },
  {
    "path": "verl_rl/tests/utils/test_timeout_decorator_cpu.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport multiprocessing\nimport sys\nimport threading\nimport time\n\nimport pytest  # Import pytest\n\nfrom verl.utils.py_functional import timeout_limit as timeout\n\n# --- Test Task Functions ---\nTEST_TIMEOUT_SECONDS = 1.5  # Timeout duration for tests\nLONG_TASK_DURATION = TEST_TIMEOUT_SECONDS + 0.5  # Duration slightly longer than timeout\n\n\n@timeout(seconds=TEST_TIMEOUT_SECONDS)  # Keep global decorator for mp tests\ndef quick_task(x):\n    \"\"\"A task that completes quickly.\"\"\"\n    time.sleep(0.1)\n    return \"quick_ok\"\n\n\n@timeout(seconds=TEST_TIMEOUT_SECONDS)  # Keep global decorator for mp tests\ndef slow_task(x):\n    \"\"\"A task that takes longer than the timeout.\"\"\"\n    time.sleep(LONG_TASK_DURATION)\n    return \"slow_finished\"  # This return value indicates it didn't time out\n\n\n# REMOVE global decorator here\ndef task_raises_value_error():  # Now truly not globally decorated\n    \"\"\"A task that intentionally raises a ValueError.\"\"\"\n    raise ValueError(\"Specific value error from task\")\n\n\n# --- Top-level function for signal test in subprocess ---\n# Keep this decorated globally for the specific subprocess test case\n@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)\ndef top_level_decorated_quick_task_signal():\n    \"\"\"A pickleable top-level function decorated with signal timeout.\"\"\"\n    # Assuming this calls the logic of quick_task directly for the test purpose\n    time.sleep(0.1)\n    return \"quick_ok_signal_subprocess\"  # Different return for clarity if needed\n\n\n# --- Top-level function for signal test in subprocess ---\n# Keep this decorated globally for the specific subprocess test case\n@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)\ndef top_level_decorated_slow_task_signal():\n    \"\"\"A pickleable top-level function decorated with signal timeout.\"\"\"\n    time.sleep(LONG_TASK_DURATION)\n    return \"slow_finished\"\n\n\n# --- NEW: Top-level helper function to run target in process ---\ndef run_target_and_put_in_queue(target_func, q):\n    \"\"\"\n    Top-level helper function to run a target function and put its result or exception into a queue.\n    This function is pickleable and can be used as the target for multiprocessing.Process.\n    \"\"\"\n    try:\n        result = target_func()\n        q.put((\"success\", result))\n    except Exception as e:\n        q.put((\"error\", e))\n\n\n# Use a module-level fixture to set the start method on macOS\n@pytest.fixture(scope=\"module\", autouse=True)  # Changed scope to module\ndef set_macos_start_method():\n    if sys.platform == \"darwin\":\n        # Force fork method on macOS to avoid pickling issues with globally decorated functions\n        # when running tests via pytest discovery.\n        current_method = multiprocessing.get_start_method(allow_none=True)\n        # Only set if not already set or if set to something else (less likely in test run)\n        if current_method is None or current_method != \"fork\":\n            try:\n                multiprocessing.set_start_method(\"fork\", force=True)\n            except RuntimeError:\n                # Might fail if context is already started, ignore in that case.\n                pass\n\n\ndef test_quick_task():  # Renamed from test_multiprocessing_quick_task\n    \"\"\"Tests timeout handles a quick task correctly.\"\"\"\n    # Call the globally decorated function directly\n    result = quick_task(1)\n    assert result == \"quick_ok\"  # Use pytest assert\n\n\ndef test_slow_task_timeout():  # Renamed from test_multiprocessing_slow_task_timeout\n    \"\"\"Tests timeout correctly raises TimeoutError for a slow task.\"\"\"\n    # Call the globally decorated function directly within pytest.raises\n    with pytest.raises(TimeoutError) as excinfo:  # Use pytest.raises\n        slow_task(1)\n    # Check the error message from the multiprocessing implementation\n    assert f\"timed out after {TEST_TIMEOUT_SECONDS} seconds\" in str(excinfo.value)  # Use pytest assert\n\n\ndef test_internal_exception():  # Renamed from test_multiprocessing_internal_exception\n    \"\"\"Tests timeout correctly propagates internal exceptions.\"\"\"\n    # Apply the default timeout decorator dynamically to the undecorated function\n    decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS)(task_raises_value_error)  # Apply decorator dynamically\n    with pytest.raises(ValueError) as excinfo:  # Use pytest.raises\n        decorated_task()  # Call the dynamically decorated function\n    assert str(excinfo.value) == \"Specific value error from task\"  # Use pytest assert\n\n\n# --- Test the signal implementation (use_signals=True) ---\n# Note: As per py_functional.py, use_signals=True currently falls back to\n# multiprocessing on POSIX. These tests verify that behavior.\n\n\ndef test_signal_quick_task_main_process():  # Removed self\n    \"\"\"Tests signal timeout handles a quick task correctly in the main process.\"\"\"\n\n    # Apply the signal decorator dynamically\n    def plain_quick_task_logic():\n        time.sleep(0.1)\n        return \"quick_ok_signal\"\n\n    decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_quick_task_logic)\n    assert decorated_task() == \"quick_ok_signal\"  # Use pytest assert\n\n\ndef test_signal_slow_task_main_process_timeout():  # Removed self\n    \"\"\"Tests signal timeout correctly raises TimeoutError for a slow task in the main process.\"\"\"\n\n    # Apply the signal decorator dynamically\n    def plain_slow_task_logic():\n        time.sleep(LONG_TASK_DURATION)\n        return \"slow_finished_signal\"\n\n    decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_slow_task_logic)\n    with pytest.raises(TimeoutError) as excinfo:  # Use pytest.raises\n        decorated_task()\n    # Check the error message (falls back to multiprocessing message on POSIX)\n    assert f\"timed out after {TEST_TIMEOUT_SECONDS} seconds\" in str(excinfo.value)  # Use pytest assert\n\n\n@pytest.mark.skip(reason=\"this test won't pass. Just to show why use_signals should not be used\")\ndef test_signal_in_thread_does_not_timeout():\n    \"\"\"\n    Tests that signal-based timeout does NOT work reliably in a child thread.\n    The TimeoutError from the signal handler is not expected to be raised.\n    \"\"\"\n    result_container = []  # Use a list to store result from thread\n    exception_container = []  # Use a list to store exception from thread\n\n    @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)\n    def slow_task_in_thread():\n        try:\n            print(\"Thread: Starting slow task...\")\n            time.sleep(LONG_TASK_DURATION)\n            print(\"Thread: Slow task finished.\")\n            return \"slow_finished_in_thread\"\n        except Exception as e:\n            # Catch any exception within the thread's target function\n            print(f\"Thread: Caught exception: {e}\")\n            exception_container.append(e)\n            return None  # Indicate failure\n\n    def thread_target():\n        try:\n            # Run the decorated function inside the thread\n            res = slow_task_in_thread()\n            if res is not None:\n                result_container.append(res)\n        except Exception as e:\n            # This might catch exceptions happening *outside* the decorated function\n            # but still within the thread target, though less likely here.\n            print(f\"Thread Target: Caught exception: {e}\")\n            exception_container.append(e)\n\n    thread = threading.Thread(target=thread_target)\n    print(\"Main: Starting thread...\")\n    thread.start()\n    # Wait longer than the timeout + task duration to ensure the thread finishes\n    # regardless of whether timeout worked or not.\n    thread.join(timeout=LONG_TASK_DURATION + 1)\n\n    assert len(exception_container) == 1\n    assert isinstance(exception_container[0], TimeoutError)\n    assert not result_container\n\n\ndef test_in_thread_timeout():\n    result_container = []  # Use a list to store result from thread\n    exception_container = []  # Use a list to store exception from thread\n\n    @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=False)\n    def slow_task_in_thread():\n        try:\n            print(\"Thread: Starting slow task...\")\n            time.sleep(LONG_TASK_DURATION)\n            print(\"Thread: Slow task finished.\")\n            return \"slow_finished_in_thread\"\n        except Exception as e:\n            # Catch any exception within the thread's target function\n            print(f\"Thread: Caught exception: {e}\")\n            exception_container.append(e)\n            return None  # Indicate failure\n\n    def thread_target():\n        try:\n            # Run the decorated function inside the thread\n            res = slow_task_in_thread()\n            if res is not None:\n                result_container.append(res)\n        except Exception as e:\n            # This might catch exceptions happening *outside* the decorated function\n            # but still within the thread target, though less likely here.\n            print(f\"Thread Target: Caught exception: {e}\")\n            exception_container.append(e)\n\n    thread = threading.Thread(target=thread_target)\n    print(\"Main: Starting thread...\")\n    thread.start()\n    # Wait longer than the timeout + task duration to ensure the thread finishes\n    # regardless of whether timeout worked or not.\n    thread.join(timeout=LONG_TASK_DURATION + 1)\n\n    assert len(exception_container) == 1\n    assert isinstance(exception_container[0], TimeoutError)\n    assert not result_container\n"
  },
  {
    "path": "verl_rl/tests/utils/test_torch_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport pytest\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nfrom verl.utils.torch_functional import distributed_masked_mean, distributed_mean_max_min_std, masked_mean\n\n\ndef _worker_mean(rank: int, world_size: int, rendezvous_file: str):\n    # 1) set GPU and init NCCL\n    torch.cuda.set_device(rank)\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=f\"file://{rendezvous_file}\",\n        rank=rank,\n        world_size=world_size,\n    )\n\n    # each rank holds tensor [rank+1]\n    local = torch.tensor([float(rank + 1)], device=f\"cuda:{rank}\")\n    mean, gmax, gmin, gstd = distributed_mean_max_min_std(local, True, True, True)\n\n    values = [float(i + 1) for i in range(world_size)]\n    exp_mean = sum(values) / len(values)\n    exp_max = max(values)\n    exp_min = min(values)\n    var = sum((x - exp_mean) ** 2 for x in values) / (len(values) - 1)\n    exp_std = var**0.5\n\n    # all ranks should see the same result\n    assert torch.allclose(mean.cpu(), torch.tensor(exp_mean)), f\"mean@{rank}\"\n    assert torch.allclose(gmax.cpu(), torch.tensor(exp_max)), f\"max@{rank}\"\n    assert torch.allclose(gmin.cpu(), torch.tensor(exp_min)), f\"min@{rank}\"\n    assert torch.allclose(gstd.cpu(), torch.tensor(exp_std)), f\"std@{rank}\"\n\n    dist.destroy_process_group()\n\n\n@pytest.mark.parametrize(\n    \"value,mask,gt\",\n    [\n        ([1.0, 2.0, 3.0, 4.0], [1, 0, 0, 1], 2.5),\n        ([1.0, 2.0, float(\"nan\"), 4.0], [1, 0, 0, 1], 2.5),\n        ([1.0, 2.0, float(\"nan\"), 4.0], [1, 0, 1, 0], float(\"nan\")),\n    ],\n)\ndef test_masked_mean(value, mask, gt):\n    res = masked_mean(torch.tensor(value), torch.tensor(mask))\n    gt = torch.tensor(gt)\n    assert torch.allclose(res, gt) or (torch.isnan(res) and torch.isnan(gt))\n\n\n@pytest.mark.parametrize(\"world_size\", [2, 4])\ndef test_distributed_mean_max_min_std(world_size, tmp_path):\n    rendezvous_file = str(tmp_path / \"rdzv_mean\")\n    os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)\n\n    mp.spawn(\n        fn=_worker_mean,\n        args=(world_size, rendezvous_file),\n        nprocs=world_size,\n        join=True,\n    )\n\n\ndef _worker_mask(rank: int, world_size: int, rendezvous_file: str):\n    torch.cuda.set_device(rank)\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=f\"file://{rendezvous_file}\",\n        rank=rank,\n        world_size=world_size,\n    )\n\n    # build per‐rank tensor and mask\n    local_tensor = torch.tensor([rank * 2 + 1.0, rank * 2 + 2.0], device=f\"cuda:{rank}\")\n    if rank == 0:\n        mask = torch.tensor([1, 0], device=f\"cuda:{rank}\", dtype=torch.float32)\n    else:\n        mask = torch.tensor([0, 1], device=f\"cuda:{rank}\", dtype=torch.float32)\n\n    gmean = distributed_masked_mean(local_tensor, mask)\n\n    valid_values = [1.0] + [2 * i + 2.0 for i in range(1, world_size)]\n    expected_mean = sum(valid_values) / len(valid_values)\n    assert torch.allclose(gmean.cpu(), torch.tensor(expected_mean)), f\"masked_mean@{rank}\"\n\n    dist.destroy_process_group()\n\n\n@pytest.mark.parametrize(\"world_size\", [2, 4])\ndef test_distributed_masked_mean(world_size, tmp_path):\n    rendezvous_file = str(tmp_path / \"rdzv_mask\")\n    os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)\n\n    mp.spawn(\n        fn=_worker_mask,\n        args=(world_size, rendezvous_file),\n        nprocs=world_size,\n        join=True,\n    )\n"
  },
  {
    "path": "verl_rl/tests/workers/reward_manager/test_registry_on_cpu.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\n# Assuming REWARD_MANAGER_REGISTRY is defined somewhere in the module\nfrom verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY, get_reward_manager_cls, register\n\n\n@pytest.fixture\ndef setup():\n    \"\"\"Setup test cases with a mock registry.\"\"\"\n    REWARD_MANAGER_REGISTRY.clear()\n    REWARD_MANAGER_REGISTRY.update({\"manager1\": \"Manager1Class\", \"manager2\": \"Manager2Class\"})\n    return REWARD_MANAGER_REGISTRY\n\n\ndef test_get_existing_manager(setup):\n    \"\"\"Test getting an existing reward manager class.\"\"\"\n    assert get_reward_manager_cls(\"manager1\") == \"Manager1Class\"\n    assert get_reward_manager_cls(\"manager2\") == \"Manager2Class\"\n\n\ndef test_get_nonexistent_manager(setup):\n    \"\"\"Test getting a non-existent reward manager raises ValueError.\"\"\"\n    with pytest.raises(ValueError) as excinfo:\n        get_reward_manager_cls(\"unknown_manager\")\n    assert \"Unknown reward manager: unknown_manager\" in str(excinfo.value)\n\n\ndef test_case_sensitivity(setup):\n    \"\"\"Test that manager names are case-sensitive.\"\"\"\n    with pytest.raises(ValueError):\n        get_reward_manager_cls(\"MANAGER1\")\n    with pytest.raises(ValueError):\n        get_reward_manager_cls(\"Manager1\")\n\n\ndef test_empty_registry(setup):\n    \"\"\"Test behavior when registry is empty.\"\"\"\n    REWARD_MANAGER_REGISTRY.clear()\n    with pytest.raises(ValueError) as excinfo:\n        get_reward_manager_cls(\"any_manager\")\n    assert \"Unknown reward manager: any_manager\" in str(excinfo.value)\n\n\ndef test_register_new_class(setup):\n    \"\"\"Test registering a new class with the decorator.\"\"\"\n\n    @register(\"test_manager\")\n    class TestManager:\n        pass\n\n    assert \"test_manager\" in REWARD_MANAGER_REGISTRY\n    assert REWARD_MANAGER_REGISTRY[\"test_manager\"] == TestManager\n\n\ndef test_register_different_classes_same_name(setup):\n    \"\"\"Test that registering different classes with same name raises ValueError.\"\"\"\n\n    @register(\"conflict_manager\")\n    class Manager1:\n        pass\n\n    with pytest.raises(ValueError):\n\n        @register(\"conflict_manager\")\n        class Manager2:\n            pass\n\n    assert REWARD_MANAGER_REGISTRY[\"conflict_manager\"] == Manager1\n\n\ndef test_decorator_returns_original_class(setup):\n    \"\"\"Test that the decorator returns the original class unchanged.\"\"\"\n\n    @register(\"return_test\")\n    class OriginalClass:\n        def method(setup):\n            return 42\n\n    assert OriginalClass().method() == 42\n    assert REWARD_MANAGER_REGISTRY[\"return_test\"] == OriginalClass\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/async_rollout_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ray\nfrom omegaconf import DictConfig\n\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\nfrom verl.workers.fsdp_workers import AsyncActorRolloutRefWorker\nfrom verl.workers.rollout.async_server import AsyncLLMServerManager\n\n\ndef init_async_rollout_manager(config: DictConfig) -> AsyncLLMServerManager:\n    # =========================== 1. Create hybrid ActorRollout workers ===========================\n    role_worker_mapping = {\n        Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker),\n    }\n    global_pool_id = \"global_pool\"\n    resource_pool_spec = {\n        global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n    }\n    mapping = {\n        Role.ActorRollout: global_pool_id,\n    }\n    resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n    resource_pool_manager.create_resource_pool()\n    resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()}\n\n    # create actor and rollout\n    resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout)\n    actor_rollout_cls = RayClassWithInitArgs(\n        cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role=\"actor_rollout\"\n    )\n    resource_pool_to_cls[resource_pool][\"actor_rollout\"] = actor_rollout_cls\n\n    all_wg = {}\n    for resource_pool, class_dict in resource_pool_to_cls.items():\n        worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n        wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)\n        spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n        all_wg.update(spawn_wg)\n    actor_rollout_wg = all_wg[\"actor_rollout\"]\n    actor_rollout_wg.init_model()\n\n    # =========================== 2. Create AsyncLLMServerManager  ===========================\n    async_rollout_manager = AsyncLLMServerManager(\n        config=config,\n        worker_group=actor_rollout_wg,\n    )\n\n    return async_rollout_manager\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/perf/vllm_async_rollout.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nCompare vLLM AsyncLLM backend: ExternalRayDistributedExecutor(remote call) vs RayDistributedExecutor(compiled graph)\n\n1. Prepare openai/gsm8k dataset\npython3 examples/data_preprocess/gsm8k.py\n\n2. Run perf test\npython3 tests/workers/rollout/perf/vllm_async_rollout.py >perf.log 2>&1\n\nhardware: Nvidia 8*H20\npackages:\n- torch==2.6.0\n- vllm==0.8.5\n\n[DEBUG] backend: sync, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 21.27 secs\n[DEBUG] backend: zeromq, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 23.40 secs\n[DEBUG] backend: ray, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 25.33 secs\n\"\"\"\n\nimport os\nimport time\n\nimport ray\nfrom omegaconf import DictConfig\nfrom torch.utils.data import SequentialSampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\n\nfrom tests.experimental.agent_loop.agent_utils import AgentLoopManager, RayWorkerGroup, init_agent_loop_manager\nfrom verl.protocol import DataProto\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.dataset import RLHFDataset\nfrom verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn\n\n\ndef init_config(n_gpus_per_node) -> DictConfig:\n    import os\n\n    from hydra import compose, initialize_config_dir\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(config_name=\"ppo_trainer\")\n    config.trainer.n_gpus_per_node = n_gpus_per_node\n    config.data.train_batch_size = 128\n    config.data.return_raw_chat = True\n    config.actor_rollout_ref.model.path = \"Qwen/Qwen2.5-7B-Instruct\"\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2\n    config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9\n    config.actor_rollout_ref.rollout.multi_turn.format = \"hermes\"\n    config.actor_rollout_ref.rollout.prompt_length = 4096\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.n = 16\n\n    # test sleep/wake_up with fsdp offload\n    config.actor_rollout_ref.actor.fsdp_config.param_offload = True\n    config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True\n\n    return config\n\n\ndef initialize(config, backend) -> tuple[AgentLoopManager | RayWorkerGroup, StatefulDataLoader]:\n    env_vars = {\n        \"NCCL_DEBUG\": \"WARN\",\n        \"VLLM_USE_V1\": \"1\",\n        \"VERL_VLLM_DISTRIBUTED_BACKEND\": backend,\n    }\n    ray.init(runtime_env={\"env_vars\": env_vars})\n\n    # STEP 1: init async llm server\n    server = init_agent_loop_manager(config)\n\n    # STEP 2: create dataloader\n    tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path)\n    dataset = RLHFDataset(\n        data_files=os.path.expanduser(\"~/data/gsm8k/train.parquet\"),\n        tokenizer=tokenizer,\n        config=config.data,\n    )\n    dataloader = StatefulDataLoader(\n        dataset=dataset,\n        batch_size=config.data.get(\"gen_batch_size\", config.data.train_batch_size),\n        num_workers=config.data.get(\"dataloader_num_workers\", 8),\n        drop_last=True,\n        collate_fn=default_collate_fn,\n        sampler=SequentialSampler(dataset),\n    )\n\n    return server, dataloader\n\n\ndef perf_rollout(mode, backend, n_gpus_per_node, num_steps):\n    config = init_config(n_gpus_per_node)\n    config.actor_rollout_ref.rollout.mode = mode\n    agent_loop_manager, dataloader = initialize(config, backend)\n\n    for step, batch in enumerate(dataloader):\n        batch: DataProto = DataProto.from_single_dict(batch)\n        batch = batch.pop(\n            batch_keys=[\"input_ids\", \"attention_mask\", \"position_ids\"],\n            non_tensor_batch_keys=[\"raw_prompt_ids\", \"raw_prompt\"],\n        )\n        t_start = time.time()\n        gen_batch = agent_loop_manager.generate_sequences(batch)\n        t_end = time.time()\n        print(\n            f\"[DEBUG] backend: {backend}, n_gpus_per_node: {n_gpus_per_node}, batch_size: {len(gen_batch)}, \"\n            f\"step: {step}, step_time: {t_end - t_start:.2f} secs\"\n        )\n        if step + 1 >= num_steps:\n            break\n\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    num_steps = 1\n    n_gpus_per_node = 8\n\n    # test_cases = [(\"sync\", \"sync\"), (\"async\", \"zeromq\"), (\"async\", \"ray\")]\n    test_cases = [(\"async\", \"zeromq\"), (\"async\", \"ray\")]\n    for mode, backend in test_cases:\n        perf_rollout(mode=mode, backend=backend, n_gpus_per_node=n_gpus_per_node, num_steps=num_steps)\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/resource/tool_configs/mcp_server.json",
    "content": "{\n    \"mcpServers\": {\n        \"Tavily Expert\": {\n            \"url\": \"https://tavily.api.tadata.com/mcp/tavily/your_expert\",\n            \"auth_token\": \"your_tavily_token\"\n        }\n    }\n}"
  },
  {
    "path": "verl_rl/tests/workers/rollout/resource/tool_configs/mcp_tool_config",
    "content": "tools:\n  - class_name: verl.tools.mcp_search_tool.MCPSearchTool\n    config:\n      rate_limit: 120\n      timeout: 120\n      type: mcp\n    mcp:\n      mcp_servers_config_path: ./resource/tool_configs/mcp_server.json\n      # optional\n      tool_selected_list: \n        - tavily_search_tool"
  },
  {
    "path": "verl_rl/tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config",
    "content": "tools:\n  - class_name: \"verl.tools.sandbox_fusion_tools.SandboxFusionTool\"\n    config: \n      sandbox_fusion_url: \"https://xxx.apigateway-cn-beijing.volceapi.com/run_code\"\n      type: native\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"code_interpreter\"\n        description: \"A tool for executing code.\"\n        parameters:\n          type: \"object\"\n          properties:\n            code:\n              type: \"string\"\n              description: \"The code to execute.\"\n          required: [\"code\"]"
  },
  {
    "path": "verl_rl/tests/workers/rollout/resource/tool_configs/search_tool_config",
    "content": "tools:\n  - class_name: verl.tools.search_tool.SearchTool\n    config:\n      retrieval_service_url: http://127.0.0.1:8000/retrieve\n      num_workers: 120\n      rate_limit: 120\n      timeout: 30\n      type: native\n    tool_schema:\n      type: function\n      function:\n        name: search\n        description: Searches the web for relevant information based on the given query.\n        parameters:\n          type: object\n          properties:\n            query_list:\n              type: array\n              item:\n                type: string\n              description: A list of fully-formed semantic queries. The tool will return search results for each query.\n          required: \n            - query_list"
  },
  {
    "path": "verl_rl/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed.fsdp import CPUOffload, MixedPrecision\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\nfrom vllm import SamplingParams\n\nfrom verl.third_party.vllm import LLM\nfrom verl.utils.distributed import initialize_global_process_group\n\n\ndef main():\n    assert torch.cuda.is_available(), \"CUDA must be present to run FSDP vLLM example\"\n    local_rank, rank, world_size = initialize_global_process_group()\n\n    local_cache_path = \"~/.cache/verl/rlhf\"\n    local_cache_path = os.path.expanduser(local_cache_path)\n    hdfs_path = \"Qwen/Qwen2-7B-Instruct\"\n\n    from verl.utils.fs import copy_to_local\n\n    local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)\n    tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)\n    actor_model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=True)\n    with torch.device(\"cuda\"):\n        actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True)\n        actor_model.to(torch.bfloat16)\n\n    max_prompt_length = 16\n    response_length = 32\n    preencode_prompts = [\n        \"The president of the United States is\",\n        \"The capital of France is\",\n        \"The future of AI is\",\n    ]\n    tokenizer.pad_token = tokenizer.eos_token\n    prompts = tokenizer(preencode_prompts, return_tensors=\"pt\", padding=True)\n    input_ids = prompts[\"input_ids\"]\n    attention_mask = prompts[\"attention_mask\"]\n    from verl.utils.torch_functional import pad_sequence_to_length\n\n    input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True).cuda()\n    attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True).cuda()\n\n    from transformers import GenerationConfig\n\n    generation_config = GenerationConfig(do_sample=False)\n    actor_model.cuda()\n    output = actor_model.generate(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        max_new_tokens=32,\n        # max_length=max_length,\n        eos_token_id=tokenizer.eos_token_id,\n        pad_token_id=tokenizer.pad_token_id,\n        generation_config=generation_config,\n        # renormalize_logits=True,\n        output_scores=False,  # this is potentially very large\n        return_dict_in_generate=True,\n        use_cache=False,\n    )  # may OOM when use_cache = True\n    seq = output.sequences\n    response = seq[:, max_prompt_length:]\n\n    print(f\"hf response: {tokenizer.batch_decode(response)}\")\n\n    tensor_model_parallel_size = 4\n    from torch.distributed.device_mesh import init_device_mesh\n\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n\n    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)\n    fsdp_model = FSDP(\n        actor_model,\n        use_orig_params=True,\n        auto_wrap_policy=None,\n        device_id=torch.cuda.current_device(),\n        sharding_strategy=ShardingStrategy.FULL_SHARD,\n        mixed_precision=mixed_precision,\n        cpu_offload=CPUOffload(offload_params=False),\n        sync_module_states=False,\n        device_mesh=device_mesh,\n    )\n\n    FSDP.set_state_dict_type(\n        fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()\n    )\n\n    state_dict = fsdp_model.state_dict()\n\n    sampling_params = SamplingParams(\n        temperature=0, top_p=1, n=1, max_tokens=response_length, logprobs=1, ignore_eos=True, detokenize=False\n    )\n\n    print(actor_model_config)\n    llm = LLM(\n        model=None,\n        tokenizer=tokenizer,\n        model_hf_config=actor_model_config,\n        tensor_parallel_size=tensor_model_parallel_size,\n        enforce_eager=True,\n        dtype=\"bfloat16\",\n        load_format=\"dummy_dtensor\",\n        gpu_memory_utilization=0.8,\n        trust_remote_code=True,\n    )\n\n    # Warmup iterations\n    for _ in range(10):\n        torch.cuda.synchronize()\n        llm.sync_model_weights(actor_weights=state_dict, load_format=\"dtensor\")\n        torch.cuda.synchronize()\n        dist.barrier()\n\n    start_time = time.time()\n    llm.sync_model_weights(actor_weights=state_dict, load_format=\"dtensor\")\n    torch.cuda.synchronize()\n    dist.barrier()\n    end_time = time.time()\n\n    # Calculate elapsed time\n    elapsed_time = end_time - start_time\n    print(f\"Time taken: {elapsed_time:.6f} seconds\")\n\n    input_ids = input_ids.cuda()\n    attention_mask = attention_mask.cuda()\n    idx_list = []\n    batch_size = input_ids.shape[0]\n\n    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n    from verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import _pre_process_inputs\n\n    for i in range(batch_size):\n        idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))\n    print(\"start generation\")\n    outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False)\n    vllm_output = outputs[0].cuda()\n    if torch.distributed.get_rank() == 0:\n        print(f\"hf response: {tokenizer.batch_decode(response)}\")\n        print(f\"vllm response: {tokenizer.batch_decode(vllm_output)}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nfrom typing import Any\n\nimport numpy as np\nimport pytest\nimport ray\nfrom omegaconf import DictConfig\nfrom transformers.utils import get_json_schema\n\nfrom tests.workers.rollout.async_rollout_utils import init_async_rollout_manager\nfrom verl.protocol import DataProto\nfrom verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema\nfrom verl.utils import hf_tokenizer\n\n\n@pytest.fixture\ndef init_config() -> DictConfig:\n    import os\n\n    from hydra import compose, initialize_config_dir\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(config_name=\"ppo_trainer\")\n    model_path = \"Qwen/Qwen2.5-1.5B-Instruct\"\n    config.actor_rollout_ref.model.path = model_path\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.multi_turn.format = \"hermes\"\n    config.actor_rollout_ref.rollout.prompt_length = 4096\n    config.actor_rollout_ref.rollout.response_length = 4096\n\n    # test sleep/wake_up with fsdp offload\n    config.actor_rollout_ref.actor.fsdp_config.param_offload = True\n    config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True\n\n    return config\n\n\ndef test_vllm_async_rollout_without_tool_calls(init_config):\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    # =========================== 1. Init rollout manager ===========================\n    async_rollout_manager = init_async_rollout_manager(init_config)\n\n    # test sleep and wake_up\n    async_rollout_manager.sleep()\n    async_rollout_manager.wake_up()\n\n    # =========================== 2. Generate sequences  ===========================\n    raw_prompts = [\n        [\n            {\n                \"role\": \"user\",\n                \"content\": \"Let's play a role playing game. Your name is Alice, your favorite color is blue.\",\n            }\n        ],\n        [{\"role\": \"user\", \"content\": \"Let's play a role playing game. Your name is Bob, your favorite color is red.\"}],\n    ]\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array(raw_prompts),\n        },\n    )\n    result = async_rollout_manager.generate_sequences(prompts=batch)\n\n    # check result\n    seq_len = result.batch[\"prompts\"].size(1) + result.batch[\"responses\"].size(1)\n    assert len(result) == 2\n    assert result.batch[\"input_ids\"].size(1) == seq_len\n    assert result.batch[\"attention_mask\"].size(1) == seq_len\n    assert result.batch[\"position_ids\"].size(1) == seq_len\n\n    # check turns\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    assert np.all(num_turns == 2)\n\n    print(\"Test passed!\")\n    ray.shutdown()\n\n\nclass WeatherTool(BaseTool):\n    def get_current_temperature(self, location: str, unit: str = \"celsius\"):\n        \"\"\"Get current temperature at a location.\n\n        Args:\n            location: The location to get the temperature for, in the format \"City, State, Country\".\n            unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n        Returns:\n            the temperature, the location, and the unit in a dict\n        \"\"\"\n        return {\n            \"temperature\": 26.1,\n            \"location\": location,\n            \"unit\": unit,\n        }\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        schema = get_json_schema(self.get_current_temperature)\n        return OpenAIFunctionToolSchema(**schema)\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        try:\n            result = self.get_current_temperature(**parameters)\n            return json.dumps(result), 0, {}\n        except Exception as e:\n            return str(e), 0, {}\n\n\nclass WeatherToolWithData(BaseTool):\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        schema = get_json_schema(self.get_temperature_date)\n        return OpenAIFunctionToolSchema(**schema)\n\n    def get_temperature_date(self, location: str, date: str, unit: str = \"celsius\"):\n        \"\"\"Get temperature at a location and date.\n\n        Args:\n            location: The location to get the temperature for, in the format \"City, State, Country\".\n            date: The date to get the temperature for, in the format \"Year-Month-Day\".\n            unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n        Returns:\n            the temperature, the location, the date and the unit in a dict\n        \"\"\"\n        return {\n            \"temperature\": 25.9,\n            \"location\": location,\n            \"date\": date,\n            \"unit\": unit,\n        }\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        try:\n            result = self.get_temperature_date(**parameters)\n            return json.dumps(result), 0, {}\n        except Exception as e:\n            return str(e), 0, {}\n\n\ndef test_vllm_async_rollout_with_tool_calls(init_config):\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    # =========================== 1. Init rollout manager ===========================\n    tool_config = {\n        \"tools\": [\n            {\n                \"class_name\": \"tests.workers.rollout.rollout_vllm.test_vllm_chat_scheduler.WeatherTool\",\n                \"config\": {\"type\": \"native\"},\n            },\n            {\n                \"class_name\": \"tests.workers.rollout.rollout_vllm.test_vllm_chat_scheduler.WeatherToolWithData\",\n                \"config\": {\"type\": \"native\"},\n            },\n        ]\n    }\n    tool_config_path = \"/tmp/tool_config.json\"\n    with open(tool_config_path, \"w\") as f:\n        json.dump(tool_config, f)\n\n    init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path\n    async_rollout_manager = init_async_rollout_manager(init_config)\n\n    # =========================== 2. Generate sequences  ===========================\n    raw_prompts = [\n        [\n            {\"role\": \"user\", \"content\": \"How are you?\"},\n        ],\n        [\n            {\"role\": \"user\", \"content\": \"What's the temperature in Los Angeles now?\"},\n        ],\n        [\n            {\n                \"role\": \"system\",\n                \"content\": \"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\\n\\n\"\n                \"Current Date: 2024-09-30\",\n            },\n            {\"role\": \"user\", \"content\": \"What's the temperature in San Francisco now? How about tomorrow?\"},\n        ],\n    ]\n    batch = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),\n        },\n    )\n    result = async_rollout_manager.generate_sequences(prompts=batch)\n\n    # Check turns\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    # [user, assistant]\n    assert num_turns[0] == 2\n    # [user, assistant, tool, assistant]\n    assert num_turns[1] == 4\n    # [system, user, assistant, tool, tool, assistant]\n    assert num_turns[2] == 6\n\n    # Check response_mask\n    tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)\n    responses = result.batch[\"responses\"]\n    response_mask = result.batch[\"response_mask\"]\n    assert responses.size() == response_mask.size(), f\"{responses.size()} != {response_mask.size()}\"\n\n    # Decode responses with response_mask\n    for i in range(len(responses)):\n        valid_tokens = responses[i][response_mask[i].bool()]\n        response_str = tokenizer.decode(valid_tokens)\n        assert \"<tool_response>\" not in response_str, f\"found <tool_response> in response: {response_str}\"\n        assert \"</tool_response>\" not in response_str, f\"found </tool_response> in response: {response_str}\"\n        print(f\"response: {response_str}\")\n\n    print(\"Test passed!\")\n    ray.shutdown()\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\n\nimport torch\nimport torch.distributed\nimport torch.distributed as dist\nfrom omegaconf import OmegaConf\nfrom transformers import AutoConfig, AutoTokenizer\n\nfrom verl import DataProto\nfrom verl.utils.distributed import initialize_global_process_group\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import vLLMRollout\n\n\ndef test_vllm_rollout_with_yarn_position_embeddings():\n    \"\"\"\n    Test the vLLM rollout with yarn position embeddings.\n    \"\"\"\n\n    local_rank, rank, world_size = initialize_global_process_group()\n    config = OmegaConf.create(\n        {\n            \"model_path\": \"OldKingMeister/Qwen2.5-1.5B-Instruct-YaRN\",\n            \"prompt_length\": 35000,\n            \"response_length\": 512,\n            \"dtype\": \"bfloat16\",\n            \"enforce_eager\": True,\n            \"gpu_memory_utilization\": 0.4,\n            \"enable_chunked_prefill\": False,\n            \"free_cache_engine\": False,\n            \"disable_log_stats\": True,\n            \"max_model_len\": 35000 + 512,\n            \"load_format\": \"auto\",\n            \"val_kwargs\": {\n                \"top_k\": -1,\n                \"top_p\": 1.0,\n                \"temperature\": 0,\n                \"n\": 1,\n                \"do_sample\": False,\n            },\n            \"tensor_model_parallel_size\": 4,\n            \"trust_remote_code\": True,\n            \"calculate_log_probs\": False,\n            \"do_sample\": False,\n            \"temperature\": 0.0,\n            \"max_num_batched_tokens\": 35000 + 512,\n        }\n    )\n\n    tokenizer = AutoTokenizer.from_pretrained(config.model_path, trust_remote_code=True, padding_side=\"left\")\n    tokenizer.pad_token = tokenizer.eos_token\n    model_hf_config = AutoConfig.from_pretrained(config.model_path)\n\n    # do_sample=False for temperate=0 deterministic\n    input_dataproto = prepare_input_dataproto(tokenizer, config, validate=True, do_sample=False)\n\n    vllm_rollout = vLLMRollout(\n        model_path=config.model_path,\n        config=config,\n        tokenizer=tokenizer,\n        model_hf_config=model_hf_config,\n    )\n    # rollout\n    rollout_response = vllm_rollout.generate_sequences(\n        prompts=input_dataproto,\n    )\n    if rank == 0:\n        print(\"VLLM Rollout Outputs:\")\n        print(tokenizer.batch_decode(rollout_response.batch[\"responses\"][:], skip_special_tokens=False))\n        for response in rollout_response.batch[\"responses\"]:\n            assert \"<|im_end|>\" in tokenizer.decode(response, skip_special_tokens=False), (\n                \"Response should contain <|im_end|> token\"\n            )\n    print(\"Checks passed.\")\n\n    del vllm_rollout\n    gc.collect()\n    torch.cuda.empty_cache()\n    torch.cuda.ipc_collect()\n    dist.barrier()\n    torch.distributed.destroy_process_group()\n\n\ndef prepare_input_dataproto(tokenizer, config, validate, do_sample=False):\n    base_phrase = \"Roses are red, sky is blue. \" * 4096\n    preencode_prompts = [\n        # 32810 tokens > 32768 tokens\n        [{\"role\": \"user\", \"content\": base_phrase + \"Who won the Champions League in 2019?\"}],\n        [{\"role\": \"user\", \"content\": base_phrase + \"The founder of Apple is\"}],\n        [{\"role\": \"user\", \"content\": base_phrase + \"What's your name\"}],\n    ]\n    formatted_prompts = [\n        tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)\n        for conversation in preencode_prompts\n    ]\n    prompts = tokenizer(formatted_prompts, return_tensors=\"pt\", padding=\"max_length\", max_length=config.prompt_length)\n    input_dataproto = DataProto.from_dict(\n        {\n            \"input_ids\": prompts[\"input_ids\"],\n            \"attention_mask\": prompts[\"attention_mask\"],\n            \"position_ids\": compute_position_id_with_mask(prompts[\"attention_mask\"]),\n        },\n        meta_info={\n            \"bos_token_id\": tokenizer.bos_token_id,\n            \"eos_token_id\": tokenizer.eos_token_id,\n            \"pad_token_id\": tokenizer.pad_token_id,\n            \"validate\": validate,\n            \"do_sample\": do_sample,\n            \"response_length\": config.response_length,\n            \"temperature\": config.temperature,\n        },\n    )\n    return input_dataproto\n\n\nif __name__ == \"__main__\":\n    test_vllm_rollout_with_yarn_position_embeddings()\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport pytest\nimport torch\nfrom torch.distributed.fsdp import CPUOffload, MixedPrecision\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom vllm import LLM, SamplingParams\n\nfrom verl.utils.distributed import initialize_global_process_group\nfrom verl.utils.torch_functional import pad_sequence_to_length\n\n\ndef levenshtein(s1, s2):\n    m, n = len(s1), len(s2)\n    # Initialize matrix of zeros\n    dp = [[0] * (n + 1) for _ in range(m + 1)]\n    # Initialize first column and first row of the matrix\n    for i in range(m + 1):\n        dp[i][0] = i  # Deletion from s1 to empty string\n    for j in range(n + 1):\n        dp[0][j] = j  # Insertion to s1 from empty string\n    # Compute the Levenshtein distance matrix\n    for i in range(1, m + 1):\n        for j in range(1, n + 1):\n            cost = 0 if s1[i - 1] == s2[j - 1] else 1  # No cost if characters match\n            dp[i][j] = min(\n                dp[i - 1][j] + 1,  # Deletion\n                dp[i][j - 1] + 1,  # Insertion\n                dp[i - 1][j - 1] + cost,  # Substitution\n            )\n    return dp[m][n]\n\n\ndef are_lists_similar(a, b):\n    if len(a) != len(b):\n        print(\"The lists are of different lengths.\")\n        return False\n\n    total_length = 0\n    total_diff = 0\n\n    for s1, s2 in zip(a, b, strict=True):\n        max_len = max(len(s1), len(s2))\n        total_length += max_len\n        diff = levenshtein(s1, s2)\n        total_diff += diff\n        print(f\"Comparing strings:\\n{s1}\\n{s2}\\nDifference: {diff} characters\\n\")\n\n    percentage_difference = (total_diff / total_length) * 100\n    print(f\"Total difference: {percentage_difference:.2f}%\")\n\n    return percentage_difference <= 15\n\n\n@pytest.mark.skip(\"https://github.com/vllm-project/vllm/issues/16993\")\ndef test_vllm_spmd():\n    assert torch.cuda.device_count() >= 2, \"At least 2 GPUs is required to run tp+dp tests.\"\n    local_rank, rank, world_size = initialize_global_process_group()\n\n    # Initialize model and token\n    local_cache_path = \"~/.cache/verl/rlhf\"\n    local_cache_path = os.path.expanduser(local_cache_path)\n    hdfs_path = \"Qwen/Qwen2-7B-Instruct\"\n    from verl.utils.fs import copy_to_local\n\n    local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)\n    tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side=\"left\", trust_remote_code=True)\n\n    actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True)\n    actor_model.to(torch.bfloat16)\n\n    # fill rollout config\n    max_prompt_length = 16\n    max_response_length = 32\n    preencode_prompts = [\n        \"Who won the Champions League in 2019?\",\n        \"The founder of Apple is\",\n        \"What's your name?\",\n    ]\n    tokenizer.pad_token = tokenizer.eos_token\n    prompts = tokenizer(preencode_prompts, return_tensors=\"pt\", padding=True)\n    input_ids = prompts[\"input_ids\"]\n    attention_mask = prompts[\"attention_mask\"]\n\n    input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)\n    attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)\n\n    print(\"start generation\")\n    input_ids = input_ids.cuda()\n    attention_mask = attention_mask.cuda()\n\n    temperature = 0\n    top_p = 1\n    kwargs = dict(\n        n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True\n    )\n\n    tensor_parallel_size = 4\n\n    from torch.distributed.device_mesh import init_device_mesh\n\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n\n    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)\n\n    fsdp_model = FSDP(\n        actor_model,\n        use_orig_params=True,\n        auto_wrap_policy=None,\n        device_id=torch.cuda.current_device(),\n        sharding_strategy=ShardingStrategy.FULL_SHARD,\n        mixed_precision=mixed_precision,\n        cpu_offload=CPUOffload(offload_params=False),\n        sync_module_states=False,\n        device_mesh=device_mesh,\n    )\n\n    FSDP.set_state_dict_type(\n        fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()\n    )\n\n    state_dict = fsdp_model.state_dict()\n\n    sampling_params = SamplingParams(**kwargs)\n    llm = LLM(\n        model=local_model_path,\n        enable_sleep_mode=True,\n        tensor_parallel_size=tensor_parallel_size,\n        distributed_executor_backend=\"external_launcher\",\n        dtype=\"bfloat16\",\n        enforce_eager=True,\n        gpu_memory_utilization=0.8,\n        disable_custom_all_reduce=True,\n        skip_tokenizer_init=False,\n        enable_prefix_caching=True,\n        trust_remote_code=True,\n        seed=1,\n    )\n\n    outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False)\n    vllm_response_tokens = []\n    for output in outputs:\n        generated_text = output.outputs[0].text\n        vllm_response_tokens.append(generated_text)\n\n    world_size = torch.distributed.get_world_size()\n    model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model\n    model.load_weights(\n        ((name, param.full_tensor() if world_size != 1 else param) for name, param in state_dict.items())\n    )\n\n    outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False)\n    verl_vllm_response_tokens = []\n    for output in outputs:\n        generated_text = output.outputs[0].text\n        verl_vllm_response_tokens.append(generated_text)\n\n    if torch.distributed.get_rank() == 0:\n        print(f\"vllm response: {vllm_response_tokens}\")\n        print(f\"verl-vllm response: {verl_vllm_response_tokens}\")\n    assert are_lists_similar(vllm_response_tokens, verl_vllm_response_tokens), \"Strings differ more than 10%:\\n\"\n    print(\"Check Pass\")\n    torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    test_vllm_spmd()\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_async_sglang_server_on_cpu.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom unittest.mock import AsyncMock, MagicMock, patch\n\nimport pytest\nfrom omegaconf import DictConfig\n\n\n@patch.dict(\n    \"sys.modules\",\n    {\n        \"verl.workers.rollout.sglang_rollout.sglang_rollout\": MagicMock(SGLangRollout=MagicMock()),\n        \"verl.workers.rollout.chat_scheduler\": MagicMock(ChatCompletionScheduler=MagicMock()),\n        \"fastapi\": MagicMock(FastAPI=MagicMock()),\n        \"uvicorn\": MagicMock(FastAPI=MagicMock()),\n        \"starlette.requests\": MagicMock(Request=MagicMock()),\n        \"starlette.responses\": MagicMock(JSONResponse=MagicMock()),\n    },\n)\nclass TestAsyncSglangServer:\n    @pytest.fixture\n    def server_config(self):\n        return DictConfig({\"actor_rollout_ref\": {\"rollout\": {\"tensor_model_parallel_size\": 2}}})\n\n    @pytest.mark.asyncio\n    @patch(\"verl.workers.rollout.sglang_rollout.async_sglang_server.ray.util.list_named_actors\")\n    @patch(\"verl.workers.rollout.async_server.AsyncServerBase._start_fastapi_server\", new_callable=AsyncMock)\n    @pytest.mark.filterwarnings(\"ignore:Ray state API is no longer experimental:DeprecationWarning\")\n    async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, server_config):\n        from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSGLangServer\n\n        ActualClassToInstantiate = AsyncSGLangServer\n        if hasattr(AsyncSGLangServer, \"__ray_metadata__\") and hasattr(\n            AsyncSGLangServer.__ray_metadata__, \"modified_class\"\n        ):\n            ActualClassToInstantiate = AsyncSGLangServer.__ray_metadata__.modified_class\n\n        def mock_get_actor_side_effect(name, namespace=None):\n            # Create a new mock actor for each call\n            actor_mock = MagicMock()\n\n            # Support .name attribute access\n            actor_mock.name = name  # Use 'name' here\n\n            # Support ['name'] item access by mocking __getitem__\n            def getitem_mock(key):\n                if key == \"name\":\n                    return name  # Use 'name' here\n                # For other keys, return a new MagicMock to mimic default behavior or raise KeyError\n                # Returning a MagicMock is consistent with the original error's cause for unmocked keys\n                return MagicMock(name=f\"mock.__getitem__('{key}')\")\n\n            actor_mock.__getitem__.side_effect = getitem_mock\n\n            return actor_mock\n\n        # Verify instance.workers is correctly populated\n        with patch(\n            \"verl.workers.rollout.sglang_rollout.async_sglang_server.ray.get_actor\",\n            side_effect=mock_get_actor_side_effect,\n        ):\n            # nnodes: 2\n            # n_gpus_per_node: 4\n            # tensor_model_parallel_size: 2\n            # DP_size: 4\n            mock_list_actors.return_value = [\n                {\"name\": \"test_xxxx\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:0\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:1\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:0\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:1\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:2\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:3\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:2\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:3\", \"namespace\": \"test\"},\n            ]\n\n            # Instance 1\n            instance = ActualClassToInstantiate(server_config, 4, 0, \"test_prefix\")\n            await instance.init_engine()\n\n            assert len(instance.workers) == 2\n            assert instance.master_worker[\"name\"] == \"test_prefixWorkerDict_0:0\"\n            assert instance.workers[0].name == \"test_prefixWorkerDict_0:0\"\n            assert instance.workers[1].name == \"test_prefixWorkerDict_0:1\"\n\n            # Instance 2\n            instance = ActualClassToInstantiate(server_config, 4, 1, \"test_prefix\")\n            await instance.init_engine()\n\n            assert len(instance.workers) == 2\n            assert instance.master_worker[\"name\"] == \"test_prefixWorkerDict_0:2\"\n            assert instance.workers[0].name == \"test_prefixWorkerDict_0:2\"\n            assert instance.workers[1].name == \"test_prefixWorkerDict_0:3\"\n\n            # Instance 3\n            instance = ActualClassToInstantiate(server_config, 4, 3, \"test_prefix\")\n            await instance.init_engine()\n\n            assert len(instance.workers) == 2\n            assert instance.master_worker[\"name\"] == \"test_prefixWorkerDict_1:2\"\n            assert instance.workers[0].name == \"test_prefixWorkerDict_1:2\"\n            assert instance.workers[1].name == \"test_prefixWorkerDict_1:3\"\n\n            # nnodes: 4\n            # n_gpus_per_node: 8\n            # tensor_model_parallel_size: 8\n            # DP_size: 4\n            mock_list_actors.return_value = [\n                {\"name\": \"test_prefixWorkerDict_0:0\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:1\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:2\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:3\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:4\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:5\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:6\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_0:7\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:0\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:1\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:2\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:3\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:4\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:5\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:6\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_1:7\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_2:0\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_2:1\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_2:2\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_2:3\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_2:4\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_2:5\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_2:6\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_2:7\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_3:0\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_3:1\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_3:2\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_3:3\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_3:4\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_3:5\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_3:6\", \"namespace\": \"test\"},\n                {\"name\": \"test_prefixWorkerDict_3:7\", \"namespace\": \"test\"},\n            ]\n\n            server_config.actor_rollout_ref.rollout.tensor_model_parallel_size = 8\n            # Instance 1\n            instance = ActualClassToInstantiate(server_config, 4, 0, \"test_prefix\")\n            await instance.init_engine()\n\n            assert len(instance.workers) == 8\n            assert instance.master_worker[\"name\"] == \"test_prefixWorkerDict_0:0\"\n            assert instance.workers[0].name == \"test_prefixWorkerDict_0:0\"\n            assert instance.workers[7].name == \"test_prefixWorkerDict_0:7\"\n\n            # Instance 2\n            instance = ActualClassToInstantiate(server_config, 4, 1, \"test_prefix\")\n            await instance.init_engine()\n\n            assert len(instance.workers) == 8\n            assert instance.master_worker[\"name\"] == \"test_prefixWorkerDict_1:0\"\n            assert instance.workers[0].name == \"test_prefixWorkerDict_1:0\"\n            assert instance.workers[7].name == \"test_prefixWorkerDict_1:7\"\n\n            # Instance 3\n            instance = ActualClassToInstantiate(server_config, 4, 3, \"test_prefix\")\n            await instance.init_engine()\n\n            assert len(instance.workers) == 8\n            assert instance.master_worker[\"name\"] == \"test_prefixWorkerDict_3:0\"\n            assert instance.workers[0].name == \"test_prefixWorkerDict_3:0\"\n            assert instance.workers[7].name == \"test_prefixWorkerDict_3:7\"\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_custom_completion_callback.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport concurrent.futures\nimport os\nimport re\nimport socket\nimport sys\nimport tempfile\nfrom contextlib import asynccontextmanager\nfrom typing import Any\n\nimport fastapi\nimport numpy as np\nimport ray\nimport uvicorn\nfrom datasets import load_dataset\nfrom omegaconf import DictConfig\nfrom openai.types.chat.chat_completion import ChatCompletion\nfrom starlette.requests import Request\nfrom starlette.responses import JSONResponse\n\nfrom tests.workers.rollout.async_rollout_utils import init_async_rollout_manager\nfrom verl.protocol import DataProto\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.reward_score.sandbox_fusion.utils import _process_single_case\nfrom verl.workers.rollout.chat_scheduler import ChatCompletionScheduler, ToolCompletionCallback\n\n\ndef _get_free_port():\n    with socket.socket() as sock:\n        sock.bind((\"\", 0))\n        return sock.getsockname()[1]\n\n\n@ray.remote(num_cpus=1)\nclass Sandbox:\n    \"\"\"Sandbox to execute python code.\n\n    WARNING: This class is for testing purpose only, do not use it in production.\n    Please use a sandbox with strong isolation and security restrictions instead.\n    \"\"\"\n\n    def __init__(self):\n        self.address = ray.util.get_node_ip_address()\n        self.port = None\n        self.server_ready = asyncio.Event()\n        asyncio.create_task(self._start_fastapi_server())\n\n    async def code_execution(self, request: Request):\n        request_json = await request.json()\n        code = request_json[\"code\"]\n        print(f\"execute code:\\n{code}\")\n\n        _, temp_file = tempfile.mkstemp(suffix=\".py\", prefix=\"temp_code\", dir=None, text=True)\n        with open(temp_file, \"w\") as f:\n            f.write(code)\n\n        try:\n            process = await asyncio.create_subprocess_exec(\n                sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE\n            )\n\n            stdout, stderr = await process.communicate()\n\n            response = {\n                \"status\": \"Success\" if process.returncode == 0 else \"Failed\",\n                \"run_result\": {\n                    \"status\": \"Finished\",\n                    \"stdout\": stdout.decode(),\n                    \"stderr\": stderr.decode(),\n                    \"return_code\": process.returncode,\n                },\n            }\n            return JSONResponse(content=response)\n        finally:\n            try:\n                os.unlink(temp_file)\n            except Exception:  # noqa: E722\n                pass\n\n    async def _start_fastapi_server(self):\n        @asynccontextmanager\n        async def lifespan(app: fastapi.FastAPI):\n            print(\"FastAPI startup\")\n            self.server_ready.set()\n            yield\n\n            print(\"FastAPI shutdown, maybe address already in use, exit process immediately.\")\n            os._exit(-1)\n\n        app = fastapi.FastAPI(lifespan=lifespan)\n        app.router.add_api_route(\"/run_code\", self.code_execution, methods=[\"POST\"])\n\n        self.port = _get_free_port()\n        config = uvicorn.Config(app, host=[\"::\", \"0.0.0.0\"], port=self.port, log_level=\"warning\")\n        server = uvicorn.Server(config)\n        await server.serve()\n\n    async def get_server_address(self) -> str:\n        \"\"\"Get FastAPI server address.\"\"\"\n        await self.server_ready.wait()\n        return f\"{self.address}:{self.port}\"\n\n\nclass CustomCompletionCallback(ToolCompletionCallback):\n    def __init__(self, config: DictConfig, scheduler: ChatCompletionScheduler):\n        super().__init__(config, scheduler)\n\n        self.max_assistant_turns = 16\n        self.answer_pattern = re.compile(r\"<answer>(.*?)</answer>\", re.DOTALL)\n        self.code_pattern = re.compile(r\"<code>\\s*```python(.*?)```\\s*</code>\", re.DOTALL)\n\n        self.sandbox_fusion_url = config.reward_model.sandbox_fusion.url\n        self.default_timeout = 10\n        self.memory_limit_mb = config.reward_model.sandbox_fusion.memory_limit_mb\n        # TODO: support asyncio executor\n        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5))\n\n    async def sandbox_code_execution(self, code: str) -> dict[str, Any]:\n        loop = asyncio.get_running_loop()\n        result_status, metadata = await loop.run_in_executor(\n            self.executor,\n            _process_single_case,\n            0,  # case_index,\n            None,  # stdin_data,\n            None,  # expected_output,\n            self.sandbox_fusion_url,  # sandbox_fusion_url\n            code,  # generation\n            self.default_timeout,  # timeout\n            self.memory_limit_mb,  # memory limit\n            \"python\",  # language\n        )\n\n        return metadata\n\n    @property\n    def extra_body(self):\n        extra = {\n            \"include_stop_str_in_output\": True,\n            \"stop\": [\"</answer>\", \"</code>\"],\n        }\n        return extra\n\n    async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]):\n        role, content, finish_reason = (\n            completions.choices[0].message.role,\n            completions.choices[0].message.content,\n            completions.choices[0].finish_reason,\n        )\n        messages.append({\"role\": role, \"content\": content})\n        turn = len(messages)\n\n        # STEP 0: check if we reach max turns\n        if len(messages) >= self.max_assistant_turns:\n            print(f\"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max turns, done!\")\n            return\n\n        # STEP 1: check if we reach max tokens\n        if finish_reason == \"length\":\n            print(f\"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max tokens, done!\")\n            return\n\n        # STEP 2: check if we got answer\n        matches = self.answer_pattern.findall(content)\n        if matches:\n            print(f\"[id={completions.id},turn={turn},finish_reason={finish_reason}] Got answer: {matches[0]}, done!\")\n            return\n\n        # STEP 3: check if we got code block\n        matches = self.code_pattern.findall(content)\n        if not matches:\n            print(f\"[id={completions.id},turn={turn},finish_reason={finish_reason}] No code block found, done!\")\n            return\n\n        # STEP 4: execute code block in sandbox\n        code = matches[0].strip()\n        metadata = await self.sandbox_code_execution(code)\n        if metadata[\"run_status\"] != \"Finished\":\n            print(\n                f\"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block execution failed: \"\n                f\"{metadata}, done!\"\n            )\n            return\n\n        stdout, stderr = metadata[\"stdout\"], metadata[\"stderr\"]\n        messages.append({\"role\": \"tool\", \"content\": f\"<interpreter>{stdout}{stderr}</interpreter>\"})\n        print(f\"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block executed, continue...\")\n\n        # STEP 5: resubmit chat completions with code block output\n        self.scheduler.submit_chat_completions(\n            messages=messages,\n            request_id=completions.id,\n            info=info,\n        )\n\n\nuser_prompt_template = \"\"\"\nYou are a helpful assistant. Let's solve math problem in following steps:\n1. Write a python code first and return the code to user, the code must be in following format:\n\n<code>\n```python\nimport os\n\nprint(...)\n```\n</code>\n\nThe code must explictly print necessary output to stdout. Remember stop generation at </code> immediately and \nreturn the code.\n2. User will send the python code to a external sandbox to execute and get output from stdout.\n3. User will send the output in format <interpreter>output</interpreter> to you, and you should use the \noutput to answer the question.\nThe answer format must be: <answer>\\\\boxed{'The final answer goes here.'}</answer>\n\n*user question:*\n{question}\n\"\"\"\n\n\nif __name__ == \"__main__\":\n    ray.init(\n        runtime_env={\n            \"env_vars\": {\n                \"TOKENIZERS_PARALLELISM\": \"true\",\n                \"NCCL_DEBUG\": \"WARN\",\n                \"VLLM_LOGGING_LEVEL\": \"INFO\",\n                \"VLLM_USE_V1\": \"1\",\n            }\n        }\n    )\n\n    # Load config\n    import os\n\n    from hydra import compose, initialize_config_dir\n\n    with initialize_config_dir(config_dir=os.path.abspath(\"verl/trainer/config\")):\n        config = compose(config_name=\"ppo_trainer\")\n    model_path = \"Qwen/Qwen2.5-1.5B-Instruct\"\n    config.actor_rollout_ref.model.path = model_path\n    config.actor_rollout_ref.rollout.mode = \"async\"\n    config.actor_rollout_ref.rollout.multi_turn.format = \"hermes\"\n    config.actor_rollout_ref.rollout.multi_turn.completion_callback = (\n        \"tests.workers.rollout.test_custom_completion_callback.CustomCompletionCallback\"\n    )\n    config.actor_rollout_ref.rollout.prompt_length = 4096\n    config.actor_rollout_ref.rollout.response_length = 4096\n    config.actor_rollout_ref.rollout.n = 4\n\n    # Init sandbox and async rollout manager\n    sandbox = Sandbox.options(num_cpus=1).remote()\n    sandbox_address = ray.get(sandbox.get_server_address.remote())\n    sandbox_fusion_url = f\"http://{sandbox_address}/run_code\"\n    config.reward_model.sandbox_fusion.url = sandbox_fusion_url\n    async_rollout_manager = init_async_rollout_manager(config)\n\n    # Build dataset\n    dataset = load_dataset(\"Maxwell-Jia/AIME_2024\", split=\"train\")\n    prompts = DataProto(\n        non_tensor_batch={\n            \"raw_prompt\": np.array(\n                [\n                    [{\"role\": \"user\", \"content\": user_prompt_template.replace(\"{question}\", problem)}]\n                    for problem in dataset[\"Problem\"]\n                ]\n            ),\n        },\n    )\n\n    result = async_rollout_manager.generate_sequences(prompts=prompts)\n    assert len(result) == len(dataset) * config.actor_rollout_ref.rollout.n\n\n    # Check max turns that sandbox is called\n    num_turns = result.non_tensor_batch[\"__num_turns__\"]\n    print(f\"num_turns: {num_turns}\")\n    assert np.max(num_turns) > 2, f\"max turns: {np.max(num_turns)}\"\n\n    # Check response_mask\n    tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path)\n    responses = result.batch[\"responses\"]\n    response_mask = result.batch[\"response_mask\"]\n    assert responses.size() == response_mask.size(), f\"{responses.size()} != {response_mask.size()}\"\n\n    # Decode responses with response_mask\n    for i in range(len(responses)):\n        valid_tokens = responses[i][response_mask[i].bool()]\n        response_str = tokenizer.decode(valid_tokens)\n        assert \"<tool_response>\" not in response_str, f\"found <tool_response> in response: {response_str}\"\n        assert \"</tool_response>\" not in response_str, f\"found </tool_response> in response: {response_str}\"\n        print(f\"response: {response_str}\")\n\n    print(\"Test passed!\")\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_hf_rollout.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\nfrom omegaconf import OmegaConf\nfrom torch.distributed.fsdp import CPUOffload, MixedPrecision\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom verl import DataProto\nfrom verl.utils.distributed import initialize_global_process_group\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.workers.rollout.hf_rollout import HFRollout\n\nBASE_HF_ROLLOUT_CONFIG = {\n    \"temperature\": 1.0,\n    \"top_k\": -1,\n    \"top_p\": 1,\n    \"prompt_length\": 64,\n    \"response_length\": 64,\n    \"do_sample\": True,\n    \"n\": 1,\n    \"val_kwargs\": {\n        \"top_k\": -1,\n        \"top_p\": 1.0,\n        \"temperature\": 0,\n        \"n\": 1,\n        \"do_sample\": False,\n    },\n}\n\n\ndef prepare_input_dataproto(tokenizer, config, validate):\n    preencode_prompts = [\n        [{\"role\": \"user\", \"content\": \"Who won the Champions League in 2019?\"}],\n        [{\"role\": \"user\", \"content\": \"The founder of Apple is\"}],\n        [{\"role\": \"user\", \"content\": \"What's your name\"}],\n    ]\n    formatted_prompts = [\n        tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)\n        for conversation in preencode_prompts\n    ]\n    prompts = tokenizer(formatted_prompts, return_tensors=\"pt\", padding=\"max_length\", max_length=config.prompt_length)\n    input_dataproto = DataProto.from_dict(\n        {\n            \"input_ids\": prompts[\"input_ids\"],\n            \"attention_mask\": prompts[\"attention_mask\"],\n            \"position_ids\": compute_position_id_with_mask(prompts[\"attention_mask\"]),\n        },\n        meta_info={\n            \"bos_token_id\": tokenizer.bos_token_id,\n            \"eos_token_id\": tokenizer.eos_token_id,\n            \"pad_token_id\": tokenizer.pad_token_id,\n            \"validate\": validate,\n        },\n    )\n    return input_dataproto\n\n\ndef prepare_fsdp_model(model, world_size):\n    from torch.distributed.device_mesh import init_device_mesh\n\n    device_mesh = init_device_mesh(\"cuda\", mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n\n    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)\n\n    fsdp_model = FSDP(\n        model,\n        use_orig_params=True,\n        auto_wrap_policy=None,\n        device_id=torch.cuda.current_device(),\n        sharding_strategy=ShardingStrategy.FULL_SHARD,\n        mixed_precision=mixed_precision,\n        cpu_offload=CPUOffload(offload_params=False),\n        sync_module_states=False,\n        device_mesh=device_mesh,\n    )\n\n    FSDP.set_state_dict_type(\n        fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()\n    )\n    return fsdp_model\n\n\ndef test_hf_rollout(n: int = 1, do_sample: bool = True, validate: bool = False):\n    config = OmegaConf.create(BASE_HF_ROLLOUT_CONFIG)\n    config.update({\"n\": n, \"do_sample\": do_sample})\n\n    assert torch.cuda.device_count() >= 2, \"At least 2 GPUs is required to run tp+dp tests.\"\n    local_rank, rank, world_size = initialize_global_process_group()\n\n    # Initialize model and tokenizer\n    local_cache_path = \"~/.cache/verl/rlhf\"\n    local_cache_path = os.path.expanduser(local_cache_path)\n    hdfs_path = \"Qwen/Qwen2-7B-Instruct\"\n    local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)\n    tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side=\"left\", trust_remote_code=True)\n    tokenizer.pad_token = tokenizer.eos_token\n\n    # Initialize FSDP model\n    actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True)\n    actor_model.to(torch.bfloat16)\n    fsdp_model = prepare_fsdp_model(actor_model, world_size)\n\n    # Initialize HFRollout and start generate\n    hf_rollout = HFRollout(fsdp_model, OmegaConf.create(config))\n    input = prepare_input_dataproto(tokenizer, config, validate).to(torch.cuda.current_device())\n    outputs = hf_rollout.generate_sequences(input)\n\n    # check generated batch size is expected\n    generated_batch_size = outputs.batch.batch_size[0]\n    assert generated_batch_size == input.batch.batch_size[0] * config.n\n\n    for i in range(generated_batch_size):\n        prompt_tokens = outputs.batch[\"prompts\"][i]\n        prompt_mask = prompt_tokens != tokenizer.pad_token_id\n        prompt_tokens = prompt_tokens[prompt_mask]\n        decoded_prompt = tokenizer.decode(prompt_tokens, skip_special_tokens=False)\n\n        response_tokens = outputs.batch[\"responses\"][i]\n        response_mask = response_tokens != tokenizer.pad_token_id\n        response_tokens = response_tokens[response_mask]\n        decoded_response = tokenizer.decode(response_tokens, skip_special_tokens=False)\n\n        attention_mask = outputs.batch[\"attention_mask\"][i]\n        position_ids = outputs.batch[\"position_ids\"][i]\n        prompt_length = outputs.batch[\"prompts\"].size(1)\n        response_length = outputs.batch[\"responses\"].size(1)\n\n        assert attention_mask.size(0) == prompt_length + response_length\n        assert position_ids.size(0) == prompt_length + response_length\n\n        # check response attention mask is expected\n        response_attention = attention_mask[prompt_length:]\n        eos_positions = (outputs.batch[\"responses\"][i] == tokenizer.pad_token_id).nonzero(as_tuple=True)[0]\n        if len(eos_positions) > 0:\n            first_eos_pos = eos_positions[0].item()\n            assert response_attention[: first_eos_pos + 1].all(), \"Response attention mask should be 1 until EOS\"\n            if first_eos_pos + 1 < response_length:\n                assert not response_attention[first_eos_pos + 1 :].any(), (\n                    \"Response attention mask should be 0 after EOS\"\n                )\n        else:\n            assert response_attention.all(), \"Response attention mask should be all 1 if no EOS token\"\n\n        # check response position ids is expected\n        prompt_positions = position_ids[:prompt_length]\n        response_positions = position_ids[prompt_length:]\n        valid_response_length = min(len(response_tokens), response_length)\n        if valid_response_length > 0:\n            assert response_positions[0] == prompt_positions[-1] + 1\n            for j in range(1, valid_response_length):\n                assert response_positions[j] == response_positions[j - 1] + 1\n\n        # print generated text for inspection\n        if torch.distributed.get_rank() == 0:\n            print(f\"prompt: {decoded_prompt}\")\n            print(f\"response: {decoded_response}\")\n            print(\"=\" * 30)\n\n\nif __name__ == \"__main__\":\n    test_hf_rollout(n=2, do_sample=True, validate=False)\n    # test_hf_rollout(n=1, do_sample=False, validate=True)\n    # test_hf_rollout(n=1, do_sample=True, validate=False)\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from tests/workers/rollout/test_sglang_async_rollout_sf_tools.py\n\n\nimport asyncio\nfrom copy import deepcopy\nfrom unittest.mock import AsyncMock, MagicMock, patch\n\nimport numpy as np\nimport pytest\nfrom tensordict import TensorDict\nfrom transformers import AutoConfig, AutoTokenizer\nfrom utils_sglang import get_rollout_config, prepare_inputs\n\nfrom verl.protocol import DataProto\nfrom verl.tools.mcp_search_tool import MCPSearchTool\nfrom verl.tools.utils.mcp_clients.McpClientManager import MCPClientManager\nfrom verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\nDEFAULT_USER_CONTENT_PREFIX = (\n    \"Answer the given question. You must conduct reasoning inside <think> and </think> \"\n    \"first every time you get new information. After reasoning, if you find you lack \"\n    \"some knowledge, you can call a search engine by <tool_call> query </tool_call> \"\n    \"and it will return the top searched results between <tool_response> and \"\n    \"</tool_response>. You can search as many times as your want. If you find no \"\n    \"further external knowledge needed, you can directly provide the answer inside \"\n    \"<answer> and </answer>, without detailed illustrations. For example, \"\n    \"<answer> Beijing </answer>. Question: \"\n)\nuser_content = DEFAULT_USER_CONTENT_PREFIX.rstrip(\"\\n\") + \"How's the weather lately?\"\n\n\ndef get_search_messages():\n    user_prompt = {\n        \"role\": \"user\",\n        \"content\": user_content,\n    }\n\n    expect_turn_0_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"Let me search the web.\",\n        \"tool_calls\": [\n            {\n                \"id\": \"10\",\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"tavily_search_tool\",\n                    \"arguments\": {\n                        \"what_is_your_intent\": \"Search for the weather lately\",\n                        \"query\": \"the weather in Beijing today\",\n                        \"search_depth\": \"basic\",\n                        \"time_range\": \"day\",\n                        \"include_domains\": [\"google.com\", \"baidu.com\"],\n                        \"max_results\": 2,\n                    },\n                },\n            }\n        ],\n    }\n\n    expect_turn_1_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"Let me search again.\",\n        \"tool_calls\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"tavily_search_tool\",\n                    \"arguments\": {\n                        \"what_is_your_intent\": \"Search for the weather lately\",\n                        \"query\": \"the weather in Beijing tomorrow\",\n                        \"search_depth\": \"basic\",\n                        \"time_range\": \"day\",\n                        \"include_domains\": [\"google.com\", \"baidu.com\"],\n                        \"max_results\": 2,\n                    },\n                },\n            }\n        ],\n    }\n\n    expect_turn_2_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"<answer>Today is sunny and tomorrow will be cloudy in Beijing.</answer>\",\n    }\n\n    # Mock search tool responses\n    tool_return_0_msg = {\"role\": \"tool\", \"content\": [{\"type\": \"text\", \"text\": \"Today's weather in Beijing is sunny.\"}]}\n    tool_return_1_msg = {\n        \"role\": \"tool\",\n        \"content\": [{\"type\": \"text\", \"text\": \"Tomorrow's weather in Beijing is cloudy.\"}],\n    }\n\n    user_prompts = [user_prompt]\n    expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg]\n    tool_return_array = [tool_return_0_msg, tool_return_1_msg]\n\n    return user_prompts, expect_turn_array, tool_return_array\n\n\nclass TestRolloutWithMCPSearchTools:\n    @pytest.fixture\n    def qwen_tokenizer(self):\n        local_model_path = \"Qwen/Qwen2.5-0.5B\"\n        tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side=\"left\")\n        tokenizer.pad_token = tokenizer.eos_token\n        return tokenizer\n\n    # we only need this for tokenizer\n    @pytest.fixture\n    def qwen_model_config(self):\n        local_model_path = \"Qwen/Qwen2.5-0.5B\"\n        config = AutoConfig.from_pretrained(local_model_path)\n        return config\n\n    @pytest.fixture\n    def search_data(self, qwen_tokenizer):\n        user_prompt, expect_turn_array, tool_return_array = get_search_messages()\n        prompts = [[message] for message in user_prompt]\n        preencode_turn_array = [\n            qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False)\n            for turn in expect_turn_array\n        ]\n        preencode_tool_return_array = [\n            qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True)\n            for turn in tool_return_array\n        ]\n        return prompts, preencode_turn_array, preencode_tool_return_array\n\n    @pytest.fixture\n    def search_rollout_config(self):\n        max_prompt_length = 4096\n        max_response_length = 3000\n        dtype = \"bfloat16\"\n        tensor_parallel_size = 1\n        tool_path = \"./resource/tool_configs/mcp_tool_config\"\n        rollout_config = get_rollout_config(\n            max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path\n        )\n        return rollout_config\n\n    @pytest.fixture\n    def search_data_proto(self, search_data, qwen_tokenizer):\n        preencode_prompts, _, _ = search_data\n        prompts = [\n            qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n            for message in preencode_prompts\n        ]\n        input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000)\n        prompt_dict = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=input_ids.shape[0],\n        )\n        messages = np.asarray(preencode_prompts)\n\n        tools_kwargs = np.array(\n            [\n                {\n                    \"tavily_search_tool\": {\n                        \"create_kwargs\": {\"ground_truth\": \"Today is sunny and tomorrow will be cloudy in Beijing.\"},\n                    },\n                }\n            ],\n            dtype=object,\n        )\n        index = np.array([0], dtype=object)\n        prompts = DataProto(\n            batch=prompt_dict, non_tensor_batch={\"raw_prompt\": messages, \"tools_kwargs\": tools_kwargs, \"index\": index}\n        )\n        return prompts\n\n    @pytest.fixture\n    def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config):\n        \"\"\"Mock the rollout instance with sampling_params initialized.\"\"\"\n        tool_schema = [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"tavily_search_tool\",\n                    \"description\": \"A powerful web search tool...\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\n                            \"what_is_your_intent\": {\n                                \"type\": \"string\",\n                                \"description\": \"Describe your intent for using Tavily\",\n                            },\n                            \"query\": {\"type\": \"string\", \"description\": \"Search query\"},\n                            \"search_depth\": {\n                                \"type\": \"string\",\n                                \"description\": \"The depth of the search ('basic' or 'advanced')\",\n                            },\n                            \"topic\": {\n                                \"type\": \"string\",\n                                \"description\": \"The category of the search ('general' or 'news')\",\n                            },\n                            \"days\": {\n                                \"type\": \"integer\",\n                                \"description\": \"Number of days back to include in search results (only for \"\n                                \"'news' topic)\",\n                            },\n                            \"time_range\": {\n                                \"type\": \"string\",\n                                \"description\": \"Time range for results ('day', 'week', 'month', 'year', 'd', \"\n                                \"'w', 'm', 'y')\",\n                            },\n                            \"include_domains\": {\n                                \"type\": \"array\",\n                                \"description\": \"List of domains to specifically include in search results\",\n                            },\n                            \"exclude_domains\": {\n                                \"type\": \"array\",\n                                \"description\": \"List of domains to specifically exclude from search results\",\n                            },\n                            \"include_answer\": {\n                                \"type\": \"boolean\",\n                                \"description\": \"Whether to include an answer summary generated by an LLM\",\n                            },\n                            \"include_raw_content\": {\n                                \"type\": \"boolean\",\n                                \"description\": \"Whether to include the cleaned and parsed HTML content of each result\",\n                            },\n                            \"include_images\": {\n                                \"type\": \"boolean\",\n                                \"description\": \"Whether to include images from search results\",\n                            },\n                            \"include_image_descriptions\": {\n                                \"type\": \"boolean\",\n                                \"description\": \"Whether to include descriptions with images\",\n                            },\n                            \"max_results\": {\n                                \"type\": \"integer\",\n                                \"description\": \"Maximum number of results to return (5-20)\",\n                            },\n                            \"async_search\": {\n                                \"type\": \"boolean\",\n                                \"description\": \"Whether to perform the search asynchronously\",\n                            },\n                        },\n                        \"required\": [\"what_is_your_intent\", \"query\"],\n                    },\n                    \"strict\": False,\n                },\n            }\n        ]\n        with (\n            patch.object(MCPClientManager, \"fetch_tool_schemas\", return_value=tool_schema),\n            patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n            patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n            patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n        ):\n            rollout = SGLangRollout(\n                actor_module=\"\",\n                config=search_rollout_config,\n                processing_class=qwen_tokenizer,\n                model_hf_config=qwen_model_config,\n            )\n            rollout.sampling_params = {\n                \"n\": 1,\n                \"max_new_tokens\": search_rollout_config.response_length,\n                \"presence_penalty\": 0.0,\n                \"frequency_penalty\": 0.0,\n                \"repetition_penalty\": 1.0,\n            }\n            return rollout\n\n    def test_tools_registration(self, mock_rollout):\n        assert len(mock_rollout._tool_schemas) != 0\n        assert \"tavily_search_tool\" in mock_rollout._tool_map.keys()\n        from verl.tools.mcp_search_tool import MCPSearchTool\n\n        assert isinstance(mock_rollout._tool_map[\"tavily_search_tool\"], MCPSearchTool)\n        # depend on the tokenizer\n        assert mock_rollout._tool_call_parser_type == \"qwen25\"\n\n    def test_rollout_req_creation(self, mock_rollout, search_data_proto):\n        req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)\n        assert len(req_list) == 1\n        assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING\n        assert len(req_list[0].tool_schemas) == 1\n\n    def test_over_size_case(self, mock_rollout, search_data_proto, search_data):\n        mock_rollout.config.multi_turn.max_assistant_turns = 1\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n\n        _, expect_turn_array, _ = search_data\n        # here we mock a meta info with 'length'. indicate the response is truncate\n        mock_rollout._handle_engine_call = MagicMock()\n        future = asyncio.Future()\n        future.set_result(\n            {\n                \"text\": expect_turn_array[0],\n                \"meta_info\": {\n                    \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                    \"finish_reason\": {\"type\": \"length\", \"length\": 3000},\n                    \"prompt_tokens\": 132,\n                    \"completion_tokens\": 100,\n                    \"cached_tokens\": 0,\n                    \"e2e_latency\": 2.23543,\n                },\n            }\n        )\n        mock_rollout._handle_engine_call.return_value = future\n        mock_rollout._tp_rank = 0\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(\n                *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list],\n            )\n        )\n        assert len(output_req_list) == 1\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        assert output_req.reward_scores.get(\"tavily_search_tool\") == []\n        # we should only have two message, one for prompt, second for response.\n        assert len(output_req.messages) == 2\n        assert output_req.messages[1] == Message(\n            role=\"assistant\",\n            content=expect_turn_array[0],\n            tool_calls=None,\n        )\n\n    @patch.object(MCPSearchTool, \"execute\", new_callable=AsyncMock)\n    def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_proto, search_data):\n        _, expect_turn_array, tool_return_array = search_data\n        # Mock search tool execution to return predefined responses\n        mock_execute.side_effect = [(msg, 0.0, {\"status\": \"success\"}) for msg in tool_return_array]\n\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n\n        mock_rollout._handle_engine_call = MagicMock()\n        futures = [asyncio.Future() for i in expect_turn_array]\n        for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):\n            i.set_result(\n                {\n                    \"text\": turn,\n                    \"meta_info\": {\n                        \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                        \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                        \"prompt_tokens\": len(turn),\n                        \"completion_tokens\": 100,\n                        \"cached_tokens\": 0,\n                        \"e2e_latency\": 2.23543,\n                    },\n                }\n            )\n            if idx < len(expect_turn_array) - 1:\n                assert mock_rollout._function_call_parser.has_tool_call(turn)\n                assert mock_rollout._function_call_parser.parse_non_stream(turn)\n\n        mock_rollout._handle_engine_call.side_effect = futures\n        mock_rollout._tp_rank = 0\n\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(*[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list])\n        )\n\n        # Verify conversation completed successfully with proper tool usage\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        assert \"tavily_search_tool\" in output_req.metrics\n        assert output_req.metrics[\"tavily_search_tool\"][0][\"status\"] == \"success\"\n        assert mock_execute.await_count == 2\n        assert len(output_req.messages) == 6\n        # Verify tool response messages contain expected content\n        search_counter = 0\n        for msg in output_req.messages:\n            if msg.role == \"tool\":\n                assert msg.content == tool_return_array[search_counter]\n                search_counter += 1\n        assert search_counter == 2\n\n    @patch.object(MCPSearchTool, \"execute\", new_callable=AsyncMock)\n    def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_proto, search_data):\n        _, expect_turn_array, tool_return_array = search_data\n        # Mock tool execution for large batch (100 requests * 2 calls each)\n        mock_execute.side_effect = [\n            (tool_return_array[0], 0.0, {\"status\": \"success\"}),\n            (tool_return_array[1], 0.0, {\"status\": \"success\"}),\n        ] * 100\n\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n\n        req_nums = 100\n        req_list = []\n        req_turns_map = {}\n        req_turns_counter = {}\n\n        for i in range(req_nums):\n            tmp_req = deepcopy(base_req)\n            tmp_req.batch_data_id = i\n            tmp_req.request_id = i\n            req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest))\n\n            futures = [asyncio.Future() for _ in expect_turn_array]\n            for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):\n                fut.set_result(\n                    {\n                        \"text\": turn,\n                        \"meta_info\": {\n                            \"id\": \"dummy\",\n                            \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                            \"prompt_tokens\": len(turn),\n                            \"completion_tokens\": 100,\n                        },\n                    }\n                )\n            req_turns_map[i] = futures\n            req_turns_counter[i] = 0\n\n        async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_kwargs):\n            fut = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]]\n            req_turns_counter[_req.batch_data_id] += 1\n            return await fut\n\n        with patch.object(SGLangRollout, \"_handle_engine_call\", new=hacked_handle_engine_call):\n            mock_rollout._tp_rank = 0\n            loop = asyncio.get_event_loop()\n            output_req_list = loop.run_until_complete(\n                asyncio.gather(*[mock_rollout._async_rollout_a_request(r, True, False) for r in req_list])\n            )\n\n        # Verify all requests completed successfully\n        assert len(output_req_list) == req_nums\n        for out_req in output_req_list:\n            assert out_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n            assert \"tavily_search_tool\" in out_req.metrics\n            for metric in out_req.metrics[\"tavily_search_tool\"]:\n                assert metric[\"status\"] == \"success\"\n            assert len(out_req.messages) == 6\n            assert sum(1 for m in out_req.messages if m.role == \"tool\") == 2\n\n        assert mock_execute.await_count == 2 * req_nums\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py",
    "content": "# Copyright 2025 Amazon.com, Inc. or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport pytest\n\nfrom verl.utils.dataset.vision_utils import process_image\nfrom verl.utils.tokenizer import hf_processor\nfrom verl.workers.rollout.schemas import (\n    AsyncRolloutRequest,\n    AsyncRolloutRequestStateEnum,\n    TokenizationSanityCheckModeEnum,\n)\n\n\ndef _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False):\n    assert len(image_list) == len(description_list)\n    # Get the smallest dimensions across all images\n    processed_images = []\n    for img_url in image_list:\n        img = process_image(img_url)\n        processed_images.append(img)\n\n    min_width = min(img.size[0] for img in processed_images)\n    min_height = min(img.size[1] for img in processed_images)\n    min_size = (min_width, min_height)\n\n    if resize_image:\n        processed_images_resized = []\n        for img in processed_images:\n            img = img.resize(min_size)\n            processed_images_resized.append(img)\n        processed_images = processed_images_resized\n\n    # Initial message history\n    system_prompt = (\n        \"You will be provided with an image. Describe this image and then generate a new image for the next round\"\n    )\n    messages = [\n        {\n            \"role\": \"system\",\n            \"content\": system_prompt,\n        },\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": \"Here is the first image provided: \"},\n                {\"type\": \"image\", \"image\": [processed_images[0]]},\n            ],\n        },\n    ]\n\n    # Initial multi_modal_data with one image\n    multi_modal_data = {\"image\": [processed_images[0]], \"video\": []}\n    # Minimal required fields for AsyncRolloutRequest\n\n    req = AsyncRolloutRequest(\n        batch_data_id=0,\n        request_id=\"test-req-1\",\n        state=AsyncRolloutRequestStateEnum.PENDING,\n        messages=messages,\n        multi_modal_keys=[\"image\", \"video\"],\n        multi_modal_data=multi_modal_data.copy(),\n        tool_schemas=[],\n        tools_kwargs={},\n        interaction_kwargs={},\n        input_ids=None,\n        prompt_ids=None,\n        response_ids=None,\n        attention_mask=None,\n        prompt_attention_mask=None,\n        response_attention_mask=None,\n        position_ids=None,\n        prompt_position_ids=None,\n        response_position_ids=None,\n        loss_mask=None,\n        prompt_loss_mask=None,\n        response_loss_mask=None,\n        reward_scores={},\n        max_prompt_len=8192,\n        max_response_len=8192,\n        max_model_len=16384,\n        metrics={},\n        use_inference_chat_template=True,\n        tokenization_sanity_check_mode=TokenizationSanityCheckModeEnum.STRICT,\n        generation_prompt_ids=None,\n        base_conv_wo_gen_prompt_end_pos=0,\n        base_conv_with_gen_prompt_end_pos=0,\n        processing_class=processor,\n    )\n\n    prev_generated_len = 0\n    # Add First Assistant Message and first tool response message(image)\n    for idx, img in enumerate(processed_images):\n        if idx == 0:\n            continue\n        _ = req.get_generation_prompt_ids(processor)\n        req.add_assistant_message(processor, content=description_list[idx - 1])\n        before_tool_call_len = req.input_ids.shape[-1]\n        req.add_tool_response_messages(processor, [{\"image\": [img], \"text\": \"Here is the new image you requested: \"}])\n        after_tool_call_len = req.input_ids.shape[-1]\n        if prev_generated_len == 0:\n            prev_generated_len = after_tool_call_len - before_tool_call_len\n        else:\n            if resize_image:\n                assert after_tool_call_len - before_tool_call_len == prev_generated_len\n        assert req.multi_modal_data[\"image\"] == processed_images[: idx + 1]\n\n    _ = req.get_generation_prompt_ids(processor)\n    req.add_assistant_message(processor, content=description_list[-1])\n\n    messages = [msg.model_dump() for msg in req.messages]\n    tools = [tool.model_dump() for tool in req.tool_schemas] if req.tool_schemas else None\n    full_prompt_info = req._handle_apply_chat_template(\n        processor,\n        messages,\n        multi_modal_data=req.multi_modal_data,\n        tools=tools,\n        add_generation_prompt=False,\n        tokenize=True,\n        return_dict=True,\n    )\n    full_prompt_ids = full_prompt_info[\"input_ids\"]\n    assert full_prompt_ids.eq(req.input_ids).all()\n\n    # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict\n    # because np.array() only keeps the keys for BatchFeature.\n    full_prompt_multi_modal_inputs = full_prompt_info.copy()\n    full_prompt_multi_modal_inputs.pop(\"input_ids\", None)\n    full_prompt_multi_modal_inputs.pop(\"attention_mask\", None)\n\n    for key in full_prompt_multi_modal_inputs:\n        assert full_prompt_multi_modal_inputs[key].eq(req.multi_modal_inputs[key]).all()\n\n\n@pytest.mark.skipif(\n    hf_processor(\"Qwen/Qwen2.5-VL-3B-Instruct\") is None, reason=\"Processor not available for Qwen/Qwen2.5-VL-B-Instruct\"\n)\ndef test_add_tool_response_messages_image_delta():\n    processor = hf_processor(\"Qwen/Qwen2.5-VL-3B-Instruct\")\n\n    # From Qwen2.5-VL-3B-Instruct HF example\n    img_1_url = {\"image\": \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg\"}\n    img_1_description = \"A woman sits on the beach at sunset, smiling as she shares a high five with her large dog.\"\n    # GitHub Logo\n    img_2_url = {\"image\": \"https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png\"}\n    img_2_description = \"A GitHub Logo image\"\n    # Octocat\n    img_3_url = {\"image\": \"https://octodex.github.com/images/orderedlistocat.png\"}\n    img_3_description = \"An Octocat image\"\n\n    image_list = [img_1_url, img_2_url, img_3_url]\n    description_list = [img_1_description, img_2_description, img_3_description]\n    _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False)\n\n\n@pytest.mark.skipif(\n    hf_processor(\"Qwen/Qwen2.5-VL-3B-Instruct\") is None, reason=\"Processor not available for Qwen/Qwen2.5-VL-B-Instruct\"\n)\ndef test_add_tool_response_messages_image_delta_resize_image():\n    processor = hf_processor(\"Qwen/Qwen2.5-VL-3B-Instruct\")\n\n    # From Qwen2.5-VL-3B-Instruct HF example\n    img_1_url = {\"image\": \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg\"}\n    img_1_description = \"A woman sits on the beach at sunset, smiling as she shares a high five with her large dog.\"\n    # GitHub Logo\n    img_2_url = {\"image\": \"https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png\"}\n    img_2_description = \"A GitHub Logo image\"\n    # Octocat\n    img_3_url = {\"image\": \"https://octodex.github.com/images/orderedlistocat.png\"}\n    img_3_description = \"An Octocat image\"\n\n    image_list = [img_1_url, img_2_url, img_3_url]\n    description_list = [img_1_description, img_2_description, img_3_description]\n    _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=True)\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_sglang_async_rollout_search_tools.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from tests/workers/rollout/test_sglang_async_rollout_sf_tools.py\n\n\nimport asyncio\nfrom copy import deepcopy\nfrom unittest.mock import AsyncMock, MagicMock, patch\n\nimport numpy as np\nimport pytest\nfrom tensordict import TensorDict\nfrom transformers import AutoConfig, AutoTokenizer\nfrom utils_sglang import get_rollout_config, prepare_inputs\n\nfrom verl.protocol import DataProto\nfrom verl.tools.schemas import (\n    OpenAIFunctionParametersSchema,\n    OpenAIFunctionPropertySchema,\n    OpenAIFunctionSchema,\n    OpenAIFunctionToolSchema,\n)\nfrom verl.tools.search_tool import SearchTool\nfrom verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\nDEFAULT_USER_CONTENT_PREFIX = (\n    \"Answer the given question. You must conduct reasoning inside <think> and </think> \"\n    \"first every time you get new information. After reasoning, if you find you lack \"\n    \"some knowledge, you can call a search engine by <tool_call> query </tool_call> \"\n    \"and it will return the top searched results between <tool_response> and \"\n    \"</tool_response>. You can search as many times as your want. If you find no \"\n    \"further external knowledge needed, you can directly provide the answer inside \"\n    \"<answer> and </answer>, without detailed illustrations. For example, \"\n    \"<answer> Beijing </answer>. Question: \"\n)\nuser_content = DEFAULT_USER_CONTENT_PREFIX.rstrip(\"\\n\") + \"How's the weather lately?\"\n\n\ndef get_search_messages():\n    user_prompt = {\n        \"role\": \"user\",\n        \"content\": user_content,\n    }\n\n    expect_turn_0_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"Let me search the web.\",\n        \"tool_calls\": [{\"type\": \"function\", \"function\": {\"name\": \"search\", \"arguments\": {\"query\": \"today's weather\"}}}],\n    }\n\n    expect_turn_1_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"Let me search again.\",\n        \"tool_calls\": [\n            {\"type\": \"function\", \"function\": {\"name\": \"search\", \"arguments\": {\"query\": \"tomorrow's weather\"}}}\n        ],\n    }\n\n    expect_turn_2_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"<answer>Today is sunny and tomorrow will be cloudy in Beijing.</answer>\",\n    }\n\n    # Mock search tool responses\n    tool_return_0_msg = {\"role\": \"tool\", \"content\": \"Today's weather in Beijing is sunny.\"}\n    tool_return_1_msg = {\"role\": \"tool\", \"content\": \"Tomorrow's weather in Beijing is cloudy.\"}\n\n    user_prompts = [user_prompt]\n    expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg]\n    tool_return_array = [tool_return_0_msg, tool_return_1_msg]\n\n    return user_prompts, expect_turn_array, tool_return_array\n\n\nclass TestRolloutWithSearchTools:\n    @pytest.fixture\n    def qwen_tokenizer(self):\n        local_model_path = \"Qwen/Qwen2.5-0.5B\"\n        tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side=\"left\")\n        tokenizer.pad_token = tokenizer.eos_token\n        return tokenizer\n\n    # we only need this for tokenizer\n    @pytest.fixture\n    def qwen_model_config(self):\n        local_model_path = \"Qwen/Qwen2.5-0.5B\"\n        config = AutoConfig.from_pretrained(local_model_path)\n        return config\n\n    @pytest.fixture\n    def search_data(self, qwen_tokenizer):\n        user_prompt, expect_turn_array, tool_return_array = get_search_messages()\n        prompts = [[message] for message in user_prompt]\n        preencode_turn_array = [\n            qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False)\n            for turn in expect_turn_array\n        ]\n        preencode_tool_return_array = [\n            qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True)\n            for turn in tool_return_array\n        ]\n        return prompts, preencode_turn_array, preencode_tool_return_array\n\n    @pytest.fixture\n    def search_rollout_config(self):\n        max_prompt_length = 4096\n        max_response_length = 3000\n        dtype = \"bfloat16\"\n        tensor_parallel_size = 1\n        tool_path = \"./resource/tool_configs/search_tool_config\"\n        rollout_config = get_rollout_config(\n            max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path\n        )\n        return rollout_config\n\n    @pytest.fixture\n    def search_data_proto(self, search_data, qwen_tokenizer):\n        preencode_prompts, _, _ = search_data\n        prompts = [\n            qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n            for message in preencode_prompts\n        ]\n        input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000)\n        prompt_dict = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=input_ids.shape[0],\n        )\n        messages = np.asarray(preencode_prompts)\n\n        tools_kwargs = np.array(\n            [\n                {\n                    \"search\": {\n                        \"create_kwargs\": {\n                            \"ground_truth\": \"Today is sunny and tomorrow will be cloudy in Beijing.\",\n                            \"data_source\": \"searchR1_nq\",\n                        },\n                    },\n                }\n            ],\n            dtype=object,\n        )\n        index = np.array([0], dtype=object)\n        prompts = DataProto(\n            batch=prompt_dict, non_tensor_batch={\"raw_prompt\": messages, \"tools_kwargs\": tools_kwargs, \"index\": index}\n        )\n        return prompts\n\n    @pytest.fixture\n    def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config):\n        \"\"\"Mock the rollout instance with sampling_params initialized.\"\"\"\n        with (\n            patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n            patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n            patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n        ):\n            rollout = SGLangRollout(\n                actor_module=\"\",\n                config=search_rollout_config,\n                processing_class=qwen_tokenizer,\n                model_hf_config=qwen_model_config,\n            )\n            rollout.sampling_params = {\n                \"n\": 1,\n                \"max_new_tokens\": search_rollout_config.response_length,\n                \"presence_penalty\": 0.0,\n                \"frequency_penalty\": 0.0,\n                \"repetition_penalty\": 1.0,\n            }\n            return rollout\n\n    @patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None)\n    @patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None)\n    @patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None)\n    def test_tools_registration(\n        self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config\n    ):\n        rollout = SGLangRollout(\n            actor_module=\"\",\n            config=search_rollout_config,\n            processing_class=qwen_tokenizer,\n            model_hf_config=qwen_model_config,\n        )\n        assert len(rollout._tool_schemas) == 1\n        assert \"search\" in rollout._tool_map.keys()\n        from verl.tools.search_tool import SearchTool\n\n        assert isinstance(rollout._tool_map[\"search\"], SearchTool)\n        # depend on the tokenizer\n        assert rollout._tool_call_parser_type == \"qwen25\"\n\n    @patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None)\n    @patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None)\n    @patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None)\n    def test_rollout_req_creation(\n        self,\n        mock_env,\n        mock_engine,\n        mock_sampling,\n        search_rollout_config,\n        qwen_tokenizer,\n        qwen_model_config,\n        search_data_proto,\n    ):\n        rollout = SGLangRollout(\n            actor_module=\"\",\n            config=search_rollout_config,\n            processing_class=qwen_tokenizer,\n            model_hf_config=qwen_model_config,\n        )\n        req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)\n        assert len(req_list) == 1\n        assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING\n        assert len(req_list[0].tool_schemas) == 1\n        print(type(req_list[0].tool_schemas[0]))\n        assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema(\n            type=\"function\",\n            function=OpenAIFunctionSchema(\n                name=\"search\",\n                description=\"Searches the web for relevant information based on the given query.\",\n                parameters=OpenAIFunctionParametersSchema(\n                    type=\"object\",\n                    properties={\n                        \"query_list\": OpenAIFunctionPropertySchema(\n                            type=\"array\",\n                            description=\"A list of fully-formed semantic queries. The tool will return search \"\n                            \"results for each query.\",\n                            items={\"type\": \"string\"},\n                        )\n                    },\n                    required=[\"query_list\"],\n                ),\n                strict=False,\n            ),\n        )\n\n    def test_over_size_case(self, mock_rollout, search_data_proto, search_data):\n        mock_rollout.config.multi_turn.max_assistant_turns = 1\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n\n        _, expect_turn_array, _ = search_data\n        mock_rollout._handle_engine_call = MagicMock()\n        future = asyncio.Future()\n        future.set_result(\n            {\n                \"text\": expect_turn_array[0],\n                \"meta_info\": {\n                    \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                    \"finish_reason\": {\"type\": \"length\", \"length\": 3000},\n                    \"prompt_tokens\": 132,\n                    \"completion_tokens\": 100,\n                    \"cached_tokens\": 0,\n                    \"e2e_latency\": 2.23543,\n                },\n            }\n        )\n        mock_rollout._handle_engine_call.return_value = future\n        mock_rollout._tp_rank = 0\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(\n                *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list],\n            )\n        )\n        assert len(output_req_list) == 1\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        assert output_req.reward_scores.get(\"search\") == []\n        assert len(output_req.messages) == 2\n        assert output_req.messages[1] == Message(\n            role=\"assistant\",\n            content=expect_turn_array[0],\n            tool_calls=None,\n        )\n\n    @patch.object(SearchTool, \"execute\", new_callable=AsyncMock)\n    def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_proto, search_data):\n        _, expect_turn_array, tool_return_array = search_data\n\n        # Mock search tool execution to return predefined responses\n        mock_execute.side_effect = [(msg, 0.0, {\"status\": \"success\"}) for msg in tool_return_array]\n\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        mock_rollout._tool_map[\"search\"].retrieval_service_url = \"mock://dummy\"\n\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n\n        mock_rollout._handle_engine_call = MagicMock()\n        futures = [asyncio.Future() for i in expect_turn_array]\n        for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):\n            i.set_result(\n                {\n                    \"text\": turn,\n                    \"meta_info\": {\n                        \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                        \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                        \"prompt_tokens\": len(turn),\n                        \"completion_tokens\": 100,\n                        \"cached_tokens\": 0,\n                        \"e2e_latency\": 2.23543,\n                    },\n                }\n            )\n            if idx < len(expect_turn_array) - 1:\n                assert mock_rollout._function_call_parser.has_tool_call(turn)\n                assert mock_rollout._function_call_parser.parse_non_stream(turn)\n\n        mock_rollout._handle_engine_call.side_effect = futures\n        mock_rollout._tp_rank = 0\n\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(*[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list])\n        )\n\n        # Verify conversation completed successfully with proper tool usage\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        assert \"search\" in output_req.metrics\n        assert output_req.metrics[\"search\"][0][\"status\"] == \"success\"\n        assert mock_execute.await_count == 2\n        assert len(output_req.messages) == 6  # user + 3*assistant + 2*tool_call\n        # Verify tool response messages contain expected content\n        search_counter = 0\n        for msg in output_req.messages:\n            if msg.role == \"tool\":\n                assert msg.content == tool_return_array[search_counter]\n                search_counter += 1\n        assert search_counter == 2\n\n    @patch.object(SearchTool, \"execute\", new_callable=AsyncMock)\n    def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_proto, search_data):\n        _, expect_turn_array, tool_return_array = search_data\n\n        # Mock tool execution for large batch (100 requests * 2 calls each)\n        mock_execute.side_effect = [\n            (tool_return_array[0], 0.0, {\"status\": \"success\"}),\n            (tool_return_array[1], 0.0, {\"status\": \"success\"}),\n        ] * 100\n\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        mock_rollout._tool_map[\"search\"].retrieval_service_url = \"mock://dummy\"\n\n        base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]\n\n        req_nums = 100\n        req_list = []\n        req_turns_map = {}\n        req_turns_counter = {}\n\n        for i in range(req_nums):\n            tmp_req = deepcopy(base_req)\n            tmp_req.batch_data_id = i\n            tmp_req.request_id = i\n            req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest))\n\n            futures = [asyncio.Future() for _ in expect_turn_array]\n            for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):\n                fut.set_result(\n                    {\n                        \"text\": turn,\n                        \"meta_info\": {\n                            \"id\": \"dummy\",\n                            \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                            \"prompt_tokens\": len(turn),\n                            \"completion_tokens\": 100,\n                        },\n                    }\n                )\n            req_turns_map[i] = futures\n            req_turns_counter[i] = 0\n\n        async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_kwargs):\n            fut = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]]\n            req_turns_counter[_req.batch_data_id] += 1\n            return await fut\n\n        with patch.object(SGLangRollout, \"_handle_engine_call\", new=hacked_handle_engine_call):\n            mock_rollout._tp_rank = 0\n            loop = asyncio.get_event_loop()\n            output_req_list = loop.run_until_complete(\n                asyncio.gather(*[mock_rollout._async_rollout_a_request(r, True, False) for r in req_list])\n            )\n\n        # Verify all requests completed successfully\n        assert len(output_req_list) == req_nums\n        for out_req in output_req_list:\n            assert out_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n            assert \"search\" in out_req.metrics\n            for metric in out_req.metrics[\"search\"]:\n                assert metric[\"status\"] == \"success\"\n            assert len(out_req.messages) == 6  # user + 3 assistant + 2 tool\n            assert sum(1 for m in out_req.messages if m.role == \"tool\") == 2\n\n        assert mock_execute.await_count == 2 * req_nums\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# noqa\nimport asyncio\nimport time\nfrom copy import deepcopy\nfrom functools import wraps\nfrom unittest.mock import MagicMock, patch\n\nimport numpy as np\nimport pytest\nimport ray\nfrom tensordict import TensorDict\nfrom torch.testing._internal.common_distributed import MultiProcessTestCase\nfrom transformers import AutoConfig, AutoTokenizer\nfrom utils_sglang import (\n    get_rollout_config,\n    prepare_inputs,\n)\n\nfrom verl.protocol import DataProto\nfrom verl.tools.sandbox_fusion_tools import TokenBucketWorker\nfrom verl.tools.schemas import (\n    OpenAIFunctionParametersSchema,\n    OpenAIFunctionPropertySchema,\n    OpenAIFunctionSchema,\n    OpenAIFunctionToolSchema,\n)\nfrom verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\nsandbox_url = \"\"\n\n\ndef get_sandbox_fusion_messages():\n    user_prompt = {\n        \"role\": \"user\",\n        \"content\": \"\"\"\n            Solve the following problem step by step. You now have the ability to selectively \n            write executable Python code to enhance your reasoning process. \\n\\n**user question:**\\nThere \n            are 152 students at Dala High School. Assume the following: \\n- 100 students take a Math class \\n- 94 \n            students take a Science class \\n- 57 students take an English class \\n- 73 students take a Math class \n            and a Science class \\n- 24 students take a Math class and an English class \\n- 27 students take a Science \n            class and an English class \\n- 22 students take a Math class and a Science class and an English class\\n \\nHow \n            many students take neither a Math class nor a Science class nor an Eglish class?\\n\\nRemember to place the final \n            answer in the last part using the format: \\n<answer>\\n\\boxed{'The final answer goes here.'}\\n</answer>\n        \"\"\",\n    }\n    expect_turn_0_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"\"\"\n            Okay, so I need to find out how many students at Dala High School are not taking any of the three classes: Math, \n            Science, or English. The total number of students is 152. Let me see... I remember this is a problem about sets \n            and maybe using the principle of inclusion-exclusion. Let me recall how that works.\\n\\nFirst, the inclusion-exclusion \n            principle for three sets says that the total number of students taking at least one of the classes is equal to the \n            sum of the numbers in each individual class, minus the sum of the numbers in each pair of classes, plus the number in \n            all three classes. Then, subtract that total from the overall number of students to get those not taking any of the \n            three. \\n\\nLet me write that down step by step. Let M be the set of students taking Math, S for Science, and E for English. \n            Then:\\n\\nTotal in at least one class = |M ∪ S ∪ E| = |M| + |S| + |E| - |M ∩ S| - |M ∩ E| - |S ∩ E| + |M ∩ S ∩ E|\\n\\nGiven the \n            numbers:\\n\\n|M| = 100\\n\\n|S| = 94\\n\\n|E| = 57\\n\\n|M ∩ S| = 73\\n\\n|M ∩ E| = 24\\n\\n|S ∩ E| = 27\\n\\n|M ∩ S ∩ E| = 22\\n\\nSo plugging \n            these into the formula:\\n\\nTotal = 100 + 94 + 57 - 73 - 24 - 27 + 22\\n\\nLet me compute that step by step using code to ensure \n            accuracy.\\n\n        \"\"\",\n        \"tool_calls\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"code_interpreter\",\n                    \"arguments\": {\n                        \"code\": \"\"\"M = 100\\\\nS = 94\\\\nE = 57\\\\nM_S = 73\\\\nM_E = 24\\\\nS_E = 27\\\\nM_S_E = \n                        22\\\\n\\\\ntotal_in_any = M + S + E - M_S - M_E - S_E + M_S_E\\\\nstudents_neither = 152 - total_in_any\\\\nprint(students_neither)\"\"\",\n                    },\n                },\n            }\n        ],\n    }\n\n    tool_return_0_msg = {\n        \"role\": \"tool\",\n        \"content\": \"\"\"3\"\"\",\n    }\n\n    expect_turn_1_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"\"\"\n            The code returns 3, indicating that 3 students take none of the classes. To verify, I'll break down each\n            category to ensure the total matches:\\n\\n- Only Math and Science: 73 - 22 = 51\\n- Only Math and English:\n            24 - 22 = 2\\n- Only Science and English: 27 - 22 = 5\\n- Only Math: 100 - 51 - 2 - 22 = 25\\n- Only Science:\n            94 - 51 - 5 - 22 = 16\\n- Only English: 57 - 2 - 5 - 22 = 28\\n\\nSumming all categories:\\n\n        \"\"\",\n        \"tool_calls\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"code_interpreter\",\n                    \"arguments\": {\n                        \"code\": \"\"\"only_M_S = 73 - 22\\\\nonly_M_E = 24 - 22\\\\nonly_S_E = 27 - 22\\\\n\\\\nonly_M = 100 - only_M_S -\n                        only_M_E - 22\\\\nonly_S = 94 - only_M_S - only_S_E - 22\\\\nonly_E = 57 - only_M_E - only_S_E - 22\\\\n\\\\ntotal_verify\n                        = only_M + only_S + only_E + only_M_S + only_M_E + only_S_E + 22\\\\nprint(total_verify)\"\"\",\n                    },\n                },\n            }\n        ],\n    }\n\n    tool_return_1_msg = {\n        \"role\": \"tool\",\n        \"content\": \"\"\"149\"\"\",\n    }\n    expect_turn_2_msg = {\n        \"role\": \"assistant\",\n        \"content\": \"\"\"\n            The verification total is 149, so students not taking any classes are 152 - 149 = 3, confirming the initial\n            result.\\n\\n<answer>\\n\\\\boxed{3}\\n</answer>\n        \"\"\",\n    }\n\n    user_prompts = [user_prompt]\n    expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg]\n    tool_return_array = [tool_return_0_msg, tool_return_1_msg]\n\n    return user_prompts, expect_turn_array, tool_return_array\n\n\ndef skip_if_valid_sandbox(url):\n    def decorator(func):\n        @wraps(func)\n        def wrapper(*args, **kwargs):\n            if url == \"\" or url is None:\n                pytest.skip(\"No valid sandbox url provided\")\n\n        return wrapper\n\n    return decorator\n\n\nclass TestRolloutWithTools:\n    @pytest.fixture\n    def qwen_tokenizer(self):\n        local_model_path = \"Qwen/Qwen2.5-0.5B\"\n        tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side=\"left\")\n        tokenizer.pad_token = tokenizer.eos_token\n        return tokenizer\n\n    # we only need this for tokenizer\n    @pytest.fixture\n    def qwen_model_config(self):\n        local_model_path = \"Qwen/Qwen2.5-0.5B\"\n        config = AutoConfig.from_pretrained(local_model_path)\n        return config\n\n    @pytest.fixture\n    def sandbox_fusion_data(self, qwen_tokenizer):\n        user_prompt, expect_turn_array, tool_return_array = get_sandbox_fusion_messages()\n        prompts = [[message] for message in user_prompt]\n        preencode_turn_array = [\n            qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False)\n            for turn in expect_turn_array\n        ]\n        preencode_tool_return_array = [\n            qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True)\n            for turn in tool_return_array\n        ]\n        return prompts, preencode_turn_array, preencode_tool_return_array\n\n    @pytest.fixture\n    def sandbox_fusion_rollout_config(self):\n        max_prompt_length = 1024\n        max_response_length = 1024\n        dtype = \"bfloat16\"\n        tensor_parallel_size = 1\n        tool_path = \"./resource/tool_configs/sandbox_fusion_tool_config\"\n        rollout_config = get_rollout_config(\n            max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path\n        )\n        return rollout_config\n\n    @pytest.fixture\n    def sandbox_data_proto(self, sandbox_fusion_data, qwen_tokenizer):\n        preencode_prompts, _, _ = sandbox_fusion_data\n        prompts = [\n            qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n            for message in preencode_prompts\n        ]\n        input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000)\n        prompt_dict = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=input_ids.shape[0],\n        )\n        messages = np.asarray(preencode_prompts)\n        tools_kwargs = np.array(\n            [\n                {\n                    \"code_interpreter\": {\n                        \"create_kwargs\": {\"ground_truth\": \"test-solution-str\"},\n                    },\n                }\n            ],\n            dtype=object,\n        )\n        index = np.array([0], dtype=object)\n        prompts = DataProto(\n            batch=prompt_dict, non_tensor_batch={\"raw_prompt\": messages, \"tools_kwargs\": tools_kwargs, \"index\": index}\n        )\n        return prompts\n\n    @pytest.fixture\n    def mock_rollout(self, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config):\n        \"\"\"Mock the rollout instance\"\"\"\n        with patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None), patch.object(\n            SGLangRollout, \"_init_inference_engine\", return_value=None\n        ), patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None):\n            rollout = SGLangRollout(\n                actor_module=\"\",\n                config=sandbox_fusion_rollout_config,\n                processing_class=qwen_tokenizer,\n                model_hf_config=qwen_model_config,\n            )\n            # set default sampling_params\n            rollout.sampling_params = {\n                \"n\": 1,\n                \"max_new_tokens\": sandbox_fusion_rollout_config.response_length,\n                \"presence_penalty\": 0.0,\n                \"frequency_penalty\": 0.0,\n                \"repetition_penalty\": 1.0,\n            }\n            return rollout\n\n    def test_tools_registration(self, mock_rollout):\n        \"\"\"Test tool registration functionality\"\"\"\n        assert len(mock_rollout._tool_schemas) == 1\n        assert \"code_interpreter\" in mock_rollout._tool_map.keys()\n        from verl.tools.sandbox_fusion_tools import SandboxFusionTool\n\n        assert isinstance(mock_rollout._tool_map[\"code_interpreter\"], SandboxFusionTool)\n        assert mock_rollout._tool_call_parser_type == \"qwen25\"\n\n    def test_rollout_req_creation(self, mock_rollout, sandbox_data_proto):\n        \"\"\"Test request creation functionality\"\"\"\n        req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)\n        assert len(req_list) == 1\n        assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING\n        assert len(req_list[0].tool_schemas) == 1\n        print(type(req_list[0].tool_schemas[0]))\n        assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema(\n            type=\"function\",\n            function=OpenAIFunctionSchema(\n                name=\"code_interpreter\",\n                description=\"A tool for executing code.\",\n                parameters=OpenAIFunctionParametersSchema(\n                    type=\"object\",\n                    properties={\n                        \"code\": OpenAIFunctionPropertySchema(\n                            type=\"string\",\n                            description=\"The code to execute.\",\n                            enum=None,\n                        )\n                    },\n                    required=[\"code\"],\n                ),\n                strict=False,\n            ),\n        )\n\n    def test_over_size_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data):\n        \"\"\"Test over-size response truncation case\"\"\"\n        mock_rollout.config.multi_turn.max_assistant_turns = 1\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n\n        _, expect_turn_array, tool_return_array = sandbox_fusion_data\n        # here we mock a meta info with 'length'. indicate the response is truncate\n        mock_rollout._handle_engine_call = MagicMock()\n        future = asyncio.Future()\n        future.set_result(\n            {\n                \"text\": expect_turn_array[0],\n                \"meta_info\": {\n                    \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                    \"finish_reason\": {\"type\": \"length\", \"length\": 1024},\n                    \"prompt_tokens\": 132,\n                    \"completion_tokens\": 100,\n                    \"cached_tokens\": 0,\n                    \"e2e_latency\": 9.9304039478302,\n                },\n            }\n        )\n        mock_rollout._handle_engine_call.return_value = future\n        mock_rollout._tp_rank = 0\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(\n                *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list],\n            )\n        )\n        assert len(output_req_list) == 1\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        assert output_req.reward_scores.get(\"code_interpreter\") == []\n        # we should only have two message, one for prompt, second for response.\n        assert len(output_req.messages) == 2\n        assert output_req.messages[1] == Message(\n            role=\"assistant\",\n            content=expect_turn_array[0],\n            tool_calls=None,\n        )\n\n    @skip_if_valid_sandbox(sandbox_url)\n    def test_tool_call_basic_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data):\n        \"\"\"Test basic tool call case\"\"\"\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        mock_rollout._tool_map[\"code_interpreter\"].sandbox_fusion_url = sandbox_url\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]\n        req = MagicMock(wraps=req, spec=AsyncRolloutRequest)\n        req.finalize = MagicMock()\n        req_list = [req]\n        _, expect_turn_array, tool_return_array = sandbox_fusion_data\n        # here we mock a meta info with 'length'. indicate the response is truncate\n        mock_rollout._handle_engine_call = MagicMock()\n        futures = [asyncio.Future() for i in expect_turn_array]\n        for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)):\n            i.set_result(\n                {\n                    \"text\": turn,\n                    \"meta_info\": {\n                        \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                        \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                        \"prompt_tokens\": len(turn),\n                        \"completion_tokens\": 100,\n                        \"cached_tokens\": 0,\n                        \"e2e_latency\": 9.9304039478302,\n                    },\n                }\n            )\n            if idx < len(expect_turn_array) - 1:\n                assert mock_rollout._function_call_parser.has_tool_call(turn)\n                assert mock_rollout._function_call_parser.parse_non_stream(turn)\n\n        mock_rollout._handle_engine_call.side_effect = futures\n        mock_rollout._tp_rank = 0\n        loop = asyncio.get_event_loop()\n        output_req_list = loop.run_until_complete(\n            asyncio.gather(\n                *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list],\n            )\n        )\n        assert len(output_req_list) == 1\n        output_req = output_req_list[0]\n        assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n        # here we verify whether the code sandbox is executed correctly\n        assert output_req.metrics == {\"code_interpreter\": [\"3\", \"149\"]}\n        assert mock_rollout._handle_engine_call.call_count == 3\n        assert len(output_req.messages) == 6  # user + 3*assistant + 2*tool_call\n        code_counter = 0\n        for msg in output_req.messages:\n            if msg.role == \"tool\":\n                code_counter += 1\n                assert msg.content == tool_return_array[code_counter]\n        assert code_counter == 2\n\n    @skip_if_valid_sandbox(sandbox_url)\n    def test_tool_call_batch_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data):\n        \"\"\"Test batch tool call case\"\"\"\n        mock_rollout.config.multi_turn.max_assistant_turns = 10\n        mock_rollout._tool_map[\"code_interpreter\"].sandbox_fusion_url = sandbox_url\n        req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]\n        req_nums = 100\n        req_list = []\n        req_turns_counter = {}\n        # this map should a Map[id:List[Futures]]\n        req_turns_map = {}\n        _, expect_turn_array, tool_return_array = sandbox_fusion_data\n        for i in range(req_nums):\n            _temp_req = deepcopy(req)\n            _temp_req.batch_data_id = i\n            _temp_req.request_id = i\n            req_list.append(MagicMock(wraps=_temp_req, spec=AsyncRolloutRequest))\n            futures = [asyncio.Future() for i in expect_turn_array]\n            for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)):\n                i.set_result(\n                    {\n                        \"text\": turn,\n                        \"meta_info\": {\n                            \"id\": \"d1188d81cba840359df5b352b344bc8e\",\n                            \"finish_reason\": {\"type\": \"tool_calls\" if idx < len(expect_turn_array) - 1 else \"stop\"},\n                            \"prompt_tokens\": len(turn),\n                            \"completion_tokens\": 100,\n                            \"cached_tokens\": 0,\n                            \"e2e_latency\": 9.9304039478302,\n                        },\n                    }\n                )\n                if idx < len(expect_turn_array) - 1:\n                    assert mock_rollout._function_call_parser.has_tool_call(turn)\n                    assert mock_rollout._function_call_parser.parse_non_stream(turn)\n            req_turns_map[_temp_req.batch_data_id] = futures\n            req_turns_counter[_temp_req.batch_data_id] = 0\n\n        async def hacked_handle_engine_call(\n            self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs\n        ):\n            result = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]]\n            req_turns_counter[_req.batch_data_id] += 1\n            re = await result\n            return re\n\n        with patch.object(SGLangRollout, \"_handle_engine_call\", new=hacked_handle_engine_call):\n            mock_rollout._tp_rank = 0\n            loop = asyncio.get_event_loop()\n            output_req_list = loop.run_until_complete(\n                asyncio.gather(\n                    *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list],\n                )\n            )\n            assert len(output_req_list) == req_nums\n            # FIGUER out how to count this\n            # assert rollout._handle_engine_call.call_count == 3 * req_nums\n            for output_req in output_req_list:\n                assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED\n                # here we verify whether the code sandbox is executed correctly\n                assert output_req.metrics == {\"code_interpreter\": [\"3\", \"149\"]}\n                assert len(output_req.messages) == 6  # user + 3*assistant + 2*tool_call\n                code_counter = 0\n                for msg in output_req.messages:\n                    if msg.role == \"tool\":\n                        code_counter += 1\n                assert code_counter == 2\n\n    def test_sampling_params_functionality(self, mock_rollout):\n        \"\"\"Test sampling_params functionality\"\"\"\n        # test basic copy functionality\n        copied_params = mock_rollout.sampling_params.copy()\n        assert copied_params == mock_rollout.sampling_params\n        assert copied_params is not mock_rollout.sampling_params\n\n        # test parameter update\n        copied_params.update({\"temperature\": 0.8, \"top_p\": 0.9})\n        assert copied_params[\"temperature\"] == 0.8\n        assert copied_params[\"top_p\"] == 0.9\n\n        # ensure original parameters are not modified\n        assert \"temperature\" not in mock_rollout.sampling_params\n        assert \"top_p\" not in mock_rollout.sampling_params\n\n\nclass RayMultiProcessTestCase(MultiProcessTestCase):\n    def setUp(self):\n        super().setUp()\n        ray.init(ignore_reinit_error=True)\n        print(\"init_single cluster\")\n        self._spawn_processes()\n\n    def tearDown(self):\n        print(\"tearDown_single cluster\")\n        ray.shutdown()\n\n\n@ray.remote\nclass TestActor:\n    def __init__(self, rank, world_size):\n        self._world_size = world_size\n        self._rank = rank\n        self.rank_list = []\n        self.time_list = []\n\n    def record_rank(self, rank):\n        self.rank_list.append(rank)\n\n    def get_rank(self):\n        return self._rank\n\n    def ping(self):\n        return True\n\n    def record_execution_time(self, time):\n        self.time_list.append(time)\n\n    def get_time(self, timeout):\n        import time\n\n        now = time.time()\n        while time.time() - now < timeout:\n            # for start and end time\n            if len(self.time_list) == self._world_size * 2:\n                self.time_list.sort()\n                return self.time_list[-1] - self.time_list[0]\n            else:\n                time.sleep(1)\n                continue\n        return False\n\n    def verify_rank(self):\n        import time\n\n        now = time.time()\n        while time.time() - now < 10:\n            if len(self.rank_list) == self._world_size:\n                print(self.rank_list)\n                self.rank_list.sort()\n                for i in range(self._world_size):\n                    if self.rank_list[i] != i:\n                        return False\n                return True\n            else:\n                time.sleep(1)\n                continue\n        return False\n\n\nclass TestRayGlobalActorCase(RayMultiProcessTestCase):\n    @property\n    def world_size(self) -> int:\n        # for DP = 8\n        return 2\n\n    def test_basic_multi_process_init(self):\n        ray.init(\"auto\", namespace=\"test\", ignore_reinit_error=True)\n        handle = TestActor.remote(self.rank, self.world_size)\n        re = ray.get(handle.get_rank.remote())\n        assert re == self.rank, f\"rank not match: {re} != {self.rank}\"\n\n    # def test_global_actor(self):\n    #     ray.init(\"auto\",namespace=\"test\",ignore_reinit_error=True)\n    #     handle = TestActor.options(get_if_exists=True,name=\"test-actor\").remote(self.rank,self.world_size)\n    #     handle.record_rank.remote(self.rank)\n    #     # since test actor's concurrency is 1, we need to wait for all processes to finish\n    #     time.sleep(5)\n    #     assert ray.get(handle.ping.remote()) == True # make sure actor handle is valid\n    #     if self.rank == 0:\n    #         assert ray.get(handle.verify_rank.remote()) == True\n    #     else:\n    #         # get_actor use weak_ref, so we need to make sure the actor is not garbage collected\n    #         time.sleep(10)\n\n\nclass TestSingleNodeRateLimiterCase(RayMultiProcessTestCase):\n    @property\n    def world_size(self) -> int:\n        return 1\n\n    def test_rate_limiter(self):\n        ray.init(\"auto\", namespace=\"test\", ignore_reinit_error=True)\n        from verl.tools.sandbox_fusion_tools import PoolMode, init_execution_pool\n\n        # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=3)\n        exec_worker = init_execution_pool(\n            num_workers=10, enable_global_rate_limit=True, rate_limit=3, mode=PoolMode.ThreadMode\n        )\n        center = TestActor.options(get_if_exists=True, name=\"test-actor\").remote(self.rank, self.world_size)\n        ray.get(exec_worker.ping.remote())\n\n        def fn(i):\n            import time\n\n            time.sleep(3)\n            return i\n\n        start = time.time()\n        tasks = [exec_worker.execute.remote(fn, i) for i in range(6)]\n        loop = asyncio.get_event_loop()\n        results = loop.run_until_complete(asyncio.gather(*tasks))\n        end = time.time()\n        duration = end - start\n        center.record_execution_time.remote(start)\n        center.record_execution_time.remote(end)\n        print(f\"Total time: {duration:.2f} seconds for rank: {self.rank}\")\n\n        assert results == list(range(6))\n        # we have 6 task with rate limit of 3, therefore we need at least 2 round: 3*2=6 seconds\n        assert duration > 6\n        assert duration < 10\n\n    def test_rotten_execution(self):\n        ray.init(\"auto\", namespace=\"test\", ignore_reinit_error=True)\n        from verl.tools.sandbox_fusion_tools import PoolMode, init_execution_pool\n\n        # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=6)\n        exec_worker = init_execution_pool(\n            num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode\n        )\n        ray.get(exec_worker.ping.remote())\n\n        def fn(i):\n            if i == 10:\n                raise Exception(\"test\")\n            else:\n                return i\n\n        tasks = [exec_worker.execute.remote(fn, i) for i in range(20)]\n        loop = asyncio.get_event_loop()\n        results = loop.run_until_complete(asyncio.gather(*tasks))\n        expect_result = [None] + list(range(10)) + list(range(11, 20))\n        sorted_data = sorted(results, key=lambda x: (x is not None, x))\n        assert sorted_data == expect_result, f\"results: {results}, expect_result: {expect_result}\"\n        rate_limiter = TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote()\n        rate = ray.get(rate_limiter.get_current_count.remote())\n        assert rate == 0, f\"rate: {rate}\"\n\n\nclass TestMultiNodeRateLimiterCase(RayMultiProcessTestCase):\n    @property\n    def world_size(self) -> int:\n        return 2\n\n    def test_rate_limiter(self):\n        ray.init(\"auto\", namespace=\"test\", ignore_reinit_error=True)\n        from verl.tools.sandbox_fusion_tools import PoolMode, init_execution_pool\n\n        # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=6)\n        exec_worker = init_execution_pool(\n            num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode\n        )\n        center = TestActor.options(get_if_exists=True, name=\"test-actor\").remote(self.rank, self.world_size)\n        ray.get(exec_worker.ping.remote())\n\n        def fn(i):\n            import time\n\n            time.sleep(2)\n            return i\n\n        start = time.time()\n        tasks = [exec_worker.execute.remote(fn, i) for i in range(6)]\n        loop = asyncio.get_event_loop()\n        results = loop.run_until_complete(asyncio.gather(*tasks))\n        end = time.time()\n        duration = end - start\n        center.record_execution_time.remote(start)\n        center.record_execution_time.remote(end)\n        print(f\"Total time: {duration:.2f} seconds for rank: {self.rank}\")\n        assert results == list(range(6))\n        time.sleep(5)\n        if self.rank == 0:\n            total_cost = ray.get(center.get_time.remote(10))\n            print(f\"for total cost: {total_cost}\")\n            # # we have 6 task each node * 2node = 12 task, each task take 2 second.\n            # with rate limit of 6,\n            # therefore we need at least 2 round: 12/6*2=4 seconds\n            assert total_cost > 4, total_cost\n        else:\n            time.sleep(10)\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nusage: torchrun --standalone --nnodes=1 \\\n    --nproc_per_node=2 $(which pytest) \\\n    -s test_sglang_async_rollout_w_interaction.py\n\"\"\"\n\nimport numpy as np\nimport torch\nfrom tensordict import TensorDict\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import MixedPrecision, ShardingStrategy\nfrom utils_sglang import (\n    are_lists_similar,\n    clean_torchelastic_env,\n    generate_hf_output,\n    get_rollout_config,\n    initialize_global_process_group,\n    load_tokenizer_and_model,\n    prepare_inputs,\n)\n\nfrom verl import DataProto\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\nfrom verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager\n\n\ndef test_async_sglang_rollout_w_interaction():\n    assert torch.cuda.device_count() >= 2\n    initialize_global_process_group()\n    clean_torchelastic_env()\n\n    max_prompt_length = 32\n    max_response_length = 16\n    dtype = \"bfloat16\"\n    tensor_parallel_size = 2\n    local_model_path = \"Qwen/Qwen2.5-0.5B\"\n\n    tokenizer, actor_model = load_tokenizer_and_model(local_model_path)\n\n    preencode_prompts = [\n        [{\"role\": \"user\", \"content\": prompt, \"tool_calls\": None}]\n        for prompt in [\n            \"Who won the Champions League in 2019?\",\n            \"The founder of Apple is\",\n            \"What's the best way to learn python?\",\n        ]\n    ]\n    interaction_kwargs = [\n        {\"name\": \"gsm8k\", \"query\": \"Who won the Champions League in 2019?\", \"ground_truth\": \"Real Madrid\"},\n        {\"name\": \"gsm8k\", \"query\": \"The founder of Apple is\", \"ground_truth\": \"Steve Jobs\"},\n        {\"name\": \"gsm8k\", \"query\": \"What's the best way to learn python?\", \"ground_truth\": \"Learn python from scratch\"},\n    ]\n    prompts = [\n        tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n        for message in preencode_prompts\n    ]\n    input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length)\n\n    hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)\n\n    fsdp_device_mesh = init_device_mesh(\"cuda\", mesh_shape=(tensor_parallel_size,), mesh_dim_names=(\"fsdp\",))\n    inference_device_mesh_cpu = init_device_mesh(\n        \"cpu\", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=(\"dp\", \"infer_tp\", \"pp\")\n    )\n\n    fsdp_model = FSDP(\n        actor_model,\n        use_orig_params=True,\n        device_id=fsdp_device_mesh[\"fsdp\"].get_local_rank(),\n        mixed_precision=MixedPrecision(param_dtype=getattr(torch, dtype)),\n        sharding_strategy=ShardingStrategy.FULL_SHARD,\n        device_mesh=fsdp_device_mesh,\n    )\n\n    # Create a temporary interaction config file for testing\n    import tempfile\n\n    from omegaconf import OmegaConf\n\n    interaction_config = {\n        \"interaction\": [\n            {\"name\": \"gsm8k\", \"class_name\": \"verl.interactions.gsm8k_interaction.Gsm8kInteraction\", \"config\": {}}\n        ]\n    }\n\n    with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n        OmegaConf.save(interaction_config, f.name)\n        interaction_config_path = f.name\n\n    rollout_config = get_rollout_config(\n        max_response_length, max_prompt_length, dtype, tensor_parallel_size, None, interaction_config_path\n    )\n    rollout = SGLangRollout(\n        actor_module=local_model_path,\n        config=rollout_config,\n        processing_class=tokenizer,\n        model_hf_config=actor_model.config,\n    )\n\n    rollout_sharding_manager = FSDPSGLangShardingManager(\n        module=fsdp_model,\n        inference_engine=rollout._engine,\n        model_config=actor_model.config,\n        rollout_config=rollout_config,\n        full_params=True,\n        device_mesh=inference_device_mesh_cpu,\n    )\n\n    with rollout_sharding_manager:\n        prompt_dict = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=input_ids.shape[0],\n        )\n        print(f\"preprocessed {input_ids.shape=}\")\n\n        messages = np.asarray(preencode_prompts)\n        prompts = DataProto(\n            batch=prompt_dict,\n            non_tensor_batch={\"raw_prompt\": messages, \"interaction_kwargs\": np.asarray(interaction_kwargs)},\n        )\n\n        prompts.meta_info.update(\n            {\n                \"eos_token_id\": tokenizer.eos_token_id,\n                \"pad_token_id\": tokenizer.pad_token_id,\n            }\n        )\n\n        prompts = rollout_sharding_manager.preprocess_data(prompts)\n        # log_gpu_memory_usage(\"Before generating sequences\", logger=None)\n        output = rollout.generate_sequences(prompts=prompts)\n        print(f\"generated {output.batch['responses'].shape=}\")\n        # log_gpu_memory_usage(\"After generating sequences\", logger=None)\n        output = rollout_sharding_manager.postprocess_data(output)\n        print(f\"postprocessed {output.batch['responses'].shape=}\")\n        sglang_output = output.to(\"cpu\")\n\n    sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch[\"responses\"])\n\n    print(f\"hf response: {hf_response_tokens}\")\n    print(f\"sglang response: {sglang_response_tokens}\")\n    assert are_lists_similar(hf_response_tokens, sglang_response_tokens)\n    print(\"SGLang w interaction Test Passed!\")\n\n    # Clean up temporary config file\n    import os\n\n    os.unlink(interaction_config_path)\n\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    test_async_sglang_rollout_w_interaction()\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_sglang_async_rollout_w_tools.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nusage: torchrun --standalone --nnodes=1 \\\n    --nproc_per_node=2 $(which pytest) \\\n    -s test_sglang_async_rollout_w_tools.py\n\"\"\"\n\nimport numpy as np\nimport torch\nfrom tensordict import TensorDict\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import MixedPrecision, ShardingStrategy\nfrom utils_sglang import (\n    are_lists_similar,\n    clean_torchelastic_env,\n    generate_hf_output,\n    get_rollout_config,\n    initialize_global_process_group,\n    load_tokenizer_and_model,\n    prepare_inputs,\n)\n\nfrom verl import DataProto\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\nfrom verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager\n\n\ndef test_async_sglang_rollout_w_tool():\n    assert torch.cuda.device_count() >= 2\n    initialize_global_process_group()\n    clean_torchelastic_env()\n\n    max_prompt_length = 32\n    max_response_length = 16\n    dtype = \"bfloat16\"\n    tensor_parallel_size = 2\n    local_model_path = \"Qwen/Qwen2.5-0.5B\"\n\n    tokenizer, actor_model = load_tokenizer_and_model(local_model_path)\n\n    preencode_prompts = [\n        [{\"role\": \"user\", \"content\": prompt, \"tool_calls\": None}]\n        for prompt in [\n            \"Who won the Champions League in 2019?\",\n            \"The founder of Apple is\",\n            \"What's the best way to learn python?\",\n        ]\n    ]\n    prompts = [\n        tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n        for message in preencode_prompts\n    ]\n    input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length)\n\n    hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)\n\n    fsdp_device_mesh = init_device_mesh(\"cuda\", mesh_shape=(tensor_parallel_size,), mesh_dim_names=(\"fsdp\",))\n    inference_device_mesh_cpu = init_device_mesh(\n        \"cpu\", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=(\"dp\", \"infer_tp\", \"pp\")\n    )\n\n    fsdp_model = FSDP(\n        actor_model,\n        use_orig_params=True,\n        device_id=fsdp_device_mesh[\"fsdp\"].get_local_rank(),\n        mixed_precision=MixedPrecision(param_dtype=getattr(torch, dtype)),\n        sharding_strategy=ShardingStrategy.FULL_SHARD,\n        device_mesh=fsdp_device_mesh,\n    )\n\n    rollout_config = get_rollout_config(\n        max_response_length,\n        max_prompt_length,\n        dtype,\n        tensor_parallel_size,\n        \"./resource/tool_configs/sandbox_fusion_tool_config\",\n    )\n    rollout = SGLangRollout(\n        actor_module=local_model_path,\n        config=rollout_config,\n        processing_class=tokenizer,\n        model_hf_config=actor_model.config,\n    )\n\n    rollout_sharding_manager = FSDPSGLangShardingManager(\n        module=fsdp_model,\n        inference_engine=rollout._engine,\n        model_config=actor_model.config,\n        rollout_config=rollout_config,\n        full_params=True,\n        device_mesh=inference_device_mesh_cpu,\n    )\n\n    with rollout_sharding_manager:\n        prompt_dict = TensorDict(\n            {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=input_ids.shape[0],\n        )\n        print(f\"preprocessed {input_ids.shape=}\")\n\n        messages = np.asarray(preencode_prompts)\n        prompts = DataProto(\n            batch=prompt_dict,\n            non_tensor_batch={\n                \"raw_prompt\": messages,\n                \"tools_kwargs\": np.array([{}] * input_ids.shape[0], dtype=object),\n            },\n        )\n\n        prompts.meta_info.update(\n            {\n                \"eos_token_id\": tokenizer.eos_token_id,\n                \"pad_token_id\": tokenizer.pad_token_id,\n            }\n        )\n\n        prompts = rollout_sharding_manager.preprocess_data(prompts)\n        # log_gpu_memory_usage(\"Before generating sequences\", logger=None)\n        output = rollout.generate_sequences(prompts=prompts)\n        print(f\"generated {output.batch['responses'].shape=}\")\n        # log_gpu_memory_usage(\"After generating sequences\", logger=None)\n        output = rollout_sharding_manager.postprocess_data(output)\n        print(f\"postprocessed {output.batch['responses'].shape=}\")\n        sglang_output = output.to(\"cpu\")\n\n    sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch[\"responses\"])\n\n    print(f\"hf response: {hf_response_tokens}\")\n    print(f\"sglang response: {sglang_response_tokens}\")\n    assert are_lists_similar(hf_response_tokens, sglang_response_tokens)\n    print(\"SGLang w tool Test Passed!\")\n\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    test_async_sglang_rollout_w_tool()\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_sglang_multi_interaction.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n\"\"\"\nTest for multi-interaction support in SGLangRollout.\nusage: torchrun --standalone --nnodes=1 \\\n    --nproc_per_node=2 $(which pytest) \\\n    -s test_sglang_multi_interaction.py\n\"\"\"\n\nimport os\nimport tempfile\nfrom unittest.mock import MagicMock, patch\n\nimport torch\nimport torch.distributed as dist\nfrom omegaconf import DictConfig, OmegaConf\nfrom transformers import AutoTokenizer\n\nfrom verl.interactions.base import BaseInteraction\nfrom verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout\n\n\nclass MockInteraction(BaseInteraction):\n    \"\"\"Mock interaction for testing.\"\"\"\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.started_instances = set()\n\n    async def start_interaction(self, instance_id=None, **kwargs):\n        if instance_id is None:\n            instance_id = \"mock_instance\"\n        self.started_instances.add(instance_id)\n        return instance_id\n\n    async def generate_response(self, instance_id, messages, **kwargs):\n        return False, f\"Mock response from {self.name}\", 1.0, {}\n\n\ndef create_mock_config_with_multi_interactions():\n    \"\"\"Create a mock configuration with multiple interactions.\"\"\"\n    # Create temporary interaction config file\n    interaction_config = {\n        \"interaction\": [\n            {\n                \"name\": \"mock_agent1\",\n                \"class_name\": \"tests.workers.rollout.test_sglang_multi_interaction.MockInteraction\",\n                \"config\": {\"param1\": \"value1\"},\n            },\n            {\n                \"name\": \"mock_agent2\",\n                \"class_name\": \"tests.workers.rollout.test_sglang_multi_interaction.MockInteraction\",\n                \"config\": {\"param2\": \"value2\"},\n            },\n        ]\n    }\n\n    with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n        OmegaConf.save(interaction_config, f.name)\n        interaction_config_path = f.name\n\n    # Create mock SGLangRollout config\n    config = DictConfig(\n        {\n            \"multi_turn\": {\n                \"interaction_config_path\": interaction_config_path,\n                \"tool_config_path\": None,\n                \"enable\": True,\n                \"max_assistant_turns\": 5,\n                \"max_user_turns\": 3,\n                \"use_inference_chat_template\": True,\n                \"tokenization_sanity_check_mode\": \"off\",\n            },\n            \"prompt_length\": 32,\n            \"response_length\": 16,\n            \"max_model_len\": 512,\n            \"dtype\": \"bfloat16\",\n            \"gpu_memory_utilization\": 0.8,\n            \"load_format\": \"dummy\",\n            \"enforce_eager\": True,\n            \"free_cache_engine\": False,\n            \"calculate_log_probs\": False,\n            \"tensor_model_parallel_size\": 1,\n            \"n\": 1,\n            \"val_kwargs\": {\"top_k\": 1, \"top_p\": 1.0, \"temperature\": 0.0},\n        }\n    )\n\n    return config, interaction_config_path\n\n\ndef setup_distributed():\n    \"\"\"Initialize distributed environment if not already initialized.\"\"\"\n    if not dist.is_initialized():\n        dist.init_process_group(backend=\"nccl\" if torch.cuda.is_available() else \"gloo\")\n\n\nclass TestSGLangMultiInteraction:\n    def test_initialize_multiple_interactions(self):\n        \"\"\"Test that SGLangRollout can initialize multiple interactions.\"\"\"\n        setup_distributed()\n        config, temp_config_path = create_mock_config_with_multi_interactions()\n\n        try:\n            # Mock SGLang engine and initialization methods like the reference test\n            with (\n                patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n                patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n                patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n            ):\n                # Create a real tokenizer like the reference test\n                tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-0.5B\", padding_side=\"left\")\n                tokenizer.pad_token = tokenizer.eos_token\n\n                # Mock model config\n                mock_model_config = MagicMock()\n                mock_model_config.max_position_embeddings = 2048\n                # since this is a mock, we can set any rope scaling config\n                # to test the rope_scaling logic at the same time of this test\n                mock_model_config.rope_scaling = {\n                    \"factor\": 4.0,\n                    \"original_max_position_embeddings\": 32768,\n                    \"type\": \"yarn\",\n                }\n\n                # Create SGLangRollout instance\n                rollout = SGLangRollout(\n                    actor_module=\"mock_model\",\n                    config=config,\n                    processing_class=tokenizer,\n                    model_hf_config=mock_model_config,\n                    port=None,\n                    trust_remote_code=False,\n                    device_mesh=None,\n                )\n\n                # Check that interactions were initialized\n                assert len(rollout.interaction_map) == 2\n                assert \"mock_agent1\" in rollout.interaction_map\n                assert \"mock_agent2\" in rollout.interaction_map\n\n                # Use class name comparison instead of isinstance for multi-process compatibility\n                assert rollout.interaction_map[\"mock_agent1\"].__class__.__name__ == \"MockInteraction\"\n                assert rollout.interaction_map[\"mock_agent2\"].__class__.__name__ == \"MockInteraction\"\n\n                # Also check that they are instances of BaseInteraction (which should work across processes)\n                assert isinstance(rollout.interaction_map[\"mock_agent1\"], BaseInteraction)\n                assert isinstance(rollout.interaction_map[\"mock_agent2\"], BaseInteraction)\n\n                # Check that names were set correctly\n                assert rollout.interaction_map[\"mock_agent1\"].name == \"mock_agent1\"\n                assert rollout.interaction_map[\"mock_agent2\"].name == \"mock_agent2\"\n\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_interaction_selection_by_name(self):\n        \"\"\"Test that interactions are selected by name from interaction_kwargs.\"\"\"\n        setup_distributed()\n        config, temp_config_path = create_mock_config_with_multi_interactions()\n\n        try:\n            with (\n                patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n                patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n                patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n            ):\n                tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-0.5B\", padding_side=\"left\")\n                tokenizer.pad_token = tokenizer.eos_token\n\n                mock_model_config = MagicMock()\n                mock_model_config.max_position_embeddings = 2048\n                mock_model_config.rope_scaling = {\n                    \"factor\": 4.0,\n                    \"original_max_position_embeddings\": 32768,\n                    \"type\": \"yarn\",\n                }\n\n                rollout = SGLangRollout(\n                    actor_module=\"mock_model\",\n                    config=config,\n                    processing_class=tokenizer,\n                    model_hf_config=mock_model_config,\n                    port=None,\n                    trust_remote_code=False,\n                    device_mesh=None,\n                )\n\n                # Test interaction selection logic\n                from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message\n\n                # Create a mock request with specific interaction name\n                req = AsyncRolloutRequest(\n                    request_id=\"test_req\",\n                    state=AsyncRolloutRequestStateEnum.INTERACTING,\n                    messages=[Message(role=\"user\", content=\"test message\")],\n                    interaction_kwargs={\"name\": \"mock_agent2\", \"test_param\": \"value\"},\n                    input_ids=None,\n                    prompt_ids=None,\n                    response_ids=None,\n                    attention_mask=None,\n                    prompt_attention_mask=None,\n                    response_attention_mask=None,\n                    position_ids=None,\n                    prompt_position_ids=None,\n                    response_position_ids=None,\n                    loss_mask=None,\n                    prompt_loss_mask=None,\n                    response_loss_mask=None,\n                    reward_scores={},\n                    max_prompt_len=32,\n                    max_response_len=16,\n                    max_model_len=512,\n                    use_inference_chat_template=True,\n                    tokenization_sanity_check_mode=\"disable\",\n                    processing_class=tokenizer,\n                )\n\n                # Test that the correct interaction is selected\n                interaction_name = req.interaction_kwargs.get(\"name\", \"gsm8k\")\n                assert interaction_name == \"mock_agent2\"\n                assert interaction_name in rollout.interaction_map\n\n                selected_interaction = rollout.interaction_map[interaction_name]\n                assert selected_interaction.name == \"mock_agent2\"\n\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_fallback_to_default_interaction(self):\n        \"\"\"Test fallback to default interaction when name is not specified.\"\"\"\n        setup_distributed()\n        # Create config with gsm8k interaction\n        interaction_config = {\n            \"interaction\": [\n                {\n                    \"name\": \"gsm8k\",\n                    \"class_name\": \"tests.workers.rollout.test_sglang_multi_interaction.MockInteraction\",\n                    \"config\": {},\n                }\n            ]\n        }\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".yaml\", delete=False) as f:\n            OmegaConf.save(interaction_config, f.name)\n            interaction_config_path = f.name\n\n        config = DictConfig(\n            {\n                \"multi_turn\": {\n                    \"interaction_config_path\": interaction_config_path,\n                    \"tool_config_path\": None,\n                    \"enable\": True,\n                    \"max_assistant_turns\": 5,\n                    \"max_user_turns\": 3,\n                    \"use_inference_chat_template\": True,\n                    \"tokenization_sanity_check_mode\": \"disable\",\n                },\n                \"prompt_length\": 32,\n                \"response_length\": 16,\n                \"max_model_len\": 512,\n                \"dtype\": \"bfloat16\",\n                \"gpu_memory_utilization\": 0.8,\n                \"load_format\": \"dummy\",\n                \"enforce_eager\": True,\n                \"free_cache_engine\": False,\n                \"calculate_log_probs\": False,\n                \"tensor_model_parallel_size\": 1,\n                \"n\": 1,\n                \"val_kwargs\": {\"top_k\": 1, \"top_p\": 1.0, \"temperature\": 0.0},\n            }\n        )\n\n        try:\n            with (\n                patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n                patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n                patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n            ):\n                tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-0.5B\", padding_side=\"left\")\n                tokenizer.pad_token = tokenizer.eos_token\n\n                mock_model_config = MagicMock()\n                mock_model_config.max_position_embeddings = 2048\n                mock_model_config.rope_scaling = {\n                    \"factor\": 4.0,\n                    \"original_max_position_embeddings\": 32768,\n                    \"type\": \"yarn\",\n                }\n\n                rollout = SGLangRollout(\n                    actor_module=\"mock_model\",\n                    config=config,\n                    processing_class=tokenizer,\n                    model_hf_config=mock_model_config,\n                    port=None,\n                    trust_remote_code=False,\n                    device_mesh=None,\n                )\n\n                # Test that default interaction name works\n                interaction_kwargs_without_name = {\"test_param\": \"value\"}\n                default_name = interaction_kwargs_without_name.get(\"name\", \"gsm8k\")\n                assert default_name == \"gsm8k\"\n                assert default_name in rollout.interaction_map\n\n        finally:\n            os.unlink(interaction_config_path)\n\n    def test_error_on_missing_interaction(self):\n        \"\"\"Test that error is raised when requested interaction is not found.\"\"\"\n        setup_distributed()\n        config, temp_config_path = create_mock_config_with_multi_interactions()\n\n        try:\n            with (\n                patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n                patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n                patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n            ):\n                tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-0.5B\", padding_side=\"left\")\n                tokenizer.pad_token = tokenizer.eos_token\n\n                mock_model_config = MagicMock()\n                mock_model_config.max_position_embeddings = 2048\n                mock_model_config.rope_scaling = {\n                    \"factor\": 4.0,\n                    \"original_max_position_embeddings\": 32768,\n                    \"type\": \"yarn\",\n                }\n\n                rollout = SGLangRollout(\n                    actor_module=\"mock_model\",\n                    config=config,\n                    processing_class=tokenizer,\n                    model_hf_config=mock_model_config,\n                    port=None,\n                    trust_remote_code=False,\n                    device_mesh=None,\n                )\n\n                # Test error when requesting non-existent interaction\n                non_existent_name = \"non_existent_interaction\"\n                assert non_existent_name not in rollout.interaction_map\n\n                # This should raise ValueError in actual usage\n                available_interactions = list(rollout.interaction_map.keys())\n                assert \"mock_agent1\" in available_interactions\n                assert \"mock_agent2\" in available_interactions\n                assert non_existent_name not in available_interactions\n\n        finally:\n            os.unlink(temp_config_path)\n\n    def test_backward_compatibility_no_interaction_config(self):\n        \"\"\"Test backward compatibility when no interaction config is provided.\"\"\"\n        setup_distributed()\n        # Create config without interaction config\n        config = DictConfig(\n            {\n                \"multi_turn\": {\n                    \"interaction_config_path\": None,\n                    \"tool_config_path\": None,\n                    \"enable\": True,\n                    \"max_assistant_turns\": 5,\n                    \"max_user_turns\": 3,\n                    \"use_inference_chat_template\": True,\n                    \"tokenization_sanity_check_mode\": \"disable\",\n                },\n                \"prompt_length\": 32,\n                \"response_length\": 16,\n                \"max_model_len\": 512,\n                \"dtype\": \"bfloat16\",\n                \"gpu_memory_utilization\": 0.8,\n                \"load_format\": \"dummy\",\n                \"enforce_eager\": True,\n                \"free_cache_engine\": False,\n                \"calculate_log_probs\": False,\n                \"tensor_model_parallel_size\": 1,\n                \"n\": 1,\n                \"val_kwargs\": {\"top_k\": 1, \"top_p\": 1.0, \"temperature\": 0.0},\n            }\n        )\n\n        with (\n            patch.object(SGLangRollout, \"_init_distributed_env\", return_value=None),\n            patch.object(SGLangRollout, \"_init_inference_engine\", return_value=None),\n            patch.object(SGLangRollout, \"_init_sampling_params\", return_value=None),\n        ):\n            tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-0.5B\", padding_side=\"left\")\n            tokenizer.pad_token = tokenizer.eos_token\n\n            mock_model_config = MagicMock()\n            mock_model_config.max_position_embeddings = 2048\n            mock_model_config.rope_scaling = {\n                \"factor\": 4.0,\n                \"original_max_position_embeddings\": 32768,\n                \"type\": \"yarn\",\n            }\n\n            rollout = SGLangRollout(\n                actor_module=\"mock_model\",\n                config=config,\n                processing_class=tokenizer,\n                model_hf_config=mock_model_config,\n                port=None,\n                trust_remote_code=False,\n                device_mesh=None,\n            )\n\n            # Check that no interactions were initialized\n            assert len(rollout.interaction_map) == 0\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_sglang_rollout_sharding_manager.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\n\nfrom verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets\n\n_TENSOR_1MB = torch.zeros(512, 512)\n_BYTES_1MB = 1 << 20\n\n\n@pytest.mark.parametrize(\n    \"named_tensors, bucket_size_mb, gt_groups\",\n    [\n        (\n            [(\"a\", _TENSOR_1MB), (\"b\", _TENSOR_1MB)],\n            0.5 * _BYTES_1MB,\n            [[\"a\"], [\"b\"]],\n        ),\n        (\n            [(\"a\", _TENSOR_1MB), (\"b\", _TENSOR_1MB)],\n            1 * _BYTES_1MB,\n            [[\"a\"], [\"b\"]],\n        ),\n        (\n            [(\"a\", _TENSOR_1MB), (\"b\", _TENSOR_1MB)],\n            1.5 * _BYTES_1MB,\n            [[\"a\"], [\"b\"]],\n        ),\n        (\n            [(\"a\", _TENSOR_1MB), (\"b\", _TENSOR_1MB)],\n            2 * _BYTES_1MB,\n            [[\"a\", \"b\"]],\n        ),\n    ],\n)\ndef test_get_named_tensor_buckets(named_tensors, bucket_size_mb, gt_groups: list[list[str]]):\n    named_tensors_iter = iter(named_tensors)\n    groups = list(get_named_tensor_buckets(named_tensors_iter, bucket_size_mb))\n    assert len(groups) == len(gt_groups)\n    for group, gt_group in zip(groups, gt_groups, strict=True):\n        assert len(group) == len(gt_group)\n        for (name, _), (gt_name) in zip(group, gt_group, strict=True):\n            assert name == gt_name\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/test_sglang_spmd.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nusage: torchrun --standalone --nnodes=1 \\\n    --nproc_per_node=2 $(which pytest) \\\n    -s test_sglang_async_spmd.py\n\"\"\"\n\nimport asyncio\n\nimport torch\nfrom sglang.srt.entrypoints.engine import Engine\nfrom sglang.srt.utils import broadcast_pyobj\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom utils_sglang import (\n    are_lists_similar,\n    clean_torchelastic_env,\n    generate_hf_output,\n    initialize_global_process_group,\n    load_tokenizer_and_model,\n    prepare_inputs,\n)\n\n\ndef _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):\n    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]\n    token_ids = prompt_token_ids[non_pad_index:].tolist()\n    return token_ids\n\n\ndef test_sglang_spmd():\n    assert torch.cuda.device_count() >= 2\n    initialize_global_process_group(spmd=True)\n    clean_torchelastic_env()\n\n    max_prompt_length = 16\n    max_response_length = 16\n\n    local_model_path = \"Qwen/Qwen2.5-0.5B\"\n    tokenizer, actor_model = load_tokenizer_and_model(local_model_path)\n\n    preencode_prompts = [\"Who won the Champions League in 2019?\", \"The founder of Apple is\", \"What's your name?\"]\n    input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length)\n\n    hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)\n\n    tensor_parallel_size = 2\n    inference_device_mesh_cpu = init_device_mesh(\n        \"cpu\", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=[\"dp\", \"tp\", \"pp\"]\n    )\n    tp_rank = inference_device_mesh_cpu[\"tp\"].get_local_rank()\n\n    if tp_rank == 0:\n        llm = Engine(\n            model_path=local_model_path,\n            dtype=\"bfloat16\",\n            mem_fraction_static=0.5,\n            enable_memory_saver=True,\n            tp_size=inference_device_mesh_cpu[\"tp\"].size(),\n            attention_backend=\"fa3\",\n        )\n\n        input_ids = input_ids.cuda()\n        idx_list = []\n\n        pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n        for i in range(input_ids.shape[0]):\n            idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))\n\n        sampling_params = dict(\n            n=1,\n            temperature=0,\n            top_p=1,\n            top_k=-1,\n            max_new_tokens=max_response_length,\n            presence_penalty=0.0,\n            frequency_penalty=0.0,\n            repetition_penalty=1.0,\n            skip_special_tokens=True,\n            spaces_between_special_tokens=True,\n            ignore_eos=False,\n        )\n\n        loop = asyncio.get_event_loop()\n        outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params))\n    else:\n        outputs = None\n\n    [outputs] = broadcast_pyobj(\n        [outputs],\n        rank=inference_device_mesh_cpu[\"tp\"].get_local_rank(),\n        src=inference_device_mesh_cpu[\"tp\"].mesh[0].item(),\n        dist_group=inference_device_mesh_cpu[\"tp\"].get_group(),\n        force_cpu_device=False,\n    )\n\n    sglang_response_tokens = [output[\"text\"] for output in outputs]\n\n    print(f\"sglang response: {sglang_response_tokens}\")\n    assert are_lists_similar(hf_response_tokens, sglang_response_tokens), \"Strings differ more than 10%:\\n\"\n    print(\"SPMD Test Passed!\")\n\n    torch.distributed.barrier()\n    torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "verl_rl/tests/workers/rollout/utils_sglang.py",
    "content": "# Copyright 2023-2024 SGLang Team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nfrom datetime import timedelta\n\nimport torch\nfrom omegaconf import OmegaConf\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig\n\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.utils.torch_functional import pad_sequence_to_length\n\n\n# ====================== utils ======================\ndef levenshtein(s1, s2):\n    m, n = len(s1), len(s2)\n    dp = [[0] * (n + 1) for _ in range(m + 1)]\n    for i in range(m + 1):\n        dp[i][0] = i\n    for j in range(n + 1):\n        dp[0][j] = j\n    for i in range(1, m + 1):\n        for j in range(1, n + 1):\n            cost = 0 if s1[i - 1] == s2[j - 1] else 1\n            dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost)\n    return dp[m][n]\n\n\ndef are_lists_similar(a, b, threshold=10):\n    if len(a) != len(b):\n        print(\"The lists are of different lengths.\")\n        return False\n    total_length = 0\n    total_diff = 0\n    for s1, s2 in zip(a, b, strict=True):\n        max_len = max(len(s1), len(s2))\n        total_length += max_len\n        total_diff += levenshtein(s1, s2)\n    percentage_difference = (total_diff / total_length) * 100\n    print(f\"Total difference: {percentage_difference:.2f}%\")\n    return percentage_difference <= threshold\n\n\ndef initialize_global_process_group(timeout_second=36000, spmd=False):\n    import torch.distributed\n\n    if not torch.distributed.is_initialized():  # Check if already initialized\n        print(\"Initializing process group...\")\n        torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second))\n    else:\n        print(\"Process group already initialized.\")\n\n    local_rank = int(os.environ[\"LOCAL_RANK\"])\n    rank = int(os.environ[\"RANK\"])\n    world_size = int(os.environ[\"WORLD_SIZE\"])\n    torch.cuda.set_device(local_rank)\n\n    CUDA_VISIBLE_DEVICES = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"\")\n    if not CUDA_VISIBLE_DEVICES:\n        if spmd:\n            # CUDA_VISIBLE_DEVICES = ','.join(str(i) for i in range(tensor_parallel_size))\n            CUDA_VISIBLE_DEVICES = \",\".join(str(i) for i in range(world_size))\n        else:\n            CUDA_VISIBLE_DEVICES = str(local_rank)\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = CUDA_VISIBLE_DEVICES\n        print(f\"CUDA_VISIBLE_DEVICES is not set, set to {CUDA_VISIBLE_DEVICES}\")\n\n    return local_rank, rank, world_size\n\n\ndef clean_torchelastic_env():\n    for k in [\"TORCHELASTIC_USE_AGENT_STORE\"]:\n        if k in os.environ:\n            del os.environ[k]\n\n\ndef load_tokenizer_and_model(local_model_path, dtype=\"bfloat16\"):\n    tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side=\"left\")\n    tokenizer.pad_token = tokenizer.eos_token\n    model = AutoModelForCausalLM.from_pretrained(local_model_path, torch_dtype=getattr(torch, dtype), device_map=\"cuda\")\n    return tokenizer, model\n\n\ndef prepare_inputs(tokenizer, prompts, max_prompt_length):\n    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n    tokenized = tokenizer(prompts, return_tensors=\"pt\", padding=True)\n    input_ids = pad_sequence_to_length(tokenized[\"input_ids\"], max_prompt_length, pad_token_id, left_pad=True)\n    attention_mask = pad_sequence_to_length(\n        tokenized[\"attention_mask\"], max_prompt_length, pad_token_id=0, left_pad=True\n    )\n    position_ids = compute_position_id_with_mask(attention_mask)\n    position_ids = pad_sequence_to_length(position_ids, max_prompt_length, pad_token_id=0, left_pad=True)\n    return input_ids, attention_mask, position_ids\n\n\ndef generate_hf_output(model, input_ids, attention_mask, tokenizer, max_response_length):\n    generation_config = GenerationConfig(do_sample=False)\n    output = model.generate(\n        input_ids=input_ids.cuda(),\n        attention_mask=attention_mask.cuda(),\n        max_new_tokens=max_response_length,\n        eos_token_id=tokenizer.eos_token_id,\n        pad_token_id=tokenizer.pad_token_id,\n        generation_config=generation_config,\n        output_scores=False,\n        return_dict_in_generate=True,\n        use_cache=False,\n    )\n    seq = output.sequences\n    response = seq[:, input_ids.shape[1] :]\n    return tokenizer.batch_decode(response)\n\n\ndef get_rollout_config(\n    max_response_length,\n    max_prompt_length,\n    dtype,\n    tensor_parallel_size,\n    tool_config_path=None,\n    interaction_config_path=None,\n):\n    sampling_params = dict(\n        n=1,\n        temperature=0,\n        top_p=1,\n        top_k=-1,\n        max_new_tokens=max_response_length,\n        presence_penalty=0.0,\n        frequency_penalty=0.0,\n        repetition_penalty=1.0,\n        skip_special_tokens=True,\n        spaces_between_special_tokens=True,\n        ignore_eos=False,\n    )\n\n    rollout_config = OmegaConf.create(\n        {\n            \"name\": \"sglang\",\n            \"mode\": \"sync\",\n            \"load_format\": \"dummy_dtensor\",\n            \"enforce_eager\": False,\n            \"free_cache_engine\": True,\n            \"dtype\": dtype,\n            \"gpu_memory_utilization\": 0.5,\n            \"ignore_eos\": False,\n            \"max_num_batched_tokens\": 8192,\n            \"prompt_length\": max_prompt_length,\n            \"response_length\": max_response_length,\n            \"tensor_model_parallel_size\": tensor_parallel_size,\n            # set to 128MB only for testing\n            \"update_weights_bucket_megabytes\": 128,\n            \"multi_turn\": {\n                \"max_assistant_turns\": 4,\n                \"max_user_turns\": 4,\n                \"enable\": True,\n                \"tool_config_path\": tool_config_path,\n                \"interaction_config_path\": interaction_config_path,\n                \"use_inference_chat_template\": False,\n                \"tokenization_sanity_check_mode\": \"strict\",\n            },\n            \"max_model_len\": None,\n            **sampling_params,\n        }\n    )\n\n    return rollout_config\n"
  },
  {
    "path": "verl_rl/verl/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nimport logging\nimport os\nfrom importlib.metadata import PackageNotFoundError\nfrom importlib.metadata import version as get_version\n\nfrom packaging.version import parse as parse_version\n\nfrom .protocol import DataProto\nfrom .utils.device import is_npu_available\nfrom .utils.logging_utils import set_basic_config\n\nversion_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))\n\nwith open(os.path.join(version_folder, \"version/version\")) as f:\n    __version__ = f.read().strip()\n\n\nset_basic_config(level=logging.WARNING)\n\n\n__all__ = [\"DataProto\", \"__version__\"]\n\nif os.getenv(\"VERL_USE_MODELSCOPE\", \"False\").lower() == \"true\":\n    if importlib.util.find_spec(\"modelscope\") is None:\n        raise ImportError(\"You are using the modelscope hub, please install modelscope by `pip install modelscope -U`\")\n    # Patch hub to download models from modelscope to speed up.\n    from modelscope.utils.hf_util import patch_hub\n\n    patch_hub()\n\nif is_npu_available:\n    from .models.transformers import npu_patch as npu_patch\n\n    package_name = \"transformers\"\n    required_version_spec = \"4.52.4\"\n    try:\n        installed_version = get_version(package_name)\n        installed = parse_version(installed_version)\n        required = parse_version(required_version_spec)\n\n        if installed < required:\n            raise ValueError(\n                f\"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is \"\n                f\"{installed}.\"\n            )\n    except PackageNotFoundError as e:\n        raise ImportError(\n            f\"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}\"\n        ) from e\n"
  },
  {
    "path": "verl_rl/verl/base_config.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport collections\nfrom dataclasses import (\n    dataclass,\n    field,\n    fields,  # Import the fields function to inspect dataclass fields\n)\nfrom typing import Any\n\n\n# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary\n@dataclass\nclass BaseConfig(collections.abc.Mapping):\n    \"\"\"The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config.\n\n    The BaseConfig class implements the Mapping Abstract Base Class.\n    This allows instances of this class to be used like dictionaries.\n    \"\"\"\n\n    extra: dict[str, Any] = field(default_factory=dict)\n\n    def __setattr__(self, name: str, value):\n        # if the field already exists (i.e. was set in __init__)\n        # and is in our frozen list, block assignment\n        if hasattr(self, \"_frozen_fields\") and name in self._frozen_fields and name in self.__dict__:\n            from dataclasses import FrozenInstanceError\n\n            raise FrozenInstanceError(f\"Field '{name}' is frozen and cannot be modified\")\n        # otherwise do the normal thing\n        super().__setattr__(name, value)\n\n    def get(self, key: str, default: Any = None) -> Any:\n        \"\"\"Get the value associated with the given key. If the key does not exist, return the default value.\n\n        Args:\n            key (str): The attribute name to retrieve.\n            default (Any, optional): The value to return if the attribute does not exist. Defaults to None.\n\n        Returns:\n            Any: The value of the attribute or the default value.\n        \"\"\"\n        try:\n            return getattr(self, key)\n        except AttributeError:\n            return default\n\n    def __getitem__(self, key: str):\n        \"\"\"Implement the [] operator for the class. Allows accessing attributes like dictionary items.\n\n        Args:\n            key (str): The attribute name to retrieve.\n\n        Returns:\n            Any: The value of the attribute.\n\n        Raises:\n            AttributeError: If the attribute does not exist.\n            TypeError: If the key type is not string\n        \"\"\"\n        return getattr(self, key)\n\n    def __iter__(self):\n        \"\"\"Implement the iterator protocol. Allows iterating over the attribute names of the instance.\n\n        Yields:\n            str: The name of each field in the dataclass.\n        \"\"\"\n        for f in fields(self):\n            yield f.name\n\n    def __len__(self):\n        \"\"\"\n        Return the number of fields in the dataclass.\n\n        Returns:\n            int: The number of fields in the dataclass.\n        \"\"\"\n        return len(fields(self))\n"
  },
  {
    "path": "verl_rl/verl/experimental/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/experimental/agent_loop/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .agent_loop import AgentLoopBase, AgentLoopManager\nfrom .single_turn_agent_loop import SingleTurnAgentLoop\nfrom .tool_agent_loop import ToolAgentLoop\n\n_ = [SingleTurnAgentLoop, ToolAgentLoop]\n\n__all__ = [\"AgentLoopBase\", \"AgentLoopManager\"]\n"
  },
  {
    "path": "verl_rl/verl/experimental/agent_loop/agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport heapq\nimport logging\nimport os\nimport random\nfrom abc import ABC, abstractmethod\nfrom typing import Any\n\nimport hydra\nimport numpy as np\nimport ray\nimport torch\nfrom cachetools import LRUCache\nfrom omegaconf import DictConfig, OmegaConf\nfrom pydantic import BaseModel\nfrom tensordict import TensorDict\nfrom transformers import AutoTokenizer\n\nfrom verl.protocol import DataProto\nfrom verl.single_controller.ray.base import RayWorkerGroup\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op\nfrom verl.workers.rollout.async_server import async_server_class\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass AsyncLLMServerManager:\n    \"\"\"\n    A class to manage multiple OpenAI compatible LLM servers. This class provides\n    - Load balance: least requests load balancing\n    - Sticky session: send multi-turn chat completions to same server for automatic prefix caching\n    \"\"\"\n\n    def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000):\n        \"\"\"Initialize the AsyncLLMServerManager.\n\n        Args:\n            config (DictConfig): YAML config.\n            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.\n            max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000.\n        \"\"\"\n        self.config = config\n        self.server_handles = server_handles\n        random.shuffle(self.server_handles)\n\n        # Least requests load balancing\n        self.weighted_serveres = [[0, (hash(server), server)] for server in server_handles]\n        heapq.heapify(self.weighted_serveres)\n\n        # LRU cache to map request_id to server\n        self.request_id_to_server = LRUCache(maxsize=max_cache_size)\n\n    def _choose_server(self, request_id: str) -> ray.actor.ActorHandle:\n        # TODO: implement server pressure awareness load balancing\n        if request_id in self.request_id_to_server:\n            return self.request_id_to_server[request_id]\n\n        server = self.weighted_serveres[0][1][1]\n        self.weighted_serveres[0][0] += 1\n        heapq.heapreplace(self.weighted_serveres, self.weighted_serveres[0])\n        self.request_id_to_server[request_id] = server\n        return server\n\n    @rollout_trace_op\n    async def generate(\n        self,\n        request_id,\n        *,\n        prompt_ids: list[int],\n        sampling_params: dict[str, Any],\n    ) -> list[int]:\n        \"\"\"Generate tokens from prompt ids.\n\n        Args:\n            request_id (str): request id for sticky session.\n            prompt_ids (List[int]): List of prompt token ids.\n            sampling_params (Dict[str, Any]): Sampling parameters for the chat completion.\n\n        Returns:\n            List[int]: List of generated token ids.\n        \"\"\"\n        server = self._choose_server(request_id)\n        output = await server.generate.remote(\n            request_id=request_id,\n            prompt_ids=prompt_ids,\n            sampling_params=sampling_params,\n        )\n        return output\n\n\nclass AgentLoopMetrics(BaseModel):\n    \"\"\"Agent loop performance metrics.\"\"\"\n\n    generate_sequences: float = 0.0\n    tool_calls: float = 0.0\n\n\nclass AgentLoopOutput(BaseModel):\n    \"\"\"Agent loop output.\"\"\"\n\n    prompt_ids: list[int]\n    \"\"\"Prompt token ids.\"\"\"\n    response_ids: list[int]\n    \"\"\"Response token ids including LLM generated token, tool response token.\"\"\"\n    response_mask: list[int]\n    \"\"\"Response mask, 1 for LLM generated token, 0 for tool response token.\"\"\"\n    num_turns: int = 0\n    \"\"\"Number of chat turns, including user, assistant, tool.\"\"\"\n    metrics: AgentLoopMetrics\n    \"\"\"Auxiliary performance metrics\"\"\"\n\n\n# make hydra.utils.instantiate happy\nclass _DummyConfig:\n    def __init__(self, config: DictConfig) -> None:\n        self.config = config\n\n\nclass AgentLoopBase(ABC):\n    \"\"\"An agent loop takes a input message, chat with OpenAI compatible LLM server and interact with various\n    environments.\"\"\"\n\n    _class_initialized = False\n\n    def __init__(\n        self, trainer_config: _DummyConfig, server_manager: AsyncLLMServerManager, tokenizer: AutoTokenizer, **kwargs\n    ):\n        \"\"\"Initialize agent loop, each sample will have its own loop instance.\n\n        Args:\n            trainer_config (_DummyConfig): trainer config.\n            server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager.\n            tokenizer (AutoTokenizer): Tokenizer for tokenize messages.\n        \"\"\"\n        self.init_class(trainer_config.config, tokenizer, **kwargs)\n        self.config = trainer_config.config\n        self.server_manager = server_manager\n        self.tokenizer = tokenizer\n        self.loop = asyncio.get_running_loop()\n\n    @classmethod\n    def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, **kwargs):\n        \"\"\"This is used to do heavy initialization work that should shared across all instances. It's only called once.\n\n        Args:\n            config (DictConfig): trainer config.\n            tokenizer (AutoTokenizer): Tokenizer for tokenize messages.\n            **kwargs: extra kwargs from config file passed in by `hydra.utils.instantiate`.\n        \"\"\"\n        if cls._class_initialized:\n            return\n        cls._class_initialized = True\n\n    @abstractmethod\n    async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:\n        \"\"\"Run agent loop to interact with LLM server and environment.\n\n        Args:\n            messages (List[Dict[str, Any]]): Input messages.\n            sampling_params (Dict[str, Any]): LLM sampling params.\n\n        Returns:\n            AgentLoopOutput: Agent loop output.\n        \"\"\"\n        raise NotImplementedError\n\n\n\"\"\"Agent loop registry: key is agent_name, value is a dict of agent loop config\nused by hydra.utils.instantiate to initialize agent loop instance.\n\nhttps://hydra.cc/docs/advanced/instantiate_objects/overview/\n\"\"\"\n_agent_loop_registry: dict[str, dict] = {}\n\n\ndef register(agent_name: str):\n    \"\"\"Register agent loop class.\"\"\"\n\n    def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]:\n        fqdn = f\"{subclass.__module__}.{subclass.__qualname__}\"\n        _agent_loop_registry[agent_name] = {\"_target_\": fqdn}\n        return subclass\n\n    return decorator\n\n\n@ray.remote\nclass AgentLoopWorker:\n    \"\"\"Agent loop worker takes a batch of messages and run each message in an agent loop.\"\"\"\n\n    def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle]):\n        \"\"\"Initialize agent loop manager.\n\n        Args:\n            config (DictConfig): YAML config.\n            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.\n        \"\"\"\n        self.config = config\n        self.server_manager = AsyncLLMServerManager(config, server_handles)\n\n        model_path = config.actor_rollout_ref.model.path\n        self.model_name = \"/\".join(model_path.split(\"/\")[-2:])\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n        self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)\n\n        agent_loop_config_path = config.actor_rollout_ref.rollout.agent.agent_loop_config_path\n        if agent_loop_config_path:\n            agent_loop_configs = OmegaConf.load(agent_loop_config_path)\n            for agent_loop_config in agent_loop_configs:\n                _agent_loop_registry[agent_loop_config.name] = agent_loop_config\n\n        trace_config = config.trainer.get(\"rollout_trace\", {})\n        trace_config = self.config.actor_rollout_ref.rollout.get(\"trace\", {})\n        RolloutTraceConfig.init(\n            self.config.trainer.project_name,\n            self.config.trainer.experiment_name,\n            trace_config.get(\"backend\"),\n            trace_config.get(\"token2text\", False),\n        )\n\n    async def generate_sequences(self, batch: DataProto) -> DataProto:\n        \"\"\"Generate sequences from agent loop.\n\n        Args:\n            batch (DataProto): Input batch.\n\n        Returns:\n            DataProto: Output batch.\n            - prompts: [bsz, prompt_length], prompt token ids from dataset.\n            - responses: [bsz, response_length], output token ids include response tokens\n              from LLM generation and observation tokens from tool_calls.\n            - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.\n            - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens\n              and response tokens.\n            - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.\n            - position_ids: [bsz, prompt_length + response_length], incremental position ids.\n\n            For multi-turn conversations:\n            responses:     |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|\n            response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|\n        \"\"\"\n        config = self.config.actor_rollout_ref.rollout\n        sampling_params = dict(\n            temperature=config.temperature,\n            top_p=config.top_p,\n            repetition_penalty=1.0,\n        )\n\n        # override sampling params for validation\n        if batch.meta_info.get(\"validate\", False):\n            sampling_params[\"top_p\"] = config.val_kwargs.top_p\n            sampling_params[\"temperature\"] = config.val_kwargs.temperature\n\n        # by default, we assume it's a single turn agent\n        if \"agent_name\" not in batch.non_tensor_batch:\n            batch.non_tensor_batch[\"agent_name\"] = np.array([\"single_turn_agent\"] * len(batch), dtype=object)\n\n        tasks = []\n        agent_names = batch.non_tensor_batch[\"agent_name\"]\n        raw_prompts = batch.non_tensor_batch[\"raw_prompt\"]\n        if \"index\" in batch.non_tensor_batch:\n            index = batch.non_tensor_batch[\"index\"]\n        else:\n            index = np.arange(len(raw_prompts))\n\n        trajectory_info = await get_trajectory_info(\n            batch.meta_info.get(\"global_steps\", -1), index, batch.meta_info.get(\"validate\", False)\n        )\n\n        for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True):\n            tasks.append(\n                asyncio.create_task(self._run_agent_loop(agent_name, messages.tolist(), sampling_params, trajectory))\n            )\n        outputs = await asyncio.gather(*tasks)\n\n        output = self._postprocess(outputs)\n        return output\n\n    async def _run_agent_loop(\n        self,\n        agent_name: str,\n        messages: list[dict[str, Any]],\n        sampling_params: dict[str, Any],\n        trajectory: dict[str, Any],\n    ) -> AgentLoopOutput:\n        with rollout_trace_attr(\n            step=trajectory[\"step\"],\n            sample_index=trajectory[\"sample_index\"],\n            rollout_n=trajectory[\"rollout_n\"],\n            validate=trajectory[\"validate\"],\n            name=\"agent_loop\",\n        ):\n            assert agent_name in _agent_loop_registry, (\n                f\"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}\"\n            )\n\n            agent_loop_config = _agent_loop_registry[agent_name]\n            agent_loop = hydra.utils.instantiate(\n                config=agent_loop_config,\n                trainer_config=_DummyConfig(config=self.config),\n                server_manager=self.server_manager,\n                tokenizer=self.tokenizer,\n            )\n            output = await agent_loop.run(messages, sampling_params)\n            return output\n\n    def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto:\n        # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py\n        # prompts: left pad\n        # responses: right pad\n        # input_ids: prompt + response\n        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n\n        # prompts\n        self.tokenizer.padding_side = \"left\"\n        outputs = self.tokenizer.pad(\n            [{\"input_ids\": input.prompt_ids} for input in inputs],\n            padding=\"max_length\",\n            max_length=self.config.actor_rollout_ref.rollout.prompt_length,\n            return_tensors=\"pt\",\n            return_attention_mask=True,\n        )\n        prompt_ids, prompt_attention_mask = outputs[\"input_ids\"], outputs[\"attention_mask\"]\n\n        # responses\n        self.tokenizer.padding_side = \"right\"\n        outputs = self.tokenizer.pad(\n            [{\"input_ids\": input.response_ids} for input in inputs],\n            padding=\"max_length\",\n            max_length=self.config.actor_rollout_ref.rollout.response_length,\n            return_tensors=\"pt\",\n            return_attention_mask=True,\n        )\n        response_ids, response_attention_mask = outputs[\"input_ids\"], outputs[\"attention_mask\"]\n\n        # response_mask\n        outputs = self.tokenizer.pad(\n            [{\"input_ids\": input.response_mask} for input in inputs],\n            padding=\"max_length\",\n            max_length=self.config.actor_rollout_ref.rollout.response_length,\n            return_tensors=\"pt\",\n            return_attention_mask=False,\n        )\n        response_mask = outputs[\"input_ids\"]\n        assert response_ids.shape == response_mask.shape, (\n            f\"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}\"\n        )\n        response_mask = response_mask * response_attention_mask\n\n        input_ids = torch.cat([prompt_ids, response_ids], dim=1)\n        attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1)\n        position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask\n\n        batch = TensorDict(\n            {\n                \"prompts\": prompt_ids,  # [bsz, prompt_length]\n                \"responses\": response_ids,  # [bsz, response_length]\n                \"response_mask\": response_mask,  # [bsz, response_length]\n                \"input_ids\": input_ids,  # [bsz, prompt_length + response_length]\n                \"attention_mask\": attention_mask,  # [bsz, prompt_length + response_length]\n                \"position_ids\": position_ids,  # [bsz, prompt_length + response_length]\n            },\n            batch_size=len(input_ids),\n        )\n\n        num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32)\n        metrics = [input.metrics.model_dump() for input in inputs]\n        return DataProto(batch=batch, non_tensor_batch={\"__num_turns__\": num_turns}, meta_info={\"metrics\": metrics})\n\n\nasync def get_trajectory_info(step, index, validate):\n    \"\"\"Get trajectory info.\n\n    Args:\n        step (int): global steps in the trainer.\n        index (list): form datastore extra_info.index column.\n        validate (bool): whether is a validate step.\n\n    Returns:\n        list: trajectory.\n    \"\"\"\n    trajectory_info = []\n    rollout_n = 0\n    for i in range(len(index)):\n        if i > 0 and index[i - 1] == index[i]:\n            rollout_n += 1\n        else:\n            rollout_n = 0\n        trajectory_info.append({\"step\": step, \"sample_index\": index[i], \"rollout_n\": rollout_n, \"validate\": validate})\n    return trajectory_info\n\n\nclass AgentLoopManager:\n    \"\"\"Agent loop manager that manages a group of agent loop workers.\"\"\"\n\n    def __init__(self, config: DictConfig, worker_group: RayWorkerGroup):\n        \"\"\"Initialize agent loop manager.\n\n        Args:\n            config (DictConfig): trainer config.\n            worker_group (RayWorkerGroup): ActorRolloutRef worker group.\n        \"\"\"\n        self.config = config\n        self.worker_group = worker_group\n\n        self._initialize_llm_servers()\n        self._init_agent_loop_workers()\n\n        # Initially we're in sleep mode.\n        self.sleep()\n\n    def _initialize_llm_servers(self):\n        self.rollout_tp_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size\n        self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size\n\n        register_center = ray.get_actor(f\"{self.worker_group.name_prefix}_register_center\")\n        workers_info = ray.get(register_center.get_worker_info.remote())\n        assert len(workers_info) == self.worker_group.world_size\n\n        self.async_llm_servers = [None] * self.rollout_dp_size\n        self.server_addresses = [None] * self.rollout_dp_size\n\n        if self.config.actor_rollout_ref.rollout.agent.custom_async_server:\n            server_class = async_server_class(\n                rollout_backend=self.config.actor_rollout_ref.rollout.name,\n                rollout_backend_module=self.config.actor_rollout_ref.rollout.agent.custom_async_server.path,\n                rollout_backend_class=self.config.actor_rollout_ref.rollout.agent.custom_async_server.name,\n            )\n        else:\n            server_class = async_server_class(rollout_backend=self.config.actor_rollout_ref.rollout.name)\n\n        # Start all server instances, restart if address already in use.\n        unready_dp_ranks = set(range(self.rollout_dp_size))\n        while len(unready_dp_ranks) > 0:\n            servers = {\n                rollout_dp_rank: server_class.options(\n                    # make sure AsyncvLLMServer colocates with its corresponding workers\n                    scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(\n                        node_id=workers_info[rollout_dp_rank * self.rollout_tp_size],\n                        soft=False,\n                    ),\n                    name=f\"async_llm_server_{rollout_dp_rank}\",\n                ).remote(self.config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix)\n                for rollout_dp_rank in unready_dp_ranks\n            }\n\n            for rollout_dp_rank, server in servers.items():\n                try:\n                    address = ray.get(server.get_server_address.remote())\n                    self.server_addresses[rollout_dp_rank] = address\n                    self.async_llm_servers[rollout_dp_rank] = server\n                    unready_dp_ranks.remove(rollout_dp_rank)\n                except Exception:\n                    ray.kill(server)\n                    print(f\"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting...\")\n\n        # All server instances are ready, init AsyncLLM engine.\n        ray.get([server.init_engine.remote() for server in self.async_llm_servers])\n\n    def _init_agent_loop_workers(self):\n        self.agent_loop_workers = []\n        for i in range(self.config.actor_rollout_ref.rollout.agent.num_workers):\n            self.agent_loop_workers.append(\n                AgentLoopWorker.options(\n                    name=f\"agent_loop_worker_{i}\",\n                ).remote(self.config, self.async_llm_servers)\n            )\n\n    def generate_sequences(self, prompts: DataProto) -> DataProto:\n        \"\"\"Split input batch and dispatch to agent loop workers.\n\n        Args:\n            prompts (DataProto): Input batch.\n\n        Returns:\n            DataProto: Output batch.\n        \"\"\"\n        if self.config.actor_rollout_ref.rollout.free_cache_engine:\n            self.wake_up()\n        chunkes = prompts.chunk(len(self.agent_loop_workers))\n        outputs = ray.get(\n            [\n                worker.generate_sequences.remote(chunk)\n                for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)\n            ]\n        )\n        output = DataProto.concat(outputs)\n        if self.config.actor_rollout_ref.rollout.free_cache_engine:\n            self.sleep()\n\n        # calculate performance metrics\n        metrics = [output.meta_info[\"metrics\"] for output in outputs]  # List[List[Dict[str, str]]]\n        timing = self._performance_metrics(metrics, output)\n\n        output.meta_info = {\"timing\": timing}\n        return output\n\n    def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:\n        timing = {}\n        t_generate_sequences = np.array([metric[\"generate_sequences\"] for chunk in metrics for metric in chunk])\n        t_tool_calls = np.array([metric[\"tool_calls\"] for chunk in metrics for metric in chunk])\n        timing[\"agent_loop/generate_sequences/min\"] = t_generate_sequences.min()\n        timing[\"agent_loop/generate_sequences/max\"] = t_generate_sequences.max()\n        timing[\"agent_loop/generate_sequences/mean\"] = t_generate_sequences.mean()\n        timing[\"agent_loop/tool_calls/min\"] = t_tool_calls.min()\n        timing[\"agent_loop/tool_calls/max\"] = t_tool_calls.max()\n        timing[\"agent_loop/tool_calls/mean\"] = t_tool_calls.mean()\n\n        # batch sequence generation is bounded by the slowest sample\n        slowest = np.argmax(t_generate_sequences + t_tool_calls)\n        attention_mask = output.batch[\"attention_mask\"][slowest]\n        prompt_length = output.batch[\"prompts\"].shape[1]\n        timing[\"agent_loop/slowest/generate_sequences\"] = t_generate_sequences[slowest]\n        timing[\"agent_loop/slowest/tool_calls\"] = t_tool_calls[slowest]\n        timing[\"agent_loop/slowest/prompt_length\"] = attention_mask[:prompt_length].sum().item()\n        timing[\"agent_loop/slowest/response_length\"] = attention_mask[prompt_length:].sum().item()\n\n        return timing\n\n    def wake_up(self):\n        \"\"\"Wake up all rollout server instances.\"\"\"\n        ray.get([server.wake_up.remote() for server in self.async_llm_servers])\n\n    def sleep(self):\n        \"\"\"Sleep all rollout server instances.\"\"\"\n        ray.get([server.sleep.remote() for server in self.async_llm_servers])\n"
  },
  {
    "path": "verl_rl/verl/experimental/agent_loop/single_turn_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport logging\nimport os\nfrom typing import Any\nfrom uuid import uuid4\n\nfrom verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register\nfrom verl.utils.profiler import simple_timer\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n@register(\"single_turn_agent\")\nclass SingleTurnAgentLoop(AgentLoopBase):\n    \"\"\"Naive agent loop that only do single turn chat completion.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length\n        self.response_length = self.config.actor_rollout_ref.rollout.response_length\n\n    async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:\n        metrics = {}\n        request_id = uuid4().hex\n        prompt_ids = await self.loop.run_in_executor(\n            None, lambda: self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)\n        )\n\n        with simple_timer(\"generate_sequences\", metrics):\n            response_ids = await self.server_manager.generate(\n                request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params\n            )\n        response_mask = [1] * len(response_ids)\n\n        output = AgentLoopOutput(\n            prompt_ids=prompt_ids,\n            response_ids=response_ids[: self.response_length],\n            response_mask=response_mask[: self.response_length],\n            num_turns=2,\n            metrics=metrics,\n        )\n        return output\n"
  },
  {
    "path": "verl_rl/verl/experimental/agent_loop/tool_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport json\nimport logging\nimport os\nfrom typing import Any\nfrom uuid import uuid4\n\nfrom verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register\nfrom verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser\nfrom verl.tools.utils.tool_registry import initialize_tools_from_config\nfrom verl.utils.profiler import simple_timer\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n@register(\"tool_agent\")\nclass ToolAgentLoop(AgentLoopBase):\n    @classmethod\n    def init_class(cls, config, tokenizer, **kwargs):\n        if cls._class_initialized:\n            return\n        cls._class_initialized = True\n        print(\"Performing class-level ToolAgentLoop initialization\")\n\n        # Initialize tools from config file\n        cls.tokenizer = tokenizer\n        cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns\n        cls.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns\n        cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls\n        cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length\n        cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side\n        tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path\n        tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []\n        cls.tools = {tool.name: tool for tool in tool_list}\n        cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]\n        cls.tool_parser = ToolParser.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format, cls.tokenizer)\n        print(f\"Initialized tools: {cls.tools}\")\n\n        cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length\n        cls.response_length = config.actor_rollout_ref.rollout.response_length\n        cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)\n\n    @rollout_trace_op\n    async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:\n        metrics = {}\n        request_id = uuid4().hex\n        prompt_ids = await self.loop.run_in_executor(\n            None,\n            lambda: self.tokenizer.apply_chat_template(\n                messages, tools=self.tool_schemas, add_generation_prompt=True, tokenize=True\n            ),\n        )\n        response_mask = []\n\n        user_turns, assistant_turns = 0, 0\n        while True:\n            with simple_timer(\"generate_sequences\", metrics):\n                response_ids = await self.server_manager.generate(\n                    request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params\n                )\n            prompt_ids += response_ids\n            response_mask += [1] * len(response_ids)\n            assistant_turns += 1\n\n            # reach max response length\n            if len(response_mask) >= self.response_length:\n                break\n\n            # reach max assistant turns\n            if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns:\n                break\n\n            # reach max user turns\n            if self.max_user_turns and user_turns >= self.max_user_turns:\n                break\n\n            # no tool calls\n            _, tool_calls = await self.tool_parser.extract_tool_calls(response_ids)\n            if not tool_calls:\n                break\n\n            # call tools\n            tasks = []\n            for tool_call in tool_calls[: self.max_parallel_calls]:\n                tasks.append(self._call_tool(tool_call))\n            with simple_timer(\"tool_calls\", metrics):\n                tool_responses = await asyncio.gather(*tasks)\n            if any(isinstance(item, Exception) for item in tool_responses):\n                break\n\n            # append tool_response_ids\n            tool_response_ids = await self.loop.run_in_executor(\n                None,\n                lambda messages=tool_responses: self.tokenizer.apply_chat_template(\n                    messages, add_generation_prompt=True, tokenize=True\n                ),\n            )\n            tool_response_ids = tool_response_ids[len(self.system_prompt) :]\n\n            # NOTE: last turn should not be user turn, or the EOS token reward\n            # can't be propagated to previous token in GAE.\n            if len(response_mask) + len(tool_response_ids) >= self.response_length:\n                break\n\n            prompt_ids += tool_response_ids\n            response_mask += [0] * len(tool_response_ids)\n            user_turns += 1\n\n        response_ids = prompt_ids[-len(response_mask) :]\n        prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]\n\n        output = AgentLoopOutput(\n            prompt_ids=prompt_ids,\n            response_ids=response_ids[: self.response_length],\n            response_mask=response_mask[: self.response_length],\n            num_turns=user_turns + assistant_turns + 1,\n            metrics=metrics,\n        )\n        return output\n\n    async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]:\n        \"\"\"Call tool and return tool response.\"\"\"\n        tool, instance_id = None, None\n        try:\n            # TODO: append malformed tool_call to the prompt: invalid function name or arguments\n            tool_name = tool_call.name\n            tool_args = json.loads(tool_call.arguments)\n            tool = self.tools[tool_name]\n\n            instance_id = await tool.create()\n            tool_response, _, _ = await tool.execute(instance_id, tool_args)\n        except Exception as e:\n            logger.exception(f\"Error when executing tool: {e}\")\n            return e\n        finally:\n            if tool and instance_id:\n                await tool.release(instance_id)\n\n        if len(tool_response) > self.max_tool_response_length:\n            if self.tool_response_truncate_side == \"left\":\n                tool_response = tool_response[: self.max_tool_response_length] + \"...(truncated)\"\n            elif self.tool_response_truncate_side == \"right\":\n                tool_response = \"(truncated)...\" + tool_response[-self.max_tool_response_length :]\n            else:\n                length = self.max_tool_response_length // 2\n                tool_response = tool_response[:length] + \"...(truncated)...\" + tool_response[-length:]\n\n        return {\n            \"role\": \"tool\",\n            \"content\": tool_response,\n        }\n"
  },
  {
    "path": "verl_rl/verl/experimental/agent_loop/tool_parser.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport json\nimport logging\nimport os\nfrom abc import ABC, abstractmethod\n\nimport regex as re\nfrom pydantic import BaseModel\n\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass FunctionCall(BaseModel):\n    arguments: str\n    \"\"\"\n    The arguments to call the function with, as generated by the model in JSON\n    format. Note that the model does not always generate valid JSON, and may\n    hallucinate parameters not defined by your function schema. Validate the\n    arguments in your code before calling your function.\n    \"\"\"\n\n    name: str\n    \"\"\"The name of the function to call.\"\"\"\n\n\nclass ToolParser(ABC):\n    _registry: dict[str, type[\"ToolParser\"]] = {}\n\n    def __init__(self, tokenizer) -> None:\n        self.tokenizer = tokenizer\n\n    @abstractmethod\n    async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:\n        \"\"\"Extract tool calls from the responses.\n\n        Args:\n            responses_ids (List[int]): The ids of the responses.\n\n        Returns:\n            Tuple[str, List[FunctionCall]]: Content and extracted tool calls.\n        \"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    def get_tool_parser(cls, name: str, tokenizer):\n        if name not in cls._registry:\n            raise ValueError(f\"Unknown tool parser: {name}\")\n        return cls._registry[name](tokenizer)\n\n    @classmethod\n    def register(cls, name: str):\n        def decorator(subclass: type[ToolParser]) -> type[ToolParser]:\n            cls._registry[name] = subclass\n            return subclass\n\n        return decorator\n\n\n@ToolParser.register(\"hermes\")\nclass HermesToolParser(ToolParser):\n    \"\"\"Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py\"\"\"\n\n    def __init__(self, tokenizer) -> None:\n        super().__init__(tokenizer)\n\n        self.tool_call_start_token: str = \"<tool_call>\"\n        self.tool_call_end_token: str = \"</tool_call>\"\n        self.tool_call_regex = re.compile(r\"<tool_call>(.*?)</tool_call>\", re.DOTALL)\n\n    @rollout_trace_op\n    async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:\n        loop = asyncio.get_running_loop()\n        text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)\n        if self.tool_call_start_token not in text or self.tool_call_end_token not in text:\n            return text, []\n\n        matches = self.tool_call_regex.findall(text)\n        function_calls = []\n        for match in matches:\n            try:\n                function_call = json.loads(match)\n                name, arguments = function_call[\"name\"], function_call[\"arguments\"]\n                function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False)))\n            except Exception as e:\n                logger.error(f\"Failed to decode tool call: {e}\")\n\n        # remaing text exclude tool call tokens\n        content = self.tool_call_regex.sub(\"\", text)\n\n        return content, function_calls\n"
  },
  {
    "path": "verl_rl/verl/experimental/dataset/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/experimental/dataset/sampler.py",
    "content": "# Copyright 2025 Amazon.com Inc and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom abc import abstractmethod\nfrom collections.abc import Sized\n\nfrom omegaconf import DictConfig\nfrom torch.utils.data import Sampler\n\nfrom verl import DataProto\n\n\nclass AbstractSampler(Sampler[int]):\n    \"\"\"Abstract interface for custom samplers.\"\"\"\n\n    @abstractmethod\n    def __init__(\n        self,\n        data_source: Sized,\n        data_config: DictConfig,\n    ):\n        pass\n\n\nclass AbstractCurriculumSampler(AbstractSampler):\n    \"\"\"Experimental interface for curriculum learning samplers.\"\"\"\n\n    @abstractmethod\n    def update(self, batch: DataProto) -> None:\n        pass\n"
  },
  {
    "path": "verl_rl/verl/experimental/dynamic_dataset/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/experimental/dynamic_dataset/dynamicgen_dataset.py",
    "content": "# Copyright 2025 Amazon.com Inc and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nDataset class that enables dynamic data generation strategies between iterations of training.\nThis class extends RLHFDataset and uses an AbstractDataGen instance to generate data.\n\nThis is especially useful in settings where proposer model generates new tasks based\non rollout data.\n\"\"\"\n\nimport logging\nfrom abc import ABC, abstractmethod\nfrom typing import Optional\n\nimport datasets\nfrom omegaconf import DictConfig\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nfrom verl import DataProto\nfrom verl.utils.dataset import RLHFDataset\nfrom verl.utils.import_utils import load_extern_type\n\nlogger = logging.getLogger(__name__)\n\n\nclass AbstractDataGenerator(ABC):\n    def __init__(self, config: DictConfig):\n        self.config = config\n\n    @abstractmethod\n    def generate(self, dataset: Dataset) -> datasets.Dataset:\n        \"\"\"\n        Generate method must be implemented by subclasses.\n        Args:\n            dataset: The dataset to generate from.\n        Returns:\n            Processed data or result as implemented by the subclass.\n        \"\"\"\n        pass\n\n\nclass MockDataGenerator(AbstractDataGenerator):\n    \"\"\"\n    A noop data gen class that only reappends the first datapoint.\n    This class is useful as a placeholder and testing.\n    \"\"\"\n\n    def __init__(self, config: DictConfig = None):\n        super().__init__(config)\n\n    def generate(self, dataset: Dataset) -> datasets.Dataset:\n        print(\"MockDataGenerator: No operation performed on the dataset.\")\n        return dataset.dataframe.select([0])\n\n\nclass DynamicGenDataset(RLHFDataset):\n    \"\"\"\n    A dataset class that uses a data generation strategy to process data.\n    This class extends RLHFDataset and uses an AbstractDataGen instance to generate data.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_files: str | list[str],\n        tokenizer: PreTrainedTokenizer,\n        config: DictConfig,\n        processor: Optional[ProcessorMixin] = None,\n    ):\n        super().__init__(data_files, tokenizer, config, processor)\n        self.datagen: AbstractDataGenerator = config.datagen\n        assert \"datagen\" in config and config.datagen.get(\"path\", None) is not None, (\n            f\"datagen path is not set in config: {config}\"\n        )\n        # Dynamically load the custom datagen class\n        datagen_cls = load_extern_type(config.datagen.path, config.datagen.name)\n\n        # Verify that the custom datagen class inherits from AbstractDataGenerator\n        abs_cls = AbstractDataGenerator\n        if not issubclass(datagen_cls, abs_cls):\n            raise TypeError(\n                f\"The custom datagen class '{config.datagen.name}' from '{config.datagen.path}'\"\n                + \" must inherit from {abs_cls}\"\n            )\n\n        self.data_generator = datagen_cls(config.datagen)\n        self.on_batch_end()\n\n    def append_dataframe(self, new_dataframe: datasets.Dataset):\n        new_dataframe = self.maybe_filter_out_long_prompts(new_dataframe)\n        self.dataframe = datasets.concatenate_datasets([self.dataframe, new_dataframe])\n\n        logger.info(f\"new dataset len: {len(self.dataframe)}\")\n\n    def on_batch_end(self, batch: DataProto) -> None:\n        \"\"\"\n        Generate data using the provided data generation strategy.\n        Note: This method is intended to change the dataset after each training batch.\n        \"\"\"\n        new_data = self.data_generator.generate(self)\n        self.append_dataframe(new_data)\n"
  },
  {
    "path": "verl_rl/verl/interactions/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/interactions/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\n\nclass BaseInteraction:\n    def __init__(self, config: dict[str, Any]):\n        self.config = config\n        self.name: str = config.get(\"name\", \"interaction_agent\")  # More general agent default role name\n\n    async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str:\n        \"\"\"Create a tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The instance id of the tool.\n        \"\"\"\n        if instance_id is None:\n            return str(uuid4())\n        else:\n            return instance_id\n\n    async def generate_response(\n        self, instance_id: str, messages: list[dict[str, Any]], **kwargs\n    ) -> tuple[bool, str, float, dict[str, Any]]:  # More clear response generation method\n        \"\"\"\n        Generates a response for the current turn of interaction.\n        Returns a tuple containing:\n        - should_terminate_sequence (bool): True if the interaction sequence should end.\n        - response_content (str): The textual content of the response.\n        - current_turn_score (float): The score for this specific turn/response.\n        - additional_data (dict): Any extra information or metadata.\n        \"\"\"\n        should_terminate_sequence: bool = False  # if True, end rollout\n        response_content: str = \"Your current result seems acceptable.\"\n        current_turn_score: float = 0.8\n        additional_data: dict[str, Any] = {}\n        return should_terminate_sequence, response_content, current_turn_score, additional_data\n\n    async def calculate_score(self) -> float:  # More clear score calculation method\n        \"\"\"\n        Calculates a score for the interaction,\n        potentially considering aspects like partial exposure & in-context task switching.\n        should be invoke at turn-level\n        \"\"\"\n        # ...implement the logic to calculate turn-level score...\n        score = 0.0\n        return score\n\n    async def finalize_interaction(self) -> None:  # More clear interaction end and resource release method\n        \"\"\"\n        Finalizes the interaction session and releases any associated state or resources.\n        Simulates: release state\n        \"\"\"\n        # ...implement the logic to release state...\n        pass\n"
  },
  {
    "path": "verl_rl/verl/interactions/gsm8k_interaction.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom verl.utils.reward_score import gsm8k\n\nfrom .base import BaseInteraction\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass Gsm8kInteraction(BaseInteraction):\n    \"\"\"A demo interaction for calculating the reward of gsm8k.\n\n    - `start_interaction`: start a interaction instance for a trajectory.\n    - `generate_response`: generate the response of the user.\n    - `calculate_score`: calculate the score of the interaction.\n    - `finalize_interaction`: finalize the interaction instance.\n    \"\"\"\n\n    def __init__(self, config: dict):\n        super().__init__(config)\n        self._instance_dict = {}\n\n    async def start_interaction(\n        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs\n    ) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        return instance_id\n\n    async def generate_response(\n        self, instance_id: str, messages: list[dict[str, Any]], **kwargs\n    ) -> tuple[bool, str, float, dict]:\n        content = \"\"\n        for i in range(len(messages) - 1, -1, -1):\n            item = messages[i]\n            if item.get(\"role\") == \"assistant\":\n                content = item.get(\"content\")\n                break\n\n        self._instance_dict[instance_id][\"response\"] = content\n\n        reward = await self.calculate_score(instance_id)\n        if reward == 1.0:\n            response = \"Your response is correct!\"\n            should_terminate_sequence = True\n        else:\n            response = \"Your response is incorrect! You need to reflect on your answer and try again.\"\n            should_terminate_sequence = False\n\n        return should_terminate_sequence, response, reward, {}\n\n    async def calculate_score(self, instance_id: str, **kwargs) -> float:\n        return gsm8k.compute_score(\n            self._instance_dict[instance_id][\"response\"],\n            self._instance_dict[instance_id][\"ground_truth\"],\n            method=\"strict\",\n            format_score=0.0,\n            score=1.0,\n        )\n\n    async def finalize_interaction(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "verl_rl/verl/interactions/utils/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/interactions/utils/interaction_registry.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib.util\nimport logging\nimport os\nimport sys\n\nfrom omegaconf import OmegaConf\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef get_interaction_class(cls_name):\n    \"\"\"Dynamically import and return the interaction class.\"\"\"\n    module_name, class_name = cls_name.rsplit(\".\", 1)\n    if module_name not in sys.modules:\n        spec = importlib.util.find_spec(module_name)\n        module = importlib.util.module_from_spec(spec)\n        sys.modules[module_name] = module\n        spec.loader.exec_module(module)\n    else:\n        module = sys.modules[module_name]\n\n    interaction_cls = getattr(module, class_name)\n    return interaction_cls\n\n\ndef initialize_interactions_from_config(interaction_config_file):\n    \"\"\"Initialize interactions from configuration file.\n\n    Args:\n        interaction_config_file: Path to the interaction configuration file.\n\n    Returns:\n        dict: A dictionary mapping interaction names to BaseInteraction instances.\n    \"\"\"\n    interaction_config = OmegaConf.load(interaction_config_file)\n    interaction_map = {}\n\n    for interaction_item in interaction_config.interaction:\n        cls_name = interaction_item.class_name\n        interaction_cls = get_interaction_class(cls_name)\n\n        # Extract config and name\n        config = OmegaConf.to_container(interaction_item.config, resolve=True)\n\n        # Get the interaction name - either from config or derive from class name\n        name = interaction_item.get(\"name\", None)\n        if name is None:\n            # If no name is specified, use the class name as default\n            class_simple_name = cls_name.split(\".\")[-1]\n            # Remove \"Interaction\" suffix if present, otherwise use full class name\n            if class_simple_name.endswith(\"Interaction\"):\n                name = class_simple_name[:-11].lower()  # Remove \"Interaction\" (11 chars)\n            else:\n                name = class_simple_name.lower()\n\n        # Check for duplicate names\n        if name in interaction_map:\n            raise ValueError(f\"Duplicate interaction name '{name}' found. Each interaction must have a unique name.\")\n\n        # Inject the name into the config\n        config[\"name\"] = name\n\n        # Create the interaction instance\n        interaction = interaction_cls(config=config)\n        interaction_map[name] = interaction\n\n        logger.info(f\"Initialized interaction '{name}' with class '{cls_name}'\")\n\n    return interaction_map\n"
  },
  {
    "path": "verl_rl/verl/model_merger/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/model_merger/__main__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis module is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends.\n\nTo merge FSDP checkpoints:\n```sh\npython -m verl.model_merger merge \\\n    --backend fsdp \\\n    --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \\\n    --target_dir /path/to/merged_hf_model\n```\n\nTo merge Megatron checkpoints:\n```sh\npython -m verl.model_merger merge \\\n    --backend megatron \\\n    --tie-word-embedding \\\n    --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \\\n    --target_dir /path/to/merged_hf_model\n```\n\nor use distribtued merge for large models like dpskv3 671B\n\n```sh\ntorchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge\\\n    --backend megatron \\\n    --local_dir ./checkpoints/global_step_1/actor \\\n    --target_dir /path/to/merged_hf_model\n```\n\n\nFor more details, please refer to documentation:\nhttps://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model\n\"\"\"\n\nfrom .base_model_merger import generate_config_from_args, parse_args\n\n\ndef main():\n    args = parse_args()\n    config = generate_config_from_args(args)\n    print(f\"config: {config}\")\n\n    if config.backend == \"fsdp\":\n        from .fsdp_model_merger import FSDPModelMerger\n\n        merger = FSDPModelMerger(config)\n    elif config.backend == \"megatron\":\n        from .megatron_model_merger import MegatronModelMerger\n\n        merger = MegatronModelMerger(config)\n    else:\n        raise NotImplementedError(f\"Unknown backend: {config.backend}\")\n\n    merger.merge_and_save()\n    merger.cleanup()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/verl/model_merger/base_model_merger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport os\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass, field\nfrom typing import Optional\n\nimport torch\nfrom accelerate import init_empty_weights\nfrom transformers import (\n    AutoConfig,\n    AutoModelForCausalLM,\n    AutoModelForTokenClassification,\n    AutoModelForVision2Seq,\n    GenerationConfig,\n)\n\nfrom verl.utils import hf_processor, hf_tokenizer\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"verl model merger\")\n    subparsers = parser.add_subparsers(dest=\"operation\", required=True, help=\"Specify 'merge' or 'test' operation.\")\n\n    base_op_parser = argparse.ArgumentParser(add_help=False)\n    base_op_parser.add_argument(\n        \"--backend\", type=str, required=True, choices=[\"fsdp\", \"megatron\"], help=\"The backend of the model\"\n    )\n    base_op_parser.add_argument(\"--local_dir\", type=str, default=None, help=\"Path to the saved model checkpoints.\")\n    base_op_parser.add_argument(\n        \"--tie-word-embedding\",\n        action=\"store_true\",\n        help=\"Whether to tie word embedding weights (currently only Megatron supported)\",\n    )\n    base_op_parser.add_argument(\"--trust-remote-code\", action=\"store_true\", help=\"Whether to trust remote code\")\n    base_op_parser.add_argument(\n        \"--is-value-model\",\n        action=\"store_true\",\n        help=\"Whether the model is a value model (currently only Megatron supported)\",\n    )\n    base_op_parser.add_argument(\n        \"--use_cpu_initialization\",\n        action=\"store_true\",\n        help=\"Whether to use CPU initialization for the model. This is useful for large models that cannot \"\n        \"fit into GPU memory during initialization.\",\n    )\n\n    merge_parser = subparsers.add_parser(\"merge\", parents=[base_op_parser], help=\"Merge model checkpoints and save.\")\n    merge_parser.add_argument(\n        \"--target_dir\", default=\"tmp\", type=str, help=\"Directory to save the merged huggingface model\"\n    )\n    merge_parser.add_argument(\n        \"--hf_upload_path\", default=None, type=str, help=\"Hugging Face repository ID to upload the model\"\n    )\n    merge_parser.add_argument(\n        \"--private\", action=\"store_true\", help=\"Whether to upload the model to a private Hugging Face repository\"\n    )\n\n    test_parser = subparsers.add_parser(\n        \"test\", parents=[base_op_parser], help=\"Test merged model against a reference Hugging Face model\"\n    )\n    test_parser.add_argument(\n        \"--test_hf_dir\", type=str, required=True, help=\"Path to the reference Hugging Face model directory for testing\"\n    )\n\n    args = parser.parse_args()\n    return args\n\n\n@dataclass\nclass ModelMergerConfig:\n    \"\"\"Configuration for model merger operations.\n\n    Args:\n        operation (str): Operation type - 'merge' or 'test'.\n        backend (str): Backend type for the model ('fsdp' or 'megatron').\n        target_dir (Optional[str]): Directory to save the merged huggingface model. Defaults to \"tmp\".\n        hf_upload_path (Optional[str]): Hugging Face repository ID to upload the model. Defaults to None.\n        private (bool): Whether to upload the model to a private Hugging Face repository. Defaults to False.\n        test_hf_dir (Optional[str]): Path to the reference Hugging Face model directory for testing. Defaults to None.\n        tie_word_embedding (bool): Whether to tie word embedding weights (currently only Megatron\n            supported). Defaults to False.\n        trust_remote_code (bool): Whether to trust remote code. Defaults to False.\n        is_value_model (bool): Whether the model is a value model (currently only Megatron\n            supported). Defaults to False.\n        local_dir (Optional[str]): Path to the saved model checkpoints. Defaults to None.\n        hf_model_config_path (Optional[str]): Path to HuggingFace model configuration files. Defaults to None.\n        hf_upload (bool): Whether to upload to HuggingFace (computed automatically). Not for initialization.\n        use_cpu_initialization (bool): Whether to use CPU initialization for large models. Defaults to False.\n    \"\"\"\n\n    operation: str  # 'merge' or 'test'\n    backend: str\n    target_dir: Optional[str] = \"tmp\"\n    hf_upload_path: Optional[str] = None\n    private: bool = False\n    test_hf_dir: Optional[str] = None\n    tie_word_embedding: bool = False\n    trust_remote_code: bool = False\n    is_value_model: bool = False\n    local_dir: Optional[str] = None\n    hf_model_config_path: Optional[str] = None\n    hf_upload: bool = field(init=False)\n    use_cpu_initialization: bool = False\n\n    def __post_init__(self):\n        self.hf_upload = self.operation == \"merge\" and bool(self.hf_upload_path)\n        if self.operation == \"test\":\n            self.target_dir = None\n            self.hf_upload_path = None\n            self.private = False\n\n\ndef generate_config_from_args(args: argparse.Namespace) -> ModelMergerConfig:\n    common_config_args = {\n        \"operation\": args.operation,\n        \"backend\": args.backend,\n        \"tie_word_embedding\": args.tie_word_embedding,\n        \"trust_remote_code\": args.trust_remote_code,\n        \"is_value_model\": args.is_value_model,\n        \"local_dir\": args.local_dir,\n        \"hf_model_config_path\": os.path.join(args.local_dir, \"huggingface\"),\n        \"use_cpu_initialization\": args.use_cpu_initialization,\n    }\n\n    if args.operation == \"merge\":\n        config = ModelMergerConfig(\n            **common_config_args,\n            target_dir=args.target_dir,\n            hf_upload_path=args.hf_upload_path,\n            private=args.private,\n            test_hf_dir=None,\n        )\n        os.makedirs(config.target_dir, exist_ok=True)\n    elif args.operation == \"test\":\n        config = ModelMergerConfig(\n            **common_config_args,\n            test_hf_dir=args.test_hf_dir,\n            # the following args are not used by test operation\n            target_dir=None,\n            hf_upload_path=None,\n            private=False,\n        )\n    else:\n        raise NotImplementedError(f\"Unknown operation: {args.operation}\")\n    return config\n\n\nclass BaseModelMerger(ABC):\n    \"\"\"\n    Abstract base class for merging distributed model checkpoints into HuggingFace format.\n\n    This class provides common functionality for converting model checkpoints from different\n    distributed training backends (FSDP, Megatron) into standard HuggingFace format that\n    can be easily loaded and used for inference or further training.\n\n    The merger supports two main operations:\n    - merge: Convert and save checkpoints to HuggingFace format\n    - test: Validate merged checkpoints against a reference model\n\n    Args:\n        config (ModelMergerConfig): Configuration object containing paths, backend type,\n            and operation parameters.\n\n    Attributes:\n        config (ModelMergerConfig): The configuration object passed during initialization.\n        hf_model_config_path (str): Path to the HuggingFace model configuration files.\n        model_config (PretrainedConfig): Loaded HuggingFace model configuration.\n    \"\"\"\n\n    def __init__(self, config: ModelMergerConfig):\n        self.config = config\n        self.hf_model_config_path = config.hf_model_config_path\n        self.model_config = AutoConfig.from_pretrained(\n            self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code\n        )\n\n    def get_transformers_auto_model_class(self):\n        if \"ForTokenClassification\" in self.model_config.architectures[0]:\n            return AutoModelForTokenClassification\n        elif \"ForCausalLM\" in self.model_config.architectures[0]:\n            return AutoModelForCausalLM\n        elif \"ForConditionalGeneration\" in self.model_config.architectures[0]:\n            return AutoModelForVision2Seq\n\n        raise NotImplementedError(f\"Unknown architecture {self.model_config.architectures}\")\n\n    def patch_model_generation_config(self, model):\n        \"\"\"\n        The generation_config created from model config may be different to the pretrained model,\n        this may lead to error when generating: https://github.com/volcengine/verl/issues/1246\n\n        This function patch the generation_config created from model config to the pretrained model.\n        \"\"\"\n        if model.can_generate():\n            try:\n                model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path)\n            except OSError:\n                print(\n                    f\"Warning: Generation config file not found in {self.hf_model_config_path}, using a \"\n                    f\"generation config created from the model config.\"\n                )\n        return model\n\n    def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]):\n        \"\"\"\n        Save lora adapter to safetensors.\n\n        Returns:\n            lora_path: str, the path to the lora adapter. None if no lora adapter found.\n\n        Note:\n            This function change the 'state_dict' in place.\n        \"\"\"\n        lora_params_names = [name for name in state_dict.keys() if \"lora_\" in name]\n\n        if len(lora_params_names) == 0:\n            return None\n\n        import json\n        from typing import OrderedDict\n\n        import peft\n        from safetensors.torch import save_file\n\n        lora_params = OrderedDict()\n        target_modules = set()\n        lora_key = None\n\n        for name in lora_params_names:\n            lora_key = name.replace(\".default.weight\", \".weight\")\n            target_modules.add(lora_key.split(\".\")[-3])\n            lora_params[lora_key] = state_dict.pop(name)\n\n        lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1])\n        peft_dict = {\n            \"r\": lora_rank,\n            \"lora_alpha\": 0,  # lora_alpha is not set. An error should be raised to inform the user to set it manually.\n            \"target_modules\": list(target_modules),\n        }\n        peft_config = peft.LoraConfig(**peft_dict).to_dict()\n        peft_config[\"task_type\"] = peft_config[\"task_type\"].value if peft_config[\"task_type\"] else None\n        peft_config[\"peft_type\"] = peft_config[\"peft_type\"].value if peft_config[\"peft_type\"] else None\n        peft_config[\"target_modules\"] = list(peft_config[\"target_modules\"])\n\n        lora_path = os.path.join(self.config.target_dir, \"lora_adapter\")\n        os.makedirs(lora_path, exist_ok=True)\n        with open(os.path.join(lora_path, \"adapter_config.json\"), \"w\", encoding=\"utf-8\") as f:\n            json.dump(peft_config, f, ensure_ascii=False, indent=4)\n        save_file(lora_params, os.path.join(lora_path, \"adapter_model.safetensors\"))\n\n        for name in list(state_dict.keys()):\n            key = (\n                name.replace(\"base_model.model.\", \"\")\n                .replace(\".base_layer.weight\", \".weight\")\n                .replace(\".base_layer.bias\", \".bias\")\n            )\n            state_dict[key] = state_dict.pop(name)\n\n        return lora_path\n\n    def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]):\n        auto_model_class = self.get_transformers_auto_model_class()\n        with init_empty_weights():\n            model = auto_model_class.from_config(\n                self.model_config, torch_dtype=torch.bfloat16, trust_remote_code=self.config.trust_remote_code\n            )\n        model.to_empty(device=\"cpu\")\n        model = self.patch_model_generation_config(model)\n\n        lora_path = self.save_lora_adapter(state_dict)\n        if lora_path:\n            print(f\"Saving lora adapter to {lora_path}\")\n\n        print(f\"Saving model to {self.config.target_dir}\")\n        model.save_pretrained(self.config.target_dir, state_dict=state_dict)\n        del state_dict\n        del model\n\n        processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)\n        tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)\n        if processor is not None:\n            print(f\"Saving processor to {self.config.target_dir}\")\n            processor.save_pretrained(self.config.target_dir)\n        if tokenizer is not None:\n            print(f\"Saving tokenizer to {self.config.target_dir}\")\n            tokenizer.save_pretrained(self.config.target_dir)\n\n    def upload_to_huggingface(self):\n        import requests\n        from huggingface_hub import HfApi\n        from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError\n\n        api = HfApi()\n        try:\n            # Attempt to create repository\n            api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True)\n        except HfHubHTTPError as e:\n            # Handle authentication/API errors\n            if e.response.status_code == 401:\n                raise PermissionError(\n                    \"Hugging Face authentication failed. Verify your token is valid and has write permissions.\"\n                ) from e\n            elif e.response.status_code == 404:\n                raise RepositoryNotFoundError(f\"Repository path not found: {self.config.hf_upload_path}\") from e\n            else:\n                raise ConnectionError(f\"Failed to create repository ({e.response.status_code}): {e}\") from e\n        except requests.exceptions.ConnectionError as e:\n            raise ConnectionError(\"Network connection failed. Check your internet connection.\") from e\n\n        try:\n            # Attempt folder upload\n            api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type=\"model\")\n        except HfHubHTTPError as e:\n            if e.response.status_code == 401:\n                raise PermissionError(\"Authentication failed during upload. Token may have expired.\") from e\n            else:\n                raise RuntimeError(f\"Upload failed ({e.response.status_code}): {e}\") from e\n        except requests.exceptions.ConnectionError as e:\n            raise ConnectionError(\"Network interruption during upload. Try again with stable connection.\") from e\n        except OSError as e:\n            raise FileNotFoundError(f\"Local folder error: {self.config.target_dir} - {str(e)}\") from e\n        except Exception as e:\n            raise RuntimeError(f\"Unexpected error during upload: {str(e)}\") from e\n\n    @abstractmethod\n    def merge_and_save(self):\n        raise NotImplementedError(\"Subclasses should implement this method\")\n\n    @abstractmethod\n    def cleanup(self):\n        raise NotImplementedError(\"Subclasses should implement this method to clean up resources if needed\")\n"
  },
  {
    "path": "verl_rl/verl/model_merger/fsdp_model_merger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\nfrom concurrent.futures import ThreadPoolExecutor\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom torch.distributed._tensor import Placement, Shard\n\ntry:\n    # for torch 2.5+\n    from torch.distributed.tensor import DTensor\nexcept ImportError:\n    from torch.distributed._tensor import DTensor\n\nfrom tqdm import tqdm\n\nfrom .base_model_merger import BaseModelMerger\n\n\nclass FSDPModelMerger(BaseModelMerger):\n    \"\"\"\n    Model merger for FSDP (Fully Sharded Data Parallel) checkpoints.\n\n    This class handles the conversion of FSDP distributed checkpoints into HuggingFace format.\n    FSDP shards model parameters across multiple processes, and this merger reconstructs\n    the full model by loading and concatenating the sharded parameters from all ranks.\n\n    The merger supports various FSDP configurations including:\n    - Pure FSDP (single dimension sharding)\n    - FSDP + DDP (data parallel + fully sharded data parallel)\n    - DTensor-based sharding with custom device meshes\n\n    Key features:\n    - Automatic detection of world size from checkpoint filenames\n    - Support for DTensor and non-DTensor checkpoints\n    - Parallel loading of checkpoint shards for efficiency\n    - Validation against reference HuggingFace models\n\n    Example:\n        To merge FSDP checkpoints:\n        ```python\n        config = ModelMergerConfig(\n            operation=\"merge\",\n            backend=\"fsdp\",\n            local_dir=\"path/to/fsdp/checkpoints\",\n            target_dir=\"path/to/output\"\n        )\n        merger = FSDPModelMerger(config)\n        merger.merge_and_save()\n        ```\n    \"\"\"\n\n    def _get_world_size(self) -> int:\n        \"\"\"_summary_\n        From FSDP json config file, extract the world size.\n\n        Returns:\n            int: world size\n        \"\"\"\n        config_path = Path(self.config.local_dir) / \"fsdp_config.json\"\n        if not config_path.exists():\n            raise FileNotFoundError(f\"Config file {config_path} does not exist.\")\n\n        with open(config_path) as f:\n            config = json.load(f)\n\n        # Extract world size from the config\n        world_size = config.get(\"world_size\", None)\n        if world_size is None:\n            raise ValueError(\"World size not found in the config file.\")\n\n        return world_size\n\n    def _load_rank_zero_state_dict(self, world_size: int) -> dict:\n        return torch.load(\n            Path(self.config.local_dir) / f\"model_world_size_{world_size}_rank_0.pt\",\n            map_location=\"cpu\",\n            weights_only=False,\n        )\n\n    def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]:\n        \"\"\"\n        Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict.\n        If no DTensor is found, infers a simple FSDP mesh based on world_size.\n        \"\"\"\n        pivot_key = sorted(list(state_dict.keys()))[0]\n        weight = state_dict[pivot_key]\n\n        if isinstance(weight, DTensor):\n            # get sharding info\n            device_mesh = weight.device_mesh\n            mesh = device_mesh.mesh\n            mesh_dim_names = device_mesh.mesh_dim_names\n        else:\n            # for non-DTensor\n            mesh = np.array([world_size], dtype=np.int64)\n            mesh_dim_names = (\"fsdp\",)\n\n        return mesh, mesh_dim_names\n\n    def _calculate_shard_configuration(\n        self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...]\n    ) -> tuple[int, tuple[int, ...]]:\n        \"\"\"Calculates the total number of shards and the shape of the device mesh.\"\"\"\n        assert mesh_dim_names in ((\"fsdp\",), (\"ddp\", \"fsdp\")), f\"Unsupported mesh_dim_names {mesh_dim_names}\"\n\n        if \"tp\" in mesh_dim_names:\n            # TODO: \"tp\" is not supported yet due to the above assert\n            total_shards = mesh.shape[-1] * mesh.shape[-2]\n            mesh_shape = (mesh.shape[-2], mesh.shape[-1])\n        else:\n            total_shards = mesh.shape[-1]\n            mesh_shape = (mesh.shape[-1],)\n\n        return total_shards, mesh_shape\n\n    def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor:\n        \"\"\"Merges a list of tensors based on their DTensor placement\"\"\"\n        if placement.is_replicate():\n            return tensors[0]\n        elif placement.is_partial():\n            raise NotImplementedError(\"Partial placement is not supported yet\")\n        elif placement.is_shard():\n            return torch.cat(tensors, dim=placement.dim).contiguous()\n\n        raise NotImplementedError(f\"Unsupported placement: {placement}\")\n\n    def _load_and_merge_state_dicts(\n        self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...]\n    ) -> dict[str, torch.Tensor]:\n        model_state_dict_lst = [None] * total_shards\n\n        def process_one_shard(rank: int, model_state_dict_lst: list):\n            model_path = Path(self.config.local_dir) / f\"model_world_size_{world_size}_rank_{rank}.pt\"\n            state_dict = torch.load(model_path, map_location=\"cpu\", weights_only=False)\n            model_state_dict_lst[rank] = state_dict\n            return state_dict\n\n        with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:\n            futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)]\n            for future in tqdm(futures, desc=f\"Loading {total_shards} FSDP shards\", total=total_shards):\n                future.result()\n\n        # Merge state dicts from all shards\n        state_dict = {}\n        param_placements: dict[str, list] = {}\n\n        for key in set(model_state_dict_lst[0].keys()):\n            state_dict[key] = []\n            for model_state_shard in model_state_dict_lst:\n                # add tensor shard in order of rank to state_dict[key]\n                tensor = model_state_shard.pop(key)\n                if isinstance(tensor, DTensor):\n                    state_dict[key].append(tensor._local_tensor.bfloat16())\n\n                    placements = tuple(tensor.placements)\n                    # replicated placement at dp dimension can be discarded\n                    if mesh_dim_names[0] in (\"dp\", \"ddp\"):\n                        placements = placements[1:]\n\n                    if key not in param_placements:\n                        param_placements[key] = placements\n                    else:\n                        assert param_placements[key] == placements\n                else:\n                    state_dict[key].append(tensor.bfloat16())\n\n        del model_state_dict_lst\n\n        # Merge tensors\n        for key in sorted(state_dict):\n            if not isinstance(state_dict[key], list):\n                print(f\"No need to merge key {key}\")\n                continue\n            if key in param_placements:\n                # merge shards\n                placements: tuple[Shard] = param_placements[key]\n                if len(mesh_shape) == 1:\n                    # 1-D list, FSDP without TP\n                    assert len(placements) == 1\n                    shards = state_dict[key]\n                    state_dict[key] = self._merge_by_placement(shards, placements[0])\n                else:\n                    # 2-D list, FSDP + TP\n                    raise NotImplementedError(\"FSDP + TP is not supported yet\")\n            else:\n                state_dict[key] = torch.cat(state_dict[key], dim=0)\n\n        return state_dict\n\n    def merge_and_save(self):\n        world_size = self._get_world_size()\n        rank_zero_state_dict = self._load_rank_zero_state_dict(world_size)\n\n        mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size)\n        print(f\"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}\")\n\n        total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names)\n        print(f\"Processing model shards with {total_shards} {mesh_shape} in total\")\n\n        merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names)\n\n        if self.config.operation == \"test\":\n            if not self.config.test_hf_dir:\n                raise ValueError(\"test_hf_dir must be provided for test operation\")\n            self._validate_state_dict(merged_state_dict)\n        elif self.config.operation == \"merge\":\n            self.save_hf_model_and_tokenizer(merged_state_dict)\n            if self.config.hf_upload:\n                self.upload_to_huggingface()\n        else:\n            raise ValueError(f\"Unknown operation: {self.config.operation}\")\n\n    def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]):\n        auto_model_class = self.get_transformers_auto_model_class()\n\n        hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16)\n        hf_state_dict = hf_model.state_dict()\n        del hf_model\n\n        hf_model_keys = set(hf_state_dict.keys())\n        collected_keys = set(state_dict.keys())\n\n        missing_keys = hf_model_keys - collected_keys\n        assert len(missing_keys) == 0, f\"Missing keys in collected state dict: {list(sorted(missing_keys))}\"\n\n        extra_keys = collected_keys - hf_model_keys\n        assert len(extra_keys) == 0, f\"Extra keys in collected state dict: {list(sorted(extra_keys))}\"\n\n        for key in hf_model_keys:\n            hf_shape = hf_state_dict[key].shape\n            collected_shape = state_dict[key].shape\n            assert hf_shape == collected_shape, (\n                f\"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}\"\n            )\n\n            hf_dtype = hf_state_dict[key].dtype\n            collected_dtype = state_dict[key].dtype\n            assert hf_dtype == collected_dtype, (\n                f\"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}\"\n            )\n\n            torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6)\n\n        print(\"FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.\")\n\n    def cleanup(self):\n        \"\"\"Cleanup temporary files if needed.\"\"\"\n        # FSDP merger does not create temporary files, so no cleanup is needed.\n        pass\n"
  },
  {
    "path": "verl_rl/verl/model_merger/megatron_model_merger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\nimport warnings\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import Any, Callable, ContextManager\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom accelerate import init_empty_weights\nfrom megatron.core import mpu\nfrom megatron.core.models.gpt.gpt_model import ModelType\nfrom megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed\nfrom safetensors.torch import load_file\nfrom transformers import (\n    AutoConfig,\n    PretrainedConfig,\n)\n\nfrom verl.models.mcore import hf_to_mcore_config\nfrom verl.utils.device import get_device_name, get_nccl_backend, get_torch_device\nfrom verl.utils.megatron.dist_checkpointing import load_dist_checkpointing\nfrom verl.utils.megatron_utils import get_model\nfrom verl.utils.tokenizer import hf_processor, hf_tokenizer\n\nfrom .base_model_merger import BaseModelMerger, ModelMergerConfig\n\n\n@contextmanager\ndef noop_context() -> Any:\n    yield\n\n\ndef get_dynamic_pipeline_shards(layer_num: int, pp_size: int) -> list[int]:\n    \"\"\"Calculate the pipeline sharding configuration for Megatron-LM.\n\n    Args:\n        layer_num: Total number of layers in the model.\n        pp_size: Number of pipeline parallel ranks.\n\n    Returns:\n        layer number of each pp rank. Make the sharding of the pipeline as uniform as possible.\n    \"\"\"\n    if layer_num < pp_size:\n        raise ValueError(f\"layer_num {layer_num} must be greater than pp_size {pp_size}.\")\n\n    if pp_size < 1:\n        raise ValueError(f\"pp_size must be at least 1, got {pp_size}.\")\n    if pp_size == 1:\n        return [layer_num]\n\n    if pp_size == 2:\n        return [\n            layer_num // 2,\n            layer_num - layer_num // 2,\n        ]\n\n    middle_size = pp_size - 2\n    shards_strategy = []\n    for middle_layer_num in range(layer_num):\n        first_last_layer_num = layer_num - middle_layer_num * middle_size\n        first_layer_num = first_last_layer_num // 2\n        last_layer_num = first_last_layer_num - first_last_layer_num // 2\n        if 0 < first_layer_num <= middle_layer_num and 0 < last_layer_num <= middle_layer_num:\n            shards_strategy.append(\n                (\n                    [first_layer_num] + [middle_layer_num] * middle_size + [last_layer_num],\n                    abs(first_layer_num - middle_layer_num),\n                )\n            )\n\n    # sort by diff of layer_num, to make it as uniform as possible\n    res = sorted(shards_strategy, key=lambda x: x[1])[0][0]\n    assert sum(res) == layer_num, f\"sum(res)={sum(res)} != layer_num={layer_num}, pp_size={pp_size}\"\n    return res\n\n\nclass MegatronModelMerger(BaseModelMerger):\n    \"\"\"\n    Model merger for Megatron-LM distributed checkpoints.\n\n    This class handles the conversion of Megatron-LM distributed checkpoints into HuggingFace format.\n    Megatron-LM uses tensor parallelism, pipeline parallelism, and data parallelism to distribute\n    large language models across multiple GPUs. This merger reconstructs the full model by\n    loading distributed checkpoints and applying the necessary transformations.\n\n    Key features:\n    - Support for tensor parallel, pipeline parallel, and data parallel configurations\n    - Automatic parameter name mapping from Megatron to HuggingFace conventions\n    - Handling of QKV and gate-up tensor splitting/merging\n    - Support for tied word embeddings and value models\n    - Integration with Megatron's distributed checkpointing system\n\n    The merger handles various model architectures and configurations:\n    - Standard transformer models (GPT-style)\n    - Models with tied word embeddings\n    - Value models for reinforcement learning\n    - Multi-layer attention (MLA) architectures\n    - Mixture of Experts (MoE) models\n\n    Args:\n        config (ModelMergerConfig): Configuration object with Megatron-specific settings\n            including tie_word_embedding and is_value_model flags.\n\n    Example:\n        To merge Megatron checkpoints:\n        ```python\n        config = ModelMergerConfig(\n            operation=\"merge\",\n            backend=\"megatron\",\n            local_dir=\"path/to/megatron/checkpoints\",\n            target_dir=\"path/to/output\",\n            tie_word_embedding=True\n        )\n        merger = MegatronModelMerger(config)\n        merger.merge_and_save()\n        ```\n    \"\"\"\n\n    def __init__(self, config: ModelMergerConfig):\n        super().__init__(config)\n        # Currently we use only 1 rank to merge the dist_ckpt, we will move to multi-process save shortly afterwards\n        if \"WORLD_SIZE\" not in os.environ:\n            os.environ[\"RANK\"] = \"0\"\n            os.environ[\"LOCAL_RANK\"] = \"0\"\n            os.environ[\"WORLD_SIZE\"] = \"1\"\n            os.environ[\"MASTER_ADDR\"] = \"localhost\"\n            os.environ[\"MASTER_PORT\"] = \"12355\"\n\n        torch.distributed.init_process_group(get_nccl_backend())\n\n        self.rank = torch.distributed.get_rank()\n        self.world_size = torch.distributed.get_world_size()\n        local_rank = os.environ.get(\"LOCAL_RANK\", 0)\n        get_torch_device().set_device(f\"{get_device_name()}:{local_rank}\")\n\n        mpu.initialize_model_parallel(\n            tensor_model_parallel_size=1,\n            pipeline_model_parallel_size=self.world_size,\n            virtual_pipeline_model_parallel_size=None,\n            context_parallel_size=1,\n            expert_model_parallel_size=1,\n        )\n        model_parallel_cuda_manual_seed(0)\n        self.hf_config = AutoConfig.from_pretrained(\n            self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code\n        )\n        print(self.hf_config, flush=True)\n\n        self.params_mapping = {\n            # megatron core gpt model name, huggingface model name\n            # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the\n            # longer key within the containing relationship is processed first.\n            \"embedding.word_embeddings\": \"model.embed_tokens\",\n            # input layer norm for dpskv3\n            \"input_layernorm.weight\": \"input_layernorm.weight\",\n            \"input_layernorm.bias\": \"input_layernorm.bias\",\n            # attn\n            \"self_attention.linear_qkv.layer_norm_weight\": \"input_layernorm.weight\",\n            \"self_attention.linear_qkv.layer_norm_bias\": \"input_layernorm.bias\",\n            \"self_attention.linear_qkv\": \"self_attn.qkv_proj\",\n            \"self_attention.q_layernorm\": \"self_attn.q_norm\",\n            \"self_attention.k_layernorm\": \"self_attn.k_norm\",\n            \"self_attention.linear_proj\": \"self_attn.o_proj\",\n            # mla\n            \"self_attention.linear_q_proj\": \"self_attn.q_proj\",\n            \"self_attention.linear_q_down_proj\": \"self_attn.q_a_proj\",\n            \"self_attention.linear_q_up_proj.layer_norm_weight\": \"self_attn.q_a_layernorm.weight\",\n            \"self_attention.linear_q_up_proj\": \"self_attn.q_b_proj\",\n            \"self_attention.linear_kv_down_proj\": \"self_attn.kv_a_proj_with_mqa\",\n            \"self_attention.linear_kv_up_proj.layer_norm_weight\": \"self_attn.kv_a_layernorm.weight\",\n            \"self_attention.linear_kv_up_proj\": \"self_attn.kv_b_proj\",\n            # mlp\n            \"pre_mlp_layernorm\": \"post_attention_layernorm\",\n            \"mlp.linear_fc1.layer_norm_weight\": \"post_attention_layernorm.weight\",\n            \"mlp.linear_fc1.layer_norm_bias\": \"post_attention_layernorm.bias\",\n            \"mlp.linear_fc1\": \"mlp.gate_up_proj\",\n            \"mlp.linear_fc2\": \"mlp.down_proj\",\n            # moe\n            \"mlp.router.expert_bias\": \"mlp.gate.e_score_correction_bias\",\n            \"mlp.router\": \"mlp.gate\",\n            \"mlp.shared_experts.linear_fc1\": \"mlp.shared_experts.gate_up_proj\",\n            \"mlp.shared_experts.linear_fc2\": \"mlp.shared_experts.down_proj\",\n            \"linear_fc1\": \"gate_up_proj\",\n            \"linear_fc2\": \"down_proj\",\n            # output\n            \"final_layernorm\": \"norm\",\n            \"output_layer\": \"lm_head\",\n        }\n\n        if \"Qwen2MoeForCausalLM\" in self.hf_config.architectures:\n            self.params_mapping[\"mlp.shared_experts.linear_fc1\"] = \"mlp.shared_expert.gate_up_proj\"\n            self.params_mapping[\"mlp.shared_experts.linear_fc2\"] = \"mlp.shared_expert.down_proj\"\n            self.params_mapping[\"mlp.shared_experts.gate_weight\"] = \"mlp.shared_expert_gate.weight\"\n\n    def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]:\n        \"\"\"_summary_\n        Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory.\n\n        Args:\n            model_ckpt_path (str): Path to the model checkpoint directory.\n\n        Returns:\n            State dict containing the model parameters.\n        \"\"\"\n\n        # init hf config\n        self.pipeline_shards = get_dynamic_pipeline_shards(self.hf_config.num_hidden_layers, self.world_size)\n        print(f\"Pipeline shards: {self.pipeline_shards}, total layers: {sum(self.pipeline_shards)}\")\n\n        tf_config = hf_to_mcore_config(\n            self.hf_config,\n            torch.bfloat16,\n            num_layers_in_first_pipeline_stage=self.pipeline_shards[0] if len(self.pipeline_shards) > 1 else None,\n            num_layers_in_last_pipeline_stage=self.pipeline_shards[-1] if len(self.pipeline_shards) > 2 else None,\n        )\n        tf_config.use_cpu_initialization = self.config.use_cpu_initialization\n        tie_word_embeddings = getattr(self.hf_config, \"tie_word_embeddings\", False)\n\n        # init megatron model\n        def megatron_model_provider(pre_process, post_process):\n            from verl.models.mcore import init_mcore_model\n\n            parallel_model = init_mcore_model(\n                tf_config,\n                self.hf_config,\n                pre_process,\n                post_process,\n                share_embeddings_and_output_weights=tie_word_embeddings,\n                value=False,\n            )\n            return parallel_model\n\n        context: Callable[..., ContextManager] = (\n            init_empty_weights if self.config.use_cpu_initialization else noop_context\n        )\n        with context():\n            whole_model = get_model(\n                model_provider_func=megatron_model_provider,\n                model_type=ModelType.encoder_or_decoder,\n                wrap_with_ddp=False,\n                transformer_config=tf_config,\n            )\n\n        if self.config.use_cpu_initialization:\n            # convert meta device to empty tensor so it can use `copy_` function\n            whole_model[0].module = whole_model[0].module.to_empty(device=\"cpu\")\n\n        # load state dicts\n        sharded_state_dict = {}\n        for vpp_rank, model in enumerate(whole_model):\n            key = f\"model{vpp_rank}\" if len(whole_model) > 1 else \"model\"\n            mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)\n            sharded_state_dict[key] = model.sharded_state_dict()\n        model_state_dict = load_dist_checkpointing(sharded_state_dict, model_ckpt_path)\n        model_state_dict_list = []\n        for vpp_rank, model in enumerate(whole_model):\n            key = f\"model{vpp_rank}\" if len(whole_model) > 1 else \"model\"\n            mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)\n            model_state_dict_list.append(model_state_dict[key])\n\n        return model_state_dict_list\n\n    def _check_megatron_state_key(self, key: str) -> bool:\n        \"\"\"\n        Checks if the key is a valid Megatron state key.\n\n        Now the model merger only supports keys that start with \"decoder/embedding/output_layer\" in TransformerLayer.\n        Shall not use key starts with \"model.\"\n        \"\"\"\n        if key.startswith(\"model.\"):\n            raise ValueError(\n                f\"Invalid key {key} in Megatron state_dict. Expected keys to start with \"\n                f\"'decoder/embedding/output_layer' in TransformerLayer.\"\n            )\n\n        skip_checking_keys = [\"embedding.word_embeddings\", \"output_layer\"]\n        for skip_key in skip_checking_keys:\n            if skip_key in key:\n                print(f\"skip checking key {key}\")\n                return\n\n        # Exclude extra state keys\n        if not key.startswith(\"decoder\"):\n            raise ValueError(\n                f\"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer.\"\n            )\n\n    def _split_tensors(\n        self, key: str, tensor: torch.Tensor, config: PretrainedConfig, is_value_model: bool = False\n    ) -> list[torch.Tensor]:\n        \"\"\"\n        Splits a tensor into multiple tensors based on the name.\n        This is used to handle qkv and gate_up tensors.\n        \"\"\"\n        if \"linear_fc1.weight\" in key:\n            # if the tensor is gate and proj\n            gate_lst = []\n            up_lst = []\n            gate, up = tensor.chunk(2)\n            gate_lst.append(gate)\n            up_lst.append(up)\n            gate = torch.cat(gate_lst, dim=0)\n            up = torch.cat(up_lst, dim=0)\n            return [gate, up]\n        elif \"self_attention.linear_qkv.\" in key and \"layer_norm\" not in key:\n            # if the tensor is qkv, for each param on tp, split into q, k, v\n            # concat q, k, v separately.\n            q_lst, k_lst, v_lst = [], [], []\n            assert config.num_attention_heads % config.num_key_value_heads == 0\n            num_q_per_kv = config.num_attention_heads // config.num_key_value_heads\n            assert tensor.shape[0] % (num_q_per_kv + 2) == 0, (\n                f\"Tensor shape {tensor.shape} is not divisible by {num_q_per_kv + 2}\"\n            )\n            kv_size = tensor.shape[0] // (num_q_per_kv + 2)\n            split_size = [kv_size * num_q_per_kv, kv_size, kv_size]\n\n            num_query_groups_per_partition = config.num_key_value_heads\n            for chunk in tensor.chunk(num_query_groups_per_partition):\n                split_size = [\n                    kv_size * num_q_per_kv // num_query_groups_per_partition,\n                    kv_size // num_query_groups_per_partition,\n                    kv_size // num_query_groups_per_partition,\n                ]\n                q, k, v = chunk.split(split_size)\n                q_lst.append(q)\n                k_lst.append(k)\n                v_lst.append(v)\n\n            return [torch.cat(q_lst, dim=0), torch.cat(k_lst, dim=0), torch.cat(v_lst, dim=0)]\n        else:\n            return [tensor]\n\n    def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]:\n        state_dict = {}\n        layers_cum = 0\n        if self.world_size > 1:\n            pipeline_cumsum = np.cumsum(self.pipeline_shards)\n            layers_cum = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1]\n\n        print(f\"{layers_cum=}\")\n        for model_state_dict in model_state_dict_list:\n            layers_handled = 0\n            keys = model_state_dict.keys()\n            for key in keys:\n                if \"extra_state\" in key:\n                    continue\n                if self.config.tie_word_embedding and (\"output_layer\" in key):\n                    print(\"skip lm_head and reward_head loading because of tie_word_embeddings\")\n                    continue\n\n                self._check_megatron_state_key(key)\n                hf_name = self._replace_name(key, self.params_mapping)\n                assert hf_name is not None, f\"Failed to convert layer name [{key}] from megatron to huggingface.\"\n                if \"model.layers.\" in hf_name:\n                    local_layer_no = int(hf_name.split(\".\")[2])\n                    layers_handled = max(local_layer_no, layers_handled)\n                    global_layer_no = local_layer_no + layers_cum\n                    new_key_list = hf_name.split(\".\")\n                    new_key_list[2] = str(global_layer_no)\n                    hf_name = \".\".join(new_key_list)\n                else:\n                    warnings.warn(f\"hf_name {hf_name} will not be fixed with layer number\", stacklevel=2)\n\n                if \"mlp.experts.\" in hf_name and \".weight\" in hf_name:\n                    name_prefix, expert_id = hf_name.split(\".weight\")\n                    for proj in [\"gate_up\", \"down\"]:\n                        if f\"{proj}_proj\" in hf_name:\n                            hf_name = hf_name.replace(\n                                f\"mlp.experts.{proj}_proj.weight{expert_id}\",\n                                f\"mlp.experts.{expert_id}.{proj}_proj.weight\",\n                            )\n\n                tensor = model_state_dict[key]\n                split_tensor = self._split_tensors(\n                    key, tensor, self.hf_config, is_value_model=self.config.is_value_model\n                )\n\n                if len(split_tensor) == 1:\n                    state_dict[hf_name] = split_tensor[0]\n                elif len(split_tensor) == 3:\n                    # split qkv\n                    for n, d in zip([\"q\", \"k\", \"v\"], split_tensor, strict=True):\n                        state_dict[hf_name.replace(\"qkv\", n)] = d\n                elif len(split_tensor) == 2:\n                    # split gate up\n                    state_dict[hf_name.replace(\"gate_up\", \"gate\")] = split_tensor[0]\n                    state_dict[hf_name.replace(\"gate_up\", \"up\")] = split_tensor[1]\n                shape_info = (\n                    split_tensor.shape if isinstance(split_tensor, torch.Tensor) else [t.shape for t in split_tensor]\n                )\n                print(f\"converted {key} to {hf_name} with shape {shape_info}\")\n\n            layers_cum += layers_handled + 1  # zero based\n\n        return state_dict\n\n    def save_hf_model_and_tokenizer(self, merged_state_dict):\n        if self.world_size == 1:\n            return super().save_hf_model_and_tokenizer(merged_state_dict)\n\n        from safetensors.torch import save_file\n\n        layer_num = self.hf_config.num_hidden_layers\n\n        # FIXME: make configurable\n        saves_per_layer = 1 if layer_num < 30 else 2\n        saves_total = saves_per_layer * layer_num\n        saves_indexes = {}\n\n        # calculate the layer start index and key chunks\n        layer_this_rank = self.pipeline_shards[self.rank]\n        pipeline_cumsum = np.cumsum(self.pipeline_shards)\n        layer_start = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1]\n        keys = list(merged_state_dict.keys())\n        keys_chunk = np.array_split(np.array(keys), layer_this_rank * saves_per_layer)\n        numel = 0\n\n        assert len(keys_chunk) == layer_this_rank * saves_per_layer, (\n            f\"Expected {len(keys_chunk)} chunks, but got {layer_this_rank * saves_per_layer} for rank {self.rank}.\"\n        )\n\n        # save to model shards manually\n        target_dir = Path(self.config.target_dir)\n        for i, keys in enumerate(keys_chunk):\n            sd_to_save = {k: merged_state_dict[k] for k in keys}\n            numel += sum([sd_to_save[i].numel() for i in sd_to_save])\n            save_idx = layer_start * saves_per_layer + i\n            save_path = target_dir / f\"model-{save_idx + 1:05d}-of-{saves_total:05d}.safetensors\"\n\n            save_file(sd_to_save, save_path)\n            for k in keys:\n                saves_indexes[k] = str(save_path.name)\n\n        tensor = torch.tensor([numel]).to(get_device_name())\n        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)\n        numel = tensor.cpu().item()\n\n        all_save_indexes = [{} for _ in range(self.world_size)]\n        dist.all_gather_object(all_save_indexes, saves_indexes)\n        saves_indexes = {k: v for i in all_save_indexes for k, v in i.items()}\n        if self.rank == 0:\n            with open(target_dir / \"model.safetensors.index.json\", \"w\") as f:\n                json.dump(\n                    {\n                        \"metadata\": {\n                            \"total_size\": numel,\n                        },\n                        \"weight_map\": saves_indexes,\n                    },\n                    f,\n                    indent=4,\n                )\n            print(f\"model saved to {target_dir} with {numel=}\")\n\n            self.model_config.save_pretrained(self.config.target_dir)\n\n            processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)\n            tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code)\n            if processor is not None:\n                print(f\"Saving processor to {self.config.target_dir}\")\n                processor.save_pretrained(self.config.target_dir)\n            if tokenizer is not None:\n                print(f\"Saving tokenizer to {self.config.target_dir}\")\n                tokenizer.save_pretrained(self.config.target_dir)\n\n    def merge_and_save(self):\n        from verl.utils.megatron_utils import get_dist_checkpoint_path\n\n        model_ckpt_path = get_dist_checkpoint_path(self.config.local_dir)\n\n        model_state_dict = self._load_state_dicts(model_ckpt_path)\n        merged_state_dict = self._merge_state_dicts(model_state_dict)\n        del model_state_dict\n\n        if self.config.operation == \"test\":\n            if not self.config.test_hf_dir:\n                raise ValueError(\"test_hf_dir must be provided for test operation\")\n            self._validate_state_dict(merged_state_dict)\n        elif self.config.operation == \"merge\":\n            self.save_hf_model_and_tokenizer(merged_state_dict)\n            if self.config.hf_upload:\n                self.upload_to_huggingface()\n        else:\n            raise ValueError(f\"Unknown operation: {self.config.operation}\")\n\n    def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]):\n        \"\"\"\n        Compares the merged Megatron state_dict against a reference safetensors model.\n        Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name.\n        \"\"\"\n        ref_state_dict = load_file(Path(self.config.test_hf_dir) / \"model.safetensors\")\n\n        for name, loaded_weight in state_dict.items():\n            # name = self._replace_name(original_name, self.params_mapping)\n            if not name or name.endswith(\".bias\") and name not in ref_state_dict:\n                continue\n            if \"rotary_emb.inv_freq\" in name:\n                continue\n            if \"lm_head.weight\" in name:\n                if self.config.is_value_model or self.config.tie_word_embedding:\n                    continue\n            if name not in ref_state_dict:\n                raise RuntimeError(f\"key: {name} not exist in state_dict\")\n            param = ref_state_dict[name]\n            assert loaded_weight.dtype == param.dtype\n            torch.testing.assert_close(loaded_weight.to(\"cpu\"), param, atol=1e-2, rtol=5e-2)\n\n    def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str:\n        for m_name, v_name in name_mapping.items():\n            if m_name not in megatron_name:\n                continue\n\n            megatron_name = megatron_name.replace(\"decoder\", \"model\")\n            param_name = megatron_name.replace(m_name, v_name)\n\n            return param_name\n\n        return None  # Return None if no mapping found\n\n    def cleanup(self):\n        torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "verl_rl/verl/models/README.md",
    "content": "# Models\nCommon modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl. \n## Adding a New Huggingface Model\n### Step 1: Copy the model file from HF to verl\n- Add a new file under verl/models/hf\n- Copy ONLY the model file from huggingface/transformers/models to verl/models/hf\n\n### Step 2: Modify the model file to use packed inputs\n- Remove all the code related to inference (kv cache)\n- Modify the inputs to include only\n    - input_ids (total_nnz,)\n    - cu_seqlens (total_nnz + 1,)\n    - max_seqlen_in_batch: int\n- Note that this requires using flash attention with causal mask.\n\n### Step 2.5: Add tests\n- Add a test to compare this version and the huggingface version\n- Following the infrastructure and add tests to tests/models/hf\n\n### Step 3: Add a function to apply tensor parallelism\n- Please follow\n    - https://pytorch.org/docs/stable/distributed.tensor.parallel.html\n    - https://pytorch.org/tutorials/intermediate/TP_tutorial.html\n- General comments\n    - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward.\n\n### Step 4: Add a function to apply data parallelism\n- Please use FSDP2 APIs\n- See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413\n\n### Step 5: Add a function to apply pipeline parallelism\n- Comes in Pytorch 2.4\n- Currently only in alpha in nightly version\n- Check torchtitan for more details\n\n"
  },
  {
    "path": "verl_rl/verl/models/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/models/llama/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .modeling_llama_megatron import (\n    ParallelLlamaForCausalLM,\n    # rmpad with megatron\n    ParallelLlamaForCausalLMRmPad,\n    # rmpad with megatron and pipeline parallelism\n    ParallelLlamaForCausalLMRmPadPP,\n    ParallelLlamaForValueRmPad,\n    ParallelLlamaForValueRmPadPP,\n    # original model with megatron\n    ParallelLlamaModel,\n)\n\n__all__ = [\n    \"ParallelLlamaForCausalLM\",\n    \"ParallelLlamaForCausalLMRmPad\",\n    \"ParallelLlamaForCausalLMRmPadPP\",\n    \"ParallelLlamaForValueRmPad\",\n    \"ParallelLlamaForValueRmPadPP\",\n    \"ParallelLlamaModel\",\n]\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/checkpoint_utils/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/checkpoint_utils/llama_loader.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    print(f\"get megatron data parallel size: {mpu.get_data_parallel_world_size()}\")\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_llama(\n    state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False\n):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from verl.utils.logger import print_rank_0\n    from verl.utils.megatron_utils import unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def fetch_params(module):\n        for param in module.parameters():\n            torch.distributed.fetch(\n                param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()\n            )\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, (\n        f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size \"\n        f\"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n    )\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _fetch_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"fetch tensor\"\"\"\n        nonlocal state_dict\n        if tensor is not None:\n            tensor.data.copy_(state_dict[name])\n\n    def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor.data.copy_(tensor_chunk[tp_rank])\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor.data.copy_(tensor_chunk[tp_rank])\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"fetch gate_up tensor in tp shards\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if gate_name in state_dict and up_name in state_dict:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(\n                config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(\n                    torch.cat([gate_weight_tp, up_weight_tp], dim=0)\n                )\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            if tensor is not None:\n                tensor.data.copy_(tensor_chunk[tp_rank])\n        else:\n            print(f\"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n        full_weight_q = state_dict[q_name]\n        full_weight_k = state_dict[k_name]\n        full_weight_v = state_dict[v_name]\n\n        hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n        if config.num_key_value_heads >= tp_size:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n            total_size = q_size_tp + 2 * kv_size_tp\n            new_weight_qkv = torch.empty(\n                total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        else:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head\n            total_size = q_size_tp + 2 * kv_size_tp\n            new_weight_qkv = torch.empty(\n                total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                k_part = full_weight_k[start_idx:end_idx]\n                v_part = full_weight_v[start_idx:end_idx]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n        if tensor is not None:\n            tensor.data.copy_(tensor_chunk[tp_rank])\n\n    # Embeddings\n    # -------------------\n    print_rank_0(\"loading embeddings...\")\n    gpt_model_module = _get_gpt_model(models[0])\n    embed_tokens_weight = None\n    if pp_rank == 0:\n        embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n    _fetch_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n    # Transformer layers\n    # -------------------\n    layer_map = _megatron_calc_layer_map(config)\n\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    num_layer_per_pp = config.num_hidden_layers // pp_size\n    vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n\n    layer_list = []\n    if vpp_size is not None:\n        for vpp_rank in range(vpp_size):\n            num_layer_vpp_chunk = num_layer_per_pp // vpp_size\n            num_layer_this_model = num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (\n                mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk\n            )\n            layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n    else:\n        num_layer_this_model = num_layer_per_pp\n        offset = pp_rank * num_layer_per_pp\n        layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n\n    for layer in layer_list:\n        print_rank_0(f\"loading layer #{layer}...\")\n        layer_name = f\"model.layers.{layer}\"\n        dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n        gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n        sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n        _fetch_tensor(\n            sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.input_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_qkv(\n            sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.q_proj.weight\",\n            f\"{layer_name}.self_attn.k_proj.weight\",\n            f\"{layer_name}.self_attn.v_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.o_proj.weight\",\n            chunk_dim=1,\n        )\n\n        _fetch_tensor(\n            sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.post_attention_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_gate_up(\n            sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.gate_proj.weight\",\n            f\"{layer_name}.mlp.up_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.down_proj.weight\",\n            chunk_dim=1,\n        )\n    # Final Layernorm\n    # -------------------\n    print_rank_0(\"loading final layernorm...\")\n    gpt_model_module = _get_gpt_model(models[-1])\n    _fetch_tensor(\n        getattr(gpt_model_module.model.norm, \"weight\", None),\n        \"model.norm.weight\",\n    )\n\n    print_rank_0(\"loading lm_head...\")\n    if pp_rank + 1 == pp_size:\n        lm_head_weight = gpt_model_module.lm_head.weight\n\n        if is_value_model:\n            if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                _fetch_tensor(lm_head_weight, \"lm_head.weight\")\n                print_rank_0(\"load lm_head weight\")\n            elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                _fetch_tensor(lm_head_weight, \"reward_head.weight\")\n                print_rank_0(\"load lm_head from value_head weight\")\n            else:\n                _fetch_tensor(None, \"lm_head.weight\")\n                print_rank_0(\"fail to match lm_head in value_model\")\n        else:\n            _fetch_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n\n    dist.barrier()\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    print(f\"get megatron data parallel size: {mpu.get_data_parallel_world_size()}\")\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_llama(\n    state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False\n):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from verl.utils.logger import print_rank_0\n    from verl.utils.megatron_utils import unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def broadcast_params(module):\n        for param in module.parameters():\n            torch.distributed.broadcast(\n                param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()\n            )\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, (\n        f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size \"\n        f\"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n    )\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _broadcast_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"broadcast tensor from rank0 across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                weight = state_dict[name]\n                tensor_shape = weight.shape\n            else:\n                tensor_shape = None\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not in state_dict, skip load\")\n            return\n\n        if tensor is None:\n            tensor = torch.empty(\n                tensor_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        if torch.distributed.get_rank() == 0:\n            tensor.data.copy_(weight)\n        dist.broadcast(tensor, src=0, group=mp_group)\n\n    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(\n                config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(\n                    torch.cat([gate_weight_tp, up_weight_tp], dim=0)\n                )\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape \"\n                f\"{tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n            full_weight_q = state_dict[q_name]\n            full_weight_k = state_dict[k_name]\n            full_weight_v = state_dict[v_name]\n\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                new_weight_qkv = torch.empty(\n                    total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                )\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(\n                        torch.cat([q_part, k_part, v_part], dim=0)\n                    )\n\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                new_weight_qkv = torch.empty(\n                    total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                )\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                    k_part = full_weight_k[start_idx:end_idx]\n                    v_part = full_weight_v[start_idx:end_idx]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(\n                        torch.cat([q_part, k_part, v_part], dim=0)\n                    )\n\n            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"loading embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        embed_tokens_weight = None\n        if pp_rank == 0:\n            embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"loading layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.input_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                chunk_dim=1,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                chunk_dim=1,\n            )\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"loading final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n        )\n\n        print_rank_0(\"loading lm_head...\")\n        lm_head_weight = None\n        if pp_rank + 1 == pp_size:\n            lm_head_weight = gpt_model_module.lm_head.weight\n\n        if is_value_model:\n            if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n                print_rank_0(\"load lm_head weight\")\n            elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"reward_head.weight\")\n                print_rank_0(\"load lm_head from value_head weight\")\n            else:\n                _broadcast_tensor(None, \"lm_head.weight\")\n                print_rank_0(\"fail to match lm_head in value_model\")\n        else:\n            _broadcast_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n    dist.barrier()\n    # Broadcast weights inside data parallel groups\n    for wrapped_model in wrapped_models:\n        broadcast_params(wrapped_model)\n\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/checkpoint_utils/llama_saver.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import mpu\nfrom megatron.core.distributed import DistributedDataParallel as LocalDDP\nfrom megatron.core.transformer.module import Float16Module\nfrom torch.nn.parallel import DistributedDataParallel as torchDDP\n\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.logger import print_rank_0\nfrom verl.utils.megatron_utils import unwrap_model\n\n\ndef _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):\n    \"\"\"given TP,DP,PP rank to get the global rank.\"\"\"\n\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    dp_size = mpu.get_data_parallel_world_size()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), (\n        f\"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}\"\n    )\n    # We only support TP-DP-PP grouping, for correctness when resharding\n    return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Merge sharded parameters of a Megatron module into a merged checkpoint.\n\n    Args:\n        wrapped_models (list of megatron.core.distributed.DistributedDataParallel):\n            The local DDP wrapped megatron modules.\n        config (str or None):\n            HF config for model\n        dtype: model params type\n        is_value_model: if model is value model\n        tie_word_embeddings: tie_word_embeddings, not used in llama, only to keep same interface with qwen2\n    Returns:\n        state_dict (dict):\n            The merged state_dict in rank 0, and an empty dictionary in other ranks.\n    \"\"\"\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if dist.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        assert len(models[i].model.layers) == num_layers_per_model, (\n            \"len model layers {} not equal to num_layers_per_model {}\".format(\n                len(models[i].model.layers), num_layers_per_model\n            )\n        )\n\n    state_dict = dict()\n\n    def _get_cpu_tensor(tensor: torch.Tensor):\n        if tensor is None:\n            return None\n        if tensor.device == torch.device(\"cpu\"):\n            return tensor.detach().clone()\n        return tensor.detach().cpu()\n\n    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        if torch.distributed.get_rank() == src_rank:\n            if tensor is None:\n                weight = None\n                tensor_shape = None\n            else:\n                weight = tensor\n                tensor_shape = weight.shape\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not exist, skip collect\")\n            return\n\n        if weight is None:\n            weight = torch.empty(\n                tensor_shape,\n                dtype=dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n\n        dist.broadcast(weight, src=src_rank, group=mp_group)\n\n        if torch.distributed.get_rank() == 0:\n            state_dict[name] = _get_cpu_tensor(weight)\n\n    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)\n            if mutate_func is not None:\n                full_tensor = mutate_func(full_tensor)\n            state_dict[name] = full_tensor\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            intermediate_size_tp = config.intermediate_size // tp_size\n            gate_weight_list = []\n            up_weight_list = []\n            for i in range(tp_size):\n                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n                gate_weight_list.append(gate_weight_tp)\n                up_weight_list.append(up_weight_tp)\n\n            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)\n            state_dict[up_name] = torch.cat(up_weight_list, dim=0)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            q_weight_list = []\n            k_weight_list = []\n            v_weight_list = []\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    k_weight_list.append(k_part)\n                    v_weight_list.append(v_part)\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    if i * config.num_key_value_heads % tp_size == 0:\n                        k_weight_list.append(k_part)\n                        v_weight_list.append(v_part)\n\n            state_dict[q_name] = torch.cat(q_weight_list, dim=0)\n            state_dict[k_name] = torch.cat(k_weight_list, dim=0)\n            state_dict[v_name] = torch.cat(v_weight_list, dim=0)\n\n    # empty cache before collecting weights\n    get_torch_device().empty_cache()\n    # Embeddings\n    # -------------------\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"collecting embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        _broadcast_tp_shard_tensor(\n            gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,\n            \"model.embed_tokens.weight\",\n            src_pp_rank=0,\n        )\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"collecting layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[src_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight,\n                f\"{layer_name}.input_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"collecting final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n            src_pp_rank=pp_size - 1,\n        )\n\n        print_rank_0(\"collecting lm_head...\")\n\n        if is_value_model:\n            if pp_rank == pp_size - 1:\n                print(f\"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}\")\n            _broadcast_tensor(\n                gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,\n                \"lm_head.weight\",\n                src_pp_rank=pp_size - 1,\n            )\n            _broadcast_tensor(\n                gpt_model_module.reward_head.weight\n                if pp_rank == pp_size - 1 and getattr(gpt_model_module, \"reward_weight\", None) is not None\n                else None,\n                \"reward_head.weight\",\n                src_pp_rank=pp_size - 1,\n            )\n\n        else:\n            _broadcast_tp_shard_tensor(\n                getattr(gpt_model_module.lm_head, \"weight\", None) if pp_rank == pp_size - 1 else None,\n                \"lm_head.weight\",\n                src_pp_rank=pp_size - 1,\n            )\n\n    dist.barrier()\n\n    get_torch_device().empty_cache()\n    if torch.distributed.get_rank() == 0:\n        if dtype not in [torch.float16, torch.bfloat16, torch.float32]:\n            print(f'Unknown/unsupported dtype to save: {dtype}\"')\n            exit(1)\n        for k, v in state_dict.items():\n            if dtype != v.dtype:\n                state_dict[k] = v.to(dtype)\n\n    print_rank_0(f\"merge megatron ckpt done, time elapsed {time.time() - start_time}s\")\n    return state_dict\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/layers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .parallel_attention import ParallelLlamaAttention\nfrom .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad\nfrom .parallel_linear import (\n    LinearForLastLayer,\n    MergedColumnParallelLinear,\n    QKVParallelLinear,\n)\nfrom .parallel_mlp import ParallelLlamaMLP\nfrom .parallel_rmsnorm import ParallelLlamaRMSNorm\n\n__all__ = [\n    \"LinearForLastLayer\",\n    \"MergedColumnParallelLinear\",\n    \"QKVParallelLinear\",\n    \"ParallelLlamaAttention\",\n    \"ParallelLlamaDecoderLayer\",\n    \"ParallelLlamaDecoderLayerRmPad\",\n    \"ParallelLlamaMLP\",\n    \"ParallelLlamaRMSNorm\",\n]\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/layers/parallel_attention.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom flash_attn.layers.rotary import apply_rotary_emb\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers import LlamaConfig\nfrom transformers.utils import is_flash_attn_2_available\n\nfrom verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear\nfrom verl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass LlamaRotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n        )\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\nclass LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n        t = t / self.scaling_factor\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\nclass LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * (\n                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)\n            ) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\nclass LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__(dim, max_position_embeddings, base, device)\n\n        self.factor = config.rope_scaling[\"factor\"]  # `8` in the original implementation\n        self.high_freq_factor = config.rope_scaling[\"high_freq_factor\"]  # `1` in the original implementation\n        self.low_freq_factor = config.rope_scaling[\"low_freq_factor\"]  # `4` in the original implementation\n        self.old_context_len = config.rope_scaling[\n            \"original_max_position_embeddings\"\n        ]  # `8192` in the original implementation\n\n        low_freq_wavelen = self.old_context_len / self.low_freq_factor\n        high_freq_wavelen = self.old_context_len / self.high_freq_factor\n\n        wavelen = 2 * math.pi / self.inv_freq\n        # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor\n        inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq)\n        # otherwise: interpolate between the two, using a smooth factor\n        smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / (\n            self.high_freq_factor - self.low_freq_factor\n        )\n        smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama\n        is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)\n        inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n        )\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass ParallelLlamaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config = config\n        self.megatron_config = megatron_config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n\n        # assign values after tp\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert self.num_heads % tp_size == 0, (\n            f\"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}\"\n        )\n        assert self.num_key_value_heads % tp_size == 0, (\n            f\"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=\"\n            f\"{self.num_key_value_heads}, tp_size={tp_size}\"\n        )\n\n        self.num_heads_per_tp = self.num_heads // tp_size\n        self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size\n        self.hidden_size_per_tp = self.hidden_size // tp_size\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and \"\n                f\"`num_heads`: {self.num_heads}).\"\n            )\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n\n        # [self.q_size, self.k_size, self.v_size]\n        self.qkv_proj = QKVParallelLinear(\n            input_size=self.hidden_size,\n            num_heads=self.num_heads,\n            num_key_value_heads=self.num_key_value_heads,\n            head_dim=self.head_dim,\n            bias=config.attention_bias,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n        self.q_size = self.num_heads_per_tp * self.head_dim\n        self.k_size = self.num_key_value_heads_per_tp * self.head_dim\n        self.v_size = self.num_key_value_heads_per_tp * self.head_dim\n\n        self.o_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.num_heads * self.head_dim,\n            output_size=self.hidden_size,\n            bias=config.attention_bias,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self._init_rope()\n\n    def _init_rope(self):\n        if self.config.rope_scaling is None:\n            self.rotary_emb = LlamaRotaryEmbedding(\n                self.head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                base=self.rope_theta,\n            )\n        else:\n            rope_type_key = \"type\" if \"type\" in self.config.rope_scaling else \"rope_type\"\n            scaling_type = self.config.rope_scaling[rope_type_key]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if scaling_type == \"linear\":\n                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(\n                    self.head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"dynamic\":\n                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(\n                    self.head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"llama3\":\n                self.rotary_emb = LlamaLlama3ScalingRotaryEmbedding(\n                    self.head_dim,\n                    self.config,\n                    max_position_embeddings=self.max_position_embeddings,\n                    base=self.rope_theta,\n                )\n            else:\n                raise ValueError(f\"Unknown RoPE scaling type {scaling_type}\")\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, \"\n                f\"but is {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, \"\n                f\"but is {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)\n        attn_output = self.o_proj(attn_output)[0]\n        return attn_output\n\n\n\"\"\"\nRemove padding Attention\n- Using Flash-attn 2\n- Compatible with sequence parallel\n\"\"\"\n\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n\ndef apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):\n    batch_size = position_ids.shape[0]\n\n    q = pad_input(q, indices, batch_size, sequence_length)  # (batch_size, seqlen, num_head, head_dim)\n    k = pad_input(k, indices, batch_size, sequence_length)\n    cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n\n    q_embed = index_first_axis(rearrange(q_embed, \"b s ... -> (b s) ...\"), indices)\n    k_embed = index_first_axis(rearrange(k_embed, \"b s ... -> (b s) ...\"), indices)\n\n    return q_embed, k_embed\n\n\n# use flash-attn rotary embeddings with rmpad\n# cos/sin shoudl be: (seq_length, rotary_dim / 2)\ndef apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):\n    q_embed = apply_rotary_emb(\n        q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen\n    )\n    k_embed = apply_rotary_emb(\n        k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen\n    )\n    return q_embed, k_embed\n\n\nclass ParallelLlamaAttentionRmPad(ParallelLlamaAttention):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: torch.Tensor = None,\n        max_seqlen_in_batch: int = None,\n    ):\n        total_nnz, _, _ = hidden_states.size()  # This is the total_nnz padded after sequence parallel\n\n        if self.megatron_config.sequence_parallel:\n            total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()\n\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split(\n            [self.q_size, self.k_size, self.v_size], dim=-1\n        )  # (total_nnz, 1, hidden_size)\n\n        if self.megatron_config.sequence_parallel:\n            sequence_parallel_pad = total_nnz - cu_seqlens[-1]\n            total_nnz = cu_seqlens[-1]  # total_nnz before sp padding\n            query_states = query_states[:total_nnz]\n            key_states = key_states[:total_nnz]\n            value_states = value_states[:total_nnz]\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dime x hidden_dim\n        # therefore we just need to keep the original shape\n        query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)\n        key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n        value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n\n        cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)\n        cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2]  # flash attn only needs half\n        query_states, key_states = apply_rotary_pos_emb_rmpad_flash(\n            query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch\n        )\n        # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin,\n        # position_ids, indices,\n\n        # TODO: llama does not have dropout in the config??\n        # It is recommended to use dropout with FA according to the docs\n        # when training.\n        dropout_rate = 0.0  # if not self.training else self.attn_dropout\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in float16 just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (LlamaRMSNorm handles it correctly)\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            query_states = query_states.to(torch.float16)\n            key_states = key_states.to(torch.float16)\n            value_states = value_states.to(torch.float16)\n\n        attn_output_unpad = flash_attn_varlen_func(\n            query_states,\n            key_states,\n            value_states,\n            cu_seqlens_q=cu_seqlens,\n            cu_seqlens_k=cu_seqlens,\n            max_seqlen_q=max_seqlen_in_batch,\n            max_seqlen_k=max_seqlen_in_batch,\n            dropout_p=dropout_rate,\n            softmax_scale=None,\n            causal=True,\n        )\n\n        attn_output_unpad = attn_output_unpad.to(input_dtype)\n        attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()\n\n        # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled\n        # Here we need to repad\n        if self.megatron_config.sequence_parallel:\n            attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))\n\n        attn_output_unpad = self.o_proj(attn_output_unpad)[0]\n        return attn_output_unpad\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/layers/parallel_decoder.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import LlamaConfig\n\nfrom verl.utils.megatron_utils import TransformerConfig, convert_config\n\nfrom .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad\nfrom .parallel_mlp import ParallelLlamaMLP\nfrom .parallel_rmsnorm import ParallelLlamaRMSNorm\n\n\nclass ParallelLlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Note: sequence parallel is hidden inside ColumnParallelLinear\n        # reduce scatter is hidden inside RowParallelLinear\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        # TODO: add sequence parallel operator all_gather here\n\n        hidden_states = self.mlp(hidden_states)\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n\n\nclass ParallelLlamaDecoderLayerRmPad(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        residual = hidden_states  # (total_nnz // sp, 1, hidden_size)\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)\n        # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        # shape changes same as attn\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/layers/parallel_linear.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py\n\nimport torch\nfrom megatron.core import tensor_parallel\n\n\nclass QKVParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        num_heads,\n        num_key_value_heads,\n        head_dim,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.q_output_size = num_heads * head_dim\n        self.kv_output_size = num_key_value_heads * head_dim\n        self.head_dim = head_dim\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        input_size = self.input_size\n        output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim\n\n        super().__init__(\n            input_size=input_size,\n            output_size=output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n\n\nclass MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        gate_ouput_size,\n        up_output_size,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.output_size = gate_ouput_size + up_output_size\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        super().__init__(\n            input_size=self.input_size,\n            output_size=self.output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n\n\nclass LinearForLastLayer(torch.nn.Linear):\n    def __init__(\n        self,\n        input_size,\n        output_size,\n        *,\n        config,\n        bias=True,\n    ):\n        super().__init__(in_features=input_size, out_features=output_size, bias=bias)\n        self.sequence_parallel = config.sequence_parallel\n        if self.sequence_parallel:\n            self.weight.sequence_parallel = True\n\n    def forward(\n        self,\n        input_,\n        weight=None,\n        runtime_gather_output=None,\n    ):\n        logits = super().forward(input_)\n        logits = logits.float()\n        if self.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits, None\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/layers/parallel_mlp.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers.activations import ACT2FN\n\nfrom verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear\nfrom verl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass ParallelLlamaMLP(nn.Module):\n    def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=self.hidden_size,\n            gate_ouput_size=self.intermediate_size,\n            up_output_size=self.intermediate_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n        self.gate_size = self.intermediate_size // tp_size\n\n        self.down_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.intermediate_size,\n            output_size=self.hidden_size,\n            bias=False,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        gate_up = self.gate_up_proj(x)[0]\n        gate, up = gate_up.split(self.gate_size, dim=-1)\n        return self.down_proj(self.act_fn(gate) * up)[0]\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/layers/parallel_rmsnorm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numbers\n\nimport torch\nfrom apex.normalization.fused_layer_norm import fused_rms_norm_affine\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import LlamaConfig\n\nfrom verl.utils.megatron import sequence_parallel as sp_utils\n\n\nclass ParallelLlamaRMSNorm(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        if isinstance(config.hidden_size, numbers.Integral):\n            normalized_shape = (config.hidden_size,)\n        self.normalized_shape = torch.Size(normalized_shape)\n        self.weight = nn.Parameter(torch.ones(self.normalized_shape))\n        self.variance_epsilon = config.rms_norm_eps\n\n        if megatron_config.sequence_parallel:\n            sp_utils.mark_parameter_as_sequence_parallel(self.weight)\n\n    def forward(self, hidden_states):\n        return fused_rms_norm_affine(\n            input=hidden_states,\n            weight=self.weight,\n            normalized_shape=self.normalized_shape,\n            eps=self.variance_epsilon,\n            memory_efficient=True,\n        )\n"
  },
  {
    "path": "verl_rl/verl/models/llama/megatron/modeling_llama_megatron.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch LLaMA model with Megatron-style acceleration.\"\"\"\n\nfrom typing import Optional\n\nimport torch\nimport torch.utils.checkpoint\nfrom megatron.core import ModelParallelConfig, mpu, tensor_parallel\nfrom torch import nn\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\nfrom transformers.models.llama.configuration_llama import LlamaConfig\nfrom transformers.models.llama.modeling_llama import CausalLMOutputWithPast\n\nfrom verl.utils.megatron import sequence_parallel as sp_utils\nfrom verl.utils.megatron import tensor_parallel as tp_utils\nfrom verl.utils.megatron_utils import TransformerConfig, convert_config\n\nfrom .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm\n\n\"\"\"\nTODO: \n1. Add weight initialization. Here we need to be careful on TP weight init.\n2. Add sequence parallel\n3. Load checkpoint from meta LLama pretrained checkpoint\n\"\"\"\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\nclass ParallelLlamaModel(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n            num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n        )\n\n        self.layers = nn.ModuleList(\n            [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n                inputs_embeds.device\n            )\n            combined_attention_mask = (\n                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n            )\n\n        return combined_attention_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (batch_size, seq_length)\n            attention_mask: attention_mask. shape (batch_size, seq_length)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        batch_size, seq_length = input_ids.shape\n        inputs_embeds = self.embed_tokens(input_ids)\n        # embed positions\n\n        attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)\n\n        hidden_states = inputs_embeds\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelLlamaForCausalLM(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.model = ParallelLlamaModel(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        hidden_states = outputs\n        logits = self.lm_head(hidden_states)[0]\n\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)\n\n        logits = logits.float()\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nfrom flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n\nclass ParallelLlamaModelRmPad(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        self.megatron_config = megatron_config\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n            num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n        )\n\n        self.layers = nn.ModuleList(\n            [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n        # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n        inputs_embeds = inputs_embeds.transpose(0, 1)\n        if self.megatron_config.sequence_parallel:\n            inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n        hidden_states = inputs_embeds\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelLlamaForCausalLMRmPad(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n        self._init_head(config)\n\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        logits = self.lm_head(hidden_states)[0]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)  # (total_nnz_padded, 1, vocab_size)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n        batch_size, sequence_length = input_ids.shape\n\n        # remove padding here\n        input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(\n            input_ids.unsqueeze(dim=-1), attention_mask\n        )  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids = sp_utils.pad_to_sequence_parallel(input_ids)\n\n        input_ids = input_ids.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = outputs\n\n        logits = self._forward_head(hidden_states)\n\n        # remove padding from sequence parallel\n        if self.megatron_config.sequence_parallel:\n            totol_nnz = cu_seqlens[-1]\n            logits = logits[:totol_nnz]  # (total_nnz_padded)\n\n        logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension\n        # add removed padding back\n        logits = pad_input(\n            logits, indices, batch_size, seqlen=sequence_length\n        )  # (batch_size, sequence_length, vocab_size)\n\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nclass ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        output = super().forward(input_ids, attention_mask, position_ids)\n        output.logits = torch.squeeze(output.logits, dim=-1)\n        return output\n\n\n\"\"\"\nSupport pipeline parallelism\n\"\"\"\n\n\nclass ParallelLlamaModelRmPadPP(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n    This model definition supports pipeline parallelism. To support pp and vpp,\n    - This model only contains layer in this pp stage and vpp chunk\n    - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        self.megatron_config = megatron_config\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        if pre_process:\n            self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n                num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n            )\n        else:\n            self.embed_tokens = None\n\n        pp_rank = mpu.get_pipeline_model_parallel_rank()\n        pp_size = megatron_config.pipeline_model_parallel_size\n        self.num_layer_per_pp = config.num_hidden_layers // pp_size\n        vpp_size = megatron_config.virtual_pipeline_model_parallel_size\n        vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()\n\n        if vpp_size is not None:\n            self.layers = nn.ModuleList()\n            self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size\n            self.num_layer_this_model = self.num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk)\n        else:\n            self.num_layer_this_model = self.num_layer_per_pp\n            offset = pp_rank * self.num_layer_per_pp\n\n        self.layers = nn.ModuleList()\n        for i in range(self.num_layer_this_model):\n            layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i)\n            self.layers.add_module(f\"{i}\", layer)\n\n        if post_process:\n            self.norm = ParallelLlamaRMSNorm(config, megatron_config)\n        else:\n            self.norm = None\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        self.input_tensor = input_tensor\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        if self.pre_process:\n            inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n            # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron\n            # so need to deal with it by handle here:\n            # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n            inputs_embeds = inputs_embeds.transpose(0, 1)\n            if self.megatron_config.sequence_parallel:\n                inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n            hidden_states = inputs_embeds\n        else:\n            # self.hidden_states should be passed by Megatron\n            hidden_states = self.input_tensor\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        if self.post_process:\n            hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelLlamaForCausalLMRmPadPP(nn.Module):\n    def __init__(\n        self,\n        config: LlamaConfig,\n        megatron_config: ModelParallelConfig,\n        pre_process,\n        post_process,\n        share_embeddings_and_output_weights=False,\n    ):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelLlamaModelRmPadPP(\n            config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process\n        )\n        assert share_embeddings_and_output_weights is False, (\n            \"Llama Model not supports sharing embedding and output weights\"\n        )\n        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        if post_process:\n            self._init_head(config)\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        assert len(input_tensor) == 1\n        self.model.set_input_tensor(input_tensor[0])\n\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        # logits shape before forward_head hidden_states.shape: [4, 32, 4096]\n        logits = self.lm_head(hidden_states)[0]\n        # logits shape after forward_head logits.shape: [8, 32, 8]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        return logits\n\n    def forward(\n        self,\n        # original input\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # Note that input_ids, attention_mask and position_ids should be passed to every pp layer.\n        # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model\n        batch_size, sequence_length = input_ids.shape\n        # remove padding here\n        input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(\n            input_ids.unsqueeze(dim=-1), attention_mask\n        )  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)\n\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids_rmpad,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        if self.post_process:\n            hidden_states = outputs\n            # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])\n            logits = self._forward_head(hidden_states)\n            logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension # torch.Size([8, 32, 16])\n\n            # remove padding from sequence parallel\n            if self.megatron_config.sequence_parallel:\n                totol_nnz = cu_seqlens[-1]\n                logits = logits[:totol_nnz]  # (total_nnz_padded)\n            # add removed padding back. If input is already rmpad, we let the caller pad_input\n            logits = pad_input(\n                logits, indices, batch_size, seqlen=sequence_length\n            )  # (batch_size, sequence_length, vocab_size)\n\n            return CausalLMOutputWithPast(\n                loss=None,\n                logits=logits,\n                past_key_values=None,\n                hidden_states=None,\n                attentions=None,\n            )\n        else:\n            return outputs\n\n\nclass ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)\n        if self.post_process:\n            output.logits = torch.squeeze(output.logits, dim=-1)\n            return output\n        else:\n            return output\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .registry import (\n    get_mcore_forward_fn,\n    get_mcore_forward_fused_fn,\n    get_mcore_weight_converter,\n    hf_to_mcore_config,\n    init_mcore_model,\n)\n\n__all__ = [\n    \"hf_to_mcore_config\",\n    \"init_mcore_model\",\n    \"get_mcore_forward_fn\",\n    \"get_mcore_weight_converter\",\n    \"get_mcore_forward_fused_fn\",\n]\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/config_converter.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# convert huggingface config to mcore transformer config\n\n\nimport warnings\n\nimport torch\nimport torch.nn.functional as F\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.transformer import MLATransformerConfig, TransformerConfig\nfrom transformers import PretrainedConfig\n\n\ndef _get_base_transformer_config(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> dict:\n    \"\"\"\n    Create a base TransformerConfig with common parameters across different model architectures.\n    TODO: (ycl) use dataclass or converter config?\n\n    Args:\n        hf_config: HuggingFace model configuration\n        dtype: Data type for the model\n        override_transformer_config_kwargs: Additional parameters to override defaults\n\n    Returns:\n        TransformerConfig with common parameters\n    \"\"\"\n\n    # Common parallel state parameters\n    overlap_p2p_comm = (\n        mpu.get_virtual_pipeline_model_parallel_world_size() is not None\n        and mpu.get_virtual_pipeline_model_parallel_world_size() > 1\n    )\n    batch_p2p_comm = False\n\n    # Base configuration with common parameters\n    base_config = {\n        # Model architecture parameters\n        \"num_layers\": hf_config.num_hidden_layers,\n        \"hidden_size\": hf_config.hidden_size,\n        \"num_attention_heads\": hf_config.num_attention_heads,\n        \"num_query_groups\": hf_config.num_key_value_heads,\n        \"ffn_hidden_size\": hf_config.intermediate_size,\n        \"attention_dropout\": hf_config.attention_dropout,\n        \"hidden_dropout\": getattr(hf_config, \"hidden_dropout\", 0.0),\n        \"kv_channels\": getattr(hf_config, \"head_dim\", None),\n        \"layernorm_epsilon\": hf_config.rms_norm_eps,\n        \"add_bias_linear\": True,\n        # Activation and normalization\n        \"activation_func\": F.silu,\n        \"normalization\": \"RMSNorm\",\n        \"gated_linear_unit\": True,\n        # Data types\n        \"pipeline_dtype\": dtype,\n        \"params_dtype\": dtype,\n        \"bf16\": dtype is torch.bfloat16,\n        # Parallel configuration\n        \"tensor_model_parallel_size\": mpu.get_tensor_model_parallel_world_size(),\n        \"pipeline_model_parallel_size\": mpu.get_pipeline_model_parallel_world_size(),\n        \"expert_model_parallel_size\": mpu.get_expert_model_parallel_world_size(),\n        \"expert_tensor_parallel_size\": mpu.get_expert_tensor_parallel_world_size(),\n        \"virtual_pipeline_model_parallel_size\": mpu.get_virtual_pipeline_model_parallel_world_size(),\n        \"context_parallel_size\": mpu.get_context_parallel_world_size(),\n        \"overlap_p2p_comm\": overlap_p2p_comm,\n        \"batch_p2p_comm\": batch_p2p_comm,\n        \"sequence_parallel\": mpu.get_tensor_model_parallel_world_size() > 1,\n        # Common settings\n        \"variable_seq_lengths\": True,\n        \"masked_softmax_fusion\": True,\n        \"moe_token_dispatcher_type\": \"alltoall\",\n    }\n\n    # Update with any provided overrides\n    # override_transformer_config_kwargs as kwargs shall never be none\n    base_config.update(override_transformer_config_kwargs)\n\n    return base_config\n\n\ndef _get_mla_transformer_config(\n    hf_config: PretrainedConfig, mla_rope_config: dict, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> dict:\n    \"\"\"\n    Create a MLATransformerConfig with common parameters across different model architectures.\n    This is specifically for MLA models like DeepseekV3.\n\n    Args:\n        hf_config: HuggingFace model configuration\n        mla_rope_config: MLA specific RoPE configuration\n        dtype: Data type for the model\n        override_transformer_config_kwargs: Additional parameters to override defaults\n\n    Returns:\n        MLATransformerConfig with common parameters\n    \"\"\"\n    base_config = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, **override_transformer_config_kwargs)\n    mla_config = {\n        # MLA specific parameters\n        \"q_lora_rank\": hf_config.q_lora_rank,\n        \"kv_lora_rank\": hf_config.kv_lora_rank,\n        \"qk_head_dim\": hf_config.qk_nope_head_dim,\n        \"qk_pos_emb_head_dim\": hf_config.qk_rope_head_dim,\n        \"v_head_dim\": hf_config.v_head_dim,\n        \"rotary_base\": hf_config.rope_theta,\n        \"rotary_scaling_factor\": mla_rope_config[\"factor\"],\n        \"rope_type\": mla_rope_config[\"type\"],\n        \"max_position_embeddings\": mla_rope_config[\"original_max_position_embeddings\"],\n        \"beta_fast\": mla_rope_config[\"beta_fast\"],\n        \"beta_slow\": mla_rope_config[\"beta_slow\"],\n        \"mscale\": mla_rope_config[\"mscale\"],\n        \"mscale_all_dim\": mla_rope_config[\"mscale_all_dim\"],\n    }\n\n    base_config.update(mla_config)\n    return base_config\n\n\ndef check_and_disable_incompatible_configs(original_config: dict) -> dict:\n    \"\"\"\n    Check and disable incompatible configurations for older Megatron version.\n\n    Args:\n        original_config (dict): The original model configuration.\n\n    Returns:\n        dict: The updated model configuration with incompatible settings disabled.\n    \"\"\"\n    removed_keys = []\n    for key in original_config.keys():\n        if not hasattr(TransformerConfig, key):\n            removed_keys.append(key)\n    if removed_keys:\n        warnings.warn(\n            f\"The following keys are not supported in the current Megatron version and will be removed: {removed_keys}\",\n            stacklevel=2,\n        )\n        for key in removed_keys:\n            original_config.pop(key)\n    return original_config\n\n\ndef hf_to_mcore_config_dense(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    # for LlamaForCausalLM or Qwen2ForCausalLM\n    qkv_bias = True if \"Qwen2ForCausalLM\" in hf_config.architectures else getattr(hf_config, \"attention_bias\", False)\n    qk_layernorm = True if \"Qwen3ForCausalLM\" in hf_config.architectures else False\n\n    args: dict = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        add_qkv_bias=qkv_bias,\n        qk_layernorm=qk_layernorm,\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    args = check_and_disable_incompatible_configs(args)\n    print(f\"Overridden TF init config: {args}\")\n    return TransformerConfig(**args)\n\n\ndef hf_to_mcore_config_qwen2moe(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    args: dict = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        layernorm_epsilon=hf_config.rms_norm_eps,\n        # MoE specific\n        moe_ffn_hidden_size=hf_config.moe_intermediate_size,\n        moe_router_bias_update_rate=0.001,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        num_moe_experts=hf_config.num_experts,\n        moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size,\n        moe_aux_loss_coeff=hf_config.router_aux_loss_coef,\n        # moe_aux_loss_coeff=0.0,\n        moe_router_load_balancing_type=\"none\",  # turn off aux_loss as it hurts perf in RL\n        moe_shared_expert_overlap=True,\n        moe_grouped_gemm=True,\n        moe_router_score_function=\"softmax\",\n        # Other optimizations\n        persist_layer_norm=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n        # Qwen specific\n        moe_router_pre_softmax=True,\n        add_qkv_bias=True,\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    args = check_and_disable_incompatible_configs(args)\n    print(f\"Overridden TF init config: {args}\")\n    return TransformerConfig(**args)\n\n\ndef hf_to_mcore_config_mixtral(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    args: dict = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        layernorm_epsilon=hf_config.rms_norm_eps,\n        # MoE specific\n        num_moe_experts=hf_config.num_local_experts,\n        moe_aux_loss_coeff=hf_config.router_aux_loss_coef,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        moe_router_pre_softmax=True,\n        moe_router_load_balancing_type=\"none\",  # turn off aux_loss as it hurts perf in RL\n        moe_router_score_function=\"softmax\",\n        moe_shared_expert_intermediate_size=None,  # mixtral has no shared expert\n        moe_shared_expert_overlap=False,  # mixtral has no shared expert\n        moe_ffn_hidden_size=hf_config.intermediate_size,\n        moe_router_bias_update_rate=0.001,\n        # moe_permute_fusion=True, # need TE 2.1+\n        moe_grouped_gemm=True,\n        # Other optimizations\n        persist_layer_norm=True,\n        apply_rope_fusion=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    args = check_and_disable_incompatible_configs(args)\n    print(f\"Overridden TF init config: {args}\")\n    return TransformerConfig(**args)\n\n\ndef hf_to_mcore_config_qwen3moe(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    args: dict = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        layernorm_epsilon=hf_config.rms_norm_eps,\n        # MoE specific\n        moe_ffn_hidden_size=hf_config.moe_intermediate_size,\n        moe_router_bias_update_rate=0.001,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        num_moe_experts=hf_config.num_experts,\n        moe_aux_loss_coeff=hf_config.router_aux_loss_coef,\n        # moe_aux_loss_coeff=0.0,\n        moe_router_load_balancing_type=\"none\",  # turn off aux_loss as it hurts perf in RL\n        moe_grouped_gemm=True,\n        moe_router_score_function=\"softmax\",\n        # Other optimizations\n        persist_layer_norm=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n        # Qwen specific\n        moe_router_pre_softmax=False,\n        qk_layernorm=True,\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    args = check_and_disable_incompatible_configs(args)\n    print(f\"Overridden TF init config: {args}\")\n    return TransformerConfig(**args)\n\n\ndef hf_to_mcore_config_dpskv3(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> MLATransformerConfig:\n    # DeepseekV3ForCausalLM\n    from megatron.core.transformer.enums import AttnBackend\n\n    from .patch_v012 import apply_patch\n\n    apply_patch()\n\n    mla_rope_config = {\n        \"beta_fast\": 32,\n        \"beta_slow\": 1,\n        \"factor\": 1,\n        \"mscale\": 1.0,\n        \"mscale_all_dim\": 1.0,\n        \"original_max_position_embeddings\": 4096,\n        \"type\": \"rope\",\n    }\n    if \"rope_scaling\" in hf_config and hf_config.rope_scaling is not None:\n        mla_rope_config.update(hf_config.rope_scaling)\n    moe_layer_freq = [1] * hf_config.num_hidden_layers\n    for i in range(min(hf_config.first_k_dense_replace, hf_config.num_hidden_layers)):\n        moe_layer_freq[i] = 0\n\n    # disable MTP and quantization for now\n    if \"num_nextn_predict_layers\" in hf_config:\n        assert hf_config.num_nextn_predict_layers == 0, (\n            \"MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0\"\n        )\n    assert \"quantization_config\" not in hf_config or not hf_config.quantization_config, (\n        \"quantization is not supported for now, please modify the config.json to remove quantization_config\"\n    )\n\n    args: dict = _get_mla_transformer_config(\n        hf_config=hf_config,\n        mla_rope_config=mla_rope_config,\n        dtype=dtype,\n        # Additional parameters\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        attention_backend=AttnBackend.fused,\n        qk_layernorm=True,\n        # Standard MoE parameters\n        moe_ffn_hidden_size=hf_config.moe_intermediate_size,\n        moe_token_dispatcher_type=\"alltoall\",\n        moe_router_bias_update_rate=0.001,\n        moe_router_enable_expert_bias=True,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        num_moe_experts=hf_config.n_routed_experts,\n        moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts,\n        moe_aux_loss_coeff=getattr(hf_config, \"aux_loss_alpha\", 0.001),\n        moe_router_load_balancing_type=\"seq_aux_loss\",\n        moe_shared_expert_overlap=True,\n        # moe_permute_fusion=True, # need TE 2.1+\n        moe_grouped_gemm=True,\n        moe_router_score_function=\"sigmoid\",\n        moe_router_pre_softmax=True,\n        moe_router_topk_scaling_factor=hf_config.routed_scaling_factor,\n        moe_layer_freq=moe_layer_freq,\n        # mcore 0.12 moe\n        moe_router_dtype=\"fp64\",\n        disable_bf16_reduced_precision_matmul=True,\n        # Other optimizations\n        # deallocate_pipeline_outputs=True,\n        # gradient_accumulation_fusion=True,\n        persist_layer_norm=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    args = check_and_disable_incompatible_configs(args)\n    transformer_config: MLATransformerConfig = MLATransformerConfig(**args)\n    print(f\"Overridden MLA TF init config: {transformer_config}\")\n    # MTP\n    if \"num_nextn_predict_layers\" in hf_config:\n        transformer_config.mtp_num_layers = hf_config.num_nextn_predict_layers\n        transformer_config.mtp_loss_scaling_factor = 0.1\n\n    return transformer_config\n\n\ndef hf_to_mcore_config_qwen2_5_vl(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    # Qwen2_5_VLForConditionalGeneration\n\n    args = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        add_bias_linear=False,\n        # qwen specific\n        add_qkv_bias=True,\n        mrope_section=hf_config.rope_scaling[\"mrope_section\"],\n    )\n    # override_transformer_config_kwargs as kwargs shall never be none\n    args.update(override_transformer_config_kwargs)\n    args = check_and_disable_incompatible_configs(args)\n    print(f\"Overridden TF init config: {args}\")\n    return TransformerConfig(**args)\n\n\ndef hf_to_mcore_config_llama4(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    # Llama4ForConditionalGeneration\n    raise NotImplementedError(\"Llama4ForConditionalGeneration is not supported yet\")\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/loader.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_id, get_torch_device\n\nfrom .saver import _megatron_calc_global_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from verl.utils.logger import print_rank_0\n    from verl.utils.megatron_utils import unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def broadcast_params(module):\n        for param in module.parameters():\n            torch.distributed.broadcast(\n                param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()\n            )\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    cp_rank = mpu.get_context_parallel_rank()\n    src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank)\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == src_rank:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.decoder.layers) == num_layers_per_model\n\n    def _broadcast_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"broadcast tensor from rank0 across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        if torch.distributed.get_rank() == src_rank:\n            if name in state_dict:\n                weight = state_dict[name]\n                tensor_shape = weight.shape\n            else:\n                tensor_shape = None\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not in state_dict, skip load\")\n            return\n\n        if tensor is None:\n            tensor = torch.empty(\n                tensor_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        if torch.distributed.get_rank() == src_rank:\n            tensor.data.copy_(weight)\n        dist.broadcast(tensor, src=src_rank, group=mp_group)\n\n    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            if name in state_dict:\n                full_weight = state_dict[name]\n\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            if name in state_dict:\n                full_weight = state_dict[name]\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(\n                config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(\n                    torch.cat([gate_weight_tp, up_weight_tp], dim=0)\n                )\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape \"\n                f\"{tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n            full_weight_q = state_dict[q_name]\n            full_weight_k = state_dict[k_name]\n            full_weight_v = state_dict[v_name]\n\n            hidden_size_per_head = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                sizes = [total_size * tp_size]\n                if not bias:\n                    sizes.append(config.hidden_size)\n                new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    num_query_groups_per_partition = models[0].config.num_query_groups // tp_size\n                    new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]\n                    q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0)\n                    k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0)\n                    v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0)\n                    total_size_per_head = total_size // num_query_groups_per_partition\n                    for j in range(num_query_groups_per_partition):\n                        new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(\n                            torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)\n                        )\n\n            else:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                sizes = [total_size * tp_size]\n                if not bias:\n                    sizes.append(config.hidden_size)\n                new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                    k_part = full_weight_k[start_idx:end_idx]\n                    v_part = full_weight_v[start_idx:end_idx]\n                    new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]\n                    q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0)\n                    k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0)\n                    v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0)\n                    total_size_per_head = total_size // config.num_attention_heads\n                    for j in range(config.num_attention_heads):\n                        new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(\n                            torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)\n                        )\n\n            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"loading embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        embed_tokens_weight = None\n        if pp_rank == 0:\n            embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight\n        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n\n        for layer in range(config.num_hidden_layers):\n            layer_name = f\"model.layers.{layer}\"\n            print_rank_0(f\"loading layer #{layer}, with layer_name model.layers.{layer}...\")\n            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n            sync_layer = gpt_model_module.decoder.layers[dst_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.input_layernorm.weight\",\n            )\n\n            if f\"{layer_name}.self_attn.q_norm.weight\" in state_dict:\n                _broadcast_tensor(\n                    sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None,\n                    f\"{layer_name}.self_attn.q_norm.weight\",\n                )\n                _broadcast_tensor(\n                    sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None,\n                    f\"{layer_name}.self_attn.k_norm.weight\",\n                )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n            )\n            if f\"{layer_name}.self_attn.q_proj.bias\" in state_dict:\n                _broadcast_tp_shard_tensor_qkv(\n                    sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None,\n                    f\"{layer_name}.self_attn.q_proj.bias\",\n                    f\"{layer_name}.self_attn.k_proj.bias\",\n                    f\"{layer_name}.self_attn.v_proj.bias\",\n                    bias=True,\n                )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                chunk_dim=1,\n            )\n            _broadcast_tensor(\n                sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                chunk_dim=1,\n            )\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"loading final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.decoder.final_layernorm, \"weight\", None),\n            \"model.norm.weight\",\n        )\n\n        print_rank_0(\"loading lm_head...\")\n        lm_head_weight = None\n        if pp_rank + 1 == pp_size:\n            lm_head_weight = gpt_model_module.output_layer.weight\n\n        if is_value_model:\n            # if torch.distributed.get_rank() == src_rank:\n            if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n            elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"reward_head.weight\")\n                print_rank_0(\"load lm_head from value_head weight\")\n            else:\n                _broadcast_tensor(None, \"lm_head.weight\")\n                print_rank_0(\"fail to match lm_head in value_model\")\n            # else:\n\n            #     _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n\n        else:\n            _broadcast_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n    dist.barrier()\n    # Broadcast weights inside data parallel groups\n    for wrapped_model in wrapped_models:\n        broadcast_params(wrapped_model)\n    pass\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/mbridge.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ntry:\n    from mbridge import AutoBridge\n    from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model\nexcept ImportError:\n    print(\"mbridge package not found. Please install mbridge with `pip install verl[mcore]` or `pip install mbridge`\")\n    raise\n\n__all__ = [\"AutoBridge\", \"make_value_model\", \"freeze_moe_router\"]\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/model_forward.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom verl.utils.megatron_utils import unwrap_model\n\nfrom .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding\n\n\ndef gptmodel_forward(\n    model,\n    input_ids,\n    attention_mask,\n    position_ids,\n    sequence_parallel,\n    value_model=False,\n    pack_seqs=True,\n    logits_processor=None,\n    logits_processor_args: dict = None,\n    **kwargs,\n):\n    \"\"\"Default forward pass for GPT models with optional sequence packing.\"\"\"\n    pre_process = unwrap_model(model).pre_process\n    post_process = unwrap_model(model).post_process\n    if pack_seqs:\n        batch_size, seq_len = attention_mask.shape[:2]\n        input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)\n        input_ids_rmpad = input_ids_rmpad.contiguous()\n        output_orig = model(\n            input_ids=input_ids_rmpad,\n            attention_mask=None,\n            position_ids=position_ids,\n            packed_seq_params=packed_seq_params,\n        )\n        if post_process and logits_processor is not None:\n            args = {\n                k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0]\n                for k, v in logits_processor_args.items()\n            }\n            output_dict = logits_processor(output_orig, **args)\n            output = {\n                k: postprocess_packed_seqs(\n                    v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n                )\n                for k, v in output_dict.items()\n            }\n        else:\n            output = postprocess_packed_seqs(\n                output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n            )\n    else:\n        assert logits_processor is None, \"logits_processor is not supported for non-packed sequence\"\n        batch_size, sequence_length = attention_mask.shape\n        new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(\n            input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process\n        )\n        output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids)\n        output = recover_left_padding(\n            output, new_attention_mask, attention_mask, sequence_length, post_process=post_process\n        )\n    if value_model and post_process:\n        output = output[..., 0]\n    return output\n\n\ndef gptmodel_forward_qwen2_5_vl(\n    model,\n    input_ids,\n    attention_mask,\n    position_ids,\n    sequence_parallel,\n    value_model=False,\n    pack_seqs=True,\n    multi_modal_inputs=None,\n    logits_processor=None,\n    logits_processor_args: dict = None,\n    **kwargs,\n):\n    from megatron.core import parallel_state as mpu\n\n    assert mpu.get_context_parallel_world_size() == 1, \"qwen2_5_vl's context parallel is not accurate yet\"\n    pre_process = unwrap_model(model).pre_process\n    post_process = unwrap_model(model).post_process\n    pixel_values = (\n        multi_modal_inputs[\"pixel_values\"].to(input_ids.device) if \"pixel_values\" in multi_modal_inputs else None\n    )\n    image_grid_thw = (\n        multi_modal_inputs[\"image_grid_thw\"].to(input_ids.device) if \"image_grid_thw\" in multi_modal_inputs else None\n    )\n    if pack_seqs:\n        batch_size, seq_len = attention_mask.shape[:2]\n        input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True)\n        input_ids_rmpad = input_ids_rmpad.contiguous()\n        output_orig = model(\n            input_ids=input_ids_rmpad,\n            attention_mask=None,\n            position_ids=position_ids,\n            packed_seq_params=packed_seq_params,\n            pixel_values=pixel_values,\n            image_grid_thw=image_grid_thw,\n        )\n\n        if post_process and logits_processor is not None:\n            args = {\n                k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0]\n                for k, v in logits_processor_args.items()\n            }\n            output_dict = logits_processor(output_orig, **args)\n            output = {\n                k: postprocess_packed_seqs(\n                    v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n                )\n                for k, v in output_dict.items()\n            }\n        else:\n            output = postprocess_packed_seqs(\n                output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n            )\n    else:\n        batch_size, sequence_length = attention_mask.shape\n        new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(\n            input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process\n        )\n        output = model(\n            input_ids=new_input_ids,\n            position_ids=new_position_ids,\n            attention_mask=new_attention_mask,\n            pixel_values=pixel_values,\n            image_grid_thw=image_grid_thw,\n        )\n        output = recover_left_padding(\n            output, new_attention_mask, attention_mask, sequence_length, post_process=post_process\n        )\n    if value_model and post_process:\n        output = output[..., 0]\n    return output\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/model_forward_fused.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import OrderedDict\nfrom typing import Optional\n\nimport torch\nfrom megatron.core import parallel_state\nfrom megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk\nfrom megatron.core.inference.contexts import BaseInferenceContext\nfrom megatron.core.models.gpt.gpt_model import GPTModel\nfrom megatron.core.packed_seq_params import PackedSeqParams\nfrom megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region\nfrom torch import Tensor\n\nfrom verl.models.mcore.util import preprocess_packed_seqs\nfrom verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\nfrom verl.utils.megatron_utils import unwrap_model\nfrom verl.utils.model import CausalLMOutputForPPO\n\nfrom .qwen2_5_vl.model import Qwen2_5VLModel\nfrom .util import postprocess_packed_seqs_for_dict_output\n\n\ndef patch_fused_forward(model: torch.nn.Module):\n    model = unwrap_model(model)\n    if isinstance(model, GPTModel):\n        model = model\n    elif isinstance(model, Qwen2_5VLModel):\n        if not hasattr(model, \"language_model\"):\n            # the qwen2.5vl model might only have vision_model\n            return\n        model = model.language_model\n    else:\n        raise ValueError(\"Model is not a GPTModel or Qwen2_5VLModel\")\n    model.forward_backup = model.forward\n    model.forward = _fused_GPTModel_forward.__get__(model, model.__class__)\n    return\n\n\ndef unpatch_fused_forward(model: torch.nn.Module):\n    model = unwrap_model(model)\n    if isinstance(model, GPTModel):\n        model = model\n    elif isinstance(model, Qwen2_5VLModel):\n        model = model.language_model\n    else:\n        raise ValueError(\"Model is not a GPTModel or Qwen2_5VLModel\")\n    model.forward = model.forward_backup\n    return\n\n\ndef fused_forward_gptmodel(\n    model: GPTModel,\n    input_ids: Tensor,\n    position_ids: Tensor,\n    attention_mask: Tensor,\n    labels: Tensor,\n    labels_mask: Tensor,\n    **kwargs,\n):\n    pre_process: bool = unwrap_model(model).pre_process\n    post_process: bool = unwrap_model(model).post_process\n\n    batch_size, seq_len = attention_mask.shape[:2]\n    input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)\n    input_ids_rmpad = input_ids_rmpad.contiguous()\n    labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True)\n    labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True)\n    labels_rmpad = labels_rmpad.contiguous()\n    labels_mask_rmpad = labels_mask_rmpad.contiguous()\n\n    output_orig: CausalLMOutputForPPO = model(\n        input_ids=input_ids_rmpad,\n        attention_mask=None,\n        position_ids=position_ids,\n        labels=labels_rmpad,\n        packed_seq_params=packed_seq_params,\n    )\n\n    if post_process:\n        # output_orig is in type of CausalLMOutputForPPO\n        output = postprocess_packed_seqs_for_dict_output(\n            labels_mask_rmpad,\n            output_orig,\n            packed_seq_params,\n            attention_mask,\n            batch_size,\n            seq_len,\n            post_process=post_process,\n        )\n    else:\n        output = output_orig\n    return output\n\n\ndef fused_forward_qwen2_5_vl(\n    model: Qwen2_5VLModel,\n    input_ids: Tensor,\n    position_ids: Tensor,\n    attention_mask: Tensor,\n    labels: Tensor,\n    labels_mask: Tensor,\n    multi_modal_inputs=None,\n    **kwargs,\n):\n    # pre_process = unwrap_model(model).pre_process\n    post_process = unwrap_model(model).post_process\n\n    pixel_values = (\n        multi_modal_inputs[\"pixel_values\"].to(input_ids.device) if \"pixel_values\" in multi_modal_inputs else None\n    )\n    image_grid_thw = (\n        multi_modal_inputs[\"image_grid_thw\"].to(input_ids.device) if \"image_grid_thw\" in multi_modal_inputs else None\n    )\n\n    batch_size, seq_len = attention_mask.shape[:2]\n    input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True)\n    labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True)\n    labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True)\n    labels_rmpad = labels_rmpad.contiguous()\n    labels_mask_rmpad = labels_mask_rmpad.contiguous()\n    input_ids_rmpad = input_ids_rmpad.contiguous()\n    output_orig: CausalLMOutputForPPO = model(\n        input_ids=input_ids_rmpad,\n        attention_mask=None,\n        position_ids=position_ids,\n        packed_seq_params=packed_seq_params,\n        pixel_values=pixel_values,\n        image_grid_thw=image_grid_thw,\n        labels=labels,\n    )\n    if post_process:\n        # output_orig is in type of CausalLMOutputForPPO\n        output = postprocess_packed_seqs_for_dict_output(\n            labels_mask_rmpad,\n            output_orig,\n            packed_seq_params,\n            attention_mask,\n            batch_size,\n            seq_len,\n            post_process=post_process,\n        )\n    else:\n        output = output_orig\n    return output\n\n\ndef _fused_GPTModel_forward(\n    self,\n    input_ids: Tensor,\n    position_ids: Tensor,\n    attention_mask: Tensor,\n    decoder_input: Tensor = None,\n    labels: Tensor = None,\n    inference_context: BaseInferenceContext = None,\n    packed_seq_params: PackedSeqParams = None,\n    extra_block_kwargs: dict = None,\n    runtime_gather_output: Optional[bool] = None,\n    *,\n    inference_params: Optional[BaseInferenceContext] = None,\n    loss_mask: Optional[Tensor] = None,\n    temperature: float = 1.0,\n) -> CausalLMOutputForPPO:\n    \"\"\"\n    Forward pass for GPT models with fused kernel support.\n\n    Patch https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py\n    \"\"\"\n\n    # If decoder_input is provided (not None), then input_ids and position_ids are ignored.\n    # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.\n\n    # Decoder embedding.\n    if decoder_input is not None:\n        pass\n    elif self.pre_process:\n        decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)\n    else:\n        # intermediate stage of pipeline\n        # decoder will get hidden_states from encoder.input_tensor\n        decoder_input = None\n\n    # Rotary positional embeddings (embedding is None for PP intermediate devices)\n    rotary_pos_emb = None\n    rotary_pos_cos = None\n    rotary_pos_sin = None\n    if self.position_embedding_type == \"rope\" and not self.config.multi_latent_attention:\n        if not self.training and self.config.flash_decode and inference_context:\n            assert inference_context.is_static_batching(), \"GPTModel currently only supports static inference batching.\"\n            # Flash decoding uses precomputed cos and sin for RoPE\n            rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(\n                inference_context.max_sequence_length,\n                self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),\n            )\n        else:\n            rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(\n                inference_context, self.decoder, decoder_input, self.config, packed_seq_params\n            )\n            rotary_pos_emb = self.rotary_pos_emb(\n                rotary_seq_len,\n                packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == \"thd\",\n            )\n    elif self.position_embedding_type == \"mrope\" and not self.config.multi_latent_attention:\n        if self.training or not self.config.flash_decode:\n            rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section)\n        else:\n            # Flash decoding uses precomputed cos and sin for RoPE\n            raise NotImplementedError(\n                \"Flash decoding uses precomputed cos and sin for RoPE, not implmented in MultimodalRotaryEmbedding yet.\"\n            )\n\n    if (\n        (self.config.enable_cuda_graph or self.config.flash_decode)\n        and rotary_pos_cos is not None\n        and inference_context\n        and inference_context.is_static_batching()\n        and not self.training\n    ):\n        sequence_len_offset = torch.tensor(\n            [inference_context.sequence_len_offset] * inference_context.current_batch_size,\n            dtype=torch.int32,\n            device=rotary_pos_cos.device,  # Co-locate this with the rotary tensors\n        )\n    else:\n        sequence_len_offset = None\n\n    # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the\n    # reference held by this caller function, enabling early garbage collection for\n    # skip inference\n\n    # Run decoder.\n    hidden_states = self.decoder(\n        hidden_states=decoder_input,\n        attention_mask=attention_mask,\n        inference_context=inference_context,\n        rotary_pos_emb=rotary_pos_emb,\n        rotary_pos_cos=rotary_pos_cos,\n        rotary_pos_sin=rotary_pos_sin,\n        packed_seq_params=packed_seq_params,\n        sequence_len_offset=sequence_len_offset,\n        **(extra_block_kwargs or {}),\n    )\n\n    # Process inference output.\n    if inference_context and not inference_context.is_static_batching():\n        hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1)\n\n    # logits and loss\n    output_weight = None\n    if self.share_embeddings_and_output_weights:\n        output_weight = self.shared_embedding_or_output_weight()\n\n    if self.mtp_process:\n        hidden_states = self.mtp(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            labels=labels,\n            loss_mask=loss_mask,\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            inference_params=inference_params,\n            rotary_pos_emb=rotary_pos_emb,\n            rotary_pos_cos=rotary_pos_cos,\n            rotary_pos_sin=rotary_pos_sin,\n            packed_seq_params=packed_seq_params,\n            sequence_len_offset=sequence_len_offset,\n            embedding=self.embedding,\n            output_layer=self.output_layer,\n            output_weight=output_weight,\n            runtime_gather_output=runtime_gather_output,\n            compute_language_model_loss=self.compute_language_model_loss,\n            **(extra_block_kwargs or {}),\n        )\n\n    if not self.post_process:\n        return hidden_states\n\n    output = CausalLMOutputForPPO(\n        loss=None,\n        logits=None,\n        past_key_values=None,\n        hidden_states=hidden_states,\n        attentions=None,\n    )\n\n    if self.config.sequence_parallel:\n        hidden_states = gather_from_sequence_parallel_region(hidden_states)\n    logprobs, entropy = linear_cross_entropy(\n        hidden_states,\n        self.output_layer.weight,\n        labels,\n        temperature,\n        \"none\",\n        parallel_state.get_tensor_model_parallel_group(),\n    )\n\n    if has_config_logger_enabled(self.config):\n        payload = OrderedDict(\n            {\n                \"input_ids\": input_ids,\n                \"position_ids\": position_ids,\n                \"attention_mask\": attention_mask,\n                \"decoder_input\": decoder_input,\n                \"logprobs\": logprobs,\n                \"entropy\": entropy,\n            }\n        )\n        log_config_to_disk(self.config, payload, prefix=\"input_and_logits\")\n\n    output.entropy = entropy\n    output.log_probs = logprobs\n\n    return output\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/model_initializer.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# use mcore transformer config to initialize the model\nfrom abc import ABC, abstractmethod\n\nfrom megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec\nfrom megatron.core.models.gpt.gpt_model import GPTModel\n\nfrom .config_converter import PretrainedConfig, TransformerConfig\n\n\nclass BaseModelInitializer(ABC):\n    \"\"\"Base class for model initializers.\"\"\"\n\n    def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig):\n        self.tfconfig = tfconfig\n        self.hf_config = hf_config\n\n    @abstractmethod\n    def get_transformer_layer_spec(self):\n        \"\"\"Get the transformer layer specification.\n        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py\"\"\"\n        pass\n\n    def get_rope_scaling_args(self) -> dict:\n        \"\"\"Get rope scaling args.\"\"\"\n        rope_scaling_args = {}\n        if \"rope_scaling\" in self.hf_config:\n            if self.hf_config.rope_scaling is not None:\n                # assert self.hf_config.rope_scaling[\"type\"] == \"linear\", \"only linear scaling is supported for now\"\n                rope_scaling_args[\"seq_len_interpolation_factor\"] = self.hf_config.rope_scaling[\"factor\"]\n        return rope_scaling_args\n\n    def initialize(\n        self,\n        pre_process: bool = True,\n        post_process: bool = True,\n        share_embeddings_and_output_weights: bool = False,\n        value: bool = False,\n        **extra_kwargs,\n    ) -> GPTModel:\n        \"\"\"Initialize a GPT model with the given configuration.\n        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py\n\n        Args:\n            pre_process (bool): include embedding layer.\n            post_process (bool): including an output layer.\n            share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared.\n            value (bool): add an extra linear layer for classification or regression.\n\n        Returns:\n            GPTModel: An initialized GPT model instance\n        \"\"\"\n        transformer_layer_spec = self.get_transformer_layer_spec()\n        rope_scaling_args = self.get_rope_scaling_args()\n        mtp_block_spec = extra_kwargs.get(\"mtp_block_spec\", None)\n        model = GPTModel(\n            config=self.tfconfig,\n            transformer_layer_spec=transformer_layer_spec,\n            vocab_size=self.hf_config.vocab_size,\n            max_sequence_length=self.hf_config.max_position_embeddings,\n            pre_process=pre_process,\n            post_process=post_process,\n            share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n            position_embedding_type=\"rope\",\n            rotary_base=self.hf_config.rope_theta,\n            **rope_scaling_args,\n            mtp_block_spec=mtp_block_spec,\n        )\n\n        if post_process and value:\n            from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer\n\n            model.output_layer = LinearForLastLayer(\n                input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig\n            )\n\n        return model\n\n\nclass DenseModel(BaseModelInitializer):\n    \"\"\"Initializer for dense models like Llama and Qwen2.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n\n\nclass Qwen2MoEModel(BaseModelInitializer):\n    \"\"\"Initializer for Qwen2 MoE models.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n\n        # Patch layer spec for shared experts\n        for i in range(len(transformer_layer_spec.layer_specs)):\n            transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params[\"gate\"] = True\n\n        return transformer_layer_spec\n\n    def initialize(self, **kwargs):\n        # Qwen default freeze_moe_router: true\n        model = super().initialize(**kwargs)\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", True)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass MixtralModel(BaseModelInitializer):\n    \"\"\"Initializer for Mixtral models.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n        return transformer_layer_spec\n\n    def initialize(self, **kwargs):\n        model = super().initialize(**kwargs)\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", False)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass Qwen3MoEModel(BaseModelInitializer):\n    \"\"\"Initializer for Qwen3 MoE models.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n        return transformer_layer_spec\n\n    def initialize(self, **kwargs):\n        # Qwen default freeze_moe_router: true\n        model = super().initialize(**kwargs)\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", True)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass DeepseekV3Model(BaseModelInitializer):\n    \"\"\"Initializer for DeepseekV3 models.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n        return transformer_layer_spec\n\n    def get_rope_scaling_args(self) -> dict:\n        \"\"\"Get rope scaling args.\"\"\"\n        rope_scaling_args = {}\n        return rope_scaling_args\n\n    def initialize(\n        self,\n        **kwargs,\n    ):\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", True)\n        if freeze_moe_router:\n            self.tfconfig.moe_router_load_balancing_type = \"none\"\n        # MTP\n        if self.tfconfig.mtp_num_layers is not None:\n            transformer_layer_spec = self.get_transformer_layer_spec()\n            mtp_block_spec = get_gpt_mtp_block_spec(self.tfconfig, transformer_layer_spec, use_transformer_engine=True)\n            kwargs[\"mtp_block_spec\"] = mtp_block_spec\n\n        model = super().initialize(**kwargs)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                if hasattr(layer.mlp, \"router\"):\n                    layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass Qwen25VLModel(BaseModelInitializer):\n    \"\"\"Initializer for Qwen2.5 VL models.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n        return transformer_layer_spec\n\n    def initialize(\n        self,\n        pre_process=None,\n        post_process=None,\n        share_embeddings_and_output_weights=False,\n        value=False,\n        **extra_kwargs,\n    ):\n        tfconfig = self.tfconfig\n        hf_config = self.hf_config\n        # Qwen2_5_VLForConditionalGeneration\n        from copy import deepcopy\n\n        transformer_layer_spec = self.get_transformer_layer_spec()\n\n        from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear\n        from megatron.core.models.gpt.moe_module_specs import MLPSubmodules\n        from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec\n\n        from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config\n\n        vision_transformer_config = get_vision_model_config(deepcopy(tfconfig))\n        vision_transformer_config.pipeline_model_parallel_size = 1\n        vision_transformer_config.first_pipeline_num_layers = None\n\n        vision_projection_config = get_vision_projection_config(\n            deepcopy(tfconfig),\n            vision_transformer_config.hidden_size,\n            spatial_merge_size=hf_config.vision_config.spatial_merge_size,\n        )\n        vision_projection_layer_spec = MLPSubmodules(\n            linear_fc1=TEColumnParallelLinear,\n            linear_fc2=TERowParallelLinear,\n        )\n        vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec()\n\n        qwen25_vl_model = Qwen2_5VLModel(\n            language_transformer_config=tfconfig,\n            language_transformer_layer_spec=transformer_layer_spec,\n            language_vocab_size=hf_config.vocab_size,\n            language_max_sequence_length=hf_config.max_position_embeddings,\n            vision_transformer_config=vision_transformer_config,\n            vision_transformer_layer_spec=vision_transformer_layer_spec,\n            vision_projection_config=vision_projection_config,\n            vision_projection_layer_spec=vision_projection_layer_spec,\n            vision_projection_type=\"mlp\",\n            language_rotary_base=hf_config.rope_theta,\n            pre_process=pre_process,\n            post_process=post_process,\n            add_decoder=True,\n            add_encoder=True,\n            parallel_output=True,\n            language_share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n        )\n\n        if post_process and value:\n            from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer\n\n            qwen25_vl_model.language_model.output_layer = LinearForLastLayer(\n                input_size=tfconfig.hidden_size, output_size=1, config=tfconfig\n            )\n\n        return qwen25_vl_model\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/patch_v012.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# there is some bug in mcore 0.12, so we need to patch it\n# 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None\n\n\ndef apply_patch():\n    import torch\n    from megatron.core import parallel_state, tensor_parallel\n    from megatron.core.transformer.multi_latent_attention import (\n        MLASelfAttention,\n        apply_rotary_pos_emb,\n        deprecate_inference_params,\n        gather_from_sequence_parallel_region,\n        gather_from_tensor_model_parallel_region,\n        scatter_to_sequence_parallel_region,\n    )\n\n    def patch_get_query_key_value_tensors(\n        self,\n        hidden_states,\n        key_value_states=None,\n        position_ids=None,\n        packed_seq_params=None,\n        inference_context=None,\n        *,\n        inference_params=None,\n    ):\n        \"\"\"\n        Derives `query`, `key` and `value` tensors from `hidden_states`.\n        \"\"\"\n        # s = sequence length, b = batch size, h = hidden size, n = num attention heads\n        # Attention heads [s, b, n*h]\n        assert hidden_states.ndim == 3, f\"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D\"\n\n        inference_context = deprecate_inference_params(inference_context, inference_params)\n\n        # =========================================\n        # Prepare RoPE and seqlen related params\n        # =========================================\n        rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(\n            inference_context, None, hidden_states, self.config, packed_seq_params\n        )\n\n        # rotary_pos_emb:[s, b, 1, 64]\n        mscale = 1.0\n        if self.config.rope_type == \"rope\":\n            packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == \"thd\"\n            rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq)\n        else:\n            rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len)\n\n        # =========================================\n        # QKV down projection and layernorm\n        # =========================================\n        if self.config.q_lora_rank is not None:\n            # if linear_q_down_proj is ColumnParallelLinear:\n            #     q_compressed: [s, b, q_lora_rank / TP]\n            # elif linear_q_down_proj is Linear:\n            #     q_compressed: [s / TP, b, q_lora_rank]\n            q_compressed, _ = self.linear_q_down_proj(hidden_states)\n\n            # When output is sharded (ColumnParallelLinear), two things are needed to be\n            # identical to a normal Linear.\n            #   1. Manually gather output to restore output dim q_lora_rank;\n            #   2. Scatter sequence back to s / TP if sequence-parallel since it was\n            #      gathered by ColumnParallelLinear.\n            if q_compressed.size(-1) != self.config.q_lora_rank:\n                q_compressed = gather_from_tensor_model_parallel_region(q_compressed)\n                if self.config.sequence_parallel:\n                    q_compressed = scatter_to_sequence_parallel_region(q_compressed)\n\n            q_compressed = self.q_layernorm(q_compressed)\n        else:\n            q_compressed = hidden_states\n\n        # if linear_kv_down_proj is ColumnParallelLinear:\n        #     kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP]\n        # elif linear_kv_down_proj is Linear:\n        #     kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)]\n        kv_combined, _ = self.linear_kv_down_proj(hidden_states)\n        if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim:\n            # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)]\n            kv_combined = gather_from_tensor_model_parallel_region(kv_combined)\n            # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim]\n            kv_compressed, k_pos_emb = torch.split(\n                kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1\n            )\n            if self.config.sequence_parallel:\n                # kv_compressed:[s / TP, b, kv_lora_rank]\n                kv_compressed = scatter_to_sequence_parallel_region(kv_compressed)\n        else:\n            # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim]\n            kv_compressed, k_pos_emb = torch.split(\n                kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1\n            )\n            if parallel_state.get_tensor_model_parallel_world_size() > 1:\n                # k_pos_emb: [s, b, qk_pos_emb_head_dim]\n                k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb)\n\n        kv_compressed = self.kv_layernorm(kv_compressed)\n\n        # =========================================\n        # QKV up projection and RoPE apply\n        # =========================================\n        def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb):\n            if self.config.q_lora_rank is not None:\n                q, _ = self.linear_q_up_proj(q_compressed)\n            else:\n                # hidden_states:[s, b, 2048], q: [s, b, n * 192]\n                q, _ = self.linear_q_proj(q_compressed)\n\n            q_len, bsz, _ = q.size()\n\n            # q: [s, b, n, 192]\n            q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim)\n\n            # kv: [s, b, 2048]\n            kv, _ = self.linear_kv_up_proj(kv_compressed)\n\n            # kv: [s, b, n, 256]\n            kv = kv.view(\n                q_len,\n                bsz,\n                self.num_attention_heads_per_partition,\n                self.config.qk_head_dim + self.config.v_head_dim,\n            )\n\n            if inference_context is not None:\n                # add offset to the sequence start for inference\n                sequence_start = inference_context.sequence_len_offset\n                sequence_end = sequence_start + q_len\n                rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end]\n            else:\n                # Shorten rotary_pos_emb to the sequence length when inference_params\n                # is not provided. This makes sure we can run forward directly with\n                # any sequence length. During training, the sequence length is always\n                # the full rotary_pos_emb length.\n                rotary_pos_emb = rotary_pos_emb[0:q_len]\n\n            # [s, b, 64] -> [s, b, 1, 64]\n            k_pos_emb = torch.unsqueeze(k_pos_emb, 2)\n\n            # q: [s, b, n, 128], q_pos_emb: [s, b, n, 64]\n            q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1)\n\n            # k_no_pe: [s, b, n, 128], value: [s, b, n, 128]\n            k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1)\n\n            if packed_seq_params is not None:\n                cu_seqlens_q = packed_seq_params.cu_seqlens_q\n                cu_seqlens_kv = packed_seq_params.cu_seqlens_kv\n                q_pos_emb = q_pos_emb.squeeze(1)\n                k_pos_emb = k_pos_emb.squeeze(1)\n                q_no_pe = q_no_pe.squeeze(1)\n                k_no_pe = k_no_pe.squeeze(1)\n                value = value.squeeze(1)\n            else:\n                cu_seqlens_q = cu_seqlens_kv = None\n\n            # q_pos_emb: [s, b, n, 64], k_pos_emb:[s, b, 1, 64]\n            q_pos_emb = apply_rotary_pos_emb(\n                q_pos_emb,\n                rotary_pos_emb,\n                config=self.config,\n                cu_seqlens=cu_seqlens_q,\n                mscale=mscale,\n            )\n            k_pos_emb = apply_rotary_pos_emb(\n                k_pos_emb,\n                rotary_pos_emb,\n                config=self.config,\n                cu_seqlens=cu_seqlens_kv,\n                mscale=mscale,\n            )\n\n            # query: [s, b, n, 192]\n            query = torch.cat([q_no_pe, q_pos_emb], dim=-1)\n            if packed_seq_params is not None:\n                k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1)\n                key = torch.cat([k_no_pe, k_pos_emb], dim=-1)\n            else:\n                # key: [s, b, n, 192]\n                k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1)\n                key = torch.cat([k_no_pe, k_pos_emb], dim=-1)\n\n            query = query.contiguous()\n            key = key.contiguous()\n            value = value.contiguous()\n            return query, key, value\n\n        if self.recompute_up_proj:\n            self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput()\n            query, key, value = self.qkv_up_checkpoint.checkpoint(\n                qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb\n            )\n        else:\n            query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb)\n\n        return query, key, value\n\n    MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/qwen2_5_vl/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom .model import Qwen2_5VLModel\nfrom .vision_config import get_vision_model_config, get_vision_projection_config\n\n__all__ = [\"Qwen2_5VLModel\", \"get_vision_model_config\", \"get_vision_projection_config\"]\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/qwen2_5_vl/attention.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core.transformer.attention import *\n\nfrom .rope_utils import apply_rotary_pos_emb_absolute\n\n\nclass Qwen2_5VLSelfAttention(SelfAttention):\n    \"\"\"\n    Overrides the SelfAttention class, the difference is that qwen2_5_vl uses apply_rotary_pos_emb_absolute\n    instead of apply_rotary_pos_emb\n    \"\"\"\n\n    def forward(\n        self,\n        hidden_states: Tensor,\n        attention_mask: Tensor,\n        key_value_states: Optional[Tensor] = None,\n        inference_context: Optional[BaseInferenceContext] = None,\n        rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None,\n        rotary_pos_cos: Optional[Tensor] = None,\n        rotary_pos_sin: Optional[Tensor] = None,\n        attention_bias: Optional[Tensor] = None,\n        packed_seq_params: Optional[PackedSeqParams] = None,\n        sequence_len_offset: Optional[int] = None,\n        *,\n        inference_params: Optional[BaseInferenceContext] = None,\n    ) -> Tuple[Tensor, Tensor]:\n        \"\"\"\n        Perform a forward pass through the attention module.\n\n        Args:\n            hidden_states (Tensor): Hidden states.\n            attention_mask (Tensor): Attention mask.\n            key_value_states (Optional[Tensor]): Key/value states (for cross attention).\n            inference_context (Optional[BaseInferenceContext]): Inference context that manages\n                KV cache.\n            rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary\n                embedding tensor(s).\n            rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine.\n            rotary_pos_sin (Optional[Tensor]): Rotary embedding sine.\n            attention_bias (Optional[Tensor]): Attention bias.\n            packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format.\n            sequence_len_offset (Optional[int]): Sequence length offset used for\n                inference CUDA graphs.\n\n        Return:\n            (Tuple[Tensor, Tensor]) Attention output and bias.\n\n        \"\"\"\n\n        inference_context = deprecate_inference_params(inference_context, inference_params)\n\n        if inference_context and inference_context.is_dynamic_batching():\n            assert flash_decode_and_prefill_kernel is not None, (\n                \"Internal use only: install package `nvidia_chunked_flash_attn`.\"\n            )\n\n        # hidden_states: [sq, b, h]\n        if self.config.flash_decode and not self.training and inference_context is not None:\n            rotary_pos_emb = None\n        else:\n            assert rotary_pos_cos is None and rotary_pos_sin is None\n\n        # For self attention we just duplicate the rotary_pos_emb if it isn't already\n        if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):\n            rotary_pos_emb = (rotary_pos_emb,) * 2\n\n        # =====================\n        # Query, Key, and Value\n        # =====================\n        # Get the query, key and value tensors based on the type of attention -\n        # self or cross attn.\n        query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)\n\n        # ===================================================\n        # Adjust key, value, and rotary_pos_emb for inference\n        # ===================================================\n\n        # This branch only runs in the decode phase of flash decoding and returns after the linear\n        # projection. This conditional is not used in the prefill phase or non-flash-decoding cases.\n        if (\n            self.config.flash_decode\n            and inference_context is not None\n            and inference_context.is_decode_only()\n            and not self.training\n            and rotary_pos_cos is not None\n        ):\n            assert self.layer_number in inference_context.key_value_memory_dict\n            assert inference_context.sequence_len_offset is not None\n            inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number]\n            output = self.flash_decode(\n                sequence_len_offset=sequence_len_offset,\n                query_layer=query,\n                key_layer=key,\n                value_layer=value,\n                inference_key_memory=inference_key_memory,\n                inference_value_memory=inference_value_memory,\n                rotary_cos=rotary_pos_cos,\n                rotary_sin=rotary_pos_sin,\n            )\n            out = output.transpose(0, 1).contiguous()\n            context_layer = out.view(out.size(0), out.size(1), -1)\n            output, bias = self.linear_proj(context_layer)\n            return output, bias\n\n        query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(\n            inference_context,\n            query,\n            key,\n            value,\n            rotary_pos_emb,\n            rotary_pos_cos,\n            rotary_pos_sin,\n            sequence_len_offset,\n        )\n\n        if packed_seq_params is not None:\n            query = query.squeeze(1)\n            key = key.squeeze(1)\n            value = value.squeeze(1)\n\n        # ================================================\n        # relative positional embedding (rotary embedding)\n        # ================================================\n        if rotary_pos_emb is not None and not self.config.flash_decode:\n            q_pos_emb, k_pos_emb = rotary_pos_emb\n\n            if packed_seq_params is not None:\n                if packed_seq_params.cu_seqlens_q_padded is not None:\n                    cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded\n                else:\n                    cu_seqlens_q = packed_seq_params.cu_seqlens_q\n                if packed_seq_params.cu_seqlens_kv_padded is not None:\n                    cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded\n                else:\n                    cu_seqlens_kv = packed_seq_params.cu_seqlens_kv\n            else:\n                cu_seqlens_q = cu_seqlens_kv = None\n\n            if q_pos_emb is not None:\n                # TODO VIJAY: simplify\n                if inference_context is None or inference_context.is_static_batching():\n                    query = apply_rotary_pos_emb_absolute(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q)\n                else:\n                    query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q)\n            if k_pos_emb is not None:\n                key = apply_rotary_pos_emb_absolute(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)\n\n            # TODO, can apply positional embedding to value_layer so it has\n            # absolute positional embedding.\n            # otherwise, only relative positional embedding takes effect\n            # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)\n\n        # ==================================\n        # core attention computation\n        # ==================================\n\n        if self.checkpoint_core_attention and self.training:\n            core_attn_out = self._checkpointed_attention_forward(\n                query,\n                key,\n                value,\n                attention_mask,\n                attn_mask_type=attn_mask_type,\n                attention_bias=attention_bias,\n                packed_seq_params=packed_seq_params,\n            )\n        else:\n            if inference_context is None or inference_context.is_static_batching():\n                # Static batching attention kernel.\n                core_attn_out = self.core_attention(\n                    query,\n                    key,\n                    value,\n                    attention_mask,\n                    attn_mask_type=attn_mask_type,\n                    attention_bias=attention_bias,\n                    packed_seq_params=packed_seq_params,\n                )\n\n            else:\n                # Dynamic batching attention kernel.\n                q, k, v = (query, key, value)\n                cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths()\n                cu_kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths()\n\n                core_attn_out = self.flash_decode_and_prefill(\n                    q, k, v, max_seqlen_q, max_seqlen_k, cu_query_lengths, cu_kv_lengths\n                )\n                core_attn_out = core_attn_out.squeeze(0).unsqueeze(1)\n                core_attn_out = rearrange(core_attn_out, \"s b h d -> s b (h d)\")\n\n        if packed_seq_params is not None and packed_seq_params.qkv_format == \"thd\":\n            # reshape to same output shape as unpacked case\n            # (t, np, hn) -> (t, b=1, h=np*hn)\n            # t is the pack size = sum (sq_i)\n            # note that batch is a dummy dimension in the packed case\n            core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)\n\n        # =================\n        # Output. [sq, b, h]\n        # =================\n\n        output, bias = self.linear_proj(core_attn_out)\n\n        return output, bias\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/qwen2_5_vl/model.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport logging\n\nimport torch\nfrom megatron.core import InferenceParams, tensor_parallel\nfrom megatron.core.models.gpt.gpt_model import GPTModel\n\n# from .transformer_config import Qwen2VLTransformerConfig\nfrom megatron.core.packed_seq_params import PackedSeqParams\nfrom megatron.core.transformer import MegatronModule\nfrom megatron.core.transformer.spec_utils import ModuleSpec\nfrom megatron.core.transformer.transformer_config import TransformerConfig\n\nfrom .attention import Qwen2_5VLSelfAttention\nfrom .vision_model import Qwen2_5VisionModel\n\n\n# Note: This is under development and may be missing features.\nclass Qwen2_5VLModel(MegatronModule):\n    \"\"\"Qwen2.5VL multi-modal model.\n\n    Args:\n        language_transformer_config (TransformerConfig): Transformer config for the language model.\n        language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the\n            language model.\n        language_vocab_size (int): Language model vocabulary size.\n        language_max_sequence_length (int): Language model maximum sequence length. This is used for\n            positional embedding.\n        vision_transformer_config (TransformerConfig): Transformer config for the vision model.\n        vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the\n            vision model.\n        vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to\n            language model inputs.\n        vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision\n            projection.\n        vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP.\n        parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This\n            is typically True for training and False for inference.\n        language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings\n            in the language model. Defaults to 1.0.\n        pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism).\n            Defaults to True.\n        post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline\n            parallelism). Defaults to True.\n        add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True.\n            When we use pipelining, the encoder\n            will live on only a subset of the pipeline stages (specifically, only the first stage).\n        add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True.\n            When we use pipelining, the decoder\n            will live on only a subset of the pipeline stages (specifically, every stage after the first one).\n        img_h (int): The height of each image that the ViT will see.\n        img_w (int): The width of each image that the ViT will see.\n        patch_dim (int): The size of each patch side.\n        img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be\n            inserted. Defaults to 0.\n    \"\"\"\n\n    def __init__(\n        self,\n        language_transformer_config: TransformerConfig,\n        language_transformer_layer_spec: ModuleSpec,\n        language_vocab_size: int,\n        language_max_sequence_length: int,\n        vision_transformer_config: TransformerConfig,\n        vision_transformer_layer_spec: ModuleSpec,\n        vision_projection_config: TransformerConfig,\n        vision_projection_layer_spec: ModuleSpec,\n        vision_projection_type: str = \"mlp\",\n        parallel_output: bool = True,\n        language_rotary_percent: float = 1.0,\n        pre_process: bool = True,\n        post_process: bool = True,\n        add_encoder: bool = True,\n        add_decoder: bool = True,\n        language_rotary_base: int = 10000,\n        fp16_lm_cross_entropy: bool = False,\n        language_share_embeddings_and_output_weights: bool = False,\n        image_token_id: int = 151655,\n        video_token_id: int = 151656,\n    ) -> None:\n        super().__init__(config=language_transformer_config)\n\n        # patch self_attention to use qwen2_5_vl attention\n        vision_transformer_layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention\n        for layer_spec in language_transformer_layer_spec.layer_specs:\n            layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention\n\n        logging.getLogger(__name__).warning(\"Qwen2VL model is under development and may be missing features.\")\n\n        self.pre_process = pre_process\n        self.post_process = post_process\n        self.add_encoder = add_encoder\n        self.add_decoder = add_decoder\n\n        self.encoder_hidden_state = None\n        self.vision_model = None\n        self.vision_projection = None\n        self.language_model = None\n        self.image_token_id = image_token_id\n        self.video_token_id = video_token_id\n\n        self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size\n\n        # This attribute is needed to check if an all-reduce is required\n        # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.\n        self.share_embeddings_and_output_weights = False\n        if self.pre_process:\n            self.vision_model = Qwen2_5VisionModel(\n                vision_transformer_config,\n                vision_transformer_layer_spec,\n                vision_projection_config,\n                vision_projection_layer_spec,\n                projection_type=vision_projection_type,\n                pre_process=True,\n                post_process=True,\n            )\n\n        self.language_model = GPTModel(\n            config=language_transformer_config,\n            transformer_layer_spec=language_transformer_layer_spec,\n            vocab_size=language_vocab_size,\n            max_sequence_length=language_max_sequence_length,\n            parallel_output=parallel_output,\n            position_embedding_type=\"mrope\",\n            rotary_percent=language_rotary_percent,\n            pre_process=self.pre_process,\n            post_process=self.post_process,\n            rotary_base=language_rotary_base,\n            fp16_lm_cross_entropy=fp16_lm_cross_entropy,\n            share_embeddings_and_output_weights=language_share_embeddings_and_output_weights,\n            scatter_embedding_sequence_parallel=False,\n        )\n\n        self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights\n\n    def shared_embedding_or_output_weight(self):\n        \"\"\"This is a convenience method to surface the language model's word embeddings, which is\n        necessary for `finalize_model_grads._allreduce_word_embedding_grads`.\"\"\"\n        if self.add_decoder:\n            return self.language_model.shared_embedding_or_output_weight()\n        return None\n\n    def set_input_tensor(self, input_tensor) -> None:\n        # This is usually handled in schedules.py but some inference code still\n        # gives us non-lists or None\n        if not isinstance(input_tensor, list):\n            input_tensor = [input_tensor]\n        assert len(input_tensor) == 1, \"input_tensor should only be length 1 for Qwen2VL\"\n\n        if self.pre_process:\n            self.encoder_hidden_state = input_tensor[0]\n        else:\n            self.language_model.set_input_tensor(input_tensor[0])\n\n    def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool):\n        \"\"\"Freeze model modules.\n\n        Make specific modules non-trainable by setting requires_grad to False for the module's parameters.\n\n        Args:\n            freeze_language_model (bool): Freeze the language model module.\n            freeze_vision_model (bool): Freeze the vision model module.\n            freeze_vision_projection (bool): Freeze the vision projection module.\n        \"\"\"\n        modules = []\n        if freeze_language_model and self.language_model is not None:\n            modules.append(self.language_model)\n        if freeze_vision_model and self.vision_model is not None:\n            modules.append(self.vision_model)\n        if freeze_vision_projection and self.vision_projection is not None:\n            modules.append(self.vision_projection)\n\n        for module in modules:\n            for param in module.parameters():\n                param.requires_grad = False\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        attention_mask: torch.Tensor = None,\n        labels: torch.Tensor = None,\n        inference_params: InferenceParams = None,\n        packed_seq_params: PackedSeqParams = None,\n        extra_block_kwargs: dict = None,\n        pixel_values: torch.Tensor = None,\n        pixel_values_videos: torch.Tensor = None,\n        image_grid_thw: torch.Tensor = None,\n        video_grid_thw: torch.Tensor = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward function of the Qwen2VL model.\n\n        Args:\n            image_data (torch.Tensor): input image of shape [total_thw_size, n_features].\n            input_ids (torch.Tensor): input text ids [batch, text_seq_len].\n            position_ids (torch.Tensor): input text position ids [batch, text_seq_len].\n            attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len,\n                combined_seq_len].\n            labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].\n            inference_params (InferenceParams): Inference-time parameters including KV cache.\n\n            video_start_index:\n                0 -- all video\n                len(video_seq) -- all image\n                others -- mixture\n            *_input_mask: should not be None in the first PP stage\n        Returns:\n            output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape\n                [b, s, vocab_size].\n        \"\"\"\n        video_start_index = 0\n        vision_grid_thw = None\n        vision_data = None\n        if image_grid_thw is not None:\n            image_mask = input_ids == self.image_token_id\n            vision_grid_thw = image_grid_thw\n            vision_data = pixel_values\n            video_start_index = image_mask.sum().item()\n        if video_grid_thw is not None:\n            video_mask = input_ids == self.video_token_id\n            vision_grid_thw = torch.cat([vision_grid_thw, video_grid_thw], dim=0)\n            vision_data = torch.cat([vision_data, pixel_values_videos], dim=0)\n            video_start_index = image_mask.sum().item() + video_mask.sum().item()\n        use_inference_kv_cache = (\n            inference_params is not None and \"image_tokens_count\" in inference_params.key_value_memory_dict\n        )\n        use_inference_kv_cache = (\n            inference_params is not None and \"image_tokens_count\" in inference_params.key_value_memory_dict\n        )\n        if use_inference_kv_cache:\n            raise NotImplementedError()\n\n        if self.pre_process:\n            vision_embeds = None\n            if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0:\n                vision_embeds = self.vision_model(\n                    vision_data=vision_data,  # If None, vision model should use intermediate outputs (EPP > 1)\n                    grid_thw=vision_grid_thw,  # should provided in each EPP stage\n                )\n\n            # If running inference, the language model KV cache will be updated for image token positions.\n            # Here we store the image tokens sequence length, which can be used as an offset to the KV cache later.\n            if inference_params is not None:\n                raise NotImplementedError()\n                # inference_params.key_value_memory_dict[\"image_tokens_count\"] = (\n                #     vision_embeddings.shape[0]\n                # )\n\n            # If running inference, we can skip image token computation if they were computed already earlier\n            # for this sample.\n            if use_inference_kv_cache:\n                language_embeddings: torch.Tensor = self.language_model.embedding(\n                    input_ids=input_ids,\n                    position_ids=None,  # NOTE: disable\n                )  # [text_seq_len, b, h_language]\n                # NOTE: why not cat here? is it the combined embeddings useless?\n                combined_embeddings = language_embeddings\n            elif vision_embeds is not None:\n                if video_start_index == 0:\n                    image_embeds = None\n                    video_embeds = vision_embeds\n                elif video_start_index == vision_embeds.shape[0]:\n                    image_embeds = vision_embeds\n                    video_embeds = None\n                elif 0 < video_start_index < vision_embeds.shape[0]:\n                    image_embeds = vision_embeds[:video_start_index]\n                    video_embeds = vision_embeds[video_start_index:]\n                else:\n                    raise ValueError(\n                        f\"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got \"\n                        f\"{video_start_index}\"\n                    )\n\n                combined_embeddings = self.language_model.embedding(\n                    input_ids=input_ids,\n                    position_ids=None,  # NOTE: disable\n                )  # [text_seq_len, b, h_language]\n\n                if image_embeds is not None or video_embeds is not None:\n                    combined_embeddings = combined_embeddings.transpose(0, 1).contiguous()\n                    if image_embeds is not None:\n                        image_mask = (input_ids == self.image_token_id).contiguous()\n                        if image_mask.sum() > 0:\n                            combined_embeddings = combined_embeddings.clone()\n                            combined_embeddings[image_mask] = image_embeds.to(\n                                dtype=combined_embeddings.dtype, device=combined_embeddings.device\n                            )\n                    if video_embeds is not None:\n                        video_mask = (input_ids == self.video_token_id).contiguous()\n                        if video_mask.sum() > 0:\n                            combined_embeddings = combined_embeddings.clone()\n                            combined_embeddings[video_mask] = video_embeds.to(\n                                dtype=combined_embeddings.dtype, device=combined_embeddings.device\n                            )\n                    combined_embeddings = combined_embeddings.transpose(0, 1).contiguous()\n\n            else:\n                combined_embeddings = self.language_model.embedding(\n                    input_ids=input_ids,\n                    position_ids=None,  # NOTE: disable\n                )  # [text_seq_len, b, h_language]\n            if self.config.sequence_parallel:\n                combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings)\n                combined_embeddings = combined_embeddings.contiguous()\n        else:\n            combined_embeddings = None\n        from .rope_utils import get_rope_index\n\n        position_ids, _ = get_rope_index(\n            input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, attention_mask=attention_mask\n        )\n\n        output = self.language_model(\n            input_ids=None,\n            position_ids=position_ids,  # None in encoder\n            attention_mask=attention_mask,  # None in encoder\n            decoder_input=combined_embeddings,  # only not None in the first decoder PP stage\n            labels=labels,  # only not None in the last decoder PP stage\n            # inference_params=inference_params,  # currently always None\n            packed_seq_params=packed_seq_params,  # currently always None\n            **(extra_block_kwargs or {}),\n        )\n\n        return output\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/qwen2_5_vl/rope_utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom __future__ import annotations\n\nimport logging\nfrom typing import Optional\n\nimport torch\nfrom megatron.core.models.common.embeddings.rope_utils import *\nfrom megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd\nfrom torch import Tensor\n\nlogger = logging.getLogger(__name__)\n\n\n# Slightly modified from Qwen2VLForConditionalGeneration.get_rope_index\ndef get_rope_index(\n    input_ids: Optional[torch.LongTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    second_per_grid_ts: Optional[torch.Tensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n):\n    \"\"\"\n    Calculate the 3D rope index based on image and video's temporal, height and width in LLM.\n\n    Explanation:\n\n        Each embedding sequence contains vision embedding and text embedding or just contains text embedding.\n\n        For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.\n\n        Examples:\n\n            input_ids: [T T T T T], here T is for text.\n            temporal position_ids: [0, 1, 2, 3, 4]\n            height position_ids: [0, 1, 2, 3, 4]\n            width position_ids: [0, 1, 2, 3, 4]\n\n        For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part\n        and 1D rotary position embedding for text part.\n\n        Examples:\n\n            Temporal (Time): 3 patches, representing different segments of the video in time.\n            Height: 2 patches, dividing each frame vertically.\n            Width: 2 patches, dividing each frame horizontally.\n            We also have some important parameters:\n            fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each\n            second.\n            tokens_per_second: This is a crucial parameter. It dictates how many \"time-steps\" or \"temporal\n                               tokens\" are conceptually packed into a one-second interval of the video.\n                               In this case, we have 25 tokens per second. So each second of the video will be\n                               represented with 25 separate time points. It essentially defines the temporal\n                               granularity.\n            temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.\n            interval: The step size for the temporal position IDs, calculated as tokens_per_second *\n            temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be\n            have a difference of 50 in the temporal position IDs.\n            input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.\n            vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]\n            vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]\n            vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]\n            text temporal position_ids: [101, 102, 103, 104, 105]\n            text height position_ids: [101, 102, 103, 104, 105]\n            text width position_ids: [101, 102, 103, 104, 105]\n            Here we calculate the text start position_ids as the max vision position_ids plus 1.\n\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n        image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):\n            The temporal, height and width of feature shape of each image in LLM.\n        video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):\n            The temporal, height and width of feature shape of each video in LLM.\n        second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):\n            The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n    Returns:\n        position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)\n        mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)\n    \"\"\"\n    spatial_merge_size = 2\n    tokens_per_second = 2\n    image_token_id = 151655\n    video_token_id = 151656\n    vision_start_token_id = 151652\n    mrope_position_deltas = []\n    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):\n        total_input_ids = input_ids\n        if attention_mask is None:\n            attention_mask = torch.ones_like(total_input_ids)\n        position_ids = torch.ones(\n            3,\n            input_ids.shape[0],\n            input_ids.shape[1],\n            dtype=input_ids.dtype,\n            device=input_ids.device,\n        )\n        image_index, video_index = 0, 0\n        attention_mask = attention_mask.to(total_input_ids.device)\n        for i, input_ids in enumerate(total_input_ids):\n            input_ids = input_ids[attention_mask[i] == 1]\n            image_nums, video_nums = 0, 0\n            vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)\n            vision_tokens = input_ids[vision_start_indices + 1]\n            image_nums = (vision_tokens == image_token_id).sum()\n            video_nums = (vision_tokens == video_token_id).sum()\n            input_tokens = input_ids.tolist()\n            llm_pos_ids_list: list = []\n            st = 0\n            remain_images, remain_videos = image_nums, video_nums\n            for _ in range(image_nums + video_nums):\n                if image_token_id in input_tokens and remain_images > 0:\n                    ed_image = input_tokens.index(image_token_id, st)\n                else:\n                    ed_image = len(input_tokens) + 1\n                if video_token_id in input_tokens and remain_videos > 0:\n                    ed_video = input_tokens.index(video_token_id, st)\n                else:\n                    ed_video = len(input_tokens) + 1\n                if ed_image < ed_video:\n                    t, h, w = (\n                        image_grid_thw[image_index][0],\n                        image_grid_thw[image_index][1],\n                        image_grid_thw[image_index][2],\n                    )\n                    second_per_grid_t = 0\n                    image_index += 1\n                    remain_images -= 1\n                    ed = ed_image\n\n                else:\n                    t, h, w = (\n                        video_grid_thw[video_index][0],\n                        video_grid_thw[video_index][1],\n                        video_grid_thw[video_index][2],\n                    )\n                    if second_per_grid_ts is not None:\n                        second_per_grid_t = second_per_grid_ts[video_index]\n                    else:\n                        second_per_grid_t = 1.0\n                    video_index += 1\n                    remain_videos -= 1\n                    ed = ed_video\n                llm_grid_t, llm_grid_h, llm_grid_w = (\n                    t.item(),\n                    h.item() // spatial_merge_size,\n                    w.item() // spatial_merge_size,\n                )\n                text_len = ed - st\n\n                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n                llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n                range_tensor = torch.arange(llm_grid_t).view(-1, 1)\n                expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)\n\n                time_tensor = expanded_range * second_per_grid_t * tokens_per_second\n\n                time_tensor_long = time_tensor.long()\n                t_index = time_tensor_long.flatten()\n\n                h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()\n                w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()\n                llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)\n                st = ed + llm_grid_t * llm_grid_h * llm_grid_w\n\n            if st < len(input_tokens):\n                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n                text_len = len(input_tokens) - st\n                llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n            llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)\n            position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)\n            mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))\n        mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)\n        return position_ids, mrope_position_deltas\n    else:\n        if attention_mask is not None:\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)\n            max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]\n            mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]\n        else:\n            position_ids = (\n                torch.arange(input_ids.shape[1], device=input_ids.device)\n                .view(1, 1, -1)\n                .expand(3, input_ids.shape[0], -1)\n            )\n            mrope_position_deltas = torch.zeros(\n                [input_ids.shape[0], 1],\n                device=input_ids.device,\n                dtype=input_ids.dtype,\n            )\n\n        return position_ids, mrope_position_deltas\n\n\ndef apply_rotary_pos_emb_thd_absolute(\n    t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False\n) -> Tensor:\n    \"\"\"A baseline implementation of applying RoPE for `thd` format.\n\n    Args:\n        t (Tensor): Input tensor T is of shape [t, h, d]\n        cu_seqlens(Tensor):  Cumulative sum of sequence lengths in a batch for `t`,\n        with shape [b + 1] and dtype torch.int32.\n        freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]\n\n    Returns:\n        Tensor: Shape [t, h, d]. The input tensor after applying RoPE.\n    \"\"\"\n    return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1)\n\n\ndef apply_rotary_pos_emb_absolute(\n    t: Tensor,\n    freqs: Tensor,\n    config: TransformerConfig,\n    cu_seqlens: Optional[Tensor] = None,\n):\n    \"\"\"\n    Reroute to the appropriate apply_rotary_pos_emb function depending on\n    bshd (conventional) / thd (packed seq) format\n\n    In Qwen2-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim]\n    \"\"\"\n\n    if config.apply_rope_fusion:\n        if cu_seqlens is None:\n            # NOTE: TE backends do not support mRoPE in bshd format when bs > 1\n            if freqs.shape[1] > 1:\n                return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved)\n            else:\n                return fused_apply_rotary_pos_emb(t, freqs)\n        else:\n            # NOTE: as expected, thd format can use bshd\n            return fused_apply_rotary_pos_emb(t[:, None], freqs).squeeze(1)\n    else:\n        if cu_seqlens is None:\n            return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved)\n        else:\n            return apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved)\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/qwen2_5_vl/vision_config.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom megatron.core import parallel_state\nfrom megatron.core.transformer import TransformerConfig\n\n\ndef get_vision_model_config(config: TransformerConfig) -> TransformerConfig:\n    # Given a Transformer Config from decoder, build vision encoder config\n    # diff: out_hidden_size & intermediate_size\n\n    # mlp: hidden_size -> intermediate_size -> embed_dim, silu\n    # NOTE: here we provide a workaround to solve the wrong layer amount when VPP of decoder is on\n    if config.num_layers in [28, 36]:\n        config.ffn_hidden_size = 3420\n    else:\n        config.ffn_hidden_size = 3456\n\n    if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:\n        config.num_layers = 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size()  # depth\n    else:\n        config.num_layers = 32  # depth\n    config.num_attention_heads = 16  # num_heads\n    config.add_bias_linear = True  # all nn.Linear has bias (MLP, attn)\n    config.add_qkv_bias = True  # qkv_proj in attn has bias\n    config.hidden_size = 1280  # hidden_size\n    config.hidden_dropout = 0.0\n    config.attention_dropout = 0.0\n\n    # config.gated_linear_unit = False # no gated\n    # config.activation_func = quick_gelu # hidden_act\n    config.kv_channels = config.hidden_size // config.num_attention_heads\n    config.num_query_groups = config.num_attention_heads  # no GQA\n    config.layernorm_zero_centered_gamma = False  # False\n    config.apply_query_key_layer_scaling = False  # factor=math.sqrt(head_dim)\n    config.bias_activation_fusion = False  # no swiglu, set false\n    config.bias_dropout_fusion = False  # no dropout, set false\n    config.attention_softmax_in_fp32 = True  # use True\n    # config.normalization = 'LayerNorm' # use RMSNorm\n    config.seq_length = 1\n\n    config.tp_comm_overlap = False\n    config.sequence_parallel = False\n    config.temporal_patch_size = 2\n    config.patch_size = 14\n    config.in_channels = 3\n    config.spatial_merge_size = 2\n\n    config.fullatt_block_indexes = [7, 15, 23, 31]\n    config._qwen2_5_vl_window_size = 112\n    return config\n\n\ndef get_vision_projection_config(\n    config: TransformerConfig, embed_dim: int, spatial_merge_size: int\n) -> TransformerConfig:\n    # merger:\n    # context_dim = hidden_size * merge_size**2\n    # out_hidden_size = hidden_size\n    # context_dim -> context_dim -> out_hidden_size\n    # MLP:\n    # input_size -> ffn_hidden_size -> hidden_size\n    # spec: LN -> Linear(bias=True) -> GELU -> Linear(bias=True)\n    config.gated_linear_unit = False\n    config.bias_activation_fusion = False\n    config.add_bias_linear = True\n    config.ffn_hidden_size = embed_dim * (spatial_merge_size**2)\n    config.activation_func = torch.nn.functional.gelu\n    config.tp_comm_overlap = False\n    config.sequence_parallel = False\n    return config\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/qwen2_5_vl/vision_model.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\nfrom megatron.core import InferenceParams\nfrom megatron.core.models.common.vision_module.vision_module import VisionModule\nfrom megatron.core.models.vision.multimodal_projector import MultimodalProjector\nfrom megatron.core.packed_seq_params import PackedSeqParams\nfrom megatron.core.transformer.enums import ModelType\nfrom megatron.core.transformer.spec_utils import ModuleSpec\nfrom megatron.core.transformer.transformer_config import TransformerConfig\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .vision_transformer_block import Qwen2_5VisionTransformerBlock as TransformerBlock\n\n\n# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py\nclass PatchEmbed(nn.Module):\n    def __init__(\n        self,\n        patch_size: int = 14,\n        temporal_patch_size: int = 2,\n        in_channels: int = 3,\n        embed_dim: int = 1152,\n    ) -> None:\n        super().__init__()\n        self.patch_size = patch_size\n        self.temporal_patch_size = temporal_patch_size\n        self.in_channels = in_channels\n        self.embed_dim = embed_dim\n\n        kernel_size = [temporal_patch_size, patch_size, patch_size]\n        self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        target_dtype = self.proj.weight.dtype\n        hidden_states = hidden_states.view(\n            -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size\n        )\n        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)\n        return hidden_states\n\n\n# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py\nclass VisionRotaryEmbedding(nn.Module):\n    def __init__(self, dim: int, theta: float = 10000.0) -> None:\n        super().__init__()\n        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n    def forward(self, seqlen: int) -> torch.Tensor:\n        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)\n        freqs = torch.outer(seq, self.inv_freq)\n        return freqs.float()\n\n\nclass Qwen2_5VisionModel(VisionModule):\n    \"\"\"Qwen2.5 ViT vision model.\n\n    Args:\n        transformer_config (TransformerConfig): Transformer config.\n        transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers.\n        ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre.\n        add_class_token (bool, optional): Include a class token. Defaults to True.\n        class_token_len (int): Class token length. Defaults to 1 but 8 may be faster.\n        patch_dim (int): Image patch size.\n        img_h (int): Input image height.\n        img_w (int): Input image width.\n    \"\"\"\n\n    def __init__(\n        self,\n        transformer_config: TransformerConfig,\n        transformer_layer_spec: ModuleSpec,\n        projection_config: TransformerConfig,\n        projection_layer_spec: ModuleSpec,\n        projection_type: str = \"mlp\",\n        pre_process: bool = True,\n        post_process: bool = False,\n    ) -> None:\n        super().__init__(config=transformer_config)\n\n        self.spatial_merge_size = transformer_config.spatial_merge_size\n\n        embed_dim = transformer_config.hidden_size\n        num_heads = transformer_config.num_attention_heads\n        temporal_patch_size = transformer_config.temporal_patch_size\n        patch_size = transformer_config.patch_size\n        in_channels = transformer_config.in_channels\n\n        self.patch_size = transformer_config.patch_size\n        self.fullatt_block_indexes = transformer_config.fullatt_block_indexes\n        self.window_size = transformer_config._qwen2_5_vl_window_size\n        self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size\n\n        self.max_sequence_length = transformer_config.seq_length\n        self.patch_embed = PatchEmbed(\n            patch_size=patch_size,\n            temporal_patch_size=temporal_patch_size,\n            in_channels=in_channels,\n            embed_dim=embed_dim,\n        )\n\n        head_dim = embed_dim // num_heads\n        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)\n\n        self.model_type = ModelType.encoder_or_decoder\n        self.pre_process = pre_process\n        self.post_process = post_process\n\n        # Transformer layers.\n        # TODO: Follow-up changes will make pre and post_process configurable. They are needed for supporting\n        # pipeline parallelism.\n        # NOTE: a final layer norm and/or linear layer present in some implementations are omitted here.\n        self.decoder = TransformerBlock(\n            config=transformer_config,\n            spec=transformer_layer_spec,\n            pre_process=self.pre_process,\n            post_process=self.post_process,\n            post_layer_norm=True,\n        )\n\n        self.merge_hidden_size = projection_config.ffn_hidden_size\n        self.square_merge_size = self.merge_hidden_size // embed_dim\n\n        if self.post_process:\n            self.projection = MultimodalProjector(\n                projection_config, projection_layer_spec, projection_type, projection_config.ffn_hidden_size\n            )\n        else:\n            self.projection = None\n\n        self.input_tensor = None\n\n    def set_input_tensor(self, input_tensor: torch.Tensor) -> None:\n        \"\"\"Sets input tensor to the model.\n\n        Args:\n            input_tensor (Tensor): Sets the input tensor for the model.\n        \"\"\"\n        if self.pre_process:  # always True\n            self.input_tensor = input_tensor\n        else:\n            raise NotImplementedError()\n\n    def rot_pos_emb(self, grid_thw):\n        pos_ids = []\n        for t, h, w in grid_thw:\n            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)\n            hpos_ids = hpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            hpos_ids = hpos_ids.permute(0, 2, 1, 3)\n            hpos_ids = hpos_ids.flatten()\n\n            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)\n            wpos_ids = wpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            wpos_ids = wpos_ids.permute(0, 2, 1, 3)\n            wpos_ids = wpos_ids.flatten()\n            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))\n        pos_ids = torch.cat(pos_ids, dim=0).to(grid_thw.device)\n        max_grid_size = grid_thw[:, 1:].max()\n        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n        return rotary_pos_emb\n\n    def get_window_index(self, grid_thw):\n        window_index: list = []\n        cu_window_seqlens: list = [0]\n        window_index_id = 0\n        vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size\n\n        for grid_t, grid_h, grid_w in grid_thw:\n            llm_grid_h, llm_grid_w = (\n                grid_h // self.spatial_merge_size,\n                grid_w // self.spatial_merge_size,\n            )\n            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)\n            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size\n            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size\n            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size\n            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size\n            index_padded = F.pad(index, (0, pad_w, 0, pad_h), \"constant\", -100)\n            index_padded = index_padded.reshape(\n                grid_t,\n                num_windows_h,\n                vit_merger_window_size,\n                num_windows_w,\n                vit_merger_window_size,\n            )\n            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(\n                grid_t,\n                num_windows_h * num_windows_w,\n                vit_merger_window_size,\n                vit_merger_window_size,\n            )\n            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)\n            index_padded = index_padded.reshape(-1)\n            index_new = index_padded[index_padded != -100]\n            window_index.append(index_new + window_index_id)\n            cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]\n            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())\n            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()\n        window_index = torch.cat(window_index, dim=0)\n\n        return window_index, cu_window_seqlens\n\n    def forward(\n        self,\n        vision_data: Optional[torch.Tensor],\n        grid_thw: torch.Tensor,\n        inference_params: Optional[InferenceParams] = None,\n        extra_block_kwargs: dict = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward function of the Qwen2 Vision Model. This function passes the input tensors\n        through the embedding layer and then the transformer.\n\n        Args:\n            x (torch.Tensor): input image/video data of shape [n_tokens, n_dims]\n            grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame\n            packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend\n\n        Returns:\n            x (torch.Tensor): output after final transformer block of shape [b, s, h].\n        \"\"\"\n        assert grid_thw is not None\n        assert self.input_tensor is None\n        assert inference_params is None\n\n        # Rotary positional embeddings (embedding is None for PP intermediate devices)\n        vision_data = self.patch_embed(vision_data)\n        window_index, cu_window_seqlens = self.get_window_index(grid_thw)\n        cu_window_seqlens = torch.tensor(\n            cu_window_seqlens,\n            device=vision_data.device,\n            dtype=torch.int32,\n        )\n        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)\n\n        seq_len, _ = vision_data.size()\n        vision_data = vision_data.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)\n        vision_data = vision_data[window_index, :, :]\n        vision_data = vision_data.reshape(seq_len, 1, -1)\n\n        rotary_pos_emb = self.rot_pos_emb(grid_thw)\n        rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)\n        rotary_pos_emb = rotary_pos_emb[window_index, :, :]\n        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, 1, 1, -1).repeat(1, 1, 1, 2)\n\n        hidden_states = self.decoder(\n            hidden_states=vision_data,\n            attention_mask=None,\n            inference_params=inference_params,\n            rotary_pos_emb=rotary_pos_emb,\n            packed_seq_params=self.build_packed_seq_params(None, cu_window_seqlens),\n            packed_seq_params_full=self.build_packed_seq_params(grid_thw),\n            fullatt_block_indexes=self.fullatt_block_indexes,\n            **(extra_block_kwargs or {}),\n        )\n\n        hidden_states = self.projection(hidden_states.view(-1, self.merge_hidden_size))\n        reverse_indices = torch.argsort(window_index)\n        return hidden_states[reverse_indices, :]\n\n    def build_packed_seq_params(\n        self,\n        grid_thw: Optional[torch.Tensor],\n        cu_seqlens: Optional[torch.Tensor] = None,\n    ) -> PackedSeqParams:\n        # NOTE: each frame is a sequence (rather than each grid)\n        if grid_thw is not None:\n            seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])\n            cu_seqlens = seqlens.cumsum(dim=0)\n            cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).int()\n        else:\n            seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n\n        max_seqlen_q = seqlens.max()\n        return PackedSeqParams(\n            cu_seqlens_q=cu_seqlens,\n            cu_seqlens_kv=cu_seqlens,\n            qkv_format=\"thd\",\n            max_seqlen_q=max_seqlen_q,\n            max_seqlen_kv=max_seqlen_q,\n        )\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright (c) 2024 Alibaba PAI Team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom megatron.core.transformer.transformer_block import *\n\n\nclass Qwen2_5VisionTransformerBlock(TransformerBlock):\n    def _checkpointed_forward(\n        self,\n        hidden_states: Tensor,\n        attention_mask: Tensor,\n        context: Tensor,\n        context_mask: Tensor,\n        rotary_pos_emb: Tensor,\n        attention_bias: Tensor,\n        packed_seq_params: PackedSeqParams,\n        packed_seq_params_full: PackedSeqParams,\n        fullatt_block_indexes,\n    ):\n        \"\"\"Forward method with activation checkpointing.\"\"\"\n\n        def custom(start: int, end: int):\n            def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb):\n                for index in range(start, end):\n                    if index in fullatt_block_indexes:\n                        packed_seq_params_now = packed_seq_params_full\n                    else:\n                        packed_seq_params_now = packed_seq_params\n                    layer = self._get_layer(index)\n                    hidden_states, context = layer(\n                        hidden_states=hidden_states,\n                        attention_mask=attention_mask,\n                        context=context,\n                        context_mask=context_mask,\n                        rotary_pos_emb=rotary_pos_emb,\n                        attention_bias=attention_bias,\n                        inference_context=None,\n                        packed_seq_params=packed_seq_params_now,\n                    )\n                return hidden_states, context\n\n            return custom_forward\n\n        def checkpoint_handler(forward_func):\n            \"\"\"Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`\"\"\"\n            if self.config.fp8:\n                return te_checkpoint(\n                    forward_func,\n                    self.config.distribute_saved_activations,\n                    tensor_parallel.random.get_cuda_rng_tracker,\n                    parallel_state.get_tensor_model_parallel_group(),\n                    hidden_states,\n                    attention_mask,\n                    context,\n                    context_mask,\n                    rotary_pos_emb,\n                )\n            else:\n                return tensor_parallel.checkpoint(\n                    forward_func,\n                    self.config.distribute_saved_activations,\n                    hidden_states,\n                    attention_mask,\n                    context,\n                    context_mask,\n                    rotary_pos_emb,\n                )\n\n        if self.config.recompute_method == \"uniform\":\n            # Uniformly divide the total number of Transformer layers and checkpoint\n            # the input activation of each divided chunk.\n            # A method to further reduce memory usage reducing checkpoints.\n            layer_idx = 0\n            while layer_idx < self.num_layers_per_pipeline_rank:\n                hidden_states, context = checkpoint_handler(\n                    custom(layer_idx, layer_idx + self.config.recompute_num_layers)\n                )\n\n                layer_idx += self.config.recompute_num_layers\n\n        elif self.config.recompute_method == \"block\":\n            # Checkpoint the input activation of only a set number of individual\n            # Transformer layers and skip the rest.\n            # A method fully use the device memory removing redundant re-computation.\n            recompute_skip_num_layers = 0\n            for layer_idx in range(self.num_layers_per_pipeline_rank):\n                # Skip recomputation when input grad computation is not needed.\n                # Need to have at least one input tensor with gradient computation\n                # for re-enterant autograd engine.\n                if self.config.fp8 and not hidden_states.requires_grad:\n                    recompute_skip_num_layers += 1\n                if (\n                    layer_idx >= recompute_skip_num_layers\n                    and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers\n                ):\n                    hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1))\n                else:\n                    hidden_states, context = custom(layer_idx, layer_idx + 1)(\n                        hidden_states, attention_mask, context, context_mask, rotary_pos_emb\n                    )\n        else:\n            raise ValueError(\"Invalid activation recompute method.\")\n\n        return hidden_states\n\n    def forward(\n        self,\n        hidden_states: Union[Tensor, WrappedTensor],\n        attention_mask: Optional[Tensor],\n        context: Optional[Tensor] = None,\n        context_mask: Optional[Tensor] = None,\n        rotary_pos_emb: Optional[Tensor] = None,\n        rotary_pos_cos: Optional[Tensor] = None,\n        rotary_pos_sin: Optional[Tensor] = None,\n        attention_bias: Optional[Tensor] = None,\n        inference_context: Optional[BaseInferenceContext] = None,\n        packed_seq_params: Optional[PackedSeqParams] = None,\n        sequence_len_offset: Optional[Tensor] = None,\n        packed_seq_params_full: PackedSeqParams = None,\n        fullatt_block_indexes=None,\n        *,\n        inference_params: Optional[BaseInferenceContext] = None,\n    ):\n        \"\"\"\n        Perform the forward pass through the transformer block.\n\n        This method handles the core computation of the transformer, including\n        self-attention, optional cross-attention, and feed-forward operations.\n\n        Args:\n            hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h]\n                where s is the sequence length, b is the batch size, and h is the hidden size.\n                Can be passed as a WrappedTensor during inference to avoid an obsolete\n                reference in the calling function.\n            attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking\n                self-attention.\n            context (Tensor, optional): Context tensor for cross-attention.\n            context_mask (Tensor, optional): Mask for cross-attention context\n            rotary_pos_emb (Tensor, optional): Rotary positional embeddings.\n            attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable\n                to [b, num_head, sq, skv], e.g. [1, 1, sq, skv].\n                Used as an alternative to apply attention mask for TE cuDNN attention.\n            inference_context (BaseInferenceContext, optional): Parameters for inference-time\n                optimizations.\n            packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence\n                processing.\n\n        Returns:\n            Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape\n            [s, b, h], and optionally the updated context tensor if cross-attention is used.\n        \"\"\"\n\n        inference_context = deprecate_inference_params(inference_context, inference_params)\n\n        # Delete the obsolete reference to the initial input tensor if necessary\n        if isinstance(hidden_states, WrappedTensor):\n            hidden_states = hidden_states.unwrap()\n\n        if not self.pre_process:\n            # See set_input_tensor()\n            hidden_states = self.input_tensor\n\n        # Update the inference parameters with the current batch size in case it is variable\n        if inference_context and not self.training:\n            inference_context.current_batch_size = hidden_states.size(1)\n\n        # Viewless tensor.\n        # - We only need to create a viewless tensor in the case of micro batch\n        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'\n        #   above creates a view tensor, and '.contiguous()' is a pass-through.\n        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating\n        #   the need to make it viewless.\n        #\n        #   However, we don't explicitly check mbs == 1 here because\n        #   make_viewless_tensor() has negligible overhead when its input\n        #   is already viewless.\n        #\n        # - For the 'else' case above, calling make_viewless_tensor() here is\n        #   likely redundant, since p2p_communication.py (likely originator)\n        #   already creates viewless tensors. That said, make_viewless_tensor()\n        #   is called here to be future-proof and corner-case-proof.\n        hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)\n\n        if self.config.sequence_parallel:\n            rng_context = tensor_parallel.get_cuda_rng_tracker().fork()\n        else:\n            rng_context = nullcontext()\n\n        # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(),\n        # otherwise do nothing extra at the outer level\n        # if we are using other fp8 recipes, then the context manager enter&exit are free\n        # we can wrap fp8_context within the for loop over layers, so that we can fine-grained\n        # control which layer will be fp8 or bf16\n        use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed\n        use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed\n        outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext()\n\n        with rng_context, outer_fp8_context:\n            # Forward pass.\n            if self.config.recompute_granularity == \"full\" and self.training:\n                hidden_states = self._checkpointed_forward(\n                    hidden_states=hidden_states,\n                    attention_mask=attention_mask,\n                    context=context,\n                    context_mask=context_mask,\n                    rotary_pos_emb=rotary_pos_emb,\n                    attention_bias=attention_bias,\n                    packed_seq_params=packed_seq_params,\n                    packed_seq_params_full=packed_seq_params_full,\n                    fullatt_block_indexes=fullatt_block_indexes,\n                )\n            else:\n                for l_no, layer in enumerate(self.layers):\n                    inner_fp8_context = (\n                        get_fp8_context(self.config, layer.layer_number - 1) if use_inner_fp8_context else nullcontext()\n                    )\n                    if l_no in fullatt_block_indexes:\n                        packed_seq_params_now = packed_seq_params_full\n                    else:\n                        packed_seq_params_now = packed_seq_params\n                    with self.offload_context, inner_fp8_context:\n                        hidden_states, context = layer(\n                            hidden_states=hidden_states,\n                            attention_mask=attention_mask,\n                            context=context,\n                            context_mask=context_mask,\n                            rotary_pos_emb=rotary_pos_emb,\n                            rotary_pos_cos=rotary_pos_cos,\n                            rotary_pos_sin=rotary_pos_sin,\n                            attention_bias=attention_bias,\n                            inference_context=inference_context,\n                            packed_seq_params=packed_seq_params_now,\n                            sequence_len_offset=sequence_len_offset,\n                        )\n\n                    if (\n                        torch.is_grad_enabled()\n                        and self.config.cpu_offloading\n                        and self.group_prefetch_offload_commit_async is not None\n                    ):\n                        hidden_states = self.group_prefetch_offload_commit_async(hidden_states)\n\n        # Final layer norm.\n        if self.final_layernorm is not None:\n            hidden_states = self.final_layernorm(hidden_states)\n            # TENorm produces a \"viewed\" tensor. This will result in schedule.py's\n            # deallocate_output_tensor() throwing an error, so a viewless tensor is\n            # created to prevent this.\n            hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)\n\n        return hidden_states\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/readme.md",
    "content": "# verl Megatron-Core Models\nThe earlier versions of verl use `Megatron-LM` 0.4 and workaround huggingface model classes. To better use the latest features and speedup of modern Megatron, we are migrating to `Megatron-Core`(mcore), and use the recommended `GPTModel` class for all language models. With mcore `GPTModel`, we can use the latest features like `context parallel`, `expert parallel`, `dist_checkpointing`, etc. and we can update mcore with little effort in the future for new features.\n\nThe migration has been successful with the help of the mcore team and the community. What we have done is:\n1. update `Megatron` version to `0.11.0`\n2. migrate `LlamaForCausalLM` and `Qwen2ForCausalLM` to mcore `GPTModel`\n3. support sequence packing/thd format.\n4. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`.\n5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion script from huggingface to mcore `dist_checkpointing` format.\n\nWe are working on the following features:\n- support `Qwen2MoeForCausalLM`\n- support `MixtralForCausalLM`\n- support `DeepseekV3ForCausalLM`\n- support `expert parallel`\n\nFeatures we invite the community to contribute:\n- better scripts for offline weights conversion from huggingface to mcore `dist_checkpointing` format.\n    - conversion of large models with multiple GPUs\n    - conversion of large models with single GPU\n- refactor the `megatron_checkpoint_manager.py` by `dist_checkpointing` format.\n- support llama4\n- support qwen2.5-vl\n\nTo track the progress of verl mcore integration, please refer to the [mcore integration issue](https://github.com/volcengine/verl/issues/1033).\n\n## How things work now\nTo engage the community in contributing, here are the key steps in our mcore integration process and features under development. \n\nThe huggingface `transformers` is the de facto standard of model zoo while mcore is good at computation efficiency. The main challenge is conversion between the two.\nmain steps:\n1. modelling the huggingface model with mcore `GPTModel`\n    - a. convert the huggingface config to mcore `TransformerConfig`\n    - b. init the mcore `GPTModel` with the converted config\n    - c. load the huggingface model weights to the `GPTModel`\n2. online weight conversion from mcore to huggingface (due to the rollout engine `vLLM` is using huggingface format)\n    - a. bridge the gap between mcore and huggingface weights format and name mapping\n    - b. online resharding the mcore weights to rollout engine\n        - this part is very complicated with multiple parallel strategies composition between mcore and rollout engine\n3. support the mcore features in verl\n    - a. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`\n    - b. support recompute and other mcore speed up features\n\n4. checkpointing\n    - a. support recovering the verl training.\n    - b. support exporting the mcore checkpoint to huggingface format, for downstream inference.\n\n### Modelling the huggingface model with mcore `GPTModel`\nThe first step is to convert huggingface config to mcore `TransformerConfig` and init the mcore `GPTModel` with the converted config. See code in `verl/models/mcore/config_converter.py` and `verl/verl/models/mcore/models/model_initializer.py`. The corresponding model forward code is in `verl/verl/models/mcore/models/model_forward.py`.\n\nThere are two ways of loading the huggingface model weights to the `GPTModel`\n1. Runtime loading\n    - every rank loads the entire huggingface model weights and then shard and convert to mcore weights.\n    - speed is slow and memory consumption is high.\n    - this way is deprecated and will not support new models.\n2. Offline loading\n    - use offline script to convert the huggingface model weights to mcore weights and save with mcore `dist_checkpointing` format.\n    - online loading and sharding is automatically done by mcore `dist_checkpointing` format. The speed is fast and memory consumption is low.\n    - the offline script is in `verl/scripts/converter_hf_to_mcore.py`.\n\n### online weight conversion from mcore to huggingface\nSee function `convert_megatron_model_to_transformers_model` in `verl/utils/megatron_utils.py` for the details.\n\nIt should be refatored for extensibility and better performance.\n\n### support the mcore features in verl\nMost of the features of `GPTModel` is out-of-the-box supported in verl through changing the `TransformerConfig`, except those about parallel strategies, such as `expert parallel`. \nFeatures about parallel strategies should be supported with changes about the online weights conversion(especially the resharding part) and verl work dispatching.\n\n### checkpointing\nThe existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger`.\n\nThe existing checkpoint format simply saves every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format.\n\n\n## How to support new models\n1. make sure the model is supported by vLLM\n2. modelling the huggingface model with mcore `GPTModel` (The [Pai-Megatron-Path](https://github.com/alibaba/Pai-Megatron-Patch/tree/main) is a good reference)\n    - a. convert the huggingface config to mcore `TransformerConfig`\n    - b. init the mcore `GPTModel` with the converted config\n    - c. load the huggingface model weights to the `GPTModel`\n    - d. for VLM the interface might be different, it is ok to add a new model class with GPTModel as its module.\n3. offline weights conversion from huggingface to mcore `dist_checkpointing` format\n4. support online weights conversion from mcore to huggingface\n    - it is recommended to initialize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct.\n\n\n## How to scale up to larger models like deepseek-v3 or other 100B+ models\nThe greatest challenge for scaling up to larger models is the memory consumption.\n\nThe necessary features under development for scaling up are\n1. Training engine part\n    - expert parallel\n2. Rollout engine part\n    - pipeline parallel\n    - expert parallel\n    - more efficient and general weight resharding and loading\n3. Offline weights conversion\n    - support weights larger than single GPU memory\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/registry.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nRegistry module for model architecture components.\n\"\"\"\n\nfrom enum import Enum\nfrom typing import Callable\n\nimport torch\nimport torch.nn as nn\n\nfrom .config_converter import (\n    PretrainedConfig,\n    TransformerConfig,\n    hf_to_mcore_config_dense,\n    hf_to_mcore_config_dpskv3,\n    hf_to_mcore_config_llama4,\n    hf_to_mcore_config_mixtral,\n    hf_to_mcore_config_qwen2_5_vl,\n    hf_to_mcore_config_qwen2moe,\n    hf_to_mcore_config_qwen3moe,\n)\nfrom .model_forward import (\n    gptmodel_forward,\n    gptmodel_forward_qwen2_5_vl,\n)\nfrom .model_forward_fused import (\n    fused_forward_gptmodel,\n    fused_forward_qwen2_5_vl,\n)\nfrom .model_initializer import (\n    BaseModelInitializer,\n    DeepseekV3Model,\n    DenseModel,\n    MixtralModel,\n    Qwen2MoEModel,\n    Qwen3MoEModel,\n    Qwen25VLModel,\n)\nfrom .weight_converter import (\n    McoreToHFWeightConverterDense,\n    McoreToHFWeightConverterDpskv3,\n    McoreToHFWeightConverterMixtral,\n    McoreToHFWeightConverterQwen2_5_VL,\n    McoreToHFWeightConverterQwen2Moe,\n    McoreToHFWeightConverterQwen3Moe,\n)\n\n\nclass SupportedModel(Enum):\n    LLAMA = \"LlamaForCausalLM\"  # tested\n    QWEN2 = \"Qwen2ForCausalLM\"  # tested\n    QWEN2_MOE = \"Qwen2MoeForCausalLM\"  # pending\n    DEEPSEEK_V3 = \"DeepseekV3ForCausalLM\"  # not tested\n    MIXTRAL = \"MixtralForCausalLM\"  # tested\n    QWEN2_5_VL = \"Qwen2_5_VLForConditionalGeneration\"  # not supported\n    LLAMA4 = \"Llama4ForConditionalGeneration\"  # not tested\n    QWEN3 = \"Qwen3ForCausalLM\"  # tested\n    QWEN3_MOE = \"Qwen3MoeForCausalLM\"  # not tested\n\n\n# Registry for model configuration converters\nMODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {\n    SupportedModel.LLAMA: hf_to_mcore_config_dense,\n    SupportedModel.QWEN2: hf_to_mcore_config_dense,\n    SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,\n    SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3,\n    SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral,\n    SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,\n    SupportedModel.LLAMA4: hf_to_mcore_config_llama4,\n    SupportedModel.QWEN3: hf_to_mcore_config_dense,\n    SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,\n    SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,\n}\n\n# Registry for model initializers\nMODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = {\n    SupportedModel.LLAMA: DenseModel,\n    SupportedModel.QWEN2: DenseModel,\n    SupportedModel.QWEN2_MOE: Qwen2MoEModel,\n    SupportedModel.MIXTRAL: MixtralModel,\n    SupportedModel.DEEPSEEK_V3: DeepseekV3Model,\n    SupportedModel.QWEN2_5_VL: Qwen25VLModel,\n    SupportedModel.LLAMA4: DenseModel,\n    SupportedModel.QWEN3: DenseModel,\n    SupportedModel.QWEN3_MOE: Qwen3MoEModel,\n    SupportedModel.QWEN2_5_VL: Qwen25VLModel,\n}\n\n# Registry for model forward functions\nMODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {\n    SupportedModel.LLAMA: gptmodel_forward,\n    SupportedModel.QWEN2: gptmodel_forward,\n    SupportedModel.QWEN2_MOE: gptmodel_forward,\n    SupportedModel.MIXTRAL: gptmodel_forward,\n    SupportedModel.DEEPSEEK_V3: gptmodel_forward,\n    SupportedModel.QWEN2_5_VL: gptmodel_forward,\n    SupportedModel.LLAMA4: gptmodel_forward,\n    SupportedModel.QWEN3: gptmodel_forward,\n    SupportedModel.QWEN3_MOE: gptmodel_forward,\n    SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,\n    SupportedModel.DEEPSEEK_V3: gptmodel_forward,\n}\n\n# Registry for model forward functions\nMODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {\n    SupportedModel.LLAMA: fused_forward_gptmodel,\n    SupportedModel.QWEN2: fused_forward_gptmodel,\n    SupportedModel.QWEN2_MOE: fused_forward_gptmodel,\n    SupportedModel.MIXTRAL: fused_forward_gptmodel,\n    SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,\n    SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl,\n    SupportedModel.LLAMA4: fused_forward_gptmodel,\n    SupportedModel.QWEN3: fused_forward_gptmodel,\n    SupportedModel.QWEN3_MOE: fused_forward_gptmodel,\n    SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl,\n    SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,\n}\n\n# Registry for model weight converters\nMODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = {\n    SupportedModel.LLAMA: McoreToHFWeightConverterDense,\n    SupportedModel.QWEN2: McoreToHFWeightConverterDense,\n    SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,\n    SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral,\n    SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3,\n    SupportedModel.QWEN3: McoreToHFWeightConverterDense,\n    SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,\n    SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL,\n}\n\n\ndef get_supported_model(model_type: str) -> SupportedModel:\n    try:\n        return SupportedModel(model_type)\n    except ValueError as err:\n        supported_models = [e.value for e in SupportedModel]\n        raise NotImplementedError(\n            f\"Model Type: {model_type} not supported. Supported models: {supported_models}\"\n        ) from err\n\n\ndef hf_to_mcore_config(\n    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs\n) -> TransformerConfig:\n    \"\"\"Convert huggingface PretrainedConfig to mcore TransformerConfig.\n\n    Args:\n        hf_config: The huggingface PretrainedConfig.\n        dtype: The dtype of the model.\n        **override_transformer_config_kwargs: The kwargs to override the transformer config.\n\n    Returns:\n        The mcore TransformerConfig.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs)\n\n\ndef init_mcore_model(\n    tfconfig: TransformerConfig,\n    hf_config: PretrainedConfig,\n    pre_process: bool = True,\n    post_process: bool = None,\n    *,\n    share_embeddings_and_output_weights: bool = False,\n    value: bool = False,\n    **extra_kwargs,  # may be used for vlm and moe\n) -> nn.Module:\n    \"\"\"\n    Initialize a Mcore model.\n\n    Args:\n        tfconfig: The transformer config.\n        hf_config: The HuggingFace config.\n        pre_process: Optional pre-processing function.\n        post_process: Optional post-processing function.\n        share_embeddings_and_output_weights: Whether to share embeddings and output weights.\n        value: Whether to use value.\n        **extra_kwargs: Additional keyword arguments.\n\n    Returns:\n        The initialized model.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    initializer_cls = MODEL_INITIALIZER_REGISTRY[model]\n    initializer = initializer_cls(tfconfig, hf_config)\n    return initializer.initialize(\n        pre_process=pre_process,\n        post_process=post_process,\n        share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n        value=value,\n        **extra_kwargs,\n    )\n\n\ndef get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:\n    \"\"\"\n    Get the forward function for given model architecture.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    return MODEL_FORWARD_REGISTRY[model]\n\n\ndef get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable:\n    \"\"\"\n    Get the forward function for given model architecture.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    return MODEL_FORWARD_FUSED_REGISTRY[model]\n\n\ndef get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable:\n    \"\"\"\n    Get the weight converter for given model architecture.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    tfconfig = hf_to_mcore_config(hf_config, dtype)\n    return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig)\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/saver.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import mpu\nfrom megatron.core.distributed import DistributedDataParallel as LocalDDP\nfrom megatron.core.transformer.module import Float16Module\nfrom torch.nn.parallel import DistributedDataParallel as torchDDP\n\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.logger import print_rank_0\nfrom verl.utils.megatron_utils import unwrap_model\n\n\ndef _megatron_calc_global_rank(\n    tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0\n):\n    \"\"\"Calculate global rank with support for CP/EP parallelism\"\"\"\n\n    # Get parallel sizes for each dimension\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    dp_size = mpu.get_data_parallel_world_size()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    cp_size = mpu.get_context_parallel_world_size()\n    # ep_size = mpu.get_expert_model_parallel_world_size()\n\n    # Verify total GPU count matches (must be consistent with parallel_state.py)\n    total_size = tp_size * dp_size * pp_size * cp_size\n    assert total_size == torch.distributed.get_world_size(), (\n        f\"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}\"\n    )\n\n    # Core calculation logic (corresponds to RankGenerator order parameter)\n    # Assumes default order is \"tp-cp-ep-dp-pp\"\n    return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Merge sharded parameters of a Megatron module into a merged checkpoint.\n\n    Args:\n        wrapped_models (list of megatron.core.distributed.DistributedDataParallel):\n            The local DDP wrapped megatron modules.\n        config (str or None):\n            HF config for model\n        dtype: model params type\n        is_value_model: if model is value model\n        tie_word_embeddings: tie_word_embeddings\n    Returns:\n        state_dict (dict):\n            The merged state_dict in rank 0, and an empty dictionary in other ranks.\n    \"\"\"\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    cp_rank = mpu.get_context_parallel_rank()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if dist.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        assert len(models[i].decoder.layers) == num_layers_per_model, (\n            \"len model layers {} not equal to num_layers_per_model {}\".format(\n                len(models[i].decoder.layers), num_layers_per_model\n            )\n        )\n\n    state_dict = dict()\n\n    def _get_cpu_tensor(tensor: torch.Tensor):\n        if tensor is None:\n            return None\n        if tensor.device == torch.device(\"cpu\"):\n            return tensor.detach().clone()\n        return tensor.detach().cpu()\n\n    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        if torch.distributed.get_rank() == src_rank:\n            if tensor is None:\n                weight = None\n                tensor_shape = None\n            else:\n                weight = tensor\n                tensor_shape = weight.shape\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not exist, skip collect\")\n            return\n\n        if weight is None:\n            weight = torch.empty(\n                tensor_shape,\n                dtype=dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n\n        dist.broadcast(weight, src=src_rank, group=mp_group)\n\n        if torch.distributed.get_rank() == 0:\n            state_dict[name] = _get_cpu_tensor(weight)\n\n    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        # tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)\n            if mutate_func is not None:\n                full_tensor = mutate_func(full_tensor)\n            state_dict[name] = full_tensor\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        # tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            intermediate_size_tp = config.intermediate_size // tp_size\n            gate_weight_list = []\n            up_weight_list = []\n            for i in range(tp_size):\n                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n                gate_weight_list.append(gate_weight_tp)\n                up_weight_list.append(up_weight_tp)\n\n            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)\n            state_dict[up_name] = torch.cat(up_weight_list, dim=0)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        # tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            q_weight_list = []\n            k_weight_list = []\n            v_weight_list = []\n            hidden_size_per_head = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_size_chunk = q_size_tp // num_query_groups_per_partition\n                    kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                    for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                        q_part = qkv_part_chunk[:q_size_chunk]\n                        k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                        v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                        q_weight_list.append(q_part)\n                        k_weight_list.append(k_part)\n                        v_weight_list.append(v_part)\n            else:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_size_chunk = q_size_tp // num_query_groups_per_partition\n                    kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                    for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                        q_part = qkv_part_chunk[:q_size_chunk]\n                        k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                        v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                        q_weight_list.append(q_part)\n                        if i * config.num_key_value_heads % tp_size == 0:\n                            k_weight_list.append(k_part)\n                            v_weight_list.append(v_part)\n\n            state_dict[q_name] = torch.cat(q_weight_list, dim=0)\n            state_dict[k_name] = torch.cat(k_weight_list, dim=0)\n            state_dict[v_name] = torch.cat(v_weight_list, dim=0)\n\n    # empty cache before collecting weights\n    get_torch_device().empty_cache()\n    # Embeddings\n    # -------------------\n    if dp_rank == 0 and cp_rank == 0:  # models are identical across cp ranks\n        # Embeddings\n        # -------------------\n        print_rank_0(\"collecting embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        _broadcast_tp_shard_tensor(\n            gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None,\n            \"model.embed_tokens.weight\",\n            src_pp_rank=0,\n        )\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"collecting layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])\n            sync_layer = gpt_model_module.decoder.layers[src_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.self_attention.linear_qkv.layer_norm_weight,\n                f\"{layer_name}.input_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            if gpt_model_module.config.qk_layernorm:\n                _broadcast_tensor(\n                    sync_layer.self_attention.q_layernorm.weight,\n                    f\"{layer_name}.self_attn.q_norm.weight\",\n                    src_pp_rank=src_pp_rank,\n                )\n                _broadcast_tensor(\n                    sync_layer.self_attention.k_layernorm.weight,\n                    f\"{layer_name}.self_attn.k_norm.weight\",\n                    src_pp_rank=src_pp_rank,\n                )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attention.linear_qkv.weight,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            if gpt_model_module.config.add_qkv_bias:\n                _broadcast_tp_shard_tensor_qkv(\n                    sync_layer.self_attention.linear_qkv.bias,\n                    f\"{layer_name}.self_attn.q_proj.bias\",\n                    f\"{layer_name}.self_attn.k_proj.bias\",\n                    f\"{layer_name}.self_attn.v_proj.bias\",\n                    src_pp_rank=src_pp_rank,\n                )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attention.linear_proj.weight,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tensor(\n                sync_layer.mlp.linear_fc1.layer_norm_weight,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.linear_fc1.weight,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.linear_fc2.weight,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"collecting final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.decoder.final_layernorm, \"weight\", None),\n            \"model.norm.weight\",\n            src_pp_rank=pp_size - 1,\n        )\n\n        if tie_word_embeddings:\n            print_rank_0(\"tie word embedding skip load lm_head...\")\n        else:\n            print_rank_0(\"collecting lm_head...\")\n\n            if is_value_model:\n                lm_head_weight = None\n                if pp_rank == pp_size - 1:\n                    lm_head_weight = getattr(gpt_model_module.output_layer, \"weight\", None)\n                _broadcast_tensor(lm_head_weight, \"lm_head.weight\", src_pp_rank=pp_size - 1)\n\n            else:\n                _broadcast_tp_shard_tensor(\n                    getattr(gpt_model_module.output_layer, \"weight\", None) if pp_rank == pp_size - 1 else None,\n                    \"lm_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n\n    dist.barrier()\n    get_torch_device().empty_cache()\n    if torch.distributed.get_rank() == 0:\n        for k, v in state_dict.items():\n            if dtype != v.dtype:\n                state_dict[k] = v.to(dtype)\n\n    print_rank_0(f\"merge megatron ckpt done, time elapsed {time.time() - start_time}s\")\n    return state_dict\n\n\ndef merge_megatron_ckpt_gptmodel_qwen_moe(\n    wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False\n):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_qwen_moe is not implemented\")\n\n\ndef merge_megatron_ckpt_gptmodel_qwen2_5_vl(\n    wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False\n):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_qwen2_5_vl is not implemented\")\n\n\ndef merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_dpskv3 is not implemented\")\n\n\ndef merge_megatron_ckpt_gptmodel_mixtral(\n    wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False\n):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_mixtral is not implemented\")\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/util.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.packed_seq_params import PackedSeqParams\n\nfrom verl.utils.model import CausalLMOutputForPPO\n\n\ndef preprocess_packed_seqs(\n    input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True\n) -> tuple[torch.Tensor, PackedSeqParams]:\n    \"\"\"\n    Preprocess packed sequences\n    CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1\n    gets second and second last chunks, and so on), this is for load balancing with causal masking.\n    See https://github.com/NVIDIA/TransformerEngine/issues/1368\n    \"\"\"\n    batch_size = input_ids.shape[0]\n\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    cp_size = mpu.get_context_parallel_world_size()\n    cp_rank = mpu.get_context_parallel_rank()\n    align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size\n\n    pad_size = (align_size - seqlens_in_batch % align_size) % align_size\n    seqlens_in_batch_padded = seqlens_in_batch + pad_size\n    cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)\n    cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)\n    cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)\n    cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)\n    max_seqlen_in_batch = seqlens_in_batch_padded.max().item()\n\n    shape = list(input_ids.shape[1:])\n    shape[0] = seqlens_in_batch_padded.sum().item() // cp_size\n    if pre_process:\n        input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)\n        for i in range(batch_size):\n            if cp_size <= 1:\n                seqlen = seqlens_in_batch[i]\n                input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]]\n                continue\n            seqlen = seqlens_in_batch_padded[i] // cp_size\n            half_seqlen = seqlen // 2\n            start_idx = cu_seqlens_padded[i] // cp_size\n            # split to 2 chunks\n            d = input_ids[i, attention_mask[i]]\n            input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[\n                half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)\n            ]\n\n            remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1)\n            remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank\n            remain_end = min(remain_end, d.shape[0])\n            remain_len = remain_end - remain_start\n            if remain_len > 0:\n                input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[\n                    remain_start:remain_end\n                ]\n\n    packed_seq_params = PackedSeqParams(\n        qkv_format=\"thd\",\n        cu_seqlens_q=cu_seqlens_padded,\n        max_seqlen_q=max_seqlen_in_batch,\n        cu_seqlens_kv=cu_seqlens_padded,\n        max_seqlen_kv=max_seqlen_in_batch,\n        cu_seqlens_q_padded=cu_seqlens_padded,\n        cu_seqlens_kv_padded=cu_seqlens_padded,\n    )\n    if pre_process:\n        return input_ids_rmpad.unsqueeze(0), packed_seq_params\n    else:\n        return input_ids, packed_seq_params\n\n\ndef postprocess_packed_seqs(\n    output: torch.Tensor,\n    packed_seq_params: PackedSeqParams,\n    attention_mask: torch.Tensor,\n    batch_size: int,\n    seq_len: int,\n    post_process: bool = True,\n) -> torch.Tensor:\n    \"\"\"\n    Postprocess packed sequences\n    \"\"\"\n    if not post_process:\n        return output\n    shape = [batch_size, seq_len] + list(output.shape[2:])  # 1,packed, dim -> batch_size, seq_len, dim\n    output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)\n\n    cp_size = mpu.get_context_parallel_world_size()\n    # all gather output across context parallel group\n    if cp_size > 1:\n        # output shape: [1, packed_len, hidden_dim]\n        # need to gather across cp group and concatenate in sequence dimension\n        output_list = [torch.empty_like(output) for _ in range(cp_size)]\n        torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())\n        output_list[mpu.get_context_parallel_rank()] = output\n    else:\n        output_list = [output]\n    for i in range(batch_size):\n        if cp_size <= 1:\n            s = attention_mask[i].sum().item()\n            output_new[i, attention_mask[i]] = output[0][\n                packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s\n            ]\n            continue\n        s_len_padded_chunk = (\n            packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i]\n        ) // cp_size\n        half_seqlen = s_len_padded_chunk // 2\n        s_len = attention_mask[i].sum().item()\n        s_len_padded = s_len_padded_chunk * cp_size\n        tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)\n        for j in range(cp_size):\n            o = output_list[j][0]\n            # split to 2 chunks\n            packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size\n            o0, o1 = (\n                o[packed_start_idx : packed_start_idx + half_seqlen],\n                o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],\n            )\n            tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0\n            tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1\n        output_new[i, attention_mask[i]] = tmp[:s_len]\n\n    return output_new\n\n\ndef remove_left_padding(\n    input_ids: torch.Tensor,\n    attention_mask: torch.Tensor,\n    position_ids: torch.Tensor,\n    sequence_parallel: bool = False,\n    pre_process: bool = True,\n):\n    \"\"\"\n    Remove left padding from input_ids, attention_mask and position_ids\n    return new_input_ids, new_attention_mask, new_position_ids\n    \"\"\"\n    assert attention_mask.ndim == 2\n    assert position_ids.ndim == 2\n    cp_size = mpu.get_context_parallel_world_size()\n    assert cp_size == 1, \"Context parallel size without seq_pack is not supported\"\n    batch_size = input_ids.shape[0]\n    shape = list(input_ids.shape)  # batch_size, seq_len,...\n    seq_lens = attention_mask.sum(dim=1)\n    seq_len = seq_lens.max().item()\n    if sequence_parallel:\n        sp_world_size = mpu.get_tensor_model_parallel_world_size()\n        pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size\n        seq_len = seq_len + pad_size\n    shape[1] = seq_len\n    if pre_process:\n        new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape)\n    new_attention_mask = torch.zeros(\n        dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len)\n    )\n    new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len))\n    for i in range(batch_size):\n        if pre_process:\n            new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]]\n        new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]]\n        new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]]\n    if pre_process:\n        return new_input_ids, new_attention_mask, new_position_ids\n    else:\n        return input_ids, new_attention_mask, new_position_ids\n\n\ndef recover_left_padding(\n    result,\n    attention_mask: torch.Tensor,\n    original_attention_mask: torch.Tensor,\n    origin_seqlen: int,\n    post_process: bool = True,\n):\n    \"\"\"\n    Recover left padding from result\n    return result\n    \"\"\"\n    if not post_process:\n        return result\n    shape = list(result.shape)\n    batch_size = shape[0]\n    shape[1] = origin_seqlen\n    new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape)\n    for i in range(batch_size):\n        new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]]\n    return new_result\n\n\ndef postprocess_packed_seqs_for_dict_output(\n    labels_mask: torch.Tensor,\n    output: CausalLMOutputForPPO,\n    packed_seq_params: PackedSeqParams,\n    attention_mask: torch.Tensor,\n    batch_size: int,\n    seq_len: int,\n    post_process: bool = True,\n) -> dict[str, torch.Tensor]:\n    \"\"\"_summary_\n    For fused kernels, the output is a dictionary with keys like 'log_probs', 'entropy', etc.\n    This function post-processes each tensor in the output dictionary.\n    Args:\n        output (CausalLMOutputForPPO): _description_\n        packed_seq_params (PackedSeqParams): _description_\n        attention_mask (torch.Tensor): _description_\n        batch_size (int): _description_\n        seq_len (int): _description_\n        post_process (bool, optional): _description_. Defaults to True.\n    Returns:\n        CausalLMOutputForPPO: _description_\n    \"\"\"\n    ret = {}\n    output.entropy = output.entropy.view(1, -1)\n    output.log_probs = output.log_probs.view(1, -1)\n    output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0)\n    ret[\"entropy\"] = postprocess_packed_seqs(\n        output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n    )\n    ret[\"log_probs\"] = postprocess_packed_seqs(\n        output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n    )\n    return ret\n"
  },
  {
    "path": "verl_rl/verl/models/mcore/weight_converter.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# online convert mcore weight to pure huggingface weight, no any fusion\n# including format conversion and name mapping\n# not including resharding\nimport torch\nfrom megatron.core.transformer import TransformerConfig\nfrom transformers import PretrainedConfig\n\n\nclass McoreToHFWeightConverterBase:\n    def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig):\n        self.hf_config = hf_config\n        self.mcore_config = mcore_config\n\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor:\n        raise NotImplementedError\n\n\nclass McoreToHFWeightConverterDense(McoreToHFWeightConverterBase):\n    def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # 'decoder.layers.0.self_attention.linear_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight'\n        # 'decoder.layers.0.self_attention.linear_qkv.weight'\n        # 'decoder.layers.0.self_attention.linear_qkv.bias'\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"self_attention.linear_qkv.bias\" in name or \"self_attention.linear_qkv.weight\" in name:\n            param_type = name.split(\".\")[-1]\n            assert param_type == \"bias\" or param_type == \"weight\"\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.q_proj.{param_type}\")\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.k_proj.{param_type}\")\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.v_proj.{param_type}\")\n            assert len(params) == 3\n        elif \"self_attention.linear_proj.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.o_proj.weight\")\n            assert len(params) == 1\n        elif \"self_attention.linear_qkv.layer_norm_weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.input_layernorm.weight\")\n            assert len(params) == 1\n        elif \"self_attention.q_layernorm.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.q_norm.weight\")\n            assert len(params) == 1\n        elif \"self_attention.k_layernorm.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.k_norm.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight'\n        # 'decoder.layers.0.mlp.linear_fc1.weight'\n        # 'decoder.layers.0.mlp.linear_fc2.weight'\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"mlp.linear_fc1.weight\" in name:\n            # split gate_proj and up_proj\n            convert_names.append(f\"model.layers.{layer_number}.mlp.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.up_proj.weight\")\n            assert len(params) == 2\n        elif \"mlp.linear_fc1.layer_norm_weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n            assert len(params) == 1\n        elif \"mlp.linear_fc2.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.down_proj.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        direct_name_mapping = {\n            \"embedding.word_embeddings.weight\": \"model.embed_tokens.weight\",\n            \"decoder.final_layernorm.weight\": \"model.norm.weight\",\n            \"output_layer.weight\": \"lm_head.weight\",\n        }\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params_one_group[0]]\n\n        if \"self_attention\" in name:\n            return self._convert_attention_param(name, params_one_group)\n        elif \"mlp\" in name:\n            return self._convert_mlp_param(name, params_one_group)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n\nclass McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense):\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # 'decoder.layers.0.pre_mlp_layernorm.weight',\n        # 'decoder.layers.0.mlp.router.weight',\n        # 'decoder.layers.0.mlp.shared_experts.gate_weight',\n        # 'decoder.layers.0.mlp.shared_experts.linear_fc1.weight',\n        # 'decoder.layers.0.mlp.shared_experts.linear_fc2.weight'\n        # moe1\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight1',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight2',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight3',\n        # moe2\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight1',\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"pre_mlp_layernorm\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n            assert len(params) == 1\n        elif \"mlp.router.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.gate.weight\")\n            assert len(params) == 1\n        elif \"shared_experts.gate_weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert_gate.weight\")\n            assert len(params) == 1\n        elif \"shared_experts.linear_fc1.weight\" in name:  # split gate_proj and up_proj\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight\")\n            assert len(params) == 2\n        elif \"shared_experts.linear_fc2.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight\")\n            assert len(params) == 1\n        elif \"mlp.experts.linear_fc1\" in name:  # split gate_proj and up_proj\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight\")\n            assert len(params) == 2\n        elif \"mlp.experts.linear_fc2\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n\nclass McoreToHFWeightConverterQwen2_5_VL(McoreToHFWeightConverterDense):\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        direct_name_mapping = {\n            \"language_model.embedding.word_embeddings.weight\": \"model.embed_tokens.weight\",\n            \"language_model.decoder.final_layernorm.weight\": \"model.norm.weight\",\n            \"language_model.output_layer.weight\": \"lm_head.weight\",\n            \"vision_model.patch_embed.proj.weight\": \"visual.patch_embed.proj.weight\",\n            \"vision_model.decoder.final_layernorm.weight\": \"visual.merger.ln_q.weight\",\n            \"vision_model.projection.encoder.linear_fc1.weight\": \"visual.merger.mlp.0.weight\",\n            \"vision_model.projection.encoder.linear_fc1.bias\": \"visual.merger.mlp.0.bias\",\n            \"vision_model.projection.encoder.linear_fc2.weight\": \"visual.merger.mlp.2.weight\",\n            \"vision_model.projection.encoder.linear_fc2.bias\": \"visual.merger.mlp.2.bias\",\n        }\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params_one_group[0]]\n\n        if \"self_attention\" in name:\n            return self._convert_attention_param(name, params_one_group)\n        elif \"mlp\" in name:\n            return self._convert_mlp_param(name, params_one_group)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n    def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        model_type, _, _, layer_number = name.split(\".\")[:4]\n\n        convert_names = []\n        if model_type == \"language_model\":\n            name_map_after_layer = {\n                \"self_attention.linear_qkv.bias\": [\n                    \"self_attn.q_proj.bias\",\n                    \"self_attn.k_proj.bias\",\n                    \"self_attn.v_proj.bias\",\n                ],\n                \"self_attention.linear_qkv.weight\": [\n                    \"self_attn.q_proj.weight\",\n                    \"self_attn.k_proj.weight\",\n                    \"self_attn.v_proj.weight\",\n                ],\n                \"self_attention.linear_proj.weight\": \"self_attn.o_proj.weight\",\n                \"self_attention.linear_qkv.layer_norm_weight\": \"input_layernorm.weight\",\n            }\n            name_after_layer = \".\".join(name.split(\".\")[-3:])\n            mapped_name = name_map_after_layer.get(name_after_layer)\n            if isinstance(mapped_name, list):\n                assert len(params) == len(mapped_name)\n                for one in mapped_name:\n                    convert_names.append(f\"model.layers.{layer_number}.{one}\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"model.layers.{layer_number}.{mapped_name}\")\n        elif model_type == \"vision_model\":\n            name_map_after_layer = {\n                \"self_attention.linear_proj.weight\": \"attn.proj.weight\",\n                \"self_attention.linear_proj.bias\": \"attn.proj.bias\",\n                \"self_attention.linear_qkv.layer_norm_weight\": \"norm1.weight\",\n            }\n            name_after_layer = \".\".join(name.split(\".\")[-3:])\n            mapped_name = name_map_after_layer.get(name_after_layer, None)\n            if mapped_name is None:\n                assert \"linear_qkv\" in name_after_layer\n                assert len(params) == 3\n                new_param = torch.cat(params, dim=0)\n                params = [new_param]\n                if \"bias\" in name_after_layer:\n                    convert_names.append(f\"visual.blocks.{layer_number}.attn.qkv.bias\")\n                else:\n                    convert_names.append(f\"visual.blocks.{layer_number}.attn.qkv.weight\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"visual.blocks.{layer_number}.{mapped_name}\")\n        else:\n            raise NotImplementedError(f\"Unsupported model type: {model_type}\")\n        return convert_names, params\n\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        model_type, _, _, layer_number = name.split(\".\")[:4]\n\n        convert_names = []\n        if model_type == \"language_model\":\n            name_map_after_layer = {\n                \"mlp.linear_fc1.weight\": [\"mlp.gate_proj.weight\", \"mlp.up_proj.weight\"],\n                \"mlp.linear_fc1.bias\": [\"mlp.gate_proj.bias\", \"mlp.up_proj.bias\"],\n                \"mlp.linear_fc2.weight\": \"mlp.down_proj.weight\",\n                \"mlp.linear_fc2.bias\": \"mlp.down_proj.bias\",\n                \"mlp.linear_fc1.layer_norm_weight\": \"post_attention_layernorm.weight\",\n            }\n            name_after_layer = \".\".join(name.split(\".\")[-3:])\n            mapped_name = name_map_after_layer.get(name_after_layer)\n            if isinstance(mapped_name, list):\n                assert len(params) == len(mapped_name)\n                for one in mapped_name:\n                    convert_names.append(f\"model.layers.{layer_number}.{one}\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"model.layers.{layer_number}.{mapped_name}\")\n\n        elif model_type == \"vision_model\":\n            name_map_after_layer = {\n                \"mlp.linear_fc1.weight\": [\"mlp.gate_proj.weight\", \"mlp.up_proj.weight\"],\n                \"mlp.linear_fc1.bias\": [\"mlp.gate_proj.bias\", \"mlp.up_proj.bias\"],\n                \"mlp.linear_fc2.weight\": \"mlp.down_proj.weight\",\n                \"mlp.linear_fc2.bias\": \"mlp.down_proj.bias\",\n                \"mlp.linear_fc1.layer_norm_weight\": \"norm2.weight\",\n            }\n            name_after_layer = \".\".join(name.split(\".\")[-3:])\n            mapped_name = name_map_after_layer.get(name_after_layer)\n            if isinstance(mapped_name, list):\n                assert len(params) == len(mapped_name)\n                for one in mapped_name:\n                    convert_names.append(f\"visual.blocks.{layer_number}.{one}\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"visual.blocks.{layer_number}.{mapped_name}\")\n        else:\n            raise NotImplementedError(f\"Unsupported model type: {model_type}\")\n        return convert_names, params\n\n\nclass McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase):\n    def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # mcore\n        # 'decoder.layers.0.input_layernorm.weight'\n        # 'decoder.layers.0.self_attention.linear_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_kv_down_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_kv_up_proj.layer_norm_weight'\n        # 'decoder.layers.0.self_attention.linear_kv_up_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_down_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_up_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_up_proj.layer_norm_weight'\n        # hf\n        # 'model.layers.0.input_layernorm.weight'\n        # 'model.layers.0.self_attn.o_proj.weight'\n        # 'model.layers.0.self_attn.q_proj.weight'\n        # 'model.layers.0.self_attn.kv_a_proj_with_mqa.weight'\n        # 'model.layers.0.self_attn.kv_a_layernorm.weight'\n        # 'model.layers.0.self_attn.kv_b_proj.weight'\n        # 'model.layers.0.self_attn.q_a_proj.weight'\n        # 'model.layers.0.self_attn.q_b_proj.weight'\n        # 'model.layers.0.self_attn.q_a_layernorm.weight'\n        name_map_after_layer = {\n            \"input_layernorm.weight\": \"input_layernorm.weight\",\n            \"self_attention.linear_proj.weight\": \"self_attn.o_proj.weight\",\n            \"self_attention.linear_q_proj.weight\": \"self_attn.q_proj.weight\",\n            \"self_attention.linear_kv_down_proj.weight\": \"self_attn.kv_a_proj_with_mqa.weight\",\n            \"self_attention.linear_kv_up_proj.layer_norm_weight\": \"self_attn.kv_a_layernorm.weight\",\n            \"self_attention.linear_kv_up_proj.weight\": \"self_attn.kv_b_proj.weight\",\n            \"self_attention.linear_q_down_proj.weight\": \"self_attn.q_a_proj.weight\",\n            \"self_attention.linear_q_up_proj.weight\": \"self_attn.q_b_proj.weight\",\n            \"self_attention.linear_q_up_proj.layer_norm_weight\": \"self_attn.q_a_layernorm.weight\",\n        }\n        assert len(params) == 1\n        convert_names = []\n        layer_number = name.split(\".\")[2]\n        name_after_layer = name.split(f\".{layer_number}.\")[1]\n        convert_names.append(f\"model.layers.{layer_number}.{name_map_after_layer[name_after_layer]}\")\n        return convert_names, params\n\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # mcore dense\n        # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight'\n        # 'decoder.layers.0.mlp.linear_fc2.weight'\n        # 'decoder.layers.0.mlp.linear_fc1.weight'\n        #       ---\n        # 'decoder.layers.1.mlp.shared_experts.linear_fc1.weight'\n        #       ---\n        # 'decoder.layers.1.mlp.shared_experts.linear_fc2.weight'\n        # hf dense\n        # 'model.layers.0.post_attention_layernorm.weight'\n        # 'model.layers.0.mlp.down_proj.weight'\n        # 'model.layers.0.mlp.gate_proj.weight'\n        # 'model.layers.0.mlp.up_proj.weight'\n        # 'model.layers.1.mlp.shared_experts.gate_proj.weight'\n        # 'model.layers.1.mlp.shared_experts.up_proj.weight'\n        # 'model.layers.1.mlp.shared_experts.down_proj.weight'\n\n        # mcore moe\n        # 'decoder.layers.1.pre_mlp_layernorm.weight'\n        # 'decoder.layers.1.mlp.router.weight'\n        # 'decoder.layers.1.mlp.router.expert_bias'\n        # 'decoder.layers.1.mlp.experts.linear_fc1.weight0'\n        #       ---\n        # 'decoder.layers.1.mlp.experts.linear_fc2.weight0'\n        # hf moe\n        # 'model.layers.1.post_attention_layernorm.weight'\n        # 'model.layers.1.mlp.gate.weight'\n        # 'model.layers.1.mlp.gate.e_score_correction_bias'\n        # 'model.layers.1.mlp.experts.0.gate_proj.weight'\n        # 'model.layers.1.mlp.experts.0.up_proj.weight'\n        # 'model.layers.1.mlp.experts.0.down_proj.weight'\n\n        name_map_after_layer = {\n            \"mlp.linear_fc1.layer_norm_weight\": \"post_attention_layernorm.weight\",\n            \"mlp.linear_fc2.weight\": \"mlp.down_proj.weight\",\n            \"mlp.shared_experts.linear_fc2.weight\": \"mlp.shared_experts.down_proj.weight\",\n            \"mlp.linear_fc1.weight\": [\"mlp.gate_proj.weight\", \"mlp.up_proj.weight\"],\n            \"mlp.shared_experts.linear_fc1.weight\": [\n                \"mlp.shared_experts.gate_proj.weight\",\n                \"mlp.shared_experts.up_proj.weight\",\n            ],\n            \"pre_mlp_layernorm.weight\": \"post_attention_layernorm.weight\",\n            \"mlp.router.weight\": \"mlp.gate.weight\",\n            \"mlp.router.expert_bias\": \"mlp.gate.e_score_correction_bias\",\n        }\n        convert_names = []\n        layer_number = name.split(\".\")[2]\n        name_after_layer = name.split(f\".{layer_number}.\")[1]\n        if name_after_layer in name_map_after_layer:\n            mapped_name = name_map_after_layer[name_after_layer]\n            if isinstance(mapped_name, list):\n                assert len(params) == len(mapped_name)\n                for one in mapped_name:\n                    convert_names.append(f\"model.layers.{layer_number}.{one}\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"model.layers.{layer_number}.{mapped_name}\")\n        else:\n            if \"mlp.experts.linear_fc1.weight\" in name:\n                expert_id = name.split(\"weight\")[-1]\n                convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight\")\n                convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight\")\n                assert len(params) == 2\n            elif \"mlp.experts.linear_fc2.weight\" in name:\n                expert_id = name.split(\"weight\")[-1]\n                convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight\")\n                assert len(params) == 1\n            else:\n                raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n        return convert_names, params\n\n    def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        assert self.mcore_config.mtp_num_layers == 1, \"only support one mtp layer for now\"\n        assert self.mcore_config.num_layers == 61, \"only support 61 layers for now\"\n        direct_name_mapping = {\n            \"mtp.layers.0.enorm.weight\": \"model.layers.61.enorm.weight\",\n            \"mtp.layers.0.hnorm.weight\": \"model.layers.61.hnorm.weight\",\n            \"mtp.layers.0.eh_proj.weight\": \"model.layers.61.eh_proj.weight\",\n            \"mtp.layers.0.final_layernorm.weight\": \"model.layers.61.shared_head.norm.weight\",\n        }\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params[0]]\n        assert \"mtp.layers.0.transformer_layer\" in name, \"only support transformer layer for now\"\n        # use proxy name to convert\n        proxy_name = name.replace(\"mtp.layers.0.transformer_layer\", \"decoder.layers.61\")\n        if \"self_attention\" in proxy_name or \"input_layernorm.weight\" in proxy_name:\n            convert_names, params = self._convert_attention_param(proxy_name, params)\n        elif \"mlp\" in proxy_name:\n            convert_names, params = self._convert_mlp_param(proxy_name, params)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        direct_name_mapping = {\n            \"embedding.word_embeddings.weight\": \"model.embed_tokens.weight\",\n            \"decoder.final_layernorm.weight\": \"model.norm.weight\",\n            \"output_layer.weight\": \"lm_head.weight\",\n        }\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params_one_group[0]]\n        if \"mtp\" in name:\n            return self._convert_mtp_param(name, params_one_group)\n        elif \"self_attention\" in name or \"input_layernorm.weight\" in name:\n            return self._convert_attention_param(name, params_one_group)\n        elif \"mlp\" in name:\n            return self._convert_mlp_param(name, params_one_group)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n\nclass McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense):\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # decoder.layers.0.mlp.router.weight\n        # decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7\n        # decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7\n\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"pre_mlp_layernorm\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n        elif \"mlp.router.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.gate.weight\")\n        elif \"mlp.experts.linear_fc1.weight\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight\")\n        elif \"mlp.experts.linear_fc2.weight\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight\")\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n\nclass McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense):\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # qwen3 moe no share expert\n\n        # 'decoder.layers.0.pre_mlp_layernorm.weight',\n        # 'decoder.layers.0.mlp.router.weight',\n        # moe1\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight1',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight2',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight3',\n        # moe2\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight1',\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"pre_mlp_layernorm\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n            assert len(params) == 1\n        elif \"mlp.router.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.gate.weight\")\n            assert len(params) == 1\n        elif \"mlp.experts.linear_fc1\" in name:  # split gate_proj and up_proj\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight\")\n            assert len(params) == 2\n        elif \"mlp.experts.linear_fc2\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .modeling_qwen2_megatron import (\n    ParallelQwen2ForCausalLM,\n    # rmpad with megatron\n    ParallelQwen2ForCausalLMRmPad,\n    # rmpad with megatron and pipeline parallelism\n    ParallelQwen2ForCausalLMRmPadPP,\n    ParallelQwen2ForValueRmPad,\n    ParallelQwen2ForValueRmPadPP,\n    # original model with megatron\n    ParallelQwen2Model,\n)\n\n__all__ = [\n    \"ParallelQwen2ForCausalLM\",\n    \"ParallelQwen2ForCausalLMRmPad\",\n    \"ParallelQwen2ForCausalLMRmPadPP\",\n    \"ParallelQwen2ForValueRmPad\",\n    \"ParallelQwen2ForValueRmPadPP\",\n    \"ParallelQwen2Model\",\n]\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/checkpoint_utils/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_qwen2(\n    state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False\n):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from verl.utils.logger import print_rank_0\n    from verl.utils.megatron_utils import unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def fetch_params(module):\n        for param in module.parameters():\n            torch.distributed.fetch(\n                param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()\n            )\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, (\n        f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: \"\n        f\"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n    )\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _fetch_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"fetch tensor\"\"\"\n        nonlocal state_dict\n        if tensor is not None:\n            tensor = tensor.data.copy_(state_dict[name], non_blocking=True)\n\n    def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"fetch gate_up tensor in tp shards\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if gate_name in state_dict and up_name in state_dict:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(\n                config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(\n                    torch.cat([gate_weight_tp, up_weight_tp], dim=0)\n                )\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            if tensor is not None:\n                tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n        else:\n            print(f\"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n        full_weight_q = state_dict[q_name]\n        full_weight_k = state_dict[k_name]\n        full_weight_v = state_dict[v_name]\n\n        hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n        if config.num_key_value_heads >= tp_size:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n            total_size = q_size_tp + 2 * kv_size_tp\n            if not bias:\n                new_weight_qkv = torch.empty(\n                    total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                )\n            else:\n                new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        else:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head\n            total_size = q_size_tp + 2 * kv_size_tp\n            if not bias:\n                new_weight_qkv = torch.empty(\n                    total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                )\n            else:\n                new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                k_part = full_weight_k[start_idx:end_idx]\n                v_part = full_weight_v[start_idx:end_idx]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n        if tensor is not None:\n            tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n\n    # Embeddings\n    # -------------------\n    print_rank_0(\"loading embeddings...\")\n    gpt_model_module = _get_gpt_model(models[0])\n    if pp_rank == 0:\n        embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n        _fetch_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n    # Transformer layers\n    # -------------------\n    layer_map = _megatron_calc_layer_map(config)\n\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    num_layer_per_pp = config.num_hidden_layers // pp_size\n    vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n\n    layer_list = []\n    if vpp_size is not None:\n        for vpp_rank in range(vpp_size):\n            num_layer_vpp_chunk = num_layer_per_pp // vpp_size\n            num_layer_this_model = num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (\n                mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk\n            )\n            layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n    else:\n        num_layer_this_model = num_layer_per_pp\n        offset = pp_rank * num_layer_per_pp\n        layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n\n    for layer in layer_list:\n        print(f\"{torch.distributed.get_rank()} loading layer #{layer}...\")\n        layer_name = f\"model.layers.{layer}\"\n        dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n        print(\n            f\"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, \"\n            f\"layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}\"\n        )\n\n        gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n        sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n        _fetch_tensor(\n            sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.input_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_qkv(\n            sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.q_proj.weight\",\n            f\"{layer_name}.self_attn.k_proj.weight\",\n            f\"{layer_name}.self_attn.v_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor_qkv(\n            sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.q_proj.bias\",\n            f\"{layer_name}.self_attn.k_proj.bias\",\n            f\"{layer_name}.self_attn.v_proj.bias\",\n            bias=True,\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.o_proj.weight\",\n            chunk_dim=1,\n        )\n\n        _fetch_tensor(\n            sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.post_attention_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_gate_up(\n            sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.gate_proj.weight\",\n            f\"{layer_name}.mlp.up_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.down_proj.weight\",\n            chunk_dim=1,\n        )\n    # Final Layernorm\n    # -------------------\n    print_rank_0(\"loading final layernorm...\")\n    gpt_model_module = _get_gpt_model(models[-1])\n    _fetch_tensor(\n        getattr(gpt_model_module.model.norm, \"weight\", None),\n        \"model.norm.weight\",\n    )\n\n    if tie_word_embeddings:\n        print_rank_0(\"tie_word_embeddings skip load lm_head\")\n    else:\n        print_rank_0(\"loading lm_head...\")\n        if pp_rank + 1 == pp_size:\n            lm_head_weight = gpt_model_module.lm_head.weight\n\n            if is_value_model:\n                if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                    _fetch_tensor(lm_head_weight, \"lm_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                    _fetch_tensor(lm_head_weight, \"reward_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                else:\n                    _fetch_tensor(None, \"lm_head.weight\")\n                    print_rank_0(\"fail to match lm_head in value_model\")\n\n            else:\n                _fetch_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n\n    dist.barrier()\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_qwen2(\n    state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False\n):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from verl.utils.logger import print_rank_0\n    from verl.utils.megatron_utils import unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def broadcast_params(module):\n        for param in module.parameters():\n            torch.distributed.broadcast(\n                param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()\n            )\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, (\n        f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: \"\n        f\"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n    )\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _broadcast_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"broadcast tensor from rank0 across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                weight = state_dict[name]\n                tensor_shape = weight.shape\n            else:\n                tensor_shape = None\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not in state_dict, skip load\")\n            return\n\n        if tensor is None:\n            tensor = torch.empty(\n                tensor_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        if torch.distributed.get_rank() == 0:\n            tensor.data.copy_(weight)\n        dist.broadcast(tensor, src=0, group=mp_group)\n\n    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(\n                config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id()\n            )\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(\n                    torch.cat([gate_weight_tp, up_weight_tp], dim=0)\n                )\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape \"\n                f\"{tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n            full_weight_q = state_dict[q_name]\n            full_weight_k = state_dict[k_name]\n            full_weight_v = state_dict[v_name]\n\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                if not bias:\n                    new_weight_qkv = torch.empty(\n                        total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                    )\n                else:\n                    new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(\n                        torch.cat([q_part, k_part, v_part], dim=0)\n                    )\n\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                if not bias:\n                    new_weight_qkv = torch.empty(\n                        total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id()\n                    )\n                else:\n                    new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                    k_part = full_weight_k[start_idx:end_idx]\n                    v_part = full_weight_v[start_idx:end_idx]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(\n                        torch.cat([q_part, k_part, v_part], dim=0)\n                    )\n\n            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, (\n                f\"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}\"\n            )\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"loading embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        embed_tokens_weight = None\n        if pp_rank == 0:\n            embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"loading layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.input_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.bias\",\n                f\"{layer_name}.self_attn.k_proj.bias\",\n                f\"{layer_name}.self_attn.v_proj.bias\",\n                bias=True,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                chunk_dim=1,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                chunk_dim=1,\n            )\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"loading final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n        )\n\n        if tie_word_embeddings:\n            print_rank_0(\"tie_word_embeddings skip load lm_head\")\n        else:\n            print_rank_0(\"loading lm_head...\")\n            lm_head_weight = None\n            if pp_rank + 1 == pp_size:\n                lm_head_weight = gpt_model_module.lm_head.weight\n\n            if is_value_model:\n                if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                    _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                    _broadcast_tensor(lm_head_weight, \"reward_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                else:\n                    _broadcast_tensor(None, \"lm_head.weight\")\n                    print_rank_0(\"fail to match lm_head in value_model\")\n\n            else:\n                _broadcast_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n\n    dist.barrier()\n    # Broadcast weights inside data parallel groups\n    for wrapped_model in wrapped_models:\n        broadcast_params(wrapped_model)\n\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import mpu\nfrom megatron.core.distributed import DistributedDataParallel as LocalDDP\nfrom megatron.core.transformer.module import Float16Module\nfrom torch.nn.parallel import DistributedDataParallel as torchDDP\n\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.logger import print_rank_0\nfrom verl.utils.megatron_utils import unwrap_model\n\n\ndef _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):\n    \"\"\"given TP,DP,PP rank to get the global rank.\"\"\"\n\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    dp_size = mpu.get_data_parallel_world_size()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), (\n        f\"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}\"\n    )\n    # We only support TP-DP-PP grouping, for correctness when resharding\n    return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = (\n                virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            )\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Merge sharded parameters of a Megatron module into a merged checkpoint.\n\n    Args:\n        wrapped_models (list of megatron.core.distributed.DistributedDataParallel):\n            The local DDP wrapped megatron modules.\n        config (str or None):\n            HF config for model\n        dtype: model params type\n        is_value_model: if model is value model\n        tie_word_embeddings: tie_word_embeddings\n    Returns:\n        state_dict (dict):\n            The merged state_dict in rank 0, and an empty dictionary in other ranks.\n    \"\"\"\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if dist.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, list | tuple):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        assert len(models[i].model.layers) == num_layers_per_model, (\n            \"len model layers {} not equal to num_layers_per_model {}\".format(\n                len(models[i].model.layers), num_layers_per_model\n            )\n        )\n\n    state_dict = dict()\n\n    def _get_cpu_tensor(tensor: torch.Tensor):\n        if tensor is None:\n            return None\n        if tensor.device == torch.device(\"cpu\"):\n            return tensor.detach().clone()\n        return tensor.detach().cpu()\n\n    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        if torch.distributed.get_rank() == src_rank:\n            if tensor is None:\n                weight = None\n                tensor_shape = None\n            else:\n                weight = tensor\n                tensor_shape = weight.shape\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not exist, skip collect\")\n            return\n\n        if weight is None:\n            weight = torch.empty(\n                tensor_shape,\n                dtype=dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n\n        dist.broadcast(weight, src=src_rank, group=mp_group)\n\n        if torch.distributed.get_rank() == 0:\n            state_dict[name] = _get_cpu_tensor(weight)\n\n    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)\n            if mutate_func is not None:\n                full_tensor = mutate_func(full_tensor)\n            state_dict[name] = full_tensor\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            intermediate_size_tp = config.intermediate_size // tp_size\n            gate_weight_list = []\n            up_weight_list = []\n            for i in range(tp_size):\n                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n                gate_weight_list.append(gate_weight_tp)\n                up_weight_list.append(up_weight_tp)\n\n            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)\n            state_dict[up_name] = torch.cat(up_weight_list, dim=0)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            q_weight_list = []\n            k_weight_list = []\n            v_weight_list = []\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    k_weight_list.append(k_part)\n                    v_weight_list.append(v_part)\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    if i * config.num_key_value_heads % tp_size == 0:\n                        k_weight_list.append(k_part)\n                        v_weight_list.append(v_part)\n\n            state_dict[q_name] = torch.cat(q_weight_list, dim=0)\n            state_dict[k_name] = torch.cat(k_weight_list, dim=0)\n            state_dict[v_name] = torch.cat(v_weight_list, dim=0)\n\n    # empty cache before collecting weights\n    get_torch_device().empty_cache()\n    # Embeddings\n    # -------------------\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"collecting embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        _broadcast_tp_shard_tensor(\n            gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,\n            \"model.embed_tokens.weight\",\n            src_pp_rank=0,\n        )\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"collecting layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[src_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight,\n                f\"{layer_name}.input_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.bias,\n                f\"{layer_name}.self_attn.q_proj.bias\",\n                f\"{layer_name}.self_attn.k_proj.bias\",\n                f\"{layer_name}.self_attn.v_proj.bias\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"collecting final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n            src_pp_rank=pp_size - 1,\n        )\n\n        if tie_word_embeddings:\n            print_rank_0(\"tie word embedding skip load lm_head...\")\n        else:\n            print_rank_0(\"collecting lm_head...\")\n\n            if is_value_model:\n                _broadcast_tensor(\n                    gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,\n                    \"lm_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n                _broadcast_tensor(\n                    gpt_model_module.reward_head.weight\n                    if pp_rank == pp_size - 1 and getattr(gpt_model_module, \"reward_weight\", None) is not None\n                    else None,\n                    \"reward_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n\n            else:\n                _broadcast_tp_shard_tensor(\n                    getattr(gpt_model_module.lm_head, \"weight\", None) if pp_rank == pp_size - 1 else None,\n                    \"lm_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n\n    dist.barrier()\n\n    get_torch_device().empty_cache()\n    if torch.distributed.get_rank() == 0:\n        for k, v in state_dict.items():\n            if dtype != v.dtype:\n                state_dict[k] = v.to(dtype)\n\n    print_rank_0(f\"merge megatron ckpt done, time elapsed {time.time() - start_time}s\")\n    return state_dict\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/layers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .parallel_attention import ParallelQwen2Attention\nfrom .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad\nfrom .parallel_mlp import ParallelQwen2MLP\nfrom .parallel_rmsnorm import ParallelQwen2RMSNorm\n\n__all__ = [\n    \"ParallelQwen2Attention\",\n    \"ParallelQwen2DecoderLayer\",\n    \"ParallelQwen2DecoderLayerRmPad\",\n    \"ParallelQwen2MLP\",\n    \"ParallelQwen2RMSNorm\",\n]\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/layers/parallel_attention.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Optional\n\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom transformers.utils import is_flash_attn_2_available\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\nimport torch\nfrom flash_attn.layers.rotary import apply_rotary_emb\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers import Qwen2Config\n\nfrom verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear\nfrom verl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass Qwen2RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n        )\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\nclass Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding):\n    \"\"\"Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n        t = t / self.scaling_factor\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\nclass Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding):\n    \"\"\"Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * (\n                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)\n            ) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass ParallelQwen2Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config = config\n        self.megatron_config = megatron_config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n\n        # assign values after tp\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert self.num_heads % tp_size == 0, (\n            f\"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}\"\n        )\n        assert self.num_key_value_heads % tp_size == 0, (\n            f\"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=\"\n            f\"{self.num_key_value_heads}, tp_size={tp_size}\"\n        )\n\n        self.num_heads_per_tp = self.num_heads // tp_size\n        self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size\n        self.hidden_size_per_tp = self.hidden_size // tp_size\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and \"\n                f\"`num_heads`: {self.num_heads}).\"\n            )\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n\n        # [self.q_size, self.k_size, self.v_size]\n        self.qkv_proj = QKVParallelLinear(\n            input_size=self.hidden_size,\n            num_heads=self.num_heads,\n            num_key_value_heads=self.num_key_value_heads,\n            head_dim=self.head_dim,\n            # bias=config.attention_bias,\n            bias=True,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n        self.q_size = self.num_heads_per_tp * self.head_dim\n        self.k_size = self.num_key_value_heads_per_tp * self.head_dim\n        self.v_size = self.num_key_value_heads_per_tp * self.head_dim\n\n        self.o_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.num_heads * self.head_dim,\n            output_size=self.hidden_size,\n            # bias=config.attention_bias,\n            bias=False,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self._init_rope()\n\n    def _init_rope(self):\n        self.rotary_emb = Qwen2RotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, \"\n                f\"but is {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, \"\n                f\"but is {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)\n        attn_output = self.o_proj(attn_output)[0]\n        return attn_output\n\n\n\"\"\"\nRemove padding Attention\n- Using Flash-attn 2\n- Compatible with sequence parallel\n\"\"\"\n\n\ndef apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):\n    batch_size = position_ids.shape[0]\n\n    q = pad_input(q, indices, batch_size, sequence_length)  # (batch_size, seqlen, num_head, head_dim)\n    k = pad_input(k, indices, batch_size, sequence_length)\n    cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n\n    q_embed = index_first_axis(rearrange(q_embed, \"b s ... -> (b s) ...\"), indices)\n    k_embed = index_first_axis(rearrange(k_embed, \"b s ... -> (b s) ...\"), indices)\n\n    return q_embed, k_embed\n\n\n# use flash-attn rotary embeddings with rmpad\n# cos/sin shoudl be: (seq_length, rotary_dim / 2)\ndef apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):\n    q_embed = apply_rotary_emb(\n        q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen\n    )\n    k_embed = apply_rotary_emb(\n        k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen\n    )\n    return q_embed, k_embed\n\n\nclass ParallelQwen2AttentionRmPad(ParallelQwen2Attention):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: torch.Tensor = None,\n        max_seqlen_in_batch: int = None,\n    ):\n        total_nnz, _, _ = hidden_states.size()  # This is the total_nnz padded after sequence parallel\n\n        if self.megatron_config.sequence_parallel:\n            total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()\n\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split(\n            [self.q_size, self.k_size, self.v_size], dim=-1\n        )  # (total_nnz, 1, hidden_size)\n\n        if self.megatron_config.sequence_parallel:\n            sequence_parallel_pad = total_nnz - cu_seqlens[-1]\n            total_nnz = cu_seqlens[-1]  # total_nnz before sp padding\n            query_states = query_states[:total_nnz]\n            key_states = key_states[:total_nnz]\n            value_states = value_states[:total_nnz]\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dime x hidden_dim\n        # therefore we just need to keep the original shape\n        query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)\n        key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n        value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n\n        cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)\n        cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2]  # flash attn only needs half\n        query_states, key_states = apply_rotary_pos_emb_rmpad_flash(\n            query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch\n        )\n        # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin,\n        # position_ids, indices,\n\n        # It is recommended to use dropout with FA according to the docs\n        # when training.\n        dropout_rate = 0.0  # if not self.training else self.attn_dropout\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in float16 just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (Qwen2RMSNorm handles it correctly)\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            query_states = query_states.to(torch.float16)\n            key_states = key_states.to(torch.float16)\n            value_states = value_states.to(torch.float16)\n\n        attn_output_unpad = flash_attn_varlen_func(\n            query_states,\n            key_states,\n            value_states,\n            cu_seqlens_q=cu_seqlens,\n            cu_seqlens_k=cu_seqlens,\n            max_seqlen_q=max_seqlen_in_batch,\n            max_seqlen_k=max_seqlen_in_batch,\n            dropout_p=dropout_rate,\n            softmax_scale=None,\n            causal=True,\n        )\n\n        attn_output_unpad = attn_output_unpad.to(input_dtype)\n        attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()\n\n        # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled\n        # Here we need to repad\n        if self.megatron_config.sequence_parallel:\n            attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))\n\n        attn_output_unpad = self.o_proj(attn_output_unpad)[0]\n        return attn_output_unpad\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/layers/parallel_decoder.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import Qwen2Config\n\nfrom verl.utils.megatron_utils import TransformerConfig, convert_config\n\nfrom .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad\nfrom .parallel_mlp import ParallelQwen2MLP\nfrom .parallel_rmsnorm import ParallelQwen2RMSNorm\n\n\nclass ParallelQwen2DecoderLayer(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Note: sequence parallel is hidden inside ColumnParallelLinear\n        # reduce scatter is hidden inside RowParallelLinear\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        # TODO: add sequence parallel operator all_gather here\n\n        hidden_states = self.mlp(hidden_states)\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n\n\nclass ParallelQwen2DecoderLayerRmPad(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.hidden_size = config.hidden_size\n        self.layer_idx = layer_idx\n        self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        residual = hidden_states  # (total_nnz // sp, 1, hidden_size)\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)\n        # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        # shape changes same as attn\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/layers/parallel_linear.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py\n\n\nfrom megatron.core import tensor_parallel\n\n\nclass QKVParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        num_heads,\n        num_key_value_heads,\n        head_dim,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.q_output_size = num_heads * head_dim\n        self.kv_output_size = num_key_value_heads * head_dim\n        self.head_dim = head_dim\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        input_size = self.input_size\n        output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim\n\n        super().__init__(\n            input_size=input_size,\n            output_size=output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n\n\nclass MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        gate_ouput_size,\n        up_output_size,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.output_size = gate_ouput_size + up_output_size\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        super().__init__(\n            input_size=self.input_size,\n            output_size=self.output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/layers/parallel_mlp.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers.activations import ACT2FN\n\nfrom verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear\nfrom verl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass ParallelQwen2MLP(nn.Module):\n    def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=self.hidden_size,\n            gate_ouput_size=self.intermediate_size,\n            up_output_size=self.intermediate_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n        self.gate_size = self.intermediate_size // tp_size\n\n        self.down_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.intermediate_size,\n            output_size=self.hidden_size,\n            bias=False,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        gate_up = self.gate_up_proj(x)[0]\n        gate, up = gate_up.split(self.gate_size, dim=-1)\n        return self.down_proj(self.act_fn(gate) * up)[0]\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numbers\n\nimport torch\nfrom apex.normalization.fused_layer_norm import fused_rms_norm_affine\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import Qwen2Config\n\nfrom verl.utils.megatron import sequence_parallel as sp_utils\n\n\nclass ParallelQwen2RMSNorm(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        \"\"\"\n        Qwen2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        if isinstance(config.hidden_size, numbers.Integral):\n            normalized_shape = (config.hidden_size,)\n        self.normalized_shape = torch.Size(normalized_shape)\n        self.weight = nn.Parameter(torch.ones(self.normalized_shape))\n        self.variance_epsilon = config.rms_norm_eps\n\n        if megatron_config.sequence_parallel:\n            sp_utils.mark_parameter_as_sequence_parallel(self.weight)\n\n    def forward(self, hidden_states):\n        return fused_rms_norm_affine(\n            input=hidden_states,\n            weight=self.weight,\n            normalized_shape=self.normalized_shape,\n            eps=self.variance_epsilon,\n            memory_efficient=True,\n        )\n"
  },
  {
    "path": "verl_rl/verl/models/qwen2/megatron/modeling_qwen2_megatron.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Qwen2 model.\"\"\"\n\nfrom typing import Optional\n\nimport torch\nimport torch.utils.checkpoint\nfrom megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel\nfrom torch import nn\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\nfrom transformers.models.qwen2.configuration_qwen2 import Qwen2Config\nfrom transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast\n\nfrom verl.utils.device import get_device_name\nfrom verl.utils.megatron import sequence_parallel as sp_utils\nfrom verl.utils.megatron import tensor_parallel as tp_utils\nfrom verl.utils.megatron_utils import TransformerConfig, convert_config\n\nfrom .layers import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad, ParallelQwen2RMSNorm\n\n\"\"\"\nTODO: \n1. Add weight initialization. Here we need to be careful on TP weight init.\n2. Add sequence parallel\n3. Load checkpoint from Qwen2 pretrained checkpoint\n\"\"\"\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\nclass ParallelQwen2Model(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n\n    Args:\n        config: Qwen2Config\n    \"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n            num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n        )\n\n        self.layers = nn.ModuleList(\n            [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n                inputs_embeds.device\n            )\n            combined_attention_mask = (\n                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n            )\n\n        return combined_attention_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (batch_size, seq_length)\n            attention_mask: attention_mask. shape (batch_size, seq_length)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        batch_size, seq_length = input_ids.shape\n        inputs_embeds = self.embed_tokens(input_ids)\n        # embed positions\n\n        attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)\n\n        hidden_states = inputs_embeds\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelQwen2ForCausalLM(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.model = ParallelQwen2Model(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        hidden_states = outputs\n        logits = self.lm_head(hidden_states)[0]\n\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)\n\n        logits = logits.float()\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nfrom flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n\nclass ParallelQwen2ModelRmPad(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n\n    Args:\n        config: Qwen2Config\n    \"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        self.megatron_config = megatron_config\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n            num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n        )\n\n        self.layers = nn.ModuleList(\n            [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n        # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n        inputs_embeds = inputs_embeds.transpose(0, 1)\n        if self.megatron_config.sequence_parallel:\n            inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n        hidden_states = inputs_embeds\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelQwen2ForCausalLMRmPad(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelQwen2ModelRmPad(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n        self._init_head(config)\n\n    def _init_head(self, config: Qwen2Config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        logits = self.lm_head(hidden_states)[0]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)  # (total_nnz_padded, 1, vocab_size)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n        batch_size, sequence_length = input_ids.shape\n\n        # remove padding here\n        input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(\n            input_ids.unsqueeze(dim=-1), attention_mask\n        )  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids = sp_utils.pad_to_sequence_parallel(input_ids)\n\n        input_ids = input_ids.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = outputs\n\n        logits = self._forward_head(hidden_states)\n\n        # remove padding from sequence parallel\n        if self.megatron_config.sequence_parallel:\n            totol_nnz = cu_seqlens[-1]\n            logits = logits[:totol_nnz]  # (total_nnz_padded)\n\n        logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension\n        # add removed padding back\n        logits = pad_input(\n            logits, indices, batch_size, seqlen=sequence_length\n        )  # (batch_size, sequence_length, vocab_size)\n\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nclass ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        output = super().forward(input_ids, attention_mask, position_ids)\n        output.logits = torch.squeeze(output.logits, dim=-1)\n        return output\n\n\n\"\"\"\nSupport pipeline parallelism\n\"\"\"\n\n\nclass ParallelQwen2ModelRmPadPP(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n    This model definition supports pipeline parallelism. To support pp and vpp,\n    - This model only contains layer in this pp stage and vpp chunk\n    - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.\n    Args:\n        config: Qwen2Config\n    \"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        self.megatron_config = megatron_config\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        if pre_process:\n            self.embed_tokens = tensor_parallel.VocabParallelEmbedding(\n                num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs\n            )\n        else:\n            self.embed_tokens = None\n\n        pp_rank = mpu.get_pipeline_model_parallel_rank()\n        pp_size = megatron_config.pipeline_model_parallel_size\n        self.num_layer_per_pp = config.num_hidden_layers // pp_size\n        vpp_size = megatron_config.virtual_pipeline_model_parallel_size\n        vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()\n\n        if vpp_size is not None:\n            self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size\n            self.num_layer_this_model = self.num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk)\n        else:\n            self.num_layer_this_model = self.num_layer_per_pp\n            offset = pp_rank * self.num_layer_per_pp\n\n        self.layers = nn.ModuleList()\n        for i in range(self.num_layer_this_model):\n            layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config, layer_idx=i + offset)\n            self.layers.add_module(f\"{i}\", layer)\n\n        if post_process:\n            self.norm = ParallelQwen2RMSNorm(config, megatron_config)\n        else:\n            self.norm = None\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        self.input_tensor = input_tensor\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> tuple | BaseModelOutputWithPast:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        if self.pre_process:\n            inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n            # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron\n            # so need to deal with it by handle here:\n            # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n            inputs_embeds = inputs_embeds.transpose(0, 1)\n            if self.megatron_config.sequence_parallel:\n                inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n            hidden_states = inputs_embeds\n        else:\n            # self.hidden_states should be passed by Megatron\n            hidden_states = self.input_tensor\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        if self.post_process:\n            hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelQwen2ForCausalLMRmPadPP(nn.Module):\n    def __init__(\n        self,\n        config: Qwen2Config,\n        megatron_config: ModelParallelConfig,\n        pre_process,\n        post_process,\n        share_embeddings_and_output_weights,\n    ):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelQwen2ModelRmPadPP(\n            config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process\n        )\n        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        if post_process:\n            self._init_head(config)\n        if pre_process or post_process:\n            self.setup_embeddings_and_output_layer()\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        assert len(input_tensor) == 1\n        self.model.set_input_tensor(input_tensor[0])\n\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights,\n            **column_kwargs,\n        )\n\n    def setup_embeddings_and_output_layer(self) -> None:\n        \"\"\"Sets up embedding layer in first stage and output layer in last stage.\n\n        This function initalizes word embeddings in the final stage when we are\n        using pipeline parallelism and sharing word embeddings, and sets up param\n        attributes on the embedding and output layers.\n        \"\"\"\n        # Set `is_embedding_or_output_parameter` attribute.\n        if self.pre_process:\n            self.model.embed_tokens.weight.is_embedding_or_output_parameter = True\n        if self.post_process and self.lm_head.weight is not None:\n            self.lm_head.weight.is_embedding_or_output_parameter = True\n\n        if not self.share_embeddings_and_output_weights:\n            return\n\n        if parallel_state.get_pipeline_model_parallel_world_size() == 1:\n            # Zero out wgrad if sharing embeddings between two layers on same\n            # pipeline stage to make sure grad accumulation into main_grad is\n            # correct and does not include garbage values (e.g., from torch.empty).\n            self.shared_embedding_or_output_weight().zero_out_wgrad = True\n            return\n\n        if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process:\n            self.shared_embedding_or_output_weight().shared_embedding = True\n\n        if self.post_process and not self.pre_process:\n            assert not parallel_state.is_pipeline_first_stage()\n            # set word_embeddings weights to 0 here, then copy first\n            # stage's weights using all_reduce below.\n            self.lm_head.weight.data.fill_(0)\n            self.lm_head.weight.shared = True\n            self.lm_head.weight.shared_embedding = True\n\n        if torch.distributed.is_initialized() and parallel_state.is_rank_in_embedding_group():\n            weight = self.shared_embedding_or_output_weight()\n            weight.data = weight.data.to(get_device_name())\n            torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group())\n\n    def shared_embedding_or_output_weight(self) -> torch.Tensor:\n        if self.pre_process:\n            return self.model.embed_tokens.weight\n        elif self.post_process:\n            return self.lm_head.weight\n        return None\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = '\n        # f'{self.config.vocab_size}') # [4, 32, 4096]\n        output_weight = None\n        if self.share_embeddings_and_output_weights:\n            output_weight = self.shared_embedding_or_output_weight()\n        logits = self.lm_head(hidden_states, weight=output_weight)[0]\n        # print(f'logits shape after forward_head: {logits.shape}') # [8, 32, 8]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        return logits\n\n    def forward(\n        self,\n        # original input\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # Note that input_ids, attention_mask and position_ids should be passed to every pp layer.\n        # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model\n        batch_size, sequence_length = input_ids.shape\n        # remove padding here\n        input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(\n            input_ids.unsqueeze(dim=-1), attention_mask\n        )  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)\n\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids_rmpad,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        if self.post_process:\n            hidden_states = outputs\n            logits = self._forward_head(hidden_states)\n            logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension # torch.Size([8, 32, 16])\n\n            # remove padding from sequence parallel\n            if self.megatron_config.sequence_parallel:\n                totol_nnz = cu_seqlens[-1]\n                logits = logits[:totol_nnz]  # (total_nnz_padded)\n            # add removed padding back. If input is already rmpad, we let the caller pad_input\n            logits = pad_input(\n                logits, indices, batch_size, seqlen=sequence_length\n            )  # (batch_size, sequence_length, vocab_size)\n\n            return CausalLMOutputWithPast(\n                loss=None,\n                logits=logits,\n                past_key_values=None,\n                hidden_states=None,\n                attentions=None,\n            )\n        else:\n            return outputs\n\n\nclass ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> tuple | CausalLMOutputWithPast:\n        output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)\n        if self.post_process:\n            output.logits = torch.squeeze(output.logits, dim=-1)\n            return output\n        else:\n            return output\n"
  },
  {
    "path": "verl_rl/verl/models/registry.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nfrom typing import Optional\n\nimport torch.nn as nn\n\n# Supported models in Megatron-LM\n# Architecture -> (module, class).\n_MODELS = {\n    \"LlamaForCausalLM\": (\n        \"llama\",\n        (\"ParallelLlamaForCausalLMRmPadPP\", \"ParallelLlamaForValueRmPadPP\", \"ParallelLlamaForCausalLMRmPad\"),\n    ),\n    \"Qwen2ForCausalLM\": (\n        \"qwen2\",\n        (\"ParallelQwen2ForCausalLMRmPadPP\", \"ParallelQwen2ForValueRmPadPP\", \"ParallelQwen2ForCausalLMRmPad\"),\n    ),\n    \"MistralForCausalLM\": (\n        \"mistral\",\n        (\"ParallelMistralForCausalLMRmPadPP\", \"ParallelMistralForValueRmPadPP\", \"ParallelMistralForCausalLMRmPad\"),\n    ),\n}\n\n\n# return model class\nclass ModelRegistry:\n    @staticmethod\n    def load_model_cls(model_arch: str, value=False) -> Optional[type[nn.Module]]:\n        if model_arch not in _MODELS:\n            return None\n\n        megatron = \"megatron\"\n\n        module_name, model_cls_name = _MODELS[model_arch]\n        if not value:  # actor/ref\n            model_cls_name = model_cls_name[0]\n        elif value:  # critic/rm\n            model_cls_name = model_cls_name[1]\n\n        module = importlib.import_module(f\"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron\")\n        return getattr(module, model_cls_name, None)\n\n    @staticmethod\n    def get_supported_archs() -> list[str]:\n        return list(_MODELS.keys())\n"
  },
  {
    "path": "verl_rl/verl/models/transformers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/models/transformers/dense_common.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass\nfrom typing import Optional, Union\n\nimport torch\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\n\n@dataclass\nclass CausalLMOutputForPPO(CausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n\n\ndef forward_base_model(\n    self,\n    input_ids: Optional[torch.LongTensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[Cache] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n) -> CausalLMOutputWithPast:\n    r\"\"\"\n    Copy paste LLaMa's forward\n    https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py\n\n    This function should be generic enough for all pure text models.\n    ```\"\"\"\n\n    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n    output_hidden_states = (\n        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n    )\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    outputs = self.model(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        cache_position=cache_position,\n    )\n\n    return outputs\n\n\ndef forward_with_torch_backend(\n    self,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[Union[\"Cache\", list[torch.FloatTensor]]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    logits_to_keep: int | torch.Tensor = 0,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> tuple | CausalLMOutputForPPO:\n    from verl.utils.experimental.torch_functional import FusedLinearForPPO\n\n    outputs = forward_base_model(\n        self,\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        cache_position=cache_position,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_with_torch_backend has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_torch_backend, either labels or input_ids must be provided.\")\n\n    fused_linear_for_ppo = FusedLinearForPPO()\n    log_probs, entropy = fused_linear_for_ppo.forward(\n        hidden_states=hidden_states,\n        vocab_weights=self.lm_head.weight,\n        input_ids=rolled_labels,\n        temperature=temperature,\n    )\n\n    return CausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n\n\ndef forward_with_triton_backend(\n    self,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[Union[\"Cache\", list[torch.FloatTensor]]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    logits_to_keep: int | torch.Tensor = 0,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> tuple | CausalLMOutputForPPO:\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\n\n    outputs = forward_base_model(\n        self,\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        cache_position=cache_position,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_with_triton_backend has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_triton_backend, either labels or input_ids must be provided.\")\n\n    log_probs, entropy = linear_cross_entropy(\n        hidden_states,\n        self.lm_head.weight,\n        rolled_labels,\n        temperature,\n        \"none\",\n    )\n\n    return CausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n"
  },
  {
    "path": "verl_rl/verl/models/transformers/kimi_vl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\n\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n            used to pass offsetted position ids when working with a KV-cache.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n    sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n\n    b, h, s, d = q.shape\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    b, h, s, d = k.shape\n    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef _ulysses_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.LongTensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Cache] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    **kwargs,\n) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n    bsz, q_len, _ = hidden_states.size()\n\n    if self.q_lora_rank is None:\n        q = self.q_proj(hidden_states)\n    else:\n        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n    q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n\n    # Flash attention requires the input to have the shape\n    # batch_size x seq_length x head_dim x hidden_dim\n    # therefore we just need to keep the original shape\n    compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n    compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)\n    k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n    kv = (\n        self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n        .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n        .transpose(1, 2)\n    )\n\n    k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)\n\n    # patch\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads\n        k_pe = repeat_kv(k_pe, ulysses_sp_size)  # to keep heads=1 after a2a\n        k_nope = repeat_kv(k_nope, num_key_value_groups)\n        value_states = repeat_kv(value_states, num_key_value_groups)\n        q = gather_seq_scatter_heads(q, seq_dim=2, head_dim=1)\n        k_pe = gather_seq_scatter_heads(k_pe, seq_dim=2, head_dim=1)\n        k_nope = gather_seq_scatter_heads(k_nope, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n        # (batch_size, num_head / sp_size, seq_length, head_size)\n        full_q_len = q.size(2)  # full_q_len = seq_length\n\n    else:\n        full_q_len = q_len\n\n    q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)\n    cos, sin = self.rotary_emb(value_states, seq_len=full_q_len)\n    q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)\n\n    query_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim)\n    query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n    query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n    key_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim)\n    key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n    key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n\n    if self.q_head_dim != self.v_head_dim:\n        value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])\n\n    # TODO: These transpose are quite inefficient but Flash Attention requires the layout\n    # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n    # to be able to avoid many of these transpose/reshape/view.\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    dropout_rate = self.attention_dropout if self.training else 0.0\n\n    attn_output = _flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        dropout=dropout_rate,\n        sliding_window=None,\n        is_causal=self.is_causal,\n        use_top_left_mask=self._flash_attn_uses_top_left_mask,\n        position_ids=position_ids,  # important: pass position ids\n        softmax_scale=self.softmax_scale,\n    )\n\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)\n\n    if self.q_head_dim != self.v_head_dim:\n        attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous()\n    attn_output = self.o_proj(attn_output)\n\n    return attn_output, None, None\n"
  },
  {
    "path": "verl_rl/verl/models/transformers/llama.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport sys\nfrom typing import Callable, Optional\n\nimport torch\n\nif sys.version_info >= (3, 11):\n    pass\nelse:\n    pass\n\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.models.llama.modeling_llama import apply_rotary_pos_emb\nfrom transformers.utils import logging\n\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\nlogger = logging.get_logger(__name__)\n\n\ndef llama_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.LongTensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Cache] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    cache_position: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46\n    **kwargs,\n) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n    \"\"\"\n    Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.\n\n    NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1].\n    \"\"\"\n    output_attentions = False\n\n    bsz, q_len, _ = hidden_states.size()\n\n    query_states = self.q_proj(hidden_states)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    # Flash attention requires the input to have the shape\n    # batch_size x seq_length x head_dim x hidden_dim\n    # therefore we just need to keep the original shape\n    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n    # trade off: repeat first and then all to all\n    # key_states = repeat_kv(key_states, self.num_key_value_groups)\n    # value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)  # full seq length\n\n    if position_embeddings is None:\n        logger.warning_once(\n            \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n            \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n            \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be \"\n            \"removed and `position_embeddings` will be mandatory.\"\n        )\n        cos, sin = self.rotary_emb(value_states, position_ids)\n    else:\n        cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    # TODO: These transpose are quite inefficient but Flash Attention requires the layout\n    # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n    # to be able to avoid many of these transpose/reshape/view.\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    dropout_rate = self.attention_dropout if self.training else 0.0\n\n    # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n    # therefore the input hidden states gets silently casted in float32. Hence, we need\n    # cast them back in the correct dtype just to be sure everything works as expected.\n    # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n    # in fp32. (LlamaRMSNorm handles it correctly)\n\n    input_dtype = query_states.dtype\n    if input_dtype == torch.float32:\n        if torch.is_autocast_enabled():\n            target_dtype = torch.get_autocast_gpu_dtype()\n        # Handle the case where the model is quantized\n        elif hasattr(self.config, \"_pre_quantization_dtype\"):\n            target_dtype = self.config._pre_quantization_dtype\n        else:\n            target_dtype = self.q_proj.weight.dtype\n\n        logger.warning_once(\n            f\"The input hidden states seems to be silently casted in float32, this might be related to \"\n            f\"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the \"\n            f\"input in {target_dtype}.\"\n        )\n\n        query_states = query_states.to(target_dtype)\n        key_states = key_states.to(target_dtype)\n        value_states = value_states.to(target_dtype)\n\n    attn_output = _flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        position_ids=position_ids,\n        dropout=dropout_rate,\n        sliding_window=getattr(self, \"sliding_window\", None),\n        use_top_left_mask=self._flash_attn_uses_top_left_mask,\n        is_causal=self.is_causal,\n        **kwargs,\n    )\n\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n\n    if not output_attentions:\n        attn_weights = None\n\n    return attn_output, attn_weights, past_key_value\n\n\ndef llama_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    position_embeddings: tuple[torch.Tensor, torch.Tensor],\n    attention_mask: Optional[torch.Tensor],\n    past_key_value: Optional[Cache] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n    \"\"\"\n    Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.\n\n    NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.\n    \"\"\"\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n    from transformers.models.llama.modeling_llama import eager_attention_forward\n\n    bsz, q_len, _ = hidden_states.shape\n\n    query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)\n\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)\n\n    cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    attention_interface: Callable = eager_attention_forward\n    if self.config._attn_implementation != \"eager\":\n        if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n            logger.warning_once(\n                \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. \"\n                \"Falling back to eager attention. This warning can be removed using the argument \"\n                '`attn_implementation=\"eager\"` when loading the model.'\n            )\n        else:\n            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n    attn_output, attn_weights = attention_interface(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        dropout=0.0 if not self.training else self.attention_dropout,\n        scaling=self.scaling,\n        **kwargs,\n    )\n\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n    return attn_output, attn_weights\n"
  },
  {
    "path": "verl_rl/verl/models/transformers/monkey_patch.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nApply monkey-patch function to models\n\"\"\"\n\nimport importlib.metadata\nimport sys\nfrom functools import lru_cache\nfrom typing import Optional\n\nimport torch\nfrom packaging import version\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.modeling_utils import PreTrainedModel\n\nfrom verl.utils.import_utils import is_trl_available\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_group,\n    get_ulysses_sequence_parallel_world_size,\n    slice_input_tensor,\n)\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=2, repeats=n_rep). The hidden states go from (batch,\n    seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim)\n    \"\"\"\n    batch, slen, num_key_value_heads, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)\n    return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)\n\n\ndef _ulysses_flash_attention_forward(\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    *args,\n    position_ids: Optional[torch.Tensor] = None,\n    **kwargs,\n):\n    \"\"\"Insert all-to-all before and after flash attention.\n    DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509\n\n    Args:\n        query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim)\n        key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)\n        value_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)\n        position_ids (torch.Tensor, optional): (batch_size, seqlen/sp_size)\n\n    Returns:\n        torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim)\n    \"\"\"\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        assert position_ids is not None, \"position_ids is required for Ulysses sequence parallelism\"\n\n        # NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,\n        # we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.\n        # For example:\n        # - nheads_k=4, sp=8, repeats=2\n        # - nheads_k=8, sp=8, repeats=1\n        # - nheads_k=16, sp=8, repeats=1\n        repeats = max(ulysses_sp_size // key_states.size(2), 1)\n        key_states = repeat_kv(key_states, repeats)\n        value_states = repeat_kv(value_states, repeats)\n\n        # (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)\n\n        # TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate\n        # this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly.\n        # https://github.com/huggingface/transformers/pull/33932\n\n        # (bsz, seq_len/n) -> (bsz, seq_len)\n        position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]\n        torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())\n        position_ids = torch.concat(position_ids_list, dim=-1)\n\n    # (bsz, seq_len, n_head/n, head_dim)\n    attn_output = _flash_attention_forward(\n        query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs\n    )\n\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n\n    return attn_output\n\n\ndef patch_vlm_for_ulysses_input_slicing(model_class: type):\n    \"\"\"\n    Applies a monkey patch to the forward method of a given model class\n    to enable Ulysses sequence parallelism input slicing.\n    \"\"\"\n\n    def _create_ulysses_wrapped_decoder_forward(original_forward):\n        def ulysses_wrapped_decoder_forward(self, *args, **kwargs):\n            inputs_embeds = kwargs.get(\"inputs_embeds\")\n            call_kwargs = kwargs.copy()\n\n            current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n            slice_now = (\n                inputs_embeds is not None\n                and current_ulysses_sp_size > 1\n                and getattr(self, \"_needs_initial_slice\", True)\n            )\n            if slice_now:\n                call_kwargs[\"inputs_embeds\"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)\n                self._needs_initial_slice = False\n            try:\n                return original_forward(self, *args, **call_kwargs)\n            finally:\n                if slice_now:\n                    self._needs_initial_slice = True\n\n        return ulysses_wrapped_decoder_forward\n\n    original_forward = model_class.forward\n    wrapped_forward = _create_ulysses_wrapped_decoder_forward(original_forward)\n    model_class.forward = wrapped_forward\n    print(f\"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.\")\n\n\ndef patch_forward_with_backends(\n    model: PreTrainedModel,\n    use_fused_kernels: bool = False,\n    fused_kernels_backend: str = None,\n):\n    \"\"\"\n    Choose the forward function based on the model and backend.\n    Args:\n        model (PreTrainedModel): The model to apply the monkey patch.\n        use_fused_kernels (bool): Whether to use fused kernels.\n        fused_kernels_backend (str): The backend to use for fused kernels.\n    \"\"\"\n    if not use_fused_kernels or fused_kernels_backend not in [\"triton\", \"torch\"]:\n        print(\n            f\"Skipping monkey patch for {model.__class__.__name__} as use_fused_kernels is \"\n            f\"{use_fused_kernels} or fused_kernels_backend is {fused_kernels_backend}\"\n        )\n        return\n\n    forward_with_torch_backend_function = model.__class__.forward\n    forward_with_triton_backend_function = model.__class__.forward\n    if model.config.model_type == \"qwen2_5_vl\":\n        from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend\n\n        forward_with_torch_backend_function = forward_with_torch_backend\n        forward_with_triton_backend_function = forward_with_triton_backend\n    elif model.config.model_type == \"qwen2_vl\":\n        from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend\n\n        forward_with_torch_backend_function = forward_with_torch_backend\n        forward_with_triton_backend_function = forward_with_triton_backend\n    else:\n        from verl.models.transformers.dense_common import forward_with_torch_backend, forward_with_triton_backend\n\n        forward_with_torch_backend_function = forward_with_torch_backend\n        forward_with_triton_backend_function = forward_with_triton_backend\n\n    if fused_kernels_backend == \"triton\":\n        model.__class__.forward = forward_with_triton_backend_function\n        print(f\"Using Triton backend for fused kernels in {model.__class__.__name__}\")\n    elif fused_kernels_backend == \"torch\":\n        model.__class__.forward = forward_with_torch_backend_function\n        print(f\"Using Torch backend for fused kernels in {model.__class__.__name__}\")\n    else:\n        raise ValueError(f\"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'.\")\n\n\ndef apply_monkey_patch(\n    model: PreTrainedModel,\n    ulysses_sp_size: int = 1,\n    use_remove_padding: bool = True,\n    use_fused_kernels: bool = False,\n    fused_kernels_backend: str = None,\n):\n    \"\"\"\n    Apply monkey patch to the models for ulysses sequence parallel and fused kernel.\n\n    In the end of this function forward function of the model is patched for fused kernel.\n    If the model is not supported with fused kernel, please return after patch.\n    \"\"\"\n\n    \"\"\"Replace _flash_attention_forward to _ulysses_flash_attention_forward\"\"\"\n    module = sys.modules[model.__module__]\n\n    try:\n        num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads\n    except AttributeError:\n        num_attention_heads, num_key_value_heads = (\n            model.config.text_config.num_attention_heads,\n            model.config.text_config.num_key_value_heads,\n        )\n\n    assert num_attention_heads % ulysses_sp_size == 0, (\n        f\"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}\"\n    )\n    assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, (\n        f\"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size \"\n        f\"{ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,\"\n        f\"kv heads are repeated to ensure correctness.\"\n    )\n\n    if is_trl_available():\n        from trl import AutoModelForCausalLMWithValueHead  # type: ignore\n\n        def state_dict(self, *args, **kwargs):\n            return torch.nn.Module.state_dict(self, *args, **kwargs)\n\n        AutoModelForCausalLMWithValueHead.state_dict = state_dict\n        print(\"Monkey patch state_dict in AutoModelForCausalLMWithValueHead. \")\n\n    # TODO: VLM models only, unify monkey patch to LLM models.\n    if model.config.model_type == \"qwen2_5_vl\":\n        if is_transformers_version_in_range(min_version=\"4.53.0\"):\n            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention\n\n            # TODO: Support transformers 4.53\n            raise ValueError(\"Transformers 4.53 is not supported\")\n        else:\n            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (\n                Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,\n            )\n\n        if use_remove_padding or ulysses_sp_size > 1:\n            from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward\n\n            Qwen2_5_VLAttention.forward = ulysses_flash_attn_forward\n            print(\"Monkey patch FlashAttention2.forward in Qwen2.5VL\")\n\n        if ulysses_sp_size > 1:\n            if is_transformers_version_in_range(min_version=\"4.52.0\"):\n                from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel\n\n                patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)\n            else:\n                from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel\n\n                patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel)\n\n    elif model.config.model_type == \"qwen2_vl\":\n        if is_transformers_version_in_range(min_version=\"4.53.0\"):\n            from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention\n\n            # TODO: Support transformers 4.53\n            raise ValueError(\"Transformers 4.53 is not supported\")\n        else:\n            from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention\n\n        if use_remove_padding or ulysses_sp_size > 1:\n            from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward\n\n            Qwen2VLAttention.forward = ulysses_flash_attn_forward\n            print(\"Monkey patch FlashAttention2.forward in Qwen2VL\")\n\n        if ulysses_sp_size > 1:\n            if is_transformers_version_in_range(min_version=\"4.52.0\"):\n                from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel\n\n                patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)\n            else:\n                from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel\n\n                patch_vlm_for_ulysses_input_slicing(Qwen2VLModel)\n\n    elif model.config.model_type == \"kimi_vl\":\n        if use_remove_padding or ulysses_sp_size > 1:\n            # TODO: Changes need to be made when transformers are adapted.\n            from verl.models.transformers.kimi_vl import _ulysses_flash_attn_forward\n\n            module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward\n            print(\"Monkey patch FlashAttention2.forward in KimiVL\")\n\n        if ulysses_sp_size > 1:\n            patch_vlm_for_ulysses_input_slicing(module.DeepseekV3ForCausalLM)\n\n        if use_fused_kernels:\n            print(\"Not support fused kernels for KimiVL\")\n\n        return\n\n    # transformers<=4.47.1\n    if use_remove_padding or ulysses_sp_size > 1:\n        if hasattr(module, \"_flash_attention_forward\"):\n            module._flash_attention_forward = _ulysses_flash_attention_forward\n            print(f\"Monkey patch _flash_attention_forward in {model.__module__}\")\n        else:\n            # transformers>=4.48.0\n            from transformers.integrations import flash_attention\n\n            flash_attention._flash_attention_forward = _ulysses_flash_attention_forward\n            print(f\"Monkey patch _flash_attention_forward in {flash_attention.__name__}\")\n\n    patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend)\n\n\n@lru_cache\ndef is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool:\n    try:\n        # Get the installed version of the transformers library\n        transformers_version_str = importlib.metadata.version(\"transformers\")\n    except importlib.metadata.PackageNotFoundError as e:\n        raise ModuleNotFoundError(\"The `transformers` package is not installed.\") from e\n\n    transformers_version = version.parse(transformers_version_str)\n\n    lower_bound_check = True\n    if min_version is not None:\n        lower_bound_check = version.parse(min_version) <= transformers_version\n\n    upper_bound_check = True\n    if max_version is not None:\n        upper_bound_check = transformers_version <= version.parse(max_version)\n\n    return lower_bound_check and upper_bound_check\n"
  },
  {
    "path": "verl_rl/verl/models/transformers/npu_patch.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\r\n#\r\n# Copyright 2025 The Qwen Team and The HuggingFace Inc. team\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\n\r\nimport torch\r\nimport torch_npu\r\nfrom torch_npu import npu_rotary_mul as apply_rotary_emb\r\nfrom transformers.models.qwen2_5_vl import modeling_qwen2_5_vl\r\nfrom transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm\r\n\r\n\r\n# This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in\r\n# subsequent versions\r\n# https://github.com/huggingface/transformers/pull/38491\r\ndef apply_rotary_pos_emb_flashatt_npu(\r\n    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor\r\n) -> tuple[torch.Tensor, torch.Tensor]:\r\n    cos = cos.chunk(2, dim=-1)[0].contiguous()\r\n    sin = sin.chunk(2, dim=-1)[0].contiguous()\r\n    cos = cos.repeat(1, 2)\r\n    sin = sin.repeat(1, 2)\r\n    q_embed = apply_rotary_emb(\r\n        q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()\r\n    ).type_as(q)\r\n    k_embed = apply_rotary_emb(\r\n        k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()\r\n    ).type_as(k)\r\n    return q_embed, k_embed\r\n\r\n\r\n# This api can improve performance on ASCEND NPU\r\ndef rms_norm_forward(self, x):\r\n    return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0]\r\n\r\n\r\nQwen2RMSNorm.forward = rms_norm_forward\r\nmodeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu\r\n"
  },
  {
    "path": "verl_rl/verl/models/transformers/qwen2.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable, Optional\n\nimport torch\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv\nfrom transformers.utils import logging\n\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\nlogger = logging.get_logger(__name__)\n\n\ndef qwen2_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Cache] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    cache_position: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46\n):\n    \"\"\"\n    Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.\n\n    NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1.\n    \"\"\"\n    bsz, q_len, _ = hidden_states.size()\n\n    query_states = self.q_proj(hidden_states)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)  # full seq length\n\n    if position_embeddings is None:\n        logger.warning_once(\n            \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n            \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n            \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be \"\n            \"removed and `position_embeddings` will be mandatory.\"\n        )\n        cos, sin = self.rotary_emb(value_states, position_ids)\n    else:\n        cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    # repeat k/v heads if n_kv_heads < n_heads\n    key_states = repeat_kv(key_states, self.num_key_value_groups)\n    value_states = repeat_kv(value_states, self.num_key_value_groups)\n    dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n    # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n    # therefore the input hidden states gets silently casted in float32. Hence, we need\n    # cast them back in float16 just to be sure everything works as expected.\n    input_dtype = query_states.dtype\n    if input_dtype == torch.float32:\n        if torch.is_autocast_enabled():\n            target_dtype = torch.get_autocast_gpu_dtype()\n        # Handle the case where the model is quantized\n        elif hasattr(self.config, \"_pre_quantization_dtype\"):\n            target_dtype = self.config._pre_quantization_dtype\n        else:\n            target_dtype = self.q_proj.weight.dtype\n\n        logger.warning_once(\n            f\"The input hidden states seems to be silently casted in float32, this might be related to \"\n            f\"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the \"\n            f\"input in {target_dtype}.\"\n        )\n\n        query_states = query_states.to(target_dtype)\n        key_states = key_states.to(target_dtype)\n        value_states = value_states.to(target_dtype)\n\n    # Reashape to the expected shape for Flash Attention\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    if (\n        self.config.use_sliding_window\n        and getattr(self.config, \"sliding_window\", None) is not None\n        and self.layer_idx >= self.config.max_window_layers\n    ):\n        sliding_window = self.config.sliding_window\n    else:\n        sliding_window = None\n\n    attn_output = _flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        position_ids=position_ids,\n        dropout=dropout_rate,\n        sliding_window=sliding_window,\n        is_causal=self.is_causal,\n        use_top_left_mask=self._flash_attn_uses_top_left_mask,\n    )\n\n    # use full_q_len to reshape\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n\n    if not output_attentions:\n        attn_weights = None\n\n    return attn_output, attn_weights, past_key_value\n\n\ndef qwen2_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    position_embeddings: tuple[torch.Tensor, torch.Tensor],\n    attention_mask: Optional[torch.Tensor],\n    past_key_value: Optional[Cache] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n    \"\"\"\n    Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.\n\n    NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.\n    \"\"\"\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n\n    bsz, q_len, _ = hidden_states.shape\n    hidden_shape = (bsz, q_len, -1, self.head_dim)\n\n    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)\n\n        # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)\n\n    cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    sliding_window = None\n    if (\n        self.config.use_sliding_window\n        and getattr(self.config, \"sliding_window\", None) is not None\n        and self.layer_idx >= self.config.max_window_layers\n    ):\n        sliding_window = self.config.sliding_window\n\n    from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward\n\n    attention_interface: Callable = eager_attention_forward\n    if self.config._attn_implementation != \"eager\":\n        if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n            logger.warning_once(\n                \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. \"\n                \"Falling back to eager attention. This warning can be removed using the argument \"\n                '`attn_implementation=\"eager\"` when loading the model.'\n            )\n        else:\n            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n    attn_output, attn_weights = attention_interface(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        dropout=0.0 if not self.training else self.attention_dropout,\n        scaling=self.scaling,\n        sliding_window=sliding_window,  # main diff with Llama\n        **kwargs,\n    )\n\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n    return attn_output, attn_weights\n"
  },
  {
    "path": "verl_rl/verl/models/transformers/qwen2_5_vl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (\n    Qwen2_5_VLCausalLMOutputWithPast,\n    Qwen2_5_VLForConditionalGeneration,\n)\n\n\n@dataclass\nclass Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n\n\ndef forward_base_model(\n    self: Qwen2_5_VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[list[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    pixel_values: Optional[torch.Tensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    rope_deltas: Optional[torch.LongTensor] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    second_per_grid_ts: Optional[torch.Tensor] = None,\n) -> tuple | Qwen2_5_VLCausalLMOutputWithPast:\n    r\"\"\"\n    Copy paste Qwen2_5_VL's forward\n    https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py\n    ```\"\"\"\n    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n    output_hidden_states = (\n        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n    )\n    return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n    if inputs_embeds is None:\n        inputs_embeds = self.model.embed_tokens(input_ids)\n        if pixel_values is not None:\n            pixel_values = pixel_values.type(self.visual.dtype)\n            image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)\n            n_image_tokens = (input_ids == self.config.image_token_id).sum().item()\n            n_image_features = image_embeds.shape[0]\n            if n_image_tokens != n_image_features:\n                raise ValueError(\n                    f\"Image features and image tokens do not match: tokens: {n_image_tokens}, \"\n                    f\"features {n_image_features}\"\n                )\n\n            mask = input_ids == self.config.image_token_id\n            mask_unsqueezed = mask.unsqueeze(-1)\n            mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n            image_mask = mask_expanded.to(inputs_embeds.device)\n\n            image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n            inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)\n\n        if pixel_values_videos is not None:\n            pixel_values_videos = pixel_values_videos.type(self.visual.dtype)\n            video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)\n            n_video_tokens = (input_ids == self.config.video_token_id).sum().item()\n            n_video_features = video_embeds.shape[0]\n            if n_video_tokens != n_video_features:\n                raise ValueError(\n                    f\"Video features and video tokens do not match: tokens: {n_video_tokens}, \"\n                    f\"features {n_video_features}\"\n                )\n\n            mask = input_ids == self.config.video_token_id\n            mask_unsqueezed = mask.unsqueeze(-1)\n            mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n            video_mask = mask_expanded.to(inputs_embeds.device)\n\n            video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n            inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)\n\n        if attention_mask is not None:\n            attention_mask = attention_mask.to(inputs_embeds.device)\n\n    # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme\n    if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):\n        # calculate RoPE index once per generation in the pre-fill stage only\n        if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:\n            position_ids, rope_deltas = self.get_rope_index(\n                input_ids,\n                image_grid_thw,\n                video_grid_thw,\n                second_per_grid_ts,\n                attention_mask,\n            )\n            self.rope_deltas = rope_deltas\n        # then use the prev pre-calculated rope-deltas to get the correct position ids\n        else:\n            batch_size, seq_length, _ = inputs_embeds.shape\n            delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0\n            position_ids = torch.arange(seq_length, device=inputs_embeds.device)\n            position_ids = position_ids.view(1, -1).expand(batch_size, -1)\n            if cache_position is not None:  # otherwise `deltas` is an int `0`\n                delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)\n            position_ids = position_ids.add(delta)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)\n\n    outputs = self.model(\n        input_ids=None,\n        position_ids=position_ids,\n        attention_mask=attention_mask,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        cache_position=cache_position,\n    )\n    return outputs\n\n\ndef forward_with_torch_backend(\n    self: Qwen2_5_VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[list[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    pixel_values: Optional[torch.Tensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    rope_deltas: Optional[torch.LongTensor] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    second_per_grid_ts: Optional[torch.Tensor] = None,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> tuple | Qwen2_5_VLCausalLMOutputForPPO:\n    from verl.utils.experimental.torch_functional import FusedLinearForPPO\n\n    outputs = forward_base_model(\n        self,\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        pixel_values=pixel_values,\n        pixel_values_videos=pixel_values_videos,\n        image_grid_thw=image_grid_thw,\n        video_grid_thw=video_grid_thw,\n        rope_deltas=rope_deltas,\n        cache_position=cache_position,\n        second_per_grid_ts=second_per_grid_ts,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_with_torch_backend has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_torch_backend, either labels or input_ids must be provided.\")\n\n    fused_linear_for_ppo = FusedLinearForPPO()\n    log_probs, entropy = fused_linear_for_ppo.forward(\n        hidden_states=hidden_states,\n        vocab_weights=self.lm_head.weight,\n        input_ids=rolled_labels,\n        temperature=temperature,\n    )\n\n    return Qwen2_5_VLCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n        rope_deltas=rope_deltas,\n    )\n\n\ndef forward_with_triton_backend(\n    self: Qwen2_5_VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[list[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    pixel_values: Optional[torch.Tensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    rope_deltas: Optional[torch.LongTensor] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    second_per_grid_ts: Optional[torch.Tensor] = None,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> tuple | Qwen2_5_VLCausalLMOutputForPPO:\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\n\n    outputs = forward_base_model(\n        self,\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        pixel_values=pixel_values,\n        pixel_values_videos=pixel_values_videos,\n        image_grid_thw=image_grid_thw,\n        video_grid_thw=video_grid_thw,\n        rope_deltas=rope_deltas,\n        cache_position=cache_position,\n        second_per_grid_ts=second_per_grid_ts,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_with_triton_backend has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_triton_backend, either labels or input_ids must be provided.\")\n\n    log_probs, entropy = linear_cross_entropy(\n        hidden_states,\n        self.lm_head.weight,\n        rolled_labels,\n        temperature,\n        \"none\",\n    )\n\n    return Qwen2_5_VLCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n        rope_deltas=rope_deltas,\n    )\n"
  },
  {
    "path": "verl_rl/verl/models/transformers/qwen2_vl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport os\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.models.qwen2_vl.modeling_qwen2_vl import (\n    Qwen2VLCausalLMOutputWithPast,\n    Qwen2VLForConditionalGeneration,\n)\nfrom transformers.utils import is_flash_attn_greater_or_equal\n\nfrom verl.utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\ntry:\n    from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func\n\n    _flash_supports_window_size = \"window_size\" in list(inspect.signature(flash_attn_func).parameters)\nexcept ImportError:\n    flash_attn_varlen_func = None\n\n\ndef get_rope_index(\n    processor,\n    input_ids: torch.Tensor,\n    image_grid_thw: Optional[torch.Tensor] = None,\n    video_grid_thw: Optional[torch.Tensor] = None,\n    second_per_grid_ts: Optional[torch.Tensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"\n    Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.\n    The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.\n    https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546\n    \"\"\"\n    spatial_merge_size = processor.image_processor.merge_size\n    tokens_per_second = 2\n    image_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|image_pad|>\")\n    video_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|video_pad|>\")\n    vision_start_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|vision_start|>\")\n    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):\n        if attention_mask is None:\n            attention_mask = torch.ones_like(input_ids)\n\n        position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device)  # (3, seqlen)\n        image_index, video_index = 0, 0\n        input_ids = input_ids[attention_mask == 1]\n        image_nums, video_nums = 0, 0\n        vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)\n        vision_tokens = input_ids[vision_start_indices + 1]\n        image_nums = (vision_tokens == image_token_id).sum()\n        video_nums = (vision_tokens == video_token_id).sum()\n        input_tokens = input_ids.tolist()\n        llm_pos_ids_list: list = []\n        st = 0\n        remain_images, remain_videos = image_nums, video_nums\n        for _ in range(image_nums + video_nums):\n            if image_token_id in input_tokens and remain_images > 0:\n                ed_image = input_tokens.index(image_token_id, st)\n            else:\n                ed_image = len(input_tokens) + 1\n            if video_token_id in input_tokens and remain_videos > 0:\n                ed_video = input_tokens.index(video_token_id, st)\n            else:\n                ed_video = len(input_tokens) + 1\n            if ed_image < ed_video:\n                t, h, w = (\n                    image_grid_thw[image_index][0],\n                    image_grid_thw[image_index][1],\n                    image_grid_thw[image_index][2],\n                )\n                second_per_grid_t = 0\n                image_index += 1\n                remain_images -= 1\n                ed = ed_image\n            else:\n                t, h, w = (\n                    video_grid_thw[video_index][0],\n                    video_grid_thw[video_index][1],\n                    video_grid_thw[video_index][2],\n                )\n                second_per_grid_t = second_per_grid_ts[video_index] if second_per_grid_ts is not None else 1.0\n\n                video_index += 1\n                remain_videos -= 1\n                ed = ed_video\n\n            llm_grid_t, llm_grid_h, llm_grid_w = (\n                t.item(),\n                h.item() // spatial_merge_size,\n                w.item() // spatial_merge_size,\n            )\n            text_len = ed - st\n\n            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n            t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)\n            t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()\n            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()\n            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()\n            llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)\n            st = ed + llm_grid_t * llm_grid_h * llm_grid_w\n\n        if st < len(input_tokens):\n            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n            text_len = len(input_tokens) - st\n            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)\n        position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)\n    else:\n        if attention_mask is not None:\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)\n        else:\n            position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)\n\n    return position_ids\n\n\ndef prepare_fa2_from_position_ids(\n    query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor\n):\n    query = query.view(-1, query.size(-2), query.size(-1))\n    key = key.view(-1, key.size(-2), key.size(-1))\n    value = value.view(-1, value.size(-2), value.size(-1))\n    position_ids = position_ids.flatten()\n    indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)\n    cu_seqlens = torch.cat(\n        (\n            indices_q[position_ids == 0],\n            torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),\n        )\n    )\n    max_length = cu_seqlens.diff().max()  # use cu_seqlens to infer max_length for qwen2vl mrope\n    return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))\n\n\ndef flash_attention_forward(\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    attention_mask: torch.Tensor,\n    query_length: int,\n    is_causal: bool = True,\n    position_ids: Optional[torch.Tensor] = None,\n    sliding_window: Optional[int] = None,\n    use_top_left_mask: bool = False,\n    deterministic: Optional[bool] = None,\n    **kwargs,\n):\n    \"\"\"\n    Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)\n    \"\"\"\n    causal = is_causal if not use_top_left_mask else is_causal and query_length != 1\n\n    # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).\n    use_sliding_windows = (\n        _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window\n    )\n    flash_kwargs = {\"window_size\": (sliding_window, sliding_window)} if use_sliding_windows else {}\n\n    if is_flash_attn_greater_or_equal(\"2.4.1\"):\n        if deterministic is None:\n            deterministic = os.environ.get(\"FLASH_ATTENTION_DETERMINISTIC\", \"0\") == \"1\"\n        flash_kwargs[\"deterministic\"] = deterministic\n\n    if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all():\n        batch_size = query_states.size(0)\n        query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(\n            query_states, key_states, value_states, position_ids[0]\n        )  # remove channel dimension\n        cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n        max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n        attn_output = flash_attn_varlen_func(\n            query_states,\n            key_states,\n            value_states,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            max_seqlen_q=max_seqlen_in_batch_q,\n            max_seqlen_k=max_seqlen_in_batch_k,\n            dropout_p=kwargs.pop(\"dropout\", 0.0),\n            softmax_scale=kwargs.pop(\"softmax_scale\", None),\n            causal=causal,\n            **flash_kwargs,\n        )\n        attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))\n    else:\n        attn_output = _flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            query_length,\n            is_causal=is_causal,\n            sliding_window=sliding_window,\n            use_top_left_mask=use_top_left_mask,\n            deterministic=deterministic,\n            **kwargs,\n        )  # do not pass position_ids to old flash_attention_forward\n\n    return attn_output\n\n\ndef ulysses_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46\n    **kwargs,\n) -> tuple[torch.Tensor, None, None]:\n    from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv\n\n    bsz, q_len, _ = hidden_states.size()  # q_len = seq_length / sp_size\n    query_states = self.q_proj(hidden_states)  # (batch_size, seq_length / sp_size, num_heads * head_size)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n        # (batch_size, num_head / sp_size, seq_length, head_size)\n        full_q_len = query_states.size(2)  # full_q_len = seq_length\n    else:\n        full_q_len = q_len\n\n    # Because the input can be padded, the absolute sequence length depends on the max position id.\n    if position_embeddings is None:\n        cos, sin = self.rotary_emb(value_states, position_ids)\n    else:\n        cos, sin = position_embeddings\n\n    query_states, key_states = apply_multimodal_rotary_pos_emb(\n        query_states, key_states, cos, sin, self.rope_scaling[\"mrope_section\"]\n    )\n    dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n    # Reashape to the expected shape for Flash Attention\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    if (\n        self.config.use_sliding_window\n        and getattr(self.config, \"sliding_window\", None) is not None\n        and self.layer_idx >= self.config.max_window_layers\n    ):\n        sliding_window = self.config.sliding_window\n    else:\n        sliding_window = None\n\n    attn_output = flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        dropout=dropout_rate,\n        sliding_window=sliding_window,\n        is_causal=self.is_causal,\n        use_top_left_mask=self._flash_attn_uses_top_left_mask,\n        position_ids=position_ids,  # important: pass position ids\n    )  # (batch_size, seq_length, num_head / sp_size, head_size)\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)\n\n    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()\n    attn_output = self.o_proj(attn_output)\n    return attn_output, None, None\n\n\n@dataclass\nclass Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n\n\ndef forward_base_model(\n    self: Qwen2VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[list[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    pixel_values: Optional[torch.Tensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    rope_deltas: Optional[torch.LongTensor] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n) -> tuple | Qwen2VLCausalLMOutputWithPast:\n    r\"\"\"\n    Copy paste Qwen2VL's forward\n    https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py\n    ```\"\"\"\n    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n    output_hidden_states = (\n        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n    )\n    return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n    if inputs_embeds is None:\n        inputs_embeds = self.model.embed_tokens(input_ids)\n        if pixel_values is not None:\n            pixel_values = pixel_values.type(self.visual.get_dtype())\n            image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)\n            n_image_tokens = (input_ids == self.config.image_token_id).sum().item()\n            n_image_features = image_embeds.shape[0]\n            if n_image_tokens != n_image_features:\n                raise ValueError(\n                    f\"Image features and image tokens do not match: tokens: {n_image_tokens}, \"\n                    f\"features {n_image_features}\"\n                )\n            image_mask = (\n                (input_ids == self.config.image_token_id)\n                .unsqueeze(-1)\n                .expand_as(inputs_embeds)\n                .to(inputs_embeds.device)\n            )\n            image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n            inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)\n\n        if pixel_values_videos is not None:\n            pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())\n            video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)\n            n_video_tokens = (input_ids == self.config.video_token_id).sum().item()\n            n_video_features = video_embeds.shape[0]\n            if n_video_tokens != n_video_features:\n                raise ValueError(\n                    f\"Video features and video tokens do not match: tokens: {n_video_tokens}, \"\n                    f\"features {n_video_features}\"\n                )\n            video_mask = (\n                (input_ids == self.config.video_token_id)\n                .unsqueeze(-1)\n                .expand_as(inputs_embeds)\n                .to(inputs_embeds.device)\n            )\n            video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n            inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)\n\n        if attention_mask is not None:\n            attention_mask = attention_mask.to(inputs_embeds.device)\n\n    if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):\n        # calculate RoPE index once per generation in the pre-fill stage only\n        if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:\n            position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)\n            self.rope_deltas = rope_deltas\n        # then use the prev pre-calculated rope-deltas to get the correct position ids\n        else:\n            batch_size, seq_length, _ = inputs_embeds.shape\n            delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0\n            position_ids = torch.arange(seq_length, device=inputs_embeds.device)\n            position_ids = position_ids.view(1, -1).expand(batch_size, -1)\n            if cache_position is not None:  # otherwise `deltas` is an int `0`\n                delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)\n            position_ids = position_ids.add(delta)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)\n\n    outputs = self.model(\n        input_ids=None,\n        position_ids=position_ids,\n        attention_mask=attention_mask,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        cache_position=cache_position,\n    )\n\n    return outputs\n\n\ndef forward_with_torch_backend(\n    self: Qwen2VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[list[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    pixel_values: Optional[torch.Tensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    rope_deltas: Optional[torch.LongTensor] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> tuple | Qwen2VLCausalLMOutputForPPO:\n    from verl.utils.experimental.torch_functional import FusedLinearForPPO\n\n    outputs = forward_base_model(\n        self,\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        pixel_values=pixel_values,\n        pixel_values_videos=pixel_values_videos,\n        image_grid_thw=image_grid_thw,\n        video_grid_thw=video_grid_thw,\n        rope_deltas=rope_deltas,\n        cache_position=cache_position,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_with_torch_backend has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_torch_backend, either labels or input_ids must be provided.\")\n\n    fused_linear_for_ppo = FusedLinearForPPO()\n    log_probs, entropy = fused_linear_for_ppo.forward(\n        hidden_states=hidden_states,\n        vocab_weights=self.lm_head.weight,\n        input_ids=rolled_labels,\n        temperature=temperature,\n    )\n\n    return Qwen2VLCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n        rope_deltas=rope_deltas,\n    )\n\n\ndef forward_with_triton_backend(\n    self: Qwen2VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[list[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    pixel_values: Optional[torch.Tensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    rope_deltas: Optional[torch.LongTensor] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> tuple | Qwen2VLCausalLMOutputForPPO:\n    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy\n\n    outputs = forward_base_model(\n        self,\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        pixel_values=pixel_values,\n        pixel_values_videos=pixel_values_videos,\n        image_grid_thw=image_grid_thw,\n        video_grid_thw=video_grid_thw,\n        rope_deltas=rope_deltas,\n        cache_position=cache_position,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_with_triton_backend has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_with_triton_backend, either labels or input_ids must be provided.\")\n\n    log_probs, entropy = linear_cross_entropy(\n        hidden_states,\n        self.lm_head.weight,\n        rolled_labels,\n        temperature,\n        \"none\",\n    )\n\n    return Qwen2VLCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n        rope_deltas=rope_deltas,\n    )\n"
  },
  {
    "path": "verl_rl/verl/models/weight_loader_registry.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ndef get_weight_loader(arch: str):\n    from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel\n\n    _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = {\n        \"LlamaForCausalLM\": load_state_dict_to_megatron_gptmodel,\n        \"Qwen2ForCausalLM\": load_state_dict_to_megatron_gptmodel,\n    }\n\n    if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY:\n        return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch]\n    raise ValueError(\n        f\"Model architectures {arch} loader are not supported for now. Supported architectures: \"\n        f\"{_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}\"\n    )\n\n\ndef get_weight_saver(arch: str):\n    from verl.models.mcore.saver import (\n        merge_megatron_ckpt_gptmodel,\n        merge_megatron_ckpt_gptmodel_dpskv3,\n        merge_megatron_ckpt_gptmodel_mixtral,\n        merge_megatron_ckpt_gptmodel_qwen2_5_vl,\n        merge_megatron_ckpt_gptmodel_qwen_moe,\n    )\n\n    _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = {\n        \"LlamaForCausalLM\": merge_megatron_ckpt_gptmodel,\n        \"Qwen2ForCausalLM\": merge_megatron_ckpt_gptmodel,\n        \"MixtralForCausalLM\": merge_megatron_ckpt_gptmodel_mixtral,\n        \"Qwen2MoeForCausalLM\": merge_megatron_ckpt_gptmodel_qwen_moe,\n        \"Qwen2_5_VLForConditionalGeneration\": merge_megatron_ckpt_gptmodel_qwen2_5_vl,\n        \"DeepseekV3ForCausalLM\": merge_megatron_ckpt_gptmodel_dpskv3,\n        \"Qwen3ForCausalLM\": merge_megatron_ckpt_gptmodel,\n        \"Qwen3MoeForCausalLM\": merge_megatron_ckpt_gptmodel_qwen_moe,\n    }\n    if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY:\n        return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch]\n    raise ValueError(\n        f\"Model architectures {arch} saver are not supported for now. Supported architectures: \"\n        f\"{_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}\"\n    )\n"
  },
  {
    "path": "verl_rl/verl/protocol.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement base data transfer protocol between any two functions, modules.\nWe can subclass Protocol to define more detailed batch info with specific keys\n\"\"\"\n\nimport contextlib\nimport copy\nimport logging\nimport os\nimport pickle\nfrom dataclasses import dataclass, field\nfrom typing import Callable, Optional\n\nimport numpy as np\nimport pandas as pd\nimport ray\nimport tensordict\nimport torch\nimport torch.distributed\nfrom packaging import version\nfrom tensordict import TensorDict\nfrom torch.utils.data import DataLoader\n\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.py_functional import union_two_dict\nfrom verl.utils.torch_functional import allgather_dict_tensors\n\n__all__ = [\"DataProto\", \"union_tensor_dict\"]\n\nwith contextlib.suppress(Exception):\n    tensordict.set_lazy_legacy(False).set()\n\n\nclass _DataProtoConfigMeta(type):\n    _config = {}\n\n    auto_padding_key = \"_verl_auto_padding\"\n\n    @property\n    def auto_padding(cls):\n        enabled_by_env = os.getenv(\"VERL_AUTO_PADDING\", \"FALSE\").upper() in [\"TRUE\", \"1\"]\n        return enabled_by_env or cls._config.get(cls.auto_padding_key, False)\n\n    @auto_padding.setter\n    def auto_padding(cls, enabled: bool):\n        assert isinstance(enabled, bool), f\"enabled must be a boolean, got {enabled} as {type(enabled)}\"\n        cls._config[cls.auto_padding_key] = enabled\n\n\nclass DataProtoConfig(metaclass=_DataProtoConfigMeta):\n    pass\n\n\n_padding_size_key = \"_padding_size_key_x123d\"\n\n\ndef pad_dataproto_to_divisor(data: \"DataProto\", size_divisor: int):\n    \"\"\"Pad a DataProto to size divisible by size_divisor\n\n    Args:\n        size_divisor (int): size divisor\n\n    Returns:\n        data: (DataProto): the padded DataProto\n        pad_size (int)\n    \"\"\"\n    assert isinstance(data, DataProto), \"data must be a DataProto\"\n    if len(data) % size_divisor != 0:\n        pad_size = size_divisor - len(data) % size_divisor\n        padding_protos = []\n        remaining_pad = pad_size\n        while remaining_pad > 0:\n            take_size = min(remaining_pad, len(data))\n            padding_protos.append(data[:take_size])\n            remaining_pad -= take_size\n        data_padded = DataProto.concat([data] + padding_protos)\n    else:\n        if len(data) == 0:\n            logging.warning(\"padding a DataProto with no item, no changed made\")\n        pad_size = 0\n        data_padded = data\n    return data_padded, pad_size\n\n\ndef unpad_dataproto(data: \"DataProto\", pad_size):\n    \"\"\"Unpad the data proto with pad_size. i.e. `data[:-pad_size]`\"\"\"\n    if pad_size != 0:\n        data = data[:-pad_size]\n    return data\n\n\ndef union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:\n    \"\"\"Union two tensordicts.\"\"\"\n    assert tensor_dict1.batch_size == tensor_dict2.batch_size, (\n        f\"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}\"\n    )\n    for key in tensor_dict2.keys():\n        if key not in tensor_dict1.keys():\n            tensor_dict1[key] = tensor_dict2[key]\n        else:\n            assert tensor_dict1[key].equal(tensor_dict2[key]), (\n                f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n            )\n\n    return tensor_dict1\n\n\ndef union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]:\n    for key, val in tensor_dict2.items():\n        if key in tensor_dict1:\n            assert isinstance(tensor_dict2[key], np.ndarray)\n            assert isinstance(tensor_dict1[key], np.ndarray)\n            # to properly deal with nan and object type\n            assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), (\n                f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n            )\n        tensor_dict1[key] = val\n\n    return tensor_dict1\n\n\ndef list_of_dict_to_dict_of_list(list_of_dict: list[dict]):\n    if len(list_of_dict) == 0:\n        return {}\n    keys = list_of_dict[0].keys()\n    output = {key: [] for key in keys}\n    for data in list_of_dict:\n        for key, item in data.items():\n            assert key in output\n            output[key].append(item)\n    return output\n\n\ndef fold_batch_dim(data: \"DataProto\", new_batch_size):\n    \"\"\"\n    Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]\n    \"\"\"\n    batch_size = data.batch.batch_size[0]\n\n    assert batch_size % new_batch_size == 0\n\n    tensor: TensorDict = data.batch\n    non_tensor = data.non_tensor_batch\n\n    tensor = tensor.view(new_batch_size, -1)\n    tensor.auto_batch_size_(batch_dims=1)\n\n    for key, val in non_tensor.items():\n        non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))\n\n    return type(data)(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)\n\n\ndef unfold_batch_dim(data: \"DataProto\", batch_dims=2):\n    \"\"\"\n    Unfold the first n dims as new batch dim\n    \"\"\"\n    tensor: TensorDict = data.batch\n    non_tensor = data.non_tensor_batch\n    tensor.auto_batch_size_(batch_dims=batch_dims)\n    tensor = tensor.view(-1)\n\n    batch_size = tensor.batch_size[0]\n\n    non_tensor_new = {}\n\n    for key, val in non_tensor.items():\n        non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:]))\n\n    return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info)\n\n\ndef collate_fn(x: list[\"DataProtoItem\"]):\n    batch = []\n    non_tensor_batch = []\n    for data in x:\n        batch.append(data.batch)\n        non_tensor_batch.append(data.non_tensor_batch)\n    batch = torch.stack(batch).contiguous()\n    non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch)\n    for key, val in non_tensor_batch.items():\n        non_tensor_batch[key] = np.array(val, dtype=object)\n    return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)\n\n\n@dataclass\nclass DataProtoItem:\n    # TODO(zhangchi.usc1992) add consistency check\n    batch: TensorDict = None\n    non_tensor_batch: dict = field(default_factory=dict)\n    meta_info: dict = field(default_factory=dict)\n\n\n@dataclass\nclass DataProto:\n    \"\"\"\n    A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.\n    It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.\n    TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the\n    same batch size should be put inside batch.\n    \"\"\"\n\n    batch: TensorDict = None\n    non_tensor_batch: dict = field(default_factory=dict)\n    meta_info: dict = field(default_factory=dict)\n\n    def __post_init__(self):\n        # perform necessary checking\n        self.check_consistency()\n\n    def __len__(self):\n        if self.batch is not None:\n            return self.batch.batch_size[0]\n        elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:\n            random_key = list(self.non_tensor_batch.keys())[0]\n            return self.non_tensor_batch[random_key].shape[0]\n        else:\n            return 0\n\n    def __getitem__(self, item):\n        \"\"\"\n        Enhanced indexing for DataProto objects.\n\n        Args:\n            item: Can be one of:\n                - int: A single index\n                - slice: A slice object (start:stop:step)\n                - list: A list of indices\n                - numpy.ndarray: An array of indices\n                - torch.Tensor: A tensor of indices\n\n        Returns:\n            DataProto: For all indexing types except single integers\n            DataProtoItem: Only for single integer indices\n        \"\"\"\n        # Case 1: Slice object - use the slice method\n        if isinstance(item, slice):\n            return self.slice(item.start, item.stop, item.step)\n\n        # Case 2: List, numpy array, or torch tensor - use sel_idxs\n        elif isinstance(item, list | np.ndarray | torch.Tensor):\n            return self.select_idxs(item)\n\n        # Case 3: Single integer - return DataProtoItem for backward compatibility\n        elif isinstance(item, int | np.integer):\n            tensor_data = self.batch[item] if self.batch is not None else None\n            non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}\n            return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)\n\n        # # Case 4: Unsupported type\n        else:\n            raise TypeError(f\"Indexing with {type(item)} is not supported\")\n\n    def __getstate__(self):\n        import io\n\n        buffer = io.BytesIO()\n        if version.parse(tensordict.__version__) >= version.parse(\"0.5.0\") and self.batch is not None:\n            self.batch = self.batch.contiguous()\n            self.batch = self.batch.consolidate()\n        torch.save(self.batch, buffer)\n        buffer_bytes = buffer.getvalue()\n        return buffer_bytes, self.non_tensor_batch, self.meta_info\n\n    def __setstate__(self, data):\n        import io\n\n        batch_deserialized_bytes, non_tensor_batch, meta_info = data\n        batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)\n        batch = torch.load(\n            batch_deserialized,\n            weights_only=False,\n            map_location=\"cpu\" if not get_torch_device().is_available() else None,\n        )\n        self.batch = batch\n        self.non_tensor_batch = non_tensor_batch\n        self.meta_info = meta_info\n\n    def save_to_disk(self, filepath):\n        with open(filepath, \"wb\") as f:\n            pickle.dump(self, f)\n\n    @staticmethod\n    def load_from_disk(filepath) -> \"DataProto\":\n        with open(filepath, \"rb\") as f:\n            data = pickle.load(f)\n            return data\n\n    def print_size(self, prefix=\"\"):\n        size_of_tensordict = 0\n        if self.batch is not None:\n            for _, tensor in self.batch.items():\n                size_of_tensordict += tensor.element_size() * tensor.numel()\n        size_of_numpy_array = 0\n        for _, numpy_array in self.non_tensor_batch.items():\n            size_of_numpy_array += numpy_array.nbytes\n\n        size_of_numpy_array /= 1024**3\n        size_of_tensordict /= 1024**3\n\n        message = f\"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB\"\n\n        if prefix:\n            message = f\"{prefix}, \" + message\n        print(message)\n\n    def check_consistency(self):\n        \"\"\"Check the consistency of the DataProto. Mainly for batch and non_tensor_batch\n        We expose this function as a public one so that user can call themselves directly\n        \"\"\"\n        if self.batch is not None:\n            assert len(self.batch.batch_size) == 1, \"only support num_batch_dims=1\"\n\n        if self.non_tensor_batch is not None:\n            for key, val in self.non_tensor_batch.items():\n                assert isinstance(val, np.ndarray)\n\n        if self.batch is not None and self.non_tensor_batch is not None and len(self.non_tensor_batch) != 0:\n            # TODO: we can actually lift this restriction if needed\n            assert len(self.batch.batch_size) == 1, \"only support num_batch_dims=1 when non_tensor_batch is not empty.\"\n\n            batch_size = self.batch.batch_size[0]\n            for key, val in self.non_tensor_batch.items():\n                assert isinstance(val, np.ndarray), (\n                    f\"data in the non_tensor_batch must be a numpy.array with dtype=object, but for \"\n                    f\"{key=}, got {type(val)=}\"\n                )\n                assert val.shape[0] == batch_size, (\n                    f\"key {key} length {len(val)} is not equal to batch size {batch_size}\"\n                )\n\n    @classmethod\n    def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None, auto_padding=False):\n        \"\"\"Create a DataProto from a dict of tensors and non_tensors\"\"\"\n        tensors = {}\n        non_tensors = {}\n\n        for key, val in data.items():\n            if isinstance(val, torch.Tensor):\n                tensors[key] = val\n            elif isinstance(val, np.ndarray):\n                non_tensors[key] = val\n            else:\n                raise ValueError(f\"Unsupported type in data {type(val)}\")\n\n        return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info, auto_padding=auto_padding)\n\n    @classmethod\n    def from_dict(\n        cls,\n        tensors: Optional[dict[str, torch.Tensor]] = None,\n        non_tensors=None,\n        meta_info=None,\n        num_batch_dims=1,\n        auto_padding=False,\n    ):\n        \"\"\"Create a DataProto from a dict of tensors. This assumes that\n        1. All the tensor in tensors have the same dim0\n        2. Only dim0 is the batch dim\n        \"\"\"\n\n        assert num_batch_dims > 0, \"num_batch_dims must be greater than zero\"\n        if non_tensors is not None:\n            assert num_batch_dims == 1, \"only support num_batch_dims=1 when non_tensors is not None.\"\n\n        if tensors is None:\n            tensors = {}\n        if meta_info is None:\n            meta_info = {}\n        if non_tensors is None:\n            non_tensors = {}\n\n        assert isinstance(non_tensors, dict)\n\n        # get and check batch size\n        batch_size = None\n        pivot_key = None\n        for key, tensor in tensors.items():\n            if batch_size is None:\n                batch_size = tensor.shape[:num_batch_dims]\n                pivot_key = key\n            else:\n                current_batch = tensor.shape[:num_batch_dims]\n                assert batch_size == current_batch, (\n                    f\"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. \"\n                    f\"Got {pivot_key} has {batch_size}, {key} has {current_batch}\"\n                )\n\n        for key, val in non_tensors.items():\n            if not isinstance(val, np.ndarray):\n                non_tensors[key] = np.array(val, dtype=object)\n\n        tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None\n        if auto_padding:\n            meta_info[DataProtoConfig.auto_padding_key] = True\n        return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)\n\n    def to(self, device) -> \"DataProto\":\n        \"\"\"move the batch to device\n\n        Args:\n            device (torch.device, str): torch device\n\n        Returns:\n            DataProto: the current DataProto\n\n        \"\"\"\n        if self.batch is not None:\n            self.batch = self.batch.to(device)\n        return self\n\n    def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> \"DataProto\":\n        \"\"\"Select a subset of the DataProto via batch_keys and meta_info_keys\n\n        Args:\n            batch_keys (list, optional): a list of strings indicating the keys in batch to select\n            meta_info_keys (list, optional): a list of keys indicating the meta info to select\n\n        Returns:\n            DataProto: the DataProto with the selected batch_keys and meta_info_keys\n        \"\"\"\n        # TODO (zhangchi.usc1992) whether to copy\n        if batch_keys is not None:\n            batch_keys = tuple(batch_keys)\n            sub_batch = self.batch.select(*batch_keys)\n        else:\n            sub_batch = self.batch\n\n        if non_tensor_batch_keys is not None:\n            non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}\n        else:\n            non_tensor_batch = self.non_tensor_batch\n\n        if deepcopy:\n            non_tensor_batch = copy.deepcopy(non_tensor_batch)\n\n        if meta_info_keys is not None:\n            sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}\n        else:\n            sub_meta_info = self.meta_info\n\n        if deepcopy:\n            sub_meta_info = copy.deepcopy(sub_meta_info)\n\n        return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)\n\n    def select_idxs(self, idxs):\n        \"\"\"\n        Select specific indices from the DataProto.\n\n        Args:\n            idxs (torch.Tensor or numpy.ndarray or list): Indices to select\n\n        Returns:\n            DataProto: A new DataProto containing only the selected indices\n        \"\"\"\n        if isinstance(idxs, list):\n            idxs = torch.tensor(idxs)\n            if idxs.dtype != torch.bool:\n                idxs = idxs.type(torch.int32)\n\n        if isinstance(idxs, np.ndarray):\n            idxs_np = idxs\n            idxs_torch = torch.from_numpy(idxs)\n        else:  # torch.Tensor\n            idxs_torch = idxs\n            idxs_np = idxs.detach().cpu().numpy()\n\n        batch_size = int(idxs_np.sum()) if idxs_np.dtype == bool else idxs_np.shape[0]\n\n        if self.batch is not None:\n            # Use TensorDict's built-in indexing capabilities\n            selected_batch = TensorDict(\n                source={key: tensor[idxs_torch] for key, tensor in self.batch.items()},\n                batch_size=(batch_size,),\n                device=self.batch.device,\n            )\n        else:\n            selected_batch = None\n\n        selected_non_tensor = {}\n        for key, val in self.non_tensor_batch.items():\n            selected_non_tensor[key] = val[idxs_np]\n\n        return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info)\n\n    def slice(self, start=None, end=None, step=None):\n        \"\"\"\n        Slice the DataProto and return a new DataProto object.\n        This is an improved version of direct slicing which returns a DataProtoItem.\n\n        Args:\n            start (int, optional): Start index. Defaults to None (start from beginning).\n            end (int, optional): End index (exclusive). Defaults to None (go to end).\n            step (int, optional): Step size. Defaults to None (step=1).\n\n        Returns:\n            DataProto: A new DataProto containing the sliced data\n\n        Examples:\n            # Using the slice method directly\n            sliced_data = data_proto.slice(10, 20)\n\n            # Using enhanced indexing (returns DataProto)\n            sliced_data = data_proto[10:20]\n            sliced_data = data_proto[::2]  # Every other element\n\n            # Using list indexing (returns DataProto)\n            indices = [1, 5, 10]\n            selected_data = data_proto[indices]\n\n            # Single index still returns DataProtoItem\n            single_item = data_proto[5]\n        \"\"\"\n        # Create a slice object\n        slice_obj = slice(start, end, step)\n\n        # Handle the batch data\n        if self.batch is not None:\n            # Use TensorDict's built-in slicing capabilities\n            sliced_batch = self.batch[slice_obj]\n        else:\n            sliced_batch = None\n\n        # Handle the non-tensor batch data\n        sliced_non_tensor = {}\n        for key, val in self.non_tensor_batch.items():\n            sliced_non_tensor[key] = val[slice_obj]\n\n        # Return a new DataProto object\n        return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info)\n\n    def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> \"DataProto\":\n        \"\"\"Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`\n\n        Args:\n            batch_keys (list, optional): a list of strings indicating the keys in batch to pop\n            meta_info_keys (list, optional): a list of keys indicating the meta info to pop\n\n        Returns:\n            DataProto: the DataProto with the poped batch_keys and meta_info_keys\n        \"\"\"\n        if batch_keys is None:\n            batch_keys = []\n        if meta_info_keys is None:\n            meta_info_keys = []\n        if non_tensor_batch_keys is None:\n            non_tensor_batch_keys = []\n\n        tensors = {}\n        # tensor batch\n        for key in batch_keys:\n            assert key in self.batch.keys()\n            tensors[key] = self.batch.pop(key)\n        non_tensors = {}\n        # non tensor batch\n        for key in non_tensor_batch_keys:\n            assert key in self.non_tensor_batch.keys()\n            non_tensors[key] = self.non_tensor_batch.pop(key)\n        meta_info = {}\n        for key in meta_info_keys:\n            assert key in self.meta_info.keys()\n            meta_info[key] = self.meta_info.pop(key)\n        return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)\n\n    def rename(self, old_keys=None, new_keys=None) -> \"DataProto\":\n        \"\"\"\n        Note that this function only rename the key in the batch\n        \"\"\"\n\n        def validate_input(keys):\n            if keys is not None:\n                if isinstance(keys, str):\n                    keys = [keys]\n                elif isinstance(keys, list):\n                    pass\n                else:\n                    raise TypeError(f\"keys must be a list or a string, but got {type(keys)}\")\n            return keys\n\n        old_keys = validate_input(old_keys)\n        new_keys = validate_input(new_keys)\n\n        if len(new_keys) != len(old_keys):\n            raise ValueError(\n                f\"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}\"\n            )\n\n        self.batch.rename_key_(tuple(old_keys), tuple(new_keys))\n\n        return self\n\n    def union(self, other: \"DataProto\") -> \"DataProto\":\n        \"\"\"Union with another DataProto. Union batch and meta_info separately.\n        Throw an error if\n\n        - there are conflict keys in batch and they are not equal\n        - the batch size of two data batch is not the same\n        - there are conflict keys in meta_info and they are not the same.\n\n        Args:\n            other (DataProto): another DataProto to union\n\n        Returns:\n            DataProto: the DataProto after union\n        \"\"\"\n        self.batch = union_tensor_dict(self.batch, other.batch)\n        self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)\n        self.meta_info = union_two_dict(self.meta_info, other.meta_info)\n        return self\n\n    def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):\n        r\"\"\"Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch\n        dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.\n\n\n        Args:\n            mini_batch_size (int): mini-batch size when iterating the dataset. We require that\n                ``batch.batch_size[0] % mini_batch_size == 0``.\n            epochs (int): number of epochs when iterating the dataset.\n            dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The\n                dataloader_kwargs is the kwargs passed to the DataLoader.\n\n        Returns:\n            Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration\n                steps is ``self.batch.batch_size * epochs // mini_batch_size``\n        \"\"\"\n        assert self.batch.batch_size[0] % mini_batch_size == 0, f\"{self.batch.batch_size[0]} % {mini_batch_size} != 0\"\n        # we can directly create a dataloader from TensorDict\n        if dataloader_kwargs is None:\n            dataloader_kwargs = {}\n\n        if seed is not None:\n            generator = torch.Generator()\n            generator.manual_seed(seed)\n        else:\n            generator = None\n\n        assert isinstance(dataloader_kwargs, dict)\n        train_dataloader = DataLoader(\n            dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs\n        )\n\n        def get_data():\n            for _ in range(epochs):\n                for d in train_dataloader:\n                    d.meta_info = self.meta_info\n                    yield d\n\n        return iter(get_data())\n\n    def is_padding_enabled(self):\n        \"\"\"\n        Check if padding is enabled for the DataProto.\n        Returns:\n            bool: True if padding is enabled, False otherwise.\n        \"\"\"\n        dataproto_specific_padding = self.meta_info.get(DataProtoConfig.auto_padding_key, False)\n        return dataproto_specific_padding or DataProtoConfig.auto_padding\n\n    def padding(self, padding_size, padding_candidate=\"\"):\n        \"\"\"Pad the DataProto by concating with padding_candidate.repeat(padding_size)\n\n        Args:\n            padding_size (int): the number of repeated padding_candidate\n            padding_candidate: the item to be repeated and appended to the DataProto, only supporting [\"first\", \"last\"]\n        \"\"\"\n        if padding_size == 0:\n            return\n        padding_candidate = self.select_idxs([0 if padding_candidate == \"first\" else len(self) - 1])\n        padding_part = padding_candidate.repeat(padding_size)\n        padded_dp = DataProto.concat([self, padding_part])\n        self.batch = padded_dp.batch\n        self.non_tensor_batch = padded_dp.non_tensor_batch\n\n    def chunk(self, chunks: int) -> list[\"DataProto\"]:\n        \"\"\"Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.\n\n        Args:\n            chunks (int): the number of chunks to split on dim=0\n\n        Returns:\n            List[DataProto]: a list of DataProto after splitting\n        \"\"\"\n        if not self.is_padding_enabled():\n            assert len(self) % chunks == 0, (\n                f\"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.\"\n            )\n\n        bsz_in_batch = None\n        if self.batch is not None:\n            batch_lst = self.batch.chunk(chunks=chunks, dim=0)\n            bsz_in_batch = np.array([batch.batch_size[0] for batch in batch_lst])\n            chunk_indices = np.cumsum(bsz_in_batch)[:-1]\n        else:\n            batch_lst = [None for _ in range(chunks)]\n\n        non_tensor_batch_lst = [{} for _ in range(chunks)]\n        for key, val in self.non_tensor_batch.items():\n            assert isinstance(val, np.ndarray)\n            if bsz_in_batch is not None:\n                non_tensor_lst = np.array_split(val, chunk_indices.tolist())\n            else:\n                non_tensor_lst = np.array_split(val, chunks)\n            assert len(non_tensor_lst) == chunks\n            for i in range(chunks):\n                non_tensor_batch_lst[i][key] = non_tensor_lst[i]\n\n        output = []\n        for i in range(chunks):\n            output.append(\n                type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)\n            )\n\n        return output\n\n    def split(self, split_size: int) -> list[\"DataProto\"]:\n        \"\"\"Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.\n\n        Args:\n            split_size (int): the size of each split\n\n        Returns:\n            List[DataProto]: a list of DataProto after splitting\n        \"\"\"\n        return [self[i : i + split_size] for i in range(0, len(self), split_size)]\n\n    @staticmethod\n    def concat(data: list[\"DataProto\"]) -> \"DataProto\":\n        \"\"\"Concat a list of DataProto. The batch is concatenated among dim=0.\n        The meta_info is assumed to be identical and will use the first one.\n\n        Args:\n            data (List[DataProto]): list of DataProto\n\n        Returns:\n            DataProto: concatenated DataProto\n        \"\"\"\n        batch_lst = []\n        for batch in data:\n            batch_lst.append(batch.batch)\n        new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None\n\n        non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])\n        for key, val in non_tensor_batch.items():\n            non_tensor_batch[key] = np.concatenate(val, axis=0)\n\n        cls = type(data[0]) if len(data) > 0 else DataProto\n        return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)\n\n    def reorder(self, indices):\n        \"\"\"\n        Note that this operation is in-place\n        \"\"\"\n        indices_np = indices.detach().numpy()\n        self.batch = self.batch[indices]\n        self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}\n\n    def repeat(self, repeat_times=2, interleave=True):\n        \"\"\"\n        Repeat the batch data a specified number of times.\n\n        Args:\n            repeat_times (int): Number of times to repeat the data.\n            interleave (bool): Whether to interleave the repeated data.\n\n        Returns:\n            DataProto: A new DataProto with repeated data.\n        \"\"\"\n        if self.batch is not None:\n            if interleave:\n                # Interleave the data\n                repeated_tensors = {\n                    key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()\n                }\n            else:\n                # Stack the data\n                repeated_tensors = {\n                    key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:])\n                    for key, tensor in self.batch.items()\n                }\n\n            repeated_batch = TensorDict(\n                source=repeated_tensors,\n                batch_size=(self.batch.batch_size[0] * repeat_times,),\n            )\n        else:\n            repeated_batch = None\n\n        repeated_non_tensor_batch = {}\n        for key, val in self.non_tensor_batch.items():\n            if interleave:\n                repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)\n            else:\n                repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))\n\n        return type(self)(\n            batch=repeated_batch,\n            non_tensor_batch=repeated_non_tensor_batch,\n            meta_info=self.meta_info,\n        )\n\n    def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None):\n        \"\"\"Split along the second dim into `n_split`, unfold it to the first dim (batch dim)\n        Useful in passing grouped tensors that doesn't want to be shuffled in dataset.\n        keys not in split_keys are repeated to match the shape\n        Note that if the `split_keys` is not provided, it will repeat all the keys in the second dim.\n        \"\"\"\n        if self.batch is not None:\n            unfolded_batch = {}\n            for key in self.batch.keys():\n                if key in split_keys if split_keys is not None else False:\n                    shape = list(self.batch[key].shape)\n                    shape[0] = self.batch[key].shape[0] * n_split\n                    shape[1] = self.batch[key].shape[1] // n_split\n                    unfolded_batch[key] = self.batch[key].reshape(*shape)\n                else:\n                    unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0)\n            # locate the `unfolded_batch` as a TensorDict on the same device as the original batch\n            unfolded_batch = TensorDict(\n                source=unfolded_batch, batch_size=(self.batch.batch_size[0] * n_split,), device=self.batch.device\n            )\n        else:\n            unfolded_batch = None\n\n        repeated_non_tensor_batch = {}\n        for key, val in self.non_tensor_batch.items():\n            if key in split_keys:\n                shape = list(val.shape)\n                shape[0] = val.shape[0] * n_split\n                shape[1] = val.shape[1] // n_split\n                repeated_non_tensor_batch[key] = val.reshape(*shape)\n            else:\n                repeated_non_tensor_batch[key] = np.repeat(val, n_split, axis=0)\n\n        return type(self)(\n            batch=unfolded_batch,\n            non_tensor_batch=repeated_non_tensor_batch,\n            meta_info=self.meta_info,\n        )\n\n    def sample_level_repeat(self, repeat_times):\n        \"\"\"\n        Repeat each row of the batch data a specified number of times.\n\n        Args:\n            repeat_times (torch.tensor, list, tuple, ndarray):  Number of times to repeat the data.\n\n        Returns:\n            DataProto: A new DataProto with repeated data.\n        \"\"\"\n        if isinstance(repeat_times, tuple):\n            repeat_times = list(repeat_times)\n        elif isinstance(repeat_times, torch.Tensor):\n            assert len(repeat_times.shape) == 1\n            repeat_times = repeat_times.tolist()\n        elif isinstance(repeat_times, np.ndarray):\n            assert len(repeat_times.shape) == 1\n            repeat_times = repeat_times.tolist()\n        else:\n            assert isinstance(repeat_times, list), (\n                f\"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}\"\n            )\n        repeat_times = torch.tensor(repeat_times)\n\n        if self.batch is not None:\n            # Interleave the data\n            repeated_tensors = {\n                key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()\n            }\n\n            repeated_batch = TensorDict(\n                source=repeated_tensors,\n                batch_size=(repeat_times.sum().item(),),\n                device=self.batch.device,\n            )\n        else:\n            repeated_batch = None\n\n        repeated_non_tensor_batch = {}\n        for key, val in self.non_tensor_batch.items():\n            repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)\n\n        return type(self)(\n            batch=repeated_batch,\n            non_tensor_batch=repeated_non_tensor_batch,\n            meta_info=self.meta_info,\n        )\n\n\n@dataclass\nclass DataProtoFuture:\n    \"\"\"\n    DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait\n    for data so that asynchronous execution becomes possible.\n    DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.\n    - collect_fn is a Callable that reduces the list of futures to a DataProto\n    - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size\n        and then select\n\n    Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination\n    - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any\n    operation on the DataProtoFuture in driver.\n    \"\"\"\n\n    collect_fn: Callable\n    futures: list[ray.ObjectRef]\n    dispatch_fn: Callable = None\n\n    @staticmethod\n    def concat(data: list[ray.ObjectRef]) -> \"DataProtoFuture\":\n        output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)\n        return output\n\n    def chunk(self, chunks: int) -> list[\"DataProtoFuture\"]:\n        from functools import partial\n\n        arg_future_lst = []\n        for i in range(chunks):\n            # note that we can't directly pass i and chunks\n            def dispatch_fn(x, i, chunks):\n                return x.chunk(chunks=chunks)[i]\n\n            arg_future = DataProtoFuture(\n                collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures\n            )\n            arg_future_lst.append(arg_future)\n        return arg_future_lst\n\n    def get(self):\n        output = ray.get(self.futures)  # dp_size.\n        for o in output:\n            assert isinstance(o, DataProto)\n        output = self.collect_fn(output)  # select dp, concat\n        if self.dispatch_fn is not None:\n            output = self.dispatch_fn(output)  # split in batch dim, select using dp\n        return output\n\n\ndef all_gather_data_proto(data: DataProto, process_group):\n    # Note that this is an inplace operator just like torch.distributed.all_gather\n    group_size = torch.distributed.get_world_size(group=process_group)\n    assert isinstance(data, DataProto)\n    prev_device = data.batch.device\n    data.batch = data.batch.to(get_device_id())\n    data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0)\n    data.batch = data.batch.to(prev_device)\n    # all gather non_tensor_batch\n    all_non_tensor_batch = [None for _ in range(group_size)]\n    torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=process_group)\n    data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}\n"
  },
  {
    "path": "verl_rl/verl/py.typed",
    "content": ""
  },
  {
    "path": "verl_rl/verl/single_controller/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nfrom . import base\nfrom .base import *\n\nversion_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))\n\n# Note(haibin.lin): single_controller.__version__ is deprecated\nwith open(os.path.join(os.path.join(version_folder, os.pardir), \"version/version\")) as f:\n    __version__ = f.read().strip()\n\n\n__all__ = base.__all__\n"
  },
  {
    "path": "verl_rl/verl/single_controller/base/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .worker import Worker\nfrom .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup\n\n__all__ = [\"Worker\", \"WorkerGroup\", \"ClassWithInitArgs\", \"ResourcePool\"]\n"
  },
  {
    "path": "verl_rl/verl/single_controller/base/decorator.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom functools import wraps\nfrom types import FunctionType\n\nfrom verl.protocol import DataProtoFuture, _padding_size_key\nfrom verl.utils.py_functional import DynamicEnum\n\n# here we add a magic number of avoid user-defined function already have this attribute\nMAGIC_ATTR = \"attrs_3141562937\"\n\n\nclass Dispatch(DynamicEnum):\n    \"\"\"Enum class defining different dispatch modes for distributed computation.\n\n    Each mode represents a specific strategy for distributing data across\n    different ranks in a distributed system. The modes are used to control\n    how data is partitioned and processed across different worker groups.\n    \"\"\"\n\n    _registry = {}\n    _next_value = 0\n\n\ndef init_predefined_dispatch_mode():\n    Dispatch.register(\"RANK_ZERO\")\n    Dispatch.register(\"ONE_TO_ALL\")\n    Dispatch.register(\"ALL_TO_ALL\")\n    Dispatch.register(\"MEGATRON_COMPUTE\")\n    Dispatch.register(\"MEGATRON_PP_AS_DP\")\n    Dispatch.register(\"MEGATRON_PP_ONLY\")\n    Dispatch.register(\"MEGATRON_COMPUTE_PROTO\")\n    Dispatch.register(\"MEGATRON_PP_AS_DP_PROTO\")\n    Dispatch.register(\"DP_COMPUTE\")\n    Dispatch.register(\"DP_COMPUTE_PROTO\")\n    Dispatch.register(\"DP_COMPUTE_PROTO_WITH_FUNC\")\n    Dispatch.register(\"DP_COMPUTE_METRIC\")\n    # This is a special dispatch mode for vllm ExternalRayDistributedExecutor\n    Dispatch.register(\"DIRECT_ROLLOUT_METHOD\")\n\n\nclass Execute(DynamicEnum):\n    \"\"\"Enum class defining different execution modes for distributed computation.\n\n    These modes control how a function should be executed across different ranks\n    in a distributed system.\n    \"\"\"\n\n    _registry = {}\n    _next_value = 0\n\n\ndef init_predefined_execute_mode():\n    Execute.register(\"ALL\")\n    Execute.register(\"RANK_ZERO\")\n\n\n# Initialize the two Dynamic Enum Classes\ninit_predefined_dispatch_mode()\ninit_predefined_execute_mode()\n\n\ndef _split_args_kwargs_data_proto(chunks, *args, **kwargs):\n    from verl.protocol import DataProto, DataProtoFuture\n\n    splitted_args = []\n    for arg in args:\n        assert isinstance(arg, DataProto | DataProtoFuture)\n        splitted_args.append(arg.chunk(chunks=chunks))\n\n    splitted_kwargs = {}\n    for key, val in kwargs.items():\n        assert isinstance(val, DataProto | DataProtoFuture)\n        splitted_kwargs[key] = val.chunk(chunks=chunks)\n\n    return splitted_args, splitted_kwargs\n\n\ndef _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs):\n    from verl.protocol import DataProto, DataProtoFuture\n\n    data_proto_len = None\n    padding_size = None\n\n    def _padding_and_split_data(obj, chunks):\n        nonlocal data_proto_len, padding_size\n        assert isinstance(obj, DataProto | DataProtoFuture)\n        if isinstance(obj, DataProto) and obj.is_padding_enabled():\n            # for padding, we only support DataProto with same length\n            if data_proto_len is None:\n                data_proto_len = len(obj)\n                padding_size = (chunks - (data_proto_len % chunks)) if (data_proto_len % chunks > 0) else 0\n            else:\n                assert data_proto_len == len(obj), (\n                    f\"expecting all arg share same length of {data_proto_len}, but got {len(obj)}\"\n                )\n            obj.padding(padding_size=padding_size)\n        return obj.chunk(chunks=chunks)\n\n    splitted_args = [_padding_and_split_data(arg, chunks) for arg in args]\n    splitted_kwargs = {key: _padding_and_split_data(val, chunks) for key, val in kwargs.items()}\n    if padding_size is not None:\n        splitted_kwargs[_padding_size_key] = padding_size\n\n    return splitted_args, splitted_kwargs\n\n\ndef dispatch_one_to_all(worker_group, *args, **kwargs):\n    args = tuple([arg] * worker_group.world_size for arg in args)\n    kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}\n    return args, kwargs\n\n\ndef dummy_direct_rollout_call(worker_group, *args, **kwargs):\n    raise NotImplementedError(\"Direct rollout call is forbidden.\")\n\n\ndef dispatch_all_to_all(worker_group, *args, **kwargs):\n    return args, kwargs\n\n\ndef collect_all_to_all(worker_group, output):\n    return output\n\n\ndef dispatch_megatron_compute(worker_group, *args, **kwargs):\n    \"\"\"\n    User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp\n    \"\"\"\n    from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup\n\n    assert isinstance(worker_group, MegatronWorkerGroup), (\n        f\"worker_group must be MegatronWorkerGroup, Got {type(worker_group)}\"\n    )\n\n    # ray put all the args in advance to avoid duplicate serialization cost\n    import ray\n\n    args = [[ray.put(dp_arg) for dp_arg in arg] for arg in args]\n    kwargs = {k: [ray.put(dp_v) for dp_v in v] for k, v in kwargs.items()}\n\n    def _transform_data(obj_list, worker_group):\n        assert isinstance(obj_list, tuple | list) and len(obj_list) == worker_group.dp_size\n        transformed_data = []\n        for i in range(worker_group.world_size):\n            local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank\n            transformed_data.append(obj_list[local_dp_rank])\n        return transformed_data\n\n    all_args = tuple([_transform_data(arg, worker_group) for arg in args])\n    all_kwargs = {key: _transform_data(val, worker_group) for key, val in kwargs.items()}\n\n    return all_args, all_kwargs\n\n\ndef collect_megatron_compute(worker_group, output):\n    \"\"\"\n    Only collect the data from the tp=0 and pp=last and every dp ranks\n    \"\"\"\n    from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup\n\n    assert isinstance(worker_group, MegatronWorkerGroup)\n    output_in_dp = []\n    pp_size = worker_group.get_megatron_global_info().pp_size\n    for global_rank in range(worker_group.world_size):\n        local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)\n        if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1 and local_rank_info.cp_rank == 0:\n            output_in_dp.append(output[global_rank])\n    return output_in_dp\n\n\ndef dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):\n    \"\"\"\n    All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank\n    \"\"\"\n    from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup\n\n    assert isinstance(worker_group, MegatronWorkerGroup)\n\n    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs)\n    return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs)\n\n\ndef _concat_data_proto_or_future(output: list):\n    import ray\n\n    from verl.protocol import DataProto, DataProtoFuture\n\n    # make sure all the elements in output has the same type\n    for o in output:\n        assert type(o) is type(output[0])\n\n    o = output[0]\n\n    if isinstance(o, DataProto):\n        return DataProto.concat(output)\n    elif isinstance(o, ray.ObjectRef):\n        return DataProtoFuture.concat(output)\n    else:\n        raise NotImplementedError\n\n\ndef collect_megatron_compute_data_proto(worker_group, output):\n    \"\"\"\n    Each output must be a DataProto. We concat the dim=0 of output\n    \"\"\"\n    import ray\n\n    from verl.protocol import DataProto\n\n    output = collect_megatron_compute(worker_group, output)\n    for o in output:\n        assert isinstance(o, DataProto | ray.ObjectRef), f\"expecting {o} to be DataProto, but got {type(o)}\"\n\n    return _concat_data_proto_or_future(output)\n\n\ndef dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):\n    \"\"\"\n    treat pp as dp.\n    \"\"\"\n    from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup\n\n    assert isinstance(worker_group, MegatronWorkerGroup)\n\n    pp_size = worker_group.pp_size\n    dp_size = worker_group.dp_size\n    cp_size = worker_group.cp_size\n    pp_dp_cp_size = pp_size * dp_size * cp_size\n\n    def _transform_data(obj_list, worker_group):\n        assert isinstance(obj_list, list | tuple) and len(obj_list) == pp_dp_cp_size\n        transformed_data = []\n        for i in range(worker_group.world_size):\n            local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank\n            local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank\n            local_cp_rank = worker_group.get_megatron_rank_info(rank=i).cp_rank\n            # compute the rank in obj_list. Note that the order is dp then cp then pp\n            # Also note that the outputs within a pp group will be firstly allgathered, then only the\n            # output of pp0 will be collected.\n            # For pp=2 dp=4, a batch of data \"ABCDEFGH\" should be dispatched and collected in below order:\n            #    dispatch:       pp_allgther:        collect:\n            #   dp 0 1 2 3      dp  0  1  2  3\n            # pp +---------+  pp +-------------+\n            #  0 | A C E G |   0 | AB CD EF GH |     ABCDEFGH\n            #  1 | B D F H |   1 | AB CD EF GH |\n            #    +---------+     +-------------+\n            dp_cp_rank = local_cp_rank * dp_size + local_dp_rank\n            arg_rank = dp_cp_rank * pp_size + local_pp_rank\n            transformed_data.append(obj_list[arg_rank])\n        return transformed_data\n\n    all_args = tuple([_transform_data(arg, worker_group) for arg in args])\n    all_kwargs = {key: _transform_data(val, worker_group) for key, val in kwargs.items()}\n\n    return all_args, all_kwargs\n\n\ndef collect_megatron_pp_as_dp(worker_group, output):\n    \"\"\"\n    treat pp as dp. Only collect data on tp=0\n    \"\"\"\n    from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup\n\n    assert isinstance(worker_group, MegatronWorkerGroup)\n    output_in_dp = []\n    for global_rank in range(worker_group.world_size):\n        local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)\n        if local_rank_info.tp_rank == 0:\n            output_in_dp.append(output[global_rank])\n    return output_in_dp\n\n\ndef collect_megatron_pp_only(worker_group, output):\n    \"\"\"\n    Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp\n    \"\"\"\n    from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup\n\n    assert isinstance(worker_group, MegatronWorkerGroup)\n    output_in_pp = []\n    for global_rank in range(worker_group.world_size):\n        local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)\n        if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0:\n            output_in_pp.append(output[global_rank])\n    return output_in_pp\n\n\ndef dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):\n    from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup\n\n    assert isinstance(worker_group, MegatronWorkerGroup)\n\n    pp_dp_cp_size = worker_group.dp_size * worker_group.pp_size * worker_group.cp_size\n    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_cp_size, *args, **kwargs)\n    ret = dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs)\n    return ret\n\n\ndef collect_megatron_pp_as_dp_data_proto(worker_group, output):\n    from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup\n\n    assert isinstance(worker_group, MegatronWorkerGroup)\n\n    output = collect_megatron_pp_as_dp(worker_group, output)\n    return _concat_data_proto_or_future(output)\n\n\ndef dispatch_dp_compute(worker_group, *args, **kwargs):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n    for arg in args:\n        assert isinstance(arg, tuple | list) and len(arg) == worker_group.world_size\n    for k, v in kwargs.items():\n        assert isinstance(v, tuple | list) and len(v) == worker_group.world_size\n    return args, kwargs\n\n\ndef collect_dp_compute(worker_group, output):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n    assert len(output) == worker_group.world_size\n    return output\n\n\ndef dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n    # Note: enable auto padding for dp compute DatapProto\n    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding(\n        worker_group.world_size,\n        *args,\n        **kwargs,\n    )\n    return splitted_args, splitted_kwargs\n\n\ndef dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):\n    from verl.single_controller.base.worker_group import WorkerGroup\n\n    assert isinstance(worker_group, WorkerGroup)\n    assert isinstance(args[0], FunctionType)  # NOTE: The first one args is a function!\n\n    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)\n    splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args\n    return splitted_args_with_func, splitted_kwargs\n\n\ndef collect_dp_compute_data_proto(worker_group, output):\n    import ray\n\n    from verl.protocol import DataProto\n\n    for o in output:\n        assert isinstance(o, DataProto | ray.ObjectRef), f\"expecting {o} to be DataProto, but got {type(o)}\"\n\n    output = collect_dp_compute(worker_group, output)\n    return _concat_data_proto_or_future(output)\n\n\n# Global registry for dispatch mode.\nDISPATCH_MODE_FN_REGISTRY = {\n    Dispatch.ONE_TO_ALL: {\n        \"dispatch_fn\": dispatch_one_to_all,\n        \"collect_fn\": collect_all_to_all,\n    },\n    Dispatch.ALL_TO_ALL: {\n        \"dispatch_fn\": dispatch_all_to_all,\n        \"collect_fn\": collect_all_to_all,\n    },\n    Dispatch.MEGATRON_COMPUTE: {\n        \"dispatch_fn\": dispatch_megatron_compute,\n        \"collect_fn\": collect_megatron_compute,\n    },\n    Dispatch.MEGATRON_PP_AS_DP: {\n        \"dispatch_fn\": dispatch_megatron_pp_as_dp,\n        \"collect_fn\": collect_megatron_pp_as_dp,\n    },\n    Dispatch.MEGATRON_PP_ONLY: {\"dispatch_fn\": dispatch_one_to_all, \"collect_fn\": collect_megatron_pp_only},\n    Dispatch.MEGATRON_COMPUTE_PROTO: {\n        \"dispatch_fn\": dispatch_megatron_compute_data_proto,\n        \"collect_fn\": collect_megatron_compute_data_proto,\n    },\n    Dispatch.MEGATRON_PP_AS_DP_PROTO: {\n        \"dispatch_fn\": dispatch_megatron_pp_as_dp_data_proto,\n        \"collect_fn\": collect_megatron_pp_as_dp_data_proto,\n    },\n    Dispatch.DP_COMPUTE: {\"dispatch_fn\": dispatch_dp_compute, \"collect_fn\": collect_dp_compute},\n    Dispatch.DP_COMPUTE_PROTO: {\n        \"dispatch_fn\": dispatch_dp_compute_data_proto,\n        \"collect_fn\": collect_dp_compute_data_proto,\n    },\n    Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {\n        \"dispatch_fn\": dispatch_dp_compute_data_proto_with_func,\n        \"collect_fn\": collect_dp_compute_data_proto,\n    },\n    Dispatch.DP_COMPUTE_METRIC: {\"dispatch_fn\": dispatch_dp_compute_data_proto, \"collect_fn\": collect_dp_compute},\n    Dispatch.DIRECT_ROLLOUT_METHOD: {\n        \"dispatch_fn\": dummy_direct_rollout_call,\n        \"collect_fn\": dummy_direct_rollout_call,\n    },\n}\n\n\ndef get_predefined_dispatch_fn(dispatch_mode):\n    return DISPATCH_MODE_FN_REGISTRY[dispatch_mode]\n\n\ndef register_dispatch_mode(dispatch_mode_name, dispatch_fn, collect_fn):\n    \"\"\"\n    Register a new dispatch mode.\n    \"\"\"\n    dispatch_mode = Dispatch.register(dispatch_mode_name)\n    _check_dispatch_mode(dispatch_mode)\n    assert dispatch_mode not in DISPATCH_MODE_FN_REGISTRY, f\"dispatch_mode_name {dispatch_mode_name} already exists\"\n    DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {\"dispatch_fn\": dispatch_fn, \"collect_fn\": collect_fn}\n\n\ndef update_dispatch_mode(dispatch_mode, dispatch_fn, collect_fn):\n    \"\"\"\n    Update the dispatch mode.\n    \"\"\"\n    _check_dispatch_mode(dispatch_mode)\n    assert dispatch_mode in DISPATCH_MODE_FN_REGISTRY, f\"dispatch_mode {dispatch_mode} not found\"\n    DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {\"dispatch_fn\": dispatch_fn, \"collect_fn\": collect_fn}\n\n\ndef get_predefined_execute_fn(execute_mode):\n    \"\"\"\n    Note that here we only asks execute_all and execute_rank_zero to be implemented\n    Leave the choice of how these two functions handle argument 'blocking' to users\n    \"\"\"\n    predefined_execute_mode_fn = {\n        Execute.ALL: {\"execute_fn_name\": \"execute_all\"},\n        Execute.RANK_ZERO: {\"execute_fn_name\": \"execute_rank_zero\"},\n    }\n    return predefined_execute_mode_fn[execute_mode]\n\n\ndef _check_dispatch_mode(dispatch_mode):\n    assert isinstance(dispatch_mode, Dispatch | dict), (\n        f\"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}\"\n    )\n    if isinstance(dispatch_mode, dict):\n        necessary_keys = [\"dispatch_fn\", \"collect_fn\"]\n        for key in necessary_keys:\n            assert key in dispatch_mode, f\"key {key} should be in dispatch_mode if it is a dictionary\"\n\n\ndef _check_execute_mode(execute_mode):\n    assert isinstance(execute_mode, Execute), f\"execute_mode must be a Execute. Got {execute_mode}\"\n\n\ndef _materialize_futures(*args, **kwargs):\n    new_args = []\n    for arg in args:\n        if isinstance(arg, DataProtoFuture):\n            arg = arg.get()\n        # add more type to materialize\n        new_args.append(arg)\n    for k, v in kwargs.items():\n        if isinstance(v, DataProtoFuture):\n            kwargs[k] = v.get()\n\n    new_args = tuple(new_args)\n    return new_args, kwargs\n\n\ndef register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):\n    \"\"\"Register a function with distributed execution configuration.\n\n    This decorator registers a function with specific dispatch and execution modes\n    for distributed computation. It handles both synchronous and asynchronous\n    functions, and optionally materializes futures before execution.\n\n    Args:\n        dispatch_mode:\n            Dispatch mode for computation distribution. Default: Dispatch.ALL_TO_ALL.\n        execute_mode:\n            Execute mode for computation distribution. Default: Execute.ALL.\n        blocking:\n            Whether the execution should be blocking. Defaults to True.\n        materialize_futures:\n            Whether to materialize the data before dispatching. Defaults to True.\n\n    Returns:\n        A decorator that wraps the original function with distributed execution\n        configuration.\n    \"\"\"\n    _check_dispatch_mode(dispatch_mode=dispatch_mode)\n    _check_execute_mode(execute_mode=execute_mode)\n\n    def decorator(func):\n        @wraps(func)\n        def inner(*args, **kwargs):\n            if materialize_futures:\n                args, kwargs = _materialize_futures(*args, **kwargs)\n            return func(*args, **kwargs)\n\n        @wraps(func)\n        async def async_inner(*args, **kwargs):\n            if materialize_futures:\n                args, kwargs = _materialize_futures(*args, **kwargs)\n            return await func(*args, **kwargs)\n\n        wrapper = async_inner if inspect.iscoroutinefunction(func) else inner\n        attrs = {\"dispatch_mode\": dispatch_mode, \"execute_mode\": execute_mode, \"blocking\": blocking}\n        setattr(wrapper, MAGIC_ATTR, attrs)\n        return wrapper\n\n    return decorator\n"
  },
  {
    "path": "verl_rl/verl/single_controller/base/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/single_controller/base/megatron/worker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom verl.single_controller.base.worker import DistGlobalInfo, DistRankInfo, Worker\n\n\nclass MegatronWorker(Worker):\n    def __init__(self, cuda_visible_devices=None) -> None:\n        super().__init__(cuda_visible_devices)\n\n    def get_megatron_global_info(self):\n        from megatron.core import parallel_state as mpu\n\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        dp_size = mpu.get_data_parallel_world_size()\n        pp_size = mpu.get_pipeline_model_parallel_world_size()\n        cp_size = mpu.get_context_parallel_world_size()\n        info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size, cp_size=cp_size)\n        return info\n\n    def get_megatron_rank_info(self):\n        from megatron.core import parallel_state as mpu\n\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        dp_rank = mpu.get_data_parallel_rank()\n        pp_rank = mpu.get_pipeline_model_parallel_rank()\n        cp_rank = mpu.get_context_parallel_rank()\n        info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank)\n        return info\n\n    def _init_hf_config_and_tf_config(\n        self,\n        model_path,\n        tokenizer_or_path,\n        dtype,\n        override_model_config,\n        override_transformer_config,\n        trust_remote_code=False,\n        use_mbridge=False,\n    ):\n        from transformers import AutoConfig\n\n        from verl.models.mcore import hf_to_mcore_config\n        from verl.utils import hf_processor, hf_tokenizer\n        from verl.utils.fs import copy_to_local\n        from verl.utils.model import update_model_config\n\n        # Step 1: initialize the tokenizer\n        self.local_path = copy_to_local(model_path)\n        if tokenizer_or_path is None:\n            self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code)\n            self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code)\n        elif isinstance(tokenizer_or_path, str):\n            self.tokenizer = hf_tokenizer(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code)\n            self.processor = hf_processor(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code)\n        else:\n            self.tokenizer = tokenizer_or_path\n            self.processor = tokenizer_or_path\n\n        if self.config.model.get(\"custom_chat_template\", None) is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.config.model.custom_chat_template\n            else:\n                self.tokenizer.chat_template = self.config.model.custom_chat_template\n\n        # Step 2: get the hf\n        hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code)\n\n        # Step 3: override the hf config\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_model_config.get(\"model_config\", {}))\n        self.share_embeddings_and_output_weights = getattr(hf_config, \"tie_word_embeddings\", False)\n        update_model_config(hf_config, override_config_kwargs=override_config_kwargs)\n        self.architectures = getattr(hf_config, \"architectures\", None)\n        if self.rank == 0:\n            print(f\"Model config after override: {hf_config}\")\n        tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config)\n\n        if use_mbridge:\n            from verl.models.mcore.mbridge import AutoBridge\n\n            bridge = AutoBridge.from_config(hf_config)\n            bridge.set_extra_args(**override_transformer_config)\n            tf_config = bridge.config\n            self.bridge = bridge\n        else:\n            self.bridge = None\n\n        print(f\"TF config: {tf_config}\")\n        self.hf_config = hf_config\n        self.tf_config = tf_config\n"
  },
  {
    "path": "verl_rl/verl/single_controller/base/megatron/worker_group.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom verl.single_controller.base import ResourcePool, WorkerGroup\n\nfrom .worker import DistGlobalInfo, DistRankInfo\n\n\nclass MegatronWorkerGroup(WorkerGroup):\n    def __init__(self, resource_pool: ResourcePool, **kwargs):\n        super().__init__(resource_pool=resource_pool, **kwargs)\n        self._megatron_rank_info = None\n        self._megatron_global_info: DistGlobalInfo = None\n\n    def init_megatron(self, default_megatron_kwargs: dict = None):\n        raise NotImplementedError(\"MegatronWorkerGroup.init_megatron should be overwritten\")\n\n    def get_megatron_rank_info(self, rank: int) -> DistRankInfo:\n        assert 0 <= rank < self.world_size, f\"rank must be from [0, world_size), Got {rank}\"\n        return self._megatron_rank_info[rank]\n\n    @property\n    def tp_size(self):\n        assert self._megatron_global_info is not None, \"MegatronWorkerGroup._megatron_global_info must be initialized\"\n        return self._megatron_global_info.tp_size\n\n    @property\n    def dp_size(self):\n        assert self._megatron_global_info is not None, \"MegatronWorkerGroup._megatron_global_info must be initialized\"\n        return self._megatron_global_info.dp_size\n\n    @property\n    def pp_size(self):\n        assert self._megatron_global_info is not None, \"MegatronWorkerGroup._megatron_global_info must be initialized\"\n        return self._megatron_global_info.pp_size\n\n    @property\n    def cp_size(self):\n        assert self._megatron_global_info is not None, \"MegatronWorkerGroup._megatron_global_info must be initialized\"\n        return self._megatron_global_info.cp_size\n\n    def get_megatron_global_info(self):\n        return self._megatron_global_info\n"
  },
  {
    "path": "verl_rl/verl/single_controller/base/register_center/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/single_controller/base/register_center/ray.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport ray\n\n\n@ray.remote\nclass WorkerGroupRegisterCenter:\n    def __init__(self, rank_zero_info):\n        self.rank_zero_info = rank_zero_info\n        # rank -> node_id\n        self.workers_info: dict[int, str] = {}\n\n    def get_rank_zero_info(self):\n        return self.rank_zero_info\n\n    def set_worker_info(self, rank, node_id) -> None:\n        self.workers_info[rank] = node_id\n\n    def get_worker_info(self) -> dict[int, str]:\n        return self.workers_info\n\n\ndef create_worker_group_register_center(name, info):\n    return WorkerGroupRegisterCenter.options(name=name).remote(info)\n"
  },
  {
    "path": "verl_rl/verl/single_controller/base/worker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nthe class for Worker\n\"\"\"\n\nimport os\nimport socket\nfrom dataclasses import dataclass\n\nimport ray\n\nfrom verl.utils.device import get_torch_device, get_visible_devices_keyword\n\nfrom .decorator import Dispatch, Execute, register\n\n\n@dataclass\nclass DistRankInfo:\n    tp_rank: int\n    dp_rank: int\n    pp_rank: int\n    cp_rank: int\n\n\n@dataclass\nclass DistGlobalInfo:\n    tp_size: int\n    dp_size: int\n    pp_size: int\n    cp_size: int\n\n\nclass WorkerHelper:\n    @staticmethod\n    def _get_node_ip():\n        if os.getenv(\"WG_BACKEND\", None) == \"ray\":\n            return ray.util.get_node_ip_address()\n        else:\n            raise NotImplementedError(\"WG_BACKEND now just support ray mode.\")\n\n    @staticmethod\n    def _get_free_port():\n        with socket.socket() as sock:\n            sock.bind((\"\", 0))\n            return sock.getsockname()[1]\n\n    def get_availale_master_addr_port(self):\n        return self._get_node_ip().strip(\"[]\"), str(self._get_free_port())\n\n\n# we assume that in each WorkerGroup, there is a Master Worker\nclass Worker(WorkerHelper):\n    \"\"\"A distributed worker that handles initialization and configuration for distributed training.\n\n    This class manages worker initialization, configuration, and provides methods for executing\n    distributed operations. It handles communication settings, device configuration, and worker\n    metadata management.\n    \"\"\"\n\n    fused_worker_attr_name = \"fused_worker_dict\"\n\n    def __new__(cls, *args, **kwargs):\n        \"\"\"Create a new Worker instance with proper initialization based on environment settings.\"\"\"\n        instance = super().__new__(cls)\n\n        # note that here we use int to distinguish\n        disable_worker_init = int(os.environ.get(\"DISABLE_WORKER_INIT\", 0))\n        if disable_worker_init:\n            return instance\n\n        rank = os.environ.get(\"RANK\", None)\n        worker_group_prefix = os.environ.get(\"WG_PREFIX\", None)\n\n        # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init\n        if None not in [rank, worker_group_prefix] and \"ActorClass(\" not in cls.__name__:\n            instance._configure_before_init(f\"{worker_group_prefix}_register_center\", int(rank))\n\n        return instance\n\n    def _configure_before_init(self, register_center_name: str, rank: int):\n        \"\"\"Configure worker settings before initialization.\n\n        Args:\n            register_center_name (str):\n                Name of the register center Ray actor for worker coordination\n            rank (int):\n                Rank of the worker in the distributed setup\n        \"\"\"\n        assert isinstance(rank, int), f\"rank must be int, instead of {type(rank)}\"\n\n        if rank == 0:\n            master_addr, master_port = self.get_availale_master_addr_port()\n            rank_zero_info = {\n                \"MASTER_ADDR\": master_addr,\n                \"MASTER_PORT\": master_port,\n            }\n\n            if os.getenv(\"WG_BACKEND\", None) == \"ray\":\n                from verl.single_controller.base.register_center.ray import create_worker_group_register_center\n\n                self.register_center = create_worker_group_register_center(\n                    name=register_center_name, info=rank_zero_info\n                )\n\n            os.environ.update(rank_zero_info)\n        else:\n            self.register_center = ray.get_actor(register_center_name)\n\n        # set worker info for node affinity scheduling\n        ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id()))\n\n    @classmethod\n    def env_keys(cls):\n        \"\"\"The keys of the environment variables that are used to configure the Worker.\"\"\"\n        return [\n            \"WORLD_SIZE\",\n            \"RANK\",\n            \"LOCAL_WORLD_SIZE\",\n            \"LOCAL_RANK\",\n            \"MASTER_ADDR\",\n            \"MASTER_PORT\",\n            get_visible_devices_keyword().upper(),\n        ]\n\n    def __init__(self, cuda_visible_devices=None) -> None:\n        \"\"\"Initialize the worker with environment settings and device configuration.\n\n        Args:\n            cuda_visible_devices (str, optional):\n                CUDA visible devices configuration. Defaults to None.\n        \"\"\"\n        # construct a meta from environment variable. Note that the import must be inside the class because\n        # it is executed remotely\n        import os\n\n        self._setup_env_cuda_visible_devices()\n\n        world_size = int(os.environ[\"WORLD_SIZE\"])\n        rank = int(os.environ[\"RANK\"])\n        self._rank = rank\n        self._world_size = world_size\n\n        master_addr = os.environ[\"MASTER_ADDR\"]\n        master_port = os.environ[\"MASTER_PORT\"]\n\n        local_world_size = int(os.getenv(\"LOCAL_WORLD_SIZE\", \"1\"))\n        local_rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n\n        store = {\n            \"_world_size\": world_size,\n            \"_rank\": rank,\n            \"_local_world_size\": local_world_size,\n            \"_local_rank\": local_rank,\n            \"_master_addr\": master_addr,\n            \"_master_port\": master_port,\n        }\n        if cuda_visible_devices is not None:\n            store[f\"_{get_visible_devices_keyword()}\".lower()] = cuda_visible_devices\n\n        self._configure_with_store(store=store)\n\n        self.fused_worker_dict = {}\n\n    def get_fused_worker_by_name(self, worker_name: str):\n        \"\"\"Get a fused worker by its name.\n\n        Args:\n            worker_name (str):\n                Name of the worker to retrieve\n        \"\"\"\n        return self.fused_worker_dict.get(worker_name, None)\n\n    def _setup_env_cuda_visible_devices(self):\n        from verl.utils.ray_utils import ray_noset_visible_devices\n\n        is_ray_noset_visible_devices = ray_noset_visible_devices()\n\n        # Prevent use of clashing `{CUDA/HIP/ROCR}_VISIBLE_DEVICES``\n        rocr_val = os.environ.get(\"ROCR_VISIBLE_DEVICES\", None)\n        hip_val = os.environ.get(\"HIP_VISIBLE_DEVICES\", None)\n        cuda_val = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n        if hip_val:\n            # Switch the use of HIP_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES for consistency.\n            # Make sure that the HIP_VISIBLE_DEVICES is set to the same value as CUDA_VISIBLE_DEVICES\n            # at this point.\n            val = os.environ.pop(\"HIP_VISIBLE_DEVICES\")\n            hip_val = None\n            if cuda_val:\n                assert val == cuda_val, (\n                    f\"Please use the same HIP_VISIBLE_DEVICES or CUDA_VISIBLE_DEVICES, inconsistant values \"\n                    f\"found: {val} and {cuda_val}.\"\n                )\n            else:\n                cuda_val = val\n                os.environ[\"CUDA_VISIBLE_DEVICES\"] = val\n                # os.environ[\"HIP_VISIBLE_DEVICES\"] = val\n\n        if rocr_val:\n            # You must take care if both HIP/CUDA and ROCR env vars are set as they have\n            # different meanings. Both env vars accept either a list of ints or a\n            # list of UUIDs. The ROCR env var is processed first which then reduces\n            # the number of GPUs that HIP can select from.\n            # https://github.com/pytorch/pytorch/pull/144026\n            # To avoid the complexity of this, we simply gives out error if both are set\n            # (Also to keep consistency with ray's practice with 2.45.0).\n            # Otherwise, we will set ROCR_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES\n            # and remove ROCR_VISIBLE_DEVICES.\n            if cuda_val:\n                raise ValueError(\"Please don't set ROCR_VISIBLE_DEVICES when HIP/CUDA_VISIBLE_DEVICES is set.\")\n\n            cuda_val = os.environ.pop(\"ROCR_VISIBLE_DEVICES\")\n            os.environ[\"CUDA_VISIBLE_DEVICES\"] = cuda_val\n            rocr_val = None\n\n        if is_ray_noset_visible_devices:\n            # NOTE: Ray will automatically set the *_VISIBLE_DEVICES\n            # environment variable for each actor, unless\n            # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set,\n            # so we need to set local rank when the flag is set.\n            local_rank = os.environ.get(\"RAY_LOCAL_RANK\")\n            os.environ[\"LOCAL_RANK\"] = local_rank\n            get_torch_device().set_device(int(local_rank))\n\n    def _configure_with_store(self, store: dict):\n        \"\"\"\n        This function should only be called inside by WorkerGroup\n        \"\"\"\n        store_env_dict = {f\"_{key.lower()}\": store.get(f\"_{key.lower()}\", None) for key in type(self).env_keys()}\n        self.__dict__.update(store_env_dict)  # this is hacky\n        # print(f\"__dict__: {self.__dict__}\")\n        for key in type(self).env_keys():\n            val = self.__dict__.get(f\"_{key.lower()}\", None)\n            if val is not None:\n                # print(f\"set {key} to {val}\")\n                os.environ[key] = str(val)\n        os.environ[\"REDIS_STORE_SERVER_HOST\"] = (\n            str(self._master_addr).replace(\"[\", \"\").replace(\"]\", \"\") if self._master_addr else \"\"\n        )\n\n    def get_master_addr_port(self):\n        \"\"\"Get the master address and port for distributed communication.\"\"\"\n        return self._master_addr, self._master_port\n\n    def get_cuda_visible_devices(self):\n        \"\"\"Get the CUDA visible devices configuration.\"\"\"\n        import os\n\n        visible_devices = os.environ.get(get_visible_devices_keyword().upper(), \"not set\")\n        return visible_devices\n\n    @property\n    def world_size(self):\n        \"\"\"Get the total number of workers in the distributed setup.\"\"\"\n        return self._world_size\n\n    @property\n    def rank(self):\n        \"\"\"Get the rank of this worker in the distributed setup.\"\"\"\n        return self._rank\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)\n    def execute_with_func_generator(self, func, *args, **kwargs):\n        \"\"\"Execute a function with function generator dispatch mode.\n\n        Args:\n            func:\n                Function to execute\n            *args:\n                Positional arguments for the function\n            **kwargs:\n                Keyword arguments for the function\n        \"\"\"\n        ret_proto = func(self, *args, **kwargs)\n        return ret_proto\n\n    @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)\n    def execute_func_rank_zero(self, func, *args, **kwargs):\n        \"\"\"Execute a function in rank zero execution mode.\n\n        Args:\n            func:\n                Function to execute\n            *args:\n                Positional arguments for the function\n            **kwargs:\n                Keyword arguments for the function\n        \"\"\"\n        result = func(*args, **kwargs)\n        return result\n"
  },
  {
    "path": "verl_rl/verl/single_controller/base/worker_group.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nthe class of WorkerGroup\n\"\"\"\n\nimport logging\nimport signal\nimport threading\nimport time\nfrom typing import Any, Callable\n\nfrom .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn\n\n\nclass ResourcePool:\n    \"\"\"\n    Manages a pool of resources across multiple nodes, tracking process counts and GPU allocations.\n    The class provides methods to calculate world size, local world sizes, and local ranks\n    across all nodes in the pool.\n    \"\"\"\n\n    def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8) -> None:\n        \"\"\"Initialize the ResourcePool with node processes and GPU configuration.\n\n        Args:\n            process_on_nodes (List[int], optional): List of process counts per node. Defaults to empty list.\n            max_colocate_count (int, optional): Maximum number of processes that can be colocated. Defaults to 10.\n            n_gpus_per_node (int, optional): Number of GPUs available per node. Defaults to 8.\n        \"\"\"\n        if process_on_nodes is None:\n            process_on_nodes = []\n        self._store = process_on_nodes\n        self.max_colocate_count = max_colocate_count\n        self.n_gpus_per_node = n_gpus_per_node  # this is left for future huawei GPU that contains 16 GPUs per node\n\n    def add_node(self, process_count):\n        self._store.append(process_count)\n\n    @property\n    def world_size(self):\n        \"\"\"Total number of processes across all nodes in the pool.\"\"\"\n        return sum(self._store)\n\n    def __call__(self) -> Any:\n        return self._store\n\n    @property\n    def store(self):\n        return self._store\n\n    def local_world_size_list(self) -> list[int]:\n        \"\"\"Returns a flat list where each process has its local world size.\"\"\"\n        nested_local_world_size_list = [\n            [local_world_size for _ in range(local_world_size)] for local_world_size in self._store\n        ]\n        return [item for row in nested_local_world_size_list for item in row]\n\n    def local_rank_list(self) -> list[int]:\n        \"\"\"Returns a flat list of local ranks for all processes across all nodes.\"\"\"\n        nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]\n        return [item for row in nested_local_rank_list for item in row]\n\n\nclass ClassWithInitArgs:\n    \"\"\"\n    Wrapper class that stores constructor arguments for deferred instantiation.\n    This class is particularly useful for remote class instantiation where\n    the actual construction needs to happen at a different time or location.\n    \"\"\"\n\n    def __init__(self, cls, *args, **kwargs) -> None:\n        \"\"\"Initialize the ClassWithInitArgs instance.\n\n        Args:\n            cls: The class to be instantiated later\n            *args: Positional arguments for the class constructor\n            **kwargs: Keyword arguments for the class constructor\n        \"\"\"\n        self.cls = cls\n        self.args = args\n        self.kwargs = kwargs\n\n        self.fused_worker_used = False\n\n    def __call__(self) -> Any:\n        \"\"\"Instantiate the stored class with the stored arguments.\"\"\"\n        return self.cls(*self.args, **self.kwargs)\n\n\ndef check_workers_alive(workers: list, is_alive: Callable, gap_time: float = 1) -> None:\n    \"\"\"Continuously monitors worker processes and raises SIGABRT if any worker dies.\n\n    Args:\n        workers (List):\n            List of worker objects to monitor\n        is_alive (Callable):\n            Function to check if a worker is alive\n        gap_time (float):\n            Time interval between checks\n    \"\"\"\n    import time\n\n    while True:\n        for worker in workers:\n            if not is_alive(worker):\n                logging.warning(f\"worker {worker} is not alive sending signal to main thread\")\n                signal.raise_signal(signal.SIGABRT)\n        time.sleep(gap_time)\n\n\nclass WorkerGroup:\n    \"\"\"\n    Base class for managing a group of workers in a distributed system.\n    The class provides methods for worker management, aliveness checking, and method binding.\n    \"\"\"\n\n    fused_worker_execute_fn_name = \"_fuw_execute\"\n\n    def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:\n        self._is_init_with_detached_workers = resource_pool is None\n\n        self.fused_worker_used = False\n\n        if resource_pool is not None:\n            # handle the case when WorkGroup is attached to an existing one\n            self._procecss_dispatch_config = resource_pool()\n        else:\n            self._procecss_dispatch_config = None\n\n        self._workers = []\n        self._worker_names = []\n\n        self._master_addr = None\n        self._master_port = None\n\n        self._checker_thread: threading.Thread = None\n\n    def _is_worker_alive(self, worker):\n        \"\"\"Check if a worker is alive. Must be implemented by derived classes.\"\"\"\n        raise NotImplementedError(\"WorkerGroup._is_worker_alive called, should be implemented in derived class.\")\n\n    def _block_until_all_workers_alive(self) -> None:\n        \"\"\"Blocks until all workers in the group are alive.\"\"\"\n        while True:\n            all_state = [self._is_worker_alive(worker) for worker in self._workers]\n            if False in all_state:\n                time.sleep(1)\n            else:\n                break\n\n    def start_worker_aliveness_check(self, every_n_seconds=1) -> None:\n        \"\"\"Starts a background thread to monitor worker aliveness.\n\n        Args:\n            every_n_seconds (int): Interval between aliveness checks\n        \"\"\"\n        # before starting checking worker aliveness, make sure all workers are already alive\n        self._block_until_all_workers_alive()\n\n        self._checker_thread = threading.Thread(\n            target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds)\n        )\n        self._checker_thread.start()\n\n    @property\n    def world_size(self):\n        \"\"\"Number of workers in the group.\"\"\"\n        return len(self._workers)\n\n    def _bind_worker_method(self, user_defined_cls, func_generator):\n        \"\"\"Binds worker methods to the WorkerGroup based on registered attributes.\n\n        Args:\n            user_defined_cls (type): The class containing methods to bind\n            func_generator (Callable): Function that generates the bound method\n\n        Returns:\n            List[str]: List of method names that were successfully bound\n        \"\"\"\n        method_names = []\n        for method_name in dir(user_defined_cls):\n            try:\n                method = getattr(user_defined_cls, method_name)\n                assert callable(method), f\"{method_name} in {user_defined_cls} is not callable\"\n            except Exception:\n                # if it is a property, it will fail because Class doesn't have instance property\n                continue\n\n            if hasattr(method, MAGIC_ATTR):\n                # this method is decorated by register\n                attribute = getattr(method, MAGIC_ATTR)\n                assert isinstance(attribute, dict), f\"attribute must be a dictionary. Got {type(attribute)}\"\n                assert \"dispatch_mode\" in attribute, \"attribute must contain dispatch_mode in its key\"\n\n                dispatch_mode = attribute[\"dispatch_mode\"]\n                execute_mode = attribute[\"execute_mode\"]\n                blocking = attribute[\"blocking\"]\n\n                # get dispatch fn\n                if isinstance(dispatch_mode, Dispatch):\n                    # get default dispatch fn\n                    fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)\n                    dispatch_fn = fn[\"dispatch_fn\"]\n                    collect_fn = fn[\"collect_fn\"]\n                else:\n                    assert isinstance(dispatch_mode, dict)\n                    assert \"dispatch_fn\" in dispatch_mode\n                    assert \"collect_fn\" in dispatch_mode\n                    dispatch_fn = dispatch_mode[\"dispatch_fn\"]\n                    collect_fn = dispatch_mode[\"collect_fn\"]\n\n                # get execute_fn_name\n                execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)\n                wg_execute_fn_name = execute_mode[\"execute_fn_name\"]\n\n                # get execute_fn from string\n                try:\n                    execute_fn = getattr(self, wg_execute_fn_name)\n                    assert callable(execute_fn), \"execute_fn must be callable\"\n                except Exception:\n                    print(f\"execute_fn {wg_execute_fn_name} is invalid\")\n                    raise\n\n                # bind a new method to the RayWorkerGroup\n                func = func_generator(\n                    self,\n                    method_name,\n                    dispatch_fn=dispatch_fn,\n                    collect_fn=collect_fn,\n                    execute_fn=execute_fn,\n                    blocking=blocking,\n                )\n\n                try:\n                    setattr(self, method_name, func)\n                    method_names.append(method_name)\n                except Exception as e:\n                    raise ValueError(f\"Fail to set method_name {method_name}\") from e\n\n        return method_names\n"
  },
  {
    "path": "verl_rl/verl/single_controller/ray/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import (\n    RayClassWithInitArgs,\n    RayResourcePool,\n    RayWorkerGroup,\n    create_colocated_worker_cls,\n    create_colocated_worker_cls_fused,\n)\n\n__all__ = [\n    \"RayClassWithInitArgs\",\n    \"RayResourcePool\",\n    \"RayWorkerGroup\",\n    \"create_colocated_worker_cls\",\n    \"create_colocated_worker_cls_fused\",\n]\n"
  },
  {
    "path": "verl_rl/verl/single_controller/ray/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport logging\nimport time\nfrom copy import deepcopy\nfrom typing import Any, Optional\n\nimport ray\nfrom ray.experimental.state.api import get_actor\nfrom ray.util import list_named_actors\nfrom ray.util.placement_group import PlacementGroup, placement_group\nfrom ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy\n\nfrom verl.protocol import DataProto, _padding_size_key\nfrom verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup\nfrom verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch\nfrom verl.utils.py_functional import temp_env_var\n\n__all__ = [\"Worker\"]\n\n\ndef get_random_string(length: int) -> str:\n    import random\n    import string\n\n    letters_digits = string.ascii_letters + string.digits\n    return \"\".join(random.choice(letters_digits) for _ in range(length))\n\n\ndef func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):\n    class Functor:\n        def __call__(this, *args, **kwargs):\n            args, kwargs = dispatch_fn(self, *args, **kwargs)\n            padding_count = kwargs.pop(_padding_size_key, 0)\n            output = execute_fn(method_name, *args, **kwargs)\n            if blocking:\n                output = ray.get(output)\n            output = collect_fn(self, output)\n            if padding_count > 0:\n                if isinstance(output, DataProto):\n                    indices = [i for i in range(len(output))][:-padding_count]\n                    output = output.select_idxs(indices)\n                elif isinstance(output, list):\n                    output = output[:-padding_count]\n            return output\n\n    # use class type to pass the method_name to get a better observability\n    return type(method_name, (Functor,), {})()\n\n\ndef sort_placement_group_by_node_ip(pgs: list[PlacementGroup]) -> list[PlacementGroup]:\n    \"\"\"\n    Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.\n\n    FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK\n    to be consistent across nodes when resume from checkpoint.\n\n    With this function, if there's only one resource pool and there's no node change, RANK should be consistent\n    across nodes in multiple ray jobs, even if the whole ray cluster is restarted.\n    \"\"\"\n    node_ip = {node[\"NodeID\"]: node[\"NodeManagerAddress\"] for node in ray.nodes()}\n    pg_ip = {}\n    for pg in pgs:\n        specs = ray._private.state.state.placement_group_table(pg.id)\n        # all bunles should be on the same node\n        node_id = specs[\"bundles_to_node_id\"][0]\n        pg_ip[pg.id] = node_ip[node_id]\n    return sorted(pgs, key=lambda pg: pg_ip[pg.id])\n\n\nclass RayResourcePool(ResourcePool):\n    def __init__(\n        self,\n        process_on_nodes: Optional[list[int]] = None,\n        use_gpu: bool = True,\n        name_prefix: str = None,\n        max_colocate_count: int = 10,\n        detached=False,\n        accelerator_type: Optional[str] = None,\n    ) -> None:\n        super().__init__(process_on_nodes, max_colocate_count)\n        self.use_gpu = use_gpu\n        # print(f\"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}\")\n        self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix\n        self.pgs = None\n        self.detached = detached\n        self.accelerator_type = accelerator_type\n\n    def get_placement_groups(self, strategy=\"STRICT_PACK\", name=None, device_name=\"cuda\"):\n        if self.pgs is not None:\n            return self.pgs\n\n        pg_name_prefix = (\n            name if name else f\"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:\"\n        )\n        # print(f\"pg_name_prefix = {pg_name_prefix}\")\n        if device_name == \"npu\":\n            device_name = \"NPU\"\n        elif device_name == \"cuda\":\n            device_name = \"GPU\"\n\n        bundle = {\"CPU\": self.max_colocate_count}\n        if self.use_gpu:\n            bundle[device_name] = 1\n            if self.accelerator_type is not None:\n                bundle[self.accelerator_type] = 1e-4\n        pg_scheme = [[bundle.copy() for _ in range(process_count)] for process_count in self._store]\n\n        lifetime = \"detached\" if self.detached else None\n\n        pgs = [\n            placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime)\n            for idx, bundles in enumerate(pg_scheme)\n        ]\n\n        ray.get([pg.ready() for pg in pgs])\n\n        self.pgs = pgs\n        return pgs\n\n\ndef extract_pg_from_exist(\n    resource_pools: dict[str, RayResourcePool], src_role_names: list[str], resource_pool: RayResourcePool\n) -> list:\n    src_pgs = [\n        pg\n        for role_name, resource_pool in resource_pools.items()\n        for pg in resource_pool.get_placement_groups()\n        if role_name in src_role_names\n    ]\n\n    sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True)\n    sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True)\n\n    unsorted_pgs: list[tuple[int, PlacementGroup]] = []\n    searching_idx = 0\n    for request_process, original_idx in sorted_process_on_nodes:\n        assert searching_idx < len(sorted_src_pgs), f\"no enough nodes for request: searching {searching_idx} th node\"\n        assert request_process <= sorted_src_pgs[searching_idx].bundle_count, (\n            f\"requesting {request_process} processes, bundle count cannot satisfy\"\n        )\n        unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx]))\n        searching_idx += 1\n\n    return [pg for _, pg in sorted(unsorted_pgs)]\n\n\ndef merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool:\n    assert rp1.use_gpu == rp2.use_gpu, \"Both RayResourcePool must either use_gpu or not\"\n    assert rp1.max_colocate_count == rp2.max_colocate_count, \"Both RayResourcePool must has the same max_colocate_count\"\n    assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, \"Both RayResourcePool must has the same n_gpus_per_node\"\n    assert rp1.detached == rp2.detached, \"Detached ResourcePool cannot be merged with non-detached ResourcePool\"\n\n    new_store = rp1.store + rp2.store\n\n    merged = type(rp1)(new_store, rp1.use_gpu, f\"{rp1.name_prefix}_{rp2.name_prefix}\")\n    merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups()\n\n    return merged\n\n\nclass RayClassWithInitArgs(ClassWithInitArgs):\n    \"\"\"A wrapper class for Ray actors with initialization arguments.\n\n    This class extends ClassWithInitArgs to provide additional functionality for\n    configuring and creating Ray actors with specific resource requirements and\n    scheduling strategies.\n    \"\"\"\n\n    def __init__(self, cls, *args, **kwargs) -> None:\n        # self._options = kwargs.pop('options', dict())\n        super().__init__(cls, *args, **kwargs)\n        self._options = {}\n        self._additional_resource = {}\n\n    def set_additional_resource(self, additional_resource):\n        \"\"\"Set additional resource requirements for the actor.\n\n        Args:\n            additional_resource: Dictionary specifying additional resource requirements\n        \"\"\"\n        self._additional_resource = additional_resource\n\n    def update_options(self, options: dict):\n        \"\"\"Update the Ray actor creation options.\n\n        Args:\n            options: Dictionary of options to update\n        \"\"\"\n        self._options.update(options)\n\n    def __call__(\n        self,\n        placement_group,\n        placement_group_bundle_idx,\n        use_gpu: bool = True,\n        num_gpus=1,\n        sharing_with=None,\n        device_name=\"cuda\",\n    ) -> Any:\n        \"\"\"Create and return a Ray actor with the configured options.\n\n        Args:\n            placement_group: Ray placement group for scheduling\n            placement_group_bundle_idx: Index of the bundle in the placement group\n            use_gpu: Whether to use GPU resources\n            num_gpus: Number of GPUs to allocate\n            sharing_with: Actor to share resources with\n            device_name: Device for training\n\n        Returns:\n            A Ray actor handle with the configured options\n        \"\"\"\n        if sharing_with is not None:\n            target_node_id = ray.get(sharing_with.get_node_id.remote())\n            visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote())\n            options = {\"scheduling_strategy\": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)}\n            return self.cls.options(**options).remote(*self.args, cuda_visible_devices=visible_devices, **self.kwargs)\n\n        options = {\n            \"scheduling_strategy\": PlacementGroupSchedulingStrategy(\n                placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx\n            )\n        }\n        options.update(self._options)\n\n        if use_gpu and device_name == \"cuda\":\n            options[\"num_gpus\"] = num_gpus\n        if use_gpu and device_name == \"npu\":\n            options[\"resources\"] = {\"NPU\": num_gpus}\n\n        if len(self._additional_resource) > 1:\n            for k, v in self._additional_resource.items():\n                options[k] = v\n\n        # print(\"cls:\", self.cls)\n        # print(\"args: \", self.args)\n        # print(\"kwargs: \", self.kwargs)\n        return self.cls.options(**options).remote(*self.args, **self.kwargs)\n\n\nclass RayWorkerGroup(WorkerGroup):\n    \"\"\"A group of Ray workers that can be managed collectively.\n\n    This class extends WorkerGroup to provide Ray-specific functionality for\n    creating and managing groups of Ray actors with specific resource requirements\n    and scheduling strategies.\n    \"\"\"\n\n    def __init__(\n        self,\n        resource_pool: RayResourcePool = None,\n        ray_cls_with_init: RayClassWithInitArgs = None,\n        bin_pack: bool = True,\n        name_prefix: str = None,\n        detached=False,\n        worker_names=None,\n        worker_handles: list[ray.actor.ActorHandle] = None,\n        ray_wait_register_center_timeout: int = 300,\n        **kwargs,\n    ) -> None:\n        \"\"\"Initialize a RayWorkerGroup.\n\n        Args:\n            resource_pool: Resource pool for worker allocation\n            ray_cls_with_init: Class with initialization arguments for workers\n            bin_pack: Whether to use strict bin packing for resource allocation\n            name_prefix: Prefix for worker names\n            detached: Whether workers should be detached\n            worker_names: Names of existing workers to attach to\n            ray_wait_register_center_timeout: Timeout for waiting on register center\n            **kwargs: Additional keyword arguments\n        \"\"\"\n        super().__init__(resource_pool=resource_pool, **kwargs)\n        self.ray_cls_with_init = ray_cls_with_init\n        self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix\n        self._ray_wait_register_center_timeout = ray_wait_register_center_timeout\n        # Whether the WorkerGroup is a Colocate WorkerGroup created by FusedWorker.\n        self.fused_worker_used = ray_cls_with_init.fused_worker_used\n        # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to\n        # this WorkerGroup.\n        self.sub_cls_name = \"\"\n        self.device_name = kwargs.get(\"device_name\", \"cuda\")\n        self.profile_steps = kwargs.get(\"profile_steps\", None)\n        self.worker_nsight_options = kwargs.get(\"worker_nsight_options\", None)\n        if self.worker_nsight_options is not None and self.worker_nsight_options[\"capture-range-end\"] is None:\n            self.worker_nsight_options[\"capture-range-end\"] = f\"repeat-shutdown:{6 * len(self.profile_steps)}\"\n\n        if worker_names is not None and (not self.fused_worker_used):\n            assert self._is_init_with_detached_workers\n            self._worker_names = worker_names\n\n        if self._is_init_with_detached_workers:\n            self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles)\n        else:\n            self._init_with_resource_pool(\n                resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached\n            )\n\n        if ray_cls_with_init is not None:\n            self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)\n\n        self.wg_dict = None\n        self.method_names = []\n\n    def _is_worker_alive(self, worker: ray.actor.ActorHandle):\n        \"\"\"Check if a worker actor is still alive.\n\n        Args:\n            worker: Ray actor handle to check\n\n        Returns:\n            bool: True if the worker is alive, False otherwise\n        \"\"\"\n        worker_state_dict = get_actor(worker._actor_id.hex())\n        return worker_state_dict.get(\"state\", \"undefined\") == \"ALIVE\" if worker_state_dict is not None else False\n\n    def _init_with_detached_workers(self, worker_names, worker_handles):\n        # ray.get_actor holds a weak reference to the actor, which causes actors garbage collected unexpectedly\n        # if we only hold spawn RayWorkerGroup. By passing actor handle explicitly, spawn RayWorkerGroup have\n        # strong reference to these actors.\n        # https://github.com/ray-project/ray/pull/45699\n        workers = worker_handles if worker_handles else [ray.get_actor(name=name) for name in worker_names]\n        self._workers = workers\n        self._world_size = len(worker_names)\n\n    def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached):\n        \"\"\"Initialize the worker group by creating new workers from a resource pool.\n\n        Args:\n            resource_pool: Resource pool for worker allocation\n            ray_cls_with_init: Class with initialization arguments for workers\n            bin_pack: Whether to use strict bin packing for resource allocation\n            detached: Whether workers should be detached\n        \"\"\"\n        use_gpu = resource_pool.use_gpu\n\n        strategy = \"PACK\"\n        if bin_pack:\n            strategy = \"STRICT_PACK\"\n        pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name)\n        world_size = resource_pool.world_size\n        self._world_size = world_size\n        # cia.add_kwarg(\"_world_size\", world_size)\n        num_gpus = 1 / resource_pool.max_colocate_count\n\n        rank = -1\n        local_world_size = resource_pool.store[0]\n        for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)):\n            assert local_world_size <= pg.bundle_count, f\"when generating for {self.name_prefix}, for the \"\n            for local_rank in range(local_world_size):\n                rank += 1\n\n                # we pass in environment variable at option so that Worker can use environment variable to set\n                env_vars = {\n                    \"WORLD_SIZE\": str(world_size),\n                    \"RANK\": str(rank),\n                    \"WG_PREFIX\": self.name_prefix,\n                    \"WG_BACKEND\": \"ray\",\n                    \"RAY_LOCAL_WORLD_SIZE\": str(local_world_size),\n                    \"RAY_LOCAL_RANK\": str(local_rank),\n                }\n                if rank != 0:\n                    env_vars[\"MASTER_ADDR\"] = self._master_addr\n                    env_vars[\"MASTER_PORT\"] = self._master_port\n\n                import re\n\n                cia_name = type(ray_cls_with_init.cls).__name__\n                match = re.search(r\"ActorClass\\(([^)]+)\\)\", cia_name)  # ray.remote(Obj) -> \"ActorClass(Obj)\"\n                cia_name = match.group(1) if match else cia_name  # \"ActorClass(Obj)\" -> \"Obj\"\n                name = f\"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}\"  # e.g. Worker_2:5\n\n                if self.profile_steps and self.device_name == \"cuda\":\n                    ray_cls_with_init.update_options(\n                        {\n                            \"runtime_env\": {\n                                \"env_vars\": env_vars,\n                                \"nsight\": self.worker_nsight_options,\n                            },\n                            \"name\": name,\n                        }\n                    )\n                else:\n                    ray_cls_with_init.update_options({\"runtime_env\": {\"env_vars\": env_vars}, \"name\": name})\n\n                if detached:\n                    ray_cls_with_init.update_options({\"lifetime\": \"detached\"})\n\n                # create a worker\n                worker = ray_cls_with_init(\n                    placement_group=pg,\n                    placement_group_bundle_idx=local_rank,\n                    use_gpu=use_gpu,\n                    num_gpus=num_gpus,\n                    device_name=self.device_name,\n                )\n                self._workers.append(worker)\n                self._worker_names.append(name)\n\n                if rank == 0:\n                    register_center_actor = None\n                    actor_name = f\"{self.name_prefix}_register_center\"\n                    start_time = time.time()\n\n                    while time.time() - start_time < self._ray_wait_register_center_timeout:\n                        if actor_name in list_named_actors():\n                            register_center_actor = ray.get_actor(actor_name)\n                            break\n\n                        elapsed = int(time.time() - start_time)\n                        if elapsed % 30 == 0:\n                            logging.warning(\n                                \"Waiting for register center actor %s to be ready. Elapsed time: %s seconds out of \"\n                                \"%s seconds.\",\n                                actor_name,\n                                elapsed,\n                                self._ray_wait_register_center_timeout,\n                            )\n                        time.sleep(1)\n\n                    if register_center_actor is None:\n                        raise TimeoutError(\n                            f\"Failed to get register_center_actor {actor_name} \"\n                            f\"in {list_named_actors(all_namespaces=True)} \"\n                            f\"for {self._ray_wait_register_center_timeout} seconds. \"\n                            \"Ensure that any lingering Ray resources from previous \"\n                            \"runs are cleaned up (e.g., by restarting the Ray cluster), \"\n                            \"or adjust the waiting time by modifying the config \"\n                            \"`trainer.ray_wait_register_center_timeout`.\"\n                        )\n\n                    rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote())\n                    self._master_addr, self._master_port = rank_zero_info[\"MASTER_ADDR\"], rank_zero_info[\"MASTER_PORT\"]\n                    # print(f\"rank_zero_info: {rank_zero_info}\")\n                    # print(f\"master_addr: {self._master_addr}, master_port: {self._master_port}\")\n\n    @property\n    def worker_names(self):\n        return self._worker_names\n\n    @classmethod\n    def from_detached(\n        cls,\n        name_prefix=None,\n        worker_names=None,\n        worker_handles=None,\n        ray_cls_with_init=None,\n        **kwargs,\n    ):\n        \"\"\"Create a worker group from existing detached workers.\n\n        Args:\n            name_prefix: Prefix for worker names\n            worker_names: Names of existing workers to attach to\n            ray_cls_with_init: Class with initialization arguments for workers\n\n        Returns:\n            A new RayWorkerGroup instance\n        \"\"\"\n        worker_group = cls(\n            resource_pool=None,\n            ray_cls_with_init=ray_cls_with_init,\n            name_prefix=name_prefix,\n            worker_names=worker_names,\n            worker_handles=worker_handles,\n            **kwargs,\n        )\n        return worker_group\n\n    def spawn(self, prefix_set):\n        \"\"\"Spawn to a dictionary of worker groups, each with a subset of method with prefix.\n\n        Args:\n            prefix_set: Set of prefixes to create worker groups for\n\n        Returns:\n            Dictionary of worker groups keyed by prefix\n        \"\"\"\n        if self.fused_worker_used:\n            return self.spawn_fused(prefix_set)\n\n        def _rebind_actor_methods(worker_group, actor_name):\n            prefix: str = actor_name + \"_\"\n            for method_name in dir(worker_group):\n                if method_name.startswith(prefix):\n                    original_method_name = method_name.removeprefix(prefix)\n                    method = getattr(worker_group, method_name)\n                    setattr(worker_group, original_method_name, method)\n\n        new_worker_group_dict = {}\n        for prefix in prefix_set:\n            new_worker_group = self.from_detached(\n                name_prefix=self.name_prefix,\n                worker_names=self._worker_names,\n                worker_handles=self._workers,\n                ray_cls_with_init=self.ray_cls_with_init,\n                profile_steps=self.profile_steps,\n                worker_nsight_options=self.worker_nsight_options,\n            )\n\n            _rebind_actor_methods(new_worker_group, prefix)\n            new_worker_group_dict[prefix] = new_worker_group\n        return new_worker_group_dict\n\n    def spawn_fused(self, prefix_set):\n        \"\"\"Create a dictionary of worker groups for fused workers.\n\n        Args:\n            prefix_set: Set of prefixes to create worker groups for\n\n        Returns:\n            Dictionary of worker groups keyed by prefix\n        \"\"\"\n        wg_dict = dict()\n        for key in prefix_set:\n            new_wg = deepcopy(self)\n            new_wg._bind_worker_method(self.ray_cls_with_init.cls.raw_cls_dict[key], func_generator)\n            new_wg.sub_cls_name = key\n            wg_dict[key] = new_wg\n        return wg_dict\n\n    def fuse(self, prefix_set):\n        \"\"\"Fuse multiple worker groups into the current worker group.\n\n        Args:\n            prefix_set: Set of prefixes to fuse into the worker group\n        \"\"\"\n        if self.wg_dict is None:\n            self.wg_dict = self.spawn(prefix_set)\n        for role_name, role_wg in self.wg_dict.items():\n            setattr(self, role_name, role_wg)\n        self.method_names = self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)\n\n    def _execute_remote_single_worker(self, worker, method_name: str, *args, **kwargs):\n        \"\"\"Execute a method on a single worker remotely.\n\n        Args:\n            worker: The worker actor handle\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            Remote object reference to the method execution\n        \"\"\"\n        if self.fused_worker_used and method_name not in self.method_names:\n            remote_call = getattr(worker, self.fused_worker_execute_fn_name)\n            return remote_call.remote(f\"{self.sub_cls_name}_fwmn_{method_name}\", *args, **kwargs)\n        # fused worker not used\n        remote_call = getattr(worker, method_name)\n        return remote_call.remote(*args, **kwargs)\n\n    def execute_rank_zero_sync(self, method_name: str, *args, **kwargs):\n        \"\"\"Execute a method on rank zero worker synchronously.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            Result of the method execution\n        \"\"\"\n        return ray.get(self.execute_rank_zero_async(method_name, *args, **kwargs))\n\n    def execute_rank_zero_async(self, method_name: str, *args, **kwargs):\n        \"\"\"Execute a method on rank zero worker asynchronously.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            Remote object reference to the method execution\n        \"\"\"\n        return self._execute_remote_single_worker(self._workers[0], method_name, *args, **kwargs)\n\n    def execute_rank_zero(self, method_name: str, *args, **kwargs):\n        \"\"\"Alias for execute_rank_zero_async.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            Remote object reference to the method execution\n        \"\"\"\n        return self.execute_rank_zero_async(method_name, *args, **kwargs)\n\n    def execute_all(self, method_name: str, *args, **kwargs):\n        \"\"\"Alias for execute_all_async.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            List of remote object references to the method executions\n        \"\"\"\n        return self.execute_all_async(method_name, *args, **kwargs)\n\n    def execute_all_sync(self, method_name: str, *args, **kwargs):\n        \"\"\"Execute a method on all workers synchronously.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            List of results from all workers\n        \"\"\"\n        return ray.get(self.execute_all_async(method_name, *args, **kwargs))\n\n    def execute_all_async(self, method_name: str, *args, **kwargs):\n        \"\"\"Execute a method on all workers asynchronously.\n\n        Args:\n            method_name: Name of the method to execute\n            *args: Positional arguments for the method\n            **kwargs: Keyword arguments for the method\n\n        Returns:\n            List of remote object references to the method executions\n        \"\"\"\n        # Here, we assume that if all arguments in args and kwargs are lists,\n        # and their lengths match len(self._workers), we'll distribute each\n        # element in these lists to the corresponding worker\n        # print(f\"execute_all_async: method {method_name}({args}, {kwargs})\")\n        length = len(self._workers)\n        if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):\n            if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()):\n                # print(f\"splitting args and kwargs into {length} shards\")\n                result = []\n                for i in range(length):\n                    sliced_args = tuple(arg[i] for arg in args)\n                    sliced_kwargs = {k: v[i] for k, v in kwargs.items()}\n                    result.append(\n                        self._execute_remote_single_worker(self._workers[i], method_name, *sliced_args, **sliced_kwargs)\n                    )\n                return result\n\n        return [self._execute_remote_single_worker(worker, method_name, *args, **kwargs) for worker in self._workers]\n\n    @property\n    def master_address(self):\n        return self._master_addr\n\n    @property\n    def master_port(self):\n        return self._master_port\n\n    @property\n    def workers(self):\n        return self._workers\n\n    @property\n    def world_size(self):\n        return self._world_size\n\n\n\"\"\"\nUtilities that enables creating workers inside the same ray.Actor,\nwith code written in separate ray.Actors.\n\"\"\"\n\n\n# deprecated, switching to FusedWorker\ndef _bind_workers_method_to_parent(cls, key, user_defined_cls):\n    \"\"\"\n    Binds the methods of each worker to the WorkerDict.\n    Note that we only bind public methods that are decorated by register\n    \"\"\"\n\n    for method_name in dir(user_defined_cls):\n        try:\n            method = getattr(user_defined_cls, method_name)\n            assert callable(method), f\"{method_name} in {user_defined_cls} is not callable\"\n        except Exception:\n            # if it is a property, it will fail because Class doesn't have instance property\n            continue\n\n        if hasattr(method, MAGIC_ATTR):\n\n            def generate_function(name, key=key):\n                def func(self, *args, **kwargs):\n                    # dispatch to the actual worker\n                    return getattr(self.worker_dict[key], name)(*args, **kwargs)\n\n                async def async_func(self, *args, **kwargs):\n                    # dispatch to the actual worker\n                    return await getattr(self.worker_dict[key], name)(*args, **kwargs)\n\n                wrapper = async_func if inspect.iscoroutinefunction(method) else func  # noqa: B023\n\n                return wrapper\n\n            func = generate_function(method_name)\n            # pass MAGIC_ATTR for outer worker group\n            attrs = getattr(method, MAGIC_ATTR)\n            setattr(func, MAGIC_ATTR, attrs)\n            try:\n                # bind direct rollout method to class without prefix\n                if attrs[\"dispatch_mode\"] == Dispatch.DIRECT_ROLLOUT_METHOD and \"rollout\" in key:\n                    assert not hasattr(cls, method_name), (\n                        f\"conflict direct rollout method {method_name} with role {key}\"\n                    )\n                    setattr(cls, method_name, func)\n                    print(f\"bind role {key} method {method_name} to class {cls}\")\n                else:\n                    method_name_with_prefix = key + \"_\" + method_name\n                    setattr(cls, method_name_with_prefix, func)\n            except Exception as e:\n                raise ValueError(f\"Fail to set method_name {method_name}\") from e\n\n\ndef _unwrap_ray_remote(cls):\n    if hasattr(cls, \"__ray_actor_class__\"):\n        cls = cls.__ray_actor_class__\n    return cls\n\n\ndef _determine_fsdp_megatron_base_class(mros: list):\n    \"\"\"\n    - megatron: base class should be MegatronWorker\n    - fsdp: base class should be Worker\n    \"\"\"\n    for cls in mros[0]:\n        if cls.__name__ == \"MegatronWorker\":\n            return cls\n        if cls.__name__ == \"Worker\":\n            return cls\n    raise ValueError(f\"Cannot determine base class for {mros}\")\n\n\n# deprecated, switching to FusedWorker\ndef create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):\n    \"\"\"\n    This function should return a class instance that delegates the calls to every\n    cls in cls_dict\n    \"\"\"\n    cls_dict = {}\n    init_args_dict = {}\n    worker_cls = _determine_fsdp_megatron_base_class(\n        [cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()]\n    )\n    assert issubclass(worker_cls, Worker), f\"worker_cls {worker_cls} should be a subclass of Worker\"\n    print(f\"colocated worker base class {worker_cls}\")\n\n    for key, cls in class_dict.items():\n        cls_dict[key] = cls.cls\n        init_args_dict[key] = {\"args\": cls.args, \"kwargs\": cls.kwargs}\n\n    assert cls_dict.keys() == init_args_dict.keys()\n\n    # TODO: create a class with customizable name\n    class WorkerDict(worker_cls):\n        def __init__(self):\n            super().__init__()\n            self.worker_dict = {}\n            for key, user_defined_cls in cls_dict.items():\n                user_defined_cls = _unwrap_ray_remote(user_defined_cls)\n                # directly instantiate the class without remote\n                # in worker class, e.g. <verl.single_controller.base.worker.Worker>\n                # when DISABLE_WORKER_INIT == 1 it will return immediately\n                with temp_env_var(\"DISABLE_WORKER_INIT\", \"1\"):\n                    self.worker_dict[key] = user_defined_cls(\n                        *init_args_dict[key].get(\"args\", ()), **init_args_dict[key].get(\"kwargs\", {})\n                    )\n\n    # now monkey-patch the methods from inner class to WorkerDict\n    for key, user_defined_cls in cls_dict.items():\n        user_defined_cls = _unwrap_ray_remote(user_defined_cls)\n        _bind_workers_method_to_parent(WorkerDict, key, user_defined_cls)\n\n    remote_cls = ray.remote(WorkerDict)\n    remote_cls = RayClassWithInitArgs(cls=remote_cls)\n    return remote_cls\n\n\nFusedWorkerCLSName = \"FusedWorker\"\n\n\ndef create_colocated_worker_raw_cls(class_dict: dict[str, RayClassWithInitArgs]):\n    \"\"\"\n    This function returns a FusedWorker class.\n\n    `FusedWorker.{class_name}` -> FusedClass\n        Use `class_name` as a param to directly access the underlying class.\n\n    `FusedWorker._fuw_execute(\"{class_name}_fwmn_{method_name}\", *args, **kwargs)`\n        First param must be \"{class_name}_fwmn_{method_name}\" in order to access `method_name`\n        of underlying class `{class_name}`.\n\n    `FusedWorker.fused_worker_dict` -> {\"class_name\": FusedClass}\n        Stores all underlying classes.\n\n    `FusedClass.fused_worker_dict` -> {\"class_name\": FusedClass}\n        The same as `FusedWorker.fused_worker_dict`, enables underlying class to access other\n        underlying classes.\n    \"\"\"\n    raw_cls_dict = {cls_name: _unwrap_ray_remote(cia.cls) for cls_name, cia in class_dict.items()}\n    init_args_dict = {cls_name: cia.args for cls_name, cia in class_dict.items()}\n    init_kwargs_dict = {cls_name: cia.kwargs for cls_name, cia in class_dict.items()}\n    cls_names = list(class_dict.keys())\n\n    # FusedWorker_Actor_Critic\n    class_name_renamed = \"_\".join([FusedWorkerCLSName] + cls_names)\n\n    class FusedWorker(Worker):\n        def __init__(self, *args, **kwargs):\n            super().__init__(*args, **kwargs)\n            self.cls_names = cls_names\n            self.raw_cls_dict = raw_cls_dict\n            self.init_args_dict = init_args_dict\n            self.init_kwargs_dict = init_kwargs_dict\n\n            for cls_name, udc, ud_args, ud_kwargs in zip(\n                self.cls_names,\n                self.raw_cls_dict.values(),\n                self.init_args_dict.values(),\n                self.init_kwargs_dict.values(),\n                strict=True,\n            ):\n                with temp_env_var(\"DISABLE_WORKER_INIT\", \"1\"):\n                    udc._get_ray_actor_cls_name = lambda x, name_renamed=class_name_renamed: name_renamed\n                    udc._get_ray_method_prefix = lambda x, name_prefixed=cls_name: f\"{name_prefixed}_\"\n                    # cls_name = \"actor\", \"critic\", udc = ActorWorker, CriticWorker\n                    self.fused_worker_dict[cls_name] = udc(*ud_args, **ud_kwargs)\n                    setattr(self, cls_name, self.fused_worker_dict[cls_name])\n\n            # injecting fused_worker to each sub worker so they can be aware of existence of each other\n            for _, worker in self.fused_worker_dict.items():\n                setattr(worker, Worker.fused_worker_attr_name, self.fused_worker_dict)\n\n        def _fuw_execute(self, method_name: str, *args, **kwargs):\n            # for fused_worker, method_name is in a form of \"{cls_name}_fwmn_{method_name}\"\n            # where fwmn stands \"fused worker method name\"\n            names = method_name.split(\"_fwmn_\")\n            cls_name = names[0]\n            method_name = names[1]\n\n            assert cls_name in self.fused_worker_dict, (\n                f\"calling {cls_name}'s {method_name}, but {cls_name} not in fused_worker_dict\"\n            )\n            udc_method = getattr(self.fused_worker_dict[cls_name], method_name)\n            return udc_method(*args, **kwargs)\n\n    renamed_fused_worker_cls = type(class_name_renamed, (FusedWorker,), {})\n    renamed_fused_worker_cls.is_fused_worker = True\n    renamed_fused_worker_cls.raw_cls_dict = raw_cls_dict\n\n    return renamed_fused_worker_cls\n\n\ndef create_colocated_worker_cls_fused(class_dict: dict[str, RayClassWithInitArgs]):\n    \"\"\"\n    This function returns a RayClassWithInitArgs instance of FusedWorker, which is an replacement\n    of `create_colocated_worker_cls`. WorkerGroup constructed using this class will be a colocated\n    WorkerGroup, which will be referenced as `ColocateWorkerGroup` below.\n\n    `ColocateWorkerGroup.spawn(prefix_set)`\n        returns a dict of WorkerGroup {\"class_name\": WorkerGroup}, WorkerGroup in this dict will\n        have methods of underlying class `class_name` attached.\n\n    `ColocateWorkerGroup.fuse(prefix_set)`\n        After executing this function, `ColocateWorkerGroup.{class_name}` will return WorkerGroup\n        with methods of underlying class `class_name` attached.\n    \"\"\"\n    raw_colocated_worker_cls = create_colocated_worker_raw_cls(class_dict)\n\n    remote_cls = ray.remote(raw_colocated_worker_cls)\n    cia = RayClassWithInitArgs(cls=remote_cls)\n    cia.fused_worker_used = True\n\n    return cia\n"
  },
  {
    "path": "verl_rl/verl/single_controller/ray/megatron.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport ray\n\nfrom verl.single_controller.base.megatron.worker import DistGlobalInfo, DistRankInfo\nfrom verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup\n\nfrom .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n\n\n# NOTE(sgm): for open-source megatron-core\nclass NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):\n    \"\"\"\n    MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup\n    so that the dispatcher can use it to dispatch data.\n    \"\"\"\n\n    def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs):\n        \"\"\"\n        Initialize the NVMegatronRayWorkerGroup.\n\n        Args:\n            resource_pool (RayResourcePool): The resource pool containing worker resources\n            ray_cls_with_init (RayClassWithInitArgs): The Ray class with initialization arguments\n            **kwargs: Additional keyword arguments to pass to the parent class\n        \"\"\"\n        super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs)\n        self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name=\"get_megatron_rank_info\")\n        self._megatron_global_info: DistGlobalInfo = ray.get(\n            self.execute_rank_zero_async(method_name=\"get_megatron_global_info\")\n        )\n\n\nclass MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):\n    \"\"\"\n    MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup\n    so that the dispatcher can use it to dispatch data.\n    \"\"\"\n\n    def __init__(\n        self,\n        resource_pool: RayResourcePool,\n        ray_cls_with_init: RayClassWithInitArgs,\n        default_megatron_kwargs: dict = None,\n        **kwargs,\n    ):\n        super().__init__(\n            resource_pool=resource_pool,\n            ray_cls_with_init=ray_cls_with_init,\n            default_megatron_kwargs=default_megatron_kwargs,\n            **kwargs,\n        )\n        self.init_megatron(default_megatron_kwargs=default_megatron_kwargs)\n        self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name=\"get_megatron_rank_info\")\n        self._megatron_global_info: DistGlobalInfo = ray.get(\n            self.execute_rank_zero_async(method_name=\"get_megatron_global_info\")\n        )\n\n    def init_megatron(self, default_megatron_kwargs: Optional[dict] = None):\n        # after super, we will call init of each worker\n        if not self._is_init_with_detached_workers:\n            # only init_megatron if the WorkerGroup is created from scratch\n            self.execute_all_sync(method_name=\"init_megatron\", default_megatron_kwargs=default_megatron_kwargs)\n"
  },
  {
    "path": "verl_rl/verl/third_party/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/third_party/sglang/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/third_party/sglang/parallel_state.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The SGlang team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n\"\"\"Model and data parallel groups.\"\"\"\n\nimport os\nfrom typing import Optional\n\nimport sglang.srt.distributed.parallel_state as ps\nimport torch\nimport torch.distributed\nfrom sglang.srt.distributed.parallel_state import (\n    get_pp_group,\n    get_world_group,\n    init_distributed_environment,\n    init_model_parallel_group,\n)\n\n\"\"\"\nThis version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.\n- We assume the Megatron tp+dp+pp world is already established before calling this function.\n\n\"\"\"\n\n# Device mesh for using DTensor\n_DEVICE_MESH = None\n\n# Tensor model parallel group that the current rank belongs to.\n_TP = None\n# Pipeline model parallel group that the current rank belongs to.\n_PP = None\n\n\n# This method is for initializing the ParallelGroup when using HybridEngine\n# NOTE(linjunrong): this function is for megatron\ndef initialize_parallel_state(\n    distributed_init_method: str = \"env://\",\n    backend: str = \"nccl\",\n    tensor_model_parallel_size: int = 1,\n    num_tp_per_train_tp: int = 1,\n    pipeline_model_parallel_size: int = 1,\n):\n    # torch.distributed.all_reduce does not free the input tensor until\n    # the synchronization point. This causes the memory usage to grow\n    # as the number of all_reduce calls increases. This env var disables\n    # this behavior.\n    # Related issue:\n    # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573\n    os.environ[\"TORCH_NCCL_AVOID_RECORD_STREAMS\"] = \"1\"\n\n    # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.\n    rank = int(os.getenv(\"RANK\", \"-1\"))\n    local_rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n\n    # Use the world_size set by TORCHRUN\n    world_size = int(os.getenv(\"WORLD_SIZE\", \"-1\"))\n    assert world_size != -1, \"The world_size is set to -1, not initialized by TORCHRUN\"\n    init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend)\n    if torch.distributed.get_world_size() > 1:\n        # NOTE: build a separate inference group with infer tp & micro dp\n        initialize_model_parallel_for_sglang(\n            tensor_model_parallel_size=tensor_model_parallel_size,\n            num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp,\n        )\n    else:\n        initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)\n\n\n# NOTE(linjunrong): After init SGLang rollout using class EngineFragment, user should always remember to call\n# this function to sync the _TP, _PP define at the beginning of this file. Otherwise, only the conterparts\n# inside sglang.srt.distributed are init as ProcessGroup, the symbols defined in this file remain as None.\n# It could be weird to maintain two _TP and _PP, I follow the same way to maintain an extra ones for\n# verl itself as how it was done in verl.third_party.vllm.parallel_state. Note that the process is a little\n# bit different\ndef ensure_model_parallel_initialized(\n    tensor_model_parallel_size: int,\n    pipeline_model_parallel_size: int = 1,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"Helper to initialize model parallel groups if they are not initialized,\n    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected\n    values if the model parallel groups are initialized.\n    \"\"\"\n    # get the backend of _DEVICE_WORLD_GROUP\n    backend = backend or torch.distributed.get_backend(get_world_group().device_group)\n    if not model_parallel_is_initialized():\n        initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)\n        return\n\n    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (\n        f\"tensor parallel group already initialized, but of unexpected size: \"\n        f\"{get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}\"\n    )\n    pp_world_size = get_pp_group().world_size\n    assert pp_world_size == pipeline_model_parallel_size, (\n        f\"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. \"\n        f\"{pipeline_model_parallel_size=}\"\n    )\n\n\n# TODO(sgm): deviate from the v0.5.4, not pp now\n# NOTE(linjunrong): the SGLang version using _TP instead of ps._TP\ndef model_parallel_is_initialized():\n    \"\"\"Check if tensor and pipeline parallel groups are initialized.\"\"\"\n    return _TP is not None\n    # and _PIPELINE_MODEL_PARALLEL_GROUP is not None)\n\n\ndef initialize_model_parallel_for_sglang(\n    tensor_model_parallel_size: int,\n    num_tensor_model_parallel_groups_per_train_tp: int = 1,\n    pipeline_model_parallel_size: int = 1,\n) -> None:\n    pass\n\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n\n    assert isinstance(tensor_model_parallel_size, int)\n\n    # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group\n    # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group\n\n    # Build the tensor model-parallel groups.\n    assert ps._TP is None, \"tensor model parallel group is already initialized\"\n\n    global _TP\n\n    world_size: int = torch.distributed.get_world_size()\n\n    backend = torch.distributed.get_backend()\n\n    num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size\n\n    if num_tensor_model_parallel_groups_per_train_tp == 1:\n        # if tensor_model_parallel_size == train_tensor_parallel_size:\n        # using the same tp group as Megatron/vllm\n        assert _TP is None, \"tensor model parallel group is already initialized\"\n        group_ranks = []\n        for i in range(num_tensor_model_parallel_groups):\n            ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n            group_ranks.append(ranks)\n        _TP = init_model_parallel_group(\n            group_ranks=group_ranks,\n            local_rank=get_world_group().local_rank,\n            backend=backend,\n            use_custom_allreduce=False,  # TODO: check why True is not work in Ray trainer\n            use_message_queue_broadcaster=True,\n        )\n        ps._TP = _TP\n        # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine\n    else:\n        # initialize a micro_dp group and a tp group\n        # assume training tp=4, infer tp=2, then, weight is partitioned as\n        # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference\n\n        # Build the inference tp groups\n        # train_tp = train_tensor_parallel_size\n        train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size\n        # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size\n        assert _TP is None, \"tensor model parallel group is already initialized\"\n        group_ranks = []\n        for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):\n            start = train_tp * i\n            end = train_tp * (i + 1)\n            for j in range(num_tensor_model_parallel_groups_per_train_tp):\n                ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))\n                for i in range(len(ranks)):\n                    ranks[i] += j\n                group_ranks.append(ranks)\n        _TP = init_model_parallel_group(\n            group_ranks=group_ranks,\n            local_rank=get_world_group().local_rank,\n            backend=backend,\n            use_custom_allreduce=False,  # TODO: check why True is not work in Ray trainer\n            use_message_queue_broadcaster=True,\n        )\n        ps._TP = _TP\n\n    # Build the pipeline model-parallel groups.\n    # global _PIPELINE_MODEL_PARALLEL_GROUP\n    # global _PIPELINE_GLOBAL_RANKS\n    # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, (\"pipeline model parallel group is already initialized\")\n\n    # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()\n    # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()\n\n    # TODO: init using device mesh (not support hybrid engine now)\n    # Build the pipeline model-parallel groups.\n    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n    global _PP\n    assert _PP is None, \"pipeline model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_pipeline_model_parallel_groups):\n        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))\n        group_ranks.append(ranks)\n    # pipeline parallel does not need custom allreduce\n    _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)\n    ps._PP = _PP  # for verl\n\n\ndef initialize_model_parallel(\n    tensor_model_parallel_size: int = 1,\n    pipeline_model_parallel_size: int = 1,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"\n    NOTE: This method is a hack from the open-sourced version without\n    asertion of world_size = tp * pp\n\n    Initialize model parallel groups.\n\n    Arguments:\n        tensor_model_parallel_size: number of GPUs used for tensor model\n            parallelism.\n        pipeline_model_parallel_size: number of GPUs used for pipeline model\n            parallelism.\n\n    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we\n    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize\n    the model pipeline. The present function will\n    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:\n        4 tensor model-parallel groups:\n            [g0, g1], [g2, g3], [g4, g5], [g6, g7]\n        2 pipeline model-parallel groups:\n            [g0, g2, g4, g6], [g1, g3, g5, g7]\n    Note that for efficiency, the caller should make sure adjacent ranks\n    are on the same DGX box. For example if we are using 2 DGX-1 boxes\n    with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n    ranks 8 to 15 belong to the second box.\n    \"\"\"\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n    world_size: int = torch.distributed.get_world_size()\n    backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group)\n\n    # NOTE(sgm) we don't assert world_size == tp * pp\n    # DP is not managed by vllm but by the VeRL WorkerGroup\n    # if (world_size !=\n    #         tensor_model_parallel_size * pipeline_model_parallel_size):\n    #     raise RuntimeError(\n    #         f\"world_size ({world_size}) is not equal to \"\n    #         f\"tensor_model_parallel_size ({tensor_model_parallel_size}) x \"\n    #         f\"pipeline_model_parallel_size ({pipeline_model_parallel_size})\")\n\n    num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n\n    global _TP\n    assert _TP is None, \"tensor model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_tensor_model_parallel_groups):\n        ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size))\n        group_ranks.append(ranks)\n\n    # message queue broadcaster is only used in tensor model parallel group\n    if ps._TP is not None:\n        _TP = ps._TP\n    else:\n        _TP = init_model_parallel_group(\n            group_ranks,\n            get_world_group().local_rank,\n            backend,\n            use_custom_allreduce=False,  # TODO: check why True is not work in Ray trainer\n            use_message_queue_broadcaster=True,\n        )\n        ps._TP = _TP\n\n    # TODO: init using device mesh (not support hybrid engine now)\n    # Build the pipeline model-parallel groups.\n    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n    global _PP\n    assert _PP is None, \"pipeline model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_pipeline_model_parallel_groups):\n        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))\n        group_ranks.append(ranks)\n    # pipeline parallel does not need custom allreduce\n    if ps._TP is not None:\n        _PP = ps._TP\n    else:\n        _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)\n        ps._PP = _PP\n\n\n\"\"\"\nDevice mesh utilities\n\"\"\"\n\n\ndef get_device_mesh():\n    assert _DEVICE_MESH is not None, \"device mesh is not initialized\"\n    return _DEVICE_MESH\n\n\n\"\"\"\nTensor model parallel utilities\n\"\"\"\n\n\n# NOTE(linjunrong): In the vllm version parallel_state.py. verl created its own _TP and _PP as verl want to use\n# the process group for some extra purpose. Under the hood, there is no difference between them and the original\n# one in vllm.distributed.parallel_state. However, the implementation need to hack the init process of inference\n# engine, as we do not maintain another SGLang here, I just use the original _TP and _PP directly.\ndef get_tensor_model_parallel_group():\n    \"\"\"Get the tensor model parallel group the caller rank belongs to.\"\"\"\n\n    assert _TP is not None, \"tensor model parallel group is not initialized\"\n    return _TP.device_group\n\n\ndef get_tensor_model_parallel_world_size():\n    \"\"\"Return world size for the tensor model parallel group.\"\"\"\n    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())\n\n\ndef get_tensor_model_parallel_rank():\n    \"\"\"Return my rank for the tensor model parallel group.\"\"\"\n    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())\n\n\ndef get_tensor_model_parallel_src_rank():\n    \"\"\"Calculate the global rank corresponding to the first local rank\n    in the tensor model parallel group.\"\"\"\n    global_rank = torch.distributed.get_rank()\n    local_world_size = get_tensor_model_parallel_world_size()\n    return (global_rank // local_world_size) * local_world_size\n"
  },
  {
    "path": "verl_rl/verl/third_party/torch/__init__.py",
    "content": "# official torch 2.6.0 set_model_state_dict API leads to OOM\n# this is a copy of torch/distributed/checkpoint from torch 2.7.0\n\n# From PyTorch:\n\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n# From Caffe2:\n\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n\n# All contributions by Cruise LLC:\n# Copyright (c) 2022 Cruise LLC.\n# All rights reserved.\n\n# All contributions by Tri Dao:\n# Copyright (c) 2024 Tri Dao.\n# All rights reserved.\n\n# All contributions by Arm:\n# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "verl_rl/verl/third_party/torch/distributed/__init__.py",
    "content": "# official torch 2.6.0 set_model_state_dict API leads to OOM\n# this is a copy of torch/distributed/checkpoint from torch 2.7.0\n\n# From PyTorch:\n\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n# From Caffe2:\n\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n\n# All contributions by Cruise LLC:\n# Copyright (c) 2022 Cruise LLC.\n# All rights reserved.\n\n# All contributions by Tri Dao:\n# Copyright (c) 2024 Tri Dao.\n# All rights reserved.\n\n# All contributions by Arm:\n# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "verl_rl/verl/third_party/torch/distributed/_state_dict_utils.py",
    "content": "# official torch 2.6.0 set_model_state_dict API leads to OOM\n# this is a copy of torch/distributed/checkpoint from torch 2.7.0\n\n# From PyTorch:\n\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n# From Caffe2:\n\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n\n# All contributions by Cruise LLC:\n# Copyright (c) 2022 Cruise LLC.\n# All rights reserved.\n\n# All contributions by Tri Dao:\n# Copyright (c) 2024 Tri Dao.\n# All rights reserved.\n\n# All contributions by Arm:\n# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n\n\n# ruff: noqa: B028, UP038, UP007, E721, E501\n# mypy: allow-untyped-defs\nimport copy\nimport io\nimport math\nimport weakref\nfrom collections.abc import Mapping, MutableMapping\nfrom typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union, cast\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom torch.distributed._functional_collectives import AsyncCollectiveTensor\n\nif dist.is_available() or TYPE_CHECKING:\n    from torch.distributed import distributed_c10d\n    from torch.distributed._shard.sharded_tensor import ShardedTensor\n    from torch.distributed.tensor import DTensor, Replicate, distribute_tensor\n    from torch.distributed.tensor._utils import compute_local_shape_and_global_offset\n\n\ndef _identity_func(\n    obj: torch.Tensor,\n    pg: Optional[dist.ProcessGroup],\n    device: Optional[torch.device],\n    companion_obj: Any,\n) -> torch.Tensor:\n    return obj\n\n\ndef _all_gather_sharded_tensor(\n    sharded_tensor: \"ShardedTensor\",\n    pg: Optional[dist.ProcessGroup] = None,\n    device: Optional[torch.device] = None,\n) -> torch.Tensor:\n    if pg is None:\n        pg = distributed_c10d._get_default_group()\n    world_size = dist.get_world_size(pg)\n    shards = sharded_tensor.local_shards()\n    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]\n    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]\n    chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size\n    pg_device = distributed_c10d._get_pg_default_device(pg) if device is None else device\n    if shards:\n        local_tensor = shards[0].tensor.flatten()\n        if local_tensor.device.type != pg_device.type:\n            local_tensor = local_tensor.to(pg_device)\n        num_padding = chunk_size - local_tensor.numel()\n        if num_padding > 0:\n            local_tensor = F.pad(local_tensor, [0, num_padding])\n    else:\n        local_tensor = torch.zeros(chunk_size, dtype=sharded_tensor.dtype, device=pg_device)\n\n    tensor = torch.empty(\n        chunk_size * world_size,\n        dtype=local_tensor.dtype,\n        device=pg_device,\n    )\n    dist.all_gather_into_tensor(tensor, local_tensor, group=pg)\n\n    tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())\n    return tensor\n\n\nclass CompanionMismatch(Exception):\n    pass\n\n\ndef _iterate_state_dict(\n    iter_object: Any,\n    sharded_tensor_func: Callable,\n    dtensor_func: Callable,\n    tensor_func: Callable,\n    *,\n    pg: Optional[dist.ProcessGroup] = None,\n    device: Optional[torch.device] = None,\n    cpu_offload: bool = False,\n    companion_obj: Any = None,\n    ranks_only: tuple[int, ...] = (),\n    type_check: bool = True,\n    non_blocking: bool = True,\n) -> dict[str, Any]:\n    \"\"\"Iterate through the state dict, applying the given functions to each tensor type.\n\n    Args:\n        iter_object (Any): the target state_dict.\n        sharded_tensor_func (Callable): the function to apply to ShardedTensor\n        dtensor_func (Callable): the function to apply to DTensor\n        tensor_func (Callable): the function to apply to Tensor\n        pg (Optional[dist.ProcessGroup]): process group passed to tensor functions\n        device (Optional[torch.device]): device passed to tensor functions\n        cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored\n            if a companion_obj is supplied.\n        companion_obj (Any): A companion object to the state dict. If this object\n            is supplied, we attempt to copy the tensor to the companion object.\n        ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will\n            have the same state_dicts. Otherwise only ranks that in ``ranks_only``\n            have the same state_dicts. Other ranks will get empty state_dicts.\n        type_check (bool): check if the instance data type is a supported type\n            that can be saved by DCP.  The current supported data types are\n            torch.Tensor, DTensor, int, float, str, list, dict, None.\n        non_blocking (bool): whether to use non-blocking copy when copying to the companion object.\n    \"\"\"\n    # TODO: should we use pytree?\n    cpu_device = torch.device(\"cpu\")\n    if isinstance(iter_object, ShardedTensor):\n        ret = sharded_tensor_func(iter_object, pg, device, companion_obj)\n    elif isinstance(iter_object, DTensor):\n        ret = dtensor_func(iter_object, pg, device, companion_obj)\n    elif isinstance(iter_object, torch.Tensor):\n        ret = tensor_func(iter_object, pg, device, companion_obj)\n    elif isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) or iter_object is None:\n        ret = iter_object\n    elif isinstance(iter_object, dict):\n        if companion_obj is not None and (\n            not isinstance(companion_obj, dict) or set(companion_obj.keys()) != set(iter_object.keys())\n        ):\n            msg = \"\" if isinstance(companion_obj, dict) else f\"{set(companion_obj.keys())=} {set(iter_object.keys())=}\"\n            raise CompanionMismatch(msg)\n\n        ret = {\n            key: _iterate_state_dict(\n                value,\n                sharded_tensor_func,\n                dtensor_func,\n                tensor_func,\n                pg=pg,\n                device=device,\n                cpu_offload=cpu_offload,\n                companion_obj=companion_obj[key] if companion_obj is not None else None,\n                ranks_only=ranks_only,\n                type_check=type_check,\n                non_blocking=non_blocking,\n            )\n            for key, value in iter_object.items()\n        }\n    elif isinstance(iter_object, (list, tuple)):\n        if companion_obj is not None and (\n            not isinstance(companion_obj, (list, tuple)) or len(companion_obj) != len(iter_object)\n        ):\n            raise CompanionMismatch\n\n        ret = [\n            _iterate_state_dict(\n                v,\n                sharded_tensor_func,\n                dtensor_func,\n                tensor_func,\n                pg=pg,\n                device=device,\n                cpu_offload=cpu_offload,\n                companion_obj=companion_obj[idx] if companion_obj is not None else None,\n                ranks_only=ranks_only,\n                type_check=type_check,\n                non_blocking=non_blocking,\n            )\n            for idx, v in enumerate(iter_object)\n        ]\n        if isinstance(iter_object, tuple):\n            ret = tuple(ret)\n    elif not type_check:\n        ret = copy.deepcopy(iter_object)\n    else:\n        raise ValueError(f\"Unexpected value type {type(iter_object)}\")\n\n    if not ranks_only or dist.get_rank(pg) in ranks_only:\n        if isinstance(ret, torch.Tensor):\n            if cpu_offload and companion_obj is None:\n                ret = ret.to(cpu_device)\n\n            if companion_obj is not None:\n                if isinstance(companion_obj, DTensor):\n                    assert isinstance(ret, DTensor)\n                    companion_obj._local_tensor.copy_(ret._local_tensor, non_blocking=non_blocking)\n                else:\n                    companion_obj.copy_(ret, non_blocking=non_blocking)\n                ret = companion_obj\n    else:\n        ret = {} if isinstance(ret, dict) else None\n\n    return ret\n\n\ndef _gather_state_dict(\n    state_dict: dict[str, Any],\n    *,\n    pg: Optional[dist.ProcessGroup] = None,\n    device: Optional[torch.device] = None,\n    cpu_offload: bool = False,\n    ranks_only: tuple[int, ...] = (),\n    type_check: bool = True,\n) -> dict[str, Any]:\n    \"\"\"\n    Given a state_dict, this API gathers all the ShardedTensors or DTensors in\n    the state_dict.\n\n\n    Args:\n        state_dict (Dict[str, Any]): the target sharded state_dict.\n        pg (Optional[dist.ProcessGroup]): the process group that is used to\n            gather ShardedTensor. Note that gathering a DTensor will use\n            the DeviceMesh. So this argument will be ignored when gathering a\n            DTensor.\n        device: (Optional[torch.device]): the device that is used to\n            perform allgather for ShardedTensor. Note that gathering a DTensor\n            will use the DeviceMesh. So this argument will be ignored when\n            gathering a DTensor.\n        cpu_offload (bool): whether to offload the tensors to CPU memory. The\n            default value is False.\n        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will\n            have the same state_dicts. Otherwise only ranks that in ``ranks_only``\n            have the same state_dicts. Other ranks will get empty state_dicts.\n        type_check: (bool): check if the instance data type is a supported type\n            that can be saved by DCP.  The current supported data types are\n            torch.Tensor, DTensor, int, float, str, list, dict, None.\n\n    Returns:\n        The gathered state dictionary.\n    \"\"\"\n\n    def sharded_tensor_func(value, pg, device, companion_obj):\n        # ShardedTensor does not seem to record the original device type.\n        # So if the tensor is moved to CPU, we won't know the original type.\n        # As a result, we have to rely on the user to tell us the correct one.\n        cpu_device = torch.device(\"cpu\")\n        output_tensor = _all_gather_sharded_tensor(value, pg, device)\n        local_shard_device = value.local_shards()[0].tensor.device if value.local_shards() else cpu_device\n        if output_tensor.device != local_shard_device:\n            value = output_tensor.to(local_shard_device)\n        else:\n            value = output_tensor\n        return value\n\n    def dtensor_func(value, pg, device, companion_obj):\n        if value.device != value.device_mesh.device_type:\n            value = value.to(value.device_mesh.device_type)\n        # FSDP all_gather: [Shard(0)] -> [Replicate()]\n        # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]\n        # 2D FSDP + TP all_gather:\n        # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]\n        # - [Shard(0), Replicate()] -> [Replicate(), Replicate()]\n        placements = [Replicate() for _ in value.placements]\n        value = value.redistribute(\n            device_mesh=value.device_mesh,\n            placements=placements,\n        )\n        # Call `wait()` to force the tensor to be synchronous with respect\n        # to the main stream.\n        # See the discussion in https://github.com/pytorch/pytorch/pull/117799.\n        value = value.to_local()\n        if isinstance(value, AsyncCollectiveTensor):\n            value = value.wait()\n        return value\n\n    return _iterate_state_dict(\n        state_dict,\n        sharded_tensor_func,\n        dtensor_func,\n        _identity_func,\n        pg=pg,\n        device=device,\n        cpu_offload=cpu_offload,\n        ranks_only=ranks_only,\n        type_check=type_check,\n    )\n\n\ndef _offload_state_dict_to_cpu(\n    state_dict: dict[str, Any],\n    *,\n    ranks_only: tuple[int, ...] = (),\n    type_check: bool = True,\n) -> dict[str, Any]:\n    \"\"\"\n    Given a state_dict, this API offload all the tensors to CPU memory.\n\n    Args:\n        state_dict (Dict[str, Any]): the target state_dict.\n        pg (Optional[dist.ProcessGroup]): the process group that is used to\n            gather ShardedTensor. Note that gathering a DTensor will use\n            the DeviceMesh. So this argument will be ignored when gathering a\n            DTensor.\n        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will\n            have the same state_dicts. Otherwise only ranks that in ``ranks_only``\n            have the same state_dicts. Other ranks will get empty state_dicts.\n        type_check: (bool): check if the instance data type is a supported type\n            that can be saved by DCP.  The current supported data types are\n            torch.Tensor, DTensor, int, float, str, list, dict, None.\n\n    Returns:\n        The gathered state dictionary.\n    \"\"\"\n\n    ret = _iterate_state_dict(\n        state_dict,\n        _identity_func,\n        _identity_func,\n        _identity_func,\n        pg=None,\n        device=None,\n        cpu_offload=True,\n        ranks_only=ranks_only,\n        type_check=type_check,\n    )\n    return ret\n\n\n@torch.no_grad()\ndef _copy_state_dict(\n    state_dict: dict[str, Any],\n    copy_state_dict: dict[str, Any],\n    non_blocking: bool = False,\n    type_check: bool = True,\n) -> dict[str, Any]:\n    \"\"\"\n    Copies all tensors in a given state dict into a different state_dict with the\n    same structure. Additionally, a copied state dict with the same value references\n    is returned. Editing the keys on this state dict will not affect the\n    passed in copy_state_dict (but the value references are the same).\n\n    .. warning::\n        It is expected by this function that state_dict and copy_state_dict share\n        the same structure and data types.\n\n    .. warning::\n        The current supported data types are\n            torch.Tensor, DTensor, int, float, str, list, dict, None.\n\n    Args:\n        state_dict (Dict[str, Any]): the target state_dict.\n        copy_state_dict (Dict[str, Any]):\n            The state dict we are copying into. This state_dict must have exactly\n             the same structure as the source `state_dict`.\n        non_blocking: (bool): Whether copy ops should be performed asynchronously\n        type_check (bool): check if the instance data type is a supported type\n            that can be saved by DCP. The current supported data types are\n            torch.Tensor, DTensor, int, float, str, list, dict, None.\n\n    Returns:\n        State Dict copy\n    \"\"\"\n\n    return _iterate_state_dict(\n        state_dict,\n        _identity_func,\n        _identity_func,\n        _identity_func,\n        pg=None,\n        device=None,\n        cpu_offload=False,\n        ranks_only=(),\n        companion_obj=copy_state_dict,\n        type_check=type_check,\n        non_blocking=non_blocking,\n    )\n\n\n@torch.no_grad()\ndef _create_cpu_state_dict(\n    state_dict: dict[str, Any], pin_memory: bool = False, share_memory: bool = False\n) -> dict[str, Any]:\n    \"\"\"\n    Given a state_dict, create another state_dict with the same structure and elements.\n    However, all tensors in the returned state_dict are new tensors on CPU. These\n    tensors can be placed on pin_memory or share_memory based on the provided arguments.\n\n    .. warning::\n        Setting both `pin_memory` and `share_memory` to True significantly increases the\n        latency of this method because of the nuances which require us to register memory\n        as pinned directly as opposed to relying on the pin_memory cache allocator. This\n        option should only be used for long lived tensors which are required to be shared.\n        This is not the case as long as at least one of `pin_memory` or `share_memory` is\n         set to False.\n\n    \"\"\"\n\n    def tensor_func(\n        obj: torch.Tensor,\n        pg: Optional[dist.ProcessGroup],\n        device: Optional[torch.device],\n        _: Any,\n    ) -> torch.Tensor:\n        if len(obj.size()) == 0:\n            return torch.tensor(0, dtype=obj.dtype)\n\n        if share_memory:\n            t = torch.empty(*tuple(obj.size()), dtype=obj.dtype)\n            t = t.share_memory_()\n            if pin_memory:\n\n                def unpin_memory(t):\n                    succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr()))\n                    assert succ == 0, f\"Unpinning shared memory failed with error-code: {succ}\"\n\n                weakref.finalize(t, unpin_memory, t)\n                succ = int(\n                    torch.cuda.cudart().cudaHostRegister(\n                        t.data_ptr(),\n                        t.numel() * t.element_size(),\n                        1,  # lines up with 'cudaHostRegisterPortable'\n                    )\n                )\n                assert succ == 0, f\"Pinning shared memory failed with error-code: {succ}\"\n            return t\n        elif pin_memory:\n            return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()\n        else:\n            return torch.empty(*tuple(obj.size()), dtype=obj.dtype)\n\n    def dtensor_func(\n        obj: DTensor,\n        pg: Optional[dist.ProcessGroup],\n        device: Optional[torch.device],\n        _: Any,\n    ) -> DTensor:\n        if len(obj.size()) == 0:\n            return obj\n\n        if obj.device != torch.device(\"cpu\"):\n            ret = cast(DTensor, obj.to(device=\"cpu\"))\n        else:\n            ret = copy.deepcopy(obj)\n        ret._local_tensor = tensor_func(ret._local_tensor, pg, device, None)\n        return ret\n\n    ret = _iterate_state_dict(\n        state_dict,\n        _identity_func,\n        dtensor_func,\n        tensor_func,\n        pg=None,\n        device=None,\n        cpu_offload=False,\n        ranks_only=(),\n        type_check=False,\n    )\n    return ret\n\n\ndef _check_state_dict_similarity(\n    state_dict: dict[str, Any],\n    compared_state_dict: dict[str, Any],\n) -> bool:\n    \"\"\"\n    Given two state_dicts, check if the structures are the same. And\n    if a [key, tensor] pair exist in one state_dict there must be\n    the a corresponding pait, [key, other_tensor], in the other state_dict,\n    where tensor and other_tensor have the same size and dtype.\n\n    Return the check result.\n    \"\"\"\n\n    def tensor_func(\n        obj: torch.Tensor,\n        pg: Optional[dist.ProcessGroup],\n        device: Optional[torch.device],\n        companion_obj: Any,\n    ) -> torch.Tensor:\n        if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():\n            raise CompanionMismatch\n        return obj\n\n    try:\n        _iterate_state_dict(\n            state_dict,\n            _identity_func,\n            _identity_func,\n            tensor_func,\n            pg=None,\n            device=None,\n            cpu_offload=False,\n            ranks_only=(),\n            companion_obj=compared_state_dict,\n            type_check=False,\n        )\n    except CompanionMismatch:\n        return False\n\n    return True\n\n\nclass _TensorInfo(NamedTuple):\n    size: torch.Size\n    dtype: torch.dtype\n\n\ndef _broadcast_tensors(\n    full_state_dict: dict[str, Any],\n    local_state_dict: dict[str, Any],\n    keys: list[str],\n    device: torch.device,\n    pg: Optional[dist.ProcessGroup] = None,\n) -> None:\n    tensors = []\n    for key in keys:\n        if dist.get_rank() == 0:\n            full_state = full_state_dict[key]\n            assert isinstance(full_state, torch.Tensor)\n            full_tensor = full_state.detach().to(device)\n        else:\n            tensor_info = full_state_dict[key]\n            full_tensor = torch.empty(\n                size=tensor_info.size,\n                device=device,\n                dtype=tensor_info.dtype,\n            )\n        tensors.append(full_tensor)\n        local_state = local_state_dict.get(key, None)\n        if local_state is None:\n            continue\n        elif isinstance(local_state, DTensor):\n            local_state_dict[key] = (local_state, full_tensor)\n        else:\n            local_state_dict[key] = full_tensor\n\n    if pg is None:\n        pg = dist.distributed_c10d._get_default_group()\n\n    if len(tensors) > 1:\n        dist._broadcast_coalesced(pg, tensors, 500, 0)\n    else:\n        dist.broadcast(tensors[0], src=0, group=pg)\n\n    _distribute_tensors(local_state_dict, keys, device, pg)\n\n\ndef _distribute_tensors(\n    local_state_dict: dict[str, Any],\n    keys: list[str],\n    device: torch.device,\n    pg: Optional[dist.ProcessGroup] = None,\n) -> None:\n    if pg is None:\n        pg = dist.distributed_c10d._get_default_group()\n    for key in keys:\n        _local_state = local_state_dict.get(key, None)\n        if _local_state is None or torch.is_tensor(_local_state):\n            continue\n\n        local_state = _local_state[0]\n        full_tensor = _local_state[1]\n\n        shape, offset = compute_local_shape_and_global_offset(\n            full_tensor.shape, local_state.device_mesh, local_state.placements\n        )\n        slices = [\n            slice(cur_offset, cur_offset + cur_shape) for cur_shape, cur_offset in zip(shape, offset, strict=False)\n        ]\n        if local_state.is_meta:\n            # Use .clone() here rather than view to clone and return only the sliced portion, minimizing memory access and cost.\n            local_tensor = full_tensor[slices].detach().clone()\n            # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example,\n            # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)).\n            ret = DTensor.from_local(\n                local_tensor,\n                local_state.device_mesh,\n                local_state.placements,\n                shape=local_state.shape,\n                stride=local_state.stride(),\n            )\n        else:\n            ret = local_state\n            # Copy full_tensor[slices] into local_state.to_local() to reduce memory footprint.\n            ret.to_local().copy_(full_tensor[slices])\n        local_state_dict[key] = ret\n\n\ndef _broadcast_state_dict(\n    full_state_dict: dict[str, Any],\n    local_state_dict: dict[str, Any],\n    device: torch.device,\n    pg: Optional[dist.ProcessGroup] = None,\n    strict: bool = False,\n    cpu_offload: bool = False,\n) -> None:\n    # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`.\n    # If strict is True, any keys in `local_state_dict` but not in `full_state_dict`\n    # will be removed from `local_state_dict`.\n    ret = {}\n    if dist.get_rank() == 0:\n        for key, value in full_state_dict.items():\n            if not torch.is_tensor(value):\n                ret[key] = value\n            elif value.dim() == 0:\n                ret[key] = value.cpu()\n            else:\n                ret[key] = _TensorInfo(value.size(), value.dtype)\n\n    broadcast_list = [ret]\n    dist.broadcast_object_list(broadcast_list, src=0, group=pg)\n    ret = broadcast_list[0]\n    # Gather values\n    keys = []\n    local_state_dict_keys = set(local_state_dict.keys())\n    global_keys = set()\n    for key, value in ret.items():\n        global_keys.add(key)\n        if not isinstance(value, _TensorInfo):\n            if key in local_state_dict:\n                local_state_dict[key] = value\n            continue\n\n        if dist.get_rank() == 0:\n            ret[key] = full_state_dict[key]\n\n        keys.append(key)\n        # Broadcast every tensor to avoid OOM for now.\n        if len(keys) >= 1:\n            _broadcast_tensors(ret, local_state_dict, keys, device, pg)\n            if cpu_offload:\n                for key in keys:\n                    local_state_dict[key] = local_state_dict[key].cpu()\n            keys.clear()\n\n    if strict:\n        if missing_keys := (local_state_dict_keys - global_keys):\n            for key in missing_keys:\n                local_state_dict.pop(key)\n\n    if keys:\n        _broadcast_tensors(ret, local_state_dict, keys, device, pg)\n        if cpu_offload:\n            for key in keys:\n                local_state_dict[key] = local_state_dict[key].cpu()\n\n\ndef _distribute_state_dict(\n    full_state_dict: dict[str, Any],\n    local_state_dict: dict[str, Any],\n    device: torch.device,\n    pg: Optional[dist.ProcessGroup] = None,\n) -> None:\n    # Full_state_dict = True, broadcast_from_rank0 = False here. Each rank has\n    # full_state_dict. Skip the broadcast in ``_broadcast_state_dict`` and\n    # distribute tensors in each rank\n    for key, value in full_state_dict.items():\n        if key not in full_state_dict:\n            continue\n        if not torch.is_tensor(value):\n            local_state_dict[key] = value\n        elif value.dim() == 0:\n            local_state_dict[key] = value.cpu()\n        else:\n            assert isinstance(value, torch.Tensor)\n            local_state = local_state_dict.get(key, None)\n            if local_state is None:\n                continue\n            elif isinstance(local_state, DTensor):\n                local_state_dict[key] = distribute_tensor(\n                    value.detach().to(device),\n                    local_state.device_mesh,\n                    local_state.placements,\n                )\n            else:\n                local_state_dict[key] = value.detach().to(device)\n\n\n# These APIs are from torch.distributed.checkpoint.\n# TODO: We should consolidate the code here as some not all modules can depend on\n# DCP.\nPATH_ITEM = Union[str, int]\nOBJ_PATH = tuple[PATH_ITEM, ...]\nFLATTEN_MAPPING = dict[str, OBJ_PATH]\nSTATE_DICT_TYPE = dict[str, Any]\nCONTAINER_TYPE = MutableMapping[PATH_ITEM, Any]\n\n\ndef _traverse_state_dict(\n    state_dict: STATE_DICT_TYPE,\n    visitor: Callable[[OBJ_PATH, Any], None],\n) -> None:\n    \"\"\"\n    Invoke ``visitor`` for each value recursively in ``state_dict``.\n    Mapping, list, and tuple will be flattened and other value types are treated\n    as the terminal values and will invoke ``visitor``.\n    \"\"\"\n\n    def _traverse_obj(path: OBJ_PATH, value: Any) -> None:\n        if isinstance(value, Mapping):\n            for k, v in value.items():\n                _traverse_obj(path + (str(k),), v)\n        elif isinstance(value, (list, tuple)):\n            for i, v in enumerate(value):\n                _traverse_obj(path + (i,), v)\n        else:\n            visitor(path, value)\n\n    for key, value in state_dict.items():\n        _traverse_obj((str(key),), value)\n\n\ndef _flatten_state_dict(\n    state_dict: STATE_DICT_TYPE,\n) -> tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]:\n    \"\"\"\n    Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary.\n\n    Use ``unflatten_state_dict`` to revert this process.\n    Returns:\n        A tuple with the flatten state_dict and a mapping from original to new state_dict.\n    N.B. The new keys are derived from the object paths, joined by dot.\n        For example: ``{ 'a': {'b':...}}`` results in the key `a.b`.\n    \"\"\"\n    flattened: STATE_DICT_TYPE = {}\n    mappings: FLATTEN_MAPPING = {}\n\n    def flat_copy(path: OBJ_PATH, value: Any) -> None:\n        new_fqn = \".\".join(map(str, path))\n        if new_fqn in flattened:\n            raise ValueError(f\"duplicated flatten key {new_fqn}\")\n        flattened[new_fqn] = value\n        mappings[new_fqn] = path\n\n    _traverse_state_dict(state_dict, flat_copy)\n    return flattened, mappings\n\n\ndef _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None:\n    \"\"\"Set ``value`` in ``root_dict`` along the ``path`` object path.\"\"\"\n    cur_container = cast(CONTAINER_TYPE, root_dict)\n\n    def extend_list(lst: list[Any], idx: int) -> None:\n        while len(lst) <= idx:\n            lst.append(None)\n\n    for i in range(1, len(path)):\n        prev_key = path[i - 1]\n        key = path[i]\n        def_val: CONTAINER_TYPE | list[Any] = {} if type(key) == str else []\n\n        if isinstance(cur_container, Mapping):\n            cur_container = cast(CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val))\n        else:\n            extend_list(cur_container, prev_key)\n            if cur_container[prev_key] is None:\n                cur_container[prev_key] = def_val\n            cur_container = cur_container[prev_key]\n\n    key = path[-1]\n    if type(key) == int:\n        extend_list(cast(list[Any], cur_container), key)\n\n    cur_container[key] = value\n\n\ndef _unflatten_state_dict(state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING) -> STATE_DICT_TYPE:\n    \"\"\"Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.\"\"\"\n    nested: STATE_DICT_TYPE = {}\n    for key, value in state_dict.items():\n        _set_element(nested, mapping[key], value)\n    return nested\n"
  },
  {
    "path": "verl_rl/verl/third_party/torch/distributed/checkpoint/__init__.py",
    "content": "# official torch 2.6.0 set_model_state_dict API leads to OOM\n# this is a copy of torch/distributed/checkpoint from torch 2.7.0\n\n# From PyTorch:\n\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n# From Caffe2:\n\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n\n# All contributions by Cruise LLC:\n# Copyright (c) 2022 Cruise LLC.\n# All rights reserved.\n\n# All contributions by Tri Dao:\n# Copyright (c) 2024 Tri Dao.\n# All rights reserved.\n\n# All contributions by Arm:\n# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "verl_rl/verl/third_party/torch/distributed/checkpoint/state_dict.py",
    "content": "# official torch 2.6.0 set_model_state_dict API leads to OOM\n# this is a copy of torch/distributed/checkpoint from torch 2.7.0\n\n# From PyTorch:\n\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n# From Caffe2:\n\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n\n# All contributions by Cruise LLC:\n# Copyright (c) 2022 Cruise LLC.\n# All rights reserved.\n\n# All contributions by Tri Dao:\n# Copyright (c) 2024 Tri Dao.\n# All rights reserved.\n\n# All contributions by Arm:\n# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n\n# ruff: noqa: B028, UP038, UP007, E721\n# mypy: allow-untyped-defs\nimport contextlib\nimport functools\nimport gc\nimport warnings\nfrom collections.abc import Generator, Iterable\nfrom dataclasses import asdict, dataclass, field\nfrom itertools import chain\nfrom typing import Any, Callable, Optional, Union, cast, no_type_check\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom torch.distributed._shard.sharded_tensor import ShardedTensor\nfrom torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (\n    _CHECKPOINT_PREFIX,\n)\nfrom torch.distributed.fsdp import (\n    FullOptimStateDictConfig,\n    FullStateDictConfig,\n    OptimStateDictConfig,\n    ShardedOptimStateDictConfig,\n    ShardedStateDictConfig,\n    StateDictConfig,\n    StateDictType,\n)\nfrom torch.distributed.fsdp import (\n    FullyShardedDataParallel as FSDP,\n)\nfrom torch.distributed.fsdp._common_utils import (\n    FSDP_WRAPPED_MODULE,\n    _get_module_fsdp_state_if_fully_sharded_module,\n)\nfrom torch.distributed.tensor import DTensor\nfrom torch.nn.modules.module import _IncompatibleKeys\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.utils._pytree import tree_map_only\n\nfrom verl.third_party.torch.distributed._state_dict_utils import (\n    _broadcast_state_dict,\n    _distribute_state_dict,\n    _flatten_state_dict,\n    _gather_state_dict,\n    _offload_state_dict_to_cpu,\n    _unflatten_state_dict,\n)\n\n__all__ = [\n    \"FQNS_T\",\n    \"PrimitiveType\",\n    \"ValueType\",\n    \"DictValueType\",\n    \"ListDictValueType\",\n    \"OptimizerStateType\",\n    \"StateDictOptions\",\n    \"get_model_state_dict\",\n    \"get_optimizer_state_dict\",\n    \"get_state_dict\",\n    \"set_model_state_dict\",\n    \"set_optimizer_state_dict\",\n    \"set_state_dict\",\n]\n\n\n_FLAT_PARAM = \"_flat_param\"\n_PG = \"param_groups\"\n_PARAMS = \"params\"\n_STATE = \"state\"\n\nFQNS_T = set[str]\nPrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str]\nValueType = Union[PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, \"ValueType\"]]\nDictValueType = dict[str, ValueType]\nListDictValueType = list[DictValueType]\nOptimizerStateType = dict[str, DictValueType | ListDictValueType]\n\n\n_patched_state_dict: set[Callable] = set()\n\n\n@contextlib.contextmanager\ndef _gc_context():\n    is_enabled = gc.isenabled()\n    gc.disable()\n    try:\n        yield\n    finally:\n        if is_enabled:\n            gc.enable()\n\n\n@dataclass\nclass StateDictOptions:\n    \"\"\"\n    This dataclass specifies how get_state_dict/set_state_dict will work.\n\n    - ``full_state_dict``: if this is set to True, all the tensors in the\n      returned state_dict will be gathered. No ShardedTensor and DTensor\n      will be in the returned state_dict.\n\n    - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if\n      ``full_state_dict`` is also true, then only the rank0 will get the\n      state_dict and all other ranks will get empty state_dict.\n\n    - ``ignore_frozen_params``: if the value is True, the returned state_dict\n      won't contain any frozen parameters -- the ``requires_grad`` is False.\n      The default value is False.\n\n    - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option\n      indicates whether to keep the submodule prefixes from the state_dict keys.\n      or example, if the submodule is ``module.pretrain`` and the full FQN of\n      the parameter is ``pretrain.layer1.weight`` of the param. When this option\n      is True, the parameter's key in the returned state_dict will be\n      ``pretrain.layer1.weight``. If the options is False, the key will be\n      ``layer1.weight``.\n      Note that if ``keep_submodule_prefixes`` is False, there may be conflicted\n      FQNs, hence there should be only one submodule in ``submodules``.\n\n    - ``strict``: the ``strict`` option when ``set_state_dict`` calls\n      model.load_state_dict().\n\n    - ``broadcast_from_rank0``: when the option is True, rank0 should receive a\n       full state_dict and will broadcast the tensors in the state_dict/\n       optim_state_dict one by one to other ranks. Other ranks will receive\n       the tensors and shard according to the local shards in the model and\n       optimizer. ``full_state_dict`` must be set to True when using this option.\n       This option currently only supports DTensor, not the legacy ShardedTensor.\n    \"\"\"\n\n    full_state_dict: bool = False\n    cpu_offload: bool = False\n    ignore_frozen_params: bool = False\n    keep_submodule_prefixes: bool = True\n    strict: bool = True\n    broadcast_from_rank0: bool = False\n    flatten_optimizer_state_dict: bool = False\n    dsd_fqn_modifiers: str = \"_fqn_modifiers\"\n\n\n@dataclass\nclass _StateDictInfo(StateDictOptions):\n    fqn_param_mapping: dict[\n        str | torch.Tensor,\n        FQNS_T | torch.Tensor,\n    ] = field(default_factory=dict)\n    shared_params_mapping: dict[\n        str | torch.Tensor,\n        FQNS_T | torch.Tensor,\n    ] = field(default_factory=dict)\n    submodule_prefixes: set[str] = field(default_factory=set)\n    handle_model: bool = True\n    handle_optim: bool = True\n    fsdp_context: Callable = contextlib.nullcontext\n    fsdp_modules: list[nn.Module] = field(default_factory=list)\n\n\n@functools.cache\ndef _get_fqns(\n    model: nn.Module,\n    name: str,\n    dsd_fqn_modifiers: str = \"_fqn_modifiers\",\n    skip_ddp_prefix: bool = True,\n    skip_compiler_prefix: bool = True,\n) -> FQNS_T:\n    \"\"\"\n    This API is used to convert the name of a parameter to the FQNs. For FSDP\n    without `use_orig_params`, the name of FlatParameter can be mapped to\n    multiple original parameters. As a result, the return type of this function\n    is `set[str]`.\n\n    Args:\n        module (nn.Module): the root model.\n        name (str): the name\n        skip_ddp_prefix (bool): whether to skip DDP's `module` prefix\n\n    Returns:\n        The canonical FQNs based on the model traversal.\n    \"\"\"\n\n    # Remove the checkpoint prefix, if it exists.\n    name = name.replace(_CHECKPOINT_PREFIX, \"\")\n    if \".\" not in name:\n        return {name}\n\n    obj_names = name.split(\".\")\n    fqn_obj_names = []\n    curr_obj = model\n    for i, curr_obj_name in enumerate(obj_names):\n        if isinstance(curr_obj, DDP):\n            assert curr_obj_name == \"module\"\n            curr_obj = curr_obj.module\n            if not skip_ddp_prefix:\n                fqn_obj_names.append(curr_obj_name)\n        elif isinstance(curr_obj, FSDP):\n            if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM:\n                prefix = \".\".join(fqn_obj_names)\n                flat_param = getattr(curr_obj, _FLAT_PARAM)\n                if prefix:\n                    prefix = f\"{prefix}.\"\n                return {f\"{prefix}{fqn}\" for fqn in flat_param._fqns}\n            curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)\n            if curr_obj_name != FSDP_WRAPPED_MODULE:\n                fqn_obj_names.append(curr_obj_name)\n                curr_obj = getattr(curr_obj, curr_obj_name)\n        elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):\n            assert curr_obj_name == \"_orig_mod\"\n            curr_obj = curr_obj._orig_mod\n            if not skip_compiler_prefix:\n                fqn_obj_names.append(curr_obj_name)\n        else:\n            # In some modeuls, _fqn_modifiers would not shown in the state_dict keys,\n            # skip them in the fqn to ensure load stat dict successfully for them.\n            if hasattr(curr_obj, dsd_fqn_modifiers):\n                if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get(curr_obj_name):\n                    if hasattr(curr_obj, removed_fqn):\n                        curr_obj = getattr(curr_obj, removed_fqn)\n            fqn_obj_names.append(curr_obj_name)\n            if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:\n                if i != len(obj_names) - 1:\n                    raise RuntimeError(\"Expect `_extra_state` to be the last obj name\")\n            else:\n                curr_obj = getattr(curr_obj, curr_obj_name)\n\n    return {\".\".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, \"\")}\n\n\nclass _EXTRA_STATE:\n    pass\n\n\ndef _iterate_valid_model_state(model, dsd_fqn_modifiers=\"_fqn_modifiers\"):\n    visited_modules: set[nn.Module] = set()\n\n    def recurse(module: nn.Module, curr_fqn: str) -> Generator:\n        visited_modules.add(module)\n\n        curr_fqn = f\"{curr_fqn}.\" if curr_fqn else \"\"\n        for name, submodule in module.named_children():\n            if submodule in visited_modules:\n                continue\n            # if user have state_dict_hooks in their model, they can add the state_dict key changes\n            # at dsd_fqn_modifiers in input to align with the function of state_dict_hook\n            if hasattr(module, dsd_fqn_modifiers) and name in getattr(module, dsd_fqn_modifiers)().values():\n                # skip _fqn_modifiers here thus remove the last `.` added\n                new_fqn = curr_fqn[:-1]\n            else:\n                new_fqn = f\"{curr_fqn}{name}\"\n            yield from recurse(submodule, new_fqn)\n\n        for name, obj in chain(module.named_buffers(recurse=False), module.named_parameters(recurse=False)):\n            if name in module._non_persistent_buffers_set:\n                continue\n            new_fqn = f\"{curr_fqn}{name}\"\n            yield new_fqn, obj\n\n        if getattr(module.__class__, \"get_extra_state\", nn.Module.get_extra_state) != nn.Module.get_extra_state:\n            new_fqn = f\"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}\"\n            yield new_fqn, _EXTRA_STATE()\n\n    yield from recurse(model, \"\")\n\n\ndef _verify_options(\n    model: nn.Module,\n    optims: tuple[torch.optim.Optimizer, ...],\n    optim_only: bool,\n    *,\n    submodules: Optional[set[nn.Module]] = None,\n    options: Optional[StateDictOptions] = None,\n) -> _StateDictInfo:\n    \"\"\"\n    Verify the model and options passed by the user and generates _StateDictInfo.\n    \"\"\"\n    if submodules:\n        warnings.warn(\n            \"Getting submodules only model/optim state_dict is deprecated and \"\n            \"will be removed in 2.5. This feature can be achieved by manually \"\n            \"filtering out the state_dict returned from get_state_dict.\",\n            FutureWarning,\n        )\n    if optim_only and not optims:\n        raise RuntimeError(\"Optimizers are not passed in but optim_only is set to True.\")\n\n    options = options or StateDictOptions()\n\n    fqn_param_mapping: dict[str | torch.Tensor, set[str] | torch.Tensor] = {}\n    shared_params_mapping: dict[str | torch.Tensor, set[str] | torch.Tensor] = {}\n    for name, param in _iterate_valid_model_state(model):\n        if isinstance(param, _EXTRA_STATE):\n            continue\n\n        fqns = _get_fqns(model, name)\n        fqn = fqn_param_mapping.get(param, None)\n        if fqn is not None:\n            cast(set[str], fqn_param_mapping[param]).update(fqns)\n            shared_params_mapping[param] = fqn_param_mapping[param]\n        else:\n            # We need to do copy as _get_fqns is lru_cached\n            fqn_param_mapping[param] = fqns.copy()\n        for fqn in fqns:\n            if not isinstance(param, _EXTRA_STATE):\n                fqn_param_mapping[fqn] = param\n\n    for param_, fqns_ in list(shared_params_mapping.items()):\n        for fqn in fqns_:\n            shared_params_mapping[fqn] = cast(torch.Tensor, param_)\n\n    submodule_prefixes: set[str] = set()\n    if submodules:\n        submodules = set(submodules)\n        for name, module in model.named_modules():\n            if module not in submodules:\n                continue\n            fqns = _get_fqns(model, name)\n            assert len(fqns) == 1, \"Submodule FQN should only have 1 instance\"\n            submodule_prefixes.update(f\"{fqn}.\" for fqn in fqns)\n\n    if options.broadcast_from_rank0 and not options.full_state_dict:\n        raise ValueError(\"full_state_dict must be True when broadcast_from_rank0 is True.\")\n    fsdp_modules = FSDP.fsdp_modules(model)\n    state_dict_config: StateDictConfig\n    optim_state_dict_config: OptimStateDictConfig\n    fsdp_context: Callable\n    if fsdp_modules:\n        # FSDP API only work if at least one FSDP instance exists.\n        if options.full_state_dict:\n            state_dict_config = FullStateDictConfig(offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload)\n            optim_state_dict_config = FullOptimStateDictConfig(\n                offload_to_cpu=options.cpu_offload,\n                rank0_only=(options.cpu_offload or options.broadcast_from_rank0),\n            )\n            state_dict_type = StateDictType.FULL_STATE_DICT\n        else:\n            state_dict_config = ShardedStateDictConfig(\n                offload_to_cpu=options.cpu_offload,\n            )\n            optim_state_dict_config = ShardedOptimStateDictConfig(\n                offload_to_cpu=options.cpu_offload,\n            )\n            state_dict_type = StateDictType.SHARDED_STATE_DICT\n\n        @contextlib.contextmanager\n        def fsdp_state_dict_type_without_warning(\n            module,\n            state_dict_type,\n            state_dict_config,\n            optim_state_dict_config,\n        ):\n            with warnings.catch_warnings():\n                warnings.filterwarnings(\"ignore\", message=\"FSDP.state_dict_type\", category=FutureWarning)\n                with FSDP.state_dict_type(\n                    module=module,\n                    state_dict_type=state_dict_type,\n                    state_dict_config=state_dict_config,\n                    optim_state_dict_config=optim_state_dict_config,\n                ):\n                    yield\n\n        fsdp_context = functools.partial(\n            fsdp_state_dict_type_without_warning,\n            module=model,\n            state_dict_type=state_dict_type,\n            state_dict_config=state_dict_config,\n            optim_state_dict_config=optim_state_dict_config,\n        )\n    else:\n        fsdp_context = contextlib.nullcontext\n\n    return _StateDictInfo(\n        **asdict(options),\n        fqn_param_mapping=fqn_param_mapping,\n        shared_params_mapping=shared_params_mapping,\n        submodule_prefixes=submodule_prefixes,\n        fsdp_context=fsdp_context,\n        fsdp_modules=cast(list[nn.Module], fsdp_modules),\n        handle_model=not optim_only,\n        handle_optim=(len(optims) > 0),\n    )\n\n\ndef _verify_state_dict(\n    model_state_dict: dict[str, ValueType],\n    optim_state_dict: OptimizerStateType,\n    info: _StateDictInfo,\n) -> None:\n    for module in info.fsdp_modules:\n        fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)\n        assert fsdp_state is not None, \"Expected a fsdp_state with a fsdp module.\"\n\n    # Verify if the model_state_dict and optim_state_dict are valid. This API\n    # should give the users an explicit error message to debug or report.\n    if (\n        info.handle_model\n        and not model_state_dict\n        and not info.submodule_prefixes\n        and not info.ignore_frozen_params\n        and not (info.cpu_offload and info.full_state_dict)\n        and info.strict\n        and not info.broadcast_from_rank0\n    ):\n        raise RuntimeError(\n            \"The option indicates that model state_dict is required to save \"\n            \"or load, but model state_dict is empty.\"\n            f\"rank = {dist.get_rank()=}.\"\n        )\n\n    if info.handle_optim:\n        if not optim_state_dict and not (info.cpu_offload and info.full_state_dict) and (not info.broadcast_from_rank0):\n            raise RuntimeError(\n                \"The option indicates that model state_dict is required to save, \"\n                f\"or load but optim state_dict is empty. {optim_state_dict}\"\n            )\n\n    for key in model_state_dict.keys():\n        if _FLAT_PARAM in key:\n            raise RuntimeError(f\"{key} contains {_FLAT_PARAM}. This can happen if the model is not the root module.\")\n\n\ndef _state_dict_fn(obj: nn.Module | torch.optim.Optimizer, api: str) -> Callable:\n    call = getattr(obj, api)\n    if call in _patched_state_dict:\n        call = functools.partial(getattr(obj.__class__, api), self=obj)\n    return call\n\n\ndef _maybe_full_or_cpu_state_dict(state_dict: dict[str, Any], info: _StateDictInfo) -> dict[str, Any]:\n    if info.full_state_dict:\n        ranks_only = () if (not info.cpu_offload or not torch.distributed.is_initialized()) else (0,)\n        return _gather_state_dict(state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only)\n    elif info.cpu_offload:\n        return _offload_state_dict_to_cpu(state_dict)\n    else:\n        return state_dict\n\n\n@torch.no_grad()\ndef _get_model_state_dict(model: nn.Module, info: _StateDictInfo) -> dict[str, ValueType]:\n    if not info.handle_model:\n        return {}\n\n    with info.fsdp_context():\n        state_dict = _state_dict_fn(model, \"state_dict\")()\n\n    for key in list(state_dict.keys()):\n        fqns = _get_fqns(model, key)\n        assert len(fqns) == 1, (key, fqns)\n        fqn = next(iter(fqns))\n        if fqn != key:\n            # As we only support FSDP, DDP, and TP, the only cases are\n            # wrapper-based DDP and compiler. Verify if the assumption\n            # is correct.\n            def verify(key, fqn) -> bool:\n                if len(fqn) >= len(key):\n                    return False\n                fqn_split = fqn.split(\".\")\n                key_split = key.split(\".\")\n                fqn_idx = 0\n                for key_idx, key_name in enumerate(key_split):\n                    if key_name == fqn_split[fqn_idx]:\n                        fqn_idx += 1\n                        if fqn_idx == len(fqn_split):\n                            return key_idx == len(key_split) - 1\n                    elif key_name in (\"module\", \"_orig_mod\"):\n                        continue\n                    else:\n                        return False\n                return True\n\n            if not verify(key, fqn):\n                raise RuntimeError(f\"An unexpected key, {key}, exists. FQN is {fqn}\")\n            state_dict[fqn] = state_dict.pop(key)\n\n    if info.submodule_prefixes:\n        new_state_dict: dict[str, ValueType] = {}\n        # TODO: make this faster.\n        for fqn in state_dict.keys():\n            for prefix in info.submodule_prefixes:\n                if not fqn.startswith(prefix):\n                    continue\n                if info.keep_submodule_prefixes:\n                    new_state_dict[fqn] = state_dict[fqn]\n                else:\n                    new_fqn = fqn[len(prefix) :]\n                    new_state_dict[new_fqn] = state_dict[fqn]\n        state_dict = new_state_dict\n\n    if info.ignore_frozen_params:\n        for key, param in model.named_parameters():\n            if param.requires_grad:\n                continue\n            fqns = _get_fqns(model, key)\n            for fqn in fqns:\n                state_dict.pop(fqn)\n\n    for key, p in list(state_dict.items()):\n        if torch.is_tensor(p) and p.is_meta:\n            state_dict.pop(key)\n\n    return _maybe_full_or_cpu_state_dict(state_dict, info)\n\n\n@torch.no_grad()\ndef _load_model_state_dict(\n    model: nn.Module,\n    state_dict: dict[str, ValueType],\n    info: _StateDictInfo,\n) -> _IncompatibleKeys:\n    if not info.handle_model or (not state_dict and not info.broadcast_from_rank0):\n        return _IncompatibleKeys({}, {})\n\n    local_state_dict = {}\n    for key, value in _iterate_valid_model_state(model, info.dsd_fqn_modifiers):\n        fqns = _get_fqns(model, key, info.dsd_fqn_modifiers)\n        fqns_with_prefix = _get_fqns(\n            model,\n            key,\n            info.dsd_fqn_modifiers,\n            skip_ddp_prefix=False,\n            skip_compiler_prefix=False,\n        )\n\n        for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix, strict=False):\n            if (not info.broadcast_from_rank0 or dist.get_rank() == 0) and fqn != fqn_with_prefix:\n                load_value = state_dict.pop(fqn, None)\n                if load_value is None:\n                    if info.strict:\n                        raise RuntimeError(f\"Missing key: {fqn}.\")\n                else:\n                    state_dict[fqn_with_prefix] = load_value\n            local_state_dict[fqn_with_prefix] = value\n\n    assign = False\n    if info.broadcast_from_rank0 or info.full_state_dict:\n        devices = set()\n        for key, value in local_state_dict.items():\n            if torch.is_tensor(value) and value.dim() > 0:\n                devices.add(value.device)\n        # In lora state_dict, there could be multiple devices, with meta device inside.\n        # Take the other device in the broadcast/distribtue, and set assign to True\n        if torch.device(\"meta\") in devices:\n            devices.remove(torch.device(\"meta\"))\n            assign = True\n        if len(devices) == 0:\n            devices.add(dist.distributed_c10d._get_pg_default_device())\n        elif len(devices) > 1:\n            raise ValueError(\"Multiple devices found\")\n\n        if info.broadcast_from_rank0:\n            _broadcast_state_dict(\n                state_dict,\n                local_state_dict,\n                device=devices.pop(),\n                strict=info.strict,\n                cpu_offload=info.cpu_offload,\n            )\n        elif info.full_state_dict:\n            _distribute_state_dict(state_dict, local_state_dict, device=devices.pop())\n        for fqn, local_state in local_state_dict.items():\n            state_dict[fqn] = local_state\n\n    with info.fsdp_context():\n        return cast(\n            _IncompatibleKeys,\n            _state_dict_fn(model, \"load_state_dict\")(state_dict=state_dict, strict=info.strict, assign=assign),\n        )\n\n\ndef _init_optim_state(optim: torch.optim.Optimizer) -> None:\n    \"\"\"\n    Initialize optim states by calling the step() with zero grads.\n    \"\"\"\n    if optim.state:\n        # The optimizer state is initialized.\n        return\n\n    # There are some stateless optimizers like SGD. These optimizer will\n    # not return in the above condition. So if gradients exist, we should also\n    # return. If gradients do not exist, the following initialization should\n    # not disturb SGD because the gradients and lr are both zero.\n    for param_group in optim.param_groups:\n        for param in param_group[_PARAMS]:\n            if param.grad is not None:\n                return\n\n    for param_group in optim.param_groups:\n        for param in param_group[_PARAMS]:\n            if param.requires_grad:\n                param.grad = torch.zeros_like(param)\n\n    # Some optimizers will update parameters regardless of grads due to lr, so\n    # make lr to zero when calling `step()`.\n    lrs = []\n    for param_group in optim.param_groups:\n        if \"lr\" in param_group:\n            lrs.append(param_group[\"lr\"])\n            param_group[\"lr\"] = torch.tensor(0.0) if isinstance(param_group[\"lr\"], torch.Tensor) else 0.0\n    optim.step(closure=None)\n    # Whether to recover the \"lr\" should not matter too much as we will\n    # restore checkpointing later.\n    for param_group in optim.param_groups:\n        if \"lr\" in param_group:\n            param_group[\"lr\"] = lrs.pop(0)\n    optim.zero_grad(set_to_none=True)\n\n\ndef _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]:\n    \"\"\"\n    This API flattens the optimizer state_dict to support optimizer resharding for\n    MPMD, e.g., pipeline parallelism.\n\n    Without the API, the original optimizer state_dict looks like:\n    {\n        \"state\": {\n            \"layer1.weight\": {\n                \"step\": 10, \"exp_avg\": SomeTensor, \"exp_avg_sq\": SomeTensor\n            },\n            \"layer2.weight\": {\n                \"step\": 10, \"exp_avg\": SomeTensor, \"exp_avg_sq\": SomeTensor\n            },\n        },\n        \"param_group\": [\n            {\n                \"lr\": 0.0,\n                \"betas\": (0.9, 0.95), ...,\n                \"params\": [\"layer1.weight\", \"layer2.weight\"]\n            }\n        ]\n    }\n\n    With this API, the optimizer state_dict looks like:\n    {\n        \"state.layer1.weight.step\": 10,\n        \"state.layer2.weight.step\": 10,\n        \"state.layer1.weight.exp_avg\": SomeTensor,\n        \"state.layer2.weight.exp_avg\": SomeTensor,\n        \"state.layer1.weight.exp_avg_sq\": SomeTensor,\n        \"state.layer2.weight.exp_avg_sq\": SomeTensor,\n        \"param_group.layer1.weight.lr\" : 0.1,\n        \"param_group.layer2.weight.lr\" : 0.1,\n        \"param_group.layer1.weight.betas\" : (0.9, 0.95),\n        \"param_group.layer2.weight.betas\" : (0.9, 0.95),\n    }\n\n    Note that if any of the value is a container, like the betas in the example,\n    this API won't flattent it.\n    \"\"\"\n\n    def _raise_if_type_not_supported(v):\n        if not isinstance(v, (torch.Tensor, int, float)):\n            raise NotImplementedError(\n                f\"Flattening optimizer state_dict only supports tensor, int, float states now. Type is {type(v)}.\"\n            )\n\n    ret: dict[str, ValueType] = {}\n    for fqn, state in cast(DictValueType, state_dict[_STATE]).items():\n        for k, v in cast(DictValueType, state).items():\n            _raise_if_type_not_supported(v)\n            ret[f\"{_STATE}.{fqn}.{k}\"] = v\n\n    for param_group in cast(ListDictValueType, state_dict[_PG]):\n        fqns = param_group.pop(_PARAMS)\n        for fqn in cast(list[str], fqns):\n            for k, v in param_group.items():\n                ret[f\"{_PG}.{fqn}.{k}\"] = v\n    return ret\n\n\ndef _unflatten_optim_state_dict(\n    optim: torch.optim.Optimizer,\n    state_dict: dict[str, ValueType],\n    info: _StateDictInfo,\n) -> OptimizerStateType:\n    \"\"\"\n    This API unflattens the state_dict generated by _flatten_optim_state_dict().\n    See the docstring of _flatten_optim_state_dict() for more detail.\n    \"\"\"\n    state: DictValueType = {}\n    pg_state: ListDictValueType = []\n    return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}\n\n    for param_group in optim.param_groups:\n        pg_state.append({_PARAMS: []})\n        for param in param_group[_PARAMS]:\n            for fqn in info.fqn_param_mapping[param]:\n                # If a parameter is shared, only one of the FQN will be used.\n                # So we need to verify which if this fqn is actually used in\n                # the state_dict.\n                if fqn in info.shared_params_mapping:\n                    in_params = False\n                    for k in param_group.keys():\n                        if k == _PARAMS:\n                            continue\n                        flatten_key = f\"{_PG}.{fqn}.{k}\"\n                        if flatten_key in state_dict:\n                            in_params = True\n                        break\n                else:\n                    in_params = True\n\n                if not in_params:\n                    continue\n\n                params = pg_state[-1][_PARAMS]\n                assert isinstance(params, list)  # typing\n                params.append(fqn)\n                if not param.requires_grad:\n                    continue\n                state[fqn] = {}\n                for state_name in optim.state[param].keys():\n                    cast(DictValueType, state[fqn])[state_name] = state_dict[f\"{_STATE}.{fqn}.{state_name}\"]\n\n        first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0]\n        for k in param_group.keys():\n            if k == _PARAMS:\n                continue\n            value = state_dict[f\"{_PG}.{first_param_fqn}.{k}\"]\n            if k not in pg_state[-1]:\n                pg_state[-1][k] = value\n            elif pg_state[-1][k] != value:\n                raise RuntimeError(\n                    \"All the parameters in the same parameter group should have \"\n                    f\"the same saved param_group value. But {first_param_fqn}.{k} \"\n                    f\"is {value} while other(s) is {pg_state[-1][k]}.\"\n                )\n\n    return return_osd\n\n\n@torch.no_grad()\ndef _get_optim_state_dict(\n    model: nn.Module,\n    optimizers: tuple[torch.optim.Optimizer, ...],\n    info: _StateDictInfo,\n) -> OptimizerStateType:\n    if not info.handle_optim:\n        return {}\n\n    optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []}\n    for optim in optimizers:\n        _init_optim_state(optim)\n        osd = _state_dict_fn(optim, \"state_dict\")()\n        if info.fsdp_modules:\n            with info.fsdp_context():\n                osd = FSDP.optim_state_dict(model, optim, osd)\n\n            # We need to specially handle FlatParameter FSDP as\n            # FlatParameter FSDP converts the FQNs.\n            # There are no easy ways to do this conversion systematically.\n            # We can only use a string replacment without correctness check.\n            if not osd:\n                continue\n            for k in list(osd[_STATE].keys()):\n                if \"_orig_mod\" in k:\n                    osd[_STATE][k.replace(\"_orig_mod.\", \"\")] = osd[_STATE].pop(k)\n            for g in osd[_PG]:\n                params = [k.replace(\"_orig_mod.\", \"\") for k in g[_PARAMS]]\n                g[_PARAMS] = params\n        else:\n            params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups))\n            param_pid_mapping = dict(zip(params, range(len(params)), strict=False))\n            fqn_pid_mapping = {}\n            for key, param in model.named_parameters():\n                fqns = _get_fqns(model, key)\n                assert len(fqns) == 1\n                fqn = next(iter(fqns))\n                if param not in param_pid_mapping:\n                    continue\n                pid = param_pid_mapping[param]\n                fqn_pid_mapping[fqn] = pid\n                fqn_pid_mapping[pid] = fqn\n\n            for key in list(osd[_STATE].keys()):\n                fqn = fqn_pid_mapping[key]\n                osd[_STATE][fqn] = osd[_STATE].pop(key)\n\n            for group in osd[_PG]:\n                group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]]\n\n        if not osd:\n            continue\n\n        cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE])\n        cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG])\n\n    if info.flatten_optimizer_state_dict:\n        optim_state_dict = cast(OptimizerStateType, _flatten_optim_state_dict(optim_state_dict))\n\n    return _maybe_full_or_cpu_state_dict(optim_state_dict, info)\n\n\ndef _split_optim_state_dict(\n    model: nn.Module,\n    optim: torch.optim.Optimizer,\n    optim_state_dict: OptimizerStateType,\n    info: _StateDictInfo,\n) -> OptimizerStateType:\n    \"\"\"\n    Extract the corresponding optim state_dict from ``optim_state_dict`` for\n    ``optim`` and return the result optim state_dict.\n\n    Args:\n        model (nn.Module): the root model.\n        optim (torch.optim.Optimizer): the optimizer.\n        optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that\n            contains the optim state_dict of ``optim``.\n        info (_StateDictInfo): state dict information.\n\n    Returns:\n        The optim state_dict of ``optim``.\n    \"\"\"\n\n    state: DictValueType = {}\n    pg_state: ListDictValueType = []\n    return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}\n    pg_mapping: dict[int, int] = {}\n\n    if all(isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys()):\n        return optim_state_dict\n\n    for param_group in optim.param_groups:\n        pg_state.append({_PARAMS: []})\n        for param in param_group[_PARAMS]:\n            for fqn in info.fqn_param_mapping[param]:\n                if fqn in info.shared_params_mapping:\n                    in_params = False\n                    for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]):\n                        if fqn in cast(list[str], loaded_param_group[_PARAMS]):\n                            in_params = True\n                            break\n                else:\n                    in_params = True\n                if not in_params:\n                    continue\n\n                params = pg_state[-1][_PARAMS]\n                assert isinstance(params, list)\n                params.append(fqn)\n                if param.requires_grad:\n                    state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn]\n                for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]):\n                    if fqn in cast(list[str], loaded_param_group[_PARAMS]):\n                        pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1\n\n        if len(param_group[_PARAMS]) == 0:\n            # Param_group with empty params.\n            ret = []\n            for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]):\n                if len(cast(list[str], loaded_param_group[_PARAMS])) == 0:\n                    ret.append(loaded_param_group)\n            if len(ret) != 1:\n                raise ValueError(\n                    \"There are param groups that have zero parameters. \"\n                    \"In such a case, DSD only support exactly one param group \"\n                    \"with zero parameters.\"\n                    \"But the loaded state_dict has zero or more than one param groups \"\n                    \"that have zero parameters.\"\n                )\n            if len(optim_state_dict[_PG]) != len(optim.param_groups):\n                raise ValueError(\n                    \"When there is a parameter group that has zero parameters, multiple optimizers are not supported.\"\n                )\n            pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1\n\n    for param_group in cast(ListDictValueType, optim_state_dict[_PG]):\n        pg_idx = pg_mapping.get(id(param_group), -1)\n        if pg_idx == -1:\n            continue\n\n        for key, value in param_group.items():\n            if key == _PARAMS:\n                continue\n            # TODO: check if value is the same if exists.\n            pg_state[pg_idx][key] = value\n\n    return return_osd\n\n\n@torch.no_grad()\ndef _load_optim_state_dict(\n    model: nn.Module,\n    optimizers: tuple[torch.optim.Optimizer, ...],\n    state_dict: OptimizerStateType,\n    info: _StateDictInfo,\n) -> None:\n    if not info.handle_optim:\n        return\n\n    for optim in optimizers:\n        _init_optim_state(optim)\n        if state_dict:\n            if _STATE in state_dict:\n                optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info)\n            else:\n                optim_state_dict = _unflatten_optim_state_dict(optim, cast(dict[str, ValueType], state_dict), info)\n        else:\n            optim_state_dict = {}\n        if info.fsdp_modules:\n            # We need to specially handle FlatParameter FSDP as\n            # FlatParameter FSDP converts the FQNs.\n            for original_fqn, _ in model.named_parameters():\n                fqns = _get_fqns(model, original_fqn)\n                fqns_with_compiler = _get_fqns(model, original_fqn, skip_compiler_prefix=False)\n                if fqns == fqns_with_compiler:\n                    continue\n\n                assert len(fqns) == 1\n                fqn = fqns.pop()\n                fqn_with_compiler = fqns_with_compiler.pop()\n                for g in optim_state_dict[_PG]:\n                    val = cast(dict[str, Any], g)\n                    params = [key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS]]\n                    val[_PARAMS] = params\n                osd_state = cast(DictValueType, optim_state_dict[_STATE])\n                for k in list(osd_state.keys()):\n                    if fqn in k:\n                        osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k)\n\n            with info.fsdp_context():\n                optim_state_dict = FSDP.optim_state_dict_to_load(model, optim, optim_state_dict)\n        elif info.full_state_dict:\n            info.full_state_dict = False\n            local_state_dict = _get_optim_state_dict(model, (optim,), info)\n            info.full_state_dict = True\n            device = None\n\n            def _device(t):\n                if t.dim() > 0:\n                    nonlocal device\n                    if device is None:\n                        device = t.device\n                    elif device != t.device:\n                        raise ValueError(\"Device mismatch\")\n                return t\n\n            _ = tree_map_only(torch.Tensor, _device, local_state_dict)\n            assert device is not None\n            flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict)\n            flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict)\n            if info.broadcast_from_rank0:\n                _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device)\n            else:\n                _distribute_state_dict(flatten_osd, flatten_local_osd, device=device)\n            # The modifications listed seek to address the problem where optim might possess\n            # dissimilar parameters in comparison to optim_state_dict. This is achieved by\n            # incorporating differential parameters within local, which may result in optim\n            # having additional parameters ultimately.\n            for optim_key in flatten_osd.keys():\n                if optim_key not in flatten_local_osd:\n                    assert optim_key in osd_mapping\n                    flatten_local_osd[optim_key] = flatten_osd[optim_key]\n                    local_osd_mapping[optim_key] = osd_mapping[optim_key]\n            optim_state_dict = _unflatten_state_dict(flatten_local_osd, local_osd_mapping)\n            for pg in optim_state_dict[_PG]:\n                if _PARAMS not in pg:\n                    cast(dict[str, ValueType], pg)[_PARAMS] = []\n\n        # Note that we do not have to convert the FQN back to param id here if\n        # order in optim.param_groups[idx][_PARAMS] is the same as the one in\n        # optim_state_dict[_PG][idx][_PARAMS].\n        _state_dict_fn(optim, \"load_state_dict\")(state_dict=optim_state_dict)\n\n\ndef get_model_state_dict(\n    model: nn.Module,\n    *,\n    submodules: Optional[set[nn.Module]] = None,\n    options: Optional[StateDictOptions] = None,\n) -> dict[str, ValueType]:\n    \"\"\"\n    Return the model state_dict of ``model``.\n\n    See ``get_state_dict`` for the detail usage.\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters\n            that belong to the submodules.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be returned. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        The state_dict for ``model``.\n\n    :rtype: typing.Dict[str, ValueType]\n    \"\"\"\n    with _gc_context():\n        info = _verify_options(\n            model,\n            (),\n            optim_only=False,\n            submodules=submodules,\n            options=options,\n        )\n        model_state_dict = _get_model_state_dict(model, info)\n        _verify_state_dict(model_state_dict, {}, info)\n        return model_state_dict\n\n\ndef get_optimizer_state_dict(\n    model: nn.Module,\n    optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer],\n    *,\n    submodules: Optional[set[nn.Module]] = None,\n    options: Optional[StateDictOptions] = None,\n) -> OptimizerStateType:\n    \"\"\"\n    Return the combined state_dict for optimizers.\n\n    See ``get_state_dict`` for the detail usage.\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        optimizers (Union[None, Optimizer, Iterable[Optimizer]]):\n            The optimizers that are used to optimize ``model``.\n        submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters\n            that belong to the submodules.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be returned. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        The state_dict for ``optimizers``.\n\n    :rtype: OptimizerStateType\n    \"\"\"\n    with _gc_context():\n        optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers)\n        info = _verify_options(\n            model,\n            optimizers,\n            optim_only=True,\n            submodules=submodules,\n            options=options,\n        )\n        optim_state_dict = _get_optim_state_dict(model, optimizers, info)\n        _verify_state_dict({}, optim_state_dict, info)\n        return optim_state_dict\n\n\ndef get_state_dict(\n    model: nn.Module,\n    optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer],\n    *,\n    submodules: Optional[set[nn.Module]] = None,\n    options: Optional[StateDictOptions] = None,\n) -> tuple[dict[str, ValueType], OptimizerStateType]:\n    \"\"\"\n    Return the model state_dict and optimizers state_dict.\n\n    ``get_state_dict`` can process any module that is parallelized by PyTorch\n    FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any\n    combination of these parallelisms. The main functions of ``get_state_dict``\n    are: 1.) returning a model and optimizer state_dict that can be resharded\n    with a different number of trainers and/or different parallelisms.\n    2.) hiding the parallelism-specific state_dict APIs. Users don't have to call\n    these APIs.\n    3.) sanity checking the result state_dict.\n\n    The keys of the result state dictionary are the canonical FQNs (Fully\n    Qualified Names).  A canonical FQN refers to the FQN based on a parameter's\n    position in an nn.Module hierarchy. More specifically, a canonical FQN to a\n    parameter is the FQN returned by ``module.named_parameters()`` or\n    ``module.named_buffers()`` when the module is not distributed by any\n    parallelisms. Since the optimizer internally uses parameter IDs to represent\n    a parameter, there will be a conversion from the parameter IDs to the\n    canonical FQNs when calling this API.\n\n    ``get_state_dict`` can also process a module that is not parallelized. In\n    such a case, ``get_state_dict`` only performs one function -- converting the\n    optimizer parameter IDs to the canonical FQNs.\n\n    Example:\n        >>> # xdoctest: +SKIP\n        >>> import torch\n        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        >>> from torch.nn.parallel import DistributedDataParallel as DDP\n        >>> from torch.distributed.checkpoint.state_dict import get_state_dict\n\n        >>> fsdp_model = FSDP(copy.deepcopy(model))\n        >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)\n        >>> ddp_model = DDP(copy.deepcopy(model))\n        >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)\n\n\n        >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)\n        >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(\n        ...     fsdp_model, fsdp_optim\n        ... )\n\n        >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),\n        >>> # the asserts will fail.\n        >>> assert ddp_state_dict == fsdp_state_dict\n        >>> assert ddp_optim_state == fsdp_optim_state_dict\n\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        optimizers (Union[None, Optimizer, Iterable[Optimizer]]):\n            The optimizers that are used to optimize ``model``.\n        submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters\n            that belong to the submodules.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be returned. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        ``Tuple`` that contain model state_dict and optimizer state_dict.\n\n    :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType]\n    \"\"\"\n\n    with _gc_context():\n        optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers)\n        info = _verify_options(\n            model,\n            optimizers,\n            optim_only=False,\n            submodules=submodules,\n            options=options,\n        )\n        model_state_dict = _get_model_state_dict(model, info)\n        optim_state_dict = _get_optim_state_dict(model, optimizers, info)\n        _verify_state_dict(model_state_dict, optim_state_dict, info)\n        return model_state_dict, optim_state_dict\n\n\ndef _unflatten_model_state_dict(\n    model: nn.Module,\n    state_dict: dict[nn.Module, dict[str, ValueType]] | dict[str, ValueType],\n) -> dict[str, ValueType]:\n    if not state_dict:\n        return {}\n\n    if isinstance(next(iter(state_dict.keys())), nn.Module):\n        warnings.warn(\n            \"Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``\"\n            \"is deprecated and will be removed in 2.5. If you need this \"\n            \"feature, please preprocessing the model_state_dict to achieve the \"\n            \"same functionality.\",\n            FutureWarning,\n        )\n        cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict)\n        new_state_dict: dict[str, ValueType] = {}\n        for submodule, sub_state_dict in cast_state_dict.items():\n            for name, m in model.named_modules():\n                if m != submodule:\n                    continue\n\n                fqns = _get_fqns(model, name)\n                assert len(fqns) == 1, \"FQNs for a submodule should only have 1 element\"\n                prefix = f\"{next(iter(fqns))}.\"\n                new_state_dict.update({prefix + subfqn: value for subfqn, value in sub_state_dict.items()})\n        return new_state_dict\n    else:\n        return cast(dict[str, ValueType], state_dict)\n\n\ndef set_model_state_dict(\n    model: nn.Module,\n    model_state_dict: dict[str, ValueType],\n    *,\n    options: Optional[StateDictOptions] = None,\n) -> _IncompatibleKeys:\n    \"\"\"Load the model state_dict.\n\n    The counterpart of ``get_model_state_dict`` to set the state_dict to the\n    model. See ``set_state_dict`` for the detail usage.\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        model_state_dict: (Dict[str, ValueType]):\n           the model state_dict to load. If the key of the ``model_state_dict``\n           is nn.Module, the key is a submodule of ``model`` and the value should\n           be the state_dict of the submodule. When loading the state_dict,\n           the prefix of the submodule will be append to the state_dict.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be loaded. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n            * **missing_keys** is a list of str containing the missing keys\n            * **unexpected_keys** is a list of str containing the unexpected keys\n\n    :type model_state_dict: typing.Dict[str, ValueType]\n    \"\"\"\n    model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(model, model_state_dict)\n    with _gc_context():\n        info = _verify_options(model, (), optim_only=False, options=options)\n\n        _verify_state_dict(model_state_dict, {}, info)\n        return _load_model_state_dict(model, model_state_dict, info)\n\n\ndef set_optimizer_state_dict(\n    model: nn.Module,\n    optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer],\n    optim_state_dict: OptimizerStateType,\n    *,\n    options: Optional[StateDictOptions] = None,\n) -> None:\n    \"\"\"Load the optimizers state_dict.\n\n    The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the\n    optimizers. See ``set_state_dict`` for the detail usage.\n\n    WARN: ``set_optimizer_state_dict`` can only be called before ``backward()`` or after\n        ``step()`` is called on the optimizers. Otherwise, the optimizer states won't be\n        initialized correctly.\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        optimizers (Union[Optimizer, Iterable[Optimizer]]):\n            The optimizers that are used to optimize ``model``.\n        optim_state_dict: OptimizerStateType:\n            the optimizer state_dict to load.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be loaded. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        None\n\n    :type optim_state_dict: typing.OptimizerStateType\n    \"\"\"\n    with _gc_context():\n        optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers)\n        info = _verify_options(model, optimizers, optim_only=True, options=options)\n\n        _verify_state_dict({}, optim_state_dict, info)\n        _load_optim_state_dict(model, optimizers, optim_state_dict, info)\n\n\ndef set_state_dict(\n    model: nn.Module,\n    optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer],\n    *,\n    model_state_dict: dict[str, ValueType],\n    optim_state_dict: OptimizerStateType,\n    options: Optional[StateDictOptions] = None,\n) -> _IncompatibleKeys:\n    \"\"\"Load the model state_dict and optimizers state_dict.\n\n    The counterpart of ``get_state_dict`` to set the state_dict to the model and\n    optimizers.  The given ``model_state_dict`` and ``optim_state_dict`` do not\n    have to be returned by ``get_state_dict`` but must meet the following\n    requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``,\n    2) if a tensor is sharded, it must be either a ShardedTensor or DTensor,\n    3) optimizer state_dict cannot contain the parameter IDs; the keys should be\n    the canonical FQNs.\n\n    WARN: ``set_state_dict`` can only be called before ``backward()`` or after ``step()``\n        is called on the optimizers. Otherwise, the optimizer states won't be initialized\n        correctly.\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        optimizers (Union[Optimizer, Iterable[Optimizer]]):\n            The optimizers that are used to optimize ``model``.\n        model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):\n           the model state_dict to load. If the key of the ``model_state_dict``\n           is nn.Module, the key is a submodule of ``model`` and the value should\n           be the state_dict of the submodule. When loading the state_dict,\n           the prefix of the submodule will be append to the state_dict.\n        optim_state_dict: OptimizerStateType:\n            the optimizer state_dict to load.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be loaded. See\n            `StateDictOptions` for the details.\n\n    Returns:\n        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n            * **missing_keys** is a list of str containing the missing keys of the model state_dict.\n            * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict.\n\n    :type model_state_dict: typing.Dict[str, ValueType]\n    :type optim_state_dict: typing.OptimizerStateType\n    \"\"\"\n\n    model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(model, model_state_dict)\n    with _gc_context():\n        optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers)\n        info = _verify_options(model, optimizers, optim_only=not model_state_dict, options=options)\n\n        _verify_state_dict(model_state_dict, optim_state_dict, info)\n        _load_optim_state_dict(model, optimizers, optim_state_dict, info)\n        return _load_model_state_dict(model, model_state_dict, info)\n\n\n# TODO: correct the state_dict function signature.\n# TODO: this API is not yet fully tested. Make it private\n@no_type_check\ndef _patch_model_state_dict(\n    model: nn.Module,\n    *,\n    options: Optional[StateDictOptions] = None,\n) -> None:\n    \"\"\"Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``.\n\n    Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to\n    be a partial function to call ``get_state_dict`` and ``set_state_dict``.\n\n    Example:\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from torch.distributed.checkpoint.state_dict import patch_model_state_dict\n\n        model = fsdp(model)\n        patch_model_state_dict(model)\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be loaded. See\n            `StateDictOptions` for the details.\n    Returns:\n        None\n    \"\"\"\n\n    _state_dict_call = functools.partial(\n        get_model_state_dict,\n        model=model,\n        options=options,\n    )\n\n    def state_dict_call():\n        return _state_dict_call()\n\n    model.state_dict = state_dict_call\n\n    _load_state_dict_call = functools.partial(\n        set_model_state_dict,\n        model=model,\n        options=options,\n    )\n\n    def load_state_dict_call(state_dict: dict[str, Any]):\n        _load_state_dict_call(model_state_dict=state_dict)\n\n    model.load_state_dict = load_state_dict_call\n\n    _patched_state_dict.add(state_dict_call)\n    _patched_state_dict.add(load_state_dict_call)\n\n\n# TODO: correct the load_state_dict function signature.\n# TODO: this API is not yet fully tested. Make it private\n@no_type_check\ndef _patch_optimizer_state_dict(\n    model: nn.Module,\n    *,\n    optimizers: tuple[torch.optim.Optimizer, ...],\n    options: Optional[StateDictOptions] = None,\n) -> None:\n    \"\"\"Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``.\n\n    Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to\n    be a partial function to call ``get_state_dict`` and ``set_state_dict``.\n\n    Note that if there are multiple optimizers, all of the optimizers will be patched.\n    So users only need to call one of the state_dict() to get the full result.\n\n    Example:\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from torch.distributed.checkpoint.state_dict import patch_model_state_dict\n\n        model = fsdp(model)\n        patch_model_state_dict(model)\n\n    Args:\n        model (nn.Module): the nn.Module to the model.\n        options (StateDictOptions): the options to control how\n            model state_dict and optimizer state_dict should be loaded. See\n            `StateDictOptions` for the details.\n    Returns:\n        None\n    \"\"\"\n\n    _state_dict_call = functools.partial(\n        get_optimizer_state_dict,\n        model=model,\n        optimizers=optimizers,\n        options=options,\n    )\n\n    def state_dict_call():\n        return _state_dict_call()\n\n    _load_state_dict_call = functools.partial(\n        set_optimizer_state_dict,\n        model=model,\n        optimizers=optimizers,\n        options=options,\n    )\n\n    def load_state_dict_call(state_dict: dict[str, Any]):\n        _load_state_dict_call(optim_state_dict=state_dict)\n\n    _patched_state_dict.add(state_dict_call)\n    _patched_state_dict.add(load_state_dict_call)\n    optimizers = (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) else tuple(optimizers)\n    for optim in optimizers:\n        optim.state_dict = state_dict_call\n        optim.load_state_dict = load_state_dict_call\n"
  },
  {
    "path": "verl_rl/verl/third_party/vllm/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom importlib.metadata import PackageNotFoundError, version\n\nfrom packaging import version as vs\n\nfrom verl.utils.import_utils import is_sglang_available\n\n\ndef get_version(pkg):\n    try:\n        return version(pkg)\n    except PackageNotFoundError:\n        return None\n\n\npackage_name = \"vllm\"\npackage_version = get_version(package_name)\nvllm_version = None\n\nif package_version is None:\n    if not is_sglang_available():\n        raise ValueError(\n            f\"vllm version {package_version} not supported and SGLang also not Found. Currently supported \"\n            f\"vllm versions are 0.7.0+\"\n        )\nelif vs.parse(package_version) >= vs.parse(\"0.7.0\"):\n    vllm_version = package_version\n    from vllm import LLM\n    from vllm.distributed import parallel_state\nelse:\n    if vs.parse(package_version) in [vs.parse(\"0.5.4\"), vs.parse(\"0.6.3\")]:\n        raise ValueError(\n            f\"vLLM version {package_version} support has been removed. vLLM 0.5.4 and 0.6.3 are no longer \"\n            f\"supported. Please use vLLM 0.7.0 or later.\"\n        )\n    if not is_sglang_available():\n        raise ValueError(\n            f\"vllm version {package_version} not supported and SGLang also not Found. Currently supported \"\n            f\"vllm versions are 0.7.0+\"\n        )\n\n__all__ = [\"LLM\", \"parallel_state\"]\n"
  },
  {
    "path": "verl_rl/verl/tools/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/tools/base_tool.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nfrom .schemas import OpenAIFunctionToolSchema\n\n\nclass BaseTool:\n    \"\"\"Base class for tools.\n\n    A tool should support the following methods:\n\n    - `to_openai_function_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        self.config = config\n        self.tool_schema = tool_schema or self.get_openai_tool_schema()\n        assert self.tool_schema is not None, \"Tool schema is not set!\"\n        self.name = self.tool_schema.function.name\n        print(json.dumps(self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), indent=2))\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, **kwargs) -> str:\n        \"\"\"Create a tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The instance id of the tool.\n        \"\"\"\n        if instance_id is None:\n            return str(uuid4())\n        else:\n            return instance_id\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        \"\"\"Execute the tool.\n\n        Args:\n            instance_id: The instance id of the tool.\n            parameters: The json string of the parameters of the tool.\n\n        Returns: tool_response, tool_reward_score, tool_metrics\n            tool_response: The response str of the tool.\n            tool_reward_score: The step reward score of the tool.\n            tool_metrics: The metrics of the tool.\n        \"\"\"\n        return \"Updated the tool state.\", 0.0, {}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> float:\n        \"\"\"Calculate the reward of the tool.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The reward of the tool.\n        \"\"\"\n        return 0.0\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        \"\"\"Release the tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n        \"\"\"\n        pass\n"
  },
  {
    "path": "verl_rl/verl/tools/geo3k_tool.py",
    "content": "# Copyright 2023-2025 SGLang Team\n# Copyright Amazon.com, Inc. or its affiliates.\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom verl.utils.reward_score import geo3k\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nfrom .base_tool import BaseTool\nfrom .schemas import OpenAIFunctionToolSchema\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass Geo3kTool(BaseTool):\n    \"\"\"A demo tool for calculating the reward of geo3k.\n    - `to_openai_function_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        \"\"\"\n        _tool_schema = OpenAIFunctionToolSchema.model_validate({\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"calc_geo3k_reward\",\n                \"description\": \"A tool for calculating the reward of geo3k\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"answer\": {\n                            \"type\": \"string\",\n                            \"description\": \"The answer to the question, enclosed in \\\\boxed{}\",\n                        },\n                    },\n                    \"required\": [\"answer\"],\n                },\n            }\n        })\n        \"\"\"\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        return instance_id, None\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        answer = parameters.get(\"answer\", \"\")\n        if not isinstance(answer, str):\n            answer = str(answer)\n        self._instance_dict[instance_id][\"response\"] = answer\n        reward = await self.calc_reward(instance_id)\n        # penalty for non improved answer submission\n        tool_reward = 0.0 if reward > self._instance_dict[instance_id][\"reward\"] else -0.05\n        # update the reward\n        self._instance_dict[instance_id][\"reward\"] = reward\n        return f\"Current parsed {answer=} {reward=}\", tool_reward, {}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> float:\n        return geo3k.compute_score(\n            self._instance_dict[instance_id][\"response\"],\n            self._instance_dict[instance_id][\"ground_truth\"],\n            use_boxed=False,\n            format_score=0.0,\n        )\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "verl_rl/verl/tools/gsm8k_tool.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom verl.utils.reward_score import gsm8k\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nfrom .base_tool import BaseTool\nfrom .schemas import OpenAIFunctionToolSchema\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass Gsm8kTool(BaseTool):\n    \"\"\"A demo tool for calculating the reward of gsm8k.\n\n    - `to_openai_function_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        \"\"\"\n        _tool_schema = OpenAIFunctionToolSchema.model_validate({\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"calc_gsm8k_reward\",\n                \"description\": \"A tool for calculating the reward of gsm8k\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"answer\": {\n                            \"type\": \"string\",\n                            \"description\": \"The answer to the question\",\n                        },\n                    },\n                    \"required\": [\"answer\"],\n                },\n            }\n        })\n        \"\"\"\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        return instance_id\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        answer = parameters.get(\"answer\", \"\")\n        if not isinstance(answer, str):\n            answer = str(answer)\n\n        if answer.startswith(\"#### \"):\n            self._instance_dict[instance_id][\"response\"] = answer\n        else:\n            self._instance_dict[instance_id][\"response\"] = \"#### \" + answer\n\n        reward = await self.calc_reward(instance_id)\n        # penalty for non improved answer submission\n        tool_reward = 0.0 if reward > self._instance_dict[instance_id][\"reward\"] else -0.05\n        # update the reward\n        self._instance_dict[instance_id][\"reward\"] = reward\n\n        return f\"Current parsed {answer=} {reward=}\", tool_reward, {}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> float:\n        return gsm8k.compute_score(\n            self._instance_dict[instance_id][\"response\"],\n            self._instance_dict[instance_id][\"ground_truth\"],\n            method=\"flexible\",\n            format_score=0.0,\n            score=1.0,\n        )\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "verl_rl/verl/tools/mcp_base_tool.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nfrom typing import Any, Optional\nfrom uuid import uuid4\n\nfrom fastmcp.exceptions import ClientError\n\nfrom verl.tools.utils.mcp_clients.McpClientManager import ClientManager\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nfrom .base_tool import BaseTool\nfrom .schemas import OpenAIFunctionToolSchema\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MCPBaseTool(BaseTool):\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n        self.timeout = config.get(\"timeout\", 30)\n\n        # TODO(hechanghao): create a global client manager to manage the rate limit, client and pool\n        logger.info(f\"Initialized MCPBaseTool with config: {config}\")\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        \"\"\"Return the OpenAI tool schema.\"\"\"\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, **kwargs) -> str:\n        \"\"\"Create a tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The instance id of the tool.\n        \"\"\"\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"reward\": [],\n        }\n        return instance_id\n\n    async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]:\n        err_msg = \"\"\n        try:\n            call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout)\n        except ClientError as e:\n            err_msg = f\"\\n Tool call failed: {e}\"\n        except ConnectionError as e:\n            err_msg = f\"\\n Connection failed: {e}\"\n        except Exception as e:\n            err_msg = f\"\\n An unexpected error occurred: {e}\"\n\n        logger.debug(f\"Tool result for instance {instance_id} with tool {self.name}: {call_tool_result.content}\")\n        result, metadata = self._parse_tool_result(call_tool_result.content)\n        metadata[\"api_request_error\"] += err_msg\n        return result, metadata\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        if self.name == \"\" or self.name is None or parameters is None:\n            error_msg = \"Error: 'parameters' is missing or empty.\"\n            logger.error(f\"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}\")\n            return json.dumps({\"result\": error_msg}), 0.0, {}\n\n        try:\n            result_text, metadata = await self._call_tool(instance_id, parameters)\n\n            # Store results in instance dictionary\n            self._instance_dict[instance_id][\"reward\"].append(result_text.strip())\n\n            # Convert metadata to metrics\n            metrics = {\n                \"query_count\": metadata.get(\"query_count\", 0),\n                \"status\": metadata.get(\"status\", \"unknown\"),\n                \"total_results\": metadata.get(\"total_results\", 0),\n                \"api_request_error\": metadata.get(\"api_request_error\"),\n            }\n\n            return result_text, 0.0, metrics\n\n        except Exception as e:\n            error_result = json.dumps({\"result\": f\"Tool execution failed: {e}\"})\n            logger.error(f\"[MCPBaseTool] Execution failed: {e}\")\n            return error_result, 0.0, {\"error\": str(e)}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> str:\n        return self._instance_dict[instance_id][\"reward\"]\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        if instance_id in self._instance_dict:\n            del self._instance_dict[instance_id]\n\n    def _parse_tool_result(self, content: list) -> tuple[str, dict]:\n        tools_content = [part.text for part in filter(lambda x: x.type == \"text\", content)]\n        return \" \".join(tools_content), {}\n"
  },
  {
    "path": "verl_rl/verl/tools/mcp_search_tool.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport re\n\nfrom verl.tools.mcp_base_tool import MCPBaseTool\n\nfrom .schemas import OpenAIFunctionToolSchema\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MCPSearchTool(MCPBaseTool):\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        super().__init__(config, tool_schema)\n\n    def _parse_tool_result(self, content: list) -> tuple[str, dict]:\n        res = \"\"\n        res_cnt = 0\n        query_list = []\n        metadata = {\n            \"api_request_error\": \"\",\n            \"status\": \"unknown\",\n            \"total_results\": 0,\n        }\n        try:\n            for part in content:\n                if part.type != \"text\":\n                    continue\n                text = part.text.replace(\"'\", '\"')\n                query_match = re.search(r'query\"\\s*:\\s*\"([^\"]+)\"', text)\n                query = query_match.group(1) if query_match else \"\"\n                query_list.append(query)\n\n                title_matches = re.findall(r'\"title\"\\s*:', text)\n                title_count = len(title_matches)\n\n                results_match = re.search(r'\"results\"\\s*:\\s*(\\[.*?\\])', text, re.DOTALL)\n                results_content = results_match.group(1) if results_match else \"\"\n\n                res += results_content\n                res_cnt += title_count\n        except json.JSONDecodeError:\n            err_msg = \"json parse error.\"\n            logger.error(err_msg)\n            metadata[\"api_request_error\"] = err_msg\n            metadata[\"status\"] = \"error\"\n\n        # update metadata\n        metadata[\"status\"] = \"success\"\n        metadata[\"queries\"] = query_list\n        metadata[\"query_count\"] = len(query_list)\n        metadata[\"total_results\"] = res_cnt\n        return res, metadata\n"
  },
  {
    "path": "verl_rl/verl/tools/sandbox_fusion_tools.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport threading\nfrom contextlib import ExitStack\nfrom enum import Enum\nfrom typing import Any, Callable, Optional, TypeVar\nfrom uuid import uuid4\n\nimport ray\n\nfrom verl.tools.base_tool import BaseTool\nfrom verl.utils.reward_score.sandbox_fusion.utils import _process_single_case\nfrom verl.utils.rollout_trace import rollout_trace_op\n\nfrom .schemas import OpenAIFunctionToolSchema\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\nT = TypeVar(\"T\")\n\n\nclass PoolMode(Enum):\n    ThreadMode = 1\n    ProcessMode = 2\n\n\n@ray.remote(concurrency_groups={\"acquire\": 1, \"release\": 10})\nclass TokenBucketWorker:\n    def __init__(self, rate_limit: int):\n        self.rate_limit = rate_limit\n        # this only used for observalability\n        self.current_count = 0\n        self._semaphore = threading.Semaphore(rate_limit)\n\n    @ray.method(concurrency_group=\"acquire\")\n    def acquire(self):\n        self._semaphore.acquire()\n        self.current_count += 1\n\n    @ray.method(concurrency_group=\"release\")\n    def release(self):\n        self._semaphore.release()\n        self.current_count -= 1\n\n    def get_current_count(self):\n        return self.current_count\n\n\nclass ExecutionWorker:\n    def __init__(self, enable_global_rate_limit=True, rate_limit=10):\n        self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None\n\n    def _init_rate_limit(self, rate_limit):\n        # TODO validation for rate_limit\n        # A Singleton Rate Limitor\n        return TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote(rate_limit)\n\n    def ping(self):\n        return True\n\n    def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T:\n        with ExitStack() as stack:\n            stack.callback(self.rate_limit_worker.release.remote)\n            ray.get(self.rate_limit_worker.acquire.remote())\n            try:\n                return fn(*fn_args, **fn_kwargs)\n            except Exception as e:\n                # TODO we should make this available to the tool caller\n                logger.warning(f\"Error when executing code: {e}\")\n\n\ndef init_execution_pool(\n    num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode\n):\n    if mode == PoolMode.ThreadMode:\n        return (\n            ray.remote(ExecutionWorker)\n            .options(max_concurrency=num_workers)\n            .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit)\n        )\n    else:\n        raise NotImplementedError(\"Process mode is not implemented yet\")\n        # return ray.util.multiprocessing.Pool(processes=num_workers)\n\n\nclass SandboxFusionTool(BaseTool):\n    \"\"\"A tool for executing the code using sanbox fusion image.\n\n    - `to_openai_function_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        \"\"\"\n        _tool_schema = OpenAIFunctionToolSchema.model_validate({\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"code_interpreter\",\n                \"description\": \"A tool for execute code\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"code\": {\n                            \"type\": \"string\",\n                            \"description\": \"code needs to be execute and grad\",\n                        },\n                    },\n                    \"required\": [\"code\"],\n                },\n            }\n        })\n        \"\"\"\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n        # TODO: better documentation for the config\n        self.num_workers = config.get(\"num_workers\", 10)\n        self.rate_limit = config.get(\"rate_limit\", 10)\n        self.default_timeout = config.get(\"default_timeout\", 30)\n        self.default_language = config.get(\"default_language\", \"python\")\n        self.enable_global_rate_limit = config.get(\"enable_global_rate_limit\", True)\n        self.execution_pool = init_execution_pool(\n            num_workers=self.num_workers,\n            enable_global_rate_limit=self.enable_global_rate_limit,\n            rate_limit=self.rate_limit,\n            mode=PoolMode.ThreadMode,\n        )\n        self.sandbox_fusion_url = config.get(\"sandbox_fusion_url\", \"\")\n        self.memory_limit_mb = config.get(\"memory_limit_mb\", 1024)\n        if self.sandbox_fusion_url == \"\":\n            raise ValueError(\"sandbox_fusion_url is not set\")\n        log_msg = f\"Init SandboxFusionTool with config: {config}\"\n        logger.info(log_msg)\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": [],\n        }\n        return instance_id\n\n    @rollout_trace_op\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\n        code = parameters.get(\"code\", \"\")\n        timeout = parameters.get(\"timeout\", self.default_timeout)\n        language = parameters.get(\"language\", self.default_language)\n        if not isinstance(code, str):\n            code = str(code)\n\n        result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language)\n        # sandbox has no score or metrics, use Nones\n        return result, None, None\n\n    def execute_code(self, instance_id, code, timeout=30, language=\"python\"):\n        result_status, metadata = _process_single_case(\n            0, None, None, self.sandbox_fusion_url, code, timeout, self.memory_limit_mb, language\n        )\n        # we should always expect this since we don't have correct answer\n        if metadata[\"run_status\"] == \"Finished\":\n            actual_output = metadata[\"stdout\"] + metadata[\"stderr\"]\n            logger.debug(f\"actual_output from sandbox fusion: {actual_output},{instance_id}\")\n            return actual_output\n        else:\n            return \"no stdout here\"\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> str:\n        return self._instance_dict[instance_id][\"reward\"]\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "verl_rl/verl/tools/schemas.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nfrom typing import Any, Literal\n\nfrom pydantic import BaseModel\n\n\nclass OpenAIFunctionPropertySchema(BaseModel):\n    \"\"\"The schema of a parameter in OpenAI format.\"\"\"\n\n    type: str\n    description: str | None = None\n    enum: list[str] | None = None\n\n\nclass OpenAIFunctionParametersSchema(BaseModel):\n    \"\"\"The schema of parameters in OpenAI format.\"\"\"\n\n    type: str\n    properties: dict[str, OpenAIFunctionPropertySchema]\n    required: list[str]\n\n\nclass OpenAIFunctionSchema(BaseModel):\n    \"\"\"The schema of a function in OpenAI format.\"\"\"\n\n    name: str\n    description: str\n    parameters: OpenAIFunctionParametersSchema\n    strict: bool = False\n\n\nclass OpenAIFunctionToolSchema(BaseModel):\n    \"\"\"The schema of a tool in OpenAI format.\"\"\"\n\n    type: str\n    function: OpenAIFunctionSchema\n\n\nclass OpenAIFunctionParsedSchema(BaseModel):\n    \"\"\"The parsed schema of a tool in OpenAI format.\"\"\"\n\n    name: str\n    arguments: str  # JSON string\n\n\nclass OpenAIFunctionCallSchema(BaseModel):\n    \"\"\"The parsed schema of a tool in OpenAI format.\"\"\"\n\n    name: str\n    arguments: dict[str, Any]\n\n    @staticmethod\n    def from_openai_function_parsed_schema(\n        parsed_schema: OpenAIFunctionParsedSchema,\n    ) -> tuple[\"OpenAIFunctionCallSchema\", bool]:\n        has_decode_error = False\n        try:\n            arguments = json.loads(parsed_schema.arguments)\n        except json.JSONDecodeError:\n            arguments = {}\n            has_decode_error = True\n        # If the arguments is not a dict, it means the arguments is not a valid JSON string\n        if not isinstance(arguments, dict):\n            arguments = {}\n            has_decode_error = True\n\n        return OpenAIFunctionCallSchema(name=parsed_schema.name, arguments=arguments), has_decode_error\n\n\nclass OpenAIFunctionToolCall(BaseModel):\n    \"\"\"The tool call in OpenAI format.\"\"\"\n\n    id: str\n    type: Literal[\"function\"] = \"function\"\n    function: OpenAIFunctionCallSchema\n"
  },
  {
    "path": "verl_rl/verl/tools/search_tool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport json\r\nimport logging\r\nimport os\r\nimport threading\r\nfrom contextlib import ExitStack\r\nfrom enum import Enum\r\nfrom typing import Any, Callable, Optional, TypeVar\r\nfrom uuid import uuid4\r\n\r\nimport ray\r\nimport ray.actor\r\n\r\nfrom verl.tools.utils.search_r1_like_utils import perform_single_search_batch\r\nfrom verl.utils.rollout_trace import rollout_trace_op\r\n\r\nfrom .base_tool import BaseTool\r\nfrom .schemas import OpenAIFunctionToolSchema\r\n\r\nlogger = logging.getLogger(__name__)\r\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\r\n\r\nT = TypeVar(\"T\")\r\n\r\n\r\n# Adapted from verl/tools/sandbox_fusion_tools.py\r\nclass PoolMode(Enum):\r\n    \"\"\"Execution pool mode enumeration.\"\"\"\r\n\r\n    ThreadMode = 1\r\n    ProcessMode = 2\r\n\r\n\r\n@ray.remote(concurrency_groups={\"acquire\": 1, \"release\": 10})\r\nclass TokenBucketWorker:\r\n    \"\"\"Ray actor for rate limiting using token bucket algorithm.\"\"\"\r\n\r\n    def __init__(self, rate_limit: int):\r\n        self.rate_limit = rate_limit\r\n        self.current_count = 0  # For observability\r\n        self._semaphore = threading.Semaphore(rate_limit)\r\n\r\n    @ray.method(concurrency_group=\"acquire\")\r\n    def acquire(self):\r\n        \"\"\"Acquire a token from the bucket.\"\"\"\r\n        self._semaphore.acquire()\r\n        self.current_count += 1\r\n\r\n    @ray.method(concurrency_group=\"release\")\r\n    def release(self):\r\n        \"\"\"Release a token back to the bucket.\"\"\"\r\n        self._semaphore.release()\r\n        self.current_count -= 1\r\n\r\n    def get_current_count(self):\r\n        \"\"\"Get current number of acquired tokens.\"\"\"\r\n        return self.current_count\r\n\r\n\r\nclass SearchExecutionWorker:\r\n    \"\"\"Worker for executing search operations with optional rate limiting.\"\"\"\r\n\r\n    def __init__(self, enable_global_rate_limit=True, rate_limit=10):\r\n        self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None\r\n\r\n    def _init_rate_limit(self, rate_limit):\r\n        \"\"\"Initialize singleton rate limiter.\"\"\"\r\n        return TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote(rate_limit)\r\n\r\n    def ping(self):\r\n        \"\"\"Health check method.\"\"\"\r\n        return True\r\n\r\n    def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T:\r\n        \"\"\"Execute function with optional rate limiting.\"\"\"\r\n        if self.rate_limit_worker:\r\n            with ExitStack() as stack:\r\n                stack.callback(self.rate_limit_worker.release.remote)\r\n                ray.get(self.rate_limit_worker.acquire.remote())\r\n                try:\r\n                    return fn(*fn_args, **fn_kwargs)\r\n                except Exception as e:\r\n                    # TODO we should make this available to the tool caller\r\n                    logger.warning(f\"Error when executing search: {e}\")\r\n        else:\r\n            return fn(*fn_args, **fn_kwargs)\r\n\r\n\r\ndef init_search_execution_pool(\r\n    num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode\r\n):\r\n    \"\"\"Initialize search execution pool.\"\"\"\r\n    if mode == PoolMode.ThreadMode:\r\n        return (\r\n            ray.remote(SearchExecutionWorker)\r\n            .options(max_concurrency=num_workers)\r\n            .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit)\r\n        )\r\n    else:\r\n        raise NotImplementedError(\"Process mode is not implemented yet\")\r\n\r\n\r\nclass SearchTool(BaseTool):\r\n    \"\"\"Search tool for retrieving information using external retrieval services.\r\n\r\n    This tool provides search functionality with rate limiting and concurrent execution\r\n    support through Ray. It integrates with external retrieval services to perform\r\n    semantic search operations.\r\n\r\n    Methods:\r\n        get_openai_tool_schema: Return the tool schema in OpenAI format\r\n        create: Create a tool instance for a trajectory\r\n        execute: Execute the search tool\r\n        calc_reward: Calculate the reward with respect to tool state\r\n        release: Release the tool instance\r\n    \"\"\"\r\n\r\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\r\n        \"\"\"Initialize SearchTool with configuration and schema.\r\n\r\n        Args:\r\n            config: Configuration dictionary containing tool settings\r\n            tool_schema: OpenAI function tool schema definition\r\n\r\n        Example tool_schema:\r\n            {\r\n                \"type\": \"function\",\r\n                \"function\": {\r\n                    \"name\": \"search\",\r\n                    \"description\": \"Searches for relevant information based on queries.\",\r\n                    \"parameters\": {\r\n                        \"type\": \"object\",\r\n                        \"properties\": {\r\n                            \"query_list\": {\r\n                                \"type\": \"array\",\r\n                                \"items\": {\"type\": \"string\"},\r\n                                \"description\": \"List of search queries\"\r\n                            }\r\n                        },\r\n                        \"required\": [\"query_list\"]\r\n                    }\r\n                }\r\n            }\r\n        \"\"\"\r\n        super().__init__(config, tool_schema)\r\n        self._instance_dict = {}\r\n\r\n        # Worker and rate limiting configuration\r\n        self.num_workers = config.get(\"num_workers\", 120)\r\n        self.rate_limit = config.get(\"rate_limit\", 120)\r\n        self.timeout = config.get(\"timeout\", 30)\r\n\r\n        self.enable_global_rate_limit = config.get(\"enable_global_rate_limit\", True)\r\n        self.execution_pool = init_search_execution_pool(\r\n            num_workers=self.num_workers,\r\n            enable_global_rate_limit=self.enable_global_rate_limit,\r\n            rate_limit=self.rate_limit,\r\n            mode=PoolMode.ThreadMode,\r\n        )\r\n\r\n        # Retrieval service configuration\r\n        self.retrieval_service_url = config.get(\"retrieval_service_url\")\r\n        assert self.retrieval_service_url, \"Configuration must include 'retrieval_service_url'\"\r\n        self.topk = config.get(\"topk\", 3)\r\n        if self.retrieval_service_url == \"\":\r\n            raise ValueError(\"retrieval_service_url is not set\")\r\n\r\n        logger.info(f\"Initialized SearchTool with config: {config}\")\r\n\r\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\r\n        \"\"\"Return the OpenAI tool schema.\"\"\"\r\n        return self.tool_schema\r\n\r\n    async def create(self, instance_id: Optional[str] = None, **kwargs) -> str:\r\n        \"\"\"Create a tool instance.\r\n\r\n        Args:\r\n            instance_id: The instance id of the tool.\r\n\r\n        Returns:\r\n            The instance id of the tool.\r\n        \"\"\"\r\n        if instance_id is None:\r\n            instance_id = str(uuid4())\r\n        self._instance_dict[instance_id] = {\r\n            \"response\": \"\",\r\n            \"reward\": [],\r\n        }\r\n        return instance_id\r\n\r\n    def execute_search(self, instance_id: str, query_list: list, retrieval_service_url: str, topk: int, timeout: int):\r\n        \"\"\"Execute search operation using retrieval service.\r\n\r\n        Args:\r\n            instance_id: Tool instance ID\r\n            query_list: List of search queries\r\n            retrieval_service_url: URL of the retrieval service\r\n            topk: Number of top results to return\r\n            timeout: Request timeout in seconds\r\n\r\n        Returns:\r\n            Tuple of (result_text, metadata)\r\n        \"\"\"\r\n        result_text, metadata = perform_single_search_batch(\r\n            retrieval_service_url=retrieval_service_url,\r\n            query_list=query_list,\r\n            topk=topk,\r\n            concurrent_semaphore=None,  # Ray handles concurrency control\r\n            timeout=timeout,\r\n        )\r\n        logger.debug(f\"Search result for instance {instance_id}: {result_text}\")\r\n        return result_text, metadata\r\n\r\n    @rollout_trace_op\r\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:\r\n        \"\"\"Execute the search tool.\r\n\r\n        Args:\r\n            instance_id: The instance ID of the tool\r\n            parameters: Tool parameters containing query_list and optional timeout\r\n\r\n        Returns: tool_response, tool_reward_score, tool_metrics\r\n            tool_response: The response str of the tool.\r\n            tool_reward_score: The step reward score of the tool.\r\n            tool_metrics: The metrics of the tool.\r\n        \"\"\"\r\n        timeout = self.timeout\r\n        query_list_from_params = parameters.get(\"query_list\")\r\n\r\n        if not query_list_from_params or not isinstance(query_list_from_params, list):\r\n            error_msg = \"Error: 'query_list' is missing, empty, or not a list in parameters.\"\r\n            logger.error(f\"[SearchTool] {error_msg} Received parameters: {parameters}\")\r\n            return json.dumps({\"result\": error_msg}), 0.0, {}\r\n\r\n        # Execute search using Ray execution pool\r\n        try:\r\n            result_text, metadata = await self.execution_pool.execute.remote(\r\n                self.execute_search, instance_id, query_list_from_params, self.retrieval_service_url, self.topk, timeout\r\n            )\r\n\r\n            # Store results in instance dictionary\r\n            self._instance_dict[instance_id][\"reward\"].append(result_text.strip())\r\n\r\n            # Convert metadata to metrics\r\n            metrics = {\r\n                \"query_count\": metadata.get(\"query_count\", 0),\r\n                \"status\": metadata.get(\"status\", \"unknown\"),\r\n                \"total_results\": metadata.get(\"total_results\", 0),\r\n                \"api_request_error\": metadata.get(\"api_request_error\"),\r\n            }\r\n\r\n            return result_text, 0.0, metrics\r\n\r\n        except Exception as e:\r\n            error_result = json.dumps({\"result\": f\"Search execution failed: {e}\"})\r\n            logger.error(f\"[SearchTool] Execution failed: {e}\")\r\n            return error_result, 0.0, {\"error\": str(e)}\r\n\r\n    async def calc_reward(self, instance_id: str, **kwargs) -> str:\r\n        return self._instance_dict[instance_id][\"reward\"]\r\n\r\n    async def release(self, instance_id: str, **kwargs) -> None:\r\n        if instance_id in self._instance_dict:\r\n            del self._instance_dict[instance_id]\r\n"
  },
  {
    "path": "verl_rl/verl/tools/utils/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/tools/utils/mcp_clients/McpClientManager.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport asyncio\r\nimport json\r\nimport logging\r\nfrom typing import Any\r\n\r\nfrom fastmcp import Client\r\nfrom fastmcp.client.transports import SSETransport\r\n\r\nfrom verl.tools.utils.mcp_clients.utils import TokenBucket, mcp2openai\r\n\r\nlogger = logging.getLogger(__name__)\r\n\r\n\r\nclass MCPClientManager:\r\n    rootServerName = \"mcpServers\"\r\n    initialized = False\r\n    clients = []\r\n    tool_client_mapping = {}\r\n    rate_limiter = None\r\n\r\n    async def initialize(self, config_path, rate_limit: float = 10.0):\r\n        if self.initialized:\r\n            return\r\n        \"\"\"Initialize the MCP Client Manager and start all clients\"\"\"\r\n        result = self._load_config(config_path)\r\n        servers = result[self.rootServerName]\r\n        exclude_sse_servers = {self.rootServerName: {}}\r\n        for server_name in servers.keys():\r\n            server = servers[server_name]\r\n            if \"auth_token\" in server:\r\n                transport = SSETransport(url=server[\"url\"], headers={\"Authorization\": f\"Bearer {server['auth_token']}\"})\r\n                client = Client(transport)\r\n                self.clients.append(client)\r\n            else:\r\n                exclude_sse_servers[self.rootServerName][server_name] = server\r\n\r\n        if exclude_sse_servers[self.rootServerName]:\r\n            self.clients.append(Client(exclude_sse_servers))\r\n\r\n        # Initialize rate limiter\r\n        self.rate_limiter = TokenBucket(rate_limit)\r\n        self.initialized = True\r\n\r\n    async def call_tool(self, tool_name, parameters, timeout):\r\n        # Apply rate limiting\r\n        while not self.rate_limiter.acquire():\r\n            await asyncio.sleep(0.1)\r\n\r\n        client = self.get_client_with_tool_name(tool_name)\r\n        async with client:\r\n            return await client.call_tool_mcp(tool_name, parameters)\r\n\r\n    async def fetch_tool_schemas(self, tool_selected_list: list[str]) -> list[dict]:\r\n        tool_schemas = []\r\n        for client in self.clients:\r\n            async with client:\r\n                tools = await client.list_tools_mcp()\r\n                for tool in tools.tools:\r\n                    if not tool_selected_list:\r\n                        self.tool_client_mapping[tool.name] = client\r\n                        tool_schemas.append(mcp2openai(tool))\r\n                    elif tool.name in tool_selected_list:\r\n                        self.tool_client_mapping[tool.name] = client\r\n                        tool_schemas.append(mcp2openai(tool))\r\n\r\n        return tool_schemas\r\n\r\n    def get_client_with_tool_name(self, tool_name: str):\r\n        return self.tool_client_mapping[tool_name]\r\n\r\n    def _load_config(self, file: str) -> dict[str, Any]:\r\n        try:\r\n            with open(file) as f:\r\n                return json.load(f)\r\n        except FileNotFoundError:\r\n            logger.warning(f'the \"{file}\" file was not found')\r\n        except Exception:\r\n            logger.error(f'there was an error reading the \"{file}\" file')\r\n\r\n        return {}\r\n\r\n\r\nClientManager = MCPClientManager()\r\n"
  },
  {
    "path": "verl_rl/verl/tools/utils/mcp_clients/utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport threading\nimport time\n\nfrom mcp import Tool\n\nlogger = logging.getLogger(__file__)\n\n\nclass TokenBucket:\n    def __init__(self, rate_limit: float):\n        self.rate_limit = rate_limit  # tokens per second\n        self.tokens = rate_limit\n        self.last_update = time.time()\n        self.lock = threading.Lock()\n\n    def acquire(self) -> bool:\n        with self.lock:\n            now = time.time()\n            # Add new tokens based on time elapsed\n            new_tokens = (now - self.last_update) * self.rate_limit\n            self.tokens = min(self.rate_limit, self.tokens + new_tokens)\n            self.last_update = now\n\n            if self.tokens >= 1:\n                self.tokens -= 1\n                return True\n            return False\n\n\ndef mcp2openai(mcp_tool: Tool) -> dict:\n    \"\"\"Convert a MCP Tool to an OpenAI ChatCompletionTool.\"\"\"\n    openai_format = {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": mcp_tool.name,\n            \"description\": mcp_tool.description,\n            \"parameters\": mcp_tool.inputSchema,\n            \"strict\": False,\n        },\n    }\n    if not openai_format[\"function\"][\"parameters\"].get(\"required\", None):\n        openai_format[\"function\"][\"parameters\"][\"required\"] = []\n    return openai_format\n"
  },
  {
    "path": "verl_rl/verl/tools/utils/search_r1_like_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport json\r\nimport logging\r\nimport threading\r\nimport time\r\nimport traceback\r\nimport uuid\r\nfrom typing import Any, Optional\r\n\r\nimport requests\r\n\r\nDEFAULT_TIMEOUT = 30  # Default search request timeout\r\nMAX_RETRIES = 10\r\nINITIAL_RETRY_DELAY = 1\r\nAPI_TIMEOUT = 10\r\n\r\nlogger = logging.getLogger(__name__)\r\n\r\n\r\ndef call_search_api(\r\n    retrieval_service_url: str,\r\n    query_list: list[str],\r\n    topk: int = 3,\r\n    return_scores: bool = True,\r\n    timeout: int = DEFAULT_TIMEOUT,\r\n) -> tuple[Optional[dict[str, Any]], Optional[str]]:\r\n    \"\"\"\r\n    Calls the remote search API to perform retrieval with retry logic for various errors,\r\n    using increasing delay between retries. Logs internal calls with a unique ID.\r\n\r\n    Args:\r\n        retrieval_service_url: The URL of the retrieval service API.\r\n        query_list: List of search queries.\r\n        topk: Number of top results to return.\r\n        return_scores: Whether to return scores.\r\n        timeout: Request timeout in seconds.\r\n\r\n    Returns:\r\n        A tuple (response_json, error_message).\r\n        If successful, response_json is the API's returned JSON object, error_message is None.\r\n        If failed after retries, response_json is None, error_message contains the error information.\r\n    \"\"\"\r\n    request_id = str(uuid.uuid4())\r\n    log_prefix = f\"[Search Request ID: {request_id}] \"\r\n\r\n    payload = {\"queries\": query_list, \"topk\": topk, \"return_scores\": return_scores}\r\n\r\n    headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\r\n\r\n    last_error = None\r\n\r\n    for attempt in range(MAX_RETRIES):\r\n        try:\r\n            logger.info(\r\n                f\"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling search API at {retrieval_service_url}\"\r\n            )\r\n            response = requests.post(\r\n                retrieval_service_url,\r\n                headers=headers,\r\n                json=payload,\r\n                timeout=timeout,\r\n            )\r\n\r\n            # Check for Gateway Timeout (504) and other server errors for retrying\r\n            if response.status_code in [500, 502, 503, 504]:\r\n                last_error = (\r\n                    f\"{log_prefix}API Request Error: Server Error ({response.status_code}) on attempt \"\r\n                    f\"{attempt + 1}/{MAX_RETRIES}\"\r\n                )\r\n                logger.warning(last_error)\r\n                if attempt < MAX_RETRIES - 1:\r\n                    delay = INITIAL_RETRY_DELAY * (attempt + 1)\r\n                    logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")\r\n                    time.sleep(delay)\r\n                continue\r\n\r\n            # Check for other HTTP errors (e.g., 4xx)\r\n            response.raise_for_status()\r\n\r\n            # If successful (status code 2xx)\r\n            logger.info(f\"{log_prefix}Search API call successful on attempt {attempt + 1}\")\r\n            return response.json(), None\r\n\r\n        except requests.exceptions.ConnectionError as e:\r\n            last_error = f\"{log_prefix}Connection Error: {e}\"\r\n            logger.warning(last_error)\r\n            if attempt < MAX_RETRIES - 1:\r\n                delay = INITIAL_RETRY_DELAY * (attempt + 1)\r\n                logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")\r\n                time.sleep(delay)\r\n            continue\r\n        except requests.exceptions.Timeout as e:\r\n            last_error = f\"{log_prefix}Timeout Error: {e}\"\r\n            logger.warning(last_error)\r\n            if attempt < MAX_RETRIES - 1:\r\n                delay = INITIAL_RETRY_DELAY * (attempt + 1)\r\n                logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")\r\n                time.sleep(delay)\r\n            continue\r\n        except requests.exceptions.RequestException as e:\r\n            last_error = f\"{log_prefix}API Request Error: {e}\"\r\n            break  # Exit retry loop on other request errors\r\n        except json.JSONDecodeError as e:\r\n            raw_response_text = response.text if \"response\" in locals() else \"N/A\"\r\n            last_error = f\"{log_prefix}API Response JSON Decode Error: {e}, Response: {raw_response_text[:200]}\"\r\n            break  # Exit retry loop on JSON decode errors\r\n        except Exception as e:\r\n            last_error = f\"{log_prefix}Unexpected Error: {e}\"\r\n            break  # Exit retry loop on other unexpected errors\r\n\r\n    # If loop finishes without returning success, return the last recorded error\r\n    logger.error(f\"{log_prefix}Search API call failed. Last error: {last_error}\")\r\n    return None, last_error.replace(log_prefix, \"API Call Failed: \") if last_error else \"API Call Failed after retries\"\r\n\r\n\r\ndef _passages2string(retrieval_result):\r\n    \"\"\"Convert retrieval results to formatted string.\"\"\"\r\n    format_reference = \"\"\r\n    for idx, doc_item in enumerate(retrieval_result):\r\n        content = doc_item[\"document\"][\"contents\"]\r\n        title = content.split(\"\\n\")[0]\r\n        text = \"\\n\".join(content.split(\"\\n\")[1:])\r\n        format_reference += f\"Doc {idx + 1} (Title: {title})\\n{text}\\n\\n\"\r\n    return format_reference.strip()\r\n\r\n\r\ndef perform_single_search_batch(\r\n    retrieval_service_url: str,\r\n    query_list: list[str],\r\n    topk: int = 3,\r\n    concurrent_semaphore: Optional[threading.Semaphore] = None,\r\n    timeout: int = DEFAULT_TIMEOUT,\r\n) -> tuple[str, dict[str, Any]]:\r\n    \"\"\"\r\n    Performs a single batch search for multiple queries (original search tool behavior).\r\n\r\n    Args:\r\n        retrieval_service_url: The URL of the retrieval service API.\r\n        query_list: List of search queries.\r\n        topk: Number of top results to return.\r\n        concurrent_semaphore: Optional semaphore for concurrency control.\r\n        timeout: Request timeout in seconds.\r\n\r\n    Returns:\r\n        A tuple (result_text, metadata).\r\n        result_text: The search result JSON string.\r\n        metadata: Metadata dictionary for the batch search.\r\n    \"\"\"\r\n    logger.info(f\"Starting batch search for {len(query_list)} queries.\")\r\n\r\n    api_response = None\r\n    error_msg = None\r\n\r\n    try:\r\n        if concurrent_semaphore:\r\n            with concurrent_semaphore:\r\n                api_response, error_msg = call_search_api(\r\n                    retrieval_service_url=retrieval_service_url,\r\n                    query_list=query_list,\r\n                    topk=topk,\r\n                    return_scores=True,\r\n                    timeout=timeout,\r\n                )\r\n        else:\r\n            api_response, error_msg = call_search_api(\r\n                retrieval_service_url=retrieval_service_url,\r\n                query_list=query_list,\r\n                topk=topk,\r\n                return_scores=True,\r\n                timeout=timeout,\r\n            )\r\n    except Exception as e:\r\n        error_msg = f\"API Request Exception during batch search: {e}\"\r\n        logger.error(f\"Batch search: {error_msg}\")\r\n        traceback.print_exc()\r\n\r\n    metadata = {\r\n        \"query_count\": len(query_list),\r\n        \"queries\": query_list,\r\n        \"api_request_error\": error_msg,\r\n        \"api_response\": None,\r\n        \"status\": \"unknown\",\r\n        \"total_results\": 0,\r\n        \"formatted_result\": None,\r\n    }\r\n\r\n    result_text = json.dumps({\"result\": \"Search request failed or timed out after retries.\"})\r\n\r\n    if error_msg:\r\n        metadata[\"status\"] = \"api_error\"\r\n        result_text = json.dumps({\"result\": f\"Search error: {error_msg}\"})\r\n        logger.error(f\"Batch search: API error occurred: {error_msg}\")\r\n    elif api_response:\r\n        logger.debug(f\"Batch search: API Response: {api_response}\")\r\n        metadata[\"api_response\"] = api_response\r\n\r\n        try:\r\n            raw_results = api_response.get(\"result\", [])\r\n            if raw_results:\r\n                pretty_results = []\r\n                total_results = 0\r\n\r\n                for retrieval in raw_results:\r\n                    formatted = _passages2string(retrieval)\r\n                    pretty_results.append(formatted)\r\n                    total_results += len(retrieval) if isinstance(retrieval, list) else 1\r\n\r\n                final_result = \"\\n---\\n\".join(pretty_results)\r\n                result_text = json.dumps({\"result\": final_result})\r\n                metadata[\"status\"] = \"success\"\r\n                metadata[\"total_results\"] = total_results\r\n                metadata[\"formatted_result\"] = final_result\r\n                logger.info(f\"Batch search: Successful, got {total_results} total results\")\r\n            else:\r\n                result_text = json.dumps({\"result\": \"No search results found.\"})\r\n                metadata[\"status\"] = \"no_results\"\r\n                metadata[\"total_results\"] = 0\r\n                logger.info(\"Batch search: No results found\")\r\n        except Exception as e:\r\n            error_msg = f\"Error processing search results: {e}\"\r\n            result_text = json.dumps({\"result\": error_msg})\r\n            metadata[\"status\"] = \"processing_error\"\r\n            logger.error(f\"Batch search: {error_msg}\")\r\n    else:\r\n        metadata[\"status\"] = \"unknown_api_state\"\r\n        result_text = json.dumps({\"result\": \"Unknown API state (no response and no error message).\"})\r\n        logger.error(\"Batch search: Unknown API state.\")\r\n\r\n    return result_text, metadata\r\n"
  },
  {
    "path": "verl_rl/verl/tools/utils/tool_registry.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport importlib\nimport logging\nimport os\nimport sys\nfrom enum import Enum\n\nfrom omegaconf import OmegaConf\n\nfrom verl.tools.schemas import OpenAIFunctionToolSchema\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass ToolType(Enum):\n    NATIVE = \"native\"\n    MCP = \"mcp\"\n\n\nasync def initialize_mcp_tool(tool_cls, tool_config) -> list:\n    from verl.tools.utils.mcp_clients.McpClientManager import ClientManager\n\n    tool_list = []\n    mcp_servers_config_path = tool_config.mcp.mcp_servers_config_path\n    tool_selected_list = tool_config.mcp.tool_selected_list if \"tool_selected_list\" in tool_config.mcp else None\n    await ClientManager.initialize(mcp_servers_config_path, tool_config.config.rate_limit)\n    # Wait for MCP client to be ready\n    max_retries = 10\n    retry_interval = 2  # seconds\n    for i in range(max_retries):\n        tool_schemas = await ClientManager.fetch_tool_schemas(tool_selected_list)\n        if tool_schemas:\n            break\n        if i < max_retries - 1:\n            logger.debug(f\"Waiting for MCP client to be ready, attempt {i + 1}/{max_retries}\")\n            await asyncio.sleep(retry_interval)\n    else:\n        raise RuntimeError(\"Failed to initialize MCP tools after maximum retries\")\n    # mcp registry\n    assert len(tool_schemas), \"mcp tool is empty\"\n    for tool_schema_dict in tool_schemas:\n        logger.debug(f\"tool_schema_dict: {tool_schema_dict}\")\n        tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict)\n        tool = tool_cls(\n            config=OmegaConf.to_container(tool_config.config, resolve=True),\n            tool_schema=tool_schema,\n        )\n        tool_list.append(tool)\n    return tool_list\n\n\ndef get_tool_class(cls_name):\n    module_name, class_name = cls_name.rsplit(\".\", 1)\n    if module_name not in sys.modules:\n        spec = importlib.util.find_spec(module_name)\n        module = importlib.util.module_from_spec(spec)\n        sys.modules[module_name] = module\n        spec.loader.exec_module(module)\n    else:\n        module = sys.modules[module_name]\n\n    tool_cls = getattr(module, class_name)\n    return tool_cls\n\n\ndef initialize_tools_from_config(tools_config_file):\n    tools_config = OmegaConf.load(tools_config_file)\n    tool_list = []\n    for tool_config in tools_config.tools:\n        cls_name = tool_config.class_name\n        tool_type = ToolType(tool_config.config.type)\n        tool_cls = get_tool_class(cls_name)\n\n        match tool_type:\n            case ToolType.NATIVE:\n                if tool_config.get(\"tool_schema\", None) is None:\n                    tool_schema = None\n                else:\n                    tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True)\n                    tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict)\n                tool = tool_cls(\n                    config=OmegaConf.to_container(tool_config.config, resolve=True),\n                    tool_schema=tool_schema,\n                )\n                tool_list.append(tool)\n            case ToolType.MCP:\n                loop = asyncio.get_event_loop()\n                mcp_tools = loop.run_until_complete(initialize_mcp_tool(tool_cls, tool_config))\n                tool_list.extend(mcp_tools)\n            case _:\n                raise NotImplementedError\n    return tool_list\n"
  },
  {
    "path": "verl_rl/verl/trainer/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .algorithm import AlgoConfig, FilterGroupsConfig, KLControlConfig, PFPPOConfig\nfrom .config import CriticConfig, FSDPCriticConfig, MegatronCriticConfig\n\n__all__ = [\n    \"AlgoConfig\",\n    \"CriticConfig\",\n    \"FilterGroupsConfig\",\n    \"FSDPCriticConfig\",\n    \"KLControlConfig\",\n    \"MegatronCriticConfig\",\n    \"PFPPOConfig\",\n]\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/_generated_ppo_megatron_trainer.yaml",
    "content": "# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'\n# in which it invokes 'python3 scripts/print_cfg.py --cfg job --config-name=ppo_megatron_trainer.yaml' to flatten the 'verl/trainer/config/ppo_megatron_trainer.yaml' config fields into a single file.\n# Do not modify this file directly.\n# The file is usually only for reference and never used.\n\nactor_rollout_ref:\n  actor:\n    strategy: megatron\n    ppo_mini_batch_size: 256\n    ppo_micro_batch_size: null\n    ppo_micro_batch_size_per_gpu: null\n    use_dynamic_bsz: false\n    ppo_max_token_len_per_gpu: 16384\n    clip_ratio: 0.2\n    clip_ratio_low: 0.2\n    clip_ratio_high: 0.2\n    policy_loss:\n      loss_mode: vanilla\n      clip_cov_ratio: 0.0002\n      clip_cov_lb: 1.0\n      clip_cov_ub: 5.0\n      kl_cov_ratio: 0.0002\n      ppo_kl_coef: 0.1\n    clip_ratio_c: 3.0\n    loss_agg_mode: token-mean\n    entropy_coeff: 0\n    use_kl_loss: false\n    use_torch_compile: true\n    kl_loss_coef: 0.001\n    kl_loss_type: low_var_kl\n    ppo_epochs: 1\n    shuffle: false\n    checkpoint:\n      save_contents:\n      - model\n      - optimizer\n      - extra\n      load_contents: ${.save_contents}\n      async_save: false\n    optim:\n      lr: 1.0e-06\n      lr_warmup_steps_ratio: 0.0\n      total_training_steps: -1\n      weight_decay: 0.01\n      optimizer: adam\n      clip_grad: 1.0\n      lr_warmup_init: 0.0\n      lr_warmup_steps: null\n      lr_decay_steps: null\n      lr_decay_style: constant\n      min_lr: 0.0\n      weight_decay_incr_style: constant\n      lr_wsd_decay_style: exponential\n      lr_wsd_decay_steps: null\n      use_checkpoint_opt_param_scheduler: false\n    data_loader_seed: null\n    load_weight: true\n    megatron:\n      param_offload: false\n      grad_offload: false\n      optimizer_offload: false\n      tensor_model_parallel_size: 1\n      expert_model_parallel_size: 1\n      expert_tensor_parallel_size: null\n      pipeline_model_parallel_size: 1\n      virtual_pipeline_model_parallel_size: null\n      context_parallel_size: 1\n      sequence_parallel: true\n      use_distributed_optimizer: true\n      use_dist_checkpointing: false\n      dist_checkpointing_path: null\n      seed: 42\n      override_ddp_config: {}\n      override_transformer_config:\n        recompute_granularity: null\n        recompute_modules:\n        - core_attn\n        recompute_method: null\n        recompute_num_layers: null\n      use_mbridge: false\n    profile:\n      use_profile: false\n      profile_ranks: null\n      step_start: -1\n      step_end: -1\n      save_path: null\n  ref:\n    strategy: megatron\n    use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}\n    log_prob_micro_batch_size: null\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n    log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n    megatron:\n      param_offload: false\n      tensor_model_parallel_size: 1\n      expert_model_parallel_size: 1\n      expert_tensor_parallel_size: None\n      pipeline_model_parallel_size: 1\n      virtual_pipeline_model_parallel_size: null\n      context_parallel_size: 1\n      sequence_parallel: true\n      use_distributed_optimizer: false\n      use_dist_checkpointing: false\n      dist_checkpointing_path: null\n      seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}\n      override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}\n      use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}\n    profile:\n      use_profile: false\n      profile_ranks: null\n      step_start: -1\n      step_end: -1\n      save_path: null\n    load_weight: true\n  rollout:\n    name: vllm\n    mode: sync\n    temperature: 1.0\n    top_k: -1\n    top_p: 1\n    prompt_length: ${oc.select:data.max_prompt_length,512}\n    response_length: ${oc.select:data.max_response_length,512}\n    dtype: bfloat16\n    gpu_memory_utilization: 0.5\n    ignore_eos: false\n    enforce_eager: true\n    free_cache_engine: true\n    tensor_model_parallel_size: 1\n    max_num_batched_tokens: 8192\n    max_model_len: null\n    max_num_seqs: 1024\n    log_prob_micro_batch_size: null\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n    log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n    disable_log_stats: true\n    do_sample: true\n    'n': 1\n    multi_stage_wake_up: false\n    engine_kwargs:\n      vllm:\n        swap_space: null\n        disable_mm_preprocessor_cache: false\n      sglang:\n        attention_backend: null\n    val_kwargs:\n      top_k: -1\n      top_p: 1.0\n      temperature: 0\n      'n': 1\n      do_sample: false\n    multi_turn:\n      enable: false\n      max_assistant_turns: null\n      tool_config_path: null\n      max_user_turns: null\n      max_parallel_calls: 1\n      max_tool_response_length: 256\n      tool_response_truncate_side: middle\n      interaction_config_path: null\n      completion_callback: null\n      use_inference_chat_template: false\n      tokenization_sanity_check_mode: strict\n      format: hermes\n    calculate_log_probs: false\n    agent:\n      num_workers: 8\n      agent_loop_config_path: null\n      custom_async_server:\n        path: null\n        name: null\n    update_weights_bucket_megabytes: 512\n    trace:\n      backend: null\n      token2text: false\n    enable_chunked_prefill: false\n    load_format: dummy_megatron\n    layer_name_map:\n      qkv_layer_name: qkv\n      gate_proj_layer_name: gate_up\n  hybrid_engine: true\n  nccl_timeout: 600\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    custom_chat_template: null\n    external_lib: null\n    override_config:\n      model_config: {}\n      moe_config:\n        freeze_moe_router: false\n    use_fused_kernels: false\n    trust_remote_code: false\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: false\n    all_ranks: false\n    ranks: []\ntrainer:\n  npu_profile:\n    options:\n      save_path: ./profiler_data\n      level: level1\n      with_memory: false\n      record_shapes: false\n      with_npu: true\n      with_cpu: true\n      with_module: false\n      with_stack: false\n      analysis: true\n  balance_batch: true\n  total_epochs: 30\n  total_training_steps: null\n  profile_steps: null\n  project_name: verl_examples\n  experiment_name: gsm8k\n  logger:\n  - console\n  - wandb\n  log_val_generations: 0\n  nnodes: 1\n  n_gpus_per_node: 8\n  save_freq: -1\n  esi_redundant_time: 0\n  resume_mode: auto\n  resume_from_path: null\n  del_local_ckpt_after_load: false\n  val_before_train: true\n  test_freq: -1\n  critic_warmup: 0\n  default_hdfs_dir: null\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  max_actor_ckpt_to_keep: null\n  max_critic_ckpt_to_keep: null\n  ray_wait_register_center_timeout: 300\n  device: cuda\n  controller_nsight_options:\n    trace: cuda,nvtx,cublas,ucx\n    cuda-memory-usage: 'true'\n    cuda-graph-trace: graph\n  worker_nsight_options:\n    trace: cuda,nvtx,cublas,ucx\n    cuda-memory-usage: 'true'\n    cuda-graph-trace: graph\n    capture-range: cudaProfilerApi\n    capture-range-end: null\n    kill: none\ndata:\n  tokenizer: null\n  use_shm: false\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n  prompt_key: prompt\n  reward_fn_key: data_source\n  max_prompt_length: 512\n  max_response_length: 512\n  train_batch_size: 1024\n  val_batch_size: null\n  return_raw_input_ids: false\n  return_raw_chat: false\n  return_full_prompt: false\n  shuffle: true\n  dataloader_num_workers: 8\n  validation_shuffle: false\n  filter_overlong_prompts: false\n  filter_overlong_prompts_workers: 1\n  truncation: error\n  image_key: images\n  video_key: videos\n  trust_remote_code: false\n  custom_cls:\n    path: null\n    name: null\n  return_multi_modal_inputs: true\n  sampler:\n    class_path: null\n    class_name: null\n  datagen:\n    path: null\n    name: null\ncritic:\n  rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}\n  strategy: megatron\n  optim:\n    lr_warmup_steps_ratio: 0.0\n    total_training_steps: -1\n    weight_decay: 0.01\n    optimizer: adam\n    lr: 1.0e-06\n    clip_grad: 1.0\n    lr_warmup_init: 0.0\n    lr_warmup_steps: null\n    lr_decay_steps: null\n    lr_decay_style: linear\n    min_lr: 0.0\n    weight_decay_incr_style: constant\n    lr_wsd_decay_style: exponential\n    lr_wsd_decay_steps: null\n    use_checkpoint_opt_param_scheduler: false\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    tokenizer_path: ${oc.select:actor_rollout_ref.model.path,\"~/models/deepseek-llm-7b-chat\"}\n    override_config:\n      model_config: {}\n      moe_config:\n        freeze_moe_router: false\n    external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}\n    trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}\n  ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}\n  ppo_micro_batch_size: null\n  ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}\n  use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n  ppo_max_token_len_per_gpu: 32768\n  forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}\n  ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1}\n  shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false}\n  cliprange_value: 0.5\n  loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}\n  checkpoint:\n    save_contents:\n    - model\n    - optimizer\n    - extra\n    load_contents: ${.save_contents}\n    async_save: false\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: false\n    all_ranks: false\n    ranks: []\n  _target_: verl.trainer.config.MegatronCriticConfig\n  nccl_timeout: 600\n  megatron:\n    param_offload: false\n    grad_offload: false\n    optimizer_offload: false\n    tensor_model_parallel_size: 1\n    expert_model_parallel_size: 1\n    expert_tensor_parallel_size: null\n    pipeline_model_parallel_size: 1\n    virtual_pipeline_model_parallel_size: null\n    context_parallel_size: 1\n    sequence_parallel: true\n    use_distributed_optimizer: true\n    use_dist_checkpointing: false\n    dist_checkpointing_path: null\n    seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}\n    override_ddp_config: ${oc.select:actor_rollout_ref.actor.megatron.override_ddp_config,{}}\n    override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}\n    use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}\n  load_weight: true\n  data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null}\nreward_model:\n  enable: false\n  strategy: megatron\n  model:\n    input_tokenizer: ${actor_rollout_ref.model.path}\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    trust_remote_code: false\n  micro_batch_size: null\n  micro_batch_size_per_gpu: null\n  max_length: null\n  use_dynamic_bsz: ${critic.use_dynamic_bsz}\n  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n  reward_manager: naive\n  launch_reward_fn_async: false\n  sandbox_fusion:\n    url: null\n    max_concurrent: 64\n    memory_limit_mb: 1024\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: false\n    all_ranks: false\n    ranks: []\n  nccl_timeout: 600\n  megatron:\n    param_offload: false\n    tensor_model_parallel_size: 1\n    expert_model_parallel_size: 1\n    expert_tensor_parallel_size: null\n    pipeline_model_parallel_size: 1\n    virtual_pipeline_model_parallel_size: null\n    context_parallel_size: 1\n    sequence_parallel: true\n    use_distributed_optimizer: false\n    use_dist_checkpointing: false\n    dist_checkpointing_path: null\n    seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}\n    override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}\n    use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}\n  load_weight: true\ncustom_reward_function:\n  path: null\n  name: compute_score\nalgorithm:\n  _target_: verl.trainer.config.AlgoConfig\n  gamma: 1.0\n  lam: 1.0\n  adv_estimator: gae\n  norm_adv_by_std_in_grpo: true\n  use_kl_in_reward: false\n  kl_penalty: kl\n  kl_ctrl:\n    _target_: verl.trainer.config.KLControlConfig\n    type: fixed\n    kl_coef: 0.001\n    horizon: 10000\n    target_kl: 0.1\n  use_pf_ppo: false\n  pf_ppo:\n    _target_: verl.trainer.config.PFPPOConfig\n    reweight_method: pow\n    weight_pow: 2.0\nray_init:\n  num_cpus: null\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/_generated_ppo_trainer.yaml",
    "content": "# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'\n# in which it invokes 'python3 scripts/print_cfg.py --cfg job ' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file.\n# Do not modify this file directly.\n# The file is usually only for reference and never used.\n\nactor_rollout_ref:\n  actor:\n    strategy: fsdp\n    ppo_mini_batch_size: 256\n    ppo_micro_batch_size: null\n    ppo_micro_batch_size_per_gpu: null\n    use_dynamic_bsz: false\n    ppo_max_token_len_per_gpu: 16384\n    clip_ratio: 0.2\n    clip_ratio_low: 0.2\n    clip_ratio_high: 0.2\n    policy_loss:\n      loss_mode: vanilla\n      clip_cov_ratio: 0.0002\n      clip_cov_lb: 1.0\n      clip_cov_ub: 5.0\n      kl_cov_ratio: 0.0002\n      ppo_kl_coef: 0.1\n    clip_ratio_c: 3.0\n    loss_agg_mode: token-mean\n    entropy_coeff: 0\n    use_kl_loss: false\n    use_torch_compile: true\n    kl_loss_coef: 0.001\n    kl_loss_type: low_var_kl\n    ppo_epochs: 1\n    shuffle: false\n    checkpoint:\n      save_contents:\n      - model\n      - optimizer\n      - extra\n      load_contents: ${.save_contents}\n    optim:\n      lr: 1.0e-06\n      lr_warmup_steps_ratio: 0.0\n      total_training_steps: -1\n      weight_decay: 0.01\n      lr_warmup_steps: -1\n      min_lr_ratio: 0.0\n      num_cycles: 0.5\n      warmup_style: constant\n    grad_clip: 1.0\n    ulysses_sequence_parallel_size: 1\n    entropy_from_logits_with_chunking: false\n    entropy_checkpointing: false\n    fsdp_config:\n      wrap_policy:\n        min_num_params: 0\n      param_offload: false\n      optimizer_offload: false\n      offload_policy: false\n      reshard_after_forward: true\n      fsdp_size: -1\n      forward_prefetch: false\n  ref:\n    strategy: ${actor_rollout_ref.actor.strategy}\n    use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}\n    log_prob_micro_batch_size: null\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n    log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n    fsdp_config:\n      param_offload: false\n      reshard_after_forward: true\n      forward_prefetch: false\n      wrap_policy:\n        min_num_params: 0\n    ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}\n    entropy_from_logits_with_chunking: false\n    entropy_checkpointing: false\n  rollout:\n    name: vllm\n    mode: sync\n    temperature: 1.0\n    top_k: -1\n    top_p: 1\n    prompt_length: ${oc.select:data.max_prompt_length,512}\n    response_length: ${oc.select:data.max_response_length,512}\n    dtype: bfloat16\n    gpu_memory_utilization: 0.5\n    ignore_eos: false\n    enforce_eager: true\n    free_cache_engine: true\n    tensor_model_parallel_size: 2\n    max_num_batched_tokens: 8192\n    max_model_len: null\n    max_num_seqs: 1024\n    log_prob_micro_batch_size: null\n    log_prob_micro_batch_size_per_gpu: null\n    log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n    log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n    disable_log_stats: true\n    do_sample: true\n    'n': 1\n    multi_stage_wake_up: false\n    engine_kwargs:\n      vllm:\n        swap_space: null\n        disable_mm_preprocessor_cache: false\n      sglang:\n        attention_backend: null\n    val_kwargs:\n      top_k: -1\n      top_p: 1.0\n      temperature: 0\n      'n': 1\n      do_sample: false\n    multi_turn:\n      enable: false\n      max_assistant_turns: null\n      tool_config_path: null\n      max_user_turns: null\n      max_parallel_calls: 1\n      max_tool_response_length: 256\n      tool_response_truncate_side: middle\n      interaction_config_path: null\n      completion_callback: null\n      use_inference_chat_template: false\n      tokenization_sanity_check_mode: strict\n      format: hermes\n    calculate_log_probs: false\n    agent:\n      num_workers: 8\n      agent_loop_config_path: null\n      custom_async_server:\n        path: null\n        name: null\n    update_weights_bucket_megabytes: 512\n    trace:\n      backend: null\n      token2text: false\n    enable_chunked_prefill: true\n    load_format: dummy_dtensor\n    layered_summon: false\n  hybrid_engine: true\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    custom_chat_template: null\n    use_shm: false\n    external_lib: null\n    override_config: {}\n    enable_gradient_checkpointing: true\n    enable_activation_offload: false\n    use_remove_padding: false\n    lora_rank: 0\n    lora_alpha: 16\n    target_modules: all-linear\n    exclude_modules: null\n    use_liger: false\n    use_fused_kernels: false\n    fused_kernel_options:\n      impl_backend: torch\n    trust_remote_code: false\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: false\n    all_ranks: false\n    ranks: []\ntrainer:\n  npu_profile:\n    options:\n      save_path: ./profiler_data\n      level: level1\n      with_memory: false\n      record_shapes: false\n      with_npu: true\n      with_cpu: true\n      with_module: false\n      with_stack: false\n      analysis: true\n  balance_batch: true\n  total_epochs: 30\n  total_training_steps: null\n  profile_steps: null\n  controller_nsight_options:\n    trace: cuda,nvtx,cublas,ucx\n    cuda-memory-usage: 'true'\n    cuda-graph-trace: graph\n  worker_nsight_options:\n    trace: cuda,nvtx,cublas,ucx\n    cuda-memory-usage: 'true'\n    cuda-graph-trace: graph\n    capture-range: cudaProfilerApi\n    capture-range-end: null\n    kill: none\n  project_name: verl_examples\n  experiment_name: gsm8k\n  logger:\n  - console\n  - wandb\n  log_val_generations: 0\n  rollout_data_dir: null\n  validation_data_dir: null\n  nnodes: 1\n  n_gpus_per_node: 8\n  save_freq: -1\n  esi_redundant_time: 0\n  resume_mode: auto\n  resume_from_path: null\n  val_before_train: true\n  val_only: false\n  test_freq: -1\n  critic_warmup: 0\n  default_hdfs_dir: null\n  del_local_ckpt_after_load: false\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  max_actor_ckpt_to_keep: null\n  max_critic_ckpt_to_keep: null\n  ray_wait_register_center_timeout: 300\n  device: cuda\n  use_legacy_worker_impl: auto\ndata:\n  tokenizer: null\n  use_shm: false\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n  prompt_key: prompt\n  reward_fn_key: data_source\n  max_prompt_length: 512\n  max_response_length: 512\n  train_batch_size: 1024\n  val_batch_size: null\n  return_raw_input_ids: false\n  return_raw_chat: false\n  return_full_prompt: false\n  shuffle: true\n  dataloader_num_workers: 8\n  validation_shuffle: false\n  filter_overlong_prompts: false\n  filter_overlong_prompts_workers: 1\n  truncation: error\n  image_key: images\n  video_key: videos\n  trust_remote_code: false\n  custom_cls:\n    path: null\n    name: null\n  return_multi_modal_inputs: true\n  sampler:\n    class_path: null\n    class_name: null\n  datagen:\n    path: null\n    name: null\ncritic:\n  rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}\n  strategy: fsdp\n  optim:\n    lr_warmup_steps_ratio: 0.0\n    total_training_steps: -1\n    weight_decay: 0.01\n    lr: 1.0e-05\n    min_lr_ratio: null\n    warmup_style: constant\n  model:\n    path: ~/models/deepseek-llm-7b-chat\n    tokenizer_path: ${oc.select:actor_rollout_ref.model.path,\"~/models/deepseek-llm-7b-chat\"}\n    override_config: {}\n    external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}\n    trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}\n    use_shm: false\n    enable_gradient_checkpointing: true\n    enable_activation_offload: false\n    use_remove_padding: false\n    fsdp_config:\n      param_offload: false\n      optimizer_offload: false\n      offload_policy: false\n      reshard_after_forward: true\n      wrap_policy:\n        min_num_params: 0\n      fsdp_size: -1\n      forward_prefetch: false\n    lora_rank: 0\n    lora_alpha: 16\n    target_modules: all-linear\n  ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}\n  ppo_micro_batch_size: null\n  ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}\n  use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n  ppo_max_token_len_per_gpu: 32768\n  forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}\n  ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1}\n  shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false}\n  cliprange_value: 0.5\n  loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}\n  checkpoint:\n    save_contents:\n    - model\n    - optimizer\n    - extra\n    load_contents: ${.save_contents}\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: false\n    all_ranks: false\n    ranks: []\n  _target_: verl.trainer.config.FSDPCriticConfig\n  forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null}\n  forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null}\n  ulysses_sequence_parallel_size: 1\n  grad_clip: 1.0\nreward_model:\n  enable: false\n  strategy: fsdp\n  model:\n    input_tokenizer: ${actor_rollout_ref.model.path}\n    path: ~/models/FsfairX-LLaMA3-RM-v0.1\n    external_lib: ${actor_rollout_ref.model.external_lib}\n    trust_remote_code: false\n    use_shm: false\n    use_remove_padding: false\n    use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}\n    fsdp_config:\n      wrap_policy:\n        min_num_params: 0\n      param_offload: false\n      reshard_after_forward: true\n      fsdp_size: -1\n      forward_prefetch: false\n  micro_batch_size: null\n  micro_batch_size_per_gpu: null\n  max_length: null\n  use_dynamic_bsz: ${critic.use_dynamic_bsz}\n  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n  reward_manager: naive\n  launch_reward_fn_async: false\n  sandbox_fusion:\n    url: null\n    max_concurrent: 64\n    memory_limit_mb: 1024\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: false\n    all_ranks: false\n    ranks: []\n  ulysses_sequence_parallel_size: 1\ncustom_reward_function:\n  path: null\n  name: compute_score\nalgorithm:\n  _target_: verl.trainer.config.AlgoConfig\n  gamma: 1.0\n  lam: 1.0\n  adv_estimator: gae\n  norm_adv_by_std_in_grpo: true\n  use_kl_in_reward: false\n  kl_penalty: kl\n  kl_ctrl:\n    _target_: verl.trainer.config.KLControlConfig\n    type: fixed\n    kl_coef: 0.001\n    horizon: 10000\n    target_kl: 0.1\n  use_pf_ppo: false\n  pf_ppo:\n    _target_: verl.trainer.config.PFPPOConfig\n    reweight_method: pow\n    weight_pow: 2.0\nray_init:\n  num_cpus: null\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/actor/actor.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# the abstract actor configs\n# fsdp, fsdp2 or megatron. must be set.\nstrategy: ???\n\n# Split each sample into sub-batches of this size for PPO\nppo_mini_batch_size: 256\n\n# [Deprecated] Global micro batch size\nppo_micro_batch_size: null\n\n# Local per-GPU micro batch size\nppo_micro_batch_size_per_gpu: null\n\n# Whether to automatically adjust batch size at runtime\n# oc.select: the default val for ref.log_prob_use_dynamic_bsz\nuse_dynamic_bsz: false\n\n# Max tokens per GPU in one PPO batch; affects gradient accumulation\n# Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length}\n# oc.select: the default val for ref.log_prob_max_token_len_per_gpu\nppo_max_token_len_per_gpu: 16384\n\n# PPO clip ratio\nclip_ratio: 0.2\n\n# Lower bound for asymmetric clipping (used in dual-clip PPO)\nclip_ratio_low: 0.2\n\n# Upper bound for asymmetric clipping (used in dual-clip PPO)\nclip_ratio_high: 0.2\n\n# policy loss config\npolicy_loss:\n\n  # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617\n  loss_mode: \"vanilla\"\n\n  # Ratio of tokens to be clipped for clip-cov loss\n  clip_cov_ratio: 0.0002\n\n  # Lower bound for clip-cov loss\n  clip_cov_lb: 1.0\n\n  # Upper bound for clip-cov loss\n  clip_cov_ub: 5.0\n\n  # Ratio of tokens to be applied kl penalty for kl-cov loss\n  kl_cov_ratio: 0.0002\n\n  # KL divergence penalty coefficient\n  ppo_kl_coef: 0.1\n\n# Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C\nclip_ratio_c: 3.0\n\n# Loss aggregation mode: \"token-mean\", \"seq-mean-token-sum\", or \"seq-mean-token-mean\"\nloss_agg_mode: token-mean\n\n# Entropy regularization coefficient in PPO loss\nentropy_coeff: 0\n\n# Whether to use KL loss instead of KL reward penalty. True for GRPO\nuse_kl_loss: false\n\n# Whether to use torch.compile()\n# oc.select: the default val for ref.use_torch_compile\nuse_torch_compile: true\n\n# KL loss coefficient when use_kl_loss is enabled. For GRPO\nkl_loss_coef: 0.001\n\n# Type of KL divergence loss. Options: \"kl\"(k1), \"abs\", \"mse\"(k2), \"low_var_kl\"(k3), \"full\"\nkl_loss_type: low_var_kl\n\n# Number of PPO epochs per batch\nppo_epochs: 1\n\n# Shuffle training data across PPO epochs\nshuffle: false\n\n# checkpoint configs\ncheckpoint:\n\n  # What to include in saved checkpoints\n  # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n  save_contents: ['model', 'optimizer', 'extra']\n\n  # For more flexibility, you can specify the contents to load from the checkpoint.\n  # .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg\n  load_contents: ${.save_contents}\n\n# optimizer configs\noptim:\n\n  # Learning rate\n  lr: 1e-6\n\n  # Warmup steps ratio (used if lr_warmup_steps is negative)\n  lr_warmup_steps_ratio: 0.0\n\n  # Total training steps (must be overridden at runtime)\n  total_training_steps: -1\n\n  # Weight decay\n  weight_decay: 0.01\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/actor/dp_actor.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# defaults specify the default config from each component\ndefaults:\n\n  # dp actor config, inheriting from trainer/config/actor/actor.yaml\n  - actor\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\n# TODO(haibin.lin): switch to fsdp2\nstrategy: fsdp\n\n# Gradient clipping for actor updates, specific to the strategy.\ngrad_clip: 1.0\n\n# Sequence parallelism size for Ulysses-style model parallelism\n# oc.select: the default val for ref.ulysses_sequence_parallel_size\nulysses_sequence_parallel_size: 1\n\n# calculate entropy with chunking to reduce memory peak\nentropy_from_logits_with_chunking: False\n\n# recompute entropy\nentropy_checkpointing: False\n\n# optimizer configs\noptim:\n\n  # Warmup steps; negative value delegates to lr_warmup_steps_ratio\n  lr_warmup_steps: -1\n\n  # Minimum LR ratio for cosine schedule\n  min_lr_ratio: 0.0\n\n  # Number of cosine cycles in LR schedule\n  num_cycles: 0.5\n\n  # LR warmup style: \"constant\" or \"cosine\"\n  warmup_style: constant\n\n# configs for FSDP\nfsdp_config:\n\n  # policy for wrapping the model\n  wrap_policy:\n\n    # Minimum number of parameters to trigger wrapping a layer with FSDP\n    min_num_params: 0\n\n  # Whether to offload model parameters to CPU (trades speed for memory)\n  param_offload: false\n\n  # Whether to offload optimizer state to CPU\n  optimizer_offload: false\n\n  # Only for FSDP2: offload param/grad/optimizer during train\n  offload_policy: false\n\n  # Only for FSDP2: Reshard after forward pass to reduce memory footprint\n  reshard_after_forward: true\n\n  # Number of GPUs in each FSDP shard group; -1 means auto\n  fsdp_size: -1\n\n  # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n  # before the current forward computation.\n  forward_prefetch: False\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/actor/megatron_actor.yaml",
    "content": "# megatron actor config, inheriting from trainer/config/actor/actor.yaml\ndefaults:\n  - actor\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\nstrategy: megatron\n\ndata_loader_seed: null\n\nload_weight: True\n\ncheckpoint:\n\n  async_save: False\n\noptim:\n\n  optimizer: adam\n\n  clip_grad: 1.0\n\n  # initial learning rate for warmup, default to 0.0\n  lr_warmup_init: 0.0\n\n  # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.\n  lr_warmup_steps: null\n\n  lr_decay_steps: null\n\n  # select from constant/linear/cosine/inverse_square_root\n  lr_decay_style: constant\n\n  # minimum learning rate, default to 0.0\n  min_lr: 0.0\n\n  # select from constant/linear/cosine\n  weight_decay_incr_style: constant\n\n  # select from constant/exponential/cosine\n  lr_wsd_decay_style: exponential\n\n  lr_wsd_decay_steps: null\n\n  # use checkpoint optimizer parameter scheduler\n  use_checkpoint_opt_param_scheduler: False\n\nmegatron:\n\n  param_offload: False\n\n  grad_offload: False\n\n  optimizer_offload: False\n\n  tensor_model_parallel_size: 1\n\n  expert_model_parallel_size: 1\n\n  expert_tensor_parallel_size: null\n\n  pipeline_model_parallel_size: 1\n\n  virtual_pipeline_model_parallel_size: null\n\n  context_parallel_size: 1\n\n  sequence_parallel: True\n\n  use_distributed_optimizer: True\n\n  use_dist_checkpointing: False\n\n  dist_checkpointing_path: null\n\n  # oc.select: default val for ref.megatron.seed\n  seed: 42\n\n  # Allow to override Distributed Data Parallel (DDP) config\n  override_ddp_config: {}\n\n  # additional transformer config like: num_layers_in_first(/last)_pipeline_stage\n  # oc.select: default val for ref.megatron.override_transformer_config\n  override_transformer_config:\n    # Recompute configuration, same as in megatron.training.arguments\n    # default use minimal performance-interference recompute methods\n    # Recompute granualarity, choices: [\"full\", \"selective\"]\n    recompute_granularity: null\n\n    # Recompute modules, multiple choices: [\"core_attn\", \"moe_act\", \"layernorm\", \"mla_up_proj\", \"mlp\", \"moe\"]\n    # Please use correct module in matched model\n    recompute_modules: [\"core_attn\"]\n\n    # 'uniform', 'block'\n    # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk\n    # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity\n    recompute_method: null\n\n    # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention\n    recompute_num_layers: null\n\n  # oc.select: default val for ref.megatron.use_mbridge\n  use_mbridge: False\n\n# profile the actor model in `update_policy` \nprofile:\n  # turn it on when you want to profile the actor model\n  use_profile: False\n\n  # list, you can specify the ranks to profile\n  profile_ranks: null\n\n  # start step in update_policy\n  step_start: -1\n\n  # end step\n  step_end: -1\n\n  # the path to save the profile result\n  save_path: null\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/algorithm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import Optional\n\nfrom verl.base_config import BaseConfig\n\n\n@dataclass\nclass KLControlConfig(BaseConfig):\n    \"\"\"Configuration for KL control.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        type (str): Type of KL control. Can be \"fixed\" or \"adaptive\".\n        kl_coef (float): Initial coefficient for KL penalty.\n        horizon (int): Horizon value for adaptive controller.\n        target_kl (float): Target KL divergence for adaptive controller.\n    \"\"\"\n\n    _frozen_fields = [\"type\", \"kl_coef\", \"horizon\", \"target_kl\"]\n    type: str = \"fixed\"\n    kl_coef: float = 0.001\n    horizon: int = 10000\n    target_kl: float = 0.1\n\n\n@dataclass\nclass PFPPOConfig(BaseConfig):\n    \"\"\"Configuration for preference feedback PPO.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        reweight_method (str): Method for reweighting samples. Can be \"pow\", \"max_min\", or \"max_random\".\n        weight_pow (float): Power used for weight scaling in \"pow\" method.\n    \"\"\"\n\n    _frozen_fields = [\"reweight_method\", \"weight_pow\"]\n    reweight_method: str = \"pow\"\n    weight_pow: float = 2.0\n\n\n@dataclass\nclass FilterGroupsConfig(BaseConfig):\n    \"\"\"Configuration for filter groups (used in DAPO and Entropy).\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        enable (bool): Whether to enable filter groups.\n        metric (Optional[str]): Metric to use for filtering: \"acc\", \"score\", \"seq_reward\", \"seq_final_reward\", etc.\n        max_num_gen_batches (int): Non-positive values mean no upper limit.\n    \"\"\"\n\n    _frozen_fields = [\"enable\", \"metric\", \"max_num_gen_batches\"]\n\n    enable: bool = False\n    metric: Optional[str] = None\n    max_num_gen_batches: int = 0\n\n\n@dataclass\nclass AlgoConfig(BaseConfig):\n    \"\"\"Configuration for the algorithm.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        gamma (float): Discount factor for future rewards.\n        lam (float): Trade-off between bias and variance in the GAE estimator.\n        adv_estimator (str): Advantage estimator type: \"gae\", \"grpo\", \"reinforce_plus_plus\", etc.\n        norm_adv_by_std_in_grpo (bool): Whether to normalize advantages by std (specific to GRPO).\n        use_kl_in_reward (bool): Whether to enable in-reward KL penalty.\n        kl_penalty (str): How to estimate KL divergence: \"kl\", \"abs\", \"mse\", \"low_var_kl\", or \"full\".\n        kl_ctrl (KLControlConfig): KL control configuration.\n        use_pf_ppo (bool): Whether to enable preference feedback PPO.\n        pf_ppo (Optional[PFPPOConfig]): Preference feedback PPO settings.\n        filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy\n    \"\"\"\n\n    _frozen_fields = [\n        \"gamma\",\n        \"lam\",\n        \"adv_estimator\",\n        \"norm_adv_by_std_in_grpo\",\n        \"use_kl_in_reward\",\n        \"kl_penalty\",\n        \"use_pf_ppo\",\n    ]\n\n    gamma: float = 1.0\n    lam: float = 1.0\n    adv_estimator: str = \"gae\"\n    norm_adv_by_std_in_grpo: bool = True\n    use_kl_in_reward: bool = False\n    kl_penalty: str = \"kl\"\n    kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig)\n    use_pf_ppo: bool = False\n    pf_ppo: Optional[PFPPOConfig] = None\n    filter_groups: Optional[FilterGroupsConfig] = None\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/config.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, Optional\n\nfrom verl.base_config import BaseConfig\n\n\n@dataclass\nclass CriticConfig(BaseConfig):\n    \"\"\"Configuration for critic model training.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        rollout_n (int): Number of rollouts per update (mirrors actor rollout_n).\n        strategy (str): Strategy used for critic model training (fsdp, fsdp2, megatron).\n        optim (Dict[str, Any]): Optimizer configuration including lr, weight_decay, etc.\n        model (Dict[str, Any]): Model configuration including path, tokenizer_path, etc.\n        ppo_mini_batch_size (int): PPO mini-batch size per update.\n        ppo_micro_batch_size (Optional[int]): Global micro batch size (deprecated).\n        ppo_micro_batch_size_per_gpu (Optional[int]): Local per-GPU micro batch size.\n        use_dynamic_bsz (bool): Whether to automatically adjust batch size at runtime.\n        ppo_max_token_len_per_gpu (int): Max tokens per GPU in one PPO batch.\n        forward_max_token_len_per_gpu (int): Max token length per GPU in forward pass.\n        ppo_epochs (int): Number of PPO epochs per batch.\n        shuffle (bool): Shuffle training data across PPO epochs.\n        cliprange_value (float): PPO value function clipping range.\n        loss_agg_mode (str): Loss aggregation mode.\n        checkpoint (Dict[str, Any]): Checkpoint configuration.\n        profiler (Dict[str, Any]): Profiler configuration.\n    \"\"\"\n\n    # For legacy reason configs related to batch_size are mutated in each role\n    # In the future they will be added to frozen fields instead\n    _frozen_fields = [\n        \"rollout_n\",\n        \"strategy\",\n        \"use_dynamic_bsz\",\n        \"ppo_max_token_len_per_gpu\",\n        \"forward_max_token_len_per_gpu\",\n        \"ppo_epochs\",\n        \"shuffle\",\n        \"cliprange_value\",\n        \"loss_agg_mode\",\n    ]\n\n    rollout_n: int = 1\n    strategy: str = \"fsdp\"\n    optim: dict[str, Any] = field(default_factory=dict)\n    model: dict[str, Any] = field(default_factory=dict)\n    ppo_mini_batch_size: int = 1\n    ppo_micro_batch_size: Optional[int] = None\n    ppo_micro_batch_size_per_gpu: Optional[int] = None\n    use_dynamic_bsz: bool = False\n    ppo_max_token_len_per_gpu: int = 32768\n    forward_max_token_len_per_gpu: int = 32768\n    ppo_epochs: int = 1\n    shuffle: bool = True\n    cliprange_value: float = 0.5\n    loss_agg_mode: str = \"token-mean\"\n    checkpoint: dict[str, Any] = field(default_factory=dict)\n    profiler: dict[str, Any] = field(default_factory=dict)\n\n\n@dataclass\nclass MegatronCriticConfig(CriticConfig):\n    \"\"\"Configuration for Megatron-based critic model training.\n\n    The inheritance from CriticConfig provides all base critic configuration plus Megatron-specific settings.\n\n    Args:\n        nccl_timeout (int): NCCL timeout in seconds for distributed operations.\n        megatron (Dict[str, Any]): Megatron-specific parallelism settings.\n        load_weight (bool): Whether to load initial weights.\n        data_loader_seed (Optional[int]): Seed for data loader.\n    \"\"\"\n\n    _frozen_fields = CriticConfig._frozen_fields + [\n        \"nccl_timeout\",\n        \"load_weight\",\n        \"data_loader_seed\",\n    ]\n\n    strategy: str = \"megatron\"\n    nccl_timeout: int = 600\n    megatron: dict[str, Any] = field(default_factory=dict)\n    load_weight: bool = True\n    data_loader_seed: Optional[int] = None\n\n\n@dataclass\nclass FSDPCriticConfig(CriticConfig):\n    \"\"\"Configuration for FSDP-based critic model training.\n\n    The inheritance from CriticConfig provides all base critic configuration plus FSDP-specific settings.\n\n    Args:\n        forward_micro_batch_size (int): Forward-only batch size during inference (global).\n        forward_micro_batch_size_per_gpu (int): Forward-only batch size during inference (per GPU).\n        ulysses_sequence_parallel_size (int): Sequence parallelism size for Ulysses-style model parallelism.\n        grad_clip (float): Gradient clipping for critic updates.\n    \"\"\"\n\n    _frozen_fields = CriticConfig._frozen_fields + [\n        \"ulysses_sequence_parallel_size\",\n        \"grad_clip\",\n    ]\n\n    strategy: str = \"fsdp\"\n    forward_micro_batch_size: int = 1\n    forward_micro_batch_size_per_gpu: int = 1\n    ulysses_sequence_parallel_size: int = 1\n    grad_clip: float = 1.0\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/critic/critic.yaml",
    "content": "# Number of rollouts per update (mirrors actor rollout_n)\nrollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}\n\n# fsdp or fsdp2 strategy used for critic model training\nstrategy: ???\n\n# optimizer configs\noptim:\n\n  # Warmup steps ratio; total steps will be injected at runtime\n  lr_warmup_steps_ratio: 0.0\n\n  # Total training steps (must be overridden at runtime)\n  total_training_steps: -1\n\n  # Weight decay\n  weight_decay: 0.01\n\n# model config for the critic\nmodel:\n\n  # Path to pretrained model weights\n  path: ~/models/deepseek-llm-7b-chat\n\n  # Tokenizer path (defaults to actor's model path)\n  tokenizer_path: ${oc.select:actor_rollout_ref.model.path,\"~/models/deepseek-llm-7b-chat\"}\n\n  # Hugging Face config override\n  override_config: {}\n\n  # External model implementation (optional)\n  external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}\n\n  # Whether to trust remote code from Hugging Face models\n  trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}\n\n# PPO mini-batch size per update\nppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}\n\n# [Deprecated] Global micro batch size\nppo_micro_batch_size: null\n\n# Local per-GPU micro batch size\nppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}\n\n# Whether to automatically adjust batch size at runtime\nuse_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n\n# Max tokens per GPU in one PPO batch (doubled for critic)\nppo_max_token_len_per_gpu: 32768\n\n# Max token length per GPU in forward pass\nforward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}\n\n# Number of PPO epochs per batch\nppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1}\n\n# Shuffle training data across PPO epochs\nshuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false}\n\n# PPO value function clipping range\ncliprange_value: 0.5\n\n# Loss aggregation mode: \"token-mean\", \"seq-mean-token-sum\", or \"seq-mean-token-mean\"\nloss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}\n\n# checkpoint configs\ncheckpoint:\n\n  # What to include in saved checkpoints\n  # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n  save_contents: ['model', 'optimizer', 'extra']\n\n  # What to include when loading checkpoints\n  load_contents: ${.save_contents}\n\n# profiler configs\n# the corresponding dataclass is verl.utils.profiler.ProfilerConfig.\nprofiler:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n  _target_: verl.utils.profiler.ProfilerConfig\n\n  # True for each task has its own database, False for all tasks in one training step share one database.\n  discrete: False\n\n  # Whether to profile all ranks.\n  all_ranks: False\n\n  # The ranks that will be profiled. [] or [0,1,...]\n  ranks: []\n\n# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n_target_: verl.trainer.config.CriticConfig\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/critic/dp_critic.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# defaults specify the default config from each component\ndefaults:\n\n  # dp actor config, inheriting from trainer/config/critic/critic.yaml\n  - critic\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\nstrategy: fsdp\n\n# optimizer configs\noptim:\n\n  # Learning rate\n  lr: 1e-5\n\n  # Minimum LR ratio for cosine schedule\n  min_lr_ratio: null\n\n  # LR warmup style: \"constant\" or \"cosine\"\n  warmup_style: constant\n\n# model config for the critic\nmodel:\n\n  # Whether to use shared memory for loading the model\n  use_shm: False\n\n  # Enable gradient checkpointing to save memory\n  enable_gradient_checkpointing: True\n\n  # Offload activations to CPU to reduce GPU memory usage\n  enable_activation_offload: False\n\n  # Use remove padding optimization (saves compute)\n  use_remove_padding: False\n\n  # FSDP-specific config\n  fsdp_config:\n\n    # Whether to offload model parameters to CPU\n    param_offload: False\n\n    # Whether to offload optimizer state to CPU\n    optimizer_offload: False\n\n    # Only for FSDP2: offload param/grad/optimizer during train\n    offload_policy: False\n\n    # Only for FSDP2: Reshard after forward pass to reduce memory footprint\n    reshard_after_forward: True\n\n    # Policy for wrapping layers with FSDP\n    wrap_policy:\n\n      # Minimum number of parameters to trigger wrapping\n      min_num_params: 0\n\n    # Number of GPUs in each FSDP shard group; -1 means auto\n    fsdp_size: -1\n\n    # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n    # before the current forward computation.\n    forward_prefetch: False\n\n  # Set to positive value to enable LoRA (e.g., 32)\n  lora_rank: 0\n\n  # LoRA scaling factor\n  lora_alpha: 16\n\n  # LoRA target modules: \"all-linear\" or list of linear projection layers\n  target_modules: all-linear\n\n# Forward-only batch size during inference (global)\nforward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null}\n\n# Forward-only batch size during inference (per GPU)\nforward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null}\n\n# Sequence parallelism size for Ulysses-style model parallelism\nulysses_sequence_parallel_size: 1\n\n# Gradient clipping for critic updates\ngrad_clip: 1.0\n\n# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n_target_: verl.trainer.config.FSDPCriticConfig\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/critic/megatron_critic.yaml",
    "content": "# defaults specify the default config from each component\ndefaults:\n\n  # dp actor config, inheriting from trainer/config/critic/critic.yaml\n  - critic\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\nstrategy: megatron\n\n# seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron\nnccl_timeout: 600\n\n# optimizer configs\noptim:\n\n  # select optimizer, default is Adam\n  optimizer: adam\n\n  # Learning rate\n  lr: 1e-6\n\n  # Clip gradients norm\n  clip_grad: 1.0\n\n  # initial learning rate for warmup, default to 0.0\n  lr_warmup_init: 0.0\n\n  # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.\n  lr_warmup_steps: null\n\n  lr_decay_steps: null\n\n  # select from constant/linear/cosine/inverse_square_root\n  lr_decay_style: linear\n\n  # minimum learning rate, default to 0.0\n  min_lr: 0.0\n\n  # select from constant/linear/cosine\n  weight_decay_incr_style: constant\n\n  # select from constant/exponential/cosine\n  lr_wsd_decay_style: exponential\n\n  # number of steps for weight std decay\n  lr_wsd_decay_steps: null\n\n  # use checkpoint optimizer parameter scheduler\n  use_checkpoint_opt_param_scheduler: False\n\n# model config for the critic\nmodel:\n\n  # override default empty mapping\n  override_config:\n\n    model_config: {}\n\n    moe_config:\n\n      freeze_moe_router: False\n\n# megatron-specific parallelism settings\nmegatron:\n\n  # Whether to offload model parameters to CPU\n  param_offload: False\n\n  # Whether to offload gradients to CPU\n  grad_offload: False\n\n  # Whether to offload optimizer state to CPU\n  optimizer_offload: False\n\n  # size of tensor model parallel group\n  tensor_model_parallel_size: 1\n\n  # size of expert model parallel group\n  expert_model_parallel_size: 1\n\n  # size of expert tensor parallel group\n  expert_tensor_parallel_size: null\n\n  # size of pipeline model parallel group\n  pipeline_model_parallel_size: 1\n\n  # size of virtual pipeline model parallel group\n  virtual_pipeline_model_parallel_size: null\n\n  # size of context parallel group\n  context_parallel_size: 1\n\n  # Whether to use sequence parallelism\n  sequence_parallel: True\n\n  # Whether to use distributed optimizer\n  use_distributed_optimizer: True\n\n  # Whether to use distributed checkpointing\n  use_dist_checkpointing: False\n\n  # Path for distributed checkpointing\n  dist_checkpointing_path: null\n\n  # Random seed for Megatron\n  seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}\n\n  # Allow to override Distributed Data Parallel (DDP) config\n  override_ddp_config: ${oc.select:actor_rollout_ref.actor.megatron.override_ddp_config,{}}\n\n  # Transformer config overrides for Megatron\n  override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}\n\n  # Whether to use mBridge communications\n  use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}\n\n# Whether to load initial weights\nload_weight: True\n\n# seed for data loader\ndata_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null}\n\n# Asynchronous checkpoint saving\ncheckpoint:\n  async_save: False\n\n# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n_target_: verl.trainer.config.MegatronCriticConfig\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/data/legacy_data.yaml",
    "content": "# Tokenizer class or path. If null, it will be inferred from the model.\ntokenizer: null\n\n# Whether to use shared memory for data loading.\nuse_shm: False\n\n# Training set parquet. Can be a list or a single file.\n# The program will read all files into memory, so it can't be too large (< 100GB).\n# The path can be either a local path or an HDFS path.\n# For HDFS path, we provide utils to download it to DRAM and convert it to a local path.\ntrain_files: ~/data/rlhf/gsm8k/train.parquet\n\n# Validation parquet. Can be a list or a single file.\nval_files: ~/data/rlhf/gsm8k/test.parquet\n\n# The field in the dataset where the prompt is located. Default is 'prompt'.\nprompt_key: prompt\n\n# The field used to select the reward function (if using different ones per example).\nreward_fn_key: data_source\n\n# Maximum prompt length. All prompts will be left-padded to this length.\n# An error will be reported if the length is too long.\n# oc.select: default val for rollout.prompt_length\nmax_prompt_length: 512\n\n# Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length.\n# oc.select: default val for rollout.response_length\nmax_response_length: 512\n\n# Batch size sampled for one training iteration of different RL algorithms.\ntrain_batch_size: 1024\n\n# Batch size used during validation. Can be null.\nval_batch_size: null\n\n# Whether to return the original input_ids without adding chat template.\n# This is used when the reward model's chat template differs from the policy.\n# If using a model-based RM with different templates, this should be True.\nreturn_raw_input_ids: False\n\n# Whether to return the original chat (prompt) without applying chat template.\nreturn_raw_chat: False\n\n# Whether to return the full prompt with chat template.\nreturn_full_prompt: False\n\n# Whether to shuffle the data in the dataloader.\nshuffle: True\n\n# num dataloader workers\ndataloader_num_workers: 8\n\n# Whether to shuffle the validation set.\nvalidation_shuffle: False\n\n# Whether to filter overlong prompts.\nfilter_overlong_prompts: False\n\n# Number of workers for filtering overlong prompts.\n# For large-scale datasets, filtering can be time-consuming.\n# Use multiprocessing to speed up. Default is 1.\nfilter_overlong_prompts_workers: 1\n\n# Truncate the input_ids or prompt if they exceed max_prompt_length.\n# Options: 'error', 'left', 'right', 'middle'. Default is 'error'.\ntruncation: error\n\n# The field in the multi-modal dataset where the image is located. Default is 'images'.\nimage_key: images\n\n# The field in the multi-modal dataset where the video is located.\nvideo_key: videos\n\n# If the remote tokenizer has a Python file, this flag determines whether to allow using it.\ntrust_remote_code: False\n\n# Optional: specify a custom dataset class path and name if overriding default loading behavior.\ncustom_cls:\n\n  # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used.\n  path: null\n\n  # The name of the dataset class within the specified file.\n  name: null\n\n# Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs.\nreturn_multi_modal_inputs: True\n\n# settings related to data sampler\nsampler:\n\n  # the path to the module containing a curriculum class which implements the\n  # AbstractSampler interface\n  class_path: null\n\n  # the name of the curriculum class like `MySampler`\n  class_name: null\n\n# Data generation configuration for augmenting the dataset.\ndatagen:\n\n  # The path to the file containing your customized data generation class.\n  # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset'\n  path: null\n\n  # The class name of the data generation class within the specified file.\n  # E.g. 'MockDataGenerator'\n  name: null"
  },
  {
    "path": "verl_rl/verl/trainer/config/evaluation.yaml",
    "content": "data:\n  path: /tmp/math_Qwen2-7B-Instruct.parquet\n  prompt_key: prompt\n  response_key: responses\n  data_source_key: data_source\n  reward_model_key: reward_model\n\ncustom_reward_function:\n  path: null\n  name: compute_score\n\nray_init:\n  num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/generation.yaml",
    "content": "trainer:\n  nnodes: 1\n  n_gpus_per_node: 8\n  device: cuda\n\ndata:\n  path: ~/data/rlhf/math/test.parquet\n  prompt_key: prompt\n  n_samples: 5\n  output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet\n  batch_size: 128\n\nmodel:\n  path: ~/models/Qwen2-7B-Instruct\n  external_lib: null\nrollout:\n  name: vllm\n  mode: sync # sync: LLM, async: AsyncLLM\n  temperature: 1.0\n  top_k: 50 # 0 for hf rollout, -1 for vllm rollout\n  top_p: 0.7\n  prompt_length: 1536\n  response_length: 512\n  # for vllm rollout\n  dtype: bfloat16 # should align with FSDP\n  gpu_memory_utilization: 0.5\n  ignore_eos: False\n  enforce_eager: True\n  free_cache_engine: True\n  load_format: dummy_dtensor\n  tensor_model_parallel_size: 1\n  max_num_batched_tokens: 8192\n  max_model_len: null\n  max_num_seqs: 1024\n  log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu\n  log_prob_micro_batch_size_per_gpu: 8\n  # for hf rollout\n  do_sample: True\n  disable_log_stats: True\n  enable_chunked_prefill: True\n  n: 1\n  # support logging rollout prob for debugging purpose\n  calculate_log_probs: False\nactor:\n  strategy: fsdp  # This is for backward-compatibility\n  ulysses_sequence_parallel_size: 1 # sp size\n  entropy_from_logits_with_chunking: False  # calculate entropy with chunking to reduce memory peak\n  entropy_checkpointing: False  # recompute entropy\n  fsdp_config:\n    fsdp_size: -1\n    forward_prefetch: False  # FSDP1 forward_prefetch configuration\n\nray_init:\n  num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/npu_profile/npu_profile.yaml",
    "content": "# Options for the npu profiler\noptions:\n\n  # Storage path of collected data.\n  save_path: ./profiler_data\n\n  # Collection level, optional values: level_none, level0, level1, level2.\n  level: level1\n\n  # Whether to enable memory analysis.\n  with_memory: False\n\n  # Whether to record tensor shape.\n  record_shapes: False\n\n  # Whether to record Device-side performance data.\n  with_npu: True\n\n  # Whether to record Host-side performance data.\n  with_cpu: True\n\n  # Whether to record Python call stack information.\n  with_module: False\n\n  # Whether to record operator call stack information.\n  with_stack: False\n\n  # Whether to automatically parse the data.\n  analysis: True"
  },
  {
    "path": "verl_rl/verl/trainer/config/ppo_megatron_trainer.yaml",
    "content": "# specify the default per-component configs\ndefaults:\n\n  # <folder_name>@<field_name>.<field_name>: <yaml_file_name>\n  # actor_rollout_ref.actor: trainer/config/actor/megatron_actor.yaml\n  - actor@actor_rollout_ref.actor: megatron_actor\n  # trainer.npu_profile: trainer/config/npu_profile/npu_profile.yaml\n  - npu_profile@trainer.npu_profile: npu_profile\n  # data: trainer/config/data/legacy_data.yaml\n  - data@data: legacy_data\n  # load the reference default config, then apply the fields in the current yaml\n  # Reference model config.\n  # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True.\n  - ref@actor_rollout_ref.ref: megatron_ref\n  # Rollout model config.\n  - rollout@actor_rollout_ref.rollout: rollout\n  # Critic model config.\n  - critic@critic: megatron_critic\n  # Reward model config.\n  - reward_model@reward_model: megatron_reward_model\n  - _self_\n\nactor_rollout_ref:\n  hybrid_engine: True\n\n  nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron\n\n  model:\n\n    path: ~/models/deepseek-llm-7b-chat\n\n    custom_chat_template: null\n\n    external_lib: null\n\n    override_config:\n\n      model_config: {}\n\n      moe_config:\n\n        freeze_moe_router: False\n\n    use_fused_kernels: False # Whether to use custom fused kernels (PostProcessing, for memory efficiency)\n\n    trust_remote_code: False\n    \n  rollout:\n    # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.\n    enable_chunked_prefill: False\n\n    load_format: dummy_megatron\n\n    tensor_model_parallel_size: 1\n\n    layer_name_map:\n      qkv_layer_name: qkv\n      gate_proj_layer_name: gate_up\n\n  profiler:\n    _target_: verl.utils.profiler.ProfilerConfig\n    discrete: False\n    all_ranks: False\n    ranks: []\n\ncustom_reward_function:\n  path: null\n  name: compute_score\n\nalgorithm:\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n  _target_: verl.trainer.config.AlgoConfig\n  gamma: 1.0\n  lam: 1.0\n  adv_estimator: gae\n  norm_adv_by_std_in_grpo: True\n  use_kl_in_reward: False\n  kl_penalty: kl  # how to estimate kl divergence\n  kl_ctrl:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.trainer.config.KLControlConfig\n    type: fixed\n    kl_coef: 0.001\n    horizon: 10000\n    target_kl: 0.1\n  use_pf_ppo: False\n  pf_ppo:\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.trainer.config.PFPPOConfig\n    reweight_method: pow  # [\"pow\", \"max_min\", \"max_random\"]\n    weight_pow: 2.0\n\ntrainer:\n  balance_batch: True\n  total_epochs: 30\n  total_training_steps: null\n  profile_steps: null # [1,2,5] or [] or null\n  project_name: verl_examples\n  experiment_name: gsm8k\n  logger: ['console', 'wandb']\n  log_val_generations: 0\n  nnodes: 1\n  n_gpus_per_node: 8\n  save_freq: -1\n  esi_redundant_time: 0\n\n  # auto: find the last ckpt to resume. If can't find, start from scratch\n  resume_mode: auto # or disable or resume_path if resume_from_path is set\n  resume_from_path: null\n  del_local_ckpt_after_load: False\n  val_before_train: True\n  test_freq: -1\n  critic_warmup: 0\n  default_hdfs_dir: null\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  max_actor_ckpt_to_keep: null\n  max_critic_ckpt_to_keep: null\n  # The timeout for ray worker group to wait for the register center to be ready\n  ray_wait_register_center_timeout: 300\n  device: cuda\n  # see ppo_trainer.yaml for more details\n  controller_nsight_options:\n    trace: \"cuda,nvtx,cublas,ucx\"\n    cuda-memory-usage: \"true\"\n    cuda-graph-trace: \"graph\"\n  worker_nsight_options:\n    trace: \"cuda,nvtx,cublas,ucx\"\n    cuda-memory-usage: \"true\"\n    cuda-graph-trace: \"graph\"\n    capture-range: \"cudaProfilerApi\"\n    capture-range-end: null\n    kill: none\nray_init:\n  num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/ppo_trainer.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# specify the default per-component configs\ndefaults:\n\n  # <folder_name>@<field_name>.<field_name>: <yaml_file_name>\n  # actor_rollout_ref.actor: trainer/config/actor/dp_actor.yaml\n  - actor@actor_rollout_ref.actor: dp_actor\n\n  # trainer.npu_profile: trainer/config/npu_profile/npu_profile.yaml\n  - npu_profile@trainer.npu_profile: npu_profile\n\n  # data: trainer/config/data/legacy_data.yaml\n  - data@data: legacy_data\n\n  # Reference model config.\n  # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True.\n  - ref@actor_rollout_ref.ref: dp_ref\n\n  # Rollout model config.\n  - rollout@actor_rollout_ref.rollout: rollout\n\n  # Critic model config.\n  - critic@critic: dp_critic\n\n  # Reward model config.\n  - reward_model@reward_model: dp_reward_model\n\n  # load the reference default config, then apply the fields in the current yaml\n  # self config override anything above\n  - _self_\n\n# config for actor, rollout and reference model\nactor_rollout_ref:\n\n  # Whether it's a hybrid engine, currently only supports hybrid engine\n  hybrid_engine: true\n\n  # common configs for the model\n  model:\n\n    # Huggingface model path. This can be either local path or HDFS path.\n    path: ~/models/deepseek-llm-7b-chat\n\n    # Custom chat template for the model.\n    custom_chat_template: null\n\n    # Whether to use shared memory (SHM) for accelerating the loading of model weights\n    use_shm: false\n\n    # Additional Python packages to register huggingface models/tokenizers.\n    external_lib: null\n\n    # Used to override model's original configurations, mainly dropout\n    override_config: {}\n\n    # Enable gradient checkpointing for actor\n    enable_gradient_checkpointing: true\n\n    # Enable activation offloading for actor\n    enable_activation_offload: false\n\n    # Whether to remove padding tokens in inputs during training\n    use_remove_padding: false\n\n    # Set to positive value to enable LoRA (e.g., 32)\n    lora_rank: 0\n\n    # LoRA scaling factor\n    lora_alpha: 16\n\n    # Target modules to apply LoRA. Options: \"all-linear\" (not recommended for VLMs) or\n    # [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj]\n    target_modules: all-linear\n\n    # Exclude modules from applying Lora. Similar usage to target_modules and Peft.\n    # Example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as currently vllm does not support ViT Lora.\n    exclude_modules: null\n\n    # Whether to use Liger for linear layer fusion\n    use_liger: false\n\n    # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP)\n    use_fused_kernels: false\n\n    # Options for fused kernels. If use_fused_kernels is true, this will be used.\n    fused_kernel_options:\n\n      # Implementation backend for fused kernels. Options: \"triton\" or \"torch\".\n      impl_backend: torch\n\n    # Whether to enable loading a remote code model\n    trust_remote_code: false\n\n  # Rollout model config.\n  rollout:\n\n    # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.\n    enable_chunked_prefill: True\n\n    # Which loader to use for rollout model weights: dummy_dtensor, hf, megatron, etc.\n    # safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight\n    load_format: dummy_dtensor\n\n    # for huge model, layered summon can save memory (prevent OOM) but make it slower\n    layered_summon: False\n\n  # profiler configs\n  profiler:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.utils.profiler.ProfilerConfig\n\n    # True for each task has its own database, False for all tasks in one training step share one database.\n    discrete: False\n\n    # Whether to profile all ranks.\n    all_ranks: False\n\n    # The ranks that will be profiled. [] or [0,1,...]\n    ranks: []\n\n# custom reward function definition\ncustom_reward_function:\n\n  # The path to the file containing your customized reward function.\n  # If not specified, pre-implemented reward functions will be used.\n  path: null\n\n  # The name of the reward function within the specified file. Default is 'compute_score'.\n  name: compute_score\n\n# config for the algorithm\nalgorithm:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n  _target_: verl.trainer.config.AlgoConfig\n\n  # Discount factor for future rewards\n  gamma: 1.0\n\n  # Trade-off between bias and variance in the GAE estimator\n  lam: 1.0\n\n  # Advantage estimator type: \"gae\", \"grpo\", \"reinforce_plus_plus\", etc.\n  adv_estimator: gae\n\n  # Whether to normalize advantages by std (specific to GRPO)\n  norm_adv_by_std_in_grpo: True\n\n  # Whether to enable in-reward KL penalty\n  use_kl_in_reward: False\n\n  # How to estimate KL divergence: \"kl\", \"abs\", \"mse\", \"low_var_kl\", or \"full\"\n  kl_penalty: kl\n\n  # KL control configuration\n  kl_ctrl:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.trainer.config.KLControlConfig\n\n    # KL control type: \"fixed\" or \"adaptive\"\n    type: fixed\n\n    # Initial coefficient for KL penalty\n    kl_coef: 0.001\n\n    # Horizon value for adaptive controller (if enabled)\n    horizon: 10000\n\n    # Target KL divergence (used for adaptive controller)\n    target_kl: 0.1\n\n  # Whether to enable preference feedback PPO\n  use_pf_ppo: False\n\n  # Preference feedback PPO settings\n  pf_ppo:\n\n    # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n    _target_: verl.trainer.config.PFPPOConfig\n\n    # Method for reweighting samples: \"pow\", \"max_min\", or \"max_random\"\n    reweight_method: pow\n\n    # Power used for weight scaling in \"pow\" method\n    weight_pow: 2.0\n\n# config for the trainer\ntrainer:\n\n  # Whether to balance batch sizes across distributed workers\n  balance_batch: True\n\n  # Number of epochs in training\n  total_epochs: 30\n\n  # Total training steps (can be set explicitly or derived from epochs)\n  total_training_steps: null\n\n  # The steps that will be profiled. null means no profiling. null or [1,2,5,...]\n  profile_steps: null\n\n  # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None.\n  ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html\n  ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html\n  controller_nsight_options:\n\n    # Select the API(s) to be traced.\n    trace: \"cuda,nvtx,cublas,ucx\"\n\n    # Track the GPU memory usage by CUDA kernels. Must be string type \"true\" or \"false\".\n    cuda-memory-usage: \"true\"\n\n    # CUDA graphs will be traced as a whole\n    cuda-graph-trace: \"graph\"\n\n  # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None.\n  worker_nsight_options:\n\n    # Select the API(s) to be traced.\n    trace: \"cuda,nvtx,cublas,ucx\"\n\n    # Track the GPU memory usage by CUDA kernels. Must be string type \"true\" or \"false\".\n    cuda-memory-usage: \"true\"\n\n    # CUDA graphs will be traced as a whole\n    cuda-graph-trace: \"graph\"\n\n    # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config.\n    capture-range: \"cudaProfilerApi\"\n\n    # Specify the desired behavior when a capture range ends.\n    # In verl we need the orch.cuda.profiler.start/stop pair to repeats n times.\n    # valid values are \"repeat-shutdown:n\" or null.\n    # For normal whole step profiling, n = len(profile_steps);\n    # but for discrete profiling, n = len(profile_steps) * Number(subtasks).\n    # Or you can just leave it null and the program will use n = len(profile_steps) * 6;\n    capture-range-end: null\n\n    # Send signal to the target application's process group. We let the program to exit by itself.\n    kill: none\n\n  # Project name for experiment tracking (e.g., wandb)\n  project_name: verl_examples\n\n  # Experiment name for run identification in tracking tools\n  experiment_name: gsm8k\n\n  # Logging backends to use: \"console\", \"wandb\", etc.\n  logger: [ 'console', 'wandb' ]\n\n  # Number of generations to log during validation\n  log_val_generations: 0\n\n  # Directory for logging rollout data; no dump if null\n  rollout_data_dir: null\n\n  # Directory for logging validation data; no dump if null\n  validation_data_dir: null\n\n  # Number of nodes used in the training\n  nnodes: 1\n\n  # Number of GPUs per node\n  n_gpus_per_node: 8\n\n  # Save frequency (by iteration) for model checkpoints\n  save_freq: -1\n\n  # ESI refers to the elastic server instance used during training, similar to the training plan. For example,\n  # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training.\n  # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance.\n  # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time.\n  # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety.\n  esi_redundant_time: 0\n\n  # Resume mode: \"auto\", \"disable\", or \"resume_path\"\n  # \"auto\": resume from last checkpoint if available\n  # \"disable\": start from scratch\n  # \"resume_path\": resume from a user-defined path\n  resume_mode: auto\n\n  # Path to resume training from (only used when resume_mode is \"resume_path\")\n  resume_from_path: null\n\n  # Whether to run validation before training begins\n  val_before_train: True\n\n  # Whether to run validation only\n  val_only: False\n\n  # Validation frequency (in training iterations)\n  test_freq: -1\n\n  # Number of iterations to warm up the critic before updating policy\n  critic_warmup: 0\n\n  # Default path to distributed filesystem for saving checkpoints\n  default_hdfs_dir: null\n\n  # Whether to delete local checkpoints after loading\n  del_local_ckpt_after_load: False\n\n  # Default local directory for saving checkpoints\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n\n  # Maximum number of actor checkpoints to keep\n  max_actor_ckpt_to_keep: null\n\n  # Maximum number of critic checkpoints to keep\n  max_critic_ckpt_to_keep: null\n\n  # Timeout (in seconds) for Ray worker to wait for registration\n  ray_wait_register_center_timeout: 300\n\n  # Device to run training on (e.g., \"cuda\", \"cpu\")\n  device: cuda\n\n  # whether to use legacy worker implementation\n  #  mode: \"auto\", \"enable\", or \"disable\"\n  use_legacy_worker_impl: auto\n\n# configs related to ray initialization\nray_init:\n\n  # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM.\n  num_cpus: null\n\n  # Path to save Ray timeline JSON for performance profiling\n  timeline_json_file: null\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/ref/dp_ref.yaml",
    "content": "# defaults specify the default config from each component\ndefaults:\n\n  # dp ref config, inheriting from trainer/config/ref/ref.yaml\n  - ref\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\n# config for FSDP strategy\nfsdp_config:\n\n  # whether to offload parameters in FSDP\n  param_offload: False\n\n  # whether to perform reshard after model forward to save memory.\n  # only for fsdp2, [True, False, int between 1 and fsdp_size]\n  reshard_after_forward: True\n\n  # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n  # before the current forward computation.\n  forward_prefetch: False\n\n  # the wrap policy for FSDP model\n  wrap_policy:\n\n    # minimum number of params in a wrapped module\n    min_num_params: 0\n\n# sequence parallel size\n# same as actor_rollout_ref.actor.ulysses_sequence_parallel_size if it exists, otherwise 1\nulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}\n\n# calculate entropy with chunking to reduce memory peak\nentropy_from_logits_with_chunking: False\n\n# recompute entropy\nentropy_checkpointing: False\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/ref/megatron_ref.yaml",
    "content": "# megatron ref config, inheriting from trainer/config/ref/ref.yaml\ndefaults:\n  - ref\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\nstrategy: megatron\n\nmegatron:\n\n  param_offload: False\n\n  tensor_model_parallel_size: 1\n\n  expert_model_parallel_size: 1\n\n  expert_tensor_parallel_size: None\n\n  pipeline_model_parallel_size: 1\n\n  virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests\n\n  context_parallel_size: 1\n\n  sequence_parallel: True\n\n  use_distributed_optimizer: False\n\n  use_dist_checkpointing: False\n\n  dist_checkpointing_path: null\n\n  seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}\n\n  override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}\n\n  use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}\n\nprofile:\n\n  use_profile: False\n\n  profile_ranks: null\n\n  step_start: -1\n\n  step_end: -1\n\n  save_path: null\n\nload_weight: True"
  },
  {
    "path": "verl_rl/verl/trainer/config/ref/ref.yaml",
    "content": "# actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default\nstrategy: ${actor_rollout_ref.actor.strategy}\n\n# whether to enable torch.compile\n# same as actor_rollout_ref.actor.use_torch_compile if it exists, otherwise 1\nuse_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}\n\n# [Will be deprecated, use log_prob_micro_batch_size_per_gpu]\n# The batch size for one forward pass in the computation of log_prob. Global batch size.\nlog_prob_micro_batch_size: null\n\n# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.\nlog_prob_micro_batch_size_per_gpu: null\n\n# enable dynamic batch size (sequence packing) for log_prob computation\n# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false\nlog_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n\n# the max token length per GPU\n# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384\nlog_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/reward_model/dp_reward_model.yaml",
    "content": "# Format checks enforced on CI:\n# 1. Comments must appear above each field.\n# 2. There must be a blank line between each field.\n# 3. Inline comments (after a field on the same line) are not allowed.\n# 4. Indentation level is respected for nested fields.\n\n# defaults specify the default config from each component\ndefaults:\n\n  # dp actor config, inheriting from trainer/config/reward_model/reward_model.yaml\n  - reward_model\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\nstrategy: fsdp\n\nmodel:\n\n  # Whether to use shared memory for loading the model\n  use_shm: False\n\n  # Use remove padding optimization (saves compute)\n  use_remove_padding: False\n\n  # Whether to use fused reward kernels for speedup\n  use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}\n\n  # FSDP-specific config\n  fsdp_config:\n\n    # Policy for wrapping layers with FSDP\n    wrap_policy:\n      # Minimum number of parameters to trigger wrapping\n      min_num_params: 0\n\n    # Whether to offload model parameters to CPU\n    param_offload: False\n\n    # Only for FSDP2: Reshard after forward pass to reduce memory footprint\n    reshard_after_forward: True\n\n    # Number of GPUs in each FSDP shard group; -1 means auto\n    fsdp_size: -1\n\n    # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather\n    # before the current forward computation.\n    forward_prefetch: False\n\n# Sequence parallelism size for Ulysses-style model parallelism\nulysses_sequence_parallel_size: 1"
  },
  {
    "path": "verl_rl/verl/trainer/config/reward_model/megatron_reward_model.yaml",
    "content": "# defaults specify the default config from each component\ndefaults:\n\n  # dp actor config, inheriting from trainer/config/reward_model/reward_model.yaml\n  - reward_model\n\n  # load the reference default config, then apply the fields in the current yaml\n  - _self_\n\nstrategy: megatron\n\n# seconds, default is 10 minutes for torch, you can set it to a larger value\n# if you have long-running operations like 32B or 72B model using megatron\nnccl_timeout: 600\n\n# Megatron parallelism & checkpointing config\nmegatron:\n  # Whether to offload model parameters to CPU\n  param_offload: False\n\n  # Number of GPUs in tensor model parallel group\n  tensor_model_parallel_size: 1\n\n  # Number of GPUs in expert model parallel group\n  expert_model_parallel_size: 1\n\n  # Expert tensor parallel size\n  expert_tensor_parallel_size: null\n\n  # Number of pipeline model parallel stages\n  pipeline_model_parallel_size: 1\n\n  # change VPP interface for parallelism tests\n  virtual_pipeline_model_parallel_size: null\n\n  # Context parallel size\n  context_parallel_size: 1\n\n  # Whether to use sequence parallelism\n  sequence_parallel: True\n\n  # Whether to use distributed optimizer\n  use_distributed_optimizer: False\n\n  # Whether to enable distributed checkpointing\n  use_dist_checkpointing: False\n\n  # Path for distributed checkpoints\n  dist_checkpointing_path: null\n\n  # RNG seed for megatron\n  seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}\n\n  # Any overrides to transformer config\n  override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}\n\n  # Whether to use mbridge for faster comms\n  use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}\n\n# Whether to load weights (default True)\nload_weight: True"
  },
  {
    "path": "verl_rl/verl/trainer/config/reward_model/reward_model.yaml",
    "content": "# configs for the reward model\n\n# Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions.\n# In GSM8K and Math examples, we disable reward model.\n# For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses.\n# If False, the following parameters are not effective\nenable: False\n\n# FSDP strategy: \"fsdp\" or \"fsdp2\"\nstrategy: ???\n\n# model config for reward scoring\nmodel:\n\n  # Input tokenizer. If the reward model’s chat template is inconsistent with the policy,\n  # we need to first decode to plaintext, then apply the rm’s chat_template.\n  # Then score with RM. If chat_templates are consistent, it can be set to null.\n  # set this to null if the chat template is identical\n  input_tokenizer: ${actor_rollout_ref.model.path}\n\n  # RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification.\n  # Other model types need to define their own RewardModelWorker and pass it from the code.\n  path: ~/models/FsfairX-LLaMA3-RM-v0.1\n\n  # External model implementation (optional)\n  external_lib: ${actor_rollout_ref.model.external_lib}\n\n  # Whether to enable loading a remote code model, default to False\n  trust_remote_code: False\n\n# [Deprecated] Global micro batch size\n# will be deprecated, use micro_batch_size_per_gpu\nmicro_batch_size: null\n\n# Local per-GPU micro batch size\nmicro_batch_size_per_gpu: null\n\n# Maximum sequence length to process for scoring\nmax_length: null\n\n# Whether to dynamically adjust batch size at runtime\nuse_dynamic_bsz: ${critic.use_dynamic_bsz}\n\n# Maximum number of tokens per GPU in one forward pass\nforward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}\n\n# Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources.\n# Default is naive. If all verification functions are multiprocessing-safe,\n# the reward manager can be set to prime for parallel verification.\nreward_manager: naive\n\n# Whether to launch custom reward function asynchronously during log_prob\n# custom reward function executed async on CPU, during log_prob\nlaunch_reward_fn_async: False\n\n# Cloud/local sandbox fusion configuration for custom reward logic\nsandbox_fusion:\n\n  # Cloud /local function URL for sandbox execution\n  url: null\n\n  # Max concurrent requests allowed to sandbox\n  max_concurrent: 64\n\n  # Max memory limit for each sandbox process in MB\n  memory_limit_mb: 1024\n\n# profiler configs\nprofiler:\n\n  # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint\n  _target_: verl.utils.profiler.ProfilerConfig\n\n  # True for each task has its own database, False for all tasks in one training step share one database.\n  discrete: False\n\n  # Whether to profile all ranks.\n  all_ranks: False\n\n  # The ranks that will be profiled. [] or [0,1,...]\n  ranks: []"
  },
  {
    "path": "verl_rl/verl/trainer/config/rollout/rollout.yaml",
    "content": "# actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future\nname: vllm\n\n# sync: LLM, async: AsyncLLM\nmode: sync\n\n# Sampling temperature for rollout.\ntemperature: 1.0\n\n# Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.\ntop_k: -1\n\n# Top-p sampling parameter. Default 1.0.\ntop_p: 1\n\n# typically the same as data max prompt length\n# same as data.max_prompt_length if it exists\nprompt_length: ${oc.select:data.max_prompt_length,512}\n\n# typically the same as data max response length\n# same as data.max_response_length if it exists\nresponse_length: ${oc.select:data.max_response_length,512}\n\n# for vllm rollout\n# Rollout model parameters type. Align with actor model's FSDP/Megatron type.\ndtype: bfloat16\n\n# Fraction of GPU memory used by vLLM/SGLang for KV cache.\ngpu_memory_utilization: 0.5\n\n# Whether to ignore EOS and continue generating after EOS is hit.\nignore_eos: False\n\n# Whether to disable CUDA graph. Default True to allow cache freeing.\nenforce_eager: True\n\n# Whether to free engine KVCache after generation. Set enforce_eager=True when enabled.\nfree_cache_engine: True\n\n# TP size for rollout. Not effective for hf\ntensor_model_parallel_size: 2\n\n# max number of tokens in a batch\nmax_num_batched_tokens: 8192\n\n# max length for rollout\nmax_model_len: null\n\n# max length of sequences\nmax_num_seqs: 1024\n\n# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size.\nlog_prob_micro_batch_size: null\n\n# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.\nlog_prob_micro_batch_size_per_gpu: null\n\n# enable dynamic batch size (sequence packing) for log_prob computation\n# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false\nlog_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}\n\n# max token length for log_prob computation\n# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384\nlog_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}\n\n# disable logging statistics\ndisable_log_stats: True\n\n# for hf rollout\n# Whether to sample during training rollout. False uses greedy sampling.\ndo_sample: True\n\n# number of responses (i.e. num sample times). > 1 for grpo\nn: 1\n\n# Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache)\nmulti_stage_wake_up: false\n\n# Extra inference engine arguments (vllm, sglang).\nengine_kwargs:\n\n  # for vllm\n  vllm:\n\n    # Swap space (in GB) used by inference engine. null uses default (e.g., 4 GB).\n    swap_space: null\n\n    # Whether to disable the preprocessor cache for multimodel models.\n    disable_mm_preprocessor_cache: False\n\n  # for sglang\n  sglang:\n\n    # The attention backend for sglang engine. Options: flashinfer, triton, flashmla, null for default.\n    attention_backend: null\n\n# Sampling parameters used during validation.\nval_kwargs:\n\n  # sampling parameters for validation\n  # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.\n  top_k: -1\n\n  # Top-p sampling parameter. Default 1.0.\n  top_p: 1.0\n\n  # Sampling temperature for rollout.\n  temperature: 0\n\n  # whether to repeat n times for validation\n  n: 1\n\n  # Whether to sample during training rollout. False uses greedy sampling.\n  do_sample: False\n\n# Multi-turn interaction config for tools or chat.\nmulti_turn:\n\n  # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well\n  enable: False\n\n  # null for no limit (default max_length // 3)\n  max_assistant_turns: null\n\n  # null for no tool\n  tool_config_path: null\n\n  # null for no limit (default max_length // 3)\n  max_user_turns: null\n\n  # max parallel call for tools in single turn\n  max_parallel_calls: 1\n\n  # max length of tool response\n  max_tool_response_length: 256\n\n  # truncate side of tool response: left, middle, right\n  tool_response_truncate_side: middle\n\n  # null for no interaction\n  interaction_config_path: null\n\n  # null for default callback\n  completion_callback: null\n\n  # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.\n  # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,\n  #   which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.\n  use_inference_chat_template: False\n\n  # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.\n  # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.\n  # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.\n  # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:\n  # Qwen/QwQ-32B, Qwen/Qwen3-xxB\n  # - disable: disable tokenization sanity check\n  # - strict: enable strict tokenization sanity check (default)\n  # - ignore_strippable: ignore strippable tokens when checking tokenization sanity\n  tokenization_sanity_check_mode: strict\n\n  # Format of the multi-turn interaction. Options: hermes, llama3_json, ...\n  format: hermes\n\n# support logging rollout prob for debugging purpose\ncalculate_log_probs: False\n\n# [Experimental] agent loop based rollout configs\nagent:\n\n  # Number of agent loop workers\n  num_workers: 8\n\n  # custom agent loop config path, which should contain list of configs to intialize AgentLoop instances.\n  # https://hydra.cc/docs/advanced/instantiate_objects/overview/\n  #\n  # - name: react_agent\n  #   _target_: recipe.langgraph_agent.react_agent_loop.ReactAgentLoop\n  #   tools: [\"get_current_temperature\"]\n  # - name: math_expression\n  #   _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop\n  #   min_terms: 2\n  #   max_terms: 6\n  agent_loop_config_path: null\n\n  # custom async server configs\n  custom_async_server:\n\n    # Path to the custom async server implementation\n    path: null\n\n    # Class name of the custom async server class (e.g. AsyncvLLMServer)\n    name: null\n\n# Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations.\n# This parameter controls the maximum payload size for a single weight update request.\n# Reference: https://github.com/volcengine/verl/pull/2418\n# Currently only supported in SGLang rollout implementations\n# Larger values may improve throughput but increase memory overhead\n# Detailed performance comparison:\n# https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/169#issuecomment-3070686720\n# Default value (512MB) is optimized for typical GPU memory configurations\n# For the best performance of `rebuild_cuda_tensor`, it is recommended to:\n# 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`\n# 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`\n# when using Tensor Parallelism (TP) >= 8.\nupdate_weights_bucket_megabytes: 512\n\n# trace rollout data\ntrace:\n  \n  # trace backend, support mlflow, weave\n  backend: null\n\n  # whether translate token id to text in output\n  token2text: False\n"
  },
  {
    "path": "verl_rl/verl/trainer/config/sft_trainer.yaml",
    "content": "data:\n  train_batch_size: 256\n  micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu\n  micro_batch_size_per_gpu: 4  # this is also val batch size\n  train_files: ~/data/gsm8k/train.parquet\n  val_files: ~/data/gsm8k/test.parquet\n  # Single-turn settings\n  prompt_key: question\n  response_key: answer\n  prompt_dict_keys: null\n  response_dict_keys: null\n  # Multi-turn settings\n  multiturn:\n    enable: false  # Set to true to use multi-turn dataset\n    messages_key: messages  # Key for messages list in multi-turn mode\n    tools_key: tools  # Key for tools list in multi-turn mode\n    enable_thinking_key: enable_thinking  # Whether to enable thinking in multi-turn mode\n  max_length: 1024\n  truncation: error\n  balance_dp_token: False\n  chat_template: null\n  custom_cls:\n    path: null\n    name: null\n  use_shm: False\nmodel:\n  partial_pretrain: ~/models/gemma-1.1-7b-it\n  use_shm: False\n  fsdp_config:\n    model_dtype: fp32\n    wrap_policy:\n      min_num_params: 0\n    cpu_offload: False\n    offload_params: False\n  external_lib: null\n  enable_gradient_checkpointing: True\n  trust_remote_code: False\n  lora_rank: 0  # Set to positive value to enable LoRA (e.g., 32)\n  lora_alpha: 16  # LoRA scaling factor\n  target_modules: all-linear  # Target modules for LoRA adaptation\n  use_liger: False\n  strategy: fsdp2\noptim:\n  lr: 1e-5\n  betas: [0.9, 0.95]\n  weight_decay: 0.01\n  warmup_steps_ratio: 0.1\n  clip_grad: 1.0\n  lr_scheduler: cosine\nulysses_sequence_parallel_size: 1\nuse_remove_padding: False\ntrainer:\n  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}\n  default_hdfs_dir: null\n  project_name: gsm8k-sft\n  experiment_name: test\n  total_epochs: 4\n  total_training_steps: null\n  logger: [ 'console', 'wandb' ]\n  seed: 1\n\n  save_freq: -1\n  test_freq: -1\n  nnodes: 1\n  n_gpus_per_node: 8\n  max_ckpt_to_keep: null  # Maximum number of checkpoints to keep, set to null to keep all\n\n  # Resume mode: \"auto\", \"disable\", or \"resume_path\"\n  # \"auto\": resume from last checkpoint if available\n  # \"disable\": start from scratch\n  # \"resume_path\": resume from a user-defined path\n  resume_mode: auto\n\n  # Path to resume training from (used when resume_mode is \"resume_path\" or \"auto\")\n  resume_from_path: null\n\n  # Checkpoint configuration\n  checkpoint:\n    # What to include in saved checkpoints\n    # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space\n    save_contents: [\"model\", \"optimizer\", \"extra\"]\n\n    # For more flexibility, you can specify the contents to load from the checkpoint.\n    load_contents: ${trainer.checkpoint.save_contents}\n  device: cuda\n"
  },
  {
    "path": "verl_rl/verl/trainer/constants_ppo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nPPO_RAY_RUNTIME_ENV = {\n    \"env_vars\": {\n        \"TOKENIZERS_PARALLELISM\": \"true\",\n        \"NCCL_DEBUG\": \"WARN\",\n        \"VLLM_LOGGING_LEVEL\": \"WARN\",\n        \"VLLM_ALLOW_RUNTIME_LORA_UPDATING\": \"true\",\n        \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",\n        \"WANDB_MODE\": \"offline\",\n        \"WANDB_DISABLE_SERVICE\": \"true\",\n    },\n}\n\n\ndef get_ppo_ray_runtime_env():\n    \"\"\"\n    A filter function to return the PPO Ray runtime environment.\n    To avoid repeat of some environment variables that are already set.\n    \"\"\"\n    runtime_env = {\"env_vars\": PPO_RAY_RUNTIME_ENV[\"env_vars\"].copy()}\n    for key in list(runtime_env[\"env_vars\"].keys()):\n        if os.environ.get(key) is not None:\n            runtime_env[\"env_vars\"].pop(key, None)\n    return runtime_env\n"
  },
  {
    "path": "verl_rl/verl/trainer/fsdp_sft_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA lightweight one-file FSDP SFT Trainer\nTODO(zhangchi.usc1992)\n- Add calculation of mfu\n- Add validation\n\"\"\"\n\nimport os\n\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n\nimport logging\nimport re\nfrom contextlib import nullcontext\n\nimport hydra\nimport torch\nimport torch.distributed\nfrom omegaconf import DictConfig\nfrom peft import LoraConfig, TaskType, get_peft_model\nfrom tensordict import TensorDict\nfrom torch import nn, optim\nfrom torch.distributed.device_mesh import DeviceMesh, init_device_mesh\nfrom torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.utils.data import Dataset, DistributedSampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom tqdm import tqdm\nfrom transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel\n\nimport verl.utils.hdfs_io as hdfs_io\nfrom verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, get_checkpoint_tracker_filename\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.dataset import SFTDataset\nfrom verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset\nfrom verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available\nfrom verl.utils.distributed import destroy_global_process_group, initialize_global_process_group\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.fsdp_utils import (\n    CPUOffloadPolicy,\n    MixedPrecisionPolicy,\n    apply_fsdp2,\n    fsdp2_clip_grad_norm_,\n    fsdp2_load_full_state_dict,\n    get_fsdp_wrap_policy,\n    get_init_weight_context_manager,\n    init_fn,\n)\nfrom verl.utils.logger import log_with_rank\nfrom verl.utils.profiler import log_gpu_memory_usage\nfrom verl.utils.py_functional import convert_to_regular_types\nfrom verl.utils.torch_dtypes import PrecisionType\nfrom verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup\nfrom verl.utils.tracking import Tracking\nfrom verl.utils.ulysses import (\n    gather_outputs_and_unpad,\n    get_ulysses_sequence_parallel_world_size,\n    ulysses_pad_and_slice_inputs,\n)\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\nif is_cuda_available:\n    from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\nelif is_npu_available:\n    from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_SFT_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef extract_step(path):\n    match = re.search(r\"global_step_(\\d+)\", path)\n    if match:\n        return int(match.group(1))\n    return None\n\n\nclass FSDPSFTTrainer:\n    def __init__(\n        self,\n        config,\n        device_mesh: DeviceMesh,\n        ulysses_device_mesh: DeviceMesh,\n        tokenizer,\n        train_dataset: Dataset,\n        val_dataset: Dataset,\n    ):\n        self.config = config\n        self.device_mesh = device_mesh\n        self.ulysses_device_mesh = ulysses_device_mesh\n        self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n        self.tokenizer = tokenizer\n        if self.config.data.chat_template is not None:\n            raise ValueError(\"Apply Chat template from config is not supported yet.\")\n\n        # normalize dp size\n        self._normalize_config_bsz()\n\n        # Set sequence parallel size\n        self.config.ulysses_sequence_parallel_size = getattr(self.config, \"ulysses_sequence_parallel_size\", 1)\n        self.use_remove_padding = getattr(self.config, \"use_remove_padding\", False)\n        if self.device_mesh.get_rank() == 0:\n            print(f\"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}\")\n            print(f\"Using remove padding: {self.use_remove_padding}\")\n\n        self._build_dataloader(train_dataset, val_dataset)\n\n        # Initialize resume-related variables\n        self.resume_global_step = 0\n\n        # build model\n        self._build_model_optimizer()\n\n        # Initialize checkpoint manager\n        self._init_checkpoint_manager()\n\n        self.load_checkpoint()\n\n        if self.device_mesh.get_rank() == 0:\n            print(self.config)\n        self.device_name = self.config.trainer.device\n\n    def _normalize_config_bsz(self):\n        dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0)\n        if self.device_mesh.get_rank() == 0:\n            print(f\"Normalize batch size by dp {dp_size}\")\n\n        assert self.config.data.train_batch_size % dp_size == 0, (\n            f\"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}\"\n        )\n\n        self.config.data.train_batch_size //= dp_size\n\n        assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0\n\n    def _build_dataloader(self, train_dataset, val_dataset):\n        # build dataset\n        config = self.config\n        self.train_dataset, self.val_dataset = train_dataset, val_dataset\n\n        # build dataloader\n        # Use data parallel rank and size instead of global rank and world size\n\n        # If doing SP, we need to use the local rank and size\n        if self.config.ulysses_sequence_parallel_size > 1:\n            rank = self.ulysses_device_mesh.get_local_rank(\"dp\")\n            world_size = self.ulysses_device_mesh.size(0)\n            if self.ulysses_device_mesh.get_rank() == 0:\n                print(f\"Using SP rank {rank} and size {world_size} for data distribution\")\n                print(\"Each SP rank gets different data, but the same data WITHIN the same rank\")\n        else:\n            rank = self.device_mesh.get_rank()\n            world_size = self.device_mesh.size()\n        if self.device_mesh.get_rank() == 0:\n            print(f\"Using FSDP rank {rank} and size {world_size} for data distribution\")\n\n        self.train_sampler = DistributedSampler(\n            self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True\n        )\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=config.data.train_batch_size,\n            sampler=self.train_sampler,\n            num_workers=8,\n            pin_memory=True,\n            drop_last=True,\n        )\n\n        self.val_sampler = DistributedSampler(\n            self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True\n        )\n        self.val_dataloader = StatefulDataLoader(\n            dataset=self.val_dataset,\n            batch_size=config.data.micro_batch_size_per_gpu,\n            sampler=self.val_sampler,\n            num_workers=8,\n            pin_memory=True,\n            drop_last=True,\n        )\n\n    def _build_model_optimizer(self):\n        # TODO (zhangchi.usc1992):\n        # 1. support pretrain from random weights\n        # 2. support init directly from sharded weights\n        local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)\n\n        if self.config.model.get(\"external_lib\", None) is not None:\n            # This is used to import external_lib into the huggingface systems\n            import importlib\n\n            importlib.import_module(self.config.model.external_lib)\n\n        log_gpu_memory_usage(\"Before model allocation\", logger=logger)\n\n        trust_remote_code = self.config.model.trust_remote_code\n        torch_dtype = self.config.model.fsdp_config.get(\"model_dtype\", \"fp32\")\n        torch_dtype = PrecisionType.to_dtype(torch_dtype)\n        # load config first\n        config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)\n        self.model_config = config\n        if hasattr(self.model_config, \"max_position_embeddings\"):\n            self.model_config.max_position_embeddings = max(\n                self.model_config.max_position_embeddings, self.config.data.max_length\n            )\n        if self.config.ulysses_sequence_parallel_size > 1:\n            assert self.use_remove_padding, \"Sequence parallel is only supported when remove_padding is enabled\"\n\n        # This may be very large\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context():\n            self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(\n                local_model_path,\n                config=config,\n                torch_dtype=torch_dtype,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )\n\n            if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1:\n                from verl.models.transformers.monkey_patch import apply_monkey_patch\n\n                apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size)\n\n            # Apply Liger kernel if use_liger is enabled\n            if self.config.model.get(\"use_liger\", False):\n                from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance\n\n                _apply_liger_kernel_to_instance(model=self.model)\n\n            if self.config.model.get(\"lora_rank\", 0) > 0:\n                self.model.enable_input_require_grads()\n                # Convert config to regular Python types before creating PEFT model\n                lora_config = {\n                    \"task_type\": TaskType.CAUSAL_LM,\n                    \"r\": self.config.model.lora_rank,\n                    \"lora_alpha\": self.config.model.lora_alpha,\n                    \"target_modules\": convert_to_regular_types(self.config.model.target_modules),\n                    \"bias\": \"none\",\n                }\n                self.model = get_peft_model(self.model, LoraConfig(**lora_config))\n                self.model = self.model.to(torch_dtype)\n\n        if self.config.model.enable_gradient_checkpointing:\n            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n\n        log_gpu_memory_usage(\"After model allocation\", logger=logger)\n\n        mixed_precision = MixedPrecision(\n            param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32\n        )\n\n        auto_wrap_policy = get_fsdp_wrap_policy(\n            self.model,\n            config=self.config.model.fsdp_config.wrap_policy,\n            is_lora=self.config.model.get(\"lora_rank\", 0) > 0,\n        )\n        if self.device_mesh.get_rank() == 0:\n            print(auto_wrap_policy)\n\n        if not self.config.model.fsdp_config.cpu_offload:\n            cpu_offload = None\n        else:\n            cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)\n\n        fsdp_strategy = self.config.model.strategy\n        if fsdp_strategy == \"fsdp\":\n            self.fsdp_model = FSDP(\n                self.model,\n                cpu_offload=cpu_offload,\n                param_init_fn=init_fn,\n                use_orig_params=False,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=ShardingStrategy.FULL_SHARD,\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                device_mesh=self.device_mesh,\n                forward_prefetch=False,\n            )\n        elif fsdp_strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            mp_policy = MixedPrecisionPolicy(\n                param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True\n            )\n\n            fsdp_kwargs = {\n                \"mesh\": self.device_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": cpu_offload,\n                \"reshard_after_forward\": True,\n            }\n            full_state = self.model.state_dict()\n            apply_fsdp2(self.model, fsdp_kwargs, self.config.model.fsdp_config)\n            fsdp2_load_full_state_dict(self.model, full_state, self.device_mesh, cpu_offload)\n            self.fsdp_model = self.model\n        else:\n            raise NotImplementedError(f\"not implement {fsdp_strategy}\")\n\n        log_gpu_memory_usage(\"After FSDP wrapping\", logger=logger)\n\n        self.optimizer = optim.AdamW(\n            self.fsdp_model.parameters(),\n            lr=self.config.optim.lr,\n            betas=self.config.optim.betas,\n            weight_decay=self.config.optim.weight_decay,\n        )\n\n        log_gpu_memory_usage(\"After initialize optimizer\", logger=logger)\n\n        self.steps_per_epoch = len(self.train_dataloader)\n        self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs\n\n        if self.device_mesh.get_rank() == 0:\n            print(\n                f\"Number of steps/epoch {self.steps_per_epoch}, number of epochs \"\n                f\"{self.config.trainer.total_epochs}, total number of steps {self.total_steps}\"\n            )\n\n        num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio)\n\n        if not hasattr(self.config.optim, \"lr_scheduler\") or self.config.optim.lr_scheduler == \"cosine\":\n            self.lr_scheduler = get_cosine_schedule_with_warmup(\n                optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps\n            )\n        elif self.config.optim.lr_scheduler == \"wsd\":\n            self.lr_scheduler = get_wsd_schedule_with_warmup(\n                optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps\n            )\n        else:\n            raise ValueError(f\"Unknown lr scheduler: {self.config.optim.lr_scheduler}\")\n\n    def _compute_loss_and_backward(self, batch, do_backward=True):\n        \"\"\"Compute loss with optional sequence parallelism and remove padding features\"\"\"\n        use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1\n\n        # Move inputs to GPU and prepare loss mask\n        input_ids = batch[\"input_ids\"].to(self.device_name)\n        attention_mask = batch[\"attention_mask\"].to(self.device_name)\n        position_ids = batch[\"position_ids\"].to(self.device_name)\n        loss_mask = batch.pop(\"loss_mask\")[:, :-1].reshape(-1).to(self.device_name)\n        loss_fct = nn.CrossEntropyLoss(reduction=\"none\")\n\n        # Context manager for sequence parallel if needed\n        context = self.sharding_manager if use_sp else nullcontext()\n        with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):\n            if not use_sp:\n                # Standard forward pass without sequence parallel\n                labels = input_ids[:, 1:].contiguous()\n                output = self.fsdp_model(\n                    input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False\n                )\n                logits = output.logits\n\n                shift_logits = logits[..., :-1, :].contiguous()\n                shift_labels = labels.contiguous()\n                # Flatten the tokens\n                shift_logits = shift_logits.view(-1, self.model.config.vocab_size)\n                shift_labels = shift_labels.view(-1)\n                # Enable model parallelism\n                shift_labels = shift_labels.to(shift_logits.device)\n                loss = loss_fct(shift_logits, shift_labels)\n                loss = loss * loss_mask.to(loss.device)\n            else:\n                # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks\n                # i.e., each GPU has <1 sequence, and each SP group has 1 sequence\n                # 1. All SP ranks will receive the *SAME* batch\n                # 2. Different SP groups will receive *DIFFERENT* batches\n                # This is implemented by the DistributedSampler\n\n                batch_size, seqlen = input_ids.shape\n                # Remove padding\n                input_ids_rmpad, indices, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # Unpad position_ids to align rotary\n                position_ids_rmpad = index_first_axis(\n                    rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                ).transpose(0, 1)\n\n                # Pad and slice inputs for sequence parallelism\n                input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(\n                    input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()\n                )\n                # For computing loss\n                input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)\n                input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(\n                    input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size()\n                )\n                input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)  # ((total_nnz / sp) + pad)\n\n                # Forward pass\n                output = self.fsdp_model(\n                    input_ids=input_ids_rmpad_sliced,\n                    attention_mask=None,  # Not needed with flash attention varlen\n                    position_ids=position_ids_rmpad_padded,\n                    use_cache=False,\n                )\n\n                # Compute loss locally then aggregate\n                logits_rmpad = output.logits.squeeze(0)\n                input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device)\n                loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled)\n                # Gather and unpad for sequence parallelism\n                loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size)\n\n                # This is the loss collected from all ulysses ranks\n                full_loss = pad_input(\n                    hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n                )\n                full_loss = full_loss.squeeze(-1)[:, :-1]  # Remove last token's loss\n                full_loss = full_loss.reshape(-1)\n                loss_mask = loss_mask.to(full_loss.device)\n                loss = full_loss * loss_mask\n\n            valid_token_this_rank = torch.sum(loss_mask)\n\n            if self.config.data.balance_dp_token:\n                torch.distributed.all_reduce(valid_token_this_rank)\n                dp_size = self.ulysses_device_mesh.size(\"dp\") if use_sp else torch.distributed.get_world_size()\n            else:\n                dp_size = 1\n\n            loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size\n\n            if do_backward:\n                loss.backward()\n            return loss\n\n    def training_step(self, batch: TensorDict):\n        self.fsdp_model.train()\n\n        log_gpu_memory_usage(\"Before optimizer zero_grad\", logger=logger)\n\n        self.optimizer.zero_grad()\n\n        log_gpu_memory_usage(\"After optimizer zero_grad\", logger=logger)\n\n        micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)\n        n_micro_batches = len(micro_batches)\n        step_loss = 0\n        for micro_batch in micro_batches:\n            loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches\n            step_loss += loss.item()\n\n        if self.config.model.strategy == \"fsdp\":\n            grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)\n        elif self.config.model.strategy == \"fsdp2\":\n            grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad)\n        else:\n            raise NotImplementedError(f\"not implement {self.config.model.strategy}\")\n\n        log_gpu_memory_usage(\"Before optimizer step\", logger=logger)\n\n        # if grad_norm is not finite, skip the update\n        if not torch.isfinite(grad_norm):\n            print(f\"WARN: grad_norm is not finite: {grad_norm}\")\n            self.optimizer.zero_grad()\n        else:\n            self.optimizer.step()\n\n        log_gpu_memory_usage(\"After optimizer step\", logger=logger)\n\n        self.lr_scheduler.step()\n\n        # reduce loss across dp ranks\n        lr = self.lr_scheduler.get_last_lr()[0]\n\n        log_gpu_memory_usage(\"After offload weights\", logger=logger)\n\n        step_loss = torch.tensor(step_loss).to(self.device_name)\n        if is_cuda_available:\n            torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)\n        elif is_npu_available:\n            torch.distributed.all_reduce(step_loss)\n            step_loss /= self.device_mesh.size(0)\n        return {\"train/loss\": step_loss.detach().item(), \"train/lr(1e-3)\": lr * 1e3}\n\n    def validation_step(self, batch: TensorDict):\n        self.fsdp_model.eval()\n        with torch.no_grad():\n            loss = self._compute_loss_and_backward(batch, do_backward=False)\n            if is_cuda_available:\n                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)\n            elif is_npu_available:\n                torch.distributed.all_reduce(loss)\n                loss /= self.device_mesh.size(0)\n        return loss\n\n    def save_checkpoint(self, step):\n        \"\"\"Save checkpoint using FSDPCheckpointManager with improved tracking\"\"\"\n        from verl.utils.fs import local_mkdir_safe\n\n        # Determine checkpoint path\n        local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f\"global_step_{step}\")\n\n        if self.device_mesh.get_rank() == 0:\n            print(f\"Saving checkpoint to: {local_global_step_folder}\")\n\n        # Get max checkpoints to keep\n        max_ckpt_to_keep = getattr(self.config.trainer, \"max_ckpt_to_keep\", None)\n\n        # Use checkpoint manager to save\n        self.checkpoint_manager.save_checkpoint(\n            local_path=local_global_step_folder, global_step=step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n\n        # Save dataloader state\n        if self.device_mesh.get_rank() == 0:\n            local_mkdir_safe(local_global_step_folder)\n            dataloader_local_path = os.path.join(local_global_step_folder, \"data.pt\")\n\n            # Use StatefulDataLoader's built-in state dict functionality\n            dataloader_state_dict = self.train_dataloader.state_dict()\n            torch.save(dataloader_state_dict, dataloader_local_path)\n            print(f\"Saved dataloader state to: {dataloader_local_path}\")\n\n            # Update latest checkpoint tracker (atomic write)\n            tracker_file = get_checkpoint_tracker_filename(self.config.trainer.default_local_dir)\n            temp_tracker_file = tracker_file + \".tmp\"\n            with open(temp_tracker_file, \"w\") as f:\n                f.write(str(step))\n            os.rename(temp_tracker_file, tracker_file)\n            print(f\"Updated checkpoint tracker: {tracker_file}\")\n\n        # Copy to HDFS if configured\n        if self.device_mesh.get_rank() == 0 and getattr(self.config.trainer, \"default_hdfs_dir\", None):\n            hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)\n            hdfs_io.copy(src=local_global_step_folder, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)\n\n        torch.distributed.barrier()\n\n    def _init_checkpoint_manager(self):\n        \"\"\"Initialize checkpoint manager with proper configuration\"\"\"\n        # Get checkpoint configuration from config, with defaults\n        checkpoint_config = getattr(self.config.trainer, \"checkpoint\", {})\n\n        # Set default values if not specified\n        save_contents = checkpoint_config.get(\"save_contents\", [\"model\", \"optimizer\", \"extra\"])\n        load_contents = checkpoint_config.get(\"load_contents\", save_contents)\n\n        # Create checkpoint config dict\n        checkpoint_config_dict = {\n            \"load_contents\": load_contents,\n            \"save_contents\": save_contents,\n        }\n\n        # Convert to DictConfig for compatibility\n        checkpoint_config_dict = DictConfig(checkpoint_config_dict)\n\n        # Initialize checkpoint manager\n        self.checkpoint_manager = FSDPCheckpointManager(\n            model=self.fsdp_model,\n            optimizer=self.optimizer,\n            lr_scheduler=self.lr_scheduler,\n            processing_class=self.tokenizer,\n            checkpoint_config=checkpoint_config_dict,\n        )\n\n    def load_checkpoint(self):\n        # Determine resume path based on configuration\n        checkpoint_path = self._determine_resume_path()\n\n        if checkpoint_path is None:\n            return 0\n\n        # extract resume step from checkpoint path\n        resume_step = extract_step(checkpoint_path)\n        if resume_step is None:\n            log_with_rank(\n                f\"Warning: Could not extract step number from {checkpoint_path}, starting from step 0\",\n                logger=logger,\n                rank=self.device_mesh.get_rank(),\n                level=logging.WARNING,\n                log_only_rank_0=True,\n            )\n            return 0\n        self.resume_global_step = resume_step\n\n        # Use checkpoint manager to load model state\n        self.checkpoint_manager.load_checkpoint(checkpoint_path)\n        log_with_rank(\n            f\"Successfully loaded model checkpoint from {checkpoint_path} (step {resume_step})\",\n            logger=logger,\n            rank=self.device_mesh.get_rank(),\n            log_only_rank_0=True,\n        )\n\n        # Always load dataloader state for StatefulDataLoader\n        self._load_dataloader_state(checkpoint_path)\n\n        return resume_step\n\n    def _load_dataloader_state(self, checkpoint_path: str):\n        \"\"\"Load dataloader state from checkpoint\"\"\"\n        dataloader_path = os.path.join(checkpoint_path, \"data.pt\")\n\n        if os.path.exists(dataloader_path):\n            # Use StatefulDataLoader's built-in state dict functionality\n            dataloader_state_dict = torch.load(dataloader_path, map_location=\"cpu\", weights_only=False)\n            self.train_dataloader.load_state_dict(dataloader_state_dict)\n\n            log_with_rank(\n                f\"Successfully loaded dataloader state from {dataloader_path}\",\n                logger=logger,\n                rank=self.device_mesh.get_rank(),\n                log_only_rank_0=True,\n            )\n\n        else:\n            log_with_rank(\n                f\"Warning: No dataloader state found at {dataloader_path}, will start from scratch\",\n                logger=logger,\n                rank=self.device_mesh.get_rank(),\n                level=logging.WARNING,\n                log_only_rank_0=True,\n            )\n\n    def _determine_resume_path(self):\n        \"\"\"Determine the path to resume from based on resume_mode configuration\"\"\"\n        resume_mode = getattr(self.config.trainer, \"resume_mode\", \"auto\")\n        resume_from_path = getattr(self.config.trainer, \"resume_from_path\", None)\n\n        if resume_mode == \"disable\":\n            return None\n        elif resume_mode == \"auto\":\n            if resume_from_path is not None:\n                assert os.path.exists(resume_from_path), (\n                    \"resume_from_path must be null or an existing path when resume_mode is 'auto'\"\n                )\n                assert \"global_step_\" in resume_from_path, \"resume_from_path must specify the global_steps\"\n                return resume_from_path\n            # Try to find the latest checkpoint in the default directory\n            return self._find_latest_checkpoint()\n        elif resume_mode == \"resume_path\":\n            assert os.path.exists(resume_from_path), (\n                \"resume_from_path must be an existing path when resume_mode is 'resume_path'\"\n            )\n            assert \"global_step_\" in resume_from_path, \"resume_from_path must specify the global_steps\"\n            return resume_from_path\n        else:\n            raise ValueError(f\"Invalid resume_mode: {resume_mode}. Must be 'auto', 'disable', or 'resume_path'\")\n\n    def _find_latest_checkpoint(self):\n        \"\"\"Find the latest checkpoint in the default local directory\"\"\"\n        checkpoint_dir = self.config.trainer.default_local_dir\n\n        if not os.path.exists(checkpoint_dir):\n            return None\n\n        latest_checkpoint = find_latest_ckpt_path(checkpoint_dir)\n\n        if latest_checkpoint and self.device_mesh.get_rank() == 0:\n            step_num = extract_step(latest_checkpoint)\n            print(f\"Found latest checkpoint: {latest_checkpoint} (step {step_num})\")\n\n        return latest_checkpoint\n\n    def fit(self):\n        rank = self.device_mesh.get_rank()\n\n        # TODO: add a unified tracking\n        if rank == 0:\n            tracking = Tracking(\n                project_name=self.config.trainer.project_name,\n                experiment_name=self.config.trainer.experiment_name,\n                default_backend=self.config.trainer.logger,\n            )\n\n        global_step = self.resume_global_step  # Start from resumed step\n        last_valid_metric = None\n        # compute the total training steps.\n        # the total training steps in SFT is mainly for early exit\n        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n\n        if self.config.trainer.total_training_steps is not None:\n            total_training_steps = self.config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n        log_with_rank(\n            f\"Total training steps: {self.total_training_steps},\",\n            logger=logger,\n            rank=self.device_mesh.get_rank(),\n            log_only_rank_0=True,\n        )\n\n        # With StatefulDataLoader, we don't need to manually calculate epochs and steps\n        # The dataloader will automatically resume from where it left off\n        if global_step > 0:\n            log_with_rank(\n                f\"StatefulDataLoader will automatically resume from global step: {global_step}\",\n                logger=logger,\n                rank=self.device_mesh.get_rank(),\n                log_only_rank_0=True,\n            )\n\n        # Calculate which epoch we're starting from for sampler.set_epoch()\n        start_epoch = global_step // self.steps_per_epoch\n\n        for epoch in range(start_epoch, self.config.trainer.total_epochs):\n            self.train_sampler.set_epoch(epoch=epoch)\n\n            for step_in_epoch, data in enumerate(\n                tqdm(\n                    self.train_dataloader,\n                    initial=global_step % self.steps_per_epoch if epoch == start_epoch else 0,\n                    total=self.steps_per_epoch,\n                    desc=f\"Epoch {epoch + 1}/{self.config.trainer.total_epochs}\",\n                    disable=rank != 0,\n                )\n            ):\n                global_step += 1\n                data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name)\n                metric = self.training_step(data)\n                if rank == 0:\n                    tracking.log(data=metric, step=global_step)\n\n                is_last_step = global_step >= self.total_training_steps\n                is_valid_step = global_step % self.config.trainer.test_freq == 0\n                is_save_step = global_step % self.config.trainer.save_freq == 0\n\n                # early exit or validation step\n                if is_last_step or (self.config.trainer.test_freq > 0 and is_valid_step):\n                    # Perform validation\n                    val_losses = []\n                    for val_data in self.val_dataloader:\n                        val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(\n                            self.device_name\n                        )\n                        val_loss = self.validation_step(val_data)\n                        val_losses.append(val_loss)\n                    if rank == 0:\n                        val_loss = torch.mean(torch.stack(val_losses))\n                        metric = {\"val/loss\": val_loss.detach().item()}\n                        tracking.log(data=metric, step=global_step)\n                        last_valid_metric = metric\n                    torch.distributed.barrier()\n\n                if is_last_step or (self.config.trainer.save_freq > 0 and is_save_step):\n                    self.save_checkpoint(step=global_step)\n\n                if is_last_step:\n                    if rank == 0:\n                        print(f\"Final validation metrics: {last_valid_metric}\")\n                    return\n\n\ndef run_sft(config):\n    device_name = get_device_name()\n    local_rank, rank, world_size = initialize_global_process_group()\n\n    device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=(\"fsdp\",))\n    dp_size = world_size // config.ulysses_sequence_parallel_size\n    ulysses_device_mesh = init_device_mesh(\n        device_type=device_name,\n        mesh_shape=(dp_size, config.ulysses_sequence_parallel_size),\n        mesh_dim_names=(\"dp\", \"sp\"),\n    )\n    # build tokenizer and datasets first\n    from verl.utils import hf_tokenizer\n\n    local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)\n    tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)\n    train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)\n    val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)\n\n    trainer = FSDPSFTTrainer(\n        config=config,\n        device_mesh=device_mesh,\n        ulysses_device_mesh=ulysses_device_mesh,\n        tokenizer=tokenizer,\n        train_dataset=train_dataset,\n        val_dataset=val_dataset,\n    )\n\n    trainer.fit()\n\n    destroy_global_process_group()\n\n\n@hydra.main(config_path=\"config\", config_name=\"sft_trainer\", version_base=None)\ndef main(config):\n    run_sft(config)\n\n\ndef create_sft_dataset(data_paths, data_config, tokenizer):\n    \"\"\"Create a dataset.\"\"\"\n    # build dataset\n    # First check if a custom dataset class is specified\n    if data_config.custom_cls.get(\"path\", None):\n        from verl.utils.import_utils import load_extern_type\n\n        dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)\n    # Then check if multi-turn dataset should be used\n    elif data_config.get(\"multiturn\", {}).get(\"enable\", False):\n        dataset_cls = MultiTurnSFTDataset\n    # Default to single-turn dataset\n    else:\n        dataset_cls = SFTDataset\n\n    # Create datasets based on the selected class\n    dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config)\n    return dataset\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/verl/trainer/main_eval.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nOffline evaluate the performance of a generated file using reward model and ground truth verifier.\nThe input is a parquet file that contains N generated sequences and (optional) the ground truth.\n\n\"\"\"\n\nfrom collections import defaultdict\n\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport ray\nfrom tqdm import tqdm\n\nfrom verl.trainer.ppo.reward import get_custom_reward_fn\nfrom verl.utils.fs import copy_to_local\n\n\n@ray.remote\ndef process_item(reward_fn, data_source, response_lst, reward_data):\n    ground_truth = reward_data[\"ground_truth\"]\n    score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]\n    return data_source, np.mean(score_lst)\n\n\n@hydra.main(config_path=\"config\", config_name=\"evaluation\", version_base=None)\ndef main(config):\n    local_path = copy_to_local(config.data.path, use_shm=config.data.get(\"use_shm\", False))\n    dataset = pd.read_parquet(local_path)\n    responses = dataset[config.data.response_key]\n    data_sources = dataset[config.data.data_source_key]\n    reward_model_data = dataset[config.data.reward_model_key]\n\n    total = len(dataset)\n\n    # Initialize Ray\n    if not ray.is_initialized():\n        ray.init(num_cpus=config.ray_init.num_cpus)\n\n    # evaluate test_score based on data source\n    data_source_reward = defaultdict(list)\n    compute_score = get_custom_reward_fn(config)\n\n    # Create remote tasks\n    remote_tasks = [\n        process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)\n    ]\n\n    # Process results as they come in\n    with tqdm(total=total) as pbar:\n        while len(remote_tasks) > 0:\n            # Use ray.wait to get completed tasks\n            done_ids, remote_tasks = ray.wait(remote_tasks)\n            for result_id in done_ids:\n                data_source, score = ray.get(result_id)\n                data_source_reward[data_source].append(score)\n                pbar.update(1)\n\n    metric_dict = {}\n    for data_source, rewards in data_source_reward.items():\n        metric_dict[f\"test_score/{data_source}\"] = np.mean(rewards)\n\n    print(metric_dict)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/verl/trainer/main_generation.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nGenerate responses given a dataset of prompts\n\"\"\"\n\nimport os\n\nimport hydra\nimport numpy as np\nimport ray\n\nos.environ[\"NCCL_DEBUG\"] = \"WARN\"\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n# os.environ['TORCH_COMPILE_DISABLE'] = '1'\n\nfrom pprint import pprint\n\nimport pandas as pd\nfrom omegaconf import OmegaConf\n\nfrom verl import DataProto\nfrom verl.protocol import pad_dataproto_to_divisor, unpad_dataproto\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.hdfs_io import makedirs\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.workers.fsdp_workers import ActorRolloutRefWorker\n\n\n@hydra.main(config_path=\"config\", config_name=\"generation\", version_base=None)\ndef main(config):\n    run_generation(config)\n\n\ndef run_generation(config) -> None:\n    if not ray.is_initialized():\n        # this is for local ray cluster\n        ray.init(\n            runtime_env={\"env_vars\": {\"TOKENIZERS_PARALLELISM\": \"true\", \"NCCL_DEBUG\": \"WARN\"}},\n            num_cpus=config.ray_init.num_cpus,\n        )\n\n    ray.get(main_task.remote(config))\n\n\n@ray.remote(num_cpus=1)\ndef main_task(config):\n    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values\n    OmegaConf.resolve(config)\n\n    local_path = copy_to_local(config.model.path)\n    trust_remote_code = config.data.get(\"trust_remote_code\", False)\n    tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n\n    if config.rollout.temperature == 0.0:\n        assert config.data.n_samples == 1, \"When temperature=0, n_samples must be 1.\"\n    assert config.data.n_samples >= 1, \"n_samples should always >= 1\"\n\n    # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)\n    dataset = pd.read_parquet(config.data.path)\n    chat_lst = dataset[config.data.prompt_key].tolist()\n\n    chat_lst = [chat.tolist() for chat in chat_lst]\n\n    tokenizer.padding_side = \"left\"\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n\n    ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role=\"rollout\")\n    resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)\n    wg = RayWorkerGroup(\n        resource_pool=resource_pool,\n        ray_cls_with_init=ray_cls_with_init,\n        device_name=config.trainer.device,\n    )\n    wg.init_model()\n\n    total_samples = len(dataset)\n    config_batch_size = config.data.batch_size\n    num_batch = -(-total_samples // config_batch_size)\n    output_lst = [[] for _ in range(config.data.n_samples)]\n\n    for batch_idx in range(num_batch):\n        print(f\"[{batch_idx + 1}/{num_batch}] Start to process.\")\n        batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size]\n        inputs = tokenizer.apply_chat_template(\n            batch_chat_lst,\n            add_generation_prompt=True,\n            padding=True,\n            truncation=True,\n            max_length=config.rollout.prompt_length,\n            return_tensors=\"pt\",\n            return_dict=True,\n            tokenize=True,\n        )\n        input_ids = inputs[\"input_ids\"]\n        attention_mask = inputs[\"attention_mask\"]\n        position_ids = compute_position_id_with_mask(attention_mask)\n        batch_dict = {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"position_ids\": position_ids}\n\n        data = DataProto.from_dict(batch_dict)\n        data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)\n\n        # START TO GENERATE FOR n_samples TIMES\n        print(f\"[{batch_idx + 1}/{num_batch}] Start to generate.\")\n        for n_sample in range(config.data.n_samples):\n            output_padded = wg.generate_sequences(data_padded)\n            output = unpad_dataproto(output_padded, pad_size=pad_size)\n\n            output_texts = []\n            for i in range(len(output)):\n                data_item = output[i]\n                prompt_length = data_item.batch[\"prompts\"].shape[-1]\n                valid_response_length = data_item.batch[\"attention_mask\"][prompt_length:].sum()\n                valid_response_ids = data_item.batch[\"responses\"][:valid_response_length]\n                response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n                output_texts.append(response_str)\n\n            output_lst[n_sample].extend(output_texts)\n\n    # convert output_lst from (n_samples, n_data) to (n_data, n_sampels)\n    output_lst = np.array(output_lst, dtype=object)\n    output_lst = np.transpose(output_lst, axes=(1, 0)).tolist()\n\n    # add to the data frame\n    dataset[\"responses\"] = output_lst\n\n    # write to a new parquet\n    output_dir = os.path.dirname(config.data.output_path)\n    makedirs(output_dir, exist_ok=True)\n    dataset.to_parquet(config.data.output_path)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/verl/trainer/main_ppo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nNote that we don't combine the main with ray_trainer as ray_trainer is used by other main.\n\"\"\"\n\nimport os\nimport socket\n\nimport hydra\nimport ray\nfrom omegaconf import OmegaConf\n\nfrom verl.experimental.dataset.sampler import AbstractSampler\nfrom verl.trainer.constants_ppo import get_ppo_ray_runtime_env\nfrom verl.trainer.ppo.ray_trainer import RayPPOTrainer\nfrom verl.trainer.ppo.reward import load_reward_manager\nfrom verl.utils.device import is_cuda_available\nfrom verl.utils.import_utils import load_extern_type\n\n\n@hydra.main(config_path=\"config\", config_name=\"ppo_trainer\", version_base=None)\ndef main(config):\n    \"\"\"Main entry point for PPO training with Hydra configuration management.\n\n    Args:\n        config_dict: Hydra configuration dictionary containing training parameters.\n    \"\"\"\n    run_ppo(config)\n\n\n# Define a function to run the PPO-like training process\ndef run_ppo(config) -> None:\n    \"\"\"Initialize Ray cluster and run distributed PPO training process.\n\n    Args:\n        config: Training configuration object containing all necessary parameters\n                for distributed PPO training including Ray initialization settings,\n                model paths, and training hyperparameters.\n    \"\"\"\n    # Check if Ray is not initialized\n    if not ray.is_initialized():\n        # Initialize Ray with a local cluster configuration\n        # Set environment variables in the runtime environment to control tokenizer parallelism,\n        # NCCL debug level, VLLM logging level, and allow runtime LoRA updating\n        # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration\n        ray.init(\n            runtime_env=get_ppo_ray_runtime_env(),\n            num_cpus=config.ray_init.num_cpus,\n        )\n\n    # Create a remote instance of the TaskRunner class, and\n    # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete\n    if (\n        is_cuda_available\n        and config.trainer.get(\"profile_steps\") is not None\n        and len(config.trainer.get(\"profile_steps\", [])) > 0\n    ):\n        nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)\n        runner = TaskRunner.options(runtime_env={\"nsight\": nsight_options}).remote()\n    else:\n        runner = TaskRunner.remote()\n    ray.get(runner.run.remote(config))\n\n    # [Optional] get the path of the timeline trace file from the configuration, default to None\n    # This file is used for performance analysis\n    timeline_json_file = config.ray_init.get(\"timeline_json_file\", None)\n    if timeline_json_file:\n        ray.timeline(filename=timeline_json_file)\n\n\n@ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head\nclass TaskRunner:\n    \"\"\"Ray remote class for executing distributed PPO training tasks.\n\n    This class encapsulates the main training logic and runs as a Ray remote actor\n    to enable distributed execution across multiple nodes and GPUs.\n    \"\"\"\n\n    def run(self, config):\n        \"\"\"Execute the main PPO training workflow.\n\n        This method sets up the distributed training environment, initializes\n        workers, datasets, and reward functions, then starts the training process.\n\n        Args:\n            config: Training configuration object containing all parameters needed\n                   for setting up and running the PPO training process.\n        \"\"\"\n        # Print the initial configuration. `resolve=True` will evaluate symbolic values.\n        from pprint import pprint\n\n        from omegaconf import OmegaConf\n\n        from verl.utils.fs import copy_to_local\n\n        print(f\"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}\")\n        pprint(OmegaConf.to_container(config, resolve=True))\n        OmegaConf.resolve(config)\n\n        # Download the checkpoint from HDFS to the local machine.\n        # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on\n        local_path = copy_to_local(\n            config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get(\"use_shm\", False)\n        )\n\n        # Instantiate the tokenizer and processor.\n        from verl.utils import hf_processor, hf_tokenizer\n\n        trust_remote_code = config.data.get(\"trust_remote_code\", False)\n        tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        # Used for multimodal LLM, could be None\n        processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)\n\n        # Define worker classes based on the actor strategy.\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"}:\n            assert config.critic.strategy in {\"fsdp\", \"fsdp2\"}\n            from verl.single_controller.ray import RayWorkerGroup\n            from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker\n\n            use_legacy_worker_impl = config.trainer.get(\"use_legacy_worker_impl\", \"auto\")\n            if use_legacy_worker_impl in [\"auto\", \"enable\"]:\n                # import warnings\n                # warnings.warn(f\"Legacy worker impl is going to be deprecated, will be removed in the future. \\\n                #   Please set trainer.use_legacy_worker_impl = false to switch to the new worker implementation.\")\n                from verl.workers.fsdp_workers import CriticWorker\n            elif use_legacy_worker_impl == \"disable\":\n                from verl.workers.roles import CriticWorker\n\n                print(\"Using new worker implementation\")\n            else:\n                raise ValueError(f\"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}\")\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = RayWorkerGroup\n\n        elif config.actor_rollout_ref.actor.strategy == \"megatron\":\n            assert config.actor_rollout_ref.actor.strategy == config.critic.strategy\n            from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n            from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker\n\n            actor_rollout_cls = (\n                AsyncActorRolloutRefWorker\n                if config.actor_rollout_ref.rollout.mode == \"async\"\n                else ActorRolloutRefWorker\n            )\n            ray_worker_group_cls = NVMegatronRayWorkerGroup\n\n        else:\n            raise NotImplementedError\n\n        from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role\n\n        # Map roles to their corresponding remote worker classes.\n        role_worker_mapping = {\n            Role.ActorRollout: ray.remote(actor_rollout_cls),\n            Role.Critic: ray.remote(CriticWorker),\n        }\n\n        # Define the resource pool specification.\n        # Map roles to the resource pool.\n        global_pool_id = \"global_pool\"\n        resource_pool_spec = {\n            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,\n        }\n        mapping = {\n            Role.ActorRollout: global_pool_id,\n            Role.Critic: global_pool_id,\n        }\n\n        # We should adopt a multi-source reward function here:\n        # - for rule-based rm, we directly call a reward score\n        # - for model-based rm, we call a model\n        # - for code related prompt, we send to a sandbox if there are test cases\n        # finally, we combine all the rewards together\n        # The reward type depends on the tag of the data\n        if config.reward_model.enable:\n            if config.reward_model.strategy in {\"fsdp\", \"fsdp2\"}:\n                from verl.workers.fsdp_workers import RewardModelWorker\n            elif config.reward_model.strategy == \"megatron\":\n                from verl.workers.megatron_workers import RewardModelWorker\n            else:\n                raise NotImplementedError\n            role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)\n            mapping[Role.RewardModel] = global_pool_id\n\n        # Add a reference policy worker if KL loss or KL reward is used.\n        if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:\n            role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)\n            mapping[Role.RefPolicy] = global_pool_id\n\n        # Load the reward manager for training and validation.\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        val_reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=1, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n        resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)\n\n        from verl.utils.dataset.rl_dataset import collate_fn\n\n        # Create training and validation datasets.\n        train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True)\n        val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False)\n        train_sampler = create_rl_sampler(config.data, train_dataset)\n\n        # Initialize the PPO trainer.\n        trainer = RayPPOTrainer(\n            config=config,\n            tokenizer=tokenizer,\n            processor=processor,\n            role_worker_mapping=role_worker_mapping,\n            resource_pool_manager=resource_pool_manager,\n            ray_worker_group_cls=ray_worker_group_cls,\n            reward_fn=reward_fn,\n            val_reward_fn=val_reward_fn,\n            train_dataset=train_dataset,\n            val_dataset=val_dataset,\n            collate_fn=collate_fn,\n            train_sampler=train_sampler,\n        )\n        # Initialize the workers of the trainer.\n        trainer.init_workers()\n        # Start the training process.\n        trainer.fit()\n\n\ndef create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True):\n    \"\"\"Create a dataset.\n\n    Arguments:\n        data_paths: List of paths to data files.\n        data_config: The data config.\n        tokenizer (Tokenizer): The tokenizer.\n        processor (Processor): The processor.\n\n    Returns:\n        dataset (Dataset): The dataset.\n    \"\"\"\n    from torch.utils.data import Dataset\n\n    from verl.utils.dataset.rl_dataset import RLHFDataset\n\n    # Check if a custom dataset class is specified in the data configuration\n    # and if the path to the custom class is provided\n    if \"custom_cls\" in data_config and data_config.custom_cls.get(\"path\", None) is not None:\n        # Dynamically load the custom dataset class\n        dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)\n        # Verify that the custom dataset class inherits from torch.utils.data.Dataset\n        if not issubclass(dataset_cls, Dataset):\n            raise TypeError(\n                f\"The custom dataset class '{data_config.custom_cls.name}' from \"\n                f\"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset\"\n            )\n    elif \"datagen\" in data_config and data_config.datagen.get(\"path\", None) is not None and is_train:\n        # If a data generation strategy is specified, use the DynamicGenDataset class\n        from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset\n\n        dataset_cls = DynamicGenDataset\n        print(\"Using DynamicGenDataset for data generation.\")\n\n    else:\n        # Use the default RLHFDataset class if no custom class is specified\n        dataset_cls = RLHFDataset\n    print(f\"Using dataset class: {dataset_cls.__name__}\")\n\n    # Instantiate the dataset using the determined dataset class\n    dataset = dataset_cls(\n        data_files=data_paths,\n        tokenizer=tokenizer,\n        processor=processor,\n        config=data_config,\n    )\n\n    return dataset\n\n\ndef create_rl_sampler(data_config, dataset):\n    \"\"\"Create a sampler for the dataset.\n\n    Arguments:\n        data_config: The data config.\n        dataset (Dataset): The dataset.\n\n    Returns:\n        sampler (Sampler): The sampler.\n    \"\"\"\n    import torch\n    from torch.utils.data import RandomSampler, SequentialSampler\n\n    if data_config.sampler is not None and data_config.sampler.get(\"class_path\", None) is not None:\n        curriculum_class = load_extern_type(\n            data_config.sampler.class_path,\n            data_config.sampler.class_name,\n        )\n        sampler = curriculum_class(\n            data_source=dataset,\n            data_config=data_config,\n        )\n        assert isinstance(sampler, AbstractSampler)\n        assert data_config.get(\"dataloader_num_workers\", 8) == 0, (\n            \"If using curriculum, num_workers must be 0 to prevent data caching. \"\n            \"If the dataloader caches data before the batch is done the \"\n            \"curriculum sampler won't have the opportunity to reorder it. \"\n        )\n\n    # Use a sampler to facilitate checkpoint resumption.\n    # If shuffling is enabled in the data configuration, create a random sampler.\n    elif data_config.shuffle:\n        train_dataloader_generator = torch.Generator()\n        train_dataloader_generator.manual_seed(data_config.get(\"seed\", 1))\n        sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)\n    else:\n        # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order.\n        sampler = SequentialSampler(data_source=dataset)\n\n    return sampler\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "verl_rl/verl/trainer/ppo/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/trainer/ppo/core_algos.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nCore functions to implement PPO algorithms.\nThe function implemented in this file should be used by trainer with different distributed strategies to\nimplement PPO-like algorithms.\n\"\"\"\n\n__all__ = [\"register_adv_est\", \"get_adv_estimator_fn\", \"AdvantageEstimator\"]\n\nfrom collections import defaultdict\nfrom enum import Enum\nfrom typing import Optional\n\nimport numpy as np\nimport torch\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.trainer.config import AlgoConfig\n\nPOLICY_LOSS_REGISTRY = {}\n\n\ndef register_policy_loss(name):\n    \"\"\"Register a policy loss function with the given name.\n\n    Args:\n        name (str): The name to register the policy loss function under.\n\n    Returns:\n        function: Decorator function that registers the policy loss function.\n    \"\"\"\n\n    def decorator(func):\n        POLICY_LOSS_REGISTRY[name] = func\n        return func\n\n    return decorator\n\n\ndef get_policy_loss_fn(name):\n    \"\"\"Get the policy loss with a given name.\n\n    Args:\n        name: `(str)`\n            The name of the policy loss.\n\n    Returns:\n        `(callable)`: The policy loss function.\n    \"\"\"\n    loss_name = name\n    if loss_name not in POLICY_LOSS_REGISTRY:\n        raise ValueError(\n            f\"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}\"\n        )\n    return POLICY_LOSS_REGISTRY[loss_name]\n\n\nADV_ESTIMATOR_REGISTRY = {}\n\n\ndef register_adv_est(name_or_enum):\n    \"\"\"Decorator to register a advantage estimator function with a given name.\n\n    Args:\n        name_or_enum: `(str)` or `(AdvantageEstimator)`\n            The name or enum of the advantage estimator.\n\n    \"\"\"\n\n    def decorator(fn):\n        name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum\n        if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn:\n            raise ValueError(\n                f\"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}\"\n            )\n        ADV_ESTIMATOR_REGISTRY[name] = fn\n        return fn\n\n    return decorator\n\n\ndef get_adv_estimator_fn(name_or_enum):\n    \"\"\"Get the advantage estimator function with a given name.\n\n    Args:\n        name_or_enum: `(str)` or `(AdvantageEstimator)`\n            The name or enum of the advantage estimator.\n\n    Returns:\n        `(callable)`: The advantage estimator function.\n    \"\"\"\n    name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum\n    if name not in ADV_ESTIMATOR_REGISTRY:\n        raise ValueError(f\"Unknown advantage estimator simply: {name}\")\n    return ADV_ESTIMATOR_REGISTRY[name]\n\n\nclass AdvantageEstimator(str, Enum):\n    \"\"\"Using an enumeration class to avoid spelling errors in adv_estimator.\n\n    Note(haibin.lin): this enum class is immutable after creation. Extending this\n    enum for new estimators may not be necessary since users can always just call\n    `verl.trainer.ppo.core_algos.register` with string name for a custom advantage\n    estimator instead.\n    \"\"\"\n\n    GAE = \"gae\"\n    GRPO = \"grpo\"\n    REINFORCE_PLUS_PLUS = \"reinforce_plus_plus\"\n    REINFORCE_PLUS_PLUS_BASELINE = \"reinforce_plus_plus_baseline\"\n    REMAX = \"remax\"\n    RLOO = \"rloo\"\n    OPO = \"opo\"\n    GRPO_PASSK = \"grpo_passk\"\n    GPG = \"gpg\"\n\n\nclass AdaptiveKLController:\n    \"\"\"\n    Adaptive KL controller described in the paper:\n    https://arxiv.org/pdf/1909.08593.pdf\n    \"\"\"\n\n    def __init__(self, init_kl_coef, target_kl, horizon):\n        self.value = init_kl_coef\n        self.target = target_kl\n        self.horizon = horizon\n\n    def update(self, current_kl, n_steps):\n        \"\"\"Update the KL coefficient based on current KL divergence.\n\n        Args:\n            current_kl (float): Current KL divergence value.\n            n_steps (int): Number of steps taken.\n        \"\"\"\n        target = self.target\n        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)\n        mult = 1 + proportional_error * n_steps / self.horizon\n        self.value *= mult\n\n\nclass FixedKLController:\n    \"\"\"Fixed KL controller.\"\"\"\n\n    def __init__(self, kl_coef):\n        self.value = kl_coef\n\n    def update(self, current_kl, n_steps):\n        \"\"\"Update method for fixed KL controller (no-op).\n\n        Args:\n            current_kl (float): Current KL divergence value (unused).\n            n_steps (int): Number of steps taken (unused).\n        \"\"\"\n        pass\n\n\ndef get_kl_controller(kl_ctrl):\n    \"\"\"Factory function to create appropriate KL controller based on configuration.\n\n    Args:\n        kl_ctrl: Configuration object containing KL controller settings.\n\n    Returns:\n        KL controller instance (FixedKLController or AdaptiveKLController).\n\n    Raises:\n        NotImplementedError: If controller type is not supported.\n        AssertionError: If adaptive controller horizon is not positive.\n    \"\"\"\n    if kl_ctrl.type == \"fixed\":\n        return FixedKLController(kl_coef=kl_ctrl.kl_coef)\n    elif kl_ctrl.type == \"adaptive\":\n        assert kl_ctrl.horizon > 0, f\"horizon must be larger than 0. Got {kl_ctrl.horizon}\"\n        return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)\n    else:\n        raise NotImplementedError\n\n\n@register_adv_est(AdvantageEstimator.GAE)  # or simply: @register_adv_est(\"gae\")\ndef compute_gae_advantage_return(\n    token_level_rewards: torch.Tensor,\n    values: torch.Tensor,\n    response_mask: torch.Tensor,\n    gamma: torch.Tensor,\n    lam: torch.Tensor,\n):\n    \"\"\"Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape is (bs, response_length)\n        values: `(torch.Tensor)`\n            shape is (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.\n        gamma is `(float)`\n            discounted factor used in RL\n        lam: `(float)`\n            lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n\n    \"\"\"\n    with torch.no_grad():\n        nextvalues = 0\n        lastgaelam = 0\n        advantages_reversed = []\n        gen_len = token_level_rewards.shape[-1]\n\n        for t in reversed(range(gen_len)):\n            delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]\n            lastgaelam_ = delta + gamma * lam * lastgaelam\n\n            # skip values and TD-error on observation tokens\n            nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues\n            lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam\n\n            advantages_reversed.append(lastgaelam)\n        advantages = torch.stack(advantages_reversed[::-1], dim=1)\n\n        returns = advantages + values\n        advantages = verl_F.masked_whiten(advantages, response_mask)\n    return advantages, returns\n\n\n# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.\n@register_adv_est(AdvantageEstimator.GRPO)  # or simply: @register_adv_est(\"grpo\")\ndef compute_grpo_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    norm_adv_by_std_in_grpo: bool = True,\n    config: Optional[AlgoConfig] = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for GRPO, operating only on Outcome reward\n    (with only one scalar reward for each response).\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape is (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape is (bs, response_length)\n        index: `(np.ndarray)`\n            index array for grouping\n        epsilon: `(float)`\n            small value to avoid division by zero\n        norm_adv_by_std_in_grpo: `(bool)`\n            whether to scale the GRPO advantage\n        config: `(Optional[AlgoConfig])`\n            algorithm configuration object\n\n    Note:\n        If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.\n        If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape is (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape is (bs, response_length)\n    \"\"\"\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2mean = {}\n    id2std = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2mean[idx] = torch.tensor(0.0)\n                id2std[idx] = torch.tensor(1.0)\n            elif len(id2score[idx]) > 1:\n                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))\n                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            if norm_adv_by_std_in_grpo:\n                scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)\n            else:\n                scores[i] = scores[i] - id2mean[index[i]]\n        scores = scores.unsqueeze(-1) * response_mask\n\n    return scores, scores\n\n\n@register_adv_est(AdvantageEstimator.GRPO_PASSK)  # or simply: @register_adv_est(\"grpo_passk\")\ndef compute_grpo_passk_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    norm_adv_by_std_in_grpo: bool = True,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for Pass@k using a GRPO-style outcome reward formulation.\n    Only the best response per group gets a non-zero advantage: r_max - r_second_max.\n\n    Implemented as described in https://arxiv.org/abs/2503.19595.\n\n    Args:\n        token_level_rewards: (bs, response_length)\n        response_mask: (bs, response_length)\n        index: (bs,) → group ID per sample\n        epsilon: float for numerical stability\n        config: (AlgoConfig) algorithm settings, which contains \"norm_adv_by_std_in_grpo\"\n\n    Returns:\n        advantages: (bs, response_length)\n        returns: (bs, response_length)\n    \"\"\"\n    assert config is not None\n    # if True, normalize advantage by std within group\n    norm_adv_by_std_in_grpo = config.get(\"norm_adv_by_std_in_grpo\", True)\n    scores = token_level_rewards.sum(dim=-1)  # (bs,)\n    advantages = torch.zeros_like(scores)\n\n    id2scores = defaultdict(list)\n    id2indices = defaultdict(list)\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            idx = index[i]\n            id2scores[idx].append(scores[i])\n            id2indices[idx].append(i)\n\n        for idx in id2scores:\n            rewards = torch.stack(id2scores[idx])  # (k,)\n            if rewards.numel() < 2:\n                raise ValueError(\n                    f\"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}.\"\n                )\n            topk, topk_idx = torch.topk(rewards, 2)\n            r_max, r_second_max = topk[0], topk[1]\n            i_max = id2indices[idx][topk_idx[0].item()]\n            advantage = r_max - r_second_max\n            if norm_adv_by_std_in_grpo:\n                std = torch.std(rewards)\n                advantage = advantage / (std + epsilon)\n            advantages[i_max] = advantage\n\n    advantages = advantages.unsqueeze(-1) * response_mask\n    return advantages, advantages\n\n\n@register_adv_est(\n    AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE\n)  # or simply: @register_adv_est(\"reinforce_plus_plus_baseline\")\ndef compute_reinforce_plus_plus_baseline_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: torch.Tensor,\n    epsilon: float = 1e-6,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward\n    (with only one scalar reward for each response).\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    response_length = token_level_rewards.shape[-1]\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2mean = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2mean[idx] = torch.tensor(0.0)\n            elif len(id2score[idx]) > 1:\n                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            scores[i] = scores[i] - id2mean[index[i]]\n\n        scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask\n        scores = verl_F.masked_whiten(scores, response_mask) * response_mask\n\n    return scores, scores\n\n\n@register_adv_est(AdvantageEstimator.RLOO)  # or simply: @register_adv_est(\"rloo\")\ndef compute_rloo_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2mean = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2mean[idx] = torch.tensor(0.0)\n            elif len(id2score[idx]) > 1:\n                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            response_num = len(id2score[index[i]])\n            if response_num > 1:\n                scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (\n                    response_num - 1\n                )\n        scores = scores.unsqueeze(-1) * response_mask\n\n    return scores, scores\n\n\n@register_adv_est(AdvantageEstimator.OPO)  # or simply: @register_adv_est(\"opo\")\ndef compute_opo_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    response_length = response_mask.sum(dim=-1)\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2len = defaultdict(list)\n    id2bsl = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n            id2len[index[i]].append(response_length[i])\n\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2bsl[idx] = torch.tensor(0.0)\n            elif len(id2score[idx]) > 1:\n                score_tensor = torch.tensor(id2score[idx])\n                len_tensor = torch.tensor(id2len[idx])\n                id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum()\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            scores[i] = scores[i] - id2bsl[index[i]]\n        scores = scores.unsqueeze(-1) * response_mask\n\n    return scores, scores\n\n\n@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS)  # or simply: @register_adv_est(\"reinforce_plus_plus\")\ndef compute_reinforce_plus_plus_outcome_advantage(\n    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for REINFORCE++.\n    This implementation is based on the paper: https://arxiv.org/abs/2501.03262\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    assert config is not None\n    gamma = config.gamma\n    with torch.no_grad():\n        returns = torch.zeros_like(token_level_rewards)\n        running_return = 0\n\n        for t in reversed(range(token_level_rewards.shape[1])):\n            running_return = token_level_rewards[:, t] + gamma * running_return\n            returns[:, t] = running_return\n            # Reset after EOS\n            running_return = running_return * response_mask[:, t]\n\n        advantages = verl_F.masked_whiten(returns, response_mask)\n        advantages = advantages * response_mask\n\n    return advantages, returns\n\n\n@register_adv_est(AdvantageEstimator.REMAX)  # or simply: @register_adv_est(\"remax\")\ndef compute_remax_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    reward_baselines: torch.Tensor,\n    response_mask: torch.Tensor,\n    config: Optional[AlgoConfig] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for ReMax, operating only on Outcome reward\n    This implementation is based on the paper: https://arxiv.org/abs/2310.10505\n    (with only one scalar reward for each response).\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        reward_baselines: `(torch.Tensor)`\n            shape: (bs,)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        config: (AlgoConfig) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n\n    with torch.no_grad():\n        returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])\n        advantages = returns - reward_baselines.unsqueeze(-1) * response_mask\n\n    return advantages, returns\n\n\n@register_adv_est(AdvantageEstimator.GPG)  # or simply: @register_adv_est(\"gpg\")\ndef compute_gpg_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    f_norm: float = 1.0,\n    alpha: float = 1.0,\n    config=None,\n    **kwargs,\n):\n    \"\"\"\n    Compute advantage for GPG, operating only on Outcome reward\n    (with only one scalar reward for each response).\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        index: `(np.ndarray)`\n            shape: (bs,)\n        epsilon: (float)\n        f_norm: (float)\n        alpha: (float)\n        config: (dict) algorithm config\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2mean = {}\n    id2std = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        m = torch.count_nonzero(scores)\n        alpha = bsz / m.clamp(min=1)\n\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2mean[idx] = torch.tensor(0.0)\n                id2std[idx] = torch.tensor(1.0)\n            elif len(id2score[idx]) > 1:\n                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))\n                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm)\n        scores = scores.unsqueeze(-1) * response_mask\n\n    return scores, scores\n\n\ndef compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):\n    \"\"\"Compute token-level rewards with KL penalty.\n\n    Args:\n        token_level_scores (torch.Tensor): Token-level reward scores.\n        old_log_prob (torch.Tensor): Log probabilities from current policy.\n        ref_log_prob (torch.Tensor): Log probabilities from reference policy.\n        kl_ratio (float): KL penalty coefficient.\n\n    Returns:\n        torch.Tensor: Token-level rewards with KL penalty applied.\n    \"\"\"\n    kl = old_log_prob - ref_log_prob\n    return token_level_scores - kl * kl_ratio\n\n\ndef agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):\n    \"\"\"\n    Aggregate the loss matrix into a scalar.\n\n    Args:\n        loss_mat: `(torch.Tensor)`:\n            shape: (bs, response_length)\n        loss_mask: `(torch.Tensor)`:\n            shape: (bs, response_length)\n        loss_agg_mode: (str) choices:\n            method to aggregate the loss matrix into a scalar.\n    Returns:\n        loss: `a scalar torch.Tensor`\n            aggregated loss\n    \"\"\"\n    if loss_agg_mode == \"token-mean\":\n        loss = verl_F.masked_mean(loss_mat, loss_mask)\n    elif loss_agg_mode == \"seq-mean-token-sum\":\n        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum\n        loss = torch.mean(seq_losses)  # seq-mean\n    elif loss_agg_mode == \"seq-mean-token-mean\":\n        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)  # token-mean\n        loss = torch.mean(seq_losses)  # seq-mean\n    elif loss_agg_mode == \"seq-mean-token-sum-norm\":\n        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)\n        loss = torch.sum(seq_losses) / loss_mask.shape[-1]  # The divisor\n        # (loss_mask.shape[-1]) should ideally be constant\n        # throughout training to well-replicate the DrGRPO paper.\n        # TODO: Perhaps add user-defined normalizer argument to\n        # agg_loss to ensure divisor stays constant throughout.\n    else:\n        raise ValueError(f\"Invalid loss_agg_mode: {loss_agg_mode}\")\n\n    return loss\n\n@register_policy_loss(\"gspo\")\ndef compute_policy_loss_gspo(\n    old_log_prob,\n    log_prob,\n    advantages,\n    response_mask,\n    cliprange=None,\n    cliprange_low=None,\n    cliprange_high=None,\n    loss_agg_mode=\"seq-mean-token-mean\"\n):\n    clip_ratio_low = cliprange_low if cliprange_low is not None else cliprange\n    clip_ratio_high = cliprange_high if cliprange_high is not None else cliprange\n\n    negative_approx_kl = log_prob - old_log_prob\n\n    # compute sequence-level importance ratio:\n    # si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) =\n    # exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]\n    seq_lengths = torch.sum(response_mask, dim=-1).clamp(min=1)\n    negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths\n\n    # Combined ratio at token level:\n    # s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]\n    # In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_prob - sg[log_prob]\n    log_seq_importance_ratio = log_prob - log_prob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)\n    log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, max=10.0)  # clamp for numerical stability\n\n    # finaly exp() to remove log\n    seq_importance_ratio = torch.exp(log_seq_importance_ratio)\n\n    pg_losses1 = -advantages * seq_importance_ratio\n    pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)\n    pg_losses = torch.maximum(pg_losses1, pg_losses2)\n\n    # for GSPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    # For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)\n    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)\n    pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)\n\n    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\ndef compute_policy_loss(\n    old_log_prob,\n    log_prob,\n    advantages,\n    response_mask,\n    cliprange=None,\n    cliprange_low=None,\n    cliprange_high=None,\n    clip_ratio_c=3.0,\n    loss_agg_mode: str = \"token-mean\",\n):\n    \"\"\"\n    Compute the clipped policy objective and related metrics for PPO.\n\n    Adapted from\n    https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        cliprange (float, optional):\n            Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.\n            Defaults to None (must be provided).\n        cliprange_low (float, optional):\n            Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        cliprange_high (float, optional):\n            Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        clip_ratio_c (float, optional):\n            Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.\n            Defaults to 3.0.\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n    \"\"\"\n    assert clip_ratio_c > 1.0, (\n        \"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,\"\n        + f\" but get the value: {clip_ratio_c}.\"\n    )\n\n    negative_approx_kl = log_prob - old_log_prob\n    # Clamp negative_approx_kl for stability\n    negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    pg_losses1 = -advantages * ratio\n    if cliprange_low is None:\n        cliprange_low = cliprange\n    if cliprange_high is None:\n        cliprange_high = cliprange\n    pg_losses2 = -advantages * torch.clamp(\n        ratio, 1 - cliprange_low, 1 + cliprange_high\n    )  # - clip(ratio, 1-cliprange, 1+cliprange) * A\n    clip_pg_losses1 = torch.maximum(\n        pg_losses1, pg_losses2\n    )  # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)\n    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)\n\n    pg_losses3 = -advantages * clip_ratio_c\n    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)\n    pg_clipfrac_lower = verl_F.masked_mean(\n        torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask\n    )\n\n    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\n@register_policy_loss(\"gpg\")\ndef compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode=\"token-mean\", config=None):\n    \"\"\"Adapted from\n    https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495\n    Args:\n        log_prob: `(torch.Tensor)`\n            shape: (bs, response_length)\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n    return:\n        pg_loss: `a scalar torch.Tensor`\n            policy gradient loss computed via GPG\n    \"\"\"\n    pg_losses = -log_prob * advantages\n\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n    return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)\n\n\n@register_policy_loss(\"clip_cov\")\ndef compute_policy_loss_clip_cov(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[AlgoConfig] = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for Clip-Cov.\n\n    Adapted from\n    https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        cliprange (float, optional):\n            Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.\n            Defaults to None (must be provided).\n        cliprange_low (float, optional):\n            Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        cliprange_high (float, optional):\n            Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n        clip_cvo_ratio (float, optional):\n            Ratio for clipping the covariance. Defaults to 0.0002.\n        clip_cov_lb (float, optional):\n            Lower bound for clipping covariance. Defaults to 1.0.\n        clip_cov_ub (float, optional):\n            Upper bound for clipping covariance. Defaults to 5.0.\n    \"\"\"\n    clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002\n    cliprange = config.clip_ratio\n    cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange\n    cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange\n    clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0\n    clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0\n\n    assert clip_cov_ratio > 0, \"clip_ratio should be larger than 0.\"\n\n    negative_approx_kl = log_prob - old_log_prob\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    pg_losses1 = -advantages * ratio\n\n    if cliprange_low is None:\n        cliprange_low = cliprange\n    if cliprange_high is None:\n        cliprange_high = cliprange\n\n    corr = torch.ones_like(advantages)\n    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)\n    clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0)\n\n    cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * (\n        log_prob - verl_F.masked_mean(log_prob.detach(), response_mask)\n    )\n    cov_all[response_mask == 0] = -torch.inf\n    cov_all[clip_by_origin] = -torch.inf\n\n    clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1)\n    top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0)\n    top_k_idx = torch.nonzero(top_k_idx)\n\n    if len(top_k_idx) > 0:\n        perm = torch.randperm(len(top_k_idx))\n        top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]]\n    else:\n        top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long)\n\n    corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0\n\n    pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask)\n\n    pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)\n\n\n@register_policy_loss(\"kl_cov\")\ndef compute_policy_loss_kl_cov(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[AlgoConfig] = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for Clip-Cov.\n\n    Adapted from\n    https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n        kl_cov_ratio (float, optional):\n            Ratio for selecting the top-k covariance values. Defaults to 0.0002.\n        ppo_kl_coef (float, optional):\n            Coefficient for the KL penalty term in the loss. Defaults to 1.\n    \"\"\"\n    kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002\n    ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0\n\n    assert kl_cov_ratio > 0, \"kl_cov_ratio should be larger than 0.\"\n\n    negative_approx_kl = log_prob - old_log_prob\n    abs_kl = negative_approx_kl.abs()\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask)\n    pg_losses1 = -advantages * ratio\n    pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl\n    pg_losses = pg_losses1\n\n    all_valid = response_mask > 0\n    all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0]\n    all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu()\n    all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu()\n\n    k = min(kl_cov_ratio, len(all_valid_adv))\n\n    if k != 0:\n        cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean())\n        k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio))\n        large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices\n\n        if len(large_cov_idxs) != 0:\n            large_cov_idxs = all_valid_idx[large_cov_idxs]\n            pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[\n                large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]\n            ]\n\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)\n\n\ndef compute_entropy_loss(logits, response_mask, loss_agg_mode: str = \"token-mean\"):\n    \"\"\"Compute categorical entropy loss (For backward compatibility)\n\n    Args:\n        logits (torch.Tensor): shape is (bs, response_length, vocab_size)\n        response_mask (torch.Tensor): shape is (bs, response_length)\n\n    Returns:\n        entropy: a scalar torch.Tensor\n\n    \"\"\"\n    # compute entropy\n    token_entropy = verl_F.entropy_from_logits(logits)  # (bs, response_len)\n    entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n    return entropy_loss\n\n\ndef compute_value_loss(\n    vpreds: torch.Tensor,\n    returns: torch.Tensor,\n    values: torch.Tensor,\n    response_mask: torch.Tensor,\n    cliprange_value: float,\n    loss_agg_mode: str = \"token-mean\",\n):\n    \"\"\"\n    Compute the clipped value-function loss for PPO.\n\n    Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151\n\n    Args:\n        vpreds (torch.FloatTensor):\n            Predicted values from the value head, shape (batch_size, response_length).\n        values (torch.FloatTensor):\n            Old (baseline) values from the value head, shape (batch_size, response_length).\n        returns (torch.FloatTensor):\n            Ground-truth returns, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the value loss calculation.\n        cliprange_value (float):\n            Clip range for value prediction updates.\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n\n    Returns:\n        vf_loss (torch.FloatTensor):\n            A scalar tensor containing the aggregated value-function loss.\n        vf_clipfrac (float):\n            Fraction of elements where the clipped loss was used.\n    \"\"\"\n    vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)\n    vf_losses1 = (vpreds - returns) ** 2\n    vf_losses2 = (vpredclipped - returns) ** 2\n    clipped_vf_losses = torch.max(vf_losses1, vf_losses2)\n    vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n    vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)\n    return vf_loss, vf_clipfrac\n\n\ndef kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:\n    \"\"\"Compute KL divergence given logprob and ref_logprob.\n    Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104\n    See more description in http://joschu.net/blog/kl-approx.html\n\n    Args:\n        logprob:\n        ref_logprob:\n\n    Returns:\n\n    \"\"\"\n    if kl_penalty in (\"kl\", \"k1\"):\n        return logprob - ref_logprob\n\n    if kl_penalty == \"abs\":\n        return (logprob - ref_logprob).abs()\n\n    if kl_penalty in (\"mse\", \"k2\"):\n        return 0.5 * (logprob - ref_logprob).square()\n\n    # J. Schulman. Approximating kl divergence, 2020.\n    # # URL http://joschu.net/blog/kl-approx.html.\n    if kl_penalty in (\"low_var_kl\", \"k3\"):\n        kl = ref_logprob - logprob\n        # For numerical stability\n        kl = torch.clamp(kl, min=-20, max=20)\n        ratio = torch.exp(kl)\n        kld = (ratio - kl - 1).contiguous()\n        return torch.clamp(kld, min=-10, max=10)\n\n    if kl_penalty == \"full\":\n        # so, here logprob and ref_logprob should contain the logits for every token in vocabulary\n        raise NotImplementedError\n\n    raise NotImplementedError\n\n\ndef compute_pf_ppo_reweight_data(\n    data,\n    reweight_method: str = \"pow\",\n    weight_pow: float = 2.0,\n):\n    \"\"\"Reweight the data based on the token_level_scores.\n\n    Args:\n        data: DataProto object, containing batch, non_tensor_batch and meta_info\n        reweight_method: str, choices: \"pow\", \"max_min\", \"max_random\"\n        weight_pow: float, the power of the weight\n\n    Returns:\n\n    \"\"\"\n\n    @torch.no_grad()\n    def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor:\n        \"\"\"Compute importance weights for resampling based on scores.\n\n        Args:\n            scores (torch.Tensor): Tensor of scores to compute weights from.\n            reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random').\n            weight_pow (float): Power exponent for 'pow' method.\n\n        Returns:\n            torch.Tensor: Computed importance weights.\n\n        Raises:\n            ValueError: If reweight_method is not supported.\n        \"\"\"\n        if reweight_method == \"pow\":\n            weights = torch.pow(torch.abs(scores), weight_pow)\n        elif reweight_method == \"max_min\":\n            max_score = torch.max(scores)\n            min_score = torch.min(scores)\n            weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0)\n        elif reweight_method == \"max_random\":\n            max_score = torch.max(scores)\n            weights = torch.where(scores == max_score, 0.4, 0.1)\n        else:\n            raise ValueError(f\"Unsupported reweight_method: {reweight_method}\")\n        return weights\n\n    scores = data.batch[\"token_level_scores\"].sum(dim=-1)\n    weights = compute_weights(scores, reweight_method, weight_pow)\n    weights = torch.clamp(weights + 1e-8, min=1e-8)\n\n    batch_size = scores.shape[0]\n    sample_indices = torch.multinomial(weights, batch_size, replacement=True)\n\n    resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()}\n\n    sample_indices_np = sample_indices.numpy()\n    resampled_non_tensor_batch = {}\n    for key, array in data.non_tensor_batch.items():\n        if isinstance(array, np.ndarray):\n            resampled_non_tensor_batch[key] = array[sample_indices_np]\n        else:\n            resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np]\n\n    resampled_meta_info = {}\n    for key, value in data.meta_info.items():\n        if isinstance(value, list) and len(value) == batch_size:\n            resampled_meta_info[key] = [value[i] for i in sample_indices_np]\n        else:\n            resampled_meta_info[key] = value\n\n    from copy import deepcopy\n\n    resampled_data = deepcopy(data)\n    resampled_data.batch = type(data.batch)(resampled_batch)\n    resampled_data.batch.batch_size = data.batch.batch_size\n    resampled_data.non_tensor_batch = resampled_non_tensor_batch\n    resampled_data.meta_info = resampled_meta_info\n\n    return resampled_data\n"
  },
  {
    "path": "verl_rl/verl/trainer/ppo/metric_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMetrics related to the PPO trainer.\n\"\"\"\n\nfrom collections import defaultdict\nfrom functools import partial\nfrom typing import Any, Callable\n\nimport numpy as np\nimport torch\n\nfrom verl import DataProto\nfrom verl.utils.import_utils import deprecated\n\n\n@deprecated(\"verl.utils.metric.reduce_metrics\")\ndef reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:\n    \"\"\"\n    Reduces a dictionary of metric lists by computing the mean of each list.\n\n    Args:\n        metrics: A dictionary mapping metric names to lists of metric values.\n\n    Returns:\n        A dictionary with the same keys but with each list replaced by its mean value.\n\n    Example:\n        >>> metrics = {\"loss\": [1.0, 2.0, 3.0], \"accuracy\": [0.8, 0.9, 0.7]}\n        >>> reduce_metrics(metrics)\n        {\"loss\": 2.0, \"accuracy\": 0.8}\n    \"\"\"\n    from verl.utils.metric import reduce_metrics\n\n    return reduce_metrics(metrics)\n\n\ndef _compute_response_info(batch: DataProto) -> dict[str, Any]:\n    \"\"\"\n    Computes information about prompts and responses from a batch.\n\n    This is an internal helper function that extracts masks and lengths for prompts and responses.\n\n    Args:\n        batch: A DataProto object containing batch data with responses and attention masks.\n\n    Returns:\n        A dictionary containing:\n            - response_mask: Attention mask for the response tokens\n            - prompt_length: Tensor of prompt lengths for each item in the batch\n            - response_length: Tensor of response lengths for each item in the batch\n    \"\"\"\n    response_length = batch.batch[\"responses\"].shape[-1]\n\n    prompt_mask = batch.batch[\"attention_mask\"][:, :-response_length]\n    response_mask = batch.batch[\"attention_mask\"][:, -response_length:]\n\n    prompt_length = prompt_mask.sum(-1).float()\n    response_length = response_mask.sum(-1).float()  # (batch_size,)\n\n    return dict(\n        response_mask=response_mask,\n        prompt_length=prompt_length,\n        response_length=response_length,\n    )\n\n\ndef compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, Any]:\n    \"\"\"\n    Computes various metrics from a batch of data for PPO training.\n\n    This function calculates metrics related to scores, rewards, advantages, returns, values,\n    and sequence lengths from a batch of data. It provides statistical information (mean, max, min)\n    for each metric category.\n\n    Args:\n        batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.\n        use_critic: Whether to include critic-specific metrics. Defaults to True.\n\n    Returns:\n        A dictionary of metrics including:\n            - critic/score/mean, max, min: Statistics about sequence scores\n            - critic/rewards/mean, max, min: Statistics about sequence rewards\n            - critic/advantages/mean, max, min: Statistics about advantages\n            - critic/returns/mean, max, min: Statistics about returns\n            - critic/values/mean, max, min: Statistics about critic values (if use_critic=True)\n            - critic/vf_explained_var: Explained variance of the value function (if use_critic=True)\n            - response_length/mean, max, min, clip_ratio: Statistics about response lengths\n            - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths\n            - num_turns/mean, max, min: Statistics about the number of multi-turn conversations\n    \"\"\"\n    sequence_score = batch.batch[\"token_level_scores\"].sum(-1)\n    sequence_reward = batch.batch[\"token_level_rewards\"].sum(-1)\n\n    advantages = batch.batch[\"advantages\"]\n    returns = batch.batch[\"returns\"]\n\n    max_response_length = batch.batch[\"responses\"].shape[-1]\n\n    prompt_mask = batch.batch[\"attention_mask\"][:, :-max_response_length].bool()\n    response_mask = batch.batch[\"response_mask\"].bool()\n\n    max_prompt_length = prompt_mask.size(-1)\n\n    response_info = _compute_response_info(batch)\n    prompt_length = response_info[\"prompt_length\"]\n    response_length = response_info[\"response_length\"]\n\n    valid_adv = torch.masked_select(advantages, response_mask)\n    valid_returns = torch.masked_select(returns, response_mask)\n\n    if use_critic:\n        values = batch.batch[\"values\"]\n        valid_values = torch.masked_select(values, response_mask)\n        return_diff_var = torch.var(valid_returns - valid_values)\n        return_var = torch.var(valid_returns)\n\n    metrics = {\n        # score\n        \"critic/score/mean\": torch.mean(sequence_score).detach().item(),\n        \"critic/score/max\": torch.max(sequence_score).detach().item(),\n        \"critic/score/min\": torch.min(sequence_score).detach().item(),\n        # reward\n        \"critic/rewards/mean\": torch.mean(sequence_reward).detach().item(),\n        \"critic/rewards/max\": torch.max(sequence_reward).detach().item(),\n        \"critic/rewards/min\": torch.min(sequence_reward).detach().item(),\n        # adv\n        \"critic/advantages/mean\": torch.mean(valid_adv).detach().item(),\n        \"critic/advantages/max\": torch.max(valid_adv).detach().item(),\n        \"critic/advantages/min\": torch.min(valid_adv).detach().item(),\n        # returns\n        \"critic/returns/mean\": torch.mean(valid_returns).detach().item(),\n        \"critic/returns/max\": torch.max(valid_returns).detach().item(),\n        \"critic/returns/min\": torch.min(valid_returns).detach().item(),\n        **(\n            {\n                # values\n                \"critic/values/mean\": torch.mean(valid_values).detach().item(),\n                \"critic/values/max\": torch.max(valid_values).detach().item(),\n                \"critic/values/min\": torch.min(valid_values).detach().item(),\n                # vf explained var\n                \"critic/vf_explained_var\": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),\n            }\n            if use_critic\n            else {}\n        ),\n        # response length\n        \"response_length/mean\": torch.mean(response_length).detach().item(),\n        \"response_length/max\": torch.max(response_length).detach().item(),\n        \"response_length/min\": torch.min(response_length).detach().item(),\n        \"response_length/clip_ratio\": torch.mean(torch.eq(response_length, max_response_length).float())\n        .detach()\n        .item(),\n        # prompt length\n        \"prompt_length/mean\": torch.mean(prompt_length).detach().item(),\n        \"prompt_length/max\": torch.max(prompt_length).detach().item(),\n        \"prompt_length/min\": torch.min(prompt_length).detach().item(),\n        \"prompt_length/clip_ratio\": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),\n    }\n\n    # multi-turn conversation\n    if \"__num_turns__\" in batch.non_tensor_batch:\n        num_turns = batch.non_tensor_batch[\"__num_turns__\"]\n        metrics[\"num_turns/min\"] = num_turns.min()\n        metrics[\"num_turns/max\"] = num_turns.max()\n        metrics[\"num_turns/mean\"] = num_turns.mean()\n\n    return metrics\n\n\ndef compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]:\n    \"\"\"\n    Computes timing metrics for different processing stages in PPO training.\n\n    This function calculates both raw timing metrics (in seconds) and per-token timing metrics\n    (in milliseconds) for various processing stages like generation, reference computation,\n    value computation, advantage computation, and model updates.\n\n    Args:\n        batch: A DataProto object containing batch data with responses and attention masks.\n        timing_raw: A dictionary mapping stage names to their execution times in seconds.\n\n    Returns:\n        A dictionary containing:\n            - timing_s/{name}: Raw timing in seconds for each stage\n            - timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage\n\n    Note:\n        Different stages use different token counts for normalization:\n        - \"gen\" uses only response tokens\n        - Other stages (\"ref\", \"values\", \"adv\", \"update_critic\", \"update_actor\") use all tokens\n          (prompt + response)\n    \"\"\"\n    response_info = _compute_response_info(batch)\n    num_prompt_tokens = torch.sum(response_info[\"prompt_length\"]).item()\n    num_response_tokens = torch.sum(response_info[\"response_length\"]).item()\n    num_overall_tokens = num_prompt_tokens + num_response_tokens\n\n    num_tokens_of_section = {\n        \"gen\": num_response_tokens,\n        **{name: num_overall_tokens for name in [\"ref\", \"values\", \"adv\", \"update_critic\", \"update_actor\"]},\n    }\n\n    return {\n        **{f\"timing_s/{name}\": value for name, value in timing_raw.items()},\n        **{\n            f\"timing_per_token_ms/{name}\": timing_raw[name] * 1000 / num_tokens_of_section[name]\n            for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())\n        },\n    }\n\n\ndef compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]:\n    \"\"\"\n    Computes throughput metrics for PPO training.\n\n    This function calculates performance metrics related to token processing speed,\n    including the total number of tokens processed, time per step, and throughput\n    (tokens per second per GPU).\n\n    Args:\n        batch: A DataProto object containing batch data with meta information about token counts.\n        timing_raw: A dictionary mapping stage names to their execution times in seconds.\n                   Must contain a \"step\" key with the total step time.\n        n_gpus: Number of GPUs used for training.\n\n    Returns:\n        A dictionary containing:\n            - perf/total_num_tokens: Total number of tokens processed in the batch\n            - perf/time_per_step: Time taken for the step in seconds\n            - perf/throughput: Tokens processed per second per GPU\n\n    Note:\n        The throughput is calculated as total_tokens / (time * n_gpus) to normalize\n        across different GPU counts.\n    \"\"\"\n    total_num_tokens = sum(batch.meta_info[\"global_token_num\"])\n    time = timing_raw[\"step\"]\n    # estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time)\n    # f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus),\n    # f'Theoretical TFLOPs/s/GPU​': promised_flops,\n    return {\n        \"perf/total_num_tokens\": total_num_tokens,\n        \"perf/time_per_step\": time,\n        \"perf/throughput\": total_num_tokens / (time * n_gpus),\n    }\n\n\ndef bootstrap_metric(\n    data: list[Any],\n    subset_size: int,\n    reduce_fns: list[Callable[[np.ndarray], float]],\n    n_bootstrap: int = 1000,\n    seed: int = 42,\n) -> list[tuple[float, float]]:\n    \"\"\"\n    Performs bootstrap resampling to estimate statistics of metrics.\n\n    This function uses bootstrap resampling to estimate the mean and standard deviation\n    of metrics computed by the provided reduction functions on random subsets of the data.\n\n    Args:\n        data: List of data points to bootstrap from.\n        subset_size: Size of each bootstrap sample.\n        reduce_fns: List of functions that compute a metric from a subset of data.\n        n_bootstrap: Number of bootstrap iterations. Defaults to 1000.\n        seed: Random seed for reproducibility. Defaults to 42.\n\n    Returns:\n        A list of tuples, where each tuple contains (mean, std) for a metric\n        corresponding to each reduction function in reduce_fns.\n\n    Example:\n        >>> data = [1, 2, 3, 4, 5]\n        >>> reduce_fns = [np.mean, np.max]\n        >>> bootstrap_metric(data, 3, reduce_fns)\n        [(3.0, 0.5), (4.5, 0.3)]  # Example values\n    \"\"\"\n    np.random.seed(seed)\n\n    bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))]\n    for _ in range(n_bootstrap):\n        bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True)\n        bootstrap_data = [data[i] for i in bootstrap_idxs]\n        for i, reduce_fn in enumerate(reduce_fns):\n            bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data))\n    return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts]\n\n\ndef calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float:\n    \"\"\"\n    Calculate a value based on majority voting.\n\n    This function identifies the most common value for a specified vote key\n    in the data, then returns the corresponding value for that majority vote.\n\n    Args:\n        data: List of dictionaries, where each dictionary contains both vote_key and val_key.\n        vote_key: The key in each dictionary used for voting/counting.\n        val_key: The key in each dictionary whose value will be returned for the majority vote.\n\n    Returns:\n        The value associated with the most common vote.\n\n    Example:\n        >>> data = [\n        ...     {\"pred\": \"A\", \"val\": 0.9},\n        ...     {\"pred\": \"B\", \"val\": 0.8},\n        ...     {\"pred\": \"A\", \"val\": 0.7}\n        ... ]\n        >>> calc_maj_val(data, vote_key=\"pred\", val_key=\"val\")\n        0.9  # Returns the first \"val\" for the majority vote \"A\"\n    \"\"\"\n    vote2vals = defaultdict(list)\n    for d in data:\n        vote2vals[d[vote_key]].append(d[val_key])\n\n    vote2cnt = {k: len(v) for k, v in vote2vals.items()}\n    maj_vote = max(vote2cnt, key=vote2cnt.get)\n\n    maj_val = vote2vals[maj_vote][0]\n\n    return maj_val\n\n\ndef process_validation_metrics(\n    data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42\n) -> dict[str, dict[str, dict[str, float]]]:\n    \"\"\"\n    Process validation metrics into a structured format with statistical analysis.\n\n    This function organizes validation metrics by data source and prompt, then computes\n    various statistical measures including means, standard deviations, best/worst values,\n    and majority voting results. It also performs bootstrap sampling to estimate statistics\n    for different sample sizes.\n\n    Args:\n        data_sources: List of data source identifiers for each sample.\n        sample_inputs: List of input prompts corresponding to each sample.\n        infos_dict: Dictionary mapping variable names to lists of values for each sample.\n        seed: Random seed for bootstrap sampling. Defaults to 42.\n\n    Returns:\n        A nested dictionary with the structure:\n        {\n            data_source: {\n                variable_name: {\n                    metric_name: value\n                }\n            }\n        }\n\n        Where metric_name includes:\n        - \"mean@N\": Mean value across N samples\n        - \"std@N\": Standard deviation across N samples\n        - \"best@N/mean\": Mean of the best values in bootstrap samples of size N\n        - \"best@N/std\": Standard deviation of the best values in bootstrap samples\n        - \"worst@N/mean\": Mean of the worst values in bootstrap samples\n        - \"worst@N/std\": Standard deviation of the worst values in bootstrap samples\n        - \"maj@N/mean\": Mean of majority voting results in bootstrap samples (if \"pred\" exists)\n        - \"maj@N/std\": Standard deviation of majority voting results (if \"pred\" exists)\n\n    Example:\n        >>> data_sources = [\"source1\", \"source1\", \"source2\"]\n        >>> sample_inputs = [\"prompt1\", \"prompt1\", \"prompt2\"]\n        >>> infos_dict = {\"score\": [0.8, 0.9, 0.7], \"pred\": [\"A\", \"A\", \"B\"]}\n        >>> result = process_validation_metrics(data_sources, sample_inputs, infos_dict)\n        >>> # result will contain statistics for each data source and variable\n    \"\"\"\n    # Group metrics by data source, prompt and variable\n    data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))\n    for sample_idx, data_source in enumerate(data_sources):\n        prompt = sample_inputs[sample_idx]\n        var2vals = data_src2prompt2var2vals[data_source][prompt]\n        for var_name, var_vals in infos_dict.items():\n            var2vals[var_name].append(var_vals[sample_idx])\n\n    # Calculate metrics for each group\n    data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))\n    for data_source, prompt2var2vals in data_src2prompt2var2vals.items():\n        for prompt, var2vals in prompt2var2vals.items():\n            for var_name, var_vals in var2vals.items():\n                if isinstance(var_vals[0], str):\n                    continue\n\n                metric = {}\n                n_resps = len(var_vals)\n                metric[f\"mean@{n_resps}\"] = np.mean(var_vals)\n\n                if n_resps > 1:\n                    metric[f\"std@{n_resps}\"] = np.std(var_vals)\n\n                    ns = []\n                    n = 2\n                    while n < n_resps:\n                        ns.append(n)\n                        n *= 2\n                    ns.append(n_resps)\n\n                    for n in ns:\n                        [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(\n                            data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed\n                        )\n                        metric[f\"best@{n}/mean\"], metric[f\"best@{n}/std\"] = bon_mean, bon_std\n                        metric[f\"worst@{n}/mean\"], metric[f\"worst@{n}/std\"] = won_mean, won_std\n                        if var2vals.get(\"pred\", None) is not None:\n                            vote_data = [\n                                {\"val\": val, \"pred\": pred} for val, pred in zip(var_vals, var2vals[\"pred\"], strict=True)\n                            ]\n                            [(maj_n_mean, maj_n_std)] = bootstrap_metric(\n                                data=vote_data,\n                                subset_size=n,\n                                reduce_fns=[partial(calc_maj_val, vote_key=\"pred\", val_key=\"val\")],\n                                seed=seed,\n                            )\n                            metric[f\"maj@{n}/mean\"], metric[f\"maj@{n}/std\"] = maj_n_mean, maj_n_std\n\n                data_src2prompt2var2metric[data_source][prompt][var_name] = metric\n\n    # Aggregate metrics across prompts\n    data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))\n    for data_source, prompt2var2metric in data_src2prompt2var2metric.items():\n        for prompt, var2metric in prompt2var2metric.items():\n            for var_name, metric in var2metric.items():\n                for metric_name, metric_val in metric.items():\n                    data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val)\n\n    data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))\n    for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items():\n        for var_name, metric2prompt_vals in var2metric2prompt_vals.items():\n            for metric_name, prompt_vals in metric2prompt_vals.items():\n                data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals)\n\n    return data_src2var2metric2val\n"
  },
  {
    "path": "verl_rl/verl/trainer/ppo/ray_trainer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPPO Trainer with Ray-based single controller.\nThis trainer supports model-agonistic model initialization with huggingface\n\"\"\"\n\nimport json\nimport os\nimport uuid\nfrom collections import defaultdict\nfrom copy import deepcopy\nfrom dataclasses import dataclass, field\nfrom enum import Enum\nfrom pprint import pprint\nfrom typing import Optional\n\nimport numpy as np\nimport ray\nimport torch\nimport wandb\nfrom omegaconf import OmegaConf, open_dict\nfrom torch.utils.data import Dataset, Sampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom tqdm import tqdm\n\nfrom verl import DataProto\nfrom verl.experimental.dataset.sampler import AbstractCurriculumSampler\nfrom verl.protocol import pad_dataproto_to_divisor, unpad_dataproto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\nfrom verl.single_controller.ray.base import create_colocated_worker_cls\nfrom verl.trainer.config import AlgoConfig\nfrom verl.trainer.ppo import core_algos\nfrom verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss\nfrom verl.trainer.ppo.metric_utils import (\n    compute_data_metrics,\n    compute_throughout_metrics,\n    compute_timing_metrics,\n    process_validation_metrics,\n)\nfrom verl.trainer.ppo.reward import compute_reward, compute_reward_async\nfrom verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi\nfrom verl.utils.debug import marked_timer\nfrom verl.utils.metric import (\n    reduce_metrics,\n)\nfrom verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance\nfrom verl.utils.torch_functional import masked_mean\nfrom verl.utils.tracking import ValidationGenerationsLogger\n\nWorkerType = type[Worker]\n\n\nclass Role(Enum):\n    \"\"\"\n    To create more roles dynamically, you can subclass Role and add new members\n    \"\"\"\n\n    Actor = 0\n    Rollout = 1\n    ActorRollout = 2\n    Critic = 3\n    RefPolicy = 4\n    RewardModel = 5\n    ActorRolloutRef = 6\n\n\n@dataclass\nclass ResourcePoolManager:\n    \"\"\"\n    Define a resource pool specification. Resource pool will be initialized first.\n    \"\"\"\n\n    resource_pool_spec: dict[str, list[int]]\n    mapping: dict[Role, str]\n    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)\n\n    def create_resource_pool(self):\n        \"\"\"Create Ray resource pools for distributed training.\n\n        Initializes resource pools based on the resource pool specification,\n        with each pool managing GPU resources across multiple nodes.\n        For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups.\n        For Megatron backend, uses max_colocate_count>1 for different models.\n        \"\"\"\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool\n            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.\n            # For Megatron backend, we recommend using max_colocate_count>1\n            # that can utilize different WorkerGroup for differnt models\n            resource_pool = RayResourcePool(\n                process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name\n            )\n            self.resource_pool_dict[resource_pool_name] = resource_pool\n\n        self._check_resource_available()\n\n    def get_resource_pool(self, role: Role) -> RayResourcePool:\n        \"\"\"Get the resource pool of the worker_cls\"\"\"\n        return self.resource_pool_dict[self.mapping[role]]\n\n    def get_n_gpus(self) -> int:\n        \"\"\"Get the number of gpus in this cluster.\"\"\"\n        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])\n\n    def _check_resource_available(self):\n        \"\"\"Check if the resource pool can be satisfied in this ray cluster.\"\"\"\n        node_available_resources = ray.state.available_resources_per_node()\n        node_available_gpus = {\n            node: node_info.get(\"GPU\", 0) if \"GPU\" in node_info else node_info.get(\"NPU\", 0)\n            for node, node_info in node_available_resources.items()\n        }\n\n        # check total required gpus can be satisfied\n        total_available_gpus = sum(node_available_gpus.values())\n        total_required_gpus = sum(\n            [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]\n        )\n        if total_available_gpus < total_required_gpus:\n            raise ValueError(\n                f\"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}\"\n            )\n\n        # check each resource pool can be satisfied, O(#resource_pools * #nodes)\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)\n            for node, available_gpus in node_available_gpus.items():\n                if available_gpus >= num_gpus:\n                    node_available_gpus[node] -= num_gpus\n                    num_nodes -= 1\n                    if num_nodes == 0:\n                        break\n            if num_nodes > 0:\n                raise ValueError(\n                    f\"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}\"\n                    + \"cannot be satisfied in this ray cluster\"\n                )\n\n\ndef apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty=\"kl\"):\n    \"\"\"Apply KL penalty to the token-level rewards.\n\n    This function computes the KL divergence between the reference policy and current policy,\n    then applies a penalty to the token-level rewards based on this divergence.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n        kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty.\n        kl_penalty (str, optional): Type of KL penalty to apply. Defaults to \"kl\".\n        multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False.\n\n    Returns:\n        tuple: A tuple containing:\n            - The updated data with token-level rewards adjusted by KL penalty\n            - A dictionary of metrics related to the KL penalty\n    \"\"\"\n    response_mask = data.batch[\"response_mask\"]\n    token_level_scores = data.batch[\"token_level_scores\"]\n    batch_size = data.batch.batch_size[0]\n\n    # compute kl between ref_policy and current policy\n    # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.\n    kld = core_algos.kl_penalty(\n        data.batch[\"old_log_probs\"], data.batch[\"ref_log_prob\"], kl_penalty=kl_penalty\n    )  # (batch_size, response_length)\n    kld = kld * response_mask\n    beta = kl_ctrl.value\n\n    token_level_rewards = token_level_scores - beta * kld\n\n    current_kl = masked_mean(kld, mask=response_mask, axis=-1)  # average over sequence\n    current_kl = torch.mean(current_kl, dim=0).item()\n\n    # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837\n    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)\n    data.batch[\"token_level_rewards\"] = token_level_rewards\n\n    metrics = {\"actor/reward_kl_penalty\": current_kl, \"actor/reward_kl_penalty_coeff\": beta}\n\n    return data, metrics\n\n\ndef compute_response_mask(data: DataProto):\n    \"\"\"Compute the attention mask for the response part of the sequence.\n\n    This function extracts the portion of the attention mask that corresponds to the model's response,\n    which is used for masking computations that should only apply to response tokens.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n\n    Returns:\n        torch.Tensor: The attention mask for the response tokens.\n    \"\"\"\n    responses = data.batch[\"responses\"]\n    response_length = responses.size(1)\n    attention_mask = data.batch[\"attention_mask\"]\n    return attention_mask[:, -response_length:]\n\n\ndef compute_advantage(\n    data: DataProto,\n    adv_estimator: AdvantageEstimator,\n    gamma: float = 1.0,\n    lam: float = 1.0,\n    num_repeat: int = 1,\n    norm_adv_by_std_in_grpo: bool = True,\n    config: Optional[AlgoConfig] = None,\n) -> DataProto:\n    \"\"\"Compute advantage estimates for policy optimization.\n\n    This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc.\n    The advantage estimates are used to guide policy optimization in RL algorithms.\n\n    Args:\n        data (DataProto): The data containing batched model outputs and inputs.\n        adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++).\n        gamma (float, optional): Discount factor for future rewards. Defaults to 1.0.\n        lam (float, optional): Lambda parameter for GAE. Defaults to 1.0.\n        num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1.\n        norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in\n            GRPO. Defaults to True.\n        config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None.\n\n    Returns:\n        DataProto: The updated data with computed advantages and returns.\n    \"\"\"\n    # Back-compatible with trainers that do not compute response mask in fit\n    if \"response_mask\" not in data.batch.keys():\n        data.batch[\"response_mask\"] = compute_response_mask(data)\n    # prepare response group\n    if adv_estimator == AdvantageEstimator.GAE:\n        # Compute advantages and returns using Generalized Advantage Estimation (GAE)\n        advantages, returns = core_algos.compute_gae_advantage_return(\n            token_level_rewards=data.batch[\"token_level_rewards\"],\n            values=data.batch[\"values\"],\n            response_mask=data.batch[\"response_mask\"],\n            gamma=gamma,\n            lam=lam,\n        )\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n        if config.get(\"use_pf_ppo\", False):\n            data = core_algos.compute_pf_ppo_reweight_data(\n                data,\n                config.pf_ppo.reweight_method,\n                config.pf_ppo.weight_pow,\n            )\n    elif adv_estimator == AdvantageEstimator.GRPO:\n        # Initialize the mask for GRPO calculation\n        grpo_calculation_mask = data.batch[\"response_mask\"]\n        # Call compute_grpo_outcome_advantage with parameters matching its definition\n        advantages, returns = core_algos.compute_grpo_outcome_advantage(\n            token_level_rewards=data.batch[\"token_level_rewards\"],\n            response_mask=grpo_calculation_mask,\n            index=data.non_tensor_batch[\"uid\"],\n            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n        )\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n    else:\n        # handle all other adv estimator type other than GAE and GRPO\n        adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)\n        adv_kwargs = {\n            \"token_level_rewards\": data.batch[\"token_level_rewards\"],\n            \"response_mask\": data.batch[\"response_mask\"],\n            \"config\": config,\n        }\n        if \"uid\" in data.non_tensor_batch:  # optional\n            adv_kwargs[\"index\"] = data.non_tensor_batch[\"uid\"]\n        if \"reward_baselines\" in data.batch:  # optional\n            adv_kwargs[\"reward_baselines\"] = data.batch[\"reward_baselines\"]\n\n        # calculate advantage estimator\n        advantages, returns = adv_estimator_fn(**adv_kwargs)\n        data.batch[\"advantages\"] = advantages\n        data.batch[\"returns\"] = returns\n    return data\n\n\nclass RayPPOTrainer:\n    \"\"\"Distributed PPO trainer using Ray for scalable reinforcement learning.\n\n    This trainer orchestrates distributed PPO training across multiple nodes and GPUs,\n    managing actor rollouts, critic training, and reward computation with Ray backend.\n    Supports various model architectures including FSDP, Megatron, and vLLM integration.\n    \"\"\"\n\n    # TODO: support each role have individual ray_worker_group_cls,\n    # i.e., support different backend of different role\n    def __init__(\n        self,\n        config,\n        tokenizer,\n        role_worker_mapping: dict[Role, WorkerType],\n        resource_pool_manager: ResourcePoolManager,\n        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,\n        processor=None,\n        reward_fn=None,\n        val_reward_fn=None,\n        train_dataset: Optional[Dataset] = None,\n        val_dataset: Optional[Dataset] = None,\n        collate_fn=None,\n        train_sampler: Optional[Sampler] = None,\n        device_name=None,\n    ):\n        \"\"\"\n        Initialize distributed PPO trainer with Ray backend.\n        Note that this trainer runs on the driver process on a single CPU/GPU node.\n\n        Args:\n            config: Configuration object containing training parameters.\n            tokenizer: Tokenizer used for encoding and decoding text.\n            role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes.\n            resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools.\n            ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup.\n            processor: Optional data processor, used for multimodal data\n            reward_fn: Function for computing rewards during training.\n            val_reward_fn: Function for computing rewards during validation.\n            train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.\n            val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.\n            collate_fn: Function to collate data samples into batches.\n            train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.\n            device_name (str, optional): Device name for training (e.g., \"cuda\", \"cpu\"). Defaults to None.\n        \"\"\"\n\n        # Store the tokenizer for text processing\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n        self.reward_fn = reward_fn\n        self.val_reward_fn = val_reward_fn\n\n        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine\n        assert self.hybrid_engine, \"Currently, only support hybrid engine\"\n\n        if self.hybrid_engine:\n            assert Role.ActorRollout in role_worker_mapping, f\"{role_worker_mapping.keys()=}\"\n\n        self.role_worker_mapping = role_worker_mapping\n        self.resource_pool_manager = resource_pool_manager\n        self.use_reference_policy = Role.RefPolicy in role_worker_mapping\n        self.use_rm = Role.RewardModel in role_worker_mapping\n        self.ray_worker_group_cls = ray_worker_group_cls\n        self.device_name = device_name if device_name else self.config.trainer.device\n        self.validation_generations_logger = ValidationGenerationsLogger(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n        )\n\n        # if ref_in_actor is True, the reference policy will be actor without lora applied\n        self.ref_in_actor = config.actor_rollout_ref.model.get(\"lora_rank\", 0) > 0\n\n        # define in-reward KL control\n        # kl loss control currently not suppoorted\n        if self.config.algorithm.use_kl_in_reward:\n            self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)\n\n        if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:\n            self.use_critic = True\n        elif self.config.algorithm.adv_estimator in [\n            AdvantageEstimator.GRPO,\n            AdvantageEstimator.GRPO_PASSK,\n            AdvantageEstimator.REINFORCE_PLUS_PLUS,\n            AdvantageEstimator.REMAX,\n            AdvantageEstimator.RLOO,\n            AdvantageEstimator.OPO,\n            AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,\n            AdvantageEstimator.GPG,\n        ]:\n            self.use_critic = False\n        else:\n            raise NotImplementedError\n\n        self._validate_config()\n        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)\n\n    def _validate_config(self):\n        config = self.config\n        # number of GPUs total\n        n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes\n        if config.actor_rollout_ref.actor.strategy == \"megatron\":\n            model_parallel_size = (\n                config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size\n                * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size\n            )\n            assert (\n                n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0\n            ), (\n                f\"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times \"\n                f\"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})\"\n            )\n            megatron_dp = n_gpus // (\n                model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size\n            )\n            minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu\n        else:\n            minimal_bsz = n_gpus\n\n        # 1. Check total batch size for data correctness\n        real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n\n        assert real_train_batch_size % minimal_bsz == 0, (\n            f\"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size \"\n            f\"({minimal_bsz})\"\n        )\n\n        # A helper function to check \"micro_batch_size\" vs \"micro_batch_size_per_gpu\"\n        # We throw an error if the user sets both. The new convention is \"..._micro_batch_size_per_gpu\".\n        def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):\n            \"\"\"Validate mutually exclusive micro batch size configuration options.\n\n            Ensures that users don't set both deprecated micro_batch_size and\n            the new micro_batch_size_per_gpu parameters simultaneously.\n\n            Args:\n                mbs: Deprecated micro batch size parameter value.\n                mbs_per_gpu: New micro batch size per GPU parameter value.\n                name (str): Configuration section name for error messages.\n\n            Raises:\n                ValueError: If both parameters are set or neither is set.\n            \"\"\"\n            settings = {\n                \"actor_rollout_ref.actor\": \"micro_batch_size\",\n                \"critic\": \"micro_batch_size\",\n                \"reward_model\": \"micro_batch_size\",\n                \"actor_rollout_ref.ref\": \"log_prob_micro_batch_size\",\n                \"actor_rollout_ref.rollout\": \"log_prob_micro_batch_size\",\n            }\n\n            if name in settings:\n                param = settings[name]\n                param_per_gpu = f\"{param}_per_gpu\"\n\n                if mbs is None and mbs_per_gpu is None:\n                    raise ValueError(\n                        f\"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.\"\n                    )\n\n                if mbs is not None and mbs_per_gpu is not None:\n                    raise ValueError(\n                        f\"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove \"\n                        f\"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated).\"\n                    )\n\n        if not config.actor_rollout_ref.actor.use_dynamic_bsz:\n            # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu\n            check_mutually_exclusive(\n                config.actor_rollout_ref.actor.ppo_micro_batch_size,\n                config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,\n                \"actor_rollout_ref.actor\",\n            )\n\n            if self.use_reference_policy:\n                # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu\n                check_mutually_exclusive(\n                    config.actor_rollout_ref.ref.log_prob_micro_batch_size,\n                    config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,\n                    \"actor_rollout_ref.ref\",\n                )\n\n            #  The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu\n            check_mutually_exclusive(\n                config.actor_rollout_ref.rollout.log_prob_micro_batch_size,\n                config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,\n                \"actor_rollout_ref.rollout\",\n            )\n\n        if self.use_critic and not config.critic.use_dynamic_bsz:\n            # Check for critic micro-batch size conflicts\n            check_mutually_exclusive(\n                config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, \"critic\"\n            )\n\n        # Check for reward model micro-batch size conflicts\n        if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:\n            check_mutually_exclusive(\n                config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, \"reward_model\"\n            )\n\n        # Actor\n        # check if train_batch_size is larger than ppo_mini_batch_size\n        # if NOT dynamic_bsz, we must ensure:\n        #    ppo_mini_batch_size is divisible by ppo_micro_batch_size\n        #    ppo_micro_batch_size * sequence_parallel_size >= n_gpus\n        if not config.actor_rollout_ref.actor.use_dynamic_bsz:\n            assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size\n            sp_size = config.actor_rollout_ref.actor.get(\"ulysses_sequence_parallel_size\", 1)\n            if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:\n                assert (\n                    config.actor_rollout_ref.actor.ppo_mini_batch_size\n                    % config.actor_rollout_ref.actor.ppo_micro_batch_size\n                    == 0\n                )\n                assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus\n\n        assert config.actor_rollout_ref.actor.loss_agg_mode in [\n            \"token-mean\",\n            \"seq-mean-token-sum\",\n            \"seq-mean-token-mean\",\n            \"seq-mean-token-sum-norm\",\n        ], f\"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}\"\n\n        if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:\n            print(\"NOTICE: You have both enabled in-reward kl and kl loss.\")\n\n        # critic\n        if self.use_critic and not config.critic.use_dynamic_bsz:\n            assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size\n            sp_size = config.critic.get(\"ulysses_sequence_parallel_size\", 1)\n            if config.critic.ppo_micro_batch_size is not None:\n                assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0\n                assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus\n\n        # Check if use_remove_padding is enabled when using sequence parallelism for fsdp\n        if config.actor_rollout_ref.actor.strategy in {\"fsdp\", \"fsdp2\"} and (\n            config.actor_rollout_ref.actor.get(\"ulysses_sequence_parallel_size\", 1) > 1\n            or config.actor_rollout_ref.ref.get(\"ulysses_sequence_parallel_size\", 1) > 1\n        ):\n            assert config.actor_rollout_ref.model.use_remove_padding, (\n                \"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`.\"\n            )\n\n        if self.use_critic and config.critic.strategy in {\"fsdp\", \"fsdp2\"}:\n            if config.critic.get(\"ulysses_sequence_parallel_size\", 1) > 1:\n                assert config.critic.model.use_remove_padding, (\n                    \"When using sequence parallelism for critic, you must enable `use_remove_padding`.\"\n                )\n\n        if config.data.get(\"val_batch_size\", None) is not None:\n            print(\n                \"WARNING: val_batch_size is deprecated.\"\n                + \" Validation datasets are sent to inference engines as a whole batch,\"\n                + \" which will schedule the memory themselves.\"\n            )\n\n        # check eval config\n        if config.actor_rollout_ref.rollout.val_kwargs.do_sample:\n            assert config.actor_rollout_ref.rollout.temperature > 0, (\n                \"validation gen temperature should be greater than 0 when enabling do_sample\"\n            )\n\n        print(\"[validate_config] All configuration checks passed successfully!\")\n\n    def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):\n        \"\"\"\n        Creates the train and validation dataloaders.\n        \"\"\"\n        # TODO: we have to make sure the batch size is divisible by the dp size\n        from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler\n\n        if train_dataset is None:\n            train_dataset = create_rl_dataset(\n                self.config.data.train_files, self.config.data, self.tokenizer, self.processor\n            )\n        if val_dataset is None:\n            val_dataset = create_rl_dataset(\n                self.config.data.val_files, self.config.data, self.tokenizer, self.processor\n            )\n        self.train_dataset, self.val_dataset = train_dataset, val_dataset\n\n        if train_sampler is None:\n            train_sampler = create_rl_sampler(self.config.data, self.train_dataset)\n        if collate_fn is None:\n            from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn\n\n            collate_fn = default_collate_fn\n\n        num_workers = self.config.data[\"dataloader_num_workers\"]\n\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=self.config.data.get(\"gen_batch_size\", self.config.data.train_batch_size),\n            num_workers=num_workers,\n            drop_last=True,\n            collate_fn=collate_fn,\n            sampler=train_sampler,\n        )\n\n        val_batch_size = self.config.data.val_batch_size  # Prefer config value if set\n        if val_batch_size is None:\n            val_batch_size = len(self.val_dataset)\n\n        self.val_dataloader = StatefulDataLoader(\n            dataset=self.val_dataset,\n            batch_size=val_batch_size,\n            num_workers=num_workers,\n            shuffle=self.config.data.get(\"validation_shuffle\", True),\n            drop_last=False,\n            collate_fn=collate_fn,\n        )\n\n        assert len(self.train_dataloader) >= 1, \"Train dataloader is empty!\"\n        assert len(self.val_dataloader) >= 1, \"Validation dataloader is empty!\"\n\n        print(\n            f\"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: \"\n            f\"{len(self.val_dataloader)}\"\n        )\n\n        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs\n\n        if self.config.trainer.total_training_steps is not None:\n            total_training_steps = self.config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n        print(f\"Total training steps: {self.total_training_steps}\")\n\n        try:\n            OmegaConf.set_struct(self.config, True)\n            with open_dict(self.config):\n                if OmegaConf.select(self.config, \"actor_rollout_ref.actor.optim\"):\n                    self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps\n                if OmegaConf.select(self.config, \"critic.optim\"):\n                    self.config.critic.optim.total_training_steps = total_training_steps\n        except Exception as e:\n            print(f\"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}\")\n\n    def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path, ground_truths=None):\n        \"\"\"Dump rollout/validation samples as JSONL.\"\"\"\n        os.makedirs(dump_path, exist_ok=True)\n        filename = os.path.join(dump_path, f\"{self.global_steps}.jsonl\")\n\n        n = len(inputs)\n        base_data = {\n            \"input\": inputs,\n            \"output\": outputs,\n            \"score\": scores,\n            \"step\": [self.global_steps] * n,\n        }\n\n        if ground_truths and len(ground_truths) == n:\n            base_data[\"ground_truth\"] = ground_truths\n\n        for k, v in reward_extra_infos_dict.items():\n            if len(v) == n:\n                base_data[k] = v\n\n        lines = []\n        for i in range(n):\n            entry = {k: v[i] for k, v in base_data.items()}\n            lines.append(json.dumps(entry, ensure_ascii=False))\n\n        with open(filename, \"w\") as f:\n            f.write(\"\\n\".join(lines) + \"\\n\")\n\n        print(f\"Dumped generations to {filename}\")\n\n    def _maybe_log_val_generations(self, inputs, outputs, scores):\n        \"\"\"Log a table of validation samples to the configured logger (wandb or swanlab)\"\"\"\n\n        generations_to_log = self.config.trainer.log_val_generations\n\n        if generations_to_log == 0:\n            return\n\n        import numpy as np\n\n        # Create tuples of (input, output, score) and sort by input text\n        samples = list(zip(inputs, outputs, scores, strict=True))\n        samples.sort(key=lambda x: x[0])  # Sort by input text\n\n        # Use fixed random seed for deterministic shuffling\n        rng = np.random.RandomState(42)\n        rng.shuffle(samples)\n\n        # Take first N samples after shuffling\n        samples = samples[:generations_to_log]\n\n        # Log to each configured logger\n        self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)\n\n    def _validate(self):\n        data_source_lst = []\n        reward_extra_infos_dict: dict[str, list] = defaultdict(list)\n\n        # Lists to collect samples for the table\n        sample_inputs = []\n        sample_outputs = []\n        sample_scores = []\n        sample_turns = []\n        sample_ground_truths = []\n\n        for test_data in self.val_dataloader:\n            test_batch = DataProto.from_single_dict(test_data)\n\n            # repeat test batch\n            test_batch = test_batch.repeat(\n                repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True\n            )\n\n            # we only do validation on rule-based rm\n            if self.config.reward_model.enable and test_batch[0].non_tensor_batch[\"reward_model\"][\"style\"] == \"model\":\n                return {}\n\n            # Store original inputs\n            input_ids = test_batch.batch[\"input_ids\"]\n            # TODO: Can we keep special tokens except for padding tokens?\n            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]\n            sample_inputs.extend(input_texts)\n\n            if \"reward_model\" in test_batch.non_tensor_batch:\n                ground_truths = [item[\"ground_truth\"] for item in test_batch.non_tensor_batch[\"reward_model\"]]\n                sample_ground_truths.extend(ground_truths)\n\n            batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n            non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n            if \"multi_modal_data\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"multi_modal_data\")\n            if \"raw_prompt\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"raw_prompt\")\n            if \"tools_kwargs\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"tools_kwargs\")\n            if \"interaction_kwargs\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"interaction_kwargs\")\n            if \"agent_name\" in test_batch.non_tensor_batch:\n                non_tensor_batch_keys_to_pop.append(\"agent_name\")\n            test_gen_batch = test_batch.pop(\n                batch_keys=batch_keys_to_pop,\n                non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n            )\n\n            test_gen_batch.meta_info = {\n                \"eos_token_id\": self.tokenizer.eos_token_id,\n                \"pad_token_id\": self.tokenizer.pad_token_id,\n                \"recompute_log_prob\": False,\n                \"do_sample\": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,\n                \"validate\": True,\n                \"global_steps\": self.global_steps,\n            }\n            print(f\"test_gen_batch meta info: {test_gen_batch.meta_info}\")\n\n            # pad to be divisible by dp_size\n            size_divisor = (\n                self.actor_rollout_wg.world_size\n                if not self.async_rollout_mode\n                else self.config.actor_rollout_ref.rollout.agent.num_workers\n            )\n            test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)\n            if not self.async_rollout_mode:\n                test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)\n            else:\n                test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)\n\n            # unpad\n            test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)\n\n            print(\"validation generation end\")\n\n            # Store generated outputs\n            output_ids = test_output_gen_batch.batch[\"responses\"]\n            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]\n            sample_outputs.extend(output_texts)\n\n            test_batch = test_batch.union(test_output_gen_batch)\n            test_batch.meta_info[\"validate\"] = True\n\n            # evaluate using reward_function\n            result = self.val_reward_fn(test_batch, return_dict=True)\n            reward_tensor = result[\"reward_tensor\"]\n            scores = reward_tensor.sum(-1).cpu().tolist()\n            sample_scores.extend(scores)\n\n            reward_extra_infos_dict[\"reward\"].extend(scores)\n            print(f\"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}\")\n            if \"reward_extra_info\" in result:\n                for key, lst in result[\"reward_extra_info\"].items():\n                    reward_extra_infos_dict[key].extend(lst)\n                    print(f\"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}\")\n\n            # collect num_turns of each prompt\n            if \"__num_turns__\" in test_batch.non_tensor_batch:\n                sample_turns.append(test_batch.non_tensor_batch[\"__num_turns__\"])\n\n            data_source_lst.append(test_batch.non_tensor_batch.get(\"data_source\", [\"unknown\"] * reward_tensor.shape[0]))\n\n        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)\n\n        # dump generations\n        val_data_dir = self.config.trainer.get(\"validation_data_dir\", None)\n        if val_data_dir:\n            self._dump_generations(\n                inputs=sample_inputs,\n                outputs=sample_outputs,\n                scores=sample_scores,\n                reward_extra_infos_dict=reward_extra_infos_dict,\n                dump_path=val_data_dir,\n                ground_truths=sample_ground_truths,\n            )\n\n        for key_info, lst in reward_extra_infos_dict.items():\n            assert len(lst) == 0 or len(lst) == len(sample_scores), f\"{key_info}: {len(lst)=}, {len(sample_scores)=}\"\n\n        data_sources = np.concatenate(data_source_lst, axis=0)\n\n        data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)\n        metric_dict = {}\n        for data_source, var2metric2val in data_src2var2metric2val.items():\n            core_var = \"acc\" if \"acc\" in var2metric2val else \"reward\"\n            for var_name, metric2val in var2metric2val.items():\n                n_max = max([int(name.split(\"@\")[-1].split(\"/\")[0]) for name in metric2val.keys()])\n                for metric_name, metric_val in metric2val.items():\n                    if (\n                        (var_name == core_var)\n                        and any(metric_name.startswith(pfx) for pfx in [\"mean\", \"maj\", \"best\"])\n                        and (f\"@{n_max}\" in metric_name)\n                    ):\n                        metric_sec = \"val-core\"\n                    else:\n                        metric_sec = \"val-aux\"\n                    pfx = f\"{metric_sec}/{data_source}/{var_name}/{metric_name}\"\n                    metric_dict[pfx] = metric_val\n\n        if len(sample_turns) > 0:\n            sample_turns = np.concatenate(sample_turns)\n            metric_dict[\"val-aux/num_turns/min\"] = sample_turns.min()\n            metric_dict[\"val-aux/num_turns/max\"] = sample_turns.max()\n            metric_dict[\"val-aux/num_turns/mean\"] = sample_turns.mean()\n\n        return metric_dict\n\n    def init_workers(self):\n        \"\"\"Initialize distributed training workers using Ray backend.\n\n        Creates:\n        1. Ray resource pools from configuration\n        2. Worker groups for each role (actor, critic, etc.)\n        \"\"\"\n        self.resource_pool_manager.create_resource_pool()\n\n        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}\n\n        # create actor and rollout\n        if self.hybrid_engine:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)\n            actor_rollout_cls = RayClassWithInitArgs(\n                cls=self.role_worker_mapping[Role.ActorRollout],\n                config=self.config.actor_rollout_ref,\n                role=\"actor_rollout\",\n                profile_option=self.config.trainer.npu_profile.options,\n            )\n            self.resource_pool_to_cls[resource_pool][\"actor_rollout\"] = actor_rollout_cls\n        else:\n            raise NotImplementedError\n\n        # create critic\n        if self.use_critic:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)\n            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)\n            self.resource_pool_to_cls[resource_pool][\"critic\"] = critic_cls\n\n        # create reference policy if needed\n        if self.use_reference_policy:\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)\n            ref_policy_cls = RayClassWithInitArgs(\n                self.role_worker_mapping[Role.RefPolicy],\n                config=self.config.actor_rollout_ref,\n                role=\"ref\",\n                profile_option=self.config.trainer.npu_profile.options,\n            )\n            self.resource_pool_to_cls[resource_pool][\"ref\"] = ref_policy_cls\n\n        # create a reward model if reward_fn is None\n        if self.use_rm:\n            # we create a RM here\n            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)\n            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)\n            self.resource_pool_to_cls[resource_pool][\"rm\"] = rm_cls\n\n        # initialize WorkerGroup\n        # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,\n        # you should not use `create_colocated_worker_cls`.\n        # Instead, directly pass different resource pool to different worker groups.\n        # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.\n        all_wg = {}\n        wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup\n        if OmegaConf.select(self.config.trainer, \"ray_wait_register_center_timeout\") is not None:\n            wg_kwargs[\"ray_wait_register_center_timeout\"] = self.config.trainer.ray_wait_register_center_timeout\n        if OmegaConf.select(self.config.trainer, \"profile_steps\") is not None:\n            wg_kwargs[\"profile_steps\"] = OmegaConf.select(self.config.trainer, \"profile_steps\")\n            assert OmegaConf.select(self.config.trainer, \"worker_nsight_options\") is not None, (\n                \"worker_nsight_options must be set when profile_steps is set\"\n            )\n            wg_kwargs[\"worker_nsight_options\"] = OmegaConf.to_container(\n                OmegaConf.select(self.config.trainer, \"worker_nsight_options\")\n            )\n        wg_kwargs[\"device_name\"] = self.device_name\n\n        for resource_pool, class_dict in self.resource_pool_to_cls.items():\n            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)\n            wg_dict = self.ray_worker_group_cls(\n                resource_pool=resource_pool,\n                ray_cls_with_init=worker_dict_cls,\n                **wg_kwargs,\n            )\n            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())\n            all_wg.update(spawn_wg)\n\n        if self.use_critic:\n            self.critic_wg = all_wg[\"critic\"]\n            self.critic_wg.init_model()\n\n        if self.use_reference_policy and not self.ref_in_actor:\n            self.ref_policy_wg = all_wg[\"ref\"]\n            self.ref_policy_wg.init_model()\n\n        if self.use_rm:\n            self.rm_wg = all_wg[\"rm\"]\n            self.rm_wg.init_model()\n\n        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory\n        self.actor_rollout_wg = all_wg[\"actor_rollout\"]\n        self.actor_rollout_wg.init_model()\n\n        # create async rollout manager and request scheduler\n        self.async_rollout_mode = False\n        if self.config.actor_rollout_ref.rollout.mode == \"async\":\n            from verl.experimental.agent_loop import AgentLoopManager\n\n            self.async_rollout_mode = True\n            self.async_rollout_manager = AgentLoopManager(\n                config=self.config,\n                worker_group=self.actor_rollout_wg,\n            )\n\n    def _save_checkpoint(self):\n        from verl.utils.fs import local_mkdir_safe\n\n        # path: given_path + `/global_step_{global_steps}` + `/actor`\n        local_global_step_folder = os.path.join(\n            self.config.trainer.default_local_dir, f\"global_step_{self.global_steps}\"\n        )\n\n        print(f\"local_global_step_folder: {local_global_step_folder}\")\n        actor_local_path = os.path.join(local_global_step_folder, \"actor\")\n\n        actor_remote_path = (\n            None\n            if self.config.trainer.default_hdfs_dir is None\n            else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"actor\")\n        )\n\n        remove_previous_ckpt_in_save = self.config.trainer.get(\"remove_previous_ckpt_in_save\", False)\n        if remove_previous_ckpt_in_save:\n            print(\n                \"Warning: remove_previous_ckpt_in_save is deprecated,\"\n                + \" set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead\"\n            )\n        max_actor_ckpt_to_keep = (\n            self.config.trainer.get(\"max_actor_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n        max_critic_ckpt_to_keep = (\n            self.config.trainer.get(\"max_critic_ckpt_to_keep\", None) if not remove_previous_ckpt_in_save else 1\n        )\n\n        self.actor_rollout_wg.save_checkpoint(\n            actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep\n        )\n\n        if self.use_critic:\n            critic_local_path = os.path.join(local_global_step_folder, \"critic\")\n            critic_remote_path = (\n                None\n                if self.config.trainer.default_hdfs_dir is None\n                else os.path.join(self.config.trainer.default_hdfs_dir, f\"global_step_{self.global_steps}\", \"critic\")\n            )\n            self.critic_wg.save_checkpoint(\n                critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep\n            )\n\n        # save dataloader\n        local_mkdir_safe(local_global_step_folder)\n        dataloader_local_path = os.path.join(local_global_step_folder, \"data.pt\")\n        dataloader_state_dict = self.train_dataloader.state_dict()\n        torch.save(dataloader_state_dict, dataloader_local_path)\n\n        # latest checkpointed iteration tracker (for atomic usage)\n        local_latest_checkpointed_iteration = os.path.join(\n            self.config.trainer.default_local_dir, \"latest_checkpointed_iteration.txt\"\n        )\n        with open(local_latest_checkpointed_iteration, \"w\") as f:\n            f.write(str(self.global_steps))\n\n    def _load_checkpoint(self):\n        if self.config.trainer.resume_mode == \"disable\":\n            return 0\n\n        # load from hdfs\n        if self.config.trainer.default_hdfs_dir is not None:\n            raise NotImplementedError(\"load from hdfs is not implemented yet\")\n        else:\n            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path\n            if not os.path.isabs(checkpoint_folder):\n                working_dir = os.getcwd()\n                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)\n            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest\n\n        # find global_step_folder\n        if self.config.trainer.resume_mode == \"auto\":\n            if global_step_folder is None:\n                print(\"Training from scratch\")\n                return 0\n        else:\n            if self.config.trainer.resume_mode == \"resume_path\":\n                assert isinstance(self.config.trainer.resume_from_path, str), \"resume ckpt must be str type\"\n                assert \"global_step_\" in self.config.trainer.resume_from_path, (\n                    \"resume ckpt must specify the global_steps\"\n                )\n                global_step_folder = self.config.trainer.resume_from_path\n                if not os.path.isabs(global_step_folder):\n                    working_dir = os.getcwd()\n                    global_step_folder = os.path.join(working_dir, global_step_folder)\n        print(f\"Load from checkpoint folder: {global_step_folder}\")\n        # set global step\n        self.global_steps = int(global_step_folder.split(\"global_step_\")[-1])\n\n        print(f\"Setting global step to {self.global_steps}\")\n        print(f\"Resuming from {global_step_folder}\")\n\n        actor_path = os.path.join(global_step_folder, \"actor\")\n        critic_path = os.path.join(global_step_folder, \"critic\")\n        # load actor\n        self.actor_rollout_wg.load_checkpoint(\n            actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n        )\n        # load critic\n        if self.use_critic:\n            self.critic_wg.load_checkpoint(\n                critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n            )\n\n        # load dataloader,\n        # TODO: from remote not implemented yet\n        dataloader_local_path = os.path.join(global_step_folder, \"data.pt\")\n        if os.path.exists(dataloader_local_path):\n            dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)\n            self.train_dataloader.load_state_dict(dataloader_state_dict)\n        else:\n            print(f\"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch\")\n\n    def _start_profiling(self, do_profile: bool) -> None:\n        \"\"\"Start profiling for all worker groups if profiling is enabled.\"\"\"\n        if do_profile:\n            self.actor_rollout_wg.start_profile(role=\"e2e\", profile_step=self.global_steps)\n            if self.use_reference_policy:\n                self.ref_policy_wg.start_profile()\n            if self.use_critic:\n                self.critic_wg.start_profile()\n            if self.use_rm:\n                self.rm_wg.start_profile()\n\n    def _stop_profiling(self, do_profile: bool) -> None:\n        \"\"\"Stop profiling for all worker groups if profiling is enabled.\"\"\"\n        if do_profile:\n            self.actor_rollout_wg.stop_profile()\n            if self.use_reference_policy:\n                self.ref_policy_wg.stop_profile()\n            if self.use_critic:\n                self.critic_wg.stop_profile()\n            if self.use_rm:\n                self.rm_wg.stop_profile()\n\n    def _balance_batch(self, batch: DataProto, metrics, logging_prefix=\"global_seqlen\"):\n        \"\"\"Reorder the data on single controller such that each dp rank gets similar total tokens\"\"\"\n        attention_mask = batch.batch[\"attention_mask\"]\n        batch_size = attention_mask.shape[0]\n        global_seqlen_lst = batch.batch[\"attention_mask\"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)\n        world_size = self.actor_rollout_wg.world_size\n        global_partition_lst = get_seqlen_balanced_partitions(\n            global_seqlen_lst, k_partitions=world_size, equal_size=True\n        )\n        # reorder based on index. The data will be automatically equally partitioned by dispatch function\n        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])\n        batch.reorder(global_idx)\n        global_balance_stats = log_seqlen_unbalance(\n            seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix\n        )\n        metrics.update(global_balance_stats)\n\n    def fit(self):\n        \"\"\"\n        The training loop of PPO.\n        The driver process only need to call the compute functions of the worker group through RPC\n        to construct the PPO dataflow.\n        The light-weight advantage computation is done on the driver process.\n        \"\"\"\n        from omegaconf import OmegaConf\n\n        from verl.utils.tracking import Tracking\n\n        logger = Tracking(\n            project_name=self.config.trainer.project_name,\n            experiment_name=self.config.trainer.experiment_name,\n            default_backend=self.config.trainer.logger,\n            config=OmegaConf.to_container(self.config, resolve=True),\n        )\n\n        self.global_steps = 0\n\n        # load checkpoint before doing anything\n        self._load_checkpoint()\n\n        # perform validation before training\n        # currently, we only support validation using the reward_function.\n        if self.val_reward_fn is not None and self.config.trainer.get(\"val_before_train\", True):\n            val_metrics = self._validate()\n            assert val_metrics, f\"{val_metrics=}\"\n            pprint(f\"Initial validation metrics: {val_metrics}\")\n            logger.log(data=val_metrics, step=self.global_steps)\n            if self.config.trainer.get(\"val_only\", False):\n                return\n\n        # add tqdm\n        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        # we start from step 1\n        self.global_steps += 1\n        last_val_metrics = None\n        self.max_steps_duration = 0\n\n        for epoch in range(self.config.trainer.total_epochs):\n            for batch_dict in self.train_dataloader:\n                metrics = {}\n                timing_raw = {}\n\n                do_profile = (\n                    self.global_steps in self.config.trainer.profile_steps\n                    if self.config.trainer.profile_steps is not None\n                    else False\n                )\n                with marked_timer(\"start_profile\", timing_raw):\n                    self._start_profiling(do_profile)\n\n                batch: DataProto = DataProto.from_single_dict(batch_dict)\n\n                # pop those keys for generation\n                batch_keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n                non_tensor_batch_keys_to_pop = [\"raw_prompt_ids\"]\n                if \"multi_modal_data\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"multi_modal_data\")\n                if \"raw_prompt\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"raw_prompt\")\n                if \"tools_kwargs\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"tools_kwargs\")\n                if \"interaction_kwargs\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"interaction_kwargs\")\n                if \"index\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"index\")\n                if \"agent_name\" in batch.non_tensor_batch:\n                    non_tensor_batch_keys_to_pop.append(\"agent_name\")\n\n                gen_batch = batch.pop(\n                    batch_keys=batch_keys_to_pop,\n                    non_tensor_batch_keys=non_tensor_batch_keys_to_pop,\n                )\n\n                # pass global_steps to trace\n                gen_batch.meta_info[\"global_steps\"] = self.global_steps\n                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                with marked_timer(\"step\", timing_raw):\n                    # generate a batch\n                    with marked_timer(\"gen\", timing_raw, color=\"red\"):\n                        if not self.async_rollout_mode:\n                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)\n                        else:\n                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)\n                        timing_raw.update(gen_batch_output.meta_info[\"timing\"])\n                        gen_batch_output.meta_info.pop(\"timing\", None)\n\n                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:\n                        with marked_timer(\"gen_max\", timing_raw, color=\"purple\"):\n                            gen_baseline_batch = deepcopy(gen_batch)\n                            gen_baseline_batch.meta_info[\"do_sample\"] = False\n                            if not self.async_rollout_mode:\n                                gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)\n                            else:\n                                gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)\n                            batch = batch.union(gen_baseline_output)\n                            reward_baseline_tensor = self.reward_fn(batch)\n                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)\n\n                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))\n\n                            batch.batch[\"reward_baselines\"] = reward_baseline_tensor\n\n                            del gen_baseline_batch, gen_baseline_output\n\n                    batch.non_tensor_batch[\"uid\"] = np.array(\n                        [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object\n                    )\n                    # repeat to align with repeated responses in rollout\n                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)\n                    batch = batch.union(gen_batch_output)\n\n                    if \"response_mask\" not in batch.batch.keys():\n                        batch.batch[\"response_mask\"] = compute_response_mask(batch)\n                    # Balance the number of valid tokens across DP ranks.\n                    # NOTE: This usually changes the order of data in the `batch`,\n                    # which won't affect the advantage calculation (since it's based on uid),\n                    # but might affect the loss calculation (due to the change of mini-batching).\n                    # TODO: Decouple the DP balancing and mini-batching.\n                    if self.config.trainer.balance_batch:\n                        self._balance_batch(batch, metrics=metrics)\n\n                    # compute global_valid tokens\n                    batch.meta_info[\"global_token_num\"] = torch.sum(batch.batch[\"attention_mask\"], dim=-1).tolist()\n\n                    with marked_timer(\"reward\", timing_raw, color=\"yellow\"):\n                        # compute reward model score\n                        if self.use_rm:\n                            reward_tensor = self.rm_wg.compute_rm_score(batch)\n                            batch = batch.union(reward_tensor)\n\n                        if self.config.reward_model.launch_reward_fn_async:\n                            future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)\n                        else:\n                            reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)\n\n                    # recompute old_log_probs\n                    with marked_timer(\"old_log_prob\", timing_raw, color=\"blue\"):\n                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)\n                        entropys = old_log_prob.batch[\"entropys\"]\n                        response_masks = batch.batch[\"response_mask\"]\n                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode\n                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)\n                        old_log_prob_metrics = {\"actor/entropy\": entropy_agg.detach().item()}\n\n                        # per-position entropy plot\n                        masked_entropys = entropys * response_masks\n                        sum_entropy_per_position = torch.sum(masked_entropys, dim=0)\n                        num_tokens_per_position = torch.sum(response_masks, dim=0)\n                        mean_entropy_per_position = sum_entropy_per_position / torch.clamp(\n                            num_tokens_per_position, min=1\n                        )\n                        try:\n                            entropy_list = mean_entropy_per_position.cpu().tolist()\n                            table_data = [[i, ent] for i, ent in enumerate(entropy_list)]\n                            table = wandb.Table(data=table_data, columns=[\"position\", \"entropy\"])\n                            old_log_prob_metrics[\"actor/per_position_entropy_plot\"] = wandb.plot.line(\n                                table, \"position\", \"entropy\", title=\"Per-Position Entropy\"\n                            )\n                        except Exception as e:\n                            print(f\"Warning: Could not create wandb per-position entropy plot. Error: {e}\")\n\n                        # token-type entropy\n                        try:\n                            responses = batch.batch[\"responses\"]\n                            # mask for token type 1 (id >= 151669)\n                            type1_mask = (responses >= 151669) * response_masks\n                            # mask for token type 2 (id < 151669)\n                            type2_mask = (responses < 151669) * response_masks\n\n                            count_type1 = type1_mask.sum().item()\n                            count_type2 = type2_mask.sum().item()\n\n                            if count_type1 > 0:\n                                entropy_type1 = masked_mean(entropys, mask=type1_mask, axis=None).item()\n                                old_log_prob_metrics[\"actor/entropy_itemic_token\"] = entropy_type1\n\n                            if count_type2 > 0:\n                                entropy_type2 = masked_mean(entropys, mask=type2_mask, axis=None).item()\n                                old_log_prob_metrics[\"actor/entropy_lang_token\"] = entropy_type2\n\n                            old_log_prob_metrics[\"actor/token_count_itemic_token\"] = count_type1\n                            old_log_prob_metrics[\"actor/token_count_lang_token\"] = count_type2\n                        except Exception as e:\n                            print(f\"Warning: Could not compute token-type entropy metrics. Error: {e}\")\n\n                        metrics.update(old_log_prob_metrics)\n                        old_log_prob.batch.pop(\"entropys\")\n                        batch = batch.union(old_log_prob)\n\n                        if \"rollout_log_probs\" in batch.batch.keys():\n                            # TODO: we may want to add diff of probs too.\n                            rollout_old_log_probs = batch.batch[\"rollout_log_probs\"]\n                            actor_old_log_probs = batch.batch[\"old_log_probs\"]\n                            attention_mask = batch.batch[\"attention_mask\"]\n                            responses = batch.batch[\"responses\"]\n                            response_length = responses.size(1)\n                            response_mask = attention_mask[:, -response_length:]\n\n                            rollout_probs = torch.exp(rollout_old_log_probs)\n                            actor_probs = torch.exp(actor_old_log_probs)\n                            rollout_probs_diff = torch.abs(rollout_probs - actor_probs)\n                            rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())\n                            rollout_probs_diff_max = torch.max(rollout_probs_diff)\n                            rollout_probs_diff_mean = torch.mean(rollout_probs_diff)\n                            rollout_probs_diff_std = torch.std(rollout_probs_diff)\n                            metrics.update(\n                                {\n                                    \"training/rollout_probs_diff_max\": rollout_probs_diff_max.detach().item(),\n                                    \"training/rollout_probs_diff_mean\": rollout_probs_diff_mean.detach().item(),\n                                    \"training/rollout_probs_diff_std\": rollout_probs_diff_std.detach().item(),\n                                }\n                            )\n\n                    if self.use_reference_policy:\n                        # compute reference log_prob\n                        with marked_timer(\"ref\", timing_raw, color=\"olive\"):\n                            if not self.ref_in_actor:\n                                ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)\n                            else:\n                                ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)\n                            batch = batch.union(ref_log_prob)\n\n                    # compute values\n                    if self.use_critic:\n                        with marked_timer(\"values\", timing_raw, color=\"cyan\"):\n                            values = self.critic_wg.compute_values(batch)\n                            batch = batch.union(values)\n\n                    with marked_timer(\"adv\", timing_raw, color=\"brown\"):\n                        # we combine with rule-based rm\n                        reward_extra_infos_dict: dict[str, list]\n                        if self.config.reward_model.launch_reward_fn_async:\n                            reward_tensor, reward_extra_infos_dict = ray.get(future_reward)\n                        batch.batch[\"token_level_scores\"] = reward_tensor\n\n                        if reward_extra_infos_dict:\n                            batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})\n                            \n                            # 将 reward_extra_infos_dict 中的统计信息添加到 metrics 中\n                            for key, values in reward_extra_infos_dict.items():\n                                if values and len(values) > 0:\n                                    values_array = np.array(values)\n                                    # 只记录数值类型的指标\n                                    if np.issubdtype(values_array.dtype, np.number):\n                                        metrics[f\"reward/{key}/mean\"] = float(np.mean(values_array))\n                                        metrics[f\"reward/{key}/max\"] = float(np.max(values_array))\n                                        metrics[f\"reward/{key}/min\"] = float(np.min(values_array))\n\n                        # compute rewards. apply_kl_penalty if available\n                        if self.config.algorithm.use_kl_in_reward:\n                            batch, kl_metrics = apply_kl_penalty(\n                                batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty\n                            )\n                            metrics.update(kl_metrics)\n                        else:\n                            batch.batch[\"token_level_rewards\"] = batch.batch[\"token_level_scores\"]\n\n                        # compute advantages, executed on the driver process\n\n                        norm_adv_by_std_in_grpo = self.config.algorithm.get(\n                            \"norm_adv_by_std_in_grpo\", True\n                        )  # GRPO adv normalization factor\n\n                        batch = compute_advantage(\n                            batch,\n                            adv_estimator=self.config.algorithm.adv_estimator,\n                            gamma=self.config.algorithm.gamma,\n                            lam=self.config.algorithm.lam,\n                            num_repeat=self.config.actor_rollout_ref.rollout.n,\n                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n                            config=self.config.algorithm,\n                        )\n\n                    # update critic\n                    if self.use_critic:\n                        with marked_timer(\"update_critic\", timing_raw, color=\"pink\"):\n                            critic_output = self.critic_wg.update_critic(batch)\n                        critic_output_metrics = reduce_metrics(critic_output.meta_info[\"metrics\"])\n                        metrics.update(critic_output_metrics)\n\n                    # implement critic warmup\n                    if self.config.trainer.critic_warmup <= self.global_steps:\n                        # update actor\n                        with marked_timer(\"update_actor\", timing_raw, color=\"red\"):\n                            batch.meta_info[\"multi_turn\"] = self.config.actor_rollout_ref.rollout.multi_turn.enable\n                            actor_output = self.actor_rollout_wg.update_actor(batch)\n                        actor_output_metrics = reduce_metrics(actor_output.meta_info[\"metrics\"])\n                        metrics.update(actor_output_metrics)\n\n                    # Log rollout generations if enabled\n                    rollout_data_dir = self.config.trainer.get(\"rollout_data_dir\", None)\n                    if rollout_data_dir:\n                        with marked_timer(\"dump_rollout_generations\", timing_raw, color=\"green\"):\n                            inputs = self.tokenizer.batch_decode(batch.batch[\"prompts\"], skip_special_tokens=True)\n                            outputs = self.tokenizer.batch_decode(batch.batch[\"responses\"], skip_special_tokens=True)\n                            scores = batch.batch[\"token_level_scores\"].sum(-1).cpu().tolist()\n                            ground_truths = None\n                            if \"reward_model\" in batch.non_tensor_batch:\n                                ground_truths = [item[\"ground_truth\"] for item in batch.non_tensor_batch[\"reward_model\"]]\n                            if \"request_id\" in batch.non_tensor_batch:\n                                reward_extra_infos_dict.setdefault(\n                                    \"request_id\",\n                                    batch.non_tensor_batch[\"request_id\"].tolist(),\n                                )\n                            self._dump_generations(\n                                inputs=inputs,\n                                outputs=outputs,\n                                scores=scores,\n                                reward_extra_infos_dict=reward_extra_infos_dict,\n                                dump_path=rollout_data_dir,\n                                ground_truths=ground_truths,\n                            )\n\n                    # validate\n                    if (\n                        self.val_reward_fn is not None\n                        and self.config.trainer.test_freq > 0\n                        and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)\n                    ):\n                        with marked_timer(\"testing\", timing_raw, color=\"green\"):\n                            val_metrics: dict = self._validate()\n                            if is_last_step:\n                                last_val_metrics = val_metrics\n                        metrics.update(val_metrics)\n\n                    # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.\n                    esi_close_to_expiration = should_save_ckpt_esi(\n                        max_steps_duration=self.max_steps_duration,\n                        redundant_time=self.config.trainer.esi_redundant_time,\n                    )\n                    # Check if the conditions for saving a checkpoint are met.\n                    # The conditions include a mandatory condition (1) and\n                    # one of the following optional conditions (2/3/4):\n                    # 1. The save frequency is set to a positive value.\n                    # 2. It's the last training step.\n                    # 3. The current step number is a multiple of the save frequency.\n                    # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.\n                    if self.config.trainer.save_freq > 0 and (\n                        is_last_step\n                        or self.global_steps % self.config.trainer.save_freq == 0\n                        or esi_close_to_expiration\n                    ):\n                        if esi_close_to_expiration:\n                            print(\"Force saving checkpoint: ESI instance expiration approaching.\")\n                        with marked_timer(\"save_checkpoint\", timing_raw, color=\"green\"):\n                            self._save_checkpoint()\n\n                with marked_timer(\"stop_profile\", timing_raw):\n                    self._stop_profiling(do_profile)\n\n                steps_duration = timing_raw[\"step\"]\n                self.max_steps_duration = max(self.max_steps_duration, steps_duration)\n\n                # training metrics\n                metrics.update(\n                    {\n                        \"training/global_step\": self.global_steps,\n                        \"training/epoch\": epoch,\n                    }\n                )\n                # collect metrics\n                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))\n                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))\n                # TODO: implement actual tflpo and theoretical tflpo\n                n_gpus = self.resource_pool_manager.get_n_gpus()\n                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))\n\n                # this is experimental and may be changed/removed in the future in favor of a general-purpose one\n                if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):\n                    self.train_dataloader.sampler.update(batch=batch)\n\n                # TODO: make a canonical logger that supports various backend\n                logger.log(data=metrics, step=self.global_steps)\n\n                progress_bar.update(1)\n                self.global_steps += 1\n\n                if is_last_step:\n                    pprint(f\"Final validation metrics: {last_val_metrics}\")\n                    progress_bar.close()\n                    return\n\n                # this is experimental and may be changed/removed in the future\n                # in favor of a general-purpose data buffer pool\n                if hasattr(self.train_dataset, \"on_batch_end\"):\n                    # The dataset may be changed after each training batch\n                    self.train_dataset.on_batch_end(batch=batch)\n"
  },
  {
    "path": "verl_rl/verl/trainer/ppo/reward.py",
    "content": "# Copyright 2025 Individual Contributor: Thibaut Barroyer\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport multiprocessing\nimport os\nfrom functools import partial\n\nimport ray\n\nfrom verl import DataProto\nfrom verl.utils.reward_score import default_compute_score\n\n\ndef _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs):\n    \"\"\"Calls `raw_fn` by merging `extra_kwargs` into call-time `kwargs`, with `extra_kwargs` taking precedence.\n\n    This function is used to merge additional keyword arguments with the original function's arguments.\n    \"\"\"\n    merged_kwargs = {**kwargs, **extra_kwargs}\n    return raw_fn(*args, **merged_kwargs)\n\n\ndef get_custom_reward_fn(config):\n    \"\"\"Load and return a custom reward function from external file.\n\n    Dynamically imports a reward function from a specified file path and wraps\n    it with additional keyword arguments from the configuration.\n\n    Args:\n        config (dict): Configuration dictionary containing custom_reward_function\n                      settings with 'path', 'name', and 'reward_kwargs' fields.\n\n    Returns:\n        callable or None: Wrapped reward function with merged kwargs, or None\n                         if no custom reward function is configured.\n\n    Raises:\n        FileNotFoundError: If the specified reward function file doesn't exist.\n        RuntimeError: If there's an error loading the module from file.\n        AttributeError: If the specified function name isn't found in the module.\n    \"\"\"\n    import importlib.util\n    import sys\n\n    reward_fn_config = config.get(\"custom_reward_function\") or {}\n    file_path = reward_fn_config.get(\"path\")\n    if not file_path:\n        return None\n\n    if not os.path.exists(file_path):\n        raise FileNotFoundError(f\"Reward function file '{file_path}' not found.\")\n\n    spec = importlib.util.spec_from_file_location(\"custom_module\", file_path)\n    module = importlib.util.module_from_spec(spec)\n    try:\n        sys.modules[\"custom_module\"] = module\n        spec.loader.exec_module(module)\n    except Exception as e:\n        raise RuntimeError(f\"Error loading module from '{file_path}': {e}\") from e\n\n    function_name = reward_fn_config.get(\"name\")\n    if not hasattr(module, function_name):\n        raise AttributeError(f\"Reward function '{function_name}' not found in '{file_path}'.\")\n\n    print(f\"using customized reward function '{function_name}' from '{file_path}'\")\n    raw_fn = getattr(module, function_name)\n\n    reward_kwargs = dict(reward_fn_config.get(\"reward_kwargs\", {}))\n\n    return partial(_call_with_kwargs, raw_fn, reward_kwargs)\n\n\ndef load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):\n    \"\"\"\n    Load and initialize a reward manager based on the configuration.\n\n    Args:\n        config: PPO trainer configuration object containing reward_model fields.\n        tokenizer: Tokenizer object used for processing text.\n        num_examine: Number of samples to examine.\n        **reward_kwargs: Additional keyword arguments for the reward manager.\n\n    Returns:\n        An instance of the specified reward manager class.\n    \"\"\"\n    from verl.workers.reward_manager import get_reward_manager_cls\n\n    # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:\n    # naive: NaiveRewardManager\n    # prime: PrimeRewardManager\n    # batch: BatchRewardManager\n    # dapo: DAPORewardManager\n    # Note(haibin.lin): For custom reward managers, please make sure they are imported and\n    # registered via `verl.workers.reward_manager.register`\n    # By default reward_manager is set to naive (NaiveRewardManager)\n    reward_manager_name = config.reward_model.get(\"reward_manager\", \"naive\")\n    reward_manager_cls = get_reward_manager_cls(reward_manager_name)\n\n    # Try to get a custom reward function based on the configuration\n    compute_score = get_custom_reward_fn(config)\n    final_compute_score = compute_score\n\n    if compute_score is None:\n        sandbox_config = config.reward_model.get(\"sandbox_fusion\")\n        sandbox_url = sandbox_config.get(\"url\") if sandbox_config else None\n        memory_limit_mb = sandbox_config.get(\"memory_limit_mb\", 1024)\n        if sandbox_url:\n            sandbox_manager = multiprocessing.Manager()\n            # Create a semaphore to control concurrent access to the sandbox\n            _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get(\"max_concurrent\", 64))\n            final_compute_score = partial(\n                default_compute_score,\n                sandbox_fusion_url=sandbox_url,\n                concurrent_semaphore=_concurrent_semaphore,\n                memory_limit_mb=memory_limit_mb,\n            )\n        else:\n            final_compute_score = default_compute_score\n\n    # Instantiate and return the reward manager with the specified parameters\n    return reward_manager_cls(\n        tokenizer=tokenizer,\n        num_examine=num_examine,\n        compute_score=final_compute_score,\n        reward_fn_key=config.data.reward_fn_key,\n        **reward_kwargs,\n    )\n\n\ndef compute_reward(data: DataProto, reward_fn):\n    \"\"\"\n    Compute reward for a batch of data.\n    Args:\n        data: DataProto object containing the input data.\n        reward_fn: Reward function to compute the reward.\n    Returns:\n        Tuple of reward tensor and extra info dictionary.\n    \"\"\"\n    try:\n        reward_result = reward_fn(data, return_dict=True)\n        reward_tensor = reward_result[\"reward_tensor\"]\n        reward_extra_infos_dict = reward_result.get(\"reward_extra_info\", {})\n    except Exception as e:\n        print(f\"Error in reward_fn: {e}\")\n        reward_tensor = reward_fn(data)\n        reward_extra_infos_dict = {}\n\n    return reward_tensor, reward_extra_infos_dict\n\n\n@ray.remote(num_cpus=1)\ndef compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn=None):\n    \"\"\"\n    Load the reward manager and compute the reward for a batch of data.\n    This is meant to be run in a separate Ray worker.\n    \"\"\"\n    if reward_fn is None:\n        assert config is not None and tokenizer is not None, (\n            \"config and tokenizer must not be None when reward_fn is None\"\n        )\n        import warnings\n\n        warnings.warn(\"using config and tokenizer with compute_reward_async is deprecated\", stacklevel=2)\n        reward_fn = load_reward_manager(\n            config, tokenizer, num_examine=0, **config.reward_model.get(\"reward_kwargs\", {})\n        )\n\n    return compute_reward(data, reward_fn)\n"
  },
  {
    "path": "verl_rl/verl/trainer/runtime_env.yaml",
    "content": "working_dir: ./\nexcludes: [\"/.git/\"]\nenv_vars:\n  TORCH_NCCL_AVOID_RECORD_STREAMS: \"1\"\n  CUDA_DEVICE_MAX_CONNECTIONS: \"1\"\n"
  },
  {
    "path": "verl_rl/verl/utils/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom . import config, tokenizer\nfrom .config import omega_conf_to_dataclass\nfrom .tokenizer import hf_processor, hf_tokenizer\n\n__all__ = tokenizer.__all__ + config.__all__ + [\"hf_processor\", \"hf_tokenizer\", \"omega_conf_to_dataclass\"]\n"
  },
  {
    "path": "verl_rl/verl/utils/activation_offload.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Functionality for CPU offloading of tensors saved for backward pass.\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nimport logging\nimport os\nfrom typing import Any, Optional\n\nimport torch\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nfrom verl.utils.device import get_torch_device\nfrom verl.utils.fsdp_utils import FSDPModule as FSDP2\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef _get_unique_tensor_key(tensor):\n    key = (tensor.untyped_storage().data_ptr() + tensor.storage_offset(), tensor.dtype)\n    return key\n\n\nclass FSDPParameterFilter:\n    def __init__(self):\n        self.model_parameters_storage = set()\n\n    def __call__(self, tensor):\n        return tensor.untyped_storage().data_ptr() not in self.model_parameters_storage\n\n    def update_model_parameters(self, model):\n        new_storage = set()\n        for p in model.parameters():\n            new_storage.add(p.data.untyped_storage().data_ptr())\n        self.model_parameters_storage = new_storage\n\n\nclass CpuOffloadHookWithOffloadHandler:\n    \"\"\"Context-manager that offloads/recovers tensors through an offload hander.\n\n    The hook just offloads/recovers the tensor object to the handler through `tensor_push`\n    and `tensor_pop` interface. How the offload-handler manages the offloading, recovering\n    or prefetching timing is transparent to this hook.\n    \"\"\"\n\n    def __init__(\n        self,\n        offload_handler: OffloadHandler,\n        handler_extra_kwargs: Optional[dict[str, Any]] = None,\n    ) -> None:\n        if handler_extra_kwargs is None:\n            handler_extra_kwargs = {}\n        self.offload_handler: OffloadHandler = offload_handler\n        self.handler_extra_kwargs: dict[str, Any] = handler_extra_kwargs\n        self.inside_context = False\n\n    def __enter__(self):\n        self.inside_context = True\n        torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor)\n\n    def __exit__(self, *args: Any):\n        self.inside_context = False\n        torch._C._autograd._pop_saved_tensors_default_hooks()\n\n    def on_save_for_backward(self, tensor: torch.Tensor) -> Any:\n        retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)\n        return retrieve_identifier\n\n    def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:\n        tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs)\n        return tensor\n\n\nclass OffloadHandler:\n    \"\"\"A base class for CPU offload-handler.\"\"\"\n\n    def __init__(self) -> None:\n        pass\n\n    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:\n        \"\"\"Tensor push.\"\"\"\n        raise NotImplementedError(\n            \"`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your \"\n            \"custom tensor_push.\"\n        )\n\n    def tensor_pop(self, tensor_tag: Any, **kwargs):\n        \"\"\"Tensor pop.\"\"\"\n        raise NotImplementedError(\n            \"`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your \"\n            \"custom tensor_pop.\"\n        )\n\n\nclass GroupCommitFunction(torch.autograd.Function):\n    \"\"\"this is a dummy op with output identical to input.\n    However, it is necessary for marking a timepoint for offload handler to\n    accomplish all synchronizations. Implementing it as a function is necessary\n    because we need to actions in both forward and backward.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, tensor, cpu_offload_handler):\n        # pylint: disable=missing-function-docstring\n        cpu_offload_handler.on_group_commit_forward()\n        ctx.cpu_offload_handler = cpu_offload_handler\n        # return the identical tensor\n        return tensor\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        # pylint: disable=missing-function-docstring\n        cpu_offload_handler = ctx.cpu_offload_handler\n        cpu_offload_handler.on_group_commit_backward()\n        return grad_output, None\n\n\ngroup_prefetch_offload_commit = GroupCommitFunction.apply\n\n\nclass SynchronizedGroupOffloadHandler(OffloadHandler):\n    \"\"\"Offload Handler that offloads/reloads in a synchronized way.\n    The device-to-host and host-to-device copying happen in the same stream\n    as the computation kernels, thus the copying will block computation.\n    \"\"\"\n\n    def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None:\n        super().__init__()\n\n        self.num_offload_group = num_offload_group\n        self.tensor_need_offloading_checker = tensor_need_offloading_checker\n\n        self.groupid_reset()\n\n    def groupid_reset(self):\n        \"\"\"Groupid reset.\"\"\"\n        # Data structures to label saved tensors and book-keep their cpu copies.\n        # Currently, on push, create a new cpu tensor and copies; on pop, copies\n        # the tensor back to gpu and deletes the cpu tensor.\n        # These will increment whenever `group_commit()` is invoked\n        self.current_group, self.tensor_count_current_group = (0, 0)\n        self.torch_tensor_count = 0\n        self.tensor_tag_to_state = {}\n\n    def on_group_commit_forward(self):\n        \"\"\"On group commit forward.\"\"\"\n        # finishing up with updating current group and tensor count\n        self.current_group += 1  # increment\n        self.tensor_count_current_group = 0  # reset\n\n    def on_group_commit_backward(self):\n        \"\"\"On group commit backward.\"\"\"\n        self.current_group -= 1\n        assert self.current_group >= 0\n\n    @staticmethod\n    def offload(src_tensor, pin_memory=True):\n        \"\"\"Offload.\"\"\"\n\n        cpu_backup = torch.empty(\n            src_tensor.size(),\n            dtype=src_tensor.dtype,\n            layout=src_tensor.layout,\n            device=\"cpu\",\n            pin_memory=pin_memory,\n        )\n        cpu_backup.copy_(src_tensor, non_blocking=True)\n        state = (src_tensor.device, cpu_backup)\n        return state\n\n    @staticmethod\n    def reload(state, non_blocking=None):\n        \"\"\"Reload.\"\"\"\n        dev, cpu_backup = state\n        if non_blocking is None:\n            non_blocking = cpu_backup.is_pinned()\n        return cpu_backup.to(dev, non_blocking=non_blocking)\n\n    def tensor_push(self, tensor: torch.Tensor, **kwargs):\n        \"\"\"Tensor push.\"\"\"\n        # obtain a unique tensor tag\n        tensor_tag = (self.current_group, self.tensor_count_current_group)\n        self.tensor_count_current_group += 1\n        assert tensor_tag not in self.tensor_tag_to_state\n        if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor):\n            state = SynchronizedGroupOffloadHandler.offload(tensor)\n            self.tensor_tag_to_state[tensor_tag] = state\n        else:\n            # will be offloaded together after group commit\n            self.tensor_tag_to_state[tensor_tag] = tensor\n\n        return tensor_tag\n\n    def tensor_pop(self, tensor_tag, **kwargs):\n        \"\"\"Tensor pop.\"\"\"\n        assert tensor_tag in self.tensor_tag_to_state\n        state = self.tensor_tag_to_state.pop(tensor_tag)\n        if isinstance(state, tuple):\n            tensor = SynchronizedGroupOffloadHandler.reload(state)\n        else:\n            tensor = state\n        return tensor\n\n\nclass AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):\n    \"\"\"Compared to synchronize, this uses more memory because of the buffer but\n    achieves better performance due to the overlapping. D2h and h2d copying are\n    completely hidden behind computation if computation time of a layer is longer\n    than host-device communication time. Bulk offloading with delay and bulk reloading\n    with prefetch are implemented.\"\"\"\n\n    def __init__(\n        self,\n        num_offload_group,  # must be <= actual number of groups (number of commits)\n        num_model_group,\n        tensor_need_offloading_checker=(lambda t: True),\n    ) -> None:\n        super().__init__(\n            num_offload_group=num_offload_group,\n            tensor_need_offloading_checker=tensor_need_offloading_checker,\n        )\n        # Number of layers in the model\n        self.num_layers = num_model_group\n        # Data Structure to maintain reference to activation tensors\n        self.tensor_tag_to_buf = {}\n        # Tracking the number of layers offloaded\n        self.offloaded_group_count = 0\n        # Core data structure that decides the window for offloading\n        self.layer_window_map = {}\n        self.group_offload_mapping = {}\n\n        # Logic to make offloading load balance across computation\n        # for optimal CPU/GPU interconnect usage\n        constant = 0\n        for i in range(self.num_offload_group):\n            self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1\n            if i < (self.num_layers % self.num_offload_group):\n                self.layer_window_map[i] += i + 1\n                constant = i + 1\n            else:\n                self.layer_window_map[i] += constant\n\n        # allocate streams and events for synchronization\n        self.d2h_stream = get_torch_device().Stream()\n        self.h2d_stream = get_torch_device().Stream()\n\n    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:\n        torch_stray_tensor = isinstance(\n            tensor,\n            torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor,\n        )\n        need_offload = not torch_stray_tensor\n        need_offload = need_offload and self.tensor_need_offloading_checker(tensor)\n\n        if need_offload:\n            # obtain a unique tensor tag\n            tensor_tag = (self.current_group, self.tensor_count_current_group)\n            self.tensor_count_current_group += 1\n\n            assert tensor_tag not in self.tensor_tag_to_state\n            self.tensor_tag_to_state[tensor_tag] = tensor\n\n            if self.current_group < self.num_offload_group:\n                self.tensor_tag_to_buf[tensor_tag] = tensor\n        else:\n            tensor_tag = tensor\n        return tensor_tag\n\n    def tensor_pop(self, tensor_tag, **kwargs):\n        \"\"\"Tensor pop.\"\"\"\n        if isinstance(tensor_tag, torch.Tensor):\n            return tensor_tag\n        assert tensor_tag in self.tensor_tag_to_state\n        tensor = self.tensor_tag_to_state.pop(tensor_tag)\n        self.tensor_tag_to_buf.pop(tensor_tag, None)\n\n        # the tensor should have been copied back in on_group_commit_backward()\n        # which invokes bulk_reload_group.\n        assert not isinstance(tensor, tuple)\n        return tensor\n\n    def bulk_offload_group(self, group_to_offload):\n        \"\"\"Bulk offload group.\"\"\"\n        offload_mapping = {}\n        offload_size = 0\n        with get_torch_device().stream(self.d2h_stream):\n            for tensor_tag, state in self.tensor_tag_to_state.items():\n                group_id, _ = tensor_tag\n                if group_id == group_to_offload:\n                    assert not isinstance(state, tuple)\n                    key = _get_unique_tensor_key(state)\n                    if key not in offload_mapping:\n                        offload_mapping[key] = state\n                    # if offload, return the reference to cpu copy\n                    self.tensor_tag_to_state[tensor_tag] = (key, state.shape)\n            for key, tensor in offload_mapping.items():\n                state = SynchronizedGroupOffloadHandler.offload(tensor)\n                offload_size += tensor.numel() * tensor.element_size()\n                offload_mapping[key] = state\n\n            self.group_offload_mapping[group_to_offload] = offload_mapping\n\n    def synchronize_on_group_commit_forward(self, current_group):\n        \"\"\"Synchronize on group commit forward.\"\"\"\n\n        # For the first group, kickstart the offload after we have\n        # the first compute completion\n        if current_group == 0:\n            self.d2h_stream.wait_stream(get_torch_device().current_stream())\n            self.bulk_offload_group(current_group)\n\n        # Window map data structure helps us synchronize based on number\n        # of layers offloaded\n        if self.layer_window_map[self.offloaded_group_count] == current_group:\n            # Stream synchronization both ways\n            self.d2h_stream.wait_stream(get_torch_device().current_stream())\n            get_torch_device().current_stream().wait_stream(self.d2h_stream)\n\n            # Time to free the activation memory after usage\n            for tensor_tag, _ in self.tensor_tag_to_buf.items():\n                if tensor_tag[0] == self.offloaded_group_count:\n                    self.tensor_tag_to_buf[tensor_tag] = None\n\n            # Time to offload the next group\n            if self.offloaded_group_count < (self.num_offload_group - 1):\n                self.bulk_offload_group(self.offloaded_group_count + 1)\n\n            # Increment the offload group count to keep track\n            self.offloaded_group_count += 1\n\n    def on_group_commit_forward(self):\n        \"\"\"This function will cause host device synchronization\"\"\"\n        # handle synchronization events\n        self.synchronize_on_group_commit_forward(self.current_group)\n\n        super().on_group_commit_forward()\n\n    @torch.no_grad\n    def bulk_reload_group(self, group_to_reload):\n        \"\"\"Bulk reload group.\"\"\"\n        assert group_to_reload < self.num_offload_group\n\n        with get_torch_device().stream(self.h2d_stream):\n            # move back tensors\n            offload_mapping = self.group_offload_mapping.pop(group_to_reload)\n            assert offload_mapping is not None\n            for key, state in offload_mapping.items():\n                offload_mapping[key] = SynchronizedGroupOffloadHandler.reload(state)\n            for tensor_label, state in self.tensor_tag_to_state.items():\n                group_id, _ = tensor_label\n                if group_id == group_to_reload and not isinstance(state, torch.Tensor):\n                    assert isinstance(state, tuple), f\"{group_id} {state}\"\n                    key, shape = state\n                    recovered_tensor = offload_mapping[key].view(shape)\n                    self.tensor_tag_to_state[tensor_label] = recovered_tensor\n\n    def on_group_commit_backward(self):\n        # first decrement the current group.\n        # after last commit in forward, the group will +1; in backward it -1.\n        # Finally it should be decremented to 0.\n        self.current_group -= 1\n        assert self.current_group >= 0\n\n        # Layer window data structure helps us to reload at right times\n        if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:\n            # Stream synchronization both ways\n            self.h2d_stream.wait_stream(get_torch_device().current_stream())\n            get_torch_device().current_stream().wait_stream(self.h2d_stream)\n\n            # Time to reload the next group\n            self.bulk_reload_group(self.offloaded_group_count - 1)\n\n            # Decrease the offloading group counter\n            self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0\n\n        # Last group computation needs to wait till all the reloads complete\n        if self.current_group == 0:\n            get_torch_device().current_stream().wait_stream(self.h2d_stream)\n            self.offloaded_group_count = 0\n\n\ndef get_activation_offload_context(\n    num_layers: int = 1, model_layers: int = 1, tensor_need_offloading_checker=(lambda t: True)\n):\n    cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(\n        num_offload_group=num_layers,\n        num_model_group=model_layers,\n        tensor_need_offloading_checker=tensor_need_offloading_checker,\n    )\n\n    def group_prefetch_offload_commit_async(tensor):\n        return group_prefetch_offload_commit(tensor, cpu_offload_handler)\n\n    return (\n        CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),\n        group_prefetch_offload_commit_async,\n    )\n\n\nclass ActivationHandler:\n    def __init__(self, offload_ctx, sync_func, tensor_filter, enable_ckpt):\n        self._offload_ctx = offload_ctx\n        self._sync_func = sync_func\n        self._enable_ckpt = enable_ckpt\n        self._tensor_filter = tensor_filter\n        if enable_ckpt:\n            self.checkpoint_fn = functools.partial(\n                torch.utils.checkpoint.checkpoint,\n                use_reentrant=True,\n            )\n\n    def pre_forward(self, module):\n        if module.training:\n            self._offload_ctx.__enter__()\n            self._tensor_filter.update_model_parameters(module)\n\n    def post_forward(self, module):\n        if module.training:\n            self._offload_ctx.__exit__(None, None, None)\n\n    def _pack_kwargs(self, *args, **kwargs):\n        kwarg_keys = []\n        flat_args = list(args)\n        for k, v in kwargs.items():\n            kwarg_keys.append(k)\n            flat_args.append(v)\n\n        return tuple(flat_args), tuple(kwarg_keys)\n\n    def _unpack_kwargs(self, flat_args, kwarg_keys):\n        assert len(kwarg_keys) <= len(flat_args), f\"too many keys {len(kwarg_keys)} vs. {len(flat_args)}\"\n        if len(kwarg_keys) == 0:\n            return flat_args, {}\n        args = flat_args[: -len(kwarg_keys)]\n        kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :], strict=True))\n        return args, kwargs\n\n    def _ckpt_forward(self, forward_method, *args, **kwargs):\n        flat_args, kwarg_keys = self._pack_kwargs(*args, **kwargs)\n\n        def my_function(*inputs):\n            # unpack back into args and kwargs\n            nonlocal forward_method, kwarg_keys\n            unpacked_args, unpacked_kwargs = self._unpack_kwargs(inputs, kwarg_keys)\n            # run original module\n            return forward_method(*unpacked_args, **unpacked_kwargs)\n\n        return self.checkpoint_fn(\n            my_function,\n            *flat_args,\n        )\n\n    def forward(self, module, forward_method, *args, **kwargs):\n        if not module.training:\n            return forward_method(*args, **kwargs)\n        if not self._enable_ckpt:\n            ret = forward_method(*args, **kwargs)\n        else:\n            ret = self._ckpt_forward(forward_method, *args, **kwargs)\n        binded_tensor = ret\n        if isinstance(ret, tuple):\n            binded_tensor = ret[0]\n        binded_tensor = self._sync_func(binded_tensor)\n        final_ret = binded_tensor\n        if isinstance(ret, tuple):\n            final_ret = (final_ret,) + ret[1:]\n        return final_ret\n\n    def wrap_module_forward_method(self, module):\n        orig_method = module.forward\n        handler = self\n\n        @functools.wraps(orig_method)\n        def wrapped_method(model_self, *args, **kwargs):\n            nonlocal handler\n            handler.pre_forward(model_self)\n            out = handler.forward(model_self, orig_method, *args, **kwargs)\n            handler.post_forward(model_self)\n            return out\n\n        module.forward = wrapped_method.__get__(module, type(module))\n\n\ndef enable_activation_offloading(model, strategy, enable_ckpt=False):\n    \"\"\"\n    Enable activation offloading for the model. It groups activations by TransformerLayer and offloads activation\n    groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th\n    activation group happen at the same time, and there are at most two activation groups in GPU memory.\n\n    Args:\n        model: the model to enable activation offloading\n        strategy: the training strategy of the model, such as \"fsdp\"\n        enable_ckpt: whether activation checkpointing(also called gradient checkpointing) has been enabled for the model\n\n    Note:\n        For best efficiency, activation offloading is usually combined with activation checkpointing. However, this\n        implementation of activation offloading is conflicted with the implementation of activation checkpointing in\n        some training strategies. This function resolves this conflict, and therefore requires the \"strategy\" and\n        \"enable_ckpt\" arguments.\n\n    Returns:\n\n    \"\"\"\n\n    assert strategy == \"fsdp\" or strategy == \"fsdp2\", \"activation offloading only supports fsdp strategy\"\n    layers = []\n\n    def get_layers(module):\n        for name, child in module.named_children():\n            if not isinstance(child, FSDP | FSDP2):\n                get_layers(child)\n            else:\n                wrapped_module = child\n                if isinstance(child, FSDP):\n                    wrapped_module = child._fsdp_wrapped_module\n                # In some cases, torch.nn.Embedding is wrapped with FSDP alone. However, the activation\n                # size of torch.nn.Embedding is small, so it's not necessary to offload it.\n                if not isinstance(wrapped_module, torch.nn.Embedding):\n                    layers.append(child)\n\n    get_layers(model)\n    if len(layers) < 3:\n        logger.warning(f\"Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading\")\n        return\n\n    tensor_filter = FSDPParameterFilter()\n    context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter)\n    if enable_ckpt:\n        # The implementation of activation checkpointing in transformers library is incompatible with\n        # activation offloading,\n        # so it will be disabled, but this implementation supports another version of activation checkpointing, so that\n        # these two features can be enabled at the same time.\n        for module in model.modules():\n            if hasattr(module, \"gradient_checkpointing_disable\"):\n                module.gradient_checkpointing_disable()\n\n    handler = ActivationHandler(context, sync_func, tensor_filter, enable_ckpt)\n    for layer in layers:\n        module = layer\n        if isinstance(layer, FSDP):\n            module = module._fsdp_wrapped_module\n        handler.wrap_module_forward_method(module)\n"
  },
  {
    "path": "verl_rl/verl/utils/checkpoint/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/utils/checkpoint/checkpoint_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport random\nimport shutil\n\nimport numpy as np\nimport torch\nimport torch.distributed\nfrom omegaconf import DictConfig\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nfrom verl.utils.device import get_device_name, get_torch_device\n\n\nclass BaseCheckpointManager:\n    \"\"\"\n    A checkpoint manager that saves and loads\n    - model\n    - optimizer\n    - lr_scheduler\n    - extra_states\n    in a SPMD way.\n\n    We save\n    - sharded model states and optimizer states\n    - full lr_scheduler states\n    - huggingface tokenizer and config for ckpt merge\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        optimizer: torch.optim.Optimizer,\n        lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,\n        processing_class: PreTrainedTokenizer | ProcessorMixin = None,\n        checkpoint_config: DictConfig = None,\n    ):\n        self.checkpoint_config = checkpoint_config\n        checkpoint_load_contents = checkpoint_config.get(\"load_contents\", None) if checkpoint_config else None\n        checkpoint_save_contents = checkpoint_config.get(\"save_contents\", None) if checkpoint_config else None\n        if checkpoint_load_contents is None:\n            checkpoint_load_contents = [\"model\", \"optimizer\", \"extra\"]\n        if checkpoint_save_contents is None:\n            checkpoint_save_contents = [\"model\", \"optimizer\", \"extra\"]\n        self.previous_global_step = None\n        self.previous_saved_paths = []\n\n        self.model = model\n        self.optimizer = optimizer\n        self.lr_scheduler = lr_scheduler\n        self.processing_class = processing_class\n        self.checkpoint_load_contents = checkpoint_load_contents\n        self.checkpoint_save_contents = checkpoint_save_contents\n\n        self.rank = torch.distributed.get_rank()\n        self.world_size = torch.distributed.get_world_size()\n\n    @property\n    def should_save_model(self) -> bool:\n        \"\"\"\n        Returns True if 'model' is in checkpoint_save_contents, indicating the model state should be saved.\n        \"\"\"\n        return \"model\" in self.checkpoint_save_contents\n\n    @property\n    def should_save_optimizer(self) -> bool:\n        \"\"\"\n        Returns True if 'optimizer' is in checkpoint_save_contents, indicating the optimizer state should be saved.\n        \"\"\"\n        return \"optimizer\" in self.checkpoint_save_contents\n\n    @property\n    def should_save_extra(self) -> bool:\n        \"\"\"\n        Returns True if 'extra' is in checkpoint_save_contents, indicating the extra state should be saved.\n        \"\"\"\n        return \"extra\" in self.checkpoint_save_contents\n\n    @property\n    def should_save_hf_model(self) -> bool:\n        \"\"\"\n        Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf\n        model and saved.\n        \"\"\"\n        return \"hf_model\" in self.checkpoint_save_contents\n\n    @property\n    def should_load_model(self) -> bool:\n        \"\"\"\n        Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded.\n        \"\"\"\n        return \"model\" in self.checkpoint_load_contents\n\n    @property\n    def should_load_optimizer(self) -> bool:\n        \"\"\"\n        Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded.\n        \"\"\"\n        return \"optimizer\" in self.checkpoint_load_contents\n\n    @property\n    def should_load_extra(self) -> bool:\n        \"\"\"\n        Returns True if 'extra' is in checkpoint_load_contents, indicating the extra state should be loaded.\n        \"\"\"\n        return \"extra\" in self.checkpoint_load_contents\n\n    def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False):\n        raise NotImplementedError\n\n    def save_checkpoint(\n        self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None\n    ):\n        raise NotImplementedError\n\n    @staticmethod\n    def checkpath(local_path: str, hdfs_path: str):\n        assert local_path is not None or hdfs_path is not None, \"local_path and hdfs_path cannot be both None\"\n        return local_path is not None, local_path if local_path is not None else hdfs_path\n\n    def remove_previous_save_local_path(self, path):\n        if isinstance(path, str):\n            path = [path]\n        for p in path:\n            abs_path = os.path.abspath(p)\n            print(f\"Checkpoint manager remove previous save local path: {abs_path}\")\n            if not os.path.exists(abs_path):\n                continue\n            shutil.rmtree(abs_path, ignore_errors=True)\n\n    @staticmethod\n    def get_rng_state():\n        rng_state = {\n            \"cpu\": torch.get_rng_state(),\n            \"numpy\": np.random.get_state(),\n            \"random\": random.getstate(),\n        }\n\n        if get_device_name() != \"cpu\":\n            rng_state[get_device_name()] = get_torch_device().get_rng_state()\n\n        return rng_state\n\n    @staticmethod\n    def load_rng_state(rng_state):\n        torch.set_rng_state(rng_state[\"cpu\"])\n        np.random.set_state(rng_state[\"numpy\"])\n        random.setstate(rng_state[\"random\"])\n\n        if get_device_name() != \"cpu\":\n            get_torch_device().set_rng_state(rng_state[get_device_name()])\n\n\ndef find_latest_ckpt_path(path, directory_format=\"global_step_{}\"):\n    \"\"\"\n    Return the most recent checkpoint directory based on a tracker file.\n\n    Args:\n        path (str): Base directory containing the checkpoint tracker.\n        directory_format (str): Template for checkpoint subfolders with one\n            placeholder for the iteration number (default \"global_step_{}\").\n\n    Returns:\n        str or None: Full path to the latest checkpoint directory, or\n        None if the tracker or checkpoint folder is missing.\n    \"\"\"\n    if path is None:\n        return None\n\n    tracker_file = get_checkpoint_tracker_filename(path)\n    if not os.path.exists(tracker_file):\n        print(f\"Checkpoint tracker file does not exist: {tracker_file}\")\n        return None\n\n    with open(tracker_file, \"rb\") as f:\n        iteration = int(f.read().decode())\n    ckpt_path = os.path.join(path, directory_format.format(iteration))\n    if not os.path.exists(ckpt_path):\n        print(\"Checkpoint does not exist: %s\", ckpt_path)\n        return None\n\n    print(\"Found checkpoint: %s\", ckpt_path)\n    return ckpt_path\n\n\ndef get_checkpoint_tracker_filename(root_path: str):\n    \"\"\"\n    Tracker file rescords the latest chckpoint during training to restart from.\n    \"\"\"\n    return os.path.join(root_path, \"latest_checkpointed_iteration.txt\")\n\n\ndef should_save_ckpt_esi(max_steps_duration: float, save_ckpt_duration: float = 60, redundant_time: float = 0) -> bool:\n    \"\"\"\n    Determine if checkpoint should be saved based on capacity esi expiration.\n\n    Args:\n        max_steps_duration: Max estimated time (seconds) required to complete one training step\n        save_ckpt_duration: Estimated time (seconds) required to save checkpoint (default: 60)\n        redundant_time: Additional buffer time (seconds) for unexpected delays (default: 0)\n    \"\"\"\n    exp_ts_mlp = os.getenv(\"MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP\")  # vemlp\n    exp_ts_aws = os.getenv(\"SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP\")  # aws\n    if exp_ts_mlp:\n        try:\n            import time\n\n            remaining = float(exp_ts_mlp) - time.time()\n        except ValueError:\n            return False\n        return (\n            remaining > 0\n            and max_steps_duration > 0\n            and remaining <= save_ckpt_duration + max_steps_duration + redundant_time\n        )\n    elif exp_ts_aws:\n        from datetime import datetime, timedelta\n\n        expiration_time = datetime.fromtimestamp(int(exp_ts_aws))\n        time_difference = expiration_time - datetime.now()\n        threshold_minutes = (save_ckpt_duration + max_steps_duration + redundant_time) / 60\n        return time_difference < timedelta(minutes=threshold_minutes)\n    else:\n        return False\n"
  },
  {
    "path": "verl_rl/verl/utils/checkpoint/fsdp_checkpoint_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport warnings\nfrom dataclasses import asdict, dataclass\nfrom typing import Optional\n\nimport torch\nimport torch.distributed\nfrom accelerate import init_empty_weights\nfrom omegaconf import DictConfig\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType\nfrom transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin\n\nfrom verl.utils.device import is_cuda_available\nfrom verl.utils.fs import copy_to_local, is_non_local, local_mkdir_safe\nfrom verl.utils.fsdp_utils import fsdp_version, get_fsdp_full_state_dict, get_fsdp_state_ctx\nfrom verl.utils.logger import log_with_rank\n\nfrom .checkpoint_manager import BaseCheckpointManager\n\n# Setup logging\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"INFO\"))\n\n\n@dataclass\nclass FSDPConfig:\n    \"\"\"Configuration for FSDP checkpointing.\n\n    Args:\n        FSDP_version (int): Version of FSDP being used.\n        world_size (int): Number of processes in the distributed training setup.\n    \"\"\"\n\n    FSDP_version: int\n    world_size: int\n\n\nclass FSDPCheckpointManager(BaseCheckpointManager):\n    \"\"\"\n    Manage FSDP checkpointing in SPMD training.\n\n    - Saves/loads per-rank sharded model & optimizer states\n    - Persists full lr_scheduler and RNG state\n    - Stores HF tokenizer/processor and model/config for unified restore\n\n    Args:\n        model (FSDP): Wrapped model instance.\n        optimizer (Optimizer): Training optimizer.\n        lr_scheduler (LRScheduler): Learning-rate scheduler.\n        processing_class (PreTrainedTokenizer or ProcessorMixin, optional):\n            Pre-/post-processing artifact handler.\n        checkpoint_contents DictConfig: Configuration for checkpoint contents.\n            - 'load': Components to load; must contain 'model'. Defaults to ['model', 'optimizer', 'extra'].\n            - 'save': Components to save; must contain 'model'. Defaults to ['model', 'optimizer', 'extra'].\n    \"\"\"\n\n    def __init__(\n        self,\n        model: FSDP,\n        optimizer: Optional[torch.optim.Optimizer] = None,\n        lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,\n        processing_class: PreTrainedTokenizer | ProcessorMixin = None,\n        checkpoint_config: DictConfig = None,\n        **kwargs,\n    ):\n        if processing_class is None:\n            assert \"tokenizer\" in kwargs, \"tokenizer or processor must be provided\"\n            warnings.warn(\n                \"`tokenizer` is deprecated. use `processing_class` instead.\", DeprecationWarning, stacklevel=2\n            )\n            processing_class = kwargs.pop(\"tokenizer\")\n\n        super().__init__(\n            model,\n            optimizer,\n            lr_scheduler=lr_scheduler,\n            processing_class=processing_class,\n            checkpoint_config=checkpoint_config,\n        )\n\n    def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):\n        \"\"\"\n        Load an FSDP checkpoint for this rank.\n\n        Downloads and loads:\n          - model and optimizer shards\n          - extra state dict (scheduler + RNG)\n\n        Args:\n            local_path: Directory with per-rank checkpoint files.\n            hdfs_path: Unused (for API compatibility).\n            del_local_after_load: Remove local files after loading.\n        \"\"\"\n        if local_path is None:\n            return\n\n        # check if the checkpoint_load_contents is valid\n        if self.should_load_model:\n            assert self.model is not None, \"model must be provided when checkpoint_contents.load includes ['model']\"\n        if self.should_load_optimizer:\n            assert self.optimizer is not None, (\n                \"optimizer must be provided when checkpoint_contents.load includes ['optimizer']\"\n            )\n\n        # every rank download its own checkpoint\n        state_dict_cfg = (\n            ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n            if self.should_load_model\n            else None\n        )\n        optim_cfg = (\n            ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n            if self.should_load_optimizer\n            else None\n        )\n        with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):\n            if self.should_load_model:\n                remote_model_path = os.path.join(local_path, f\"model_world_size_{self.world_size}_rank_{self.rank}.pt\")\n                local_model_path = copy_to_local(remote_model_path)\n                model_state_dict = torch.load(local_model_path, weights_only=False)\n                self.model.load_state_dict(model_state_dict)\n                log_with_rank(f\"Loaded model from {remote_model_path}\", rank=self.rank, logger=logger)\n\n            if self.should_load_optimizer:\n                remote_optim_path = os.path.join(local_path, f\"optim_world_size_{self.world_size}_rank_{self.rank}.pt\")\n                local_optim_path = copy_to_local(remote_optim_path)\n                optimizer_state_dict = torch.load(local_optim_path, weights_only=False)\n                self.optimizer.load_state_dict(optimizer_state_dict)\n                log_with_rank(f\"Loaded optimizer from {remote_optim_path}\", rank=self.rank, logger=logger)\n\n        if self.should_load_extra:\n            remote_extra_state_path = os.path.join(\n                local_path, f\"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt\"\n            )\n            local_extra_state_path = copy_to_local(remote_extra_state_path)\n            extra_state_dict = torch.load(local_extra_state_path, weights_only=False)\n            # recover random state\n            if \"rng\" in extra_state_dict:\n                # 'rng' may not exist for backward compatibility\n                self.load_rng_state(extra_state_dict[\"rng\"])\n                log_with_rank(f\"Loaded rng from {remote_extra_state_path}\", rank=self.rank, logger=logger)\n\n            lr_scheduler_state_dict = extra_state_dict[\"lr_scheduler\"]\n            if lr_scheduler_state_dict is not None and self.lr_scheduler is not None:\n                self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)\n                log_with_rank(f\"Loaded lr_scheduler from {remote_extra_state_path}\", rank=self.rank, logger=logger)\n\n        if self.rank == 0 and del_local_after_load:\n            try:\n                os.remove(local_model_path) if is_non_local(local_model_path) else None\n                os.remove(local_optim_path) if is_non_local(local_optim_path) else None\n                os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None\n            except Exception as e:\n                log_with_rank(\n                    f\"remove local resume ckpt file after loading failed, exception {e} will be ignored\",\n                    rank=self.rank,\n                    logger=logger,\n                )\n\n        # wait for everyone to load checkpoints\n        torch.distributed.barrier()\n\n    def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):\n        \"\"\"\n        Save an FSDP checkpoint for this rank.\n\n        Writes:\n          - model & optimizer shard files\n          - extra state dict (scheduler + RNG)\n          - HF tokenizer/processor and model/config on rank 0\n          - optional full HF model under 'huggingface/' if requested\n\n        Rotates old checkpoints, keeping at most `max_ckpt_to_keep`.\n\n        Args:\n            local_path: Target directory for checkpoint files.\n            hdfs_path: Unused (for API compatibility).\n            global_step: Current training step (used for bookkeeping).\n            max_ckpt_to_keep: Number of recent checkpoints to retain.\n        \"\"\"\n        if local_path is None:\n            return\n\n        # record the previous global step\n        self.previous_global_step = global_step\n\n        # remove previous local_path, only rank 0 should do this\n        if (\n            self.rank == 0\n            and max_ckpt_to_keep\n            and isinstance(max_ckpt_to_keep, int)\n            and max_ckpt_to_keep > 0\n            and len(self.previous_saved_paths) >= max_ckpt_to_keep\n        ):\n            keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1\n            self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])\n            self.previous_saved_paths = self.previous_saved_paths[keep_start:]\n\n        local_path = local_mkdir_safe(local_path)\n        torch.distributed.barrier()\n\n        # check if the checkpoint_save_contents is valid\n        if self.should_save_model:\n            assert self.model is not None, \"model must be provided when checkpoint_contents.save includes ['model']\"\n        if self.should_save_optimizer:\n            assert self.optimizer is not None, (\n                \"optimizer must be provided when checkpoint_contents.save includes ['optimizer']\"\n            )\n\n        # every rank will save its own model and optim shard\n        state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n        optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n        with warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):\n                model_path = os.path.join(local_path, f\"model_world_size_{self.world_size}_rank_{self.rank}.pt\")\n                optim_path = os.path.join(local_path, f\"optim_world_size_{self.world_size}_rank_{self.rank}.pt\")\n                extra_path = os.path.join(local_path, f\"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt\")\n\n                if self.should_save_model:\n                    model_state_dict = self.model.state_dict()\n                    torch.save(model_state_dict, model_path)\n                    log_with_rank(f\"Saved model to {os.path.abspath(model_path)}\", rank=self.rank, logger=logger)\n\n                if self.should_save_optimizer:\n                    optimizer_state_dict = self.optimizer.state_dict()\n                    torch.save(optimizer_state_dict, optim_path)\n                    log_with_rank(f\"Saved optim to {os.path.abspath(optim_path)}\", rank=self.rank, logger=logger)\n\n                if self.should_save_extra:\n                    lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None\n                    extra_state_dict = {\n                        \"lr_scheduler\": lr_scheduler_state_dict,\n                        \"rng\": self.get_rng_state(),\n                    }\n                    torch.save(extra_state_dict, extra_path)\n                    log_with_rank(f\"Saved extra_state to {os.path.abspath(extra_path)}\", rank=self.rank, logger=logger)\n\n        if self.rank == 0:\n            # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether\n            # huggingface model is requested to be saved or not.\n\n            if fsdp_version(self.model) == 1:\n                unwrap_model = self.model._fsdp_wrapped_module\n            else:\n                unwrap_model = self.model\n\n            hf_config_tokenizer_path = os.path.join(local_path, \"huggingface\")\n            local_mkdir_safe(hf_config_tokenizer_path)\n            model_config = unwrap_model.config\n            generation_config = None\n            if unwrap_model.can_generate() and hasattr(model_config, \"name_or_path\") and model_config.name_or_path:\n                try:\n                    # Some model's name_or_path is empty if not initialized from pretrained,\n                    # in this cases, we don't save generation config.\n                    generation_config = GenerationConfig.from_pretrained(model_config.name_or_path)\n                    generation_config.save_pretrained(hf_config_tokenizer_path)\n                except Exception:\n                    # if the generation config isn't available, we don't save it\n                    pass\n\n            model_config.save_pretrained(hf_config_tokenizer_path)\n            self.processing_class.save_pretrained(hf_config_tokenizer_path)\n            log_with_rank(\n                f\"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}\",\n                rank=self.rank,\n                logger=logger,\n                log_only_rank_0=True,\n            )\n\n            # Also save runtime FSDP config\n            fsdp_config_path = os.path.join(local_path, \"fsdp_config.json\")\n            fsdp_config = FSDPConfig(\n                FSDP_version=fsdp_version(self.model),\n                world_size=self.world_size,\n            )\n            with open(fsdp_config_path, \"w\") as f:\n                json.dump(asdict(fsdp_config), f, indent=4)\n\n        # wait for everyone to dump to local\n        torch.distributed.barrier()\n\n        if self.should_save_hf_model:\n            # Only rank 0 will save hf model and,\n            # offload to cpu to save LLMs which may be too large to fit in one GPU\n            state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True)\n\n            if self.rank == 0:\n                hf_local_path = os.path.join(local_path, \"huggingface\")\n                os.makedirs(hf_local_path, exist_ok=True)\n\n                if \"ForTokenClassification\" in model_config.architectures[0]:\n                    from transformers import AutoModelForTokenClassification\n\n                    auto_model_cls = AutoModelForTokenClassification\n                elif \"ForCausalLM\" in model_config.architectures[0]:\n                    from transformers import AutoModelForCausalLM\n\n                    auto_model_cls = AutoModelForCausalLM\n                elif \"ForConditionalGeneration\" in model_config.architectures[0]:\n                    from transformers import AutoModelForVision2Seq\n\n                    auto_model_cls = AutoModelForVision2Seq\n                else:\n                    raise NotImplementedError(f\"Unknown architecture {model_config['architectures']}\")\n\n                with init_empty_weights():\n                    save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)\n                save_model.to_empty(device=\"cpu\")\n\n                if save_model.can_generate():\n                    if generation_config is not None:\n                        save_model.generation_config = generation_config\n                    else:\n                        print(\n                            f\"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found \"\n                            f\"in, using a generation config created from the model config when saving hf_model.\"\n                        )\n\n                save_model.save_pretrained(hf_local_path, state_dict=state_dict)\n                log_with_rank(\n                    f\"Saved hf_model to {os.path.abspath(hf_local_path)}\",\n                    rank=self.rank,\n                    logger=logger,\n                    log_only_rank_0=True,\n                )\n                del state_dict\n                del save_model\n\n            # wait for rank0 to dump hf_model to local\n            torch.distributed.barrier()\n\n        self.previous_saved_paths.append(local_path)\n"
  },
  {
    "path": "verl_rl/verl/utils/checkpoint/megatron_checkpoint_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport random\nfrom collections.abc import Callable\nfrom dataclasses import asdict\n\nimport numpy as np\nimport torch\nimport torch.distributed\nfrom megatron.core import mpu, tensor_parallel\nfrom megatron.core.dist_checkpointing.mapping import ShardedObject\nfrom megatron.core.transformer.enums import AttnBackend\nfrom transformers import GenerationConfig\n\nfrom verl.models.weight_loader_registry import get_weight_saver\nfrom verl.utils.device import get_device_name, get_torch_device\nfrom verl.utils.fs import is_non_local, local_mkdir_safe\nfrom verl.utils.logger import log_with_rank\nfrom verl.utils.megatron.dist_checkpointing import load_dist_checkpointing, save_dist_checkpointing\nfrom verl.utils.megatron_utils import (\n    get_dist_checkpoint_path,\n    get_hf_model_checkpoint_path,\n    get_transformer_config_checkpoint_path,\n)\n\nfrom .checkpoint_manager import BaseCheckpointManager\n\n# Setup logging\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"INFO\"))\n\n\nclass MegatronCheckpointManager(BaseCheckpointManager):\n    \"\"\"\n    Checkpoint manager for Megatron-LM distributed training.\n\n    This class manages the saving and loading of model checkpoints in a Megatron-LM\n    distributed training environment. It handles various aspects of checkpointing\n    including model states, optimizer states, learning rate schedulers, and random\n    number generator states, ensuring compatibility with HuggingFace formats.\n\n    Key features:\n    - Distributed checkpoint saving and loading using Megatron's dist_checkpointing\n    - Support for tensor parallel, pipeline parallel, and data parallel configurations\n    - Automatic handling of model state dictionaries across multiple pipeline stages\n    - Integration with HuggingFace model configurations and tokenizers\n    - Random number generator state management for reproducibility\n    - Support for both synchronous and asynchronous checkpoint operations\n\n    The manager automatically handles:\n    - Directory structure creation based on global steps and process ranks\n    - Model configuration and tokenizer saving in HuggingFace format\n    - Optimizer and scheduler state persistence\n    - CUDA RNG state management for deterministic training\n    - Checkpoint cleanup and retention policies\n\n    Args:\n        model: The Megatron model instance to checkpoint\n        optimizer: The optimizer instance (optional)\n        lr_scheduler: The learning rate scheduler instance (optional)\n\n    Attributes:\n        model: Reference to the Megatron model being checkpointed\n        optimizer: Reference to the optimizer (if provided)\n        lr_scheduler: Reference to the learning rate scheduler (if provided)\n        rank: Current process rank in the distributed setup\n\n    Example:\n        ```python\n        checkpoint_manager = MegatronCheckpointManager(\n            model=megatron_model,\n            optimizer=optimizer,\n            lr_scheduler=scheduler\n        )\n\n        checkpoint_manager.save_checkpoint(\n            local_path=\"checkpoints/step_1000\",\n            global_step=1000\n        )\n\n        checkpoint_manager.load_checkpoint(\n            local_path=\"checkpoints/step_1000\"\n        )\n        ```\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        checkpoint_config,\n        model_config,\n        transformer_config,\n        role,\n        model: torch.nn.ModuleList,\n        arch: str,\n        hf_config,\n        param_dtype: torch.dtype,\n        share_embeddings_and_output_weights: bool,\n        processing_class,\n        optimizer,\n        optimizer_scheduler,\n        use_distributed_optimizer: bool,\n        use_checkpoint_opt_param_scheduler: bool = False,\n        use_dist_checkpointing: bool = True,\n        bridge=None,\n        **kwargs,\n    ):\n        super().__init__(\n            model,\n            optimizer=optimizer,\n            lr_scheduler=optimizer_scheduler,\n            processing_class=processing_class,\n            checkpoint_config=checkpoint_config,\n        )\n        self.arch = arch\n        self.config = config\n        self.transformer_config = transformer_config\n        self.role = role\n        self.is_value_model = False\n        if self.role in [\"reward\", \"critic\"]:\n            self.is_value_model = True\n        self.model_config = model_config\n        self.hf_config = hf_config\n        self.param_dtype = param_dtype\n        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights\n        self.model_path = self.config.model.path\n        self.use_distributed_optimizer = use_distributed_optimizer\n        self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler\n        self.bridge = bridge\n        self.rank = torch.distributed.get_rank()\n        self.use_dist_checkpointing = use_dist_checkpointing or not self.bridge or self.is_value_model\n        self.use_hf_checkpoint = not self.use_dist_checkpointing\n\n        self.weight_saver = get_weight_saver(self.arch)\n\n    def get_rng_state(self, use_dist_ckpt: bool = True, data_parallel_random_init: bool = False):\n        \"\"\"collect rng state across data parallel ranks\"\"\"\n        rng_state = {\n            \"random_rng_state\": random.getstate(),\n            \"np_rng_state\": np.random.get_state(),\n            \"torch_rng_state\": torch.get_rng_state(),\n            \"rng_tracker_states\": tensor_parallel.get_cuda_rng_tracker().get_states(),\n        }\n\n        if get_device_name() != \"cpu\":\n            rng_state[f\"{get_device_name()}_rng_state\"] = get_torch_device().get_rng_state()\n\n        rng_state_list = None\n        if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init:\n            rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())]\n            torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group())\n        else:\n            rng_state_list = [rng_state]\n\n        if use_dist_ckpt:\n            pp_rank = mpu.get_pipeline_model_parallel_rank()\n            pp_size = mpu.get_pipeline_model_parallel_world_size()\n            tp_rank = mpu.get_tensor_model_parallel_rank()\n            tp_size = mpu.get_tensor_model_parallel_world_size()\n            rng_state_list = ShardedObject(\n                \"rng_state\",\n                rng_state_list,\n                (pp_size, tp_size),\n                (pp_rank, tp_rank),\n                replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),\n            )\n\n        return rng_state_list\n\n    def get_checkpoint_name(\n        self,\n        checkpoints_path,\n        pipeline_parallel=None,\n        tensor_rank=None,\n        pipeline_rank=None,\n        cp_rank=None,\n        expert_parallel=None,\n        expert_rank=None,\n        return_base_dir=True,\n        basename=\"model.pt\",\n    ):\n        \"\"\"Determine the directory name for this rank's checkpoint.\"\"\"\n        # Use both the tensor and pipeline MP rank.\n        if pipeline_parallel is None:\n            pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1\n        if tensor_rank is None:\n            tensor_rank = mpu.get_tensor_model_parallel_rank()\n        if pipeline_rank is None:\n            pipeline_rank = mpu.get_pipeline_model_parallel_rank()\n        if cp_rank is None:\n            cp_rank = mpu.get_context_parallel_rank()\n        if expert_parallel is None:\n            expert_parallel = mpu.get_expert_model_parallel_world_size() > 1\n        if expert_rank is None:\n            expert_rank = mpu.get_expert_model_parallel_rank()\n\n        # Use both the tensor and pipeline MP rank. If using the distributed\n        # optimizer, then the optimizer's path must additionally include the\n        # data parallel rank.\n\n        # due to the fact that models are identical across cp ranks, cp rank is not used in the checkpoint path\n        if not pipeline_parallel:\n            common_path = os.path.join(checkpoints_path, f\"mp_rank_{tensor_rank:02d}\")\n        else:\n            common_path = os.path.join(checkpoints_path, f\"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}\")\n\n        if expert_parallel:\n            common_path = common_path + f\"_{expert_rank:03d}\"\n\n        os.makedirs(common_path, exist_ok=True)\n\n        if return_base_dir:\n            return common_path\n        return os.path.join(common_path, basename)\n\n    def generate_state_dict(self):\n        # For save dist checkpointing\n        state_dict = {}\n\n        # All ranks Save Model to reduce memory pressure\n        if self.should_save_model or self.should_load_model:\n            # Get sharded state dict, notice that state_dict will collect among dp groups, causing memory pressure\n            for vpp_rank, model in enumerate(self.model):\n                if len(self.model) > 1:\n                    mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)\n                    key = f\"model{vpp_rank}\" if len(self.model) > 1 else \"model\"\n                else:\n                    key = \"model\"\n                if hasattr(model, \"module\"):\n                    model = model.module\n                state_dict[key] = model.sharded_state_dict()\n\n        # Optimizer State Dict\n        if self.should_save_optimizer or self.should_load_optimizer:\n            torch.distributed.barrier()\n            optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict)\n            state_dict[\"optimizer\"] = optimizer_sharded_states\n\n            if self.lr_scheduler is not None:\n                lr_state_dict = self.lr_scheduler.state_dict()\n                state_dict[\"lr_scheduler\"] = lr_state_dict\n\n        # RNG States State Dict\n        if self.should_save_extra or self.should_load_extra:\n            torch.distributed.barrier()\n            rng_state = self.get_rng_state()\n            state_dict[\"rng_state\"] = rng_state\n\n        return state_dict\n\n    def load_rng_states(self, rng_states, data_parallel_random_init=False, use_dist_ckpt=True):\n        # access rng_state for data parallel rank\n        if data_parallel_random_init:\n            rng_states = rng_states[mpu.get_data_parallel_rank()]\n        else:\n            rng_states = rng_states[0]\n        random.setstate(rng_states[\"random_rng_state\"])\n        np.random.set_state(rng_states[\"np_rng_state\"])\n        torch.set_rng_state(rng_states[\"torch_rng_state\"])\n\n        if get_device_name() != \"cpu\":\n            get_torch_device().set_rng_state(rng_states[f\"{get_device_name()}_rng_state\"])\n\n        # Check for empty states array\n        if not rng_states[\"rng_tracker_states\"]:\n            raise KeyError\n        tensor_parallel.get_cuda_rng_tracker().set_states(rng_states[\"rng_tracker_states\"])\n\n    def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):\n        if local_path is not None:\n            assert os.path.exists(local_path), f\"Checkpoint path {local_path} does not exist.\"\n\n        dist_checkpoint_path = get_dist_checkpoint_path(local_path)\n\n        # Get State Dict for loading\n        sharded_state_dict = self.generate_state_dict()\n        log_with_rank(f\"Generated state dict for saving: {sharded_state_dict.keys()}\", rank=self.rank, logger=logger)\n        for vpp_rank, model in enumerate(self.model):\n            if len(self.model) > 1:\n                model_i_keys = sharded_state_dict[f\"model{vpp_rank}\"].keys()\n                log_with_rank(f\"Generated state dict for saving: {model_i_keys}\", rank=self.rank, logger=logger)\n            else:\n                log_with_rank(\n                    f\"Generated state dict for saving: {sharded_state_dict['model'].keys()}\",\n                    rank=self.rank,\n                    logger=logger,\n                )\n\n        # Load Dist Checkpointing\n        state_dict = load_dist_checkpointing(\n            sharded_state_dict=sharded_state_dict,\n            ckpt_dir=dist_checkpoint_path,\n        )\n\n        if self.should_load_model and self.use_dist_checkpointing:\n            assert \"model\" in state_dict or any(\n                f\"model{vpp_rank}\" in state_dict for vpp_rank in range(len(self.model))\n            ), f\"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}.\"\n            for vpp_rank, model in enumerate(self.model):\n                if len(self.model) == 1:\n                    model_state_dict = state_dict[\"model\"]\n                else:\n                    assert f\"model{vpp_rank}\" in state_dict, f\"model{vpp_rank} not found in state_dict\"\n                    model_state_dict = state_dict[f\"model{vpp_rank}\"]\n                mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)\n                self.model[vpp_rank].load_state_dict(model_state_dict)\n            log_with_rank(f\"Loaded sharded model checkpoint from {local_path}\", rank=self.rank, logger=logger)\n        elif self.should_load_model and self.use_hf_checkpoint:\n            hf_model_path = get_hf_model_checkpoint_path(local_path)\n            self.bridge.load_weights(self.model, hf_model_path)\n            log_with_rank(f\"Loaded HF model checkpoint from {hf_model_path} with bridge\", rank=self.rank, logger=logger)\n\n        if self.should_load_optimizer:\n            assert \"optimizer\" in state_dict, (\n                f\"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}.\"\n            )\n            optimizer_state_dict = state_dict[\"optimizer\"]\n            self.optimizer.load_state_dict(optimizer_state_dict)\n            log_with_rank(f\"Loaded optimizer checkpoint from {local_path}\", rank=self.rank, logger=logger)\n            if self.use_checkpoint_opt_param_scheduler:\n                assert \"lr_scheduler\" in state_dict, (\n                    f\"LR scheduler state dict not found in {state_dict.keys()}. Please check the checkpoint file \"\n                    f\"{local_path}.\"\n                )\n                lr_scheduler_state_dict = state_dict[\"lr_scheduler\"]\n                if self.lr_scheduler is not None:\n                    self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)\n                    log_with_rank(f\"Loaded LR scheduler checkpoint from {local_path}\", rank=self.rank, logger=logger)\n\n        if self.should_load_extra:\n            assert \"rng_state\" in state_dict, (\n                f\"RNG state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}.\"\n            )\n            rng_state = state_dict[\"rng_state\"]\n            self.load_rng_states(rng_state)\n            log_with_rank(f\"Loaded RNG states from {local_path}\", rank=self.rank, logger=logger)\n\n        if del_local_after_load:\n            try:\n                os.remove(local_path) if is_non_local(local_path) else None\n            except Exception as e:\n                log_with_rank(\n                    f\"remove local resume ckpt file after loading failed, exception {e} will be ignored\",\n                    rank=self.rank,\n                    logger=logger,\n                )\n\n    def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):\n        # record the previous global step\n        self.previous_global_step = global_step\n\n        # remove previous local_path\n        if (\n            max_ckpt_to_keep\n            and isinstance(max_ckpt_to_keep, int)\n            and max_ckpt_to_keep > 0\n            and len(self.previous_saved_paths) >= max_ckpt_to_keep\n        ):\n            keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1\n            self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])\n            self.previous_saved_paths = self.previous_saved_paths[keep_start:]\n\n        local_path = local_mkdir_safe(local_path)\n        dist_checkpoint_path = get_dist_checkpoint_path(local_path)\n\n        if self.use_dist_checkpointing:\n            # Generate state dict for saving\n            state_dict = self.generate_state_dict()\n            log_with_rank(f\"Generated state dict for saving: {state_dict.keys()}\", rank=self.rank, logger=logger)\n            for vpp_rank, model in enumerate(self.model):\n                if len(self.model) > 1:\n                    model_i_keys = state_dict[f\"model{vpp_rank}\"].keys()\n                    log_with_rank(f\"Generated state dict for saving: {model_i_keys}\", rank=self.rank, logger=logger)\n                else:\n                    log_with_rank(\n                        f\"Generated state dict for saving: {state_dict['model'].keys()}\", rank=self.rank, logger=logger\n                    )\n            # Start Async save if enabled\n            async_save_request = save_dist_checkpointing(\n                sharded_state_dict=state_dict,\n                ckpt_path=dist_checkpoint_path,\n                async_save=self.checkpoint_config.async_save,\n            )\n\n            # Synchronize all async save requests\n            if not self.checkpoint_config.async_save:\n                assert async_save_request is None, \"Async save request should be None when not using async save.\"\n                torch.distributed.barrier()\n        else:\n            assert self.use_hf_checkpoint, \"use_hf_checkpoint should be True when not using dist checkpointing\"\n            log_with_rank(f\"Saving HF model checkpoint to {local_path} with bridge\", rank=self.rank, logger=logger)\n            hf_ckpt_path = get_hf_model_checkpoint_path(local_path)\n            self.bridge.save_weights(self.model, hf_ckpt_path)\n            log_with_rank(f\"Saved bridge checkpoint to {hf_ckpt_path}\", rank=self.rank, logger=logger)\n\n        if self.should_save_model:\n            # Only rank 0 saves the hf config and tokenizer to huggingface path\n            # No matter whether we save hf model or not\n            if self.rank == 0:\n                # Save tokenizer\n                hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path)\n                self.processing_class.save_pretrained(hf_config_tokenizer_path)\n                # Save huggingface config\n                self.hf_config.save_pretrained(hf_config_tokenizer_path)\n                if hasattr(self.hf_config, \"name_or_path\") and self.hf_config.name_or_path:\n                    try:\n                        generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path)\n                        generation_config.save_pretrained(hf_config_tokenizer_path)\n                    except Exception:\n                        # if the generation config isn't available, we don't save it\n                        pass\n                log_with_rank(\n                    f\"Saved Huggingface config and tokenizer to {hf_config_tokenizer_path}\",\n                    rank=self.rank,\n                    logger=logger,\n                    log_only_rank_0=True,\n                )\n\n        if self.should_save_extra:\n            if self.rank == 0:\n                # Save transformer config\n                print(self.transformer_config)\n                transformer_config_dict = asdict(self.transformer_config)\n                to_convert_types = {torch.dtype: str, AttnBackend: str}\n                ignore_types = [Callable]\n                pop_keys = []\n                for key, value in transformer_config_dict.items():\n                    if type(value) in to_convert_types:\n                        transformer_config_dict[key] = to_convert_types[type(value)](value)\n                    if type(value) in ignore_types:\n                        pop_keys.append(key)\n                    if callable(value):\n                        pop_keys.append(key)\n                for key in pop_keys:\n                    transformer_config_dict.pop(key)\n                transformer_config_path = get_transformer_config_checkpoint_path(local_path)\n                with open(transformer_config_path, \"w\") as f:\n                    json.dump(transformer_config_dict, f, indent=2)\n\n        if self.should_save_hf_model:\n            # wait for everyone to dump to local\n            state_dict = self.weight_saver(\n                self.model,\n                self.hf_config,\n                dtype=self.param_dtype,\n                is_value_model=self.is_value_model,\n                tie_word_embeddings=self.share_embeddings_and_output_weights,\n            )\n\n            torch.distributed.barrier()\n            if self.rank == 0:\n                hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)\n                import warnings\n\n                from accelerate import init_empty_weights\n\n                with init_empty_weights(), warnings.catch_warnings():\n                    warnings.simplefilter(\"ignore\")\n                    if \"mistral7b-rm\" in self.config.model.path:\n                        from transformers import MistralForSequenceClassification\n\n                        model = MistralForSequenceClassification.from_pretrained(\n                            self.config.model.path\n                        )  # use score head instead of lm_head\n                        state_dict[\"score.weight\"] = state_dict[\"score.weight\"]\n                    else:\n                        from transformers import AutoModelForCausalLM\n\n                        model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype=\"auto\")\n                model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)\n                log_with_rank(\n                    f\"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}\",\n                    rank=self.rank,\n                    logger=logger,\n                    log_only_rank_0=True,\n                )\n\n                if hdfs_path is not None:\n                    log_with_rank(\n                        f\"Uploading checkpoint to {hdfs_path}\", rank=self.rank, logger=logger, log_only_rank_0=True\n                    )\n                    from verl.utils import hdfs_io\n\n                    hdfs_io.makedirs(hdfs_path, exist_ok=True)\n                    hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)\n                    log_with_rank(\n                        f\"HDFS checkpoint uploaded to {hdfs_path}\", rank=self.rank, logger=logger, log_only_rank_0=True\n                    )\n\n        def finalize_save_fn():\n            # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided\n            log_with_rank(\n                f\"Dist checkpointing save completed for {dist_checkpoint_path}\", rank=self.rank, logger=logger\n            )\n            if self.rank == 0:\n                if hdfs_path is not None:\n                    log_with_rank(f\"Uploading checkpoint to {hdfs_path}\", rank=self.rank, logger=logger)\n                    from verl.utils import hdfs_io\n\n                    hdfs_io.makedirs(hdfs_path, exist_ok=True)\n                    hdfs_io.copy(src=dist_checkpoint_path, dst=hdfs_path, dirs_exist_ok=True)\n                    hdfs_io.copy(src=hf_config_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True)\n\n        if self.checkpoint_config.async_save:\n            assert async_save_request is not None, \"Async save request should not be None when using async save.\"\n            async_save_request.add_finalize_fn(finalize_save_fn)\n        else:\n            finalize_save_fn()\n\n        self.previous_saved_paths.append(local_path)\n"
  },
  {
    "path": "verl_rl/verl/utils/config.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import is_dataclass\nfrom typing import Any, Optional\n\nfrom omegaconf import DictConfig, ListConfig, OmegaConf\n\n__all__ = [\"omega_conf_to_dataclass\"]\n\n\ndef omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any:\n    \"\"\"\n    Convert an OmegaConf DictConfig to a dataclass.\n\n    Args:\n        config: The OmegaConf DictConfig or dict to convert.\n        dataclass_type: The dataclass type to convert to. When dataclass_type is None,\n            the DictConfig must contain _target_ to be instantiated via hydra.instantiate API.\n\n    Returns:\n        The dataclass instance.\n    \"\"\"\n    # Got an empty config\n    if not config:\n        return dataclass_type if dataclass_type is None else dataclass_type()\n    # Got an object\n    if not isinstance(config, DictConfig | ListConfig | dict | list):\n        return config\n\n    if dataclass_type is None:\n        assert \"_target_\" in config, (\n            \"When dataclass_type is not provided, config must contain _target_.\"\n            \"See trainer/config/ppo_trainer.yaml algorithm section for an example.\"\n        )\n        from hydra.utils import instantiate\n\n        return instantiate(config, _convert_=\"partial\")\n\n    if not is_dataclass(dataclass_type):\n        raise ValueError(f\"{dataclass_type} must be a dataclass\")\n    cfg = OmegaConf.create(config)  # in case it's a dict\n    cfg_from_dataclass = OmegaConf.structured(dataclass_type)\n    # let cfg override the existing vals in `cfg_from_dataclass`\n    cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg)\n    # now convert to `dataclass_type`\n    config_object = OmegaConf.to_object(cfg_merged)\n    return config_object\n\n\ndef update_dict_with_config(dictionary: dict, config: DictConfig):\n    for key in dictionary:\n        if hasattr(config, key):\n            dictionary[key] = getattr(config, key)\n"
  },
  {
    "path": "verl_rl/verl/utils/dataset/README.md",
    "content": "# Dataset Format\n## RLHF dataset\nWe combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers.\n\nMath problems\n```json\n{\n    \"data_source\": \"openai/gsm8k\",\n    \"prompt\": [{\"role\": \"user\", \"content\": \"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \\\"####\\\"\"}],\n    \"ability\": \"math\",\n    \"reward_model\": {\n        \"style\": \"rule\",\n        \"ground_truth\": [\"72\"]\n    },\n}\n```\n"
  },
  {
    "path": "verl_rl/verl/utils/dataset/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .rl_dataset import RLHFDataset\nfrom .rm_dataset import RMDataset\nfrom .sft_dataset import SFTDataset\n\n__all__ = [\"RLHFDataset\", \"RMDataset\", \"SFTDataset\"]\n"
  },
  {
    "path": "verl_rl/verl/utils/dataset/multiturn_sft_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n\n#     http://www.apache.org/licenses/LICENSE-2.0\n\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMulti-turn SFT dataset that supports training on conversation data with multiple turns\n\"\"\"\n\nimport logging\nfrom typing import Any, Optional\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer\n\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.fs import copy_local_path_from_hdfs\n\n\ndef convert_nested_value_to_list_recursive(data_item):\n    if isinstance(data_item, dict):\n        return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()}\n    elif isinstance(data_item, list):\n        return [convert_nested_value_to_list_recursive(elem) for elem in data_item]\n    elif isinstance(data_item, np.ndarray):\n        # Convert to list, then recursively process the elements of the new list\n        return convert_nested_value_to_list_recursive(data_item.tolist())\n    else:\n        # Base case: item is already a primitive type (int, str, float, bool, etc.)\n        return data_item\n\n\nclass MultiTurnSFTDataset(Dataset):\n    \"\"\"\n    Dataset for multi-turn conversations where each assistant response should be trained\n    \"\"\"\n\n    def __init__(self, parquet_files: str | list[str], tokenizer, config=None):\n        # Set defaults and extract parameters from config if provided\n        config = config or {}\n        self.truncation = config.get(\"truncation\", \"error\")\n        self.max_length = config.get(\"max_length\", 1024)\n        # Get messages_key from the new multiturn config structure\n        multiturn_config = config.get(\"multiturn\", {})\n        self.messages_key = multiturn_config.get(\"messages_key\", \"messages\")\n        self.tools_key = multiturn_config.get(\"tools_key\", \"tools\")\n        self.enable_thinking_key = multiturn_config.get(\"enable_thinking_key\", \"enable_thinking\")\n        assert self.truncation in [\"error\", \"left\", \"right\"]\n\n        if not isinstance(parquet_files, list):\n            parquet_files = [parquet_files]\n\n        self.parquet_files = parquet_files\n        if isinstance(tokenizer, str):\n            tokenizer = hf_tokenizer(tokenizer)\n        self.tokenizer: PreTrainedTokenizer = tokenizer\n\n        self._download()\n        self._read_files_and_process()\n\n    def _download(self):\n        for i, parquet_file in enumerate(self.parquet_files):\n            self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True)\n\n    def _read_files_and_process(self):\n        def series_to_item(ls):\n            import numpy\n            import pandas\n\n            while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1:\n                ls = ls[0]\n            return ls\n\n        dataframes = []\n        for parquet_file in self.parquet_files:\n            dataframe = pd.read_parquet(parquet_file)\n            dataframes.append(dataframe)\n        self.dataframe = pd.concat(dataframes)\n\n        # Extract messages list from dataframe\n        self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist()\n\n        # Extract tools list from dataframe\n        if self.tools_key in self.dataframe.columns:\n            self.tools = self.dataframe[self.tools_key].apply(convert_nested_value_to_list_recursive).tolist()\n        else:\n            self.tools = None\n        # Extract enable_thinking list from dataframe\n        if self.enable_thinking_key in self.dataframe.columns:\n            self.enable_thinking = self.dataframe[self.enable_thinking_key].tolist()\n        else:\n            self.enable_thinking = None\n\n    def __len__(self):\n        return len(self.messages)\n\n    def _process_message_tokens(\n        self,\n        messages: list[dict[str, Any]],\n        start_idx: int,\n        end_idx: int,\n        is_assistant: bool = False,\n        enable_thinking: Optional[bool] = None,\n        tools: Optional[list[dict[str, Any]]] = None,\n    ) -> tuple[list[int], list[int], list[int]]:\n        \"\"\"\n        Process tokens for a single message or a group of messages.\n\n        Args:\n            messages: List of message dictionaries\n            start_idx: Start index in messages list\n            end_idx: End index in messages list\n            is_assistant: Whether this is an assistant message\n            enable_thinking: Whether to enable thinking mode\n\n        Returns:\n            Tuple of (tokens, loss_mask, attention_mask)\n        \"\"\"\n        if start_idx > 0:\n            prev_applied_text = self.tokenizer.apply_chat_template(\n                messages[:start_idx],\n                tokenize=False,\n                add_generation_prompt=False,\n                enable_thinking=enable_thinking,\n                tools=tools,\n            )\n            if is_assistant:\n                prev_applied_text_w_generation_prompt = self.tokenizer.apply_chat_template(\n                    messages[:start_idx],\n                    tokenize=False,\n                    add_generation_prompt=True,\n                    enable_thinking=enable_thinking,\n                    tools=tools,\n                )\n\n        else:\n            prev_applied_text = \"\"\n\n        cur_applied_text = self.tokenizer.apply_chat_template(\n            messages[:end_idx],\n            tokenize=False,\n            add_generation_prompt=False,\n            enable_thinking=enable_thinking,\n            tools=tools,\n        )\n        # Get tokens for the current message only\n        if is_assistant:\n            generation_prompt_text = prev_applied_text_w_generation_prompt[len(prev_applied_text) :]\n            generation_prompt_tokens = self.tokenizer.encode(\n                generation_prompt_text,\n                add_special_tokens=False,\n            )\n            _message_tokens = self.tokenizer.encode(\n                cur_applied_text[len(prev_applied_text_w_generation_prompt) :],\n                add_special_tokens=False,\n            )\n            message_tokens = generation_prompt_tokens + _message_tokens\n            loss_mask = [0] * (len(generation_prompt_tokens)) + [1] * (\n                len(message_tokens) - len(generation_prompt_tokens)\n            )\n        else:\n            message_tokens = self.tokenizer.encode(\n                cur_applied_text[len(prev_applied_text) :],\n                add_special_tokens=False,\n            )\n            loss_mask = [0] * len(message_tokens)\n\n        attention_mask = [1] * len(message_tokens)\n\n        return message_tokens, loss_mask, attention_mask\n\n    def _validate_and_convert_tokens(\n        self,\n        full_tokens: torch.Tensor,\n        concat_tokens: list[int],\n        concat_loss_mask: list[int],\n        concat_attention_mask: list[int],\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Validate tokenization and convert to tensors.\n\n        Args:\n            full_tokens: Full conversation tokens\n            concat_tokens: Concatenated tokens\n            concat_loss_mask: Concatenated loss mask\n            concat_attention_mask: Concatenated attention mask\n\n        Returns:\n            Tuple of (input_ids, loss_mask, attention_mask) as tensors\n        \"\"\"\n        full_tokens_list = full_tokens.tolist()\n\n        if len(concat_tokens) != len(full_tokens_list) or not all(\n            a == b for a, b in zip(concat_tokens, full_tokens_list, strict=True)\n        ):\n            logging.warning(\n                f\"Token mismatch detected! Full tokenization length: {len(full_tokens_list)}, Concatenated tokens \"\n                f\"length: {len(concat_tokens)}. Using concatenated version.\"\n                # f\"full tokens text: {self.tokenizer.decode(full_tokens_list)}\"\n                # f\"concat tokens text: {self.tokenizer.decode(concat_tokens)}\"\n            )\n            return (\n                torch.tensor(concat_tokens, dtype=torch.long),\n                torch.tensor(concat_loss_mask, dtype=torch.long),\n                torch.tensor(concat_attention_mask, dtype=torch.long),\n            )\n\n        return (\n            full_tokens,\n            torch.tensor(concat_loss_mask, dtype=torch.long),\n            torch.tensor(concat_attention_mask, dtype=torch.long),\n        )\n\n    def __getitem__(self, item):\n        tokenizer = self.tokenizer\n        messages = self.messages[item]\n        tools = self.tools[item] if self.tools is not None else None\n        enable_thinking = self.enable_thinking[item] if self.enable_thinking is not None else None\n\n        # First, get the full conversation tokens\n        try:\n            full_tokens = tokenizer.apply_chat_template(\n                messages,\n                tools=tools,\n                tokenize=True,\n                return_tensors=\"pt\",\n                add_generation_prompt=False,\n                enable_thinking=enable_thinking,\n            )\n        except Exception as e:\n            logging.error(\n                f\"Error applying chat template: {e}\\nMessages: {messages}\\nTools: {tools}\\nEnable thinking: \"\n                f\"{enable_thinking}\"\n            )\n            raise\n\n        # Track concatenated tokens for validation\n        concat_tokens = []\n        concat_loss_mask = []\n        concat_attention_mask = []\n\n        i = 0\n        while i < len(messages):\n            cur_messages = messages[i]\n            if cur_messages[\"role\"] == \"assistant\":\n                # Process assistant message\n                tokens, loss_mask, attention_mask = self._process_message_tokens(\n                    messages, i, i + 1, is_assistant=True, enable_thinking=enable_thinking, tools=tools\n                )\n                concat_tokens.extend(tokens)\n                concat_loss_mask.extend(loss_mask)\n                concat_attention_mask.extend(attention_mask)\n                i += 1\n            elif cur_messages[\"role\"] == \"tool\":\n                # Process consecutive tool messages\n                st = i\n                ed = i + 1\n                while ed < len(messages) and messages[ed][\"role\"] == \"tool\":\n                    ed += 1\n                tokens, loss_mask, attention_mask = self._process_message_tokens(\n                    messages, st, ed, enable_thinking=enable_thinking, tools=tools\n                )\n                concat_tokens.extend(tokens)\n                concat_loss_mask.extend(loss_mask)\n                concat_attention_mask.extend(attention_mask)\n                i = ed\n            elif cur_messages[\"role\"] in [\"user\", \"system\"]:\n                # Process user or system message\n                if cur_messages[\"role\"] == \"system\" and i != 0:\n                    raise ValueError(\"System message should be the first message\")\n                tokens, loss_mask, attention_mask = self._process_message_tokens(\n                    messages, i, i + 1, enable_thinking=enable_thinking, tools=tools\n                )\n                concat_tokens.extend(tokens)\n                concat_loss_mask.extend(loss_mask)\n                concat_attention_mask.extend(attention_mask)\n                i += 1\n            else:\n                raise ValueError(f\"Unknown role: {cur_messages['role']}\")\n\n        # Validate and convert tokens\n        input_ids, loss_mask, attention_mask = self._validate_and_convert_tokens(\n            full_tokens[0], concat_tokens, concat_loss_mask, concat_attention_mask\n        )\n\n        # Handle sequence length\n        sequence_length = input_ids.shape[0]\n        if sequence_length < self.max_length:\n            # Pad sequences\n            pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0\n            padded_input_ids = torch.full((self.max_length - sequence_length,), pad_token_id, dtype=input_ids.dtype)\n            padded_attention_mask = torch.zeros((self.max_length - sequence_length,), dtype=attention_mask.dtype)\n            padded_loss_mask = torch.zeros((self.max_length - sequence_length,), dtype=loss_mask.dtype)\n\n            input_ids = torch.cat((input_ids, padded_input_ids))\n            attention_mask = torch.cat((attention_mask, padded_attention_mask))\n            loss_mask = torch.cat((loss_mask, padded_loss_mask))\n        elif sequence_length > self.max_length:\n            if self.truncation == \"left\":\n                input_ids = input_ids[-self.max_length :]\n                attention_mask = attention_mask[-self.max_length :]\n                loss_mask = loss_mask[-self.max_length :]\n            elif self.truncation == \"right\":\n                input_ids = input_ids[: self.max_length]\n                attention_mask = attention_mask[: self.max_length]\n                loss_mask = loss_mask[: self.max_length]\n            elif self.truncation == \"error\":\n                raise ValueError(f\"{sequence_length=} is larger than {self.max_length=}\")\n            else:\n                raise ValueError(f\"Unknown truncation method {self.truncation}\")\n\n        # Create position IDs\n        position_ids = torch.arange(len(input_ids), dtype=torch.long)\n        # Zero out position IDs for padding\n        position_ids = position_ids * attention_mask\n\n        return {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"position_ids\": position_ids,\n            \"loss_mask\": loss_mask,\n        }\n"
  },
  {
    "path": "verl_rl/verl/utils/dataset/rl_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nimport logging\nimport os\nimport re\nfrom collections import defaultdict\nfrom typing import Optional\n\nimport datasets\nimport numpy as np\nimport torch\nfrom omegaconf import DictConfig, ListConfig\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nimport verl.utils.torch_functional as verl_F\nfrom verl.utils.model import compute_position_id_with_mask\n\nlogger = logging.getLogger(__name__)\n\n\ndef collate_fn(data_list: list[dict]) -> dict:\n    \"\"\"\n    Collate a batch of sample dicts into batched tensors and arrays.\n\n    Args:\n        data_list: List of dicts mapping feature names to torch.Tensor or other values.\n\n    Returns:\n        Dict where tensor entries are stacked into a torch.Tensor of shape\n        (batch_size, \\*dims) and non-tensor entries are converted to\n        np.ndarray of dtype object with shape (batch_size,).\n    \"\"\"\n    tensors = defaultdict(list)\n    non_tensors = defaultdict(list)\n\n    for data in data_list:\n        for key, val in data.items():\n            if isinstance(val, torch.Tensor):\n                tensors[key].append(val)\n            else:\n                non_tensors[key].append(val)\n\n    for key, val in tensors.items():\n        tensors[key] = torch.stack(val, dim=0)\n\n    for key, val in non_tensors.items():\n        non_tensors[key] = np.array(val, dtype=object)\n\n    return {**tensors, **non_tensors}\n\n\nclass RLHFDataset(Dataset):\n    \"\"\"\n    Load and preprocess RLHF data from Parquet files.\n\n    - Caches files locally.\n    - Reads into a HuggingFace Dataset and tokenizes prompts.\n    - Optionally handles images/videos via a ProcessorMixin.\n    - Filters prompts over a max length.\n    - Supports resuming from checkpoints.\n\n    Args:\n        data_files (str or list): Path(s) to Parquet file(s).\n        tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.\n        config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.\n        processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_files: str | list[str],\n        tokenizer: PreTrainedTokenizer,\n        config: DictConfig,\n        processor: Optional[ProcessorMixin] = None,\n    ):\n        if not isinstance(data_files, list | ListConfig):\n            data_files = [data_files]\n\n        self.data_files = copy.deepcopy(data_files)\n        self.original_data_files = copy.deepcopy(data_files)  # use for resume\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.config = config\n\n        self.cache_dir = os.path.expanduser(config.get(\"cache_dir\", \"~/.cache/verl/rlhf\"))\n        self.prompt_key = config.get(\"prompt_key\", \"prompt\")\n        self.image_key = config.get(\"image_key\", \"images\")\n        self.video_key = config.get(\"video_key\", \"videos\")\n        self.max_prompt_length = config.get(\"max_prompt_length\", 1024)\n        self.return_raw_chat = config.get(\"return_raw_chat\", False)\n        self.return_full_prompt = config.get(\"return_full_prompt\", False)\n        self.truncation = config.get(\"truncation\", \"error\")\n        self.filter_overlong_prompts = config.get(\"filter_overlong_prompts\", True)\n\n        self.num_workers = config.get(\"filter_overlong_prompts_workers\", max(1, os.cpu_count() // 4))\n        self.num_workers = min(self.num_workers, os.cpu_count())\n        self.use_shm = config.get(\"use_shm\", False)\n        self.chat_template_func = config.get(\"chat_template_func\", None)\n        self.need_tools_kwargs = config.get(\"need_tools_kwargs\", False)\n        self.filter_prompts = config.get(\"filter_prompts\", True)\n        self.serialize_dataset = False\n        self.return_multi_modal_inputs = config.get(\"return_multi_modal_inputs\", True)\n\n        self._download()\n        self._read_files_and_tokenize()\n\n    def _download(self, use_origin_parquet=False):\n        from verl.utils.fs import copy_to_local\n\n        data_files = self.data_files if not use_origin_parquet else self.original_data_files\n        for i, parquet_file in enumerate(data_files):\n            self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)\n\n    def _read_files_and_tokenize(self):\n        dataframes = []\n        for parquet_file in self.data_files:\n            # read parquet files and cache\n            dataframe = datasets.load_dataset(\"parquet\", data_files=parquet_file)[\"train\"]\n            dataframes.append(dataframe)\n        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)\n\n        print(f\"dataset len: {len(self.dataframe)}\")\n\n        self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)\n\n    def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None):\n        # filter out too long prompts\n        if self.filter_overlong_prompts:\n            tokenizer = self.tokenizer\n            processor = self.processor\n            prompt_key = self.prompt_key\n            image_key = self.image_key\n            video_key = self.video_key\n\n            if processor is not None:\n                from verl.utils.dataset.vision_utils import process_image, process_video\n\n                def doc2len(doc) -> int:\n                    messages = self._build_messages(doc)\n                    raw_prompt = self.processor.apply_chat_template(\n                        messages, add_generation_prompt=True, tokenize=False\n                    )\n                    images = [process_image(image) for image in doc[image_key]] if image_key in doc else None\n                    videos = [process_video(video) for video in doc[video_key]] if video_key in doc else None\n\n                    return len(processor(text=[raw_prompt], images=images, videos=videos)[\"input_ids\"][0])\n\n            else:\n\n                def doc2len(doc) -> int:\n                    return len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True))\n\n            dataframe = dataframe.filter(\n                lambda doc: doc2len(doc) <= self.max_prompt_length,\n                num_proc=self.num_workers,\n                desc=f\"Filtering prompts longer than {self.max_prompt_length} tokens\",\n            )\n\n            print(f\"filter dataset len: {len(dataframe)}\")\n        return dataframe\n\n    def resume_dataset_state(self):\n        self.serialize_dataset = not hasattr(self, \"original_data_files\")\n        # resume dataframe if not it's serialized in data.pt\n        if not self.serialize_dataset:\n            self._download(use_origin_parquet=True)  # download and resume from original parquet files\n            self._read_files_and_tokenize()\n        else:\n            print(r\"old dataloader ckpt file is used, please train from scratch for better ckpt performance\")\n\n    def __len__(self):\n        return len(self.dataframe)\n\n    def _build_messages(self, example: dict):\n        messages: list = example.pop(self.prompt_key)\n\n        if self.image_key in example or self.video_key in example:\n            for message in messages:\n                content = message[\"content\"]\n                content_list = []\n                segments = re.split(\"(<image>|<video>)\", content)\n                segments = [item for item in segments if item != \"\"]\n                for segment in segments:\n                    if segment == \"<image>\":\n                        content_list.append({\"type\": \"image\"})\n                    elif segment == \"<video>\":\n                        content_list.append({\"type\": \"video\"})\n                    else:\n                        content_list.append({\"type\": \"text\", \"text\": segment})\n\n                message[\"content\"] = content_list\n\n        return messages\n\n    def __getitem__(self, item):\n        \"\"\"\n        Note that we also return the raw_input_ids so that it can be combined with other chat template\n        \"\"\"\n        row_dict: dict = self.dataframe[item]\n        messages = self._build_messages(row_dict)\n        model_inputs = {}\n\n        if self.processor is not None:\n            from verl.utils.dataset.vision_utils import process_image, process_video\n\n            raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n            multi_modal_data = {}\n\n            images = None\n            if self.image_key in row_dict and row_dict.get(self.image_key, None) is not None:\n                images = [process_image(image) for image in row_dict.pop(self.image_key)]\n\n                # due to the image key is \"image\" instead of \"images\" in vllm, we need to use \"image\" here\n                # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205\n                multi_modal_data[\"image\"] = images\n\n            videos = None\n            if self.video_key in row_dict and row_dict.get(self.video_key, None) is not None:\n                videos = [process_video(video) for video in row_dict.pop(self.video_key)]\n\n                # due to the video key is \"video\" instead of \"videos\" in vllm, we need to use \"video\" here\n                # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205\n                multi_modal_data[\"video\"] = [video.numpy() for video in videos]\n\n            model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors=\"pt\")\n\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n            if \"second_per_grid_ts\" in model_inputs:\n                model_inputs.pop(\"second_per_grid_ts\")\n\n            # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature\n            row_dict[\"multi_modal_data\"] = multi_modal_data\n\n            # We will do batch.union() in the trainer,\n            # so we cannot have \"multi_modal_inputs\" in row_dict if rollout generates new multi_modal_inputs\n            if self.return_multi_modal_inputs:\n                row_dict[\"multi_modal_inputs\"] = dict(model_inputs)\n\n                # second_per_grid_ts isn't used for training, just for mrope\n                row_dict[\"multi_modal_inputs\"].pop(\"second_per_grid_ts\", None)\n\n        else:\n            raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n            model_inputs = self.tokenizer(raw_prompt, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n        input_ids, attention_mask = verl_F.postprocess_data(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            max_length=self.max_prompt_length,\n            pad_token_id=self.tokenizer.pad_token_id,\n            left_pad=True,\n            truncation=self.truncation,\n        )\n\n        if self.processor is not None and \"Qwen2VLImageProcessor\" in self.processor.image_processor.__class__.__name__:\n            from verl.models.transformers.qwen2_vl import get_rope_index\n\n            position_ids = [\n                get_rope_index(\n                    self.processor,\n                    input_ids=input_ids[0],\n                    image_grid_thw=model_inputs.get(\"image_grid_thw\"),\n                    video_grid_thw=model_inputs.get(\"video_grid_thw\"),\n                    second_per_grid_ts=model_inputs.get(\"second_per_grid_ts\"),\n                    attention_mask=attention_mask[0],\n                )\n            ]  # (1, 3, seq_len)\n\n        else:\n            position_ids = compute_position_id_with_mask(attention_mask)\n\n        row_dict[\"input_ids\"] = input_ids[0]\n        row_dict[\"attention_mask\"] = attention_mask[0]\n        row_dict[\"position_ids\"] = position_ids[0]\n\n        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)\n        if len(raw_prompt_ids) > self.max_prompt_length:\n            if self.truncation == \"left\":\n                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]\n            elif self.truncation == \"right\":\n                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]\n            elif self.truncation == \"middle\":\n                left_half = self.max_prompt_length // 2\n                right_half = self.max_prompt_length - left_half\n                raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]\n            elif self.truncation == \"error\":\n                raise RuntimeError(f\"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.\")\n\n        row_dict[\"raw_prompt_ids\"] = raw_prompt_ids\n        # encode prompts without chat template\n        if self.return_raw_chat:\n            row_dict[\"raw_prompt\"] = messages\n\n        # get prompts with chat template\n        if self.return_full_prompt:\n            row_dict[\"full_prompts\"] = raw_prompt  # array of strings\n\n        # add index for each prompt\n        index = row_dict.get(\"extra_info\", {}).get(\"index\", 0)\n        tools_kwargs = row_dict.get(\"extra_info\", {}).get(\"tools_kwargs\", {})\n        interaction_kwargs = row_dict.get(\"extra_info\", {}).get(\"interaction_kwargs\", {})\n        need_tools_kwargs = row_dict.get(\"extra_info\", {}).get(\"need_tools_kwargs\", self.need_tools_kwargs)\n        if need_tools_kwargs and not tools_kwargs:\n            logger.warning(\"tools_kwargs is empty for index {}, data source: {}\", index, row_dict[\"data_source\"])\n        row_dict[\"index\"] = index\n        row_dict[\"tools_kwargs\"] = tools_kwargs\n        row_dict[\"interaction_kwargs\"] = interaction_kwargs\n        return row_dict\n\n    def __getstate__(self):\n        if not self.serialize_dataset:\n            state = self.__dict__.copy()\n\n            if \"dataframe\" in state:\n                del state[\"dataframe\"]\n            return state\n\n        return self.__dict__.copy()\n"
  },
  {
    "path": "verl_rl/verl/utils/dataset/rm_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport pandas as pd\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom verl.utils import hf_tokenizer\n\n\ndef download_files_distributed(download_fn):\n    import torch.distributed\n\n    if torch.distributed.is_initialized():\n        if torch.distributed.get_rank() == 0:\n            # download files\n            download_fn()\n\n        torch.distributed.barrier()\n    else:\n        # download anyway\n        download_fn()\n\n\nclass RMDataset(Dataset):\n    def __init__(\n        self,\n        parquet_files: str | list[str],\n        tokenizer,\n        prompt_key=\"prompt\",\n        chosen_key=\"chosen\",\n        rejected_key=\"rejected\",\n        max_length=1024,\n        add_eos=True,\n        cache_dir=\"~/.cache/verl/rm\",\n    ):\n        if not isinstance(parquet_files, list):\n            parquet_files = [parquet_files]\n\n        self.parquet_files = parquet_files\n        self.cache_dir = os.path.expanduser(cache_dir)\n        if isinstance(tokenizer, str):\n            tokenizer = hf_tokenizer(tokenizer)\n        self.tokenizer = tokenizer\n\n        self.prompt_key = prompt_key\n        self.chosen_key = chosen_key\n        self.rejected_key = rejected_key\n\n        self.add_eos = add_eos\n        self.max_length = max_length\n\n        self._download()\n        self._read_files_and_tokenize()\n\n    def _download(self):\n        def _download_files():\n            from verl.utils.fs import copy, is_non_local\n\n            os.makedirs(self.cache_dir, exist_ok=True)\n            assert os.path.exists(self.cache_dir)\n            for i, parquet_file in enumerate(self.parquet_files):\n                if is_non_local(parquet_file):\n                    dst = os.path.join(self.cache_dir, os.path.basename(parquet_file))\n                    if not os.path.exists(dst):\n                        copy(src=parquet_file, dst=dst)\n                    self.parquet_files[i] = dst\n\n        download_files_distributed(_download_files)\n\n    def _read_files_and_tokenize(self):\n        dataframes = []\n        for parquet_file in self.parquet_files:\n            # read parquet files and cache\n            dataframe = pd.read_parquet(parquet_file)\n            dataframes.append(dataframe)\n        self.dataframe = pd.concat(dataframes)\n        self.prompts = self.dataframe[self.prompt_key].tolist()\n        self.chosen_responses = self.dataframe[self.chosen_key].tolist()\n        self.rejected_responses = self.dataframe[self.rejected_key].tolist()\n\n    def __len__(self):\n        return len(self.prompts)\n\n    def _pad_to_length(self, input_ids, attention_mask):\n        curr_length = input_ids.shape[-1]\n\n        if curr_length < self.max_length:\n            input_ids = torch.cat(\n                (input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1\n            )\n            attention_mask = torch.cat(\n                (attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)), dim=-1\n            )\n        elif curr_length > self.max_length:\n            input_ids = input_ids[: self.max_length]\n            attention_mask = attention_mask[: self.max_length]\n\n        return input_ids, attention_mask\n\n    def __getitem__(self, item):\n        prompt = self.prompts[item]\n        chosen_response = self.chosen_responses[item]\n        rejected_response = self.rejected_responses[item]\n\n        prompt_ids = self.tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"][0]\n        chosen_response_ids = self.tokenizer(chosen_response, return_tensors=\"pt\")[\"input_ids\"][0]\n        rejected_response_ids = self.tokenizer(rejected_response, return_tensors=\"pt\")[\"input_ids\"][0]\n\n        if self.add_eos:\n            chosen_response_ids = torch.cat((chosen_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1)\n            rejected_response_ids = torch.cat(\n                (rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1\n            )\n\n        chosen_input_ids = torch.cat((prompt_ids, chosen_response_ids), dim=-1)\n        chosen_attention_mask = torch.ones_like(chosen_input_ids)\n\n        rejected_input_ids = torch.cat((prompt_ids, rejected_response_ids), dim=-1)\n        rejected_attention_mask = torch.ones_like(rejected_input_ids)\n\n        chosen_input_ids, chosen_attention_mask = self._pad_to_length(chosen_input_ids, chosen_attention_mask)\n        rejected_input_ids, rejected_attention_mask = self._pad_to_length(rejected_input_ids, rejected_attention_mask)\n\n        input_ids = torch.stack((chosen_input_ids, rejected_input_ids), dim=0)\n        attention_mask = torch.stack((chosen_attention_mask, rejected_attention_mask), dim=0)\n\n        return {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n        }\n"
  },
  {
    "path": "verl_rl/verl/utils/dataset/sft_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nSFT dataset\n- We assume user pass a single parquet file.\n- We load all the data into the memory.\nEach parquet file contains\n\"\"\"\n\nimport pandas as pd\nimport torch\nfrom omegaconf.listconfig import ListConfig\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer\n\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.model import compute_position_id_with_mask\n\n\nclass SFTDataset(Dataset):\n    \"\"\"\n    This is an in-memory SFTDataset\n\n    Arguments:\n        config (OmegaConf): the data config\n    \"\"\"\n\n    def __init__(self, parquet_files: str | ListConfig, tokenizer, config):\n        prompt_key = config.get(\"prompt_key\", \"prompt\")\n        prompt_dict_keys = config.get(\"prompt_dict_keys\", None)\n        response_key = config.get(\"response_key\", \"response\")\n        response_dict_keys = config.get(\"response_dict_keys\", None)\n        max_length = config.get(\"max_length\", 1024)\n        truncation = config.get(\"truncation\", \"error\")\n        use_shm = config.get(\"use_shm\", False)\n\n        assert truncation in [\"error\", \"left\", \"right\"]\n        self.truncation = truncation\n        self.use_shm = use_shm\n\n        if not isinstance(parquet_files, ListConfig):\n            parquet_files = [parquet_files]\n\n        self.parquet_files = parquet_files\n        if isinstance(tokenizer, str):\n            tokenizer = hf_tokenizer(tokenizer)\n        self.tokenizer: PreTrainedTokenizer = tokenizer\n\n        self.prompt_key = prompt_key if isinstance(prompt_key, tuple | list) else [prompt_key]\n        self.response_key = response_key if isinstance(response_key, tuple | list) else [response_key]\n        self.prompt_dict_keys = prompt_dict_keys if prompt_dict_keys else []\n        self.response_dict_keys = response_dict_keys if response_dict_keys else []\n\n        self.max_length = max_length\n\n        self._download()\n        self._read_files_and_tokenize()\n\n    def _download(self):\n        for i, parquet_file in enumerate(self.parquet_files):\n            self.parquet_files[i] = copy_to_local(parquet_file, verbose=True, use_shm=self.use_shm)\n\n    def _read_files_and_tokenize(self):\n        def series_to_item(ls):\n            import numpy\n            import pandas\n\n            while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1:\n                ls = ls[0]\n            return ls\n\n        dataframes = []\n        for parquet_file in self.parquet_files:\n            # read parquet files and cache\n            dataframe = pd.read_parquet(parquet_file)\n            dataframes.append(dataframe)\n        self.dataframe = pd.concat(dataframes)\n        self.prompts = self.dataframe[self.prompt_key]\n        for key in self.prompt_dict_keys:\n            # type(x): pandas.core.series.Series\n            # type(x[0]): numpy.ndarray\n            # type(x[0][0]): dict\n            try:\n                self.prompts = self.prompts.apply(lambda x: series_to_item(x)[key], axis=1)  # noqa: B023\n            except Exception:\n                print(f\"self.prompts={self.prompts}\")\n                raise\n        if isinstance(self.prompts, pd.DataFrame):\n            self.prompts = self.prompts.squeeze()\n        self.prompts = self.prompts.tolist()\n        self.responses = self.dataframe[self.response_key]\n        for key in self.response_dict_keys:\n            try:\n                self.responses = self.responses.apply(lambda x: series_to_item(x)[key], axis=1)  # noqa: B023\n            except Exception:\n                print(f\"self.responses={self.responses}\")\n                raise\n        if isinstance(self.responses, pd.DataFrame):\n            self.responses = self.responses.squeeze()\n        self.responses = self.responses.tolist()\n\n    def __len__(self):\n        return len(self.prompts)\n\n    def __getitem__(self, item):\n        tokenizer = self.tokenizer\n\n        prompt = self.prompts[item]\n        response = self.responses[item]\n\n        # apply chat template\n        prompt_chat = [{\"role\": \"user\", \"content\": prompt}]\n\n        # string\n        prompt_chat_str = tokenizer.apply_chat_template(prompt_chat, add_generation_prompt=True, tokenize=False)\n        response_chat_str = response + tokenizer.eos_token\n\n        # tokenize\n        prompt_ids_output = tokenizer(prompt_chat_str, return_tensors=\"pt\", add_special_tokens=False)\n        prompt_ids = prompt_ids_output[\"input_ids\"][0]\n        prompt_attention_mask = prompt_ids_output[\"attention_mask\"][0]\n\n        response_ids_output = tokenizer(response_chat_str, return_tensors=\"pt\", add_special_tokens=False)\n        response_ids = response_ids_output[\"input_ids\"][0]\n        response_attention_mask = response_ids_output[\"attention_mask\"][0]\n\n        prompt_length = prompt_ids.shape[0]\n        response_length = response_ids.shape[0]\n\n        input_ids = torch.cat((prompt_ids, response_ids), dim=-1)\n        attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)\n\n        # padding to max length\n        sequence_length = input_ids.shape[0]\n        if sequence_length < self.max_length:\n            padded_input_ids = (\n                torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype)\n                * self.tokenizer.pad_token_id\n            )\n            padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype)\n\n            input_ids = torch.cat((input_ids, padded_input_ids))\n            attention_mask = torch.cat((attention_mask, padded_attention_mask))\n        elif sequence_length > self.max_length:\n            if self.truncation == \"left\":\n                # actually, left truncation may not be reasonable\n                input_ids = input_ids[-self.max_length :]\n                attention_mask = attention_mask[-self.max_length :]\n            elif self.truncation == \"right\":\n                input_ids = input_ids[: self.max_length]\n                attention_mask = attention_mask[: self.max_length]\n            elif self.truncation == \"error\":\n                raise NotImplementedError(f\"{sequence_length=} is larger than {self.max_length=}\")\n            else:\n                raise NotImplementedError(f\"Unknown truncation method {self.truncation}\")\n\n        position_ids = compute_position_id_with_mask(attention_mask)\n\n        loss_mask = attention_mask.clone()\n        if prompt_length > 1:\n            # mask out prompt for SFT.\n            loss_mask[: min(prompt_length, loss_mask.size(0)) - 1] = 0\n        # mask out the last token in response\n        loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0\n\n        return {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"position_ids\": position_ids,\n            \"loss_mask\": loss_mask,\n        }\n"
  },
  {
    "path": "verl_rl/verl/utils/dataset/vision_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom io import BytesIO\nfrom typing import Optional\n\nimport torch\nfrom PIL import Image\nfrom qwen_vl_utils import fetch_image, fetch_video\n\n\ndef process_image(image: dict | Image.Image) -> Image.Image:\n    if isinstance(image, Image.Image):\n        return image.convert(\"RGB\")\n\n    if \"bytes\" in image:\n        assert \"image\" not in image, \"Cannot have both `bytes` and `image`\"\n        image[\"image\"] = BytesIO(image[\"bytes\"])\n\n    return fetch_image(image)\n\n\nVIDEO_FORMAT_HELP = \"\"\"Currently, we only support the video formats introduced in qwen2-vl.\nRefer to https://github.com/QwenLM/Qwen2.5-VL?tab=readme-ov-file#using---transformers-to-chat.\n\neg.\n{\n    \"type\": \"video\",\n    \"video\": [\n        \"file:///path/to/frame1.jpg\",\n        \"file:///path/to/frame2.jpg\"\n    ]\n}\n\n{\n    \"type\": \"video\",\n    \"video\": \"file:///path/to/video.mp4\"\n}\n# Defaults to fps=2, min_frames=4, max_frames=768\n\n{\n    \"type\": \"video\",\n    \"video\": \"file:///path/to/video.mp4\",\n    \"fps\": 2,\n    \"min_frames\": 1,\n    \"max_frames\": 32\n}\n\"\"\"\n\n\ndef process_video(\n    video: dict,\n    nframes: Optional[int] = None,\n    fps: Optional[float] = None,\n    fps_min_frames: Optional[int] = None,\n    fps_max_frames: Optional[int] = None,\n) -> torch.Tensor:\n    \"\"\"Converts a video dict into a [n_frames, 3, H, W] tensor\n\n    Add video sample FPS in a future MR\n    \"\"\"\n\n    if not isinstance(video, dict) or \"video\" not in video:\n        raise NotImplementedError(VIDEO_FORMAT_HELP)\n    assert nframes is None or fps is None, \"Can't use both `nframes` or `fps`\"\n\n    # Shallow copy... since we might want to add some keys\n    video = dict(video)\n\n    contains_sampling_rules = \"nframes\" in video or \"fps\" in video\n    if not contains_sampling_rules:\n        if nframes is not None:\n            video[\"nframes\"] = nframes\n        elif fps is not None:\n            video[\"fps\"] = fps\n            if fps_min_frames is not None:\n                video[\"min_frames\"] = fps_min_frames\n            if fps_max_frames is not None:\n                video[\"max_frames\"] = fps_max_frames\n\n    return fetch_video(video)\n\n\ndef process_multi_modal_inputs_for_minicpmo(input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs):\n    # Adjust image bounds based on left padding and cumulative sequence lengths\n    # This is necessary for MiniCPM-o's vision-language alignment\n    left_padding_length = torch.argmax(attention_mask, dim=1)\n    image_bounds = []\n    for i in range(len(multi_modal_inputs[\"image_bound\"])):\n        image_bound = (\n            multi_modal_inputs[\"image_bound\"][i].to(left_padding_length.device) - left_padding_length[i] + cu_seqlens[i]\n        )\n        image_bounds.append(image_bound)\n\n    # Flatten pixel values list for MiniCPM-o processing\n    pixel_values = []\n    for i in range(len(multi_modal_inputs[\"pixel_values\"])):\n        pixel_values.extend([p for p in multi_modal_inputs[\"pixel_values\"][i]])\n\n    multi_modal_inputs[\"pixel_values\"] = [pixel_values]\n    multi_modal_inputs[\"image_bound\"] = [torch.vstack(image_bounds)]\n    multi_modal_inputs[\"tgt_sizes\"] = [torch.vstack(multi_modal_inputs[\"tgt_sizes\"])]\n    multi_modal_inputs[\"input_ids\"] = input_ids\n    multi_modal_inputs[\"attention_mask\"] = attention_mask\n    multi_modal_inputs[\"position_ids\"] = position_ids\n    return {\"data\": multi_modal_inputs}\n"
  },
  {
    "path": "verl_rl/verl/utils/debug/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# APIs kept for backward compatibility purpose\n# For new features please develop in verl/utils/profiler/\nfrom ..profiler import *  # noqa\n"
  },
  {
    "path": "verl_rl/verl/utils/debug/performance.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# APIs kept for backward compatibility purpose\n# This file is deprecated, for new features please develop in profiler/performance.py\nfrom verl.utils.profiler.performance import simple_timer, reduce_timing  # noqa\n"
  },
  {
    "path": "verl_rl/verl/utils/debug/trajectory_tracker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTrajectory tracker can be inserted into code to save the intermediate results.\nThe results will be dump to hdfs for offline comparison.\nEach process will have a client that first move all the tensors to CPU\n\"\"\"\n\nimport io\nimport os\nimport tempfile\nfrom collections import deque\n\nimport ray\nimport torch\n\nfrom verl.utils.hdfs_io import copy, makedirs\n\nremote_copy = ray.remote(copy)\n\n\n@ray.remote\ndef save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose):\n    filename = name + \".pth\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        local_filepath = os.path.join(tmpdirname, filename)\n        with open(local_filepath, \"wb\") as f:\n            f.write(data.getbuffer())\n        # upload to hdfs\n\n        if verbose:\n            print(f\"Saving {local_filepath} to {hdfs_dir}\")\n        try:\n            copy(local_filepath, hdfs_dir)\n        except Exception as e:\n            print(e)\n\n\n@ray.remote\nclass TrajectoryTracker:\n    def __init__(self, hdfs_dir, verbose) -> None:\n        self.hdfs_dir = hdfs_dir\n        makedirs(hdfs_dir)\n        self.verbose = verbose\n\n        self.handle = deque()\n\n    def dump(self, data: io.BytesIO, name):\n        # get a temp file and write to it\n        self.handle.append(save_to_hdfs.remote(data, name, self.hdfs_dir, self.verbose))\n\n    def wait_for_hdfs(self):\n        while len(self.handle) != 0:\n            future = self.handle.popleft()\n            ray.get(future)\n\n\ndef dump_data(data, name):\n    enable = os.getenv(\"VERL_ENABLE_TRACKER\", \"0\") == \"1\"\n    if not enable:\n        return\n    buffer = io.BytesIO()\n    torch.save(data, buffer)\n    tracker = get_trajectory_tracker()\n    ray.get(tracker.dump.remote(buffer, name))\n\n\ndef get_trajectory_tracker():\n    hdfs_dir = os.getenv(\"VERL_TRACKER_HDFS_DIR\", default=None)\n    verbose = os.getenv(\"VERL_TRACKER_VERBOSE\", default=\"0\") == \"1\"\n    assert hdfs_dir is not None\n    tracker = TrajectoryTracker.options(name=\"global_tracker\", get_if_exists=True, lifetime=\"detached\").remote(\n        hdfs_dir, verbose\n    )\n    return tracker\n\n\nif __name__ == \"__main__\":\n    # testing\n    os.environ[\"VERL_ENABLE_TRACKER\"] = \"1\"\n    os.environ[\"VERL_TRACKER_HDFS_DIR\"] = \"~/debug/test\"\n\n    @ray.remote\n    def process(iter):\n        data = {\"obs\": torch.randn(10, 20)}\n        dump_data(data, f\"process_{iter}_obs\")\n\n    ray.init()\n\n    output_lst = []\n\n    for i in range(10):\n        output_lst.append(process.remote(i))\n\n    out = ray.get(output_lst)\n\n    tracker = get_trajectory_tracker()\n    ray.get(tracker.wait_for_hdfs.remote())\n"
  },
  {
    "path": "verl_rl/verl/utils/device.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# This code is inspired by the torchtune.\n# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py\n#\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license in https://github.com/pytorch/torchtune/blob/main/LICENSE\n\nimport logging\n\nimport torch\n\nlogger = logging.getLogger(__name__)\n\n\ndef is_torch_npu_available() -> bool:\n    \"\"\"Check the availability of NPU\"\"\"\n    try:\n        import torch_npu  # noqa: F401\n\n        return torch.npu.is_available()\n    except ImportError:\n        return False\n\n\nis_cuda_available = torch.cuda.is_available()\nis_npu_available = is_torch_npu_available()\n\n\ndef get_visible_devices_keyword() -> str:\n    \"\"\"Function that gets visible devices keyword name.\n    Returns:\n        'CUDA_VISIBLE_DEVICES' or `ASCEND_RT_VISIBLE_DEVICES`\n    \"\"\"\n    return \"CUDA_VISIBLE_DEVICES\" if is_cuda_available else \"ASCEND_RT_VISIBLE_DEVICES\"\n\n\ndef get_device_name() -> str:\n    \"\"\"Function that gets the torch.device based on the current machine.\n    This currently only supports CPU, CUDA, NPU.\n    Returns:\n        device\n    \"\"\"\n    if is_cuda_available:\n        device = \"cuda\"\n    elif is_npu_available:\n        device = \"npu\"\n    else:\n        device = \"cpu\"\n    return device\n\n\ndef get_torch_device() -> any:\n    \"\"\"Return the corresponding torch attribute based on the device type string.\n    Returns:\n        module: The corresponding torch device namespace, or torch.cuda if not found.\n    \"\"\"\n    device_name = get_device_name()\n    try:\n        return getattr(torch, device_name)\n    except AttributeError:\n        logger.warning(f\"Device namespace '{device_name}' not found in torch, try to load torch.cuda.\")\n        return torch.cuda\n\n\ndef get_device_id() -> int:\n    \"\"\"Return current device id based on the device type.\n    Returns:\n        device index\n    \"\"\"\n    return get_torch_device().current_device()\n\n\ndef get_nccl_backend() -> str:\n    \"\"\"Return nccl backend type based on the device type.\n    Returns:\n        nccl backend type string.\n    \"\"\"\n    if is_cuda_available:\n        return \"nccl\"\n    elif is_npu_available:\n        return \"hccl\"\n    else:\n        raise RuntimeError(f\"No available nccl backend found on device type {get_device_name()}.\")\n"
  },
  {
    "path": "verl_rl/verl/utils/distributed.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Utilities for distributed training.\"\"\"\n\nimport os\n\nimport torch.distributed\n\nfrom verl.utils.device import get_nccl_backend, get_torch_device\n\n\ndef initialize_global_process_group(timeout_second=36000):\n    from datetime import timedelta\n\n    torch.distributed.init_process_group(\n        get_nccl_backend(),\n        timeout=timedelta(seconds=timeout_second),\n        init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n    )\n    local_rank = int(os.environ[\"LOCAL_RANK\"])\n    rank = int(os.environ[\"RANK\"])\n    world_size = int(os.environ[\"WORLD_SIZE\"])\n\n    if torch.distributed.is_initialized():\n        get_torch_device().set_device(local_rank)\n    return local_rank, rank, world_size\n\n\ndef destroy_global_process_group():\n    if torch.distributed.is_initialized():\n        torch.distributed.destroy_process_group()\n"
  },
  {
    "path": "verl_rl/verl/utils/experimental/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/utils/experimental/torch_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\n\n\ndef _fused_linear_for_ppo_fwd(\n    hidden_states: torch.FloatTensor,\n    vocab_weights: torch.FloatTensor,\n    input_ids: torch.LongTensor,\n    temperature: float = 1.0,\n) -> tuple[torch.FloatTensor, torch.FloatTensor]:\n    logits = (hidden_states @ vocab_weights.t()) / temperature\n    orig_dtype = logits.dtype\n    logits = logits.to(torch.float32)\n\n    # Slower but more numerically stable to do log_softmax than probs.log()\n    probs = logits.softmax(dim=-1)\n    log_probs = logits.log_softmax(dim=-1)\n\n    token_log_probs = log_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)\n    entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1)\n\n    return token_log_probs.to(orig_dtype), entropy.to(orig_dtype)\n\n\ndef _fused_linear_for_ppo_bwd(\n    dlog_probs: Optional[torch.FloatTensor],\n    dentropy: Optional[torch.FloatTensor],\n    hidden_states: torch.FloatTensor,\n    vocab_weights: torch.FloatTensor,\n    input_ids: torch.LongTensor,\n    temperature: float = 1.0,\n) -> tuple[torch.FloatTensor, torch.FloatTensor]:\n    logits = (hidden_states @ vocab_weights.t()) / temperature\n    orig_dtype = logits.dtype\n    logits = logits.to(torch.float32)\n\n    probs = logits.softmax(dim=-1)\n\n    dlogits = 0\n\n    # Gradient from log_probs\n    if dlog_probs is not None:\n        one_hot_input = torch.zeros_like(logits).scatter_(-1, input_ids.unsqueeze(-1), 1)\n        dlogits += dlog_probs.to(torch.float32).unsqueeze(-1) * (one_hot_input - probs)\n\n    # Gradient from entropy\n    if dentropy is not None:\n        log_probs = logits.log_softmax(dim=-1)\n        entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1)\n        dlogits += probs * (log_probs + entropy.unsqueeze(-1)) * (-dentropy.unsqueeze(-1))\n\n    dlogits = dlogits.to(orig_dtype) / temperature\n\n    dhidden_states = dlogits @ vocab_weights\n    dvocab_weights = dlogits.t() @ hidden_states\n\n    return dhidden_states, dvocab_weights\n\n\nclass FusedLinearForPPOFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        hidden_states: torch.FloatTensor,\n        vocab_weights: torch.FloatTensor,\n        input_ids: torch.LongTensor,\n        temperature: float = 1.0,\n        chunk_size: int = 512,\n    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:\n        ctx.set_materialize_grads(False)\n\n        # Cast to a 2D tensor of the shape [T, D] for ease of working\n        orig_ndim = hidden_states.ndim\n        assert orig_ndim in (2, 3), f\"Invalid hidden_states shape, received {hidden_states.shape}\"\n\n        orig_batch_size = -1\n        if orig_ndim == 3:\n            assert input_ids.ndim == 2, f\"input_ids shape doesn't match, {hidden_states.shape} {input_ids.shape}\"\n            orig_batch_size = hidden_states.shape[0]\n            hidden_states = hidden_states.flatten(0, 1)\n            input_ids = input_ids.flatten(0, 1)\n\n        T = hidden_states.shape[0]\n\n        # Allocate memory for outputs\n        output_requires_grad = hidden_states.requires_grad or vocab_weights.requires_grad\n        log_probs = hidden_states.new_zeros(T, requires_grad=output_requires_grad)\n        entropy = hidden_states.new_zeros(T, requires_grad=output_requires_grad)\n\n        # Perform forward one chunk at a time\n        for chunk_start in range(0, T, chunk_size):\n            chunk_end = min(chunk_start + chunk_size, T)\n\n            chunk_log_probs, chunk_entropy = _fused_linear_for_ppo_fwd(\n                hidden_states=hidden_states[chunk_start:chunk_end],\n                vocab_weights=vocab_weights,\n                input_ids=input_ids[chunk_start:chunk_end],\n                temperature=temperature,\n            )\n            log_probs[chunk_start:chunk_end] = chunk_log_probs\n            entropy[chunk_start:chunk_end] = chunk_entropy\n\n        # Cast the output back to the original input dimension\n        if orig_ndim == 3:\n            log_probs = log_probs.view(orig_batch_size, -1)\n            entropy = entropy.view(orig_batch_size, -1)\n\n        ctx.save_for_backward(hidden_states, vocab_weights, input_ids)\n        ctx.orig_batch_size = orig_batch_size\n        ctx.orig_ndim = orig_ndim\n        ctx.temperature = temperature\n        ctx.chunk_size = chunk_size\n\n        return log_probs, entropy\n\n    @staticmethod\n    def backward(ctx, dlog_probs: Optional[torch.FloatTensor], dentropy: Optional[torch.FloatTensor]):\n        assert dlog_probs is not None or dentropy is not None\n\n        hidden_states, vocab_weights, input_ids = ctx.saved_tensors\n        orig_batch_size = ctx.orig_batch_size\n        orig_ndim = ctx.orig_ndim\n        temperature = ctx.temperature\n        chunk_size = ctx.chunk_size\n\n        # Here orig_ndim refers to the orig_ndim of hidden_states\n        if orig_ndim == 3:\n            if dlog_probs is not None:\n                dlog_probs = dlog_probs.flatten()\n            if dentropy is not None:\n                dentropy = dentropy.flatten()\n\n        T = hidden_states.shape[0]\n\n        # Allocate memory for outputs\n        dhidden_states = None\n        if hidden_states.requires_grad:\n            dhidden_states = torch.zeros_like(hidden_states)\n        dvocab_weights = None\n        if vocab_weights.requires_grad:\n            dvocab_weights = torch.zeros_like(vocab_weights)\n\n        # Perform backward one chunk at a time\n        for chunk_start in range(0, T, chunk_size):\n            chunk_end = min(chunk_start + chunk_size, T)\n            chunk_dlog_probs = None\n            if dlog_probs is not None:\n                chunk_dlog_probs = dlog_probs[chunk_start:chunk_end]\n            chunk_dentropy = None\n            if dentropy is not None:\n                chunk_dentropy = dentropy[chunk_start:chunk_end]\n\n            h, v = _fused_linear_for_ppo_bwd(\n                dlog_probs=chunk_dlog_probs,\n                dentropy=chunk_dentropy,\n                hidden_states=hidden_states[chunk_start:chunk_end],\n                vocab_weights=vocab_weights,\n                input_ids=input_ids[chunk_start:chunk_end],\n                temperature=temperature,\n            )\n\n            if hidden_states.requires_grad:\n                dhidden_states[chunk_start:chunk_end] += h\n            if vocab_weights.requires_grad:\n                dvocab_weights += v\n\n        # Cast the output back to the original input dimension\n        if orig_ndim == 3 and hidden_states.requires_grad:\n            hidden_size = hidden_states.shape[-1]\n            dhidden_states = dhidden_states.view(orig_batch_size, -1, hidden_size)\n\n        return (\n            dhidden_states,  # hidden_states\n            dvocab_weights,  # vocab_weights\n            None,  # input_ids\n            None,  # temperature\n            None,  # chunk_size\n        )\n\n\nclass FusedLinearForPPO(torch.nn.Module):\n    def __init__(self, chunk_size: int = 512):\n        super().__init__()\n\n        self.chunk_size = chunk_size\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        vocab_weights: torch.FloatTensor,\n        input_ids: torch.LongTensor,\n        temperature: float = 1.0,\n    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:\n        input_ids = input_ids.to(torch.int64)\n        return FusedLinearForPPOFunction.apply(\n            hidden_states,\n            vocab_weights,\n            input_ids,\n            temperature,\n            self.chunk_size,\n        )\n"
  },
  {
    "path": "verl_rl/verl/utils/flops_counter.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers import PretrainedConfig\n\nfrom verl.utils.device import get_torch_device\n\nVALID_CONFIG_TYPE = {\n    \"llama\",\n    \"qwen2\",\n    \"qwen2_vl\",\n    \"qwen2_5_vl\",\n    \"qwen3\",\n    \"qwen3_moe\",\n    \"deepseek_v3\",\n    \"minicpmv\",\n    \"minicpmo\",\n    \"mistral\",\n    \"gemma3_text\",\n}\n\n\ndef get_device_flops(unit=\"T\"):\n    def unit_convert(number, level):\n        units = [\"B\", \"K\", \"M\", \"G\", \"T\", \"P\"]\n        if number <= 0:\n            return number\n        ptr = 0\n        while ptr < len(units) and units[ptr] != level:\n            number /= 1000\n            ptr += 1\n        return number\n\n    device_name = get_torch_device().get_device_name()\n    flops = float(\"inf\")  # INF flops for unkown gpu type\n\n    if \"MI300X\" in device_name:\n        flops = 1336e12\n    elif \"H100\" in device_name or \"H800\" in device_name or \"H200\" in device_name:\n        flops = 989e12\n    elif \"A100\" in device_name or \"A800\" in device_name:\n        flops = 312e12\n    elif \"L40\" in device_name:\n        flops = 181.05e12\n    elif \"L20\" in device_name:\n        flops = 119.5e12\n    elif \"H20\" in device_name:\n        flops = 148e12\n    elif \"910B\" in device_name:\n        flops = 354e12\n    elif \"RTX 3070 Ti\" in device_name:\n        flops = 21.75e12\n    flops_unit = unit_convert(flops, unit)\n    return flops_unit\n\n\nclass FlopsCounter:\n    \"\"\"\n    Used to count mfu during training loop\n\n    Example:\n        flops_counter = FlopsCounter(config)\n        flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)\n\n    \"\"\"\n\n    def __init__(self, config: PretrainedConfig):\n        if config.model_type not in VALID_CONFIG_TYPE:\n            print(\n                f\"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. MFU will always be \"\n                f\"zero.\"\n            )\n\n        self.estimate_func = {\n            \"qwen2\": self._estimate_qwen2_flops,\n            \"llama\": self._estimate_qwen2_flops,\n            \"qwen2_moe\": self._estimate_qwen2_moe_flops,\n            \"qwen2_vl\": self._estimate_qwen2_flops,\n            \"qwen2_5_vl\": self._estimate_qwen2_flops,\n            \"qwen3\": self._estimate_qwen2_flops,\n            \"qwen3_moe\": self._estimate_qwen2_moe_flops,\n            \"deepseek_v3\": self._estimate_deepseek_v3_flops,\n            \"minicpmv\": self._estimate_qwen2_flops,\n            \"minicpmo\": self._estimate_qwen2_flops,\n            \"mistral\": self._estimate_qwen2_flops,\n            \"gemma3_text\": self._estimate_gemma3_flops,\n        }\n        self.config = config\n\n    def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):\n        return 0\n\n    def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        num_hidden_layers = self.config.num_hidden_layers\n        num_key_value_heads = self.config.num_key_value_heads\n        num_attention_heads = self.config.num_attention_heads\n        intermediate_size = self.config.intermediate_size\n\n        head_dim = getattr(self.config, \"head_dim\", self.config.hidden_size // self.config.num_attention_heads)\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # non-attn per layer parm\n        # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp\n        mlp_N = hidden_size * intermediate_size * 3\n        attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * dense_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen\n        attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        moe_intermediate_size = self.config.moe_intermediate_size\n        num_hidden_layers = self.config.num_hidden_layers\n        first_k_dense_replace = self.config.first_k_dense_replace\n        num_query_heads = self.config.num_attention_heads\n        moe_num_expert = self.config.n_routed_experts\n\n        moe_topk = self.config.num_experts_per_tok\n        share_expert_num = self.config.n_shared_experts\n\n        # non-attn per layer parm\n        moe_gata_N = hidden_size * moe_num_expert\n        # moe has fc1_1, fc1_2 and fc2 using SwiGLU in ExpertMlp layer & shared experts\n        moe_expertmlp_N = hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3\n        # MLA attn\n        attn_linear_N = 0\n        q_head_dim = self.config.qk_nope_head_dim + self.config.qk_rope_head_dim\n        if self.config.q_lora_rank is None:\n            attn_linear_N += hidden_size * num_query_heads * q_head_dim\n        else:\n            attn_linear_N += hidden_size * self.config.q_lora_rank\n            attn_linear_N += num_query_heads * q_head_dim * self.config.q_lora_rank\n\n        attn_linear_N += hidden_size * (self.config.kv_lora_rank + self.config.qk_rope_head_dim)\n        attn_linear_N += (\n            num_query_heads\n            * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim)\n            * self.config.kv_lora_rank\n        )\n        attn_linear_N += num_query_heads * self.config.v_head_dim * hidden_size\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        moe_N = (\n            (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace)\n            + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace\n            + emd_and_lm_head_N\n        )\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * moe_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen * num_hidden_layers\n\n        attn_qkv_flops = 12 * seqlen_square_sum * q_head_dim * num_query_heads\n        # all_layer & all_token fwd & bwk flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n\n        return flops_achieved\n\n    def _estimate_qwen2_moe_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        num_hidden_layers = self.config.num_hidden_layers\n        num_key_value_heads = self.config.num_key_value_heads\n        num_attention_heads = self.config.num_attention_heads\n        moe_intermediate_size = self.config.moe_intermediate_size\n        moe_topk = self.config.num_experts_per_tok\n        num_experts = self.config.num_experts\n\n        head_dim = getattr(self.config, \"head_dim\", self.config.hidden_size // self.config.num_attention_heads)\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # non-attn per layer parm\n        # gate + moe export\n        moe_mlp_N = hidden_size * moe_topk * moe_intermediate_size * 3 + hidden_size * num_experts\n        attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * dense_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen\n        attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def _estimate_gemma3_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        num_hidden_layers = self.config.num_hidden_layers\n        num_key_value_heads = self.config.num_key_value_heads\n        num_attention_heads = self.config.num_attention_heads\n        intermediate_size = self.config.intermediate_size\n\n        head_dim = getattr(self.config, \"head_dim\", self.config.hidden_size // self.config.num_attention_heads)\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # non-attn per layer parm\n        # Gemma3 uses GeGLU (gelu_pytorch_tanh), having 3 matrices in MLP (inherited from Gemma2MLP)\n        mlp_N = hidden_size * intermediate_size * 3\n        attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * dense_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        # Gemma3 alternates between full and sliding window attention based on layer_types\n        seqlen_square_sum = 0\n\n        layer_types = getattr(self.config, \"layer_types\", None)\n        sliding_window = getattr(self.config, \"sliding_window\", 1024)  # default 1024\n        # default pattern: every 6th layer is full\n        sliding_window_pattern = getattr(self.config, \"sliding_window_pattern\", 6)\n\n        # If layer_types is not provided, generate it based on sliding_window_pattern\n        if layer_types is None and sliding_window is not None and sliding_window_pattern is not None:\n            layer_types = [\n                \"sliding_attention\" if bool((i + 1) % sliding_window_pattern) else \"full_attention\"\n                for i in range(num_hidden_layers)\n            ]\n\n        if layer_types:\n            # Calculate attention flops per layer based on attention type\n            for layer_idx in range(num_hidden_layers):\n                is_sliding = False\n                if layer_types and layer_idx < len(layer_types):\n                    is_sliding = layer_types[layer_idx] == \"sliding_attention\"\n\n                for seqlen in batch_seqlens:\n                    if is_sliding and sliding_window:\n                        # Sliding window limits each token to attend to at most window_size tokens\n                        effective_seqlen = min(seqlen, sliding_window)\n                        seqlen_square_sum += seqlen * effective_seqlen\n                    else:\n                        # Full attention\n                        seqlen_square_sum += seqlen * seqlen\n        else:\n            # If no layer_types config, assume all layers use full attention\n            for seqlen in batch_seqlens:\n                seqlen_square_sum += seqlen * seqlen\n            seqlen_square_sum *= num_hidden_layers\n\n        attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def estimate_flops(self, batch_seqlens, delta_time):\n        \"\"\"\n        Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.\n\n        Args:\n            batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the\n                current batch.\n            delta_time (float): The time taken to process the batch, in seconds.\n\n        Returns:\n            estimated_flops (float): The estimated FLOPS based on the input tokens and time.\n            promised_flops (float): The expected FLOPS of the current device.\n        \"\"\"\n        tokens_sum = sum(batch_seqlens)\n        func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)\n        estimated_flops = func(tokens_sum, batch_seqlens, delta_time)\n        promised_flops = get_device_flops()\n        return estimated_flops, promised_flops\n"
  },
  {
    "path": "verl_rl/verl/utils/fs.py",
    "content": "#!/usr/bin/env python\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# -*- coding: utf-8 -*-\n\"\"\"File-system agnostic IO APIs\"\"\"\n\nimport hashlib\nimport os\nimport shutil\nimport tempfile\n\ntry:\n    from hdfs_io import copy, exists, makedirs  # for internal use only\nexcept ImportError:\n    from .hdfs_io import copy, exists, makedirs\n\n__all__ = [\"copy\", \"exists\", \"makedirs\"]\n\n_HDFS_PREFIX = \"hdfs://\"\n\n\ndef is_non_local(path):\n    \"\"\"Check if a path is a non-local (HDFS) path.\n\n    Args:\n        path (str): The path to check.\n\n    Returns:\n        bool: True if the path is an HDFS path, False otherwise.\n    \"\"\"\n    return path.startswith(_HDFS_PREFIX)\n\n\ndef md5_encode(path: str) -> str:\n    \"\"\"Generate an MD5 hash of a path string.\n\n    This function is used to create unique identifiers for paths, typically\n    for creating cache directories or lock files.\n\n    Args:\n        path (str): The path to encode.\n\n    Returns:\n        str: The hexadecimal MD5 hash of the path.\n    \"\"\"\n    return hashlib.md5(path.encode()).hexdigest()\n\n\ndef get_local_temp_path(hdfs_path: str, cache_dir: str) -> str:\n    \"\"\"Generate a unique local cache path for an HDFS resource.\n    Creates a MD5-hashed subdirectory in cache_dir to avoid name conflicts,\n    then returns path combining this subdirectory with the HDFS basename.\n\n    Args:\n        hdfs_path (str): Source HDFS path to be cached\n        cache_dir (str): Local directory for storing cached files\n\n    Returns:\n        str: Absolute local filesystem path in format:\n            {cache_dir}/{md5(hdfs_path)}/{basename(hdfs_path)}\n    \"\"\"\n    # make a base64 encoding of hdfs_path to avoid directory conflict\n    encoded_hdfs_path = md5_encode(hdfs_path)\n    temp_dir = os.path.join(cache_dir, encoded_hdfs_path)\n    os.makedirs(temp_dir, exist_ok=True)\n    dst = os.path.join(temp_dir, os.path.basename(hdfs_path))\n    return dst\n\n\ndef verify_copy(src: str, dest: str) -> bool:\n    \"\"\"\n    verify the copy of src to dest by comparing their sizes and file structures.\n\n    return:\n        bool: True if the copy is verified, False otherwise.\n    \"\"\"\n    if not os.path.exists(src):\n        return False\n    if not os.path.exists(dest):\n        return False\n\n    if os.path.isfile(src) != os.path.isfile(dest):\n        return False\n\n    if os.path.isfile(src):\n        src_size = os.path.getsize(src)\n        dest_size = os.path.getsize(dest)\n        if src_size != dest_size:\n            return False\n        return True\n\n    src_files = set()\n    dest_files = set()\n\n    for root, dirs, files in os.walk(src):\n        rel_path = os.path.relpath(root, src)\n        dest_root = os.path.join(dest, rel_path) if rel_path != \".\" else dest\n\n        if not os.path.exists(dest_root):\n            return False\n\n        for entry in os.listdir(root):\n            src_entry = os.path.join(root, entry)\n            src_files.add(os.path.relpath(src_entry, src))\n\n        for entry in os.listdir(dest_root):\n            dest_entry = os.path.join(dest_root, entry)\n            dest_files.add(os.path.relpath(dest_entry, dest))\n\n    if src_files != dest_files:\n        return False\n\n    for rel_path in src_files:\n        src_entry = os.path.join(src, rel_path)\n        dest_entry = os.path.join(dest, rel_path)\n\n        if os.path.isdir(src_entry) != os.path.isdir(dest_entry):\n            return False\n\n        if os.path.isfile(src_entry):\n            src_size = os.path.getsize(src_entry)\n            dest_size = os.path.getsize(dest_entry)\n            if src_size != dest_size:\n                return False\n\n    return True\n\n\ndef copy_to_shm(src: str):\n    \"\"\"\n    Load the model into   /dev/shm   to make the process of loading the model multiple times more efficient.\n    \"\"\"\n    shm_model_root = \"/dev/shm/verl-cache/\"\n    src_abs = os.path.abspath(os.path.normpath(src))\n    dest = os.path.join(shm_model_root, hashlib.md5(src_abs.encode(\"utf-8\")).hexdigest())\n    os.makedirs(dest, exist_ok=True)\n    dest = os.path.join(dest, os.path.basename(src_abs))\n    if os.path.exists(dest) and verify_copy(src, dest):\n        # inform user and depends on him\n        print(\n            f\"[WARNING]: The memory model path {dest} already exists. If it is not you want, please clear it and \"\n            f\"restart the task.\"\n        )\n    else:\n        if os.path.isdir(src):\n            shutil.copytree(src, dest, symlinks=False, dirs_exist_ok=True)\n        else:\n            shutil.copy2(src, dest)\n    return dest\n\n\ndef _record_directory_structure(folder_path):\n    record_file = os.path.join(folder_path, \".directory_record.txt\")\n    with open(record_file, \"w\") as f:\n        for root, dirs, files in os.walk(folder_path):\n            for dir_name in dirs:\n                relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path)\n                f.write(f\"dir:{relative_dir}\\n\")\n            for file_name in files:\n                if file_name != \".directory_record.txt\":\n                    relative_file = os.path.relpath(os.path.join(root, file_name), folder_path)\n                    f.write(f\"file:{relative_file}\\n\")\n    return record_file\n\n\ndef _check_directory_structure(folder_path, record_file):\n    if not os.path.exists(record_file):\n        return False\n    existing_entries = set()\n    for root, dirs, files in os.walk(folder_path):\n        for dir_name in dirs:\n            relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path)\n            existing_entries.add(f\"dir:{relative_dir}\")\n        for file_name in files:\n            if file_name != \".directory_record.txt\":\n                relative_file = os.path.relpath(os.path.join(root, file_name), folder_path)\n                existing_entries.add(f\"file:{relative_file}\")\n    with open(record_file) as f:\n        recorded_entries = set(f.read().splitlines())\n    return existing_entries == recorded_entries\n\n\ndef copy_to_local(\n    src: str, cache_dir=None, filelock=\".file.lock\", verbose=False, always_recopy=False, use_shm: bool = False\n) -> str:\n    \"\"\"Copy files/directories from HDFS to local cache with validation.\n\n    Args:\n        src (str): Source path - HDFS path (hdfs://...) or local filesystem path\n        cache_dir (str, optional): Local directory for cached files. Uses system tempdir if None\n        filelock (str): Base name for file lock. Defaults to \".file.lock\"\n        verbose (bool): Enable copy operation logging. Defaults to False\n        always_recopy (bool): Force fresh copy ignoring cache. Defaults to False\n        use_shm (bool): Enable shared memory copy. Defaults to False\n\n    Returns:\n        str: Local filesystem path to copied resource\n    \"\"\"\n    # Save to a local path for persistence.\n    local_path = copy_local_path_from_hdfs(src, cache_dir, filelock, verbose, always_recopy)\n    # Load into shm to improve efficiency.\n    if use_shm:\n        return copy_to_shm(local_path)\n    return local_path\n\n\ndef copy_local_path_from_hdfs(\n    src: str, cache_dir=None, filelock=\".file.lock\", verbose=False, always_recopy=False\n) -> str:\n    \"\"\"Deprecated. Please use copy_to_local instead.\"\"\"\n    from filelock import FileLock\n\n    assert src[-1] != \"/\", f\"Make sure the last char in src is not / because it will cause error. Got {src}\"\n\n    if is_non_local(src):\n        # download from hdfs to local\n        if cache_dir is None:\n            # get a temp folder\n            cache_dir = tempfile.gettempdir()\n        os.makedirs(cache_dir, exist_ok=True)\n        assert os.path.exists(cache_dir)\n        local_path = get_local_temp_path(src, cache_dir)\n        # get a specific lock\n        filelock = md5_encode(src) + \".lock\"\n        lock_file = os.path.join(cache_dir, filelock)\n        with FileLock(lock_file=lock_file):\n            if always_recopy and os.path.exists(local_path):\n                if os.path.isdir(local_path):\n                    shutil.rmtree(local_path, ignore_errors=True)\n                else:\n                    os.remove(local_path)\n            if not os.path.exists(local_path):\n                if verbose:\n                    print(f\"Copy from {src} to {local_path}\")\n                copy(src, local_path)\n                if os.path.isdir(local_path):\n                    _record_directory_structure(local_path)\n            elif os.path.isdir(local_path):\n                # always_recopy=False, local path exists, and it is a folder: check whether there is anything missed\n                record_file = os.path.join(local_path, \".directory_record.txt\")\n                if not _check_directory_structure(local_path, record_file):\n                    if verbose:\n                        print(f\"Recopy from {src} to {local_path} due to missing files or directories.\")\n                    shutil.rmtree(local_path, ignore_errors=True)\n                    copy(src, local_path)\n                    _record_directory_structure(local_path)\n        return local_path\n    else:\n        return src\n\n\ndef local_mkdir_safe(path):\n    \"\"\"_summary_\n    Thread-safe directory creation function that ensures the directory is created\n    even if multiple processes attempt to create it simultaneously.\n\n    Args:\n        path (str): The path to create a directory at.\n    \"\"\"\n\n    from filelock import FileLock\n\n    if not os.path.isabs(path):\n        working_dir = os.getcwd()\n        path = os.path.join(working_dir, path)\n\n    # Using hash value of path as lock file name to avoid long file name\n    lock_filename = f\"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock\"\n    lock_path = os.path.join(tempfile.gettempdir(), lock_filename)\n\n    try:\n        with FileLock(lock_path, timeout=60):  # Add timeout\n            # make a new dir\n            os.makedirs(path, exist_ok=True)\n    except Exception as e:\n        print(f\"Warning: Failed to acquire lock for {path}: {e}\")\n        # Even if the lock is not acquired, try to create the directory\n        os.makedirs(path, exist_ok=True)\n\n    return path\n"
  },
  {
    "path": "verl_rl/verl/utils/fsdp_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nimport itertools\nimport json\nimport math\nimport os\nfrom collections import OrderedDict\nfrom contextlib import contextmanager, nullcontext\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom packaging import version\nfrom torch.distributed import DeviceMesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp._runtime_utils import _lazy_init\nfrom torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy\nfrom transformers.trainer_pt_utils import get_module_class_from_name\n\nfrom verl.utils.device import get_device_id, get_device_name, get_torch_device\n\nif version.parse(torch.__version__) >= version.parse(\"2.6\"):\n    from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard\nelif version.parse(torch.__version__) >= version.parse(\"2.4\"):\n    from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard\nelse:\n    fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = None, None, None, None\n\n\ndef init_fn(x: torch.nn.Module):\n    if torch.distributed.get_rank() != 0:\n        x = x.to_empty(device=get_device_id(), recurse=False)\n        get_torch_device().empty_cache()\n    return x\n\n\ndef get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None):\n    from accelerate import init_empty_weights\n\n    cpu_init_weights = lambda: torch.device(\"cpu\")\n    if use_meta_tensor:\n        if mesh is None:\n            init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights\n        else:\n            init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights\n    else:\n        init_context = cpu_init_weights\n    return init_context\n\n\n# Copyright 2020-present the HuggingFace Inc. team.\n# Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py\ndef get_fsdp_wrap_policy(module, config=None, is_lora=False):\n    \"\"\"Get FSDP wrap policy for the module.\n\n    Args:\n        module: The module to get wrap policy for\n        config: Configuration for wrap policy\n        is_lora: Whether to enable lambda policy for LoRA modules\n    \"\"\"\n    if config is None:\n        config = {}\n\n    # NOTE: This is a temporary workaround to be compatible with the OmegaConf & dataclass. We will remove this\n    # once we have make all config in verl from OmegaConf to data class.\n    def _get_attr(attr_name, default_value=None):\n        if hasattr(config, \"get\"):\n            return config.get(attr_name, default_value)\n        else:\n            return config.__getattribute__(attr_name)\n\n    if _get_attr(\"disable\", False):\n        return None\n\n    default_transformer_cls_names_to_wrap = getattr(module, \"_no_split_modules\", None)\n    fsdp_transformer_layer_cls_to_wrap = _get_attr(\n        \"transformer_layer_cls_to_wrap\", default_transformer_cls_names_to_wrap\n    )\n    min_num_params = _get_attr(\"min_num_params\", 0)\n    auto_wrap_policy = None\n\n    policies = []\n\n    from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy\n\n    # Add lambda policy for LoRA modules if is_lora is True\n    if is_lora:\n\n        def lambda_policy_fn(module):\n            return bool(\n                len(list(module.named_children())) == 0\n                and getattr(module, \"weight\", None) is not None\n                and module.weight.requires_grad\n            )\n\n        lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)\n        policies.append(lambda_policy)\n\n    if min_num_params > 0:\n        size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params)\n        policies.append(size_policy)\n    elif fsdp_transformer_layer_cls_to_wrap is not None:\n        transformer_cls_to_wrap = set()\n        for layer_class in fsdp_transformer_layer_cls_to_wrap:\n            transformer_cls = get_module_class_from_name(module, layer_class)\n            if transformer_cls is None:\n                raise Exception(\"Could not find the transformer layer class to wrap in the model.\")\n            else:\n                transformer_cls_to_wrap.add(transformer_cls)\n\n        transformer_policy = functools.partial(\n            transformer_auto_wrap_policy,\n            transformer_layer_cls=transformer_cls_to_wrap,\n        )\n        policies.append(transformer_policy)\n\n    if len(policies) > 0:\n        auto_wrap_policy = functools.partial(_or_policy, policies=policies)\n\n    return auto_wrap_policy\n\n\n@torch.no_grad()\ndef offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):\n    if fsdp_version(model) == 2:\n        offload_fsdp2_model_to_cpu(model, empty_cache)\n        return\n\n    assert isinstance(model, FSDP)\n    # lazy init FSDP model\n    _lazy_init(model, model)\n    assert model._is_root, \"Only support root model offloading to CPU\"\n    for handle in model._all_handles:\n        if handle._offload_params:\n            continue\n        flat_param = handle.flat_param\n        assert (\n            flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()\n            and id(flat_param.data) != id(flat_param._local_shard)\n            and flat_param.data.size() == flat_param._local_shard.size()\n        )\n        handle.flat_param_to(torch.device(\"cpu\"), non_blocking=True)\n        # the following still keeps id(._local_shard) != id(.data)\n        flat_param._local_shard = flat_param.data\n        assert id(flat_param._local_shard) != id(flat_param.data)\n    if empty_cache:\n        get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef offload_fsdp2_model_to_cpu(model, empty_cache: bool = True):\n    for param in model.parameters():\n        param.data = param.data.to(torch.device(\"cpu\"), non_blocking=True)\n    if empty_cache:\n        get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef load_fsdp_model_to_gpu(model: FSDP):\n    if fsdp_version(model) == 2:\n        load_fsdp2_model_to_gpu(model)\n        return\n\n    assert isinstance(model, FSDP)\n    # lazy init FSDP model\n    _lazy_init(model, model)\n    assert model._is_root, \"Only support root model loading to GPU\"\n    device_id = get_device_id()\n    for handle in model._all_handles:\n        if handle._offload_params:\n            continue\n        flat_param = handle.flat_param\n        handle.flat_param_to(torch.device(f\"{get_device_name()}:{device_id}\"), non_blocking=True)\n        # the following still keeps id(._local_shard) != id(.data)\n        flat_param._local_shard = flat_param.data\n\n\n@torch.no_grad()\ndef load_fsdp2_model_to_gpu(model):\n    device = get_device_id()\n    for param in model.parameters():\n        param.data = param.data.to(device, non_blocking=True)\n\n\n@torch.no_grad()\ndef offload_fsdp_optimizer(optimizer):\n    if not optimizer.state:\n        return\n    for param_group in optimizer.param_groups:\n        for param in param_group[\"params\"]:\n            state = optimizer.state[param]\n            for key, value in state.items():\n                if isinstance(value, torch.Tensor):\n                    state[key] = value.to(\"cpu\", non_blocking=True)\n\n\n@torch.no_grad()\ndef load_fsdp_optimizer(optimizer, device_id):\n    if not optimizer.state:\n        return\n    for param_group in optimizer.param_groups:\n        for param in param_group[\"params\"]:\n            state = optimizer.state[param]\n            for key, value in state.items():\n                if isinstance(value, torch.Tensor):\n                    state[key] = value.to(device_id, non_blocking=True)\n\n\n@contextmanager\ndef meta_device_init():\n    \"\"\"\n    Create model parameters with meta device.\n\n    Note buffers in model will still be initialized in default device (e.g., CPU),\n    since the buffers can be non-persistent and filled with expected values that can\n    NOT be captured in meta device.\n    \"\"\"\n    device = torch.device(\"meta\")\n    old_register_parameter = nn.Module.register_parameter\n    registered = set()\n\n    def register_empty_parameter(module, name, param):\n        old_register_parameter(module, name, param)\n        # we will skip register shared parameters as it\n        # is already registered previously\n        if param is not None and param not in registered:\n            param_cls = type(module._parameters[name])\n            kwargs = module._parameters[name].__dict__\n            kwargs[\"requires_grad\"] = param.requires_grad\n            module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)\n            registered.add(module._parameters[name])\n\n    try:\n        nn.Module.register_parameter = register_empty_parameter\n        yield\n    finally:\n        registered.clear()\n        nn.Module.register_parameter = old_register_parameter\n\n\ndef parallel_load_safetensors(filepath):\n    \"\"\"\n    Parallel load safetensors from huggingface checkpoint\n\n    Huggingface checkpoint contains:\n\n    - config.json: a json file for model configuration\n    - model.safetensor.index.json: a json file for safetensors (parameters & buffers) index\n    - model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks\n\n    Or (when model is small),\n\n    - model.safetensors: a binary file for all parameters and buffers\n\n    Each rank will own a part of model chunks and load them directly into GPU memory.\n    \"\"\"\n    from safetensors.torch import load_file\n\n    safetensors2param = {}\n\n    index_file = os.path.join(filepath, \"model.safetensors.index.json\")\n    if os.path.exists(index_file):\n        index = json.load(open(index_file, \"rb\"))\n        for param_name, filename in index[\"weight_map\"].items():\n            safetensors2param.setdefault(filename, []).append(param_name)\n    else:\n        # in this case, the model is small and we can load it all at once\n        param_file = os.path.join(filepath, \"model.safetensors\")\n        assert os.path.exists(param_file), f\"Cannot find {param_file}\"\n        states = load_file(param_file)\n        for param_name in states:\n            safetensors2param.setdefault(\"model.safetensors\", []).append(param_name)\n        del states\n\n    total_files = len(safetensors2param)\n    ckpt_chunks = sorted(safetensors2param.keys())\n    world_size = dist.get_world_size()\n    size = int(math.ceil(total_files / world_size))\n    ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)]\n\n    shard_states = {}\n    device = get_device_id()\n    for rank, files in enumerate(ckpt_chunks):\n        if rank == dist.get_rank():\n            for file in files:\n                file = os.path.join(filepath, file)\n                states = load_file(file, device=device)\n                # print(f\"rank {rank} loading {file}...\")\n                shard_states.update(states)\n        else:\n            for file in files:\n                for param_name in safetensors2param[file]:\n                    shard_states[param_name] = rank\n    return shard_states\n\n\ndef parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter]):\n    \"\"\"\n    Generate a function to initialize sub-modules in the `module` with `shard_states`\n    from huggingface checkpoint.\n\n    Args:\n        module (torch.nn.Module): the global module to be initialized\n        shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint\n\n    Returns:\n        init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states`\n    \"\"\"\n\n    state2fqn = {}\n    for name, state in itertools.chain(\n        module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False)\n    ):\n        state2fqn.setdefault(state, []).append(name)\n    # remove standalone parameters and buffers\n    shared = {s for s, names in state2fqn.items() if len(names) > 1}\n    materialized_states = {}\n\n    @torch.no_grad()\n    def create_and_sync_state(param_name, state, is_param):\n        assert param_name in shard_states, f\"{param_name} not loaded\"\n        device = get_device_id()\n        if is_param:\n            param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad)\n        else:  # buffer\n            param = torch.empty_like(state.data, device=device)\n        loaded = shard_states[param_name]\n        if isinstance(loaded, torch.nn.Parameter | torch.Tensor):\n            # NOTE: loaded.dtype can be different with param.dtype\n            param.data.copy_(loaded.data)\n            dist.broadcast(param.data, src=dist.get_rank())\n        else:\n            assert isinstance(loaded, int)  # the rank that holds the state\n            dist.broadcast(param.data, src=loaded)\n        shard_states.pop(param_name)\n        del loaded\n        return param\n\n    def init_fn(sub_mod: torch.nn.Module, recurse: bool = True):\n        param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False))\n        # param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0])\n        for name, state in param_and_buffers:\n            if not state.is_meta:\n                continue\n            is_param = name in sub_mod._parameters\n            fqn = state2fqn[state].pop(0)\n            # non-persistent buffers will not be saved in state dict, we can safely skip it\n            if (not is_param) and fqn not in shard_states:\n                if state.is_meta:\n                    raise RuntimeError(\n                        f\"find a non-persistent buffer ({fqn}) initiated with device meta. Such buffer is not saved \"\n                        f\"in checkpoint and user should guarantee to init in CPU / GPU device.\"\n                    )\n                continue\n            # for shared parameter, we get it from the first time it is created\n            if state in shared:\n                if state not in materialized_states:\n                    materialized_states[state] = create_and_sync_state(fqn, state, is_param)\n                else:\n                    if fqn in shard_states:\n                        shard_states.pop(fqn)\n                materialize_state = materialized_states[state]\n            # for not shared parameter, we create it directly\n            else:\n                materialize_state = create_and_sync_state(fqn, state, is_param)\n            if is_param:\n                sub_mod._parameters[name] = materialize_state\n            else:\n                sub_mod._buffers[name] = materialize_state\n        if recurse:\n            for module in sub_mod.children():\n                init_fn(module, recurse=True)\n\n        # for debug\n        # if len(shard_states) == 0: print(\"clear\")\n        return sub_mod\n\n    return init_fn\n\n\ndef fsdp_version(model):\n    if isinstance(model, FSDP):\n        return 1\n    elif isinstance(model, FSDPModule):\n        return 2\n    else:\n        return 0\n\n\ndef get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg):\n    if fsdp_version(model) == 1:\n        return FSDP.state_dict_type(model, state_type, state_cfg, optim_cfg)\n    else:\n        return nullcontext()\n\n\ndef get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True, rank0_only: bool = True):\n    \"\"\"\n    Get the full state dict from an FSDP model.\n\n    Args:\n        model (torch.nn.Module): The FSDP model to get state dict from\n        offload_to_cpu (bool, optional): Whether to offload the state dict to CPU. Defaults to True.\n        rank0_only (bool, optional): Whether to only get state dict on rank 0. Defaults to True.\n\n    Returns:\n        dict: The full state dict of the model\n\n    Raises:\n        NotImplementedError: If the FSDP version is unknown\n    \"\"\"\n    if fsdp_version(model) == 1:\n        from torch.distributed.fsdp import FullStateDictConfig, StateDictType\n\n        state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only)\n        with get_fsdp_state_ctx(\n            model, state_type=StateDictType.FULL_STATE_DICT, state_cfg=state_dict_config, optim_cfg=None\n        ):\n            state_dict = model.state_dict()\n        return state_dict\n    elif fsdp_version(model) == 2:\n        from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict\n\n        state_dict_config = StateDictOptions(\n            full_state_dict=True, cpu_offload=offload_to_cpu, broadcast_from_rank0=not rank0_only\n        )\n        state_dict = get_model_state_dict(model, options=state_dict_config)\n        return state_dict\n    else:\n        raise NotImplementedError(f\"Unknown FSDP version {fsdp_version}\")\n\n\ndef fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None):\n    \"\"\"\n    Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the\n    parameters from rank 0 to all other ranks. This function modifies the model in-place.\n\n    Args:\n        model (`torch.nn.Module`): The model to load the state dict into\n        full_state (`dict`): The full state dict to load, can only be on rank 0\n    \"\"\"\n\n    if version.parse(torch.__version__) >= version.parse(\"2.7.0\"):\n        from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict\n    else:\n        # official torch 2.6.0 set_model_state_dict API leads to OOM\n        # use torch 2.7.0 copy from verl/third_party/torch/distributed/checkpoint\n        from verl.third_party.torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict\n\n    # To broadcast, it needs to be instantiated in the GPU.\n    if dist.get_rank() == 0:\n        model = model.to(device=get_device_id(), non_blocking=True)\n    else:\n        model = model.to_empty(device=get_device_id())\n\n    cpu_offload = cpu_offload is not None\n    options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True)\n    set_model_state_dict(model, full_state, options=options)\n\n    # rotary_emb is not in state_dict, so we need to broadcast it manually\n    for name, buf in model.named_buffers():\n        dist.broadcast(buf, src=0)\n\n    if cpu_offload:\n        model.to(\"cpu\", non_blocking=True)\n        for buf in model.buffers():\n            buf.data = buf.data.to(get_device_id())\n\n\ndef apply_fsdp2(model, fsdp_kwargs, config):\n    \"\"\"model: AutoModelForCausalLM\"\"\"\n    assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n\n    default_transformer_cls_names_to_wrap = getattr(model, \"_no_split_modules\", None)\n    fsdp_transformer_layer_cls_to_wrap = config.get(\"wrap_policy\", {}).get(\n        \"transformer_layer_cls_to_wrap\", default_transformer_cls_names_to_wrap\n    )\n\n    if isinstance(fsdp_transformer_layer_cls_to_wrap, str):\n        fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]\n\n    assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None\n\n    modules = []\n    for name, module in model.named_modules():\n        if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (\n            isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings\n        ):\n            modules.append(module)\n\n    for idx, module in enumerate(modules):\n        fully_shard(module, **fsdp_kwargs)\n    fully_shard(model, **fsdp_kwargs)  # fsdp2 will not reshard_after_forward for root module\n\n\ndef fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None):\n    \"\"\"torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor\"\"\"\n    from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm\n\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    else:\n        # prevent generators from being exhausted\n        parameters = list(parameters)\n    grads = [p.grad for p in parameters if p.grad is not None]\n    total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)\n    total_norm = total_norm.to(get_device_id(), non_blocking=True)\n    _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)\n    return total_norm\n\n\ndef layered_summon_lora_params(fsdp_module) -> OrderedDict:\n    from peft.utils.save_and_load import get_peft_model_state_dict\n\n    def __prefix_submodules(module, prefix):\n        for name, submodule in module.named_modules():\n            if name.startswith(prefix) and \".\" not in name[len(prefix) :]:\n                yield name, submodule\n\n    lora_params = OrderedDict()\n    prefix_list = [\n        # fsdp\n        \"_fsdp_wrapped_module.base_model.model.\",\n        \"_fsdp_wrapped_module.base_model.model.model.\",\n        \"_fsdp_wrapped_module.base_model.model.model.layers.\",\n        # fsdp2\n        \"base_model.model.\",\n        \"base_model.model.model.\",\n        \"base_model.model.model.layers.\",\n    ]\n    peft_model = getattr(fsdp_module, \"_fsdp_wrapped_module\", fsdp_module)\n    for prefix in prefix_list:\n        for name, submodule in __prefix_submodules(fsdp_module, prefix):\n            prefix = name.replace(\"_fsdp_wrapped_module.base_model.model.\", \"base_model.model.\")\n            if name.endswith(\".model\") or name.endswith(\".layers\"):\n                continue\n            if fsdp_version(submodule) > 0:\n                with FSDP.summon_full_params(submodule, writeback=False):\n                    sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict())\n                    sub_lora_params = {\n                        f\"{prefix}.{name}\": param.full_tensor().detach().cpu()\n                        if hasattr(param, \"full_tensor\")\n                        else param.detach().cpu()\n                        for name, param in sub_lora_params.items()\n                    }\n                    lora_params.update(sub_lora_params)\n                    submodule._is_root = False\n                get_torch_device().empty_cache()\n    return lora_params\n"
  },
  {
    "path": "verl_rl/verl/utils/hdfs_io.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport shutil\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_SFT_LOGGING_LEVEL\", \"WARN\"))\n\n_HDFS_PREFIX = \"hdfs://\"\n\n_HDFS_BIN_PATH = shutil.which(\"hdfs\")\n\n\ndef exists(path: str, **kwargs) -> bool:\n    r\"\"\"Works like os.path.exists() but supports hdfs.\n\n    Test whether a path exists. Returns False for broken symbolic links.\n\n    Args:\n        path (str): path to test\n\n    Returns:\n        bool: True if the path exists, False otherwise\n    \"\"\"\n    if _is_non_local(path):\n        return _exists(path, **kwargs)\n    return os.path.exists(path)\n\n\ndef _exists(file_path: str):\n    \"\"\"hdfs capable to check whether a file_path is exists\"\"\"\n    if file_path.startswith(\"hdfs\"):\n        return _run_cmd(_hdfs_cmd(f\"-test -e {file_path}\")) == 0\n    return os.path.exists(file_path)\n\n\ndef makedirs(name, mode=0o777, exist_ok=False, **kwargs) -> None:\n    r\"\"\"Works like os.makedirs() but supports hdfs.\n\n    Super-mkdir; create a leaf directory and all intermediate ones.  Works like\n    mkdir, except that any intermediate path segment (not just the rightmost)\n    will be created if it does not exist. If the target directory already\n    exists, raise an OSError if exist_ok is False. Otherwise no exception is\n    raised.  This is recursive.\n\n    Args:\n        name (str): directory to create\n        mode (int): file mode bits\n        exist_ok (bool): if True, do not raise an exception if the directory already exists\n        kwargs: keyword arguments for hdfs\n\n    \"\"\"\n    if _is_non_local(name):\n        # TODO(haibin.lin):\n        # - handle OSError for hdfs(?)\n        # - support exist_ok for hdfs(?)\n        _mkdir(name, **kwargs)\n    else:\n        os.makedirs(name, mode=mode, exist_ok=exist_ok)\n\n\ndef _mkdir(file_path: str) -> bool:\n    \"\"\"hdfs mkdir\"\"\"\n    if file_path.startswith(\"hdfs\"):\n        _run_cmd(_hdfs_cmd(f\"-mkdir -p {file_path}\"))\n    else:\n        os.makedirs(file_path, exist_ok=True)\n    return True\n\n\ndef copy(src: str, dst: str, **kwargs) -> bool:\n    r\"\"\"Works like shutil.copy() for file, and shutil.copytree for dir, and supports hdfs.\n\n    Copy data and mode bits (\"cp src dst\"). Return the file's destination.\n    The destination may be a directory.\n    If source and destination are the same file, a SameFileError will be\n    raised.\n\n    Arg:\n        src (str): source file path\n        dst (str): destination file path\n        kwargs: keyword arguments for hdfs copy\n\n    Returns:\n        str: destination file path\n\n    \"\"\"\n    if _is_non_local(src) or _is_non_local(dst):\n        # TODO(haibin.lin):\n        # - handle SameFileError for hdfs files(?)\n        # - return file destination for hdfs files\n        return _copy(src, dst)\n    else:\n        if os.path.isdir(src):\n            return shutil.copytree(src, dst, **kwargs)\n        else:\n            return shutil.copy(src, dst, **kwargs)\n\n\ndef _copy(from_path: str, to_path: str, timeout: int = None) -> bool:\n    if to_path.startswith(\"hdfs\"):\n        if from_path.startswith(\"hdfs\"):\n            returncode = _run_cmd(_hdfs_cmd(f\"-cp -f {from_path} {to_path}\"), timeout=timeout)\n        else:\n            returncode = _run_cmd(_hdfs_cmd(f\"-put -f {from_path} {to_path}\"), timeout=timeout)\n    else:\n        if from_path.startswith(\"hdfs\"):\n            returncode = _run_cmd(\n                _hdfs_cmd(\n                    f\"-get \\\n                {from_path} {to_path}\"\n                ),\n                timeout=timeout,\n            )\n        else:\n            try:\n                shutil.copy(from_path, to_path)\n                returncode = 0\n            except shutil.SameFileError:\n                returncode = 0\n            except Exception as e:\n                logger.warning(f\"copy {from_path} {to_path} failed: {e}\")\n                returncode = -1\n    return returncode == 0\n\n\ndef _run_cmd(cmd: str, timeout=None):\n    return os.system(cmd)\n\n\ndef _hdfs_cmd(cmd: str) -> str:\n    return f\"{_HDFS_BIN_PATH} dfs {cmd}\"\n\n\ndef _is_non_local(path: str):\n    return path.startswith(_HDFS_PREFIX)\n"
  },
  {
    "path": "verl_rl/verl/utils/import_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities to check if packages are available.\nWe assume package availability won't change during runtime.\n\"\"\"\n\nimport importlib\nimport importlib.util\nimport os\nimport warnings\nfrom functools import cache, wraps\nfrom typing import Optional\n\n\n@cache\ndef is_megatron_core_available():\n    try:\n        mcore_spec = importlib.util.find_spec(\"megatron.core\")\n    except ModuleNotFoundError:\n        mcore_spec = None\n    return mcore_spec is not None\n\n\n@cache\ndef is_vllm_available():\n    try:\n        vllm_spec = importlib.util.find_spec(\"vllm\")\n    except ModuleNotFoundError:\n        vllm_spec = None\n    return vllm_spec is not None\n\n\n@cache\ndef is_sglang_available():\n    try:\n        sglang_spec = importlib.util.find_spec(\"sglang\")\n    except ModuleNotFoundError:\n        sglang_spec = None\n    return sglang_spec is not None\n\n\n@cache\ndef is_nvtx_available():\n    try:\n        nvtx_spec = importlib.util.find_spec(\"nvtx\")\n    except ModuleNotFoundError:\n        nvtx_spec = None\n    return nvtx_spec is not None\n\n\n@cache\ndef is_trl_available():\n    try:\n        trl_spec = importlib.util.find_spec(\"trl\")\n    except ModuleNotFoundError:\n        trl_spec = None\n    return trl_spec is not None\n\n\ndef import_external_libs(external_libs=None):\n    if external_libs is None:\n        return\n    if not isinstance(external_libs, list):\n        external_libs = [external_libs]\n    import importlib\n\n    for external_lib in external_libs:\n        importlib.import_module(external_lib)\n\n\ndef load_extern_type(file_path: Optional[str], type_name: Optional[str]) -> type:\n    \"\"\"Load a external data type based on the file path and type name\"\"\"\n    if not file_path:\n        return None\n\n    if file_path.startswith(\"pkg://\"):\n        # pkg://verl.utils.dataset.rl_dataset\n        # pkg://verl/utils/dataset/rl_dataset\n        module_name = file_path[6:].replace(\"/\", \".\")\n        module = importlib.import_module(module_name)\n\n    else:\n        # file://verl/utils/dataset/rl_dataset\n        # file:///path/to/verl/utils/dataset/rl_dataset.py\n        # or without file:// prefix\n        if file_path.startswith(\"file://\"):\n            file_path = file_path[7:]\n\n        if not os.path.exists(file_path):\n            raise FileNotFoundError(f\"Custom type file '{file_path}' not found.\")\n\n        spec = importlib.util.spec_from_file_location(\"custom_module\", file_path)\n        module = importlib.util.module_from_spec(spec)\n        try:\n            spec.loader.exec_module(module)\n        except Exception as e:\n            raise RuntimeError(f\"Error loading module from '{file_path}'\") from e\n\n    if not hasattr(module, type_name):\n        raise AttributeError(f\"Custom type '{type_name}' not found in '{file_path}'.\")\n\n    return getattr(module, type_name)\n\n\ndef _get_qualified_name(func):\n    \"\"\"Get full qualified name including module and class (if any).\"\"\"\n    module = func.__module__\n    qualname = func.__qualname__\n    return f\"{module}.{qualname}\"\n\n\ndef deprecated(replacement: str = \"\"):\n    \"\"\"Decorator to mark functions or classes as deprecated.\"\"\"\n\n    def decorator(obj):\n        qualified_name = _get_qualified_name(obj)\n\n        if isinstance(obj, type):\n            original_init = obj.__init__\n\n            @wraps(original_init)\n            def wrapped_init(self, *args, **kwargs):\n                msg = f\"Warning: Class '{qualified_name}' is deprecated.\"\n                if replacement:\n                    msg += f\" Please use '{replacement}' instead.\"\n                warnings.warn(msg, category=FutureWarning, stacklevel=2)\n                return original_init(self, *args, **kwargs)\n\n            obj.__init__ = wrapped_init\n            return obj\n\n        else:\n\n            @wraps(obj)\n            def wrapped(*args, **kwargs):\n                msg = f\"Warning: Function '{qualified_name}' is deprecated.\"\n                if replacement:\n                    msg += f\" Please use '{replacement}' instead.\"\n                warnings.warn(msg, category=FutureWarning, stacklevel=2)\n                return obj(*args, **kwargs)\n\n            return wrapped\n\n    return decorator\n"
  },
  {
    "path": "verl_rl/verl/utils/kernel/__init__.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "verl_rl/verl/utils/kernel/kernels.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplementations of the linear cross entropy with token entropy kernel.\n\"\"\"\n\nimport typing\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.distributed as dist\nimport triton\nimport triton.language as tl\n\nfrom verl.utils.device import get_torch_device\n\n\n@dataclass\nclass EntropyReductionEnum:\n    \"\"\"\n    Enum for the reduction method of cross entropy.\n    \"\"\"\n\n    _None = 0\n    _Sum = 1\n    _Mean = 2\n\n\ndef get_entropy_reduction_enum_number(reduction: str) -> int:\n    \"\"\"\n    Get the enum number for the reduction method of cross entropy.\n    \"\"\"\n    _enum = EntropyReductionEnum._None\n    if reduction == \"none\":\n        _enum = EntropyReductionEnum._None\n    elif reduction == \"sum\":\n        _enum = EntropyReductionEnum._Sum\n    elif reduction == \"mean\":\n        _enum = EntropyReductionEnum._Mean\n    else:\n        raise ValueError(f\"Invalid reduction: {reduction}\")\n    return _enum\n\n\ndef get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum:\n    \"\"\"\n    Get the enum for the reduction method of cross entropy.\n    \"\"\"\n    _enum = EntropyReductionEnum._None\n    if ce_reduction == 0:\n        _enum = EntropyReductionEnum._None\n    elif ce_reduction == 1:\n        _enum = EntropyReductionEnum._Sum\n    elif ce_reduction == 2:\n        _enum = EntropyReductionEnum._Mean\n    else:\n        raise ValueError(f\"Invalid ce_reduction: {ce_reduction}\")\n    return _enum\n\n\n@dataclass\nclass BackwardEnum:\n    \"\"\"\n    Enum for the backward method.\n    \"\"\"\n\n    _Total_Fuse_MN = (\n        0  # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight\n    )\n    _Total_Separate = 1  # Store d_logits, no special requirements for d_hidden & d_weight\n    _Split_Dlogits_N = 2  # split d_logits along its N dimension, aka. vocab_size\n    _Split_Dlogits_M = 3  # split d_logits along its M dimension, aka. num_tokens\n\n\n@dataclass\nclass Config:\n    \"\"\"Configuration for efficient entropy kernel operations.\n\n    Args:\n        _backward (BackwardEnum): Backward computation method. Defaults to BackwardEnum._Split_Dlogits_N.\n        _use_triton (bool): Whether to use Triton kernels for computation. Defaults to True.\n    \"\"\"\n\n    _backward: BackwardEnum = BackwardEnum._Split_Dlogits_N\n    _use_triton: bool = True\n\n\n_config = Config()\n\n\ndef set_backward_method(backward_method: BackwardEnum):\n    \"\"\"\n    Set the backward method.\n    \"\"\"\n    global _config\n    _config._backward = backward_method\n\n\n@triton.autotune(\n    configs=[triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32}, num_stages=3, num_warps=8)],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_kernel_general_mainloop(\n    rank,\n    hidden_ptr,\n    weight_ptr,\n    labels_ptr,\n    num_tokens,\n    hidden_size,\n    vocab_size,\n    vocab_per_split,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    max_ptr,\n    stride_max_m: tl.int64,\n    stride_max_n: tl.int64,\n    accu_ptr,\n    stride_accu_m: tl.int64,\n    stride_accu_n: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b_m: tl.int64,\n    stride_entropy_b_n: tl.int64,\n    global_logprobs_ptr,\n    stride_global_logprobs: tl.int64,\n    global_logprobs_scalar_ptr,\n    rcp_temperature: tl.float32,\n    # Meta-parameters\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n):\n    \"\"\"\n    forward mainloop\n    \"\"\"\n    pid = tl.program_id(axis=0)\n    num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N)\n    pid_m = pid % num_pid_m\n    pid_n = pid // num_pid_m\n\n    if pid_m == 0 and pid_n == 0:\n        tl.store(global_logprobs_scalar_ptr, 0.0)\n\n    # create pointers for the first blocks of hidden\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n\n    # load labels for this block\n    labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens)\n\n    # traverse over N dimension\n    # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    _max = tl.full((BLOCK_SIZE_M,), -float(\"inf\"), dtype=tl.float32)\n    _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    for n in range(0, num_pid_n):\n        offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n        # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n)\n        weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n        # iterate over K dimension\n        logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n            # load the next block of hidden and weight\n            _hidden = tl.load(\n                hidden_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n                other=0.0,\n            )\n            # _weight = tl.load(weight_ptrs,\n            #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < (min(\n            #                       (pid_n + 1) * vocab_per_split, vocab_size))),\n            #                   other=0.0)\n\n            _weight = tl.load(\n                weight_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K)\n                & (offs_bn[:, None] < (min((pid_n + 1) * vocab_per_split, vocab_size))),\n                other=0.0,\n            )\n\n            # GEMM\n            logits = tl.dot(_hidden, _weight.trans(), logits)\n\n            # advance the ptrs to the next K block\n            hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n            weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n        # reset hidden_ptrs for next iteration\n        hidden_ptrs -= hidden_size * stride_hidden_k\n\n        # scale logits by temperature\n        logits *= rcp_temperature\n\n        # update global maximum\n        _max_old = _max\n        m_pid_n = tl.max(logits, axis=1)\n        _max = tl.maximum(_max_old, m_pid_n)\n\n        exp_logits = tl.exp(logits - _max[:, None])\n        coeff = tl.exp(_max_old - _max)\n        _accu = coeff * _accu + tl.sum(exp_logits, axis=1)\n\n        _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1)\n\n        label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n        _logprobs += tl.sum(logits * label_mask, axis=1)\n\n    # store maximum\n    offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_max_n = pid_n\n    maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m\n    tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits))\n\n    # store entropy\n    accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m\n    tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits))\n    entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m\n    tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits))\n\n    # store logprobs\n    vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size\n    vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size\n    mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx)\n    mask &= offs_am < num_tokens\n    global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs\n    # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask)\n    tl.store(global_logprobs_ptrs, _logprobs, mask=mask)\n\n\n@triton.autotune(configs=[triton.Config({\"BLOCK_SIZE_M\": 16, \"BLOCK_SIZE_N\": 64})], key=[\"num_tokens\", \"num_splits\"])\n@triton.jit\ndef efficient_entropy_triton_kernel_epilogue(\n    max_ptr,\n    stride_max_m: tl.int64,\n    stride_max_n: tl.int64,\n    num_tokens,\n    num_splits,\n    global_max_ptr,\n    stride_global_max: tl.int64,\n    accu_ptr,\n    stride_accu_m: tl.int64,\n    stride_accu_n: tl.int64,\n    global_accu_ptr,\n    stride_global_accu: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b_m: tl.int64,\n    stride_entropy_b_n: tl.int64,\n    global_entropy_b_ptr,\n    stride_global_entropy_b: tl.int64,\n    global_entropy_ptr,\n    stride_global_entropy: tl.int64,\n    global_logprobs_ptr,\n    stride_global_logprobs: tl.int64,\n    global_logprobs_scalar_ptr,\n    reduction: int,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n):\n    \"\"\"\n    foward epilogue\n    \"\"\"\n    pid_m = tl.program_id(axis=0)\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)):\n        offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n        max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n\n\n        _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0)\n\n        accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n\n        _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0)\n\n        entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n\n        _entropy_b = tl.load(\n            entropy_b_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0\n        )\n\n        # local reduction\n        _max_old = global_max\n        _local_max = tl.max(_max, axis=1)\n        global_max = tl.maximum(global_max, _local_max)\n\n        _scale = tl.exp(_max - global_max[:, None])\n        _coeff = tl.exp(_max_old - global_max)\n        global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1)\n        global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1)\n\n    # store\n    maximum_ptrs = global_max_ptr + offs_m * stride_global_max\n    tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens)\n\n    # store entropy_b\n    global_entropy_b = tl.fdiv(global_entropy_b, global_accu)  # entropy_b\n    tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens)\n\n    # store entropy\n    global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu\n    tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens)\n    global_entropy = tl.log(global_accu) + global_max - global_entropy_b  # entropy_a\n    global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy\n    tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens)\n    # update logprobs\n    global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs\n    global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens)\n    global_logprobs = global_max + tl.log(global_accu) - global_logprobs\n\n    global_logprobs = -1 * global_logprobs\n    if reduction == 0:\n        tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens)\n    elif reduction == 1:\n        global_logprobs_scalar = tl.sum(global_logprobs, axis=0)\n        tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar)\n    elif reduction == 2:\n        global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32)\n        tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar)\n\n\n@triton.autotune(configs=[triton.Config({\"BLOCK_SIZE_M\": 16, \"BLOCK_SIZE_N\": 64})], key=[\"num_tokens\", \"num_splits\"])\n@triton.jit\ndef efficient_entropy_triton_kernel_epilogue_tp(\n    num_tokens,\n    num_splits,\n    reduced_max_ptr,\n    stride_reduced_max_m: tl.int64,\n    stride_reduced_max_n: tl.int64,\n    original_max_ptr,\n    stride_original_max_m: tl.int64,\n    stride_original_max_n: tl.int64,\n    accu_ptr,\n    stride_accu_m: tl.int64,\n    stride_accu_n: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b_m: tl.int64,\n    stride_entropy_b_n: tl.int64,\n    global_max_ptr,\n    stride_global_max: tl.int64,\n    global_accu_ptr,\n    stride_global_accu: tl.int64,\n    global_entropy_b_ptr,\n    stride_global_entropy_b: tl.int64,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n):\n    pid_m = tl.program_id(axis=0)\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n    global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)):\n        offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n        _reduced_max = tl.load(\n            reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + offs_n[None, :] * stride_reduced_max_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n        _original_max = tl.load(\n            original_max_ptr + offs_m[:, None] * stride_original_max_m + offs_n[None, :] * stride_original_max_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n        _accu = tl.load(\n            accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n\n        # local reduce-max\n        _max_old = global_max\n        _local_max = tl.max(_reduced_max, axis=1)\n        global_max = tl.maximum(global_max, _local_max)\n\n        # update accumulate\n        _coeff = tl.exp(_max_old - global_max)\n        _scale = tl.exp(_original_max - global_max[:, None])\n        global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1)\n\n        # update entropy_b\n        _entropy_b = tl.load(\n            entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n        global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1)\n\n    # store\n    tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens)\n    tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens)\n    tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens)\n\n\n@triton.autotune(configs=[triton.Config({\"BLOCK_SIZE_M\": 16})], key=[\"num_tokens\"])\n@triton.jit\ndef efficient_entropy_triton_epilogue_tp_update(\n    num_tokens,\n    logprobs_ptr,\n    stride_logprobs: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accumulate_ptr,\n    stride_accumulate: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    entropy_ptr,\n    stride_entropy: tl.int64,\n    logprobs_scalar_ptr,\n    reduction: int,\n    BLOCK_SIZE_M: tl.constexpr,\n):\n    pid_m = tl.program_id(axis=0)\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n    maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens)\n    accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens)\n\n    entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens)\n    entropy_b = tl.fdiv(entropy_b, accumulate)\n    tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens)\n\n    entropy = tl.log(accumulate) + maximum - entropy_b\n    tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens)\n\n    logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens)\n    logprobs = maximum + tl.log(accumulate) - logprobs\n\n    logprobs = -1 * logprobs\n    if reduction == 0:\n        tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens)\n    elif reduction == 1:\n        logprobs_scalar = tl.sum(logprobs, axis=0)\n        tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar)\n    elif reduction == 2:\n        logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32)\n        tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar)\n\n\n_dedicated_stream, _dedicated_events = None, None\n\n\ndef efficient_entropy_forward(\n    hidden: torch.Tensor,\n    weight: torch.Tensor,\n    labels: torch.Tensor,\n    reduction: typing.Optional[int] = 2,\n    temperature: typing.Optional[float] = 1.0,\n    dist_process_group: typing.Optional[dist.ProcessGroup] = None,\n) -> list[torch.Tensor]:\n    \"\"\"\n    forward host function\n    \"\"\"\n    assert hidden.is_cuda and weight.is_cuda and labels.is_cuda\n    assert weight.device == hidden.device and labels.device == hidden.device\n    assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1\n    assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous()\n\n    assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1]\n\n    _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group)\n    _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group)\n\n    if dist_process_group is not None and not hasattr(efficient_entropy_forward, \"_initialized\"):\n        global _dedicated_stream, _dedicated_events\n        _dedicated_stream = get_torch_device().Stream(hidden.device)\n        _dedicated_events = [get_torch_device().Event() for _ in range(2)]\n        efficient_entropy_forward._initialized = True\n\n    num_tokens, hidden_size = hidden.shape\n    num_tokens = labels.shape[0]\n    vocab_size, hidden_size = weight.shape\n    assert hidden_size % 128 == 0\n\n    REDUCTION = get_entropy_reduction_enum(reduction)\n\n    if REDUCTION == EntropyReductionEnum._None:\n        if dist_process_group is None:\n            logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32)\n        else:\n            logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32)\n    elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean):\n        logprobs = torch.empty((), device=hidden.device, dtype=torch.float32)\n    else:\n        raise ValueError(f\"Invalid reduction: {reduction}\")\n\n    entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32)\n    assert logprobs.is_contiguous() and entropy.is_contiguous()\n\n    maximum = torch.empty_like(entropy)\n    accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32)\n    accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens)\n    accumulate = accumulate_and_entropy_b_view[0, :]\n    entropy_b = accumulate_and_entropy_b_view[1, :]\n    assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous()\n\n    vocab_per_split = 1024\n    assert vocab_per_split % 128 == 0\n    num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n\n    _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32)\n    _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32)\n    _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32)\n\n    if REDUCTION == EntropyReductionEnum._None:\n        _logprobs = logprobs\n    else:\n        _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32)\n\n    assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous()\n    assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda\n\n    if _config._use_triton:\n        # 1D kernel launch, then split the tile\n        def mainloop_grid(meta):\n            return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * num_splits,)\n\n        efficient_entropy_kernel_general_mainloop[mainloop_grid](\n            _rank,\n            hidden,\n            weight,\n            labels,\n            num_tokens,\n            hidden_size,\n            vocab_size,\n            vocab_per_split,\n            hidden.stride(0),\n            hidden.stride(1),\n            weight.stride(0),\n            weight.stride(1),\n            _max,\n            _max.stride(0),\n            _max.stride(1),\n            _accu,\n            _accu.stride(0),\n            _accu.stride(1),\n            _entropy_b,\n            _entropy_b.stride(0),\n            _entropy_b.stride(1),\n            _logprobs,\n            _logprobs.stride(0),\n            logprobs,\n            1.0 / temperature,\n        )\n    else:\n        raise AssertionError(\"Triton is required for efficient entropy kernel\")\n\n    # reduction on maximum and maximum_indices\n    def epilogue_grid(meta):\n        return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]),)\n\n    if dist_process_group is None:\n        efficient_entropy_triton_kernel_epilogue[epilogue_grid](\n            _max,\n            _max.stride(0),\n            _max.stride(1),\n            num_tokens,\n            num_splits,\n            maximum,\n            maximum.stride(0),\n            _accu,\n            _accu.stride(0),\n            _accu.stride(1),\n            accumulate,\n            accumulate.stride(0),\n            _entropy_b,\n            _entropy_b.stride(0),\n            _entropy_b.stride(1),\n            entropy_b,\n            entropy_b.stride(0),\n            entropy,\n            entropy.stride(0),\n            _logprobs,\n            _logprobs.stride(0),\n            logprobs,\n            REDUCTION,\n        )\n    else:\n        # tensor-parallel\n        _max_backup = _max.clone()\n        dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group)\n\n        get_torch_device().current_stream().record_event(_dedicated_events[0])\n        with get_torch_device().stream(_dedicated_stream):\n            _dedicated_stream.wait_event(_dedicated_events[0])\n            dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group)\n            _dedicated_stream.record_event(_dedicated_events[1])\n\n        efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid](\n            num_tokens,\n            num_splits,\n            _max,\n            _max.stride(0),\n            _max.stride(1),\n            _max_backup,\n            _max_backup.stride(0),\n            _max_backup.stride(1),\n            _accu,\n            _accu.stride(0),\n            _accu.stride(1),\n            _entropy_b,\n            _entropy_b.stride(0),\n            _entropy_b.stride(1),\n            maximum,\n            maximum.stride(0),\n            accumulate,\n            accumulate.stride(0),\n            entropy_b,\n            entropy_b.stride(0),\n        )\n        get_torch_device().current_stream().wait_event(_dedicated_events[1])\n\n        dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group)\n\n        # update logprobs & entropy\n        efficient_entropy_triton_epilogue_tp_update[epilogue_grid](\n            num_tokens,\n            _logprobs,\n            _logprobs.stride(0),\n            maximum,\n            maximum.stride(0),\n            accumulate,\n            accumulate.stride(0),\n            entropy_b,\n            entropy_b.stride(0),\n            entropy,\n            entropy.stride(0),\n            logprobs,\n            REDUCTION,\n        )\n\n    return (logprobs, entropy, maximum, accumulate, entropy_b)\n\n\n# NOTE: merge d_weight & d_hidden here, split along M & N\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        )\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_general_mainloop_MN(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    d_hidden_ptr,\n    stride_d_hidden_m: tl.int64,\n    stride_d_hidden_k: tl.int64,\n    d_weight_ptr,\n    stride_d_weight_n: tl.int64,\n    stride_d_weight_k: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"\n    backward mainloop, where d_logits & d_hidden & d_weight are fused\n    \"\"\"\n    # block swizzling\n    # pid = tl.program_id(axis=0)\n    # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    # pid_m = pid % num_pid_m\n    # pid_n = pid // num_pid_m\n\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n    maximum_ptrs = maximum_ptr + offs_am * stride_maximum\n    maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0)\n    accu_ptrs = accu_ptr + offs_am * stride_accu\n    accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6)  # epsilon to avoid division by zero\n    accu_rcp = tl.fdiv(1.0, accu)\n\n    d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy\n    d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0)\n    if reduction == 0:  # none\n        d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs\n        d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0)\n    elif reduction == 1:  # sum\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:  # mean\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n\n    entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b\n    entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0)\n\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n    # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n)\n    weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n    labels_ptrs = labels_ptr + offs_am * stride_labels\n    labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0)\n\n    d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k\n    # d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n\n    d_weight_ptrs = d_weight_ptr + offs_bn[:, None] * stride_d_weight_n + offs_k[None, :] * stride_d_weight_k\n\n    logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        # _weight = tl.load(weight_ptrs,\n        #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size),\n        #                   other=0.0)\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n            other=0.0,\n        )\n\n        logits = tl.dot(_hidden, _weight.trans(), logits)\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n    hidden_ptrs -= hidden_size * stride_hidden_k\n    weight_ptrs -= hidden_size * stride_weight_k\n\n    # scale logits by temperature\n    logits *= rcp_temperature\n\n    exp_logits = tl.exp(logits - maximum[:, None])\n\n    mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n    d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n    d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n    # scale d_logits by temperature\n    d_logits *= rcp_temperature\n\n    # loop for d_weight & d_hidden\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        # _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits)\n        # tl.atomic_add(d_weight_ptrs,\n        #               _d_weight,\n        #               mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size))\n        _d_weight = tl.dot(d_logits.trans(), _hidden.to(tl.float32))\n        tl.atomic_add(\n            d_weight_ptrs,\n            _d_weight,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n        )\n\n        # _weight = tl.load(weight_ptrs,\n        #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size),\n        #                   other=0.0)\n        # _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32))\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n            other=0.0,\n        )\n        _d_hidden = tl.dot(d_logits, _weight.to(tl.float32))\n        tl.atomic_add(\n            d_hidden_ptrs,\n            _d_hidden,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n        )\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n        d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k\n        d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_d_hidden(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    d_hidden_ptr,\n    stride_d_hidden_m: tl.int64,\n    stride_d_hidden_k: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"\n    backward d_hidden\n    \"\"\"\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    pid_m = pid % num_pid_m\n    pid_k = pid // num_pid_m\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    result_offs_k = pid_k * BLOCK_SIZE_K + offs_k\n\n    maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0)\n    accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6)\n    accu_rcp = tl.fdiv(1.0, accu)\n    d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0)\n    if reduction == 0:\n        d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0)\n    elif reduction == 1:\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n\n    entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0)\n    labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0)\n\n    # iterate over vocab_size\n    d_hidden = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n    for n in range(0, tl.cdiv(vocab_size, BLOCK_SIZE_N)):\n        offs_n = n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n        hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n        weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n        # iterate over hidden_size to get logits\n        logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n            _hidden = tl.load(\n                hidden_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens),\n                other=0.0,\n            )\n            _weight = tl.load(\n                weight_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size),\n                other=0.0,\n            )\n\n            logits = tl.dot(_hidden, _weight.trans(), logits)\n\n            hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n            weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n\n        # scale logits by temperature\n        logits *= rcp_temperature\n\n        exp_logits = tl.exp(logits - maximum[:, None])\n\n        mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None]\n        d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n        d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n        # scale d_logits\n        d_logits *= rcp_temperature\n\n        # calculate d_hidden\n        weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k)\n        _weight = tl.load(\n            weight_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_n[:, None] < vocab_size), other=0.0\n        )\n        d_hidden = tl.dot(d_logits.to(weight_ptr.dtype.element_ty), _weight, d_hidden)\n\n    # write back\n    tl.store(\n        d_hidden_ptr + offs_m[:, None] * stride_d_hidden_m + result_offs_k[None, :] * stride_d_hidden_k,\n        d_hidden,\n        mask=(offs_m[:, None] < num_tokens) & (result_offs_k[None, :] < hidden_size),\n    )\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_d_weight(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    d_weight_ptr,\n    stride_d_weight_n: tl.int64,\n    stride_d_weight_k: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n    num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N)\n    pid_n = pid % num_pid_n\n    pid_k = pid // num_pid_n\n\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    result_offs_k = pid_k * BLOCK_SIZE_K + offs_k\n\n    d_weight = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)\n    for m in range(0, tl.cdiv(num_tokens, BLOCK_SIZE_M)):\n        offs_m = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n        maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0)\n        accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6)\n        accu_rcp = tl.fdiv(1.0, accu)\n        d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0)\n        if reduction == 0:\n            d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0)\n        elif reduction == 1:\n            d_logprobs = tl.load(d_logprobs_ptr)\n            d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n        else:\n            d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n            d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n        d_logprobs = -1 * d_logprobs\n\n        entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0)\n        labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0)\n\n        hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n        weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n        logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n            _hidden = tl.load(\n                hidden_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens),\n                other=0.0,\n            )\n            _weight = tl.load(\n                weight_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size),\n                other=0.0,\n            )\n\n            logits = tl.dot(_hidden, _weight.trans(), logits)\n\n            hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n            weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n\n        logits *= rcp_temperature\n\n        exp_logits = tl.exp(logits - maximum[:, None])\n\n        mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None]\n        d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n        d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n        d_logits *= rcp_temperature\n\n        hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k)\n        _hidden = tl.load(\n            hidden_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_m[:, None] < num_tokens), other=0.0\n        )\n        d_weight = tl.dot(d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight)\n\n    # write back\n    tl.store(\n        d_weight_ptr + offs_n[:, None] * stride_d_weight_n + result_offs_k[None, :] * stride_d_weight_k,\n        d_weight,\n        mask=(offs_n[:, None] < vocab_size) & (result_offs_k[None, :] < hidden_size),\n    )\n\n\n# NOTE: split tile from d_logits' perspective\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_general_d_logits(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b,\n    d_logits_ptr,\n    stride_d_logits_m: tl.int64,\n    stride_d_logits_n: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"\n    backward d_logits\n    \"\"\"\n    # block swizzling\n    # pid = tl.program_id(axis=0)\n    # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    # pid_m = pid % num_pid_m\n    # pid_n = pid // num_pid_m\n\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n    maximum_ptrs = maximum_ptr + offs_am * stride_maximum\n    maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0)\n    accu_ptrs = accu_ptr + offs_am * stride_accu\n    accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6)  # epsilon to avoid division by zero\n    accu_rcp = tl.fdiv(1.0, accu)\n\n    d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy\n    d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0)\n    if reduction == 0:  # none\n        d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs\n        d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0)\n    elif reduction == 1:  # sum\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:  # mean\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n\n    entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b\n    entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0)\n\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n    # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n)\n    weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n    labels_ptrs = labels_ptr + offs_am * stride_labels\n    labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0)\n\n    logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        # _weight = tl.load(weight_ptrs,\n        #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size),\n        #                   other=0.0)\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n            other=0.0,\n        )\n\n        logits = tl.dot(_hidden, _weight.trans(), logits)\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n    hidden_ptrs -= hidden_size * stride_hidden_k\n    weight_ptrs -= hidden_size * stride_weight_k\n\n    # scale logits by temperature\n    logits *= rcp_temperature\n\n    exp_logits = tl.exp(logits - maximum[:, None])\n\n    mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n    d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n    d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n    # scale d_logits by temperature\n    d_logits *= rcp_temperature\n\n    # store d_logits\n    d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n\n    tl.store(\n        d_logits_ptrs,\n        d_logits,  # will be implicitly converted to d_logits_ptrs.dtype.element_ty\n        mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size),\n    )\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_general_d_logits_split_N(\n    split_idx: int,\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    vocab_per_split: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b,\n    d_logits_ptr,\n    stride_d_logits_m: tl.int64,\n    stride_d_logits_n: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_bn = split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n    maximum = tl.load(maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0)\n    accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6)\n    accu_rcp = tl.fdiv(1.0, accu)\n    d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0)\n    if reduction == 0:\n        d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0)\n    elif reduction == 1:\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n    entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0)\n    labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0)\n\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n    weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n    vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size)\n    logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_right_bound),\n            other=0.0,\n        )\n        logits = tl.dot(_hidden, _weight.trans(), logits)\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n\n    logits *= rcp_temperature\n    exp_logits = tl.exp(logits - maximum[:, None])\n\n    mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n    d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n    d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n    d_logits *= rcp_temperature\n\n    # filter d_logits with mask\n    result_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split)\n\n    tl.store(\n        d_logits_ptr + offs_am[:, None] * stride_d_logits_m + result_offs_n[None, :] * stride_d_logits_n, d_logits, mask\n    )\n\n\ndef efficient_entropy_backward(\n    dlogprobs: torch.Tensor,\n    dentropy: torch.Tensor,\n    hidden: torch.Tensor,\n    weight: torch.Tensor,\n    labels: torch.Tensor,\n    maximum: torch.Tensor,\n    acc: torch.Tensor,\n    entropy_b: torch.Tensor,\n    reduction: typing.Optional[int] = 2,\n    should_return_fp32_grad: bool = False,\n    temperature: typing.Optional[float] = 1.0,\n    dist_process_group: typing.Optional[dist.ProcessGroup] = None,\n) -> list[torch.Tensor]:\n    \"\"\"\n    backward host function\n    \"\"\"\n    assert hidden.is_cuda and weight.is_cuda and labels.is_cuda\n    assert weight.device == hidden.device and labels.device == hidden.device\n    assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1\n    assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous()\n    assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1]\n\n    _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group)\n    _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group)\n\n    num_tokens, hidden_size = hidden.shape\n    num_tokens = labels.shape[0]\n    vocab_size, hidden_size = weight.shape\n    assert hidden_size % 128 == 0\n\n    REDUCTION = get_entropy_reduction_enum(reduction)\n\n    if REDUCTION == EntropyReductionEnum._None:\n        assert dlogprobs.shape == (num_tokens,)\n    else:\n        assert dlogprobs.dim() == 0\n\n    assert dlogprobs.is_contiguous() and dentropy.is_contiguous()\n    assert dlogprobs.is_cuda and dentropy.is_cuda\n    assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device\n    assert dentropy.shape == (num_tokens,)\n\n    d_hidden, d_weight = None, None\n    if _config._backward == BackwardEnum._Total_Fuse_MN or should_return_fp32_grad:\n        d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device)\n        d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device)\n    else:\n        d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device)\n        d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device)\n    assert d_hidden.is_contiguous() and d_weight.is_contiguous()\n\n    assert maximum.is_contiguous() and acc.is_contiguous()\n    assert maximum.device == hidden.device and acc.device == hidden.device\n    assert maximum.shape == labels.shape == acc.shape\n    assert maximum.is_cuda and acc.is_cuda\n\n    vocab_per_split = 1024\n    assert vocab_per_split % 128 == 0\n    num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n\n    assert entropy_b.is_contiguous() and entropy_b.is_cuda\n    assert entropy_b.shape == (num_tokens,)\n\n    if _config._backward == BackwardEnum._Total_Fuse_MN:\n        # --- Triton doesn't materialize d_logits at all. Split tiles at the perspective of d_logits.\n        def mainloop_grid(meta):\n            return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * triton.cdiv(vocab_size, meta[\"BLOCK_SIZE_N\"]),)\n\n        efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid](\n            num_tokens,\n            hidden_size,\n            vocab_size,\n            _rank,\n            hidden,\n            hidden.stride(0),\n            hidden.stride(1),\n            weight,\n            weight.stride(0),\n            weight.stride(1),\n            labels,\n            labels.stride(0),\n            maximum,\n            maximum.stride(0),\n            acc,\n            acc.stride(0),\n            dentropy,\n            dentropy.stride(0),\n            dlogprobs,\n            dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0,\n            REDUCTION,\n            entropy_b,\n            entropy_b.stride(0),\n            d_hidden,\n            d_hidden.stride(0),\n            d_hidden.stride(1),\n            d_weight,\n            d_weight.stride(0),\n            d_weight.stride(1),\n            1.0 / temperature,\n        )\n\n    elif _config._backward == BackwardEnum._Total_Separate:\n        _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype).contiguous()\n        assert _d_logits.is_contiguous()\n\n        if _config._use_triton:\n\n            def d_logits_grid(meta):\n                return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * triton.cdiv(vocab_size, meta[\"BLOCK_SIZE_N\"]),)\n\n            efficient_entropy_backward_kernel_general_d_logits[d_logits_grid](\n                num_tokens,\n                hidden_size,\n                vocab_size,\n                _rank,\n                hidden,\n                hidden.stride(0),\n                hidden.stride(1),\n                weight,\n                weight.stride(0),\n                weight.stride(1),\n                labels,\n                labels.stride(0),\n                maximum,\n                maximum.stride(0),\n                acc,\n                acc.stride(0),\n                dentropy,\n                dentropy.stride(0),\n                dlogprobs,\n                dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0,\n                REDUCTION,\n                entropy_b,\n                entropy_b.stride(0),\n                _d_logits,\n                _d_logits.stride(0),\n                _d_logits.stride(1),\n                1.0 / temperature,\n            )\n\n            torch.matmul(_d_logits, weight, out=d_hidden)\n            torch.matmul(_d_logits.T, hidden, out=d_weight)\n        else:\n            raise AssertionError(\"Triton is required for efficient entropy kernel\")\n\n    elif _config._backward == BackwardEnum._Split_Dlogits_N:\n        vocab_per_split = 9504\n        num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n\n        _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous()\n        assert _d_logits.is_contiguous()\n\n        def d_logits_grid(meta):\n            return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * triton.cdiv(vocab_per_split, meta[\"BLOCK_SIZE_N\"]),)\n\n        for split_idx in range(num_splits):\n            efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid](\n                split_idx,\n                num_tokens,\n                hidden_size,\n                vocab_size,\n                vocab_per_split,\n                _rank,\n                hidden,\n                hidden.stride(0),\n                hidden.stride(1),\n                weight,\n                weight.stride(0),\n                weight.stride(1),\n                labels,\n                labels.stride(0),\n                maximum,\n                maximum.stride(0),\n                acc,\n                acc.stride(0),\n                dentropy,\n                dentropy.stride(0),\n                dlogprobs,\n                dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0,\n                REDUCTION,\n                entropy_b,\n                entropy_b.stride(0),\n                _d_logits,\n                _d_logits.stride(0),\n                _d_logits.stride(1),\n                1.0 / temperature,\n            )\n\n            if split_idx == (num_splits - 1):\n                vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split\n                _d_logits = _d_logits[:, :vocab_right_bound].contiguous()\n\n            if split_idx == 0:\n                torch.matmul(\n                    _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], out=d_hidden\n                )\n            else:\n                d_hidden += torch.matmul(\n                    _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :]\n                )\n            torch.matmul(\n                _d_logits.T, hidden, out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :]\n            )\n\n    elif _config._backward == BackwardEnum._Split_Dlogits_M:\n        raise NotImplementedError(\"BackwardEnum._Split_Dlogits_M is not implemented yet\")\n\n    return d_hidden, d_weight\n"
  },
  {
    "path": "verl_rl/verl/utils/kernel/linear_cross_entropy.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport typing\n\nimport torch\nimport torch.distributed as dist\n\nfrom . import kernels\n\n\nclass LinearCrossEntropy(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        hidden: torch.Tensor,\n        weight: torch.Tensor,\n        labels: torch.Tensor,\n        temperature: typing.Optional[float] = 1.0,\n        reduction: typing.Optional[str] = \"none\",\n        dist_process_group: typing.Optional[dist.ProcessGroup] = None,\n    ) -> list[torch.Tensor]:\n        \"\"\"_summary_\n\n        Args:\n            ctx (_type_): _description_\n            hidden (torch.Tensor): (batch_size, num_tokens, hidden_size) -> (batch_size * num_tokens, hidden_size)\n            weight (torch.Tensor): (vocab_size, hidden_size)\n            labels (torch.Tensor): (batch_size, num_tokens) -> (batch_size * num_tokens, )\n            temperature (typing.Optional[float], optional): _description_. Defaults to 1.0.\n            reduction (typing.Optional[str], optional): _description_. Defaults to \"none\".\n            dist_process_group (typing.Optional[dist.ProcessGroup], optional): _description_. Defaults to None.\n\n        Returns:\n            typing.List[torch.Tensor]: _description_\n        \"\"\"\n\n        assert isinstance(temperature, float), f\"temperature must be a float, but got {type(temperature)}\"\n        assert isinstance(reduction, str), f\"reduction must be a str, but got {type(reduction)}\"\n        with torch.cuda.nvtx.range(\"LinearCrossEntropy-forward\"):\n            REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower())\n\n            original_hidden_shape = hidden.shape\n            if len(hidden.shape) != 2:\n                hidden = hidden.view(-1, hidden.shape[-1])  # (batch_size * num_tokens, hidden_size)\n            if len(labels.shape) != 1:\n                labels = labels.view(-1)\n\n            logprobs, entropy, _maximum, _accumulate, _entropy_b = kernels.efficient_entropy_forward(\n                hidden, weight, labels, REDUCTION, temperature, dist_process_group\n            )\n\n            ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b)\n            ctx.original_hidden_shape = original_hidden_shape\n            ctx.REDUCTION = REDUCTION\n            ctx.dist_process_group = dist_process_group\n            ctx.should_return_fp32_grad = False\n            ctx.temperature = temperature\n        return logprobs, entropy\n\n    @staticmethod\n    def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> list[torch.Tensor]:\n        with torch.cuda.nvtx.range(\"LinearCrossEntropy-backward\"):\n            (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors\n            REDUCTION = ctx.REDUCTION\n            dist_process_group = ctx.dist_process_group\n            should_return_fp32_grad = ctx.should_return_fp32_grad\n            temperature = ctx.temperature\n\n            d_hidden, d_weight = kernels.efficient_entropy_backward(\n                dlogprobs,\n                dentropy,\n                hidden,\n                weight,\n                labels,\n                _maximum,\n                _accumulate,\n                _entropy_b,\n                REDUCTION,\n                should_return_fp32_grad,\n                temperature,\n                dist_process_group,\n            )\n            d_hidden = d_hidden.view(ctx.original_hidden_shape)\n\n        return (d_hidden, d_weight, None, None, None, None)\n\n\nlinear_cross_entropy = LinearCrossEntropy.apply\n"
  },
  {
    "path": "verl_rl/verl/utils/logger/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom .aggregate_logger import (\n    DecoratorLoggerBase,\n    LocalLogger,\n    log_with_rank,\n    print_rank_0,\n    print_with_rank,\n    print_with_rank_and_timer,\n)\n\n__all__ = [\n    \"LocalLogger\",\n    \"DecoratorLoggerBase\",\n    \"print_rank_0\",\n    \"print_with_rank\",\n    \"print_with_rank_and_timer\",\n    \"log_with_rank\",\n]\n"
  },
  {
    "path": "verl_rl/verl/utils/logger/aggregate_logger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA Ray logger will receive logging info from different processes.\n\"\"\"\n\nimport datetime\nimport logging\nimport numbers\nimport pprint\n\nimport torch\n\n\ndef concat_dict_to_str(dict: dict, step):\n    output = [f\"step:{step}\"]\n    for k, v in dict.items():\n        if isinstance(v, numbers.Number):\n            output.append(f\"{k}:{pprint.pformat(v)}\")\n    output_str = \" - \".join(output)\n    return output_str\n\n\nclass LocalLogger:\n    \"\"\"\n    A local logger that logs messages to the console.\n\n    Args:\n        print_to_console (bool): Whether to print to the console.\n    \"\"\"\n\n    def __init__(self, print_to_console=True):\n        self.print_to_console = print_to_console\n\n    def flush(self):\n        pass\n\n    def log(self, data, step):\n        if self.print_to_console:\n            print(concat_dict_to_str(data, step=step), flush=True)\n\n\nclass DecoratorLoggerBase:\n    \"\"\"\n    Base class for all decorators that log messages.\n\n    Args:\n        role (str): The role (the name) of the logger.\n        logger (logging.Logger): The logger instance to use for logging.\n        level (int): The logging level.\n        rank (int): The rank of the process.\n        log_only_rank_0 (bool): If True, only log for rank 0.\n    \"\"\"\n\n    def __init__(\n        self, role: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0, log_only_rank_0: bool = True\n    ):\n        self.role = role\n        self.logger = logger\n        self.level = level\n        self.rank = rank\n        self.log_only_rank_0 = log_only_rank_0\n        self.logging_function = self.log_by_logging\n        if logger is None:\n            self.logging_function = self.log_by_print\n\n    def log_by_print(self, log_str):\n        if not self.log_only_rank_0 or self.rank == 0:\n            print(f\"{self.role} {log_str}\", flush=True)\n\n    def log_by_logging(self, log_str):\n        if self.logger is None:\n            raise ValueError(\"Logger is not initialized\")\n        if not self.log_only_rank_0 or self.rank == 0:\n            self.logger.log(self.level, f\"{self.role} {log_str}\")\n\n\ndef print_rank_0(message):\n    \"\"\"If distributed is initialized, print only on rank 0.\"\"\"\n    if torch.distributed.is_initialized():\n        if torch.distributed.get_rank() == 0:\n            print(message, flush=True)\n    else:\n        print(message, flush=True)\n\n\ndef print_with_rank(message: str, rank: int = 0, log_only_rank_0: bool = False):\n    \"\"\"_summary_\n    Print a message with rank information.\n    This function prints the message only if `log_only_rank_0` is False or if the rank is 0.\n\n    Args:\n        message (str): _description_\n        rank (int, optional): _description_. Defaults to 0.\n        log_only_rank_0 (bool, optional): _description_. Defaults to False.\n    \"\"\"\n    if not log_only_rank_0 or rank == 0:\n        print(f\"[Rank {rank}] {message}\", flush=True)\n\n\ndef print_with_rank_and_timer(message: str, rank: int = 0, log_only_rank_0: bool = False):\n    \"\"\"_summary_\n    Print a message with rank information and a timestamp.\n    This function prints the message only if `log_only_rank_0` is False or if the rank is 0.\n\n    Args:\n        message (str): _description_\n        rank (int, optional): _description_. Defaults to 0.\n        log_only_rank_0 (bool, optional): _description_. Defaults to False.\n    \"\"\"\n    now = datetime.datetime.now()\n    message = f\"[{now.strftime('%Y-%m-%d %H:%M:%S')}] [Rank {rank}] {message}\"\n    if not log_only_rank_0 or rank == 0:\n        print(message, flush=True)\n\n\ndef log_with_rank(message: str, rank, logger: logging.Logger, level=logging.INFO, log_only_rank_0: bool = False):\n    \"\"\"_summary_\n    Log a message with rank information using a logger.\n    This function logs the message only if `log_only_rank_0` is False or if the rank is 0.\n    Args:\n        message (str): The message to log.\n        rank (int): The rank of the process.\n        logger (logging.Logger): The logger instance to use for logging.\n        level (int, optional): The logging level. Defaults to logging.INFO.\n        log_only_rank_0 (bool, optional): If True, only log for rank 0. Defaults to False.\n    \"\"\"\n    if not log_only_rank_0 or rank == 0:\n        logger.log(level, f\"[Rank {rank}] {message}\")\n"
  },
  {
    "path": "verl_rl/verl/utils/logging_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\n\n\ndef set_basic_config(level):\n    \"\"\"\n    This function sets the global logging format and level. It will be called when import verl\n    \"\"\"\n    logging.basicConfig(format=\"%(levelname)s:%(asctime)s:%(message)s\", level=level)\n\n\ndef log_to_file(string):\n    print(string)\n    if os.path.isdir(\"logs\"):\n        with open(f\"logs/log_{torch.distributed.get_rank()}\", \"a+\") as f:\n            f.write(string + \"\\n\")\n"
  },
  {
    "path": "verl_rl/verl/utils/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/utils/megatron/dist_checkpointing.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core import dist_checkpointing, mpu\nfrom megatron.core.dist_checkpointing.serialization import (\n    get_default_load_sharded_strategy,\n    get_default_save_sharded_strategy,\n)\nfrom megatron.core.dist_checkpointing.strategies.fully_parallel import (\n    FullyParallelLoadStrategyWrapper,\n    FullyParallelSaveStrategyWrapper,\n)\n\n\ndef save_dist_checkpointing(sharded_state_dict, ckpt_path, async_save=False):\n    validate_sharding_integrity = True\n    # Get checkpointing strategies\n    save_strategy = get_default_save_sharded_strategy(\"torch_dist\")\n    save_strategy = FullyParallelSaveStrategyWrapper(\n        save_strategy, mpu.get_data_parallel_group(with_context_parallel=True)\n    )\n\n    # Save model sharded state dicts\n    async_save_request = dist_checkpointing.save(\n        sharded_state_dict,\n        ckpt_path,\n        sharded_strategy=save_strategy,\n        async_sharded_save=async_save,\n        validate_access_integrity=validate_sharding_integrity,\n    )\n\n    return async_save_request\n\n\ndef load_dist_checkpointing(sharded_state_dict, ckpt_dir):\n    # Get checkpointing strategies\n    load_strategy = get_default_load_sharded_strategy(ckpt_dir)\n    load_strategy = FullyParallelLoadStrategyWrapper(\n        load_strategy, mpu.get_data_parallel_group(with_context_parallel=True)\n    )\n\n    # Load model sharded state dicts\n    state_dict = dist_checkpointing.load(sharded_state_dict, ckpt_dir, sharded_strategy=load_strategy)\n\n    return state_dict\n"
  },
  {
    "path": "verl_rl/verl/utils/megatron/memory.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\nfrom verl.utils.device import get_device_id\n\n\nclass MemoryBuffer:\n    def __init__(self, numel, numel_padded, dtype):\n        self.numel = numel\n        self.numel_padded = numel_padded\n        self.dtype = dtype\n        self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_id(), requires_grad=False)\n\n    def zero(self):\n        \"\"\"Reset the buffer to zero.\"\"\"\n        self.data.zero_()\n\n    def get(self, shape, start_index):\n        \"\"\"Return a tensor with the input `shape` as a view into the\n        1-D data starting at `start_index`.\"\"\"\n        end_index = start_index + shape.numel()\n        assert end_index <= self.numel, \"requested tensor is out of the buffer range.\"\n        buffer_tensor = self.data[start_index:end_index]\n        buffer_tensor = buffer_tensor.view(shape)\n        return buffer_tensor\n"
  },
  {
    "path": "verl_rl/verl/utils/megatron/optimizer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core.optimizer import OptimizerConfig\nfrom megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native\nfrom megatron.core.optimizer_param_scheduler import OptimizerParamScheduler\n\n\ndef get_megatron_optimizer(\n    model,\n    config: OptimizerConfig,\n    no_weight_decay_cond=None,\n    scale_lr_cond=None,\n    lr_mult=1.0,\n):\n    # Base optimizer.\n    return get_megatron_optimizer_native(\n        config=config,\n        model_chunks=model,\n        no_weight_decay_cond=no_weight_decay_cond,\n        scale_lr_cond=scale_lr_cond,\n        lr_mult=lr_mult,\n    )\n\n\ndef get_megatron_optimizer_param_scheduler(\n    optimizer,\n    config,\n):\n    \"\"\"\n    Get the optimizer parameter scheduler for Megatron.\n    \"\"\"\n    if config.get(\"lr_decay_steps\", None) is None:\n        config.lr_decay_steps = config.total_training_steps\n    wsd_decay_steps = None\n    if config.get(\"lr_wsd_decay_steps\", None) is not None:\n        wsd_decay_steps = config.lr_wsd_decay_steps\n    if config.get(\"lr_warmup_steps_ratio\", None) is not None and (\n        config.get(\"lr_warmup_steps\", None) is None or config.lr_warmup_steps <= 0\n    ):\n        config.lr_warmup_steps = int(config.lr_warmup_steps_ratio * config.lr_decay_steps)\n\n    opt_param_scheduler = OptimizerParamScheduler(\n        optimizer,\n        init_lr=config.lr_warmup_init,\n        max_lr=config.lr,\n        min_lr=config.min_lr,\n        lr_warmup_steps=config.lr_warmup_steps,\n        lr_decay_steps=config.lr_decay_steps,\n        lr_decay_style=config.lr_decay_style,\n        start_wd=config.weight_decay,\n        end_wd=config.weight_decay,\n        wd_incr_steps=config.total_training_steps,\n        wd_incr_style=config.weight_decay_incr_style,\n        use_checkpoint_opt_param_scheduler=config.use_checkpoint_opt_param_scheduler,\n        override_opt_param_scheduler=(not config.use_checkpoint_opt_param_scheduler),\n        wsd_decay_steps=wsd_decay_steps,\n        lr_wsd_decay_style=config.lr_wsd_decay_style,\n    )\n\n    return opt_param_scheduler\n\n\ndef get_megatron_last_lr(optimizer):\n    \"\"\"\n    Get the last learning rate from the optimizer parameter scheduler.\n    \"\"\"\n    return optimizer.param_groups[0][\"lr\"]\n"
  },
  {
    "path": "verl_rl/verl/utils/megatron/pipeline_parallel.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom megatron.core import parallel_state as mpu\n\nfrom .sequence_parallel import pad_to_sequence_parallel\n\n\ndef compute_transformers_input_shapes(batches, meta_info):\n    from flash_attn.bert_padding import unpad_input  # flash 2 is a must for Megatron\n\n    # pre-compute input shapes for each micro-batch at each pp stage\n    input_shapes = []\n    for model_inputs in batches:\n        input_ids = model_inputs[\"input_ids\"]\n        attention_mask = model_inputs[\"attention_mask\"]\n        input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0]  # (total_nnz, 1)\n        if meta_info[\"sequence_parallel\"]:\n            input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad)\n            # compute shapes for model_inputs\n            input_shapes.append(\n                torch.Size(\n                    [\n                        input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(),\n                        1,\n                        meta_info[\"hidden_size\"],\n                    ]\n                )\n            )\n        else:\n            # compute shapes for model_inputs\n            input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info[\"hidden_size\"]]))\n    return input_shapes\n\n\ndef make_batch_generator(batches, vpp_size):\n    \"\"\"\n    Creates a batch generator suitable for Megatron pipeline parallelism,\n    handling virtual pipeline parallelism (VPP).\n\n    If VPP is used (vpp_size > 1), it duplicates the batch iterator for each\n    virtual pipeline stage. Otherwise, it returns a single iterator.\n\n    Args:\n        batches: An iterable (e.g., list) of micro-batches.\n        vpp_size (int): The virtual pipeline model parallel size.\n\n    Returns:\n        An iterator or a list of iterators over the micro-batches.\n    \"\"\"\n    if vpp_size > 1:\n        # has vpp\n        batch_generator = [batches] * vpp_size  # number of vpp chunks\n        batch_generator = [iter(b) for b in batch_generator]\n    else:\n        # no vpp\n        batch_generator = iter(batches)\n    return batch_generator\n"
  },
  {
    "path": "verl_rl/verl/utils/megatron/sequence_parallel.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.nn.functional as F\nfrom megatron.core import parallel_state as mpu\n\n\ndef mark_parameter_as_sequence_parallel(parameter):\n    parameter.sequence_parallel = True\n\n\ndef is_sequence_parallel_param(param):\n    return hasattr(param, \"sequence_parallel\") and param.sequence_parallel\n\n\ndef pad_to_sequence_parallel(unpad_tokens: torch.Tensor):\n    \"\"\"pad the tokens such that the total length is a multiple of sp world size\n\n    Args:\n        unpad_tokens: (total_nnz, ...). Tokens after removing padding\n\n    Returns:\n        the padded tokens: (total_nnz + pad_size,...)\n\n    \"\"\"\n    total_nnz = unpad_tokens.shape[0]\n    sp_world_size = mpu.get_tensor_model_parallel_world_size()\n\n    pad_size = 0 if total_nnz % sp_world_size == 0 else sp_world_size - total_nnz % sp_world_size\n\n    if pad_size > 0:\n        if unpad_tokens.ndim == 1:\n            unpad_tokens = F.pad(unpad_tokens, (0, pad_size))\n        elif unpad_tokens.ndim == 2:\n            unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))\n        else:\n            raise NotImplementedError(f\"Padding dim {unpad_tokens.ndim()} is not supported\")\n\n    return unpad_tokens\n"
  },
  {
    "path": "verl_rl/verl/utils/megatron/tensor_parallel.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities for using tensor_parallel in megatron\n\"\"\"\n\nfrom typing import TYPE_CHECKING\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import parallel_state as mpu\nfrom torch.nn import init\n\nif TYPE_CHECKING:\n    from megatron.core import ModelParallelConfig\n\n\ndef update_kwargs_with_config(dictionary: dict, config: \"ModelParallelConfig\"):\n    dictionary[\"config\"] = config\n    return dictionary\n\n\ndef get_default_kwargs_for_model_parallel_config():\n    model_parallel_config_kwargs = {\n        \"params_dtype\": torch.float32,\n        \"use_cpu_initialization\": False,\n        \"perform_initialization\": True,\n        \"gradient_accumulation_fusion\": False,\n        \"sequence_parallel\": False,\n    }\n    return model_parallel_config_kwargs\n\n\ndef get_default_model_parallel_config():\n    from megatron.core import ModelParallelConfig\n\n    return ModelParallelConfig(**get_default_kwargs_for_model_parallel_config())\n\n\ndef get_common_default_kwargs_for_parallel_linear():\n    default_model_parallel_config = get_default_model_parallel_config()\n    common_default_kwargs = {\n        \"init_method\": init.xavier_normal_,\n        \"stride\": 1,\n        \"keep_master_weight_for_test\": False,\n        \"config\": default_model_parallel_config,\n    }\n    return common_default_kwargs\n\n\ndef get_default_kwargs_for_column_parallel_linear():\n    from megatron.core import ModelParallelConfig\n\n    model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()\n    column_parallel_config_kwargs = {\n        \"async_tensor_model_parallel_allreduce\": False,\n    }\n    model_parallel_config_kwargs.update(column_parallel_config_kwargs)\n    column_default_kwargs = {\n        \"config\": ModelParallelConfig(**model_parallel_config_kwargs),\n    }\n    common_default_kwargs = get_common_default_kwargs_for_parallel_linear()\n    common_default_kwargs.update(column_default_kwargs)\n    return common_default_kwargs\n\n\ndef get_default_kwargs_for_row_parallel_linear():\n    common_default_kwargs = get_common_default_kwargs_for_parallel_linear()\n    return common_default_kwargs\n\n\ndef get_default_kwargs_for_parallel_embedding():\n    from megatron.core import ModelParallelConfig\n\n    model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()\n    embedding_default_kwargs = {\n        \"init_method\": init.xavier_normal_,\n        \"config\": ModelParallelConfig(**model_parallel_config_kwargs),\n    }\n    return embedding_default_kwargs\n\n\ndef is_tensor_parallel_param(param):\n    return hasattr(param, \"tensor_model_parallel\") and param.tensor_model_parallel\n\n\ndef get_tensor_parallel_partition_dim(param):\n    assert is_tensor_parallel_param(param)\n    return param.partition_dim\n\n\ndef get_tensor_parallel_partition_stride(param):\n    assert is_tensor_parallel_param(param)\n    return param.partition_stride\n\n\nclass _VocabParallelEntropy(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor:\n        @torch.compile(dynamic=True)\n        def mul_reduce(a, b):\n            return (a * b).sum(dim=-1, keepdim=True)\n\n        logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values\n        dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group())\n        normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max\n        normalized_exp_logits = normalized_vocab_parallel_logits.exp_()\n        normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)\n        dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group())\n        softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits)\n        sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits)\n        dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group())\n        entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits\n        ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)\n        return entropy.squeeze(dim=-1)\n\n    @staticmethod\n    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:\n        vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors\n        # reuse softmax_logits as grad\n        vocab_parallel_logits.sub_(sum_softmax_times_logits)\n        softmax_logits.mul_(vocab_parallel_logits)\n        softmax_logits.mul_(grad_output.unsqueeze(dim=-1))\n        # recover vocab_parallel_logits\n        vocab_parallel_logits.add_(sum_softmax_times_logits)\n        softmax_logits.mul_(-1)\n        return softmax_logits\n\n\ndef vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor:\n    \"\"\"Compute entropy when the logits are sharded in tp ranks\n\n    Args:\n        vocab_parallel_logits: (total_nnz, vocab_size // tp_size)\n\n    Returns: (total_nnz,)\n\n    \"\"\"\n    return _VocabParallelEntropy.apply(vocab_parallel_logits)\n\n\ndef vocab_parallel_log_probs_from_logits(logits, labels):\n    \"\"\"TODO(zhangchi.usc1992): We may change the implementation later\"\"\"\n    from megatron.core import tensor_parallel\n\n    return -tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits=logits, target=labels)\n\n\ndef vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):\n    \"\"\"Similar to log_probs_from_logits_response_rmpad, but the logits_rmpad is now spliited across tensor parallel\n    region.\n    This will further reduce the peak memory usage during training\n\n    Args:\n        input_ids: [batch_size, seqlen]\n        attention_mask: [batch_size, seqlen]\n        logits_rmpad: [total_nnz, vocab_size // tp_size]\n        response_length: int\n\n    \"\"\"\n    from flash_attn.bert_padding import pad_input, unpad_input\n\n    batch_size, seqlen = input_ids.shape\n    input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)\n    input_ids_rmpad = input_ids_rmpad.squeeze(-1)\n    input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)\n    full_log_probs_rmpad = vocab_parallel_log_probs_from_logits(\n        logits=logits_rmpad, labels=input_ids_rmpad_rolled\n    )  # (total_nnz,)\n    full_output = pad_input(\n        hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n    )\n    output = full_output.squeeze(-1)[:, -response_length - 1 : -1]  # [batch_size, response_length]\n    return output\n"
  },
  {
    "path": "verl_rl/verl/utils/megatron_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Pretrain utilities.\"\"\"\n\nimport gc\nimport os\nimport warnings\nfrom typing import Any\n\nimport torch\nimport torch.nn.functional as F\nfrom megatron.core import ModelParallelConfig, mpu, tensor_parallel\nfrom megatron.core.distributed import DistributedDataParallel as DDP\nfrom megatron.core.distributed import DistributedDataParallelConfig\nfrom megatron.core.enums import ModelType\nfrom megatron.core.optimizer import ChainedOptimizer, OptimizerConfig\nfrom megatron.core.transformer import TransformerConfig\nfrom megatron.core.transformer.module import Float16Module\nfrom megatron.core.utils import get_attr_wrapped_model\nfrom transformers import PretrainedConfig\n\nimport verl.utils.megatron.tensor_parallel as tp_utils\nfrom verl.utils.device import get_device_id, get_device_name, get_torch_device\nfrom verl.utils.fs import local_mkdir_safe\nfrom verl.utils.model import normalize_model_name\nfrom verl.utils.torch_dtypes import PrecisionType\n\n\ndef get_model_config(model):\n    return get_attr_wrapped_model(model, \"config\", allow_none=False)\n\n\ndef get_model(\n    model_provider_func,\n    model_type=ModelType.encoder_or_decoder,\n    wrap_with_ddp=True,\n    use_distributed_optimizer=True,\n    transformer_config=None,\n    override_ddp_config=None,\n):\n    \"\"\"Build the model.\"\"\"\n    # Build model.\n    if (\n        mpu.get_pipeline_model_parallel_world_size() > 1\n        and mpu.get_virtual_pipeline_model_parallel_world_size() is not None\n    ):\n        assert model_type != ModelType.encoder_and_decoder, (\n            \"Interleaved schedule not supported for model with both encoder and decoder\"\n        )\n        model = []\n        for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()):\n            mpu.set_virtual_pipeline_model_parallel_rank(i)\n            # Set pre_process and post_process only after virtual rank is set.\n            pre_process = mpu.is_pipeline_first_stage()\n            post_process = mpu.is_pipeline_last_stage()\n            this_model = model_provider_func(pre_process=pre_process, post_process=post_process)\n            this_model.model_type = model_type\n            model.append(this_model)\n        mpu.set_virtual_pipeline_model_parallel_rank(0)\n    else:\n        pre_process = mpu.is_pipeline_first_stage()\n        post_process = mpu.is_pipeline_last_stage()\n        add_encoder = True\n        add_decoder = True\n        if model_type == ModelType.encoder_and_decoder:\n            if mpu.get_pipeline_model_parallel_world_size() > 1:\n                assert mpu.get_pipeline_model_parallel_split_rank() is not None, (\n                    \"Split rank needs to be specified for model with both encoder and decoder\"\n                )\n                rank = mpu.get_pipeline_model_parallel_rank()\n                split_rank = mpu.get_pipeline_model_parallel_split_rank()\n                world_size = mpu.get_pipeline_model_parallel_world_size()\n                pre_process = rank == 0 or rank == split_rank\n                post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1))\n                add_encoder = mpu.is_pipeline_stage_before_split()\n                add_decoder = mpu.is_pipeline_stage_after_split()\n            model = model_provider_func(\n                pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder\n            )\n        else:\n            model = model_provider_func(pre_process=pre_process, post_process=post_process)\n        model.model_type = model_type\n\n    if not isinstance(model, list):\n        model = [model]\n\n    # Set tensor model parallel attributes if not set.\n    # Only parameters that are already tensor model parallel have these\n    # attributes set for them. We should make sure the default attributes\n    # are set for all params so the optimizer can use them.\n    for model_module in model:\n        for param in model_module.parameters():\n            tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)\n\n    # Print number of parameters.\n    if mpu.get_data_parallel_rank() == 0:\n        print(\n            \" > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}\".format(\n                mpu.get_tensor_model_parallel_rank(),\n                mpu.get_pipeline_model_parallel_rank(),\n                sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]),\n            ),\n            flush=True,\n        )\n\n    # GPU allocation.\n    if transformer_config is None or (not transformer_config.use_cpu_initialization):\n        for model_module in model:\n            model_module.to(f\"{get_device_name()}:{get_device_id()}\")\n\n    # Fp16 conversion.\n    config: TransformerConfig = get_model_config(model[0])\n    config.fp8 = None\n    tfconfig: TransformerConfig = model[0].config\n    if config.fp16 or config.bf16:  # the ModelParallelConfig in GPTModel\n        model = [Float16Module(config, model_module) for model_module in model]\n\n    if wrap_with_ddp:\n        ddp_models = []\n        ddp_config_dict = {\n            \"use_distributed_optimizer\": use_distributed_optimizer,\n            \"grad_reduce_in_fp32\": True,\n            \"overlap_grad_reduce\": False,\n        }\n        if override_ddp_config is not None:\n            ddp_config_dict.update(override_ddp_config)\n        ddp_config = DistributedDataParallelConfig(**ddp_config_dict)\n        for model_chunk_idx, model_chunk in enumerate(model):\n            ddp_model = DDP(\n                config=tfconfig,\n                module=model_chunk,\n                disable_bucketing=(model_chunk_idx > 0),\n                ddp_config=ddp_config,\n            )\n            ddp_models.append(ddp_model)\n        model = ddp_models\n        # # Broadcast params from data parallel src rank to other data parallel ranks.\n        # # if args.data_parallel_random_init:\n        for model_module in model:\n            model_module.broadcast_params()\n    return model\n\n\nALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)\n\n\ndef unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):\n    return_list = True\n    if not isinstance(model, list):\n        model = [model]\n        return_list = False\n    unwrapped_model = []\n    for model_module in model:\n        while isinstance(model_module, module_instances):\n            model_module = model_module.module\n        unwrapped_model.append(model_module)\n    if not return_list:\n        return unwrapped_model[0]\n    return unwrapped_model\n\n\ndef convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig:\n    print(f\"megatron config {megatron_config}\")\n    dt = PrecisionType.to_dtype(megatron_config.params_dtype)\n    print(f\"pipeline_dtype=megatron_config {dt}\")\n    qkv_bias = True if \"Qwen2ForCausalLM\" in hf_config.architectures else getattr(hf_config, \"attention_bias\", False)\n    overlap_p2p_comm = (\n        mpu.get_virtual_pipeline_model_parallel_world_size() is not None\n        and mpu.get_virtual_pipeline_model_parallel_world_size() > 1\n    )\n    batch_p2p_comm = False\n    transformer_config = TransformerConfig(\n        num_layers=hf_config.num_hidden_layers,\n        hidden_size=hf_config.hidden_size,\n        num_attention_heads=hf_config.num_attention_heads,\n        num_query_groups=hf_config.num_key_value_heads,\n        ffn_hidden_size=hf_config.intermediate_size,\n        #    max_position_embeddings=hf_config.max_position_embeddings,\n        activation_func=F.silu,\n        normalization=\"RMSNorm\",\n        #    rotary_percent=False, # default,\n        gated_linear_unit=True,  # for llama\n        use_cpu_initialization=True,\n        apply_residual_connection_post_layernorm=False,  # check what's this mean\n        add_bias_linear=False,\n        tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),\n        pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),\n        virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),\n        context_parallel_size=mpu.get_context_parallel_world_size(),\n        overlap_p2p_comm=overlap_p2p_comm,\n        batch_p2p_comm=batch_p2p_comm,\n        pipeline_dtype=dt,\n        params_dtype=dt,\n        sequence_parallel=mpu.get_tensor_model_parallel_world_size() > 1,\n        variable_seq_lengths=True,\n        masked_softmax_fusion=True,\n        moe_token_dispatcher_type=\"alltoall\",\n        attention_dropout=hf_config.attention_dropout,\n        hidden_dropout=getattr(hf_config, \"hidden_dropout\", 0.0),\n        add_qkv_bias=qkv_bias,\n        bf16=dt is torch.bfloat16,\n    )\n\n    return transformer_config\n\n\ndef init_megatron_optim_config(optim_config: dict) -> OptimizerConfig:\n    config = OptimizerConfig(\n        optimizer=optim_config.get(\"optimizer\", \"adam\"),\n        lr=optim_config.get(\"lr\"),\n        min_lr=optim_config.get(\"min_lr\", None),\n        clip_grad=optim_config.get(\"clip_grad\", 1.0),\n        weight_decay=optim_config.get(\"weight_decay\", 0.01),\n        bf16=True,\n        params_dtype=torch.bfloat16,\n        use_distributed_optimizer=True,\n    )\n    return config\n\n\ndef mcore_model_parallel_config(\n    sequence_parallel: bool,\n    params_dtype: torch.dtype,\n) -> ModelParallelConfig:\n    # WARNING: Code should not reach this point. This function is deprecated and will be removed.\n    # Please use hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead.\n    warnings.warn(\n        \"Code should not reach this point. This function is deprecated and will be removed. Please use \"\n        \"hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead.\",\n        DeprecationWarning,\n        stacklevel=2,\n    )\n    return ModelParallelConfig(\n        tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),\n        pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),\n        virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),\n        context_parallel_size=mpu.get_context_parallel_world_size(),\n        sequence_parallel=sequence_parallel,\n        params_dtype=params_dtype,\n        pipeline_dtype=params_dtype,\n        bf16=True,\n        fp16=False,\n        timers=None,\n    )\n\n\n@torch.no_grad()\ndef offload_megatron_model_to_cpu(models):\n    \"\"\"\n    In megatron, the model and optimizer storage are:\n    - bf16 parameter data chunked in model parallel group\n    - fp32 grad chunked in model parallel group\n    - fp32 main_parameter chunked in model and dp group\n    - fp32 optimizer state chunked in model and dp group\n    \"\"\"\n    for model_chunk in models:\n        if isinstance(model_chunk, DDP):\n            model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]\n            for buffers in model_chunk_all_buffers:\n                for buffer in buffers:\n                    # offload parameters\n                    if buffer.param_data.storage().size() > 0:\n                        buffer.param_data.cpu_data = buffer.param_data.data.cpu().pin_memory()\n                        buffer.param_data_size = buffer.param_data.storage().size()\n                        buffer.param_data.storage().resize_(0)\n\n                    assert buffer.param_data_size == buffer.param_data.cpu_data.storage().size()\n\n                    if buffer.grad_data.storage().size() > 0:\n                        # if the grad_data size is already zero, we assume that it is already offloaded\n                        buffer.grad_data_size = buffer.grad_data.storage().size()\n                        buffer.grad_data.storage().resize_(0)\n        else:\n            # we need this for ref module\n            for _, param in model_chunk.named_parameters():\n                param.data = param.data.to(\"cpu\", non_blocking=True)\n                if param.grad is not None:\n                    param.grad = param.grad.to(\"cpu\", non_blocking=True)\n    gc.collect()\n    get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef load_megatron_model_to_gpu(models, load_grad=True):\n    for model_chunk in models:\n        if isinstance(model_chunk, DDP):\n            model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]\n            for buffers in model_chunk_all_buffers:\n                for buffer in buffers:\n                    # sometimes, we don't want to load grad for pure inference\n                    if load_grad:\n                        buffer.grad_data.storage().resize_(buffer.grad_data_size)\n                        buffer.grad_data.zero_()\n\n                    if buffer.param_data.storage().size() == 0:\n                        buffer.param_data.storage().resize_(buffer.param_data_size)\n                        # copy data from cpu to cuda\n                        buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True)\n        else:\n            # we need this for ref module\n            device_id = get_device_id()\n            for _, param in model_chunk.named_parameters():\n                param.data = param.data.to(device_id, non_blocking=True)\n                if param.grad is not None:\n                    param.grad = param.grad.to(device_id, non_blocking=True)\n    gc.collect()\n    get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef offload_megatron_copy_params(optimizers):\n    \"\"\"\n    Offload optimizer parameters to CPU. Supports both Megatron optimizers\n    and `ChainedOptimizer`, which wraps a list of underlying optimizers.\n\n    Args:\n        optimizers: The optimizer or ChainedOptimizer instance.\n    \"\"\"\n\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    def offload_tensor_to_cpu(tensor):\n        if tensor is None:\n            return\n        tensor.data = tensor.data.to(\"cpu\", non_blocking=True)\n\n    def offload_group_to_cpu(group):\n        if group is None:\n            return\n\n        if isinstance(group, list):\n            for param_group in group:\n                if isinstance(param_group, list):\n                    for param in param_group:\n                        offload_tensor_to_cpu(param)\n                else:\n                    offload_tensor_to_cpu(param_group)\n        else:\n            offload_tensor_to_cpu(group)\n\n    # Offload all parameter groups to CPU for each underlying optimizer\n\n    for _opt in _iter_opts(optimizers):\n        if hasattr(_opt, \"shard_fp32_from_float16_groups\"):\n            offload_group_to_cpu(_opt.shard_fp32_from_float16_groups)\n\n\n@torch.no_grad()\ndef load_megatron_copy_params(optimizers):\n    \"\"\"\n    Load optimizer parameters back to GPU. Handles ChainedOptimizer.\n\n    Args:\n        optimizers: Optimizer or ChainedOptimizer instance.\n    \"\"\"\n\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    def load_tensor_to_gpu(tensor):\n        if tensor is None:\n            return\n        device_id = get_device_id()\n        tensor.data = tensor.data.to(device_id, non_blocking=True)\n\n    def load_group_to_gpu(group):\n        if group is None:\n            return\n\n        if isinstance(group, list):\n            for param_group in group:\n                if isinstance(param_group, list):\n                    for param in param_group:\n                        load_tensor_to_gpu(param)\n                else:\n                    load_tensor_to_gpu(param_group)\n        else:\n            load_tensor_to_gpu(group)\n\n    # Load all parameter groups to GPU for each underlying optimizer\n\n    for _opt in _iter_opts(optimizers):\n        if hasattr(_opt, \"shard_fp32_from_float16_groups\"):\n            load_group_to_gpu(_opt.shard_fp32_from_float16_groups)\n\n\n@torch.no_grad()\ndef offload_megatron_optimizer(optimizers):\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    for _opt in _iter_opts(optimizers):\n        offload_megatron_copy_params(_opt)\n        opt_state_dict_values = _opt.optimizer.state.values()\n        for v in opt_state_dict_values:\n            if \"exp_avg\" in v:\n                v[\"exp_avg\"] = v[\"exp_avg\"].to(\"cpu\", non_blocking=True)\n            if \"exp_avg_sq\" in v:\n                v[\"exp_avg_sq\"] = v[\"exp_avg_sq\"].to(\"cpu\", non_blocking=True)\n        gc.collect()\n        get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef load_megatron_optimizer(optimizers):\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    for _opt in _iter_opts(optimizers):\n        load_megatron_copy_params(_opt)\n        opt_state_dict_values = _opt.optimizer.state.values()\n        for v in opt_state_dict_values:\n            if \"exp_avg\" in v:\n                v[\"exp_avg\"] = v[\"exp_avg\"].to(get_device_id(), non_blocking=True)\n            if \"exp_avg_sq\" in v:\n                v[\"exp_avg_sq\"] = v[\"exp_avg_sq\"].to(get_device_id(), non_blocking=True)\n        gc.collect()\n        get_torch_device().empty_cache()\n\n\ndef get_dist_checkpoint_path(checkpoint_path):\n    local_mkdir_safe(checkpoint_path)\n    local_mkdir_safe(os.path.join(checkpoint_path, \"dist_ckpt\"))\n    return os.path.join(checkpoint_path, \"dist_ckpt\")\n\n\ndef get_hf_model_checkpoint_path(checkpoint_path):\n    local_mkdir_safe(checkpoint_path)\n    local_mkdir_safe(os.path.join(checkpoint_path, \"huggingface\"))\n    return os.path.join(checkpoint_path, \"huggingface\")\n\n\ndef get_transformer_config_checkpoint_path(checkpoint_path):\n    os.makedirs(checkpoint_path, exist_ok=True)\n    return os.path.join(checkpoint_path, \"transformer_config.json\")\n\n\ndef convert_megatron_model_to_transformers_model(\n    name,\n    param,\n    config: PretrainedConfig,\n    tp_size: int,\n    num_query_groups: int,\n    convert_qkv_gate_up_by_trunk_concat=False,\n):\n    \"\"\"Convert megatron model to transformers model.\"\"\"\n    new_params = {}\n\n    def convert_qkv_shard(full_tensor, q_name, k_name, v_name):\n        nonlocal config\n        nonlocal tp_size\n        nonlocal num_query_groups\n\n        q_shard_list = []\n        k_shard_list = []\n        v_shard_list = []\n        hidden_size_per_head = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n\n        if config.num_key_value_heads >= tp_size:\n            q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n            kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n            total_size = q_size_tp + 2 * kv_size_tp\n            for i in range(tp_size):\n                num_query_groups_per_partition = num_query_groups // tp_size\n                qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                q_size_chunk = q_size_tp // num_query_groups_per_partition\n                kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                    q_part = qkv_part_chunk[:q_size_chunk]\n                    k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                    v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                    q_shard_list.append(q_part)\n                    k_shard_list.append(k_part)\n                    v_shard_list.append(v_part)\n        else:\n            q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n            kv_size_tp = hidden_size_per_head\n            total_size = q_size_tp + 2 * kv_size_tp\n            for i in range(tp_size):\n                num_query_groups_per_partition = num_query_groups // tp_size\n                qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                q_size_chunk = q_size_tp // num_query_groups_per_partition\n                kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                    q_part = qkv_part_chunk[:q_size_chunk]\n                    k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                    v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                    q_shard_list.append(q_part)\n                    if i * config.num_key_value_heads % tp_size == 0:\n                        k_shard_list.append(k_part)\n                        v_shard_list.append(v_part)\n\n        new_params[q_name] = torch.cat(q_shard_list, dim=0)\n        new_params[k_name] = torch.cat(k_shard_list, dim=0)\n        new_params[v_name] = torch.cat(v_shard_list, dim=0)\n\n    def convert_gate_up_shard(full_tensor, gate_name, up_name):\n        nonlocal config\n        nonlocal tp_size\n\n        intermediate_size_tp = config.intermediate_size // tp_size\n        gate_weight_list = []\n        up_weight_list = []\n        for i in range(tp_size):\n            gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n            gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n            up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n            gate_weight_list.append(gate_weight_tp)\n            up_weight_list.append(up_weight_tp)\n\n        new_params[gate_name] = torch.cat(gate_weight_list, dim=0)\n        new_params[up_name] = torch.cat(up_weight_list, dim=0)\n\n    if name == \"embedding.word_embeddings.weight\":\n        new_params[\"model.embed_tokens.weight\"] = param\n    elif \"self_attention\" in name:\n        splitted_name = name.split(\".\")\n        layer_number = splitted_name[2]\n        component = splitted_name[4]\n        param_type = splitted_name[5]\n        if component == \"linear_proj\":\n            new_params[f\"model.layers.{layer_number}.self_attn.o_proj.weight\"] = param\n        elif component == \"linear_qkv\" and not isinstance(param, list):\n            if param_type == \"layer_norm_weight\":\n                new_params[f\"model.layers.{layer_number}.input_layernorm.weight\"] = param\n            else:\n                if convert_qkv_gate_up_by_trunk_concat:\n                    convert_qkv_shard(\n                        param,\n                        f\"model.layers.{layer_number}.self_attn.q_proj.{param_type}\",\n                        f\"model.layers.{layer_number}.self_attn.k_proj.{param_type}\",\n                        f\"model.layers.{layer_number}.self_attn.v_proj.{param_type}\",\n                    )\n                else:\n                    new_params[f\"model.layers.{layer_number}.self_attn.qkv_proj.{param_type}\"] = param\n        elif component == \"q_layernorm\" or component == \"k_layernorm\":\n            hf_component = component.replace(\"layer\", \"\")\n            new_params[f\"model.layers.{layer_number}.self_attn.{hf_component}.weight\"] = param\n        else:\n            assert isinstance(param, list) and len(param) == 3\n            assert param_type == \"weight\" or param_type == \"bias\"\n            new_params[f\"model.layers.{layer_number}.self_attn.q_proj.{param_type}\"] = param[0]\n            new_params[f\"model.layers.{layer_number}.self_attn.k_proj.{param_type}\"] = param[1]\n            new_params[f\"model.layers.{layer_number}.self_attn.v_proj.{param_type}\"] = param[2]\n    elif \"mlp\" in name:\n        splitted_name = name.split(\".\")\n        layer_number = splitted_name[2]\n        component = splitted_name[4]\n        param_type = splitted_name[5]\n        if component == \"linear_fc1\" and not isinstance(param, list):\n            if param_type == \"layer_norm_weight\":\n                new_params[f\"model.layers.{layer_number}.post_attention_layernorm.weight\"] = param\n            elif param_type == \"weight\":\n                if convert_qkv_gate_up_by_trunk_concat:\n                    convert_gate_up_shard(\n                        param,\n                        f\"model.layers.{layer_number}.mlp.gate_proj.weight\",\n                        f\"model.layers.{layer_number}.mlp.up_proj.weight\",\n                    )\n                else:\n                    new_params[f\"model.layers.{layer_number}.mlp.gate_up_proj.weight\"] = param\n        elif component == \"linear_fc1\" and isinstance(param, list):\n            assert len(param) == 2\n            assert param_type == \"weight\" or param_type == \"bias\"\n            new_params[f\"model.layers.{layer_number}.mlp.gate_proj.weight\"] = param[0]\n            new_params[f\"model.layers.{layer_number}.mlp.up_proj.weight\"] = param[1]\n        elif component == \"linear_fc2\":\n            new_params[f\"model.layers.{layer_number}.mlp.down_proj.weight\"] = param\n    elif name == \"decoder.final_layernorm.weight\":\n        new_params[\"model.norm.weight\"] = param\n    elif name == \"output_layer.weight\":\n        new_params[\"lm_head.weight\"] = param\n    else:\n        raise ValueError(f\"Unknown param name: {name}\")\n    return new_params.keys(), new_params.values()\n\n\ndef broadcast_from_megatron_pp(tensor: torch.Tensor):\n    # tensor is not None only in one of the pp ranks\n    if tensor is not None:\n        shape = tensor.shape\n        dtype = tensor.dtype\n        tensor_parallel = getattr(tensor, \"tensor_model_parallel\", None)\n        partition_dim = getattr(tensor, \"partition_dim\", None)\n        tensor_spec = (shape, dtype, tensor_parallel, partition_dim)\n    else:\n        tensor_spec = None\n    tensor_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size()\n    torch.distributed.all_gather_object(\n        object_list=tensor_spec_output, obj=tensor_spec, group=mpu.get_pipeline_model_parallel_group()\n    )\n    # find the src rank\n    target_tensor_spec = None\n    src_rank = None\n    for rank, tensor_spec in enumerate(tensor_spec_output):\n        if tensor_spec is not None:\n            if target_tensor_spec is None:\n                target_tensor_spec = tensor_spec\n            else:\n                raise ValueError(\"A tensor exists on two pp ranks\")\n            src_rank = rank\n    assert target_tensor_spec is not None\n    if tensor is None:\n        tensor = torch.empty(size=target_tensor_spec[0], dtype=target_tensor_spec[1], device=get_device_id())\n        if target_tensor_spec[2] is not None:\n            tensor.tensor_model_parallel = target_tensor_spec[2]\n        if target_tensor_spec[3] is not None:\n            tensor.partition_dim = target_tensor_spec[3]\n\n    global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank)\n    torch.distributed.broadcast(tensor=tensor, src=global_rank, group=mpu.get_pipeline_model_parallel_group())\n    return tensor\n\n\ndef broadcast_str_from_megatron_pp(obj: Any):\n    obj_output = [None] * mpu.get_pipeline_model_parallel_world_size()\n    torch.distributed.all_gather_object(object_list=obj_output, obj=obj, group=mpu.get_pipeline_model_parallel_group())\n\n    src_rank = None\n    target_obj = None\n    for rank, item in enumerate(obj_output):\n        if item is not None:\n            if target_obj is not None:\n                raise ValueError(\"An object exists on two pp ranks\")\n            target_obj = item\n            src_rank = rank\n\n    assert target_obj is not None, \"No valid object found to broadcast.\"\n\n    global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank)\n\n    obj_output = [None] * torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group())\n    obj_output[0] = target_obj\n    torch.distributed.broadcast_object_list(\n        object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group()\n    )\n\n    return obj_output[0]\n\n\ndef default_tp_concat_fn(\n    layer_name_mapping,\n    name,\n    train_params,\n    infer_params,\n    model_config,\n    hf_config=None,\n    convert_qkv_gate_up_by_simple_split=False,\n):\n    \"\"\"\n    name: name of the parameter\n    train_params: training parameters\n    infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group\n    model_config: huggingface model_config\n    TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model\n    definition so that it is model-agnostic. If the model doesn't implement this function,\n    we can throw an error to force user disable TP HybridEngine.\n    \"\"\"\n    from megatron.core import mpu\n\n    train_tp_size = mpu.get_tensor_model_parallel_world_size()\n    if layer_name_mapping.get(\"qkv_layer_name\") in name and \"layer_norm\" not in name:\n        # if the tensor is qkv, for each param on tp, split into q, k, v\n        # concat q, k, v separately.\n        q_lst = []\n        k_lst = []\n        v_lst = []\n        num_attention_heads = model_config.num_attention_heads\n        num_key_value_heads = model_config.num_key_value_heads\n        if \"vision_model\" in name:\n            num_attention_heads = hf_config.vision_config.num_heads\n            num_key_value_heads = hf_config.vision_config.num_heads\n        assert num_attention_heads % num_key_value_heads == 0\n        num_q_per_kv = num_attention_heads // num_key_value_heads\n        assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, (\n            f\"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}\"\n        )\n        kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)\n        split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]\n        for infer_param in infer_params:\n            num_query_groups_per_partition = num_key_value_heads // train_tp_size\n            for chunk in infer_param.chunk(num_query_groups_per_partition):\n                split_size = [\n                    kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,\n                    kv_size_per_tp // num_query_groups_per_partition,\n                    kv_size_per_tp // num_query_groups_per_partition,\n                ]\n                q, k, v = chunk.split(split_size)\n                q_lst.append(q)\n                k_lst.append(k)\n                v_lst.append(v)\n        q = torch.cat(q_lst, dim=0)\n        k = torch.cat(k_lst, dim=0)\n        v = torch.cat(v_lst, dim=0)\n        infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v]\n\n    elif (\n        layer_name_mapping.get(\"gate_proj_layer_name\") in name\n        and \"layer_norm\" not in name\n        and \"vision_model.projection\" not in name\n    ):\n        # if the tensor is gate and proj\n        gate_lst = []\n        up_lst = []\n        for infer_param in infer_params:\n            gate, up = infer_param.chunk(2)\n            gate_lst.append(gate)\n            up_lst.append(up)\n        gate = torch.cat(gate_lst, dim=0)\n        up = torch.cat(up_lst, dim=0)\n        infer_params = torch.cat((gate, up), dim=0) if not convert_qkv_gate_up_by_simple_split else [gate, up]\n\n    elif \"mlp.experts.linear_fc2.weight\" in name:  # moe\n        infer_params = torch.cat(infer_params, dim=1)\n\n    else:\n        # concat tensor\n        infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(train_params))\n\n    return infer_params\n\n\ndef per_tensor_generator(\n    actor_module,\n    model_config,\n    weight_converter,\n    transformer_config,\n    layer_name_mapping,\n    convert_qkv_gate_up_by_simple_split=True,\n):\n    from megatron.core import parallel_state as mpu\n\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    ep_size = mpu.get_expert_model_parallel_world_size()\n    etp_size = mpu.get_expert_tensor_parallel_world_size()\n    ep_group = mpu.get_expert_model_parallel_group()\n    etp_group = mpu.get_expert_tensor_parallel_group()\n    vpp_size = len(actor_module)\n    all_gather_group = mpu.get_tensor_model_parallel_group()\n    all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group)\n\n    def tensor_generator():\n        for scan_vpp_idx in range(vpp_size):\n            existing_keys = set()\n            model = unwrap_model(actor_module[scan_vpp_idx])\n            for name, param in model.named_parameters():\n                existing_keys.add(name)\n                yield name, param\n            # note\n            # there is a bug in megatron GPTModel\n            # decoder.layers[n].mlp.router.expert_bias\" in GPTModel is not registered in named_parameter, but in\n            # state_dict(). for now we patch it by adding those keys to extra_keys.\n            extra_keys = [x for x in model.state_dict().keys() if \"_extra_state\" not in x and x not in existing_keys]\n            for name in extra_keys:\n                yield name, model.state_dict()[name].to(get_device_id())\n\n    # we need first make all rank get full model information\n    meta_info = []\n    for scan_vpp_idx in range(vpp_size):\n        existing_keys = set()\n        model = unwrap_model(actor_module[scan_vpp_idx])\n        for idx, (name, _) in enumerate(model.named_parameters()):\n            existing_keys.add(name)\n            meta_info.append((pp_rank, scan_vpp_idx, idx, name))\n        extra_keys = [x for x in model.state_dict().keys() if \"_extra_state\" not in x and x not in existing_keys]\n        for name in extra_keys:\n            meta_info.append((pp_rank, scan_vpp_idx, idx, name))\n\n    obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size()\n    torch.distributed.all_gather_object(\n        object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group()\n    )\n    layer_list_meta = [item for sublist in obj_spec_output for item in sublist]\n\n    gen_func = tensor_generator()\n\n    # lazy load tensor for full model\n    for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta:\n        if model_config.tie_word_embeddings and (\"output_layers\" in name):\n            import warnings\n\n            warnings.warn(\n                \"Current model sharing word and embedding weights, skip output layer conversion\", stacklevel=2\n            )\n            continue\n\n        if cur_pp_rank == pp_rank:\n            try:\n                cur_name, cur_tensor = next(gen_func)\n            except StopIteration:\n                cur_name, cur_tensor = None, None\n            cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, transformer_config)\n        else:\n            cur_tensor, cur_name = None, None\n\n        # pp broadcast model tensor and name\n        cur_name = broadcast_str_from_megatron_pp(cur_name)\n        broad_pp_tensor = broadcast_from_megatron_pp(cur_tensor)\n\n        # (xya): this is a hack to fix the name of the parameters\n        while cur_name.startswith(\"module.\"):\n            cur_name = cur_name[len(\"module.\") :]\n\n        # EP\n        if \".mlp.experts.linear_fc\" in cur_name and ep_size > 1:\n            num_experts = weight_converter.mcore_config.num_moe_experts\n            num_experts_per_rank = num_experts // ep_size\n            infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(ep_size)]\n            torch.distributed.all_gather(infer_params, broad_pp_tensor, group=ep_group)\n\n            name_prefix, local_expert_id = cur_name.split(\".weight\")\n            local_expert_id = int(local_expert_id)\n            global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(ep_size)]\n            global_expert_names = [f\"{name_prefix}.weight{expert_id}\" for expert_id in global_expert_ids]\n\n            for name, param in zip(global_expert_names, infer_params, strict=True):\n                if etp_size > 1:\n                    # gather etp\n                    etp_params = [torch.empty_like(param) for _ in range(etp_size)]\n                    torch.distributed.all_gather(etp_params, param, group=etp_group)\n                    params = etp_params\n                else:\n                    params = [param]\n\n                merge_params = default_tp_concat_fn(\n                    layer_name_mapping,\n                    name,\n                    broad_pp_tensor,\n                    params,\n                    model_config,\n                    weight_converter.hf_config,\n                    convert_qkv_gate_up_by_simple_split,\n                )\n                if not isinstance(merge_params, list):\n                    merge_params = [merge_params]\n                converted_names, converted_params = weight_converter.convert_param(name, merge_params)\n\n                yield from zip(converted_names, converted_params, strict=True)\n            continue\n\n        # tp all gather\n        if tp_utils.is_tensor_parallel_param(broad_pp_tensor):\n            # allocate a new tensor with proper size\n            if all_gather_group_size <= 1:\n                infer_params = [broad_pp_tensor]\n            else:\n                infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)]\n                torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group())\n            infer_params = default_tp_concat_fn(\n                layer_name_mapping,\n                cur_name,\n                broad_pp_tensor,\n                infer_params,\n                model_config,\n                weight_converter.hf_config,\n                convert_qkv_gate_up_by_simple_split,\n            )\n        else:\n            infer_params = broad_pp_tensor\n\n        if not isinstance(infer_params, list):\n            infer_params = [infer_params]\n        converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params)\n\n        yield from zip(converted_names, converted_params, strict=True)\n\n\ndef get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConfig):\n    '''\n    Get the index offset of any pipeline stage, given the level of pipelining.\n\n    Make pp_rank and vpp_rank as two arguments to make it more flexible,\n    which is able to fetch layer offset for any pipeline stage.\n    The original function only returns the layer offset for current pipeline stage.\n\n    Extension to https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_layer.py::get_transformer_layer_offset\"\"\"\n    '''\n    if config.pipeline_model_parallel_size > 1:\n        if (\n            config.num_layers_in_first_pipeline_stage is not None\n            or config.num_layers_in_last_pipeline_stage is not None\n        ):\n            # Calculate number of pipeline stages to distribute the remaining Transformer\n            # layers after deducting the Transformer layers in the first or the last stages\n            middle_pipeline_stages = config.pipeline_model_parallel_size\n            middle_pipeline_stages -= sum(\n                [\n                    1 if x is not None else 0\n                    for x in (\n                        config.num_layers_in_first_pipeline_stage,\n                        config.num_layers_in_last_pipeline_stage,\n                    )\n                ]\n            )\n\n            # Calculate layers to distribute in each pipeline stage. If the\n            # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage\n            # are not set, we will not enable uneven pipeline. All layers will be treated\n            # as middle layers.\n            num_layers_in_first_pipeline_stage = (\n                0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage\n            )\n            num_layers_in_last_pipeline_stage = (\n                0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage\n            )\n\n            middle_num_layers = (\n                config.num_layers - num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage\n            )\n\n            if mpu.get_virtual_pipeline_model_parallel_world_size() is not None:\n                vp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n\n                # Calculate number of layers in each virtual model chunk\n                # If the num_layers_in_first_pipeline_stage and\n                # num_layers_in_last_pipeline_stage are not set, all pipeline stages\n                # will be treated as middle pipeline stages in the calculation\n                num_layers_per_virtual_model_chunk_in_first_pipeline_stage = (\n                    0\n                    if config.num_layers_in_first_pipeline_stage is None\n                    else config.num_layers_in_first_pipeline_stage // vp_size\n                )\n\n                num_layers_per_virtual_model_chunk_in_last_pipeline_stage = (\n                    0\n                    if config.num_layers_in_last_pipeline_stage is None\n                    else config.num_layers_in_last_pipeline_stage // vp_size\n                )\n\n                num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = middle_num_layers // vp_size\n\n                # First stage + middle stage + last stage\n                total_virtual_chunks = (\n                    num_layers_per_virtual_model_chunk_in_first_pipeline_stage\n                    + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage\n                    + num_layers_per_virtual_model_chunk_in_last_pipeline_stage\n                )\n\n                # Calculate the layer offset with interleaved uneven pipeline parallelism\n                if pipeline_rank == 0:\n                    offset = vp_rank * total_virtual_chunks\n                else:\n                    offset = (\n                        vp_rank * total_virtual_chunks\n                        + num_layers_per_virtual_model_chunk_in_first_pipeline_stage\n                        + (pipeline_rank - 1)\n                        * (num_layers_per_vritual_model_chunk_in_middle_pipeline_stage // middle_pipeline_stages)\n                    )\n            else:\n                if middle_pipeline_stages > 0:\n                    num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages\n                else:\n                    num_layers_per_pipeline_rank = 0\n\n                middle_pipeline_rank = (\n                    pipeline_rank if config.num_layers_in_first_pipeline_stage is None else pipeline_rank - 1\n                )\n\n                if pipeline_rank == 0:\n                    offset = 0\n                else:\n                    offset = (middle_pipeline_rank * num_layers_per_pipeline_rank) + num_layers_in_first_pipeline_stage\n        else:\n            num_layers = config.num_layers\n\n            # Increase the number of layers by one if we include the embedding (loss)\n            # layer into pipeline parallelism partition and placement\n            if config.account_for_embedding_in_pipeline_split:\n                num_layers += 1\n\n            if config.account_for_loss_in_pipeline_split:\n                num_layers += 1\n\n            num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size\n\n            if mpu.get_virtual_pipeline_model_parallel_world_size() is not None:\n                vp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n\n                num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size\n                total_virtual_chunks = num_layers // vp_size\n                offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)\n\n                # Reduce the offset of embedding layer from the total layer number\n                if config.account_for_embedding_in_pipeline_split and not mpu.is_pipeline_first_stage():\n                    offset -= 1\n            else:\n                offset = pipeline_rank * num_layers_per_pipeline_rank\n\n                # Reduce the offset of embedding layer from the total layer number\n                if config.account_for_embedding_in_pipeline_split and not mpu.is_pipeline_first_stage():\n                    offset -= 1\n    else:\n        offset = 0\n    return offset\n"
  },
  {
    "path": "verl_rl/verl/utils/memory_buffer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis file contains utilities to manipulate torch memory buffers\n\"\"\"\n\nfrom typing import Optional\n\nimport torch\nfrom torch import nn\n\nfrom verl.utils.device import get_device_name\n\n\nclass MemoryBuffer:\n    \"\"\"\n    A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying\n    memory. It must have a unique type to support this behavior.\n    \"\"\"\n\n    def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Optional[torch.Tensor] = None):\n        self.numel = numel\n        self.numel_padded = numel_padded\n        self.dtype = dtype\n        if source is not None:\n            self.data = source\n        else:\n            self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_name(), requires_grad=False)\n\n    def zero(self):\n        \"\"\"Reset the buffer to zero.\"\"\"\n        self.data.zero_()\n\n    def get(self, shape, start_index):\n        \"\"\"Return a tensor with the input `shape` as a view into the\n        1-D data starting at `start_index`.\"\"\"\n        end_index = start_index + shape.numel()\n        assert end_index <= self.numel, \"requested tensor is out of the buffer range.\"\n        buffer_tensor = self.data[start_index:end_index]\n        buffer_tensor = buffer_tensor.view(shape)\n        return buffer_tensor\n\n\ndef calc_padded_numel(shape: torch.Size, dtype: torch.dtype):\n    \"\"\"for cuda memory alignment, make sure alignment by 128-bits\"\"\"\n    align_numel = 128 // torch.finfo(dtype).bits\n    numel = shape.numel()\n    return (numel + align_numel - 1) // align_numel * align_numel\n\n\ndef get_weight_buffer_meta_from_module(module: nn.Module) -> dict[str, dict]:\n    \"\"\"\n    Return a dictionary containing name to a shape and dtype.\n    \"\"\"\n    weight_buffer_meta = {}\n    for name, param in sorted(module.named_parameters()):\n        weight_buffer_meta[name] = {\"shape\": param.shape, \"dtype\": param.dtype}\n    return weight_buffer_meta\n\n\ndef build_memory_buffer(weight_buffer_meta: dict[str, dict]) -> dict[torch.dtype, MemoryBuffer]:\n    \"\"\"Build the memory buffer given weight_buffer_meta\n\n    Args:\n        weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors\n\n    Returns: a large memory buffer for each dtype that can hold all the tensors\n\n    \"\"\"\n    memory_buffers = {}\n    total_numel_map = {}  # map from dtype to the total numel\n    for name, meta_info in sorted(weight_buffer_meta.items()):\n        shape = meta_info[\"shape\"]\n        dtype = meta_info[\"dtype\"]\n\n        assert isinstance(shape, torch.Size)\n        assert isinstance(dtype, torch.dtype)\n\n        if dtype not in total_numel_map:\n            total_numel_map[dtype] = 0\n\n        total_numel_map[dtype] += calc_padded_numel(shape, dtype)\n\n    for dtype, total_numel in total_numel_map.items():\n        memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype)\n\n    return memory_buffers\n\n\ndef build_memory_reference_from_module(\n    module: torch.nn.Module, memory_buffers: dict[torch.dtype, MemoryBuffer], maintain_weight=True\n):\n    start_index = {}\n    for dtype in memory_buffers:\n        start_index[dtype] = 0\n    for name, param in sorted(module.named_parameters()):\n        memory_buffer = memory_buffers[param.dtype]\n        buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype])\n        # need to increment start_index\n        start_index[param.dtype] += calc_padded_numel(param.shape, param.dtype)\n        if maintain_weight:\n            buffer.copy_(param.data)\n        param.data = buffer\n\n\ndef build_memory_reference(weight_buffer_meta: dict[str, dict], memory_buffers: dict[torch.dtype, MemoryBuffer]):\n    \"\"\"Build the memory references. The memory buffers are built using the build_memory_buffer API.\n    This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta.\n\n    Args:\n        weight_buffer_meta:\n        memory_buffers:\n\n    Returns:\n\n    \"\"\"\n    start_idx = {}\n    weight_buffers = {}\n    for dtype in memory_buffers:\n        start_idx[dtype] = 0\n\n    for name, meta_info in sorted(weight_buffer_meta.items()):\n        shape = meta_info[\"shape\"]\n        dtype = meta_info[\"dtype\"]\n\n        buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype])\n        start_idx[dtype] += calc_padded_numel(shape, dtype)\n        weight_buffers[name] = buffer\n\n    return weight_buffers\n\n\nclass MemoryBufferModuleWrapper:\n    \"\"\"\n    Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to\n    - It will change the checkpoint name\n    \"\"\"\n\n    def __init__(self, module: nn.Module):\n        super().__init__()\n        self.module = module\n        self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module)\n        self.memory_buffers = build_memory_buffer(self.weight_buffer_meta)\n        build_memory_reference_from_module(self.module, self.memory_buffers)\n\n    def get_memory_buffers(self):\n        return self.memory_buffers\n\n    def get_weight_buffer_meta(self):\n        return self.weight_buffer_meta\n\n\nclass MegatronMemoryBufferForRollout:\n    \"\"\"\n    We assume that\n    - inference engine has tp + dp\n    - actor has tp + pp + dp\n    - the tp between inference engine and actor should be the same\n    - memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer\n    - weight_buffers: contains a list of weight_buffers, each is a dict from name to param\n    - named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that\n        the named_parameters may not be directly compatible with inference engine. User has to take care of\n        this part such as the layout mismatches. (e.g. qkv transpose)\n    - Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory.\n    - When doing weight sync, the data is transfer via memory buffers\n    \"\"\"\n\n    def __init__(self, transform_memory_param_fn):\n        self._memory_buffers = []\n        self._weight_buffers = []\n        self._named_parameters = {}\n        self.transform_memory_param_fn = transform_memory_param_fn\n\n    def initialize_weight_buffer(self, weight_buffer_meta_pp: list[dict[str, dict]]):\n        \"\"\"\n        Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct\n        a large buffer for each dtype in the weight_buffer.\n\n        Args:\n            weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from\n\n        Returns: None\n\n        \"\"\"\n        self.weight_buffer_meta_pp = weight_buffer_meta_pp\n\n        for weight_buffer_meta in self.weight_buffer_meta_pp:\n            memory_buffer = build_memory_buffer(weight_buffer_meta)\n            self._memory_buffers.append(memory_buffer)\n            self._weight_buffers.append(None)\n\n    def build_memory_reference(self):\n        for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp):\n            self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i])\n        self._named_parameters = self.transform_memory_param_fn(self._weight_buffers)\n\n    @property\n    def named_parameters(self):\n        return self._named_parameters\n\n    @property\n    def weight_buffers(self):\n        return self._weight_buffers\n\n    @property\n    def memory_buffers(self):\n        return self._memory_buffers\n"
  },
  {
    "path": "verl_rl/verl/utils/metric/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .utils import reduce_metrics\n\n__all__ = [\"reduce_metrics\"]\n"
  },
  {
    "path": "verl_rl/verl/utils/metric/utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMetrics utils.\n\"\"\"\n\nfrom typing import Any\n\nimport numpy as np\n\n\ndef reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:\n    \"\"\"\n    Reduces a dictionary of metric lists by computing the mean, max, or min of each list.\n    The reduce operation is determined by the key name:\n    - If the key contains \"max\", np.max is used\n    - If the key contains \"min\", np.min is used\n    - Otherwise, np.mean is used\n\n    Args:\n        metrics: A dictionary mapping metric names to lists of metric values.\n\n    Returns:\n        A dictionary with the same keys but with each list replaced by its reduced value.\n\n    Example:\n        >>> metrics = {\n        ...     \"loss\": [1.0, 2.0, 3.0],\n        ...     \"accuracy\": [0.8, 0.9, 0.7],\n        ...     \"max_reward\": [5.0, 8.0, 6.0],\n        ...     \"min_error\": [0.1, 0.05, 0.2]\n        ... }\n        >>> reduce_metrics(metrics)\n        {\"loss\": 2.0, \"accuracy\": 0.8, \"max_reward\": 8.0, \"min_error\": 0.05}\n    \"\"\"\n    for key, val in metrics.items():\n        if \"max\" in key:\n            metrics[key] = np.max(val)\n        elif \"min\" in key:\n            metrics[key] = np.min(val)\n        else:\n            metrics[key] = np.mean(val)\n    return metrics\n"
  },
  {
    "path": "verl_rl/verl/utils/model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities to create common models from huggingface\n\"\"\"\n\nimport os\nimport re\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom transformers import (\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    MistralForSequenceClassification,\n    PretrainedConfig,\n    PreTrainedModel,\n)\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom verl.models.registry import ModelRegistry\nfrom verl.utils.import_utils import is_trl_available\n\n\nclass LambdaLayer(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, *args, **kwargs):\n        return self.fn(*args, **kwargs)\n\n\ndef squeeze(x):\n    return torch.squeeze(x, dim=-1)\n\n\ndef update_model_config(module_config, override_config_kwargs):\n    \"\"\"Update the module config with the override_config_kwargs.\n    Args:\n        module_config: The module config from Huggingface Transformers.\n        override_config_kwargs: The kwargs to override the module config.\n    \"\"\"\n    for key, val in override_config_kwargs.items():\n        if isinstance(val, dict):\n            update_model_config(getattr(module_config, key), val)\n        else:\n            setattr(module_config, key, val)\n\n\ndef get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> dict:\n    if override_config_kwargs is None:\n        override_config_kwargs = {}\n    assert isinstance(override_config_kwargs, dict), (\n        f\"override_config_kwargs must be a dict, got {type(override_config_kwargs)}\"\n    )\n    module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)\n    update_model_config(module_config, override_config_kwargs)\n\n    return module_config\n\n\ndef get_generation_config(\n    model: str,\n    trust_remote_code: bool = False,\n) -> Optional[GenerationConfig]:\n    try:\n        return GenerationConfig.from_pretrained(model)\n    except OSError:  # Not found\n        try:\n            config = get_huggingface_actor_config(\n                model,\n                trust_remote_code=trust_remote_code,\n            )\n            return GenerationConfig.from_model_config(config)\n        except OSError:  # Not found\n            return None\n\n\ndef create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:\n    \"\"\"\n\n    Args:\n        model_name:\n        override_config_kwargs:\n\n    Returns:\n\n    \"\"\"\n    if override_config_kwargs is None:\n        override_config_kwargs = {}\n    if automodel_kwargs is None:\n        automodel_kwargs = {}\n    assert isinstance(override_config_kwargs, dict), (\n        f\"override_config_kwargs must be a dict, got {type(override_config_kwargs)}\"\n    )\n    module_config = get_huggingface_actor_config(\n        model_name, override_config_kwargs, trust_remote_code=automodel_kwargs.get(\"trust_remote_code\", False)\n    )\n    module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs)\n    return module\n\n\ndef create_huggingface_critic(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:\n    \"\"\"\n\n    Args:\n        model_name:\n        override_config_kwargs:\n\n    Returns:\n\n    \"\"\"\n    critic_module: nn.Module = create_huggingface_actor(\n        model_name, override_config_kwargs=override_config_kwargs, automodel_kwargs=automodel_kwargs\n    )\n    if automodel_kwargs is None:\n        automodel_kwargs = {}\n    torch_dtype = automodel_kwargs.get(\"torch_dtype\", torch.float32)\n    critic_module.lm_head = nn.Sequential(\n        nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), LambdaLayer(fn=squeeze)\n    )\n    return critic_module\n\n\ndef get_model_size(model: nn.Module, scale=\"auto\"):\n    n_params = sum(p.numel() for p in model.parameters())\n\n    if scale == \"auto\":\n        if n_params > 1e9:\n            scale = \"B\"\n        elif n_params > 1e6:\n            scale = \"M\"\n        elif n_params > 1e3:\n            scale = \"K\"\n        else:\n            scale = \"\"\n\n    if scale == \"B\":\n        n_params = n_params / 1e9\n    elif scale == \"M\":\n        n_params = n_params / 1e6\n    elif scale == \"K\":\n        n_params = n_params / 1e3\n    elif scale == \"\":\n        pass\n    else:\n        raise NotImplementedError(f\"Unknown scale {scale}\")\n\n    return n_params, scale\n\n\ndef print_model_size(model: nn.Module, name: str = None):\n    n_params, scale = get_model_size(model, scale=\"auto\")\n    if name is None:\n        name = model.__class__.__name__\n    print(f\"{name} contains {n_params:.2f}{scale} parameters\")\n\n\ndef create_random_mask(\n    input_ids: torch.Tensor,\n    max_ratio_of_valid_token: float,\n    max_ratio_of_left_padding: float,\n    min_ratio_of_valid_token: float = 0,\n):\n    \"\"\"Create a random mask given input_ids. Support left padding and right padding.\n    Process:\n    - Sample valid token length\n    - Sample left_padding length\n    - Generate padding\n\n    Args:\n        input_ids:\n            shape (batch_size, seq_len)\n\n    Returns:\n\n    \"\"\"\n    assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.0\n    assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.0\n    assert min_ratio_of_valid_token <= max_ratio_of_valid_token\n\n    batch_size, sequence_length = input_ids.shape\n    max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token)\n    min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token))\n    max_left_padding = int(sequence_length * max_ratio_of_left_padding)\n    assert max_num_valid_tokens + max_left_padding <= sequence_length\n    assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length\n    masks = torch.ones_like(input_ids, dtype=torch.int64)\n    # TODO: we can make this faster\n    for i in range(batch_size):\n        num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64)\n        num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64)\n\n        for index in range(num_left_padding):\n            masks[i, index] = 0\n\n        for index in range(num_left_padding + num_valid, sequence_length):\n            masks[i, index] = 0\n    return masks\n\n\ndef compute_position_id_with_mask(mask):\n    return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)\n\n\ndef convert_weight_keys(state_dict: dict[str, torch.Tensor], model: PreTrainedModel):\n    # convert state dict keys: https://github.com/huggingface/transformers/pull/38385\n    if not hasattr(model, \"_checkpoint_conversion_mapping\"):\n        return state_dict\n\n    reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()}\n    original_weights = {}\n    for key, value in state_dict.items():\n        for pattern, replacement in reverse_key_mapping.items():\n            replacement = replacement.lstrip(\"^\")  # strip off un-needed chars and patterns\n            replacement = re.sub(r\"\\(.*\\)\", \"\", replacement)\n            key, n_replace = re.subn(pattern, replacement, key)\n            # Early exit of the loop\n            if n_replace > 0:\n                break\n\n        original_weights[key] = value\n\n    return original_weights\n\n\ndef check_exclude_modules(config, key: str) -> bool:\n    \"\"\"\n    A helper method to check if the passed module's key name matches any of the exclude modules in the adapter_config.\n    Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py\n\n    Args:\n        config (`LoraConfig` | `LycorisConfig`): A config to match exclude modules from\n        key (`str`): A key to search any matches in config\n\n    Returns:\n        True of match object if key matches any exclude modules from config, False if no match found\n    \"\"\"\n    if hasattr(config, \"exclude_modules\") and config.exclude_modules:\n        if isinstance(config.exclude_modules, str):\n            if re.fullmatch(config.exclude_modules, key):\n                return True\n        elif key in config.exclude_modules:\n            return True\n        elif any(key.endswith(f\".{exclude_key}\") for exclude_key in config.exclude_modules):\n            return True\n    return False\n\n\ndef check_target_modules(config, key: str) -> bool:\n    \"\"\"\n    A helper method to check if the passed module's key name matches any of the target modules in the adapter_config.\n    Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py\n\n    Args:\n        config (`LoraConfig` | `LycorisConfig`): A config to match target modules from\n        key (`str`): A key to search any matches in config\n\n    Returns:\n        True of match object if key matches any target modules from config, False if no match found\n    \"\"\"\n    if isinstance(config.target_modules, str):\n        target_module_found = re.fullmatch(config.target_modules, key)\n    elif key in config.target_modules:\n        # this module is specified directly in target_modules\n        target_module_found = True\n    else:\n        target_module_found = any(key.endswith(f\".{target_key}\") for target_key in config.target_modules)\n\n        layer_indexes = getattr(config, \"layers_to_transform\", None)\n        layers_pattern = getattr(config, \"layers_pattern\", None)\n\n        is_using_layer_indexes = layer_indexes is not None and (\n            len(layer_indexes) != 0 if isinstance(layer_indexes, list) else True\n        )\n        if is_using_layer_indexes and target_module_found:\n            layer_index = None\n            # TODO: It's still unclear how empty layers_pattern (None, [], or \"\") should behave\n            # For now, empty layers_pattern means any layer pattern is ok\n            if layers_pattern is None or len(layers_pattern) == 0:\n                layer_index = re.match(r\".*\\.[^.]*\\.(\\d+)\\.\", key)\n            else:\n                layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern\n                for pattern in layers_pattern:\n                    layer_index = re.match(rf\".*\\.{pattern}\\.(\\d+)\\.\", key)\n                    if layer_index is not None:\n                        break\n\n            if layer_index is None:\n                target_module_found = False\n            else:\n                layer_index = int(layer_index.group(1))\n                if isinstance(layer_indexes, int):\n                    target_module_found = layer_index == layer_indexes\n                else:\n                    target_module_found = layer_index in layer_indexes\n\n    return target_module_found\n\n\ndef normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name=\"layers\"):\n    \"\"\"\n    Transform the model name in each model_chunk in each pp stage into the name in inference engine\n    \"\"\"\n    from verl.utils.megatron_utils import get_transformer_layer_offset\n\n    layer_offset = get_transformer_layer_offset(pp_rank, vpp_rank, transformer_config)\n\n    if layer_name in name:  # belong to an intermediate layer\n        split_name = name.split(\".\")\n        # find the num next to split_name\n        for i, name in enumerate(split_name):\n            if name == layer_name:\n                break\n        layer_num_idx = i + 1\n        # check the name\n        assert len(split_name) >= layer_num_idx + 1, f\"split_name = {split_name}\"\n        assert split_name[layer_num_idx].isdigit(), f\"split_name = {split_name}\"\n        # increment layer_num_idx by layer_offset\n        split_name[layer_num_idx] = str(int(split_name[layer_num_idx]) + layer_offset)\n        name = \".\".join(split_name)  # weight name in inference_tp_model\n    return name\n\n\ndef normalize_pp_vpp_params(params, num_hidden_layers, layer_name=\"layers\"):\n    \"\"\"\n    Normalize the pp vpp params into a complete named parameters.\n    This is useful when gather parameters from pp ranks and passed to a model without pp\n\n    params: Iterable[List[Dict[str, param]]]\n        params contains a list of pp, with a list of vpp named_parameters in each vpp chunk.\n    output: Dict[str, param]\n\n    \"\"\"\n    pp_size = len(params)\n    for pp_rank in range(len(params)):\n        vpp_size = len(params[pp_rank])\n        for vpp_rank in range(vpp_size):\n            for name, param in params[pp_rank][vpp_rank].items():\n                normalized_name = normalize_model_name(\n                    name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name\n                )\n                yield normalized_name, param\n\n\ndef get_parallel_model_from_config(\n    config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False\n):\n    from megatron.core import ModelParallelConfig\n\n    assert isinstance(megatron_config, ModelParallelConfig)\n    model_class = _get_parallel_model_architecture_from_config(config, value)\n\n    model = model_class(\n        config,\n        megatron_config,\n        pre_process=pre_process,\n        post_process=post_process,\n        share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n    )\n    return model\n\n\ndef _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> type[nn.Module]:\n    architectures = getattr(config, \"architectures\", [])\n    for arch in architectures:\n        model_cls = ModelRegistry.load_model_cls(arch, value)\n        print(\"after load model cls\")\n        if model_cls is not None:\n            return model_cls\n    raise ValueError(\n        f\"Model architectures {architectures} are not supported for now. Supported architectures: \"\n        f\"{ModelRegistry.get_supported_archs()}\"\n    )\n\n\ndef _load_hf_model(config, model_config, is_value_model, local_cache_path):\n    \"\"\"Helper function containing the loading hf model logic\"\"\"\n    from accelerate import init_empty_weights\n    from megatron.core import parallel_state as mpu\n\n    from verl.models.mcore.saver import _megatron_calc_global_rank\n\n    assert hasattr(model_config, \"architectures\"), \"architectures cannot be empty when load weight!\"\n    architectures = getattr(model_config, \"architectures\", [])\n    local_cache_path = os.path.expanduser(local_cache_path)\n\n    if config.model.path.startswith(\"hdfs:\"):\n        from verl.utils.fs import copy_to_local\n\n        print(f\"start download from {config.model.path}\")\n        local_model_path = copy_to_local(\n            src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get(\"use_shm\", False)\n        )\n        print(\"finish download\")\n    else:\n        local_model_path = config.model.path\n        print(f\"load from local dir {local_model_path}\")\n\n    src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=mpu.get_context_parallel_rank())\n    cpu_init_weights = lambda: torch.device(\"cpu\")\n    init_context = init_empty_weights if torch.distributed.get_rank() != src_rank else cpu_init_weights\n    with init_context(), warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\")\n        # TODO: to find a better way to load mistral7b-rm lm_head\n        if \"mistral7b-rm\" in config.model.path:\n            model = MistralForSequenceClassification.from_pretrained(\n                local_model_path,\n                torch_dtype=\"auto\",\n                # device_map=\"auto\",  # disable auto device_map, the HF weight is only loaded to CPU in src_rank\n                # low_cpu_mem_usage=True\n            )  # use score head instead of lm_head\n            state_dict = model.state_dict()\n            state_dict[\"lm_head.weight\"] = state_dict[\"score.weight\"]\n            state_dict[\"model.embed_tokens.weight\"] = state_dict[\"model.embed_tokens.weight\"][\n                :32000\n            ]  # workaround, 32001 -> 32000\n            is_value_model = True\n        else:\n            model = AutoModelForCausalLM.from_pretrained(\n                local_model_path,\n                torch_dtype=\"auto\",\n                # device_map=\"auto\", # disable auto device_map, the HF weight is only loaded to CPU in src_rank\n                # low_cpu_mem_usage=True\n            )\n            state_dict = model.state_dict()\n\n    return architectures, model, state_dict, is_value_model\n\n\ndef get_hf_model_path(config, local_cache_path=\"~/.cache/verl/rlhf\"):\n    local_cache_path = os.path.expanduser(local_cache_path)\n    if config.model.path.startswith(\"hdfs:\"):\n        from verl.utils.fs import copy_to_local\n\n        local_model_path = copy_to_local(\n            src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get(\"use_shm\", False)\n        )\n    else:\n        local_model_path = config.model.path\n    return local_model_path\n\n\ndef load_megatron_model_weights(\n    config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path=\"~/.cache/verl/rlhf\"\n):\n    \"\"\"Load weights for verl customized model.\"\"\"\n    architectures, model, state_dict, is_value_model = _load_hf_model(\n        config, model_config, is_value_model, local_cache_path\n    )\n\n    from verl.models.weight_loader_registry import get_weight_loader\n\n    print(f\"before weight loader: architectures = {architectures}...\")\n    for arch in architectures:\n        print(f\"call weight loader arch = {arch}, model config = {model.config}\")\n        weight_loader = get_weight_loader(arch)\n        weight_loader(\n            state_dict=state_dict,\n            wrapped_models=parallel_model,\n            config=model.config,\n            params_dtype=params_dtype,\n            is_value_model=is_value_model,\n            tie_word_embeddings=model_config.tie_word_embeddings,\n        )\n    return model.config\n\n\ndef load_megatron_gptmodel_weights(\n    config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path=\"~/.cache/verl/rlhf\"\n):\n    \"\"\"Load weights for mcore GPT model.\"\"\"\n    _, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model, local_cache_path)\n\n    from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel\n\n    load_state_dict_to_megatron_gptmodel(\n        state_dict=state_dict,\n        wrapped_models=parallel_model,\n        config=model.config,\n        params_dtype=params_dtype,\n        is_value_model=is_value_model,\n    )\n    del state_dict, model\n\n\n# pad input_ids_rmpad, cu_seqlens and max_seqlen_in_batch to be divisible by tp\ndef pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size):\n    \"\"\"pad the tokens such that the total length is a multiple of size.\n    This function is useful when applying sequence parallel and context parallel\n\n    Args:\n        unpad_tokens: (total_nnz, ...). Tokens after removing padding\n        cu_seqlens: (total_nnz + 1,)\n        max_seqlen_in_batch: int\n\n    Returns:\n\n    \"\"\"\n    F = nn.functional\n\n    total_nnz = unpad_tokens.shape[0]\n\n    pad_size = 0 if total_nnz % size == 0 else size - total_nnz % size\n\n    # we assume adding a new data in the batch with seqlen pad_size\n    if pad_size > 0:\n        if unpad_tokens.ndim == 1:\n            unpad_tokens = F.pad(unpad_tokens, (0, pad_size))\n        elif unpad_tokens.ndim == 2:\n            unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))\n        else:\n            raise NotImplementedError(f\"Padding dim {unpad_tokens.ndim()} is not supported\")\n\n        cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_size + cu_seqlens[-1])\n        max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size)\n\n    return unpad_tokens, cu_seqlens, max_seqlen_in_batch\n\n\ndef load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False):\n    from megatron.core import dist_checkpointing\n    from megatron.core.dist_checkpointing.serialization import StrictHandling\n\n    from verl.utils.megatron_utils import unwrap_model\n\n    # strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED\n    strict = StrictHandling.ASSUME_OK_UNEXPECTED\n    for model in parallel_model:\n        ssd = unwrap_model(model).sharded_state_dict()\n        if is_value_model:\n            for k in list(ssd.keys()):\n                if \"output_layer\" in k:\n                    ssd.pop(k)\n        dist_checkpointing.load(ssd, dist_weight_path, strict=strict)\n\n    return\n\n\ndef get_parallel_gptmodel_from_config(\n    tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False\n):\n    from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec\n    from megatron.core.models.gpt.gpt_model import GPTModel\n\n    use_te = True\n    assert tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n    transformer_layer_spec = get_gpt_decoder_block_spec(tfconfig, use_transformer_engine=use_te)\n    rope_scaling_args = {}\n    if hf_config.rope_scaling is not None:\n        assert hf_config.rope_scaling[\"type\"] == \"linear\", \"only linear scaling is supported for now\"\n        rope_scaling_args[\"seq_len_interpolation_factor\"] = hf_config.rope_scaling[\"factor\"]\n    parallel_model = GPTModel(\n        config=tfconfig,\n        transformer_layer_spec=transformer_layer_spec,\n        vocab_size=hf_config.vocab_size,\n        max_sequence_length=hf_config.max_position_embeddings,\n        pre_process=pre_process,\n        post_process=post_process,\n        share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n        position_embedding_type=\"rope\",\n        rotary_base=hf_config.rope_theta,\n        **rope_scaling_args,\n    )\n    # # for layer in parallel_model.decoder.layers:\n    # layer.self_attention.core_attention.flash_attention.softmax_scale = None\n    if post_process and value:\n        from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer\n\n        parallel_model.output_layer = LinearForLastLayer(\n            input_size=tfconfig.hidden_size, output_size=1, config=tfconfig\n        )\n    return parallel_model\n\n\ndef patch_valuehead_model(model) -> None:\n    from types import MethodType\n\n    from transformers import PreTrainedModel\n    from trl import AutoModelForCausalLMWithValueHead\n\n    def tie_weights(self: \"AutoModelForCausalLMWithValueHead\") -> None:\n        if isinstance(self.pretrained_model, PreTrainedModel):\n            self.pretrained_model.tie_weights()\n\n    def get_input_embeddings(self: \"AutoModelForCausalLMWithValueHead\") -> torch.nn.Module:\n        if isinstance(self.pretrained_model, PreTrainedModel):\n            return self.pretrained_model.get_input_embeddings()\n\n    def get_output_embeddings(self: \"AutoModelForCausalLMWithValueHead\") -> torch.nn.Module:\n        if isinstance(self.pretrained_model, PreTrainedModel):\n            return self.pretrained_model.get_output_embeddings()\n\n    def can_generate(self):\n        return False\n\n    ignore_modules = [name for name, _ in model.named_parameters() if \"pretrained_model\" in name]\n    model._keys_to_ignore_on_save = ignore_modules\n    model.tie_weights = MethodType(tie_weights, model)\n    model.get_input_embeddings = MethodType(get_input_embeddings, model)\n    model.get_output_embeddings = MethodType(get_output_embeddings, model)\n    model.can_generate = MethodType(can_generate, model)\n    model._no_split_modules = getattr(model.pretrained_model, \"_no_split_modules\", [])\n\n\ndef load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code):\n    from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq\n\n    try:\n        model = AutoModelForTokenClassification.from_pretrained(\n            pretrained_model_name_or_path=local_path,\n            torch_dtype=torch_dtype,\n            config=model_config,\n            attn_implementation=\"flash_attention_2\",\n            trust_remote_code=trust_remote_code,\n        )\n        return model\n    except BaseException as e:\n        if not is_trl_available():\n            raise RuntimeError(\n                f\"model({local_path}) is not a value head model, please install trl to make it valid\"\n            ) from e\n\n    assert is_trl_available()\n\n    from trl import AutoModelForCausalLMWithValueHead\n\n    if type(model_config) in AutoModelForVision2Seq._model_mapping.keys():\n        module_class = AutoModelForVision2Seq\n    else:\n        module_class = AutoModelForCausalLM\n    ori_model = module_class.from_pretrained(\n        pretrained_model_name_or_path=local_path,\n        torch_dtype=torch_dtype,\n        config=model_config,\n        attn_implementation=\"flash_attention_2\",\n        trust_remote_code=trust_remote_code,\n    )\n    model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model)\n    patch_valuehead_model(model)\n    return model\n\n\n@dataclass\nclass CausalLMOutputForPPO(CausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n"
  },
  {
    "path": "verl_rl/verl/utils/net_utils.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport ipaddress\n\n\ndef is_ipv4(ip_str: str) -> bool:\n    \"\"\"\n    Check if the given string is an IPv4 address\n\n    Args:\n        ip_str: The IP address string to check\n\n    Returns:\n        bool: Returns True if it's an IPv4 address, False otherwise\n    \"\"\"\n    try:\n        ipaddress.IPv4Address(ip_str)\n        return True\n    except ipaddress.AddressValueError:\n        return False\n\n\ndef is_ipv6(ip_str: str) -> bool:\n    \"\"\"\n    Check if the given string is an IPv6 address\n\n    Args:\n        ip_str: The IP address string to check\n\n    Returns:\n        bool: Returns True if it's an IPv6 address, False otherwise\n    \"\"\"\n    try:\n        ipaddress.IPv6Address(ip_str)\n        return True\n    except ipaddress.AddressValueError:\n        return False\n"
  },
  {
    "path": "verl_rl/verl/utils/profiler/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom ..device import is_npu_available\nfrom ..import_utils import is_nvtx_available\nfrom .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer\nfrom .profile import DistProfilerExtension, ProfilerConfig\n\nif is_nvtx_available():\n    from .nvtx_profile import NsightSystemsProfiler as DistProfiler\n    from .nvtx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer\nelif is_npu_available:\n    from .mstx_profile import NPUProfiler as DistProfiler\n    from .mstx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer\nelse:\n    from .performance import marked_timer\n    from .profile import DistProfiler, mark_annotate, mark_end_range, mark_start_range\n\n__all__ = [\n    \"GPUMemoryLogger\",\n    \"log_gpu_memory_usage\",\n    \"mark_start_range\",\n    \"mark_end_range\",\n    \"mark_annotate\",\n    \"DistProfiler\",\n    \"DistProfilerExtension\",\n    \"ProfilerConfig\",\n    \"simple_timer\",\n    \"marked_timer\",\n]\n"
  },
  {
    "path": "verl_rl/verl/utils/profiler/config.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import ClassVar\n\nfrom verl.base_config import BaseConfig\n\n\n@dataclass\nclass ProfilerConfig(BaseConfig):\n    \"\"\"Worker profiler config. Currently only support Nsight system profiler.\n\n    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.\n\n    Args:\n        discrete (bool): True for each task has its own database, False for all tasks in one training step\n          share one database.\n        all_ranks (bool): Whether to profile all ranks.\n        ranks (list[int]): The ranks that will be profiled. Defaults to [].\n    \"\"\"\n\n    # the fields expected to be frozen\n    _frozen_fields: ClassVar[set[str]] = {\"discrete\", \"all_ranks\", \"ranks\"}\n\n    discrete: bool = False\n\n    all_ranks: bool = False\n\n    ranks: list[int] = field(default_factory=list)\n\n    def union(self, other: \"ProfilerConfig\") -> \"ProfilerConfig\":\n        return ProfilerConfig(\n            all_ranks=self.all_ranks or other.all_ranks,\n            ranks=list(set(self.ranks or []) | set(other.ranks or [])),\n            discrete=self.discrete or other.discrete,\n        )\n\n    def intersect(self, other: \"ProfilerConfig\") -> \"ProfilerConfig\":\n        return ProfilerConfig(\n            all_ranks=self.all_ranks and other.all_ranks,\n            ranks=list(set(self.ranks or []) & set(other.ranks or [])),\n            discrete=self.discrete and other.discrete,\n        )\n\n    def __post_init__(self) -> None:\n        \"\"\"config validation logics go here\"\"\"\n        assert isinstance(self.ranks, set | list | tuple), (\n            f\"Profiler ranks must be of type list, got {type(self.ranks)}\"\n        )\n"
  },
  {
    "path": "verl_rl/verl/utils/profiler/empty_annotations.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable, Optional\n\n\ndef mark_start_range(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> None:\n    pass\n\n\ndef mark_end_range(range_id: str) -> None:\n    pass\n\n\ndef mark_annotate(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> Callable:\n    def decorator(func):\n        return func\n\n    return decorator\n"
  },
  {
    "path": "verl_rl/verl/utils/profiler/mstx_profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Inspired from https://gitee.com/ascend/MindSpeed-RL/blob/master/mindspeed_rl/utils/utils.py\nimport functools\nimport logging\nimport os\nfrom contextlib import contextmanager\nfrom typing import Any, Callable, Optional\n\nimport torch_npu\nfrom omegaconf import DictConfig\nfrom torch_npu.npu import mstx\n\nfrom .profile import DistProfiler, ProfilerConfig\n\n\ndef mark_start_range(message: Optional[str] = None) -> None:\n    \"\"\"Start a mark range in the profiler.\n\n    Args:\n        message (str, optional):\n            The message to be displayed in the profiler. Defaults to None.\n    \"\"\"\n    return mstx.range_start(message=message)\n\n\ndef mark_end_range(range_id: str) -> None:\n    \"\"\"End a mark range in the profiler.\n\n    Args:\n        range_id (str):\n            The id of the mark range to end.\n    \"\"\"\n    return mstx.range_end(range_id)\n\n\ndef mark_annotate(message: Optional[str] = None) -> Callable:\n    \"\"\"Decorate a function to annotate a mark range along with the function life cycle.\n\n    Args:\n        message (str, optional):\n            The message to be displayed in the profiler. Defaults to None.\n    \"\"\"\n\n    def decorator(func):\n        profile_message = message or func.__name__\n        return mstx.mstx_range(profile_message)(func)\n\n    return decorator\n\n\n@contextmanager\ndef marked_timer(name: str, timing_raw: dict[str, float], *args: Any, **kwargs: Any) -> None:\n    \"\"\"Context manager for timing with MSTX markers.\n\n    This utility function measures the execution time of code within its context,\n    accumulates the timing information, and adds MSTX markers for profiling.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n\n    Yields:\n        None: This is a context manager that yields control back to the code block.\n    \"\"\"\n    if args:\n        logging.warning(f\"Args are not supported in mstx_profile, but received: {args}\")\n    if kwargs:\n        logging.warning(f\"Kwargs are not supported in mstx_profile, but received: {kwargs}\")\n    mark_range = mark_start_range(message=name)\n    from .performance import _timer\n\n    yield from _timer(name, timing_raw)\n    mark_end_range(mark_range)\n\n\ndef get_npu_profiler(option: DictConfig, role: Optional[str] = None, profile_step: Optional[str] = None):\n    \"\"\"Generate and return an NPU profiler object.\n\n    Args:\n        option (DictConfig):\n            The options to control npu profiler.\n        role (str, optional):\n            The role of the current data collection. Defaults to None.\n        profile_step(str, optional):\n            The current training step. Defaults to None.\n    \"\"\"\n    if option.level == \"level_none\":\n        profile_level = torch_npu.profiler.ProfilerLevel.Level_none\n    elif option.level == \"level0\":\n        profile_level = torch_npu.profiler.ProfilerLevel.Level0\n    elif option.level == \"level1\":\n        profile_level = torch_npu.profiler.ProfilerLevel.Level1\n    elif option.level == \"level2\":\n        profile_level = torch_npu.profiler.ProfilerLevel.Level2\n    else:\n        raise ValueError(f\"level only supports level0, 1, 2, and level_none, but gets {option.level}\")\n\n    profile_save_path = option.save_path\n    if profile_step:\n        profile_save_path = os.path.join(profile_save_path, profile_step)\n    if role:\n        profile_save_path = os.path.join(profile_save_path, role)\n\n    experimental_config = torch_npu.profiler._ExperimentalConfig(\n        aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,\n        profiler_level=profile_level,\n        export_type=torch_npu.profiler.ExportType.Text,\n        data_simplification=True,\n        msprof_tx=True,\n    )\n\n    activites = []\n    if option.with_npu:\n        activites.append(torch_npu.profiler.ProfilerActivity.NPU)\n    if option.with_cpu:\n        activites.append(torch_npu.profiler.ProfilerActivity.CPU)\n\n    prof = torch_npu.profiler.profile(\n        with_modules=option.with_module,\n        with_stack=option.with_stack,\n        record_shapes=option.record_shapes,\n        profile_memory=option.with_memory,\n        activities=activites,\n        on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_save_path, analyse_flag=option.analysis),\n        experimental_config=experimental_config,\n    )\n    return prof\n\n\nclass NPUProfiler(DistProfiler):\n    \"\"\"\n    NPU profiler. Initialized in a worker to control the NPU profiler.\n    \"\"\"\n\n    _define_count = 0\n\n    def __init__(self, rank: int, config: ProfilerConfig, **kwargs):\n        \"\"\"Initialize the NsightSystemsProfiler.\n\n        Args:\n            rank (int): The rank of the current process.\n            config (Optional[ProfilerConfig]): Configuration for the profiler. If None, a default configuration is used.\n        \"\"\"\n        if not config:\n            config = ProfilerConfig(ranks=[])\n        self.this_step: bool = False\n        self.discrete: bool = config.discrete\n        self.this_rank: bool = False\n        self.profile_npu = None\n        self.profile_option = kwargs.get(\"option\", None)\n        if config.all_ranks:\n            self.this_rank = True\n        elif config.ranks:\n            self.this_rank = rank in config.ranks\n\n    def start(self, **kwargs):\n        role, profile_step = kwargs.get(\"role\", None), kwargs.get(\"profile_step\", None)\n        profile_step = str(profile_step) if profile_step is not None else None\n        if self.this_rank and self.profile_option is not None:\n            self.this_step = True\n            if not self.discrete and NPUProfiler._define_count == 0:\n                self.profile_npu = get_npu_profiler(option=self.profile_option, role=role, profile_step=profile_step)\n                self.profile_npu.start()\n                NPUProfiler._define_count += 1\n\n    def stop(self):\n        if self.this_rank and self.profile_option is not None:\n            self.this_step = False\n            if not self.discrete and NPUProfiler._define_count == 1:\n                self.profile_npu.step()\n                self.profile_npu.stop()\n                NPUProfiler._define_count -= 1\n\n    @staticmethod\n    def annotate(message: Optional[str] = None, role: Optional[str] = None, **kwargs) -> Callable:\n        \"\"\"Decorate a Worker member function to profile the current rank in the current training step.\n\n        Requires the target function to be a member function of a Worker,\n        which has a member field `profiler` with NPUProfiler type.\n\n        Args:\n            message (str, optional):\n                The message to be displayed in the profiler. Defaults to None.\n            role (str, optional):\n                The role of the current data collection. Defaults to None.\n        \"\"\"\n\n        def decorator(func):\n            @functools.wraps(func)\n            def wrapper(self, *args, **kwargs):\n                profile_name = message or func.__name__\n\n                if self.profiler.this_step and self.profile_option is not None:\n                    if self.profiler.discrete:\n                        profile_npu = get_npu_profiler(option=self.profile_option, role=role)\n                        profile_npu.start()\n                    mark_range = mark_start_range(message=profile_name)\n\n                result = func(self, *args, **kwargs)\n\n                if self.profiler.this_step and self.profile_option is not None:\n                    mark_end_range(mark_range)\n                    if self.profiler.discrete:\n                        profile_npu.step()\n                        profile_npu.stop()\n\n                return result\n\n            return wrapper\n\n        return decorator\n"
  },
  {
    "path": "verl_rl/verl/utils/profiler/nvtx_profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nfrom contextlib import contextmanager\nfrom typing import Callable, Optional\n\nimport nvtx\nimport torch\n\nfrom .profile import DistProfiler, ProfilerConfig\n\n\ndef mark_start_range(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> None:\n    \"\"\"Start a mark range in the profiler.\n\n    Args:\n        message (str, optional):\n            The message to be displayed in the profiler. Defaults to None.\n        color (str, optional):\n            The color of the range. Defaults to None.\n        domain (str, optional):\n            The domain of the range. Defaults to None.\n        category (str, optional):\n            The category of the range. Defaults to None.\n    \"\"\"\n    return nvtx.start_range(message=message, color=color, domain=domain, category=category)\n\n\ndef mark_end_range(range_id: str) -> None:\n    \"\"\"End a mark range in the profiler.\n\n    Args:\n        range_id (str):\n            The id of the mark range to end.\n    \"\"\"\n    return nvtx.end_range(range_id)\n\n\ndef mark_annotate(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> Callable:\n    \"\"\"Decorate a function to annotate a mark range along with the function life cycle.\n\n    Args:\n        message (str, optional):\n            The message to be displayed in the profiler. Defaults to None.\n        color (str, optional):\n            The color of the range. Defaults to None.\n        domain (str, optional):\n            The domain of the range. Defaults to None.\n        category (str, optional):\n            The category of the range. Defaults to None.\n    \"\"\"\n\n    def decorator(func):\n        profile_message = message or func.__name__\n        return nvtx.annotate(profile_message, color=color, domain=domain, category=category)(func)\n\n    return decorator\n\n\n@contextmanager\ndef marked_timer(\n    name: str,\n    timing_raw: dict[str, float],\n    color: str = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n):\n    \"\"\"Context manager for timing with NVTX markers.\n\n    This utility function measures the execution time of code within its context,\n    accumulates the timing information, and adds NVTX markers for profiling.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n        color (Optional[str]): Color for the NVTX marker. Defaults to None.\n        domain (Optional[str]): Domain for the NVTX marker. Defaults to None.\n        category (Optional[str]): Category for the NVTX marker. Defaults to None.\n\n    Yields:\n        None: This is a context manager that yields control back to the code block.\n    \"\"\"\n    mark_range = mark_start_range(message=name, color=color, domain=domain, category=category)\n    from .performance import _timer\n\n    yield from _timer(name, timing_raw)\n    mark_end_range(mark_range)\n\n\nclass NsightSystemsProfiler(DistProfiler):\n    \"\"\"Nsight system profiler. Installed in a worker to control the Nsight system profiler.\"\"\"\n\n    def __init__(self, rank: int, config: Optional[ProfilerConfig], **kwargs):\n        \"\"\"Initialize the NsightSystemsProfiler.\n\n        Args:\n            rank (int): The rank of the current process.\n            config (Optional[ProfilerConfig]): Configuration for the profiler. If None, a default configuration is used.\n        \"\"\"\n        # If no configuration is provided, create a default ProfilerConfig with an empty list of ranks\n        if not config:\n            config = ProfilerConfig(ranks=[])\n        self.this_step: bool = False\n        self.discrete: bool = config.discrete\n        self.this_rank: bool = False\n        if config.all_ranks:\n            self.this_rank = True\n        elif config.ranks:\n            self.this_rank = rank in config.ranks\n\n    def start(self, **kwargs):\n        if self.this_rank:\n            self.this_step = True\n            if not self.discrete:\n                torch.cuda.profiler.start()\n\n    def stop(self):\n        if self.this_rank:\n            self.this_step = False\n            if not self.discrete:\n                torch.cuda.profiler.stop()\n\n    @staticmethod\n    def annotate(\n        message: Optional[str] = None,\n        color: Optional[str] = None,\n        domain: Optional[str] = None,\n        category: Optional[str] = None,\n        **kwargs,\n    ) -> Callable:\n        \"\"\"Decorate a Worker member function to profile the current rank in the current training step.\n\n        Requires the target function to be a member function of a Worker, which has a member field `profiler` with\n        NightSystemsProfiler type.\n\n        Args:\n            message (str, optional):\n                The message to be displayed in the profiler. Defaults to None.\n            color (str, optional):\n                The color of the range. Defaults to None.\n            domain (str, optional):\n                The domain of the range. Defaults to None.\n            category (str, optional):\n                The category of the range. Defaults to None.\n        \"\"\"\n\n        def decorator(func):\n            @functools.wraps(func)\n            def wrapper(self, *args, **kwargs):\n                profile_name = message or func.__name__\n\n                if self.profiler.this_step:\n                    if self.profiler.discrete:\n                        torch.cuda.profiler.start()\n                    mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category)\n\n                result = func(self, *args, **kwargs)\n\n                if self.profiler.this_step:\n                    mark_end_range(mark_range)\n                    if self.profiler.discrete:\n                        torch.cuda.profiler.stop()\n\n                return result\n\n            return wrapper\n\n        return decorator\n"
  },
  {
    "path": "verl_rl/verl/utils/profiler/performance.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport datetime\nimport inspect\nimport logging\nfrom contextlib import contextmanager\nfrom typing import Any, Optional\n\nimport torch\nimport torch.distributed as dist\nfrom codetiming import Timer\n\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.logger import DecoratorLoggerBase\n\n\ndef _get_current_mem_info(unit: str = \"GB\", precision: int = 2) -> tuple[str]:\n    \"\"\"Get current memory usage.\"\"\"\n    assert unit in [\"GB\", \"MB\", \"KB\"]\n    divisor = 1024**3 if unit == \"GB\" else 1024**2 if unit == \"MB\" else 1024\n    mem_allocated = get_torch_device().memory_allocated()\n    mem_reserved = get_torch_device().memory_reserved()\n    # use get_torch_device().mem_get_info to profile device memory\n    # since vllm's sleep mode works below pytorch\n    # see https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119\n    mem_free, mem_total = get_torch_device().mem_get_info()\n    mem_used = mem_total - mem_free\n    mem_allocated = f\"{mem_allocated / divisor:.{precision}f}\"\n    mem_reserved = f\"{mem_reserved / divisor:.{precision}f}\"\n    mem_used = f\"{mem_used / divisor:.{precision}f}\"\n    mem_total = f\"{mem_total / divisor:.{precision}f}\"\n    return mem_allocated, mem_reserved, mem_used, mem_total\n\n\ndef log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0):\n    \"\"\"Log GPU memory usage information.\n\n    Args:\n        head (str): A descriptive header for the memory usage log message.\n        logger (logging.Logger, optional): Logger instance to use for logging. If None, prints to stdout.\n        level: Logging level to use. Defaults to logging.DEBUG.\n        rank (int): The rank of the process to log memory for. Defaults to 0.\n    \"\"\"\n    if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):\n        mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()\n        message = (\n            f\"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, \"\n            f\"device memory used/total (GB): {mem_used}/{mem_total}\"\n        )\n\n        if logger is None:\n            print(message)\n        else:\n            logger.log(msg=message, level=level)\n\n\nclass GPUMemoryLogger(DecoratorLoggerBase):\n    \"\"\"A decorator class to log GPU memory usage.\n\n    Example:\n        >>> from verl.utils.profiler.performance import GPUMemoryLogger\n        >>> @GPUMemoryLogger(role=\"actor\")\n        >>> def update_actor(self, batch):\n        ...     # real actor update logics\n        ...     return\n    \"\"\"\n\n    def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True):\n        if dist.is_initialized() and dist.get_world_size() > 1:\n            rank = dist.get_rank()\n        else:\n            rank = 0\n        super().__init__(role, logger, level, rank, log_only_rank_0)\n\n    def __call__(self, decorated_function: callable):\n        def f(*args, **kwargs):\n            return self.log(decorated_function, *args, **kwargs)\n\n        return f\n\n    def log(self, func, *args, **kwargs):\n        name = func.__name__\n        mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()\n        message = (\n            f\"Before {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, \"\n            f\"device memory used/total (GB): {mem_used}/{mem_total}\"\n        )\n        self.logging_function(message)\n\n        output = func(*args, **kwargs)\n\n        mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()\n        message = (\n            f\"After {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, \"\n            f\"device memory used/total (GB): {mem_used}/{mem_total}\"\n        )\n\n        self.logging_function(message)\n        return output\n\n\ndef log_print(ctn: Any):\n    current_time = datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n\n    frame = inspect.currentframe().f_back\n    function_name = frame.f_code.co_name\n    line_number = frame.f_lineno\n    file_name = frame.f_code.co_filename.split(\"/\")[-1]\n    print(f\"[{current_time}-{file_name}:{line_number}:{function_name}]: {ctn}\")\n\n\ndef _timer(name: str, timing_raw: dict[str, float]):\n    \"\"\"Inner function that handles the core timing logic.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n    \"\"\"\n    with Timer(name=name, logger=None) as timer:\n        yield\n    if name not in timing_raw:\n        timing_raw[name] = 0\n    timing_raw[name] += timer.last\n\n\n@contextmanager\ndef simple_timer(name: str, timing_raw: dict[str, float]):\n    \"\"\"Context manager for basic timing without NVTX markers.\n\n    This utility function measures the execution time of code within its context\n    and accumulates the timing information in the provided dictionary.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n\n    Yields:\n        None: This is a context manager that yields control back to the code block.\n    \"\"\"\n    yield from _timer(name, timing_raw)\n\n\n@contextmanager\ndef marked_timer(\n    name: str,\n    timing_raw: dict[str, float],\n    color: str = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n):\n    \"\"\"Context manager for timing with platform markers.\n\n    This utility function measures the execution time of code within its context,\n    accumulates the timing information, and adds platform markers for profiling.\n    This function is a default implementation when hardware profiler is not available.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n        color (Optional[str]): Color for the marker. Defaults to None.\n        domain (Optional[str]): Domain for the marker. Defaults to None.\n        category (Optional[str]): Category for the marker. Defaults to None.\n\n    Yields:\n        None: This is a context manager that yields control back to the code block.\n    \"\"\"\n    yield from _timer(name, timing_raw)\n\n\ndef reduce_timing(timing_raw: dict[str, float]) -> dict[str, float]:\n    \"\"\"Reduce timing information across all processes.\n\n    This function uses distributed communication to gather and sum the timing\n    information from all processes in a distributed environment.\n\n    Args:\n        timing_raw (Dict[str, float]): Dictionary containing timing information.\n\n    Returns:\n        Dict[str, float]: Reduced timing information.\n    \"\"\"\n    if not dist.is_initialized():\n        return timing_raw\n\n    key_list, timing_list = [], []\n    for key in sorted(timing_raw.keys()):\n        key_list.append(key)\n        timing_list.append(timing_raw[key])\n    timing_list = torch.tensor(timing_list, dtype=torch.float32, device=get_device_id())\n    torch.distributed.all_reduce(timing_list, op=torch.distributed.ReduceOp.AVG)\n    timing_list = [tensor.item() for tensor in timing_list.to(\"cpu\")]\n    timing_generate = {key_list[i]: timing_list[i] for i in range(len(key_list))}\n    return timing_generate\n"
  },
  {
    "path": "verl_rl/verl/utils/profiler/profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom typing import Callable, Optional\n\nimport torch\nimport torch.distributed\n\nfrom .config import ProfilerConfig\n\n\nclass Profiler:\n    \"\"\"A PyTorch profiler wrapper class for collecting performance metrics.\n\n    TODO(haibin.lin): this should implement the DistProfiler interface, and the config should be unified.\n\n    This profiler provides a convenient interface for profiling PyTorch operations,\n    with support for:\n\n    - CPU and CUDA activity profiling\n    - Configurable profiling schedule (wait/warmup/active steps)\n    - Multi-rank profiling support\n    - Chrome trace export\n\n    Args:\n        config: Configuration object containing profiling parameters\n    \"\"\"\n\n    def __init__(self, config):\n        # note : if we do not set use_profile, it will be set as None, so that all function will be skip\n        self.config = config\n        self.skip_prof = False\n        self.saved = False\n        self.prof = None\n        self.rank = torch.distributed.get_rank()\n        # we need to validate the config before using the profiler\n        self._validate()\n        if config.use_profile and self.rank in self.config.profile_ranks:\n            print(f\"[Profiler] Profiler init for rank {self.rank}\")\n\n            self.prof = torch.profiler.profile(\n                activities=[\n                    torch.profiler.ProfilerActivity.CPU,\n                    torch.profiler.ProfilerActivity.CUDA,\n                ],\n                schedule=torch.profiler.schedule(\n                    wait=max(self.config.step_start - 1, 0),\n                    warmup=1 if self.config.step_start > 0 else 0,\n                    active=self.config.step_end - self.config.step_start,\n                    repeat=1,\n                ),\n                record_shapes=True,\n                with_stack=True,\n            )\n\n    def _validate(self):\n        if self.config.use_profile:\n            if self.config.profile_ranks is None:\n                print(\"[WARNING] Profile ranks is not set, default to rank 0\")\n                self.config.profile_ranks = [0]\n            assert self.config.step_start >= 0, \"[ERROR] Profile step start must be greater than 0\"\n            assert self.config.step_end >= 0, \"[ERROR] Profile step end must be greater than 0\"\n            assert self.config.step_start < self.config.step_end, (\n                \"[ERROR] Profile step start must be less than step end\"\n            )\n\n    def check(self):\n        return self.prof is not None and not self.skip_prof\n\n    def start(self):\n        if self.check():\n            print(f\"[Profiler] started for rank {self.rank}\")\n            self.prof.start()\n\n    def step(self):\n        if self.check():\n            self.prof.step()\n\n    def stop(self):\n        if self.check():\n            print(f\"[Profiler] stopped for rank {self.rank}\")\n            self.prof.stop()\n\n    def save(self):\n        if self.prof is not None and not self.saved:\n            if not os.path.exists(self.config.save_path):\n                os.makedirs(self.config.save_path)\n            save_file_name = f\"/prof_start_{self.config.step_start}_end_{self.config.step_end}_rank_{self.rank}.json\"\n            print(f\"[Profiler] Saving trace to {self.config.save_path + save_file_name}\")\n            self.prof.export_chrome_trace(self.config.save_path + save_file_name)\n            self.skip_prof = True\n            self.saved = True\n\n    def stop_and_save(self):\n        if self.check():\n            self.stop()\n            self.save()\n\n    def stop_trace(self):\n        if self.check():\n            print(f\"[Profiler] Trace stopped for rank {self.rank}\")\n            self.skip_prof = True\n\n\ndef mark_start_range(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> None:\n    \"\"\"Start a profiling range marker (no-op implementation).\n\n    Args:\n        message (Optional[str]): Message to associate with the range marker.\n        color (Optional[str]): Color for the marker visualization.\n        domain (Optional[str]): Domain for the marker.\n        category (Optional[str]): Category for the marker.\n    \"\"\"\n    pass\n\n\ndef mark_end_range(range_id: str) -> None:\n    \"\"\"End a profiling range marker (no-op implementation).\n\n    Args:\n        range_id (str): Identifier of the range to end.\n    \"\"\"\n    pass\n\n\ndef mark_annotate(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> Callable:\n    \"\"\"Decorator to annotate a function with profiling markers (no-op implementation).\n\n    Args:\n        message (Optional[str]): Message to associate with the annotation.\n        color (Optional[str]): Color for the marker visualization.\n        domain (Optional[str]): Domain for the marker.\n        category (Optional[str]): Category for the marker.\n\n    Returns:\n        Callable: Decorator function that returns the original function unchanged.\n    \"\"\"\n\n    def decorator(func):\n        return func\n\n    return decorator\n\n\nclass DistProfiler:\n    \"\"\"A distributed profiler class for collecting performance metrics across multiple ranks.\n\n    This profiler is designed to work in distributed training environments, allowing selective\n    profiling of specific ranks or all ranks. It provides basic start/stop functionality and\n    supports annotation of code sections for detailed profiling.\n\n    Args:\n        rank (int): The rank of the current process\n        config (ProfilerConfig, optional): Configuration for the profiler.\n    \"\"\"\n\n    def __init__(self, rank: int, config: Optional[ProfilerConfig] = None, **kwargs):\n        pass\n\n    def start(self, **kwargs):\n        pass\n\n    def stop(self):\n        pass\n\n    @staticmethod\n    def annotate(\n        message: Optional[str] = None,\n        color: Optional[str] = None,\n        domain: Optional[str] = None,\n        category: Optional[str] = None,\n        **kwargs,\n    ) -> Callable:\n        def decorator(func):\n            return func\n\n        return decorator\n\n\nclass DistProfilerExtension:\n    \"\"\"An extension class for DistProfiler that provides distributed profiling capabilities.\n\n    It is intended for workers in verl that single controller invokes.\n\n    This class wraps a DistProfiler instance and provides methods to start/stop profiling\n    that can be dispatched across multiple ranks in a distributed training environment.\n\n    Args:\n        profiler (DistProfiler): The base distributed profiler instance to extend\n    \"\"\"\n\n    def __init__(self, profiler: DistProfiler):\n        self.profiler = profiler\n\n    from verl.single_controller.base.decorator import Dispatch, register\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def start_profile(self, **kwargs) -> None:\n        \"\"\"Start profiling for the current rank in the current training step.\"\"\"\n        self.profiler.start(**kwargs)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def stop_profile(self) -> None:\n        \"\"\"Stop profiling for the current rank in the current training step.\"\"\"\n        self.profiler.stop()\n"
  },
  {
    "path": "verl_rl/verl/utils/py_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContain small python utility functions\n\"\"\"\n\nimport importlib\nimport multiprocessing\nimport os\nimport queue  # Import the queue module for exception type hint\nimport signal\nfrom contextlib import contextmanager\nfrom functools import wraps\nfrom types import SimpleNamespace\nfrom typing import Any, Callable, Iterator, Optional\n\n\n# --- Top-level helper for multiprocessing timeout ---\n# This function MUST be defined at the top level to be pickleable\ndef _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: tuple, kwargs: dict[str, Any]):\n    \"\"\"\n    Internal wrapper function executed in the child process.\n    Calls the original target function and puts the result or exception into the queue.\n    \"\"\"\n    try:\n        result = target_func(*args, **kwargs)\n        mp_queue.put((True, result))  # Indicate success and put result\n    except Exception as e:\n        # Ensure the exception is pickleable for the queue\n        try:\n            import pickle\n\n            pickle.dumps(e)  # Test if the exception is pickleable\n            mp_queue.put((False, e))  # Indicate failure and put exception\n        except (pickle.PicklingError, TypeError):\n            # Fallback if the original exception cannot be pickled\n            mp_queue.put((False, RuntimeError(f\"Original exception type {type(e).__name__} not pickleable: {e}\")))\n\n\n# Renamed the function from timeout to timeout_limit\ndef timeout_limit(seconds: float, use_signals: bool = False):\n    \"\"\"\n    Decorator to add a timeout to a function.\n\n    Args:\n        seconds: The timeout duration in seconds.\n        use_signals: (Deprecated)  This is deprecated because signals only work reliably in the main thread\n                     and can cause issues in multiprocessing or multithreading contexts.\n                     Defaults to False, which uses the more robust multiprocessing approach.\n\n    Returns:\n        A decorated function with timeout.\n\n    Raises:\n        TimeoutError: If the function execution exceeds the specified time.\n        RuntimeError: If the child process exits with an error (multiprocessing mode).\n        NotImplementedError: If the OS is not POSIX (signals are only supported on POSIX).\n    \"\"\"\n\n    def decorator(func):\n        if use_signals:\n            if os.name != \"posix\":\n                raise NotImplementedError(f\"Unsupported OS: {os.name}\")\n            # Issue deprecation warning if use_signals is explicitly True\n            print(\n                \"WARN: The 'use_signals=True' option in the timeout decorator is deprecated. \\\n                Signals are unreliable outside the main thread. \\\n                Please use the default multiprocessing-based timeout (use_signals=False).\"\n            )\n\n            @wraps(func)\n            def wrapper_signal(*args, **kwargs):\n                def handler(signum, frame):\n                    # Update function name in error message if needed (optional but good practice)\n                    raise TimeoutError(f\"Function {func.__name__} timed out after {seconds} seconds (signal)!\")\n\n                old_handler = signal.getsignal(signal.SIGALRM)\n                signal.signal(signal.SIGALRM, handler)\n                # Use setitimer for float seconds support, alarm only supports integers\n                signal.setitimer(signal.ITIMER_REAL, seconds)\n\n                try:\n                    result = func(*args, **kwargs)\n                finally:\n                    # Reset timer and handler\n                    signal.setitimer(signal.ITIMER_REAL, 0)\n                    signal.signal(signal.SIGALRM, old_handler)\n                return result\n\n            return wrapper_signal\n        else:\n            # --- Multiprocessing based timeout (existing logic) ---\n            @wraps(func)\n            def wrapper_mp(*args, **kwargs):\n                q = multiprocessing.Queue(maxsize=1)\n                process = multiprocessing.Process(target=_mp_target_wrapper, args=(func, q, args, kwargs))\n                process.start()\n                process.join(timeout=seconds)\n\n                if process.is_alive():\n                    process.terminate()\n                    process.join(timeout=0.5)  # Give it a moment to terminate\n                    if process.is_alive():\n                        print(f\"Warning: Process {process.pid} did not terminate gracefully after timeout.\")\n                    # Update function name in error message if needed (optional but good practice)\n                    raise TimeoutError(f\"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!\")\n\n                try:\n                    success, result_or_exc = q.get(timeout=0.1)  # Small timeout for queue read\n                    if success:\n                        return result_or_exc\n                    else:\n                        raise result_or_exc  # Reraise exception from child\n                except queue.Empty as err:\n                    exitcode = process.exitcode\n                    if exitcode is not None and exitcode != 0:\n                        raise RuntimeError(\n                            f\"Child process exited with error (exitcode: {exitcode}) before returning result.\"\n                        ) from err\n                    else:\n                        # Should have timed out if queue is empty after join unless process died unexpectedly\n                        # Update function name in error message if needed (optional but good practice)\n                        raise TimeoutError(\n                            f\"Operation timed out or process finished unexpectedly without result \"\n                            f\"(exitcode: {exitcode}).\"\n                        ) from err\n                finally:\n                    q.close()\n                    q.join_thread()\n\n            return wrapper_mp\n\n    return decorator\n\n\ndef union_two_dict(dict1: dict, dict2: dict):\n    \"\"\"Union two dict. Will throw an error if there is an item not the same object with the same key.\n\n    Args:\n        dict1:\n        dict2:\n\n    Returns:\n\n    \"\"\"\n    for key, val in dict2.items():\n        if key in dict1:\n            assert dict2[key] == dict1[key], f\"{key} in meta_dict1 and meta_dict2 are not the same object\"\n        dict1[key] = val\n\n    return dict1\n\n\ndef append_to_dict(data: dict, new_data: dict):\n    \"\"\"Append values from new_data to lists in data.\n\n    For each key in new_data, this function appends the corresponding value to a list\n    stored under the same key in data. If the key doesn't exist in data, a new list is created.\n\n    Args:\n        data (Dict): The target dictionary containing lists as values.\n        new_data (Dict): The source dictionary with values to append.\n\n    Returns:\n        None: The function modifies data in-place.\n    \"\"\"\n    for key, val in new_data.items():\n        if key not in data:\n            data[key] = []\n        data[key].append(val)\n\n\nclass NestedNamespace(SimpleNamespace):\n    \"\"\"A nested version of SimpleNamespace that recursively converts dictionaries to namespaces.\n\n    This class allows for dot notation access to nested dictionary structures by recursively\n    converting dictionaries to NestedNamespace objects.\n\n    Example:\n        config_dict = {\"a\": 1, \"b\": {\"c\": 2, \"d\": 3}}\n        config = NestedNamespace(config_dict)\n        # Access with: config.a, config.b.c, config.b.d\n\n    Args:\n        dictionary: The dictionary to convert to a nested namespace.\n        **kwargs: Additional attributes to set on the namespace.\n    \"\"\"\n\n    def __init__(self, dictionary, **kwargs):\n        super().__init__(**kwargs)\n        for key, value in dictionary.items():\n            if isinstance(value, dict):\n                self.__setattr__(key, NestedNamespace(value))\n            else:\n                self.__setattr__(key, value)\n\n\nclass DynamicEnumMeta(type):\n    def __iter__(cls) -> Iterator[Any]:\n        return iter(cls._registry.values())\n\n    def __contains__(cls, item: Any) -> bool:\n        # allow `name in EnumClass` or `member in EnumClass`\n        if isinstance(item, str):\n            return item in cls._registry\n        return item in cls._registry.values()\n\n    def __getitem__(cls, name: str) -> Any:\n        return cls._registry[name]\n\n    def __reduce_ex__(cls, protocol):\n        # Always load the existing module and grab the class\n        return getattr, (importlib.import_module(cls.__module__), cls.__name__)\n\n    def names(cls):\n        return list(cls._registry.keys())\n\n    def values(cls):\n        return list(cls._registry.values())\n\n\nclass DynamicEnum(metaclass=DynamicEnumMeta):\n    _registry: dict[str, \"DynamicEnum\"] = {}\n    _next_value: int = 0\n\n    def __init__(self, name: str, value: int):\n        self.name = name\n        self.value = value\n\n    def __repr__(self):\n        return f\"<{self.__class__.__name__}.{self.name}: {self.value}>\"\n\n    def __reduce_ex__(self, protocol):\n        \"\"\"\n        Unpickle via: getattr(import_module(module).Dispatch, 'ONE_TO_ALL')\n        so the existing class is reused instead of re-executed.\n        \"\"\"\n        module = importlib.import_module(self.__class__.__module__)\n        enum_cls = getattr(module, self.__class__.__name__)\n        return getattr, (enum_cls, self.name)\n\n    @classmethod\n    def register(cls, name: str) -> \"DynamicEnum\":\n        key = name.upper()\n        if key in cls._registry:\n            raise ValueError(f\"{key} already registered\")\n        member = cls(key, cls._next_value)\n        cls._registry[key] = member\n        setattr(cls, key, member)\n        cls._next_value += 1\n        return member\n\n    @classmethod\n    def remove(cls, name: str):\n        key = name.upper()\n        member = cls._registry.pop(key)\n        delattr(cls, key)\n        return member\n\n    @classmethod\n    def from_name(cls, name: str) -> Optional[\"DynamicEnum\"]:\n        return cls._registry.get(name.upper())\n\n\n@contextmanager\ndef temp_env_var(key: str, value: str):\n    \"\"\"Context manager for temporarily setting an environment variable.\n\n    This context manager ensures that environment variables are properly set and restored,\n    even if an exception occurs during the execution of the code block.\n\n    Args:\n        key: Environment variable name to set\n        value: Value to set the environment variable to\n\n    Yields:\n        None\n\n    Example:\n        >>> with temp_env_var(\"MY_VAR\", \"test_value\"):\n        ...     # MY_VAR is set to \"test_value\"\n        ...     do_something()\n        ... # MY_VAR is restored to its original value or removed if it didn't exist\n    \"\"\"\n    original = os.environ.get(key)\n    os.environ[key] = value\n    try:\n        yield\n    finally:\n        if original is None:\n            os.environ.pop(key, None)\n        else:\n            os.environ[key] = original\n\n\ndef convert_to_regular_types(obj):\n    \"\"\"Convert Hydra configs and other special types to regular Python types.\"\"\"\n    from omegaconf import DictConfig, ListConfig\n\n    if isinstance(obj, ListConfig | DictConfig):\n        return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)\n    elif isinstance(obj, list | tuple):\n        return [convert_to_regular_types(x) for x in obj]\n    elif isinstance(obj, dict):\n        return {k: convert_to_regular_types(v) for k, v in obj.items()}\n    return obj\n"
  },
  {
    "path": "verl_rl/verl/utils/ray_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContains commonly used utilities for ray\n\"\"\"\n\nimport concurrent.futures\nimport os\nfrom typing import Any, Optional\n\nimport ray\n\n\ndef ray_noset_visible_devices(env_vars=os.environ):\n    # Refer to\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103\n    # https://github.com/ray-project/ray/blob/3b9e729f6a669ffd85190f901f5e262af79771b0/python/ray/_private/accelerators/amd_gpu.py#L114-L115\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98\n    NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [\n        \"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES\",\n        \"RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES\",\n        \"RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS\",\n        \"RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR\",\n    ]\n    return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST)\n\n\ndef parallel_put(data_list: list[Any], max_workers: Optional[int] = None):\n    \"\"\"\n    Puts a list of data into the Ray object store in parallel using a thread pool.\n\n    Args:\n        data_list (List[Any]): A list of Python objects to be put into the Ray object store.\n        max_workers (int, optional): The maximum number of worker threads to use.\n                                     Defaults to min(len(data_list), 16).\n\n    Returns:\n        List[ray.ObjectRef]: A list of Ray object references corresponding to the input data_list,\n                             maintaining the original order.\n    \"\"\"\n    assert len(data_list) > 0, \"data_list must not be empty\"\n\n    def put_data(index, data):\n        return index, ray.put(data)\n\n    if max_workers is None:\n        max_workers = min(len(data_list), 16)\n\n    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n        data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)]\n        res_lst = []\n        for future in concurrent.futures.as_completed(data_list_f):\n            res_lst.append(future.result())\n\n        # reorder based on index\n        output = [None for _ in range(len(data_list))]\n        for res in res_lst:\n            index, data_ref = res\n            output[index] = data_ref\n\n    return output\n"
  },
  {
    "path": "verl_rl/verl/utils/rendezvous/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/utils/rendezvous/ray_backend.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport time\n\nimport ray\nfrom cupy.cuda.nccl import NcclCommunicator, get_unique_id\nfrom ray.util import list_named_actors\n\n\n@ray.remote\nclass NCCLIDStore:\n    def __init__(self, nccl_id):\n        self._nccl_id = nccl_id\n\n    def get(self):\n        return self._nccl_id\n\n\ndef get_nccl_id_store_by_name(name):\n    all_actors = list_named_actors(all_namespaces=True)\n    matched_actors = [actor for actor in all_actors if actor.get(\"name\", None) == name]\n    if len(matched_actors) == 1:\n        actor = matched_actors[0]\n        return ray.get_actor(**actor)\n    elif len(matched_actors) > 1:\n        logging.warning(\"multiple actors with same name found: %s\", matched_actors)\n    elif len(matched_actors) == 0:\n        logging.info(\"failed to get any actor named %s\", name)\n    return None\n\n\ndef create_nccl_communicator_in_ray(\n    rank: int, world_size: int, group_name: str, max_retries: int = 100, interval_s: int = 5\n):\n    if rank == 0:\n        nccl_id = get_unique_id()\n        nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id)\n\n        assert ray.get(nccl_id_store.get.remote()) == nccl_id\n        communicator = NcclCommunicator(\n            ndev=world_size,\n            commId=nccl_id,\n            rank=0,\n        )\n        return communicator\n    else:\n        for i in range(max_retries):\n            nccl_id_store = get_nccl_id_store_by_name(group_name)\n            if nccl_id_store is not None:\n                logging.info(\"nccl_id_store %s got\", group_name)\n                nccl_id = ray.get(nccl_id_store.get.remote())\n                logging.info(\"nccl id for %s got: %s\", group_name, nccl_id)\n                communicator = NcclCommunicator(\n                    ndev=world_size,\n                    commId=nccl_id,\n                    rank=rank,\n                )\n                return communicator\n            logging.info(\"failed to get nccl_id for %d time, sleep for %d seconds\", i + 1, interval_s)\n            time.sleep(interval_s)\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# from . import gsm8k, math, prime_math, prime_code\n\nfrom verl.utils.import_utils import deprecated\n\n\ndef default_compute_score(\n    data_source,\n    solution_str,\n    ground_truth,\n    extra_info=None,\n    sandbox_fusion_url=None,\n    concurrent_semaphore=None,\n    memory_limit_mb=None,\n):\n    \"\"\"Compute the score for a given solution based on the data source.\n\n    Args:\n        data_source (str): The source dataset identifier which determines the scoring method.\n        solution_str (str): The solution string to be evaluated.\n        ground_truth (str): The ground truth answer for comparison.\n        extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None.\n\n    Returns:\n        float: The computed score as a floating point number. If the result is a dictionary,\n               it returns the dictionary instead.\n\n    Raises:\n        NotImplementedError: If the reward function is not implemented for the given data source.\n    \"\"\"\n    if data_source == \"openai/gsm8k\":\n        from . import gsm8k\n\n        res = gsm8k.compute_score(solution_str, ground_truth)\n    elif data_source in [\"lighteval/MATH\", \"DigitalLearningGmbH/MATH-lighteval\", \"HuggingFaceH4/MATH-500\"]:\n        from . import math\n\n        res = math.compute_score(solution_str, ground_truth)\n        # [Optional] Math-Verify Integration\n        # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify).\n        # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`.\n        # To use it, override the `compute_score` function with the following implementation:\n\n        # from . import math_verify\n        # res = math_verify.compute_score(solution_str, ground_truth)\n    elif data_source == \"math_dapo\" or data_source.startswith(\"aime\"):\n        from . import math_dapo\n\n        res = math_dapo.compute_score(solution_str, ground_truth)\n    elif data_source in [\n        \"numina_aops_forum\",\n        \"numina_synthetic_math\",\n        \"numina_amc_aime\",\n        \"numina_synthetic_amc\",\n        \"numina_cn_k12\",\n        \"numina_olympiads\",\n    ]:\n        from . import prime_math\n\n        res = prime_math.compute_score(solution_str, ground_truth)\n    elif data_source in [\"codecontests\", \"apps\", \"codeforces\", \"taco\"]:\n        # Use the passed sandbox_fusion_url if available\n        if sandbox_fusion_url:\n            from . import sandbox_fusion\n\n            # Pass the URL directly, ground_truth likely contains test cases here\n            res = sandbox_fusion.compute_score(\n                sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, solution_str, ground_truth, continuous=True\n            )\n        else:\n            # If no sandbox URL is provided, fall back to prime_code or raise error\n            from . import prime_code\n\n            # Assuming prime_code doesn't need the URL\n            res = prime_code.compute_score(solution_str, ground_truth, continuous=True)\n    elif data_source in [\"hiyouga/geometry3k\"]:\n        from . import geo3k\n\n        res = geo3k.compute_score(solution_str, ground_truth)\n    elif data_source in [\n        \"searchR1_nq\",\n        \"searchR1_triviaqa\",\n        \"searchR1_popqa\",\n        \"searchR1_hotpotqa\",\n        \"searchR1_2wikimultihopqa\",\n        \"searchR1_musique\",\n        \"searchR1_bamboogle\",\n    ]:\n        from . import search_r1_like_qa_em\n\n        res = search_r1_like_qa_em.compute_score(solution_str, ground_truth)\n\n    else:\n        raise NotImplementedError(f\"Reward function is not implemented for {data_source=}\")\n\n    if isinstance(res, dict):\n        return res\n    elif isinstance(res, int | float | bool):\n        return float(res)\n    else:\n        return float(res[0])\n\n\n@deprecated(\"verl.utils.reward_score.default_compute_score\")\ndef _default_compute_score(\n    data_source,\n    solution_str,\n    ground_truth,\n    extra_info=None,\n    sandbox_fusion_url=None,\n    concurrent_semaphore=None,\n    memory_limit_mb=None,\n):\n    \"\"\"\n    Legacy function API to be deprecated. Please use `default_compute_score` instead.\n    \"\"\"\n    return default_compute_score(\n        data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore, memory_limit_mb\n    )\n\n\n__all__ = [\"default_compute_score\"]\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/geo3k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport re\n\nfrom mathruler.grader import extract_boxed_content, grade_answer\n\n\ndef format_reward(predict_str: str) -> float:\n    pattern = re.compile(r\"<think>.*</think>.*\\\\boxed\\{.*\\}.*\", re.DOTALL)\n    match_result = re.fullmatch(pattern, predict_str)\n    return 1.0 if match_result else 0.0\n\n\ndef acc_reward(predict_str: str, ground_truth: str, use_boxed: bool = True) -> float:\n    if use_boxed:\n        answer = extract_boxed_content(predict_str)\n    else:\n        answer = predict_str\n    return 1.0 if grade_answer(answer, ground_truth) else 0.0\n\n\ndef compute_score(predict_str: str, ground_truth: str, use_boxed: bool = True, format_score: float = 0.1) -> float:\n    return (1.0 - format_score) * acc_reward(predict_str, ground_truth, use_boxed) + format_score * format_reward(\n        predict_str\n    )\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/gsm8k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\n\n_SOLUTION_CLIP_CHARS = 300\n\n\ndef extract_solution(solution_str, method=\"strict\"):\n    assert method in [\"strict\", \"flexible\"]\n\n    # Optimization: Regular expression matching on very long strings can be slow.\n    # For math problems, the final answer is usually at the end.\n    # We only match on the last 300 characters, which is a safe approximation for 300 tokens.\n    if len(solution_str) > _SOLUTION_CLIP_CHARS:\n        solution_str = solution_str[-_SOLUTION_CLIP_CHARS:]\n\n    if method == \"strict\":\n        # this also tests the formatting of the model\n        solutions = re.findall(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n        if len(solutions) == 0:\n            final_answer = None\n        else:\n            # take the last solution\n            final_answer = solutions[-1].replace(\",\", \"\").replace(\"$\", \"\")\n    elif method == \"flexible\":\n        answer = re.findall(\"(\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n        final_answer = None\n        if len(answer) == 0:\n            # no reward is there is no answer\n            pass\n        else:\n            invalid_str = [\"\", \".\"]\n            # find the last number that is not '.'\n            for final_answer in reversed(answer):\n                if final_answer not in invalid_str:\n                    break\n    return final_answer\n\n\ndef compute_score(solution_str, ground_truth, method=\"strict\", format_score=0.0, score=1.0):\n    \"\"\"The scoring function for GSM8k.\n\n    Reference: Trung, Luong, et al. \"Reft: Reasoning with reinforced fine-tuning.\" Proceedings of the 62nd Annual\n    Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.\n\n    Args:\n        solution_str: the solution text\n        ground_truth: the ground truth\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\n        format_score: the score for the format\n        score: the score for the correct answer\n    \"\"\"\n    answer = extract_solution(solution_str=solution_str, method=method)\n    if answer is None:\n        return 0\n    else:\n        if answer == ground_truth:\n            return score\n        else:\n            return format_score\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/math.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py\n\n\ndef compute_score(solution_str, ground_truth) -> float:\n    retval = 0.0\n    try:\n        string_in_last_boxed = last_boxed_only_string(solution_str)\n        if string_in_last_boxed is not None:\n            answer = remove_boxed(string_in_last_boxed)\n            if is_equiv(answer, ground_truth):\n                retval = 1.0\n    except Exception as e:\n        print(e)\n\n    return retval\n\n\n# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py\ndef is_equiv(str1, str2, verbose=False):\n    if str1 is None and str2 is None:\n        print(\"WARNING: Both None\")\n        return True\n    if str1 is None or str2 is None:\n        return False\n\n    try:\n        ss1 = strip_string(str1)\n        ss2 = strip_string(str2)\n        if verbose:\n            print(ss1, ss2)\n        return ss1 == ss2\n    except Exception:\n        return str1 == str2\n\n\ndef remove_boxed(s):\n    if \"\\\\boxed \" in s:\n        left = \"\\\\boxed \"\n        assert s[: len(left)] == left\n        return s[len(left) :]\n\n    left = \"\\\\boxed{\"\n\n    assert s[: len(left)] == left\n    assert s[-1] == \"}\"\n\n    return s[len(left) : -1]\n\n\ndef last_boxed_only_string(string):\n    idx = string.rfind(\"\\\\boxed\")\n    if \"\\\\boxed \" in string:\n        return \"\\\\boxed \" + string.split(\"\\\\boxed \")[-1].split(\"$\")[0]\n    if idx < 0:\n        idx = string.rfind(\"\\\\fbox\")\n        if idx < 0:\n            return None\n\n    i = idx\n    right_brace_idx = None\n    num_left_braces_open = 0\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n        if string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n        i += 1\n\n    retval = None if right_brace_idx is None else string[idx : right_brace_idx + 1]\n\n    return retval\n\n\ndef fix_fracs(string):\n    substrs = string.split(\"\\\\frac\")\n    new_str = substrs[0]\n    if len(substrs) > 1:\n        substrs = substrs[1:]\n        for substr in substrs:\n            new_str += \"\\\\frac\"\n            if substr[0] == \"{\":\n                new_str += substr\n            else:\n                try:\n                    assert len(substr) >= 2\n                except Exception:  # noqa: E722\n                    return string\n                a = substr[0]\n                b = substr[1]\n                if b != \"{\":\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}{\" + b + \"}\"\n                else:\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}\" + b + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}\" + b\n    string = new_str\n    return string\n\n\ndef fix_a_slash_b(string):\n    if len(string.split(\"/\")) != 2:\n        return string\n    a = string.split(\"/\")[0]\n    b = string.split(\"/\")[1]\n    try:\n        a = int(a)\n        b = int(b)\n        assert string == \"{}/{}\".format(a, b)\n        new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n        return new_string\n    except Exception:  # noqa: E722\n        return string\n\n\ndef remove_right_units(string):\n    # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n    if \"\\\\text{ \" in string:\n        splits = string.split(\"\\\\text{ \")\n        assert len(splits) == 2\n        return splits[0]\n    else:\n        return string\n\n\ndef fix_sqrt(string):\n    if \"\\\\sqrt\" not in string:\n        return string\n    splits = string.split(\"\\\\sqrt\")\n    new_string = splits[0]\n    for split in splits[1:]:\n        if split[0] != \"{\":\n            a = split[0]\n            new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n        else:\n            new_substr = \"\\\\sqrt\" + split\n        new_string += new_substr\n    return new_string\n\n\ndef strip_string(string):\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\%\", \"\")\n    string = string.replace(\"\\%\", \"\")  # noqa: W605\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2 and len(string.split(\"=\")[0]) <= 2:\n        string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1).\n    # Also does a/b --> \\\\frac{a}{b}\n    string = fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = fix_a_slash_b(string)\n\n    return string\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/math_batch.py",
    "content": "# Copyright 2025 Individual Contributor: Mert Unsal\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .math import compute_score\n\n\ndef compute_score_batched(data_sources, solution_strs, ground_truths, extra_infos):\n    \"\"\"\n    This is a demonstration of how the batched reward function should look like.\n    Typically, you want to use batched reward to speed up the process with parallelization\n    \"\"\"\n    return [\n        compute_score(solution_str, ground_truth)\n        for solution_str, ground_truth in zip(solution_strs, ground_truths, strict=True)\n    ]\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/math_dapo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py\n\nimport re\nfrom typing import Optional\n\n\ndef last_boxed_only_string(string: str) -> Optional[str]:\n    \"\"\"Extract the last LaTeX boxed expression from a string.\n\n    Args:\n        string: Input string containing LaTeX code\n\n    Returns:\n        The last boxed expression or None if not found\n    \"\"\"\n    idx = string.rfind(\"\\\\boxed{\")\n    if idx < 0:\n        return None\n\n    i = idx\n    right_brace_idx = None\n    num_left_braces_open = 0\n\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n        if string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n        i += 1\n\n    return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None\n\n\ndef remove_boxed(s: str) -> str:\n    \"\"\"Remove the LaTeX boxed command from a string.\n\n    Args:\n        s: String with format \"\\\\boxed{content}\"\n\n    Returns:\n        The content inside the boxed command\n    \"\"\"\n    left = \"\\\\boxed{\"\n    assert s[: len(left)] == left, f\"box error: {s}\"\n    assert s[-1] == \"}\", f\"box error: {s}\"\n    return s[len(left) : -1]\n\n\n# Constants for normalization\nSUBSTITUTIONS = [\n    (\"an \", \"\"),\n    (\"a \", \"\"),\n    (\".$\", \"$\"),\n    (\"\\\\$\", \"\"),\n    (r\"\\ \", \"\"),\n    (\" \", \"\"),\n    (\"mbox\", \"text\"),\n    (\",\\\\text{and}\", \",\"),\n    (\"\\\\text{and}\", \",\"),\n    (\"\\\\text{m}\", \"\\\\text{}\"),\n]\n\nREMOVED_EXPRESSIONS = [\n    \"square\",\n    \"ways\",\n    \"integers\",\n    \"dollars\",\n    \"mph\",\n    \"inches\",\n    \"hours\",\n    \"km\",\n    \"units\",\n    \"\\\\ldots\",\n    \"sue\",\n    \"points\",\n    \"feet\",\n    \"minutes\",\n    \"digits\",\n    \"cents\",\n    \"degrees\",\n    \"cm\",\n    \"gm\",\n    \"pounds\",\n    \"meters\",\n    \"meals\",\n    \"edges\",\n    \"students\",\n    \"childrentickets\",\n    \"multiples\",\n    \"\\\\text{s}\",\n    \"\\\\text{.}\",\n    \"\\\\text{\\ns}\",\n    \"\\\\text{}^2\",\n    \"\\\\text{}^3\",\n    \"\\\\text{\\n}\",\n    \"\\\\text{}\",\n    r\"\\mathrm{th}\",\n    r\"^\\circ\",\n    r\"^{\\circ}\",\n    r\"\\;\",\n    r\",\\!\",\n    \"{,}\",\n    '\"',\n    \"\\\\dots\",\n]\n\n\ndef normalize_final_answer(final_answer: str) -> str:\n    \"\"\"Normalize a final answer to a quantitative reasoning question.\n\n    Args:\n        final_answer: The answer string to normalize\n\n    Returns:\n        Normalized answer string\n    \"\"\"\n    final_answer = final_answer.split(\"=\")[-1]\n\n    # Apply substitutions and removals\n    for before, after in SUBSTITUTIONS:\n        final_answer = final_answer.replace(before, after)\n    for expr in REMOVED_EXPRESSIONS:\n        final_answer = final_answer.replace(expr, \"\")\n\n    # Extract and normalize LaTeX math\n    final_answer = re.sub(r\"(.*?)(\\$)(.*?)(\\$)(.*)\", \"$\\\\3$\", final_answer)\n    final_answer = re.sub(r\"(\\\\text\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\textbf\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\overline\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\boxed\\{)(.*)(\\})\", \"\\\\2\", final_answer)\n\n    # Normalize shorthand TeX:\n    #  \\fracab -> \\frac{a}{b}\n    #  \\frac{abc}{bef} -> \\frac{abc}{bef}\n    #  \\fracabc -> \\frac{a}{b}c\n    #  \\sqrta -> \\sqrt{a}\n    #  \\sqrtab -> sqrt{a}b\n    final_answer = re.sub(r\"(frac)([^{])(.)\", \"frac{\\\\2}{\\\\3}\", final_answer)\n    final_answer = re.sub(r\"(sqrt)([^{])\", \"sqrt{\\\\2}\", final_answer)\n    final_answer = final_answer.replace(\"$\", \"\")\n\n    # Normalize numbers\n    if final_answer.replace(\",\", \"\").isdigit():\n        final_answer = final_answer.replace(\",\", \"\")\n\n    return final_answer.strip()\n\n\ndef is_correct_minerva(\n    solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r\"(?i)Answer\\s*:\\s*([^\\n]+)\"\n) -> tuple[bool, str]:\n    \"\"\"Check if the solution is correct according to Minerva criteria.\n\n    Args:\n        solution_str: The solution string to check\n        gt: The ground truth answer\n        gt_need_extract: Whether the ground truth needs extraction\n        answer_pattern: Regex pattern to extract the answer\n\n    Returns:\n        Tuple of (is_correct, normalized_prediction)\n    \"\"\"\n    # Extract answer from solution\n    match = re.findall(answer_pattern, solution_str)\n    extracted_answer = match[-1] if match else \"[INVALID]\"\n    pred = normalize_final_answer(extracted_answer)\n\n    # Process ground truth\n    if gt_need_extract:\n        gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt)))\n    else:\n        gt = normalize_final_answer(gt)\n\n    return (pred == gt), pred\n\n\ndef is_correct_strict_box(\n    pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None\n) -> tuple[int, Optional[str]]:\n    \"\"\"Check if the prediction is correct using strict boxed answer criteria.\n\n    Args:\n        pred: The prediction string\n        gt: The ground truth answer\n        pause_tokens_index: Indices of pause tokens\n\n    Returns:\n        Tuple of (score, extracted_prediction)\n    \"\"\"\n    # Extract the relevant part of the prediction\n    if pause_tokens_index is not None:\n        assert len(pause_tokens_index) == 4\n        pred = pred[pause_tokens_index[-1] - 100 :]\n    else:\n        pred = pred[-100:]\n\n    # Extract and check the boxed answer\n    boxed_pred = last_boxed_only_string(pred)\n    extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None\n\n    return 1 if (extracted_pred == gt) else -1, extracted_pred\n\n\ndef verify(\n    solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None\n) -> bool:\n    \"\"\"Verify if the solution is correct.\n\n    Args:\n        solution_str: The solution string to verify\n        answer: The ground truth answer\n        strict_box_verify: Whether to use strict box verification\n        pause_tokens_index: Indices of pause tokens\n\n    Returns:\n        True if the solution is correct, False otherwise\n    \"\"\"\n    if strict_box_verify:\n        correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index)\n        return correct == 1, pred\n\n    correct, pred = is_correct_minerva(solution_str, answer)\n    return correct, pred\n\n\ndef compute_score(\n    solution_str: str,\n    ground_truth: str,\n    strict_box_verify: bool = False,\n    pause_tokens_index: Optional[list[int]] = None,\n) -> float:\n    \"\"\"Compute the reward score for a solution.\n\n    Args:\n        solution_str: The solution string\n        ground_truth: The ground truth answer\n        strict_box_verify: Whether to use strict box verification\n        pause_tokens_index: Indices of pause tokens\n\n    Returns:\n        Reward score (1.0 for correct, -1.0 for incorrect)\n    \"\"\"\n    # Limit solution length for efficiency\n    solution_str = solution_str[-300:]  # The longest answer in MATH-500 has 159 characters\n\n    # Verify the solution\n    correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index)\n\n    reward = 1.0 if correct else -1.0\n    acc = correct\n\n    return {\n        \"score\": reward,\n        \"acc\": acc,\n        \"pred\": pred,\n    }\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/math_verify.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ntry:\n    from math_verify.errors import TimeoutException\n    from math_verify.metric import math_metric\n    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig\nexcept ImportError:\n    print(\"To use Math-Verify, please install it first by running `pip install math-verify`.\")\n\n\ndef compute_score(model_output: str, ground_truth: str, timeout_score: float = 0) -> bool:\n    verify_func = math_metric(\n        gold_extraction_target=(LatexExtractionConfig(),),\n        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),\n    )\n    ret_score = 0.0\n\n    # Wrap the ground truth in \\boxed{} format for verification\n    ground_truth_boxed = \"\\\\boxed{\" + ground_truth + \"}\"\n    try:\n        ret_score, _ = verify_func([ground_truth_boxed], [model_output])\n    except Exception:\n        pass\n    except TimeoutException:\n        ret_score = timeout_score\n\n    return ret_score\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/prime_code/README.md",
    "content": "## LiveCodeBench\n\n### Introduction\n[LiveCodeBench](https://github.com/LiveCodeBench/LiveCodeBench) provides holistic and contamination-free evaluation of coding capabilities of LLMs. Particularly, LiveCodeBench continuously collects new problems over time from contests across three competition platforms -- LeetCode, AtCoder, and CodeForces. \n\n### How to reproduce\nOur evaluation is grounded on the version found in LiveCodeBench.\n> **Installation**\n```bash\n# Make sure the CUDA version > 12.0.\npip install -r requirements.txt\npip install flash-attn --no-build-isolation\n```\n\n### Acknowleage\nThank you to the [LiveCodeBench](https://livecodebench.github.io/leaderboard.html) team for their contributions to the open-source community."
  },
  {
    "path": "verl_rl/verl/utils/reward_score/prime_code/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport traceback\n\nfrom .utils import check_correctness as apps_check_correctness\n\n\ndef compute_score(completion, test_cases, continuous=False):\n    # try to get code solution from completion. if the completion is pure code, this will not take effect.\n    solution = completion.split(\"```python\")[-1].split(\"```\")[0]\n    try:\n        try:\n            if not isinstance(test_cases, dict):\n                test_cases = json.loads(test_cases)\n        except Exception as e:\n            print(f\"Error:{e}\")\n\n        # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped.\n        try:\n            res, metadata = apps_check_correctness(in_outs=test_cases, generation=solution, timeout=5, debug=False)\n            metadata = dict(enumerate(metadata))[0]\n            success = all(map(lambda x: x is True, res))\n            if success:\n                return success, metadata\n        except Exception:\n            pass\n\n        test_cases_list = []\n        inputs = test_cases[\"inputs\"]\n        outputs = test_cases[\"outputs\"]\n        for i in range(len(inputs)):\n            test_cases_list.append({\"inputs\": [inputs[i]], \"outputs\": [outputs[i]]})\n\n        if continuous:\n            # per sample test: if continuous score is needed, test first 10 samples regardless of failures\n            # do not test all samples cuz some problems have enormous test cases\n            metadata_list = []\n            res_list = []\n            for test_case_id, test_case in enumerate(test_cases_list):\n                res, metadata = apps_check_correctness(in_outs=test_case, generation=solution, timeout=10, debug=False)\n                try:\n                    metadata = dict(enumerate(metadata))[0]  # metadata can be empty occasionally\n                except Exception:\n                    metadata = {}\n                metadata[\"test_case\"] = {}\n                metadata[\"test_case\"][\"input\"] = str(test_case[\"inputs\"][0])\n                metadata[\"test_case\"][\"output\"] = str(test_case[\"outputs\"][0])\n                metadata[\"test_case\"][\"res\"] = str(res)\n                metadata_list.append(metadata)\n                res_list.extend(res)\n\n                if test_case_id >= 9:\n                    break\n            res_count = len(res_list) if len(res_list) > 0 else 1\n            success = sum(map(lambda x: x is True, res_list)) / res_count\n    except Exception:\n        traceback.print_exc(10)\n        success = False\n        metadata_list = None\n    return success, metadata_list\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/prime_code/testing_util.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\nimport faulthandler\nimport json\nimport platform\n\n# to run the solution files we're using a timing based approach\nimport signal\nimport sys\nimport traceback\n\n# used for debugging to time steps\nfrom datetime import datetime\nfrom enum import Enum\n\n# for capturing the stdout\nfrom io import StringIO\n\n# used for testing the code that reads from input\nfrom unittest.mock import mock_open, patch\n\nimport numpy as np\nfrom pyext import RuntimeModule\n\n\ndef truncatefn(s, length=300):\n    assert isinstance(s, str)\n    if len(s) <= length:\n        return s\n\n    return s[: length // 2] + \"...(truncated) ...\" + s[-length // 2 :]\n\n\nclass CODE_TYPE(Enum):\n    call_based = 0\n    standard_input = 1\n\n\n# used to capture stdout as a list\n# from https://stackoverflow.com/a/16571630/6416660\n# alternative use redirect_stdout() from contextlib\nclass Capturing(list):\n    def __enter__(self):\n        self._stdout = sys.stdout\n        sys.stdout = self._stringio = StringIO()\n        # Make closing the StringIO a no-op\n        self._stringio.close = lambda x: 1\n        return self\n\n    def __exit__(self, *args):\n        self.append(self._stringio.getvalue())\n        del self._stringio  # free up some memory\n        sys.stdout = self._stdout\n\n\ndef only_int_check(val):\n    return isinstance(val, int)\n\n\ndef string_int_check(val):\n    return isinstance(val, str) and val.isdigit()\n\n\ndef combined_int_check(val):\n    return only_int_check(val) or string_int_check(val)\n\n\ndef clean_traceback(error_traceback):\n    file_start = error_traceback.find('File \"<string>\"')\n    # print(file_start)\n    error_traceback = \"Traceback (most recent call last):\\n  \" + error_traceback[file_start:]\n    return error_traceback\n\n\ndef run_test(in_outs, test=None, debug=False, timeout=15):\n    \"\"\"\n    if test(generated_code) is not None it'll try to run the code.\n    otherwise it'll just return an input and output pair.\n    \"\"\"\n    # Disable functionalities that can make destructive changes to the test.\n    reliability_guard()\n\n    if debug:\n        print(f\"start = {datetime.now().time()}\")\n\n    if in_outs:\n        if in_outs.get(\"fn_name\") is None:\n            which_type = CODE_TYPE.standard_input  # Standard input\n            method_name = None\n        else:\n            which_type = CODE_TYPE.call_based  # Call-based\n            method_name = in_outs[\"fn_name\"]\n\n    if debug:\n        print(f\"loaded input_output = {datetime.now().time()}\")\n\n    if test is None:\n        raise AssertionError(\"should not happen: test code is none\")\n    elif test is not None:\n        results = []\n        sol = \"from string import *\\nfrom re import *\\nfrom datetime import *\\nfrom collections import *\\nfrom heapq import *\\nfrom bisect import *\\nfrom copy import *\\nfrom math import *\\nfrom random import *\\nfrom statistics import *\\nfrom itertools import *\\nfrom functools import *\\nfrom operator import *\\nfrom io import *\\nfrom sys import *\\nfrom json import *\\nfrom builtins import *\\nfrom typing import *\\nimport string\\nimport re\\nimport datetime\\nimport collections\\nimport heapq\\nimport bisect\\nimport copy\\nimport math\\nimport random\\nimport statistics\\nimport itertools\\nimport functools\\nimport operator\\nimport io\\nimport sys\\nimport json\\nsys.setrecursionlimit(6*10**5)\\n\"  # noqa: E501\n        if debug:\n            print(f\"loading test code = {datetime.now().time()}\")\n\n        if which_type == CODE_TYPE.call_based:\n            sol += test\n            if debug:\n                print(f\"sol = {sol}\")\n            signal.alarm(timeout)\n            try:\n                tmp_sol = RuntimeModule.from_string(\"tmp_sol\", \"\", sol)\n                tmp = tmp_sol if \"class Solution\" not in test else tmp_sol.Solution()\n                signal.alarm(0)\n            except Exception as e:\n                signal.alarm(0)\n                error_traceback = traceback.format_exc()\n                if debug:\n                    print(f\"type 0 compilation error = {e}\")\n                results.append(-2)\n                return results, {\n                    \"error\": repr(e),\n                    # \"error_code\": -1,\n                    # \"error_message\": \"Compilation Error\",\n                    \"traceback\": clean_traceback(error_traceback),\n                }\n            signal.alarm(0)\n\n        elif which_type == CODE_TYPE.standard_input:\n            # sol\n            # if code has if __name__ == \"__main__\": then remove it\n            try:\n                astree = ast.parse(test)\n                last_block = astree.body[-1]\n                if isinstance(last_block, ast.If):\n                    condition = last_block.test\n                    if ast.unparse(condition).strip() == \"__name__ == '__main__'\":\n                        test = ast.unparse(astree.body[:-1]) + \"\\n\" + ast.unparse(last_block.body)\n            except Exception:\n                pass\n\n            tmp_test = test.split(\"\\n\")\n\n            new_test = []\n            for x in tmp_test:\n                if (not x.startswith(\"from \")) and (not x.startswith(\"import \")):\n                    new_test.append(\"\\t\" + x + \"\\n\")\n                else:\n                    new_test.append(x + \"\\n\")\n            tmp_test = new_test\n\n            new_test = \"\"\n            started = False\n            for i in tmp_test:\n                if i.startswith(\"\\t\") and not started:\n                    new_test += \"stdin = sys.stdin\\nstdout = sys.stdout\\n\"\n                    new_test += \"def code():\\n\"\n                    new_test += i\n                    started = True\n                elif started and ((i.startswith(\"from \")) or (i.startswith(\"import \"))):\n                    new_test += \"\\t\" + i\n                else:\n                    new_test += i\n            tmp_test = new_test\n\n            sol += tmp_test\n            if debug:\n                print(f\"sol = {sol}\")\n            method_name = \"code\"\n            signal.alarm(timeout)\n            try:\n                tmp_sol = RuntimeModule.from_string(\"tmp_sol\", \"\", sol)\n                tmp = tmp_sol\n                signal.alarm(0)\n            except Exception as e:\n                signal.alarm(0)\n                error_traceback = traceback.format_exc()\n                if debug:\n                    print(f\"type 1 compilation error = {e}\")\n                results.append(-2)\n                return results, {\n                    \"error\": repr(e),\n                    # \"error_code\": -1,\n                    # \"error_message\": \"Compilation Error\",\n                    \"traceback\": clean_traceback(error_traceback),\n                }\n            signal.alarm(0)\n        if debug:\n            print(f\"get method = {datetime.now().time()}\")\n\n        try:\n            method = getattr(tmp, method_name)  # get_attr second arg must be str\n        except Exception:\n            signal.alarm(0)\n            error_traceback = traceback.format_exc()\n            error_info = sys.exc_info()\n            print(f\"unable to get function error = {error_info}\")\n            results.append(-2)\n            return results, {\n                \"error\": repr(error_info),\n                # \"error_code\": -1,\n                # \"error_message\": \"Unable to extract code\",\n                \"traceback\": clean_traceback(error_traceback),\n            }\n\n        for index, inputs in enumerate(in_outs[\"inputs\"]):\n            raw_inputs = inputs\n            raw_outputs = in_outs[\"outputs\"][index]\n            if which_type == CODE_TYPE.call_based:\n                inputs = [json.loads(line) for line in inputs.split(\"\\n\")]\n                in_outs[\"outputs\"][index] = json.loads(in_outs[\"outputs\"][index])\n\n                truncate_line_size = 300 // (raw_inputs.count(\"\\n\") + 1)\n                raw_inputs = \"\\n\".join(\n                    [truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split(\"\\n\")]\n                )\n                raw_outputs = truncatefn(raw_outputs, 200)\n            else:\n                raw_inputs = truncatefn(raw_inputs)\n                raw_outputs = truncatefn(raw_outputs, 200)\n            # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)\n            try:\n                if isinstance(inputs[0], dict):\n                    inputs = [{int(k): v for k, v in inputs[0].items()}]\n            except Exception:\n                pass\n            try:\n                if isinstance(in_outs[\"outputs\"][index], dict):\n                    in_outs[\"outputs\"][index] = [{int(k): v for k, v in in_outs[\"outputs\"][index].items()}]\n            except Exception:\n                pass\n            try:\n                if isinstance(in_outs[\"outputs\"][index][0], dict):\n                    in_outs[\"outputs\"][index] = [{int(k): v for k, v in in_outs[\"outputs\"][index][0].items()}]\n            except Exception:\n                pass\n\n            if debug:\n                print(\n                    f\"time: {datetime.now().time()} testing index = {index}  inputs = {inputs}, {type(inputs)}. \"\n                    f\"type = {which_type}\"\n                )\n            if which_type == CODE_TYPE.call_based:  # Call-based\n                signal.alarm(timeout)\n                faulthandler.enable()\n                try:\n                    output = method(*inputs)\n                    raw_true_output = output\n\n                    raw_true_output_copy = json.dumps(output)\n                    raw_true_output_copy = truncatefn(raw_true_output_copy, 200)\n\n                    # ground truth sequences are not tuples\n                    if isinstance(output, tuple):\n                        output = list(output)\n\n                    tmp_result = output == in_outs[\"outputs\"][index]\n                    if isinstance(in_outs[\"outputs\"][index], list) and in_outs[\"outputs\"][index]:\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index][0])\n\n                    # ground truth sequences are not tuples\n                    try:\n                        if isinstance(output[0], tuple):\n                            tmp_result = tmp_result or ([list(x) for x in output] == in_outs[\"outputs\"][index][0])\n                    except Exception:\n                        pass\n                    results.append(tmp_result)\n                    if tmp_result is not True:\n                        return results, {\n                            \"output\": raw_true_output_copy,\n                            \"expected\": raw_outputs,\n                            \"inputs\": raw_inputs,\n                            # \"error_code\": -2,\n                            \"error_message\": \"Wrong Answer\",\n                        }\n                    # reset the alarm\n                    signal.alarm(0)\n                except Exception as e:\n                    signal.alarm(0)\n                    error_traceback = traceback.format_exc()\n                    faulthandler.disable()\n                    if debug:\n                        print(f\"Standard input runtime error or time limit exceeded error = {e}\")\n                    results.append(-1)\n                    return results, {\n                        \"error\": repr(e),\n                        \"traceback\": clean_traceback(error_traceback),\n                    }\n                faulthandler.disable()\n                signal.alarm(0)\n                if debug:\n                    print(\n                        f\"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, \"\n                        f\"{type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                    )\n            elif which_type == CODE_TYPE.standard_input:  # Standard input\n                faulthandler.enable()\n                passed = False\n\n                if isinstance(inputs, list):\n                    inputs = \"\\n\".join(inputs)\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    in_outs[\"outputs\"][index] = \"\\n\".join(in_outs[\"outputs\"][index])\n\n                signal.alarm(timeout)\n                with Capturing() as output:\n                    try:\n                        call_method(method, inputs)\n                        # reset the alarm\n                        signal.alarm(0)\n                        passed = True\n                    except Exception as e:\n                        # runtime error or took too long\n                        signal.alarm(0)\n                        error_traceback = traceback.format_exc()\n                        print(f\"Call-based runtime error or time limit exceeded error = {repr(e)}{e}\")\n                        results.append(-1)\n                        return results, {\n                            \"error\": repr(e),\n                            \"traceback\": clean_traceback(error_traceback),\n                        }\n                    signal.alarm(0)\n                raw_true_output = output[0]\n                raw_true_output_copy = truncatefn(raw_true_output, 200)\n                output = raw_true_output.splitlines()\n                if not passed:\n                    if debug:\n                        nl = \"\\n\"\n                        if not isinstance(inputs, list):\n                            print(\n                                f\"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, \"\n                                f\"inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, \"\n                                f\"{output == [in_outs['outputs'][index]]}\"\n                            )\n                        else:\n                            print(\n                                f\"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, \"\n                                f\"inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                            )\n                    continue\n\n                if passed and debug:\n                    print(f\"==> output = {output}, test outputs = {in_outs['outputs'][index]}\")\n\n                if custom_compare_(output, in_outs[\"outputs\"][index]):\n                    tmp_result = True\n                    results.append(tmp_result)\n                    continue\n\n                # ground truth sequences are expressed as lists not tuples\n                if isinstance(output, tuple):\n                    output = list(output)\n\n                tmp_result = False\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                        if isinstance(output[0], str):\n                            tmp_result = tmp_result or ([e.strip() for e in output] == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check1 exception = {e}\")\n                    pass\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                # try one more time without \\n\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    for tmp_index, i in enumerate(in_outs[\"outputs\"][index]):\n                        in_outs[\"outputs\"][index][tmp_index] = i.split(\"\\n\")\n                        in_outs[\"outputs\"][index][tmp_index] = [\n                            x.strip() for x in in_outs[\"outputs\"][index][tmp_index] if x\n                        ]\n                else:\n                    in_outs[\"outputs\"][index] = in_outs[\"outputs\"][index].split(\"\\n\")\n                    in_outs[\"outputs\"][index] = list(filter(len, in_outs[\"outputs\"][index]))\n                    in_outs[\"outputs\"][index] = list(map(lambda x: x.strip(), in_outs[\"outputs\"][index]))\n\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check2 exception = {e}\")\n                    pass\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                # try by converting the output into a split up list too\n                if isinstance(output, list):\n                    output = list(filter(len, output))\n\n                if debug:\n                    nl = \"\\n\"\n                    if not isinstance(inputs, list):\n                        print(\n                            f\"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, \"\n                            f\"inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, \"\n                            f\"{output == [in_outs['outputs'][index]]} {tmp_result=}\"\n                        )\n                    else:\n                        print(\n                            f\"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, \"\n                            f\"{type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}\"\n                        )\n\n                if debug:\n                    print(f\"{tmp_result=} @a\")\n\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check3 exception = {e}\")\n                    pass\n\n                if debug:\n                    print(f\"{tmp_result=} @b\")\n\n                try:\n                    all_ints = all(\n                        combined_int_check(e1) and combined_int_check(e2)\n                        for e1, e2 in zip(output, in_outs[\"outputs\"][index], strict=True)\n                    )\n                    if not all_ints:\n                        if debug:\n                            print(\n                                [\n                                    combined_int_check(e1) and combined_int_check(e2)\n                                    for e1, e2 in zip(output, in_outs[\"outputs\"][index], strict=True)\n                                ]\n                            )\n                        output_float = [float(e) for e in output]\n                        gt_float = [float(e) for e in in_outs[\"outputs\"][index]]\n                        tmp_result = tmp_result or (\n                            (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)\n                        )\n                except Exception:\n                    pass\n\n                if debug:\n                    print(f\"{tmp_result=} @c\")\n\n                try:\n                    if isinstance(output[0], list):\n                        all_ints = all(\n                            combined_int_check(e1) and combined_int_check(e2)\n                            for e1, e2 in zip(output[0], in_outs[\"outputs\"][index], strict=True)\n                        )\n                        if not all_ints:\n                            output_float = [float(e) for e in output[0]]\n                            gt_float = [float(e) for e in in_outs[\"outputs\"][index][0]]\n                            tmp_result = tmp_result or (\n                                (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)\n                            )\n                except Exception:\n                    pass\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                if debug:\n                    print(f\"{tmp_result=} @d\")\n                # try by converting the stuff into split up list\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    for tmp_index, i in enumerate(in_outs[\"outputs\"][index]):\n                        in_outs[\"outputs\"][index][tmp_index] = set(i.split())\n                else:\n                    in_outs[\"outputs\"][index] = set(in_outs[\"outputs\"][index].split())\n\n                if debug:\n                    print(f\"{tmp_result=} @e\")\n\n                try:\n                    tmp_result = output == in_outs[\"outputs\"][index]\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check4 exception = {e}\")\n                    continue\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                if debug:\n                    print(f\"{tmp_result=} @f\")\n\n                # try by converting the output into a split up list too\n                if isinstance(output, list):\n                    for tmp_index, i in enumerate(output):\n                        output[tmp_index] = i.split()\n                    output = list(filter(len, output))\n                    for tmp_index, i in enumerate(output):\n                        output[tmp_index] = set(i)\n                else:\n                    output = output.split()\n                    output = list(filter(len, output))\n                    output = set(output)\n\n                if debug:\n                    print(f\"{tmp_result=} @g\")\n\n                if tmp_result is True and debug:\n                    print(\"PASSED\")\n\n                results.append(tmp_result)\n                if tmp_result is not True:\n                    return results, {\n                        \"output\": raw_true_output_copy,\n                        \"expected\": raw_outputs,\n                        \"inputs\": raw_inputs,\n                        # \"error_code\": -2,\n                        \"error_message\": \"Wrong Answer\",\n                    }\n\n                if debug:\n                    nl = \"\\n\"\n                    if not isinstance(inputs, list):\n                        print(\n                            f\"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, \"\n                            f\"inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, \"\n                            f\"{output == [in_outs['outputs'][index]]}\"\n                        )\n                    else:\n                        print(\n                            f\"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, \"\n                            f\"{type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                        )\n\n                    print(f\"results = {results}\")\n\n    return results, {}\n\n\ndef custom_compare_(output, ground_truth):\n    if isinstance(output, list):\n        output_1 = \"\\n\".join(output)\n        if stripped_string_compare(output_1, ground_truth):\n            return True\n\n    if isinstance(output, list):\n        output_2 = [o.lstrip().rstrip() for o in output]\n        output_2 = \"\\n\".join(output_2)\n        if stripped_string_compare(output_2, ground_truth):\n            return True\n\n    return False\n\n\ndef stripped_string_compare(s1, s2):\n    s1 = s1.lstrip().rstrip()\n    s2 = s2.lstrip().rstrip()\n    return s1 == s2\n\n\ndef call_method(method, inputs):\n    if isinstance(inputs, list):\n        inputs = \"\\n\".join(inputs)\n\n    inputs_line_iterator = iter(inputs.split(\"\\n\"))\n\n    # sys.setrecursionlimit(10000)\n\n    # @patch('builtins.input', side_effect=inputs.split(\"\\n\"))\n    @patch(\"builtins.open\", mock_open(read_data=inputs))\n    @patch(\"sys.stdin\", StringIO(inputs))\n    @patch(\"sys.stdin.readline\", lambda *args: next(inputs_line_iterator))\n    @patch(\"sys.stdin.readlines\", lambda *args: inputs.split(\"\\n\"))\n    @patch(\"sys.stdin.read\", lambda *args: inputs)\n    # @patch('sys.stdout.write', print)\n    def _inner_call_method(_method):\n        try:\n            return _method()\n        except SystemExit:\n            pass\n        finally:\n            pass\n\n    return _inner_call_method(method)\n\n\ndef reliability_guard(maximum_memory_bytes=None):\n    \"\"\"\n    This disables various destructive functions and prevents the generated code\n    from interfering with the test (e.g. fork bomb, killing other processes,\n    removing filesystem files, etc.)\n    WARNING\n    This function is NOT a security sandbox. Untrusted code, including, model-\n    generated code, should not be blindly executed outside of one. See the\n    Codex paper for more information about OpenAI's code sandbox, and proceed\n    with caution.\n    \"\"\"\n\n    if maximum_memory_bytes is not None:\n        import resource\n\n        resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))\n        resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))\n        if platform.uname().system != \"Darwin\":\n            resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))\n\n    faulthandler.disable()\n\n    import builtins\n\n    builtins.exit = None\n    builtins.quit = None\n\n    import os\n\n    os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n\n    os.kill = None\n    os.system = None  # 防止干扰repl评测\n    os.putenv = None\n    os.remove = None\n    os.removedirs = None\n    os.rmdir = None\n    os.fchdir = None\n    os.setuid = None\n    os.fork = None\n    os.forkpty = None\n    os.killpg = None\n    os.rename = None\n    os.renames = None\n    os.truncate = None\n    os.replace = None\n    os.unlink = None\n    os.fchmod = None\n    os.fchown = None\n    os.chmod = None\n    os.chown = None\n    os.chroot = None\n    os.lchflags = None\n    os.lchmod = None\n    os.lchown = None\n    os.getcwd = None\n    os.chdir = None\n\n    import shutil\n\n    shutil.rmtree = None\n    shutil.move = None\n    shutil.chown = None\n\n    import subprocess\n\n    subprocess.Popen = None  # type: ignore\n\n    __builtins__[\"help\"] = None\n\n    import sys\n\n    sys.modules[\"ipdb\"] = None\n    sys.modules[\"joblib\"] = None\n    sys.modules[\"resource\"] = None\n    sys.modules[\"psutil\"] = None\n    sys.modules[\"tkinter\"] = None\n\n    # Disable some built-in functions that can be destructive\n    for mod in [\"subprocess\", \"ctypes\"]:\n        sys.modules[mod] = None\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/prime_code/utils.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py\n\nimport multiprocessing\nimport os\nimport sys\nimport traceback\nfrom typing import Optional\n\nfrom .testing_util import run_test\n\n\ndef _temp_run(sample, generation, debug, result, metadata_list, timeout):\n    with open(os.devnull, \"w\") as devnull:\n        sys.stdout = devnull\n        sys.stderr = devnull\n        try:\n            res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)\n            result.append(res)\n            metadata_list.append(metadata)\n        except Exception:\n            # print(e) # some tracebacks are extremely long.\n            traceback.print_exc(10)\n            result.append([-1 for i in range(len(sample[\"inputs\"]))])\n            metadata_list.append({})\n\n\ndef check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):\n    \"\"\"Check correctness of code generation with a global timeout.\n    The global timeout is to catch some extreme/rare cases not handled by the timeouts\n    inside `run_test`\"\"\"\n\n    manager = multiprocessing.Manager()\n    result = manager.list()\n    metadata_list = manager.list()\n    p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))\n    p.start()\n    p.join(timeout=timeout + 1)\n    if p.is_alive():\n        p.kill()\n        # p.terminate()\n    if not result:\n        # consider that all tests failed\n        result = [[-1 for i in range(len(in_outs[\"inputs\"]))]]\n        if debug:\n            print(\"global timeout\")\n    return result[0], metadata_list\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/prime_math/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nAnswer checker API that uses sympy to simplify expressions and check for equality.\n\nCall grade_answer(given_answer: str, ground_truth: str).\n\nFROM: https://github.com/openai/prm800k/blob/main/prm800k/grading/grader.py\n\"\"\"\n\nimport contextlib\nimport math\nimport re\n\nimport sympy\nfrom pylatexenc import latex2text\nfrom sympy.parsing import sympy_parser\n\nfrom verl.utils.py_functional import timeout_limit\n\nfrom . import math_normalize\nfrom .grader import math_equal\n\n# import math_normalize\n# from grader import math_equal\n\n# sympy might hang -- we don't care about trying to be lenient in these cases\nBAD_SUBSTRINGS = [\"^{\", \"^(\"]\nBAD_REGEXES = [\"\\^[0-9]+\\^\", \"\\^[0-9][0-9]+\"]\nTUPLE_CHARS = \"()[]\"\n\n\ndef _sympy_parse(expr: str):\n    \"\"\"Parses an expression with sympy.\"\"\"\n    py_expr = expr.replace(\"^\", \"**\")\n    return sympy_parser.parse_expr(\n        py_expr,\n        transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)),\n    )\n\n\ndef _parse_latex(expr: str) -> str:\n    \"\"\"Attempts to parse latex to an expression sympy can read.\"\"\"\n    expr = expr.replace(\"\\\\tfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\dfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\frac\", \" \\\\frac\")  # Play nice with mixed numbers.\n    expr = latex2text.LatexNodes2Text().latex_to_text(expr)\n\n    # Replace the specific characters that this parser uses.\n    expr = expr.replace(\"√\", \"sqrt\")\n    expr = expr.replace(\"π\", \"pi\")\n    expr = expr.replace(\"∞\", \"inf\")\n    expr = expr.replace(\"∪\", \"U\")\n    expr = expr.replace(\"·\", \"*\")\n    expr = expr.replace(\"×\", \"*\")\n\n    return expr.strip()\n\n\ndef _is_float(num: str) -> bool:\n    try:\n        float(num)\n        return True\n    except ValueError:\n        return False\n\n\ndef _is_int(x: float) -> bool:\n    try:\n        return abs(x - int(round(x))) <= 1e-7\n    except Exception:\n        return False\n\n\ndef _is_frac(expr: str) -> bool:\n    return bool(re.search(r\"^-?[0-9]+.?/0*[1-9][0-9]*.?$\", expr))\n\n\ndef _str_is_int(x: str) -> bool:\n    try:\n        x = _strip_properly_formatted_commas(x)\n        x = float(x)\n        return abs(x - int(round(x))) <= 1e-7\n    except Exception:\n        return False\n\n\ndef _str_to_int(x: str) -> bool:\n    x = x.replace(\",\", \"\")\n    x = float(x)\n    return int(x)\n\n\ndef _inject_implicit_mixed_number(step: str):\n    \"\"\"\n    Automatically make a mixed number evalable\n    e.g. 7 3/4 => 7+3/4\n    \"\"\"\n    p1 = re.compile(\"([0-9]) +([0-9])\")\n    step = p1.sub(\"\\\\1+\\\\2\", step)  ## implicit mults\n    return step\n\n\ndef _strip_properly_formatted_commas(expr: str):\n    # We want to be careful because we don't want to strip tuple commas\n    p1 = re.compile(\"(\\d)(,)(\\d\\d\\d)($|\\D)\")\n    while True:\n        next_expr = p1.sub(\"\\\\1\\\\3\\\\4\", expr)\n        if next_expr == expr:\n            break\n        expr = next_expr\n    return next_expr\n\n\ndef _normalize(expr: str) -> str:\n    \"\"\"Normalize answer expressions.\"\"\"\n    if expr is None:\n        return None\n\n    # Remove enclosing `\\text{}`.\n    m = re.search(\"^\\\\\\\\text\\{(?P<text>.+?)\\}$\", expr)\n    if m is not None:\n        expr = m.group(\"text\")\n\n    expr = expr.replace(\"\\\\%\", \"%\")\n    expr = expr.replace(\"\\\\$\", \"$\")\n    expr = expr.replace(\"$\", \"\")\n    expr = expr.replace(\"%\", \"\")\n    expr = expr.replace(\" or \", \" , \")\n    expr = expr.replace(\" and \", \" , \")\n\n    expr = expr.replace(\"million\", \"*10^6\")\n    expr = expr.replace(\"billion\", \"*10^9\")\n    expr = expr.replace(\"trillion\", \"*10^12\")\n\n    for unit in [\n        \"degree\",\n        \"cm\",\n        \"centimeter\",\n        \"meter\",\n        \"mile\",\n        \"second\",\n        \"minute\",\n        \"hour\",\n        \"day\",\n        \"week\",\n        \"month\",\n        \"year\",\n        \"foot\",\n        \"feet\",\n        \"inch\",\n        \"yard\",\n        \"liter\",\n    ]:\n        expr = re.sub(f\"{unit}(es)?(s)? *(\\^[0-9]+)?\", \"\", expr)\n    expr = re.sub(\"\\^ *\\\\\\\\circ\", \"\", expr)\n\n    if len(expr) > 0 and expr[0] == \"{\" and expr[-1] == \"}\":\n        expr = expr[1:-1]\n\n    expr = re.sub(\",\\\\\\\\! *\", \"\", expr)\n    if _is_float(expr) and _is_int(float(expr)):\n        expr = str(int(round(float(expr))))\n    if \"\\\\\" in expr:\n        with contextlib.suppress(Exception):\n            expr = _parse_latex(expr)\n\n    # edge case with mixed numbers and negative signs\n    expr = re.sub(\"- *\", \"-\", expr)\n\n    expr = _inject_implicit_mixed_number(expr)\n\n    # don't be case sensitive for text answers\n    expr = expr.lower()\n\n    if _str_is_int(expr):\n        expr = str(_str_to_int(expr))\n\n    return expr\n\n\ndef count_unknown_letters_in_expr(expr: str):\n    expr = expr.replace(\"sqrt\", \"\")\n    expr = expr.replace(\"frac\", \"\")\n    letters_in_expr = set([x for x in expr if x.isalpha()])\n    return len(letters_in_expr)\n\n\ndef should_allow_eval(expr: str):\n    # we don't want to try parsing unknown text or functions of more than two variables\n    if count_unknown_letters_in_expr(expr) > 2:\n        return False\n\n    for bad_string in BAD_SUBSTRINGS:\n        if bad_string in expr:\n            return False\n\n    return all(re.search(bad_regex, expr) is None for bad_regex in BAD_REGEXES)\n\n\n@timeout_limit(seconds=10)\ndef are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):\n    are_equal = False\n    try:\n        expr = f\"({ground_truth_normalized})-({given_normalized})\"\n        if should_allow_eval(expr):\n            sympy_diff = _sympy_parse(expr)\n            simplified = sympy.simplify(sympy_diff)\n            if simplified == 0:\n                are_equal = True\n    except Exception:\n        pass\n    return are_equal\n\n\ndef split_tuple(expr: str):\n    \"\"\"\n    Split the elements in a tuple/interval, while handling well-formatted commas in large numbers\n    \"\"\"\n    expr = _strip_properly_formatted_commas(expr)\n    if len(expr) == 0:\n        return []\n    if (\n        len(expr) > 2\n        and expr[0] in TUPLE_CHARS\n        and expr[-1] in TUPLE_CHARS\n        and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])\n    ):\n        elems = [elem.strip() for elem in expr[1:-1].split(\",\")]\n    else:\n        elems = [expr]\n    return elems\n\n\ndef grade_answer(given_answer: str, ground_truth: str) -> bool:\n    \"\"\"\n    The answer will be considered correct if:\n    (a) it normalizes to the same string as the ground truth answer\n    OR\n    (b) sympy can simplify the difference between the expressions to 0\n    \"\"\"\n    if given_answer is None:\n        return False\n\n    ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)\n    given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)\n\n    # be at least as lenient as mathd\n    if ground_truth_normalized_mathd == given_answer_normalized_mathd:\n        return True\n\n    ground_truth_normalized = _normalize(ground_truth)\n    given_normalized = _normalize(given_answer)\n\n    if ground_truth_normalized is None:\n        return False\n\n    if ground_truth_normalized == given_normalized:\n        return True\n\n    if len(given_normalized) == 0:\n        return False\n\n    ground_truth_elems = split_tuple(ground_truth_normalized)\n    given_elems = split_tuple(given_normalized)\n\n    if (\n        len(ground_truth_elems) > 1\n        and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1])\n        or len(ground_truth_elems) != len(given_elems)\n    ):\n        is_correct = False\n    else:\n        for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True):\n            if _is_frac(ground_truth_elem) and _is_frac(given_elem):\n                # if fractions aren't reduced, then shouldn't be marked as correct\n                # so, we don't want to allow sympy.simplify in this case\n                is_correct = ground_truth_elem == given_elem\n            elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):\n                # if the ground truth answer is an integer, we require the given answer to be a strict match\n                # (no sympy.simplify)\n                is_correct = False\n            else:\n                try:\n                    is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)\n                except Exception as e:\n                    # if there's an error, we'll just say it's not correct\n                    is_correct = False\n                    print(f\"Error: {e} from are_equal_under_sympy, {ground_truth_elem}, {given_elem}\")\n            if not is_correct:\n                break\n\n    return is_correct\n\n\ndef remove_boxed(s):\n    left = \"\\\\boxed{\"\n    try:\n        assert s[: len(left)] == left\n        assert s[-1] == \"}\"\n        return s[len(left) : -1]\n    except Exception:\n        return None\n\n\ndef _last_boxed_only_string(string):\n    idx = string.rfind(\"\\\\boxed\")\n    if idx < 0:\n        idx = string.rfind(\"\\\\fbox\")\n        if idx < 0:\n            return None\n\n    i = idx\n    left_brace_idx = None\n    right_brace_idx = None\n    num_left_braces_open = 0\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n            if left_brace_idx is None:\n                left_brace_idx = i\n        elif string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n\n        i += 1\n\n    if left_brace_idx is None or right_brace_idx is None:\n        return None\n\n    return string[left_brace_idx + 1 : right_brace_idx].strip()\n\n\ndef match_answer(response):\n    is_matched = False\n    for ans_marker in [\"answer:\", \"answer is\", \"answers are\"]:\n        ans_idx = response.lower().rfind(ans_marker)\n        if ans_idx != -1:\n            is_matched = True\n            response = response[ans_idx + len(ans_marker) :].strip()\n            if response.endswith(\"\\n\"):\n                response = response[:-2]\n\n    for ans_marker in [\"is answer\", \"is the answer\", \"are answers\", \"are the answers\"]:\n        ans_idx = response.lower().rfind(ans_marker)\n        if ans_idx != -1:\n            is_matched = True\n            response = response[:ans_idx].strip()\n            if response.endswith(\"\\n\"):\n                response = response[:-2]\n\n    # Find boxed\n    ans_boxed = _last_boxed_only_string(response)\n    if ans_boxed:\n        is_matched = True\n        response = ans_boxed\n\n    if \". \" in response:\n        dot_idx = response.lower().rfind(\". \")\n        if dot_idx != -1:\n            response = response[:dot_idx].strip()\n\n    for ans_marker in [\"be \", \"is \", \"are \", \"=\", \": \", \"get \", \"be\\n\", \"is\\n\", \"are\\n\", \":\\n\", \"get\\n\"]:\n        ans_idx = response.lower().rfind(ans_marker)\n        if ans_idx != -1:\n            is_matched = True\n            response = response[ans_idx + len(ans_marker) :].strip()\n            if response.endswith(\"\\n\"):\n                response = response[:-2]\n\n    is_matched = is_matched if any([c.isdigit() for c in response]) else False  # answer must have a digit\n    # Grade\n    return is_matched, response\n\n\ndef compute_score(model_output: str, ground_truth: str) -> bool:\n    model_output = str(model_output)\n    ground_truth = str(ground_truth)\n\n    is_matched, extracted_model_output = match_answer(model_output)\n    format_correctness = \"Step 2:\" in model_output and \"\\\\box\" in model_output\n\n    # grade simple algebra questions. if succeeded, return; otherwise, proceed to more complex grading\n    if grade_answer(extracted_model_output, ground_truth):\n        return True, True, extracted_model_output\n\n    try:\n        if \"\\pi\" in extracted_model_output or \"\\pi\" in ground_truth:\n            equivs = []\n            for pi in [math.pi, 3.14]:\n                equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi))\n            is_correct = any(equivs)\n        else:\n            is_correct = math_equal(extracted_model_output, ground_truth, timeout=True)\n    except Exception:\n        is_correct = False\n\n    return is_correct, format_correctness, extracted_model_output\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/prime_math/grader.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright (c) Microsoft Corporation.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE\n\n# Copyright (c) 2023 OpenAI\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Copyright (c) 2021 Dan Hendrycks\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:\n- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py\n- https://github.com/microsoft/ProphetNet/tree/master/CRITIC\n- https://github.com/openai/prm800k\n\"\"\"\n\nimport contextlib\nimport math\nimport re\nfrom math import isclose\n\n# sympy related\nfrom sympy import N, simplify\nfrom sympy.parsing.latex import parse_latex\nfrom sympy.parsing.sympy_parser import parse_expr\n\n# verl related\nfrom verl.utils.py_functional import timeout_limit\n\n\ndef is_digit(s):\n    try:\n        if \"{,}\" in str(s):\n            num = float(str(s).replace(\"{,}\", \"\"))\n            return True, num\n\n        num = float(str(s).replace(\",\", \"\"))\n        return True, num\n    except ValueError:\n        return False, None\n\n\ndef normalize(answer, pi) -> str:\n    # checking if answer is $<number> and removing $ in that case to compare\n    if isinstance(answer, str) and bool(re.match(r\"\\$\\d+(\\.\\d+)?\", answer)):\n        return answer[1:]\n\n    # checking if answer is <number>% or <number>\\\\% and removing %\n    if isinstance(answer, str) and (\n        bool(re.match(r\"^\\d+(\\.\\d+)?%$\", answer)) or bool(re.match(r\"^\\d+(\\.\\d+)?\\\\%$\", answer))\n    ):\n        return answer.replace(\"\\\\%\", \"\").replace(\"%\", \"\")\n\n    # handle base\n    answer = handle_base(answer)\n\n    # handle pi\n    answer = handle_pi(answer, pi)\n\n    return answer\n\n\ndef handle_base(x) -> str:\n    if isinstance(x, str) and \"_\" in x:\n        # Due to base\n        x = x.split(\"_\")[0]\n        x = float(x)\n        return int(x)\n    return x\n\n\ndef handle_pi(string, pi):\n    if isinstance(string, str) and \"\\pi\" in string:\n        # Find the first occurrence of \"\\pi\"\n        idx = string.find(\"\\pi\")\n\n        # Iterate over the string and find all occurrences of \"\\pi\" with a valid previous character\n        while idx != -1:\n            if idx > 0 and string[idx - 1].isdigit():\n                # Replace \"\\pi\" with \"*math.pi\" if the previous character is a digit\n                string = string[:idx] + f\"*{pi}\" + string[idx + 3 :]\n            else:\n                # Replace \"\\pi\" with \"1*math.pi\" if the previous character is not a digit\n                string = string[:idx] + f\"1*{pi}\" + string[idx + 3 :]\n\n            # Find the next occurrence of \"\\pi\"\n            idx = string.find(\"\\pi\", idx + 1)\n\n        # Evaluate the expression using eval() function\n        with contextlib.suppress(Exception):\n            string = eval(string)\n\n    return string\n\n\ndef math_equal(\n    prediction: bool | float | str,\n    reference: float | str,\n    include_percentage: bool = True,\n    tolerance: float = 1e-4,\n    timeout: float = 10.0,\n    pi: float = math.pi,\n) -> bool:\n    \"\"\"\n    Exact match of math if and only if:\n    1. numerical equal: both can convert to float and are equal\n    2. symbolic equal: both can convert to sympy expression and are equal\n    \"\"\"\n\n    prediction = normalize(prediction, pi)\n    reference = normalize(reference, pi)\n\n    if isinstance(prediction, str) and len(prediction) > 1000:  # handling weird corner-cases\n        prediction = prediction[:1000]\n\n    # 0. string comparison\n    if isinstance(prediction, str) and isinstance(reference, str):\n        if prediction.strip().lower() == reference.strip().lower():\n            return True\n        if prediction.replace(\" \", \"\") == reference.replace(\" \", \"\"):\n            return True\n\n    try:  # 1. numerical equal\n        if is_digit(prediction)[0] and is_digit(reference)[0]:\n            prediction = is_digit(prediction)[1]\n            reference = is_digit(reference)[1]\n            # number questions\n            gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]\n            for item in gt_result:\n                try:\n                    if isclose(item, prediction, rel_tol=tolerance):\n                        return True\n                except Exception:\n                    continue\n            return False\n    except Exception:\n        pass\n\n    if not prediction and prediction not in [0, False]:\n        return False\n\n    # 2. symbolic equal\n    reference = str(reference).strip()\n    prediction = str(prediction).strip()\n\n    ## deal with [], (), {}\n    prediction = format_intervals(prediction)\n\n    pred_str, ref_str = prediction, reference\n    if (prediction.startswith(\"[\") and prediction.endswith(\"]\") and not reference.startswith(\"(\")) or (\n        prediction.startswith(\"(\") and prediction.endswith(\")\") and not reference.startswith(\"[\")\n    ):\n        pred_str = pred_str.strip(\"[]()\")\n        ref_str = ref_str.strip(\"[]()\")\n    for s in [\"{\", \"}\", \"(\", \")\"]:\n        ref_str = ref_str.replace(s, \"\")\n        pred_str = pred_str.replace(s, \"\")\n    if pred_str == ref_str:\n        return True\n\n    ## [a, b] vs. [c, d], return a==c and b==d\n    if (\n        prediction\n        and reference\n        and prediction[0] in \"([\"\n        and prediction[-1] in \")]\"\n        and prediction[0] == reference[0]\n        and prediction[-1] == reference[-1]\n    ):\n        pred_parts = prediction[1:-1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts) and all(\n            [\n                math_equal(pred_pt, ref_pt, include_percentage, tolerance)\n                for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)\n            ]\n        ):\n            return True\n\n    if \",\" in prediction and \",\" in reference:\n        pred_parts = [item.strip() for item in prediction.split(\",\")]\n        ref_parts = [item.strip() for item in reference.split(\",\")]\n\n        if len(pred_parts) == len(ref_parts):\n            return bool(\n                all(\n                    [\n                        math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance)\n                        for i in range(len(pred_parts))\n                    ]\n                )\n            )\n\n    # if we have point == tuple of values\n    if prediction.startswith(\"Point\") and reference[0] == \"(\" and reference[-1] == \")\":\n        pred_parts = prediction[prediction.find(\"(\") + 1 : -1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts) and all(\n            [\n                math_equal(pred_pt, ref_pt, include_percentage, tolerance)\n                for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=False)\n            ]\n        ):\n            return True\n\n    # if reference is a matrix\n    if \"\\begin{pmatrix}\" in reference and prediction.startswith(\"Matrix\"):\n        try:\n            pred_matrix = parse_expr(prediction)\n            ref_matrix_items = reference.split()[1:-1:2]\n            if len(pred_matrix) == len(ref_matrix_items) and all(\n                [\n                    math_equal(pred, ref, include_percentage, tolerance)\n                    for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False)\n                ]\n            ):\n                return True\n        except Exception:\n            pass\n    elif \"\\begin{pmatrix}\" in reference and prediction.startswith(\"[\") and prediction.endswith(\"]\"):\n        if isinstance(eval(prediction), list):\n            try:\n                pred_matrix = eval(prediction)\n                # ref_matrix_items = reference.split()[1:-1:2]\n                ref_matrix_items = (\n                    reference.lstrip(\"\\\\begin{pmatrix}\")  # noqa: B005\n                    .lstrip(\"\\begin{pmatrix}\")\n                    .rstrip(\"\\\\end{pmatrix}\")\n                    .rstrip(\"\\end{pmatrix}\")\n                )  # noqa: B005\n                ref_matrix_items = ref_matrix_items.split(\"\\\\\")\n                ref_matrix_items = [row.split(\"&\") if \"&\" in row else row for row in ref_matrix_items]\n                if len(pred_matrix) == len(ref_matrix_items) and all(\n                    [\n                        math_equal(pred, ref, include_percentage, tolerance)\n                        for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False)\n                    ]\n                ):\n                    return True\n            except Exception:\n                pass\n\n    return symbolic_equal(prediction, reference, tolerance, timeout)\n\n\ndef symbolic_equal(a, b, tolerance, timeout=10.0):\n    def _parse(s):\n        for f in [parse_expr, parse_latex]:\n            try:\n                with timeout_limit(seconds=timeout):\n                    return f(s)\n            except TimeoutError:\n                print(f\"Parsing timed out for {s}\")\n                continue\n            except Exception:\n                continue\n        return s\n\n    a = _parse(a)\n    b = _parse(b)\n\n    try:\n        with timeout_limit(seconds=timeout):\n            if simplify(a - b) == 0:\n                return True\n    except TimeoutError:\n        print(f\"Simplification timed out for {a} - {b}\")\n        pass\n    except Exception:\n        pass\n\n    try:\n        with timeout_limit(seconds=timeout):\n            if isclose(N(a), N(b), rel_tol=tolerance):\n                return True\n    except TimeoutError:\n        print(f\"Numerical evaluation timed out for {a}, {b}\")\n        pass\n    except Exception:\n        pass\n    return False\n\n\ndef format_intervals(prediction):\n    patterns = {\n        \"Interval(\": r\"^Interval\\((.*)\\)$\",\n        \"Interval.Ropen(\": r\"^Interval\\.Ropen\\((.*)\\)$\",\n        \"Interval.Lopen(\": r\"^Interval\\.Lopen\\((.*)\\)$\",\n        \"Interval.open(\": r\"^Interval\\.open\\((.*)\\)$\",\n    }\n\n    for key, pattern in patterns.items():\n        match = re.match(pattern, prediction)\n        if match:\n            inner_content = match.group(1)\n\n            if key == \"Interval(\":  # Intarval(a, b) == [a, b]\n                return f\"[{inner_content}]\"\n            elif key == \"Interval.Ropen(\":  # Intarval.Ropen(a, b) == [a, b)\n                return f\"[{inner_content})\"\n            elif key == \"Interval.Lopen(\":  # Intarval.Lopen(a, b) == (a, b]\n                return f\"({inner_content}]\"\n            elif key == \"Interval.open(\":  # Intarval.open(a, b) == (a, b)\n                return f\"({inner_content})\"\n\n    return prediction\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/prime_math/math_normalize.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright (c) 2021 Dan Hendrycks\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\"\"\"\nThis logic is largely copied from the Hendrycks' MATH release (math_equivalence).\n\nFrom: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py\n\"\"\"\n\nimport re\nfrom typing import Optional\n\n\ndef normalize_answer(answer: Optional[str]) -> Optional[str]:\n    if answer is None:\n        return None\n    answer = answer.strip()\n    try:\n        # Remove enclosing `\\text{}`.\n        m = re.search(\"^\\\\\\\\text\\{(?P<text>.+?)\\}$\", answer)\n        if m is not None:\n            answer = m.group(\"text\").strip()\n        return _strip_string(answer)\n    except Exception:  # noqa: E722\n        return answer\n\n\ndef _fix_fracs(string):\n    substrs = string.split(\"\\\\frac\")\n    new_str = substrs[0]\n    if len(substrs) > 1:\n        substrs = substrs[1:]\n        for substr in substrs:\n            new_str += \"\\\\frac\"\n            if substr[0] == \"{\":\n                new_str += substr\n            else:\n                try:\n                    assert len(substr) >= 2\n                except Exception:  # noqa: E722\n                    return string\n                a = substr[0]\n                b = substr[1]\n                if b != \"{\":\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}{\" + b + \"}\"\n                else:\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}\" + b + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}\" + b\n    string = new_str\n    return string\n\n\ndef _fix_a_slash_b(string):\n    if len(string.split(\"/\")) != 2:\n        return string\n    a = string.split(\"/\")[0]\n    b = string.split(\"/\")[1]\n    try:\n        a = int(a)\n        b = int(b)\n        assert string == \"{}/{}\".format(a, b)\n        new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n        return new_string\n    except Exception:  # noqa: E722\n        return string\n\n\ndef _remove_right_units(string):\n    # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n    if \"\\\\text{ \" in string:\n        splits = string.split(\"\\\\text{ \")\n        assert len(splits) == 2\n        return splits[0]\n    else:\n        return string\n\n\ndef _fix_sqrt(string):\n    if \"\\\\sqrt\" not in string:\n        return string\n    splits = string.split(\"\\\\sqrt\")\n    new_string = splits[0]\n    for split in splits[1:]:\n        if split[0] != \"{\":\n            a = split[0]\n            new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n        else:\n            new_substr = \"\\\\sqrt\" + split\n        new_string += new_substr\n    return new_string\n\n\ndef _strip_string(string):\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = _remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\%\", \"\")\n    string = string.replace(\"\\%\", \"\")\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2 and len(string.split(\"=\")[0]) <= 2:\n        string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = _fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1).\n    # Also does a/b --> \\\\frac{a}{b}\n    string = _fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = _fix_a_slash_b(string)\n\n    return string\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/sandbox_fusion/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport logging\nimport traceback\n\nfrom .utils import check_correctness\n\n\"\"\"\nVerify code correctness using the Sandbox Fusion (https://github.com/bytedance/SandboxFusion).\nYou can either deploy the sandbox_fusion service yourself or use the\nFaaS service provided by public cloud, eg: volcengine.com.\n\"\"\"\nlogger = logging.getLogger(__name__)\n\n\ndef compute_score(\n    sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, completion, test_cases, continuous=False, timeout=10\n):\n    \"\"\"\n    Computes the code score using the remote sandbox API.\n\n    Args:\n        sandbox_fusion_url: The URL of the sandbox_fusion service, eg: \"https://<your service endpoint>/run_code\"\n\n        completion: The completion string containing the code.\n        test_cases: JSON string or dictionary containing \"inputs\" and \"outputs\".\n        continuous: Whether to compute a continuous score (based on the first N test cases).\n        timeout: Timeout for each test case.\n\n    Returns:\n        A tuple (score, metadata_list).\n        score: Float score (0.0 to 1.0).\n        metadata_list: List containing execution metadata for each test case.\n    \"\"\"\n    solution = completion\n    if \"```python\" in completion:\n        solution = completion.split(\"```python\")[-1].split(\"```\")[0]\n    elif \"```\" in completion:\n        # Handle cases like ```\\ncode\\n```\n        parts = completion.split(\"```\")\n        if len(parts) >= 2:\n            solution = parts[1]\n            # Remove potential language specifier like 'python\\n'\n            if \"\\n\" in solution:\n                first_line, rest = solution.split(\"\\n\", 1)\n                if first_line.strip().isalpha():  # Simple check for language name\n                    solution = rest\n    else:\n        return 0.0, [{\"error\": \"Invalid completion (missing code block)\"}]\n\n    try:\n        if not isinstance(test_cases, dict):\n            try:\n                test_cases = json.loads(test_cases)\n            except json.JSONDecodeError as e:\n                logger.error(f\"Failed to parse test_cases JSON: {e}\")\n                return 0.0, [{\"error\": \"Invalid test_cases JSON format\"}]\n\n        if not test_cases or \"inputs\" not in test_cases or \"outputs\" not in test_cases:\n            logger.error(\"Invalid test_cases structure.\")\n            return 0.0, [{\"error\": \"Invalid test_cases structure (missing inputs/outputs)\"}]\n\n        # Check all test cases\n        # Note: The return value of check_correctness might need adaptation here\n        # Assume check_correctness returns (results_list, metadata_list)\n        # results_list contains True, False, or error codes (-1, -2, -3, etc.)\n        res_list, metadata_list = check_correctness(\n            sandbox_fusion_url=sandbox_fusion_url,\n            in_outs=test_cases,\n            generation=solution,\n            timeout=timeout,\n            concurrent_semaphore=concurrent_semaphore,\n            memory_limit_mb=memory_limit_mb,\n        )\n\n        # Calculate score\n        if not res_list:  # If there are no results (e.g., invalid input)\n            return 0.0, metadata_list\n\n        if continuous:\n            # Calculate pass rate for the first N (e.g., 10) test cases\n            num_to_consider = min(len(res_list), 10)\n            if num_to_consider == 0:\n                score = 0.0\n            else:\n                passed_count = sum(1 for r in res_list[:num_to_consider] if r is True)\n                score = passed_count / num_to_consider\n            # Return all metadata, even if score is based on the first N\n            final_metadata = metadata_list\n        else:\n            # Calculate pass rate for all test cases\n            passed_count = sum(1 for r in res_list if r is True)\n            total_cases = len(res_list)\n            score = passed_count / total_cases if total_cases > 0 else 0.0\n            final_metadata = metadata_list\n\n    except Exception as e:\n        logger.error(f\"Error during compute_score: {e}\")\n        traceback.print_exc()\n        score = 0.0\n        # Try to return partial metadata if available, otherwise return error info\n        final_metadata = metadata_list if \"metadata_list\" in locals() else [{\"error\": f\"Unhandled exception: {e}\"}]\n\n    # Ensure float and list are returned\n    return float(score), final_metadata if isinstance(final_metadata, list) else [final_metadata]\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/sandbox_fusion/utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport concurrent.futures  # <-- Import concurrent.futures\nimport json\nimport logging\nimport os\nimport threading\nimport time\nimport traceback\nimport uuid\nfrom typing import Any, Optional\n\nimport requests\n\nDEFAULT_TIMEOUT = 10  # Default compile and run timeout\nMAX_RETRIES = 3\nINITIAL_RETRY_DELAY = 1\nAPI_TIMEOUT = 10\n\nlogger = logging.getLogger(__name__)\n\n# Define supported languages list (optional, for documentation or validation)\nSUPPORTED_LANGUAGES = [\n    \"python\",\n    \"cpp\",\n    \"nodejs\",\n    \"go\",\n    \"go_test\",\n    \"java\",\n    \"php\",\n    \"csharp\",\n    \"bash\",\n    \"typescript\",\n    \"sql\",\n    \"rust\",\n    \"cuda\",\n    \"lua\",\n    \"R\",\n    \"perl\",\n    \"D_ut\",\n    \"ruby\",\n    \"scala\",\n    \"julia\",\n    \"pytest\",\n    \"junit\",\n    \"kotlin_script\",\n    \"jest\",\n    \"verilog\",\n    \"python_gpu\",\n    \"lean\",\n    \"swift\",\n    \"racket\",\n]\n\n\ndef call_sandbox_api(\n    sandbox_fusion_url: str,\n    code: str,\n    stdin: Optional[str],\n    compile_timeout: int,\n    run_timeout: int,\n    memory_limit_mb: int,\n    language: str = \"python\",\n) -> tuple[Optional[dict[str, Any]], Optional[str]]:  # <-- Remove request_id parameter\n    \"\"\"\n    Calls the remote sandbox API to execute code with retry logic for Gateway Timeout,\n    using increasing delay between retries. Logs internal calls with a unique ID.\n\n    Args:\n        sandbox_fusion_url: The URL of the sandbox fusion API.\n        code: The code string to execute.\n        stdin: The standard input string.\n        compile_timeout: Compile timeout in seconds.\n        run_timeout: Run timeout in seconds.\n        language: The programming language of the code (e.g., \"python\", \"cpp\", \"java\"). Defaults to \"python\".\n\n    Returns:\n        A tuple (response_json, error_message).\n        If successful, response_json is the API's returned JSON object, error_message is None.\n        If failed after retries, response_json is None, error_message contains the error information.\n    \"\"\"\n    request_id = str(uuid.uuid4())  # <-- Generate request_id internally\n    log_prefix = f\"[Request ID: {request_id}] \"  # <-- Create log prefix\n\n    if language not in SUPPORTED_LANGUAGES:\n        error_msg = f\"{log_prefix}Unsupported language: {language}\"\n        logger.error(error_msg)\n        return None, error_msg\n\n    payload = json.dumps(\n        {\n            \"compile_timeout\": compile_timeout,\n            \"run_timeout\": run_timeout,\n            \"code\": code,\n            \"stdin\": stdin,\n            \"memory_limit_MB\": memory_limit_mb,\n            \"language\": language,  # Use the passed language parameter\n            \"files\": {},\n            \"fetch_files\": [],\n        }\n    )\n    headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n    # Calculate a reasonable request timeout based on compile/run timeouts plus a buffer\n    request_timeout = compile_timeout + run_timeout + API_TIMEOUT\n\n    last_error = None  # Store the last error encountered\n\n    for attempt in range(MAX_RETRIES):\n        try:\n            logger.info(\n                f\"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling sandbox API at {sandbox_fusion_url}\"\n            )  # <-- Use internal log_prefix\n            response = requests.post(\n                sandbox_fusion_url,\n                headers=headers,\n                data=payload,\n                timeout=request_timeout,  # Use the calculated timeout\n            )\n\n            # Check for Gateway Timeout (504) specifically for retrying\n            if response.status_code == 504:\n                last_error = (\n                    f\"{log_prefix}API Request Error: Gateway Timeout (504) on attempt \"\n                    f\"{attempt + 1}/{MAX_RETRIES}\"\n                )  # <-- Use internal log_prefix\n                logger.warning(last_error)\n                if attempt < MAX_RETRIES - 1:  # Don't sleep after the last attempt\n                    # Calculate increasing delay (e.g., 1s, 2s, 4s, ...) or (1s, 2s, 3s, ...)\n                    # Simple linear increase: delay = INITIAL_RETRY_DELAY * (attempt + 1)\n                    # Exponential backoff: delay = INITIAL_RETRY_DELAY * (2 ** attempt)\n                    delay = INITIAL_RETRY_DELAY * (attempt + 1)  # Using linear increase for simplicity\n                    logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")  # <-- Use internal log_prefix\n                    time.sleep(delay)\n                continue  # Go to the next retry attempt\n\n            # Check for other HTTP errors (e.g., 4xx, other 5xx)\n            response.raise_for_status()\n\n            # If successful (status code 2xx)\n            logger.info(\n                f\"{log_prefix}Sandbox API call successful on attempt {attempt + 1}\"\n            )  # <-- Use internal log_prefix\n            return response.json(), None\n\n        except requests.exceptions.RequestException as e:\n            last_error = f\"{log_prefix}API Request Error: {e}\"  # <-- Use internal log_prefix\n            break  # Exit retry loop on non-504 request errors\n        except json.JSONDecodeError as e:\n            raw_response_text = response.text if \"response\" in locals() else \"N/A\"\n            last_error = f\"{log_prefix}API Response JSON Decode Error: {e}\"  # <-- Use internal log_prefix\n            break  # Exit retry loop on JSON decode errors\n        except Exception as e:\n            last_error = f\"{log_prefix}Unexpected Error: {e}\"  # <-- Use internal log_prefix\n            break  # Exit retry loop on other unexpected errors\n\n    # If loop finishes without returning success, return the last recorded error\n    logger.error(f\"{log_prefix}Sandbox API call failed. Last error: {last_error}\")  # <-- Use internal log_prefix\n    # Return the error message without the prefix, as the caller doesn't need the internal ID\n    # Ensure API call failure returns error message, leading to -1 in check_correctness\n    return None, last_error.replace(log_prefix, \"API Call Failed: \") if last_error else \"API Call Failed after retries\"\n\n\ndef _process_single_case(\n    case_index: int,\n    stdin_data: Any,\n    expected_output: Any,\n    sandbox_fusion_url: str,\n    generation: str,\n    timeout: int,\n    memory_limit_mb: int,\n    language: str,\n    concurrent_semaphore: Optional[threading.Semaphore] = None,\n    fn_name: Optional[str] = None,\n) -> tuple[int, dict[str, Any]]:\n    \"\"\"Helper function to process a single test case.\"\"\"\n    api_response = None\n    error_msg = None\n    logger.info(f\"Processing test case {case_index + 1}.\")\n\n    current_generation_code = generation\n\n    if fn_name and language == \"python\":\n        # Wrapper assumes stdin_data is a JSON string for function arguments.\n        wrapper_code = f\"\"\"\nimport traceback\nfrom string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\n\n# === User's Original Code START ===\n{generation}\n# === User's Original Code END ===\n\n_SANDBOX_FN_NAME = \"{fn_name}\"\n\ndef _execute_user_function():\n    # --- Input Parsing ---\n    _raw_input_str = sys.stdin.read()\n    _args = []\n    if _raw_input_str.strip(): # If there's input\n        try:\n            _args = [json.loads(line) for line in _raw_input_str.split('\\\\n')]\n        except json.JSONDecodeError as _je:\n            sys.stderr.write(f\"WrapperError: Invalid JSON input for '{{_SANDBOX_FN_NAME}}': {{_je}}\\\\nInput was: \"\n                              f\"{{_raw_input_str[:200]}}\\\\n\")\n            return None, True # result, error_occurred\n\n    # --- Function Location and Execution ---\n    try:\n        _target_callable = None\n        # Try global scope first\n        if _SANDBOX_FN_NAME in globals():\n            _target_callable = globals()[_SANDBOX_FN_NAME]\n        # Else, if 'Solution' class exists, try to get its method\n        elif 'Solution' in globals():\n            _Solution_class = globals()['Solution']\n            # Attempt to instantiate and get method.\n            # Errors (e.g., Solution not a class, instantiation fails, method missing)\n            # will be caught by the broad except block below.\n            _solution_instance = _Solution_class()\n            _target_callable = getattr(_solution_instance, _SANDBOX_FN_NAME)\n\n        if not _target_callable:\n            sys.stderr.write(f\"WrapperError: Function or method '{{_SANDBOX_FN_NAME}}' not found.\\\\n\")\n            return None, True # result, error_occurred\n\n        _fn_result = _target_callable(*_args)\n        return _fn_result, False # result, no_error\n    except Exception: # Catches errors from Solution instantiation, getattr, or function call\n        sys.stderr.write(f\"Error during setup or execution of '{{_SANDBOX_FN_NAME}}':\\\\n{{traceback.format_exc()}}\\\\n\")\n        return None, True # result, error_occurred\n\nif __name__ == '__main__':\n    _result, _error_occurred = _execute_user_function()\n\n    if not _error_occurred:\n        # Serialize result to stdout\n        if isinstance(_result, (dict, list, tuple)) or _result is None or isinstance(_result, bool):\n            print(json.dumps(_result))\n        elif isinstance(_result, (int, float, str)):\n            print(str(_result)) # Ensure string conversion for print\n        else:\n            # For other types, default to string representation.\n            print(str(_result))\n    # Optional: To explicitly exit with an error code if the sandbox relies on it\n    # else:\n    #    sys.exit(1)\n\"\"\"\n        current_generation_code = wrapper_code\n\n    stdin = None if stdin_data is None else str(stdin_data)\n    try:\n        if concurrent_semaphore:\n            # logger.debug(f\"Case {case_index + 1}: Attempting to acquire semaphore.\")\n            with concurrent_semaphore:\n                # logger.debug(f\"Case {case_index + 1}: Semaphore acquired. Calling API.\")\n                api_response, error_msg = call_sandbox_api(\n                    sandbox_fusion_url=sandbox_fusion_url,\n                    code=current_generation_code,\n                    stdin=stdin,\n                    compile_timeout=timeout,\n                    run_timeout=timeout,\n                    memory_limit_mb=memory_limit_mb,\n                    language=language,\n                )\n            # logger.debug(f\"Case {case_index + 1}: Semaphore released.\")\n        else:\n            api_response, error_msg = call_sandbox_api(\n                sandbox_fusion_url=sandbox_fusion_url,\n                code=current_generation_code,\n                stdin=stdin,\n                compile_timeout=timeout,\n                run_timeout=timeout,\n                memory_limit_mb=memory_limit_mb,\n                language=language,\n            )\n    except Exception as e:\n        error_msg = f\"API Request Exception during check_correctness for case {case_index + 1}: {e}\"\n        logger.error(f\"Case {case_index + 1}: {error_msg}\")\n        traceback.print_exc()\n\n    metadata = {\n        \"case_index\": case_index,\n        \"input\": stdin,\n        \"expected_output\": str(expected_output),\n        \"api_request_error\": error_msg,\n        \"api_response\": None,\n        \"status\": \"unknown\",\n        \"stdout\": None,\n        \"stderr\": None,\n        \"exit_code\": None,\n        \"duration\": None,\n        \"compile_duration\": None,\n        \"compile_stderr\": None,\n        \"api_status\": None,\n        \"compile_status\": None,\n        \"run_status\": None,\n    }\n    result_status = -1  # Default error: API request error or unknown sandbox error\n\n    if error_msg:\n        metadata[\"status\"] = \"api_error\"\n        result_status = -1  # API request itself failed (includes timeout after retries)\n        logger.error(f\"Case {case_index}: API error occurred: {error_msg}\")\n        # Log code and input only on error for brevity\n        generation_to_log = generation[:200] + \"...\" if len(generation) > 200 else generation\n        logger.error(f\"Case {case_index}: code: {generation_to_log}\")\n        logger.error(f\"Case {case_index}: input: {stdin}\")\n    elif api_response:\n        # --- Add debug logging ---\n        logger.debug(f\"Case {case_index}: API Response: {api_response}\")\n        metadata[\"api_response\"] = api_response\n        metadata[\"api_status\"] = api_response.get(\"status\")\n        compile_result = api_response.get(\"compile_result\")\n        run_result = api_response.get(\"run_result\")\n\n        # Extract compile information\n        if compile_result:\n            metadata[\"compile_status\"] = compile_result.get(\"status\")\n            metadata[\"compile_duration\"] = compile_result.get(\"execution_time\")\n            metadata[\"compile_stderr\"] = compile_result.get(\"stderr\")\n\n        # Extract run information\n        if run_result:\n            metadata[\"run_status\"] = run_result.get(\"status\")\n            metadata[\"stdout\"] = run_result.get(\"stdout\")\n            metadata[\"stderr\"] = run_result.get(\"stderr\")  # stderr during runtime\n            metadata[\"exit_code\"] = run_result.get(\"return_code\")\n            metadata[\"duration\"] = run_result.get(\"execution_time\")\n\n        # --- Determine status based on API response ---\n        api_status = metadata[\"api_status\"]\n\n        if api_status == \"SandboxError\":\n            metadata[\"status\"] = \"sandbox_error\"\n            result_status = -1  # Internal sandbox error\n        elif api_status == \"Failed\":\n            # --- Add debug logging ---\n            logger.debug(f\"API returned Failed status. Response: {api_response}\")\n            logger.debug(f\"Compile Result: {compile_result}\")\n            logger.debug(f\"Run Result: {run_result}\")\n            # --- Check the logic here ---\n            # Compile failed or timed out\n            is_compile_error = compile_result and (\n                metadata[\"compile_status\"] in [\"Error\", \"TimeLimitExceeded\"]\n                or (metadata[\"compile_status\"] == \"Finished\" and compile_result.get(\"return_code\") != 0)\n            )\n            if is_compile_error:\n                # Differentiate between compile_error and compile_timeout based on specific status\n                if metadata[\"compile_status\"] == \"TimeLimitExceeded\":\n                    metadata[\"status\"] = \"compile_timeout\"\n                else:  # Includes Error and Finished but return_code != 0 cases\n                    metadata[\"status\"] = \"compile_error\"\n                result_status = -4\n            # Run failed or timed out\n            elif run_result:\n                # Modified condition: Check for TimeLimitExceeded OR (Finished with non-zero exit code) OR Error status\n                is_runtime_error = (\n                    metadata[\"run_status\"] == \"TimeLimitExceeded\"\n                    or metadata[\"run_status\"] == \"Error\"\n                    or (metadata[\"run_status\"] == \"Finished\" and run_result.get(\"return_code\") != 0)\n                )\n                if is_runtime_error:\n                    if metadata[\"run_status\"] == \"TimeLimitExceeded\":\n                        metadata[\"status\"] = \"timeout\"  # Runtime timeout\n                        result_status = -3\n                    else:  # Includes Error and Finished with non-zero return_code\n                        metadata[\"status\"] = \"runtime_error\"\n                        result_status = -2\n                else:\n                    # Other Failed status with run_result, classify as unknown failure\n                    logger.warning(f\"Unknown run_status '{metadata['run_status']}' or state within Failed API status.\")\n                    metadata[\"status\"] = \"unknown_failure\"\n                    result_status = -1  # Default to -1\n            else:\n                # Status is Failed but neither a clear compile error nor run_result exists\n                logger.warning(\"API status Failed but cannot determine specific error type (compile/run).\")\n                metadata[\"status\"] = \"unknown_failure_state\"\n                result_status = -1  # Default to -1\n        elif api_status == \"Success\":\n            # Run completed successfully, now check the answer\n            if run_result and metadata[\"run_status\"] == \"Finished\":\n                actual_output = metadata[\"stdout\"] if metadata[\"stdout\"] is not None else \"\"\n                # Note: Output might contain trailing newlines, need normalization\n                if str(actual_output).rstrip(\"\\n\") == str(expected_output).rstrip(\"\\n\"):\n                    result_status = True\n                    metadata[\"status\"] = \"success\"\n                else:\n                    result_status = False\n                    metadata[\"status\"] = \"wrong_answer\"\n            else:\n                # Status is Success but run_result status is not Finished, this is unexpected\n                metadata[\"status\"] = \"unexpected_success_state\"\n                result_status = -1  # Classify as unknown error\n        else:\n            # API returned an unknown top-level status\n            logger.warning(f\"Unknown API status received: {api_status}\")\n            metadata[\"status\"] = f\"unknown_api_status_{api_status}\"\n            result_status = -1  # Default to -1\n    else:  # api_response is None and no error_msg (Should not happen with current call_sandbox_api logic)\n        metadata[\"status\"] = \"unknown_api_state\"\n        result_status = -1\n        logger.error(f\"Case {case_index}: Unknown API state (no response and no error message).\")\n    return result_status, metadata\n\n\ndef check_correctness(\n    sandbox_fusion_url: str,\n    in_outs: Optional[dict],\n    generation: str,\n    timeout: int = DEFAULT_TIMEOUT,\n    memory_limit_mb: int = 1024,\n    language: str = \"python\",\n    concurrent_semaphore: Optional[threading.Semaphore] = None,\n) -> tuple[list[Any], list[dict[str, Any]]]:\n    \"\"\"\n    Checks the correctness of code generation using the remote sandbox API,\n    processing test cases concurrently.\n\n    Args:\n        sandbox_fusion_url: The URL of the sandbox fusion API.\n        in_outs: Dictionary containing \"inputs\" and \"outputs\" lists.\n        generation: The generated code string.\n        timeout: Timeout for each test case (compile and run share this timeout).\n        language: The programming language of the code.\n\n    Returns:\n        A tuple (results, metadata_list).\n        results: A list containing the test result for each input/output pair\n                 (True/False/-1 api/sandbox err, -2 runtime err, -3 timeout, -4 compile err).\n                 Results are ordered corresponding to the inputs.\n        metadata_list: A list containing metadata dictionaries for each test case,\n                       ordered corresponding to the inputs.\n    \"\"\"\n    logger.info(\"Starting correctness check for generation.\")\n\n    if not in_outs or \"inputs\" not in in_outs or \"outputs\" not in in_outs:\n        logger.warning(\"Invalid in_outs format provided.\")\n        return [-1], [{\"error\": \"Invalid input/output data\"}]\n\n    inputs = in_outs[\"inputs\"]\n    expected_outputs = in_outs[\"outputs\"]\n    fn_name = in_outs.get(\"fn_name\")\n    num_cases = len(inputs)\n    results = [None] * num_cases  # Initialize with placeholders\n    metadata_list = [None] * num_cases  # Initialize with placeholders\n\n    if num_cases == 0:\n        logger.warning(\"Empty inputs provided.\")\n        return [], []\n\n    if len(inputs) != len(expected_outputs):\n        logger.warning(f\"Mismatch between number of inputs ({len(inputs)}) and outputs ({len(expected_outputs)}).\")\n        # Return error based on the number of inputs provided\n        return [-1] * num_cases, [{\"error\": \"Input/output count mismatch\", \"case_index\": i} for i in range(num_cases)]\n\n    first_compile_error_index = -1\n\n    # max_workers is limited by sandbox_fusion_max_concurrent from concurrent_semaphore\n    with concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5)) as executor:\n        # Submit all tasks, passing the concurrent_semaphore to _process_single_case\n        future_to_index = {\n            executor.submit(\n                _process_single_case,\n                i,\n                stdin_data,\n                expected_outputs[i],\n                sandbox_fusion_url,\n                generation,\n                timeout,\n                memory_limit_mb,\n                language,\n                concurrent_semaphore,\n                fn_name,\n            ): i\n            for i, stdin_data in enumerate(inputs)\n        }\n\n        # Process results as they complete\n        for future in concurrent.futures.as_completed(future_to_index):\n            index = future_to_index[future]\n            try:\n                result_status, metadata = future.result()\n                results[index] = result_status\n                metadata_list[index] = metadata\n\n                # Check for compile error (-4)\n                if result_status == -4:\n                    if first_compile_error_index == -1 or index < first_compile_error_index:\n                        first_compile_error_index = index\n                    # Optimization: could potentially cancel futures for index > first_compile_error_index\n                    # However, cancellation is not guaranteed. Post-processing is safer.\n\n            except Exception as exc:\n                logger.error(f\"Test case {index} generated an exception: {exc}\")\n                traceback.print_exc()\n                results[index] = -1  # Mark as API/internal error\n                metadata_list[index] = {\n                    \"case_index\": index,\n                    \"input\": str(inputs[index]),\n                    \"expected_output\": str(expected_outputs[index]),\n                    \"api_request_error\": f\"Internal execution error: {exc}\",\n                    \"status\": \"internal_error\",\n                }\n\n    # Post-processing for compile errors\n    if first_compile_error_index != -1:\n        logger.warning(\n            f\"Compile error detected in case {first_compile_error_index}. Marking subsequent cases as compile errors.\"\n        )\n        for i in range(first_compile_error_index + 1, num_cases):\n            # Only update if not already processed (though it should be None or have a result)\n            if results[i] != -4:  # Avoid overwriting if it somehow already got -4\n                results[i] = -4\n                # Update or create metadata for skipped cases due to compile error\n                if metadata_list[i] is None:  # If future failed before returning metadata\n                    metadata_list[i] = {\n                        \"case_index\": i,\n                        \"input\": str(inputs[i]),\n                        \"expected_output\": str(expected_outputs[i]),\n                        \"api_request_error\": None,\n                        \"status\": \"compile_error_skipped\",  # Indicate skipped due to prior compile error\n                    }\n                else:  # If future completed but result is overridden\n                    metadata_list[i][\"status\"] = \"compile_error_skipped\"\n\n    logger.info(f\"Correctness check finished. Results: {results}\")\n    return results, metadata_list\n"
  },
  {
    "path": "verl_rl/verl/utils/reward_score/search_r1_like_qa_em.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n# Copyright 2025 Search-R1 Contributors\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/verl/utils/reward_score/qa_em.py\r\n\r\nimport random\r\nimport re\r\nimport string\r\n\r\n\r\ndef normalize_answer(s):\r\n    def remove_articles(text):\r\n        return re.sub(r\"\\b(a|an|the)\\b\", \" \", text)\r\n\r\n    def white_space_fix(text):\r\n        return \" \".join(text.split())\r\n\r\n    def remove_punc(text):\r\n        exclude = set(string.punctuation)\r\n        return \"\".join(ch for ch in text if ch not in exclude)\r\n\r\n    def lower(text):\r\n        return text.lower()\r\n\r\n    return white_space_fix(remove_articles(remove_punc(lower(s))))\r\n\r\n\r\ndef em_check(prediction, golden_answers):\r\n    if isinstance(golden_answers, str):\r\n        golden_answers = [golden_answers]\r\n    normalized_prediction = normalize_answer(prediction)\r\n    score = 0\r\n    for golden_answer in golden_answers:\r\n        golden_answer = normalize_answer(golden_answer)\r\n        if golden_answer == normalized_prediction:\r\n            score = 1\r\n            break\r\n    return score\r\n\r\n\r\ndef subem_check(prediction, golden_answers):\r\n    if isinstance(golden_answers, str):\r\n        golden_answers = [golden_answers]\r\n    normalized_prediction = normalize_answer(prediction)\r\n    score = 0\r\n    for golden_answer in golden_answers:\r\n        golden_answer = normalize_answer(golden_answer)\r\n        if golden_answer in normalized_prediction:\r\n            score = 1\r\n            break\r\n    return score\r\n\r\n\r\ndef extract_solution(solution_str):\r\n    \"\"\"Extract the equation from the solution string.\"\"\"\r\n    # Remove everything before the first \"Assistant:\"\r\n    # if \"Assistant:\" in solution_str:\r\n    #     solution_str = solution_str.split(\"Assistant:\", 1)[1]\r\n    # elif \"<|im_start|>assistant\" in solution_str:\r\n    #     solution_str = solution_str.split(\"<|im_start|>assistant\", 1)[1]\r\n    # else:\r\n    #     return None\r\n    # solution_str = solution_str.split('\\n')[-1]\r\n\r\n    answer_pattern = r\"<answer>(.*?)</answer>\"\r\n    match = re.finditer(answer_pattern, solution_str, re.DOTALL)\r\n    matches = list(match)\r\n\r\n    # If there are 0  matches, return None\r\n    if len(matches) < 1:\r\n        return None\r\n\r\n    # If there are 2 or more matches, return the last one\r\n    return matches[-1].group(1).strip()\r\n\r\n\r\ndef count_answer_tags(text):\r\n    opening_tags = text.count(\"<answer>\")\r\n    closing_tags = text.count(\"</answer>\")\r\n\r\n    return opening_tags, closing_tags\r\n\r\n\r\ndef compute_score(solution_str, ground_truth, method=\"strict\", format_score=0.0, score=1.0):\r\n    \"\"\"The scoring function for exact match (EM).\r\n\r\n    Args:\r\n        solution_str: the solution text\r\n        ground_truth: the ground truth\r\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\r\n        format_score: the score for the format\r\n        score: the score for the correct answer\r\n    \"\"\"\r\n    answer = extract_solution(solution_str=solution_str)\r\n    open_count, close_count = count_answer_tags(solution_str)\r\n    do_print = random.randint(1, 64) == 1\r\n\r\n    if do_print:\r\n        print(\"--------------------------------\")\r\n        print(f\"Golden answers: {ground_truth['target']}\")\r\n        if answer is not None:\r\n            print(f\"Extracted answer is not None: {answer}\")\r\n        else:\r\n            print(\"Extracted answer: None!\")\r\n        print(f\"Solution string: {solution_str}\")\r\n\r\n    if answer is None:\r\n        return 0\r\n    else:\r\n        if em_check(answer, ground_truth[\"target\"]):\r\n            if open_count > 10 or close_count > 10:  # prevent output a lot of </answer>\r\n                score = score / 4\r\n                return score\r\n            return score\r\n        else:\r\n            return format_score\r\n\r\n\r\ndef compute_score_subem(solution_str, ground_truth, method=\"strict\", format_score=0.0, score=1.0):\r\n    \"\"\"The scoring function for substring exact match (EM).\r\n\r\n    Args:\r\n        solution_str: the solution text\r\n        ground_truth: the ground truth\r\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\r\n        format_score: the score for the format\r\n        score: the score for the correct answer\r\n    \"\"\"\r\n    answer = extract_solution(solution_str=solution_str)\r\n    do_print = random.randint(1, 64) == 1\r\n\r\n    if do_print:\r\n        print(\"--------------------------------\")\r\n        print(f\"Golden answers: {ground_truth['target']}\")\r\n        print(f\"Extracted answer: {answer}\")\r\n        print(f\"Solution string: {solution_str}\")\r\n\r\n    if answer is None:\r\n        return 0\r\n    else:\r\n        if subem_check(answer, ground_truth[\"target\"]):\r\n            return score\r\n        else:\r\n            return format_score\r\n"
  },
  {
    "path": "verl_rl/verl/utils/rollout_trace.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport contextlib\nimport functools\nimport inspect\nimport os\nfrom typing import Optional\n\n\nclass RolloutTraceConfig:\n    \"\"\"Configuration for rollout tracing with various backends.\n\n    Singleton configuration class for managing rollout trace settings across different\n    tracing backends like Weave and MLflow.\n\n    Args:\n        backend (Optional[str]): Tracing backend to use ('weave', 'mlflow', or None).\n        client (Optional[object]): Client instance for the selected backend.\n        token2text (bool): Whether to convert tokens to text in traces. Defaults to False.\n        project_name (str): Name of the project for tracing.\n        experiment_name (str): Name of the experiment for tracing.\n    \"\"\"\n\n    _instance: Optional[\"RolloutTraceConfig\"] = None\n    backend: Optional[str] = None\n    client: Optional[object] = None\n    token2text: bool = False\n    _initialized: bool = False\n    project_name: str = None\n    experiment_name: str = None\n\n    def __new__(cls, *args, **kwargs):\n        if cls._instance is None:\n            cls._instance = super().__new__(cls)\n            cls._instance._initialized = False\n        return cls._instance\n\n    @classmethod\n    def get_instance(cls) -> \"RolloutTraceConfig\":\n        if cls._instance is None:\n            cls._instance = cls()\n        return cls._instance\n\n    @classmethod\n    def init(cls, project_name: str, experiment_name: str, backend: str, token2text: bool = False):\n        config = cls.get_instance()\n        if config._initialized:\n            return\n\n        config.backend = backend\n        config.token2text = token2text\n        config.project_name = project_name\n        config.experiment_name = experiment_name\n\n        if backend == \"weave\":\n            import weave\n\n            config.client = weave.init(project_name)\n        elif backend == \"mlflow\":\n            import mlflow\n\n            mlflow.config.enable_async_logging()\n            config.client = mlflow\n\n            MLFLOW_TRACKING_URI = os.environ.get(\"MLFLOW_TRACKING_URI\", \"sqlite:////tmp/mlruns.db\")\n            mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)\n\n            mlflow.set_experiment(project_name)\n        else:\n            config.client = None\n\n        config._initialized = True\n\n    @classmethod\n    def get_backend(cls) -> Optional[str]:\n        return cls.get_instance().backend\n\n    @classmethod\n    def get_client(cls) -> Optional[object]:\n        return cls.get_instance().client\n\n    @classmethod\n    def enable_token2text(cls) -> Optional[bool]:\n        return cls.get_instance().token2text\n\n    @classmethod\n    def reset(cls):\n        cls._instance = None\n\n\n@contextlib.contextmanager\ndef rollout_trace_attr(sample_index=None, step=None, rollout_n=None, name=\"rollout_trace\", validate=False):\n    \"\"\"A context manager to add attributes to a trace for the configured backend.\"\"\"\n    backend = RolloutTraceConfig.get_backend()\n    attributes = {}\n    if backend:\n        if sample_index is not None:\n            attributes[\"sample_index\"] = sample_index\n        if step is not None:\n            attributes[\"step\"] = step\n        if rollout_n is not None:\n            attributes[\"rollout_n\"] = rollout_n\n        attributes[\"validate\"] = validate\n        attributes[\"experiment_name\"] = RolloutTraceConfig.get_instance().experiment_name\n\n    if not attributes or backend is None:\n        yield\n        return\n\n    if backend == \"weave\":\n        import weave\n\n        with weave.attributes(attributes):\n            yield\n    elif backend == \"mlflow\":\n        import mlflow\n\n        with mlflow.start_span(name=name) as span:\n            trace_id = span.trace_id\n            for key, value in attributes.items():\n                mlflow.set_trace_tag(trace_id, str(key), str(value))\n            yield\n    else:\n        yield\n\n\ndef rollout_trace_op(func):\n    @functools.wraps(func)\n    async def async_wrapper(self, *args, **kwargs):\n        backend = RolloutTraceConfig.get_backend()\n        enable_token2text = RolloutTraceConfig.enable_token2text()\n        if backend is None:\n            return await func(self, *args, **kwargs)\n\n        sig = inspect.signature(func)\n        bound_args = sig.bind(self, *args, **kwargs)\n        bound_args.apply_defaults()\n        inputs = dict(bound_args.arguments)\n        del inputs[\"self\"]\n\n        async def add_token2text(self, result):\n            if hasattr(result, \"prompt_ids\") and hasattr(self, \"tokenizer\") and hasattr(self.tokenizer, \"decode\"):\n                _result = vars(result)\n                loop = asyncio.get_running_loop()\n                if hasattr(result, \"prompt_ids\"):\n                    prompt_text = await loop.run_in_executor(None, self.tokenizer.decode, result.prompt_ids)\n                    _result[\"prompt_text\"] = prompt_text\n\n                if hasattr(result, \"response_ids\"):\n                    response_text = await loop.run_in_executor(None, self.tokenizer.decode, result.response_ids)\n                    _result[\"response_text\"] = response_text\n                return _result\n            return result\n\n        if backend == \"weave\":\n            tracer = RolloutTraceConfig.get_client()\n            from weave.trace.context import call_context\n\n            cur_attributes = {**call_context.call_attributes.get()}\n            call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes)\n            try:\n                result = await func(self, *args, **kwargs)\n\n                if enable_token2text:\n                    _result = await add_token2text(self, result)\n                    tracer.finish_call(call, output=_result)\n                else:\n                    tracer.finish_call(call, output=result)\n\n                return result\n\n            except Exception as e:\n                tracer.finish_call(call, exception=e)\n                raise e\n        elif backend == \"mlflow\":\n            import mlflow\n\n            with mlflow.start_span(name=func.__qualname__) as span:\n                span.set_inputs(inputs)\n                result = await func(self, *args, **kwargs)\n                if enable_token2text:\n                    _result = await add_token2text(self, result)\n                    span.set_outputs(_result)\n                else:\n                    span.set_outputs(result)\n\n            return result\n\n        else:\n            return await func(self, *args, **kwargs)\n\n    @functools.wraps(func)\n    def wrapper(self, *args, **kwargs):\n        backend = RolloutTraceConfig.get_backend()\n        if backend is None:\n            return func(self, *args, **kwargs)\n\n        sig = inspect.signature(func)\n        bound_args = sig.bind(self, *args, **kwargs)\n        bound_args.apply_defaults()\n        inputs = dict(bound_args.arguments)\n        del inputs[\"self\"]\n\n        if backend == \"weave\":\n            tracer = RolloutTraceConfig.get_client()\n            from weave.trace.context import call_context\n\n            cur_attributes = {**call_context.call_attributes.get()}\n            call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes)\n            try:\n                result = func(self, *args, **kwargs)\n                tracer.finish_call(call, output=result)\n                return result\n            except Exception as e:\n                tracer.finish_call(call, exception=e)\n                raise e\n        elif backend == \"mlflow\":\n            import mlflow\n\n            return mlflow.trace(func)(self, *args, **kwargs)\n        else:\n            return func(self, *args, **kwargs)\n\n    return async_wrapper if inspect.iscoroutinefunction(func) else wrapper\n"
  },
  {
    "path": "verl_rl/verl/utils/seqlen_balancing.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nimport heapq\nfrom itertools import chain\n\nimport torch\nfrom torch import distributed as dist\n\nfrom verl.protocol import DataProto\nfrom verl.utils.device import get_device_name\n\n\ndef karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool):\n    # see: https://en.wikipedia.org/wiki/Largest_differencing_method\n    class Set:\n        def __init__(self) -> None:\n            self.sum = 0\n            self.items = []\n\n        def add(self, idx: int, val: int):\n            self.items.append((idx, val))\n            self.sum += val\n\n        def merge(self, other):\n            for idx, val in other.items:\n                self.items.append((idx, val))\n                self.sum += val\n\n        def __lt__(self, other):\n            if self.sum != other.sum:\n                return self.sum < other.sum\n            if len(self.items) != len(other.items):\n                return len(self.items) < len(other.items)\n            return self.items < other.items\n\n    class State:\n        def __init__(self, items: list[tuple[int, int]], k: int) -> None:\n            self.k = k\n            # sets should always be decreasing order\n            self.sets = [Set() for _ in range(k)]\n            assert len(items) in [1, k], f\"{len(items)} not in [1, {k}]\"\n            for i, (idx, seqlen) in enumerate(items):\n                self.sets[i].add(idx=idx, val=seqlen)\n            self.sets = sorted(self.sets, reverse=True)\n\n        def get_partitions(self):\n            partitions = []\n            for i in range(len(self.sets)):\n                cur_partition = []\n                for idx, _ in self.sets[i].items:\n                    cur_partition.append(idx)\n                partitions.append(cur_partition)\n            return partitions\n\n        def merge(self, other):\n            for i in range(self.k):\n                self.sets[i].merge(other.sets[self.k - 1 - i])\n            self.sets = sorted(self.sets, reverse=True)\n\n        @property\n        def spread(self) -> int:\n            return self.sets[0].sum - self.sets[-1].sum\n\n        def __lt__(self, other):\n            # least heap, let the state with largest spread to be popped first,\n            # if the spread is the same, let the state who has the largest set\n            # to be popped first.\n            if self.spread != other.spread:\n                return self.spread > other.spread\n            return self.sets[0] > other.sets[0]\n\n        def __repr__(self) -> str:\n            repr_str = \"[\"\n            for i in range(self.k):\n                if i > 0:\n                    repr_str += \",\"\n                repr_str += \"{\"\n                for j, (_, seqlen) in enumerate(self.sets[i].items):\n                    if j > 0:\n                        repr_str += \",\"\n                    repr_str += str(seqlen)\n                repr_str += \"}\"\n            repr_str += \"]\"\n            return repr_str\n\n    sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])\n    states_pq = []\n    if equal_size:\n        assert len(seqlen_list) % k_partitions == 0, f\"{len(seqlen_list)} % {k_partitions} != 0\"\n        for offset in range(0, len(sorted_seqlen_list), k_partitions):\n            items = []\n            for i in range(k_partitions):\n                seqlen, idx = sorted_seqlen_list[offset + i]\n                items.append((idx, seqlen))\n            heapq.heappush(states_pq, State(items=items, k=k_partitions))\n    else:\n        for seqlen, idx in sorted_seqlen_list:\n            heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions))\n\n    while len(states_pq) > 1:\n        state0 = heapq.heappop(states_pq)\n        state1 = heapq.heappop(states_pq)\n        # merge states\n        state0.merge(state1)\n        heapq.heappush(states_pq, state0)\n\n    final_state = states_pq[0]\n    partitions = final_state.get_partitions()\n    if equal_size:\n        for i, partition in enumerate(partitions):\n            assert len(partition) * k_partitions == len(seqlen_list), (\n                f\"{len(partition)} * {k_partitions} != {len(seqlen_list)}\"\n            )\n    return partitions\n\n\ndef greedy_partition(seqlen_list: list[int], k_partitions: int, equal_size: bool):\n    bias = sum(seqlen_list) + 1 if equal_size else 0\n    sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)]\n    partitions = [[] for _ in range(k_partitions)]\n    partition_sums = [0 for _ in range(k_partitions)]\n    for seqlen, i in sorted_seqlen:\n        min_idx = None\n        for j in range(k_partitions):\n            if min_idx is None or partition_sums[j] < partition_sums[min_idx]:\n                min_idx = j\n        partitions[min_idx].append(i)\n        partition_sums[min_idx] += seqlen\n    if equal_size:\n        for i, partition in enumerate(partitions):\n            assert len(partition) * k_partitions == len(seqlen_list), (\n                f\"{len(partition)} * {k_partitions} != {len(seqlen_list)}\"\n            )\n    return partitions\n\n\ndef get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool):\n    \"\"\"\n    Calculates partitions of indices from seqlen_list such that the sum of sequence lengths\n    in each partition is balanced. Uses the Karmarkar-Karp differencing method.\n\n    This is useful for balancing workload across devices or batches, especially when\n    dealing with variable sequence lengths.\n\n    Args:\n        seqlen_list (List[int]): A list of sequence lengths for each item.\n        k_partitions (int): The desired number of partitions.\n        equal_size (bool): If True, ensures that each partition has the same number of items.\n                           Requires len(seqlen_list) to be divisible by k_partitions.\n                           If False, partitions can have varying numbers of items, focusing\n                           only on balancing the sum of sequence lengths.\n\n    Returns:\n        List[List[int]]: A list containing k_partitions lists. Each inner list contains the\n                         original indices of the items assigned to that partition. The indices\n                         within each partition list are sorted.\n\n    Raises:\n        AssertionError: If len(seqlen_list) < k_partitions.\n        AssertionError: If equal_size is True and len(seqlen_list) is not divisible by k_partitions.\n        AssertionError: If any resulting partition is empty.\n    \"\"\"\n    assert len(seqlen_list) >= k_partitions, f\"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]\"\n\n    def _check_and_sort_partitions(partitions):\n        assert len(partitions) == k_partitions, f\"{len(partitions)} != {k_partitions}\"\n        seen_idx = set()\n        sorted_partitions = [None] * k_partitions\n        for i, partition in enumerate(partitions):\n            assert len(partition) > 0, f\"the {i}-th partition is empty\"\n            for idx in partition:\n                seen_idx.add(idx)\n            sorted_partitions[i] = sorted(partition)\n        assert seen_idx == set(range(len(seqlen_list)))\n        return sorted_partitions\n\n    partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)\n    return _check_and_sort_partitions(partitions)\n\n\ndef log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], prefix):\n    \"\"\"\n    Calculate and log metrics related to sequence length imbalance before and after partitioning.\n\n    Args:\n        seqlen_list (List[int]): A list of sequence lengths for each item.\n        partitions (List[List[int]]): A list of partitions, where each inner list contains indices\n                                      from seqlen_list assigned to that partition.\n        prefix (str): A prefix to be added to each metric key in the returned dictionary.\n\n    Returns:\n        dict: A dictionary containing metrics related to sequence length imbalance.\n    \"\"\"\n    # Get the number of partitions\n    k_partition = len(partitions)\n    # assert len(seqlen_list) % k_partition == 0\n    batch_size = len(seqlen_list) // k_partition\n    min_sum_seqlen = None\n    max_sum_seqlen = None\n    total_sum_seqlen = 0\n\n    # Iterate over each batch of sequence lengths\n    for offset in range(0, len(seqlen_list), batch_size):\n        cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size])\n        if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:\n            min_sum_seqlen = cur_sum_seqlen\n        if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:\n            max_sum_seqlen = cur_sum_seqlen\n        total_sum_seqlen += cur_sum_seqlen\n\n    balanced_sum_seqlen_list = []\n    for partition in partitions:\n        cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition])\n        balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced)\n    # print(\"balanced_sum_seqlen_list: \", balanced_sum_seqlen_list)\n    min_sum_seqlen_balanced = min(balanced_sum_seqlen_list)\n    max_sum_seqlen_balanced = max(balanced_sum_seqlen_list)\n\n    return {\n        f\"{prefix}/min\": min_sum_seqlen,\n        f\"{prefix}/max\": max_sum_seqlen,\n        f\"{prefix}/minmax_diff\": max_sum_seqlen - min_sum_seqlen,\n        f\"{prefix}/balanced_min\": min_sum_seqlen_balanced,\n        f\"{prefix}/balanced_max\": max_sum_seqlen_balanced,\n        f\"{prefix}/mean\": total_sum_seqlen / len(partitions),\n    }\n\n\ndef ceildiv(a, b):\n    return -(a // -b)\n\n\ndef roundup_divisible(a, b):\n    return ((a + b - 1) // b) * b\n\n\ndef rearrange_micro_batches(\n    batch,\n    max_token_len,\n    dp_group=None,\n    num_batches_divided_by=None,\n    same_micro_num_in_dp=True,\n    min_num_micro_batch=None,\n    use_dynamic_bsz_balance=True,\n):\n    \"\"\"\n    Split a batch into micro-batches by total token count, with optional DP sync and padding.\n\n    Args:\n        batch (TensorDict): must include \"attention_mask\" (B*S); other fields are sliced similarly.\n        max_token_len (int): max sum of attention_mask per micro-batch.\n        dp_group (optional): torch.distributed group for data-parallel sync.\n        num_batches_divided_by (optional): virtual pipeline parallel size, for megatron.\n        same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count.\n        min_num_micro_batch (int, optional): force at least this many splits (pads empty ones).\n        use_dynamic_bsz_balance (bool, optional): balance the computational workload between micro-batches\n\n    Returns:\n        List[TensorDict]: the micro-batches.\n        List[List[int]]: index lists mapping each micro-batch back to original positions.\n    \"\"\"\n    # this is per local micro_bsz\n    max_seq_len = batch[\"attention_mask\"].shape[-1]\n    assert max_token_len >= max_seq_len, (\n        f\"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}\"\n    )\n    seq_len_effective: torch.Tensor = batch[\"attention_mask\"].sum(dim=1)\n    total_seqlen = seq_len_effective.sum().item()\n    # NOTE: num_microbatches <= batch_size, so take the min of this two.\n    num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))\n    if min_num_micro_batch is not None:\n        # used to support pp\n        num_micro_batches = max(min_num_micro_batch, num_micro_batches)\n    if dist.is_initialized() and same_micro_num_in_dp:\n        num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name())\n        dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)\n        num_micro_batches = num_micro_batches.cpu().item()\n    if num_batches_divided_by is not None:\n        num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by)\n\n    seq_len_effective = seq_len_effective.tolist()\n    assert num_micro_batches <= len(seq_len_effective)\n\n    micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)\n\n    if use_dynamic_bsz_balance:\n        # Use the sum of squared sequence lengths to approximate attention computation workload\n        micro_bsz_idx.sort(\n            key=lambda partition: (\n                sum(seq_len_effective[idx] ** 2 for idx in partition),\n                min(partition) if partition else 0,\n            ),\n            reverse=True,\n        )\n\n    micro_batches = []\n\n    for partition in micro_bsz_idx:\n        curr_micro_batch = []\n        for idx in partition:\n            curr_micro_batch.append(batch[idx : idx + 1])\n        curr_micro_batch = torch.cat(curr_micro_batch)\n\n        micro_batches.append(curr_micro_batch)\n\n    return micro_batches, micro_bsz_idx\n\n\ndef get_reverse_idx(idx_map):\n    \"\"\"\n    Build the inverse of an index mapping.\n\n    Args:\n        idx_map (Sequence[int]): Sequence where idx_map[i] = j.\n\n    Returns:\n        List[int]: Inverse mapping list such that output[j] = i for each i.\n    \"\"\"\n    reverse_idx_map = copy.deepcopy(idx_map)\n\n    for i, idx in enumerate(idx_map):\n        reverse_idx_map[idx] = i\n\n    return reverse_idx_map\n\n\ndef prepare_dynamic_batch(data: DataProto, max_token_len: int) -> tuple[list[DataProto], list[list[int]]]:\n    \"\"\"\n    Prepare a batch for dynamic batching.\n\n    Args:\n        data (DataProto): The input data.\n        max_token_len (int): The maximum token length for dynamic batching.\n\n    Returns:\n        Tuple[List[DataProto], List[List[int]]]: A tuple containing a list of DataProto objects\n        and a list of index lists.\n    \"\"\"\n    batch, batch_idx_list = rearrange_micro_batches(data.batch, max_token_len=max_token_len)\n    micro_batches = []\n    for i, batch_idx in enumerate(batch_idx_list):\n        tensors = dict(batch[i])\n        non_tensors = {key: value[batch_idx] for key, value in data.non_tensor_batch.items()}\n        micro_batches.append(DataProto.from_dict(tensors, non_tensors))\n\n    return micro_batches, batch_idx_list\n\n\ndef restore_dynamic_batch(data: torch.Tensor, batch_idx_list: list[list[int]]) -> torch.Tensor:\n    \"\"\"\n    Restore a batch from dynamic batching.\n\n    Args:\n        data (torch.Tensor): The input data.\n        batch_idx_list (List[List[int]]): The list of index lists.\n\n    Returns:\n        torch.Tensor: The restored data.\n    \"\"\"\n    indices = list(chain.from_iterable(batch_idx_list))\n    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n    return data[revert_indices]\n"
  },
  {
    "path": "verl_rl/verl/utils/tokenizer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Utils for tokenization.\"\"\"\n\nimport warnings\n\n__all__ = [\"hf_tokenizer\", \"hf_processor\"]\n\n\ndef set_pad_token_id(tokenizer):\n    \"\"\"Set pad_token_id to eos_token_id if it is None.\n\n    Args:\n        tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set.\n\n    \"\"\"\n    if tokenizer.pad_token_id is None:\n        tokenizer.pad_token_id = tokenizer.eos_token_id\n        warnings.warn(f\"tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}\", stacklevel=1)\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n        warnings.warn(f\"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}\", stacklevel=1)\n\n\ndef hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs):\n    \"\"\"Create a huggingface pretrained tokenizer which correctness handles eos and pad tokens.\n\n    Args:\n\n        name (str): The name of the tokenizer.\n        correct_pad_token (bool): Whether to correct the pad token id.\n        correct_gemma2 (bool): Whether to correct the gemma2 tokenizer.\n\n    Returns:\n\n        transformers.PreTrainedTokenizer: The pretrained tokenizer.\n\n    \"\"\"\n    from transformers import AutoTokenizer\n\n    if correct_gemma2 and isinstance(name_or_path, str) and \"gemma-2-2b-it\" in name_or_path:\n        # the EOS token in gemma2 is ambiguious, which may worsen RL performance.\n        # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a\n        warnings.warn(\n            \"Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to <end_of_turn> and 107.\", stacklevel=1\n        )\n        kwargs[\"eos_token\"] = \"<end_of_turn>\"\n        kwargs[\"eos_token_id\"] = 107\n    tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)\n    if correct_pad_token:\n        set_pad_token_id(tokenizer)\n    return tokenizer\n\n\ndef hf_processor(name_or_path, **kwargs):\n    \"\"\"Create a huggingface processor to process multimodal data.\n\n    Args:\n        name_or_path (str): The name of the processor.\n\n    Returns:\n        transformers.ProcessorMixin: The pretrained processor.\n    \"\"\"\n    from transformers import AutoProcessor\n\n    try:\n        processor = AutoProcessor.from_pretrained(name_or_path, **kwargs)\n    except Exception as e:\n        processor = None\n        # TODO(haibin.lin): try-catch should be removed after adding transformer version req to setup.py to avoid\n        # silent failure\n        warnings.warn(f\"Failed to create processor: {e}. This may affect multimodal processing\", stacklevel=1)\n    # Avoid load tokenizer, see:\n    # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344\n    if processor is not None and \"Processor\" not in processor.__class__.__name__:\n        processor = None\n    return processor\n"
  },
  {
    "path": "verl_rl/verl/utils/torch_dtypes.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nAdapted from Cruise.\n\"\"\"\n\nimport torch\n\nHALF_LIST = [16, \"16\", \"fp16\", \"float16\", torch.float16]\nFLOAT_LIST = [32, \"32\", \"fp32\", \"float32\", torch.float32]\nBFLOAT_LIST = [\"bf16\", \"bfloat16\", torch.bfloat16]\n\n\nclass PrecisionType:\n    \"\"\"Type of precision used.\n\n    >>> PrecisionType.HALF == 16\n    True\n    >>> PrecisionType.HALF in (16, \"16\")\n    True\n    \"\"\"\n\n    HALF = \"16\"\n    FLOAT = \"32\"\n    FULL = \"64\"\n    BFLOAT = \"bf16\"\n    MIXED = \"mixed\"\n\n    @staticmethod\n    def supported_type(precision: str | int) -> bool:\n        return any(x == precision for x in PrecisionType)\n\n    @staticmethod\n    def supported_types() -> list[str]:\n        return [x.value for x in PrecisionType]\n\n    @staticmethod\n    def is_fp16(precision):\n        return precision in HALF_LIST\n\n    @staticmethod\n    def is_fp32(precision):\n        return precision in FLOAT_LIST\n\n    @staticmethod\n    def is_bf16(precision):\n        return precision in BFLOAT_LIST\n\n    @staticmethod\n    def to_dtype(precision):\n        if precision in HALF_LIST:\n            return torch.float16\n        elif precision in FLOAT_LIST:\n            return torch.float32\n        elif precision in BFLOAT_LIST:\n            return torch.bfloat16\n        else:\n            raise RuntimeError(f\"unexpected precision: {precision}\")\n\n    @staticmethod\n    def to_str(precision):\n        if precision == torch.float16:\n            return \"fp16\"\n        elif precision == torch.float32:\n            return \"fp32\"\n        elif precision == torch.bfloat16:\n            return \"bf16\"\n        else:\n            raise RuntimeError(f\"unexpected precision: {precision}\")\n"
  },
  {
    "path": "verl_rl/verl/utils/torch_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContain small torch utilities\n\"\"\"\n\nimport math\nfrom contextlib import contextmanager\nfrom typing import Optional\n\nimport torch\nimport torch.distributed\nimport torch.nn.functional as F\nfrom tensordict import TensorDict\nfrom torch import nn\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom transformers import PreTrainedTokenizer\n\nfrom verl.utils.device import get_device_name, get_torch_device\n\ntry:\n    from flash_attn.ops.triton.cross_entropy import cross_entropy_loss\n\n    FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True\nexcept ImportError:\n    FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False\n\n\ntry:\n    import torch_npu\n\n    NPU_CROSS_ENTROPY_LOSS_AVAILABLE = hasattr(torch_npu, \"npu_cross_entropy_loss\")\nexcept ImportError:\n    NPU_CROSS_ENTROPY_LOSS_AVAILABLE = False\n\n\ndef gather_from_labels(data, label):\n    \"\"\"Gather the label from data. The value in label should be [0, vocab_size)\n\n    Args:\n        data: (..., vocab_size)\n        label (torch.IntTensor) : (...,)\n\n    Returns:\n\n    \"\"\"\n\n    output = torch.gather(data, -1, label.unsqueeze(-1)).squeeze(-1)\n    return output\n\n\ndef logprobs_from_logits(logits, labels, inplace_backward=True):\n    \"\"\"\n    Compute per-token log-probabilities for the given labels.\n\n    Uses a Flash-Attention–based cross-entropy (if available) for efficient backward,\n    otherwise falls back to a standard log-softmax+gather approach.\n\n    See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591\n\n    Args:\n        logits (Tensor): Model outputs of shape (..., vocab_size).\n        labels (LongTensor): True class indices of shape matching logits[..., :-1].\n        inplace_backward (bool): If True and Flash-Attn is available, perform backward in-place.\n\n    Returns:\n        Tensor: Log-probabilities of the target labels, shape logits.shape[:-1].\n    \"\"\"\n    if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE:\n        batch_dim = logits.shape[:-1]\n        last_dim = logits.shape[-1]\n        logits = logits.reshape(-1, last_dim)\n        labels = labels.reshape(-1)\n        output = logprobs_from_logits_flash_attn(logits, labels, inplace_backward=inplace_backward)\n        output = output.view(*batch_dim)\n    elif NPU_CROSS_ENTROPY_LOSS_AVAILABLE:\n        output = logprobs_from_logits_torch_npu(logits, labels)\n    else:\n        output = logprobs_from_logits_v2(logits, labels)\n    return output\n\n\ndef logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True):\n    output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward)\n    assert isinstance(output, tuple), (\n        \"please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses].\"\n    )\n    return -output[0]\n\n\ndef logprobs_from_logits_torch_npu(logits, labels):\n    batch_dim = logits.shape[:-1]\n    logits = logits.reshape(-1, logits.shape[-1])\n    loss, _, _, _ = torch_npu.npu_cross_entropy_loss(logits, labels.reshape(-1), reduction=\"none\")\n    return -loss.view(*batch_dim)\n\n\ndef logprobs_from_logits_naive(logits, labels):\n    logp = F.log_softmax(logits, dim=-1)\n    logpy = gather_from_labels(logp, labels)\n    return logpy\n\n\ndef logprobs_from_logits_v2(logits: torch.FloatTensor, labels):\n    \"\"\"\n    A memory efficient implementation of logprobs_from_logits\n    \"\"\"\n    if logits.dtype in [torch.float32, torch.float64]:\n        logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)\n        # loop to reduce peak mem consumption\n        logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits])\n        logprobs_labels = logits_labels - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)\n    else:\n        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach\n        logprobs_labels = []\n        for row_logits, row_labels in zip(logits, labels, strict=True):  # loop to reduce peak mem consumption\n            row_logprobs = F.log_softmax(row_logits, dim=-1)\n            row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)\n            logprobs_labels.append(row_logprobs_labels)\n        logprobs_labels = torch.stack(logprobs_labels)\n    return logprobs_labels\n\n\ndef clip_by_value(x, tensor_min, tensor_max):\n    \"\"\"\n    Tensor extenstion to torch.clamp\n    https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713\n    \"\"\"\n    clipped = torch.max(torch.min(x, tensor_max), tensor_min)\n    return clipped\n\n\ndef entropy_from_logits(logits: torch.Tensor):\n    \"\"\"Calculate entropy from logits.\"\"\"\n    pd = torch.nn.functional.softmax(logits, dim=-1)\n    entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)\n    return entropy\n\n\ndef entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 2048):\n    \"\"\"Memory-efficient entropy calculation with chunking.\"\"\"\n    entropy = torch.zeros(logits.shape[0], device=logits.device)\n    for i in range(0, logits.shape[0], chunk_size):\n        logits_chunk = logits[i : i + chunk_size].float()\n        pd_chunk = torch.nn.functional.softmax(logits_chunk, dim=-1)\n        entropy_chunk = torch.logsumexp(logits_chunk, dim=-1) - torch.sum(pd_chunk * logits_chunk, dim=-1)\n        entropy[i : i + chunk_size] = entropy_chunk\n    return entropy\n\n\ndef masked_sum(values, mask, axis=None):\n    \"\"\"Compute mean of tensor with a masked values.\"\"\"\n    # If NaNs exist out of mask, replace NaNs in values with a value that\n    # won't affect the sum (e.g., 0 for masked regions)\n    valid_values = torch.where(mask.bool(), values, 0.0)\n    return (valid_values * mask).sum(axis=axis)\n\n\ndef masked_mean(values, mask, axis=None):\n    \"\"\"\n    Compute the mean of `values` over elements selected by `mask`.\n\n    Args:\n        values (Tensor): Input tensor.\n        mask (Tensor): Boolean or numeric mask of the same shape as `values`.\n        axis (int or tuple of int, optional): Dimension(s) along which to compute the mean.\n            Defaults to None (over all elements).\n\n    Returns:\n        Tensor: Masked mean, with shape equal to `values` reduced over `axis`.\n    \"\"\"\n    s = masked_sum(values, mask, axis)\n    return s / (mask.sum(axis=axis) + 1e-8)\n\n\ndef masked_var(values, mask, unbiased=True):\n    \"\"\"Compute variance of tensor with masked values.\"\"\"\n    mean = masked_mean(values, mask)\n    centered_values = values - mean\n    variance = masked_mean(centered_values**2, mask)\n    if unbiased:\n        mask_sum = mask.sum()\n        if mask_sum == 0:\n            raise ValueError(\"At least one element in the mask has to be 1.\")\n        # note that if mask_sum == 1, then there is a division by zero issue\n        # to avoid it you just need to use a larger minibatch_size\n        if mask_sum == 1:\n            raise ValueError(\"The sum of the mask is one, which can cause a division by zero.\")\n        bessel_correction = mask_sum / (mask_sum - 1)\n        variance = variance * bessel_correction\n    return variance\n\n\ndef masked_whiten(values, mask, shift_mean=True):\n    \"\"\"\n    Whiten `values` by normalizing with mean and variance computed over `mask`.\n\n    Args:\n        values (torch.Tensor): Input tensor.\n        mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats.\n        shift_mean (bool): If True (default), output is zero-mean;\n                           if False, the original mean is re-added after scaling.\n\n    Returns:\n        torch.Tensor: Whitened tensor of same shape as `values`.\n    \"\"\"\n    mean, var = masked_mean(values, mask), masked_var(values, mask)\n    whitened = (values - mean) * torch.rsqrt(var + 1e-8)\n    if not shift_mean:\n        whitened += mean\n    return whitened\n\n\ndef get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64):\n    \"\"\"\n    end of sentence token can be int or list: 1 or [1, 2]\n    e.g.\n    response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0],\n                                [78, 0, 76, 2, 1, 0, 0],\n                                [23, 98, 1, 0, 0, 0, 0],\n                                [33, 3, 98, 45, 1, 0, 0]])\n    #eos_token=1\n    response_mask:  tensor([[1, 1, 1, 1, 0, 0, 0],\n                            [1, 1, 1, 1, 1, 0, 0],\n                            [1, 1, 1, 0, 0, 0, 0],\n                            [1, 1, 1, 1, 1, 0, 0]])\n    #eos_token=[1,2]\n    response_mask:  tensor([[1, 1, 1, 1, 0, 0, 0],\n                            [1, 1, 1, 1, 0, 0, 0],\n                            [1, 1, 1, 0, 0, 0, 0],\n                            [1, 1, 1, 1, 1, 0, 0]])\n    \"\"\"\n    eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int()\n    return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype)\n\n\ndef compute_grad_norm(model: nn.Module):\n    total_grad_square = 0\n    for param in model.parameters():\n        if param.grad is not None:\n            total_grad_square += torch.sum(torch.square(param.grad.detach())).item()\n    return total_grad_square\n\n\ndef broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src, group):\n    \"\"\"\n    TODO: optimize this. Technically, we only need one broadcast\n    \"\"\"\n\n    for key in tensors.sorted_keys:\n        torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False)\n\n\ndef allgather_dict_tensors(tensors: dict[str, torch.Tensor] | TensorDict, size, group, dim=0):\n    \"\"\"\n    TODO: optimize this.\n    - We can use async ops\n    - We can use only one allgather\n    Args:\n        tensors:\n        size:\n        group:\n\n    Returns:\n\n    \"\"\"\n    if isinstance(tensors, TensorDict):\n        is_tensor_dict = True\n        tensors_as_dict = tensors.to_dict()\n    else:\n        tensors_as_dict = tensors\n        is_tensor_dict = False\n\n    output = {}\n    sorted_keys = sorted(tensors_as_dict.keys())\n    for key in sorted_keys:\n        val = tensors_as_dict[key]\n        output[key] = [torch.empty_like(val) for _ in range(size)]\n        torch.distributed.all_gather(output[key], val, group=group, async_op=False)\n        output[key] = torch.cat(output[key], dim=dim)\n\n    if is_tensor_dict:\n        output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)\n\n    return output\n\n\ndef split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> list[TensorDict]:\n    assert tensors.batch_size[0] % batch_size == 0, (\n        f\"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}\"\n    )\n    return tensors.split(batch_size)\n\n\ndef pad_2d_list_to_length(response, pad_token_id, max_length=None):\n    \"\"\"\n    pad a 2D list (e.g. responses, logprobs) to a 2D tensor.\n    \"\"\"\n    response_length = max(len(sub_list) for sub_list in response)\n    target_length = max_length if max_length is not None and max_length > response_length else response_length\n    padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response]\n    tensor = torch.tensor(padded_response)\n    return tensor\n\n\ndef pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):\n    \"\"\"\n    pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.\n    input shape: [bs, seq_length]\n    output shape: [bs, max_seq_length]\n    \"\"\"\n    if tensors.shape[-1] >= max_seq_len:\n        return tensors\n    # (0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad\n    pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1])\n    return F.pad(tensors, pad_tuple, \"constant\", pad_token_id)\n\n\ndef postprocess_data(\n    input_ids: torch.Tensor,\n    attention_mask: torch.Tensor,\n    max_length: int,\n    pad_token_id: int,\n    left_pad=True,\n    truncation=\"error\",\n):\n    \"\"\"Process tokenizer outputs to consistent shapes via padding/truncation.\n\n    Args:\n        input_ids: Token indices [batch_size, seq_len]\n        attention_mask: Mask [batch_size, seq_len]\n        max_length: Target sequence length\n        pad_token_id: Padding token ID\n        left_pad: Pad left if True\n        truncation: \"left\", \"right\", \"middle\" or \"error\"\n\n    Returns:\n        (input_ids, attention_mask) padded/truncated to max_length\n    \"\"\"\n    assert truncation in [\"left\", \"right\", \"middle\", \"error\"]\n    assert input_ids.ndim == 2\n\n    sequence_length = input_ids.shape[-1]\n    if sequence_length < max_length:\n        input_ids = pad_sequence_to_length(\n            input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad\n        )\n        attention_mask = pad_sequence_to_length(\n            attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad\n        )\n    elif sequence_length > max_length:\n        if truncation == \"left\":\n            # actually, left truncation may not be reasonable\n            input_ids = input_ids[:, -max_length:]\n            attention_mask = attention_mask[:, -max_length:]\n        elif truncation == \"right\":\n            input_ids = input_ids[:, :max_length]\n            attention_mask = attention_mask[:, :max_length]\n        elif truncation == \"middle\":\n            left_half = max_length // 2\n            right_half = max_length - left_half\n            input_ids = torch.cat([input_ids[:, :left_half], input_ids[:, -right_half:]], dim=-1)\n            attention_mask = torch.cat([attention_mask[:, :left_half], attention_mask[:, -right_half:]], dim=-1)\n        elif truncation == \"error\":\n            raise NotImplementedError(f\"{sequence_length=} is larger than {max_length=}\")\n        else:\n            raise NotImplementedError(f\"Unknown truncation method {truncation}\")\n\n    return input_ids, attention_mask\n\n\ndef tokenize_and_postprocess_data(\n    prompt: str, tokenizer: PreTrainedTokenizer, max_length: int, pad_token_id: int, left_pad=True, truncation=\"error\"\n):\n    \"\"\"Tokenize text and process outputs to consistent tensor shapes.\n\n    Args:\n        prompt: Input text to tokenize\n        tokenizer: HuggingFace tokenizer instance\n        max_length: Target sequence length\n        pad_token_id: Padding token ID\n        left_pad: Pad left if True\n        truncation: Truncation strategy (\"left\"/\"right\"/\"error\")\n\n    Returns:\n        Tuple of (input_ids, attention_mask) from postprocess_data\n    \"\"\"\n    input_data = tokenizer(prompt, return_tensors=\"pt\", add_special_tokens=False)\n    input_ids = input_data[\"input_ids\"]\n    attention_mask = input_data[\"attention_mask\"]\n\n    return postprocess_data(input_ids, attention_mask, max_length, pad_token_id, left_pad, truncation)\n\n\ndef remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor):\n    \"\"\"Remove the pad token.\n\n    Args:\n        input_ids shape: [bs, seq_length]\n        attention_mask shape: [bs, seq_length]\n    Returns:\n        no_padding_batch(List[List[int]]): contains the rmpad token ids per query.\n    \"\"\"\n    no_padding_batch = []\n    for ids, mask in zip(input_ids, attention_mask, strict=True):\n        no_padding_batch.append((ids[len(ids) - mask.sum() :]).cpu().numpy().tolist())\n    return no_padding_batch\n\n\ndef log_probs_from_logits_response(input_ids, logits, response_length):\n    \"\"\"Compute the response log_probs from full logits. Note that logits = model(input_ids)\n\n    Args:\n        input_ids: [batch_size, seqlen]\n        logits: [batch_size, seqlen, vocab_size]\n\n    Returns:\n        response_log_prob:\n    \"\"\"\n    response_logits = logits[:, -response_length - 1 : -1]\n    response = input_ids[:, -response_length:]\n    response_log_prob = logprobs_from_logits(logits=response_logits, labels=response)\n    return response_log_prob\n\n\ndef log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):\n    \"\"\"Compute the log_probs from logits with rmpad logits and pad input. Note that\n    logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between\n    logits and input_ids.\n    The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive\n    for large vocab_size\n\n    Args:\n        input_ids: [batch_size, seqlen]\n        attention_mask: [batch_size, seqlen]\n        logits_rmpad: [total_nnz, vocab_size]\n        response_length: int\n    \"\"\"\n    from flash_attn.bert_padding import pad_input, unpad_input\n\n    batch_size, seqlen = input_ids.shape\n    input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)\n    input_ids_rmpad = input_ids_rmpad.squeeze(-1)\n    input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)\n    full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)  # (total_nnz,)\n    full_output = pad_input(\n        hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n    )\n    output = full_output.squeeze(-1)[:, -response_length - 1 : -1]  # [batch_size, response_length]\n    return output\n\n\ndef log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length):\n    \"\"\"Compute the log_probs from logits with rmpad input_ids and logits. Note that\n    logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between\n    logits and input_ids.\n    The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive\n    for large vocab_size\n\n    Args:\n        input_ids_rmpad: [1, total_nnz]\n        logits_rmpad: [total_nnz, vocab_size]\n        indices: [total_nnz]\n        batch_size: int\n        seqlen: int\n        response_length: int\n    \"\"\"\n    from flash_attn.bert_padding import pad_input\n\n    input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # transpose back to [total_nnz, 1]\n    input_ids_rmpad = input_ids_rmpad.squeeze(-1)\n    input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)\n    full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)  # (total_nnz,)\n    full_output = pad_input(\n        hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n    )\n    output = full_output.squeeze(-1)[:, -response_length - 1 : -1]  # [batch_size, response_length]\n    return output\n\n\ndef post_process_logits(input_ids, logits, temperature, top_k, top_p):\n    if temperature != 1.0:\n        logits = logits.div_(temperature)  # inplace operation to avoid OOM\n    # TODO: add them back\n    # if top_k is not None and top_k > 0:\n    #     logits = TopKLogitsWarper(top_k=top_k)(input_ids, logits)\n    # if top_p is not None and top_p < 1.0 and top_p > 0.0:\n    #     logits = TopPLogitsWarper(top_p=top_p)(input_ids, logits)\n    return logits\n\n\n\"\"\"\nOptimizer related\n\"\"\"\n\n\ndef get_cosine_schedule_with_warmup(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    min_lr_ratio: float = 0.0,\n    num_cycles: float = 0.5,\n    last_epoch: int = -1,\n):\n    \"\"\"\n    Create a schedule with a learning rate that decreases following the values of the cosine function between the\n    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\n    initial lr set in the optimizer.\n    Args:\n        optimizer (:class:`~torch.optim.Optimizer`):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (:obj:`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (:obj:`int`):\n            The total number of training steps.\n        min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):\n            The minimum lr ratio w.r.t the maximum.\n        num_cycles (:obj:`float`, `optional`, defaults to 0.5):\n            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0\n            following a half-cosine).\n        last_epoch (:obj:`int`, `optional`, defaults to -1):\n            The index of the last epoch when resuming training.\n    Return:\n        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n    min_lr_ratio = 0.0 if min_lr_ratio is None else min_lr_ratio\n    assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0\n    coef = (1 - min_lr_ratio) * 0.5\n    intercept = (1 + min_lr_ratio) * 0.5\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return min_lr_ratio + (1.0 - min_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps)))\n        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n        x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)\n        return max(min_lr_ratio, x * coef + intercept)\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef get_constant_schedule_with_warmup(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    last_epoch: int = -1,\n):\n    \"\"\"\n    Create a constant LR schedule with a linear warmup phase.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        num_warmup_steps (int): Number of steps to ramp up the LR from 0 to initial value.\n        last_epoch (int, optional): The index of the last epoch when resuming training. Defaults to -1.\n\n    Returns:\n        LambdaLR: Scheduler that increases LR linearly during warmup, then holds it constant.\n    \"\"\"\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1.0, num_warmup_steps))\n        return 1.0\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds):\n    # create causal mask\n    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n    combined_attention_mask = None\n    if input_shape[-1] > 1:\n        combined_attention_mask = _make_causal_mask(\n            input_shape,\n            inputs_embeds.dtype,\n            device=inputs_embeds.device,\n        )\n\n    if attention_mask is not None:\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n            inputs_embeds.device\n        )\n        combined_attention_mask = (\n            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n        )\n\n    return combined_attention_mask\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\ndef get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\ndef get_wsd_schedule_with_warmup(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    min_lr_ratio: float = 0.0,\n    num_cycles: float = 0.5,\n    last_epoch: int = -1,\n    stable_ratio: float = 0.9,\n):\n    \"\"\"\n    Create a Warmup-Stable-Decay learning rate scheduler.\n\n    The schedule follows three phases:\n    1. Warmup: Learning rate increases linearly from 0 to the initial LR\n    2. Stable: Learning rate remains constant at the initial LR\n    3. Decay: Learning rate decreases following a cosine curve to min_lr_ratio * initial LR\n\n    Args:\n        optimizer (:class:`~torch.optim.Optimizer`):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (:obj:`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (:obj:`int`):\n            The total number of training steps.\n        min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):\n            The minimum learning rate ratio w.r.t the initial learning rate.\n        num_cycles (:obj:`float`, `optional`, defaults to 0.5):\n            The number of waves in the cosine schedule during decay phase.\n        last_epoch (:obj:`int`, `optional`, defaults to -1):\n            The index of the last epoch when resuming training.\n        stable_ratio (:obj:`float`, `optional`, defaults to 0.0):\n            The ratio of non-warmup steps that should maintain a constant learning rate.\n            Set to 0.0 to behave exactly like cosine schedule.\n\n    Return:\n        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n    remaining_steps = max(0, num_training_steps - num_warmup_steps)\n    num_stable_steps = int(remaining_steps * stable_ratio)\n    num_decay_steps = remaining_steps - num_stable_steps\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        if current_step < num_warmup_steps + num_stable_steps:\n            return 1.0\n        if current_step < num_training_steps:\n            progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))\n            value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))\n            return (1.0 - min_lr_ratio) * value + min_lr_ratio\n        return min_lr_ratio\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\n@contextmanager\ndef check_device_is_available():\n    \"\"\"\n    Some modules must be imported after CUDA is initialized. Such as sglang's sharding manager.\n\n    This context manager checks if CUDA is available and raises an error if it is not.\n    \"\"\"\n    if not get_torch_device().is_available():\n        raise RuntimeError(\"Device {} must be initialized before importing this module.\".format(get_device_name()))\n\n    yield\n\n\ndef distributed_mean_max_min_std(local_tensor, compute_max=True, compute_min=True, compute_std=True):\n    \"\"\"Compute distributed statistics across all processes.\n\n    Args:\n        local_tensor: Tensor containing local values\n        compute_max: Include maximum value calculation\n        compute_min: Include minimum value calculation\n        compute_std: Include standard deviation calculation\n\n    Returns:\n        Tuple containing (mean, max, min, std) in this order. None for disabled metrics.\n    \"\"\"\n    # Sum the local tensor across all processes\n    local_sum = torch.sum(local_tensor)\n    local_num = torch.tensor(torch.numel(local_tensor), device=get_device_name())\n\n    torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM)\n    torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM)\n\n    global_mean = local_sum / local_num\n\n    if compute_max:\n        local_max = torch.max(local_tensor)\n        torch.distributed.all_reduce(local_max, op=torch.distributed.ReduceOp.MAX)\n    else:\n        local_max = None\n\n    if compute_min:\n        local_min = torch.min(local_tensor)\n        torch.distributed.all_reduce(local_min, op=torch.distributed.ReduceOp.MIN)\n    else:\n        local_min = None\n\n    if compute_std:\n        square_diff = torch.sum(torch.pow(local_tensor - global_mean, 2))\n        torch.distributed.all_reduce(square_diff, op=torch.distributed.ReduceOp.SUM)\n        global_std = torch.sqrt(square_diff / (local_num - 1))\n    else:\n        global_std = None\n\n    return global_mean, local_max, local_min, global_std\n\n\ndef distributed_masked_mean(local_tensor, local_mask):\n    \"\"\"Compute global mean of non-masked elements across distributed processes.\n\n    Args:\n        local_tensor (torch.Tensor): Input tensor with local values\n        local_mask (torch.Tensor): Binary mask (1=valid, 0=ignore) matching local_tensor shape\n\n    Returns:\n        torch.Tensor: Global mean of all valid elements across processes\n    \"\"\"\n    local_tensor = local_tensor * local_mask\n\n    local_sum = torch.sum(local_tensor)\n    local_num = torch.sum(local_mask)\n\n    torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM)\n    torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM)\n\n    global_mean = local_sum / local_num\n    return global_mean\n"
  },
  {
    "path": "verl_rl/verl/utils/tracking.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA unified tracking interface that supports logging data to different backend\n\"\"\"\n\nimport dataclasses\nimport os\nfrom enum import Enum\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Any\n\n\nclass Tracking:\n    \"\"\"A unified tracking interface for logging experiment data to multiple backends.\n\n    This class provides a centralized way to log experiment metrics, parameters, and artifacts\n    to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console.\n\n    Attributes:\n        supported_backend: List of supported tracking backends.\n        logger: Dictionary of initialized logger instances for each backend.\n    \"\"\"\n\n    supported_backend = [\"wandb\", \"mlflow\", \"swanlab\", \"vemlp_wandb\", \"tensorboard\", \"console\", \"clearml\"]\n\n    def __init__(self, project_name, experiment_name, default_backend: str | list[str] = \"console\", config=None):\n        if isinstance(default_backend, str):\n            default_backend = [default_backend]\n        for backend in default_backend:\n            if backend == \"tracking\":\n                import warnings\n\n                warnings.warn(\"`tracking` logger is deprecated. use `wandb` instead.\", DeprecationWarning, stacklevel=2)\n            else:\n                assert backend in self.supported_backend, f\"{backend} is not supported\"\n\n        self.logger = {}\n\n        if \"tracking\" in default_backend or \"wandb\" in default_backend:\n            import wandb\n\n            settings = None\n            if config and config[\"trainer\"].get(\"wandb_proxy\", None):\n                settings = wandb.Settings(https_proxy=config[\"trainer\"][\"wandb_proxy\"])\n            wandb.init(project=project_name, name=experiment_name, config=config, settings=settings, mode='offline')\n            self.logger[\"wandb\"] = wandb\n\n        if \"mlflow\" in default_backend:\n            import os\n\n            import mlflow\n\n            MLFLOW_TRACKING_URI = os.environ.get(\"MLFLOW_TRACKING_URI\", \"sqlite:////tmp/mlruns.db\")\n            mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)\n\n            # Project_name is actually experiment_name in MLFlow\n            # If experiment does not exist, will create a new experiment\n            experiment = mlflow.set_experiment(project_name)\n            mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name)\n            mlflow.log_params(_compute_mlflow_params_from_objects(config))\n            self.logger[\"mlflow\"] = _MlflowLoggingAdapter()\n\n        if \"swanlab\" in default_backend:\n            import os\n\n            import swanlab\n\n            SWANLAB_API_KEY = os.environ.get(\"SWANLAB_API_KEY\", None)\n            SWANLAB_LOG_DIR = os.environ.get(\"SWANLAB_LOG_DIR\", \"swanlog\")\n            SWANLAB_MODE = os.environ.get(\"SWANLAB_MODE\", \"cloud\")\n            if SWANLAB_API_KEY:\n                swanlab.login(SWANLAB_API_KEY)  # NOTE: previous login information will be overwritten\n\n            if config is None:\n                config = {}  # make sure config is not None, otherwise **config will raise error\n            swanlab.init(\n                project=project_name,\n                experiment_name=experiment_name,\n                config={\"FRAMEWORK\": \"verl\", **config},\n                logdir=SWANLAB_LOG_DIR,\n                mode=SWANLAB_MODE,\n            )\n            self.logger[\"swanlab\"] = swanlab\n\n        if \"vemlp_wandb\" in default_backend:\n            import os\n\n            import volcengine_ml_platform\n            from volcengine_ml_platform import wandb as vemlp_wandb\n\n            volcengine_ml_platform.init(\n                ak=os.environ[\"VOLC_ACCESS_KEY_ID\"],\n                sk=os.environ[\"VOLC_SECRET_ACCESS_KEY\"],\n                region=os.environ[\"MLP_TRACKING_REGION\"],\n            )\n\n            vemlp_wandb.init(\n                project=project_name,\n                name=experiment_name,\n                config=config,\n                sync_tensorboard=True,\n            )\n            self.logger[\"vemlp_wandb\"] = vemlp_wandb\n\n        if \"tensorboard\" in default_backend:\n            self.logger[\"tensorboard\"] = _TensorboardAdapter(project_name, experiment_name)\n\n        if \"console\" in default_backend:\n            from verl.utils.logger import LocalLogger\n\n            self.console_logger = LocalLogger(print_to_console=True)\n            self.logger[\"console\"] = self.console_logger\n\n        if \"clearml\" in default_backend:\n            self.logger[\"clearml\"] = ClearMLLogger(project_name, experiment_name, config)\n\n    def log(self, data, step, backend=None):\n        for default_backend, logger_instance in self.logger.items():\n            if backend is None or default_backend in backend:\n                logger_instance.log(data=data, step=step)\n\n    def __del__(self):\n        if \"wandb\" in self.logger:\n            self.logger[\"wandb\"].finish(exit_code=0)\n        if \"swanlab\" in self.logger:\n            self.logger[\"swanlab\"].finish()\n        if \"vemlp_wandb\" in self.logger:\n            self.logger[\"vemlp_wandb\"].finish(exit_code=0)\n        if \"tensorboard\" in self.logger:\n            self.logger[\"tensorboard\"].finish()\n\n        if \"clearnml\" in self.logger:\n            self.logger[\"clearnml\"].finish()\n\n\nclass ClearMLLogger:\n    def __init__(self, project_name: str, experiment_name: str, config):\n        self.project_name = project_name\n        self.experiment_name = experiment_name\n\n        import clearml\n\n        self._task: clearml.Task = clearml.Task.init(\n            task_name=experiment_name,\n            project_name=project_name,\n            continue_last_task=True,\n            output_uri=False,\n        )\n\n        self._task.connect_configuration(config, name=\"Hyperparameters\")\n\n    def _get_logger(self):\n        return self._task.get_logger()\n\n    def log(self, data, step):\n        import numpy as np\n        import pandas as pd\n\n        # logs = self._rewrite_logs(data)\n        logger = self._get_logger()\n        for k, v in data.items():\n            title, series = k.split(\"/\", 1)\n\n            if isinstance(v, int | float | np.floating | np.integer):\n                logger.report_scalar(\n                    title=title,\n                    series=series,\n                    value=v,\n                    iteration=step,\n                )\n            elif isinstance(v, pd.DataFrame):\n                logger.report_table(\n                    title=title,\n                    series=series,\n                    table_plot=v,\n                    iteration=step,\n                )\n            else:\n                logger.warning(\n                    f'Trainer is attempting to log a value of \"{v}\" of type {type(v)} for key \"{k}\". This '\n                    f\"invocation of ClearML logger's function is incorrect so this attribute was dropped. \"\n                )\n\n    def finish(self):\n        self._task.mark_completed()\n\n\nclass _TensorboardAdapter:\n    def __init__(self, project_name, experiment_name):\n        import os\n\n        from torch.utils.tensorboard import SummaryWriter\n\n        tensorboard_dir = os.environ.get(\"TENSORBOARD_DIR\", f\"tensorboard_log/{project_name}/{experiment_name}\")\n        os.makedirs(tensorboard_dir, exist_ok=True)\n        print(f\"Saving tensorboard log to {tensorboard_dir}.\")\n        self.writer = SummaryWriter(tensorboard_dir)\n\n    def log(self, data, step):\n        for key in data:\n            self.writer.add_scalar(key, data[key], step)\n\n    def finish(self):\n        self.writer.close()\n\n\nclass _MlflowLoggingAdapter:\n    def log(self, data, step):\n        import mlflow\n\n        results = {k.replace(\"@\", \"_at_\"): v for k, v in data.items()}\n        mlflow.log_metrics(metrics=results, step=step)\n\n\ndef _compute_mlflow_params_from_objects(params) -> dict[str, Any]:\n    if params is None:\n        return {}\n\n    return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep=\"/\")\n\n\ndef _transform_params_to_json_serializable(x, convert_list_to_dict: bool):\n    _transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict)\n\n    if dataclasses.is_dataclass(x):\n        return _transform(dataclasses.asdict(x))\n    if isinstance(x, dict):\n        return {k: _transform(v) for k, v in x.items()}\n    if isinstance(x, list):\n        if convert_list_to_dict:\n            return {\"list_len\": len(x)} | {f\"{i}\": _transform(v) for i, v in enumerate(x)}\n        else:\n            return [_transform(v) for v in x]\n    if isinstance(x, Path):\n        return str(x)\n    if isinstance(x, Enum):\n        return x.value\n\n    return x\n\n\ndef _flatten_dict(raw: dict[str, Any], *, sep: str) -> dict[str, Any]:\n    import pandas as pd\n\n    ans = pd.json_normalize(raw, sep=sep).to_dict(orient=\"records\")[0]\n    assert isinstance(ans, dict)\n    return ans\n\n\n@dataclasses.dataclass\nclass ValidationGenerationsLogger:\n    project_name: str = None\n    experiment_name: str = None\n\n    def log(self, loggers, samples, step):\n        if \"wandb\" in loggers:\n            self.log_generations_to_wandb(samples, step)\n        if \"swanlab\" in loggers:\n            self.log_generations_to_swanlab(samples, step)\n        if \"mlflow\" in loggers:\n            self.log_generations_to_mlflow(samples, step)\n\n        if \"clearml\" in loggers:\n            self.log_generations_to_clearml(samples, step)\n        if \"tensorboard\" in loggers:\n            self.log_generations_to_tensorboard(samples, step)\n\n        if \"vemlp_wandb\" in loggers:\n            self.log_generations_to_vemlp_wandb(samples, step)\n\n    def log_generations_to_vemlp_wandb(self, samples, step):\n        from volcengine_ml_platform import wandb as vemlp_wandb\n\n        self._log_generations_to_wandb(samples, step, vemlp_wandb)\n\n    def log_generations_to_wandb(self, samples, step):\n        import wandb\n\n        self._log_generations_to_wandb(samples, step, wandb)\n\n    def _log_generations_to_wandb(self, samples, step, wandb):\n        \"\"\"Log samples to wandb as a table\"\"\"\n\n        # Create column names for all samples\n        columns = [\"step\"] + sum(\n            [[f\"input_{i + 1}\", f\"output_{i + 1}\", f\"score_{i + 1}\"] for i in range(len(samples))], []\n        )\n\n        if not hasattr(self, \"validation_table\"):\n            # Initialize the table on first call\n            self.validation_table = wandb.Table(columns=columns)\n\n        # Create a new table with same columns and existing data\n        # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737\n        new_table = wandb.Table(columns=columns, data=self.validation_table.data)\n\n        # Add new row with all data\n        row_data = []\n        row_data.append(step)\n        for sample in samples:\n            row_data.extend(sample)\n\n        new_table.add_data(*row_data)\n\n        # Update reference and log\n        wandb.log({\"val/generations\": new_table}, step=step)\n        self.validation_table = new_table\n\n    def log_generations_to_swanlab(self, samples, step):\n        \"\"\"Log samples to swanlab as text\"\"\"\n        import swanlab\n\n        swanlab_table = swanlab.echarts.Table()\n\n        # Create column names\n        headers = [\"step\", \"input\", \"output\", \"score\"]\n\n        swanlab_row_list = [[step, *sample] for sample in samples]\n        swanlab_table.add(headers=headers, rows=swanlab_row_list)\n\n        # Log to swanlab\n        swanlab.log({\"val/generations\": swanlab_table}, step=step)\n\n    def log_generations_to_mlflow(self, samples, step):\n        \"\"\"Log validation generation to mlflow as artifacts\"\"\"\n        # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact\n\n        import json\n        import tempfile\n\n        import mlflow\n\n        try:\n            with tempfile.TemporaryDirectory() as tmp_dir:\n                validation_gen_step_file = Path(tmp_dir, f\"val_step{step}.json\")\n                row_data = []\n                for sample in samples:\n                    data = {\"input\": sample[0], \"output\": sample[1], \"score\": sample[2]}\n                    row_data.append(data)\n                with open(validation_gen_step_file, \"w\") as file:\n                    json.dump(row_data, file)\n                mlflow.log_artifact(validation_gen_step_file)\n        except Exception as e:\n            print(f\"WARNING: save validation generation file to mlflow failed with error {e}\")\n\n    def log_generations_to_clearml(self, samples, step):\n        \"\"\"Log validation generation to clearml as table\"\"\"\n\n        import clearml\n        import pandas as pd\n\n        task: clearml.Task | None = clearml.Task.current_task()\n        if task is None:\n            return\n\n        table = [\n            {\n                \"step\": step,\n                \"input\": sample[0],\n                \"output\": sample[1],\n                \"score\": sample[2],\n            }\n            for sample in samples\n        ]\n\n        logger = task.get_logger()\n        logger.report_table(\n            series=\"Validation generations\",\n            title=\"Validation\",\n            table_plot=pd.DataFrame.from_records(table),\n            iteration=step,\n        )\n\n    def log_generations_to_tensorboard(self, samples, step):\n        \"\"\"Log samples to tensorboard as text\"\"\"\n        # Initialize tensorboard writer if not exists\n        if not hasattr(self, \"writer\"):\n            from torch.utils.tensorboard import SummaryWriter\n\n            # Use the same directory structure as _TensorboardAdapter\n            if self.project_name and self.experiment_name:\n                default_dir = os.path.join(\"tensorboard_log\", self.project_name, self.experiment_name)\n            else:\n                default_dir = \"tensorboard_log\"\n\n            tensorboard_dir = os.environ.get(\"TENSORBOARD_DIR\", default_dir)\n            os.makedirs(tensorboard_dir, exist_ok=True)\n            self.writer = SummaryWriter(log_dir=tensorboard_dir)\n\n        # Format the samples data into readable text\n        text_content = f\"**Generation Results - Step {step}**\\n\\n\"\n\n        for i, sample in enumerate(samples):\n            text_content += f\"### Sample {i + 1}\\n\"\n\n            # Assuming sample contains [input, output, score]\n            if len(sample) >= 3:\n                input_text, output_text, score = sample[0], sample[1], sample[2]\n\n                text_content += f\"**Input:** {input_text}\\n\\n\"\n                text_content += f\"**Output:** {output_text}\\n\\n\"\n                text_content += f\"**Score:** {score}\\n\\n\"\n            else:\n                # Handle cases where sample format might be different\n                text_content += f\"**Data:** {sample}\\n\\n\"\n\n            text_content += \"---\\n\\n\"\n\n        # Log to tensorboard as text\n        self.writer.add_text(\"val/generations\", text_content, step)\n        # Flush to ensure data is written\n        self.writer.flush()\n"
  },
  {
    "path": "verl_rl/verl/utils/ulysses.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities for DeepSpeed Ulysses Sequence Parallelism.\nDeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509\nInspired from: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/sequence/layer.py\n\"\"\"\n\nfrom typing import Any, Optional\n\nimport torch\nimport torch.distributed as dist\nfrom torch import Tensor\nfrom torch.distributed import ProcessGroup\n\n_ULYSSES_SEQUENCE_PARALLEL_GROUP = None\n\n\ndef set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):\n    \"\"\"\n    Set ulysses sequence parallel process group.\n    \"\"\"\n    global _ULYSSES_SEQUENCE_PARALLEL_GROUP\n    _ULYSSES_SEQUENCE_PARALLEL_GROUP = group\n\n\ndef get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:\n    \"\"\"\n    Get ulysses sequence parallel process group.\n    \"\"\"\n    global _ULYSSES_SEQUENCE_PARALLEL_GROUP\n    return _ULYSSES_SEQUENCE_PARALLEL_GROUP\n\n\ndef get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:\n    \"\"\"\n    Get ulysses sequence parallel world size.\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    return dist.get_world_size(group) if group else 1\n\n\ndef get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:\n    \"\"\"\n    Get ulysses sequence parallel rank.\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    return dist.get_rank(group) if group else 0\n\n\ndef gather_seq_scatter_heads(\n    x: Tensor,\n    seq_dim: int,\n    head_dim: int,\n    unpadded_dim_size: int = 0,\n    group: ProcessGroup = None,\n) -> Tensor:\n    \"\"\"\n    A func to sync embedding input with alltoall in sequence parallel\n    gather sequence dimension and scatter head dim:\n    e.g. seq_dim: 1, head_dim: 2\n    [bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...]\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    if not group:\n        return x\n    sp_world = get_ulysses_sequence_parallel_world_size(group)\n    x = SeqAllToAll.apply(group, x, head_dim, seq_dim)\n    if unpadded_dim_size and unpadded_dim_size % sp_world != 0:\n        padding_size = x.size(seq_dim) - unpadded_dim_size\n        x = _unpad_tensor(x, seq_dim, padding_size)\n    return x\n\n\ndef gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor:\n    \"\"\"\n    A func to sync attention result with alltoall in sequence parallel\n    gather head dimension and scatter seq dim:\n    e.g. seq_dim: 1, head_dim: 2\n    [bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...]\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    if not group:\n        return x\n    dim_size = x.size(seq_dim)\n    sp_world = get_ulysses_sequence_parallel_world_size(group)\n    if dim_size % sp_world != 0:\n        padding_size = sp_world - (dim_size % sp_world)\n        x = _pad_tensor(x, seq_dim, padding_size)\n    return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)\n\n\ndef _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:\n    shape = list(x.shape)\n    shape[dim] = padding_size\n    pad = torch.zeros(shape, dtype=x.dtype, device=x.device)\n    return torch.cat([x, pad], dim=dim)\n\n\ndef _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:\n    slc = [slice(None)] * len(x.shape)\n    slc[dim] = slice(0, -padding_size)\n    return x[slc]\n\n\ndef slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor:\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    sp_world_size = dist.get_world_size(group)\n    sp_rank = get_ulysses_sequence_parallel_rank()\n    dim_size = x.size(dim)\n    # pad before slice\n    if padding and dim_size % sp_world_size:\n        padding_size = sp_world_size - (dim_size % sp_world_size)\n        x = _pad_tensor(x, dim, padding_size)\n    # slice the input tensor\n    parts = x.size(dim) // sp_world_size\n    slc = [slice(None)] * len(x.shape)\n    slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts)\n    return x[slc].contiguous()\n\n\ndef all_to_all_tensor(\n    local_input: Tensor,\n    scatter_dim: int,\n    gather_dim: int,\n    group: Optional[dist.ProcessGroup] = None,\n    async_op: bool = False,\n):\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    seq_world_size = dist.get_world_size(group)\n    input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]\n    output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]\n    comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)\n    if async_op:\n\n        def wait():\n            comm.wait()\n            return torch.cat(output_list, dim=gather_dim).contiguous()\n\n        return wait\n    return torch.cat(output_list, dim=gather_dim).contiguous()\n\n\ndef all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False):\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    sp_world_size = dist.get_world_size(group=group)\n    output_shape = list(local_tensor.shape)\n    output_shape[0] = output_shape[0] * sp_world_size\n    output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device)\n    dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op)\n    return output\n\n\nclass SeqAllToAll(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx: Any,\n        group: dist.ProcessGroup,\n        local_input: Tensor,\n        scatter_dim: int,\n        gather_dim: int,\n        async_op: bool = False,\n    ) -> Tensor:\n        ctx.group = group\n        ctx.scatter_dim = scatter_dim\n        ctx.gather_dim = gather_dim\n        ctx.async_op = async_op\n        return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)\n\n    @staticmethod\n    def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]:\n        input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() if ctx.async_op else grad_output[0]\n        return (\n            None,\n            all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),\n            None,\n            None,\n            None,\n            None,\n        )\n\n\nclass Gather(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx: Any,\n        group: dist.ProcessGroup,\n        local_tensor: Tensor,\n        gather_dim: int,\n        grad_scaler: bool = True,\n        async_op=False,\n    ) -> Tensor:\n        ctx.group = group\n        ctx.gather_dim = gather_dim\n        ctx.grad_scaler = grad_scaler\n        ctx.async_op = async_op\n\n        sp_world_size = dist.get_world_size(group=group)\n        ctx.sp_world_size = sp_world_size\n\n        sp_rank = dist.get_rank(group=group)\n        ctx.sp_rank = sp_rank\n\n        local_shape = list(local_tensor.size())\n        split_size = local_shape[0]\n        part_size = local_shape[gather_dim]  # store original size\n        ctx.part_size = part_size\n\n        output = all_gather_tensor(local_tensor, group, async_op)\n        return torch.cat(output.split(split_size, dim=0), dim=gather_dim)\n\n    @staticmethod\n    def backward(ctx: Any, grad_output: Tensor) -> Any:\n        if ctx.grad_scaler:\n            grad_output = grad_output * ctx.sp_world_size\n        return (\n            None,\n            grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(),\n            None,\n            None,\n            None,\n            None,\n        )\n\n\ndef gather_outpus_and_unpad(*args, **kwargs):\n    raise RuntimeError(\n        \"please use verl.utils.ulysses.gather_outputs_and_unpad instead of verl.utils.ulysses.gather_outpus_and_unpad\"\n    )\n\n\ndef gather_outputs_and_unpad(\n    x: Tensor,\n    gather_dim: int,\n    unpad_dim: int = None,\n    padding_size: int = 0,\n    grad_scaler: bool = True,\n    group: Optional[dist.ProcessGroup] = None,\n):\n    \"\"\"\n    Gather a tensor across a process group and optionally unpad its padded elements.\n\n    Args:\n        x (Tensor): Input tensor to gather.\n        gather_dim (int): Dimension along which to gather across ranks.\n        unpad_dim (int, optional): Dimension from which to remove padding. If None, no unpadding.\n        padding_size (int): Number of padding elements to remove on `unpad_dim`. Defaults to 0.\n        grad_scaler (bool): Whether to apply gradient scaling during gather. Defaults to True.\n        group (ProcessGroup, optional): Process group for gathering. If None, uses\n            `get_ulysses_sequence_parallel_group()`. If still None, returns `x` unchanged.\n\n    Returns:\n        Tensor: The gathered tensor, with padding removed if requested.\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    if group is None:\n        return x\n    x = Gather.apply(group, x, gather_dim, grad_scaler)\n    if unpad_dim is not None:\n        assert isinstance(padding_size, int), \"padding size is not given or is not an integer\"\n        if padding_size == 0:\n            return x\n        x = _unpad_tensor(x, unpad_dim, padding_size)\n    return x\n\n\ndef ulysses_pad(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1):\n    if position_ids_rmpad is not None:\n        assert position_ids_rmpad.size(-2) == 1\n        assert input_ids_rmpad.size(-1) == position_ids_rmpad.size(-1)\n    if sp_size <= 1:\n        return input_ids_rmpad, position_ids_rmpad, 0\n    _, total_seq_len = input_ids_rmpad.shape\n    pad_size = (sp_size - total_seq_len % sp_size) % sp_size\n    if pad_size > 0:\n        input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0)\n        if position_ids_rmpad is not None:\n            pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)\n            if position_ids_rmpad.dim() == 3:\n                pad_pos_ids = pad_pos_ids.unsqueeze(0).repeat(3, 1, 1)\n            position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)\n    return input_ids_rmpad, position_ids_rmpad, pad_size\n\n\ndef ulysses_pad_and_slice_inputs(\n    input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1\n):\n    \"\"\"\n    Pad and slice input_ids to be divisible by sp_size\n    Pad position_ids to be divisible by sp_size.\n\n    Note both input_ids_rmpad and position_ids_rmpad will be padded and sliced.\n\n    The is the utility of pre-forward for ulysses sequence parallelism\n\n    Args:\n        input_ids_rmpad: shape of [bsz, seqlen]\n        position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1\n        sp_size (int): ulysses sequence parallelism size\n\n    Returns:\n        torch.Tensor: padded and sliced input_ids\n        torch.Tensor: padded and sliced position_ids\n        int: pad size\n    \"\"\"\n    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(input_ids_rmpad, position_ids_rmpad, sp_size)\n    input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)\n    if position_ids_rmpad is not None:\n        position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False)\n    return input_ids_rmpad, position_ids_rmpad, pad_size\n\n\ndef validate_ulysses_config(num_heads, ulysses_sequence_size):\n    if ulysses_sequence_size > 1:\n        assert num_heads % ulysses_sequence_size == 0, (\n            f\"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})\"\n        )\n"
  },
  {
    "path": "verl_rl/verl/utils/vllm_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom msgspec import field\nfrom packaging import version as vs\nfrom vllm.lora.models import LoRAModel\nfrom vllm.lora.request import LoRARequest\nfrom vllm.lora.utils import get_adapter_absolute_path\nfrom vllm.lora.worker_manager import LRUCacheWorkerLoRAManager\n\nfrom verl.third_party.vllm import get_version\n\n# To support different vLLM versions, we add the model into SUPPORTED_MOE_MODELS separately to avoid triggering\n# unsupported issues.\nSUPPORTED_MOE_MODELS = []\n\ntry:\n    from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(DeepseekV2ForCausalLM)\n    SUPPORTED_MOE_MODELS.append(DeepseekV3ForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.mixtral import MixtralForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(MixtralForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(Qwen2MoeForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(Qwen3MoeForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.kimi_vl import KimiVLForConditionalGeneration\n\n    SUPPORTED_MOE_MODELS.append(KimiVLForConditionalGeneration)\nexcept ImportError:\n    pass\n\n\ndef patch_vllm_moe_model_weight_loader(model):\n    # this is a work around to load the weight of vllm fused moe model\n    # it is from a bug from vllm 0.8.2\n    # all the weights are supposed to have a weight_loader, but the moe weights\n    # do not have a weight_loader, so we need to patch it\n    # (True, 'model.embed_tokens.weight')\n    # (True, 'model.layers.0.self_attn.qkv_proj.weight')\n    # (True, 'model.layers.0.self_attn.qkv_proj.bias')\n    # (True, 'model.layers.0.self_attn.o_proj.weight')\n    # (True, 'model.layers.0.mlp.gate.weight')\n    # (True, 'model.layers.0.mlp.shared_expert.gate_up_proj.weight')\n    # (True, 'model.layers.0.mlp.shared_expert.down_proj.weight')\n    # (False, 'model.layers.0.mlp.shared_expert_gate.weight')   use default\n    # (False, 'model.layers.0.input_layernorm.weight')          use default\n    # (False, 'model.layers.0.post_attention_layernorm.weight') use default\n    # (False, 'model.layers.0.mlp.experts.w13_weight')          use mlp.experts.weight_loader\n    # (False, 'model.layers.0.mlp.experts.w2_weight')          use mlp.experts.weight_loader\n\n    # Define MLP attribute mapping for different model types\n    MLP_ATTR_MAPPING = {\n        MixtralForCausalLM: \"block_sparse_moe\",\n    }\n    DEFAULT_MLP_ATTR = \"mlp\"\n\n    if not isinstance(model, tuple(SUPPORTED_MOE_MODELS)):\n        return\n\n    model = getattr(model, \"model\", None) or getattr(model, \"language_model\", None)\n    if model is None:\n        raise ValueError(\"The provided model does not have a valid 'model' or 'language_model' attribute.\")\n\n    for layer in model.layers:\n        mlp_attr = MLP_ATTR_MAPPING.get(type(model), DEFAULT_MLP_ATTR)\n        mlp = getattr(layer, mlp_attr)\n\n        param_dict = dict(mlp.named_parameters())\n        for name, param in param_dict.items():\n            if \"w13_weight\" in name or \"w2_weight\" in name:\n                param.weight_loader = mlp.experts.weight_loader\n\n\nclass TensorLoRARequest(LoRARequest):\n    peft_config: dict = field(default=None)\n    lora_tensors: dict = field(default=None)\n\n\nclass VLLMHijack:\n    @staticmethod\n    def hijack():\n        def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel:\n            \"\"\"\n            based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors\n\n            Reason:\n            VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths.\n            To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to\n            load memory-based LoRA tensors.\n            \"\"\"\n            try:\n                supported_lora_modules = self._adapter_manager.supported_lora_modules\n                packed_modules_mapping = self._adapter_manager.packed_modules_mapping\n                expected_lora_modules: list[str] = []\n                for module in supported_lora_modules:\n                    if module in packed_modules_mapping:\n                        expected_lora_modules.extend(packed_modules_mapping[module])\n                    else:\n                        expected_lora_modules.append(module)\n\n                expected_lora_modules = list(set(expected_lora_modules))\n\n                lora_tensors = None\n                from vllm.lora.peft_helper import PEFTHelper\n\n                if isinstance(lora_request, TensorLoRARequest):\n                    peft_config = lora_request.peft_config\n                    lora_tensors = lora_request.lora_tensors\n                    peft_helper = PEFTHelper.from_dict(peft_config)\n                else:\n                    lora_path = get_adapter_absolute_path(lora_request.lora_path)\n\n                    peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings)\n\n                # Validates the LoRA configuration against requirements before\n                # loading weights, throwing an exception if validation fails.\n                peft_helper.validate_legal(self.lora_config)\n\n                # For some models like Qwen2VL, we need to use hf_to_vllm_mapper\n                # to ensure correct loading of lora weights.\n                model = self._adapter_manager.model\n                hf_to_vllm_mapper = None\n                if hasattr(model, \"hf_to_vllm_mapper\") and model.hf_to_vllm_mapper is not None:\n                    hf_to_vllm_mapper = model.hf_to_vllm_mapper\n\n                if isinstance(lora_request, TensorLoRARequest):\n                    lora = self._lora_model_cls.from_lora_tensors(\n                        lora_model_id=lora_request.lora_int_id,\n                        tensors=lora_tensors,\n                        peft_helper=peft_helper,\n                        device=\"cpu\",\n                        dtype=self.lora_config.lora_dtype,\n                        embeddings=None,\n                        target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size,\n                        embedding_modules=self.embedding_modules,\n                        embedding_padding_modules=self.embedding_padding_modules,\n                        weights_mapper=hf_to_vllm_mapper,\n                    )\n                else:\n                    lora = self._lora_model_cls.from_local_checkpoint(\n                        lora_path,\n                        expected_lora_modules,\n                        peft_helper=peft_helper,\n                        lora_model_id=lora_request.lora_int_id,\n                        device=\"cpu\",\n                        dtype=self.lora_config.lora_dtype,\n                        target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size,\n                        embedding_modules=self.embedding_modules,\n                        embedding_padding_modules=self.embedding_padding_modules,\n                        weights_mapper=hf_to_vllm_mapper,\n                    )\n            except Exception as e:\n                raise e\n\n            if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:\n                raise ValueError(\n                    f\"LoRA added vocab size {lora.extra_vocab_size} is greater than lora_extra_vocab_size \"\n                    f\"{self.lora_config.lora_extra_vocab_size}.\"\n                )\n            return lora\n\n        def do_hijack(target_cls, target_method_name, hooking_method):\n            setattr(target_cls, target_method_name, hooking_method)\n\n        do_hijack(LRUCacheWorkerLoRAManager, \"_load_adapter\", hijack__load_adapter)\n\n\ndef is_version_ge(pkg: str = \"vllm\", minver: str = \"0.7.3\"):\n    \"\"\"check if the package version is greater than or equal to the minimum version\"\"\"\n    return vs.parse(get_version(pkg)) >= vs.parse(minver)\n"
  },
  {
    "path": "verl_rl/verl/version/version",
    "content": "0.5.0\n"
  },
  {
    "path": "verl_rl/verl/workers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/workers/actor/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BasePPOActor\nfrom .dp_actor import DataParallelPPOActor\n\n__all__ = [\"BasePPOActor\", \"DataParallelPPOActor\"]\n"
  },
  {
    "path": "verl_rl/verl/workers/actor/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe base class for Actor\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nimport torch\n\nfrom verl import DataProto\n\n__all__ = [\"BasePPOActor\"]\n\n\nclass BasePPOActor(ABC):\n    def __init__(self, config):\n        \"\"\"The base class for PPO actor\n\n        Args:\n            config (DictConfig): a config passed to the PPOActor. We expect the type to be\n                DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general.\n        \"\"\"\n        super().__init__()\n        self.config = config\n\n    @abstractmethod\n    def compute_log_prob(self, data: DataProto) -> torch.Tensor:\n        \"\"\"Compute logits given a batch of data.\n\n        Args:\n            data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,\n                ```attention_mask``` and ```position_ids```.\n\n        Returns:\n            DataProto: a DataProto containing the key ```log_probs```\n\n\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def update_policy(self, data: DataProto) -> dict:\n        \"\"\"Update the policy with an iterator of DataProto\n\n        Args:\n            data (DataProto): an iterator over the DataProto that returns by\n                ```make_minibatch_iterator```\n\n        Returns:\n            Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model\n            such as ```loss```, ```grad_norm```, etc,.\n\n        \"\"\"\n        pass\n"
  },
  {
    "path": "verl_rl/verl/workers/actor/dp_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nSingle Process Actor\n\"\"\"\n\nimport logging\nimport os\n\nimport torch\nfrom torch import nn\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nimport verl.utils.torch_functional as verl_F\nfrom verl import DataProto\nfrom verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty\nfrom verl.utils.device import get_device_name, is_cuda_available, is_npu_available\nfrom verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch\nfrom verl.utils.torch_functional import logprobs_from_logits\nfrom verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs\nfrom verl.workers.actor import BasePPOActor\n\nif is_cuda_available:\n    from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\nelif is_npu_available:\n    from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input\n\n\n__all__ = [\"DataParallelPPOActor\"]\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass DataParallelPPOActor(BasePPOActor):\n    def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None):\n        \"\"\"When optimizer is None, it is Reference Policy\"\"\"\n        super().__init__(config)\n        self.actor_module = actor_module\n        self.actor_optimizer = actor_optimizer\n\n        self.use_remove_padding = self.config.get(\"use_remove_padding\", False)\n        if torch.distributed.get_rank() == 0:\n            print(f\"Actor use_remove_padding={self.use_remove_padding}\")\n        self.use_fused_kernels = self.config.get(\"use_fused_kernels\", False)\n        if torch.distributed.get_rank() == 0:\n            print(f\"Actor use_fused_kernels={self.use_fused_kernels}\")\n\n        self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size\n        self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1\n\n        if self.config.entropy_from_logits_with_chunking:\n            entropy_from_logits = verl_F.entropy_from_logits_with_chunking\n        else:\n            entropy_from_logits = verl_F.entropy_from_logits\n\n        self.compute_entropy_from_logits = (\n            torch.compile(entropy_from_logits, dynamic=True)\n            if self.config.get(\"use_torch_compile\", True)  #  use torch compile by default\n            else entropy_from_logits\n        )\n        self.device_name = get_device_name()\n\n    def _forward_micro_batch(\n        self, micro_batch, temperature, calculate_entropy=False\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Returns:\n            entropy: # (bs, response_len)\n            log_probs: # (bs, response_len)\n        \"\"\"\n        response_length = micro_batch[\"responses\"].size(-1)\n        multi_modal_inputs = {}\n        if \"multi_modal_inputs\" in micro_batch.keys():\n            if \"image_bound\" in micro_batch[\"multi_modal_inputs\"][0]:  # minicpm-o logic\n                for key in micro_batch[\"multi_modal_inputs\"][0].keys():\n                    multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch[\"multi_modal_inputs\"]]\n            else:\n                for key in micro_batch[\"multi_modal_inputs\"][0].keys():\n                    multi_modal_inputs[key] = torch.cat(\n                        [inputs[key] for inputs in micro_batch[\"multi_modal_inputs\"]], dim=0\n                    )\n\n        with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch_size, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n            entropy = None\n            if position_ids.dim() == 3:  # qwen2vl mrope\n                position_ids = position_ids.transpose(0, 1)  # (bsz, 3, seqlen) -> (3, bsz, seqlen)\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                if position_ids.dim() == 3:\n                    position_ids_rmpad = (\n                        index_first_axis(rearrange(position_ids, \"c b s ... -> (b s) c ...\"), indices)\n                        .transpose(0, 1)\n                        .unsqueeze(1)\n                    )  # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)\n                else:\n                    position_ids_rmpad = index_first_axis(\n                        rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                    ).transpose(0, 1)\n\n                if \"image_bound\" in multi_modal_inputs:\n                    from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo\n\n                    multi_modal_inputs = process_multi_modal_inputs_for_minicpmo(\n                        input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs\n                    )\n\n                # for compute the log_prob\n                input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)\n\n                # pad and slice the inputs if sp > 1\n                if self.use_ulysses_sp:\n                    is_vlm_model = \"multi_modal_inputs\" in micro_batch.keys()\n                    if is_vlm_model:\n                        # vlm model's inputs will be sliced after embedding\n                        input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(\n                            input_ids_rmpad,\n                            position_ids_rmpad=position_ids_rmpad,\n                            sp_size=self.ulysses_sequence_parallel_size,\n                        )\n                    else:\n                        input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                            input_ids_rmpad,\n                            position_ids_rmpad=position_ids_rmpad,\n                            sp_size=self.ulysses_sequence_parallel_size,\n                        )\n                    input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad_rolled,\n                        position_ids_rmpad=None,\n                        sp_size=self.ulysses_sequence_parallel_size,\n                    )\n\n                input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)  # ((total_nnz / sp) + pad)\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                extra_args = {}\n                if self.use_fused_kernels:\n                    extra_args[\"temperature\"] = temperature\n                    extra_args[\"return_dict\"] = True\n\n                output = self.actor_module(\n                    input_ids=input_ids_rmpad,\n                    attention_mask=None,\n                    position_ids=position_ids_rmpad,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                    **extra_args,\n                )  # prevent model thinks we are generating\n\n                if self.use_fused_kernels:\n                    log_probs = output.log_probs.squeeze(0)  # (total_nnz,)\n                    entropy_rmpad = output.entropy.squeeze(0)  # (total_nnz,)\n\n                else:\n                    logits_rmpad = output.logits.squeeze(0)  # (total_nnz, vocab_size)\n                    logits_rmpad.div_(temperature)\n\n                    # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)\n                    inplace_backward = True\n                    if calculate_entropy:\n                        inplace_backward = False\n                    log_probs = logprobs_from_logits(\n                        logits=logits_rmpad,\n                        labels=input_ids_rmpad_rolled,\n                        inplace_backward=inplace_backward,\n                    )\n\n                    # compute entropy\n                    if calculate_entropy:\n                        if not self.config.entropy_checkpointing:\n                            entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad)  # ((total_nnz / sp) + pad)\n                        else:\n                            entropy_rmpad = torch.utils.checkpoint.checkpoint(\n                                self.compute_entropy_from_logits, logits_rmpad\n                            )\n\n                # gather log_prob if sp > 1\n                if self.use_ulysses_sp:\n                    # gather and unpad for the ulysses sp\n                    log_probs = gather_outputs_and_unpad(\n                        log_probs,\n                        gather_dim=0,\n                        unpad_dim=0,\n                        padding_size=pad_size,\n                    )\n                    if calculate_entropy:\n                        entropy_rmpad = gather_outputs_and_unpad(\n                            entropy_rmpad,\n                            gather_dim=0,\n                            unpad_dim=0,\n                            padding_size=pad_size,\n                        )\n                # pad back to (bsz, seqlen)\n                if calculate_entropy:\n                    full_entropy = pad_input(\n                        hidden_states=entropy_rmpad.unsqueeze(-1),\n                        indices=indices,\n                        batch=batch_size,\n                        seqlen=seqlen,\n                    )\n                full_log_probs = pad_input(\n                    hidden_states=log_probs.unsqueeze(-1),\n                    indices=indices,\n                    batch=batch_size,\n                    seqlen=seqlen,\n                )\n\n                # only return response part:\n                if calculate_entropy:\n                    entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)\n                log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)\n\n            else:  # not using rmpad and no ulysses sp\n                extra_args = {}\n                if self.use_fused_kernels:\n                    extra_args[\"temperature\"] = temperature\n                    extra_args[\"return_dict\"] = True\n\n                output = self.actor_module(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                    **extra_args,\n                )  # prevent model thinks we are generating\n\n                if self.use_fused_kernels:\n                    log_probs = output.log_probs[:, -response_length - 1 : -1]\n                    entropy = output.entropy[:, -response_length - 1 : -1]  # (bsz, response_length)\n\n                else:\n                    logits = output.logits\n\n                    logits.div_(temperature)\n                    logits = logits[:, -response_length - 1 : -1, :]  # (bsz, response_length, vocab_size)\n                    log_probs = logprobs_from_logits(logits, micro_batch[\"responses\"])\n                    if calculate_entropy:\n                        if not self.config.entropy_checkpointing:\n                            entropy = verl_F.entropy_from_logits(logits)  # (bsz, response_length)\n                        else:\n                            entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits)\n\n            return entropy, log_probs\n\n    def _optimizer_step(self):\n        assert self.config.grad_clip is not None\n\n        if isinstance(self.actor_module, FSDP):\n            grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)\n        elif isinstance(self.actor_module, FSDPModule):\n            grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)\n\n        # if grad_norm is not finite, skip the update\n        if not torch.isfinite(grad_norm):\n            print(f\"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}\")\n            self.actor_optimizer.zero_grad()\n        else:\n            self.actor_optimizer.step()\n        return grad_norm\n\n    @GPUMemoryLogger(role=\"dp actor\", logger=logger)\n    def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:\n        \"\"\"Compute the log probability of the responses given input_ids, attention_mask and position_ids\n\n        Args:\n            data (DataProto): a DataProto containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the\n                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.\n\n        Returns:\n            torch.Tensor: the log_prob tensor\n        \"\"\"\n        # set to eval\n        self.actor_module.eval()\n\n        micro_batch_size = data.meta_info[\"micro_batch_size\"]\n        temperature = data.meta_info[\"temperature\"]  # temperature must be in the data.meta_info to avoid silent error\n        use_dynamic_bsz = data.meta_info[\"use_dynamic_bsz\"]\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n        non_tensor_select_keys = [\"multi_modal_inputs\"] if has_multi_modal_inputs else []\n\n        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)\n\n        if use_dynamic_bsz:\n            max_token_len = data.meta_info[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len)\n        else:\n            micro_batches = data.split(micro_batch_size)\n\n        log_probs_lst = []\n        entropy_lst = []\n        for micro_batch in micro_batches:\n            model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n            with torch.no_grad():\n                entropy, log_probs = self._forward_micro_batch(\n                    model_inputs, temperature=temperature, calculate_entropy=calculate_entropy\n                )\n            log_probs_lst.append(log_probs)\n            if calculate_entropy:\n                entropy_lst.append(entropy)\n\n        log_probs = torch.concat(log_probs_lst, dim=0)\n        entropys = None\n        if calculate_entropy:\n            entropys = torch.concat(entropy_lst, dim=0)\n\n        if use_dynamic_bsz:\n            log_probs = restore_dynamic_batch(log_probs, batch_idx_list)\n            if calculate_entropy:\n                entropys = restore_dynamic_batch(entropys, batch_idx_list)\n\n        return log_probs, entropys\n\n    @GPUMemoryLogger(role=\"dp actor\", logger=logger)\n    def update_policy(self, data: DataProto):\n        # make sure we are in training mode\n        self.actor_module.train()\n\n        temperature = data.meta_info[\"temperature\"]  # temperature must be in the data.meta_info to avoid silent error\n\n        select_keys = [\n            \"responses\",\n            \"response_mask\",\n            \"input_ids\",\n            \"attention_mask\",\n            \"position_ids\",\n            \"old_log_probs\",\n            \"advantages\",\n        ]\n        if self.config.use_kl_loss:\n            select_keys.append(\"ref_log_prob\")\n\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n        non_tensor_select_keys = [\"multi_modal_inputs\"] if has_multi_modal_inputs else []\n\n        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        mini_batches = data.split(self.config.ppo_mini_batch_size)\n\n        metrics = {}\n        for _ in range(self.config.ppo_epochs):\n            for batch_idx, mini_batch in enumerate(mini_batches):\n                if self.config.use_dynamic_bsz:\n                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                    micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len)\n                else:\n                    self.gradient_accumulation = (\n                        self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n                    )\n                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n\n                self.actor_optimizer.zero_grad()\n\n                for micro_batch in micro_batches:\n                    micro_batch_metrics = {}\n                    model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n                    response_mask = model_inputs[\"response_mask\"]\n                    old_log_prob = model_inputs[\"old_log_probs\"]\n                    advantages = model_inputs[\"advantages\"]\n\n                    clip_ratio = self.config.clip_ratio\n                    clip_ratio_low = (\n                        self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio\n                    )\n                    clip_ratio_high = (\n                        self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio\n                    )\n                    clip_ratio_c = self.config.get(\"clip_ratio_c\", 3.0)\n                    entropy_coeff = self.config.entropy_coeff\n                    loss_agg_mode = self.config.loss_agg_mode\n\n                    # all return: (bsz, response_length)\n                    calculate_entropy = False\n                    if entropy_coeff != 0:\n                        calculate_entropy = True\n                    entropy, log_prob = self._forward_micro_batch(\n                        model_inputs, temperature=temperature, calculate_entropy=calculate_entropy\n                    )\n\n                    loss_mode = self.config.policy_loss.get(\"loss_mode\", \"vanilla\")\n\n                    if self.config.policy_loss.loss_mode == \"vanilla\":\n                        pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(\n                            old_log_prob=old_log_prob,\n                            log_prob=log_prob,\n                            advantages=advantages,\n                            response_mask=response_mask,\n                            cliprange=clip_ratio,\n                            cliprange_low=clip_ratio_low,\n                            cliprange_high=clip_ratio_high,\n                            clip_ratio_c=clip_ratio_c,\n                            loss_agg_mode=loss_agg_mode,\n                        )\n\n                    else:\n                        policy_loss_fn = get_policy_loss_fn(loss_mode)\n                        pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(\n                            old_log_prob=old_log_prob,\n                            log_prob=log_prob,\n                            advantages=advantages,\n                            response_mask=response_mask,\n                            loss_agg_mode=loss_agg_mode,\n                            config=self.config,\n                        )\n\n                    if entropy_coeff != 0:\n                        entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n                        # compute policy loss\n                        policy_loss = pg_loss - entropy_loss * entropy_coeff\n                    else:\n                        policy_loss = pg_loss\n\n                    if self.config.use_kl_loss:\n                        ref_log_prob = model_inputs[\"ref_log_prob\"]\n                        # compute kl loss\n                        kld = kl_penalty(\n                            logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type\n                        )\n                        kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n                        policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef\n                        micro_batch_metrics[\"actor/kl_loss\"] = kl_loss.detach().item()\n                        micro_batch_metrics[\"actor/kl_coef\"] = self.config.kl_loss_coef\n\n                    if self.config.use_dynamic_bsz:\n                        # relative to the dynamic bsz\n                        loss = policy_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size)\n                    else:\n                        loss = policy_loss / self.gradient_accumulation\n                    loss.backward()\n\n                    micro_batch_metrics.update(\n                        {\n                            \"actor/pg_loss\": pg_loss.detach().item(),\n                            \"actor/pg_clipfrac\": pg_clipfrac.detach().item(),\n                            \"actor/ppo_kl\": ppo_kl.detach().item(),\n                            \"actor/pg_clipfrac_lower\": pg_clipfrac_lower.detach().item(),\n                        }\n                    )\n                    append_to_dict(metrics, micro_batch_metrics)\n\n                grad_norm = self._optimizer_step()\n                mini_batch_metrics = {\"actor/grad_norm\": grad_norm.detach().item()}\n                append_to_dict(metrics, mini_batch_metrics)\n        self.actor_optimizer.zero_grad()\n        return metrics\n"
  },
  {
    "path": "verl_rl/verl/workers/actor/megatron_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMegatron Actor.\nIn megatron actor, the differences are:\n1. We only make minibatch\n\nNote that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer\n\"\"\"\n\nimport itertools\nimport logging\nimport os\nfrom functools import partial\nfrom typing import Iterable\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.distributed import finalize_model_grads\n\n# from megatron.core.optimizer import DistributedOptimizer\nfrom megatron.core.optimizer import DistributedOptimizer\nfrom megatron.core.pipeline_parallel import get_forward_backward_func\nfrom omegaconf import OmegaConf\nfrom torch import nn\n\nfrom verl import DataProto\nfrom verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.megatron.pipeline_parallel import make_batch_generator\nfrom verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits\nfrom verl.utils.megatron_utils import get_model_config\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.profiler.profile import Profiler\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.utils.torch_functional import broadcast_dict_tensor\nfrom verl.workers.actor import BasePPOActor\n\n__all__ = [\"MegatronPPOActor\"]\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MegatronPPOActor(BasePPOActor):\n    def __init__(\n        self,\n        config,\n        model_config,\n        hf_config,\n        tf_config,\n        actor_module: nn.ModuleList,\n        actor_optimizer: DistributedOptimizer,\n    ):\n        \"\"\"MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron.\n\n        Args:\n            config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain\n\n                ``ppo_micro_batch_size_per_gpu``: micro batch size when updating ppo.\n\n                ``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data.\n\n                ``ppo_epochs``: number of epochs to update the actor using the batch data.\n\n                ``shuffle``: whether to shuffle the data after each ppo epoch.\n\n                ``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347.\n\n                ``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347.\n            model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and\n                ``model_config.hidden_size``\n            hf_config (PretrainedConfig): huggingface config\n            tf_config (TransformerConfig): mcore transformer config\n            actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this\n                pp stage.\n                each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for\n                more details.\n                The actor module has some constraints to follow in order to use the updating logics implemented here\n\n                1. It must implement unpad_input before any computation and pad_input after all the computation.\n                Remove padding is an\n                optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn\n                (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py).\n\n                2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size],\n                where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size\n                of the hidden state is [total_nnz // tp, 1, hidden_size].\n            actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron.\n                It implements\n                zero1 optimizer that shards the optimizer state across dp ranks.\n\n        >>> from megatron.training import get_model\n        >>> from megatron.optimizer import get_megatron_optimizer\n        >>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True)\n        >>> actor_module = nn.ModuleList(actor_module)\n        >>> actor_optimizer = get_megatron_optimizer(actor_module)\n        >>> actor = MegatronPPOActor(config=config,\n        >>>                          model_config=actor_model_config,\n        >>>                          hf_config=hf_config,\n        >>>                          tf_config=tf_config,\n        >>>                          actor_module=actor_module,\n        >>>                          actor_optimizer=actor_optimizer)\n        \"\"\"\n        super().__init__(config)\n        self._validate_config(config)\n        self.model_config = model_config\n        self.hf_config = hf_config\n        self.tf_config = tf_config\n        self.actor_module = actor_module\n        self.actor_optimizer: DistributedOptimizer = actor_optimizer\n        self.prof = Profiler(self.config.profile)\n        self.use_fused_kernels = self.config.get(\"use_fused_kernels\", False)\n        if self.use_fused_kernels:\n            from verl.models.mcore.model_forward_fused import patch_fused_forward\n\n            for model in self.actor_module:\n                patch_fused_forward(model)\n\n        self.optimizer_step_args = OmegaConf.create(\n            {\n                \"skip_grad\": None,\n                \"overlap_dp_param_comm\": False,\n                \"overlap_dp_grad_comm\": False,\n                \"gradient_accumulation_steps\": 1,\n                \"sequence_parallel\": self.tf_config.sequence_parallel,\n                \"DDP_impl\": \"local\",\n                \"layernorm_allreduce_bucket_threshold\": 0,\n                \"pipeline_model_parallel_split_rank\": None,\n                \"reduce_grads_use_alltoall\": False,\n            }\n        )\n\n        config = get_model_config(self.actor_module[0])\n        print(config)\n        config.finalize_model_grads_func = finalize_model_grads\n\n    def _validate_config(self, config) -> None:\n        \"\"\"Validate config options not implemented for Megatron backend\"\"\"\n        assert config.get(\"ulysses_sequence_parallel_size\", 1) == 1\n        if config.get(\"shuffle\", False):\n            assert config.data_loader_seed is not None, \"If shuffle dataloader, seed must be manually set\"\n        if config.megatron.tensor_model_parallel_size == 1:\n            print(\"[Warining] Because actor tp size == 1, set sp to False\")\n            config.megatron.sequence_parallel = False\n        self.config = config\n\n    @GPUMemoryLogger(role=\"megatron actor\", logger=logger)\n    def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:\n        \"\"\"Compute the log probability of the responses given input_ids, attention_mask and position_ids\n\n        Args:\n            data (DataProto): a DataProto containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the\n                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.\n\n        Returns:\n            DataProto: torch.Tensor: the log_prob tensor\n        \"\"\"\n        data.to(get_device_id())\n        data.batch = data.batch.contiguous()\n        use_dynamic_bsz = data.meta_info.get(\"use_dynamic_bsz\", False)\n        micro_batch_size = data.meta_info.get(\"micro_batch_size\", None)\n        max_token_len = data.meta_info.get(\"max_token_len\", None)\n        assert micro_batch_size is not None, \"micro batch size is needed for forward compute\"\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            max_token_len = max_token_len * self.config.megatron.context_parallel_size\n\n        def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None):\n            response = data[\"responses\"]\n            response_length = response.size(1)\n            log_probs = output[\"log_probs\"][:, -response_length - 1 : -1].contiguous()\n            return {\"log_probs\": log_probs}\n\n        # We make recompute_old_log_prob by default here.\n        # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be\n        # handled by user outside\n        recompute_old_log_prob = self.config.get(\"recompute_old_log_prob\", True)\n\n        entropys = torch.Tensor()\n        if recompute_old_log_prob:\n            select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n            batch = data.select(batch_keys=select_keys).batch\n            input_ids = batch[\"input_ids\"]\n            batch_size = input_ids.size(0)\n            response = batch[\"responses\"]\n            response_length = response.size(1)\n            with torch.no_grad():\n                output = self.forward_backward_batch(\n                    data,\n                    forward_only=True,\n                    post_process_fn=compute_logprobs_fn,\n                    calculate_entropy=calculate_entropy,\n                    use_dynamic_bsz=use_dynamic_bsz,\n                    micro_batch_size=micro_batch_size,\n                    max_token_len=max_token_len,\n                )\n                if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                    # only on last rank. It should be on every tp rank\n                    if calculate_entropy:\n                        log_probs = [o[0][\"log_probs\"] for o in output[\"output\"]]  # (bs, seq_size)\n                    else:\n                        log_probs = [o[\"log_probs\"] for o in output[\"output\"]]  # (bs, seq_size)\n                    log_probs = torch.cat(log_probs, dim=0).to(torch.float32)\n                    if use_dynamic_bsz:\n                        indices = output[\"indices\"]\n                        indices = list(itertools.chain.from_iterable(indices))\n                        assert len(indices) == log_probs.size(0), f\"{len(indices)} vs. {log_probs.size()}\"\n                        revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                        log_probs = log_probs[revert_indices]\n                else:\n                    log_probs = torch.empty(\n                        size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device\n                    )\n\n                # broadcast across pp ranks\n                torch.distributed.broadcast(\n                    tensor=log_probs,\n                    src=mpu.get_pipeline_model_parallel_last_rank(),\n                    group=mpu.get_pipeline_model_parallel_group(),\n                    async_op=False,\n                )\n                if calculate_entropy:\n                    # Note that o[0] is metrics, o[1] is entropy\n                    if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                        entropys = torch.cat([o[1] for o in output[\"output\"]], dim=0)\n                        entropys = entropys.to(torch.float32)\n                        if use_dynamic_bsz:\n                            indices = output[\"indices\"]\n                            indices = list(itertools.chain.from_iterable(indices))\n                            assert len(indices) == entropys.size(0), f\"{len(indices)} vs. {entropys.size()}\"\n                            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                            entropys = entropys[revert_indices]\n                    else:\n                        entropys = torch.empty(\n                            size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device\n                        )\n                    # broadcast across pp ranks\n                    torch.distributed.broadcast(\n                        tensor=entropys,\n                        src=mpu.get_pipeline_model_parallel_last_rank(),\n                        group=mpu.get_pipeline_model_parallel_group(),\n                        async_op=False,\n                    )\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        return log_probs, entropys\n\n    def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:\n        \"\"\"Make minibatch iterator for updating the actor\n\n        Args:\n            data (DataProto): a DataProto containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where\n                ``sequence_length = prompt_length + response_length``\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64\n\n                ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that\n                responses = input_ids[:, -response_length:]\n\n                ``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability\n                of responses.\n\n                ``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of\n                responses.\n                See PPO paper for details. https://arxiv.org/abs/1707.06347\n\n        Returns:\n\n        \"\"\"\n        select_keys = [\n            \"responses\",\n            \"input_ids\",\n            \"attention_mask\",\n            \"response_mask\",\n            \"position_ids\",\n            \"old_log_probs\",\n            \"advantages\",\n        ]\n        if self.config.use_kl_loss:\n            select_keys.append(\"ref_log_prob\")\n        self.has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n        if self.has_multi_modal_inputs:\n            data = data.select(select_keys, [\"multi_modal_inputs\"])\n        else:\n            data = data.select(batch_keys=select_keys)\n        return data.make_iterator(\n            mini_batch_size=self.config.ppo_mini_batch_size,\n            epochs=self.config.ppo_epochs,\n            seed=self.config.data_loader_seed,\n            dataloader_kwargs={\"shuffle\": self.config.shuffle},\n        )\n\n    def forward_backward_batch(\n        self,\n        data: DataProto,\n        forward_only=False,\n        post_process_fn=None,\n        calculate_entropy=False,\n        use_dynamic_bsz=False,\n        micro_batch_size=None,\n        max_token_len=None,\n        mini_batch_size=None,\n    ):\n        \"\"\"\n        We assume:\n        - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input\n        - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled\n        \"\"\"\n        # broadcast from last pp rank to all other pp ranks\n        # TODO: actually, we just need to control the sampling order.\n        mini_batch = data\n        broadcast_dict_tensor(\n            mini_batch.batch,\n            src=mpu.get_pipeline_model_parallel_last_rank(),\n            group=mpu.get_pipeline_model_parallel_group(),\n        )\n        # split into micro-batches\n        mini_batch.batch[\"attention_mask\"] = mini_batch.batch[\"attention_mask\"].to(bool)\n        self.has_multi_modal_inputs = \"multi_modal_inputs\" in mini_batch.non_tensor_batch.keys()\n        if self.has_multi_modal_inputs:\n            mini_batch.batch[\"multi_modal_inputs\"] = mini_batch.non_tensor_batch[\"multi_modal_inputs\"]\n            mini_batch.batch[\"multi_modal_inputs_idx\"] = torch.Tensor(\n                list(range(len(mini_batch.non_tensor_batch[\"multi_modal_inputs\"])))\n            ).to(torch.int64)\n\n        if mini_batch.batch[\"position_ids\"].dim() == 3:  # qwen2vl mrope [bs, 3, seq_len]\n            mini_batch.batch[\"position_ids\"] = mini_batch.batch[\"position_ids\"][\n                :, 0\n            ]  # mcore patch recompute qwen2vl's pos ids during forward\n\n        indices = None\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n            if vpp_size is not None and vpp_size > 1:\n                microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage\n                micro_batches, indices = rearrange_micro_batches(\n                    batch=mini_batch.batch,\n                    num_batches_divided_by=microbatch_group_size_per_vp_stage,\n                    max_token_len=max_token_len,\n                )\n                assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (\n                    f\"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage \"\n                    f\"{microbatch_group_size_per_vp_stage} for megatron backend\"\n                )\n            else:\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)\n            total_seqlen = max_token_len\n        else:\n            assert micro_batch_size is not None, (\n                \"micro_batch_size is needed to be passed in when not using dynamic batch size\"\n            )\n            micro_batches = mini_batch.batch.split(micro_batch_size)\n            seq_len = micro_batches[0][\"input_ids\"].shape[1]\n            total_seqlen = micro_batch_size * seq_len\n        # compute input shapes for pp stages\n        n_micro_batch = len(micro_batches)\n\n        forward_backward_func = get_forward_backward_func()\n\n        def loss_func(output, data, meta_info):\n            # For memory efficiency\n            # We move calculation of entropy to compute_log_probs, forward_only == True\n            device = output[\"log_probs\"].device\n            metrics = {}\n            if forward_only:\n                if post_process_fn is None:\n                    pass\n                    # metrics[\"logits\"] = output\n                else:\n                    stats = post_process_fn(output, data)\n                    metrics.update(stats)\n                if not calculate_entropy:\n                    return torch.tensor(1.0, device=device), metrics\n\n            responses = data[\"responses\"]\n            response_length = responses.size(1)\n            response_mask = data[\"response_mask\"].to(bool)\n            loss_agg_mode = self.config.loss_agg_mode\n\n            # compute policy loss\n            log_prob = output[\"log_probs\"][:, -response_length - 1 : -1].contiguous()\n            ret_entropy = None\n            stats = {}\n            if not forward_only:\n                old_log_prob = data[\"old_log_probs\"]\n                advantages = data[\"advantages\"]\n\n                clip_ratio = self.config.clip_ratio\n                clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio\n                clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio\n\n                clip_ratio_c = self.config.get(\"clip_ratio_c\", 3.0)\n                entropy_coeff = self.config.entropy_coeff\n                loss_agg_mode = self.config.loss_agg_mode\n\n                loss_mode = self.config.policy_loss.get(\"loss_mode\", \"vanilla\")\n\n                if self.config.policy_loss.loss_mode == \"vanilla\":\n                    pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(\n                        old_log_prob=old_log_prob,\n                        log_prob=log_prob,\n                        advantages=advantages,\n                        response_mask=response_mask,\n                        cliprange=clip_ratio,\n                        cliprange_low=clip_ratio_low,\n                        cliprange_high=clip_ratio_high,\n                        clip_ratio_c=clip_ratio_c,\n                        loss_agg_mode=loss_agg_mode,\n                    )\n\n                else:\n                    policy_loss_fn = get_policy_loss_fn(loss_mode)\n                    pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(\n                        old_log_prob=old_log_prob,\n                        log_prob=log_prob,\n                        advantages=advantages,\n                        response_mask=response_mask,\n                        loss_agg_mode=loss_agg_mode,\n                        config=self.config,\n                    )\n\n                stats.update(\n                    {\n                        \"actor/pg_loss\": pg_loss.detach().item(),\n                        \"actor/pg_clipfrac\": pg_clipfrac.detach().item(),\n                        \"actor/ppo_kl\": ppo_kl.detach().item(),\n                        \"actor/pg_clipfrac_lower\": pg_clipfrac_lower.detach().item(),\n                    }\n                )\n                policy_loss = pg_loss\n\n            if calculate_entropy:\n                entropy = output[\"entropy\"][:, -response_length - 1 : -1].contiguous()\n                if not forward_only:\n                    entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n                    entropy_coeff = meta_info[\"entropy_coeff\"]\n                    policy_loss = pg_loss - entropy_coeff * entropy_loss\n                else:\n                    ret_entropy = entropy\n\n            if forward_only:\n                policy_loss = torch.tensor(1.0, device=device)\n            else:\n                if self.config.use_kl_loss:\n                    ref_log_prob = data[\"ref_log_prob\"]\n                    # compute kl loss\n                    kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)\n                    kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)\n\n                    policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef\n                    metrics[\"actor/kl_loss\"] = kl_loss.detach().item()\n                    metrics[\"actor/kl_coef\"] = self.config.kl_loss_coef\n\n                # return loss and stats\n\n            append_to_dict(metrics, stats)\n            return policy_loss, [metrics, ret_entropy]\n\n        def forward_step(batch_iter, model):\n            batch = next(batch_iter)\n            input_ids = batch[\"input_ids\"]\n            attention_mask = batch[\"attention_mask\"].to(bool)\n            position_ids = batch[\"position_ids\"]\n\n            multi_modal_inputs = {}\n            if \"multi_modal_inputs\" in batch:\n                for key in batch[\"multi_modal_inputs\"][0].keys():\n                    idxs = batch[\"multi_modal_inputs_idx\"]\n                    mmi = batch[\"multi_modal_inputs\"]\n                    multi_modal_inputs[key] = torch.cat(\n                        [mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0\n                    )\n            responses = batch[\"responses\"]\n            response_length = responses.size(1)\n            label = position_ids.clone()\n            label[:, -response_length - 1 : -1] = responses\n            label_mask = attention_mask.clone()\n            label_mask[:, : -response_length - 1] = False\n            label_mask[:, -1] = False\n\n            from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn\n\n            if self.use_fused_kernels:\n                forward_fn = get_mcore_forward_fused_fn(self.hf_config)\n                # return dict of [logits, entropy]\n                output = forward_fn(\n                    model,\n                    input_ids,\n                    position_ids,\n                    attention_mask,\n                    sequence_parallel=self.tf_config.sequence_parallel,\n                    multi_modal_inputs=multi_modal_inputs,\n                    labels=label,\n                    labels_mask=label_mask,\n                )\n            else:\n                forward_fn = get_mcore_forward_fn(self.hf_config)\n\n                def logits_processor(logits, label, label_mask):\n                    assert logits.shape[:2] == label.shape[:2]\n                    assert label.shape == label_mask.shape\n                    ret = {}\n                    if calculate_entropy:\n                        entropy = vocab_parallel_entropy(logits)\n                        ret[\"entropy\"] = entropy\n                    log_probs = vocab_parallel_log_probs_from_logits(logits, label)\n                    log_probs = log_probs.masked_fill(~label_mask, 0.0)\n                    ret[\"log_probs\"] = log_probs\n                    return ret\n\n                logits_processor_args = {\"label\": label, \"label_mask\": label_mask}\n                output = forward_fn(\n                    model,\n                    input_ids,\n                    attention_mask,\n                    position_ids,\n                    sequence_parallel=self.tf_config.sequence_parallel,\n                    multi_modal_inputs=multi_modal_inputs,\n                    logits_processor=logits_processor,\n                    logits_processor_args=logits_processor_args,\n                )\n\n            if forward_only:\n                meta_info = None\n            else:\n                clip_ratio_c = self.config.get(\"clip_ratio_c\", 3.0)\n                meta_info = {\n                    \"clip_ratio\": self.config.clip_ratio,\n                    \"entropy_coeff\": self.config.entropy_coeff,\n                    \"clip_ratio_c\": clip_ratio_c,\n                }\n            return output, partial(loss_func, data=batch, meta_info=meta_info)\n\n        # batch should be a list of batches inside micro-batches\n        batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.actor_module))\n\n        # TODO: we may use the new schedule instead\n        # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)\n        if mpu.get_pipeline_model_parallel_world_size() > 1:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.actor_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # no use when input_shapes was set\n                micro_batch_size=1,  # no use when input_shapes was set\n                forward_only=forward_only,\n            )\n        else:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.actor_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # in use for pp = 1\n                micro_batch_size=1,  # in use for pp = 1\n                forward_only=forward_only,\n            )\n        # loss_reduces contains the stats returned from loss_func\n\n        if self.has_multi_modal_inputs:\n            data.batch.pop(\"multi_modal_inputs\")\n            data.batch.pop(\"multi_modal_inputs_idx\")\n            data.non_tensor_batch.pop(\"multi_modal_inputs\")\n\n        losses_reduced = {\"output\": losses_reduced}\n        if use_dynamic_bsz:\n            losses_reduced[\"indices\"] = indices\n        return losses_reduced\n\n    @GPUMemoryLogger(role=\"megatron actor\", logger=logger)\n    def update_policy(self, dataloader: Iterable[DataProto]) -> dict:\n        \"\"\"Update the policy with an iterator of DataProto\n\n        Args:\n            dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator``\n                The keys of each data batch is described in the make_minibatch_iterator.\n\n        Returns:\n            Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage\n            and users have to combine the output in each dp rank manually.\n\n        \"\"\"\n        metrics = {}\n        self.prof.start()\n        for data in dataloader:\n            data.to(get_device_id())\n            self.actor_optimizer.zero_grad()\n            # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm\n            for chunk in self.actor_module:\n                # if use distributed optimizer, zero grad buffer will be handled by optimizer\n                chunk.zero_grad_buffer()\n\n            calculate_entropy = self.config.entropy_coeff != 0\n            if data.meta_info.get(\"micro_batch_size\", None) is not None:\n                micro_batch_size = data.meta_info[\"micro_batch_size\"]\n            else:\n                micro_batch_size = self.config.ppo_micro_batch_size_per_gpu\n            max_token_len = None\n            if self.config.use_dynamic_bsz:\n                max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size\n            metric_micro_batch = self.forward_backward_batch(\n                data,\n                calculate_entropy=calculate_entropy,\n                use_dynamic_bsz=self.config.use_dynamic_bsz,\n                micro_batch_size=micro_batch_size,\n                max_token_len=max_token_len,\n                mini_batch_size=self.config.ppo_mini_batch_size,\n            )\n            metric_micro_batch = metric_micro_batch[\"output\"]\n            for metric in metric_micro_batch:\n                # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask\n                append_to_dict(metrics, metric[0])  # append the metric from this micro-batch to global metrics.\n\n            update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step()\n            data = {\"actor/grad_norm\": grad_norm}\n            append_to_dict(metrics, data)\n\n            if update_successful:\n                # allgather already execute in optimizer.step in new megatron\n                pass\n            else:\n                raise NotImplementedError\n            self.prof.step()\n        # add empty cache after each compute\n        self.prof.stop_and_save()\n        self.prof.stop_trace()\n        get_torch_device().empty_cache()\n        return metrics\n"
  },
  {
    "path": "verl_rl/verl/workers/critic/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BasePPOCritic\nfrom .dp_critic import DataParallelPPOCritic\n\n__all__ = [\"BasePPOCritic\", \"DataParallelPPOCritic\"]\n"
  },
  {
    "path": "verl_rl/verl/workers/critic/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nBase class for a critic\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nimport torch\n\nfrom verl import DataProto\n\n__all__ = [\"BasePPOCritic\"]\n\n\nclass BasePPOCritic(ABC):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n\n    @abstractmethod\n    def compute_values(self, data: DataProto) -> torch.Tensor:\n        \"\"\"Compute values\"\"\"\n        pass\n\n    @abstractmethod\n    def update_critic(self, data: DataProto):\n        \"\"\"Update the critic\"\"\"\n        pass\n"
  },
  {
    "path": "verl_rl/verl/workers/critic/dp_critic.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement a multiprocess PPOCritic\n\"\"\"\n\nimport logging\nimport os\n\nimport torch\nimport torch.distributed\nfrom torch import nn, optim\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nfrom verl import DataProto\nfrom verl.trainer.ppo import core_algos\nfrom verl.utils.device import get_device_name, is_cuda_available, is_npu_available\nfrom verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch\nfrom verl.utils.torch_functional import masked_mean\nfrom verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs\nfrom verl.workers.critic import BasePPOCritic\n\nif is_cuda_available:\n    from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\nelif is_npu_available:\n    from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass DataParallelPPOCritic(BasePPOCritic):\n    def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer):\n        super().__init__(config=config)\n        self.critic_module = critic_module\n        self.critic_optimizer = critic_optimizer\n        self.use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        print(f\"Critic use_remove_padding={self.use_remove_padding}\")\n\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        self.device_name = get_device_name()\n\n    def _forward_micro_batch(self, micro_batch):\n        response_length = micro_batch[\"responses\"].size(-1)\n        multi_modal_inputs = {}\n        if \"multi_modal_inputs\" in micro_batch.keys():\n            for key in micro_batch[\"multi_modal_inputs\"][0].keys():\n                multi_modal_inputs[key] = torch.cat(\n                    [inputs[key] for inputs in micro_batch[\"multi_modal_inputs\"]], dim=0\n                )\n\n        with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n            if position_ids.dim() == 3:  # qwen2vl mrope\n                position_ids = position_ids.transpose(0, 1)\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                if position_ids.dim() == 3:\n                    position_ids_rmpad = (\n                        index_first_axis(rearrange(position_ids, \"c b s ... -> (b s) c ...\"), indices)\n                        .transpose(0, 1)\n                        .unsqueeze(1)\n                    )  # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)\n                else:\n                    position_ids_rmpad = index_first_axis(\n                        rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                    ).transpose(0, 1)\n\n                # pad and slice the inputs if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size\n                    )\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                output = self.critic_module(\n                    input_ids=input_ids_rmpad,\n                    attention_mask=None,\n                    position_ids=position_ids_rmpad,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                )  # prevent model thinks we are generating\n\n                if hasattr(self.critic_module, \"v_head\"):\n                    # For trl.AutoModelForCausalLMWithValueHead\n                    values_rmpad = output[2].squeeze(0).unsqueeze(-1)\n                else:\n                    values_rmpad = output.logits\n                    values_rmpad = values_rmpad.squeeze(0)  # (total_nnz)\n\n                # gather output if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    values_rmpad = gather_outputs_and_unpad(\n                        values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size\n                    )\n\n                # pad it back\n                values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)\n                values = values[:, -response_length - 1 : -1]\n            else:\n                output = self.critic_module(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                )  # prevent model thinks we are generating\n                if hasattr(self.critic_module, \"v_head\"):\n                    # For trl.AutoModelForCausalLMWithValueHead\n                    values = output[2]\n                else:\n                    values = output.logits\n                values = values[:, -response_length - 1 : -1].squeeze(-1)\n            return values\n\n    def _optimizer_step(self):\n        assert self.config.grad_clip is not None\n\n        if isinstance(self.critic_module, FSDP):\n            grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip)\n        elif isinstance(self.critic_module, FSDPModule):\n            grad_norm = fsdp2_clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip)\n\n        # if grad_norm is not finite, skip the update\n        if not torch.isfinite(grad_norm):\n            print(f\"WARN: grad_norm is not finite: {grad_norm}\")\n            self.critic_optimizer.zero_grad()\n        else:\n            self.critic_optimizer.step()\n        return grad_norm\n\n    @GPUMemoryLogger(role=\"dp critic\", logger=logger)\n    def compute_values(self, data: DataProto) -> torch.Tensor:\n        self.critic_module.eval()\n        micro_batch_size = data.meta_info[\"micro_batch_size\"]\n        use_dynamic_bsz = data.meta_info[\"use_dynamic_bsz\"]\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n        select_keys = [\"responses\", \"input_ids\", \"response_mask\", \"attention_mask\", \"position_ids\"]\n        non_tensor_select_keys = [\"multi_modal_inputs\"] if has_multi_modal_inputs else []\n\n        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)\n\n        if use_dynamic_bsz:\n            max_token_len = data.meta_info[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len)\n        else:\n            micro_batches = data.split(micro_batch_size)\n\n        values_lst = []\n        for micro_batch in micro_batches:\n            model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n            with torch.no_grad():\n                values = self._forward_micro_batch(model_inputs)\n            values_lst.append(values)\n        values = torch.concat(values_lst, dim=0)\n\n        if use_dynamic_bsz:\n            values = restore_dynamic_batch(values, batch_idx_list)\n\n        response_mask = data.batch[\"response_mask\"]\n        values = values * response_mask  # Only action tokens have values\n        return values\n\n    @GPUMemoryLogger(role=\"dp critic\", logger=logger)\n    def update_critic(self, data: DataProto):\n        # make sure we are in training mode\n        self.critic_module.train()\n        metrics = {}\n\n        select_keys = [\"input_ids\", \"responses\", \"response_mask\", \"attention_mask\", \"position_ids\", \"values\", \"returns\"]\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n        non_tensor_select_keys = [\"multi_modal_inputs\"] if has_multi_modal_inputs else []\n\n        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        mini_batches = data.split(self.config.ppo_mini_batch_size)\n\n        for _ in range(self.config.ppo_epochs):\n            for batch_idx, mini_batch in enumerate(mini_batches):\n                if self.config.use_dynamic_bsz:\n                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                    micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len)\n                else:\n                    self.gradient_accumulation = (\n                        self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n                    )\n                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n\n                self.critic_optimizer.zero_grad()\n\n                for micro_batch in micro_batches:\n                    micro_batch_metrics = {}\n                    model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n                    response_mask = model_inputs[\"response_mask\"]\n                    values = model_inputs[\"values\"]\n                    returns = model_inputs[\"returns\"]\n\n                    vpreds = self._forward_micro_batch(model_inputs)\n                    vf_loss, vf_clipfrac = core_algos.compute_value_loss(\n                        vpreds=vpreds,\n                        values=values,\n                        returns=returns,\n                        response_mask=response_mask,\n                        cliprange_value=self.config.cliprange_value,\n                        loss_agg_mode=self.config.loss_agg_mode,\n                    )\n                    if self.config.use_dynamic_bsz:\n                        # relative to the dynamic bsz\n                        loss = vf_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size)\n                    else:\n                        loss = vf_loss / self.gradient_accumulation\n\n                    loss.backward()\n\n                    micro_batch_metrics.update(\n                        {\n                            \"critic/vf_loss\": vf_loss.detach().item(),\n                            \"critic/vf_clipfrac\": vf_clipfrac.detach().item(),\n                            \"critic/vpred_mean\": masked_mean(vpreds, response_mask).detach().item(),\n                        }\n                    )\n\n                    append_to_dict(metrics, micro_batch_metrics)\n\n                grad_norm = self._optimizer_step()\n                mini_batch_metrics = {\"critic/grad_norm\": grad_norm.detach().item()}\n                append_to_dict(metrics, mini_batch_metrics)\n        self.critic_optimizer.zero_grad()\n        return metrics\n"
  },
  {
    "path": "verl_rl/verl/workers/critic/megatron_critic.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement a multiprocess PPOCritic\n\"\"\"\n\nimport itertools\nimport logging\nimport os\nfrom functools import partial\nfrom typing import Iterable\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.optimizer import DistributedOptimizer, OptimizerConfig\nfrom megatron.core.pipeline_parallel import get_forward_backward_func\nfrom omegaconf import OmegaConf\nfrom torch import nn\n\nfrom verl import DataProto\nfrom verl.trainer.ppo import core_algos\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.megatron.pipeline_parallel import make_batch_generator\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.utils.torch_functional import broadcast_dict_tensor, masked_mean\nfrom verl.workers.critic import BasePPOCritic\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass MegatronPPOCritic(BasePPOCritic):\n    def __init__(\n        self,\n        config,\n        model_config,\n        hf_config,\n        tf_config,\n        critic_module: nn.ModuleList,\n        critic_optimizer: DistributedOptimizer,\n        critic_optimizer_config: OptimizerConfig,\n    ):\n        super().__init__(config=config)\n        self._validate_config(config)\n        self.model_config = model_config\n        self.hf_config = hf_config  # huggingface config\n        self.tf_config = tf_config  # mcore transformer config\n\n        self.critic_module = critic_module\n        self.critic_optimizer = critic_optimizer\n        self.critic_optimizer_config = critic_optimizer_config\n\n        # we create a separate nametuple for optimizer step so that global args won't affect it.\n        self.optimizer_step_args = OmegaConf.create(\n            {\n                \"skip_grad\": None,\n                \"overlap_dp_param_comm\": False,\n                \"overlap_dp_grad_comm\": False,\n                \"gradient_accumulation_steps\": 1,\n                \"sequence_parallel\": self.tf_config.sequence_parallel,\n                \"DDP_impl\": \"local\",\n                \"layernorm_allreduce_bucket_threshold\": 0,\n                \"pipeline_model_parallel_split_rank\": None,\n                \"reduce_grads_use_alltoall\": False,\n            }\n        )\n\n    def _validate_config(self, config) -> None:\n        \"\"\"Validate config options not implemented for Megatron backend\"\"\"\n        assert config.get(\"ulysses_sequence_parallel_size\", 1) == 1\n        if config.shuffle:\n            assert config.data_loader_seed is not None, \"If shuffle dataloader, seed must be manually set\"\n        if config.megatron.tensor_model_parallel_size == 1:\n            print(\"[Warining] Because critic tp size == 1, set sp to False\")\n            config.megatron.sequence_parallel = False\n        self.config = config\n\n    @GPUMemoryLogger(\"megatron critic\", logger=logger)\n    def compute_values(self, data: DataProto) -> DataProto:\n        data.to(get_device_id())\n        responses = data.batch[\"responses\"]\n        attention_mask = data.batch[\"attention_mask\"]\n        use_dynamic_bsz = data.meta_info.get(\"use_dynamic_bsz\", False)\n        micro_batch_size = data.meta_info.get(\"micro_batch_size\", None)\n        max_token_len = data.meta_info.get(\"max_token_len\", None)\n        assert micro_batch_size is not None, \"micro batch size is needed for forward compute\"\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            max_token_len = max_token_len * self.config.megatron.context_parallel_size\n        response_length = responses.size(1)\n        with torch.no_grad():\n            output = self.forward_backward_batch(\n                data=data,\n                forward_only=True,\n                use_dynamic_bsz=use_dynamic_bsz,\n                micro_batch_size=micro_batch_size,\n                max_token_len=max_token_len,\n                mini_batch_size=None,\n            )\n            if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                # only on last rank. It should be on every tp rank\n                values = [o[\"vpreds\"] for o in output[\"output\"]]  # (bs, seq_size, vocal_size)\n                values = torch.cat(values, dim=0).to(torch.float32)\n                if use_dynamic_bsz:\n                    indices = output[\"indices\"]\n                    indices = list(itertools.chain.from_iterable(indices))\n                    assert len(indices) == values.size(0), f\"{len(indices)} vs. {values.size()}\"\n                    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                    values = values[revert_indices]\n            else:\n                values = torch.empty_like(attention_mask, dtype=torch.float32)\n\n            # each tp ranks should contain the same value\n            values = values[\n                :, -response_length - 1 : -1\n            ]  # Values are predicted at the ends of prefixes, e.g., the last prompt token\n            response_mask = attention_mask[:, -response_length:]\n            values = values * response_mask  # Only action tokens have values\n            values = values.contiguous()\n\n            # sync among pp ranks\n            torch.distributed.broadcast(\n                tensor=values,\n                src=mpu.get_pipeline_model_parallel_last_rank(),\n                group=mpu.get_pipeline_model_parallel_group(),\n            )\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        return values\n\n    def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:\n        select_keys = [\"input_ids\", \"responses\", \"attention_mask\", \"position_ids\", \"values\", \"returns\"]\n        data = data.select(batch_keys=select_keys)\n        return data.make_iterator(\n            mini_batch_size=self.config.ppo_mini_batch_size,\n            epochs=self.config.ppo_epochs,\n            seed=self.config.data_loader_seed,\n            dataloader_kwargs={\"shuffle\": self.config.shuffle},\n        )\n\n    def forward_backward_batch(\n        self,\n        data: DataProto,\n        forward_only=False,\n        use_dynamic_bsz=False,\n        micro_batch_size=None,\n        max_token_len=None,\n        mini_batch_size=None,\n    ):\n        # broadcast from last pp rank to all other pp ranks\n        mini_batch = data\n        mini_batch.to(get_device_id())\n        mini_batch.batch = mini_batch.batch.contiguous()\n        broadcast_dict_tensor(\n            mini_batch.batch,\n            src=mpu.get_pipeline_model_parallel_last_rank(),\n            group=mpu.get_pipeline_model_parallel_group(),\n        )\n        # split into micro-batches\n        mini_batch.batch[\"attention_mask\"] = mini_batch.batch[\"attention_mask\"].to(bool)\n\n        indices = None\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n            if vpp_size is not None and vpp_size > 1:\n                microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage\n                micro_batches, indices = rearrange_micro_batches(\n                    batch=mini_batch.batch,\n                    num_batches_divided_by=microbatch_group_size_per_vp_stage,\n                    max_token_len=max_token_len,\n                )\n                assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (\n                    f\"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage \"\n                    f\"{microbatch_group_size_per_vp_stage} for megatron backend\"\n                )\n            else:\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)\n            total_seqlen = max_token_len\n        else:\n            assert micro_batch_size is not None, (\n                \"micro_batch_size is needed to be passed in when not using dynamic batch size\"\n            )\n            micro_batches = mini_batch.batch.split(micro_batch_size)\n            seq_len = micro_batches[0][\"input_ids\"].shape[1]\n            total_seqlen = micro_batch_size * seq_len\n        n_micro_batch = len(micro_batches)\n\n        forward_backward_func = get_forward_backward_func()\n\n        def loss_func(output, data, meta_info):\n            nonlocal use_dynamic_bsz\n\n            if forward_only:\n                return torch.tensor(1.0, device=output.device), {\"vpreds\": output}\n\n            responses = data[\"responses\"]\n            attention_mask = data[\"attention_mask\"]\n            values = data[\"values\"]\n            returns = data[\"returns\"]\n            response_length = responses.size(1)\n\n            response_mask = attention_mask[:, -response_length:]\n\n            cliprange_value = self.config.cliprange_value\n\n            vpreds = output  # (bs, sequence_length)\n            vpreds = vpreds[:, -response_length - 1 : -1]\n\n            vf_loss, vf_clipfrac = core_algos.compute_value_loss(\n                vpreds=vpreds,\n                values=values,\n                returns=returns,\n                response_mask=response_mask,\n                cliprange_value=cliprange_value,\n                loss_agg_mode=self.config.loss_agg_mode,\n            )\n\n            stats = {\n                \"critic/vf_loss\": vf_loss.detach().item(),\n                \"critic/vf_clipfrac\": vf_clipfrac.detach().item(),\n                \"critic/vpred_mean\": masked_mean(vpreds, response_mask).detach().item(),\n            }\n\n            return vf_loss, stats\n\n        def forward_step(batch_iter, model):\n            batch = next(batch_iter)\n            input_ids = batch[\"input_ids\"]\n            attention_mask = batch[\"attention_mask\"]\n            position_ids = batch[\"position_ids\"]\n            from verl.models.mcore import get_mcore_forward_fn\n\n            forward_fn = get_mcore_forward_fn(self.hf_config)\n\n            output = forward_fn(\n                model,\n                input_ids,\n                attention_mask,\n                position_ids,\n                sequence_parallel=self.tf_config.sequence_parallel,\n                value_model=True,\n            )\n\n            return output, partial(loss_func, data=batch, meta_info={})\n\n        # batch should be a list of batches inside micro-batches\n        batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.critic_module))\n\n        # TODO: we may use the new schedule instead\n        # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)\n        if mpu.get_pipeline_model_parallel_world_size() > 1:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.critic_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # no use when input_shapes was set\n                micro_batch_size=1,  # no use when input_shapes was set\n                forward_only=forward_only,\n            )\n        else:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.critic_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # in use for pp = 1\n                micro_batch_size=1,  # in use for pp = 1\n                forward_only=forward_only,\n            )\n        # loss_reduces contains the stats returned from loss_func\n        losses_reduced = {\"output\": losses_reduced}\n        if use_dynamic_bsz:\n            losses_reduced[\"indices\"] = indices\n        return losses_reduced\n\n    @GPUMemoryLogger(\"megatron critic\", logger=logger)\n    def update_critic(self, dataloader: Iterable[DataProto]):\n        metrics = {}\n\n        for data in dataloader:\n            # data = data.batch.to(self.critic_module.device)\n            self.critic_optimizer.zero_grad()\n            # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm\n            for chunk in self.critic_module:\n                chunk.zero_grad_buffer()\n\n            micro_batch_size = self.config.ppo_micro_batch_size_per_gpu\n            max_token_len = None\n            if self.config.use_dynamic_bsz:\n                max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size\n            metric_micro_batch = self.forward_backward_batch(\n                data,\n                forward_only=False,\n                use_dynamic_bsz=self.config.use_dynamic_bsz,\n                micro_batch_size=micro_batch_size,\n                max_token_len=max_token_len,\n                mini_batch_size=self.config.ppo_mini_batch_size,\n            )\n            metric_micro_batch = metric_micro_batch[\"output\"]\n            update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step()\n            learning_rate = self.critic_optimizer.param_groups[-1][\"lr\"]\n            data = {\"critic/grad_norm\": grad_norm, \"critic/lr\": learning_rate}\n            append_to_dict(metrics, data)\n\n            if update_successful:\n                # allgather already execute in optimizer.step in new megatron\n                pass\n            else:\n                raise NotImplementedError\n\n            for metric in metric_micro_batch:\n                append_to_dict(metrics, metric)  # append the metric from this micro-batch to global metrics.\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n        return metrics\n"
  },
  {
    "path": "verl_rl/verl/workers/engine/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom .base import BaseEngine, EngineRegistry\nfrom .fsdp import FSDPEngine\n\n__all__ = [\"BaseEngine\", \"EngineRegistry\", \"FSDPEngine\"]\n"
  },
  {
    "path": "verl_rl/verl/workers/engine/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe abstract base class defining the interface for model training engines.\n\"\"\"\n\nfrom typing import Callable\n\nimport torch\n\nfrom verl import DataProto\n\n\nclass BaseEngine:\n    \"\"\"\n    Abstract base class defining the interface for model training engines.\n\n    Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.\n    \"\"\"\n\n    def __init__(self, config):\n        \"\"\"\n        Initialize the BaseEngine.\n\n        Args:\n            config: Configuration object containing parameters for engine setup.\n        \"\"\"\n        raise NotImplementedError\n\n    def init_model(self):\n        \"\"\"\n        Instantiate or load the model, optimizer, and learning rate scheduler.\n\n        Should prepare all components necessary for training or evaluation.\n        \"\"\"\n        raise NotImplementedError\n\n    def train_mode(self):\n        \"\"\"\n        Context manager entry for switching the engine and model into training mode.\n\n        Usage:\n            with engine.train_mode():\n                # runs in training mode\n        \"\"\"\n        raise NotImplementedError\n\n    def eval_mode(self):\n        \"\"\"\n        Context manager entry for switching the engine and model into evaluation mode.\n\n        Usage:\n            with engine.eval_mode():\n                # runs in evaluation mode\n        \"\"\"\n        raise NotImplementedError\n\n    def infer_batch(\n        self,\n        data: DataProto,\n        post_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],\n    ) -> dict[str, torch.Tensor]:\n        \"\"\"\n        Perform inference on a mini batch of data.\n\n        Args:\n            data: The input data for inference, typically containing tensors and metadata.\n            post_fn: A post-processing function that takes a micro-batch and predictions as input,\n                     and returns a tuple containing processed predictions and a dictionary of outputs.\n\n        Returns:\n            dict[str, torch.Tensor]: A dictionary containing the predictions for the entire batch.\n        \"\"\"\n        raise NotImplementedError\n\n    def train_batch(\n        self,\n        data: DataProto,\n        loss_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],\n    ) -> dict[str, torch.Tensor]:\n        \"\"\"\n        Perform a training step on a mini-batch of data.\n\n        Args:\n            data (DataProto): The input data for training, typically containing tensors and metadata.\n            loss_fn (Callable): A function that computes the loss and metrics given a micro-batch and predictions.\n\n        Returns:\n            dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the mini-batch.\n        \"\"\"\n        raise NotImplementedError\n\n    def optimizer_zero_grad(self):\n        \"\"\"\n        Zero out gradients of all parameters before starting a new backward pass.\n        \"\"\"\n        raise NotImplementedError\n\n    def optimizer_step(self):\n        \"\"\"\n        Perform an optimization step to update model parameters based on accumulated gradients.\n\n        Returns:\n            grad_norm (float): The norm of the gradients before clipping or update.\n        \"\"\"\n        raise NotImplementedError\n\n    def lr_scheduler_step(self):\n        \"\"\"\n        Advance the learning rate scheduler by one step.\n\n        Returns:\n            current_lr (float or list[float]): Updated learning rate(s).\n        \"\"\"\n        raise NotImplementedError\n\n    def shard_data(self, data):\n        \"\"\"\n        Shard or partition data for distributed training or parallel execution.\n\n        Args:\n            data: Data structure to be sharded across devices/workers.\n\n        Returns:\n            Sharded data in the same format as input.\n        \"\"\"\n        raise NotImplementedError\n\n    def unshard_data(self, data):\n        \"\"\"\n        Reconstruct or gather sharded data back to a unified format.\n\n        Args:\n            data: Sharded data structure to reconstruct.\n\n        Returns:\n            Unsharded, combined data.\n        \"\"\"\n        raise NotImplementedError\n\n    def to(self, device: str, model: bool = True, optimizer: bool = True):\n        \"\"\"\n        Move model parameters, optimizer states, or both to the specified device.\n\n        Args:\n            device: Target device identifier.\n            model: If True, move the model.\n            optimizer: If True, move the optimizer states.\n        \"\"\"\n        raise NotImplementedError\n\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        \"\"\"\n        Save model, optimizer, and scheduler states to a checkpoint.\n\n        Args:\n            local_path: Local filesystem path to save checkpoint.\n            hdfs_path: Optional HDFS path to copy checkpoint.\n            global_step: Integer training step number for naming.\n            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.\n        \"\"\"\n        raise NotImplementedError\n\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):\n        \"\"\"\n        Load model, optimizer, and scheduler states from a checkpoint.\n\n        Args:\n            local_path: Local filesystem path of the checkpoint.\n            hdfs_path: Optional HDFS path where checkpoint is stored.\n            del_local_after_load: Whether to delete local copy after loading.\n        \"\"\"\n        raise NotImplementedError\n\n\nclass EngineRegistry:\n    \"\"\"\n    A registry for managing and instantiating different types of training engines.\n\n    This class uses a dictionary to store engine classes, mapping a string key to each class.\n    It provides a decorator `register` to add new engines to the registry and a `new` method\n    to create an instance of a registered engine.\n    \"\"\"\n\n    _engines = {}\n\n    @classmethod\n    def register(cls, key):\n        \"\"\"\n        A class method decorator that registers an engine class with a given key.\n\n        This allows for dynamic instantiation of engine classes by their registered key.\n\n        Args:\n            key (str): The identifier to associate with the engine class.\n\n        Returns:\n            A decorator function that takes an engine class and registers it.\n        \"\"\"\n\n        def decorator(engine_class):\n            assert issubclass(engine_class, BaseEngine)\n            cls._engines[key] = engine_class\n            return engine_class\n\n        return decorator\n\n    @classmethod\n    def new(cls, key, *args, **kwargs):\n        \"\"\"\n        Function to create a new training engine instance based on the provided config.\n        Args:\n            key: A configuration object containing the engine key and other settings.\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n        Returns:\n            engine: An instance of the training engine corresponding to the config.\n        Raises:\n            NotImplementedError: If the engine key in the config does not match any known engines.\n        \"\"\"\n        if key in cls._engines:\n            return cls._engines[key](*args, **kwargs)\n        else:\n            raise NotImplementedError(f\"Unknown engine: {key}\")\n"
  },
  {
    "path": "verl_rl/verl/workers/engine/fsdp/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom .engine_impl import FSDPEngine\n\n__all__ = [\"FSDPEngine\"]\n"
  },
  {
    "path": "verl_rl/verl/workers/engine/fsdp/engine_impl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP)\n\"\"\"\n\nimport gc\nimport itertools\nimport logging\nimport os\nimport warnings\nfrom typing import Callable\n\nimport torch\nimport torch.distributed\nfrom omegaconf import OmegaConf\nfrom peft import LoraConfig, TaskType, get_peft_model\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nfrom verl import DataProto\nfrom verl.models.transformers.monkey_patch import apply_monkey_patch\nfrom verl.utils import hf_processor, hf_tokenizer\nfrom verl.utils.activation_offload import enable_activation_offloading\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.debug import log_gpu_memory_usage\nfrom verl.utils.device import (\n    get_device_id,\n    get_device_name,\n    get_torch_device,\n    is_cuda_available,\n    is_npu_available,\n)\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.fsdp_utils import (\n    CPUOffloadPolicy,\n    FSDPModule,\n    MixedPrecisionPolicy,\n    apply_fsdp2,\n    fsdp2_clip_grad_norm_,\n    fsdp2_load_full_state_dict,\n    get_fsdp_wrap_policy,\n    get_init_weight_context_manager,\n    init_fn,\n    load_fsdp_model_to_gpu,\n    load_fsdp_optimizer,\n    offload_fsdp_model_to_cpu,\n    offload_fsdp_optimizer,\n)\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.py_functional import append_to_dict, convert_to_regular_types\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\nif is_cuda_available:\n    from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\nelif is_npu_available:\n    from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input\n\nfrom ..base import BaseEngine, EngineRegistry\nfrom .utils import create_device_mesh, get_sharding_strategy\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndevice_name = get_device_name()\n\n\n@EngineRegistry.register(\"fsdp\")\nclass FSDPEngine(BaseEngine):\n    \"\"\"\n    Concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP).\n\n    Supports model sharding, activation/optimizer offloading, LoRA, and sequence parallelism.\n    \"\"\"\n\n    def __init__(self, config):\n        \"\"\"\n        Initialize the FSDPEngine.\n\n        Sets up distributed device meshes, LoRA, and offload policies based on config.\n\n        Args:\n            config: Configuration object with FSDP and model settings.\n        \"\"\"\n        self.config = config\n        self.rank = torch.distributed.get_rank()\n        # build device mesh for Ulysses Sequence Parallel\n        world_size = torch.distributed.get_world_size()\n        from torch.distributed.device_mesh import init_device_mesh\n\n        fsdp_size = self.config.model.fsdp_config.fsdp_size\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)\n        self.use_remove_padding = config.model.get(\"use_remove_padding\", False)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        # set FSDP offload params\n        self._is_offload_param = self.config.model.fsdp_config.param_offload\n        self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload\n\n        # normalize config\n        self.config.ppo_mini_batch_size *= self.config.rollout_n\n        self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n        if self.config.ppo_micro_batch_size is not None:\n            self.config.ppo_micro_batch_size //= (\n                torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n            )\n            self.config.forward_micro_batch_size //= (\n                torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n            )\n            self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size\n            self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size\n\n        if self.config.ppo_micro_batch_size_per_gpu is not None:\n            assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, (\n                f\"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by \"\n                f\"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}\"\n            )\n            assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, (\n                f\"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than \"\n                f\"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}\"\n            )\n        self._is_lora = self.config.model.get(\"lora_rank\", 0) > 0\n\n    def init_model(self):\n        \"\"\"\n        Build the model, optimizer, and learning rate scheduler under FSDP.\n\n        Applies device, dtype, and precision configurations, including mixed precision.\n        Sets up checkpoint manager and FLOPs counter.\n        \"\"\"\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        self.module, self.optimizer, self.lr_scheduler = self._build_model_optimizer(self.config)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n            log_gpu_memory_usage(\"After offload model during init\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.optimizer)\n            log_gpu_memory_usage(\"After offload optimizer during init\", logger=logger)\n\n        self.flops_counter = FlopsCounter(self.model_config)\n        self.checkpoint_manager = FSDPCheckpointManager(\n            model=self.module,\n            optimizer=self.optimizer,\n            lr_scheduler=self.lr_scheduler,\n            processing_class=self.processor if self.processor is not None else self.tokenizer,\n            checkpoint_contents=self.config.checkpoint,\n        )\n\n    def _build_model_optimizer(self, config):\n        # the following line is necessary\n        from torch import optim\n        from torch.distributed.fsdp import MixedPrecision\n\n        from verl.utils.model import load_valuehead_model, print_model_size\n        from verl.utils.torch_dtypes import PrecisionType\n\n        use_shm = config.model.get(\"use_shm\", False)\n        local_path = copy_to_local(config.model.path, use_shm=use_shm)\n        # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info\n        # using random initialized model from any architecture. May not be the same as Actor.\n\n        tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm)\n        self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n        self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n\n        if self.config.model.get(\"custom_chat_template\", None) is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.config.model.custom_chat_template\n            else:\n                self.tokenizer.chat_template = self.config.model.custom_chat_template\n\n        override_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_config)\n        if self.rank == 0:\n            print(f\"Engine overriding config {override_config_kwargs}\")\n\n        torch_dtype = self.config.model.fsdp_config.get(\"model_dtype\", \"fp32\")\n        torch_dtype = PrecisionType.to_dtype(torch_dtype)\n\n        from transformers import AutoConfig\n\n        model_config = AutoConfig.from_pretrained(\n            local_path,\n            attn_implementation=\"flash_attention_2\",\n            trust_remote_code=config.model.get(\"trust_remote_code\", False),\n        )\n        model_config.num_labels = 1\n        # patch for kimi-vl\n        if getattr(model_config, \"model_type\", None) == \"kimi_vl\":\n            model_config.text_config.topk_method = \"greedy\"\n\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            model_config.classifier_dropout = 0.0\n            model_config.hidden_dropout = \"0\"\n            model_config.summary_dropout_prob = 0.0\n\n            module = load_valuehead_model(\n                local_path,\n                torch_dtype,\n                model_config,\n                config.model.get(\"trust_remote_code\", False),\n            )\n\n            apply_monkey_patch(\n                model=module,\n                use_remove_padding=self.use_remove_padding,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n            )\n\n            # some parameters may not in torch_dtype\n            module.to(torch_dtype)\n\n            if config.model.get(\"enable_gradient_checkpointing\", False):\n                module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n\n        if self._is_lora:\n            print(\"Applying LoRA to the module\")\n            module.enable_input_require_grads()\n            # Convert config to regular Python types before creating PEFT model\n            lora_config = {\n                \"task_type\": TaskType.CAUSAL_LM,\n                \"r\": self.config.model.lora_rank,\n                \"lora_alpha\": self.config.model.lora_alpha,\n                \"target_modules\": convert_to_regular_types(self.config.model.target_modules),\n                \"bias\": \"none\",\n            }\n            module = get_peft_model(module, LoraConfig(**lora_config))\n\n        if self.rank == 0:\n            print_model_size(module)\n\n        self.model_config = model_config\n\n        fsdp_config = self.config.model.fsdp_config\n        mixed_precision_config = fsdp_config.get(\"mixed_precision\", None)\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"param_dtype\", \"bf16\"))\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"reduce_dtype\", \"fp32\"))\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"buffer_dtype\", \"fp32\"))\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n\n        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(\n            module=module,\n            config=self.config.model.fsdp_config.wrap_policy,\n            is_lora=self.config.model.get(\"lora_rank\", 0) > 0,\n        )\n\n        log_gpu_memory_usage(\"Before FSDP\", logger=None)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        # Note: We force turn off CPUOffload because it causes incorrect results when using grad accumulation\n        if config.strategy == \"fsdp\":\n            module = FSDP(\n                module,\n                param_init_fn=init_fn,\n                use_orig_params=False,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                forward_prefetch=self.config.model.fsdp_config.forward_prefetch,\n                device_mesh=self.device_mesh,\n                cpu_offload=None,\n            )\n        elif config.strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            mp_policy = MixedPrecisionPolicy(\n                param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True\n            )\n            offload_policy = None\n            if fsdp_config.offload_policy:\n                self._is_offload_param = False\n                self._is_offload_optimizer = False\n                offload_policy = CPUOffloadPolicy(pin_memory=True)\n\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": offload_policy,\n                \"reshard_after_forward\": fsdp_config.reshard_after_forward,\n            }\n            full_state = module.state_dict()\n            apply_fsdp2(module, fsdp_kwargs, fsdp_config)\n            fsdp2_load_full_state_dict(module, full_state, fsdp_mesh, offload_policy)\n        else:\n            raise NotImplementedError(f\"Unknown strategy {config.strategy}\")\n\n        if config.model.get(\"enable_activation_offload\", False):\n            enable_gradient_checkpointing = config.model.get(\"enable_gradient_checkpointing\", False)\n            enable_activation_offloading(module, config.strategy, enable_gradient_checkpointing)\n\n        log_gpu_memory_usage(\"After FSDP\", logger=None)\n\n        optimizer = optim.AdamW(\n            module.parameters(),\n            lr=config.optim.lr,\n            betas=config.optim.get(\"betas\", (0.9, 0.999)),\n            weight_decay=config.optim.get(\"weight_decay\", 1e-2),\n        )\n\n        total_steps = config.optim.get(\"total_training_steps\", 0)\n        num_warmup_steps = int(config.optim.get(\"lr_warmup_steps\", -1))\n        warmup_style = config.optim.get(\"warmup_style\", \"constant\")\n        if num_warmup_steps < 0:\n            num_warmup_steps_ratio = config.optim.get(\"lr_warmup_steps_ratio\", 0.0)\n            num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n\n        if self.rank == 0:\n            print(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n\n        from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup\n\n        if warmup_style == \"constant\":\n            lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)\n        elif warmup_style == \"cosine\":\n            lr_scheduler = get_cosine_schedule_with_warmup(\n                optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps\n            )\n        else:\n            raise NotImplementedError(f\"Warmup style {warmup_style} is not supported\")\n\n        return module, optimizer, lr_scheduler\n\n    def train_mode(self):\n        \"\"\"\n        Return a context manager that switches to training mode with FSDP-specific handling.\n\n        Includes parameter and optimizer offload entry/exit.\n        \"\"\"\n        return EngineTrainModeCtx(self)\n\n    def eval_mode(self):\n        \"\"\"\n        Return a context manager that switches to evaluation mode with FSDP-specific handling.\n\n        Includes activation offload entry/exit.\n        \"\"\"\n        return EngineEvalModeCtx(self)\n\n    def shard_data(self, data):\n        \"\"\"\n        Preprocess data into sharded format via UlyssesShardingManager.\n        \"\"\"\n        return self.ulysses_sharding_manager.preprocess_data(data)\n\n    def unshard_data(self, data):\n        \"\"\"\n        Postprocess data from sharded format back to full format.\n        \"\"\"\n        return self.ulysses_sharding_manager.postprocess_data(data)\n\n    def get_default_ctx(self):\n        use_value_head_model = hasattr(self.module, \"v_head\")\n        ctx = {\n            \"use_value_head_model\": use_value_head_model,\n            \"ulysses_sequence_parallel_size\": self.ulysses_sequence_parallel_size,\n        }\n        return ctx\n\n    def _forward_micro_batch(self, micro_batch):\n        multi_modal_inputs = {}\n        if \"multi_modal_inputs\" in micro_batch.keys():\n            for key in micro_batch[\"multi_modal_inputs\"][0].keys():\n                multi_modal_inputs[key] = torch.cat(\n                    [inputs[key] for inputs in micro_batch[\"multi_modal_inputs\"]], dim=0\n                )\n\n        with torch.autocast(device_type=device_name, dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n            if position_ids.dim() == 3:  # qwen2vl mrope\n                position_ids = position_ids.transpose(0, 1)\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                if position_ids.dim() == 3:\n                    position_ids_rmpad = (\n                        index_first_axis(rearrange(position_ids, \"c b s ... -> (b s) c ...\"), indices)\n                        .transpose(0, 1)\n                        .unsqueeze(1)\n                    )  # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)\n                else:\n                    position_ids_rmpad = index_first_axis(\n                        rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                    ).transpose(0, 1)\n\n                # pad and slice the inputs if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size\n                    )\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                preds = self.module(\n                    input_ids=input_ids_rmpad,\n                    attention_mask=None,\n                    position_ids=position_ids_rmpad,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                )  # prevent model thinks we are generating\n\n                if hasattr(self.module, \"v_head\"):\n                    # For trl.AutoModelForCausalLMWithValueHead\n                    preds_rmpad = preds[2].squeeze(0).unsqueeze(-1)\n                else:\n                    preds_rmpad = preds.logits\n                    preds_rmpad = preds_rmpad.squeeze(0)  # (total_nnz)\n\n                # gather output if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    preds_rmpad = gather_outpus_and_unpad(preds_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)\n\n                # pad it back\n                preds = pad_input(preds_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)\n            else:\n                preds = self.module(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                )  # prevent model thinks we are generating\n                if hasattr(self.module, \"v_head\"):\n                    # For trl.AutoModelForCausalLMWithValueHead\n                    preds = preds[2]\n                else:\n                    preds = preds.logits\n\n            return preds\n\n    def infer_batch(\n        self,\n        data: DataProto,\n        post_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],\n    ) -> dict[str, torch.Tensor]:\n        \"\"\"\n        Perform inference on a mini batch of data.\n\n        Args:\n            data: The input data for inference, typically containing tensors and metadata.\n            post_fn: A post-processing function that takes a micro-batch and predictions as input,\n                     and returns a tuple containing processed predictions and a dictionary of outputs.\n\n        Returns:\n            dict[str, torch.Tensor]: A dictionary containing the predictions for the entire batch.\n        \"\"\"\n        assert self.mode == \"eval\"\n        micro_batch_size = data.meta_info[\"micro_batch_size\"]\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n        batch = data.select(batch_keys=select_keys).batch\n        use_dynamic_bsz = data.meta_info[\"use_dynamic_bsz\"]\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n\n        if has_multi_modal_inputs:\n            num_micro_batches = data.batch.batch_size[0] // micro_batch_size\n            non_tensor_select_keys = [\"multi_modal_inputs\"]\n            micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)\n        elif use_dynamic_bsz:\n            # split using dynamic bsz\n            max_token_len = data.meta_info[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)\n        else:\n            micro_batches = batch.split(micro_batch_size)\n\n        preds_list = {}\n        for micro_batch in micro_batches:\n            if isinstance(micro_batch, DataProto):\n                micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}\n\n            with torch.no_grad():\n                # micro_batch_preds would be a dict[str, torch.Tensor]\n                preds = self._forward_micro_batch(micro_batch)\n                _, outputs = post_fn(micro_batch, preds)\n                assert isinstance(outputs, dict)\n\n            # append micro batch preds to dict[str, List[torch.Tensor]]\n            append_to_dict(preds_list, outputs)\n\n        # reorganize mini batch preds from\n        # dict[str, List[torch.Tensor]] to dict[str, torch.Tensor]\n        mini_batch_preds = {}\n        for key, t_list in preds_list.items():\n            t_concat = torch.concat(t_list, dim=0)\n\n            if use_dynamic_bsz:\n                indices = list(itertools.chain.from_iterable(indices))\n                assert len(indices) == t_concat.size(0), f\"{len(indices)} vs. {t_concat.size()}\"\n                revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                t_concat = t_concat[revert_indices]\n\n            mini_batch_preds[key] = t_concat\n\n        return mini_batch_preds\n\n    def train_batch(\n        self,\n        data: DataProto,\n        loss_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],\n    ) -> dict[str, torch.Tensor]:\n        \"\"\"\n        Perform a training step on a mini-batch of data.\n\n        Args:\n            data (DataProto): The input data for training, typically containing tensors and metadata.\n            loss_fn (Callable): A function that computes the loss and metrics given a micro-batch and predictions.\n\n        Returns:\n            dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the mini-batch.\n        \"\"\"\n        assert self.mode == \"train\"\n        # split batch into micro_batches\n        mini_batch = data\n        select_keys = [\"input_ids\", \"responses\", \"response_mask\", \"attention_mask\", \"position_ids\"]\n        if \"multi_modal_inputs\" in mini_batch:\n            non_tensor_select_keys = [\"multi_modal_inputs\"]\n            num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu\n            micro_batches = mini_batch.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)\n            self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n        elif self.config.use_dynamic_bsz:\n            max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n            micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)\n        else:\n            micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n            self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n\n        mini_batch_metrics = {}\n        for micro_batch in micro_batches:\n            # Support all devices\n            if isinstance(micro_batch, DataProto):\n                micro_batch = {**micro_batch.batch.to(get_device_id()), **micro_batch.non_tensor_batch}\n            else:\n                micro_batch = micro_batch.to(get_device_id())  # critic device is cpu when using offload\n\n            preds = self._forward_micro_batch(micro_batch)\n            loss, micro_batch_metrics = loss_fn(micro_batch, preds)\n            append_to_dict(mini_batch_metrics, micro_batch_metrics)\n            loss.backward()\n\n        return mini_batch_metrics\n\n    def optimizer_zero_grad(self):\n        \"\"\"\n        Zero gradients and enforce FSDP grad-clipping logic.\n        \"\"\"\n        self.optimizer.zero_grad()\n\n    def optimizer_step(self):\n        \"\"\"\n        Clip gradients, skip update if non-finite, and step optimizer.\n\n        Returns:\n            grad_norm (float): Norm of gradients before clipping.\n        \"\"\"\n        assert self.config.grad_clip is not None\n\n        if isinstance(self.module, FSDP):\n            grad_norm = self.module.clip_grad_norm_(self.config.grad_clip)\n        elif isinstance(self.module, FSDPModule):\n            grad_norm = fsdp2_clip_grad_norm_(self.module.parameters(), max_norm=self.config.grad_clip)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm=self.config.grad_clip)\n\n        # if grad_norm is not finite, skip the update\n        if not torch.isfinite(grad_norm):\n            print(f\"WARN: grad_norm is not finite: {grad_norm}\")\n            self.optimizer.zero_grad()\n        else:\n            self.optimizer.step()\n        return grad_norm\n\n    def lr_scheduler_step(self):\n        \"\"\"\n        Advance FSDP scheduler and return updated learning rate.\n        \"\"\"\n        self.lr_scheduler.step()\n        lr = self.lr_scheduler.get_last_lr()\n        return lr\n\n    def to(self, device: str, model: bool = True, optimizer: bool = True):\n        \"\"\"\n        Move FSDP model and/or optimizer to CPU or GPU with offload support.\n        \"\"\"\n        assert device in (\"cuda\", \"cpu\")\n        if device == \"cuda\":\n            if not self.config.model.fsdp_config.param_offload:\n                if model:\n                    load_fsdp_model_to_gpu(self.model_module)\n                if optimizer and self.optimizer is not None:\n                    load_fsdp_optimizer(self.optimizer, device)\n            gc.collect()\n        elif device == \"cpu\":\n            if not self.config.model.fsdp_config.param_offload:\n                if model:\n                    offload_fsdp_model_to_cpu(self.model_module)\n                if optimizer and self.optimizer is not None:\n                    offload_fsdp_optimizer(self.optimizer)\n        else:\n            raise ValueError(f\"Invalid device type: {device}\")\n\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        \"\"\"\n        Save FSDP checkpoint, handling parameter offload as needed.\n        \"\"\"\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.module)\n\n        self.checkpoint_manager.save_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):\n        \"\"\"\n        Load FSDP checkpoint, restoring parameters and optimizer state.\n        \"\"\"\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.module)\n\n        self.checkpoint_manager.load_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(self.optimizer)\n\n\nclass EngineEvalModeCtx:\n    def __init__(self, engine):\n        self.engine = engine\n\n    def __enter__(self):\n        self.engine.mode = \"eval\"\n        if self.engine._is_offload_param:\n            load_fsdp_model_to_gpu(self.engine.module)\n\n        self.engine.ulysses_sharding_manager.__enter__()\n        self.engine.module.eval()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)\n        if self.engine._is_offload_param:\n            offload_fsdp_model_to_cpu(self.engine.module)\n        self.engine.mode = None\n\n\nclass EngineTrainModeCtx:\n    def __init__(self, engine):\n        self.engine = engine\n\n    def __enter__(self):\n        self.engine.mode = \"train\"\n        if self.engine._is_offload_param:\n            load_fsdp_model_to_gpu(self.engine.module)\n        if self.engine._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.engine.optimizer, device_id=get_torch_device().current_device())\n\n        self.engine.ulysses_sharding_manager.__enter__()\n        self.engine.module.train()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)\n\n        if self.engine._is_offload_param:\n            offload_fsdp_model_to_cpu(self.engine.module)\n        if self.engine._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.optimizer)\n        self.engine.mode = None\n"
  },
  {
    "path": "verl_rl/verl/workers/engine/fsdp/utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom torch.distributed.device_mesh import init_device_mesh\n\nfrom verl.utils.device import get_device_name\n\n\ndef create_device_mesh(world_size, fsdp_size):\n    \"\"\"\n    Create a device mesh for distributed training based on the world size and FSDP size.\n\n    Args:\n        world_size (int): Total number of processes in the distributed training setup.\n        fsdp_size (int): Size of the Fully Sharded Data Parallel (FSDP) group.\n\n    Returns:\n        torch.distributed.device_mesh.DeviceMesh: The initialized device mesh.\n    \"\"\"\n    device_name = get_device_name()\n    if fsdp_size < 0 or fsdp_size >= world_size:\n        device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n    else:\n        device_mesh = init_device_mesh(\n            device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=[\"ddp\", \"fsdp\"]\n        )\n    return device_mesh\n\n\ndef get_sharding_strategy(device_mesh):\n    \"\"\"\n    Determine the appropriate sharding strategy based on the number of dimensions of the device mesh.\n\n    Args:\n        device_mesh (torch.distributed.device_mesh.DeviceMesh): The device mesh used for distributed training.\n\n    Returns:\n        torch.distributed.fsdp.ShardingStrategy: The sharding strategy to be used with FSDP.\n\n    Raises:\n        NotImplementedError: If the number of dimensions of the device mesh is neither 1 nor 2.\n    \"\"\"\n    from torch.distributed.fsdp import ShardingStrategy\n\n    if device_mesh.ndim == 1:\n        sharding_strategy = ShardingStrategy.FULL_SHARD\n    elif device_mesh.ndim == 2:\n        sharding_strategy = ShardingStrategy.HYBRID_SHARD\n    else:\n        raise NotImplementedError(f\"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2\")\n    return sharding_strategy\n"
  },
  {
    "path": "verl_rl/verl/workers/engine/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/workers/engine/megatron/engine_impl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Callable\n\nimport torch\n\nfrom verl import DataProto\n\nfrom ..base import BaseEngine, EngineRegistry\n\n\n@EngineRegistry.register(\"megatron\")\nclass MegatronEngine(BaseEngine):\n    def __init__(self, config):\n        raise NotImplementedError\n\n    def init_model(self):\n        raise NotImplementedError\n\n    def train_mode(self):\n        \"\"\"\n        Context manager entry for switching the engine and model into training mode.\n\n        Usage:\n            with engine.train_mode():\n                # runs in training mode\n        \"\"\"\n        raise NotImplementedError\n\n    def eval_mode(self):\n        \"\"\"\n        Context manager entry for switching the engine and model into evaluation mode.\n\n        Usage:\n            with engine.eval_mode():\n                # runs in evaluation mode\n        \"\"\"\n        raise NotImplementedError\n\n    def infer_batch(\n        self,\n        data: DataProto,\n        post_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],\n    ) -> dict[str, torch.Tensor]:\n        \"\"\"\n        Perform inference on a mini batch of data.\n\n        Args:\n            data: The input data for inference, typically containing tensors and metadata.\n            post_fn: A post-processing function that takes a micro-batch and predictions as input,\n                     and returns a tuple containing processed predictions and a dictionary of outputs.\n\n        Returns:\n            dict[str, torch.Tensor]: A dictionary containing the predictions for the entire batch.\n        \"\"\"\n        raise NotImplementedError\n\n    def train_batch(\n        self,\n        data: DataProto,\n        loss_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],\n    ) -> dict[str, torch.Tensor]:\n        \"\"\"\n        Perform a training step on a mini-batch of data.\n\n        Args:\n            data (DataProto): The input data for training, typically containing tensors and metadata.\n            loss_fn (Callable): A function that computes the loss and metrics given a micro-batch and predictions.\n\n        Returns:\n            dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the mini-batch.\n        \"\"\"\n        raise NotImplementedError\n\n    def optimizer_zero_grad(self):\n        \"\"\"\n        Zero out gradients of all parameters before starting a new backward pass.\n        \"\"\"\n        raise NotImplementedError\n\n    def optimizer_step(self):\n        \"\"\"\n        Perform an optimization step to update model parameters based on accumulated gradients.\n\n        Returns:\n            grad_norm (float): The norm of the gradients before clipping or update.\n        \"\"\"\n        raise NotImplementedError\n\n    def lr_scheduler_step(self):\n        \"\"\"\n        Advance the learning rate scheduler by one step.\n\n        Returns:\n            current_lr (float or list[float]): Updated learning rate(s).\n        \"\"\"\n        raise NotImplementedError\n\n    def shard_data(self, data):\n        \"\"\"\n        Shard or partition data for distributed training or parallel execution.\n\n        Args:\n            data: Data structure to be sharded across devices/workers.\n\n        Returns:\n            Sharded data in the same format as input.\n        \"\"\"\n        raise NotImplementedError\n\n    def unshard_data(self, data):\n        \"\"\"\n        Reconstruct or gather sharded data back to a unified format.\n\n        Args:\n            data: Sharded data structure to reconstruct.\n\n        Returns:\n            Unsharded, combined data.\n        \"\"\"\n        raise NotImplementedError\n\n    def to(self, device: str, model: bool = True, optimizer: bool = True):\n        \"\"\"\n        Move model parameters, optimizer states, or both to the specified device.\n\n        Args:\n            device: Target device identifier.\n            model: If True, move the model.\n            optimizer: If True, move the optimizer states.\n        \"\"\"\n        raise NotImplementedError\n\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        \"\"\"\n        Save model, optimizer, and scheduler states to a checkpoint.\n\n        Args:\n            local_path: Local filesystem path to save checkpoint.\n            hdfs_path: Optional HDFS path to copy checkpoint.\n            global_step: Integer training step number for naming.\n            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.\n        \"\"\"\n        raise NotImplementedError\n\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):\n        \"\"\"\n        Load model, optimizer, and scheduler states from a checkpoint.\n\n        Args:\n            local_path: Local filesystem path of the checkpoint.\n            hdfs_path: Optional HDFS path where checkpoint is stored.\n            del_local_after_load: Whether to delete local copy after loading.\n        \"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_rl/verl/workers/fsdp_workers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe main entry point to run the PPO algorithm\n\"\"\"\n\nimport json\nimport logging\nimport os\nimport warnings\nfrom dataclasses import asdict\nfrom typing import Any\n\nimport psutil\nimport torch\nimport torch.distributed\nimport torch.distributed as dist\nfrom codetiming import Timer\nfrom omegaconf import DictConfig, OmegaConf, open_dict\nfrom peft import LoraConfig, TaskType, get_peft_model\nfrom safetensors.torch import save_file\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nimport verl.utils.torch_functional as verl_F\nfrom verl import DataProto\nfrom verl.models.transformers.monkey_patch import apply_monkey_patch\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.utils import hf_processor, hf_tokenizer\nfrom verl.utils.activation_offload import enable_activation_offloading\nfrom verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.device import (\n    get_device_id,\n    get_device_name,\n    get_nccl_backend,\n    get_torch_device,\n    is_cuda_available,\n    is_npu_available,\n)\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.fsdp_utils import (\n    CPUOffloadPolicy,\n    MixedPrecisionPolicy,\n    apply_fsdp2,\n    fsdp2_load_full_state_dict,\n    fsdp_version,\n    get_fsdp_wrap_policy,\n    get_init_weight_context_manager,\n    init_fn,\n    layered_summon_lora_params,\n    load_fsdp_model_to_gpu,\n    load_fsdp_optimizer,\n    offload_fsdp_model_to_cpu,\n    offload_fsdp_optimizer,\n)\nfrom verl.utils.import_utils import import_external_libs\nfrom verl.utils.model import compute_position_id_with_mask\nfrom verl.utils.profiler import DistProfiler, DistProfilerExtension, log_gpu_memory_usage, simple_timer\nfrom verl.utils.profiler.performance import reduce_timing\nfrom verl.utils.py_functional import convert_to_regular_types\nfrom verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\ndevice_name = get_device_name()\n\n\ndef create_device_mesh(world_size, fsdp_size):\n    if fsdp_size < 0 or fsdp_size >= world_size:\n        device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=[\"fsdp\"])\n    else:\n        device_mesh = init_device_mesh(\n            device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=[\"ddp\", \"fsdp\"]\n        )\n    return device_mesh\n\n\ndef get_sharding_strategy(device_mesh):\n    from torch.distributed.fsdp import ShardingStrategy\n\n    if device_mesh.ndim == 1:\n        sharding_strategy = ShardingStrategy.FULL_SHARD\n    elif device_mesh.ndim == 2:\n        sharding_strategy = ShardingStrategy.HYBRID_SHARD\n    else:\n        raise NotImplementedError(f\"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2\")\n    return sharding_strategy\n\n\nclass ActorRolloutRefWorker(Worker, DistProfilerExtension):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    def __init__(self, config: DictConfig, role: str, **kwargs):\n        Worker.__init__(self)\n\n        self.config = config\n        self.profile_option = kwargs.get(\"profile_option\", None)\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ.get(\"RANK\", 0))\n            world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n            torch.distributed.init_process_group(\n                backend=f\"cpu:gloo,{get_device_name()}:{get_nccl_backend()}\",\n                rank=rank,\n                world_size=world_size,\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n\n        # build device mesh for FSDP\n        world_size = torch.distributed.get_world_size()\n        # TODO(sgm): support FSDP hybrid shard for larger model\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size)\n\n        # build device mesh for Ulysses Sequence Parallel\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.actor.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n        self._lora_rank = self.config.model.get(\"lora_rank\", 0)\n        self._is_lora = self._lora_rank > 0\n\n        self.role = role\n        assert self.role in [\"actor\", \"rollout\", \"ref\", \"actor_rollout\", \"actor_rollout_ref\"]\n\n        self._is_actor = self.role in [\"actor\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_rollout = self.role in [\"rollout\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_ref = self.role in [\"ref\", \"actor_rollout_ref\"]\n\n        # TODO(haibin.lin):\n        # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig,\n        # it will actually convert the ProfilerConfig dataclass back to a DictConfig.\n        # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py)\n        # as they provides DictConfig-like interface\n        # The benefit of creating the dataclass config is to perform validation during __post_init__\n        profiler_config = omega_conf_to_dataclass(config.get(\"profiler\"))\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=profiler_config, option=self.profile_option)\n        )\n\n        self._is_offload_param = False\n        self._is_offload_optimizer = False\n        if self._is_actor:\n            self._is_offload_param = self.config.actor.fsdp_config.get(\"param_offload\", False)\n            self._is_offload_optimizer = self.config.actor.fsdp_config.get(\"optimizer_offload\", False)\n        elif self._is_ref:\n            # TODO: it seems that manual offload is slowly than FSDP offload\n            self._is_offload_param = self.config.ref.fsdp_config.get(\"param_offload\", False)\n\n        # normalize config\n        if self._is_actor:\n            self.config.actor.ppo_mini_batch_size *= self.config.rollout.n\n            self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size\n            assert self.config.actor.ppo_mini_batch_size > 0, (\n                f\"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after \"\n                f\"normalization\"\n            )\n            # micro bsz\n            if self.config.actor.ppo_micro_batch_size is not None:\n                self.config.actor.ppo_micro_batch_size //= (\n                    self.device_mesh.size() // self.ulysses_sequence_parallel_size\n                )\n                self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size\n\n            if self.config.actor.ppo_micro_batch_size_per_gpu is not None:\n                assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, (\n                    f\"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by \"\n                    f\"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}\"\n                )\n                assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, (\n                    f\"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than \"\n                    f\"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}\"\n                )\n\n        # normalize rollout config\n        if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:\n            self.config.rollout.log_prob_micro_batch_size //= (\n                self.device_mesh.size() // self.ulysses_sequence_parallel_size\n            )\n            self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size\n        # normalize ref config\n        if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:\n            self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size\n            self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size\n\n    def _build_model_optimizer(\n        self,\n        model_path,\n        fsdp_config,\n        optim_config,\n        override_model_config,\n        use_remove_padding=False,\n        use_fused_kernels=False,\n        enable_gradient_checkpointing=False,\n        trust_remote_code=False,\n        use_liger=False,\n        role=\"actor\",\n        enable_activation_offload=False,\n    ):\n        from torch import optim\n        from torch.distributed.fsdp import CPUOffload, MixedPrecision\n        from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq\n\n        from verl.utils.model import get_generation_config, print_model_size, update_model_config\n        from verl.utils.torch_dtypes import PrecisionType\n\n        assert role in [\"actor\", \"ref\"]\n\n        log_gpu_memory_usage(f\"Before init {role} from HF AutoModel\", logger=logger)\n        local_path = model_path\n\n        # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect\n        # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly\n        self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)\n        self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)\n\n        if self.config.model.get(\"custom_chat_template\", None) is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.config.model.custom_chat_template\n            else:\n                self.tokenizer.chat_template = self.config.model.custom_chat_template\n\n        torch_dtype = fsdp_config.get(\"model_dtype\", None)\n        if torch_dtype is None:\n            torch_dtype = torch.float32 if self._is_actor else torch.bfloat16\n        else:\n            torch_dtype = PrecisionType.to_dtype(torch_dtype)\n\n        # override model kwargs\n        actor_model_config = AutoConfig.from_pretrained(\n            local_path, trust_remote_code=trust_remote_code, attn_implementation=\"flash_attention_2\"\n        )\n\n        # patch for kimi-vl\n        if getattr(actor_model_config, \"model_type\", None) == \"kimi_vl\":\n            actor_model_config.text_config.topk_method = \"greedy\"\n\n        self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)\n\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_model_config)\n        update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)\n        if self.rank == 0:\n            print(f\"Model config after override: {actor_model_config}\")\n\n        # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():\n                actor_module_class = AutoModelForVision2Seq\n            else:\n                actor_module_class = AutoModelForCausalLM\n\n            actor_module = actor_module_class.from_pretrained(\n                pretrained_model_name_or_path=local_path,\n                torch_dtype=torch_dtype,\n                config=actor_model_config,\n                trust_remote_code=trust_remote_code,\n            )\n\n            # Apply Liger kernel to the model if use_liger is set to True\n            if use_liger:\n                from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance\n\n                _apply_liger_kernel_to_instance(model=actor_module)\n\n            fused_kernel_options = self.config.model.get(\"fused_kernel_options\", None)\n            fused_kernels_backend = (\n                fused_kernel_options.get(\"impl_backend\", None) if fused_kernel_options is not None else None\n            )\n\n            apply_monkey_patch(\n                model=actor_module,\n                use_remove_padding=use_remove_padding,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n                use_fused_kernels=use_fused_kernels,\n                fused_kernels_backend=fused_kernels_backend,\n            )\n\n            # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2\n            actor_module.to(torch_dtype)\n\n            if enable_gradient_checkpointing:\n                actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n            if self._is_lora:\n                print(\"Applying LoRA to actor module\")\n                actor_module.enable_input_require_grads()\n                # Convert config to regular Python types before creating PEFT model\n                lora_config = {\n                    \"task_type\": TaskType.CAUSAL_LM,\n                    \"r\": self.config.model.lora_rank,\n                    \"lora_alpha\": self.config.model.lora_alpha,\n                    \"target_modules\": convert_to_regular_types(self.config.model.target_modules),\n                    \"exclude_modules\": convert_to_regular_types(self.config.model.exclude_modules),\n                    \"bias\": \"none\",\n                }\n                actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))\n        torch.distributed.barrier()\n\n        if self.rank == 0:\n            print_model_size(actor_module)\n\n        log_gpu_memory_usage(f\"After init {role} from HF AutoModel\", logger=logger)\n\n        # We wrap FSDP for rollout as well\n        mixed_precision_config = fsdp_config.get(\"mixed_precision\", None)\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"param_dtype\", \"bf16\"))\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"reduce_dtype\", \"fp32\"))\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"buffer_dtype\", \"fp32\"))\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n\n        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(\n            module=actor_module,\n            config=fsdp_config.get(\"wrap_policy\", None),\n            is_lora=self.config.model.get(\"lora_rank\", 0) > 0,\n        )\n\n        if self._is_rollout and self.config.rollout.name == \"hf\":\n            # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma\n            auto_wrap_policy = None\n\n        if self.rank == 0:\n            print(f\"wrap_policy: {auto_wrap_policy}\")\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        # TODO: add transformer policy\n        # We force reference policy to use CPUOffload to save memory.\n        # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation\n        cpu_offload = None if role == \"actor\" else CPUOffload(offload_params=True)\n        fsdp_strategy = self.config.actor.strategy\n        if fsdp_strategy == \"fsdp\":\n            actor_module_fsdp = FSDP(\n                actor_module,\n                cpu_offload=cpu_offload,\n                param_init_fn=init_fn,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,  # zero3\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                device_mesh=self.device_mesh,\n                use_orig_params=self.config.actor.fsdp_config.get(\"use_orig_params\", False),\n                forward_prefetch=self.config.actor.fsdp_config.get(\"forward_prefetch\", False),\n            )\n        elif fsdp_strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            mp_policy = MixedPrecisionPolicy(\n                param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True\n            )\n            if role == \"actor\" and fsdp_config.offload_policy:\n                cpu_offload = CPUOffloadPolicy(pin_memory=True)\n                self._is_offload_param = False\n                self._is_offload_optimizer = False\n            else:\n                cpu_offload = None if role == \"actor\" else CPUOffloadPolicy(pin_memory=True)\n\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": cpu_offload,\n                \"reshard_after_forward\": fsdp_config.reshard_after_forward,\n            }\n            full_state = actor_module.state_dict()\n            apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)\n            fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)\n            actor_module_fsdp = actor_module\n        else:\n            raise NotImplementedError(f\"not implement {fsdp_strategy}\")\n\n        if enable_activation_offload:\n            enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing)\n\n        log_gpu_memory_usage(f\"After {role} FSDP init\", logger=logger)\n\n        # TODO: add more optimizer args into config\n        if role == \"actor\" and optim_config is not None:\n            from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup\n\n            actor_optimizer = optim.AdamW(\n                actor_module_fsdp.parameters(),\n                lr=optim_config.lr,\n                betas=optim_config.get(\"betas\", (0.9, 0.999)),\n                weight_decay=optim_config.get(\"weight_decay\", 1e-2),\n            )\n\n            total_steps = optim_config.get(\"total_training_steps\", 0)\n            num_warmup_steps = int(optim_config.get(\"lr_warmup_steps\", -1))\n            warmup_style = optim_config.get(\"warmup_style\", \"constant\")\n            min_lr_ratio = optim_config.get(\"min_lr_ratio\", 0.0)\n            num_cycles = optim_config.get(\"num_cycles\", 0.5)\n            if num_warmup_steps < 0:\n                num_warmup_steps_ratio = optim_config.get(\"lr_warmup_steps_ratio\", 0.0)\n                num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n\n            if self.rank == 0:\n                print(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n\n            if warmup_style == \"constant\":\n                actor_lr_scheduler = get_constant_schedule_with_warmup(\n                    optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps\n                )\n            elif warmup_style == \"cosine\":\n                actor_lr_scheduler = get_cosine_schedule_with_warmup(\n                    optimizer=actor_optimizer,\n                    num_warmup_steps=num_warmup_steps,\n                    num_training_steps=total_steps,\n                    min_lr_ratio=min_lr_ratio,\n                    num_cycles=num_cycles,\n                )\n            else:\n                raise NotImplementedError(f\"Warmup style {warmup_style} is not supported\")\n\n            log_gpu_memory_usage(f\"After {role} optimizer init\", logger=logger)\n        else:\n            actor_optimizer = None\n            actor_lr_scheduler = None\n\n        return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config\n\n    def _build_rollout(self, trust_remote_code=False):\n        from torch.distributed.device_mesh import init_device_mesh\n\n        # TODO(sgm): support FSDP hybrid shard for larger model\n        infer_tp = self.config.rollout.tensor_model_parallel_size\n        dp = self.world_size // infer_tp\n        assert self.world_size % infer_tp == 0, (\n            f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n        )\n        rollout_device_mesh = init_device_mesh(\n            device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=[\"dp\", \"infer_tp\"]\n        )\n        rollout_name = self.config.rollout.name\n        if rollout_name == \"hf\":\n            from verl.workers.rollout import HFRollout\n            from verl.workers.sharding_manager.base import BaseShardingManager\n\n            rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)\n            rollout_sharding_manager = BaseShardingManager()\n            # TODO: a sharding manager that do nothing?\n\n        elif rollout_name == \"vllm\":\n            from verl.workers.rollout.vllm_rollout import vLLMRollout\n            from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager\n\n            log_gpu_memory_usage(f\"Before building {rollout_name} rollout\", logger=logger)\n            local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get(\"use_shm\", False))\n            lora_kwargs = (\n                {\"lora_kwargs\": {\"enable_lora\": True, \"max_loras\": 1, \"max_lora_rank\": self._lora_rank}}\n                if self._is_lora\n                else {}\n            )\n            # lora_kwargs = {}\n            from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout\n\n            vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == \"sync\" else vLLMAsyncRollout\n            rollout = vllm_rollout_cls(\n                model_path=local_path,\n                config=self.config.rollout,\n                tokenizer=self.tokenizer,\n                model_hf_config=self.actor_model_config,\n                device_mesh=rollout_device_mesh,\n                trust_remote_code=trust_remote_code,\n                **lora_kwargs,\n            )\n\n            log_gpu_memory_usage(f\"After building {rollout_name} rollout\", logger=logger)\n            full_params = torch.distributed.get_world_size() == 1\n            rollout_sharding_manager = FSDPVLLMShardingManager(\n                module=self.actor_module_fsdp,\n                inference_engine=rollout.inference_engine,\n                model_config=self.actor_model_config,\n                rollout_config=self.config.rollout,\n                full_params=full_params,\n                device_mesh=rollout_device_mesh,\n                offload_param=self._is_offload_param,\n                load_format=self.config.rollout.load_format,\n                layered_summon=self.config.rollout.get(\"layered_summon\", False),\n            )\n            log_gpu_memory_usage(\"After building sharding manager\", logger=logger)\n\n        elif rollout_name == \"sglang\":\n            from verl.workers.rollout.sglang_rollout import SGLangRollout\n\n            # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to\n            # SGLang's model_runner would check CUDA device capability. However, due to verl's setting,\n            # the main process of ray can not find any CUDA device, which would potentially lead to:\n            # \"RuntimeError: No CUDA GPUs are available\".\n            # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and\n            # we import it here use the abs path.\n            # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76\n            from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager\n\n            local_path = copy_to_local(self.config.model.path)\n            log_gpu_memory_usage(f\"Before building {rollout_name} rollout\", logger=logger)\n            rollout = SGLangRollout(\n                actor_module=local_path,\n                config=self.config.rollout,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                model_hf_config=self.actor_model_config,\n                trust_remote_code=trust_remote_code,\n            )\n            log_gpu_memory_usage(f\"After building {rollout_name} rollout\", logger=logger)\n\n            if torch.distributed.get_world_size() == 1:\n                self.config.rollout.load_format = \"dummy_hf\"\n            rollout_sharding_manager = FSDPSGLangShardingManager(\n                module=self.actor_module_fsdp,\n                inference_engine=rollout._engine,\n                model_config=self.actor_model_config,\n                rollout_config=self.config.rollout,\n                full_params=\"hf\" in self.config.rollout.load_format,\n                device_mesh=rollout_device_mesh,\n                offload_param=self._is_offload_param,\n                multi_stage_wake_up=self.config.rollout.multi_stage_wake_up,\n            )\n            log_gpu_memory_usage(\"After building sharding manager\", logger=logger)\n\n        else:\n            raise NotImplementedError(f\"Rollout name: {self.config.rollout.name} is not supported\")\n\n        return rollout, rollout_sharding_manager\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        from verl.workers.actor import DataParallelPPOActor\n\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        override_model_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n\n        use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        use_shm = self.config.model.get(\"use_shm\", False)\n        use_fused_kernels = self.config.model.get(\"use_fused_kernels\", False)\n\n        if self._is_actor or self._is_rollout:\n            # we need the model for actor and rollout\n            if self._is_actor:\n                optim_config = self.config.actor.optim\n                fsdp_config = self.config.actor.fsdp_config\n            else:\n                optim_config = None\n                fsdp_config = OmegaConf.create()\n\n            local_path = copy_to_local(self.config.model.path, use_shm=use_shm)\n            (\n                self.actor_module_fsdp,\n                self.actor_optimizer,\n                self.actor_lr_scheduler,\n                self.actor_model_config,\n            ) = self._build_model_optimizer(\n                model_path=local_path,\n                fsdp_config=fsdp_config,\n                optim_config=optim_config,\n                override_model_config=override_model_config,\n                use_remove_padding=use_remove_padding,\n                use_fused_kernels=use_fused_kernels,\n                enable_gradient_checkpointing=self.config.model.get(\"enable_gradient_checkpointing\", False),\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                use_liger=self.config.model.get(\"use_liger\", False),\n                role=\"actor\",\n                enable_activation_offload=self.config.model.get(\"enable_activation_offload\", False),\n            )\n\n            # get the original unwrapped module\n            if fsdp_version(self.actor_module_fsdp) == 1:\n                self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module\n\n            if self._is_offload_param:\n                offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n                log_gpu_memory_usage(\"After offload actor model during init\", logger=logger)\n\n            if self._is_offload_optimizer:\n                offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n                log_gpu_memory_usage(\"After offload actor optimizer during init\", logger=logger)\n\n        if self._is_actor:\n            OmegaConf.set_struct(self.config.actor, True)\n            with open_dict(self.config.actor):\n                self.config.actor.use_remove_padding = use_remove_padding\n                self.config.actor.use_fused_kernels = use_fused_kernels\n            self.actor = DataParallelPPOActor(\n                config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer\n            )\n\n        if self._is_rollout:\n            self.rollout, self.rollout_sharding_manager = self._build_rollout(\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False)\n            )\n\n        if self._is_ref:\n            local_path = copy_to_local(self.config.model.path, use_shm=use_shm)\n            self.ref_module_fsdp = self._build_model_optimizer(\n                model_path=local_path,\n                fsdp_config=self.config.ref.fsdp_config,\n                optim_config=None,\n                override_model_config=override_model_config,\n                use_remove_padding=use_remove_padding,\n                use_fused_kernels=use_fused_kernels,\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False),\n                use_liger=self.config.model.get(\"use_liger\", False),\n                role=\"ref\",\n            )[0]\n            OmegaConf.set_struct(self.config.ref, True)\n            with open_dict(self.config.ref):\n                self.config.ref.use_remove_padding = use_remove_padding\n                self.config.ref.use_fused_kernels = use_fused_kernels\n            self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)\n\n        if self._is_actor:\n            self.flops_counter = FlopsCounter(self.actor_model_config)\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp,\n                optimizer=self.actor.actor_optimizer,\n                lr_scheduler=self.actor_lr_scheduler,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                checkpoint_config=self.config.actor.checkpoint,\n            )\n\n        if not self._is_actor and self._is_rollout:\n            # If ActorRolloutRefWorker is initialized as a standalone rollout,\n            # create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout.\n\n            checkpoint_contents = OmegaConf.create({\"load_contents\": [\"model\"], \"save_contents\": []})\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp,\n                optimizer=None,\n                lr_scheduler=None,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                checkpoint_config=checkpoint_contents,\n            )\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"red\", role=\"actor_update\")\n    def update_actor(self, data: DataProto):\n        # Support all hardwares\n        data = data.to(get_device_id())\n\n        assert self._is_actor\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n        if self._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())\n\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n            # perform training\n            with Timer(name=\"update_policy\", logger=None) as timer:\n                metrics = self.actor.update_policy(data=data)\n            delta_time = timer.last\n            global_num_tokens = data.meta_info[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/actor\"] = (\n                estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size\n            )\n            metrics[\"perf/max_memory_allocated_gb\"] = get_torch_device().max_memory_allocated() / (1024**3)\n            metrics[\"perf/max_memory_reserved_gb\"] = get_torch_device().max_memory_reserved() / (1024**3)\n            metrics[\"perf/cpu_memory_used_gb\"] = psutil.virtual_memory().used / (1024**3)\n\n            lr = self.actor_lr_scheduler.get_last_lr()[0]\n            metrics[\"actor/lr\"] = lr\n            self.actor_lr_scheduler.step()\n\n            # TODO: here, we should return all metrics\n            output = DataProto(meta_info={\"metrics\": metrics})\n\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n            output = output.to(\"cpu\")\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n            log_gpu_memory_usage(\"After offload actor model during update_actor\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n            log_gpu_memory_usage(\"After offload actor optimizer during update_actor\", logger=logger)\n\n        return output\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"red\", role=\"rollout_generate\")\n    def generate_sequences(self, prompts: DataProto):\n        # Support all hardwares\n        prompts = prompts.to(get_device_id())\n\n        assert self._is_rollout\n\n        meta_info = {\n            \"eos_token_id\": self.generation_config.eos_token_id\n            if self.generation_config is not None\n            else self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.generation_config.pad_token_id\n            if self.generation_config is not None\n            else self.tokenizer.pad_token_id,\n        }\n        prompts.meta_info.update(meta_info)\n        timing_generate = {}\n        with self.rollout_sharding_manager:\n            log_gpu_memory_usage(\"After entering rollout sharding manager\", logger=logger)\n\n            prompts = self.rollout_sharding_manager.preprocess_data(prompts)\n            with simple_timer(\"generate_sequences\", timing_generate):\n                output = self.rollout.generate_sequences(prompts=prompts)\n\n            log_gpu_memory_usage(\"After rollout generation\", logger=logger)\n\n            output = self.rollout_sharding_manager.postprocess_data(output)\n\n        timing_generate.update(self.rollout_sharding_manager.timing)\n        # We calculate the average timing across all ranks\n        # to make sure meta_info[\"timing\"] is the same\n        timing_generate = reduce_timing(timing_generate)\n        output.meta_info[\"timing\"] = timing_generate\n        output = output.to(\"cpu\")\n\n        # clear kv cache\n        get_torch_device().empty_cache()\n        return output\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"blue\", role=\"actor_compute_log_prob\")\n    def compute_log_prob(self, data: DataProto):\n        # when is_lora is True, we use the actor without lora applied to calculate the log_prob\n        # which is mostly used for ref log_prob calculation\n        assert self._is_actor\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        # Support all hardwares\n        from contextlib import nullcontext\n\n        is_lora = data.meta_info.pop(\"is_lora\", False)\n        adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext()\n        data = data.to(get_device_id())\n        # we should always recompute old_log_probs when it is HybridEngine\n        data.meta_info[\"micro_batch_size\"] = self.config.rollout.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"max_token_len\"] = self.config.rollout.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.rollout.log_prob_use_dynamic_bsz\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        # perform recompute log_prob\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data)\n            with adapter_ctx:\n                output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)\n            output = DataProto.from_dict(\n                tensors={\"old_log_probs\": output, \"entropys\": entropys},\n                meta_info={\"temperature\": self.config.rollout.temperature},\n            )\n            output = self.ulysses_sharding_manager.postprocess_data(output)\n\n        output = output.to(\"cpu\")\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1:\n            self.actor.actor_module._handle.reshard(True)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n            log_gpu_memory_usage(\"After offload actor model during compute_log_prob\", logger=logger)\n\n        return output\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"olive\", role=\"ref_compute_log_prob\")\n    def compute_ref_log_prob(self, data: DataProto):\n        if self._is_lora:\n            # if _is_lora, actor without lora applied is the ref\n            data.meta_info[\"is_lora\"] = True\n            data = self.compute_log_prob(data)\n            # this old_log_probs is in fact ref_log_prob\n            data = DataProto.from_dict(tensors={\"ref_log_prob\": data.batch[\"old_log_probs\"]})\n            return data\n        assert self._is_ref\n        # else:\n        # otherwise, the class have a standalone ref model\n        # Support all hardwares\n        data = data.to(get_device_id())\n\n        micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        data.meta_info[\"max_token_len\"] = self.config.ref.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.ref.log_prob_use_dynamic_bsz\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data)\n            output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)\n            output = DataProto.from_dict(tensors={\"ref_log_prob\": output})\n            output = self.ulysses_sharding_manager.postprocess_data(output)\n\n        output = output.to(\"cpu\")\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1:\n            self.ref_policy.actor_module._handle.reshard(True)\n\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        from verl.utils.logger import log_with_rank\n\n        # only support save and load ckpt for actor\n        assert self._is_actor\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        self.checkpoint_manager.save_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n        dist.barrier()\n\n        if self._is_lora and hasattr(getattr(self, \"actor_module\", self.actor_module_fsdp), \"peft_config\"):\n            lora_save_path = os.path.join(local_path, \"lora_adapter\")\n            peft_model = getattr(self, \"actor_module\", self.actor_module_fsdp)\n            peft_config = {}\n            if dist.get_rank() == 0:\n                os.makedirs(lora_save_path, exist_ok=True)\n                peft_config = asdict(peft_model.peft_config.get(\"default\", {}))\n                peft_config[\"task_type\"] = peft_config[\"task_type\"].value\n                peft_config[\"peft_type\"] = peft_config[\"peft_type\"].value\n                peft_config[\"target_modules\"] = list(peft_config[\"target_modules\"])\n            try:\n                if fsdp_version(self.actor_module_fsdp) > 0:\n                    self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name())\n                    lora_params = layered_summon_lora_params(self.actor_module_fsdp)\n                    if dist.get_rank() == 0:\n                        save_file(lora_params, os.path.join(lora_save_path, \"adapter_model.safetensors\"))\n                        with open(os.path.join(lora_save_path, \"adapter_config.json\"), \"w\", encoding=\"utf-8\") as f:\n                            json.dump(peft_config, f, ensure_ascii=False, indent=4)\n            except Exception as e:\n                log_with_rank(\n                    f\"Save LoRA Adapter Error ({e})\", rank=dist.get_rank(), logger=logger, log_only_rank_0=True\n                )\n\n            dist.barrier()\n            log_with_rank(\n                f\"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}\",\n                rank=dist.get_rank(),\n                logger=logger,\n                log_only_rank_0=True,\n            )\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):\n        assert self._is_actor or (not self._is_actor and self._is_rollout), (\n            f\"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got \"\n            f\"{self._is_actor} and {self._is_rollout}\"\n        )\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        self.checkpoint_manager.load_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(self.actor_optimizer)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def start_profile(self, **kwargs) -> None:\n        \"\"\"Start profiling for the current rank in the current training step.\"\"\"\n        self.profiler.start(**kwargs)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def stop_profile(self) -> None:\n        \"\"\"Stop profiling for the current rank in the current training step.\"\"\"\n        self.profiler.stop()\n\n\nclass CriticWorker(Worker, DistProfilerExtension):\n    def __init__(self, config):\n        Worker.__init__(self)\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get(\"profiler\")))\n        )\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(\n                backend=get_nccl_backend(), init_method=os.environ.get(\"DIST_INIT_METHOD\", None)\n            )\n        self.config = config\n\n        # build device mesh for Ulysses Sequence Parallel\n        world_size = torch.distributed.get_world_size()\n        from torch.distributed.device_mesh import init_device_mesh\n\n        fsdp_size = self.config.model.fsdp_config.fsdp_size\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        # set FSDP offload params\n        self._is_offload_param = self.config.model.fsdp_config.param_offload\n        self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload\n\n        # normalize config\n        self.config.ppo_mini_batch_size *= self.config.rollout_n\n        self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n        if self.config.ppo_micro_batch_size is not None:\n            self.config.ppo_micro_batch_size //= (\n                torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n            )\n            self.config.forward_micro_batch_size //= (\n                torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size\n            )\n            self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size\n            self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size\n\n        if self.config.ppo_micro_batch_size_per_gpu is not None:\n            assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, (\n                f\"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by \"\n                f\"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}\"\n            )\n            assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, (\n                f\"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than \"\n                f\"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}\"\n            )\n        self._is_lora = self.config.model.get(\"lora_rank\", 0) > 0\n\n    def _build_critic_model_optimizer(self, config):\n        # the following line is necessary\n        from torch import optim\n        from torch.distributed.fsdp import MixedPrecision\n\n        from verl.utils.model import load_valuehead_model, print_model_size\n        from verl.utils.torch_dtypes import PrecisionType\n\n        use_shm = config.model.get(\"use_shm\", False)\n        local_path = copy_to_local(config.model.path, use_shm=use_shm)\n        # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info\n        # using random initialized model from any architecture. May not be the same as Actor.\n\n        tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm)\n        self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n        self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n\n        if self.config.model.get(\"custom_chat_template\", None) is not None:\n            if self.processor is not None:\n                self.processor.chat_template = self.config.model.custom_chat_template\n            else:\n                self.tokenizer.chat_template = self.config.model.custom_chat_template\n\n        override_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_config)\n        if self.rank == 0:\n            print(f\"Critic overriding config {override_config_kwargs}\")\n\n        torch_dtype = self.config.model.fsdp_config.get(\"model_dtype\", \"fp32\")\n        torch_dtype = PrecisionType.to_dtype(torch_dtype)\n\n        from transformers import AutoConfig\n\n        critic_model_config = AutoConfig.from_pretrained(\n            local_path,\n            attn_implementation=\"flash_attention_2\",\n            trust_remote_code=config.model.get(\"trust_remote_code\", False),\n        )\n        critic_model_config.num_labels = 1\n        # patch for kimi-vl\n        if getattr(critic_model_config, \"model_type\", None) == \"kimi_vl\":\n            critic_model_config.text_config.topk_method = \"greedy\"\n\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            critic_model_config.classifier_dropout = 0.0\n            critic_model_config.hidden_dropout = \"0\"\n            critic_model_config.summary_dropout_prob = 0.0\n\n            critic_module = load_valuehead_model(\n                local_path,\n                torch_dtype,\n                critic_model_config,\n                config.model.get(\"trust_remote_code\", False),\n            )\n\n            use_remove_padding = config.model.get(\"use_remove_padding\", False)\n\n            apply_monkey_patch(\n                model=critic_module,\n                use_remove_padding=use_remove_padding,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n            )\n\n            # some parameters may not in torch_dtype\n            critic_module.to(torch_dtype)\n\n            if config.model.get(\"enable_gradient_checkpointing\", False):\n                critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n\n        if self._is_lora:\n            print(\"Applying LoRA to critic module\")\n            critic_module.enable_input_require_grads()\n            # Convert config to regular Python types before creating PEFT model\n            lora_config = {\n                \"task_type\": TaskType.CAUSAL_LM,\n                \"r\": self.config.model.lora_rank,\n                \"lora_alpha\": self.config.model.lora_alpha,\n                \"target_modules\": convert_to_regular_types(self.config.model.target_modules),\n                \"bias\": \"none\",\n            }\n            critic_module = get_peft_model(critic_module, LoraConfig(**lora_config))\n\n        if self.rank == 0:\n            print_model_size(critic_module)\n\n        self.critic_model_config = critic_model_config\n\n        fsdp_config = self.config.model.fsdp_config\n        mixed_precision_config = fsdp_config.get(\"mixed_precision\", None)\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"param_dtype\", \"bf16\"))\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"reduce_dtype\", \"fp32\"))\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get(\"buffer_dtype\", \"fp32\"))\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n\n        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(\n            module=critic_module,\n            config=self.config.model.fsdp_config.wrap_policy,\n            is_lora=self.config.model.get(\"lora_rank\", 0) > 0,\n        )\n\n        log_gpu_memory_usage(\"Before critic FSDP\", logger=None)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation\n        if config.strategy == \"fsdp\":\n            critic_module = FSDP(\n                critic_module,\n                param_init_fn=init_fn,\n                use_orig_params=False,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                forward_prefetch=self.config.model.fsdp_config.forward_prefetch,\n                device_mesh=self.device_mesh,\n                cpu_offload=None,\n            )\n        elif config.strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            mp_policy = MixedPrecisionPolicy(\n                param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True\n            )\n            offload_policy = None\n            if fsdp_config.offload_policy:\n                self._is_offload_param = False\n                self._is_offload_optimizer = False\n                offload_policy = CPUOffloadPolicy(pin_memory=True)\n\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": offload_policy,\n                \"reshard_after_forward\": fsdp_config.reshard_after_forward,\n            }\n            full_state = critic_module.state_dict()\n            apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config)\n            fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy)\n        else:\n            raise NotImplementedError(f\"Unknown strategy {config.strategy}\")\n\n        if config.model.get(\"enable_activation_offload\", False):\n            enable_gradient_checkpointing = config.model.get(\"enable_gradient_checkpointing\", False)\n            enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing)\n\n        log_gpu_memory_usage(\"After critic FSDP\", logger=None)\n\n        critic_optimizer = optim.AdamW(\n            critic_module.parameters(),\n            lr=config.optim.lr,\n            betas=config.optim.get(\"betas\", (0.9, 0.999)),\n            weight_decay=config.optim.get(\"weight_decay\", 1e-2),\n        )\n\n        total_steps = config.optim.get(\"total_training_steps\", 0)\n        num_warmup_steps = int(config.optim.get(\"lr_warmup_steps\", -1))\n        warmup_style = config.optim.get(\"warmup_style\", \"constant\")\n        if num_warmup_steps < 0:\n            num_warmup_steps_ratio = config.optim.get(\"lr_warmup_steps_ratio\", 0.0)\n            num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n\n        if self.rank == 0:\n            print(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n\n        from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup\n\n        if warmup_style == \"constant\":\n            critic_lr_scheduler = get_constant_schedule_with_warmup(\n                optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps\n            )\n        elif warmup_style == \"cosine\":\n            critic_lr_scheduler = get_cosine_schedule_with_warmup(\n                optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps\n            )\n        else:\n            raise NotImplementedError(f\"Warmup style {warmup_style} is not supported\")\n\n        return critic_module, critic_optimizer, critic_lr_scheduler\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n\n        from verl.workers.critic import DataParallelPPOCritic\n\n        self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(\n            self.config\n        )\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n            log_gpu_memory_usage(\"After offload critic model during init\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.critic_optimizer)\n            log_gpu_memory_usage(\"After offload critic optimizer during init\", logger=logger)\n\n        self.critic = DataParallelPPOCritic(\n            config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer\n        )\n\n        self.flops_counter = FlopsCounter(self.critic_model_config)\n        self.checkpoint_manager = FSDPCheckpointManager(\n            model=self.critic_module,\n            optimizer=self.critic_optimizer,\n            lr_scheduler=self.critic_lr_scheduler,\n            processing_class=self.processor if self.processor is not None else self.tokenizer,\n            checkpoint_config=self.config.checkpoint,\n        )\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"cyan\")\n    def compute_values(self, data: DataProto):\n        # Support all hardwares\n        data = data.to(get_device_id())\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n        micro_batch_size = self.config.forward_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"max_token_len\"] = self.config.forward_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n            values = self.critic.compute_values(data=data)\n            output = DataProto.from_dict(tensors={\"values\": values})\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n\n        output = output.to(\"cpu\")\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n        return output\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"pink\")\n    def update_critic(self, data: DataProto):\n        # Support all hardwares\n        data = data.to(get_device_id())\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n        if self._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id())\n\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n\n            with Timer(name=\"update_critic\", logger=None) as timer:\n                metrics = self.critic.update_critic(data=data)\n            delta_time = timer.last\n\n            global_num_tokens = data.meta_info[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/critic\"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size\n\n            lr = self.critic_lr_scheduler.get_last_lr()[0]\n            metrics[\"critic/lr\"] = lr\n            self.critic_lr_scheduler.step()\n\n            output = DataProto(batch=None, meta_info={\"metrics\": metrics})\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.critic_optimizer)\n\n        output = output.to(\"cpu\")\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n\n        self.checkpoint_manager.save_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n\n        self.checkpoint_manager.load_checkpoint(\n            local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(self.critic_optimizer)\n\n\n# TODO(sgm): we may need to extract it to dp_reward_model.py\nclass RewardModelWorker(Worker, DistProfilerExtension):\n    \"\"\"\n    Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.\n    \"\"\"\n\n    def __init__(self, config):\n        Worker.__init__(self)\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get(\"profiler\")))\n        )\n\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(\n                backend=get_nccl_backend(), init_method=os.environ.get(\"DIST_INIT_METHOD\", None)\n            )\n        self.config = config\n\n        # build device mesh for Ulysses Sequence Parallel\n        world_size = torch.distributed.get_world_size()\n        from torch.distributed.device_mesh import init_device_mesh\n\n        fsdp_size = self.config.model.fsdp_config.fsdp_size\n        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.get(\"ulysses_sequence_parallel_size\", 1)\n        dp = world_size // self.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = init_device_mesh(\n                device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=[\"dp\", \"sp\"]\n            )\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        self.use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n\n        # normalize config\n        if self.config.micro_batch_size is not None:\n            self.config.micro_batch_size //= torch.distributed.get_world_size()\n            self.config.micro_batch_size_per_gpu = self.config.micro_batch_size\n\n    def _build_model(self, config):\n        # the following line is necessary\n        from torch.distributed.fsdp import CPUOffload\n        from transformers import AutoConfig, AutoModelForTokenClassification\n\n        use_shm = config.model.get(\"use_shm\", False)\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.model.path, use_shm=use_shm)\n\n        if self.config.model.input_tokenizer is None:\n            self._do_switch_chat_template = False\n        else:\n            self._do_switch_chat_template = True\n            input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm)\n            self.input_tokenizer = hf_tokenizer(\n                input_tokenizer_local_path, trust_remote_code=config.model.get(\"trust_remote_code\", False)\n            )\n            self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get(\"trust_remote_code\", False))\n\n        trust_remote_code = config.model.get(\"trust_remote_code\", False)\n        model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)\n        model_config.num_labels = 1\n\n        # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh\n        )\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            model_config.classifier_dropout = 0.0\n            reward_module = AutoModelForTokenClassification.from_pretrained(\n                pretrained_model_name_or_path=local_path,\n                config=model_config,\n                torch_dtype=torch.bfloat16,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )\n\n            apply_monkey_patch(\n                model=reward_module,\n                use_remove_padding=config.model.get(\"use_remove_padding\", False),\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n            )\n\n            reward_module.to(torch.bfloat16)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        if config.strategy == \"fsdp\":\n            reward_module = FSDP(\n                reward_module,\n                param_init_fn=init_fn,\n                use_orig_params=False,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,  # zero3\n                sync_module_states=True,\n                cpu_offload=CPUOffload(offload_params=True),\n                forward_prefetch=self.config.model.fsdp_config.forward_prefetch,\n                device_mesh=self.device_mesh,\n            )\n        elif config.strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            cpu_offload = CPUOffloadPolicy(pin_memory=True)\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"offload_policy\": cpu_offload,\n                \"reshard_after_forward\": config.model.fsdp_config.reshard_after_forward,\n            }\n            full_state = reward_module.state_dict()\n            apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config)\n            fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload)\n        else:\n            raise NotImplementedError(f\"Unknown strategy: {config.strategy}\")\n        return reward_module\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.get(\"external_lib\", None))\n        self.reward_module = self._build_model(config=self.config)\n\n    def _forward_micro_batch(self, micro_batch):\n        if is_cuda_available:\n            from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\n        elif is_npu_available:\n            from transformers.integrations.npu_flash_attention import (\n                index_first_axis,\n                pad_input,\n                rearrange,\n                unpad_input,\n            )\n\n        from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs\n\n        with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch_size, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n            if position_ids.dim() == 3:  # qwen2vl mrope\n                position_ids = position_ids.transpose(0, 1)  # (bsz, 3, seqlen) -> (3, bsz, seqlen)\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, *_ = unpad_input(\n                    input_ids.unsqueeze(-1), attention_mask\n                )  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                if position_ids.dim() == 3:\n                    position_ids_rmpad = (\n                        index_first_axis(rearrange(position_ids, \"c b s ... -> (b s) c ...\"), indices)\n                        .transpose(0, 1)\n                        .unsqueeze(1)\n                    )  # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)\n                else:\n                    position_ids_rmpad = index_first_axis(\n                        rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices\n                    ).transpose(0, 1)\n\n                # pad and slice the inputs if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size\n                    )\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                output = self.reward_module(\n                    input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False\n                )\n                reward_rmpad = output.logits\n                reward_rmpad = reward_rmpad.squeeze(0)  # (total_nnz)\n\n                # gather output if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    reward_rmpad = gather_outputs_and_unpad(\n                        reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size\n                    )\n\n                # pad it back\n                rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)\n            else:\n                output = self.reward_module(\n                    input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False\n                )\n                rm_score = output.logits  # (batch_size, seq_len, 1)\n                rm_score = rm_score.squeeze(-1)\n\n            # extract the result of the last valid token\n            eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)\n            rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]\n            return rm_score\n\n    def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):\n        batch_size = data.batch.batch_size[0]\n        # expand as token_level_reward\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        response_length = data.batch[\"responses\"].shape[-1]\n        if position_ids.dim() == 3:  # qwen2vl mrope [bs, 3, seq_len]\n            position_ids = position_ids[:, 0, :]\n        eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)\n        token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)  # (bsz, seqlen)\n        token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores\n\n        # select the response part\n        token_level_scores = token_level_scores[:, -response_length:]\n\n        return token_level_scores\n\n    def _switch_chat_template(self, data: DataProto):\n        src_max_length = data.batch[\"attention_mask\"].shape[-1]\n\n        src_tokenizer = self.input_tokenizer\n        target_tokenizer = self.tokenizer\n\n        rm_input_ids = []\n        rm_attention_mask = []\n\n        for i in range(data.batch.batch_size[0]):\n            # extract raw prompt\n            if isinstance(data.non_tensor_batch[\"raw_prompt\"][i], list):\n                chat: list = data.non_tensor_batch[\"raw_prompt\"][i]\n            else:\n                chat: list = data.non_tensor_batch[\"raw_prompt\"][i].tolist()\n\n            # extract response\n            response_ids = data.batch[\"responses\"][i]\n            response_length = response_ids.shape[-1]\n            valid_response_length = data.batch[\"attention_mask\"][i][-response_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            response = src_tokenizer.decode(valid_response_ids)\n            # remove bos and eos\n            response = response.replace(src_tokenizer.eos_token, \"\")\n\n            chat.append({\"role\": \"assistant\", \"content\": response})\n\n            prompt_with_chat_template = target_tokenizer.apply_chat_template(\n                chat, add_generation_prompt=False, tokenize=False\n            )\n            if self.rank == 0 and i == 0:\n                # for debugging purpose\n                print(f\"Switch template. chat: {prompt_with_chat_template}\")\n\n            # the maximum length is actually determined by the reward model itself\n            max_length = self.config.get(\"max_length\", src_max_length)\n            if max_length is None:\n                max_length = src_max_length\n\n            model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids, attention_mask = verl_F.postprocess_data(\n                input_ids=model_inputs[\"input_ids\"],\n                attention_mask=model_inputs[\"attention_mask\"],\n                max_length=max_length,\n                pad_token_id=target_tokenizer.pad_token_id,\n                left_pad=False,  # right padding\n                truncation=self.config.get(\"truncation\", \"right\"),\n            )  # truncate from the right\n\n            rm_input_ids.append(input_ids)\n            rm_attention_mask.append(attention_mask)\n\n        rm_input_ids = torch.cat(rm_input_ids, dim=0)\n        rm_attention_mask = torch.cat(rm_attention_mask, dim=0)\n\n        rm_position_ids = compute_position_id_with_mask(rm_attention_mask)\n\n        rm_inputs = {\"input_ids\": rm_input_ids, \"attention_mask\": rm_attention_mask, \"position_ids\": rm_position_ids}\n\n        return DataProto.from_dict(rm_inputs)\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"brown\")\n    def compute_rm_score(self, data: DataProto):\n        import itertools\n\n        from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\n\n        # Support all hardwares\n        data = data.to(get_device_id())\n        if self._do_switch_chat_template:\n            rm_data = self._switch_chat_template(data)\n        else:\n            rm_input_ids = data.batch[\"input_ids\"]\n            rm_attention_mask = data.batch[\"attention_mask\"]\n            rm_position_ids = data.batch[\"position_ids\"]\n            rm_inputs = {\n                \"input_ids\": rm_input_ids,\n                \"attention_mask\": rm_attention_mask,\n                \"position_ids\": rm_position_ids,\n            }\n            rm_data = DataProto.from_dict(rm_inputs)\n\n        # Support all hardwares\n        rm_data.batch = rm_data.batch.to(get_device_id())\n\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data)\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n\n            use_dynamic_bsz = self.config.use_dynamic_bsz\n            if use_dynamic_bsz:\n                max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)\n            else:\n                micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)\n            output = []\n            for micro_batch in micro_batches:\n                rm_score = self._forward_micro_batch(micro_batch)\n                output.append(rm_score)\n            scores = torch.cat(output, dim=0)  # (batch_size)\n\n            if use_dynamic_bsz:\n                indices = list(itertools.chain.from_iterable(indices))\n                assert len(indices) == scores.size(0), f\"{len(indices)} vs. {scores.size()}\"\n                revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                scores = scores[revert_indices]\n\n            token_level_scores = self._expand_to_token_level(data, scores)\n            # Note that this is only the scores, may not be the final rewards used to train RL\n            output = DataProto.from_dict(tensors={\"rm_scores\": token_level_scores})\n            output = self.ulysses_sharding_manager.postprocess_data(data=output)\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1 and fsdp_version(self.reward_module) == 1:\n            self.reward_module._handle.reshard(True)\n\n        output = output.to(\"cpu\")\n        return output\n\n\n# ================================= Async related workers =================================\nclass AsyncActorRolloutRefWorker(ActorRolloutRefWorker):\n    def _build_rollout(self, trust_remote_code=False):\n        rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)\n\n        # NOTE: rollout is not actually initialized here, it's deferred\n        # to be initialized by AsyncvLLMServer.\n\n        self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size\n        self.vllm_dp_rank = int(os.environ[\"RANK\"]) // self.vllm_tp_size\n        self.vllm_tp_rank = int(os.environ[\"RANK\"]) % self.vllm_tp_size\n\n        # used for sleep/wake_up\n        rollout.sharding_manager = rollout_sharding_manager\n\n        return rollout, rollout_sharding_manager\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def generate_sequences(self, prompts: DataProto):\n        raise NotImplementedError(\"AsyncActorRolloutRefWorker does not support generate_sequences\")\n\n    # ============================ vLLM related ============================\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    def execute_method(self, method: str | bytes, *args, **kwargs):\n        \"\"\"Called by ExternalRayDistributedExecutor collective_rpc.\"\"\"\n        return self.rollout.execute_method(method, *args, **kwargs)\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    def get_zeromq_address(self):\n        return self.rollout.get_zeromq_address()\n\n    # ============================ SGLang related ============================\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)\n    async def chat_completion(self, json_request):\n        ret = await self.rollout.chat_completion(json_request)\n        return ret\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)\n    async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]:\n        ret = await self.rollout.generate(prompt_ids, sampling_params, request_id)\n        return ret\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    async def wake_up(self):\n        if self.config.rollout.free_cache_engine:\n            await self.rollout.wake_up()\n        # return something to block the caller\n        return True\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    async def sleep(self):\n        if self.config.rollout.free_cache_engine:\n            await self.rollout.sleep()\n        # return something to block the caller\n        return True\n"
  },
  {
    "path": "verl_rl/verl/workers/megatron_workers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe main entry point to run the PPO algorithm\n\"\"\"\n\nimport datetime\nimport logging\nimport os\nimport time\nfrom typing import Any\n\nimport psutil\nimport torch\nimport torch.distributed\nfrom codetiming import Timer\nfrom megatron.core import parallel_state as mpu\nfrom omegaconf import DictConfig, OmegaConf, open_dict\n\nfrom verl import DataProto\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.single_controller.base.megatron.worker import MegatronWorker\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device\nfrom verl.utils.flops_counter import FlopsCounter\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.megatron_utils import (\n    load_megatron_model_to_gpu,\n    load_megatron_optimizer,\n    offload_megatron_model_to_cpu,\n    offload_megatron_optimizer,\n)\nfrom verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights\nfrom verl.utils.profiler import (\n    DistProfiler,\n    DistProfilerExtension,\n    GPUMemoryLogger,\n    log_gpu_memory_usage,\n    simple_timer,\n)\nfrom verl.utils.profiler.performance import reduce_timing\nfrom verl.workers.actor.megatron_actor import MegatronPPOActor\nfrom verl.workers.critic.megatron_critic import MegatronPPOCritic\nfrom verl.workers.reward_model.megatron.reward_model import MegatronRewardModel\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef set_random_seed(seed):\n    import random\n\n    import numpy as np\n    import torch\n\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    if get_torch_device().device_count() > 0:\n        from megatron.core import tensor_parallel\n\n        tensor_parallel.model_parallel_cuda_manual_seed(seed)\n    # FIXME: torch cumsum not support deterministic (used in vllm sampler),\n    # https://github.com/pytorch/pytorch/issues/89492\n    # torch.use_deterministic_algorithms(True, warn_only=True)\n    # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'\n\n\nclass ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    def __init__(self, config: DictConfig, role: str, **kwargs):\n        MegatronWorker.__init__(self)\n        self.config = config\n\n        # NOTE(sgm): We utilize colocate WorkerGroup by default.\n        # As a result, Workers for different model share the same process.\n        # Therefore, we only require one distribute initialization.\n        # To utilize different parallel strategy in different models:\n        # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,\n        # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(\n                backend=get_nccl_backend(),\n                timeout=datetime.timedelta(seconds=self.config.get(\"nccl_timeout\", 600)),\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n            get_torch_device().set_device(rank)\n\n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size,\n                pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size,\n                virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size,\n                pipeline_model_parallel_split_rank=None,\n                use_sharp=False,\n                context_parallel_size=self.config.actor.megatron.context_parallel_size,\n                expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size,\n                expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size,\n                nccl_communicator_config_path=None,\n            )\n\n        set_random_seed(seed=self.config.actor.megatron.seed)\n\n        self.role = role\n        assert self.role in [\"actor\", \"rollout\", \"ref\", \"actor_rollout\", \"actor_rollout_ref\"]\n\n        self._is_actor = self.role in [\"actor\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_rollout = self.role in [\"rollout\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_ref = self.role in [\"ref\", \"actor_rollout_ref\"]\n\n        profiler_config = omega_conf_to_dataclass(config.get(\"profiler\"))\n        DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config))\n\n        # TODO(sgm): Currently, we only support reference model param offload\n        # will support other offload later\n        self._is_offload_param = False\n        self._is_offload_grad = False\n        self._is_offload_optimizer = False\n\n        # normalize config\n        if self._is_actor and self._is_rollout:\n            self.config.actor.ppo_mini_batch_size *= self.config.rollout.n\n            self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()\n            if self.config.actor.get(\"ppo_micro_batch_size\", None):\n                self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()\n                self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()\n                self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size\n                self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size\n\n            self._is_offload_param = self.config.actor.megatron.get(\"param_offload\", False)\n            self._is_offload_grad = self.config.actor.megatron.get(\"grad_offload\", False)\n            self._is_offload_optimizer = self.config.actor.megatron.get(\"optimizer_offload\", False)\n        elif self._is_ref:\n            if self.config.ref.get(\"log_prob_micro_batch_size\", None):\n                self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()\n                self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size\n            else:\n                assert self.config.ref.get(\"log_prob_micro_batch_size_per_gpu\", None) is not None, (\n                    \"Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and \"\n                    \"`log_prob_micro_batch_size` should not be None at the same time.\"\n                )\n            self._ref_is_offload_param = self.config.ref.megatron.get(\"param_offload\", False)\n\n    def _build_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config):\n        from verl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler\n        from verl.utils.megatron_utils import get_model, init_megatron_optim_config\n        from verl.utils.model import get_generation_config, print_model_size\n\n        self._init_hf_config_and_tf_config(\n            model_path,\n            model_path,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            self.config.model.get(\"trust_remote_code\", False),\n            self.config.actor.megatron.use_mbridge,\n        )\n        self.generation_config = get_generation_config(self.local_path)\n\n        def make_model(wrap_with_ddp=False):\n            if self.bridge is not None:\n                from verl.models.mcore.mbridge import freeze_moe_router\n\n                post_model_creation_callbacks = []\n                if override_model_config.get(\"moe_config\", {}).get(\"freeze_moe_router\", False):\n                    post_model_creation_callbacks.append(freeze_moe_router)\n                return self.bridge.get_model(\n                    post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=wrap_with_ddp\n                )\n            else:\n\n                def megatron_actor_model_provider(pre_process, post_process):\n                    from verl.models.mcore import init_mcore_model\n\n                    parallel_model = init_mcore_model(\n                        self.tf_config,\n                        self.hf_config,\n                        pre_process,\n                        post_process,\n                        share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,\n                        value=False,\n                        freeze_moe_router=override_model_config.get(\"moe_config\", {}).get(\"freeze_moe_router\", False),\n                    )\n                    parallel_model.to(get_device_name())\n                    return parallel_model\n\n                override_ddp_config = OmegaConf.to_container(\n                    self.config.actor.megatron.get(\"override_ddp_config\", OmegaConf.create()), resolve=True\n                )\n                return get_model(\n                    megatron_actor_model_provider,\n                    wrap_with_ddp=wrap_with_ddp,\n                    use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,\n                    override_ddp_config=override_ddp_config,\n                )\n\n        if self._is_actor or self._is_rollout:\n            actor_module = make_model(wrap_with_ddp=True)\n            print(f\"actor_module: {len(actor_module)}\")\n            if self.config.actor.load_weight:\n                if self.config.actor.megatron.use_dist_checkpointing:\n                    load_mcore_dist_weights(\n                        actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False\n                    )\n                else:\n                    if self.bridge is not None:\n                        local_model_path = get_hf_model_path(self.config)\n                        self.bridge.load_weights(actor_module, local_model_path)\n                    else:\n                        load_megatron_gptmodel_weights(\n                            self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False\n                        )\n\n            if self.rank == 0:\n                print_model_size(actor_module[0])\n            log_gpu_memory_usage(\"After MegatronPPOActor init\", logger=logger)\n        elif self._is_ref:\n            print(f\"self.config.ref.load_weight: {self.config.ref.load_weight}\")\n            ref_module = make_model(wrap_with_ddp=False)\n            if self.config.ref.load_weight:  # should align with the actor:\n                assert self.config.actor.load_weight == self.config.ref.load_weight\n                print(\"load ref weight start\")\n                if self.config.ref.megatron.use_dist_checkpointing:\n                    load_mcore_dist_weights(\n                        ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False\n                    )\n                else:\n                    if self.bridge is not None:\n                        local_model_path = get_hf_model_path(self.config)\n                        self.bridge.load_weights(ref_module, local_model_path)\n                    else:\n                        load_megatron_gptmodel_weights(\n                            self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False\n                        )\n            log_gpu_memory_usage(\"After ref module init\", logger=logger)\n            return ref_module, self.hf_config\n\n        # TODO: add more optimizer args into config\n        if self._is_actor:\n            optim_config_megatron = init_megatron_optim_config(optim_config)\n            actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron)\n            actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler(\n                optimizer=actor_optimizer, config=optim_config\n            )\n        else:\n            optim_config = None\n            actor_optimizer = None\n            actor_optimizer_scheduler = None\n\n        log_gpu_memory_usage(\"After actor optimizer init\", logger=logger)\n\n        return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config\n\n    def _build_rollout(self, trust_remote_code=False):\n        from torch.distributed.device_mesh import init_device_mesh\n\n        layer_name_mapping = {\n            \"qkv_layer_name\": \"self_attention.linear_qkv.\",\n            \"gate_proj_layer_name\": \"linear_fc1.\",\n        }\n        if self.config.rollout.name == \"vllm\":\n            from torch.distributed.device_mesh import init_device_mesh\n\n            from verl.workers.rollout.vllm_rollout import vLLMRollout\n            from verl.workers.sharding_manager.megatron_vllm import MegatronVLLMShardingManager\n\n            # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,\n            # we will reorganize their weight format when resharding from actor to rollout.\n\n            infer_tp = self.config.rollout.tensor_model_parallel_size\n            dp = self.world_size // infer_tp\n            assert self.world_size % infer_tp == 0, (\n                f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n            )\n            rollout_device_mesh = init_device_mesh(\n                get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=[\"dp\", \"infer_tp\"]\n            )\n            log_gpu_memory_usage(\"Before building vllm rollout\", logger=None)\n\n            local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get(\"use_shm\", False))\n            from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout\n\n            vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == \"sync\" else vLLMAsyncRollout\n            rollout = vllm_rollout_cls(\n                model_path=local_path,\n                config=self.config.rollout,\n                tokenizer=self.tokenizer,\n                model_hf_config=self.actor_model_config,\n                device_mesh=rollout_device_mesh,\n                trust_remote_code=trust_remote_code,\n            )\n            log_gpu_memory_usage(\"After building vllm rollout\", logger=logger)\n\n            # perform weight resharding between actor and rollout\n            from verl.models.mcore import get_mcore_weight_converter\n\n            weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)\n            sharding_manager = MegatronVLLMShardingManager(\n                inference_engine=rollout.inference_engine,\n                model_config=self.actor_model_config,\n                transformer_config=self.tf_config,\n                rollout_config=self.config.rollout,\n                layer_name_mapping=layer_name_mapping,\n                actor_module=self.actor.actor_module,\n                weight_converter=weight_converter,\n                device_mesh=rollout_device_mesh,\n                offload_param=self._is_offload_param,\n                bridge=self.bridge,\n            )\n            log_gpu_memory_usage(\"After building sharding manager\", logger=logger)\n\n        elif self.config.rollout.name == \"sglang\":\n            from verl.workers.rollout.sglang_rollout import SGLangRollout\n\n            # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's\n            # model_runner would check CUDA device capability.\n            # However, due to verl's setting, the main process of ray can not find any CUDA device, which would\n            # potentially lead to: \"RuntimeError: No CUDA GPUs are available\".\n            # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it\n            # here use the abs path.\n            # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76\n            from verl.workers.sharding_manager.megatron_sglang import MegatronSGLangShardingManager\n\n            infer_tp = self.config.rollout.tensor_model_parallel_size\n            dp = self.world_size // infer_tp\n            assert self.world_size % infer_tp == 0, (\n                f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n            )\n            rollout_device_mesh = init_device_mesh(\n                \"cpu\", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=(\"dp\", \"tp\", \"pp\")\n            )\n\n            local_path = copy_to_local(self.config.model.path)\n            log_gpu_memory_usage(f\"Before building {self.config.rollout.name} rollout\", logger=None)\n            rollout = SGLangRollout(\n                actor_module=local_path,\n                config=self.config.rollout,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                model_hf_config=self.actor_model_config,\n                trust_remote_code=trust_remote_code,\n                device_mesh=rollout_device_mesh,\n            )\n            log_gpu_memory_usage(f\"After building {self.config.rollout.name} rollout\", logger=None)\n\n            from verl.models.mcore import get_mcore_weight_converter\n\n            weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)\n            sharding_manager = MegatronSGLangShardingManager(\n                actor_module=self.actor.actor_module,\n                inference_engine=rollout._engine,\n                model_config=self.actor_model_config,\n                rollout_config=self.config.rollout,\n                transformer_config=self.tf_config,\n                layer_name_mapping=layer_name_mapping,\n                weight_converter=weight_converter,\n                bridge=self.bridge,\n                device_mesh=rollout_device_mesh,\n                offload_param=self._is_offload_param,\n            )\n            log_gpu_memory_usage(\"After building sharding manager\", logger=logger)\n        else:\n            raise NotImplementedError(\"Only vllmRollout is supported with Megatron now\")\n        print(f\"rollout and sharding manager init done sharding_manager: {sharding_manager}\")\n        return rollout, sharding_manager\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        if self.config.model.get(\"external_lib\", None) is not None:\n            # This is used to import external_lib into the huggingface systems\n            import importlib\n\n            importlib.import_module(self.config.model.external_lib)\n\n        from verl.utils.torch_dtypes import PrecisionType\n\n        override_model_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n        if self._is_actor:\n            override_transformer_config = OmegaConf.to_container(\n                self.config.actor.megatron.get(\"override_transformer_config\", OmegaConf.create()), resolve=True\n            )\n        elif self._is_ref:\n            override_transformer_config = OmegaConf.to_container(\n                self.config.ref.megatron.get(\"override_transformer_config\", OmegaConf.create()), resolve=True\n            )\n        else:\n            override_transformer_config = {}\n        self.param_dtype = torch.bfloat16\n        log_gpu_memory_usage(\"Before init actor model and optimizer\", logger=logger)\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n        if self._is_actor or self._is_rollout:\n            # we need the model for actor and rollout\n            optim_config = self.config.actor.optim if self._is_actor else None\n            (\n                self.actor_module,\n                self.actor_optimizer,\n                self.actor_optimizer_scheduler,\n                self.actor_model_config,\n                self.actor_optim_config,\n            ) = self._build_model_optimizer(\n                model_path=self.config.model.path,\n                optim_config=optim_config,\n                override_model_config=override_model_config,\n                override_transformer_config=override_transformer_config,\n            )\n            if self._is_offload_param:\n                offload_megatron_model_to_cpu(self.actor_module)\n                log_gpu_memory_usage(\"After offload actor params and grad during init\", logger=logger)\n            if self._is_offload_optimizer:\n                offload_megatron_optimizer(self.actor_optimizer)\n                log_gpu_memory_usage(\"After offload actor optimizer during init\", logger=logger)\n\n        if self._is_actor:\n            OmegaConf.set_struct(self.config.actor, True)\n            with open_dict(self.config.actor):\n                use_fused_kernels = self.config.model.get(\"use_fused_kernels\", False)\n                self.config.actor.use_fused_kernels = use_fused_kernels\n            self.actor = MegatronPPOActor(\n                config=self.config.actor,\n                model_config=self.actor_model_config,\n                hf_config=self.hf_config,\n                tf_config=self.tf_config,\n                actor_module=self.actor_module,\n                actor_optimizer=self.actor_optimizer,\n            )\n            log_gpu_memory_usage(\"After MegatronPPOActor init\", logger=logger)\n\n        if self._is_rollout:\n            self.rollout, self.sharding_manager = self._build_rollout(\n                trust_remote_code=self.config.model.get(\"trust_remote_code\", False)\n            )\n            # used for sleep/wake_up\n            self.rollout.sharding_manager = self.sharding_manager\n            log_gpu_memory_usage(\"After rollout init\", logger=logger)\n\n        if self._is_ref:\n            self.ref_module, self.ref_model_config = self._build_model_optimizer(\n                model_path=self.config.model.path,\n                optim_config=None,\n                override_model_config=override_model_config,\n                override_transformer_config=override_transformer_config,\n            )\n            log_gpu_memory_usage(\"After ref model init\", logger=logger)\n            self.ref_policy = MegatronPPOActor(\n                config=self.config.ref,\n                model_config=self.ref_model_config,\n                hf_config=self.hf_config,\n                tf_config=self.tf_config,\n                actor_module=self.ref_module,\n                actor_optimizer=None,\n            )\n            if self._ref_is_offload_param:\n                offload_megatron_model_to_cpu(self.ref_module)\n                log_gpu_memory_usage(\"After offload ref params during init\", logger=logger)\n\n        if self._is_actor:\n            self.flops_counter = FlopsCounter(self.actor_model_config)\n            self.checkpoint_mananager = MegatronCheckpointManager(\n                config=self.config,\n                checkpoint_config=self.config.actor.checkpoint,\n                model_config=self.actor_model_config,\n                transformer_config=self.tf_config,\n                role=\"actor\",\n                model=self.actor_module,\n                arch=self.architectures[0],\n                hf_config=self.hf_config,\n                param_dtype=self.param_dtype,\n                share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                optimizer=self.actor_optimizer,\n                optimizer_scheduler=self.actor_optimizer_scheduler,\n                use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,\n                use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler,\n                bridge=self.bridge,\n                use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing,\n            )\n        get_torch_device().empty_cache()\n        log_gpu_memory_usage(\"After init_model finish\", logger=logger)\n\n    @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n    @GPUMemoryLogger(role=\"update_actor\", logger=logger)\n    @DistProfiler.annotate(color=\"red\")\n    def update_actor(self, data: DataProto):\n        assert self._is_actor\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n            log_gpu_memory_usage(\"After load actor params and grad during update_actor\", logger=logger)\n        if self._is_offload_optimizer:\n            load_megatron_optimizer(self.actor_optimizer)\n            log_gpu_memory_usage(\"After load actor optimizer during update_actor\", logger=logger)\n        data.batch = data.batch.to(get_device_name())\n\n        micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        dataloader = self.actor.make_minibatch_iterator(data=data)\n        with Timer(name=\"update_policy\", logger=None) as timer:\n            metrics = self.actor.update_policy(dataloader=dataloader)\n        delta_time = timer.last\n        global_num_tokens = data.meta_info[\"global_token_num\"]\n        estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n        metrics[\"perf/mfu/actor\"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size\n        metrics[\"perf/max_memory_allocated_gb\"] = get_torch_device().max_memory_allocated() / (1024**3)\n        metrics[\"perf/max_memory_reserved_gb\"] = get_torch_device().max_memory_reserved() / (1024**3)\n        metrics[\"perf/cpu_memory_used_gb\"] = psutil.virtual_memory().used / (1024**3)\n        from verl.utils.megatron.optimizer import get_megatron_last_lr\n\n        metrics[\"actor/lr\"] = get_megatron_last_lr(self.actor_optimizer)\n        self.actor_optimizer_scheduler.step(1)\n\n        # TODO: here, we should return all metrics\n        output = DataProto(meta_info={\"metrics\": metrics})\n        output = output.to(\"cpu\")\n\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n            log_gpu_memory_usage(\"After offload actor params and grad during update_actor\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n            log_gpu_memory_usage(\"After offload actor optimizer during update_actor\", logger=logger)\n\n        get_torch_device().empty_cache()\n        return output\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    @GPUMemoryLogger(role=\"generate_sequences\", logger=logger)\n    @DistProfiler.annotate(color=\"red\")\n    def generate_sequences(self, prompts: DataProto):\n        assert self._is_rollout\n        prompts.batch = prompts.batch.to(get_device_name())\n        meta_info = {\n            \"eos_token_id\": self.generation_config.eos_token_id\n            if self.generation_config is not None\n            else self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.generation_config.pad_token_id\n            if self.generation_config is not None\n            else self.tokenizer.pad_token_id,\n        }\n        prompts.meta_info.update(meta_info)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n\n        timing_generate = {}\n        with self.sharding_manager:\n            log_gpu_memory_usage(\"After entering sharding manager\", logger=logger)\n            prompts = self.sharding_manager.preprocess_data(prompts)\n            with simple_timer(\"generate_sequences\", timing_generate):\n                output = self.rollout.generate_sequences(prompts=prompts)\n            output = self.sharding_manager.postprocess_data(output)\n            log_gpu_memory_usage(\"After rollout generation\", logger=logger)\n\n        timing_generate.update(self.sharding_manager.timing)\n        # We calculate the average timing across all ranks\n        # to make sure meta_info[\"timing\"] is the same\n        timing_generate = reduce_timing(timing_generate)\n        output.meta_info[\"timing\"] = timing_generate\n        output = output.to(\"cpu\")\n        # clear kv cache\n        get_torch_device().empty_cache()\n        return output\n\n    @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n    @GPUMemoryLogger(role=\"compute_ref_log_prob\", logger=logger)\n    @DistProfiler.annotate(color=\"olive\")\n    def compute_ref_log_prob(self, data: DataProto):\n        assert self._is_ref\n        if self._ref_is_offload_param:\n            load_megatron_model_to_gpu(self.ref_module, load_grad=False)\n            log_gpu_memory_usage(\"After load ref params and grad during compute_ref_log_prob\", logger=logger)\n        micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"max_token_len\"] = self.config.ref.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.ref.log_prob_use_dynamic_bsz\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        data = data.to(get_device_id())\n        output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)\n        output = DataProto.from_dict(tensors={\"ref_log_prob\": output})\n        output = output.to(\"cpu\")\n        if self._ref_is_offload_param:\n            offload_megatron_model_to_cpu(self.ref_module)\n            log_gpu_memory_usage(\"After offload ref params and grad during compute_ref_log_prob\", logger=logger)\n        get_torch_device().empty_cache()\n        return output\n\n    @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n    @GPUMemoryLogger(role=\"compute_log_prob\", logger=logger)\n    @DistProfiler.annotate(color=\"blue\")\n    def compute_log_prob(self, data: DataProto):\n        assert self._is_actor\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module, load_grad=False)\n            log_gpu_memory_usage(\"After load actor params and grad during compute_log_prob\", logger=logger)\n        # we should always recompute old_log_probs when it is HybridEngine\n        data.meta_info[\"micro_batch_size\"] = self.config.rollout.log_prob_micro_batch_size_per_gpu\n        data.meta_info[\"max_token_len\"] = self.config.rollout.log_prob_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.rollout.log_prob_use_dynamic_bsz\n        data.meta_info[\"temperature\"] = self.config.rollout.temperature\n        data = data.to(get_device_id())\n        output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)\n        output = DataProto.from_dict(\n            tensors={\"old_log_probs\": output, \"entropys\": entropys},\n            meta_info={\"temperature\": self.config.rollout.temperature},\n        )\n        output = output.to(\"cpu\")\n        # clear kv cache\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n            log_gpu_memory_usage(\"After offload actor params and grad during compute_log_prob\", logger=logger)\n        get_torch_device().empty_cache()\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n        self.checkpoint_mananager.load_checkpoint(\n            local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_pretrained_model(self, checkpoint_path, del_local_after_load=True):\n        pass\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n        self.checkpoint_mananager.save_checkpoint(\n            local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n\n\nclass AsyncActorRolloutRefWorker(ActorRolloutRefWorker):\n    def _build_rollout(self, trust_remote_code=False):\n        rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)\n\n        # NOTE: rollout is not actually initialized here, it's deferred\n        # to be initialized by AsyncvLLMServer.\n\n        self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size\n        self.vllm_dp_rank = int(os.environ[\"RANK\"]) // self.vllm_tp_size\n        self.vllm_tp_rank = int(os.environ[\"RANK\"]) % self.vllm_tp_size\n\n        # used for sleep/wake_up\n        rollout.sharding_manager = rollout_sharding_manager\n\n        return rollout, rollout_sharding_manager\n\n    # ============================ vLLM related ============================\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    def execute_method(self, method: str | bytes, *args, **kwargs):\n        \"\"\"Called by ExternalRayDistributedExecutor collective_rpc.\"\"\"\n        if self.vllm_tp_rank == 0 and method != \"execute_model\":\n            print(\n                f\"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: \"\n                f\"{method if isinstance(method, str) else 'Callable'}\"\n            )\n        return self.rollout.execute_method(method, *args, **kwargs)\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    def get_zeromq_address(self):\n        return self.rollout.get_zeromq_address()\n\n    # ============================ SGLang related ============================\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)\n    async def chat_completion(self, json_request):\n        ret = await self.rollout.chat_completion(json_request)\n        return ret\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)\n    async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]:\n        ret = await self.rollout.generate(prompt_ids, sampling_params, request_id)\n        return ret\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    async def wake_up(self):\n        if self.config.rollout.free_cache_engine:\n            await self.rollout.wake_up()\n        # return something to block the caller\n        return True\n\n    @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)\n    async def sleep(self):\n        if self.config.rollout.free_cache_engine:\n            await self.rollout.sleep()\n        # return something to block the caller\n        return True\n\n\nclass CriticWorker(MegatronWorker, DistProfilerExtension):\n    def __init__(self, config):\n        MegatronWorker.__init__(self)\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get(\"profiler\")))\n        )\n        self.config = config\n\n        # NOTE(sgm): We utilize colocate WorkerGroup by default.\n        # As a result, Workers for different model share the same process.\n        # Therefore, we only require one distribute initialization.\n        # To utilize different parallel strategy in different models:\n        # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,\n        # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(\n                backend=get_nccl_backend(),\n                timeout=datetime.timedelta(seconds=self.config.get(\"nccl_timeout\", 600)),\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n            get_torch_device().set_device(rank)\n\n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size,\n                pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size,\n                virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size,\n                pipeline_model_parallel_split_rank=None,\n                use_sharp=False,\n                context_parallel_size=self.config.megatron.context_parallel_size,\n                expert_model_parallel_size=self.config.megatron.expert_model_parallel_size,\n                expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size,\n                nccl_communicator_config_path=None,\n            )\n\n        set_random_seed(seed=self.config.megatron.seed)\n\n        # set FSDP offload params\n        self._is_offload_param = self.config.megatron.param_offload\n        self._is_offload_optimizer = self.config.megatron.optimizer_offload\n\n        # normalize config\n        self.config.ppo_mini_batch_size *= self.config.rollout_n\n        self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()\n        if self.config.get(\"ppo_micro_batch_size\", None):\n            self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()\n            self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size\n\n        # TODO(sgm): support critic model offload\n\n    def _build_critic_model_optimizer(\n        self, model_path, optim_config, override_model_config, override_transformer_config\n    ):\n        from megatron.core.models.gpt.gpt_model import ModelType\n\n        from verl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler\n        from verl.utils.megatron_utils import get_model, init_megatron_optim_config\n        from verl.utils.model import print_model_size\n\n        self._init_hf_config_and_tf_config(\n            model_path,\n            self.config.model.tokenizer_path,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            self.config.model.get(\"trust_remote_code\", False),\n            self.config.megatron.use_mbridge,\n        )\n\n        if self.bridge is not None:\n            from verl.models.mcore.mbridge import freeze_moe_router, make_value_model\n\n            post_model_creation_callbacks = [make_value_model]\n            if override_model_config.get(\"moe_config\", {}).get(\"freeze_moe_router\", False):\n                post_model_creation_callbacks.append(freeze_moe_router)\n            critic_module = self.bridge.get_model(\n                post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=True\n            )\n        else:\n\n            def megatron_critic_model_provider(pre_process, post_process):\n                from verl.models.mcore import init_mcore_model\n\n                parallel_model = init_mcore_model(\n                    self.tf_config,\n                    self.hf_config,\n                    pre_process,\n                    post_process,\n                    share_embeddings_and_output_weights=False,\n                    value=True,\n                    freeze_moe_router=override_model_config.get(\"moe_config\", {}).get(\"freeze_moe_router\", False),\n                )\n                parallel_model.to(get_device_name())\n                return parallel_model\n\n            override_ddp_config = OmegaConf.to_container(\n                self.config.megatron.get(\"override_ddp_config\", OmegaConf.create()), resolve=True\n            )\n            # Step 3: initialize the megatron model\n            critic_module = get_model(\n                model_provider_func=megatron_critic_model_provider,\n                model_type=ModelType.encoder_or_decoder,\n                wrap_with_ddp=True,\n                use_distributed_optimizer=self.config.megatron.use_distributed_optimizer,\n                override_ddp_config=override_ddp_config,\n            )\n        # note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp).\n        # but here, we do not use pp (vpp) yet. For simplicity, we remove the list\n        # critic_module = nn.ModuleList(critic_module)\n\n        if self.config.load_weight:\n            t0 = time.time()\n            if self.config.megatron.use_dist_checkpointing:\n                load_mcore_dist_weights(\n                    critic_module, self.config.megatron.dist_checkpointing_path, is_value_model=True\n                )\n            else:\n                if self.bridge is not None:\n                    local_model_path = get_hf_model_path(self.config)\n                    self.bridge.load_weights(critic_module, local_model_path)\n                else:\n                    load_megatron_gptmodel_weights(\n                        self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True\n                    )\n            t1 = time.time()\n            if torch.distributed.get_rank() == 0:\n                print(f\"critic load_weight time: {t1 - t0}\")\n        if self.rank == 0:\n            print_model_size(critic_module[0])\n\n        # TODO: add more optimizer args into config\n        optim_config_megatron = init_megatron_optim_config(optim_config)\n        critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron)\n        critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler(\n            optimizer=critic_optimizer, config=optim_config\n        )\n        get_torch_device().empty_cache()\n        return critic_module, critic_optimizer, critic_optimizer_scheduler, self.hf_config, optim_config\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # create critic\n\n        from verl.utils.torch_dtypes import PrecisionType\n\n        if self.config.model.get(\"external_lib\", None) is not None:\n            # This is used to import external_lib into the huggingface systems\n            import importlib\n\n            importlib.import_module(self.config.model.external_lib)\n        override_model_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n        override_transformer_config = OmegaConf.to_container(\n            self.config.megatron.get(\"override_transformer_config\", OmegaConf.create()), resolve=True\n        )\n        self.param_dtype = torch.bfloat16\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n        (\n            self.critic_module,\n            self.critic_optimizer,\n            self.critic_optimizer_scheduler,\n            self.critic_model_config,\n            critic_optimizer_config,\n        ) = self._build_critic_model_optimizer(\n            model_path=self.config.model.path,\n            optim_config=self.config.optim,\n            override_model_config=override_model_config,\n            override_transformer_config=override_transformer_config,\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.critic_optimizer)\n\n        self.critic = MegatronPPOCritic(\n            config=self.config,\n            model_config=self.critic_model_config,\n            hf_config=self.hf_config,\n            tf_config=self.tf_config,\n            critic_module=self.critic_module,\n            critic_optimizer=self.critic_optimizer,\n            critic_optimizer_config=critic_optimizer_config,\n        )\n        self.flops_counter = FlopsCounter(self.critic_model_config)\n        self.checkpoint_mananager = MegatronCheckpointManager(\n            config=self.config,\n            checkpoint_config=self.config.checkpoint,\n            model_config=self.critic_model_config,\n            transformer_config=self.tf_config,\n            role=\"critic\",\n            model=self.critic_module,\n            arch=self.architectures[0],\n            hf_config=self.hf_config,\n            param_dtype=self.param_dtype,\n            share_embeddings_and_output_weights=False,\n            processing_class=self.processor if self.processor is not None else self.tokenizer,\n            optimizer=self.critic_optimizer,\n            optimizer_scheduler=self.critic_optimizer_scheduler,\n            use_distributed_optimizer=self.config.megatron.use_distributed_optimizer,\n            use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler,\n            bridge=self.bridge,\n            use_dist_checkpointing=self.config.megatron.use_dist_checkpointing,\n        )\n\n    @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"cyan\")\n    def compute_values(self, data: DataProto):\n        micro_batch_size = self.config.ppo_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"max_token_len\"] = self.config.forward_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        data = data.to(get_device_id())\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        values = self.critic.compute_values(data=data)\n        output = DataProto.from_dict(tensors={\"values\": values})\n        output = output.to(\"cpu\")\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        return output\n\n    @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"pink\")\n    def update_critic(self, data: DataProto):\n        data = data.to(get_device_id())\n\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        if self._is_offload_optimizer:\n            load_megatron_optimizer(self.critic_optimizer)\n\n        dataloader = self.critic.make_minibatch_iterator(data)\n        with Timer(name=\"update_critic\", logger=None) as timer:\n            metrics = self.critic.update_critic(dataloader=dataloader)\n        delta_time = timer.last\n        global_num_tokens = data.meta_info[\"global_token_num\"]\n        estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n        metrics[\"perf/mfu/critic\"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size\n        from verl.utils.megatron.optimizer import get_megatron_last_lr\n\n        metrics[\"critic/lr\"] = get_megatron_last_lr(self.critic_optimizer)\n        self.critic_optimizer_scheduler.step(1)\n\n        output = DataProto(batch=None, meta_info={\"metrics\": metrics})\n\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.critic_optimizer)\n        output = output.to(\"cpu\")\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        self.checkpoint_mananager.load_checkpoint(\n            local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.critic_optimizer)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_steps=0, max_ckpt_to_keep=None):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        self.checkpoint_mananager.save_checkpoint(\n            local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_steps, max_ckpt_to_keep=max_ckpt_to_keep\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n\n\nclass RewardModelWorker(MegatronWorker, DistProfilerExtension):\n    \"\"\"\n    Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification.\n    \"\"\"\n\n    def __init__(self, config):\n        MegatronWorker.__init__(self)\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get(\"profiler\")))\n        )\n        self.config = config\n\n        # NOTE(sgm): We utilize colocate WorkerGroup by default.\n        # As a result, Workers for different model share the same process.\n        # Therefore, we only require one distribute initialization.\n        # To utilize different parallel strategy in different models:\n        # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,\n        # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(\n                backend=get_nccl_backend(),\n                timeout=datetime.timedelta(seconds=self.config.get(\"nccl_timeout\", 600)),\n                init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n            )\n            get_torch_device().set_device(rank)\n\n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size,\n                pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size,\n                virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size,\n                pipeline_model_parallel_split_rank=None,\n                use_sharp=False,\n                context_parallel_size=self.config.megatron.context_parallel_size,\n                expert_model_parallel_size=self.config.megatron.expert_model_parallel_size,\n                expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size,\n                nccl_communicator_config_path=None,\n            )\n\n        set_random_seed(seed=self.config.megatron.seed)\n\n        # normalize config\n        if self.config.micro_batch_size is not None:\n            self.config.micro_batch_size //= mpu.get_data_parallel_world_size()\n            self.config.micro_batch_size_per_gpu = self.config.micro_batch_size\n\n    def _build_rm_model(self, model_path, tokenizer, override_model_config, override_transformer_config):\n        from megatron.core.models.gpt.gpt_model import ModelType\n\n        from verl.utils.megatron_utils import get_model\n\n        self._init_hf_config_and_tf_config(\n            model_path,\n            tokenizer,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            self.config.model.get(\"trust_remote_code\", False),\n            self.config.megatron.use_mbridge,\n        )\n        if self.bridge is not None:\n            from verl.models.mcore.mbridge import freeze_moe_router, make_value_model\n\n            post_model_creation_callbacks = [make_value_model]\n            if override_model_config.get(\"moe_config\", {}).get(\"freeze_moe_router\", False):\n                post_model_creation_callbacks.append(freeze_moe_router)\n            reward_model = self.bridge.get_model(\n                post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=False\n            )\n        else:\n\n            def megatron_rm_model_provider(pre_process, post_process):\n                from verl.models.mcore import init_mcore_model\n\n                parallel_model = init_mcore_model(\n                    self.tf_config,\n                    self.hf_config,\n                    pre_process,\n                    post_process,\n                    share_embeddings_and_output_weights=False,\n                    value=True,\n                )\n                parallel_model.to(get_device_name())\n                return parallel_model\n\n            # Step 3: initialize the megatron model\n            reward_model = get_model(\n                model_provider_func=megatron_rm_model_provider,\n                model_type=ModelType.encoder_or_decoder,\n                wrap_with_ddp=False,\n                use_distributed_optimizer=self.config.megatron.use_distributed_optimizer,\n            )\n            # note that here reward_model will be a list to be compatible with the construction of interleaved pp (vpp)\n            # but here, we do not use pp (vpp) yet. For simplicity, we remove the list\n            # reward_model = nn.ModuleList(reward_model)\n\n        if self.config.load_weight:\n            if self.config.megatron.use_dist_checkpointing:\n                load_mcore_dist_weights(reward_model, self.config.megatron.dist_checkpointing_path, is_value_model=True)\n            else:\n                if self.bridge is not None:\n                    local_model_path = get_hf_model_path(self.config)\n                    self.bridge.load_weights(reward_model, local_model_path)\n                else:\n                    load_megatron_gptmodel_weights(\n                        self.config, self.hf_config, reward_model, params_dtype=self.dtype, is_value_model=True\n                    )\n\n        # TODO: add more optimizer args into config\n        get_torch_device().empty_cache()\n        return reward_model, self.hf_config\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        # create critic\n\n        from verl.utils.torch_dtypes import PrecisionType\n\n        if self.config.model.get(\"external_lib\", None) is not None:\n            # This is used to import external_lib into the huggingface systems\n            import importlib\n\n            importlib.import_module(self.config.model.external_lib)\n        override_model_config = OmegaConf.to_container(self.config.model.get(\"override_config\", OmegaConf.create()))\n        override_transformer_config = OmegaConf.to_container(\n            self.config.megatron.get(\"override_transformer_config\", OmegaConf.create()), resolve=True\n        )\n\n        use_shm = self.config.model.get(\"use_shm\", False)\n        sft_tokenizer_local_path = copy_to_local(self.config.model.input_tokenizer, use_shm=use_shm)\n        sft_tokenizer = hf_tokenizer(sft_tokenizer_local_path)\n        rm_tokenizer_path = self.config.model.get(\"rm_tokenizer\", None)\n        rm_tokenizer = None\n        if rm_tokenizer_path is not None:\n            rm_tokenizer_local_path = copy_to_local(rm_tokenizer_path, use_shm=use_shm)\n            rm_tokenizer = hf_tokenizer(\n                rm_tokenizer_local_path, trust_remote_code=self.config.model.get(\"trust_remote_code\", False)\n            )\n\n        self.param_dtype = torch.bfloat16\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n\n        reward_model_module, reward_model_config = self._build_rm_model(\n            model_path=self.config.model.path,\n            tokenizer=rm_tokenizer,\n            override_model_config=override_model_config,\n            override_transformer_config=override_transformer_config,\n        )\n        # FIXME(sgm): reward model param offload is implemented in MegatronRewardModel\n        # should be implemented in workers\n        self.rm = MegatronRewardModel(\n            config=self.config,\n            reward_model_module=reward_model_module,\n            model_config=reward_model_config,\n            hf_config=self.hf_config,\n            tf_config=self.tf_config,\n            sft_tokenizer=sft_tokenizer,\n            rm_tokenizer=rm_tokenizer,\n        )\n\n    # TODO: reward model use itself tokenizer instead of sft tokenizer\n    # the input_ids, responses, attention_mask and position_ids may be different!\n    @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"brown\")\n    def compute_rm_score(self, data: DataProto):\n        data.meta_info[\"micro_batch_size\"] = self.config.micro_batch_size_per_gpu\n        data.meta_info[\"max_token_len\"] = self.config.forward_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        data = data.to(get_device_id())\n        output = self.rm.compute_reward(data)\n        output = output.to(\"cpu\")\n        return output\n"
  },
  {
    "path": "verl_rl/verl/workers/reward_manager/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .registry import get_reward_manager_cls, register  # noqa: I001\nfrom .batch import BatchRewardManager\nfrom .dapo import DAPORewardManager\nfrom .naive import NaiveRewardManager\nfrom .prime import PrimeRewardManager\n\n# Note(haibin.lin): no need to include all reward managers here in case of complicated dependencies\n__all__ = [\n    \"BatchRewardManager\",\n    \"DAPORewardManager\",\n    \"NaiveRewardManager\",\n    \"PrimeRewardManager\",\n    \"register\",\n    \"get_reward_manager_cls\",\n]\n"
  },
  {
    "path": "verl_rl/verl/workers/reward_manager/batch.py",
    "content": "# Copyright 2025 Individual Contributor: Mert Unsal\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import defaultdict\n\nimport torch\n\nfrom verl import DataProto\nfrom verl.workers.reward_manager import register\n\n\n@register(\"batch\")\nclass BatchRewardManager:\n    \"\"\"\n    A batch reward manager that computes rewards for a batch of data.\n\n    Args:\n        tokenizer (Tokenizer): The tokenizer to use for decoding the responses.\n        num_examine (int): The number of responses to examine.\n        compute_score (callable): The function to compute the rewards.\n        reward_fn_key (str): The key to use for the reward function.\n        reward_kwargs (dict): The keyword arguments to pass to the reward function.\n    \"\"\"\n\n    def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key=\"data_source\", **reward_kwargs):\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine\n        self.compute_score = compute_score\n        self.reward_fn_key = reward_fn_key\n        self.reward_kwargs = reward_kwargs\n\n    def verify(self, data):\n        prompt_ids = data.batch[\"prompts\"]\n        response_ids = data.batch[\"responses\"]\n        attention_mask = data.batch[\"attention_mask\"]\n\n        prompt_len = prompt_ids.shape[-1]\n        valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1)\n\n        responses_str = []\n        for i in range(len(data)):\n            valid_len = valid_response_lengths[i]\n            valid_response_ids = response_ids[i][:valid_len]\n            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n            responses_str.append(response_str)\n\n        ground_truths = [item.non_tensor_batch[\"reward_model\"].get(\"ground_truth\", None) for item in data]\n        data_sources = data.non_tensor_batch[self.reward_fn_key]\n        extras = data.non_tensor_batch.get(\"extra_info\", [None] * len(data))\n\n        scores = self.compute_score(\n            data_sources=data_sources,\n            solution_strs=responses_str,\n            ground_truths=ground_truths,\n            extra_infos=extras,\n            **self.reward_kwargs,\n        )\n\n        return scores\n\n    def __call__(self, data: DataProto, return_dict=False):\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            if return_dict:\n                return {\"reward_tensor\": data.batch[\"rm_scores\"]}\n            else:\n                return data.batch[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n        reward_extra_info = defaultdict(list)\n        prompt_ids = data.batch[\"prompts\"]\n        prompt_len = prompt_ids.shape[-1]\n        attention_mask = data.batch[\"attention_mask\"]\n        valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1)\n        data_sources = data.non_tensor_batch[self.reward_fn_key]\n\n        scores = self.verify(data)\n        rewards = []\n        already_printed = {}\n\n        for i in range(len(data)):\n            length = valid_response_lengths[i].item()\n            score = scores[i]\n\n            if isinstance(score, dict):\n                reward = score[\"score\"]\n                for key, value in score.items():\n                    reward_extra_info[key].append(value)\n            else:\n                reward = score\n\n            rewards.append(reward)\n            reward_tensor[i, length - 1] = reward\n\n            data_source = data_sources[i]\n            if already_printed.get(data_source, 0) < self.num_examine:\n                response_str = self.tokenizer.decode(data.batch[\"responses\"][i][:length], skip_special_tokens=True)\n                prompt_str = self.tokenizer.decode(data.batch[\"prompts\"][i], skip_special_tokens=True)\n                ground_truth = data[i].non_tensor_batch[\"reward_model\"].get(\"ground_truth\", None)\n                print(\"[prompt]\", prompt_str)\n                print(\"[response]\", response_str)\n                print(\"[ground_truth]\", ground_truth)\n                print(\"[score]\", scores[i])\n                already_printed[data_source] = already_printed.get(data_source, 0) + 1\n\n        data.batch[\"acc\"] = torch.tensor(rewards, dtype=torch.float32, device=prompt_ids.device)\n\n        if return_dict:\n            return {\"reward_tensor\": reward_tensor, \"reward_extra_info\": reward_extra_info}\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "verl_rl/verl/workers/reward_manager/dapo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import defaultdict\n\nimport torch\n\nfrom verl import DataProto\nfrom verl.utils.reward_score import default_compute_score\nfrom verl.workers.reward_manager import register\n\n\n@register(\"dapo\")\nclass DAPORewardManager:\n    \"\"\"The reward manager.\"\"\"\n\n    def __init__(\n        self,\n        tokenizer,\n        num_examine,\n        compute_score=None,\n        reward_fn_key=\"data_source\",\n        max_resp_len=None,\n        overlong_buffer_cfg=None,\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n        self.compute_score = compute_score or default_compute_score\n        self.reward_fn_key = reward_fn_key\n        self.overlong_buffer_cfg = overlong_buffer_cfg\n        self.max_resp_len = max_resp_len\n\n        if self.overlong_buffer_cfg is not None:\n            assert self.max_resp_len is not None, (\n                f\"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None\"\n            )\n            assert self.max_resp_len >= self.overlong_buffer_cfg.len, (\n                \"max_resp_len must be larger than overlong_buffer.len\"\n            )\n\n    def __call__(self, data: DataProto, return_dict: bool = False):\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            if return_dict:\n                return {\"reward_tensor\": data.batch[\"rm_scores\"]}\n            else:\n                return data.batch[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n        reward_extra_info = defaultdict(list)\n\n        already_print_data_sources = {}\n\n        for i in range(len(data)):\n            data_item = data[i]  # DataProtoItem\n\n            prompt_ids = data_item.batch[\"prompts\"]\n\n            prompt_length = prompt_ids.shape[-1]\n\n            valid_prompt_length = data_item.batch[\"attention_mask\"][:prompt_length].sum()\n            valid_prompt_ids = prompt_ids[-valid_prompt_length:]\n\n            response_ids = data_item.batch[\"responses\"]\n            valid_response_length = data_item.batch[\"attention_mask\"][prompt_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)\n            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n            eos_token = self.tokenizer.eos_token\n            if response_str.endswith(eos_token):\n                response_str = response_str[: -len(eos_token)]\n\n            ground_truth = data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"]\n\n            data_source = data_item.non_tensor_batch[self.reward_fn_key]\n\n            extra_info = data_item.non_tensor_batch.get(\"extra_info\", None)\n\n            result = self.compute_score(\n                data_source=data_source,\n                solution_str=response_str,\n                ground_truth=ground_truth,\n                extra_info=extra_info,\n            )\n\n            score: float\n            if isinstance(result, dict):\n                score = result[\"score\"]\n                # Store the information including original reward\n                for key, value in result.items():\n                    reward_extra_info[key].append(value)\n            else:\n                score = result\n                reward_extra_info[\"acc\"].append(score)\n\n            reward = score\n\n            if self.overlong_buffer_cfg.enable:\n                overlong_buffer_len = self.overlong_buffer_cfg.len\n                expected_len = self.max_resp_len - overlong_buffer_len\n                exceed_len = valid_response_length - expected_len\n                overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor\n                overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)\n                reward += overlong_reward\n                if self.overlong_buffer_cfg.log:\n                    reward_extra_info[\"overlong_reward\"].append(overlong_reward)\n                    reward_extra_info[\"overlong\"].append(overlong_reward < 0)\n\n            reward_tensor[i, valid_response_length - 1] = reward\n\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                print(\"[prompt]\", prompt_str)\n                print(\"[response]\", response_str)\n                print(\"[ground_truth]\", ground_truth)\n                if isinstance(result, dict):\n                    for key, value in result.items():\n                        print(f\"[{key}]\", value)\n                else:\n                    print(\"[score]\", score)\n\n        if return_dict:\n            return {\n                \"reward_tensor\": reward_tensor,\n                \"reward_extra_info\": reward_extra_info,\n            }\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "verl_rl/verl/workers/reward_manager/naive.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import defaultdict\n\nimport torch\n\nfrom verl import DataProto\nfrom verl.utils.reward_score import default_compute_score\nfrom verl.workers.reward_manager import register\n\n\n@register(\"naive\")\nclass NaiveRewardManager:\n    \"\"\"The reward manager.\"\"\"\n\n    def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key=\"data_source\") -> None:\n        \"\"\"\n        Initialize the NaiveRewardManager instance.\n\n        Args:\n            tokenizer: The tokenizer used to decode token IDs into text.\n            num_examine: The number of batches of decoded responses to print to the console for debugging purpose.\n            compute_score: A function to compute the reward score. If None, `default_compute_score` will be used.\n            reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to\n                \"data_source\".\n        \"\"\"\n        self.tokenizer = tokenizer  # Store the tokenizer for decoding token IDs\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n        self.compute_score = compute_score or default_compute_score\n        self.reward_fn_key = reward_fn_key  # Store the key for accessing the data source\n\n    def __call__(self, data: DataProto, return_dict=False):\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            if return_dict:\n                return {\"reward_tensor\": data.batch[\"rm_scores\"]}\n            else:\n                return data.batch[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n        reward_extra_info = defaultdict(list)\n\n        already_print_data_sources = {}\n\n        for i in range(len(data)):\n            data_item = data[i]  # DataProtoItem\n\n            prompt_ids = data_item.batch[\"prompts\"]\n\n            prompt_length = prompt_ids.shape[-1]\n\n            valid_prompt_length = data_item.batch[\"attention_mask\"][:prompt_length].sum()\n            valid_prompt_ids = prompt_ids[-valid_prompt_length:]\n\n            response_ids = data_item.batch[\"responses\"]\n            valid_response_length = data_item.batch[\"attention_mask\"][prompt_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)\n            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n\n            ground_truth = data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"]\n            data_source = data_item.non_tensor_batch[self.reward_fn_key]\n            extra_info = data_item.non_tensor_batch.get(\"extra_info\", {})\n            num_turns = data_item.non_tensor_batch.get(\"__num_turns__\", None)\n            extra_info[\"num_turns\"] = num_turns\n\n            score = self.compute_score(\n                data_source=data_source,\n                solution_str=response_str,\n                ground_truth=ground_truth,\n                extra_info=extra_info,\n            )\n\n            if isinstance(score, dict):\n                reward = score[\"score\"]\n                # Store the information including original reward\n                for key, value in score.items():\n                    reward_extra_info[key].append(value)\n            else:\n                reward = score\n\n            reward_tensor[i, valid_response_length - 1] = reward\n\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                print(\"[prompt]\", prompt_str)\n                print(\"[response]\", response_str)\n                print(\"[ground_truth]\", ground_truth)\n                if isinstance(score, dict):\n                    for key, value in score.items():\n                        print(f\"[{key}]\", value)\n                else:\n                    print(\"[score]\", score)\n\n        if return_dict:\n            return {\n                \"reward_tensor\": reward_tensor,\n                \"reward_extra_info\": reward_extra_info,\n            }\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "verl_rl/verl/workers/reward_manager/prime.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nfrom concurrent.futures import ProcessPoolExecutor\nfrom functools import partial\nfrom typing import Callable, Optional\n\nimport psutil\nimport torch\nfrom transformers import PreTrainedTokenizer\n\nfrom verl import DataProto\nfrom verl.utils.reward_score import default_compute_score\nfrom verl.workers.reward_manager import register\n\n\nasync def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0):\n    loop = asyncio.get_running_loop()\n    try:\n        # Ensure process_completion is called properly\n        future = loop.run_in_executor(executor, partial(evaluation_func, task, completion, reference, task_extra_info))\n        return await asyncio.wait_for(future, timeout=timeout)\n    except asyncio.TimeoutError:\n        print(f\"[Timeout] Task timeout: {completion}\")\n        return None  # Default value for timed-out rows\n    except Exception as e:\n        print(f\"[Error] Task failed: {e}, completion: {completion[:80]}\")\n        return None  # Default value for failed rows\n\n\nasync def parallel_compute_score_async(\n    evaluation_func, completions, references, tasks, extra_info=None, num_processes=64\n):\n    if extra_info is None:\n        extra_info = [None] * len(tasks)\n    scores = []\n    with ProcessPoolExecutor(max_workers=num_processes) as executor:\n        # to prevent very occasional starvation caused by some anomalous programs ( like infinite loop ), the\n        # exceptions in async programs will instantly halt the evaluation, and all summoned processes will be killed.\n        try:\n            # Create tasks for all rows\n            tasks_async = [\n                single_compute_score(evaluation_func, c, r, t, ei, executor, timeout=300.0)\n                for c, r, t, ei in zip(completions, references, tasks, extra_info, strict=True)\n            ]\n            results = await asyncio.gather(*tasks_async, return_exceptions=False)\n        except Exception as e:\n            print(f\"[Exception] async gather failed: {e}\")\n            raise\n        finally:\n            terminated_count = 0\n            for pid, proc in executor._processes.items():\n                try:\n                    p = psutil.Process(pid)\n                    p.terminate()\n                    try:\n                        p.wait(timeout=5)\n                    except psutil.TimeoutExpired:\n                        p.kill()\n                    terminated_count += 1\n                except Exception:\n                    pass\n            print(f\"[Shutdown] {terminated_count} subprocess(es) terminated.\")\n\n    # Process results\n    for result, completion, reference, task in zip(results, completions, references, tasks, strict=True):\n        if isinstance(result, Exception) or result is None:\n            # Handle failed or timed-out tasks\n            scores.append(0.0)\n        elif isinstance(result, int | float | bool):\n            scores.append(float(result))\n        else:\n            scores.append(float(result[0]))\n    return scores\n\n\ndef run_reward_scoring(evaluation_func, completions, references, tasks, extra_info=None, num_processes=64):\n    loop = asyncio.new_event_loop()\n    asyncio.set_event_loop(loop)\n    try:\n        return loop.run_until_complete(\n            parallel_compute_score_async(evaluation_func, completions, references, tasks, extra_info, num_processes)\n        )\n    finally:\n        loop.close()\n\n\n@register(\"prime\")\nclass PrimeRewardManager:\n    \"\"\"\n    The Reward Manager used in https://github.com/PRIME-RL/PRIME\n    \"\"\"\n\n    def __init__(\n        self,\n        tokenizer: PreTrainedTokenizer,\n        num_examine: int,\n        compute_score: Optional[Callable] = None,\n        reward_fn_key: str = \"data_source\",\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n        self.compute_score = compute_score or default_compute_score\n        self.reward_fn_key = reward_fn_key\n\n    def verify(self, data):\n        \"\"\"\n        verify the batch and save as ``acc`` tensor\n        \"\"\"\n        # batched scoring\n        prompt_ids = data.batch[\"prompts\"]\n\n        response_ids = data.batch[\"responses\"]\n        sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)\n        ground_truth = [data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"] for data_item in data]\n        data_sources = data.non_tensor_batch[self.reward_fn_key]\n        extra_info = data.non_tensor_batch.get(\"extra_info\", None)\n\n        assert len(sequences_str) == len(ground_truth) == len(data_sources)\n        try:\n            scores = run_reward_scoring(\n                self.compute_score,\n                completions=sequences_str,\n                references=ground_truth,\n                tasks=data_sources,\n                extra_info=extra_info,\n                num_processes=64,\n            )\n        except asyncio.TimeoutError:\n            print(\"[Timeout] Global reward scoring timed out. Setting all as 0.\")\n            scores = [0.0 for _ in range(len(sequences_str))]\n        except Exception as e:\n            print(f\"[Error] Unexpected error during scoring. Setting all as 0. {e}\")\n            scores = [0.0 for _ in range(len(sequences_str))]\n        data.batch[\"acc\"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device)\n        return scores\n\n    def __call__(self, data: DataProto, return_dict: bool = False):\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.batch.keys():\n            return data.batch[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data.batch[\"responses\"], dtype=torch.float32)\n\n        already_print_data_sources = {}\n\n        # batched scoring\n        prompt_ids = data.batch[\"prompts\"]\n        prompt_length = prompt_ids.shape[-1]\n\n        response_ids = data.batch[\"responses\"]\n        valid_response_length = data.batch[\"attention_mask\"][:, prompt_length:].sum(dim=-1)\n        sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)\n        data_sources = data.non_tensor_batch[\"data_source\"]\n\n        scores = self.verify(data)\n\n        for i in range(len(data)):\n            data_source = data_sources[i]\n            reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]\n\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                print(sequences_str)\n\n        if return_dict:\n            return {\"reward_tensor\": reward_tensor}\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "verl_rl/verl/workers/reward_manager/registry.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__all__ = [\"register\", \"get_reward_manager_cls\"]\n\nREWARD_MANAGER_REGISTRY = {}\n\n\ndef register(name):\n    \"\"\"Decorator to register a reward manager class with a given name.\n\n    Args:\n        name: `(str)`\n            The name of the reward manager.\n    \"\"\"\n\n    def decorator(cls):\n        if name in REWARD_MANAGER_REGISTRY and REWARD_MANAGER_REGISTRY[name] != cls:\n            raise ValueError(\n                f\"Reward manager {name} has already been registered: {REWARD_MANAGER_REGISTRY[name]} vs {cls}\"\n            )\n        REWARD_MANAGER_REGISTRY[name] = cls\n        return cls\n\n    return decorator\n\n\ndef get_reward_manager_cls(name):\n    \"\"\"Get the reward manager class with a given name.\n\n    Args:\n        name: `(str)`\n            The name of the reward manager.\n\n    Returns:\n        `(type)`: The reward manager class.\n    \"\"\"\n    if name not in REWARD_MANAGER_REGISTRY:\n        raise ValueError(f\"Unknown reward manager: {name}\")\n    return REWARD_MANAGER_REGISTRY[name]\n"
  },
  {
    "path": "verl_rl/verl/workers/reward_model/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BasePPORewardModel\n\n__all__ = [\"BasePPORewardModel\"]\n"
  },
  {
    "path": "verl_rl/verl/workers/reward_model/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe base class for reward model\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nfrom verl import DataProto\n\n\nclass BasePPORewardModel(ABC):\n    def __init__(self, config):\n        self.config = config\n\n    @abstractmethod\n    def compute_reward(self, data: DataProto) -> DataProto:\n        \"\"\"Computing reward given input_ids. The transformers should output a tensor with shape\n           [batch_size, sequence_length], and the value at [EOS] mask should be gathered.\n\n        Args:\n            data: must contain keys \"input_ids\", \"attention_mask\" and \"position_ids\".\n                - input_ids: [batch_size, sequence_length]\n                - attention_mask: [batch_size, sequence_length]\n                - position_ids: [batch_size, sequence_length]\n\n        Returns: a data pass protocol containing \"reward\". Only the [EOS] position contains the reward.\n            Other position should have zero reward. Note that this may change in the future if we use\n            dense reward. So, we leave the interface for general case.\n            - reward: [batch_size, sequence_length].\n\n        \"\"\"\n        pass\n"
  },
  {
    "path": "verl_rl/verl/workers/reward_model/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .reward_model import MegatronRewardModel\n\n__all__ = [\"MegatronRewardModel\"]\n"
  },
  {
    "path": "verl_rl/verl/workers/reward_model/megatron/reward_model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMegatron Reward Model.\n\"\"\"\n\nimport itertools\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.pipeline_parallel import get_forward_backward_func\nfrom tensordict import TensorDict\n\nfrom verl import DataProto\nfrom verl.utils.device import get_device_id, get_device_name, get_torch_device\nfrom verl.utils.megatron.pipeline_parallel import make_batch_generator\nfrom verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom verl.utils.torch_functional import broadcast_dict_tensor, pad_sequence_to_length\nfrom verl.workers.reward_model.base import BasePPORewardModel\n\n\nclass MegatronRewardModel(BasePPORewardModel):\n    def __init__(\n        self,\n        config,\n        model_config,\n        reward_model_module: torch.nn.ModuleList,\n        hf_config,\n        tf_config,\n        sft_tokenizer=None,\n        rm_tokenizer=None,\n    ):\n        self.config = config\n        self.reward_model_module = reward_model_module\n        self.hf_config = hf_config\n        self.tf_config = tf_config\n        self.model_config = model_config\n        self.device = \"cuda\"\n        self.sft_tokenizer = sft_tokenizer\n        self.rm_tokenizer = rm_tokenizer\n        self.use_different_tokenizer = rm_tokenizer is not None\n\n        print(f\"MegatronRewardModel.config: {self.config}\")\n\n        if self.config.megatron.param_offload:\n            self.offload_params_to_cpu()\n\n    def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto:\n        assert self.use_different_tokenizer, \"re-encode need rm tokenizer not be None!\"\n        # need to use rm tokenizer to re-generate input_ids, attention_mask and position_ids\n        # 1. remove pad for each sequence\n        # 2. decode by sft_tokenizer, remove sft system prompts\n        # 3. encode by rm_tokenizer with rm system prompts, get rm_input_ids\n        # 4. generate attention_mask and position_ids\n        input_ids = data.batch[\"input_ids\"]  # (bs, seq_len)\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        ori_values = {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"position_ids\": position_ids}\n        _, ori_seqlen = input_ids.size(0), input_ids.size(1)\n        input_ids_for_rm = []\n        attention_mask_for_rm = []\n        position_ids_for_rm = []\n        print_decode = True\n        ori_seqlen = ori_seqlen + 128\n        for id, mask in zip(input_ids, attention_mask, strict=True):\n            # 1. remove pad for each sequence\n            non_zero_indices = torch.nonzero(mask).view(-1)\n            begin_pos, end_pos = non_zero_indices[0].item(), non_zero_indices[-1].item()\n            valid_id = id[begin_pos : end_pos + 1]\n            # 2. decode by sft_tokenizer, remove sft system prompts\n            decode_result = self.sft_tokenizer.decode(valid_id)\n            # workaround\n            decode_with_rm_chat = (\n                decode_result.replace(\"<|user|>\\n\", \"[INST] \")\n                .replace(\"</s>\\n<|assistant|>\\n\", \" [/INST]\")\n                .replace(\"</s> \\n<|assistant|>\\n\", \" [/INST]\")\n                + \"</s>\"\n            )\n            if print_decode and torch.distributed.get_rank() == 0:\n                # only print first decode result\n                print(\n                    f\"device {get_device_id()}: sft decode result:\\n{decode_result}\\n \\\n                        \\ndevice {get_device_id()}: sft decode result with \\\n                        rm chat template:\\n{decode_with_rm_chat}\\n\\n\"\n                )\n                print_decode = False\n            # 3. encode by rm_tokenizer\n            rm_input_ids = self.rm_tokenizer(decode_with_rm_chat, return_tensors=\"pt\")[\"input_ids\"][0].to(\n                input_ids.device\n            )\n            # 4. generate attention_mask and position_ids\n            rm_attention_mask = torch.ones_like(rm_input_ids, device=input_ids.device)\n            cur_seqlen = rm_input_ids.shape[-1]\n            # NOTE(gh): the later reward compute will process the shape (bs, seqlen_pad_128)\n            if cur_seqlen > ori_seqlen:\n                print(f\"warninig: rm encode seqlen {cur_seqlen} > sft encode seqlen {ori_seqlen}\")\n                rm_input_ids = rm_input_ids[:ori_seqlen]\n                rm_attention_mask = rm_attention_mask[:ori_seqlen]\n            else:\n                # right padding\n                rm_input_ids = pad_sequence_to_length(rm_input_ids, ori_seqlen, self.rm_tokenizer.pad_token_id)\n                rm_attention_mask = pad_sequence_to_length(rm_attention_mask, ori_seqlen, 0)\n            rm_position_ids = torch.arange(0, ori_seqlen, device=input_ids.device)\n            input_ids_for_rm.append(torch.unsqueeze(rm_input_ids, dim=0))\n            attention_mask_for_rm.append(torch.unsqueeze(rm_attention_mask, dim=0))\n            position_ids_for_rm.append(torch.unsqueeze(rm_position_ids, dim=0))\n        input_ids_for_rm = torch.cat(input_ids_for_rm, dim=0)\n        attention_mask_for_rm = torch.cat(attention_mask_for_rm, dim=0)\n        position_ids_for_rm = torch.cat(position_ids_for_rm, dim=0)\n\n        # (bs, seqlen) will not change, but input_ids, attention_mask and position_ids will change\n        # NOTE(gh): need to replace into origin values after compute reward!\n        data.batch[\"input_ids\"] = input_ids_for_rm\n        data.batch[\"attention_mask\"] = attention_mask_for_rm\n        data.batch[\"position_ids\"] = position_ids_for_rm\n\n        return data, ori_values\n\n    @torch.no_grad()\n    def compute_reward(self, data: DataProto) -> DataProto:\n        if self.config.megatron.param_offload:\n            self.load_params_to_cuda()\n\n        if self.use_different_tokenizer:\n            data, ori_values = self.re_encode_by_rm_tokenizer(data)\n\n        input_ids = data.batch[\"input_ids\"]  # (bs, seq_len')\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        use_dynamic_bsz = data.meta_info.get(\"use_dynamic_bsz\", False)\n        micro_batch_size = data.meta_info.get(\"micro_batch_size\", None)\n        max_token_len = data.meta_info.get(\"max_token_len\", None)\n        assert micro_batch_size is not None, \"micro batch size is needed for forward compute\"\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"use_dynamic_bsz is True, but max_token_len is None!\"\n            max_token_len = max_token_len * self.config.megatron.context_parallel_size\n\n        responses = data.batch[\"responses\"]\n        batch_size = responses.size(0)\n        response_length = responses.size(1)\n\n        with torch.no_grad():\n            output = self.forward_batch(\n                data, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len\n            )\n            if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                logits = torch.cat(output[\"output\"], dim=0)\n                if use_dynamic_bsz:\n                    indices = output[\"indices\"]\n                    indices = list(itertools.chain.from_iterable(indices))\n                    assert len(indices) == logits.size(0), f\"{len(indices)} vs. {logits.size()}\"\n                    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                    logits = logits[revert_indices]\n            else:\n                logits = torch.empty(\n                    (input_ids.shape[0], input_ids.shape[1]),\n                    device=input_ids.device,\n                )\n            logits = logits.to(torch.float32)\n\n            # broadcast across pp ranks\n            torch.distributed.broadcast(\n                tensor=logits,\n                src=mpu.get_pipeline_model_parallel_last_rank(),\n                group=mpu.get_pipeline_model_parallel_group(),\n                async_op=False,\n            )\n\n        # (bs, seqlen', hidden_size) -> (bs, seqlen', 1) -> (bs, seqlen')\n        token_level_rewards = logits\n        # find the last token reward\n        ends = attention_mask.cumsum(dim=-1).argmax(dim=-1).view(-1, 1)  # (bs, 1)\n        rewards = torch.gather(token_level_rewards, dim=1, index=ends)  # (bs, 1)\n\n        if self.use_different_tokenizer:\n            data.batch.update(ori_values)\n            input_ids = ori_values[\"input_ids\"]\n            attention_mask = ori_values[\"attention_mask\"]\n            position_ids = ori_values[\"position_ids\"]\n\n        token_level_rewards = rewards.expand(attention_mask.shape[0], attention_mask.shape[1])  # (bs, ori_seqlen)\n\n        # assign last valid token reward to ori position\n        if position_ids.dim() == 3:  # qwen2vl mrope [bs, 3, seq_len]\n            position_ids = position_ids[:, 0, :]\n        eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bs,)\n        eos_mask = torch.zeros_like(attention_mask)\n        eos_mask[torch.arange(batch_size), eos_mask_idx] = 1.0\n\n        token_level_rewards = token_level_rewards * eos_mask\n        token_level_rewards = token_level_rewards[:, -response_length:]\n\n        if self.config.megatron.param_offload:\n            self.offload_params_to_cpu()\n        else:\n            # add empty cache after each compute\n            get_torch_device().empty_cache()\n\n        batch = TensorDict({\"rm_scores\": token_level_rewards}, batch_size=input_ids.shape[0])\n\n        return DataProto(batch=batch)\n\n    def forward_batch(self, data: DataProto, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None):\n        \"\"\"\n        We assume:\n        - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input\n        - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled\n        \"\"\"\n        # broadcast from last pp rank to all other pp ranks\n        # TODO: actually, we just need to control the sampling order.\n        mini_batch = data\n        mini_batch.batch = mini_batch.batch.contiguous()\n        broadcast_dict_tensor(\n            mini_batch.batch,\n            src=mpu.get_pipeline_model_parallel_last_rank(),\n            group=mpu.get_pipeline_model_parallel_group(),\n        )\n\n        mini_batch.batch[\"attention_mask\"] = mini_batch.batch[\"attention_mask\"].to(bool)\n\n        self.has_multi_modal_inputs = \"multi_modal_inputs\" in mini_batch.non_tensor_batch.keys()\n        if self.has_multi_modal_inputs:\n            mini_batch.batch[\"multi_modal_inputs\"] = mini_batch.non_tensor_batch[\"multi_modal_inputs\"]\n            mini_batch.batch[\"multi_modal_inputs_idx\"] = torch.Tensor(\n                list(range(len(mini_batch.non_tensor_batch[\"multi_modal_inputs\"])))\n            ).to(torch.int64)\n\n        indices = None\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n            if vpp_size is not None and vpp_size > 1:\n                microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage\n                micro_batches, indices = rearrange_micro_batches(\n                    batch=mini_batch.batch,\n                    num_batches_divided_by=microbatch_group_size_per_vp_stage,\n                    max_token_len=max_token_len,\n                )\n                assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (\n                    f\"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage \"\n                    f\"{microbatch_group_size_per_vp_stage} for megatron backend\"\n                )\n            else:\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)\n            total_seqlen = max_token_len\n        else:\n            assert micro_batch_size is not None, (\n                \"micro_batch_size is needed to be passed in when not using dynamic batch size\"\n            )\n            micro_batches = mini_batch.batch.split(micro_batch_size)\n            seq_len = micro_batches[0][\"input_ids\"].shape[1]\n            total_seqlen = micro_batch_size * seq_len\n        n_micro_batch = len(micro_batches)\n\n        # compute input shapes for pp stages\n        forward_backward_func = get_forward_backward_func()\n\n        def loss_func(output):\n            return torch.tensor(1.0, device=output.device), output\n\n        def forward_step(batch_iter, model):\n            batch = next(batch_iter)\n            input_ids = batch[\"input_ids\"]\n            attention_mask = batch[\"attention_mask\"]\n            position_ids = batch[\"position_ids\"]\n            from verl.models.mcore import get_mcore_forward_fn\n\n            forward_fn = get_mcore_forward_fn(self.hf_config)\n\n            multi_modal_inputs = {}\n            if \"multi_modal_inputs\" in batch:\n                for key in batch[\"multi_modal_inputs\"][0].keys():\n                    multi_modal_inputs[key] = torch.cat(\n                        [batch[\"multi_modal_inputs\"][i][key] for i in batch[\"multi_modal_inputs_idx\"]], dim=0\n                    )\n\n            output = forward_fn(\n                model,\n                input_ids,\n                attention_mask,\n                position_ids,\n                sequence_parallel=self.tf_config.sequence_parallel,\n                value_model=True,\n                multi_modal_inputs=multi_modal_inputs,\n            )\n\n            return output, loss_func\n\n        # batch should be a list of batches inside micro-batches\n        batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.reward_model_module))\n\n        # TODO: we may use the new schedule instead\n        # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)\n        if mpu.get_pipeline_model_parallel_world_size() > 1:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.reward_model_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # no use when input_shapes was set\n                micro_batch_size=1,  # no use when input_shapes was set\n                forward_only=True,\n            )\n        else:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.reward_model_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # in use for pp = 1\n                micro_batch_size=1,  # in use for pp = 1\n                forward_only=True,\n            )\n\n        if self.has_multi_modal_inputs:\n            data.batch.pop(\"multi_modal_inputs\")\n            data.batch.pop(\"multi_modal_inputs_idx\")\n            data.non_tensor_batch.pop(\"multi_modal_inputs\")\n        # loss_reduces contains the stats returned from loss_func\n        losses_reduced = {\"output\": losses_reduced}\n        if use_dynamic_bsz:\n            losses_reduced[\"indices\"] = indices\n        return losses_reduced\n\n    def offload_params_to_cpu(self):\n        if self.device in [\"cuda\", \"npu\"]:\n            for reward_model_module in self.reward_model_module:\n                for name, param in reward_model_module.named_parameters():\n                    param.data = param.data.to(\"cpu\", non_blocking=True)\n            self.device = \"cpu\"\n            get_torch_device().empty_cache()\n\n    def load_params_to_cuda(self):\n        if self.device == \"cpu\":\n            for reward_model_module in self.reward_model_module:\n                for name, param in reward_model_module.named_parameters():\n                    param.data = param.data.to(get_device_id(), non_blocking=True)\n            self.device = get_device_name()\n"
  },
  {
    "path": "verl_rl/verl/workers/roles/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .critic import CriticWorker\n\n__all__ = [\"CriticWorker\"]\n"
  },
  {
    "path": "verl_rl/verl/workers/roles/actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\n\n\nclass ActorWorker(Worker):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    def __init__(self, config):\n        raise NotImplementedError\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        raise NotImplementedError\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def update_actor(self, data: DataProto):\n        raise NotImplementedError\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def compute_log_prob(self, data: DataProto):\n        raise NotImplementedError\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    def compute_ref_log_prob(self, data: DataProto):\n        raise NotImplementedError\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        raise NotImplementedError\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):\n        raise NotImplementedError\n"
  },
  {
    "path": "verl_rl/verl/workers/roles/critic.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe main entry point to run the PPO algorithm\n\"\"\"\n\nimport logging\nimport os\n\nimport torch\nfrom codetiming import Timer\n\nfrom verl import DataProto\nfrom verl.single_controller.base import Worker\nfrom verl.single_controller.base.decorator import Dispatch, register\nfrom verl.trainer.ppo import core_algos\nfrom verl.utils.config import omega_conf_to_dataclass\nfrom verl.utils.device import (\n    get_device_id,\n    get_nccl_backend,\n)\nfrom verl.utils.profiler import DistProfiler, DistProfilerExtension\nfrom verl.utils.py_functional import append_to_dict\nfrom verl.utils.torch_functional import masked_mean\nfrom verl.workers.engine import EngineRegistry\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass CriticWorker(Worker, DistProfilerExtension):\n    def __init__(self, config):\n        Worker.__init__(self)\n        DistProfilerExtension.__init__(\n            self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get(\"profiler\")))\n        )\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(backend=get_nccl_backend())\n        self.config = config\n        self.engine = EngineRegistry.new(self.config.strategy, self.config)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def init_model(self):\n        self.engine.init_model()\n\n    def _post_fn_values(self, micro_batch, preds):\n        response_length = micro_batch[\"responses\"].size(-1)\n        values = preds[:, -response_length - 1 : -1]\n\n        use_remove_padding = self.config.model.get(\"use_remove_padding\", False)\n        if not use_remove_padding:\n            values = values.squeeze(-1)\n\n        return values, {\"values\": values.clone().detach()}\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"cyan\")\n    def compute_values(self, data: DataProto):\n        # Support all hardwares\n        data = data.to(get_device_id())\n        micro_batch_size = self.config.forward_micro_batch_size_per_gpu\n        data.meta_info[\"micro_batch_size\"] = micro_batch_size\n        data.meta_info[\"max_token_len\"] = self.config.forward_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n\n        with self.engine.eval_mode():\n            data = self.engine.shard_data(data=data)\n            output = self.engine.infer_batch(data, post_fn=self._post_fn_values)\n            response_mask = data.batch[\"response_mask\"]\n            values = output[\"values\"] * response_mask  # Only action tokens have values\n            output = DataProto.from_dict(tensors={\"values\": values})\n\n            output = self.engine.unshard_data(data=output)\n        output = output.to(\"cpu\")\n        return output\n\n    def loss_fn(\n        self, batch: DataProto, vpreds: dict[str, torch.Tensor]\n    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\n        old_values = batch[\"values\"]\n        returns = batch[\"returns\"]\n        response_mask = batch[\"response_mask\"]\n        micro_batch_metrics = {}\n\n        values, _ = self._post_fn_values(batch, vpreds)\n\n        vf_loss, vf_clipfrac = core_algos.compute_value_loss(\n            vpreds=values,\n            values=old_values,\n            returns=returns,\n            response_mask=response_mask,\n            cliprange_value=self.config.cliprange_value,\n            loss_agg_mode=self.config.loss_agg_mode,\n        )\n        if self.config.use_dynamic_bsz:\n            # relative to the dynamic bsz\n            loss = vf_loss * (len(batch) / self.config.ppo_mini_batch_size)\n        else:\n            gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n            loss = vf_loss / gradient_accumulation\n\n        micro_batch_metrics = {\n            \"critic/vf_loss\": vf_loss.detach().item(),\n            \"critic/vf_clipfrac\": vf_clipfrac.detach().item(),\n            \"critic/vpred_mean\": masked_mean(values, response_mask).detach().item(),\n        }\n\n        return loss, micro_batch_metrics\n\n    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)\n    @DistProfiler.annotate(color=\"pink\")\n    def update_critic(self, data: DataProto):\n        metrics = {}\n        # Support all hardwares\n        data = data.to(get_device_id())\n        # perform forward computation\n        with self.engine.train_mode():\n            data = self.engine.shard_data(data=data)\n\n            with Timer(name=\"update_critic\", logger=None) as timer:\n                select_keys = [\n                    \"input_ids\",\n                    \"responses\",\n                    \"response_mask\",\n                    \"attention_mask\",\n                    \"position_ids\",\n                    \"values\",\n                    \"returns\",\n                ]\n                batch = data.select(batch_keys=select_keys).batch\n                has_multi_modal_inputs = \"multi_modal_inputs\" in data.non_tensor_batch.keys()\n\n                # Split to make minibatch iterator for updating the actor\n                # See PPO paper for details. https://arxiv.org/abs/1707.06347\n                if has_multi_modal_inputs:\n                    num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size\n                    non_tensor_select_keys = [\"multi_modal_inputs\"]\n                    dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)\n                else:\n                    dataloader = batch.split(self.config.ppo_mini_batch_size)\n\n                for epoch in range(self.config.ppo_epochs):\n                    for batch_idx, mini_batch in enumerate(dataloader):\n                        self.engine.optimizer_zero_grad()\n                        mini_batch_metrics = self.engine.train_batch(mini_batch, self.loss_fn)\n                        grad_norm = self.engine.optimizer_step()\n                        mini_batch_metrics[\"critic/grad_norm\"] = grad_norm.detach().item()\n                        append_to_dict(metrics, mini_batch_metrics)\n                self.engine.optimizer_zero_grad()\n            delta_time = timer.last\n\n            # TODO: should not access engine's flops_counter\n            global_num_tokens = data.meta_info[\"global_token_num\"]\n            estimated_flops, promised_flops = self.engine.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/critic\"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size\n\n            metrics[\"critic/lr\"] = self.engine.lr_scheduler_step()[0]\n            output = DataProto(batch=None, meta_info={\"metrics\": metrics})\n            output = self.engine.unshard_data(data=output)\n\n        output = output.to(\"cpu\")\n        return output\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep)\n\n    @register(dispatch_mode=Dispatch.ONE_TO_ALL)\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):\n        self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load)\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BaseRollout\nfrom .hf_rollout import HFRollout\nfrom .naive import NaiveRollout\n\n__all__ = [\"BaseRollout\", \"NaiveRollout\", \"HFRollout\"]\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/async_server.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport logging\nimport os\nimport socket\nimport threading\nfrom abc import ABC, abstractmethod\nfrom contextlib import asynccontextmanager\nfrom typing import Any, Optional\n\nimport fastapi\nimport ray\nimport uvicorn\nfrom omegaconf import DictConfig\nfrom starlette.requests import Request\nfrom starlette.responses import JSONResponse\n\nfrom verl.protocol import DataProto\nfrom verl.single_controller.ray.base import RayWorkerGroup\nfrom verl.workers.rollout.chat_scheduler import ChatCompletionScheduler\n\nlogger = logging.getLogger(__file__)\n\n\ndef _get_free_port():\n    with socket.socket() as sock:\n        sock.bind((\"\", 0))\n        return sock.getsockname()[1]\n\n\nclass AsyncServerBase(ABC):\n    \"\"\"Base class for AsyncServer.\"\"\"\n\n    def __init__(self):\n        self.address = ray.util.get_node_ip_address()\n        self.port = None\n        self.server_ready = asyncio.Event()\n        asyncio.create_task(self._start_fastapi_server())\n\n    async def _start_fastapi_server(self):\n        @asynccontextmanager\n        async def lifespan(app: fastapi.FastAPI):\n            print(f\"FastAPI listen on {self.address}:{self.port}\")\n            self.server_ready.set()\n            yield\n\n            # There's no way to gracefully restart uvicorn server if port is already in use,\n            # so we exit the process directly and let AsyncLLMServerManager restart it.\n            print(\"FastAPI shutdown, maybe address already in use, exit process immediately.\")\n            os._exit(-1)\n\n        app = fastapi.FastAPI(lifespan=lifespan)\n        app.router.add_api_route(\"/v1/chat/completions\", self.chat_completion, methods=[\"POST\"])\n\n        self.port = _get_free_port()\n        config = uvicorn.Config(app, host=[\"::\", \"0.0.0.0\"], port=self.port, log_level=\"warning\")\n        server = uvicorn.Server(config)\n        await server.serve()\n\n    async def get_server_address(self) -> tuple[str, int]:\n        \"\"\"Get FastAPI server address.\"\"\"\n        await self.server_ready.wait()\n        return f\"{self.address}:{self.port}\"\n\n    @abstractmethod\n    async def chat_completion(self, raw_request: Request) -> JSONResponse:\n        \"\"\"OpenAI chat completion API.\n\n        Args:\n            raw_request (Request): raw json request\n\n        Returns:\n            JSONResponse: json response\n\n        API reference: https://platform.openai.com/docs/api-reference/chat/create\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]:\n        \"\"\"Generate response ids given prompt ids.\n\n        Args:\n            prompt_ids (List[int]): prompt ids\n            sampling_params (Dict[str, Any]): sampling params\n            request_id (str): request id\n\n        Returns:\n            List[int]: response ids\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    async def init_engine(self):\n        \"\"\"Init async LLM engine.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    async def wake_up(self):\n        \"\"\"Wake up engine to load model weights and build kv cache.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    async def sleep(self):\n        \"\"\"Sleep engine to offload model weights and discard kv cache.\"\"\"\n        raise NotImplementedError\n\n\nclass AsyncLLMServerManager:\n    \"\"\"AsyncLLMServerManager manage a group of vllm instances, i.e AsyncvLLMServer.\"\"\"\n\n    def __init__(self, config: DictConfig, worker_group: RayWorkerGroup):\n        \"\"\"Initialize AsyncLLMServerManager.\n\n        Args:\n            config: DictConfig, actor_rollout_ref config.\n            worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker.\n        \"\"\"\n        self.full_config = config\n        self.config = config.actor_rollout_ref\n        self.worker_group = worker_group\n\n        self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size\n        self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size\n\n        register_center = ray.get_actor(f\"{self.worker_group.name_prefix}_register_center\")\n        workers_info = ray.get(register_center.get_worker_info.remote())\n        assert len(workers_info) == self.worker_group.world_size\n\n        self.async_llm_servers = [None] * self.rollout_dp_size\n        self.server_addresses = [None] * self.rollout_dp_size\n\n        if self.config.rollout.agent.custom_async_server:\n            server_class = async_server_class(\n                rollout_backend=self.config.rollout.name,\n                rollout_backend_module=self.config.rollout.agent.custom_async_server.path,\n                rollout_backend_class=self.config.rollout.agent.custom_async_server.name,\n            )\n        else:\n            server_class = async_server_class(rollout_backend=self.config.rollout.name)\n\n        # Start all server instances, restart if address already in use.\n        unready_dp_ranks = set(range(self.rollout_dp_size))\n        while len(unready_dp_ranks) > 0:\n            servers = {\n                rollout_dp_rank: server_class.options(\n                    # make sure AsyncvLLMServer colocates with its corresponding workers\n                    scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(\n                        node_id=workers_info[rollout_dp_rank * self.rollout_tp_size],\n                        soft=False,\n                    ),\n                    name=f\"async_llm_server_{rollout_dp_rank}\",\n                ).remote(config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix)\n                for rollout_dp_rank in unready_dp_ranks\n            }\n\n            for rollout_dp_rank, server in servers.items():\n                try:\n                    address = ray.get(server.get_server_address.remote())\n                    self.server_addresses[rollout_dp_rank] = address\n                    self.async_llm_servers[rollout_dp_rank] = server\n                    unready_dp_ranks.remove(rollout_dp_rank)\n                except Exception:\n                    ray.kill(server)\n                    print(f\"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting...\")\n\n        # All server instances are ready, init AsyncLLM engine.\n        ray.get([server.init_engine.remote() for server in self.async_llm_servers])\n\n        # Init user provided chat scheduler in sperate thread.\n        self.chat_scheduler: ChatCompletionScheduler = None\n        self.chat_scheduler_exception: Exception = None\n        self.chat_scheduler_loop = None\n        self.chat_scheduler_ready = threading.Event()\n        self.chat_scheduler_thread = threading.Thread(target=self._init_chat_scheduler, daemon=True)\n        self.chat_scheduler_thread.start()\n        self.chat_scheduler_ready.wait()\n\n    def _init_chat_scheduler(self):\n        self.chat_scheduler_loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(self.chat_scheduler_loop)\n\n        try:\n            self.chat_scheduler = ChatCompletionScheduler(\n                config=self.full_config,\n                server_addresses=self.server_addresses,\n            )\n        except Exception as e:\n            logger.exception(f\"chat_scheduler init error: {e}\")\n            self.chat_scheduler_exception = e\n        finally:\n            self.chat_scheduler_ready.set()\n        self.chat_scheduler_loop.run_forever()\n\n    def wake_up(self):\n        \"\"\"Wake up all vllm instances.\"\"\"\n        if self.config.rollout.free_cache_engine:\n            ray.get([server.wake_up.remote() for server in self.async_llm_servers])\n\n    def sleep(self):\n        \"\"\"Sleep all vllm instances.\"\"\"\n        if self.config.rollout.free_cache_engine:\n            ray.get([server.sleep.remote() for server in self.async_llm_servers])\n\n    def submit_chat_completions(\n        self,\n        messages: list[dict[str, str]],\n        sampling_params: dict[str, Any],\n    ):\n        \"\"\"Submit a chat completion request to chat scheduler and wait until it is done.\n        To submit multiple requests in parallel, please use `generate_sequences` instead.\n\n        Args: same as ChatCompletionScheduler.submit_chat_completions.\n        \"\"\"\n        assert self.chat_scheduler is not None, \"chat scheduler is not initialized.\"\n        future = asyncio.run_coroutine_threadsafe(\n            self.chat_scheduler._submit_chat_completions_semaphore(\n                messages=messages,\n                request_id=None,\n                sampling_params=sampling_params,\n            ),\n            self.chat_scheduler_loop,\n        )\n        future.result()\n\n    def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto:\n        \"\"\"Generate multiple sequences in parallel via chat scheduler.\"\"\"\n        assert self.chat_scheduler is not None, \"chat scheduler is not initialized.\"\n\n        future = asyncio.run_coroutine_threadsafe(\n            self.chat_scheduler.generate_sequences(prompts, **sampling_params), self.chat_scheduler_loop\n        )\n        return future.result()\n\n\ndef async_server_class(\n    rollout_backend: str, rollout_backend_module: Optional[str] = None, rollout_backend_class: Optional[str] = None\n) -> type[AsyncServerBase]:\n    \"\"\"Get async server class.\n\n    Args:\n        rollout_backend: str, rollout backend type (alias), should be \"vllm\" or \"sglang\".\n        rollout_backend_module: Optional[str], import path of the rollout backend.\n        rollout_backend_class: Optional[str], class name of the rollout backend.\n\n    Returns:\n        Type[AsyncServerBase]: async server class.\n    \"\"\"\n    if rollout_backend_class is None and rollout_backend_module is None:\n        # If both are None, use the default backend class\n        # Do not change the original import behavior\n        # importlib.import_module and from ... import ... have subtle differences in ray\n\n        if rollout_backend == \"vllm\":\n            from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer\n\n            return AsyncvLLMServer\n        elif rollout_backend == \"sglang\":\n            from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSGLangServer\n\n            return AsyncSGLangServer\n        else:\n            raise NotImplementedError(f\"rollout backend {rollout_backend} is not supported\")\n\n    if rollout_backend_module is None or rollout_backend_class is None:\n        raise ValueError(\"rollout_backend_module and rollout_backend_class must be both provided for customization\")\n\n    from verl.utils.import_utils import load_extern_type\n\n    return load_extern_type(rollout_backend_module, rollout_backend_class)\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom abc import ABC, abstractmethod\n\nfrom verl import DataProto\n\n__all__ = [\"BaseRollout\"]\n\n\nclass BaseRollout(ABC):\n    \"\"\"Base class for rollout.\"\"\"\n\n    @abstractmethod\n    def generate_sequences(self, prompts: DataProto) -> DataProto:\n        \"\"\"Generate sequences\"\"\"\n        pass\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/chat_scheduler.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport heapq\nimport importlib\nimport itertools\nimport json\nimport logging\nimport time\nfrom abc import ABC, abstractmethod\nfrom typing import Any\nfrom uuid import uuid4\n\nimport aiohttp\nimport numpy as np\nimport torch\nfrom cachetools import LRUCache\nfrom omegaconf import DictConfig\nfrom openai import AsyncOpenAI\nfrom openai.types.chat.chat_completion import ChatCompletion\nfrom tensordict import TensorDict\n\nfrom verl.protocol import DataProto\nfrom verl.tools.utils.tool_registry import initialize_tools_from_config\nfrom verl.utils import hf_tokenizer\nfrom verl.utils.fs import copy_to_local\nfrom verl.utils.import_utils import deprecated\n\nlogger = logging.getLogger(__file__)\n\n\nclass CompletionCallback(ABC):\n    def __init__(self, config: DictConfig, scheduler: \"ChatCompletionScheduler\"):\n        self.config = config\n        self.scheduler = scheduler\n\n        # Initialize tools from config file\n        self.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns\n        tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path\n        tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []\n        self.tools = {tool.name: tool for tool in tool_list}\n        self._tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]\n        print(f\"Initialized tools: {self.tools}\", flush=True)\n\n        local_path = copy_to_local(config.actor_rollout_ref.model.path)\n        self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)\n\n    @property\n    def tool_schemas(self):\n        \"\"\"OpenAI JSON tool schemas.\"\"\"\n        return self._tool_schemas\n\n    @property\n    def extra_body(self) -> dict[str, Any]:\n        \"\"\"Extra body pass to OpenAI API.\"\"\"\n        return None\n\n    @abstractmethod\n    async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]):\n        \"\"\"Call back function to process completions.\n\n        Args:\n            messages: List of messages including raw prompt and assistant, tool response generated so far.\n            completions: Chat completions from OpenAI compatible server.\n            info: Any other auxiliary information pass across multi-turn.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, str]]], n: int) -> DataProto:\n        \"\"\"Post process batch data.\n\n        Args:\n            batch: Batch input messages from RLHFDataset.\n            batch_conversations: List of messages including raw prompt, assistant response, tool response.\n                Note that `len(batch_conversations) == len(batch) * n`, e.g n=2,\n                batch_conversations=[messages_0_0, messages_0_1, messages_1_0, messages_1_1, ...]\n            n: How many chat completion choices to generate for each input message.\n\n        Returns:\n            Batch data, should include [\"prompts\", \"responses\", \"response_mask\", \"input_ids\", \"attention_mask\",\n            \"position_ids\"].\n        \"\"\"\n        raise NotImplementedError\n\n\nclass ToolCompletionCallback(CompletionCallback):\n    def __init__(self, config: DictConfig, scheduler: \"ChatCompletionScheduler\"):\n        super().__init__(config, scheduler)\n\n        # TODO: add reward manager to calculate reward score once a sample finish\n\n    async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]):\n        message = completions.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)\n        if \"content\" not in message:\n            message[\"content\"] = \"\"\n        messages.append(message)\n        finish_reason = completions.choices[0].finish_reason\n\n        # STEP 0: check if we reach max turns\n        if self.max_assistant_turns and len(messages) >= self.max_assistant_turns:\n            print(f\"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Reach max turns, done!\")\n            return\n\n        # STEP 1: check if the model called tools\n        if finish_reason != \"tool_calls\":\n            print(f\"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] No tool called, done!\")\n            return\n\n        # STEP 2: call tools\n        tool_calls = completions.choices[0].message.tool_calls\n        print(f\"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Call {len(tool_calls)} tools\")\n        tasks = []\n        for tool_call in tool_calls:\n            tasks.append(self._call_tool(tool_call))\n        tool_responses = await asyncio.gather(*tasks)\n        if any(isinstance(item, Exception) for item in tool_responses):\n            print(\n                f\"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Error when calling tools, \"\n                f\"done!\"\n            )\n            return\n        messages.extend(tool_responses)\n\n        # STEP 3: resubmit completion request with tool responses\n        self.scheduler.submit_chat_completions(messages=messages, request_id=completions.id, info=info)\n\n    async def _call_tool(self, tool_call) -> dict[str, str]:\n        \"\"\"Call tool and return tool response.\"\"\"\n        tool_name = tool_call.function.name\n        tool_args = json.loads(tool_call.function.arguments)\n        tool = self.tools[tool_name]\n\n        instance_id = await tool.create()\n        try:\n            tool_response, tool_reward_score, tool_metrics = await tool.execute(instance_id, tool_args)\n        except Exception as e:\n            logger.exception(f\"Error when executing tool: {e}\")\n            return e\n        finally:\n            await tool.release(instance_id)\n\n        return {\n            \"role\": \"tool\",\n            \"content\": tool_response,\n            \"tool_call_id\": tool_call.id,\n        }\n\n    def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, str]]], n: int) -> DataProto:\n        # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py\n        # prompts: left pad\n        # responses: right pad\n        # input_ids: prompt + response\n        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n\n        # prompts: [prompt] from input dataset\n        prompts = [\n            self.tokenizer.apply_chat_template(\n                prompt, tools=self.tool_schemas, add_generation_prompt=True, tokenize=False\n            )\n            for prompt in batch.non_tensor_batch[\"raw_prompt\"]\n        ]\n        assert len(batch_conversations) == len(prompts) * n\n\n        # sequences: [prompt + response]\n        sequences = [\n            self.tokenizer.apply_chat_template(\n                conversation, tools=self.tool_schemas, add_generation_prompt=False, tokenize=False\n            )\n            for conversation in batch_conversations\n        ]\n\n        # responses: [response]\n        responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)]\n\n        prompts = self.tokenizer(prompts, return_tensors=\"pt\", padding=\"longest\", padding_side=\"left\")\n        responses = self.tokenizer(responses, return_tensors=\"pt\", padding=\"longest\", padding_side=\"right\")\n        if n > 1:\n            prompts[\"input_ids\"] = prompts[\"input_ids\"].repeat_interleave(n, dim=0)\n            prompts[\"attention_mask\"] = prompts[\"attention_mask\"].repeat_interleave(n, dim=0)\n\n        # response_mask: response mask with tools calling masked out\n        response_mask = self._mask_out_tools_calling_tokens(\n            batch.non_tensor_batch[\"raw_prompt\"].repeat(n, axis=0),\n            batch_conversations,\n            responses[\"input_ids\"],\n            responses[\"attention_mask\"],\n        )\n\n        input_ids = torch.cat([prompts[\"input_ids\"], responses[\"input_ids\"]], dim=1)\n        attention_mask = torch.cat([prompts[\"attention_mask\"], responses[\"attention_mask\"]], dim=1)\n        position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask\n\n        batch = TensorDict(\n            {\n                \"prompts\": prompts[\"input_ids\"],  # [bsz, prompt_length]\n                \"responses\": responses[\"input_ids\"],  # [bsz, response_length]\n                \"response_mask\": response_mask,  # [bsz, response_length]\n                \"input_ids\": input_ids,  # [bsz, prompt_length + response_length]\n                \"attention_mask\": attention_mask,  # [bsz, prompt_length + response_length]\n                \"position_ids\": position_ids,  # [bsz, prompt_length + response_length]\n            },\n            batch_size=len(input_ids),\n        )\n\n        num_turns = np.array([len(conversation) for conversation in batch_conversations], dtype=np.int32)\n        return DataProto(batch=batch, non_tensor_batch={\"__num_turns__\": num_turns})\n\n    def _mask_out_tools_calling_tokens(\n        self,\n        raw_prompts: list[list[dict[str, str]]],\n        batch_conversations: list[list[dict[str, str]]],\n        input_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Mask out tools calling tokens in the responses.\n\n        Args:\n            raw_prompts: [prompt] from input dataset\n            batch_conversations: [prompt + response]\n            input_ids: responses tokens\n            attention_mask: responses attention mask\n\n        Returns:\n            mask: (batch_size, response_length)\n        \"\"\"\n        batch_size = input_ids.size(0)\n        assert len(raw_prompts) == batch_size, f\"{len(raw_prompts)} != {batch_size}\"\n        assert len(batch_conversations) == batch_size, f\"{len(batch_conversations)} != {batch_size}\"\n\n        # Deduplicate adjacent tool calls, since they're merged into one turn.\n        # [user, assistant, tool, tool, assistant] -> [user, assistant, tool, assistant]\n        # TODO: it's chat_template specific, find a more generic way to do this.\n        def deduplicate_adjacent_tool_calls(roles):\n            result = []\n            for role, group in itertools.groupby(roles):\n                if role == \"tool\":\n                    result.append(role)\n                else:\n                    result.extend(group)\n            return result\n\n        loss_mask = attention_mask.clone()\n        for i in range(batch_size):\n            responses = batch_conversations[i][len(raw_prompts[i]) :]\n            assert len(responses) > 0, f\"responses is empty: {responses}\"\n\n            roles = deduplicate_adjacent_tool_calls([response[\"role\"] for response in responses])\n            # Each turn should be: [BOS]...[EOS]\n            eos_indices = input_ids[i].eq(self.tokenizer.eos_token_id).nonzero().squeeze(1)[: len(roles)]\n            for j in range(len(roles)):\n                if roles[j] == \"tool\":\n                    bos = eos_indices[j - 1] + 1 if j > 0 else 0\n                    eos = eos_indices[j]\n                    loss_mask[i, bos : eos + 1] = 0\n\n        return loss_mask\n\n\n@deprecated(\"verl.experimental.agent_loop.AgentLoopManager\")\nclass ChatCompletionScheduler:\n    def __init__(\n        self,\n        config: DictConfig,\n        server_addresses: list[str],\n        max_cache_size: int = 10000,\n    ):\n        \"\"\"\n        Args:\n            config: DictConfig.\n            server_addresses: List[str], OpenAI compatible server addresses.\n            max_cache_size: int, max cache size of request_id to address mapping.\n        \"\"\"\n        self.config = config.actor_rollout_ref.rollout\n        model_path = config.actor_rollout_ref.model.path\n        self.model_name = \"/\".join(model_path.split(\"/\")[-2:])\n\n        # Least requests load balancing\n        self.weighted_addresses = [[0, address] for address in server_addresses]\n        heapq.heapify(self.weighted_addresses)\n\n        # LRU cache to map request_id to address\n        self.request_id_to_address = LRUCache(maxsize=max_cache_size)\n\n        self.background_tasks = set()\n        if self.config.multi_turn.completion_callback is None:\n            self.completion_callback = ToolCompletionCallback(config, self)\n            logger.warning(\"completion_callback is None, use ToolCompletionCallback\")\n        else:\n            module_path, class_name = self.config.multi_turn.completion_callback.rsplit(\".\", 1)\n            module = importlib.import_module(module_path)\n            self.completion_callback = getattr(module, class_name)(config, self)\n\n    def submit_chat_completions(self, *, messages: list[dict[str, str]], request_id: str, info: dict[str, Any]):\n        \"\"\"Submit chat completion request without wait, completion_callback will be called when the request is done.\n\n        Args:\n            messages: List of messages.\n            request_id: Request id.\n            info: Any other auxiliary information pass across multi-turn.\n        \"\"\"\n        info[\"__depth__\"] += 1\n        task = asyncio.create_task(self._submit_chat_completions_and_callback(messages, request_id, info))\n\n        # “fire-and-forget” background tasks\n        self.background_tasks.add(task)\n        task.add_done_callback(self.background_tasks.discard)\n\n    async def _submit_chat_completions_and_callback(\n        self,\n        messages: list[dict[str, str]],\n        request_id: str,\n        info: dict[str, Any],\n    ):\n        \"\"\"Submit chat completion request, wait request finish and do callback.\"\"\"\n        if request_id:\n            request_id = request_id.removeprefix(\"chatcmpl-\")\n            assert request_id in self.request_id_to_address\n            address = self.request_id_to_address.pop(request_id)\n        else:\n            address = self.weighted_addresses[0][1]\n            self.weighted_addresses[0][0] += 1\n            heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0])\n\n        # use new request_id to avoid duplicate request_id problem\n        request_id = uuid4().hex\n        self.request_id_to_address[request_id] = address\n\n        completions, exception = None, None\n        try:\n            # NOTE: OpenAI client uses httpx, seems to have performance issue in high concurrency requests.\n            completions = await self._chat_completions_aiohttp(\n                address,\n                messages=messages,\n                tools=self.completion_callback.tool_schemas,\n                extra_body=self.completion_callback.extra_body,\n                extra_headers={\"x-request-id\": request_id},\n                **info[\"__sampling_params__\"],\n            )\n        except Exception as e:\n            # Let user handle the exception\n            exception = e\n\n        info[\"__depth__\"] -= 1\n\n        if exception is not None:\n            logger.exception(f\"chat completion failed with exception: {exception}\")\n        else:\n            try:\n                await self.completion_callback(messages, completions, info)\n            except Exception as e:\n                logger.exception(f\"completion callback failed with exception: {e}\")\n\n        # No more ongoing completion requests\n        if info[\"__depth__\"] == 0:\n            info[\"__done__\"].set()\n\n    async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion:\n        client = AsyncOpenAI(base_url=f\"http://{address}/v1\", api_key=\"token-abc123\", timeout=None, max_retries=0)\n        return await client.chat.completions.create(**chat_complete_request)\n\n    async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion:\n        try:\n            extra_body = chat_complete_request.pop(\"extra_body\", {})\n            chat_complete_request.update(extra_body or {})\n            extra_headers = chat_complete_request.pop(\"extra_headers\")\n            timeout = aiohttp.ClientTimeout(total=None)\n            session = aiohttp.ClientSession(timeout=timeout)\n            async with session.post(\n                url=f\"http://{address}/v1/chat/completions\",\n                headers={\"Authorization\": \"Bearer token-abc123\", **extra_headers},\n                json=chat_complete_request,\n            ) as resp:\n                data = await resp.json()\n                return ChatCompletion(**data)\n        finally:\n            await session.close()\n\n    async def generate_sequences(self, batch: DataProto) -> DataProto:\n        t_start = time.time()\n        kwargs = dict(\n            model=self.model_name,\n            temperature=self.config.temperature,\n            top_p=self.config.top_p,\n        )\n\n        # override sampling params for validation\n        if batch.meta_info.get(\"validate\", False):\n            kwargs[\"top_p\"] = self.config.val_kwargs.top_p\n            kwargs[\"temperature\"] = self.config.val_kwargs.temperature\n\n        print(f\"[ChatCompletionScheduler] generate_sequences sampling params: {kwargs}\")\n\n        # NOTE: For multi-turn rollout, repeat raw_prompt n times and process each prompt independently,\n        # validation dataset has already been repeated in `PPOTrainer._validate`.\n        n = 1 if batch.meta_info.get(\"validate\", False) else self.config.n\n        tasks, batch_conversations = [], [None] * len(batch) * n\n        for batch_index, conversation in enumerate(batch.non_tensor_batch[\"raw_prompt\"].repeat(n, axis=0)):\n            # raw_prompt: [{\"role\": \"user\", \"content\": \"\"}, [\"role\": \"assistant\", \"content\"], ...]\n            batch_conversations[batch_index] = conversation.tolist()\n\n            tasks.append(\n                asyncio.create_task(\n                    self._submit_chat_completions_semaphore(\n                        messages=batch_conversations[batch_index],\n                        request_id=None,\n                        sampling_params=kwargs,\n                    )\n                )\n            )\n\n        await asyncio.gather(*tasks)\n        output_batch = self.completion_callback.postprocess(batch, batch_conversations, n=n)\n        output_batch.meta_info[\"timing\"] = {\"generate_sequences\": time.time() - t_start}\n        print(\"[ChatCompletionScheduler] generate_sequences done\")\n        return output_batch\n\n    async def _submit_chat_completions_semaphore(\n        self, messages: list[dict[str, str]], request_id: str, sampling_params: dict[str, Any]\n    ):\n        done = asyncio.Event()\n\n        info = {\n            \"__done__\": done,\n            \"__depth__\": 0,  # indicate how many ongoing completion requests\n            \"__sampling_params__\": sampling_params,\n        }\n\n        self.submit_chat_completions(messages=messages, request_id=request_id, info=info)\n\n        # Wait until all completion requests are done\n        await done.wait()\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/hf_rollout.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nRollout with huggingface models.\nTODO: refactor this class. Currently, it will hang when using FSDP HybridShard. We should actually create a single\nGPU model. Then, get full state_dict and bind the state_dict to the single GPU model. Then, use the single GPU model\nto perform generation.\n\"\"\"\n\nimport contextlib\n\nimport torch\nimport torch.distributed\nfrom tensordict import TensorDict\nfrom torch import nn\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom transformers import GenerationConfig\n\nfrom verl import DataProto\nfrom verl.utils.device import get_device_name, get_torch_device\nfrom verl.utils.torch_functional import get_response_mask\n\nfrom .base import BaseRollout\n\n__all__ = [\"HFRollout\"]\n\n\nclass HFRollout(BaseRollout):\n    def __init__(self, module: nn.Module, config):\n        super().__init__()\n        self.config = config\n        self.module = module\n\n    def generate_sequences(self, prompts: DataProto) -> DataProto:\n        batch_size = prompts.batch.batch_size[0]\n        num_chunks = max(batch_size // self.config.get(\"micro_batch_size\", batch_size), 1)\n        batch_prompts = prompts.chunk(chunks=num_chunks)\n        output = [self._generate_minibatch(p) for p in batch_prompts]\n        output = DataProto.concat(output)\n        return output\n\n    @torch.no_grad()\n    def _generate_minibatch(self, prompts: DataProto) -> DataProto:\n        # make sampling args can be overridden by inputs\n        do_sample = prompts.meta_info.get(\"do_sample\", self.config.do_sample)\n        is_validate = prompts.meta_info.get(\"validate\", False)\n\n        temperature = prompts.meta_info.get(\"temperature\", self.config.temperature)\n        response_length = prompts.meta_info.get(\"response_length\", self.config.response_length)\n        top_p = prompts.meta_info.get(\"top_p\", self.config.get(\"top_p\", 1.0))\n        top_k = max(0, prompts.meta_info.get(\"top_k\", self.config.get(\"top_k\", 0)))  # to be compatible with vllm\n\n        if not do_sample:\n            # do_sample==False -> greedy decoding\n            kwargs = {\n                \"do_sample\": False,\n                \"num_beams\": 1,\n            }\n        elif is_validate:\n            # do validate and do sample -> use val_kwargs\n            kwargs = {\n                \"do_sample\": True,\n                \"num_beams\": 1,\n                \"top_k\": max(0, self.config.val_kwargs.top_k),  # to be compatible with vllm\n                \"top_p\": self.config.val_kwargs.top_p,\n                \"temperature\": self.config.val_kwargs.temperature,\n                \"num_return_sequences\": 1,  # if validate, already repeat in ray_trainer\n            }\n        else:\n            # do_sample -> use rollout config\n            kwargs = {\n                \"do_sample\": True,\n                \"num_beams\": 1,\n                \"top_p\": top_p,\n                \"top_k\": top_k,\n                \"temperature\": temperature,\n                \"num_return_sequences\": self.config.n,\n            }\n\n        # make config according to generate mode\n        generation_config = GenerationConfig(**kwargs)\n\n        idx = prompts.batch[\"input_ids\"]  # (bs, prompt_length)\n        prompt_length = idx.size(1)\n        attention_mask = prompts.batch[\"attention_mask\"]  # left-padded attention_mask\n        position_ids = prompts.batch[\"position_ids\"]\n\n        # used to construct attention_mask\n        eos_token_id = prompts.meta_info[\"eos_token_id\"]\n        pad_token_id = prompts.meta_info[\"pad_token_id\"]\n\n        self.module.eval()\n        param_ctx = contextlib.nullcontext()\n\n        if isinstance(self.module, FSDP):\n            # recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069\n            param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False)\n        with param_ctx, torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):\n            output = self.module.generate(\n                input_ids=idx,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                do_sample=do_sample,\n                max_new_tokens=response_length,\n                eos_token_id=eos_token_id,\n                pad_token_id=pad_token_id,\n                generation_config=generation_config,\n                output_scores=False,  # this is potentially very large\n                return_dict_in_generate=True,\n                use_cache=True,\n            )\n\n        # TODO: filter out the seq with no answers like ds-chat\n        seq = output.sequences\n        generated_batch_size = seq.size(0)  # bs * num_return_sequences\n\n        # huggingface generate will stop generating when all the batch reaches [EOS].\n        # We have to pad to response_length\n        sequence_length = prompt_length + self.config.response_length\n        delta_length = sequence_length - seq.shape[1]\n\n        if delta_length > 0:\n            delta_tokens = torch.ones(size=(generated_batch_size, delta_length), device=seq.device, dtype=seq.dtype)\n            delta_tokens = pad_token_id * delta_tokens\n            seq = torch.cat((seq, delta_tokens), dim=1)\n        assert seq.shape[1] == sequence_length\n\n        # make necessary reputations if num_return_sequences > 1\n        num_return_sequences = kwargs.get(\"num_return_sequences\", 1)\n        if num_return_sequences > 1:\n            position_ids = position_ids.repeat_interleave(num_return_sequences, dim=0)\n            attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)\n\n        prompt = seq[:, :prompt_length]  # (generated_batch_size, prompt_length)\n        response = seq[:, prompt_length:]  # (generated_batch_size, response_length)\n\n        response_length = response.size(1)\n        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n        delta_position_id = delta_position_id.unsqueeze(0).repeat(generated_batch_size, 1)\n\n        response_position_ids = position_ids[:, -1:] + delta_position_id\n        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n\n        response_attention_mask = get_response_mask(\n            response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype\n        )\n        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n\n        batch = TensorDict(\n            {\n                \"prompts\": prompt,\n                \"responses\": response,\n                \"input_ids\": seq,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=generated_batch_size,\n        )\n\n        # empty cache before compute old_log_prob\n        get_torch_device().empty_cache()\n\n        self.module.train()\n        return DataProto(batch=batch)\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/naive/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .naive_rollout import NaiveRollout\n\n__all__ = [\"NaiveRollout\"]\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/naive/naive_rollout.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nIn single GPU rollout, the sequences are generated directly by sampling from the model.\nThe output will contain\n1. output_ids\n2. attention_masks (left padding)\n3. eos_masks\n4. log_probs\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom tensordict import TensorDict\nfrom torch import nn\n\nfrom verl import DataProto\nfrom verl.utils.torch_functional import logprobs_from_logits\n\nfrom ..base import BaseRollout\n\n__all__ = [\"NaiveRollout\"]\n\n\nclass NaiveRollout(BaseRollout):\n    def __init__(self, module: nn.Module, config):\n        \"\"\"A naive rollout. It requires the module to be compatible with huggingface APIs. That is:\n        The module should define __call__ to receive input_ids, attention_mask and position_ids.\n        It outputs a structure that contains logits field.\n\n        Args:\n            module: module here follows huggingface APIs\n            config: DictConfig\n        \"\"\"\n        super().__init__()\n        self.config = config\n        self.module = module\n\n    @torch.no_grad()\n    def generate_sequences(self, prompts: DataProto) -> DataProto:\n        \"\"\"Generate sequences\"\"\"\n        idx = prompts.batch[\"input_ids\"]  # (bs, prompt_length)\n        attention_mask = prompts.batch[\"attention_mask\"]  # left-padded attention_mask\n        position_ids = prompts.batch[\"position_ids\"]\n\n        # used to construct attention_mask\n        eos_token_id = prompts.meta_info[\"eos_token_id\"]\n\n        batch_size = idx.size(0)\n        prompt_length = idx.size(1)\n\n        self.module.eval()\n\n        prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)\n\n        logits_lst = []\n        for _ in range(self.config.response_length):\n            # if the sequence context is growing too long we must crop it at block_size\n            # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]\n            idx_cond = idx\n            # forward the model to get the logits for the index in the sequence\n            # we use huggingface APIs here\n            output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids)\n            logits = output.logits\n            # pluck the logits at the final step and scale by desired temperature\n            logits = logits[:, -1, :] / self.config.temperature  # (bs, vocab_size)\n            # optionally crop the logits to only the top k options\n            if self.config.top_k is not None:\n                v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1)))\n                logits[logits < v[:, [-1]]] = -float(\"Inf\")\n            # apply softmax to convert logits to (normalized) probabilities\n            probs = F.softmax(logits, dim=-1)\n            # sample from the distribution\n            if self.config.do_sample:\n                idx_next = torch.multinomial(probs, num_samples=1)\n            else:\n                idx_next = torch.argmax(probs, dim=-1, keepdim=True)\n\n            attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)\n\n            for token_id in eos_token_id:\n                prev_attention_mask = torch.logical_and(idx_next != token_id, prev_attention_mask.bool())\n            prev_attention_mask.to(attention_mask.dtype)\n\n            position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)\n\n            # append sampled index to the running sequence and continue\n            idx = torch.cat((idx, idx_next), dim=1)\n            logits_lst.append(logits)\n\n        logits = torch.stack(logits_lst, dim=1)  # (bs, response_length, vocab_size)\n        prompts = idx[:, :prompt_length]  # (bs, prompt_length)\n        response = idx[:, prompt_length:]  # (bs, response_length)\n        log_probs = logprobs_from_logits(logits=logits, labels=response)\n        batch = TensorDict(\n            {\n                \"input_ids\": prompts,\n                \"responses\": response,\n                \"sequences\": idx,\n                \"old_log_probs\": log_probs,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=batch_size,\n        )\n\n        self.module.train()\n\n        return DataProto(batch=batch)\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/schemas.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport difflib\nimport logging\nimport os\nfrom enum import Enum\nfrom typing import Any, Optional\n\nimport torch\nfrom pydantic import BaseModel, ConfigDict, model_validator\nfrom transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin\n\nfrom verl.tools.schemas import OpenAIFunctionToolCall, OpenAIFunctionToolSchema\nfrom verl.utils.model import compute_position_id_with_mask\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\nBASE_CHAT_HISTORY = [\n    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n    {\"role\": \"user\", \"content\": \"I am a user.\"},\n]\n\n\nclass FinishReasonTypeEnum(str, Enum):\n    \"\"\"The enum for finish reason type.\"\"\"\n\n    LENGTH = \"length\"\n    STOP = \"stop\"\n    TOOL_CALL = \"tool_calls\"\n\n    @classmethod\n    def from_str(cls, value: str) -> \"FinishReasonTypeEnum\":\n        if value == \"stop\":\n            return cls.STOP\n        elif value == \"length\":\n            return cls.LENGTH\n        elif value == \"tool_calls\":\n            return cls.TOOL_CALL\n        else:\n            raise ValueError(f\"Unsupported finish reason type: {value}\")\n\n\nclass Message(BaseModel):\n    role: str\n    content: str | dict[str, Any] | list[dict[str, Any]]\n    tool_calls: Optional[list[OpenAIFunctionToolCall]] = None\n\n\nclass AsyncRolloutRequestStateEnum(str, Enum):\n    \"\"\"The enum for async rollout request state.\"\"\"\n\n    PENDING = \"pending\"\n    RUNNING = \"running\"\n    COMPLETED = \"completed\"\n    FAILED = \"failed\"\n    TOOL_CALLING = \"tool_calling\"\n    INTERACTING = \"interacting\"\n\n\nclass TokenizationSanityCheckModeEnum(str, Enum):\n    \"\"\"The enum for tokenization sanity check mode.\"\"\"\n\n    DISABLE = \"disable\"\n    STRICT = \"strict\"\n    IGNORE_STRIPPABLE = \"ignore_strippable\"\n\n\nclass AsyncRolloutRequest(BaseModel):\n    \"\"\"The data model for async rollout.\"\"\"\n\n    model_config = ConfigDict(arbitrary_types_allowed=True)\n\n    batch_data_id: int = 0\n    rollout_offset: int = 0\n    request_id: str\n    state: AsyncRolloutRequestStateEnum\n    messages: list[Message]\n    multi_modal_keys: Optional[list[str]] = None\n    multi_modal_data: Optional[dict[str, Any]] = None\n    multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None\n    tool_schemas: Optional[list[OpenAIFunctionToolSchema]] = None\n    tools_kwargs: dict[str, Any] = {}\n    interaction_kwargs: dict[str, Any] = {}\n    input_ids: Optional[torch.Tensor] = None\n    prompt_ids: Optional[torch.Tensor] = None\n    response_ids: Optional[torch.Tensor] = None\n    attention_mask: Optional[torch.Tensor] = None\n    prompt_attention_mask: Optional[torch.Tensor] = None\n    response_attention_mask: Optional[torch.Tensor] = None\n    position_ids: Optional[torch.Tensor] = None\n    prompt_position_ids: Optional[torch.Tensor] = None\n    response_position_ids: Optional[torch.Tensor] = None\n    loss_mask: Optional[torch.Tensor] = None\n    prompt_loss_mask: Optional[torch.Tensor] = None\n    response_loss_mask: Optional[torch.Tensor] = None\n    reward_scores: dict[str, float]\n    max_prompt_len: int\n    max_response_len: int = 8192\n    max_model_len: int = 32768\n    metrics: dict[str, list[Any]] = {}\n\n    use_inference_chat_template: bool\n    tokenization_sanity_check_mode: TokenizationSanityCheckModeEnum\n    generation_prompt_ids: Optional[torch.Tensor] = None\n    base_conv_wo_gen_prompt_end_pos: int\n    base_conv_with_gen_prompt_end_pos: int\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def initialize_request(cls, values):\n        if not (messages := values.get(\"messages\")):\n            raise ValueError(\"messages is required for AsyncRolloutRequest initialization\")\n        if not (max_prompt_len := values.get(\"max_prompt_len\")):\n            raise ValueError(\"max_prompt_len is required for AsyncRolloutRequest initialization\")\n        if not (processing_class := values.pop(\"processing_class\", None)):\n            raise ValueError(\"processing_class is required for AsyncRolloutRequest initialization\")\n\n        values[\"messages\"] = [Message.model_validate(msg) for msg in messages]\n\n        # If there is no multi_modal_keys, we assume the multi-modal data is image and video.\n        if not values.get(\"multi_modal_keys\"):\n            values[\"multi_modal_keys\"] = [\"image\", \"video\"]\n        if not values.get(\"multi_modal_data\"):\n            values[\"multi_modal_data\"] = {key: [] for key in values[\"multi_modal_keys\"]}\n        else:\n            # check if all multi_modal_keys are in multi_modal_data\n            for key in values[\"multi_modal_keys\"]:\n                if key not in values[\"multi_modal_data\"]:\n                    values[\"multi_modal_data\"][key] = []\n        if not values.get(\"multi_modal_inputs\"):\n            values[\"multi_modal_inputs\"] = {}\n\n        tools = (\n            [tool.model_dump() for tool in tool_schemas] if (tool_schemas := values.get(\"tool_schemas\", [])) else None\n        )\n\n        multi_modal_data = values[\"multi_modal_data\"]\n        tokens_without_prompt = cls._handle_apply_chat_template(\n            processing_class,\n            messages,\n            multi_modal_data=multi_modal_data,\n            tools=tools,\n            add_generation_prompt=False,\n            tokenize=True,\n        )\n        if (\n            values.get(\"input_ids\") is None\n            or values.get(\"attention_mask\") is None\n            or values.get(\"position_ids\") is None\n        ):\n            tokenization_dict_with_prompt = cls._handle_apply_chat_template(\n                processing_class,\n                messages,\n                multi_modal_data=multi_modal_data,\n                tools=tools,\n                add_generation_prompt=True,\n                tokenize=True,\n                return_dict=True,\n            )\n\n            values[\"input_ids\"], values[\"attention_mask\"] = (\n                tokenization_dict_with_prompt[\"input_ids\"],\n                tokenization_dict_with_prompt[\"attention_mask\"],\n            )\n            if values[\"input_ids\"].shape[-1] > max_prompt_len:\n                # Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an\n                # error for this case in the future.\n                logger.warning(\n                    f\"Prompt {values['batch_data_id']} has length {values['input_ids'].shape[-1]} \"\n                    f\"which is greater than max_prompt_len {max_prompt_len} after applied chat template with tools.\"\n                )\n\n            # Process multi_modal_inputs\n            multi_modal_inputs = tokenization_dict_with_prompt.copy()\n            multi_modal_inputs.pop(\"input_ids\", None)\n            multi_modal_inputs.pop(\"attention_mask\", None)\n            values[\"multi_modal_inputs\"] = multi_modal_inputs\n\n            values[\"position_ids\"] = values[\"prompt_position_ids\"] = cls._get_position_ids(\n                processing_class, values[\"input_ids\"], values[\"attention_mask\"], multi_modal_inputs\n            )\n\n        values[\"prompt_ids\"], values[\"prompt_attention_mask\"] = values[\"input_ids\"], values[\"attention_mask\"]\n        values[\"loss_mask\"] = values[\"prompt_loss_mask\"] = torch.zeros_like(values[\"input_ids\"], dtype=torch.bool)\n        values[\"generation_prompt_ids\"] = values[\"input_ids\"][..., tokens_without_prompt.shape[-1] :]\n        values[\"base_conv_wo_gen_prompt_end_pos\"] = cls._handle_apply_chat_template(\n            processing_class,\n            BASE_CHAT_HISTORY,\n            multi_modal_data=multi_modal_data,\n            tools=tools,\n            add_generation_prompt=False,\n            tokenize=True,\n        ).shape[-1]\n\n        values[\"base_conv_with_gen_prompt_end_pos\"] = cls._handle_apply_chat_template(\n            processing_class,\n            BASE_CHAT_HISTORY,\n            multi_modal_data=multi_modal_data,\n            tools=tools,\n            add_generation_prompt=True,\n            tokenize=True,\n        ).shape[-1]\n\n        return values\n\n    @staticmethod\n    def _handle_apply_chat_template(\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        messages: list[Message],\n        multi_modal_data: dict[str, Any],\n        tools: Optional[list[OpenAIFunctionToolSchema]] = None,\n        add_generation_prompt: bool = False,\n        tokenize: bool = False,\n        return_dict: bool = False,\n    ):\n        raw_prompt = processing_class.apply_chat_template(\n            messages, tools=tools, add_generation_prompt=add_generation_prompt, tokenize=False\n        )\n        if not tokenize:\n            return raw_prompt\n\n        if isinstance(processing_class, PreTrainedTokenizer) or isinstance(processing_class, PreTrainedTokenizerFast):\n            if any(len(values) > 0 for values in multi_modal_data.values()):\n                logger.warning(\n                    \"There is multi_modal_data but you are not using a processor. Multi-modal data will be ignored.\"\n                )\n            model_inputs = processing_class(text=[raw_prompt], return_tensors=\"pt\")\n        elif isinstance(processing_class, ProcessorMixin):\n            # When we update multi_model_keys, we also need to update this logic\n            images = images if len(images := multi_modal_data.get(\"image\", [])) > 0 else None\n            videos = videos if len(videos := multi_modal_data.get(\"video\", [])) > 0 else None\n            model_inputs = processing_class(text=[raw_prompt], images=images, videos=videos, return_tensors=\"pt\")\n        else:\n            raise ValueError(f\"Unsupported processing class type: {type(processing_class)}\")\n\n        model_inputs = dict(model_inputs)\n        if return_dict:\n            return model_inputs\n        else:\n            return model_inputs[\"input_ids\"]\n\n    @staticmethod\n    def _get_position_ids(\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        input_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n        multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        # special case for qwen2vl\n        is_qwen2vl = (\n            hasattr(processing_class, \"image_processor\")\n            and \"Qwen2VLImageProcessor\" in processing_class.image_processor.__class__.__name__\n        )\n        if is_qwen2vl:\n            from verl.models.transformers.qwen2_vl import get_rope_index\n\n            image_grid_thw = video_grid_thw = second_per_grid_ts = None\n            if multi_modal_inputs:\n                image_grid_thw = multi_modal_inputs.get(\"image_grid_thw\")\n                video_grid_thw = multi_modal_inputs.get(\"video_grid_thw\")\n                second_per_grid_ts = multi_modal_inputs.get(\"second_per_grid_ts\")\n\n            assert input_ids.dim() == 2 and input_ids.shape[0] == 1, (\n                f\"input_ids should be 2D with batch size 1, but got shape {input_ids.shape}\"\n            )\n            assert attention_mask.dim() == 2 and attention_mask.shape[0] == 1, (\n                f\"attention_mask should be 2D with batch size 1, but got shape {attention_mask.shape}\"\n            )\n            new_position_ids = get_rope_index(\n                processing_class,\n                input_ids=input_ids.squeeze(0),\n                image_grid_thw=image_grid_thw,\n                video_grid_thw=video_grid_thw,\n                second_per_grid_ts=second_per_grid_ts,\n                attention_mask=attention_mask.squeeze(0),\n            )\n            return new_position_ids  # (3, seq_len)\n        else:\n            return compute_position_id_with_mask(attention_mask)  # (1, seq_len)\n\n    def _update_input_ids(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        new_input_ids: torch.Tensor,\n        attention_mask: bool,\n        loss_mask: bool,\n        new_multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None,\n    ) -> None:\n        \"\"\"\n        Update the input_ids, attention_mask, position_ids, and loss_mask of the request in additive manner.\n        \"\"\"\n        self.input_ids = torch.cat([self.input_ids, new_input_ids], dim=-1)\n        attention_mask = torch.ones_like(new_input_ids) * int(attention_mask)\n        self.attention_mask = torch.cat([self.attention_mask, attention_mask], dim=-1)\n        loss_mask = torch.ones_like(new_input_ids) * int(loss_mask)\n        self.loss_mask = torch.cat([self.loss_mask, loss_mask], dim=-1)\n\n        if new_multi_modal_inputs:\n            self._update_multi_modal_inputs(new_multi_modal_inputs)\n\n        new_position_ids = self._get_position_ids(\n            processing_class, new_input_ids, attention_mask, new_multi_modal_inputs\n        )\n\n        last_pos = self.position_ids[..., -1:]\n        new_position_ids = new_position_ids + (last_pos + 1)\n\n        self.position_ids = torch.cat([self.position_ids, new_position_ids], dim=-1)\n\n        assert (\n            self.input_ids.shape[-1]\n            == self.attention_mask.shape[-1]\n            == self.position_ids.shape[-1]\n            == self.loss_mask.shape[-1]\n        ), f\"\"\"Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, \n            {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}\"\"\"\n\n    def _update_multi_modal_inputs(self, new_multi_modal_inputs: dict[str, torch.Tensor]) -> None:\n        \"\"\"\n        Update the multi_modal_inputs of the request in additive manner.\n        \"\"\"\n        for key in new_multi_modal_inputs:\n            input_tensor = new_multi_modal_inputs[key]\n            self.multi_modal_inputs[key] = (\n                torch.cat([self.multi_modal_inputs[key], input_tensor], dim=0)\n                if key in self.multi_modal_inputs\n                else input_tensor\n            )\n\n    def get_generation_prompt_ids(\n        self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin\n    ) -> list[int]:\n        \"\"\"\n        Get the generation prompt ids for rollout engine.\n\n        Because rollout engine(SGLang) requires the ids to be a list, we need to convert the tensor to a list.\n        \"\"\"\n        generation_prompt_ids = (\n            None\n            if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all()\n            else self.generation_prompt_ids\n        )\n        if generation_prompt_ids is not None:\n            self._update_input_ids(processing_class, generation_prompt_ids, attention_mask=True, loss_mask=False)\n\n        if self.use_inference_chat_template:\n            messages = [msg.model_dump() for msg in self.messages]\n            tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n            generation_prompt_ids = self._handle_apply_chat_template(\n                processing_class,\n                messages,\n                multi_modal_data=self.multi_modal_data,\n                tools=tools,\n                add_generation_prompt=True,\n                tokenize=True,\n            )\n            return generation_prompt_ids.squeeze(0).tolist()\n        else:\n            return self.input_ids.squeeze(0).tolist()\n\n    def add_user_message(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        content: str,\n    ) -> None:\n        self.messages.append(Message(role=\"user\", content=content))\n        messages = [*BASE_CHAT_HISTORY, self.messages[-1]]\n        tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n\n        # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine\n        # Inference, it is pure text.\n        content_ids = self._handle_apply_chat_template(\n            processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True\n        )[..., self.base_conv_wo_gen_prompt_end_pos :]\n        self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=False)\n\n    def add_assistant_message(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        content: str,\n        tool_calls: Optional[list[OpenAIFunctionToolCall]] = None,\n    ) -> None:\n        self.messages.append(Message(role=\"assistant\", content=content, tool_calls=tool_calls))\n\n        messages = [*BASE_CHAT_HISTORY, self.messages[-1]]\n        tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n\n        # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine\n        # Inference, it is pure text.\n        content_ids = self._handle_apply_chat_template(\n            processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True\n        )[..., self.base_conv_with_gen_prompt_end_pos :]\n        self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=True)\n\n    def add_tool_response_messages(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        contents: list[str | dict[str, Any]],\n    ) -> None:\n        if not contents:\n            return\n        # We also handle the case when tool returns image\n        # We require the processing of the image and video to be done at tool.execute() level\n        delta_multi_modal_data = {key: [] for key in self.multi_modal_keys}\n        for content in contents:\n            if isinstance(content, dict):\n                content_list = []\n                # When we update multi_model_keys, we also need to update this logic\n                if \"image\" in content:\n                    if not isinstance(content[\"image\"], list):\n                        raise ValueError(\n                            f\"Image must be a list, but got {type(content['image'])}. Please check the tool.execute(). \"\n                            f\"For single images, wrap in a list: [image]. \"\n                            f\"Example: {{'image': [img1]}} or {{'image': [img1, img2, ...]}}.\"\n                        )\n\n                    content_list.extend([{\"type\": \"image\"} for _ in content[\"image\"]])\n                    delta_multi_modal_data[\"image\"].extend(content[\"image\"])\n                if \"video\" in content:\n                    if not isinstance(content[\"video\"], list):\n                        raise ValueError(\n                            f\"Video must be a list, but got {type(content['video'])}. Please check the tool.execute(). \"\n                            f\"For single videos, wrap in a list: [video]. \"\n                            f\"Example: {{'video': [video1]}} or {{'video': [video1, video2, ...]}}.\"\n                        )\n\n                    content_list.extend([{\"type\": \"video\"} for _ in content[\"video\"]])\n                    delta_multi_modal_data[\"video\"].extend(content[\"video\"])\n                if \"text\" in content:\n                    content_list.append({\"type\": \"text\", \"text\": content[\"text\"]})\n                for key in content:\n                    if key not in [\"image\", \"video\", \"text\"]:\n                        logger.warning(\n                            f\"Tool response message contains unexpected key: {key} \"\n                            f\"while we only support `image`, `video`, and `text`.\"\n                        )\n                self.messages.append(Message(role=\"tool\", content=content_list))\n            else:\n                self.messages.append(Message(role=\"tool\", content=content))\n\n        messages = [*BASE_CHAT_HISTORY, *self.messages[-len(contents) :]]\n        tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n\n        for key in self.multi_modal_keys:\n            if len(delta_multi_modal_data[key]) > 0:\n                self.multi_modal_data[key].extend(delta_multi_modal_data[key])\n\n        # We just passed the new multi-modal data to the chat template to update the input_ids.\n        content_info = self._handle_apply_chat_template(\n            processing_class,\n            messages,\n            multi_modal_data=delta_multi_modal_data,\n            tools=tools,\n            add_generation_prompt=False,\n            tokenize=True,\n            return_dict=True,\n        )\n        content_ids = content_info[\"input_ids\"][..., self.base_conv_wo_gen_prompt_end_pos :]\n\n        # process multi_modal_inputs\n        multi_modal_inputs = content_info.copy()\n        multi_modal_inputs.pop(\"input_ids\", None)\n        multi_modal_inputs.pop(\"attention_mask\", None)\n        self._update_input_ids(\n            processing_class,\n            content_ids,\n            attention_mask=True,\n            loss_mask=False,\n            new_multi_modal_inputs=multi_modal_inputs,\n        )\n\n    def update_metrics(self, metrics: Any, tool_id: str) -> None:\n        \"\"\"\n        metrics: should be a dict of tools_name -> Any\n        \"\"\"\n        if self.metrics.get(tool_id) is None:\n            self.metrics[tool_id] = []\n        self.metrics[tool_id].append(metrics)\n\n    def _get_prompt_diffs(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        full_prompt_ids: torch.Tensor,\n        current_prompt_ids: torch.Tensor,\n        diff_surrounding_chars: int = 10,\n    ) -> list[dict[str, Any]]:\n        \"\"\"Get differences between full prompt and current prompt with surrounding context.\n\n        This function helps debug tokenization mismatches by showing the differences between\n        full prompt and current prompt with surrounding context. Instead of just showing\n        the exact diff, it includes additional tokens before and after to help locate\n        the issue in the chat template.\n\n        For example, if the actual diff is a newline change from \"\\n\\n\" to \"\\n\", with\n        diff_surrounding_chars the output might look like:\n\n        full_prompt_chunk:    \"<|im_start|>assistant\\n\\nI think...\"\n        current_prompt_chunk: \"<|im_start|>assistant\\nI think...\"\n\n        This context makes it much easier to identify where in the chat template the\n        mismatch occurs.\n\n        Args:\n            processing_class: The processing class to use for decoding the token IDs\n            full_prompt_ids: Token IDs from applying chat template to all messages at once\n            current_prompt_ids: Token IDs from incremental chat template application\n            diff_surrounding_chars: Number of surrounding characters to include for context (default: 10)\n\n        Returns:\n            List of dicts containing the differing chunks with context and their indices\n        \"\"\"\n        full_prompt_ids = full_prompt_ids.squeeze(0)\n        current_prompt_ids = current_prompt_ids.squeeze(0)\n        full_prompt = processing_class.decode(full_prompt_ids, skip_special_tokens=False)\n        current_prompt = processing_class.decode(current_prompt_ids, skip_special_tokens=False)\n        s = difflib.SequenceMatcher(None, full_prompt, current_prompt, autojunk=False)\n        diffs = []\n        for tag, i1, i2, j1, j2 in s.get_opcodes():\n            if tag == \"equal\":\n                continue\n\n            # Get the surrounding context for better readability\n            start_i = max(0, i1 - diff_surrounding_chars)\n            end_i = min(len(full_prompt), i2 + diff_surrounding_chars)\n            start_j = max(0, j1 - diff_surrounding_chars)\n            end_j = min(len(current_prompt), j2 + diff_surrounding_chars)\n\n            diffs.append(\n                {\n                    \"full_prompt_chunk\": full_prompt[start_i:end_i],\n                    \"current_prompt_chunk\": current_prompt[start_j:end_j],\n                    \"indices\": (start_i, end_i, start_j, end_j),\n                }\n            )\n        return diffs\n\n    def finalize(\n        self,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        reward_scores: dict[str, list[float]],\n        finish_reason_type: FinishReasonTypeEnum = FinishReasonTypeEnum.STOP,\n    ) -> None:\n        self.state = AsyncRolloutRequestStateEnum.COMPLETED\n        self.reward_scores = reward_scores\n\n        # In case we failed to generate the assistant message and the generation prompt ids were already added to\n        # input_ids, remove them from the end of input_ids\n        if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all():\n            self.input_ids = self.input_ids[..., : -self.generation_prompt_ids.shape[-1]]\n            self.attention_mask = self.attention_mask[..., : -self.generation_prompt_ids.shape[-1]]\n            self.position_ids = self.position_ids[..., : -self.generation_prompt_ids.shape[-1]]\n            self.loss_mask = self.loss_mask[..., : -self.generation_prompt_ids.shape[-1]]\n\n        self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :]\n\n        if self.tokenization_sanity_check_mode != TokenizationSanityCheckModeEnum.DISABLE:\n            # When there is a diff, we log the diffs with diff_surrounding_chars context\n            diff_surrounding_chars = 10\n\n            messages = [msg.model_dump() for msg in self.messages]\n            tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n            full_prompt_info = self._handle_apply_chat_template(\n                processing_class,\n                messages,\n                multi_modal_data=self.multi_modal_data,\n                tools=tools,\n                add_generation_prompt=False,\n                tokenize=True,\n                return_dict=True,\n            )\n            full_prompt_ids = full_prompt_info[\"input_ids\"]\n\n            # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict\n            # because np.array() only keeps the keys for BatchFeature.\n            full_prompt_multi_modal_inputs = full_prompt_info.copy()\n            full_prompt_multi_modal_inputs.pop(\"input_ids\", None)\n            full_prompt_multi_modal_inputs.pop(\"attention_mask\", None)\n\n            for multi_modal_inputs_key in self.multi_modal_inputs:\n                if multi_modal_inputs_key in full_prompt_multi_modal_inputs:\n                    if (\n                        not self.multi_modal_inputs[multi_modal_inputs_key]\n                        .eq(full_prompt_multi_modal_inputs[multi_modal_inputs_key])\n                        .all()\n                    ):\n                        logger.warning(\n                            f\"Multi-modal data {multi_modal_inputs_key} is not consistent. \"\n                            f\"This may lead to unexpected behavior during training. \"\n                            f\"Please review your multi_modal_inputs logic.\"\n                        )\n                else:\n                    logger.warning(\n                        f\"Multi-modal inputs key {multi_modal_inputs_key} is not found in the multi_modal_inputs. \"\n                        f\"This may lead to unexpected behavior during training.\"\n                        f\"Please review your multi_modal_inputs logic.\"\n                    )\n\n            if diffs := self._get_prompt_diffs(\n                processing_class, full_prompt_ids, self.input_ids, diff_surrounding_chars=diff_surrounding_chars\n            ):\n                log_warning = False\n                if self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.STRICT:\n                    log_warning = True\n                elif self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.IGNORE_STRIPPABLE:\n                    non_strippable_diffs_exist = any(\n                        d[\"full_prompt_chunk\"].strip() or d[\"current_prompt_chunk\"].strip() for d in diffs\n                    )\n                    if non_strippable_diffs_exist:\n                        log_warning = True\n\n                if log_warning:\n                    mode_str = f\" ({self.tokenization_sanity_check_mode.value})\"\n                    logger.warning(\n                        f\"Inconsistent training and inference tokenization detected{mode_str}. This may lead to \"\n                        f\"unexpected behavior during training. Please review your chat template to determine if this \"\n                        f\"is intentional. For more information, refer to the multiturn README.md.\"\n                    )\n                    logger.warning(\n                        f\"Showing {diff_surrounding_chars} characters before and after the diffs for context and \"\n                        f\"better readability.\"\n                    )\n                    diff_details_list = []\n                    for d in diffs:\n                        i1, i2, j1, j2 = d[\"indices\"]\n                        diff_details_list.append(\n                            f\"idx {i1}:{i2} -> {j1}:{j2} | full_prompt_chunk: {repr(d['full_prompt_chunk'])} | \"\n                            f\"current_prompt_chunk: {repr(d['current_prompt_chunk'])}\"\n                        )\n                    diff_details = \"\\n\".join(diff_details_list)\n                    logger.warning(f\"Found differences:\\n{diff_details}\")\n\n        if finish_reason_type == FinishReasonTypeEnum.STOP:\n            pass\n        elif finish_reason_type == FinishReasonTypeEnum.LENGTH:\n            pass\n        else:\n            raise ValueError(f\"Unsupported finalize finish reason type: {finish_reason_type}\")\n        self.truncate_output_ids(processing_class)\n\n        assert (\n            self.input_ids.shape[-1]\n            == self.attention_mask.shape[-1]\n            == self.position_ids.shape[-1]\n            == self.loss_mask.shape[-1]\n        ), f\"\"\"Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, \n            {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}\"\"\"\n\n    def truncate_output_ids(\n        self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin\n    ) -> None:\n        self.input_ids = self.input_ids[..., : self.max_model_len]\n        self.attention_mask = self.attention_mask[..., : self.max_model_len]\n        self.position_ids = self.position_ids[..., : self.max_model_len]\n        self.loss_mask = self.loss_mask[..., : self.max_model_len]\n        self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :][..., : self.max_response_len]\n        self.response_attention_mask = self.attention_mask[..., self.prompt_attention_mask.shape[-1] :][\n            ..., : self.max_response_len\n        ]\n        self.response_position_ids = self.position_ids[..., self.prompt_position_ids.shape[-1] :][\n            ..., : self.max_response_len\n        ]\n        self.response_loss_mask = self.loss_mask[..., self.prompt_loss_mask.shape[-1] :][..., : self.max_response_len]\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/sglang_rollout/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n\nfrom .sglang_rollout import SGLangRollout\n\n__all__ = [\"SGLangRollout\"]\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/sglang_rollout/async_sglang_server.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport logging\nfrom typing import Any\n\nimport ray\nfrom omegaconf import DictConfig\nfrom starlette.requests import Request\nfrom starlette.responses import JSONResponse\n\nfrom verl.workers.rollout.async_server import AsyncServerBase\n\nlogger = logging.getLogger(__file__)\n\n\n@ray.remote(num_cpus=1)\nclass AsyncSGLangServer(AsyncServerBase):\n    def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: str):\n        super().__init__()\n        self.config = config.actor_rollout_ref\n        self._tp_size = self.config.rollout.get(\"tensor_model_parallel_size\", 1)\n        self._dp_size = dp_size\n        self._dp_rank = dp_rank\n        self.wg_prefix = wg_prefix\n        self.workers = []\n        self.master_worker = None\n\n    async def init_engine(self):\n        if self.workers:\n            # avoid init twice\n            return\n        all_actors = ray.util.list_named_actors(all_namespaces=True)\n        matched_actors = [\n            actor for actor in all_actors if actor.get(\"name\", None).startswith(self.wg_prefix + \"WorkerDict_\")\n        ]\n\n        gpu_per_node = len(set([actor[\"name\"].split(\":\")[1] for actor in matched_actors]))\n        # total gpu num\n        assert len(matched_actors) == self._dp_size * self._tp_size\n\n        for matched_actor in matched_actors:\n            fields = matched_actor[\"name\"].split(\":\")\n            assert len(fields) == 2, f\"invalid actor name: {matched_actor['name']}\"\n            pg_index, local_rank = int(fields[0].split(\"_\")[-1]), int(fields[1])\n\n            current_global_rank = gpu_per_node * pg_index + local_rank\n            worker_dp_rank = current_global_rank // self._tp_size\n            worker_tp_rank = current_global_rank % self._tp_size\n\n            if worker_dp_rank == self._dp_rank:\n                worker = ray.get_actor(**matched_actor)\n                self.workers.append(worker)\n\n                if worker_tp_rank == 0:\n                    self.master_worker = worker\n\n    async def chat_completion(self, raw_request: Request):\n        request = await raw_request.json()\n\n        # only send request to master worker in tp rank 0\n        output_future = self.master_worker.chat_completion.remote(request)\n        [outputs] = await asyncio.gather(output_future)\n        return JSONResponse(outputs)\n\n    async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]:\n        return await self.master_worker.generate.remote(prompt_ids, sampling_params, request_id)\n\n    async def wake_up(self):\n        if not self.config.rollout.free_cache_engine:\n            return\n\n        tasks = [worker.wake_up.remote() for worker in self.workers]\n        if tasks:\n            await asyncio.gather(*tasks)\n\n    async def sleep(self):\n        if not self.config.rollout.free_cache_engine:\n            return\n\n        tasks = [worker.sleep.remote() for worker in self.workers]\n        if tasks:\n            await asyncio.gather(*tasks)\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/sglang_rollout/sglang_rollout.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nimport multiprocessing as mp\nimport os\nimport time\nfrom copy import deepcopy\nfrom json import JSONDecodeError\nfrom typing import Any, List, Optional, Tuple\nfrom uuid import uuid4\n\nimport numpy as np\nimport sglang.srt.entrypoints.engine\nimport torch\nimport torch.distributed as dist\nfrom omegaconf import DictConfig\nfrom sglang.srt.managers.tokenizer_manager import (\n    ReleaseMemoryOccupationReqInput,\n    ResumeMemoryOccupationReqInput,\n    UpdateWeightsFromTensorReqInput,\n)\nfrom sglang.srt.sampling.sampling_params import SamplingParams\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import (\n    MultiprocessingSerializer,\n    assert_pkg_version,\n    get_ip,\n    get_open_port,\n    is_cuda,\n    maybe_set_triton_cache_manager,\n    set_prometheus_multiproc_dir,\n    set_ulimit,\n)\nfrom tensordict import TensorDict\nfrom torch.distributed.device_mesh import DeviceMesh, init_device_mesh\nfrom torch.nn.utils.rnn import pad_sequence\nfrom transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin\n\nfrom verl import DataProto\nfrom verl.interactions.base import BaseInteraction\nfrom verl.interactions.utils.interaction_registry import initialize_interactions_from_config\nfrom verl.third_party.sglang import parallel_state as sglang_ps\nfrom verl.tools.base_tool import BaseTool\nfrom verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall\nfrom verl.tools.utils.tool_registry import initialize_tools_from_config\nfrom verl.utils.net_utils import is_ipv6\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.torch_functional import get_response_mask, pad_sequence_to_length\nfrom verl.workers.rollout.base import BaseRollout\nfrom verl.workers.rollout.schemas import (\n    AsyncRolloutRequest,\n    AsyncRolloutRequestStateEnum,\n    FinishReasonTypeEnum,\n    Message,\n)\nfrom verl.workers.rollout.sglang_rollout.utils import broadcast_pyobj\n\ntry:\n    from sglang.srt.function_call.function_call_parser import FunctionCallParser\nexcept ImportError:\n    from sglang.srt.function_call_parser import FunctionCallParser\n\ntry:\n    from sglang.srt.entrypoints.openai.protocol import Tool\nexcept ImportError:\n    from sglang.srt.openai_api.protocol import Tool\n\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n# patch to avoid issue https://github.com/sgl-project/sglang/issues/6723\ndef _set_envs_and_config(server_args: ServerArgs):\n    # Set global environments\n    os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n    os.environ[\"NCCL_CUMEM_ENABLE\"] = \"0\"\n    os.environ[\"NCCL_NVLS_ENABLE\"] = str(int(server_args.enable_nccl_nvls))\n    os.environ[\"TORCH_NCCL_AVOID_RECORD_STREAMS\"] = \"1\"\n    os.environ[\"CUDA_DEVICE_MAX_CONNECTIONS\"] = \"4\"\n    os.environ[\"CUDA_MODULE_LOADING\"] = \"AUTO\"\n\n    # Set prometheus env vars\n    if server_args.enable_metrics:\n        set_prometheus_multiproc_dir()\n\n    # Set ulimit\n    set_ulimit()\n\n    # Fix triton bugs\n    if server_args.tp_size * server_args.dp_size > 1:\n        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.\n        maybe_set_triton_cache_manager()\n\n    # Check flashinfer version\n    if server_args.attention_backend == \"flashinfer\":\n        assert_pkg_version(\n            \"flashinfer_python\",\n            \"0.2.5\",\n            \"Please uninstall the old version and reinstall the latest version by following the instructions at https://docs.flashinfer.ai/installation.html.\",\n        )\n    if is_cuda():\n        assert_pkg_version(\n            \"sgl-kernel\",\n            \"0.1.1\",\n            \"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`\",\n        )\n\n    # Set mp start method\n    mp.set_start_method(\"spawn\", force=True)\n\n\nsglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config\n\n\n# because chatCompletion is an async method, it makes the whole ray actor be an async actor\n# which can not call loop.run_until_complete. So we need to make the engine to be an async class\nclass AsyncEngine(sglang.srt.entrypoints.engine.Engine):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        # default to use dummy load format, which need to reload weights in first time\n        self._need_reload = True\n\n    async def release_memory_occupation(self, tags: Optional[list[str]] = None):\n        \"\"\"Release GPU occupation temporarily.\"\"\"\n        if tags is None:\n            obj = ReleaseMemoryOccupationReqInput()\n        else:\n            obj = ReleaseMemoryOccupationReqInput(tags=tags)\n        return await self.tokenizer_manager.release_memory_occupation(obj, None)\n\n    async def resume_memory_occupation(self, tags: Optional[list[str]] = None):\n        \"\"\"Resume GPU occupation.\"\"\"\n        # because __init__ is a sync method, it can not call the async release_memory_occupation\n        # have to move release_memory_occupation from __init__ to here\n        # For multi-stage awake, we run release weight and kv_cache when we resume weights for the first time.\n        if self._need_reload:\n            await self.release_memory_occupation()\n            self._need_reload = False\n\n        if tags is None:\n            obj = ResumeMemoryOccupationReqInput()\n        else:\n            obj = ResumeMemoryOccupationReqInput(tags=tags)\n        return await self.tokenizer_manager.resume_memory_occupation(obj, None)\n\n    async def update_weights_from_tensor(\n        self,\n        named_tensors: List[Tuple[str, torch.Tensor]],  # noqa: UP006\n        load_format: Optional[str] = None,\n        flush_cache: bool = True,\n    ):\n        \"\"\"Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false\n        to avoid duplicated cache cleaning operation.\"\"\"\n        obj = UpdateWeightsFromTensorReqInput(\n            serialized_named_tensors=[\n                MultiprocessingSerializer.serialize(named_tensors) for _ in range(self.server_args.tp_size)\n            ],\n            load_format=load_format,\n            flush_cache=flush_cache,\n        )\n        return await self.tokenizer_manager.update_weights_from_tensor(obj, None)\n\n    async def flush_cache(self):\n        return await self.tokenizer_manager.flush_cache()\n\n\n# NOTE(sgm): add for verl. We can optimize it by making\n#  the dataloader yield List[int] without padding.\ndef _pre_process_inputs(\n    pad_token_id,\n    prompt_token_ids: torch.Tensor,\n) -> torch.Tensor:\n    # remove the left padding in the prompt token_id\n    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]\n    return prompt_token_ids[non_pad_index:]\n\n\n# NOTE(linjunrong): adhoc\ndef _post_process_outputs(processing_class, output):\n    try:\n        # This is when processing_class is a processor\n        tokenizer = processing_class.tokenizer\n    except AttributeError:\n        try:\n            # This is when processing_class is a tokenizer\n            tokenizer = processing_class\n        except AttributeError as e:\n            raise ValueError(f\"Cannot get tokenizer from processing_class {processing_class}\") from e\n\n    def _map_each_response(resp):\n        output_token_logprobs = resp[\"meta_info\"][\"output_token_logprobs\"]\n        log_probs, output_token_ids = zip(\n            *[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs], strict=True\n        )\n        return torch.tensor(output_token_ids), torch.tensor(log_probs)\n\n    out_map = map(lambda x: _map_each_response(x), output)\n    batched_output_token_ids = []\n    batched_logprobs = []\n    for output_token_ids, log_probs in out_map:\n        batched_output_token_ids.append(output_token_ids)\n        batched_logprobs.append(log_probs)\n    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n    batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id)\n    if len(batched_logprobs) > 0:\n        batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id)\n    return batched_output_token_ids, batched_logprobs\n\n\ndef get_tool_call_parser_type(\n    processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n) -> str:\n    items = FunctionCallParser.ToolCallParserEnum.items()\n    for parser_type, parser_cls in items:\n        parser = parser_cls()\n        try:\n            # This is when processing_class is a tokenizer\n            tokenizer_vocab = processing_class.get_vocab()\n        except AttributeError:\n            try:\n                # This is when processing_class is a processor\n                tokenizer_vocab = processing_class.tokenizer.get_vocab()\n            except AttributeError as e:\n                raise ValueError(f\"Cannot get vocab from processing_class {processing_class}\") from e\n\n        if parser.bot_token.strip() in tokenizer_vocab and (\n            parser.eot_token == \"\" or parser.eot_token.strip() in tokenizer_vocab\n        ):\n            return parser_type\n    else:\n        raise ValueError(f\"No tool call parser found for processing_class {processing_class}\")\n\n\nclass SGLangRollout(BaseRollout):\n    def __init__(\n        self,\n        actor_module: str,\n        config: DictConfig,\n        processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin,\n        model_hf_config,\n        port=None,\n        trust_remote_code: bool = False,\n        device_mesh: DeviceMesh | None = None,\n        **kwargs,\n    ):\n        \"\"\"Synchronized SGLang rollout engine.\n\n        Args:\n            actor_module: Huggingface model name or path to the model. The\n                model should be supported by SGLang.\n            config: A DictConfig object containing SGLang-specific operational\n                parameters and rollout settings.\n                Refer to https://docs.sglang.ai/backend/server_arguments.html\n            processing_class: The tokenizer or processor instance compatible with the actor_module.\n            model_hf_config: The Hugging Face model's configuration (e.g.,\n                `transformers.PretrainedConfig`). It provides architectural\n                details and hyperparameters like `max_position_embeddings`,\n                used by SGLang for correct model initialization. This is\n                the model's inherent design, not SGLang's runtime behavior.\n            port: Optional port for multi-node initialization when nnodes > 1.\n            trust_remote_code: Whether or not to allow for custom models\n                defined on the Hub in their own modeling files.\n            device_mesh: Optional `DeviceMesh` object for distributed setup.\n            **kwargs: Additional keyword arguments, primarily `train_tp` for\n                Megatron Backend integration to initialize hybrid engine\n                process groups.\n        \"\"\"\n        super().__init__()\n        self.config = config\n        self._device_mesh_cpu = device_mesh\n        os.environ.setdefault(\"SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK\", \"true\")\n\n        (\n            self._tool_schemas,\n            self._tool_map,\n            self._tool_call_parser_type,\n            self._sgl_tools,\n            self._function_call_parser,\n        ) = self._initialize_tools(config, processing_class)\n        self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(config)\n        # If turn on `free_cache_engine`, SGLang engine's KV cache\n        # will be freed after each `generate_sequences` call.\n        logger.info(\n            f\"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: \"\n            f\"{self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: \"\n            f\"{self._function_call_parser}\"\n        )\n\n        self._init_distributed_env(device_mesh_cpu=device_mesh, **kwargs)\n\n        self._verify_config(model_hf_config=model_hf_config)\n        # initialize the inference engine\n        self._init_inference_engine(trust_remote_code, actor_module, port)\n\n        self._init_sampling_params(**kwargs)\n\n        self.processing_class = processing_class\n\n        try:\n            # This is when processing_class is a tokenizer\n            self.pad_token_id = self.processing_class.pad_token_id\n        except AttributeError:\n            try:\n                # This is when processing_class is a processor\n                self.pad_token_id = self.processing_class.tokenizer.pad_token_id\n            except AttributeError as e:\n                raise ValueError(f\"Cannot get pad_token_id from processing_class {self.processing_class}\") from e\n\n    def _init_distributed_env(self, device_mesh_cpu, **kwargs):\n        self._device_mesh_cpu = device_mesh_cpu\n        os.environ.setdefault(\"SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK\", \"true\")\n        self.tensor_parallel_size = self.config.get(\"tensor_model_parallel_size\", 1)\n        assert self.tensor_parallel_size <= dist.get_world_size(), (\n            \"tensor parallel size should be less than or equal to the world size\"\n        )\n        self.train_tp = kwargs.get(\"train_tp\", None)\n        if self.train_tp is not None:\n            # deployed with megatron\n            os.environ[\"CUDA_TIMER_STREAM_KAFKA_ENABLE\"] = \"0\"\n            os.environ[\"MEGATRON_IMPORT_TIMERS\"] = \"0\"\n            train_tp = kwargs.get(\"train_tp\", None)\n            num_tp_per_train_tp = train_tp // self.tensor_parallel_size\n            sglang_ps.initialize_parallel_state(\n                tensor_model_parallel_size=self.tensor_parallel_size,\n                num_tp_per_train_tp=num_tp_per_train_tp,\n            )\n\n        tp_size = self.tensor_parallel_size\n        world_size = int(os.getenv(\"WORLD_SIZE\", \"-1\"))\n\n        # init device mesh\n        if self._device_mesh_cpu is None:\n            device_mesh_kwargs = dict(\n                mesh_shape=(world_size // tp_size, tp_size, 1),\n                mesh_dim_names=[\"dp\", \"tp\", \"pp\"],\n            )\n\n            self._device_mesh_cpu = init_device_mesh(\"cpu\", **device_mesh_kwargs)\n\n        self._rank = self._device_mesh_cpu.get_rank()\n        self._tp_rank = self._device_mesh_cpu[\"tp\"].get_local_rank()\n        self._tp_size = self._device_mesh_cpu[\"tp\"].size()\n        if self._rank == 0:\n            logger.info(f\"_init_distributed_env: :tp_world: {self._tp_size}, global_world: {world_size}\")\n        # get tp_rank of this process in this tp group\n        visible_devices = [None] * self._device_mesh_cpu.size(1)\n\n        torch.distributed.all_gather_object(\n            visible_devices, os.environ[\"CUDA_VISIBLE_DEVICES\"], self._device_mesh_cpu.get_group(\"tp\")\n        )\n        self.visible_devices_set = set(\",\".join(visible_devices).split(\",\"))\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join(sorted(list(self.visible_devices_set)))\n\n    def _verify_config(self, model_hf_config):\n        if not self.config.get(\"max_model_len\", None):\n            self.config.max_model_len = self.config.prompt_length + self.config.response_length\n        assert (\n            self.config.max_model_len >= self.config.prompt_length + self.config.response_length\n        ), f\"\"\"max_model_len should be greater than total sequence length (prompt_length + response_length): \n            {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}\"\"\"\n        max_position_embeddings = None\n        if hasattr(model_hf_config, \"max_position_embeddings\"):\n            max_position_embeddings = model_hf_config.max_position_embeddings\n        elif hasattr(model_hf_config, \"llm_config\") and hasattr(model_hf_config.llm_config, \"max_position_embeddings\"):\n            max_position_embeddings = model_hf_config.llm_config.max_position_embeddings\n        elif hasattr(model_hf_config, \"text_config\") and hasattr(\n            model_hf_config.text_config, \"max_position_embeddings\"\n        ):\n            max_position_embeddings = model_hf_config.text_config.max_position_embeddings\n        if max_position_embeddings is None:\n            raise ValueError(\"max_position_embeddings not found in model_hf_config\")\n        rope_scaling_config = getattr(model_hf_config, \"rope_scaling\", None)\n        if not rope_scaling_config:\n            assert max_position_embeddings >= self.config.prompt_length + self.config.response_length, (\n                \"model context length should be greater than total sequence length\"\n            )\n        else:\n            # handle type where there's a length extend factor\n            # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support\n            # for using yarn as an example\n            rope_scaling_factor = rope_scaling_config.get(\"factor\", 1.0)\n\n            assert (\n                model_hf_config.max_position_embeddings * rope_scaling_factor\n                >= self.config.prompt_length + self.config.response_length\n            ), (\n                f\"model context length should be greater than total sequence length, \"\n                f\"got rope_scaling_factor={rope_scaling_factor} and \"\n                f\"max_position_embeddings={model_hf_config.max_position_embeddings}\"\n            )\n\n        # currently max_assistant_turns stand for max number of tool calls\n        if self.config.multi_turn.max_assistant_turns is None:\n            self.config.multi_turn.max_assistant_turns = self.config.max_model_len // 3\n        if self.config.multi_turn.max_user_turns is None:\n            self.config.multi_turn.max_user_turns = self.config.max_model_len // 3\n\n    def _init_inference_engine(self, trust_remote_code, actor_module, port):\n        # initialize the inference engine\n        nnodes = -(-self._tp_size // len(self.visible_devices_set))\n        if nnodes > 1:\n            ip = get_ip()\n            port = get_open_port() if port is None else port\n            [ip, port] = broadcast_pyobj(\n                [ip, port],\n                rank=self._rank,\n                dist_group=self._device_mesh_cpu.get_group(\"tp\"),\n                src=self._device_mesh_cpu[\"tp\"].mesh[0].item(),\n                force_cpu_device=False,\n            )\n            dist_init_addr = f\"[{ip}]:{port}\" if is_ipv6(ip) else f\"{ip}:{port}\"\n        else:\n            dist_init_addr = None\n\n        load_format = \"dummy\" if self.config.load_format.startswith(\"dummy\") else self.config.load_format\n        tp_size_per_node = self._tp_size // nnodes\n        node_rank = self._tp_rank // tp_size_per_node\n        first_rank_in_node = self._tp_rank % tp_size_per_node == 0\n\n        if first_rank_in_node:\n            rank = dist.get_rank()\n            os.environ[\"SGLANG_BLOCK_NONZERO_RANK_CHILDREN\"] = \"0\"\n            self._engine = AsyncEngine(\n                model_path=actor_module,\n                dtype=self.config.dtype,\n                mem_fraction_static=self.config.gpu_memory_utilization,\n                enable_memory_saver=True,\n                base_gpu_id=0,\n                gpu_id_step=1,\n                tp_size=self._tp_size,\n                node_rank=node_rank,\n                load_format=load_format,\n                dist_init_addr=dist_init_addr,\n                nnodes=nnodes,\n                trust_remote_code=trust_remote_code,\n                # NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new\n                # when random.seed is being set during training\n                port=30000 + rank,\n                # NOTE(Chenyang): if you want to debug the SGLang engine output\n                # please set the following parameters\n                # Otherwise, it will make the engine run too slow\n                # log_level=\"INFO\",\n                # log_requests=True,\n                # log_requests_level=2,\n                # max_running_requests=1,\n                mm_attention_backend=\"fa3\",\n                attention_backend=\"fa3\",\n                # In async mode, we want token in token out.\n                skip_tokenizer_init=self.config.mode == \"async\",\n            )\n        else:\n            self._engine = None\n\n        self.sharding_manager = None\n        self.is_sleep = True\n\n    def _init_sampling_params(self, **kwargs):\n        kwargs = dict(\n            n=1,\n            max_new_tokens=self.config.response_length,\n            presence_penalty=0.0,\n            frequency_penalty=0.0,\n            repetition_penalty=1.0,\n        )\n        # supporting adding any sampling params from the config file\n        for k in self.config.keys():\n            if hasattr(SamplingParams(), str(k)) or \"stop\" in str(k):\n                kwargs[k] = self.config.get(k)\n        kwargs[\"n\"] = 1  # already repeat in ray_trainer\n        self.sampling_params = kwargs\n\n    def _initialize_tools(self, config, processing_class):\n        \"\"\"Initialize tools from configuration.\n\n        Args:\n            config: Configuration object containing tool-related settings,\n                    specifically `config.multi_turn.tool_config_path`.\n            tokenizer: The tokenizer instance used for parsing tool calls from\n                       the model's generated text.\n\n        Returns:\n            tuple: A tuple containing:\n                - tool_schemas (list[dict]): OpenAI-formatted JSON schemas\n                  defining each tool's capabilities.\n                - tool_map (dict[str, BaseTool]): A dictionary mapping tool\n                  names to their executable `BaseTool` objects.\n                - tool_call_parser_type (str): The identifier for the specific\n                  parser type (e.g., 'json_mode', 'tool_code') used to extract\n                  tool calls.\n                - sgl_tools (list[sglang.srt.openai_api.protocol.Tool]): Tool\n                  definitions optimized for SGLang's internal engine.\n                - function_call_parser (sglang.srt.function_call_parser.FunctionCallParser):\n                  The active parser instance responsible for extracting\n                  structured tool calls from model outputs.\n        \"\"\"\n        if config.multi_turn.tool_config_path is None:\n            return [], {}, None, [], None\n\n        tools_config_file = config.multi_turn.tool_config_path\n        tool_list = initialize_tools_from_config(tools_config_file)\n\n        logger.info(f\"Initialize tools from configuration.: tool_list: {tool_list}\")\n        tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list]\n        tool_map = {tool.name: tool for tool in tool_list}\n        tool_call_parser_type = get_tool_call_parser_type(processing_class)\n        sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas]\n        function_call_parser = FunctionCallParser(\n            sgl_tools,\n            tool_call_parser_type,\n        )\n\n        return (\n            tool_schemas,\n            tool_map,\n            tool_call_parser_type,\n            sgl_tools,\n            function_call_parser,\n        )\n\n    def _initialize_interactions(self, config):\n        \"\"\"Initialize interactions from configuration.\n\n        Returns:\n            dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances.\n        \"\"\"\n        if config.multi_turn.interaction_config_path is None:\n            return {}\n\n        interaction_config_file = config.multi_turn.interaction_config_path\n        interaction_map = initialize_interactions_from_config(interaction_config_file)\n\n        logger.info(f\"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}\")\n        return interaction_map\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:\n        \"\"\"Generate sequences for a batch of prompts.\n\n        Args:\n            batch (DataProto): Input batch.\n\n        Returns:\n            DataProto: Output batch.\n            - prompts: [bsz, prompt_length], prompt token ids from dataset.\n            - responses: [bsz, response_length], output token ids include response tokens\n              from LLM generation and observation tokens from tool_calls.\n            - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.\n            - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens\n              and response tokens.\n            - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.\n            - position_ids: [bsz, prompt_length + response_length], incremental position ids.\n\n            For multi-turn conversations:\n            responses:     |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|\n            response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|\n        \"\"\"\n        if self.config.multi_turn.enable:\n            return self._req_level_generate_sequences(prompts, **kwargs)\n        return self._batch_level_generate_sequences(prompts, **kwargs)\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:\n        \"\"\"Generates single-turn sequences for a batch of prompts.\n        For single-turn generation, all prompts are processed in one request.\n        `_batch_level_generate_sequences` involves:\n        1.  Extracting and pre-processing prompt token IDs from the input\n            `prompts`. This includes handling padding and preparing raw\n            token ID lists.\n        2.  Preparing inputs for the SGLang engine, including multi-modal\n            data if present.\n        3.  Invoking the SGLang engine (`self._engine.async_generate`,\n            an async coroutine) with the batch of processed inputs and\n            specified sampling parameters on the master TP rank.\n        4.  Broadcasting the results from the master TP rank to all\n            other TP ranks.\n        5.  Post-processing the engine's output to format the generated\n            token IDs and (if applicable) log probabilities.\n        6.  Constructing the final sequences by concatenating original\n            prompts with the generated responses.\n        7.  Updating attention masks and position IDs to reflect the full\n            concatenated sequences.\n        8.  If `self.config.free_cache_engine` is true, the SGLang engine's\n            KV cache is flushed after generation on the master TP rank.\n        Args:\n            prompts: A `DataProto` object containing the batch of\n              input prompts, including tensor data (like `input_ids`,\n              `attention_mask`) and meta-information (like `eos_token_id`,\n              `do_sample`).\n            **kwargs: Additional keyword arguments that can override the\n              default sampling parameters (e.g., `temperature`, `top_p`,\n              `max_new_tokens`). These are temporarily applied using\n              `update_sampling_params`.\n        Returns:\n            DataProto: A `DataProto` object containing the batch of\n              generated sequences. This includes tensors for `prompts`\n              (original input IDs), `responses` (generated token IDs),\n              `input_ids` (concatenated prompt and response),\n              `attention_mask`, and `position_ids` for the full\n              sequences.\n        Note that in GRPO, if the prompts are validated, we repeat the prompts for rollout.n times in ray_trainer.\n        Thus we do not need to repeat the prompts here and set the sampling parameter n to 1.\n        \"\"\"\n        # input ids: (bs, prompt_length), left-padded\n        idx = prompts.batch[\"input_ids\"]\n        # attention_mask: (bs, seq_length), left-padded\n        attention_mask = prompts.batch[\"attention_mask\"]\n        position_ids = prompts.batch[\"position_ids\"]\n\n        # used to generate attention mask for the\n        # response based on EOS token position\n        eos_token_id = prompts.meta_info[\"eos_token_id\"]\n\n        batch_size = idx.size(0)\n\n        # Extract non-tensor data\n        non_tensor_batch = prompts.non_tensor_batch\n        if \"raw_prompt_ids\" not in non_tensor_batch:\n            non_tensor_batch[\"raw_prompt_ids\"] = np.array(\n                [_pre_process_inputs(self.pad_token_id, idx[i]).tolist() for i in range(batch_size)],\n                dtype=object,\n            )\n\n        if \"multi_modal_data\" in non_tensor_batch:\n            sglang_inputs = []\n            for raw_prompt_ids, multi_modal_data in zip(\n                non_tensor_batch.pop(\"raw_prompt_ids\"),\n                non_tensor_batch.pop(\"multi_modal_data\"),\n                strict=True,\n            ):\n                sglang_inputs.append(\n                    {\n                        \"prompt_token_ids\": raw_prompt_ids,\n                        \"multi_modal_data\": multi_modal_data,\n                        \"image_data\": (\n                            multi_modal_data.get(\"image\", None) if isinstance(multi_modal_data, dict) else None\n                        ),\n                    }\n                )\n        else:\n            sglang_inputs = [\n                {\"prompt_token_ids\": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop(\"raw_prompt_ids\")\n            ]\n\n        # Ensure token IDs are lists or numpy arrays\n        for input_data in sglang_inputs:\n            if isinstance(input_data[\"prompt_token_ids\"], np.ndarray):\n                input_data[\"prompt_token_ids\"] = input_data[\"prompt_token_ids\"].tolist()\n            elif not isinstance(input_data[\"prompt_token_ids\"], list):\n                raise TypeError(\n                    f\"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}\"\n                )\n\n        # Extract token IDs and image data for SGLang Engine\n        idx_list = [input_data[\"prompt_token_ids\"] for input_data in sglang_inputs]\n        image_list = [input_data.get(\"image_data\", None) for input_data in sglang_inputs]\n\n        do_sample = prompts.meta_info.get(\"do_sample\", True)\n        is_validate = prompts.meta_info.get(\"validate\", False)\n\n        # Create request-level sampling parameters\n        request_sampling_params = self.sampling_params.copy()\n        if not do_sample:\n            request_sampling_params.update(\n                {\n                    \"n\": 1,\n                    \"presence_penalty\": 0.0,\n                    \"frequency_penalty\": 0.0,\n                    \"repetition_penalty\": 1.0,\n                    \"temperature\": 0,\n                    \"top_p\": 1,\n                    \"top_k\": -1,\n                    \"ignore_eos\": False,\n                    \"min_new_tokens\": 0,\n                    \"max_new_tokens\": self.config.response_length,\n                    \"skip_special_tokens\": True,\n                    \"spaces_between_special_tokens\": True,\n                }\n            )\n        elif is_validate:\n            request_sampling_params.update(\n                {\n                    \"top_k\": self.config.val_kwargs.top_k,\n                    \"top_p\": self.config.val_kwargs.top_p,\n                    \"temperature\": self.config.val_kwargs.temperature,\n                    \"n\": 1,  # if validate, already repeat in ray_trainer\n                }\n            )\n\n        # Update with any additional kwargs\n        request_sampling_params.update(kwargs)\n\n        if self._tp_rank == 0:\n            loop = asyncio.get_event_loop()\n            output = loop.run_until_complete(\n                self._engine.async_generate(\n                    prompt=None,  # because we have already convert it to prompt token id\n                    sampling_params=request_sampling_params,\n                    return_logprob=True,\n                    input_ids=idx_list,\n                    image_data=image_list,\n                )\n            )\n        else:\n            output = None\n\n        # Most naive implementation, can extract tensor and send via gloo if too slow\n        dist.barrier()\n        [output] = broadcast_pyobj(\n            data=[output],\n            rank=self._rank,\n            dist_group=self._device_mesh_cpu[\"tp\"].get_group(),\n            src=self._device_mesh_cpu[\"tp\"].mesh[0].item(),\n            force_cpu_device=False,\n        )\n        out = _post_process_outputs(self.processing_class, output)\n\n        response = out[0].to(idx.device)\n        rollout_log_probs = None\n        if self.config.calculate_log_probs:\n            rollout_log_probs = out[1].to(idx.device)\n\n        if response.shape[1] < self.config.response_length:\n            response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)\n            if self.config.calculate_log_probs:\n                rollout_log_probs = pad_sequence_to_length(\n                    rollout_log_probs, self.config.response_length, self.pad_token_id\n                )\n\n        seq = torch.cat([idx, response], dim=-1)\n\n        response_length = response.size(1)\n        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n        delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)\n        if position_ids.dim() == 3:  # qwen2vl mrope\n            delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)\n\n        # TODO(sgm): fix position_ids on right_pad\n        # prompt: left pad + response: right pad\n        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n        response_position_ids = position_ids[..., -1:] + delta_position_id\n        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n        response_attention_mask = get_response_mask(\n            response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype\n        )\n        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n\n        # all the tp ranks should contain the same data here. data in all ranks are valid\n        batch = TensorDict(\n            {\n                \"prompts\": idx,\n                \"responses\": response,\n                \"input_ids\": seq,  # here input_ids become the whole sentences\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=batch_size,\n        )\n        if self.config.calculate_log_probs:\n            # we will recompute old log prob with actor\n            batch[\"rollout_log_probs\"] = rollout_log_probs\n\n        # free cache engine\n        if self._engine is not None and self._tp_rank == 0:\n            loop = asyncio.get_event_loop()\n            loop.run_until_complete(self._engine.flush_cache())\n\n        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)\n\n    async def _async_rollout_a_request(\n        self,\n        req: AsyncRolloutRequest,\n        do_sample: bool = True,\n        is_validate: bool = False,\n        **kwargs,\n    ) -> AsyncRolloutRequest:\n        assert self._tp_rank == 0, \"only the master process can call this function\"\n        _req = deepcopy(req)\n        finish_reason_type = None\n        output = None\n\n        current_turns = 0\n        user_turns = 0\n        user_turn_rewards = []\n\n        # Create request-level sampling parameters\n        request_sampling_params = self.sampling_params.copy()\n        if not do_sample:\n            request_sampling_params.update(\n                {\n                    \"n\": 1,\n                    \"presence_penalty\": 0.0,\n                    \"frequency_penalty\": 0.0,\n                    \"repetition_penalty\": 1.0,\n                    \"temperature\": 0,\n                    \"top_p\": 1,\n                    \"top_k\": -1,\n                    \"ignore_eos\": False,\n                    \"min_new_tokens\": 0,\n                    \"max_new_tokens\": self.config.response_length,\n                    \"skip_special_tokens\": True,\n                    \"spaces_between_special_tokens\": True,\n                }\n            )\n        elif is_validate:\n            request_sampling_params.update(\n                {\n                    \"top_k\": self.config.val_kwargs.top_k,\n                    \"top_p\": self.config.val_kwargs.top_p,\n                    \"temperature\": self.config.val_kwargs.temperature,\n                    \"n\": 1,  # if validate, already repeat in ray_trainer\n                }\n            )\n\n        # Update with any additional kwargs\n        request_sampling_params.update(kwargs)\n\n        while current_turns < self.config.multi_turn.max_assistant_turns:\n            if _req.state == AsyncRolloutRequestStateEnum.PENDING:\n                await self._handle_pending_state(_req)\n                _req.state = AsyncRolloutRequestStateEnum.RUNNING\n            elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING:\n                if _req.messages[-1].tool_calls is not None:\n                    parsed_tool_calls = _req.messages[-1].tool_calls\n                    tool_call_results = await asyncio.gather(\n                        *[\n                            self._tool_map[tool_call.function.name].execute(\n                                _req.request_id,\n                                tool_call.function.arguments,\n                                **_req.tools_kwargs[tool_call.function.name].get(\"execute_kwargs\", {}),\n                            )\n                            for tool_call in parsed_tool_calls\n                        ]\n                    )\n                    _req.add_tool_response_messages(self.processing_class, [resp for resp, _, _ in tool_call_results])\n                    for tool_call, (resp, reward, metrics) in zip(parsed_tool_calls, tool_call_results, strict=True):\n                        _req.update_metrics(metrics, tool_call.function.name)\n                    if len(_req.input_ids) >= self.config.max_model_len:\n                        finish_reason_type = FinishReasonTypeEnum.STOP\n                        break\n                    _req.state = AsyncRolloutRequestStateEnum.RUNNING\n                else:\n                    raise ValueError(f\"Unexpected tool calling last message state: {_req.messages[-1]}\")\n            elif _req.state == AsyncRolloutRequestStateEnum.RUNNING:\n                # Only continue the conversation if the prompt length is not greater than max_model_len - 1,\n                # since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra\n                # token accounts for the EOS token).\n                if len(_req.get_generation_prompt_ids(self.processing_class)) + 1 >= self.config.max_model_len:\n                    finish_reason_type = FinishReasonTypeEnum.LENGTH\n                    break\n\n                # Video support is not implemented yet\n                image_data = (\n                    _req.multi_modal_data[\"image\"]\n                    if _req.multi_modal_data and \"image\" in _req.multi_modal_data\n                    else None\n                )\n                video_data = (\n                    _req.multi_modal_data[\"video\"]\n                    if _req.multi_modal_data and \"video\" in _req.multi_modal_data\n                    else None\n                )\n                if video_data:\n                    logger.warning(\n                        \"video support is not implemented yet, current length of video data is %d\", len(video_data)\n                    )\n\n                output = await self._handle_engine_call(_req, request_sampling_params, image_data=image_data)\n                content = output[\"text\"]\n                finish_reason_type = FinishReasonTypeEnum.from_str(output[\"meta_info\"][\"finish_reason\"][\"type\"])\n                current_turns += 1\n                if finish_reason_type == FinishReasonTypeEnum.LENGTH:\n                    _req.add_assistant_message(self.processing_class, content)\n                    break\n                else:\n                    if self._function_call_parser and self._function_call_parser.has_tool_call(content):\n                        finish_reason_type = FinishReasonTypeEnum.TOOL_CALL\n                        _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING\n                        try:\n                            normed_content, tool_calls = self._function_call_parser.parse_non_stream(content)\n                        except JSONDecodeError:\n                            normed_content = content\n                            tool_calls = []\n                        except AttributeError:\n                            normed_content = content\n                            tool_calls = []\n                        parsed_tool_calls = []\n                        for tool_call in tool_calls:\n                            function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema(\n                                OpenAIFunctionParsedSchema(\n                                    name=tool_call.name,\n                                    arguments=tool_call.parameters,\n                                )\n                            )\n                            # Drop the tool call if its arguments has decode error\n                            if has_decode_error:\n                                continue\n                            parsed_tool_calls.append(\n                                OpenAIFunctionToolCall(\n                                    id=str(tool_call.tool_index),\n                                    function=function,\n                                )\n                            )\n                        if len(parsed_tool_calls) > 0:\n                            _req.add_assistant_message(\n                                self.processing_class, normed_content, tool_calls=parsed_tool_calls\n                            )\n                        else:\n                            _req.add_assistant_message(self.processing_class, content)\n                            finish_reason_type = FinishReasonTypeEnum.STOP\n                            _req.state = AsyncRolloutRequestStateEnum.COMPLETED\n                            break\n                    else:\n                        _req.add_assistant_message(\n                            self.processing_class,\n                            content,\n                        )\n                        if (\n                            _req.interaction_kwargs\n                            and self.interaction_map\n                            and user_turns < self.config.multi_turn.max_user_turns\n                            and current_turns < self.config.multi_turn.max_assistant_turns\n                        ):\n                            _req.state = AsyncRolloutRequestStateEnum.INTERACTING\n                        else:\n                            break\n            elif _req.state == AsyncRolloutRequestStateEnum.INTERACTING:\n                user_turns += 1\n                messages = [{\"role\": x.role, \"content\": x.content} for x in _req.messages]\n\n                # Get interaction by name from interaction_kwargs\n                interaction_name = _req.interaction_kwargs.get(\n                    \"name\", \"gsm8k\"\n                )  # Default to gsm8k for backward compatibility\n                if interaction_name not in self.interaction_map:\n                    raise ValueError(\n                        f\"Interaction '{interaction_name}' not found in interaction_map. Available interactions: \"\n                        f\"{list(self.interaction_map.keys())}\"\n                    )\n\n                interaction = self.interaction_map[interaction_name]\n                should_terminate_sequence, content, reward, metrics = await interaction.generate_response(\n                    _req.request_id, messages, **_req.interaction_kwargs\n                )\n                user_turn_rewards.append(reward)\n                if should_terminate_sequence:\n                    finish_reason_type = FinishReasonTypeEnum.STOP\n                    _req.state = AsyncRolloutRequestStateEnum.COMPLETED\n                    break\n                else:\n                    _req.add_user_message(self.processing_class, content)\n                    if len(_req.input_ids) >= self.config.max_model_len:\n                        finish_reason_type = FinishReasonTypeEnum.STOP\n                        break\n                    else:\n                        _req.state = AsyncRolloutRequestStateEnum.RUNNING\n\n        if current_turns >= self.config.multi_turn.max_assistant_turns:\n            finish_reason_type = FinishReasonTypeEnum.STOP\n\n        # Calculate the reward for each tool\n        async def calc_reward_and_release_fn(name: str, tool: BaseTool):\n            reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get(\"calc_reward_kwargs\", {}))\n            await tool.release(_req.request_id, **_req.tools_kwargs[name].get(\"release_kwargs\", {}))\n            return name, reward\n\n        tool_reward_tasks = []\n        for name in _req.tools_kwargs.keys():\n            tool = self._tool_map[name]\n            tool_reward_tasks.append(calc_reward_and_release_fn(name, tool))\n        tool_reward_scores = await asyncio.gather(*tool_reward_tasks)\n        tool_reward_scores = dict(tool_reward_scores)\n        all_rewards = {**tool_reward_scores, **{\"user_turn_rewards\": user_turn_rewards}}\n        _req.finalize(self.processing_class, all_rewards, finish_reason_type)\n\n        return _req\n\n    async def _handle_engine_call(\n        self, _req: AsyncRolloutRequest, sampling_params: dict, image_data: Optional[list[Any]] = None\n    ) -> dict:\n        generation_prompt_ids = _req.get_generation_prompt_ids(self.processing_class)\n        return await self._handle_engine_generate(generation_prompt_ids, sampling_params, image_data)\n\n    async def _handle_engine_generate(\n        self, generation_prompt_ids: list[int], sampling_params: dict, image_data: Optional[list[Any]] = None\n    ) -> dict:\n        max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1)\n        kwargs = sampling_params.copy()\n        kwargs[\"max_new_tokens\"] = max_new_tokens\n        kwargs[\"n\"] = 1  # group size is supported in preprocess\n        output = await self._engine.async_generate(\n            input_ids=generation_prompt_ids,\n            sampling_params=kwargs,\n            return_logprob=False,\n            image_data=image_data,\n        )\n        return output\n\n    async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest:\n        if _req.tool_schemas is not None:\n            tool_creation_coroutines = []\n            for tool_schema in _req.tool_schemas:\n                tool = self._tool_map[tool_schema.function.name]\n                create_kwargs = _req.tools_kwargs[tool.name].get(\"create_kwargs\", {})\n                tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs))\n            await asyncio.gather(*tool_creation_coroutines)\n        if _req.interaction_kwargs and self.interaction_map:\n            interaction_kwargs = _req.interaction_kwargs\n            # Get interaction by name from interaction_kwargs\n            interaction_name = interaction_kwargs.get(\"name\", \"gsm8k\")  # Default to gsm8k for backward compatibility\n            if interaction_name not in self.interaction_map:\n                raise ValueError(\n                    f\"Interaction '{interaction_name}' not found in interaction_map. Available interactions: \"\n                    f\"{list(self.interaction_map.keys())}\"\n                )\n\n            interaction = self.interaction_map[interaction_name]\n            await interaction.start_interaction(_req.request_id, **interaction_kwargs)\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto:\n        logger.warning(\n            \"`generate_sequences_with_tools` is deprecated, please use `generate_sequences(...)`\",\n            DeprecationWarning,\n            stacklevel=2,\n        )\n        return self._req_level_generate_sequences(prompts, **kwargs)\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:\n        \"\"\"Generates multi-turn sequences for a batch of prompts.\n        For multi-turn generation, each prompt is processed separately via\n        `_req_level_generate_sequences` for better tool calling control.\n        Note that in multi-turn generation, we repeat the prompts for rollout.n times in ray_trainer.\n        Thus we do not need to repeat the prompts here and set the sampling parameter n to 1.\n        \"\"\"\n        # Async rollout with tools support\n        do_sample = prompts.meta_info.get(\"do_sample\", True)\n        is_validate = prompts.meta_info.get(\"validate\", False)\n        tgt_device = prompts.batch[\"input_ids\"].device\n        if self._tp_rank == 0:\n            req_list = self._preprocess_prompt_to_async_rollout_requests(\n                prompts,\n            )\n            loop = asyncio.get_event_loop()\n            output_req_list = loop.run_until_complete(\n                asyncio.gather(\n                    *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list],\n                )\n            )\n            sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset))\n        else:\n            sorted_output_req_list = None\n\n        dist.barrier()\n        [sorted_output_req_list] = broadcast_pyobj(\n            data=[sorted_output_req_list],\n            rank=self._rank,\n            dist_group=self._device_mesh_cpu[\"tp\"].get_group(),\n            src=self._device_mesh_cpu[\"tp\"].mesh[0].item(),\n            force_cpu_device=False,\n        )\n        # Construct the batch data\n        prompt_ids, response_ids = [], []\n        prompt_attention_mask, response_attention_mask = [], []\n        prompt_position_ids, response_position_ids = [], []\n        prompt_loss_mask, response_loss_mask = [], []\n        messages = []\n        reward_scores = []\n        multi_modal_inputs = []\n        request_ids = []\n\n        for req in sorted_output_req_list:\n            assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f\"Request {req.request_id} is not completed\"\n            assert (\n                req.input_ids.shape[-1]\n                == req.attention_mask.shape[-1]\n                == req.position_ids.shape[-1]\n                == req.loss_mask.shape[-1]\n            ), f\"\"\"Request {req.request_id} has different length of \n                {req.input_ids.shape[-1]=}, {req.attention_mask.shape[-1]=}, \n                {req.position_ids.shape[-1]=}, {req.loss_mask.shape[-1]=}\"\"\"\n            error_message_lines = [\n                f\"\"\"Request {req.request_id} has input_ids length {req.input_ids.shape[-1]}\n                    greater than max_model_len {self.config.max_model_len}\"\"\",\n                f\"Decoded input_ids: {self.processing_class.decode(req.input_ids.squeeze(0))}\",\n                f\"Decoded prompt_ids: {self.processing_class.decode(req.prompt_ids.squeeze(0))}\",\n                f\"Decoded response_ids: {self.processing_class.decode(req.response_ids.squeeze(0))}\",\n                f\"Messages: {req.messages}\",\n                f\"Max model length: {req.max_model_len}\",\n            ]\n            error_message = \"\\n\".join(error_message_lines)\n            assert req.input_ids.shape[-1] <= self.config.max_model_len, error_message\n\n            prompt_ids.append(req.prompt_ids.to(tgt_device).squeeze(0))\n            response_ids.append(req.response_ids.to(tgt_device).squeeze(0))\n            if req.response_ids.shape[-1] > self.config.response_length:\n                logger.warning(\n                    f\"\"\"{req.request_id=} has response_ids length {req.response_ids.shape[-1]} \n                    greater than max_response_len {self.config.response_length},\\n{req=}\"\"\"\n                )\n            prompt_attention_mask.append(req.prompt_attention_mask.to(tgt_device).squeeze(0))\n            response_attention_mask.append(req.response_attention_mask.to(tgt_device).squeeze(0))\n            prompt_position_ids.append(req.prompt_position_ids.to(tgt_device).squeeze(0))\n            response_position_ids.append(req.response_position_ids.to(tgt_device).squeeze(0))\n            prompt_loss_mask.append(req.prompt_loss_mask.to(tgt_device).squeeze(0))\n            response_loss_mask.append(req.response_loss_mask.to(tgt_device).squeeze(0))\n            messages.append({\"messages\": req.messages})\n            reward_scores.append(req.reward_scores)\n            multi_modal_inputs.append(req.multi_modal_inputs)\n            request_ids.append(req.request_id)\n\n        prompt_ids = pad_sequence(\n            prompt_ids,\n            batch_first=True,\n            padding_value=self.pad_token_id,\n            padding_side=\"left\",\n        )\n        if prompt_ids.shape[-1] < self.config.prompt_length:\n            prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True)\n        response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id)\n        if response_ids.shape[-1] < self.config.response_length:\n            response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id)\n        prompt_attention_mask = pad_sequence(\n            prompt_attention_mask,\n            batch_first=True,\n            padding_value=0,\n            padding_side=\"left\",\n        )\n        if prompt_attention_mask.shape[-1] < self.config.prompt_length:\n            prompt_attention_mask = pad_sequence_to_length(\n                prompt_attention_mask, self.config.prompt_length, 0, left_pad=True\n            )\n        response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0)\n        if response_attention_mask.shape[-1] < self.config.response_length:\n            response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0)\n\n        # padding prompt_position_ids\n        if prompt_position_ids[0].dim() == 2:\n            # if prompt_position_ids is a 2D tensor\n            # e.g. from qwen2vl, prompt_position_ids.shape = (3, seq_len)\n            transposed_prompt_position_ids = [p.transpose(0, 1) for p in prompt_position_ids]\n            prompt_position_ids = pad_sequence(\n                transposed_prompt_position_ids, batch_first=True, padding_value=0, padding_side=\"left\"\n            )\n            prompt_position_ids = prompt_position_ids.transpose(1, 2)\n        else:\n            prompt_position_ids = pad_sequence(\n                prompt_position_ids, batch_first=True, padding_value=0, padding_side=\"left\"\n            )\n        if prompt_position_ids.shape[-1] < self.config.prompt_length:\n            prompt_position_ids = pad_sequence_to_length(\n                prompt_position_ids, self.config.prompt_length, 0, left_pad=True\n            )\n\n        # padding response_position_ids\n        if response_position_ids[0].dim() == 2:\n            # if response_position_ids is a 2D tensor\n            # e.g. from qwen2vl, response_position_ids.shape = (3, seq_len)\n            transposed_response_position_ids = [p.transpose(0, 1) for p in response_position_ids]\n            response_position_ids = pad_sequence(\n                transposed_response_position_ids, batch_first=True, padding_value=0, padding_side=\"left\"\n            )\n            response_position_ids = response_position_ids.transpose(1, 2)\n        else:\n            response_position_ids = pad_sequence(response_position_ids, batch_first=True, padding_value=0)\n        if response_position_ids.shape[-1] < self.config.response_length:\n            response_position_ids = pad_sequence_to_length(response_position_ids, self.config.response_length, 0)\n\n        prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side=\"left\")\n        if prompt_loss_mask.shape[1] < self.config.prompt_length:\n            prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True)\n        response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0)\n        if response_loss_mask.shape[1] < self.config.response_length:\n            response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0)\n\n        input_ids = torch.cat((prompt_ids, response_ids), dim=-1)\n        attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)\n        position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1)\n\n        # Construct the batch data\n        batch = TensorDict(\n            {\n                \"prompts\": prompt_ids,\n                \"responses\": response_ids,\n                \"response_mask\": response_loss_mask,\n                \"input_ids\": input_ids,  # here input_ids become the whole sentences\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=len(sorted_output_req_list),\n        )\n\n        # free cache engine\n        if self._engine is not None and self._tp_rank == 0:\n            loop = asyncio.get_event_loop()\n            loop.run_until_complete(self._engine.flush_cache())\n\n        non_tensor_batch = {\n            \"messages\": np.array(messages),\n            \"reward_scores\": np.array(reward_scores),\n            \"request_id\": np.array(request_ids),\n        }\n\n        is_multimodal = isinstance(self.processing_class, ProcessorMixin) and (\n            hasattr(self.processing_class, \"image_processor\") or hasattr(self.model_hf_config, \"vision_config\")\n        )\n\n        if is_multimodal:\n            non_tensor_batch[\"multi_modal_inputs\"] = np.array(multi_modal_inputs, dtype=object)\n\n        return DataProto(\n            batch=batch,\n            non_tensor_batch=non_tensor_batch,\n        )\n\n    def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int = 1) -> list[AsyncRolloutRequest]:\n        assert \"raw_prompt\" in prompts.non_tensor_batch, (\n            \"need data.return_raw_chat=True, due to no official way do parse_messages\"\n        )\n        logger.info(\n            \"n is deprecated for SGLang rollout since ray ppo trainer will repeat the prompts for rollout.n times\"\n        )\n        req_list = []\n        multi_modal_data_list = prompts.non_tensor_batch.get(\n            \"multi_modal_data\", [None] * len(prompts.non_tensor_batch[\"raw_prompt\"])\n        )\n\n        for data_idx, (raw_prompt, multi_modal_data) in enumerate(\n            zip(prompts.non_tensor_batch[\"raw_prompt\"], multi_modal_data_list, strict=True)\n        ):\n            if self._tool_schemas:\n                _tools_kwargs = prompts.non_tensor_batch[\"tools_kwargs\"][data_idx]\n                _tool_schemas = [self._tool_map[k].get_openai_tool_schema() for k in _tools_kwargs.keys()]\n                _input_ids = None\n                _attention_mask = None\n            else:\n                _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch[\"input_ids\"][data_idx])\n                _attention_mask = _pre_process_inputs(0, prompts.batch[\"attention_mask\"][data_idx])\n                _tools_kwargs = {}\n                _tool_schemas = None\n\n            if self.interaction_map:\n                _interaction_kwargs = prompts.non_tensor_batch[\"interaction_kwargs\"][data_idx]\n            else:\n                _interaction_kwargs = {}\n\n            req = AsyncRolloutRequest(\n                batch_data_id=data_idx,\n                rollout_offset=0,\n                request_id=str(uuid4()),\n                state=AsyncRolloutRequestStateEnum.PENDING,\n                messages=raw_prompt.tolist(),\n                multi_modal_data=multi_modal_data,\n                tool_schemas=_tool_schemas,\n                tools_kwargs=_tools_kwargs,\n                interaction_kwargs=_interaction_kwargs,\n                input_ids=_input_ids,\n                response_ids=None,\n                attention_mask=_attention_mask,\n                response_attention_mask=None,\n                response_position_ids=None,\n                response_loss_mask=None,\n                reward_scores={},\n                max_prompt_len=self.config.prompt_length,\n                max_response_len=self.config.response_length,\n                max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),\n                use_inference_chat_template=self.config.multi_turn.use_inference_chat_template,\n                tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode,\n                processing_class=self.processing_class,\n            )\n            error_message = f\"\"\"Request {req.request_id} has mismatched lengths: \n            input_ids={req.input_ids.shape[-1]}, \n            attention_mask={req.attention_mask.shape[-1]}, \n            position_ids={req.position_ids.shape[-1]}, \n            loss_mask={req.loss_mask.shape[-1]}\"\"\"\n            assert (\n                req.input_ids.shape[-1]\n                == req.attention_mask.shape[-1]\n                == req.position_ids.shape[-1]\n                == req.loss_mask.shape[-1]\n            ), error_message\n            req_list.append(req)\n\n        return req_list\n\n    async def chat_completion(self, json_request):\n        assert self._tp_rank == 0, \"only called in tp rank 0\"\n        _input_ids = None\n        _attention_mask = None\n        _position_ids = None\n        _tool_schemas = []\n        _tools_kwargs = {}\n\n        req = AsyncRolloutRequest(\n            request_id=str(uuid4()),\n            state=AsyncRolloutRequestStateEnum.PENDING,\n            messages=[Message.model_validate(msg) for msg in json_request[\"messages\"]],\n            tool_schemas=_tool_schemas,\n            tools_kwargs=_tools_kwargs,\n            input_ids=_input_ids,\n            prompt_ids=_input_ids,\n            response_ids=None,\n            attention_mask=_attention_mask,\n            prompt_attention_mask=_attention_mask,\n            response_attention_mask=None,\n            position_ids=_position_ids,\n            prompt_position_ids=_position_ids,\n            response_position_ids=None,\n            loss_mask=None,\n            prompt_loss_mask=None,\n            response_loss_mask=None,\n            reward_scores={},\n            max_prompt_len=self.config.prompt_length,\n            max_response_len=self.config.response_length,\n            max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),\n            use_inference_chat_template=self.config.multi_turn.use_inference_chat_template,\n            tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode,\n            processing_class=self.processing_class,\n        )\n\n        # json_request already contains sampling_params\n        # Filter only valid SamplingParams arguments\n        valid_sampling_params = {}\n        temp_sampling_params = SamplingParams()  # Create temporary instance to check valid attributes\n        for k, v in json_request.items():\n            if k not in [\"messages\", \"model\", \"tools\"] and hasattr(temp_sampling_params, k):\n                valid_sampling_params[k] = v\n        output = await self._handle_engine_call(req, valid_sampling_params)\n        # it can be Dict or AsyncIterator[Dict]\n        if isinstance(output, dict):\n            outputs = [output]\n        else:\n            outputs = output\n\n        # build openai chat completion format\n        choices = []\n        id = None\n        for i, content in enumerate(outputs):\n            choices.append(\n                {\n                    \"index\": i,\n                    \"message\": {\n                        \"role\": \"assistant\",\n                        \"content\": content[\"text\"],\n                    },\n                    \"finish_reason\": content[\"meta_info\"][\"finish_reason\"][\"type\"],\n                }\n            )\n            id = content[\"meta_info\"][\"id\"]\n\n        return {\n            \"id\": \"chatcmpl-\" + id,\n            \"object\": \"chat.completion\",\n            \"created\": int(time.time()),\n            \"model\": json_request.get(\"model\", \"sglang_model\"),\n            \"choices\": choices,\n        }\n\n        # this function is left for uniform train-inference resharding\n\n    async def generate(\n        self, prompt_ids: torch.Tensor, sampling_params: dict[str, Any], request_id: str\n    ) -> torch.Tensor:\n        request_sampling_params = self.sampling_params.copy()\n        request_sampling_params.update(sampling_params)\n        output = await self._handle_engine_generate(prompt_ids, request_sampling_params)\n        return output[\"output_ids\"]\n\n    async def wake_up(self):\n        if not self.is_sleep:\n            return\n        await self.sharding_manager.wake_up()  # pylint: disable=C2801\n        self.is_sleep = False\n\n    # this function is left for uniform train-inference resharding\n    async def sleep(self):\n        if self.is_sleep:\n            return\n        await self.sharding_manager.sleep()\n        self.is_sleep = True\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/sglang_rollout/utils.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pickle\nfrom typing import Any, Iterator, Optional\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom verl.utils.device import get_device_name\n\n\ndef broadcast_pyobj(\n    data: list[Any],\n    rank: int,\n    dist_group: Optional[torch.distributed.ProcessGroup] = None,\n    src: int = 0,\n    force_cpu_device: bool = False,\n):\n    \"\"\"from https://github.com/sgl-project/sglang/blob/844e2f227ab0cce6ef818a719170ce37b9eb1e1b/python/sglang/srt/utils.py#L905\n\n    Broadcast inputs from src rank to all other ranks with torch.dist backend.\n    The `rank` here refer to the source rank on global process group (regardless\n    of dist_group argument).\n    \"\"\"\n    device = torch.device(get_device_name() if not force_cpu_device else \"cpu\")\n\n    if rank == src:\n        if len(data) == 0:\n            tensor_size = torch.tensor([0], dtype=torch.long, device=device)\n            dist.broadcast(tensor_size, src=src, group=dist_group)\n        else:\n            serialized_data = pickle.dumps(data)\n            size = len(serialized_data)\n\n            tensor_data = torch.ByteTensor(np.frombuffer(serialized_data, dtype=np.uint8)).to(device)\n            tensor_size = torch.tensor([size], dtype=torch.long, device=device)\n\n            dist.broadcast(tensor_size, src=src, group=dist_group)\n            dist.broadcast(tensor_data, src=src, group=dist_group)\n        return data\n    else:\n        tensor_size = torch.tensor([0], dtype=torch.long, device=device)\n        dist.broadcast(tensor_size, src=src, group=dist_group)\n        size = tensor_size.item()\n\n        if size == 0:\n            return []\n\n        tensor_data = torch.empty(size, dtype=torch.uint8, device=device)\n        dist.broadcast(tensor_data, src=src, group=dist_group)\n\n        serialized_data = bytes(tensor_data.cpu().numpy())\n        data = pickle.loads(serialized_data)\n        return data\n\n\ndef get_named_tensor_buckets(\n    iterable: Iterator[tuple[str, torch.Tensor]], bucket_bytes: int\n) -> Iterator[list[tuple[str, torch.Tensor]]]:\n    \"\"\"\n    Group tensors into buckets based on a specified size in megabytes.\n\n    Args:\n        iterable: An iterator of tuples containing tensor names and tensors.\n        bucket_bytes: The maximum size of each bucket in bytes.\n\n    Yields:\n        Lists of tuples, where each tuple contains a tensor name and its corresponding tensor.\n\n    Example:\n        >>> tensors = [('tensor1', torch.randn(1000, 1000)), ('tensor2', torch.randn(2000, 2000))]\n        >>> for bucket in get_named_tensor_buckets(tensors, bucket_size_mb=10):\n        ...     print(bucket)\n        [('tensor1', tensor(...)), ('tensor2', tensor(...))]\n\n    \"\"\"\n    if bucket_bytes <= 0:\n        raise ValueError(f\"bucket_bytes must be greater than 0, got {bucket_bytes}\")\n\n    current_bucket = []\n    current_size = 0\n    for name, tensor in iterable:\n        tensor_size = tensor.element_size() * tensor.numel()\n        if current_size + tensor_size > bucket_bytes:\n            if current_bucket:\n                yield current_bucket\n            current_bucket = [(name, tensor)]\n            current_size = tensor_size\n        else:\n            current_bucket.append((name, tensor))\n            current_size += tensor_size\n\n    if current_bucket:\n        yield current_bucket\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/tokenizer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe base tokenizer class, required for any hybrid engine based rollout or inference with vLLM.\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nimport numpy as np\nimport torch\n\n__all__ = [\"HybridEngineBaseTokenizer\"]\n\n\nclass HybridEngineBaseTokenizer(ABC):\n    \"\"\"the tokenizer property and function name should align with HF's to meet vllm requirement\"\"\"\n\n    @property\n    @abstractmethod\n    def vocab_size(self):\n        \"\"\"\n        `int`: Size of the base vocabulary (without the added tokens).\n        \"\"\"\n        pass\n\n    @property\n    @abstractmethod\n    def pad_token_id(self):\n        \"\"\"\n        `Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.\n        \"\"\"\n        pass\n\n    @property\n    @abstractmethod\n    def eos_token_id(self):\n        \"\"\"\n        `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been\n        set.\n        \"\"\"\n        pass\n\n    @property\n    @abstractmethod\n    def all_special_ids(self) -> list[int]:\n        \"\"\"\n        `List[int]`: List the ids of the special tokens(`'<unk>'`, `'<cls>'`, etc.) mapped to class attributes.\n        \"\"\"\n        pass\n\n    @property\n    @abstractmethod\n    def all_special_tokens(self) -> list[str]:\n        \"\"\"\n        `List[str]`: A list of the unique special tokens (`'<unk>'`, `'<cls>'`, ..., etc.).\n\n        Convert tokens of `tokenizers.AddedToken` type to string.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def encode(self, text):\n        \"\"\"\n        Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.\n\n        Args:\n            text (`str`, `List[str]` or `List[int]`):\n                The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the\n                `tokenize` method) or a list of integers.\n\n            text_pair (`str`, `List[str]` or `List[int]`, *optional*):\n                Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using\n                the `tokenize` method) or a list of integers.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def decode(\n        self,\n        token_ids: int | list[int] | np.ndarray | torch.Tensor,\n        skip_special_tokens: bool = False,\n        clean_up_tokenization_spaces: bool = None,\n        **kwargs,\n    ) -> str:\n        \"\"\"\n        Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special\n        tokens and clean up tokenization spaces.\n\n        Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.\n\n        Args:\n            token_ids (`Union[int, List[int], np.ndarray, torch.Tensor]`):\n                List of tokenized input ids. Can be obtained using the `__call__` method.\n            skip_special_tokens (`bool`, *optional*, defaults to `False`):\n                Whether or not to remove special tokens in the decoding.\n            clean_up_tokenization_spaces (`bool`, *optional*):\n                Whether or not to clean up the tokenization spaces. If `None`, will default to\n                `self.clean_up_tokenization_spaces`.\n            kwargs (additional keyword arguments, *optional*):\n                Will be passed to the underlying model specific decode method.\n\n        Returns:\n            `str`: The decoded sentence.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False) -> str | list[str]:\n        \"\"\"\n        Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and\n        added tokens.\n\n        Args:\n            ids (`int` or `List[int]`):\n                The token id (or token ids) to convert to tokens.\n            skip_special_tokens (`bool`, *optional*, defaults to `False`):\n                Whether or not to remove special tokens in the decoding.\n\n        Returns:\n            `str` or `List[str]`: The decoded token(s).\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_added_vocab(self) -> dict[str, int]:\n        \"\"\"\n        Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from\n        the fast call because for now we always add the tokens even if they are already in the vocabulary. This is\n        something we should change.\n\n        Returns:\n            `Dict[str, int]`: The added tokens.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def convert_tokens_to_string(self, tokens: list[str]) -> str:\n        \"\"\"\n        Converts a sequence of tokens in a single string. The most simple way to do it is `\" \".join(tokens)` but we\n        often want to remove sub-word tokenization artifacts at the same time.\n\n        Args:\n            tokens (`List[str]`): The token to join in a string.\n\n        Returns:\n            `str`: The joined tokens.\n        \"\"\"\n        pass\n\n    @property\n    def is_fast(self):\n        return False\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/vllm_rollout/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nfrom importlib.metadata import PackageNotFoundError, version\n\nfrom .vllm_rollout_spmd import vLLMAsyncRollout, vLLMRollout  # noqa: F401\n\n\ndef get_version(pkg):\n    try:\n        return version(pkg)\n    except PackageNotFoundError:\n        return None\n\n\nvllm_package_name = \"vllm\"\nvllm_package_version = get_version(vllm_package_name)\nif vllm_package_version is None:\n    raise PackageNotFoundError(\n        \"To use vllm rollout, please ensure the 'vllm' package is properly installed. See \"\n        \"https://verl.readthedocs.io/en/latest/start/install.html for more details\"\n    )\n\nif \"ROCM_PATH\" in os.environ:\n    import re\n\n    match = re.match(r\"(\\d+\\.\\d+\\.?\\d*)\", vllm_package_version)\n    if match:\n        vllm_package_version = match.group(1)\n    else:\n        raise ValueError(f\"Warning: Could not parse version format: {vllm_package_version}\")\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/vllm_rollout/vllm_async_server.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport logging\nimport os\nimport pickle\nfrom typing import Any, Callable, Optional\n\nimport ray\nimport zmq\nfrom omegaconf import DictConfig\nfrom starlette.requests import Request\nfrom starlette.responses import JSONResponse, StreamingResponse\nfrom vllm import SamplingParams\nfrom vllm.engine.arg_utils import AsyncEngineArgs\nfrom vllm.entrypoints.logger import RequestLogger\nfrom vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse\nfrom vllm.entrypoints.openai.serving_chat import OpenAIServingChat\nfrom vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels\nfrom vllm.inputs import TokensPrompt\nfrom vllm.outputs import RequestOutput\nfrom vllm.v1.engine.async_llm import AsyncLLM\nfrom vllm.v1.executor.abstract import Executor\nfrom vllm.worker.worker_base import WorkerWrapperBase\n\nfrom verl.utils.fs import copy_to_local\nfrom verl.workers.rollout.async_server import AsyncServerBase\n\nlogger = logging.getLogger(__file__)\n\n\ndef _get_model_runner_workers(vllm_config, init_ray: bool = True):\n    assert vllm_config.instance_id is not None, \"instance_id must be set for external ray actors.\"\n\n    fields = vllm_config.instance_id.split(\":\")\n    assert len(fields) == 4, (\n        f\"instance_id: {vllm_config.instance_id} must be in the format of \"\n        f\"<namespace>:<wg_prefix>:<vllm_dp_size>:<vllm_dp_rank>.\"\n    )\n    namespace, wg_prefix, vllm_dp_size, vllm_dp_rank = fields[0], fields[1], int(fields[2]), int(fields[3])\n\n    # Make sure subprocess in same namespace as parent actor.\n    # actor name format: {name_prefix}WorkerDict_{pg_idx}:{local_rank}\n    if init_ray:\n        ray.init(namespace=namespace)\n    actor_names = [\n        actor_name for actor_name in ray.util.list_named_actors() if actor_name.startswith(f\"{wg_prefix}WorkerDict\")\n    ]\n\n    vllm_tp_size = vllm_config.parallel_config.tensor_parallel_size\n    assert len(actor_names) == vllm_dp_size * vllm_tp_size, (\n        f\"instance_id: {vllm_config.instance_id} has {len(actor_names)} actors, but vllm_dp_size: \"\n        f\"{vllm_dp_size} * vllm_tp_size: {vllm_tp_size} = {vllm_dp_size * vllm_tp_size} is expected.\"\n    )\n\n    def get_pg_index_and_local_rank(actor_name) -> tuple[int, int]:\n        fields = actor_name.split(\":\")\n        assert len(fields) == 2, f\"invalid actor name: {actor_name}\"\n        pg_index, local_rank = int(fields[0].split(\"_\")[-1]), int(fields[1])\n        return pg_index, local_rank\n\n    # sort actor names by pg_index and local_rank\n    actor_names = sorted(actor_names, key=get_pg_index_and_local_rank)\n    actor_names = actor_names[vllm_dp_rank * vllm_tp_size : (vllm_dp_rank + 1) * vllm_tp_size]\n    workers: list[WorkerWrapperBase] = [ray.get_actor(actor_name) for actor_name in actor_names]\n    print(f\"instance_id: {vllm_config.instance_id} initializes with external actors: {actor_names}\")\n\n    return workers\n\n\nclass ExternalRayDistributedExecutor(Executor):\n    \"\"\"An executor that engines are launched by external ray actors.\"\"\"\n\n    uses_ray: bool = False\n\n    def _init_executor(self) -> None:\n        self.workers = _get_model_runner_workers(vllm_config=self.vllm_config, init_ray=True)\n\n        kwargs = dict(\n            vllm_config=self.vllm_config,\n            local_rank=None,\n            rank=None,\n            distributed_init_method=\"env://\",\n            is_driver_worker=True,\n        )\n        self.collective_rpc(\"init_worker\", args=([kwargs],))\n        self.collective_rpc(\"init_device\")\n        self.collective_rpc(\"load_model\")\n        print(f\"instance_id: {self.vllm_config.instance_id} initializes finished.\")\n\n    def collective_rpc(\n        self,\n        method: str | Callable,\n        timeout: Optional[float] = None,\n        args: tuple = (),\n        kwargs: Optional[dict[str, Any]] = None,\n    ) -> list[Any]:\n        # TODO(wuxibin): support ray compiled graph\n        if isinstance(method, str):\n            sent_method = method\n        else:\n            sent_method = pickle.dumps(method)\n        del method\n\n        # ~3ms overhead per schedule step due to SchedulerOutput/ModelRunnerOutput serialization/deserialization.\n        outputs = ray.get(\n            [worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers]\n        )\n        return outputs\n\n    def check_health(self):\n        return\n\n\nclass ExternalZeroMQDistributedExecutor(Executor):\n    \"\"\"An executor that engines are launched by external ray actors.\"\"\"\n\n    uses_ray: bool = False\n\n    def _init_executor(self) -> None:\n        addresses = os.environ[\"VERL_VLLM_ZMQ_ADDRESSES\"].split(\",\")\n        self.context = zmq.Context()\n        self.sockets = []\n        for address in addresses:\n            socket = self.context.socket(zmq.REQ)\n            socket.connect(address)\n            self.sockets.append(socket)\n\n        kwargs = dict(\n            vllm_config=self.vllm_config,\n            local_rank=None,\n            rank=None,\n            distributed_init_method=\"env://\",\n            is_driver_worker=True,\n        )\n        self.collective_rpc(\"init_worker\", args=([kwargs],))\n        self.collective_rpc(\"init_device\")\n        self.collective_rpc(\"load_model\")\n\n    def collective_rpc(\n        self,\n        method: str | Callable,\n        timeout: Optional[float] = None,\n        args: tuple = (),\n        kwargs: Optional[dict[str, Any]] = None,\n    ) -> list[Any]:\n        if isinstance(method, str):\n            sent_method = method\n        else:\n            sent_method = pickle.dumps(method)\n        del method\n\n        message = pickle.dumps((sent_method, args, kwargs or {}))\n        for socket in self.sockets:\n            socket.send(message, zmq.DONTWAIT)\n\n        outputs = []\n        for socket in self.sockets:\n            outputs.append(pickle.loads(socket.recv()))\n        return outputs\n\n    def check_health(self):\n        return\n\n\n@ray.remote(num_cpus=1)\nclass AsyncvLLMServer(AsyncServerBase):\n    \"\"\"\n    AsyncvLLMServer is a wrapper for AsyncLLM, it uses ExternalRayDistributedExecutor to launch engines\n    in hybrid rollout workers, i.e AsyncActorRolloutRefWorker.\n\n    AsyncvLLMServer works as follows:\n    1. Start FastAPI server first.\n    2. Initialize AsyncLLM with ExternalRayDistributedExecutor.\n    3. AsyncLLM spawn EngineCore in subprocess.\n    4. EngineCore initialize ExternalRayDistributedExecutor.\n    5. ExternalRayDistributedExecutor lookup its corresponding actors by name.\n    6. ExternalRayDistributedExecutor init executor: init_worker, init_device, load_model.\n\n    For vLLM AsyncLLM design, see: https://github.com/vllm-project/vllm/pull/9826\n    \"\"\"\n\n    def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_prefix: str):\n        \"\"\"\n        Args:\n            config: DictConfig.\n            vllm_dp_size: int, vllm data parallel size.\n            vllm_dp_rank: int, vllm data parallel rank.\n            wg_prefix: str, worker group prefix, used to lookup actors.\n        \"\"\"\n        super().__init__()\n\n        self.config = config.actor_rollout_ref\n        self.vllm_dp_size = vllm_dp_size\n        self.vllm_dp_rank = vllm_dp_rank\n        self.wg_prefix = wg_prefix\n        self.engine: AsyncLLM = None\n\n    async def init_engine(self):\n        \"\"\"Init vLLM AsyncLLM engine.\"\"\"\n        config = self.config\n        model_path = config.model.path\n        model_name = \"/\".join(model_path.split(\"/\")[-2:])\n        local_path = copy_to_local(model_path)\n        trust_remote_code = config.model.get(\"trust_remote_code\", False)\n        config = config.rollout\n\n        tensor_parallel_size = config.get(\"tensor_model_parallel_size\", 1)\n        max_num_batched_tokens = config.get(\"max_num_batched_tokens\", 8192)\n        max_model_len = config.max_model_len if config.max_model_len else config.prompt_length + config.response_length\n        self.max_model_len = int(max_model_len)\n\n        # Override default generation config from hugging face model config,\n        # user can still override them by passing kwargs in each request.\n        kwargs = dict(\n            n=1,\n            logprobs=0,\n            repetition_penalty=1.0,\n            max_new_tokens=config.response_length,\n        )\n        for k in config.keys():\n            if hasattr(SamplingParams(), str(k)):\n                kwargs[k] = config.get(k)\n        print(f\"override_generation_config: {kwargs}\")\n\n        backend = os.environ.get(\"VERL_VLLM_DISTRIBUTED_BACKEND\", \"zeromq\")\n        if backend == \"zeromq\":\n            distributed_executor_backend = ExternalZeroMQDistributedExecutor\n        elif backend == \"ray\":\n            distributed_executor_backend = ExternalRayDistributedExecutor\n        else:\n            distributed_executor_backend = None\n\n        engine_args = AsyncEngineArgs(\n            model=local_path,\n            enable_sleep_mode=config.free_cache_engine,\n            override_generation_config=kwargs,\n            tensor_parallel_size=tensor_parallel_size,\n            distributed_executor_backend=distributed_executor_backend,\n            dtype=config.dtype,\n            enforce_eager=config.enforce_eager,\n            gpu_memory_utilization=config.gpu_memory_utilization,\n            disable_custom_all_reduce=True,\n            skip_tokenizer_init=False,\n            max_model_len=self.max_model_len,\n            load_format=\"auto\",\n            disable_log_stats=config.disable_log_stats,\n            max_num_batched_tokens=max_num_batched_tokens,\n            enable_chunked_prefill=config.enable_chunked_prefill,\n            enable_prefix_caching=True,\n            trust_remote_code=trust_remote_code,\n            seed=config.get(\"seed\", 0),\n        )\n\n        # init async llm engine\n        vllm_config = self._create_engine_config(engine_args)\n        self.engine = AsyncLLM.from_vllm_config(vllm_config)\n\n        # build serving chat\n        model_config = self.engine.model_config\n        BASE_MODEL_PATHS = [BaseModelPath(name=model_name, model_path=model_path)]\n        models = OpenAIServingModels(self.engine, model_config, BASE_MODEL_PATHS)\n        self.openai_serving_chat = OpenAIServingChat(\n            self.engine,\n            model_config,\n            models,\n            \"assistant\",\n            request_logger=RequestLogger(max_log_len=4096),\n            chat_template=None,\n            chat_template_content_format=\"auto\",\n            enable_auto_tools=config.multi_turn.tool_config_path is not None,\n            tool_parser=config.multi_turn.format,  # hermes, llama3_json, ...\n        )\n\n    def _create_engine_config(self, engine_args: AsyncEngineArgs):\n        vllm_config = engine_args.create_engine_config()\n        namespace = ray.get_runtime_context().namespace\n        vllm_config.instance_id = f\"{namespace}:{self.wg_prefix}:{self.vllm_dp_size}:{self.vllm_dp_rank}\"\n\n        # VERL_VLLM_ZMQ_ADDRESSES\n        if engine_args.distributed_executor_backend == ExternalZeroMQDistributedExecutor:\n            workers = _get_model_runner_workers(vllm_config=vllm_config, init_ray=False)\n            zmq_addresses = ray.get([worker.get_zeromq_address.remote() for worker in workers])\n            print(f\"VERL_VLLM_ZMQ_ADDRESSES: {zmq_addresses}\")\n            os.environ[\"VERL_VLLM_ZMQ_ADDRESSES\"] = \",\".join(zmq_addresses)\n\n        return vllm_config\n\n    async def chat_completion(self, raw_request: Request):\n        \"\"\"OpenAI-compatible HTTP endpoint.\n\n        API reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html\n        \"\"\"\n        request_json = await raw_request.json()\n        request = ChatCompletionRequest(**request_json)\n        generator = await self.openai_serving_chat.create_chat_completion(request, raw_request)\n\n        if isinstance(generator, ErrorResponse):\n            return JSONResponse(content=generator.model_dump(), status_code=generator.code)\n        if request.stream:\n            return StreamingResponse(content=generator, media_type=\"text/event-stream\")\n        else:\n            assert isinstance(generator, ChatCompletionResponse)\n            return JSONResponse(content=generator.model_dump())\n\n    async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]:\n        max_tokens = self.max_model_len - len(prompt_ids)\n        sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)\n        prompt = TokensPrompt(prompt_token_ids=prompt_ids)\n        generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id)\n\n        # Get final response\n        final_res: Optional[RequestOutput] = None\n        async for output in generator:\n            final_res = output\n        assert final_res is not None\n\n        return final_res.outputs[0].token_ids\n\n    async def wake_up(self):\n        if self.config.rollout.free_cache_engine:\n            await self.engine.wake_up()\n\n    async def sleep(self):\n        # TODO: https://github.com/vllm-project/vllm/issues/17103\n        await self.engine.reset_prefix_cache()\n        if self.config.rollout.free_cache_engine:\n            await self.engine.sleep()\n"
  },
  {
    "path": "verl_rl/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe vllm_rollout that can be applied in different backend\nWhen working with FSDP:\n- Use DTensor weight loader (recommended) or HF weight loader\n- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM\nWhen working with Megatron:\n- Use Megatron weight loader\n- During training, only the current pp stage holds the parameters\n- Before inference, broadcast the parameters of the current pp rank\n  to all other pp ranks (all pp ranks holds all the parameters)\n- Bind the parameters to the inference engine\n- Do inference in tp. pp is treated as additional dp\n- After inference, all the parameters that doesn't belong to this pp rank is freed.\n\"\"\"\n\nimport logging\nimport os\nimport pickle\nimport socket\nimport threading\nfrom contextlib import contextmanager\nfrom copy import deepcopy\nfrom types import MethodType\nfrom typing import Any\n\nimport numpy as np\nimport ray\nimport torch\nimport torch.distributed\nimport zmq\nfrom filelock import FileLock\nfrom omegaconf import DictConfig, OmegaConf\nfrom tensordict import TensorDict\nfrom vllm import LLM, SamplingParams\nfrom vllm.distributed import parallel_state as vllm_ps\nfrom vllm.lora.request import LoRARequest\nfrom vllm.model_executor.sampling_metadata import SamplingMetadata\nfrom vllm.worker.worker_base import WorkerWrapperBase\n\nfrom verl import DataProto\nfrom verl.utils.profiler import GPUMemoryLogger\nfrom verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length\nfrom verl.workers.rollout.base import BaseRollout\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n# TODO\n# 1. support pp in vllm\n# 2. passing tokenizer is not necessary? no encoding/decoding is happending here\n# 3. simplify init logics\n\n\n# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.\ndef _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[int]:\n    # remove the left padding in the prompt token_id\n    # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id\n    # is not None else self.llm_engine.tokenizer.eos_token_id\n    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]\n    token_ids = prompt_token_ids[non_pad_index:].tolist()\n    return token_ids\n\n\nclass vLLMRollout(BaseRollout):\n    def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs):\n        \"\"\"A vLLM rollout. It requires the module is supported by the vllm.\n\n        Args:\n            module: module here follows huggingface APIs\n            config: DictConfig\n            tokenizer: the task/model tokenizer\n            model_hf_config: the huggingface config to initiallize the generating model in vllm\n            **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group\n        \"\"\"\n        super().__init__()\n        self.config = config\n\n        tensor_parallel_size = self.config.get(\"tensor_model_parallel_size\", 1)\n        assert tensor_parallel_size <= torch.distributed.get_world_size(), (\n            \"tensor parallel size should be less than or equal to the world size\"\n        )\n        max_num_batched_tokens = self.config.get(\"max_num_batched_tokens\", 8192)\n\n        if kwargs.get(\"train_tp\") is not None:\n            # deployed with megatron\n            import os\n\n            os.environ[\"CUDA_TIMER_STREAM_KAFKA_ENABLE\"] = \"0\"\n            os.environ[\"MEGATRON_IMPORT_TIMERS\"] = \"0\"\n            vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size)\n\n        rope_scaling_config = getattr(model_hf_config, \"rope_scaling\", None)\n        if not rope_scaling_config:\n            max_position_embeddings = None\n            if hasattr(model_hf_config, \"max_position_embeddings\"):\n                max_position_embeddings = model_hf_config.max_position_embeddings\n            elif hasattr(model_hf_config, \"llm_config\") and hasattr(\n                model_hf_config.llm_config, \"max_position_embeddings\"\n            ):\n                max_position_embeddings = model_hf_config.llm_config.max_position_embeddings\n            elif hasattr(model_hf_config, \"text_config\") and hasattr(\n                model_hf_config.text_config, \"max_position_embeddings\"\n            ):\n                max_position_embeddings = model_hf_config.text_config.max_position_embeddings\n            if max_position_embeddings is None:\n                raise ValueError(\"max_position_embeddings not found in model_hf_config\")\n            assert max_position_embeddings >= config.prompt_length + config.response_length, (\n                \"model context length should be greater than total sequence length\"\n            )\n        else:\n            # handle type where there's a length extend factor\n            # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support\n            # for using yarn as an example\n            rope_scaling_factor = rope_scaling_config.get(\"factor\", 1.0)\n\n            assert (\n                model_hf_config.max_position_embeddings * rope_scaling_factor\n                >= config.prompt_length + config.response_length\n            ), (\n                \"model context length should be greater than total sequence length, \"\n                + f\"got rope_scaling_factor={rope_scaling_factor} and \"\n                + f\"max_position_embeddings={model_hf_config.max_position_embeddings}\"\n            )\n\n        max_model_len = int(config.max_model_len or config.prompt_length + config.response_length)\n\n        if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill:\n            raise ValueError(\n                \"Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \\\n                             please increase max_num_batched_tokens or disable chunked prefill\"\n            )\n\n        trust_remote_code = kwargs.get(\"trust_remote_code\", False)\n        load_format = \"dummy\" if config.load_format.startswith(\"dummy\") else config.load_format\n\n        lora_kwargs = kwargs.pop(\"lora_kwargs\", {})\n        self.lora_kwargs = lora_kwargs\n        # copy it to avoid secretly modifying the engine config\n        engine_kwargs = (\n            {}\n            if \"engine_kwargs\" not in config or \"vllm\" not in config.engine_kwargs\n            else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm))\n        )\n        # For each vLLM engine parameter,\n        # - `None` means not setting it, so we pop it, and leave it to vLLM default value\n        #    (which can vary across different vLLM versions);\n        # - Otherwise it's the desired value we want to explicitly set.\n        engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None}\n        if config.get(\"limit_images\", None):  # support for multi-image data\n            engine_kwargs[\"limit_mm_per_prompt\"] = {\"image\": config.get(\"limit_images\")}\n\n        self.inference_engine = LLM(\n            model=model_path,\n            enable_sleep_mode=config.free_cache_engine,\n            tensor_parallel_size=tensor_parallel_size,\n            distributed_executor_backend=\"external_launcher\",\n            dtype=config.dtype,\n            enforce_eager=config.enforce_eager,\n            gpu_memory_utilization=config.gpu_memory_utilization,\n            disable_custom_all_reduce=True,\n            skip_tokenizer_init=False,\n            max_model_len=max_model_len,\n            load_format=load_format,\n            disable_log_stats=config.disable_log_stats,\n            max_num_batched_tokens=max_num_batched_tokens,\n            enable_chunked_prefill=config.enable_chunked_prefill,\n            enable_prefix_caching=True,\n            trust_remote_code=trust_remote_code,\n            seed=config.get(\"seed\", 0),\n            **lora_kwargs,\n            **engine_kwargs,\n        )\n\n        # Offload vllm model to reduce peak memory usage\n        if config.free_cache_engine:\n            self.inference_engine.sleep(level=1)\n\n        kwargs = dict(\n            n=1,\n            logprobs=0,  # can be set to 0 and let actor to recompute\n            max_tokens=config.response_length,\n        )\n\n        kwargs[\"detokenize\"] = False\n\n        # supporting adding any sampling params from the config file\n        for k in config.keys():\n            if hasattr(SamplingParams(), str(k)) and k != \"seed\":\n                kwargs[k] = config.get(k)\n        kwargs[\"n\"] = 1  # already repeat in ray_trainer\n        print(f\"kwargs: {kwargs}\")\n        self.sampling_params = SamplingParams(**kwargs)\n\n        self.pad_token_id = tokenizer.pad_token_id\n\n    @contextmanager\n    def update_sampling_params(self, **kwargs):\n        # update sampling params\n        old_sampling_params_args = {}\n        if kwargs:\n            for key, value in kwargs.items():\n                if hasattr(self.sampling_params, key):\n                    old_value = getattr(self.sampling_params, key)\n                    old_sampling_params_args[key] = old_value\n                    setattr(self.sampling_params, key, value)\n        yield\n        # roll back to previous sampling params\n        # if len(old_sampling_params_args):\n        for key, value in old_sampling_params_args.items():\n            setattr(self.sampling_params, key, value)\n\n    @GPUMemoryLogger(role=\"vllm rollout spmd\", logger=logger)\n    @torch.no_grad()\n    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:\n        \"\"\"Generate sequences for a batch of prompts.\n\n        Args:\n            batch (DataProto): Input batch.\n\n        Returns:\n            DataProto: Output batch.\n            - prompts: [bsz, prompt_length], prompt token ids from dataset.\n            - responses: [bsz, response_length], output token ids include response tokens\n              from LLM generation and observation tokens from tool_calls.\n            - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.\n            - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens\n              and response tokens.\n            - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.\n            - position_ids: [bsz, prompt_length + response_length], incremental position ids.\n\n            For multi-turn conversations:\n            responses:     |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|\n            response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|\n        \"\"\"\n        idx = prompts.batch[\"input_ids\"]  # (bs, prompt_length)\n        # left-padded attention_mask\n        attention_mask = prompts.batch[\"attention_mask\"]\n        position_ids = prompts.batch[\"position_ids\"]\n\n        # used to construct attention_mask\n        eos_token_id = prompts.meta_info[\"eos_token_id\"]\n\n        batch_size = idx.size(0)\n\n        non_tensor_batch = prompts.non_tensor_batch\n        if \"raw_prompt_ids\" not in non_tensor_batch:\n            non_tensor_batch[\"raw_prompt_ids\"] = np.array(\n                [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object\n            )\n\n        if batch_size != len(non_tensor_batch[\"raw_prompt_ids\"]):\n            raise RuntimeError(\"vllm sharding manager is not work properly.\")\n\n        if \"multi_modal_data\" in non_tensor_batch:\n            vllm_inputs = []\n            for raw_prompt_ids, multi_modal_data in zip(\n                non_tensor_batch.pop(\"raw_prompt_ids\"), non_tensor_batch.pop(\"multi_modal_data\"), strict=True\n            ):\n                vllm_inputs.append({\"prompt_token_ids\": raw_prompt_ids, \"multi_modal_data\": multi_modal_data})\n        else:\n            vllm_inputs = [\n                {\"prompt_token_ids\": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop(\"raw_prompt_ids\")\n            ]\n\n        # ensure the type of `prompt_token_ids` passed to vllm is list[int]\n        # https://github.com/volcengine/verl/pull/772\n        for input_data in vllm_inputs:\n            if isinstance(input_data[\"prompt_token_ids\"], np.ndarray):\n                input_data[\"prompt_token_ids\"] = input_data[\"prompt_token_ids\"].tolist()\n            elif not isinstance(input_data[\"prompt_token_ids\"], list):\n                raise TypeError(\n                    f\"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}\"\n                )\n\n        do_sample = prompts.meta_info.get(\"do_sample\", True)\n        is_validate = prompts.meta_info.get(\"validate\", False)\n        if not do_sample:\n            kwargs = {\n                \"best_of\": 1,\n                \"top_p\": 1.0,\n                \"top_k\": -1,\n                \"min_p\": 0.0,\n                \"temperature\": 0,\n                \"n\": 1,  # if greedy, only 1 response\n            }\n        elif is_validate:\n            # TODO: try **\n            kwargs = {\n                \"top_k\": self.config.val_kwargs.top_k,\n                \"top_p\": self.config.val_kwargs.top_p,\n                \"temperature\": self.config.val_kwargs.temperature,\n                \"n\": 1,  # if validate, already repeat in ray_trainer\n            }\n\n        lora_requests = None\n        if self.lora_kwargs:\n            lora_int_ids = list(self.inference_engine.llm_engine.list_loras())\n            if len(lora_int_ids) > 0:\n                lora_int_id = lora_int_ids[0]\n                lora_requests = [\n                    LoRARequest(lora_name=f\"{lora_int_id}\", lora_int_id=lora_int_id, lora_path=\"/simon-stub-path\")\n                ] * batch_size\n\n        # users can customize different sampling_params at different run\n        with self.update_sampling_params(**kwargs):\n            outputs = self.inference_engine.generate(\n                prompts=vllm_inputs,  # because we have already convert it to prompt token id\n                sampling_params=self.sampling_params,\n                lora_request=lora_requests,\n                use_tqdm=False,\n            )\n\n            # TODO(sgm): disable logprob when recompute_log_prob is enable\n            # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)\n\n            response = []\n            rollout_log_probs = []\n            for output in outputs:\n                for sample_id in range(len(output.outputs)):\n                    response_ids = output.outputs[sample_id].token_ids\n                    response.append(response_ids)\n                    if self.config.calculate_log_probs:\n                        curr_log_prob = []\n                        for i, logprob in enumerate(output.outputs[sample_id].logprobs):\n                            curr_log_prob.append(logprob[response_ids[i]].logprob)\n                        rollout_log_probs.append(curr_log_prob)\n\n            response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(\n                idx.device\n            )\n            if self.config.calculate_log_probs:\n                rollout_log_probs = pad_2d_list_to_length(\n                    rollout_log_probs, -1, max_length=self.config.response_length\n                ).to(idx.device)\n                rollout_log_probs = rollout_log_probs.to(torch.float32)\n\n            seq = torch.cat([idx, response], dim=-1)\n\n        response_length = response.size(1)\n        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n        delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)\n        if position_ids.dim() == 3:  # qwen2vl mrope\n            delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)\n\n        # TODO(sgm): fix position_ids on right_pad\n        # prompt: left pad + response: right pad\n        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n        response_position_ids = position_ids[..., -1:] + delta_position_id\n        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n        response_attention_mask = get_response_mask(\n            response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype\n        )\n        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n\n        # all the tp ranks should contain the same data here. data in all ranks are valid\n        batch = TensorDict(\n            {\n                \"prompts\": idx,\n                \"responses\": response,\n                \"input_ids\": seq,  # here input_ids become the whole sentences\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=batch_size,\n        )\n        if self.config.calculate_log_probs:\n            # we will recompute old log prob with actor\n            batch[\"rollout_log_probs\"] = rollout_log_probs\n\n        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)\n\n\n# https://github.com/vllm-project/vllm/issues/13175\ndef _monkey_patch_compute_logits(model, vocab_size: int):\n    original_compute_logits = model.compute_logits\n\n    def compute_logits(\n        self,\n        hidden_states: torch.Tensor,\n        sampling_metadata: SamplingMetadata,\n    ) -> torch.Tensor:\n        logits = original_compute_logits(hidden_states, sampling_metadata)\n        logits[..., vocab_size:] = float(\"-inf\")\n        return logits\n\n    model.compute_logits = MethodType(compute_logits, model)\n\n\nclass vLLMAsyncRollout:\n    \"\"\"vLLMAsyncRollout is a thin wrapper of WorkerWrapperBase,\n    which is engine in single worker process.\n    \"\"\"\n\n    def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs):\n        self.tokenizer = tokenizer\n\n        # Engine is deferred to be initialized in init_worker\n        self.config = config\n        self.inference_engine: WorkerWrapperBase = None\n        self.sharding_manager = None\n        self.is_sleep = False\n        self.address = self._init_zeromq()\n\n    def _init_zeromq(self) -> str:\n        tensor_parallel_size = self.config.tensor_model_parallel_size\n\n        # single node: ipc, multi nodes: tcp\n        local_world_size = int(os.environ[\"RAY_LOCAL_WORLD_SIZE\"])\n        socket_type = \"ipc\" if tensor_parallel_size <= local_world_size else \"tcp\"\n\n        # File lock to prevent multiple workers listen to same port\n        with FileLock(\"/tmp/verl_vllm_zmq.lock\"):\n            if socket_type == \"ipc\":\n                pid = os.getpid()\n                address = f\"ipc:///tmp/verl_vllm_zmq_{pid}.ipc\"\n            else:\n                ip, port = self._get_free_port()\n                address = f\"tcp://{ip}:{port}\"\n            context = zmq.Context()\n            self.socket = context.socket(zmq.REP)\n            self.socket.bind(address)\n\n        self.loop_thread = threading.Thread(target=self._loop_forever)\n        self.loop_thread.start()\n\n        return address\n\n    def _get_free_port(self):\n        ip = ray.util.get_node_ip_address()\n        with socket.socket() as sock:\n            sock.bind((\"\", 0))\n            port = sock.getsockname()[1]\n        return ip, port\n\n    def _loop_forever(self):\n        while True:\n            message = self.socket.recv()\n            method, args, kwargs = pickle.loads(message)\n            result = self.execute_method(method, *args, **kwargs)\n            self.socket.send(pickle.dumps(result))\n\n    def get_zeromq_address(self):\n        return self.address\n\n    def init_worker(self, all_kwargs: list[dict[str, Any]]):\n        \"\"\"Initialize worker engine.\"\"\"\n        all_kwargs[0][\"rank\"] = int(os.environ[\"RANK\"])\n        all_kwargs[0][\"local_rank\"] = 0\n\n        self.vllm_config = all_kwargs[0][\"vllm_config\"]\n        self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)\n        self.inference_engine.init_worker(all_kwargs)\n\n    def load_model(self, *args, **kwargs):\n        self.inference_engine.load_model(*args, **kwargs)\n\n        # inference engine is initialized now, update sharding manager\n        self.sharding_manager.inference_engine = self.inference_engine\n        self.sharding_manager.model_runner = self.inference_engine.worker.model_runner\n\n        _monkey_patch_compute_logits(self.inference_engine.worker.model_runner.model, len(self.tokenizer))\n\n    def sleep(self, *args, **kwargs):\n        \"\"\"Offload model weights and discard kv cache.\"\"\"\n        if self.is_sleep:\n            return\n        self.sharding_manager.__exit__(None, None, None)\n        self.is_sleep = True\n\n    def wake_up(self, *args, **kwargs):\n        \"\"\"Load model weights and build kv cache.\"\"\"\n        if not self.is_sleep:\n            return\n        self.sharding_manager.__enter__()  # pylint: disable=C2801\n        self.is_sleep = False\n\n    def execute_method(self, method: str | bytes, *args, **kwargs):\n        if method == \"init_worker\":\n            return self.init_worker(*args, **kwargs)\n        elif method == \"load_model\":\n            return self.load_model(*args, **kwargs)\n        elif method == \"sleep\":\n            return self.sleep(*args, **kwargs)\n        elif method == \"wake_up\":\n            return self.wake_up(*args, **kwargs)\n        else:\n            return self.inference_engine.execute_method(method, *args, **kwargs)\n"
  },
  {
    "path": "verl_rl/verl/workers/sharding_manager/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "verl_rl/verl/workers/sharding_manager/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nSharding manager to implement HybridEngine\n\"\"\"\n\nfrom verl import DataProto\n\n\nclass BaseShardingManager:\n    def __init__(self):\n        self.timing = {}\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        pass\n\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        return data\n\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        return data\n"
  },
  {
    "path": "verl_rl/verl/workers/sharding_manager/fsdp_sglang.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport logging\nimport os\n\nimport torch\nimport torch.distributed as dist\nfrom sglang.srt.entrypoints.engine import Engine\nfrom sglang.srt.model_executor.model_runner import LocalSerializedTensor\n\ntry:\n    from sglang.srt.utils import TorchPatchMultiprocessingSerializer as MultiprocessingSerializer\nexcept ImportError:\n    from sglang.srt.utils import MultiprocessingSerializer\nfrom torch.distributed.device_mesh import DeviceMesh\nfrom torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\nfrom torch.distributed.tensor import DTensor\n\nfrom verl import DataProto\nfrom verl.protocol import all_gather_data_proto\nfrom verl.utils.device import get_device_id, get_torch_device\nfrom verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu\nfrom verl.utils.model import convert_weight_keys\nfrom verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer\nfrom verl.utils.torch_functional import check_device_is_available\nfrom verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets\n\nfrom .base import BaseShardingManager\n\n# from vllm.distributed import parallel_state as sglang_ps\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\ndef _preprocess_tensor_for_update_weights(tensor: torch.Tensor):\n    if isinstance(tensor, DTensor):\n        return tensor.full_tensor()\n    return tensor\n\n\nclass FSDPSGLangShardingManager(BaseShardingManager):\n    @check_device_is_available()\n    def __init__(\n        self,\n        module: FSDP,\n        inference_engine: Engine,\n        model_config,\n        rollout_config,\n        full_params: bool = False,\n        device_mesh: DeviceMesh = None,\n        offload_param: bool = False,\n        multi_stage_wake_up: bool = False,\n    ):\n        self.module = module\n        self.inference_engine = inference_engine\n        self.model_config = model_config\n        self.rollout_config = rollout_config\n        self.device_mesh = device_mesh\n        self.offload_param = offload_param\n        self.multi_stage_wake_up = multi_stage_wake_up\n\n        # Full params\n        self.full_params = full_params\n        if full_params and fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(\n                self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig()\n            )\n        elif fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(\n                self.module,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n\n        self.tp_size = self.device_mesh[\"infer_tp\"].size()\n        self.tp_rank = self.device_mesh[\"infer_tp\"].get_local_rank()\n\n        # Note that torch_random_states may be different on each dp rank\n        self.torch_random_states = get_torch_device().get_rng_state()\n        # get a random rng states\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n    @GPUMemoryLogger(role=\"FSDPSGLangShardingManager enter\", logger=logger)\n    def __enter__(self):\n        self.timing = {}\n        with simple_timer(\"reshard\", self.timing):\n            loop = asyncio.get_event_loop()\n            loop.run_until_complete(self.wake_up())\n\n    @GPUMemoryLogger(role=\"FSDPSGLangShardingManager exit\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        loop = asyncio.get_event_loop()\n        loop.run_until_complete(self.sleep())\n\n    async def update_weights(self, params):\n        # Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update\n        named_tensors = [(k, v) for k, v in params.items()]\n        load_format = None\n        # convert megabytes to bytes\n        update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20\n        for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes):\n            # On each rank, serialize a batch of (name, tensor) tuples.\n            # named_tensors_batch will be a list like:\n            # [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...]\n            named_tensors_batch = [\n                (name, MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor)))\n                for name, tensor in batch\n            ]\n\n            if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n                # On rank 0, prepare a list to hold the gathered batches from all ranks.\n                gathered_serialized_batches = [None for _ in range(self.device_mesh[\"infer_tp\"].mesh.size()[0])]\n            else:\n                gathered_serialized_batches = None\n\n            # Gather the named_tensors_batch from all ranks to rank 0.\n            # After this, on rank 0, gathered_serialized_batches will be a list of lists:\n            # [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ],  # batch from TP rank 0\n            #   [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ],  # batch from TP rank 1\n            #   ... ]\n            # On other ranks, gathered_serialized_batches will be None.\n            dist.gather_object(\n                obj=named_tensors_batch,\n                object_gather_list=gathered_serialized_batches,\n                dst=self.device_mesh[\"infer_tp\"].mesh.tolist()[0],\n                group=self.device_mesh[\"infer_tp\"].get_group(),\n            )\n\n            if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n                # Use zip(*) to \"transpose\" the data structure.\n                # This groups the serialized parts for each individual tensor across all TP ranks.\n                # Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]]\n                # to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ]\n                logical_tensors = zip(*gathered_serialized_batches, strict=True)\n\n                await self.inference_engine.update_weights_from_tensor(\n                    named_tensors=[\n                        # 'tensor_group' represents a single logical tensor's data from all ranks.\n                        (\n                            tensor_group[0][0],  # Get the name from the first rank's data.\n                            LocalSerializedTensor(\n                                # 'rank_part' is the (name, serialized_tensor) tuple from one specific rank.\n                                values=[rank_part[1] for rank_part in tensor_group]\n                            ),\n                        )\n                        for tensor_group in logical_tensors\n                        # each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) )\n                    ],\n                    load_format=load_format,\n                    flush_cache=False,\n                )\n\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n            await self.inference_engine.flush_cache()\n\n    async def release_memory(self):\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:\n            await self.inference_engine.release_memory_occupation()\n\n    @GPUMemoryLogger(role=\"FSDPSGLangShardingManager enter\", logger=logger)\n    async def wake_up(self):\n        get_torch_device().empty_cache()\n\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:\n            if self.multi_stage_wake_up:\n                await self.inference_engine.resume_memory_occupation(tags=[\"weights\"])\n                log_gpu_memory_usage(\"Before resume SGLang weights in sharding manager\", logger=logger)\n            else:\n                await self.inference_engine.resume_memory_occupation()\n                log_gpu_memory_usage(\"Before resume SGLang weights + kv_cache in sharding manager\", logger=logger)\n\n        log_gpu_memory_usage(\"Before state_dict() in sharding manager memory\", logger=logger)\n        if self.offload_param:\n            load_fsdp_model_to_gpu(self.module)\n        params = self.module.state_dict()\n        log_gpu_memory_usage(\"After state_dict() in sharding manager memory\", logger=logger)\n        device = get_device_id()  # used when fsdp2 set cpu_offload_policy\n        params = {\n            k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()\n        }\n\n        # convert weight keys to match the model config\n        params = convert_weight_keys(params, getattr(self.module, \"_fsdp_wrapped_module\", self.module))\n\n        # Copy, not share memory\n        await self.update_weights(params)\n        log_gpu_memory_usage(\"After sync model weights in sharding manager\", logger=logger)\n\n        del params\n        if self.offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n        get_torch_device().empty_cache()\n        log_gpu_memory_usage(\"After del state_dict and empty_cache in sharding manager\", logger=logger)\n\n        if (\n            self.multi_stage_wake_up\n            and self.rollout_config.free_cache_engine\n            and self.device_mesh[\"infer_tp\"].get_local_rank() == 0\n        ):\n            await self.inference_engine.resume_memory_occupation(tags=[\"kv_cache\"])\n            log_gpu_memory_usage(\"After resume SGLang kv_cache in sharding manager\", logger=logger)\n\n        # important: need to manually set the random states of each tp to be identical.\n        if self.device_mesh is not None:\n            self.torch_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"FSDPSGLangShardingManager exit\", logger=logger)\n    async def sleep(self):\n        if self.rollout_config.free_cache_engine:\n            log_gpu_memory_usage(\"Before SGLang offload in sharding manager\", logger=logger)\n            await self.release_memory()\n            log_gpu_memory_usage(\"After SGLang offload in sharding manager\", logger=logger)\n\n        self.module.train()\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"All gather across tp group to make each rank has identical input.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        # TODO: Current impl doesn't consider FSDP with torch micro-dp\n        group = self.device_mesh[\"infer_tp\"].get_group()\n\n        all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"Get chunk data of this tp rank since we do all gather in preprocess.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        return data.chunk(chunks=self.tp_size)[self.tp_rank]\n"
  },
  {
    "path": "verl_rl/verl/workers/sharding_manager/fsdp_ulysses.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContains a resharding manager that binds weights from FSDP zero3 to XPerfGPT\n\"\"\"\n\nfrom torch.distributed.device_mesh import DeviceMesh\n\nfrom verl import DataProto\nfrom verl.protocol import all_gather_data_proto\nfrom verl.utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group\n\nfrom .base import BaseShardingManager\n\n\nclass FSDPUlyssesShardingManager(BaseShardingManager):\n    \"\"\"\n    Sharding manager to support data resharding when using FSDP + Ulysses\n    \"\"\"\n\n    def __init__(self, device_mesh: DeviceMesh):\n        super().__init__()\n        self.device_mesh = device_mesh\n        self.seed_offset = 12345\n\n    def __enter__(self):\n        if self.device_mesh is not None:\n            # We have a global SP group\n            # so we have to change to use model-specific sp group\n            self.prev_sp_group = get_ulysses_sequence_parallel_group()\n            set_ulysses_sequence_parallel_group(self.device_mesh[\"sp\"].get_group())\n            # TODO: check how to set seed for each model\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        # restore random states\n        if self.device_mesh is not None:\n            # revert to previous sp group\n            set_ulysses_sequence_parallel_group(self.prev_sp_group)\n            # TODO: check how to set seed for each model\n\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"\n        AllGather data from sp region\n        This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE\n        In Ulysses, we need to make sure the same data is used across a SP group\n        \"\"\"\n        if self.device_mesh is not None:\n            group = self.device_mesh[\"sp\"].get_group()\n\n            all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"\n        Split the data to follow FSDP partition\n        \"\"\"\n        if self.device_mesh is not None:\n            sp_size = self.device_mesh[\"sp\"].size()\n            sp_rank = self.device_mesh[\"sp\"].get_local_rank()\n            data = data.chunk(chunks=sp_size)[sp_rank]\n        return data\n"
  },
  {
    "path": "verl_rl/verl/workers/sharding_manager/fsdp_vllm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport logging\nimport os\nimport time\nfrom collections import OrderedDict\n\nfrom torch.distributed.device_mesh import DeviceMesh\nfrom torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n\ntry:\n    # for torch 2.5+\n    from torch.distributed.tensor import DTensor\nexcept ImportError:\n    from torch.distributed._tensor import DTensor\n\nfrom dataclasses import asdict\n\nfrom verl import DataProto\nfrom verl.protocol import all_gather_data_proto\nfrom verl.third_party.vllm import LLM\nfrom verl.third_party.vllm import parallel_state as vllm_ps\nfrom verl.utils.device import get_device_id, get_device_name, get_torch_device\nfrom verl.utils.fsdp_utils import (\n    fsdp_version,\n    layered_summon_lora_params,\n    load_fsdp_model_to_gpu,\n    offload_fsdp_model_to_cpu,\n)\nfrom verl.utils.model import check_exclude_modules, check_target_modules, convert_weight_keys\nfrom verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer\nfrom verl.utils.torch_functional import check_device_is_available\nfrom verl.utils.vllm_utils import TensorLoRARequest, VLLMHijack, is_version_ge, patch_vllm_moe_model_weight_loader\n\nfrom .base import BaseShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\nclass FSDPVLLMShardingManager(BaseShardingManager):\n    \"\"\"Sharding manager for FSDP models with vLLM inference engine integration.\n\n    Manages parameter synchronization between FSDP training models and vLLM\n    inference engines, handling both full parameters and LoRA adapters with\n    efficient memory management and device placement.\n    \"\"\"\n\n    @check_device_is_available()\n    def __init__(\n        self,\n        module: FSDP,\n        inference_engine: LLM,\n        model_config,\n        rollout_config,\n        full_params: bool = False,\n        device_mesh: DeviceMesh = None,\n        offload_param: bool = False,\n        load_format: str = \"dummy_hf\",\n        layered_summon: bool = True,\n    ):\n        self.module = module\n        # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model\n        self.inference_engine = inference_engine\n        # self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if\n        # inference_engine else None\n\n        self.model_runner = (\n            self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner\n            if self.inference_engine\n            else None\n        )\n\n        self.model_config = model_config\n        self.rollout_config = rollout_config\n        self.device_mesh = device_mesh\n        self.offload_param = offload_param\n        self.load_format = load_format\n        self.layered_summon = layered_summon\n\n        # Full params\n        self.full_params = full_params\n        if full_params and fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(\n                self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig()\n            )\n        elif fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(\n                self.module,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n\n        self.tp_size = self.device_mesh[\"infer_tp\"].size()\n        self.tp_rank = self.device_mesh[\"infer_tp\"].get_local_rank()\n\n        # Note that torch_random_states may be different on each dp rank\n        self.torch_random_states = get_torch_device().get_rng_state()\n        # get a random rng states\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n        self.base_sync_done: bool = \"dummy\" not in load_format\n        if is_version_ge(pkg=\"vllm\", minver=\"0.7.3\"):\n            VLLMHijack.hijack()\n\n    @GPUMemoryLogger(role=\"fsdp vllm sharding_manager\", logger=logger)\n    def __enter__(self):\n        def __collect_lora_params() -> OrderedDict:\n            \"\"\"\n            collect lora params or full params if base model is not ready in vllm\n            work with if isinstance(self.module._fsdp_wrapped_module, PeftModel)\n            \"\"\"\n            from peft.utils.save_and_load import get_peft_model_state_dict\n\n            lora_params = OrderedDict()\n            peft_model = getattr(self.module, \"_fsdp_wrapped_module\", self.module)\n            if fsdp_version(self.module) > 0:\n                if self.layered_summon:\n                    if not self.base_sync_done:\n                        raise ValueError(\n                            \"To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let \"\n                            \"rollout.load_format=safetensors\"\n                        )\n                    lora_params = layered_summon_lora_params(self.module)\n                else:\n                    with FSDP.summon_full_params(self.module, writeback=False):\n                        if self.base_sync_done:\n                            lora_params = get_peft_model_state_dict(peft_model)\n                            lora_params = {\n                                name: param.full_tensor().detach().cpu()\n                                if hasattr(param, \"full_tensor\")\n                                else param.detach().cpu()\n                                for name, param in lora_params.items()\n                            }\n                        else:\n                            model = peft_model.base_model.model\n                            orig_dev = \"cpu\" if \"cpu\" in str(next(model.parameters()).device) else get_device_name()\n                            model = model.to(\"cpu\")\n                            for name, param in model.state_dict().items():\n                                if any(x in name for x in [\"_flat_param\", \"lora_\"]):\n                                    continue\n                                name = name.replace(\"_fsdp_wrapped_module.\", \"\").replace(\".base_layer\", \"\")\n                                lora_params[name] = (\n                                    param.full_tensor().detach().cpu()\n                                    if hasattr(param, \"full_tensor\")\n                                    else param.detach().cpu()\n                                )\n                            model = model.to(orig_dev)\n                    get_torch_device().empty_cache()\n            else:\n                if self.base_sync_done:\n                    lora_params = get_peft_model_state_dict(peft_model)\n                else:\n                    model = peft_model.base_model.model\n                    orig_dev = \"cpu\" if \"cpu\" in str(next(model.parameters()).device) else get_device_name()\n                    model = model.to(\"cpu\")\n                    for name, param in model.state_dict().items():\n                        if any(x in name for x in [\"_flat_param\", \"lora_\"]):\n                            continue\n                        name = name.replace(\"_fsdp_wrapped_module.\", \"\").replace(\".base_layer\", \"\")\n                        lora_params[name] = param.detach().cpu()\n                    model = model.to(orig_dev)\n            return lora_params\n\n        # NOTE: Basically, we only need `get_torch_device().empty_cache()` before vllm wake_up and\n        # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.\n        # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory\n        # to speed up memory allocations.\n        #\n        # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management\n        # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103\n        self.timing = {}\n        with simple_timer(\"reshard\", self.timing):\n            get_torch_device().empty_cache()\n\n            log_gpu_memory_usage(\"Before state_dict() in sharding manager memory\", logger=logger)\n            if self.offload_param:\n                load_fsdp_model_to_gpu(self.module)\n\n            peft_config = None\n            peft_model = getattr(self.module, \"_fsdp_wrapped_module\", self.module)\n            if hasattr(peft_model, \"peft_config\"):\n                peft_config = peft_model.peft_config.get(\"default\", None)\n                params = __collect_lora_params()\n            else:\n                params = self.module.state_dict()\n            params = convert_weight_keys(params, getattr(self.module, \"_fsdp_wrapped_module\", self.module))\n            log_gpu_memory_usage(\"After state_dict() in sharding manager memory\", logger=logger)\n\n            if self.rollout_config.free_cache_engine:\n                if \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters:\n                    self.inference_engine.wake_up(tags=[\"weights\"])\n                else:\n                    self.inference_engine.wake_up()\n\n            # update model params\n            self.update_params(params, peft_config=peft_config)\n            log_gpu_memory_usage(\"After sync model weights in sharding manager\", logger=logger)\n            del params\n            if self.offload_param:\n                offload_fsdp_model_to_cpu(self.module)\n            get_torch_device().empty_cache()\n\n            if (\n                self.rollout_config.free_cache_engine\n                and \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters\n            ):\n                self.inference_engine.wake_up(tags=[\"kv_cache\"])\n\n            log_gpu_memory_usage(\"After del state_dict and empty_cache in sharding manager\", logger=logger)\n\n            # important: need to manually set the random states of each tp to be identical.\n            if self.device_mesh is not None:\n                self.torch_random_states = get_torch_device().get_rng_state()\n                get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"fsdp vllm sharding_manager\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        if self.rollout_config.free_cache_engine:\n            self.inference_engine.sleep(level=1)\n\n        self.module.train()\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n\n    @GPUMemoryLogger(role=\"fsdp vllm sharding_manager\", logger=logger)\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"All gather across tp group to make each rank has identical input.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        # TODO: Current impl doesn't consider FSDP with torch micro-dp\n        group = vllm_ps.get_tensor_model_parallel_group().device_group\n\n        all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    @GPUMemoryLogger(role=\"fsdp vllm sharding_manager\", logger=logger)\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        \"\"\"Get chunk data of this tp rank since we do all gather in preprocess.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        return data.chunk(chunks=self.tp_size)[self.tp_rank]\n\n    def update_params(self, updated_params, peft_config=None):\n        \"\"\"Update model parameters in the vLLM inference engine.\n\n        Synchronizes parameters from the FSDP training model to the vLLM inference\n        engine, handling both full model parameters and LoRA adapters with proper\n        device placement and memory management.\n\n        Args:\n            updated_params (dict): Dictionary of parameter names to tensor values.\n            peft_config (optional): PEFT configuration for LoRA adapters.\n        \"\"\"\n        model = self.model_runner.model\n        if peft_config:\n            if self.base_sync_done:\n                lora_int_id = int(time.time_ns() % 0x7FFFFFFF)\n                lora_reqest = TensorLoRARequest(\n                    lora_name=f\"{lora_int_id}\",\n                    lora_int_id=lora_int_id,\n                    lora_path=\"simon_lora_path\",\n                    peft_config=asdict(peft_config),\n                    lora_tensors=updated_params,\n                )\n                self.inference_engine.llm_engine.add_lora(lora_reqest)\n                logger.info(f\"vLLM load weights, loaded_params: {len(updated_params)}\")\n                return\n            else:\n\n                def replace_lora_wrapper(k):\n                    \"\"\"Replace LoRA parameter keys with base layer equivalents.\n\n                    Transforms LoRA parameter names to their corresponding base layer\n                    names for proper weight loading in vLLM when base model sync is not done.\n\n                    Args:\n                        k (str): Original parameter key name.\n\n                    Returns:\n                        str: Transformed parameter key for base layer.\n                    \"\"\"\n                    stacked_params = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]\n                    if k.endswith(\".weight\"):\n                        module_k = k[: -len(\".weight\")]\n                        if check_exclude_modules(peft_config, module_k):\n                            return k\n                        elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(\n                            peft_config, module_k\n                        ):\n                            return f\"{module_k}.base_layer.weight\"\n                    if k.endswith(\".bias\"):\n                        module_k = k[: -len(\".bias\")]\n                        if check_exclude_modules(peft_config, module_k):\n                            return k\n                        elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(\n                            peft_config, module_k\n                        ):\n                            return f\"{module_k}.base_layer.bias\"\n                    return k\n\n                updated_params = {replace_lora_wrapper(k): v for k, v in updated_params.items()}\n\n        patch_vllm_moe_model_weight_loader(model)\n        device = get_device_id()  # used when fsdp2 set cpu_offload_policy\n        loaded_params = model.load_weights(\n            (\n                (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)\n                for name, param in updated_params.items()\n            )\n        )\n\n        self.base_sync_done = True\n        logger.info(f\"vLLM load weights, loaded_params: {len(loaded_params) if loaded_params else -1}\")\n"
  },
  {
    "path": "verl_rl/verl/workers/sharding_manager/megatron_sglang.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.\n\"\"\"\n\nimport asyncio\nimport logging\nimport os\n\nimport torch.distributed as dist\nfrom omegaconf import DictConfig\nfrom sglang.srt.entrypoints.engine import Engine\nfrom sglang.srt.model_executor.model_runner import LocalSerializedTensor\n\ntry:\n    from sglang.srt.utils import TorchPatchMultiprocessingSerializer as MultiprocessingSerializer\nexcept ImportError:\n    from sglang.srt.utils import MultiprocessingSerializer\nfrom torch import nn\nfrom torch.distributed.device_mesh import DeviceMesh\n\nfrom verl.protocol import DataProto, all_gather_data_proto\nfrom verl.utils.device import get_torch_device\nfrom verl.utils.megatron_utils import (\n    load_megatron_model_to_gpu,\n    offload_megatron_model_to_cpu,\n    per_tensor_generator,\n)\nfrom verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer\nfrom verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets\n\nfrom .base import BaseShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_PPO_LOGGING_LEVEL\", \"WARN\"))\n\n\n\"\"\"\nMegatron Hybrid Engine:\n- During training, only the current pp stage holds the parameters\n- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all \n  the parameters)\n- Bind the parameters to the inference engine\n- Do inference in tp. pp is treated as additional dp\n- After inference, all the parameters that doesn't belong to this pp rank is freed.\n\"\"\"\n\n\nclass MegatronSGLangShardingManager(BaseShardingManager):\n    \"\"\"A sharding manager for Megatron-style training & inference with SGLang.\n\n    This class manages the sharding of model parameters between training and inference\n    phases in a Megatron-style parallel setup. It handles:\n    - Loading/offloading parameters between CPU/GPU\n    - Updating inference engine weights\n    - Managing random states for reproducibility\n    - Data preprocessing for distributed inference\n\n    Args:\n        actor_module (nn.ModuleList): The actor model modules\n        inference_engine (Engine): The SGLang inference engine\n        model_config: Configuration for the actor's model\n        rollout_config: Configuration for rollout generation\n        transformer_config: Transformer-specific configuration\n        layer_name_mapping: Mapping between layer names and parameters\n        weight_converter: Utility for converting weights between formats\n        device_mesh (DeviceMesh | None): PyTorch device mesh for distributed training\n        offload_param (bool): Whether to offload parameters to CPU when not in use\n    \"\"\"\n\n    def __init__(\n        self,\n        actor_module: nn.ModuleList,\n        inference_engine: Engine,\n        model_config: DictConfig,\n        rollout_config: DictConfig,\n        transformer_config,\n        layer_name_mapping,\n        weight_converter,\n        device_mesh: DeviceMesh | None = None,\n        offload_param: bool = False,\n        bridge=None,\n    ):\n        self.actor_module = actor_module\n        self.inference_engine = inference_engine\n        self.model_config = model_config\n        self.rollout_config = rollout_config\n        self.transformer_config = transformer_config\n        self.layer_name_mapping = layer_name_mapping\n        self.weight_converter = weight_converter\n        self.device_mesh = device_mesh\n        self.bridge = bridge\n        self.offload_param = offload_param\n\n        if self.device_mesh is not None:\n            self.infer_tp_size = self.device_mesh[\"tp\"].mesh.size()[0]\n        else:\n            self.infer_tp_size = self.inference_engine._tp_size\n\n        # Note that torch_random_states may be different on each dp rank\n        self.torch_random_states = get_torch_device().get_rng_state()\n        # get a random rng states\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n    @GPUMemoryLogger(role=\"MegatronSGLangShardingManager enter\", logger=logger)\n    def __enter__(self):\n        self.timing = {}\n        with simple_timer(\"reshard\", self.timing):\n            loop = asyncio.get_event_loop()\n            loop.run_until_complete(self.wake_up())\n\n    @GPUMemoryLogger(role=\"MegatronSGLangShardingManager exit\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        loop = asyncio.get_event_loop()\n        loop.run_until_complete(self.sleep())\n\n    async def update_weights(self, params):\n        \"\"\"\n        Update model weights using tensor buckets, similar to THUDM/slime's implementation.\n\n        Notes:\n          - For the best performance of `rebuild_cuda_tensor`, it is recommended to:\n              1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`.\n              2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`\n            when using Tensor Parallelism (TP >= 8).\n          - See reference implementations in SLIME:\n            - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452\n            - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39\n        \"\"\"\n        if self.device_mesh[\"tp\"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:\n            await self.inference_engine.resume_memory_occupation()\n        named_tensors = params\n        load_format = None\n\n        update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20\n        for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes):\n            # On each rank, serialize a batch of (name, tensor) tuples.\n            # named_tensors_batch will be a list like:\n            # [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...]\n            named_tensors_batch = [\n                (name, MultiprocessingSerializer.serialize(tensor.detach())) for name, tensor in batch\n            ]\n\n            if self.device_mesh[\"tp\"].get_local_rank() == 0:\n                # On rank 0, prepare a list to hold the gathered batches from all ranks.\n                gathered_serialized_batches = [None for _ in range(self.device_mesh[\"tp\"].mesh.size()[0])]\n            else:\n                gathered_serialized_batches = None\n\n            # Gather the named_tensors_batch from all ranks to rank 0.\n            # After this, on rank 0, gathered_serialized_batches will be a list of lists:\n            # [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ],  # batch from TP rank 0\n            #   [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ],  # batch from TP rank 1\n            #   ... ]\n            # On other ranks, gathered_serialized_batches will be None.\n            dist.gather_object(\n                obj=named_tensors_batch,\n                object_gather_list=gathered_serialized_batches,\n                dst=self.device_mesh[\"tp\"].mesh.tolist()[0],\n                group=self.device_mesh[\"tp\"].get_group(),\n            )\n\n            if self.device_mesh[\"tp\"].get_local_rank() == 0:\n                # Use zip(*) to \"transpose\" the data structure.\n                # This groups the serialized parts for each individual tensor across all TP ranks.\n                # Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]]\n                # to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ]\n                logical_tensors = zip(*gathered_serialized_batches, strict=False)\n                await self.inference_engine.update_weights_from_tensor(\n                    named_tensors=[\n                        # 'tensor_group' represents a single logical tensor's data from all ranks.\n                        (\n                            tensor_group[0][0],  # Get the name from the first rank's data.\n                            LocalSerializedTensor(\n                                # 'rank_part' is the (name, serialized_tensor) tuple from one specific rank.\n                                values=[rank_part[1] for rank_part in tensor_group]\n                            ),\n                        )\n                        for tensor_group in logical_tensors\n                        # each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) )\n                    ],\n                    load_format=load_format,\n                    flush_cache=False,\n                )\n\n        if self.device_mesh[\"tp\"].get_local_rank() == 0:\n            await self.inference_engine.flush_cache()\n\n    async def release_memory(self):\n        if self.device_mesh[\"tp\"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:\n            await self.inference_engine.release_memory_occupation()\n\n    @GPUMemoryLogger(role=\"MegatronSGLangShardingManager enter\", logger=logger)\n    async def wake_up(self):\n        if self.offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n        if self.bridge is not None:\n            per_tensor_param = self.bridge.export_weights(self.actor_module)\n        else:\n            per_tensor_param = per_tensor_generator(\n                self.actor_module,\n                self.model_config,\n                self.weight_converter,\n                self.transformer_config,\n                self.layer_name_mapping,\n            )\n        await self.update_weights(per_tensor_param)\n        if self.offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n        get_torch_device().empty_cache()\n        # important: need to manually set the random states of each tp to be identical.\n        if self.device_mesh is not None:\n            self.torch_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"MegatronSGLangShardingManager exit\", logger=logger)\n    async def sleep(self):\n        if self.rollout_config.free_cache_engine:\n            log_gpu_memory_usage(\"Before SGLang offload in sharding manager\", logger=logger)\n            await self.release_memory()\n            log_gpu_memory_usage(\"After SGLang offload in sharding manager\", logger=logger)\n\n        for model in self.actor_module:\n            model.train()\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n\n    @GPUMemoryLogger(role=\"megatron sglang sharding_manager\", logger=logger)\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp\n        if self.infer_tp_size == 1:\n            return data\n        all_gather_data_proto(data, self.device_mesh[\"tp\"].get_group())\n        return data\n\n    @GPUMemoryLogger(role=\"megatron sglang sharding_manager\", logger=logger)\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp\n        if self.infer_tp_size == 1:\n            return data\n        return data.chunk(chunks=self.infer_tp_size)[self.device_mesh[\"tp\"].get_local_rank()]\n"
  },
  {
    "path": "verl_rl/verl/workers/sharding_manager/megatron_vllm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.\n\"\"\"\n\nimport inspect\nimport logging\nimport os\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom omegaconf import DictConfig\nfrom torch import nn\n\nfrom verl import DataProto\nfrom verl.models.mcore.weight_converter import McoreToHFWeightConverterBase\nfrom verl.protocol import all_gather_data_proto\nfrom verl.third_party.vllm import LLM\nfrom verl.third_party.vllm import parallel_state as vllm_ps\nfrom verl.utils.device import get_torch_device\nfrom verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator\nfrom verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage\nfrom verl.utils.profiler.performance import simple_timer\nfrom verl.utils.torch_functional import check_device_is_available\nfrom verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader\n\nfrom .base import BaseShardingManager\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"VERL_LOGGING_LEVEL\", \"WARN\"))\n\n\n\"\"\"\nMegatron Hybrid Engine:\n- During training, only the current pp stage holds the parameters\n- Before inference, broadcast the parameters of the current pp rank \n   to all other pp ranks (all pp ranks holds all the parameters)\n- Bind the parameters to the inference engine\n- Do inference in tp. pp is treated as additional dp\n- After inference, all the parameters that doesn't belong to this pp rank is freed.\n\"\"\"\n\n\nclass MegatronVLLMShardingManager(BaseShardingManager):\n    \"\"\"A sharding manager that bridges Megatron-LM training with vLLM inference.\n\n    This class handles the parameter sharding and communication between:\n    - Megatron-LM's tensor/expert parallel training setup\n    - vLLM's tensor parallel inference setup\n\n    Key responsibilities:\n    - Manages parameter broadcasting between training and inference configurations\n    - Handles weight conversion between Megatron and HuggingFace formats\n    - Coordinates memory management between training and inference phases\n    - Maintains random state consistency across different parallel groups\n\n    Args:\n        actor_module (nn.ModuleList): The Megatron-LM model being trained\n        inference_engine (LLM): The vLLM inference engine\n        model_config: Configuration for the actor's model\n        transformer_config: Transformer-specific configuration for the model\n        rollout_config: Configuration for rollout\n        layer_name_mapping: Mapping between Megatron and HF layer names\n        weight_converter (McoreToHFWeightConverterBase): Converts weights between formats\n        device_mesh: Device mesh for parallel operations\n        offload_param (bool): Whether to offload parameters when not in use\n    \"\"\"\n\n    @check_device_is_available()\n    def __init__(\n        self,\n        actor_module: nn.ModuleList,\n        inference_engine: LLM,\n        model_config: DictConfig,\n        transformer_config,\n        rollout_config: DictConfig,\n        layer_name_mapping,\n        weight_converter: McoreToHFWeightConverterBase,\n        device_mesh,\n        offload_param: bool = True,\n        bridge=None,\n    ):\n        self.actor_module = actor_module\n        self.inference_engine = inference_engine\n        self.offload_param = offload_param\n\n        # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model\n        self.model_runner = (\n            self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner\n            if self.inference_engine\n            else None\n        )\n\n        self.model_config = model_config\n        self.transformer_config = transformer_config\n        self.rollout_config = rollout_config\n        self.layer_name_mapping = layer_name_mapping\n        self.weight_converter = weight_converter\n        self.bridge = bridge\n        # initialize groups for vllm inference\n        self.rank = torch.distributed.get_rank()\n        self.world_size = torch.distributed.get_world_size()\n\n        self.device_mesh = device_mesh\n        self.infer_tp_size = self.device_mesh[\"infer_tp\"].size()\n        self.infer_tp_rank = self.device_mesh[\"infer_tp\"].get_local_rank()\n\n        self.train_tp_size = mpu.get_tensor_model_parallel_world_size()\n        self.train_tp_rank = mpu.get_tensor_model_parallel_rank()\n        self.train_tp_group = mpu.get_tensor_model_parallel_group()\n        self.train_ep_size = mpu.get_expert_model_parallel_world_size()\n        self.train_ep_rank = mpu.get_expert_model_parallel_rank()\n        self.train_ep_group = mpu.get_expert_model_parallel_group()\n        self.train_etp_size = mpu.get_expert_tensor_parallel_world_size()\n        self.train_etp_rank = mpu.get_expert_tensor_parallel_rank()\n        self.train_etp_group = mpu.get_expert_tensor_parallel_group()\n        self.need_tp_reshard = self.train_tp_size != self.infer_tp_size\n        self.train_tp_larger = self.train_tp_size > self.infer_tp_size\n\n        self.torch_random_states = get_torch_device().get_rng_state()\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n    @GPUMemoryLogger(role=\"megatron vllm sharding_manager\", logger=logger)\n    def __enter__(self):\n        self.timing = {}\n        with simple_timer(\"reshard\", self.timing):\n            get_torch_device().empty_cache()\n\n            log_gpu_memory_usage(\"Before state_dict() in sharding manager memory\", logger=logger)\n            if self.offload_param:\n                load_megatron_model_to_gpu(self.actor_module)\n\n            if self.rollout_config.free_cache_engine:\n                if \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters:\n                    self.inference_engine.wake_up(tags=[\"weights\"])\n                else:\n                    self.inference_engine.wake_up()\n            if self.bridge is not None:\n                per_tensor_param = self.bridge.export_weights(self.actor_module)\n            else:\n                per_tensor_param = per_tensor_generator(\n                    self.actor_module,\n                    self.model_config,\n                    self.weight_converter,\n                    self.transformer_config,\n                    self.layer_name_mapping,\n                )\n            model = self.model_runner.model\n            patch_vllm_moe_model_weight_loader(model)\n            loaded_params = model.load_weights(per_tensor_param)\n            info = f\"vLLM load weights, loaded_params: {len(loaded_params)}\"\n            logger.info(info)\n\n            if self.offload_param:\n                offload_megatron_model_to_cpu(self.actor_module)\n            get_torch_device().empty_cache()\n\n            if (\n                self.rollout_config.free_cache_engine\n                and \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters\n            ):\n                self.inference_engine.wake_up(tags=[\"kv_cache\"])\n\n            # important: need to manually set the random states of each tp to be identical.\n            if self.device_mesh is not None:\n                self.torch_random_states = get_torch_device().get_rng_state()\n                get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"megatron vllm sharding_manager\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        if self.rollout_config.free_cache_engine:\n            self.inference_engine.sleep(level=1)\n        for model in self.actor_module:\n            model.train()\n\n        get_torch_device().empty_cache()\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n\n    @GPUMemoryLogger(role=\"megatron vllm sharding_manager\", logger=logger)\n    def preprocess_data(self, data: DataProto) -> DataProto:\n        # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp\n        if self.infer_tp_size == 1:\n            return data\n\n        # TODO: Current impl doesn't consider FSDP with torch micro-dp\n        group = vllm_ps.get_tensor_model_parallel_group().device_group\n\n        all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    @GPUMemoryLogger(role=\"megatron vllm sharding_manager\", logger=logger)\n    def postprocess_data(self, data: DataProto) -> DataProto:\n        # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp\n        if self.infer_tp_size == 1:\n            return data\n        return data.chunk(chunks=self.infer_tp_size)[self.infer_tp_rank]\n"
  }
]